From 2c3c1048746a4622d8c89a29670120dc8fab93c4 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 7 Apr 2024 20:49:45 +0200 Subject: Adding upstream version 6.1.76. Signed-off-by: Daniel Baumann --- net/6lowpan/6lowpan_i.h | 34 + net/6lowpan/Kconfig | 105 + net/6lowpan/Makefile | 22 + net/6lowpan/core.c | 182 + net/6lowpan/debugfs.c | 278 + net/6lowpan/iphc.c | 1313 ++ net/6lowpan/ndisc.c | 233 + net/6lowpan/nhc.c | 169 + net/6lowpan/nhc.h | 133 + net/6lowpan/nhc_dest.c | 17 + net/6lowpan/nhc_fragment.c | 16 + net/6lowpan/nhc_ghc_ext_dest.c | 16 + net/6lowpan/nhc_ghc_ext_frag.c | 17 + net/6lowpan/nhc_ghc_ext_hop.c | 16 + net/6lowpan/nhc_ghc_ext_route.c | 16 + net/6lowpan/nhc_ghc_icmpv6.c | 16 + net/6lowpan/nhc_ghc_udp.c | 16 + net/6lowpan/nhc_hop.c | 16 + net/6lowpan/nhc_ipv6.c | 16 + net/6lowpan/nhc_mobility.c | 16 + net/6lowpan/nhc_routing.c | 16 + net/6lowpan/nhc_udp.c | 176 + net/802/Kconfig | 11 + net/802/Makefile | 14 + net/802/fc.c | 106 + net/802/fddi.c | 178 + net/802/garp.c | 649 + net/802/hippi.c | 193 + net/802/mrp.c | 943 ++ net/802/p8022.c | 63 + net/802/psnap.c | 163 + net/802/stp.c | 101 + net/8021q/Kconfig | 41 + net/8021q/Makefile | 11 + net/8021q/vlan.c | 742 + net/8021q/vlan.h | 206 + net/8021q/vlan_core.c | 558 + net/8021q/vlan_dev.c | 876 ++ net/8021q/vlan_gvrp.c | 67 + net/8021q/vlan_mvrp.c | 73 + net/8021q/vlan_netlink.c | 316 + net/8021q/vlanproc.c | 291 + net/8021q/vlanproc.h | 21 + net/9p/Kconfig | 52 + net/9p/Makefile | 25 + net/9p/client.c | 2276 +++ net/9p/error.c | 230 + net/9p/mod.c | 212 + net/9p/protocol.c | 801 + net/9p/protocol.h | 19 + net/9p/trans_common.c | 24 + net/9p/trans_common.h | 7 + net/9p/trans_fd.c | 1205 ++ net/9p/trans_rdma.c | 782 + net/9p/trans_virtio.c | 839 ++ net/9p/trans_xen.c | 570 + net/Kconfig | 474 + net/Kconfig.debug | 26 + net/Makefile | 81 + net/appletalk/Makefile | 10 + net/appletalk/aarp.c | 1049 ++ net/appletalk/atalk_proc.c | 242 + net/appletalk/ddp.c | 2040 +++ net/appletalk/dev.c | 46 + net/appletalk/sysctl_net_atalk.c | 59 + net/atm/Kconfig | 74 + net/atm/Makefile | 16 + net/atm/addr.c | 162 + net/atm/addr.h | 21 + net/atm/atm_misc.c | 102 + net/atm/atm_sysfs.c | 185 + net/atm/br2684.c | 872 ++ net/atm/clip.c | 929 ++ net/atm/common.c | 895 ++ net/atm/common.h | 56 + net/atm/ioctl.c | 365 + net/atm/lec.c | 2237 +++ net/atm/lec.h | 155 + net/atm/lec_arpc.h | 97 + net/atm/mpc.c | 1535 ++ net/atm/mpc.h | 65 + net/atm/mpoa_caches.c | 565 + net/atm/mpoa_caches.h | 99 + net/atm/mpoa_proc.c | 307 + net/atm/pppoatm.c | 491 + net/atm/proc.c | 400 + net/atm/protocols.h | 14 + net/atm/pvc.c | 163 + net/atm/raw.c | 94 + net/atm/resources.c | 419 + net/atm/resources.h | 49 + net/atm/signaling.c | 244 + net/atm/signaling.h | 31 + net/atm/svc.c | 692 + net/ax25/Kconfig | 122 + net/ax25/Makefile | 12 + net/ax25/af_ax25.c | 2083 +++ net/ax25/ax25_addr.c | 303 + net/ax25/ax25_dev.c | 215 + net/ax25/ax25_ds_in.c | 298 + net/ax25/ax25_ds_subr.c | 204 + net/ax25/ax25_ds_timer.c | 235 + net/ax25/ax25_iface.c | 214 + net/ax25/ax25_in.c | 451 + net/ax25/ax25_ip.c | 246 + net/ax25/ax25_out.c | 386 + net/ax25/ax25_route.c | 488 + net/ax25/ax25_std_in.c | 443 + net/ax25/ax25_std_subr.c | 83 + net/ax25/ax25_std_timer.c | 175 + net/ax25/ax25_subr.c | 296 + net/ax25/ax25_timer.c | 224 + net/ax25/ax25_uid.c | 204 + net/ax25/sysctl_net_ax25.c | 181 + net/batman-adv/Kconfig | 96 + net/batman-adv/Makefile | 34 + net/batman-adv/bat_algo.c | 209 + net/batman-adv/bat_algo.h | 25 + net/batman-adv/bat_iv_ogm.c | 2551 ++++ net/batman-adv/bat_iv_ogm.h | 14 + net/batman-adv/bat_v.c | 910 ++ net/batman-adv/bat_v.h | 41 + net/batman-adv/bat_v_elp.c | 551 + net/batman-adv/bat_v_elp.h | 24 + net/batman-adv/bat_v_ogm.c | 1088 ++ net/batman-adv/bat_v_ogm.h | 27 + net/batman-adv/bitarray.c | 89 + net/batman-adv/bitarray.h | 56 + net/batman-adv/bridge_loop_avoidance.c | 2502 ++++ net/batman-adv/bridge_loop_avoidance.h | 132 + net/batman-adv/distributed-arp-table.c | 1831 +++ net/batman-adv/distributed-arp-table.h | 186 + net/batman-adv/fragmentation.c | 562 + net/batman-adv/fragmentation.h | 44 + net/batman-adv/gateway_client.c | 769 + net/batman-adv/gateway_client.h | 55 + net/batman-adv/gateway_common.c | 274 + net/batman-adv/gateway_common.h | 38 + net/batman-adv/hard-interface.c | 1022 ++ net/batman-adv/hard-interface.h | 122 + net/batman-adv/hash.c | 84 + net/batman-adv/hash.h | 157 + net/batman-adv/log.c | 36 + net/batman-adv/log.h | 143 + net/batman-adv/main.c | 731 + net/batman-adv/main.h | 384 + net/batman-adv/multicast.c | 2317 +++ net/batman-adv/multicast.h | 123 + net/batman-adv/netlink.c | 1522 ++ net/batman-adv/netlink.h | 32 + net/batman-adv/network-coding.c | 1873 +++ net/batman-adv/network-coding.h | 106 + net/batman-adv/originator.c | 1349 ++ net/batman-adv/originator.h | 167 + net/batman-adv/routing.c | 1273 ++ net/batman-adv/routing.h | 46 + net/batman-adv/send.c | 1135 ++ net/batman-adv/send.h | 116 + net/batman-adv/soft-interface.c | 1132 ++ net/batman-adv/soft-interface.h | 42 + net/batman-adv/tp_meter.c | 1490 ++ net/batman-adv/tp_meter.h | 22 + net/batman-adv/trace.c | 8 + net/batman-adv/trace.h | 64 + net/batman-adv/translation-table.c | 4290 ++++++ net/batman-adv/translation-table.h | 74 + net/batman-adv/tvlv.c | 636 + net/batman-adv/tvlv.h | 49 + net/batman-adv/types.h | 2378 +++ net/bluetooth/6lowpan.c | 1283 ++ net/bluetooth/Kconfig | 152 + net/bluetooth/Makefile | 27 + net/bluetooth/a2mp.c | 1054 ++ net/bluetooth/a2mp.h | 154 + net/bluetooth/af_bluetooth.c | 812 + net/bluetooth/amp.c | 591 + net/bluetooth/amp.h | 61 + net/bluetooth/aosp.c | 210 + net/bluetooth/aosp.h | 29 + net/bluetooth/bnep/Kconfig | 25 + net/bluetooth/bnep/Makefile | 8 + net/bluetooth/bnep/bnep.h | 173 + net/bluetooth/bnep/core.c | 769 + net/bluetooth/bnep/netdev.c | 230 + net/bluetooth/bnep/sock.c | 268 + net/bluetooth/cmtp/Kconfig | 12 + net/bluetooth/cmtp/Makefile | 8 + net/bluetooth/cmtp/capi.c | 595 + net/bluetooth/cmtp/cmtp.h | 129 + net/bluetooth/cmtp/core.c | 519 + net/bluetooth/cmtp/sock.c | 271 + net/bluetooth/ecdh_helper.c | 228 + net/bluetooth/ecdh_helper.h | 30 + net/bluetooth/eir.c | 398 + net/bluetooth/eir.h | 99 + net/bluetooth/hci_codec.c | 253 + net/bluetooth/hci_codec.h | 7 + net/bluetooth/hci_conn.c | 2917 ++++ net/bluetooth/hci_core.c | 4158 ++++++ net/bluetooth/hci_debugfs.c | 1379 ++ net/bluetooth/hci_debugfs.h | 53 + net/bluetooth/hci_event.c | 7576 ++++++++++ net/bluetooth/hci_request.c | 922 ++ net/bluetooth/hci_request.h | 75 + net/bluetooth/hci_sock.c | 2208 +++ net/bluetooth/hci_sync.c | 6316 ++++++++ net/bluetooth/hci_sysfs.c | 124 + net/bluetooth/hidp/Kconfig | 13 + net/bluetooth/hidp/Makefile | 8 + net/bluetooth/hidp/core.c | 1480 ++ net/bluetooth/hidp/hidp.h | 192 + net/bluetooth/hidp/sock.c | 320 + net/bluetooth/iso.c | 1884 +++ net/bluetooth/l2cap_core.c | 8657 +++++++++++ net/bluetooth/l2cap_sock.c | 1981 +++ net/bluetooth/leds.c | 100 + net/bluetooth/leds.h | 23 + net/bluetooth/lib.c | 320 + net/bluetooth/mgmt.c | 10571 +++++++++++++ net/bluetooth/mgmt_config.c | 346 + net/bluetooth/mgmt_config.h | 17 + net/bluetooth/mgmt_util.c | 390 + net/bluetooth/mgmt_util.h | 79 + net/bluetooth/msft.c | 837 ++ net/bluetooth/msft.h | 78 + net/bluetooth/rfcomm/Kconfig | 19 + net/bluetooth/rfcomm/Makefile | 9 + net/bluetooth/rfcomm/core.c | 2281 +++ net/bluetooth/rfcomm/sock.c | 1093 ++ net/bluetooth/rfcomm/tty.c | 1162 ++ net/bluetooth/sco.c | 1509 ++ net/bluetooth/selftest.c | 309 + net/bluetooth/selftest.h | 45 + net/bluetooth/smp.c | 3848 +++++ net/bluetooth/smp.h | 214 + net/bpf/Makefile | 5 + net/bpf/bpf_dummy_struct_ops.c | 216 + net/bpf/test_run.c | 1676 +++ net/bpfilter/.gitignore | 2 + net/bpfilter/Kconfig | 23 + net/bpfilter/Makefile | 20 + net/bpfilter/bpfilter_kern.c | 136 + net/bpfilter/bpfilter_umh_blob.S | 7 + net/bpfilter/main.c | 64 + net/bpfilter/msgfmt.h | 17 + net/bridge/Kconfig | 86 + net/bridge/Makefile | 31 + net/bridge/br.c | 470 + net/bridge/br_arp_nd_proxy.c | 485 + net/bridge/br_cfm.c | 867 ++ net/bridge/br_cfm_netlink.c | 726 + net/bridge/br_device.c | 534 + net/bridge/br_fdb.c | 1468 ++ net/bridge/br_forward.c | 343 + net/bridge/br_if.c | 777 + net/bridge/br_input.c | 441 + net/bridge/br_ioctl.c | 440 + net/bridge/br_mdb.c | 1164 ++ net/bridge/br_mrp.c | 1260 ++ net/bridge/br_mrp_netlink.c | 571 + net/bridge/br_mrp_switchdev.c | 241 + net/bridge/br_mst.c | 357 + net/bridge/br_multicast.c | 4946 ++++++ net/bridge/br_multicast_eht.c | 819 + net/bridge/br_netfilter_hooks.c | 1247 ++ net/bridge/br_netfilter_ipv6.c | 248 + net/bridge/br_netlink.c | 1875 +++ net/bridge/br_netlink_tunnel.c | 339 + net/bridge/br_nf_core.c | 91 + net/bridge/br_private.h | 2175 +++ net/bridge/br_private_cfm.h | 147 + net/bridge/br_private_mcast_eht.h | 94 + net/bridge/br_private_mrp.h | 148 + net/bridge/br_private_stp.h | 66 + net/bridge/br_private_tunnel.h | 85 + net/bridge/br_stp.c | 713 + net/bridge/br_stp_bpdu.c | 247 + net/bridge/br_stp_if.c | 351 + net/bridge/br_stp_timer.c | 161 + net/bridge/br_switchdev.c | 825 + net/bridge/br_sysfs_br.c | 1088 ++ net/bridge/br_sysfs_if.c | 412 + net/bridge/br_vlan.c | 2310 +++ net/bridge/br_vlan_options.c | 697 + net/bridge/br_vlan_tunnel.c | 209 + net/bridge/netfilter/Kconfig | 242 + net/bridge/netfilter/Makefile | 40 + net/bridge/netfilter/ebt_802_3.c | 79 + net/bridge/netfilter/ebt_among.c | 281 + net/bridge/netfilter/ebt_arp.c | 138 + net/bridge/netfilter/ebt_arpreply.c | 103 + net/bridge/netfilter/ebt_dnat.c | 106 + net/bridge/netfilter/ebt_ip.c | 169 + net/bridge/netfilter/ebt_ip6.c | 164 + net/bridge/netfilter/ebt_limit.c | 129 + net/bridge/netfilter/ebt_log.c | 226 + net/bridge/netfilter/ebt_mark.c | 111 + net/bridge/netfilter/ebt_mark_m.c | 99 + net/bridge/netfilter/ebt_nflog.c | 75 + net/bridge/netfilter/ebt_pkttype.c | 57 + net/bridge/netfilter/ebt_redirect.c | 81 + net/bridge/netfilter/ebt_snat.c | 88 + net/bridge/netfilter/ebt_stp.c | 194 + net/bridge/netfilter/ebt_vlan.c | 173 + net/bridge/netfilter/ebtable_broute.c | 137 + net/bridge/netfilter/ebtable_filter.c | 118 + net/bridge/netfilter/ebtable_nat.c | 118 + net/bridge/netfilter/ebtables.c | 2597 ++++ net/bridge/netfilter/nf_conntrack_bridge.c | 448 + net/bridge/netfilter/nft_meta_bridge.c | 178 + net/bridge/netfilter/nft_reject_bridge.c | 216 + net/caif/Kconfig | 54 + net/caif/Makefile | 16 + net/caif/caif_dev.c | 585 + net/caif/caif_socket.c | 1119 ++ net/caif/caif_usb.c | 215 + net/caif/cfcnfg.c | 612 + net/caif/cfctrl.c | 639 + net/caif/cfdbgl.c | 55 + net/caif/cfdgml.c | 113 + net/caif/cffrml.c | 197 + net/caif/cfmuxl.c | 267 + net/caif/cfpkt_skbuff.c | 382 + net/caif/cfrfml.c | 299 + net/caif/cfserl.c | 192 + net/caif/cfsrvl.c | 220 + net/caif/cfutill.c | 104 + net/caif/cfveil.c | 101 + net/caif/cfvidl.c | 65 + net/caif/chnl_net.c | 531 + net/can/Kconfig | 73 + net/can/Makefile | 22 + net/can/af_can.c | 918 ++ net/can/af_can.h | 103 + net/can/bcm.c | 1788 +++ net/can/gw.c | 1329 ++ net/can/isotp.c | 1691 +++ net/can/j1939/Kconfig | 15 + net/can/j1939/Makefile | 10 + net/can/j1939/address-claim.c | 270 + net/can/j1939/bus.c | 333 + net/can/j1939/j1939-priv.h | 343 + net/can/j1939/main.c | 429 + net/can/j1939/socket.c | 1326 ++ net/can/j1939/transport.c | 2206 +++ net/can/proc.c | 498 + net/can/raw.c | 1024 ++ net/ceph/Kconfig | 47 + net/ceph/Makefile | 18 + net/ceph/armor.c | 106 + net/ceph/auth.c | 659 + net/ceph/auth_none.c | 146 + net/ceph/auth_none.h | 27 + net/ceph/auth_x.c | 1122 ++ net/ceph/auth_x.h | 54 + net/ceph/auth_x_protocol.h | 99 + net/ceph/buffer.c | 59 + net/ceph/ceph_common.c | 916 ++ net/ceph/ceph_hash.c | 131 + net/ceph/ceph_strings.c | 90 + net/ceph/cls_lock_client.c | 431 + net/ceph/crush/crush.c | 142 + net/ceph/crush/crush_ln_table.h | 164 + net/ceph/crush/hash.c | 152 + net/ceph/crush/mapper.c | 1096 ++ net/ceph/crypto.c | 361 + net/ceph/crypto.h | 39 + net/ceph/debugfs.c | 478 + net/ceph/decode.c | 193 + net/ceph/messenger.c | 2135 +++ net/ceph/messenger_v1.c | 1548 ++ net/ceph/messenger_v2.c | 3588 +++++ net/ceph/mon_client.c | 1586 ++ net/ceph/msgpool.c | 94 + net/ceph/osd_client.c | 5658 +++++++ net/ceph/osdmap.c | 3106 ++++ net/ceph/pagelist.c | 171 + net/ceph/pagevec.c | 166 + net/ceph/snapshot.c | 63 + net/ceph/string_table.c | 106 + net/ceph/striper.c | 278 + net/compat.c | 518 + net/core/Makefile | 41 + net/core/bpf_sk_storage.c | 965 ++ net/core/datagram.c | 842 ++ net/core/dev.c | 11486 ++++++++++++++ net/core/dev.h | 118 + net/core/dev_addr_lists.c | 1050 ++ net/core/dev_addr_lists_test.c | 236 + net/core/dev_ioctl.c | 619 + net/core/drop_monitor.c | 1772 +++ net/core/dst.c | 351 + net/core/dst_cache.c | 183 + net/core/failover.c | 315 + net/core/fib_notifier.c | 199 + net/core/fib_rules.c | 1319 ++ net/core/filter.c | 11669 +++++++++++++++ net/core/flow_dissector.c | 1960 +++ net/core/flow_offload.c | 624 + net/core/gen_estimator.c | 278 + net/core/gen_stats.c | 485 + net/core/gro.c | 815 + net/core/gro_cells.c | 138 + net/core/hwbm.c | 85 + net/core/link_watch.c | 280 + net/core/lwt_bpf.c | 657 + net/core/lwtunnel.c | 427 + net/core/neighbour.c | 3887 +++++ net/core/net-procfs.c | 407 + net/core/net-sysfs.c | 2079 +++ net/core/net-sysfs.h | 14 + net/core/net-traces.c | 63 + net/core/net_namespace.c | 1388 ++ net/core/netclassid_cgroup.c | 146 + net/core/netevent.c | 63 + net/core/netpoll.c | 878 ++ net/core/netprio_cgroup.c | 295 + net/core/of_net.c | 170 + net/core/page_pool.c | 932 ++ net/core/pktgen.c | 4039 +++++ net/core/ptp_classifier.c | 228 + net/core/request_sock.c | 129 + net/core/rtnetlink.c | 6248 ++++++++ net/core/scm.c | 375 + net/core/secure_seq.c | 200 + net/core/selftests.c | 412 + net/core/skbuff.c | 6687 +++++++++ net/core/skmsg.c | 1246 ++ net/core/sock.c | 4131 +++++ net/core/sock_destructor.h | 12 + net/core/sock_diag.c | 340 + net/core/sock_map.c | 1701 +++ net/core/sock_reuseport.c | 749 + net/core/stream.c | 220 + net/core/sysctl_net_core.c | 687 + net/core/timestamping.c | 71 + net/core/tso.c | 97 + net/core/utils.c | 486 + net/core/xdp.c | 711 + net/dcb/Kconfig | 23 + net/dcb/Makefile | 2 + net/dcb/dcbevent.c | 30 + net/dcb/dcbnl.c | 2127 +++ net/dccp/Kconfig | 46 + net/dccp/Makefile | 30 + net/dccp/ackvec.c | 407 + net/dccp/ackvec.h | 136 + net/dccp/ccid.c | 219 + net/dccp/ccid.h | 262 + net/dccp/ccids/Kconfig | 55 + net/dccp/ccids/ccid2.c | 793 + net/dccp/ccids/ccid2.h | 121 + net/dccp/ccids/ccid3.c | 866 ++ net/dccp/ccids/ccid3.h | 148 + net/dccp/ccids/lib/loss_interval.c | 184 + net/dccp/ccids/lib/loss_interval.h | 69 + net/dccp/ccids/lib/packet_history.c | 439 + net/dccp/ccids/lib/packet_history.h | 142 + net/dccp/ccids/lib/tfrc.c | 46 + net/dccp/ccids/lib/tfrc.h | 73 + net/dccp/ccids/lib/tfrc_equation.c | 702 + net/dccp/dccp.h | 483 + net/dccp/diag.c | 84 + net/dccp/feat.c | 1577 ++ net/dccp/feat.h | 134 + net/dccp/input.c | 739 + net/dccp/ipv4.c | 1100 ++ net/dccp/ipv6.c | 1181 ++ net/dccp/ipv6.h | 31 + net/dccp/minisocks.c | 272 + net/dccp/options.c | 609 + net/dccp/output.c | 709 + net/dccp/proto.c | 1290 ++ net/dccp/qpolicy.c | 136 + net/dccp/sysctl.c | 113 + net/dccp/timer.c | 272 + net/dccp/trace.h | 82 + net/devlink/Makefile | 3 + net/devlink/leftover.c | 12550 ++++++++++++++++ net/devres.c | 95 + net/dns_resolver/Kconfig | 28 + net/dns_resolver/Makefile | 8 + net/dns_resolver/dns_key.c | 391 + net/dns_resolver/dns_query.c | 172 + net/dns_resolver/internal.h | 51 + net/dsa/Kconfig | 169 + net/dsa/Makefile | 31 + net/dsa/dsa.c | 570 + net/dsa/dsa2.c | 1829 +++ net/dsa/dsa_priv.h | 588 + net/dsa/master.c | 478 + net/dsa/netlink.c | 63 + net/dsa/port.c | 2083 +++ net/dsa/slave.c | 3491 +++++ net/dsa/switch.c | 1030 ++ net/dsa/tag_8021q.c | 507 + net/dsa/tag_ar9331.c | 92 + net/dsa/tag_brcm.c | 334 + net/dsa/tag_dsa.c | 406 + net/dsa/tag_gswip.c | 111 + net/dsa/tag_hellcreek.c | 71 + net/dsa/tag_ksz.c | 264 + net/dsa/tag_lan9303.c | 123 + net/dsa/tag_mtk.c | 104 + net/dsa/tag_ocelot.c | 216 + net/dsa/tag_ocelot_8021q.c | 135 + net/dsa/tag_qca.c | 123 + net/dsa/tag_rtl4_a.c | 124 + net/dsa/tag_rtl8_4.c | 258 + net/dsa/tag_rzn1_a5psw.c | 113 + net/dsa/tag_sja1105.c | 804 + net/dsa/tag_trailer.c | 63 + net/dsa/tag_xrs700x.c | 64 + net/ethernet/Makefile | 6 + net/ethernet/eth.c | 650 + net/ethtool/Makefile | 11 + net/ethtool/bitset.c | 833 ++ net/ethtool/bitset.h | 34 + net/ethtool/cabletest.c | 433 + net/ethtool/channels.c | 224 + net/ethtool/coalesce.c | 341 + net/ethtool/common.c | 599 + net/ethtool/common.h | 56 + net/ethtool/debug.c | 128 + net/ethtool/eee.c | 190 + net/ethtool/eeprom.c | 241 + net/ethtool/features.c | 293 + net/ethtool/fec.c | 310 + net/ethtool/ioctl.c | 3341 +++++ net/ethtool/linkinfo.c | 154 + net/ethtool/linkmodes.c | 369 + net/ethtool/linkstate.c | 182 + net/ethtool/module.c | 180 + net/ethtool/netlink.c | 1077 ++ net/ethtool/netlink.h | 416 + net/ethtool/pause.c | 186 + net/ethtool/phc_vclocks.c | 94 + net/ethtool/privflags.c | 201 + net/ethtool/pse-pd.c | 185 + net/ethtool/rings.c | 235 + net/ethtool/stats.c | 412 + net/ethtool/strset.c | 480 + net/ethtool/tsinfo.c | 136 + net/ethtool/tunnels.c | 300 + net/ethtool/wol.c | 170 + net/hsr/Kconfig | 40 + net/hsr/Makefile | 10 + net/hsr/hsr_debugfs.c | 126 + net/hsr/hsr_device.c | 562 + net/hsr/hsr_device.h | 23 + net/hsr/hsr_forward.c | 638 + net/hsr/hsr_forward.h | 31 + net/hsr/hsr_framereg.c | 640 + net/hsr/hsr_framereg.h | 89 + net/hsr/hsr_main.c | 170 + net/hsr/hsr_main.h | 287 + net/hsr/hsr_netlink.c | 556 + net/hsr/hsr_netlink.h | 29 + net/hsr/hsr_slave.c | 223 + net/hsr/hsr_slave.h | 37 + net/ieee802154/6lowpan/6lowpan_i.h | 48 + net/ieee802154/6lowpan/Kconfig | 6 + net/ieee802154/6lowpan/Makefile | 4 + net/ieee802154/6lowpan/core.c | 284 + net/ieee802154/6lowpan/reassembly.c | 551 + net/ieee802154/6lowpan/rx.c | 323 + net/ieee802154/6lowpan/tx.c | 309 + net/ieee802154/Kconfig | 31 + net/ieee802154/Makefile | 10 + net/ieee802154/core.c | 388 + net/ieee802154/core.h | 50 + net/ieee802154/header_ops.c | 318 + net/ieee802154/ieee802154.h | 75 + net/ieee802154/netlink.c | 148 + net/ieee802154/nl-mac.c | 1344 ++ net/ieee802154/nl-phy.c | 346 + net/ieee802154/nl802154.c | 2517 ++++ net/ieee802154/nl802154.h | 8 + net/ieee802154/nl_policy.c | 74 + net/ieee802154/rdev-ops.h | 321 + net/ieee802154/socket.c | 1143 ++ net/ieee802154/sysfs.c | 111 + net/ieee802154/sysfs.h | 10 + net/ieee802154/trace.c | 7 + net/ieee802154/trace.h | 319 + net/ife/Kconfig | 16 + net/ife/Makefile | 6 + net/ife/ife.c | 177 + net/ipv4/Kconfig | 753 + net/ipv4/Makefile | 73 + net/ipv4/af_inet.c | 2125 +++ net/ipv4/ah4.c | 604 + net/ipv4/arp.c | 1472 ++ net/ipv4/bpf_tcp_ca.c | 281 + net/ipv4/bpfilter/Makefile | 2 + net/ipv4/bpfilter/sockopt.c | 80 + net/ipv4/cipso_ipv4.c | 2295 +++ net/ipv4/datagram.c | 129 + net/ipv4/devinet.c | 2794 ++++ net/ipv4/esp4.c | 1251 ++ net/ipv4/esp4_offload.c | 385 + net/ipv4/fib_frontend.c | 1663 +++ net/ipv4/fib_lookup.h | 63 + net/ipv4/fib_notifier.c | 72 + net/ipv4/fib_rules.c | 436 + net/ipv4/fib_semantics.c | 2275 +++ net/ipv4/fib_trie.c | 3067 ++++ net/ipv4/fou.c | 1294 ++ net/ipv4/gre_demux.c | 221 + net/ipv4/gre_offload.c | 286 + net/ipv4/icmp.c | 1507 ++ net/ipv4/igmp.c | 3110 ++++ net/ipv4/inet_connection_sock.c | 1501 ++ net/ipv4/inet_diag.c | 1485 ++ net/ipv4/inet_fragment.c | 602 + net/ipv4/inet_hashtables.c | 1257 ++ net/ipv4/inet_timewait_sock.c | 342 + net/ipv4/inetpeer.c | 308 + net/ipv4/ip_forward.c | 182 + net/ipv4/ip_fragment.c | 753 + net/ipv4/ip_gre.c | 1800 +++ net/ipv4/ip_input.c | 675 + net/ipv4/ip_options.c | 641 + net/ipv4/ip_output.c | 1773 +++ net/ipv4/ip_sockglue.c | 1832 +++ net/ipv4/ip_tunnel.c | 1283 ++ net/ipv4/ip_tunnel_core.c | 1148 ++ net/ipv4/ip_vti.c | 726 + net/ipv4/ipcomp.c | 204 + net/ipv4/ipconfig.c | 1845 +++ net/ipv4/ipip.c | 662 + net/ipv4/ipmr.c | 3167 ++++ net/ipv4/ipmr_base.c | 448 + net/ipv4/metrics.c | 94 + net/ipv4/netfilter.c | 95 + net/ipv4/netfilter/Kconfig | 358 + net/ipv4/netfilter/Makefile | 54 + net/ipv4/netfilter/arp_tables.c | 1667 +++ net/ipv4/netfilter/arpt_mangle.c | 92 + net/ipv4/netfilter/arptable_filter.c | 91 + net/ipv4/netfilter/ip_tables.c | 1952 +++ net/ipv4/netfilter/ipt_CLUSTERIP.c | 929 ++ net/ipv4/netfilter/ipt_ECN.c | 133 + net/ipv4/netfilter/ipt_REJECT.c | 111 + net/ipv4/netfilter/ipt_SYNPROXY.c | 121 + net/ipv4/netfilter/ipt_ah.c | 88 + net/ipv4/netfilter/ipt_rpfilter.c | 126 + net/ipv4/netfilter/iptable_filter.c | 110 + net/ipv4/netfilter/iptable_mangle.c | 143 + net/ipv4/netfilter/iptable_nat.c | 172 + net/ipv4/netfilter/iptable_raw.c | 110 + net/ipv4/netfilter/iptable_security.c | 98 + net/ipv4/netfilter/nf_defrag_ipv4.c | 173 + net/ipv4/netfilter/nf_dup_ipv4.c | 99 + net/ipv4/netfilter/nf_nat_h323.c | 567 + net/ipv4/netfilter/nf_nat_pptp.c | 320 + net/ipv4/netfilter/nf_nat_snmp_basic.asn1 | 177 + net/ipv4/netfilter/nf_nat_snmp_basic_main.c | 231 + net/ipv4/netfilter/nf_reject_ipv4.c | 340 + net/ipv4/netfilter/nf_socket_ipv4.c | 153 + net/ipv4/netfilter/nf_tproxy_ipv4.c | 152 + net/ipv4/netfilter/nft_dup_ipv4.c | 111 + net/ipv4/netfilter/nft_fib_ipv4.c | 223 + net/ipv4/netfilter/nft_reject_ipv4.c | 76 + net/ipv4/netlink.c | 33 + net/ipv4/nexthop.c | 3775 +++++ net/ipv4/ping.c | 1211 ++ net/ipv4/proc.c | 557 + net/ipv4/protocol.c | 70 + net/ipv4/raw.c | 1112 ++ net/ipv4/raw_diag.c | 261 + net/ipv4/route.c | 3789 +++++ net/ipv4/syncookies.c | 449 + net/ipv4/sysctl_net_ipv4.c | 1479 ++ net/ipv4/tcp.c | 4864 ++++++ net/ipv4/tcp_bbr.c | 1202 ++ net/ipv4/tcp_bic.c | 229 + net/ipv4/tcp_bpf.c | 745 + net/ipv4/tcp_cdg.c | 428 + net/ipv4/tcp_cong.c | 488 + net/ipv4/tcp_cubic.c | 557 + net/ipv4/tcp_dctcp.c | 285 + net/ipv4/tcp_dctcp.h | 40 + net/ipv4/tcp_diag.c | 250 + net/ipv4/tcp_fastopen.c | 595 + net/ipv4/tcp_highspeed.c | 186 + net/ipv4/tcp_htcp.c | 317 + net/ipv4/tcp_hybla.c | 194 + net/ipv4/tcp_illinois.c | 360 + net/ipv4/tcp_input.c | 7075 +++++++++ net/ipv4/tcp_ipv4.c | 3349 +++++ net/ipv4/tcp_lp.c | 354 + net/ipv4/tcp_metrics.c | 1055 ++ net/ipv4/tcp_minisocks.c | 880 ++ net/ipv4/tcp_nv.c | 501 + net/ipv4/tcp_offload.c | 358 + net/ipv4/tcp_output.c | 4199 ++++++ net/ipv4/tcp_rate.c | 209 + net/ipv4/tcp_recovery.c | 237 + net/ipv4/tcp_scalable.c | 65 + net/ipv4/tcp_timer.c | 819 + net/ipv4/tcp_ulp.c | 168 + net/ipv4/tcp_vegas.c | 340 + net/ipv4/tcp_vegas.h | 26 + net/ipv4/tcp_veno.c | 238 + net/ipv4/tcp_westwood.c | 309 + net/ipv4/tcp_yeah.c | 239 + net/ipv4/tunnel4.c | 297 + net/ipv4/udp.c | 3350 +++++ net/ipv4/udp_bpf.c | 157 + net/ipv4/udp_diag.c | 300 + net/ipv4/udp_impl.h | 29 + net/ipv4/udp_offload.c | 739 + net/ipv4/udp_tunnel_core.c | 207 + net/ipv4/udp_tunnel_nic.c | 973 ++ net/ipv4/udp_tunnel_stub.c | 7 + net/ipv4/udplite.c | 136 + net/ipv4/xfrm4_input.c | 173 + net/ipv4/xfrm4_output.c | 46 + net/ipv4/xfrm4_policy.c | 259 + net/ipv4/xfrm4_protocol.c | 306 + net/ipv4/xfrm4_state.c | 24 + net/ipv4/xfrm4_tunnel.c | 118 + net/ipv6/Kconfig | 343 + net/ipv6/Makefile | 55 + net/ipv6/addrconf.c | 7402 +++++++++ net/ipv6/addrconf_core.c | 273 + net/ipv6/addrlabel.c | 652 + net/ipv6/af_inet6.c | 1334 ++ net/ipv6/ah6.c | 807 + net/ipv6/anycast.c | 619 + net/ipv6/calipso.c | 1459 ++ net/ipv6/datagram.c | 1071 ++ net/ipv6/esp6.c | 1305 ++ net/ipv6/esp6_offload.c | 420 + net/ipv6/exthdrs.c | 1401 ++ net/ipv6/exthdrs_core.c | 282 + net/ipv6/exthdrs_offload.c | 37 + net/ipv6/fib6_notifier.c | 64 + net/ipv6/fib6_rules.c | 522 + net/ipv6/fou6.c | 227 + net/ipv6/icmp.c | 1208 ++ net/ipv6/ila/Makefile | 8 + net/ipv6/ila/ila.h | 125 + net/ipv6/ila/ila_common.c | 155 + net/ipv6/ila/ila_lwt.c | 325 + net/ipv6/ila/ila_main.c | 124 + net/ipv6/ila/ila_xlat.c | 660 + net/ipv6/inet6_connection_sock.c | 154 + net/ipv6/inet6_hashtables.c | 338 + net/ipv6/ioam6.c | 979 ++ net/ipv6/ioam6_iptunnel.c | 477 + net/ipv6/ip6_checksum.c | 137 + net/ipv6/ip6_fib.c | 2718 ++++ net/ipv6/ip6_flowlabel.c | 909 ++ net/ipv6/ip6_gre.c | 2436 +++ net/ipv6/ip6_icmp.c | 84 + net/ipv6/ip6_input.c | 593 + net/ipv6/ip6_offload.c | 488 + net/ipv6/ip6_offload.h | 15 + net/ipv6/ip6_output.c | 2084 +++ net/ipv6/ip6_tunnel.c | 2367 +++ net/ipv6/ip6_udp_tunnel.c | 114 + net/ipv6/ip6_vti.c | 1335 ++ net/ipv6/ip6mr.c | 2636 ++++ net/ipv6/ipcomp6.c | 222 + net/ipv6/ipv6_sockglue.c | 1520 ++ net/ipv6/mcast.c | 3212 ++++ net/ipv6/mcast_snoop.c | 190 + net/ipv6/mip6.c | 410 + net/ipv6/ndisc.c | 2061 +++ net/ipv6/netfilter.c | 273 + net/ipv6/netfilter/Kconfig | 288 + net/ipv6/netfilter/Makefile | 45 + net/ipv6/netfilter/ip6_tables.c | 1960 +++ net/ipv6/netfilter/ip6t_NPT.c | 191 + net/ipv6/netfilter/ip6t_REJECT.c | 121 + net/ipv6/netfilter/ip6t_SYNPROXY.c | 124 + net/ipv6/netfilter/ip6t_ah.c | 117 + net/ipv6/netfilter/ip6t_eui64.c | 71 + net/ipv6/netfilter/ip6t_frag.c | 132 + net/ipv6/netfilter/ip6t_hbh.c | 211 + net/ipv6/netfilter/ip6t_ipv6header.c | 153 + net/ipv6/netfilter/ip6t_mh.c | 90 + net/ipv6/netfilter/ip6t_rpfilter.c | 150 + net/ipv6/netfilter/ip6t_rt.c | 191 + net/ipv6/netfilter/ip6t_srh.c | 320 + net/ipv6/netfilter/ip6table_filter.c | 109 + net/ipv6/netfilter/ip6table_mangle.c | 136 + net/ipv6/netfilter/ip6table_nat.c | 172 + net/ipv6/netfilter/ip6table_raw.c | 108 + net/ipv6/netfilter/ip6table_security.c | 97 + net/ipv6/netfilter/nf_conntrack_reasm.c | 568 + net/ipv6/netfilter/nf_defrag_ipv6_hooks.c | 173 + net/ipv6/netfilter/nf_dup_ipv6.c | 78 + net/ipv6/netfilter/nf_reject_ipv6.c | 419 + net/ipv6/netfilter/nf_socket_ipv6.c | 146 + net/ipv6/netfilter/nf_tproxy_ipv6.c | 153 + net/ipv6/netfilter/nft_dup_ipv6.c | 109 + net/ipv6/netfilter/nft_fib_ipv6.c | 283 + net/ipv6/netfilter/nft_reject_ipv6.c | 77 + net/ipv6/output_core.c | 165 + net/ipv6/ping.c | 307 + net/ipv6/proc.c | 319 + net/ipv6/protocol.c | 70 + net/ipv6/raw.c | 1319 ++ net/ipv6/reassembly.c | 612 + net/ipv6/route.c | 6792 +++++++++ net/ipv6/rpl.c | 125 + net/ipv6/rpl_iptunnel.c | 376 + net/ipv6/seg6.c | 569 + net/ipv6/seg6_hmac.c | 439 + net/ipv6/seg6_iptunnel.c | 745 + net/ipv6/seg6_local.c | 2310 +++ net/ipv6/sit.c | 1961 +++ net/ipv6/syncookies.c | 267 + net/ipv6/sysctl_net_ipv6.c | 357 + net/ipv6/tcp_ipv6.c | 2271 +++ net/ipv6/tcpv6_offload.c | 79 + net/ipv6/tunnel6.c | 305 + net/ipv6/udp.c | 1839 +++ net/ipv6/udp_impl.h | 31 + net/ipv6/udp_offload.c | 206 + net/ipv6/udplite.c | 136 + net/ipv6/xfrm6_input.c | 258 + net/ipv6/xfrm6_output.c | 112 + net/ipv6/xfrm6_policy.c | 310 + net/ipv6/xfrm6_protocol.c | 327 + net/ipv6/xfrm6_state.c | 33 + net/ipv6/xfrm6_tunnel.c | 405 + net/iucv/Kconfig | 18 + net/iucv/Makefile | 7 + net/iucv/af_iucv.c | 2332 +++ net/iucv/iucv.c | 1908 +++ net/kcm/Kconfig | 11 + net/kcm/Makefile | 4 + net/kcm/kcmproc.c | 387 + net/kcm/kcmsock.c | 2071 +++ net/key/Makefile | 6 + net/key/af_key.c | 3930 +++++ net/l2tp/Kconfig | 110 + net/l2tp/Makefile | 18 + net/l2tp/l2tp_core.c | 1723 +++ net/l2tp/l2tp_core.h | 346 + net/l2tp/l2tp_debugfs.c | 348 + net/l2tp/l2tp_eth.c | 370 + net/l2tp/l2tp_ip.c | 684 + net/l2tp/l2tp_ip6.c | 813 + net/l2tp/l2tp_netlink.c | 1048 ++ net/l2tp/l2tp_ppp.c | 1743 +++ net/l2tp/trace.h | 211 + net/l3mdev/Kconfig | 11 + net/l3mdev/Makefile | 6 + net/l3mdev/l3mdev.c | 301 + net/lapb/Kconfig | 22 + net/lapb/Makefile | 8 + net/lapb/lapb_iface.c | 553 + net/lapb/lapb_in.c | 556 + net/lapb/lapb_out.c | 205 + net/lapb/lapb_subr.c | 299 + net/lapb/lapb_timer.c | 206 + net/llc/Kconfig | 10 + net/llc/Makefile | 25 + net/llc/af_llc.c | 1309 ++ net/llc/llc_c_ac.c | 1451 ++ net/llc/llc_c_ev.c | 748 + net/llc/llc_c_st.c | 4946 ++++++ net/llc/llc_conn.c | 1020 ++ net/llc/llc_core.c | 159 + net/llc/llc_if.c | 157 + net/llc/llc_input.c | 230 + net/llc/llc_output.c | 74 + net/llc/llc_pdu.c | 372 + net/llc/llc_proc.c | 251 + net/llc/llc_s_ac.c | 217 + net/llc/llc_s_ev.c | 115 + net/llc/llc_s_st.c | 183 + net/llc/llc_sap.c | 439 + net/llc/llc_station.c | 126 + net/llc/sysctl_net_llc.c | 79 + net/mac80211/Kconfig | 306 + net/mac80211/Makefile | 68 + net/mac80211/aead_api.c | 113 + net/mac80211/aead_api.h | 23 + net/mac80211/aes_ccm.h | 45 + net/mac80211/aes_cmac.c | 92 + net/mac80211/aes_cmac.h | 20 + net/mac80211/aes_gcm.h | 43 + net/mac80211/aes_gmac.c | 92 + net/mac80211/aes_gmac.h | 21 + net/mac80211/agg-rx.c | 562 + net/mac80211/agg-tx.c | 1044 ++ net/mac80211/airtime.c | 711 + net/mac80211/cfg.c | 4989 +++++++ net/mac80211/chan.c | 2064 +++ net/mac80211/debug.h | 234 + net/mac80211/debugfs.c | 711 + net/mac80211/debugfs.h | 17 + net/mac80211/debugfs_key.c | 472 + net/mac80211/debugfs_key.h | 44 + net/mac80211/debugfs_netdev.c | 862 ++ net/mac80211/debugfs_netdev.h | 25 + net/mac80211/debugfs_sta.c | 1079 ++ net/mac80211/debugfs_sta.h | 15 + net/mac80211/driver-ops.c | 523 + net/mac80211/driver-ops.h | 1482 ++ net/mac80211/eht.c | 79 + net/mac80211/ethtool.c | 246 + net/mac80211/fils_aead.c | 333 + net/mac80211/fils_aead.h | 16 + net/mac80211/he.c | 245 + net/mac80211/ht.c | 617 + net/mac80211/ibss.c | 1888 +++ net/mac80211/ieee80211_i.h | 2559 ++++ net/mac80211/iface.c | 2404 +++ net/mac80211/key.c | 1481 ++ net/mac80211/key.h | 179 + net/mac80211/led.c | 376 + net/mac80211/led.h | 90 + net/mac80211/link.c | 476 + net/mac80211/main.c | 1575 ++ net/mac80211/mesh.c | 1650 ++ net/mac80211/mesh.h | 350 + net/mac80211/mesh_hwmp.c | 1332 ++ net/mac80211/mesh_pathtbl.c | 790 + net/mac80211/mesh_plink.c | 1237 ++ net/mac80211/mesh_ps.c | 607 + net/mac80211/mesh_sync.c | 213 + net/mac80211/michael.c | 83 + net/mac80211/michael.h | 22 + net/mac80211/mlme.c | 7444 ++++++++++ net/mac80211/ocb.c | 246 + net/mac80211/offchannel.c | 1030 ++ net/mac80211/pm.c | 201 + net/mac80211/rate.c | 1018 ++ net/mac80211/rate.h | 112 + net/mac80211/rc80211_minstrel_ht.c | 2061 +++ net/mac80211/rc80211_minstrel_ht.h | 203 + net/mac80211/rc80211_minstrel_ht_debugfs.c | 336 + net/mac80211/rx.c | 5309 +++++++ net/mac80211/s1g.c | 201 + net/mac80211/scan.c | 1471 ++ net/mac80211/spectmgmt.c | 249 + net/mac80211/sta_info.c | 2935 ++++ net/mac80211/sta_info.h | 993 ++ net/mac80211/status.c | 1296 ++ net/mac80211/tdls.c | 2010 +++ net/mac80211/tkip.c | 323 + net/mac80211/tkip.h | 30 + net/mac80211/trace.c | 77 + net/mac80211/trace.h | 3035 ++++ net/mac80211/trace_msg.h | 57 + net/mac80211/tx.c | 6021 ++++++++ net/mac80211/util.c | 4860 ++++++ net/mac80211/vht.c | 778 + net/mac80211/wep.c | 306 + net/mac80211/wep.h | 30 + net/mac80211/wme.c | 293 + net/mac80211/wme.h | 23 + net/mac80211/wpa.c | 1118 ++ net/mac80211/wpa.h | 49 + net/mac802154/Kconfig | 22 + net/mac802154/Makefile | 6 + net/mac802154/cfg.c | 487 + net/mac802154/cfg.h | 10 + net/mac802154/driver-ops.h | 285 + net/mac802154/ieee802154_i.h | 182 + net/mac802154/iface.c | 745 + net/mac802154/llsec.c | 1050 ++ net/mac802154/llsec.h | 99 + net/mac802154/mac_cmd.c | 144 + net/mac802154/main.c | 285 + net/mac802154/mib.c | 219 + net/mac802154/rx.c | 304 + net/mac802154/trace.c | 10 + net/mac802154/trace.h | 273 + net/mac802154/tx.c | 133 + net/mac802154/util.c | 116 + net/mctp/Kconfig | 23 + net/mctp/Makefile | 6 + net/mctp/af_mctp.c | 710 + net/mctp/device.c | 548 + net/mctp/neigh.c | 343 + net/mctp/route.c | 1432 ++ net/mctp/test/route-test.c | 684 + net/mctp/test/utils.c | 66 + net/mctp/test/utils.h | 20 + net/mpls/Kconfig | 39 + net/mpls/Makefile | 9 + net/mpls/af_mpls.c | 2799 ++++ net/mpls/internal.h | 202 + net/mpls/mpls_gso.c | 109 + net/mpls/mpls_iptunnel.c | 305 + net/mptcp/Kconfig | 39 + net/mptcp/Makefile | 14 + net/mptcp/bpf.c | 21 + net/mptcp/crypto.c | 83 + net/mptcp/crypto_test.c | 72 + net/mptcp/ctrl.c | 230 + net/mptcp/diag.c | 104 + net/mptcp/mib.c | 105 + net/mptcp/mib.h | 80 + net/mptcp/mptcp_diag.c | 248 + net/mptcp/options.c | 1633 ++ net/mptcp/pm.c | 513 + net/mptcp/pm_netlink.c | 2365 +++ net/mptcp/pm_userspace.c | 505 + net/mptcp/protocol.c | 4099 +++++ net/mptcp/protocol.h | 1038 ++ net/mptcp/sockopt.c | 1338 ++ net/mptcp/subflow.c | 2000 +++ net/mptcp/syncookies.c | 133 + net/mptcp/token.c | 416 + net/mptcp/token_test.c | 142 + net/ncsi/Kconfig | 25 + net/ncsi/Makefile | 5 + net/ncsi/internal.h | 414 + net/ncsi/ncsi-aen.c | 246 + net/ncsi/ncsi-cmd.c | 406 + net/ncsi/ncsi-manage.c | 1969 +++ net/ncsi/ncsi-netlink.c | 778 + net/ncsi/ncsi-netlink.h | 25 + net/ncsi/ncsi-pkt.h | 456 + net/ncsi/ncsi-rsp.c | 1232 ++ net/netfilter/Kconfig | 1666 +++ net/netfilter/Makefile | 228 + net/netfilter/core.c | 793 + net/netfilter/ipset/Kconfig | 178 + net/netfilter/ipset/Makefile | 31 + net/netfilter/ipset/ip_set_bitmap_gen.h | 306 + net/netfilter/ipset/ip_set_bitmap_ip.c | 387 + net/netfilter/ipset/ip_set_bitmap_ipmac.c | 423 + net/netfilter/ipset/ip_set_bitmap_port.c | 332 + net/netfilter/ipset/ip_set_core.c | 2422 +++ net/netfilter/ipset/ip_set_getport.c | 150 + net/netfilter/ipset/ip_set_hash_gen.h | 1582 ++ net/netfilter/ipset/ip_set_hash_ip.c | 329 + net/netfilter/ipset/ip_set_hash_ipmac.c | 311 + net/netfilter/ipset/ip_set_hash_ipmark.c | 331 + net/netfilter/ipset/ip_set_hash_ipport.c | 395 + net/netfilter/ipset/ip_set_hash_ipportip.c | 410 + net/netfilter/ipset/ip_set_hash_ipportnet.c | 572 + net/netfilter/ipset/ip_set_hash_mac.c | 169 + net/netfilter/ipset/ip_set_hash_net.c | 407 + net/netfilter/ipset/ip_set_hash_netiface.c | 525 + net/netfilter/ipset/ip_set_hash_netnet.c | 514 + net/netfilter/ipset/ip_set_hash_netport.c | 515 + net/netfilter/ipset/ip_set_hash_netportnet.c | 628 + net/netfilter/ipset/ip_set_list_set.c | 681 + net/netfilter/ipset/pfxlen.c | 189 + net/netfilter/ipvs/Kconfig | 352 + net/netfilter/ipvs/Makefile | 45 + net/netfilter/ipvs/ip_vs_app.c | 617 + net/netfilter/ipvs/ip_vs_conn.c | 1538 ++ net/netfilter/ipvs/ip_vs_core.c | 2456 +++ net/netfilter/ipvs/ip_vs_ctl.c | 4288 ++++++ net/netfilter/ipvs/ip_vs_dh.c | 272 + net/netfilter/ipvs/ip_vs_est.c | 204 + net/netfilter/ipvs/ip_vs_fo.c | 74 + net/netfilter/ipvs/ip_vs_ftp.c | 637 + net/netfilter/ipvs/ip_vs_lblc.c | 630 + net/netfilter/ipvs/ip_vs_lblcr.c | 815 + net/netfilter/ipvs/ip_vs_lc.c | 88 + net/netfilter/ipvs/ip_vs_mh.c | 539 + net/netfilter/ipvs/ip_vs_nfct.c | 280 + net/netfilter/ipvs/ip_vs_nq.c | 138 + net/netfilter/ipvs/ip_vs_ovf.c | 81 + net/netfilter/ipvs/ip_vs_pe.c | 112 + net/netfilter/ipvs/ip_vs_pe_sip.c | 187 + net/netfilter/ipvs/ip_vs_proto.c | 384 + net/netfilter/ipvs/ip_vs_proto_ah_esp.c | 157 + net/netfilter/ipvs/ip_vs_proto_sctp.c | 595 + net/netfilter/ipvs/ip_vs_proto_tcp.c | 746 + net/netfilter/ipvs/ip_vs_proto_udp.c | 503 + net/netfilter/ipvs/ip_vs_rr.c | 125 + net/netfilter/ipvs/ip_vs_sched.c | 250 + net/netfilter/ipvs/ip_vs_sed.c | 139 + net/netfilter/ipvs/ip_vs_sh.c | 378 + net/netfilter/ipvs/ip_vs_sync.c | 2051 +++ net/netfilter/ipvs/ip_vs_twos.c | 139 + net/netfilter/ipvs/ip_vs_wlc.c | 111 + net/netfilter/ipvs/ip_vs_wrr.c | 265 + net/netfilter/ipvs/ip_vs_xmit.c | 1686 +++ net/netfilter/nf_conncount.c | 636 + net/netfilter/nf_conntrack_acct.c | 28 + net/netfilter/nf_conntrack_amanda.c | 240 + net/netfilter/nf_conntrack_bpf.c | 515 + net/netfilter/nf_conntrack_broadcast.c | 84 + net/netfilter/nf_conntrack_core.c | 2875 ++++ net/netfilter/nf_conntrack_ecache.c | 358 + net/netfilter/nf_conntrack_expect.c | 745 + net/netfilter/nf_conntrack_extend.c | 159 + net/netfilter/nf_conntrack_ftp.c | 604 + net/netfilter/nf_conntrack_h323_asn1.c | 939 ++ net/netfilter/nf_conntrack_h323_main.c | 1803 +++ net/netfilter/nf_conntrack_h323_types.c | 1921 +++ net/netfilter/nf_conntrack_helper.c | 520 + net/netfilter/nf_conntrack_irc.c | 314 + net/netfilter/nf_conntrack_labels.c | 82 + net/netfilter/nf_conntrack_netbios_ns.c | 71 + net/netfilter/nf_conntrack_netlink.c | 3918 +++++ net/netfilter/nf_conntrack_pptp.c | 613 + net/netfilter/nf_conntrack_proto.c | 726 + net/netfilter/nf_conntrack_proto_dccp.c | 824 + net/netfilter/nf_conntrack_proto_generic.c | 79 + net/netfilter/nf_conntrack_proto_gre.c | 317 + net/netfilter/nf_conntrack_proto_icmp.c | 383 + net/netfilter/nf_conntrack_proto_icmpv6.c | 359 + net/netfilter/nf_conntrack_proto_sctp.c | 743 + net/netfilter/nf_conntrack_proto_tcp.c | 1616 ++ net/netfilter/nf_conntrack_proto_udp.c | 322 + net/netfilter/nf_conntrack_sane.c | 213 + net/netfilter/nf_conntrack_seqadj.c | 234 + net/netfilter/nf_conntrack_sip.c | 1707 +++ net/netfilter/nf_conntrack_snmp.c | 75 + net/netfilter/nf_conntrack_standalone.c | 1254 ++ net/netfilter/nf_conntrack_tftp.c | 141 + net/netfilter/nf_conntrack_timeout.c | 146 + net/netfilter/nf_conntrack_timestamp.c | 25 + net/netfilter/nf_dup_netdev.c | 94 + net/netfilter/nf_flow_table_core.c | 699 + net/netfilter/nf_flow_table_inet.c | 118 + net/netfilter/nf_flow_table_ip.c | 689 + net/netfilter/nf_flow_table_offload.c | 1241 ++ net/netfilter/nf_flow_table_procfs.c | 80 + net/netfilter/nf_hooks_lwtunnel.c | 53 + net/netfilter/nf_internals.h | 37 + net/netfilter/nf_log.c | 568 + net/netfilter/nf_log_syslog.c | 1083 ++ net/netfilter/nf_nat_amanda.c | 80 + net/netfilter/nf_nat_bpf.c | 79 + net/netfilter/nf_nat_core.c | 1184 ++ net/netfilter/nf_nat_ftp.c | 138 + net/netfilter/nf_nat_helper.c | 231 + net/netfilter/nf_nat_irc.c | 110 + net/netfilter/nf_nat_masquerade.c | 368 + net/netfilter/nf_nat_proto.c | 1103 ++ net/netfilter/nf_nat_redirect.c | 138 + net/netfilter/nf_nat_sip.c | 676 + net/netfilter/nf_nat_tftp.c | 56 + net/netfilter/nf_queue.c | 356 + net/netfilter/nf_sockopt.c | 120 + net/netfilter/nf_synproxy_core.c | 1220 ++ net/netfilter/nf_tables_api.c | 10973 ++++++++++++++ net/netfilter/nf_tables_core.c | 393 + net/netfilter/nf_tables_offload.c | 691 + net/netfilter/nf_tables_trace.c | 295 + net/netfilter/nfnetlink.c | 811 + net/netfilter/nfnetlink_acct.c | 560 + net/netfilter/nfnetlink_cthelper.c | 804 + net/netfilter/nfnetlink_cttimeout.c | 681 + net/netfilter/nfnetlink_hook.c | 393 + net/netfilter/nfnetlink_log.c | 1209 ++ net/netfilter/nfnetlink_osf.c | 450 + net/netfilter/nfnetlink_queue.c | 1617 ++ net/netfilter/nft_bitwise.c | 533 + net/netfilter/nft_byteorder.c | 197 + net/netfilter/nft_chain_filter.c | 456 + net/netfilter/nft_chain_nat.c | 148 + net/netfilter/nft_chain_route.c | 169 + net/netfilter/nft_cmp.c | 433 + net/netfilter/nft_compat.c | 958 ++ net/netfilter/nft_connlimit.c | 303 + net/netfilter/nft_counter.c | 308 + net/netfilter/nft_ct.c | 1401 ++ net/netfilter/nft_dup_netdev.c | 112 + net/netfilter/nft_dynset.c | 433 + net/netfilter/nft_exthdr.c | 840 ++ net/netfilter/nft_fib.c | 210 + net/netfilter/nft_fib_inet.c | 80 + net/netfilter/nft_fib_netdev.c | 89 + net/netfilter/nft_flow_offload.c | 527 + net/netfilter/nft_fwd_netdev.c | 271 + net/netfilter/nft_hash.c | 286 + net/netfilter/nft_immediate.c | 355 + net/netfilter/nft_last.c | 137 + net/netfilter/nft_limit.c | 481 + net/netfilter/nft_log.c | 320 + net/netfilter/nft_lookup.c | 258 + net/netfilter/nft_masq.c | 307 + net/netfilter/nft_meta.c | 948 ++ net/netfilter/nft_nat.c | 408 + net/netfilter/nft_numgen.c | 255 + net/netfilter/nft_objref.c | 262 + net/netfilter/nft_osf.c | 194 + net/netfilter/nft_payload.c | 891 ++ net/netfilter/nft_queue.c | 248 + net/netfilter/nft_quota.c | 304 + net/netfilter/nft_range.c | 149 + net/netfilter/nft_redir.c | 272 + net/netfilter/nft_reject.c | 133 + net/netfilter/nft_reject_inet.c | 111 + net/netfilter/nft_reject_netdev.c | 191 + net/netfilter/nft_rt.c | 208 + net/netfilter/nft_set_bitmap.c | 316 + net/netfilter/nft_set_hash.c | 788 + net/netfilter/nft_set_pipapo.c | 2331 +++ net/netfilter/nft_set_pipapo.h | 280 + net/netfilter/nft_set_pipapo_avx2.c | 1228 ++ net/netfilter/nft_set_pipapo_avx2.h | 12 + net/netfilter/nft_set_rbtree.c | 773 + net/netfilter/nft_socket.c | 291 + net/netfilter/nft_synproxy.c | 397 + net/netfilter/nft_tproxy.c | 363 + net/netfilter/nft_tunnel.c | 750 + net/netfilter/nft_xfrm.c | 324 + net/netfilter/utils.c | 217 + net/netfilter/x_tables.c | 2017 +++ net/netfilter/xt_AUDIT.c | 158 + net/netfilter/xt_CHECKSUM.c | 87 + net/netfilter/xt_CLASSIFY.c | 70 + net/netfilter/xt_CONNSECMARK.c | 139 + net/netfilter/xt_CT.c | 405 + net/netfilter/xt_DSCP.c | 161 + net/netfilter/xt_HL.c | 159 + net/netfilter/xt_HMARK.c | 368 + net/netfilter/xt_IDLETIMER.c | 542 + net/netfilter/xt_LED.c | 202 + net/netfilter/xt_LOG.c | 119 + net/netfilter/xt_MASQUERADE.c | 128 + net/netfilter/xt_NETMAP.c | 169 + net/netfilter/xt_NFLOG.c | 90 + net/netfilter/xt_NFQUEUE.c | 158 + net/netfilter/xt_RATEEST.c | 233 + net/netfilter/xt_REDIRECT.c | 124 + net/netfilter/xt_SECMARK.c | 191 + net/netfilter/xt_TCPMSS.c | 345 + net/netfilter/xt_TCPOPTSTRIP.c | 153 + net/netfilter/xt_TEE.c | 231 + net/netfilter/xt_TPROXY.c | 269 + net/netfilter/xt_TRACE.c | 55 + net/netfilter/xt_addrtype.c | 232 + net/netfilter/xt_bpf.c | 153 + net/netfilter/xt_cgroup.c | 219 + net/netfilter/xt_cluster.c | 175 + net/netfilter/xt_comment.c | 46 + net/netfilter/xt_connbytes.c | 157 + net/netfilter/xt_connlabel.c | 102 + net/netfilter/xt_connlimit.c | 137 + net/netfilter/xt_connmark.c | 208 + net/netfilter/xt_conntrack.c | 327 + net/netfilter/xt_cpu.c | 61 + net/netfilter/xt_dccp.c | 185 + net/netfilter/xt_devgroup.c | 79 + net/netfilter/xt_dscp.c | 110 + net/netfilter/xt_ecn.c | 176 + net/netfilter/xt_esp.c | 104 + net/netfilter/xt_hashlimit.c | 1332 ++ net/netfilter/xt_helper.c | 96 + net/netfilter/xt_hl.c | 93 + net/netfilter/xt_ipcomp.c | 109 + net/netfilter/xt_iprange.c | 137 + net/netfilter/xt_ipvs.c | 191 + net/netfilter/xt_l2tp.c | 355 + net/netfilter/xt_length.c | 66 + net/netfilter/xt_limit.c | 215 + net/netfilter/xt_mac.c | 63 + net/netfilter/xt_mark.c | 82 + net/netfilter/xt_multiport.c | 176 + net/netfilter/xt_nat.c | 247 + net/netfilter/xt_nfacct.c | 93 + net/netfilter/xt_osf.c | 73 + net/netfilter/xt_owner.c | 158 + net/netfilter/xt_physdev.c | 139 + net/netfilter/xt_pkttype.c | 61 + net/netfilter/xt_policy.c | 189 + net/netfilter/xt_quota.c | 92 + net/netfilter/xt_rateest.c | 153 + net/netfilter/xt_realm.c | 51 + net/netfilter/xt_recent.c | 763 + net/netfilter/xt_repldata.h | 48 + net/netfilter/xt_sctp.c | 201 + net/netfilter/xt_set.c | 712 + net/netfilter/xt_socket.c | 336 + net/netfilter/xt_state.c | 75 + net/netfilter/xt_statistic.c | 99 + net/netfilter/xt_string.c | 93 + net/netfilter/xt_tcpmss.c | 107 + net/netfilter/xt_tcpudp.c | 232 + net/netfilter/xt_time.c | 300 + net/netfilter/xt_u32.c | 145 + net/netlabel/Kconfig | 19 + net/netlabel/Makefile | 16 + net/netlabel/netlabel_addrlist.c | 369 + net/netlabel/netlabel_addrlist.h | 194 + net/netlabel/netlabel_calipso.c | 735 + net/netlabel/netlabel_calipso.h | 137 + net/netlabel/netlabel_cipso_v4.c | 788 + net/netlabel/netlabel_cipso_v4.h | 155 + net/netlabel/netlabel_domainhash.c | 972 ++ net/netlabel/netlabel_domainhash.h | 106 + net/netlabel/netlabel_kapi.c | 1526 ++ net/netlabel/netlabel_mgmt.c | 847 ++ net/netlabel/netlabel_mgmt.h | 225 + net/netlabel/netlabel_unlabeled.c | 1557 ++ net/netlabel/netlabel_unlabeled.h | 231 + net/netlabel/netlabel_user.c | 110 + net/netlabel/netlabel_user.h | 49 + net/netlink/Kconfig | 11 + net/netlink/Makefile | 9 + net/netlink/af_netlink.c | 2917 ++++ net/netlink/af_netlink.h | 78 + net/netlink/diag.c | 261 + net/netlink/genetlink.c | 1565 ++ net/netlink/policy.c | 479 + net/netrom/Makefile | 10 + net/netrom/af_netrom.c | 1537 ++ net/netrom/nr_dev.c | 178 + net/netrom/nr_in.c | 301 + net/netrom/nr_loopback.c | 73 + net/netrom/nr_out.c | 270 + net/netrom/nr_route.c | 983 ++ net/netrom/nr_subr.c | 279 + net/netrom/nr_timer.c | 248 + net/netrom/sysctl_net_netrom.c | 157 + net/nfc/Kconfig | 34 + net/nfc/Makefile | 14 + net/nfc/af_nfc.c | 88 + net/nfc/core.c | 1247 ++ net/nfc/digital.h | 171 + net/nfc/digital_core.c | 861 ++ net/nfc/digital_dep.c | 1633 ++ net/nfc/digital_technology.c | 1300 ++ net/nfc/hci/Kconfig | 18 + net/nfc/hci/Makefile | 9 + net/nfc/hci/command.c | 344 + net/nfc/hci/core.c | 1105 ++ net/nfc/hci/hci.h | 120 + net/nfc/hci/hcp.c | 136 + net/nfc/hci/llc.c | 150 + net/nfc/hci/llc.h | 56 + net/nfc/hci/llc_nop.c | 86 + net/nfc/hci/llc_shdlc.c | 818 + net/nfc/llcp.h | 252 + net/nfc/llcp_commands.c | 815 + net/nfc/llcp_core.c | 1695 +++ net/nfc/llcp_sock.c | 1061 ++ net/nfc/nci/Kconfig | 29 + net/nfc/nci/Makefile | 14 + net/nfc/nci/core.c | 1574 ++ net/nfc/nci/data.c | 301 + net/nfc/nci/hci.c | 791 + net/nfc/nci/lib.c | 73 + net/nfc/nci/ntf.c | 824 + net/nfc/nci/rsp.c | 428 + net/nfc/nci/spi.c | 322 + net/nfc/nci/uart.c | 461 + net/nfc/netlink.c | 1932 +++ net/nfc/nfc.h | 151 + net/nfc/rawsock.c | 418 + net/nsh/Kconfig | 10 + net/nsh/Makefile | 2 + net/nsh/nsh.c | 150 + net/openvswitch/Kconfig | 74 + net/openvswitch/Makefile | 29 + net/openvswitch/actions.c | 1629 ++ net/openvswitch/conntrack.c | 2326 +++ net/openvswitch/conntrack.h | 106 + net/openvswitch/datapath.c | 2762 ++++ net/openvswitch/datapath.h | 285 + net/openvswitch/dp_notify.c | 86 + net/openvswitch/flow.c | 1115 ++ net/openvswitch/flow.h | 298 + net/openvswitch/flow_netlink.c | 3783 +++++ net/openvswitch/flow_netlink.h | 73 + net/openvswitch/flow_table.c | 1222 ++ net/openvswitch/flow_table.h | 115 + net/openvswitch/meter.c | 768 + net/openvswitch/meter.h | 63 + net/openvswitch/openvswitch_trace.c | 10 + net/openvswitch/openvswitch_trace.h | 158 + net/openvswitch/vport-geneve.c | 140 + net/openvswitch/vport-gre.c | 103 + net/openvswitch/vport-internal_dev.c | 255 + net/openvswitch/vport-internal_dev.h | 17 + net/openvswitch/vport-netdev.c | 209 + net/openvswitch/vport-netdev.h | 23 + net/openvswitch/vport-vxlan.c | 169 + net/openvswitch/vport.c | 516 + net/openvswitch/vport.h | 193 + net/packet/Kconfig | 25 + net/packet/Makefile | 8 + net/packet/af_packet.c | 4768 ++++++ net/packet/diag.c | 265 + net/packet/internal.h | 167 + net/phonet/Kconfig | 17 + net/phonet/Makefile | 12 + net/phonet/af_phonet.c | 540 + net/phonet/datagram.c | 199 + net/phonet/pep-gprs.c | 306 + net/phonet/pep.c | 1366 ++ net/phonet/pn_dev.c | 419 + net/phonet/pn_netlink.c | 306 + net/phonet/socket.c | 774 + net/phonet/sysctl.c | 96 + net/psample/Kconfig | 15 + net/psample/Makefile | 6 + net/psample/psample.c | 513 + net/qrtr/Kconfig | 38 + net/qrtr/Makefile | 10 + net/qrtr/af_qrtr.c | 1324 ++ net/qrtr/mhi.c | 137 + net/qrtr/ns.c | 830 ++ net/qrtr/qrtr.h | 36 + net/qrtr/smd.c | 111 + net/qrtr/tun.c | 176 + net/rds/Kconfig | 28 + net/rds/Makefile | 17 + net/rds/af_rds.c | 963 ++ net/rds/bind.c | 283 + net/rds/cong.c | 428 + net/rds/connection.c | 948 ++ net/rds/ib.c | 607 + net/rds/ib.h | 458 + net/rds/ib_cm.c | 1287 ++ net/rds/ib_frmr.c | 446 + net/rds/ib_mr.h | 143 + net/rds/ib_rdma.c | 701 + net/rds/ib_recv.c | 1094 ++ net/rds/ib_ring.c | 168 + net/rds/ib_send.c | 1017 ++ net/rds/ib_stats.c | 107 + net/rds/ib_sysctl.c | 121 + net/rds/info.c | 242 + net/rds/info.h | 31 + net/rds/loop.c | 254 + net/rds/loop.h | 12 + net/rds/message.c | 521 + net/rds/page.c | 167 + net/rds/rdma.c | 958 ++ net/rds/rdma_transport.c | 322 + net/rds/rdma_transport.h | 31 + net/rds/rds.h | 1019 ++ net/rds/rds_single_path.h | 31 + net/rds/recv.c | 831 ++ net/rds/send.c | 1515 ++ net/rds/stats.c | 155 + net/rds/sysctl.c | 110 + net/rds/tcp.c | 754 + net/rds/tcp.h | 98 + net/rds/tcp_connect.c | 229 + net/rds/tcp_listen.c | 348 + net/rds/tcp_recv.c | 349 + net/rds/tcp_send.c | 225 + net/rds/tcp_stats.c | 74 + net/rds/threads.c | 311 + net/rds/transport.c | 169 + net/rfkill/Kconfig | 34 + net/rfkill/Makefile | 9 + net/rfkill/core.c | 1444 ++ net/rfkill/input.c | 343 + net/rfkill/rfkill-gpio.c | 181 + net/rfkill/rfkill.h | 23 + net/rose/Makefile | 10 + net/rose/af_rose.c | 1674 +++ net/rose/rose_dev.c | 141 + net/rose/rose_in.c | 294 + net/rose/rose_link.c | 289 + net/rose/rose_loopback.c | 133 + net/rose/rose_out.c | 122 + net/rose/rose_route.c | 1326 ++ net/rose/rose_subr.c | 556 + net/rose/rose_timer.c | 212 + net/rose/sysctl_net_rose.c | 126 + net/rxrpc/Kconfig | 61 + net/rxrpc/Makefile | 37 + net/rxrpc/af_rxrpc.c | 1078 ++ net/rxrpc/ar-internal.h | 1258 ++ net/rxrpc/call_accept.c | 491 + net/rxrpc/call_event.c | 447 + net/rxrpc/call_object.c | 737 + net/rxrpc/conn_client.c | 1131 ++ net/rxrpc/conn_event.c | 499 + net/rxrpc/conn_object.c | 487 + net/rxrpc/conn_service.c | 203 + net/rxrpc/input.c | 1502 ++ net/rxrpc/insecure.c | 102 + net/rxrpc/key.c | 695 + net/rxrpc/local_event.c | 115 + net/rxrpc/local_object.c | 474 + net/rxrpc/misc.c | 76 + net/rxrpc/net_ns.c | 135 + net/rxrpc/output.c | 679 + net/rxrpc/peer_event.c | 636 + net/rxrpc/peer_object.c | 528 + net/rxrpc/proc.c | 399 + net/rxrpc/protocol.h | 186 + net/rxrpc/recvmsg.c | 773 + net/rxrpc/rtt.c | 195 + net/rxrpc/rxkad.c | 1405 ++ net/rxrpc/security.c | 190 + net/rxrpc/sendmsg.c | 882 ++ net/rxrpc/server_key.c | 146 + net/rxrpc/skbuff.c | 95 + net/rxrpc/sysctl.c | 138 + net/rxrpc/utils.c | 44 + net/sched/Kconfig | 988 ++ net/sched/Makefile | 86 + net/sched/act_api.c | 2189 +++ net/sched/act_bpf.c | 437 + net/sched/act_connmark.c | 245 + net/sched/act_csum.c | 745 + net/sched/act_ct.c | 1763 +++ net/sched/act_ctinfo.c | 398 + net/sched/act_gact.c | 338 + net/sched/act_gate.c | 676 + net/sched/act_ife.c | 925 ++ net/sched/act_ipt.c | 452 + net/sched/act_meta_mark.c | 73 + net/sched/act_meta_skbprio.c | 71 + net/sched/act_meta_skbtcindex.c | 73 + net/sched/act_mirred.c | 551 + net/sched/act_mpls.c | 489 + net/sched/act_nat.c | 334 + net/sched/act_pedit.c | 627 + net/sched/act_police.c | 533 + net/sched/act_sample.c | 352 + net/sched/act_simple.c | 248 + net/sched/act_skbedit.c | 452 + net/sched/act_skbmod.c | 323 + net/sched/act_tunnel_key.c | 873 ++ net/sched/act_vlan.c | 463 + net/sched/cls_api.c | 3785 +++++ net/sched/cls_basic.c | 342 + net/sched/cls_bpf.c | 706 + net/sched/cls_cgroup.c | 223 + net/sched/cls_flow.c | 719 + net/sched/cls_flower.c | 3474 +++++ net/sched/cls_fw.c | 447 + net/sched/cls_matchall.c | 419 + net/sched/cls_route.c | 680 + net/sched/cls_u32.c | 1490 ++ net/sched/em_canid.c | 230 + net/sched/em_cmp.c | 95 + net/sched/em_ipset.c | 134 + net/sched/em_ipt.c | 297 + net/sched/em_meta.c | 1014 ++ net/sched/em_nbyte.c | 76 + net/sched/em_text.c | 155 + net/sched/em_u32.c | 60 + net/sched/ematch.c | 550 + net/sched/sch_api.c | 2389 +++ net/sched/sch_atm.c | 706 + net/sched/sch_blackhole.c | 41 + net/sched/sch_cake.c | 3120 ++++ net/sched/sch_cbq.c | 1727 +++ net/sched/sch_cbs.c | 576 + net/sched/sch_choke.c | 515 + net/sched/sch_codel.c | 307 + net/sched/sch_drr.c | 496 + net/sched/sch_dsmark.c | 518 + net/sched/sch_etf.c | 515 + net/sched/sch_ets.c | 828 ++ net/sched/sch_fifo.c | 271 + net/sched/sch_fq.c | 1079 ++ net/sched/sch_fq_codel.c | 735 + net/sched/sch_fq_pie.c | 582 + net/sched/sch_frag.c | 152 + net/sched/sch_generic.c | 1604 ++ net/sched/sch_gred.c | 947 ++ net/sched/sch_hfsc.c | 1695 +++ net/sched/sch_hhf.c | 721 + net/sched/sch_htb.c | 2178 +++ net/sched/sch_ingress.c | 319 + net/sched/sch_mq.c | 275 + net/sched/sch_mqprio.c | 664 + net/sched/sch_multiq.c | 412 + net/sched/sch_netem.c | 1289 ++ net/sched/sch_pie.c | 576 + net/sched/sch_plug.c | 228 + net/sched/sch_prio.c | 435 + net/sched/sch_qfq.c | 1536 ++ net/sched/sch_red.c | 565 + net/sched/sch_sfb.c | 729 + net/sched/sch_sfq.c | 939 ++ net/sched/sch_skbprio.c | 309 + net/sched/sch_taprio.c | 2171 +++ net/sched/sch_tbf.c | 622 + net/sched/sch_teql.c | 525 + net/sctp/Kconfig | 97 + net/sctp/Makefile | 24 + net/sctp/associola.c | 1729 +++ net/sctp/auth.c | 1089 ++ net/sctp/bind_addr.c | 575 + net/sctp/chunk.c | 353 + net/sctp/debug.c | 170 + net/sctp/diag.c | 529 + net/sctp/endpointola.c | 417 + net/sctp/input.c | 1349 ++ net/sctp/inqueue.c | 237 + net/sctp/ipv6.c | 1217 ++ net/sctp/objcnt.c | 105 + net/sctp/offload.c | 120 + net/sctp/output.c | 865 ++ net/sctp/outqueue.c | 1922 +++ net/sctp/primitive.c | 201 + net/sctp/proc.c | 399 + net/sctp/protocol.c | 1728 +++ net/sctp/sm_make_chunk.c | 3937 +++++ net/sctp/sm_sideeffect.c | 1825 +++ net/sctp/sm_statefuns.c | 6677 +++++++++ net/sctp/sm_statetable.c | 1041 ++ net/sctp/socket.c | 9745 ++++++++++++ net/sctp/stream.c | 1087 ++ net/sctp/stream_interleave.c | 1359 ++ net/sctp/stream_sched.c | 275 + net/sctp/stream_sched_prio.c | 346 + net/sctp/stream_sched_rr.c | 196 + net/sctp/sysctl.c | 633 + net/sctp/transport.c | 854 ++ net/sctp/tsnmap.c | 364 + net/sctp/ulpevent.c | 1190 ++ net/sctp/ulpqueue.c | 1132 ++ net/smc/Kconfig | 21 + net/smc/Makefile | 8 + net/smc/af_smc.c | 3563 +++++ net/smc/smc.h | 385 + net/smc/smc_cdc.c | 493 + net/smc/smc_cdc.h | 305 + net/smc/smc_clc.c | 1177 ++ net/smc/smc_clc.h | 401 + net/smc/smc_close.c | 506 + net/smc/smc_close.h | 30 + net/smc/smc_core.c | 2617 ++++ net/smc/smc_core.h | 566 + net/smc/smc_diag.c | 270 + net/smc/smc_ib.c | 1018 ++ net/smc/smc_ib.h | 120 + net/smc/smc_ism.c | 534 + net/smc/smc_ism.h | 58 + net/smc/smc_llc.c | 2348 +++ net/smc/smc_llc.h | 120 + net/smc/smc_netlink.c | 157 + net/smc/smc_netlink.h | 34 + net/smc/smc_netns.h | 21 + net/smc/smc_pnet.c | 1206 ++ net/smc/smc_pnet.h | 70 + net/smc/smc_rx.c | 511 + net/smc/smc_rx.h | 31 + net/smc/smc_stats.c | 413 + net/smc/smc_stats.h | 271 + net/smc/smc_sysctl.c | 117 + net/smc/smc_sysctl.h | 33 + net/smc/smc_tracepoint.c | 9 + net/smc/smc_tracepoint.h | 125 + net/smc/smc_tx.c | 779 + net/smc/smc_tx.h | 42 + net/smc/smc_wr.c | 918 ++ net/smc/smc_wr.h | 140 + net/socket.c | 3693 +++++ net/strparser/Kconfig | 3 + net/strparser/Makefile | 2 + net/strparser/strparser.c | 545 + net/sunrpc/Kconfig | 80 + net/sunrpc/Makefile | 21 + net/sunrpc/addr.c | 354 + net/sunrpc/auth.c | 891 ++ net/sunrpc/auth_gss/Makefile | 15 + net/sunrpc/auth_gss/auth_gss.c | 2281 +++ net/sunrpc/auth_gss/auth_gss_internal.h | 45 + net/sunrpc/auth_gss/gss_generic_token.c | 231 + net/sunrpc/auth_gss/gss_krb5_crypto.c | 810 + net/sunrpc/auth_gss/gss_krb5_keys.c | 322 + net/sunrpc/auth_gss/gss_krb5_mech.c | 654 + net/sunrpc/auth_gss/gss_krb5_seal.c | 222 + net/sunrpc/auth_gss/gss_krb5_seqnum.c | 104 + net/sunrpc/auth_gss/gss_krb5_unseal.c | 226 + net/sunrpc/auth_gss/gss_krb5_wrap.c | 596 + net/sunrpc/auth_gss/gss_mech_switch.c | 448 + net/sunrpc/auth_gss/gss_rpc_upcall.c | 403 + net/sunrpc/auth_gss/gss_rpc_upcall.h | 36 + net/sunrpc/auth_gss/gss_rpc_xdr.c | 838 ++ net/sunrpc/auth_gss/gss_rpc_xdr.h | 252 + net/sunrpc/auth_gss/svcauth_gss.c | 2017 +++ net/sunrpc/auth_gss/trace.c | 14 + net/sunrpc/auth_null.c | 143 + net/sunrpc/auth_unix.c | 243 + net/sunrpc/backchannel_rqst.c | 376 + net/sunrpc/cache.c | 1918 +++ net/sunrpc/clnt.c | 3363 +++++ net/sunrpc/debugfs.c | 294 + net/sunrpc/fail.h | 25 + net/sunrpc/netns.h | 43 + net/sunrpc/rpc_pipe.c | 1517 ++ net/sunrpc/rpcb_clnt.c | 1098 ++ net/sunrpc/sched.c | 1361 ++ net/sunrpc/socklib.c | 324 + net/sunrpc/socklib.h | 15 + net/sunrpc/stats.c | 343 + net/sunrpc/sunrpc.h | 42 + net/sunrpc/sunrpc_syms.c | 153 + net/sunrpc/svc.c | 1698 +++ net/sunrpc/svc_xprt.c | 1484 ++ net/sunrpc/svcauth.c | 234 + net/sunrpc/svcauth_unix.c | 987 ++ net/sunrpc/svcsock.c | 1521 ++ net/sunrpc/sysctl.c | 193 + net/sunrpc/sysfs.c | 626 + net/sunrpc/sysfs.h | 42 + net/sunrpc/timer.c | 123 + net/sunrpc/xdr.c | 2272 +++ net/sunrpc/xprt.c | 2192 +++ net/sunrpc/xprtmultipath.c | 655 + net/sunrpc/xprtrdma/Makefile | 8 + net/sunrpc/xprtrdma/backchannel.c | 282 + net/sunrpc/xprtrdma/frwr_ops.c | 696 + net/sunrpc/xprtrdma/module.c | 52 + net/sunrpc/xprtrdma/rpc_rdma.c | 1510 ++ net/sunrpc/xprtrdma/svc_rdma.c | 300 + net/sunrpc/xprtrdma/svc_rdma_backchannel.c | 293 + net/sunrpc/xprtrdma/svc_rdma_pcl.c | 306 + net/sunrpc/xprtrdma/svc_rdma_recvfrom.c | 868 ++ net/sunrpc/xprtrdma/svc_rdma_rw.c | 1165 ++ net/sunrpc/xprtrdma/svc_rdma_sendto.c | 1044 ++ net/sunrpc/xprtrdma/svc_rdma_transport.c | 610 + net/sunrpc/xprtrdma/transport.c | 805 + net/sunrpc/xprtrdma/verbs.c | 1396 ++ net/sunrpc/xprtrdma/xprt_rdma.h | 605 + net/sunrpc/xprtsock.c | 3257 ++++ net/switchdev/Kconfig | 14 + net/switchdev/Makefile | 6 + net/switchdev/switchdev.c | 864 ++ net/sysctl_net.c | 177 + net/tipc/Kconfig | 61 + net/tipc/Makefile | 22 + net/tipc/addr.c | 124 + net/tipc/addr.h | 136 + net/tipc/bcast.c | 864 ++ net/tipc/bcast.h | 127 + net/tipc/bearer.c | 1372 ++ net/tipc/bearer.h | 268 + net/tipc/core.c | 229 + net/tipc/core.h | 228 + net/tipc/crypto.c | 2475 +++ net/tipc/crypto.h | 200 + net/tipc/diag.c | 116 + net/tipc/discover.c | 420 + net/tipc/discover.h | 51 + net/tipc/eth_media.c | 98 + net/tipc/group.c | 959 ++ net/tipc/group.h | 77 + net/tipc/ib_media.c | 104 + net/tipc/link.c | 3009 ++++ net/tipc/link.h | 160 + net/tipc/monitor.c | 874 ++ net/tipc/monitor.h | 83 + net/tipc/msg.c | 851 ++ net/tipc/msg.h | 1310 ++ net/tipc/name_distr.c | 411 + net/tipc/name_distr.h | 80 + net/tipc/name_table.c | 1204 ++ net/tipc/name_table.h | 155 + net/tipc/net.c | 345 + net/tipc/net.h | 53 + net/tipc/netlink.c | 315 + net/tipc/netlink.h | 64 + net/tipc/netlink_compat.c | 1380 ++ net/tipc/node.c | 3166 ++++ net/tipc/node.h | 131 + net/tipc/socket.c | 4019 +++++ net/tipc/socket.h | 80 + net/tipc/subscr.c | 183 + net/tipc/subscr.h | 122 + net/tipc/sysctl.c | 108 + net/tipc/topsrv.c | 726 + net/tipc/topsrv.h | 54 + net/tipc/trace.c | 206 + net/tipc/trace.h | 434 + net/tipc/udp_media.c | 859 ++ net/tipc/udp_media.h | 60 + net/tls/Kconfig | 39 + net/tls/Makefile | 13 + net/tls/tls.h | 328 + net/tls/tls_device.c | 1487 ++ net/tls/tls_device_fallback.c | 513 + net/tls/tls_main.c | 1246 ++ net/tls/tls_proc.c | 56 + net/tls/tls_strp.c | 633 + net/tls/tls_sw.c | 2818 ++++ net/tls/tls_toe.c | 141 + net/tls/trace.c | 10 + net/tls/trace.h | 202 + net/unix/Kconfig | 39 + net/unix/Makefile | 15 + net/unix/af_unix.c | 3785 +++++ net/unix/diag.c | 342 + net/unix/garbage.c | 335 + net/unix/scm.c | 154 + net/unix/scm.h | 10 + net/unix/sysctl_net_unix.c | 60 + net/unix/unix_bpf.c | 198 + net/vmw_vsock/Kconfig | 83 + net/vmw_vsock/Makefile | 21 + net/vmw_vsock/af_vsock.c | 2462 +++ net/vmw_vsock/af_vsock_tap.c | 110 + net/vmw_vsock/diag.c | 178 + net/vmw_vsock/hyperv_transport.c | 956 ++ net/vmw_vsock/virtio_transport.c | 821 + net/vmw_vsock/virtio_transport_common.c | 1415 ++ net/vmw_vsock/vmci_transport.c | 2147 +++ net/vmw_vsock/vmci_transport.h | 133 + net/vmw_vsock/vmci_transport_notify.c | 672 + net/vmw_vsock/vmci_transport_notify.h | 75 + net/vmw_vsock/vmci_transport_notify_qstate.c | 430 + net/vmw_vsock/vsock_addr.c | 69 + net/vmw_vsock/vsock_loopback.c | 165 + net/wireless/.gitignore | 3 + net/wireless/Kconfig | 237 + net/wireless/Makefile | 60 + net/wireless/ap.c | 88 + net/wireless/certs/sforshee.hex | 86 + net/wireless/certs/wens.hex | 87 + net/wireless/chan.c | 1462 ++ net/wireless/core.c | 1763 +++ net/wireless/core.h | 582 + net/wireless/debugfs.c | 111 + net/wireless/debugfs.h | 12 + net/wireless/ethtool.c | 29 + net/wireless/ibss.c | 542 + net/wireless/lib80211.c | 257 + net/wireless/lib80211_crypt_ccmp.c | 448 + net/wireless/lib80211_crypt_tkip.c | 738 + net/wireless/lib80211_crypt_wep.c | 256 + net/wireless/mesh.c | 302 + net/wireless/mlme.c | 1160 ++ net/wireless/nl80211.c | 19837 +++++++++++++++++++++++++ net/wireless/nl80211.h | 124 + net/wireless/ocb.c | 92 + net/wireless/of.c | 138 + net/wireless/pmsr.c | 661 + net/wireless/radiotap.c | 370 + net/wireless/rdev-ops.h | 1497 ++ net/wireless/reg.c | 4410 ++++++ net/wireless/reg.h | 189 + net/wireless/scan.c | 3177 ++++ net/wireless/sme.c | 1598 ++ net/wireless/sysfs.c | 182 + net/wireless/sysfs.h | 10 + net/wireless/trace.c | 7 + net/wireless/trace.h | 3910 +++++ net/wireless/util.c | 2517 ++++ net/wireless/wext-compat.c | 1665 +++ net/wireless/wext-compat.h | 64 + net/wireless/wext-core.c | 1198 ++ net/wireless/wext-priv.c | 249 + net/wireless/wext-proc.c | 142 + net/wireless/wext-sme.c | 410 + net/wireless/wext-spy.c | 232 + net/x25/Kconfig | 34 + net/x25/Makefile | 11 + net/x25/af_x25.c | 1856 +++ net/x25/sysctl_net_x25.c | 88 + net/x25/x25_dev.c | 216 + net/x25/x25_facilities.c | 350 + net/x25/x25_forward.c | 157 + net/x25/x25_in.c | 456 + net/x25/x25_link.c | 421 + net/x25/x25_out.c | 226 + net/x25/x25_proc.c | 206 + net/x25/x25_route.c | 204 + net/x25/x25_subr.c | 384 + net/x25/x25_timer.c | 169 + net/xdp/Kconfig | 16 + net/xdp/Makefile | 4 + net/xdp/xdp_umem.c | 259 + net/xdp/xdp_umem.h | 15 + net/xdp/xsk.c | 1524 ++ net/xdp/xsk.h | 48 + net/xdp/xsk_buff_pool.c | 681 + net/xdp/xsk_diag.c | 214 + net/xdp/xsk_queue.c | 57 + net/xdp/xsk_queue.h | 432 + net/xdp/xskmap.c | 272 + net/xfrm/Kconfig | 140 + net/xfrm/Makefile | 17 + net/xfrm/espintcp.c | 579 + net/xfrm/xfrm_algo.c | 866 ++ net/xfrm/xfrm_compat.c | 675 + net/xfrm/xfrm_device.c | 444 + net/xfrm/xfrm_hash.c | 40 + net/xfrm/xfrm_hash.h | 199 + net/xfrm/xfrm_inout.h | 70 + net/xfrm/xfrm_input.c | 834 ++ net/xfrm/xfrm_interface_core.c | 1242 ++ net/xfrm/xfrm_ipcomp.c | 378 + net/xfrm/xfrm_output.c | 909 ++ net/xfrm/xfrm_policy.c | 4490 ++++++ net/xfrm/xfrm_proc.c | 75 + net/xfrm/xfrm_replay.c | 795 + net/xfrm/xfrm_state.c | 2933 ++++ net/xfrm/xfrm_sysctl.c | 87 + net/xfrm/xfrm_user.c | 3830 +++++ 1892 files changed, 1241503 insertions(+) create mode 100644 net/6lowpan/6lowpan_i.h create mode 100644 net/6lowpan/Kconfig create mode 100644 net/6lowpan/Makefile create mode 100644 net/6lowpan/core.c create mode 100644 net/6lowpan/debugfs.c create mode 100644 net/6lowpan/iphc.c create mode 100644 net/6lowpan/ndisc.c create mode 100644 net/6lowpan/nhc.c create mode 100644 net/6lowpan/nhc.h create mode 100644 net/6lowpan/nhc_dest.c create mode 100644 net/6lowpan/nhc_fragment.c create mode 100644 net/6lowpan/nhc_ghc_ext_dest.c create mode 100644 net/6lowpan/nhc_ghc_ext_frag.c create mode 100644 net/6lowpan/nhc_ghc_ext_hop.c create mode 100644 net/6lowpan/nhc_ghc_ext_route.c create mode 100644 net/6lowpan/nhc_ghc_icmpv6.c create mode 100644 net/6lowpan/nhc_ghc_udp.c create mode 100644 net/6lowpan/nhc_hop.c create mode 100644 net/6lowpan/nhc_ipv6.c create mode 100644 net/6lowpan/nhc_mobility.c create mode 100644 net/6lowpan/nhc_routing.c create mode 100644 net/6lowpan/nhc_udp.c create mode 100644 net/802/Kconfig create mode 100644 net/802/Makefile create mode 100644 net/802/fc.c create mode 100644 net/802/fddi.c create mode 100644 net/802/garp.c create mode 100644 net/802/hippi.c create mode 100644 net/802/mrp.c create mode 100644 net/802/p8022.c create mode 100644 net/802/psnap.c create mode 100644 net/802/stp.c create mode 100644 net/8021q/Kconfig create mode 100644 net/8021q/Makefile create mode 100644 net/8021q/vlan.c create mode 100644 net/8021q/vlan.h create mode 100644 net/8021q/vlan_core.c create mode 100644 net/8021q/vlan_dev.c create mode 100644 net/8021q/vlan_gvrp.c create mode 100644 net/8021q/vlan_mvrp.c create mode 100644 net/8021q/vlan_netlink.c create mode 100644 net/8021q/vlanproc.c create mode 100644 net/8021q/vlanproc.h create mode 100644 net/9p/Kconfig create mode 100644 net/9p/Makefile create mode 100644 net/9p/client.c create mode 100644 net/9p/error.c create mode 100644 net/9p/mod.c create mode 100644 net/9p/protocol.c create mode 100644 net/9p/protocol.h create mode 100644 net/9p/trans_common.c create mode 100644 net/9p/trans_common.h create mode 100644 net/9p/trans_fd.c create mode 100644 net/9p/trans_rdma.c create mode 100644 net/9p/trans_virtio.c create mode 100644 net/9p/trans_xen.c create mode 100644 net/Kconfig create mode 100644 net/Kconfig.debug create mode 100644 net/Makefile create mode 100644 net/appletalk/Makefile create mode 100644 net/appletalk/aarp.c create mode 100644 net/appletalk/atalk_proc.c create mode 100644 net/appletalk/ddp.c create mode 100644 net/appletalk/dev.c create mode 100644 net/appletalk/sysctl_net_atalk.c create mode 100644 net/atm/Kconfig create mode 100644 net/atm/Makefile create mode 100644 net/atm/addr.c create mode 100644 net/atm/addr.h create mode 100644 net/atm/atm_misc.c create mode 100644 net/atm/atm_sysfs.c create mode 100644 net/atm/br2684.c create mode 100644 net/atm/clip.c create mode 100644 net/atm/common.c create mode 100644 net/atm/common.h create mode 100644 net/atm/ioctl.c create mode 100644 net/atm/lec.c create mode 100644 net/atm/lec.h create mode 100644 net/atm/lec_arpc.h create mode 100644 net/atm/mpc.c create mode 100644 net/atm/mpc.h create mode 100644 net/atm/mpoa_caches.c create mode 100644 net/atm/mpoa_caches.h create mode 100644 net/atm/mpoa_proc.c create mode 100644 net/atm/pppoatm.c create mode 100644 net/atm/proc.c create mode 100644 net/atm/protocols.h create mode 100644 net/atm/pvc.c create mode 100644 net/atm/raw.c create mode 100644 net/atm/resources.c create mode 100644 net/atm/resources.h create mode 100644 net/atm/signaling.c create mode 100644 net/atm/signaling.h create mode 100644 net/atm/svc.c create mode 100644 net/ax25/Kconfig create mode 100644 net/ax25/Makefile create mode 100644 net/ax25/af_ax25.c create mode 100644 net/ax25/ax25_addr.c create mode 100644 net/ax25/ax25_dev.c create mode 100644 net/ax25/ax25_ds_in.c create mode 100644 net/ax25/ax25_ds_subr.c create mode 100644 net/ax25/ax25_ds_timer.c create mode 100644 net/ax25/ax25_iface.c create mode 100644 net/ax25/ax25_in.c create mode 100644 net/ax25/ax25_ip.c create mode 100644 net/ax25/ax25_out.c create mode 100644 net/ax25/ax25_route.c create mode 100644 net/ax25/ax25_std_in.c create mode 100644 net/ax25/ax25_std_subr.c create mode 100644 net/ax25/ax25_std_timer.c create mode 100644 net/ax25/ax25_subr.c create mode 100644 net/ax25/ax25_timer.c create mode 100644 net/ax25/ax25_uid.c create mode 100644 net/ax25/sysctl_net_ax25.c create mode 100644 net/batman-adv/Kconfig create mode 100644 net/batman-adv/Makefile create mode 100644 net/batman-adv/bat_algo.c create mode 100644 net/batman-adv/bat_algo.h create mode 100644 net/batman-adv/bat_iv_ogm.c create mode 100644 net/batman-adv/bat_iv_ogm.h create mode 100644 net/batman-adv/bat_v.c create mode 100644 net/batman-adv/bat_v.h create mode 100644 net/batman-adv/bat_v_elp.c create mode 100644 net/batman-adv/bat_v_elp.h create mode 100644 net/batman-adv/bat_v_ogm.c create mode 100644 net/batman-adv/bat_v_ogm.h create mode 100644 net/batman-adv/bitarray.c create mode 100644 net/batman-adv/bitarray.h create mode 100644 net/batman-adv/bridge_loop_avoidance.c create mode 100644 net/batman-adv/bridge_loop_avoidance.h create mode 100644 net/batman-adv/distributed-arp-table.c create mode 100644 net/batman-adv/distributed-arp-table.h create mode 100644 net/batman-adv/fragmentation.c create mode 100644 net/batman-adv/fragmentation.h create mode 100644 net/batman-adv/gateway_client.c create mode 100644 net/batman-adv/gateway_client.h create mode 100644 net/batman-adv/gateway_common.c create mode 100644 net/batman-adv/gateway_common.h create mode 100644 net/batman-adv/hard-interface.c create mode 100644 net/batman-adv/hard-interface.h create mode 100644 net/batman-adv/hash.c create mode 100644 net/batman-adv/hash.h create mode 100644 net/batman-adv/log.c create mode 100644 net/batman-adv/log.h create mode 100644 net/batman-adv/main.c create mode 100644 net/batman-adv/main.h create mode 100644 net/batman-adv/multicast.c create mode 100644 net/batman-adv/multicast.h create mode 100644 net/batman-adv/netlink.c create mode 100644 net/batman-adv/netlink.h create mode 100644 net/batman-adv/network-coding.c create mode 100644 net/batman-adv/network-coding.h create mode 100644 net/batman-adv/originator.c create mode 100644 net/batman-adv/originator.h create mode 100644 net/batman-adv/routing.c create mode 100644 net/batman-adv/routing.h create mode 100644 net/batman-adv/send.c create mode 100644 net/batman-adv/send.h create mode 100644 net/batman-adv/soft-interface.c create mode 100644 net/batman-adv/soft-interface.h create mode 100644 net/batman-adv/tp_meter.c create mode 100644 net/batman-adv/tp_meter.h create mode 100644 net/batman-adv/trace.c create mode 100644 net/batman-adv/trace.h create mode 100644 net/batman-adv/translation-table.c create mode 100644 net/batman-adv/translation-table.h create mode 100644 net/batman-adv/tvlv.c create mode 100644 net/batman-adv/tvlv.h create mode 100644 net/batman-adv/types.h create mode 100644 net/bluetooth/6lowpan.c create mode 100644 net/bluetooth/Kconfig create mode 100644 net/bluetooth/Makefile create mode 100644 net/bluetooth/a2mp.c create mode 100644 net/bluetooth/a2mp.h create mode 100644 net/bluetooth/af_bluetooth.c create mode 100644 net/bluetooth/amp.c create mode 100644 net/bluetooth/amp.h create mode 100644 net/bluetooth/aosp.c create mode 100644 net/bluetooth/aosp.h create mode 100644 net/bluetooth/bnep/Kconfig create mode 100644 net/bluetooth/bnep/Makefile create mode 100644 net/bluetooth/bnep/bnep.h create mode 100644 net/bluetooth/bnep/core.c create mode 100644 net/bluetooth/bnep/netdev.c create mode 100644 net/bluetooth/bnep/sock.c create mode 100644 net/bluetooth/cmtp/Kconfig create mode 100644 net/bluetooth/cmtp/Makefile create mode 100644 net/bluetooth/cmtp/capi.c create mode 100644 net/bluetooth/cmtp/cmtp.h create mode 100644 net/bluetooth/cmtp/core.c create mode 100644 net/bluetooth/cmtp/sock.c create mode 100644 net/bluetooth/ecdh_helper.c create mode 100644 net/bluetooth/ecdh_helper.h create mode 100644 net/bluetooth/eir.c create mode 100644 net/bluetooth/eir.h create mode 100644 net/bluetooth/hci_codec.c create mode 100644 net/bluetooth/hci_codec.h create mode 100644 net/bluetooth/hci_conn.c create mode 100644 net/bluetooth/hci_core.c create mode 100644 net/bluetooth/hci_debugfs.c create mode 100644 net/bluetooth/hci_debugfs.h create mode 100644 net/bluetooth/hci_event.c create mode 100644 net/bluetooth/hci_request.c create mode 100644 net/bluetooth/hci_request.h create mode 100644 net/bluetooth/hci_sock.c create mode 100644 net/bluetooth/hci_sync.c create mode 100644 net/bluetooth/hci_sysfs.c create mode 100644 net/bluetooth/hidp/Kconfig create mode 100644 net/bluetooth/hidp/Makefile create mode 100644 net/bluetooth/hidp/core.c create mode 100644 net/bluetooth/hidp/hidp.h create mode 100644 net/bluetooth/hidp/sock.c create mode 100644 net/bluetooth/iso.c create mode 100644 net/bluetooth/l2cap_core.c create mode 100644 net/bluetooth/l2cap_sock.c create mode 100644 net/bluetooth/leds.c create mode 100644 net/bluetooth/leds.h create mode 100644 net/bluetooth/lib.c create mode 100644 net/bluetooth/mgmt.c create mode 100644 net/bluetooth/mgmt_config.c create mode 100644 net/bluetooth/mgmt_config.h create mode 100644 net/bluetooth/mgmt_util.c create mode 100644 net/bluetooth/mgmt_util.h create mode 100644 net/bluetooth/msft.c create mode 100644 net/bluetooth/msft.h create mode 100644 net/bluetooth/rfcomm/Kconfig create mode 100644 net/bluetooth/rfcomm/Makefile create mode 100644 net/bluetooth/rfcomm/core.c create mode 100644 net/bluetooth/rfcomm/sock.c create mode 100644 net/bluetooth/rfcomm/tty.c create mode 100644 net/bluetooth/sco.c create mode 100644 net/bluetooth/selftest.c create mode 100644 net/bluetooth/selftest.h create mode 100644 net/bluetooth/smp.c create mode 100644 net/bluetooth/smp.h create mode 100644 net/bpf/Makefile create mode 100644 net/bpf/bpf_dummy_struct_ops.c create mode 100644 net/bpf/test_run.c create mode 100644 net/bpfilter/.gitignore create mode 100644 net/bpfilter/Kconfig create mode 100644 net/bpfilter/Makefile create mode 100644 net/bpfilter/bpfilter_kern.c create mode 100644 net/bpfilter/bpfilter_umh_blob.S create mode 100644 net/bpfilter/main.c create mode 100644 net/bpfilter/msgfmt.h create mode 100644 net/bridge/Kconfig create mode 100644 net/bridge/Makefile create mode 100644 net/bridge/br.c create mode 100644 net/bridge/br_arp_nd_proxy.c create mode 100644 net/bridge/br_cfm.c create mode 100644 net/bridge/br_cfm_netlink.c create mode 100644 net/bridge/br_device.c create mode 100644 net/bridge/br_fdb.c create mode 100644 net/bridge/br_forward.c create mode 100644 net/bridge/br_if.c create mode 100644 net/bridge/br_input.c create mode 100644 net/bridge/br_ioctl.c create mode 100644 net/bridge/br_mdb.c create mode 100644 net/bridge/br_mrp.c create mode 100644 net/bridge/br_mrp_netlink.c create mode 100644 net/bridge/br_mrp_switchdev.c create mode 100644 net/bridge/br_mst.c create mode 100644 net/bridge/br_multicast.c create mode 100644 net/bridge/br_multicast_eht.c create mode 100644 net/bridge/br_netfilter_hooks.c create mode 100644 net/bridge/br_netfilter_ipv6.c create mode 100644 net/bridge/br_netlink.c create mode 100644 net/bridge/br_netlink_tunnel.c create mode 100644 net/bridge/br_nf_core.c create mode 100644 net/bridge/br_private.h create mode 100644 net/bridge/br_private_cfm.h create mode 100644 net/bridge/br_private_mcast_eht.h create mode 100644 net/bridge/br_private_mrp.h create mode 100644 net/bridge/br_private_stp.h create mode 100644 net/bridge/br_private_tunnel.h create mode 100644 net/bridge/br_stp.c create mode 100644 net/bridge/br_stp_bpdu.c create mode 100644 net/bridge/br_stp_if.c create mode 100644 net/bridge/br_stp_timer.c create mode 100644 net/bridge/br_switchdev.c create mode 100644 net/bridge/br_sysfs_br.c create mode 100644 net/bridge/br_sysfs_if.c create mode 100644 net/bridge/br_vlan.c create mode 100644 net/bridge/br_vlan_options.c create mode 100644 net/bridge/br_vlan_tunnel.c create mode 100644 net/bridge/netfilter/Kconfig create mode 100644 net/bridge/netfilter/Makefile create mode 100644 net/bridge/netfilter/ebt_802_3.c create mode 100644 net/bridge/netfilter/ebt_among.c create mode 100644 net/bridge/netfilter/ebt_arp.c create mode 100644 net/bridge/netfilter/ebt_arpreply.c create mode 100644 net/bridge/netfilter/ebt_dnat.c create mode 100644 net/bridge/netfilter/ebt_ip.c create mode 100644 net/bridge/netfilter/ebt_ip6.c create mode 100644 net/bridge/netfilter/ebt_limit.c create mode 100644 net/bridge/netfilter/ebt_log.c create mode 100644 net/bridge/netfilter/ebt_mark.c create mode 100644 net/bridge/netfilter/ebt_mark_m.c create mode 100644 net/bridge/netfilter/ebt_nflog.c create mode 100644 net/bridge/netfilter/ebt_pkttype.c create mode 100644 net/bridge/netfilter/ebt_redirect.c create mode 100644 net/bridge/netfilter/ebt_snat.c create mode 100644 net/bridge/netfilter/ebt_stp.c create mode 100644 net/bridge/netfilter/ebt_vlan.c create mode 100644 net/bridge/netfilter/ebtable_broute.c create mode 100644 net/bridge/netfilter/ebtable_filter.c create mode 100644 net/bridge/netfilter/ebtable_nat.c create mode 100644 net/bridge/netfilter/ebtables.c create mode 100644 net/bridge/netfilter/nf_conntrack_bridge.c create mode 100644 net/bridge/netfilter/nft_meta_bridge.c create mode 100644 net/bridge/netfilter/nft_reject_bridge.c create mode 100644 net/caif/Kconfig create mode 100644 net/caif/Makefile create mode 100644 net/caif/caif_dev.c create mode 100644 net/caif/caif_socket.c create mode 100644 net/caif/caif_usb.c create mode 100644 net/caif/cfcnfg.c create mode 100644 net/caif/cfctrl.c create mode 100644 net/caif/cfdbgl.c create mode 100644 net/caif/cfdgml.c create mode 100644 net/caif/cffrml.c create mode 100644 net/caif/cfmuxl.c create mode 100644 net/caif/cfpkt_skbuff.c create mode 100644 net/caif/cfrfml.c create mode 100644 net/caif/cfserl.c create mode 100644 net/caif/cfsrvl.c create mode 100644 net/caif/cfutill.c create mode 100644 net/caif/cfveil.c create mode 100644 net/caif/cfvidl.c create mode 100644 net/caif/chnl_net.c create mode 100644 net/can/Kconfig create mode 100644 net/can/Makefile create mode 100644 net/can/af_can.c create mode 100644 net/can/af_can.h create mode 100644 net/can/bcm.c create mode 100644 net/can/gw.c create mode 100644 net/can/isotp.c create mode 100644 net/can/j1939/Kconfig create mode 100644 net/can/j1939/Makefile create mode 100644 net/can/j1939/address-claim.c create mode 100644 net/can/j1939/bus.c create mode 100644 net/can/j1939/j1939-priv.h create mode 100644 net/can/j1939/main.c create mode 100644 net/can/j1939/socket.c create mode 100644 net/can/j1939/transport.c create mode 100644 net/can/proc.c create mode 100644 net/can/raw.c create mode 100644 net/ceph/Kconfig create mode 100644 net/ceph/Makefile create mode 100644 net/ceph/armor.c create mode 100644 net/ceph/auth.c create mode 100644 net/ceph/auth_none.c create mode 100644 net/ceph/auth_none.h create mode 100644 net/ceph/auth_x.c create mode 100644 net/ceph/auth_x.h create mode 100644 net/ceph/auth_x_protocol.h create mode 100644 net/ceph/buffer.c create mode 100644 net/ceph/ceph_common.c create mode 100644 net/ceph/ceph_hash.c create mode 100644 net/ceph/ceph_strings.c create mode 100644 net/ceph/cls_lock_client.c create mode 100644 net/ceph/crush/crush.c create mode 100644 net/ceph/crush/crush_ln_table.h create mode 100644 net/ceph/crush/hash.c create mode 100644 net/ceph/crush/mapper.c create mode 100644 net/ceph/crypto.c create mode 100644 net/ceph/crypto.h create mode 100644 net/ceph/debugfs.c create mode 100644 net/ceph/decode.c create mode 100644 net/ceph/messenger.c create mode 100644 net/ceph/messenger_v1.c create mode 100644 net/ceph/messenger_v2.c create mode 100644 net/ceph/mon_client.c create mode 100644 net/ceph/msgpool.c create mode 100644 net/ceph/osd_client.c create mode 100644 net/ceph/osdmap.c create mode 100644 net/ceph/pagelist.c create mode 100644 net/ceph/pagevec.c create mode 100644 net/ceph/snapshot.c create mode 100644 net/ceph/string_table.c create mode 100644 net/ceph/striper.c create mode 100644 net/compat.c create mode 100644 net/core/Makefile create mode 100644 net/core/bpf_sk_storage.c create mode 100644 net/core/datagram.c create mode 100644 net/core/dev.c create mode 100644 net/core/dev.h create mode 100644 net/core/dev_addr_lists.c create mode 100644 net/core/dev_addr_lists_test.c create mode 100644 net/core/dev_ioctl.c create mode 100644 net/core/drop_monitor.c create mode 100644 net/core/dst.c create mode 100644 net/core/dst_cache.c create mode 100644 net/core/failover.c create mode 100644 net/core/fib_notifier.c create mode 100644 net/core/fib_rules.c create mode 100644 net/core/filter.c create mode 100644 net/core/flow_dissector.c create mode 100644 net/core/flow_offload.c create mode 100644 net/core/gen_estimator.c create mode 100644 net/core/gen_stats.c create mode 100644 net/core/gro.c create mode 100644 net/core/gro_cells.c create mode 100644 net/core/hwbm.c create mode 100644 net/core/link_watch.c create mode 100644 net/core/lwt_bpf.c create mode 100644 net/core/lwtunnel.c create mode 100644 net/core/neighbour.c create mode 100644 net/core/net-procfs.c create mode 100644 net/core/net-sysfs.c create mode 100644 net/core/net-sysfs.h create mode 100644 net/core/net-traces.c create mode 100644 net/core/net_namespace.c create mode 100644 net/core/netclassid_cgroup.c create mode 100644 net/core/netevent.c create mode 100644 net/core/netpoll.c create mode 100644 net/core/netprio_cgroup.c create mode 100644 net/core/of_net.c create mode 100644 net/core/page_pool.c create mode 100644 net/core/pktgen.c create mode 100644 net/core/ptp_classifier.c create mode 100644 net/core/request_sock.c create mode 100644 net/core/rtnetlink.c create mode 100644 net/core/scm.c create mode 100644 net/core/secure_seq.c create mode 100644 net/core/selftests.c create mode 100644 net/core/skbuff.c create mode 100644 net/core/skmsg.c create mode 100644 net/core/sock.c create mode 100644 net/core/sock_destructor.h create mode 100644 net/core/sock_diag.c create mode 100644 net/core/sock_map.c create mode 100644 net/core/sock_reuseport.c create mode 100644 net/core/stream.c create mode 100644 net/core/sysctl_net_core.c create mode 100644 net/core/timestamping.c create mode 100644 net/core/tso.c create mode 100644 net/core/utils.c create mode 100644 net/core/xdp.c create mode 100644 net/dcb/Kconfig create mode 100644 net/dcb/Makefile create mode 100644 net/dcb/dcbevent.c create mode 100644 net/dcb/dcbnl.c create mode 100644 net/dccp/Kconfig create mode 100644 net/dccp/Makefile create mode 100644 net/dccp/ackvec.c create mode 100644 net/dccp/ackvec.h create mode 100644 net/dccp/ccid.c create mode 100644 net/dccp/ccid.h create mode 100644 net/dccp/ccids/Kconfig create mode 100644 net/dccp/ccids/ccid2.c create mode 100644 net/dccp/ccids/ccid2.h create mode 100644 net/dccp/ccids/ccid3.c create mode 100644 net/dccp/ccids/ccid3.h create mode 100644 net/dccp/ccids/lib/loss_interval.c create mode 100644 net/dccp/ccids/lib/loss_interval.h create mode 100644 net/dccp/ccids/lib/packet_history.c create mode 100644 net/dccp/ccids/lib/packet_history.h create mode 100644 net/dccp/ccids/lib/tfrc.c create mode 100644 net/dccp/ccids/lib/tfrc.h create mode 100644 net/dccp/ccids/lib/tfrc_equation.c create mode 100644 net/dccp/dccp.h create mode 100644 net/dccp/diag.c create mode 100644 net/dccp/feat.c create mode 100644 net/dccp/feat.h create mode 100644 net/dccp/input.c create mode 100644 net/dccp/ipv4.c create mode 100644 net/dccp/ipv6.c create mode 100644 net/dccp/ipv6.h create mode 100644 net/dccp/minisocks.c create mode 100644 net/dccp/options.c create mode 100644 net/dccp/output.c create mode 100644 net/dccp/proto.c create mode 100644 net/dccp/qpolicy.c create mode 100644 net/dccp/sysctl.c create mode 100644 net/dccp/timer.c create mode 100644 net/dccp/trace.h create mode 100644 net/devlink/Makefile create mode 100644 net/devlink/leftover.c create mode 100644 net/devres.c create mode 100644 net/dns_resolver/Kconfig create mode 100644 net/dns_resolver/Makefile create mode 100644 net/dns_resolver/dns_key.c create mode 100644 net/dns_resolver/dns_query.c create mode 100644 net/dns_resolver/internal.h create mode 100644 net/dsa/Kconfig create mode 100644 net/dsa/Makefile create mode 100644 net/dsa/dsa.c create mode 100644 net/dsa/dsa2.c create mode 100644 net/dsa/dsa_priv.h create mode 100644 net/dsa/master.c create mode 100644 net/dsa/netlink.c create mode 100644 net/dsa/port.c create mode 100644 net/dsa/slave.c create mode 100644 net/dsa/switch.c create mode 100644 net/dsa/tag_8021q.c create mode 100644 net/dsa/tag_ar9331.c create mode 100644 net/dsa/tag_brcm.c create mode 100644 net/dsa/tag_dsa.c create mode 100644 net/dsa/tag_gswip.c create mode 100644 net/dsa/tag_hellcreek.c create mode 100644 net/dsa/tag_ksz.c create mode 100644 net/dsa/tag_lan9303.c create mode 100644 net/dsa/tag_mtk.c create mode 100644 net/dsa/tag_ocelot.c create mode 100644 net/dsa/tag_ocelot_8021q.c create mode 100644 net/dsa/tag_qca.c create mode 100644 net/dsa/tag_rtl4_a.c create mode 100644 net/dsa/tag_rtl8_4.c create mode 100644 net/dsa/tag_rzn1_a5psw.c create mode 100644 net/dsa/tag_sja1105.c create mode 100644 net/dsa/tag_trailer.c create mode 100644 net/dsa/tag_xrs700x.c create mode 100644 net/ethernet/Makefile create mode 100644 net/ethernet/eth.c create mode 100644 net/ethtool/Makefile create mode 100644 net/ethtool/bitset.c create mode 100644 net/ethtool/bitset.h create mode 100644 net/ethtool/cabletest.c create mode 100644 net/ethtool/channels.c create mode 100644 net/ethtool/coalesce.c create mode 100644 net/ethtool/common.c create mode 100644 net/ethtool/common.h create mode 100644 net/ethtool/debug.c create mode 100644 net/ethtool/eee.c create mode 100644 net/ethtool/eeprom.c create mode 100644 net/ethtool/features.c create mode 100644 net/ethtool/fec.c create mode 100644 net/ethtool/ioctl.c create mode 100644 net/ethtool/linkinfo.c create mode 100644 net/ethtool/linkmodes.c create mode 100644 net/ethtool/linkstate.c create mode 100644 net/ethtool/module.c create mode 100644 net/ethtool/netlink.c create mode 100644 net/ethtool/netlink.h create mode 100644 net/ethtool/pause.c create mode 100644 net/ethtool/phc_vclocks.c create mode 100644 net/ethtool/privflags.c create mode 100644 net/ethtool/pse-pd.c create mode 100644 net/ethtool/rings.c create mode 100644 net/ethtool/stats.c create mode 100644 net/ethtool/strset.c create mode 100644 net/ethtool/tsinfo.c create mode 100644 net/ethtool/tunnels.c create mode 100644 net/ethtool/wol.c create mode 100644 net/hsr/Kconfig create mode 100644 net/hsr/Makefile create mode 100644 net/hsr/hsr_debugfs.c create mode 100644 net/hsr/hsr_device.c create mode 100644 net/hsr/hsr_device.h create mode 100644 net/hsr/hsr_forward.c create mode 100644 net/hsr/hsr_forward.h create mode 100644 net/hsr/hsr_framereg.c create mode 100644 net/hsr/hsr_framereg.h create mode 100644 net/hsr/hsr_main.c create mode 100644 net/hsr/hsr_main.h create mode 100644 net/hsr/hsr_netlink.c create mode 100644 net/hsr/hsr_netlink.h create mode 100644 net/hsr/hsr_slave.c create mode 100644 net/hsr/hsr_slave.h create mode 100644 net/ieee802154/6lowpan/6lowpan_i.h create mode 100644 net/ieee802154/6lowpan/Kconfig create mode 100644 net/ieee802154/6lowpan/Makefile create mode 100644 net/ieee802154/6lowpan/core.c create mode 100644 net/ieee802154/6lowpan/reassembly.c create mode 100644 net/ieee802154/6lowpan/rx.c create mode 100644 net/ieee802154/6lowpan/tx.c create mode 100644 net/ieee802154/Kconfig create mode 100644 net/ieee802154/Makefile create mode 100644 net/ieee802154/core.c create mode 100644 net/ieee802154/core.h create mode 100644 net/ieee802154/header_ops.c create mode 100644 net/ieee802154/ieee802154.h create mode 100644 net/ieee802154/netlink.c create mode 100644 net/ieee802154/nl-mac.c create mode 100644 net/ieee802154/nl-phy.c create mode 100644 net/ieee802154/nl802154.c create mode 100644 net/ieee802154/nl802154.h create mode 100644 net/ieee802154/nl_policy.c create mode 100644 net/ieee802154/rdev-ops.h create mode 100644 net/ieee802154/socket.c create mode 100644 net/ieee802154/sysfs.c create mode 100644 net/ieee802154/sysfs.h create mode 100644 net/ieee802154/trace.c create mode 100644 net/ieee802154/trace.h create mode 100644 net/ife/Kconfig create mode 100644 net/ife/Makefile create mode 100644 net/ife/ife.c create mode 100644 net/ipv4/Kconfig create mode 100644 net/ipv4/Makefile create mode 100644 net/ipv4/af_inet.c create mode 100644 net/ipv4/ah4.c create mode 100644 net/ipv4/arp.c create mode 100644 net/ipv4/bpf_tcp_ca.c create mode 100644 net/ipv4/bpfilter/Makefile create mode 100644 net/ipv4/bpfilter/sockopt.c create mode 100644 net/ipv4/cipso_ipv4.c create mode 100644 net/ipv4/datagram.c create mode 100644 net/ipv4/devinet.c create mode 100644 net/ipv4/esp4.c create mode 100644 net/ipv4/esp4_offload.c create mode 100644 net/ipv4/fib_frontend.c create mode 100644 net/ipv4/fib_lookup.h create mode 100644 net/ipv4/fib_notifier.c create mode 100644 net/ipv4/fib_rules.c create mode 100644 net/ipv4/fib_semantics.c create mode 100644 net/ipv4/fib_trie.c create mode 100644 net/ipv4/fou.c create mode 100644 net/ipv4/gre_demux.c create mode 100644 net/ipv4/gre_offload.c create mode 100644 net/ipv4/icmp.c create mode 100644 net/ipv4/igmp.c create mode 100644 net/ipv4/inet_connection_sock.c create mode 100644 net/ipv4/inet_diag.c create mode 100644 net/ipv4/inet_fragment.c create mode 100644 net/ipv4/inet_hashtables.c create mode 100644 net/ipv4/inet_timewait_sock.c create mode 100644 net/ipv4/inetpeer.c create mode 100644 net/ipv4/ip_forward.c create mode 100644 net/ipv4/ip_fragment.c create mode 100644 net/ipv4/ip_gre.c create mode 100644 net/ipv4/ip_input.c create mode 100644 net/ipv4/ip_options.c create mode 100644 net/ipv4/ip_output.c create mode 100644 net/ipv4/ip_sockglue.c create mode 100644 net/ipv4/ip_tunnel.c create mode 100644 net/ipv4/ip_tunnel_core.c create mode 100644 net/ipv4/ip_vti.c create mode 100644 net/ipv4/ipcomp.c create mode 100644 net/ipv4/ipconfig.c create mode 100644 net/ipv4/ipip.c create mode 100644 net/ipv4/ipmr.c create mode 100644 net/ipv4/ipmr_base.c create mode 100644 net/ipv4/metrics.c create mode 100644 net/ipv4/netfilter.c create mode 100644 net/ipv4/netfilter/Kconfig create mode 100644 net/ipv4/netfilter/Makefile create mode 100644 net/ipv4/netfilter/arp_tables.c create mode 100644 net/ipv4/netfilter/arpt_mangle.c create mode 100644 net/ipv4/netfilter/arptable_filter.c create mode 100644 net/ipv4/netfilter/ip_tables.c create mode 100644 net/ipv4/netfilter/ipt_CLUSTERIP.c create mode 100644 net/ipv4/netfilter/ipt_ECN.c create mode 100644 net/ipv4/netfilter/ipt_REJECT.c create mode 100644 net/ipv4/netfilter/ipt_SYNPROXY.c create mode 100644 net/ipv4/netfilter/ipt_ah.c create mode 100644 net/ipv4/netfilter/ipt_rpfilter.c create mode 100644 net/ipv4/netfilter/iptable_filter.c create mode 100644 net/ipv4/netfilter/iptable_mangle.c create mode 100644 net/ipv4/netfilter/iptable_nat.c create mode 100644 net/ipv4/netfilter/iptable_raw.c create mode 100644 net/ipv4/netfilter/iptable_security.c create mode 100644 net/ipv4/netfilter/nf_defrag_ipv4.c create mode 100644 net/ipv4/netfilter/nf_dup_ipv4.c create mode 100644 net/ipv4/netfilter/nf_nat_h323.c create mode 100644 net/ipv4/netfilter/nf_nat_pptp.c create mode 100644 net/ipv4/netfilter/nf_nat_snmp_basic.asn1 create mode 100644 net/ipv4/netfilter/nf_nat_snmp_basic_main.c create mode 100644 net/ipv4/netfilter/nf_reject_ipv4.c create mode 100644 net/ipv4/netfilter/nf_socket_ipv4.c create mode 100644 net/ipv4/netfilter/nf_tproxy_ipv4.c create mode 100644 net/ipv4/netfilter/nft_dup_ipv4.c create mode 100644 net/ipv4/netfilter/nft_fib_ipv4.c create mode 100644 net/ipv4/netfilter/nft_reject_ipv4.c create mode 100644 net/ipv4/netlink.c create mode 100644 net/ipv4/nexthop.c create mode 100644 net/ipv4/ping.c create mode 100644 net/ipv4/proc.c create mode 100644 net/ipv4/protocol.c create mode 100644 net/ipv4/raw.c create mode 100644 net/ipv4/raw_diag.c create mode 100644 net/ipv4/route.c create mode 100644 net/ipv4/syncookies.c create mode 100644 net/ipv4/sysctl_net_ipv4.c create mode 100644 net/ipv4/tcp.c create mode 100644 net/ipv4/tcp_bbr.c create mode 100644 net/ipv4/tcp_bic.c create mode 100644 net/ipv4/tcp_bpf.c create mode 100644 net/ipv4/tcp_cdg.c create mode 100644 net/ipv4/tcp_cong.c create mode 100644 net/ipv4/tcp_cubic.c create mode 100644 net/ipv4/tcp_dctcp.c create mode 100644 net/ipv4/tcp_dctcp.h create mode 100644 net/ipv4/tcp_diag.c create mode 100644 net/ipv4/tcp_fastopen.c create mode 100644 net/ipv4/tcp_highspeed.c create mode 100644 net/ipv4/tcp_htcp.c create mode 100644 net/ipv4/tcp_hybla.c create mode 100644 net/ipv4/tcp_illinois.c create mode 100644 net/ipv4/tcp_input.c create mode 100644 net/ipv4/tcp_ipv4.c create mode 100644 net/ipv4/tcp_lp.c create mode 100644 net/ipv4/tcp_metrics.c create mode 100644 net/ipv4/tcp_minisocks.c create mode 100644 net/ipv4/tcp_nv.c create mode 100644 net/ipv4/tcp_offload.c create mode 100644 net/ipv4/tcp_output.c create mode 100644 net/ipv4/tcp_rate.c create mode 100644 net/ipv4/tcp_recovery.c create mode 100644 net/ipv4/tcp_scalable.c create mode 100644 net/ipv4/tcp_timer.c create mode 100644 net/ipv4/tcp_ulp.c create mode 100644 net/ipv4/tcp_vegas.c create mode 100644 net/ipv4/tcp_vegas.h create mode 100644 net/ipv4/tcp_veno.c create mode 100644 net/ipv4/tcp_westwood.c create mode 100644 net/ipv4/tcp_yeah.c create mode 100644 net/ipv4/tunnel4.c create mode 100644 net/ipv4/udp.c create mode 100644 net/ipv4/udp_bpf.c create mode 100644 net/ipv4/udp_diag.c create mode 100644 net/ipv4/udp_impl.h create mode 100644 net/ipv4/udp_offload.c create mode 100644 net/ipv4/udp_tunnel_core.c create mode 100644 net/ipv4/udp_tunnel_nic.c create mode 100644 net/ipv4/udp_tunnel_stub.c create mode 100644 net/ipv4/udplite.c create mode 100644 net/ipv4/xfrm4_input.c create mode 100644 net/ipv4/xfrm4_output.c create mode 100644 net/ipv4/xfrm4_policy.c create mode 100644 net/ipv4/xfrm4_protocol.c create mode 100644 net/ipv4/xfrm4_state.c create mode 100644 net/ipv4/xfrm4_tunnel.c create mode 100644 net/ipv6/Kconfig create mode 100644 net/ipv6/Makefile create mode 100644 net/ipv6/addrconf.c create mode 100644 net/ipv6/addrconf_core.c create mode 100644 net/ipv6/addrlabel.c create mode 100644 net/ipv6/af_inet6.c create mode 100644 net/ipv6/ah6.c create mode 100644 net/ipv6/anycast.c create mode 100644 net/ipv6/calipso.c create mode 100644 net/ipv6/datagram.c create mode 100644 net/ipv6/esp6.c create mode 100644 net/ipv6/esp6_offload.c create mode 100644 net/ipv6/exthdrs.c create mode 100644 net/ipv6/exthdrs_core.c create mode 100644 net/ipv6/exthdrs_offload.c create mode 100644 net/ipv6/fib6_notifier.c create mode 100644 net/ipv6/fib6_rules.c create mode 100644 net/ipv6/fou6.c create mode 100644 net/ipv6/icmp.c create mode 100644 net/ipv6/ila/Makefile create mode 100644 net/ipv6/ila/ila.h create mode 100644 net/ipv6/ila/ila_common.c create mode 100644 net/ipv6/ila/ila_lwt.c create mode 100644 net/ipv6/ila/ila_main.c create mode 100644 net/ipv6/ila/ila_xlat.c create mode 100644 net/ipv6/inet6_connection_sock.c create mode 100644 net/ipv6/inet6_hashtables.c create mode 100644 net/ipv6/ioam6.c create mode 100644 net/ipv6/ioam6_iptunnel.c create mode 100644 net/ipv6/ip6_checksum.c create mode 100644 net/ipv6/ip6_fib.c create mode 100644 net/ipv6/ip6_flowlabel.c create mode 100644 net/ipv6/ip6_gre.c create mode 100644 net/ipv6/ip6_icmp.c create mode 100644 net/ipv6/ip6_input.c create mode 100644 net/ipv6/ip6_offload.c create mode 100644 net/ipv6/ip6_offload.h create mode 100644 net/ipv6/ip6_output.c create mode 100644 net/ipv6/ip6_tunnel.c create mode 100644 net/ipv6/ip6_udp_tunnel.c create mode 100644 net/ipv6/ip6_vti.c create mode 100644 net/ipv6/ip6mr.c create mode 100644 net/ipv6/ipcomp6.c create mode 100644 net/ipv6/ipv6_sockglue.c create mode 100644 net/ipv6/mcast.c create mode 100644 net/ipv6/mcast_snoop.c create mode 100644 net/ipv6/mip6.c create mode 100644 net/ipv6/ndisc.c create mode 100644 net/ipv6/netfilter.c create mode 100644 net/ipv6/netfilter/Kconfig create mode 100644 net/ipv6/netfilter/Makefile create mode 100644 net/ipv6/netfilter/ip6_tables.c create mode 100644 net/ipv6/netfilter/ip6t_NPT.c create mode 100644 net/ipv6/netfilter/ip6t_REJECT.c create mode 100644 net/ipv6/netfilter/ip6t_SYNPROXY.c create mode 100644 net/ipv6/netfilter/ip6t_ah.c create mode 100644 net/ipv6/netfilter/ip6t_eui64.c create mode 100644 net/ipv6/netfilter/ip6t_frag.c create mode 100644 net/ipv6/netfilter/ip6t_hbh.c create mode 100644 net/ipv6/netfilter/ip6t_ipv6header.c create mode 100644 net/ipv6/netfilter/ip6t_mh.c create mode 100644 net/ipv6/netfilter/ip6t_rpfilter.c create mode 100644 net/ipv6/netfilter/ip6t_rt.c create mode 100644 net/ipv6/netfilter/ip6t_srh.c create mode 100644 net/ipv6/netfilter/ip6table_filter.c create mode 100644 net/ipv6/netfilter/ip6table_mangle.c create mode 100644 net/ipv6/netfilter/ip6table_nat.c create mode 100644 net/ipv6/netfilter/ip6table_raw.c create mode 100644 net/ipv6/netfilter/ip6table_security.c create mode 100644 net/ipv6/netfilter/nf_conntrack_reasm.c create mode 100644 net/ipv6/netfilter/nf_defrag_ipv6_hooks.c create mode 100644 net/ipv6/netfilter/nf_dup_ipv6.c create mode 100644 net/ipv6/netfilter/nf_reject_ipv6.c create mode 100644 net/ipv6/netfilter/nf_socket_ipv6.c create mode 100644 net/ipv6/netfilter/nf_tproxy_ipv6.c create mode 100644 net/ipv6/netfilter/nft_dup_ipv6.c create mode 100644 net/ipv6/netfilter/nft_fib_ipv6.c create mode 100644 net/ipv6/netfilter/nft_reject_ipv6.c create mode 100644 net/ipv6/output_core.c create mode 100644 net/ipv6/ping.c create mode 100644 net/ipv6/proc.c create mode 100644 net/ipv6/protocol.c create mode 100644 net/ipv6/raw.c create mode 100644 net/ipv6/reassembly.c create mode 100644 net/ipv6/route.c create mode 100644 net/ipv6/rpl.c create mode 100644 net/ipv6/rpl_iptunnel.c create mode 100644 net/ipv6/seg6.c create mode 100644 net/ipv6/seg6_hmac.c create mode 100644 net/ipv6/seg6_iptunnel.c create mode 100644 net/ipv6/seg6_local.c create mode 100644 net/ipv6/sit.c create mode 100644 net/ipv6/syncookies.c create mode 100644 net/ipv6/sysctl_net_ipv6.c create mode 100644 net/ipv6/tcp_ipv6.c create mode 100644 net/ipv6/tcpv6_offload.c create mode 100644 net/ipv6/tunnel6.c create mode 100644 net/ipv6/udp.c create mode 100644 net/ipv6/udp_impl.h create mode 100644 net/ipv6/udp_offload.c create mode 100644 net/ipv6/udplite.c create mode 100644 net/ipv6/xfrm6_input.c create mode 100644 net/ipv6/xfrm6_output.c create mode 100644 net/ipv6/xfrm6_policy.c create mode 100644 net/ipv6/xfrm6_protocol.c create mode 100644 net/ipv6/xfrm6_state.c create mode 100644 net/ipv6/xfrm6_tunnel.c create mode 100644 net/iucv/Kconfig create mode 100644 net/iucv/Makefile create mode 100644 net/iucv/af_iucv.c create mode 100644 net/iucv/iucv.c create mode 100644 net/kcm/Kconfig create mode 100644 net/kcm/Makefile create mode 100644 net/kcm/kcmproc.c create mode 100644 net/kcm/kcmsock.c create mode 100644 net/key/Makefile create mode 100644 net/key/af_key.c create mode 100644 net/l2tp/Kconfig create mode 100644 net/l2tp/Makefile create mode 100644 net/l2tp/l2tp_core.c create mode 100644 net/l2tp/l2tp_core.h create mode 100644 net/l2tp/l2tp_debugfs.c create mode 100644 net/l2tp/l2tp_eth.c create mode 100644 net/l2tp/l2tp_ip.c create mode 100644 net/l2tp/l2tp_ip6.c create mode 100644 net/l2tp/l2tp_netlink.c create mode 100644 net/l2tp/l2tp_ppp.c create mode 100644 net/l2tp/trace.h create mode 100644 net/l3mdev/Kconfig create mode 100644 net/l3mdev/Makefile create mode 100644 net/l3mdev/l3mdev.c create mode 100644 net/lapb/Kconfig create mode 100644 net/lapb/Makefile create mode 100644 net/lapb/lapb_iface.c create mode 100644 net/lapb/lapb_in.c create mode 100644 net/lapb/lapb_out.c create mode 100644 net/lapb/lapb_subr.c create mode 100644 net/lapb/lapb_timer.c create mode 100644 net/llc/Kconfig create mode 100644 net/llc/Makefile create mode 100644 net/llc/af_llc.c create mode 100644 net/llc/llc_c_ac.c create mode 100644 net/llc/llc_c_ev.c create mode 100644 net/llc/llc_c_st.c create mode 100644 net/llc/llc_conn.c create mode 100644 net/llc/llc_core.c create mode 100644 net/llc/llc_if.c create mode 100644 net/llc/llc_input.c create mode 100644 net/llc/llc_output.c create mode 100644 net/llc/llc_pdu.c create mode 100644 net/llc/llc_proc.c create mode 100644 net/llc/llc_s_ac.c create mode 100644 net/llc/llc_s_ev.c create mode 100644 net/llc/llc_s_st.c create mode 100644 net/llc/llc_sap.c create mode 100644 net/llc/llc_station.c create mode 100644 net/llc/sysctl_net_llc.c create mode 100644 net/mac80211/Kconfig create mode 100644 net/mac80211/Makefile create mode 100644 net/mac80211/aead_api.c create mode 100644 net/mac80211/aead_api.h create mode 100644 net/mac80211/aes_ccm.h create mode 100644 net/mac80211/aes_cmac.c create mode 100644 net/mac80211/aes_cmac.h create mode 100644 net/mac80211/aes_gcm.h create mode 100644 net/mac80211/aes_gmac.c create mode 100644 net/mac80211/aes_gmac.h create mode 100644 net/mac80211/agg-rx.c create mode 100644 net/mac80211/agg-tx.c create mode 100644 net/mac80211/airtime.c create mode 100644 net/mac80211/cfg.c create mode 100644 net/mac80211/chan.c create mode 100644 net/mac80211/debug.h create mode 100644 net/mac80211/debugfs.c create mode 100644 net/mac80211/debugfs.h create mode 100644 net/mac80211/debugfs_key.c create mode 100644 net/mac80211/debugfs_key.h create mode 100644 net/mac80211/debugfs_netdev.c create mode 100644 net/mac80211/debugfs_netdev.h create mode 100644 net/mac80211/debugfs_sta.c create mode 100644 net/mac80211/debugfs_sta.h create mode 100644 net/mac80211/driver-ops.c create mode 100644 net/mac80211/driver-ops.h create mode 100644 net/mac80211/eht.c create mode 100644 net/mac80211/ethtool.c create mode 100644 net/mac80211/fils_aead.c create mode 100644 net/mac80211/fils_aead.h create mode 100644 net/mac80211/he.c create mode 100644 net/mac80211/ht.c create mode 100644 net/mac80211/ibss.c create mode 100644 net/mac80211/ieee80211_i.h create mode 100644 net/mac80211/iface.c create mode 100644 net/mac80211/key.c create mode 100644 net/mac80211/key.h create mode 100644 net/mac80211/led.c create mode 100644 net/mac80211/led.h create mode 100644 net/mac80211/link.c create mode 100644 net/mac80211/main.c create mode 100644 net/mac80211/mesh.c create mode 100644 net/mac80211/mesh.h create mode 100644 net/mac80211/mesh_hwmp.c create mode 100644 net/mac80211/mesh_pathtbl.c create mode 100644 net/mac80211/mesh_plink.c create mode 100644 net/mac80211/mesh_ps.c create mode 100644 net/mac80211/mesh_sync.c create mode 100644 net/mac80211/michael.c create mode 100644 net/mac80211/michael.h create mode 100644 net/mac80211/mlme.c create mode 100644 net/mac80211/ocb.c create mode 100644 net/mac80211/offchannel.c create mode 100644 net/mac80211/pm.c create mode 100644 net/mac80211/rate.c create mode 100644 net/mac80211/rate.h create mode 100644 net/mac80211/rc80211_minstrel_ht.c create mode 100644 net/mac80211/rc80211_minstrel_ht.h create mode 100644 net/mac80211/rc80211_minstrel_ht_debugfs.c create mode 100644 net/mac80211/rx.c create mode 100644 net/mac80211/s1g.c create mode 100644 net/mac80211/scan.c create mode 100644 net/mac80211/spectmgmt.c create mode 100644 net/mac80211/sta_info.c create mode 100644 net/mac80211/sta_info.h create mode 100644 net/mac80211/status.c create mode 100644 net/mac80211/tdls.c create mode 100644 net/mac80211/tkip.c create mode 100644 net/mac80211/tkip.h create mode 100644 net/mac80211/trace.c create mode 100644 net/mac80211/trace.h create mode 100644 net/mac80211/trace_msg.h create mode 100644 net/mac80211/tx.c create mode 100644 net/mac80211/util.c create mode 100644 net/mac80211/vht.c create mode 100644 net/mac80211/wep.c create mode 100644 net/mac80211/wep.h create mode 100644 net/mac80211/wme.c create mode 100644 net/mac80211/wme.h create mode 100644 net/mac80211/wpa.c create mode 100644 net/mac80211/wpa.h create mode 100644 net/mac802154/Kconfig create mode 100644 net/mac802154/Makefile create mode 100644 net/mac802154/cfg.c create mode 100644 net/mac802154/cfg.h create mode 100644 net/mac802154/driver-ops.h create mode 100644 net/mac802154/ieee802154_i.h create mode 100644 net/mac802154/iface.c create mode 100644 net/mac802154/llsec.c create mode 100644 net/mac802154/llsec.h create mode 100644 net/mac802154/mac_cmd.c create mode 100644 net/mac802154/main.c create mode 100644 net/mac802154/mib.c create mode 100644 net/mac802154/rx.c create mode 100644 net/mac802154/trace.c create mode 100644 net/mac802154/trace.h create mode 100644 net/mac802154/tx.c create mode 100644 net/mac802154/util.c create mode 100644 net/mctp/Kconfig create mode 100644 net/mctp/Makefile create mode 100644 net/mctp/af_mctp.c create mode 100644 net/mctp/device.c create mode 100644 net/mctp/neigh.c create mode 100644 net/mctp/route.c create mode 100644 net/mctp/test/route-test.c create mode 100644 net/mctp/test/utils.c create mode 100644 net/mctp/test/utils.h create mode 100644 net/mpls/Kconfig create mode 100644 net/mpls/Makefile create mode 100644 net/mpls/af_mpls.c create mode 100644 net/mpls/internal.h create mode 100644 net/mpls/mpls_gso.c create mode 100644 net/mpls/mpls_iptunnel.c create mode 100644 net/mptcp/Kconfig create mode 100644 net/mptcp/Makefile create mode 100644 net/mptcp/bpf.c create mode 100644 net/mptcp/crypto.c create mode 100644 net/mptcp/crypto_test.c create mode 100644 net/mptcp/ctrl.c create mode 100644 net/mptcp/diag.c create mode 100644 net/mptcp/mib.c create mode 100644 net/mptcp/mib.h create mode 100644 net/mptcp/mptcp_diag.c create mode 100644 net/mptcp/options.c create mode 100644 net/mptcp/pm.c create mode 100644 net/mptcp/pm_netlink.c create mode 100644 net/mptcp/pm_userspace.c create mode 100644 net/mptcp/protocol.c create mode 100644 net/mptcp/protocol.h create mode 100644 net/mptcp/sockopt.c create mode 100644 net/mptcp/subflow.c create mode 100644 net/mptcp/syncookies.c create mode 100644 net/mptcp/token.c create mode 100644 net/mptcp/token_test.c create mode 100644 net/ncsi/Kconfig create mode 100644 net/ncsi/Makefile create mode 100644 net/ncsi/internal.h create mode 100644 net/ncsi/ncsi-aen.c create mode 100644 net/ncsi/ncsi-cmd.c create mode 100644 net/ncsi/ncsi-manage.c create mode 100644 net/ncsi/ncsi-netlink.c create mode 100644 net/ncsi/ncsi-netlink.h create mode 100644 net/ncsi/ncsi-pkt.h create mode 100644 net/ncsi/ncsi-rsp.c create mode 100644 net/netfilter/Kconfig create mode 100644 net/netfilter/Makefile create mode 100644 net/netfilter/core.c create mode 100644 net/netfilter/ipset/Kconfig create mode 100644 net/netfilter/ipset/Makefile create mode 100644 net/netfilter/ipset/ip_set_bitmap_gen.h create mode 100644 net/netfilter/ipset/ip_set_bitmap_ip.c create mode 100644 net/netfilter/ipset/ip_set_bitmap_ipmac.c create mode 100644 net/netfilter/ipset/ip_set_bitmap_port.c create mode 100644 net/netfilter/ipset/ip_set_core.c create mode 100644 net/netfilter/ipset/ip_set_getport.c create mode 100644 net/netfilter/ipset/ip_set_hash_gen.h create mode 100644 net/netfilter/ipset/ip_set_hash_ip.c create mode 100644 net/netfilter/ipset/ip_set_hash_ipmac.c create mode 100644 net/netfilter/ipset/ip_set_hash_ipmark.c create mode 100644 net/netfilter/ipset/ip_set_hash_ipport.c create mode 100644 net/netfilter/ipset/ip_set_hash_ipportip.c create mode 100644 net/netfilter/ipset/ip_set_hash_ipportnet.c create mode 100644 net/netfilter/ipset/ip_set_hash_mac.c create mode 100644 net/netfilter/ipset/ip_set_hash_net.c create mode 100644 net/netfilter/ipset/ip_set_hash_netiface.c create mode 100644 net/netfilter/ipset/ip_set_hash_netnet.c create mode 100644 net/netfilter/ipset/ip_set_hash_netport.c create mode 100644 net/netfilter/ipset/ip_set_hash_netportnet.c create mode 100644 net/netfilter/ipset/ip_set_list_set.c create mode 100644 net/netfilter/ipset/pfxlen.c create mode 100644 net/netfilter/ipvs/Kconfig create mode 100644 net/netfilter/ipvs/Makefile create mode 100644 net/netfilter/ipvs/ip_vs_app.c create mode 100644 net/netfilter/ipvs/ip_vs_conn.c create mode 100644 net/netfilter/ipvs/ip_vs_core.c create mode 100644 net/netfilter/ipvs/ip_vs_ctl.c create mode 100644 net/netfilter/ipvs/ip_vs_dh.c create mode 100644 net/netfilter/ipvs/ip_vs_est.c create mode 100644 net/netfilter/ipvs/ip_vs_fo.c create mode 100644 net/netfilter/ipvs/ip_vs_ftp.c create mode 100644 net/netfilter/ipvs/ip_vs_lblc.c create mode 100644 net/netfilter/ipvs/ip_vs_lblcr.c create mode 100644 net/netfilter/ipvs/ip_vs_lc.c create mode 100644 net/netfilter/ipvs/ip_vs_mh.c create mode 100644 net/netfilter/ipvs/ip_vs_nfct.c create mode 100644 net/netfilter/ipvs/ip_vs_nq.c create mode 100644 net/netfilter/ipvs/ip_vs_ovf.c create mode 100644 net/netfilter/ipvs/ip_vs_pe.c create mode 100644 net/netfilter/ipvs/ip_vs_pe_sip.c create mode 100644 net/netfilter/ipvs/ip_vs_proto.c create mode 100644 net/netfilter/ipvs/ip_vs_proto_ah_esp.c create mode 100644 net/netfilter/ipvs/ip_vs_proto_sctp.c create mode 100644 net/netfilter/ipvs/ip_vs_proto_tcp.c create mode 100644 net/netfilter/ipvs/ip_vs_proto_udp.c create mode 100644 net/netfilter/ipvs/ip_vs_rr.c create mode 100644 net/netfilter/ipvs/ip_vs_sched.c create mode 100644 net/netfilter/ipvs/ip_vs_sed.c create mode 100644 net/netfilter/ipvs/ip_vs_sh.c create mode 100644 net/netfilter/ipvs/ip_vs_sync.c create mode 100644 net/netfilter/ipvs/ip_vs_twos.c create mode 100644 net/netfilter/ipvs/ip_vs_wlc.c create mode 100644 net/netfilter/ipvs/ip_vs_wrr.c create mode 100644 net/netfilter/ipvs/ip_vs_xmit.c create mode 100644 net/netfilter/nf_conncount.c create mode 100644 net/netfilter/nf_conntrack_acct.c create mode 100644 net/netfilter/nf_conntrack_amanda.c create mode 100644 net/netfilter/nf_conntrack_bpf.c create mode 100644 net/netfilter/nf_conntrack_broadcast.c create mode 100644 net/netfilter/nf_conntrack_core.c create mode 100644 net/netfilter/nf_conntrack_ecache.c create mode 100644 net/netfilter/nf_conntrack_expect.c create mode 100644 net/netfilter/nf_conntrack_extend.c create mode 100644 net/netfilter/nf_conntrack_ftp.c create mode 100644 net/netfilter/nf_conntrack_h323_asn1.c create mode 100644 net/netfilter/nf_conntrack_h323_main.c create mode 100644 net/netfilter/nf_conntrack_h323_types.c create mode 100644 net/netfilter/nf_conntrack_helper.c create mode 100644 net/netfilter/nf_conntrack_irc.c create mode 100644 net/netfilter/nf_conntrack_labels.c create mode 100644 net/netfilter/nf_conntrack_netbios_ns.c create mode 100644 net/netfilter/nf_conntrack_netlink.c create mode 100644 net/netfilter/nf_conntrack_pptp.c create mode 100644 net/netfilter/nf_conntrack_proto.c create mode 100644 net/netfilter/nf_conntrack_proto_dccp.c create mode 100644 net/netfilter/nf_conntrack_proto_generic.c create mode 100644 net/netfilter/nf_conntrack_proto_gre.c create mode 100644 net/netfilter/nf_conntrack_proto_icmp.c create mode 100644 net/netfilter/nf_conntrack_proto_icmpv6.c create mode 100644 net/netfilter/nf_conntrack_proto_sctp.c create mode 100644 net/netfilter/nf_conntrack_proto_tcp.c create mode 100644 net/netfilter/nf_conntrack_proto_udp.c create mode 100644 net/netfilter/nf_conntrack_sane.c create mode 100644 net/netfilter/nf_conntrack_seqadj.c create mode 100644 net/netfilter/nf_conntrack_sip.c create mode 100644 net/netfilter/nf_conntrack_snmp.c create mode 100644 net/netfilter/nf_conntrack_standalone.c create mode 100644 net/netfilter/nf_conntrack_tftp.c create mode 100644 net/netfilter/nf_conntrack_timeout.c create mode 100644 net/netfilter/nf_conntrack_timestamp.c create mode 100644 net/netfilter/nf_dup_netdev.c create mode 100644 net/netfilter/nf_flow_table_core.c create mode 100644 net/netfilter/nf_flow_table_inet.c create mode 100644 net/netfilter/nf_flow_table_ip.c create mode 100644 net/netfilter/nf_flow_table_offload.c create mode 100644 net/netfilter/nf_flow_table_procfs.c create mode 100644 net/netfilter/nf_hooks_lwtunnel.c create mode 100644 net/netfilter/nf_internals.h create mode 100644 net/netfilter/nf_log.c create mode 100644 net/netfilter/nf_log_syslog.c create mode 100644 net/netfilter/nf_nat_amanda.c create mode 100644 net/netfilter/nf_nat_bpf.c create mode 100644 net/netfilter/nf_nat_core.c create mode 100644 net/netfilter/nf_nat_ftp.c create mode 100644 net/netfilter/nf_nat_helper.c create mode 100644 net/netfilter/nf_nat_irc.c create mode 100644 net/netfilter/nf_nat_masquerade.c create mode 100644 net/netfilter/nf_nat_proto.c create mode 100644 net/netfilter/nf_nat_redirect.c create mode 100644 net/netfilter/nf_nat_sip.c create mode 100644 net/netfilter/nf_nat_tftp.c create mode 100644 net/netfilter/nf_queue.c create mode 100644 net/netfilter/nf_sockopt.c create mode 100644 net/netfilter/nf_synproxy_core.c create mode 100644 net/netfilter/nf_tables_api.c create mode 100644 net/netfilter/nf_tables_core.c create mode 100644 net/netfilter/nf_tables_offload.c create mode 100644 net/netfilter/nf_tables_trace.c create mode 100644 net/netfilter/nfnetlink.c create mode 100644 net/netfilter/nfnetlink_acct.c create mode 100644 net/netfilter/nfnetlink_cthelper.c create mode 100644 net/netfilter/nfnetlink_cttimeout.c create mode 100644 net/netfilter/nfnetlink_hook.c create mode 100644 net/netfilter/nfnetlink_log.c create mode 100644 net/netfilter/nfnetlink_osf.c create mode 100644 net/netfilter/nfnetlink_queue.c create mode 100644 net/netfilter/nft_bitwise.c create mode 100644 net/netfilter/nft_byteorder.c create mode 100644 net/netfilter/nft_chain_filter.c create mode 100644 net/netfilter/nft_chain_nat.c create mode 100644 net/netfilter/nft_chain_route.c create mode 100644 net/netfilter/nft_cmp.c create mode 100644 net/netfilter/nft_compat.c create mode 100644 net/netfilter/nft_connlimit.c create mode 100644 net/netfilter/nft_counter.c create mode 100644 net/netfilter/nft_ct.c create mode 100644 net/netfilter/nft_dup_netdev.c create mode 100644 net/netfilter/nft_dynset.c create mode 100644 net/netfilter/nft_exthdr.c create mode 100644 net/netfilter/nft_fib.c create mode 100644 net/netfilter/nft_fib_inet.c create mode 100644 net/netfilter/nft_fib_netdev.c create mode 100644 net/netfilter/nft_flow_offload.c create mode 100644 net/netfilter/nft_fwd_netdev.c create mode 100644 net/netfilter/nft_hash.c create mode 100644 net/netfilter/nft_immediate.c create mode 100644 net/netfilter/nft_last.c create mode 100644 net/netfilter/nft_limit.c create mode 100644 net/netfilter/nft_log.c create mode 100644 net/netfilter/nft_lookup.c create mode 100644 net/netfilter/nft_masq.c create mode 100644 net/netfilter/nft_meta.c create mode 100644 net/netfilter/nft_nat.c create mode 100644 net/netfilter/nft_numgen.c create mode 100644 net/netfilter/nft_objref.c create mode 100644 net/netfilter/nft_osf.c create mode 100644 net/netfilter/nft_payload.c create mode 100644 net/netfilter/nft_queue.c create mode 100644 net/netfilter/nft_quota.c create mode 100644 net/netfilter/nft_range.c create mode 100644 net/netfilter/nft_redir.c create mode 100644 net/netfilter/nft_reject.c create mode 100644 net/netfilter/nft_reject_inet.c create mode 100644 net/netfilter/nft_reject_netdev.c create mode 100644 net/netfilter/nft_rt.c create mode 100644 net/netfilter/nft_set_bitmap.c create mode 100644 net/netfilter/nft_set_hash.c create mode 100644 net/netfilter/nft_set_pipapo.c create mode 100644 net/netfilter/nft_set_pipapo.h create mode 100644 net/netfilter/nft_set_pipapo_avx2.c create mode 100644 net/netfilter/nft_set_pipapo_avx2.h create mode 100644 net/netfilter/nft_set_rbtree.c create mode 100644 net/netfilter/nft_socket.c create mode 100644 net/netfilter/nft_synproxy.c create mode 100644 net/netfilter/nft_tproxy.c create mode 100644 net/netfilter/nft_tunnel.c create mode 100644 net/netfilter/nft_xfrm.c create mode 100644 net/netfilter/utils.c create mode 100644 net/netfilter/x_tables.c create mode 100644 net/netfilter/xt_AUDIT.c create mode 100644 net/netfilter/xt_CHECKSUM.c create mode 100644 net/netfilter/xt_CLASSIFY.c create mode 100644 net/netfilter/xt_CONNSECMARK.c create mode 100644 net/netfilter/xt_CT.c create mode 100644 net/netfilter/xt_DSCP.c create mode 100644 net/netfilter/xt_HL.c create mode 100644 net/netfilter/xt_HMARK.c create mode 100644 net/netfilter/xt_IDLETIMER.c create mode 100644 net/netfilter/xt_LED.c create mode 100644 net/netfilter/xt_LOG.c create mode 100644 net/netfilter/xt_MASQUERADE.c create mode 100644 net/netfilter/xt_NETMAP.c create mode 100644 net/netfilter/xt_NFLOG.c create mode 100644 net/netfilter/xt_NFQUEUE.c create mode 100644 net/netfilter/xt_RATEEST.c create mode 100644 net/netfilter/xt_REDIRECT.c create mode 100644 net/netfilter/xt_SECMARK.c create mode 100644 net/netfilter/xt_TCPMSS.c create mode 100644 net/netfilter/xt_TCPOPTSTRIP.c create mode 100644 net/netfilter/xt_TEE.c create mode 100644 net/netfilter/xt_TPROXY.c create mode 100644 net/netfilter/xt_TRACE.c create mode 100644 net/netfilter/xt_addrtype.c create mode 100644 net/netfilter/xt_bpf.c create mode 100644 net/netfilter/xt_cgroup.c create mode 100644 net/netfilter/xt_cluster.c create mode 100644 net/netfilter/xt_comment.c create mode 100644 net/netfilter/xt_connbytes.c create mode 100644 net/netfilter/xt_connlabel.c create mode 100644 net/netfilter/xt_connlimit.c create mode 100644 net/netfilter/xt_connmark.c create mode 100644 net/netfilter/xt_conntrack.c create mode 100644 net/netfilter/xt_cpu.c create mode 100644 net/netfilter/xt_dccp.c create mode 100644 net/netfilter/xt_devgroup.c create mode 100644 net/netfilter/xt_dscp.c create mode 100644 net/netfilter/xt_ecn.c create mode 100644 net/netfilter/xt_esp.c create mode 100644 net/netfilter/xt_hashlimit.c create mode 100644 net/netfilter/xt_helper.c create mode 100644 net/netfilter/xt_hl.c create mode 100644 net/netfilter/xt_ipcomp.c create mode 100644 net/netfilter/xt_iprange.c create mode 100644 net/netfilter/xt_ipvs.c create mode 100644 net/netfilter/xt_l2tp.c create mode 100644 net/netfilter/xt_length.c create mode 100644 net/netfilter/xt_limit.c create mode 100644 net/netfilter/xt_mac.c create mode 100644 net/netfilter/xt_mark.c create mode 100644 net/netfilter/xt_multiport.c create mode 100644 net/netfilter/xt_nat.c create mode 100644 net/netfilter/xt_nfacct.c create mode 100644 net/netfilter/xt_osf.c create mode 100644 net/netfilter/xt_owner.c create mode 100644 net/netfilter/xt_physdev.c create mode 100644 net/netfilter/xt_pkttype.c create mode 100644 net/netfilter/xt_policy.c create mode 100644 net/netfilter/xt_quota.c create mode 100644 net/netfilter/xt_rateest.c create mode 100644 net/netfilter/xt_realm.c create mode 100644 net/netfilter/xt_recent.c create mode 100644 net/netfilter/xt_repldata.h create mode 100644 net/netfilter/xt_sctp.c create mode 100644 net/netfilter/xt_set.c create mode 100644 net/netfilter/xt_socket.c create mode 100644 net/netfilter/xt_state.c create mode 100644 net/netfilter/xt_statistic.c create mode 100644 net/netfilter/xt_string.c create mode 100644 net/netfilter/xt_tcpmss.c create mode 100644 net/netfilter/xt_tcpudp.c create mode 100644 net/netfilter/xt_time.c create mode 100644 net/netfilter/xt_u32.c create mode 100644 net/netlabel/Kconfig create mode 100644 net/netlabel/Makefile create mode 100644 net/netlabel/netlabel_addrlist.c create mode 100644 net/netlabel/netlabel_addrlist.h create mode 100644 net/netlabel/netlabel_calipso.c create mode 100644 net/netlabel/netlabel_calipso.h create mode 100644 net/netlabel/netlabel_cipso_v4.c create mode 100644 net/netlabel/netlabel_cipso_v4.h create mode 100644 net/netlabel/netlabel_domainhash.c create mode 100644 net/netlabel/netlabel_domainhash.h create mode 100644 net/netlabel/netlabel_kapi.c create mode 100644 net/netlabel/netlabel_mgmt.c create mode 100644 net/netlabel/netlabel_mgmt.h create mode 100644 net/netlabel/netlabel_unlabeled.c create mode 100644 net/netlabel/netlabel_unlabeled.h create mode 100644 net/netlabel/netlabel_user.c create mode 100644 net/netlabel/netlabel_user.h create mode 100644 net/netlink/Kconfig create mode 100644 net/netlink/Makefile create mode 100644 net/netlink/af_netlink.c create mode 100644 net/netlink/af_netlink.h create mode 100644 net/netlink/diag.c create mode 100644 net/netlink/genetlink.c create mode 100644 net/netlink/policy.c create mode 100644 net/netrom/Makefile create mode 100644 net/netrom/af_netrom.c create mode 100644 net/netrom/nr_dev.c create mode 100644 net/netrom/nr_in.c create mode 100644 net/netrom/nr_loopback.c create mode 100644 net/netrom/nr_out.c create mode 100644 net/netrom/nr_route.c create mode 100644 net/netrom/nr_subr.c create mode 100644 net/netrom/nr_timer.c create mode 100644 net/netrom/sysctl_net_netrom.c create mode 100644 net/nfc/Kconfig create mode 100644 net/nfc/Makefile create mode 100644 net/nfc/af_nfc.c create mode 100644 net/nfc/core.c create mode 100644 net/nfc/digital.h create mode 100644 net/nfc/digital_core.c create mode 100644 net/nfc/digital_dep.c create mode 100644 net/nfc/digital_technology.c create mode 100644 net/nfc/hci/Kconfig create mode 100644 net/nfc/hci/Makefile create mode 100644 net/nfc/hci/command.c create mode 100644 net/nfc/hci/core.c create mode 100644 net/nfc/hci/hci.h create mode 100644 net/nfc/hci/hcp.c create mode 100644 net/nfc/hci/llc.c create mode 100644 net/nfc/hci/llc.h create mode 100644 net/nfc/hci/llc_nop.c create mode 100644 net/nfc/hci/llc_shdlc.c create mode 100644 net/nfc/llcp.h create mode 100644 net/nfc/llcp_commands.c create mode 100644 net/nfc/llcp_core.c create mode 100644 net/nfc/llcp_sock.c create mode 100644 net/nfc/nci/Kconfig create mode 100644 net/nfc/nci/Makefile create mode 100644 net/nfc/nci/core.c create mode 100644 net/nfc/nci/data.c create mode 100644 net/nfc/nci/hci.c create mode 100644 net/nfc/nci/lib.c create mode 100644 net/nfc/nci/ntf.c create mode 100644 net/nfc/nci/rsp.c create mode 100644 net/nfc/nci/spi.c create mode 100644 net/nfc/nci/uart.c create mode 100644 net/nfc/netlink.c create mode 100644 net/nfc/nfc.h create mode 100644 net/nfc/rawsock.c create mode 100644 net/nsh/Kconfig create mode 100644 net/nsh/Makefile create mode 100644 net/nsh/nsh.c create mode 100644 net/openvswitch/Kconfig create mode 100644 net/openvswitch/Makefile create mode 100644 net/openvswitch/actions.c create mode 100644 net/openvswitch/conntrack.c create mode 100644 net/openvswitch/conntrack.h create mode 100644 net/openvswitch/datapath.c create mode 100644 net/openvswitch/datapath.h create mode 100644 net/openvswitch/dp_notify.c create mode 100644 net/openvswitch/flow.c create mode 100644 net/openvswitch/flow.h create mode 100644 net/openvswitch/flow_netlink.c create mode 100644 net/openvswitch/flow_netlink.h create mode 100644 net/openvswitch/flow_table.c create mode 100644 net/openvswitch/flow_table.h create mode 100644 net/openvswitch/meter.c create mode 100644 net/openvswitch/meter.h create mode 100644 net/openvswitch/openvswitch_trace.c create mode 100644 net/openvswitch/openvswitch_trace.h create mode 100644 net/openvswitch/vport-geneve.c create mode 100644 net/openvswitch/vport-gre.c create mode 100644 net/openvswitch/vport-internal_dev.c create mode 100644 net/openvswitch/vport-internal_dev.h create mode 100644 net/openvswitch/vport-netdev.c create mode 100644 net/openvswitch/vport-netdev.h create mode 100644 net/openvswitch/vport-vxlan.c create mode 100644 net/openvswitch/vport.c create mode 100644 net/openvswitch/vport.h create mode 100644 net/packet/Kconfig create mode 100644 net/packet/Makefile create mode 100644 net/packet/af_packet.c create mode 100644 net/packet/diag.c create mode 100644 net/packet/internal.h create mode 100644 net/phonet/Kconfig create mode 100644 net/phonet/Makefile create mode 100644 net/phonet/af_phonet.c create mode 100644 net/phonet/datagram.c create mode 100644 net/phonet/pep-gprs.c create mode 100644 net/phonet/pep.c create mode 100644 net/phonet/pn_dev.c create mode 100644 net/phonet/pn_netlink.c create mode 100644 net/phonet/socket.c create mode 100644 net/phonet/sysctl.c create mode 100644 net/psample/Kconfig create mode 100644 net/psample/Makefile create mode 100644 net/psample/psample.c create mode 100644 net/qrtr/Kconfig create mode 100644 net/qrtr/Makefile create mode 100644 net/qrtr/af_qrtr.c create mode 100644 net/qrtr/mhi.c create mode 100644 net/qrtr/ns.c create mode 100644 net/qrtr/qrtr.h create mode 100644 net/qrtr/smd.c create mode 100644 net/qrtr/tun.c create mode 100644 net/rds/Kconfig create mode 100644 net/rds/Makefile create mode 100644 net/rds/af_rds.c create mode 100644 net/rds/bind.c create mode 100644 net/rds/cong.c create mode 100644 net/rds/connection.c create mode 100644 net/rds/ib.c create mode 100644 net/rds/ib.h create mode 100644 net/rds/ib_cm.c create mode 100644 net/rds/ib_frmr.c create mode 100644 net/rds/ib_mr.h create mode 100644 net/rds/ib_rdma.c create mode 100644 net/rds/ib_recv.c create mode 100644 net/rds/ib_ring.c create mode 100644 net/rds/ib_send.c create mode 100644 net/rds/ib_stats.c create mode 100644 net/rds/ib_sysctl.c create mode 100644 net/rds/info.c create mode 100644 net/rds/info.h create mode 100644 net/rds/loop.c create mode 100644 net/rds/loop.h create mode 100644 net/rds/message.c create mode 100644 net/rds/page.c create mode 100644 net/rds/rdma.c create mode 100644 net/rds/rdma_transport.c create mode 100644 net/rds/rdma_transport.h create mode 100644 net/rds/rds.h create mode 100644 net/rds/rds_single_path.h create mode 100644 net/rds/recv.c create mode 100644 net/rds/send.c create mode 100644 net/rds/stats.c create mode 100644 net/rds/sysctl.c create mode 100644 net/rds/tcp.c create mode 100644 net/rds/tcp.h create mode 100644 net/rds/tcp_connect.c create mode 100644 net/rds/tcp_listen.c create mode 100644 net/rds/tcp_recv.c create mode 100644 net/rds/tcp_send.c create mode 100644 net/rds/tcp_stats.c create mode 100644 net/rds/threads.c create mode 100644 net/rds/transport.c create mode 100644 net/rfkill/Kconfig create mode 100644 net/rfkill/Makefile create mode 100644 net/rfkill/core.c create mode 100644 net/rfkill/input.c create mode 100644 net/rfkill/rfkill-gpio.c create mode 100644 net/rfkill/rfkill.h create mode 100644 net/rose/Makefile create mode 100644 net/rose/af_rose.c create mode 100644 net/rose/rose_dev.c create mode 100644 net/rose/rose_in.c create mode 100644 net/rose/rose_link.c create mode 100644 net/rose/rose_loopback.c create mode 100644 net/rose/rose_out.c create mode 100644 net/rose/rose_route.c create mode 100644 net/rose/rose_subr.c create mode 100644 net/rose/rose_timer.c create mode 100644 net/rose/sysctl_net_rose.c create mode 100644 net/rxrpc/Kconfig create mode 100644 net/rxrpc/Makefile create mode 100644 net/rxrpc/af_rxrpc.c create mode 100644 net/rxrpc/ar-internal.h create mode 100644 net/rxrpc/call_accept.c create mode 100644 net/rxrpc/call_event.c create mode 100644 net/rxrpc/call_object.c create mode 100644 net/rxrpc/conn_client.c create mode 100644 net/rxrpc/conn_event.c create mode 100644 net/rxrpc/conn_object.c create mode 100644 net/rxrpc/conn_service.c create mode 100644 net/rxrpc/input.c create mode 100644 net/rxrpc/insecure.c create mode 100644 net/rxrpc/key.c create mode 100644 net/rxrpc/local_event.c create mode 100644 net/rxrpc/local_object.c create mode 100644 net/rxrpc/misc.c create mode 100644 net/rxrpc/net_ns.c create mode 100644 net/rxrpc/output.c create mode 100644 net/rxrpc/peer_event.c create mode 100644 net/rxrpc/peer_object.c create mode 100644 net/rxrpc/proc.c create mode 100644 net/rxrpc/protocol.h create mode 100644 net/rxrpc/recvmsg.c create mode 100644 net/rxrpc/rtt.c create mode 100644 net/rxrpc/rxkad.c create mode 100644 net/rxrpc/security.c create mode 100644 net/rxrpc/sendmsg.c create mode 100644 net/rxrpc/server_key.c create mode 100644 net/rxrpc/skbuff.c create mode 100644 net/rxrpc/sysctl.c create mode 100644 net/rxrpc/utils.c create mode 100644 net/sched/Kconfig create mode 100644 net/sched/Makefile create mode 100644 net/sched/act_api.c create mode 100644 net/sched/act_bpf.c create mode 100644 net/sched/act_connmark.c create mode 100644 net/sched/act_csum.c create mode 100644 net/sched/act_ct.c create mode 100644 net/sched/act_ctinfo.c create mode 100644 net/sched/act_gact.c create mode 100644 net/sched/act_gate.c create mode 100644 net/sched/act_ife.c create mode 100644 net/sched/act_ipt.c create mode 100644 net/sched/act_meta_mark.c create mode 100644 net/sched/act_meta_skbprio.c create mode 100644 net/sched/act_meta_skbtcindex.c create mode 100644 net/sched/act_mirred.c create mode 100644 net/sched/act_mpls.c create mode 100644 net/sched/act_nat.c create mode 100644 net/sched/act_pedit.c create mode 100644 net/sched/act_police.c create mode 100644 net/sched/act_sample.c create mode 100644 net/sched/act_simple.c create mode 100644 net/sched/act_skbedit.c create mode 100644 net/sched/act_skbmod.c create mode 100644 net/sched/act_tunnel_key.c create mode 100644 net/sched/act_vlan.c create mode 100644 net/sched/cls_api.c create mode 100644 net/sched/cls_basic.c create mode 100644 net/sched/cls_bpf.c create mode 100644 net/sched/cls_cgroup.c create mode 100644 net/sched/cls_flow.c create mode 100644 net/sched/cls_flower.c create mode 100644 net/sched/cls_fw.c create mode 100644 net/sched/cls_matchall.c create mode 100644 net/sched/cls_route.c create mode 100644 net/sched/cls_u32.c create mode 100644 net/sched/em_canid.c create mode 100644 net/sched/em_cmp.c create mode 100644 net/sched/em_ipset.c create mode 100644 net/sched/em_ipt.c create mode 100644 net/sched/em_meta.c create mode 100644 net/sched/em_nbyte.c create mode 100644 net/sched/em_text.c create mode 100644 net/sched/em_u32.c create mode 100644 net/sched/ematch.c create mode 100644 net/sched/sch_api.c create mode 100644 net/sched/sch_atm.c create mode 100644 net/sched/sch_blackhole.c create mode 100644 net/sched/sch_cake.c create mode 100644 net/sched/sch_cbq.c create mode 100644 net/sched/sch_cbs.c create mode 100644 net/sched/sch_choke.c create mode 100644 net/sched/sch_codel.c create mode 100644 net/sched/sch_drr.c create mode 100644 net/sched/sch_dsmark.c create mode 100644 net/sched/sch_etf.c create mode 100644 net/sched/sch_ets.c create mode 100644 net/sched/sch_fifo.c create mode 100644 net/sched/sch_fq.c create mode 100644 net/sched/sch_fq_codel.c create mode 100644 net/sched/sch_fq_pie.c create mode 100644 net/sched/sch_frag.c create mode 100644 net/sched/sch_generic.c create mode 100644 net/sched/sch_gred.c create mode 100644 net/sched/sch_hfsc.c create mode 100644 net/sched/sch_hhf.c create mode 100644 net/sched/sch_htb.c create mode 100644 net/sched/sch_ingress.c create mode 100644 net/sched/sch_mq.c create mode 100644 net/sched/sch_mqprio.c create mode 100644 net/sched/sch_multiq.c create mode 100644 net/sched/sch_netem.c create mode 100644 net/sched/sch_pie.c create mode 100644 net/sched/sch_plug.c create mode 100644 net/sched/sch_prio.c create mode 100644 net/sched/sch_qfq.c create mode 100644 net/sched/sch_red.c create mode 100644 net/sched/sch_sfb.c create mode 100644 net/sched/sch_sfq.c create mode 100644 net/sched/sch_skbprio.c create mode 100644 net/sched/sch_taprio.c create mode 100644 net/sched/sch_tbf.c create mode 100644 net/sched/sch_teql.c create mode 100644 net/sctp/Kconfig create mode 100644 net/sctp/Makefile create mode 100644 net/sctp/associola.c create mode 100644 net/sctp/auth.c create mode 100644 net/sctp/bind_addr.c create mode 100644 net/sctp/chunk.c create mode 100644 net/sctp/debug.c create mode 100644 net/sctp/diag.c create mode 100644 net/sctp/endpointola.c create mode 100644 net/sctp/input.c create mode 100644 net/sctp/inqueue.c create mode 100644 net/sctp/ipv6.c create mode 100644 net/sctp/objcnt.c create mode 100644 net/sctp/offload.c create mode 100644 net/sctp/output.c create mode 100644 net/sctp/outqueue.c create mode 100644 net/sctp/primitive.c create mode 100644 net/sctp/proc.c create mode 100644 net/sctp/protocol.c create mode 100644 net/sctp/sm_make_chunk.c create mode 100644 net/sctp/sm_sideeffect.c create mode 100644 net/sctp/sm_statefuns.c create mode 100644 net/sctp/sm_statetable.c create mode 100644 net/sctp/socket.c create mode 100644 net/sctp/stream.c create mode 100644 net/sctp/stream_interleave.c create mode 100644 net/sctp/stream_sched.c create mode 100644 net/sctp/stream_sched_prio.c create mode 100644 net/sctp/stream_sched_rr.c create mode 100644 net/sctp/sysctl.c create mode 100644 net/sctp/transport.c create mode 100644 net/sctp/tsnmap.c create mode 100644 net/sctp/ulpevent.c create mode 100644 net/sctp/ulpqueue.c create mode 100644 net/smc/Kconfig create mode 100644 net/smc/Makefile create mode 100644 net/smc/af_smc.c create mode 100644 net/smc/smc.h create mode 100644 net/smc/smc_cdc.c create mode 100644 net/smc/smc_cdc.h create mode 100644 net/smc/smc_clc.c create mode 100644 net/smc/smc_clc.h create mode 100644 net/smc/smc_close.c create mode 100644 net/smc/smc_close.h create mode 100644 net/smc/smc_core.c create mode 100644 net/smc/smc_core.h create mode 100644 net/smc/smc_diag.c create mode 100644 net/smc/smc_ib.c create mode 100644 net/smc/smc_ib.h create mode 100644 net/smc/smc_ism.c create mode 100644 net/smc/smc_ism.h create mode 100644 net/smc/smc_llc.c create mode 100644 net/smc/smc_llc.h create mode 100644 net/smc/smc_netlink.c create mode 100644 net/smc/smc_netlink.h create mode 100644 net/smc/smc_netns.h create mode 100644 net/smc/smc_pnet.c create mode 100644 net/smc/smc_pnet.h create mode 100644 net/smc/smc_rx.c create mode 100644 net/smc/smc_rx.h create mode 100644 net/smc/smc_stats.c create mode 100644 net/smc/smc_stats.h create mode 100644 net/smc/smc_sysctl.c create mode 100644 net/smc/smc_sysctl.h create mode 100644 net/smc/smc_tracepoint.c create mode 100644 net/smc/smc_tracepoint.h create mode 100644 net/smc/smc_tx.c create mode 100644 net/smc/smc_tx.h create mode 100644 net/smc/smc_wr.c create mode 100644 net/smc/smc_wr.h create mode 100644 net/socket.c create mode 100644 net/strparser/Kconfig create mode 100644 net/strparser/Makefile create mode 100644 net/strparser/strparser.c create mode 100644 net/sunrpc/Kconfig create mode 100644 net/sunrpc/Makefile create mode 100644 net/sunrpc/addr.c create mode 100644 net/sunrpc/auth.c create mode 100644 net/sunrpc/auth_gss/Makefile create mode 100644 net/sunrpc/auth_gss/auth_gss.c create mode 100644 net/sunrpc/auth_gss/auth_gss_internal.h create mode 100644 net/sunrpc/auth_gss/gss_generic_token.c create mode 100644 net/sunrpc/auth_gss/gss_krb5_crypto.c create mode 100644 net/sunrpc/auth_gss/gss_krb5_keys.c create mode 100644 net/sunrpc/auth_gss/gss_krb5_mech.c create mode 100644 net/sunrpc/auth_gss/gss_krb5_seal.c create mode 100644 net/sunrpc/auth_gss/gss_krb5_seqnum.c create mode 100644 net/sunrpc/auth_gss/gss_krb5_unseal.c create mode 100644 net/sunrpc/auth_gss/gss_krb5_wrap.c create mode 100644 net/sunrpc/auth_gss/gss_mech_switch.c create mode 100644 net/sunrpc/auth_gss/gss_rpc_upcall.c create mode 100644 net/sunrpc/auth_gss/gss_rpc_upcall.h create mode 100644 net/sunrpc/auth_gss/gss_rpc_xdr.c create mode 100644 net/sunrpc/auth_gss/gss_rpc_xdr.h create mode 100644 net/sunrpc/auth_gss/svcauth_gss.c create mode 100644 net/sunrpc/auth_gss/trace.c create mode 100644 net/sunrpc/auth_null.c create mode 100644 net/sunrpc/auth_unix.c create mode 100644 net/sunrpc/backchannel_rqst.c create mode 100644 net/sunrpc/cache.c create mode 100644 net/sunrpc/clnt.c create mode 100644 net/sunrpc/debugfs.c create mode 100644 net/sunrpc/fail.h create mode 100644 net/sunrpc/netns.h create mode 100644 net/sunrpc/rpc_pipe.c create mode 100644 net/sunrpc/rpcb_clnt.c create mode 100644 net/sunrpc/sched.c create mode 100644 net/sunrpc/socklib.c create mode 100644 net/sunrpc/socklib.h create mode 100644 net/sunrpc/stats.c create mode 100644 net/sunrpc/sunrpc.h create mode 100644 net/sunrpc/sunrpc_syms.c create mode 100644 net/sunrpc/svc.c create mode 100644 net/sunrpc/svc_xprt.c create mode 100644 net/sunrpc/svcauth.c create mode 100644 net/sunrpc/svcauth_unix.c create mode 100644 net/sunrpc/svcsock.c create mode 100644 net/sunrpc/sysctl.c create mode 100644 net/sunrpc/sysfs.c create mode 100644 net/sunrpc/sysfs.h create mode 100644 net/sunrpc/timer.c create mode 100644 net/sunrpc/xdr.c create mode 100644 net/sunrpc/xprt.c create mode 100644 net/sunrpc/xprtmultipath.c create mode 100644 net/sunrpc/xprtrdma/Makefile create mode 100644 net/sunrpc/xprtrdma/backchannel.c create mode 100644 net/sunrpc/xprtrdma/frwr_ops.c create mode 100644 net/sunrpc/xprtrdma/module.c create mode 100644 net/sunrpc/xprtrdma/rpc_rdma.c create mode 100644 net/sunrpc/xprtrdma/svc_rdma.c create mode 100644 net/sunrpc/xprtrdma/svc_rdma_backchannel.c create mode 100644 net/sunrpc/xprtrdma/svc_rdma_pcl.c create mode 100644 net/sunrpc/xprtrdma/svc_rdma_recvfrom.c create mode 100644 net/sunrpc/xprtrdma/svc_rdma_rw.c create mode 100644 net/sunrpc/xprtrdma/svc_rdma_sendto.c create mode 100644 net/sunrpc/xprtrdma/svc_rdma_transport.c create mode 100644 net/sunrpc/xprtrdma/transport.c create mode 100644 net/sunrpc/xprtrdma/verbs.c create mode 100644 net/sunrpc/xprtrdma/xprt_rdma.h create mode 100644 net/sunrpc/xprtsock.c create mode 100644 net/switchdev/Kconfig create mode 100644 net/switchdev/Makefile create mode 100644 net/switchdev/switchdev.c create mode 100644 net/sysctl_net.c create mode 100644 net/tipc/Kconfig create mode 100644 net/tipc/Makefile create mode 100644 net/tipc/addr.c create mode 100644 net/tipc/addr.h create mode 100644 net/tipc/bcast.c create mode 100644 net/tipc/bcast.h create mode 100644 net/tipc/bearer.c create mode 100644 net/tipc/bearer.h create mode 100644 net/tipc/core.c create mode 100644 net/tipc/core.h create mode 100644 net/tipc/crypto.c create mode 100644 net/tipc/crypto.h create mode 100644 net/tipc/diag.c create mode 100644 net/tipc/discover.c create mode 100644 net/tipc/discover.h create mode 100644 net/tipc/eth_media.c create mode 100644 net/tipc/group.c create mode 100644 net/tipc/group.h create mode 100644 net/tipc/ib_media.c create mode 100644 net/tipc/link.c create mode 100644 net/tipc/link.h create mode 100644 net/tipc/monitor.c create mode 100644 net/tipc/monitor.h create mode 100644 net/tipc/msg.c create mode 100644 net/tipc/msg.h create mode 100644 net/tipc/name_distr.c create mode 100644 net/tipc/name_distr.h create mode 100644 net/tipc/name_table.c create mode 100644 net/tipc/name_table.h create mode 100644 net/tipc/net.c create mode 100644 net/tipc/net.h create mode 100644 net/tipc/netlink.c create mode 100644 net/tipc/netlink.h create mode 100644 net/tipc/netlink_compat.c create mode 100644 net/tipc/node.c create mode 100644 net/tipc/node.h create mode 100644 net/tipc/socket.c create mode 100644 net/tipc/socket.h create mode 100644 net/tipc/subscr.c create mode 100644 net/tipc/subscr.h create mode 100644 net/tipc/sysctl.c create mode 100644 net/tipc/topsrv.c create mode 100644 net/tipc/topsrv.h create mode 100644 net/tipc/trace.c create mode 100644 net/tipc/trace.h create mode 100644 net/tipc/udp_media.c create mode 100644 net/tipc/udp_media.h create mode 100644 net/tls/Kconfig create mode 100644 net/tls/Makefile create mode 100644 net/tls/tls.h create mode 100644 net/tls/tls_device.c create mode 100644 net/tls/tls_device_fallback.c create mode 100644 net/tls/tls_main.c create mode 100644 net/tls/tls_proc.c create mode 100644 net/tls/tls_strp.c create mode 100644 net/tls/tls_sw.c create mode 100644 net/tls/tls_toe.c create mode 100644 net/tls/trace.c create mode 100644 net/tls/trace.h create mode 100644 net/unix/Kconfig create mode 100644 net/unix/Makefile create mode 100644 net/unix/af_unix.c create mode 100644 net/unix/diag.c create mode 100644 net/unix/garbage.c create mode 100644 net/unix/scm.c create mode 100644 net/unix/scm.h create mode 100644 net/unix/sysctl_net_unix.c create mode 100644 net/unix/unix_bpf.c create mode 100644 net/vmw_vsock/Kconfig create mode 100644 net/vmw_vsock/Makefile create mode 100644 net/vmw_vsock/af_vsock.c create mode 100644 net/vmw_vsock/af_vsock_tap.c create mode 100644 net/vmw_vsock/diag.c create mode 100644 net/vmw_vsock/hyperv_transport.c create mode 100644 net/vmw_vsock/virtio_transport.c create mode 100644 net/vmw_vsock/virtio_transport_common.c create mode 100644 net/vmw_vsock/vmci_transport.c create mode 100644 net/vmw_vsock/vmci_transport.h create mode 100644 net/vmw_vsock/vmci_transport_notify.c create mode 100644 net/vmw_vsock/vmci_transport_notify.h create mode 100644 net/vmw_vsock/vmci_transport_notify_qstate.c create mode 100644 net/vmw_vsock/vsock_addr.c create mode 100644 net/vmw_vsock/vsock_loopback.c create mode 100644 net/wireless/.gitignore create mode 100644 net/wireless/Kconfig create mode 100644 net/wireless/Makefile create mode 100644 net/wireless/ap.c create mode 100644 net/wireless/certs/sforshee.hex create mode 100644 net/wireless/certs/wens.hex create mode 100644 net/wireless/chan.c create mode 100644 net/wireless/core.c create mode 100644 net/wireless/core.h create mode 100644 net/wireless/debugfs.c create mode 100644 net/wireless/debugfs.h create mode 100644 net/wireless/ethtool.c create mode 100644 net/wireless/ibss.c create mode 100644 net/wireless/lib80211.c create mode 100644 net/wireless/lib80211_crypt_ccmp.c create mode 100644 net/wireless/lib80211_crypt_tkip.c create mode 100644 net/wireless/lib80211_crypt_wep.c create mode 100644 net/wireless/mesh.c create mode 100644 net/wireless/mlme.c create mode 100644 net/wireless/nl80211.c create mode 100644 net/wireless/nl80211.h create mode 100644 net/wireless/ocb.c create mode 100644 net/wireless/of.c create mode 100644 net/wireless/pmsr.c create mode 100644 net/wireless/radiotap.c create mode 100644 net/wireless/rdev-ops.h create mode 100644 net/wireless/reg.c create mode 100644 net/wireless/reg.h create mode 100644 net/wireless/scan.c create mode 100644 net/wireless/sme.c create mode 100644 net/wireless/sysfs.c create mode 100644 net/wireless/sysfs.h create mode 100644 net/wireless/trace.c create mode 100644 net/wireless/trace.h create mode 100644 net/wireless/util.c create mode 100644 net/wireless/wext-compat.c create mode 100644 net/wireless/wext-compat.h create mode 100644 net/wireless/wext-core.c create mode 100644 net/wireless/wext-priv.c create mode 100644 net/wireless/wext-proc.c create mode 100644 net/wireless/wext-sme.c create mode 100644 net/wireless/wext-spy.c create mode 100644 net/x25/Kconfig create mode 100644 net/x25/Makefile create mode 100644 net/x25/af_x25.c create mode 100644 net/x25/sysctl_net_x25.c create mode 100644 net/x25/x25_dev.c create mode 100644 net/x25/x25_facilities.c create mode 100644 net/x25/x25_forward.c create mode 100644 net/x25/x25_in.c create mode 100644 net/x25/x25_link.c create mode 100644 net/x25/x25_out.c create mode 100644 net/x25/x25_proc.c create mode 100644 net/x25/x25_route.c create mode 100644 net/x25/x25_subr.c create mode 100644 net/x25/x25_timer.c create mode 100644 net/xdp/Kconfig create mode 100644 net/xdp/Makefile create mode 100644 net/xdp/xdp_umem.c create mode 100644 net/xdp/xdp_umem.h create mode 100644 net/xdp/xsk.c create mode 100644 net/xdp/xsk.h create mode 100644 net/xdp/xsk_buff_pool.c create mode 100644 net/xdp/xsk_diag.c create mode 100644 net/xdp/xsk_queue.c create mode 100644 net/xdp/xsk_queue.h create mode 100644 net/xdp/xskmap.c create mode 100644 net/xfrm/Kconfig create mode 100644 net/xfrm/Makefile create mode 100644 net/xfrm/espintcp.c create mode 100644 net/xfrm/xfrm_algo.c create mode 100644 net/xfrm/xfrm_compat.c create mode 100644 net/xfrm/xfrm_device.c create mode 100644 net/xfrm/xfrm_hash.c create mode 100644 net/xfrm/xfrm_hash.h create mode 100644 net/xfrm/xfrm_inout.h create mode 100644 net/xfrm/xfrm_input.c create mode 100644 net/xfrm/xfrm_interface_core.c create mode 100644 net/xfrm/xfrm_ipcomp.c create mode 100644 net/xfrm/xfrm_output.c create mode 100644 net/xfrm/xfrm_policy.c create mode 100644 net/xfrm/xfrm_proc.c create mode 100644 net/xfrm/xfrm_replay.c create mode 100644 net/xfrm/xfrm_state.c create mode 100644 net/xfrm/xfrm_sysctl.c create mode 100644 net/xfrm/xfrm_user.c (limited to 'net') diff --git a/net/6lowpan/6lowpan_i.h b/net/6lowpan/6lowpan_i.h new file mode 100644 index 000000000..01853cec0 --- /dev/null +++ b/net/6lowpan/6lowpan_i.h @@ -0,0 +1,34 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __6LOWPAN_I_H +#define __6LOWPAN_I_H + +#include + +#include + +/* caller need to be sure it's dev->type is ARPHRD_6LOWPAN */ +static inline bool lowpan_is_ll(const struct net_device *dev, + enum lowpan_lltypes lltype) +{ + return lowpan_dev(dev)->lltype == lltype; +} + +extern const struct ndisc_ops lowpan_ndisc_ops; + +int addrconf_ifid_802154_6lowpan(u8 *eui, struct net_device *dev); + +#ifdef CONFIG_6LOWPAN_DEBUGFS +void lowpan_dev_debugfs_init(struct net_device *dev); +void lowpan_dev_debugfs_exit(struct net_device *dev); + +void __init lowpan_debugfs_init(void); +void lowpan_debugfs_exit(void); +#else +static inline void lowpan_dev_debugfs_init(struct net_device *dev) { } +static inline void lowpan_dev_debugfs_exit(struct net_device *dev) { } + +static inline void __init lowpan_debugfs_init(void) { } +static inline void lowpan_debugfs_exit(void) { } +#endif /* CONFIG_6LOWPAN_DEBUGFS */ + +#endif /* __6LOWPAN_I_H */ diff --git a/net/6lowpan/Kconfig b/net/6lowpan/Kconfig new file mode 100644 index 000000000..d8fc45949 --- /dev/null +++ b/net/6lowpan/Kconfig @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: GPL-2.0-only +menuconfig 6LOWPAN + tristate "6LoWPAN Support" + depends on IPV6 + help + This enables IPv6 over Low power Wireless Personal Area Network - + "6LoWPAN" which is supported by IEEE 802.15.4 or Bluetooth stacks. + +config 6LOWPAN_DEBUGFS + bool "6LoWPAN debugfs support" + depends on 6LOWPAN + depends on DEBUG_FS + help + This enables 6LoWPAN debugfs support. For example to manipulate + IPHC context information at runtime. + +menuconfig 6LOWPAN_NHC + tristate "Next Header and Generic Header Compression Support" + depends on 6LOWPAN + default y + help + Support for next header and generic header compression defined in + RFC6282 and RFC7400. + +if 6LOWPAN_NHC + +config 6LOWPAN_NHC_DEST + tristate "Destination Options Header Support" + default y + help + 6LoWPAN IPv6 Destination Options Header compression according to + RFC6282. + +config 6LOWPAN_NHC_FRAGMENT + tristate "Fragment Header Support" + default y + help + 6LoWPAN IPv6 Fragment Header compression according to RFC6282. + +config 6LOWPAN_NHC_HOP + tristate "Hop-by-Hop Options Header Support" + default y + help + 6LoWPAN IPv6 Hop-by-Hop Options Header compression according to + RFC6282. + +config 6LOWPAN_NHC_IPV6 + tristate "IPv6 Header Support" + default y + help + 6LoWPAN IPv6 Header compression according to RFC6282. + +config 6LOWPAN_NHC_MOBILITY + tristate "Mobility Header Support" + default y + help + 6LoWPAN IPv6 Mobility Header compression according to RFC6282. + +config 6LOWPAN_NHC_ROUTING + tristate "Routing Header Support" + default y + help + 6LoWPAN IPv6 Routing Header compression according to RFC6282. + +config 6LOWPAN_NHC_UDP + tristate "UDP Header Support" + default y + help + 6LoWPAN IPv6 UDP Header compression according to RFC6282. + +config 6LOWPAN_GHC_EXT_HDR_HOP + tristate "GHC Hop-by-Hop Options Header Support" + help + 6LoWPAN IPv6 Hop-by-Hop option generic header compression according + to RFC7400. + +config 6LOWPAN_GHC_UDP + tristate "GHC UDP Support" + help + 6LoWPAN IPv6 UDP generic header compression according to RFC7400. + +config 6LOWPAN_GHC_ICMPV6 + tristate "GHC ICMPv6 Support" + help + 6LoWPAN IPv6 ICMPv6 generic header compression according to RFC7400. + +config 6LOWPAN_GHC_EXT_HDR_DEST + tristate "GHC Destination Options Header Support" + help + 6LoWPAN IPv6 destination option generic header compression according + to RFC7400. + +config 6LOWPAN_GHC_EXT_HDR_FRAG + tristate "GHC Fragmentation Options Header Support" + help + 6LoWPAN IPv6 fragmentation option generic header compression + according to RFC7400. + +config 6LOWPAN_GHC_EXT_HDR_ROUTE + tristate "GHC Routing Options Header Support" + help + 6LoWPAN IPv6 routing option generic header compression according + to RFC7400. + +endif diff --git a/net/6lowpan/Makefile b/net/6lowpan/Makefile new file mode 100644 index 000000000..2247b96db --- /dev/null +++ b/net/6lowpan/Makefile @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: GPL-2.0 +obj-$(CONFIG_6LOWPAN) += 6lowpan.o + +6lowpan-y := core.o iphc.o nhc.o ndisc.o +6lowpan-$(CONFIG_6LOWPAN_DEBUGFS) += debugfs.o + +#rfc6282 nhcs +obj-$(CONFIG_6LOWPAN_NHC_DEST) += nhc_dest.o +obj-$(CONFIG_6LOWPAN_NHC_FRAGMENT) += nhc_fragment.o +obj-$(CONFIG_6LOWPAN_NHC_HOP) += nhc_hop.o +obj-$(CONFIG_6LOWPAN_NHC_IPV6) += nhc_ipv6.o +obj-$(CONFIG_6LOWPAN_NHC_MOBILITY) += nhc_mobility.o +obj-$(CONFIG_6LOWPAN_NHC_ROUTING) += nhc_routing.o +obj-$(CONFIG_6LOWPAN_NHC_UDP) += nhc_udp.o + +#rfc7400 ghcs +obj-$(CONFIG_6LOWPAN_GHC_EXT_HDR_HOP) += nhc_ghc_ext_hop.o +obj-$(CONFIG_6LOWPAN_GHC_UDP) += nhc_ghc_udp.o +obj-$(CONFIG_6LOWPAN_GHC_ICMPV6) += nhc_ghc_icmpv6.o +obj-$(CONFIG_6LOWPAN_GHC_EXT_HDR_DEST) += nhc_ghc_ext_dest.o +obj-$(CONFIG_6LOWPAN_GHC_EXT_HDR_FRAG) += nhc_ghc_ext_frag.o +obj-$(CONFIG_6LOWPAN_GHC_EXT_HDR_ROUTE) += nhc_ghc_ext_route.o diff --git a/net/6lowpan/core.c b/net/6lowpan/core.c new file mode 100644 index 000000000..7b3341cef --- /dev/null +++ b/net/6lowpan/core.c @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * + * Authors: + * (C) 2015 Pengutronix, Alexander Aring + */ + +#include +#include + +#include +#include + +#include "6lowpan_i.h" + +int lowpan_register_netdevice(struct net_device *dev, + enum lowpan_lltypes lltype) +{ + int i, ret; + + switch (lltype) { + case LOWPAN_LLTYPE_IEEE802154: + dev->addr_len = EUI64_ADDR_LEN; + break; + + case LOWPAN_LLTYPE_BTLE: + dev->addr_len = ETH_ALEN; + break; + } + + dev->type = ARPHRD_6LOWPAN; + dev->mtu = IPV6_MIN_MTU; + + lowpan_dev(dev)->lltype = lltype; + + spin_lock_init(&lowpan_dev(dev)->ctx.lock); + for (i = 0; i < LOWPAN_IPHC_CTX_TABLE_SIZE; i++) + lowpan_dev(dev)->ctx.table[i].id = i; + + dev->ndisc_ops = &lowpan_ndisc_ops; + + ret = register_netdevice(dev); + if (ret < 0) + return ret; + + lowpan_dev_debugfs_init(dev); + + return ret; +} +EXPORT_SYMBOL(lowpan_register_netdevice); + +int lowpan_register_netdev(struct net_device *dev, + enum lowpan_lltypes lltype) +{ + int ret; + + rtnl_lock(); + ret = lowpan_register_netdevice(dev, lltype); + rtnl_unlock(); + return ret; +} +EXPORT_SYMBOL(lowpan_register_netdev); + +void lowpan_unregister_netdevice(struct net_device *dev) +{ + unregister_netdevice(dev); + lowpan_dev_debugfs_exit(dev); +} +EXPORT_SYMBOL(lowpan_unregister_netdevice); + +void lowpan_unregister_netdev(struct net_device *dev) +{ + rtnl_lock(); + lowpan_unregister_netdevice(dev); + rtnl_unlock(); +} +EXPORT_SYMBOL(lowpan_unregister_netdev); + +int addrconf_ifid_802154_6lowpan(u8 *eui, struct net_device *dev) +{ + struct wpan_dev *wpan_dev = lowpan_802154_dev(dev)->wdev->ieee802154_ptr; + + /* Set short_addr autoconfiguration if short_addr is present only */ + if (!lowpan_802154_is_valid_src_short_addr(wpan_dev->short_addr)) + return -1; + + /* For either address format, all zero addresses MUST NOT be used */ + if (wpan_dev->pan_id == cpu_to_le16(0x0000) && + wpan_dev->short_addr == cpu_to_le16(0x0000)) + return -1; + + /* Alternatively, if no PAN ID is known, 16 zero bits may be used */ + if (wpan_dev->pan_id == cpu_to_le16(IEEE802154_PAN_ID_BROADCAST)) + memset(eui, 0, 2); + else + ieee802154_le16_to_be16(eui, &wpan_dev->pan_id); + + /* The "Universal/Local" (U/L) bit shall be set to zero */ + eui[0] &= ~2; + eui[2] = 0; + eui[3] = 0xFF; + eui[4] = 0xFE; + eui[5] = 0; + ieee802154_le16_to_be16(&eui[6], &wpan_dev->short_addr); + return 0; +} + +static int lowpan_event(struct notifier_block *unused, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct inet6_dev *idev; + struct in6_addr addr; + int i; + + if (dev->type != ARPHRD_6LOWPAN) + return NOTIFY_DONE; + + idev = __in6_dev_get(dev); + if (!idev) + return NOTIFY_DONE; + + switch (event) { + case NETDEV_UP: + case NETDEV_CHANGE: + /* (802.15.4 6LoWPAN short address slaac handling */ + if (lowpan_is_ll(dev, LOWPAN_LLTYPE_IEEE802154) && + addrconf_ifid_802154_6lowpan(addr.s6_addr + 8, dev) == 0) { + __ipv6_addr_set_half(&addr.s6_addr32[0], + htonl(0xFE800000), 0); + addrconf_add_linklocal(idev, &addr, 0); + } + break; + case NETDEV_DOWN: + for (i = 0; i < LOWPAN_IPHC_CTX_TABLE_SIZE; i++) + clear_bit(LOWPAN_IPHC_CTX_FLAG_ACTIVE, + &lowpan_dev(dev)->ctx.table[i].flags); + break; + default: + return NOTIFY_DONE; + } + + return NOTIFY_OK; +} + +static struct notifier_block lowpan_notifier = { + .notifier_call = lowpan_event, +}; + +static int __init lowpan_module_init(void) +{ + int ret; + + lowpan_debugfs_init(); + + ret = register_netdevice_notifier(&lowpan_notifier); + if (ret < 0) { + lowpan_debugfs_exit(); + return ret; + } + + request_module_nowait("nhc_dest"); + request_module_nowait("nhc_fragment"); + request_module_nowait("nhc_hop"); + request_module_nowait("nhc_ipv6"); + request_module_nowait("nhc_mobility"); + request_module_nowait("nhc_routing"); + request_module_nowait("nhc_udp"); + + return 0; +} + +static void __exit lowpan_module_exit(void) +{ + lowpan_debugfs_exit(); + unregister_netdevice_notifier(&lowpan_notifier); +} + +module_init(lowpan_module_init); +module_exit(lowpan_module_exit); + +MODULE_LICENSE("GPL"); diff --git a/net/6lowpan/debugfs.c b/net/6lowpan/debugfs.c new file mode 100644 index 000000000..600b9563b --- /dev/null +++ b/net/6lowpan/debugfs.c @@ -0,0 +1,278 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * + * Authors: + * (C) 2015 Pengutronix, Alexander Aring + * Copyright (c) 2015 Nordic Semiconductor. All Rights Reserved. + */ + +#include + +#include "6lowpan_i.h" + +#define LOWPAN_DEBUGFS_CTX_PFX_NUM_ARGS 8 + +static struct dentry *lowpan_debugfs; + +static int lowpan_ctx_flag_active_set(void *data, u64 val) +{ + struct lowpan_iphc_ctx *ctx = data; + + if (val != 0 && val != 1) + return -EINVAL; + + if (val) + set_bit(LOWPAN_IPHC_CTX_FLAG_ACTIVE, &ctx->flags); + else + clear_bit(LOWPAN_IPHC_CTX_FLAG_ACTIVE, &ctx->flags); + + return 0; +} + +static int lowpan_ctx_flag_active_get(void *data, u64 *val) +{ + *val = lowpan_iphc_ctx_is_active(data); + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(lowpan_ctx_flag_active_fops, + lowpan_ctx_flag_active_get, + lowpan_ctx_flag_active_set, "%llu\n"); + +static int lowpan_ctx_flag_c_set(void *data, u64 val) +{ + struct lowpan_iphc_ctx *ctx = data; + + if (val != 0 && val != 1) + return -EINVAL; + + if (val) + set_bit(LOWPAN_IPHC_CTX_FLAG_COMPRESSION, &ctx->flags); + else + clear_bit(LOWPAN_IPHC_CTX_FLAG_COMPRESSION, &ctx->flags); + + return 0; +} + +static int lowpan_ctx_flag_c_get(void *data, u64 *val) +{ + *val = lowpan_iphc_ctx_is_compression(data); + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(lowpan_ctx_flag_c_fops, lowpan_ctx_flag_c_get, + lowpan_ctx_flag_c_set, "%llu\n"); + +static int lowpan_ctx_plen_set(void *data, u64 val) +{ + struct lowpan_iphc_ctx *ctx = data; + struct lowpan_iphc_ctx_table *t = + container_of(ctx, struct lowpan_iphc_ctx_table, table[ctx->id]); + + if (val > 128) + return -EINVAL; + + spin_lock_bh(&t->lock); + ctx->plen = val; + spin_unlock_bh(&t->lock); + + return 0; +} + +static int lowpan_ctx_plen_get(void *data, u64 *val) +{ + struct lowpan_iphc_ctx *ctx = data; + struct lowpan_iphc_ctx_table *t = + container_of(ctx, struct lowpan_iphc_ctx_table, table[ctx->id]); + + spin_lock_bh(&t->lock); + *val = ctx->plen; + spin_unlock_bh(&t->lock); + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(lowpan_ctx_plen_fops, lowpan_ctx_plen_get, + lowpan_ctx_plen_set, "%llu\n"); + +static int lowpan_ctx_pfx_show(struct seq_file *file, void *offset) +{ + struct lowpan_iphc_ctx *ctx = file->private; + struct lowpan_iphc_ctx_table *t = + container_of(ctx, struct lowpan_iphc_ctx_table, table[ctx->id]); + + spin_lock_bh(&t->lock); + seq_printf(file, "%04x:%04x:%04x:%04x:%04x:%04x:%04x:%04x\n", + be16_to_cpu(ctx->pfx.s6_addr16[0]), + be16_to_cpu(ctx->pfx.s6_addr16[1]), + be16_to_cpu(ctx->pfx.s6_addr16[2]), + be16_to_cpu(ctx->pfx.s6_addr16[3]), + be16_to_cpu(ctx->pfx.s6_addr16[4]), + be16_to_cpu(ctx->pfx.s6_addr16[5]), + be16_to_cpu(ctx->pfx.s6_addr16[6]), + be16_to_cpu(ctx->pfx.s6_addr16[7])); + spin_unlock_bh(&t->lock); + + return 0; +} + +static int lowpan_ctx_pfx_open(struct inode *inode, struct file *file) +{ + return single_open(file, lowpan_ctx_pfx_show, inode->i_private); +} + +static ssize_t lowpan_ctx_pfx_write(struct file *fp, + const char __user *user_buf, size_t count, + loff_t *ppos) +{ + char buf[128] = {}; + struct seq_file *file = fp->private_data; + struct lowpan_iphc_ctx *ctx = file->private; + struct lowpan_iphc_ctx_table *t = + container_of(ctx, struct lowpan_iphc_ctx_table, table[ctx->id]); + int status = count, n, i; + unsigned int addr[8]; + + if (copy_from_user(&buf, user_buf, min_t(size_t, sizeof(buf) - 1, + count))) { + status = -EFAULT; + goto out; + } + + n = sscanf(buf, "%04x:%04x:%04x:%04x:%04x:%04x:%04x:%04x", + &addr[0], &addr[1], &addr[2], &addr[3], &addr[4], + &addr[5], &addr[6], &addr[7]); + if (n != LOWPAN_DEBUGFS_CTX_PFX_NUM_ARGS) { + status = -EINVAL; + goto out; + } + + spin_lock_bh(&t->lock); + for (i = 0; i < 8; i++) + ctx->pfx.s6_addr16[i] = cpu_to_be16(addr[i] & 0xffff); + spin_unlock_bh(&t->lock); + +out: + return status; +} + +static const struct file_operations lowpan_ctx_pfx_fops = { + .open = lowpan_ctx_pfx_open, + .read = seq_read, + .write = lowpan_ctx_pfx_write, + .llseek = seq_lseek, + .release = single_release, +}; + +static void lowpan_dev_debugfs_ctx_init(struct net_device *dev, + struct dentry *ctx, u8 id) +{ + struct lowpan_dev *ldev = lowpan_dev(dev); + struct dentry *root; + char buf[32]; + + if (WARN_ON_ONCE(id >= LOWPAN_IPHC_CTX_TABLE_SIZE)) + return; + + sprintf(buf, "%d", id); + + root = debugfs_create_dir(buf, ctx); + + debugfs_create_file("active", 0644, root, &ldev->ctx.table[id], + &lowpan_ctx_flag_active_fops); + + debugfs_create_file("compression", 0644, root, &ldev->ctx.table[id], + &lowpan_ctx_flag_c_fops); + + debugfs_create_file("prefix", 0644, root, &ldev->ctx.table[id], + &lowpan_ctx_pfx_fops); + + debugfs_create_file("prefix_len", 0644, root, &ldev->ctx.table[id], + &lowpan_ctx_plen_fops); +} + +static int lowpan_context_show(struct seq_file *file, void *offset) +{ + struct lowpan_iphc_ctx_table *t = file->private; + int i; + + seq_printf(file, "%3s|%-43s|%c\n", "cid", "prefix", 'C'); + seq_puts(file, "-------------------------------------------------\n"); + + spin_lock_bh(&t->lock); + for (i = 0; i < LOWPAN_IPHC_CTX_TABLE_SIZE; i++) { + if (!lowpan_iphc_ctx_is_active(&t->table[i])) + continue; + + seq_printf(file, "%3d|%39pI6c/%-3d|%d\n", t->table[i].id, + &t->table[i].pfx, t->table[i].plen, + lowpan_iphc_ctx_is_compression(&t->table[i])); + } + spin_unlock_bh(&t->lock); + + return 0; +} +DEFINE_SHOW_ATTRIBUTE(lowpan_context); + +static int lowpan_short_addr_get(void *data, u64 *val) +{ + struct wpan_dev *wdev = data; + + rtnl_lock(); + *val = le16_to_cpu(wdev->short_addr); + rtnl_unlock(); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(lowpan_short_addr_fops, lowpan_short_addr_get, NULL, + "0x%04llx\n"); + +static void lowpan_dev_debugfs_802154_init(const struct net_device *dev, + struct lowpan_dev *ldev) +{ + struct dentry *root; + + if (!lowpan_is_ll(dev, LOWPAN_LLTYPE_IEEE802154)) + return; + + root = debugfs_create_dir("ieee802154", ldev->iface_debugfs); + + debugfs_create_file("short_addr", 0444, root, + lowpan_802154_dev(dev)->wdev->ieee802154_ptr, + &lowpan_short_addr_fops); +} + +void lowpan_dev_debugfs_init(struct net_device *dev) +{ + struct lowpan_dev *ldev = lowpan_dev(dev); + struct dentry *contexts; + int i; + + /* creating the root */ + ldev->iface_debugfs = debugfs_create_dir(dev->name, lowpan_debugfs); + + contexts = debugfs_create_dir("contexts", ldev->iface_debugfs); + + debugfs_create_file("show", 0644, contexts, &lowpan_dev(dev)->ctx, + &lowpan_context_fops); + + for (i = 0; i < LOWPAN_IPHC_CTX_TABLE_SIZE; i++) + lowpan_dev_debugfs_ctx_init(dev, contexts, i); + + lowpan_dev_debugfs_802154_init(dev, ldev); +} + +void lowpan_dev_debugfs_exit(struct net_device *dev) +{ + debugfs_remove_recursive(lowpan_dev(dev)->iface_debugfs); +} + +void __init lowpan_debugfs_init(void) +{ + lowpan_debugfs = debugfs_create_dir("6lowpan", NULL); +} + +void lowpan_debugfs_exit(void) +{ + debugfs_remove_recursive(lowpan_debugfs); +} diff --git a/net/6lowpan/iphc.c b/net/6lowpan/iphc.c new file mode 100644 index 000000000..52fad5dad --- /dev/null +++ b/net/6lowpan/iphc.c @@ -0,0 +1,1313 @@ +/* + * Copyright 2011, Siemens AG + * written by Alexander Smirnov + */ + +/* Based on patches from Jon Smirl + * Copyright (c) 2011 Jon Smirl + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + */ + +/* Jon's code is based on 6lowpan implementation for Contiki which is: + * Copyright (c) 2008, Swedish Institute of Computer Science. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the Institute nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + */ + +#include +#include +#include + +#include +#include + +#include "6lowpan_i.h" +#include "nhc.h" + +/* Values of fields within the IPHC encoding first byte */ +#define LOWPAN_IPHC_TF_MASK 0x18 +#define LOWPAN_IPHC_TF_00 0x00 +#define LOWPAN_IPHC_TF_01 0x08 +#define LOWPAN_IPHC_TF_10 0x10 +#define LOWPAN_IPHC_TF_11 0x18 + +#define LOWPAN_IPHC_NH 0x04 + +#define LOWPAN_IPHC_HLIM_MASK 0x03 +#define LOWPAN_IPHC_HLIM_00 0x00 +#define LOWPAN_IPHC_HLIM_01 0x01 +#define LOWPAN_IPHC_HLIM_10 0x02 +#define LOWPAN_IPHC_HLIM_11 0x03 + +/* Values of fields within the IPHC encoding second byte */ +#define LOWPAN_IPHC_CID 0x80 + +#define LOWPAN_IPHC_SAC 0x40 + +#define LOWPAN_IPHC_SAM_MASK 0x30 +#define LOWPAN_IPHC_SAM_00 0x00 +#define LOWPAN_IPHC_SAM_01 0x10 +#define LOWPAN_IPHC_SAM_10 0x20 +#define LOWPAN_IPHC_SAM_11 0x30 + +#define LOWPAN_IPHC_M 0x08 + +#define LOWPAN_IPHC_DAC 0x04 + +#define LOWPAN_IPHC_DAM_MASK 0x03 +#define LOWPAN_IPHC_DAM_00 0x00 +#define LOWPAN_IPHC_DAM_01 0x01 +#define LOWPAN_IPHC_DAM_10 0x02 +#define LOWPAN_IPHC_DAM_11 0x03 + +/* ipv6 address based on mac + * second bit-flip (Universe/Local) is done according RFC2464 + */ +#define is_addr_mac_addr_based(a, m) \ + ((((a)->s6_addr[8]) == (((m)[0]) ^ 0x02)) && \ + (((a)->s6_addr[9]) == (m)[1]) && \ + (((a)->s6_addr[10]) == (m)[2]) && \ + (((a)->s6_addr[11]) == (m)[3]) && \ + (((a)->s6_addr[12]) == (m)[4]) && \ + (((a)->s6_addr[13]) == (m)[5]) && \ + (((a)->s6_addr[14]) == (m)[6]) && \ + (((a)->s6_addr[15]) == (m)[7])) + +/* check whether we can compress the IID to 16 bits, + * it's possible for unicast addresses with first 49 bits are zero only. + */ +#define lowpan_is_iid_16_bit_compressable(a) \ + ((((a)->s6_addr16[4]) == 0) && \ + (((a)->s6_addr[10]) == 0) && \ + (((a)->s6_addr[11]) == 0xff) && \ + (((a)->s6_addr[12]) == 0xfe) && \ + (((a)->s6_addr[13]) == 0)) + +/* check whether the 112-bit gid of the multicast address is mappable to: */ + +/* 48 bits, FFXX::00XX:XXXX:XXXX */ +#define lowpan_is_mcast_addr_compressable48(a) \ + ((((a)->s6_addr16[1]) == 0) && \ + (((a)->s6_addr16[2]) == 0) && \ + (((a)->s6_addr16[3]) == 0) && \ + (((a)->s6_addr16[4]) == 0) && \ + (((a)->s6_addr[10]) == 0)) + +/* 32 bits, FFXX::00XX:XXXX */ +#define lowpan_is_mcast_addr_compressable32(a) \ + ((((a)->s6_addr16[1]) == 0) && \ + (((a)->s6_addr16[2]) == 0) && \ + (((a)->s6_addr16[3]) == 0) && \ + (((a)->s6_addr16[4]) == 0) && \ + (((a)->s6_addr16[5]) == 0) && \ + (((a)->s6_addr[12]) == 0)) + +/* 8 bits, FF02::00XX */ +#define lowpan_is_mcast_addr_compressable8(a) \ + ((((a)->s6_addr[1]) == 2) && \ + (((a)->s6_addr16[1]) == 0) && \ + (((a)->s6_addr16[2]) == 0) && \ + (((a)->s6_addr16[3]) == 0) && \ + (((a)->s6_addr16[4]) == 0) && \ + (((a)->s6_addr16[5]) == 0) && \ + (((a)->s6_addr16[6]) == 0) && \ + (((a)->s6_addr[14]) == 0)) + +#define lowpan_is_linklocal_zero_padded(a) \ + (!(hdr->saddr.s6_addr[1] & 0x3f) && \ + !hdr->saddr.s6_addr16[1] && \ + !hdr->saddr.s6_addr32[1]) + +#define LOWPAN_IPHC_CID_DCI(cid) (cid & 0x0f) +#define LOWPAN_IPHC_CID_SCI(cid) ((cid & 0xf0) >> 4) + +static inline void +lowpan_iphc_uncompress_802154_lladdr(struct in6_addr *ipaddr, + const void *lladdr) +{ + const struct ieee802154_addr *addr = lladdr; + u8 eui64[EUI64_ADDR_LEN]; + + switch (addr->mode) { + case IEEE802154_ADDR_LONG: + ieee802154_le64_to_be64(eui64, &addr->extended_addr); + lowpan_iphc_uncompress_eui64_lladdr(ipaddr, eui64); + break; + case IEEE802154_ADDR_SHORT: + /* fe:80::ff:fe00:XXXX + * \__/ + * short_addr + * + * Universe/Local bit is zero. + */ + ipaddr->s6_addr[0] = 0xFE; + ipaddr->s6_addr[1] = 0x80; + ipaddr->s6_addr[11] = 0xFF; + ipaddr->s6_addr[12] = 0xFE; + ieee802154_le16_to_be16(&ipaddr->s6_addr16[7], + &addr->short_addr); + break; + default: + /* should never handled and filtered by 802154 6lowpan */ + WARN_ON_ONCE(1); + break; + } +} + +static struct lowpan_iphc_ctx * +lowpan_iphc_ctx_get_by_id(const struct net_device *dev, u8 id) +{ + struct lowpan_iphc_ctx *ret = &lowpan_dev(dev)->ctx.table[id]; + + if (!lowpan_iphc_ctx_is_active(ret)) + return NULL; + + return ret; +} + +static struct lowpan_iphc_ctx * +lowpan_iphc_ctx_get_by_addr(const struct net_device *dev, + const struct in6_addr *addr) +{ + struct lowpan_iphc_ctx *table = lowpan_dev(dev)->ctx.table; + struct lowpan_iphc_ctx *ret = NULL; + struct in6_addr addr_pfx; + u8 addr_plen; + int i; + + for (i = 0; i < LOWPAN_IPHC_CTX_TABLE_SIZE; i++) { + /* Check if context is valid. A context that is not valid + * MUST NOT be used for compression. + */ + if (!lowpan_iphc_ctx_is_active(&table[i]) || + !lowpan_iphc_ctx_is_compression(&table[i])) + continue; + + ipv6_addr_prefix(&addr_pfx, addr, table[i].plen); + + /* if prefix len < 64, the remaining bits until 64th bit is + * zero. Otherwise we use table[i]->plen. + */ + if (table[i].plen < 64) + addr_plen = 64; + else + addr_plen = table[i].plen; + + if (ipv6_prefix_equal(&addr_pfx, &table[i].pfx, addr_plen)) { + /* remember first match */ + if (!ret) { + ret = &table[i]; + continue; + } + + /* get the context with longest prefix len */ + if (table[i].plen > ret->plen) + ret = &table[i]; + } + } + + return ret; +} + +static struct lowpan_iphc_ctx * +lowpan_iphc_ctx_get_by_mcast_addr(const struct net_device *dev, + const struct in6_addr *addr) +{ + struct lowpan_iphc_ctx *table = lowpan_dev(dev)->ctx.table; + struct lowpan_iphc_ctx *ret = NULL; + struct in6_addr addr_mcast, network_pfx = {}; + int i; + + /* init mcast address with */ + memcpy(&addr_mcast, addr, sizeof(*addr)); + + for (i = 0; i < LOWPAN_IPHC_CTX_TABLE_SIZE; i++) { + /* Check if context is valid. A context that is not valid + * MUST NOT be used for compression. + */ + if (!lowpan_iphc_ctx_is_active(&table[i]) || + !lowpan_iphc_ctx_is_compression(&table[i])) + continue; + + /* setting plen */ + addr_mcast.s6_addr[3] = table[i].plen; + /* get network prefix to copy into multicast address */ + ipv6_addr_prefix(&network_pfx, &table[i].pfx, + table[i].plen); + /* setting network prefix */ + memcpy(&addr_mcast.s6_addr[4], &network_pfx, 8); + + if (ipv6_addr_equal(addr, &addr_mcast)) { + ret = &table[i]; + break; + } + } + + return ret; +} + +static void lowpan_iphc_uncompress_lladdr(const struct net_device *dev, + struct in6_addr *ipaddr, + const void *lladdr) +{ + switch (dev->addr_len) { + case ETH_ALEN: + lowpan_iphc_uncompress_eui48_lladdr(ipaddr, lladdr); + break; + case EUI64_ADDR_LEN: + lowpan_iphc_uncompress_eui64_lladdr(ipaddr, lladdr); + break; + default: + WARN_ON_ONCE(1); + break; + } +} + +/* Uncompress address function for source and + * destination address(non-multicast). + * + * address_mode is the masked value for sam or dam value + */ +static int lowpan_iphc_uncompress_addr(struct sk_buff *skb, + const struct net_device *dev, + struct in6_addr *ipaddr, + u8 address_mode, const void *lladdr) +{ + bool fail; + + switch (address_mode) { + /* SAM and DAM are the same here */ + case LOWPAN_IPHC_DAM_00: + /* for global link addresses */ + fail = lowpan_fetch_skb(skb, ipaddr->s6_addr, 16); + break; + case LOWPAN_IPHC_SAM_01: + case LOWPAN_IPHC_DAM_01: + /* fe:80::XXXX:XXXX:XXXX:XXXX */ + ipaddr->s6_addr[0] = 0xFE; + ipaddr->s6_addr[1] = 0x80; + fail = lowpan_fetch_skb(skb, &ipaddr->s6_addr[8], 8); + break; + case LOWPAN_IPHC_SAM_10: + case LOWPAN_IPHC_DAM_10: + /* fe:80::ff:fe00:XXXX */ + ipaddr->s6_addr[0] = 0xFE; + ipaddr->s6_addr[1] = 0x80; + ipaddr->s6_addr[11] = 0xFF; + ipaddr->s6_addr[12] = 0xFE; + fail = lowpan_fetch_skb(skb, &ipaddr->s6_addr[14], 2); + break; + case LOWPAN_IPHC_SAM_11: + case LOWPAN_IPHC_DAM_11: + fail = false; + switch (lowpan_dev(dev)->lltype) { + case LOWPAN_LLTYPE_IEEE802154: + lowpan_iphc_uncompress_802154_lladdr(ipaddr, lladdr); + break; + default: + lowpan_iphc_uncompress_lladdr(dev, ipaddr, lladdr); + break; + } + break; + default: + pr_debug("Invalid address mode value: 0x%x\n", address_mode); + return -EINVAL; + } + + if (fail) { + pr_debug("Failed to fetch skb data\n"); + return -EIO; + } + + raw_dump_inline(NULL, "Reconstructed ipv6 addr is", + ipaddr->s6_addr, 16); + + return 0; +} + +/* Uncompress address function for source context + * based address(non-multicast). + */ +static int lowpan_iphc_uncompress_ctx_addr(struct sk_buff *skb, + const struct net_device *dev, + const struct lowpan_iphc_ctx *ctx, + struct in6_addr *ipaddr, + u8 address_mode, const void *lladdr) +{ + bool fail; + + switch (address_mode) { + /* SAM and DAM are the same here */ + case LOWPAN_IPHC_DAM_00: + fail = false; + /* SAM_00 -> unspec address :: + * Do nothing, address is already :: + * + * DAM 00 -> reserved should never occur. + */ + break; + case LOWPAN_IPHC_SAM_01: + case LOWPAN_IPHC_DAM_01: + fail = lowpan_fetch_skb(skb, &ipaddr->s6_addr[8], 8); + ipv6_addr_prefix_copy(ipaddr, &ctx->pfx, ctx->plen); + break; + case LOWPAN_IPHC_SAM_10: + case LOWPAN_IPHC_DAM_10: + ipaddr->s6_addr[11] = 0xFF; + ipaddr->s6_addr[12] = 0xFE; + fail = lowpan_fetch_skb(skb, &ipaddr->s6_addr[14], 2); + ipv6_addr_prefix_copy(ipaddr, &ctx->pfx, ctx->plen); + break; + case LOWPAN_IPHC_SAM_11: + case LOWPAN_IPHC_DAM_11: + fail = false; + switch (lowpan_dev(dev)->lltype) { + case LOWPAN_LLTYPE_IEEE802154: + lowpan_iphc_uncompress_802154_lladdr(ipaddr, lladdr); + break; + default: + lowpan_iphc_uncompress_lladdr(dev, ipaddr, lladdr); + break; + } + ipv6_addr_prefix_copy(ipaddr, &ctx->pfx, ctx->plen); + break; + default: + pr_debug("Invalid sam value: 0x%x\n", address_mode); + return -EINVAL; + } + + if (fail) { + pr_debug("Failed to fetch skb data\n"); + return -EIO; + } + + raw_dump_inline(NULL, + "Reconstructed context based ipv6 src addr is", + ipaddr->s6_addr, 16); + + return 0; +} + +/* Uncompress function for multicast destination address, + * when M bit is set. + */ +static int lowpan_uncompress_multicast_daddr(struct sk_buff *skb, + struct in6_addr *ipaddr, + u8 address_mode) +{ + bool fail; + + switch (address_mode) { + case LOWPAN_IPHC_DAM_00: + /* 00: 128 bits. The full address + * is carried in-line. + */ + fail = lowpan_fetch_skb(skb, ipaddr->s6_addr, 16); + break; + case LOWPAN_IPHC_DAM_01: + /* 01: 48 bits. The address takes + * the form ffXX::00XX:XXXX:XXXX. + */ + ipaddr->s6_addr[0] = 0xFF; + fail = lowpan_fetch_skb(skb, &ipaddr->s6_addr[1], 1); + fail |= lowpan_fetch_skb(skb, &ipaddr->s6_addr[11], 5); + break; + case LOWPAN_IPHC_DAM_10: + /* 10: 32 bits. The address takes + * the form ffXX::00XX:XXXX. + */ + ipaddr->s6_addr[0] = 0xFF; + fail = lowpan_fetch_skb(skb, &ipaddr->s6_addr[1], 1); + fail |= lowpan_fetch_skb(skb, &ipaddr->s6_addr[13], 3); + break; + case LOWPAN_IPHC_DAM_11: + /* 11: 8 bits. The address takes + * the form ff02::00XX. + */ + ipaddr->s6_addr[0] = 0xFF; + ipaddr->s6_addr[1] = 0x02; + fail = lowpan_fetch_skb(skb, &ipaddr->s6_addr[15], 1); + break; + default: + pr_debug("DAM value has a wrong value: 0x%x\n", address_mode); + return -EINVAL; + } + + if (fail) { + pr_debug("Failed to fetch skb data\n"); + return -EIO; + } + + raw_dump_inline(NULL, "Reconstructed ipv6 multicast addr is", + ipaddr->s6_addr, 16); + + return 0; +} + +static int lowpan_uncompress_multicast_ctx_daddr(struct sk_buff *skb, + struct lowpan_iphc_ctx *ctx, + struct in6_addr *ipaddr, + u8 address_mode) +{ + struct in6_addr network_pfx = {}; + bool fail; + + ipaddr->s6_addr[0] = 0xFF; + fail = lowpan_fetch_skb(skb, &ipaddr->s6_addr[1], 2); + fail |= lowpan_fetch_skb(skb, &ipaddr->s6_addr[12], 4); + if (fail) + return -EIO; + + /* take prefix_len and network prefix from the context */ + ipaddr->s6_addr[3] = ctx->plen; + /* get network prefix to copy into multicast address */ + ipv6_addr_prefix(&network_pfx, &ctx->pfx, ctx->plen); + /* setting network prefix */ + memcpy(&ipaddr->s6_addr[4], &network_pfx, 8); + + return 0; +} + +/* get the ecn values from iphc tf format and set it to ipv6hdr */ +static inline void lowpan_iphc_tf_set_ecn(struct ipv6hdr *hdr, const u8 *tf) +{ + /* get the two higher bits which is ecn */ + u8 ecn = tf[0] & 0xc0; + + /* ECN takes 0x30 in hdr->flow_lbl[0] */ + hdr->flow_lbl[0] |= (ecn >> 2); +} + +/* get the dscp values from iphc tf format and set it to ipv6hdr */ +static inline void lowpan_iphc_tf_set_dscp(struct ipv6hdr *hdr, const u8 *tf) +{ + /* DSCP is at place after ECN */ + u8 dscp = tf[0] & 0x3f; + + /* The four highest bits need to be set at hdr->priority */ + hdr->priority |= ((dscp & 0x3c) >> 2); + /* The two lower bits is part of hdr->flow_lbl[0] */ + hdr->flow_lbl[0] |= ((dscp & 0x03) << 6); +} + +/* get the flow label values from iphc tf format and set it to ipv6hdr */ +static inline void lowpan_iphc_tf_set_lbl(struct ipv6hdr *hdr, const u8 *lbl) +{ + /* flow label is always some array started with lower nibble of + * flow_lbl[0] and followed with two bytes afterwards. Inside inline + * data the flow_lbl position can be different, which will be handled + * by lbl pointer. E.g. case "01" vs "00" the traffic class is 8 bit + * shifted, the different lbl pointer will handle that. + * + * The flow label will started at lower nibble of flow_lbl[0], the + * higher nibbles are part of DSCP + ECN. + */ + hdr->flow_lbl[0] |= lbl[0] & 0x0f; + memcpy(&hdr->flow_lbl[1], &lbl[1], 2); +} + +/* lowpan_iphc_tf_decompress - decompress the traffic class. + * This function will return zero on success, a value lower than zero if + * failed. + */ +static int lowpan_iphc_tf_decompress(struct sk_buff *skb, struct ipv6hdr *hdr, + u8 val) +{ + u8 tf[4]; + + /* Traffic Class and Flow Label */ + switch (val) { + case LOWPAN_IPHC_TF_00: + /* ECN + DSCP + 4-bit Pad + Flow Label (4 bytes) */ + if (lowpan_fetch_skb(skb, tf, 4)) + return -EINVAL; + + /* 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * |ECN| DSCP | rsv | Flow Label | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + lowpan_iphc_tf_set_ecn(hdr, tf); + lowpan_iphc_tf_set_dscp(hdr, tf); + lowpan_iphc_tf_set_lbl(hdr, &tf[1]); + break; + case LOWPAN_IPHC_TF_01: + /* ECN + 2-bit Pad + Flow Label (3 bytes), DSCP is elided. */ + if (lowpan_fetch_skb(skb, tf, 3)) + return -EINVAL; + + /* 1 2 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * |ECN|rsv| Flow Label | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + lowpan_iphc_tf_set_ecn(hdr, tf); + lowpan_iphc_tf_set_lbl(hdr, &tf[0]); + break; + case LOWPAN_IPHC_TF_10: + /* ECN + DSCP (1 byte), Flow Label is elided. */ + if (lowpan_fetch_skb(skb, tf, 1)) + return -EINVAL; + + /* 0 1 2 3 4 5 6 7 + * +-+-+-+-+-+-+-+-+ + * |ECN| DSCP | + * +-+-+-+-+-+-+-+-+ + */ + lowpan_iphc_tf_set_ecn(hdr, tf); + lowpan_iphc_tf_set_dscp(hdr, tf); + break; + case LOWPAN_IPHC_TF_11: + /* Traffic Class and Flow Label are elided */ + break; + default: + WARN_ON_ONCE(1); + return -EINVAL; + } + + return 0; +} + +/* TTL uncompression values */ +static const u8 lowpan_ttl_values[] = { + [LOWPAN_IPHC_HLIM_01] = 1, + [LOWPAN_IPHC_HLIM_10] = 64, + [LOWPAN_IPHC_HLIM_11] = 255, +}; + +int lowpan_header_decompress(struct sk_buff *skb, const struct net_device *dev, + const void *daddr, const void *saddr) +{ + struct ipv6hdr hdr = {}; + struct lowpan_iphc_ctx *ci; + u8 iphc0, iphc1, cid = 0; + int err; + + raw_dump_table(__func__, "raw skb data dump uncompressed", + skb->data, skb->len); + + if (lowpan_fetch_skb(skb, &iphc0, sizeof(iphc0)) || + lowpan_fetch_skb(skb, &iphc1, sizeof(iphc1))) + return -EINVAL; + + hdr.version = 6; + + /* default CID = 0, another if the CID flag is set */ + if (iphc1 & LOWPAN_IPHC_CID) { + if (lowpan_fetch_skb(skb, &cid, sizeof(cid))) + return -EINVAL; + } + + err = lowpan_iphc_tf_decompress(skb, &hdr, + iphc0 & LOWPAN_IPHC_TF_MASK); + if (err < 0) + return err; + + /* Next Header */ + if (!(iphc0 & LOWPAN_IPHC_NH)) { + /* Next header is carried inline */ + if (lowpan_fetch_skb(skb, &hdr.nexthdr, sizeof(hdr.nexthdr))) + return -EINVAL; + + pr_debug("NH flag is set, next header carried inline: %02x\n", + hdr.nexthdr); + } + + /* Hop Limit */ + if ((iphc0 & LOWPAN_IPHC_HLIM_MASK) != LOWPAN_IPHC_HLIM_00) { + hdr.hop_limit = lowpan_ttl_values[iphc0 & LOWPAN_IPHC_HLIM_MASK]; + } else { + if (lowpan_fetch_skb(skb, &hdr.hop_limit, + sizeof(hdr.hop_limit))) + return -EINVAL; + } + + if (iphc1 & LOWPAN_IPHC_SAC) { + spin_lock_bh(&lowpan_dev(dev)->ctx.lock); + ci = lowpan_iphc_ctx_get_by_id(dev, LOWPAN_IPHC_CID_SCI(cid)); + if (!ci) { + spin_unlock_bh(&lowpan_dev(dev)->ctx.lock); + return -EINVAL; + } + + pr_debug("SAC bit is set. Handle context based source address.\n"); + err = lowpan_iphc_uncompress_ctx_addr(skb, dev, ci, &hdr.saddr, + iphc1 & LOWPAN_IPHC_SAM_MASK, + saddr); + spin_unlock_bh(&lowpan_dev(dev)->ctx.lock); + } else { + /* Source address uncompression */ + pr_debug("source address stateless compression\n"); + err = lowpan_iphc_uncompress_addr(skb, dev, &hdr.saddr, + iphc1 & LOWPAN_IPHC_SAM_MASK, + saddr); + } + + /* Check on error of previous branch */ + if (err) + return -EINVAL; + + switch (iphc1 & (LOWPAN_IPHC_M | LOWPAN_IPHC_DAC)) { + case LOWPAN_IPHC_M | LOWPAN_IPHC_DAC: + skb->pkt_type = PACKET_BROADCAST; + + spin_lock_bh(&lowpan_dev(dev)->ctx.lock); + ci = lowpan_iphc_ctx_get_by_id(dev, LOWPAN_IPHC_CID_DCI(cid)); + if (!ci) { + spin_unlock_bh(&lowpan_dev(dev)->ctx.lock); + return -EINVAL; + } + + /* multicast with context */ + pr_debug("dest: context-based mcast compression\n"); + err = lowpan_uncompress_multicast_ctx_daddr(skb, ci, + &hdr.daddr, + iphc1 & LOWPAN_IPHC_DAM_MASK); + spin_unlock_bh(&lowpan_dev(dev)->ctx.lock); + break; + case LOWPAN_IPHC_M: + skb->pkt_type = PACKET_BROADCAST; + + /* multicast */ + err = lowpan_uncompress_multicast_daddr(skb, &hdr.daddr, + iphc1 & LOWPAN_IPHC_DAM_MASK); + break; + case LOWPAN_IPHC_DAC: + skb->pkt_type = PACKET_HOST; + + spin_lock_bh(&lowpan_dev(dev)->ctx.lock); + ci = lowpan_iphc_ctx_get_by_id(dev, LOWPAN_IPHC_CID_DCI(cid)); + if (!ci) { + spin_unlock_bh(&lowpan_dev(dev)->ctx.lock); + return -EINVAL; + } + + /* Destination address context based uncompression */ + pr_debug("DAC bit is set. Handle context based destination address.\n"); + err = lowpan_iphc_uncompress_ctx_addr(skb, dev, ci, &hdr.daddr, + iphc1 & LOWPAN_IPHC_DAM_MASK, + daddr); + spin_unlock_bh(&lowpan_dev(dev)->ctx.lock); + break; + default: + skb->pkt_type = PACKET_HOST; + + err = lowpan_iphc_uncompress_addr(skb, dev, &hdr.daddr, + iphc1 & LOWPAN_IPHC_DAM_MASK, + daddr); + pr_debug("dest: stateless compression mode %d dest %pI6c\n", + iphc1 & LOWPAN_IPHC_DAM_MASK, &hdr.daddr); + break; + } + + if (err) + return -EINVAL; + + /* Next header data uncompression */ + if (iphc0 & LOWPAN_IPHC_NH) { + err = lowpan_nhc_do_uncompression(skb, dev, &hdr); + if (err < 0) + return err; + } else { + err = skb_cow(skb, sizeof(hdr)); + if (unlikely(err)) + return err; + } + + switch (lowpan_dev(dev)->lltype) { + case LOWPAN_LLTYPE_IEEE802154: + if (lowpan_802154_cb(skb)->d_size) + hdr.payload_len = htons(lowpan_802154_cb(skb)->d_size - + sizeof(struct ipv6hdr)); + else + hdr.payload_len = htons(skb->len); + break; + default: + hdr.payload_len = htons(skb->len); + break; + } + + pr_debug("skb headroom size = %d, data length = %d\n", + skb_headroom(skb), skb->len); + + pr_debug("IPv6 header dump:\n\tversion = %d\n\tlength = %d\n\t" + "nexthdr = 0x%02x\n\thop_lim = %d\n\tdest = %pI6c\n", + hdr.version, ntohs(hdr.payload_len), hdr.nexthdr, + hdr.hop_limit, &hdr.daddr); + + skb_push(skb, sizeof(hdr)); + skb_reset_mac_header(skb); + skb_reset_network_header(skb); + skb_copy_to_linear_data(skb, &hdr, sizeof(hdr)); + + raw_dump_table(__func__, "raw header dump", (u8 *)&hdr, sizeof(hdr)); + + return 0; +} +EXPORT_SYMBOL_GPL(lowpan_header_decompress); + +static const u8 lowpan_iphc_dam_to_sam_value[] = { + [LOWPAN_IPHC_DAM_00] = LOWPAN_IPHC_SAM_00, + [LOWPAN_IPHC_DAM_01] = LOWPAN_IPHC_SAM_01, + [LOWPAN_IPHC_DAM_10] = LOWPAN_IPHC_SAM_10, + [LOWPAN_IPHC_DAM_11] = LOWPAN_IPHC_SAM_11, +}; + +static inline bool +lowpan_iphc_compress_ctx_802154_lladdr(const struct in6_addr *ipaddr, + const struct lowpan_iphc_ctx *ctx, + const void *lladdr) +{ + const struct ieee802154_addr *addr = lladdr; + unsigned char extended_addr[EUI64_ADDR_LEN]; + bool lladdr_compress = false; + struct in6_addr tmp = {}; + + switch (addr->mode) { + case IEEE802154_ADDR_LONG: + ieee802154_le64_to_be64(&extended_addr, &addr->extended_addr); + /* check for SAM/DAM = 11 */ + memcpy(&tmp.s6_addr[8], &extended_addr, EUI64_ADDR_LEN); + /* second bit-flip (Universe/Local) is done according RFC2464 */ + tmp.s6_addr[8] ^= 0x02; + /* context information are always used */ + ipv6_addr_prefix_copy(&tmp, &ctx->pfx, ctx->plen); + if (ipv6_addr_equal(&tmp, ipaddr)) + lladdr_compress = true; + break; + case IEEE802154_ADDR_SHORT: + tmp.s6_addr[11] = 0xFF; + tmp.s6_addr[12] = 0xFE; + ieee802154_le16_to_be16(&tmp.s6_addr16[7], + &addr->short_addr); + /* context information are always used */ + ipv6_addr_prefix_copy(&tmp, &ctx->pfx, ctx->plen); + if (ipv6_addr_equal(&tmp, ipaddr)) + lladdr_compress = true; + break; + default: + /* should never handled and filtered by 802154 6lowpan */ + WARN_ON_ONCE(1); + break; + } + + return lladdr_compress; +} + +static bool lowpan_iphc_addr_equal(const struct net_device *dev, + const struct lowpan_iphc_ctx *ctx, + const struct in6_addr *ipaddr, + const void *lladdr) +{ + struct in6_addr tmp = {}; + + lowpan_iphc_uncompress_lladdr(dev, &tmp, lladdr); + + if (ctx) + ipv6_addr_prefix_copy(&tmp, &ctx->pfx, ctx->plen); + + return ipv6_addr_equal(&tmp, ipaddr); +} + +static u8 lowpan_compress_ctx_addr(u8 **hc_ptr, const struct net_device *dev, + const struct in6_addr *ipaddr, + const struct lowpan_iphc_ctx *ctx, + const unsigned char *lladdr, bool sam) +{ + struct in6_addr tmp = {}; + u8 dam; + + switch (lowpan_dev(dev)->lltype) { + case LOWPAN_LLTYPE_IEEE802154: + if (lowpan_iphc_compress_ctx_802154_lladdr(ipaddr, ctx, + lladdr)) { + dam = LOWPAN_IPHC_DAM_11; + goto out; + } + break; + default: + if (lowpan_iphc_addr_equal(dev, ctx, ipaddr, lladdr)) { + dam = LOWPAN_IPHC_DAM_11; + goto out; + } + break; + } + + memset(&tmp, 0, sizeof(tmp)); + /* check for SAM/DAM = 10 */ + tmp.s6_addr[11] = 0xFF; + tmp.s6_addr[12] = 0xFE; + memcpy(&tmp.s6_addr[14], &ipaddr->s6_addr[14], 2); + /* context information are always used */ + ipv6_addr_prefix_copy(&tmp, &ctx->pfx, ctx->plen); + if (ipv6_addr_equal(&tmp, ipaddr)) { + lowpan_push_hc_data(hc_ptr, &ipaddr->s6_addr[14], 2); + dam = LOWPAN_IPHC_DAM_10; + goto out; + } + + memset(&tmp, 0, sizeof(tmp)); + /* check for SAM/DAM = 01, should always match */ + memcpy(&tmp.s6_addr[8], &ipaddr->s6_addr[8], 8); + /* context information are always used */ + ipv6_addr_prefix_copy(&tmp, &ctx->pfx, ctx->plen); + if (ipv6_addr_equal(&tmp, ipaddr)) { + lowpan_push_hc_data(hc_ptr, &ipaddr->s6_addr[8], 8); + dam = LOWPAN_IPHC_DAM_01; + goto out; + } + + WARN_ONCE(1, "context found but no address mode matched\n"); + return LOWPAN_IPHC_DAM_00; +out: + + if (sam) + return lowpan_iphc_dam_to_sam_value[dam]; + else + return dam; +} + +static inline bool +lowpan_iphc_compress_802154_lladdr(const struct in6_addr *ipaddr, + const void *lladdr) +{ + const struct ieee802154_addr *addr = lladdr; + unsigned char extended_addr[EUI64_ADDR_LEN]; + bool lladdr_compress = false; + struct in6_addr tmp = {}; + + switch (addr->mode) { + case IEEE802154_ADDR_LONG: + ieee802154_le64_to_be64(&extended_addr, &addr->extended_addr); + if (is_addr_mac_addr_based(ipaddr, extended_addr)) + lladdr_compress = true; + break; + case IEEE802154_ADDR_SHORT: + /* fe:80::ff:fe00:XXXX + * \__/ + * short_addr + * + * Universe/Local bit is zero. + */ + tmp.s6_addr[0] = 0xFE; + tmp.s6_addr[1] = 0x80; + tmp.s6_addr[11] = 0xFF; + tmp.s6_addr[12] = 0xFE; + ieee802154_le16_to_be16(&tmp.s6_addr16[7], + &addr->short_addr); + if (ipv6_addr_equal(&tmp, ipaddr)) + lladdr_compress = true; + break; + default: + /* should never handled and filtered by 802154 6lowpan */ + WARN_ON_ONCE(1); + break; + } + + return lladdr_compress; +} + +static u8 lowpan_compress_addr_64(u8 **hc_ptr, const struct net_device *dev, + const struct in6_addr *ipaddr, + const unsigned char *lladdr, bool sam) +{ + u8 dam = LOWPAN_IPHC_DAM_01; + + switch (lowpan_dev(dev)->lltype) { + case LOWPAN_LLTYPE_IEEE802154: + if (lowpan_iphc_compress_802154_lladdr(ipaddr, lladdr)) { + dam = LOWPAN_IPHC_DAM_11; /* 0-bits */ + pr_debug("address compression 0 bits\n"); + goto out; + } + break; + default: + if (lowpan_iphc_addr_equal(dev, NULL, ipaddr, lladdr)) { + dam = LOWPAN_IPHC_DAM_11; + pr_debug("address compression 0 bits\n"); + goto out; + } + + break; + } + + if (lowpan_is_iid_16_bit_compressable(ipaddr)) { + /* compress IID to 16 bits xxxx::XXXX */ + lowpan_push_hc_data(hc_ptr, &ipaddr->s6_addr16[7], 2); + dam = LOWPAN_IPHC_DAM_10; /* 16-bits */ + raw_dump_inline(NULL, "Compressed ipv6 addr is (16 bits)", + *hc_ptr - 2, 2); + goto out; + } + + /* do not compress IID => xxxx::IID */ + lowpan_push_hc_data(hc_ptr, &ipaddr->s6_addr16[4], 8); + raw_dump_inline(NULL, "Compressed ipv6 addr is (64 bits)", + *hc_ptr - 8, 8); + +out: + + if (sam) + return lowpan_iphc_dam_to_sam_value[dam]; + else + return dam; +} + +/* lowpan_iphc_get_tc - get the ECN + DCSP fields in hc format */ +static inline u8 lowpan_iphc_get_tc(const struct ipv6hdr *hdr) +{ + u8 dscp, ecn; + + /* hdr->priority contains the higher bits of dscp, lower are part of + * flow_lbl[0]. Note ECN, DCSP is swapped in ipv6 hdr. + */ + dscp = (hdr->priority << 2) | ((hdr->flow_lbl[0] & 0xc0) >> 6); + /* ECN is at the two lower bits from first nibble of flow_lbl[0] */ + ecn = (hdr->flow_lbl[0] & 0x30); + /* for pretty debug output, also shift ecn to get the ecn value */ + pr_debug("ecn 0x%02x dscp 0x%02x\n", ecn >> 4, dscp); + /* ECN is at 0x30 now, shift it to have ECN + DCSP */ + return (ecn << 2) | dscp; +} + +/* lowpan_iphc_is_flow_lbl_zero - check if flow label is zero */ +static inline bool lowpan_iphc_is_flow_lbl_zero(const struct ipv6hdr *hdr) +{ + return ((!(hdr->flow_lbl[0] & 0x0f)) && + !hdr->flow_lbl[1] && !hdr->flow_lbl[2]); +} + +/* lowpan_iphc_tf_compress - compress the traffic class which is set by + * ipv6hdr. Return the corresponding format identifier which is used. + */ +static u8 lowpan_iphc_tf_compress(u8 **hc_ptr, const struct ipv6hdr *hdr) +{ + /* get ecn dscp data in a byteformat as: ECN(hi) + DSCP(lo) */ + u8 tc = lowpan_iphc_get_tc(hdr), tf[4], val; + + /* printout the traffic class in hc format */ + pr_debug("tc 0x%02x\n", tc); + + if (lowpan_iphc_is_flow_lbl_zero(hdr)) { + if (!tc) { + /* 11: Traffic Class and Flow Label are elided. */ + val = LOWPAN_IPHC_TF_11; + } else { + /* 10: ECN + DSCP (1 byte), Flow Label is elided. + * + * 0 1 2 3 4 5 6 7 + * +-+-+-+-+-+-+-+-+ + * |ECN| DSCP | + * +-+-+-+-+-+-+-+-+ + */ + lowpan_push_hc_data(hc_ptr, &tc, sizeof(tc)); + val = LOWPAN_IPHC_TF_10; + } + } else { + /* check if dscp is zero, it's after the first two bit */ + if (!(tc & 0x3f)) { + /* 01: ECN + 2-bit Pad + Flow Label (3 bytes), DSCP is elided + * + * 1 2 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * |ECN|rsv| Flow Label | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + memcpy(&tf[0], &hdr->flow_lbl[0], 3); + /* zero the highest 4-bits, contains DCSP + ECN */ + tf[0] &= ~0xf0; + /* set ECN */ + tf[0] |= (tc & 0xc0); + + lowpan_push_hc_data(hc_ptr, tf, 3); + val = LOWPAN_IPHC_TF_01; + } else { + /* 00: ECN + DSCP + 4-bit Pad + Flow Label (4 bytes) + * + * 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * |ECN| DSCP | rsv | Flow Label | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ + memcpy(&tf[0], &tc, sizeof(tc)); + /* highest nibble of flow_lbl[0] is part of DSCP + ECN + * which will be the 4-bit pad and will be filled with + * zeros afterwards. + */ + memcpy(&tf[1], &hdr->flow_lbl[0], 3); + /* zero the 4-bit pad, which is reserved */ + tf[1] &= ~0xf0; + + lowpan_push_hc_data(hc_ptr, tf, 4); + val = LOWPAN_IPHC_TF_00; + } + } + + return val; +} + +static u8 lowpan_iphc_mcast_ctx_addr_compress(u8 **hc_ptr, + const struct lowpan_iphc_ctx *ctx, + const struct in6_addr *ipaddr) +{ + u8 data[6]; + + /* flags/scope, reserved (RIID) */ + memcpy(data, &ipaddr->s6_addr[1], 2); + /* group ID */ + memcpy(&data[1], &ipaddr->s6_addr[11], 4); + lowpan_push_hc_data(hc_ptr, data, 6); + + return LOWPAN_IPHC_DAM_00; +} + +static u8 lowpan_iphc_mcast_addr_compress(u8 **hc_ptr, + const struct in6_addr *ipaddr) +{ + u8 val; + + if (lowpan_is_mcast_addr_compressable8(ipaddr)) { + pr_debug("compressed to 1 octet\n"); + /* use last byte */ + lowpan_push_hc_data(hc_ptr, &ipaddr->s6_addr[15], 1); + val = LOWPAN_IPHC_DAM_11; + } else if (lowpan_is_mcast_addr_compressable32(ipaddr)) { + pr_debug("compressed to 4 octets\n"); + /* second byte + the last three */ + lowpan_push_hc_data(hc_ptr, &ipaddr->s6_addr[1], 1); + lowpan_push_hc_data(hc_ptr, &ipaddr->s6_addr[13], 3); + val = LOWPAN_IPHC_DAM_10; + } else if (lowpan_is_mcast_addr_compressable48(ipaddr)) { + pr_debug("compressed to 6 octets\n"); + /* second byte + the last five */ + lowpan_push_hc_data(hc_ptr, &ipaddr->s6_addr[1], 1); + lowpan_push_hc_data(hc_ptr, &ipaddr->s6_addr[11], 5); + val = LOWPAN_IPHC_DAM_01; + } else { + pr_debug("using full address\n"); + lowpan_push_hc_data(hc_ptr, ipaddr->s6_addr, 16); + val = LOWPAN_IPHC_DAM_00; + } + + return val; +} + +int lowpan_header_compress(struct sk_buff *skb, const struct net_device *dev, + const void *daddr, const void *saddr) +{ + u8 iphc0, iphc1, *hc_ptr, cid = 0; + struct ipv6hdr *hdr; + u8 head[LOWPAN_IPHC_MAX_HC_BUF_LEN] = {}; + struct lowpan_iphc_ctx *dci, *sci, dci_entry, sci_entry; + int ret, ipv6_daddr_type, ipv6_saddr_type; + + if (skb->protocol != htons(ETH_P_IPV6)) + return -EINVAL; + + hdr = ipv6_hdr(skb); + hc_ptr = head + 2; + + pr_debug("IPv6 header dump:\n\tversion = %d\n\tlength = %d\n" + "\tnexthdr = 0x%02x\n\thop_lim = %d\n\tdest = %pI6c\n", + hdr->version, ntohs(hdr->payload_len), hdr->nexthdr, + hdr->hop_limit, &hdr->daddr); + + raw_dump_table(__func__, "raw skb network header dump", + skb_network_header(skb), sizeof(struct ipv6hdr)); + + /* As we copy some bit-length fields, in the IPHC encoding bytes, + * we sometimes use |= + * If the field is 0, and the current bit value in memory is 1, + * this does not work. We therefore reset the IPHC encoding here + */ + iphc0 = LOWPAN_DISPATCH_IPHC; + iphc1 = 0; + + raw_dump_table(__func__, "sending raw skb network uncompressed packet", + skb->data, skb->len); + + ipv6_daddr_type = ipv6_addr_type(&hdr->daddr); + spin_lock_bh(&lowpan_dev(dev)->ctx.lock); + if (ipv6_daddr_type & IPV6_ADDR_MULTICAST) + dci = lowpan_iphc_ctx_get_by_mcast_addr(dev, &hdr->daddr); + else + dci = lowpan_iphc_ctx_get_by_addr(dev, &hdr->daddr); + if (dci) { + memcpy(&dci_entry, dci, sizeof(*dci)); + cid |= dci->id; + } + spin_unlock_bh(&lowpan_dev(dev)->ctx.lock); + + spin_lock_bh(&lowpan_dev(dev)->ctx.lock); + sci = lowpan_iphc_ctx_get_by_addr(dev, &hdr->saddr); + if (sci) { + memcpy(&sci_entry, sci, sizeof(*sci)); + cid |= (sci->id << 4); + } + spin_unlock_bh(&lowpan_dev(dev)->ctx.lock); + + /* if cid is zero it will be compressed */ + if (cid) { + iphc1 |= LOWPAN_IPHC_CID; + lowpan_push_hc_data(&hc_ptr, &cid, sizeof(cid)); + } + + /* Traffic Class, Flow Label compression */ + iphc0 |= lowpan_iphc_tf_compress(&hc_ptr, hdr); + + /* NOTE: payload length is always compressed */ + + /* Check if we provide the nhc format for nexthdr and compression + * functionality. If not nexthdr is handled inline and not compressed. + */ + ret = lowpan_nhc_check_compression(skb, hdr, &hc_ptr); + if (ret == -ENOENT) + lowpan_push_hc_data(&hc_ptr, &hdr->nexthdr, + sizeof(hdr->nexthdr)); + else + iphc0 |= LOWPAN_IPHC_NH; + + /* Hop limit + * if 1: compress, encoding is 01 + * if 64: compress, encoding is 10 + * if 255: compress, encoding is 11 + * else do not compress + */ + switch (hdr->hop_limit) { + case 1: + iphc0 |= LOWPAN_IPHC_HLIM_01; + break; + case 64: + iphc0 |= LOWPAN_IPHC_HLIM_10; + break; + case 255: + iphc0 |= LOWPAN_IPHC_HLIM_11; + break; + default: + lowpan_push_hc_data(&hc_ptr, &hdr->hop_limit, + sizeof(hdr->hop_limit)); + } + + ipv6_saddr_type = ipv6_addr_type(&hdr->saddr); + /* source address compression */ + if (ipv6_saddr_type == IPV6_ADDR_ANY) { + pr_debug("source address is unspecified, setting SAC\n"); + iphc1 |= LOWPAN_IPHC_SAC; + } else { + if (sci) { + iphc1 |= lowpan_compress_ctx_addr(&hc_ptr, dev, + &hdr->saddr, + &sci_entry, saddr, + true); + iphc1 |= LOWPAN_IPHC_SAC; + } else { + if (ipv6_saddr_type & IPV6_ADDR_LINKLOCAL && + lowpan_is_linklocal_zero_padded(hdr->saddr)) { + iphc1 |= lowpan_compress_addr_64(&hc_ptr, dev, + &hdr->saddr, + saddr, true); + pr_debug("source address unicast link-local %pI6c iphc1 0x%02x\n", + &hdr->saddr, iphc1); + } else { + pr_debug("send the full source address\n"); + lowpan_push_hc_data(&hc_ptr, + hdr->saddr.s6_addr, 16); + } + } + } + + /* destination address compression */ + if (ipv6_daddr_type & IPV6_ADDR_MULTICAST) { + pr_debug("destination address is multicast: "); + iphc1 |= LOWPAN_IPHC_M; + if (dci) { + iphc1 |= lowpan_iphc_mcast_ctx_addr_compress(&hc_ptr, + &dci_entry, + &hdr->daddr); + iphc1 |= LOWPAN_IPHC_DAC; + } else { + iphc1 |= lowpan_iphc_mcast_addr_compress(&hc_ptr, + &hdr->daddr); + } + } else { + if (dci) { + iphc1 |= lowpan_compress_ctx_addr(&hc_ptr, dev, + &hdr->daddr, + &dci_entry, daddr, + false); + iphc1 |= LOWPAN_IPHC_DAC; + } else { + if (ipv6_daddr_type & IPV6_ADDR_LINKLOCAL && + lowpan_is_linklocal_zero_padded(hdr->daddr)) { + iphc1 |= lowpan_compress_addr_64(&hc_ptr, dev, + &hdr->daddr, + daddr, false); + pr_debug("dest address unicast link-local %pI6c iphc1 0x%02x\n", + &hdr->daddr, iphc1); + } else { + pr_debug("dest address unicast %pI6c\n", + &hdr->daddr); + lowpan_push_hc_data(&hc_ptr, + hdr->daddr.s6_addr, 16); + } + } + } + + /* next header compression */ + if (iphc0 & LOWPAN_IPHC_NH) { + ret = lowpan_nhc_do_compression(skb, hdr, &hc_ptr); + if (ret < 0) + return ret; + } + + head[0] = iphc0; + head[1] = iphc1; + + skb_pull(skb, sizeof(struct ipv6hdr)); + skb_reset_transport_header(skb); + memcpy(skb_push(skb, hc_ptr - head), head, hc_ptr - head); + skb_reset_network_header(skb); + + pr_debug("header len %d skb %u\n", (int)(hc_ptr - head), skb->len); + + raw_dump_table(__func__, "raw skb data dump compressed", + skb->data, skb->len); + return 0; +} +EXPORT_SYMBOL_GPL(lowpan_header_compress); diff --git a/net/6lowpan/ndisc.c b/net/6lowpan/ndisc.c new file mode 100644 index 000000000..16be8f8b2 --- /dev/null +++ b/net/6lowpan/ndisc.c @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * + * Authors: + * (C) 2016 Pengutronix, Alexander Aring + */ + +#include +#include +#include + +#include "6lowpan_i.h" + +static int lowpan_ndisc_is_useropt(u8 nd_opt_type) +{ + return nd_opt_type == ND_OPT_6CO; +} + +#if IS_ENABLED(CONFIG_IEEE802154_6LOWPAN) +#define NDISC_802154_SHORT_ADDR_LENGTH 1 +static int lowpan_ndisc_parse_802154_options(const struct net_device *dev, + struct nd_opt_hdr *nd_opt, + struct ndisc_options *ndopts) +{ + switch (nd_opt->nd_opt_len) { + case NDISC_802154_SHORT_ADDR_LENGTH: + if (ndopts->nd_802154_opt_array[nd_opt->nd_opt_type]) + ND_PRINTK(2, warn, + "%s: duplicated short addr ND6 option found: type=%d\n", + __func__, nd_opt->nd_opt_type); + else + ndopts->nd_802154_opt_array[nd_opt->nd_opt_type] = nd_opt; + return 1; + default: + /* all others will be handled by ndisc IPv6 option parsing */ + return 0; + } +} + +static int lowpan_ndisc_parse_options(const struct net_device *dev, + struct nd_opt_hdr *nd_opt, + struct ndisc_options *ndopts) +{ + if (!lowpan_is_ll(dev, LOWPAN_LLTYPE_IEEE802154)) + return 0; + + switch (nd_opt->nd_opt_type) { + case ND_OPT_SOURCE_LL_ADDR: + case ND_OPT_TARGET_LL_ADDR: + return lowpan_ndisc_parse_802154_options(dev, nd_opt, ndopts); + default: + return 0; + } +} + +static void lowpan_ndisc_802154_update(struct neighbour *n, u32 flags, + u8 icmp6_type, + const struct ndisc_options *ndopts) +{ + struct lowpan_802154_neigh *neigh = lowpan_802154_neigh(neighbour_priv(n)); + u8 *lladdr_short = NULL; + + switch (icmp6_type) { + case NDISC_ROUTER_SOLICITATION: + case NDISC_ROUTER_ADVERTISEMENT: + case NDISC_NEIGHBOUR_SOLICITATION: + if (ndopts->nd_802154_opts_src_lladdr) { + lladdr_short = __ndisc_opt_addr_data(ndopts->nd_802154_opts_src_lladdr, + IEEE802154_SHORT_ADDR_LEN, 0); + if (!lladdr_short) { + ND_PRINTK(2, warn, + "NA: invalid short link-layer address length\n"); + return; + } + } + break; + case NDISC_REDIRECT: + case NDISC_NEIGHBOUR_ADVERTISEMENT: + if (ndopts->nd_802154_opts_tgt_lladdr) { + lladdr_short = __ndisc_opt_addr_data(ndopts->nd_802154_opts_tgt_lladdr, + IEEE802154_SHORT_ADDR_LEN, 0); + if (!lladdr_short) { + ND_PRINTK(2, warn, + "NA: invalid short link-layer address length\n"); + return; + } + } + break; + default: + break; + } + + write_lock_bh(&n->lock); + if (lladdr_short) { + ieee802154_be16_to_le16(&neigh->short_addr, lladdr_short); + if (!lowpan_802154_is_valid_src_short_addr(neigh->short_addr)) + neigh->short_addr = cpu_to_le16(IEEE802154_ADDR_SHORT_UNSPEC); + } + write_unlock_bh(&n->lock); +} + +static void lowpan_ndisc_update(const struct net_device *dev, + struct neighbour *n, u32 flags, u8 icmp6_type, + const struct ndisc_options *ndopts) +{ + if (!lowpan_is_ll(dev, LOWPAN_LLTYPE_IEEE802154)) + return; + + /* react on overrides only. TODO check if this is really right. */ + if (flags & NEIGH_UPDATE_F_OVERRIDE) + lowpan_ndisc_802154_update(n, flags, icmp6_type, ndopts); +} + +static int lowpan_ndisc_opt_addr_space(const struct net_device *dev, + u8 icmp6_type, struct neighbour *neigh, + u8 *ha_buf, u8 **ha) +{ + struct lowpan_802154_neigh *n; + struct wpan_dev *wpan_dev; + int addr_space = 0; + + if (!lowpan_is_ll(dev, LOWPAN_LLTYPE_IEEE802154)) + return 0; + + switch (icmp6_type) { + case NDISC_REDIRECT: + n = lowpan_802154_neigh(neighbour_priv(neigh)); + + read_lock_bh(&neigh->lock); + if (lowpan_802154_is_valid_src_short_addr(n->short_addr)) { + memcpy(ha_buf, &n->short_addr, + IEEE802154_SHORT_ADDR_LEN); + read_unlock_bh(&neigh->lock); + addr_space += __ndisc_opt_addr_space(IEEE802154_SHORT_ADDR_LEN, 0); + *ha = ha_buf; + } else { + read_unlock_bh(&neigh->lock); + } + break; + case NDISC_NEIGHBOUR_ADVERTISEMENT: + case NDISC_NEIGHBOUR_SOLICITATION: + case NDISC_ROUTER_SOLICITATION: + wpan_dev = lowpan_802154_dev(dev)->wdev->ieee802154_ptr; + + if (lowpan_802154_is_valid_src_short_addr(wpan_dev->short_addr)) + addr_space = __ndisc_opt_addr_space(IEEE802154_SHORT_ADDR_LEN, 0); + break; + default: + break; + } + + return addr_space; +} + +static void lowpan_ndisc_fill_addr_option(const struct net_device *dev, + struct sk_buff *skb, u8 icmp6_type, + const u8 *ha) +{ + struct wpan_dev *wpan_dev; + __be16 short_addr; + u8 opt_type; + + if (!lowpan_is_ll(dev, LOWPAN_LLTYPE_IEEE802154)) + return; + + switch (icmp6_type) { + case NDISC_REDIRECT: + if (ha) { + ieee802154_le16_to_be16(&short_addr, ha); + __ndisc_fill_addr_option(skb, ND_OPT_TARGET_LL_ADDR, + &short_addr, + IEEE802154_SHORT_ADDR_LEN, 0); + } + return; + case NDISC_NEIGHBOUR_ADVERTISEMENT: + opt_type = ND_OPT_TARGET_LL_ADDR; + break; + case NDISC_ROUTER_SOLICITATION: + case NDISC_NEIGHBOUR_SOLICITATION: + opt_type = ND_OPT_SOURCE_LL_ADDR; + break; + default: + return; + } + + wpan_dev = lowpan_802154_dev(dev)->wdev->ieee802154_ptr; + + if (lowpan_802154_is_valid_src_short_addr(wpan_dev->short_addr)) { + ieee802154_le16_to_be16(&short_addr, + &wpan_dev->short_addr); + __ndisc_fill_addr_option(skb, opt_type, &short_addr, + IEEE802154_SHORT_ADDR_LEN, 0); + } +} + +static void lowpan_ndisc_prefix_rcv_add_addr(struct net *net, + struct net_device *dev, + const struct prefix_info *pinfo, + struct inet6_dev *in6_dev, + struct in6_addr *addr, + int addr_type, u32 addr_flags, + bool sllao, bool tokenized, + __u32 valid_lft, + u32 prefered_lft, + bool dev_addr_generated) +{ + int err; + + /* generates short based address for RA PIO's */ + if (lowpan_is_ll(dev, LOWPAN_LLTYPE_IEEE802154) && dev_addr_generated && + !addrconf_ifid_802154_6lowpan(addr->s6_addr + 8, dev)) { + err = addrconf_prefix_rcv_add_addr(net, dev, pinfo, in6_dev, + addr, addr_type, addr_flags, + sllao, tokenized, valid_lft, + prefered_lft); + if (err) + ND_PRINTK(2, warn, + "RA: could not add a short address based address for prefix: %pI6c\n", + &pinfo->prefix); + } +} +#endif + +const struct ndisc_ops lowpan_ndisc_ops = { + .is_useropt = lowpan_ndisc_is_useropt, +#if IS_ENABLED(CONFIG_IEEE802154_6LOWPAN) + .parse_options = lowpan_ndisc_parse_options, + .update = lowpan_ndisc_update, + .opt_addr_space = lowpan_ndisc_opt_addr_space, + .fill_addr_option = lowpan_ndisc_fill_addr_option, + .prefix_rcv_add_addr = lowpan_ndisc_prefix_rcv_add_addr, +#endif +}; diff --git a/net/6lowpan/nhc.c b/net/6lowpan/nhc.c new file mode 100644 index 000000000..7b3745953 --- /dev/null +++ b/net/6lowpan/nhc.c @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * 6LoWPAN next header compression + * + * Authors: + * Alexander Aring + */ + +#include + +#include + +#include "nhc.h" + +static const struct lowpan_nhc *lowpan_nexthdr_nhcs[NEXTHDR_MAX + 1]; +static DEFINE_SPINLOCK(lowpan_nhc_lock); + +static const struct lowpan_nhc *lowpan_nhc_by_nhcid(struct sk_buff *skb) +{ + const struct lowpan_nhc *nhc; + int i; + u8 id; + + if (!pskb_may_pull(skb, 1)) + return NULL; + + id = *skb->data; + + for (i = 0; i < NEXTHDR_MAX + 1; i++) { + nhc = lowpan_nexthdr_nhcs[i]; + if (!nhc) + continue; + + if ((id & nhc->idmask) == nhc->id) + return nhc; + } + + return NULL; +} + +int lowpan_nhc_check_compression(struct sk_buff *skb, + const struct ipv6hdr *hdr, u8 **hc_ptr) +{ + const struct lowpan_nhc *nhc; + int ret = 0; + + spin_lock_bh(&lowpan_nhc_lock); + + nhc = lowpan_nexthdr_nhcs[hdr->nexthdr]; + if (!(nhc && nhc->compress)) + ret = -ENOENT; + + spin_unlock_bh(&lowpan_nhc_lock); + + return ret; +} + +int lowpan_nhc_do_compression(struct sk_buff *skb, const struct ipv6hdr *hdr, + u8 **hc_ptr) +{ + int ret; + const struct lowpan_nhc *nhc; + + spin_lock_bh(&lowpan_nhc_lock); + + nhc = lowpan_nexthdr_nhcs[hdr->nexthdr]; + /* check if the nhc module was removed in unlocked part. + * TODO: this is a workaround we should prevent unloading + * of nhc modules while unlocked part, this will always drop + * the lowpan packet but it's very unlikely. + * + * Solution isn't easy because we need to decide at + * lowpan_nhc_check_compression if we do a compression or not. + * Because the inline data which is added to skb, we can't move this + * handling. + */ + if (unlikely(!nhc || !nhc->compress)) { + ret = -EINVAL; + goto out; + } + + /* In the case of RAW sockets the transport header is not set by + * the ip6 stack so we must set it ourselves + */ + if (skb->transport_header == skb->network_header) + skb_set_transport_header(skb, sizeof(struct ipv6hdr)); + + ret = nhc->compress(skb, hc_ptr); + if (ret < 0) + goto out; + + /* skip the transport header */ + skb_pull(skb, nhc->nexthdrlen); + +out: + spin_unlock_bh(&lowpan_nhc_lock); + + return ret; +} + +int lowpan_nhc_do_uncompression(struct sk_buff *skb, + const struct net_device *dev, + struct ipv6hdr *hdr) +{ + const struct lowpan_nhc *nhc; + int ret; + + spin_lock_bh(&lowpan_nhc_lock); + + nhc = lowpan_nhc_by_nhcid(skb); + if (nhc) { + if (nhc->uncompress) { + ret = nhc->uncompress(skb, sizeof(struct ipv6hdr) + + nhc->nexthdrlen); + if (ret < 0) { + spin_unlock_bh(&lowpan_nhc_lock); + return ret; + } + } else { + spin_unlock_bh(&lowpan_nhc_lock); + netdev_warn(dev, "received nhc id for %s which is not implemented.\n", + nhc->name); + return -ENOTSUPP; + } + } else { + spin_unlock_bh(&lowpan_nhc_lock); + netdev_warn(dev, "received unknown nhc id which was not found.\n"); + return -ENOENT; + } + + hdr->nexthdr = nhc->nexthdr; + skb_reset_transport_header(skb); + raw_dump_table(__func__, "raw transport header dump", + skb_transport_header(skb), nhc->nexthdrlen); + + spin_unlock_bh(&lowpan_nhc_lock); + + return 0; +} + +int lowpan_nhc_add(const struct lowpan_nhc *nhc) +{ + int ret = 0; + + spin_lock_bh(&lowpan_nhc_lock); + + if (lowpan_nexthdr_nhcs[nhc->nexthdr]) { + ret = -EEXIST; + goto out; + } + + lowpan_nexthdr_nhcs[nhc->nexthdr] = nhc; +out: + spin_unlock_bh(&lowpan_nhc_lock); + return ret; +} +EXPORT_SYMBOL(lowpan_nhc_add); + +void lowpan_nhc_del(const struct lowpan_nhc *nhc) +{ + spin_lock_bh(&lowpan_nhc_lock); + + lowpan_nexthdr_nhcs[nhc->nexthdr] = NULL; + + spin_unlock_bh(&lowpan_nhc_lock); + + synchronize_net(); +} +EXPORT_SYMBOL(lowpan_nhc_del); diff --git a/net/6lowpan/nhc.h b/net/6lowpan/nhc.h new file mode 100644 index 000000000..ab7b4977c --- /dev/null +++ b/net/6lowpan/nhc.h @@ -0,0 +1,133 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __6LOWPAN_NHC_H +#define __6LOWPAN_NHC_H + +#include +#include +#include + +#include +#include + +/** + * LOWPAN_NHC - helper macro to generate nh id fields and lowpan_nhc struct + * + * @__nhc: variable name of the lowpan_nhc struct. + * @_name: const char * of common header compression name. + * @_nexthdr: ipv6 nexthdr field for the header compression. + * @_nexthdrlen: ipv6 nexthdr len for the reserved space. + * @_id: one byte nhc id value. + * @_idmask: one byte nhc id mask value. + * @_uncompress: callback for uncompression call. + * @_compress: callback for compression call. + */ +#define LOWPAN_NHC(__nhc, _name, _nexthdr, \ + _hdrlen, _id, _idmask, \ + _uncompress, _compress) \ +static const struct lowpan_nhc __nhc = { \ + .name = _name, \ + .nexthdr = _nexthdr, \ + .nexthdrlen = _hdrlen, \ + .id = _id, \ + .idmask = _idmask, \ + .uncompress = _uncompress, \ + .compress = _compress, \ +} + +#define module_lowpan_nhc(__nhc) \ +static int __init __nhc##_init(void) \ +{ \ + return lowpan_nhc_add(&(__nhc)); \ +} \ +module_init(__nhc##_init); \ +static void __exit __nhc##_exit(void) \ +{ \ + lowpan_nhc_del(&(__nhc)); \ +} \ +module_exit(__nhc##_exit); + +/** + * struct lowpan_nhc - hold 6lowpan next hdr compression ifnformation + * + * @name: name of the specific next header compression + * @nexthdr: next header value of the protocol which should be compressed. + * @nexthdrlen: ipv6 nexthdr len for the reserved space. + * @id: one byte nhc id value. + * @idmask: one byte nhc id mask value. + * @compress: callback to do the header compression. + * @uncompress: callback to do the header uncompression. + */ +struct lowpan_nhc { + const char *name; + u8 nexthdr; + size_t nexthdrlen; + u8 id; + u8 idmask; + + int (*uncompress)(struct sk_buff *skb, size_t needed); + int (*compress)(struct sk_buff *skb, u8 **hc_ptr); +}; + +/** + * lowpan_nhc_by_nexthdr - return the 6lowpan nhc by ipv6 nexthdr. + * + * @nexthdr: ipv6 nexthdr value. + */ +struct lowpan_nhc *lowpan_nhc_by_nexthdr(u8 nexthdr); + +/** + * lowpan_nhc_check_compression - checks if we support compression format. If + * we support the nhc by nexthdr field, the function will return 0. If we + * don't support the nhc by nexthdr this function will return -ENOENT. + * + * @skb: skb of 6LoWPAN header to read nhc and replace header. + * @hdr: ipv6hdr to check the nexthdr value + * @hc_ptr: pointer for 6LoWPAN header which should increment at the end of + * replaced header. + */ +int lowpan_nhc_check_compression(struct sk_buff *skb, + const struct ipv6hdr *hdr, u8 **hc_ptr); + +/** + * lowpan_nhc_do_compression - calling compress callback for nhc + * + * @skb: skb of 6LoWPAN header to read nhc and replace header. + * @hdr: ipv6hdr to set the nexthdr value + * @hc_ptr: pointer for 6LoWPAN header which should increment at the end of + * replaced header. + */ +int lowpan_nhc_do_compression(struct sk_buff *skb, const struct ipv6hdr *hdr, + u8 **hc_ptr); + +/** + * lowpan_nhc_do_uncompression - calling uncompress callback for nhc + * + * @nhc: 6LoWPAN nhc context, get by lowpan_nhc_by_ functions. + * @skb: skb of 6LoWPAN header, skb->data should be pointed to nhc id value. + * @dev: netdevice for print logging information. + * @hdr: ipv6hdr for setting nexthdr value. + */ +int lowpan_nhc_do_uncompression(struct sk_buff *skb, + const struct net_device *dev, + struct ipv6hdr *hdr); + +/** + * lowpan_nhc_add - register a next header compression to framework + * + * @nhc: nhc which should be add. + */ +int lowpan_nhc_add(const struct lowpan_nhc *nhc); + +/** + * lowpan_nhc_del - delete a next header compression from framework + * + * @nhc: nhc which should be delete. + */ +void lowpan_nhc_del(const struct lowpan_nhc *nhc); + +/** + * lowpan_nhc_init - adding all default nhcs + */ +void lowpan_nhc_init(void); + +#endif /* __6LOWPAN_NHC_H */ diff --git a/net/6lowpan/nhc_dest.c b/net/6lowpan/nhc_dest.c new file mode 100644 index 000000000..0cbcc7806 --- /dev/null +++ b/net/6lowpan/nhc_dest.c @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * 6LoWPAN IPv6 Destination Options Header compression according to + * RFC6282 + */ + +#include "nhc.h" + +#define LOWPAN_NHC_DEST_ID_0 0xe6 +#define LOWPAN_NHC_DEST_MASK_0 0xfe + +LOWPAN_NHC(nhc_dest, "RFC6282 Destination Options", NEXTHDR_DEST, 0, + LOWPAN_NHC_DEST_ID_0, LOWPAN_NHC_DEST_MASK_0, NULL, NULL); + +module_lowpan_nhc(nhc_dest); +MODULE_DESCRIPTION("6LoWPAN next header RFC6282 Destination Options compression"); +MODULE_LICENSE("GPL"); diff --git a/net/6lowpan/nhc_fragment.c b/net/6lowpan/nhc_fragment.c new file mode 100644 index 000000000..9414552df --- /dev/null +++ b/net/6lowpan/nhc_fragment.c @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * 6LoWPAN IPv6 Fragment Header compression according to RFC6282 + */ + +#include "nhc.h" + +#define LOWPAN_NHC_FRAGMENT_ID_0 0xe4 +#define LOWPAN_NHC_FRAGMENT_MASK_0 0xfe + +LOWPAN_NHC(nhc_fragment, "RFC6282 Fragment", NEXTHDR_FRAGMENT, 0, + LOWPAN_NHC_FRAGMENT_ID_0, LOWPAN_NHC_FRAGMENT_MASK_0, NULL, NULL); + +module_lowpan_nhc(nhc_fragment); +MODULE_DESCRIPTION("6LoWPAN next header RFC6282 Fragment compression"); +MODULE_LICENSE("GPL"); diff --git a/net/6lowpan/nhc_ghc_ext_dest.c b/net/6lowpan/nhc_ghc_ext_dest.c new file mode 100644 index 000000000..e4745ddd1 --- /dev/null +++ b/net/6lowpan/nhc_ghc_ext_dest.c @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * 6LoWPAN Extension Header compression according to RFC7400 + */ + +#include "nhc.h" + +#define LOWPAN_GHC_EXT_DEST_ID_0 0xb6 +#define LOWPAN_GHC_EXT_DEST_MASK_0 0xfe + +LOWPAN_NHC(ghc_ext_dest, "RFC7400 Destination Extension Header", NEXTHDR_DEST, + 0, LOWPAN_GHC_EXT_DEST_ID_0, LOWPAN_GHC_EXT_DEST_MASK_0, NULL, NULL); + +module_lowpan_nhc(ghc_ext_dest); +MODULE_DESCRIPTION("6LoWPAN generic header destination extension compression"); +MODULE_LICENSE("GPL"); diff --git a/net/6lowpan/nhc_ghc_ext_frag.c b/net/6lowpan/nhc_ghc_ext_frag.c new file mode 100644 index 000000000..220e5abfa --- /dev/null +++ b/net/6lowpan/nhc_ghc_ext_frag.c @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * 6LoWPAN Extension Header compression according to RFC7400 + */ + +#include "nhc.h" + +#define LOWPAN_GHC_EXT_FRAG_ID_0 0xb4 +#define LOWPAN_GHC_EXT_FRAG_MASK_0 0xfe + +LOWPAN_NHC(ghc_ext_frag, "RFC7400 Fragmentation Extension Header", + NEXTHDR_FRAGMENT, 0, LOWPAN_GHC_EXT_FRAG_ID_0, + LOWPAN_GHC_EXT_FRAG_MASK_0, NULL, NULL); + +module_lowpan_nhc(ghc_ext_frag); +MODULE_DESCRIPTION("6LoWPAN generic header fragmentation extension compression"); +MODULE_LICENSE("GPL"); diff --git a/net/6lowpan/nhc_ghc_ext_hop.c b/net/6lowpan/nhc_ghc_ext_hop.c new file mode 100644 index 000000000..9b0de4da7 --- /dev/null +++ b/net/6lowpan/nhc_ghc_ext_hop.c @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * 6LoWPAN Extension Header compression according to RFC7400 + */ + +#include "nhc.h" + +#define LOWPAN_GHC_EXT_HOP_ID_0 0xb0 +#define LOWPAN_GHC_EXT_HOP_MASK_0 0xfe + +LOWPAN_NHC(ghc_ext_hop, "RFC7400 Hop-by-Hop Extension Header", NEXTHDR_HOP, 0, + LOWPAN_GHC_EXT_HOP_ID_0, LOWPAN_GHC_EXT_HOP_MASK_0, NULL, NULL); + +module_lowpan_nhc(ghc_ext_hop); +MODULE_DESCRIPTION("6LoWPAN generic header hop-by-hop extension compression"); +MODULE_LICENSE("GPL"); diff --git a/net/6lowpan/nhc_ghc_ext_route.c b/net/6lowpan/nhc_ghc_ext_route.c new file mode 100644 index 000000000..3e86faec5 --- /dev/null +++ b/net/6lowpan/nhc_ghc_ext_route.c @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * 6LoWPAN Extension Header compression according to RFC7400 + */ + +#include "nhc.h" + +#define LOWPAN_GHC_EXT_ROUTE_ID_0 0xb2 +#define LOWPAN_GHC_EXT_ROUTE_MASK_0 0xfe + +LOWPAN_NHC(ghc_ext_route, "RFC7400 Routing Extension Header", NEXTHDR_ROUTING, + 0, LOWPAN_GHC_EXT_ROUTE_ID_0, LOWPAN_GHC_EXT_ROUTE_MASK_0, NULL, NULL); + +module_lowpan_nhc(ghc_ext_route); +MODULE_DESCRIPTION("6LoWPAN generic header routing extension compression"); +MODULE_LICENSE("GPL"); diff --git a/net/6lowpan/nhc_ghc_icmpv6.c b/net/6lowpan/nhc_ghc_icmpv6.c new file mode 100644 index 000000000..1634f3eb0 --- /dev/null +++ b/net/6lowpan/nhc_ghc_icmpv6.c @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * 6LoWPAN ICMPv6 compression according to RFC7400 + */ + +#include "nhc.h" + +#define LOWPAN_GHC_ICMPV6_ID_0 0xdf +#define LOWPAN_GHC_ICMPV6_MASK_0 0xff + +LOWPAN_NHC(ghc_icmpv6, "RFC7400 ICMPv6", NEXTHDR_ICMP, 0, + LOWPAN_GHC_ICMPV6_ID_0, LOWPAN_GHC_ICMPV6_MASK_0, NULL, NULL); + +module_lowpan_nhc(ghc_icmpv6); +MODULE_DESCRIPTION("6LoWPAN generic header ICMPv6 compression"); +MODULE_LICENSE("GPL"); diff --git a/net/6lowpan/nhc_ghc_udp.c b/net/6lowpan/nhc_ghc_udp.c new file mode 100644 index 000000000..4ac4813b7 --- /dev/null +++ b/net/6lowpan/nhc_ghc_udp.c @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * 6LoWPAN UDP compression according to RFC7400 + */ + +#include "nhc.h" + +#define LOWPAN_GHC_UDP_ID_0 0xd0 +#define LOWPAN_GHC_UDP_MASK_0 0xf8 + +LOWPAN_NHC(ghc_udp, "RFC7400 UDP", NEXTHDR_UDP, 0, + LOWPAN_GHC_UDP_ID_0, LOWPAN_GHC_UDP_MASK_0, NULL, NULL); + +module_lowpan_nhc(ghc_udp); +MODULE_DESCRIPTION("6LoWPAN generic header UDP compression"); +MODULE_LICENSE("GPL"); diff --git a/net/6lowpan/nhc_hop.c b/net/6lowpan/nhc_hop.c new file mode 100644 index 000000000..182087dfd --- /dev/null +++ b/net/6lowpan/nhc_hop.c @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * 6LoWPAN IPv6 Hop-by-Hop Options Header compression according to RFC6282 + */ + +#include "nhc.h" + +#define LOWPAN_NHC_HOP_ID_0 0xe0 +#define LOWPAN_NHC_HOP_MASK_0 0xfe + +LOWPAN_NHC(nhc_hop, "RFC6282 Hop-by-Hop Options", NEXTHDR_HOP, 0, + LOWPAN_NHC_HOP_ID_0, LOWPAN_NHC_HOP_MASK_0, NULL, NULL); + +module_lowpan_nhc(nhc_hop); +MODULE_DESCRIPTION("6LoWPAN next header RFC6282 Hop-by-Hop Options compression"); +MODULE_LICENSE("GPL"); diff --git a/net/6lowpan/nhc_ipv6.c b/net/6lowpan/nhc_ipv6.c new file mode 100644 index 000000000..20242360b --- /dev/null +++ b/net/6lowpan/nhc_ipv6.c @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * 6LoWPAN IPv6 Header compression according to RFC6282 + */ + +#include "nhc.h" + +#define LOWPAN_NHC_IPV6_ID_0 0xee +#define LOWPAN_NHC_IPV6_MASK_0 0xfe + +LOWPAN_NHC(nhc_ipv6, "RFC6282 IPv6", NEXTHDR_IPV6, 0, LOWPAN_NHC_IPV6_ID_0, + LOWPAN_NHC_IPV6_MASK_0, NULL, NULL); + +module_lowpan_nhc(nhc_ipv6); +MODULE_DESCRIPTION("6LoWPAN next header RFC6282 IPv6 compression"); +MODULE_LICENSE("GPL"); diff --git a/net/6lowpan/nhc_mobility.c b/net/6lowpan/nhc_mobility.c new file mode 100644 index 000000000..1c31d872c --- /dev/null +++ b/net/6lowpan/nhc_mobility.c @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * 6LoWPAN IPv6 Mobility Header compression according to RFC6282 + */ + +#include "nhc.h" + +#define LOWPAN_NHC_MOBILITY_ID_0 0xe8 +#define LOWPAN_NHC_MOBILITY_MASK_0 0xfe + +LOWPAN_NHC(nhc_mobility, "RFC6282 Mobility", NEXTHDR_MOBILITY, 0, + LOWPAN_NHC_MOBILITY_ID_0, LOWPAN_NHC_MOBILITY_MASK_0, NULL, NULL); + +module_lowpan_nhc(nhc_mobility); +MODULE_DESCRIPTION("6LoWPAN next header RFC6282 Mobility compression"); +MODULE_LICENSE("GPL"); diff --git a/net/6lowpan/nhc_routing.c b/net/6lowpan/nhc_routing.c new file mode 100644 index 000000000..dae03ebf7 --- /dev/null +++ b/net/6lowpan/nhc_routing.c @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * 6LoWPAN IPv6 Routing Header compression according to RFC6282 + */ + +#include "nhc.h" + +#define LOWPAN_NHC_ROUTING_ID_0 0xe2 +#define LOWPAN_NHC_ROUTING_MASK_0 0xfe + +LOWPAN_NHC(nhc_routing, "RFC6282 Routing", NEXTHDR_ROUTING, 0, + LOWPAN_NHC_ROUTING_ID_0, LOWPAN_NHC_ROUTING_MASK_0, NULL, NULL); + +module_lowpan_nhc(nhc_routing); +MODULE_DESCRIPTION("6LoWPAN next header RFC6282 Routing compression"); +MODULE_LICENSE("GPL"); diff --git a/net/6lowpan/nhc_udp.c b/net/6lowpan/nhc_udp.c new file mode 100644 index 000000000..0a506c772 --- /dev/null +++ b/net/6lowpan/nhc_udp.c @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * 6LoWPAN IPv6 UDP compression according to RFC6282 + * + * Authors: + * Alexander Aring + * + * Original written by: + * Alexander Smirnov + * Jon Smirl + */ + +#include "nhc.h" + +#define LOWPAN_NHC_UDP_MASK 0xF8 +#define LOWPAN_NHC_UDP_ID 0xF0 + +#define LOWPAN_NHC_UDP_4BIT_PORT 0xF0B0 +#define LOWPAN_NHC_UDP_4BIT_MASK 0xFFF0 +#define LOWPAN_NHC_UDP_8BIT_PORT 0xF000 +#define LOWPAN_NHC_UDP_8BIT_MASK 0xFF00 + +/* values for port compression, _with checksum_ ie bit 5 set to 0 */ + +/* all inline */ +#define LOWPAN_NHC_UDP_CS_P_00 0xF0 +/* source 16bit inline, dest = 0xF0 + 8 bit inline */ +#define LOWPAN_NHC_UDP_CS_P_01 0xF1 +/* source = 0xF0 + 8bit inline, dest = 16 bit inline */ +#define LOWPAN_NHC_UDP_CS_P_10 0xF2 +/* source & dest = 0xF0B + 4bit inline */ +#define LOWPAN_NHC_UDP_CS_P_11 0xF3 +/* checksum elided */ +#define LOWPAN_NHC_UDP_CS_C 0x04 + +static int udp_uncompress(struct sk_buff *skb, size_t needed) +{ + u8 tmp = 0, val = 0; + struct udphdr uh; + bool fail; + int err; + + fail = lowpan_fetch_skb(skb, &tmp, sizeof(tmp)); + + pr_debug("UDP header uncompression\n"); + switch (tmp & LOWPAN_NHC_UDP_CS_P_11) { + case LOWPAN_NHC_UDP_CS_P_00: + fail |= lowpan_fetch_skb(skb, &uh.source, sizeof(uh.source)); + fail |= lowpan_fetch_skb(skb, &uh.dest, sizeof(uh.dest)); + break; + case LOWPAN_NHC_UDP_CS_P_01: + fail |= lowpan_fetch_skb(skb, &uh.source, sizeof(uh.source)); + fail |= lowpan_fetch_skb(skb, &val, sizeof(val)); + uh.dest = htons(val + LOWPAN_NHC_UDP_8BIT_PORT); + break; + case LOWPAN_NHC_UDP_CS_P_10: + fail |= lowpan_fetch_skb(skb, &val, sizeof(val)); + uh.source = htons(val + LOWPAN_NHC_UDP_8BIT_PORT); + fail |= lowpan_fetch_skb(skb, &uh.dest, sizeof(uh.dest)); + break; + case LOWPAN_NHC_UDP_CS_P_11: + fail |= lowpan_fetch_skb(skb, &val, sizeof(val)); + uh.source = htons(LOWPAN_NHC_UDP_4BIT_PORT + (val >> 4)); + uh.dest = htons(LOWPAN_NHC_UDP_4BIT_PORT + (val & 0x0f)); + break; + default: + BUG(); + } + + pr_debug("uncompressed UDP ports: src = %d, dst = %d\n", + ntohs(uh.source), ntohs(uh.dest)); + + /* checksum */ + if (tmp & LOWPAN_NHC_UDP_CS_C) { + pr_debug_ratelimited("checksum elided currently not supported\n"); + fail = true; + } else { + fail |= lowpan_fetch_skb(skb, &uh.check, sizeof(uh.check)); + } + + if (fail) + return -EINVAL; + + /* UDP length needs to be inferred from the lower layers + * here, we obtain the hint from the remaining size of the + * frame + */ + switch (lowpan_dev(skb->dev)->lltype) { + case LOWPAN_LLTYPE_IEEE802154: + if (lowpan_802154_cb(skb)->d_size) + uh.len = htons(lowpan_802154_cb(skb)->d_size - + sizeof(struct ipv6hdr)); + else + uh.len = htons(skb->len + sizeof(struct udphdr)); + break; + default: + uh.len = htons(skb->len + sizeof(struct udphdr)); + break; + } + pr_debug("uncompressed UDP length: src = %d", ntohs(uh.len)); + + /* replace the compressed UDP head by the uncompressed UDP + * header + */ + err = skb_cow(skb, needed); + if (unlikely(err)) + return err; + + skb_push(skb, sizeof(struct udphdr)); + skb_copy_to_linear_data(skb, &uh, sizeof(struct udphdr)); + + return 0; +} + +static int udp_compress(struct sk_buff *skb, u8 **hc_ptr) +{ + const struct udphdr *uh = udp_hdr(skb); + u8 tmp; + + if (((ntohs(uh->source) & LOWPAN_NHC_UDP_4BIT_MASK) == + LOWPAN_NHC_UDP_4BIT_PORT) && + ((ntohs(uh->dest) & LOWPAN_NHC_UDP_4BIT_MASK) == + LOWPAN_NHC_UDP_4BIT_PORT)) { + pr_debug("UDP header: both ports compression to 4 bits\n"); + /* compression value */ + tmp = LOWPAN_NHC_UDP_CS_P_11; + lowpan_push_hc_data(hc_ptr, &tmp, sizeof(tmp)); + /* source and destination port */ + tmp = ntohs(uh->dest) - LOWPAN_NHC_UDP_4BIT_PORT + + ((ntohs(uh->source) - LOWPAN_NHC_UDP_4BIT_PORT) << 4); + lowpan_push_hc_data(hc_ptr, &tmp, sizeof(tmp)); + } else if ((ntohs(uh->dest) & LOWPAN_NHC_UDP_8BIT_MASK) == + LOWPAN_NHC_UDP_8BIT_PORT) { + pr_debug("UDP header: remove 8 bits of dest\n"); + /* compression value */ + tmp = LOWPAN_NHC_UDP_CS_P_01; + lowpan_push_hc_data(hc_ptr, &tmp, sizeof(tmp)); + /* source port */ + lowpan_push_hc_data(hc_ptr, &uh->source, sizeof(uh->source)); + /* destination port */ + tmp = ntohs(uh->dest) - LOWPAN_NHC_UDP_8BIT_PORT; + lowpan_push_hc_data(hc_ptr, &tmp, sizeof(tmp)); + } else if ((ntohs(uh->source) & LOWPAN_NHC_UDP_8BIT_MASK) == + LOWPAN_NHC_UDP_8BIT_PORT) { + pr_debug("UDP header: remove 8 bits of source\n"); + /* compression value */ + tmp = LOWPAN_NHC_UDP_CS_P_10; + lowpan_push_hc_data(hc_ptr, &tmp, sizeof(tmp)); + /* source port */ + tmp = ntohs(uh->source) - LOWPAN_NHC_UDP_8BIT_PORT; + lowpan_push_hc_data(hc_ptr, &tmp, sizeof(tmp)); + /* destination port */ + lowpan_push_hc_data(hc_ptr, &uh->dest, sizeof(uh->dest)); + } else { + pr_debug("UDP header: can't compress\n"); + /* compression value */ + tmp = LOWPAN_NHC_UDP_CS_P_00; + lowpan_push_hc_data(hc_ptr, &tmp, sizeof(tmp)); + /* source port */ + lowpan_push_hc_data(hc_ptr, &uh->source, sizeof(uh->source)); + /* destination port */ + lowpan_push_hc_data(hc_ptr, &uh->dest, sizeof(uh->dest)); + } + + /* checksum is always inline */ + lowpan_push_hc_data(hc_ptr, &uh->check, sizeof(uh->check)); + + return 0; +} + +LOWPAN_NHC(nhc_udp, "RFC6282 UDP", NEXTHDR_UDP, sizeof(struct udphdr), + LOWPAN_NHC_UDP_ID, LOWPAN_NHC_UDP_MASK, udp_uncompress, udp_compress); + +module_lowpan_nhc(nhc_udp); +MODULE_DESCRIPTION("6LoWPAN next header RFC6282 UDP compression"); +MODULE_LICENSE("GPL"); diff --git a/net/802/Kconfig b/net/802/Kconfig new file mode 100644 index 000000000..aaa83e888 --- /dev/null +++ b/net/802/Kconfig @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: GPL-2.0-only +config STP + tristate + select LLC + +config GARP + tristate + select STP + +config MRP + tristate diff --git a/net/802/Makefile b/net/802/Makefile new file mode 100644 index 000000000..bfed80221 --- /dev/null +++ b/net/802/Makefile @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the Linux 802.x protocol layers. +# + +# Check the p8022 selections against net/core/Makefile. +obj-$(CONFIG_LLC) += p8022.o psnap.o +obj-$(CONFIG_NET_FC) += fc.o +obj-$(CONFIG_FDDI) += fddi.o +obj-$(CONFIG_HIPPI) += hippi.o +obj-$(CONFIG_ATALK) += p8022.o psnap.o +obj-$(CONFIG_STP) += stp.o +obj-$(CONFIG_GARP) += garp.o +obj-$(CONFIG_MRP) += mrp.o diff --git a/net/802/fc.c b/net/802/fc.c new file mode 100644 index 000000000..afd3d288a --- /dev/null +++ b/net/802/fc.c @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NET3: Fibre Channel device handling subroutines + * + * Vineet Abraham + * v 1.0 03/22/99 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * Put the headers on a Fibre Channel packet. + */ + +static int fc_header(struct sk_buff *skb, struct net_device *dev, + unsigned short type, + const void *daddr, const void *saddr, unsigned int len) +{ + struct fch_hdr *fch; + int hdr_len; + + /* + * Add the 802.2 SNAP header if IP as the IPv4 code calls + * dev->hard_header directly. + */ + if (type == ETH_P_IP || type == ETH_P_ARP) + { + struct fcllc *fcllc; + + hdr_len = sizeof(struct fch_hdr) + sizeof(struct fcllc); + fch = skb_push(skb, hdr_len); + fcllc = (struct fcllc *)(fch+1); + fcllc->dsap = fcllc->ssap = EXTENDED_SAP; + fcllc->llc = UI_CMD; + fcllc->protid[0] = fcllc->protid[1] = fcllc->protid[2] = 0x00; + fcllc->ethertype = htons(type); + } + else + { + hdr_len = sizeof(struct fch_hdr); + fch = skb_push(skb, hdr_len); + } + + if(saddr) + memcpy(fch->saddr,saddr,dev->addr_len); + else + memcpy(fch->saddr,dev->dev_addr,dev->addr_len); + + if(daddr) + { + memcpy(fch->daddr,daddr,dev->addr_len); + return hdr_len; + } + return -hdr_len; +} + +static const struct header_ops fc_header_ops = { + .create = fc_header, +}; + +static void fc_setup(struct net_device *dev) +{ + dev->header_ops = &fc_header_ops; + dev->type = ARPHRD_IEEE802; + dev->hard_header_len = FC_HLEN; + dev->mtu = 2024; + dev->addr_len = FC_ALEN; + dev->tx_queue_len = 100; /* Long queues on fc */ + dev->flags = IFF_BROADCAST; + + memset(dev->broadcast, 0xFF, FC_ALEN); +} + +/** + * alloc_fcdev - Register fibre channel device + * @sizeof_priv: Size of additional driver-private structure to be allocated + * for this fibre channel device + * + * Fill in the fields of the device structure with fibre channel-generic values. + * + * Constructs a new net device, complete with a private data area of + * size @sizeof_priv. A 32-byte (not bit) alignment is enforced for + * this private data area. + */ +struct net_device *alloc_fcdev(int sizeof_priv) +{ + return alloc_netdev(sizeof_priv, "fc%d", NET_NAME_UNKNOWN, fc_setup); +} +EXPORT_SYMBOL(alloc_fcdev); diff --git a/net/802/fddi.c b/net/802/fddi.c new file mode 100644 index 000000000..7533ce26b --- /dev/null +++ b/net/802/fddi.c @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * FDDI-type device handling. + * + * Version: @(#)fddi.c 1.0.0 08/12/96 + * + * Authors: Lawrence V. Stefani, + * + * fddi.c is based on previous eth.c and tr.c work by + * Ross Biro + * Fred N. van Kempen, + * Mark Evans, + * Florian La Roche, + * Alan Cox, + * + * Changes + * Alan Cox : New arp/rebuild header + * Maciej W. Rozycki : IPv6 support + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * Create the FDDI MAC header for an arbitrary protocol layer + * + * saddr=NULL means use device source address + * daddr=NULL means leave destination address (eg unresolved arp) + */ + +static int fddi_header(struct sk_buff *skb, struct net_device *dev, + unsigned short type, + const void *daddr, const void *saddr, unsigned int len) +{ + int hl = FDDI_K_SNAP_HLEN; + struct fddihdr *fddi; + + if(type != ETH_P_IP && type != ETH_P_IPV6 && type != ETH_P_ARP) + hl=FDDI_K_8022_HLEN-3; + fddi = skb_push(skb, hl); + fddi->fc = FDDI_FC_K_ASYNC_LLC_DEF; + if(type == ETH_P_IP || type == ETH_P_IPV6 || type == ETH_P_ARP) + { + fddi->hdr.llc_snap.dsap = FDDI_EXTENDED_SAP; + fddi->hdr.llc_snap.ssap = FDDI_EXTENDED_SAP; + fddi->hdr.llc_snap.ctrl = FDDI_UI_CMD; + fddi->hdr.llc_snap.oui[0] = 0x00; + fddi->hdr.llc_snap.oui[1] = 0x00; + fddi->hdr.llc_snap.oui[2] = 0x00; + fddi->hdr.llc_snap.ethertype = htons(type); + } + + /* Set the source and destination hardware addresses */ + + if (saddr != NULL) + memcpy(fddi->saddr, saddr, dev->addr_len); + else + memcpy(fddi->saddr, dev->dev_addr, dev->addr_len); + + if (daddr != NULL) + { + memcpy(fddi->daddr, daddr, dev->addr_len); + return hl; + } + + return -hl; +} + +/* + * Determine the packet's protocol ID and fill in skb fields. + * This routine is called before an incoming packet is passed + * up. It's used to fill in specific skb fields and to set + * the proper pointer to the start of packet data (skb->data). + */ + +__be16 fddi_type_trans(struct sk_buff *skb, struct net_device *dev) +{ + struct fddihdr *fddi = (struct fddihdr *)skb->data; + __be16 type; + + /* + * Set mac.raw field to point to FC byte, set data field to point + * to start of packet data. Assume 802.2 SNAP frames for now. + */ + + skb->dev = dev; + skb_reset_mac_header(skb); /* point to frame control (FC) */ + + if(fddi->hdr.llc_8022_1.dsap==0xe0) + { + skb_pull(skb, FDDI_K_8022_HLEN-3); + type = htons(ETH_P_802_2); + } + else + { + skb_pull(skb, FDDI_K_SNAP_HLEN); /* adjust for 21 byte header */ + type=fddi->hdr.llc_snap.ethertype; + } + + /* Set packet type based on destination address and flag settings */ + + if (*fddi->daddr & 0x01) + { + if (memcmp(fddi->daddr, dev->broadcast, FDDI_K_ALEN) == 0) + skb->pkt_type = PACKET_BROADCAST; + else + skb->pkt_type = PACKET_MULTICAST; + } + + else if (dev->flags & IFF_PROMISC) + { + if (memcmp(fddi->daddr, dev->dev_addr, FDDI_K_ALEN)) + skb->pkt_type = PACKET_OTHERHOST; + } + + /* Assume 802.2 SNAP frames, for now */ + + return type; +} + +EXPORT_SYMBOL(fddi_type_trans); + +static const struct header_ops fddi_header_ops = { + .create = fddi_header, +}; + + +static void fddi_setup(struct net_device *dev) +{ + dev->header_ops = &fddi_header_ops; + dev->type = ARPHRD_FDDI; + dev->hard_header_len = FDDI_K_SNAP_HLEN+3; /* Assume 802.2 SNAP hdr len + 3 pad bytes */ + dev->mtu = FDDI_K_SNAP_DLEN; /* Assume max payload of 802.2 SNAP frame */ + dev->min_mtu = FDDI_K_SNAP_HLEN; + dev->max_mtu = FDDI_K_SNAP_DLEN; + dev->addr_len = FDDI_K_ALEN; + dev->tx_queue_len = 100; /* Long queues on FDDI */ + dev->flags = IFF_BROADCAST | IFF_MULTICAST; + + memset(dev->broadcast, 0xFF, FDDI_K_ALEN); +} + +/** + * alloc_fddidev - Register FDDI device + * @sizeof_priv: Size of additional driver-private structure to be allocated + * for this FDDI device + * + * Fill in the fields of the device structure with FDDI-generic values. + * + * Constructs a new net device, complete with a private data area of + * size @sizeof_priv. A 32-byte (not bit) alignment is enforced for + * this private data area. + */ +struct net_device *alloc_fddidev(int sizeof_priv) +{ + return alloc_netdev(sizeof_priv, "fddi%d", NET_NAME_UNKNOWN, + fddi_setup); +} +EXPORT_SYMBOL(alloc_fddidev); + +MODULE_LICENSE("GPL"); diff --git a/net/802/garp.c b/net/802/garp.c new file mode 100644 index 000000000..fc9eb02a9 --- /dev/null +++ b/net/802/garp.c @@ -0,0 +1,649 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * IEEE 802.1D Generic Attribute Registration Protocol (GARP) + * + * Copyright (c) 2008 Patrick McHardy + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static unsigned int garp_join_time __read_mostly = 200; +module_param(garp_join_time, uint, 0644); +MODULE_PARM_DESC(garp_join_time, "Join time in ms (default 200ms)"); +MODULE_LICENSE("GPL"); + +static const struct garp_state_trans { + u8 state; + u8 action; +} garp_applicant_state_table[GARP_APPLICANT_MAX + 1][GARP_EVENT_MAX + 1] = { + [GARP_APPLICANT_VA] = { + [GARP_EVENT_TRANSMIT_PDU] = { .state = GARP_APPLICANT_AA, + .action = GARP_ACTION_S_JOIN_IN }, + [GARP_EVENT_R_JOIN_IN] = { .state = GARP_APPLICANT_AA }, + [GARP_EVENT_R_JOIN_EMPTY] = { .state = GARP_APPLICANT_VA }, + [GARP_EVENT_R_EMPTY] = { .state = GARP_APPLICANT_VA }, + [GARP_EVENT_R_LEAVE_IN] = { .state = GARP_APPLICANT_VA }, + [GARP_EVENT_R_LEAVE_EMPTY] = { .state = GARP_APPLICANT_VP }, + [GARP_EVENT_REQ_JOIN] = { .state = GARP_APPLICANT_INVALID }, + [GARP_EVENT_REQ_LEAVE] = { .state = GARP_APPLICANT_LA }, + }, + [GARP_APPLICANT_AA] = { + [GARP_EVENT_TRANSMIT_PDU] = { .state = GARP_APPLICANT_QA, + .action = GARP_ACTION_S_JOIN_IN }, + [GARP_EVENT_R_JOIN_IN] = { .state = GARP_APPLICANT_QA }, + [GARP_EVENT_R_JOIN_EMPTY] = { .state = GARP_APPLICANT_VA }, + [GARP_EVENT_R_EMPTY] = { .state = GARP_APPLICANT_VA }, + [GARP_EVENT_R_LEAVE_IN] = { .state = GARP_APPLICANT_VA }, + [GARP_EVENT_R_LEAVE_EMPTY] = { .state = GARP_APPLICANT_VP }, + [GARP_EVENT_REQ_JOIN] = { .state = GARP_APPLICANT_INVALID }, + [GARP_EVENT_REQ_LEAVE] = { .state = GARP_APPLICANT_LA }, + }, + [GARP_APPLICANT_QA] = { + [GARP_EVENT_TRANSMIT_PDU] = { .state = GARP_APPLICANT_INVALID }, + [GARP_EVENT_R_JOIN_IN] = { .state = GARP_APPLICANT_QA }, + [GARP_EVENT_R_JOIN_EMPTY] = { .state = GARP_APPLICANT_VA }, + [GARP_EVENT_R_EMPTY] = { .state = GARP_APPLICANT_VA }, + [GARP_EVENT_R_LEAVE_IN] = { .state = GARP_APPLICANT_VP }, + [GARP_EVENT_R_LEAVE_EMPTY] = { .state = GARP_APPLICANT_VP }, + [GARP_EVENT_REQ_JOIN] = { .state = GARP_APPLICANT_INVALID }, + [GARP_EVENT_REQ_LEAVE] = { .state = GARP_APPLICANT_LA }, + }, + [GARP_APPLICANT_LA] = { + [GARP_EVENT_TRANSMIT_PDU] = { .state = GARP_APPLICANT_VO, + .action = GARP_ACTION_S_LEAVE_EMPTY }, + [GARP_EVENT_R_JOIN_IN] = { .state = GARP_APPLICANT_LA }, + [GARP_EVENT_R_JOIN_EMPTY] = { .state = GARP_APPLICANT_VO }, + [GARP_EVENT_R_EMPTY] = { .state = GARP_APPLICANT_LA }, + [GARP_EVENT_R_LEAVE_IN] = { .state = GARP_APPLICANT_LA }, + [GARP_EVENT_R_LEAVE_EMPTY] = { .state = GARP_APPLICANT_VO }, + [GARP_EVENT_REQ_JOIN] = { .state = GARP_APPLICANT_VA }, + [GARP_EVENT_REQ_LEAVE] = { .state = GARP_APPLICANT_INVALID }, + }, + [GARP_APPLICANT_VP] = { + [GARP_EVENT_TRANSMIT_PDU] = { .state = GARP_APPLICANT_AA, + .action = GARP_ACTION_S_JOIN_IN }, + [GARP_EVENT_R_JOIN_IN] = { .state = GARP_APPLICANT_AP }, + [GARP_EVENT_R_JOIN_EMPTY] = { .state = GARP_APPLICANT_VP }, + [GARP_EVENT_R_EMPTY] = { .state = GARP_APPLICANT_VP }, + [GARP_EVENT_R_LEAVE_IN] = { .state = GARP_APPLICANT_VP }, + [GARP_EVENT_R_LEAVE_EMPTY] = { .state = GARP_APPLICANT_VP }, + [GARP_EVENT_REQ_JOIN] = { .state = GARP_APPLICANT_INVALID }, + [GARP_EVENT_REQ_LEAVE] = { .state = GARP_APPLICANT_VO }, + }, + [GARP_APPLICANT_AP] = { + [GARP_EVENT_TRANSMIT_PDU] = { .state = GARP_APPLICANT_QA, + .action = GARP_ACTION_S_JOIN_IN }, + [GARP_EVENT_R_JOIN_IN] = { .state = GARP_APPLICANT_QP }, + [GARP_EVENT_R_JOIN_EMPTY] = { .state = GARP_APPLICANT_VP }, + [GARP_EVENT_R_EMPTY] = { .state = GARP_APPLICANT_VP }, + [GARP_EVENT_R_LEAVE_IN] = { .state = GARP_APPLICANT_VP }, + [GARP_EVENT_R_LEAVE_EMPTY] = { .state = GARP_APPLICANT_VP }, + [GARP_EVENT_REQ_JOIN] = { .state = GARP_APPLICANT_INVALID }, + [GARP_EVENT_REQ_LEAVE] = { .state = GARP_APPLICANT_AO }, + }, + [GARP_APPLICANT_QP] = { + [GARP_EVENT_TRANSMIT_PDU] = { .state = GARP_APPLICANT_INVALID }, + [GARP_EVENT_R_JOIN_IN] = { .state = GARP_APPLICANT_QP }, + [GARP_EVENT_R_JOIN_EMPTY] = { .state = GARP_APPLICANT_VP }, + [GARP_EVENT_R_EMPTY] = { .state = GARP_APPLICANT_VP }, + [GARP_EVENT_R_LEAVE_IN] = { .state = GARP_APPLICANT_VP }, + [GARP_EVENT_R_LEAVE_EMPTY] = { .state = GARP_APPLICANT_VP }, + [GARP_EVENT_REQ_JOIN] = { .state = GARP_APPLICANT_INVALID }, + [GARP_EVENT_REQ_LEAVE] = { .state = GARP_APPLICANT_QO }, + }, + [GARP_APPLICANT_VO] = { + [GARP_EVENT_TRANSMIT_PDU] = { .state = GARP_APPLICANT_INVALID }, + [GARP_EVENT_R_JOIN_IN] = { .state = GARP_APPLICANT_AO }, + [GARP_EVENT_R_JOIN_EMPTY] = { .state = GARP_APPLICANT_VO }, + [GARP_EVENT_R_EMPTY] = { .state = GARP_APPLICANT_VO }, + [GARP_EVENT_R_LEAVE_IN] = { .state = GARP_APPLICANT_VO }, + [GARP_EVENT_R_LEAVE_EMPTY] = { .state = GARP_APPLICANT_VO }, + [GARP_EVENT_REQ_JOIN] = { .state = GARP_APPLICANT_VP }, + [GARP_EVENT_REQ_LEAVE] = { .state = GARP_APPLICANT_INVALID }, + }, + [GARP_APPLICANT_AO] = { + [GARP_EVENT_TRANSMIT_PDU] = { .state = GARP_APPLICANT_INVALID }, + [GARP_EVENT_R_JOIN_IN] = { .state = GARP_APPLICANT_QO }, + [GARP_EVENT_R_JOIN_EMPTY] = { .state = GARP_APPLICANT_VO }, + [GARP_EVENT_R_EMPTY] = { .state = GARP_APPLICANT_VO }, + [GARP_EVENT_R_LEAVE_IN] = { .state = GARP_APPLICANT_VO }, + [GARP_EVENT_R_LEAVE_EMPTY] = { .state = GARP_APPLICANT_VO }, + [GARP_EVENT_REQ_JOIN] = { .state = GARP_APPLICANT_AP }, + [GARP_EVENT_REQ_LEAVE] = { .state = GARP_APPLICANT_INVALID }, + }, + [GARP_APPLICANT_QO] = { + [GARP_EVENT_TRANSMIT_PDU] = { .state = GARP_APPLICANT_INVALID }, + [GARP_EVENT_R_JOIN_IN] = { .state = GARP_APPLICANT_QO }, + [GARP_EVENT_R_JOIN_EMPTY] = { .state = GARP_APPLICANT_VO }, + [GARP_EVENT_R_EMPTY] = { .state = GARP_APPLICANT_VO }, + [GARP_EVENT_R_LEAVE_IN] = { .state = GARP_APPLICANT_VO }, + [GARP_EVENT_R_LEAVE_EMPTY] = { .state = GARP_APPLICANT_VO }, + [GARP_EVENT_REQ_JOIN] = { .state = GARP_APPLICANT_QP }, + [GARP_EVENT_REQ_LEAVE] = { .state = GARP_APPLICANT_INVALID }, + }, +}; + +static int garp_attr_cmp(const struct garp_attr *attr, + const void *data, u8 len, u8 type) +{ + if (attr->type != type) + return attr->type - type; + if (attr->dlen != len) + return attr->dlen - len; + return memcmp(attr->data, data, len); +} + +static struct garp_attr *garp_attr_lookup(const struct garp_applicant *app, + const void *data, u8 len, u8 type) +{ + struct rb_node *parent = app->gid.rb_node; + struct garp_attr *attr; + int d; + + while (parent) { + attr = rb_entry(parent, struct garp_attr, node); + d = garp_attr_cmp(attr, data, len, type); + if (d > 0) + parent = parent->rb_left; + else if (d < 0) + parent = parent->rb_right; + else + return attr; + } + return NULL; +} + +static struct garp_attr *garp_attr_create(struct garp_applicant *app, + const void *data, u8 len, u8 type) +{ + struct rb_node *parent = NULL, **p = &app->gid.rb_node; + struct garp_attr *attr; + int d; + + while (*p) { + parent = *p; + attr = rb_entry(parent, struct garp_attr, node); + d = garp_attr_cmp(attr, data, len, type); + if (d > 0) + p = &parent->rb_left; + else if (d < 0) + p = &parent->rb_right; + else { + /* The attribute already exists; re-use it. */ + return attr; + } + } + attr = kmalloc(sizeof(*attr) + len, GFP_ATOMIC); + if (!attr) + return attr; + attr->state = GARP_APPLICANT_VO; + attr->type = type; + attr->dlen = len; + memcpy(attr->data, data, len); + + rb_link_node(&attr->node, parent, p); + rb_insert_color(&attr->node, &app->gid); + return attr; +} + +static void garp_attr_destroy(struct garp_applicant *app, struct garp_attr *attr) +{ + rb_erase(&attr->node, &app->gid); + kfree(attr); +} + +static void garp_attr_destroy_all(struct garp_applicant *app) +{ + struct rb_node *node, *next; + struct garp_attr *attr; + + for (node = rb_first(&app->gid); + next = node ? rb_next(node) : NULL, node != NULL; + node = next) { + attr = rb_entry(node, struct garp_attr, node); + garp_attr_destroy(app, attr); + } +} + +static int garp_pdu_init(struct garp_applicant *app) +{ + struct sk_buff *skb; + struct garp_pdu_hdr *gp; + +#define LLC_RESERVE sizeof(struct llc_pdu_un) + skb = alloc_skb(app->dev->mtu + LL_RESERVED_SPACE(app->dev), + GFP_ATOMIC); + if (!skb) + return -ENOMEM; + + skb->dev = app->dev; + skb->protocol = htons(ETH_P_802_2); + skb_reserve(skb, LL_RESERVED_SPACE(app->dev) + LLC_RESERVE); + + gp = __skb_put(skb, sizeof(*gp)); + put_unaligned(htons(GARP_PROTOCOL_ID), &gp->protocol); + + app->pdu = skb; + return 0; +} + +static int garp_pdu_append_end_mark(struct garp_applicant *app) +{ + if (skb_tailroom(app->pdu) < sizeof(u8)) + return -1; + __skb_put_u8(app->pdu, GARP_END_MARK); + return 0; +} + +static void garp_pdu_queue(struct garp_applicant *app) +{ + if (!app->pdu) + return; + + garp_pdu_append_end_mark(app); + garp_pdu_append_end_mark(app); + + llc_pdu_header_init(app->pdu, LLC_PDU_TYPE_U, LLC_SAP_BSPAN, + LLC_SAP_BSPAN, LLC_PDU_CMD); + llc_pdu_init_as_ui_cmd(app->pdu); + llc_mac_hdr_init(app->pdu, app->dev->dev_addr, + app->app->proto.group_address); + + skb_queue_tail(&app->queue, app->pdu); + app->pdu = NULL; +} + +static void garp_queue_xmit(struct garp_applicant *app) +{ + struct sk_buff *skb; + + while ((skb = skb_dequeue(&app->queue))) + dev_queue_xmit(skb); +} + +static int garp_pdu_append_msg(struct garp_applicant *app, u8 attrtype) +{ + struct garp_msg_hdr *gm; + + if (skb_tailroom(app->pdu) < sizeof(*gm)) + return -1; + gm = __skb_put(app->pdu, sizeof(*gm)); + gm->attrtype = attrtype; + garp_cb(app->pdu)->cur_type = attrtype; + return 0; +} + +static int garp_pdu_append_attr(struct garp_applicant *app, + const struct garp_attr *attr, + enum garp_attr_event event) +{ + struct garp_attr_hdr *ga; + unsigned int len; + int err; +again: + if (!app->pdu) { + err = garp_pdu_init(app); + if (err < 0) + return err; + } + + if (garp_cb(app->pdu)->cur_type != attr->type) { + if (garp_cb(app->pdu)->cur_type && + garp_pdu_append_end_mark(app) < 0) + goto queue; + if (garp_pdu_append_msg(app, attr->type) < 0) + goto queue; + } + + len = sizeof(*ga) + attr->dlen; + if (skb_tailroom(app->pdu) < len) + goto queue; + ga = __skb_put(app->pdu, len); + ga->len = len; + ga->event = event; + memcpy(ga->data, attr->data, attr->dlen); + return 0; + +queue: + garp_pdu_queue(app); + goto again; +} + +static void garp_attr_event(struct garp_applicant *app, + struct garp_attr *attr, enum garp_event event) +{ + enum garp_applicant_state state; + + state = garp_applicant_state_table[attr->state][event].state; + if (state == GARP_APPLICANT_INVALID) + return; + + switch (garp_applicant_state_table[attr->state][event].action) { + case GARP_ACTION_NONE: + break; + case GARP_ACTION_S_JOIN_IN: + /* When appending the attribute fails, don't update state in + * order to retry on next TRANSMIT_PDU event. */ + if (garp_pdu_append_attr(app, attr, GARP_JOIN_IN) < 0) + return; + break; + case GARP_ACTION_S_LEAVE_EMPTY: + garp_pdu_append_attr(app, attr, GARP_LEAVE_EMPTY); + /* As a pure applicant, sending a leave message implies that + * the attribute was unregistered and can be destroyed. */ + garp_attr_destroy(app, attr); + return; + default: + WARN_ON(1); + } + + attr->state = state; +} + +int garp_request_join(const struct net_device *dev, + const struct garp_application *appl, + const void *data, u8 len, u8 type) +{ + struct garp_port *port = rtnl_dereference(dev->garp_port); + struct garp_applicant *app = rtnl_dereference(port->applicants[appl->type]); + struct garp_attr *attr; + + spin_lock_bh(&app->lock); + attr = garp_attr_create(app, data, len, type); + if (!attr) { + spin_unlock_bh(&app->lock); + return -ENOMEM; + } + garp_attr_event(app, attr, GARP_EVENT_REQ_JOIN); + spin_unlock_bh(&app->lock); + return 0; +} +EXPORT_SYMBOL_GPL(garp_request_join); + +void garp_request_leave(const struct net_device *dev, + const struct garp_application *appl, + const void *data, u8 len, u8 type) +{ + struct garp_port *port = rtnl_dereference(dev->garp_port); + struct garp_applicant *app = rtnl_dereference(port->applicants[appl->type]); + struct garp_attr *attr; + + spin_lock_bh(&app->lock); + attr = garp_attr_lookup(app, data, len, type); + if (!attr) { + spin_unlock_bh(&app->lock); + return; + } + garp_attr_event(app, attr, GARP_EVENT_REQ_LEAVE); + spin_unlock_bh(&app->lock); +} +EXPORT_SYMBOL_GPL(garp_request_leave); + +static void garp_gid_event(struct garp_applicant *app, enum garp_event event) +{ + struct rb_node *node, *next; + struct garp_attr *attr; + + for (node = rb_first(&app->gid); + next = node ? rb_next(node) : NULL, node != NULL; + node = next) { + attr = rb_entry(node, struct garp_attr, node); + garp_attr_event(app, attr, event); + } +} + +static void garp_join_timer_arm(struct garp_applicant *app) +{ + unsigned long delay; + + delay = prandom_u32_max(msecs_to_jiffies(garp_join_time)); + mod_timer(&app->join_timer, jiffies + delay); +} + +static void garp_join_timer(struct timer_list *t) +{ + struct garp_applicant *app = from_timer(app, t, join_timer); + + spin_lock(&app->lock); + garp_gid_event(app, GARP_EVENT_TRANSMIT_PDU); + garp_pdu_queue(app); + spin_unlock(&app->lock); + + garp_queue_xmit(app); + garp_join_timer_arm(app); +} + +static int garp_pdu_parse_end_mark(struct sk_buff *skb) +{ + if (!pskb_may_pull(skb, sizeof(u8))) + return -1; + if (*skb->data == GARP_END_MARK) { + skb_pull(skb, sizeof(u8)); + return -1; + } + return 0; +} + +static int garp_pdu_parse_attr(struct garp_applicant *app, struct sk_buff *skb, + u8 attrtype) +{ + const struct garp_attr_hdr *ga; + struct garp_attr *attr; + enum garp_event event; + unsigned int dlen; + + if (!pskb_may_pull(skb, sizeof(*ga))) + return -1; + ga = (struct garp_attr_hdr *)skb->data; + if (ga->len < sizeof(*ga)) + return -1; + + if (!pskb_may_pull(skb, ga->len)) + return -1; + skb_pull(skb, ga->len); + dlen = sizeof(*ga) - ga->len; + + if (attrtype > app->app->maxattr) + return 0; + + switch (ga->event) { + case GARP_LEAVE_ALL: + if (dlen != 0) + return -1; + garp_gid_event(app, GARP_EVENT_R_LEAVE_EMPTY); + return 0; + case GARP_JOIN_EMPTY: + event = GARP_EVENT_R_JOIN_EMPTY; + break; + case GARP_JOIN_IN: + event = GARP_EVENT_R_JOIN_IN; + break; + case GARP_LEAVE_EMPTY: + event = GARP_EVENT_R_LEAVE_EMPTY; + break; + case GARP_EMPTY: + event = GARP_EVENT_R_EMPTY; + break; + default: + return 0; + } + + if (dlen == 0) + return -1; + attr = garp_attr_lookup(app, ga->data, dlen, attrtype); + if (attr == NULL) + return 0; + garp_attr_event(app, attr, event); + return 0; +} + +static int garp_pdu_parse_msg(struct garp_applicant *app, struct sk_buff *skb) +{ + const struct garp_msg_hdr *gm; + + if (!pskb_may_pull(skb, sizeof(*gm))) + return -1; + gm = (struct garp_msg_hdr *)skb->data; + if (gm->attrtype == 0) + return -1; + skb_pull(skb, sizeof(*gm)); + + while (skb->len > 0) { + if (garp_pdu_parse_attr(app, skb, gm->attrtype) < 0) + return -1; + if (garp_pdu_parse_end_mark(skb) < 0) + break; + } + return 0; +} + +static void garp_pdu_rcv(const struct stp_proto *proto, struct sk_buff *skb, + struct net_device *dev) +{ + struct garp_application *appl = proto->data; + struct garp_port *port; + struct garp_applicant *app; + const struct garp_pdu_hdr *gp; + + port = rcu_dereference(dev->garp_port); + if (!port) + goto err; + app = rcu_dereference(port->applicants[appl->type]); + if (!app) + goto err; + + if (!pskb_may_pull(skb, sizeof(*gp))) + goto err; + gp = (struct garp_pdu_hdr *)skb->data; + if (get_unaligned(&gp->protocol) != htons(GARP_PROTOCOL_ID)) + goto err; + skb_pull(skb, sizeof(*gp)); + + spin_lock(&app->lock); + while (skb->len > 0) { + if (garp_pdu_parse_msg(app, skb) < 0) + break; + if (garp_pdu_parse_end_mark(skb) < 0) + break; + } + spin_unlock(&app->lock); +err: + kfree_skb(skb); +} + +static int garp_init_port(struct net_device *dev) +{ + struct garp_port *port; + + port = kzalloc(sizeof(*port), GFP_KERNEL); + if (!port) + return -ENOMEM; + rcu_assign_pointer(dev->garp_port, port); + return 0; +} + +static void garp_release_port(struct net_device *dev) +{ + struct garp_port *port = rtnl_dereference(dev->garp_port); + unsigned int i; + + for (i = 0; i <= GARP_APPLICATION_MAX; i++) { + if (rtnl_dereference(port->applicants[i])) + return; + } + RCU_INIT_POINTER(dev->garp_port, NULL); + kfree_rcu(port, rcu); +} + +int garp_init_applicant(struct net_device *dev, struct garp_application *appl) +{ + struct garp_applicant *app; + int err; + + ASSERT_RTNL(); + + if (!rtnl_dereference(dev->garp_port)) { + err = garp_init_port(dev); + if (err < 0) + goto err1; + } + + err = -ENOMEM; + app = kzalloc(sizeof(*app), GFP_KERNEL); + if (!app) + goto err2; + + err = dev_mc_add(dev, appl->proto.group_address); + if (err < 0) + goto err3; + + app->dev = dev; + app->app = appl; + app->gid = RB_ROOT; + spin_lock_init(&app->lock); + skb_queue_head_init(&app->queue); + rcu_assign_pointer(dev->garp_port->applicants[appl->type], app); + timer_setup(&app->join_timer, garp_join_timer, 0); + garp_join_timer_arm(app); + return 0; + +err3: + kfree(app); +err2: + garp_release_port(dev); +err1: + return err; +} +EXPORT_SYMBOL_GPL(garp_init_applicant); + +void garp_uninit_applicant(struct net_device *dev, struct garp_application *appl) +{ + struct garp_port *port = rtnl_dereference(dev->garp_port); + struct garp_applicant *app = rtnl_dereference(port->applicants[appl->type]); + + ASSERT_RTNL(); + + RCU_INIT_POINTER(port->applicants[appl->type], NULL); + + /* Delete timer and generate a final TRANSMIT_PDU event to flush out + * all pending messages before the applicant is gone. */ + del_timer_sync(&app->join_timer); + + spin_lock_bh(&app->lock); + garp_gid_event(app, GARP_EVENT_TRANSMIT_PDU); + garp_attr_destroy_all(app); + garp_pdu_queue(app); + spin_unlock_bh(&app->lock); + + garp_queue_xmit(app); + + dev_mc_del(dev, appl->proto.group_address); + kfree_rcu(app, rcu); + garp_release_port(dev); +} +EXPORT_SYMBOL_GPL(garp_uninit_applicant); + +int garp_register_application(struct garp_application *appl) +{ + appl->proto.rcv = garp_pdu_rcv; + appl->proto.data = appl; + return stp_proto_register(&appl->proto); +} +EXPORT_SYMBOL_GPL(garp_register_application); + +void garp_unregister_application(struct garp_application *appl) +{ + stp_proto_unregister(&appl->proto); +} +EXPORT_SYMBOL_GPL(garp_unregister_application); diff --git a/net/802/hippi.c b/net/802/hippi.c new file mode 100644 index 000000000..1997b7dd2 --- /dev/null +++ b/net/802/hippi.c @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * HIPPI-type device handling. + * + * Version: @(#)hippi.c 1.0.0 05/29/97 + * + * Authors: Ross Biro + * Fred N. van Kempen, + * Mark Evans, + * Florian La Roche, + * Alan Cox, + * Jes Sorensen, + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * Create the HIPPI MAC header for an arbitrary protocol layer + * + * saddr=NULL means use device source address + * daddr=NULL means leave destination address (eg unresolved arp) + */ + +static int hippi_header(struct sk_buff *skb, struct net_device *dev, + unsigned short type, + const void *daddr, const void *saddr, unsigned int len) +{ + struct hippi_hdr *hip = skb_push(skb, HIPPI_HLEN); + struct hippi_cb *hcb = (struct hippi_cb *) skb->cb; + + if (!len){ + len = skb->len - HIPPI_HLEN; + printk("hippi_header(): length not supplied\n"); + } + + /* + * Due to the stupidity of the little endian byte-order we + * have to set the fp field this way. + */ + hip->fp.fixed = htonl(0x04800018); + hip->fp.d2_size = htonl(len + 8); + hip->le.fc = 0; + hip->le.double_wide = 0; /* only HIPPI 800 for the time being */ + hip->le.message_type = 0; /* Data PDU */ + + hip->le.dest_addr_type = 2; /* 12 bit SC address */ + hip->le.src_addr_type = 2; /* 12 bit SC address */ + + memcpy(hip->le.src_switch_addr, dev->dev_addr + 3, 3); + memset_startat(&hip->le, 0, reserved); + + hip->snap.dsap = HIPPI_EXTENDED_SAP; + hip->snap.ssap = HIPPI_EXTENDED_SAP; + hip->snap.ctrl = HIPPI_UI_CMD; + hip->snap.oui[0] = 0x00; + hip->snap.oui[1] = 0x00; + hip->snap.oui[2] = 0x00; + hip->snap.ethertype = htons(type); + + if (daddr) + { + memcpy(hip->le.dest_switch_addr, daddr + 3, 3); + memcpy(&hcb->ifield, daddr + 2, 4); + return HIPPI_HLEN; + } + hcb->ifield = 0; + return -((int)HIPPI_HLEN); +} + + +/* + * Determine the packet's protocol ID. + */ + +__be16 hippi_type_trans(struct sk_buff *skb, struct net_device *dev) +{ + struct hippi_hdr *hip; + + /* + * This is actually wrong ... question is if we really should + * set the raw address here. + */ + skb->dev = dev; + skb_reset_mac_header(skb); + hip = (struct hippi_hdr *)skb_mac_header(skb); + skb_pull(skb, HIPPI_HLEN); + + /* + * No fancy promisc stuff here now. + */ + + return hip->snap.ethertype; +} + +EXPORT_SYMBOL(hippi_type_trans); + +/* + * For HIPPI we will actually use the lower 4 bytes of the hardware + * address as the I-FIELD rather than the actual hardware address. + */ +int hippi_mac_addr(struct net_device *dev, void *p) +{ + struct sockaddr *addr = p; + if (netif_running(dev)) + return -EBUSY; + dev_addr_set(dev, addr->sa_data); + return 0; +} +EXPORT_SYMBOL(hippi_mac_addr); + +int hippi_neigh_setup_dev(struct net_device *dev, struct neigh_parms *p) +{ + /* Never send broadcast/multicast ARP messages */ + NEIGH_VAR_INIT(p, MCAST_PROBES, 0); + + /* In IPv6 unicast probes are valid even on NBMA, + * because they are encapsulated in normal IPv6 protocol. + * Should be a generic flag. + */ + if (p->tbl->family != AF_INET6) + NEIGH_VAR_INIT(p, UCAST_PROBES, 0); + return 0; +} +EXPORT_SYMBOL(hippi_neigh_setup_dev); + +static const struct header_ops hippi_header_ops = { + .create = hippi_header, +}; + + +static void hippi_setup(struct net_device *dev) +{ + dev->header_ops = &hippi_header_ops; + + /* + * We don't support HIPPI `ARP' for the time being, and probably + * never will unless someone else implements it. However we + * still need a fake ARPHRD to make ifconfig and friends play ball. + */ + dev->type = ARPHRD_HIPPI; + dev->hard_header_len = HIPPI_HLEN; + dev->mtu = 65280; + dev->min_mtu = 68; + dev->max_mtu = 65280; + dev->addr_len = HIPPI_ALEN; + dev->tx_queue_len = 25 /* 5 */; + memset(dev->broadcast, 0xFF, HIPPI_ALEN); + + + /* + * HIPPI doesn't support broadcast+multicast and we only use + * static ARP tables. ARP is disabled by hippi_neigh_setup_dev. + */ + dev->flags = 0; +} + +/** + * alloc_hippi_dev - Register HIPPI device + * @sizeof_priv: Size of additional driver-private structure to be allocated + * for this HIPPI device + * + * Fill in the fields of the device structure with HIPPI-generic values. + * + * Constructs a new net device, complete with a private data area of + * size @sizeof_priv. A 32-byte (not bit) alignment is enforced for + * this private data area. + */ + +struct net_device *alloc_hippi_dev(int sizeof_priv) +{ + return alloc_netdev(sizeof_priv, "hip%d", NET_NAME_UNKNOWN, + hippi_setup); +} + +EXPORT_SYMBOL(alloc_hippi_dev); diff --git a/net/802/mrp.c b/net/802/mrp.c new file mode 100644 index 000000000..6c927d4b3 --- /dev/null +++ b/net/802/mrp.c @@ -0,0 +1,943 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * IEEE 802.1Q Multiple Registration Protocol (MRP) + * + * Copyright (c) 2012 Massachusetts Institute of Technology + * + * Adapted from code in net/802/garp.c + * Copyright (c) 2008 Patrick McHardy + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static unsigned int mrp_join_time __read_mostly = 200; +module_param(mrp_join_time, uint, 0644); +MODULE_PARM_DESC(mrp_join_time, "Join time in ms (default 200ms)"); + +static unsigned int mrp_periodic_time __read_mostly = 1000; +module_param(mrp_periodic_time, uint, 0644); +MODULE_PARM_DESC(mrp_periodic_time, "Periodic time in ms (default 1s)"); + +MODULE_LICENSE("GPL"); + +static const u8 +mrp_applicant_state_table[MRP_APPLICANT_MAX + 1][MRP_EVENT_MAX + 1] = { + [MRP_APPLICANT_VO] = { + [MRP_EVENT_NEW] = MRP_APPLICANT_VN, + [MRP_EVENT_JOIN] = MRP_APPLICANT_VP, + [MRP_EVENT_LV] = MRP_APPLICANT_VO, + [MRP_EVENT_TX] = MRP_APPLICANT_VO, + [MRP_EVENT_R_NEW] = MRP_APPLICANT_VO, + [MRP_EVENT_R_JOIN_IN] = MRP_APPLICANT_AO, + [MRP_EVENT_R_IN] = MRP_APPLICANT_VO, + [MRP_EVENT_R_JOIN_MT] = MRP_APPLICANT_VO, + [MRP_EVENT_R_MT] = MRP_APPLICANT_VO, + [MRP_EVENT_R_LV] = MRP_APPLICANT_VO, + [MRP_EVENT_R_LA] = MRP_APPLICANT_VO, + [MRP_EVENT_REDECLARE] = MRP_APPLICANT_VO, + [MRP_EVENT_PERIODIC] = MRP_APPLICANT_VO, + }, + [MRP_APPLICANT_VP] = { + [MRP_EVENT_NEW] = MRP_APPLICANT_VN, + [MRP_EVENT_JOIN] = MRP_APPLICANT_VP, + [MRP_EVENT_LV] = MRP_APPLICANT_VO, + [MRP_EVENT_TX] = MRP_APPLICANT_AA, + [MRP_EVENT_R_NEW] = MRP_APPLICANT_VP, + [MRP_EVENT_R_JOIN_IN] = MRP_APPLICANT_AP, + [MRP_EVENT_R_IN] = MRP_APPLICANT_VP, + [MRP_EVENT_R_JOIN_MT] = MRP_APPLICANT_VP, + [MRP_EVENT_R_MT] = MRP_APPLICANT_VP, + [MRP_EVENT_R_LV] = MRP_APPLICANT_VP, + [MRP_EVENT_R_LA] = MRP_APPLICANT_VP, + [MRP_EVENT_REDECLARE] = MRP_APPLICANT_VP, + [MRP_EVENT_PERIODIC] = MRP_APPLICANT_VP, + }, + [MRP_APPLICANT_VN] = { + [MRP_EVENT_NEW] = MRP_APPLICANT_VN, + [MRP_EVENT_JOIN] = MRP_APPLICANT_VN, + [MRP_EVENT_LV] = MRP_APPLICANT_LA, + [MRP_EVENT_TX] = MRP_APPLICANT_AN, + [MRP_EVENT_R_NEW] = MRP_APPLICANT_VN, + [MRP_EVENT_R_JOIN_IN] = MRP_APPLICANT_VN, + [MRP_EVENT_R_IN] = MRP_APPLICANT_VN, + [MRP_EVENT_R_JOIN_MT] = MRP_APPLICANT_VN, + [MRP_EVENT_R_MT] = MRP_APPLICANT_VN, + [MRP_EVENT_R_LV] = MRP_APPLICANT_VN, + [MRP_EVENT_R_LA] = MRP_APPLICANT_VN, + [MRP_EVENT_REDECLARE] = MRP_APPLICANT_VN, + [MRP_EVENT_PERIODIC] = MRP_APPLICANT_VN, + }, + [MRP_APPLICANT_AN] = { + [MRP_EVENT_NEW] = MRP_APPLICANT_AN, + [MRP_EVENT_JOIN] = MRP_APPLICANT_AN, + [MRP_EVENT_LV] = MRP_APPLICANT_LA, + [MRP_EVENT_TX] = MRP_APPLICANT_QA, + [MRP_EVENT_R_NEW] = MRP_APPLICANT_AN, + [MRP_EVENT_R_JOIN_IN] = MRP_APPLICANT_AN, + [MRP_EVENT_R_IN] = MRP_APPLICANT_AN, + [MRP_EVENT_R_JOIN_MT] = MRP_APPLICANT_AN, + [MRP_EVENT_R_MT] = MRP_APPLICANT_AN, + [MRP_EVENT_R_LV] = MRP_APPLICANT_VN, + [MRP_EVENT_R_LA] = MRP_APPLICANT_VN, + [MRP_EVENT_REDECLARE] = MRP_APPLICANT_VN, + [MRP_EVENT_PERIODIC] = MRP_APPLICANT_AN, + }, + [MRP_APPLICANT_AA] = { + [MRP_EVENT_NEW] = MRP_APPLICANT_VN, + [MRP_EVENT_JOIN] = MRP_APPLICANT_AA, + [MRP_EVENT_LV] = MRP_APPLICANT_LA, + [MRP_EVENT_TX] = MRP_APPLICANT_QA, + [MRP_EVENT_R_NEW] = MRP_APPLICANT_AA, + [MRP_EVENT_R_JOIN_IN] = MRP_APPLICANT_QA, + [MRP_EVENT_R_IN] = MRP_APPLICANT_AA, + [MRP_EVENT_R_JOIN_MT] = MRP_APPLICANT_AA, + [MRP_EVENT_R_MT] = MRP_APPLICANT_AA, + [MRP_EVENT_R_LV] = MRP_APPLICANT_VP, + [MRP_EVENT_R_LA] = MRP_APPLICANT_VP, + [MRP_EVENT_REDECLARE] = MRP_APPLICANT_VP, + [MRP_EVENT_PERIODIC] = MRP_APPLICANT_AA, + }, + [MRP_APPLICANT_QA] = { + [MRP_EVENT_NEW] = MRP_APPLICANT_VN, + [MRP_EVENT_JOIN] = MRP_APPLICANT_QA, + [MRP_EVENT_LV] = MRP_APPLICANT_LA, + [MRP_EVENT_TX] = MRP_APPLICANT_QA, + [MRP_EVENT_R_NEW] = MRP_APPLICANT_QA, + [MRP_EVENT_R_JOIN_IN] = MRP_APPLICANT_QA, + [MRP_EVENT_R_IN] = MRP_APPLICANT_QA, + [MRP_EVENT_R_JOIN_MT] = MRP_APPLICANT_AA, + [MRP_EVENT_R_MT] = MRP_APPLICANT_AA, + [MRP_EVENT_R_LV] = MRP_APPLICANT_VP, + [MRP_EVENT_R_LA] = MRP_APPLICANT_VP, + [MRP_EVENT_REDECLARE] = MRP_APPLICANT_VP, + [MRP_EVENT_PERIODIC] = MRP_APPLICANT_AA, + }, + [MRP_APPLICANT_LA] = { + [MRP_EVENT_NEW] = MRP_APPLICANT_VN, + [MRP_EVENT_JOIN] = MRP_APPLICANT_AA, + [MRP_EVENT_LV] = MRP_APPLICANT_LA, + [MRP_EVENT_TX] = MRP_APPLICANT_VO, + [MRP_EVENT_R_NEW] = MRP_APPLICANT_LA, + [MRP_EVENT_R_JOIN_IN] = MRP_APPLICANT_LA, + [MRP_EVENT_R_IN] = MRP_APPLICANT_LA, + [MRP_EVENT_R_JOIN_MT] = MRP_APPLICANT_LA, + [MRP_EVENT_R_MT] = MRP_APPLICANT_LA, + [MRP_EVENT_R_LV] = MRP_APPLICANT_LA, + [MRP_EVENT_R_LA] = MRP_APPLICANT_LA, + [MRP_EVENT_REDECLARE] = MRP_APPLICANT_LA, + [MRP_EVENT_PERIODIC] = MRP_APPLICANT_LA, + }, + [MRP_APPLICANT_AO] = { + [MRP_EVENT_NEW] = MRP_APPLICANT_VN, + [MRP_EVENT_JOIN] = MRP_APPLICANT_AP, + [MRP_EVENT_LV] = MRP_APPLICANT_AO, + [MRP_EVENT_TX] = MRP_APPLICANT_AO, + [MRP_EVENT_R_NEW] = MRP_APPLICANT_AO, + [MRP_EVENT_R_JOIN_IN] = MRP_APPLICANT_QO, + [MRP_EVENT_R_IN] = MRP_APPLICANT_AO, + [MRP_EVENT_R_JOIN_MT] = MRP_APPLICANT_AO, + [MRP_EVENT_R_MT] = MRP_APPLICANT_AO, + [MRP_EVENT_R_LV] = MRP_APPLICANT_VO, + [MRP_EVENT_R_LA] = MRP_APPLICANT_VO, + [MRP_EVENT_REDECLARE] = MRP_APPLICANT_VO, + [MRP_EVENT_PERIODIC] = MRP_APPLICANT_AO, + }, + [MRP_APPLICANT_QO] = { + [MRP_EVENT_NEW] = MRP_APPLICANT_VN, + [MRP_EVENT_JOIN] = MRP_APPLICANT_QP, + [MRP_EVENT_LV] = MRP_APPLICANT_QO, + [MRP_EVENT_TX] = MRP_APPLICANT_QO, + [MRP_EVENT_R_NEW] = MRP_APPLICANT_QO, + [MRP_EVENT_R_JOIN_IN] = MRP_APPLICANT_QO, + [MRP_EVENT_R_IN] = MRP_APPLICANT_QO, + [MRP_EVENT_R_JOIN_MT] = MRP_APPLICANT_AO, + [MRP_EVENT_R_MT] = MRP_APPLICANT_AO, + [MRP_EVENT_R_LV] = MRP_APPLICANT_VO, + [MRP_EVENT_R_LA] = MRP_APPLICANT_VO, + [MRP_EVENT_REDECLARE] = MRP_APPLICANT_VO, + [MRP_EVENT_PERIODIC] = MRP_APPLICANT_QO, + }, + [MRP_APPLICANT_AP] = { + [MRP_EVENT_NEW] = MRP_APPLICANT_VN, + [MRP_EVENT_JOIN] = MRP_APPLICANT_AP, + [MRP_EVENT_LV] = MRP_APPLICANT_AO, + [MRP_EVENT_TX] = MRP_APPLICANT_QA, + [MRP_EVENT_R_NEW] = MRP_APPLICANT_AP, + [MRP_EVENT_R_JOIN_IN] = MRP_APPLICANT_QP, + [MRP_EVENT_R_IN] = MRP_APPLICANT_AP, + [MRP_EVENT_R_JOIN_MT] = MRP_APPLICANT_AP, + [MRP_EVENT_R_MT] = MRP_APPLICANT_AP, + [MRP_EVENT_R_LV] = MRP_APPLICANT_VP, + [MRP_EVENT_R_LA] = MRP_APPLICANT_VP, + [MRP_EVENT_REDECLARE] = MRP_APPLICANT_VP, + [MRP_EVENT_PERIODIC] = MRP_APPLICANT_AP, + }, + [MRP_APPLICANT_QP] = { + [MRP_EVENT_NEW] = MRP_APPLICANT_VN, + [MRP_EVENT_JOIN] = MRP_APPLICANT_QP, + [MRP_EVENT_LV] = MRP_APPLICANT_QO, + [MRP_EVENT_TX] = MRP_APPLICANT_QP, + [MRP_EVENT_R_NEW] = MRP_APPLICANT_QP, + [MRP_EVENT_R_JOIN_IN] = MRP_APPLICANT_QP, + [MRP_EVENT_R_IN] = MRP_APPLICANT_QP, + [MRP_EVENT_R_JOIN_MT] = MRP_APPLICANT_AP, + [MRP_EVENT_R_MT] = MRP_APPLICANT_AP, + [MRP_EVENT_R_LV] = MRP_APPLICANT_VP, + [MRP_EVENT_R_LA] = MRP_APPLICANT_VP, + [MRP_EVENT_REDECLARE] = MRP_APPLICANT_VP, + [MRP_EVENT_PERIODIC] = MRP_APPLICANT_AP, + }, +}; + +static const u8 +mrp_tx_action_table[MRP_APPLICANT_MAX + 1] = { + [MRP_APPLICANT_VO] = MRP_TX_ACTION_S_IN_OPTIONAL, + [MRP_APPLICANT_VP] = MRP_TX_ACTION_S_JOIN_IN, + [MRP_APPLICANT_VN] = MRP_TX_ACTION_S_NEW, + [MRP_APPLICANT_AN] = MRP_TX_ACTION_S_NEW, + [MRP_APPLICANT_AA] = MRP_TX_ACTION_S_JOIN_IN, + [MRP_APPLICANT_QA] = MRP_TX_ACTION_S_JOIN_IN_OPTIONAL, + [MRP_APPLICANT_LA] = MRP_TX_ACTION_S_LV, + [MRP_APPLICANT_AO] = MRP_TX_ACTION_S_IN_OPTIONAL, + [MRP_APPLICANT_QO] = MRP_TX_ACTION_S_IN_OPTIONAL, + [MRP_APPLICANT_AP] = MRP_TX_ACTION_S_JOIN_IN, + [MRP_APPLICANT_QP] = MRP_TX_ACTION_S_IN_OPTIONAL, +}; + +static void mrp_attrvalue_inc(void *value, u8 len) +{ + u8 *v = (u8 *)value; + + /* Add 1 to the last byte. If it becomes zero, + * go to the previous byte and repeat. + */ + while (len > 0 && !++v[--len]) + ; +} + +static int mrp_attr_cmp(const struct mrp_attr *attr, + const void *value, u8 len, u8 type) +{ + if (attr->type != type) + return attr->type - type; + if (attr->len != len) + return attr->len - len; + return memcmp(attr->value, value, len); +} + +static struct mrp_attr *mrp_attr_lookup(const struct mrp_applicant *app, + const void *value, u8 len, u8 type) +{ + struct rb_node *parent = app->mad.rb_node; + struct mrp_attr *attr; + int d; + + while (parent) { + attr = rb_entry(parent, struct mrp_attr, node); + d = mrp_attr_cmp(attr, value, len, type); + if (d > 0) + parent = parent->rb_left; + else if (d < 0) + parent = parent->rb_right; + else + return attr; + } + return NULL; +} + +static struct mrp_attr *mrp_attr_create(struct mrp_applicant *app, + const void *value, u8 len, u8 type) +{ + struct rb_node *parent = NULL, **p = &app->mad.rb_node; + struct mrp_attr *attr; + int d; + + while (*p) { + parent = *p; + attr = rb_entry(parent, struct mrp_attr, node); + d = mrp_attr_cmp(attr, value, len, type); + if (d > 0) + p = &parent->rb_left; + else if (d < 0) + p = &parent->rb_right; + else { + /* The attribute already exists; re-use it. */ + return attr; + } + } + attr = kmalloc(sizeof(*attr) + len, GFP_ATOMIC); + if (!attr) + return attr; + attr->state = MRP_APPLICANT_VO; + attr->type = type; + attr->len = len; + memcpy(attr->value, value, len); + + rb_link_node(&attr->node, parent, p); + rb_insert_color(&attr->node, &app->mad); + return attr; +} + +static void mrp_attr_destroy(struct mrp_applicant *app, struct mrp_attr *attr) +{ + rb_erase(&attr->node, &app->mad); + kfree(attr); +} + +static void mrp_attr_destroy_all(struct mrp_applicant *app) +{ + struct rb_node *node, *next; + struct mrp_attr *attr; + + for (node = rb_first(&app->mad); + next = node ? rb_next(node) : NULL, node != NULL; + node = next) { + attr = rb_entry(node, struct mrp_attr, node); + mrp_attr_destroy(app, attr); + } +} + +static int mrp_pdu_init(struct mrp_applicant *app) +{ + struct sk_buff *skb; + struct mrp_pdu_hdr *ph; + + skb = alloc_skb(app->dev->mtu + LL_RESERVED_SPACE(app->dev), + GFP_ATOMIC); + if (!skb) + return -ENOMEM; + + skb->dev = app->dev; + skb->protocol = app->app->pkttype.type; + skb_reserve(skb, LL_RESERVED_SPACE(app->dev)); + skb_reset_network_header(skb); + skb_reset_transport_header(skb); + + ph = __skb_put(skb, sizeof(*ph)); + ph->version = app->app->version; + + app->pdu = skb; + return 0; +} + +static int mrp_pdu_append_end_mark(struct mrp_applicant *app) +{ + __be16 *endmark; + + if (skb_tailroom(app->pdu) < sizeof(*endmark)) + return -1; + endmark = __skb_put(app->pdu, sizeof(*endmark)); + put_unaligned(MRP_END_MARK, endmark); + return 0; +} + +static void mrp_pdu_queue(struct mrp_applicant *app) +{ + if (!app->pdu) + return; + + if (mrp_cb(app->pdu)->mh) + mrp_pdu_append_end_mark(app); + mrp_pdu_append_end_mark(app); + + dev_hard_header(app->pdu, app->dev, ntohs(app->app->pkttype.type), + app->app->group_address, app->dev->dev_addr, + app->pdu->len); + + skb_queue_tail(&app->queue, app->pdu); + app->pdu = NULL; +} + +static void mrp_queue_xmit(struct mrp_applicant *app) +{ + struct sk_buff *skb; + + while ((skb = skb_dequeue(&app->queue))) + dev_queue_xmit(skb); +} + +static int mrp_pdu_append_msg_hdr(struct mrp_applicant *app, + u8 attrtype, u8 attrlen) +{ + struct mrp_msg_hdr *mh; + + if (mrp_cb(app->pdu)->mh) { + if (mrp_pdu_append_end_mark(app) < 0) + return -1; + mrp_cb(app->pdu)->mh = NULL; + mrp_cb(app->pdu)->vah = NULL; + } + + if (skb_tailroom(app->pdu) < sizeof(*mh)) + return -1; + mh = __skb_put(app->pdu, sizeof(*mh)); + mh->attrtype = attrtype; + mh->attrlen = attrlen; + mrp_cb(app->pdu)->mh = mh; + return 0; +} + +static int mrp_pdu_append_vecattr_hdr(struct mrp_applicant *app, + const void *firstattrvalue, u8 attrlen) +{ + struct mrp_vecattr_hdr *vah; + + if (skb_tailroom(app->pdu) < sizeof(*vah) + attrlen) + return -1; + vah = __skb_put(app->pdu, sizeof(*vah) + attrlen); + put_unaligned(0, &vah->lenflags); + memcpy(vah->firstattrvalue, firstattrvalue, attrlen); + mrp_cb(app->pdu)->vah = vah; + memcpy(mrp_cb(app->pdu)->attrvalue, firstattrvalue, attrlen); + return 0; +} + +static int mrp_pdu_append_vecattr_event(struct mrp_applicant *app, + const struct mrp_attr *attr, + enum mrp_vecattr_event vaevent) +{ + u16 len, pos; + u8 *vaevents; + int err; +again: + if (!app->pdu) { + err = mrp_pdu_init(app); + if (err < 0) + return err; + } + + /* If there is no Message header in the PDU, or the Message header is + * for a different attribute type, add an EndMark (if necessary) and a + * new Message header to the PDU. + */ + if (!mrp_cb(app->pdu)->mh || + mrp_cb(app->pdu)->mh->attrtype != attr->type || + mrp_cb(app->pdu)->mh->attrlen != attr->len) { + if (mrp_pdu_append_msg_hdr(app, attr->type, attr->len) < 0) + goto queue; + } + + /* If there is no VectorAttribute header for this Message in the PDU, + * or this attribute's value does not sequentially follow the previous + * attribute's value, add a new VectorAttribute header to the PDU. + */ + if (!mrp_cb(app->pdu)->vah || + memcmp(mrp_cb(app->pdu)->attrvalue, attr->value, attr->len)) { + if (mrp_pdu_append_vecattr_hdr(app, attr->value, attr->len) < 0) + goto queue; + } + + len = be16_to_cpu(get_unaligned(&mrp_cb(app->pdu)->vah->lenflags)); + pos = len % 3; + + /* Events are packed into Vectors in the PDU, three to a byte. Add a + * byte to the end of the Vector if necessary. + */ + if (!pos) { + if (skb_tailroom(app->pdu) < sizeof(u8)) + goto queue; + vaevents = __skb_put(app->pdu, sizeof(u8)); + } else { + vaevents = (u8 *)(skb_tail_pointer(app->pdu) - sizeof(u8)); + } + + switch (pos) { + case 0: + *vaevents = vaevent * (__MRP_VECATTR_EVENT_MAX * + __MRP_VECATTR_EVENT_MAX); + break; + case 1: + *vaevents += vaevent * __MRP_VECATTR_EVENT_MAX; + break; + case 2: + *vaevents += vaevent; + break; + default: + WARN_ON(1); + } + + /* Increment the length of the VectorAttribute in the PDU, as well as + * the value of the next attribute that would continue its Vector. + */ + put_unaligned(cpu_to_be16(++len), &mrp_cb(app->pdu)->vah->lenflags); + mrp_attrvalue_inc(mrp_cb(app->pdu)->attrvalue, attr->len); + + return 0; + +queue: + mrp_pdu_queue(app); + goto again; +} + +static void mrp_attr_event(struct mrp_applicant *app, + struct mrp_attr *attr, enum mrp_event event) +{ + enum mrp_applicant_state state; + + state = mrp_applicant_state_table[attr->state][event]; + if (state == MRP_APPLICANT_INVALID) { + WARN_ON(1); + return; + } + + if (event == MRP_EVENT_TX) { + /* When appending the attribute fails, don't update its state + * in order to retry at the next TX event. + */ + + switch (mrp_tx_action_table[attr->state]) { + case MRP_TX_ACTION_NONE: + case MRP_TX_ACTION_S_JOIN_IN_OPTIONAL: + case MRP_TX_ACTION_S_IN_OPTIONAL: + break; + case MRP_TX_ACTION_S_NEW: + if (mrp_pdu_append_vecattr_event( + app, attr, MRP_VECATTR_EVENT_NEW) < 0) + return; + break; + case MRP_TX_ACTION_S_JOIN_IN: + if (mrp_pdu_append_vecattr_event( + app, attr, MRP_VECATTR_EVENT_JOIN_IN) < 0) + return; + break; + case MRP_TX_ACTION_S_LV: + if (mrp_pdu_append_vecattr_event( + app, attr, MRP_VECATTR_EVENT_LV) < 0) + return; + /* As a pure applicant, sending a leave message + * implies that the attribute was unregistered and + * can be destroyed. + */ + mrp_attr_destroy(app, attr); + return; + default: + WARN_ON(1); + } + } + + attr->state = state; +} + +int mrp_request_join(const struct net_device *dev, + const struct mrp_application *appl, + const void *value, u8 len, u8 type) +{ + struct mrp_port *port = rtnl_dereference(dev->mrp_port); + struct mrp_applicant *app = rtnl_dereference( + port->applicants[appl->type]); + struct mrp_attr *attr; + + if (sizeof(struct mrp_skb_cb) + len > + sizeof_field(struct sk_buff, cb)) + return -ENOMEM; + + spin_lock_bh(&app->lock); + attr = mrp_attr_create(app, value, len, type); + if (!attr) { + spin_unlock_bh(&app->lock); + return -ENOMEM; + } + mrp_attr_event(app, attr, MRP_EVENT_JOIN); + spin_unlock_bh(&app->lock); + return 0; +} +EXPORT_SYMBOL_GPL(mrp_request_join); + +void mrp_request_leave(const struct net_device *dev, + const struct mrp_application *appl, + const void *value, u8 len, u8 type) +{ + struct mrp_port *port = rtnl_dereference(dev->mrp_port); + struct mrp_applicant *app = rtnl_dereference( + port->applicants[appl->type]); + struct mrp_attr *attr; + + if (sizeof(struct mrp_skb_cb) + len > + sizeof_field(struct sk_buff, cb)) + return; + + spin_lock_bh(&app->lock); + attr = mrp_attr_lookup(app, value, len, type); + if (!attr) { + spin_unlock_bh(&app->lock); + return; + } + mrp_attr_event(app, attr, MRP_EVENT_LV); + spin_unlock_bh(&app->lock); +} +EXPORT_SYMBOL_GPL(mrp_request_leave); + +static void mrp_mad_event(struct mrp_applicant *app, enum mrp_event event) +{ + struct rb_node *node, *next; + struct mrp_attr *attr; + + for (node = rb_first(&app->mad); + next = node ? rb_next(node) : NULL, node != NULL; + node = next) { + attr = rb_entry(node, struct mrp_attr, node); + mrp_attr_event(app, attr, event); + } +} + +static void mrp_join_timer_arm(struct mrp_applicant *app) +{ + unsigned long delay; + + delay = prandom_u32_max(msecs_to_jiffies(mrp_join_time)); + mod_timer(&app->join_timer, jiffies + delay); +} + +static void mrp_join_timer(struct timer_list *t) +{ + struct mrp_applicant *app = from_timer(app, t, join_timer); + + spin_lock(&app->lock); + mrp_mad_event(app, MRP_EVENT_TX); + mrp_pdu_queue(app); + spin_unlock(&app->lock); + + mrp_queue_xmit(app); + spin_lock(&app->lock); + if (likely(app->active)) + mrp_join_timer_arm(app); + spin_unlock(&app->lock); +} + +static void mrp_periodic_timer_arm(struct mrp_applicant *app) +{ + mod_timer(&app->periodic_timer, + jiffies + msecs_to_jiffies(mrp_periodic_time)); +} + +static void mrp_periodic_timer(struct timer_list *t) +{ + struct mrp_applicant *app = from_timer(app, t, periodic_timer); + + spin_lock(&app->lock); + if (likely(app->active)) { + mrp_mad_event(app, MRP_EVENT_PERIODIC); + mrp_pdu_queue(app); + mrp_periodic_timer_arm(app); + } + spin_unlock(&app->lock); +} + +static int mrp_pdu_parse_end_mark(struct sk_buff *skb, int *offset) +{ + __be16 endmark; + + if (skb_copy_bits(skb, *offset, &endmark, sizeof(endmark)) < 0) + return -1; + if (endmark == MRP_END_MARK) { + *offset += sizeof(endmark); + return -1; + } + return 0; +} + +static void mrp_pdu_parse_vecattr_event(struct mrp_applicant *app, + struct sk_buff *skb, + enum mrp_vecattr_event vaevent) +{ + struct mrp_attr *attr; + enum mrp_event event; + + attr = mrp_attr_lookup(app, mrp_cb(skb)->attrvalue, + mrp_cb(skb)->mh->attrlen, + mrp_cb(skb)->mh->attrtype); + if (attr == NULL) + return; + + switch (vaevent) { + case MRP_VECATTR_EVENT_NEW: + event = MRP_EVENT_R_NEW; + break; + case MRP_VECATTR_EVENT_JOIN_IN: + event = MRP_EVENT_R_JOIN_IN; + break; + case MRP_VECATTR_EVENT_IN: + event = MRP_EVENT_R_IN; + break; + case MRP_VECATTR_EVENT_JOIN_MT: + event = MRP_EVENT_R_JOIN_MT; + break; + case MRP_VECATTR_EVENT_MT: + event = MRP_EVENT_R_MT; + break; + case MRP_VECATTR_EVENT_LV: + event = MRP_EVENT_R_LV; + break; + default: + return; + } + + mrp_attr_event(app, attr, event); +} + +static int mrp_pdu_parse_vecattr(struct mrp_applicant *app, + struct sk_buff *skb, int *offset) +{ + struct mrp_vecattr_hdr _vah; + u16 valen; + u8 vaevents, vaevent; + + mrp_cb(skb)->vah = skb_header_pointer(skb, *offset, sizeof(_vah), + &_vah); + if (!mrp_cb(skb)->vah) + return -1; + *offset += sizeof(_vah); + + if (get_unaligned(&mrp_cb(skb)->vah->lenflags) & + MRP_VECATTR_HDR_FLAG_LA) + mrp_mad_event(app, MRP_EVENT_R_LA); + valen = be16_to_cpu(get_unaligned(&mrp_cb(skb)->vah->lenflags) & + MRP_VECATTR_HDR_LEN_MASK); + + /* The VectorAttribute structure in a PDU carries event information + * about one or more attributes having consecutive values. Only the + * value for the first attribute is contained in the structure. So + * we make a copy of that value, and then increment it each time we + * advance to the next event in its Vector. + */ + if (sizeof(struct mrp_skb_cb) + mrp_cb(skb)->mh->attrlen > + sizeof_field(struct sk_buff, cb)) + return -1; + if (skb_copy_bits(skb, *offset, mrp_cb(skb)->attrvalue, + mrp_cb(skb)->mh->attrlen) < 0) + return -1; + *offset += mrp_cb(skb)->mh->attrlen; + + /* In a VectorAttribute, the Vector contains events which are packed + * three to a byte. We process one byte of the Vector at a time. + */ + while (valen > 0) { + if (skb_copy_bits(skb, *offset, &vaevents, + sizeof(vaevents)) < 0) + return -1; + *offset += sizeof(vaevents); + + /* Extract and process the first event. */ + vaevent = vaevents / (__MRP_VECATTR_EVENT_MAX * + __MRP_VECATTR_EVENT_MAX); + if (vaevent >= __MRP_VECATTR_EVENT_MAX) { + /* The byte is malformed; stop processing. */ + return -1; + } + mrp_pdu_parse_vecattr_event(app, skb, vaevent); + + /* If present, extract and process the second event. */ + if (!--valen) + break; + mrp_attrvalue_inc(mrp_cb(skb)->attrvalue, + mrp_cb(skb)->mh->attrlen); + vaevents %= (__MRP_VECATTR_EVENT_MAX * + __MRP_VECATTR_EVENT_MAX); + vaevent = vaevents / __MRP_VECATTR_EVENT_MAX; + mrp_pdu_parse_vecattr_event(app, skb, vaevent); + + /* If present, extract and process the third event. */ + if (!--valen) + break; + mrp_attrvalue_inc(mrp_cb(skb)->attrvalue, + mrp_cb(skb)->mh->attrlen); + vaevents %= __MRP_VECATTR_EVENT_MAX; + vaevent = vaevents; + mrp_pdu_parse_vecattr_event(app, skb, vaevent); + } + return 0; +} + +static int mrp_pdu_parse_msg(struct mrp_applicant *app, struct sk_buff *skb, + int *offset) +{ + struct mrp_msg_hdr _mh; + + mrp_cb(skb)->mh = skb_header_pointer(skb, *offset, sizeof(_mh), &_mh); + if (!mrp_cb(skb)->mh) + return -1; + *offset += sizeof(_mh); + + if (mrp_cb(skb)->mh->attrtype == 0 || + mrp_cb(skb)->mh->attrtype > app->app->maxattr || + mrp_cb(skb)->mh->attrlen == 0) + return -1; + + while (skb->len > *offset) { + if (mrp_pdu_parse_end_mark(skb, offset) < 0) + break; + if (mrp_pdu_parse_vecattr(app, skb, offset) < 0) + return -1; + } + return 0; +} + +static int mrp_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + struct mrp_application *appl = container_of(pt, struct mrp_application, + pkttype); + struct mrp_port *port; + struct mrp_applicant *app; + struct mrp_pdu_hdr _ph; + const struct mrp_pdu_hdr *ph; + int offset = skb_network_offset(skb); + + /* If the interface is in promiscuous mode, drop the packet if + * it was unicast to another host. + */ + if (unlikely(skb->pkt_type == PACKET_OTHERHOST)) + goto out; + skb = skb_share_check(skb, GFP_ATOMIC); + if (unlikely(!skb)) + goto out; + port = rcu_dereference(dev->mrp_port); + if (unlikely(!port)) + goto out; + app = rcu_dereference(port->applicants[appl->type]); + if (unlikely(!app)) + goto out; + + ph = skb_header_pointer(skb, offset, sizeof(_ph), &_ph); + if (!ph) + goto out; + offset += sizeof(_ph); + + if (ph->version != app->app->version) + goto out; + + spin_lock(&app->lock); + while (skb->len > offset) { + if (mrp_pdu_parse_end_mark(skb, &offset) < 0) + break; + if (mrp_pdu_parse_msg(app, skb, &offset) < 0) + break; + } + spin_unlock(&app->lock); +out: + kfree_skb(skb); + return 0; +} + +static int mrp_init_port(struct net_device *dev) +{ + struct mrp_port *port; + + port = kzalloc(sizeof(*port), GFP_KERNEL); + if (!port) + return -ENOMEM; + rcu_assign_pointer(dev->mrp_port, port); + return 0; +} + +static void mrp_release_port(struct net_device *dev) +{ + struct mrp_port *port = rtnl_dereference(dev->mrp_port); + unsigned int i; + + for (i = 0; i <= MRP_APPLICATION_MAX; i++) { + if (rtnl_dereference(port->applicants[i])) + return; + } + RCU_INIT_POINTER(dev->mrp_port, NULL); + kfree_rcu(port, rcu); +} + +int mrp_init_applicant(struct net_device *dev, struct mrp_application *appl) +{ + struct mrp_applicant *app; + int err; + + ASSERT_RTNL(); + + if (!rtnl_dereference(dev->mrp_port)) { + err = mrp_init_port(dev); + if (err < 0) + goto err1; + } + + err = -ENOMEM; + app = kzalloc(sizeof(*app), GFP_KERNEL); + if (!app) + goto err2; + + err = dev_mc_add(dev, appl->group_address); + if (err < 0) + goto err3; + + app->dev = dev; + app->app = appl; + app->mad = RB_ROOT; + app->active = true; + spin_lock_init(&app->lock); + skb_queue_head_init(&app->queue); + rcu_assign_pointer(dev->mrp_port->applicants[appl->type], app); + timer_setup(&app->join_timer, mrp_join_timer, 0); + mrp_join_timer_arm(app); + timer_setup(&app->periodic_timer, mrp_periodic_timer, 0); + mrp_periodic_timer_arm(app); + return 0; + +err3: + kfree(app); +err2: + mrp_release_port(dev); +err1: + return err; +} +EXPORT_SYMBOL_GPL(mrp_init_applicant); + +void mrp_uninit_applicant(struct net_device *dev, struct mrp_application *appl) +{ + struct mrp_port *port = rtnl_dereference(dev->mrp_port); + struct mrp_applicant *app = rtnl_dereference( + port->applicants[appl->type]); + + ASSERT_RTNL(); + + RCU_INIT_POINTER(port->applicants[appl->type], NULL); + + spin_lock_bh(&app->lock); + app->active = false; + spin_unlock_bh(&app->lock); + /* Delete timer and generate a final TX event to flush out + * all pending messages before the applicant is gone. + */ + del_timer_sync(&app->join_timer); + del_timer_sync(&app->periodic_timer); + + spin_lock_bh(&app->lock); + mrp_mad_event(app, MRP_EVENT_TX); + mrp_attr_destroy_all(app); + mrp_pdu_queue(app); + spin_unlock_bh(&app->lock); + + mrp_queue_xmit(app); + + dev_mc_del(dev, appl->group_address); + kfree_rcu(app, rcu); + mrp_release_port(dev); +} +EXPORT_SYMBOL_GPL(mrp_uninit_applicant); + +int mrp_register_application(struct mrp_application *appl) +{ + appl->pkttype.func = mrp_rcv; + dev_add_pack(&appl->pkttype); + return 0; +} +EXPORT_SYMBOL_GPL(mrp_register_application); + +void mrp_unregister_application(struct mrp_application *appl) +{ + dev_remove_pack(&appl->pkttype); +} +EXPORT_SYMBOL_GPL(mrp_unregister_application); diff --git a/net/802/p8022.c b/net/802/p8022.c new file mode 100644 index 000000000..79c231731 --- /dev/null +++ b/net/802/p8022.c @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NET3: Support for 802.2 demultiplexing off Ethernet + * + * Demultiplex 802.2 encoded protocols. We match the entry by the + * SSAP/DSAP pair and then deliver to the registered datalink that + * matches. The control byte is ignored and handling of such items + * is up to the routine passed the frame. + * + * Unlike the 802.3 datalink we have a list of 802.2 entries as + * there are multiple protocols to demux. The list is currently + * short (3 or 4 entries at most). The current demux assumes this. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int p8022_request(struct datalink_proto *dl, struct sk_buff *skb, + const unsigned char *dest) +{ + llc_build_and_send_ui_pkt(dl->sap, skb, dest, dl->sap->laddr.lsap); + return 0; +} + +struct datalink_proto *register_8022_client(unsigned char type, + int (*func)(struct sk_buff *skb, + struct net_device *dev, + struct packet_type *pt, + struct net_device *orig_dev)) +{ + struct datalink_proto *proto; + + proto = kmalloc(sizeof(*proto), GFP_ATOMIC); + if (proto) { + proto->type[0] = type; + proto->header_length = 3; + proto->request = p8022_request; + proto->sap = llc_sap_open(type, func); + if (!proto->sap) { + kfree(proto); + proto = NULL; + } + } + return proto; +} + +void unregister_8022_client(struct datalink_proto *proto) +{ + llc_sap_put(proto->sap); + kfree(proto); +} + +EXPORT_SYMBOL(register_8022_client); +EXPORT_SYMBOL(unregister_8022_client); + +MODULE_LICENSE("GPL"); diff --git a/net/802/psnap.c b/net/802/psnap.c new file mode 100644 index 000000000..1406bfdbd --- /dev/null +++ b/net/802/psnap.c @@ -0,0 +1,163 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * SNAP data link layer. Derived from 802.2 + * + * Alan Cox , + * from the 802.2 layer by Greg Page. + * Merged in additions from Greg Page's psnap.c. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static LIST_HEAD(snap_list); +static DEFINE_SPINLOCK(snap_lock); +static struct llc_sap *snap_sap; + +/* + * Find a snap client by matching the 5 bytes. + */ +static struct datalink_proto *find_snap_client(const unsigned char *desc) +{ + struct datalink_proto *proto = NULL, *p; + + list_for_each_entry_rcu(p, &snap_list, node, lockdep_is_held(&snap_lock)) { + if (!memcmp(p->type, desc, 5)) { + proto = p; + break; + } + } + return proto; +} + +/* + * A SNAP packet has arrived + */ +static int snap_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + int rc = 1; + struct datalink_proto *proto; + static struct packet_type snap_packet_type = { + .type = cpu_to_be16(ETH_P_SNAP), + }; + + if (unlikely(!pskb_may_pull(skb, 5))) + goto drop; + + rcu_read_lock(); + proto = find_snap_client(skb_transport_header(skb)); + if (proto) { + /* Pass the frame on. */ + skb->transport_header += 5; + skb_pull_rcsum(skb, 5); + rc = proto->rcvfunc(skb, dev, &snap_packet_type, orig_dev); + } + rcu_read_unlock(); + + if (unlikely(!proto)) + goto drop; + +out: + return rc; + +drop: + kfree_skb(skb); + goto out; +} + +/* + * Put a SNAP header on a frame and pass to 802.2 + */ +static int snap_request(struct datalink_proto *dl, + struct sk_buff *skb, const u8 *dest) +{ + memcpy(skb_push(skb, 5), dl->type, 5); + llc_build_and_send_ui_pkt(snap_sap, skb, dest, snap_sap->laddr.lsap); + return 0; +} + +/* + * Set up the SNAP layer + */ +EXPORT_SYMBOL(register_snap_client); +EXPORT_SYMBOL(unregister_snap_client); + +static const char snap_err_msg[] __initconst = + KERN_CRIT "SNAP - unable to register with 802.2\n"; + +static int __init snap_init(void) +{ + snap_sap = llc_sap_open(0xAA, snap_rcv); + if (!snap_sap) { + printk(snap_err_msg); + return -EBUSY; + } + + return 0; +} + +module_init(snap_init); + +static void __exit snap_exit(void) +{ + llc_sap_put(snap_sap); +} + +module_exit(snap_exit); + + +/* + * Register SNAP clients. We don't yet use this for IP. + */ +struct datalink_proto *register_snap_client(const unsigned char *desc, + int (*rcvfunc)(struct sk_buff *, + struct net_device *, + struct packet_type *, + struct net_device *)) +{ + struct datalink_proto *proto = NULL; + + spin_lock_bh(&snap_lock); + + if (find_snap_client(desc)) + goto out; + + proto = kmalloc(sizeof(*proto), GFP_ATOMIC); + if (proto) { + memcpy(proto->type, desc, 5); + proto->rcvfunc = rcvfunc; + proto->header_length = 5 + 3; /* snap + 802.2 */ + proto->request = snap_request; + list_add_rcu(&proto->node, &snap_list); + } +out: + spin_unlock_bh(&snap_lock); + + return proto; +} + +/* + * Unregister SNAP clients. Protocols no longer want to play with us ... + */ +void unregister_snap_client(struct datalink_proto *proto) +{ + spin_lock_bh(&snap_lock); + list_del_rcu(&proto->node); + spin_unlock_bh(&snap_lock); + + synchronize_net(); + + kfree(proto); +} + +MODULE_LICENSE("GPL"); diff --git a/net/802/stp.c b/net/802/stp.c new file mode 100644 index 000000000..d550d9f88 --- /dev/null +++ b/net/802/stp.c @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * STP SAP demux + * + * Copyright (c) 2008 Patrick McHardy + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* 01:80:c2:00:00:20 - 01:80:c2:00:00:2F */ +#define GARP_ADDR_MIN 0x20 +#define GARP_ADDR_MAX 0x2F +#define GARP_ADDR_RANGE (GARP_ADDR_MAX - GARP_ADDR_MIN) + +static const struct stp_proto __rcu *garp_protos[GARP_ADDR_RANGE + 1] __read_mostly; +static const struct stp_proto __rcu *stp_proto __read_mostly; + +static struct llc_sap *sap __read_mostly; +static unsigned int sap_registered; +static DEFINE_MUTEX(stp_proto_mutex); + +/* Called under rcu_read_lock from LLC */ +static int stp_pdu_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + const struct ethhdr *eh = eth_hdr(skb); + const struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + const struct stp_proto *proto; + + if (pdu->ssap != LLC_SAP_BSPAN || + pdu->dsap != LLC_SAP_BSPAN || + pdu->ctrl_1 != LLC_PDU_TYPE_U) + goto err; + + if (eh->h_dest[5] >= GARP_ADDR_MIN && eh->h_dest[5] <= GARP_ADDR_MAX) { + proto = rcu_dereference(garp_protos[eh->h_dest[5] - + GARP_ADDR_MIN]); + if (proto && + !ether_addr_equal(eh->h_dest, proto->group_address)) + goto err; + } else + proto = rcu_dereference(stp_proto); + + if (!proto) + goto err; + + proto->rcv(proto, skb, dev); + return 0; + +err: + kfree_skb(skb); + return 0; +} + +int stp_proto_register(const struct stp_proto *proto) +{ + int err = 0; + + mutex_lock(&stp_proto_mutex); + if (sap_registered++ == 0) { + sap = llc_sap_open(LLC_SAP_BSPAN, stp_pdu_rcv); + if (!sap) { + err = -ENOMEM; + goto out; + } + } + if (is_zero_ether_addr(proto->group_address)) + rcu_assign_pointer(stp_proto, proto); + else + rcu_assign_pointer(garp_protos[proto->group_address[5] - + GARP_ADDR_MIN], proto); +out: + mutex_unlock(&stp_proto_mutex); + return err; +} +EXPORT_SYMBOL_GPL(stp_proto_register); + +void stp_proto_unregister(const struct stp_proto *proto) +{ + mutex_lock(&stp_proto_mutex); + if (is_zero_ether_addr(proto->group_address)) + RCU_INIT_POINTER(stp_proto, NULL); + else + RCU_INIT_POINTER(garp_protos[proto->group_address[5] - + GARP_ADDR_MIN], NULL); + synchronize_rcu(); + + if (--sap_registered == 0) + llc_sap_put(sap); + mutex_unlock(&stp_proto_mutex); +} +EXPORT_SYMBOL_GPL(stp_proto_unregister); + +MODULE_LICENSE("GPL"); diff --git a/net/8021q/Kconfig b/net/8021q/Kconfig new file mode 100644 index 000000000..8bf7a1765 --- /dev/null +++ b/net/8021q/Kconfig @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Configuration for 802.1Q VLAN support +# + +config VLAN_8021Q + tristate "802.1Q/802.1ad VLAN Support" + help + Select this and you will be able to create 802.1Q VLAN interfaces + on your Ethernet interfaces. 802.1Q VLAN supports almost + everything a regular Ethernet interface does, including + firewalling, bridging, and of course IP traffic. You will need + the 'ip' utility in order to effectively use VLANs. + See the VLAN web page for more information: + + + To compile this code as a module, choose M here: the module + will be called 8021q. + + If unsure, say N. + +config VLAN_8021Q_GVRP + bool "GVRP (GARP VLAN Registration Protocol) support" + depends on VLAN_8021Q + select GARP + help + Select this to enable GVRP end-system support. GVRP is used for + automatic propagation of registered VLANs to switches. + + If unsure, say N. + +config VLAN_8021Q_MVRP + bool "MVRP (Multiple VLAN Registration Protocol) support" + depends on VLAN_8021Q + select MRP + help + Select this to enable MVRP end-system support. MVRP is used for + automatic propagation of registered VLANs to switches; it + supersedes GVRP and is not backwards-compatible. + + If unsure, say N. diff --git a/net/8021q/Makefile b/net/8021q/Makefile new file mode 100644 index 000000000..e05d4d7aa --- /dev/null +++ b/net/8021q/Makefile @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the Linux VLAN layer. +# +obj-$(subst m,y,$(CONFIG_VLAN_8021Q)) += vlan_core.o +obj-$(CONFIG_VLAN_8021Q) += 8021q.o + +8021q-y := vlan.o vlan_dev.o vlan_netlink.o +8021q-$(CONFIG_VLAN_8021Q_GVRP) += vlan_gvrp.o +8021q-$(CONFIG_VLAN_8021Q_MVRP) += vlan_mvrp.o +8021q-$(CONFIG_PROC_FS) += vlanproc.o diff --git a/net/8021q/vlan.c b/net/8021q/vlan.c new file mode 100644 index 000000000..e40aa3e36 --- /dev/null +++ b/net/8021q/vlan.c @@ -0,0 +1,742 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET 802.1Q VLAN + * Ethernet-type device handling. + * + * Authors: Ben Greear + * Please send support related email to: netdev@vger.kernel.org + * VLAN Home Page: http://www.candelatech.com/~greear/vlan.html + * + * Fixes: + * Fix for packet capture - Nick Eggleston ; + * Add HW acceleration hooks - David S. Miller ; + * Correct all the locking - David S. Miller ; + * Use hash table for VLAN groups - David S. Miller + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "vlan.h" +#include "vlanproc.h" + +#define DRV_VERSION "1.8" + +/* Global VLAN variables */ + +unsigned int vlan_net_id __read_mostly; + +const char vlan_fullname[] = "802.1Q VLAN Support"; +const char vlan_version[] = DRV_VERSION; + +/* End of global variables definitions. */ + +static int vlan_group_prealloc_vid(struct vlan_group *vg, + __be16 vlan_proto, u16 vlan_id) +{ + struct net_device **array; + unsigned int vidx; + unsigned int size; + int pidx; + + ASSERT_RTNL(); + + pidx = vlan_proto_idx(vlan_proto); + if (pidx < 0) + return -EINVAL; + + vidx = vlan_id / VLAN_GROUP_ARRAY_PART_LEN; + array = vg->vlan_devices_arrays[pidx][vidx]; + if (array != NULL) + return 0; + + size = sizeof(struct net_device *) * VLAN_GROUP_ARRAY_PART_LEN; + array = kzalloc(size, GFP_KERNEL_ACCOUNT); + if (array == NULL) + return -ENOBUFS; + + /* paired with smp_rmb() in __vlan_group_get_device() */ + smp_wmb(); + + vg->vlan_devices_arrays[pidx][vidx] = array; + return 0; +} + +static void vlan_stacked_transfer_operstate(const struct net_device *rootdev, + struct net_device *dev, + struct vlan_dev_priv *vlan) +{ + if (!(vlan->flags & VLAN_FLAG_BRIDGE_BINDING)) + netif_stacked_transfer_operstate(rootdev, dev); +} + +void unregister_vlan_dev(struct net_device *dev, struct list_head *head) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + struct net_device *real_dev = vlan->real_dev; + struct vlan_info *vlan_info; + struct vlan_group *grp; + u16 vlan_id = vlan->vlan_id; + + ASSERT_RTNL(); + + vlan_info = rtnl_dereference(real_dev->vlan_info); + BUG_ON(!vlan_info); + + grp = &vlan_info->grp; + + grp->nr_vlan_devs--; + + if (vlan->flags & VLAN_FLAG_MVRP) + vlan_mvrp_request_leave(dev); + if (vlan->flags & VLAN_FLAG_GVRP) + vlan_gvrp_request_leave(dev); + + vlan_group_set_device(grp, vlan->vlan_proto, vlan_id, NULL); + + netdev_upper_dev_unlink(real_dev, dev); + /* Because unregister_netdevice_queue() makes sure at least one rcu + * grace period is respected before device freeing, + * we dont need to call synchronize_net() here. + */ + unregister_netdevice_queue(dev, head); + + if (grp->nr_vlan_devs == 0) { + vlan_mvrp_uninit_applicant(real_dev); + vlan_gvrp_uninit_applicant(real_dev); + } + + vlan_vid_del(real_dev, vlan->vlan_proto, vlan_id); +} + +int vlan_check_real_dev(struct net_device *real_dev, + __be16 protocol, u16 vlan_id, + struct netlink_ext_ack *extack) +{ + const char *name = real_dev->name; + + if (real_dev->features & NETIF_F_VLAN_CHALLENGED) { + pr_info("VLANs not supported on %s\n", name); + NL_SET_ERR_MSG_MOD(extack, "VLANs not supported on device"); + return -EOPNOTSUPP; + } + + if (vlan_find_dev(real_dev, protocol, vlan_id) != NULL) { + NL_SET_ERR_MSG_MOD(extack, "VLAN device already exists"); + return -EEXIST; + } + + return 0; +} + +int register_vlan_dev(struct net_device *dev, struct netlink_ext_ack *extack) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + struct net_device *real_dev = vlan->real_dev; + u16 vlan_id = vlan->vlan_id; + struct vlan_info *vlan_info; + struct vlan_group *grp; + int err; + + err = vlan_vid_add(real_dev, vlan->vlan_proto, vlan_id); + if (err) + return err; + + vlan_info = rtnl_dereference(real_dev->vlan_info); + /* vlan_info should be there now. vlan_vid_add took care of it */ + BUG_ON(!vlan_info); + + grp = &vlan_info->grp; + if (grp->nr_vlan_devs == 0) { + err = vlan_gvrp_init_applicant(real_dev); + if (err < 0) + goto out_vid_del; + err = vlan_mvrp_init_applicant(real_dev); + if (err < 0) + goto out_uninit_gvrp; + } + + err = vlan_group_prealloc_vid(grp, vlan->vlan_proto, vlan_id); + if (err < 0) + goto out_uninit_mvrp; + + err = register_netdevice(dev); + if (err < 0) + goto out_uninit_mvrp; + + err = netdev_upper_dev_link(real_dev, dev, extack); + if (err) + goto out_unregister_netdev; + + vlan_stacked_transfer_operstate(real_dev, dev, vlan); + linkwatch_fire_event(dev); /* _MUST_ call rfc2863_policy() */ + + /* So, got the sucker initialized, now lets place + * it into our local structure. + */ + vlan_group_set_device(grp, vlan->vlan_proto, vlan_id, dev); + grp->nr_vlan_devs++; + + return 0; + +out_unregister_netdev: + unregister_netdevice(dev); +out_uninit_mvrp: + if (grp->nr_vlan_devs == 0) + vlan_mvrp_uninit_applicant(real_dev); +out_uninit_gvrp: + if (grp->nr_vlan_devs == 0) + vlan_gvrp_uninit_applicant(real_dev); +out_vid_del: + vlan_vid_del(real_dev, vlan->vlan_proto, vlan_id); + return err; +} + +/* Attach a VLAN device to a mac address (ie Ethernet Card). + * Returns 0 if the device was created or a negative error code otherwise. + */ +static int register_vlan_device(struct net_device *real_dev, u16 vlan_id) +{ + struct net_device *new_dev; + struct vlan_dev_priv *vlan; + struct net *net = dev_net(real_dev); + struct vlan_net *vn = net_generic(net, vlan_net_id); + char name[IFNAMSIZ]; + int err; + + if (vlan_id >= VLAN_VID_MASK) + return -ERANGE; + + err = vlan_check_real_dev(real_dev, htons(ETH_P_8021Q), vlan_id, + NULL); + if (err < 0) + return err; + + /* Gotta set up the fields for the device. */ + switch (vn->name_type) { + case VLAN_NAME_TYPE_RAW_PLUS_VID: + /* name will look like: eth1.0005 */ + snprintf(name, IFNAMSIZ, "%s.%.4i", real_dev->name, vlan_id); + break; + case VLAN_NAME_TYPE_PLUS_VID_NO_PAD: + /* Put our vlan.VID in the name. + * Name will look like: vlan5 + */ + snprintf(name, IFNAMSIZ, "vlan%i", vlan_id); + break; + case VLAN_NAME_TYPE_RAW_PLUS_VID_NO_PAD: + /* Put our vlan.VID in the name. + * Name will look like: eth0.5 + */ + snprintf(name, IFNAMSIZ, "%s.%i", real_dev->name, vlan_id); + break; + case VLAN_NAME_TYPE_PLUS_VID: + /* Put our vlan.VID in the name. + * Name will look like: vlan0005 + */ + default: + snprintf(name, IFNAMSIZ, "vlan%.4i", vlan_id); + } + + new_dev = alloc_netdev(sizeof(struct vlan_dev_priv), name, + NET_NAME_UNKNOWN, vlan_setup); + + if (new_dev == NULL) + return -ENOBUFS; + + dev_net_set(new_dev, net); + /* need 4 bytes for extra VLAN header info, + * hope the underlying device can handle it. + */ + new_dev->mtu = real_dev->mtu; + + vlan = vlan_dev_priv(new_dev); + vlan->vlan_proto = htons(ETH_P_8021Q); + vlan->vlan_id = vlan_id; + vlan->real_dev = real_dev; + vlan->dent = NULL; + vlan->flags = VLAN_FLAG_REORDER_HDR; + + new_dev->rtnl_link_ops = &vlan_link_ops; + err = register_vlan_dev(new_dev, NULL); + if (err < 0) + goto out_free_newdev; + + return 0; + +out_free_newdev: + free_netdev(new_dev); + return err; +} + +static void vlan_sync_address(struct net_device *dev, + struct net_device *vlandev) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(vlandev); + + /* May be called without an actual change */ + if (ether_addr_equal(vlan->real_dev_addr, dev->dev_addr)) + return; + + /* vlan continues to inherit address of lower device */ + if (vlan_dev_inherit_address(vlandev, dev)) + goto out; + + /* vlan address was different from the old address and is equal to + * the new address */ + if (!ether_addr_equal(vlandev->dev_addr, vlan->real_dev_addr) && + ether_addr_equal(vlandev->dev_addr, dev->dev_addr)) + dev_uc_del(dev, vlandev->dev_addr); + + /* vlan address was equal to the old address and is different from + * the new address */ + if (ether_addr_equal(vlandev->dev_addr, vlan->real_dev_addr) && + !ether_addr_equal(vlandev->dev_addr, dev->dev_addr)) + dev_uc_add(dev, vlandev->dev_addr); + +out: + ether_addr_copy(vlan->real_dev_addr, dev->dev_addr); +} + +static void vlan_transfer_features(struct net_device *dev, + struct net_device *vlandev) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(vlandev); + + netif_inherit_tso_max(vlandev, dev); + + if (vlan_hw_offload_capable(dev->features, vlan->vlan_proto)) + vlandev->hard_header_len = dev->hard_header_len; + else + vlandev->hard_header_len = dev->hard_header_len + VLAN_HLEN; + +#if IS_ENABLED(CONFIG_FCOE) + vlandev->fcoe_ddp_xid = dev->fcoe_ddp_xid; +#endif + + vlandev->priv_flags &= ~IFF_XMIT_DST_RELEASE; + vlandev->priv_flags |= (vlan->real_dev->priv_flags & IFF_XMIT_DST_RELEASE); + vlandev->hw_enc_features = vlan_tnl_features(vlan->real_dev); + + netdev_update_features(vlandev); +} + +static int __vlan_device_event(struct net_device *dev, unsigned long event) +{ + int err = 0; + + switch (event) { + case NETDEV_CHANGENAME: + vlan_proc_rem_dev(dev); + err = vlan_proc_add_dev(dev); + break; + case NETDEV_REGISTER: + err = vlan_proc_add_dev(dev); + break; + case NETDEV_UNREGISTER: + vlan_proc_rem_dev(dev); + break; + } + + return err; +} + +static int vlan_device_event(struct notifier_block *unused, unsigned long event, + void *ptr) +{ + struct netlink_ext_ack *extack = netdev_notifier_info_to_extack(ptr); + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct vlan_group *grp; + struct vlan_info *vlan_info; + int i, flgs; + struct net_device *vlandev; + struct vlan_dev_priv *vlan; + bool last = false; + LIST_HEAD(list); + int err; + + if (is_vlan_dev(dev)) { + int err = __vlan_device_event(dev, event); + + if (err) + return notifier_from_errno(err); + } + + if ((event == NETDEV_UP) && + (dev->features & NETIF_F_HW_VLAN_CTAG_FILTER)) { + pr_info("adding VLAN 0 to HW filter on device %s\n", + dev->name); + vlan_vid_add(dev, htons(ETH_P_8021Q), 0); + } + if (event == NETDEV_DOWN && + (dev->features & NETIF_F_HW_VLAN_CTAG_FILTER)) + vlan_vid_del(dev, htons(ETH_P_8021Q), 0); + + vlan_info = rtnl_dereference(dev->vlan_info); + if (!vlan_info) + goto out; + grp = &vlan_info->grp; + + /* It is OK that we do not hold the group lock right now, + * as we run under the RTNL lock. + */ + + switch (event) { + case NETDEV_CHANGE: + /* Propagate real device state to vlan devices */ + vlan_group_for_each_dev(grp, i, vlandev) + vlan_stacked_transfer_operstate(dev, vlandev, + vlan_dev_priv(vlandev)); + break; + + case NETDEV_CHANGEADDR: + /* Adjust unicast filters on underlying device */ + vlan_group_for_each_dev(grp, i, vlandev) { + flgs = vlandev->flags; + if (!(flgs & IFF_UP)) + continue; + + vlan_sync_address(dev, vlandev); + } + break; + + case NETDEV_CHANGEMTU: + vlan_group_for_each_dev(grp, i, vlandev) { + if (vlandev->mtu <= dev->mtu) + continue; + + dev_set_mtu(vlandev, dev->mtu); + } + break; + + case NETDEV_FEAT_CHANGE: + /* Propagate device features to underlying device */ + vlan_group_for_each_dev(grp, i, vlandev) + vlan_transfer_features(dev, vlandev); + break; + + case NETDEV_DOWN: { + struct net_device *tmp; + LIST_HEAD(close_list); + + /* Put all VLANs for this dev in the down state too. */ + vlan_group_for_each_dev(grp, i, vlandev) { + flgs = vlandev->flags; + if (!(flgs & IFF_UP)) + continue; + + vlan = vlan_dev_priv(vlandev); + if (!(vlan->flags & VLAN_FLAG_LOOSE_BINDING)) + list_add(&vlandev->close_list, &close_list); + } + + dev_close_many(&close_list, false); + + list_for_each_entry_safe(vlandev, tmp, &close_list, close_list) { + vlan_stacked_transfer_operstate(dev, vlandev, + vlan_dev_priv(vlandev)); + list_del_init(&vlandev->close_list); + } + list_del(&close_list); + break; + } + case NETDEV_UP: + /* Put all VLANs for this dev in the up state too. */ + vlan_group_for_each_dev(grp, i, vlandev) { + flgs = dev_get_flags(vlandev); + if (flgs & IFF_UP) + continue; + + vlan = vlan_dev_priv(vlandev); + if (!(vlan->flags & VLAN_FLAG_LOOSE_BINDING)) + dev_change_flags(vlandev, flgs | IFF_UP, + extack); + vlan_stacked_transfer_operstate(dev, vlandev, vlan); + } + break; + + case NETDEV_UNREGISTER: + /* twiddle thumbs on netns device moves */ + if (dev->reg_state != NETREG_UNREGISTERING) + break; + + vlan_group_for_each_dev(grp, i, vlandev) { + /* removal of last vid destroys vlan_info, abort + * afterwards */ + if (vlan_info->nr_vids == 1) + last = true; + + unregister_vlan_dev(vlandev, &list); + if (last) + break; + } + unregister_netdevice_many(&list); + break; + + case NETDEV_PRE_TYPE_CHANGE: + /* Forbid underlaying device to change its type. */ + if (vlan_uses_dev(dev)) + return NOTIFY_BAD; + break; + + case NETDEV_NOTIFY_PEERS: + case NETDEV_BONDING_FAILOVER: + case NETDEV_RESEND_IGMP: + /* Propagate to vlan devices */ + vlan_group_for_each_dev(grp, i, vlandev) + call_netdevice_notifiers(event, vlandev); + break; + + case NETDEV_CVLAN_FILTER_PUSH_INFO: + err = vlan_filter_push_vids(vlan_info, htons(ETH_P_8021Q)); + if (err) + return notifier_from_errno(err); + break; + + case NETDEV_CVLAN_FILTER_DROP_INFO: + vlan_filter_drop_vids(vlan_info, htons(ETH_P_8021Q)); + break; + + case NETDEV_SVLAN_FILTER_PUSH_INFO: + err = vlan_filter_push_vids(vlan_info, htons(ETH_P_8021AD)); + if (err) + return notifier_from_errno(err); + break; + + case NETDEV_SVLAN_FILTER_DROP_INFO: + vlan_filter_drop_vids(vlan_info, htons(ETH_P_8021AD)); + break; + } + +out: + return NOTIFY_DONE; +} + +static struct notifier_block vlan_notifier_block __read_mostly = { + .notifier_call = vlan_device_event, +}; + +/* + * VLAN IOCTL handler. + * o execute requested action or pass command to the device driver + * arg is really a struct vlan_ioctl_args __user *. + */ +static int vlan_ioctl_handler(struct net *net, void __user *arg) +{ + int err; + struct vlan_ioctl_args args; + struct net_device *dev = NULL; + + if (copy_from_user(&args, arg, sizeof(struct vlan_ioctl_args))) + return -EFAULT; + + /* Null terminate this sucker, just in case. */ + args.device1[sizeof(args.device1) - 1] = 0; + args.u.device2[sizeof(args.u.device2) - 1] = 0; + + rtnl_lock(); + + switch (args.cmd) { + case SET_VLAN_INGRESS_PRIORITY_CMD: + case SET_VLAN_EGRESS_PRIORITY_CMD: + case SET_VLAN_FLAG_CMD: + case ADD_VLAN_CMD: + case DEL_VLAN_CMD: + case GET_VLAN_REALDEV_NAME_CMD: + case GET_VLAN_VID_CMD: + err = -ENODEV; + dev = __dev_get_by_name(net, args.device1); + if (!dev) + goto out; + + err = -EINVAL; + if (args.cmd != ADD_VLAN_CMD && !is_vlan_dev(dev)) + goto out; + } + + switch (args.cmd) { + case SET_VLAN_INGRESS_PRIORITY_CMD: + err = -EPERM; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + break; + vlan_dev_set_ingress_priority(dev, + args.u.skb_priority, + args.vlan_qos); + err = 0; + break; + + case SET_VLAN_EGRESS_PRIORITY_CMD: + err = -EPERM; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + break; + err = vlan_dev_set_egress_priority(dev, + args.u.skb_priority, + args.vlan_qos); + break; + + case SET_VLAN_FLAG_CMD: + err = -EPERM; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + break; + err = vlan_dev_change_flags(dev, + args.vlan_qos ? args.u.flag : 0, + args.u.flag); + break; + + case SET_VLAN_NAME_TYPE_CMD: + err = -EPERM; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + break; + if (args.u.name_type < VLAN_NAME_TYPE_HIGHEST) { + struct vlan_net *vn; + + vn = net_generic(net, vlan_net_id); + vn->name_type = args.u.name_type; + err = 0; + } else { + err = -EINVAL; + } + break; + + case ADD_VLAN_CMD: + err = -EPERM; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + break; + err = register_vlan_device(dev, args.u.VID); + break; + + case DEL_VLAN_CMD: + err = -EPERM; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + break; + unregister_vlan_dev(dev, NULL); + err = 0; + break; + + case GET_VLAN_REALDEV_NAME_CMD: + err = 0; + vlan_dev_get_realdev_name(dev, args.u.device2, + sizeof(args.u.device2)); + if (copy_to_user(arg, &args, + sizeof(struct vlan_ioctl_args))) + err = -EFAULT; + break; + + case GET_VLAN_VID_CMD: + err = 0; + args.u.VID = vlan_dev_vlan_id(dev); + if (copy_to_user(arg, &args, + sizeof(struct vlan_ioctl_args))) + err = -EFAULT; + break; + + default: + err = -EOPNOTSUPP; + break; + } +out: + rtnl_unlock(); + return err; +} + +static int __net_init vlan_init_net(struct net *net) +{ + struct vlan_net *vn = net_generic(net, vlan_net_id); + int err; + + vn->name_type = VLAN_NAME_TYPE_RAW_PLUS_VID_NO_PAD; + + err = vlan_proc_init(net); + + return err; +} + +static void __net_exit vlan_exit_net(struct net *net) +{ + vlan_proc_cleanup(net); +} + +static struct pernet_operations vlan_net_ops = { + .init = vlan_init_net, + .exit = vlan_exit_net, + .id = &vlan_net_id, + .size = sizeof(struct vlan_net), +}; + +static int __init vlan_proto_init(void) +{ + int err; + + pr_info("%s v%s\n", vlan_fullname, vlan_version); + + err = register_pernet_subsys(&vlan_net_ops); + if (err < 0) + goto err0; + + err = register_netdevice_notifier(&vlan_notifier_block); + if (err < 0) + goto err2; + + err = vlan_gvrp_init(); + if (err < 0) + goto err3; + + err = vlan_mvrp_init(); + if (err < 0) + goto err4; + + err = vlan_netlink_init(); + if (err < 0) + goto err5; + + vlan_ioctl_set(vlan_ioctl_handler); + return 0; + +err5: + vlan_mvrp_uninit(); +err4: + vlan_gvrp_uninit(); +err3: + unregister_netdevice_notifier(&vlan_notifier_block); +err2: + unregister_pernet_subsys(&vlan_net_ops); +err0: + return err; +} + +static void __exit vlan_cleanup_module(void) +{ + vlan_ioctl_set(NULL); + + vlan_netlink_fini(); + + unregister_netdevice_notifier(&vlan_notifier_block); + + unregister_pernet_subsys(&vlan_net_ops); + rcu_barrier(); /* Wait for completion of call_rcu()'s */ + + vlan_mvrp_uninit(); + vlan_gvrp_uninit(); +} + +module_init(vlan_proto_init); +module_exit(vlan_cleanup_module); + +MODULE_LICENSE("GPL"); +MODULE_VERSION(DRV_VERSION); diff --git a/net/8021q/vlan.h b/net/8021q/vlan.h new file mode 100644 index 000000000..5eaf38875 --- /dev/null +++ b/net/8021q/vlan.h @@ -0,0 +1,206 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __BEN_VLAN_802_1Q_INC__ +#define __BEN_VLAN_802_1Q_INC__ + +#include +#include +#include + +/* if this changes, algorithm will have to be reworked because this + * depends on completely exhausting the VLAN identifier space. Thus + * it gives constant time look-up, but in many cases it wastes memory. + */ +#define VLAN_GROUP_ARRAY_SPLIT_PARTS 8 +#define VLAN_GROUP_ARRAY_PART_LEN (VLAN_N_VID/VLAN_GROUP_ARRAY_SPLIT_PARTS) + +enum vlan_protos { + VLAN_PROTO_8021Q = 0, + VLAN_PROTO_8021AD, + VLAN_PROTO_NUM, +}; + +struct vlan_group { + unsigned int nr_vlan_devs; + struct hlist_node hlist; /* linked list */ + struct net_device **vlan_devices_arrays[VLAN_PROTO_NUM] + [VLAN_GROUP_ARRAY_SPLIT_PARTS]; +}; + +struct vlan_info { + struct net_device *real_dev; /* The ethernet(like) device + * the vlan is attached to. + */ + struct vlan_group grp; + struct list_head vid_list; + unsigned int nr_vids; + struct rcu_head rcu; +}; + +static inline int vlan_proto_idx(__be16 proto) +{ + switch (proto) { + case htons(ETH_P_8021Q): + return VLAN_PROTO_8021Q; + case htons(ETH_P_8021AD): + return VLAN_PROTO_8021AD; + default: + WARN(1, "invalid VLAN protocol: 0x%04x\n", ntohs(proto)); + return -EINVAL; + } +} + +static inline struct net_device *__vlan_group_get_device(struct vlan_group *vg, + unsigned int pidx, + u16 vlan_id) +{ + struct net_device **array; + + array = vg->vlan_devices_arrays[pidx] + [vlan_id / VLAN_GROUP_ARRAY_PART_LEN]; + + /* paired with smp_wmb() in vlan_group_prealloc_vid() */ + smp_rmb(); + + return array ? array[vlan_id % VLAN_GROUP_ARRAY_PART_LEN] : NULL; +} + +static inline struct net_device *vlan_group_get_device(struct vlan_group *vg, + __be16 vlan_proto, + u16 vlan_id) +{ + int pidx = vlan_proto_idx(vlan_proto); + + if (pidx < 0) + return NULL; + + return __vlan_group_get_device(vg, pidx, vlan_id); +} + +static inline void vlan_group_set_device(struct vlan_group *vg, + __be16 vlan_proto, u16 vlan_id, + struct net_device *dev) +{ + int pidx = vlan_proto_idx(vlan_proto); + struct net_device **array; + + if (!vg || pidx < 0) + return; + array = vg->vlan_devices_arrays[pidx] + [vlan_id / VLAN_GROUP_ARRAY_PART_LEN]; + array[vlan_id % VLAN_GROUP_ARRAY_PART_LEN] = dev; +} + +/* Must be invoked with rcu_read_lock or with RTNL. */ +static inline struct net_device *vlan_find_dev(struct net_device *real_dev, + __be16 vlan_proto, u16 vlan_id) +{ + struct vlan_info *vlan_info = rcu_dereference_rtnl(real_dev->vlan_info); + + if (vlan_info) + return vlan_group_get_device(&vlan_info->grp, + vlan_proto, vlan_id); + + return NULL; +} + +static inline netdev_features_t vlan_tnl_features(struct net_device *real_dev) +{ + netdev_features_t ret; + + ret = real_dev->hw_enc_features & + (NETIF_F_CSUM_MASK | NETIF_F_GSO_SOFTWARE | + NETIF_F_GSO_ENCAP_ALL); + + if ((ret & NETIF_F_GSO_ENCAP_ALL) && (ret & NETIF_F_CSUM_MASK)) + return (ret & ~NETIF_F_CSUM_MASK) | NETIF_F_HW_CSUM; + return 0; +} + +#define vlan_group_for_each_dev(grp, i, dev) \ + for ((i) = 0; i < VLAN_PROTO_NUM * VLAN_N_VID; i++) \ + if (((dev) = __vlan_group_get_device((grp), (i) / VLAN_N_VID, \ + (i) % VLAN_N_VID))) + +int vlan_filter_push_vids(struct vlan_info *vlan_info, __be16 proto); +void vlan_filter_drop_vids(struct vlan_info *vlan_info, __be16 proto); + +/* found in vlan_dev.c */ +void vlan_dev_set_ingress_priority(const struct net_device *dev, + u32 skb_prio, u16 vlan_prio); +int vlan_dev_set_egress_priority(const struct net_device *dev, + u32 skb_prio, u16 vlan_prio); +void vlan_dev_free_egress_priority(const struct net_device *dev); +int vlan_dev_change_flags(const struct net_device *dev, u32 flag, u32 mask); +void vlan_dev_get_realdev_name(const struct net_device *dev, char *result, + size_t size); + +int vlan_check_real_dev(struct net_device *real_dev, + __be16 protocol, u16 vlan_id, + struct netlink_ext_ack *extack); +void vlan_setup(struct net_device *dev); +int register_vlan_dev(struct net_device *dev, struct netlink_ext_ack *extack); +void unregister_vlan_dev(struct net_device *dev, struct list_head *head); +bool vlan_dev_inherit_address(struct net_device *dev, + struct net_device *real_dev); + +static inline u32 vlan_get_ingress_priority(struct net_device *dev, + u16 vlan_tci) +{ + struct vlan_dev_priv *vip = vlan_dev_priv(dev); + + return vip->ingress_priority_map[(vlan_tci >> VLAN_PRIO_SHIFT) & 0x7]; +} + +#ifdef CONFIG_VLAN_8021Q_GVRP +int vlan_gvrp_request_join(const struct net_device *dev); +void vlan_gvrp_request_leave(const struct net_device *dev); +int vlan_gvrp_init_applicant(struct net_device *dev); +void vlan_gvrp_uninit_applicant(struct net_device *dev); +int vlan_gvrp_init(void); +void vlan_gvrp_uninit(void); +#else +static inline int vlan_gvrp_request_join(const struct net_device *dev) { return 0; } +static inline void vlan_gvrp_request_leave(const struct net_device *dev) {} +static inline int vlan_gvrp_init_applicant(struct net_device *dev) { return 0; } +static inline void vlan_gvrp_uninit_applicant(struct net_device *dev) {} +static inline int vlan_gvrp_init(void) { return 0; } +static inline void vlan_gvrp_uninit(void) {} +#endif + +#ifdef CONFIG_VLAN_8021Q_MVRP +int vlan_mvrp_request_join(const struct net_device *dev); +void vlan_mvrp_request_leave(const struct net_device *dev); +int vlan_mvrp_init_applicant(struct net_device *dev); +void vlan_mvrp_uninit_applicant(struct net_device *dev); +int vlan_mvrp_init(void); +void vlan_mvrp_uninit(void); +#else +static inline int vlan_mvrp_request_join(const struct net_device *dev) { return 0; } +static inline void vlan_mvrp_request_leave(const struct net_device *dev) {} +static inline int vlan_mvrp_init_applicant(struct net_device *dev) { return 0; } +static inline void vlan_mvrp_uninit_applicant(struct net_device *dev) {} +static inline int vlan_mvrp_init(void) { return 0; } +static inline void vlan_mvrp_uninit(void) {} +#endif + +extern const char vlan_fullname[]; +extern const char vlan_version[]; +int vlan_netlink_init(void); +void vlan_netlink_fini(void); + +extern struct rtnl_link_ops vlan_link_ops; + +extern unsigned int vlan_net_id; + +struct proc_dir_entry; + +struct vlan_net { + /* /proc/net/vlan */ + struct proc_dir_entry *proc_vlan_dir; + /* /proc/net/vlan/config */ + struct proc_dir_entry *proc_vlan_conf; + /* Determines interface naming scheme. */ + unsigned short name_type; +}; + +#endif /* !(__BEN_VLAN_802_1Q_INC__) */ diff --git a/net/8021q/vlan_core.c b/net/8021q/vlan_core.c new file mode 100644 index 000000000..f00158234 --- /dev/null +++ b/net/8021q/vlan_core.c @@ -0,0 +1,558 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include +#include "vlan.h" + +bool vlan_do_receive(struct sk_buff **skbp) +{ + struct sk_buff *skb = *skbp; + __be16 vlan_proto = skb->vlan_proto; + u16 vlan_id = skb_vlan_tag_get_id(skb); + struct net_device *vlan_dev; + struct vlan_pcpu_stats *rx_stats; + + vlan_dev = vlan_find_dev(skb->dev, vlan_proto, vlan_id); + if (!vlan_dev) + return false; + + skb = *skbp = skb_share_check(skb, GFP_ATOMIC); + if (unlikely(!skb)) + return false; + + if (unlikely(!(vlan_dev->flags & IFF_UP))) { + kfree_skb(skb); + *skbp = NULL; + return false; + } + + skb->dev = vlan_dev; + if (unlikely(skb->pkt_type == PACKET_OTHERHOST)) { + /* Our lower layer thinks this is not local, let's make sure. + * This allows the VLAN to have a different MAC than the + * underlying device, and still route correctly. */ + if (ether_addr_equal_64bits(eth_hdr(skb)->h_dest, vlan_dev->dev_addr)) + skb->pkt_type = PACKET_HOST; + } + + if (!(vlan_dev_priv(vlan_dev)->flags & VLAN_FLAG_REORDER_HDR) && + !netif_is_macvlan_port(vlan_dev) && + !netif_is_bridge_port(vlan_dev)) { + unsigned int offset = skb->data - skb_mac_header(skb); + + /* + * vlan_insert_tag expect skb->data pointing to mac header. + * So change skb->data before calling it and change back to + * original position later + */ + skb_push(skb, offset); + skb = *skbp = vlan_insert_inner_tag(skb, skb->vlan_proto, + skb->vlan_tci, skb->mac_len); + if (!skb) + return false; + skb_pull(skb, offset + VLAN_HLEN); + skb_reset_mac_len(skb); + } + + skb->priority = vlan_get_ingress_priority(vlan_dev, skb->vlan_tci); + __vlan_hwaccel_clear_tag(skb); + + rx_stats = this_cpu_ptr(vlan_dev_priv(vlan_dev)->vlan_pcpu_stats); + + u64_stats_update_begin(&rx_stats->syncp); + u64_stats_inc(&rx_stats->rx_packets); + u64_stats_add(&rx_stats->rx_bytes, skb->len); + if (skb->pkt_type == PACKET_MULTICAST) + u64_stats_inc(&rx_stats->rx_multicast); + u64_stats_update_end(&rx_stats->syncp); + + return true; +} + +/* Must be invoked with rcu_read_lock. */ +struct net_device *__vlan_find_dev_deep_rcu(struct net_device *dev, + __be16 vlan_proto, u16 vlan_id) +{ + struct vlan_info *vlan_info = rcu_dereference(dev->vlan_info); + + if (vlan_info) { + return vlan_group_get_device(&vlan_info->grp, + vlan_proto, vlan_id); + } else { + /* + * Lower devices of master uppers (bonding, team) do not have + * grp assigned to themselves. Grp is assigned to upper device + * instead. + */ + struct net_device *upper_dev; + + upper_dev = netdev_master_upper_dev_get_rcu(dev); + if (upper_dev) + return __vlan_find_dev_deep_rcu(upper_dev, + vlan_proto, vlan_id); + } + + return NULL; +} +EXPORT_SYMBOL(__vlan_find_dev_deep_rcu); + +struct net_device *vlan_dev_real_dev(const struct net_device *dev) +{ + struct net_device *ret = vlan_dev_priv(dev)->real_dev; + + while (is_vlan_dev(ret)) + ret = vlan_dev_priv(ret)->real_dev; + + return ret; +} +EXPORT_SYMBOL(vlan_dev_real_dev); + +u16 vlan_dev_vlan_id(const struct net_device *dev) +{ + return vlan_dev_priv(dev)->vlan_id; +} +EXPORT_SYMBOL(vlan_dev_vlan_id); + +__be16 vlan_dev_vlan_proto(const struct net_device *dev) +{ + return vlan_dev_priv(dev)->vlan_proto; +} +EXPORT_SYMBOL(vlan_dev_vlan_proto); + +/* + * vlan info and vid list + */ + +static void vlan_group_free(struct vlan_group *grp) +{ + int i, j; + + for (i = 0; i < VLAN_PROTO_NUM; i++) + for (j = 0; j < VLAN_GROUP_ARRAY_SPLIT_PARTS; j++) + kfree(grp->vlan_devices_arrays[i][j]); +} + +static void vlan_info_free(struct vlan_info *vlan_info) +{ + vlan_group_free(&vlan_info->grp); + kfree(vlan_info); +} + +static void vlan_info_rcu_free(struct rcu_head *rcu) +{ + vlan_info_free(container_of(rcu, struct vlan_info, rcu)); +} + +static struct vlan_info *vlan_info_alloc(struct net_device *dev) +{ + struct vlan_info *vlan_info; + + vlan_info = kzalloc(sizeof(struct vlan_info), GFP_KERNEL); + if (!vlan_info) + return NULL; + + vlan_info->real_dev = dev; + INIT_LIST_HEAD(&vlan_info->vid_list); + return vlan_info; +} + +struct vlan_vid_info { + struct list_head list; + __be16 proto; + u16 vid; + int refcount; +}; + +static bool vlan_hw_filter_capable(const struct net_device *dev, __be16 proto) +{ + if (proto == htons(ETH_P_8021Q) && + dev->features & NETIF_F_HW_VLAN_CTAG_FILTER) + return true; + if (proto == htons(ETH_P_8021AD) && + dev->features & NETIF_F_HW_VLAN_STAG_FILTER) + return true; + return false; +} + +static struct vlan_vid_info *vlan_vid_info_get(struct vlan_info *vlan_info, + __be16 proto, u16 vid) +{ + struct vlan_vid_info *vid_info; + + list_for_each_entry(vid_info, &vlan_info->vid_list, list) { + if (vid_info->proto == proto && vid_info->vid == vid) + return vid_info; + } + return NULL; +} + +static struct vlan_vid_info *vlan_vid_info_alloc(__be16 proto, u16 vid) +{ + struct vlan_vid_info *vid_info; + + vid_info = kzalloc(sizeof(struct vlan_vid_info), GFP_KERNEL); + if (!vid_info) + return NULL; + vid_info->proto = proto; + vid_info->vid = vid; + + return vid_info; +} + +static int vlan_add_rx_filter_info(struct net_device *dev, __be16 proto, u16 vid) +{ + if (!vlan_hw_filter_capable(dev, proto)) + return 0; + + if (netif_device_present(dev)) + return dev->netdev_ops->ndo_vlan_rx_add_vid(dev, proto, vid); + else + return -ENODEV; +} + +static int vlan_kill_rx_filter_info(struct net_device *dev, __be16 proto, u16 vid) +{ + if (!vlan_hw_filter_capable(dev, proto)) + return 0; + + if (netif_device_present(dev)) + return dev->netdev_ops->ndo_vlan_rx_kill_vid(dev, proto, vid); + else + return -ENODEV; +} + +int vlan_for_each(struct net_device *dev, + int (*action)(struct net_device *dev, int vid, void *arg), + void *arg) +{ + struct vlan_vid_info *vid_info; + struct vlan_info *vlan_info; + struct net_device *vdev; + int ret; + + ASSERT_RTNL(); + + vlan_info = rtnl_dereference(dev->vlan_info); + if (!vlan_info) + return 0; + + list_for_each_entry(vid_info, &vlan_info->vid_list, list) { + vdev = vlan_group_get_device(&vlan_info->grp, vid_info->proto, + vid_info->vid); + ret = action(vdev, vid_info->vid, arg); + if (ret) + return ret; + } + + return 0; +} +EXPORT_SYMBOL(vlan_for_each); + +int vlan_filter_push_vids(struct vlan_info *vlan_info, __be16 proto) +{ + struct net_device *real_dev = vlan_info->real_dev; + struct vlan_vid_info *vlan_vid_info; + int err; + + list_for_each_entry(vlan_vid_info, &vlan_info->vid_list, list) { + if (vlan_vid_info->proto == proto) { + err = vlan_add_rx_filter_info(real_dev, proto, + vlan_vid_info->vid); + if (err) + goto unwind; + } + } + + return 0; + +unwind: + list_for_each_entry_continue_reverse(vlan_vid_info, + &vlan_info->vid_list, list) { + if (vlan_vid_info->proto == proto) + vlan_kill_rx_filter_info(real_dev, proto, + vlan_vid_info->vid); + } + + return err; +} +EXPORT_SYMBOL(vlan_filter_push_vids); + +void vlan_filter_drop_vids(struct vlan_info *vlan_info, __be16 proto) +{ + struct vlan_vid_info *vlan_vid_info; + + list_for_each_entry(vlan_vid_info, &vlan_info->vid_list, list) + if (vlan_vid_info->proto == proto) + vlan_kill_rx_filter_info(vlan_info->real_dev, + vlan_vid_info->proto, + vlan_vid_info->vid); +} +EXPORT_SYMBOL(vlan_filter_drop_vids); + +static int __vlan_vid_add(struct vlan_info *vlan_info, __be16 proto, u16 vid, + struct vlan_vid_info **pvid_info) +{ + struct net_device *dev = vlan_info->real_dev; + struct vlan_vid_info *vid_info; + int err; + + vid_info = vlan_vid_info_alloc(proto, vid); + if (!vid_info) + return -ENOMEM; + + err = vlan_add_rx_filter_info(dev, proto, vid); + if (err) { + kfree(vid_info); + return err; + } + + list_add(&vid_info->list, &vlan_info->vid_list); + vlan_info->nr_vids++; + *pvid_info = vid_info; + return 0; +} + +int vlan_vid_add(struct net_device *dev, __be16 proto, u16 vid) +{ + struct vlan_info *vlan_info; + struct vlan_vid_info *vid_info; + bool vlan_info_created = false; + int err; + + ASSERT_RTNL(); + + vlan_info = rtnl_dereference(dev->vlan_info); + if (!vlan_info) { + vlan_info = vlan_info_alloc(dev); + if (!vlan_info) + return -ENOMEM; + vlan_info_created = true; + } + vid_info = vlan_vid_info_get(vlan_info, proto, vid); + if (!vid_info) { + err = __vlan_vid_add(vlan_info, proto, vid, &vid_info); + if (err) + goto out_free_vlan_info; + } + vid_info->refcount++; + + if (vlan_info_created) + rcu_assign_pointer(dev->vlan_info, vlan_info); + + return 0; + +out_free_vlan_info: + if (vlan_info_created) + kfree(vlan_info); + return err; +} +EXPORT_SYMBOL(vlan_vid_add); + +static void __vlan_vid_del(struct vlan_info *vlan_info, + struct vlan_vid_info *vid_info) +{ + struct net_device *dev = vlan_info->real_dev; + __be16 proto = vid_info->proto; + u16 vid = vid_info->vid; + int err; + + err = vlan_kill_rx_filter_info(dev, proto, vid); + if (err && dev->reg_state != NETREG_UNREGISTERING) + netdev_warn(dev, "failed to kill vid %04x/%d\n", proto, vid); + + list_del(&vid_info->list); + kfree(vid_info); + vlan_info->nr_vids--; +} + +void vlan_vid_del(struct net_device *dev, __be16 proto, u16 vid) +{ + struct vlan_info *vlan_info; + struct vlan_vid_info *vid_info; + + ASSERT_RTNL(); + + vlan_info = rtnl_dereference(dev->vlan_info); + if (!vlan_info) + return; + + vid_info = vlan_vid_info_get(vlan_info, proto, vid); + if (!vid_info) + return; + vid_info->refcount--; + if (vid_info->refcount == 0) { + __vlan_vid_del(vlan_info, vid_info); + if (vlan_info->nr_vids == 0) { + RCU_INIT_POINTER(dev->vlan_info, NULL); + call_rcu(&vlan_info->rcu, vlan_info_rcu_free); + } + } +} +EXPORT_SYMBOL(vlan_vid_del); + +int vlan_vids_add_by_dev(struct net_device *dev, + const struct net_device *by_dev) +{ + struct vlan_vid_info *vid_info; + struct vlan_info *vlan_info; + int err; + + ASSERT_RTNL(); + + vlan_info = rtnl_dereference(by_dev->vlan_info); + if (!vlan_info) + return 0; + + list_for_each_entry(vid_info, &vlan_info->vid_list, list) { + if (!vlan_hw_filter_capable(by_dev, vid_info->proto)) + continue; + err = vlan_vid_add(dev, vid_info->proto, vid_info->vid); + if (err) + goto unwind; + } + return 0; + +unwind: + list_for_each_entry_continue_reverse(vid_info, + &vlan_info->vid_list, + list) { + if (!vlan_hw_filter_capable(by_dev, vid_info->proto)) + continue; + vlan_vid_del(dev, vid_info->proto, vid_info->vid); + } + + return err; +} +EXPORT_SYMBOL(vlan_vids_add_by_dev); + +void vlan_vids_del_by_dev(struct net_device *dev, + const struct net_device *by_dev) +{ + struct vlan_vid_info *vid_info; + struct vlan_info *vlan_info; + + ASSERT_RTNL(); + + vlan_info = rtnl_dereference(by_dev->vlan_info); + if (!vlan_info) + return; + + list_for_each_entry(vid_info, &vlan_info->vid_list, list) { + if (!vlan_hw_filter_capable(by_dev, vid_info->proto)) + continue; + vlan_vid_del(dev, vid_info->proto, vid_info->vid); + } +} +EXPORT_SYMBOL(vlan_vids_del_by_dev); + +bool vlan_uses_dev(const struct net_device *dev) +{ + struct vlan_info *vlan_info; + + ASSERT_RTNL(); + + vlan_info = rtnl_dereference(dev->vlan_info); + if (!vlan_info) + return false; + return vlan_info->grp.nr_vlan_devs ? true : false; +} +EXPORT_SYMBOL(vlan_uses_dev); + +static struct sk_buff *vlan_gro_receive(struct list_head *head, + struct sk_buff *skb) +{ + const struct packet_offload *ptype; + unsigned int hlen, off_vlan; + struct sk_buff *pp = NULL; + struct vlan_hdr *vhdr; + struct sk_buff *p; + __be16 type; + int flush = 1; + + off_vlan = skb_gro_offset(skb); + hlen = off_vlan + sizeof(*vhdr); + vhdr = skb_gro_header(skb, hlen, off_vlan); + if (unlikely(!vhdr)) + goto out; + + type = vhdr->h_vlan_encapsulated_proto; + + ptype = gro_find_receive_by_type(type); + if (!ptype) + goto out; + + flush = 0; + + list_for_each_entry(p, head, list) { + struct vlan_hdr *vhdr2; + + if (!NAPI_GRO_CB(p)->same_flow) + continue; + + vhdr2 = (struct vlan_hdr *)(p->data + off_vlan); + if (compare_vlan_header(vhdr, vhdr2)) + NAPI_GRO_CB(p)->same_flow = 0; + } + + skb_gro_pull(skb, sizeof(*vhdr)); + skb_gro_postpull_rcsum(skb, vhdr, sizeof(*vhdr)); + + pp = indirect_call_gro_receive_inet(ptype->callbacks.gro_receive, + ipv6_gro_receive, inet_gro_receive, + head, skb); + +out: + skb_gro_flush_final(skb, pp, flush); + + return pp; +} + +static int vlan_gro_complete(struct sk_buff *skb, int nhoff) +{ + struct vlan_hdr *vhdr = (struct vlan_hdr *)(skb->data + nhoff); + __be16 type = vhdr->h_vlan_encapsulated_proto; + struct packet_offload *ptype; + int err = -ENOENT; + + ptype = gro_find_complete_by_type(type); + if (ptype) + err = INDIRECT_CALL_INET(ptype->callbacks.gro_complete, + ipv6_gro_complete, inet_gro_complete, + skb, nhoff + sizeof(*vhdr)); + + return err; +} + +static struct packet_offload vlan_packet_offloads[] __read_mostly = { + { + .type = cpu_to_be16(ETH_P_8021Q), + .priority = 10, + .callbacks = { + .gro_receive = vlan_gro_receive, + .gro_complete = vlan_gro_complete, + }, + }, + { + .type = cpu_to_be16(ETH_P_8021AD), + .priority = 10, + .callbacks = { + .gro_receive = vlan_gro_receive, + .gro_complete = vlan_gro_complete, + }, + }, +}; + +static int __init vlan_offload_init(void) +{ + unsigned int i; + + for (i = 0; i < ARRAY_SIZE(vlan_packet_offloads); i++) + dev_add_offload(&vlan_packet_offloads[i]); + + return 0; +} + +fs_initcall(vlan_offload_init); diff --git a/net/8021q/vlan_dev.c b/net/8021q/vlan_dev.c new file mode 100644 index 000000000..d3e511e1e --- /dev/null +++ b/net/8021q/vlan_dev.c @@ -0,0 +1,876 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* -*- linux-c -*- + * INET 802.1Q VLAN + * Ethernet-type device handling. + * + * Authors: Ben Greear + * Please send support related email to: netdev@vger.kernel.org + * VLAN Home Page: http://www.candelatech.com/~greear/vlan.html + * + * Fixes: Mar 22 2001: Martin Bokaemper + * - reset skb->pkt_type on incoming packets when MAC was changed + * - see that changed MAC is saddr for outgoing packets + * Oct 20, 2001: Ard van Breeman: + * - Fix MC-list, finally. + * - Flush MC-list on VLAN destroy. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "vlan.h" +#include "vlanproc.h" +#include +#include + +/* + * Create the VLAN header for an arbitrary protocol layer + * + * saddr=NULL means use device source address + * daddr=NULL means leave destination address (eg unresolved arp) + * + * This is called when the SKB is moving down the stack towards the + * physical devices. + */ +static int vlan_dev_hard_header(struct sk_buff *skb, struct net_device *dev, + unsigned short type, + const void *daddr, const void *saddr, + unsigned int len) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + struct vlan_hdr *vhdr; + unsigned int vhdrlen = 0; + u16 vlan_tci = 0; + int rc; + + if (!(vlan->flags & VLAN_FLAG_REORDER_HDR)) { + vhdr = skb_push(skb, VLAN_HLEN); + + vlan_tci = vlan->vlan_id; + vlan_tci |= vlan_dev_get_egress_qos_mask(dev, skb->priority); + vhdr->h_vlan_TCI = htons(vlan_tci); + + /* + * Set the protocol type. For a packet of type ETH_P_802_3/2 we + * put the length in here instead. + */ + if (type != ETH_P_802_3 && type != ETH_P_802_2) + vhdr->h_vlan_encapsulated_proto = htons(type); + else + vhdr->h_vlan_encapsulated_proto = htons(len); + + skb->protocol = vlan->vlan_proto; + type = ntohs(vlan->vlan_proto); + vhdrlen = VLAN_HLEN; + } + + /* Before delegating work to the lower layer, enter our MAC-address */ + if (saddr == NULL) + saddr = dev->dev_addr; + + /* Now make the underlying real hard header */ + dev = vlan->real_dev; + rc = dev_hard_header(skb, dev, type, daddr, saddr, len + vhdrlen); + if (rc > 0) + rc += vhdrlen; + return rc; +} + +static inline netdev_tx_t vlan_netpoll_send_skb(struct vlan_dev_priv *vlan, struct sk_buff *skb) +{ +#ifdef CONFIG_NET_POLL_CONTROLLER + return netpoll_send_skb(vlan->netpoll, skb); +#else + BUG(); + return NETDEV_TX_OK; +#endif +} + +static netdev_tx_t vlan_dev_hard_start_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + struct vlan_ethhdr *veth = (struct vlan_ethhdr *)(skb->data); + unsigned int len; + int ret; + + /* Handle non-VLAN frames if they are sent to us, for example by DHCP. + * + * NOTE: THIS ASSUMES DIX ETHERNET, SPECIFICALLY NOT SUPPORTING + * OTHER THINGS LIKE FDDI/TokenRing/802.3 SNAPs... + */ + if (vlan->flags & VLAN_FLAG_REORDER_HDR || + veth->h_vlan_proto != vlan->vlan_proto) { + u16 vlan_tci; + vlan_tci = vlan->vlan_id; + vlan_tci |= vlan_dev_get_egress_qos_mask(dev, skb->priority); + __vlan_hwaccel_put_tag(skb, vlan->vlan_proto, vlan_tci); + } + + skb->dev = vlan->real_dev; + len = skb->len; + if (unlikely(netpoll_tx_running(dev))) + return vlan_netpoll_send_skb(vlan, skb); + + ret = dev_queue_xmit(skb); + + if (likely(ret == NET_XMIT_SUCCESS || ret == NET_XMIT_CN)) { + struct vlan_pcpu_stats *stats; + + stats = this_cpu_ptr(vlan->vlan_pcpu_stats); + u64_stats_update_begin(&stats->syncp); + u64_stats_inc(&stats->tx_packets); + u64_stats_add(&stats->tx_bytes, len); + u64_stats_update_end(&stats->syncp); + } else { + this_cpu_inc(vlan->vlan_pcpu_stats->tx_dropped); + } + + return ret; +} + +static int vlan_dev_change_mtu(struct net_device *dev, int new_mtu) +{ + struct net_device *real_dev = vlan_dev_priv(dev)->real_dev; + unsigned int max_mtu = real_dev->mtu; + + if (netif_reduces_vlan_mtu(real_dev)) + max_mtu -= VLAN_HLEN; + if (max_mtu < new_mtu) + return -ERANGE; + + dev->mtu = new_mtu; + + return 0; +} + +void vlan_dev_set_ingress_priority(const struct net_device *dev, + u32 skb_prio, u16 vlan_prio) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + + if (vlan->ingress_priority_map[vlan_prio & 0x7] && !skb_prio) + vlan->nr_ingress_mappings--; + else if (!vlan->ingress_priority_map[vlan_prio & 0x7] && skb_prio) + vlan->nr_ingress_mappings++; + + vlan->ingress_priority_map[vlan_prio & 0x7] = skb_prio; +} + +int vlan_dev_set_egress_priority(const struct net_device *dev, + u32 skb_prio, u16 vlan_prio) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + struct vlan_priority_tci_mapping *mp = NULL; + struct vlan_priority_tci_mapping *np; + u32 vlan_qos = (vlan_prio << VLAN_PRIO_SHIFT) & VLAN_PRIO_MASK; + + /* See if a priority mapping exists.. */ + mp = vlan->egress_priority_map[skb_prio & 0xF]; + while (mp) { + if (mp->priority == skb_prio) { + if (mp->vlan_qos && !vlan_qos) + vlan->nr_egress_mappings--; + else if (!mp->vlan_qos && vlan_qos) + vlan->nr_egress_mappings++; + mp->vlan_qos = vlan_qos; + return 0; + } + mp = mp->next; + } + + /* Create a new mapping then. */ + mp = vlan->egress_priority_map[skb_prio & 0xF]; + np = kmalloc(sizeof(struct vlan_priority_tci_mapping), GFP_KERNEL); + if (!np) + return -ENOBUFS; + + np->next = mp; + np->priority = skb_prio; + np->vlan_qos = vlan_qos; + /* Before inserting this element in hash table, make sure all its fields + * are committed to memory. + * coupled with smp_rmb() in vlan_dev_get_egress_qos_mask() + */ + smp_wmb(); + vlan->egress_priority_map[skb_prio & 0xF] = np; + if (vlan_qos) + vlan->nr_egress_mappings++; + return 0; +} + +/* Flags are defined in the vlan_flags enum in + * include/uapi/linux/if_vlan.h file. + */ +int vlan_dev_change_flags(const struct net_device *dev, u32 flags, u32 mask) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + u32 old_flags = vlan->flags; + + if (mask & ~(VLAN_FLAG_REORDER_HDR | VLAN_FLAG_GVRP | + VLAN_FLAG_LOOSE_BINDING | VLAN_FLAG_MVRP | + VLAN_FLAG_BRIDGE_BINDING)) + return -EINVAL; + + vlan->flags = (old_flags & ~mask) | (flags & mask); + + if (netif_running(dev) && (vlan->flags ^ old_flags) & VLAN_FLAG_GVRP) { + if (vlan->flags & VLAN_FLAG_GVRP) + vlan_gvrp_request_join(dev); + else + vlan_gvrp_request_leave(dev); + } + + if (netif_running(dev) && (vlan->flags ^ old_flags) & VLAN_FLAG_MVRP) { + if (vlan->flags & VLAN_FLAG_MVRP) + vlan_mvrp_request_join(dev); + else + vlan_mvrp_request_leave(dev); + } + return 0; +} + +void vlan_dev_get_realdev_name(const struct net_device *dev, char *result, size_t size) +{ + strscpy_pad(result, vlan_dev_priv(dev)->real_dev->name, size); +} + +bool vlan_dev_inherit_address(struct net_device *dev, + struct net_device *real_dev) +{ + if (dev->addr_assign_type != NET_ADDR_STOLEN) + return false; + + eth_hw_addr_set(dev, real_dev->dev_addr); + call_netdevice_notifiers(NETDEV_CHANGEADDR, dev); + return true; +} + +static int vlan_dev_open(struct net_device *dev) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + struct net_device *real_dev = vlan->real_dev; + int err; + + if (!(real_dev->flags & IFF_UP) && + !(vlan->flags & VLAN_FLAG_LOOSE_BINDING)) + return -ENETDOWN; + + if (!ether_addr_equal(dev->dev_addr, real_dev->dev_addr) && + !vlan_dev_inherit_address(dev, real_dev)) { + err = dev_uc_add(real_dev, dev->dev_addr); + if (err < 0) + goto out; + } + + if (dev->flags & IFF_ALLMULTI) { + err = dev_set_allmulti(real_dev, 1); + if (err < 0) + goto del_unicast; + } + if (dev->flags & IFF_PROMISC) { + err = dev_set_promiscuity(real_dev, 1); + if (err < 0) + goto clear_allmulti; + } + + ether_addr_copy(vlan->real_dev_addr, real_dev->dev_addr); + + if (vlan->flags & VLAN_FLAG_GVRP) + vlan_gvrp_request_join(dev); + + if (vlan->flags & VLAN_FLAG_MVRP) + vlan_mvrp_request_join(dev); + + if (netif_carrier_ok(real_dev) && + !(vlan->flags & VLAN_FLAG_BRIDGE_BINDING)) + netif_carrier_on(dev); + return 0; + +clear_allmulti: + if (dev->flags & IFF_ALLMULTI) + dev_set_allmulti(real_dev, -1); +del_unicast: + if (!ether_addr_equal(dev->dev_addr, real_dev->dev_addr)) + dev_uc_del(real_dev, dev->dev_addr); +out: + netif_carrier_off(dev); + return err; +} + +static int vlan_dev_stop(struct net_device *dev) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + struct net_device *real_dev = vlan->real_dev; + + dev_mc_unsync(real_dev, dev); + dev_uc_unsync(real_dev, dev); + if (dev->flags & IFF_ALLMULTI) + dev_set_allmulti(real_dev, -1); + if (dev->flags & IFF_PROMISC) + dev_set_promiscuity(real_dev, -1); + + if (!ether_addr_equal(dev->dev_addr, real_dev->dev_addr)) + dev_uc_del(real_dev, dev->dev_addr); + + if (!(vlan->flags & VLAN_FLAG_BRIDGE_BINDING)) + netif_carrier_off(dev); + return 0; +} + +static int vlan_dev_set_mac_address(struct net_device *dev, void *p) +{ + struct net_device *real_dev = vlan_dev_priv(dev)->real_dev; + struct sockaddr *addr = p; + int err; + + if (!is_valid_ether_addr(addr->sa_data)) + return -EADDRNOTAVAIL; + + if (!(dev->flags & IFF_UP)) + goto out; + + if (!ether_addr_equal(addr->sa_data, real_dev->dev_addr)) { + err = dev_uc_add(real_dev, addr->sa_data); + if (err < 0) + return err; + } + + if (!ether_addr_equal(dev->dev_addr, real_dev->dev_addr)) + dev_uc_del(real_dev, dev->dev_addr); + +out: + eth_hw_addr_set(dev, addr->sa_data); + return 0; +} + +static int vlan_dev_ioctl(struct net_device *dev, struct ifreq *ifr, int cmd) +{ + struct net_device *real_dev = vlan_dev_priv(dev)->real_dev; + const struct net_device_ops *ops = real_dev->netdev_ops; + struct ifreq ifrr; + int err = -EOPNOTSUPP; + + strscpy_pad(ifrr.ifr_name, real_dev->name, IFNAMSIZ); + ifrr.ifr_ifru = ifr->ifr_ifru; + + switch (cmd) { + case SIOCSHWTSTAMP: + if (!net_eq(dev_net(dev), dev_net(real_dev))) + break; + fallthrough; + case SIOCGMIIPHY: + case SIOCGMIIREG: + case SIOCSMIIREG: + case SIOCGHWTSTAMP: + if (netif_device_present(real_dev) && ops->ndo_eth_ioctl) + err = ops->ndo_eth_ioctl(real_dev, &ifrr, cmd); + break; + } + + if (!err) + ifr->ifr_ifru = ifrr.ifr_ifru; + + return err; +} + +static int vlan_dev_neigh_setup(struct net_device *dev, struct neigh_parms *pa) +{ + struct net_device *real_dev = vlan_dev_priv(dev)->real_dev; + const struct net_device_ops *ops = real_dev->netdev_ops; + int err = 0; + + if (netif_device_present(real_dev) && ops->ndo_neigh_setup) + err = ops->ndo_neigh_setup(real_dev, pa); + + return err; +} + +#if IS_ENABLED(CONFIG_FCOE) +static int vlan_dev_fcoe_ddp_setup(struct net_device *dev, u16 xid, + struct scatterlist *sgl, unsigned int sgc) +{ + struct net_device *real_dev = vlan_dev_priv(dev)->real_dev; + const struct net_device_ops *ops = real_dev->netdev_ops; + int rc = 0; + + if (ops->ndo_fcoe_ddp_setup) + rc = ops->ndo_fcoe_ddp_setup(real_dev, xid, sgl, sgc); + + return rc; +} + +static int vlan_dev_fcoe_ddp_done(struct net_device *dev, u16 xid) +{ + struct net_device *real_dev = vlan_dev_priv(dev)->real_dev; + const struct net_device_ops *ops = real_dev->netdev_ops; + int len = 0; + + if (ops->ndo_fcoe_ddp_done) + len = ops->ndo_fcoe_ddp_done(real_dev, xid); + + return len; +} + +static int vlan_dev_fcoe_enable(struct net_device *dev) +{ + struct net_device *real_dev = vlan_dev_priv(dev)->real_dev; + const struct net_device_ops *ops = real_dev->netdev_ops; + int rc = -EINVAL; + + if (ops->ndo_fcoe_enable) + rc = ops->ndo_fcoe_enable(real_dev); + return rc; +} + +static int vlan_dev_fcoe_disable(struct net_device *dev) +{ + struct net_device *real_dev = vlan_dev_priv(dev)->real_dev; + const struct net_device_ops *ops = real_dev->netdev_ops; + int rc = -EINVAL; + + if (ops->ndo_fcoe_disable) + rc = ops->ndo_fcoe_disable(real_dev); + return rc; +} + +static int vlan_dev_fcoe_ddp_target(struct net_device *dev, u16 xid, + struct scatterlist *sgl, unsigned int sgc) +{ + struct net_device *real_dev = vlan_dev_priv(dev)->real_dev; + const struct net_device_ops *ops = real_dev->netdev_ops; + int rc = 0; + + if (ops->ndo_fcoe_ddp_target) + rc = ops->ndo_fcoe_ddp_target(real_dev, xid, sgl, sgc); + + return rc; +} +#endif + +#ifdef NETDEV_FCOE_WWNN +static int vlan_dev_fcoe_get_wwn(struct net_device *dev, u64 *wwn, int type) +{ + struct net_device *real_dev = vlan_dev_priv(dev)->real_dev; + const struct net_device_ops *ops = real_dev->netdev_ops; + int rc = -EINVAL; + + if (ops->ndo_fcoe_get_wwn) + rc = ops->ndo_fcoe_get_wwn(real_dev, wwn, type); + return rc; +} +#endif + +static void vlan_dev_change_rx_flags(struct net_device *dev, int change) +{ + struct net_device *real_dev = vlan_dev_priv(dev)->real_dev; + + if (dev->flags & IFF_UP) { + if (change & IFF_ALLMULTI) + dev_set_allmulti(real_dev, dev->flags & IFF_ALLMULTI ? 1 : -1); + if (change & IFF_PROMISC) + dev_set_promiscuity(real_dev, dev->flags & IFF_PROMISC ? 1 : -1); + } +} + +static void vlan_dev_set_rx_mode(struct net_device *vlan_dev) +{ + dev_mc_sync(vlan_dev_priv(vlan_dev)->real_dev, vlan_dev); + dev_uc_sync(vlan_dev_priv(vlan_dev)->real_dev, vlan_dev); +} + +/* + * vlan network devices have devices nesting below it, and are a special + * "super class" of normal network devices; split their locks off into a + * separate class since they always nest. + */ +static struct lock_class_key vlan_netdev_xmit_lock_key; +static struct lock_class_key vlan_netdev_addr_lock_key; + +static void vlan_dev_set_lockdep_one(struct net_device *dev, + struct netdev_queue *txq, + void *unused) +{ + lockdep_set_class(&txq->_xmit_lock, &vlan_netdev_xmit_lock_key); +} + +static void vlan_dev_set_lockdep_class(struct net_device *dev) +{ + lockdep_set_class(&dev->addr_list_lock, + &vlan_netdev_addr_lock_key); + netdev_for_each_tx_queue(dev, vlan_dev_set_lockdep_one, NULL); +} + +static __be16 vlan_parse_protocol(const struct sk_buff *skb) +{ + struct vlan_ethhdr *veth = (struct vlan_ethhdr *)(skb->data); + + return __vlan_get_protocol(skb, veth->h_vlan_proto, NULL); +} + +static const struct header_ops vlan_header_ops = { + .create = vlan_dev_hard_header, + .parse = eth_header_parse, + .parse_protocol = vlan_parse_protocol, +}; + +static int vlan_passthru_hard_header(struct sk_buff *skb, struct net_device *dev, + unsigned short type, + const void *daddr, const void *saddr, + unsigned int len) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + struct net_device *real_dev = vlan->real_dev; + + if (saddr == NULL) + saddr = dev->dev_addr; + + return dev_hard_header(skb, real_dev, type, daddr, saddr, len); +} + +static const struct header_ops vlan_passthru_header_ops = { + .create = vlan_passthru_hard_header, + .parse = eth_header_parse, + .parse_protocol = vlan_parse_protocol, +}; + +static struct device_type vlan_type = { + .name = "vlan", +}; + +static const struct net_device_ops vlan_netdev_ops; + +static int vlan_dev_init(struct net_device *dev) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + struct net_device *real_dev = vlan->real_dev; + + netif_carrier_off(dev); + + /* IFF_BROADCAST|IFF_MULTICAST; ??? */ + dev->flags = real_dev->flags & ~(IFF_UP | IFF_PROMISC | IFF_ALLMULTI | + IFF_MASTER | IFF_SLAVE); + dev->state = (real_dev->state & ((1<<__LINK_STATE_NOCARRIER) | + (1<<__LINK_STATE_DORMANT))) | + (1<<__LINK_STATE_PRESENT); + + if (vlan->flags & VLAN_FLAG_BRIDGE_BINDING) + dev->state |= (1 << __LINK_STATE_NOCARRIER); + + dev->hw_features = NETIF_F_HW_CSUM | NETIF_F_SG | + NETIF_F_FRAGLIST | NETIF_F_GSO_SOFTWARE | + NETIF_F_GSO_ENCAP_ALL | + NETIF_F_HIGHDMA | NETIF_F_SCTP_CRC | + NETIF_F_ALL_FCOE; + + dev->features |= dev->hw_features | NETIF_F_LLTX; + netif_inherit_tso_max(dev, real_dev); + if (dev->features & NETIF_F_VLAN_FEATURES) + netdev_warn(real_dev, "VLAN features are set incorrectly. Q-in-Q configurations may not work correctly.\n"); + + dev->vlan_features = real_dev->vlan_features & ~NETIF_F_ALL_FCOE; + dev->hw_enc_features = vlan_tnl_features(real_dev); + dev->mpls_features = real_dev->mpls_features; + + /* ipv6 shared card related stuff */ + dev->dev_id = real_dev->dev_id; + + if (is_zero_ether_addr(dev->dev_addr)) { + eth_hw_addr_set(dev, real_dev->dev_addr); + dev->addr_assign_type = NET_ADDR_STOLEN; + } + if (is_zero_ether_addr(dev->broadcast)) + memcpy(dev->broadcast, real_dev->broadcast, dev->addr_len); + +#if IS_ENABLED(CONFIG_FCOE) + dev->fcoe_ddp_xid = real_dev->fcoe_ddp_xid; +#endif + + dev->needed_headroom = real_dev->needed_headroom; + if (vlan_hw_offload_capable(real_dev->features, vlan->vlan_proto)) { + dev->header_ops = &vlan_passthru_header_ops; + dev->hard_header_len = real_dev->hard_header_len; + } else { + dev->header_ops = &vlan_header_ops; + dev->hard_header_len = real_dev->hard_header_len + VLAN_HLEN; + } + + dev->netdev_ops = &vlan_netdev_ops; + + SET_NETDEV_DEVTYPE(dev, &vlan_type); + + vlan_dev_set_lockdep_class(dev); + + vlan->vlan_pcpu_stats = netdev_alloc_pcpu_stats(struct vlan_pcpu_stats); + if (!vlan->vlan_pcpu_stats) + return -ENOMEM; + + /* Get vlan's reference to real_dev */ + netdev_hold(real_dev, &vlan->dev_tracker, GFP_KERNEL); + + return 0; +} + +/* Note: this function might be called multiple times for the same device. */ +void vlan_dev_free_egress_priority(const struct net_device *dev) +{ + struct vlan_priority_tci_mapping *pm; + struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + int i; + + for (i = 0; i < ARRAY_SIZE(vlan->egress_priority_map); i++) { + while ((pm = vlan->egress_priority_map[i]) != NULL) { + vlan->egress_priority_map[i] = pm->next; + kfree(pm); + } + } +} + +static void vlan_dev_uninit(struct net_device *dev) +{ + vlan_dev_free_egress_priority(dev); +} + +static netdev_features_t vlan_dev_fix_features(struct net_device *dev, + netdev_features_t features) +{ + struct net_device *real_dev = vlan_dev_priv(dev)->real_dev; + netdev_features_t old_features = features; + netdev_features_t lower_features; + + lower_features = netdev_intersect_features((real_dev->vlan_features | + NETIF_F_RXCSUM), + real_dev->features); + + /* Add HW_CSUM setting to preserve user ability to control + * checksum offload on the vlan device. + */ + if (lower_features & (NETIF_F_IP_CSUM|NETIF_F_IPV6_CSUM)) + lower_features |= NETIF_F_HW_CSUM; + features = netdev_intersect_features(features, lower_features); + features |= old_features & (NETIF_F_SOFT_FEATURES | NETIF_F_GSO_SOFTWARE); + features |= NETIF_F_LLTX; + + return features; +} + +static int vlan_ethtool_get_link_ksettings(struct net_device *dev, + struct ethtool_link_ksettings *cmd) +{ + const struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + + return __ethtool_get_link_ksettings(vlan->real_dev, cmd); +} + +static void vlan_ethtool_get_drvinfo(struct net_device *dev, + struct ethtool_drvinfo *info) +{ + strscpy(info->driver, vlan_fullname, sizeof(info->driver)); + strscpy(info->version, vlan_version, sizeof(info->version)); + strscpy(info->fw_version, "N/A", sizeof(info->fw_version)); +} + +static int vlan_ethtool_get_ts_info(struct net_device *dev, + struct ethtool_ts_info *info) +{ + const struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + const struct ethtool_ops *ops = vlan->real_dev->ethtool_ops; + struct phy_device *phydev = vlan->real_dev->phydev; + + if (phy_has_tsinfo(phydev)) { + return phy_ts_info(phydev, info); + } else if (ops->get_ts_info) { + return ops->get_ts_info(vlan->real_dev, info); + } else { + info->so_timestamping = SOF_TIMESTAMPING_RX_SOFTWARE | + SOF_TIMESTAMPING_SOFTWARE; + info->phc_index = -1; + } + + return 0; +} + +static void vlan_dev_get_stats64(struct net_device *dev, + struct rtnl_link_stats64 *stats) +{ + struct vlan_pcpu_stats *p; + u32 rx_errors = 0, tx_dropped = 0; + int i; + + for_each_possible_cpu(i) { + u64 rxpackets, rxbytes, rxmulticast, txpackets, txbytes; + unsigned int start; + + p = per_cpu_ptr(vlan_dev_priv(dev)->vlan_pcpu_stats, i); + do { + start = u64_stats_fetch_begin_irq(&p->syncp); + rxpackets = u64_stats_read(&p->rx_packets); + rxbytes = u64_stats_read(&p->rx_bytes); + rxmulticast = u64_stats_read(&p->rx_multicast); + txpackets = u64_stats_read(&p->tx_packets); + txbytes = u64_stats_read(&p->tx_bytes); + } while (u64_stats_fetch_retry_irq(&p->syncp, start)); + + stats->rx_packets += rxpackets; + stats->rx_bytes += rxbytes; + stats->multicast += rxmulticast; + stats->tx_packets += txpackets; + stats->tx_bytes += txbytes; + /* rx_errors & tx_dropped are u32 */ + rx_errors += READ_ONCE(p->rx_errors); + tx_dropped += READ_ONCE(p->tx_dropped); + } + stats->rx_errors = rx_errors; + stats->tx_dropped = tx_dropped; +} + +#ifdef CONFIG_NET_POLL_CONTROLLER +static void vlan_dev_poll_controller(struct net_device *dev) +{ + return; +} + +static int vlan_dev_netpoll_setup(struct net_device *dev, struct netpoll_info *npinfo) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + struct net_device *real_dev = vlan->real_dev; + struct netpoll *netpoll; + int err = 0; + + netpoll = kzalloc(sizeof(*netpoll), GFP_KERNEL); + err = -ENOMEM; + if (!netpoll) + goto out; + + err = __netpoll_setup(netpoll, real_dev); + if (err) { + kfree(netpoll); + goto out; + } + + vlan->netpoll = netpoll; + +out: + return err; +} + +static void vlan_dev_netpoll_cleanup(struct net_device *dev) +{ + struct vlan_dev_priv *vlan= vlan_dev_priv(dev); + struct netpoll *netpoll = vlan->netpoll; + + if (!netpoll) + return; + + vlan->netpoll = NULL; + __netpoll_free(netpoll); +} +#endif /* CONFIG_NET_POLL_CONTROLLER */ + +static int vlan_dev_get_iflink(const struct net_device *dev) +{ + struct net_device *real_dev = vlan_dev_priv(dev)->real_dev; + + return real_dev->ifindex; +} + +static int vlan_dev_fill_forward_path(struct net_device_path_ctx *ctx, + struct net_device_path *path) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(ctx->dev); + + path->type = DEV_PATH_VLAN; + path->encap.id = vlan->vlan_id; + path->encap.proto = vlan->vlan_proto; + path->dev = ctx->dev; + ctx->dev = vlan->real_dev; + if (ctx->num_vlans >= ARRAY_SIZE(ctx->vlan)) + return -ENOSPC; + + ctx->vlan[ctx->num_vlans].id = vlan->vlan_id; + ctx->vlan[ctx->num_vlans].proto = vlan->vlan_proto; + ctx->num_vlans++; + + return 0; +} + +static const struct ethtool_ops vlan_ethtool_ops = { + .get_link_ksettings = vlan_ethtool_get_link_ksettings, + .get_drvinfo = vlan_ethtool_get_drvinfo, + .get_link = ethtool_op_get_link, + .get_ts_info = vlan_ethtool_get_ts_info, +}; + +static const struct net_device_ops vlan_netdev_ops = { + .ndo_change_mtu = vlan_dev_change_mtu, + .ndo_init = vlan_dev_init, + .ndo_uninit = vlan_dev_uninit, + .ndo_open = vlan_dev_open, + .ndo_stop = vlan_dev_stop, + .ndo_start_xmit = vlan_dev_hard_start_xmit, + .ndo_validate_addr = eth_validate_addr, + .ndo_set_mac_address = vlan_dev_set_mac_address, + .ndo_set_rx_mode = vlan_dev_set_rx_mode, + .ndo_change_rx_flags = vlan_dev_change_rx_flags, + .ndo_eth_ioctl = vlan_dev_ioctl, + .ndo_neigh_setup = vlan_dev_neigh_setup, + .ndo_get_stats64 = vlan_dev_get_stats64, +#if IS_ENABLED(CONFIG_FCOE) + .ndo_fcoe_ddp_setup = vlan_dev_fcoe_ddp_setup, + .ndo_fcoe_ddp_done = vlan_dev_fcoe_ddp_done, + .ndo_fcoe_enable = vlan_dev_fcoe_enable, + .ndo_fcoe_disable = vlan_dev_fcoe_disable, + .ndo_fcoe_ddp_target = vlan_dev_fcoe_ddp_target, +#endif +#ifdef NETDEV_FCOE_WWNN + .ndo_fcoe_get_wwn = vlan_dev_fcoe_get_wwn, +#endif +#ifdef CONFIG_NET_POLL_CONTROLLER + .ndo_poll_controller = vlan_dev_poll_controller, + .ndo_netpoll_setup = vlan_dev_netpoll_setup, + .ndo_netpoll_cleanup = vlan_dev_netpoll_cleanup, +#endif + .ndo_fix_features = vlan_dev_fix_features, + .ndo_get_iflink = vlan_dev_get_iflink, + .ndo_fill_forward_path = vlan_dev_fill_forward_path, +}; + +static void vlan_dev_free(struct net_device *dev) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + + free_percpu(vlan->vlan_pcpu_stats); + vlan->vlan_pcpu_stats = NULL; + + /* Get rid of the vlan's reference to real_dev */ + netdev_put(vlan->real_dev, &vlan->dev_tracker); +} + +void vlan_setup(struct net_device *dev) +{ + ether_setup(dev); + + dev->priv_flags |= IFF_802_1Q_VLAN | IFF_NO_QUEUE; + dev->priv_flags |= IFF_UNICAST_FLT; + dev->priv_flags &= ~IFF_TX_SKB_SHARING; + netif_keep_dst(dev); + + dev->netdev_ops = &vlan_netdev_ops; + dev->needs_free_netdev = true; + dev->priv_destructor = vlan_dev_free; + dev->ethtool_ops = &vlan_ethtool_ops; + + dev->min_mtu = 0; + dev->max_mtu = ETH_MAX_MTU; + + eth_zero_addr(dev->broadcast); +} diff --git a/net/8021q/vlan_gvrp.c b/net/8021q/vlan_gvrp.c new file mode 100644 index 000000000..6b34b72aa --- /dev/null +++ b/net/8021q/vlan_gvrp.c @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * IEEE 802.1Q GARP VLAN Registration Protocol (GVRP) + * + * Copyright (c) 2008 Patrick McHardy + */ +#include +#include +#include +#include "vlan.h" + +#define GARP_GVRP_ADDRESS { 0x01, 0x80, 0xc2, 0x00, 0x00, 0x21 } + +enum gvrp_attributes { + GVRP_ATTR_INVALID, + GVRP_ATTR_VID, + __GVRP_ATTR_MAX +}; +#define GVRP_ATTR_MAX (__GVRP_ATTR_MAX - 1) + +static struct garp_application vlan_gvrp_app __read_mostly = { + .proto.group_address = GARP_GVRP_ADDRESS, + .maxattr = GVRP_ATTR_MAX, + .type = GARP_APPLICATION_GVRP, +}; + +int vlan_gvrp_request_join(const struct net_device *dev) +{ + const struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + __be16 vlan_id = htons(vlan->vlan_id); + + if (vlan->vlan_proto != htons(ETH_P_8021Q)) + return 0; + return garp_request_join(vlan->real_dev, &vlan_gvrp_app, + &vlan_id, sizeof(vlan_id), GVRP_ATTR_VID); +} + +void vlan_gvrp_request_leave(const struct net_device *dev) +{ + const struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + __be16 vlan_id = htons(vlan->vlan_id); + + if (vlan->vlan_proto != htons(ETH_P_8021Q)) + return; + garp_request_leave(vlan->real_dev, &vlan_gvrp_app, + &vlan_id, sizeof(vlan_id), GVRP_ATTR_VID); +} + +int vlan_gvrp_init_applicant(struct net_device *dev) +{ + return garp_init_applicant(dev, &vlan_gvrp_app); +} + +void vlan_gvrp_uninit_applicant(struct net_device *dev) +{ + garp_uninit_applicant(dev, &vlan_gvrp_app); +} + +int __init vlan_gvrp_init(void) +{ + return garp_register_application(&vlan_gvrp_app); +} + +void vlan_gvrp_uninit(void) +{ + garp_unregister_application(&vlan_gvrp_app); +} diff --git a/net/8021q/vlan_mvrp.c b/net/8021q/vlan_mvrp.c new file mode 100644 index 000000000..689eceeaa --- /dev/null +++ b/net/8021q/vlan_mvrp.c @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * IEEE 802.1Q Multiple VLAN Registration Protocol (MVRP) + * + * Copyright (c) 2012 Massachusetts Institute of Technology + * + * Adapted from code in net/8021q/vlan_gvrp.c + * Copyright (c) 2008 Patrick McHardy + */ +#include +#include +#include +#include +#include "vlan.h" + +#define MRP_MVRP_ADDRESS { 0x01, 0x80, 0xc2, 0x00, 0x00, 0x21 } + +enum mvrp_attributes { + MVRP_ATTR_INVALID, + MVRP_ATTR_VID, + __MVRP_ATTR_MAX +}; +#define MVRP_ATTR_MAX (__MVRP_ATTR_MAX - 1) + +static struct mrp_application vlan_mrp_app __read_mostly = { + .type = MRP_APPLICATION_MVRP, + .maxattr = MVRP_ATTR_MAX, + .pkttype.type = htons(ETH_P_MVRP), + .group_address = MRP_MVRP_ADDRESS, + .version = 0, +}; + +int vlan_mvrp_request_join(const struct net_device *dev) +{ + const struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + __be16 vlan_id = htons(vlan->vlan_id); + + if (vlan->vlan_proto != htons(ETH_P_8021Q)) + return 0; + return mrp_request_join(vlan->real_dev, &vlan_mrp_app, + &vlan_id, sizeof(vlan_id), MVRP_ATTR_VID); +} + +void vlan_mvrp_request_leave(const struct net_device *dev) +{ + const struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + __be16 vlan_id = htons(vlan->vlan_id); + + if (vlan->vlan_proto != htons(ETH_P_8021Q)) + return; + mrp_request_leave(vlan->real_dev, &vlan_mrp_app, + &vlan_id, sizeof(vlan_id), MVRP_ATTR_VID); +} + +int vlan_mvrp_init_applicant(struct net_device *dev) +{ + return mrp_init_applicant(dev, &vlan_mrp_app); +} + +void vlan_mvrp_uninit_applicant(struct net_device *dev) +{ + mrp_uninit_applicant(dev, &vlan_mrp_app); +} + +int __init vlan_mvrp_init(void) +{ + return mrp_register_application(&vlan_mrp_app); +} + +void vlan_mvrp_uninit(void) +{ + mrp_unregister_application(&vlan_mrp_app); +} diff --git a/net/8021q/vlan_netlink.c b/net/8021q/vlan_netlink.c new file mode 100644 index 000000000..a3b68243f --- /dev/null +++ b/net/8021q/vlan_netlink.c @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * VLAN netlink control interface + * + * Copyright (c) 2007 Patrick McHardy + */ + +#include +#include +#include +#include +#include +#include +#include +#include "vlan.h" + + +static const struct nla_policy vlan_policy[IFLA_VLAN_MAX + 1] = { + [IFLA_VLAN_ID] = { .type = NLA_U16 }, + [IFLA_VLAN_FLAGS] = { .len = sizeof(struct ifla_vlan_flags) }, + [IFLA_VLAN_EGRESS_QOS] = { .type = NLA_NESTED }, + [IFLA_VLAN_INGRESS_QOS] = { .type = NLA_NESTED }, + [IFLA_VLAN_PROTOCOL] = { .type = NLA_U16 }, +}; + +static const struct nla_policy vlan_map_policy[IFLA_VLAN_QOS_MAX + 1] = { + [IFLA_VLAN_QOS_MAPPING] = { .len = sizeof(struct ifla_vlan_qos_mapping) }, +}; + + +static inline int vlan_validate_qos_map(struct nlattr *attr) +{ + if (!attr) + return 0; + return nla_validate_nested_deprecated(attr, IFLA_VLAN_QOS_MAX, + vlan_map_policy, NULL); +} + +static int vlan_validate(struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ifla_vlan_flags *flags; + u16 id; + int err; + + if (tb[IFLA_ADDRESS]) { + if (nla_len(tb[IFLA_ADDRESS]) != ETH_ALEN) { + NL_SET_ERR_MSG_MOD(extack, "Invalid link address"); + return -EINVAL; + } + if (!is_valid_ether_addr(nla_data(tb[IFLA_ADDRESS]))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid link address"); + return -EADDRNOTAVAIL; + } + } + + if (!data) { + NL_SET_ERR_MSG_MOD(extack, "VLAN properties not specified"); + return -EINVAL; + } + + if (data[IFLA_VLAN_PROTOCOL]) { + switch (nla_get_be16(data[IFLA_VLAN_PROTOCOL])) { + case htons(ETH_P_8021Q): + case htons(ETH_P_8021AD): + break; + default: + NL_SET_ERR_MSG_MOD(extack, "Invalid VLAN protocol"); + return -EPROTONOSUPPORT; + } + } + + if (data[IFLA_VLAN_ID]) { + id = nla_get_u16(data[IFLA_VLAN_ID]); + if (id >= VLAN_VID_MASK) { + NL_SET_ERR_MSG_MOD(extack, "Invalid VLAN id"); + return -ERANGE; + } + } + if (data[IFLA_VLAN_FLAGS]) { + flags = nla_data(data[IFLA_VLAN_FLAGS]); + if ((flags->flags & flags->mask) & + ~(VLAN_FLAG_REORDER_HDR | VLAN_FLAG_GVRP | + VLAN_FLAG_LOOSE_BINDING | VLAN_FLAG_MVRP | + VLAN_FLAG_BRIDGE_BINDING)) { + NL_SET_ERR_MSG_MOD(extack, "Invalid VLAN flags"); + return -EINVAL; + } + } + + err = vlan_validate_qos_map(data[IFLA_VLAN_INGRESS_QOS]); + if (err < 0) { + NL_SET_ERR_MSG_MOD(extack, "Invalid ingress QOS map"); + return err; + } + err = vlan_validate_qos_map(data[IFLA_VLAN_EGRESS_QOS]); + if (err < 0) { + NL_SET_ERR_MSG_MOD(extack, "Invalid egress QOS map"); + return err; + } + return 0; +} + +static int vlan_changelink(struct net_device *dev, struct nlattr *tb[], + struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ifla_vlan_flags *flags; + struct ifla_vlan_qos_mapping *m; + struct nlattr *attr; + int rem, err; + + if (data[IFLA_VLAN_FLAGS]) { + flags = nla_data(data[IFLA_VLAN_FLAGS]); + err = vlan_dev_change_flags(dev, flags->flags, flags->mask); + if (err) + return err; + } + if (data[IFLA_VLAN_INGRESS_QOS]) { + nla_for_each_nested(attr, data[IFLA_VLAN_INGRESS_QOS], rem) { + if (nla_type(attr) != IFLA_VLAN_QOS_MAPPING) + continue; + m = nla_data(attr); + vlan_dev_set_ingress_priority(dev, m->to, m->from); + } + } + if (data[IFLA_VLAN_EGRESS_QOS]) { + nla_for_each_nested(attr, data[IFLA_VLAN_EGRESS_QOS], rem) { + if (nla_type(attr) != IFLA_VLAN_QOS_MAPPING) + continue; + m = nla_data(attr); + err = vlan_dev_set_egress_priority(dev, m->from, m->to); + if (err) + return err; + } + } + return 0; +} + +static int vlan_newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + struct net_device *real_dev; + unsigned int max_mtu; + __be16 proto; + int err; + + if (!data[IFLA_VLAN_ID]) { + NL_SET_ERR_MSG_MOD(extack, "VLAN id not specified"); + return -EINVAL; + } + + if (!tb[IFLA_LINK]) { + NL_SET_ERR_MSG_MOD(extack, "link not specified"); + return -EINVAL; + } + + real_dev = __dev_get_by_index(src_net, nla_get_u32(tb[IFLA_LINK])); + if (!real_dev) { + NL_SET_ERR_MSG_MOD(extack, "link does not exist"); + return -ENODEV; + } + + if (data[IFLA_VLAN_PROTOCOL]) + proto = nla_get_be16(data[IFLA_VLAN_PROTOCOL]); + else + proto = htons(ETH_P_8021Q); + + vlan->vlan_proto = proto; + vlan->vlan_id = nla_get_u16(data[IFLA_VLAN_ID]); + vlan->real_dev = real_dev; + dev->priv_flags |= (real_dev->priv_flags & IFF_XMIT_DST_RELEASE); + vlan->flags = VLAN_FLAG_REORDER_HDR; + + err = vlan_check_real_dev(real_dev, vlan->vlan_proto, vlan->vlan_id, + extack); + if (err < 0) + return err; + + max_mtu = netif_reduces_vlan_mtu(real_dev) ? real_dev->mtu - VLAN_HLEN : + real_dev->mtu; + if (!tb[IFLA_MTU]) + dev->mtu = max_mtu; + else if (dev->mtu > max_mtu) + return -EINVAL; + + /* Note: If this initial vlan_changelink() fails, we need + * to call vlan_dev_free_egress_priority() to free memory. + */ + err = vlan_changelink(dev, tb, data, extack); + + if (!err) + err = register_vlan_dev(dev, extack); + + if (err) + vlan_dev_free_egress_priority(dev); + return err; +} + +static inline size_t vlan_qos_map_size(unsigned int n) +{ + if (n == 0) + return 0; + /* IFLA_VLAN_{EGRESS,INGRESS}_QOS + n * IFLA_VLAN_QOS_MAPPING */ + return nla_total_size(sizeof(struct nlattr)) + + nla_total_size(sizeof(struct ifla_vlan_qos_mapping)) * n; +} + +static size_t vlan_get_size(const struct net_device *dev) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + + return nla_total_size(2) + /* IFLA_VLAN_PROTOCOL */ + nla_total_size(2) + /* IFLA_VLAN_ID */ + nla_total_size(sizeof(struct ifla_vlan_flags)) + /* IFLA_VLAN_FLAGS */ + vlan_qos_map_size(vlan->nr_ingress_mappings) + + vlan_qos_map_size(vlan->nr_egress_mappings); +} + +static int vlan_fill_info(struct sk_buff *skb, const struct net_device *dev) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(dev); + struct vlan_priority_tci_mapping *pm; + struct ifla_vlan_flags f; + struct ifla_vlan_qos_mapping m; + struct nlattr *nest; + unsigned int i; + + if (nla_put_be16(skb, IFLA_VLAN_PROTOCOL, vlan->vlan_proto) || + nla_put_u16(skb, IFLA_VLAN_ID, vlan->vlan_id)) + goto nla_put_failure; + if (vlan->flags) { + f.flags = vlan->flags; + f.mask = ~0; + if (nla_put(skb, IFLA_VLAN_FLAGS, sizeof(f), &f)) + goto nla_put_failure; + } + if (vlan->nr_ingress_mappings) { + nest = nla_nest_start_noflag(skb, IFLA_VLAN_INGRESS_QOS); + if (nest == NULL) + goto nla_put_failure; + + for (i = 0; i < ARRAY_SIZE(vlan->ingress_priority_map); i++) { + if (!vlan->ingress_priority_map[i]) + continue; + + m.from = i; + m.to = vlan->ingress_priority_map[i]; + if (nla_put(skb, IFLA_VLAN_QOS_MAPPING, + sizeof(m), &m)) + goto nla_put_failure; + } + nla_nest_end(skb, nest); + } + + if (vlan->nr_egress_mappings) { + nest = nla_nest_start_noflag(skb, IFLA_VLAN_EGRESS_QOS); + if (nest == NULL) + goto nla_put_failure; + + for (i = 0; i < ARRAY_SIZE(vlan->egress_priority_map); i++) { + for (pm = vlan->egress_priority_map[i]; pm; + pm = pm->next) { + if (!pm->vlan_qos) + continue; + + m.from = pm->priority; + m.to = (pm->vlan_qos >> 13) & 0x7; + if (nla_put(skb, IFLA_VLAN_QOS_MAPPING, + sizeof(m), &m)) + goto nla_put_failure; + } + } + nla_nest_end(skb, nest); + } + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static struct net *vlan_get_link_net(const struct net_device *dev) +{ + struct net_device *real_dev = vlan_dev_priv(dev)->real_dev; + + return dev_net(real_dev); +} + +struct rtnl_link_ops vlan_link_ops __read_mostly = { + .kind = "vlan", + .maxtype = IFLA_VLAN_MAX, + .policy = vlan_policy, + .priv_size = sizeof(struct vlan_dev_priv), + .setup = vlan_setup, + .validate = vlan_validate, + .newlink = vlan_newlink, + .changelink = vlan_changelink, + .dellink = unregister_vlan_dev, + .get_size = vlan_get_size, + .fill_info = vlan_fill_info, + .get_link_net = vlan_get_link_net, +}; + +int __init vlan_netlink_init(void) +{ + return rtnl_link_register(&vlan_link_ops); +} + +void __exit vlan_netlink_fini(void) +{ + rtnl_link_unregister(&vlan_link_ops); +} + +MODULE_ALIAS_RTNL_LINK("vlan"); diff --git a/net/8021q/vlanproc.c b/net/8021q/vlanproc.c new file mode 100644 index 000000000..7825c1297 --- /dev/null +++ b/net/8021q/vlanproc.c @@ -0,0 +1,291 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/****************************************************************************** + * vlanproc.c VLAN Module. /proc filesystem interface. + * + * This module is completely hardware-independent and provides + * access to the router using Linux /proc filesystem. + * + * Author: Ben Greear, coppied from wanproc.c + * by: Gene Kozin + * + * Copyright: (c) 1998 Ben Greear + * + * ============================================================================ + * Jan 20, 1998 Ben Greear Initial Version + *****************************************************************************/ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "vlanproc.h" +#include "vlan.h" + +/****** Function Prototypes *************************************************/ + +/* Methods for preparing data for reading proc entries */ +static int vlan_seq_show(struct seq_file *seq, void *v); +static void *vlan_seq_start(struct seq_file *seq, loff_t *pos); +static void *vlan_seq_next(struct seq_file *seq, void *v, loff_t *pos); +static void vlan_seq_stop(struct seq_file *seq, void *); +static int vlandev_seq_show(struct seq_file *seq, void *v); + +/* + * Global Data + */ + + +/* + * Names of the proc directory entries + */ + +static const char name_root[] = "vlan"; +static const char name_conf[] = "config"; + +/* + * Structures for interfacing with the /proc filesystem. + * VLAN creates its own directory /proc/net/vlan with the following + * entries: + * config device status/configuration + * entry for each device + */ + +/* + * Generic /proc/net/vlan/ file and inode operations + */ + +static const struct seq_operations vlan_seq_ops = { + .start = vlan_seq_start, + .next = vlan_seq_next, + .stop = vlan_seq_stop, + .show = vlan_seq_show, +}; + +/* + * Proc filesystem directory entries. + */ + +/* Strings */ +static const char *const vlan_name_type_str[VLAN_NAME_TYPE_HIGHEST] = { + [VLAN_NAME_TYPE_RAW_PLUS_VID] = "VLAN_NAME_TYPE_RAW_PLUS_VID", + [VLAN_NAME_TYPE_PLUS_VID_NO_PAD] = "VLAN_NAME_TYPE_PLUS_VID_NO_PAD", + [VLAN_NAME_TYPE_RAW_PLUS_VID_NO_PAD] = "VLAN_NAME_TYPE_RAW_PLUS_VID_NO_PAD", + [VLAN_NAME_TYPE_PLUS_VID] = "VLAN_NAME_TYPE_PLUS_VID", +}; +/* + * Interface functions + */ + +/* + * Clean up /proc/net/vlan entries + */ + +void vlan_proc_cleanup(struct net *net) +{ + struct vlan_net *vn = net_generic(net, vlan_net_id); + + if (vn->proc_vlan_conf) + remove_proc_entry(name_conf, vn->proc_vlan_dir); + + if (vn->proc_vlan_dir) + remove_proc_entry(name_root, net->proc_net); + + /* Dynamically added entries should be cleaned up as their vlan_device + * is removed, so we should not have to take care of it here... + */ +} + +/* + * Create /proc/net/vlan entries + */ + +int __net_init vlan_proc_init(struct net *net) +{ + struct vlan_net *vn = net_generic(net, vlan_net_id); + + vn->proc_vlan_dir = proc_net_mkdir(net, name_root, net->proc_net); + if (!vn->proc_vlan_dir) + goto err; + + vn->proc_vlan_conf = proc_create_net(name_conf, S_IFREG | 0600, + vn->proc_vlan_dir, &vlan_seq_ops, + sizeof(struct seq_net_private)); + if (!vn->proc_vlan_conf) + goto err; + return 0; + +err: + pr_err("can't create entry in proc filesystem!\n"); + vlan_proc_cleanup(net); + return -ENOBUFS; +} + +/* + * Add directory entry for VLAN device. + */ + +int vlan_proc_add_dev(struct net_device *vlandev) +{ + struct vlan_dev_priv *vlan = vlan_dev_priv(vlandev); + struct vlan_net *vn = net_generic(dev_net(vlandev), vlan_net_id); + + if (!strcmp(vlandev->name, name_conf)) + return -EINVAL; + vlan->dent = proc_create_single_data(vlandev->name, S_IFREG | 0600, + vn->proc_vlan_dir, vlandev_seq_show, vlandev); + if (!vlan->dent) + return -ENOBUFS; + return 0; +} + +/* + * Delete directory entry for VLAN device. + */ +void vlan_proc_rem_dev(struct net_device *vlandev) +{ + /** NOTE: This will consume the memory pointed to by dent, it seems. */ + proc_remove(vlan_dev_priv(vlandev)->dent); + vlan_dev_priv(vlandev)->dent = NULL; +} + +/****** Proc filesystem entry points ****************************************/ + +/* + * The following few functions build the content of /proc/net/vlan/config + */ + +/* start read of /proc/net/vlan/config */ +static void *vlan_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(rcu) +{ + struct net_device *dev; + struct net *net = seq_file_net(seq); + loff_t i = 1; + + rcu_read_lock(); + if (*pos == 0) + return SEQ_START_TOKEN; + + for_each_netdev_rcu(net, dev) { + if (!is_vlan_dev(dev)) + continue; + + if (i++ == *pos) + return dev; + } + + return NULL; +} + +static void *vlan_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct net_device *dev; + struct net *net = seq_file_net(seq); + + ++*pos; + + dev = v; + if (v == SEQ_START_TOKEN) + dev = net_device_entry(&net->dev_base_head); + + for_each_netdev_continue_rcu(net, dev) { + if (!is_vlan_dev(dev)) + continue; + + return dev; + } + + return NULL; +} + +static void vlan_seq_stop(struct seq_file *seq, void *v) + __releases(rcu) +{ + rcu_read_unlock(); +} + +static int vlan_seq_show(struct seq_file *seq, void *v) +{ + struct net *net = seq_file_net(seq); + struct vlan_net *vn = net_generic(net, vlan_net_id); + + if (v == SEQ_START_TOKEN) { + const char *nmtype = NULL; + + seq_puts(seq, "VLAN Dev name | VLAN ID\n"); + + if (vn->name_type < ARRAY_SIZE(vlan_name_type_str)) + nmtype = vlan_name_type_str[vn->name_type]; + + seq_printf(seq, "Name-Type: %s\n", + nmtype ? nmtype : "UNKNOWN"); + } else { + const struct net_device *vlandev = v; + const struct vlan_dev_priv *vlan = vlan_dev_priv(vlandev); + + seq_printf(seq, "%-15s| %d | %s\n", vlandev->name, + vlan->vlan_id, vlan->real_dev->name); + } + return 0; +} + +static int vlandev_seq_show(struct seq_file *seq, void *offset) +{ + struct net_device *vlandev = (struct net_device *) seq->private; + const struct vlan_dev_priv *vlan = vlan_dev_priv(vlandev); + struct rtnl_link_stats64 temp; + const struct rtnl_link_stats64 *stats; + static const char fmt64[] = "%30s %12llu\n"; + int i; + + if (!is_vlan_dev(vlandev)) + return 0; + + stats = dev_get_stats(vlandev, &temp); + seq_printf(seq, + "%s VID: %d REORDER_HDR: %i dev->priv_flags: %llx\n", + vlandev->name, vlan->vlan_id, + (int)(vlan->flags & 1), vlandev->priv_flags); + + seq_printf(seq, fmt64, "total frames received", stats->rx_packets); + seq_printf(seq, fmt64, "total bytes received", stats->rx_bytes); + seq_printf(seq, fmt64, "Broadcast/Multicast Rcvd", stats->multicast); + seq_puts(seq, "\n"); + seq_printf(seq, fmt64, "total frames transmitted", stats->tx_packets); + seq_printf(seq, fmt64, "total bytes transmitted", stats->tx_bytes); + seq_printf(seq, "Device: %s", vlan->real_dev->name); + /* now show all PRIORITY mappings relating to this VLAN */ + seq_printf(seq, "\nINGRESS priority mappings: " + "0:%u 1:%u 2:%u 3:%u 4:%u 5:%u 6:%u 7:%u\n", + vlan->ingress_priority_map[0], + vlan->ingress_priority_map[1], + vlan->ingress_priority_map[2], + vlan->ingress_priority_map[3], + vlan->ingress_priority_map[4], + vlan->ingress_priority_map[5], + vlan->ingress_priority_map[6], + vlan->ingress_priority_map[7]); + + seq_printf(seq, " EGRESS priority mappings: "); + for (i = 0; i < 16; i++) { + const struct vlan_priority_tci_mapping *mp + = vlan->egress_priority_map[i]; + while (mp) { + seq_printf(seq, "%u:%d ", + mp->priority, ((mp->vlan_qos >> 13) & 0x7)); + mp = mp->next; + } + } + seq_puts(seq, "\n"); + + return 0; +} diff --git a/net/8021q/vlanproc.h b/net/8021q/vlanproc.h new file mode 100644 index 000000000..48cd4b478 --- /dev/null +++ b/net/8021q/vlanproc.h @@ -0,0 +1,21 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __BEN_VLAN_PROC_INC__ +#define __BEN_VLAN_PROC_INC__ + +#ifdef CONFIG_PROC_FS +struct net; + +int vlan_proc_init(struct net *net); +void vlan_proc_rem_dev(struct net_device *vlandev); +int vlan_proc_add_dev(struct net_device *vlandev); +void vlan_proc_cleanup(struct net *net); + +#else /* No CONFIG_PROC_FS */ + +#define vlan_proc_init(net) (0) +#define vlan_proc_cleanup(net) do {} while (0) +#define vlan_proc_add_dev(dev) ({(void)(dev), 0; }) +#define vlan_proc_rem_dev(dev) do {} while (0) +#endif + +#endif /* !(__BEN_VLAN_PROC_INC__) */ diff --git a/net/9p/Kconfig b/net/9p/Kconfig new file mode 100644 index 000000000..deabbd376 --- /dev/null +++ b/net/9p/Kconfig @@ -0,0 +1,52 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# 9P protocol configuration +# + +menuconfig NET_9P + tristate "Plan 9 Resource Sharing Support (9P2000)" + help + If you say Y here, you will get experimental support for + Plan 9 resource sharing via the 9P2000 protocol. + + See for more information. + + If unsure, say N. + +if NET_9P + +config NET_9P_FD + default NET_9P + tristate "9P FD Transport" + help + This builds support for transports over TCP, Unix sockets and + filedescriptors. + +config NET_9P_VIRTIO + depends on VIRTIO + tristate "9P Virtio Transport" + help + This builds support for a transports between + guest partitions and a host partition. + +config NET_9P_XEN + depends on XEN + select XEN_XENBUS_FRONTEND + tristate "9P Xen Transport" + help + This builds support for a transport for 9pfs between + two Xen domains. + + +config NET_9P_RDMA + depends on INET && INFINIBAND && INFINIBAND_ADDR_TRANS + tristate "9P RDMA Transport (Experimental)" + help + This builds support for an RDMA transport. + +config NET_9P_DEBUG + bool "Debug information" + help + Say Y if you want the 9P subsystem to log debug information. + +endif diff --git a/net/9p/Makefile b/net/9p/Makefile new file mode 100644 index 000000000..1df9b344c --- /dev/null +++ b/net/9p/Makefile @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: GPL-2.0 +obj-$(CONFIG_NET_9P) := 9pnet.o +obj-$(CONFIG_NET_9P_FD) += 9pnet_fd.o +obj-$(CONFIG_NET_9P_XEN) += 9pnet_xen.o +obj-$(CONFIG_NET_9P_VIRTIO) += 9pnet_virtio.o +obj-$(CONFIG_NET_9P_RDMA) += 9pnet_rdma.o + +9pnet-objs := \ + mod.o \ + client.o \ + error.o \ + protocol.o \ + trans_common.o \ + +9pnet_fd-objs := \ + trans_fd.o \ + +9pnet_virtio-objs := \ + trans_virtio.o \ + +9pnet_xen-objs := \ + trans_xen.o \ + +9pnet_rdma-objs := \ + trans_rdma.o \ diff --git a/net/9p/client.c b/net/9p/client.c new file mode 100644 index 000000000..84b93b04d --- /dev/null +++ b/net/9p/client.c @@ -0,0 +1,2276 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * 9P Client + * + * Copyright (C) 2008 by Eric Van Hensbergen + * Copyright (C) 2007 by Latchesar Ionkov + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "protocol.h" + +#define CREATE_TRACE_POINTS +#include + +#define DEFAULT_MSIZE (128 * 1024) + +/* Client Option Parsing (code inspired by NFS code) + * - a little lazy - parse all client options + */ + +enum { + Opt_msize, + Opt_trans, + Opt_legacy, + Opt_version, + Opt_err, +}; + +static const match_table_t tokens = { + {Opt_msize, "msize=%u"}, + {Opt_legacy, "noextend"}, + {Opt_trans, "trans=%s"}, + {Opt_version, "version=%s"}, + {Opt_err, NULL}, +}; + +inline int p9_is_proto_dotl(struct p9_client *clnt) +{ + return clnt->proto_version == p9_proto_2000L; +} +EXPORT_SYMBOL(p9_is_proto_dotl); + +inline int p9_is_proto_dotu(struct p9_client *clnt) +{ + return clnt->proto_version == p9_proto_2000u; +} +EXPORT_SYMBOL(p9_is_proto_dotu); + +int p9_show_client_options(struct seq_file *m, struct p9_client *clnt) +{ + if (clnt->msize != DEFAULT_MSIZE) + seq_printf(m, ",msize=%u", clnt->msize); + seq_printf(m, ",trans=%s", clnt->trans_mod->name); + + switch (clnt->proto_version) { + case p9_proto_legacy: + seq_puts(m, ",noextend"); + break; + case p9_proto_2000u: + seq_puts(m, ",version=9p2000.u"); + break; + case p9_proto_2000L: + /* Default */ + break; + } + + if (clnt->trans_mod->show_options) + return clnt->trans_mod->show_options(m, clnt); + return 0; +} +EXPORT_SYMBOL(p9_show_client_options); + +/* Some error codes are taken directly from the server replies, + * make sure they are valid. + */ +static int safe_errno(int err) +{ + if (err > 0 || err < -MAX_ERRNO) { + p9_debug(P9_DEBUG_ERROR, "Invalid error code %d\n", err); + return -EPROTO; + } + return err; +} + +/* Interpret mount option for protocol version */ +static int get_protocol_version(char *s) +{ + int version = -EINVAL; + + if (!strcmp(s, "9p2000")) { + version = p9_proto_legacy; + p9_debug(P9_DEBUG_9P, "Protocol version: Legacy\n"); + } else if (!strcmp(s, "9p2000.u")) { + version = p9_proto_2000u; + p9_debug(P9_DEBUG_9P, "Protocol version: 9P2000.u\n"); + } else if (!strcmp(s, "9p2000.L")) { + version = p9_proto_2000L; + p9_debug(P9_DEBUG_9P, "Protocol version: 9P2000.L\n"); + } else { + pr_info("Unknown protocol version %s\n", s); + } + + return version; +} + +/** + * parse_opts - parse mount options into client structure + * @opts: options string passed from mount + * @clnt: existing v9fs client information + * + * Return 0 upon success, -ERRNO upon failure + */ + +static int parse_opts(char *opts, struct p9_client *clnt) +{ + char *options, *tmp_options; + char *p; + substring_t args[MAX_OPT_ARGS]; + int option; + char *s; + int ret = 0; + + clnt->proto_version = p9_proto_2000L; + clnt->msize = DEFAULT_MSIZE; + + if (!opts) + return 0; + + tmp_options = kstrdup(opts, GFP_KERNEL); + if (!tmp_options) + return -ENOMEM; + options = tmp_options; + + while ((p = strsep(&options, ",")) != NULL) { + int token, r; + + if (!*p) + continue; + token = match_token(p, tokens, args); + switch (token) { + case Opt_msize: + r = match_int(&args[0], &option); + if (r < 0) { + p9_debug(P9_DEBUG_ERROR, + "integer field, but no integer?\n"); + ret = r; + continue; + } + if (option < 4096) { + p9_debug(P9_DEBUG_ERROR, + "msize should be at least 4k\n"); + ret = -EINVAL; + continue; + } + clnt->msize = option; + break; + case Opt_trans: + s = match_strdup(&args[0]); + if (!s) { + ret = -ENOMEM; + p9_debug(P9_DEBUG_ERROR, + "problem allocating copy of trans arg\n"); + goto free_and_return; + } + + v9fs_put_trans(clnt->trans_mod); + clnt->trans_mod = v9fs_get_trans_by_name(s); + if (!clnt->trans_mod) { + pr_info("Could not find request transport: %s\n", + s); + ret = -EINVAL; + } + kfree(s); + break; + case Opt_legacy: + clnt->proto_version = p9_proto_legacy; + break; + case Opt_version: + s = match_strdup(&args[0]); + if (!s) { + ret = -ENOMEM; + p9_debug(P9_DEBUG_ERROR, + "problem allocating copy of version arg\n"); + goto free_and_return; + } + r = get_protocol_version(s); + if (r < 0) + ret = r; + else + clnt->proto_version = r; + kfree(s); + break; + default: + continue; + } + } + +free_and_return: + if (ret) + v9fs_put_trans(clnt->trans_mod); + kfree(tmp_options); + return ret; +} + +static int p9_fcall_init(struct p9_client *c, struct p9_fcall *fc, + int alloc_msize) +{ + if (likely(c->fcall_cache) && alloc_msize == c->msize) { + fc->sdata = kmem_cache_alloc(c->fcall_cache, GFP_NOFS); + fc->cache = c->fcall_cache; + } else { + fc->sdata = kmalloc(alloc_msize, GFP_NOFS); + fc->cache = NULL; + } + if (!fc->sdata) + return -ENOMEM; + fc->capacity = alloc_msize; + return 0; +} + +void p9_fcall_fini(struct p9_fcall *fc) +{ + /* sdata can be NULL for interrupted requests in trans_rdma, + * and kmem_cache_free does not do NULL-check for us + */ + if (unlikely(!fc->sdata)) + return; + + if (fc->cache) + kmem_cache_free(fc->cache, fc->sdata); + else + kfree(fc->sdata); +} +EXPORT_SYMBOL(p9_fcall_fini); + +static struct kmem_cache *p9_req_cache; + +/** + * p9_tag_alloc - Allocate a new request. + * @c: Client session. + * @type: Transaction type. + * @t_size: Buffer size for holding this request + * (automatic calculation by format template if 0). + * @r_size: Buffer size for holding server's reply on this request + * (automatic calculation by format template if 0). + * @fmt: Format template for assembling 9p request message + * (see p9pdu_vwritef). + * @ap: Variable arguments to be fed to passed format template + * (see p9pdu_vwritef). + * + * Context: Process context. + * Return: Pointer to new request. + */ +static struct p9_req_t * +p9_tag_alloc(struct p9_client *c, int8_t type, uint t_size, uint r_size, + const char *fmt, va_list ap) +{ + struct p9_req_t *req = kmem_cache_alloc(p9_req_cache, GFP_NOFS); + int alloc_tsize; + int alloc_rsize; + int tag; + va_list apc; + + va_copy(apc, ap); + alloc_tsize = min_t(size_t, c->msize, + t_size ?: p9_msg_buf_size(c, type, fmt, apc)); + va_end(apc); + + alloc_rsize = min_t(size_t, c->msize, + r_size ?: p9_msg_buf_size(c, type + 1, fmt, ap)); + + if (!req) + return ERR_PTR(-ENOMEM); + + if (p9_fcall_init(c, &req->tc, alloc_tsize)) + goto free_req; + if (p9_fcall_init(c, &req->rc, alloc_rsize)) + goto free; + + p9pdu_reset(&req->tc); + p9pdu_reset(&req->rc); + req->t_err = 0; + req->status = REQ_STATUS_ALLOC; + /* refcount needs to be set to 0 before inserting into the idr + * so p9_tag_lookup does not accept a request that is not fully + * initialized. refcount_set to 2 below will mark request ready. + */ + refcount_set(&req->refcount, 0); + init_waitqueue_head(&req->wq); + INIT_LIST_HEAD(&req->req_list); + + idr_preload(GFP_NOFS); + spin_lock_irq(&c->lock); + if (type == P9_TVERSION) + tag = idr_alloc(&c->reqs, req, P9_NOTAG, P9_NOTAG + 1, + GFP_NOWAIT); + else + tag = idr_alloc(&c->reqs, req, 0, P9_NOTAG, GFP_NOWAIT); + req->tc.tag = tag; + spin_unlock_irq(&c->lock); + idr_preload_end(); + if (tag < 0) + goto free; + + /* Init ref to two because in the general case there is one ref + * that is put asynchronously by a writer thread, one ref + * temporarily given by p9_tag_lookup and put by p9_client_cb + * in the recv thread, and one ref put by p9_req_put in the + * main thread. The only exception is virtio that does not use + * p9_tag_lookup but does not have a writer thread either + * (the write happens synchronously in the request/zc_request + * callback), so p9_client_cb eats the second ref there + * as the pointer is duplicated directly by virtqueue_add_sgs() + */ + refcount_set(&req->refcount, 2); + + return req; + +free: + p9_fcall_fini(&req->tc); + p9_fcall_fini(&req->rc); +free_req: + kmem_cache_free(p9_req_cache, req); + return ERR_PTR(-ENOMEM); +} + +/** + * p9_tag_lookup - Look up a request by tag. + * @c: Client session. + * @tag: Transaction ID. + * + * Context: Any context. + * Return: A request, or %NULL if there is no request with that tag. + */ +struct p9_req_t *p9_tag_lookup(struct p9_client *c, u16 tag) +{ + struct p9_req_t *req; + + rcu_read_lock(); +again: + req = idr_find(&c->reqs, tag); + if (req) { + /* We have to be careful with the req found under rcu_read_lock + * Thanks to SLAB_TYPESAFE_BY_RCU we can safely try to get the + * ref again without corrupting other data, then check again + * that the tag matches once we have the ref + */ + if (!p9_req_try_get(req)) + goto again; + if (req->tc.tag != tag) { + p9_req_put(c, req); + goto again; + } + } + rcu_read_unlock(); + + return req; +} +EXPORT_SYMBOL(p9_tag_lookup); + +/** + * p9_tag_remove - Remove a tag. + * @c: Client session. + * @r: Request of reference. + * + * Context: Any context. + */ +static void p9_tag_remove(struct p9_client *c, struct p9_req_t *r) +{ + unsigned long flags; + u16 tag = r->tc.tag; + + p9_debug(P9_DEBUG_MUX, "freeing clnt %p req %p tag: %d\n", c, r, tag); + spin_lock_irqsave(&c->lock, flags); + idr_remove(&c->reqs, tag); + spin_unlock_irqrestore(&c->lock, flags); +} + +int p9_req_put(struct p9_client *c, struct p9_req_t *r) +{ + if (refcount_dec_and_test(&r->refcount)) { + p9_tag_remove(c, r); + + p9_fcall_fini(&r->tc); + p9_fcall_fini(&r->rc); + kmem_cache_free(p9_req_cache, r); + return 1; + } + return 0; +} +EXPORT_SYMBOL(p9_req_put); + +/** + * p9_tag_cleanup - cleans up tags structure and reclaims resources + * @c: v9fs client struct + * + * This frees resources associated with the tags structure + * + */ +static void p9_tag_cleanup(struct p9_client *c) +{ + struct p9_req_t *req; + int id; + + rcu_read_lock(); + idr_for_each_entry(&c->reqs, req, id) { + pr_info("Tag %d still in use\n", id); + if (p9_req_put(c, req) == 0) + pr_warn("Packet with tag %d has still references", + req->tc.tag); + } + rcu_read_unlock(); +} + +/** + * p9_client_cb - call back from transport to client + * @c: client state + * @req: request received + * @status: request status, one of REQ_STATUS_* + * + */ +void p9_client_cb(struct p9_client *c, struct p9_req_t *req, int status) +{ + p9_debug(P9_DEBUG_MUX, " tag %d\n", req->tc.tag); + + /* This barrier is needed to make sure any change made to req before + * the status change is visible to another thread + */ + smp_wmb(); + WRITE_ONCE(req->status, status); + + wake_up(&req->wq); + p9_debug(P9_DEBUG_MUX, "wakeup: %d\n", req->tc.tag); + p9_req_put(c, req); +} +EXPORT_SYMBOL(p9_client_cb); + +/** + * p9_parse_header - parse header arguments out of a packet + * @pdu: packet to parse + * @size: size of packet + * @type: type of request + * @tag: tag of packet + * @rewind: set if we need to rewind offset afterwards + */ + +int +p9_parse_header(struct p9_fcall *pdu, int32_t *size, int8_t *type, + int16_t *tag, int rewind) +{ + s8 r_type; + s16 r_tag; + s32 r_size; + int offset = pdu->offset; + int err; + + pdu->offset = 0; + + err = p9pdu_readf(pdu, 0, "dbw", &r_size, &r_type, &r_tag); + if (err) + goto rewind_and_exit; + + if (type) + *type = r_type; + if (tag) + *tag = r_tag; + if (size) + *size = r_size; + + if (pdu->size != r_size || r_size < 7) { + err = -EINVAL; + goto rewind_and_exit; + } + + pdu->id = r_type; + pdu->tag = r_tag; + + p9_debug(P9_DEBUG_9P, "<<< size=%d type: %d tag: %d\n", + pdu->size, pdu->id, pdu->tag); + +rewind_and_exit: + if (rewind) + pdu->offset = offset; + return err; +} +EXPORT_SYMBOL(p9_parse_header); + +/** + * p9_check_errors - check 9p packet for error return and process it + * @c: current client instance + * @req: request to parse and check for error conditions + * + * returns error code if one is discovered, otherwise returns 0 + * + * this will have to be more complicated if we have multiple + * error packet types + */ + +static int p9_check_errors(struct p9_client *c, struct p9_req_t *req) +{ + s8 type; + int err; + int ecode; + + err = p9_parse_header(&req->rc, NULL, &type, NULL, 0); + if (req->rc.size >= c->msize) { + p9_debug(P9_DEBUG_ERROR, + "requested packet size too big: %d\n", + req->rc.size); + return -EIO; + } + /* dump the response from server + * This should be after check errors which poplulate pdu_fcall. + */ + trace_9p_protocol_dump(c, &req->rc); + if (err) { + p9_debug(P9_DEBUG_ERROR, "couldn't parse header %d\n", err); + return err; + } + if (type != P9_RERROR && type != P9_RLERROR) + return 0; + + if (!p9_is_proto_dotl(c)) { + char *ename = NULL; + + err = p9pdu_readf(&req->rc, c->proto_version, "s?d", + &ename, &ecode); + if (err) { + kfree(ename); + goto out_err; + } + + if (p9_is_proto_dotu(c) && ecode < 512) + err = -ecode; + + if (!err) { + err = p9_errstr2errno(ename, strlen(ename)); + + p9_debug(P9_DEBUG_9P, "<<< RERROR (%d) %s\n", + -ecode, ename); + } + kfree(ename); + } else { + err = p9pdu_readf(&req->rc, c->proto_version, "d", &ecode); + if (err) + goto out_err; + err = -ecode; + + p9_debug(P9_DEBUG_9P, "<<< RLERROR (%d)\n", -ecode); + } + + return err; + +out_err: + p9_debug(P9_DEBUG_ERROR, "couldn't parse error%d\n", err); + + return err; +} + +static struct p9_req_t * +p9_client_rpc(struct p9_client *c, int8_t type, const char *fmt, ...); + +/** + * p9_client_flush - flush (cancel) a request + * @c: client state + * @oldreq: request to cancel + * + * This sents a flush for a particular request and links + * the flush request to the original request. The current + * code only supports a single flush request although the protocol + * allows for multiple flush requests to be sent for a single request. + * + */ + +static int p9_client_flush(struct p9_client *c, struct p9_req_t *oldreq) +{ + struct p9_req_t *req; + s16 oldtag; + int err; + + err = p9_parse_header(&oldreq->tc, NULL, NULL, &oldtag, 1); + if (err) + return err; + + p9_debug(P9_DEBUG_9P, ">>> TFLUSH tag %d\n", oldtag); + + req = p9_client_rpc(c, P9_TFLUSH, "w", oldtag); + if (IS_ERR(req)) + return PTR_ERR(req); + + /* if we haven't received a response for oldreq, + * remove it from the list + */ + if (READ_ONCE(oldreq->status) == REQ_STATUS_SENT) { + if (c->trans_mod->cancelled) + c->trans_mod->cancelled(c, oldreq); + } + + p9_req_put(c, req); + return 0; +} + +static struct p9_req_t *p9_client_prepare_req(struct p9_client *c, + int8_t type, uint t_size, uint r_size, + const char *fmt, va_list ap) +{ + int err; + struct p9_req_t *req; + va_list apc; + + p9_debug(P9_DEBUG_MUX, "client %p op %d\n", c, type); + + /* we allow for any status other than disconnected */ + if (c->status == Disconnected) + return ERR_PTR(-EIO); + + /* if status is begin_disconnected we allow only clunk request */ + if (c->status == BeginDisconnect && type != P9_TCLUNK) + return ERR_PTR(-EIO); + + va_copy(apc, ap); + req = p9_tag_alloc(c, type, t_size, r_size, fmt, apc); + va_end(apc); + if (IS_ERR(req)) + return req; + + /* marshall the data */ + p9pdu_prepare(&req->tc, req->tc.tag, type); + err = p9pdu_vwritef(&req->tc, c->proto_version, fmt, ap); + if (err) + goto reterr; + p9pdu_finalize(c, &req->tc); + trace_9p_client_req(c, type, req->tc.tag); + return req; +reterr: + p9_req_put(c, req); + /* We have to put also the 2nd reference as it won't be used */ + p9_req_put(c, req); + return ERR_PTR(err); +} + +/** + * p9_client_rpc - issue a request and wait for a response + * @c: client session + * @type: type of request + * @fmt: protocol format string (see protocol.c) + * + * Returns request structure (which client must free using p9_req_put) + */ + +static struct p9_req_t * +p9_client_rpc(struct p9_client *c, int8_t type, const char *fmt, ...) +{ + va_list ap; + int sigpending, err; + unsigned long flags; + struct p9_req_t *req; + /* Passing zero for tsize/rsize to p9_client_prepare_req() tells it to + * auto determine an appropriate (small) request/response size + * according to actual message data being sent. Currently RDMA + * transport is excluded from this response message size optimization, + * as it would not cope with it, due to its pooled response buffers + * (using an optimized request size for RDMA as well though). + */ + const uint tsize = 0; + const uint rsize = c->trans_mod->pooled_rbuffers ? c->msize : 0; + + va_start(ap, fmt); + req = p9_client_prepare_req(c, type, tsize, rsize, fmt, ap); + va_end(ap); + if (IS_ERR(req)) + return req; + + if (signal_pending(current)) { + sigpending = 1; + clear_thread_flag(TIF_SIGPENDING); + } else { + sigpending = 0; + } + + err = c->trans_mod->request(c, req); + if (err < 0) { + /* write won't happen */ + p9_req_put(c, req); + if (err != -ERESTARTSYS && err != -EFAULT) + c->status = Disconnected; + goto recalc_sigpending; + } +again: + /* Wait for the response */ + err = wait_event_killable(req->wq, + READ_ONCE(req->status) >= REQ_STATUS_RCVD); + + /* Make sure our req is coherent with regard to updates in other + * threads - echoes to wmb() in the callback + */ + smp_rmb(); + + if (err == -ERESTARTSYS && c->status == Connected && + type == P9_TFLUSH) { + sigpending = 1; + clear_thread_flag(TIF_SIGPENDING); + goto again; + } + + if (READ_ONCE(req->status) == REQ_STATUS_ERROR) { + p9_debug(P9_DEBUG_ERROR, "req_status error %d\n", req->t_err); + err = req->t_err; + } + if (err == -ERESTARTSYS && c->status == Connected) { + p9_debug(P9_DEBUG_MUX, "flushing\n"); + sigpending = 1; + clear_thread_flag(TIF_SIGPENDING); + + if (c->trans_mod->cancel(c, req)) + p9_client_flush(c, req); + + /* if we received the response anyway, don't signal error */ + if (READ_ONCE(req->status) == REQ_STATUS_RCVD) + err = 0; + } +recalc_sigpending: + if (sigpending) { + spin_lock_irqsave(¤t->sighand->siglock, flags); + recalc_sigpending(); + spin_unlock_irqrestore(¤t->sighand->siglock, flags); + } + if (err < 0) + goto reterr; + + err = p9_check_errors(c, req); + trace_9p_client_res(c, type, req->rc.tag, err); + if (!err) + return req; +reterr: + p9_req_put(c, req); + return ERR_PTR(safe_errno(err)); +} + +/** + * p9_client_zc_rpc - issue a request and wait for a response + * @c: client session + * @type: type of request + * @uidata: destination for zero copy read + * @uodata: source for zero copy write + * @inlen: read buffer size + * @olen: write buffer size + * @in_hdrlen: reader header size, This is the size of response protocol data + * @fmt: protocol format string (see protocol.c) + * + * Returns request structure (which client must free using p9_req_put) + */ +static struct p9_req_t *p9_client_zc_rpc(struct p9_client *c, int8_t type, + struct iov_iter *uidata, + struct iov_iter *uodata, + int inlen, int olen, int in_hdrlen, + const char *fmt, ...) +{ + va_list ap; + int sigpending, err; + unsigned long flags; + struct p9_req_t *req; + + va_start(ap, fmt); + /* We allocate a inline protocol data of only 4k bytes. + * The actual content is passed in zero-copy fashion. + */ + req = p9_client_prepare_req(c, type, P9_ZC_HDR_SZ, P9_ZC_HDR_SZ, fmt, ap); + va_end(ap); + if (IS_ERR(req)) + return req; + + if (signal_pending(current)) { + sigpending = 1; + clear_thread_flag(TIF_SIGPENDING); + } else { + sigpending = 0; + } + + err = c->trans_mod->zc_request(c, req, uidata, uodata, + inlen, olen, in_hdrlen); + if (err < 0) { + if (err == -EIO) + c->status = Disconnected; + if (err != -ERESTARTSYS) + goto recalc_sigpending; + } + if (READ_ONCE(req->status) == REQ_STATUS_ERROR) { + p9_debug(P9_DEBUG_ERROR, "req_status error %d\n", req->t_err); + err = req->t_err; + } + if (err == -ERESTARTSYS && c->status == Connected) { + p9_debug(P9_DEBUG_MUX, "flushing\n"); + sigpending = 1; + clear_thread_flag(TIF_SIGPENDING); + + if (c->trans_mod->cancel(c, req)) + p9_client_flush(c, req); + + /* if we received the response anyway, don't signal error */ + if (READ_ONCE(req->status) == REQ_STATUS_RCVD) + err = 0; + } +recalc_sigpending: + if (sigpending) { + spin_lock_irqsave(¤t->sighand->siglock, flags); + recalc_sigpending(); + spin_unlock_irqrestore(¤t->sighand->siglock, flags); + } + if (err < 0) + goto reterr; + + err = p9_check_errors(c, req); + trace_9p_client_res(c, type, req->rc.tag, err); + if (!err) + return req; +reterr: + p9_req_put(c, req); + return ERR_PTR(safe_errno(err)); +} + +static struct p9_fid *p9_fid_create(struct p9_client *clnt) +{ + int ret; + struct p9_fid *fid; + + p9_debug(P9_DEBUG_FID, "clnt %p\n", clnt); + fid = kzalloc(sizeof(*fid), GFP_KERNEL); + if (!fid) + return NULL; + + fid->mode = -1; + fid->uid = current_fsuid(); + fid->clnt = clnt; + refcount_set(&fid->count, 1); + + idr_preload(GFP_KERNEL); + spin_lock_irq(&clnt->lock); + ret = idr_alloc_u32(&clnt->fids, fid, &fid->fid, P9_NOFID - 1, + GFP_NOWAIT); + spin_unlock_irq(&clnt->lock); + idr_preload_end(); + if (!ret) { + trace_9p_fid_ref(fid, P9_FID_REF_CREATE); + return fid; + } + + kfree(fid); + return NULL; +} + +static void p9_fid_destroy(struct p9_fid *fid) +{ + struct p9_client *clnt; + unsigned long flags; + + p9_debug(P9_DEBUG_FID, "fid %d\n", fid->fid); + trace_9p_fid_ref(fid, P9_FID_REF_DESTROY); + clnt = fid->clnt; + spin_lock_irqsave(&clnt->lock, flags); + idr_remove(&clnt->fids, fid->fid); + spin_unlock_irqrestore(&clnt->lock, flags); + kfree(fid->rdir); + kfree(fid); +} + +/* We also need to export tracepoint symbols for tracepoint_enabled() */ +EXPORT_TRACEPOINT_SYMBOL(9p_fid_ref); + +void do_trace_9p_fid_get(struct p9_fid *fid) +{ + trace_9p_fid_ref(fid, P9_FID_REF_GET); +} +EXPORT_SYMBOL(do_trace_9p_fid_get); + +void do_trace_9p_fid_put(struct p9_fid *fid) +{ + trace_9p_fid_ref(fid, P9_FID_REF_PUT); +} +EXPORT_SYMBOL(do_trace_9p_fid_put); + +static int p9_client_version(struct p9_client *c) +{ + int err = 0; + struct p9_req_t *req; + char *version = NULL; + int msize; + + p9_debug(P9_DEBUG_9P, ">>> TVERSION msize %d protocol %d\n", + c->msize, c->proto_version); + + switch (c->proto_version) { + case p9_proto_2000L: + req = p9_client_rpc(c, P9_TVERSION, "ds", + c->msize, "9P2000.L"); + break; + case p9_proto_2000u: + req = p9_client_rpc(c, P9_TVERSION, "ds", + c->msize, "9P2000.u"); + break; + case p9_proto_legacy: + req = p9_client_rpc(c, P9_TVERSION, "ds", + c->msize, "9P2000"); + break; + default: + return -EINVAL; + } + + if (IS_ERR(req)) + return PTR_ERR(req); + + err = p9pdu_readf(&req->rc, c->proto_version, "ds", &msize, &version); + if (err) { + p9_debug(P9_DEBUG_9P, "version error %d\n", err); + trace_9p_protocol_dump(c, &req->rc); + goto error; + } + + p9_debug(P9_DEBUG_9P, "<<< RVERSION msize %d %s\n", msize, version); + if (!strncmp(version, "9P2000.L", 8)) { + c->proto_version = p9_proto_2000L; + } else if (!strncmp(version, "9P2000.u", 8)) { + c->proto_version = p9_proto_2000u; + } else if (!strncmp(version, "9P2000", 6)) { + c->proto_version = p9_proto_legacy; + } else { + p9_debug(P9_DEBUG_ERROR, + "server returned an unknown version: %s\n", version); + err = -EREMOTEIO; + goto error; + } + + if (msize < 4096) { + p9_debug(P9_DEBUG_ERROR, + "server returned a msize < 4096: %d\n", msize); + err = -EREMOTEIO; + goto error; + } + if (msize < c->msize) + c->msize = msize; + +error: + kfree(version); + p9_req_put(c, req); + + return err; +} + +struct p9_client *p9_client_create(const char *dev_name, char *options) +{ + int err; + struct p9_client *clnt; + char *client_id; + + err = 0; + clnt = kmalloc(sizeof(*clnt), GFP_KERNEL); + if (!clnt) + return ERR_PTR(-ENOMEM); + + clnt->trans_mod = NULL; + clnt->trans = NULL; + clnt->fcall_cache = NULL; + + client_id = utsname()->nodename; + memcpy(clnt->name, client_id, strlen(client_id) + 1); + + spin_lock_init(&clnt->lock); + idr_init(&clnt->fids); + idr_init(&clnt->reqs); + + err = parse_opts(options, clnt); + if (err < 0) + goto free_client; + + if (!clnt->trans_mod) + clnt->trans_mod = v9fs_get_default_trans(); + + if (!clnt->trans_mod) { + err = -EPROTONOSUPPORT; + p9_debug(P9_DEBUG_ERROR, + "No transport defined or default transport\n"); + goto free_client; + } + + p9_debug(P9_DEBUG_MUX, "clnt %p trans %p msize %d protocol %d\n", + clnt, clnt->trans_mod, clnt->msize, clnt->proto_version); + + err = clnt->trans_mod->create(clnt, dev_name, options); + if (err) + goto put_trans; + + if (clnt->msize > clnt->trans_mod->maxsize) { + clnt->msize = clnt->trans_mod->maxsize; + pr_info("Limiting 'msize' to %d as this is the maximum " + "supported by transport %s\n", + clnt->msize, clnt->trans_mod->name + ); + } + + if (clnt->msize < 4096) { + p9_debug(P9_DEBUG_ERROR, + "Please specify a msize of at least 4k\n"); + err = -EINVAL; + goto close_trans; + } + + err = p9_client_version(clnt); + if (err) + goto close_trans; + + /* P9_HDRSZ + 4 is the smallest packet header we can have that is + * followed by data accessed from userspace by read + */ + clnt->fcall_cache = + kmem_cache_create_usercopy("9p-fcall-cache", clnt->msize, + 0, 0, P9_HDRSZ + 4, + clnt->msize - (P9_HDRSZ + 4), + NULL); + + return clnt; + +close_trans: + clnt->trans_mod->close(clnt); +put_trans: + v9fs_put_trans(clnt->trans_mod); +free_client: + kfree(clnt); + return ERR_PTR(err); +} +EXPORT_SYMBOL(p9_client_create); + +void p9_client_destroy(struct p9_client *clnt) +{ + struct p9_fid *fid; + int id; + + p9_debug(P9_DEBUG_MUX, "clnt %p\n", clnt); + + if (clnt->trans_mod) + clnt->trans_mod->close(clnt); + + v9fs_put_trans(clnt->trans_mod); + + idr_for_each_entry(&clnt->fids, fid, id) { + pr_info("Found fid %d not clunked\n", fid->fid); + p9_fid_destroy(fid); + } + + p9_tag_cleanup(clnt); + + kmem_cache_destroy(clnt->fcall_cache); + kfree(clnt); +} +EXPORT_SYMBOL(p9_client_destroy); + +void p9_client_disconnect(struct p9_client *clnt) +{ + p9_debug(P9_DEBUG_9P, "clnt %p\n", clnt); + clnt->status = Disconnected; +} +EXPORT_SYMBOL(p9_client_disconnect); + +void p9_client_begin_disconnect(struct p9_client *clnt) +{ + p9_debug(P9_DEBUG_9P, "clnt %p\n", clnt); + clnt->status = BeginDisconnect; +} +EXPORT_SYMBOL(p9_client_begin_disconnect); + +struct p9_fid *p9_client_attach(struct p9_client *clnt, struct p9_fid *afid, + const char *uname, kuid_t n_uname, + const char *aname) +{ + int err = 0; + struct p9_req_t *req; + struct p9_fid *fid; + struct p9_qid qid; + + p9_debug(P9_DEBUG_9P, ">>> TATTACH afid %d uname %s aname %s\n", + afid ? afid->fid : -1, uname, aname); + fid = p9_fid_create(clnt); + if (!fid) { + err = -ENOMEM; + goto error; + } + fid->uid = n_uname; + + req = p9_client_rpc(clnt, P9_TATTACH, "ddss?u", fid->fid, + afid ? afid->fid : P9_NOFID, uname, aname, n_uname); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + + err = p9pdu_readf(&req->rc, clnt->proto_version, "Q", &qid); + if (err) { + trace_9p_protocol_dump(clnt, &req->rc); + p9_req_put(clnt, req); + goto error; + } + + p9_debug(P9_DEBUG_9P, "<<< RATTACH qid %x.%llx.%x\n", + qid.type, qid.path, qid.version); + + memmove(&fid->qid, &qid, sizeof(struct p9_qid)); + + p9_req_put(clnt, req); + return fid; + +error: + if (fid) + p9_fid_destroy(fid); + return ERR_PTR(err); +} +EXPORT_SYMBOL(p9_client_attach); + +struct p9_fid *p9_client_walk(struct p9_fid *oldfid, uint16_t nwname, + const unsigned char * const *wnames, int clone) +{ + int err; + struct p9_client *clnt; + struct p9_fid *fid; + struct p9_qid *wqids; + struct p9_req_t *req; + u16 nwqids, count; + + err = 0; + wqids = NULL; + clnt = oldfid->clnt; + if (clone) { + fid = p9_fid_create(clnt); + if (!fid) { + err = -ENOMEM; + goto error; + } + + fid->uid = oldfid->uid; + } else { + fid = oldfid; + } + + p9_debug(P9_DEBUG_9P, ">>> TWALK fids %d,%d nwname %ud wname[0] %s\n", + oldfid->fid, fid->fid, nwname, wnames ? wnames[0] : NULL); + req = p9_client_rpc(clnt, P9_TWALK, "ddT", oldfid->fid, fid->fid, + nwname, wnames); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + + err = p9pdu_readf(&req->rc, clnt->proto_version, "R", &nwqids, &wqids); + if (err) { + trace_9p_protocol_dump(clnt, &req->rc); + p9_req_put(clnt, req); + goto clunk_fid; + } + p9_req_put(clnt, req); + + p9_debug(P9_DEBUG_9P, "<<< RWALK nwqid %d:\n", nwqids); + + if (nwqids != nwname) { + err = -ENOENT; + goto clunk_fid; + } + + for (count = 0; count < nwqids; count++) + p9_debug(P9_DEBUG_9P, "<<< [%d] %x.%llx.%x\n", + count, wqids[count].type, + wqids[count].path, + wqids[count].version); + + if (nwname) + memmove(&fid->qid, &wqids[nwqids - 1], sizeof(struct p9_qid)); + else + memmove(&fid->qid, &oldfid->qid, sizeof(struct p9_qid)); + + kfree(wqids); + return fid; + +clunk_fid: + kfree(wqids); + p9_fid_put(fid); + fid = NULL; + +error: + if (fid && fid != oldfid) + p9_fid_destroy(fid); + + return ERR_PTR(err); +} +EXPORT_SYMBOL(p9_client_walk); + +int p9_client_open(struct p9_fid *fid, int mode) +{ + int err; + struct p9_client *clnt; + struct p9_req_t *req; + struct p9_qid qid; + int iounit; + + clnt = fid->clnt; + p9_debug(P9_DEBUG_9P, ">>> %s fid %d mode %d\n", + p9_is_proto_dotl(clnt) ? "TLOPEN" : "TOPEN", fid->fid, mode); + err = 0; + + if (fid->mode != -1) + return -EINVAL; + + if (p9_is_proto_dotl(clnt)) + req = p9_client_rpc(clnt, P9_TLOPEN, "dd", fid->fid, mode); + else + req = p9_client_rpc(clnt, P9_TOPEN, "db", fid->fid, mode); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + + err = p9pdu_readf(&req->rc, clnt->proto_version, "Qd", &qid, &iounit); + if (err) { + trace_9p_protocol_dump(clnt, &req->rc); + goto free_and_error; + } + + p9_debug(P9_DEBUG_9P, "<<< %s qid %x.%llx.%x iounit %x\n", + p9_is_proto_dotl(clnt) ? "RLOPEN" : "ROPEN", qid.type, + qid.path, qid.version, iounit); + + memmove(&fid->qid, &qid, sizeof(struct p9_qid)); + fid->mode = mode; + fid->iounit = iounit; + +free_and_error: + p9_req_put(clnt, req); +error: + return err; +} +EXPORT_SYMBOL(p9_client_open); + +int p9_client_create_dotl(struct p9_fid *ofid, const char *name, u32 flags, + u32 mode, kgid_t gid, struct p9_qid *qid) +{ + int err = 0; + struct p9_client *clnt; + struct p9_req_t *req; + int iounit; + + p9_debug(P9_DEBUG_9P, + ">>> TLCREATE fid %d name %s flags %d mode %d gid %d\n", + ofid->fid, name, flags, mode, + from_kgid(&init_user_ns, gid)); + clnt = ofid->clnt; + + if (ofid->mode != -1) + return -EINVAL; + + req = p9_client_rpc(clnt, P9_TLCREATE, "dsddg", ofid->fid, name, flags, + mode, gid); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + + err = p9pdu_readf(&req->rc, clnt->proto_version, "Qd", qid, &iounit); + if (err) { + trace_9p_protocol_dump(clnt, &req->rc); + goto free_and_error; + } + + p9_debug(P9_DEBUG_9P, "<<< RLCREATE qid %x.%llx.%x iounit %x\n", + qid->type, qid->path, qid->version, iounit); + + memmove(&ofid->qid, qid, sizeof(struct p9_qid)); + ofid->mode = flags; + ofid->iounit = iounit; + +free_and_error: + p9_req_put(clnt, req); +error: + return err; +} +EXPORT_SYMBOL(p9_client_create_dotl); + +int p9_client_fcreate(struct p9_fid *fid, const char *name, u32 perm, int mode, + char *extension) +{ + int err; + struct p9_client *clnt; + struct p9_req_t *req; + struct p9_qid qid; + int iounit; + + p9_debug(P9_DEBUG_9P, ">>> TCREATE fid %d name %s perm %d mode %d\n", + fid->fid, name, perm, mode); + err = 0; + clnt = fid->clnt; + + if (fid->mode != -1) + return -EINVAL; + + req = p9_client_rpc(clnt, P9_TCREATE, "dsdb?s", fid->fid, name, perm, + mode, extension); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + + err = p9pdu_readf(&req->rc, clnt->proto_version, "Qd", &qid, &iounit); + if (err) { + trace_9p_protocol_dump(clnt, &req->rc); + goto free_and_error; + } + + p9_debug(P9_DEBUG_9P, "<<< RCREATE qid %x.%llx.%x iounit %x\n", + qid.type, qid.path, qid.version, iounit); + + memmove(&fid->qid, &qid, sizeof(struct p9_qid)); + fid->mode = mode; + fid->iounit = iounit; + +free_and_error: + p9_req_put(clnt, req); +error: + return err; +} +EXPORT_SYMBOL(p9_client_fcreate); + +int p9_client_symlink(struct p9_fid *dfid, const char *name, + const char *symtgt, kgid_t gid, struct p9_qid *qid) +{ + int err = 0; + struct p9_client *clnt; + struct p9_req_t *req; + + p9_debug(P9_DEBUG_9P, ">>> TSYMLINK dfid %d name %s symtgt %s\n", + dfid->fid, name, symtgt); + clnt = dfid->clnt; + + req = p9_client_rpc(clnt, P9_TSYMLINK, "dssg", dfid->fid, name, symtgt, + gid); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + + err = p9pdu_readf(&req->rc, clnt->proto_version, "Q", qid); + if (err) { + trace_9p_protocol_dump(clnt, &req->rc); + goto free_and_error; + } + + p9_debug(P9_DEBUG_9P, "<<< RSYMLINK qid %x.%llx.%x\n", + qid->type, qid->path, qid->version); + +free_and_error: + p9_req_put(clnt, req); +error: + return err; +} +EXPORT_SYMBOL(p9_client_symlink); + +int p9_client_link(struct p9_fid *dfid, struct p9_fid *oldfid, const char *newname) +{ + struct p9_client *clnt; + struct p9_req_t *req; + + p9_debug(P9_DEBUG_9P, ">>> TLINK dfid %d oldfid %d newname %s\n", + dfid->fid, oldfid->fid, newname); + clnt = dfid->clnt; + req = p9_client_rpc(clnt, P9_TLINK, "dds", dfid->fid, oldfid->fid, + newname); + if (IS_ERR(req)) + return PTR_ERR(req); + + p9_debug(P9_DEBUG_9P, "<<< RLINK\n"); + p9_req_put(clnt, req); + return 0; +} +EXPORT_SYMBOL(p9_client_link); + +int p9_client_fsync(struct p9_fid *fid, int datasync) +{ + int err; + struct p9_client *clnt; + struct p9_req_t *req; + + p9_debug(P9_DEBUG_9P, ">>> TFSYNC fid %d datasync:%d\n", + fid->fid, datasync); + err = 0; + clnt = fid->clnt; + + req = p9_client_rpc(clnt, P9_TFSYNC, "dd", fid->fid, datasync); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + + p9_debug(P9_DEBUG_9P, "<<< RFSYNC fid %d\n", fid->fid); + + p9_req_put(clnt, req); + +error: + return err; +} +EXPORT_SYMBOL(p9_client_fsync); + +int p9_client_clunk(struct p9_fid *fid) +{ + int err; + struct p9_client *clnt; + struct p9_req_t *req; + int retries = 0; + +again: + p9_debug(P9_DEBUG_9P, ">>> TCLUNK fid %d (try %d)\n", + fid->fid, retries); + err = 0; + clnt = fid->clnt; + + req = p9_client_rpc(clnt, P9_TCLUNK, "d", fid->fid); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + + p9_debug(P9_DEBUG_9P, "<<< RCLUNK fid %d\n", fid->fid); + + p9_req_put(clnt, req); +error: + /* Fid is not valid even after a failed clunk + * If interrupted, retry once then give up and + * leak fid until umount. + */ + if (err == -ERESTARTSYS) { + if (retries++ == 0) + goto again; + } else { + p9_fid_destroy(fid); + } + return err; +} +EXPORT_SYMBOL(p9_client_clunk); + +int p9_client_remove(struct p9_fid *fid) +{ + int err; + struct p9_client *clnt; + struct p9_req_t *req; + + p9_debug(P9_DEBUG_9P, ">>> TREMOVE fid %d\n", fid->fid); + err = 0; + clnt = fid->clnt; + + req = p9_client_rpc(clnt, P9_TREMOVE, "d", fid->fid); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + + p9_debug(P9_DEBUG_9P, "<<< RREMOVE fid %d\n", fid->fid); + + p9_req_put(clnt, req); +error: + if (err == -ERESTARTSYS) + p9_fid_put(fid); + else + p9_fid_destroy(fid); + return err; +} +EXPORT_SYMBOL(p9_client_remove); + +int p9_client_unlinkat(struct p9_fid *dfid, const char *name, int flags) +{ + int err = 0; + struct p9_req_t *req; + struct p9_client *clnt; + + p9_debug(P9_DEBUG_9P, ">>> TUNLINKAT fid %d %s %d\n", + dfid->fid, name, flags); + + clnt = dfid->clnt; + req = p9_client_rpc(clnt, P9_TUNLINKAT, "dsd", dfid->fid, name, flags); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + p9_debug(P9_DEBUG_9P, "<<< RUNLINKAT fid %d %s\n", dfid->fid, name); + + p9_req_put(clnt, req); +error: + return err; +} +EXPORT_SYMBOL(p9_client_unlinkat); + +int +p9_client_read(struct p9_fid *fid, u64 offset, struct iov_iter *to, int *err) +{ + int total = 0; + *err = 0; + + while (iov_iter_count(to)) { + int count; + + count = p9_client_read_once(fid, offset, to, err); + if (!count || *err) + break; + offset += count; + total += count; + } + return total; +} +EXPORT_SYMBOL(p9_client_read); + +int +p9_client_read_once(struct p9_fid *fid, u64 offset, struct iov_iter *to, + int *err) +{ + struct p9_client *clnt = fid->clnt; + struct p9_req_t *req; + int count = iov_iter_count(to); + int rsize, received, non_zc = 0; + char *dataptr; + + *err = 0; + p9_debug(P9_DEBUG_9P, ">>> TREAD fid %d offset %llu %zu\n", + fid->fid, offset, iov_iter_count(to)); + + rsize = fid->iounit; + if (!rsize || rsize > clnt->msize - P9_IOHDRSZ) + rsize = clnt->msize - P9_IOHDRSZ; + + if (count < rsize) + rsize = count; + + /* Don't bother zerocopy for small IO (< 1024) */ + if (clnt->trans_mod->zc_request && rsize > 1024) { + /* response header len is 11 + * PDU Header(7) + IO Size (4) + */ + req = p9_client_zc_rpc(clnt, P9_TREAD, to, NULL, rsize, + 0, 11, "dqd", fid->fid, + offset, rsize); + } else { + non_zc = 1; + req = p9_client_rpc(clnt, P9_TREAD, "dqd", fid->fid, offset, + rsize); + } + if (IS_ERR(req)) { + *err = PTR_ERR(req); + if (!non_zc) + iov_iter_revert(to, count - iov_iter_count(to)); + return 0; + } + + *err = p9pdu_readf(&req->rc, clnt->proto_version, + "D", &received, &dataptr); + if (*err) { + if (!non_zc) + iov_iter_revert(to, count - iov_iter_count(to)); + trace_9p_protocol_dump(clnt, &req->rc); + p9_req_put(clnt, req); + return 0; + } + if (rsize < received) { + pr_err("bogus RREAD count (%d > %d)\n", received, rsize); + received = rsize; + } + + p9_debug(P9_DEBUG_9P, "<<< RREAD count %d\n", count); + + if (non_zc) { + int n = copy_to_iter(dataptr, received, to); + + if (n != received) { + *err = -EFAULT; + p9_req_put(clnt, req); + return n; + } + } else { + iov_iter_revert(to, count - received - iov_iter_count(to)); + } + p9_req_put(clnt, req); + return received; +} +EXPORT_SYMBOL(p9_client_read_once); + +int +p9_client_write(struct p9_fid *fid, u64 offset, struct iov_iter *from, int *err) +{ + struct p9_client *clnt = fid->clnt; + struct p9_req_t *req; + int total = 0; + *err = 0; + + p9_debug(P9_DEBUG_9P, ">>> TWRITE fid %d offset %llu count %zd\n", + fid->fid, offset, iov_iter_count(from)); + + while (iov_iter_count(from)) { + int count = iov_iter_count(from); + int rsize = fid->iounit; + int written; + + if (!rsize || rsize > clnt->msize - P9_IOHDRSZ) + rsize = clnt->msize - P9_IOHDRSZ; + + if (count < rsize) + rsize = count; + + /* Don't bother zerocopy for small IO (< 1024) */ + if (clnt->trans_mod->zc_request && rsize > 1024) { + req = p9_client_zc_rpc(clnt, P9_TWRITE, NULL, from, 0, + rsize, P9_ZC_HDR_SZ, "dqd", + fid->fid, offset, rsize); + } else { + req = p9_client_rpc(clnt, P9_TWRITE, "dqV", fid->fid, + offset, rsize, from); + } + if (IS_ERR(req)) { + iov_iter_revert(from, count - iov_iter_count(from)); + *err = PTR_ERR(req); + break; + } + + *err = p9pdu_readf(&req->rc, clnt->proto_version, "d", &written); + if (*err) { + iov_iter_revert(from, count - iov_iter_count(from)); + trace_9p_protocol_dump(clnt, &req->rc); + p9_req_put(clnt, req); + break; + } + if (rsize < written) { + pr_err("bogus RWRITE count (%d > %d)\n", written, rsize); + written = rsize; + } + + p9_debug(P9_DEBUG_9P, "<<< RWRITE count %d\n", count); + + p9_req_put(clnt, req); + iov_iter_revert(from, count - written - iov_iter_count(from)); + total += written; + offset += written; + } + return total; +} +EXPORT_SYMBOL(p9_client_write); + +struct p9_wstat *p9_client_stat(struct p9_fid *fid) +{ + int err; + struct p9_client *clnt; + struct p9_wstat *ret; + struct p9_req_t *req; + u16 ignored; + + p9_debug(P9_DEBUG_9P, ">>> TSTAT fid %d\n", fid->fid); + + ret = kmalloc(sizeof(*ret), GFP_KERNEL); + if (!ret) + return ERR_PTR(-ENOMEM); + + err = 0; + clnt = fid->clnt; + + req = p9_client_rpc(clnt, P9_TSTAT, "d", fid->fid); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + + err = p9pdu_readf(&req->rc, clnt->proto_version, "wS", &ignored, ret); + if (err) { + trace_9p_protocol_dump(clnt, &req->rc); + p9_req_put(clnt, req); + goto error; + } + + p9_debug(P9_DEBUG_9P, + "<<< RSTAT sz=%x type=%x dev=%x qid=%x.%llx.%x\n" + "<<< mode=%8.8x atime=%8.8x mtime=%8.8x length=%llx\n" + "<<< name=%s uid=%s gid=%s muid=%s extension=(%s)\n" + "<<< uid=%d gid=%d n_muid=%d\n", + ret->size, ret->type, ret->dev, ret->qid.type, ret->qid.path, + ret->qid.version, ret->mode, + ret->atime, ret->mtime, ret->length, + ret->name, ret->uid, ret->gid, ret->muid, ret->extension, + from_kuid(&init_user_ns, ret->n_uid), + from_kgid(&init_user_ns, ret->n_gid), + from_kuid(&init_user_ns, ret->n_muid)); + + p9_req_put(clnt, req); + return ret; + +error: + kfree(ret); + return ERR_PTR(err); +} +EXPORT_SYMBOL(p9_client_stat); + +struct p9_stat_dotl *p9_client_getattr_dotl(struct p9_fid *fid, + u64 request_mask) +{ + int err; + struct p9_client *clnt; + struct p9_stat_dotl *ret; + struct p9_req_t *req; + + p9_debug(P9_DEBUG_9P, ">>> TGETATTR fid %d, request_mask %lld\n", + fid->fid, request_mask); + + ret = kmalloc(sizeof(*ret), GFP_KERNEL); + if (!ret) + return ERR_PTR(-ENOMEM); + + err = 0; + clnt = fid->clnt; + + req = p9_client_rpc(clnt, P9_TGETATTR, "dq", fid->fid, request_mask); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + + err = p9pdu_readf(&req->rc, clnt->proto_version, "A", ret); + if (err) { + trace_9p_protocol_dump(clnt, &req->rc); + p9_req_put(clnt, req); + goto error; + } + + p9_debug(P9_DEBUG_9P, "<<< RGETATTR st_result_mask=%lld\n" + "<<< qid=%x.%llx.%x\n" + "<<< st_mode=%8.8x st_nlink=%llu\n" + "<<< st_uid=%d st_gid=%d\n" + "<<< st_rdev=%llx st_size=%llx st_blksize=%llu st_blocks=%llu\n" + "<<< st_atime_sec=%lld st_atime_nsec=%lld\n" + "<<< st_mtime_sec=%lld st_mtime_nsec=%lld\n" + "<<< st_ctime_sec=%lld st_ctime_nsec=%lld\n" + "<<< st_btime_sec=%lld st_btime_nsec=%lld\n" + "<<< st_gen=%lld st_data_version=%lld\n", + ret->st_result_mask, + ret->qid.type, ret->qid.path, ret->qid.version, + ret->st_mode, ret->st_nlink, + from_kuid(&init_user_ns, ret->st_uid), + from_kgid(&init_user_ns, ret->st_gid), + ret->st_rdev, ret->st_size, ret->st_blksize, ret->st_blocks, + ret->st_atime_sec, ret->st_atime_nsec, + ret->st_mtime_sec, ret->st_mtime_nsec, + ret->st_ctime_sec, ret->st_ctime_nsec, + ret->st_btime_sec, ret->st_btime_nsec, + ret->st_gen, ret->st_data_version); + + p9_req_put(clnt, req); + return ret; + +error: + kfree(ret); + return ERR_PTR(err); +} +EXPORT_SYMBOL(p9_client_getattr_dotl); + +static int p9_client_statsize(struct p9_wstat *wst, int proto_version) +{ + int ret; + + /* NOTE: size shouldn't include its own length */ + /* size[2] type[2] dev[4] qid[13] */ + /* mode[4] atime[4] mtime[4] length[8]*/ + /* name[s] uid[s] gid[s] muid[s] */ + ret = 2 + 4 + 13 + 4 + 4 + 4 + 8 + 2 + 2 + 2 + 2; + + if (wst->name) + ret += strlen(wst->name); + if (wst->uid) + ret += strlen(wst->uid); + if (wst->gid) + ret += strlen(wst->gid); + if (wst->muid) + ret += strlen(wst->muid); + + if (proto_version == p9_proto_2000u || + proto_version == p9_proto_2000L) { + /* extension[s] n_uid[4] n_gid[4] n_muid[4] */ + ret += 2 + 4 + 4 + 4; + if (wst->extension) + ret += strlen(wst->extension); + } + + return ret; +} + +int p9_client_wstat(struct p9_fid *fid, struct p9_wstat *wst) +{ + int err; + struct p9_req_t *req; + struct p9_client *clnt; + + err = 0; + clnt = fid->clnt; + wst->size = p9_client_statsize(wst, clnt->proto_version); + p9_debug(P9_DEBUG_9P, ">>> TWSTAT fid %d\n", + fid->fid); + p9_debug(P9_DEBUG_9P, + " sz=%x type=%x dev=%x qid=%x.%llx.%x\n" + " mode=%8.8x atime=%8.8x mtime=%8.8x length=%llx\n" + " name=%s uid=%s gid=%s muid=%s extension=(%s)\n" + " uid=%d gid=%d n_muid=%d\n", + wst->size, wst->type, wst->dev, wst->qid.type, + wst->qid.path, wst->qid.version, + wst->mode, wst->atime, wst->mtime, wst->length, + wst->name, wst->uid, wst->gid, wst->muid, wst->extension, + from_kuid(&init_user_ns, wst->n_uid), + from_kgid(&init_user_ns, wst->n_gid), + from_kuid(&init_user_ns, wst->n_muid)); + + req = p9_client_rpc(clnt, P9_TWSTAT, "dwS", + fid->fid, wst->size + 2, wst); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + + p9_debug(P9_DEBUG_9P, "<<< RWSTAT fid %d\n", fid->fid); + + p9_req_put(clnt, req); +error: + return err; +} +EXPORT_SYMBOL(p9_client_wstat); + +int p9_client_setattr(struct p9_fid *fid, struct p9_iattr_dotl *p9attr) +{ + int err; + struct p9_req_t *req; + struct p9_client *clnt; + + err = 0; + clnt = fid->clnt; + p9_debug(P9_DEBUG_9P, ">>> TSETATTR fid %d\n", fid->fid); + p9_debug(P9_DEBUG_9P, " valid=%x mode=%x uid=%d gid=%d size=%lld\n", + p9attr->valid, p9attr->mode, + from_kuid(&init_user_ns, p9attr->uid), + from_kgid(&init_user_ns, p9attr->gid), + p9attr->size); + p9_debug(P9_DEBUG_9P, " atime_sec=%lld atime_nsec=%lld\n", + p9attr->atime_sec, p9attr->atime_nsec); + p9_debug(P9_DEBUG_9P, " mtime_sec=%lld mtime_nsec=%lld\n", + p9attr->mtime_sec, p9attr->mtime_nsec); + + req = p9_client_rpc(clnt, P9_TSETATTR, "dI", fid->fid, p9attr); + + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + p9_debug(P9_DEBUG_9P, "<<< RSETATTR fid %d\n", fid->fid); + p9_req_put(clnt, req); +error: + return err; +} +EXPORT_SYMBOL(p9_client_setattr); + +int p9_client_statfs(struct p9_fid *fid, struct p9_rstatfs *sb) +{ + int err; + struct p9_req_t *req; + struct p9_client *clnt; + + err = 0; + clnt = fid->clnt; + + p9_debug(P9_DEBUG_9P, ">>> TSTATFS fid %d\n", fid->fid); + + req = p9_client_rpc(clnt, P9_TSTATFS, "d", fid->fid); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + + err = p9pdu_readf(&req->rc, clnt->proto_version, "ddqqqqqqd", &sb->type, + &sb->bsize, &sb->blocks, &sb->bfree, &sb->bavail, + &sb->files, &sb->ffree, &sb->fsid, &sb->namelen); + if (err) { + trace_9p_protocol_dump(clnt, &req->rc); + p9_req_put(clnt, req); + goto error; + } + + p9_debug(P9_DEBUG_9P, + "<<< RSTATFS fid %d type 0x%x bsize %u blocks %llu bfree %llu bavail %llu files %llu ffree %llu fsid %llu namelen %u\n", + fid->fid, sb->type, sb->bsize, sb->blocks, sb->bfree, + sb->bavail, sb->files, sb->ffree, sb->fsid, sb->namelen); + + p9_req_put(clnt, req); +error: + return err; +} +EXPORT_SYMBOL(p9_client_statfs); + +int p9_client_rename(struct p9_fid *fid, + struct p9_fid *newdirfid, const char *name) +{ + int err; + struct p9_req_t *req; + struct p9_client *clnt; + + err = 0; + clnt = fid->clnt; + + p9_debug(P9_DEBUG_9P, ">>> TRENAME fid %d newdirfid %d name %s\n", + fid->fid, newdirfid->fid, name); + + req = p9_client_rpc(clnt, P9_TRENAME, "dds", fid->fid, + newdirfid->fid, name); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + + p9_debug(P9_DEBUG_9P, "<<< RRENAME fid %d\n", fid->fid); + + p9_req_put(clnt, req); +error: + return err; +} +EXPORT_SYMBOL(p9_client_rename); + +int p9_client_renameat(struct p9_fid *olddirfid, const char *old_name, + struct p9_fid *newdirfid, const char *new_name) +{ + int err; + struct p9_req_t *req; + struct p9_client *clnt; + + err = 0; + clnt = olddirfid->clnt; + + p9_debug(P9_DEBUG_9P, + ">>> TRENAMEAT olddirfid %d old name %s newdirfid %d new name %s\n", + olddirfid->fid, old_name, newdirfid->fid, new_name); + + req = p9_client_rpc(clnt, P9_TRENAMEAT, "dsds", olddirfid->fid, + old_name, newdirfid->fid, new_name); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + + p9_debug(P9_DEBUG_9P, "<<< RRENAMEAT newdirfid %d new name %s\n", + newdirfid->fid, new_name); + + p9_req_put(clnt, req); +error: + return err; +} +EXPORT_SYMBOL(p9_client_renameat); + +/* An xattrwalk without @attr_name gives the fid for the lisxattr namespace + */ +struct p9_fid *p9_client_xattrwalk(struct p9_fid *file_fid, + const char *attr_name, u64 *attr_size) +{ + int err; + struct p9_req_t *req; + struct p9_client *clnt; + struct p9_fid *attr_fid; + + err = 0; + clnt = file_fid->clnt; + attr_fid = p9_fid_create(clnt); + if (!attr_fid) { + err = -ENOMEM; + goto error; + } + p9_debug(P9_DEBUG_9P, + ">>> TXATTRWALK file_fid %d, attr_fid %d name '%s'\n", + file_fid->fid, attr_fid->fid, attr_name); + + req = p9_client_rpc(clnt, P9_TXATTRWALK, "dds", + file_fid->fid, attr_fid->fid, attr_name); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + err = p9pdu_readf(&req->rc, clnt->proto_version, "q", attr_size); + if (err) { + trace_9p_protocol_dump(clnt, &req->rc); + p9_req_put(clnt, req); + goto clunk_fid; + } + p9_req_put(clnt, req); + p9_debug(P9_DEBUG_9P, "<<< RXATTRWALK fid %d size %llu\n", + attr_fid->fid, *attr_size); + return attr_fid; +clunk_fid: + p9_fid_put(attr_fid); + attr_fid = NULL; +error: + if (attr_fid && attr_fid != file_fid) + p9_fid_destroy(attr_fid); + + return ERR_PTR(err); +} +EXPORT_SYMBOL_GPL(p9_client_xattrwalk); + +int p9_client_xattrcreate(struct p9_fid *fid, const char *name, + u64 attr_size, int flags) +{ + int err; + struct p9_req_t *req; + struct p9_client *clnt; + + p9_debug(P9_DEBUG_9P, + ">>> TXATTRCREATE fid %d name %s size %llu flag %d\n", + fid->fid, name, attr_size, flags); + err = 0; + clnt = fid->clnt; + req = p9_client_rpc(clnt, P9_TXATTRCREATE, "dsqd", + fid->fid, name, attr_size, flags); + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + p9_debug(P9_DEBUG_9P, "<<< RXATTRCREATE fid %d\n", fid->fid); + p9_req_put(clnt, req); +error: + return err; +} +EXPORT_SYMBOL_GPL(p9_client_xattrcreate); + +int p9_client_readdir(struct p9_fid *fid, char *data, u32 count, u64 offset) +{ + int err, rsize, non_zc = 0; + struct p9_client *clnt; + struct p9_req_t *req; + char *dataptr; + struct kvec kv = {.iov_base = data, .iov_len = count}; + struct iov_iter to; + + iov_iter_kvec(&to, ITER_DEST, &kv, 1, count); + + p9_debug(P9_DEBUG_9P, ">>> TREADDIR fid %d offset %llu count %d\n", + fid->fid, offset, count); + + err = 0; + clnt = fid->clnt; + + rsize = fid->iounit; + if (!rsize || rsize > clnt->msize - P9_READDIRHDRSZ) + rsize = clnt->msize - P9_READDIRHDRSZ; + + if (count < rsize) + rsize = count; + + /* Don't bother zerocopy for small IO (< 1024) */ + if (clnt->trans_mod->zc_request && rsize > 1024) { + /* response header len is 11 + * PDU Header(7) + IO Size (4) + */ + req = p9_client_zc_rpc(clnt, P9_TREADDIR, &to, NULL, rsize, 0, + 11, "dqd", fid->fid, offset, rsize); + } else { + non_zc = 1; + req = p9_client_rpc(clnt, P9_TREADDIR, "dqd", fid->fid, + offset, rsize); + } + if (IS_ERR(req)) { + err = PTR_ERR(req); + goto error; + } + + err = p9pdu_readf(&req->rc, clnt->proto_version, "D", &count, &dataptr); + if (err) { + trace_9p_protocol_dump(clnt, &req->rc); + goto free_and_error; + } + if (rsize < count) { + pr_err("bogus RREADDIR count (%d > %d)\n", count, rsize); + count = rsize; + } + + p9_debug(P9_DEBUG_9P, "<<< RREADDIR count %d\n", count); + + if (non_zc) + memmove(data, dataptr, count); + + p9_req_put(clnt, req); + return count; + +free_and_error: + p9_req_put(clnt, req); +error: + return err; +} +EXPORT_SYMBOL(p9_client_readdir); + +int p9_client_mknod_dotl(struct p9_fid *fid, const char *name, int mode, + dev_t rdev, kgid_t gid, struct p9_qid *qid) +{ + int err; + struct p9_client *clnt; + struct p9_req_t *req; + + err = 0; + clnt = fid->clnt; + p9_debug(P9_DEBUG_9P, + ">>> TMKNOD fid %d name %s mode %d major %d minor %d\n", + fid->fid, name, mode, MAJOR(rdev), MINOR(rdev)); + req = p9_client_rpc(clnt, P9_TMKNOD, "dsdddg", fid->fid, name, mode, + MAJOR(rdev), MINOR(rdev), gid); + if (IS_ERR(req)) + return PTR_ERR(req); + + err = p9pdu_readf(&req->rc, clnt->proto_version, "Q", qid); + if (err) { + trace_9p_protocol_dump(clnt, &req->rc); + goto error; + } + p9_debug(P9_DEBUG_9P, "<<< RMKNOD qid %x.%llx.%x\n", + qid->type, qid->path, qid->version); + +error: + p9_req_put(clnt, req); + return err; +} +EXPORT_SYMBOL(p9_client_mknod_dotl); + +int p9_client_mkdir_dotl(struct p9_fid *fid, const char *name, int mode, + kgid_t gid, struct p9_qid *qid) +{ + int err; + struct p9_client *clnt; + struct p9_req_t *req; + + err = 0; + clnt = fid->clnt; + p9_debug(P9_DEBUG_9P, ">>> TMKDIR fid %d name %s mode %d gid %d\n", + fid->fid, name, mode, from_kgid(&init_user_ns, gid)); + req = p9_client_rpc(clnt, P9_TMKDIR, "dsdg", + fid->fid, name, mode, gid); + if (IS_ERR(req)) + return PTR_ERR(req); + + err = p9pdu_readf(&req->rc, clnt->proto_version, "Q", qid); + if (err) { + trace_9p_protocol_dump(clnt, &req->rc); + goto error; + } + p9_debug(P9_DEBUG_9P, "<<< RMKDIR qid %x.%llx.%x\n", qid->type, + qid->path, qid->version); + +error: + p9_req_put(clnt, req); + return err; +} +EXPORT_SYMBOL(p9_client_mkdir_dotl); + +int p9_client_lock_dotl(struct p9_fid *fid, struct p9_flock *flock, u8 *status) +{ + int err; + struct p9_client *clnt; + struct p9_req_t *req; + + err = 0; + clnt = fid->clnt; + p9_debug(P9_DEBUG_9P, + ">>> TLOCK fid %d type %i flags %d start %lld length %lld proc_id %d client_id %s\n", + fid->fid, flock->type, flock->flags, flock->start, + flock->length, flock->proc_id, flock->client_id); + + req = p9_client_rpc(clnt, P9_TLOCK, "dbdqqds", fid->fid, flock->type, + flock->flags, flock->start, flock->length, + flock->proc_id, flock->client_id); + + if (IS_ERR(req)) + return PTR_ERR(req); + + err = p9pdu_readf(&req->rc, clnt->proto_version, "b", status); + if (err) { + trace_9p_protocol_dump(clnt, &req->rc); + goto error; + } + p9_debug(P9_DEBUG_9P, "<<< RLOCK status %i\n", *status); +error: + p9_req_put(clnt, req); + return err; +} +EXPORT_SYMBOL(p9_client_lock_dotl); + +int p9_client_getlock_dotl(struct p9_fid *fid, struct p9_getlock *glock) +{ + int err; + struct p9_client *clnt; + struct p9_req_t *req; + + err = 0; + clnt = fid->clnt; + p9_debug(P9_DEBUG_9P, + ">>> TGETLOCK fid %d, type %i start %lld length %lld proc_id %d client_id %s\n", + fid->fid, glock->type, glock->start, glock->length, + glock->proc_id, glock->client_id); + + req = p9_client_rpc(clnt, P9_TGETLOCK, "dbqqds", fid->fid, + glock->type, glock->start, glock->length, + glock->proc_id, glock->client_id); + + if (IS_ERR(req)) + return PTR_ERR(req); + + err = p9pdu_readf(&req->rc, clnt->proto_version, "bqqds", &glock->type, + &glock->start, &glock->length, &glock->proc_id, + &glock->client_id); + if (err) { + trace_9p_protocol_dump(clnt, &req->rc); + goto error; + } + p9_debug(P9_DEBUG_9P, + "<<< RGETLOCK type %i start %lld length %lld proc_id %d client_id %s\n", + glock->type, glock->start, glock->length, + glock->proc_id, glock->client_id); +error: + p9_req_put(clnt, req); + return err; +} +EXPORT_SYMBOL(p9_client_getlock_dotl); + +int p9_client_readlink(struct p9_fid *fid, char **target) +{ + int err; + struct p9_client *clnt; + struct p9_req_t *req; + + err = 0; + clnt = fid->clnt; + p9_debug(P9_DEBUG_9P, ">>> TREADLINK fid %d\n", fid->fid); + + req = p9_client_rpc(clnt, P9_TREADLINK, "d", fid->fid); + if (IS_ERR(req)) + return PTR_ERR(req); + + err = p9pdu_readf(&req->rc, clnt->proto_version, "s", target); + if (err) { + trace_9p_protocol_dump(clnt, &req->rc); + goto error; + } + p9_debug(P9_DEBUG_9P, "<<< RREADLINK target %s\n", *target); +error: + p9_req_put(clnt, req); + return err; +} +EXPORT_SYMBOL(p9_client_readlink); + +int __init p9_client_init(void) +{ + p9_req_cache = KMEM_CACHE(p9_req_t, SLAB_TYPESAFE_BY_RCU); + return p9_req_cache ? 0 : -ENOMEM; +} + +void __exit p9_client_exit(void) +{ + kmem_cache_destroy(p9_req_cache); +} diff --git a/net/9p/error.c b/net/9p/error.c new file mode 100644 index 000000000..8da744494 --- /dev/null +++ b/net/9p/error.c @@ -0,0 +1,230 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Error string handling + * + * Plan 9 uses error strings, Unix uses error numbers. These functions + * try to help manage that and provide for dynamically adding error + * mappings. + * + * Copyright (C) 2004 by Eric Van Hensbergen + * Copyright (C) 2002 by Ron Minnich + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include + +/** + * struct errormap - map string errors from Plan 9 to Linux numeric ids + * @name: string sent over 9P + * @val: numeric id most closely representing @name + * @namelen: length of string + * @list: hash-table list for string lookup + */ +struct errormap { + char *name; + int val; + + int namelen; + struct hlist_node list; +}; + +#define ERRHASHSZ 32 +static struct hlist_head hash_errmap[ERRHASHSZ]; + +/* FixMe - reduce to a reasonable size */ +static struct errormap errmap[] = { + {"Operation not permitted", EPERM}, + {"wstat prohibited", EPERM}, + {"No such file or directory", ENOENT}, + {"directory entry not found", ENOENT}, + {"file not found", ENOENT}, + {"Interrupted system call", EINTR}, + {"Input/output error", EIO}, + {"No such device or address", ENXIO}, + {"Argument list too long", E2BIG}, + {"Bad file descriptor", EBADF}, + {"Resource temporarily unavailable", EAGAIN}, + {"Cannot allocate memory", ENOMEM}, + {"Permission denied", EACCES}, + {"Bad address", EFAULT}, + {"Block device required", ENOTBLK}, + {"Device or resource busy", EBUSY}, + {"File exists", EEXIST}, + {"Invalid cross-device link", EXDEV}, + {"No such device", ENODEV}, + {"Not a directory", ENOTDIR}, + {"Is a directory", EISDIR}, + {"Invalid argument", EINVAL}, + {"Too many open files in system", ENFILE}, + {"Too many open files", EMFILE}, + {"Text file busy", ETXTBSY}, + {"File too large", EFBIG}, + {"No space left on device", ENOSPC}, + {"Illegal seek", ESPIPE}, + {"Read-only file system", EROFS}, + {"Too many links", EMLINK}, + {"Broken pipe", EPIPE}, + {"Numerical argument out of domain", EDOM}, + {"Numerical result out of range", ERANGE}, + {"Resource deadlock avoided", EDEADLK}, + {"File name too long", ENAMETOOLONG}, + {"No locks available", ENOLCK}, + {"Function not implemented", ENOSYS}, + {"Directory not empty", ENOTEMPTY}, + {"Too many levels of symbolic links", ELOOP}, + {"No message of desired type", ENOMSG}, + {"Identifier removed", EIDRM}, + {"No data available", ENODATA}, + {"Machine is not on the network", ENONET}, + {"Package not installed", ENOPKG}, + {"Object is remote", EREMOTE}, + {"Link has been severed", ENOLINK}, + {"Communication error on send", ECOMM}, + {"Protocol error", EPROTO}, + {"Bad message", EBADMSG}, + {"File descriptor in bad state", EBADFD}, + {"Streams pipe error", ESTRPIPE}, + {"Too many users", EUSERS}, + {"Socket operation on non-socket", ENOTSOCK}, + {"Message too long", EMSGSIZE}, + {"Protocol not available", ENOPROTOOPT}, + {"Protocol not supported", EPROTONOSUPPORT}, + {"Socket type not supported", ESOCKTNOSUPPORT}, + {"Operation not supported", EOPNOTSUPP}, + {"Protocol family not supported", EPFNOSUPPORT}, + {"Network is down", ENETDOWN}, + {"Network is unreachable", ENETUNREACH}, + {"Network dropped connection on reset", ENETRESET}, + {"Software caused connection abort", ECONNABORTED}, + {"Connection reset by peer", ECONNRESET}, + {"No buffer space available", ENOBUFS}, + {"Transport endpoint is already connected", EISCONN}, + {"Transport endpoint is not connected", ENOTCONN}, + {"Cannot send after transport endpoint shutdown", ESHUTDOWN}, + {"Connection timed out", ETIMEDOUT}, + {"Connection refused", ECONNREFUSED}, + {"Host is down", EHOSTDOWN}, + {"No route to host", EHOSTUNREACH}, + {"Operation already in progress", EALREADY}, + {"Operation now in progress", EINPROGRESS}, + {"Is a named type file", EISNAM}, + {"Remote I/O error", EREMOTEIO}, + {"Disk quota exceeded", EDQUOT}, +/* errors from fossil, vacfs, and u9fs */ + {"fid unknown or out of range", EBADF}, + {"permission denied", EACCES}, + {"file does not exist", ENOENT}, + {"authentication failed", ECONNREFUSED}, + {"bad offset in directory read", ESPIPE}, + {"bad use of fid", EBADF}, + {"wstat can't convert between files and directories", EPERM}, + {"directory is not empty", ENOTEMPTY}, + {"file exists", EEXIST}, + {"file already exists", EEXIST}, + {"file or directory already exists", EEXIST}, + {"fid already in use", EBADF}, + {"file in use", ETXTBSY}, + {"i/o error", EIO}, + {"file already open for I/O", ETXTBSY}, + {"illegal mode", EINVAL}, + {"illegal name", ENAMETOOLONG}, + {"not a directory", ENOTDIR}, + {"not a member of proposed group", EPERM}, + {"not owner", EACCES}, + {"only owner can change group in wstat", EACCES}, + {"read only file system", EROFS}, + {"no access to special file", EPERM}, + {"i/o count too large", EIO}, + {"unknown group", EINVAL}, + {"unknown user", EINVAL}, + {"bogus wstat buffer", EPROTO}, + {"exclusive use file already open", EAGAIN}, + {"corrupted directory entry", EIO}, + {"corrupted file entry", EIO}, + {"corrupted block label", EIO}, + {"corrupted meta data", EIO}, + {"illegal offset", EINVAL}, + {"illegal path element", ENOENT}, + {"root of file system is corrupted", EIO}, + {"corrupted super block", EIO}, + {"protocol botch", EPROTO}, + {"file system is full", ENOSPC}, + {"file is in use", EAGAIN}, + {"directory entry is not allocated", ENOENT}, + {"file is read only", EROFS}, + {"file has been removed", EIDRM}, + {"only support truncation to zero length", EPERM}, + {"cannot remove root", EPERM}, + {"file too big", EFBIG}, + {"venti i/o error", EIO}, + /* these are not errors */ + {"u9fs rhostsauth: no authentication required", 0}, + {"u9fs authnone: no authentication required", 0}, + {NULL, -1} +}; + +/** + * p9_error_init - preload mappings into hash list + * + */ + +int p9_error_init(void) +{ + struct errormap *c; + int bucket; + + /* initialize hash table */ + for (bucket = 0; bucket < ERRHASHSZ; bucket++) + INIT_HLIST_HEAD(&hash_errmap[bucket]); + + /* load initial error map into hash table */ + for (c = errmap; c->name; c++) { + c->namelen = strlen(c->name); + bucket = jhash(c->name, c->namelen, 0) % ERRHASHSZ; + INIT_HLIST_NODE(&c->list); + hlist_add_head(&c->list, &hash_errmap[bucket]); + } + + return 1; +} +EXPORT_SYMBOL(p9_error_init); + +/** + * p9_errstr2errno - convert error string to error number + * @errstr: error string + * @len: length of error string + * + */ + +int p9_errstr2errno(char *errstr, int len) +{ + int errno; + struct errormap *c; + int bucket; + + errno = 0; + c = NULL; + bucket = jhash(errstr, len, 0) % ERRHASHSZ; + hlist_for_each_entry(c, &hash_errmap[bucket], list) { + if (c->namelen == len && !memcmp(c->name, errstr, len)) { + errno = c->val; + break; + } + } + + if (errno == 0) { + /* TODO: if error isn't found, add it dynamically */ + errstr[len] = 0; + pr_err("%s: server reported unknown error %s\n", + __func__, errstr); + errno = ESERVERFAULT; + } + + return -errno; +} +EXPORT_SYMBOL(p9_errstr2errno); diff --git a/net/9p/mod.c b/net/9p/mod.c new file mode 100644 index 000000000..55576c186 --- /dev/null +++ b/net/9p/mod.c @@ -0,0 +1,212 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * 9P entry point + * + * Copyright (C) 2007 by Latchesar Ionkov + * Copyright (C) 2004 by Eric Van Hensbergen + * Copyright (C) 2002 by Ron Minnich + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CONFIG_NET_9P_DEBUG +unsigned int p9_debug_level; /* feature-rific global debug level */ +EXPORT_SYMBOL(p9_debug_level); +module_param_named(debug, p9_debug_level, uint, 0); +MODULE_PARM_DESC(debug, "9P debugging level"); + +void _p9_debug(enum p9_debug_flags level, const char *func, + const char *fmt, ...) +{ + struct va_format vaf; + va_list args; + + if ((p9_debug_level & level) != level) + return; + + va_start(args, fmt); + + vaf.fmt = fmt; + vaf.va = &args; + + if (level == P9_DEBUG_9P) + pr_notice("(%8.8d) %pV", task_pid_nr(current), &vaf); + else + pr_notice("-- %s (%d): %pV", func, task_pid_nr(current), &vaf); + + va_end(args); +} +EXPORT_SYMBOL(_p9_debug); +#endif + +/* Dynamic Transport Registration Routines */ + +static DEFINE_SPINLOCK(v9fs_trans_lock); +static LIST_HEAD(v9fs_trans_list); + +/** + * v9fs_register_trans - register a new transport with 9p + * @m: structure describing the transport module and entry points + * + */ +void v9fs_register_trans(struct p9_trans_module *m) +{ + spin_lock(&v9fs_trans_lock); + list_add_tail(&m->list, &v9fs_trans_list); + spin_unlock(&v9fs_trans_lock); +} +EXPORT_SYMBOL(v9fs_register_trans); + +/** + * v9fs_unregister_trans - unregister a 9p transport + * @m: the transport to remove + * + */ +void v9fs_unregister_trans(struct p9_trans_module *m) +{ + spin_lock(&v9fs_trans_lock); + list_del_init(&m->list); + spin_unlock(&v9fs_trans_lock); +} +EXPORT_SYMBOL(v9fs_unregister_trans); + +static struct p9_trans_module *_p9_get_trans_by_name(const char *s) +{ + struct p9_trans_module *t, *found = NULL; + + spin_lock(&v9fs_trans_lock); + + list_for_each_entry(t, &v9fs_trans_list, list) + if (strcmp(t->name, s) == 0 && + try_module_get(t->owner)) { + found = t; + break; + } + + spin_unlock(&v9fs_trans_lock); + + return found; +} + +/** + * v9fs_get_trans_by_name - get transport with the matching name + * @s: string identifying transport + * + */ +struct p9_trans_module *v9fs_get_trans_by_name(const char *s) +{ + struct p9_trans_module *found = NULL; + + found = _p9_get_trans_by_name(s); + +#ifdef CONFIG_MODULES + if (!found) { + request_module("9p-%s", s); + found = _p9_get_trans_by_name(s); + } +#endif + + return found; +} +EXPORT_SYMBOL(v9fs_get_trans_by_name); + +static const char * const v9fs_default_transports[] = { + "virtio", "tcp", "fd", "unix", "xen", "rdma", +}; + +/** + * v9fs_get_default_trans - get the default transport + * + */ + +struct p9_trans_module *v9fs_get_default_trans(void) +{ + struct p9_trans_module *t, *found = NULL; + int i; + + spin_lock(&v9fs_trans_lock); + + list_for_each_entry(t, &v9fs_trans_list, list) + if (t->def && try_module_get(t->owner)) { + found = t; + break; + } + + if (!found) + list_for_each_entry(t, &v9fs_trans_list, list) + if (try_module_get(t->owner)) { + found = t; + break; + } + + spin_unlock(&v9fs_trans_lock); + + for (i = 0; !found && i < ARRAY_SIZE(v9fs_default_transports); i++) + found = v9fs_get_trans_by_name(v9fs_default_transports[i]); + + return found; +} +EXPORT_SYMBOL(v9fs_get_default_trans); + +/** + * v9fs_put_trans - put trans + * @m: transport to put + * + */ +void v9fs_put_trans(struct p9_trans_module *m) +{ + if (m) + module_put(m->owner); +} + +/** + * init_p9 - Initialize module + * + */ +static int __init init_p9(void) +{ + int ret; + + ret = p9_client_init(); + if (ret) + return ret; + + p9_error_init(); + pr_info("Installing 9P2000 support\n"); + + return ret; +} + +/** + * exit_p9 - shutdown module + * + */ + +static void __exit exit_p9(void) +{ + pr_info("Unloading 9P2000 support\n"); + + p9_client_exit(); +} + +module_init(init_p9) +module_exit(exit_p9) + +MODULE_AUTHOR("Latchesar Ionkov "); +MODULE_AUTHOR("Eric Van Hensbergen "); +MODULE_AUTHOR("Ron Minnich "); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Plan 9 Resource Sharing Support (9P2000)"); diff --git a/net/9p/protocol.c b/net/9p/protocol.c new file mode 100644 index 000000000..0e6603b1e --- /dev/null +++ b/net/9p/protocol.c @@ -0,0 +1,801 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * 9P Protocol Support Code + * + * Copyright (C) 2008 by Eric Van Hensbergen + * + * Base on code from Anthony Liguori + * Copyright (C) 2008 by IBM, Corp. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "protocol.h" + +#include + +/* len[2] text[len] */ +#define P9_STRLEN(s) \ + (2 + min_t(size_t, s ? strlen(s) : 0, USHRT_MAX)) + +/** + * p9_msg_buf_size - Returns a buffer size sufficiently large to hold the + * intended 9p message. + * @c: client + * @type: message type + * @fmt: format template for assembling request message + * (see p9pdu_vwritef) + * @ap: variable arguments to be fed to passed format template + * (see p9pdu_vwritef) + * + * Note: Even for response types (P9_R*) the format template and variable + * arguments must always be for the originating request type (P9_T*). + */ +size_t p9_msg_buf_size(struct p9_client *c, enum p9_msg_t type, + const char *fmt, va_list ap) +{ + /* size[4] type[1] tag[2] */ + const int hdr = 4 + 1 + 2; + /* ename[s] errno[4] */ + const int rerror_size = hdr + P9_ERRMAX + 4; + /* ecode[4] */ + const int rlerror_size = hdr + 4; + const int err_size = + c->proto_version == p9_proto_2000L ? rlerror_size : rerror_size; + + static_assert(NAME_MAX <= 4*1024, "p9_msg_buf_size() currently assumes " + "a max. allowed directory entry name length of 4k"); + + switch (type) { + + /* message types not used at all */ + case P9_TERROR: + case P9_TLERROR: + case P9_TAUTH: + case P9_RAUTH: + BUG(); + + /* variable length & potentially large message types */ + case P9_TATTACH: + BUG_ON(strcmp("ddss?u", fmt)); + va_arg(ap, int32_t); + va_arg(ap, int32_t); + { + const char *uname = va_arg(ap, const char *); + const char *aname = va_arg(ap, const char *); + /* fid[4] afid[4] uname[s] aname[s] n_uname[4] */ + return hdr + 4 + 4 + P9_STRLEN(uname) + P9_STRLEN(aname) + 4; + } + case P9_TWALK: + BUG_ON(strcmp("ddT", fmt)); + va_arg(ap, int32_t); + va_arg(ap, int32_t); + { + uint i, nwname = va_arg(ap, int); + size_t wname_all; + const char **wnames = va_arg(ap, const char **); + for (i = 0, wname_all = 0; i < nwname; ++i) { + wname_all += P9_STRLEN(wnames[i]); + } + /* fid[4] newfid[4] nwname[2] nwname*(wname[s]) */ + return hdr + 4 + 4 + 2 + wname_all; + } + case P9_RWALK: + BUG_ON(strcmp("ddT", fmt)); + va_arg(ap, int32_t); + va_arg(ap, int32_t); + { + uint nwname = va_arg(ap, int); + /* nwqid[2] nwqid*(wqid[13]) */ + return max_t(size_t, hdr + 2 + nwname * 13, err_size); + } + case P9_TCREATE: + BUG_ON(strcmp("dsdb?s", fmt)); + va_arg(ap, int32_t); + { + const char *name = va_arg(ap, const char *); + if (c->proto_version == p9_proto_legacy) { + /* fid[4] name[s] perm[4] mode[1] */ + return hdr + 4 + P9_STRLEN(name) + 4 + 1; + } else { + va_arg(ap, int32_t); + va_arg(ap, int); + { + const char *ext = va_arg(ap, const char *); + /* fid[4] name[s] perm[4] mode[1] extension[s] */ + return hdr + 4 + P9_STRLEN(name) + 4 + 1 + P9_STRLEN(ext); + } + } + } + case P9_TLCREATE: + BUG_ON(strcmp("dsddg", fmt)); + va_arg(ap, int32_t); + { + const char *name = va_arg(ap, const char *); + /* fid[4] name[s] flags[4] mode[4] gid[4] */ + return hdr + 4 + P9_STRLEN(name) + 4 + 4 + 4; + } + case P9_RREAD: + case P9_RREADDIR: + BUG_ON(strcmp("dqd", fmt)); + va_arg(ap, int32_t); + va_arg(ap, int64_t); + { + const int32_t count = va_arg(ap, int32_t); + /* count[4] data[count] */ + return max_t(size_t, hdr + 4 + count, err_size); + } + case P9_TWRITE: + BUG_ON(strcmp("dqV", fmt)); + va_arg(ap, int32_t); + va_arg(ap, int64_t); + { + const int32_t count = va_arg(ap, int32_t); + /* fid[4] offset[8] count[4] data[count] */ + return hdr + 4 + 8 + 4 + count; + } + case P9_TRENAMEAT: + BUG_ON(strcmp("dsds", fmt)); + va_arg(ap, int32_t); + { + const char *oldname, *newname; + oldname = va_arg(ap, const char *); + va_arg(ap, int32_t); + newname = va_arg(ap, const char *); + /* olddirfid[4] oldname[s] newdirfid[4] newname[s] */ + return hdr + 4 + P9_STRLEN(oldname) + 4 + P9_STRLEN(newname); + } + case P9_TSYMLINK: + BUG_ON(strcmp("dssg", fmt)); + va_arg(ap, int32_t); + { + const char *name = va_arg(ap, const char *); + const char *symtgt = va_arg(ap, const char *); + /* fid[4] name[s] symtgt[s] gid[4] */ + return hdr + 4 + P9_STRLEN(name) + P9_STRLEN(symtgt) + 4; + } + + case P9_RERROR: + return rerror_size; + case P9_RLERROR: + return rlerror_size; + + /* small message types */ + case P9_TWSTAT: + case P9_RSTAT: + case P9_RREADLINK: + case P9_TXATTRWALK: + case P9_TXATTRCREATE: + case P9_TLINK: + case P9_TMKDIR: + case P9_TMKNOD: + case P9_TRENAME: + case P9_TUNLINKAT: + case P9_TLOCK: + return 8 * 1024; + + /* tiny message types */ + default: + return 4 * 1024; + + } +} + +static int +p9pdu_writef(struct p9_fcall *pdu, int proto_version, const char *fmt, ...); + +void p9stat_free(struct p9_wstat *stbuf) +{ + kfree(stbuf->name); + stbuf->name = NULL; + kfree(stbuf->uid); + stbuf->uid = NULL; + kfree(stbuf->gid); + stbuf->gid = NULL; + kfree(stbuf->muid); + stbuf->muid = NULL; + kfree(stbuf->extension); + stbuf->extension = NULL; +} +EXPORT_SYMBOL(p9stat_free); + +size_t pdu_read(struct p9_fcall *pdu, void *data, size_t size) +{ + size_t len = min(pdu->size - pdu->offset, size); + + memcpy(data, &pdu->sdata[pdu->offset], len); + pdu->offset += len; + return size - len; +} + +static size_t pdu_write(struct p9_fcall *pdu, const void *data, size_t size) +{ + size_t len = min(pdu->capacity - pdu->size, size); + + memcpy(&pdu->sdata[pdu->size], data, len); + pdu->size += len; + return size - len; +} + +static size_t +pdu_write_u(struct p9_fcall *pdu, struct iov_iter *from, size_t size) +{ + size_t len = min(pdu->capacity - pdu->size, size); + + if (!copy_from_iter_full(&pdu->sdata[pdu->size], len, from)) + len = 0; + + pdu->size += len; + return size - len; +} + +/* b - int8_t + * w - int16_t + * d - int32_t + * q - int64_t + * s - string + * u - numeric uid + * g - numeric gid + * S - stat + * Q - qid + * D - data blob (int32_t size followed by void *, results are not freed) + * T - array of strings (int16_t count, followed by strings) + * R - array of qids (int16_t count, followed by qids) + * A - stat for 9p2000.L (p9_stat_dotl) + * ? - if optional = 1, continue parsing + */ + +static int +p9pdu_vreadf(struct p9_fcall *pdu, int proto_version, const char *fmt, + va_list ap) +{ + const char *ptr; + int errcode = 0; + + for (ptr = fmt; *ptr; ptr++) { + switch (*ptr) { + case 'b':{ + int8_t *val = va_arg(ap, int8_t *); + if (pdu_read(pdu, val, sizeof(*val))) { + errcode = -EFAULT; + break; + } + } + break; + case 'w':{ + int16_t *val = va_arg(ap, int16_t *); + __le16 le_val; + if (pdu_read(pdu, &le_val, sizeof(le_val))) { + errcode = -EFAULT; + break; + } + *val = le16_to_cpu(le_val); + } + break; + case 'd':{ + int32_t *val = va_arg(ap, int32_t *); + __le32 le_val; + if (pdu_read(pdu, &le_val, sizeof(le_val))) { + errcode = -EFAULT; + break; + } + *val = le32_to_cpu(le_val); + } + break; + case 'q':{ + int64_t *val = va_arg(ap, int64_t *); + __le64 le_val; + if (pdu_read(pdu, &le_val, sizeof(le_val))) { + errcode = -EFAULT; + break; + } + *val = le64_to_cpu(le_val); + } + break; + case 's':{ + char **sptr = va_arg(ap, char **); + uint16_t len; + + errcode = p9pdu_readf(pdu, proto_version, + "w", &len); + if (errcode) + break; + + *sptr = kmalloc(len + 1, GFP_NOFS); + if (*sptr == NULL) { + errcode = -ENOMEM; + break; + } + if (pdu_read(pdu, *sptr, len)) { + errcode = -EFAULT; + kfree(*sptr); + *sptr = NULL; + } else + (*sptr)[len] = 0; + } + break; + case 'u': { + kuid_t *uid = va_arg(ap, kuid_t *); + __le32 le_val; + if (pdu_read(pdu, &le_val, sizeof(le_val))) { + errcode = -EFAULT; + break; + } + *uid = make_kuid(&init_user_ns, + le32_to_cpu(le_val)); + } break; + case 'g': { + kgid_t *gid = va_arg(ap, kgid_t *); + __le32 le_val; + if (pdu_read(pdu, &le_val, sizeof(le_val))) { + errcode = -EFAULT; + break; + } + *gid = make_kgid(&init_user_ns, + le32_to_cpu(le_val)); + } break; + case 'Q':{ + struct p9_qid *qid = + va_arg(ap, struct p9_qid *); + + errcode = p9pdu_readf(pdu, proto_version, "bdq", + &qid->type, &qid->version, + &qid->path); + } + break; + case 'S':{ + struct p9_wstat *stbuf = + va_arg(ap, struct p9_wstat *); + + memset(stbuf, 0, sizeof(struct p9_wstat)); + stbuf->n_uid = stbuf->n_muid = INVALID_UID; + stbuf->n_gid = INVALID_GID; + + errcode = + p9pdu_readf(pdu, proto_version, + "wwdQdddqssss?sugu", + &stbuf->size, &stbuf->type, + &stbuf->dev, &stbuf->qid, + &stbuf->mode, &stbuf->atime, + &stbuf->mtime, &stbuf->length, + &stbuf->name, &stbuf->uid, + &stbuf->gid, &stbuf->muid, + &stbuf->extension, + &stbuf->n_uid, &stbuf->n_gid, + &stbuf->n_muid); + if (errcode) + p9stat_free(stbuf); + } + break; + case 'D':{ + uint32_t *count = va_arg(ap, uint32_t *); + void **data = va_arg(ap, void **); + + errcode = + p9pdu_readf(pdu, proto_version, "d", count); + if (!errcode) { + *count = + min_t(uint32_t, *count, + pdu->size - pdu->offset); + *data = &pdu->sdata[pdu->offset]; + } + } + break; + case 'T':{ + uint16_t *nwname = va_arg(ap, uint16_t *); + char ***wnames = va_arg(ap, char ***); + + *wnames = NULL; + + errcode = p9pdu_readf(pdu, proto_version, + "w", nwname); + if (!errcode) { + *wnames = + kmalloc_array(*nwname, + sizeof(char *), + GFP_NOFS); + if (!*wnames) + errcode = -ENOMEM; + else + (*wnames)[0] = NULL; + } + + if (!errcode) { + int i; + + for (i = 0; i < *nwname; i++) { + errcode = + p9pdu_readf(pdu, + proto_version, + "s", + &(*wnames)[i]); + if (errcode) { + (*wnames)[i] = NULL; + break; + } + } + } + + if (errcode) { + if (*wnames) { + int i; + + for (i = 0; i < *nwname; i++) { + if (!(*wnames)[i]) + break; + kfree((*wnames)[i]); + } + kfree(*wnames); + *wnames = NULL; + } + } + } + break; + case 'R':{ + uint16_t *nwqid = va_arg(ap, uint16_t *); + struct p9_qid **wqids = + va_arg(ap, struct p9_qid **); + + *wqids = NULL; + + errcode = + p9pdu_readf(pdu, proto_version, "w", nwqid); + if (!errcode) { + *wqids = + kmalloc_array(*nwqid, + sizeof(struct p9_qid), + GFP_NOFS); + if (*wqids == NULL) + errcode = -ENOMEM; + } + + if (!errcode) { + int i; + + for (i = 0; i < *nwqid; i++) { + errcode = + p9pdu_readf(pdu, + proto_version, + "Q", + &(*wqids)[i]); + if (errcode) + break; + } + } + + if (errcode) { + kfree(*wqids); + *wqids = NULL; + } + } + break; + case 'A': { + struct p9_stat_dotl *stbuf = + va_arg(ap, struct p9_stat_dotl *); + + memset(stbuf, 0, sizeof(struct p9_stat_dotl)); + errcode = + p9pdu_readf(pdu, proto_version, + "qQdugqqqqqqqqqqqqqqq", + &stbuf->st_result_mask, + &stbuf->qid, + &stbuf->st_mode, + &stbuf->st_uid, &stbuf->st_gid, + &stbuf->st_nlink, + &stbuf->st_rdev, &stbuf->st_size, + &stbuf->st_blksize, &stbuf->st_blocks, + &stbuf->st_atime_sec, + &stbuf->st_atime_nsec, + &stbuf->st_mtime_sec, + &stbuf->st_mtime_nsec, + &stbuf->st_ctime_sec, + &stbuf->st_ctime_nsec, + &stbuf->st_btime_sec, + &stbuf->st_btime_nsec, + &stbuf->st_gen, + &stbuf->st_data_version); + } + break; + case '?': + if ((proto_version != p9_proto_2000u) && + (proto_version != p9_proto_2000L)) + return 0; + break; + default: + BUG(); + break; + } + + if (errcode) + break; + } + + return errcode; +} + +int +p9pdu_vwritef(struct p9_fcall *pdu, int proto_version, const char *fmt, + va_list ap) +{ + const char *ptr; + int errcode = 0; + + for (ptr = fmt; *ptr; ptr++) { + switch (*ptr) { + case 'b':{ + int8_t val = va_arg(ap, int); + if (pdu_write(pdu, &val, sizeof(val))) + errcode = -EFAULT; + } + break; + case 'w':{ + __le16 val = cpu_to_le16(va_arg(ap, int)); + if (pdu_write(pdu, &val, sizeof(val))) + errcode = -EFAULT; + } + break; + case 'd':{ + __le32 val = cpu_to_le32(va_arg(ap, int32_t)); + if (pdu_write(pdu, &val, sizeof(val))) + errcode = -EFAULT; + } + break; + case 'q':{ + __le64 val = cpu_to_le64(va_arg(ap, int64_t)); + if (pdu_write(pdu, &val, sizeof(val))) + errcode = -EFAULT; + } + break; + case 's':{ + const char *sptr = va_arg(ap, const char *); + uint16_t len = 0; + if (sptr) + len = min_t(size_t, strlen(sptr), + USHRT_MAX); + + errcode = p9pdu_writef(pdu, proto_version, + "w", len); + if (!errcode && pdu_write(pdu, sptr, len)) + errcode = -EFAULT; + } + break; + case 'u': { + kuid_t uid = va_arg(ap, kuid_t); + __le32 val = cpu_to_le32( + from_kuid(&init_user_ns, uid)); + if (pdu_write(pdu, &val, sizeof(val))) + errcode = -EFAULT; + } break; + case 'g': { + kgid_t gid = va_arg(ap, kgid_t); + __le32 val = cpu_to_le32( + from_kgid(&init_user_ns, gid)); + if (pdu_write(pdu, &val, sizeof(val))) + errcode = -EFAULT; + } break; + case 'Q':{ + const struct p9_qid *qid = + va_arg(ap, const struct p9_qid *); + errcode = + p9pdu_writef(pdu, proto_version, "bdq", + qid->type, qid->version, + qid->path); + } break; + case 'S':{ + const struct p9_wstat *stbuf = + va_arg(ap, const struct p9_wstat *); + errcode = + p9pdu_writef(pdu, proto_version, + "wwdQdddqssss?sugu", + stbuf->size, stbuf->type, + stbuf->dev, &stbuf->qid, + stbuf->mode, stbuf->atime, + stbuf->mtime, stbuf->length, + stbuf->name, stbuf->uid, + stbuf->gid, stbuf->muid, + stbuf->extension, stbuf->n_uid, + stbuf->n_gid, stbuf->n_muid); + } break; + case 'V':{ + uint32_t count = va_arg(ap, uint32_t); + struct iov_iter *from = + va_arg(ap, struct iov_iter *); + errcode = p9pdu_writef(pdu, proto_version, "d", + count); + if (!errcode && pdu_write_u(pdu, from, count)) + errcode = -EFAULT; + } + break; + case 'T':{ + uint16_t nwname = va_arg(ap, int); + const char **wnames = va_arg(ap, const char **); + + errcode = p9pdu_writef(pdu, proto_version, "w", + nwname); + if (!errcode) { + int i; + + for (i = 0; i < nwname; i++) { + errcode = + p9pdu_writef(pdu, + proto_version, + "s", + wnames[i]); + if (errcode) + break; + } + } + } + break; + case 'R':{ + uint16_t nwqid = va_arg(ap, int); + struct p9_qid *wqids = + va_arg(ap, struct p9_qid *); + + errcode = p9pdu_writef(pdu, proto_version, "w", + nwqid); + if (!errcode) { + int i; + + for (i = 0; i < nwqid; i++) { + errcode = + p9pdu_writef(pdu, + proto_version, + "Q", + &wqids[i]); + if (errcode) + break; + } + } + } + break; + case 'I':{ + struct p9_iattr_dotl *p9attr = va_arg(ap, + struct p9_iattr_dotl *); + + errcode = p9pdu_writef(pdu, proto_version, + "ddugqqqqq", + p9attr->valid, + p9attr->mode, + p9attr->uid, + p9attr->gid, + p9attr->size, + p9attr->atime_sec, + p9attr->atime_nsec, + p9attr->mtime_sec, + p9attr->mtime_nsec); + } + break; + case '?': + if ((proto_version != p9_proto_2000u) && + (proto_version != p9_proto_2000L)) + return 0; + break; + default: + BUG(); + break; + } + + if (errcode) + break; + } + + return errcode; +} + +int p9pdu_readf(struct p9_fcall *pdu, int proto_version, const char *fmt, ...) +{ + va_list ap; + int ret; + + va_start(ap, fmt); + ret = p9pdu_vreadf(pdu, proto_version, fmt, ap); + va_end(ap); + + return ret; +} + +static int +p9pdu_writef(struct p9_fcall *pdu, int proto_version, const char *fmt, ...) +{ + va_list ap; + int ret; + + va_start(ap, fmt); + ret = p9pdu_vwritef(pdu, proto_version, fmt, ap); + va_end(ap); + + return ret; +} + +int p9stat_read(struct p9_client *clnt, char *buf, int len, struct p9_wstat *st) +{ + struct p9_fcall fake_pdu; + int ret; + + fake_pdu.size = len; + fake_pdu.capacity = len; + fake_pdu.sdata = buf; + fake_pdu.offset = 0; + + ret = p9pdu_readf(&fake_pdu, clnt->proto_version, "S", st); + if (ret) { + p9_debug(P9_DEBUG_9P, "<<< p9stat_read failed: %d\n", ret); + trace_9p_protocol_dump(clnt, &fake_pdu); + return ret; + } + + return fake_pdu.offset; +} +EXPORT_SYMBOL(p9stat_read); + +int p9pdu_prepare(struct p9_fcall *pdu, int16_t tag, int8_t type) +{ + pdu->id = type; + return p9pdu_writef(pdu, 0, "dbw", 0, type, tag); +} + +int p9pdu_finalize(struct p9_client *clnt, struct p9_fcall *pdu) +{ + int size = pdu->size; + int err; + + pdu->size = 0; + err = p9pdu_writef(pdu, 0, "d", size); + pdu->size = size; + + trace_9p_protocol_dump(clnt, pdu); + p9_debug(P9_DEBUG_9P, ">>> size=%d type: %d tag: %d\n", + pdu->size, pdu->id, pdu->tag); + + return err; +} + +void p9pdu_reset(struct p9_fcall *pdu) +{ + pdu->offset = 0; + pdu->size = 0; +} + +int p9dirent_read(struct p9_client *clnt, char *buf, int len, + struct p9_dirent *dirent) +{ + struct p9_fcall fake_pdu; + int ret; + char *nameptr; + + fake_pdu.size = len; + fake_pdu.capacity = len; + fake_pdu.sdata = buf; + fake_pdu.offset = 0; + + ret = p9pdu_readf(&fake_pdu, clnt->proto_version, "Qqbs", &dirent->qid, + &dirent->d_off, &dirent->d_type, &nameptr); + if (ret) { + p9_debug(P9_DEBUG_9P, "<<< p9dirent_read failed: %d\n", ret); + trace_9p_protocol_dump(clnt, &fake_pdu); + return ret; + } + + ret = strscpy(dirent->d_name, nameptr, sizeof(dirent->d_name)); + if (ret < 0) { + p9_debug(P9_DEBUG_ERROR, + "On the wire dirent name too long: %s\n", + nameptr); + kfree(nameptr); + return ret; + } + kfree(nameptr); + + return fake_pdu.offset; +} +EXPORT_SYMBOL(p9dirent_read); diff --git a/net/9p/protocol.h b/net/9p/protocol.h new file mode 100644 index 000000000..ad2283d1f --- /dev/null +++ b/net/9p/protocol.h @@ -0,0 +1,19 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * 9P Protocol Support Code + * + * Copyright (C) 2008 by Eric Van Hensbergen + * + * Base on code from Anthony Liguori + * Copyright (C) 2008 by IBM, Corp. + */ + +size_t p9_msg_buf_size(struct p9_client *c, enum p9_msg_t type, + const char *fmt, va_list ap); +int p9pdu_vwritef(struct p9_fcall *pdu, int proto_version, const char *fmt, + va_list ap); +int p9pdu_readf(struct p9_fcall *pdu, int proto_version, const char *fmt, ...); +int p9pdu_prepare(struct p9_fcall *pdu, int16_t tag, int8_t type); +int p9pdu_finalize(struct p9_client *clnt, struct p9_fcall *pdu); +void p9pdu_reset(struct p9_fcall *pdu); +size_t pdu_read(struct p9_fcall *pdu, void *data, size_t size); diff --git a/net/9p/trans_common.c b/net/9p/trans_common.c new file mode 100644 index 000000000..c827f6945 --- /dev/null +++ b/net/9p/trans_common.c @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: LGPL-2.1 +/* + * Copyright IBM Corporation, 2010 + * Author Venkateswararao Jujjuri + */ + +#include +#include +#include "trans_common.h" + +/** + * p9_release_pages - Release pages after the transaction. + * @pages: array of pages to be put + * @nr_pages: size of array + */ +void p9_release_pages(struct page **pages, int nr_pages) +{ + int i; + + for (i = 0; i < nr_pages; i++) + if (pages[i]) + put_page(pages[i]); +} +EXPORT_SYMBOL(p9_release_pages); diff --git a/net/9p/trans_common.h b/net/9p/trans_common.h new file mode 100644 index 000000000..32134db6a --- /dev/null +++ b/net/9p/trans_common.h @@ -0,0 +1,7 @@ +/* SPDX-License-Identifier: LGPL-2.1 */ +/* + * Copyright IBM Corporation, 2010 + * Author Venkateswararao Jujjuri + */ + +void p9_release_pages(struct page **pages, int nr_pages); diff --git a/net/9p/trans_fd.c b/net/9p/trans_fd.c new file mode 100644 index 000000000..a69422366 --- /dev/null +++ b/net/9p/trans_fd.c @@ -0,0 +1,1205 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Fd transport layer. Includes deprecated socket layer. + * + * Copyright (C) 2006 by Russ Cox + * Copyright (C) 2004-2005 by Latchesar Ionkov + * Copyright (C) 2004-2008 by Eric Van Hensbergen + * Copyright (C) 1997-2002 by Ron Minnich + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include /* killme */ + +#define P9_PORT 564 +#define MAX_SOCK_BUF (1024*1024) +#define MAXPOLLWADDR 2 + +static struct p9_trans_module p9_tcp_trans; +static struct p9_trans_module p9_fd_trans; + +/** + * struct p9_fd_opts - per-transport options + * @rfd: file descriptor for reading (trans=fd) + * @wfd: file descriptor for writing (trans=fd) + * @port: port to connect to (trans=tcp) + * @privport: port is privileged + */ + +struct p9_fd_opts { + int rfd; + int wfd; + u16 port; + bool privport; +}; + +/* + * Option Parsing (code inspired by NFS code) + * - a little lazy - parse all fd-transport options + */ + +enum { + /* Options that take integer arguments */ + Opt_port, Opt_rfdno, Opt_wfdno, Opt_err, + /* Options that take no arguments */ + Opt_privport, +}; + +static const match_table_t tokens = { + {Opt_port, "port=%u"}, + {Opt_rfdno, "rfdno=%u"}, + {Opt_wfdno, "wfdno=%u"}, + {Opt_privport, "privport"}, + {Opt_err, NULL}, +}; + +enum { + Rworksched = 1, /* read work scheduled or running */ + Rpending = 2, /* can read */ + Wworksched = 4, /* write work scheduled or running */ + Wpending = 8, /* can write */ +}; + +struct p9_poll_wait { + struct p9_conn *conn; + wait_queue_entry_t wait; + wait_queue_head_t *wait_addr; +}; + +/** + * struct p9_conn - fd mux connection state information + * @mux_list: list link for mux to manage multiple connections (?) + * @client: reference to client instance for this connection + * @err: error state + * @req_lock: lock protecting req_list and requests statuses + * @req_list: accounting for requests which have been sent + * @unsent_req_list: accounting for requests that haven't been sent + * @rreq: read request + * @wreq: write request + * @req: current request being processed (if any) + * @tmp_buf: temporary buffer to read in header + * @rc: temporary fcall for reading current frame + * @wpos: write position for current frame + * @wsize: amount of data to write for current frame + * @wbuf: current write buffer + * @poll_pending_link: pending links to be polled per conn + * @poll_wait: array of wait_q's for various worker threads + * @pt: poll state + * @rq: current read work + * @wq: current write work + * @wsched: ???? + * + */ + +struct p9_conn { + struct list_head mux_list; + struct p9_client *client; + int err; + spinlock_t req_lock; + struct list_head req_list; + struct list_head unsent_req_list; + struct p9_req_t *rreq; + struct p9_req_t *wreq; + char tmp_buf[P9_HDRSZ]; + struct p9_fcall rc; + int wpos; + int wsize; + char *wbuf; + struct list_head poll_pending_link; + struct p9_poll_wait poll_wait[MAXPOLLWADDR]; + poll_table pt; + struct work_struct rq; + struct work_struct wq; + unsigned long wsched; +}; + +/** + * struct p9_trans_fd - transport state + * @rd: reference to file to read from + * @wr: reference of file to write to + * @conn: connection state reference + * + */ + +struct p9_trans_fd { + struct file *rd; + struct file *wr; + struct p9_conn conn; +}; + +static void p9_poll_workfn(struct work_struct *work); + +static DEFINE_SPINLOCK(p9_poll_lock); +static LIST_HEAD(p9_poll_pending_list); +static DECLARE_WORK(p9_poll_work, p9_poll_workfn); + +static unsigned int p9_ipport_resv_min = P9_DEF_MIN_RESVPORT; +static unsigned int p9_ipport_resv_max = P9_DEF_MAX_RESVPORT; + +static void p9_mux_poll_stop(struct p9_conn *m) +{ + unsigned long flags; + int i; + + for (i = 0; i < ARRAY_SIZE(m->poll_wait); i++) { + struct p9_poll_wait *pwait = &m->poll_wait[i]; + + if (pwait->wait_addr) { + remove_wait_queue(pwait->wait_addr, &pwait->wait); + pwait->wait_addr = NULL; + } + } + + spin_lock_irqsave(&p9_poll_lock, flags); + list_del_init(&m->poll_pending_link); + spin_unlock_irqrestore(&p9_poll_lock, flags); + + flush_work(&p9_poll_work); +} + +/** + * p9_conn_cancel - cancel all pending requests with error + * @m: mux data + * @err: error code + * + */ + +static void p9_conn_cancel(struct p9_conn *m, int err) +{ + struct p9_req_t *req, *rtmp; + LIST_HEAD(cancel_list); + + p9_debug(P9_DEBUG_ERROR, "mux %p err %d\n", m, err); + + spin_lock(&m->req_lock); + + if (m->err) { + spin_unlock(&m->req_lock); + return; + } + + m->err = err; + + list_for_each_entry_safe(req, rtmp, &m->req_list, req_list) { + list_move(&req->req_list, &cancel_list); + WRITE_ONCE(req->status, REQ_STATUS_ERROR); + } + list_for_each_entry_safe(req, rtmp, &m->unsent_req_list, req_list) { + list_move(&req->req_list, &cancel_list); + WRITE_ONCE(req->status, REQ_STATUS_ERROR); + } + + spin_unlock(&m->req_lock); + + list_for_each_entry_safe(req, rtmp, &cancel_list, req_list) { + p9_debug(P9_DEBUG_ERROR, "call back req %p\n", req); + list_del(&req->req_list); + if (!req->t_err) + req->t_err = err; + p9_client_cb(m->client, req, REQ_STATUS_ERROR); + } +} + +static __poll_t +p9_fd_poll(struct p9_client *client, struct poll_table_struct *pt, int *err) +{ + __poll_t ret; + struct p9_trans_fd *ts = NULL; + + if (client && client->status == Connected) + ts = client->trans; + + if (!ts) { + if (err) + *err = -EREMOTEIO; + return EPOLLERR; + } + + ret = vfs_poll(ts->rd, pt); + if (ts->rd != ts->wr) + ret = (ret & ~EPOLLOUT) | (vfs_poll(ts->wr, pt) & ~EPOLLIN); + return ret; +} + +/** + * p9_fd_read- read from a fd + * @client: client instance + * @v: buffer to receive data into + * @len: size of receive buffer + * + */ + +static int p9_fd_read(struct p9_client *client, void *v, int len) +{ + int ret; + struct p9_trans_fd *ts = NULL; + loff_t pos; + + if (client && client->status != Disconnected) + ts = client->trans; + + if (!ts) + return -EREMOTEIO; + + if (!(ts->rd->f_flags & O_NONBLOCK)) + p9_debug(P9_DEBUG_ERROR, "blocking read ...\n"); + + pos = ts->rd->f_pos; + ret = kernel_read(ts->rd, v, len, &pos); + if (ret <= 0 && ret != -ERESTARTSYS && ret != -EAGAIN) + client->status = Disconnected; + return ret; +} + +/** + * p9_read_work - called when there is some data to be read from a transport + * @work: container of work to be done + * + */ + +static void p9_read_work(struct work_struct *work) +{ + __poll_t n; + int err; + struct p9_conn *m; + + m = container_of(work, struct p9_conn, rq); + + if (m->err < 0) + return; + + p9_debug(P9_DEBUG_TRANS, "start mux %p pos %zd\n", m, m->rc.offset); + + if (!m->rc.sdata) { + m->rc.sdata = m->tmp_buf; + m->rc.offset = 0; + m->rc.capacity = P9_HDRSZ; /* start by reading header */ + } + + clear_bit(Rpending, &m->wsched); + p9_debug(P9_DEBUG_TRANS, "read mux %p pos %zd size: %zd = %zd\n", + m, m->rc.offset, m->rc.capacity, + m->rc.capacity - m->rc.offset); + err = p9_fd_read(m->client, m->rc.sdata + m->rc.offset, + m->rc.capacity - m->rc.offset); + p9_debug(P9_DEBUG_TRANS, "mux %p got %d bytes\n", m, err); + if (err == -EAGAIN) + goto end_clear; + + if (err <= 0) + goto error; + + m->rc.offset += err; + + /* header read in */ + if ((!m->rreq) && (m->rc.offset == m->rc.capacity)) { + p9_debug(P9_DEBUG_TRANS, "got new header\n"); + + /* Header size */ + m->rc.size = P9_HDRSZ; + err = p9_parse_header(&m->rc, &m->rc.size, NULL, NULL, 0); + if (err) { + p9_debug(P9_DEBUG_ERROR, + "error parsing header: %d\n", err); + goto error; + } + + p9_debug(P9_DEBUG_TRANS, + "mux %p pkt: size: %d bytes tag: %d\n", + m, m->rc.size, m->rc.tag); + + m->rreq = p9_tag_lookup(m->client, m->rc.tag); + if (!m->rreq || (m->rreq->status != REQ_STATUS_SENT)) { + p9_debug(P9_DEBUG_ERROR, "Unexpected packet tag %d\n", + m->rc.tag); + err = -EIO; + goto error; + } + + if (m->rc.size > m->rreq->rc.capacity) { + p9_debug(P9_DEBUG_ERROR, + "requested packet size too big: %d for tag %d with capacity %zd\n", + m->rc.size, m->rc.tag, m->rreq->rc.capacity); + err = -EIO; + goto error; + } + + if (!m->rreq->rc.sdata) { + p9_debug(P9_DEBUG_ERROR, + "No recv fcall for tag %d (req %p), disconnecting!\n", + m->rc.tag, m->rreq); + p9_req_put(m->client, m->rreq); + m->rreq = NULL; + err = -EIO; + goto error; + } + m->rc.sdata = m->rreq->rc.sdata; + memcpy(m->rc.sdata, m->tmp_buf, m->rc.capacity); + m->rc.capacity = m->rc.size; + } + + /* packet is read in + * not an else because some packets (like clunk) have no payload + */ + if ((m->rreq) && (m->rc.offset == m->rc.capacity)) { + p9_debug(P9_DEBUG_TRANS, "got new packet\n"); + m->rreq->rc.size = m->rc.offset; + spin_lock(&m->req_lock); + if (m->rreq->status == REQ_STATUS_SENT) { + list_del(&m->rreq->req_list); + p9_client_cb(m->client, m->rreq, REQ_STATUS_RCVD); + } else if (m->rreq->status == REQ_STATUS_FLSHD) { + /* Ignore replies associated with a cancelled request. */ + p9_debug(P9_DEBUG_TRANS, + "Ignore replies associated with a cancelled request\n"); + } else { + spin_unlock(&m->req_lock); + p9_debug(P9_DEBUG_ERROR, + "Request tag %d errored out while we were reading the reply\n", + m->rc.tag); + err = -EIO; + goto error; + } + spin_unlock(&m->req_lock); + m->rc.sdata = NULL; + m->rc.offset = 0; + m->rc.capacity = 0; + p9_req_put(m->client, m->rreq); + m->rreq = NULL; + } + +end_clear: + clear_bit(Rworksched, &m->wsched); + + if (!list_empty(&m->req_list)) { + if (test_and_clear_bit(Rpending, &m->wsched)) + n = EPOLLIN; + else + n = p9_fd_poll(m->client, NULL, NULL); + + if ((n & EPOLLIN) && !test_and_set_bit(Rworksched, &m->wsched)) { + p9_debug(P9_DEBUG_TRANS, "sched read work %p\n", m); + schedule_work(&m->rq); + } + } + + return; +error: + p9_conn_cancel(m, err); + clear_bit(Rworksched, &m->wsched); +} + +/** + * p9_fd_write - write to a socket + * @client: client instance + * @v: buffer to send data from + * @len: size of send buffer + * + */ + +static int p9_fd_write(struct p9_client *client, void *v, int len) +{ + ssize_t ret; + struct p9_trans_fd *ts = NULL; + + if (client && client->status != Disconnected) + ts = client->trans; + + if (!ts) + return -EREMOTEIO; + + if (!(ts->wr->f_flags & O_NONBLOCK)) + p9_debug(P9_DEBUG_ERROR, "blocking write ...\n"); + + ret = kernel_write(ts->wr, v, len, &ts->wr->f_pos); + if (ret <= 0 && ret != -ERESTARTSYS && ret != -EAGAIN) + client->status = Disconnected; + return ret; +} + +/** + * p9_write_work - called when a transport can send some data + * @work: container for work to be done + * + */ + +static void p9_write_work(struct work_struct *work) +{ + __poll_t n; + int err; + struct p9_conn *m; + struct p9_req_t *req; + + m = container_of(work, struct p9_conn, wq); + + if (m->err < 0) { + clear_bit(Wworksched, &m->wsched); + return; + } + + if (!m->wsize) { + spin_lock(&m->req_lock); + if (list_empty(&m->unsent_req_list)) { + clear_bit(Wworksched, &m->wsched); + spin_unlock(&m->req_lock); + return; + } + + req = list_entry(m->unsent_req_list.next, struct p9_req_t, + req_list); + WRITE_ONCE(req->status, REQ_STATUS_SENT); + p9_debug(P9_DEBUG_TRANS, "move req %p\n", req); + list_move_tail(&req->req_list, &m->req_list); + + m->wbuf = req->tc.sdata; + m->wsize = req->tc.size; + m->wpos = 0; + p9_req_get(req); + m->wreq = req; + spin_unlock(&m->req_lock); + } + + p9_debug(P9_DEBUG_TRANS, "mux %p pos %d size %d\n", + m, m->wpos, m->wsize); + clear_bit(Wpending, &m->wsched); + err = p9_fd_write(m->client, m->wbuf + m->wpos, m->wsize - m->wpos); + p9_debug(P9_DEBUG_TRANS, "mux %p sent %d bytes\n", m, err); + if (err == -EAGAIN) + goto end_clear; + + + if (err < 0) + goto error; + else if (err == 0) { + err = -EREMOTEIO; + goto error; + } + + m->wpos += err; + if (m->wpos == m->wsize) { + m->wpos = m->wsize = 0; + p9_req_put(m->client, m->wreq); + m->wreq = NULL; + } + +end_clear: + clear_bit(Wworksched, &m->wsched); + + if (m->wsize || !list_empty(&m->unsent_req_list)) { + if (test_and_clear_bit(Wpending, &m->wsched)) + n = EPOLLOUT; + else + n = p9_fd_poll(m->client, NULL, NULL); + + if ((n & EPOLLOUT) && + !test_and_set_bit(Wworksched, &m->wsched)) { + p9_debug(P9_DEBUG_TRANS, "sched write work %p\n", m); + schedule_work(&m->wq); + } + } + + return; + +error: + p9_conn_cancel(m, err); + clear_bit(Wworksched, &m->wsched); +} + +static int p9_pollwake(wait_queue_entry_t *wait, unsigned int mode, int sync, void *key) +{ + struct p9_poll_wait *pwait = + container_of(wait, struct p9_poll_wait, wait); + struct p9_conn *m = pwait->conn; + unsigned long flags; + + spin_lock_irqsave(&p9_poll_lock, flags); + if (list_empty(&m->poll_pending_link)) + list_add_tail(&m->poll_pending_link, &p9_poll_pending_list); + spin_unlock_irqrestore(&p9_poll_lock, flags); + + schedule_work(&p9_poll_work); + return 1; +} + +/** + * p9_pollwait - add poll task to the wait queue + * @filp: file pointer being polled + * @wait_address: wait_q to block on + * @p: poll state + * + * called by files poll operation to add v9fs-poll task to files wait queue + */ + +static void +p9_pollwait(struct file *filp, wait_queue_head_t *wait_address, poll_table *p) +{ + struct p9_conn *m = container_of(p, struct p9_conn, pt); + struct p9_poll_wait *pwait = NULL; + int i; + + for (i = 0; i < ARRAY_SIZE(m->poll_wait); i++) { + if (m->poll_wait[i].wait_addr == NULL) { + pwait = &m->poll_wait[i]; + break; + } + } + + if (!pwait) { + p9_debug(P9_DEBUG_ERROR, "not enough wait_address slots\n"); + return; + } + + pwait->conn = m; + pwait->wait_addr = wait_address; + init_waitqueue_func_entry(&pwait->wait, p9_pollwake); + add_wait_queue(wait_address, &pwait->wait); +} + +/** + * p9_conn_create - initialize the per-session mux data + * @client: client instance + * + * Note: Creates the polling task if this is the first session. + */ + +static void p9_conn_create(struct p9_client *client) +{ + __poll_t n; + struct p9_trans_fd *ts = client->trans; + struct p9_conn *m = &ts->conn; + + p9_debug(P9_DEBUG_TRANS, "client %p msize %d\n", client, client->msize); + + INIT_LIST_HEAD(&m->mux_list); + m->client = client; + + spin_lock_init(&m->req_lock); + INIT_LIST_HEAD(&m->req_list); + INIT_LIST_HEAD(&m->unsent_req_list); + INIT_WORK(&m->rq, p9_read_work); + INIT_WORK(&m->wq, p9_write_work); + INIT_LIST_HEAD(&m->poll_pending_link); + init_poll_funcptr(&m->pt, p9_pollwait); + + n = p9_fd_poll(client, &m->pt, NULL); + if (n & EPOLLIN) { + p9_debug(P9_DEBUG_TRANS, "mux %p can read\n", m); + set_bit(Rpending, &m->wsched); + } + + if (n & EPOLLOUT) { + p9_debug(P9_DEBUG_TRANS, "mux %p can write\n", m); + set_bit(Wpending, &m->wsched); + } +} + +/** + * p9_poll_mux - polls a mux and schedules read or write works if necessary + * @m: connection to poll + * + */ + +static void p9_poll_mux(struct p9_conn *m) +{ + __poll_t n; + int err = -ECONNRESET; + + if (m->err < 0) + return; + + n = p9_fd_poll(m->client, NULL, &err); + if (n & (EPOLLERR | EPOLLHUP | EPOLLNVAL)) { + p9_debug(P9_DEBUG_TRANS, "error mux %p err %d\n", m, n); + p9_conn_cancel(m, err); + } + + if (n & EPOLLIN) { + set_bit(Rpending, &m->wsched); + p9_debug(P9_DEBUG_TRANS, "mux %p can read\n", m); + if (!test_and_set_bit(Rworksched, &m->wsched)) { + p9_debug(P9_DEBUG_TRANS, "sched read work %p\n", m); + schedule_work(&m->rq); + } + } + + if (n & EPOLLOUT) { + set_bit(Wpending, &m->wsched); + p9_debug(P9_DEBUG_TRANS, "mux %p can write\n", m); + if ((m->wsize || !list_empty(&m->unsent_req_list)) && + !test_and_set_bit(Wworksched, &m->wsched)) { + p9_debug(P9_DEBUG_TRANS, "sched write work %p\n", m); + schedule_work(&m->wq); + } + } +} + +/** + * p9_fd_request - send 9P request + * The function can sleep until the request is scheduled for sending. + * The function can be interrupted. Return from the function is not + * a guarantee that the request is sent successfully. + * + * @client: client instance + * @req: request to be sent + * + */ + +static int p9_fd_request(struct p9_client *client, struct p9_req_t *req) +{ + __poll_t n; + struct p9_trans_fd *ts = client->trans; + struct p9_conn *m = &ts->conn; + + p9_debug(P9_DEBUG_TRANS, "mux %p task %p tcall %p id %d\n", + m, current, &req->tc, req->tc.id); + if (m->err < 0) + return m->err; + + spin_lock(&m->req_lock); + WRITE_ONCE(req->status, REQ_STATUS_UNSENT); + list_add_tail(&req->req_list, &m->unsent_req_list); + spin_unlock(&m->req_lock); + + if (test_and_clear_bit(Wpending, &m->wsched)) + n = EPOLLOUT; + else + n = p9_fd_poll(m->client, NULL, NULL); + + if (n & EPOLLOUT && !test_and_set_bit(Wworksched, &m->wsched)) + schedule_work(&m->wq); + + return 0; +} + +static int p9_fd_cancel(struct p9_client *client, struct p9_req_t *req) +{ + struct p9_trans_fd *ts = client->trans; + struct p9_conn *m = &ts->conn; + int ret = 1; + + p9_debug(P9_DEBUG_TRANS, "client %p req %p\n", client, req); + + spin_lock(&m->req_lock); + + if (req->status == REQ_STATUS_UNSENT) { + list_del(&req->req_list); + WRITE_ONCE(req->status, REQ_STATUS_FLSHD); + p9_req_put(client, req); + ret = 0; + } + spin_unlock(&m->req_lock); + + return ret; +} + +static int p9_fd_cancelled(struct p9_client *client, struct p9_req_t *req) +{ + struct p9_trans_fd *ts = client->trans; + struct p9_conn *m = &ts->conn; + + p9_debug(P9_DEBUG_TRANS, "client %p req %p\n", client, req); + + spin_lock(&m->req_lock); + /* Ignore cancelled request if message has been received + * before lock. + */ + if (req->status == REQ_STATUS_RCVD) { + spin_unlock(&m->req_lock); + return 0; + } + + /* we haven't received a response for oldreq, + * remove it from the list. + */ + list_del(&req->req_list); + WRITE_ONCE(req->status, REQ_STATUS_FLSHD); + spin_unlock(&m->req_lock); + + p9_req_put(client, req); + + return 0; +} + +static int p9_fd_show_options(struct seq_file *m, struct p9_client *clnt) +{ + if (clnt->trans_mod == &p9_tcp_trans) { + if (clnt->trans_opts.tcp.port != P9_PORT) + seq_printf(m, ",port=%u", clnt->trans_opts.tcp.port); + } else if (clnt->trans_mod == &p9_fd_trans) { + if (clnt->trans_opts.fd.rfd != ~0) + seq_printf(m, ",rfd=%u", clnt->trans_opts.fd.rfd); + if (clnt->trans_opts.fd.wfd != ~0) + seq_printf(m, ",wfd=%u", clnt->trans_opts.fd.wfd); + } + return 0; +} + +/** + * parse_opts - parse mount options into p9_fd_opts structure + * @params: options string passed from mount + * @opts: fd transport-specific structure to parse options into + * + * Returns 0 upon success, -ERRNO upon failure + */ + +static int parse_opts(char *params, struct p9_fd_opts *opts) +{ + char *p; + substring_t args[MAX_OPT_ARGS]; + int option; + char *options, *tmp_options; + + opts->port = P9_PORT; + opts->rfd = ~0; + opts->wfd = ~0; + opts->privport = false; + + if (!params) + return 0; + + tmp_options = kstrdup(params, GFP_KERNEL); + if (!tmp_options) { + p9_debug(P9_DEBUG_ERROR, + "failed to allocate copy of option string\n"); + return -ENOMEM; + } + options = tmp_options; + + while ((p = strsep(&options, ",")) != NULL) { + int token; + int r; + if (!*p) + continue; + token = match_token(p, tokens, args); + if ((token != Opt_err) && (token != Opt_privport)) { + r = match_int(&args[0], &option); + if (r < 0) { + p9_debug(P9_DEBUG_ERROR, + "integer field, but no integer?\n"); + continue; + } + } + switch (token) { + case Opt_port: + opts->port = option; + break; + case Opt_rfdno: + opts->rfd = option; + break; + case Opt_wfdno: + opts->wfd = option; + break; + case Opt_privport: + opts->privport = true; + break; + default: + continue; + } + } + + kfree(tmp_options); + return 0; +} + +static int p9_fd_open(struct p9_client *client, int rfd, int wfd) +{ + struct p9_trans_fd *ts = kzalloc(sizeof(struct p9_trans_fd), + GFP_KERNEL); + if (!ts) + return -ENOMEM; + + ts->rd = fget(rfd); + if (!ts->rd) + goto out_free_ts; + if (!(ts->rd->f_mode & FMODE_READ)) + goto out_put_rd; + /* Prevent workers from hanging on IO when fd is a pipe. + * It's technically possible for userspace or concurrent mounts to + * modify this flag concurrently, which will likely result in a + * broken filesystem. However, just having bad flags here should + * not crash the kernel or cause any other sort of bug, so mark this + * particular data race as intentional so that tooling (like KCSAN) + * can allow it and detect further problems. + */ + data_race(ts->rd->f_flags |= O_NONBLOCK); + ts->wr = fget(wfd); + if (!ts->wr) + goto out_put_rd; + if (!(ts->wr->f_mode & FMODE_WRITE)) + goto out_put_wr; + data_race(ts->wr->f_flags |= O_NONBLOCK); + + client->trans = ts; + client->status = Connected; + + return 0; + +out_put_wr: + fput(ts->wr); +out_put_rd: + fput(ts->rd); +out_free_ts: + kfree(ts); + return -EIO; +} + +static int p9_socket_open(struct p9_client *client, struct socket *csocket) +{ + struct p9_trans_fd *p; + struct file *file; + + p = kzalloc(sizeof(struct p9_trans_fd), GFP_KERNEL); + if (!p) { + sock_release(csocket); + return -ENOMEM; + } + + csocket->sk->sk_allocation = GFP_NOIO; + file = sock_alloc_file(csocket, 0, NULL); + if (IS_ERR(file)) { + pr_err("%s (%d): failed to map fd\n", + __func__, task_pid_nr(current)); + kfree(p); + return PTR_ERR(file); + } + + get_file(file); + p->wr = p->rd = file; + client->trans = p; + client->status = Connected; + + p->rd->f_flags |= O_NONBLOCK; + + p9_conn_create(client); + return 0; +} + +/** + * p9_conn_destroy - cancels all pending requests of mux + * @m: mux to destroy + * + */ + +static void p9_conn_destroy(struct p9_conn *m) +{ + p9_debug(P9_DEBUG_TRANS, "mux %p prev %p next %p\n", + m, m->mux_list.prev, m->mux_list.next); + + p9_mux_poll_stop(m); + cancel_work_sync(&m->rq); + if (m->rreq) { + p9_req_put(m->client, m->rreq); + m->rreq = NULL; + } + cancel_work_sync(&m->wq); + if (m->wreq) { + p9_req_put(m->client, m->wreq); + m->wreq = NULL; + } + + p9_conn_cancel(m, -ECONNRESET); + + m->client = NULL; +} + +/** + * p9_fd_close - shutdown file descriptor transport + * @client: client instance + * + */ + +static void p9_fd_close(struct p9_client *client) +{ + struct p9_trans_fd *ts; + + if (!client) + return; + + ts = client->trans; + if (!ts) + return; + + client->status = Disconnected; + + p9_conn_destroy(&ts->conn); + + if (ts->rd) + fput(ts->rd); + if (ts->wr) + fput(ts->wr); + + kfree(ts); +} + +/* + * stolen from NFS - maybe should be made a generic function? + */ +static inline int valid_ipaddr4(const char *buf) +{ + int rc, count, in[4]; + + rc = sscanf(buf, "%d.%d.%d.%d", &in[0], &in[1], &in[2], &in[3]); + if (rc != 4) + return -EINVAL; + for (count = 0; count < 4; count++) { + if (in[count] > 255) + return -EINVAL; + } + return 0; +} + +static int p9_bind_privport(struct socket *sock) +{ + struct sockaddr_in cl; + int port, err = -EINVAL; + + memset(&cl, 0, sizeof(cl)); + cl.sin_family = AF_INET; + cl.sin_addr.s_addr = htonl(INADDR_ANY); + for (port = p9_ipport_resv_max; port >= p9_ipport_resv_min; port--) { + cl.sin_port = htons((ushort)port); + err = kernel_bind(sock, (struct sockaddr *)&cl, sizeof(cl)); + if (err != -EADDRINUSE) + break; + } + return err; +} + + +static int +p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args) +{ + int err; + struct socket *csocket; + struct sockaddr_in sin_server; + struct p9_fd_opts opts; + + err = parse_opts(args, &opts); + if (err < 0) + return err; + + if (addr == NULL || valid_ipaddr4(addr) < 0) + return -EINVAL; + + csocket = NULL; + + client->trans_opts.tcp.port = opts.port; + client->trans_opts.tcp.privport = opts.privport; + sin_server.sin_family = AF_INET; + sin_server.sin_addr.s_addr = in_aton(addr); + sin_server.sin_port = htons(opts.port); + err = __sock_create(current->nsproxy->net_ns, PF_INET, + SOCK_STREAM, IPPROTO_TCP, &csocket, 1); + if (err) { + pr_err("%s (%d): problem creating socket\n", + __func__, task_pid_nr(current)); + return err; + } + + if (opts.privport) { + err = p9_bind_privport(csocket); + if (err < 0) { + pr_err("%s (%d): problem binding to privport\n", + __func__, task_pid_nr(current)); + sock_release(csocket); + return err; + } + } + + err = csocket->ops->connect(csocket, + (struct sockaddr *)&sin_server, + sizeof(struct sockaddr_in), 0); + if (err < 0) { + pr_err("%s (%d): problem connecting socket to %s\n", + __func__, task_pid_nr(current), addr); + sock_release(csocket); + return err; + } + + return p9_socket_open(client, csocket); +} + +static int +p9_fd_create_unix(struct p9_client *client, const char *addr, char *args) +{ + int err; + struct socket *csocket; + struct sockaddr_un sun_server; + + csocket = NULL; + + if (!addr || !strlen(addr)) + return -EINVAL; + + if (strlen(addr) >= UNIX_PATH_MAX) { + pr_err("%s (%d): address too long: %s\n", + __func__, task_pid_nr(current), addr); + return -ENAMETOOLONG; + } + + sun_server.sun_family = PF_UNIX; + strcpy(sun_server.sun_path, addr); + err = __sock_create(current->nsproxy->net_ns, PF_UNIX, + SOCK_STREAM, 0, &csocket, 1); + if (err < 0) { + pr_err("%s (%d): problem creating socket\n", + __func__, task_pid_nr(current)); + + return err; + } + err = csocket->ops->connect(csocket, (struct sockaddr *)&sun_server, + sizeof(struct sockaddr_un) - 1, 0); + if (err < 0) { + pr_err("%s (%d): problem connecting socket: %s: %d\n", + __func__, task_pid_nr(current), addr, err); + sock_release(csocket); + return err; + } + + return p9_socket_open(client, csocket); +} + +static int +p9_fd_create(struct p9_client *client, const char *addr, char *args) +{ + int err; + struct p9_fd_opts opts; + + err = parse_opts(args, &opts); + if (err < 0) + return err; + client->trans_opts.fd.rfd = opts.rfd; + client->trans_opts.fd.wfd = opts.wfd; + + if (opts.rfd == ~0 || opts.wfd == ~0) { + pr_err("Insufficient options for proto=fd\n"); + return -ENOPROTOOPT; + } + + err = p9_fd_open(client, opts.rfd, opts.wfd); + if (err < 0) + return err; + + p9_conn_create(client); + + return 0; +} + +static struct p9_trans_module p9_tcp_trans = { + .name = "tcp", + .maxsize = MAX_SOCK_BUF, + .pooled_rbuffers = false, + .def = 0, + .create = p9_fd_create_tcp, + .close = p9_fd_close, + .request = p9_fd_request, + .cancel = p9_fd_cancel, + .cancelled = p9_fd_cancelled, + .show_options = p9_fd_show_options, + .owner = THIS_MODULE, +}; +MODULE_ALIAS_9P("tcp"); + +static struct p9_trans_module p9_unix_trans = { + .name = "unix", + .maxsize = MAX_SOCK_BUF, + .def = 0, + .create = p9_fd_create_unix, + .close = p9_fd_close, + .request = p9_fd_request, + .cancel = p9_fd_cancel, + .cancelled = p9_fd_cancelled, + .show_options = p9_fd_show_options, + .owner = THIS_MODULE, +}; +MODULE_ALIAS_9P("unix"); + +static struct p9_trans_module p9_fd_trans = { + .name = "fd", + .maxsize = MAX_SOCK_BUF, + .def = 0, + .create = p9_fd_create, + .close = p9_fd_close, + .request = p9_fd_request, + .cancel = p9_fd_cancel, + .cancelled = p9_fd_cancelled, + .show_options = p9_fd_show_options, + .owner = THIS_MODULE, +}; +MODULE_ALIAS_9P("fd"); + +/** + * p9_poll_workfn - poll worker thread + * @work: work queue + * + * polls all v9fs transports for new events and queues the appropriate + * work to the work queue + * + */ + +static void p9_poll_workfn(struct work_struct *work) +{ + unsigned long flags; + + p9_debug(P9_DEBUG_TRANS, "start %p\n", current); + + spin_lock_irqsave(&p9_poll_lock, flags); + while (!list_empty(&p9_poll_pending_list)) { + struct p9_conn *conn = list_first_entry(&p9_poll_pending_list, + struct p9_conn, + poll_pending_link); + list_del_init(&conn->poll_pending_link); + spin_unlock_irqrestore(&p9_poll_lock, flags); + + p9_poll_mux(conn); + + spin_lock_irqsave(&p9_poll_lock, flags); + } + spin_unlock_irqrestore(&p9_poll_lock, flags); + + p9_debug(P9_DEBUG_TRANS, "finish\n"); +} + +static int __init p9_trans_fd_init(void) +{ + v9fs_register_trans(&p9_tcp_trans); + v9fs_register_trans(&p9_unix_trans); + v9fs_register_trans(&p9_fd_trans); + + return 0; +} + +static void __exit p9_trans_fd_exit(void) +{ + flush_work(&p9_poll_work); + v9fs_unregister_trans(&p9_tcp_trans); + v9fs_unregister_trans(&p9_unix_trans); + v9fs_unregister_trans(&p9_fd_trans); +} + +module_init(p9_trans_fd_init); +module_exit(p9_trans_fd_exit); + +MODULE_AUTHOR("Eric Van Hensbergen "); +MODULE_DESCRIPTION("Filedescriptor Transport for 9P"); +MODULE_LICENSE("GPL"); diff --git a/net/9p/trans_rdma.c b/net/9p/trans_rdma.c new file mode 100644 index 000000000..29a929230 --- /dev/null +++ b/net/9p/trans_rdma.c @@ -0,0 +1,782 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * RDMA transport layer based on the trans_fd.c implementation. + * + * Copyright (C) 2008 by Tom Tucker + * Copyright (C) 2006 by Russ Cox + * Copyright (C) 2004-2005 by Latchesar Ionkov + * Copyright (C) 2004-2008 by Eric Van Hensbergen + * Copyright (C) 1997-2002 by Ron Minnich + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define P9_PORT 5640 +#define P9_RDMA_SQ_DEPTH 32 +#define P9_RDMA_RQ_DEPTH 32 +#define P9_RDMA_SEND_SGE 4 +#define P9_RDMA_RECV_SGE 4 +#define P9_RDMA_IRD 0 +#define P9_RDMA_ORD 0 +#define P9_RDMA_TIMEOUT 30000 /* 30 seconds */ +#define P9_RDMA_MAXSIZE (1024*1024) /* 1MB */ + +/** + * struct p9_trans_rdma - RDMA transport instance + * + * @state: tracks the transport state machine for connection setup and tear down + * @cm_id: The RDMA CM ID + * @pd: Protection Domain pointer + * @qp: Queue Pair pointer + * @cq: Completion Queue pointer + * @timeout: Number of uSecs to wait for connection management events + * @privport: Whether a privileged port may be used + * @port: The port to use + * @sq_depth: The depth of the Send Queue + * @sq_sem: Semaphore for the SQ + * @rq_depth: The depth of the Receive Queue. + * @rq_sem: Semaphore for the RQ + * @excess_rc : Amount of posted Receive Contexts without a pending request. + * See rdma_request() + * @addr: The remote peer's address + * @req_lock: Protects the active request list + * @cm_done: Completion event for connection management tracking + */ +struct p9_trans_rdma { + enum { + P9_RDMA_INIT, + P9_RDMA_ADDR_RESOLVED, + P9_RDMA_ROUTE_RESOLVED, + P9_RDMA_CONNECTED, + P9_RDMA_FLUSHING, + P9_RDMA_CLOSING, + P9_RDMA_CLOSED, + } state; + struct rdma_cm_id *cm_id; + struct ib_pd *pd; + struct ib_qp *qp; + struct ib_cq *cq; + long timeout; + bool privport; + u16 port; + int sq_depth; + struct semaphore sq_sem; + int rq_depth; + struct semaphore rq_sem; + atomic_t excess_rc; + struct sockaddr_in addr; + spinlock_t req_lock; + + struct completion cm_done; +}; + +struct p9_rdma_req; + +/** + * struct p9_rdma_context - Keeps track of in-process WR + * + * @cqe: completion queue entry + * @busa: Bus address to unmap when the WR completes + * @req: Keeps track of requests (send) + * @rc: Keepts track of replies (receive) + */ +struct p9_rdma_context { + struct ib_cqe cqe; + dma_addr_t busa; + union { + struct p9_req_t *req; + struct p9_fcall rc; + }; +}; + +/** + * struct p9_rdma_opts - Collection of mount options + * @port: port of connection + * @privport: Whether a privileged port may be used + * @sq_depth: The requested depth of the SQ. This really doesn't need + * to be any deeper than the number of threads used in the client + * @rq_depth: The depth of the RQ. Should be greater than or equal to SQ depth + * @timeout: Time to wait in msecs for CM events + */ +struct p9_rdma_opts { + short port; + bool privport; + int sq_depth; + int rq_depth; + long timeout; +}; + +/* + * Option Parsing (code inspired by NFS code) + */ +enum { + /* Options that take integer arguments */ + Opt_port, Opt_rq_depth, Opt_sq_depth, Opt_timeout, + /* Options that take no argument */ + Opt_privport, + Opt_err, +}; + +static match_table_t tokens = { + {Opt_port, "port=%u"}, + {Opt_sq_depth, "sq=%u"}, + {Opt_rq_depth, "rq=%u"}, + {Opt_timeout, "timeout=%u"}, + {Opt_privport, "privport"}, + {Opt_err, NULL}, +}; + +static int p9_rdma_show_options(struct seq_file *m, struct p9_client *clnt) +{ + struct p9_trans_rdma *rdma = clnt->trans; + + if (rdma->port != P9_PORT) + seq_printf(m, ",port=%u", rdma->port); + if (rdma->sq_depth != P9_RDMA_SQ_DEPTH) + seq_printf(m, ",sq=%u", rdma->sq_depth); + if (rdma->rq_depth != P9_RDMA_RQ_DEPTH) + seq_printf(m, ",rq=%u", rdma->rq_depth); + if (rdma->timeout != P9_RDMA_TIMEOUT) + seq_printf(m, ",timeout=%lu", rdma->timeout); + if (rdma->privport) + seq_puts(m, ",privport"); + return 0; +} + +/** + * parse_opts - parse mount options into rdma options structure + * @params: options string passed from mount + * @opts: rdma transport-specific structure to parse options into + * + * Returns 0 upon success, -ERRNO upon failure + */ +static int parse_opts(char *params, struct p9_rdma_opts *opts) +{ + char *p; + substring_t args[MAX_OPT_ARGS]; + int option; + char *options, *tmp_options; + + opts->port = P9_PORT; + opts->sq_depth = P9_RDMA_SQ_DEPTH; + opts->rq_depth = P9_RDMA_RQ_DEPTH; + opts->timeout = P9_RDMA_TIMEOUT; + opts->privport = false; + + if (!params) + return 0; + + tmp_options = kstrdup(params, GFP_KERNEL); + if (!tmp_options) { + p9_debug(P9_DEBUG_ERROR, + "failed to allocate copy of option string\n"); + return -ENOMEM; + } + options = tmp_options; + + while ((p = strsep(&options, ",")) != NULL) { + int token; + int r; + if (!*p) + continue; + token = match_token(p, tokens, args); + if ((token != Opt_err) && (token != Opt_privport)) { + r = match_int(&args[0], &option); + if (r < 0) { + p9_debug(P9_DEBUG_ERROR, + "integer field, but no integer?\n"); + continue; + } + } + switch (token) { + case Opt_port: + opts->port = option; + break; + case Opt_sq_depth: + opts->sq_depth = option; + break; + case Opt_rq_depth: + opts->rq_depth = option; + break; + case Opt_timeout: + opts->timeout = option; + break; + case Opt_privport: + opts->privport = true; + break; + default: + continue; + } + } + /* RQ must be at least as large as the SQ */ + opts->rq_depth = max(opts->rq_depth, opts->sq_depth); + kfree(tmp_options); + return 0; +} + +static int +p9_cm_event_handler(struct rdma_cm_id *id, struct rdma_cm_event *event) +{ + struct p9_client *c = id->context; + struct p9_trans_rdma *rdma = c->trans; + switch (event->event) { + case RDMA_CM_EVENT_ADDR_RESOLVED: + BUG_ON(rdma->state != P9_RDMA_INIT); + rdma->state = P9_RDMA_ADDR_RESOLVED; + break; + + case RDMA_CM_EVENT_ROUTE_RESOLVED: + BUG_ON(rdma->state != P9_RDMA_ADDR_RESOLVED); + rdma->state = P9_RDMA_ROUTE_RESOLVED; + break; + + case RDMA_CM_EVENT_ESTABLISHED: + BUG_ON(rdma->state != P9_RDMA_ROUTE_RESOLVED); + rdma->state = P9_RDMA_CONNECTED; + break; + + case RDMA_CM_EVENT_DISCONNECTED: + if (rdma) + rdma->state = P9_RDMA_CLOSED; + c->status = Disconnected; + break; + + case RDMA_CM_EVENT_TIMEWAIT_EXIT: + break; + + case RDMA_CM_EVENT_ADDR_CHANGE: + case RDMA_CM_EVENT_ROUTE_ERROR: + case RDMA_CM_EVENT_DEVICE_REMOVAL: + case RDMA_CM_EVENT_MULTICAST_JOIN: + case RDMA_CM_EVENT_MULTICAST_ERROR: + case RDMA_CM_EVENT_REJECTED: + case RDMA_CM_EVENT_CONNECT_REQUEST: + case RDMA_CM_EVENT_CONNECT_RESPONSE: + case RDMA_CM_EVENT_CONNECT_ERROR: + case RDMA_CM_EVENT_ADDR_ERROR: + case RDMA_CM_EVENT_UNREACHABLE: + c->status = Disconnected; + rdma_disconnect(rdma->cm_id); + break; + default: + BUG(); + } + complete(&rdma->cm_done); + return 0; +} + +static void +recv_done(struct ib_cq *cq, struct ib_wc *wc) +{ + struct p9_client *client = cq->cq_context; + struct p9_trans_rdma *rdma = client->trans; + struct p9_rdma_context *c = + container_of(wc->wr_cqe, struct p9_rdma_context, cqe); + struct p9_req_t *req; + int err = 0; + int16_t tag; + + req = NULL; + ib_dma_unmap_single(rdma->cm_id->device, c->busa, client->msize, + DMA_FROM_DEVICE); + + if (wc->status != IB_WC_SUCCESS) + goto err_out; + + c->rc.size = wc->byte_len; + err = p9_parse_header(&c->rc, NULL, NULL, &tag, 1); + if (err) + goto err_out; + + req = p9_tag_lookup(client, tag); + if (!req) + goto err_out; + + /* Check that we have not yet received a reply for this request. + */ + if (unlikely(req->rc.sdata)) { + pr_err("Duplicate reply for request %d", tag); + goto err_out; + } + + req->rc.size = c->rc.size; + req->rc.sdata = c->rc.sdata; + p9_client_cb(client, req, REQ_STATUS_RCVD); + + out: + up(&rdma->rq_sem); + kfree(c); + return; + + err_out: + p9_debug(P9_DEBUG_ERROR, "req %p err %d status %d\n", + req, err, wc->status); + rdma->state = P9_RDMA_FLUSHING; + client->status = Disconnected; + goto out; +} + +static void +send_done(struct ib_cq *cq, struct ib_wc *wc) +{ + struct p9_client *client = cq->cq_context; + struct p9_trans_rdma *rdma = client->trans; + struct p9_rdma_context *c = + container_of(wc->wr_cqe, struct p9_rdma_context, cqe); + + ib_dma_unmap_single(rdma->cm_id->device, + c->busa, c->req->tc.size, + DMA_TO_DEVICE); + up(&rdma->sq_sem); + p9_req_put(client, c->req); + kfree(c); +} + +static void qp_event_handler(struct ib_event *event, void *context) +{ + p9_debug(P9_DEBUG_ERROR, "QP event %d context %p\n", + event->event, context); +} + +static void rdma_destroy_trans(struct p9_trans_rdma *rdma) +{ + if (!rdma) + return; + + if (rdma->qp && !IS_ERR(rdma->qp)) + ib_destroy_qp(rdma->qp); + + if (rdma->pd && !IS_ERR(rdma->pd)) + ib_dealloc_pd(rdma->pd); + + if (rdma->cq && !IS_ERR(rdma->cq)) + ib_free_cq(rdma->cq); + + if (rdma->cm_id && !IS_ERR(rdma->cm_id)) + rdma_destroy_id(rdma->cm_id); + + kfree(rdma); +} + +static int +post_recv(struct p9_client *client, struct p9_rdma_context *c) +{ + struct p9_trans_rdma *rdma = client->trans; + struct ib_recv_wr wr; + struct ib_sge sge; + int ret; + + c->busa = ib_dma_map_single(rdma->cm_id->device, + c->rc.sdata, client->msize, + DMA_FROM_DEVICE); + if (ib_dma_mapping_error(rdma->cm_id->device, c->busa)) + goto error; + + c->cqe.done = recv_done; + + sge.addr = c->busa; + sge.length = client->msize; + sge.lkey = rdma->pd->local_dma_lkey; + + wr.next = NULL; + wr.wr_cqe = &c->cqe; + wr.sg_list = &sge; + wr.num_sge = 1; + + ret = ib_post_recv(rdma->qp, &wr, NULL); + if (ret) + ib_dma_unmap_single(rdma->cm_id->device, c->busa, + client->msize, DMA_FROM_DEVICE); + return ret; + + error: + p9_debug(P9_DEBUG_ERROR, "EIO\n"); + return -EIO; +} + +static int rdma_request(struct p9_client *client, struct p9_req_t *req) +{ + struct p9_trans_rdma *rdma = client->trans; + struct ib_send_wr wr; + struct ib_sge sge; + int err = 0; + unsigned long flags; + struct p9_rdma_context *c = NULL; + struct p9_rdma_context *rpl_context = NULL; + + /* When an error occurs between posting the recv and the send, + * there will be a receive context posted without a pending request. + * Since there is no way to "un-post" it, we remember it and skip + * post_recv() for the next request. + * So here, + * see if we are this `next request' and need to absorb an excess rc. + * If yes, then drop and free our own, and do not recv_post(). + **/ + if (unlikely(atomic_read(&rdma->excess_rc) > 0)) { + if ((atomic_sub_return(1, &rdma->excess_rc) >= 0)) { + /* Got one! */ + p9_fcall_fini(&req->rc); + req->rc.sdata = NULL; + goto dont_need_post_recv; + } else { + /* We raced and lost. */ + atomic_inc(&rdma->excess_rc); + } + } + + /* Allocate an fcall for the reply */ + rpl_context = kmalloc(sizeof *rpl_context, GFP_NOFS); + if (!rpl_context) { + err = -ENOMEM; + goto recv_error; + } + rpl_context->rc.sdata = req->rc.sdata; + + /* + * Post a receive buffer for this request. We need to ensure + * there is a reply buffer available for every outstanding + * request. A flushed request can result in no reply for an + * outstanding request, so we must keep a count to avoid + * overflowing the RQ. + */ + if (down_interruptible(&rdma->rq_sem)) { + err = -EINTR; + goto recv_error; + } + + err = post_recv(client, rpl_context); + if (err) { + p9_debug(P9_DEBUG_ERROR, "POST RECV failed: %d\n", err); + goto recv_error; + } + /* remove posted receive buffer from request structure */ + req->rc.sdata = NULL; + +dont_need_post_recv: + /* Post the request */ + c = kmalloc(sizeof *c, GFP_NOFS); + if (!c) { + err = -ENOMEM; + goto send_error; + } + c->req = req; + + c->busa = ib_dma_map_single(rdma->cm_id->device, + c->req->tc.sdata, c->req->tc.size, + DMA_TO_DEVICE); + if (ib_dma_mapping_error(rdma->cm_id->device, c->busa)) { + err = -EIO; + goto send_error; + } + + c->cqe.done = send_done; + + sge.addr = c->busa; + sge.length = c->req->tc.size; + sge.lkey = rdma->pd->local_dma_lkey; + + wr.next = NULL; + wr.wr_cqe = &c->cqe; + wr.opcode = IB_WR_SEND; + wr.send_flags = IB_SEND_SIGNALED; + wr.sg_list = &sge; + wr.num_sge = 1; + + if (down_interruptible(&rdma->sq_sem)) { + err = -EINTR; + goto dma_unmap; + } + + /* Mark request as `sent' *before* we actually send it, + * because doing if after could erase the REQ_STATUS_RCVD + * status in case of a very fast reply. + */ + WRITE_ONCE(req->status, REQ_STATUS_SENT); + err = ib_post_send(rdma->qp, &wr, NULL); + if (err) + goto dma_unmap; + + /* Success */ + return 0; + +dma_unmap: + ib_dma_unmap_single(rdma->cm_id->device, c->busa, + c->req->tc.size, DMA_TO_DEVICE); + /* Handle errors that happened during or while preparing the send: */ + send_error: + WRITE_ONCE(req->status, REQ_STATUS_ERROR); + kfree(c); + p9_debug(P9_DEBUG_ERROR, "Error %d in rdma_request()\n", err); + + /* Ach. + * We did recv_post(), but not send. We have one recv_post in excess. + */ + atomic_inc(&rdma->excess_rc); + return err; + + /* Handle errors that happened during or while preparing post_recv(): */ + recv_error: + kfree(rpl_context); + spin_lock_irqsave(&rdma->req_lock, flags); + if (err != -EINTR && rdma->state < P9_RDMA_CLOSING) { + rdma->state = P9_RDMA_CLOSING; + spin_unlock_irqrestore(&rdma->req_lock, flags); + rdma_disconnect(rdma->cm_id); + } else + spin_unlock_irqrestore(&rdma->req_lock, flags); + return err; +} + +static void rdma_close(struct p9_client *client) +{ + struct p9_trans_rdma *rdma; + + if (!client) + return; + + rdma = client->trans; + if (!rdma) + return; + + client->status = Disconnected; + rdma_disconnect(rdma->cm_id); + rdma_destroy_trans(rdma); +} + +/** + * alloc_rdma - Allocate and initialize the rdma transport structure + * @opts: Mount options structure + */ +static struct p9_trans_rdma *alloc_rdma(struct p9_rdma_opts *opts) +{ + struct p9_trans_rdma *rdma; + + rdma = kzalloc(sizeof(struct p9_trans_rdma), GFP_KERNEL); + if (!rdma) + return NULL; + + rdma->port = opts->port; + rdma->privport = opts->privport; + rdma->sq_depth = opts->sq_depth; + rdma->rq_depth = opts->rq_depth; + rdma->timeout = opts->timeout; + spin_lock_init(&rdma->req_lock); + init_completion(&rdma->cm_done); + sema_init(&rdma->sq_sem, rdma->sq_depth); + sema_init(&rdma->rq_sem, rdma->rq_depth); + atomic_set(&rdma->excess_rc, 0); + + return rdma; +} + +static int rdma_cancel(struct p9_client *client, struct p9_req_t *req) +{ + /* Nothing to do here. + * We will take care of it (if we have to) in rdma_cancelled() + */ + return 1; +} + +/* A request has been fully flushed without a reply. + * That means we have posted one buffer in excess. + */ +static int rdma_cancelled(struct p9_client *client, struct p9_req_t *req) +{ + struct p9_trans_rdma *rdma = client->trans; + atomic_inc(&rdma->excess_rc); + return 0; +} + +static int p9_rdma_bind_privport(struct p9_trans_rdma *rdma) +{ + struct sockaddr_in cl = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_ANY), + }; + int port, err = -EINVAL; + + for (port = P9_DEF_MAX_RESVPORT; port >= P9_DEF_MIN_RESVPORT; port--) { + cl.sin_port = htons((ushort)port); + err = rdma_bind_addr(rdma->cm_id, (struct sockaddr *)&cl); + if (err != -EADDRINUSE) + break; + } + return err; +} + +/** + * rdma_create_trans - Transport method for creating a transport instance + * @client: client instance + * @addr: IP address string + * @args: Mount options string + */ +static int +rdma_create_trans(struct p9_client *client, const char *addr, char *args) +{ + int err; + struct p9_rdma_opts opts; + struct p9_trans_rdma *rdma; + struct rdma_conn_param conn_param; + struct ib_qp_init_attr qp_attr; + + if (addr == NULL) + return -EINVAL; + + /* Parse the transport specific mount options */ + err = parse_opts(args, &opts); + if (err < 0) + return err; + + /* Create and initialize the RDMA transport structure */ + rdma = alloc_rdma(&opts); + if (!rdma) + return -ENOMEM; + + /* Create the RDMA CM ID */ + rdma->cm_id = rdma_create_id(&init_net, p9_cm_event_handler, client, + RDMA_PS_TCP, IB_QPT_RC); + if (IS_ERR(rdma->cm_id)) + goto error; + + /* Associate the client with the transport */ + client->trans = rdma; + + /* Bind to a privileged port if we need to */ + if (opts.privport) { + err = p9_rdma_bind_privport(rdma); + if (err < 0) { + pr_err("%s (%d): problem binding to privport: %d\n", + __func__, task_pid_nr(current), -err); + goto error; + } + } + + /* Resolve the server's address */ + rdma->addr.sin_family = AF_INET; + rdma->addr.sin_addr.s_addr = in_aton(addr); + rdma->addr.sin_port = htons(opts.port); + err = rdma_resolve_addr(rdma->cm_id, NULL, + (struct sockaddr *)&rdma->addr, + rdma->timeout); + if (err) + goto error; + err = wait_for_completion_interruptible(&rdma->cm_done); + if (err || (rdma->state != P9_RDMA_ADDR_RESOLVED)) + goto error; + + /* Resolve the route to the server */ + err = rdma_resolve_route(rdma->cm_id, rdma->timeout); + if (err) + goto error; + err = wait_for_completion_interruptible(&rdma->cm_done); + if (err || (rdma->state != P9_RDMA_ROUTE_RESOLVED)) + goto error; + + /* Create the Completion Queue */ + rdma->cq = ib_alloc_cq_any(rdma->cm_id->device, client, + opts.sq_depth + opts.rq_depth + 1, + IB_POLL_SOFTIRQ); + if (IS_ERR(rdma->cq)) + goto error; + + /* Create the Protection Domain */ + rdma->pd = ib_alloc_pd(rdma->cm_id->device, 0); + if (IS_ERR(rdma->pd)) + goto error; + + /* Create the Queue Pair */ + memset(&qp_attr, 0, sizeof qp_attr); + qp_attr.event_handler = qp_event_handler; + qp_attr.qp_context = client; + qp_attr.cap.max_send_wr = opts.sq_depth; + qp_attr.cap.max_recv_wr = opts.rq_depth; + qp_attr.cap.max_send_sge = P9_RDMA_SEND_SGE; + qp_attr.cap.max_recv_sge = P9_RDMA_RECV_SGE; + qp_attr.sq_sig_type = IB_SIGNAL_REQ_WR; + qp_attr.qp_type = IB_QPT_RC; + qp_attr.send_cq = rdma->cq; + qp_attr.recv_cq = rdma->cq; + err = rdma_create_qp(rdma->cm_id, rdma->pd, &qp_attr); + if (err) + goto error; + rdma->qp = rdma->cm_id->qp; + + /* Request a connection */ + memset(&conn_param, 0, sizeof(conn_param)); + conn_param.private_data = NULL; + conn_param.private_data_len = 0; + conn_param.responder_resources = P9_RDMA_IRD; + conn_param.initiator_depth = P9_RDMA_ORD; + err = rdma_connect(rdma->cm_id, &conn_param); + if (err) + goto error; + err = wait_for_completion_interruptible(&rdma->cm_done); + if (err || (rdma->state != P9_RDMA_CONNECTED)) + goto error; + + client->status = Connected; + + return 0; + +error: + rdma_destroy_trans(rdma); + return -ENOTCONN; +} + +static struct p9_trans_module p9_rdma_trans = { + .name = "rdma", + .maxsize = P9_RDMA_MAXSIZE, + .pooled_rbuffers = true, + .def = 0, + .owner = THIS_MODULE, + .create = rdma_create_trans, + .close = rdma_close, + .request = rdma_request, + .cancel = rdma_cancel, + .cancelled = rdma_cancelled, + .show_options = p9_rdma_show_options, +}; + +/** + * p9_trans_rdma_init - Register the 9P RDMA transport driver + */ +static int __init p9_trans_rdma_init(void) +{ + v9fs_register_trans(&p9_rdma_trans); + return 0; +} + +static void __exit p9_trans_rdma_exit(void) +{ + v9fs_unregister_trans(&p9_rdma_trans); +} + +module_init(p9_trans_rdma_init); +module_exit(p9_trans_rdma_exit); +MODULE_ALIAS_9P("rdma"); + +MODULE_AUTHOR("Tom Tucker "); +MODULE_DESCRIPTION("RDMA Transport for 9P"); +MODULE_LICENSE("Dual BSD/GPL"); diff --git a/net/9p/trans_virtio.c b/net/9p/trans_virtio.c new file mode 100644 index 000000000..3cf660d8a --- /dev/null +++ b/net/9p/trans_virtio.c @@ -0,0 +1,839 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * The Virtio 9p transport driver + * + * This is a block based transport driver based on the lguest block driver + * code. + * + * Copyright (C) 2007, 2008 Eric Van Hensbergen, IBM Corporation + * + * Based on virtio console driver + * Copyright (C) 2006, 2007 Rusty Russell, IBM Corporation + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "trans_common.h" + +#define VIRTQUEUE_NUM 128 + +/* a single mutex to manage channel initialization and attachment */ +static DEFINE_MUTEX(virtio_9p_lock); +static DECLARE_WAIT_QUEUE_HEAD(vp_wq); +static atomic_t vp_pinned = ATOMIC_INIT(0); + +/** + * struct virtio_chan - per-instance transport information + * @inuse: whether the channel is in use + * @lock: protects multiple elements within this structure + * @client: client instance + * @vdev: virtio dev associated with this channel + * @vq: virtio queue associated with this channel + * @ring_bufs_avail: flag to indicate there is some available in the ring buf + * @vc_wq: wait queue for waiting for thing to be added to ring buf + * @p9_max_pages: maximum number of pinned pages + * @sg: scatter gather list which is used to pack a request (protected?) + * @chan_list: linked list of channels + * + * We keep all per-channel information in a structure. + * This structure is allocated within the devices dev->mem space. + * A pointer to the structure will get put in the transport private. + * + */ + +struct virtio_chan { + bool inuse; + + spinlock_t lock; + + struct p9_client *client; + struct virtio_device *vdev; + struct virtqueue *vq; + int ring_bufs_avail; + wait_queue_head_t *vc_wq; + /* This is global limit. Since we don't have a global structure, + * will be placing it in each channel. + */ + unsigned long p9_max_pages; + /* Scatterlist: can be too big for stack. */ + struct scatterlist sg[VIRTQUEUE_NUM]; + /** + * @tag: name to identify a mount null terminated + */ + char *tag; + + struct list_head chan_list; +}; + +static struct list_head virtio_chan_list; + +/* How many bytes left in this page. */ +static unsigned int rest_of_page(void *data) +{ + return PAGE_SIZE - offset_in_page(data); +} + +/** + * p9_virtio_close - reclaim resources of a channel + * @client: client instance + * + * This reclaims a channel by freeing its resources and + * resetting its inuse flag. + * + */ + +static void p9_virtio_close(struct p9_client *client) +{ + struct virtio_chan *chan = client->trans; + + mutex_lock(&virtio_9p_lock); + if (chan) + chan->inuse = false; + mutex_unlock(&virtio_9p_lock); +} + +/** + * req_done - callback which signals activity from the server + * @vq: virtio queue activity was received on + * + * This notifies us that the server has triggered some activity + * on the virtio channel - most likely a response to request we + * sent. Figure out which requests now have responses and wake up + * those threads. + * + * Bugs: could do with some additional sanity checking, but appears to work. + * + */ + +static void req_done(struct virtqueue *vq) +{ + struct virtio_chan *chan = vq->vdev->priv; + unsigned int len; + struct p9_req_t *req; + bool need_wakeup = false; + unsigned long flags; + + p9_debug(P9_DEBUG_TRANS, ": request done\n"); + + spin_lock_irqsave(&chan->lock, flags); + while ((req = virtqueue_get_buf(chan->vq, &len)) != NULL) { + if (!chan->ring_bufs_avail) { + chan->ring_bufs_avail = 1; + need_wakeup = true; + } + + if (len) { + req->rc.size = len; + p9_client_cb(chan->client, req, REQ_STATUS_RCVD); + } + } + spin_unlock_irqrestore(&chan->lock, flags); + /* Wakeup if anyone waiting for VirtIO ring space. */ + if (need_wakeup) + wake_up(chan->vc_wq); +} + +/** + * pack_sg_list - pack a scatter gather list from a linear buffer + * @sg: scatter/gather list to pack into + * @start: which segment of the sg_list to start at + * @limit: maximum segment to pack data to + * @data: data to pack into scatter/gather list + * @count: amount of data to pack into the scatter/gather list + * + * sg_lists have multiple segments of various sizes. This will pack + * arbitrary data into an existing scatter gather list, segmenting the + * data as necessary within constraints. + * + */ + +static int pack_sg_list(struct scatterlist *sg, int start, + int limit, char *data, int count) +{ + int s; + int index = start; + + while (count) { + s = rest_of_page(data); + if (s > count) + s = count; + BUG_ON(index >= limit); + /* Make sure we don't terminate early. */ + sg_unmark_end(&sg[index]); + sg_set_buf(&sg[index++], data, s); + count -= s; + data += s; + } + if (index-start) + sg_mark_end(&sg[index - 1]); + return index-start; +} + +/* We don't currently allow canceling of virtio requests */ +static int p9_virtio_cancel(struct p9_client *client, struct p9_req_t *req) +{ + return 1; +} + +/* Reply won't come, so drop req ref */ +static int p9_virtio_cancelled(struct p9_client *client, struct p9_req_t *req) +{ + p9_req_put(client, req); + return 0; +} + +/** + * pack_sg_list_p - Just like pack_sg_list. Instead of taking a buffer, + * this takes a list of pages. + * @sg: scatter/gather list to pack into + * @start: which segment of the sg_list to start at + * @limit: maximum number of pages in sg list. + * @pdata: a list of pages to add into sg. + * @nr_pages: number of pages to pack into the scatter/gather list + * @offs: amount of data in the beginning of first page _not_ to pack + * @count: amount of data to pack into the scatter/gather list + */ +static int +pack_sg_list_p(struct scatterlist *sg, int start, int limit, + struct page **pdata, int nr_pages, size_t offs, int count) +{ + int i = 0, s; + int data_off = offs; + int index = start; + + BUG_ON(nr_pages > (limit - start)); + /* + * if the first page doesn't start at + * page boundary find the offset + */ + while (nr_pages) { + s = PAGE_SIZE - data_off; + if (s > count) + s = count; + BUG_ON(index >= limit); + /* Make sure we don't terminate early. */ + sg_unmark_end(&sg[index]); + sg_set_page(&sg[index++], pdata[i++], s, data_off); + data_off = 0; + count -= s; + nr_pages--; + } + + if (index-start) + sg_mark_end(&sg[index - 1]); + return index - start; +} + +/** + * p9_virtio_request - issue a request + * @client: client instance issuing the request + * @req: request to be issued + * + */ + +static int +p9_virtio_request(struct p9_client *client, struct p9_req_t *req) +{ + int err; + int in, out, out_sgs, in_sgs; + unsigned long flags; + struct virtio_chan *chan = client->trans; + struct scatterlist *sgs[2]; + + p9_debug(P9_DEBUG_TRANS, "9p debug: virtio request\n"); + + WRITE_ONCE(req->status, REQ_STATUS_SENT); +req_retry: + spin_lock_irqsave(&chan->lock, flags); + + out_sgs = in_sgs = 0; + /* Handle out VirtIO ring buffers */ + out = pack_sg_list(chan->sg, 0, + VIRTQUEUE_NUM, req->tc.sdata, req->tc.size); + if (out) + sgs[out_sgs++] = chan->sg; + + in = pack_sg_list(chan->sg, out, + VIRTQUEUE_NUM, req->rc.sdata, req->rc.capacity); + if (in) + sgs[out_sgs + in_sgs++] = chan->sg + out; + + err = virtqueue_add_sgs(chan->vq, sgs, out_sgs, in_sgs, req, + GFP_ATOMIC); + if (err < 0) { + if (err == -ENOSPC) { + chan->ring_bufs_avail = 0; + spin_unlock_irqrestore(&chan->lock, flags); + err = wait_event_killable(*chan->vc_wq, + chan->ring_bufs_avail); + if (err == -ERESTARTSYS) + return err; + + p9_debug(P9_DEBUG_TRANS, "Retry virtio request\n"); + goto req_retry; + } else { + spin_unlock_irqrestore(&chan->lock, flags); + p9_debug(P9_DEBUG_TRANS, + "virtio rpc add_sgs returned failure\n"); + return -EIO; + } + } + virtqueue_kick(chan->vq); + spin_unlock_irqrestore(&chan->lock, flags); + + p9_debug(P9_DEBUG_TRANS, "virtio request kicked\n"); + return 0; +} + +static int p9_get_mapped_pages(struct virtio_chan *chan, + struct page ***pages, + struct iov_iter *data, + int count, + size_t *offs, + int *need_drop) +{ + int nr_pages; + int err; + + if (!iov_iter_count(data)) + return 0; + + if (!iov_iter_is_kvec(data)) { + int n; + /* + * We allow only p9_max_pages pinned. We wait for the + * Other zc request to finish here + */ + if (atomic_read(&vp_pinned) >= chan->p9_max_pages) { + err = wait_event_killable(vp_wq, + (atomic_read(&vp_pinned) < chan->p9_max_pages)); + if (err == -ERESTARTSYS) + return err; + } + n = iov_iter_get_pages_alloc2(data, pages, count, offs); + if (n < 0) + return n; + *need_drop = 1; + nr_pages = DIV_ROUND_UP(n + *offs, PAGE_SIZE); + atomic_add(nr_pages, &vp_pinned); + return n; + } else { + /* kernel buffer, no need to pin pages */ + int index; + size_t len; + void *p; + + /* we'd already checked that it's non-empty */ + while (1) { + len = iov_iter_single_seg_count(data); + if (likely(len)) { + p = data->kvec->iov_base + data->iov_offset; + break; + } + iov_iter_advance(data, 0); + } + if (len > count) + len = count; + + nr_pages = DIV_ROUND_UP((unsigned long)p + len, PAGE_SIZE) - + (unsigned long)p / PAGE_SIZE; + + *pages = kmalloc_array(nr_pages, sizeof(struct page *), + GFP_NOFS); + if (!*pages) + return -ENOMEM; + + *need_drop = 0; + p -= (*offs = offset_in_page(p)); + for (index = 0; index < nr_pages; index++) { + if (is_vmalloc_addr(p)) + (*pages)[index] = vmalloc_to_page(p); + else + (*pages)[index] = kmap_to_page(p); + p += PAGE_SIZE; + } + iov_iter_advance(data, len); + return len; + } +} + +static void handle_rerror(struct p9_req_t *req, int in_hdr_len, + size_t offs, struct page **pages) +{ + unsigned size, n; + void *to = req->rc.sdata + in_hdr_len; + + // Fits entirely into the static data? Nothing to do. + if (req->rc.size < in_hdr_len || !pages) + return; + + // Really long error message? Tough, truncate the reply. Might get + // rejected (we can't be arsed to adjust the size encoded in header, + // or string size for that matter), but it wouldn't be anything valid + // anyway. + if (unlikely(req->rc.size > P9_ZC_HDR_SZ)) + req->rc.size = P9_ZC_HDR_SZ; + + // data won't span more than two pages + size = req->rc.size - in_hdr_len; + n = PAGE_SIZE - offs; + if (size > n) { + memcpy_from_page(to, *pages++, offs, n); + offs = 0; + to += n; + size -= n; + } + memcpy_from_page(to, *pages, offs, size); +} + +/** + * p9_virtio_zc_request - issue a zero copy request + * @client: client instance issuing the request + * @req: request to be issued + * @uidata: user buffer that should be used for zero copy read + * @uodata: user buffer that should be used for zero copy write + * @inlen: read buffer size + * @outlen: write buffer size + * @in_hdr_len: reader header size, This is the size of response protocol data + * + */ +static int +p9_virtio_zc_request(struct p9_client *client, struct p9_req_t *req, + struct iov_iter *uidata, struct iov_iter *uodata, + int inlen, int outlen, int in_hdr_len) +{ + int in, out, err, out_sgs, in_sgs; + unsigned long flags; + int in_nr_pages = 0, out_nr_pages = 0; + struct page **in_pages = NULL, **out_pages = NULL; + struct virtio_chan *chan = client->trans; + struct scatterlist *sgs[4]; + size_t offs = 0; + int need_drop = 0; + int kicked = 0; + + p9_debug(P9_DEBUG_TRANS, "virtio request\n"); + + if (uodata) { + __le32 sz; + int n = p9_get_mapped_pages(chan, &out_pages, uodata, + outlen, &offs, &need_drop); + if (n < 0) { + err = n; + goto err_out; + } + out_nr_pages = DIV_ROUND_UP(n + offs, PAGE_SIZE); + if (n != outlen) { + __le32 v = cpu_to_le32(n); + memcpy(&req->tc.sdata[req->tc.size - 4], &v, 4); + outlen = n; + } + /* The size field of the message must include the length of the + * header and the length of the data. We didn't actually know + * the length of the data until this point so add it in now. + */ + sz = cpu_to_le32(req->tc.size + outlen); + memcpy(&req->tc.sdata[0], &sz, sizeof(sz)); + } else if (uidata) { + int n = p9_get_mapped_pages(chan, &in_pages, uidata, + inlen, &offs, &need_drop); + if (n < 0) { + err = n; + goto err_out; + } + in_nr_pages = DIV_ROUND_UP(n + offs, PAGE_SIZE); + if (n != inlen) { + __le32 v = cpu_to_le32(n); + memcpy(&req->tc.sdata[req->tc.size - 4], &v, 4); + inlen = n; + } + } + WRITE_ONCE(req->status, REQ_STATUS_SENT); +req_retry_pinned: + spin_lock_irqsave(&chan->lock, flags); + + out_sgs = in_sgs = 0; + + /* out data */ + out = pack_sg_list(chan->sg, 0, + VIRTQUEUE_NUM, req->tc.sdata, req->tc.size); + + if (out) + sgs[out_sgs++] = chan->sg; + + if (out_pages) { + sgs[out_sgs++] = chan->sg + out; + out += pack_sg_list_p(chan->sg, out, VIRTQUEUE_NUM, + out_pages, out_nr_pages, offs, outlen); + } + + /* + * Take care of in data + * For example TREAD have 11. + * 11 is the read/write header = PDU Header(7) + IO Size (4). + * Arrange in such a way that server places header in the + * allocated memory and payload onto the user buffer. + */ + in = pack_sg_list(chan->sg, out, + VIRTQUEUE_NUM, req->rc.sdata, in_hdr_len); + if (in) + sgs[out_sgs + in_sgs++] = chan->sg + out; + + if (in_pages) { + sgs[out_sgs + in_sgs++] = chan->sg + out + in; + in += pack_sg_list_p(chan->sg, out + in, VIRTQUEUE_NUM, + in_pages, in_nr_pages, offs, inlen); + } + + BUG_ON(out_sgs + in_sgs > ARRAY_SIZE(sgs)); + err = virtqueue_add_sgs(chan->vq, sgs, out_sgs, in_sgs, req, + GFP_ATOMIC); + if (err < 0) { + if (err == -ENOSPC) { + chan->ring_bufs_avail = 0; + spin_unlock_irqrestore(&chan->lock, flags); + err = wait_event_killable(*chan->vc_wq, + chan->ring_bufs_avail); + if (err == -ERESTARTSYS) + goto err_out; + + p9_debug(P9_DEBUG_TRANS, "Retry virtio request\n"); + goto req_retry_pinned; + } else { + spin_unlock_irqrestore(&chan->lock, flags); + p9_debug(P9_DEBUG_TRANS, + "virtio rpc add_sgs returned failure\n"); + err = -EIO; + goto err_out; + } + } + virtqueue_kick(chan->vq); + spin_unlock_irqrestore(&chan->lock, flags); + kicked = 1; + p9_debug(P9_DEBUG_TRANS, "virtio request kicked\n"); + err = wait_event_killable(req->wq, + READ_ONCE(req->status) >= REQ_STATUS_RCVD); + // RERROR needs reply (== error string) in static data + if (READ_ONCE(req->status) == REQ_STATUS_RCVD && + unlikely(req->rc.sdata[4] == P9_RERROR)) + handle_rerror(req, in_hdr_len, offs, in_pages); + + /* + * Non kernel buffers are pinned, unpin them + */ +err_out: + if (need_drop) { + if (in_pages) { + p9_release_pages(in_pages, in_nr_pages); + atomic_sub(in_nr_pages, &vp_pinned); + } + if (out_pages) { + p9_release_pages(out_pages, out_nr_pages); + atomic_sub(out_nr_pages, &vp_pinned); + } + /* wakeup anybody waiting for slots to pin pages */ + wake_up(&vp_wq); + } + kvfree(in_pages); + kvfree(out_pages); + if (!kicked) { + /* reply won't come */ + p9_req_put(client, req); + } + return err; +} + +static ssize_t p9_mount_tag_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + struct virtio_chan *chan; + struct virtio_device *vdev; + int tag_len; + + vdev = dev_to_virtio(dev); + chan = vdev->priv; + tag_len = strlen(chan->tag); + + memcpy(buf, chan->tag, tag_len + 1); + + return tag_len + 1; +} + +static DEVICE_ATTR(mount_tag, 0444, p9_mount_tag_show, NULL); + +/** + * p9_virtio_probe - probe for existence of 9P virtio channels + * @vdev: virtio device to probe + * + * This probes for existing virtio channels. + * + */ + +static int p9_virtio_probe(struct virtio_device *vdev) +{ + __u16 tag_len; + char *tag; + int err; + struct virtio_chan *chan; + + if (!vdev->config->get) { + dev_err(&vdev->dev, "%s failure: config access disabled\n", + __func__); + return -EINVAL; + } + + chan = kmalloc(sizeof(struct virtio_chan), GFP_KERNEL); + if (!chan) { + pr_err("Failed to allocate virtio 9P channel\n"); + err = -ENOMEM; + goto fail; + } + + chan->vdev = vdev; + + /* We expect one virtqueue, for requests. */ + chan->vq = virtio_find_single_vq(vdev, req_done, "requests"); + if (IS_ERR(chan->vq)) { + err = PTR_ERR(chan->vq); + goto out_free_chan; + } + chan->vq->vdev->priv = chan; + spin_lock_init(&chan->lock); + + sg_init_table(chan->sg, VIRTQUEUE_NUM); + + chan->inuse = false; + if (virtio_has_feature(vdev, VIRTIO_9P_MOUNT_TAG)) { + virtio_cread(vdev, struct virtio_9p_config, tag_len, &tag_len); + } else { + err = -EINVAL; + goto out_free_vq; + } + tag = kzalloc(tag_len + 1, GFP_KERNEL); + if (!tag) { + err = -ENOMEM; + goto out_free_vq; + } + + virtio_cread_bytes(vdev, offsetof(struct virtio_9p_config, tag), + tag, tag_len); + chan->tag = tag; + err = sysfs_create_file(&(vdev->dev.kobj), &dev_attr_mount_tag.attr); + if (err) { + goto out_free_tag; + } + chan->vc_wq = kmalloc(sizeof(wait_queue_head_t), GFP_KERNEL); + if (!chan->vc_wq) { + err = -ENOMEM; + goto out_remove_file; + } + init_waitqueue_head(chan->vc_wq); + chan->ring_bufs_avail = 1; + /* Ceiling limit to avoid denial of service attacks */ + chan->p9_max_pages = nr_free_buffer_pages()/4; + + virtio_device_ready(vdev); + + mutex_lock(&virtio_9p_lock); + list_add_tail(&chan->chan_list, &virtio_chan_list); + mutex_unlock(&virtio_9p_lock); + + /* Let udev rules use the new mount_tag attribute. */ + kobject_uevent(&(vdev->dev.kobj), KOBJ_CHANGE); + + return 0; + +out_remove_file: + sysfs_remove_file(&vdev->dev.kobj, &dev_attr_mount_tag.attr); +out_free_tag: + kfree(tag); +out_free_vq: + vdev->config->del_vqs(vdev); +out_free_chan: + kfree(chan); +fail: + return err; +} + + +/** + * p9_virtio_create - allocate a new virtio channel + * @client: client instance invoking this transport + * @devname: string identifying the channel to connect to (unused) + * @args: args passed from sys_mount() for per-transport options (unused) + * + * This sets up a transport channel for 9p communication. Right now + * we only match the first available channel, but eventually we could look up + * alternate channels by matching devname versus a virtio_config entry. + * We use a simple reference count mechanism to ensure that only a single + * mount has a channel open at a time. + * + */ + +static int +p9_virtio_create(struct p9_client *client, const char *devname, char *args) +{ + struct virtio_chan *chan; + int ret = -ENOENT; + int found = 0; + + if (devname == NULL) + return -EINVAL; + + mutex_lock(&virtio_9p_lock); + list_for_each_entry(chan, &virtio_chan_list, chan_list) { + if (!strcmp(devname, chan->tag)) { + if (!chan->inuse) { + chan->inuse = true; + found = 1; + break; + } + ret = -EBUSY; + } + } + mutex_unlock(&virtio_9p_lock); + + if (!found) { + pr_err("no channels available for device %s\n", devname); + return ret; + } + + client->trans = (void *)chan; + client->status = Connected; + chan->client = client; + + return 0; +} + +/** + * p9_virtio_remove - clean up resources associated with a virtio device + * @vdev: virtio device to remove + * + */ + +static void p9_virtio_remove(struct virtio_device *vdev) +{ + struct virtio_chan *chan = vdev->priv; + unsigned long warning_time; + + mutex_lock(&virtio_9p_lock); + + /* Remove self from list so we don't get new users. */ + list_del(&chan->chan_list); + warning_time = jiffies; + + /* Wait for existing users to close. */ + while (chan->inuse) { + mutex_unlock(&virtio_9p_lock); + msleep(250); + if (time_after(jiffies, warning_time + 10 * HZ)) { + dev_emerg(&vdev->dev, + "p9_virtio_remove: waiting for device in use.\n"); + warning_time = jiffies; + } + mutex_lock(&virtio_9p_lock); + } + + mutex_unlock(&virtio_9p_lock); + + virtio_reset_device(vdev); + vdev->config->del_vqs(vdev); + + sysfs_remove_file(&(vdev->dev.kobj), &dev_attr_mount_tag.attr); + kobject_uevent(&(vdev->dev.kobj), KOBJ_CHANGE); + kfree(chan->tag); + kfree(chan->vc_wq); + kfree(chan); + +} + +static struct virtio_device_id id_table[] = { + { VIRTIO_ID_9P, VIRTIO_DEV_ANY_ID }, + { 0 }, +}; + +static unsigned int features[] = { + VIRTIO_9P_MOUNT_TAG, +}; + +/* The standard "struct lguest_driver": */ +static struct virtio_driver p9_virtio_drv = { + .feature_table = features, + .feature_table_size = ARRAY_SIZE(features), + .driver.name = KBUILD_MODNAME, + .driver.owner = THIS_MODULE, + .id_table = id_table, + .probe = p9_virtio_probe, + .remove = p9_virtio_remove, +}; + +static struct p9_trans_module p9_virtio_trans = { + .name = "virtio", + .create = p9_virtio_create, + .close = p9_virtio_close, + .request = p9_virtio_request, + .zc_request = p9_virtio_zc_request, + .cancel = p9_virtio_cancel, + .cancelled = p9_virtio_cancelled, + /* + * We leave one entry for input and one entry for response + * headers. We also skip one more entry to accommodate, address + * that are not at page boundary, that can result in an extra + * page in zero copy. + */ + .maxsize = PAGE_SIZE * (VIRTQUEUE_NUM - 3), + .pooled_rbuffers = false, + .def = 1, + .owner = THIS_MODULE, +}; + +/* The standard init function */ +static int __init p9_virtio_init(void) +{ + int rc; + + INIT_LIST_HEAD(&virtio_chan_list); + + v9fs_register_trans(&p9_virtio_trans); + rc = register_virtio_driver(&p9_virtio_drv); + if (rc) + v9fs_unregister_trans(&p9_virtio_trans); + + return rc; +} + +static void __exit p9_virtio_cleanup(void) +{ + unregister_virtio_driver(&p9_virtio_drv); + v9fs_unregister_trans(&p9_virtio_trans); +} + +module_init(p9_virtio_init); +module_exit(p9_virtio_cleanup); +MODULE_ALIAS_9P("virtio"); + +MODULE_DEVICE_TABLE(virtio, id_table); +MODULE_AUTHOR("Eric Van Hensbergen "); +MODULE_DESCRIPTION("Virtio 9p Transport"); +MODULE_LICENSE("GPL"); diff --git a/net/9p/trans_xen.c b/net/9p/trans_xen.c new file mode 100644 index 000000000..68027e4fb --- /dev/null +++ b/net/9p/trans_xen.c @@ -0,0 +1,570 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * linux/fs/9p/trans_xen + * + * Xen transport layer. + * + * Copyright (C) 2017 by Stefano Stabellini + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#define XEN_9PFS_NUM_RINGS 2 +#define XEN_9PFS_RING_ORDER 9 +#define XEN_9PFS_RING_SIZE(ring) XEN_FLEX_RING_SIZE(ring->intf->ring_order) + +struct xen_9pfs_header { + uint32_t size; + uint8_t id; + uint16_t tag; + + /* uint8_t sdata[]; */ +} __attribute__((packed)); + +/* One per ring, more than one per 9pfs share */ +struct xen_9pfs_dataring { + struct xen_9pfs_front_priv *priv; + + struct xen_9pfs_data_intf *intf; + grant_ref_t ref; + int evtchn; + int irq; + /* protect a ring from concurrent accesses */ + spinlock_t lock; + + struct xen_9pfs_data data; + wait_queue_head_t wq; + struct work_struct work; +}; + +/* One per 9pfs share */ +struct xen_9pfs_front_priv { + struct list_head list; + struct xenbus_device *dev; + char *tag; + struct p9_client *client; + + int num_rings; + struct xen_9pfs_dataring *rings; +}; + +static LIST_HEAD(xen_9pfs_devs); +static DEFINE_RWLOCK(xen_9pfs_lock); + +/* We don't currently allow canceling of requests */ +static int p9_xen_cancel(struct p9_client *client, struct p9_req_t *req) +{ + return 1; +} + +static int p9_xen_create(struct p9_client *client, const char *addr, char *args) +{ + struct xen_9pfs_front_priv *priv; + + if (addr == NULL) + return -EINVAL; + + read_lock(&xen_9pfs_lock); + list_for_each_entry(priv, &xen_9pfs_devs, list) { + if (!strcmp(priv->tag, addr)) { + priv->client = client; + read_unlock(&xen_9pfs_lock); + return 0; + } + } + read_unlock(&xen_9pfs_lock); + return -EINVAL; +} + +static void p9_xen_close(struct p9_client *client) +{ + struct xen_9pfs_front_priv *priv; + + read_lock(&xen_9pfs_lock); + list_for_each_entry(priv, &xen_9pfs_devs, list) { + if (priv->client == client) { + priv->client = NULL; + read_unlock(&xen_9pfs_lock); + return; + } + } + read_unlock(&xen_9pfs_lock); +} + +static bool p9_xen_write_todo(struct xen_9pfs_dataring *ring, RING_IDX size) +{ + RING_IDX cons, prod; + + cons = ring->intf->out_cons; + prod = ring->intf->out_prod; + virt_mb(); + + return XEN_9PFS_RING_SIZE(ring) - + xen_9pfs_queued(prod, cons, XEN_9PFS_RING_SIZE(ring)) >= size; +} + +static int p9_xen_request(struct p9_client *client, struct p9_req_t *p9_req) +{ + struct xen_9pfs_front_priv *priv; + RING_IDX cons, prod, masked_cons, masked_prod; + unsigned long flags; + u32 size = p9_req->tc.size; + struct xen_9pfs_dataring *ring; + int num; + + read_lock(&xen_9pfs_lock); + list_for_each_entry(priv, &xen_9pfs_devs, list) { + if (priv->client == client) + break; + } + read_unlock(&xen_9pfs_lock); + if (list_entry_is_head(priv, &xen_9pfs_devs, list)) + return -EINVAL; + + num = p9_req->tc.tag % priv->num_rings; + ring = &priv->rings[num]; + +again: + while (wait_event_killable(ring->wq, + p9_xen_write_todo(ring, size)) != 0) + ; + + spin_lock_irqsave(&ring->lock, flags); + cons = ring->intf->out_cons; + prod = ring->intf->out_prod; + virt_mb(); + + if (XEN_9PFS_RING_SIZE(ring) - + xen_9pfs_queued(prod, cons, XEN_9PFS_RING_SIZE(ring)) < size) { + spin_unlock_irqrestore(&ring->lock, flags); + goto again; + } + + masked_prod = xen_9pfs_mask(prod, XEN_9PFS_RING_SIZE(ring)); + masked_cons = xen_9pfs_mask(cons, XEN_9PFS_RING_SIZE(ring)); + + xen_9pfs_write_packet(ring->data.out, p9_req->tc.sdata, size, + &masked_prod, masked_cons, + XEN_9PFS_RING_SIZE(ring)); + + WRITE_ONCE(p9_req->status, REQ_STATUS_SENT); + virt_wmb(); /* write ring before updating pointer */ + prod += size; + ring->intf->out_prod = prod; + spin_unlock_irqrestore(&ring->lock, flags); + notify_remote_via_irq(ring->irq); + p9_req_put(client, p9_req); + + return 0; +} + +static void p9_xen_response(struct work_struct *work) +{ + struct xen_9pfs_front_priv *priv; + struct xen_9pfs_dataring *ring; + RING_IDX cons, prod, masked_cons, masked_prod; + struct xen_9pfs_header h; + struct p9_req_t *req; + int status; + + ring = container_of(work, struct xen_9pfs_dataring, work); + priv = ring->priv; + + while (1) { + cons = ring->intf->in_cons; + prod = ring->intf->in_prod; + virt_rmb(); + + if (xen_9pfs_queued(prod, cons, XEN_9PFS_RING_SIZE(ring)) < + sizeof(h)) { + notify_remote_via_irq(ring->irq); + return; + } + + masked_prod = xen_9pfs_mask(prod, XEN_9PFS_RING_SIZE(ring)); + masked_cons = xen_9pfs_mask(cons, XEN_9PFS_RING_SIZE(ring)); + + /* First, read just the header */ + xen_9pfs_read_packet(&h, ring->data.in, sizeof(h), + masked_prod, &masked_cons, + XEN_9PFS_RING_SIZE(ring)); + + req = p9_tag_lookup(priv->client, h.tag); + if (!req || req->status != REQ_STATUS_SENT) { + dev_warn(&priv->dev->dev, "Wrong req tag=%x\n", h.tag); + cons += h.size; + virt_mb(); + ring->intf->in_cons = cons; + continue; + } + + if (h.size > req->rc.capacity) { + dev_warn(&priv->dev->dev, + "requested packet size too big: %d for tag %d with capacity %zd\n", + h.size, h.tag, req->rc.capacity); + WRITE_ONCE(req->status, REQ_STATUS_ERROR); + goto recv_error; + } + + memcpy(&req->rc, &h, sizeof(h)); + req->rc.offset = 0; + + masked_cons = xen_9pfs_mask(cons, XEN_9PFS_RING_SIZE(ring)); + /* Then, read the whole packet (including the header) */ + xen_9pfs_read_packet(req->rc.sdata, ring->data.in, h.size, + masked_prod, &masked_cons, + XEN_9PFS_RING_SIZE(ring)); + +recv_error: + virt_mb(); + cons += h.size; + ring->intf->in_cons = cons; + + status = (req->status != REQ_STATUS_ERROR) ? + REQ_STATUS_RCVD : REQ_STATUS_ERROR; + + p9_client_cb(priv->client, req, status); + } +} + +static irqreturn_t xen_9pfs_front_event_handler(int irq, void *r) +{ + struct xen_9pfs_dataring *ring = r; + + if (!ring || !ring->priv->client) { + /* ignore spurious interrupt */ + return IRQ_HANDLED; + } + + wake_up_interruptible(&ring->wq); + schedule_work(&ring->work); + + return IRQ_HANDLED; +} + +static struct p9_trans_module p9_xen_trans = { + .name = "xen", + .maxsize = 1 << (XEN_9PFS_RING_ORDER + XEN_PAGE_SHIFT - 2), + .pooled_rbuffers = false, + .def = 1, + .create = p9_xen_create, + .close = p9_xen_close, + .request = p9_xen_request, + .cancel = p9_xen_cancel, + .owner = THIS_MODULE, +}; + +static const struct xenbus_device_id xen_9pfs_front_ids[] = { + { "9pfs" }, + { "" } +}; + +static void xen_9pfs_front_free(struct xen_9pfs_front_priv *priv) +{ + int i, j; + + write_lock(&xen_9pfs_lock); + list_del(&priv->list); + write_unlock(&xen_9pfs_lock); + + for (i = 0; i < priv->num_rings; i++) { + struct xen_9pfs_dataring *ring = &priv->rings[i]; + + cancel_work_sync(&ring->work); + + if (!priv->rings[i].intf) + break; + if (priv->rings[i].irq > 0) + unbind_from_irqhandler(priv->rings[i].irq, priv->dev); + if (priv->rings[i].data.in) { + for (j = 0; + j < (1 << priv->rings[i].intf->ring_order); + j++) { + grant_ref_t ref; + + ref = priv->rings[i].intf->ref[j]; + gnttab_end_foreign_access(ref, NULL); + } + free_pages_exact(priv->rings[i].data.in, + 1UL << (priv->rings[i].intf->ring_order + + XEN_PAGE_SHIFT)); + } + gnttab_end_foreign_access(priv->rings[i].ref, NULL); + free_page((unsigned long)priv->rings[i].intf); + } + kfree(priv->rings); + kfree(priv->tag); + kfree(priv); +} + +static int xen_9pfs_front_remove(struct xenbus_device *dev) +{ + struct xen_9pfs_front_priv *priv = dev_get_drvdata(&dev->dev); + + dev_set_drvdata(&dev->dev, NULL); + xen_9pfs_front_free(priv); + return 0; +} + +static int xen_9pfs_front_alloc_dataring(struct xenbus_device *dev, + struct xen_9pfs_dataring *ring, + unsigned int order) +{ + int i = 0; + int ret = -ENOMEM; + void *bytes = NULL; + + init_waitqueue_head(&ring->wq); + spin_lock_init(&ring->lock); + INIT_WORK(&ring->work, p9_xen_response); + + ring->intf = (struct xen_9pfs_data_intf *)get_zeroed_page(GFP_KERNEL); + if (!ring->intf) + return ret; + ret = gnttab_grant_foreign_access(dev->otherend_id, + virt_to_gfn(ring->intf), 0); + if (ret < 0) + goto out; + ring->ref = ret; + bytes = alloc_pages_exact(1UL << (order + XEN_PAGE_SHIFT), + GFP_KERNEL | __GFP_ZERO); + if (!bytes) { + ret = -ENOMEM; + goto out; + } + for (; i < (1 << order); i++) { + ret = gnttab_grant_foreign_access( + dev->otherend_id, virt_to_gfn(bytes) + i, 0); + if (ret < 0) + goto out; + ring->intf->ref[i] = ret; + } + ring->intf->ring_order = order; + ring->data.in = bytes; + ring->data.out = bytes + XEN_FLEX_RING_SIZE(order); + + ret = xenbus_alloc_evtchn(dev, &ring->evtchn); + if (ret) + goto out; + ring->irq = bind_evtchn_to_irqhandler(ring->evtchn, + xen_9pfs_front_event_handler, + 0, "xen_9pfs-frontend", ring); + if (ring->irq >= 0) + return 0; + + xenbus_free_evtchn(dev, ring->evtchn); + ret = ring->irq; +out: + if (bytes) { + for (i--; i >= 0; i--) + gnttab_end_foreign_access(ring->intf->ref[i], NULL); + free_pages_exact(bytes, 1UL << (order + XEN_PAGE_SHIFT)); + } + gnttab_end_foreign_access(ring->ref, NULL); + free_page((unsigned long)ring->intf); + return ret; +} + +static int xen_9pfs_front_init(struct xenbus_device *dev) +{ + int ret, i; + struct xenbus_transaction xbt; + struct xen_9pfs_front_priv *priv = dev_get_drvdata(&dev->dev); + char *versions, *v; + unsigned int max_rings, max_ring_order, len = 0; + + versions = xenbus_read(XBT_NIL, dev->otherend, "versions", &len); + if (IS_ERR(versions)) + return PTR_ERR(versions); + for (v = versions; *v; v++) { + if (simple_strtoul(v, &v, 10) == 1) { + v = NULL; + break; + } + } + if (v) { + kfree(versions); + return -EINVAL; + } + kfree(versions); + max_rings = xenbus_read_unsigned(dev->otherend, "max-rings", 0); + if (max_rings < XEN_9PFS_NUM_RINGS) + return -EINVAL; + max_ring_order = xenbus_read_unsigned(dev->otherend, + "max-ring-page-order", 0); + if (max_ring_order > XEN_9PFS_RING_ORDER) + max_ring_order = XEN_9PFS_RING_ORDER; + if (p9_xen_trans.maxsize > XEN_FLEX_RING_SIZE(max_ring_order)) + p9_xen_trans.maxsize = XEN_FLEX_RING_SIZE(max_ring_order) / 2; + + priv->num_rings = XEN_9PFS_NUM_RINGS; + priv->rings = kcalloc(priv->num_rings, sizeof(*priv->rings), + GFP_KERNEL); + if (!priv->rings) { + kfree(priv); + return -ENOMEM; + } + + for (i = 0; i < priv->num_rings; i++) { + priv->rings[i].priv = priv; + ret = xen_9pfs_front_alloc_dataring(dev, &priv->rings[i], + max_ring_order); + if (ret < 0) + goto error; + } + + again: + ret = xenbus_transaction_start(&xbt); + if (ret) { + xenbus_dev_fatal(dev, ret, "starting transaction"); + goto error; + } + ret = xenbus_printf(xbt, dev->nodename, "version", "%u", 1); + if (ret) + goto error_xenbus; + ret = xenbus_printf(xbt, dev->nodename, "num-rings", "%u", + priv->num_rings); + if (ret) + goto error_xenbus; + for (i = 0; i < priv->num_rings; i++) { + char str[16]; + + BUILD_BUG_ON(XEN_9PFS_NUM_RINGS > 9); + sprintf(str, "ring-ref%d", i); + ret = xenbus_printf(xbt, dev->nodename, str, "%d", + priv->rings[i].ref); + if (ret) + goto error_xenbus; + + sprintf(str, "event-channel-%d", i); + ret = xenbus_printf(xbt, dev->nodename, str, "%u", + priv->rings[i].evtchn); + if (ret) + goto error_xenbus; + } + priv->tag = xenbus_read(xbt, dev->nodename, "tag", NULL); + if (IS_ERR(priv->tag)) { + ret = PTR_ERR(priv->tag); + goto error_xenbus; + } + ret = xenbus_transaction_end(xbt, 0); + if (ret) { + if (ret == -EAGAIN) + goto again; + xenbus_dev_fatal(dev, ret, "completing transaction"); + goto error; + } + + return 0; + + error_xenbus: + xenbus_transaction_end(xbt, 1); + xenbus_dev_fatal(dev, ret, "writing xenstore"); + error: + xen_9pfs_front_free(priv); + return ret; +} + +static int xen_9pfs_front_probe(struct xenbus_device *dev, + const struct xenbus_device_id *id) +{ + struct xen_9pfs_front_priv *priv = NULL; + + priv = kzalloc(sizeof(*priv), GFP_KERNEL); + if (!priv) + return -ENOMEM; + + priv->dev = dev; + dev_set_drvdata(&dev->dev, priv); + + write_lock(&xen_9pfs_lock); + list_add_tail(&priv->list, &xen_9pfs_devs); + write_unlock(&xen_9pfs_lock); + + return 0; +} + +static int xen_9pfs_front_resume(struct xenbus_device *dev) +{ + dev_warn(&dev->dev, "suspend/resume unsupported\n"); + return 0; +} + +static void xen_9pfs_front_changed(struct xenbus_device *dev, + enum xenbus_state backend_state) +{ + switch (backend_state) { + case XenbusStateReconfiguring: + case XenbusStateReconfigured: + case XenbusStateInitialising: + case XenbusStateInitialised: + case XenbusStateUnknown: + break; + + case XenbusStateInitWait: + if (!xen_9pfs_front_init(dev)) + xenbus_switch_state(dev, XenbusStateInitialised); + break; + + case XenbusStateConnected: + xenbus_switch_state(dev, XenbusStateConnected); + break; + + case XenbusStateClosed: + if (dev->state == XenbusStateClosed) + break; + fallthrough; /* Missed the backend's CLOSING state */ + case XenbusStateClosing: + xenbus_frontend_closed(dev); + break; + } +} + +static struct xenbus_driver xen_9pfs_front_driver = { + .ids = xen_9pfs_front_ids, + .probe = xen_9pfs_front_probe, + .remove = xen_9pfs_front_remove, + .resume = xen_9pfs_front_resume, + .otherend_changed = xen_9pfs_front_changed, +}; + +static int __init p9_trans_xen_init(void) +{ + int rc; + + if (!xen_domain()) + return -ENODEV; + + pr_info("Initialising Xen transport for 9pfs\n"); + + v9fs_register_trans(&p9_xen_trans); + rc = xenbus_register_frontend(&xen_9pfs_front_driver); + if (rc) + v9fs_unregister_trans(&p9_xen_trans); + + return rc; +} +module_init(p9_trans_xen_init); +MODULE_ALIAS_9P("xen"); + +static void __exit p9_trans_xen_exit(void) +{ + v9fs_unregister_trans(&p9_xen_trans); + return xenbus_unregister_driver(&xen_9pfs_front_driver); +} +module_exit(p9_trans_xen_exit); + +MODULE_ALIAS("xen:9pfs"); +MODULE_AUTHOR("Stefano Stabellini "); +MODULE_DESCRIPTION("Xen Transport for 9P"); +MODULE_LICENSE("GPL"); diff --git a/net/Kconfig b/net/Kconfig new file mode 100644 index 000000000..48c33c222 --- /dev/null +++ b/net/Kconfig @@ -0,0 +1,474 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Network configuration +# + +menuconfig NET + bool "Networking support" + select NLATTR + select GENERIC_NET_UTILS + select BPF + help + Unless you really know what you are doing, you should say Y here. + The reason is that some programs need kernel networking support even + when running on a stand-alone machine that isn't connected to any + other computer. + + If you are upgrading from an older kernel, you + should consider updating your networking tools too because changes + in the kernel and the tools often go hand in hand. The tools are + contained in the package net-tools, the location and version number + of which are given in . + + For a general introduction to Linux networking, it is highly + recommended to read the NET-HOWTO, available from + . + +if NET + +config WANT_COMPAT_NETLINK_MESSAGES + bool + help + This option can be selected by other options that need compat + netlink messages. + +config COMPAT_NETLINK_MESSAGES + def_bool y + depends on COMPAT + depends on WEXT_CORE || WANT_COMPAT_NETLINK_MESSAGES + help + This option makes it possible to send different netlink messages + to tasks depending on whether the task is a compat task or not. To + achieve this, you need to set skb_shinfo(skb)->frag_list to the + compat skb before sending the skb, the netlink code will sort out + which message to actually pass to the task. + + Newly written code should NEVER need this option but do + compat-independent messages instead! + +config NET_INGRESS + bool + +config NET_EGRESS + bool + +config NET_REDIRECT + bool + +config SKB_EXTENSIONS + bool + +menu "Networking options" + +source "net/packet/Kconfig" +source "net/unix/Kconfig" +source "net/tls/Kconfig" +source "net/xfrm/Kconfig" +source "net/iucv/Kconfig" +source "net/smc/Kconfig" +source "net/xdp/Kconfig" + +config INET + bool "TCP/IP networking" + help + These are the protocols used on the Internet and on most local + Ethernets. It is highly recommended to say Y here (this will enlarge + your kernel by about 400 KB), since some programs (e.g. the X window + system) use TCP/IP even if your machine is not connected to any + other computer. You will get the so-called loopback device which + allows you to ping yourself (great fun, that!). + + For an excellent introduction to Linux networking, please read the + Linux Networking HOWTO, available from + . + + If you say Y here and also to "/proc file system support" and + "Sysctl support" below, you can change various aspects of the + behavior of the TCP/IP code by writing to the (virtual) files in + /proc/sys/net/ipv4/*; the options are explained in the file + . + + Short answer: say Y. + +if INET +source "net/ipv4/Kconfig" +source "net/ipv6/Kconfig" +source "net/netlabel/Kconfig" +source "net/mptcp/Kconfig" + +endif # if INET + +config NETWORK_SECMARK + bool "Security Marking" + help + This enables security marking of network packets, similar + to nfmark, but designated for security purposes. + If you are unsure how to answer this question, answer N. + +config NET_PTP_CLASSIFY + def_bool n + +config NETWORK_PHY_TIMESTAMPING + bool "Timestamping in PHY devices" + select NET_PTP_CLASSIFY + help + This allows timestamping of network packets by PHYs (or + other MII bus snooping devices) with hardware timestamping + capabilities. This option adds some overhead in the transmit + and receive paths. + + If you are unsure how to answer this question, answer N. + +menuconfig NETFILTER + bool "Network packet filtering framework (Netfilter)" + help + Netfilter is a framework for filtering and mangling network packets + that pass through your Linux box. + + The most common use of packet filtering is to run your Linux box as + a firewall protecting a local network from the Internet. The type of + firewall provided by this kernel support is called a "packet + filter", which means that it can reject individual network packets + based on type, source, destination etc. The other kind of firewall, + a "proxy-based" one, is more secure but more intrusive and more + bothersome to set up; it inspects the network traffic much more + closely, modifies it and has knowledge about the higher level + protocols, which a packet filter lacks. Moreover, proxy-based + firewalls often require changes to the programs running on the local + clients. Proxy-based firewalls don't need support by the kernel, but + they are often combined with a packet filter, which only works if + you say Y here. + + You should also say Y here if you intend to use your Linux box as + the gateway to the Internet for a local network of machines without + globally valid IP addresses. This is called "masquerading": if one + of the computers on your local network wants to send something to + the outside, your box can "masquerade" as that computer, i.e. it + forwards the traffic to the intended outside destination, but + modifies the packets to make it look like they came from the + firewall box itself. It works both ways: if the outside host + replies, the Linux box will silently forward the traffic to the + correct local computer. This way, the computers on your local net + are completely invisible to the outside world, even though they can + reach the outside and can receive replies. It is even possible to + run globally visible servers from within a masqueraded local network + using a mechanism called portforwarding. Masquerading is also often + called NAT (Network Address Translation). + + Another use of Netfilter is in transparent proxying: if a machine on + the local network tries to connect to an outside host, your Linux + box can transparently forward the traffic to a local server, + typically a caching proxy server. + + Yet another use of Netfilter is building a bridging firewall. Using + a bridge with Network packet filtering enabled makes iptables "see" + the bridged traffic. For filtering on the lower network and Ethernet + protocols over the bridge, use ebtables (under bridge netfilter + configuration). + + Various modules exist for netfilter which replace the previous + masquerading (ipmasqadm), packet filtering (ipchains), transparent + proxying, and portforwarding mechanisms. Please see + under "iptables" for the location of + these packages. + +if NETFILTER + +config NETFILTER_ADVANCED + bool "Advanced netfilter configuration" + depends on NETFILTER + default y + help + If you say Y here you can select between all the netfilter modules. + If you say N the more unusual ones will not be shown and the + basic ones needed by most people will default to 'M'. + + If unsure, say Y. + +config BRIDGE_NETFILTER + tristate "Bridged IP/ARP packets filtering" + depends on BRIDGE + depends on NETFILTER && INET + depends on NETFILTER_ADVANCED + select NETFILTER_FAMILY_BRIDGE + select SKB_EXTENSIONS + help + Enabling this option will let arptables resp. iptables see bridged + ARP resp. IP traffic. If you want a bridging firewall, you probably + want this option enabled. + Enabling or disabling this option doesn't enable or disable + ebtables. + + If unsure, say N. + +source "net/netfilter/Kconfig" +source "net/ipv4/netfilter/Kconfig" +source "net/ipv6/netfilter/Kconfig" +source "net/bridge/netfilter/Kconfig" + +endif + +source "net/bpfilter/Kconfig" + +source "net/dccp/Kconfig" +source "net/sctp/Kconfig" +source "net/rds/Kconfig" +source "net/tipc/Kconfig" +source "net/atm/Kconfig" +source "net/l2tp/Kconfig" +source "net/802/Kconfig" +source "net/bridge/Kconfig" +source "net/dsa/Kconfig" +source "net/8021q/Kconfig" +source "net/llc/Kconfig" +source "drivers/net/appletalk/Kconfig" +source "net/x25/Kconfig" +source "net/lapb/Kconfig" +source "net/phonet/Kconfig" +source "net/6lowpan/Kconfig" +source "net/ieee802154/Kconfig" +source "net/mac802154/Kconfig" +source "net/sched/Kconfig" +source "net/dcb/Kconfig" +source "net/dns_resolver/Kconfig" +source "net/batman-adv/Kconfig" +source "net/openvswitch/Kconfig" +source "net/vmw_vsock/Kconfig" +source "net/netlink/Kconfig" +source "net/mpls/Kconfig" +source "net/nsh/Kconfig" +source "net/hsr/Kconfig" +source "net/switchdev/Kconfig" +source "net/l3mdev/Kconfig" +source "net/qrtr/Kconfig" +source "net/ncsi/Kconfig" + +config PCPU_DEV_REFCNT + bool "Use percpu variables to maintain network device refcount" + depends on SMP + default y + help + network device refcount are using per cpu variables if this option is set. + This can be forced to N to detect underflows (with a performance drop). + +config RPS + bool + depends on SMP && SYSFS + default y + +config RFS_ACCEL + bool + depends on RPS + select CPU_RMAP + default y + +config SOCK_RX_QUEUE_MAPPING + bool + +config XPS + bool + depends on SMP + select SOCK_RX_QUEUE_MAPPING + default y + +config HWBM + bool + +config CGROUP_NET_PRIO + bool "Network priority cgroup" + depends on CGROUPS + select SOCK_CGROUP_DATA + help + Cgroup subsystem for use in assigning processes to network priorities on + a per-interface basis. + +config CGROUP_NET_CLASSID + bool "Network classid cgroup" + depends on CGROUPS + select SOCK_CGROUP_DATA + help + Cgroup subsystem for use as general purpose socket classid marker that is + being used in cls_cgroup and for netfilter matching. + +config NET_RX_BUSY_POLL + bool + default y if !PREEMPT_RT + +config BQL + bool + depends on SYSFS + select DQL + default y + +config BPF_STREAM_PARSER + bool "enable BPF STREAM_PARSER" + depends on INET + depends on BPF_SYSCALL + depends on CGROUP_BPF + select STREAM_PARSER + select NET_SOCK_MSG + help + Enabling this allows a TCP stream parser to be used with + BPF_MAP_TYPE_SOCKMAP. + +config NET_FLOW_LIMIT + bool + depends on RPS + default y + help + The network stack has to drop packets when a receive processing CPU's + backlog reaches netdev_max_backlog. If a few out of many active flows + generate the vast majority of load, drop their traffic earlier to + maintain capacity for the other flows. This feature provides servers + with many clients some protection against DoS by a single (spoofed) + flow that greatly exceeds average workload. + +menu "Network testing" + +config NET_PKTGEN + tristate "Packet Generator (USE WITH CAUTION)" + depends on INET && PROC_FS + help + This module will inject preconfigured packets, at a configurable + rate, out of a given interface. It is used for network interface + stress testing and performance analysis. If you don't understand + what was just said, you don't need it: say N. + + Documentation on how to use the packet generator can be found + at . + + To compile this code as a module, choose M here: the + module will be called pktgen. + +config NET_DROP_MONITOR + tristate "Network packet drop alerting service" + depends on INET && TRACEPOINTS + help + This feature provides an alerting service to userspace in the + event that packets are discarded in the network stack. Alerts + are broadcast via netlink socket to any listening user space + process. If you don't need network drop alerts, or if you are ok + just checking the various proc files and other utilities for + drop statistics, say N here. + +endmenu + +endmenu + +source "net/ax25/Kconfig" +source "net/can/Kconfig" +source "net/bluetooth/Kconfig" +source "net/rxrpc/Kconfig" +source "net/kcm/Kconfig" +source "net/strparser/Kconfig" +source "net/mctp/Kconfig" + +config FIB_RULES + bool + +menuconfig WIRELESS + bool "Wireless" + depends on !S390 + default y + +if WIRELESS + +source "net/wireless/Kconfig" +source "net/mac80211/Kconfig" + +endif # WIRELESS + +source "net/rfkill/Kconfig" +source "net/9p/Kconfig" +source "net/caif/Kconfig" +source "net/ceph/Kconfig" +source "net/nfc/Kconfig" +source "net/psample/Kconfig" +source "net/ife/Kconfig" + +config LWTUNNEL + bool "Network light weight tunnels" + help + This feature provides an infrastructure to support light weight + tunnels like mpls. There is no netdevice associated with a light + weight tunnel endpoint. Tunnel encapsulation parameters are stored + with light weight tunnel state associated with fib routes. + +config LWTUNNEL_BPF + bool "Execute BPF program as route nexthop action" + depends on LWTUNNEL && INET + default y if LWTUNNEL=y + help + Allows to run BPF programs as a nexthop action following a route + lookup for incoming and outgoing packets. + +config DST_CACHE + bool + default n + +config GRO_CELLS + bool + default n + +config SOCK_VALIDATE_XMIT + bool + +config NET_SELFTESTS + def_tristate PHYLIB + depends on PHYLIB && INET + +config NET_SOCK_MSG + bool + default n + help + The NET_SOCK_MSG provides a framework for plain sockets (e.g. TCP) or + ULPs (upper layer modules, e.g. TLS) to process L7 application data + with the help of BPF programs. + +config NET_DEVLINK + bool + default n + +config PAGE_POOL + bool + +config PAGE_POOL_STATS + default n + bool "Page pool stats" + depends on PAGE_POOL + help + Enable page pool statistics to track page allocation and recycling + in page pools. This option incurs additional CPU cost in allocation + and recycle paths and additional memory cost to store the statistics. + These statistics are only available if this option is enabled and if + the driver using the page pool supports exporting this data. + + If unsure, say N. + +config FAILOVER + tristate "Generic failover module" + help + The failover module provides a generic interface for paravirtual + drivers to register a netdev and a set of ops with a failover + instance. The ops are used as event handlers that get called to + handle netdev register/unregister/link change/name change events + on slave pci ethernet devices with the same mac address as the + failover netdev. This enables paravirtual drivers to use a + VF as an accelerated low latency datapath. It also allows live + migration of VMs with direct attached VFs by failing over to the + paravirtual datapath when the VF is unplugged. + +config ETHTOOL_NETLINK + bool "Netlink interface for ethtool" + default y + help + An alternative userspace interface for ethtool based on generic + netlink. It provides better extensibility and some new features, + e.g. notification messages. + +config NETDEV_ADDR_LIST_TEST + tristate "Unit tests for device address list" + default KUNIT_ALL_TESTS + depends on KUNIT + +endif # if NET diff --git a/net/Kconfig.debug b/net/Kconfig.debug new file mode 100644 index 000000000..5e3fffe70 --- /dev/null +++ b/net/Kconfig.debug @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: GPL-2.0-only + +config NET_DEV_REFCNT_TRACKER + bool "Enable net device refcount tracking" + depends on DEBUG_KERNEL && STACKTRACE_SUPPORT && NET + select REF_TRACKER + default n + help + Enable debugging feature to track device references. + This adds memory and cpu costs. + +config NET_NS_REFCNT_TRACKER + bool "Enable networking namespace refcount tracking" + depends on DEBUG_KERNEL && STACKTRACE_SUPPORT && NET + select REF_TRACKER + default n + help + Enable debugging feature to track netns references. + This adds memory and cpu costs. + +config DEBUG_NET + bool "Add generic networking debug" + depends on DEBUG_KERNEL && NET + help + Enable extra sanity checks in networking. + This is mostly used by fuzzers, but is safe to select. diff --git a/net/Makefile b/net/Makefile new file mode 100644 index 000000000..0914bea9c --- /dev/null +++ b/net/Makefile @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the linux networking. +# +# 2 Sep 2000, Christoph Hellwig +# Rewritten to use lists instead of if-statements. +# + +obj-y := devres.o socket.o core/ + +obj-$(CONFIG_COMPAT) += compat.o + +# LLC has to be linked before the files in net/802/ +obj-$(CONFIG_LLC) += llc/ +obj-y += ethernet/ 802/ sched/ netlink/ bpf/ ethtool/ +obj-$(CONFIG_NETFILTER) += netfilter/ +obj-$(CONFIG_INET) += ipv4/ +obj-$(CONFIG_TLS) += tls/ +obj-$(CONFIG_XFRM) += xfrm/ +obj-$(CONFIG_UNIX_SCM) += unix/ +obj-y += ipv6/ +obj-$(CONFIG_BPFILTER) += bpfilter/ +obj-$(CONFIG_PACKET) += packet/ +obj-$(CONFIG_NET_KEY) += key/ +obj-$(CONFIG_BRIDGE) += bridge/ +obj-$(CONFIG_NET_DEVLINK) += devlink/ +obj-$(CONFIG_NET_DSA) += dsa/ +obj-$(CONFIG_ATALK) += appletalk/ +obj-$(CONFIG_X25) += x25/ +obj-$(CONFIG_LAPB) += lapb/ +obj-$(CONFIG_NETROM) += netrom/ +obj-$(CONFIG_ROSE) += rose/ +obj-$(CONFIG_AX25) += ax25/ +obj-$(CONFIG_CAN) += can/ +obj-$(CONFIG_BT) += bluetooth/ +obj-$(CONFIG_SUNRPC) += sunrpc/ +obj-$(CONFIG_AF_RXRPC) += rxrpc/ +obj-$(CONFIG_AF_KCM) += kcm/ +obj-$(CONFIG_STREAM_PARSER) += strparser/ +obj-$(CONFIG_ATM) += atm/ +obj-$(CONFIG_L2TP) += l2tp/ +obj-$(CONFIG_PHONET) += phonet/ +ifneq ($(CONFIG_VLAN_8021Q),) +obj-y += 8021q/ +endif +obj-$(CONFIG_IP_DCCP) += dccp/ +obj-$(CONFIG_IP_SCTP) += sctp/ +obj-$(CONFIG_RDS) += rds/ +obj-$(CONFIG_WIRELESS) += wireless/ +obj-$(CONFIG_MAC80211) += mac80211/ +obj-$(CONFIG_TIPC) += tipc/ +obj-$(CONFIG_NETLABEL) += netlabel/ +obj-$(CONFIG_IUCV) += iucv/ +obj-$(CONFIG_SMC) += smc/ +obj-$(CONFIG_RFKILL) += rfkill/ +obj-$(CONFIG_NET_9P) += 9p/ +obj-$(CONFIG_CAIF) += caif/ +obj-$(CONFIG_DCB) += dcb/ +obj-$(CONFIG_6LOWPAN) += 6lowpan/ +obj-$(CONFIG_IEEE802154) += ieee802154/ +obj-$(CONFIG_MAC802154) += mac802154/ + +obj-$(CONFIG_SYSCTL) += sysctl_net.o +obj-$(CONFIG_DNS_RESOLVER) += dns_resolver/ +obj-$(CONFIG_CEPH_LIB) += ceph/ +obj-$(CONFIG_BATMAN_ADV) += batman-adv/ +obj-$(CONFIG_NFC) += nfc/ +obj-$(CONFIG_PSAMPLE) += psample/ +obj-$(CONFIG_NET_IFE) += ife/ +obj-$(CONFIG_OPENVSWITCH) += openvswitch/ +obj-$(CONFIG_VSOCKETS) += vmw_vsock/ +obj-$(CONFIG_MPLS) += mpls/ +obj-$(CONFIG_NET_NSH) += nsh/ +obj-$(CONFIG_HSR) += hsr/ +obj-$(CONFIG_NET_SWITCHDEV) += switchdev/ +obj-$(CONFIG_NET_L3_MASTER_DEV) += l3mdev/ +obj-$(CONFIG_QRTR) += qrtr/ +obj-$(CONFIG_NET_NCSI) += ncsi/ +obj-$(CONFIG_XDP_SOCKETS) += xdp/ +obj-$(CONFIG_MPTCP) += mptcp/ +obj-$(CONFIG_MCTP) += mctp/ diff --git a/net/appletalk/Makefile b/net/appletalk/Makefile new file mode 100644 index 000000000..33164d972 --- /dev/null +++ b/net/appletalk/Makefile @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the Linux AppleTalk layer. +# + +obj-$(CONFIG_ATALK) += appletalk.o + +appletalk-y := aarp.o ddp.o dev.o +appletalk-$(CONFIG_PROC_FS) += atalk_proc.o +appletalk-$(CONFIG_SYSCTL) += sysctl_net_atalk.o diff --git a/net/appletalk/aarp.c b/net/appletalk/aarp.c new file mode 100644 index 000000000..c7236daa2 --- /dev/null +++ b/net/appletalk/aarp.c @@ -0,0 +1,1049 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * AARP: An implementation of the AppleTalk AARP protocol for + * Ethernet 'ELAP'. + * + * Alan Cox + * + * This doesn't fit cleanly with the IP arp. Potentially we can use + * the generic neighbour discovery code to clean this up. + * + * FIXME: + * We ought to handle the retransmits with a single list and a + * separate fast timer for when it is needed. + * Use neighbour discovery code. + * Token Ring Support. + * + * References: + * Inside AppleTalk (2nd Ed). + * Fixes: + * Jaume Grau - flush caches on AARP_PROBE + * Rob Newberry - Added proxy AARP and AARP proc fs, + * moved probing from DDP module. + * Arnaldo C. Melo - don't mangle rx packets + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +int sysctl_aarp_expiry_time = AARP_EXPIRY_TIME; +int sysctl_aarp_tick_time = AARP_TICK_TIME; +int sysctl_aarp_retransmit_limit = AARP_RETRANSMIT_LIMIT; +int sysctl_aarp_resolve_time = AARP_RESOLVE_TIME; + +/* Lists of aarp entries */ +/** + * struct aarp_entry - AARP entry + * @last_sent: Last time we xmitted the aarp request + * @packet_queue: Queue of frames wait for resolution + * @status: Used for proxy AARP + * @expires_at: Entry expiry time + * @target_addr: DDP Address + * @dev: Device to use + * @hwaddr: Physical i/f address of target/router + * @xmit_count: When this hits 10 we give up + * @next: Next entry in chain + */ +struct aarp_entry { + /* These first two are only used for unresolved entries */ + unsigned long last_sent; + struct sk_buff_head packet_queue; + int status; + unsigned long expires_at; + struct atalk_addr target_addr; + struct net_device *dev; + char hwaddr[ETH_ALEN]; + unsigned short xmit_count; + struct aarp_entry *next; +}; + +/* Hashed list of resolved, unresolved and proxy entries */ +static struct aarp_entry *resolved[AARP_HASH_SIZE]; +static struct aarp_entry *unresolved[AARP_HASH_SIZE]; +static struct aarp_entry *proxies[AARP_HASH_SIZE]; +static int unresolved_count; + +/* One lock protects it all. */ +static DEFINE_RWLOCK(aarp_lock); + +/* Used to walk the list and purge/kick entries. */ +static struct timer_list aarp_timer; + +/* + * Delete an aarp queue + * + * Must run under aarp_lock. + */ +static void __aarp_expire(struct aarp_entry *a) +{ + skb_queue_purge(&a->packet_queue); + kfree(a); +} + +/* + * Send an aarp queue entry request + * + * Must run under aarp_lock. + */ +static void __aarp_send_query(struct aarp_entry *a) +{ + static unsigned char aarp_eth_multicast[ETH_ALEN] = + { 0x09, 0x00, 0x07, 0xFF, 0xFF, 0xFF }; + struct net_device *dev = a->dev; + struct elapaarp *eah; + int len = dev->hard_header_len + sizeof(*eah) + aarp_dl->header_length; + struct sk_buff *skb = alloc_skb(len, GFP_ATOMIC); + struct atalk_addr *sat = atalk_find_dev_addr(dev); + + if (!skb) + return; + + if (!sat) { + kfree_skb(skb); + return; + } + + /* Set up the buffer */ + skb_reserve(skb, dev->hard_header_len + aarp_dl->header_length); + skb_reset_network_header(skb); + skb_reset_transport_header(skb); + skb_put(skb, sizeof(*eah)); + skb->protocol = htons(ETH_P_ATALK); + skb->dev = dev; + eah = aarp_hdr(skb); + + /* Set up the ARP */ + eah->hw_type = htons(AARP_HW_TYPE_ETHERNET); + eah->pa_type = htons(ETH_P_ATALK); + eah->hw_len = ETH_ALEN; + eah->pa_len = AARP_PA_ALEN; + eah->function = htons(AARP_REQUEST); + + ether_addr_copy(eah->hw_src, dev->dev_addr); + + eah->pa_src_zero = 0; + eah->pa_src_net = sat->s_net; + eah->pa_src_node = sat->s_node; + + eth_zero_addr(eah->hw_dst); + + eah->pa_dst_zero = 0; + eah->pa_dst_net = a->target_addr.s_net; + eah->pa_dst_node = a->target_addr.s_node; + + /* Send it */ + aarp_dl->request(aarp_dl, skb, aarp_eth_multicast); + /* Update the sending count */ + a->xmit_count++; + a->last_sent = jiffies; +} + +/* This runs under aarp_lock and in softint context, so only atomic memory + * allocations can be used. */ +static void aarp_send_reply(struct net_device *dev, struct atalk_addr *us, + struct atalk_addr *them, unsigned char *sha) +{ + struct elapaarp *eah; + int len = dev->hard_header_len + sizeof(*eah) + aarp_dl->header_length; + struct sk_buff *skb = alloc_skb(len, GFP_ATOMIC); + + if (!skb) + return; + + /* Set up the buffer */ + skb_reserve(skb, dev->hard_header_len + aarp_dl->header_length); + skb_reset_network_header(skb); + skb_reset_transport_header(skb); + skb_put(skb, sizeof(*eah)); + skb->protocol = htons(ETH_P_ATALK); + skb->dev = dev; + eah = aarp_hdr(skb); + + /* Set up the ARP */ + eah->hw_type = htons(AARP_HW_TYPE_ETHERNET); + eah->pa_type = htons(ETH_P_ATALK); + eah->hw_len = ETH_ALEN; + eah->pa_len = AARP_PA_ALEN; + eah->function = htons(AARP_REPLY); + + ether_addr_copy(eah->hw_src, dev->dev_addr); + + eah->pa_src_zero = 0; + eah->pa_src_net = us->s_net; + eah->pa_src_node = us->s_node; + + if (!sha) + eth_zero_addr(eah->hw_dst); + else + ether_addr_copy(eah->hw_dst, sha); + + eah->pa_dst_zero = 0; + eah->pa_dst_net = them->s_net; + eah->pa_dst_node = them->s_node; + + /* Send it */ + aarp_dl->request(aarp_dl, skb, sha); +} + +/* + * Send probe frames. Called from aarp_probe_network and + * aarp_proxy_probe_network. + */ + +static void aarp_send_probe(struct net_device *dev, struct atalk_addr *us) +{ + struct elapaarp *eah; + int len = dev->hard_header_len + sizeof(*eah) + aarp_dl->header_length; + struct sk_buff *skb = alloc_skb(len, GFP_ATOMIC); + static unsigned char aarp_eth_multicast[ETH_ALEN] = + { 0x09, 0x00, 0x07, 0xFF, 0xFF, 0xFF }; + + if (!skb) + return; + + /* Set up the buffer */ + skb_reserve(skb, dev->hard_header_len + aarp_dl->header_length); + skb_reset_network_header(skb); + skb_reset_transport_header(skb); + skb_put(skb, sizeof(*eah)); + skb->protocol = htons(ETH_P_ATALK); + skb->dev = dev; + eah = aarp_hdr(skb); + + /* Set up the ARP */ + eah->hw_type = htons(AARP_HW_TYPE_ETHERNET); + eah->pa_type = htons(ETH_P_ATALK); + eah->hw_len = ETH_ALEN; + eah->pa_len = AARP_PA_ALEN; + eah->function = htons(AARP_PROBE); + + ether_addr_copy(eah->hw_src, dev->dev_addr); + + eah->pa_src_zero = 0; + eah->pa_src_net = us->s_net; + eah->pa_src_node = us->s_node; + + eth_zero_addr(eah->hw_dst); + + eah->pa_dst_zero = 0; + eah->pa_dst_net = us->s_net; + eah->pa_dst_node = us->s_node; + + /* Send it */ + aarp_dl->request(aarp_dl, skb, aarp_eth_multicast); +} + +/* + * Handle an aarp timer expire + * + * Must run under the aarp_lock. + */ + +static void __aarp_expire_timer(struct aarp_entry **n) +{ + struct aarp_entry *t; + + while (*n) + /* Expired ? */ + if (time_after(jiffies, (*n)->expires_at)) { + t = *n; + *n = (*n)->next; + __aarp_expire(t); + } else + n = &((*n)->next); +} + +/* + * Kick all pending requests 5 times a second. + * + * Must run under the aarp_lock. + */ +static void __aarp_kick(struct aarp_entry **n) +{ + struct aarp_entry *t; + + while (*n) + /* Expired: if this will be the 11th tx, we delete instead. */ + if ((*n)->xmit_count >= sysctl_aarp_retransmit_limit) { + t = *n; + *n = (*n)->next; + __aarp_expire(t); + } else { + __aarp_send_query(*n); + n = &((*n)->next); + } +} + +/* + * A device has gone down. Take all entries referring to the device + * and remove them. + * + * Must run under the aarp_lock. + */ +static void __aarp_expire_device(struct aarp_entry **n, struct net_device *dev) +{ + struct aarp_entry *t; + + while (*n) + if ((*n)->dev == dev) { + t = *n; + *n = (*n)->next; + __aarp_expire(t); + } else + n = &((*n)->next); +} + +/* Handle the timer event */ +static void aarp_expire_timeout(struct timer_list *unused) +{ + int ct; + + write_lock_bh(&aarp_lock); + + for (ct = 0; ct < AARP_HASH_SIZE; ct++) { + __aarp_expire_timer(&resolved[ct]); + __aarp_kick(&unresolved[ct]); + __aarp_expire_timer(&unresolved[ct]); + __aarp_expire_timer(&proxies[ct]); + } + + write_unlock_bh(&aarp_lock); + mod_timer(&aarp_timer, jiffies + + (unresolved_count ? sysctl_aarp_tick_time : + sysctl_aarp_expiry_time)); +} + +/* Network device notifier chain handler. */ +static int aarp_device_event(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + int ct; + + if (!net_eq(dev_net(dev), &init_net)) + return NOTIFY_DONE; + + if (event == NETDEV_DOWN) { + write_lock_bh(&aarp_lock); + + for (ct = 0; ct < AARP_HASH_SIZE; ct++) { + __aarp_expire_device(&resolved[ct], dev); + __aarp_expire_device(&unresolved[ct], dev); + __aarp_expire_device(&proxies[ct], dev); + } + + write_unlock_bh(&aarp_lock); + } + return NOTIFY_DONE; +} + +/* Expire all entries in a hash chain */ +static void __aarp_expire_all(struct aarp_entry **n) +{ + struct aarp_entry *t; + + while (*n) { + t = *n; + *n = (*n)->next; + __aarp_expire(t); + } +} + +/* Cleanup all hash chains -- module unloading */ +static void aarp_purge(void) +{ + int ct; + + write_lock_bh(&aarp_lock); + for (ct = 0; ct < AARP_HASH_SIZE; ct++) { + __aarp_expire_all(&resolved[ct]); + __aarp_expire_all(&unresolved[ct]); + __aarp_expire_all(&proxies[ct]); + } + write_unlock_bh(&aarp_lock); +} + +/* + * Create a new aarp entry. This must use GFP_ATOMIC because it + * runs while holding spinlocks. + */ +static struct aarp_entry *aarp_alloc(void) +{ + struct aarp_entry *a = kmalloc(sizeof(*a), GFP_ATOMIC); + + if (a) + skb_queue_head_init(&a->packet_queue); + return a; +} + +/* + * Find an entry. We might return an expired but not yet purged entry. We + * don't care as it will do no harm. + * + * This must run under the aarp_lock. + */ +static struct aarp_entry *__aarp_find_entry(struct aarp_entry *list, + struct net_device *dev, + struct atalk_addr *sat) +{ + while (list) { + if (list->target_addr.s_net == sat->s_net && + list->target_addr.s_node == sat->s_node && + list->dev == dev) + break; + list = list->next; + } + + return list; +} + +/* Called from the DDP code, and thus must be exported. */ +void aarp_proxy_remove(struct net_device *dev, struct atalk_addr *sa) +{ + int hash = sa->s_node % (AARP_HASH_SIZE - 1); + struct aarp_entry *a; + + write_lock_bh(&aarp_lock); + + a = __aarp_find_entry(proxies[hash], dev, sa); + if (a) + a->expires_at = jiffies - 1; + + write_unlock_bh(&aarp_lock); +} + +/* This must run under aarp_lock. */ +static struct atalk_addr *__aarp_proxy_find(struct net_device *dev, + struct atalk_addr *sa) +{ + int hash = sa->s_node % (AARP_HASH_SIZE - 1); + struct aarp_entry *a = __aarp_find_entry(proxies[hash], dev, sa); + + return a ? sa : NULL; +} + +/* + * Probe a Phase 1 device or a device that requires its Net:Node to + * be set via an ioctl. + */ +static void aarp_send_probe_phase1(struct atalk_iface *iface) +{ + struct ifreq atreq; + struct sockaddr_at *sa = (struct sockaddr_at *)&atreq.ifr_addr; + const struct net_device_ops *ops = iface->dev->netdev_ops; + + sa->sat_addr.s_node = iface->address.s_node; + sa->sat_addr.s_net = ntohs(iface->address.s_net); + + /* We pass the Net:Node to the drivers/cards by a Device ioctl. */ + if (!(ops->ndo_do_ioctl(iface->dev, &atreq, SIOCSIFADDR))) { + ops->ndo_do_ioctl(iface->dev, &atreq, SIOCGIFADDR); + if (iface->address.s_net != htons(sa->sat_addr.s_net) || + iface->address.s_node != sa->sat_addr.s_node) + iface->status |= ATIF_PROBE_FAIL; + + iface->address.s_net = htons(sa->sat_addr.s_net); + iface->address.s_node = sa->sat_addr.s_node; + } +} + + +void aarp_probe_network(struct atalk_iface *atif) +{ + if (atif->dev->type == ARPHRD_LOCALTLK || + atif->dev->type == ARPHRD_PPP) + aarp_send_probe_phase1(atif); + else { + unsigned int count; + + for (count = 0; count < AARP_RETRANSMIT_LIMIT; count++) { + aarp_send_probe(atif->dev, &atif->address); + + /* Defer 1/10th */ + msleep(100); + + if (atif->status & ATIF_PROBE_FAIL) + break; + } + } +} + +int aarp_proxy_probe_network(struct atalk_iface *atif, struct atalk_addr *sa) +{ + int hash, retval = -EPROTONOSUPPORT; + struct aarp_entry *entry; + unsigned int count; + + /* + * we don't currently support LocalTalk or PPP for proxy AARP; + * if someone wants to try and add it, have fun + */ + if (atif->dev->type == ARPHRD_LOCALTLK || + atif->dev->type == ARPHRD_PPP) + goto out; + + /* + * create a new AARP entry with the flags set to be published -- + * we need this one to hang around even if it's in use + */ + entry = aarp_alloc(); + retval = -ENOMEM; + if (!entry) + goto out; + + entry->expires_at = -1; + entry->status = ATIF_PROBE; + entry->target_addr.s_node = sa->s_node; + entry->target_addr.s_net = sa->s_net; + entry->dev = atif->dev; + + write_lock_bh(&aarp_lock); + + hash = sa->s_node % (AARP_HASH_SIZE - 1); + entry->next = proxies[hash]; + proxies[hash] = entry; + + for (count = 0; count < AARP_RETRANSMIT_LIMIT; count++) { + aarp_send_probe(atif->dev, sa); + + /* Defer 1/10th */ + write_unlock_bh(&aarp_lock); + msleep(100); + write_lock_bh(&aarp_lock); + + if (entry->status & ATIF_PROBE_FAIL) + break; + } + + if (entry->status & ATIF_PROBE_FAIL) { + entry->expires_at = jiffies - 1; /* free the entry */ + retval = -EADDRINUSE; /* return network full */ + } else { /* clear the probing flag */ + entry->status &= ~ATIF_PROBE; + retval = 1; + } + + write_unlock_bh(&aarp_lock); +out: + return retval; +} + +/* Send a DDP frame */ +int aarp_send_ddp(struct net_device *dev, struct sk_buff *skb, + struct atalk_addr *sa, void *hwaddr) +{ + static char ddp_eth_multicast[ETH_ALEN] = + { 0x09, 0x00, 0x07, 0xFF, 0xFF, 0xFF }; + int hash; + struct aarp_entry *a; + + skb_reset_network_header(skb); + + /* Check for LocalTalk first */ + if (dev->type == ARPHRD_LOCALTLK) { + struct atalk_addr *at = atalk_find_dev_addr(dev); + struct ddpehdr *ddp = (struct ddpehdr *)skb->data; + int ft = 2; + + /* + * Compressible ? + * + * IFF: src_net == dest_net == device_net + * (zero matches anything) + */ + + if ((!ddp->deh_snet || at->s_net == ddp->deh_snet) && + (!ddp->deh_dnet || at->s_net == ddp->deh_dnet)) { + skb_pull(skb, sizeof(*ddp) - 4); + + /* + * The upper two remaining bytes are the port + * numbers we just happen to need. Now put the + * length in the lower two. + */ + *((__be16 *)skb->data) = htons(skb->len); + ft = 1; + } + /* + * Nice and easy. No AARP type protocols occur here so we can + * just shovel it out with a 3 byte LLAP header + */ + + skb_push(skb, 3); + skb->data[0] = sa->s_node; + skb->data[1] = at->s_node; + skb->data[2] = ft; + skb->dev = dev; + goto sendit; + } + + /* On a PPP link we neither compress nor aarp. */ + if (dev->type == ARPHRD_PPP) { + skb->protocol = htons(ETH_P_PPPTALK); + skb->dev = dev; + goto sendit; + } + + /* Non ELAP we cannot do. */ + if (dev->type != ARPHRD_ETHER) + goto free_it; + + skb->dev = dev; + skb->protocol = htons(ETH_P_ATALK); + hash = sa->s_node % (AARP_HASH_SIZE - 1); + + /* Do we have a resolved entry? */ + if (sa->s_node == ATADDR_BCAST) { + /* Send it */ + ddp_dl->request(ddp_dl, skb, ddp_eth_multicast); + goto sent; + } + + write_lock_bh(&aarp_lock); + a = __aarp_find_entry(resolved[hash], dev, sa); + + if (a) { /* Return 1 and fill in the address */ + a->expires_at = jiffies + (sysctl_aarp_expiry_time * 10); + ddp_dl->request(ddp_dl, skb, a->hwaddr); + write_unlock_bh(&aarp_lock); + goto sent; + } + + /* Do we have an unresolved entry: This is the less common path */ + a = __aarp_find_entry(unresolved[hash], dev, sa); + if (a) { /* Queue onto the unresolved queue */ + skb_queue_tail(&a->packet_queue, skb); + goto out_unlock; + } + + /* Allocate a new entry */ + a = aarp_alloc(); + if (!a) { + /* Whoops slipped... good job it's an unreliable protocol 8) */ + write_unlock_bh(&aarp_lock); + goto free_it; + } + + /* Set up the queue */ + skb_queue_tail(&a->packet_queue, skb); + a->expires_at = jiffies + sysctl_aarp_resolve_time; + a->dev = dev; + a->next = unresolved[hash]; + a->target_addr = *sa; + a->xmit_count = 0; + unresolved[hash] = a; + unresolved_count++; + + /* Send an initial request for the address */ + __aarp_send_query(a); + + /* + * Switch to fast timer if needed (That is if this is the first + * unresolved entry to get added) + */ + + if (unresolved_count == 1) + mod_timer(&aarp_timer, jiffies + sysctl_aarp_tick_time); + + /* Now finally, it is safe to drop the lock. */ +out_unlock: + write_unlock_bh(&aarp_lock); + + /* Tell the ddp layer we have taken over for this frame. */ + goto sent; + +sendit: + if (skb->sk) + skb->priority = skb->sk->sk_priority; + if (dev_queue_xmit(skb)) + goto drop; +sent: + return NET_XMIT_SUCCESS; +free_it: + kfree_skb(skb); +drop: + return NET_XMIT_DROP; +} +EXPORT_SYMBOL(aarp_send_ddp); + +/* + * An entry in the aarp unresolved queue has become resolved. Send + * all the frames queued under it. + * + * Must run under aarp_lock. + */ +static void __aarp_resolved(struct aarp_entry **list, struct aarp_entry *a, + int hash) +{ + struct sk_buff *skb; + + while (*list) + if (*list == a) { + unresolved_count--; + *list = a->next; + + /* Move into the resolved list */ + a->next = resolved[hash]; + resolved[hash] = a; + + /* Kick frames off */ + while ((skb = skb_dequeue(&a->packet_queue)) != NULL) { + a->expires_at = jiffies + + sysctl_aarp_expiry_time * 10; + ddp_dl->request(ddp_dl, skb, a->hwaddr); + } + } else + list = &((*list)->next); +} + +/* + * This is called by the SNAP driver whenever we see an AARP SNAP + * frame. We currently only support Ethernet. + */ +static int aarp_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + struct elapaarp *ea = aarp_hdr(skb); + int hash, ret = 0; + __u16 function; + struct aarp_entry *a; + struct atalk_addr sa, *ma, da; + struct atalk_iface *ifa; + + if (!net_eq(dev_net(dev), &init_net)) + goto out0; + + /* We only do Ethernet SNAP AARP. */ + if (dev->type != ARPHRD_ETHER) + goto out0; + + /* Frame size ok? */ + if (!skb_pull(skb, sizeof(*ea))) + goto out0; + + function = ntohs(ea->function); + + /* Sanity check fields. */ + if (function < AARP_REQUEST || function > AARP_PROBE || + ea->hw_len != ETH_ALEN || ea->pa_len != AARP_PA_ALEN || + ea->pa_src_zero || ea->pa_dst_zero) + goto out0; + + /* Looks good. */ + hash = ea->pa_src_node % (AARP_HASH_SIZE - 1); + + /* Build an address. */ + sa.s_node = ea->pa_src_node; + sa.s_net = ea->pa_src_net; + + /* Process the packet. Check for replies of me. */ + ifa = atalk_find_dev(dev); + if (!ifa) + goto out1; + + if (ifa->status & ATIF_PROBE && + ifa->address.s_node == ea->pa_dst_node && + ifa->address.s_net == ea->pa_dst_net) { + ifa->status |= ATIF_PROBE_FAIL; /* Fail the probe (in use) */ + goto out1; + } + + /* Check for replies of proxy AARP entries */ + da.s_node = ea->pa_dst_node; + da.s_net = ea->pa_dst_net; + + write_lock_bh(&aarp_lock); + a = __aarp_find_entry(proxies[hash], dev, &da); + + if (a && a->status & ATIF_PROBE) { + a->status |= ATIF_PROBE_FAIL; + /* + * we do not respond to probe or request packets of + * this address while we are probing this address + */ + goto unlock; + } + + switch (function) { + case AARP_REPLY: + if (!unresolved_count) /* Speed up */ + break; + + /* Find the entry. */ + a = __aarp_find_entry(unresolved[hash], dev, &sa); + if (!a || dev != a->dev) + break; + + /* We can fill one in - this is good. */ + ether_addr_copy(a->hwaddr, ea->hw_src); + __aarp_resolved(&unresolved[hash], a, hash); + if (!unresolved_count) + mod_timer(&aarp_timer, + jiffies + sysctl_aarp_expiry_time); + break; + + case AARP_REQUEST: + case AARP_PROBE: + + /* + * If it is my address set ma to my address and reply. + * We can treat probe and request the same. Probe + * simply means we shouldn't cache the querying host, + * as in a probe they are proposing an address not + * using one. + * + * Support for proxy-AARP added. We check if the + * address is one of our proxies before we toss the + * packet out. + */ + + sa.s_node = ea->pa_dst_node; + sa.s_net = ea->pa_dst_net; + + /* See if we have a matching proxy. */ + ma = __aarp_proxy_find(dev, &sa); + if (!ma) + ma = &ifa->address; + else { /* We need to make a copy of the entry. */ + da.s_node = sa.s_node; + da.s_net = sa.s_net; + ma = &da; + } + + if (function == AARP_PROBE) { + /* + * A probe implies someone trying to get an + * address. So as a precaution flush any + * entries we have for this address. + */ + a = __aarp_find_entry(resolved[sa.s_node % + (AARP_HASH_SIZE - 1)], + skb->dev, &sa); + + /* + * Make it expire next tick - that avoids us + * getting into a probe/flush/learn/probe/ + * flush/learn cycle during probing of a slow + * to respond host addr. + */ + if (a) { + a->expires_at = jiffies - 1; + mod_timer(&aarp_timer, jiffies + + sysctl_aarp_tick_time); + } + } + + if (sa.s_node != ma->s_node) + break; + + if (sa.s_net && ma->s_net && sa.s_net != ma->s_net) + break; + + sa.s_node = ea->pa_src_node; + sa.s_net = ea->pa_src_net; + + /* aarp_my_address has found the address to use for us. + */ + aarp_send_reply(dev, ma, &sa, ea->hw_src); + break; + } + +unlock: + write_unlock_bh(&aarp_lock); +out1: + ret = 1; +out0: + kfree_skb(skb); + return ret; +} + +static struct notifier_block aarp_notifier = { + .notifier_call = aarp_device_event, +}; + +static unsigned char aarp_snap_id[] = { 0x00, 0x00, 0x00, 0x80, 0xF3 }; + +int __init aarp_proto_init(void) +{ + int rc; + + aarp_dl = register_snap_client(aarp_snap_id, aarp_rcv); + if (!aarp_dl) { + printk(KERN_CRIT "Unable to register AARP with SNAP.\n"); + return -ENOMEM; + } + timer_setup(&aarp_timer, aarp_expire_timeout, 0); + aarp_timer.expires = jiffies + sysctl_aarp_expiry_time; + add_timer(&aarp_timer); + rc = register_netdevice_notifier(&aarp_notifier); + if (rc) { + del_timer_sync(&aarp_timer); + unregister_snap_client(aarp_dl); + } + return rc; +} + +/* Remove the AARP entries associated with a device. */ +void aarp_device_down(struct net_device *dev) +{ + int ct; + + write_lock_bh(&aarp_lock); + + for (ct = 0; ct < AARP_HASH_SIZE; ct++) { + __aarp_expire_device(&resolved[ct], dev); + __aarp_expire_device(&unresolved[ct], dev); + __aarp_expire_device(&proxies[ct], dev); + } + + write_unlock_bh(&aarp_lock); +} + +#ifdef CONFIG_PROC_FS +/* + * Get the aarp entry that is in the chain described + * by the iterator. + * If pos is set then skip till that index. + * pos = 1 is the first entry + */ +static struct aarp_entry *iter_next(struct aarp_iter_state *iter, loff_t *pos) +{ + int ct = iter->bucket; + struct aarp_entry **table = iter->table; + loff_t off = 0; + struct aarp_entry *entry; + + rescan: + while (ct < AARP_HASH_SIZE) { + for (entry = table[ct]; entry; entry = entry->next) { + if (!pos || ++off == *pos) { + iter->table = table; + iter->bucket = ct; + return entry; + } + } + ++ct; + } + + if (table == resolved) { + ct = 0; + table = unresolved; + goto rescan; + } + if (table == unresolved) { + ct = 0; + table = proxies; + goto rescan; + } + return NULL; +} + +static void *aarp_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(aarp_lock) +{ + struct aarp_iter_state *iter = seq->private; + + read_lock_bh(&aarp_lock); + iter->table = resolved; + iter->bucket = 0; + + return *pos ? iter_next(iter, pos) : SEQ_START_TOKEN; +} + +static void *aarp_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct aarp_entry *entry = v; + struct aarp_iter_state *iter = seq->private; + + ++*pos; + + /* first line after header */ + if (v == SEQ_START_TOKEN) + entry = iter_next(iter, NULL); + + /* next entry in current bucket */ + else if (entry->next) + entry = entry->next; + + /* next bucket or table */ + else { + ++iter->bucket; + entry = iter_next(iter, NULL); + } + return entry; +} + +static void aarp_seq_stop(struct seq_file *seq, void *v) + __releases(aarp_lock) +{ + read_unlock_bh(&aarp_lock); +} + +static const char *dt2str(unsigned long ticks) +{ + static char buf[32]; + + sprintf(buf, "%ld.%02ld", ticks / HZ, ((ticks % HZ) * 100) / HZ); + + return buf; +} + +static int aarp_seq_show(struct seq_file *seq, void *v) +{ + struct aarp_iter_state *iter = seq->private; + struct aarp_entry *entry = v; + unsigned long now = jiffies; + + if (v == SEQ_START_TOKEN) + seq_puts(seq, + "Address Interface Hardware Address" + " Expires LastSend Retry Status\n"); + else { + seq_printf(seq, "%04X:%02X %-12s", + ntohs(entry->target_addr.s_net), + (unsigned int) entry->target_addr.s_node, + entry->dev ? entry->dev->name : "????"); + seq_printf(seq, "%pM", entry->hwaddr); + seq_printf(seq, " %8s", + dt2str((long)entry->expires_at - (long)now)); + if (iter->table == unresolved) + seq_printf(seq, " %8s %6hu", + dt2str(now - entry->last_sent), + entry->xmit_count); + else + seq_puts(seq, " "); + seq_printf(seq, " %s\n", + (iter->table == resolved) ? "resolved" + : (iter->table == unresolved) ? "unresolved" + : (iter->table == proxies) ? "proxies" + : "unknown"); + } + return 0; +} + +const struct seq_operations aarp_seq_ops = { + .start = aarp_seq_start, + .next = aarp_seq_next, + .stop = aarp_seq_stop, + .show = aarp_seq_show, +}; +#endif + +/* General module cleanup. Called from cleanup_module() in ddp.c. */ +void aarp_cleanup_module(void) +{ + del_timer_sync(&aarp_timer); + unregister_netdevice_notifier(&aarp_notifier); + unregister_snap_client(aarp_dl); + aarp_purge(); +} diff --git a/net/appletalk/atalk_proc.c b/net/appletalk/atalk_proc.c new file mode 100644 index 000000000..9c1241292 --- /dev/null +++ b/net/appletalk/atalk_proc.c @@ -0,0 +1,242 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * atalk_proc.c - proc support for Appletalk + * + * Copyright(c) Arnaldo Carvalho de Melo + */ + +#include +#include +#include +#include +#include +#include +#include + + +static __inline__ struct atalk_iface *atalk_get_interface_idx(loff_t pos) +{ + struct atalk_iface *i; + + for (i = atalk_interfaces; pos && i; i = i->next) + --pos; + + return i; +} + +static void *atalk_seq_interface_start(struct seq_file *seq, loff_t *pos) + __acquires(atalk_interfaces_lock) +{ + loff_t l = *pos; + + read_lock_bh(&atalk_interfaces_lock); + return l ? atalk_get_interface_idx(--l) : SEQ_START_TOKEN; +} + +static void *atalk_seq_interface_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct atalk_iface *i; + + ++*pos; + if (v == SEQ_START_TOKEN) { + i = NULL; + if (atalk_interfaces) + i = atalk_interfaces; + goto out; + } + i = v; + i = i->next; +out: + return i; +} + +static void atalk_seq_interface_stop(struct seq_file *seq, void *v) + __releases(atalk_interfaces_lock) +{ + read_unlock_bh(&atalk_interfaces_lock); +} + +static int atalk_seq_interface_show(struct seq_file *seq, void *v) +{ + struct atalk_iface *iface; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, "Interface Address Networks " + "Status\n"); + goto out; + } + + iface = v; + seq_printf(seq, "%-16s %04X:%02X %04X-%04X %d\n", + iface->dev->name, ntohs(iface->address.s_net), + iface->address.s_node, ntohs(iface->nets.nr_firstnet), + ntohs(iface->nets.nr_lastnet), iface->status); +out: + return 0; +} + +static __inline__ struct atalk_route *atalk_get_route_idx(loff_t pos) +{ + struct atalk_route *r; + + for (r = atalk_routes; pos && r; r = r->next) + --pos; + + return r; +} + +static void *atalk_seq_route_start(struct seq_file *seq, loff_t *pos) + __acquires(atalk_routes_lock) +{ + loff_t l = *pos; + + read_lock_bh(&atalk_routes_lock); + return l ? atalk_get_route_idx(--l) : SEQ_START_TOKEN; +} + +static void *atalk_seq_route_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct atalk_route *r; + + ++*pos; + if (v == SEQ_START_TOKEN) { + r = NULL; + if (atalk_routes) + r = atalk_routes; + goto out; + } + r = v; + r = r->next; +out: + return r; +} + +static void atalk_seq_route_stop(struct seq_file *seq, void *v) + __releases(atalk_routes_lock) +{ + read_unlock_bh(&atalk_routes_lock); +} + +static int atalk_seq_route_show(struct seq_file *seq, void *v) +{ + struct atalk_route *rt; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, "Target Router Flags Dev\n"); + goto out; + } + + if (atrtr_default.dev) { + rt = &atrtr_default; + seq_printf(seq, "Default %04X:%02X %-4d %s\n", + ntohs(rt->gateway.s_net), rt->gateway.s_node, + rt->flags, rt->dev->name); + } + + rt = v; + seq_printf(seq, "%04X:%02X %04X:%02X %-4d %s\n", + ntohs(rt->target.s_net), rt->target.s_node, + ntohs(rt->gateway.s_net), rt->gateway.s_node, + rt->flags, rt->dev->name); +out: + return 0; +} + +static void *atalk_seq_socket_start(struct seq_file *seq, loff_t *pos) + __acquires(atalk_sockets_lock) +{ + read_lock_bh(&atalk_sockets_lock); + return seq_hlist_start_head(&atalk_sockets, *pos); +} + +static void *atalk_seq_socket_next(struct seq_file *seq, void *v, loff_t *pos) +{ + return seq_hlist_next(v, &atalk_sockets, pos); +} + +static void atalk_seq_socket_stop(struct seq_file *seq, void *v) + __releases(atalk_sockets_lock) +{ + read_unlock_bh(&atalk_sockets_lock); +} + +static int atalk_seq_socket_show(struct seq_file *seq, void *v) +{ + struct sock *s; + struct atalk_sock *at; + + if (v == SEQ_START_TOKEN) { + seq_printf(seq, "Type Local_addr Remote_addr Tx_queue " + "Rx_queue St UID\n"); + goto out; + } + + s = sk_entry(v); + at = at_sk(s); + + seq_printf(seq, "%02X %04X:%02X:%02X %04X:%02X:%02X %08X:%08X " + "%02X %u\n", + s->sk_type, ntohs(at->src_net), at->src_node, at->src_port, + ntohs(at->dest_net), at->dest_node, at->dest_port, + sk_wmem_alloc_get(s), + sk_rmem_alloc_get(s), + s->sk_state, + from_kuid_munged(seq_user_ns(seq), sock_i_uid(s))); +out: + return 0; +} + +static const struct seq_operations atalk_seq_interface_ops = { + .start = atalk_seq_interface_start, + .next = atalk_seq_interface_next, + .stop = atalk_seq_interface_stop, + .show = atalk_seq_interface_show, +}; + +static const struct seq_operations atalk_seq_route_ops = { + .start = atalk_seq_route_start, + .next = atalk_seq_route_next, + .stop = atalk_seq_route_stop, + .show = atalk_seq_route_show, +}; + +static const struct seq_operations atalk_seq_socket_ops = { + .start = atalk_seq_socket_start, + .next = atalk_seq_socket_next, + .stop = atalk_seq_socket_stop, + .show = atalk_seq_socket_show, +}; + +int __init atalk_proc_init(void) +{ + if (!proc_mkdir("atalk", init_net.proc_net)) + return -ENOMEM; + + if (!proc_create_seq("atalk/interface", 0444, init_net.proc_net, + &atalk_seq_interface_ops)) + goto out; + + if (!proc_create_seq("atalk/route", 0444, init_net.proc_net, + &atalk_seq_route_ops)) + goto out; + + if (!proc_create_seq("atalk/socket", 0444, init_net.proc_net, + &atalk_seq_socket_ops)) + goto out; + + if (!proc_create_seq_private("atalk/arp", 0444, init_net.proc_net, + &aarp_seq_ops, + sizeof(struct aarp_iter_state), NULL)) + goto out; + + return 0; + +out: + remove_proc_subtree("atalk", init_net.proc_net); + return -ENOMEM; +} + +void atalk_proc_exit(void) +{ + remove_proc_subtree("atalk", init_net.proc_net); +} diff --git a/net/appletalk/ddp.c b/net/appletalk/ddp.c new file mode 100644 index 000000000..f67f14db1 --- /dev/null +++ b/net/appletalk/ddp.c @@ -0,0 +1,2040 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * DDP: An implementation of the AppleTalk DDP protocol for + * Ethernet 'ELAP'. + * + * Alan Cox + * + * With more than a little assistance from + * + * Wesley Craig + * + * Fixes: + * Neil Horman : Added missing device ioctls + * Michael Callahan : Made routing work + * Wesley Craig : Fix probing to listen to a + * passed node id. + * Alan Cox : Added send/recvmsg support + * Alan Cox : Moved at. to protinfo in + * socket. + * Alan Cox : Added firewall hooks. + * Alan Cox : Supports new ARPHRD_LOOPBACK + * Christer Weinigel : Routing and /proc fixes. + * Bradford Johnson : LocalTalk. + * Tom Dyas : Module support. + * Alan Cox : Hooks for PPP (based on the + * LocalTalk hook). + * Alan Cox : Posix bits + * Alan Cox/Mike Freeman : Possible fix to NBP problems + * Bradford Johnson : IP-over-DDP (experimental) + * Jay Schulist : Moved IP-over-DDP to its own + * driver file. (ipddp.c & ipddp.h) + * Jay Schulist : Made work as module with + * AppleTalk drivers, cleaned it. + * Rob Newberry : Added proxy AARP and AARP + * procfs, moved probing to AARP + * module. + * Adrian Sun/ + * Michael Zuelsdorff : fix for net.0 packets. don't + * allow illegal ether/tokentalk + * port assignment. we lose a + * valid localtalk port as a + * result. + * Arnaldo C. de Melo : Cleanup, in preparation for + * shared skb support 8) + * Arnaldo C. de Melo : Move proc stuff to atalk_proc.c, + * use seq_file + */ + +#include +#include +#include +#include /* For TIOCOUTQ/INQ */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct datalink_proto *ddp_dl, *aarp_dl; +static const struct proto_ops atalk_dgram_ops; + +/**************************************************************************\ +* * +* Handlers for the socket list. * +* * +\**************************************************************************/ + +HLIST_HEAD(atalk_sockets); +DEFINE_RWLOCK(atalk_sockets_lock); + +static inline void __atalk_insert_socket(struct sock *sk) +{ + sk_add_node(sk, &atalk_sockets); +} + +static inline void atalk_remove_socket(struct sock *sk) +{ + write_lock_bh(&atalk_sockets_lock); + sk_del_node_init(sk); + write_unlock_bh(&atalk_sockets_lock); +} + +static struct sock *atalk_search_socket(struct sockaddr_at *to, + struct atalk_iface *atif) +{ + struct sock *s; + + read_lock_bh(&atalk_sockets_lock); + sk_for_each(s, &atalk_sockets) { + struct atalk_sock *at = at_sk(s); + + if (to->sat_port != at->src_port) + continue; + + if (to->sat_addr.s_net == ATADDR_ANYNET && + to->sat_addr.s_node == ATADDR_BCAST) + goto found; + + if (to->sat_addr.s_net == at->src_net && + (to->sat_addr.s_node == at->src_node || + to->sat_addr.s_node == ATADDR_BCAST || + to->sat_addr.s_node == ATADDR_ANYNODE)) + goto found; + + /* XXXX.0 -- we got a request for this router. make sure + * that the node is appropriately set. */ + if (to->sat_addr.s_node == ATADDR_ANYNODE && + to->sat_addr.s_net != ATADDR_ANYNET && + atif->address.s_node == at->src_node) { + to->sat_addr.s_node = atif->address.s_node; + goto found; + } + } + s = NULL; +found: + read_unlock_bh(&atalk_sockets_lock); + return s; +} + +/** + * atalk_find_or_insert_socket - Try to find a socket matching ADDR + * @sk: socket to insert in the list if it is not there already + * @sat: address to search for + * + * Try to find a socket matching ADDR in the socket list, if found then return + * it. If not, insert SK into the socket list. + * + * This entire operation must execute atomically. + */ +static struct sock *atalk_find_or_insert_socket(struct sock *sk, + struct sockaddr_at *sat) +{ + struct sock *s; + struct atalk_sock *at; + + write_lock_bh(&atalk_sockets_lock); + sk_for_each(s, &atalk_sockets) { + at = at_sk(s); + + if (at->src_net == sat->sat_addr.s_net && + at->src_node == sat->sat_addr.s_node && + at->src_port == sat->sat_port) + goto found; + } + s = NULL; + __atalk_insert_socket(sk); /* Wheee, it's free, assign and insert. */ +found: + write_unlock_bh(&atalk_sockets_lock); + return s; +} + +static void atalk_destroy_timer(struct timer_list *t) +{ + struct sock *sk = from_timer(sk, t, sk_timer); + + if (sk_has_allocations(sk)) { + sk->sk_timer.expires = jiffies + SOCK_DESTROY_TIME; + add_timer(&sk->sk_timer); + } else + sock_put(sk); +} + +static inline void atalk_destroy_socket(struct sock *sk) +{ + atalk_remove_socket(sk); + skb_queue_purge(&sk->sk_receive_queue); + + if (sk_has_allocations(sk)) { + timer_setup(&sk->sk_timer, atalk_destroy_timer, 0); + sk->sk_timer.expires = jiffies + SOCK_DESTROY_TIME; + add_timer(&sk->sk_timer); + } else + sock_put(sk); +} + +/**************************************************************************\ +* * +* Routing tables for the AppleTalk socket layer. * +* * +\**************************************************************************/ + +/* Anti-deadlock ordering is atalk_routes_lock --> iface_lock -DaveM */ +struct atalk_route *atalk_routes; +DEFINE_RWLOCK(atalk_routes_lock); + +struct atalk_iface *atalk_interfaces; +DEFINE_RWLOCK(atalk_interfaces_lock); + +/* For probing devices or in a routerless network */ +struct atalk_route atrtr_default; + +/* AppleTalk interface control */ +/* + * Drop a device. Doesn't drop any of its routes - that is the caller's + * problem. Called when we down the interface or delete the address. + */ +static void atif_drop_device(struct net_device *dev) +{ + struct atalk_iface **iface = &atalk_interfaces; + struct atalk_iface *tmp; + + write_lock_bh(&atalk_interfaces_lock); + while ((tmp = *iface) != NULL) { + if (tmp->dev == dev) { + *iface = tmp->next; + dev_put(dev); + kfree(tmp); + dev->atalk_ptr = NULL; + } else + iface = &tmp->next; + } + write_unlock_bh(&atalk_interfaces_lock); +} + +static struct atalk_iface *atif_add_device(struct net_device *dev, + struct atalk_addr *sa) +{ + struct atalk_iface *iface = kzalloc(sizeof(*iface), GFP_KERNEL); + + if (!iface) + goto out; + + dev_hold(dev); + iface->dev = dev; + dev->atalk_ptr = iface; + iface->address = *sa; + iface->status = 0; + + write_lock_bh(&atalk_interfaces_lock); + iface->next = atalk_interfaces; + atalk_interfaces = iface; + write_unlock_bh(&atalk_interfaces_lock); +out: + return iface; +} + +/* Perform phase 2 AARP probing on our tentative address */ +static int atif_probe_device(struct atalk_iface *atif) +{ + int netrange = ntohs(atif->nets.nr_lastnet) - + ntohs(atif->nets.nr_firstnet) + 1; + int probe_net = ntohs(atif->address.s_net); + int probe_node = atif->address.s_node; + int netct, nodect; + + /* Offset the network we start probing with */ + if (probe_net == ATADDR_ANYNET) { + probe_net = ntohs(atif->nets.nr_firstnet); + if (netrange) + probe_net += jiffies % netrange; + } + if (probe_node == ATADDR_ANYNODE) + probe_node = jiffies & 0xFF; + + /* Scan the networks */ + atif->status |= ATIF_PROBE; + for (netct = 0; netct <= netrange; netct++) { + /* Sweep the available nodes from a given start */ + atif->address.s_net = htons(probe_net); + for (nodect = 0; nodect < 256; nodect++) { + atif->address.s_node = (nodect + probe_node) & 0xFF; + if (atif->address.s_node > 0 && + atif->address.s_node < 254) { + /* Probe a proposed address */ + aarp_probe_network(atif); + + if (!(atif->status & ATIF_PROBE_FAIL)) { + atif->status &= ~ATIF_PROBE; + return 0; + } + } + atif->status &= ~ATIF_PROBE_FAIL; + } + probe_net++; + if (probe_net > ntohs(atif->nets.nr_lastnet)) + probe_net = ntohs(atif->nets.nr_firstnet); + } + atif->status &= ~ATIF_PROBE; + + return -EADDRINUSE; /* Network is full... */ +} + + +/* Perform AARP probing for a proxy address */ +static int atif_proxy_probe_device(struct atalk_iface *atif, + struct atalk_addr *proxy_addr) +{ + int netrange = ntohs(atif->nets.nr_lastnet) - + ntohs(atif->nets.nr_firstnet) + 1; + /* we probe the interface's network */ + int probe_net = ntohs(atif->address.s_net); + int probe_node = ATADDR_ANYNODE; /* we'll take anything */ + int netct, nodect; + + /* Offset the network we start probing with */ + if (probe_net == ATADDR_ANYNET) { + probe_net = ntohs(atif->nets.nr_firstnet); + if (netrange) + probe_net += jiffies % netrange; + } + + if (probe_node == ATADDR_ANYNODE) + probe_node = jiffies & 0xFF; + + /* Scan the networks */ + for (netct = 0; netct <= netrange; netct++) { + /* Sweep the available nodes from a given start */ + proxy_addr->s_net = htons(probe_net); + for (nodect = 0; nodect < 256; nodect++) { + proxy_addr->s_node = (nodect + probe_node) & 0xFF; + if (proxy_addr->s_node > 0 && + proxy_addr->s_node < 254) { + /* Tell AARP to probe a proposed address */ + int ret = aarp_proxy_probe_network(atif, + proxy_addr); + + if (ret != -EADDRINUSE) + return ret; + } + } + probe_net++; + if (probe_net > ntohs(atif->nets.nr_lastnet)) + probe_net = ntohs(atif->nets.nr_firstnet); + } + + return -EADDRINUSE; /* Network is full... */ +} + + +struct atalk_addr *atalk_find_dev_addr(struct net_device *dev) +{ + struct atalk_iface *iface = dev->atalk_ptr; + return iface ? &iface->address : NULL; +} + +static struct atalk_addr *atalk_find_primary(void) +{ + struct atalk_iface *fiface = NULL; + struct atalk_addr *retval; + struct atalk_iface *iface; + + /* + * Return a point-to-point interface only if + * there is no non-ptp interface available. + */ + read_lock_bh(&atalk_interfaces_lock); + for (iface = atalk_interfaces; iface; iface = iface->next) { + if (!fiface && !(iface->dev->flags & IFF_LOOPBACK)) + fiface = iface; + if (!(iface->dev->flags & (IFF_LOOPBACK | IFF_POINTOPOINT))) { + retval = &iface->address; + goto out; + } + } + + if (fiface) + retval = &fiface->address; + else if (atalk_interfaces) + retval = &atalk_interfaces->address; + else + retval = NULL; +out: + read_unlock_bh(&atalk_interfaces_lock); + return retval; +} + +/* + * Find a match for 'any network' - ie any of our interfaces with that + * node number will do just nicely. + */ +static struct atalk_iface *atalk_find_anynet(int node, struct net_device *dev) +{ + struct atalk_iface *iface = dev->atalk_ptr; + + if (!iface || iface->status & ATIF_PROBE) + goto out_err; + + if (node != ATADDR_BCAST && + iface->address.s_node != node && + node != ATADDR_ANYNODE) + goto out_err; +out: + return iface; +out_err: + iface = NULL; + goto out; +} + +/* Find a match for a specific network:node pair */ +static struct atalk_iface *atalk_find_interface(__be16 net, int node) +{ + struct atalk_iface *iface; + + read_lock_bh(&atalk_interfaces_lock); + for (iface = atalk_interfaces; iface; iface = iface->next) { + if ((node == ATADDR_BCAST || + node == ATADDR_ANYNODE || + iface->address.s_node == node) && + iface->address.s_net == net && + !(iface->status & ATIF_PROBE)) + break; + + /* XXXX.0 -- net.0 returns the iface associated with net */ + if (node == ATADDR_ANYNODE && net != ATADDR_ANYNET && + ntohs(iface->nets.nr_firstnet) <= ntohs(net) && + ntohs(net) <= ntohs(iface->nets.nr_lastnet)) + break; + } + read_unlock_bh(&atalk_interfaces_lock); + return iface; +} + + +/* + * Find a route for an AppleTalk packet. This ought to get cached in + * the socket (later on...). We know about host routes and the fact + * that a route must be direct to broadcast. + */ +static struct atalk_route *atrtr_find(struct atalk_addr *target) +{ + /* + * we must search through all routes unless we find a + * host route, because some host routes might overlap + * network routes + */ + struct atalk_route *net_route = NULL; + struct atalk_route *r; + + read_lock_bh(&atalk_routes_lock); + for (r = atalk_routes; r; r = r->next) { + if (!(r->flags & RTF_UP)) + continue; + + if (r->target.s_net == target->s_net) { + if (r->flags & RTF_HOST) { + /* + * if this host route is for the target, + * the we're done + */ + if (r->target.s_node == target->s_node) + goto out; + } else + /* + * this route will work if there isn't a + * direct host route, so cache it + */ + net_route = r; + } + } + + /* + * if we found a network route but not a direct host + * route, then return it + */ + if (net_route) + r = net_route; + else if (atrtr_default.dev) + r = &atrtr_default; + else /* No route can be found */ + r = NULL; +out: + read_unlock_bh(&atalk_routes_lock); + return r; +} + + +/* + * Given an AppleTalk network, find the device to use. This can be + * a simple lookup. + */ +struct net_device *atrtr_get_dev(struct atalk_addr *sa) +{ + struct atalk_route *atr = atrtr_find(sa); + return atr ? atr->dev : NULL; +} + +/* Set up a default router */ +static void atrtr_set_default(struct net_device *dev) +{ + atrtr_default.dev = dev; + atrtr_default.flags = RTF_UP; + atrtr_default.gateway.s_net = htons(0); + atrtr_default.gateway.s_node = 0; +} + +/* + * Add a router. Basically make sure it looks valid and stuff the + * entry in the list. While it uses netranges we always set them to one + * entry to work like netatalk. + */ +static int atrtr_create(struct rtentry *r, struct net_device *devhint) +{ + struct sockaddr_at *ta = (struct sockaddr_at *)&r->rt_dst; + struct sockaddr_at *ga = (struct sockaddr_at *)&r->rt_gateway; + struct atalk_route *rt; + struct atalk_iface *iface, *riface; + int retval = -EINVAL; + + /* + * Fixme: Raise/Lower a routing change semaphore for these + * operations. + */ + + /* Validate the request */ + if (ta->sat_family != AF_APPLETALK || + (!devhint && ga->sat_family != AF_APPLETALK)) + goto out; + + /* Now walk the routing table and make our decisions */ + write_lock_bh(&atalk_routes_lock); + for (rt = atalk_routes; rt; rt = rt->next) { + if (r->rt_flags != rt->flags) + continue; + + if (ta->sat_addr.s_net == rt->target.s_net) { + if (!(rt->flags & RTF_HOST)) + break; + if (ta->sat_addr.s_node == rt->target.s_node) + break; + } + } + + if (!devhint) { + riface = NULL; + + read_lock_bh(&atalk_interfaces_lock); + for (iface = atalk_interfaces; iface; iface = iface->next) { + if (!riface && + ntohs(ga->sat_addr.s_net) >= + ntohs(iface->nets.nr_firstnet) && + ntohs(ga->sat_addr.s_net) <= + ntohs(iface->nets.nr_lastnet)) + riface = iface; + + if (ga->sat_addr.s_net == iface->address.s_net && + ga->sat_addr.s_node == iface->address.s_node) + riface = iface; + } + read_unlock_bh(&atalk_interfaces_lock); + + retval = -ENETUNREACH; + if (!riface) + goto out_unlock; + + devhint = riface->dev; + } + + if (!rt) { + rt = kzalloc(sizeof(*rt), GFP_ATOMIC); + + retval = -ENOBUFS; + if (!rt) + goto out_unlock; + + rt->next = atalk_routes; + atalk_routes = rt; + } + + /* Fill in the routing entry */ + rt->target = ta->sat_addr; + dev_hold(devhint); + rt->dev = devhint; + rt->flags = r->rt_flags; + rt->gateway = ga->sat_addr; + + retval = 0; +out_unlock: + write_unlock_bh(&atalk_routes_lock); +out: + return retval; +} + +/* Delete a route. Find it and discard it */ +static int atrtr_delete(struct atalk_addr *addr) +{ + struct atalk_route **r = &atalk_routes; + int retval = 0; + struct atalk_route *tmp; + + write_lock_bh(&atalk_routes_lock); + while ((tmp = *r) != NULL) { + if (tmp->target.s_net == addr->s_net && + (!(tmp->flags&RTF_GATEWAY) || + tmp->target.s_node == addr->s_node)) { + *r = tmp->next; + dev_put(tmp->dev); + kfree(tmp); + goto out; + } + r = &tmp->next; + } + retval = -ENOENT; +out: + write_unlock_bh(&atalk_routes_lock); + return retval; +} + +/* + * Called when a device is downed. Just throw away any routes + * via it. + */ +static void atrtr_device_down(struct net_device *dev) +{ + struct atalk_route **r = &atalk_routes; + struct atalk_route *tmp; + + write_lock_bh(&atalk_routes_lock); + while ((tmp = *r) != NULL) { + if (tmp->dev == dev) { + *r = tmp->next; + dev_put(dev); + kfree(tmp); + } else + r = &tmp->next; + } + write_unlock_bh(&atalk_routes_lock); + + if (atrtr_default.dev == dev) + atrtr_set_default(NULL); +} + +/* Actually down the interface */ +static inline void atalk_dev_down(struct net_device *dev) +{ + atrtr_device_down(dev); /* Remove all routes for the device */ + aarp_device_down(dev); /* Remove AARP entries for the device */ + atif_drop_device(dev); /* Remove the device */ +} + +/* + * A device event has occurred. Watch for devices going down and + * delete our use of them (iface and route). + */ +static int ddp_device_event(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + if (!net_eq(dev_net(dev), &init_net)) + return NOTIFY_DONE; + + if (event == NETDEV_DOWN) + /* Discard any use of this */ + atalk_dev_down(dev); + + return NOTIFY_DONE; +} + +/* ioctl calls. Shouldn't even need touching */ +/* Device configuration ioctl calls */ +static int atif_ioctl(int cmd, void __user *arg) +{ + static char aarp_mcast[6] = { 0x09, 0x00, 0x00, 0xFF, 0xFF, 0xFF }; + struct ifreq atreq; + struct atalk_netrange *nr; + struct sockaddr_at *sa; + struct net_device *dev; + struct atalk_iface *atif; + int ct; + int limit; + struct rtentry rtdef; + int add_route; + + if (get_user_ifreq(&atreq, NULL, arg)) + return -EFAULT; + + dev = __dev_get_by_name(&init_net, atreq.ifr_name); + if (!dev) + return -ENODEV; + + sa = (struct sockaddr_at *)&atreq.ifr_addr; + atif = atalk_find_dev(dev); + + switch (cmd) { + case SIOCSIFADDR: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + if (sa->sat_family != AF_APPLETALK) + return -EINVAL; + if (dev->type != ARPHRD_ETHER && + dev->type != ARPHRD_LOOPBACK && + dev->type != ARPHRD_LOCALTLK && + dev->type != ARPHRD_PPP) + return -EPROTONOSUPPORT; + + nr = (struct atalk_netrange *)&sa->sat_zero[0]; + add_route = 1; + + /* + * if this is a point-to-point iface, and we already + * have an iface for this AppleTalk address, then we + * should not add a route + */ + if ((dev->flags & IFF_POINTOPOINT) && + atalk_find_interface(sa->sat_addr.s_net, + sa->sat_addr.s_node)) { + printk(KERN_DEBUG "AppleTalk: point-to-point " + "interface added with " + "existing address\n"); + add_route = 0; + } + + /* + * Phase 1 is fine on LocalTalk but we don't do + * EtherTalk phase 1. Anyone wanting to add it, go ahead. + */ + if (dev->type == ARPHRD_ETHER && nr->nr_phase != 2) + return -EPROTONOSUPPORT; + if (sa->sat_addr.s_node == ATADDR_BCAST || + sa->sat_addr.s_node == 254) + return -EINVAL; + if (atif) { + /* Already setting address */ + if (atif->status & ATIF_PROBE) + return -EBUSY; + + atif->address.s_net = sa->sat_addr.s_net; + atif->address.s_node = sa->sat_addr.s_node; + atrtr_device_down(dev); /* Flush old routes */ + } else { + atif = atif_add_device(dev, &sa->sat_addr); + if (!atif) + return -ENOMEM; + } + atif->nets = *nr; + + /* + * Check if the chosen address is used. If so we + * error and atalkd will try another. + */ + + if (!(dev->flags & IFF_LOOPBACK) && + !(dev->flags & IFF_POINTOPOINT) && + atif_probe_device(atif) < 0) { + atif_drop_device(dev); + return -EADDRINUSE; + } + + /* Hey it worked - add the direct routes */ + sa = (struct sockaddr_at *)&rtdef.rt_gateway; + sa->sat_family = AF_APPLETALK; + sa->sat_addr.s_net = atif->address.s_net; + sa->sat_addr.s_node = atif->address.s_node; + sa = (struct sockaddr_at *)&rtdef.rt_dst; + rtdef.rt_flags = RTF_UP; + sa->sat_family = AF_APPLETALK; + sa->sat_addr.s_node = ATADDR_ANYNODE; + if (dev->flags & IFF_LOOPBACK || + dev->flags & IFF_POINTOPOINT) + rtdef.rt_flags |= RTF_HOST; + + /* Routerless initial state */ + if (nr->nr_firstnet == htons(0) && + nr->nr_lastnet == htons(0xFFFE)) { + sa->sat_addr.s_net = atif->address.s_net; + atrtr_create(&rtdef, dev); + atrtr_set_default(dev); + } else { + limit = ntohs(nr->nr_lastnet); + if (limit - ntohs(nr->nr_firstnet) > 4096) { + printk(KERN_WARNING "Too many routes/" + "iface.\n"); + return -EINVAL; + } + if (add_route) + for (ct = ntohs(nr->nr_firstnet); + ct <= limit; ct++) { + sa->sat_addr.s_net = htons(ct); + atrtr_create(&rtdef, dev); + } + } + dev_mc_add_global(dev, aarp_mcast); + return 0; + + case SIOCGIFADDR: + if (!atif) + return -EADDRNOTAVAIL; + + sa->sat_family = AF_APPLETALK; + sa->sat_addr = atif->address; + break; + + case SIOCGIFBRDADDR: + if (!atif) + return -EADDRNOTAVAIL; + + sa->sat_family = AF_APPLETALK; + sa->sat_addr.s_net = atif->address.s_net; + sa->sat_addr.s_node = ATADDR_BCAST; + break; + + case SIOCATALKDIFADDR: + case SIOCDIFADDR: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + if (sa->sat_family != AF_APPLETALK) + return -EINVAL; + atalk_dev_down(dev); + break; + + case SIOCSARP: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + if (sa->sat_family != AF_APPLETALK) + return -EINVAL; + /* + * for now, we only support proxy AARP on ELAP; + * we should be able to do it for LocalTalk, too. + */ + if (dev->type != ARPHRD_ETHER) + return -EPROTONOSUPPORT; + + /* + * atif points to the current interface on this network; + * we aren't concerned about its current status (at + * least for now), but it has all the settings about + * the network we're going to probe. Consequently, it + * must exist. + */ + if (!atif) + return -EADDRNOTAVAIL; + + nr = (struct atalk_netrange *)&(atif->nets); + /* + * Phase 1 is fine on Localtalk but we don't do + * Ethertalk phase 1. Anyone wanting to add it, go ahead. + */ + if (dev->type == ARPHRD_ETHER && nr->nr_phase != 2) + return -EPROTONOSUPPORT; + + if (sa->sat_addr.s_node == ATADDR_BCAST || + sa->sat_addr.s_node == 254) + return -EINVAL; + + /* + * Check if the chosen address is used. If so we + * error and ATCP will try another. + */ + if (atif_proxy_probe_device(atif, &(sa->sat_addr)) < 0) + return -EADDRINUSE; + + /* + * We now have an address on the local network, and + * the AARP code will defend it for us until we take it + * down. We don't set up any routes right now, because + * ATCP will install them manually via SIOCADDRT. + */ + break; + + case SIOCDARP: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + if (sa->sat_family != AF_APPLETALK) + return -EINVAL; + if (!atif) + return -EADDRNOTAVAIL; + + /* give to aarp module to remove proxy entry */ + aarp_proxy_remove(atif->dev, &(sa->sat_addr)); + return 0; + } + + return put_user_ifreq(&atreq, arg); +} + +static int atrtr_ioctl_addrt(struct rtentry *rt) +{ + struct net_device *dev = NULL; + + if (rt->rt_dev) { + char name[IFNAMSIZ]; + + if (copy_from_user(name, rt->rt_dev, IFNAMSIZ-1)) + return -EFAULT; + name[IFNAMSIZ-1] = '\0'; + + dev = __dev_get_by_name(&init_net, name); + if (!dev) + return -ENODEV; + } + return atrtr_create(rt, dev); +} + +/* Routing ioctl() calls */ +static int atrtr_ioctl(unsigned int cmd, void __user *arg) +{ + struct rtentry rt; + + if (copy_from_user(&rt, arg, sizeof(rt))) + return -EFAULT; + + switch (cmd) { + case SIOCDELRT: + if (rt.rt_dst.sa_family != AF_APPLETALK) + return -EINVAL; + return atrtr_delete(&((struct sockaddr_at *) + &rt.rt_dst)->sat_addr); + + case SIOCADDRT: + return atrtr_ioctl_addrt(&rt); + } + return -EINVAL; +} + +/**************************************************************************\ +* * +* Handling for system calls applied via the various interfaces to an * +* AppleTalk socket object. * +* * +\**************************************************************************/ + +/* + * Checksum: This is 'optional'. It's quite likely also a good + * candidate for assembler hackery 8) + */ +static unsigned long atalk_sum_partial(const unsigned char *data, + int len, unsigned long sum) +{ + /* This ought to be unwrapped neatly. I'll trust gcc for now */ + while (len--) { + sum += *data++; + sum = rol16(sum, 1); + } + return sum; +} + +/* Checksum skb data -- similar to skb_checksum */ +static unsigned long atalk_sum_skb(const struct sk_buff *skb, int offset, + int len, unsigned long sum) +{ + int start = skb_headlen(skb); + struct sk_buff *frag_iter; + int i, copy; + + /* checksum stuff in header space */ + if ((copy = start - offset) > 0) { + if (copy > len) + copy = len; + sum = atalk_sum_partial(skb->data + offset, copy, sum); + if ((len -= copy) == 0) + return sum; + + offset += copy; + } + + /* checksum stuff in frags */ + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + int end; + const skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; + WARN_ON(start > offset + len); + + end = start + skb_frag_size(frag); + if ((copy = end - offset) > 0) { + u8 *vaddr; + + if (copy > len) + copy = len; + vaddr = kmap_atomic(skb_frag_page(frag)); + sum = atalk_sum_partial(vaddr + skb_frag_off(frag) + + offset - start, copy, sum); + kunmap_atomic(vaddr); + + if (!(len -= copy)) + return sum; + offset += copy; + } + start = end; + } + + skb_walk_frags(skb, frag_iter) { + int end; + + WARN_ON(start > offset + len); + + end = start + frag_iter->len; + if ((copy = end - offset) > 0) { + if (copy > len) + copy = len; + sum = atalk_sum_skb(frag_iter, offset - start, + copy, sum); + if ((len -= copy) == 0) + return sum; + offset += copy; + } + start = end; + } + + BUG_ON(len > 0); + + return sum; +} + +static __be16 atalk_checksum(const struct sk_buff *skb, int len) +{ + unsigned long sum; + + /* skip header 4 bytes */ + sum = atalk_sum_skb(skb, 4, len-4, 0); + + /* Use 0xFFFF for 0. 0 itself means none */ + return sum ? htons((unsigned short)sum) : htons(0xFFFF); +} + +static struct proto ddp_proto = { + .name = "DDP", + .owner = THIS_MODULE, + .obj_size = sizeof(struct atalk_sock), +}; + +/* + * Create a socket. Initialise the socket, blank the addresses + * set the state. + */ +static int atalk_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + int rc = -ESOCKTNOSUPPORT; + + if (!net_eq(net, &init_net)) + return -EAFNOSUPPORT; + + /* + * We permit SOCK_DGRAM and RAW is an extension. It is trivial to do + * and gives you the full ELAP frame. Should be handy for CAP 8) + */ + if (sock->type != SOCK_RAW && sock->type != SOCK_DGRAM) + goto out; + + rc = -EPERM; + if (sock->type == SOCK_RAW && !kern && !capable(CAP_NET_RAW)) + goto out; + + rc = -ENOMEM; + sk = sk_alloc(net, PF_APPLETALK, GFP_KERNEL, &ddp_proto, kern); + if (!sk) + goto out; + rc = 0; + sock->ops = &atalk_dgram_ops; + sock_init_data(sock, sk); + + /* Checksums on by default */ + sock_set_flag(sk, SOCK_ZAPPED); +out: + return rc; +} + +/* Free a socket. No work needed */ +static int atalk_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + + if (sk) { + sock_hold(sk); + lock_sock(sk); + + sock_orphan(sk); + sock->sk = NULL; + atalk_destroy_socket(sk); + + release_sock(sk); + sock_put(sk); + } + return 0; +} + +/** + * atalk_pick_and_bind_port - Pick a source port when one is not given + * @sk: socket to insert into the tables + * @sat: address to search for + * + * Pick a source port when one is not given. If we can find a suitable free + * one, we insert the socket into the tables using it. + * + * This whole operation must be atomic. + */ +static int atalk_pick_and_bind_port(struct sock *sk, struct sockaddr_at *sat) +{ + int retval; + + write_lock_bh(&atalk_sockets_lock); + + for (sat->sat_port = ATPORT_RESERVED; + sat->sat_port < ATPORT_LAST; + sat->sat_port++) { + struct sock *s; + + sk_for_each(s, &atalk_sockets) { + struct atalk_sock *at = at_sk(s); + + if (at->src_net == sat->sat_addr.s_net && + at->src_node == sat->sat_addr.s_node && + at->src_port == sat->sat_port) + goto try_next_port; + } + + /* Wheee, it's free, assign and insert. */ + __atalk_insert_socket(sk); + at_sk(sk)->src_port = sat->sat_port; + retval = 0; + goto out; + +try_next_port:; + } + + retval = -EBUSY; +out: + write_unlock_bh(&atalk_sockets_lock); + return retval; +} + +static int atalk_autobind(struct sock *sk) +{ + struct atalk_sock *at = at_sk(sk); + struct sockaddr_at sat; + struct atalk_addr *ap = atalk_find_primary(); + int n = -EADDRNOTAVAIL; + + if (!ap || ap->s_net == htons(ATADDR_ANYNET)) + goto out; + + at->src_net = sat.sat_addr.s_net = ap->s_net; + at->src_node = sat.sat_addr.s_node = ap->s_node; + + n = atalk_pick_and_bind_port(sk, &sat); + if (!n) + sock_reset_flag(sk, SOCK_ZAPPED); +out: + return n; +} + +/* Set the address 'our end' of the connection */ +static int atalk_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +{ + struct sockaddr_at *addr = (struct sockaddr_at *)uaddr; + struct sock *sk = sock->sk; + struct atalk_sock *at = at_sk(sk); + int err; + + if (!sock_flag(sk, SOCK_ZAPPED) || + addr_len != sizeof(struct sockaddr_at)) + return -EINVAL; + + if (addr->sat_family != AF_APPLETALK) + return -EAFNOSUPPORT; + + lock_sock(sk); + if (addr->sat_addr.s_net == htons(ATADDR_ANYNET)) { + struct atalk_addr *ap = atalk_find_primary(); + + err = -EADDRNOTAVAIL; + if (!ap) + goto out; + + at->src_net = addr->sat_addr.s_net = ap->s_net; + at->src_node = addr->sat_addr.s_node = ap->s_node; + } else { + err = -EADDRNOTAVAIL; + if (!atalk_find_interface(addr->sat_addr.s_net, + addr->sat_addr.s_node)) + goto out; + + at->src_net = addr->sat_addr.s_net; + at->src_node = addr->sat_addr.s_node; + } + + if (addr->sat_port == ATADDR_ANYPORT) { + err = atalk_pick_and_bind_port(sk, addr); + + if (err < 0) + goto out; + } else { + at->src_port = addr->sat_port; + + err = -EADDRINUSE; + if (atalk_find_or_insert_socket(sk, addr)) + goto out; + } + + sock_reset_flag(sk, SOCK_ZAPPED); + err = 0; +out: + release_sock(sk); + return err; +} + +/* Set the address we talk to */ +static int atalk_connect(struct socket *sock, struct sockaddr *uaddr, + int addr_len, int flags) +{ + struct sock *sk = sock->sk; + struct atalk_sock *at = at_sk(sk); + struct sockaddr_at *addr; + int err; + + sk->sk_state = TCP_CLOSE; + sock->state = SS_UNCONNECTED; + + if (addr_len != sizeof(*addr)) + return -EINVAL; + + addr = (struct sockaddr_at *)uaddr; + + if (addr->sat_family != AF_APPLETALK) + return -EAFNOSUPPORT; + + if (addr->sat_addr.s_node == ATADDR_BCAST && + !sock_flag(sk, SOCK_BROADCAST)) { +#if 1 + pr_warn("atalk_connect: %s is broken and did not set SO_BROADCAST.\n", + current->comm); +#else + return -EACCES; +#endif + } + + lock_sock(sk); + err = -EBUSY; + if (sock_flag(sk, SOCK_ZAPPED)) + if (atalk_autobind(sk) < 0) + goto out; + + err = -ENETUNREACH; + if (!atrtr_get_dev(&addr->sat_addr)) + goto out; + + at->dest_port = addr->sat_port; + at->dest_net = addr->sat_addr.s_net; + at->dest_node = addr->sat_addr.s_node; + + sock->state = SS_CONNECTED; + sk->sk_state = TCP_ESTABLISHED; + err = 0; +out: + release_sock(sk); + return err; +} + +/* + * Find the name of an AppleTalk socket. Just copy the right + * fields into the sockaddr. + */ +static int atalk_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + struct sockaddr_at sat; + struct sock *sk = sock->sk; + struct atalk_sock *at = at_sk(sk); + int err; + + lock_sock(sk); + err = -ENOBUFS; + if (sock_flag(sk, SOCK_ZAPPED)) + if (atalk_autobind(sk) < 0) + goto out; + + memset(&sat, 0, sizeof(sat)); + + if (peer) { + err = -ENOTCONN; + if (sk->sk_state != TCP_ESTABLISHED) + goto out; + + sat.sat_addr.s_net = at->dest_net; + sat.sat_addr.s_node = at->dest_node; + sat.sat_port = at->dest_port; + } else { + sat.sat_addr.s_net = at->src_net; + sat.sat_addr.s_node = at->src_node; + sat.sat_port = at->src_port; + } + + sat.sat_family = AF_APPLETALK; + memcpy(uaddr, &sat, sizeof(sat)); + err = sizeof(struct sockaddr_at); + +out: + release_sock(sk); + return err; +} + +#if IS_ENABLED(CONFIG_IPDDP) +static __inline__ int is_ip_over_ddp(struct sk_buff *skb) +{ + return skb->data[12] == 22; +} + +static int handle_ip_over_ddp(struct sk_buff *skb) +{ + struct net_device *dev = __dev_get_by_name(&init_net, "ipddp0"); + struct net_device_stats *stats; + + /* This needs to be able to handle ipddp"N" devices */ + if (!dev) { + kfree_skb(skb); + return NET_RX_DROP; + } + + skb->protocol = htons(ETH_P_IP); + skb_pull(skb, 13); + skb->dev = dev; + skb_reset_transport_header(skb); + + stats = netdev_priv(dev); + stats->rx_packets++; + stats->rx_bytes += skb->len + 13; + return netif_rx(skb); /* Send the SKB up to a higher place. */ +} +#else +/* make it easy for gcc to optimize this test out, i.e. kill the code */ +#define is_ip_over_ddp(skb) 0 +#define handle_ip_over_ddp(skb) 0 +#endif + +static int atalk_route_packet(struct sk_buff *skb, struct net_device *dev, + struct ddpehdr *ddp, __u16 len_hops, int origlen) +{ + struct atalk_route *rt; + struct atalk_addr ta; + + /* + * Don't route multicast, etc., packets, or packets sent to "this + * network" + */ + if (skb->pkt_type != PACKET_HOST || !ddp->deh_dnet) { + /* + * FIXME: + * + * Can it ever happen that a packet is from a PPP iface and + * needs to be broadcast onto the default network? + */ + if (dev->type == ARPHRD_PPP) + printk(KERN_DEBUG "AppleTalk: didn't forward broadcast " + "packet received from PPP iface\n"); + goto free_it; + } + + ta.s_net = ddp->deh_dnet; + ta.s_node = ddp->deh_dnode; + + /* Route the packet */ + rt = atrtr_find(&ta); + /* increment hops count */ + len_hops += 1 << 10; + if (!rt || !(len_hops & (15 << 10))) + goto free_it; + + /* FIXME: use skb->cb to be able to use shared skbs */ + + /* + * Route goes through another gateway, so set the target to the + * gateway instead. + */ + + if (rt->flags & RTF_GATEWAY) { + ta.s_net = rt->gateway.s_net; + ta.s_node = rt->gateway.s_node; + } + + /* Fix up skb->len field */ + skb_trim(skb, min_t(unsigned int, origlen, + (rt->dev->hard_header_len + + ddp_dl->header_length + (len_hops & 1023)))); + + /* FIXME: use skb->cb to be able to use shared skbs */ + ddp->deh_len_hops = htons(len_hops); + + /* + * Send the buffer onwards + * + * Now we must always be careful. If it's come from LocalTalk to + * EtherTalk it might not fit + * + * Order matters here: If a packet has to be copied to make a new + * headroom (rare hopefully) then it won't need unsharing. + * + * Note. ddp-> becomes invalid at the realloc. + */ + if (skb_headroom(skb) < 22) { + /* 22 bytes - 12 ether, 2 len, 3 802.2 5 snap */ + struct sk_buff *nskb = skb_realloc_headroom(skb, 32); + kfree_skb(skb); + skb = nskb; + } else + skb = skb_unshare(skb, GFP_ATOMIC); + + /* + * If the buffer didn't vanish into the lack of space bitbucket we can + * send it. + */ + if (skb == NULL) + goto drop; + + if (aarp_send_ddp(rt->dev, skb, &ta, NULL) == NET_XMIT_DROP) + return NET_RX_DROP; + return NET_RX_SUCCESS; +free_it: + kfree_skb(skb); +drop: + return NET_RX_DROP; +} + +/** + * atalk_rcv - Receive a packet (in skb) from device dev + * @skb: packet received + * @dev: network device where the packet comes from + * @pt: packet type + * @orig_dev: the original receive net device + * + * Receive a packet (in skb) from device dev. This has come from the SNAP + * decoder, and on entry skb->transport_header is the DDP header, skb->len + * is the DDP header, skb->len is the DDP length. The physical headers + * have been extracted. PPP should probably pass frames marked as for this + * layer. [ie ARPHRD_ETHERTALK] + */ +static int atalk_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + struct ddpehdr *ddp; + struct sock *sock; + struct atalk_iface *atif; + struct sockaddr_at tosat; + int origlen; + __u16 len_hops; + + if (!net_eq(dev_net(dev), &init_net)) + goto drop; + + /* Don't mangle buffer if shared */ + if (!(skb = skb_share_check(skb, GFP_ATOMIC))) + goto out; + + /* Size check and make sure header is contiguous */ + if (!pskb_may_pull(skb, sizeof(*ddp))) + goto drop; + + ddp = ddp_hdr(skb); + + len_hops = ntohs(ddp->deh_len_hops); + + /* Trim buffer in case of stray trailing data */ + origlen = skb->len; + skb_trim(skb, min_t(unsigned int, skb->len, len_hops & 1023)); + + /* + * Size check to see if ddp->deh_len was crap + * (Otherwise we'll detonate most spectacularly + * in the middle of atalk_checksum() or recvmsg()). + */ + if (skb->len < sizeof(*ddp) || skb->len < (len_hops & 1023)) { + pr_debug("AppleTalk: dropping corrupted frame (deh_len=%u, " + "skb->len=%u)\n", len_hops & 1023, skb->len); + goto drop; + } + + /* + * Any checksums. Note we don't do htons() on this == is assumed to be + * valid for net byte orders all over the networking code... + */ + if (ddp->deh_sum && + atalk_checksum(skb, len_hops & 1023) != ddp->deh_sum) + /* Not a valid AppleTalk frame - dustbin time */ + goto drop; + + /* Check the packet is aimed at us */ + if (!ddp->deh_dnet) /* Net 0 is 'this network' */ + atif = atalk_find_anynet(ddp->deh_dnode, dev); + else + atif = atalk_find_interface(ddp->deh_dnet, ddp->deh_dnode); + + if (!atif) { + /* Not ours, so we route the packet via the correct + * AppleTalk iface + */ + return atalk_route_packet(skb, dev, ddp, len_hops, origlen); + } + + /* if IP over DDP is not selected this code will be optimized out */ + if (is_ip_over_ddp(skb)) + return handle_ip_over_ddp(skb); + /* + * Which socket - atalk_search_socket() looks for a *full match* + * of the tuple. + */ + tosat.sat_addr.s_net = ddp->deh_dnet; + tosat.sat_addr.s_node = ddp->deh_dnode; + tosat.sat_port = ddp->deh_dport; + + sock = atalk_search_socket(&tosat, atif); + if (!sock) /* But not one of our sockets */ + goto drop; + + /* Queue packet (standard) */ + if (sock_queue_rcv_skb(sock, skb) < 0) + goto drop; + + return NET_RX_SUCCESS; + +drop: + kfree_skb(skb); +out: + return NET_RX_DROP; + +} + +/* + * Receive a LocalTalk frame. We make some demands on the caller here. + * Caller must provide enough headroom on the packet to pull the short + * header and append a long one. + */ +static int ltalk_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + if (!net_eq(dev_net(dev), &init_net)) + goto freeit; + + /* Expand any short form frames */ + if (skb_mac_header(skb)[2] == 1) { + struct ddpehdr *ddp; + /* Find our address */ + struct atalk_addr *ap = atalk_find_dev_addr(dev); + + if (!ap || skb->len < sizeof(__be16) || skb->len > 1023) + goto freeit; + + /* Don't mangle buffer if shared */ + if (!(skb = skb_share_check(skb, GFP_ATOMIC))) + return 0; + + /* + * The push leaves us with a ddephdr not an shdr, and + * handily the port bytes in the right place preset. + */ + ddp = skb_push(skb, sizeof(*ddp) - 4); + + /* Now fill in the long header */ + + /* + * These two first. The mac overlays the new source/dest + * network information so we MUST copy these before + * we write the network numbers ! + */ + + ddp->deh_dnode = skb_mac_header(skb)[0]; /* From physical header */ + ddp->deh_snode = skb_mac_header(skb)[1]; /* From physical header */ + + ddp->deh_dnet = ap->s_net; /* Network number */ + ddp->deh_snet = ap->s_net; + ddp->deh_sum = 0; /* No checksum */ + /* + * Not sure about this bit... + */ + /* Non routable, so force a drop if we slip up later */ + ddp->deh_len_hops = htons(skb->len + (DDP_MAXHOPS << 10)); + } + skb_reset_transport_header(skb); + + return atalk_rcv(skb, dev, pt, orig_dev); +freeit: + kfree_skb(skb); + return 0; +} + +static int atalk_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) +{ + struct sock *sk = sock->sk; + struct atalk_sock *at = at_sk(sk); + DECLARE_SOCKADDR(struct sockaddr_at *, usat, msg->msg_name); + int flags = msg->msg_flags; + int loopback = 0; + struct sockaddr_at local_satalk, gsat; + struct sk_buff *skb; + struct net_device *dev; + struct ddpehdr *ddp; + int size, hard_header_len; + struct atalk_route *rt, *rt_lo = NULL; + int err; + + if (flags & ~(MSG_DONTWAIT|MSG_CMSG_COMPAT)) + return -EINVAL; + + if (len > DDP_MAXSZ) + return -EMSGSIZE; + + lock_sock(sk); + if (usat) { + err = -EBUSY; + if (sock_flag(sk, SOCK_ZAPPED)) + if (atalk_autobind(sk) < 0) + goto out; + + err = -EINVAL; + if (msg->msg_namelen < sizeof(*usat) || + usat->sat_family != AF_APPLETALK) + goto out; + + err = -EPERM; + /* netatalk didn't implement this check */ + if (usat->sat_addr.s_node == ATADDR_BCAST && + !sock_flag(sk, SOCK_BROADCAST)) { + goto out; + } + } else { + err = -ENOTCONN; + if (sk->sk_state != TCP_ESTABLISHED) + goto out; + usat = &local_satalk; + usat->sat_family = AF_APPLETALK; + usat->sat_port = at->dest_port; + usat->sat_addr.s_node = at->dest_node; + usat->sat_addr.s_net = at->dest_net; + } + + /* Build a packet */ + SOCK_DEBUG(sk, "SK %p: Got address.\n", sk); + + /* For headers */ + size = sizeof(struct ddpehdr) + len + ddp_dl->header_length; + + if (usat->sat_addr.s_net || usat->sat_addr.s_node == ATADDR_ANYNODE) { + rt = atrtr_find(&usat->sat_addr); + } else { + struct atalk_addr at_hint; + + at_hint.s_node = 0; + at_hint.s_net = at->src_net; + + rt = atrtr_find(&at_hint); + } + err = -ENETUNREACH; + if (!rt) + goto out; + + dev = rt->dev; + + SOCK_DEBUG(sk, "SK %p: Size needed %d, device %s\n", + sk, size, dev->name); + + hard_header_len = dev->hard_header_len; + /* Leave room for loopback hardware header if necessary */ + if (usat->sat_addr.s_node == ATADDR_BCAST && + (dev->flags & IFF_LOOPBACK || !(rt->flags & RTF_GATEWAY))) { + struct atalk_addr at_lo; + + at_lo.s_node = 0; + at_lo.s_net = 0; + + rt_lo = atrtr_find(&at_lo); + + if (rt_lo && rt_lo->dev->hard_header_len > hard_header_len) + hard_header_len = rt_lo->dev->hard_header_len; + } + + size += hard_header_len; + release_sock(sk); + skb = sock_alloc_send_skb(sk, size, (flags & MSG_DONTWAIT), &err); + lock_sock(sk); + if (!skb) + goto out; + + skb_reserve(skb, ddp_dl->header_length); + skb_reserve(skb, hard_header_len); + skb->dev = dev; + + SOCK_DEBUG(sk, "SK %p: Begin build.\n", sk); + + ddp = skb_put(skb, sizeof(struct ddpehdr)); + ddp->deh_len_hops = htons(len + sizeof(*ddp)); + ddp->deh_dnet = usat->sat_addr.s_net; + ddp->deh_snet = at->src_net; + ddp->deh_dnode = usat->sat_addr.s_node; + ddp->deh_snode = at->src_node; + ddp->deh_dport = usat->sat_port; + ddp->deh_sport = at->src_port; + + SOCK_DEBUG(sk, "SK %p: Copy user data (%zd bytes).\n", sk, len); + + err = memcpy_from_msg(skb_put(skb, len), msg, len); + if (err) { + kfree_skb(skb); + err = -EFAULT; + goto out; + } + + if (sk->sk_no_check_tx) + ddp->deh_sum = 0; + else + ddp->deh_sum = atalk_checksum(skb, len + sizeof(*ddp)); + + /* + * Loopback broadcast packets to non gateway targets (ie routes + * to group we are in) + */ + if (ddp->deh_dnode == ATADDR_BCAST && + !(rt->flags & RTF_GATEWAY) && !(dev->flags & IFF_LOOPBACK)) { + struct sk_buff *skb2 = skb_copy(skb, GFP_KERNEL); + + if (skb2) { + loopback = 1; + SOCK_DEBUG(sk, "SK %p: send out(copy).\n", sk); + /* + * If it fails it is queued/sent above in the aarp queue + */ + aarp_send_ddp(dev, skb2, &usat->sat_addr, NULL); + } + } + + if (dev->flags & IFF_LOOPBACK || loopback) { + SOCK_DEBUG(sk, "SK %p: Loop back.\n", sk); + /* loop back */ + skb_orphan(skb); + if (ddp->deh_dnode == ATADDR_BCAST) { + if (!rt_lo) { + kfree_skb(skb); + err = -ENETUNREACH; + goto out; + } + dev = rt_lo->dev; + skb->dev = dev; + } + ddp_dl->request(ddp_dl, skb, dev->dev_addr); + } else { + SOCK_DEBUG(sk, "SK %p: send out.\n", sk); + if (rt->flags & RTF_GATEWAY) { + gsat.sat_addr = rt->gateway; + usat = &gsat; + } + + /* + * If it fails it is queued/sent above in the aarp queue + */ + aarp_send_ddp(dev, skb, &usat->sat_addr, NULL); + } + SOCK_DEBUG(sk, "SK %p: Done write (%zd).\n", sk, len); + +out: + release_sock(sk); + return err ? : len; +} + +static int atalk_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, + int flags) +{ + struct sock *sk = sock->sk; + struct ddpehdr *ddp; + int copied = 0; + int offset = 0; + int err = 0; + struct sk_buff *skb; + + skb = skb_recv_datagram(sk, flags, &err); + lock_sock(sk); + + if (!skb) + goto out; + + /* FIXME: use skb->cb to be able to use shared skbs */ + ddp = ddp_hdr(skb); + copied = ntohs(ddp->deh_len_hops) & 1023; + + if (sk->sk_type != SOCK_RAW) { + offset = sizeof(*ddp); + copied -= offset; + } + + if (copied > size) { + copied = size; + msg->msg_flags |= MSG_TRUNC; + } + err = skb_copy_datagram_msg(skb, offset, msg, copied); + + if (!err && msg->msg_name) { + DECLARE_SOCKADDR(struct sockaddr_at *, sat, msg->msg_name); + sat->sat_family = AF_APPLETALK; + sat->sat_port = ddp->deh_sport; + sat->sat_addr.s_node = ddp->deh_snode; + sat->sat_addr.s_net = ddp->deh_snet; + msg->msg_namelen = sizeof(*sat); + } + + skb_free_datagram(sk, skb); /* Free the datagram. */ + +out: + release_sock(sk); + return err ? : copied; +} + + +/* + * AppleTalk ioctl calls. + */ +static int atalk_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + int rc = -ENOIOCTLCMD; + struct sock *sk = sock->sk; + void __user *argp = (void __user *)arg; + + switch (cmd) { + /* Protocol layer */ + case TIOCOUTQ: { + long amount = sk->sk_sndbuf - sk_wmem_alloc_get(sk); + + if (amount < 0) + amount = 0; + rc = put_user(amount, (int __user *)argp); + break; + } + case TIOCINQ: { + struct sk_buff *skb; + long amount = 0; + + spin_lock_irq(&sk->sk_receive_queue.lock); + skb = skb_peek(&sk->sk_receive_queue); + if (skb) + amount = skb->len - sizeof(struct ddpehdr); + spin_unlock_irq(&sk->sk_receive_queue.lock); + rc = put_user(amount, (int __user *)argp); + break; + } + /* Routing */ + case SIOCADDRT: + case SIOCDELRT: + rc = -EPERM; + if (capable(CAP_NET_ADMIN)) + rc = atrtr_ioctl(cmd, argp); + break; + /* Interface */ + case SIOCGIFADDR: + case SIOCSIFADDR: + case SIOCGIFBRDADDR: + case SIOCATALKDIFADDR: + case SIOCDIFADDR: + case SIOCSARP: /* proxy AARP */ + case SIOCDARP: /* proxy AARP */ + rtnl_lock(); + rc = atif_ioctl(cmd, argp); + rtnl_unlock(); + break; + } + + return rc; +} + + +#ifdef CONFIG_COMPAT +static int atalk_compat_routing_ioctl(struct sock *sk, unsigned int cmd, + struct compat_rtentry __user *ur) +{ + compat_uptr_t rtdev; + struct rtentry rt; + + if (copy_from_user(&rt.rt_dst, &ur->rt_dst, + 3 * sizeof(struct sockaddr)) || + get_user(rt.rt_flags, &ur->rt_flags) || + get_user(rt.rt_metric, &ur->rt_metric) || + get_user(rt.rt_mtu, &ur->rt_mtu) || + get_user(rt.rt_window, &ur->rt_window) || + get_user(rt.rt_irtt, &ur->rt_irtt) || + get_user(rtdev, &ur->rt_dev)) + return -EFAULT; + + switch (cmd) { + case SIOCDELRT: + if (rt.rt_dst.sa_family != AF_APPLETALK) + return -EINVAL; + return atrtr_delete(&((struct sockaddr_at *) + &rt.rt_dst)->sat_addr); + + case SIOCADDRT: + rt.rt_dev = compat_ptr(rtdev); + return atrtr_ioctl_addrt(&rt); + default: + return -EINVAL; + } +} +static int atalk_compat_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + void __user *argp = compat_ptr(arg); + struct sock *sk = sock->sk; + + switch (cmd) { + case SIOCADDRT: + case SIOCDELRT: + return atalk_compat_routing_ioctl(sk, cmd, argp); + /* + * SIOCATALKDIFADDR is a SIOCPROTOPRIVATE ioctl number, so we + * cannot handle it in common code. The data we access if ifreq + * here is compatible, so we can simply call the native + * handler. + */ + case SIOCATALKDIFADDR: + return atalk_ioctl(sock, cmd, (unsigned long)argp); + default: + return -ENOIOCTLCMD; + } +} +#endif /* CONFIG_COMPAT */ + + +static const struct net_proto_family atalk_family_ops = { + .family = PF_APPLETALK, + .create = atalk_create, + .owner = THIS_MODULE, +}; + +static const struct proto_ops atalk_dgram_ops = { + .family = PF_APPLETALK, + .owner = THIS_MODULE, + .release = atalk_release, + .bind = atalk_bind, + .connect = atalk_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = atalk_getname, + .poll = datagram_poll, + .ioctl = atalk_ioctl, + .gettstamp = sock_gettstamp, +#ifdef CONFIG_COMPAT + .compat_ioctl = atalk_compat_ioctl, +#endif + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .sendmsg = atalk_sendmsg, + .recvmsg = atalk_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static struct notifier_block ddp_notifier = { + .notifier_call = ddp_device_event, +}; + +static struct packet_type ltalk_packet_type __read_mostly = { + .type = cpu_to_be16(ETH_P_LOCALTALK), + .func = ltalk_rcv, +}; + +static struct packet_type ppptalk_packet_type __read_mostly = { + .type = cpu_to_be16(ETH_P_PPPTALK), + .func = atalk_rcv, +}; + +static unsigned char ddp_snap_id[] = { 0x08, 0x00, 0x07, 0x80, 0x9B }; + +/* Export symbols for use by drivers when AppleTalk is a module */ +EXPORT_SYMBOL(atrtr_get_dev); +EXPORT_SYMBOL(atalk_find_dev_addr); + +/* Called by proto.c on kernel start up */ +static int __init atalk_init(void) +{ + int rc; + + rc = proto_register(&ddp_proto, 0); + if (rc) + goto out; + + rc = sock_register(&atalk_family_ops); + if (rc) + goto out_proto; + + ddp_dl = register_snap_client(ddp_snap_id, atalk_rcv); + if (!ddp_dl) { + pr_crit("Unable to register DDP with SNAP.\n"); + rc = -ENOMEM; + goto out_sock; + } + + dev_add_pack(<alk_packet_type); + dev_add_pack(&ppptalk_packet_type); + + rc = register_netdevice_notifier(&ddp_notifier); + if (rc) + goto out_snap; + + rc = aarp_proto_init(); + if (rc) + goto out_dev; + + rc = atalk_proc_init(); + if (rc) + goto out_aarp; + + rc = atalk_register_sysctl(); + if (rc) + goto out_proc; +out: + return rc; +out_proc: + atalk_proc_exit(); +out_aarp: + aarp_cleanup_module(); +out_dev: + unregister_netdevice_notifier(&ddp_notifier); +out_snap: + dev_remove_pack(&ppptalk_packet_type); + dev_remove_pack(<alk_packet_type); + unregister_snap_client(ddp_dl); +out_sock: + sock_unregister(PF_APPLETALK); +out_proto: + proto_unregister(&ddp_proto); + goto out; +} +module_init(atalk_init); + +/* + * No explicit module reference count manipulation is needed in the + * protocol. Socket layer sets module reference count for us + * and interfaces reference counting is done + * by the network device layer. + * + * Ergo, before the AppleTalk module can be removed, all AppleTalk + * sockets should be closed from user space. + */ +static void __exit atalk_exit(void) +{ +#ifdef CONFIG_SYSCTL + atalk_unregister_sysctl(); +#endif /* CONFIG_SYSCTL */ + atalk_proc_exit(); + aarp_cleanup_module(); /* General aarp clean-up. */ + unregister_netdevice_notifier(&ddp_notifier); + dev_remove_pack(<alk_packet_type); + dev_remove_pack(&ppptalk_packet_type); + unregister_snap_client(ddp_dl); + sock_unregister(PF_APPLETALK); + proto_unregister(&ddp_proto); +} +module_exit(atalk_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Alan Cox "); +MODULE_DESCRIPTION("AppleTalk 0.20\n"); +MODULE_ALIAS_NETPROTO(PF_APPLETALK); diff --git a/net/appletalk/dev.c b/net/appletalk/dev.c new file mode 100644 index 000000000..284c8e585 --- /dev/null +++ b/net/appletalk/dev.c @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Moved here from drivers/net/net_init.c, which is: + * Written 1993,1994,1995 by Donald Becker. + */ + +#include +#include +#include +#include +#include + +static void ltalk_setup(struct net_device *dev) +{ + /* Fill in the fields of the device structure with localtalk-generic values. */ + + dev->type = ARPHRD_LOCALTLK; + dev->hard_header_len = LTALK_HLEN; + dev->mtu = LTALK_MTU; + dev->addr_len = LTALK_ALEN; + dev->tx_queue_len = 10; + + dev->broadcast[0] = 0xFF; + + dev->flags = IFF_BROADCAST|IFF_MULTICAST|IFF_NOARP; +} + +/** + * alloc_ltalkdev - Allocates and sets up an localtalk device + * @sizeof_priv: Size of additional driver-private structure to be allocated + * for this localtalk device + * + * Fill in the fields of the device structure with localtalk-generic + * values. Basically does everything except registering the device. + * + * Constructs a new net device, complete with a private data area of + * size @sizeof_priv. A 32-byte (not bit) alignment is enforced for + * this private data area. + */ + +struct net_device *alloc_ltalkdev(int sizeof_priv) +{ + return alloc_netdev(sizeof_priv, "lt%d", NET_NAME_UNKNOWN, + ltalk_setup); +} +EXPORT_SYMBOL(alloc_ltalkdev); diff --git a/net/appletalk/sysctl_net_atalk.c b/net/appletalk/sysctl_net_atalk.c new file mode 100644 index 000000000..d945b7c01 --- /dev/null +++ b/net/appletalk/sysctl_net_atalk.c @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * sysctl_net_atalk.c: sysctl interface to net AppleTalk subsystem. + * + * Begun April 1, 1996, Mike Shaver. + * Added /proc/sys/net/atalk directory entry (empty =) ). [MS] + * Dynamic registration, added aarp entries. (5/30/97 Chris Horn) + */ + +#include +#include +#include + +static struct ctl_table atalk_table[] = { + { + .procname = "aarp-expiry-time", + .data = &sysctl_aarp_expiry_time, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "aarp-tick-time", + .data = &sysctl_aarp_tick_time, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "aarp-retransmit-limit", + .data = &sysctl_aarp_retransmit_limit, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "aarp-resolve-time", + .data = &sysctl_aarp_resolve_time, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { }, +}; + +static struct ctl_table_header *atalk_table_header; + +int __init atalk_register_sysctl(void) +{ + atalk_table_header = register_net_sysctl(&init_net, "net/appletalk", atalk_table); + if (!atalk_table_header) + return -ENOMEM; + return 0; +} + +void atalk_unregister_sysctl(void) +{ + unregister_net_sysctl_table(atalk_table_header); +} diff --git a/net/atm/Kconfig b/net/atm/Kconfig new file mode 100644 index 000000000..77343d57f --- /dev/null +++ b/net/atm/Kconfig @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Asynchronous Transfer Mode (ATM) +# + +config ATM + tristate "Asynchronous Transfer Mode (ATM)" + help + ATM is a high-speed networking technology for Local Area Networks + and Wide Area Networks. It uses a fixed packet size and is + connection oriented, allowing for the negotiation of minimum + bandwidth requirements. + + In order to participate in an ATM network, your Linux box needs an + ATM networking card. If you have that, say Y here and to the driver + of your ATM card below. + + Note that you need a set of user-space programs to actually make use + of ATM. See the file for + further details. + +config ATM_CLIP + tristate "Classical IP over ATM" + depends on ATM && INET + help + Classical IP over ATM for PVCs and SVCs, supporting InARP and + ATMARP. If you want to communication with other IP hosts on your ATM + network, you will typically either say Y here or to "LAN Emulation + (LANE)" below. + +config ATM_CLIP_NO_ICMP + bool "Do NOT send ICMP if no neighbour" + depends on ATM_CLIP + help + Normally, an "ICMP host unreachable" message is sent if a neighbour + cannot be reached because there is no VC to it in the kernel's + ATMARP table. This may cause problems when ATMARP table entries are + briefly removed during revalidation. If you say Y here, packets to + such neighbours are silently discarded instead. + +config ATM_LANE + tristate "LAN Emulation (LANE) support" + depends on ATM + help + LAN Emulation emulates services of existing LANs across an ATM + network. Besides operating as a normal ATM end station client, Linux + LANE client can also act as an proxy client bridging packets between + ELAN and Ethernet segments. You need LANE if you want to try MPOA. + +config ATM_MPOA + tristate "Multi-Protocol Over ATM (MPOA) support" + depends on ATM && INET && ATM_LANE!=n + help + Multi-Protocol Over ATM allows ATM edge devices such as routers, + bridges and ATM attached hosts establish direct ATM VCs across + subnetwork boundaries. These shortcut connections bypass routers + enhancing overall network performance. + +config ATM_BR2684 + tristate "RFC1483/2684 Bridged protocols" + depends on ATM && INET + help + ATM PVCs can carry ethernet PDUs according to RFC2684 (formerly 1483) + This device will act like an ethernet from the kernels point of view, + with the traffic being carried by ATM PVCs (currently 1 PVC/device). + This is sometimes used over DSL lines. If in doubt, say N. + +config ATM_BR2684_IPFILTER + bool "Per-VC IP filter kludge" + depends on ATM_BR2684 + help + This is an experimental mechanism for users who need to terminate a + large number of IP-only vcc's. Do not enable this unless you are sure + you know what you are doing. diff --git a/net/atm/Makefile b/net/atm/Makefile new file mode 100644 index 000000000..bfec0f2d8 --- /dev/null +++ b/net/atm/Makefile @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the ATM Protocol Families. +# + +atm-y := addr.o pvc.o signaling.o svc.o ioctl.o common.o atm_misc.o raw.o resources.o atm_sysfs.o +mpoa-objs := mpc.o mpoa_caches.o mpoa_proc.o + +obj-$(CONFIG_ATM) += atm.o +obj-$(CONFIG_ATM_CLIP) += clip.o +obj-$(CONFIG_ATM_BR2684) += br2684.o +atm-$(CONFIG_PROC_FS) += proc.o + +obj-$(CONFIG_ATM_LANE) += lec.o +obj-$(CONFIG_ATM_MPOA) += mpoa.o +obj-$(CONFIG_PPPOATM) += pppoatm.o diff --git a/net/atm/addr.c b/net/atm/addr.c new file mode 100644 index 000000000..0530b63f5 --- /dev/null +++ b/net/atm/addr.c @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: GPL-2.0 +/* net/atm/addr.c - Local ATM address registry */ + +/* Written 1995-2000 by Werner Almesberger, EPFL LRC/ICA */ + +#include +#include +#include +#include + +#include "signaling.h" +#include "addr.h" + +static int check_addr(const struct sockaddr_atmsvc *addr) +{ + int i; + + if (addr->sas_family != AF_ATMSVC) + return -EAFNOSUPPORT; + if (!*addr->sas_addr.pub) + return *addr->sas_addr.prv ? 0 : -EINVAL; + for (i = 1; i < ATM_E164_LEN + 1; i++) /* make sure it's \0-terminated */ + if (!addr->sas_addr.pub[i]) + return 0; + return -EINVAL; +} + +static int identical(const struct sockaddr_atmsvc *a, const struct sockaddr_atmsvc *b) +{ + if (*a->sas_addr.prv) + if (memcmp(a->sas_addr.prv, b->sas_addr.prv, ATM_ESA_LEN)) + return 0; + if (!*a->sas_addr.pub) + return !*b->sas_addr.pub; + if (!*b->sas_addr.pub) + return 0; + return !strcmp(a->sas_addr.pub, b->sas_addr.pub); +} + +static void notify_sigd(const struct atm_dev *dev) +{ + struct sockaddr_atmpvc pvc; + + pvc.sap_addr.itf = dev->number; + sigd_enq(NULL, as_itf_notify, NULL, &pvc, NULL); +} + +void atm_reset_addr(struct atm_dev *dev, enum atm_addr_type_t atype) +{ + unsigned long flags; + struct atm_dev_addr *this, *p; + struct list_head *head; + + spin_lock_irqsave(&dev->lock, flags); + if (atype == ATM_ADDR_LECS) + head = &dev->lecs; + else + head = &dev->local; + list_for_each_entry_safe(this, p, head, entry) { + list_del(&this->entry); + kfree(this); + } + spin_unlock_irqrestore(&dev->lock, flags); + if (head == &dev->local) + notify_sigd(dev); +} + +int atm_add_addr(struct atm_dev *dev, const struct sockaddr_atmsvc *addr, + enum atm_addr_type_t atype) +{ + unsigned long flags; + struct atm_dev_addr *this; + struct list_head *head; + int error; + + error = check_addr(addr); + if (error) + return error; + spin_lock_irqsave(&dev->lock, flags); + if (atype == ATM_ADDR_LECS) + head = &dev->lecs; + else + head = &dev->local; + list_for_each_entry(this, head, entry) { + if (identical(&this->addr, addr)) { + spin_unlock_irqrestore(&dev->lock, flags); + return -EEXIST; + } + } + this = kmalloc(sizeof(struct atm_dev_addr), GFP_ATOMIC); + if (!this) { + spin_unlock_irqrestore(&dev->lock, flags); + return -ENOMEM; + } + this->addr = *addr; + list_add(&this->entry, head); + spin_unlock_irqrestore(&dev->lock, flags); + if (head == &dev->local) + notify_sigd(dev); + return 0; +} + +int atm_del_addr(struct atm_dev *dev, const struct sockaddr_atmsvc *addr, + enum atm_addr_type_t atype) +{ + unsigned long flags; + struct atm_dev_addr *this; + struct list_head *head; + int error; + + error = check_addr(addr); + if (error) + return error; + spin_lock_irqsave(&dev->lock, flags); + if (atype == ATM_ADDR_LECS) + head = &dev->lecs; + else + head = &dev->local; + list_for_each_entry(this, head, entry) { + if (identical(&this->addr, addr)) { + list_del(&this->entry); + spin_unlock_irqrestore(&dev->lock, flags); + kfree(this); + if (head == &dev->local) + notify_sigd(dev); + return 0; + } + } + spin_unlock_irqrestore(&dev->lock, flags); + return -ENOENT; +} + +int atm_get_addr(struct atm_dev *dev, struct sockaddr_atmsvc __user * buf, + size_t size, enum atm_addr_type_t atype) +{ + unsigned long flags; + struct atm_dev_addr *this; + struct list_head *head; + int total = 0, error; + struct sockaddr_atmsvc *tmp_buf, *tmp_bufp; + + spin_lock_irqsave(&dev->lock, flags); + if (atype == ATM_ADDR_LECS) + head = &dev->lecs; + else + head = &dev->local; + list_for_each_entry(this, head, entry) + total += sizeof(struct sockaddr_atmsvc); + tmp_buf = tmp_bufp = kmalloc(total, GFP_ATOMIC); + if (!tmp_buf) { + spin_unlock_irqrestore(&dev->lock, flags); + return -ENOMEM; + } + list_for_each_entry(this, head, entry) + memcpy(tmp_bufp++, &this->addr, sizeof(struct sockaddr_atmsvc)); + spin_unlock_irqrestore(&dev->lock, flags); + error = total > size ? -E2BIG : total; + if (copy_to_user(buf, tmp_buf, total < size ? total : size)) + error = -EFAULT; + kfree(tmp_buf); + return error; +} diff --git a/net/atm/addr.h b/net/atm/addr.h new file mode 100644 index 000000000..da3f84841 --- /dev/null +++ b/net/atm/addr.h @@ -0,0 +1,21 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* net/atm/addr.h - Local ATM address registry */ + +/* Written 1995-2000 by Werner Almesberger, EPFL LRC/ICA */ + + +#ifndef NET_ATM_ADDR_H +#define NET_ATM_ADDR_H + +#include +#include + +void atm_reset_addr(struct atm_dev *dev, enum atm_addr_type_t type); +int atm_add_addr(struct atm_dev *dev, const struct sockaddr_atmsvc *addr, + enum atm_addr_type_t type); +int atm_del_addr(struct atm_dev *dev, const struct sockaddr_atmsvc *addr, + enum atm_addr_type_t type); +int atm_get_addr(struct atm_dev *dev, struct sockaddr_atmsvc __user *buf, + size_t size, enum atm_addr_type_t type); + +#endif diff --git a/net/atm/atm_misc.c b/net/atm/atm_misc.c new file mode 100644 index 000000000..a30b83c1c --- /dev/null +++ b/net/atm/atm_misc.c @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: GPL-2.0 +/* net/atm/atm_misc.c - Various functions for use by ATM drivers */ + +/* Written 1995-2000 by Werner Almesberger, EPFL ICA */ + +#include +#include +#include +#include +#include +#include +#include +#include + +int atm_charge(struct atm_vcc *vcc, int truesize) +{ + atm_force_charge(vcc, truesize); + if (atomic_read(&sk_atm(vcc)->sk_rmem_alloc) <= sk_atm(vcc)->sk_rcvbuf) + return 1; + atm_return(vcc, truesize); + atomic_inc(&vcc->stats->rx_drop); + return 0; +} +EXPORT_SYMBOL(atm_charge); + +struct sk_buff *atm_alloc_charge(struct atm_vcc *vcc, int pdu_size, + gfp_t gfp_flags) +{ + struct sock *sk = sk_atm(vcc); + int guess = SKB_TRUESIZE(pdu_size); + + atm_force_charge(vcc, guess); + if (atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) { + struct sk_buff *skb = alloc_skb(pdu_size, gfp_flags); + + if (skb) { + atomic_add(skb->truesize-guess, + &sk->sk_rmem_alloc); + return skb; + } + } + atm_return(vcc, guess); + atomic_inc(&vcc->stats->rx_drop); + return NULL; +} +EXPORT_SYMBOL(atm_alloc_charge); + + +/* + * atm_pcr_goal returns the positive PCR if it should be rounded up, the + * negative PCR if it should be rounded down, and zero if the maximum available + * bandwidth should be used. + * + * The rules are as follows (* = maximum, - = absent (0), x = value "x", + * (x+ = x or next value above x, x- = x or next value below): + * + * min max pcr result min max pcr result + * - - - * (UBR only) x - - x+ + * - - * * x - * * + * - - z z- x - z z- + * - * - * x * - x+ + * - * * * x * * * + * - * z z- x * z z- + * - y - y- x y - x+ + * - y * y- x y * y- + * - y z z- x y z z- + * + * All non-error cases can be converted with the following simple set of rules: + * + * if pcr == z then z- + * else if min == x && pcr == - then x+ + * else if max == y then y- + * else * + */ + +int atm_pcr_goal(const struct atm_trafprm *tp) +{ + if (tp->pcr && tp->pcr != ATM_MAX_PCR) + return -tp->pcr; + if (tp->min_pcr && !tp->pcr) + return tp->min_pcr; + if (tp->max_pcr != ATM_MAX_PCR) + return -tp->max_pcr; + return 0; +} +EXPORT_SYMBOL(atm_pcr_goal); + +void sonet_copy_stats(struct k_sonet_stats *from, struct sonet_stats *to) +{ +#define __HANDLE_ITEM(i) to->i = atomic_read(&from->i) + __SONET_ITEMS +#undef __HANDLE_ITEM +} +EXPORT_SYMBOL(sonet_copy_stats); + +void sonet_subtract_stats(struct k_sonet_stats *from, struct sonet_stats *to) +{ +#define __HANDLE_ITEM(i) atomic_sub(to->i, &from->i) + __SONET_ITEMS +#undef __HANDLE_ITEM +} +EXPORT_SYMBOL(sonet_subtract_stats); diff --git a/net/atm/atm_sysfs.c b/net/atm/atm_sysfs.c new file mode 100644 index 000000000..0fdbdfd19 --- /dev/null +++ b/net/atm/atm_sysfs.c @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: GPL-2.0 +/* ATM driver model support. */ + +#include +#include +#include +#include +#include +#include "common.h" +#include "resources.h" + +#define to_atm_dev(cldev) container_of(cldev, struct atm_dev, class_dev) + +static ssize_t type_show(struct device *cdev, + struct device_attribute *attr, char *buf) +{ + struct atm_dev *adev = to_atm_dev(cdev); + + return scnprintf(buf, PAGE_SIZE, "%s\n", adev->type); +} + +static ssize_t address_show(struct device *cdev, + struct device_attribute *attr, char *buf) +{ + struct atm_dev *adev = to_atm_dev(cdev); + + return scnprintf(buf, PAGE_SIZE, "%pM\n", adev->esi); +} + +static ssize_t atmaddress_show(struct device *cdev, + struct device_attribute *attr, char *buf) +{ + unsigned long flags; + struct atm_dev *adev = to_atm_dev(cdev); + struct atm_dev_addr *aaddr; + int count = 0; + + spin_lock_irqsave(&adev->lock, flags); + list_for_each_entry(aaddr, &adev->local, entry) { + count += scnprintf(buf + count, PAGE_SIZE - count, + "%1phN.%2phN.%10phN.%6phN.%1phN\n", + &aaddr->addr.sas_addr.prv[0], + &aaddr->addr.sas_addr.prv[1], + &aaddr->addr.sas_addr.prv[3], + &aaddr->addr.sas_addr.prv[13], + &aaddr->addr.sas_addr.prv[19]); + } + spin_unlock_irqrestore(&adev->lock, flags); + + return count; +} + +static ssize_t atmindex_show(struct device *cdev, + struct device_attribute *attr, char *buf) +{ + struct atm_dev *adev = to_atm_dev(cdev); + + return scnprintf(buf, PAGE_SIZE, "%d\n", adev->number); +} + +static ssize_t carrier_show(struct device *cdev, + struct device_attribute *attr, char *buf) +{ + struct atm_dev *adev = to_atm_dev(cdev); + + return scnprintf(buf, PAGE_SIZE, "%d\n", + adev->signal == ATM_PHY_SIG_LOST ? 0 : 1); +} + +static ssize_t link_rate_show(struct device *cdev, + struct device_attribute *attr, char *buf) +{ + struct atm_dev *adev = to_atm_dev(cdev); + int link_rate; + + /* show the link rate, not the data rate */ + switch (adev->link_rate) { + case ATM_OC3_PCR: + link_rate = 155520000; + break; + case ATM_OC12_PCR: + link_rate = 622080000; + break; + case ATM_25_PCR: + link_rate = 25600000; + break; + default: + link_rate = adev->link_rate * 8 * 53; + } + return scnprintf(buf, PAGE_SIZE, "%d\n", link_rate); +} + +static DEVICE_ATTR_RO(address); +static DEVICE_ATTR_RO(atmaddress); +static DEVICE_ATTR_RO(atmindex); +static DEVICE_ATTR_RO(carrier); +static DEVICE_ATTR_RO(type); +static DEVICE_ATTR_RO(link_rate); + +static struct device_attribute *atm_attrs[] = { + &dev_attr_atmaddress, + &dev_attr_address, + &dev_attr_atmindex, + &dev_attr_carrier, + &dev_attr_type, + &dev_attr_link_rate, + NULL +}; + + +static int atm_uevent(struct device *cdev, struct kobj_uevent_env *env) +{ + struct atm_dev *adev; + + if (!cdev) + return -ENODEV; + + adev = to_atm_dev(cdev); + if (!adev) + return -ENODEV; + + if (add_uevent_var(env, "NAME=%s%d", adev->type, adev->number)) + return -ENOMEM; + + return 0; +} + +static void atm_release(struct device *cdev) +{ + struct atm_dev *adev = to_atm_dev(cdev); + + kfree(adev); +} + +static struct class atm_class = { + .name = "atm", + .dev_release = atm_release, + .dev_uevent = atm_uevent, +}; + +int atm_register_sysfs(struct atm_dev *adev, struct device *parent) +{ + struct device *cdev = &adev->class_dev; + int i, j, err; + + cdev->class = &atm_class; + cdev->parent = parent; + dev_set_drvdata(cdev, adev); + + dev_set_name(cdev, "%s%d", adev->type, adev->number); + err = device_register(cdev); + if (err < 0) + return err; + + for (i = 0; atm_attrs[i]; i++) { + err = device_create_file(cdev, atm_attrs[i]); + if (err) + goto err_out; + } + + return 0; + +err_out: + for (j = 0; j < i; j++) + device_remove_file(cdev, atm_attrs[j]); + device_del(cdev); + return err; +} + +void atm_unregister_sysfs(struct atm_dev *adev) +{ + struct device *cdev = &adev->class_dev; + + device_del(cdev); +} + +int __init atm_sysfs_init(void) +{ + return class_register(&atm_class); +} + +void __exit atm_sysfs_exit(void) +{ + class_unregister(&atm_class); +} diff --git a/net/atm/br2684.c b/net/atm/br2684.c new file mode 100644 index 000000000..f666f2f98 --- /dev/null +++ b/net/atm/br2684.c @@ -0,0 +1,872 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Ethernet netdevice using ATM AAL5 as underlying carrier + * (RFC1483 obsoleted by RFC2684) for Linux + * + * Authors: Marcell GAL, 2000, XDSL Ltd, Hungary + * Eric Kinzie, 2006-2007, US Naval Research Laboratory + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s: " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "common.h" + +static void skb_debug(const struct sk_buff *skb) +{ +#ifdef SKB_DEBUG +#define NUM2PRINT 50 + print_hex_dump(KERN_DEBUG, "br2684: skb: ", DUMP_OFFSET, + 16, 1, skb->data, min(NUM2PRINT, skb->len), true); +#endif +} + +#define BR2684_ETHERTYPE_LEN 2 +#define BR2684_PAD_LEN 2 + +#define LLC 0xaa, 0xaa, 0x03 +#define SNAP_BRIDGED 0x00, 0x80, 0xc2 +#define SNAP_ROUTED 0x00, 0x00, 0x00 +#define PID_ETHERNET 0x00, 0x07 +#define ETHERTYPE_IPV4 0x08, 0x00 +#define ETHERTYPE_IPV6 0x86, 0xdd +#define PAD_BRIDGED 0x00, 0x00 + +static const unsigned char ethertype_ipv4[] = { ETHERTYPE_IPV4 }; +static const unsigned char ethertype_ipv6[] = { ETHERTYPE_IPV6 }; +static const unsigned char llc_oui_pid_pad[] = + { LLC, SNAP_BRIDGED, PID_ETHERNET, PAD_BRIDGED }; +static const unsigned char pad[] = { PAD_BRIDGED }; +static const unsigned char llc_oui_ipv4[] = { LLC, SNAP_ROUTED, ETHERTYPE_IPV4 }; +static const unsigned char llc_oui_ipv6[] = { LLC, SNAP_ROUTED, ETHERTYPE_IPV6 }; + +enum br2684_encaps { + e_vc = BR2684_ENCAPS_VC, + e_llc = BR2684_ENCAPS_LLC, +}; + +struct br2684_vcc { + struct atm_vcc *atmvcc; + struct net_device *device; + /* keep old push, pop functions for chaining */ + void (*old_push)(struct atm_vcc *vcc, struct sk_buff *skb); + void (*old_pop)(struct atm_vcc *vcc, struct sk_buff *skb); + void (*old_release_cb)(struct atm_vcc *vcc); + struct module *old_owner; + enum br2684_encaps encaps; + struct list_head brvccs; +#ifdef CONFIG_ATM_BR2684_IPFILTER + struct br2684_filter filter; +#endif /* CONFIG_ATM_BR2684_IPFILTER */ + unsigned int copies_needed, copies_failed; + atomic_t qspace; +}; + +struct br2684_dev { + struct net_device *net_dev; + struct list_head br2684_devs; + int number; + struct list_head brvccs; /* one device <=> one vcc (before xmas) */ + int mac_was_set; + enum br2684_payload payload; +}; + +/* + * This lock should be held for writing any time the list of devices or + * their attached vcc's could be altered. It should be held for reading + * any time these are being queried. Note that we sometimes need to + * do read-locking under interrupting context, so write locking must block + * the current CPU's interrupts. + */ +static DEFINE_RWLOCK(devs_lock); + +static LIST_HEAD(br2684_devs); + +static inline struct br2684_dev *BRPRIV(const struct net_device *net_dev) +{ + return netdev_priv(net_dev); +} + +static inline struct net_device *list_entry_brdev(const struct list_head *le) +{ + return list_entry(le, struct br2684_dev, br2684_devs)->net_dev; +} + +static inline struct br2684_vcc *BR2684_VCC(const struct atm_vcc *atmvcc) +{ + return (struct br2684_vcc *)(atmvcc->user_back); +} + +static inline struct br2684_vcc *list_entry_brvcc(const struct list_head *le) +{ + return list_entry(le, struct br2684_vcc, brvccs); +} + +/* Caller should hold read_lock(&devs_lock) */ +static struct net_device *br2684_find_dev(const struct br2684_if_spec *s) +{ + struct list_head *lh; + struct net_device *net_dev; + switch (s->method) { + case BR2684_FIND_BYNUM: + list_for_each(lh, &br2684_devs) { + net_dev = list_entry_brdev(lh); + if (BRPRIV(net_dev)->number == s->spec.devnum) + return net_dev; + } + break; + case BR2684_FIND_BYIFNAME: + list_for_each(lh, &br2684_devs) { + net_dev = list_entry_brdev(lh); + if (!strncmp(net_dev->name, s->spec.ifname, IFNAMSIZ)) + return net_dev; + } + break; + } + return NULL; +} + +static int atm_dev_event(struct notifier_block *this, unsigned long event, + void *arg) +{ + struct atm_dev *atm_dev = arg; + struct list_head *lh; + struct net_device *net_dev; + struct br2684_vcc *brvcc; + struct atm_vcc *atm_vcc; + unsigned long flags; + + pr_debug("event=%ld dev=%p\n", event, atm_dev); + + read_lock_irqsave(&devs_lock, flags); + list_for_each(lh, &br2684_devs) { + net_dev = list_entry_brdev(lh); + + list_for_each_entry(brvcc, &BRPRIV(net_dev)->brvccs, brvccs) { + atm_vcc = brvcc->atmvcc; + if (atm_vcc && brvcc->atmvcc->dev == atm_dev) { + + if (atm_vcc->dev->signal == ATM_PHY_SIG_LOST) + netif_carrier_off(net_dev); + else + netif_carrier_on(net_dev); + + } + } + } + read_unlock_irqrestore(&devs_lock, flags); + + return NOTIFY_DONE; +} + +static struct notifier_block atm_dev_notifier = { + .notifier_call = atm_dev_event, +}; + +/* chained vcc->pop function. Check if we should wake the netif_queue */ +static void br2684_pop(struct atm_vcc *vcc, struct sk_buff *skb) +{ + struct br2684_vcc *brvcc = BR2684_VCC(vcc); + + pr_debug("(vcc %p ; net_dev %p )\n", vcc, brvcc->device); + brvcc->old_pop(vcc, skb); + + /* If the queue space just went up from zero, wake */ + if (atomic_inc_return(&brvcc->qspace) == 1) + netif_wake_queue(brvcc->device); +} + +/* + * Send a packet out a particular vcc. Not to useful right now, but paves + * the way for multiple vcc's per itf. Returns true if we can send, + * otherwise false + */ +static int br2684_xmit_vcc(struct sk_buff *skb, struct net_device *dev, + struct br2684_vcc *brvcc) +{ + struct br2684_dev *brdev = BRPRIV(dev); + struct atm_vcc *atmvcc; + int minheadroom = (brvcc->encaps == e_llc) ? + ((brdev->payload == p_bridged) ? + sizeof(llc_oui_pid_pad) : sizeof(llc_oui_ipv4)) : + ((brdev->payload == p_bridged) ? BR2684_PAD_LEN : 0); + + if (skb_headroom(skb) < minheadroom) { + struct sk_buff *skb2 = skb_realloc_headroom(skb, minheadroom); + brvcc->copies_needed++; + dev_kfree_skb(skb); + if (skb2 == NULL) { + brvcc->copies_failed++; + return 0; + } + skb = skb2; + } + + if (brvcc->encaps == e_llc) { + if (brdev->payload == p_bridged) { + skb_push(skb, sizeof(llc_oui_pid_pad)); + skb_copy_to_linear_data(skb, llc_oui_pid_pad, + sizeof(llc_oui_pid_pad)); + } else if (brdev->payload == p_routed) { + unsigned short prot = ntohs(skb->protocol); + + skb_push(skb, sizeof(llc_oui_ipv4)); + switch (prot) { + case ETH_P_IP: + skb_copy_to_linear_data(skb, llc_oui_ipv4, + sizeof(llc_oui_ipv4)); + break; + case ETH_P_IPV6: + skb_copy_to_linear_data(skb, llc_oui_ipv6, + sizeof(llc_oui_ipv6)); + break; + default: + dev_kfree_skb(skb); + return 0; + } + } + } else { /* e_vc */ + if (brdev->payload == p_bridged) { + skb_push(skb, 2); + memset(skb->data, 0, 2); + } + } + skb_debug(skb); + + ATM_SKB(skb)->vcc = atmvcc = brvcc->atmvcc; + pr_debug("atm_skb(%p)->vcc(%p)->dev(%p)\n", skb, atmvcc, atmvcc->dev); + atm_account_tx(atmvcc, skb); + dev->stats.tx_packets++; + dev->stats.tx_bytes += skb->len; + + if (atomic_dec_return(&brvcc->qspace) < 1) { + /* No more please! */ + netif_stop_queue(brvcc->device); + /* We might have raced with br2684_pop() */ + if (unlikely(atomic_read(&brvcc->qspace) > 0)) + netif_wake_queue(brvcc->device); + } + + /* If this fails immediately, the skb will be freed and br2684_pop() + will wake the queue if appropriate. Just return an error so that + the stats are updated correctly */ + return !atmvcc->send(atmvcc, skb); +} + +static void br2684_release_cb(struct atm_vcc *atmvcc) +{ + struct br2684_vcc *brvcc = BR2684_VCC(atmvcc); + + if (atomic_read(&brvcc->qspace) > 0) + netif_wake_queue(brvcc->device); + + if (brvcc->old_release_cb) + brvcc->old_release_cb(atmvcc); +} + +static inline struct br2684_vcc *pick_outgoing_vcc(const struct sk_buff *skb, + const struct br2684_dev *brdev) +{ + return list_empty(&brdev->brvccs) ? NULL : list_entry_brvcc(brdev->brvccs.next); /* 1 vcc/dev right now */ +} + +static netdev_tx_t br2684_start_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct br2684_dev *brdev = BRPRIV(dev); + struct br2684_vcc *brvcc; + struct atm_vcc *atmvcc; + netdev_tx_t ret = NETDEV_TX_OK; + + pr_debug("skb_dst(skb)=%p\n", skb_dst(skb)); + read_lock(&devs_lock); + brvcc = pick_outgoing_vcc(skb, brdev); + if (brvcc == NULL) { + pr_debug("no vcc attached to dev %s\n", dev->name); + dev->stats.tx_errors++; + dev->stats.tx_carrier_errors++; + /* netif_stop_queue(dev); */ + dev_kfree_skb(skb); + goto out_devs; + } + atmvcc = brvcc->atmvcc; + + bh_lock_sock(sk_atm(atmvcc)); + + if (test_bit(ATM_VF_RELEASED, &atmvcc->flags) || + test_bit(ATM_VF_CLOSE, &atmvcc->flags) || + !test_bit(ATM_VF_READY, &atmvcc->flags)) { + dev->stats.tx_dropped++; + dev_kfree_skb(skb); + goto out; + } + + if (sock_owned_by_user(sk_atm(atmvcc))) { + netif_stop_queue(brvcc->device); + ret = NETDEV_TX_BUSY; + goto out; + } + + if (!br2684_xmit_vcc(skb, dev, brvcc)) { + /* + * We should probably use netif_*_queue() here, but that + * involves added complication. We need to walk before + * we can run. + * + * Don't free here! this pointer might be no longer valid! + */ + dev->stats.tx_errors++; + dev->stats.tx_fifo_errors++; + } + out: + bh_unlock_sock(sk_atm(atmvcc)); + out_devs: + read_unlock(&devs_lock); + return ret; +} + +/* + * We remember when the MAC gets set, so we don't override it later with + * the ESI of the ATM card of the first VC + */ +static int br2684_mac_addr(struct net_device *dev, void *p) +{ + int err = eth_mac_addr(dev, p); + if (!err) + BRPRIV(dev)->mac_was_set = 1; + return err; +} + +#ifdef CONFIG_ATM_BR2684_IPFILTER +/* this IOCTL is experimental. */ +static int br2684_setfilt(struct atm_vcc *atmvcc, void __user * arg) +{ + struct br2684_vcc *brvcc; + struct br2684_filter_set fs; + + if (copy_from_user(&fs, arg, sizeof fs)) + return -EFAULT; + if (fs.ifspec.method != BR2684_FIND_BYNOTHING) { + /* + * This is really a per-vcc thing, but we can also search + * by device. + */ + struct br2684_dev *brdev; + read_lock(&devs_lock); + brdev = BRPRIV(br2684_find_dev(&fs.ifspec)); + if (brdev == NULL || list_empty(&brdev->brvccs) || + brdev->brvccs.next != brdev->brvccs.prev) /* >1 VCC */ + brvcc = NULL; + else + brvcc = list_entry_brvcc(brdev->brvccs.next); + read_unlock(&devs_lock); + if (brvcc == NULL) + return -ESRCH; + } else + brvcc = BR2684_VCC(atmvcc); + memcpy(&brvcc->filter, &fs.filter, sizeof(brvcc->filter)); + return 0; +} + +/* Returns 1 if packet should be dropped */ +static inline int +packet_fails_filter(__be16 type, struct br2684_vcc *brvcc, struct sk_buff *skb) +{ + if (brvcc->filter.netmask == 0) + return 0; /* no filter in place */ + if (type == htons(ETH_P_IP) && + (((struct iphdr *)(skb->data))->daddr & brvcc->filter. + netmask) == brvcc->filter.prefix) + return 0; + if (type == htons(ETH_P_ARP)) + return 0; + /* + * TODO: we should probably filter ARPs too.. don't want to have + * them returning values that don't make sense, or is that ok? + */ + return 1; /* drop */ +} +#endif /* CONFIG_ATM_BR2684_IPFILTER */ + +static void br2684_close_vcc(struct br2684_vcc *brvcc) +{ + pr_debug("removing VCC %p from dev %p\n", brvcc, brvcc->device); + write_lock_irq(&devs_lock); + list_del(&brvcc->brvccs); + write_unlock_irq(&devs_lock); + brvcc->atmvcc->user_back = NULL; /* what about vcc->recvq ??? */ + brvcc->atmvcc->release_cb = brvcc->old_release_cb; + brvcc->old_push(brvcc->atmvcc, NULL); /* pass on the bad news */ + module_put(brvcc->old_owner); + kfree(brvcc); +} + +/* when AAL5 PDU comes in: */ +static void br2684_push(struct atm_vcc *atmvcc, struct sk_buff *skb) +{ + struct br2684_vcc *brvcc = BR2684_VCC(atmvcc); + struct net_device *net_dev = brvcc->device; + struct br2684_dev *brdev = BRPRIV(net_dev); + + pr_debug("\n"); + + if (unlikely(skb == NULL)) { + /* skb==NULL means VCC is being destroyed */ + br2684_close_vcc(brvcc); + if (list_empty(&brdev->brvccs)) { + write_lock_irq(&devs_lock); + list_del(&brdev->br2684_devs); + write_unlock_irq(&devs_lock); + unregister_netdev(net_dev); + free_netdev(net_dev); + } + return; + } + + skb_debug(skb); + atm_return(atmvcc, skb->truesize); + pr_debug("skb from brdev %p\n", brdev); + if (brvcc->encaps == e_llc) { + + if (skb->len > 7 && skb->data[7] == 0x01) + __skb_trim(skb, skb->len - 4); + + /* accept packets that have "ipv[46]" in the snap header */ + if ((skb->len >= (sizeof(llc_oui_ipv4))) && + (memcmp(skb->data, llc_oui_ipv4, + sizeof(llc_oui_ipv4) - BR2684_ETHERTYPE_LEN) == 0)) { + if (memcmp(skb->data + 6, ethertype_ipv6, + sizeof(ethertype_ipv6)) == 0) + skb->protocol = htons(ETH_P_IPV6); + else if (memcmp(skb->data + 6, ethertype_ipv4, + sizeof(ethertype_ipv4)) == 0) + skb->protocol = htons(ETH_P_IP); + else + goto error; + skb_pull(skb, sizeof(llc_oui_ipv4)); + skb_reset_network_header(skb); + skb->pkt_type = PACKET_HOST; + /* + * Let us waste some time for checking the encapsulation. + * Note, that only 7 char is checked so frames with a valid FCS + * are also accepted (but FCS is not checked of course). + */ + } else if ((skb->len >= sizeof(llc_oui_pid_pad)) && + (memcmp(skb->data, llc_oui_pid_pad, 7) == 0)) { + skb_pull(skb, sizeof(llc_oui_pid_pad)); + skb->protocol = eth_type_trans(skb, net_dev); + } else + goto error; + + } else { /* e_vc */ + if (brdev->payload == p_routed) { + struct iphdr *iph; + + skb_reset_network_header(skb); + iph = ip_hdr(skb); + if (iph->version == 4) + skb->protocol = htons(ETH_P_IP); + else if (iph->version == 6) + skb->protocol = htons(ETH_P_IPV6); + else + goto error; + skb->pkt_type = PACKET_HOST; + } else { /* p_bridged */ + /* first 2 chars should be 0 */ + if (memcmp(skb->data, pad, BR2684_PAD_LEN) != 0) + goto error; + skb_pull(skb, BR2684_PAD_LEN); + skb->protocol = eth_type_trans(skb, net_dev); + } + } + +#ifdef CONFIG_ATM_BR2684_IPFILTER + if (unlikely(packet_fails_filter(skb->protocol, brvcc, skb))) + goto dropped; +#endif /* CONFIG_ATM_BR2684_IPFILTER */ + skb->dev = net_dev; + ATM_SKB(skb)->vcc = atmvcc; /* needed ? */ + pr_debug("received packet's protocol: %x\n", ntohs(skb->protocol)); + skb_debug(skb); + /* sigh, interface is down? */ + if (unlikely(!(net_dev->flags & IFF_UP))) + goto dropped; + net_dev->stats.rx_packets++; + net_dev->stats.rx_bytes += skb->len; + memset(ATM_SKB(skb), 0, sizeof(struct atm_skb_data)); + netif_rx(skb); + return; + +dropped: + net_dev->stats.rx_dropped++; + goto free_skb; +error: + net_dev->stats.rx_errors++; +free_skb: + dev_kfree_skb(skb); +} + +/* + * Assign a vcc to a dev + * Note: we do not have explicit unassign, but look at _push() + */ +static int br2684_regvcc(struct atm_vcc *atmvcc, void __user * arg) +{ + struct br2684_vcc *brvcc; + struct br2684_dev *brdev; + struct net_device *net_dev; + struct atm_backend_br2684 be; + int err; + + if (copy_from_user(&be, arg, sizeof be)) + return -EFAULT; + brvcc = kzalloc(sizeof(struct br2684_vcc), GFP_KERNEL); + if (!brvcc) + return -ENOMEM; + /* + * Allow two packets in the ATM queue. One actually being sent, and one + * for the ATM 'TX done' handler to send. It shouldn't take long to get + * the next one from the netdev queue, when we need it. More than that + * would be bufferbloat. + */ + atomic_set(&brvcc->qspace, 2); + write_lock_irq(&devs_lock); + net_dev = br2684_find_dev(&be.ifspec); + if (net_dev == NULL) { + pr_err("tried to attach to non-existent device\n"); + err = -ENXIO; + goto error; + } + brdev = BRPRIV(net_dev); + if (atmvcc->push == NULL) { + err = -EBADFD; + goto error; + } + if (!list_empty(&brdev->brvccs)) { + /* Only 1 VCC/dev right now */ + err = -EEXIST; + goto error; + } + if (be.fcs_in != BR2684_FCSIN_NO || + be.fcs_out != BR2684_FCSOUT_NO || + be.fcs_auto || be.has_vpiid || be.send_padding || + (be.encaps != BR2684_ENCAPS_VC && + be.encaps != BR2684_ENCAPS_LLC) || + be.min_size != 0) { + err = -EINVAL; + goto error; + } + pr_debug("vcc=%p, encaps=%d, brvcc=%p\n", atmvcc, be.encaps, brvcc); + if (list_empty(&brdev->brvccs) && !brdev->mac_was_set) { + unsigned char *esi = atmvcc->dev->esi; + const u8 one = 1; + + if (esi[0] | esi[1] | esi[2] | esi[3] | esi[4] | esi[5]) + dev_addr_set(net_dev, esi); + else + dev_addr_mod(net_dev, 2, &one, 1); + } + list_add(&brvcc->brvccs, &brdev->brvccs); + write_unlock_irq(&devs_lock); + brvcc->device = net_dev; + brvcc->atmvcc = atmvcc; + atmvcc->user_back = brvcc; + brvcc->encaps = (enum br2684_encaps)be.encaps; + brvcc->old_push = atmvcc->push; + brvcc->old_pop = atmvcc->pop; + brvcc->old_release_cb = atmvcc->release_cb; + brvcc->old_owner = atmvcc->owner; + barrier(); + atmvcc->push = br2684_push; + atmvcc->pop = br2684_pop; + atmvcc->release_cb = br2684_release_cb; + atmvcc->owner = THIS_MODULE; + + /* initialize netdev carrier state */ + if (atmvcc->dev->signal == ATM_PHY_SIG_LOST) + netif_carrier_off(net_dev); + else + netif_carrier_on(net_dev); + + __module_get(THIS_MODULE); + + /* re-process everything received between connection setup and + backend setup */ + vcc_process_recv_queue(atmvcc); + return 0; + +error: + write_unlock_irq(&devs_lock); + kfree(brvcc); + return err; +} + +static const struct net_device_ops br2684_netdev_ops = { + .ndo_start_xmit = br2684_start_xmit, + .ndo_set_mac_address = br2684_mac_addr, + .ndo_validate_addr = eth_validate_addr, +}; + +static const struct net_device_ops br2684_netdev_ops_routed = { + .ndo_start_xmit = br2684_start_xmit, + .ndo_set_mac_address = br2684_mac_addr, +}; + +static void br2684_setup(struct net_device *netdev) +{ + struct br2684_dev *brdev = BRPRIV(netdev); + + ether_setup(netdev); + netdev->hard_header_len += sizeof(llc_oui_pid_pad); /* worst case */ + brdev->net_dev = netdev; + + netdev->netdev_ops = &br2684_netdev_ops; + + INIT_LIST_HEAD(&brdev->brvccs); +} + +static void br2684_setup_routed(struct net_device *netdev) +{ + struct br2684_dev *brdev = BRPRIV(netdev); + + brdev->net_dev = netdev; + netdev->hard_header_len = sizeof(llc_oui_ipv4); /* worst case */ + netdev->netdev_ops = &br2684_netdev_ops_routed; + netdev->addr_len = 0; + netdev->mtu = ETH_DATA_LEN; + netdev->min_mtu = 0; + netdev->max_mtu = ETH_MAX_MTU; + netdev->type = ARPHRD_PPP; + netdev->flags = IFF_POINTOPOINT | IFF_NOARP | IFF_MULTICAST; + netdev->tx_queue_len = 100; + INIT_LIST_HEAD(&brdev->brvccs); +} + +static int br2684_create(void __user *arg) +{ + int err; + struct net_device *netdev; + struct br2684_dev *brdev; + struct atm_newif_br2684 ni; + enum br2684_payload payload; + + pr_debug("\n"); + + if (copy_from_user(&ni, arg, sizeof ni)) + return -EFAULT; + + if (ni.media & BR2684_FLAG_ROUTED) + payload = p_routed; + else + payload = p_bridged; + ni.media &= 0xffff; /* strip flags */ + + if (ni.media != BR2684_MEDIA_ETHERNET || ni.mtu != 1500) + return -EINVAL; + + netdev = alloc_netdev(sizeof(struct br2684_dev), + ni.ifname[0] ? ni.ifname : "nas%d", + NET_NAME_UNKNOWN, + (payload == p_routed) ? br2684_setup_routed : br2684_setup); + if (!netdev) + return -ENOMEM; + + brdev = BRPRIV(netdev); + + pr_debug("registered netdev %s\n", netdev->name); + /* open, stop, do_ioctl ? */ + err = register_netdev(netdev); + if (err < 0) { + pr_err("register_netdev failed\n"); + free_netdev(netdev); + return err; + } + + write_lock_irq(&devs_lock); + + brdev->payload = payload; + + if (list_empty(&br2684_devs)) { + /* 1st br2684 device */ + brdev->number = 1; + } else + brdev->number = BRPRIV(list_entry_brdev(br2684_devs.prev))->number + 1; + + list_add_tail(&brdev->br2684_devs, &br2684_devs); + write_unlock_irq(&devs_lock); + return 0; +} + +/* + * This handles ioctls actually performed on our vcc - we must return + * -ENOIOCTLCMD for any unrecognized ioctl + */ +static int br2684_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + struct atm_vcc *atmvcc = ATM_SD(sock); + void __user *argp = (void __user *)arg; + atm_backend_t b; + + int err; + switch (cmd) { + case ATM_SETBACKEND: + case ATM_NEWBACKENDIF: + err = get_user(b, (atm_backend_t __user *) argp); + if (err) + return -EFAULT; + if (b != ATM_BACKEND_BR2684) + return -ENOIOCTLCMD; + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + if (cmd == ATM_SETBACKEND) { + if (sock->state != SS_CONNECTED) + return -EINVAL; + return br2684_regvcc(atmvcc, argp); + } else { + return br2684_create(argp); + } +#ifdef CONFIG_ATM_BR2684_IPFILTER + case BR2684_SETFILT: + if (atmvcc->push != br2684_push) + return -ENOIOCTLCMD; + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + err = br2684_setfilt(atmvcc, argp); + + return err; +#endif /* CONFIG_ATM_BR2684_IPFILTER */ + } + return -ENOIOCTLCMD; +} + +static struct atm_ioctl br2684_ioctl_ops = { + .owner = THIS_MODULE, + .ioctl = br2684_ioctl, +}; + +#ifdef CONFIG_PROC_FS +static void *br2684_seq_start(struct seq_file *seq, loff_t * pos) + __acquires(devs_lock) +{ + read_lock(&devs_lock); + return seq_list_start(&br2684_devs, *pos); +} + +static void *br2684_seq_next(struct seq_file *seq, void *v, loff_t * pos) +{ + return seq_list_next(v, &br2684_devs, pos); +} + +static void br2684_seq_stop(struct seq_file *seq, void *v) + __releases(devs_lock) +{ + read_unlock(&devs_lock); +} + +static int br2684_seq_show(struct seq_file *seq, void *v) +{ + const struct br2684_dev *brdev = list_entry(v, struct br2684_dev, + br2684_devs); + const struct net_device *net_dev = brdev->net_dev; + const struct br2684_vcc *brvcc; + + seq_printf(seq, "dev %.16s: num=%d, mac=%pM (%s)\n", + net_dev->name, + brdev->number, + net_dev->dev_addr, + brdev->mac_was_set ? "set" : "auto"); + + list_for_each_entry(brvcc, &brdev->brvccs, brvccs) { + seq_printf(seq, " vcc %d.%d.%d: encaps=%s payload=%s" + ", failed copies %u/%u" + "\n", brvcc->atmvcc->dev->number, + brvcc->atmvcc->vpi, brvcc->atmvcc->vci, + (brvcc->encaps == e_llc) ? "LLC" : "VC", + (brdev->payload == p_bridged) ? "bridged" : "routed", + brvcc->copies_failed, brvcc->copies_needed); +#ifdef CONFIG_ATM_BR2684_IPFILTER + if (brvcc->filter.netmask != 0) + seq_printf(seq, " filter=%pI4/%pI4\n", + &brvcc->filter.prefix, + &brvcc->filter.netmask); +#endif /* CONFIG_ATM_BR2684_IPFILTER */ + } + return 0; +} + +static const struct seq_operations br2684_seq_ops = { + .start = br2684_seq_start, + .next = br2684_seq_next, + .stop = br2684_seq_stop, + .show = br2684_seq_show, +}; + +extern struct proc_dir_entry *atm_proc_root; /* from proc.c */ +#endif /* CONFIG_PROC_FS */ + +static int __init br2684_init(void) +{ +#ifdef CONFIG_PROC_FS + struct proc_dir_entry *p; + p = proc_create_seq("br2684", 0, atm_proc_root, &br2684_seq_ops); + if (p == NULL) + return -ENOMEM; +#endif + register_atm_ioctl(&br2684_ioctl_ops); + register_atmdevice_notifier(&atm_dev_notifier); + return 0; +} + +static void __exit br2684_exit(void) +{ + struct net_device *net_dev; + struct br2684_dev *brdev; + struct br2684_vcc *brvcc; + deregister_atm_ioctl(&br2684_ioctl_ops); + +#ifdef CONFIG_PROC_FS + remove_proc_entry("br2684", atm_proc_root); +#endif + + + unregister_atmdevice_notifier(&atm_dev_notifier); + + while (!list_empty(&br2684_devs)) { + net_dev = list_entry_brdev(br2684_devs.next); + brdev = BRPRIV(net_dev); + while (!list_empty(&brdev->brvccs)) { + brvcc = list_entry_brvcc(brdev->brvccs.next); + br2684_close_vcc(brvcc); + } + + list_del(&brdev->br2684_devs); + unregister_netdev(net_dev); + free_netdev(net_dev); + } +} + +module_init(br2684_init); +module_exit(br2684_exit); + +MODULE_AUTHOR("Marcell GAL"); +MODULE_DESCRIPTION("RFC2684 bridged protocols over ATM/AAL5"); +MODULE_LICENSE("GPL"); diff --git a/net/atm/clip.c b/net/atm/clip.c new file mode 100644 index 000000000..294cb9efe --- /dev/null +++ b/net/atm/clip.c @@ -0,0 +1,929 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* net/atm/clip.c - RFC1577 Classical IP over ATM */ + +/* Written 1995-2000 by Werner Almesberger, EPFL LRC/ICA */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s: " fmt, __func__ + +#include +#include +#include /* for UINT_MAX */ +#include +#include +#include +#include +#include +#include +#include /* for some manifest constants */ +#include +#include +#include +#include +#include +#include +#include /* for net/route.h */ +#include /* for struct sockaddr_in */ +#include /* for IFF_UP */ +#include +#include +#include +#include +#include +#include +#include +#include +#include /* for struct rtable and routing */ +#include /* icmp_send */ +#include +#include /* for HZ */ +#include +#include /* for htons etc. */ +#include + +#include "common.h" +#include "resources.h" +#include + +static struct net_device *clip_devs; +static struct atm_vcc *atmarpd; +static struct timer_list idle_timer; +static const struct neigh_ops clip_neigh_ops; + +static int to_atmarpd(enum atmarp_ctrl_type type, int itf, __be32 ip) +{ + struct sock *sk; + struct atmarp_ctrl *ctrl; + struct sk_buff *skb; + + pr_debug("(%d)\n", type); + if (!atmarpd) + return -EUNATCH; + skb = alloc_skb(sizeof(struct atmarp_ctrl), GFP_ATOMIC); + if (!skb) + return -ENOMEM; + ctrl = skb_put(skb, sizeof(struct atmarp_ctrl)); + ctrl->type = type; + ctrl->itf_num = itf; + ctrl->ip = ip; + atm_force_charge(atmarpd, skb->truesize); + + sk = sk_atm(atmarpd); + skb_queue_tail(&sk->sk_receive_queue, skb); + sk->sk_data_ready(sk); + return 0; +} + +static void link_vcc(struct clip_vcc *clip_vcc, struct atmarp_entry *entry) +{ + pr_debug("%p to entry %p (neigh %p)\n", clip_vcc, entry, entry->neigh); + clip_vcc->entry = entry; + clip_vcc->xoff = 0; /* @@@ may overrun buffer by one packet */ + clip_vcc->next = entry->vccs; + entry->vccs = clip_vcc; + entry->neigh->used = jiffies; +} + +static void unlink_clip_vcc(struct clip_vcc *clip_vcc) +{ + struct atmarp_entry *entry = clip_vcc->entry; + struct clip_vcc **walk; + + if (!entry) { + pr_err("!clip_vcc->entry (clip_vcc %p)\n", clip_vcc); + return; + } + netif_tx_lock_bh(entry->neigh->dev); /* block clip_start_xmit() */ + entry->neigh->used = jiffies; + for (walk = &entry->vccs; *walk; walk = &(*walk)->next) + if (*walk == clip_vcc) { + int error; + + *walk = clip_vcc->next; /* atomic */ + clip_vcc->entry = NULL; + if (clip_vcc->xoff) + netif_wake_queue(entry->neigh->dev); + if (entry->vccs) + goto out; + entry->expires = jiffies - 1; + /* force resolution or expiration */ + error = neigh_update(entry->neigh, NULL, NUD_NONE, + NEIGH_UPDATE_F_ADMIN, 0); + if (error) + pr_err("neigh_update failed with %d\n", error); + goto out; + } + pr_err("ATMARP: failed (entry %p, vcc 0x%p)\n", entry, clip_vcc); +out: + netif_tx_unlock_bh(entry->neigh->dev); +} + +/* The neighbour entry n->lock is held. */ +static int neigh_check_cb(struct neighbour *n) +{ + struct atmarp_entry *entry = neighbour_priv(n); + struct clip_vcc *cv; + + if (n->ops != &clip_neigh_ops) + return 0; + for (cv = entry->vccs; cv; cv = cv->next) { + unsigned long exp = cv->last_use + cv->idle_timeout; + + if (cv->idle_timeout && time_after(jiffies, exp)) { + pr_debug("releasing vcc %p->%p of entry %p\n", + cv, cv->vcc, entry); + vcc_release_async(cv->vcc, -ETIMEDOUT); + } + } + + if (entry->vccs || time_before(jiffies, entry->expires)) + return 0; + + if (refcount_read(&n->refcnt) > 1) { + struct sk_buff *skb; + + pr_debug("destruction postponed with ref %d\n", + refcount_read(&n->refcnt)); + + while ((skb = skb_dequeue(&n->arp_queue)) != NULL) + dev_kfree_skb(skb); + + return 0; + } + + pr_debug("expired neigh %p\n", n); + return 1; +} + +static void idle_timer_check(struct timer_list *unused) +{ + write_lock(&arp_tbl.lock); + __neigh_for_each_release(&arp_tbl, neigh_check_cb); + mod_timer(&idle_timer, jiffies + CLIP_CHECK_INTERVAL * HZ); + write_unlock(&arp_tbl.lock); +} + +static int clip_arp_rcv(struct sk_buff *skb) +{ + struct atm_vcc *vcc; + + pr_debug("\n"); + vcc = ATM_SKB(skb)->vcc; + if (!vcc || !atm_charge(vcc, skb->truesize)) { + dev_kfree_skb_any(skb); + return 0; + } + pr_debug("pushing to %p\n", vcc); + pr_debug("using %p\n", CLIP_VCC(vcc)->old_push); + CLIP_VCC(vcc)->old_push(vcc, skb); + return 0; +} + +static const unsigned char llc_oui[] = { + 0xaa, /* DSAP: non-ISO */ + 0xaa, /* SSAP: non-ISO */ + 0x03, /* Ctrl: Unnumbered Information Command PDU */ + 0x00, /* OUI: EtherType */ + 0x00, + 0x00 +}; + +static void clip_push(struct atm_vcc *vcc, struct sk_buff *skb) +{ + struct clip_vcc *clip_vcc = CLIP_VCC(vcc); + + pr_debug("\n"); + + if (!clip_devs) { + atm_return(vcc, skb->truesize); + kfree_skb(skb); + return; + } + + if (!skb) { + pr_debug("removing VCC %p\n", clip_vcc); + if (clip_vcc->entry) + unlink_clip_vcc(clip_vcc); + clip_vcc->old_push(vcc, NULL); /* pass on the bad news */ + kfree(clip_vcc); + return; + } + atm_return(vcc, skb->truesize); + skb->dev = clip_vcc->entry ? clip_vcc->entry->neigh->dev : clip_devs; + /* clip_vcc->entry == NULL if we don't have an IP address yet */ + if (!skb->dev) { + dev_kfree_skb_any(skb); + return; + } + ATM_SKB(skb)->vcc = vcc; + skb_reset_mac_header(skb); + if (!clip_vcc->encap || + skb->len < RFC1483LLC_LEN || + memcmp(skb->data, llc_oui, sizeof(llc_oui))) + skb->protocol = htons(ETH_P_IP); + else { + skb->protocol = ((__be16 *)skb->data)[3]; + skb_pull(skb, RFC1483LLC_LEN); + if (skb->protocol == htons(ETH_P_ARP)) { + skb->dev->stats.rx_packets++; + skb->dev->stats.rx_bytes += skb->len; + clip_arp_rcv(skb); + return; + } + } + clip_vcc->last_use = jiffies; + skb->dev->stats.rx_packets++; + skb->dev->stats.rx_bytes += skb->len; + memset(ATM_SKB(skb), 0, sizeof(struct atm_skb_data)); + netif_rx(skb); +} + +/* + * Note: these spinlocks _must_not_ block on non-SMP. The only goal is that + * clip_pop is atomic with respect to the critical section in clip_start_xmit. + */ + +static void clip_pop(struct atm_vcc *vcc, struct sk_buff *skb) +{ + struct clip_vcc *clip_vcc = CLIP_VCC(vcc); + struct net_device *dev = skb->dev; + int old; + unsigned long flags; + + pr_debug("(vcc %p)\n", vcc); + clip_vcc->old_pop(vcc, skb); + /* skb->dev == NULL in outbound ARP packets */ + if (!dev) + return; + spin_lock_irqsave(&PRIV(dev)->xoff_lock, flags); + if (atm_may_send(vcc, 0)) { + old = xchg(&clip_vcc->xoff, 0); + if (old) + netif_wake_queue(dev); + } + spin_unlock_irqrestore(&PRIV(dev)->xoff_lock, flags); +} + +static void clip_neigh_solicit(struct neighbour *neigh, struct sk_buff *skb) +{ + __be32 *ip = (__be32 *) neigh->primary_key; + + pr_debug("(neigh %p, skb %p)\n", neigh, skb); + to_atmarpd(act_need, PRIV(neigh->dev)->number, *ip); +} + +static void clip_neigh_error(struct neighbour *neigh, struct sk_buff *skb) +{ +#ifndef CONFIG_ATM_CLIP_NO_ICMP + icmp_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0); +#endif + kfree_skb(skb); +} + +static const struct neigh_ops clip_neigh_ops = { + .family = AF_INET, + .solicit = clip_neigh_solicit, + .error_report = clip_neigh_error, + .output = neigh_direct_output, + .connected_output = neigh_direct_output, +}; + +static int clip_constructor(struct net_device *dev, struct neighbour *neigh) +{ + struct atmarp_entry *entry = neighbour_priv(neigh); + + if (neigh->tbl->family != AF_INET) + return -EINVAL; + + if (neigh->type != RTN_UNICAST) + return -EINVAL; + + neigh->nud_state = NUD_NONE; + neigh->ops = &clip_neigh_ops; + neigh->output = neigh->ops->output; + entry->neigh = neigh; + entry->vccs = NULL; + entry->expires = jiffies - 1; + + return 0; +} + +/* @@@ copy bh locking from arp.c -- need to bh-enable atm code before */ + +/* + * We play with the resolve flag: 0 and 1 have the usual meaning, but -1 means + * to allocate the neighbour entry but not to ask atmarpd for resolution. Also, + * don't increment the usage count. This is used to create entries in + * clip_setentry. + */ + +static int clip_encap(struct atm_vcc *vcc, int mode) +{ + if (!CLIP_VCC(vcc)) + return -EBADFD; + + CLIP_VCC(vcc)->encap = mode; + return 0; +} + +static netdev_tx_t clip_start_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct clip_priv *clip_priv = PRIV(dev); + struct dst_entry *dst = skb_dst(skb); + struct atmarp_entry *entry; + struct neighbour *n; + struct atm_vcc *vcc; + struct rtable *rt; + __be32 *daddr; + int old; + unsigned long flags; + + pr_debug("(skb %p)\n", skb); + if (!dst) { + pr_err("skb_dst(skb) == NULL\n"); + dev_kfree_skb(skb); + dev->stats.tx_dropped++; + return NETDEV_TX_OK; + } + rt = (struct rtable *) dst; + if (rt->rt_gw_family == AF_INET) + daddr = &rt->rt_gw4; + else + daddr = &ip_hdr(skb)->daddr; + n = dst_neigh_lookup(dst, daddr); + if (!n) { + pr_err("NO NEIGHBOUR !\n"); + dev_kfree_skb(skb); + dev->stats.tx_dropped++; + return NETDEV_TX_OK; + } + entry = neighbour_priv(n); + if (!entry->vccs) { + if (time_after(jiffies, entry->expires)) { + /* should be resolved */ + entry->expires = jiffies + ATMARP_RETRY_DELAY * HZ; + to_atmarpd(act_need, PRIV(dev)->number, *((__be32 *)n->primary_key)); + } + if (entry->neigh->arp_queue.qlen < ATMARP_MAX_UNRES_PACKETS) + skb_queue_tail(&entry->neigh->arp_queue, skb); + else { + dev_kfree_skb(skb); + dev->stats.tx_dropped++; + } + goto out_release_neigh; + } + pr_debug("neigh %p, vccs %p\n", entry, entry->vccs); + ATM_SKB(skb)->vcc = vcc = entry->vccs->vcc; + pr_debug("using neighbour %p, vcc %p\n", n, vcc); + if (entry->vccs->encap) { + void *here; + + here = skb_push(skb, RFC1483LLC_LEN); + memcpy(here, llc_oui, sizeof(llc_oui)); + ((__be16 *) here)[3] = skb->protocol; + } + atm_account_tx(vcc, skb); + entry->vccs->last_use = jiffies; + pr_debug("atm_skb(%p)->vcc(%p)->dev(%p)\n", skb, vcc, vcc->dev); + old = xchg(&entry->vccs->xoff, 1); /* assume XOFF ... */ + if (old) { + pr_warn("XOFF->XOFF transition\n"); + goto out_release_neigh; + } + dev->stats.tx_packets++; + dev->stats.tx_bytes += skb->len; + vcc->send(vcc, skb); + if (atm_may_send(vcc, 0)) { + entry->vccs->xoff = 0; + goto out_release_neigh; + } + spin_lock_irqsave(&clip_priv->xoff_lock, flags); + netif_stop_queue(dev); /* XOFF -> throttle immediately */ + barrier(); + if (!entry->vccs->xoff) + netif_start_queue(dev); + /* Oh, we just raced with clip_pop. netif_start_queue should be + good enough, because nothing should really be asleep because + of the brief netif_stop_queue. If this isn't true or if it + changes, use netif_wake_queue instead. */ + spin_unlock_irqrestore(&clip_priv->xoff_lock, flags); +out_release_neigh: + neigh_release(n); + return NETDEV_TX_OK; +} + +static int clip_mkip(struct atm_vcc *vcc, int timeout) +{ + struct clip_vcc *clip_vcc; + + if (!vcc->push) + return -EBADFD; + clip_vcc = kmalloc(sizeof(struct clip_vcc), GFP_KERNEL); + if (!clip_vcc) + return -ENOMEM; + pr_debug("%p vcc %p\n", clip_vcc, vcc); + clip_vcc->vcc = vcc; + vcc->user_back = clip_vcc; + set_bit(ATM_VF_IS_CLIP, &vcc->flags); + clip_vcc->entry = NULL; + clip_vcc->xoff = 0; + clip_vcc->encap = 1; + clip_vcc->last_use = jiffies; + clip_vcc->idle_timeout = timeout * HZ; + clip_vcc->old_push = vcc->push; + clip_vcc->old_pop = vcc->pop; + vcc->push = clip_push; + vcc->pop = clip_pop; + + /* re-process everything received between connection setup and MKIP */ + vcc_process_recv_queue(vcc); + + return 0; +} + +static int clip_setentry(struct atm_vcc *vcc, __be32 ip) +{ + struct neighbour *neigh; + struct atmarp_entry *entry; + int error; + struct clip_vcc *clip_vcc; + struct rtable *rt; + + if (vcc->push != clip_push) { + pr_warn("non-CLIP VCC\n"); + return -EBADF; + } + clip_vcc = CLIP_VCC(vcc); + if (!ip) { + if (!clip_vcc->entry) { + pr_err("hiding hidden ATMARP entry\n"); + return 0; + } + pr_debug("remove\n"); + unlink_clip_vcc(clip_vcc); + return 0; + } + rt = ip_route_output(&init_net, ip, 0, 1, 0); + if (IS_ERR(rt)) + return PTR_ERR(rt); + neigh = __neigh_lookup(&arp_tbl, &ip, rt->dst.dev, 1); + ip_rt_put(rt); + if (!neigh) + return -ENOMEM; + entry = neighbour_priv(neigh); + if (entry != clip_vcc->entry) { + if (!clip_vcc->entry) + pr_debug("add\n"); + else { + pr_debug("update\n"); + unlink_clip_vcc(clip_vcc); + } + link_vcc(clip_vcc, entry); + } + error = neigh_update(neigh, llc_oui, NUD_PERMANENT, + NEIGH_UPDATE_F_OVERRIDE | NEIGH_UPDATE_F_ADMIN, 0); + neigh_release(neigh); + return error; +} + +static const struct net_device_ops clip_netdev_ops = { + .ndo_start_xmit = clip_start_xmit, + .ndo_neigh_construct = clip_constructor, +}; + +static void clip_setup(struct net_device *dev) +{ + dev->netdev_ops = &clip_netdev_ops; + dev->type = ARPHRD_ATM; + dev->neigh_priv_len = sizeof(struct atmarp_entry); + dev->hard_header_len = RFC1483LLC_LEN; + dev->mtu = RFC1626_MTU; + dev->tx_queue_len = 100; /* "normal" queue (packets) */ + /* When using a "real" qdisc, the qdisc determines the queue */ + /* length. tx_queue_len is only used for the default case, */ + /* without any more elaborate queuing. 100 is a reasonable */ + /* compromise between decent burst-tolerance and protection */ + /* against memory hogs. */ + netif_keep_dst(dev); +} + +static int clip_create(int number) +{ + struct net_device *dev; + struct clip_priv *clip_priv; + int error; + + if (number != -1) { + for (dev = clip_devs; dev; dev = PRIV(dev)->next) + if (PRIV(dev)->number == number) + return -EEXIST; + } else { + number = 0; + for (dev = clip_devs; dev; dev = PRIV(dev)->next) + if (PRIV(dev)->number >= number) + number = PRIV(dev)->number + 1; + } + dev = alloc_netdev(sizeof(struct clip_priv), "", NET_NAME_UNKNOWN, + clip_setup); + if (!dev) + return -ENOMEM; + clip_priv = PRIV(dev); + sprintf(dev->name, "atm%d", number); + spin_lock_init(&clip_priv->xoff_lock); + clip_priv->number = number; + error = register_netdev(dev); + if (error) { + free_netdev(dev); + return error; + } + clip_priv->next = clip_devs; + clip_devs = dev; + pr_debug("registered (net:%s)\n", dev->name); + return number; +} + +static int clip_device_event(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + if (!net_eq(dev_net(dev), &init_net)) + return NOTIFY_DONE; + + if (event == NETDEV_UNREGISTER) + return NOTIFY_DONE; + + /* ignore non-CLIP devices */ + if (dev->type != ARPHRD_ATM || dev->netdev_ops != &clip_netdev_ops) + return NOTIFY_DONE; + + switch (event) { + case NETDEV_UP: + pr_debug("NETDEV_UP\n"); + to_atmarpd(act_up, PRIV(dev)->number, 0); + break; + case NETDEV_GOING_DOWN: + pr_debug("NETDEV_DOWN\n"); + to_atmarpd(act_down, PRIV(dev)->number, 0); + break; + case NETDEV_CHANGE: + case NETDEV_CHANGEMTU: + pr_debug("NETDEV_CHANGE*\n"); + to_atmarpd(act_change, PRIV(dev)->number, 0); + break; + } + return NOTIFY_DONE; +} + +static int clip_inet_event(struct notifier_block *this, unsigned long event, + void *ifa) +{ + struct in_device *in_dev; + struct netdev_notifier_info info; + + in_dev = ((struct in_ifaddr *)ifa)->ifa_dev; + /* + * Transitions are of the down-change-up type, so it's sufficient to + * handle the change on up. + */ + if (event != NETDEV_UP) + return NOTIFY_DONE; + netdev_notifier_info_init(&info, in_dev->dev); + return clip_device_event(this, NETDEV_CHANGE, &info); +} + +static struct notifier_block clip_dev_notifier = { + .notifier_call = clip_device_event, +}; + + + +static struct notifier_block clip_inet_notifier = { + .notifier_call = clip_inet_event, +}; + + + +static void atmarpd_close(struct atm_vcc *vcc) +{ + pr_debug("\n"); + + rtnl_lock(); + atmarpd = NULL; + skb_queue_purge(&sk_atm(vcc)->sk_receive_queue); + rtnl_unlock(); + + pr_debug("(done)\n"); + module_put(THIS_MODULE); +} + +static const struct atmdev_ops atmarpd_dev_ops = { + .close = atmarpd_close +}; + + +static struct atm_dev atmarpd_dev = { + .ops = &atmarpd_dev_ops, + .type = "arpd", + .number = 999, + .lock = __SPIN_LOCK_UNLOCKED(atmarpd_dev.lock) +}; + + +static int atm_init_atmarp(struct atm_vcc *vcc) +{ + rtnl_lock(); + if (atmarpd) { + rtnl_unlock(); + return -EADDRINUSE; + } + + mod_timer(&idle_timer, jiffies + CLIP_CHECK_INTERVAL * HZ); + + atmarpd = vcc; + set_bit(ATM_VF_META, &vcc->flags); + set_bit(ATM_VF_READY, &vcc->flags); + /* allow replies and avoid getting closed if signaling dies */ + vcc->dev = &atmarpd_dev; + vcc_insert_socket(sk_atm(vcc)); + vcc->push = NULL; + vcc->pop = NULL; /* crash */ + vcc->push_oam = NULL; /* crash */ + rtnl_unlock(); + return 0; +} + +static int clip_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + struct atm_vcc *vcc = ATM_SD(sock); + int err = 0; + + switch (cmd) { + case SIOCMKCLIP: + case ATMARPD_CTRL: + case ATMARP_MKIP: + case ATMARP_SETENTRY: + case ATMARP_ENCAP: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + break; + default: + return -ENOIOCTLCMD; + } + + switch (cmd) { + case SIOCMKCLIP: + err = clip_create(arg); + break; + case ATMARPD_CTRL: + err = atm_init_atmarp(vcc); + if (!err) { + sock->state = SS_CONNECTED; + __module_get(THIS_MODULE); + } + break; + case ATMARP_MKIP: + err = clip_mkip(vcc, arg); + break; + case ATMARP_SETENTRY: + err = clip_setentry(vcc, (__force __be32)arg); + break; + case ATMARP_ENCAP: + err = clip_encap(vcc, arg); + break; + } + return err; +} + +static struct atm_ioctl clip_ioctl_ops = { + .owner = THIS_MODULE, + .ioctl = clip_ioctl, +}; + +#ifdef CONFIG_PROC_FS + +static void svc_addr(struct seq_file *seq, struct sockaddr_atmsvc *addr) +{ + static int code[] = { 1, 2, 10, 6, 1, 0 }; + static int e164[] = { 1, 8, 4, 6, 1, 0 }; + + if (*addr->sas_addr.pub) { + seq_printf(seq, "%s", addr->sas_addr.pub); + if (*addr->sas_addr.prv) + seq_putc(seq, '+'); + } else if (!*addr->sas_addr.prv) { + seq_printf(seq, "%s", "(none)"); + return; + } + if (*addr->sas_addr.prv) { + unsigned char *prv = addr->sas_addr.prv; + int *fields; + int i, j; + + fields = *prv == ATM_AFI_E164 ? e164 : code; + for (i = 0; fields[i]; i++) { + for (j = fields[i]; j; j--) + seq_printf(seq, "%02X", *prv++); + if (fields[i + 1]) + seq_putc(seq, '.'); + } + } +} + +/* This means the neighbour entry has no attached VCC objects. */ +#define SEQ_NO_VCC_TOKEN ((void *) 2) + +static void atmarp_info(struct seq_file *seq, struct neighbour *n, + struct atmarp_entry *entry, struct clip_vcc *clip_vcc) +{ + struct net_device *dev = n->dev; + unsigned long exp; + char buf[17]; + int svc, llc, off; + + svc = ((clip_vcc == SEQ_NO_VCC_TOKEN) || + (sk_atm(clip_vcc->vcc)->sk_family == AF_ATMSVC)); + + llc = ((clip_vcc == SEQ_NO_VCC_TOKEN) || clip_vcc->encap); + + if (clip_vcc == SEQ_NO_VCC_TOKEN) + exp = entry->neigh->used; + else + exp = clip_vcc->last_use; + + exp = (jiffies - exp) / HZ; + + seq_printf(seq, "%-6s%-4s%-4s%5ld ", + dev->name, svc ? "SVC" : "PVC", llc ? "LLC" : "NULL", exp); + + off = scnprintf(buf, sizeof(buf) - 1, "%pI4", n->primary_key); + while (off < 16) + buf[off++] = ' '; + buf[off] = '\0'; + seq_printf(seq, "%s", buf); + + if (clip_vcc == SEQ_NO_VCC_TOKEN) { + if (time_before(jiffies, entry->expires)) + seq_printf(seq, "(resolving)\n"); + else + seq_printf(seq, "(expired, ref %d)\n", + refcount_read(&entry->neigh->refcnt)); + } else if (!svc) { + seq_printf(seq, "%d.%d.%d\n", + clip_vcc->vcc->dev->number, + clip_vcc->vcc->vpi, clip_vcc->vcc->vci); + } else { + svc_addr(seq, &clip_vcc->vcc->remote); + seq_putc(seq, '\n'); + } +} + +struct clip_seq_state { + /* This member must be first. */ + struct neigh_seq_state ns; + + /* Local to clip specific iteration. */ + struct clip_vcc *vcc; +}; + +static struct clip_vcc *clip_seq_next_vcc(struct atmarp_entry *e, + struct clip_vcc *curr) +{ + if (!curr) { + curr = e->vccs; + if (!curr) + return SEQ_NO_VCC_TOKEN; + return curr; + } + if (curr == SEQ_NO_VCC_TOKEN) + return NULL; + + curr = curr->next; + + return curr; +} + +static void *clip_seq_vcc_walk(struct clip_seq_state *state, + struct atmarp_entry *e, loff_t * pos) +{ + struct clip_vcc *vcc = state->vcc; + + vcc = clip_seq_next_vcc(e, vcc); + if (vcc && pos != NULL) { + while (*pos) { + vcc = clip_seq_next_vcc(e, vcc); + if (!vcc) + break; + --(*pos); + } + } + state->vcc = vcc; + + return vcc; +} + +static void *clip_seq_sub_iter(struct neigh_seq_state *_state, + struct neighbour *n, loff_t * pos) +{ + struct clip_seq_state *state = (struct clip_seq_state *)_state; + + if (n->dev->type != ARPHRD_ATM) + return NULL; + + return clip_seq_vcc_walk(state, neighbour_priv(n), pos); +} + +static void *clip_seq_start(struct seq_file *seq, loff_t * pos) +{ + struct clip_seq_state *state = seq->private; + state->ns.neigh_sub_iter = clip_seq_sub_iter; + return neigh_seq_start(seq, pos, &arp_tbl, NEIGH_SEQ_NEIGH_ONLY); +} + +static int clip_seq_show(struct seq_file *seq, void *v) +{ + static char atm_arp_banner[] = + "IPitf TypeEncp Idle IP address ATM address\n"; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, atm_arp_banner); + } else { + struct clip_seq_state *state = seq->private; + struct clip_vcc *vcc = state->vcc; + struct neighbour *n = v; + + atmarp_info(seq, n, neighbour_priv(n), vcc); + } + return 0; +} + +static const struct seq_operations arp_seq_ops = { + .start = clip_seq_start, + .next = neigh_seq_next, + .stop = neigh_seq_stop, + .show = clip_seq_show, +}; +#endif + +static void atm_clip_exit_noproc(void); + +static int __init atm_clip_init(void) +{ + register_atm_ioctl(&clip_ioctl_ops); + register_netdevice_notifier(&clip_dev_notifier); + register_inetaddr_notifier(&clip_inet_notifier); + + timer_setup(&idle_timer, idle_timer_check, 0); + +#ifdef CONFIG_PROC_FS + { + struct proc_dir_entry *p; + + p = proc_create_net("arp", 0444, atm_proc_root, &arp_seq_ops, + sizeof(struct clip_seq_state)); + if (!p) { + pr_err("Unable to initialize /proc/net/atm/arp\n"); + atm_clip_exit_noproc(); + return -ENOMEM; + } + } +#endif + + return 0; +} + +static void atm_clip_exit_noproc(void) +{ + struct net_device *dev, *next; + + unregister_inetaddr_notifier(&clip_inet_notifier); + unregister_netdevice_notifier(&clip_dev_notifier); + + deregister_atm_ioctl(&clip_ioctl_ops); + + /* First, stop the idle timer, so it stops banging + * on the table. + */ + del_timer_sync(&idle_timer); + + dev = clip_devs; + while (dev) { + next = PRIV(dev)->next; + unregister_netdev(dev); + free_netdev(dev); + dev = next; + } +} + +static void __exit atm_clip_exit(void) +{ + remove_proc_entry("arp", atm_proc_root); + + atm_clip_exit_noproc(); +} + +module_init(atm_clip_init); +module_exit(atm_clip_exit); +MODULE_AUTHOR("Werner Almesberger"); +MODULE_DESCRIPTION("Classical/IP over ATM interface"); +MODULE_LICENSE("GPL"); diff --git a/net/atm/common.c b/net/atm/common.c new file mode 100644 index 000000000..f7019df41 --- /dev/null +++ b/net/atm/common.c @@ -0,0 +1,895 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* net/atm/common.c - ATM sockets (common part for PVC and SVC) */ + +/* Written 1995-2000 by Werner Almesberger, EPFL LRC/ICA */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s: " fmt, __func__ + +#include +#include +#include /* struct socket, struct proto_ops */ +#include /* ATM stuff */ +#include +#include /* SOL_SOCKET */ +#include /* error codes */ +#include +#include +#include +#include /* 64-bit time for seconds */ +#include +#include +#include +#include +#include /* struct sock */ +#include +#include + +#include + +#include "resources.h" /* atm_find_dev */ +#include "common.h" /* prototypes */ +#include "protocols.h" /* atm_init_ */ +#include "addr.h" /* address registry */ +#include "signaling.h" /* for WAITING and sigd_attach */ + +struct hlist_head vcc_hash[VCC_HTABLE_SIZE]; +EXPORT_SYMBOL(vcc_hash); + +DEFINE_RWLOCK(vcc_sklist_lock); +EXPORT_SYMBOL(vcc_sklist_lock); + +static ATOMIC_NOTIFIER_HEAD(atm_dev_notify_chain); + +static void __vcc_insert_socket(struct sock *sk) +{ + struct atm_vcc *vcc = atm_sk(sk); + struct hlist_head *head = &vcc_hash[vcc->vci & (VCC_HTABLE_SIZE - 1)]; + sk->sk_hash = vcc->vci & (VCC_HTABLE_SIZE - 1); + sk_add_node(sk, head); +} + +void vcc_insert_socket(struct sock *sk) +{ + write_lock_irq(&vcc_sklist_lock); + __vcc_insert_socket(sk); + write_unlock_irq(&vcc_sklist_lock); +} +EXPORT_SYMBOL(vcc_insert_socket); + +static void vcc_remove_socket(struct sock *sk) +{ + write_lock_irq(&vcc_sklist_lock); + sk_del_node_init(sk); + write_unlock_irq(&vcc_sklist_lock); +} + +static bool vcc_tx_ready(struct atm_vcc *vcc, unsigned int size) +{ + struct sock *sk = sk_atm(vcc); + + if (sk_wmem_alloc_get(sk) && !atm_may_send(vcc, size)) { + pr_debug("Sorry: wmem_alloc = %d, size = %d, sndbuf = %d\n", + sk_wmem_alloc_get(sk), size, sk->sk_sndbuf); + return false; + } + return true; +} + +static void vcc_sock_destruct(struct sock *sk) +{ + if (atomic_read(&sk->sk_rmem_alloc)) + printk(KERN_DEBUG "%s: rmem leakage (%d bytes) detected.\n", + __func__, atomic_read(&sk->sk_rmem_alloc)); + + if (refcount_read(&sk->sk_wmem_alloc)) + printk(KERN_DEBUG "%s: wmem leakage (%d bytes) detected.\n", + __func__, refcount_read(&sk->sk_wmem_alloc)); +} + +static void vcc_def_wakeup(struct sock *sk) +{ + struct socket_wq *wq; + + rcu_read_lock(); + wq = rcu_dereference(sk->sk_wq); + if (skwq_has_sleeper(wq)) + wake_up(&wq->wait); + rcu_read_unlock(); +} + +static inline int vcc_writable(struct sock *sk) +{ + struct atm_vcc *vcc = atm_sk(sk); + + return (vcc->qos.txtp.max_sdu + + refcount_read(&sk->sk_wmem_alloc)) <= sk->sk_sndbuf; +} + +static void vcc_write_space(struct sock *sk) +{ + struct socket_wq *wq; + + rcu_read_lock(); + + if (vcc_writable(sk)) { + wq = rcu_dereference(sk->sk_wq); + if (skwq_has_sleeper(wq)) + wake_up_interruptible(&wq->wait); + + sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT); + } + + rcu_read_unlock(); +} + +static void vcc_release_cb(struct sock *sk) +{ + struct atm_vcc *vcc = atm_sk(sk); + + if (vcc->release_cb) + vcc->release_cb(vcc); +} + +static struct proto vcc_proto = { + .name = "VCC", + .owner = THIS_MODULE, + .obj_size = sizeof(struct atm_vcc), + .release_cb = vcc_release_cb, +}; + +int vcc_create(struct net *net, struct socket *sock, int protocol, int family, int kern) +{ + struct sock *sk; + struct atm_vcc *vcc; + + sock->sk = NULL; + if (sock->type == SOCK_STREAM) + return -EINVAL; + sk = sk_alloc(net, family, GFP_KERNEL, &vcc_proto, kern); + if (!sk) + return -ENOMEM; + sock_init_data(sock, sk); + sk->sk_state_change = vcc_def_wakeup; + sk->sk_write_space = vcc_write_space; + + vcc = atm_sk(sk); + vcc->dev = NULL; + memset(&vcc->local, 0, sizeof(struct sockaddr_atmsvc)); + memset(&vcc->remote, 0, sizeof(struct sockaddr_atmsvc)); + vcc->qos.txtp.max_sdu = 1 << 16; /* for meta VCs */ + refcount_set(&sk->sk_wmem_alloc, 1); + atomic_set(&sk->sk_rmem_alloc, 0); + vcc->push = NULL; + vcc->pop = NULL; + vcc->owner = NULL; + vcc->push_oam = NULL; + vcc->release_cb = NULL; + vcc->vpi = vcc->vci = 0; /* no VCI/VPI yet */ + vcc->atm_options = vcc->aal_options = 0; + sk->sk_destruct = vcc_sock_destruct; + return 0; +} + +static void vcc_destroy_socket(struct sock *sk) +{ + struct atm_vcc *vcc = atm_sk(sk); + struct sk_buff *skb; + + set_bit(ATM_VF_CLOSE, &vcc->flags); + clear_bit(ATM_VF_READY, &vcc->flags); + if (vcc->dev && vcc->dev->ops->close) + vcc->dev->ops->close(vcc); + if (vcc->push) + vcc->push(vcc, NULL); /* atmarpd has no push */ + module_put(vcc->owner); + + while ((skb = skb_dequeue(&sk->sk_receive_queue)) != NULL) { + atm_return(vcc, skb->truesize); + kfree_skb(skb); + } + + if (vcc->dev && vcc->dev->ops->owner) { + module_put(vcc->dev->ops->owner); + atm_dev_put(vcc->dev); + } + + vcc_remove_socket(sk); +} + +int vcc_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + + if (sk) { + lock_sock(sk); + vcc_destroy_socket(sock->sk); + release_sock(sk); + sock_put(sk); + } + + return 0; +} + +void vcc_release_async(struct atm_vcc *vcc, int reply) +{ + struct sock *sk = sk_atm(vcc); + + set_bit(ATM_VF_CLOSE, &vcc->flags); + sk->sk_shutdown |= RCV_SHUTDOWN; + sk->sk_err = -reply; + clear_bit(ATM_VF_WAITING, &vcc->flags); + sk->sk_state_change(sk); +} +EXPORT_SYMBOL(vcc_release_async); + +void vcc_process_recv_queue(struct atm_vcc *vcc) +{ + struct sk_buff_head queue, *rq; + struct sk_buff *skb, *tmp; + unsigned long flags; + + __skb_queue_head_init(&queue); + rq = &sk_atm(vcc)->sk_receive_queue; + + spin_lock_irqsave(&rq->lock, flags); + skb_queue_splice_init(rq, &queue); + spin_unlock_irqrestore(&rq->lock, flags); + + skb_queue_walk_safe(&queue, skb, tmp) { + __skb_unlink(skb, &queue); + vcc->push(vcc, skb); + } +} +EXPORT_SYMBOL(vcc_process_recv_queue); + +void atm_dev_signal_change(struct atm_dev *dev, char signal) +{ + pr_debug("%s signal=%d dev=%p number=%d dev->signal=%d\n", + __func__, signal, dev, dev->number, dev->signal); + + /* atm driver sending invalid signal */ + WARN_ON(signal < ATM_PHY_SIG_LOST || signal > ATM_PHY_SIG_FOUND); + + if (dev->signal == signal) + return; /* no change */ + + dev->signal = signal; + + atomic_notifier_call_chain(&atm_dev_notify_chain, signal, dev); +} +EXPORT_SYMBOL(atm_dev_signal_change); + +void atm_dev_release_vccs(struct atm_dev *dev) +{ + int i; + + write_lock_irq(&vcc_sklist_lock); + for (i = 0; i < VCC_HTABLE_SIZE; i++) { + struct hlist_head *head = &vcc_hash[i]; + struct hlist_node *tmp; + struct sock *s; + struct atm_vcc *vcc; + + sk_for_each_safe(s, tmp, head) { + vcc = atm_sk(s); + if (vcc->dev == dev) { + vcc_release_async(vcc, -EPIPE); + sk_del_node_init(s); + } + } + } + write_unlock_irq(&vcc_sklist_lock); +} +EXPORT_SYMBOL(atm_dev_release_vccs); + +static int adjust_tp(struct atm_trafprm *tp, unsigned char aal) +{ + int max_sdu; + + if (!tp->traffic_class) + return 0; + switch (aal) { + case ATM_AAL0: + max_sdu = ATM_CELL_SIZE-1; + break; + case ATM_AAL34: + max_sdu = ATM_MAX_AAL34_PDU; + break; + default: + pr_warn("AAL problems ... (%d)\n", aal); + fallthrough; + case ATM_AAL5: + max_sdu = ATM_MAX_AAL5_PDU; + } + if (!tp->max_sdu) + tp->max_sdu = max_sdu; + else if (tp->max_sdu > max_sdu) + return -EINVAL; + if (!tp->max_cdv) + tp->max_cdv = ATM_MAX_CDV; + return 0; +} + +static int check_ci(const struct atm_vcc *vcc, short vpi, int vci) +{ + struct hlist_head *head = &vcc_hash[vci & (VCC_HTABLE_SIZE - 1)]; + struct sock *s; + struct atm_vcc *walk; + + sk_for_each(s, head) { + walk = atm_sk(s); + if (walk->dev != vcc->dev) + continue; + if (test_bit(ATM_VF_ADDR, &walk->flags) && walk->vpi == vpi && + walk->vci == vci && ((walk->qos.txtp.traffic_class != + ATM_NONE && vcc->qos.txtp.traffic_class != ATM_NONE) || + (walk->qos.rxtp.traffic_class != ATM_NONE && + vcc->qos.rxtp.traffic_class != ATM_NONE))) + return -EADDRINUSE; + } + + /* allow VCCs with same VPI/VCI iff they don't collide on + TX/RX (but we may refuse such sharing for other reasons, + e.g. if protocol requires to have both channels) */ + + return 0; +} + +static int find_ci(const struct atm_vcc *vcc, short *vpi, int *vci) +{ + static short p; /* poor man's per-device cache */ + static int c; + short old_p; + int old_c; + int err; + + if (*vpi != ATM_VPI_ANY && *vci != ATM_VCI_ANY) { + err = check_ci(vcc, *vpi, *vci); + return err; + } + /* last scan may have left values out of bounds for current device */ + if (*vpi != ATM_VPI_ANY) + p = *vpi; + else if (p >= 1 << vcc->dev->ci_range.vpi_bits) + p = 0; + if (*vci != ATM_VCI_ANY) + c = *vci; + else if (c < ATM_NOT_RSV_VCI || c >= 1 << vcc->dev->ci_range.vci_bits) + c = ATM_NOT_RSV_VCI; + old_p = p; + old_c = c; + do { + if (!check_ci(vcc, p, c)) { + *vpi = p; + *vci = c; + return 0; + } + if (*vci == ATM_VCI_ANY) { + c++; + if (c >= 1 << vcc->dev->ci_range.vci_bits) + c = ATM_NOT_RSV_VCI; + } + if ((c == ATM_NOT_RSV_VCI || *vci != ATM_VCI_ANY) && + *vpi == ATM_VPI_ANY) { + p++; + if (p >= 1 << vcc->dev->ci_range.vpi_bits) + p = 0; + } + } while (old_p != p || old_c != c); + return -EADDRINUSE; +} + +static int __vcc_connect(struct atm_vcc *vcc, struct atm_dev *dev, short vpi, + int vci) +{ + struct sock *sk = sk_atm(vcc); + int error; + + if ((vpi != ATM_VPI_UNSPEC && vpi != ATM_VPI_ANY && + vpi >> dev->ci_range.vpi_bits) || (vci != ATM_VCI_UNSPEC && + vci != ATM_VCI_ANY && vci >> dev->ci_range.vci_bits)) + return -EINVAL; + if (vci > 0 && vci < ATM_NOT_RSV_VCI && !capable(CAP_NET_BIND_SERVICE)) + return -EPERM; + error = -ENODEV; + if (!try_module_get(dev->ops->owner)) + return error; + vcc->dev = dev; + write_lock_irq(&vcc_sklist_lock); + if (test_bit(ATM_DF_REMOVED, &dev->flags) || + (error = find_ci(vcc, &vpi, &vci))) { + write_unlock_irq(&vcc_sklist_lock); + goto fail_module_put; + } + vcc->vpi = vpi; + vcc->vci = vci; + __vcc_insert_socket(sk); + write_unlock_irq(&vcc_sklist_lock); + switch (vcc->qos.aal) { + case ATM_AAL0: + error = atm_init_aal0(vcc); + vcc->stats = &dev->stats.aal0; + break; + case ATM_AAL34: + error = atm_init_aal34(vcc); + vcc->stats = &dev->stats.aal34; + break; + case ATM_NO_AAL: + /* ATM_AAL5 is also used in the "0 for default" case */ + vcc->qos.aal = ATM_AAL5; + fallthrough; + case ATM_AAL5: + error = atm_init_aal5(vcc); + vcc->stats = &dev->stats.aal5; + break; + default: + error = -EPROTOTYPE; + } + if (!error) + error = adjust_tp(&vcc->qos.txtp, vcc->qos.aal); + if (!error) + error = adjust_tp(&vcc->qos.rxtp, vcc->qos.aal); + if (error) + goto fail; + pr_debug("VCC %d.%d, AAL %d\n", vpi, vci, vcc->qos.aal); + pr_debug(" TX: %d, PCR %d..%d, SDU %d\n", + vcc->qos.txtp.traffic_class, + vcc->qos.txtp.min_pcr, + vcc->qos.txtp.max_pcr, + vcc->qos.txtp.max_sdu); + pr_debug(" RX: %d, PCR %d..%d, SDU %d\n", + vcc->qos.rxtp.traffic_class, + vcc->qos.rxtp.min_pcr, + vcc->qos.rxtp.max_pcr, + vcc->qos.rxtp.max_sdu); + + if (dev->ops->open) { + error = dev->ops->open(vcc); + if (error) + goto fail; + } + return 0; + +fail: + vcc_remove_socket(sk); +fail_module_put: + module_put(dev->ops->owner); + /* ensure we get dev module ref count correct */ + vcc->dev = NULL; + return error; +} + +int vcc_connect(struct socket *sock, int itf, short vpi, int vci) +{ + struct atm_dev *dev; + struct atm_vcc *vcc = ATM_SD(sock); + int error; + + pr_debug("(vpi %d, vci %d)\n", vpi, vci); + if (sock->state == SS_CONNECTED) + return -EISCONN; + if (sock->state != SS_UNCONNECTED) + return -EINVAL; + if (!(vpi || vci)) + return -EINVAL; + + if (vpi != ATM_VPI_UNSPEC && vci != ATM_VCI_UNSPEC) + clear_bit(ATM_VF_PARTIAL, &vcc->flags); + else + if (test_bit(ATM_VF_PARTIAL, &vcc->flags)) + return -EINVAL; + pr_debug("(TX: cl %d,bw %d-%d,sdu %d; " + "RX: cl %d,bw %d-%d,sdu %d,AAL %s%d)\n", + vcc->qos.txtp.traffic_class, vcc->qos.txtp.min_pcr, + vcc->qos.txtp.max_pcr, vcc->qos.txtp.max_sdu, + vcc->qos.rxtp.traffic_class, vcc->qos.rxtp.min_pcr, + vcc->qos.rxtp.max_pcr, vcc->qos.rxtp.max_sdu, + vcc->qos.aal == ATM_AAL5 ? "" : + vcc->qos.aal == ATM_AAL0 ? "" : " ??? code ", + vcc->qos.aal == ATM_AAL0 ? 0 : vcc->qos.aal); + if (!test_bit(ATM_VF_HASQOS, &vcc->flags)) + return -EBADFD; + if (vcc->qos.txtp.traffic_class == ATM_ANYCLASS || + vcc->qos.rxtp.traffic_class == ATM_ANYCLASS) + return -EINVAL; + if (likely(itf != ATM_ITF_ANY)) { + dev = try_then_request_module(atm_dev_lookup(itf), + "atm-device-%d", itf); + } else { + dev = NULL; + mutex_lock(&atm_dev_mutex); + if (!list_empty(&atm_devs)) { + dev = list_entry(atm_devs.next, + struct atm_dev, dev_list); + atm_dev_hold(dev); + } + mutex_unlock(&atm_dev_mutex); + } + if (!dev) + return -ENODEV; + error = __vcc_connect(vcc, dev, vpi, vci); + if (error) { + atm_dev_put(dev); + return error; + } + if (vpi == ATM_VPI_UNSPEC || vci == ATM_VCI_UNSPEC) + set_bit(ATM_VF_PARTIAL, &vcc->flags); + if (test_bit(ATM_VF_READY, &ATM_SD(sock)->flags)) + sock->state = SS_CONNECTED; + return 0; +} + +int vcc_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, + int flags) +{ + struct sock *sk = sock->sk; + struct atm_vcc *vcc; + struct sk_buff *skb; + int copied, error = -EINVAL; + + if (sock->state != SS_CONNECTED) + return -ENOTCONN; + + /* only handle MSG_DONTWAIT and MSG_PEEK */ + if (flags & ~(MSG_DONTWAIT | MSG_PEEK)) + return -EOPNOTSUPP; + + vcc = ATM_SD(sock); + if (test_bit(ATM_VF_RELEASED, &vcc->flags) || + test_bit(ATM_VF_CLOSE, &vcc->flags) || + !test_bit(ATM_VF_READY, &vcc->flags)) + return 0; + + skb = skb_recv_datagram(sk, flags, &error); + if (!skb) + return error; + + copied = skb->len; + if (copied > size) { + copied = size; + msg->msg_flags |= MSG_TRUNC; + } + + error = skb_copy_datagram_msg(skb, 0, msg, copied); + if (error) + return error; + sock_recv_cmsgs(msg, sk, skb); + + if (!(flags & MSG_PEEK)) { + pr_debug("%d -= %d\n", atomic_read(&sk->sk_rmem_alloc), + skb->truesize); + atm_return(vcc, skb->truesize); + } + + skb_free_datagram(sk, skb); + return copied; +} + +int vcc_sendmsg(struct socket *sock, struct msghdr *m, size_t size) +{ + struct sock *sk = sock->sk; + DEFINE_WAIT(wait); + struct atm_vcc *vcc; + struct sk_buff *skb; + int eff, error; + + lock_sock(sk); + if (sock->state != SS_CONNECTED) { + error = -ENOTCONN; + goto out; + } + if (m->msg_name) { + error = -EISCONN; + goto out; + } + vcc = ATM_SD(sock); + if (test_bit(ATM_VF_RELEASED, &vcc->flags) || + test_bit(ATM_VF_CLOSE, &vcc->flags) || + !test_bit(ATM_VF_READY, &vcc->flags)) { + error = -EPIPE; + send_sig(SIGPIPE, current, 0); + goto out; + } + if (!size) { + error = 0; + goto out; + } + if (size > vcc->qos.txtp.max_sdu) { + error = -EMSGSIZE; + goto out; + } + + eff = (size+3) & ~3; /* align to word boundary */ + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + error = 0; + while (!vcc_tx_ready(vcc, eff)) { + if (m->msg_flags & MSG_DONTWAIT) { + error = -EAGAIN; + break; + } + schedule(); + if (signal_pending(current)) { + error = -ERESTARTSYS; + break; + } + if (test_bit(ATM_VF_RELEASED, &vcc->flags) || + test_bit(ATM_VF_CLOSE, &vcc->flags) || + !test_bit(ATM_VF_READY, &vcc->flags)) { + error = -EPIPE; + send_sig(SIGPIPE, current, 0); + break; + } + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + } + finish_wait(sk_sleep(sk), &wait); + if (error) + goto out; + + skb = alloc_skb(eff, GFP_KERNEL); + if (!skb) { + error = -ENOMEM; + goto out; + } + pr_debug("%d += %d\n", sk_wmem_alloc_get(sk), skb->truesize); + atm_account_tx(vcc, skb); + + skb->dev = NULL; /* for paths shared with net_device interfaces */ + if (!copy_from_iter_full(skb_put(skb, size), size, &m->msg_iter)) { + kfree_skb(skb); + error = -EFAULT; + goto out; + } + if (eff != size) + memset(skb->data + size, 0, eff-size); + error = vcc->dev->ops->send(vcc, skb); + error = error ? error : size; +out: + release_sock(sk); + return error; +} + +__poll_t vcc_poll(struct file *file, struct socket *sock, poll_table *wait) +{ + struct sock *sk = sock->sk; + struct atm_vcc *vcc; + __poll_t mask; + + sock_poll_wait(file, sock, wait); + mask = 0; + + vcc = ATM_SD(sock); + + /* exceptional events */ + if (sk->sk_err) + mask = EPOLLERR; + + if (test_bit(ATM_VF_RELEASED, &vcc->flags) || + test_bit(ATM_VF_CLOSE, &vcc->flags)) + mask |= EPOLLHUP; + + /* readable? */ + if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) + mask |= EPOLLIN | EPOLLRDNORM; + + /* writable? */ + if (sock->state == SS_CONNECTING && + test_bit(ATM_VF_WAITING, &vcc->flags)) + return mask; + + if (vcc->qos.txtp.traffic_class != ATM_NONE && + vcc_writable(sk)) + mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND; + + return mask; +} + +static int atm_change_qos(struct atm_vcc *vcc, struct atm_qos *qos) +{ + int error; + + /* + * Don't let the QoS change the already connected AAL type nor the + * traffic class. + */ + if (qos->aal != vcc->qos.aal || + qos->rxtp.traffic_class != vcc->qos.rxtp.traffic_class || + qos->txtp.traffic_class != vcc->qos.txtp.traffic_class) + return -EINVAL; + error = adjust_tp(&qos->txtp, qos->aal); + if (!error) + error = adjust_tp(&qos->rxtp, qos->aal); + if (error) + return error; + if (!vcc->dev->ops->change_qos) + return -EOPNOTSUPP; + if (sk_atm(vcc)->sk_family == AF_ATMPVC) + return vcc->dev->ops->change_qos(vcc, qos, ATM_MF_SET); + return svc_change_qos(vcc, qos); +} + +static int check_tp(const struct atm_trafprm *tp) +{ + /* @@@ Should be merged with adjust_tp */ + if (!tp->traffic_class || tp->traffic_class == ATM_ANYCLASS) + return 0; + if (tp->traffic_class != ATM_UBR && !tp->min_pcr && !tp->pcr && + !tp->max_pcr) + return -EINVAL; + if (tp->min_pcr == ATM_MAX_PCR) + return -EINVAL; + if (tp->min_pcr && tp->max_pcr && tp->max_pcr != ATM_MAX_PCR && + tp->min_pcr > tp->max_pcr) + return -EINVAL; + /* + * We allow pcr to be outside [min_pcr,max_pcr], because later + * adjustment may still push it in the valid range. + */ + return 0; +} + +static int check_qos(const struct atm_qos *qos) +{ + int error; + + if (!qos->txtp.traffic_class && !qos->rxtp.traffic_class) + return -EINVAL; + if (qos->txtp.traffic_class != qos->rxtp.traffic_class && + qos->txtp.traffic_class && qos->rxtp.traffic_class && + qos->txtp.traffic_class != ATM_ANYCLASS && + qos->rxtp.traffic_class != ATM_ANYCLASS) + return -EINVAL; + error = check_tp(&qos->txtp); + if (error) + return error; + return check_tp(&qos->rxtp); +} + +int vcc_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct atm_vcc *vcc; + unsigned long value; + int error; + + if (__SO_LEVEL_MATCH(optname, level) && optlen != __SO_SIZE(optname)) + return -EINVAL; + + vcc = ATM_SD(sock); + switch (optname) { + case SO_ATMQOS: + { + struct atm_qos qos; + + if (copy_from_sockptr(&qos, optval, sizeof(qos))) + return -EFAULT; + error = check_qos(&qos); + if (error) + return error; + if (sock->state == SS_CONNECTED) + return atm_change_qos(vcc, &qos); + if (sock->state != SS_UNCONNECTED) + return -EBADFD; + vcc->qos = qos; + set_bit(ATM_VF_HASQOS, &vcc->flags); + return 0; + } + case SO_SETCLP: + if (copy_from_sockptr(&value, optval, sizeof(value))) + return -EFAULT; + if (value) + vcc->atm_options |= ATM_ATMOPT_CLP; + else + vcc->atm_options &= ~ATM_ATMOPT_CLP; + return 0; + default: + return -EINVAL; + } +} + +int vcc_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct atm_vcc *vcc; + int len; + + if (get_user(len, optlen)) + return -EFAULT; + if (__SO_LEVEL_MATCH(optname, level) && len != __SO_SIZE(optname)) + return -EINVAL; + + vcc = ATM_SD(sock); + switch (optname) { + case SO_ATMQOS: + if (!test_bit(ATM_VF_HASQOS, &vcc->flags)) + return -EINVAL; + return copy_to_user(optval, &vcc->qos, sizeof(vcc->qos)) + ? -EFAULT : 0; + case SO_SETCLP: + return put_user(vcc->atm_options & ATM_ATMOPT_CLP ? 1 : 0, + (unsigned long __user *)optval) ? -EFAULT : 0; + case SO_ATMPVC: + { + struct sockaddr_atmpvc pvc; + + if (!vcc->dev || !test_bit(ATM_VF_ADDR, &vcc->flags)) + return -ENOTCONN; + memset(&pvc, 0, sizeof(pvc)); + pvc.sap_family = AF_ATMPVC; + pvc.sap_addr.itf = vcc->dev->number; + pvc.sap_addr.vpi = vcc->vpi; + pvc.sap_addr.vci = vcc->vci; + return copy_to_user(optval, &pvc, sizeof(pvc)) ? -EFAULT : 0; + } + default: + return -EINVAL; + } +} + +int register_atmdevice_notifier(struct notifier_block *nb) +{ + return atomic_notifier_chain_register(&atm_dev_notify_chain, nb); +} +EXPORT_SYMBOL_GPL(register_atmdevice_notifier); + +void unregister_atmdevice_notifier(struct notifier_block *nb) +{ + atomic_notifier_chain_unregister(&atm_dev_notify_chain, nb); +} +EXPORT_SYMBOL_GPL(unregister_atmdevice_notifier); + +static int __init atm_init(void) +{ + int error; + + error = proto_register(&vcc_proto, 0); + if (error < 0) + goto out; + error = atmpvc_init(); + if (error < 0) { + pr_err("atmpvc_init() failed with %d\n", error); + goto out_unregister_vcc_proto; + } + error = atmsvc_init(); + if (error < 0) { + pr_err("atmsvc_init() failed with %d\n", error); + goto out_atmpvc_exit; + } + error = atm_proc_init(); + if (error < 0) { + pr_err("atm_proc_init() failed with %d\n", error); + goto out_atmsvc_exit; + } + error = atm_sysfs_init(); + if (error < 0) { + pr_err("atm_sysfs_init() failed with %d\n", error); + goto out_atmproc_exit; + } +out: + return error; +out_atmproc_exit: + atm_proc_exit(); +out_atmsvc_exit: + atmsvc_exit(); +out_atmpvc_exit: + atmsvc_exit(); +out_unregister_vcc_proto: + proto_unregister(&vcc_proto); + goto out; +} + +static void __exit atm_exit(void) +{ + atm_proc_exit(); + atm_sysfs_exit(); + atmsvc_exit(); + atmpvc_exit(); + proto_unregister(&vcc_proto); +} + +subsys_initcall(atm_init); + +module_exit(atm_exit); + +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_ATMPVC); +MODULE_ALIAS_NETPROTO(PF_ATMSVC); diff --git a/net/atm/common.h b/net/atm/common.h new file mode 100644 index 000000000..a1e56e8de --- /dev/null +++ b/net/atm/common.h @@ -0,0 +1,56 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* net/atm/common.h - ATM sockets (common part for PVC and SVC) */ + +/* Written 1995-2000 by Werner Almesberger, EPFL LRC/ICA */ + + +#ifndef NET_ATM_COMMON_H +#define NET_ATM_COMMON_H + +#include +#include /* for poll_table */ + + +int vcc_create(struct net *net, struct socket *sock, int protocol, int family, int kern); +int vcc_release(struct socket *sock); +int vcc_connect(struct socket *sock, int itf, short vpi, int vci); +int vcc_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, + int flags); +int vcc_sendmsg(struct socket *sock, struct msghdr *m, size_t total_len); +__poll_t vcc_poll(struct file *file, struct socket *sock, poll_table *wait); +int vcc_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg); +int vcc_compat_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg); +int vcc_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen); +int vcc_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen); +void vcc_process_recv_queue(struct atm_vcc *vcc); + +int atmpvc_init(void); +void atmpvc_exit(void); +int atmsvc_init(void); +void atmsvc_exit(void); +int atm_sysfs_init(void); +void atm_sysfs_exit(void); + +#ifdef CONFIG_PROC_FS +int atm_proc_init(void); +void atm_proc_exit(void); +#else +static inline int atm_proc_init(void) +{ + return 0; +} + +static inline void atm_proc_exit(void) +{ + /* nothing */ +} +#endif /* CONFIG_PROC_FS */ + +/* SVC */ +int svc_change_qos(struct atm_vcc *vcc,struct atm_qos *qos); + +void atm_dev_release_vccs(struct atm_dev *dev); + +#endif diff --git a/net/atm/ioctl.c b/net/atm/ioctl.c new file mode 100644 index 000000000..f81f8d56f --- /dev/null +++ b/net/atm/ioctl.c @@ -0,0 +1,365 @@ +// SPDX-License-Identifier: GPL-2.0 +/* ATM ioctl handling */ + +/* Written 1995-2000 by Werner Almesberger, EPFL LRC/ICA */ +/* 2003 John Levon */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s: " fmt, __func__ + +#include +#include +#include /* struct socket, struct proto_ops */ +#include /* ATM stuff */ +#include +#include /* CLIP_*ENCAP */ +#include /* manifest constants */ +#include +#include /* for ioctls */ +#include +#include +#include +#include +#include +#include +#include + +#include "resources.h" +#include "signaling.h" /* for WAITING and sigd_attach */ +#include "common.h" + + +static DEFINE_MUTEX(ioctl_mutex); +static LIST_HEAD(ioctl_list); + + +void register_atm_ioctl(struct atm_ioctl *ioctl) +{ + mutex_lock(&ioctl_mutex); + list_add_tail(&ioctl->list, &ioctl_list); + mutex_unlock(&ioctl_mutex); +} +EXPORT_SYMBOL(register_atm_ioctl); + +void deregister_atm_ioctl(struct atm_ioctl *ioctl) +{ + mutex_lock(&ioctl_mutex); + list_del(&ioctl->list); + mutex_unlock(&ioctl_mutex); +} +EXPORT_SYMBOL(deregister_atm_ioctl); + +static int do_vcc_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg, int compat) +{ + struct sock *sk = sock->sk; + struct atm_vcc *vcc; + int error; + struct list_head *pos; + void __user *argp = (void __user *)arg; + void __user *buf; + int __user *len; + + vcc = ATM_SD(sock); + switch (cmd) { + case SIOCOUTQ: + if (sock->state != SS_CONNECTED || + !test_bit(ATM_VF_READY, &vcc->flags)) { + error = -EINVAL; + goto done; + } + error = put_user(sk->sk_sndbuf - sk_wmem_alloc_get(sk), + (int __user *)argp) ? -EFAULT : 0; + goto done; + case SIOCINQ: + { + struct sk_buff *skb; + int amount; + + if (sock->state != SS_CONNECTED) { + error = -EINVAL; + goto done; + } + spin_lock_irq(&sk->sk_receive_queue.lock); + skb = skb_peek(&sk->sk_receive_queue); + amount = skb ? skb->len : 0; + spin_unlock_irq(&sk->sk_receive_queue.lock); + error = put_user(amount, (int __user *)argp) ? -EFAULT : 0; + goto done; + } + case ATM_SETSC: + net_warn_ratelimited("ATM_SETSC is obsolete; used by %s:%d\n", + current->comm, task_pid_nr(current)); + error = 0; + goto done; + case ATMSIGD_CTRL: + if (!capable(CAP_NET_ADMIN)) { + error = -EPERM; + goto done; + } + /* + * The user/kernel protocol for exchanging signalling + * info uses kernel pointers as opaque references, + * so the holder of the file descriptor can scribble + * on the kernel... so we should make sure that we + * have the same privileges that /proc/kcore needs + */ + if (!capable(CAP_SYS_RAWIO)) { + error = -EPERM; + goto done; + } +#ifdef CONFIG_COMPAT + /* WTF? I don't even want to _think_ about making this + work for 32-bit userspace. TBH I don't really want + to think about it at all. dwmw2. */ + if (compat) { + net_warn_ratelimited("32-bit task cannot be atmsigd\n"); + error = -EINVAL; + goto done; + } +#endif + error = sigd_attach(vcc); + if (!error) + sock->state = SS_CONNECTED; + goto done; + case ATM_SETBACKEND: + case ATM_NEWBACKENDIF: + { + atm_backend_t backend; + error = get_user(backend, (atm_backend_t __user *)argp); + if (error) + goto done; + switch (backend) { + case ATM_BACKEND_PPP: + request_module("pppoatm"); + break; + case ATM_BACKEND_BR2684: + request_module("br2684"); + break; + } + break; + } + case ATMMPC_CTRL: + case ATMMPC_DATA: + request_module("mpoa"); + break; + case ATMARPD_CTRL: + request_module("clip"); + break; + case ATMLEC_CTRL: + request_module("lec"); + break; + } + + error = -ENOIOCTLCMD; + + mutex_lock(&ioctl_mutex); + list_for_each(pos, &ioctl_list) { + struct atm_ioctl *ic = list_entry(pos, struct atm_ioctl, list); + if (try_module_get(ic->owner)) { + error = ic->ioctl(sock, cmd, arg); + module_put(ic->owner); + if (error != -ENOIOCTLCMD) + break; + } + } + mutex_unlock(&ioctl_mutex); + + if (error != -ENOIOCTLCMD) + goto done; + + if (cmd == ATM_GETNAMES) { + if (IS_ENABLED(CONFIG_COMPAT) && compat) { +#ifdef CONFIG_COMPAT + struct compat_atm_iobuf __user *ciobuf = argp; + compat_uptr_t cbuf; + len = &ciobuf->length; + if (get_user(cbuf, &ciobuf->buffer)) + return -EFAULT; + buf = compat_ptr(cbuf); +#endif + } else { + struct atm_iobuf __user *iobuf = argp; + len = &iobuf->length; + if (get_user(buf, &iobuf->buffer)) + return -EFAULT; + } + error = atm_getnames(buf, len); + } else { + int number; + + if (IS_ENABLED(CONFIG_COMPAT) && compat) { +#ifdef CONFIG_COMPAT + struct compat_atmif_sioc __user *csioc = argp; + compat_uptr_t carg; + + len = &csioc->length; + if (get_user(carg, &csioc->arg)) + return -EFAULT; + buf = compat_ptr(carg); + if (get_user(number, &csioc->number)) + return -EFAULT; +#endif + } else { + struct atmif_sioc __user *sioc = argp; + + len = &sioc->length; + if (get_user(buf, &sioc->arg)) + return -EFAULT; + if (get_user(number, &sioc->number)) + return -EFAULT; + } + error = atm_dev_ioctl(cmd, buf, len, number, compat); + } + +done: + return error; +} + +int vcc_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + return do_vcc_ioctl(sock, cmd, arg, 0); +} + +#ifdef CONFIG_COMPAT +/* + * FIXME: + * The compat_ioctl handling is duplicated, using both these conversion + * routines and the compat argument to the actual handlers. Both + * versions are somewhat incomplete and should be merged, e.g. by + * moving the ioctl number translation into the actual handlers and + * killing the conversion code. + * + * -arnd, November 2009 + */ +#define ATM_GETLINKRATE32 _IOW('a', ATMIOC_ITF+1, struct compat_atmif_sioc) +#define ATM_GETNAMES32 _IOW('a', ATMIOC_ITF+3, struct compat_atm_iobuf) +#define ATM_GETTYPE32 _IOW('a', ATMIOC_ITF+4, struct compat_atmif_sioc) +#define ATM_GETESI32 _IOW('a', ATMIOC_ITF+5, struct compat_atmif_sioc) +#define ATM_GETADDR32 _IOW('a', ATMIOC_ITF+6, struct compat_atmif_sioc) +#define ATM_RSTADDR32 _IOW('a', ATMIOC_ITF+7, struct compat_atmif_sioc) +#define ATM_ADDADDR32 _IOW('a', ATMIOC_ITF+8, struct compat_atmif_sioc) +#define ATM_DELADDR32 _IOW('a', ATMIOC_ITF+9, struct compat_atmif_sioc) +#define ATM_GETCIRANGE32 _IOW('a', ATMIOC_ITF+10, struct compat_atmif_sioc) +#define ATM_SETCIRANGE32 _IOW('a', ATMIOC_ITF+11, struct compat_atmif_sioc) +#define ATM_SETESI32 _IOW('a', ATMIOC_ITF+12, struct compat_atmif_sioc) +#define ATM_SETESIF32 _IOW('a', ATMIOC_ITF+13, struct compat_atmif_sioc) +#define ATM_GETSTAT32 _IOW('a', ATMIOC_SARCOM+0, struct compat_atmif_sioc) +#define ATM_GETSTATZ32 _IOW('a', ATMIOC_SARCOM+1, struct compat_atmif_sioc) +#define ATM_GETLOOP32 _IOW('a', ATMIOC_SARCOM+2, struct compat_atmif_sioc) +#define ATM_SETLOOP32 _IOW('a', ATMIOC_SARCOM+3, struct compat_atmif_sioc) +#define ATM_QUERYLOOP32 _IOW('a', ATMIOC_SARCOM+4, struct compat_atmif_sioc) + +static struct { + unsigned int cmd32; + unsigned int cmd; +} atm_ioctl_map[] = { + { ATM_GETLINKRATE32, ATM_GETLINKRATE }, + { ATM_GETNAMES32, ATM_GETNAMES }, + { ATM_GETTYPE32, ATM_GETTYPE }, + { ATM_GETESI32, ATM_GETESI }, + { ATM_GETADDR32, ATM_GETADDR }, + { ATM_RSTADDR32, ATM_RSTADDR }, + { ATM_ADDADDR32, ATM_ADDADDR }, + { ATM_DELADDR32, ATM_DELADDR }, + { ATM_GETCIRANGE32, ATM_GETCIRANGE }, + { ATM_SETCIRANGE32, ATM_SETCIRANGE }, + { ATM_SETESI32, ATM_SETESI }, + { ATM_SETESIF32, ATM_SETESIF }, + { ATM_GETSTAT32, ATM_GETSTAT }, + { ATM_GETSTATZ32, ATM_GETSTATZ }, + { ATM_GETLOOP32, ATM_GETLOOP }, + { ATM_SETLOOP32, ATM_SETLOOP }, + { ATM_QUERYLOOP32, ATM_QUERYLOOP }, +}; + +#define NR_ATM_IOCTL ARRAY_SIZE(atm_ioctl_map) + +static int do_atm_iobuf(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + struct compat_atm_iobuf __user *iobuf32 = compat_ptr(arg); + u32 data; + + if (get_user(data, &iobuf32->buffer)) + return -EFAULT; + + return atm_getnames(&iobuf32->length, compat_ptr(data)); +} + +static int do_atmif_sioc(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + struct compat_atmif_sioc __user *sioc32 = compat_ptr(arg); + int number; + u32 data; + + if (get_user(data, &sioc32->arg) || get_user(number, &sioc32->number)) + return -EFAULT; + return atm_dev_ioctl(cmd, compat_ptr(data), &sioc32->length, number, 0); +} + +static int do_atm_ioctl(struct socket *sock, unsigned int cmd32, + unsigned long arg) +{ + int i; + unsigned int cmd = 0; + + switch (cmd32) { + case SONET_GETSTAT: + case SONET_GETSTATZ: + case SONET_GETDIAG: + case SONET_SETDIAG: + case SONET_CLRDIAG: + case SONET_SETFRAMING: + case SONET_GETFRAMING: + case SONET_GETFRSENSE: + return do_atmif_sioc(sock, cmd32, arg); + } + + for (i = 0; i < NR_ATM_IOCTL; i++) { + if (cmd32 == atm_ioctl_map[i].cmd32) { + cmd = atm_ioctl_map[i].cmd; + break; + } + } + if (i == NR_ATM_IOCTL) + return -EINVAL; + + switch (cmd) { + case ATM_GETNAMES: + return do_atm_iobuf(sock, cmd, arg); + + case ATM_GETLINKRATE: + case ATM_GETTYPE: + case ATM_GETESI: + case ATM_GETADDR: + case ATM_RSTADDR: + case ATM_ADDADDR: + case ATM_DELADDR: + case ATM_GETCIRANGE: + case ATM_SETCIRANGE: + case ATM_SETESI: + case ATM_SETESIF: + case ATM_GETSTAT: + case ATM_GETSTATZ: + case ATM_GETLOOP: + case ATM_SETLOOP: + case ATM_QUERYLOOP: + return do_atmif_sioc(sock, cmd, arg); + } + + return -EINVAL; +} + +int vcc_compat_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + int ret; + + ret = do_vcc_ioctl(sock, cmd, arg, 1); + if (ret != -ENOIOCTLCMD) + return ret; + + return do_atm_ioctl(sock, cmd, arg); +} +#endif diff --git a/net/atm/lec.c b/net/atm/lec.c new file mode 100644 index 000000000..6257bf12e --- /dev/null +++ b/net/atm/lec.c @@ -0,0 +1,2237 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * lec.c: Lan Emulation driver + * + * Marko Kiiskila + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s: " fmt, __func__ + +#include +#include +#include +#include + +/* We are ethernet device */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* And atm device */ +#include +#include + +/* Proxy LEC knows about bridging */ +#if IS_ENABLED(CONFIG_BRIDGE) +#include "../bridge/br_private.h" + +static unsigned char bridge_ula_lec[] = { 0x01, 0x80, 0xc2, 0x00, 0x00 }; +#endif + +/* Modular too */ +#include +#include + +/* Hardening for Spectre-v1 */ +#include + +#include "lec.h" +#include "lec_arpc.h" +#include "resources.h" + +#define DUMP_PACKETS 0 /* + * 0 = None, + * 1 = 30 first bytes + * 2 = Whole packet + */ + +#define LEC_UNRES_QUE_LEN 8 /* + * number of tx packets to queue for a + * single destination while waiting for SVC + */ + +static int lec_open(struct net_device *dev); +static netdev_tx_t lec_start_xmit(struct sk_buff *skb, + struct net_device *dev); +static int lec_close(struct net_device *dev); +static struct lec_arp_table *lec_arp_find(struct lec_priv *priv, + const unsigned char *mac_addr); +static int lec_arp_remove(struct lec_priv *priv, + struct lec_arp_table *to_remove); +/* LANE2 functions */ +static void lane2_associate_ind(struct net_device *dev, const u8 *mac_address, + const u8 *tlvs, u32 sizeoftlvs); +static int lane2_resolve(struct net_device *dev, const u8 *dst_mac, int force, + u8 **tlvs, u32 *sizeoftlvs); +static int lane2_associate_req(struct net_device *dev, const u8 *lan_dst, + const u8 *tlvs, u32 sizeoftlvs); + +static int lec_addr_delete(struct lec_priv *priv, const unsigned char *atm_addr, + unsigned long permanent); +static void lec_arp_check_empties(struct lec_priv *priv, + struct atm_vcc *vcc, struct sk_buff *skb); +static void lec_arp_destroy(struct lec_priv *priv); +static void lec_arp_init(struct lec_priv *priv); +static struct atm_vcc *lec_arp_resolve(struct lec_priv *priv, + const unsigned char *mac_to_find, + int is_rdesc, + struct lec_arp_table **ret_entry); +static void lec_arp_update(struct lec_priv *priv, const unsigned char *mac_addr, + const unsigned char *atm_addr, + unsigned long remoteflag, + unsigned int targetless_le_arp); +static void lec_flush_complete(struct lec_priv *priv, unsigned long tran_id); +static int lec_mcast_make(struct lec_priv *priv, struct atm_vcc *vcc); +static void lec_set_flush_tran_id(struct lec_priv *priv, + const unsigned char *atm_addr, + unsigned long tran_id); +static void lec_vcc_added(struct lec_priv *priv, + const struct atmlec_ioc *ioc_data, + struct atm_vcc *vcc, + void (*old_push)(struct atm_vcc *vcc, + struct sk_buff *skb)); +static void lec_vcc_close(struct lec_priv *priv, struct atm_vcc *vcc); + +/* must be done under lec_arp_lock */ +static inline void lec_arp_hold(struct lec_arp_table *entry) +{ + refcount_inc(&entry->usage); +} + +static inline void lec_arp_put(struct lec_arp_table *entry) +{ + if (refcount_dec_and_test(&entry->usage)) + kfree(entry); +} + +static struct lane2_ops lane2_ops = { + .resolve = lane2_resolve, /* spec 3.1.3 */ + .associate_req = lane2_associate_req, /* spec 3.1.4 */ + .associate_indicator = NULL /* spec 3.1.5 */ +}; + +static unsigned char bus_mac[ETH_ALEN] = { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff }; + +/* Device structures */ +static struct net_device *dev_lec[MAX_LEC_ITF]; + +#if IS_ENABLED(CONFIG_BRIDGE) +static void lec_handle_bridge(struct sk_buff *skb, struct net_device *dev) +{ + char *buff; + struct lec_priv *priv; + + /* + * Check if this is a BPDU. If so, ask zeppelin to send + * LE_TOPOLOGY_REQUEST with the same value of Topology Change bit + * as the Config BPDU has + */ + buff = skb->data + skb->dev->hard_header_len; + if (*buff++ == 0x42 && *buff++ == 0x42 && *buff++ == 0x03) { + struct sock *sk; + struct sk_buff *skb2; + struct atmlec_msg *mesg; + + skb2 = alloc_skb(sizeof(struct atmlec_msg), GFP_ATOMIC); + if (skb2 == NULL) + return; + skb2->len = sizeof(struct atmlec_msg); + mesg = (struct atmlec_msg *)skb2->data; + mesg->type = l_topology_change; + buff += 4; + mesg->content.normal.flag = *buff & 0x01; + /* 0x01 is topology change */ + + priv = netdev_priv(dev); + atm_force_charge(priv->lecd, skb2->truesize); + sk = sk_atm(priv->lecd); + skb_queue_tail(&sk->sk_receive_queue, skb2); + sk->sk_data_ready(sk); + } +} +#endif /* IS_ENABLED(CONFIG_BRIDGE) */ + +/* + * Open/initialize the netdevice. This is called (in the current kernel) + * sometime after booting when the 'ifconfig' program is run. + * + * This routine should set everything up anew at each open, even + * registers that "should" only need to be set once at boot, so that + * there is non-reboot way to recover if something goes wrong. + */ + +static int lec_open(struct net_device *dev) +{ + netif_start_queue(dev); + + return 0; +} + +static void +lec_send(struct atm_vcc *vcc, struct sk_buff *skb) +{ + struct net_device *dev = skb->dev; + + ATM_SKB(skb)->vcc = vcc; + atm_account_tx(vcc, skb); + + if (vcc->send(vcc, skb) < 0) { + dev->stats.tx_dropped++; + return; + } + + dev->stats.tx_packets++; + dev->stats.tx_bytes += skb->len; +} + +static void lec_tx_timeout(struct net_device *dev, unsigned int txqueue) +{ + pr_info("%s\n", dev->name); + netif_trans_update(dev); + netif_wake_queue(dev); +} + +static netdev_tx_t lec_start_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct sk_buff *skb2; + struct lec_priv *priv = netdev_priv(dev); + struct lecdatahdr_8023 *lec_h; + struct atm_vcc *vcc; + struct lec_arp_table *entry; + unsigned char *dst; + int min_frame_size; + int is_rdesc; + + pr_debug("called\n"); + if (!priv->lecd) { + pr_info("%s:No lecd attached\n", dev->name); + dev->stats.tx_errors++; + netif_stop_queue(dev); + kfree_skb(skb); + return NETDEV_TX_OK; + } + + pr_debug("skbuff head:%lx data:%lx tail:%lx end:%lx\n", + (long)skb->head, (long)skb->data, (long)skb_tail_pointer(skb), + (long)skb_end_pointer(skb)); +#if IS_ENABLED(CONFIG_BRIDGE) + if (memcmp(skb->data, bridge_ula_lec, sizeof(bridge_ula_lec)) == 0) + lec_handle_bridge(skb, dev); +#endif + + /* Make sure we have room for lec_id */ + if (skb_headroom(skb) < 2) { + pr_debug("reallocating skb\n"); + skb2 = skb_realloc_headroom(skb, LEC_HEADER_LEN); + if (unlikely(!skb2)) { + kfree_skb(skb); + return NETDEV_TX_OK; + } + consume_skb(skb); + skb = skb2; + } + skb_push(skb, 2); + + /* Put le header to place */ + lec_h = (struct lecdatahdr_8023 *)skb->data; + lec_h->le_header = htons(priv->lecid); + +#if DUMP_PACKETS >= 2 +#define MAX_DUMP_SKB 99 +#elif DUMP_PACKETS >= 1 +#define MAX_DUMP_SKB 30 +#endif +#if DUMP_PACKETS >= 1 + printk(KERN_DEBUG "%s: send datalen:%ld lecid:%4.4x\n", + dev->name, skb->len, priv->lecid); + print_hex_dump(KERN_DEBUG, "", DUMP_OFFSET, 16, 1, + skb->data, min(skb->len, MAX_DUMP_SKB), true); +#endif /* DUMP_PACKETS >= 1 */ + + /* Minimum ethernet-frame size */ + min_frame_size = LEC_MINIMUM_8023_SIZE; + if (skb->len < min_frame_size) { + if ((skb->len + skb_tailroom(skb)) < min_frame_size) { + skb2 = skb_copy_expand(skb, 0, + min_frame_size - skb->truesize, + GFP_ATOMIC); + dev_kfree_skb(skb); + if (skb2 == NULL) { + dev->stats.tx_dropped++; + return NETDEV_TX_OK; + } + skb = skb2; + } + skb_put(skb, min_frame_size - skb->len); + } + + /* Send to right vcc */ + is_rdesc = 0; + dst = lec_h->h_dest; + entry = NULL; + vcc = lec_arp_resolve(priv, dst, is_rdesc, &entry); + pr_debug("%s:vcc:%p vcc_flags:%lx, entry:%p\n", + dev->name, vcc, vcc ? vcc->flags : 0, entry); + if (!vcc || !test_bit(ATM_VF_READY, &vcc->flags)) { + if (entry && (entry->tx_wait.qlen < LEC_UNRES_QUE_LEN)) { + pr_debug("%s:queuing packet, MAC address %pM\n", + dev->name, lec_h->h_dest); + skb_queue_tail(&entry->tx_wait, skb); + } else { + pr_debug("%s:tx queue full or no arp entry, dropping, MAC address: %pM\n", + dev->name, lec_h->h_dest); + dev->stats.tx_dropped++; + dev_kfree_skb(skb); + } + goto out; + } +#if DUMP_PACKETS > 0 + printk(KERN_DEBUG "%s:sending to vpi:%d vci:%d\n", + dev->name, vcc->vpi, vcc->vci); +#endif /* DUMP_PACKETS > 0 */ + + while (entry && (skb2 = skb_dequeue(&entry->tx_wait))) { + pr_debug("emptying tx queue, MAC address %pM\n", lec_h->h_dest); + lec_send(vcc, skb2); + } + + lec_send(vcc, skb); + + if (!atm_may_send(vcc, 0)) { + struct lec_vcc_priv *vpriv = LEC_VCC_PRIV(vcc); + + vpriv->xoff = 1; + netif_stop_queue(dev); + + /* + * vcc->pop() might have occurred in between, making + * the vcc usuable again. Since xmit is serialized, + * this is the only situation we have to re-test. + */ + + if (atm_may_send(vcc, 0)) + netif_wake_queue(dev); + } + +out: + if (entry) + lec_arp_put(entry); + netif_trans_update(dev); + return NETDEV_TX_OK; +} + +/* The inverse routine to net_open(). */ +static int lec_close(struct net_device *dev) +{ + netif_stop_queue(dev); + return 0; +} + +static int lec_atm_send(struct atm_vcc *vcc, struct sk_buff *skb) +{ + static const u8 zero_addr[ETH_ALEN] = {}; + unsigned long flags; + struct net_device *dev = (struct net_device *)vcc->proto_data; + struct lec_priv *priv = netdev_priv(dev); + struct atmlec_msg *mesg; + struct lec_arp_table *entry; + char *tmp; /* FIXME */ + + WARN_ON(refcount_sub_and_test(skb->truesize, &sk_atm(vcc)->sk_wmem_alloc)); + mesg = (struct atmlec_msg *)skb->data; + tmp = skb->data; + tmp += sizeof(struct atmlec_msg); + pr_debug("%s: msg from zeppelin:%d\n", dev->name, mesg->type); + switch (mesg->type) { + case l_set_mac_addr: + eth_hw_addr_set(dev, mesg->content.normal.mac_addr); + break; + case l_del_mac_addr: + eth_hw_addr_set(dev, zero_addr); + break; + case l_addr_delete: + lec_addr_delete(priv, mesg->content.normal.atm_addr, + mesg->content.normal.flag); + break; + case l_topology_change: + priv->topology_change = mesg->content.normal.flag; + break; + case l_flush_complete: + lec_flush_complete(priv, mesg->content.normal.flag); + break; + case l_narp_req: /* LANE2: see 7.1.35 in the lane2 spec */ + spin_lock_irqsave(&priv->lec_arp_lock, flags); + entry = lec_arp_find(priv, mesg->content.normal.mac_addr); + lec_arp_remove(priv, entry); + spin_unlock_irqrestore(&priv->lec_arp_lock, flags); + + if (mesg->content.normal.no_source_le_narp) + break; + fallthrough; + case l_arp_update: + lec_arp_update(priv, mesg->content.normal.mac_addr, + mesg->content.normal.atm_addr, + mesg->content.normal.flag, + mesg->content.normal.targetless_le_arp); + pr_debug("in l_arp_update\n"); + if (mesg->sizeoftlvs != 0) { /* LANE2 3.1.5 */ + pr_debug("LANE2 3.1.5, got tlvs, size %d\n", + mesg->sizeoftlvs); + lane2_associate_ind(dev, mesg->content.normal.mac_addr, + tmp, mesg->sizeoftlvs); + } + break; + case l_config: + priv->maximum_unknown_frame_count = + mesg->content.config.maximum_unknown_frame_count; + priv->max_unknown_frame_time = + (mesg->content.config.max_unknown_frame_time * HZ); + priv->max_retry_count = mesg->content.config.max_retry_count; + priv->aging_time = (mesg->content.config.aging_time * HZ); + priv->forward_delay_time = + (mesg->content.config.forward_delay_time * HZ); + priv->arp_response_time = + (mesg->content.config.arp_response_time * HZ); + priv->flush_timeout = (mesg->content.config.flush_timeout * HZ); + priv->path_switching_delay = + (mesg->content.config.path_switching_delay * HZ); + priv->lane_version = mesg->content.config.lane_version; + /* LANE2 */ + priv->lane2_ops = NULL; + if (priv->lane_version > 1) + priv->lane2_ops = &lane2_ops; + rtnl_lock(); + if (dev_set_mtu(dev, mesg->content.config.mtu)) + pr_info("%s: change_mtu to %d failed\n", + dev->name, mesg->content.config.mtu); + rtnl_unlock(); + priv->is_proxy = mesg->content.config.is_proxy; + break; + case l_flush_tran_id: + lec_set_flush_tran_id(priv, mesg->content.normal.atm_addr, + mesg->content.normal.flag); + break; + case l_set_lecid: + priv->lecid = + (unsigned short)(0xffff & mesg->content.normal.flag); + break; + case l_should_bridge: +#if IS_ENABLED(CONFIG_BRIDGE) + { + pr_debug("%s: bridge zeppelin asks about %pM\n", + dev->name, mesg->content.proxy.mac_addr); + + if (br_fdb_test_addr_hook == NULL) + break; + + if (br_fdb_test_addr_hook(dev, mesg->content.proxy.mac_addr)) { + /* hit from bridge table, send LE_ARP_RESPONSE */ + struct sk_buff *skb2; + struct sock *sk; + + pr_debug("%s: entry found, responding to zeppelin\n", + dev->name); + skb2 = alloc_skb(sizeof(struct atmlec_msg), GFP_ATOMIC); + if (skb2 == NULL) + break; + skb2->len = sizeof(struct atmlec_msg); + skb_copy_to_linear_data(skb2, mesg, sizeof(*mesg)); + atm_force_charge(priv->lecd, skb2->truesize); + sk = sk_atm(priv->lecd); + skb_queue_tail(&sk->sk_receive_queue, skb2); + sk->sk_data_ready(sk); + } + } +#endif /* IS_ENABLED(CONFIG_BRIDGE) */ + break; + default: + pr_info("%s: Unknown message type %d\n", dev->name, mesg->type); + dev_kfree_skb(skb); + return -EINVAL; + } + dev_kfree_skb(skb); + return 0; +} + +static void lec_atm_close(struct atm_vcc *vcc) +{ + struct sk_buff *skb; + struct net_device *dev = (struct net_device *)vcc->proto_data; + struct lec_priv *priv = netdev_priv(dev); + + priv->lecd = NULL; + /* Do something needful? */ + + netif_stop_queue(dev); + lec_arp_destroy(priv); + + if (skb_peek(&sk_atm(vcc)->sk_receive_queue)) + pr_info("%s closing with messages pending\n", dev->name); + while ((skb = skb_dequeue(&sk_atm(vcc)->sk_receive_queue))) { + atm_return(vcc, skb->truesize); + dev_kfree_skb(skb); + } + + pr_info("%s: Shut down!\n", dev->name); + module_put(THIS_MODULE); +} + +static const struct atmdev_ops lecdev_ops = { + .close = lec_atm_close, + .send = lec_atm_send +}; + +static struct atm_dev lecatm_dev = { + .ops = &lecdev_ops, + .type = "lec", + .number = 999, /* dummy device number */ + .lock = __SPIN_LOCK_UNLOCKED(lecatm_dev.lock) +}; + +/* + * LANE2: new argument struct sk_buff *data contains + * the LE_ARP based TLVs introduced in the LANE2 spec + */ +static int +send_to_lecd(struct lec_priv *priv, atmlec_msg_type type, + const unsigned char *mac_addr, const unsigned char *atm_addr, + struct sk_buff *data) +{ + struct sock *sk; + struct sk_buff *skb; + struct atmlec_msg *mesg; + + if (!priv || !priv->lecd) + return -1; + skb = alloc_skb(sizeof(struct atmlec_msg), GFP_ATOMIC); + if (!skb) + return -1; + skb->len = sizeof(struct atmlec_msg); + mesg = (struct atmlec_msg *)skb->data; + memset(mesg, 0, sizeof(struct atmlec_msg)); + mesg->type = type; + if (data != NULL) + mesg->sizeoftlvs = data->len; + if (mac_addr) + ether_addr_copy(mesg->content.normal.mac_addr, mac_addr); + else + mesg->content.normal.targetless_le_arp = 1; + if (atm_addr) + memcpy(&mesg->content.normal.atm_addr, atm_addr, ATM_ESA_LEN); + + atm_force_charge(priv->lecd, skb->truesize); + sk = sk_atm(priv->lecd); + skb_queue_tail(&sk->sk_receive_queue, skb); + sk->sk_data_ready(sk); + + if (data != NULL) { + pr_debug("about to send %d bytes of data\n", data->len); + atm_force_charge(priv->lecd, data->truesize); + skb_queue_tail(&sk->sk_receive_queue, data); + sk->sk_data_ready(sk); + } + + return 0; +} + +static void lec_set_multicast_list(struct net_device *dev) +{ + /* + * by default, all multicast frames arrive over the bus. + * eventually support selective multicast service + */ +} + +static const struct net_device_ops lec_netdev_ops = { + .ndo_open = lec_open, + .ndo_stop = lec_close, + .ndo_start_xmit = lec_start_xmit, + .ndo_tx_timeout = lec_tx_timeout, + .ndo_set_rx_mode = lec_set_multicast_list, +}; + +static const unsigned char lec_ctrl_magic[] = { + 0xff, + 0x00, + 0x01, + 0x01 +}; + +#define LEC_DATA_DIRECT_8023 2 +#define LEC_DATA_DIRECT_8025 3 + +static int lec_is_data_direct(struct atm_vcc *vcc) +{ + return ((vcc->sap.blli[0].l3.tr9577.snap[4] == LEC_DATA_DIRECT_8023) || + (vcc->sap.blli[0].l3.tr9577.snap[4] == LEC_DATA_DIRECT_8025)); +} + +static void lec_push(struct atm_vcc *vcc, struct sk_buff *skb) +{ + unsigned long flags; + struct net_device *dev = (struct net_device *)vcc->proto_data; + struct lec_priv *priv = netdev_priv(dev); + +#if DUMP_PACKETS > 0 + printk(KERN_DEBUG "%s: vcc vpi:%d vci:%d\n", + dev->name, vcc->vpi, vcc->vci); +#endif + if (!skb) { + pr_debug("%s: null skb\n", dev->name); + lec_vcc_close(priv, vcc); + return; + } +#if DUMP_PACKETS >= 2 +#define MAX_SKB_DUMP 99 +#elif DUMP_PACKETS >= 1 +#define MAX_SKB_DUMP 30 +#endif +#if DUMP_PACKETS > 0 + printk(KERN_DEBUG "%s: rcv datalen:%ld lecid:%4.4x\n", + dev->name, skb->len, priv->lecid); + print_hex_dump(KERN_DEBUG, "", DUMP_OFFSET, 16, 1, + skb->data, min(MAX_SKB_DUMP, skb->len), true); +#endif /* DUMP_PACKETS > 0 */ + if (memcmp(skb->data, lec_ctrl_magic, 4) == 0) { + /* Control frame, to daemon */ + struct sock *sk = sk_atm(vcc); + + pr_debug("%s: To daemon\n", dev->name); + skb_queue_tail(&sk->sk_receive_queue, skb); + sk->sk_data_ready(sk); + } else { /* Data frame, queue to protocol handlers */ + struct lec_arp_table *entry; + unsigned char *src, *dst; + + atm_return(vcc, skb->truesize); + if (*(__be16 *) skb->data == htons(priv->lecid) || + !priv->lecd || !(dev->flags & IFF_UP)) { + /* + * Probably looping back, or if lecd is missing, + * lecd has gone down + */ + pr_debug("Ignoring frame...\n"); + dev_kfree_skb(skb); + return; + } + dst = ((struct lecdatahdr_8023 *)skb->data)->h_dest; + + /* + * If this is a Data Direct VCC, and the VCC does not match + * the LE_ARP cache entry, delete the LE_ARP cache entry. + */ + spin_lock_irqsave(&priv->lec_arp_lock, flags); + if (lec_is_data_direct(vcc)) { + src = ((struct lecdatahdr_8023 *)skb->data)->h_source; + entry = lec_arp_find(priv, src); + if (entry && entry->vcc != vcc) { + lec_arp_remove(priv, entry); + lec_arp_put(entry); + } + } + spin_unlock_irqrestore(&priv->lec_arp_lock, flags); + + if (!(dst[0] & 0x01) && /* Never filter Multi/Broadcast */ + !priv->is_proxy && /* Proxy wants all the packets */ + memcmp(dst, dev->dev_addr, dev->addr_len)) { + dev_kfree_skb(skb); + return; + } + if (!hlist_empty(&priv->lec_arp_empty_ones)) + lec_arp_check_empties(priv, vcc, skb); + skb_pull(skb, 2); /* skip lec_id */ + skb->protocol = eth_type_trans(skb, dev); + dev->stats.rx_packets++; + dev->stats.rx_bytes += skb->len; + memset(ATM_SKB(skb), 0, sizeof(struct atm_skb_data)); + netif_rx(skb); + } +} + +static void lec_pop(struct atm_vcc *vcc, struct sk_buff *skb) +{ + struct lec_vcc_priv *vpriv = LEC_VCC_PRIV(vcc); + struct net_device *dev = skb->dev; + + if (vpriv == NULL) { + pr_info("vpriv = NULL!?!?!?\n"); + return; + } + + vpriv->old_pop(vcc, skb); + + if (vpriv->xoff && atm_may_send(vcc, 0)) { + vpriv->xoff = 0; + if (netif_running(dev) && netif_queue_stopped(dev)) + netif_wake_queue(dev); + } +} + +static int lec_vcc_attach(struct atm_vcc *vcc, void __user *arg) +{ + struct lec_vcc_priv *vpriv; + int bytes_left; + struct atmlec_ioc ioc_data; + + /* Lecd must be up in this case */ + bytes_left = copy_from_user(&ioc_data, arg, sizeof(struct atmlec_ioc)); + if (bytes_left != 0) + pr_info("copy from user failed for %d bytes\n", bytes_left); + if (ioc_data.dev_num < 0 || ioc_data.dev_num >= MAX_LEC_ITF) + return -EINVAL; + ioc_data.dev_num = array_index_nospec(ioc_data.dev_num, MAX_LEC_ITF); + if (!dev_lec[ioc_data.dev_num]) + return -EINVAL; + vpriv = kmalloc(sizeof(struct lec_vcc_priv), GFP_KERNEL); + if (!vpriv) + return -ENOMEM; + vpriv->xoff = 0; + vpriv->old_pop = vcc->pop; + vcc->user_back = vpriv; + vcc->pop = lec_pop; + lec_vcc_added(netdev_priv(dev_lec[ioc_data.dev_num]), + &ioc_data, vcc, vcc->push); + vcc->proto_data = dev_lec[ioc_data.dev_num]; + vcc->push = lec_push; + return 0; +} + +static int lec_mcast_attach(struct atm_vcc *vcc, int arg) +{ + if (arg < 0 || arg >= MAX_LEC_ITF) + return -EINVAL; + arg = array_index_nospec(arg, MAX_LEC_ITF); + if (!dev_lec[arg]) + return -EINVAL; + vcc->proto_data = dev_lec[arg]; + return lec_mcast_make(netdev_priv(dev_lec[arg]), vcc); +} + +/* Initialize device. */ +static int lecd_attach(struct atm_vcc *vcc, int arg) +{ + int i; + struct lec_priv *priv; + + if (arg < 0) + arg = 0; + if (arg >= MAX_LEC_ITF) + return -EINVAL; + i = array_index_nospec(arg, MAX_LEC_ITF); + if (!dev_lec[i]) { + int size; + + size = sizeof(struct lec_priv); + dev_lec[i] = alloc_etherdev(size); + if (!dev_lec[i]) + return -ENOMEM; + dev_lec[i]->netdev_ops = &lec_netdev_ops; + dev_lec[i]->max_mtu = 18190; + snprintf(dev_lec[i]->name, IFNAMSIZ, "lec%d", i); + if (register_netdev(dev_lec[i])) { + free_netdev(dev_lec[i]); + return -EINVAL; + } + + priv = netdev_priv(dev_lec[i]); + } else { + priv = netdev_priv(dev_lec[i]); + if (priv->lecd) + return -EADDRINUSE; + } + lec_arp_init(priv); + priv->itfnum = i; /* LANE2 addition */ + priv->lecd = vcc; + vcc->dev = &lecatm_dev; + vcc_insert_socket(sk_atm(vcc)); + + vcc->proto_data = dev_lec[i]; + set_bit(ATM_VF_META, &vcc->flags); + set_bit(ATM_VF_READY, &vcc->flags); + + /* Set default values to these variables */ + priv->maximum_unknown_frame_count = 1; + priv->max_unknown_frame_time = (1 * HZ); + priv->vcc_timeout_period = (1200 * HZ); + priv->max_retry_count = 1; + priv->aging_time = (300 * HZ); + priv->forward_delay_time = (15 * HZ); + priv->topology_change = 0; + priv->arp_response_time = (1 * HZ); + priv->flush_timeout = (4 * HZ); + priv->path_switching_delay = (6 * HZ); + + if (dev_lec[i]->flags & IFF_UP) + netif_start_queue(dev_lec[i]); + __module_get(THIS_MODULE); + return i; +} + +#ifdef CONFIG_PROC_FS +static const char *lec_arp_get_status_string(unsigned char status) +{ + static const char *const lec_arp_status_string[] = { + "ESI_UNKNOWN ", + "ESI_ARP_PENDING ", + "ESI_VC_PENDING ", + " ", + "ESI_FLUSH_PENDING ", + "ESI_FORWARD_DIRECT" + }; + + if (status > ESI_FORWARD_DIRECT) + status = 3; /* ESI_UNDEFINED */ + return lec_arp_status_string[status]; +} + +static void lec_info(struct seq_file *seq, struct lec_arp_table *entry) +{ + seq_printf(seq, "%pM ", entry->mac_addr); + seq_printf(seq, "%*phN ", ATM_ESA_LEN, entry->atm_addr); + seq_printf(seq, "%s %4.4x", lec_arp_get_status_string(entry->status), + entry->flags & 0xffff); + if (entry->vcc) + seq_printf(seq, "%3d %3d ", entry->vcc->vpi, entry->vcc->vci); + else + seq_printf(seq, " "); + if (entry->recv_vcc) { + seq_printf(seq, " %3d %3d", entry->recv_vcc->vpi, + entry->recv_vcc->vci); + } + seq_putc(seq, '\n'); +} + +struct lec_state { + unsigned long flags; + struct lec_priv *locked; + struct hlist_node *node; + struct net_device *dev; + int itf; + int arp_table; + int misc_table; +}; + +static void *lec_tbl_walk(struct lec_state *state, struct hlist_head *tbl, + loff_t *l) +{ + struct hlist_node *e = state->node; + + if (!e) + e = tbl->first; + if (e == SEQ_START_TOKEN) { + e = tbl->first; + --*l; + } + + for (; e; e = e->next) { + if (--*l < 0) + break; + } + state->node = e; + + return (*l < 0) ? state : NULL; +} + +static void *lec_arp_walk(struct lec_state *state, loff_t *l, + struct lec_priv *priv) +{ + void *v = NULL; + int p; + + for (p = state->arp_table; p < LEC_ARP_TABLE_SIZE; p++) { + v = lec_tbl_walk(state, &priv->lec_arp_tables[p], l); + if (v) + break; + } + state->arp_table = p; + return v; +} + +static void *lec_misc_walk(struct lec_state *state, loff_t *l, + struct lec_priv *priv) +{ + struct hlist_head *lec_misc_tables[] = { + &priv->lec_arp_empty_ones, + &priv->lec_no_forward, + &priv->mcast_fwds + }; + void *v = NULL; + int q; + + for (q = state->misc_table; q < ARRAY_SIZE(lec_misc_tables); q++) { + v = lec_tbl_walk(state, lec_misc_tables[q], l); + if (v) + break; + } + state->misc_table = q; + return v; +} + +static void *lec_priv_walk(struct lec_state *state, loff_t *l, + struct lec_priv *priv) +{ + if (!state->locked) { + state->locked = priv; + spin_lock_irqsave(&priv->lec_arp_lock, state->flags); + } + if (!lec_arp_walk(state, l, priv) && !lec_misc_walk(state, l, priv)) { + spin_unlock_irqrestore(&priv->lec_arp_lock, state->flags); + state->locked = NULL; + /* Partial state reset for the next time we get called */ + state->arp_table = state->misc_table = 0; + } + return state->locked; +} + +static void *lec_itf_walk(struct lec_state *state, loff_t *l) +{ + struct net_device *dev; + void *v; + + dev = state->dev ? state->dev : dev_lec[state->itf]; + v = (dev && netdev_priv(dev)) ? + lec_priv_walk(state, l, netdev_priv(dev)) : NULL; + if (!v && dev) { + dev_put(dev); + /* Partial state reset for the next time we get called */ + dev = NULL; + } + state->dev = dev; + return v; +} + +static void *lec_get_idx(struct lec_state *state, loff_t l) +{ + void *v = NULL; + + for (; state->itf < MAX_LEC_ITF; state->itf++) { + v = lec_itf_walk(state, &l); + if (v) + break; + } + return v; +} + +static void *lec_seq_start(struct seq_file *seq, loff_t *pos) +{ + struct lec_state *state = seq->private; + + state->itf = 0; + state->dev = NULL; + state->locked = NULL; + state->arp_table = 0; + state->misc_table = 0; + state->node = SEQ_START_TOKEN; + + return *pos ? lec_get_idx(state, *pos) : SEQ_START_TOKEN; +} + +static void lec_seq_stop(struct seq_file *seq, void *v) +{ + struct lec_state *state = seq->private; + + if (state->dev) { + spin_unlock_irqrestore(&state->locked->lec_arp_lock, + state->flags); + dev_put(state->dev); + } +} + +static void *lec_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct lec_state *state = seq->private; + + ++*pos; + return lec_get_idx(state, 1); +} + +static int lec_seq_show(struct seq_file *seq, void *v) +{ + static const char lec_banner[] = + "Itf MAC ATM destination" + " Status Flags " + "VPI/VCI Recv VPI/VCI\n"; + + if (v == SEQ_START_TOKEN) + seq_puts(seq, lec_banner); + else { + struct lec_state *state = seq->private; + struct net_device *dev = state->dev; + struct lec_arp_table *entry = hlist_entry(state->node, + struct lec_arp_table, + next); + + seq_printf(seq, "%s ", dev->name); + lec_info(seq, entry); + } + return 0; +} + +static const struct seq_operations lec_seq_ops = { + .start = lec_seq_start, + .next = lec_seq_next, + .stop = lec_seq_stop, + .show = lec_seq_show, +}; +#endif + +static int lane_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + struct atm_vcc *vcc = ATM_SD(sock); + int err = 0; + + switch (cmd) { + case ATMLEC_CTRL: + case ATMLEC_MCAST: + case ATMLEC_DATA: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + break; + default: + return -ENOIOCTLCMD; + } + + switch (cmd) { + case ATMLEC_CTRL: + err = lecd_attach(vcc, (int)arg); + if (err >= 0) + sock->state = SS_CONNECTED; + break; + case ATMLEC_MCAST: + err = lec_mcast_attach(vcc, (int)arg); + break; + case ATMLEC_DATA: + err = lec_vcc_attach(vcc, (void __user *)arg); + break; + } + + return err; +} + +static struct atm_ioctl lane_ioctl_ops = { + .owner = THIS_MODULE, + .ioctl = lane_ioctl, +}; + +static int __init lane_module_init(void) +{ +#ifdef CONFIG_PROC_FS + struct proc_dir_entry *p; + + p = proc_create_seq_private("lec", 0444, atm_proc_root, &lec_seq_ops, + sizeof(struct lec_state), NULL); + if (!p) { + pr_err("Unable to initialize /proc/net/atm/lec\n"); + return -ENOMEM; + } +#endif + + register_atm_ioctl(&lane_ioctl_ops); + pr_info("lec.c: initialized\n"); + return 0; +} + +static void __exit lane_module_cleanup(void) +{ + int i; + +#ifdef CONFIG_PROC_FS + remove_proc_entry("lec", atm_proc_root); +#endif + + deregister_atm_ioctl(&lane_ioctl_ops); + + for (i = 0; i < MAX_LEC_ITF; i++) { + if (dev_lec[i] != NULL) { + unregister_netdev(dev_lec[i]); + free_netdev(dev_lec[i]); + dev_lec[i] = NULL; + } + } +} + +module_init(lane_module_init); +module_exit(lane_module_cleanup); + +/* + * LANE2: 3.1.3, LE_RESOLVE.request + * Non force allocates memory and fills in *tlvs, fills in *sizeoftlvs. + * If sizeoftlvs == NULL the default TLVs associated with this + * lec will be used. + * If dst_mac == NULL, targetless LE_ARP will be sent + */ +static int lane2_resolve(struct net_device *dev, const u8 *dst_mac, int force, + u8 **tlvs, u32 *sizeoftlvs) +{ + unsigned long flags; + struct lec_priv *priv = netdev_priv(dev); + struct lec_arp_table *table; + struct sk_buff *skb; + int retval; + + if (force == 0) { + spin_lock_irqsave(&priv->lec_arp_lock, flags); + table = lec_arp_find(priv, dst_mac); + spin_unlock_irqrestore(&priv->lec_arp_lock, flags); + if (table == NULL) + return -1; + + *tlvs = kmemdup(table->tlvs, table->sizeoftlvs, GFP_ATOMIC); + if (*tlvs == NULL) + return -1; + + *sizeoftlvs = table->sizeoftlvs; + + return 0; + } + + if (sizeoftlvs == NULL) + retval = send_to_lecd(priv, l_arp_xmt, dst_mac, NULL, NULL); + + else { + skb = alloc_skb(*sizeoftlvs, GFP_ATOMIC); + if (skb == NULL) + return -1; + skb->len = *sizeoftlvs; + skb_copy_to_linear_data(skb, *tlvs, *sizeoftlvs); + retval = send_to_lecd(priv, l_arp_xmt, dst_mac, NULL, skb); + } + return retval; +} + +/* + * LANE2: 3.1.4, LE_ASSOCIATE.request + * Associate the *tlvs with the *lan_dst address. + * Will overwrite any previous association + * Returns 1 for success, 0 for failure (out of memory) + * + */ +static int lane2_associate_req(struct net_device *dev, const u8 *lan_dst, + const u8 *tlvs, u32 sizeoftlvs) +{ + int retval; + struct sk_buff *skb; + struct lec_priv *priv = netdev_priv(dev); + + if (!ether_addr_equal(lan_dst, dev->dev_addr)) + return 0; /* not our mac address */ + + kfree(priv->tlvs); /* NULL if there was no previous association */ + + priv->tlvs = kmemdup(tlvs, sizeoftlvs, GFP_KERNEL); + if (priv->tlvs == NULL) + return 0; + priv->sizeoftlvs = sizeoftlvs; + + skb = alloc_skb(sizeoftlvs, GFP_ATOMIC); + if (skb == NULL) + return 0; + skb->len = sizeoftlvs; + skb_copy_to_linear_data(skb, tlvs, sizeoftlvs); + retval = send_to_lecd(priv, l_associate_req, NULL, NULL, skb); + if (retval != 0) + pr_info("lec.c: lane2_associate_req() failed\n"); + /* + * If the previous association has changed we must + * somehow notify other LANE entities about the change + */ + return 1; +} + +/* + * LANE2: 3.1.5, LE_ASSOCIATE.indication + * + */ +static void lane2_associate_ind(struct net_device *dev, const u8 *mac_addr, + const u8 *tlvs, u32 sizeoftlvs) +{ +#if 0 + int i = 0; +#endif + struct lec_priv *priv = netdev_priv(dev); +#if 0 /* + * Why have the TLVs in LE_ARP entries + * since we do not use them? When you + * uncomment this code, make sure the + * TLVs get freed when entry is killed + */ + struct lec_arp_table *entry = lec_arp_find(priv, mac_addr); + + if (entry == NULL) + return; /* should not happen */ + + kfree(entry->tlvs); + + entry->tlvs = kmemdup(tlvs, sizeoftlvs, GFP_KERNEL); + if (entry->tlvs == NULL) + return; + entry->sizeoftlvs = sizeoftlvs; +#endif +#if 0 + pr_info("\n"); + pr_info("dump of tlvs, sizeoftlvs=%d\n", sizeoftlvs); + while (i < sizeoftlvs) + pr_cont("%02x ", tlvs[i++]); + + pr_cont("\n"); +#endif + + /* tell MPOA about the TLVs we saw */ + if (priv->lane2_ops && priv->lane2_ops->associate_indicator) { + priv->lane2_ops->associate_indicator(dev, mac_addr, + tlvs, sizeoftlvs); + } +} + +/* + * Here starts what used to lec_arpc.c + * + * lec_arpc.c was added here when making + * lane client modular. October 1997 + */ + +#include +#include +#include +#include +#include +#include + +#if 0 +#define pr_debug(format, args...) +/* + #define pr_debug printk +*/ +#endif +#define DEBUG_ARP_TABLE 0 + +#define LEC_ARP_REFRESH_INTERVAL (3*HZ) + +static void lec_arp_check_expire(struct work_struct *work); +static void lec_arp_expire_arp(struct timer_list *t); + +/* + * Arp table funcs + */ + +#define HASH(ch) (ch & (LEC_ARP_TABLE_SIZE - 1)) + +/* + * Initialization of arp-cache + */ +static void lec_arp_init(struct lec_priv *priv) +{ + unsigned short i; + + for (i = 0; i < LEC_ARP_TABLE_SIZE; i++) + INIT_HLIST_HEAD(&priv->lec_arp_tables[i]); + INIT_HLIST_HEAD(&priv->lec_arp_empty_ones); + INIT_HLIST_HEAD(&priv->lec_no_forward); + INIT_HLIST_HEAD(&priv->mcast_fwds); + spin_lock_init(&priv->lec_arp_lock); + INIT_DELAYED_WORK(&priv->lec_arp_work, lec_arp_check_expire); + schedule_delayed_work(&priv->lec_arp_work, LEC_ARP_REFRESH_INTERVAL); +} + +static void lec_arp_clear_vccs(struct lec_arp_table *entry) +{ + if (entry->vcc) { + struct atm_vcc *vcc = entry->vcc; + struct lec_vcc_priv *vpriv = LEC_VCC_PRIV(vcc); + struct net_device *dev = (struct net_device *)vcc->proto_data; + + vcc->pop = vpriv->old_pop; + if (vpriv->xoff) + netif_wake_queue(dev); + kfree(vpriv); + vcc->user_back = NULL; + vcc->push = entry->old_push; + vcc_release_async(vcc, -EPIPE); + entry->vcc = NULL; + } + if (entry->recv_vcc) { + struct atm_vcc *vcc = entry->recv_vcc; + struct lec_vcc_priv *vpriv = LEC_VCC_PRIV(vcc); + + kfree(vpriv); + vcc->user_back = NULL; + + entry->recv_vcc->push = entry->old_recv_push; + vcc_release_async(entry->recv_vcc, -EPIPE); + entry->recv_vcc = NULL; + } +} + +/* + * Insert entry to lec_arp_table + * LANE2: Add to the end of the list to satisfy 8.1.13 + */ +static inline void +lec_arp_add(struct lec_priv *priv, struct lec_arp_table *entry) +{ + struct hlist_head *tmp; + + tmp = &priv->lec_arp_tables[HASH(entry->mac_addr[ETH_ALEN - 1])]; + hlist_add_head(&entry->next, tmp); + + pr_debug("Added entry:%pM\n", entry->mac_addr); +} + +/* + * Remove entry from lec_arp_table + */ +static int +lec_arp_remove(struct lec_priv *priv, struct lec_arp_table *to_remove) +{ + struct lec_arp_table *entry; + int i, remove_vcc = 1; + + if (!to_remove) + return -1; + + hlist_del(&to_remove->next); + del_timer(&to_remove->timer); + + /* + * If this is the only MAC connected to this VCC, + * also tear down the VCC + */ + if (to_remove->status >= ESI_FLUSH_PENDING) { + /* + * ESI_FLUSH_PENDING, ESI_FORWARD_DIRECT + */ + for (i = 0; i < LEC_ARP_TABLE_SIZE; i++) { + hlist_for_each_entry(entry, + &priv->lec_arp_tables[i], next) { + if (memcmp(to_remove->atm_addr, + entry->atm_addr, ATM_ESA_LEN) == 0) { + remove_vcc = 0; + break; + } + } + } + if (remove_vcc) + lec_arp_clear_vccs(to_remove); + } + skb_queue_purge(&to_remove->tx_wait); /* FIXME: good place for this? */ + + pr_debug("Removed entry:%pM\n", to_remove->mac_addr); + return 0; +} + +#if DEBUG_ARP_TABLE +static const char *get_status_string(unsigned char st) +{ + switch (st) { + case ESI_UNKNOWN: + return "ESI_UNKNOWN"; + case ESI_ARP_PENDING: + return "ESI_ARP_PENDING"; + case ESI_VC_PENDING: + return "ESI_VC_PENDING"; + case ESI_FLUSH_PENDING: + return "ESI_FLUSH_PENDING"; + case ESI_FORWARD_DIRECT: + return "ESI_FORWARD_DIRECT"; + } + return ""; +} + +static void dump_arp_table(struct lec_priv *priv) +{ + struct lec_arp_table *rulla; + char buf[256]; + int i, offset; + + pr_info("Dump %p:\n", priv); + for (i = 0; i < LEC_ARP_TABLE_SIZE; i++) { + hlist_for_each_entry(rulla, + &priv->lec_arp_tables[i], next) { + offset = 0; + offset += sprintf(buf, "%d: %p\n", i, rulla); + offset += sprintf(buf + offset, "Mac: %pM ", + rulla->mac_addr); + offset += sprintf(buf + offset, "Atm: %*ph ", ATM_ESA_LEN, + rulla->atm_addr); + offset += sprintf(buf + offset, + "Vcc vpi:%d vci:%d, Recv_vcc vpi:%d vci:%d Last_used:%lx, Timestamp:%lx, No_tries:%d ", + rulla->vcc ? rulla->vcc->vpi : 0, + rulla->vcc ? rulla->vcc->vci : 0, + rulla->recv_vcc ? rulla->recv_vcc-> + vpi : 0, + rulla->recv_vcc ? rulla->recv_vcc-> + vci : 0, rulla->last_used, + rulla->timestamp, rulla->no_tries); + offset += + sprintf(buf + offset, + "Flags:%x, Packets_flooded:%x, Status: %s ", + rulla->flags, rulla->packets_flooded, + get_status_string(rulla->status)); + pr_info("%s\n", buf); + } + } + + if (!hlist_empty(&priv->lec_no_forward)) + pr_info("No forward\n"); + hlist_for_each_entry(rulla, &priv->lec_no_forward, next) { + offset = 0; + offset += sprintf(buf + offset, "Mac: %pM ", rulla->mac_addr); + offset += sprintf(buf + offset, "Atm: %*ph ", ATM_ESA_LEN, + rulla->atm_addr); + offset += sprintf(buf + offset, + "Vcc vpi:%d vci:%d, Recv_vcc vpi:%d vci:%d Last_used:%lx, Timestamp:%lx, No_tries:%d ", + rulla->vcc ? rulla->vcc->vpi : 0, + rulla->vcc ? rulla->vcc->vci : 0, + rulla->recv_vcc ? rulla->recv_vcc->vpi : 0, + rulla->recv_vcc ? rulla->recv_vcc->vci : 0, + rulla->last_used, + rulla->timestamp, rulla->no_tries); + offset += sprintf(buf + offset, + "Flags:%x, Packets_flooded:%x, Status: %s ", + rulla->flags, rulla->packets_flooded, + get_status_string(rulla->status)); + pr_info("%s\n", buf); + } + + if (!hlist_empty(&priv->lec_arp_empty_ones)) + pr_info("Empty ones\n"); + hlist_for_each_entry(rulla, &priv->lec_arp_empty_ones, next) { + offset = 0; + offset += sprintf(buf + offset, "Mac: %pM ", rulla->mac_addr); + offset += sprintf(buf + offset, "Atm: %*ph ", ATM_ESA_LEN, + rulla->atm_addr); + offset += sprintf(buf + offset, + "Vcc vpi:%d vci:%d, Recv_vcc vpi:%d vci:%d Last_used:%lx, Timestamp:%lx, No_tries:%d ", + rulla->vcc ? rulla->vcc->vpi : 0, + rulla->vcc ? rulla->vcc->vci : 0, + rulla->recv_vcc ? rulla->recv_vcc->vpi : 0, + rulla->recv_vcc ? rulla->recv_vcc->vci : 0, + rulla->last_used, + rulla->timestamp, rulla->no_tries); + offset += sprintf(buf + offset, + "Flags:%x, Packets_flooded:%x, Status: %s ", + rulla->flags, rulla->packets_flooded, + get_status_string(rulla->status)); + pr_info("%s", buf); + } + + if (!hlist_empty(&priv->mcast_fwds)) + pr_info("Multicast Forward VCCs\n"); + hlist_for_each_entry(rulla, &priv->mcast_fwds, next) { + offset = 0; + offset += sprintf(buf + offset, "Mac: %pM ", rulla->mac_addr); + offset += sprintf(buf + offset, "Atm: %*ph ", ATM_ESA_LEN, + rulla->atm_addr); + offset += sprintf(buf + offset, + "Vcc vpi:%d vci:%d, Recv_vcc vpi:%d vci:%d Last_used:%lx, Timestamp:%lx, No_tries:%d ", + rulla->vcc ? rulla->vcc->vpi : 0, + rulla->vcc ? rulla->vcc->vci : 0, + rulla->recv_vcc ? rulla->recv_vcc->vpi : 0, + rulla->recv_vcc ? rulla->recv_vcc->vci : 0, + rulla->last_used, + rulla->timestamp, rulla->no_tries); + offset += sprintf(buf + offset, + "Flags:%x, Packets_flooded:%x, Status: %s ", + rulla->flags, rulla->packets_flooded, + get_status_string(rulla->status)); + pr_info("%s\n", buf); + } + +} +#else +#define dump_arp_table(priv) do { } while (0) +#endif + +/* + * Destruction of arp-cache + */ +static void lec_arp_destroy(struct lec_priv *priv) +{ + unsigned long flags; + struct hlist_node *next; + struct lec_arp_table *entry; + int i; + + cancel_delayed_work_sync(&priv->lec_arp_work); + + /* + * Remove all entries + */ + + spin_lock_irqsave(&priv->lec_arp_lock, flags); + for (i = 0; i < LEC_ARP_TABLE_SIZE; i++) { + hlist_for_each_entry_safe(entry, next, + &priv->lec_arp_tables[i], next) { + lec_arp_remove(priv, entry); + lec_arp_put(entry); + } + INIT_HLIST_HEAD(&priv->lec_arp_tables[i]); + } + + hlist_for_each_entry_safe(entry, next, + &priv->lec_arp_empty_ones, next) { + del_timer_sync(&entry->timer); + lec_arp_clear_vccs(entry); + hlist_del(&entry->next); + lec_arp_put(entry); + } + INIT_HLIST_HEAD(&priv->lec_arp_empty_ones); + + hlist_for_each_entry_safe(entry, next, + &priv->lec_no_forward, next) { + del_timer_sync(&entry->timer); + lec_arp_clear_vccs(entry); + hlist_del(&entry->next); + lec_arp_put(entry); + } + INIT_HLIST_HEAD(&priv->lec_no_forward); + + hlist_for_each_entry_safe(entry, next, &priv->mcast_fwds, next) { + /* No timer, LANEv2 7.1.20 and 2.3.5.3 */ + lec_arp_clear_vccs(entry); + hlist_del(&entry->next); + lec_arp_put(entry); + } + INIT_HLIST_HEAD(&priv->mcast_fwds); + priv->mcast_vcc = NULL; + spin_unlock_irqrestore(&priv->lec_arp_lock, flags); +} + +/* + * Find entry by mac_address + */ +static struct lec_arp_table *lec_arp_find(struct lec_priv *priv, + const unsigned char *mac_addr) +{ + struct hlist_head *head; + struct lec_arp_table *entry; + + pr_debug("%pM\n", mac_addr); + + head = &priv->lec_arp_tables[HASH(mac_addr[ETH_ALEN - 1])]; + hlist_for_each_entry(entry, head, next) { + if (ether_addr_equal(mac_addr, entry->mac_addr)) + return entry; + } + return NULL; +} + +static struct lec_arp_table *make_entry(struct lec_priv *priv, + const unsigned char *mac_addr) +{ + struct lec_arp_table *to_return; + + to_return = kzalloc(sizeof(struct lec_arp_table), GFP_ATOMIC); + if (!to_return) + return NULL; + ether_addr_copy(to_return->mac_addr, mac_addr); + INIT_HLIST_NODE(&to_return->next); + timer_setup(&to_return->timer, lec_arp_expire_arp, 0); + to_return->last_used = jiffies; + to_return->priv = priv; + skb_queue_head_init(&to_return->tx_wait); + refcount_set(&to_return->usage, 1); + return to_return; +} + +/* Arp sent timer expired */ +static void lec_arp_expire_arp(struct timer_list *t) +{ + struct lec_arp_table *entry; + + entry = from_timer(entry, t, timer); + + pr_debug("\n"); + if (entry->status == ESI_ARP_PENDING) { + if (entry->no_tries <= entry->priv->max_retry_count) { + if (entry->is_rdesc) + send_to_lecd(entry->priv, l_rdesc_arp_xmt, + entry->mac_addr, NULL, NULL); + else + send_to_lecd(entry->priv, l_arp_xmt, + entry->mac_addr, NULL, NULL); + entry->no_tries++; + } + mod_timer(&entry->timer, jiffies + (1 * HZ)); + } +} + +/* Unknown/unused vcc expire, remove associated entry */ +static void lec_arp_expire_vcc(struct timer_list *t) +{ + unsigned long flags; + struct lec_arp_table *to_remove = from_timer(to_remove, t, timer); + struct lec_priv *priv = to_remove->priv; + + del_timer(&to_remove->timer); + + pr_debug("%p %p: vpi:%d vci:%d\n", + to_remove, priv, + to_remove->vcc ? to_remove->recv_vcc->vpi : 0, + to_remove->vcc ? to_remove->recv_vcc->vci : 0); + + spin_lock_irqsave(&priv->lec_arp_lock, flags); + hlist_del(&to_remove->next); + spin_unlock_irqrestore(&priv->lec_arp_lock, flags); + + lec_arp_clear_vccs(to_remove); + lec_arp_put(to_remove); +} + +static bool __lec_arp_check_expire(struct lec_arp_table *entry, + unsigned long now, + struct lec_priv *priv) +{ + unsigned long time_to_check; + + if ((entry->flags) & LEC_REMOTE_FLAG && priv->topology_change) + time_to_check = priv->forward_delay_time; + else + time_to_check = priv->aging_time; + + pr_debug("About to expire: %lx - %lx > %lx\n", + now, entry->last_used, time_to_check); + if (time_after(now, entry->last_used + time_to_check) && + !(entry->flags & LEC_PERMANENT_FLAG) && + !(entry->mac_addr[0] & 0x01)) { /* LANE2: 7.1.20 */ + /* Remove entry */ + pr_debug("Entry timed out\n"); + lec_arp_remove(priv, entry); + lec_arp_put(entry); + } else { + /* Something else */ + if ((entry->status == ESI_VC_PENDING || + entry->status == ESI_ARP_PENDING) && + time_after_eq(now, entry->timestamp + + priv->max_unknown_frame_time)) { + entry->timestamp = jiffies; + entry->packets_flooded = 0; + if (entry->status == ESI_VC_PENDING) + send_to_lecd(priv, l_svc_setup, + entry->mac_addr, + entry->atm_addr, + NULL); + } + if (entry->status == ESI_FLUSH_PENDING && + time_after_eq(now, entry->timestamp + + priv->path_switching_delay)) { + lec_arp_hold(entry); + return true; + } + } + + return false; +} +/* + * Expire entries. + * 1. Re-set timer + * 2. For each entry, delete entries that have aged past the age limit. + * 3. For each entry, depending on the status of the entry, perform + * the following maintenance. + * a. If status is ESI_VC_PENDING or ESI_ARP_PENDING then if the + * tick_count is above the max_unknown_frame_time, clear + * the tick_count to zero and clear the packets_flooded counter + * to zero. This supports the packet rate limit per address + * while flooding unknowns. + * b. If the status is ESI_FLUSH_PENDING and the tick_count is greater + * than or equal to the path_switching_delay, change the status + * to ESI_FORWARD_DIRECT. This causes the flush period to end + * regardless of the progress of the flush protocol. + */ +static void lec_arp_check_expire(struct work_struct *work) +{ + unsigned long flags; + struct lec_priv *priv = + container_of(work, struct lec_priv, lec_arp_work.work); + struct hlist_node *next; + struct lec_arp_table *entry; + unsigned long now; + int i; + + pr_debug("%p\n", priv); + now = jiffies; +restart: + spin_lock_irqsave(&priv->lec_arp_lock, flags); + for (i = 0; i < LEC_ARP_TABLE_SIZE; i++) { + hlist_for_each_entry_safe(entry, next, + &priv->lec_arp_tables[i], next) { + if (__lec_arp_check_expire(entry, now, priv)) { + struct sk_buff *skb; + struct atm_vcc *vcc = entry->vcc; + + spin_unlock_irqrestore(&priv->lec_arp_lock, + flags); + while ((skb = skb_dequeue(&entry->tx_wait))) + lec_send(vcc, skb); + entry->last_used = jiffies; + entry->status = ESI_FORWARD_DIRECT; + lec_arp_put(entry); + + goto restart; + } + } + } + spin_unlock_irqrestore(&priv->lec_arp_lock, flags); + + schedule_delayed_work(&priv->lec_arp_work, LEC_ARP_REFRESH_INTERVAL); +} + +/* + * Try to find vcc where mac_address is attached. + * + */ +static struct atm_vcc *lec_arp_resolve(struct lec_priv *priv, + const unsigned char *mac_to_find, + int is_rdesc, + struct lec_arp_table **ret_entry) +{ + unsigned long flags; + struct lec_arp_table *entry; + struct atm_vcc *found; + + if (mac_to_find[0] & 0x01) { + switch (priv->lane_version) { + case 1: + return priv->mcast_vcc; + case 2: /* LANE2 wants arp for multicast addresses */ + if (ether_addr_equal(mac_to_find, bus_mac)) + return priv->mcast_vcc; + break; + default: + break; + } + } + + spin_lock_irqsave(&priv->lec_arp_lock, flags); + entry = lec_arp_find(priv, mac_to_find); + + if (entry) { + if (entry->status == ESI_FORWARD_DIRECT) { + /* Connection Ok */ + entry->last_used = jiffies; + lec_arp_hold(entry); + *ret_entry = entry; + found = entry->vcc; + goto out; + } + /* + * If the LE_ARP cache entry is still pending, reset count to 0 + * so another LE_ARP request can be made for this frame. + */ + if (entry->status == ESI_ARP_PENDING) + entry->no_tries = 0; + /* + * Data direct VC not yet set up, check to see if the unknown + * frame count is greater than the limit. If the limit has + * not been reached, allow the caller to send packet to + * BUS. + */ + if (entry->status != ESI_FLUSH_PENDING && + entry->packets_flooded < + priv->maximum_unknown_frame_count) { + entry->packets_flooded++; + pr_debug("Flooding..\n"); + found = priv->mcast_vcc; + goto out; + } + /* + * We got here because entry->status == ESI_FLUSH_PENDING + * or BUS flood limit was reached for an entry which is + * in ESI_ARP_PENDING or ESI_VC_PENDING state. + */ + lec_arp_hold(entry); + *ret_entry = entry; + pr_debug("entry->status %d entry->vcc %p\n", entry->status, + entry->vcc); + found = NULL; + } else { + /* No matching entry was found */ + entry = make_entry(priv, mac_to_find); + pr_debug("Making entry\n"); + if (!entry) { + found = priv->mcast_vcc; + goto out; + } + lec_arp_add(priv, entry); + /* We want arp-request(s) to be sent */ + entry->packets_flooded = 1; + entry->status = ESI_ARP_PENDING; + entry->no_tries = 1; + entry->last_used = entry->timestamp = jiffies; + entry->is_rdesc = is_rdesc; + if (entry->is_rdesc) + send_to_lecd(priv, l_rdesc_arp_xmt, mac_to_find, NULL, + NULL); + else + send_to_lecd(priv, l_arp_xmt, mac_to_find, NULL, NULL); + entry->timer.expires = jiffies + (1 * HZ); + entry->timer.function = lec_arp_expire_arp; + add_timer(&entry->timer); + found = priv->mcast_vcc; + } + +out: + spin_unlock_irqrestore(&priv->lec_arp_lock, flags); + return found; +} + +static int +lec_addr_delete(struct lec_priv *priv, const unsigned char *atm_addr, + unsigned long permanent) +{ + unsigned long flags; + struct hlist_node *next; + struct lec_arp_table *entry; + int i; + + pr_debug("\n"); + spin_lock_irqsave(&priv->lec_arp_lock, flags); + for (i = 0; i < LEC_ARP_TABLE_SIZE; i++) { + hlist_for_each_entry_safe(entry, next, + &priv->lec_arp_tables[i], next) { + if (!memcmp(atm_addr, entry->atm_addr, ATM_ESA_LEN) && + (permanent || + !(entry->flags & LEC_PERMANENT_FLAG))) { + lec_arp_remove(priv, entry); + lec_arp_put(entry); + } + spin_unlock_irqrestore(&priv->lec_arp_lock, flags); + return 0; + } + } + spin_unlock_irqrestore(&priv->lec_arp_lock, flags); + return -1; +} + +/* + * Notifies: Response to arp_request (atm_addr != NULL) + */ +static void +lec_arp_update(struct lec_priv *priv, const unsigned char *mac_addr, + const unsigned char *atm_addr, unsigned long remoteflag, + unsigned int targetless_le_arp) +{ + unsigned long flags; + struct hlist_node *next; + struct lec_arp_table *entry, *tmp; + int i; + + pr_debug("%smac:%pM\n", + (targetless_le_arp) ? "targetless " : "", mac_addr); + + spin_lock_irqsave(&priv->lec_arp_lock, flags); + entry = lec_arp_find(priv, mac_addr); + if (entry == NULL && targetless_le_arp) + goto out; /* + * LANE2: ignore targetless LE_ARPs for which + * we have no entry in the cache. 7.1.30 + */ + if (!hlist_empty(&priv->lec_arp_empty_ones)) { + hlist_for_each_entry_safe(entry, next, + &priv->lec_arp_empty_ones, next) { + if (memcmp(entry->atm_addr, atm_addr, ATM_ESA_LEN) == 0) { + hlist_del(&entry->next); + del_timer(&entry->timer); + tmp = lec_arp_find(priv, mac_addr); + if (tmp) { + del_timer(&tmp->timer); + tmp->status = ESI_FORWARD_DIRECT; + memcpy(tmp->atm_addr, atm_addr, ATM_ESA_LEN); + tmp->vcc = entry->vcc; + tmp->old_push = entry->old_push; + tmp->last_used = jiffies; + del_timer(&entry->timer); + lec_arp_put(entry); + entry = tmp; + } else { + entry->status = ESI_FORWARD_DIRECT; + ether_addr_copy(entry->mac_addr, + mac_addr); + entry->last_used = jiffies; + lec_arp_add(priv, entry); + } + if (remoteflag) + entry->flags |= LEC_REMOTE_FLAG; + else + entry->flags &= ~LEC_REMOTE_FLAG; + pr_debug("After update\n"); + dump_arp_table(priv); + goto out; + } + } + } + + entry = lec_arp_find(priv, mac_addr); + if (!entry) { + entry = make_entry(priv, mac_addr); + if (!entry) + goto out; + entry->status = ESI_UNKNOWN; + lec_arp_add(priv, entry); + /* Temporary, changes before end of function */ + } + memcpy(entry->atm_addr, atm_addr, ATM_ESA_LEN); + del_timer(&entry->timer); + for (i = 0; i < LEC_ARP_TABLE_SIZE; i++) { + hlist_for_each_entry(tmp, + &priv->lec_arp_tables[i], next) { + if (entry != tmp && + !memcmp(tmp->atm_addr, atm_addr, ATM_ESA_LEN)) { + /* Vcc to this host exists */ + if (tmp->status > ESI_VC_PENDING) { + /* + * ESI_FLUSH_PENDING, + * ESI_FORWARD_DIRECT + */ + entry->vcc = tmp->vcc; + entry->old_push = tmp->old_push; + } + entry->status = tmp->status; + break; + } + } + } + if (remoteflag) + entry->flags |= LEC_REMOTE_FLAG; + else + entry->flags &= ~LEC_REMOTE_FLAG; + if (entry->status == ESI_ARP_PENDING || entry->status == ESI_UNKNOWN) { + entry->status = ESI_VC_PENDING; + send_to_lecd(priv, l_svc_setup, entry->mac_addr, atm_addr, NULL); + } + pr_debug("After update2\n"); + dump_arp_table(priv); +out: + spin_unlock_irqrestore(&priv->lec_arp_lock, flags); +} + +/* + * Notifies: Vcc setup ready + */ +static void +lec_vcc_added(struct lec_priv *priv, const struct atmlec_ioc *ioc_data, + struct atm_vcc *vcc, + void (*old_push) (struct atm_vcc *vcc, struct sk_buff *skb)) +{ + unsigned long flags; + struct lec_arp_table *entry; + int i, found_entry = 0; + + spin_lock_irqsave(&priv->lec_arp_lock, flags); + /* Vcc for Multicast Forward. No timer, LANEv2 7.1.20 and 2.3.5.3 */ + if (ioc_data->receive == 2) { + pr_debug("LEC_ARP: Attaching mcast forward\n"); +#if 0 + entry = lec_arp_find(priv, bus_mac); + if (!entry) { + pr_info("LEC_ARP: Multicast entry not found!\n"); + goto out; + } + memcpy(entry->atm_addr, ioc_data->atm_addr, ATM_ESA_LEN); + entry->recv_vcc = vcc; + entry->old_recv_push = old_push; +#endif + entry = make_entry(priv, bus_mac); + if (entry == NULL) + goto out; + del_timer(&entry->timer); + memcpy(entry->atm_addr, ioc_data->atm_addr, ATM_ESA_LEN); + entry->recv_vcc = vcc; + entry->old_recv_push = old_push; + hlist_add_head(&entry->next, &priv->mcast_fwds); + goto out; + } else if (ioc_data->receive == 1) { + /* + * Vcc which we don't want to make default vcc, + * attach it anyway. + */ + pr_debug("LEC_ARP:Attaching data direct, not default: %*phN\n", + ATM_ESA_LEN, ioc_data->atm_addr); + entry = make_entry(priv, bus_mac); + if (entry == NULL) + goto out; + memcpy(entry->atm_addr, ioc_data->atm_addr, ATM_ESA_LEN); + eth_zero_addr(entry->mac_addr); + entry->recv_vcc = vcc; + entry->old_recv_push = old_push; + entry->status = ESI_UNKNOWN; + entry->timer.expires = jiffies + priv->vcc_timeout_period; + entry->timer.function = lec_arp_expire_vcc; + hlist_add_head(&entry->next, &priv->lec_no_forward); + add_timer(&entry->timer); + dump_arp_table(priv); + goto out; + } + pr_debug("LEC_ARP:Attaching data direct, default: %*phN\n", + ATM_ESA_LEN, ioc_data->atm_addr); + for (i = 0; i < LEC_ARP_TABLE_SIZE; i++) { + hlist_for_each_entry(entry, + &priv->lec_arp_tables[i], next) { + if (memcmp + (ioc_data->atm_addr, entry->atm_addr, + ATM_ESA_LEN) == 0) { + pr_debug("LEC_ARP: Attaching data direct\n"); + pr_debug("Currently -> Vcc: %d, Rvcc:%d\n", + entry->vcc ? entry->vcc->vci : 0, + entry->recv_vcc ? entry->recv_vcc-> + vci : 0); + found_entry = 1; + del_timer(&entry->timer); + entry->vcc = vcc; + entry->old_push = old_push; + if (entry->status == ESI_VC_PENDING) { + if (priv->maximum_unknown_frame_count + == 0) + entry->status = + ESI_FORWARD_DIRECT; + else { + entry->timestamp = jiffies; + entry->status = + ESI_FLUSH_PENDING; +#if 0 + send_to_lecd(priv, l_flush_xmt, + NULL, + entry->atm_addr, + NULL); +#endif + } + } else { + /* + * They were forming a connection + * to us, and we to them. Our + * ATM address is numerically lower + * than theirs, so we make connection + * we formed into default VCC (8.1.11). + * Connection they made gets torn + * down. This might confuse some + * clients. Can be changed if + * someone reports trouble... + */ + ; + } + } + } + } + if (found_entry) { + pr_debug("After vcc was added\n"); + dump_arp_table(priv); + goto out; + } + /* + * Not found, snatch address from first data packet that arrives + * from this vcc + */ + entry = make_entry(priv, bus_mac); + if (!entry) + goto out; + entry->vcc = vcc; + entry->old_push = old_push; + memcpy(entry->atm_addr, ioc_data->atm_addr, ATM_ESA_LEN); + eth_zero_addr(entry->mac_addr); + entry->status = ESI_UNKNOWN; + hlist_add_head(&entry->next, &priv->lec_arp_empty_ones); + entry->timer.expires = jiffies + priv->vcc_timeout_period; + entry->timer.function = lec_arp_expire_vcc; + add_timer(&entry->timer); + pr_debug("After vcc was added\n"); + dump_arp_table(priv); +out: + spin_unlock_irqrestore(&priv->lec_arp_lock, flags); +} + +static void lec_flush_complete(struct lec_priv *priv, unsigned long tran_id) +{ + unsigned long flags; + struct lec_arp_table *entry; + int i; + + pr_debug("%lx\n", tran_id); +restart: + spin_lock_irqsave(&priv->lec_arp_lock, flags); + for (i = 0; i < LEC_ARP_TABLE_SIZE; i++) { + hlist_for_each_entry(entry, + &priv->lec_arp_tables[i], next) { + if (entry->flush_tran_id == tran_id && + entry->status == ESI_FLUSH_PENDING) { + struct sk_buff *skb; + struct atm_vcc *vcc = entry->vcc; + + lec_arp_hold(entry); + spin_unlock_irqrestore(&priv->lec_arp_lock, + flags); + while ((skb = skb_dequeue(&entry->tx_wait))) + lec_send(vcc, skb); + entry->last_used = jiffies; + entry->status = ESI_FORWARD_DIRECT; + lec_arp_put(entry); + pr_debug("LEC_ARP: Flushed\n"); + goto restart; + } + } + } + spin_unlock_irqrestore(&priv->lec_arp_lock, flags); + dump_arp_table(priv); +} + +static void +lec_set_flush_tran_id(struct lec_priv *priv, + const unsigned char *atm_addr, unsigned long tran_id) +{ + unsigned long flags; + struct lec_arp_table *entry; + int i; + + spin_lock_irqsave(&priv->lec_arp_lock, flags); + for (i = 0; i < LEC_ARP_TABLE_SIZE; i++) + hlist_for_each_entry(entry, + &priv->lec_arp_tables[i], next) { + if (!memcmp(atm_addr, entry->atm_addr, ATM_ESA_LEN)) { + entry->flush_tran_id = tran_id; + pr_debug("Set flush transaction id to %lx for %p\n", + tran_id, entry); + } + } + spin_unlock_irqrestore(&priv->lec_arp_lock, flags); +} + +static int lec_mcast_make(struct lec_priv *priv, struct atm_vcc *vcc) +{ + unsigned long flags; + unsigned char mac_addr[] = { + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff + }; + struct lec_arp_table *to_add; + struct lec_vcc_priv *vpriv; + int err = 0; + + vpriv = kmalloc(sizeof(struct lec_vcc_priv), GFP_KERNEL); + if (!vpriv) + return -ENOMEM; + vpriv->xoff = 0; + vpriv->old_pop = vcc->pop; + vcc->user_back = vpriv; + vcc->pop = lec_pop; + spin_lock_irqsave(&priv->lec_arp_lock, flags); + to_add = make_entry(priv, mac_addr); + if (!to_add) { + vcc->pop = vpriv->old_pop; + kfree(vpriv); + err = -ENOMEM; + goto out; + } + memcpy(to_add->atm_addr, vcc->remote.sas_addr.prv, ATM_ESA_LEN); + to_add->status = ESI_FORWARD_DIRECT; + to_add->flags |= LEC_PERMANENT_FLAG; + to_add->vcc = vcc; + to_add->old_push = vcc->push; + vcc->push = lec_push; + priv->mcast_vcc = vcc; + lec_arp_add(priv, to_add); +out: + spin_unlock_irqrestore(&priv->lec_arp_lock, flags); + return err; +} + +static void lec_vcc_close(struct lec_priv *priv, struct atm_vcc *vcc) +{ + unsigned long flags; + struct hlist_node *next; + struct lec_arp_table *entry; + int i; + + pr_debug("LEC_ARP: lec_vcc_close vpi:%d vci:%d\n", vcc->vpi, vcc->vci); + dump_arp_table(priv); + + spin_lock_irqsave(&priv->lec_arp_lock, flags); + + for (i = 0; i < LEC_ARP_TABLE_SIZE; i++) { + hlist_for_each_entry_safe(entry, next, + &priv->lec_arp_tables[i], next) { + if (vcc == entry->vcc) { + lec_arp_remove(priv, entry); + lec_arp_put(entry); + if (priv->mcast_vcc == vcc) + priv->mcast_vcc = NULL; + } + } + } + + hlist_for_each_entry_safe(entry, next, + &priv->lec_arp_empty_ones, next) { + if (entry->vcc == vcc) { + lec_arp_clear_vccs(entry); + del_timer(&entry->timer); + hlist_del(&entry->next); + lec_arp_put(entry); + } + } + + hlist_for_each_entry_safe(entry, next, + &priv->lec_no_forward, next) { + if (entry->recv_vcc == vcc) { + lec_arp_clear_vccs(entry); + del_timer(&entry->timer); + hlist_del(&entry->next); + lec_arp_put(entry); + } + } + + hlist_for_each_entry_safe(entry, next, &priv->mcast_fwds, next) { + if (entry->recv_vcc == vcc) { + lec_arp_clear_vccs(entry); + /* No timer, LANEv2 7.1.20 and 2.3.5.3 */ + hlist_del(&entry->next); + lec_arp_put(entry); + } + } + + spin_unlock_irqrestore(&priv->lec_arp_lock, flags); + dump_arp_table(priv); +} + +static void +lec_arp_check_empties(struct lec_priv *priv, + struct atm_vcc *vcc, struct sk_buff *skb) +{ + unsigned long flags; + struct hlist_node *next; + struct lec_arp_table *entry, *tmp; + struct lecdatahdr_8023 *hdr = (struct lecdatahdr_8023 *)skb->data; + unsigned char *src = hdr->h_source; + + spin_lock_irqsave(&priv->lec_arp_lock, flags); + hlist_for_each_entry_safe(entry, next, + &priv->lec_arp_empty_ones, next) { + if (vcc == entry->vcc) { + del_timer(&entry->timer); + ether_addr_copy(entry->mac_addr, src); + entry->status = ESI_FORWARD_DIRECT; + entry->last_used = jiffies; + /* We might have got an entry */ + tmp = lec_arp_find(priv, src); + if (tmp) { + lec_arp_remove(priv, tmp); + lec_arp_put(tmp); + } + hlist_del(&entry->next); + lec_arp_add(priv, entry); + goto out; + } + } + pr_debug("LEC_ARP: Arp_check_empties: entry not found!\n"); +out: + spin_unlock_irqrestore(&priv->lec_arp_lock, flags); +} + +MODULE_LICENSE("GPL"); diff --git a/net/atm/lec.h b/net/atm/lec.h new file mode 100644 index 000000000..be0e2667b --- /dev/null +++ b/net/atm/lec.h @@ -0,0 +1,155 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Lan Emulation client header file + * + * Marko Kiiskila + */ + +#ifndef _LEC_H_ +#define _LEC_H_ + +#include +#include +#include + +#define LEC_HEADER_LEN 16 + +struct lecdatahdr_8023 { + __be16 le_header; + unsigned char h_dest[ETH_ALEN]; + unsigned char h_source[ETH_ALEN]; + __be16 h_type; +}; + +struct lecdatahdr_8025 { + __be16 le_header; + unsigned char ac_pad; + unsigned char fc; + unsigned char h_dest[ETH_ALEN]; + unsigned char h_source[ETH_ALEN]; +}; + +#define LEC_MINIMUM_8023_SIZE 62 +#define LEC_MINIMUM_8025_SIZE 16 + +/* + * Operations that LANE2 capable device can do. Two first functions + * are used to make the device do things. See spec 3.1.3 and 3.1.4. + * + * The third function is intended for the MPOA component sitting on + * top of the LANE device. The MPOA component assigns it's own function + * to (*associate_indicator)() and the LANE device will use that + * function to tell about TLVs it sees floating through. + * + */ +struct lane2_ops { + int (*resolve) (struct net_device *dev, const u8 *dst_mac, int force, + u8 **tlvs, u32 *sizeoftlvs); + int (*associate_req) (struct net_device *dev, const u8 *lan_dst, + const u8 *tlvs, u32 sizeoftlvs); + void (*associate_indicator) (struct net_device *dev, const u8 *mac_addr, + const u8 *tlvs, u32 sizeoftlvs); +}; + +/* + * ATM LAN Emulation supports both LLC & Dix Ethernet EtherType + * frames. + * + * 1. Dix Ethernet EtherType frames encoded by placing EtherType + * field in h_type field. Data follows immediately after header. + * 2. LLC Data frames whose total length, including LLC field and data, + * but not padding required to meet the minimum data frame length, + * is less than ETH_P_802_3_MIN MUST be encoded by placing that length + * in the h_type field. The LLC field follows header immediately. + * 3. LLC data frames longer than this maximum MUST be encoded by placing + * the value 0 in the h_type field. + * + */ + +/* Hash table size */ +#define LEC_ARP_TABLE_SIZE 16 + +struct lec_priv { + unsigned short lecid; /* Lecid of this client */ + struct hlist_head lec_arp_empty_ones; + /* Used for storing VCC's that don't have a MAC address attached yet */ + struct hlist_head lec_arp_tables[LEC_ARP_TABLE_SIZE]; + /* Actual LE ARP table */ + struct hlist_head lec_no_forward; + /* + * Used for storing VCC's (and forward packets from) which are to + * age out by not using them to forward packets. + * This is because to some LE clients there will be 2 VCCs. Only + * one of them gets used. + */ + struct hlist_head mcast_fwds; + /* + * With LANEv2 it is possible that BUS (or a special multicast server) + * establishes multiple Multicast Forward VCCs to us. This list + * collects all those VCCs. LANEv1 client has only one item in this + * list. These entries are not aged out. + */ + spinlock_t lec_arp_lock; + struct atm_vcc *mcast_vcc; /* Default Multicast Send VCC */ + struct atm_vcc *lecd; + struct delayed_work lec_arp_work; /* C10 */ + unsigned int maximum_unknown_frame_count; + /* + * Within the period of time defined by this variable, the client will send + * no more than C10 frames to BUS for a given unicast destination. (C11) + */ + unsigned long max_unknown_frame_time; + /* + * If no traffic has been sent in this vcc for this period of time, + * vcc will be torn down (C12) + */ + unsigned long vcc_timeout_period; + /* + * An LE Client MUST not retry an LE_ARP_REQUEST for a + * given frame's LAN Destination more than maximum retry count times, + * after the first LEC_ARP_REQUEST (C13) + */ + unsigned short max_retry_count; + /* + * Max time the client will maintain an entry in its arp cache in + * absence of a verification of that relationship (C17) + */ + unsigned long aging_time; + /* + * Max time the client will maintain an entry in cache when + * topology change flag is true (C18) + */ + unsigned long forward_delay_time; /* Topology change flag (C19) */ + int topology_change; + /* + * Max time the client expects an LE_ARP_REQUEST/LE_ARP_RESPONSE + * cycle to take (C20) + */ + unsigned long arp_response_time; + /* + * Time limit ot wait to receive an LE_FLUSH_RESPONSE after the + * LE_FLUSH_REQUEST has been sent before taking recover action. (C21) + */ + unsigned long flush_timeout; + /* The time since sending a frame to the bus after which the + * LE Client may assume that the frame has been either discarded or + * delivered to the recipient (C22) + */ + unsigned long path_switching_delay; + + u8 *tlvs; /* LANE2: TLVs are new */ + u32 sizeoftlvs; /* The size of the tlv array in bytes */ + int lane_version; /* LANE2 */ + int itfnum; /* e.g. 2 for lec2, 5 for lec5 */ + struct lane2_ops *lane2_ops; /* can be NULL for LANE v1 */ + int is_proxy; /* bridge between ATM and Ethernet */ +}; + +struct lec_vcc_priv { + void (*old_pop) (struct atm_vcc *vcc, struct sk_buff *skb); + int xoff; +}; + +#define LEC_VCC_PRIV(vcc) ((struct lec_vcc_priv *)((vcc)->user_back)) + +#endif /* _LEC_H_ */ diff --git a/net/atm/lec_arpc.h b/net/atm/lec_arpc.h new file mode 100644 index 000000000..39115fe07 --- /dev/null +++ b/net/atm/lec_arpc.h @@ -0,0 +1,97 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Lec arp cache + * + * Marko Kiiskila + */ +#ifndef _LEC_ARP_H_ +#define _LEC_ARP_H_ +#include +#include +#include +#include + +struct lec_arp_table { + struct hlist_node next; /* Linked entry list */ + unsigned char atm_addr[ATM_ESA_LEN]; /* Atm address */ + unsigned char mac_addr[ETH_ALEN]; /* Mac address */ + int is_rdesc; /* Mac address is a route descriptor */ + struct atm_vcc *vcc; /* Vcc this entry is attached */ + struct atm_vcc *recv_vcc; /* Vcc we receive data from */ + + void (*old_push) (struct atm_vcc *vcc, struct sk_buff *skb); + /* Push that leads to daemon */ + + void (*old_recv_push) (struct atm_vcc *vcc, struct sk_buff *skb); + /* Push that leads to daemon */ + + unsigned long last_used; /* For expiry */ + unsigned long timestamp; /* Used for various timestamping things: + * 1. FLUSH started + * (status=ESI_FLUSH_PENDING) + * 2. Counting to + * max_unknown_frame_time + * (status=ESI_ARP_PENDING|| + * status=ESI_VC_PENDING) + */ + unsigned char no_tries; /* No of times arp retry has been tried */ + unsigned char status; /* Status of this entry */ + unsigned short flags; /* Flags for this entry */ + unsigned short packets_flooded; /* Data packets flooded */ + unsigned long flush_tran_id; /* Transaction id in flush protocol */ + struct timer_list timer; /* Arping timer */ + struct lec_priv *priv; /* Pointer back */ + u8 *tlvs; + u32 sizeoftlvs; /* + * LANE2: Each MAC address can have TLVs + * associated with it. sizeoftlvs tells + * the length of the tlvs array + */ + struct sk_buff_head tx_wait; /* wait queue for outgoing packets */ + refcount_t usage; /* usage count */ +}; + +/* + * LANE2: Template tlv struct for accessing + * the tlvs in the lec_arp_table->tlvs array + */ +struct tlv { + u32 type; + u8 length; + u8 value[255]; +}; + +/* Status fields */ +#define ESI_UNKNOWN 0 /* + * Next packet sent to this mac address + * causes ARP-request to be sent + */ +#define ESI_ARP_PENDING 1 /* + * There is no ATM address associated with this + * 48-bit address. The LE-ARP protocol is in + * progress. + */ +#define ESI_VC_PENDING 2 /* + * There is a valid ATM address associated with + * this 48-bit address but there is no VC set + * up to that ATM address. The signaling + * protocol is in process. + */ +#define ESI_FLUSH_PENDING 4 /* + * The LEC has been notified of the FLUSH_START + * status and it is assumed that the flush + * protocol is in process. + */ +#define ESI_FORWARD_DIRECT 5 /* + * Either the Path Switching Delay (C22) has + * elapsed or the LEC has notified the Mapping + * that the flush protocol has completed. In + * either case, it is safe to forward packets + * to this address via the data direct VC. + */ + +/* Flag values */ +#define LEC_REMOTE_FLAG 0x0001 +#define LEC_PERMANENT_FLAG 0x0002 + +#endif /* _LEC_ARP_H_ */ diff --git a/net/atm/mpc.c b/net/atm/mpc.c new file mode 100644 index 000000000..033871e71 --- /dev/null +++ b/net/atm/mpc.c @@ -0,0 +1,1535 @@ +// SPDX-License-Identifier: GPL-2.0-only +#define pr_fmt(fmt) KBUILD_MODNAME ":%s: " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include +#include + +/* We are an ethernet device */ +#include +#include +#include +#include +#include +#include +#include +#include +#include /* for ip_fast_csum() */ +#include +#include +#include + +/* And atm device */ +#include +#include +#include +/* Modular too */ +#include + +#include "lec.h" +#include "mpc.h" +#include "resources.h" + +/* + * mpc.c: Implementation of MPOA client kernel part + */ + +#if 0 +#define dprintk(format, args...) \ + printk(KERN_DEBUG "mpoa:%s: " format, __func__, ##args) +#define dprintk_cont(format, args...) printk(KERN_CONT format, ##args) +#else +#define dprintk(format, args...) \ + do { if (0) \ + printk(KERN_DEBUG "mpoa:%s: " format, __func__, ##args);\ + } while (0) +#define dprintk_cont(format, args...) \ + do { if (0) printk(KERN_CONT format, ##args); } while (0) +#endif + +#if 0 +#define ddprintk(format, args...) \ + printk(KERN_DEBUG "mpoa:%s: " format, __func__, ##args) +#define ddprintk_cont(format, args...) printk(KERN_CONT format, ##args) +#else +#define ddprintk(format, args...) \ + do { if (0) \ + printk(KERN_DEBUG "mpoa:%s: " format, __func__, ##args);\ + } while (0) +#define ddprintk_cont(format, args...) \ + do { if (0) printk(KERN_CONT format, ##args); } while (0) +#endif + +/* mpc_daemon -> kernel */ +static void MPOA_trigger_rcvd(struct k_message *msg, struct mpoa_client *mpc); +static void MPOA_res_reply_rcvd(struct k_message *msg, struct mpoa_client *mpc); +static void ingress_purge_rcvd(struct k_message *msg, struct mpoa_client *mpc); +static void egress_purge_rcvd(struct k_message *msg, struct mpoa_client *mpc); +static void mps_death(struct k_message *msg, struct mpoa_client *mpc); +static void clean_up(struct k_message *msg, struct mpoa_client *mpc, + int action); +static void MPOA_cache_impos_rcvd(struct k_message *msg, + struct mpoa_client *mpc); +static void set_mpc_ctrl_addr_rcvd(struct k_message *mesg, + struct mpoa_client *mpc); +static void set_mps_mac_addr_rcvd(struct k_message *mesg, + struct mpoa_client *mpc); + +static const uint8_t *copy_macs(struct mpoa_client *mpc, + const uint8_t *router_mac, + const uint8_t *tlvs, uint8_t mps_macs, + uint8_t device_type); +static void purge_egress_shortcut(struct atm_vcc *vcc, eg_cache_entry *entry); + +static void send_set_mps_ctrl_addr(const char *addr, struct mpoa_client *mpc); +static void mpoad_close(struct atm_vcc *vcc); +static int msg_from_mpoad(struct atm_vcc *vcc, struct sk_buff *skb); + +static void mpc_push(struct atm_vcc *vcc, struct sk_buff *skb); +static netdev_tx_t mpc_send_packet(struct sk_buff *skb, + struct net_device *dev); +static int mpoa_event_listener(struct notifier_block *mpoa_notifier, + unsigned long event, void *dev); +static void mpc_timer_refresh(void); +static void mpc_cache_check(struct timer_list *unused); + +static struct llc_snap_hdr llc_snap_mpoa_ctrl = { + 0xaa, 0xaa, 0x03, + {0x00, 0x00, 0x5e}, + {0x00, 0x03} /* For MPOA control PDUs */ +}; +static struct llc_snap_hdr llc_snap_mpoa_data = { + 0xaa, 0xaa, 0x03, + {0x00, 0x00, 0x00}, + {0x08, 0x00} /* This is for IP PDUs only */ +}; +static struct llc_snap_hdr llc_snap_mpoa_data_tagged = { + 0xaa, 0xaa, 0x03, + {0x00, 0x00, 0x00}, + {0x88, 0x4c} /* This is for tagged data PDUs */ +}; + +static struct notifier_block mpoa_notifier = { + mpoa_event_listener, + NULL, + 0 +}; + +struct mpoa_client *mpcs = NULL; /* FIXME */ +static struct atm_mpoa_qos *qos_head = NULL; +static DEFINE_TIMER(mpc_timer, mpc_cache_check); + + +static struct mpoa_client *find_mpc_by_itfnum(int itf) +{ + struct mpoa_client *mpc; + + mpc = mpcs; /* our global linked list */ + while (mpc != NULL) { + if (mpc->dev_num == itf) + return mpc; + mpc = mpc->next; + } + + return NULL; /* not found */ +} + +static struct mpoa_client *find_mpc_by_vcc(struct atm_vcc *vcc) +{ + struct mpoa_client *mpc; + + mpc = mpcs; /* our global linked list */ + while (mpc != NULL) { + if (mpc->mpoad_vcc == vcc) + return mpc; + mpc = mpc->next; + } + + return NULL; /* not found */ +} + +static struct mpoa_client *find_mpc_by_lec(struct net_device *dev) +{ + struct mpoa_client *mpc; + + mpc = mpcs; /* our global linked list */ + while (mpc != NULL) { + if (mpc->dev == dev) + return mpc; + mpc = mpc->next; + } + + return NULL; /* not found */ +} + +/* + * Functions for managing QoS list + */ + +/* + * Overwrites the old entry or makes a new one. + */ +struct atm_mpoa_qos *atm_mpoa_add_qos(__be32 dst_ip, struct atm_qos *qos) +{ + struct atm_mpoa_qos *entry; + + entry = atm_mpoa_search_qos(dst_ip); + if (entry != NULL) { + entry->qos = *qos; + return entry; + } + + entry = kmalloc(sizeof(struct atm_mpoa_qos), GFP_KERNEL); + if (entry == NULL) { + pr_info("mpoa: out of memory\n"); + return entry; + } + + entry->ipaddr = dst_ip; + entry->qos = *qos; + + entry->next = qos_head; + qos_head = entry; + + return entry; +} + +struct atm_mpoa_qos *atm_mpoa_search_qos(__be32 dst_ip) +{ + struct atm_mpoa_qos *qos; + + qos = qos_head; + while (qos) { + if (qos->ipaddr == dst_ip) + break; + qos = qos->next; + } + + return qos; +} + +/* + * Returns 0 for failure + */ +int atm_mpoa_delete_qos(struct atm_mpoa_qos *entry) +{ + struct atm_mpoa_qos *curr; + + if (entry == NULL) + return 0; + if (entry == qos_head) { + qos_head = qos_head->next; + kfree(entry); + return 1; + } + + curr = qos_head; + while (curr != NULL) { + if (curr->next == entry) { + curr->next = entry->next; + kfree(entry); + return 1; + } + curr = curr->next; + } + + return 0; +} + +/* this is buggered - we need locking for qos_head */ +void atm_mpoa_disp_qos(struct seq_file *m) +{ + struct atm_mpoa_qos *qos; + + qos = qos_head; + seq_printf(m, "QoS entries for shortcuts:\n"); + seq_printf(m, "IP address\n TX:max_pcr pcr min_pcr max_cdv max_sdu\n RX:max_pcr pcr min_pcr max_cdv max_sdu\n"); + + while (qos != NULL) { + seq_printf(m, "%pI4\n %-7d %-7d %-7d %-7d %-7d\n %-7d %-7d %-7d %-7d %-7d\n", + &qos->ipaddr, + qos->qos.txtp.max_pcr, + qos->qos.txtp.pcr, + qos->qos.txtp.min_pcr, + qos->qos.txtp.max_cdv, + qos->qos.txtp.max_sdu, + qos->qos.rxtp.max_pcr, + qos->qos.rxtp.pcr, + qos->qos.rxtp.min_pcr, + qos->qos.rxtp.max_cdv, + qos->qos.rxtp.max_sdu); + qos = qos->next; + } +} + +static struct net_device *find_lec_by_itfnum(int itf) +{ + struct net_device *dev; + char name[IFNAMSIZ]; + + sprintf(name, "lec%d", itf); + dev = dev_get_by_name(&init_net, name); + + return dev; +} + +static struct mpoa_client *alloc_mpc(void) +{ + struct mpoa_client *mpc; + + mpc = kzalloc(sizeof(struct mpoa_client), GFP_KERNEL); + if (mpc == NULL) + return NULL; + rwlock_init(&mpc->ingress_lock); + rwlock_init(&mpc->egress_lock); + mpc->next = mpcs; + atm_mpoa_init_cache(mpc); + + mpc->parameters.mpc_p1 = MPC_P1; + mpc->parameters.mpc_p2 = MPC_P2; + memset(mpc->parameters.mpc_p3, 0, sizeof(mpc->parameters.mpc_p3)); + mpc->parameters.mpc_p4 = MPC_P4; + mpc->parameters.mpc_p5 = MPC_P5; + mpc->parameters.mpc_p6 = MPC_P6; + + mpcs = mpc; + + return mpc; +} + +/* + * + * start_mpc() puts the MPC on line. All the packets destined + * to the lec underneath us are now being monitored and + * shortcuts will be established. + * + */ +static void start_mpc(struct mpoa_client *mpc, struct net_device *dev) +{ + + dprintk("(%s)\n", mpc->dev->name); + if (!dev->netdev_ops) + pr_info("(%s) not starting\n", dev->name); + else { + mpc->old_ops = dev->netdev_ops; + mpc->new_ops = *mpc->old_ops; + mpc->new_ops.ndo_start_xmit = mpc_send_packet; + dev->netdev_ops = &mpc->new_ops; + } +} + +static void stop_mpc(struct mpoa_client *mpc) +{ + struct net_device *dev = mpc->dev; + dprintk("(%s)", mpc->dev->name); + + /* Lets not nullify lec device's dev->hard_start_xmit */ + if (dev->netdev_ops != &mpc->new_ops) { + dprintk_cont(" mpc already stopped, not fatal\n"); + return; + } + dprintk_cont("\n"); + + dev->netdev_ops = mpc->old_ops; + mpc->old_ops = NULL; + + /* close_shortcuts(mpc); ??? FIXME */ +} + +static const char *mpoa_device_type_string(char type) __attribute__ ((unused)); + +static const char *mpoa_device_type_string(char type) +{ + switch (type) { + case NON_MPOA: + return "non-MPOA device"; + case MPS: + return "MPS"; + case MPC: + return "MPC"; + case MPS_AND_MPC: + return "both MPS and MPC"; + } + + return "unspecified (non-MPOA) device"; +} + +/* + * lec device calls this via its netdev_priv(dev)->lane2_ops + * ->associate_indicator() when it sees a TLV in LE_ARP packet. + * We fill in the pointer above when we see a LANE2 lec initializing + * See LANE2 spec 3.1.5 + * + * Quite a big and ugly function but when you look at it + * all it does is to try to locate and parse MPOA Device + * Type TLV. + * We give our lec a pointer to this function and when the + * lec sees a TLV it uses the pointer to call this function. + * + */ +static void lane2_assoc_ind(struct net_device *dev, const u8 *mac_addr, + const u8 *tlvs, u32 sizeoftlvs) +{ + uint32_t type; + uint8_t length, mpoa_device_type, number_of_mps_macs; + const uint8_t *end_of_tlvs; + struct mpoa_client *mpc; + + mpoa_device_type = number_of_mps_macs = 0; /* silence gcc */ + dprintk("(%s) received TLV(s), ", dev->name); + dprintk("total length of all TLVs %d\n", sizeoftlvs); + mpc = find_mpc_by_lec(dev); /* Sampo-Fix: moved here from below */ + if (mpc == NULL) { + pr_info("(%s) no mpc\n", dev->name); + return; + } + end_of_tlvs = tlvs + sizeoftlvs; + while (end_of_tlvs - tlvs >= 5) { + type = ((tlvs[0] << 24) | (tlvs[1] << 16) | + (tlvs[2] << 8) | tlvs[3]); + length = tlvs[4]; + tlvs += 5; + dprintk(" type 0x%x length %02x\n", type, length); + if (tlvs + length > end_of_tlvs) { + pr_info("TLV value extends past its buffer, aborting parse\n"); + return; + } + + if (type == 0) { + pr_info("mpoa: (%s) TLV type was 0, returning\n", + dev->name); + return; + } + + if (type != TLV_MPOA_DEVICE_TYPE) { + tlvs += length; + continue; /* skip other TLVs */ + } + mpoa_device_type = *tlvs++; + number_of_mps_macs = *tlvs++; + dprintk("(%s) MPOA device type '%s', ", + dev->name, mpoa_device_type_string(mpoa_device_type)); + if (mpoa_device_type == MPS_AND_MPC && + length < (42 + number_of_mps_macs*ETH_ALEN)) { /* :) */ + pr_info("(%s) short MPOA Device Type TLV\n", + dev->name); + continue; + } + if ((mpoa_device_type == MPS || mpoa_device_type == MPC) && + length < 22 + number_of_mps_macs*ETH_ALEN) { + pr_info("(%s) short MPOA Device Type TLV\n", dev->name); + continue; + } + if (mpoa_device_type != MPS && + mpoa_device_type != MPS_AND_MPC) { + dprintk("ignoring non-MPS device "); + if (mpoa_device_type == MPC) + tlvs += 20; + continue; /* we are only interested in MPSs */ + } + if (number_of_mps_macs == 0 && + mpoa_device_type == MPS_AND_MPC) { + pr_info("(%s) MPS_AND_MPC has zero MACs\n", dev->name); + continue; /* someone should read the spec */ + } + dprintk_cont("this MPS has %d MAC addresses\n", + number_of_mps_macs); + + /* + * ok, now we can go and tell our daemon + * the control address of MPS + */ + send_set_mps_ctrl_addr(tlvs, mpc); + + tlvs = copy_macs(mpc, mac_addr, tlvs, + number_of_mps_macs, mpoa_device_type); + if (tlvs == NULL) + return; + } + if (end_of_tlvs - tlvs != 0) + pr_info("(%s) ignoring %zd bytes of trailing TLV garbage\n", + dev->name, end_of_tlvs - tlvs); +} + +/* + * Store at least advertizing router's MAC address + * plus the possible MAC address(es) to mpc->mps_macs. + * For a freshly allocated MPOA client mpc->mps_macs == 0. + */ +static const uint8_t *copy_macs(struct mpoa_client *mpc, + const uint8_t *router_mac, + const uint8_t *tlvs, uint8_t mps_macs, + uint8_t device_type) +{ + int num_macs; + num_macs = (mps_macs > 1) ? mps_macs : 1; + + if (mpc->number_of_mps_macs != num_macs) { /* need to reallocate? */ + if (mpc->number_of_mps_macs != 0) + kfree(mpc->mps_macs); + mpc->number_of_mps_macs = 0; + mpc->mps_macs = kmalloc_array(ETH_ALEN, num_macs, GFP_KERNEL); + if (mpc->mps_macs == NULL) { + pr_info("(%s) out of mem\n", mpc->dev->name); + return NULL; + } + } + ether_addr_copy(mpc->mps_macs, router_mac); + tlvs += 20; if (device_type == MPS_AND_MPC) tlvs += 20; + if (mps_macs > 0) + memcpy(mpc->mps_macs, tlvs, mps_macs*ETH_ALEN); + tlvs += mps_macs*ETH_ALEN; + mpc->number_of_mps_macs = num_macs; + + return tlvs; +} + +static int send_via_shortcut(struct sk_buff *skb, struct mpoa_client *mpc) +{ + in_cache_entry *entry; + struct iphdr *iph; + char *buff; + __be32 ipaddr = 0; + + static struct { + struct llc_snap_hdr hdr; + __be32 tag; + } tagged_llc_snap_hdr = { + {0xaa, 0xaa, 0x03, {0x00, 0x00, 0x00}, {0x88, 0x4c}}, + 0 + }; + + buff = skb->data + mpc->dev->hard_header_len; + iph = (struct iphdr *)buff; + ipaddr = iph->daddr; + + ddprintk("(%s) ipaddr 0x%x\n", + mpc->dev->name, ipaddr); + + entry = mpc->in_ops->get(ipaddr, mpc); + if (entry == NULL) { + entry = mpc->in_ops->add_entry(ipaddr, mpc); + if (entry != NULL) + mpc->in_ops->put(entry); + return 1; + } + /* threshold not exceeded or VCC not ready */ + if (mpc->in_ops->cache_hit(entry, mpc) != OPEN) { + ddprintk("(%s) cache_hit: returns != OPEN\n", + mpc->dev->name); + mpc->in_ops->put(entry); + return 1; + } + + ddprintk("(%s) using shortcut\n", + mpc->dev->name); + /* MPOA spec A.1.4, MPOA client must decrement IP ttl at least by one */ + if (iph->ttl <= 1) { + ddprintk("(%s) IP ttl = %u, using LANE\n", + mpc->dev->name, iph->ttl); + mpc->in_ops->put(entry); + return 1; + } + iph->ttl--; + iph->check = 0; + iph->check = ip_fast_csum((unsigned char *)iph, iph->ihl); + + if (entry->ctrl_info.tag != 0) { + ddprintk("(%s) adding tag 0x%x\n", + mpc->dev->name, entry->ctrl_info.tag); + tagged_llc_snap_hdr.tag = entry->ctrl_info.tag; + skb_pull(skb, ETH_HLEN); /* get rid of Eth header */ + skb_push(skb, sizeof(tagged_llc_snap_hdr)); + /* add LLC/SNAP header */ + skb_copy_to_linear_data(skb, &tagged_llc_snap_hdr, + sizeof(tagged_llc_snap_hdr)); + } else { + skb_pull(skb, ETH_HLEN); /* get rid of Eth header */ + skb_push(skb, sizeof(struct llc_snap_hdr)); + /* add LLC/SNAP header + tag */ + skb_copy_to_linear_data(skb, &llc_snap_mpoa_data, + sizeof(struct llc_snap_hdr)); + } + + atm_account_tx(entry->shortcut, skb); + entry->shortcut->send(entry->shortcut, skb); + entry->packets_fwded++; + mpc->in_ops->put(entry); + + return 0; +} + +/* + * Probably needs some error checks and locking, not sure... + */ +static netdev_tx_t mpc_send_packet(struct sk_buff *skb, + struct net_device *dev) +{ + struct mpoa_client *mpc; + struct ethhdr *eth; + int i = 0; + + mpc = find_mpc_by_lec(dev); /* this should NEVER fail */ + if (mpc == NULL) { + pr_info("(%s) no MPC found\n", dev->name); + goto non_ip; + } + + eth = (struct ethhdr *)skb->data; + if (eth->h_proto != htons(ETH_P_IP)) + goto non_ip; /* Multi-Protocol Over ATM :-) */ + + /* Weed out funny packets (e.g., AF_PACKET or raw). */ + if (skb->len < ETH_HLEN + sizeof(struct iphdr)) + goto non_ip; + skb_set_network_header(skb, ETH_HLEN); + if (skb->len < ETH_HLEN + ip_hdr(skb)->ihl * 4 || ip_hdr(skb)->ihl < 5) + goto non_ip; + + while (i < mpc->number_of_mps_macs) { + if (ether_addr_equal(eth->h_dest, mpc->mps_macs + i * ETH_ALEN)) + if (send_via_shortcut(skb, mpc) == 0) /* try shortcut */ + return NETDEV_TX_OK; + i++; + } + +non_ip: + return __netdev_start_xmit(mpc->old_ops, skb, dev, false); +} + +static int atm_mpoa_vcc_attach(struct atm_vcc *vcc, void __user *arg) +{ + int bytes_left; + struct mpoa_client *mpc; + struct atmmpc_ioc ioc_data; + in_cache_entry *in_entry; + __be32 ipaddr; + + bytes_left = copy_from_user(&ioc_data, arg, sizeof(struct atmmpc_ioc)); + if (bytes_left != 0) { + pr_info("mpoa:Short read (missed %d bytes) from userland\n", + bytes_left); + return -EFAULT; + } + ipaddr = ioc_data.ipaddr; + if (ioc_data.dev_num < 0 || ioc_data.dev_num >= MAX_LEC_ITF) + return -EINVAL; + + mpc = find_mpc_by_itfnum(ioc_data.dev_num); + if (mpc == NULL) + return -EINVAL; + + if (ioc_data.type == MPC_SOCKET_INGRESS) { + in_entry = mpc->in_ops->get(ipaddr, mpc); + if (in_entry == NULL || + in_entry->entry_state < INGRESS_RESOLVED) { + pr_info("(%s) did not find RESOLVED entry from ingress cache\n", + mpc->dev->name); + if (in_entry != NULL) + mpc->in_ops->put(in_entry); + return -EINVAL; + } + pr_info("(%s) attaching ingress SVC, entry = %pI4\n", + mpc->dev->name, &in_entry->ctrl_info.in_dst_ip); + in_entry->shortcut = vcc; + mpc->in_ops->put(in_entry); + } else { + pr_info("(%s) attaching egress SVC\n", mpc->dev->name); + } + + vcc->proto_data = mpc->dev; + vcc->push = mpc_push; + + return 0; +} + +/* + * + */ +static void mpc_vcc_close(struct atm_vcc *vcc, struct net_device *dev) +{ + struct mpoa_client *mpc; + in_cache_entry *in_entry; + eg_cache_entry *eg_entry; + + mpc = find_mpc_by_lec(dev); + if (mpc == NULL) { + pr_info("(%s) close for unknown MPC\n", dev->name); + return; + } + + dprintk("(%s)\n", dev->name); + in_entry = mpc->in_ops->get_by_vcc(vcc, mpc); + if (in_entry) { + dprintk("(%s) ingress SVC closed ip = %pI4\n", + mpc->dev->name, &in_entry->ctrl_info.in_dst_ip); + in_entry->shortcut = NULL; + mpc->in_ops->put(in_entry); + } + eg_entry = mpc->eg_ops->get_by_vcc(vcc, mpc); + if (eg_entry) { + dprintk("(%s) egress SVC closed\n", mpc->dev->name); + eg_entry->shortcut = NULL; + mpc->eg_ops->put(eg_entry); + } + + if (in_entry == NULL && eg_entry == NULL) + dprintk("(%s) unused vcc closed\n", dev->name); +} + +static void mpc_push(struct atm_vcc *vcc, struct sk_buff *skb) +{ + struct net_device *dev = (struct net_device *)vcc->proto_data; + struct sk_buff *new_skb; + eg_cache_entry *eg; + struct mpoa_client *mpc; + __be32 tag; + char *tmp; + + ddprintk("(%s)\n", dev->name); + if (skb == NULL) { + dprintk("(%s) null skb, closing VCC\n", dev->name); + mpc_vcc_close(vcc, dev); + return; + } + + skb->dev = dev; + if (memcmp(skb->data, &llc_snap_mpoa_ctrl, + sizeof(struct llc_snap_hdr)) == 0) { + struct sock *sk = sk_atm(vcc); + + dprintk("(%s) control packet arrived\n", dev->name); + /* Pass control packets to daemon */ + skb_queue_tail(&sk->sk_receive_queue, skb); + sk->sk_data_ready(sk); + return; + } + + /* data coming over the shortcut */ + atm_return(vcc, skb->truesize); + + mpc = find_mpc_by_lec(dev); + if (mpc == NULL) { + pr_info("(%s) unknown MPC\n", dev->name); + return; + } + + if (memcmp(skb->data, &llc_snap_mpoa_data_tagged, + sizeof(struct llc_snap_hdr)) == 0) { /* MPOA tagged data */ + ddprintk("(%s) tagged data packet arrived\n", dev->name); + + } else if (memcmp(skb->data, &llc_snap_mpoa_data, + sizeof(struct llc_snap_hdr)) == 0) { /* MPOA data */ + pr_info("(%s) Unsupported non-tagged data packet arrived. Purging\n", + dev->name); + dev_kfree_skb_any(skb); + return; + } else { + pr_info("(%s) garbage arrived, purging\n", dev->name); + dev_kfree_skb_any(skb); + return; + } + + tmp = skb->data + sizeof(struct llc_snap_hdr); + tag = *(__be32 *)tmp; + + eg = mpc->eg_ops->get_by_tag(tag, mpc); + if (eg == NULL) { + pr_info("mpoa: (%s) Didn't find egress cache entry, tag = %u\n", + dev->name, tag); + purge_egress_shortcut(vcc, NULL); + dev_kfree_skb_any(skb); + return; + } + + /* + * See if ingress MPC is using shortcut we opened as a return channel. + * This means we have a bi-directional vcc opened by us. + */ + if (eg->shortcut == NULL) { + eg->shortcut = vcc; + pr_info("(%s) egress SVC in use\n", dev->name); + } + + skb_pull(skb, sizeof(struct llc_snap_hdr) + sizeof(tag)); + /* get rid of LLC/SNAP header */ + new_skb = skb_realloc_headroom(skb, eg->ctrl_info.DH_length); + /* LLC/SNAP is shorter than MAC header :( */ + dev_kfree_skb_any(skb); + if (new_skb == NULL) { + mpc->eg_ops->put(eg); + return; + } + skb_push(new_skb, eg->ctrl_info.DH_length); /* add MAC header */ + skb_copy_to_linear_data(new_skb, eg->ctrl_info.DLL_header, + eg->ctrl_info.DH_length); + new_skb->protocol = eth_type_trans(new_skb, dev); + skb_reset_network_header(new_skb); + + eg->latest_ip_addr = ip_hdr(new_skb)->saddr; + eg->packets_rcvd++; + mpc->eg_ops->put(eg); + + memset(ATM_SKB(new_skb), 0, sizeof(struct atm_skb_data)); + netif_rx(new_skb); +} + +static const struct atmdev_ops mpc_ops = { /* only send is required */ + .close = mpoad_close, + .send = msg_from_mpoad +}; + +static struct atm_dev mpc_dev = { + .ops = &mpc_ops, + .type = "mpc", + .number = 42, + .lock = __SPIN_LOCK_UNLOCKED(mpc_dev.lock) + /* members not explicitly initialised will be 0 */ +}; + +static int atm_mpoa_mpoad_attach(struct atm_vcc *vcc, int arg) +{ + struct mpoa_client *mpc; + struct lec_priv *priv; + int err; + + if (mpcs == NULL) { + mpc_timer_refresh(); + + /* This lets us now how our LECs are doing */ + err = register_netdevice_notifier(&mpoa_notifier); + if (err < 0) { + del_timer(&mpc_timer); + return err; + } + } + + mpc = find_mpc_by_itfnum(arg); + if (mpc == NULL) { + dprintk("allocating new mpc for itf %d\n", arg); + mpc = alloc_mpc(); + if (mpc == NULL) + return -ENOMEM; + mpc->dev_num = arg; + mpc->dev = find_lec_by_itfnum(arg); + /* NULL if there was no lec */ + } + if (mpc->mpoad_vcc) { + pr_info("mpoad is already present for itf %d\n", arg); + return -EADDRINUSE; + } + + if (mpc->dev) { /* check if the lec is LANE2 capable */ + priv = netdev_priv(mpc->dev); + if (priv->lane_version < 2) { + dev_put(mpc->dev); + mpc->dev = NULL; + } else + priv->lane2_ops->associate_indicator = lane2_assoc_ind; + } + + mpc->mpoad_vcc = vcc; + vcc->dev = &mpc_dev; + vcc_insert_socket(sk_atm(vcc)); + set_bit(ATM_VF_META, &vcc->flags); + set_bit(ATM_VF_READY, &vcc->flags); + + if (mpc->dev) { + char empty[ATM_ESA_LEN]; + memset(empty, 0, ATM_ESA_LEN); + + start_mpc(mpc, mpc->dev); + /* set address if mpcd e.g. gets killed and restarted. + * If we do not do it now we have to wait for the next LE_ARP + */ + if (memcmp(mpc->mps_ctrl_addr, empty, ATM_ESA_LEN) != 0) + send_set_mps_ctrl_addr(mpc->mps_ctrl_addr, mpc); + } + + __module_get(THIS_MODULE); + return arg; +} + +static void send_set_mps_ctrl_addr(const char *addr, struct mpoa_client *mpc) +{ + struct k_message mesg; + + memcpy(mpc->mps_ctrl_addr, addr, ATM_ESA_LEN); + + mesg.type = SET_MPS_CTRL_ADDR; + memcpy(mesg.MPS_ctrl, addr, ATM_ESA_LEN); + msg_to_mpoad(&mesg, mpc); +} + +static void mpoad_close(struct atm_vcc *vcc) +{ + struct mpoa_client *mpc; + struct sk_buff *skb; + + mpc = find_mpc_by_vcc(vcc); + if (mpc == NULL) { + pr_info("did not find MPC\n"); + return; + } + if (!mpc->mpoad_vcc) { + pr_info("close for non-present mpoad\n"); + return; + } + + mpc->mpoad_vcc = NULL; + if (mpc->dev) { + struct lec_priv *priv = netdev_priv(mpc->dev); + priv->lane2_ops->associate_indicator = NULL; + stop_mpc(mpc); + dev_put(mpc->dev); + } + + mpc->in_ops->destroy_cache(mpc); + mpc->eg_ops->destroy_cache(mpc); + + while ((skb = skb_dequeue(&sk_atm(vcc)->sk_receive_queue))) { + atm_return(vcc, skb->truesize); + kfree_skb(skb); + } + + pr_info("(%s) going down\n", + (mpc->dev) ? mpc->dev->name : ""); + module_put(THIS_MODULE); +} + +/* + * + */ +static int msg_from_mpoad(struct atm_vcc *vcc, struct sk_buff *skb) +{ + + struct mpoa_client *mpc = find_mpc_by_vcc(vcc); + struct k_message *mesg = (struct k_message *)skb->data; + WARN_ON(refcount_sub_and_test(skb->truesize, &sk_atm(vcc)->sk_wmem_alloc)); + + if (mpc == NULL) { + pr_info("no mpc found\n"); + return 0; + } + dprintk("(%s)", mpc->dev ? mpc->dev->name : ""); + switch (mesg->type) { + case MPOA_RES_REPLY_RCVD: + dprintk_cont("mpoa_res_reply_rcvd\n"); + MPOA_res_reply_rcvd(mesg, mpc); + break; + case MPOA_TRIGGER_RCVD: + dprintk_cont("mpoa_trigger_rcvd\n"); + MPOA_trigger_rcvd(mesg, mpc); + break; + case INGRESS_PURGE_RCVD: + dprintk_cont("nhrp_purge_rcvd\n"); + ingress_purge_rcvd(mesg, mpc); + break; + case EGRESS_PURGE_RCVD: + dprintk_cont("egress_purge_reply_rcvd\n"); + egress_purge_rcvd(mesg, mpc); + break; + case MPS_DEATH: + dprintk_cont("mps_death\n"); + mps_death(mesg, mpc); + break; + case CACHE_IMPOS_RCVD: + dprintk_cont("cache_impos_rcvd\n"); + MPOA_cache_impos_rcvd(mesg, mpc); + break; + case SET_MPC_CTRL_ADDR: + dprintk_cont("set_mpc_ctrl_addr\n"); + set_mpc_ctrl_addr_rcvd(mesg, mpc); + break; + case SET_MPS_MAC_ADDR: + dprintk_cont("set_mps_mac_addr\n"); + set_mps_mac_addr_rcvd(mesg, mpc); + break; + case CLEAN_UP_AND_EXIT: + dprintk_cont("clean_up_and_exit\n"); + clean_up(mesg, mpc, DIE); + break; + case RELOAD: + dprintk_cont("reload\n"); + clean_up(mesg, mpc, RELOAD); + break; + case SET_MPC_PARAMS: + dprintk_cont("set_mpc_params\n"); + mpc->parameters = mesg->content.params; + break; + default: + dprintk_cont("unknown message %d\n", mesg->type); + break; + } + kfree_skb(skb); + + return 0; +} + +/* Remember that this function may not do things that sleep */ +int msg_to_mpoad(struct k_message *mesg, struct mpoa_client *mpc) +{ + struct sk_buff *skb; + struct sock *sk; + + if (mpc == NULL || !mpc->mpoad_vcc) { + pr_info("mesg %d to a non-existent mpoad\n", mesg->type); + return -ENXIO; + } + + skb = alloc_skb(sizeof(struct k_message), GFP_ATOMIC); + if (skb == NULL) + return -ENOMEM; + skb_put(skb, sizeof(struct k_message)); + skb_copy_to_linear_data(skb, mesg, sizeof(*mesg)); + atm_force_charge(mpc->mpoad_vcc, skb->truesize); + + sk = sk_atm(mpc->mpoad_vcc); + skb_queue_tail(&sk->sk_receive_queue, skb); + sk->sk_data_ready(sk); + + return 0; +} + +static int mpoa_event_listener(struct notifier_block *mpoa_notifier, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct mpoa_client *mpc; + struct lec_priv *priv; + + if (!net_eq(dev_net(dev), &init_net)) + return NOTIFY_DONE; + + if (strncmp(dev->name, "lec", 3)) + return NOTIFY_DONE; /* we are only interested in lec:s */ + + switch (event) { + case NETDEV_REGISTER: /* a new lec device was allocated */ + priv = netdev_priv(dev); + if (priv->lane_version < 2) + break; + priv->lane2_ops->associate_indicator = lane2_assoc_ind; + mpc = find_mpc_by_itfnum(priv->itfnum); + if (mpc == NULL) { + dprintk("allocating new mpc for %s\n", dev->name); + mpc = alloc_mpc(); + if (mpc == NULL) { + pr_info("no new mpc"); + break; + } + } + mpc->dev_num = priv->itfnum; + mpc->dev = dev; + dev_hold(dev); + dprintk("(%s) was initialized\n", dev->name); + break; + case NETDEV_UNREGISTER: + /* the lec device was deallocated */ + mpc = find_mpc_by_lec(dev); + if (mpc == NULL) + break; + dprintk("device (%s) was deallocated\n", dev->name); + stop_mpc(mpc); + dev_put(mpc->dev); + mpc->dev = NULL; + break; + case NETDEV_UP: + /* the dev was ifconfig'ed up */ + mpc = find_mpc_by_lec(dev); + if (mpc == NULL) + break; + if (mpc->mpoad_vcc != NULL) + start_mpc(mpc, dev); + break; + case NETDEV_DOWN: + /* the dev was ifconfig'ed down */ + /* this means that the flow of packets from the + * upper layer stops + */ + mpc = find_mpc_by_lec(dev); + if (mpc == NULL) + break; + if (mpc->mpoad_vcc != NULL) + stop_mpc(mpc); + break; + case NETDEV_REBOOT: + case NETDEV_CHANGE: + case NETDEV_CHANGEMTU: + case NETDEV_CHANGEADDR: + case NETDEV_GOING_DOWN: + break; + default: + break; + } + + return NOTIFY_DONE; +} + +/* + * Functions which are called after a message is received from mpcd. + * Msg is reused on purpose. + */ + + +static void MPOA_trigger_rcvd(struct k_message *msg, struct mpoa_client *mpc) +{ + __be32 dst_ip = msg->content.in_info.in_dst_ip; + in_cache_entry *entry; + + entry = mpc->in_ops->get(dst_ip, mpc); + if (entry == NULL) { + entry = mpc->in_ops->add_entry(dst_ip, mpc); + entry->entry_state = INGRESS_RESOLVING; + msg->type = SND_MPOA_RES_RQST; + msg->content.in_info = entry->ctrl_info; + msg_to_mpoad(msg, mpc); + entry->reply_wait = ktime_get_seconds(); + mpc->in_ops->put(entry); + return; + } + + if (entry->entry_state == INGRESS_INVALID) { + entry->entry_state = INGRESS_RESOLVING; + msg->type = SND_MPOA_RES_RQST; + msg->content.in_info = entry->ctrl_info; + msg_to_mpoad(msg, mpc); + entry->reply_wait = ktime_get_seconds(); + mpc->in_ops->put(entry); + return; + } + + pr_info("(%s) entry already in resolving state\n", + (mpc->dev) ? mpc->dev->name : ""); + mpc->in_ops->put(entry); +} + +/* + * Things get complicated because we have to check if there's an egress + * shortcut with suitable traffic parameters we could use. + */ +static void check_qos_and_open_shortcut(struct k_message *msg, + struct mpoa_client *client, + in_cache_entry *entry) +{ + __be32 dst_ip = msg->content.in_info.in_dst_ip; + struct atm_mpoa_qos *qos = atm_mpoa_search_qos(dst_ip); + eg_cache_entry *eg_entry = client->eg_ops->get_by_src_ip(dst_ip, client); + + if (eg_entry && eg_entry->shortcut) { + if (eg_entry->shortcut->qos.txtp.traffic_class & + msg->qos.txtp.traffic_class & + (qos ? qos->qos.txtp.traffic_class : ATM_UBR | ATM_CBR)) { + if (eg_entry->shortcut->qos.txtp.traffic_class == ATM_UBR) + entry->shortcut = eg_entry->shortcut; + else if (eg_entry->shortcut->qos.txtp.max_pcr > 0) + entry->shortcut = eg_entry->shortcut; + } + if (entry->shortcut) { + dprintk("(%s) using egress SVC to reach %pI4\n", + client->dev->name, &dst_ip); + client->eg_ops->put(eg_entry); + return; + } + } + if (eg_entry != NULL) + client->eg_ops->put(eg_entry); + + /* No luck in the egress cache we must open an ingress SVC */ + msg->type = OPEN_INGRESS_SVC; + if (qos && + (qos->qos.txtp.traffic_class == msg->qos.txtp.traffic_class)) { + msg->qos = qos->qos; + pr_info("(%s) trying to get a CBR shortcut\n", + client->dev->name); + } else + memset(&msg->qos, 0, sizeof(struct atm_qos)); + msg_to_mpoad(msg, client); +} + +static void MPOA_res_reply_rcvd(struct k_message *msg, struct mpoa_client *mpc) +{ + __be32 dst_ip = msg->content.in_info.in_dst_ip; + in_cache_entry *entry = mpc->in_ops->get(dst_ip, mpc); + + dprintk("(%s) ip %pI4\n", + mpc->dev->name, &dst_ip); + ddprintk("(%s) entry = %p", + mpc->dev->name, entry); + if (entry == NULL) { + pr_info("(%s) ARGH, received res. reply for an entry that doesn't exist.\n", + mpc->dev->name); + return; + } + ddprintk_cont(" entry_state = %d ", entry->entry_state); + + if (entry->entry_state == INGRESS_RESOLVED) { + pr_info("(%s) RESOLVED entry!\n", mpc->dev->name); + mpc->in_ops->put(entry); + return; + } + + entry->ctrl_info = msg->content.in_info; + entry->time = ktime_get_seconds(); + /* Used in refreshing func from now on */ + entry->reply_wait = ktime_get_seconds(); + entry->refresh_time = 0; + ddprintk_cont("entry->shortcut = %p\n", entry->shortcut); + + if (entry->entry_state == INGRESS_RESOLVING && + entry->shortcut != NULL) { + entry->entry_state = INGRESS_RESOLVED; + mpc->in_ops->put(entry); + return; /* Shortcut already open... */ + } + + if (entry->shortcut != NULL) { + pr_info("(%s) entry->shortcut != NULL, impossible!\n", + mpc->dev->name); + mpc->in_ops->put(entry); + return; + } + + check_qos_and_open_shortcut(msg, mpc, entry); + entry->entry_state = INGRESS_RESOLVED; + mpc->in_ops->put(entry); + + return; + +} + +static void ingress_purge_rcvd(struct k_message *msg, struct mpoa_client *mpc) +{ + __be32 dst_ip = msg->content.in_info.in_dst_ip; + __be32 mask = msg->ip_mask; + in_cache_entry *entry = mpc->in_ops->get_with_mask(dst_ip, mpc, mask); + + if (entry == NULL) { + pr_info("(%s) purge for a non-existing entry, ip = %pI4\n", + mpc->dev->name, &dst_ip); + return; + } + + do { + dprintk("(%s) removing an ingress entry, ip = %pI4\n", + mpc->dev->name, &dst_ip); + write_lock_bh(&mpc->ingress_lock); + mpc->in_ops->remove_entry(entry, mpc); + write_unlock_bh(&mpc->ingress_lock); + mpc->in_ops->put(entry); + entry = mpc->in_ops->get_with_mask(dst_ip, mpc, mask); + } while (entry != NULL); +} + +static void egress_purge_rcvd(struct k_message *msg, struct mpoa_client *mpc) +{ + __be32 cache_id = msg->content.eg_info.cache_id; + eg_cache_entry *entry = mpc->eg_ops->get_by_cache_id(cache_id, mpc); + + if (entry == NULL) { + dprintk("(%s) purge for a non-existing entry\n", + mpc->dev->name); + return; + } + + write_lock_irq(&mpc->egress_lock); + mpc->eg_ops->remove_entry(entry, mpc); + write_unlock_irq(&mpc->egress_lock); + + mpc->eg_ops->put(entry); +} + +static void purge_egress_shortcut(struct atm_vcc *vcc, eg_cache_entry *entry) +{ + struct sock *sk; + struct k_message *purge_msg; + struct sk_buff *skb; + + dprintk("entering\n"); + if (vcc == NULL) { + pr_info("vcc == NULL\n"); + return; + } + + skb = alloc_skb(sizeof(struct k_message), GFP_ATOMIC); + if (skb == NULL) { + pr_info("out of memory\n"); + return; + } + + skb_put(skb, sizeof(struct k_message)); + memset(skb->data, 0, sizeof(struct k_message)); + purge_msg = (struct k_message *)skb->data; + purge_msg->type = DATA_PLANE_PURGE; + if (entry != NULL) + purge_msg->content.eg_info = entry->ctrl_info; + + atm_force_charge(vcc, skb->truesize); + + sk = sk_atm(vcc); + skb_queue_tail(&sk->sk_receive_queue, skb); + sk->sk_data_ready(sk); + dprintk("exiting\n"); +} + +/* + * Our MPS died. Tell our daemon to send NHRP data plane purge to each + * of the egress shortcuts we have. + */ +static void mps_death(struct k_message *msg, struct mpoa_client *mpc) +{ + eg_cache_entry *entry; + + dprintk("(%s)\n", mpc->dev->name); + + if (memcmp(msg->MPS_ctrl, mpc->mps_ctrl_addr, ATM_ESA_LEN)) { + pr_info("(%s) wrong MPS\n", mpc->dev->name); + return; + } + + /* FIXME: This knows too much of the cache structure */ + read_lock_irq(&mpc->egress_lock); + entry = mpc->eg_cache; + while (entry != NULL) { + purge_egress_shortcut(entry->shortcut, entry); + entry = entry->next; + } + read_unlock_irq(&mpc->egress_lock); + + mpc->in_ops->destroy_cache(mpc); + mpc->eg_ops->destroy_cache(mpc); +} + +static void MPOA_cache_impos_rcvd(struct k_message *msg, + struct mpoa_client *mpc) +{ + uint16_t holding_time; + eg_cache_entry *entry = mpc->eg_ops->get_by_cache_id(msg->content.eg_info.cache_id, mpc); + + holding_time = msg->content.eg_info.holding_time; + dprintk("(%s) entry = %p, holding_time = %u\n", + mpc->dev->name, entry, holding_time); + if (entry == NULL && holding_time) { + entry = mpc->eg_ops->add_entry(msg, mpc); + mpc->eg_ops->put(entry); + return; + } + if (holding_time) { + mpc->eg_ops->update(entry, holding_time); + return; + } + + write_lock_irq(&mpc->egress_lock); + mpc->eg_ops->remove_entry(entry, mpc); + write_unlock_irq(&mpc->egress_lock); + + mpc->eg_ops->put(entry); +} + +static void set_mpc_ctrl_addr_rcvd(struct k_message *mesg, + struct mpoa_client *mpc) +{ + struct lec_priv *priv; + int i, retval ; + + uint8_t tlv[4 + 1 + 1 + 1 + ATM_ESA_LEN]; + + tlv[0] = 00; tlv[1] = 0xa0; tlv[2] = 0x3e; tlv[3] = 0x2a; /* type */ + tlv[4] = 1 + 1 + ATM_ESA_LEN; /* length */ + tlv[5] = 0x02; /* MPOA client */ + tlv[6] = 0x00; /* number of MPS MAC addresses */ + + memcpy(&tlv[7], mesg->MPS_ctrl, ATM_ESA_LEN); /* MPC ctrl ATM addr */ + memcpy(mpc->our_ctrl_addr, mesg->MPS_ctrl, ATM_ESA_LEN); + + dprintk("(%s) setting MPC ctrl ATM address to", + mpc->dev ? mpc->dev->name : ""); + for (i = 7; i < sizeof(tlv); i++) + dprintk_cont(" %02x", tlv[i]); + dprintk_cont("\n"); + + if (mpc->dev) { + priv = netdev_priv(mpc->dev); + retval = priv->lane2_ops->associate_req(mpc->dev, + mpc->dev->dev_addr, + tlv, sizeof(tlv)); + if (retval == 0) + pr_info("(%s) MPOA device type TLV association failed\n", + mpc->dev->name); + retval = priv->lane2_ops->resolve(mpc->dev, NULL, 1, NULL, NULL); + if (retval < 0) + pr_info("(%s) targetless LE_ARP request failed\n", + mpc->dev->name); + } +} + +static void set_mps_mac_addr_rcvd(struct k_message *msg, + struct mpoa_client *client) +{ + + if (client->number_of_mps_macs) + kfree(client->mps_macs); + client->number_of_mps_macs = 0; + client->mps_macs = kmemdup(msg->MPS_ctrl, ETH_ALEN, GFP_KERNEL); + if (client->mps_macs == NULL) { + pr_info("out of memory\n"); + return; + } + client->number_of_mps_macs = 1; +} + +/* + * purge egress cache and tell daemon to 'action' (DIE, RELOAD) + */ +static void clean_up(struct k_message *msg, struct mpoa_client *mpc, int action) +{ + + eg_cache_entry *entry; + msg->type = SND_EGRESS_PURGE; + + + /* FIXME: This knows too much of the cache structure */ + read_lock_irq(&mpc->egress_lock); + entry = mpc->eg_cache; + while (entry != NULL) { + msg->content.eg_info = entry->ctrl_info; + dprintk("cache_id %u\n", entry->ctrl_info.cache_id); + msg_to_mpoad(msg, mpc); + entry = entry->next; + } + read_unlock_irq(&mpc->egress_lock); + + msg->type = action; + msg_to_mpoad(msg, mpc); +} + +static unsigned long checking_time; + +static void mpc_timer_refresh(void) +{ + mpc_timer.expires = jiffies + (MPC_P2 * HZ); + checking_time = mpc_timer.expires; + add_timer(&mpc_timer); +} + +static void mpc_cache_check(struct timer_list *unused) +{ + struct mpoa_client *mpc = mpcs; + static unsigned long previous_resolving_check_time; + static unsigned long previous_refresh_time; + + while (mpc != NULL) { + mpc->in_ops->clear_count(mpc); + mpc->eg_ops->clear_expired(mpc); + if (checking_time - previous_resolving_check_time > + mpc->parameters.mpc_p4 * HZ) { + mpc->in_ops->check_resolving(mpc); + previous_resolving_check_time = checking_time; + } + if (checking_time - previous_refresh_time > + mpc->parameters.mpc_p5 * HZ) { + mpc->in_ops->refresh(mpc); + previous_refresh_time = checking_time; + } + mpc = mpc->next; + } + mpc_timer_refresh(); +} + +static int atm_mpoa_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + int err = 0; + struct atm_vcc *vcc = ATM_SD(sock); + + if (cmd != ATMMPC_CTRL && cmd != ATMMPC_DATA) + return -ENOIOCTLCMD; + + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + switch (cmd) { + case ATMMPC_CTRL: + err = atm_mpoa_mpoad_attach(vcc, (int)arg); + if (err >= 0) + sock->state = SS_CONNECTED; + break; + case ATMMPC_DATA: + err = atm_mpoa_vcc_attach(vcc, (void __user *)arg); + break; + default: + break; + } + return err; +} + +static struct atm_ioctl atm_ioctl_ops = { + .owner = THIS_MODULE, + .ioctl = atm_mpoa_ioctl, +}; + +static __init int atm_mpoa_init(void) +{ + register_atm_ioctl(&atm_ioctl_ops); + + if (mpc_proc_init() != 0) + pr_info("failed to initialize /proc/mpoa\n"); + + pr_info("mpc.c: initialized\n"); + + return 0; +} + +static void __exit atm_mpoa_cleanup(void) +{ + struct mpoa_client *mpc, *tmp; + struct atm_mpoa_qos *qos, *nextqos; + struct lec_priv *priv; + + mpc_proc_clean(); + + del_timer_sync(&mpc_timer); + unregister_netdevice_notifier(&mpoa_notifier); + deregister_atm_ioctl(&atm_ioctl_ops); + + mpc = mpcs; + mpcs = NULL; + while (mpc != NULL) { + tmp = mpc->next; + if (mpc->dev != NULL) { + stop_mpc(mpc); + priv = netdev_priv(mpc->dev); + if (priv->lane2_ops != NULL) + priv->lane2_ops->associate_indicator = NULL; + } + ddprintk("about to clear caches\n"); + mpc->in_ops->destroy_cache(mpc); + mpc->eg_ops->destroy_cache(mpc); + ddprintk("caches cleared\n"); + kfree(mpc->mps_macs); + memset(mpc, 0, sizeof(struct mpoa_client)); + ddprintk("about to kfree %p\n", mpc); + kfree(mpc); + ddprintk("next mpc is at %p\n", tmp); + mpc = tmp; + } + + qos = qos_head; + qos_head = NULL; + while (qos != NULL) { + nextqos = qos->next; + dprintk("freeing qos entry %p\n", qos); + kfree(qos); + qos = nextqos; + } +} + +module_init(atm_mpoa_init); +module_exit(atm_mpoa_cleanup); + +MODULE_LICENSE("GPL"); diff --git a/net/atm/mpc.h b/net/atm/mpc.h new file mode 100644 index 000000000..454abd076 --- /dev/null +++ b/net/atm/mpc.h @@ -0,0 +1,65 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _MPC_H_ +#define _MPC_H_ + +#include +#include +#include +#include +#include +#include "mpoa_caches.h" + +/* kernel -> mpc-daemon */ +int msg_to_mpoad(struct k_message *msg, struct mpoa_client *mpc); + +struct mpoa_client { + struct mpoa_client *next; + struct net_device *dev; /* lec in question */ + int dev_num; /* e.g. 2 for lec2 */ + + struct atm_vcc *mpoad_vcc; /* control channel to mpoad */ + uint8_t mps_ctrl_addr[ATM_ESA_LEN]; /* MPS control ATM address */ + uint8_t our_ctrl_addr[ATM_ESA_LEN]; /* MPC's control ATM address */ + + rwlock_t ingress_lock; + const struct in_cache_ops *in_ops; /* ingress cache operations */ + in_cache_entry *in_cache; /* the ingress cache of this MPC */ + + rwlock_t egress_lock; + const struct eg_cache_ops *eg_ops; /* egress cache operations */ + eg_cache_entry *eg_cache; /* the egress cache of this MPC */ + + uint8_t *mps_macs; /* array of MPS MAC addresses, >=1 */ + int number_of_mps_macs; /* number of the above MAC addresses */ + struct mpc_parameters parameters; /* parameters for this client */ + + const struct net_device_ops *old_ops; + struct net_device_ops new_ops; +}; + + +struct atm_mpoa_qos { + struct atm_mpoa_qos *next; + __be32 ipaddr; + struct atm_qos qos; +}; + + +/* MPOA QoS operations */ +struct atm_mpoa_qos *atm_mpoa_add_qos(__be32 dst_ip, struct atm_qos *qos); +struct atm_mpoa_qos *atm_mpoa_search_qos(__be32 dst_ip); +int atm_mpoa_delete_qos(struct atm_mpoa_qos *qos); + +/* Display QoS entries. This is for the procfs */ +struct seq_file; +void atm_mpoa_disp_qos(struct seq_file *m); + +#ifdef CONFIG_PROC_FS +int mpc_proc_init(void); +void mpc_proc_clean(void); +#else +#define mpc_proc_init() (0) +#define mpc_proc_clean() do { } while(0) +#endif + +#endif /* _MPC_H_ */ diff --git a/net/atm/mpoa_caches.c b/net/atm/mpoa_caches.c new file mode 100644 index 000000000..f7a2f0e41 --- /dev/null +++ b/net/atm/mpoa_caches.c @@ -0,0 +1,565 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include + +#include "mpoa_caches.h" +#include "mpc.h" + +/* + * mpoa_caches.c: Implementation of ingress and egress cache + * handling functions + */ + +#if 0 +#define dprintk(format, args...) \ + printk(KERN_DEBUG "mpoa:%s: " format, __FILE__, ##args) /* debug */ +#else +#define dprintk(format, args...) \ + do { if (0) \ + printk(KERN_DEBUG "mpoa:%s: " format, __FILE__, ##args);\ + } while (0) +#endif + +#if 0 +#define ddprintk(format, args...) \ + printk(KERN_DEBUG "mpoa:%s: " format, __FILE__, ##args) /* debug */ +#else +#define ddprintk(format, args...) \ + do { if (0) \ + printk(KERN_DEBUG "mpoa:%s: " format, __FILE__, ##args);\ + } while (0) +#endif + +static in_cache_entry *in_cache_get(__be32 dst_ip, + struct mpoa_client *client) +{ + in_cache_entry *entry; + + read_lock_bh(&client->ingress_lock); + entry = client->in_cache; + while (entry != NULL) { + if (entry->ctrl_info.in_dst_ip == dst_ip) { + refcount_inc(&entry->use); + read_unlock_bh(&client->ingress_lock); + return entry; + } + entry = entry->next; + } + read_unlock_bh(&client->ingress_lock); + + return NULL; +} + +static in_cache_entry *in_cache_get_with_mask(__be32 dst_ip, + struct mpoa_client *client, + __be32 mask) +{ + in_cache_entry *entry; + + read_lock_bh(&client->ingress_lock); + entry = client->in_cache; + while (entry != NULL) { + if ((entry->ctrl_info.in_dst_ip & mask) == (dst_ip & mask)) { + refcount_inc(&entry->use); + read_unlock_bh(&client->ingress_lock); + return entry; + } + entry = entry->next; + } + read_unlock_bh(&client->ingress_lock); + + return NULL; + +} + +static in_cache_entry *in_cache_get_by_vcc(struct atm_vcc *vcc, + struct mpoa_client *client) +{ + in_cache_entry *entry; + + read_lock_bh(&client->ingress_lock); + entry = client->in_cache; + while (entry != NULL) { + if (entry->shortcut == vcc) { + refcount_inc(&entry->use); + read_unlock_bh(&client->ingress_lock); + return entry; + } + entry = entry->next; + } + read_unlock_bh(&client->ingress_lock); + + return NULL; +} + +static in_cache_entry *in_cache_add_entry(__be32 dst_ip, + struct mpoa_client *client) +{ + in_cache_entry *entry = kzalloc(sizeof(in_cache_entry), GFP_KERNEL); + + if (entry == NULL) { + pr_info("mpoa: mpoa_caches.c: new_in_cache_entry: out of memory\n"); + return NULL; + } + + dprintk("adding an ingress entry, ip = %pI4\n", &dst_ip); + + refcount_set(&entry->use, 1); + dprintk("new_in_cache_entry: about to lock\n"); + write_lock_bh(&client->ingress_lock); + entry->next = client->in_cache; + entry->prev = NULL; + if (client->in_cache != NULL) + client->in_cache->prev = entry; + client->in_cache = entry; + + memcpy(entry->MPS_ctrl_ATM_addr, client->mps_ctrl_addr, ATM_ESA_LEN); + entry->ctrl_info.in_dst_ip = dst_ip; + entry->time = ktime_get_seconds(); + entry->retry_time = client->parameters.mpc_p4; + entry->count = 1; + entry->entry_state = INGRESS_INVALID; + entry->ctrl_info.holding_time = HOLDING_TIME_DEFAULT; + refcount_inc(&entry->use); + + write_unlock_bh(&client->ingress_lock); + dprintk("new_in_cache_entry: unlocked\n"); + + return entry; +} + +static int cache_hit(in_cache_entry *entry, struct mpoa_client *mpc) +{ + struct atm_mpoa_qos *qos; + struct k_message msg; + + entry->count++; + if (entry->entry_state == INGRESS_RESOLVED && entry->shortcut != NULL) + return OPEN; + + if (entry->entry_state == INGRESS_REFRESHING) { + if (entry->count > mpc->parameters.mpc_p1) { + msg.type = SND_MPOA_RES_RQST; + msg.content.in_info = entry->ctrl_info; + memcpy(msg.MPS_ctrl, mpc->mps_ctrl_addr, ATM_ESA_LEN); + qos = atm_mpoa_search_qos(entry->ctrl_info.in_dst_ip); + if (qos != NULL) + msg.qos = qos->qos; + msg_to_mpoad(&msg, mpc); + entry->reply_wait = ktime_get_seconds(); + entry->entry_state = INGRESS_RESOLVING; + } + if (entry->shortcut != NULL) + return OPEN; + return CLOSED; + } + + if (entry->entry_state == INGRESS_RESOLVING && entry->shortcut != NULL) + return OPEN; + + if (entry->count > mpc->parameters.mpc_p1 && + entry->entry_state == INGRESS_INVALID) { + dprintk("(%s) threshold exceeded for ip %pI4, sending MPOA res req\n", + mpc->dev->name, &entry->ctrl_info.in_dst_ip); + entry->entry_state = INGRESS_RESOLVING; + msg.type = SND_MPOA_RES_RQST; + memcpy(msg.MPS_ctrl, mpc->mps_ctrl_addr, ATM_ESA_LEN); + msg.content.in_info = entry->ctrl_info; + qos = atm_mpoa_search_qos(entry->ctrl_info.in_dst_ip); + if (qos != NULL) + msg.qos = qos->qos; + msg_to_mpoad(&msg, mpc); + entry->reply_wait = ktime_get_seconds(); + } + + return CLOSED; +} + +static void in_cache_put(in_cache_entry *entry) +{ + if (refcount_dec_and_test(&entry->use)) { + kfree_sensitive(entry); + } +} + +/* + * This should be called with write lock on + */ +static void in_cache_remove_entry(in_cache_entry *entry, + struct mpoa_client *client) +{ + struct atm_vcc *vcc; + struct k_message msg; + + vcc = entry->shortcut; + dprintk("removing an ingress entry, ip = %pI4\n", + &entry->ctrl_info.in_dst_ip); + + if (entry->prev != NULL) + entry->prev->next = entry->next; + else + client->in_cache = entry->next; + if (entry->next != NULL) + entry->next->prev = entry->prev; + client->in_ops->put(entry); + if (client->in_cache == NULL && client->eg_cache == NULL) { + msg.type = STOP_KEEP_ALIVE_SM; + msg_to_mpoad(&msg, client); + } + + /* Check if the egress side still uses this VCC */ + if (vcc != NULL) { + eg_cache_entry *eg_entry = client->eg_ops->get_by_vcc(vcc, + client); + if (eg_entry != NULL) { + client->eg_ops->put(eg_entry); + return; + } + vcc_release_async(vcc, -EPIPE); + } +} + +/* Call this every MPC-p2 seconds... Not exactly correct solution, + but an easy one... */ +static void clear_count_and_expired(struct mpoa_client *client) +{ + in_cache_entry *entry, *next_entry; + time64_t now; + + now = ktime_get_seconds(); + + write_lock_bh(&client->ingress_lock); + entry = client->in_cache; + while (entry != NULL) { + entry->count = 0; + next_entry = entry->next; + if ((now - entry->time) > entry->ctrl_info.holding_time) { + dprintk("holding time expired, ip = %pI4\n", + &entry->ctrl_info.in_dst_ip); + client->in_ops->remove_entry(entry, client); + } + entry = next_entry; + } + write_unlock_bh(&client->ingress_lock); +} + +/* Call this every MPC-p4 seconds. */ +static void check_resolving_entries(struct mpoa_client *client) +{ + + struct atm_mpoa_qos *qos; + in_cache_entry *entry; + time64_t now; + struct k_message msg; + + now = ktime_get_seconds(); + + read_lock_bh(&client->ingress_lock); + entry = client->in_cache; + while (entry != NULL) { + if (entry->entry_state == INGRESS_RESOLVING) { + + if ((now - entry->hold_down) + < client->parameters.mpc_p6) { + entry = entry->next; /* Entry in hold down */ + continue; + } + if ((now - entry->reply_wait) > entry->retry_time) { + entry->retry_time = MPC_C1 * (entry->retry_time); + /* + * Retry time maximum exceeded, + * put entry in hold down. + */ + if (entry->retry_time > client->parameters.mpc_p5) { + entry->hold_down = ktime_get_seconds(); + entry->retry_time = client->parameters.mpc_p4; + entry = entry->next; + continue; + } + /* Ask daemon to send a resolution request. */ + memset(&entry->hold_down, 0, sizeof(time64_t)); + msg.type = SND_MPOA_RES_RTRY; + memcpy(msg.MPS_ctrl, client->mps_ctrl_addr, ATM_ESA_LEN); + msg.content.in_info = entry->ctrl_info; + qos = atm_mpoa_search_qos(entry->ctrl_info.in_dst_ip); + if (qos != NULL) + msg.qos = qos->qos; + msg_to_mpoad(&msg, client); + entry->reply_wait = ktime_get_seconds(); + } + } + entry = entry->next; + } + read_unlock_bh(&client->ingress_lock); +} + +/* Call this every MPC-p5 seconds. */ +static void refresh_entries(struct mpoa_client *client) +{ + time64_t now; + struct in_cache_entry *entry = client->in_cache; + + ddprintk("refresh_entries\n"); + now = ktime_get_seconds(); + + read_lock_bh(&client->ingress_lock); + while (entry != NULL) { + if (entry->entry_state == INGRESS_RESOLVED) { + if (!(entry->refresh_time)) + entry->refresh_time = (2 * (entry->ctrl_info.holding_time))/3; + if ((now - entry->reply_wait) > + entry->refresh_time) { + dprintk("refreshing an entry.\n"); + entry->entry_state = INGRESS_REFRESHING; + + } + } + entry = entry->next; + } + read_unlock_bh(&client->ingress_lock); +} + +static void in_destroy_cache(struct mpoa_client *mpc) +{ + write_lock_irq(&mpc->ingress_lock); + while (mpc->in_cache != NULL) + mpc->in_ops->remove_entry(mpc->in_cache, mpc); + write_unlock_irq(&mpc->ingress_lock); +} + +static eg_cache_entry *eg_cache_get_by_cache_id(__be32 cache_id, + struct mpoa_client *mpc) +{ + eg_cache_entry *entry; + + read_lock_irq(&mpc->egress_lock); + entry = mpc->eg_cache; + while (entry != NULL) { + if (entry->ctrl_info.cache_id == cache_id) { + refcount_inc(&entry->use); + read_unlock_irq(&mpc->egress_lock); + return entry; + } + entry = entry->next; + } + read_unlock_irq(&mpc->egress_lock); + + return NULL; +} + +/* This can be called from any context since it saves CPU flags */ +static eg_cache_entry *eg_cache_get_by_tag(__be32 tag, struct mpoa_client *mpc) +{ + unsigned long flags; + eg_cache_entry *entry; + + read_lock_irqsave(&mpc->egress_lock, flags); + entry = mpc->eg_cache; + while (entry != NULL) { + if (entry->ctrl_info.tag == tag) { + refcount_inc(&entry->use); + read_unlock_irqrestore(&mpc->egress_lock, flags); + return entry; + } + entry = entry->next; + } + read_unlock_irqrestore(&mpc->egress_lock, flags); + + return NULL; +} + +/* This can be called from any context since it saves CPU flags */ +static eg_cache_entry *eg_cache_get_by_vcc(struct atm_vcc *vcc, + struct mpoa_client *mpc) +{ + unsigned long flags; + eg_cache_entry *entry; + + read_lock_irqsave(&mpc->egress_lock, flags); + entry = mpc->eg_cache; + while (entry != NULL) { + if (entry->shortcut == vcc) { + refcount_inc(&entry->use); + read_unlock_irqrestore(&mpc->egress_lock, flags); + return entry; + } + entry = entry->next; + } + read_unlock_irqrestore(&mpc->egress_lock, flags); + + return NULL; +} + +static eg_cache_entry *eg_cache_get_by_src_ip(__be32 ipaddr, + struct mpoa_client *mpc) +{ + eg_cache_entry *entry; + + read_lock_irq(&mpc->egress_lock); + entry = mpc->eg_cache; + while (entry != NULL) { + if (entry->latest_ip_addr == ipaddr) { + refcount_inc(&entry->use); + read_unlock_irq(&mpc->egress_lock); + return entry; + } + entry = entry->next; + } + read_unlock_irq(&mpc->egress_lock); + + return NULL; +} + +static void eg_cache_put(eg_cache_entry *entry) +{ + if (refcount_dec_and_test(&entry->use)) { + kfree_sensitive(entry); + } +} + +/* + * This should be called with write lock on + */ +static void eg_cache_remove_entry(eg_cache_entry *entry, + struct mpoa_client *client) +{ + struct atm_vcc *vcc; + struct k_message msg; + + vcc = entry->shortcut; + dprintk("removing an egress entry.\n"); + if (entry->prev != NULL) + entry->prev->next = entry->next; + else + client->eg_cache = entry->next; + if (entry->next != NULL) + entry->next->prev = entry->prev; + client->eg_ops->put(entry); + if (client->in_cache == NULL && client->eg_cache == NULL) { + msg.type = STOP_KEEP_ALIVE_SM; + msg_to_mpoad(&msg, client); + } + + /* Check if the ingress side still uses this VCC */ + if (vcc != NULL) { + in_cache_entry *in_entry = client->in_ops->get_by_vcc(vcc, client); + if (in_entry != NULL) { + client->in_ops->put(in_entry); + return; + } + vcc_release_async(vcc, -EPIPE); + } +} + +static eg_cache_entry *eg_cache_add_entry(struct k_message *msg, + struct mpoa_client *client) +{ + eg_cache_entry *entry = kzalloc(sizeof(eg_cache_entry), GFP_KERNEL); + + if (entry == NULL) { + pr_info("out of memory\n"); + return NULL; + } + + dprintk("adding an egress entry, ip = %pI4, this should be our IP\n", + &msg->content.eg_info.eg_dst_ip); + + refcount_set(&entry->use, 1); + dprintk("new_eg_cache_entry: about to lock\n"); + write_lock_irq(&client->egress_lock); + entry->next = client->eg_cache; + entry->prev = NULL; + if (client->eg_cache != NULL) + client->eg_cache->prev = entry; + client->eg_cache = entry; + + memcpy(entry->MPS_ctrl_ATM_addr, client->mps_ctrl_addr, ATM_ESA_LEN); + entry->ctrl_info = msg->content.eg_info; + entry->time = ktime_get_seconds(); + entry->entry_state = EGRESS_RESOLVED; + dprintk("new_eg_cache_entry cache_id %u\n", + ntohl(entry->ctrl_info.cache_id)); + dprintk("mps_ip = %pI4\n", &entry->ctrl_info.mps_ip); + refcount_inc(&entry->use); + + write_unlock_irq(&client->egress_lock); + dprintk("new_eg_cache_entry: unlocked\n"); + + return entry; +} + +static void update_eg_cache_entry(eg_cache_entry *entry, uint16_t holding_time) +{ + entry->time = ktime_get_seconds(); + entry->entry_state = EGRESS_RESOLVED; + entry->ctrl_info.holding_time = holding_time; +} + +static void clear_expired(struct mpoa_client *client) +{ + eg_cache_entry *entry, *next_entry; + time64_t now; + struct k_message msg; + + now = ktime_get_seconds(); + + write_lock_irq(&client->egress_lock); + entry = client->eg_cache; + while (entry != NULL) { + next_entry = entry->next; + if ((now - entry->time) > entry->ctrl_info.holding_time) { + msg.type = SND_EGRESS_PURGE; + msg.content.eg_info = entry->ctrl_info; + dprintk("egress_cache: holding time expired, cache_id = %u.\n", + ntohl(entry->ctrl_info.cache_id)); + msg_to_mpoad(&msg, client); + client->eg_ops->remove_entry(entry, client); + } + entry = next_entry; + } + write_unlock_irq(&client->egress_lock); +} + +static void eg_destroy_cache(struct mpoa_client *mpc) +{ + write_lock_irq(&mpc->egress_lock); + while (mpc->eg_cache != NULL) + mpc->eg_ops->remove_entry(mpc->eg_cache, mpc); + write_unlock_irq(&mpc->egress_lock); +} + + +static const struct in_cache_ops ingress_ops = { + .add_entry = in_cache_add_entry, + .get = in_cache_get, + .get_with_mask = in_cache_get_with_mask, + .get_by_vcc = in_cache_get_by_vcc, + .put = in_cache_put, + .remove_entry = in_cache_remove_entry, + .cache_hit = cache_hit, + .clear_count = clear_count_and_expired, + .check_resolving = check_resolving_entries, + .refresh = refresh_entries, + .destroy_cache = in_destroy_cache +}; + +static const struct eg_cache_ops egress_ops = { + .add_entry = eg_cache_add_entry, + .get_by_cache_id = eg_cache_get_by_cache_id, + .get_by_tag = eg_cache_get_by_tag, + .get_by_vcc = eg_cache_get_by_vcc, + .get_by_src_ip = eg_cache_get_by_src_ip, + .put = eg_cache_put, + .remove_entry = eg_cache_remove_entry, + .update = update_eg_cache_entry, + .clear_expired = clear_expired, + .destroy_cache = eg_destroy_cache +}; + +void atm_mpoa_init_cache(struct mpoa_client *mpc) +{ + mpc->in_ops = &ingress_ops; + mpc->eg_ops = &egress_ops; +} diff --git a/net/atm/mpoa_caches.h b/net/atm/mpoa_caches.h new file mode 100644 index 000000000..464c4c7f8 --- /dev/null +++ b/net/atm/mpoa_caches.h @@ -0,0 +1,99 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef MPOA_CACHES_H +#define MPOA_CACHES_H + +#include +#include +#include +#include +#include +#include +#include + +struct mpoa_client; + +void atm_mpoa_init_cache(struct mpoa_client *mpc); + +typedef struct in_cache_entry { + struct in_cache_entry *next; + struct in_cache_entry *prev; + time64_t time; + time64_t reply_wait; + time64_t hold_down; + uint32_t packets_fwded; + uint16_t entry_state; + uint32_t retry_time; + uint32_t refresh_time; + uint32_t count; + struct atm_vcc *shortcut; + uint8_t MPS_ctrl_ATM_addr[ATM_ESA_LEN]; + struct in_ctrl_info ctrl_info; + refcount_t use; +} in_cache_entry; + +struct in_cache_ops{ + in_cache_entry *(*add_entry)(__be32 dst_ip, + struct mpoa_client *client); + in_cache_entry *(*get)(__be32 dst_ip, struct mpoa_client *client); + in_cache_entry *(*get_with_mask)(__be32 dst_ip, + struct mpoa_client *client, + __be32 mask); + in_cache_entry *(*get_by_vcc)(struct atm_vcc *vcc, + struct mpoa_client *client); + void (*put)(in_cache_entry *entry); + void (*remove_entry)(in_cache_entry *delEntry, + struct mpoa_client *client ); + int (*cache_hit)(in_cache_entry *entry, + struct mpoa_client *client); + void (*clear_count)(struct mpoa_client *client); + void (*check_resolving)(struct mpoa_client *client); + void (*refresh)(struct mpoa_client *client); + void (*destroy_cache)(struct mpoa_client *mpc); +}; + +typedef struct eg_cache_entry{ + struct eg_cache_entry *next; + struct eg_cache_entry *prev; + time64_t time; + uint8_t MPS_ctrl_ATM_addr[ATM_ESA_LEN]; + struct atm_vcc *shortcut; + uint32_t packets_rcvd; + uint16_t entry_state; + __be32 latest_ip_addr; /* The src IP address of the last packet */ + struct eg_ctrl_info ctrl_info; + refcount_t use; +} eg_cache_entry; + +struct eg_cache_ops{ + eg_cache_entry *(*add_entry)(struct k_message *msg, struct mpoa_client *client); + eg_cache_entry *(*get_by_cache_id)(__be32 cache_id, struct mpoa_client *client); + eg_cache_entry *(*get_by_tag)(__be32 cache_id, struct mpoa_client *client); + eg_cache_entry *(*get_by_vcc)(struct atm_vcc *vcc, struct mpoa_client *client); + eg_cache_entry *(*get_by_src_ip)(__be32 ipaddr, struct mpoa_client *client); + void (*put)(eg_cache_entry *entry); + void (*remove_entry)(eg_cache_entry *entry, struct mpoa_client *client); + void (*update)(eg_cache_entry *entry, uint16_t holding_time); + void (*clear_expired)(struct mpoa_client *client); + void (*destroy_cache)(struct mpoa_client *mpc); +}; + + +/* Ingress cache entry states */ + +#define INGRESS_REFRESHING 3 +#define INGRESS_RESOLVED 2 +#define INGRESS_RESOLVING 1 +#define INGRESS_INVALID 0 + +/* VCC states */ + +#define OPEN 1 +#define CLOSED 0 + +/* Egress cache entry states */ + +#define EGRESS_RESOLVED 2 +#define EGRESS_PURGE 1 +#define EGRESS_INVALID 0 + +#endif diff --git a/net/atm/mpoa_proc.c b/net/atm/mpoa_proc.c new file mode 100644 index 000000000..aaf64b953 --- /dev/null +++ b/net/atm/mpoa_proc.c @@ -0,0 +1,307 @@ +// SPDX-License-Identifier: GPL-2.0 +#define pr_fmt(fmt) KBUILD_MODNAME ":%s: " fmt, __func__ + +#ifdef CONFIG_PROC_FS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "mpc.h" +#include "mpoa_caches.h" + +/* + * mpoa_proc.c: Implementation MPOA client's proc + * file system statistics + */ + +#if 1 +#define dprintk(format, args...) \ + printk(KERN_DEBUG "mpoa:%s: " format, __FILE__, ##args) /* debug */ +#else +#define dprintk(format, args...) \ + do { if (0) \ + printk(KERN_DEBUG "mpoa:%s: " format, __FILE__, ##args);\ + } while (0) +#endif + +#if 0 +#define ddprintk(format, args...) \ + printk(KERN_DEBUG "mpoa:%s: " format, __FILE__, ##args) /* debug */ +#else +#define ddprintk(format, args...) \ + do { if (0) \ + printk(KERN_DEBUG "mpoa:%s: " format, __FILE__, ##args);\ + } while (0) +#endif + +#define STAT_FILE_NAME "mpc" /* Our statistic file's name */ + +extern struct mpoa_client *mpcs; +extern struct proc_dir_entry *atm_proc_root; /* from proc.c. */ + +static int proc_mpc_open(struct inode *inode, struct file *file); +static ssize_t proc_mpc_write(struct file *file, const char __user *buff, + size_t nbytes, loff_t *ppos); + +static int parse_qos(const char *buff); + +static const struct proc_ops mpc_proc_ops = { + .proc_open = proc_mpc_open, + .proc_read = seq_read, + .proc_lseek = seq_lseek, + .proc_write = proc_mpc_write, + .proc_release = seq_release, +}; + +/* + * Returns the state of an ingress cache entry as a string + */ +static const char *ingress_state_string(int state) +{ + switch (state) { + case INGRESS_RESOLVING: + return "resolving "; + case INGRESS_RESOLVED: + return "resolved "; + case INGRESS_INVALID: + return "invalid "; + case INGRESS_REFRESHING: + return "refreshing "; + } + + return ""; +} + +/* + * Returns the state of an egress cache entry as a string + */ +static const char *egress_state_string(int state) +{ + switch (state) { + case EGRESS_RESOLVED: + return "resolved "; + case EGRESS_PURGE: + return "purge "; + case EGRESS_INVALID: + return "invalid "; + } + + return ""; +} + +/* + * FIXME: mpcs (and per-mpc lists) have no locking whatsoever. + */ + +static void *mpc_start(struct seq_file *m, loff_t *pos) +{ + loff_t l = *pos; + struct mpoa_client *mpc; + + if (!l--) + return SEQ_START_TOKEN; + for (mpc = mpcs; mpc; mpc = mpc->next) + if (!l--) + return mpc; + return NULL; +} + +static void *mpc_next(struct seq_file *m, void *v, loff_t *pos) +{ + struct mpoa_client *p = v; + (*pos)++; + return v == SEQ_START_TOKEN ? mpcs : p->next; +} + +static void mpc_stop(struct seq_file *m, void *v) +{ +} + +/* + * READING function - called when the /proc/atm/mpoa file is read from. + */ +static int mpc_show(struct seq_file *m, void *v) +{ + struct mpoa_client *mpc = v; + int i; + in_cache_entry *in_entry; + eg_cache_entry *eg_entry; + time64_t now; + unsigned char ip_string[16]; + + if (v == SEQ_START_TOKEN) { + atm_mpoa_disp_qos(m); + return 0; + } + + seq_printf(m, "\nInterface %d:\n\n", mpc->dev_num); + seq_printf(m, "Ingress Entries:\nIP address State Holding time Packets fwded VPI VCI\n"); + now = ktime_get_seconds(); + + for (in_entry = mpc->in_cache; in_entry; in_entry = in_entry->next) { + unsigned long seconds_delta = now - in_entry->time; + + sprintf(ip_string, "%pI4", &in_entry->ctrl_info.in_dst_ip); + seq_printf(m, "%-16s%s%-14lu%-12u", + ip_string, + ingress_state_string(in_entry->entry_state), + in_entry->ctrl_info.holding_time - + seconds_delta, + in_entry->packets_fwded); + if (in_entry->shortcut) + seq_printf(m, " %-3d %-3d", + in_entry->shortcut->vpi, + in_entry->shortcut->vci); + seq_printf(m, "\n"); + } + + seq_printf(m, "\n"); + seq_printf(m, "Egress Entries:\nIngress MPC ATM addr\nCache-id State Holding time Packets recvd Latest IP addr VPI VCI\n"); + for (eg_entry = mpc->eg_cache; eg_entry; eg_entry = eg_entry->next) { + unsigned char *p = eg_entry->ctrl_info.in_MPC_data_ATM_addr; + unsigned long seconds_delta = now - eg_entry->time; + + for (i = 0; i < ATM_ESA_LEN; i++) + seq_printf(m, "%02x", p[i]); + seq_printf(m, "\n%-16lu%s%-14lu%-15u", + (unsigned long)ntohl(eg_entry->ctrl_info.cache_id), + egress_state_string(eg_entry->entry_state), + (eg_entry->ctrl_info.holding_time - seconds_delta), + eg_entry->packets_rcvd); + + /* latest IP address */ + sprintf(ip_string, "%pI4", &eg_entry->latest_ip_addr); + seq_printf(m, "%-16s", ip_string); + + if (eg_entry->shortcut) + seq_printf(m, " %-3d %-3d", + eg_entry->shortcut->vpi, + eg_entry->shortcut->vci); + seq_printf(m, "\n"); + } + seq_printf(m, "\n"); + return 0; +} + +static const struct seq_operations mpc_op = { + .start = mpc_start, + .next = mpc_next, + .stop = mpc_stop, + .show = mpc_show +}; + +static int proc_mpc_open(struct inode *inode, struct file *file) +{ + return seq_open(file, &mpc_op); +} + +static ssize_t proc_mpc_write(struct file *file, const char __user *buff, + size_t nbytes, loff_t *ppos) +{ + char *page, *p; + unsigned int len; + + if (nbytes == 0) + return 0; + + if (nbytes >= PAGE_SIZE) + nbytes = PAGE_SIZE-1; + + page = (char *)__get_free_page(GFP_KERNEL); + if (!page) + return -ENOMEM; + + for (p = page, len = 0; len < nbytes; p++) { + if (get_user(*p, buff++)) { + free_page((unsigned long)page); + return -EFAULT; + } + len += 1; + if (*p == '\0' || *p == '\n') + break; + } + + *p = '\0'; + + if (!parse_qos(page)) + printk("mpoa: proc_mpc_write: could not parse '%s'\n", page); + + free_page((unsigned long)page); + + return len; +} + +static int parse_qos(const char *buff) +{ + /* possible lines look like this + * add 130.230.54.142 tx=max_pcr,max_sdu rx=max_pcr,max_sdu + */ + unsigned char ip[4]; + int tx_pcr, tx_sdu, rx_pcr, rx_sdu; + __be32 ipaddr; + struct atm_qos qos; + + memset(&qos, 0, sizeof(struct atm_qos)); + + if (sscanf(buff, "del %hhu.%hhu.%hhu.%hhu", + ip, ip+1, ip+2, ip+3) == 4) { + ipaddr = *(__be32 *)ip; + return atm_mpoa_delete_qos(atm_mpoa_search_qos(ipaddr)); + } + + if (sscanf(buff, "add %hhu.%hhu.%hhu.%hhu tx=%d,%d rx=tx", + ip, ip+1, ip+2, ip+3, &tx_pcr, &tx_sdu) == 6) { + rx_pcr = tx_pcr; + rx_sdu = tx_sdu; + } else if (sscanf(buff, "add %hhu.%hhu.%hhu.%hhu tx=%d,%d rx=%d,%d", + ip, ip+1, ip+2, ip+3, &tx_pcr, &tx_sdu, &rx_pcr, &rx_sdu) != 8) + return 0; + + ipaddr = *(__be32 *)ip; + qos.txtp.traffic_class = ATM_CBR; + qos.txtp.max_pcr = tx_pcr; + qos.txtp.max_sdu = tx_sdu; + qos.rxtp.traffic_class = ATM_CBR; + qos.rxtp.max_pcr = rx_pcr; + qos.rxtp.max_sdu = rx_sdu; + qos.aal = ATM_AAL5; + dprintk("parse_qos(): setting qos parameters to tx=%d,%d rx=%d,%d\n", + qos.txtp.max_pcr, qos.txtp.max_sdu, + qos.rxtp.max_pcr, qos.rxtp.max_sdu); + + atm_mpoa_add_qos(ipaddr, &qos); + return 1; +} + +/* + * INITIALIZATION function - called when module is initialized/loaded. + */ +int mpc_proc_init(void) +{ + struct proc_dir_entry *p; + + p = proc_create(STAT_FILE_NAME, 0, atm_proc_root, &mpc_proc_ops); + if (!p) { + pr_err("Unable to initialize /proc/atm/%s\n", STAT_FILE_NAME); + return -ENOMEM; + } + return 0; +} + +/* + * DELETING function - called when module is removed. + */ +void mpc_proc_clean(void) +{ + remove_proc_entry(STAT_FILE_NAME, atm_proc_root); +} + +#endif /* CONFIG_PROC_FS */ diff --git a/net/atm/pppoatm.c b/net/atm/pppoatm.c new file mode 100644 index 000000000..3e4f17d33 --- /dev/null +++ b/net/atm/pppoatm.c @@ -0,0 +1,491 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* net/atm/pppoatm.c - RFC2364 PPP over ATM/AAL5 */ + +/* Copyright 1999-2000 by Mitchell Blank Jr */ +/* Based on clip.c; 1995-1999 by Werner Almesberger, EPFL LRC/ICA */ +/* And on ppp_async.c; Copyright 1999 Paul Mackerras */ +/* And help from Jens Axboe */ + +/* + * + * This driver provides the encapsulation and framing for sending + * and receiving PPP frames in ATM AAL5 PDUs. + */ + +/* + * One shortcoming of this driver is that it does not comply with + * section 8 of RFC2364 - we are supposed to detect a change + * in encapsulation and immediately abort the connection (in order + * to avoid a black-hole being created if our peer loses state + * and changes encapsulation unilaterally. However, since the + * ppp_generic layer actually does the decapsulation, we need + * a way of notifying it when we _think_ there might be a problem) + * There's two cases: + * 1. LLC-encapsulation was missing when it was enabled. In + * this case, we should tell the upper layer "tear down + * this session if this skb looks ok to you" + * 2. LLC-encapsulation was present when it was disabled. Then + * we need to tell the upper layer "this packet may be + * ok, but if its in error tear down the session" + * These hooks are not yet available in ppp_generic + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s: " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" + +enum pppoatm_encaps { + e_autodetect = PPPOATM_ENCAPS_AUTODETECT, + e_vc = PPPOATM_ENCAPS_VC, + e_llc = PPPOATM_ENCAPS_LLC, +}; + +struct pppoatm_vcc { + struct atm_vcc *atmvcc; /* VCC descriptor */ + void (*old_push)(struct atm_vcc *, struct sk_buff *); + void (*old_pop)(struct atm_vcc *, struct sk_buff *); + void (*old_release_cb)(struct atm_vcc *); + struct module *old_owner; + /* keep old push/pop for detaching */ + enum pppoatm_encaps encaps; + atomic_t inflight; + unsigned long blocked; + int flags; /* SC_COMP_PROT - compress protocol */ + struct ppp_channel chan; /* interface to generic ppp layer */ + struct tasklet_struct wakeup_tasklet; +}; + +/* + * We want to allow two packets in the queue. The one that's currently in + * flight, and *one* queued up ready for the ATM device to send immediately + * from its TX done IRQ. We want to be able to use atomic_inc_not_zero(), so + * inflight == -2 represents an empty queue, -1 one packet, and zero means + * there are two packets in the queue. + */ +#define NONE_INFLIGHT -2 + +#define BLOCKED 0 + +/* + * Header used for LLC Encapsulated PPP (4 bytes) followed by the LCP protocol + * ID (0xC021) used in autodetection + */ +static const unsigned char pppllc[6] = { 0xFE, 0xFE, 0x03, 0xCF, 0xC0, 0x21 }; +#define LLC_LEN (4) + +static inline struct pppoatm_vcc *atmvcc_to_pvcc(const struct atm_vcc *atmvcc) +{ + return (struct pppoatm_vcc *) (atmvcc->user_back); +} + +static inline struct pppoatm_vcc *chan_to_pvcc(const struct ppp_channel *chan) +{ + return (struct pppoatm_vcc *) (chan->private); +} + +/* + * We can't do this directly from our _pop handler, since the ppp code + * doesn't want to be called in interrupt context, so we do it from + * a tasklet + */ +static void pppoatm_wakeup_sender(struct tasklet_struct *t) +{ + struct pppoatm_vcc *pvcc = from_tasklet(pvcc, t, wakeup_tasklet); + + ppp_output_wakeup(&pvcc->chan); +} + +static void pppoatm_release_cb(struct atm_vcc *atmvcc) +{ + struct pppoatm_vcc *pvcc = atmvcc_to_pvcc(atmvcc); + + /* + * As in pppoatm_pop(), it's safe to clear the BLOCKED bit here because + * the wakeup *can't* race with pppoatm_send(). They both hold the PPP + * channel's ->downl lock. And the potential race with *setting* it, + * which leads to the double-check dance in pppoatm_may_send(), doesn't + * exist here. In the sock_owned_by_user() case in pppoatm_send(), we + * set the BLOCKED bit while the socket is still locked. We know that + * ->release_cb() can't be called until that's done. + */ + if (test_and_clear_bit(BLOCKED, &pvcc->blocked)) + tasklet_schedule(&pvcc->wakeup_tasklet); + if (pvcc->old_release_cb) + pvcc->old_release_cb(atmvcc); +} +/* + * This gets called every time the ATM card has finished sending our + * skb. The ->old_pop will take care up normal atm flow control, + * but we also need to wake up the device if we blocked it + */ +static void pppoatm_pop(struct atm_vcc *atmvcc, struct sk_buff *skb) +{ + struct pppoatm_vcc *pvcc = atmvcc_to_pvcc(atmvcc); + + pvcc->old_pop(atmvcc, skb); + atomic_dec(&pvcc->inflight); + + /* + * We always used to run the wakeup tasklet unconditionally here, for + * fear of race conditions where we clear the BLOCKED flag just as we + * refuse another packet in pppoatm_send(). This was quite inefficient. + * + * In fact it's OK. The PPP core will only ever call pppoatm_send() + * while holding the channel->downl lock. And ppp_output_wakeup() as + * called by the tasklet will *also* grab that lock. So even if another + * CPU is in pppoatm_send() right now, the tasklet isn't going to race + * with it. The wakeup *will* happen after the other CPU is safely out + * of pppoatm_send() again. + * + * So if the CPU in pppoatm_send() has already set the BLOCKED bit and + * it about to return, that's fine. We trigger a wakeup which will + * happen later. And if the CPU in pppoatm_send() *hasn't* set the + * BLOCKED bit yet, that's fine too because of the double check in + * pppoatm_may_send() which is commented there. + */ + if (test_and_clear_bit(BLOCKED, &pvcc->blocked)) + tasklet_schedule(&pvcc->wakeup_tasklet); +} + +/* + * Unbind from PPP - currently we only do this when closing the socket, + * but we could put this into an ioctl if need be + */ +static void pppoatm_unassign_vcc(struct atm_vcc *atmvcc) +{ + struct pppoatm_vcc *pvcc; + pvcc = atmvcc_to_pvcc(atmvcc); + atmvcc->push = pvcc->old_push; + atmvcc->pop = pvcc->old_pop; + atmvcc->release_cb = pvcc->old_release_cb; + tasklet_kill(&pvcc->wakeup_tasklet); + ppp_unregister_channel(&pvcc->chan); + atmvcc->user_back = NULL; + kfree(pvcc); +} + +/* Called when an AAL5 PDU comes in */ +static void pppoatm_push(struct atm_vcc *atmvcc, struct sk_buff *skb) +{ + struct pppoatm_vcc *pvcc = atmvcc_to_pvcc(atmvcc); + pr_debug("\n"); + if (skb == NULL) { /* VCC was closed */ + struct module *module; + + pr_debug("removing ATMPPP VCC %p\n", pvcc); + module = pvcc->old_owner; + pppoatm_unassign_vcc(atmvcc); + atmvcc->push(atmvcc, NULL); /* Pass along bad news */ + module_put(module); + return; + } + atm_return(atmvcc, skb->truesize); + switch (pvcc->encaps) { + case e_llc: + if (skb->len < LLC_LEN || + memcmp(skb->data, pppllc, LLC_LEN)) + goto error; + skb_pull(skb, LLC_LEN); + break; + case e_autodetect: + if (pvcc->chan.ppp == NULL) { /* Not bound yet! */ + kfree_skb(skb); + return; + } + if (skb->len >= sizeof(pppllc) && + !memcmp(skb->data, pppllc, sizeof(pppllc))) { + pvcc->encaps = e_llc; + skb_pull(skb, LLC_LEN); + break; + } + if (skb->len >= (sizeof(pppllc) - LLC_LEN) && + !memcmp(skb->data, &pppllc[LLC_LEN], + sizeof(pppllc) - LLC_LEN)) { + pvcc->encaps = e_vc; + pvcc->chan.mtu += LLC_LEN; + break; + } + pr_debug("Couldn't autodetect yet (skb: %6ph)\n", skb->data); + goto error; + case e_vc: + break; + } + ppp_input(&pvcc->chan, skb); + return; + +error: + kfree_skb(skb); + ppp_input_error(&pvcc->chan, 0); +} + +static int pppoatm_may_send(struct pppoatm_vcc *pvcc, int size) +{ + /* + * It's not clear that we need to bother with using atm_may_send() + * to check we don't exceed sk->sk_sndbuf. If userspace sets a + * value of sk_sndbuf which is lower than the MTU, we're going to + * block for ever. But the code always did that before we introduced + * the packet count limit, so... + */ + if (atm_may_send(pvcc->atmvcc, size) && + atomic_inc_not_zero(&pvcc->inflight)) + return 1; + + /* + * We use test_and_set_bit() rather than set_bit() here because + * we need to ensure there's a memory barrier after it. The bit + * *must* be set before we do the atomic_inc() on pvcc->inflight. + * There's no smp_mb__after_set_bit(), so it's this or abuse + * smp_mb__after_atomic(). + */ + test_and_set_bit(BLOCKED, &pvcc->blocked); + + /* + * We may have raced with pppoatm_pop(). If it ran for the + * last packet in the queue, *just* before we set the BLOCKED + * bit, then it might never run again and the channel could + * remain permanently blocked. Cope with that race by checking + * *again*. If it did run in that window, we'll have space on + * the queue now and can return success. It's harmless to leave + * the BLOCKED flag set, since it's only used as a trigger to + * run the wakeup tasklet. Another wakeup will never hurt. + * If pppoatm_pop() is running but hasn't got as far as making + * space on the queue yet, then it hasn't checked the BLOCKED + * flag yet either, so we're safe in that case too. It'll issue + * an "immediate" wakeup... where "immediate" actually involves + * taking the PPP channel's ->downl lock, which is held by the + * code path that calls pppoatm_send(), and is thus going to + * wait for us to finish. + */ + if (atm_may_send(pvcc->atmvcc, size) && + atomic_inc_not_zero(&pvcc->inflight)) + return 1; + + return 0; +} +/* + * Called by the ppp_generic.c to send a packet - returns true if packet + * was accepted. If we return false, then it's our job to call + * ppp_output_wakeup(chan) when we're feeling more up to it. + * Note that in the ENOMEM case (as opposed to the !atm_may_send case) + * we should really drop the packet, but the generic layer doesn't + * support this yet. We just return 'DROP_PACKET' which we actually define + * as success, just to be clear what we're really doing. + */ +#define DROP_PACKET 1 +static int pppoatm_send(struct ppp_channel *chan, struct sk_buff *skb) +{ + struct pppoatm_vcc *pvcc = chan_to_pvcc(chan); + struct atm_vcc *vcc; + int ret; + + ATM_SKB(skb)->vcc = pvcc->atmvcc; + pr_debug("(skb=0x%p, vcc=0x%p)\n", skb, pvcc->atmvcc); + if (skb->data[0] == '\0' && (pvcc->flags & SC_COMP_PROT)) + (void) skb_pull(skb, 1); + + vcc = ATM_SKB(skb)->vcc; + bh_lock_sock(sk_atm(vcc)); + if (sock_owned_by_user(sk_atm(vcc))) { + /* + * Needs to happen (and be flushed, hence test_and_) before we unlock + * the socket. It needs to be seen by the time our ->release_cb gets + * called. + */ + test_and_set_bit(BLOCKED, &pvcc->blocked); + goto nospace; + } + if (test_bit(ATM_VF_RELEASED, &vcc->flags) || + test_bit(ATM_VF_CLOSE, &vcc->flags) || + !test_bit(ATM_VF_READY, &vcc->flags)) { + bh_unlock_sock(sk_atm(vcc)); + kfree_skb(skb); + return DROP_PACKET; + } + + switch (pvcc->encaps) { /* LLC encapsulation needed */ + case e_llc: + if (skb_headroom(skb) < LLC_LEN) { + struct sk_buff *n; + n = skb_realloc_headroom(skb, LLC_LEN); + if (n != NULL && + !pppoatm_may_send(pvcc, n->truesize)) { + kfree_skb(n); + goto nospace; + } + consume_skb(skb); + skb = n; + if (skb == NULL) { + bh_unlock_sock(sk_atm(vcc)); + return DROP_PACKET; + } + } else if (!pppoatm_may_send(pvcc, skb->truesize)) + goto nospace; + memcpy(skb_push(skb, LLC_LEN), pppllc, LLC_LEN); + break; + case e_vc: + if (!pppoatm_may_send(pvcc, skb->truesize)) + goto nospace; + break; + case e_autodetect: + bh_unlock_sock(sk_atm(vcc)); + pr_debug("Trying to send without setting encaps!\n"); + kfree_skb(skb); + return 1; + } + + atm_account_tx(vcc, skb); + pr_debug("atm_skb(%p)->vcc(%p)->dev(%p)\n", + skb, ATM_SKB(skb)->vcc, ATM_SKB(skb)->vcc->dev); + ret = ATM_SKB(skb)->vcc->send(ATM_SKB(skb)->vcc, skb) + ? DROP_PACKET : 1; + bh_unlock_sock(sk_atm(vcc)); + return ret; +nospace: + bh_unlock_sock(sk_atm(vcc)); + /* + * We don't have space to send this SKB now, but we might have + * already applied SC_COMP_PROT compression, so may need to undo + */ + if ((pvcc->flags & SC_COMP_PROT) && skb_headroom(skb) > 0 && + skb->data[-1] == '\0') + (void) skb_push(skb, 1); + return 0; +} + +/* This handles ioctls sent to the /dev/ppp interface */ +static int pppoatm_devppp_ioctl(struct ppp_channel *chan, unsigned int cmd, + unsigned long arg) +{ + switch (cmd) { + case PPPIOCGFLAGS: + return put_user(chan_to_pvcc(chan)->flags, (int __user *) arg) + ? -EFAULT : 0; + case PPPIOCSFLAGS: + return get_user(chan_to_pvcc(chan)->flags, (int __user *) arg) + ? -EFAULT : 0; + } + return -ENOTTY; +} + +static const struct ppp_channel_ops pppoatm_ops = { + .start_xmit = pppoatm_send, + .ioctl = pppoatm_devppp_ioctl, +}; + +static int pppoatm_assign_vcc(struct atm_vcc *atmvcc, void __user *arg) +{ + struct atm_backend_ppp be; + struct pppoatm_vcc *pvcc; + int err; + + if (copy_from_user(&be, arg, sizeof be)) + return -EFAULT; + if (be.encaps != PPPOATM_ENCAPS_AUTODETECT && + be.encaps != PPPOATM_ENCAPS_VC && be.encaps != PPPOATM_ENCAPS_LLC) + return -EINVAL; + pvcc = kzalloc(sizeof(*pvcc), GFP_KERNEL); + if (pvcc == NULL) + return -ENOMEM; + pvcc->atmvcc = atmvcc; + + /* Maximum is zero, so that we can use atomic_inc_not_zero() */ + atomic_set(&pvcc->inflight, NONE_INFLIGHT); + pvcc->old_push = atmvcc->push; + pvcc->old_pop = atmvcc->pop; + pvcc->old_owner = atmvcc->owner; + pvcc->old_release_cb = atmvcc->release_cb; + pvcc->encaps = (enum pppoatm_encaps) be.encaps; + pvcc->chan.private = pvcc; + pvcc->chan.ops = &pppoatm_ops; + pvcc->chan.mtu = atmvcc->qos.txtp.max_sdu - PPP_HDRLEN - + (be.encaps == e_vc ? 0 : LLC_LEN); + tasklet_setup(&pvcc->wakeup_tasklet, pppoatm_wakeup_sender); + err = ppp_register_channel(&pvcc->chan); + if (err != 0) { + kfree(pvcc); + return err; + } + atmvcc->user_back = pvcc; + atmvcc->push = pppoatm_push; + atmvcc->pop = pppoatm_pop; + atmvcc->release_cb = pppoatm_release_cb; + __module_get(THIS_MODULE); + atmvcc->owner = THIS_MODULE; + + /* re-process everything received between connection setup and + backend setup */ + vcc_process_recv_queue(atmvcc); + return 0; +} + +/* + * This handles ioctls actually performed on our vcc - we must return + * -ENOIOCTLCMD for any unrecognized ioctl + */ +static int pppoatm_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + struct atm_vcc *atmvcc = ATM_SD(sock); + void __user *argp = (void __user *)arg; + + if (cmd != ATM_SETBACKEND && atmvcc->push != pppoatm_push) + return -ENOIOCTLCMD; + switch (cmd) { + case ATM_SETBACKEND: { + atm_backend_t b; + if (get_user(b, (atm_backend_t __user *) argp)) + return -EFAULT; + if (b != ATM_BACKEND_PPP) + return -ENOIOCTLCMD; + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + if (sock->state != SS_CONNECTED) + return -EINVAL; + return pppoatm_assign_vcc(atmvcc, argp); + } + case PPPIOCGCHAN: + return put_user(ppp_channel_index(&atmvcc_to_pvcc(atmvcc)-> + chan), (int __user *) argp) ? -EFAULT : 0; + case PPPIOCGUNIT: + return put_user(ppp_unit_number(&atmvcc_to_pvcc(atmvcc)-> + chan), (int __user *) argp) ? -EFAULT : 0; + } + return -ENOIOCTLCMD; +} + +static struct atm_ioctl pppoatm_ioctl_ops = { + .owner = THIS_MODULE, + .ioctl = pppoatm_ioctl, +}; + +static int __init pppoatm_init(void) +{ + register_atm_ioctl(&pppoatm_ioctl_ops); + return 0; +} + +static void __exit pppoatm_exit(void) +{ + deregister_atm_ioctl(&pppoatm_ioctl_ops); +} + +module_init(pppoatm_init); +module_exit(pppoatm_exit); + +MODULE_AUTHOR("Mitchell Blank Jr "); +MODULE_DESCRIPTION("RFC2364 PPP over ATM/AAL5"); +MODULE_LICENSE("GPL"); diff --git a/net/atm/proc.c b/net/atm/proc.c new file mode 100644 index 000000000..9bf736290 --- /dev/null +++ b/net/atm/proc.c @@ -0,0 +1,400 @@ +// SPDX-License-Identifier: GPL-2.0 +/* net/atm/proc.c - ATM /proc interface + * + * Written 1995-2000 by Werner Almesberger, EPFL LRC/ICA + * + * seq_file api usage by romieu@fr.zoreil.com + * + * Evaluating the efficiency of the whole thing if left as an exercise to + * the reader. + */ + +#include /* for EXPORT_SYMBOL */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include /* for __init */ +#include +#include +#include +#include +#include /* for HZ */ +#include +#include "resources.h" +#include "common.h" /* atm_proc_init prototype */ +#include "signaling.h" /* to get sigd - ugly too */ + +static ssize_t proc_dev_atm_read(struct file *file, char __user *buf, + size_t count, loff_t *pos); + +static const struct proc_ops atm_dev_proc_ops = { + .proc_read = proc_dev_atm_read, + .proc_lseek = noop_llseek, +}; + +static void add_stats(struct seq_file *seq, const char *aal, + const struct k_atm_aal_stats *stats) +{ + seq_printf(seq, "%s ( %d %d %d %d %d )", aal, + atomic_read(&stats->tx), atomic_read(&stats->tx_err), + atomic_read(&stats->rx), atomic_read(&stats->rx_err), + atomic_read(&stats->rx_drop)); +} + +static void atm_dev_info(struct seq_file *seq, const struct atm_dev *dev) +{ + int i; + + seq_printf(seq, "%3d %-8s", dev->number, dev->type); + for (i = 0; i < ESI_LEN; i++) + seq_printf(seq, "%02x", dev->esi[i]); + seq_puts(seq, " "); + add_stats(seq, "0", &dev->stats.aal0); + seq_puts(seq, " "); + add_stats(seq, "5", &dev->stats.aal5); + seq_printf(seq, "\t[%d]", refcount_read(&dev->refcnt)); + seq_putc(seq, '\n'); +} + +struct vcc_state { + int bucket; + struct sock *sk; +}; + +static inline int compare_family(struct sock *sk, int family) +{ + return !family || (sk->sk_family == family); +} + +static int __vcc_walk(struct sock **sock, int family, int *bucket, loff_t l) +{ + struct sock *sk = *sock; + + if (sk == SEQ_START_TOKEN) { + for (*bucket = 0; *bucket < VCC_HTABLE_SIZE; ++*bucket) { + struct hlist_head *head = &vcc_hash[*bucket]; + + sk = hlist_empty(head) ? NULL : __sk_head(head); + if (sk) + break; + } + l--; + } +try_again: + for (; sk; sk = sk_next(sk)) { + l -= compare_family(sk, family); + if (l < 0) + goto out; + } + if (!sk && ++*bucket < VCC_HTABLE_SIZE) { + sk = sk_head(&vcc_hash[*bucket]); + goto try_again; + } + sk = SEQ_START_TOKEN; +out: + *sock = sk; + return (l < 0); +} + +static inline void *vcc_walk(struct seq_file *seq, loff_t l) +{ + struct vcc_state *state = seq->private; + int family = (uintptr_t)(pde_data(file_inode(seq->file))); + + return __vcc_walk(&state->sk, family, &state->bucket, l) ? + state : NULL; +} + +static void *vcc_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(vcc_sklist_lock) +{ + struct vcc_state *state = seq->private; + loff_t left = *pos; + + read_lock(&vcc_sklist_lock); + state->sk = SEQ_START_TOKEN; + return left ? vcc_walk(seq, left) : SEQ_START_TOKEN; +} + +static void vcc_seq_stop(struct seq_file *seq, void *v) + __releases(vcc_sklist_lock) +{ + read_unlock(&vcc_sklist_lock); +} + +static void *vcc_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + v = vcc_walk(seq, 1); + (*pos)++; + return v; +} + +static void pvc_info(struct seq_file *seq, struct atm_vcc *vcc) +{ + static const char *const class_name[] = { + "off", "UBR", "CBR", "VBR", "ABR"}; + static const char *const aal_name[] = { + "---", "1", "2", "3/4", /* 0- 3 */ + "???", "5", "???", "???", /* 4- 7 */ + "???", "???", "???", "???", /* 8-11 */ + "???", "0", "???", "???"}; /* 12-15 */ + + seq_printf(seq, "%3d %3d %5d %-3s %7d %-5s %7d %-6s", + vcc->dev->number, vcc->vpi, vcc->vci, + vcc->qos.aal >= ARRAY_SIZE(aal_name) ? "err" : + aal_name[vcc->qos.aal], vcc->qos.rxtp.min_pcr, + class_name[vcc->qos.rxtp.traffic_class], + vcc->qos.txtp.min_pcr, + class_name[vcc->qos.txtp.traffic_class]); + if (test_bit(ATM_VF_IS_CLIP, &vcc->flags)) { + struct clip_vcc *clip_vcc = CLIP_VCC(vcc); + struct net_device *dev; + + dev = clip_vcc->entry ? clip_vcc->entry->neigh->dev : NULL; + seq_printf(seq, "CLIP, Itf:%s, Encap:", + dev ? dev->name : "none?"); + seq_printf(seq, "%s", clip_vcc->encap ? "LLC/SNAP" : "None"); + } + seq_putc(seq, '\n'); +} + +static const char *vcc_state(struct atm_vcc *vcc) +{ + static const char *const map[] = { ATM_VS2TXT_MAP }; + + return map[ATM_VF2VS(vcc->flags)]; +} + +static void vcc_info(struct seq_file *seq, struct atm_vcc *vcc) +{ + struct sock *sk = sk_atm(vcc); + + seq_printf(seq, "%pK ", vcc); + if (!vcc->dev) + seq_printf(seq, "Unassigned "); + else + seq_printf(seq, "%3d %3d %5d ", vcc->dev->number, vcc->vpi, + vcc->vci); + switch (sk->sk_family) { + case AF_ATMPVC: + seq_printf(seq, "PVC"); + break; + case AF_ATMSVC: + seq_printf(seq, "SVC"); + break; + default: + seq_printf(seq, "%3d", sk->sk_family); + } + seq_printf(seq, " %04lx %5d %7d/%7d %7d/%7d [%d]\n", + vcc->flags, sk->sk_err, + sk_wmem_alloc_get(sk), sk->sk_sndbuf, + sk_rmem_alloc_get(sk), sk->sk_rcvbuf, + refcount_read(&sk->sk_refcnt)); +} + +static void svc_info(struct seq_file *seq, struct atm_vcc *vcc) +{ + if (!vcc->dev) + seq_printf(seq, sizeof(void *) == 4 ? + "N/A@%pK%10s" : "N/A@%pK%2s", vcc, ""); + else + seq_printf(seq, "%3d %3d %5d ", + vcc->dev->number, vcc->vpi, vcc->vci); + seq_printf(seq, "%-10s ", vcc_state(vcc)); + seq_printf(seq, "%s%s", vcc->remote.sas_addr.pub, + *vcc->remote.sas_addr.pub && *vcc->remote.sas_addr.prv ? "+" : ""); + if (*vcc->remote.sas_addr.prv) { + int i; + + for (i = 0; i < ATM_ESA_LEN; i++) + seq_printf(seq, "%02x", vcc->remote.sas_addr.prv[i]); + } + seq_putc(seq, '\n'); +} + +static int atm_dev_seq_show(struct seq_file *seq, void *v) +{ + static char atm_dev_banner[] = + "Itf Type ESI/\"MAC\"addr " + "AAL(TX,err,RX,err,drop) ... [refcnt]\n"; + + if (v == &atm_devs) + seq_puts(seq, atm_dev_banner); + else { + struct atm_dev *dev = list_entry(v, struct atm_dev, dev_list); + + atm_dev_info(seq, dev); + } + return 0; +} + +static const struct seq_operations atm_dev_seq_ops = { + .start = atm_dev_seq_start, + .next = atm_dev_seq_next, + .stop = atm_dev_seq_stop, + .show = atm_dev_seq_show, +}; + +static int pvc_seq_show(struct seq_file *seq, void *v) +{ + static char atm_pvc_banner[] = + "Itf VPI VCI AAL RX(PCR,Class) TX(PCR,Class)\n"; + + if (v == SEQ_START_TOKEN) + seq_puts(seq, atm_pvc_banner); + else { + struct vcc_state *state = seq->private; + struct atm_vcc *vcc = atm_sk(state->sk); + + pvc_info(seq, vcc); + } + return 0; +} + +static const struct seq_operations pvc_seq_ops = { + .start = vcc_seq_start, + .next = vcc_seq_next, + .stop = vcc_seq_stop, + .show = pvc_seq_show, +}; + +static int vcc_seq_show(struct seq_file *seq, void *v) +{ + if (v == SEQ_START_TOKEN) { + seq_printf(seq, sizeof(void *) == 4 ? "%-8s%s" : "%-16s%s", + "Address ", "Itf VPI VCI Fam Flags Reply " + "Send buffer Recv buffer [refcnt]\n"); + } else { + struct vcc_state *state = seq->private; + struct atm_vcc *vcc = atm_sk(state->sk); + + vcc_info(seq, vcc); + } + return 0; +} + +static const struct seq_operations vcc_seq_ops = { + .start = vcc_seq_start, + .next = vcc_seq_next, + .stop = vcc_seq_stop, + .show = vcc_seq_show, +}; + +static int svc_seq_show(struct seq_file *seq, void *v) +{ + static const char atm_svc_banner[] = + "Itf VPI VCI State Remote\n"; + + if (v == SEQ_START_TOKEN) + seq_puts(seq, atm_svc_banner); + else { + struct vcc_state *state = seq->private; + struct atm_vcc *vcc = atm_sk(state->sk); + + svc_info(seq, vcc); + } + return 0; +} + +static const struct seq_operations svc_seq_ops = { + .start = vcc_seq_start, + .next = vcc_seq_next, + .stop = vcc_seq_stop, + .show = svc_seq_show, +}; + +static ssize_t proc_dev_atm_read(struct file *file, char __user *buf, + size_t count, loff_t *pos) +{ + struct atm_dev *dev; + unsigned long page; + int length; + + if (count == 0) + return 0; + page = get_zeroed_page(GFP_KERNEL); + if (!page) + return -ENOMEM; + dev = pde_data(file_inode(file)); + if (!dev->ops->proc_read) + length = -EINVAL; + else { + length = dev->ops->proc_read(dev, pos, (char *)page); + if (length > count) + length = -EINVAL; + } + if (length >= 0) { + if (copy_to_user(buf, (char *)page, length)) + length = -EFAULT; + (*pos)++; + } + free_page(page); + return length; +} + +struct proc_dir_entry *atm_proc_root; +EXPORT_SYMBOL(atm_proc_root); + + +int atm_proc_dev_register(struct atm_dev *dev) +{ + int error; + + /* No proc info */ + if (!dev->ops->proc_read) + return 0; + + error = -ENOMEM; + dev->proc_name = kasprintf(GFP_KERNEL, "%s:%d", dev->type, dev->number); + if (!dev->proc_name) + goto err_out; + + dev->proc_entry = proc_create_data(dev->proc_name, 0, atm_proc_root, + &atm_dev_proc_ops, dev); + if (!dev->proc_entry) + goto err_free_name; + return 0; + +err_free_name: + kfree(dev->proc_name); +err_out: + return error; +} + +void atm_proc_dev_deregister(struct atm_dev *dev) +{ + if (!dev->ops->proc_read) + return; + + remove_proc_entry(dev->proc_name, atm_proc_root); + kfree(dev->proc_name); +} + +int __init atm_proc_init(void) +{ + atm_proc_root = proc_net_mkdir(&init_net, "atm", init_net.proc_net); + if (!atm_proc_root) + return -ENOMEM; + proc_create_seq("devices", 0444, atm_proc_root, &atm_dev_seq_ops); + proc_create_seq_private("pvc", 0444, atm_proc_root, &pvc_seq_ops, + sizeof(struct vcc_state), (void *)(uintptr_t)PF_ATMPVC); + proc_create_seq_private("svc", 0444, atm_proc_root, &svc_seq_ops, + sizeof(struct vcc_state), (void *)(uintptr_t)PF_ATMSVC); + proc_create_seq_private("vc", 0444, atm_proc_root, &vcc_seq_ops, + sizeof(struct vcc_state), NULL); + return 0; +} + +void atm_proc_exit(void) +{ + remove_proc_subtree("atm", init_net.proc_net); +} diff --git a/net/atm/protocols.h b/net/atm/protocols.h new file mode 100644 index 000000000..18d4d008b --- /dev/null +++ b/net/atm/protocols.h @@ -0,0 +1,14 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* net/atm/protocols.h - ATM protocol handler entry points */ + +/* Written 1995-1997 by Werner Almesberger, EPFL LRC */ + + +#ifndef NET_ATM_PROTOCOLS_H +#define NET_ATM_PROTOCOLS_H + +int atm_init_aal0(struct atm_vcc *vcc); /* "raw" AAL0 */ +int atm_init_aal34(struct atm_vcc *vcc);/* "raw" AAL3/4 transport */ +int atm_init_aal5(struct atm_vcc *vcc); /* "raw" AAL5 transport */ + +#endif diff --git a/net/atm/pvc.c b/net/atm/pvc.c new file mode 100644 index 000000000..53e7d3f39 --- /dev/null +++ b/net/atm/pvc.c @@ -0,0 +1,163 @@ +// SPDX-License-Identifier: GPL-2.0 +/* net/atm/pvc.c - ATM PVC sockets */ + +/* Written 1995-2000 by Werner Almesberger, EPFL LRC/ICA */ + + +#include /* struct socket, struct proto_ops */ +#include /* ATM stuff */ +#include /* ATM devices */ +#include /* error codes */ +#include /* printk */ +#include +#include +#include +#include +#include /* for sock_no_* */ + +#include "resources.h" /* devs and vccs */ +#include "common.h" /* common for PVCs and SVCs */ + + +static int pvc_shutdown(struct socket *sock, int how) +{ + return 0; +} + +static int pvc_bind(struct socket *sock, struct sockaddr *sockaddr, + int sockaddr_len) +{ + struct sock *sk = sock->sk; + struct sockaddr_atmpvc *addr; + struct atm_vcc *vcc; + int error; + + if (sockaddr_len != sizeof(struct sockaddr_atmpvc)) + return -EINVAL; + addr = (struct sockaddr_atmpvc *)sockaddr; + if (addr->sap_family != AF_ATMPVC) + return -EAFNOSUPPORT; + lock_sock(sk); + vcc = ATM_SD(sock); + if (!test_bit(ATM_VF_HASQOS, &vcc->flags)) { + error = -EBADFD; + goto out; + } + if (test_bit(ATM_VF_PARTIAL, &vcc->flags)) { + if (vcc->vpi != ATM_VPI_UNSPEC) + addr->sap_addr.vpi = vcc->vpi; + if (vcc->vci != ATM_VCI_UNSPEC) + addr->sap_addr.vci = vcc->vci; + } + error = vcc_connect(sock, addr->sap_addr.itf, addr->sap_addr.vpi, + addr->sap_addr.vci); +out: + release_sock(sk); + return error; +} + +static int pvc_connect(struct socket *sock, struct sockaddr *sockaddr, + int sockaddr_len, int flags) +{ + return pvc_bind(sock, sockaddr, sockaddr_len); +} + +static int pvc_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + int error; + + lock_sock(sk); + error = vcc_setsockopt(sock, level, optname, optval, optlen); + release_sock(sk); + return error; +} + +static int pvc_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + int error; + + lock_sock(sk); + error = vcc_getsockopt(sock, level, optname, optval, optlen); + release_sock(sk); + return error; +} + +static int pvc_getname(struct socket *sock, struct sockaddr *sockaddr, + int peer) +{ + struct sockaddr_atmpvc *addr; + struct atm_vcc *vcc = ATM_SD(sock); + + if (!vcc->dev || !test_bit(ATM_VF_ADDR, &vcc->flags)) + return -ENOTCONN; + addr = (struct sockaddr_atmpvc *)sockaddr; + memset(addr, 0, sizeof(*addr)); + addr->sap_family = AF_ATMPVC; + addr->sap_addr.itf = vcc->dev->number; + addr->sap_addr.vpi = vcc->vpi; + addr->sap_addr.vci = vcc->vci; + return sizeof(struct sockaddr_atmpvc); +} + +static const struct proto_ops pvc_proto_ops = { + .family = PF_ATMPVC, + .owner = THIS_MODULE, + + .release = vcc_release, + .bind = pvc_bind, + .connect = pvc_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = pvc_getname, + .poll = vcc_poll, + .ioctl = vcc_ioctl, +#ifdef CONFIG_COMPAT + .compat_ioctl = vcc_compat_ioctl, +#endif + .gettstamp = sock_gettstamp, + .listen = sock_no_listen, + .shutdown = pvc_shutdown, + .setsockopt = pvc_setsockopt, + .getsockopt = pvc_getsockopt, + .sendmsg = vcc_sendmsg, + .recvmsg = vcc_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + + +static int pvc_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + if (net != &init_net) + return -EAFNOSUPPORT; + + sock->ops = &pvc_proto_ops; + return vcc_create(net, sock, protocol, PF_ATMPVC, kern); +} + +static const struct net_proto_family pvc_family_ops = { + .family = PF_ATMPVC, + .create = pvc_create, + .owner = THIS_MODULE, +}; + + +/* + * Initialize the ATM PVC protocol family + */ + + +int __init atmpvc_init(void) +{ + return sock_register(&pvc_family_ops); +} + +void atmpvc_exit(void) +{ + sock_unregister(PF_ATMPVC); +} diff --git a/net/atm/raw.c b/net/atm/raw.c new file mode 100644 index 000000000..2b5f78a7e --- /dev/null +++ b/net/atm/raw.c @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: GPL-2.0 +/* net/atm/raw.c - Raw AAL0 and AAL5 transports */ + +/* Written 1995-2000 by Werner Almesberger, EPFL LRC/ICA */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s: " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "protocols.h" + +/* + * SKB == NULL indicates that the link is being closed + */ + +static void atm_push_raw(struct atm_vcc *vcc, struct sk_buff *skb) +{ + if (skb) { + struct sock *sk = sk_atm(vcc); + + skb_queue_tail(&sk->sk_receive_queue, skb); + sk->sk_data_ready(sk); + } +} + +static void atm_pop_raw(struct atm_vcc *vcc, struct sk_buff *skb) +{ + struct sock *sk = sk_atm(vcc); + + pr_debug("(%d) %d -= %d\n", + vcc->vci, sk_wmem_alloc_get(sk), ATM_SKB(skb)->acct_truesize); + WARN_ON(refcount_sub_and_test(ATM_SKB(skb)->acct_truesize, &sk->sk_wmem_alloc)); + dev_kfree_skb_any(skb); + sk->sk_write_space(sk); +} + +static int atm_send_aal0(struct atm_vcc *vcc, struct sk_buff *skb) +{ + /* + * Note that if vpi/vci are _ANY or _UNSPEC the below will + * still work + */ + if (!capable(CAP_NET_ADMIN) && + (((u32 *)skb->data)[0] & (ATM_HDR_VPI_MASK | ATM_HDR_VCI_MASK)) != + ((vcc->vpi << ATM_HDR_VPI_SHIFT) | + (vcc->vci << ATM_HDR_VCI_SHIFT))) { + kfree_skb(skb); + return -EADDRNOTAVAIL; + } + if (vcc->dev->ops->send_bh) + return vcc->dev->ops->send_bh(vcc, skb); + return vcc->dev->ops->send(vcc, skb); +} + +int atm_init_aal0(struct atm_vcc *vcc) +{ + vcc->push = atm_push_raw; + vcc->pop = atm_pop_raw; + vcc->push_oam = NULL; + vcc->send = atm_send_aal0; + return 0; +} + +int atm_init_aal34(struct atm_vcc *vcc) +{ + vcc->push = atm_push_raw; + vcc->pop = atm_pop_raw; + vcc->push_oam = NULL; + if (vcc->dev->ops->send_bh) + vcc->send = vcc->dev->ops->send_bh; + else + vcc->send = vcc->dev->ops->send; + return 0; +} + +int atm_init_aal5(struct atm_vcc *vcc) +{ + vcc->push = atm_push_raw; + vcc->pop = atm_pop_raw; + vcc->push_oam = NULL; + if (vcc->dev->ops->send_bh) + vcc->send = vcc->dev->ops->send_bh; + else + vcc->send = vcc->dev->ops->send; + return 0; +} +EXPORT_SYMBOL(atm_init_aal5); diff --git a/net/atm/resources.c b/net/atm/resources.c new file mode 100644 index 000000000..995d29e7f --- /dev/null +++ b/net/atm/resources.c @@ -0,0 +1,419 @@ +// SPDX-License-Identifier: GPL-2.0 +/* net/atm/resources.c - Statically allocated resources */ + +/* Written 1995-2000 by Werner Almesberger, EPFL LRC/ICA */ + +/* Fixes + * Arnaldo Carvalho de Melo + * 2002/01 - don't free the whole struct sock on sk->destruct time, + * use the default destruct function initialized by sock_init_data */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s: " fmt, __func__ + +#include +#include +#include +#include +#include /* for barrier */ +#include +#include +#include +#include +#include +#include + +#include /* for struct sock */ + +#include "common.h" +#include "resources.h" +#include "addr.h" + + +LIST_HEAD(atm_devs); +DEFINE_MUTEX(atm_dev_mutex); + +static struct atm_dev *__alloc_atm_dev(const char *type) +{ + struct atm_dev *dev; + + dev = kzalloc(sizeof(*dev), GFP_KERNEL); + if (!dev) + return NULL; + dev->type = type; + dev->signal = ATM_PHY_SIG_UNKNOWN; + dev->link_rate = ATM_OC3_PCR; + spin_lock_init(&dev->lock); + INIT_LIST_HEAD(&dev->local); + INIT_LIST_HEAD(&dev->lecs); + + return dev; +} + +static struct atm_dev *__atm_dev_lookup(int number) +{ + struct atm_dev *dev; + + list_for_each_entry(dev, &atm_devs, dev_list) { + if (dev->number == number) { + atm_dev_hold(dev); + return dev; + } + } + return NULL; +} + +struct atm_dev *atm_dev_lookup(int number) +{ + struct atm_dev *dev; + + mutex_lock(&atm_dev_mutex); + dev = __atm_dev_lookup(number); + mutex_unlock(&atm_dev_mutex); + return dev; +} +EXPORT_SYMBOL(atm_dev_lookup); + +struct atm_dev *atm_dev_register(const char *type, struct device *parent, + const struct atmdev_ops *ops, int number, + unsigned long *flags) +{ + struct atm_dev *dev, *inuse; + + dev = __alloc_atm_dev(type); + if (!dev) { + pr_err("no space for dev %s\n", type); + return NULL; + } + mutex_lock(&atm_dev_mutex); + if (number != -1) { + inuse = __atm_dev_lookup(number); + if (inuse) { + atm_dev_put(inuse); + mutex_unlock(&atm_dev_mutex); + kfree(dev); + return NULL; + } + dev->number = number; + } else { + dev->number = 0; + while ((inuse = __atm_dev_lookup(dev->number))) { + atm_dev_put(inuse); + dev->number++; + } + } + + dev->ops = ops; + if (flags) + dev->flags = *flags; + else + memset(&dev->flags, 0, sizeof(dev->flags)); + memset(&dev->stats, 0, sizeof(dev->stats)); + refcount_set(&dev->refcnt, 1); + + if (atm_proc_dev_register(dev) < 0) { + pr_err("atm_proc_dev_register failed for dev %s\n", type); + goto out_fail; + } + + if (atm_register_sysfs(dev, parent) < 0) { + pr_err("atm_register_sysfs failed for dev %s\n", type); + atm_proc_dev_deregister(dev); + goto out_fail; + } + + list_add_tail(&dev->dev_list, &atm_devs); + +out: + mutex_unlock(&atm_dev_mutex); + return dev; + +out_fail: + kfree(dev); + dev = NULL; + goto out; +} +EXPORT_SYMBOL(atm_dev_register); + +void atm_dev_deregister(struct atm_dev *dev) +{ + BUG_ON(test_bit(ATM_DF_REMOVED, &dev->flags)); + set_bit(ATM_DF_REMOVED, &dev->flags); + + /* + * if we remove current device from atm_devs list, new device + * with same number can appear, such we need deregister proc, + * release async all vccs and remove them from vccs list too + */ + mutex_lock(&atm_dev_mutex); + list_del(&dev->dev_list); + mutex_unlock(&atm_dev_mutex); + + atm_dev_release_vccs(dev); + atm_unregister_sysfs(dev); + atm_proc_dev_deregister(dev); + + atm_dev_put(dev); +} +EXPORT_SYMBOL(atm_dev_deregister); + +static void copy_aal_stats(struct k_atm_aal_stats *from, + struct atm_aal_stats *to) +{ +#define __HANDLE_ITEM(i) to->i = atomic_read(&from->i) + __AAL_STAT_ITEMS +#undef __HANDLE_ITEM +} + +static void subtract_aal_stats(struct k_atm_aal_stats *from, + struct atm_aal_stats *to) +{ +#define __HANDLE_ITEM(i) atomic_sub(to->i, &from->i) + __AAL_STAT_ITEMS +#undef __HANDLE_ITEM +} + +static int fetch_stats(struct atm_dev *dev, struct atm_dev_stats __user *arg, + int zero) +{ + struct atm_dev_stats tmp; + int error = 0; + + copy_aal_stats(&dev->stats.aal0, &tmp.aal0); + copy_aal_stats(&dev->stats.aal34, &tmp.aal34); + copy_aal_stats(&dev->stats.aal5, &tmp.aal5); + if (arg) + error = copy_to_user(arg, &tmp, sizeof(tmp)); + if (zero && !error) { + subtract_aal_stats(&dev->stats.aal0, &tmp.aal0); + subtract_aal_stats(&dev->stats.aal34, &tmp.aal34); + subtract_aal_stats(&dev->stats.aal5, &tmp.aal5); + } + return error ? -EFAULT : 0; +} + +int atm_getnames(void __user *buf, int __user *iobuf_len) +{ + int error, len, size = 0; + struct atm_dev *dev; + struct list_head *p; + int *tmp_buf, *tmp_p; + + if (get_user(len, iobuf_len)) + return -EFAULT; + mutex_lock(&atm_dev_mutex); + list_for_each(p, &atm_devs) + size += sizeof(int); + if (size > len) { + mutex_unlock(&atm_dev_mutex); + return -E2BIG; + } + tmp_buf = kmalloc(size, GFP_ATOMIC); + if (!tmp_buf) { + mutex_unlock(&atm_dev_mutex); + return -ENOMEM; + } + tmp_p = tmp_buf; + list_for_each_entry(dev, &atm_devs, dev_list) { + *tmp_p++ = dev->number; + } + mutex_unlock(&atm_dev_mutex); + error = ((copy_to_user(buf, tmp_buf, size)) || + put_user(size, iobuf_len)) + ? -EFAULT : 0; + kfree(tmp_buf); + return error; +} + +int atm_dev_ioctl(unsigned int cmd, void __user *buf, int __user *sioc_len, + int number, int compat) +{ + int error, len, size = 0; + struct atm_dev *dev; + + if (get_user(len, sioc_len)) + return -EFAULT; + + dev = try_then_request_module(atm_dev_lookup(number), "atm-device-%d", + number); + if (!dev) + return -ENODEV; + + switch (cmd) { + case ATM_GETTYPE: + size = strlen(dev->type) + 1; + if (copy_to_user(buf, dev->type, size)) { + error = -EFAULT; + goto done; + } + break; + case ATM_GETESI: + size = ESI_LEN; + if (copy_to_user(buf, dev->esi, size)) { + error = -EFAULT; + goto done; + } + break; + case ATM_SETESI: + { + int i; + + for (i = 0; i < ESI_LEN; i++) + if (dev->esi[i]) { + error = -EEXIST; + goto done; + } + } + fallthrough; + case ATM_SETESIF: + { + unsigned char esi[ESI_LEN]; + + if (!capable(CAP_NET_ADMIN)) { + error = -EPERM; + goto done; + } + if (copy_from_user(esi, buf, ESI_LEN)) { + error = -EFAULT; + goto done; + } + memcpy(dev->esi, esi, ESI_LEN); + error = ESI_LEN; + goto done; + } + case ATM_GETSTATZ: + if (!capable(CAP_NET_ADMIN)) { + error = -EPERM; + goto done; + } + fallthrough; + case ATM_GETSTAT: + size = sizeof(struct atm_dev_stats); + error = fetch_stats(dev, buf, cmd == ATM_GETSTATZ); + if (error) + goto done; + break; + case ATM_GETCIRANGE: + size = sizeof(struct atm_cirange); + if (copy_to_user(buf, &dev->ci_range, size)) { + error = -EFAULT; + goto done; + } + break; + case ATM_GETLINKRATE: + size = sizeof(int); + if (copy_to_user(buf, &dev->link_rate, size)) { + error = -EFAULT; + goto done; + } + break; + case ATM_RSTADDR: + if (!capable(CAP_NET_ADMIN)) { + error = -EPERM; + goto done; + } + atm_reset_addr(dev, ATM_ADDR_LOCAL); + break; + case ATM_ADDADDR: + case ATM_DELADDR: + case ATM_ADDLECSADDR: + case ATM_DELLECSADDR: + { + struct sockaddr_atmsvc addr; + + if (!capable(CAP_NET_ADMIN)) { + error = -EPERM; + goto done; + } + + if (copy_from_user(&addr, buf, sizeof(addr))) { + error = -EFAULT; + goto done; + } + if (cmd == ATM_ADDADDR || cmd == ATM_ADDLECSADDR) + error = atm_add_addr(dev, &addr, + (cmd == ATM_ADDADDR ? + ATM_ADDR_LOCAL : ATM_ADDR_LECS)); + else + error = atm_del_addr(dev, &addr, + (cmd == ATM_DELADDR ? + ATM_ADDR_LOCAL : ATM_ADDR_LECS)); + goto done; + } + case ATM_GETADDR: + case ATM_GETLECSADDR: + error = atm_get_addr(dev, buf, len, + (cmd == ATM_GETADDR ? + ATM_ADDR_LOCAL : ATM_ADDR_LECS)); + if (error < 0) + goto done; + size = error; + /* may return 0, but later on size == 0 means "don't + write the length" */ + error = put_user(size, sioc_len) ? -EFAULT : 0; + goto done; + case ATM_SETLOOP: + if (__ATM_LM_XTRMT((int) (unsigned long) buf) && + __ATM_LM_XTLOC((int) (unsigned long) buf) > + __ATM_LM_XTRMT((int) (unsigned long) buf)) { + error = -EINVAL; + goto done; + } + fallthrough; + case ATM_SETCIRANGE: + case SONET_GETSTATZ: + case SONET_SETDIAG: + case SONET_CLRDIAG: + case SONET_SETFRAMING: + if (!capable(CAP_NET_ADMIN)) { + error = -EPERM; + goto done; + } + fallthrough; + default: + if (IS_ENABLED(CONFIG_COMPAT) && compat) { +#ifdef CONFIG_COMPAT + if (!dev->ops->compat_ioctl) { + error = -EINVAL; + goto done; + } + size = dev->ops->compat_ioctl(dev, cmd, buf); +#endif + } else { + if (!dev->ops->ioctl) { + error = -EINVAL; + goto done; + } + size = dev->ops->ioctl(dev, cmd, buf); + } + if (size < 0) { + error = (size == -ENOIOCTLCMD ? -ENOTTY : size); + goto done; + } + } + + if (size) + error = put_user(size, sioc_len) ? -EFAULT : 0; + else + error = 0; +done: + atm_dev_put(dev); + return error; +} + +#ifdef CONFIG_PROC_FS +void *atm_dev_seq_start(struct seq_file *seq, loff_t *pos) +{ + mutex_lock(&atm_dev_mutex); + return seq_list_start_head(&atm_devs, *pos); +} + +void atm_dev_seq_stop(struct seq_file *seq, void *v) +{ + mutex_unlock(&atm_dev_mutex); +} + +void *atm_dev_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + return seq_list_next(v, &atm_devs, pos); +} +#endif diff --git a/net/atm/resources.h b/net/atm/resources.h new file mode 100644 index 000000000..4a0839e92 --- /dev/null +++ b/net/atm/resources.h @@ -0,0 +1,49 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* net/atm/resources.h - ATM-related resources */ + +/* Written 1995-1998 by Werner Almesberger, EPFL LRC/ICA */ + + +#ifndef NET_ATM_RESOURCES_H +#define NET_ATM_RESOURCES_H + +#include +#include + + +extern struct list_head atm_devs; +extern struct mutex atm_dev_mutex; + +int atm_getnames(void __user *buf, int __user *iobuf_len); +int atm_dev_ioctl(unsigned int cmd, void __user *buf, int __user *sioc_len, + int number, int compat); + +#ifdef CONFIG_PROC_FS + +#include + +void *atm_dev_seq_start(struct seq_file *seq, loff_t *pos); +void atm_dev_seq_stop(struct seq_file *seq, void *v); +void *atm_dev_seq_next(struct seq_file *seq, void *v, loff_t *pos); + + +int atm_proc_dev_register(struct atm_dev *dev); +void atm_proc_dev_deregister(struct atm_dev *dev); + +#else + +static inline int atm_proc_dev_register(struct atm_dev *dev) +{ + return 0; +} + +static inline void atm_proc_dev_deregister(struct atm_dev *dev) +{ + /* nothing */ +} + +#endif /* CONFIG_PROC_FS */ + +int atm_register_sysfs(struct atm_dev *adev, struct device *parent); +void atm_unregister_sysfs(struct atm_dev *adev); +#endif diff --git a/net/atm/signaling.c b/net/atm/signaling.c new file mode 100644 index 000000000..5de06ab8e --- /dev/null +++ b/net/atm/signaling.c @@ -0,0 +1,244 @@ +// SPDX-License-Identifier: GPL-2.0 +/* net/atm/signaling.c - ATM signaling */ + +/* Written 1995-2000 by Werner Almesberger, EPFL LRC/ICA */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s: " fmt, __func__ + +#include /* error codes */ +#include /* printk */ +#include +#include +#include /* jiffies and HZ */ +#include /* ATM stuff */ +#include +#include +#include +#include +#include + +#include "resources.h" +#include "signaling.h" + +struct atm_vcc *sigd = NULL; + +static void sigd_put_skb(struct sk_buff *skb) +{ + if (!sigd) { + pr_debug("atmsvc: no signaling daemon\n"); + kfree_skb(skb); + return; + } + atm_force_charge(sigd, skb->truesize); + skb_queue_tail(&sk_atm(sigd)->sk_receive_queue, skb); + sk_atm(sigd)->sk_data_ready(sk_atm(sigd)); +} + +static void modify_qos(struct atm_vcc *vcc, struct atmsvc_msg *msg) +{ + struct sk_buff *skb; + + if (test_bit(ATM_VF_RELEASED, &vcc->flags) || + !test_bit(ATM_VF_READY, &vcc->flags)) + return; + msg->type = as_error; + if (!vcc->dev->ops->change_qos) + msg->reply = -EOPNOTSUPP; + else { + /* should lock VCC */ + msg->reply = vcc->dev->ops->change_qos(vcc, &msg->qos, + msg->reply); + if (!msg->reply) + msg->type = as_okay; + } + /* + * Should probably just turn around the old skb. But then, the buffer + * space accounting needs to follow the change too. Maybe later. + */ + while (!(skb = alloc_skb(sizeof(struct atmsvc_msg), GFP_KERNEL))) + schedule(); + *(struct atmsvc_msg *)skb_put(skb, sizeof(struct atmsvc_msg)) = *msg; + sigd_put_skb(skb); +} + +static int sigd_send(struct atm_vcc *vcc, struct sk_buff *skb) +{ + struct atmsvc_msg *msg; + struct atm_vcc *session_vcc; + struct sock *sk; + + msg = (struct atmsvc_msg *) skb->data; + WARN_ON(refcount_sub_and_test(skb->truesize, &sk_atm(vcc)->sk_wmem_alloc)); + vcc = *(struct atm_vcc **) &msg->vcc; + pr_debug("%d (0x%lx)\n", (int)msg->type, (unsigned long)vcc); + sk = sk_atm(vcc); + + switch (msg->type) { + case as_okay: + sk->sk_err = -msg->reply; + clear_bit(ATM_VF_WAITING, &vcc->flags); + if (!*vcc->local.sas_addr.prv && !*vcc->local.sas_addr.pub) { + vcc->local.sas_family = AF_ATMSVC; + memcpy(vcc->local.sas_addr.prv, + msg->local.sas_addr.prv, ATM_ESA_LEN); + memcpy(vcc->local.sas_addr.pub, + msg->local.sas_addr.pub, ATM_E164_LEN + 1); + } + session_vcc = vcc->session ? vcc->session : vcc; + if (session_vcc->vpi || session_vcc->vci) + break; + session_vcc->itf = msg->pvc.sap_addr.itf; + session_vcc->vpi = msg->pvc.sap_addr.vpi; + session_vcc->vci = msg->pvc.sap_addr.vci; + if (session_vcc->vpi || session_vcc->vci) + session_vcc->qos = msg->qos; + break; + case as_error: + clear_bit(ATM_VF_REGIS, &vcc->flags); + clear_bit(ATM_VF_READY, &vcc->flags); + sk->sk_err = -msg->reply; + clear_bit(ATM_VF_WAITING, &vcc->flags); + break; + case as_indicate: + vcc = *(struct atm_vcc **)&msg->listen_vcc; + sk = sk_atm(vcc); + pr_debug("as_indicate!!!\n"); + lock_sock(sk); + if (sk_acceptq_is_full(sk)) { + sigd_enq(NULL, as_reject, vcc, NULL, NULL); + dev_kfree_skb(skb); + goto as_indicate_complete; + } + sk_acceptq_added(sk); + skb_queue_tail(&sk->sk_receive_queue, skb); + pr_debug("waking sk_sleep(sk) 0x%p\n", sk_sleep(sk)); + sk->sk_state_change(sk); +as_indicate_complete: + release_sock(sk); + return 0; + case as_close: + set_bit(ATM_VF_RELEASED, &vcc->flags); + vcc_release_async(vcc, msg->reply); + goto out; + case as_modify: + modify_qos(vcc, msg); + break; + case as_addparty: + case as_dropparty: + sk->sk_err_soft = -msg->reply; + /* < 0 failure, otherwise ep_ref */ + clear_bit(ATM_VF_WAITING, &vcc->flags); + break; + default: + pr_alert("bad message type %d\n", (int)msg->type); + return -EINVAL; + } + sk->sk_state_change(sk); +out: + dev_kfree_skb(skb); + return 0; +} + +void sigd_enq2(struct atm_vcc *vcc, enum atmsvc_msg_type type, + struct atm_vcc *listen_vcc, const struct sockaddr_atmpvc *pvc, + const struct sockaddr_atmsvc *svc, const struct atm_qos *qos, + int reply) +{ + struct sk_buff *skb; + struct atmsvc_msg *msg; + static unsigned int session = 0; + + pr_debug("%d (0x%p)\n", (int)type, vcc); + while (!(skb = alloc_skb(sizeof(struct atmsvc_msg), GFP_KERNEL))) + schedule(); + msg = skb_put_zero(skb, sizeof(struct atmsvc_msg)); + msg->type = type; + *(struct atm_vcc **) &msg->vcc = vcc; + *(struct atm_vcc **) &msg->listen_vcc = listen_vcc; + msg->reply = reply; + if (qos) + msg->qos = *qos; + if (vcc) + msg->sap = vcc->sap; + if (svc) + msg->svc = *svc; + if (vcc) + msg->local = vcc->local; + if (pvc) + msg->pvc = *pvc; + if (vcc) { + if (type == as_connect && test_bit(ATM_VF_SESSION, &vcc->flags)) + msg->session = ++session; + /* every new pmp connect gets the next session number */ + } + sigd_put_skb(skb); + if (vcc) + set_bit(ATM_VF_REGIS, &vcc->flags); +} + +void sigd_enq(struct atm_vcc *vcc, enum atmsvc_msg_type type, + struct atm_vcc *listen_vcc, const struct sockaddr_atmpvc *pvc, + const struct sockaddr_atmsvc *svc) +{ + sigd_enq2(vcc, type, listen_vcc, pvc, svc, vcc ? &vcc->qos : NULL, 0); + /* other ISP applications may use "reply" */ +} + +static void purge_vcc(struct atm_vcc *vcc) +{ + if (sk_atm(vcc)->sk_family == PF_ATMSVC && + !test_bit(ATM_VF_META, &vcc->flags)) { + set_bit(ATM_VF_RELEASED, &vcc->flags); + clear_bit(ATM_VF_REGIS, &vcc->flags); + vcc_release_async(vcc, -EUNATCH); + } +} + +static void sigd_close(struct atm_vcc *vcc) +{ + struct sock *s; + int i; + + pr_debug("\n"); + sigd = NULL; + if (skb_peek(&sk_atm(vcc)->sk_receive_queue)) + pr_err("closing with requests pending\n"); + skb_queue_purge(&sk_atm(vcc)->sk_receive_queue); + + read_lock(&vcc_sklist_lock); + for (i = 0; i < VCC_HTABLE_SIZE; ++i) { + struct hlist_head *head = &vcc_hash[i]; + + sk_for_each(s, head) { + vcc = atm_sk(s); + + purge_vcc(vcc); + } + } + read_unlock(&vcc_sklist_lock); +} + +static const struct atmdev_ops sigd_dev_ops = { + .close = sigd_close, + .send = sigd_send +}; + +static struct atm_dev sigd_dev = { + .ops = &sigd_dev_ops, + .type = "sig", + .number = 999, + .lock = __SPIN_LOCK_UNLOCKED(sigd_dev.lock) +}; + +int sigd_attach(struct atm_vcc *vcc) +{ + if (sigd) + return -EADDRINUSE; + pr_debug("\n"); + sigd = vcc; + vcc->dev = &sigd_dev; + vcc_insert_socket(sk_atm(vcc)); + set_bit(ATM_VF_META, &vcc->flags); + set_bit(ATM_VF_READY, &vcc->flags); + return 0; +} diff --git a/net/atm/signaling.h b/net/atm/signaling.h new file mode 100644 index 000000000..2df8220f7 --- /dev/null +++ b/net/atm/signaling.h @@ -0,0 +1,31 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* net/atm/signaling.h - ATM signaling */ + +/* Written 1995-2000 by Werner Almesberger, EPFL LRC/ICA */ + + +#ifndef NET_ATM_SIGNALING_H +#define NET_ATM_SIGNALING_H + +#include +#include +#include + + +extern struct atm_vcc *sigd; /* needed in svc_release */ + + +/* + * sigd_enq is a wrapper for sigd_enq2, covering the more common cases, and + * avoiding huge lists of null values. + */ + +void sigd_enq2(struct atm_vcc *vcc,enum atmsvc_msg_type type, + struct atm_vcc *listen_vcc,const struct sockaddr_atmpvc *pvc, + const struct sockaddr_atmsvc *svc,const struct atm_qos *qos,int reply); +void sigd_enq(struct atm_vcc *vcc,enum atmsvc_msg_type type, + struct atm_vcc *listen_vcc,const struct sockaddr_atmpvc *pvc, + const struct sockaddr_atmsvc *svc); +int sigd_attach(struct atm_vcc *vcc); + +#endif diff --git a/net/atm/svc.c b/net/atm/svc.c new file mode 100644 index 000000000..4a02bcaad --- /dev/null +++ b/net/atm/svc.c @@ -0,0 +1,692 @@ +// SPDX-License-Identifier: GPL-2.0 +/* net/atm/svc.c - ATM SVC sockets */ + +/* Written 1995-2000 by Werner Almesberger, EPFL LRC/ICA */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s: " fmt, __func__ + +#include +#include /* struct socket, struct proto_ops */ +#include /* error codes */ +#include /* printk */ +#include +#include +#include +#include /* O_NONBLOCK */ +#include +#include /* ATM stuff */ +#include +#include +#include +#include +#include /* for sock_no_* */ +#include +#include + +#include "resources.h" +#include "common.h" /* common for PVCs and SVCs */ +#include "signaling.h" +#include "addr.h" + +static int svc_create(struct net *net, struct socket *sock, int protocol, + int kern); + +/* + * Note: since all this is still nicely synchronized with the signaling demon, + * there's no need to protect sleep loops with clis. If signaling is + * moved into the kernel, that would change. + */ + + +static int svc_shutdown(struct socket *sock, int how) +{ + return 0; +} + +static void svc_disconnect(struct atm_vcc *vcc) +{ + DEFINE_WAIT(wait); + struct sk_buff *skb; + struct sock *sk = sk_atm(vcc); + + pr_debug("%p\n", vcc); + if (test_bit(ATM_VF_REGIS, &vcc->flags)) { + sigd_enq(vcc, as_close, NULL, NULL, NULL); + for (;;) { + prepare_to_wait(sk_sleep(sk), &wait, TASK_UNINTERRUPTIBLE); + if (test_bit(ATM_VF_RELEASED, &vcc->flags) || !sigd) + break; + schedule(); + } + finish_wait(sk_sleep(sk), &wait); + } + /* beware - socket is still in use by atmsigd until the last + as_indicate has been answered */ + while ((skb = skb_dequeue(&sk->sk_receive_queue)) != NULL) { + atm_return(vcc, skb->truesize); + pr_debug("LISTEN REL\n"); + sigd_enq2(NULL, as_reject, vcc, NULL, NULL, &vcc->qos, 0); + dev_kfree_skb(skb); + } + clear_bit(ATM_VF_REGIS, &vcc->flags); + /* ... may retry later */ +} + +static int svc_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct atm_vcc *vcc; + + if (sk) { + vcc = ATM_SD(sock); + pr_debug("%p\n", vcc); + clear_bit(ATM_VF_READY, &vcc->flags); + /* + * VCC pointer is used as a reference, + * so we must not free it (thereby subjecting it to re-use) + * before all pending connections are closed + */ + svc_disconnect(vcc); + vcc_release(sock); + } + return 0; +} + +static int svc_bind(struct socket *sock, struct sockaddr *sockaddr, + int sockaddr_len) +{ + DEFINE_WAIT(wait); + struct sock *sk = sock->sk; + struct sockaddr_atmsvc *addr; + struct atm_vcc *vcc; + int error; + + if (sockaddr_len != sizeof(struct sockaddr_atmsvc)) + return -EINVAL; + lock_sock(sk); + if (sock->state == SS_CONNECTED) { + error = -EISCONN; + goto out; + } + if (sock->state != SS_UNCONNECTED) { + error = -EINVAL; + goto out; + } + vcc = ATM_SD(sock); + addr = (struct sockaddr_atmsvc *) sockaddr; + if (addr->sas_family != AF_ATMSVC) { + error = -EAFNOSUPPORT; + goto out; + } + clear_bit(ATM_VF_BOUND, &vcc->flags); + /* failing rebind will kill old binding */ + /* @@@ check memory (de)allocation on rebind */ + if (!test_bit(ATM_VF_HASQOS, &vcc->flags)) { + error = -EBADFD; + goto out; + } + vcc->local = *addr; + set_bit(ATM_VF_WAITING, &vcc->flags); + sigd_enq(vcc, as_bind, NULL, NULL, &vcc->local); + for (;;) { + prepare_to_wait(sk_sleep(sk), &wait, TASK_UNINTERRUPTIBLE); + if (!test_bit(ATM_VF_WAITING, &vcc->flags) || !sigd) + break; + schedule(); + } + finish_wait(sk_sleep(sk), &wait); + clear_bit(ATM_VF_REGIS, &vcc->flags); /* doesn't count */ + if (!sigd) { + error = -EUNATCH; + goto out; + } + if (!sk->sk_err) + set_bit(ATM_VF_BOUND, &vcc->flags); + error = -sk->sk_err; +out: + release_sock(sk); + return error; +} + +static int svc_connect(struct socket *sock, struct sockaddr *sockaddr, + int sockaddr_len, int flags) +{ + DEFINE_WAIT(wait); + struct sock *sk = sock->sk; + struct sockaddr_atmsvc *addr; + struct atm_vcc *vcc = ATM_SD(sock); + int error; + + pr_debug("%p\n", vcc); + lock_sock(sk); + if (sockaddr_len != sizeof(struct sockaddr_atmsvc)) { + error = -EINVAL; + goto out; + } + + switch (sock->state) { + default: + error = -EINVAL; + goto out; + case SS_CONNECTED: + error = -EISCONN; + goto out; + case SS_CONNECTING: + if (test_bit(ATM_VF_WAITING, &vcc->flags)) { + error = -EALREADY; + goto out; + } + sock->state = SS_UNCONNECTED; + if (sk->sk_err) { + error = -sk->sk_err; + goto out; + } + break; + case SS_UNCONNECTED: + addr = (struct sockaddr_atmsvc *) sockaddr; + if (addr->sas_family != AF_ATMSVC) { + error = -EAFNOSUPPORT; + goto out; + } + if (!test_bit(ATM_VF_HASQOS, &vcc->flags)) { + error = -EBADFD; + goto out; + } + if (vcc->qos.txtp.traffic_class == ATM_ANYCLASS || + vcc->qos.rxtp.traffic_class == ATM_ANYCLASS) { + error = -EINVAL; + goto out; + } + if (!vcc->qos.txtp.traffic_class && + !vcc->qos.rxtp.traffic_class) { + error = -EINVAL; + goto out; + } + vcc->remote = *addr; + set_bit(ATM_VF_WAITING, &vcc->flags); + sigd_enq(vcc, as_connect, NULL, NULL, &vcc->remote); + if (flags & O_NONBLOCK) { + sock->state = SS_CONNECTING; + error = -EINPROGRESS; + goto out; + } + error = 0; + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + while (test_bit(ATM_VF_WAITING, &vcc->flags) && sigd) { + schedule(); + if (!signal_pending(current)) { + prepare_to_wait(sk_sleep(sk), &wait, + TASK_INTERRUPTIBLE); + continue; + } + pr_debug("*ABORT*\n"); + /* + * This is tricky: + * Kernel ---close--> Demon + * Kernel <--close--- Demon + * or + * Kernel ---close--> Demon + * Kernel <--error--- Demon + * or + * Kernel ---close--> Demon + * Kernel <--okay---- Demon + * Kernel <--close--- Demon + */ + sigd_enq(vcc, as_close, NULL, NULL, NULL); + while (test_bit(ATM_VF_WAITING, &vcc->flags) && sigd) { + prepare_to_wait(sk_sleep(sk), &wait, + TASK_INTERRUPTIBLE); + schedule(); + } + if (!sk->sk_err) + while (!test_bit(ATM_VF_RELEASED, &vcc->flags) && + sigd) { + prepare_to_wait(sk_sleep(sk), &wait, + TASK_INTERRUPTIBLE); + schedule(); + } + clear_bit(ATM_VF_REGIS, &vcc->flags); + clear_bit(ATM_VF_RELEASED, &vcc->flags); + clear_bit(ATM_VF_CLOSE, &vcc->flags); + /* we're gone now but may connect later */ + error = -EINTR; + break; + } + finish_wait(sk_sleep(sk), &wait); + if (error) + goto out; + if (!sigd) { + error = -EUNATCH; + goto out; + } + if (sk->sk_err) { + error = -sk->sk_err; + goto out; + } + } + + vcc->qos.txtp.max_pcr = SELECT_TOP_PCR(vcc->qos.txtp); + vcc->qos.txtp.pcr = 0; + vcc->qos.txtp.min_pcr = 0; + + error = vcc_connect(sock, vcc->itf, vcc->vpi, vcc->vci); + if (!error) + sock->state = SS_CONNECTED; + else + (void)svc_disconnect(vcc); +out: + release_sock(sk); + return error; +} + +static int svc_listen(struct socket *sock, int backlog) +{ + DEFINE_WAIT(wait); + struct sock *sk = sock->sk; + struct atm_vcc *vcc = ATM_SD(sock); + int error; + + pr_debug("%p\n", vcc); + lock_sock(sk); + /* let server handle listen on unbound sockets */ + if (test_bit(ATM_VF_SESSION, &vcc->flags)) { + error = -EINVAL; + goto out; + } + if (test_bit(ATM_VF_LISTEN, &vcc->flags)) { + error = -EADDRINUSE; + goto out; + } + set_bit(ATM_VF_WAITING, &vcc->flags); + sigd_enq(vcc, as_listen, NULL, NULL, &vcc->local); + for (;;) { + prepare_to_wait(sk_sleep(sk), &wait, TASK_UNINTERRUPTIBLE); + if (!test_bit(ATM_VF_WAITING, &vcc->flags) || !sigd) + break; + schedule(); + } + finish_wait(sk_sleep(sk), &wait); + if (!sigd) { + error = -EUNATCH; + goto out; + } + set_bit(ATM_VF_LISTEN, &vcc->flags); + vcc_insert_socket(sk); + sk->sk_max_ack_backlog = backlog > 0 ? backlog : ATM_BACKLOG_DEFAULT; + error = -sk->sk_err; +out: + release_sock(sk); + return error; +} + +static int svc_accept(struct socket *sock, struct socket *newsock, int flags, + bool kern) +{ + struct sock *sk = sock->sk; + struct sk_buff *skb; + struct atmsvc_msg *msg; + struct atm_vcc *old_vcc = ATM_SD(sock); + struct atm_vcc *new_vcc; + int error; + + lock_sock(sk); + + error = svc_create(sock_net(sk), newsock, 0, kern); + if (error) + goto out; + + new_vcc = ATM_SD(newsock); + + pr_debug("%p -> %p\n", old_vcc, new_vcc); + while (1) { + DEFINE_WAIT(wait); + + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + while (!(skb = skb_dequeue(&sk->sk_receive_queue)) && + sigd) { + if (test_bit(ATM_VF_RELEASED, &old_vcc->flags)) + break; + if (test_bit(ATM_VF_CLOSE, &old_vcc->flags)) { + error = -sk->sk_err; + break; + } + if (flags & O_NONBLOCK) { + error = -EAGAIN; + break; + } + release_sock(sk); + schedule(); + lock_sock(sk); + if (signal_pending(current)) { + error = -ERESTARTSYS; + break; + } + prepare_to_wait(sk_sleep(sk), &wait, + TASK_INTERRUPTIBLE); + } + finish_wait(sk_sleep(sk), &wait); + if (error) + goto out; + if (!skb) { + error = -EUNATCH; + goto out; + } + msg = (struct atmsvc_msg *)skb->data; + new_vcc->qos = msg->qos; + set_bit(ATM_VF_HASQOS, &new_vcc->flags); + new_vcc->remote = msg->svc; + new_vcc->local = msg->local; + new_vcc->sap = msg->sap; + error = vcc_connect(newsock, msg->pvc.sap_addr.itf, + msg->pvc.sap_addr.vpi, + msg->pvc.sap_addr.vci); + dev_kfree_skb(skb); + sk_acceptq_removed(sk); + if (error) { + sigd_enq2(NULL, as_reject, old_vcc, NULL, NULL, + &old_vcc->qos, error); + error = error == -EAGAIN ? -EBUSY : error; + goto out; + } + /* wait should be short, so we ignore the non-blocking flag */ + set_bit(ATM_VF_WAITING, &new_vcc->flags); + sigd_enq(new_vcc, as_accept, old_vcc, NULL, NULL); + for (;;) { + prepare_to_wait(sk_sleep(sk_atm(new_vcc)), &wait, + TASK_UNINTERRUPTIBLE); + if (!test_bit(ATM_VF_WAITING, &new_vcc->flags) || !sigd) + break; + release_sock(sk); + schedule(); + lock_sock(sk); + } + finish_wait(sk_sleep(sk_atm(new_vcc)), &wait); + if (!sigd) { + error = -EUNATCH; + goto out; + } + if (!sk_atm(new_vcc)->sk_err) + break; + if (sk_atm(new_vcc)->sk_err != ERESTARTSYS) { + error = -sk_atm(new_vcc)->sk_err; + goto out; + } + } + newsock->state = SS_CONNECTED; +out: + release_sock(sk); + return error; +} + +static int svc_getname(struct socket *sock, struct sockaddr *sockaddr, + int peer) +{ + struct sockaddr_atmsvc *addr; + + addr = (struct sockaddr_atmsvc *) sockaddr; + memcpy(addr, peer ? &ATM_SD(sock)->remote : &ATM_SD(sock)->local, + sizeof(struct sockaddr_atmsvc)); + return sizeof(struct sockaddr_atmsvc); +} + +int svc_change_qos(struct atm_vcc *vcc, struct atm_qos *qos) +{ + struct sock *sk = sk_atm(vcc); + DEFINE_WAIT(wait); + + set_bit(ATM_VF_WAITING, &vcc->flags); + sigd_enq2(vcc, as_modify, NULL, NULL, &vcc->local, qos, 0); + for (;;) { + prepare_to_wait(sk_sleep(sk), &wait, TASK_UNINTERRUPTIBLE); + if (!test_bit(ATM_VF_WAITING, &vcc->flags) || + test_bit(ATM_VF_RELEASED, &vcc->flags) || !sigd) { + break; + } + schedule(); + } + finish_wait(sk_sleep(sk), &wait); + if (!sigd) + return -EUNATCH; + return -sk->sk_err; +} + +static int svc_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct atm_vcc *vcc = ATM_SD(sock); + int value, error = 0; + + lock_sock(sk); + switch (optname) { + case SO_ATMSAP: + if (level != SOL_ATM || optlen != sizeof(struct atm_sap)) { + error = -EINVAL; + goto out; + } + if (copy_from_sockptr(&vcc->sap, optval, optlen)) { + error = -EFAULT; + goto out; + } + set_bit(ATM_VF_HASSAP, &vcc->flags); + break; + case SO_MULTIPOINT: + if (level != SOL_ATM || optlen != sizeof(int)) { + error = -EINVAL; + goto out; + } + if (copy_from_sockptr(&value, optval, sizeof(int))) { + error = -EFAULT; + goto out; + } + if (value == 1) + set_bit(ATM_VF_SESSION, &vcc->flags); + else if (value == 0) + clear_bit(ATM_VF_SESSION, &vcc->flags); + else + error = -EINVAL; + break; + default: + error = vcc_setsockopt(sock, level, optname, optval, optlen); + } + +out: + release_sock(sk); + return error; +} + +static int svc_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + int error = 0, len; + + lock_sock(sk); + if (!__SO_LEVEL_MATCH(optname, level) || optname != SO_ATMSAP) { + error = vcc_getsockopt(sock, level, optname, optval, optlen); + goto out; + } + if (get_user(len, optlen)) { + error = -EFAULT; + goto out; + } + if (len != sizeof(struct atm_sap)) { + error = -EINVAL; + goto out; + } + if (copy_to_user(optval, &ATM_SD(sock)->sap, sizeof(struct atm_sap))) { + error = -EFAULT; + goto out; + } +out: + release_sock(sk); + return error; +} + +static int svc_addparty(struct socket *sock, struct sockaddr *sockaddr, + int sockaddr_len, int flags) +{ + DEFINE_WAIT(wait); + struct sock *sk = sock->sk; + struct atm_vcc *vcc = ATM_SD(sock); + int error; + + lock_sock(sk); + set_bit(ATM_VF_WAITING, &vcc->flags); + sigd_enq(vcc, as_addparty, NULL, NULL, + (struct sockaddr_atmsvc *) sockaddr); + if (flags & O_NONBLOCK) { + error = -EINPROGRESS; + goto out; + } + pr_debug("added wait queue\n"); + for (;;) { + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + if (!test_bit(ATM_VF_WAITING, &vcc->flags) || !sigd) + break; + schedule(); + } + finish_wait(sk_sleep(sk), &wait); + error = -xchg(&sk->sk_err_soft, 0); +out: + release_sock(sk); + return error; +} + +static int svc_dropparty(struct socket *sock, int ep_ref) +{ + DEFINE_WAIT(wait); + struct sock *sk = sock->sk; + struct atm_vcc *vcc = ATM_SD(sock); + int error; + + lock_sock(sk); + set_bit(ATM_VF_WAITING, &vcc->flags); + sigd_enq2(vcc, as_dropparty, NULL, NULL, NULL, NULL, ep_ref); + for (;;) { + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + if (!test_bit(ATM_VF_WAITING, &vcc->flags) || !sigd) + break; + schedule(); + } + finish_wait(sk_sleep(sk), &wait); + if (!sigd) { + error = -EUNATCH; + goto out; + } + error = -xchg(&sk->sk_err_soft, 0); +out: + release_sock(sk); + return error; +} + +static int svc_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + int error, ep_ref; + struct sockaddr_atmsvc sa; + struct atm_vcc *vcc = ATM_SD(sock); + + switch (cmd) { + case ATM_ADDPARTY: + if (!test_bit(ATM_VF_SESSION, &vcc->flags)) + return -EINVAL; + if (copy_from_user(&sa, (void __user *) arg, sizeof(sa))) + return -EFAULT; + error = svc_addparty(sock, (struct sockaddr *)&sa, sizeof(sa), + 0); + break; + case ATM_DROPPARTY: + if (!test_bit(ATM_VF_SESSION, &vcc->flags)) + return -EINVAL; + if (copy_from_user(&ep_ref, (void __user *) arg, sizeof(int))) + return -EFAULT; + error = svc_dropparty(sock, ep_ref); + break; + default: + error = vcc_ioctl(sock, cmd, arg); + } + + return error; +} + +#ifdef CONFIG_COMPAT +static int svc_compat_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + /* The definition of ATM_ADDPARTY uses the size of struct atm_iobuf. + But actually it takes a struct sockaddr_atmsvc, which doesn't need + compat handling. So all we have to do is fix up cmd... */ + if (cmd == COMPAT_ATM_ADDPARTY) + cmd = ATM_ADDPARTY; + + if (cmd == ATM_ADDPARTY || cmd == ATM_DROPPARTY) + return svc_ioctl(sock, cmd, arg); + else + return vcc_compat_ioctl(sock, cmd, arg); +} +#endif /* CONFIG_COMPAT */ + +static const struct proto_ops svc_proto_ops = { + .family = PF_ATMSVC, + .owner = THIS_MODULE, + + .release = svc_release, + .bind = svc_bind, + .connect = svc_connect, + .socketpair = sock_no_socketpair, + .accept = svc_accept, + .getname = svc_getname, + .poll = vcc_poll, + .ioctl = svc_ioctl, +#ifdef CONFIG_COMPAT + .compat_ioctl = svc_compat_ioctl, +#endif + .gettstamp = sock_gettstamp, + .listen = svc_listen, + .shutdown = svc_shutdown, + .setsockopt = svc_setsockopt, + .getsockopt = svc_getsockopt, + .sendmsg = vcc_sendmsg, + .recvmsg = vcc_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + + +static int svc_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + int error; + + if (!net_eq(net, &init_net)) + return -EAFNOSUPPORT; + + sock->ops = &svc_proto_ops; + error = vcc_create(net, sock, protocol, AF_ATMSVC, kern); + if (error) + return error; + ATM_SD(sock)->local.sas_family = AF_ATMSVC; + ATM_SD(sock)->remote.sas_family = AF_ATMSVC; + return 0; +} + +static const struct net_proto_family svc_family_ops = { + .family = PF_ATMSVC, + .create = svc_create, + .owner = THIS_MODULE, +}; + + +/* + * Initialize the ATM SVC protocol family + */ + +int __init atmsvc_init(void) +{ + return sock_register(&svc_family_ops); +} + +void atmsvc_exit(void) +{ + sock_unregister(PF_ATMSVC); +} diff --git a/net/ax25/Kconfig b/net/ax25/Kconfig new file mode 100644 index 000000000..d3a9843a0 --- /dev/null +++ b/net/ax25/Kconfig @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Amateur Radio protocols and AX.25 device configuration +# + +menuconfig HAMRADIO + depends on NET && !S390 + bool "Amateur Radio support" + help + If you want to connect your Linux box to an amateur radio, answer Y + here. You want to read + and more specifically about AX.25 on Linux + . + + Note that the answer to this question won't directly affect the + kernel: saying N will just cause the configurator to skip all + the questions about amateur radio. + +comment "Packet Radio protocols" + depends on HAMRADIO + +config AX25 + tristate "Amateur Radio AX.25 Level 2 protocol" + depends on HAMRADIO + help + This is the protocol used for computer communication over amateur + radio. It is either used by itself for point-to-point links, or to + carry other protocols such as tcp/ip. To use it, you need a device + that connects your Linux box to your amateur radio. You can either + use a low speed TNC (a Terminal Node Controller acts as a kind of + modem connecting your computer's serial port to your radio's + microphone input and speaker output) supporting the KISS protocol or + one of the various SCC cards that are supported by the generic Z8530 + or the DMA SCC driver. Another option are the Baycom modem serial + and parallel port hacks or the sound card modem (supported by their + own drivers). If you say Y here, you also have to say Y to one of + those drivers. + + Information about where to get supporting software for Linux amateur + radio as well as information about how to configure an AX.25 port is + contained in the AX25-HOWTO, available from + . You might also want to + check out the file in the + kernel source. More information about digital amateur radio in + general is on the WWW at + . + + To compile this driver as a module, choose M here: the + module will be called ax25. + +config AX25_DAMA_SLAVE + bool "AX.25 DAMA Slave support" + default y + depends on AX25 + help + DAMA is a mechanism to prevent collisions when doing AX.25 + networking. A DAMA server (called "master") accepts incoming traffic + from clients (called "slaves") and redistributes it to other slaves. + If you say Y here, your Linux box will act as a DAMA slave; this is + transparent in that you don't have to do any special DAMA + configuration. Linux cannot yet act as a DAMA server. This option + only compiles DAMA slave support into the kernel. It still needs to + be enabled at runtime. For more about DAMA see + . If unsure, say Y. + +# placeholder until implemented +config AX25_DAMA_MASTER + bool 'AX.25 DAMA Master support' + depends on AX25_DAMA_SLAVE && BROKEN + help + DAMA is a mechanism to prevent collisions when doing AX.25 + networking. A DAMA server (called "master") accepts incoming traffic + from clients (called "slaves") and redistributes it to other slaves. + If you say Y here, your Linux box will act as a DAMA master; this is + transparent in that you don't have to do any special DAMA + configuration. Linux cannot yet act as a DAMA server. This option + only compiles DAMA slave support into the kernel. It still needs to + be explicitly enabled, so if unsure, say Y. + +config NETROM + tristate "Amateur Radio NET/ROM protocol" + depends on AX25 + help + NET/ROM is a network layer protocol on top of AX.25 useful for + routing. + + A comprehensive listing of all the software for Linux amateur radio + users as well as information about how to configure an AX.25 port is + contained in the Linux Ham Wiki, available from + . You also might want to check out the + file . More information about + digital amateur radio in general is on the WWW at + . + + To compile this driver as a module, choose M here: the + module will be called netrom. + +config ROSE + tristate "Amateur Radio X.25 PLP (Rose)" + depends on AX25 + help + The Packet Layer Protocol (PLP) is a way to route packets over X.25 + connections in general and amateur radio AX.25 connections in + particular, essentially an alternative to NET/ROM. + + A comprehensive listing of all the software for Linux amateur radio + users as well as information about how to configure an AX.25 port is + contained in the Linux Ham Wiki, available from + . You also might want to check out the + file . More information about + digital amateur radio in general is on the WWW at + . + + To compile this driver as a module, choose M here: the + module will be called rose. + +menu "AX.25 network device drivers" + depends on HAMRADIO && AX25 + +source "drivers/net/hamradio/Kconfig" + +endmenu diff --git a/net/ax25/Makefile b/net/ax25/Makefile new file mode 100644 index 000000000..2e53affc8 --- /dev/null +++ b/net/ax25/Makefile @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the Linux AX.25 layer. +# + +obj-$(CONFIG_AX25) += ax25.o + +ax25-y := ax25_addr.o ax25_dev.o ax25_iface.o ax25_in.o ax25_ip.o ax25_out.o \ + ax25_route.o ax25_std_in.o ax25_std_subr.o ax25_std_timer.o \ + ax25_subr.o ax25_timer.o ax25_uid.o af_ax25.o +ax25-$(CONFIG_AX25_DAMA_SLAVE) += ax25_ds_in.o ax25_ds_subr.o ax25_ds_timer.o +ax25-$(CONFIG_SYSCTL) += sysctl_net_ax25.o diff --git a/net/ax25/af_ax25.c b/net/ax25/af_ax25.c new file mode 100644 index 000000000..6b4c25a92 --- /dev/null +++ b/net/ax25/af_ax25.c @@ -0,0 +1,2083 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Alan Cox GW4PTS (alan@lxorguk.ukuu.org.uk) + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright (C) Darryl Miles G7LED (dlm@g7led.demon.co.uk) + * Copyright (C) Steven Whitehouse GW7RRM (stevew@acm.org) + * Copyright (C) Joerg Reuter DL1BKE (jreuter@yaina.de) + * Copyright (C) Hans-Joachim Hetscher DD8NE (dd8ne@bnv-bamberg.de) + * Copyright (C) Hans Alblas PE1AYX (hans@esrac.ele.tue.nl) + * Copyright (C) Frederic Rible F1OAT (frible@teaser.fr) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include /* For TIOCINQ/OUTQ */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +HLIST_HEAD(ax25_list); +DEFINE_SPINLOCK(ax25_list_lock); + +static const struct proto_ops ax25_proto_ops; + +static void ax25_free_sock(struct sock *sk) +{ + ax25_cb_put(sk_to_ax25(sk)); +} + +/* + * Socket removal during an interrupt is now safe. + */ +static void ax25_cb_del(ax25_cb *ax25) +{ + spin_lock_bh(&ax25_list_lock); + if (!hlist_unhashed(&ax25->ax25_node)) { + hlist_del_init(&ax25->ax25_node); + ax25_cb_put(ax25); + } + spin_unlock_bh(&ax25_list_lock); +} + +/* + * Kill all bound sockets on a dropped device. + */ +static void ax25_kill_by_device(struct net_device *dev) +{ + ax25_dev *ax25_dev; + ax25_cb *s; + struct sock *sk; + + if ((ax25_dev = ax25_dev_ax25dev(dev)) == NULL) + return; + ax25_dev->device_up = false; + + spin_lock_bh(&ax25_list_lock); +again: + ax25_for_each(s, &ax25_list) { + if (s->ax25_dev == ax25_dev) { + sk = s->sk; + if (!sk) { + spin_unlock_bh(&ax25_list_lock); + ax25_disconnect(s, ENETUNREACH); + s->ax25_dev = NULL; + ax25_cb_del(s); + spin_lock_bh(&ax25_list_lock); + goto again; + } + sock_hold(sk); + spin_unlock_bh(&ax25_list_lock); + lock_sock(sk); + ax25_disconnect(s, ENETUNREACH); + s->ax25_dev = NULL; + if (sk->sk_socket) { + netdev_put(ax25_dev->dev, + &ax25_dev->dev_tracker); + ax25_dev_put(ax25_dev); + } + ax25_cb_del(s); + release_sock(sk); + spin_lock_bh(&ax25_list_lock); + sock_put(sk); + /* The entry could have been deleted from the + * list meanwhile and thus the next pointer is + * no longer valid. Play it safe and restart + * the scan. Forward progress is ensured + * because we set s->ax25_dev to NULL and we + * are never passed a NULL 'dev' argument. + */ + goto again; + } + } + spin_unlock_bh(&ax25_list_lock); +} + +/* + * Handle device status changes. + */ +static int ax25_device_event(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + if (!net_eq(dev_net(dev), &init_net)) + return NOTIFY_DONE; + + /* Reject non AX.25 devices */ + if (dev->type != ARPHRD_AX25) + return NOTIFY_DONE; + + switch (event) { + case NETDEV_UP: + ax25_dev_device_up(dev); + break; + case NETDEV_DOWN: + ax25_kill_by_device(dev); + ax25_rt_device_down(dev); + ax25_dev_device_down(dev); + break; + default: + break; + } + + return NOTIFY_DONE; +} + +/* + * Add a socket to the bound sockets list. + */ +void ax25_cb_add(ax25_cb *ax25) +{ + spin_lock_bh(&ax25_list_lock); + ax25_cb_hold(ax25); + hlist_add_head(&ax25->ax25_node, &ax25_list); + spin_unlock_bh(&ax25_list_lock); +} + +/* + * Find a socket that wants to accept the SABM we have just + * received. + */ +struct sock *ax25_find_listener(ax25_address *addr, int digi, + struct net_device *dev, int type) +{ + ax25_cb *s; + + spin_lock(&ax25_list_lock); + ax25_for_each(s, &ax25_list) { + if ((s->iamdigi && !digi) || (!s->iamdigi && digi)) + continue; + if (s->sk && !ax25cmp(&s->source_addr, addr) && + s->sk->sk_type == type && s->sk->sk_state == TCP_LISTEN) { + /* If device is null we match any device */ + if (s->ax25_dev == NULL || s->ax25_dev->dev == dev) { + sock_hold(s->sk); + spin_unlock(&ax25_list_lock); + return s->sk; + } + } + } + spin_unlock(&ax25_list_lock); + + return NULL; +} + +/* + * Find an AX.25 socket given both ends. + */ +struct sock *ax25_get_socket(ax25_address *my_addr, ax25_address *dest_addr, + int type) +{ + struct sock *sk = NULL; + ax25_cb *s; + + spin_lock(&ax25_list_lock); + ax25_for_each(s, &ax25_list) { + if (s->sk && !ax25cmp(&s->source_addr, my_addr) && + !ax25cmp(&s->dest_addr, dest_addr) && + s->sk->sk_type == type) { + sk = s->sk; + sock_hold(sk); + break; + } + } + + spin_unlock(&ax25_list_lock); + + return sk; +} + +/* + * Find an AX.25 control block given both ends. It will only pick up + * floating AX.25 control blocks or non Raw socket bound control blocks. + */ +ax25_cb *ax25_find_cb(const ax25_address *src_addr, ax25_address *dest_addr, + ax25_digi *digi, struct net_device *dev) +{ + ax25_cb *s; + + spin_lock_bh(&ax25_list_lock); + ax25_for_each(s, &ax25_list) { + if (s->sk && s->sk->sk_type != SOCK_SEQPACKET) + continue; + if (s->ax25_dev == NULL) + continue; + if (ax25cmp(&s->source_addr, src_addr) == 0 && ax25cmp(&s->dest_addr, dest_addr) == 0 && s->ax25_dev->dev == dev) { + if (digi != NULL && digi->ndigi != 0) { + if (s->digipeat == NULL) + continue; + if (ax25digicmp(s->digipeat, digi) != 0) + continue; + } else { + if (s->digipeat != NULL && s->digipeat->ndigi != 0) + continue; + } + ax25_cb_hold(s); + spin_unlock_bh(&ax25_list_lock); + + return s; + } + } + spin_unlock_bh(&ax25_list_lock); + + return NULL; +} + +EXPORT_SYMBOL(ax25_find_cb); + +void ax25_send_to_raw(ax25_address *addr, struct sk_buff *skb, int proto) +{ + ax25_cb *s; + struct sk_buff *copy; + + spin_lock(&ax25_list_lock); + ax25_for_each(s, &ax25_list) { + if (s->sk != NULL && ax25cmp(&s->source_addr, addr) == 0 && + s->sk->sk_type == SOCK_RAW && + s->sk->sk_protocol == proto && + s->ax25_dev->dev == skb->dev && + atomic_read(&s->sk->sk_rmem_alloc) <= s->sk->sk_rcvbuf) { + if ((copy = skb_clone(skb, GFP_ATOMIC)) == NULL) + continue; + if (sock_queue_rcv_skb(s->sk, copy) != 0) + kfree_skb(copy); + } + } + spin_unlock(&ax25_list_lock); +} + +/* + * Deferred destroy. + */ +void ax25_destroy_socket(ax25_cb *); + +/* + * Handler for deferred kills. + */ +static void ax25_destroy_timer(struct timer_list *t) +{ + ax25_cb *ax25 = from_timer(ax25, t, dtimer); + struct sock *sk; + + sk=ax25->sk; + + bh_lock_sock(sk); + sock_hold(sk); + ax25_destroy_socket(ax25); + bh_unlock_sock(sk); + sock_put(sk); +} + +/* + * This is called from user mode and the timers. Thus it protects itself + * against interrupt users but doesn't worry about being called during + * work. Once it is removed from the queue no interrupt or bottom half + * will touch it and we are (fairly 8-) ) safe. + */ +void ax25_destroy_socket(ax25_cb *ax25) +{ + struct sk_buff *skb; + + ax25_cb_del(ax25); + + ax25_stop_heartbeat(ax25); + ax25_stop_t1timer(ax25); + ax25_stop_t2timer(ax25); + ax25_stop_t3timer(ax25); + ax25_stop_idletimer(ax25); + + ax25_clear_queues(ax25); /* Flush the queues */ + + if (ax25->sk != NULL) { + while ((skb = skb_dequeue(&ax25->sk->sk_receive_queue)) != NULL) { + if (skb->sk != ax25->sk) { + /* A pending connection */ + ax25_cb *sax25 = sk_to_ax25(skb->sk); + + /* Queue the unaccepted socket for death */ + sock_orphan(skb->sk); + + /* 9A4GL: hack to release unaccepted sockets */ + skb->sk->sk_state = TCP_LISTEN; + + ax25_start_heartbeat(sax25); + sax25->state = AX25_STATE_0; + } + + kfree_skb(skb); + } + skb_queue_purge(&ax25->sk->sk_write_queue); + } + + if (ax25->sk != NULL) { + if (sk_has_allocations(ax25->sk)) { + /* Defer: outstanding buffers */ + timer_setup(&ax25->dtimer, ax25_destroy_timer, 0); + ax25->dtimer.expires = jiffies + 2 * HZ; + add_timer(&ax25->dtimer); + } else { + struct sock *sk=ax25->sk; + ax25->sk=NULL; + sock_put(sk); + } + } else { + ax25_cb_put(ax25); + } +} + +/* + * dl1bke 960311: set parameters for existing AX.25 connections, + * includes a KILL command to abort any connection. + * VERY useful for debugging ;-) + */ +static int ax25_ctl_ioctl(const unsigned int cmd, void __user *arg) +{ + struct ax25_ctl_struct ax25_ctl; + ax25_digi digi; + ax25_dev *ax25_dev; + ax25_cb *ax25; + unsigned int k; + int ret = 0; + + if (copy_from_user(&ax25_ctl, arg, sizeof(ax25_ctl))) + return -EFAULT; + + if (ax25_ctl.digi_count > AX25_MAX_DIGIS) + return -EINVAL; + + if (ax25_ctl.arg > ULONG_MAX / HZ && ax25_ctl.cmd != AX25_KILL) + return -EINVAL; + + ax25_dev = ax25_addr_ax25dev(&ax25_ctl.port_addr); + if (!ax25_dev) + return -ENODEV; + + digi.ndigi = ax25_ctl.digi_count; + for (k = 0; k < digi.ndigi; k++) + digi.calls[k] = ax25_ctl.digi_addr[k]; + + ax25 = ax25_find_cb(&ax25_ctl.source_addr, &ax25_ctl.dest_addr, &digi, ax25_dev->dev); + if (!ax25) { + ax25_dev_put(ax25_dev); + return -ENOTCONN; + } + + switch (ax25_ctl.cmd) { + case AX25_KILL: + ax25_send_control(ax25, AX25_DISC, AX25_POLLON, AX25_COMMAND); +#ifdef CONFIG_AX25_DAMA_SLAVE + if (ax25_dev->dama.slave && ax25->ax25_dev->values[AX25_VALUES_PROTOCOL] == AX25_PROTO_DAMA_SLAVE) + ax25_dama_off(ax25); +#endif + ax25_disconnect(ax25, ENETRESET); + break; + + case AX25_WINDOW: + if (ax25->modulus == AX25_MODULUS) { + if (ax25_ctl.arg < 1 || ax25_ctl.arg > 7) + goto einval_put; + } else { + if (ax25_ctl.arg < 1 || ax25_ctl.arg > 63) + goto einval_put; + } + ax25->window = ax25_ctl.arg; + break; + + case AX25_T1: + if (ax25_ctl.arg < 1 || ax25_ctl.arg > ULONG_MAX / HZ) + goto einval_put; + ax25->rtt = (ax25_ctl.arg * HZ) / 2; + ax25->t1 = ax25_ctl.arg * HZ; + break; + + case AX25_T2: + if (ax25_ctl.arg < 1 || ax25_ctl.arg > ULONG_MAX / HZ) + goto einval_put; + ax25->t2 = ax25_ctl.arg * HZ; + break; + + case AX25_N2: + if (ax25_ctl.arg < 1 || ax25_ctl.arg > 31) + goto einval_put; + ax25->n2count = 0; + ax25->n2 = ax25_ctl.arg; + break; + + case AX25_T3: + if (ax25_ctl.arg > ULONG_MAX / HZ) + goto einval_put; + ax25->t3 = ax25_ctl.arg * HZ; + break; + + case AX25_IDLE: + if (ax25_ctl.arg > ULONG_MAX / (60 * HZ)) + goto einval_put; + + ax25->idle = ax25_ctl.arg * 60 * HZ; + break; + + case AX25_PACLEN: + if (ax25_ctl.arg < 16 || ax25_ctl.arg > 65535) + goto einval_put; + ax25->paclen = ax25_ctl.arg; + break; + + default: + goto einval_put; + } + +out_put: + ax25_dev_put(ax25_dev); + ax25_cb_put(ax25); + return ret; + +einval_put: + ret = -EINVAL; + goto out_put; +} + +static void ax25_fillin_cb_from_dev(ax25_cb *ax25, ax25_dev *ax25_dev) +{ + ax25->rtt = msecs_to_jiffies(ax25_dev->values[AX25_VALUES_T1]) / 2; + ax25->t1 = msecs_to_jiffies(ax25_dev->values[AX25_VALUES_T1]); + ax25->t2 = msecs_to_jiffies(ax25_dev->values[AX25_VALUES_T2]); + ax25->t3 = msecs_to_jiffies(ax25_dev->values[AX25_VALUES_T3]); + ax25->n2 = ax25_dev->values[AX25_VALUES_N2]; + ax25->paclen = ax25_dev->values[AX25_VALUES_PACLEN]; + ax25->idle = msecs_to_jiffies(ax25_dev->values[AX25_VALUES_IDLE]); + ax25->backoff = ax25_dev->values[AX25_VALUES_BACKOFF]; + + if (ax25_dev->values[AX25_VALUES_AXDEFMODE]) { + ax25->modulus = AX25_EMODULUS; + ax25->window = ax25_dev->values[AX25_VALUES_EWINDOW]; + } else { + ax25->modulus = AX25_MODULUS; + ax25->window = ax25_dev->values[AX25_VALUES_WINDOW]; + } +} + +/* + * Fill in a created AX.25 created control block with the default + * values for a particular device. + */ +void ax25_fillin_cb(ax25_cb *ax25, ax25_dev *ax25_dev) +{ + ax25->ax25_dev = ax25_dev; + + if (ax25->ax25_dev != NULL) { + ax25_fillin_cb_from_dev(ax25, ax25_dev); + return; + } + + /* + * No device, use kernel / AX.25 spec default values + */ + ax25->rtt = msecs_to_jiffies(AX25_DEF_T1) / 2; + ax25->t1 = msecs_to_jiffies(AX25_DEF_T1); + ax25->t2 = msecs_to_jiffies(AX25_DEF_T2); + ax25->t3 = msecs_to_jiffies(AX25_DEF_T3); + ax25->n2 = AX25_DEF_N2; + ax25->paclen = AX25_DEF_PACLEN; + ax25->idle = msecs_to_jiffies(AX25_DEF_IDLE); + ax25->backoff = AX25_DEF_BACKOFF; + + if (AX25_DEF_AXDEFMODE) { + ax25->modulus = AX25_EMODULUS; + ax25->window = AX25_DEF_EWINDOW; + } else { + ax25->modulus = AX25_MODULUS; + ax25->window = AX25_DEF_WINDOW; + } +} + +/* + * Create an empty AX.25 control block. + */ +ax25_cb *ax25_create_cb(void) +{ + ax25_cb *ax25; + + if ((ax25 = kzalloc(sizeof(*ax25), GFP_ATOMIC)) == NULL) + return NULL; + + refcount_set(&ax25->refcount, 1); + + skb_queue_head_init(&ax25->write_queue); + skb_queue_head_init(&ax25->frag_queue); + skb_queue_head_init(&ax25->ack_queue); + skb_queue_head_init(&ax25->reseq_queue); + + ax25_setup_timers(ax25); + + ax25_fillin_cb(ax25, NULL); + + ax25->state = AX25_STATE_0; + + return ax25; +} + +/* + * Handling for system calls applied via the various interfaces to an + * AX25 socket object + */ + +static int ax25_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + ax25_cb *ax25; + struct net_device *dev; + char devname[IFNAMSIZ]; + unsigned int opt; + int res = 0; + + if (level != SOL_AX25) + return -ENOPROTOOPT; + + if (optlen < sizeof(unsigned int)) + return -EINVAL; + + if (copy_from_sockptr(&opt, optval, sizeof(unsigned int))) + return -EFAULT; + + lock_sock(sk); + ax25 = sk_to_ax25(sk); + + switch (optname) { + case AX25_WINDOW: + if (ax25->modulus == AX25_MODULUS) { + if (opt < 1 || opt > 7) { + res = -EINVAL; + break; + } + } else { + if (opt < 1 || opt > 63) { + res = -EINVAL; + break; + } + } + ax25->window = opt; + break; + + case AX25_T1: + if (opt < 1 || opt > UINT_MAX / HZ) { + res = -EINVAL; + break; + } + ax25->rtt = (opt * HZ) >> 1; + ax25->t1 = opt * HZ; + break; + + case AX25_T2: + if (opt < 1 || opt > UINT_MAX / HZ) { + res = -EINVAL; + break; + } + ax25->t2 = opt * HZ; + break; + + case AX25_N2: + if (opt < 1 || opt > 31) { + res = -EINVAL; + break; + } + ax25->n2 = opt; + break; + + case AX25_T3: + if (opt < 1 || opt > UINT_MAX / HZ) { + res = -EINVAL; + break; + } + ax25->t3 = opt * HZ; + break; + + case AX25_IDLE: + if (opt > UINT_MAX / (60 * HZ)) { + res = -EINVAL; + break; + } + ax25->idle = opt * 60 * HZ; + break; + + case AX25_BACKOFF: + if (opt > 2) { + res = -EINVAL; + break; + } + ax25->backoff = opt; + break; + + case AX25_EXTSEQ: + ax25->modulus = opt ? AX25_EMODULUS : AX25_MODULUS; + break; + + case AX25_PIDINCL: + ax25->pidincl = opt ? 1 : 0; + break; + + case AX25_IAMDIGI: + ax25->iamdigi = opt ? 1 : 0; + break; + + case AX25_PACLEN: + if (opt < 16 || opt > 65535) { + res = -EINVAL; + break; + } + ax25->paclen = opt; + break; + + case SO_BINDTODEVICE: + if (optlen > IFNAMSIZ - 1) + optlen = IFNAMSIZ - 1; + + memset(devname, 0, sizeof(devname)); + + if (copy_from_sockptr(devname, optval, optlen)) { + res = -EFAULT; + break; + } + + if (sk->sk_type == SOCK_SEQPACKET && + (sock->state != SS_UNCONNECTED || + sk->sk_state == TCP_LISTEN)) { + res = -EADDRNOTAVAIL; + break; + } + + rtnl_lock(); + dev = __dev_get_by_name(&init_net, devname); + if (!dev) { + rtnl_unlock(); + res = -ENODEV; + break; + } + + ax25->ax25_dev = ax25_dev_ax25dev(dev); + if (!ax25->ax25_dev) { + rtnl_unlock(); + res = -ENODEV; + break; + } + ax25_fillin_cb(ax25, ax25->ax25_dev); + rtnl_unlock(); + break; + + default: + res = -ENOPROTOOPT; + } + release_sock(sk); + + return res; +} + +static int ax25_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + ax25_cb *ax25; + struct ax25_dev *ax25_dev; + char devname[IFNAMSIZ]; + void *valptr; + int val = 0; + int maxlen, length; + + if (level != SOL_AX25) + return -ENOPROTOOPT; + + if (get_user(maxlen, optlen)) + return -EFAULT; + + if (maxlen < 1) + return -EFAULT; + + valptr = (void *) &val; + length = min_t(unsigned int, maxlen, sizeof(int)); + + lock_sock(sk); + ax25 = sk_to_ax25(sk); + + switch (optname) { + case AX25_WINDOW: + val = ax25->window; + break; + + case AX25_T1: + val = ax25->t1 / HZ; + break; + + case AX25_T2: + val = ax25->t2 / HZ; + break; + + case AX25_N2: + val = ax25->n2; + break; + + case AX25_T3: + val = ax25->t3 / HZ; + break; + + case AX25_IDLE: + val = ax25->idle / (60 * HZ); + break; + + case AX25_BACKOFF: + val = ax25->backoff; + break; + + case AX25_EXTSEQ: + val = (ax25->modulus == AX25_EMODULUS); + break; + + case AX25_PIDINCL: + val = ax25->pidincl; + break; + + case AX25_IAMDIGI: + val = ax25->iamdigi; + break; + + case AX25_PACLEN: + val = ax25->paclen; + break; + + case SO_BINDTODEVICE: + ax25_dev = ax25->ax25_dev; + + if (ax25_dev != NULL && ax25_dev->dev != NULL) { + strscpy(devname, ax25_dev->dev->name, sizeof(devname)); + length = strlen(devname) + 1; + } else { + *devname = '\0'; + length = 1; + } + + valptr = (void *) devname; + break; + + default: + release_sock(sk); + return -ENOPROTOOPT; + } + release_sock(sk); + + if (put_user(length, optlen)) + return -EFAULT; + + return copy_to_user(optval, valptr, length) ? -EFAULT : 0; +} + +static int ax25_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + int res = 0; + + lock_sock(sk); + if (sk->sk_type == SOCK_SEQPACKET && sk->sk_state != TCP_LISTEN) { + sk->sk_max_ack_backlog = backlog; + sk->sk_state = TCP_LISTEN; + goto out; + } + res = -EOPNOTSUPP; + +out: + release_sock(sk); + + return res; +} + +/* + * XXX: when creating ax25_sock we should update the .obj_size setting + * below. + */ +static struct proto ax25_proto = { + .name = "AX25", + .owner = THIS_MODULE, + .obj_size = sizeof(struct ax25_sock), +}; + +static int ax25_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + ax25_cb *ax25; + + if (protocol < 0 || protocol > U8_MAX) + return -EINVAL; + + if (!net_eq(net, &init_net)) + return -EAFNOSUPPORT; + + switch (sock->type) { + case SOCK_DGRAM: + if (protocol == 0 || protocol == PF_AX25) + protocol = AX25_P_TEXT; + break; + + case SOCK_SEQPACKET: + switch (protocol) { + case 0: + case PF_AX25: /* For CLX */ + protocol = AX25_P_TEXT; + break; + case AX25_P_SEGMENT: +#ifdef CONFIG_INET + case AX25_P_ARP: + case AX25_P_IP: +#endif +#ifdef CONFIG_NETROM + case AX25_P_NETROM: +#endif +#ifdef CONFIG_ROSE + case AX25_P_ROSE: +#endif + return -ESOCKTNOSUPPORT; +#ifdef CONFIG_NETROM_MODULE + case AX25_P_NETROM: + if (ax25_protocol_is_registered(AX25_P_NETROM)) + return -ESOCKTNOSUPPORT; + break; +#endif +#ifdef CONFIG_ROSE_MODULE + case AX25_P_ROSE: + if (ax25_protocol_is_registered(AX25_P_ROSE)) + return -ESOCKTNOSUPPORT; + break; +#endif + default: + break; + } + break; + + case SOCK_RAW: + if (!capable(CAP_NET_RAW)) + return -EPERM; + break; + default: + return -ESOCKTNOSUPPORT; + } + + sk = sk_alloc(net, PF_AX25, GFP_ATOMIC, &ax25_proto, kern); + if (sk == NULL) + return -ENOMEM; + + ax25 = ax25_sk(sk)->cb = ax25_create_cb(); + if (!ax25) { + sk_free(sk); + return -ENOMEM; + } + + sock_init_data(sock, sk); + + sk->sk_destruct = ax25_free_sock; + sock->ops = &ax25_proto_ops; + sk->sk_protocol = protocol; + + ax25->sk = sk; + + return 0; +} + +struct sock *ax25_make_new(struct sock *osk, struct ax25_dev *ax25_dev) +{ + struct sock *sk; + ax25_cb *ax25, *oax25; + + sk = sk_alloc(sock_net(osk), PF_AX25, GFP_ATOMIC, osk->sk_prot, 0); + if (sk == NULL) + return NULL; + + if ((ax25 = ax25_create_cb()) == NULL) { + sk_free(sk); + return NULL; + } + + switch (osk->sk_type) { + case SOCK_DGRAM: + break; + case SOCK_SEQPACKET: + break; + default: + sk_free(sk); + ax25_cb_put(ax25); + return NULL; + } + + sock_init_data(NULL, sk); + + sk->sk_type = osk->sk_type; + sk->sk_priority = osk->sk_priority; + sk->sk_protocol = osk->sk_protocol; + sk->sk_rcvbuf = osk->sk_rcvbuf; + sk->sk_sndbuf = osk->sk_sndbuf; + sk->sk_state = TCP_ESTABLISHED; + sock_copy_flags(sk, osk); + + oax25 = sk_to_ax25(osk); + + ax25->modulus = oax25->modulus; + ax25->backoff = oax25->backoff; + ax25->pidincl = oax25->pidincl; + ax25->iamdigi = oax25->iamdigi; + ax25->rtt = oax25->rtt; + ax25->t1 = oax25->t1; + ax25->t2 = oax25->t2; + ax25->t3 = oax25->t3; + ax25->n2 = oax25->n2; + ax25->idle = oax25->idle; + ax25->paclen = oax25->paclen; + ax25->window = oax25->window; + + ax25->ax25_dev = ax25_dev; + ax25->source_addr = oax25->source_addr; + + if (oax25->digipeat != NULL) { + ax25->digipeat = kmemdup(oax25->digipeat, sizeof(ax25_digi), + GFP_ATOMIC); + if (ax25->digipeat == NULL) { + sk_free(sk); + ax25_cb_put(ax25); + return NULL; + } + } + + ax25_sk(sk)->cb = ax25; + sk->sk_destruct = ax25_free_sock; + ax25->sk = sk; + + return sk; +} + +static int ax25_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + ax25_cb *ax25; + ax25_dev *ax25_dev; + + if (sk == NULL) + return 0; + + sock_hold(sk); + lock_sock(sk); + sock_orphan(sk); + ax25 = sk_to_ax25(sk); + ax25_dev = ax25->ax25_dev; + + if (sk->sk_type == SOCK_SEQPACKET) { + switch (ax25->state) { + case AX25_STATE_0: + if (!sock_flag(ax25->sk, SOCK_DEAD)) { + release_sock(sk); + ax25_disconnect(ax25, 0); + lock_sock(sk); + } + ax25_destroy_socket(ax25); + break; + + case AX25_STATE_1: + case AX25_STATE_2: + ax25_send_control(ax25, AX25_DISC, AX25_POLLON, AX25_COMMAND); + release_sock(sk); + ax25_disconnect(ax25, 0); + lock_sock(sk); + if (!sock_flag(ax25->sk, SOCK_DESTROY)) + ax25_destroy_socket(ax25); + break; + + case AX25_STATE_3: + case AX25_STATE_4: + ax25_clear_queues(ax25); + ax25->n2count = 0; + + switch (ax25->ax25_dev->values[AX25_VALUES_PROTOCOL]) { + case AX25_PROTO_STD_SIMPLEX: + case AX25_PROTO_STD_DUPLEX: + ax25_send_control(ax25, + AX25_DISC, + AX25_POLLON, + AX25_COMMAND); + ax25_stop_t2timer(ax25); + ax25_stop_t3timer(ax25); + ax25_stop_idletimer(ax25); + break; +#ifdef CONFIG_AX25_DAMA_SLAVE + case AX25_PROTO_DAMA_SLAVE: + ax25_stop_t3timer(ax25); + ax25_stop_idletimer(ax25); + break; +#endif + } + ax25_calculate_t1(ax25); + ax25_start_t1timer(ax25); + ax25->state = AX25_STATE_2; + sk->sk_state = TCP_CLOSE; + sk->sk_shutdown |= SEND_SHUTDOWN; + sk->sk_state_change(sk); + sock_set_flag(sk, SOCK_DESTROY); + break; + + default: + break; + } + } else { + sk->sk_state = TCP_CLOSE; + sk->sk_shutdown |= SEND_SHUTDOWN; + sk->sk_state_change(sk); + ax25_destroy_socket(ax25); + } + if (ax25_dev) { + if (!ax25_dev->device_up) { + del_timer_sync(&ax25->timer); + del_timer_sync(&ax25->t1timer); + del_timer_sync(&ax25->t2timer); + del_timer_sync(&ax25->t3timer); + del_timer_sync(&ax25->idletimer); + } + netdev_put(ax25_dev->dev, &ax25->dev_tracker); + ax25_dev_put(ax25_dev); + } + + sock->sk = NULL; + release_sock(sk); + sock_put(sk); + + return 0; +} + +/* + * We support a funny extension here so you can (as root) give any callsign + * digipeated via a local address as source. This hack is obsolete now + * that we've implemented support for SO_BINDTODEVICE. It is however small + * and trivially backward compatible. + */ +static int ax25_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +{ + struct sock *sk = sock->sk; + struct full_sockaddr_ax25 *addr = (struct full_sockaddr_ax25 *)uaddr; + ax25_dev *ax25_dev = NULL; + ax25_uid_assoc *user; + ax25_address call; + ax25_cb *ax25; + int err = 0; + + if (addr_len != sizeof(struct sockaddr_ax25) && + addr_len != sizeof(struct full_sockaddr_ax25)) + /* support for old structure may go away some time + * ax25_bind(): uses old (6 digipeater) socket structure. + */ + if ((addr_len < sizeof(struct sockaddr_ax25) + sizeof(ax25_address) * 6) || + (addr_len > sizeof(struct full_sockaddr_ax25))) + return -EINVAL; + + if (addr->fsa_ax25.sax25_family != AF_AX25) + return -EINVAL; + + user = ax25_findbyuid(current_euid()); + if (user) { + call = user->call; + ax25_uid_put(user); + } else { + if (ax25_uid_policy && !capable(CAP_NET_ADMIN)) + return -EACCES; + + call = addr->fsa_ax25.sax25_call; + } + + lock_sock(sk); + + ax25 = sk_to_ax25(sk); + if (!sock_flag(sk, SOCK_ZAPPED)) { + err = -EINVAL; + goto out; + } + + ax25->source_addr = call; + + /* + * User already set interface with SO_BINDTODEVICE + */ + if (ax25->ax25_dev != NULL) + goto done; + + if (addr_len > sizeof(struct sockaddr_ax25) && addr->fsa_ax25.sax25_ndigis == 1) { + if (ax25cmp(&addr->fsa_digipeater[0], &null_ax25_address) != 0 && + (ax25_dev = ax25_addr_ax25dev(&addr->fsa_digipeater[0])) == NULL) { + err = -EADDRNOTAVAIL; + goto out; + } + } else { + if ((ax25_dev = ax25_addr_ax25dev(&addr->fsa_ax25.sax25_call)) == NULL) { + err = -EADDRNOTAVAIL; + goto out; + } + } + + if (ax25_dev) { + ax25_fillin_cb(ax25, ax25_dev); + netdev_hold(ax25_dev->dev, &ax25->dev_tracker, GFP_ATOMIC); + } + +done: + ax25_cb_add(ax25); + sock_reset_flag(sk, SOCK_ZAPPED); + +out: + release_sock(sk); + + return err; +} + +/* + * FIXME: nonblock behaviour looks like it may have a bug. + */ +static int __must_check ax25_connect(struct socket *sock, + struct sockaddr *uaddr, int addr_len, int flags) +{ + struct sock *sk = sock->sk; + ax25_cb *ax25 = sk_to_ax25(sk), *ax25t; + struct full_sockaddr_ax25 *fsa = (struct full_sockaddr_ax25 *)uaddr; + ax25_digi *digi = NULL; + int ct = 0, err = 0; + + /* + * some sanity checks. code further down depends on this + */ + + if (addr_len == sizeof(struct sockaddr_ax25)) + /* support for this will go away in early 2.5.x + * ax25_connect(): uses obsolete socket structure + */ + ; + else if (addr_len != sizeof(struct full_sockaddr_ax25)) + /* support for old structure may go away some time + * ax25_connect(): uses old (6 digipeater) socket structure. + */ + if ((addr_len < sizeof(struct sockaddr_ax25) + sizeof(ax25_address) * 6) || + (addr_len > sizeof(struct full_sockaddr_ax25))) + return -EINVAL; + + + if (fsa->fsa_ax25.sax25_family != AF_AX25) + return -EINVAL; + + lock_sock(sk); + + /* deal with restarts */ + if (sock->state == SS_CONNECTING) { + switch (sk->sk_state) { + case TCP_SYN_SENT: /* still trying */ + err = -EINPROGRESS; + goto out_release; + + case TCP_ESTABLISHED: /* connection established */ + sock->state = SS_CONNECTED; + goto out_release; + + case TCP_CLOSE: /* connection refused */ + sock->state = SS_UNCONNECTED; + err = -ECONNREFUSED; + goto out_release; + } + } + + if (sk->sk_state == TCP_ESTABLISHED && sk->sk_type == SOCK_SEQPACKET) { + err = -EISCONN; /* No reconnect on a seqpacket socket */ + goto out_release; + } + + sk->sk_state = TCP_CLOSE; + sock->state = SS_UNCONNECTED; + + kfree(ax25->digipeat); + ax25->digipeat = NULL; + + /* + * Handle digi-peaters to be used. + */ + if (addr_len > sizeof(struct sockaddr_ax25) && + fsa->fsa_ax25.sax25_ndigis != 0) { + /* Valid number of digipeaters ? */ + if (fsa->fsa_ax25.sax25_ndigis < 1 || + fsa->fsa_ax25.sax25_ndigis > AX25_MAX_DIGIS || + addr_len < sizeof(struct sockaddr_ax25) + + sizeof(ax25_address) * fsa->fsa_ax25.sax25_ndigis) { + err = -EINVAL; + goto out_release; + } + + if ((digi = kmalloc(sizeof(ax25_digi), GFP_KERNEL)) == NULL) { + err = -ENOBUFS; + goto out_release; + } + + digi->ndigi = fsa->fsa_ax25.sax25_ndigis; + digi->lastrepeat = -1; + + while (ct < fsa->fsa_ax25.sax25_ndigis) { + if ((fsa->fsa_digipeater[ct].ax25_call[6] & + AX25_HBIT) && ax25->iamdigi) { + digi->repeated[ct] = 1; + digi->lastrepeat = ct; + } else { + digi->repeated[ct] = 0; + } + digi->calls[ct] = fsa->fsa_digipeater[ct]; + ct++; + } + } + + /* + * Must bind first - autobinding in this may or may not work. If + * the socket is already bound, check to see if the device has + * been filled in, error if it hasn't. + */ + if (sock_flag(sk, SOCK_ZAPPED)) { + /* check if we can remove this feature. It is broken. */ + printk(KERN_WARNING "ax25_connect(): %s uses autobind, please contact jreuter@yaina.de\n", + current->comm); + if ((err = ax25_rt_autobind(ax25, &fsa->fsa_ax25.sax25_call)) < 0) { + kfree(digi); + goto out_release; + } + + ax25_fillin_cb(ax25, ax25->ax25_dev); + ax25_cb_add(ax25); + } else { + if (ax25->ax25_dev == NULL) { + kfree(digi); + err = -EHOSTUNREACH; + goto out_release; + } + } + + if (sk->sk_type == SOCK_SEQPACKET && + (ax25t=ax25_find_cb(&ax25->source_addr, &fsa->fsa_ax25.sax25_call, digi, + ax25->ax25_dev->dev))) { + kfree(digi); + err = -EADDRINUSE; /* Already such a connection */ + ax25_cb_put(ax25t); + goto out_release; + } + + ax25->dest_addr = fsa->fsa_ax25.sax25_call; + ax25->digipeat = digi; + + /* First the easy one */ + if (sk->sk_type != SOCK_SEQPACKET) { + sock->state = SS_CONNECTED; + sk->sk_state = TCP_ESTABLISHED; + goto out_release; + } + + /* Move to connecting socket, ax.25 lapb WAIT_UA.. */ + sock->state = SS_CONNECTING; + sk->sk_state = TCP_SYN_SENT; + + switch (ax25->ax25_dev->values[AX25_VALUES_PROTOCOL]) { + case AX25_PROTO_STD_SIMPLEX: + case AX25_PROTO_STD_DUPLEX: + ax25_std_establish_data_link(ax25); + break; + +#ifdef CONFIG_AX25_DAMA_SLAVE + case AX25_PROTO_DAMA_SLAVE: + ax25->modulus = AX25_MODULUS; + ax25->window = ax25->ax25_dev->values[AX25_VALUES_WINDOW]; + if (ax25->ax25_dev->dama.slave) + ax25_ds_establish_data_link(ax25); + else + ax25_std_establish_data_link(ax25); + break; +#endif + } + + ax25->state = AX25_STATE_1; + + ax25_start_heartbeat(ax25); + + /* Now the loop */ + if (sk->sk_state != TCP_ESTABLISHED && (flags & O_NONBLOCK)) { + err = -EINPROGRESS; + goto out_release; + } + + if (sk->sk_state == TCP_SYN_SENT) { + DEFINE_WAIT(wait); + + for (;;) { + prepare_to_wait(sk_sleep(sk), &wait, + TASK_INTERRUPTIBLE); + if (sk->sk_state != TCP_SYN_SENT) + break; + if (!signal_pending(current)) { + release_sock(sk); + schedule(); + lock_sock(sk); + continue; + } + err = -ERESTARTSYS; + break; + } + finish_wait(sk_sleep(sk), &wait); + + if (err) + goto out_release; + } + + if (sk->sk_state != TCP_ESTABLISHED) { + /* Not in ABM, not in WAIT_UA -> failed */ + sock->state = SS_UNCONNECTED; + err = sock_error(sk); /* Always set at this point */ + goto out_release; + } + + sock->state = SS_CONNECTED; + + err = 0; +out_release: + release_sock(sk); + + return err; +} + +static int ax25_accept(struct socket *sock, struct socket *newsock, int flags, + bool kern) +{ + struct sk_buff *skb; + struct sock *newsk; + DEFINE_WAIT(wait); + struct sock *sk; + int err = 0; + + if (sock->state != SS_UNCONNECTED) + return -EINVAL; + + if ((sk = sock->sk) == NULL) + return -EINVAL; + + lock_sock(sk); + if (sk->sk_type != SOCK_SEQPACKET) { + err = -EOPNOTSUPP; + goto out; + } + + if (sk->sk_state != TCP_LISTEN) { + err = -EINVAL; + goto out; + } + + /* + * The read queue this time is holding sockets ready to use + * hooked into the SABM we saved + */ + for (;;) { + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + skb = skb_dequeue(&sk->sk_receive_queue); + if (skb) + break; + + if (flags & O_NONBLOCK) { + err = -EWOULDBLOCK; + break; + } + if (!signal_pending(current)) { + release_sock(sk); + schedule(); + lock_sock(sk); + continue; + } + err = -ERESTARTSYS; + break; + } + finish_wait(sk_sleep(sk), &wait); + + if (err) + goto out; + + newsk = skb->sk; + sock_graft(newsk, newsock); + + /* Now attach up the new socket */ + kfree_skb(skb); + sk_acceptq_removed(sk); + newsock->state = SS_CONNECTED; + +out: + release_sock(sk); + + return err; +} + +static int ax25_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + struct full_sockaddr_ax25 *fsa = (struct full_sockaddr_ax25 *)uaddr; + struct sock *sk = sock->sk; + unsigned char ndigi, i; + ax25_cb *ax25; + int err = 0; + + memset(fsa, 0, sizeof(*fsa)); + lock_sock(sk); + ax25 = sk_to_ax25(sk); + + if (peer != 0) { + if (sk->sk_state != TCP_ESTABLISHED) { + err = -ENOTCONN; + goto out; + } + + fsa->fsa_ax25.sax25_family = AF_AX25; + fsa->fsa_ax25.sax25_call = ax25->dest_addr; + + if (ax25->digipeat != NULL) { + ndigi = ax25->digipeat->ndigi; + fsa->fsa_ax25.sax25_ndigis = ndigi; + for (i = 0; i < ndigi; i++) + fsa->fsa_digipeater[i] = + ax25->digipeat->calls[i]; + } + } else { + fsa->fsa_ax25.sax25_family = AF_AX25; + fsa->fsa_ax25.sax25_call = ax25->source_addr; + fsa->fsa_ax25.sax25_ndigis = 1; + if (ax25->ax25_dev != NULL) { + memcpy(&fsa->fsa_digipeater[0], + ax25->ax25_dev->dev->dev_addr, AX25_ADDR_LEN); + } else { + fsa->fsa_digipeater[0] = null_ax25_address; + } + } + err = sizeof (struct full_sockaddr_ax25); + +out: + release_sock(sk); + + return err; +} + +static int ax25_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) +{ + DECLARE_SOCKADDR(struct sockaddr_ax25 *, usax, msg->msg_name); + struct sock *sk = sock->sk; + struct sockaddr_ax25 sax; + struct sk_buff *skb; + ax25_digi dtmp, *dp; + ax25_cb *ax25; + size_t size; + int lv, err, addr_len = msg->msg_namelen; + + if (msg->msg_flags & ~(MSG_DONTWAIT|MSG_EOR|MSG_CMSG_COMPAT)) + return -EINVAL; + + lock_sock(sk); + ax25 = sk_to_ax25(sk); + + if (sock_flag(sk, SOCK_ZAPPED)) { + err = -EADDRNOTAVAIL; + goto out; + } + + if (sk->sk_shutdown & SEND_SHUTDOWN) { + send_sig(SIGPIPE, current, 0); + err = -EPIPE; + goto out; + } + + if (ax25->ax25_dev == NULL) { + err = -ENETUNREACH; + goto out; + } + + if (len > ax25->ax25_dev->dev->mtu) { + err = -EMSGSIZE; + goto out; + } + + if (usax != NULL) { + if (usax->sax25_family != AF_AX25) { + err = -EINVAL; + goto out; + } + + if (addr_len == sizeof(struct sockaddr_ax25)) + /* ax25_sendmsg(): uses obsolete socket structure */ + ; + else if (addr_len != sizeof(struct full_sockaddr_ax25)) + /* support for old structure may go away some time + * ax25_sendmsg(): uses old (6 digipeater) + * socket structure. + */ + if ((addr_len < sizeof(struct sockaddr_ax25) + sizeof(ax25_address) * 6) || + (addr_len > sizeof(struct full_sockaddr_ax25))) { + err = -EINVAL; + goto out; + } + + + if (addr_len > sizeof(struct sockaddr_ax25) && usax->sax25_ndigis != 0) { + int ct = 0; + struct full_sockaddr_ax25 *fsa = (struct full_sockaddr_ax25 *)usax; + + /* Valid number of digipeaters ? */ + if (usax->sax25_ndigis < 1 || + usax->sax25_ndigis > AX25_MAX_DIGIS || + addr_len < sizeof(struct sockaddr_ax25) + + sizeof(ax25_address) * usax->sax25_ndigis) { + err = -EINVAL; + goto out; + } + + dtmp.ndigi = usax->sax25_ndigis; + + while (ct < usax->sax25_ndigis) { + dtmp.repeated[ct] = 0; + dtmp.calls[ct] = fsa->fsa_digipeater[ct]; + ct++; + } + + dtmp.lastrepeat = 0; + } + + sax = *usax; + if (sk->sk_type == SOCK_SEQPACKET && + ax25cmp(&ax25->dest_addr, &sax.sax25_call)) { + err = -EISCONN; + goto out; + } + if (usax->sax25_ndigis == 0) + dp = NULL; + else + dp = &dtmp; + } else { + /* + * FIXME: 1003.1g - if the socket is like this because + * it has become closed (not started closed) and is VC + * we ought to SIGPIPE, EPIPE + */ + if (sk->sk_state != TCP_ESTABLISHED) { + err = -ENOTCONN; + goto out; + } + sax.sax25_family = AF_AX25; + sax.sax25_call = ax25->dest_addr; + dp = ax25->digipeat; + } + + /* Build a packet */ + /* Assume the worst case */ + size = len + ax25->ax25_dev->dev->hard_header_len; + + skb = sock_alloc_send_skb(sk, size, msg->msg_flags&MSG_DONTWAIT, &err); + if (skb == NULL) + goto out; + + skb_reserve(skb, size - len); + + /* User data follows immediately after the AX.25 data */ + if (memcpy_from_msg(skb_put(skb, len), msg, len)) { + err = -EFAULT; + kfree_skb(skb); + goto out; + } + + skb_reset_network_header(skb); + + /* Add the PID if one is not supplied by the user in the skb */ + if (!ax25->pidincl) + *(u8 *)skb_push(skb, 1) = sk->sk_protocol; + + if (sk->sk_type == SOCK_SEQPACKET) { + /* Connected mode sockets go via the LAPB machine */ + if (sk->sk_state != TCP_ESTABLISHED) { + kfree_skb(skb); + err = -ENOTCONN; + goto out; + } + + /* Shove it onto the queue and kick */ + ax25_output(ax25, ax25->paclen, skb); + + err = len; + goto out; + } + + skb_push(skb, 1 + ax25_addr_size(dp)); + + /* Building AX.25 Header */ + + /* Build an AX.25 header */ + lv = ax25_addr_build(skb->data, &ax25->source_addr, &sax.sax25_call, + dp, AX25_COMMAND, AX25_MODULUS); + + skb_set_transport_header(skb, lv); + + *skb_transport_header(skb) = AX25_UI; + + /* Datagram frames go straight out of the door as UI */ + ax25_queue_xmit(skb, ax25->ax25_dev->dev); + + err = len; + +out: + release_sock(sk); + + return err; +} + +static int ax25_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, + int flags) +{ + struct sock *sk = sock->sk; + struct sk_buff *skb, *last; + struct sk_buff_head *sk_queue; + int copied; + int err = 0; + int off = 0; + long timeo; + + lock_sock(sk); + /* + * This works for seqpacket too. The receiver has ordered the + * queue for us! We do one quick check first though + */ + if (sk->sk_type == SOCK_SEQPACKET && sk->sk_state != TCP_ESTABLISHED) { + err = -ENOTCONN; + goto out; + } + + /* We need support for non-blocking reads. */ + sk_queue = &sk->sk_receive_queue; + skb = __skb_try_recv_datagram(sk, sk_queue, flags, &off, &err, &last); + /* If no packet is available, release_sock(sk) and try again. */ + if (!skb) { + if (err != -EAGAIN) + goto out; + release_sock(sk); + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + while (timeo && !__skb_wait_for_more_packets(sk, sk_queue, &err, + &timeo, last)) { + skb = __skb_try_recv_datagram(sk, sk_queue, flags, &off, + &err, &last); + if (skb) + break; + + if (err != -EAGAIN) + goto done; + } + if (!skb) + goto done; + lock_sock(sk); + } + + if (!sk_to_ax25(sk)->pidincl) + skb_pull(skb, 1); /* Remove PID */ + + skb_reset_transport_header(skb); + copied = skb->len; + + if (copied > size) { + copied = size; + msg->msg_flags |= MSG_TRUNC; + } + + skb_copy_datagram_msg(skb, 0, msg, copied); + + if (msg->msg_name) { + ax25_digi digi; + ax25_address src; + const unsigned char *mac = skb_mac_header(skb); + DECLARE_SOCKADDR(struct sockaddr_ax25 *, sax, msg->msg_name); + + memset(sax, 0, sizeof(struct full_sockaddr_ax25)); + ax25_addr_parse(mac + 1, skb->data - mac - 1, &src, NULL, + &digi, NULL, NULL); + sax->sax25_family = AF_AX25; + /* We set this correctly, even though we may not let the + application know the digi calls further down (because it + did NOT ask to know them). This could get political... **/ + sax->sax25_ndigis = digi.ndigi; + sax->sax25_call = src; + + if (sax->sax25_ndigis != 0) { + int ct; + struct full_sockaddr_ax25 *fsa = (struct full_sockaddr_ax25 *)sax; + + for (ct = 0; ct < digi.ndigi; ct++) + fsa->fsa_digipeater[ct] = digi.calls[ct]; + } + msg->msg_namelen = sizeof(struct full_sockaddr_ax25); + } + + skb_free_datagram(sk, skb); + err = copied; + +out: + release_sock(sk); + +done: + return err; +} + +static int ax25_shutdown(struct socket *sk, int how) +{ + /* FIXME - generate DM and RNR states */ + return -EOPNOTSUPP; +} + +static int ax25_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + struct sock *sk = sock->sk; + void __user *argp = (void __user *)arg; + int res = 0; + + lock_sock(sk); + switch (cmd) { + case TIOCOUTQ: { + long amount; + + amount = sk->sk_sndbuf - sk_wmem_alloc_get(sk); + if (amount < 0) + amount = 0; + res = put_user(amount, (int __user *)argp); + break; + } + + case TIOCINQ: { + struct sk_buff *skb; + long amount = 0L; + /* These two are safe on a single CPU system as only user tasks fiddle here */ + if ((skb = skb_peek(&sk->sk_receive_queue)) != NULL) + amount = skb->len; + res = put_user(amount, (int __user *) argp); + break; + } + + case SIOCAX25ADDUID: /* Add a uid to the uid/call map table */ + case SIOCAX25DELUID: /* Delete a uid from the uid/call map table */ + case SIOCAX25GETUID: { + struct sockaddr_ax25 sax25; + if (copy_from_user(&sax25, argp, sizeof(sax25))) { + res = -EFAULT; + break; + } + res = ax25_uid_ioctl(cmd, &sax25); + break; + } + + case SIOCAX25NOUID: { /* Set the default policy (default/bar) */ + long amount; + if (!capable(CAP_NET_ADMIN)) { + res = -EPERM; + break; + } + if (get_user(amount, (long __user *)argp)) { + res = -EFAULT; + break; + } + if (amount < 0 || amount > AX25_NOUID_BLOCK) { + res = -EINVAL; + break; + } + ax25_uid_policy = amount; + res = 0; + break; + } + + case SIOCADDRT: + case SIOCDELRT: + case SIOCAX25OPTRT: + if (!capable(CAP_NET_ADMIN)) { + res = -EPERM; + break; + } + res = ax25_rt_ioctl(cmd, argp); + break; + + case SIOCAX25CTLCON: + if (!capable(CAP_NET_ADMIN)) { + res = -EPERM; + break; + } + res = ax25_ctl_ioctl(cmd, argp); + break; + + case SIOCAX25GETINFO: + case SIOCAX25GETINFOOLD: { + ax25_cb *ax25 = sk_to_ax25(sk); + struct ax25_info_struct ax25_info; + + ax25_info.t1 = ax25->t1 / HZ; + ax25_info.t2 = ax25->t2 / HZ; + ax25_info.t3 = ax25->t3 / HZ; + ax25_info.idle = ax25->idle / (60 * HZ); + ax25_info.n2 = ax25->n2; + ax25_info.t1timer = ax25_display_timer(&ax25->t1timer) / HZ; + ax25_info.t2timer = ax25_display_timer(&ax25->t2timer) / HZ; + ax25_info.t3timer = ax25_display_timer(&ax25->t3timer) / HZ; + ax25_info.idletimer = ax25_display_timer(&ax25->idletimer) / (60 * HZ); + ax25_info.n2count = ax25->n2count; + ax25_info.state = ax25->state; + ax25_info.rcv_q = sk_rmem_alloc_get(sk); + ax25_info.snd_q = sk_wmem_alloc_get(sk); + ax25_info.vs = ax25->vs; + ax25_info.vr = ax25->vr; + ax25_info.va = ax25->va; + ax25_info.vs_max = ax25->vs; /* reserved */ + ax25_info.paclen = ax25->paclen; + ax25_info.window = ax25->window; + + /* old structure? */ + if (cmd == SIOCAX25GETINFOOLD) { + static int warned = 0; + if (!warned) { + printk(KERN_INFO "%s uses old SIOCAX25GETINFO\n", + current->comm); + warned=1; + } + + if (copy_to_user(argp, &ax25_info, sizeof(struct ax25_info_struct_deprecated))) { + res = -EFAULT; + break; + } + } else { + if (copy_to_user(argp, &ax25_info, sizeof(struct ax25_info_struct))) { + res = -EINVAL; + break; + } + } + res = 0; + break; + } + + case SIOCAX25ADDFWD: + case SIOCAX25DELFWD: { + struct ax25_fwd_struct ax25_fwd; + if (!capable(CAP_NET_ADMIN)) { + res = -EPERM; + break; + } + if (copy_from_user(&ax25_fwd, argp, sizeof(ax25_fwd))) { + res = -EFAULT; + break; + } + res = ax25_fwd_ioctl(cmd, &ax25_fwd); + break; + } + + case SIOCGIFADDR: + case SIOCSIFADDR: + case SIOCGIFDSTADDR: + case SIOCSIFDSTADDR: + case SIOCGIFBRDADDR: + case SIOCSIFBRDADDR: + case SIOCGIFNETMASK: + case SIOCSIFNETMASK: + case SIOCGIFMETRIC: + case SIOCSIFMETRIC: + res = -EINVAL; + break; + + default: + res = -ENOIOCTLCMD; + break; + } + release_sock(sk); + + return res; +} + +#ifdef CONFIG_PROC_FS + +static void *ax25_info_start(struct seq_file *seq, loff_t *pos) + __acquires(ax25_list_lock) +{ + spin_lock_bh(&ax25_list_lock); + return seq_hlist_start(&ax25_list, *pos); +} + +static void *ax25_info_next(struct seq_file *seq, void *v, loff_t *pos) +{ + return seq_hlist_next(v, &ax25_list, pos); +} + +static void ax25_info_stop(struct seq_file *seq, void *v) + __releases(ax25_list_lock) +{ + spin_unlock_bh(&ax25_list_lock); +} + +static int ax25_info_show(struct seq_file *seq, void *v) +{ + ax25_cb *ax25 = hlist_entry(v, struct ax25_cb, ax25_node); + char buf[11]; + int k; + + + /* + * New format: + * magic dev src_addr dest_addr,digi1,digi2,.. st vs vr va t1 t1 t2 t2 t3 t3 idle idle n2 n2 rtt window paclen Snd-Q Rcv-Q inode + */ + + seq_printf(seq, "%p %s %s%s ", + ax25, + ax25->ax25_dev == NULL? "???" : ax25->ax25_dev->dev->name, + ax2asc(buf, &ax25->source_addr), + ax25->iamdigi? "*":""); + seq_printf(seq, "%s", ax2asc(buf, &ax25->dest_addr)); + + for (k=0; (ax25->digipeat != NULL) && (k < ax25->digipeat->ndigi); k++) { + seq_printf(seq, ",%s%s", + ax2asc(buf, &ax25->digipeat->calls[k]), + ax25->digipeat->repeated[k]? "*":""); + } + + seq_printf(seq, " %d %d %d %d %lu %lu %lu %lu %lu %lu %lu %lu %d %d %lu %d %d", + ax25->state, + ax25->vs, ax25->vr, ax25->va, + ax25_display_timer(&ax25->t1timer) / HZ, ax25->t1 / HZ, + ax25_display_timer(&ax25->t2timer) / HZ, ax25->t2 / HZ, + ax25_display_timer(&ax25->t3timer) / HZ, ax25->t3 / HZ, + ax25_display_timer(&ax25->idletimer) / (60 * HZ), + ax25->idle / (60 * HZ), + ax25->n2count, ax25->n2, + ax25->rtt / HZ, + ax25->window, + ax25->paclen); + + if (ax25->sk != NULL) { + seq_printf(seq, " %d %d %lu\n", + sk_wmem_alloc_get(ax25->sk), + sk_rmem_alloc_get(ax25->sk), + sock_i_ino(ax25->sk)); + } else { + seq_puts(seq, " * * *\n"); + } + return 0; +} + +static const struct seq_operations ax25_info_seqops = { + .start = ax25_info_start, + .next = ax25_info_next, + .stop = ax25_info_stop, + .show = ax25_info_show, +}; +#endif + +static const struct net_proto_family ax25_family_ops = { + .family = PF_AX25, + .create = ax25_create, + .owner = THIS_MODULE, +}; + +static const struct proto_ops ax25_proto_ops = { + .family = PF_AX25, + .owner = THIS_MODULE, + .release = ax25_release, + .bind = ax25_bind, + .connect = ax25_connect, + .socketpair = sock_no_socketpair, + .accept = ax25_accept, + .getname = ax25_getname, + .poll = datagram_poll, + .ioctl = ax25_ioctl, + .gettstamp = sock_gettstamp, + .listen = ax25_listen, + .shutdown = ax25_shutdown, + .setsockopt = ax25_setsockopt, + .getsockopt = ax25_getsockopt, + .sendmsg = ax25_sendmsg, + .recvmsg = ax25_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +/* + * Called by socket.c on kernel start up + */ +static struct packet_type ax25_packet_type __read_mostly = { + .type = cpu_to_be16(ETH_P_AX25), + .func = ax25_kiss_rcv, +}; + +static struct notifier_block ax25_dev_notifier = { + .notifier_call = ax25_device_event, +}; + +static int __init ax25_init(void) +{ + int rc = proto_register(&ax25_proto, 0); + + if (rc != 0) + goto out; + + sock_register(&ax25_family_ops); + dev_add_pack(&ax25_packet_type); + register_netdevice_notifier(&ax25_dev_notifier); + + proc_create_seq("ax25_route", 0444, init_net.proc_net, &ax25_rt_seqops); + proc_create_seq("ax25", 0444, init_net.proc_net, &ax25_info_seqops); + proc_create_seq("ax25_calls", 0444, init_net.proc_net, + &ax25_uid_seqops); +out: + return rc; +} +module_init(ax25_init); + + +MODULE_AUTHOR("Jonathan Naylor G4KLX "); +MODULE_DESCRIPTION("The amateur radio AX.25 link layer protocol"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_AX25); + +static void __exit ax25_exit(void) +{ + remove_proc_entry("ax25_route", init_net.proc_net); + remove_proc_entry("ax25", init_net.proc_net); + remove_proc_entry("ax25_calls", init_net.proc_net); + + unregister_netdevice_notifier(&ax25_dev_notifier); + + dev_remove_pack(&ax25_packet_type); + + sock_unregister(PF_AX25); + proto_unregister(&ax25_proto); + + ax25_rt_free(); + ax25_uid_free(); + ax25_dev_free(); +} +module_exit(ax25_exit); diff --git a/net/ax25/ax25_addr.c b/net/ax25/ax25_addr.c new file mode 100644 index 000000000..f68865a4d --- /dev/null +++ b/net/ax25/ax25_addr.c @@ -0,0 +1,303 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * The default broadcast address of an interface is QST-0; the default address + * is LINUX-1. The null address is defined as a callsign of all spaces with + * an SSID of zero. + */ + +const ax25_address ax25_bcast = + {{'Q' << 1, 'S' << 1, 'T' << 1, ' ' << 1, ' ' << 1, ' ' << 1, 0 << 1}}; +const ax25_address ax25_defaddr = + {{'L' << 1, 'I' << 1, 'N' << 1, 'U' << 1, 'X' << 1, ' ' << 1, 1 << 1}}; +const ax25_address null_ax25_address = + {{' ' << 1, ' ' << 1, ' ' << 1, ' ' << 1, ' ' << 1, ' ' << 1, 0 << 1}}; + +EXPORT_SYMBOL_GPL(ax25_bcast); +EXPORT_SYMBOL_GPL(ax25_defaddr); +EXPORT_SYMBOL(null_ax25_address); + +/* + * ax25 -> ascii conversion + */ +char *ax2asc(char *buf, const ax25_address *a) +{ + char c, *s; + int n; + + for (n = 0, s = buf; n < 6; n++) { + c = (a->ax25_call[n] >> 1) & 0x7F; + + if (c != ' ') *s++ = c; + } + + *s++ = '-'; + + if ((n = ((a->ax25_call[6] >> 1) & 0x0F)) > 9) { + *s++ = '1'; + n -= 10; + } + + *s++ = n + '0'; + *s++ = '\0'; + + if (*buf == '\0' || *buf == '-') + return "*"; + + return buf; + +} + +EXPORT_SYMBOL(ax2asc); + +/* + * ascii -> ax25 conversion + */ +void asc2ax(ax25_address *addr, const char *callsign) +{ + const char *s; + int n; + + for (s = callsign, n = 0; n < 6; n++) { + if (*s != '\0' && *s != '-') + addr->ax25_call[n] = *s++; + else + addr->ax25_call[n] = ' '; + addr->ax25_call[n] <<= 1; + addr->ax25_call[n] &= 0xFE; + } + + if (*s++ == '\0') { + addr->ax25_call[6] = 0x00; + return; + } + + addr->ax25_call[6] = *s++ - '0'; + + if (*s != '\0') { + addr->ax25_call[6] *= 10; + addr->ax25_call[6] += *s++ - '0'; + } + + addr->ax25_call[6] <<= 1; + addr->ax25_call[6] &= 0x1E; +} + +EXPORT_SYMBOL(asc2ax); + +/* + * Compare two ax.25 addresses + */ +int ax25cmp(const ax25_address *a, const ax25_address *b) +{ + int ct = 0; + + while (ct < 6) { + if ((a->ax25_call[ct] & 0xFE) != (b->ax25_call[ct] & 0xFE)) /* Clean off repeater bits */ + return 1; + ct++; + } + + if ((a->ax25_call[ct] & 0x1E) == (b->ax25_call[ct] & 0x1E)) /* SSID without control bit */ + return 0; + + return 2; /* Partial match */ +} + +EXPORT_SYMBOL(ax25cmp); + +/* + * Compare two AX.25 digipeater paths. + */ +int ax25digicmp(const ax25_digi *digi1, const ax25_digi *digi2) +{ + int i; + + if (digi1->ndigi != digi2->ndigi) + return 1; + + if (digi1->lastrepeat != digi2->lastrepeat) + return 1; + + for (i = 0; i < digi1->ndigi; i++) + if (ax25cmp(&digi1->calls[i], &digi2->calls[i]) != 0) + return 1; + + return 0; +} + +/* + * Given an AX.25 address pull of to, from, digi list, command/response and the start of data + * + */ +const unsigned char *ax25_addr_parse(const unsigned char *buf, int len, + ax25_address *src, ax25_address *dest, ax25_digi *digi, int *flags, + int *dama) +{ + int d = 0; + + if (len < 14) return NULL; + + if (flags != NULL) { + *flags = 0; + + if (buf[6] & AX25_CBIT) + *flags = AX25_COMMAND; + if (buf[13] & AX25_CBIT) + *flags = AX25_RESPONSE; + } + + if (dama != NULL) + *dama = ~buf[13] & AX25_DAMA_FLAG; + + /* Copy to, from */ + if (dest != NULL) + memcpy(dest, buf + 0, AX25_ADDR_LEN); + if (src != NULL) + memcpy(src, buf + 7, AX25_ADDR_LEN); + + buf += 2 * AX25_ADDR_LEN; + len -= 2 * AX25_ADDR_LEN; + + digi->lastrepeat = -1; + digi->ndigi = 0; + + while (!(buf[-1] & AX25_EBIT)) { + if (d >= AX25_MAX_DIGIS) + return NULL; + if (len < AX25_ADDR_LEN) + return NULL; + + memcpy(&digi->calls[d], buf, AX25_ADDR_LEN); + digi->ndigi = d + 1; + + if (buf[6] & AX25_HBIT) { + digi->repeated[d] = 1; + digi->lastrepeat = d; + } else { + digi->repeated[d] = 0; + } + + buf += AX25_ADDR_LEN; + len -= AX25_ADDR_LEN; + d++; + } + + return buf; +} + +/* + * Assemble an AX.25 header from the bits + */ +int ax25_addr_build(unsigned char *buf, const ax25_address *src, + const ax25_address *dest, const ax25_digi *d, int flag, int modulus) +{ + int len = 0; + int ct = 0; + + memcpy(buf, dest, AX25_ADDR_LEN); + buf[6] &= ~(AX25_EBIT | AX25_CBIT); + buf[6] |= AX25_SSSID_SPARE; + + if (flag == AX25_COMMAND) buf[6] |= AX25_CBIT; + + buf += AX25_ADDR_LEN; + len += AX25_ADDR_LEN; + + memcpy(buf, src, AX25_ADDR_LEN); + buf[6] &= ~(AX25_EBIT | AX25_CBIT); + buf[6] &= ~AX25_SSSID_SPARE; + + if (modulus == AX25_MODULUS) + buf[6] |= AX25_SSSID_SPARE; + else + buf[6] |= AX25_ESSID_SPARE; + + if (flag == AX25_RESPONSE) buf[6] |= AX25_CBIT; + + /* + * Fast path the normal digiless path + */ + if (d == NULL || d->ndigi == 0) { + buf[6] |= AX25_EBIT; + return 2 * AX25_ADDR_LEN; + } + + buf += AX25_ADDR_LEN; + len += AX25_ADDR_LEN; + + while (ct < d->ndigi) { + memcpy(buf, &d->calls[ct], AX25_ADDR_LEN); + + if (d->repeated[ct]) + buf[6] |= AX25_HBIT; + else + buf[6] &= ~AX25_HBIT; + + buf[6] &= ~AX25_EBIT; + buf[6] |= AX25_SSSID_SPARE; + + buf += AX25_ADDR_LEN; + len += AX25_ADDR_LEN; + ct++; + } + + buf[-1] |= AX25_EBIT; + + return len; +} + +int ax25_addr_size(const ax25_digi *dp) +{ + if (dp == NULL) + return 2 * AX25_ADDR_LEN; + + return AX25_ADDR_LEN * (2 + dp->ndigi); +} + +/* + * Reverse Digipeat List. May not pass both parameters as same struct + */ +void ax25_digi_invert(const ax25_digi *in, ax25_digi *out) +{ + int ct; + + out->ndigi = in->ndigi; + out->lastrepeat = in->ndigi - in->lastrepeat - 2; + + /* Invert the digipeaters */ + for (ct = 0; ct < in->ndigi; ct++) { + out->calls[ct] = in->calls[in->ndigi - ct - 1]; + + if (ct <= out->lastrepeat) { + out->calls[ct].ax25_call[6] |= AX25_HBIT; + out->repeated[ct] = 1; + } else { + out->calls[ct].ax25_call[6] &= ~AX25_HBIT; + out->repeated[ct] = 0; + } + } +} diff --git a/net/ax25/ax25_dev.c b/net/ax25/ax25_dev.c new file mode 100644 index 000000000..c5462486d --- /dev/null +++ b/net/ax25/ax25_dev.c @@ -0,0 +1,215 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +ax25_dev *ax25_dev_list; +DEFINE_SPINLOCK(ax25_dev_lock); + +ax25_dev *ax25_addr_ax25dev(ax25_address *addr) +{ + ax25_dev *ax25_dev, *res = NULL; + + spin_lock_bh(&ax25_dev_lock); + for (ax25_dev = ax25_dev_list; ax25_dev != NULL; ax25_dev = ax25_dev->next) + if (ax25cmp(addr, (const ax25_address *)ax25_dev->dev->dev_addr) == 0) { + res = ax25_dev; + ax25_dev_hold(ax25_dev); + } + spin_unlock_bh(&ax25_dev_lock); + + return res; +} + +/* + * This is called when an interface is brought up. These are + * reasonable defaults. + */ +void ax25_dev_device_up(struct net_device *dev) +{ + ax25_dev *ax25_dev; + + ax25_dev = kzalloc(sizeof(*ax25_dev), GFP_KERNEL); + if (!ax25_dev) { + printk(KERN_ERR "AX.25: ax25_dev_device_up - out of memory\n"); + return; + } + + refcount_set(&ax25_dev->refcount, 1); + dev->ax25_ptr = ax25_dev; + ax25_dev->dev = dev; + netdev_hold(dev, &ax25_dev->dev_tracker, GFP_KERNEL); + ax25_dev->forward = NULL; + ax25_dev->device_up = true; + + ax25_dev->values[AX25_VALUES_IPDEFMODE] = AX25_DEF_IPDEFMODE; + ax25_dev->values[AX25_VALUES_AXDEFMODE] = AX25_DEF_AXDEFMODE; + ax25_dev->values[AX25_VALUES_BACKOFF] = AX25_DEF_BACKOFF; + ax25_dev->values[AX25_VALUES_CONMODE] = AX25_DEF_CONMODE; + ax25_dev->values[AX25_VALUES_WINDOW] = AX25_DEF_WINDOW; + ax25_dev->values[AX25_VALUES_EWINDOW] = AX25_DEF_EWINDOW; + ax25_dev->values[AX25_VALUES_T1] = AX25_DEF_T1; + ax25_dev->values[AX25_VALUES_T2] = AX25_DEF_T2; + ax25_dev->values[AX25_VALUES_T3] = AX25_DEF_T3; + ax25_dev->values[AX25_VALUES_IDLE] = AX25_DEF_IDLE; + ax25_dev->values[AX25_VALUES_N2] = AX25_DEF_N2; + ax25_dev->values[AX25_VALUES_PACLEN] = AX25_DEF_PACLEN; + ax25_dev->values[AX25_VALUES_PROTOCOL] = AX25_DEF_PROTOCOL; + ax25_dev->values[AX25_VALUES_DS_TIMEOUT]= AX25_DEF_DS_TIMEOUT; + +#if defined(CONFIG_AX25_DAMA_SLAVE) || defined(CONFIG_AX25_DAMA_MASTER) + ax25_ds_setup_timer(ax25_dev); +#endif + + spin_lock_bh(&ax25_dev_lock); + ax25_dev->next = ax25_dev_list; + ax25_dev_list = ax25_dev; + spin_unlock_bh(&ax25_dev_lock); + ax25_dev_hold(ax25_dev); + + ax25_register_dev_sysctl(ax25_dev); +} + +void ax25_dev_device_down(struct net_device *dev) +{ + ax25_dev *s, *ax25_dev; + + if ((ax25_dev = ax25_dev_ax25dev(dev)) == NULL) + return; + + ax25_unregister_dev_sysctl(ax25_dev); + + spin_lock_bh(&ax25_dev_lock); + +#ifdef CONFIG_AX25_DAMA_SLAVE + ax25_ds_del_timer(ax25_dev); +#endif + + /* + * Remove any packet forwarding that points to this device. + */ + for (s = ax25_dev_list; s != NULL; s = s->next) + if (s->forward == dev) + s->forward = NULL; + + if ((s = ax25_dev_list) == ax25_dev) { + ax25_dev_list = s->next; + goto unlock_put; + } + + while (s != NULL && s->next != NULL) { + if (s->next == ax25_dev) { + s->next = ax25_dev->next; + goto unlock_put; + } + + s = s->next; + } + spin_unlock_bh(&ax25_dev_lock); + dev->ax25_ptr = NULL; + ax25_dev_put(ax25_dev); + return; + +unlock_put: + spin_unlock_bh(&ax25_dev_lock); + ax25_dev_put(ax25_dev); + dev->ax25_ptr = NULL; + netdev_put(dev, &ax25_dev->dev_tracker); + ax25_dev_put(ax25_dev); +} + +int ax25_fwd_ioctl(unsigned int cmd, struct ax25_fwd_struct *fwd) +{ + ax25_dev *ax25_dev, *fwd_dev; + + if ((ax25_dev = ax25_addr_ax25dev(&fwd->port_from)) == NULL) + return -EINVAL; + + switch (cmd) { + case SIOCAX25ADDFWD: + fwd_dev = ax25_addr_ax25dev(&fwd->port_to); + if (!fwd_dev) { + ax25_dev_put(ax25_dev); + return -EINVAL; + } + if (ax25_dev->forward) { + ax25_dev_put(fwd_dev); + ax25_dev_put(ax25_dev); + return -EINVAL; + } + ax25_dev->forward = fwd_dev->dev; + ax25_dev_put(fwd_dev); + ax25_dev_put(ax25_dev); + break; + + case SIOCAX25DELFWD: + if (!ax25_dev->forward) { + ax25_dev_put(ax25_dev); + return -EINVAL; + } + ax25_dev->forward = NULL; + ax25_dev_put(ax25_dev); + break; + + default: + ax25_dev_put(ax25_dev); + return -EINVAL; + } + + return 0; +} + +struct net_device *ax25_fwd_dev(struct net_device *dev) +{ + ax25_dev *ax25_dev; + + if ((ax25_dev = ax25_dev_ax25dev(dev)) == NULL) + return dev; + + if (ax25_dev->forward == NULL) + return dev; + + return ax25_dev->forward; +} + +/* + * Free all memory associated with device structures. + */ +void __exit ax25_dev_free(void) +{ + ax25_dev *s, *ax25_dev; + + spin_lock_bh(&ax25_dev_lock); + ax25_dev = ax25_dev_list; + while (ax25_dev != NULL) { + s = ax25_dev; + netdev_put(ax25_dev->dev, &ax25_dev->dev_tracker); + ax25_dev = ax25_dev->next; + kfree(s); + } + ax25_dev_list = NULL; + spin_unlock_bh(&ax25_dev_lock); +} diff --git a/net/ax25/ax25_ds_in.c b/net/ax25/ax25_ds_in.c new file mode 100644 index 000000000..c62f8fb06 --- /dev/null +++ b/net/ax25/ax25_ds_in.c @@ -0,0 +1,298 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright (C) Joerg Reuter DL1BKE (jreuter@yaina.de) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * State machine for state 1, Awaiting Connection State. + * The handling of the timer(s) is in file ax25_ds_timer.c. + * Handling of state 0 and connection release is in ax25.c. + */ +static int ax25_ds_state1_machine(ax25_cb *ax25, struct sk_buff *skb, int frametype, int pf, int type) +{ + switch (frametype) { + case AX25_SABM: + ax25->modulus = AX25_MODULUS; + ax25->window = ax25->ax25_dev->values[AX25_VALUES_WINDOW]; + ax25_send_control(ax25, AX25_UA, pf, AX25_RESPONSE); + break; + + case AX25_SABME: + ax25->modulus = AX25_EMODULUS; + ax25->window = ax25->ax25_dev->values[AX25_VALUES_EWINDOW]; + ax25_send_control(ax25, AX25_UA, pf, AX25_RESPONSE); + break; + + case AX25_DISC: + ax25_send_control(ax25, AX25_DM, pf, AX25_RESPONSE); + break; + + case AX25_UA: + ax25_calculate_rtt(ax25); + ax25_stop_t1timer(ax25); + ax25_start_t3timer(ax25); + ax25_start_idletimer(ax25); + ax25->vs = 0; + ax25->va = 0; + ax25->vr = 0; + ax25->state = AX25_STATE_3; + ax25->n2count = 0; + if (ax25->sk != NULL) { + bh_lock_sock(ax25->sk); + ax25->sk->sk_state = TCP_ESTABLISHED; + /* + * For WAIT_SABM connections we will produce an accept + * ready socket here + */ + if (!sock_flag(ax25->sk, SOCK_DEAD)) + ax25->sk->sk_state_change(ax25->sk); + bh_unlock_sock(ax25->sk); + } + ax25_dama_on(ax25); + + /* according to DK4EG's spec we are required to + * send a RR RESPONSE FINAL NR=0. + */ + + ax25_std_enquiry_response(ax25); + break; + + case AX25_DM: + if (pf) + ax25_disconnect(ax25, ECONNREFUSED); + break; + + default: + if (pf) + ax25_send_control(ax25, AX25_SABM, AX25_POLLON, AX25_COMMAND); + break; + } + + return 0; +} + +/* + * State machine for state 2, Awaiting Release State. + * The handling of the timer(s) is in file ax25_ds_timer.c + * Handling of state 0 and connection release is in ax25.c. + */ +static int ax25_ds_state2_machine(ax25_cb *ax25, struct sk_buff *skb, int frametype, int pf, int type) +{ + switch (frametype) { + case AX25_SABM: + case AX25_SABME: + ax25_send_control(ax25, AX25_DISC, AX25_POLLON, AX25_COMMAND); + ax25_dama_off(ax25); + break; + + case AX25_DISC: + ax25_send_control(ax25, AX25_UA, pf, AX25_RESPONSE); + ax25_dama_off(ax25); + ax25_disconnect(ax25, 0); + break; + + case AX25_DM: + case AX25_UA: + if (pf) { + ax25_dama_off(ax25); + ax25_disconnect(ax25, 0); + } + break; + + case AX25_I: + case AX25_REJ: + case AX25_RNR: + case AX25_RR: + if (pf) { + ax25_send_control(ax25, AX25_DISC, AX25_POLLON, AX25_COMMAND); + ax25_dama_off(ax25); + } + break; + + default: + break; + } + + return 0; +} + +/* + * State machine for state 3, Connected State. + * The handling of the timer(s) is in file ax25_timer.c + * Handling of state 0 and connection release is in ax25.c. + */ +static int ax25_ds_state3_machine(ax25_cb *ax25, struct sk_buff *skb, int frametype, int ns, int nr, int pf, int type) +{ + int queued = 0; + + switch (frametype) { + case AX25_SABM: + case AX25_SABME: + if (frametype == AX25_SABM) { + ax25->modulus = AX25_MODULUS; + ax25->window = ax25->ax25_dev->values[AX25_VALUES_WINDOW]; + } else { + ax25->modulus = AX25_EMODULUS; + ax25->window = ax25->ax25_dev->values[AX25_VALUES_EWINDOW]; + } + ax25_send_control(ax25, AX25_UA, pf, AX25_RESPONSE); + ax25_stop_t1timer(ax25); + ax25_start_t3timer(ax25); + ax25_start_idletimer(ax25); + ax25->condition = 0x00; + ax25->vs = 0; + ax25->va = 0; + ax25->vr = 0; + ax25_requeue_frames(ax25); + ax25_dama_on(ax25); + break; + + case AX25_DISC: + ax25_send_control(ax25, AX25_UA, pf, AX25_RESPONSE); + ax25_dama_off(ax25); + ax25_disconnect(ax25, 0); + break; + + case AX25_DM: + ax25_dama_off(ax25); + ax25_disconnect(ax25, ECONNRESET); + break; + + case AX25_RR: + case AX25_RNR: + if (frametype == AX25_RR) + ax25->condition &= ~AX25_COND_PEER_RX_BUSY; + else + ax25->condition |= AX25_COND_PEER_RX_BUSY; + + if (ax25_validate_nr(ax25, nr)) { + if (ax25_check_iframes_acked(ax25, nr)) + ax25->n2count=0; + if (type == AX25_COMMAND && pf) + ax25_ds_enquiry_response(ax25); + } else { + ax25_ds_nr_error_recovery(ax25); + ax25->state = AX25_STATE_1; + } + break; + + case AX25_REJ: + ax25->condition &= ~AX25_COND_PEER_RX_BUSY; + + if (ax25_validate_nr(ax25, nr)) { + if (ax25->va != nr) + ax25->n2count=0; + + ax25_frames_acked(ax25, nr); + ax25_calculate_rtt(ax25); + ax25_stop_t1timer(ax25); + ax25_start_t3timer(ax25); + ax25_requeue_frames(ax25); + + if (type == AX25_COMMAND && pf) + ax25_ds_enquiry_response(ax25); + } else { + ax25_ds_nr_error_recovery(ax25); + ax25->state = AX25_STATE_1; + } + break; + + case AX25_I: + if (!ax25_validate_nr(ax25, nr)) { + ax25_ds_nr_error_recovery(ax25); + ax25->state = AX25_STATE_1; + break; + } + if (ax25->condition & AX25_COND_PEER_RX_BUSY) { + ax25_frames_acked(ax25, nr); + ax25->n2count = 0; + } else { + if (ax25_check_iframes_acked(ax25, nr)) + ax25->n2count = 0; + } + if (ax25->condition & AX25_COND_OWN_RX_BUSY) { + if (pf) ax25_ds_enquiry_response(ax25); + break; + } + if (ns == ax25->vr) { + ax25->vr = (ax25->vr + 1) % ax25->modulus; + queued = ax25_rx_iframe(ax25, skb); + if (ax25->condition & AX25_COND_OWN_RX_BUSY) + ax25->vr = ns; /* ax25->vr - 1 */ + ax25->condition &= ~AX25_COND_REJECT; + if (pf) { + ax25_ds_enquiry_response(ax25); + } else { + if (!(ax25->condition & AX25_COND_ACK_PENDING)) { + ax25->condition |= AX25_COND_ACK_PENDING; + ax25_start_t2timer(ax25); + } + } + } else { + if (ax25->condition & AX25_COND_REJECT) { + if (pf) ax25_ds_enquiry_response(ax25); + } else { + ax25->condition |= AX25_COND_REJECT; + ax25_ds_enquiry_response(ax25); + ax25->condition &= ~AX25_COND_ACK_PENDING; + } + } + break; + + case AX25_FRMR: + case AX25_ILLEGAL: + ax25_ds_establish_data_link(ax25); + ax25->state = AX25_STATE_1; + break; + + default: + break; + } + + return queued; +} + +/* + * Higher level upcall for a LAPB frame + */ +int ax25_ds_frame_in(ax25_cb *ax25, struct sk_buff *skb, int type) +{ + int queued = 0, frametype, ns, nr, pf; + + frametype = ax25_decode(ax25, skb, &ns, &nr, &pf); + + switch (ax25->state) { + case AX25_STATE_1: + queued = ax25_ds_state1_machine(ax25, skb, frametype, pf, type); + break; + case AX25_STATE_2: + queued = ax25_ds_state2_machine(ax25, skb, frametype, pf, type); + break; + case AX25_STATE_3: + queued = ax25_ds_state3_machine(ax25, skb, frametype, ns, nr, pf, type); + break; + } + + return queued; +} diff --git a/net/ax25/ax25_ds_subr.c b/net/ax25/ax25_ds_subr.c new file mode 100644 index 000000000..f00e27df3 --- /dev/null +++ b/net/ax25/ax25_ds_subr.c @@ -0,0 +1,204 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright (C) Joerg Reuter DL1BKE (jreuter@yaina.de) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void ax25_ds_nr_error_recovery(ax25_cb *ax25) +{ + ax25_ds_establish_data_link(ax25); +} + +/* + * dl1bke 960114: transmit I frames on DAMA poll + */ +void ax25_ds_enquiry_response(ax25_cb *ax25) +{ + ax25_cb *ax25o; + + /* Please note that neither DK4EG's nor DG2FEF's + * DAMA spec mention the following behaviour as seen + * with TheFirmware: + * + * DB0ACH->DL1BKE [DAMA] + * DL1BKE->DB0ACH + * DL1BKE-7->DB0PRA-6 DB0ACH + * DL1BKE->DB0ACH + * + * The Flexnet DAMA Master implementation apparently + * insists on the "proper" AX.25 behaviour: + * + * DB0ACH->DL1BKE [DAMA] + * DL1BKE->DB0ACH + * DL1BKE->DB0ACH + * DL1BKE-7->DB0PRA-6 DB0ACH + * + * Flexnet refuses to send us *any* I frame if we send + * a REJ in case AX25_COND_REJECT is set. It is superfluous in + * this mode anyway (a RR or RNR invokes the retransmission). + * Is this a Flexnet bug? + */ + + ax25_std_enquiry_response(ax25); + + if (!(ax25->condition & AX25_COND_PEER_RX_BUSY)) { + ax25_requeue_frames(ax25); + ax25_kick(ax25); + } + + if (ax25->state == AX25_STATE_1 || ax25->state == AX25_STATE_2 || skb_peek(&ax25->ack_queue) != NULL) + ax25_ds_t1_timeout(ax25); + else + ax25->n2count = 0; + + ax25_start_t3timer(ax25); + ax25_ds_set_timer(ax25->ax25_dev); + + spin_lock(&ax25_list_lock); + ax25_for_each(ax25o, &ax25_list) { + if (ax25o == ax25) + continue; + + if (ax25o->ax25_dev != ax25->ax25_dev) + continue; + + if (ax25o->state == AX25_STATE_1 || ax25o->state == AX25_STATE_2) { + ax25_ds_t1_timeout(ax25o); + continue; + } + + if (!(ax25o->condition & AX25_COND_PEER_RX_BUSY) && ax25o->state == AX25_STATE_3) { + ax25_requeue_frames(ax25o); + ax25_kick(ax25o); + } + + if (ax25o->state == AX25_STATE_1 || ax25o->state == AX25_STATE_2 || skb_peek(&ax25o->ack_queue) != NULL) + ax25_ds_t1_timeout(ax25o); + + /* do not start T3 for listening sockets (tnx DD8NE) */ + + if (ax25o->state != AX25_STATE_0) + ax25_start_t3timer(ax25o); + } + spin_unlock(&ax25_list_lock); +} + +void ax25_ds_establish_data_link(ax25_cb *ax25) +{ + ax25->condition &= AX25_COND_DAMA_MODE; + ax25->n2count = 0; + ax25_calculate_t1(ax25); + ax25_start_t1timer(ax25); + ax25_stop_t2timer(ax25); + ax25_start_t3timer(ax25); +} + +/* + * :::FIXME::: + * This is a kludge. Not all drivers recognize kiss commands. + * We need a driver level request to switch duplex mode, that does + * either SCC changing, PI config or KISS as required. Currently + * this request isn't reliable. + */ +static void ax25_kiss_cmd(ax25_dev *ax25_dev, unsigned char cmd, unsigned char param) +{ + struct sk_buff *skb; + unsigned char *p; + + if (ax25_dev->dev == NULL) + return; + + if ((skb = alloc_skb(2, GFP_ATOMIC)) == NULL) + return; + + skb_reset_network_header(skb); + p = skb_put(skb, 2); + + *p++ = cmd; + *p++ = param; + + skb->protocol = ax25_type_trans(skb, ax25_dev->dev); + + dev_queue_xmit(skb); +} + +/* + * A nasty problem arises if we count the number of DAMA connections + * wrong, especially when connections on the device already existed + * and our network node (or the sysop) decides to turn on DAMA Master + * mode. We thus flag the 'real' slave connections with + * ax25->dama_slave=1 and look on every disconnect if still slave + * connections exist. + */ +static int ax25_check_dama_slave(ax25_dev *ax25_dev) +{ + ax25_cb *ax25; + int res = 0; + + spin_lock(&ax25_list_lock); + ax25_for_each(ax25, &ax25_list) + if (ax25->ax25_dev == ax25_dev && (ax25->condition & AX25_COND_DAMA_MODE) && ax25->state > AX25_STATE_1) { + res = 1; + break; + } + spin_unlock(&ax25_list_lock); + + return res; +} + +static void ax25_dev_dama_on(ax25_dev *ax25_dev) +{ + if (ax25_dev == NULL) + return; + + if (ax25_dev->dama.slave == 0) + ax25_kiss_cmd(ax25_dev, 5, 1); + + ax25_dev->dama.slave = 1; + ax25_ds_set_timer(ax25_dev); +} + +void ax25_dev_dama_off(ax25_dev *ax25_dev) +{ + if (ax25_dev == NULL) + return; + + if (ax25_dev->dama.slave && !ax25_check_dama_slave(ax25_dev)) { + ax25_kiss_cmd(ax25_dev, 5, 0); + ax25_dev->dama.slave = 0; + ax25_ds_del_timer(ax25_dev); + } +} + +void ax25_dama_on(ax25_cb *ax25) +{ + ax25_dev_dama_on(ax25->ax25_dev); + ax25->condition |= AX25_COND_DAMA_MODE; +} + +void ax25_dama_off(ax25_cb *ax25) +{ + ax25->condition &= ~AX25_COND_DAMA_MODE; + ax25_dev_dama_off(ax25->ax25_dev); +} diff --git a/net/ax25/ax25_ds_timer.c b/net/ax25/ax25_ds_timer.c new file mode 100644 index 000000000..c4f8adbf8 --- /dev/null +++ b/net/ax25/ax25_ds_timer.c @@ -0,0 +1,235 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright (C) Joerg Reuter DL1BKE (jreuter@yaina.de) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void ax25_ds_timeout(struct timer_list *); + +/* + * Add DAMA slave timeout timer to timer list. + * Unlike the connection based timers the timeout function gets + * triggered every second. Please note that NET_AX25_DAMA_SLAVE_TIMEOUT + * (aka /proc/sys/net/ax25/{dev}/dama_slave_timeout) is still in + * 1/10th of a second. + */ + +void ax25_ds_setup_timer(ax25_dev *ax25_dev) +{ + timer_setup(&ax25_dev->dama.slave_timer, ax25_ds_timeout, 0); +} + +void ax25_ds_del_timer(ax25_dev *ax25_dev) +{ + if (ax25_dev) + del_timer(&ax25_dev->dama.slave_timer); +} + +void ax25_ds_set_timer(ax25_dev *ax25_dev) +{ + if (ax25_dev == NULL) /* paranoia */ + return; + + ax25_dev->dama.slave_timeout = + msecs_to_jiffies(ax25_dev->values[AX25_VALUES_DS_TIMEOUT]) / 10; + mod_timer(&ax25_dev->dama.slave_timer, jiffies + HZ); +} + +/* + * DAMA Slave Timeout + * Silently discard all (slave) connections in case our master forgot us... + */ + +static void ax25_ds_timeout(struct timer_list *t) +{ + ax25_dev *ax25_dev = from_timer(ax25_dev, t, dama.slave_timer); + ax25_cb *ax25; + + if (ax25_dev == NULL || !ax25_dev->dama.slave) + return; /* Yikes! */ + + if (!ax25_dev->dama.slave_timeout || --ax25_dev->dama.slave_timeout) { + ax25_ds_set_timer(ax25_dev); + return; + } + + spin_lock(&ax25_list_lock); + ax25_for_each(ax25, &ax25_list) { + if (ax25->ax25_dev != ax25_dev || !(ax25->condition & AX25_COND_DAMA_MODE)) + continue; + + ax25_send_control(ax25, AX25_DISC, AX25_POLLON, AX25_COMMAND); + ax25_disconnect(ax25, ETIMEDOUT); + } + spin_unlock(&ax25_list_lock); + + ax25_dev_dama_off(ax25_dev); +} + +void ax25_ds_heartbeat_expiry(ax25_cb *ax25) +{ + struct sock *sk=ax25->sk; + + if (sk) + bh_lock_sock(sk); + + switch (ax25->state) { + + case AX25_STATE_0: + case AX25_STATE_2: + /* Magic here: If we listen() and a new link dies before it + is accepted() it isn't 'dead' so doesn't get removed. */ + if (!sk || sock_flag(sk, SOCK_DESTROY) || + (sk->sk_state == TCP_LISTEN && + sock_flag(sk, SOCK_DEAD))) { + if (sk) { + sock_hold(sk); + ax25_destroy_socket(ax25); + bh_unlock_sock(sk); + /* Ungrab socket and destroy it */ + sock_put(sk); + } else + ax25_destroy_socket(ax25); + return; + } + break; + + case AX25_STATE_3: + /* + * Check the state of the receive buffer. + */ + if (sk != NULL) { + if (atomic_read(&sk->sk_rmem_alloc) < + (sk->sk_rcvbuf >> 1) && + (ax25->condition & AX25_COND_OWN_RX_BUSY)) { + ax25->condition &= ~AX25_COND_OWN_RX_BUSY; + ax25->condition &= ~AX25_COND_ACK_PENDING; + break; + } + } + break; + } + + if (sk) + bh_unlock_sock(sk); + + ax25_start_heartbeat(ax25); +} + +/* dl1bke 960114: T3 works much like the IDLE timeout, but + * gets reloaded with every frame for this + * connection. + */ +void ax25_ds_t3timer_expiry(ax25_cb *ax25) +{ + ax25_send_control(ax25, AX25_DISC, AX25_POLLON, AX25_COMMAND); + ax25_dama_off(ax25); + ax25_disconnect(ax25, ETIMEDOUT); +} + +/* dl1bke 960228: close the connection when IDLE expires. + * unlike T3 this timer gets reloaded only on + * I frames. + */ +void ax25_ds_idletimer_expiry(ax25_cb *ax25) +{ + ax25_clear_queues(ax25); + + ax25->n2count = 0; + ax25->state = AX25_STATE_2; + + ax25_calculate_t1(ax25); + ax25_start_t1timer(ax25); + ax25_stop_t3timer(ax25); + + if (ax25->sk != NULL) { + bh_lock_sock(ax25->sk); + ax25->sk->sk_state = TCP_CLOSE; + ax25->sk->sk_err = 0; + ax25->sk->sk_shutdown |= SEND_SHUTDOWN; + if (!sock_flag(ax25->sk, SOCK_DEAD)) { + ax25->sk->sk_state_change(ax25->sk); + sock_set_flag(ax25->sk, SOCK_DEAD); + } + bh_unlock_sock(ax25->sk); + } +} + +/* dl1bke 960114: The DAMA protocol requires to send data and SABM/DISC + * within the poll of any connected channel. Remember + * that we are not allowed to send anything unless we + * get polled by the Master. + * + * Thus we'll have to do parts of our T1 handling in + * ax25_enquiry_response(). + */ +void ax25_ds_t1_timeout(ax25_cb *ax25) +{ + switch (ax25->state) { + case AX25_STATE_1: + if (ax25->n2count == ax25->n2) { + if (ax25->modulus == AX25_MODULUS) { + ax25_disconnect(ax25, ETIMEDOUT); + return; + } else { + ax25->modulus = AX25_MODULUS; + ax25->window = ax25->ax25_dev->values[AX25_VALUES_WINDOW]; + ax25->n2count = 0; + ax25_send_control(ax25, AX25_SABM, AX25_POLLOFF, AX25_COMMAND); + } + } else { + ax25->n2count++; + if (ax25->modulus == AX25_MODULUS) + ax25_send_control(ax25, AX25_SABM, AX25_POLLOFF, AX25_COMMAND); + else + ax25_send_control(ax25, AX25_SABME, AX25_POLLOFF, AX25_COMMAND); + } + break; + + case AX25_STATE_2: + if (ax25->n2count == ax25->n2) { + ax25_send_control(ax25, AX25_DISC, AX25_POLLON, AX25_COMMAND); + if (!sock_flag(ax25->sk, SOCK_DESTROY)) + ax25_disconnect(ax25, ETIMEDOUT); + return; + } else { + ax25->n2count++; + } + break; + + case AX25_STATE_3: + if (ax25->n2count == ax25->n2) { + ax25_send_control(ax25, AX25_DM, AX25_POLLON, AX25_RESPONSE); + ax25_disconnect(ax25, ETIMEDOUT); + return; + } else { + ax25->n2count++; + } + break; + } + + ax25_calculate_t1(ax25); + ax25_start_t1timer(ax25); +} diff --git a/net/ax25/ax25_iface.c b/net/ax25/ax25_iface.c new file mode 100644 index 000000000..979bc4b82 --- /dev/null +++ b/net/ax25/ax25_iface.c @@ -0,0 +1,214 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static struct ax25_protocol *protocol_list; +static DEFINE_RWLOCK(protocol_list_lock); + +static HLIST_HEAD(ax25_linkfail_list); +static DEFINE_SPINLOCK(linkfail_lock); + +static struct listen_struct { + struct listen_struct *next; + ax25_address callsign; + struct net_device *dev; +} *listen_list = NULL; +static DEFINE_SPINLOCK(listen_lock); + +/* + * Do not register the internal protocols AX25_P_TEXT, AX25_P_SEGMENT, + * AX25_P_IP or AX25_P_ARP ... + */ +void ax25_register_pid(struct ax25_protocol *ap) +{ + write_lock_bh(&protocol_list_lock); + ap->next = protocol_list; + protocol_list = ap; + write_unlock_bh(&protocol_list_lock); +} + +EXPORT_SYMBOL_GPL(ax25_register_pid); + +void ax25_protocol_release(unsigned int pid) +{ + struct ax25_protocol *protocol; + + write_lock_bh(&protocol_list_lock); + protocol = protocol_list; + if (protocol == NULL) + goto out; + + if (protocol->pid == pid) { + protocol_list = protocol->next; + goto out; + } + + while (protocol != NULL && protocol->next != NULL) { + if (protocol->next->pid == pid) { + protocol->next = protocol->next->next; + goto out; + } + + protocol = protocol->next; + } +out: + write_unlock_bh(&protocol_list_lock); +} + +EXPORT_SYMBOL(ax25_protocol_release); + +void ax25_linkfail_register(struct ax25_linkfail *lf) +{ + spin_lock_bh(&linkfail_lock); + hlist_add_head(&lf->lf_node, &ax25_linkfail_list); + spin_unlock_bh(&linkfail_lock); +} + +EXPORT_SYMBOL(ax25_linkfail_register); + +void ax25_linkfail_release(struct ax25_linkfail *lf) +{ + spin_lock_bh(&linkfail_lock); + hlist_del_init(&lf->lf_node); + spin_unlock_bh(&linkfail_lock); +} + +EXPORT_SYMBOL(ax25_linkfail_release); + +int ax25_listen_register(const ax25_address *callsign, struct net_device *dev) +{ + struct listen_struct *listen; + + if (ax25_listen_mine(callsign, dev)) + return 0; + + if ((listen = kmalloc(sizeof(*listen), GFP_ATOMIC)) == NULL) + return -ENOMEM; + + listen->callsign = *callsign; + listen->dev = dev; + + spin_lock_bh(&listen_lock); + listen->next = listen_list; + listen_list = listen; + spin_unlock_bh(&listen_lock); + + return 0; +} + +EXPORT_SYMBOL(ax25_listen_register); + +void ax25_listen_release(const ax25_address *callsign, struct net_device *dev) +{ + struct listen_struct *s, *listen; + + spin_lock_bh(&listen_lock); + listen = listen_list; + if (listen == NULL) { + spin_unlock_bh(&listen_lock); + return; + } + + if (ax25cmp(&listen->callsign, callsign) == 0 && listen->dev == dev) { + listen_list = listen->next; + spin_unlock_bh(&listen_lock); + kfree(listen); + return; + } + + while (listen != NULL && listen->next != NULL) { + if (ax25cmp(&listen->next->callsign, callsign) == 0 && listen->next->dev == dev) { + s = listen->next; + listen->next = listen->next->next; + spin_unlock_bh(&listen_lock); + kfree(s); + return; + } + + listen = listen->next; + } + spin_unlock_bh(&listen_lock); +} + +EXPORT_SYMBOL(ax25_listen_release); + +int (*ax25_protocol_function(unsigned int pid))(struct sk_buff *, ax25_cb *) +{ + int (*res)(struct sk_buff *, ax25_cb *) = NULL; + struct ax25_protocol *protocol; + + read_lock(&protocol_list_lock); + for (protocol = protocol_list; protocol != NULL; protocol = protocol->next) + if (protocol->pid == pid) { + res = protocol->func; + break; + } + read_unlock(&protocol_list_lock); + + return res; +} + +int ax25_listen_mine(const ax25_address *callsign, struct net_device *dev) +{ + struct listen_struct *listen; + + spin_lock_bh(&listen_lock); + for (listen = listen_list; listen != NULL; listen = listen->next) + if (ax25cmp(&listen->callsign, callsign) == 0 && + (listen->dev == dev || listen->dev == NULL)) { + spin_unlock_bh(&listen_lock); + return 1; + } + spin_unlock_bh(&listen_lock); + + return 0; +} + +void ax25_link_failed(ax25_cb *ax25, int reason) +{ + struct ax25_linkfail *lf; + + spin_lock_bh(&linkfail_lock); + hlist_for_each_entry(lf, &ax25_linkfail_list, lf_node) + lf->func(ax25, reason); + spin_unlock_bh(&linkfail_lock); +} + +int ax25_protocol_is_registered(unsigned int pid) +{ + struct ax25_protocol *protocol; + int res = 0; + + read_lock_bh(&protocol_list_lock); + for (protocol = protocol_list; protocol != NULL; protocol = protocol->next) + if (protocol->pid == pid) { + res = 1; + break; + } + read_unlock_bh(&protocol_list_lock); + + return res; +} diff --git a/net/ax25/ax25_in.c b/net/ax25/ax25_in.c new file mode 100644 index 000000000..1cac25aca --- /dev/null +++ b/net/ax25/ax25_in.c @@ -0,0 +1,451 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Alan Cox GW4PTS (alan@lxorguk.ukuu.org.uk) + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright (C) Joerg Reuter DL1BKE (jreuter@yaina.de) + * Copyright (C) Hans-Joachim Hetscher DD8NE (dd8ne@bnv-bamberg.de) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * Given a fragment, queue it on the fragment queue and if the fragment + * is complete, send it back to ax25_rx_iframe. + */ +static int ax25_rx_fragment(ax25_cb *ax25, struct sk_buff *skb) +{ + struct sk_buff *skbn, *skbo; + + if (ax25->fragno != 0) { + if (!(*skb->data & AX25_SEG_FIRST)) { + if ((ax25->fragno - 1) == (*skb->data & AX25_SEG_REM)) { + /* Enqueue fragment */ + ax25->fragno = *skb->data & AX25_SEG_REM; + skb_pull(skb, 1); /* skip fragno */ + ax25->fraglen += skb->len; + skb_queue_tail(&ax25->frag_queue, skb); + + /* Last fragment received ? */ + if (ax25->fragno == 0) { + skbn = alloc_skb(AX25_MAX_HEADER_LEN + + ax25->fraglen, + GFP_ATOMIC); + if (!skbn) { + skb_queue_purge(&ax25->frag_queue); + return 1; + } + + skb_reserve(skbn, AX25_MAX_HEADER_LEN); + + skbn->dev = ax25->ax25_dev->dev; + skb_reset_network_header(skbn); + skb_reset_transport_header(skbn); + + /* Copy data from the fragments */ + while ((skbo = skb_dequeue(&ax25->frag_queue)) != NULL) { + skb_copy_from_linear_data(skbo, + skb_put(skbn, skbo->len), + skbo->len); + kfree_skb(skbo); + } + + ax25->fraglen = 0; + + if (ax25_rx_iframe(ax25, skbn) == 0) + kfree_skb(skbn); + } + + return 1; + } + } + } else { + /* First fragment received */ + if (*skb->data & AX25_SEG_FIRST) { + skb_queue_purge(&ax25->frag_queue); + ax25->fragno = *skb->data & AX25_SEG_REM; + skb_pull(skb, 1); /* skip fragno */ + ax25->fraglen = skb->len; + skb_queue_tail(&ax25->frag_queue, skb); + return 1; + } + } + + return 0; +} + +/* + * This is where all valid I frames are sent to, to be dispatched to + * whichever protocol requires them. + */ +int ax25_rx_iframe(ax25_cb *ax25, struct sk_buff *skb) +{ + int (*func)(struct sk_buff *, ax25_cb *); + unsigned char pid; + int queued = 0; + + if (skb == NULL) return 0; + + ax25_start_idletimer(ax25); + + pid = *skb->data; + + if (pid == AX25_P_IP) { + /* working around a TCP bug to keep additional listeners + * happy. TCP re-uses the buffer and destroys the original + * content. + */ + struct sk_buff *skbn = skb_copy(skb, GFP_ATOMIC); + if (skbn != NULL) { + kfree_skb(skb); + skb = skbn; + } + + skb_pull(skb, 1); /* Remove PID */ + skb->mac_header = skb->network_header; + skb_reset_network_header(skb); + skb->dev = ax25->ax25_dev->dev; + skb->pkt_type = PACKET_HOST; + skb->protocol = htons(ETH_P_IP); + netif_rx(skb); + return 1; + } + if (pid == AX25_P_SEGMENT) { + skb_pull(skb, 1); /* Remove PID */ + return ax25_rx_fragment(ax25, skb); + } + + if ((func = ax25_protocol_function(pid)) != NULL) { + skb_pull(skb, 1); /* Remove PID */ + return (*func)(skb, ax25); + } + + if (ax25->sk != NULL && ax25->ax25_dev->values[AX25_VALUES_CONMODE] == 2) { + if ((!ax25->pidincl && ax25->sk->sk_protocol == pid) || + ax25->pidincl) { + if (sock_queue_rcv_skb(ax25->sk, skb) == 0) + queued = 1; + else + ax25->condition |= AX25_COND_OWN_RX_BUSY; + } + } + + return queued; +} + +/* + * Higher level upcall for a LAPB frame + */ +static int ax25_process_rx_frame(ax25_cb *ax25, struct sk_buff *skb, int type, int dama) +{ + int queued = 0; + + if (ax25->state == AX25_STATE_0) + return 0; + + switch (ax25->ax25_dev->values[AX25_VALUES_PROTOCOL]) { + case AX25_PROTO_STD_SIMPLEX: + case AX25_PROTO_STD_DUPLEX: + queued = ax25_std_frame_in(ax25, skb, type); + break; + +#ifdef CONFIG_AX25_DAMA_SLAVE + case AX25_PROTO_DAMA_SLAVE: + if (dama || ax25->ax25_dev->dama.slave) + queued = ax25_ds_frame_in(ax25, skb, type); + else + queued = ax25_std_frame_in(ax25, skb, type); + break; +#endif + } + + return queued; +} + +static int ax25_rcv(struct sk_buff *skb, struct net_device *dev, + const ax25_address *dev_addr, struct packet_type *ptype) +{ + ax25_address src, dest, *next_digi = NULL; + int type = 0, mine = 0, dama; + struct sock *make, *sk; + ax25_digi dp, reverse_dp; + ax25_cb *ax25; + ax25_dev *ax25_dev; + + /* + * Process the AX.25/LAPB frame. + */ + + skb_reset_transport_header(skb); + + if ((ax25_dev = ax25_dev_ax25dev(dev)) == NULL) + goto free; + + /* + * Parse the address header. + */ + + if (ax25_addr_parse(skb->data, skb->len, &src, &dest, &dp, &type, &dama) == NULL) + goto free; + + /* + * Ours perhaps ? + */ + if (dp.lastrepeat + 1 < dp.ndigi) /* Not yet digipeated completely */ + next_digi = &dp.calls[dp.lastrepeat + 1]; + + /* + * Pull of the AX.25 headers leaving the CTRL/PID bytes + */ + skb_pull(skb, ax25_addr_size(&dp)); + + /* For our port addresses ? */ + if (ax25cmp(&dest, dev_addr) == 0 && dp.lastrepeat + 1 == dp.ndigi) + mine = 1; + + /* Also match on any registered callsign from L3/4 */ + if (!mine && ax25_listen_mine(&dest, dev) && dp.lastrepeat + 1 == dp.ndigi) + mine = 1; + + /* UI frame - bypass LAPB processing */ + if ((*skb->data & ~0x10) == AX25_UI && dp.lastrepeat + 1 == dp.ndigi) { + skb_set_transport_header(skb, 2); /* skip control and pid */ + + ax25_send_to_raw(&dest, skb, skb->data[1]); + + if (!mine && ax25cmp(&dest, (ax25_address *)dev->broadcast) != 0) + goto free; + + /* Now we are pointing at the pid byte */ + switch (skb->data[1]) { + case AX25_P_IP: + skb_pull(skb,2); /* drop PID/CTRL */ + skb_reset_transport_header(skb); + skb_reset_network_header(skb); + skb->dev = dev; + skb->pkt_type = PACKET_HOST; + skb->protocol = htons(ETH_P_IP); + netif_rx(skb); + break; + + case AX25_P_ARP: + skb_pull(skb,2); + skb_reset_transport_header(skb); + skb_reset_network_header(skb); + skb->dev = dev; + skb->pkt_type = PACKET_HOST; + skb->protocol = htons(ETH_P_ARP); + netif_rx(skb); + break; + case AX25_P_TEXT: + /* Now find a suitable dgram socket */ + sk = ax25_get_socket(&dest, &src, SOCK_DGRAM); + if (sk != NULL) { + bh_lock_sock(sk); + if (atomic_read(&sk->sk_rmem_alloc) >= + sk->sk_rcvbuf) { + kfree_skb(skb); + } else { + /* + * Remove the control and PID. + */ + skb_pull(skb, 2); + if (sock_queue_rcv_skb(sk, skb) != 0) + kfree_skb(skb); + } + bh_unlock_sock(sk); + sock_put(sk); + } else { + kfree_skb(skb); + } + break; + + default: + kfree_skb(skb); /* Will scan SOCK_AX25 RAW sockets */ + break; + } + + return 0; + } + + /* + * Is connected mode supported on this device ? + * If not, should we DM the incoming frame (except DMs) or + * silently ignore them. For now we stay quiet. + */ + if (ax25_dev->values[AX25_VALUES_CONMODE] == 0) + goto free; + + /* LAPB */ + + /* AX.25 state 1-4 */ + + ax25_digi_invert(&dp, &reverse_dp); + + if ((ax25 = ax25_find_cb(&dest, &src, &reverse_dp, dev)) != NULL) { + /* + * Process the frame. If it is queued up internally it + * returns one otherwise we free it immediately. This + * routine itself wakes the user context layers so we do + * no further work + */ + if (ax25_process_rx_frame(ax25, skb, type, dama) == 0) + kfree_skb(skb); + + ax25_cb_put(ax25); + return 0; + } + + /* AX.25 state 0 (disconnected) */ + + /* a) received not a SABM(E) */ + + if ((*skb->data & ~AX25_PF) != AX25_SABM && + (*skb->data & ~AX25_PF) != AX25_SABME) { + /* + * Never reply to a DM. Also ignore any connects for + * addresses that are not our interfaces and not a socket. + */ + if ((*skb->data & ~AX25_PF) != AX25_DM && mine) + ax25_return_dm(dev, &src, &dest, &dp); + + goto free; + } + + /* b) received SABM(E) */ + + if (dp.lastrepeat + 1 == dp.ndigi) + sk = ax25_find_listener(&dest, 0, dev, SOCK_SEQPACKET); + else + sk = ax25_find_listener(next_digi, 1, dev, SOCK_SEQPACKET); + + if (sk != NULL) { + bh_lock_sock(sk); + if (sk_acceptq_is_full(sk) || + (make = ax25_make_new(sk, ax25_dev)) == NULL) { + if (mine) + ax25_return_dm(dev, &src, &dest, &dp); + kfree_skb(skb); + bh_unlock_sock(sk); + sock_put(sk); + + return 0; + } + + ax25 = sk_to_ax25(make); + skb_set_owner_r(skb, make); + skb_queue_head(&sk->sk_receive_queue, skb); + + make->sk_state = TCP_ESTABLISHED; + + sk_acceptq_added(sk); + bh_unlock_sock(sk); + } else { + if (!mine) + goto free; + + if ((ax25 = ax25_create_cb()) == NULL) { + ax25_return_dm(dev, &src, &dest, &dp); + goto free; + } + + ax25_fillin_cb(ax25, ax25_dev); + } + + ax25->source_addr = dest; + ax25->dest_addr = src; + + /* + * Sort out any digipeated paths. + */ + if (dp.ndigi && !ax25->digipeat && + (ax25->digipeat = kmalloc(sizeof(ax25_digi), GFP_ATOMIC)) == NULL) { + kfree_skb(skb); + ax25_destroy_socket(ax25); + if (sk) + sock_put(sk); + return 0; + } + + if (dp.ndigi == 0) { + kfree(ax25->digipeat); + ax25->digipeat = NULL; + } else { + /* Reverse the source SABM's path */ + memcpy(ax25->digipeat, &reverse_dp, sizeof(ax25_digi)); + } + + if ((*skb->data & ~AX25_PF) == AX25_SABME) { + ax25->modulus = AX25_EMODULUS; + ax25->window = ax25_dev->values[AX25_VALUES_EWINDOW]; + } else { + ax25->modulus = AX25_MODULUS; + ax25->window = ax25_dev->values[AX25_VALUES_WINDOW]; + } + + ax25_send_control(ax25, AX25_UA, AX25_POLLON, AX25_RESPONSE); + +#ifdef CONFIG_AX25_DAMA_SLAVE + if (dama && ax25->ax25_dev->values[AX25_VALUES_PROTOCOL] == AX25_PROTO_DAMA_SLAVE) + ax25_dama_on(ax25); +#endif + + ax25->state = AX25_STATE_3; + + ax25_cb_add(ax25); + + ax25_start_heartbeat(ax25); + ax25_start_t3timer(ax25); + ax25_start_idletimer(ax25); + + if (sk) { + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_data_ready(sk); + sock_put(sk); + } else { +free: + kfree_skb(skb); + } + return 0; +} + +/* + * Receive an AX.25 frame via a SLIP interface. + */ +int ax25_kiss_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *ptype, struct net_device *orig_dev) +{ + skb_orphan(skb); + + if (!net_eq(dev_net(dev), &init_net)) { + kfree_skb(skb); + return 0; + } + + if ((*skb->data & 0x0F) != 0) { + kfree_skb(skb); /* Not a KISS data frame */ + return 0; + } + + skb_pull(skb, AX25_KISS_HEADER_LEN); /* Remove the KISS byte */ + + return ax25_rcv(skb, dev, (const ax25_address *)dev->dev_addr, ptype); +} diff --git a/net/ax25/ax25_ip.c b/net/ax25/ax25_ip.c new file mode 100644 index 000000000..36249776c --- /dev/null +++ b/net/ax25/ax25_ip.c @@ -0,0 +1,246 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include /* For TIOCINQ/OUTQ */ +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * IP over AX.25 encapsulation. + */ + +/* + * Shove an AX.25 UI header on an IP packet and handle ARP + */ + +#ifdef CONFIG_INET + +static int ax25_hard_header(struct sk_buff *skb, struct net_device *dev, + unsigned short type, const void *daddr, + const void *saddr, unsigned int len) +{ + unsigned char *buff; + + /* they sometimes come back to us... */ + if (type == ETH_P_AX25) + return 0; + + /* header is an AX.25 UI frame from us to them */ + buff = skb_push(skb, AX25_HEADER_LEN); + *buff++ = 0x00; /* KISS DATA */ + + if (daddr != NULL) + memcpy(buff, daddr, dev->addr_len); /* Address specified */ + + buff[6] &= ~AX25_CBIT; + buff[6] &= ~AX25_EBIT; + buff[6] |= AX25_SSSID_SPARE; + buff += AX25_ADDR_LEN; + + if (saddr != NULL) + memcpy(buff, saddr, dev->addr_len); + else + memcpy(buff, dev->dev_addr, dev->addr_len); + + buff[6] &= ~AX25_CBIT; + buff[6] |= AX25_EBIT; + buff[6] |= AX25_SSSID_SPARE; + buff += AX25_ADDR_LEN; + + *buff++ = AX25_UI; /* UI */ + + /* Append a suitable AX.25 PID */ + switch (type) { + case ETH_P_IP: + *buff++ = AX25_P_IP; + break; + case ETH_P_ARP: + *buff++ = AX25_P_ARP; + break; + default: + printk(KERN_ERR "AX.25: ax25_hard_header - wrong protocol type 0x%2.2x\n", type); + *buff++ = 0; + break; + } + + if (daddr != NULL) + return AX25_HEADER_LEN; + + return -AX25_HEADER_LEN; /* Unfinished header */ +} + +netdev_tx_t ax25_ip_xmit(struct sk_buff *skb) +{ + struct sk_buff *ourskb; + unsigned char *bp = skb->data; + ax25_route *route; + struct net_device *dev = NULL; + ax25_address *src, *dst; + ax25_digi *digipeat = NULL; + ax25_dev *ax25_dev; + ax25_cb *ax25; + char ip_mode = ' '; + + dst = (ax25_address *)(bp + 1); + src = (ax25_address *)(bp + 8); + + ax25_route_lock_use(); + route = ax25_get_route(dst, NULL); + if (route) { + digipeat = route->digipeat; + dev = route->dev; + ip_mode = route->ip_mode; + } + + if (dev == NULL) + dev = skb->dev; + + if ((ax25_dev = ax25_dev_ax25dev(dev)) == NULL) { + kfree_skb(skb); + goto put; + } + + if (bp[16] == AX25_P_IP) { + if (ip_mode == 'V' || (ip_mode == ' ' && ax25_dev->values[AX25_VALUES_IPDEFMODE])) { + /* + * We copy the buffer and release the original thereby + * keeping it straight + * + * Note: we report 1 back so the caller will + * not feed the frame direct to the physical device + * We don't want that to happen. (It won't be upset + * as we have pulled the frame from the queue by + * freeing it). + * + * NB: TCP modifies buffers that are still + * on a device queue, thus we use skb_copy() + * instead of using skb_clone() unless this + * gets fixed. + */ + + ax25_address src_c; + ax25_address dst_c; + + if ((ourskb = skb_copy(skb, GFP_ATOMIC)) == NULL) { + kfree_skb(skb); + goto put; + } + + if (skb->sk != NULL) + skb_set_owner_w(ourskb, skb->sk); + + kfree_skb(skb); + /* dl9sau: bugfix + * after kfree_skb(), dst and src which were pointer + * to bp which is part of skb->data would not be valid + * anymore hope that after skb_pull(ourskb, ..) our + * dsc_c and src_c will not become invalid + */ + bp = ourskb->data; + dst_c = *(ax25_address *)(bp + 1); + src_c = *(ax25_address *)(bp + 8); + + skb_pull(ourskb, AX25_HEADER_LEN - 1); /* Keep PID */ + skb_reset_network_header(ourskb); + + ax25=ax25_send_frame( + ourskb, + ax25_dev->values[AX25_VALUES_PACLEN], + &src_c, + &dst_c, digipeat, dev); + if (ax25) { + ax25_cb_put(ax25); + } + goto put; + } + } + + bp[7] &= ~AX25_CBIT; + bp[7] &= ~AX25_EBIT; + bp[7] |= AX25_SSSID_SPARE; + + bp[14] &= ~AX25_CBIT; + bp[14] |= AX25_EBIT; + bp[14] |= AX25_SSSID_SPARE; + + skb_pull(skb, AX25_KISS_HEADER_LEN); + + if (digipeat != NULL) { + if ((ourskb = ax25_rt_build_path(skb, src, dst, route->digipeat)) == NULL) + goto put; + + skb = ourskb; + } + + ax25_queue_xmit(skb, dev); + +put: + + ax25_route_lock_unuse(); + return NETDEV_TX_OK; +} + +#else /* INET */ + +static int ax25_hard_header(struct sk_buff *skb, struct net_device *dev, + unsigned short type, const void *daddr, + const void *saddr, unsigned int len) +{ + return -AX25_HEADER_LEN; +} + +netdev_tx_t ax25_ip_xmit(struct sk_buff *skb) +{ + kfree_skb(skb); + return NETDEV_TX_OK; +} +#endif + +static bool ax25_validate_header(const char *header, unsigned int len) +{ + ax25_digi digi; + + if (!len) + return false; + + if (header[0]) + return true; + + return ax25_addr_parse(header + 1, len - 1, NULL, NULL, &digi, NULL, + NULL); +} + +const struct header_ops ax25_header_ops = { + .create = ax25_hard_header, + .validate = ax25_validate_header, +}; + +EXPORT_SYMBOL(ax25_header_ops); +EXPORT_SYMBOL(ax25_ip_xmit); diff --git a/net/ax25/ax25_out.c b/net/ax25/ax25_out.c new file mode 100644 index 000000000..3db76d247 --- /dev/null +++ b/net/ax25/ax25_out.c @@ -0,0 +1,386 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Alan Cox GW4PTS (alan@lxorguk.ukuu.org.uk) + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright (C) Joerg Reuter DL1BKE (jreuter@yaina.de) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static DEFINE_SPINLOCK(ax25_frag_lock); + +ax25_cb *ax25_send_frame(struct sk_buff *skb, int paclen, const ax25_address *src, ax25_address *dest, ax25_digi *digi, struct net_device *dev) +{ + ax25_dev *ax25_dev; + ax25_cb *ax25; + + /* + * Take the default packet length for the device if zero is + * specified. + */ + if (paclen == 0) { + if ((ax25_dev = ax25_dev_ax25dev(dev)) == NULL) + return NULL; + + paclen = ax25_dev->values[AX25_VALUES_PACLEN]; + } + + /* + * Look for an existing connection. + */ + if ((ax25 = ax25_find_cb(src, dest, digi, dev)) != NULL) { + ax25_output(ax25, paclen, skb); + return ax25; /* It already existed */ + } + + if ((ax25_dev = ax25_dev_ax25dev(dev)) == NULL) + return NULL; + + if ((ax25 = ax25_create_cb()) == NULL) + return NULL; + + ax25_fillin_cb(ax25, ax25_dev); + + ax25->source_addr = *src; + ax25->dest_addr = *dest; + + if (digi != NULL) { + ax25->digipeat = kmemdup(digi, sizeof(*digi), GFP_ATOMIC); + if (ax25->digipeat == NULL) { + ax25_cb_put(ax25); + return NULL; + } + } + + switch (ax25->ax25_dev->values[AX25_VALUES_PROTOCOL]) { + case AX25_PROTO_STD_SIMPLEX: + case AX25_PROTO_STD_DUPLEX: + ax25_std_establish_data_link(ax25); + break; + +#ifdef CONFIG_AX25_DAMA_SLAVE + case AX25_PROTO_DAMA_SLAVE: + if (ax25_dev->dama.slave) + ax25_ds_establish_data_link(ax25); + else + ax25_std_establish_data_link(ax25); + break; +#endif + } + + /* + * There is one ref for the state machine; a caller needs + * one more to put it back, just like with the existing one. + */ + ax25_cb_hold(ax25); + + ax25_cb_add(ax25); + + ax25->state = AX25_STATE_1; + + ax25_start_heartbeat(ax25); + + ax25_output(ax25, paclen, skb); + + return ax25; /* We had to create it */ +} + +EXPORT_SYMBOL(ax25_send_frame); + +/* + * All outgoing AX.25 I frames pass via this routine. Therefore this is + * where the fragmentation of frames takes place. If fragment is set to + * zero then we are not allowed to do fragmentation, even if the frame + * is too large. + */ +void ax25_output(ax25_cb *ax25, int paclen, struct sk_buff *skb) +{ + struct sk_buff *skbn; + unsigned char *p; + int frontlen, len, fragno, ka9qfrag, first = 1; + + if (paclen < 16) { + WARN_ON_ONCE(1); + kfree_skb(skb); + return; + } + + if ((skb->len - 1) > paclen) { + if (*skb->data == AX25_P_TEXT) { + skb_pull(skb, 1); /* skip PID */ + ka9qfrag = 0; + } else { + paclen -= 2; /* Allow for fragment control info */ + ka9qfrag = 1; + } + + fragno = skb->len / paclen; + if (skb->len % paclen == 0) fragno--; + + frontlen = skb_headroom(skb); /* Address space + CTRL */ + + while (skb->len > 0) { + spin_lock_bh(&ax25_frag_lock); + if ((skbn = alloc_skb(paclen + 2 + frontlen, GFP_ATOMIC)) == NULL) { + spin_unlock_bh(&ax25_frag_lock); + printk(KERN_CRIT "AX.25: ax25_output - out of memory\n"); + return; + } + + if (skb->sk != NULL) + skb_set_owner_w(skbn, skb->sk); + + spin_unlock_bh(&ax25_frag_lock); + + len = (paclen > skb->len) ? skb->len : paclen; + + if (ka9qfrag == 1) { + skb_reserve(skbn, frontlen + 2); + skb_set_network_header(skbn, + skb_network_offset(skb)); + skb_copy_from_linear_data(skb, skb_put(skbn, len), len); + p = skb_push(skbn, 2); + + *p++ = AX25_P_SEGMENT; + + *p = fragno--; + if (first) { + *p |= AX25_SEG_FIRST; + first = 0; + } + } else { + skb_reserve(skbn, frontlen + 1); + skb_set_network_header(skbn, + skb_network_offset(skb)); + skb_copy_from_linear_data(skb, skb_put(skbn, len), len); + p = skb_push(skbn, 1); + *p = AX25_P_TEXT; + } + + skb_pull(skb, len); + skb_queue_tail(&ax25->write_queue, skbn); /* Throw it on the queue */ + } + + kfree_skb(skb); + } else { + skb_queue_tail(&ax25->write_queue, skb); /* Throw it on the queue */ + } + + switch (ax25->ax25_dev->values[AX25_VALUES_PROTOCOL]) { + case AX25_PROTO_STD_SIMPLEX: + case AX25_PROTO_STD_DUPLEX: + ax25_kick(ax25); + break; + +#ifdef CONFIG_AX25_DAMA_SLAVE + /* + * A DAMA slave is _required_ to work as normal AX.25L2V2 + * if no DAMA master is available. + */ + case AX25_PROTO_DAMA_SLAVE: + if (!ax25->ax25_dev->dama.slave) ax25_kick(ax25); + break; +#endif + } +} + +/* + * This procedure is passed a buffer descriptor for an iframe. It builds + * the rest of the control part of the frame and then writes it out. + */ +static void ax25_send_iframe(ax25_cb *ax25, struct sk_buff *skb, int poll_bit) +{ + unsigned char *frame; + + if (skb == NULL) + return; + + skb_reset_network_header(skb); + + if (ax25->modulus == AX25_MODULUS) { + frame = skb_push(skb, 1); + + *frame = AX25_I; + *frame |= (poll_bit) ? AX25_PF : 0; + *frame |= (ax25->vr << 5); + *frame |= (ax25->vs << 1); + } else { + frame = skb_push(skb, 2); + + frame[0] = AX25_I; + frame[0] |= (ax25->vs << 1); + frame[1] = (poll_bit) ? AX25_EPF : 0; + frame[1] |= (ax25->vr << 1); + } + + ax25_start_idletimer(ax25); + + ax25_transmit_buffer(ax25, skb, AX25_COMMAND); +} + +void ax25_kick(ax25_cb *ax25) +{ + struct sk_buff *skb, *skbn; + int last = 1; + unsigned short start, end, next; + + if (ax25->state != AX25_STATE_3 && ax25->state != AX25_STATE_4) + return; + + if (ax25->condition & AX25_COND_PEER_RX_BUSY) + return; + + if (skb_peek(&ax25->write_queue) == NULL) + return; + + start = (skb_peek(&ax25->ack_queue) == NULL) ? ax25->va : ax25->vs; + end = (ax25->va + ax25->window) % ax25->modulus; + + if (start == end) + return; + + /* + * Transmit data until either we're out of data to send or + * the window is full. Send a poll on the final I frame if + * the window is filled. + */ + + /* + * Dequeue the frame and copy it. + * Check for race with ax25_clear_queues(). + */ + skb = skb_dequeue(&ax25->write_queue); + if (!skb) + return; + + ax25->vs = start; + + do { + if ((skbn = skb_clone(skb, GFP_ATOMIC)) == NULL) { + skb_queue_head(&ax25->write_queue, skb); + break; + } + + if (skb->sk != NULL) + skb_set_owner_w(skbn, skb->sk); + + next = (ax25->vs + 1) % ax25->modulus; + last = (next == end); + + /* + * Transmit the frame copy. + * bke 960114: do not set the Poll bit on the last frame + * in DAMA mode. + */ + switch (ax25->ax25_dev->values[AX25_VALUES_PROTOCOL]) { + case AX25_PROTO_STD_SIMPLEX: + case AX25_PROTO_STD_DUPLEX: + ax25_send_iframe(ax25, skbn, (last) ? AX25_POLLON : AX25_POLLOFF); + break; + +#ifdef CONFIG_AX25_DAMA_SLAVE + case AX25_PROTO_DAMA_SLAVE: + ax25_send_iframe(ax25, skbn, AX25_POLLOFF); + break; +#endif + } + + ax25->vs = next; + + /* + * Requeue the original data frame. + */ + skb_queue_tail(&ax25->ack_queue, skb); + + } while (!last && (skb = skb_dequeue(&ax25->write_queue)) != NULL); + + ax25->condition &= ~AX25_COND_ACK_PENDING; + + if (!ax25_t1timer_running(ax25)) { + ax25_stop_t3timer(ax25); + ax25_calculate_t1(ax25); + ax25_start_t1timer(ax25); + } +} + +void ax25_transmit_buffer(ax25_cb *ax25, struct sk_buff *skb, int type) +{ + unsigned char *ptr; + int headroom; + + if (ax25->ax25_dev == NULL) { + ax25_disconnect(ax25, ENETUNREACH); + return; + } + + headroom = ax25_addr_size(ax25->digipeat); + + if (unlikely(skb_headroom(skb) < headroom)) { + skb = skb_expand_head(skb, headroom); + if (!skb) { + printk(KERN_CRIT "AX.25: ax25_transmit_buffer - out of memory\n"); + return; + } + } + + ptr = skb_push(skb, headroom); + + ax25_addr_build(ptr, &ax25->source_addr, &ax25->dest_addr, ax25->digipeat, type, ax25->modulus); + + ax25_queue_xmit(skb, ax25->ax25_dev->dev); +} + +/* + * A small shim to dev_queue_xmit to add the KISS control byte, and do + * any packet forwarding in operation. + */ +void ax25_queue_xmit(struct sk_buff *skb, struct net_device *dev) +{ + unsigned char *ptr; + + skb->protocol = ax25_type_trans(skb, ax25_fwd_dev(dev)); + + ptr = skb_push(skb, 1); + *ptr = 0x00; /* KISS */ + + dev_queue_xmit(skb); +} + +int ax25_check_iframes_acked(ax25_cb *ax25, unsigned short nr) +{ + if (ax25->vs == nr) { + ax25_frames_acked(ax25, nr); + ax25_calculate_rtt(ax25); + ax25_stop_t1timer(ax25); + ax25_start_t3timer(ax25); + return 1; + } else { + if (ax25->va != nr) { + ax25_frames_acked(ax25, nr); + ax25_calculate_t1(ax25); + ax25_start_t1timer(ax25); + return 1; + } + } + return 0; +} diff --git a/net/ax25/ax25_route.c b/net/ax25/ax25_route.c new file mode 100644 index 000000000..b7c4d656a --- /dev/null +++ b/net/ax25/ax25_route.c @@ -0,0 +1,488 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Alan Cox GW4PTS (alan@lxorguk.ukuu.org.uk) + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright (C) Steven Whitehouse GW7RRM (stevew@acm.org) + * Copyright (C) Joerg Reuter DL1BKE (jreuter@yaina.de) + * Copyright (C) Hans-Joachim Hetscher DD8NE (dd8ne@bnv-bamberg.de) + * Copyright (C) Frederic Rible F1OAT (frible@teaser.fr) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static ax25_route *ax25_route_list; +DEFINE_RWLOCK(ax25_route_lock); + +void ax25_rt_device_down(struct net_device *dev) +{ + ax25_route *s, *t, *ax25_rt; + + write_lock_bh(&ax25_route_lock); + ax25_rt = ax25_route_list; + while (ax25_rt != NULL) { + s = ax25_rt; + ax25_rt = ax25_rt->next; + + if (s->dev == dev) { + if (ax25_route_list == s) { + ax25_route_list = s->next; + kfree(s->digipeat); + kfree(s); + } else { + for (t = ax25_route_list; t != NULL; t = t->next) { + if (t->next == s) { + t->next = s->next; + kfree(s->digipeat); + kfree(s); + break; + } + } + } + } + } + write_unlock_bh(&ax25_route_lock); +} + +static int __must_check ax25_rt_add(struct ax25_routes_struct *route) +{ + ax25_route *ax25_rt; + ax25_dev *ax25_dev; + int i; + + if (route->digi_count > AX25_MAX_DIGIS) + return -EINVAL; + + ax25_dev = ax25_addr_ax25dev(&route->port_addr); + if (!ax25_dev) + return -EINVAL; + + write_lock_bh(&ax25_route_lock); + + ax25_rt = ax25_route_list; + while (ax25_rt != NULL) { + if (ax25cmp(&ax25_rt->callsign, &route->dest_addr) == 0 && + ax25_rt->dev == ax25_dev->dev) { + kfree(ax25_rt->digipeat); + ax25_rt->digipeat = NULL; + if (route->digi_count != 0) { + if ((ax25_rt->digipeat = kmalloc(sizeof(ax25_digi), GFP_ATOMIC)) == NULL) { + write_unlock_bh(&ax25_route_lock); + ax25_dev_put(ax25_dev); + return -ENOMEM; + } + ax25_rt->digipeat->lastrepeat = -1; + ax25_rt->digipeat->ndigi = route->digi_count; + for (i = 0; i < route->digi_count; i++) { + ax25_rt->digipeat->repeated[i] = 0; + ax25_rt->digipeat->calls[i] = route->digi_addr[i]; + } + } + write_unlock_bh(&ax25_route_lock); + ax25_dev_put(ax25_dev); + return 0; + } + ax25_rt = ax25_rt->next; + } + + if ((ax25_rt = kmalloc(sizeof(ax25_route), GFP_ATOMIC)) == NULL) { + write_unlock_bh(&ax25_route_lock); + ax25_dev_put(ax25_dev); + return -ENOMEM; + } + + ax25_rt->callsign = route->dest_addr; + ax25_rt->dev = ax25_dev->dev; + ax25_rt->digipeat = NULL; + ax25_rt->ip_mode = ' '; + if (route->digi_count != 0) { + if ((ax25_rt->digipeat = kmalloc(sizeof(ax25_digi), GFP_ATOMIC)) == NULL) { + write_unlock_bh(&ax25_route_lock); + kfree(ax25_rt); + ax25_dev_put(ax25_dev); + return -ENOMEM; + } + ax25_rt->digipeat->lastrepeat = -1; + ax25_rt->digipeat->ndigi = route->digi_count; + for (i = 0; i < route->digi_count; i++) { + ax25_rt->digipeat->repeated[i] = 0; + ax25_rt->digipeat->calls[i] = route->digi_addr[i]; + } + } + ax25_rt->next = ax25_route_list; + ax25_route_list = ax25_rt; + write_unlock_bh(&ax25_route_lock); + ax25_dev_put(ax25_dev); + + return 0; +} + +void __ax25_put_route(ax25_route *ax25_rt) +{ + kfree(ax25_rt->digipeat); + kfree(ax25_rt); +} + +static int ax25_rt_del(struct ax25_routes_struct *route) +{ + ax25_route *s, *t, *ax25_rt; + ax25_dev *ax25_dev; + + if ((ax25_dev = ax25_addr_ax25dev(&route->port_addr)) == NULL) + return -EINVAL; + + write_lock_bh(&ax25_route_lock); + + ax25_rt = ax25_route_list; + while (ax25_rt != NULL) { + s = ax25_rt; + ax25_rt = ax25_rt->next; + if (s->dev == ax25_dev->dev && + ax25cmp(&route->dest_addr, &s->callsign) == 0) { + if (ax25_route_list == s) { + ax25_route_list = s->next; + __ax25_put_route(s); + } else { + for (t = ax25_route_list; t != NULL; t = t->next) { + if (t->next == s) { + t->next = s->next; + __ax25_put_route(s); + break; + } + } + } + } + } + write_unlock_bh(&ax25_route_lock); + ax25_dev_put(ax25_dev); + + return 0; +} + +static int ax25_rt_opt(struct ax25_route_opt_struct *rt_option) +{ + ax25_route *ax25_rt; + ax25_dev *ax25_dev; + int err = 0; + + if ((ax25_dev = ax25_addr_ax25dev(&rt_option->port_addr)) == NULL) + return -EINVAL; + + write_lock_bh(&ax25_route_lock); + + ax25_rt = ax25_route_list; + while (ax25_rt != NULL) { + if (ax25_rt->dev == ax25_dev->dev && + ax25cmp(&rt_option->dest_addr, &ax25_rt->callsign) == 0) { + switch (rt_option->cmd) { + case AX25_SET_RT_IPMODE: + switch (rt_option->arg) { + case ' ': + case 'D': + case 'V': + ax25_rt->ip_mode = rt_option->arg; + break; + default: + err = -EINVAL; + goto out; + } + break; + default: + err = -EINVAL; + goto out; + } + } + ax25_rt = ax25_rt->next; + } + +out: + write_unlock_bh(&ax25_route_lock); + ax25_dev_put(ax25_dev); + return err; +} + +int ax25_rt_ioctl(unsigned int cmd, void __user *arg) +{ + struct ax25_route_opt_struct rt_option; + struct ax25_routes_struct route; + + switch (cmd) { + case SIOCADDRT: + if (copy_from_user(&route, arg, sizeof(route))) + return -EFAULT; + return ax25_rt_add(&route); + + case SIOCDELRT: + if (copy_from_user(&route, arg, sizeof(route))) + return -EFAULT; + return ax25_rt_del(&route); + + case SIOCAX25OPTRT: + if (copy_from_user(&rt_option, arg, sizeof(rt_option))) + return -EFAULT; + return ax25_rt_opt(&rt_option); + + default: + return -EINVAL; + } +} + +#ifdef CONFIG_PROC_FS + +static void *ax25_rt_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(ax25_route_lock) +{ + struct ax25_route *ax25_rt; + int i = 1; + + read_lock(&ax25_route_lock); + if (*pos == 0) + return SEQ_START_TOKEN; + + for (ax25_rt = ax25_route_list; ax25_rt != NULL; ax25_rt = ax25_rt->next) { + if (i == *pos) + return ax25_rt; + ++i; + } + + return NULL; +} + +static void *ax25_rt_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + ++*pos; + return (v == SEQ_START_TOKEN) ? ax25_route_list : + ((struct ax25_route *) v)->next; +} + +static void ax25_rt_seq_stop(struct seq_file *seq, void *v) + __releases(ax25_route_lock) +{ + read_unlock(&ax25_route_lock); +} + +static int ax25_rt_seq_show(struct seq_file *seq, void *v) +{ + char buf[11]; + + if (v == SEQ_START_TOKEN) + seq_puts(seq, "callsign dev mode digipeaters\n"); + else { + struct ax25_route *ax25_rt = v; + const char *callsign; + int i; + + if (ax25cmp(&ax25_rt->callsign, &null_ax25_address) == 0) + callsign = "default"; + else + callsign = ax2asc(buf, &ax25_rt->callsign); + + seq_printf(seq, "%-9s %-4s", + callsign, + ax25_rt->dev ? ax25_rt->dev->name : "???"); + + switch (ax25_rt->ip_mode) { + case 'V': + seq_puts(seq, " vc"); + break; + case 'D': + seq_puts(seq, " dg"); + break; + default: + seq_puts(seq, " *"); + break; + } + + if (ax25_rt->digipeat != NULL) + for (i = 0; i < ax25_rt->digipeat->ndigi; i++) + seq_printf(seq, " %s", + ax2asc(buf, &ax25_rt->digipeat->calls[i])); + + seq_puts(seq, "\n"); + } + return 0; +} + +const struct seq_operations ax25_rt_seqops = { + .start = ax25_rt_seq_start, + .next = ax25_rt_seq_next, + .stop = ax25_rt_seq_stop, + .show = ax25_rt_seq_show, +}; +#endif + +/* + * Find AX.25 route + * + * Only routes with a reference count of zero can be destroyed. + * Must be called with ax25_route_lock read locked. + */ +ax25_route *ax25_get_route(ax25_address *addr, struct net_device *dev) +{ + ax25_route *ax25_spe_rt = NULL; + ax25_route *ax25_def_rt = NULL; + ax25_route *ax25_rt; + + /* + * Bind to the physical interface we heard them on, or the default + * route if none is found; + */ + for (ax25_rt = ax25_route_list; ax25_rt != NULL; ax25_rt = ax25_rt->next) { + if (dev == NULL) { + if (ax25cmp(&ax25_rt->callsign, addr) == 0 && ax25_rt->dev != NULL) + ax25_spe_rt = ax25_rt; + if (ax25cmp(&ax25_rt->callsign, &null_ax25_address) == 0 && ax25_rt->dev != NULL) + ax25_def_rt = ax25_rt; + } else { + if (ax25cmp(&ax25_rt->callsign, addr) == 0 && ax25_rt->dev == dev) + ax25_spe_rt = ax25_rt; + if (ax25cmp(&ax25_rt->callsign, &null_ax25_address) == 0 && ax25_rt->dev == dev) + ax25_def_rt = ax25_rt; + } + } + + ax25_rt = ax25_def_rt; + if (ax25_spe_rt != NULL) + ax25_rt = ax25_spe_rt; + + return ax25_rt; +} + +/* + * Adjust path: If you specify a default route and want to connect + * a target on the digipeater path but w/o having a special route + * set before, the path has to be truncated from your target on. + */ +static inline void ax25_adjust_path(ax25_address *addr, ax25_digi *digipeat) +{ + int k; + + for (k = 0; k < digipeat->ndigi; k++) { + if (ax25cmp(addr, &digipeat->calls[k]) == 0) + break; + } + + digipeat->ndigi = k; +} + + +/* + * Find which interface to use. + */ +int ax25_rt_autobind(ax25_cb *ax25, ax25_address *addr) +{ + ax25_uid_assoc *user; + ax25_route *ax25_rt; + int err = 0; + + ax25_route_lock_use(); + ax25_rt = ax25_get_route(addr, NULL); + if (!ax25_rt) { + ax25_route_lock_unuse(); + return -EHOSTUNREACH; + } + if ((ax25->ax25_dev = ax25_dev_ax25dev(ax25_rt->dev)) == NULL) { + err = -EHOSTUNREACH; + goto put; + } + + user = ax25_findbyuid(current_euid()); + if (user) { + ax25->source_addr = user->call; + ax25_uid_put(user); + } else { + if (ax25_uid_policy && !capable(CAP_NET_BIND_SERVICE)) { + err = -EPERM; + goto put; + } + ax25->source_addr = *(ax25_address *)ax25->ax25_dev->dev->dev_addr; + } + + if (ax25_rt->digipeat != NULL) { + ax25->digipeat = kmemdup(ax25_rt->digipeat, sizeof(ax25_digi), + GFP_ATOMIC); + if (ax25->digipeat == NULL) { + err = -ENOMEM; + goto put; + } + ax25_adjust_path(addr, ax25->digipeat); + } + + if (ax25->sk != NULL) { + local_bh_disable(); + bh_lock_sock(ax25->sk); + sock_reset_flag(ax25->sk, SOCK_ZAPPED); + bh_unlock_sock(ax25->sk); + local_bh_enable(); + } + +put: + ax25_route_lock_unuse(); + return err; +} + +struct sk_buff *ax25_rt_build_path(struct sk_buff *skb, ax25_address *src, + ax25_address *dest, ax25_digi *digi) +{ + unsigned char *bp; + int len; + + len = digi->ndigi * AX25_ADDR_LEN; + + if (unlikely(skb_headroom(skb) < len)) { + skb = skb_expand_head(skb, len); + if (!skb) { + printk(KERN_CRIT "AX.25: ax25_dg_build_path - out of memory\n"); + return NULL; + } + } + + bp = skb_push(skb, len); + + ax25_addr_build(bp, src, dest, digi, AX25_COMMAND, AX25_MODULUS); + + return skb; +} + +/* + * Free all memory associated with routing structures. + */ +void __exit ax25_rt_free(void) +{ + ax25_route *s, *ax25_rt = ax25_route_list; + + write_lock_bh(&ax25_route_lock); + while (ax25_rt != NULL) { + s = ax25_rt; + ax25_rt = ax25_rt->next; + + kfree(s->digipeat); + kfree(s); + } + write_unlock_bh(&ax25_route_lock); +} diff --git a/net/ax25/ax25_std_in.c b/net/ax25/ax25_std_in.c new file mode 100644 index 000000000..ba176196a --- /dev/null +++ b/net/ax25/ax25_std_in.c @@ -0,0 +1,443 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Alan Cox GW4PTS (alan@lxorguk.ukuu.org.uk) + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright (C) Joerg Reuter DL1BKE (jreuter@yaina.de) + * Copyright (C) Hans-Joachim Hetscher DD8NE (dd8ne@bnv-bamberg.de) + * + * Most of this code is based on the SDL diagrams published in the 7th ARRL + * Computer Networking Conference papers. The diagrams have mistakes in them, + * but are mostly correct. Before you modify the code could you read the SDL + * diagrams as the code is not obvious and probably very easy to break. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * State machine for state 1, Awaiting Connection State. + * The handling of the timer(s) is in file ax25_std_timer.c. + * Handling of state 0 and connection release is in ax25.c. + */ +static int ax25_std_state1_machine(ax25_cb *ax25, struct sk_buff *skb, int frametype, int pf, int type) +{ + switch (frametype) { + case AX25_SABM: + ax25->modulus = AX25_MODULUS; + ax25->window = ax25->ax25_dev->values[AX25_VALUES_WINDOW]; + ax25_send_control(ax25, AX25_UA, pf, AX25_RESPONSE); + break; + + case AX25_SABME: + ax25->modulus = AX25_EMODULUS; + ax25->window = ax25->ax25_dev->values[AX25_VALUES_EWINDOW]; + ax25_send_control(ax25, AX25_UA, pf, AX25_RESPONSE); + break; + + case AX25_DISC: + ax25_send_control(ax25, AX25_DM, pf, AX25_RESPONSE); + break; + + case AX25_UA: + if (pf) { + ax25_calculate_rtt(ax25); + ax25_stop_t1timer(ax25); + ax25_start_t3timer(ax25); + ax25_start_idletimer(ax25); + ax25->vs = 0; + ax25->va = 0; + ax25->vr = 0; + ax25->state = AX25_STATE_3; + ax25->n2count = 0; + if (ax25->sk != NULL) { + bh_lock_sock(ax25->sk); + ax25->sk->sk_state = TCP_ESTABLISHED; + /* For WAIT_SABM connections we will produce an accept ready socket here */ + if (!sock_flag(ax25->sk, SOCK_DEAD)) + ax25->sk->sk_state_change(ax25->sk); + bh_unlock_sock(ax25->sk); + } + } + break; + + case AX25_DM: + if (pf) { + if (ax25->modulus == AX25_MODULUS) { + ax25_disconnect(ax25, ECONNREFUSED); + } else { + ax25->modulus = AX25_MODULUS; + ax25->window = ax25->ax25_dev->values[AX25_VALUES_WINDOW]; + } + } + break; + + default: + break; + } + + return 0; +} + +/* + * State machine for state 2, Awaiting Release State. + * The handling of the timer(s) is in file ax25_std_timer.c + * Handling of state 0 and connection release is in ax25.c. + */ +static int ax25_std_state2_machine(ax25_cb *ax25, struct sk_buff *skb, int frametype, int pf, int type) +{ + switch (frametype) { + case AX25_SABM: + case AX25_SABME: + ax25_send_control(ax25, AX25_DM, pf, AX25_RESPONSE); + break; + + case AX25_DISC: + ax25_send_control(ax25, AX25_UA, pf, AX25_RESPONSE); + ax25_disconnect(ax25, 0); + break; + + case AX25_DM: + case AX25_UA: + if (pf) + ax25_disconnect(ax25, 0); + break; + + case AX25_I: + case AX25_REJ: + case AX25_RNR: + case AX25_RR: + if (pf) ax25_send_control(ax25, AX25_DM, AX25_POLLON, AX25_RESPONSE); + break; + + default: + break; + } + + return 0; +} + +/* + * State machine for state 3, Connected State. + * The handling of the timer(s) is in file ax25_std_timer.c + * Handling of state 0 and connection release is in ax25.c. + */ +static int ax25_std_state3_machine(ax25_cb *ax25, struct sk_buff *skb, int frametype, int ns, int nr, int pf, int type) +{ + int queued = 0; + + switch (frametype) { + case AX25_SABM: + case AX25_SABME: + if (frametype == AX25_SABM) { + ax25->modulus = AX25_MODULUS; + ax25->window = ax25->ax25_dev->values[AX25_VALUES_WINDOW]; + } else { + ax25->modulus = AX25_EMODULUS; + ax25->window = ax25->ax25_dev->values[AX25_VALUES_EWINDOW]; + } + ax25_send_control(ax25, AX25_UA, pf, AX25_RESPONSE); + ax25_stop_t1timer(ax25); + ax25_stop_t2timer(ax25); + ax25_start_t3timer(ax25); + ax25_start_idletimer(ax25); + ax25->condition = 0x00; + ax25->vs = 0; + ax25->va = 0; + ax25->vr = 0; + ax25_requeue_frames(ax25); + break; + + case AX25_DISC: + ax25_send_control(ax25, AX25_UA, pf, AX25_RESPONSE); + ax25_disconnect(ax25, 0); + break; + + case AX25_DM: + ax25_disconnect(ax25, ECONNRESET); + break; + + case AX25_RR: + case AX25_RNR: + if (frametype == AX25_RR) + ax25->condition &= ~AX25_COND_PEER_RX_BUSY; + else + ax25->condition |= AX25_COND_PEER_RX_BUSY; + if (type == AX25_COMMAND && pf) + ax25_std_enquiry_response(ax25); + if (ax25_validate_nr(ax25, nr)) { + ax25_check_iframes_acked(ax25, nr); + } else { + ax25_std_nr_error_recovery(ax25); + ax25->state = AX25_STATE_1; + } + break; + + case AX25_REJ: + ax25->condition &= ~AX25_COND_PEER_RX_BUSY; + if (type == AX25_COMMAND && pf) + ax25_std_enquiry_response(ax25); + if (ax25_validate_nr(ax25, nr)) { + ax25_frames_acked(ax25, nr); + ax25_calculate_rtt(ax25); + ax25_stop_t1timer(ax25); + ax25_start_t3timer(ax25); + ax25_requeue_frames(ax25); + } else { + ax25_std_nr_error_recovery(ax25); + ax25->state = AX25_STATE_1; + } + break; + + case AX25_I: + if (!ax25_validate_nr(ax25, nr)) { + ax25_std_nr_error_recovery(ax25); + ax25->state = AX25_STATE_1; + break; + } + if (ax25->condition & AX25_COND_PEER_RX_BUSY) { + ax25_frames_acked(ax25, nr); + } else { + ax25_check_iframes_acked(ax25, nr); + } + if (ax25->condition & AX25_COND_OWN_RX_BUSY) { + if (pf) ax25_std_enquiry_response(ax25); + break; + } + if (ns == ax25->vr) { + ax25->vr = (ax25->vr + 1) % ax25->modulus; + queued = ax25_rx_iframe(ax25, skb); + if (ax25->condition & AX25_COND_OWN_RX_BUSY) + ax25->vr = ns; /* ax25->vr - 1 */ + ax25->condition &= ~AX25_COND_REJECT; + if (pf) { + ax25_std_enquiry_response(ax25); + } else { + if (!(ax25->condition & AX25_COND_ACK_PENDING)) { + ax25->condition |= AX25_COND_ACK_PENDING; + ax25_start_t2timer(ax25); + } + } + } else { + if (ax25->condition & AX25_COND_REJECT) { + if (pf) ax25_std_enquiry_response(ax25); + } else { + ax25->condition |= AX25_COND_REJECT; + ax25_send_control(ax25, AX25_REJ, pf, AX25_RESPONSE); + ax25->condition &= ~AX25_COND_ACK_PENDING; + } + } + break; + + case AX25_FRMR: + case AX25_ILLEGAL: + ax25_std_establish_data_link(ax25); + ax25->state = AX25_STATE_1; + break; + + default: + break; + } + + return queued; +} + +/* + * State machine for state 4, Timer Recovery State. + * The handling of the timer(s) is in file ax25_std_timer.c + * Handling of state 0 and connection release is in ax25.c. + */ +static int ax25_std_state4_machine(ax25_cb *ax25, struct sk_buff *skb, int frametype, int ns, int nr, int pf, int type) +{ + int queued = 0; + + switch (frametype) { + case AX25_SABM: + case AX25_SABME: + if (frametype == AX25_SABM) { + ax25->modulus = AX25_MODULUS; + ax25->window = ax25->ax25_dev->values[AX25_VALUES_WINDOW]; + } else { + ax25->modulus = AX25_EMODULUS; + ax25->window = ax25->ax25_dev->values[AX25_VALUES_EWINDOW]; + } + ax25_send_control(ax25, AX25_UA, pf, AX25_RESPONSE); + ax25_stop_t1timer(ax25); + ax25_stop_t2timer(ax25); + ax25_start_t3timer(ax25); + ax25_start_idletimer(ax25); + ax25->condition = 0x00; + ax25->vs = 0; + ax25->va = 0; + ax25->vr = 0; + ax25->state = AX25_STATE_3; + ax25->n2count = 0; + ax25_requeue_frames(ax25); + break; + + case AX25_DISC: + ax25_send_control(ax25, AX25_UA, pf, AX25_RESPONSE); + ax25_disconnect(ax25, 0); + break; + + case AX25_DM: + ax25_disconnect(ax25, ECONNRESET); + break; + + case AX25_RR: + case AX25_RNR: + if (frametype == AX25_RR) + ax25->condition &= ~AX25_COND_PEER_RX_BUSY; + else + ax25->condition |= AX25_COND_PEER_RX_BUSY; + if (type == AX25_RESPONSE && pf) { + ax25_stop_t1timer(ax25); + ax25->n2count = 0; + if (ax25_validate_nr(ax25, nr)) { + ax25_frames_acked(ax25, nr); + if (ax25->vs == ax25->va) { + ax25_start_t3timer(ax25); + ax25->state = AX25_STATE_3; + } else { + ax25_requeue_frames(ax25); + } + } else { + ax25_std_nr_error_recovery(ax25); + ax25->state = AX25_STATE_1; + } + break; + } + if (type == AX25_COMMAND && pf) + ax25_std_enquiry_response(ax25); + if (ax25_validate_nr(ax25, nr)) { + ax25_frames_acked(ax25, nr); + } else { + ax25_std_nr_error_recovery(ax25); + ax25->state = AX25_STATE_1; + } + break; + + case AX25_REJ: + ax25->condition &= ~AX25_COND_PEER_RX_BUSY; + if (pf && type == AX25_RESPONSE) { + ax25_stop_t1timer(ax25); + ax25->n2count = 0; + if (ax25_validate_nr(ax25, nr)) { + ax25_frames_acked(ax25, nr); + if (ax25->vs == ax25->va) { + ax25_start_t3timer(ax25); + ax25->state = AX25_STATE_3; + } else { + ax25_requeue_frames(ax25); + } + } else { + ax25_std_nr_error_recovery(ax25); + ax25->state = AX25_STATE_1; + } + break; + } + if (type == AX25_COMMAND && pf) + ax25_std_enquiry_response(ax25); + if (ax25_validate_nr(ax25, nr)) { + ax25_frames_acked(ax25, nr); + ax25_requeue_frames(ax25); + } else { + ax25_std_nr_error_recovery(ax25); + ax25->state = AX25_STATE_1; + } + break; + + case AX25_I: + if (!ax25_validate_nr(ax25, nr)) { + ax25_std_nr_error_recovery(ax25); + ax25->state = AX25_STATE_1; + break; + } + ax25_frames_acked(ax25, nr); + if (ax25->condition & AX25_COND_OWN_RX_BUSY) { + if (pf) + ax25_std_enquiry_response(ax25); + break; + } + if (ns == ax25->vr) { + ax25->vr = (ax25->vr + 1) % ax25->modulus; + queued = ax25_rx_iframe(ax25, skb); + if (ax25->condition & AX25_COND_OWN_RX_BUSY) + ax25->vr = ns; /* ax25->vr - 1 */ + ax25->condition &= ~AX25_COND_REJECT; + if (pf) { + ax25_std_enquiry_response(ax25); + } else { + if (!(ax25->condition & AX25_COND_ACK_PENDING)) { + ax25->condition |= AX25_COND_ACK_PENDING; + ax25_start_t2timer(ax25); + } + } + } else { + if (ax25->condition & AX25_COND_REJECT) { + if (pf) ax25_std_enquiry_response(ax25); + } else { + ax25->condition |= AX25_COND_REJECT; + ax25_send_control(ax25, AX25_REJ, pf, AX25_RESPONSE); + ax25->condition &= ~AX25_COND_ACK_PENDING; + } + } + break; + + case AX25_FRMR: + case AX25_ILLEGAL: + ax25_std_establish_data_link(ax25); + ax25->state = AX25_STATE_1; + break; + + default: + break; + } + + return queued; +} + +/* + * Higher level upcall for a LAPB frame + */ +int ax25_std_frame_in(ax25_cb *ax25, struct sk_buff *skb, int type) +{ + int queued = 0, frametype, ns, nr, pf; + + frametype = ax25_decode(ax25, skb, &ns, &nr, &pf); + + switch (ax25->state) { + case AX25_STATE_1: + queued = ax25_std_state1_machine(ax25, skb, frametype, pf, type); + break; + case AX25_STATE_2: + queued = ax25_std_state2_machine(ax25, skb, frametype, pf, type); + break; + case AX25_STATE_3: + queued = ax25_std_state3_machine(ax25, skb, frametype, ns, nr, pf, type); + break; + case AX25_STATE_4: + queued = ax25_std_state4_machine(ax25, skb, frametype, ns, nr, pf, type); + break; + } + + ax25_kick(ax25); + + return queued; +} diff --git a/net/ax25/ax25_std_subr.c b/net/ax25/ax25_std_subr.c new file mode 100644 index 000000000..4c36f1342 --- /dev/null +++ b/net/ax25/ax25_std_subr.c @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * The following routines are taken from page 170 of the 7th ARRL Computer + * Networking Conference paper, as is the whole state machine. + */ + +void ax25_std_nr_error_recovery(ax25_cb *ax25) +{ + ax25_std_establish_data_link(ax25); +} + +void ax25_std_establish_data_link(ax25_cb *ax25) +{ + ax25->condition = 0x00; + ax25->n2count = 0; + + if (ax25->modulus == AX25_MODULUS) + ax25_send_control(ax25, AX25_SABM, AX25_POLLON, AX25_COMMAND); + else + ax25_send_control(ax25, AX25_SABME, AX25_POLLON, AX25_COMMAND); + + ax25_calculate_t1(ax25); + ax25_stop_idletimer(ax25); + ax25_stop_t3timer(ax25); + ax25_stop_t2timer(ax25); + ax25_start_t1timer(ax25); +} + +void ax25_std_transmit_enquiry(ax25_cb *ax25) +{ + if (ax25->condition & AX25_COND_OWN_RX_BUSY) + ax25_send_control(ax25, AX25_RNR, AX25_POLLON, AX25_COMMAND); + else + ax25_send_control(ax25, AX25_RR, AX25_POLLON, AX25_COMMAND); + + ax25->condition &= ~AX25_COND_ACK_PENDING; + + ax25_calculate_t1(ax25); + ax25_start_t1timer(ax25); +} + +void ax25_std_enquiry_response(ax25_cb *ax25) +{ + if (ax25->condition & AX25_COND_OWN_RX_BUSY) + ax25_send_control(ax25, AX25_RNR, AX25_POLLON, AX25_RESPONSE); + else + ax25_send_control(ax25, AX25_RR, AX25_POLLON, AX25_RESPONSE); + + ax25->condition &= ~AX25_COND_ACK_PENDING; +} + +void ax25_std_timeout_response(ax25_cb *ax25) +{ + if (ax25->condition & AX25_COND_OWN_RX_BUSY) + ax25_send_control(ax25, AX25_RNR, AX25_POLLOFF, AX25_RESPONSE); + else + ax25_send_control(ax25, AX25_RR, AX25_POLLOFF, AX25_RESPONSE); + + ax25->condition &= ~AX25_COND_ACK_PENDING; +} diff --git a/net/ax25/ax25_std_timer.c b/net/ax25/ax25_std_timer.c new file mode 100644 index 000000000..b17da4121 --- /dev/null +++ b/net/ax25/ax25_std_timer.c @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Alan Cox GW4PTS (alan@lxorguk.ukuu.org.uk) + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright (C) Joerg Reuter DL1BKE (jreuter@yaina.de) + * Copyright (C) Frederic Rible F1OAT (frible@teaser.fr) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void ax25_std_heartbeat_expiry(ax25_cb *ax25) +{ + struct sock *sk = ax25->sk; + + if (sk) + bh_lock_sock(sk); + + switch (ax25->state) { + case AX25_STATE_0: + case AX25_STATE_2: + /* Magic here: If we listen() and a new link dies before it + is accepted() it isn't 'dead' so doesn't get removed. */ + if (!sk || sock_flag(sk, SOCK_DESTROY) || + (sk->sk_state == TCP_LISTEN && + sock_flag(sk, SOCK_DEAD))) { + if (sk) { + sock_hold(sk); + ax25_destroy_socket(ax25); + bh_unlock_sock(sk); + /* Ungrab socket and destroy it */ + sock_put(sk); + } else + ax25_destroy_socket(ax25); + return; + } + break; + + case AX25_STATE_3: + case AX25_STATE_4: + /* + * Check the state of the receive buffer. + */ + if (sk != NULL) { + if (atomic_read(&sk->sk_rmem_alloc) < + (sk->sk_rcvbuf >> 1) && + (ax25->condition & AX25_COND_OWN_RX_BUSY)) { + ax25->condition &= ~AX25_COND_OWN_RX_BUSY; + ax25->condition &= ~AX25_COND_ACK_PENDING; + ax25_send_control(ax25, AX25_RR, AX25_POLLOFF, AX25_RESPONSE); + break; + } + } + } + + if (sk) + bh_unlock_sock(sk); + + ax25_start_heartbeat(ax25); +} + +void ax25_std_t2timer_expiry(ax25_cb *ax25) +{ + if (ax25->condition & AX25_COND_ACK_PENDING) { + ax25->condition &= ~AX25_COND_ACK_PENDING; + ax25_std_timeout_response(ax25); + } +} + +void ax25_std_t3timer_expiry(ax25_cb *ax25) +{ + ax25->n2count = 0; + ax25_std_transmit_enquiry(ax25); + ax25->state = AX25_STATE_4; +} + +void ax25_std_idletimer_expiry(ax25_cb *ax25) +{ + ax25_clear_queues(ax25); + + ax25->n2count = 0; + ax25_send_control(ax25, AX25_DISC, AX25_POLLON, AX25_COMMAND); + ax25->state = AX25_STATE_2; + + ax25_calculate_t1(ax25); + ax25_start_t1timer(ax25); + ax25_stop_t2timer(ax25); + ax25_stop_t3timer(ax25); + + if (ax25->sk != NULL) { + bh_lock_sock(ax25->sk); + ax25->sk->sk_state = TCP_CLOSE; + ax25->sk->sk_err = 0; + ax25->sk->sk_shutdown |= SEND_SHUTDOWN; + if (!sock_flag(ax25->sk, SOCK_DEAD)) { + ax25->sk->sk_state_change(ax25->sk); + sock_set_flag(ax25->sk, SOCK_DEAD); + } + bh_unlock_sock(ax25->sk); + } +} + +void ax25_std_t1timer_expiry(ax25_cb *ax25) +{ + switch (ax25->state) { + case AX25_STATE_1: + if (ax25->n2count == ax25->n2) { + if (ax25->modulus == AX25_MODULUS) { + ax25_disconnect(ax25, ETIMEDOUT); + return; + } else { + ax25->modulus = AX25_MODULUS; + ax25->window = ax25->ax25_dev->values[AX25_VALUES_WINDOW]; + ax25->n2count = 0; + ax25_send_control(ax25, AX25_SABM, AX25_POLLON, AX25_COMMAND); + } + } else { + ax25->n2count++; + if (ax25->modulus == AX25_MODULUS) + ax25_send_control(ax25, AX25_SABM, AX25_POLLON, AX25_COMMAND); + else + ax25_send_control(ax25, AX25_SABME, AX25_POLLON, AX25_COMMAND); + } + break; + + case AX25_STATE_2: + if (ax25->n2count == ax25->n2) { + ax25_send_control(ax25, AX25_DISC, AX25_POLLON, AX25_COMMAND); + if (!sock_flag(ax25->sk, SOCK_DESTROY)) + ax25_disconnect(ax25, ETIMEDOUT); + return; + } else { + ax25->n2count++; + ax25_send_control(ax25, AX25_DISC, AX25_POLLON, AX25_COMMAND); + } + break; + + case AX25_STATE_3: + ax25->n2count = 1; + ax25_std_transmit_enquiry(ax25); + ax25->state = AX25_STATE_4; + break; + + case AX25_STATE_4: + if (ax25->n2count == ax25->n2) { + ax25_send_control(ax25, AX25_DM, AX25_POLLON, AX25_RESPONSE); + ax25_disconnect(ax25, ETIMEDOUT); + return; + } else { + ax25->n2count++; + ax25_std_transmit_enquiry(ax25); + } + break; + } + + ax25_calculate_t1(ax25); + ax25_start_t1timer(ax25); +} diff --git a/net/ax25/ax25_subr.c b/net/ax25/ax25_subr.c new file mode 100644 index 000000000..9ff98f46d --- /dev/null +++ b/net/ax25/ax25_subr.c @@ -0,0 +1,296 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Alan Cox GW4PTS (alan@lxorguk.ukuu.org.uk) + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright (C) Joerg Reuter DL1BKE (jreuter@yaina.de) + * Copyright (C) Frederic Rible F1OAT (frible@teaser.fr) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * This routine purges all the queues of frames. + */ +void ax25_clear_queues(ax25_cb *ax25) +{ + skb_queue_purge(&ax25->write_queue); + skb_queue_purge(&ax25->ack_queue); + skb_queue_purge(&ax25->reseq_queue); + skb_queue_purge(&ax25->frag_queue); +} + +/* + * This routine purges the input queue of those frames that have been + * acknowledged. This replaces the boxes labelled "V(a) <- N(r)" on the + * SDL diagram. + */ +void ax25_frames_acked(ax25_cb *ax25, unsigned short nr) +{ + struct sk_buff *skb; + + /* + * Remove all the ack-ed frames from the ack queue. + */ + if (ax25->va != nr) { + while (skb_peek(&ax25->ack_queue) != NULL && ax25->va != nr) { + skb = skb_dequeue(&ax25->ack_queue); + kfree_skb(skb); + ax25->va = (ax25->va + 1) % ax25->modulus; + } + } +} + +void ax25_requeue_frames(ax25_cb *ax25) +{ + struct sk_buff *skb; + + /* + * Requeue all the un-ack-ed frames on the output queue to be picked + * up by ax25_kick called from the timer. This arrangement handles the + * possibility of an empty output queue. + */ + while ((skb = skb_dequeue_tail(&ax25->ack_queue)) != NULL) + skb_queue_head(&ax25->write_queue, skb); +} + +/* + * Validate that the value of nr is between va and vs. Return true or + * false for testing. + */ +int ax25_validate_nr(ax25_cb *ax25, unsigned short nr) +{ + unsigned short vc = ax25->va; + + while (vc != ax25->vs) { + if (nr == vc) return 1; + vc = (vc + 1) % ax25->modulus; + } + + if (nr == ax25->vs) return 1; + + return 0; +} + +/* + * This routine is the centralised routine for parsing the control + * information for the different frame formats. + */ +int ax25_decode(ax25_cb *ax25, struct sk_buff *skb, int *ns, int *nr, int *pf) +{ + unsigned char *frame; + int frametype = AX25_ILLEGAL; + + frame = skb->data; + *ns = *nr = *pf = 0; + + if (ax25->modulus == AX25_MODULUS) { + if ((frame[0] & AX25_S) == 0) { + frametype = AX25_I; /* I frame - carries NR/NS/PF */ + *ns = (frame[0] >> 1) & 0x07; + *nr = (frame[0] >> 5) & 0x07; + *pf = frame[0] & AX25_PF; + } else if ((frame[0] & AX25_U) == 1) { /* S frame - take out PF/NR */ + frametype = frame[0] & 0x0F; + *nr = (frame[0] >> 5) & 0x07; + *pf = frame[0] & AX25_PF; + } else if ((frame[0] & AX25_U) == 3) { /* U frame - take out PF */ + frametype = frame[0] & ~AX25_PF; + *pf = frame[0] & AX25_PF; + } + skb_pull(skb, 1); + } else { + if ((frame[0] & AX25_S) == 0) { + frametype = AX25_I; /* I frame - carries NR/NS/PF */ + *ns = (frame[0] >> 1) & 0x7F; + *nr = (frame[1] >> 1) & 0x7F; + *pf = frame[1] & AX25_EPF; + skb_pull(skb, 2); + } else if ((frame[0] & AX25_U) == 1) { /* S frame - take out PF/NR */ + frametype = frame[0] & 0x0F; + *nr = (frame[1] >> 1) & 0x7F; + *pf = frame[1] & AX25_EPF; + skb_pull(skb, 2); + } else if ((frame[0] & AX25_U) == 3) { /* U frame - take out PF */ + frametype = frame[0] & ~AX25_PF; + *pf = frame[0] & AX25_PF; + skb_pull(skb, 1); + } + } + + return frametype; +} + +/* + * This routine is called when the HDLC layer internally generates a + * command or response for the remote machine ( eg. RR, UA etc. ). + * Only supervisory or unnumbered frames are processed. + */ +void ax25_send_control(ax25_cb *ax25, int frametype, int poll_bit, int type) +{ + struct sk_buff *skb; + unsigned char *dptr; + + if ((skb = alloc_skb(ax25->ax25_dev->dev->hard_header_len + 2, GFP_ATOMIC)) == NULL) + return; + + skb_reserve(skb, ax25->ax25_dev->dev->hard_header_len); + + skb_reset_network_header(skb); + + /* Assume a response - address structure for DTE */ + if (ax25->modulus == AX25_MODULUS) { + dptr = skb_put(skb, 1); + *dptr = frametype; + *dptr |= (poll_bit) ? AX25_PF : 0; + if ((frametype & AX25_U) == AX25_S) /* S frames carry NR */ + *dptr |= (ax25->vr << 5); + } else { + if ((frametype & AX25_U) == AX25_U) { + dptr = skb_put(skb, 1); + *dptr = frametype; + *dptr |= (poll_bit) ? AX25_PF : 0; + } else { + dptr = skb_put(skb, 2); + dptr[0] = frametype; + dptr[1] = (ax25->vr << 1); + dptr[1] |= (poll_bit) ? AX25_EPF : 0; + } + } + + ax25_transmit_buffer(ax25, skb, type); +} + +/* + * Send a 'DM' to an unknown connection attempt, or an invalid caller. + * + * Note: src here is the sender, thus it's the target of the DM + */ +void ax25_return_dm(struct net_device *dev, ax25_address *src, ax25_address *dest, ax25_digi *digi) +{ + struct sk_buff *skb; + char *dptr; + ax25_digi retdigi; + + if (dev == NULL) + return; + + if ((skb = alloc_skb(dev->hard_header_len + 1, GFP_ATOMIC)) == NULL) + return; /* Next SABM will get DM'd */ + + skb_reserve(skb, dev->hard_header_len); + skb_reset_network_header(skb); + + ax25_digi_invert(digi, &retdigi); + + dptr = skb_put(skb, 1); + + *dptr = AX25_DM | AX25_PF; + + /* + * Do the address ourselves + */ + dptr = skb_push(skb, ax25_addr_size(digi)); + dptr += ax25_addr_build(dptr, dest, src, &retdigi, AX25_RESPONSE, AX25_MODULUS); + + ax25_queue_xmit(skb, dev); +} + +/* + * Exponential backoff for AX.25 + */ +void ax25_calculate_t1(ax25_cb *ax25) +{ + int n, t = 2; + + switch (ax25->backoff) { + case 0: + break; + + case 1: + t += 2 * ax25->n2count; + break; + + case 2: + for (n = 0; n < ax25->n2count; n++) + t *= 2; + if (t > 8) t = 8; + break; + } + + ax25->t1 = t * ax25->rtt; +} + +/* + * Calculate the Round Trip Time + */ +void ax25_calculate_rtt(ax25_cb *ax25) +{ + if (ax25->backoff == 0) + return; + + if (ax25_t1timer_running(ax25) && ax25->n2count == 0) + ax25->rtt = (9 * ax25->rtt + ax25->t1 - ax25_display_timer(&ax25->t1timer)) / 10; + + if (ax25->rtt < AX25_T1CLAMPLO) + ax25->rtt = AX25_T1CLAMPLO; + + if (ax25->rtt > AX25_T1CLAMPHI) + ax25->rtt = AX25_T1CLAMPHI; +} + +void ax25_disconnect(ax25_cb *ax25, int reason) +{ + ax25_clear_queues(ax25); + + if (reason == ENETUNREACH) { + del_timer_sync(&ax25->timer); + del_timer_sync(&ax25->t1timer); + del_timer_sync(&ax25->t2timer); + del_timer_sync(&ax25->t3timer); + del_timer_sync(&ax25->idletimer); + } else { + if (ax25->sk && !sock_flag(ax25->sk, SOCK_DESTROY)) + ax25_stop_heartbeat(ax25); + ax25_stop_t1timer(ax25); + ax25_stop_t2timer(ax25); + ax25_stop_t3timer(ax25); + ax25_stop_idletimer(ax25); + } + + ax25->state = AX25_STATE_0; + + ax25_link_failed(ax25, reason); + + if (ax25->sk != NULL) { + local_bh_disable(); + bh_lock_sock(ax25->sk); + ax25->sk->sk_state = TCP_CLOSE; + ax25->sk->sk_err = reason; + ax25->sk->sk_shutdown |= SEND_SHUTDOWN; + if (!sock_flag(ax25->sk, SOCK_DEAD)) { + ax25->sk->sk_state_change(ax25->sk); + sock_set_flag(ax25->sk, SOCK_DEAD); + } + bh_unlock_sock(ax25->sk); + local_bh_enable(); + } +} diff --git a/net/ax25/ax25_timer.c b/net/ax25/ax25_timer.c new file mode 100644 index 000000000..9f7cb0a7c --- /dev/null +++ b/net/ax25/ax25_timer.c @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Alan Cox GW4PTS (alan@lxorguk.ukuu.org.uk) + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright (C) Tomi Manninen OH2BNS (oh2bns@sral.fi) + * Copyright (C) Darryl Miles G7LED (dlm@g7led.demon.co.uk) + * Copyright (C) Joerg Reuter DL1BKE (jreuter@yaina.de) + * Copyright (C) Frederic Rible F1OAT (frible@teaser.fr) + * Copyright (C) 2002 Ralf Baechle DO1GRB (ralf@gnu.org) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void ax25_heartbeat_expiry(struct timer_list *); +static void ax25_t1timer_expiry(struct timer_list *); +static void ax25_t2timer_expiry(struct timer_list *); +static void ax25_t3timer_expiry(struct timer_list *); +static void ax25_idletimer_expiry(struct timer_list *); + +void ax25_setup_timers(ax25_cb *ax25) +{ + timer_setup(&ax25->timer, ax25_heartbeat_expiry, 0); + timer_setup(&ax25->t1timer, ax25_t1timer_expiry, 0); + timer_setup(&ax25->t2timer, ax25_t2timer_expiry, 0); + timer_setup(&ax25->t3timer, ax25_t3timer_expiry, 0); + timer_setup(&ax25->idletimer, ax25_idletimer_expiry, 0); +} + +void ax25_start_heartbeat(ax25_cb *ax25) +{ + mod_timer(&ax25->timer, jiffies + 5 * HZ); +} + +void ax25_start_t1timer(ax25_cb *ax25) +{ + mod_timer(&ax25->t1timer, jiffies + ax25->t1); +} + +void ax25_start_t2timer(ax25_cb *ax25) +{ + mod_timer(&ax25->t2timer, jiffies + ax25->t2); +} + +void ax25_start_t3timer(ax25_cb *ax25) +{ + if (ax25->t3 > 0) + mod_timer(&ax25->t3timer, jiffies + ax25->t3); + else + del_timer(&ax25->t3timer); +} + +void ax25_start_idletimer(ax25_cb *ax25) +{ + if (ax25->idle > 0) + mod_timer(&ax25->idletimer, jiffies + ax25->idle); + else + del_timer(&ax25->idletimer); +} + +void ax25_stop_heartbeat(ax25_cb *ax25) +{ + del_timer(&ax25->timer); +} + +void ax25_stop_t1timer(ax25_cb *ax25) +{ + del_timer(&ax25->t1timer); +} + +void ax25_stop_t2timer(ax25_cb *ax25) +{ + del_timer(&ax25->t2timer); +} + +void ax25_stop_t3timer(ax25_cb *ax25) +{ + del_timer(&ax25->t3timer); +} + +void ax25_stop_idletimer(ax25_cb *ax25) +{ + del_timer(&ax25->idletimer); +} + +int ax25_t1timer_running(ax25_cb *ax25) +{ + return timer_pending(&ax25->t1timer); +} + +unsigned long ax25_display_timer(struct timer_list *timer) +{ + long delta = timer->expires - jiffies; + + if (!timer_pending(timer)) + return 0; + + return max(0L, delta); +} + +EXPORT_SYMBOL(ax25_display_timer); + +static void ax25_heartbeat_expiry(struct timer_list *t) +{ + int proto = AX25_PROTO_STD_SIMPLEX; + ax25_cb *ax25 = from_timer(ax25, t, timer); + + if (ax25->ax25_dev) + proto = ax25->ax25_dev->values[AX25_VALUES_PROTOCOL]; + + switch (proto) { + case AX25_PROTO_STD_SIMPLEX: + case AX25_PROTO_STD_DUPLEX: + ax25_std_heartbeat_expiry(ax25); + break; + +#ifdef CONFIG_AX25_DAMA_SLAVE + case AX25_PROTO_DAMA_SLAVE: + if (ax25->ax25_dev->dama.slave) + ax25_ds_heartbeat_expiry(ax25); + else + ax25_std_heartbeat_expiry(ax25); + break; +#endif + } +} + +static void ax25_t1timer_expiry(struct timer_list *t) +{ + ax25_cb *ax25 = from_timer(ax25, t, t1timer); + + switch (ax25->ax25_dev->values[AX25_VALUES_PROTOCOL]) { + case AX25_PROTO_STD_SIMPLEX: + case AX25_PROTO_STD_DUPLEX: + ax25_std_t1timer_expiry(ax25); + break; + +#ifdef CONFIG_AX25_DAMA_SLAVE + case AX25_PROTO_DAMA_SLAVE: + if (!ax25->ax25_dev->dama.slave) + ax25_std_t1timer_expiry(ax25); + break; +#endif + } +} + +static void ax25_t2timer_expiry(struct timer_list *t) +{ + ax25_cb *ax25 = from_timer(ax25, t, t2timer); + + switch (ax25->ax25_dev->values[AX25_VALUES_PROTOCOL]) { + case AX25_PROTO_STD_SIMPLEX: + case AX25_PROTO_STD_DUPLEX: + ax25_std_t2timer_expiry(ax25); + break; + +#ifdef CONFIG_AX25_DAMA_SLAVE + case AX25_PROTO_DAMA_SLAVE: + if (!ax25->ax25_dev->dama.slave) + ax25_std_t2timer_expiry(ax25); + break; +#endif + } +} + +static void ax25_t3timer_expiry(struct timer_list *t) +{ + ax25_cb *ax25 = from_timer(ax25, t, t3timer); + + switch (ax25->ax25_dev->values[AX25_VALUES_PROTOCOL]) { + case AX25_PROTO_STD_SIMPLEX: + case AX25_PROTO_STD_DUPLEX: + ax25_std_t3timer_expiry(ax25); + break; + +#ifdef CONFIG_AX25_DAMA_SLAVE + case AX25_PROTO_DAMA_SLAVE: + if (ax25->ax25_dev->dama.slave) + ax25_ds_t3timer_expiry(ax25); + else + ax25_std_t3timer_expiry(ax25); + break; +#endif + } +} + +static void ax25_idletimer_expiry(struct timer_list *t) +{ + ax25_cb *ax25 = from_timer(ax25, t, idletimer); + + switch (ax25->ax25_dev->values[AX25_VALUES_PROTOCOL]) { + case AX25_PROTO_STD_SIMPLEX: + case AX25_PROTO_STD_DUPLEX: + ax25_std_idletimer_expiry(ax25); + break; + +#ifdef CONFIG_AX25_DAMA_SLAVE + case AX25_PROTO_DAMA_SLAVE: + if (ax25->ax25_dev->dama.slave) + ax25_ds_idletimer_expiry(ax25); + else + ax25_std_idletimer_expiry(ax25); + break; +#endif + } +} diff --git a/net/ax25/ax25_uid.c b/net/ax25/ax25_uid.c new file mode 100644 index 000000000..241e4680e --- /dev/null +++ b/net/ax25/ax25_uid.c @@ -0,0 +1,204 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * Callsign/UID mapper. This is in kernel space for security on multi-amateur machines. + */ + +static HLIST_HEAD(ax25_uid_list); +static DEFINE_RWLOCK(ax25_uid_lock); + +int ax25_uid_policy; + +EXPORT_SYMBOL(ax25_uid_policy); + +ax25_uid_assoc *ax25_findbyuid(kuid_t uid) +{ + ax25_uid_assoc *ax25_uid, *res = NULL; + + read_lock(&ax25_uid_lock); + ax25_uid_for_each(ax25_uid, &ax25_uid_list) { + if (uid_eq(ax25_uid->uid, uid)) { + ax25_uid_hold(ax25_uid); + res = ax25_uid; + break; + } + } + read_unlock(&ax25_uid_lock); + + return res; +} + +EXPORT_SYMBOL(ax25_findbyuid); + +int ax25_uid_ioctl(int cmd, struct sockaddr_ax25 *sax) +{ + ax25_uid_assoc *ax25_uid; + ax25_uid_assoc *user; + unsigned long res; + + switch (cmd) { + case SIOCAX25GETUID: + res = -ENOENT; + read_lock(&ax25_uid_lock); + ax25_uid_for_each(ax25_uid, &ax25_uid_list) { + if (ax25cmp(&sax->sax25_call, &ax25_uid->call) == 0) { + res = from_kuid_munged(current_user_ns(), ax25_uid->uid); + break; + } + } + read_unlock(&ax25_uid_lock); + + return res; + + case SIOCAX25ADDUID: + { + kuid_t sax25_kuid; + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + sax25_kuid = make_kuid(current_user_ns(), sax->sax25_uid); + if (!uid_valid(sax25_kuid)) + return -EINVAL; + user = ax25_findbyuid(sax25_kuid); + if (user) { + ax25_uid_put(user); + return -EEXIST; + } + if (sax->sax25_uid == 0) + return -EINVAL; + if ((ax25_uid = kmalloc(sizeof(*ax25_uid), GFP_KERNEL)) == NULL) + return -ENOMEM; + + refcount_set(&ax25_uid->refcount, 1); + ax25_uid->uid = sax25_kuid; + ax25_uid->call = sax->sax25_call; + + write_lock(&ax25_uid_lock); + hlist_add_head(&ax25_uid->uid_node, &ax25_uid_list); + write_unlock(&ax25_uid_lock); + + return 0; + } + case SIOCAX25DELUID: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + ax25_uid = NULL; + write_lock(&ax25_uid_lock); + ax25_uid_for_each(ax25_uid, &ax25_uid_list) { + if (ax25cmp(&sax->sax25_call, &ax25_uid->call) == 0) + break; + } + if (ax25_uid == NULL) { + write_unlock(&ax25_uid_lock); + return -ENOENT; + } + hlist_del_init(&ax25_uid->uid_node); + ax25_uid_put(ax25_uid); + write_unlock(&ax25_uid_lock); + + return 0; + + default: + return -EINVAL; + } + + return -EINVAL; /*NOTREACHED */ +} + +#ifdef CONFIG_PROC_FS + +static void *ax25_uid_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(ax25_uid_lock) +{ + read_lock(&ax25_uid_lock); + return seq_hlist_start_head(&ax25_uid_list, *pos); +} + +static void *ax25_uid_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + return seq_hlist_next(v, &ax25_uid_list, pos); +} + +static void ax25_uid_seq_stop(struct seq_file *seq, void *v) + __releases(ax25_uid_lock) +{ + read_unlock(&ax25_uid_lock); +} + +static int ax25_uid_seq_show(struct seq_file *seq, void *v) +{ + char buf[11]; + + if (v == SEQ_START_TOKEN) + seq_printf(seq, "Policy: %d\n", ax25_uid_policy); + else { + struct ax25_uid_assoc *pt; + + pt = hlist_entry(v, struct ax25_uid_assoc, uid_node); + seq_printf(seq, "%6d %s\n", + from_kuid_munged(seq_user_ns(seq), pt->uid), + ax2asc(buf, &pt->call)); + } + return 0; +} + +const struct seq_operations ax25_uid_seqops = { + .start = ax25_uid_seq_start, + .next = ax25_uid_seq_next, + .stop = ax25_uid_seq_stop, + .show = ax25_uid_seq_show, +}; +#endif + +/* + * Free all memory associated with UID/Callsign structures. + */ +void __exit ax25_uid_free(void) +{ + ax25_uid_assoc *ax25_uid; + + write_lock(&ax25_uid_lock); +again: + ax25_uid_for_each(ax25_uid, &ax25_uid_list) { + hlist_del_init(&ax25_uid->uid_node); + ax25_uid_put(ax25_uid); + goto again; + } + write_unlock(&ax25_uid_lock); +} diff --git a/net/ax25/sysctl_net_ax25.c b/net/ax25/sysctl_net_ax25.c new file mode 100644 index 000000000..2154d004d --- /dev/null +++ b/net/ax25/sysctl_net_ax25.c @@ -0,0 +1,181 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) 1996 Mike Shaver (shaver@zeroknowledge.com) + */ +#include +#include +#include +#include +#include + +static int min_ipdefmode[1], max_ipdefmode[] = {1}; +static int min_axdefmode[1], max_axdefmode[] = {1}; +static int min_backoff[1], max_backoff[] = {2}; +static int min_conmode[1], max_conmode[] = {2}; +static int min_window[] = {1}, max_window[] = {7}; +static int min_ewindow[] = {1}, max_ewindow[] = {63}; +static int min_t1[] = {1}, max_t1[] = {30000}; +static int min_t2[] = {1}, max_t2[] = {20000}; +static int min_t3[1], max_t3[] = {3600000}; +static int min_idle[1], max_idle[] = {65535000}; +static int min_n2[] = {1}, max_n2[] = {31}; +static int min_paclen[] = {1}, max_paclen[] = {512}; +static int min_proto[1], max_proto[] = { AX25_PROTO_MAX }; +#ifdef CONFIG_AX25_DAMA_SLAVE +static int min_ds_timeout[1], max_ds_timeout[] = {65535000}; +#endif + +static const struct ctl_table ax25_param_table[] = { + { + .procname = "ip_default_mode", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_ipdefmode, + .extra2 = &max_ipdefmode + }, + { + .procname = "ax25_default_mode", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_axdefmode, + .extra2 = &max_axdefmode + }, + { + .procname = "backoff_type", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_backoff, + .extra2 = &max_backoff + }, + { + .procname = "connect_mode", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_conmode, + .extra2 = &max_conmode + }, + { + .procname = "standard_window_size", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_window, + .extra2 = &max_window + }, + { + .procname = "extended_window_size", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_ewindow, + .extra2 = &max_ewindow + }, + { + .procname = "t1_timeout", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_t1, + .extra2 = &max_t1 + }, + { + .procname = "t2_timeout", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_t2, + .extra2 = &max_t2 + }, + { + .procname = "t3_timeout", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_t3, + .extra2 = &max_t3 + }, + { + .procname = "idle_timeout", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_idle, + .extra2 = &max_idle + }, + { + .procname = "maximum_retry_count", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_n2, + .extra2 = &max_n2 + }, + { + .procname = "maximum_packet_length", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_paclen, + .extra2 = &max_paclen + }, + { + .procname = "protocol", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_proto, + .extra2 = &max_proto + }, +#ifdef CONFIG_AX25_DAMA_SLAVE + { + .procname = "dama_slave_timeout", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_ds_timeout, + .extra2 = &max_ds_timeout + }, +#endif + + { } /* that's all, folks! */ +}; + +int ax25_register_dev_sysctl(ax25_dev *ax25_dev) +{ + char path[sizeof("net/ax25/") + IFNAMSIZ]; + int k; + struct ctl_table *table; + + table = kmemdup(ax25_param_table, sizeof(ax25_param_table), GFP_KERNEL); + if (!table) + return -ENOMEM; + + for (k = 0; k < AX25_MAX_VALUES; k++) + table[k].data = &ax25_dev->values[k]; + + snprintf(path, sizeof(path), "net/ax25/%s", ax25_dev->dev->name); + ax25_dev->sysheader = register_net_sysctl(&init_net, path, table); + if (!ax25_dev->sysheader) { + kfree(table); + return -ENOMEM; + } + return 0; +} + +void ax25_unregister_dev_sysctl(ax25_dev *ax25_dev) +{ + struct ctl_table_header *header = ax25_dev->sysheader; + struct ctl_table *table; + + if (header) { + ax25_dev->sysheader = NULL; + table = header->ctl_table_arg; + unregister_net_sysctl_table(header); + kfree(table); + } +} diff --git a/net/batman-adv/Kconfig b/net/batman-adv/Kconfig new file mode 100644 index 000000000..860a0786b --- /dev/null +++ b/net/batman-adv/Kconfig @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: GPL-2.0 +# Copyright (C) B.A.T.M.A.N. contributors: +# +# Marek Lindner, Simon Wunderlich + +# +# B.A.T.M.A.N meshing protocol +# + +config BATMAN_ADV + tristate "B.A.T.M.A.N. Advanced Meshing Protocol" + select LIBCRC32C + help + B.A.T.M.A.N. (better approach to mobile ad-hoc networking) is + a routing protocol for multi-hop ad-hoc mesh networks. The + networks may be wired or wireless. See + https://www.open-mesh.org/ for more information and user space + tools. + +config BATMAN_ADV_BATMAN_V + bool "B.A.T.M.A.N. V protocol" + depends on BATMAN_ADV && !(CFG80211=m && BATMAN_ADV=y) + default y + help + This option enables the B.A.T.M.A.N. V protocol, the successor + of the currently used B.A.T.M.A.N. IV protocol. The main + changes include splitting of the OGM protocol into a neighbor + discovery protocol (Echo Location Protocol, ELP) and a new OGM + Protocol OGMv2 for flooding protocol information through the + network, as well as a throughput based metric. + B.A.T.M.A.N. V is currently considered experimental and not + compatible to B.A.T.M.A.N. IV networks. + +config BATMAN_ADV_BLA + bool "Bridge Loop Avoidance" + depends on BATMAN_ADV && INET + select CRC16 + default y + help + This option enables BLA (Bridge Loop Avoidance), a mechanism + to avoid Ethernet frames looping when mesh nodes are connected + to both the same LAN and the same mesh. If you will never use + more than one mesh node in the same LAN, you can safely remove + this feature and save some space. + +config BATMAN_ADV_DAT + bool "Distributed ARP Table" + depends on BATMAN_ADV && INET + default y + help + This option enables DAT (Distributed ARP Table), a DHT based + mechanism that increases ARP reliability on sparse wireless + mesh networks. If you think that your network does not need + this option you can safely remove it and save some space. + +config BATMAN_ADV_NC + bool "Network Coding" + depends on BATMAN_ADV + help + This option enables network coding, a mechanism that aims to + increase the overall network throughput by fusing multiple + packets in one transmission. + Note that interfaces controlled by batman-adv must be manually + configured to have promiscuous mode enabled in order to make + network coding work. + If you think that your network does not need this feature you + can safely disable it and save some space. + +config BATMAN_ADV_MCAST + bool "Multicast optimisation" + depends on BATMAN_ADV && INET && !(BRIDGE=m && BATMAN_ADV=y) + default y + help + This option enables the multicast optimisation which aims to + reduce the air overhead while improving the reliability of + multicast messages. + +config BATMAN_ADV_DEBUG + bool "B.A.T.M.A.N. debugging" + depends on BATMAN_ADV + help + This is an option for use by developers; most people should + say N here. This enables compilation of support for + outputting debugging information to the tracing buffer. The output is + controlled via the batadv netdev specific log_level setting. + +config BATMAN_ADV_TRACING + bool "B.A.T.M.A.N. tracing support" + depends on BATMAN_ADV + depends on EVENT_TRACING + help + This is an option for use by developers; most people should + say N here. Select this option to gather traces like the debug + messages using the generic tracing infrastructure of the kernel. + BATMAN_ADV_DEBUG must also be selected to get trace events for + batadv_dbg. diff --git a/net/batman-adv/Makefile b/net/batman-adv/Makefile new file mode 100644 index 000000000..3bd0760c7 --- /dev/null +++ b/net/batman-adv/Makefile @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: GPL-2.0 +# Copyright (C) B.A.T.M.A.N. contributors: +# +# Marek Lindner, Simon Wunderlich + +obj-$(CONFIG_BATMAN_ADV) += batman-adv.o +batman-adv-y += bat_algo.o +batman-adv-y += bat_iv_ogm.o +batman-adv-$(CONFIG_BATMAN_ADV_BATMAN_V) += bat_v.o +batman-adv-$(CONFIG_BATMAN_ADV_BATMAN_V) += bat_v_elp.o +batman-adv-$(CONFIG_BATMAN_ADV_BATMAN_V) += bat_v_ogm.o +batman-adv-y += bitarray.o +batman-adv-$(CONFIG_BATMAN_ADV_BLA) += bridge_loop_avoidance.o +batman-adv-$(CONFIG_BATMAN_ADV_DAT) += distributed-arp-table.o +batman-adv-y += fragmentation.o +batman-adv-y += gateway_client.o +batman-adv-y += gateway_common.o +batman-adv-y += hard-interface.o +batman-adv-y += hash.o +batman-adv-$(CONFIG_BATMAN_ADV_DEBUG) += log.o +batman-adv-y += main.o +batman-adv-$(CONFIG_BATMAN_ADV_MCAST) += multicast.o +batman-adv-y += netlink.o +batman-adv-$(CONFIG_BATMAN_ADV_NC) += network-coding.o +batman-adv-y += originator.o +batman-adv-y += routing.o +batman-adv-y += send.o +batman-adv-y += soft-interface.o +batman-adv-$(CONFIG_BATMAN_ADV_TRACING) += trace.o +batman-adv-y += tp_meter.o +batman-adv-y += translation-table.o +batman-adv-y += tvlv.o + +CFLAGS_trace.o := -I$(src) diff --git a/net/batman-adv/bat_algo.c b/net/batman-adv/bat_algo.c new file mode 100644 index 000000000..4eee53d19 --- /dev/null +++ b/net/batman-adv/bat_algo.c @@ -0,0 +1,209 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "bat_algo.h" +#include "netlink.h" + +char batadv_routing_algo[20] = "BATMAN_IV"; +static struct hlist_head batadv_algo_list; + +/** + * batadv_algo_init() - Initialize batman-adv algorithm management data + * structures + */ +void batadv_algo_init(void) +{ + INIT_HLIST_HEAD(&batadv_algo_list); +} + +/** + * batadv_algo_get() - Search for algorithm with specific name + * @name: algorithm name to find + * + * Return: Pointer to batadv_algo_ops on success, NULL otherwise + */ +struct batadv_algo_ops *batadv_algo_get(const char *name) +{ + struct batadv_algo_ops *bat_algo_ops = NULL, *bat_algo_ops_tmp; + + hlist_for_each_entry(bat_algo_ops_tmp, &batadv_algo_list, list) { + if (strcmp(bat_algo_ops_tmp->name, name) != 0) + continue; + + bat_algo_ops = bat_algo_ops_tmp; + break; + } + + return bat_algo_ops; +} + +/** + * batadv_algo_register() - Register callbacks for a mesh algorithm + * @bat_algo_ops: mesh algorithm callbacks to add + * + * Return: 0 on success or negative error number in case of failure + */ +int batadv_algo_register(struct batadv_algo_ops *bat_algo_ops) +{ + struct batadv_algo_ops *bat_algo_ops_tmp; + + bat_algo_ops_tmp = batadv_algo_get(bat_algo_ops->name); + if (bat_algo_ops_tmp) { + pr_info("Trying to register already registered routing algorithm: %s\n", + bat_algo_ops->name); + return -EEXIST; + } + + /* all algorithms must implement all ops (for now) */ + if (!bat_algo_ops->iface.enable || + !bat_algo_ops->iface.disable || + !bat_algo_ops->iface.update_mac || + !bat_algo_ops->iface.primary_set || + !bat_algo_ops->neigh.cmp || + !bat_algo_ops->neigh.is_similar_or_better) { + pr_info("Routing algo '%s' does not implement required ops\n", + bat_algo_ops->name); + return -EINVAL; + } + + INIT_HLIST_NODE(&bat_algo_ops->list); + hlist_add_head(&bat_algo_ops->list, &batadv_algo_list); + + return 0; +} + +/** + * batadv_algo_select() - Select algorithm of soft interface + * @bat_priv: the bat priv with all the soft interface information + * @name: name of the algorithm to select + * + * The algorithm callbacks for the soft interface will be set when the algorithm + * with the correct name was found. Any previous selected algorithm will not be + * deinitialized and the new selected algorithm will also not be initialized. + * It is therefore not allowed to call batadv_algo_select outside the creation + * function of the soft interface. + * + * Return: 0 on success or negative error number in case of failure + */ +int batadv_algo_select(struct batadv_priv *bat_priv, const char *name) +{ + struct batadv_algo_ops *bat_algo_ops; + + bat_algo_ops = batadv_algo_get(name); + if (!bat_algo_ops) + return -EINVAL; + + bat_priv->algo_ops = bat_algo_ops; + + return 0; +} + +static int batadv_param_set_ra(const char *val, const struct kernel_param *kp) +{ + struct batadv_algo_ops *bat_algo_ops; + char *algo_name = (char *)val; + size_t name_len = strlen(algo_name); + + if (name_len > 0 && algo_name[name_len - 1] == '\n') + algo_name[name_len - 1] = '\0'; + + bat_algo_ops = batadv_algo_get(algo_name); + if (!bat_algo_ops) { + pr_err("Routing algorithm '%s' is not supported\n", algo_name); + return -EINVAL; + } + + return param_set_copystring(algo_name, kp); +} + +static const struct kernel_param_ops batadv_param_ops_ra = { + .set = batadv_param_set_ra, + .get = param_get_string, +}; + +static struct kparam_string batadv_param_string_ra = { + .maxlen = sizeof(batadv_routing_algo), + .string = batadv_routing_algo, +}; + +module_param_cb(routing_algo, &batadv_param_ops_ra, &batadv_param_string_ra, + 0644); + +/** + * batadv_algo_dump_entry() - fill in information about one supported routing + * algorithm + * @msg: netlink message to be sent back + * @portid: Port to reply to + * @seq: Sequence number of message + * @bat_algo_ops: Algorithm to be dumped + * + * Return: Error number, or 0 on success + */ +static int batadv_algo_dump_entry(struct sk_buff *msg, u32 portid, u32 seq, + struct batadv_algo_ops *bat_algo_ops) +{ + void *hdr; + + hdr = genlmsg_put(msg, portid, seq, &batadv_netlink_family, + NLM_F_MULTI, BATADV_CMD_GET_ROUTING_ALGOS); + if (!hdr) + return -EMSGSIZE; + + if (nla_put_string(msg, BATADV_ATTR_ALGO_NAME, bat_algo_ops->name)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + + nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +/** + * batadv_algo_dump() - fill in information about supported routing + * algorithms + * @msg: netlink message to be sent back + * @cb: Parameters to the netlink request + * + * Return: Length of reply message. + */ +int batadv_algo_dump(struct sk_buff *msg, struct netlink_callback *cb) +{ + int portid = NETLINK_CB(cb->skb).portid; + struct batadv_algo_ops *bat_algo_ops; + int skip = cb->args[0]; + int i = 0; + + hlist_for_each_entry(bat_algo_ops, &batadv_algo_list, list) { + if (i++ < skip) + continue; + + if (batadv_algo_dump_entry(msg, portid, cb->nlh->nlmsg_seq, + bat_algo_ops)) { + i--; + break; + } + } + + cb->args[0] = i; + + return msg->len; +} diff --git a/net/batman-adv/bat_algo.h b/net/batman-adv/bat_algo.h new file mode 100644 index 000000000..2c486374a --- /dev/null +++ b/net/batman-adv/bat_algo.h @@ -0,0 +1,25 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Linus Lüssing + */ + +#ifndef _NET_BATMAN_ADV_BAT_ALGO_H_ +#define _NET_BATMAN_ADV_BAT_ALGO_H_ + +#include "main.h" + +#include +#include +#include + +extern char batadv_routing_algo[]; +extern struct list_head batadv_hardif_list; + +void batadv_algo_init(void); +struct batadv_algo_ops *batadv_algo_get(const char *name); +int batadv_algo_register(struct batadv_algo_ops *bat_algo_ops); +int batadv_algo_select(struct batadv_priv *bat_priv, const char *name); +int batadv_algo_dump(struct sk_buff *msg, struct netlink_callback *cb); + +#endif /* _NET_BATMAN_ADV_BAT_ALGO_H_ */ diff --git a/net/batman-adv/bat_iv_ogm.c b/net/batman-adv/bat_iv_ogm.c new file mode 100644 index 000000000..7f6a7c96a --- /dev/null +++ b/net/batman-adv/bat_iv_ogm.c @@ -0,0 +1,2551 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#include "bat_iv_ogm.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "bat_algo.h" +#include "bitarray.h" +#include "gateway_client.h" +#include "hard-interface.h" +#include "hash.h" +#include "log.h" +#include "netlink.h" +#include "network-coding.h" +#include "originator.h" +#include "routing.h" +#include "send.h" +#include "translation-table.h" +#include "tvlv.h" + +static void batadv_iv_send_outstanding_bat_ogm_packet(struct work_struct *work); + +/** + * enum batadv_dup_status - duplicate status + */ +enum batadv_dup_status { + /** @BATADV_NO_DUP: the packet is no duplicate */ + BATADV_NO_DUP = 0, + + /** + * @BATADV_ORIG_DUP: OGM is a duplicate in the originator (but not for + * the neighbor) + */ + BATADV_ORIG_DUP, + + /** @BATADV_NEIGH_DUP: OGM is a duplicate for the neighbor */ + BATADV_NEIGH_DUP, + + /** + * @BATADV_PROTECTED: originator is currently protected (after reboot) + */ + BATADV_PROTECTED, +}; + +/** + * batadv_ring_buffer_set() - update the ring buffer with the given value + * @lq_recv: pointer to the ring buffer + * @lq_index: index to store the value at + * @value: value to store in the ring buffer + */ +static void batadv_ring_buffer_set(u8 lq_recv[], u8 *lq_index, u8 value) +{ + lq_recv[*lq_index] = value; + *lq_index = (*lq_index + 1) % BATADV_TQ_GLOBAL_WINDOW_SIZE; +} + +/** + * batadv_ring_buffer_avg() - compute the average of all non-zero values stored + * in the given ring buffer + * @lq_recv: pointer to the ring buffer + * + * Return: computed average value. + */ +static u8 batadv_ring_buffer_avg(const u8 lq_recv[]) +{ + const u8 *ptr; + u16 count = 0; + u16 i = 0; + u16 sum = 0; + + ptr = lq_recv; + + while (i < BATADV_TQ_GLOBAL_WINDOW_SIZE) { + if (*ptr != 0) { + count++; + sum += *ptr; + } + + i++; + ptr++; + } + + if (count == 0) + return 0; + + return (u8)(sum / count); +} + +/** + * batadv_iv_ogm_orig_get() - retrieve or create (if does not exist) an + * originator + * @bat_priv: the bat priv with all the soft interface information + * @addr: mac address of the originator + * + * Return: the originator object corresponding to the passed mac address or NULL + * on failure. + * If the object does not exist, it is created and initialised. + */ +static struct batadv_orig_node * +batadv_iv_ogm_orig_get(struct batadv_priv *bat_priv, const u8 *addr) +{ + struct batadv_orig_node *orig_node; + int hash_added; + + orig_node = batadv_orig_hash_find(bat_priv, addr); + if (orig_node) + return orig_node; + + orig_node = batadv_orig_node_new(bat_priv, addr); + if (!orig_node) + return NULL; + + spin_lock_init(&orig_node->bat_iv.ogm_cnt_lock); + + kref_get(&orig_node->refcount); + hash_added = batadv_hash_add(bat_priv->orig_hash, batadv_compare_orig, + batadv_choose_orig, orig_node, + &orig_node->hash_entry); + if (hash_added != 0) + goto free_orig_node_hash; + + return orig_node; + +free_orig_node_hash: + /* reference for batadv_hash_add */ + batadv_orig_node_put(orig_node); + /* reference from batadv_orig_node_new */ + batadv_orig_node_put(orig_node); + + return NULL; +} + +static struct batadv_neigh_node * +batadv_iv_ogm_neigh_new(struct batadv_hard_iface *hard_iface, + const u8 *neigh_addr, + struct batadv_orig_node *orig_node, + struct batadv_orig_node *orig_neigh) +{ + struct batadv_neigh_node *neigh_node; + + neigh_node = batadv_neigh_node_get_or_create(orig_node, + hard_iface, neigh_addr); + if (!neigh_node) + goto out; + + neigh_node->orig_node = orig_neigh; + +out: + return neigh_node; +} + +static int batadv_iv_ogm_iface_enable(struct batadv_hard_iface *hard_iface) +{ + struct batadv_ogm_packet *batadv_ogm_packet; + unsigned char *ogm_buff; + u32 random_seqno; + + mutex_lock(&hard_iface->bat_iv.ogm_buff_mutex); + + /* randomize initial seqno to avoid collision */ + get_random_bytes(&random_seqno, sizeof(random_seqno)); + atomic_set(&hard_iface->bat_iv.ogm_seqno, random_seqno); + + hard_iface->bat_iv.ogm_buff_len = BATADV_OGM_HLEN; + ogm_buff = kmalloc(hard_iface->bat_iv.ogm_buff_len, GFP_ATOMIC); + if (!ogm_buff) { + mutex_unlock(&hard_iface->bat_iv.ogm_buff_mutex); + return -ENOMEM; + } + + hard_iface->bat_iv.ogm_buff = ogm_buff; + + batadv_ogm_packet = (struct batadv_ogm_packet *)ogm_buff; + batadv_ogm_packet->packet_type = BATADV_IV_OGM; + batadv_ogm_packet->version = BATADV_COMPAT_VERSION; + batadv_ogm_packet->ttl = 2; + batadv_ogm_packet->flags = BATADV_NO_FLAGS; + batadv_ogm_packet->reserved = 0; + batadv_ogm_packet->tq = BATADV_TQ_MAX_VALUE; + + mutex_unlock(&hard_iface->bat_iv.ogm_buff_mutex); + + return 0; +} + +static void batadv_iv_ogm_iface_disable(struct batadv_hard_iface *hard_iface) +{ + mutex_lock(&hard_iface->bat_iv.ogm_buff_mutex); + + kfree(hard_iface->bat_iv.ogm_buff); + hard_iface->bat_iv.ogm_buff = NULL; + + mutex_unlock(&hard_iface->bat_iv.ogm_buff_mutex); +} + +static void batadv_iv_ogm_iface_update_mac(struct batadv_hard_iface *hard_iface) +{ + struct batadv_ogm_packet *batadv_ogm_packet; + void *ogm_buff; + + mutex_lock(&hard_iface->bat_iv.ogm_buff_mutex); + + ogm_buff = hard_iface->bat_iv.ogm_buff; + if (!ogm_buff) + goto unlock; + + batadv_ogm_packet = ogm_buff; + ether_addr_copy(batadv_ogm_packet->orig, + hard_iface->net_dev->dev_addr); + ether_addr_copy(batadv_ogm_packet->prev_sender, + hard_iface->net_dev->dev_addr); + +unlock: + mutex_unlock(&hard_iface->bat_iv.ogm_buff_mutex); +} + +static void +batadv_iv_ogm_primary_iface_set(struct batadv_hard_iface *hard_iface) +{ + struct batadv_ogm_packet *batadv_ogm_packet; + void *ogm_buff; + + mutex_lock(&hard_iface->bat_iv.ogm_buff_mutex); + + ogm_buff = hard_iface->bat_iv.ogm_buff; + if (!ogm_buff) + goto unlock; + + batadv_ogm_packet = ogm_buff; + batadv_ogm_packet->ttl = BATADV_TTL; + +unlock: + mutex_unlock(&hard_iface->bat_iv.ogm_buff_mutex); +} + +/* when do we schedule our own ogm to be sent */ +static unsigned long +batadv_iv_ogm_emit_send_time(const struct batadv_priv *bat_priv) +{ + unsigned int msecs; + + msecs = atomic_read(&bat_priv->orig_interval) - BATADV_JITTER; + msecs += prandom_u32_max(2 * BATADV_JITTER); + + return jiffies + msecs_to_jiffies(msecs); +} + +/* when do we schedule a ogm packet to be sent */ +static unsigned long batadv_iv_ogm_fwd_send_time(void) +{ + return jiffies + msecs_to_jiffies(prandom_u32_max(BATADV_JITTER / 2)); +} + +/* apply hop penalty for a normal link */ +static u8 batadv_hop_penalty(u8 tq, const struct batadv_priv *bat_priv) +{ + int hop_penalty = atomic_read(&bat_priv->hop_penalty); + int new_tq; + + new_tq = tq * (BATADV_TQ_MAX_VALUE - hop_penalty); + new_tq /= BATADV_TQ_MAX_VALUE; + + return new_tq; +} + +/** + * batadv_iv_ogm_aggr_packet() - checks if there is another OGM attached + * @buff_pos: current position in the skb + * @packet_len: total length of the skb + * @ogm_packet: potential OGM in buffer + * + * Return: true if there is enough space for another OGM, false otherwise. + */ +static bool +batadv_iv_ogm_aggr_packet(int buff_pos, int packet_len, + const struct batadv_ogm_packet *ogm_packet) +{ + int next_buff_pos = 0; + + /* check if there is enough space for the header */ + next_buff_pos += buff_pos + sizeof(*ogm_packet); + if (next_buff_pos > packet_len) + return false; + + /* check if there is enough space for the optional TVLV */ + next_buff_pos += ntohs(ogm_packet->tvlv_len); + + return (next_buff_pos <= packet_len) && + (next_buff_pos <= BATADV_MAX_AGGREGATION_BYTES); +} + +/* send a batman ogm to a given interface */ +static void batadv_iv_ogm_send_to_if(struct batadv_forw_packet *forw_packet, + struct batadv_hard_iface *hard_iface) +{ + struct batadv_priv *bat_priv = netdev_priv(hard_iface->soft_iface); + const char *fwd_str; + u8 packet_num; + s16 buff_pos; + struct batadv_ogm_packet *batadv_ogm_packet; + struct sk_buff *skb; + u8 *packet_pos; + + if (hard_iface->if_status != BATADV_IF_ACTIVE) + return; + + packet_num = 0; + buff_pos = 0; + packet_pos = forw_packet->skb->data; + batadv_ogm_packet = (struct batadv_ogm_packet *)packet_pos; + + /* adjust all flags and log packets */ + while (batadv_iv_ogm_aggr_packet(buff_pos, forw_packet->packet_len, + batadv_ogm_packet)) { + /* we might have aggregated direct link packets with an + * ordinary base packet + */ + if (forw_packet->direct_link_flags & BIT(packet_num) && + forw_packet->if_incoming == hard_iface) + batadv_ogm_packet->flags |= BATADV_DIRECTLINK; + else + batadv_ogm_packet->flags &= ~BATADV_DIRECTLINK; + + if (packet_num > 0 || !forw_packet->own) + fwd_str = "Forwarding"; + else + fwd_str = "Sending own"; + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "%s %spacket (originator %pM, seqno %u, TQ %d, TTL %d, IDF %s) on interface %s [%pM]\n", + fwd_str, (packet_num > 0 ? "aggregated " : ""), + batadv_ogm_packet->orig, + ntohl(batadv_ogm_packet->seqno), + batadv_ogm_packet->tq, batadv_ogm_packet->ttl, + ((batadv_ogm_packet->flags & BATADV_DIRECTLINK) ? + "on" : "off"), + hard_iface->net_dev->name, + hard_iface->net_dev->dev_addr); + + buff_pos += BATADV_OGM_HLEN; + buff_pos += ntohs(batadv_ogm_packet->tvlv_len); + packet_num++; + packet_pos = forw_packet->skb->data + buff_pos; + batadv_ogm_packet = (struct batadv_ogm_packet *)packet_pos; + } + + /* create clone because function is called more than once */ + skb = skb_clone(forw_packet->skb, GFP_ATOMIC); + if (skb) { + batadv_inc_counter(bat_priv, BATADV_CNT_MGMT_TX); + batadv_add_counter(bat_priv, BATADV_CNT_MGMT_TX_BYTES, + skb->len + ETH_HLEN); + batadv_send_broadcast_skb(skb, hard_iface); + } +} + +/* send a batman ogm packet */ +static void batadv_iv_ogm_emit(struct batadv_forw_packet *forw_packet) +{ + struct net_device *soft_iface; + + if (!forw_packet->if_incoming) { + pr_err("Error - can't forward packet: incoming iface not specified\n"); + return; + } + + soft_iface = forw_packet->if_incoming->soft_iface; + + if (WARN_ON(!forw_packet->if_outgoing)) + return; + + if (forw_packet->if_outgoing->soft_iface != soft_iface) { + pr_warn("%s: soft interface switch for queued OGM\n", __func__); + return; + } + + if (forw_packet->if_incoming->if_status != BATADV_IF_ACTIVE) + return; + + /* only for one specific outgoing interface */ + batadv_iv_ogm_send_to_if(forw_packet, forw_packet->if_outgoing); +} + +/** + * batadv_iv_ogm_can_aggregate() - find out if an OGM can be aggregated on an + * existing forward packet + * @new_bat_ogm_packet: OGM packet to be aggregated + * @bat_priv: the bat priv with all the soft interface information + * @packet_len: (total) length of the OGM + * @send_time: timestamp (jiffies) when the packet is to be sent + * @directlink: true if this is a direct link packet + * @if_incoming: interface where the packet was received + * @if_outgoing: interface for which the retransmission should be considered + * @forw_packet: the forwarded packet which should be checked + * + * Return: true if new_packet can be aggregated with forw_packet + */ +static bool +batadv_iv_ogm_can_aggregate(const struct batadv_ogm_packet *new_bat_ogm_packet, + struct batadv_priv *bat_priv, + int packet_len, unsigned long send_time, + bool directlink, + const struct batadv_hard_iface *if_incoming, + const struct batadv_hard_iface *if_outgoing, + const struct batadv_forw_packet *forw_packet) +{ + struct batadv_ogm_packet *batadv_ogm_packet; + int aggregated_bytes = forw_packet->packet_len + packet_len; + struct batadv_hard_iface *primary_if = NULL; + bool res = false; + unsigned long aggregation_end_time; + + batadv_ogm_packet = (struct batadv_ogm_packet *)forw_packet->skb->data; + aggregation_end_time = send_time; + aggregation_end_time += msecs_to_jiffies(BATADV_MAX_AGGREGATION_MS); + + /* we can aggregate the current packet to this aggregated packet + * if: + * + * - the send time is within our MAX_AGGREGATION_MS time + * - the resulting packet won't be bigger than + * MAX_AGGREGATION_BYTES + * otherwise aggregation is not possible + */ + if (!time_before(send_time, forw_packet->send_time) || + !time_after_eq(aggregation_end_time, forw_packet->send_time)) + return false; + + if (aggregated_bytes > BATADV_MAX_AGGREGATION_BYTES) + return false; + + /* packet is not leaving on the same interface. */ + if (forw_packet->if_outgoing != if_outgoing) + return false; + + /* check aggregation compatibility + * -> direct link packets are broadcasted on + * their interface only + * -> aggregate packet if the current packet is + * a "global" packet as well as the base + * packet + */ + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + return false; + + /* packets without direct link flag and high TTL + * are flooded through the net + */ + if (!directlink && + !(batadv_ogm_packet->flags & BATADV_DIRECTLINK) && + batadv_ogm_packet->ttl != 1 && + + /* own packets originating non-primary + * interfaces leave only that interface + */ + (!forw_packet->own || + forw_packet->if_incoming == primary_if)) { + res = true; + goto out; + } + + /* if the incoming packet is sent via this one + * interface only - we still can aggregate + */ + if (directlink && + new_bat_ogm_packet->ttl == 1 && + forw_packet->if_incoming == if_incoming && + + /* packets from direct neighbors or + * own secondary interface packets + * (= secondary interface packets in general) + */ + (batadv_ogm_packet->flags & BATADV_DIRECTLINK || + (forw_packet->own && + forw_packet->if_incoming != primary_if))) { + res = true; + goto out; + } + +out: + batadv_hardif_put(primary_if); + return res; +} + +/** + * batadv_iv_ogm_aggregate_new() - create a new aggregated packet and add this + * packet to it. + * @packet_buff: pointer to the OGM + * @packet_len: (total) length of the OGM + * @send_time: timestamp (jiffies) when the packet is to be sent + * @direct_link: whether this OGM has direct link status + * @if_incoming: interface where the packet was received + * @if_outgoing: interface for which the retransmission should be considered + * @own_packet: true if it is a self-generated ogm + */ +static void batadv_iv_ogm_aggregate_new(const unsigned char *packet_buff, + int packet_len, unsigned long send_time, + bool direct_link, + struct batadv_hard_iface *if_incoming, + struct batadv_hard_iface *if_outgoing, + int own_packet) +{ + struct batadv_priv *bat_priv = netdev_priv(if_incoming->soft_iface); + struct batadv_forw_packet *forw_packet_aggr; + struct sk_buff *skb; + unsigned char *skb_buff; + unsigned int skb_size; + atomic_t *queue_left = own_packet ? NULL : &bat_priv->batman_queue_left; + + if (atomic_read(&bat_priv->aggregated_ogms) && + packet_len < BATADV_MAX_AGGREGATION_BYTES) + skb_size = BATADV_MAX_AGGREGATION_BYTES; + else + skb_size = packet_len; + + skb_size += ETH_HLEN; + + skb = netdev_alloc_skb_ip_align(NULL, skb_size); + if (!skb) + return; + + forw_packet_aggr = batadv_forw_packet_alloc(if_incoming, if_outgoing, + queue_left, bat_priv, skb); + if (!forw_packet_aggr) { + kfree_skb(skb); + return; + } + + forw_packet_aggr->skb->priority = TC_PRIO_CONTROL; + skb_reserve(forw_packet_aggr->skb, ETH_HLEN); + + skb_buff = skb_put(forw_packet_aggr->skb, packet_len); + forw_packet_aggr->packet_len = packet_len; + memcpy(skb_buff, packet_buff, packet_len); + + forw_packet_aggr->own = own_packet; + forw_packet_aggr->direct_link_flags = BATADV_NO_FLAGS; + forw_packet_aggr->send_time = send_time; + + /* save packet direct link flag status */ + if (direct_link) + forw_packet_aggr->direct_link_flags |= 1; + + INIT_DELAYED_WORK(&forw_packet_aggr->delayed_work, + batadv_iv_send_outstanding_bat_ogm_packet); + + batadv_forw_packet_ogmv1_queue(bat_priv, forw_packet_aggr, send_time); +} + +/* aggregate a new packet into the existing ogm packet */ +static void batadv_iv_ogm_aggregate(struct batadv_forw_packet *forw_packet_aggr, + const unsigned char *packet_buff, + int packet_len, bool direct_link) +{ + unsigned long new_direct_link_flag; + + skb_put_data(forw_packet_aggr->skb, packet_buff, packet_len); + forw_packet_aggr->packet_len += packet_len; + forw_packet_aggr->num_packets++; + + /* save packet direct link flag status */ + if (direct_link) { + new_direct_link_flag = BIT(forw_packet_aggr->num_packets); + forw_packet_aggr->direct_link_flags |= new_direct_link_flag; + } +} + +/** + * batadv_iv_ogm_queue_add() - queue up an OGM for transmission + * @bat_priv: the bat priv with all the soft interface information + * @packet_buff: pointer to the OGM + * @packet_len: (total) length of the OGM + * @if_incoming: interface where the packet was received + * @if_outgoing: interface for which the retransmission should be considered + * @own_packet: true if it is a self-generated ogm + * @send_time: timestamp (jiffies) when the packet is to be sent + */ +static void batadv_iv_ogm_queue_add(struct batadv_priv *bat_priv, + unsigned char *packet_buff, + int packet_len, + struct batadv_hard_iface *if_incoming, + struct batadv_hard_iface *if_outgoing, + int own_packet, unsigned long send_time) +{ + /* _aggr -> pointer to the packet we want to aggregate with + * _pos -> pointer to the position in the queue + */ + struct batadv_forw_packet *forw_packet_aggr = NULL; + struct batadv_forw_packet *forw_packet_pos = NULL; + struct batadv_ogm_packet *batadv_ogm_packet; + bool direct_link; + unsigned long max_aggregation_jiffies; + + batadv_ogm_packet = (struct batadv_ogm_packet *)packet_buff; + direct_link = !!(batadv_ogm_packet->flags & BATADV_DIRECTLINK); + max_aggregation_jiffies = msecs_to_jiffies(BATADV_MAX_AGGREGATION_MS); + + /* find position for the packet in the forward queue */ + spin_lock_bh(&bat_priv->forw_bat_list_lock); + /* own packets are not to be aggregated */ + if (atomic_read(&bat_priv->aggregated_ogms) && !own_packet) { + hlist_for_each_entry(forw_packet_pos, + &bat_priv->forw_bat_list, list) { + if (batadv_iv_ogm_can_aggregate(batadv_ogm_packet, + bat_priv, packet_len, + send_time, direct_link, + if_incoming, + if_outgoing, + forw_packet_pos)) { + forw_packet_aggr = forw_packet_pos; + break; + } + } + } + + /* nothing to aggregate with - either aggregation disabled or no + * suitable aggregation packet found + */ + if (!forw_packet_aggr) { + /* the following section can run without the lock */ + spin_unlock_bh(&bat_priv->forw_bat_list_lock); + + /* if we could not aggregate this packet with one of the others + * we hold it back for a while, so that it might be aggregated + * later on + */ + if (!own_packet && atomic_read(&bat_priv->aggregated_ogms)) + send_time += max_aggregation_jiffies; + + batadv_iv_ogm_aggregate_new(packet_buff, packet_len, + send_time, direct_link, + if_incoming, if_outgoing, + own_packet); + } else { + batadv_iv_ogm_aggregate(forw_packet_aggr, packet_buff, + packet_len, direct_link); + spin_unlock_bh(&bat_priv->forw_bat_list_lock); + } +} + +static void batadv_iv_ogm_forward(struct batadv_orig_node *orig_node, + const struct ethhdr *ethhdr, + struct batadv_ogm_packet *batadv_ogm_packet, + bool is_single_hop_neigh, + bool is_from_best_next_hop, + struct batadv_hard_iface *if_incoming, + struct batadv_hard_iface *if_outgoing) +{ + struct batadv_priv *bat_priv = netdev_priv(if_incoming->soft_iface); + u16 tvlv_len; + + if (batadv_ogm_packet->ttl <= 1) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, "ttl exceeded\n"); + return; + } + + if (!is_from_best_next_hop) { + /* Mark the forwarded packet when it is not coming from our + * best next hop. We still need to forward the packet for our + * neighbor link quality detection to work in case the packet + * originated from a single hop neighbor. Otherwise we can + * simply drop the ogm. + */ + if (is_single_hop_neigh) + batadv_ogm_packet->flags |= BATADV_NOT_BEST_NEXT_HOP; + else + return; + } + + tvlv_len = ntohs(batadv_ogm_packet->tvlv_len); + + batadv_ogm_packet->ttl--; + ether_addr_copy(batadv_ogm_packet->prev_sender, ethhdr->h_source); + + /* apply hop penalty */ + batadv_ogm_packet->tq = batadv_hop_penalty(batadv_ogm_packet->tq, + bat_priv); + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Forwarding packet: tq: %i, ttl: %i\n", + batadv_ogm_packet->tq, batadv_ogm_packet->ttl); + + if (is_single_hop_neigh) + batadv_ogm_packet->flags |= BATADV_DIRECTLINK; + else + batadv_ogm_packet->flags &= ~BATADV_DIRECTLINK; + + batadv_iv_ogm_queue_add(bat_priv, (unsigned char *)batadv_ogm_packet, + BATADV_OGM_HLEN + tvlv_len, + if_incoming, if_outgoing, 0, + batadv_iv_ogm_fwd_send_time()); +} + +/** + * batadv_iv_ogm_slide_own_bcast_window() - bitshift own OGM broadcast windows + * for the given interface + * @hard_iface: the interface for which the windows have to be shifted + */ +static void +batadv_iv_ogm_slide_own_bcast_window(struct batadv_hard_iface *hard_iface) +{ + struct batadv_priv *bat_priv = netdev_priv(hard_iface->soft_iface); + struct batadv_hashtable *hash = bat_priv->orig_hash; + struct hlist_head *head; + struct batadv_orig_node *orig_node; + struct batadv_orig_ifinfo *orig_ifinfo; + unsigned long *word; + u32 i; + u8 *w; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(orig_node, head, hash_entry) { + hlist_for_each_entry_rcu(orig_ifinfo, + &orig_node->ifinfo_list, + list) { + if (orig_ifinfo->if_outgoing != hard_iface) + continue; + + spin_lock_bh(&orig_node->bat_iv.ogm_cnt_lock); + word = orig_ifinfo->bat_iv.bcast_own; + batadv_bit_get_packet(bat_priv, word, 1, 0); + w = &orig_ifinfo->bat_iv.bcast_own_sum; + *w = bitmap_weight(word, + BATADV_TQ_LOCAL_WINDOW_SIZE); + spin_unlock_bh(&orig_node->bat_iv.ogm_cnt_lock); + } + } + rcu_read_unlock(); + } +} + +/** + * batadv_iv_ogm_schedule_buff() - schedule submission of hardif ogm buffer + * @hard_iface: interface whose ogm buffer should be transmitted + */ +static void batadv_iv_ogm_schedule_buff(struct batadv_hard_iface *hard_iface) +{ + struct batadv_priv *bat_priv = netdev_priv(hard_iface->soft_iface); + unsigned char **ogm_buff = &hard_iface->bat_iv.ogm_buff; + struct batadv_ogm_packet *batadv_ogm_packet; + struct batadv_hard_iface *primary_if, *tmp_hard_iface; + int *ogm_buff_len = &hard_iface->bat_iv.ogm_buff_len; + u32 seqno; + u16 tvlv_len = 0; + unsigned long send_time; + + lockdep_assert_held(&hard_iface->bat_iv.ogm_buff_mutex); + + /* interface already disabled by batadv_iv_ogm_iface_disable */ + if (!*ogm_buff) + return; + + /* the interface gets activated here to avoid race conditions between + * the moment of activating the interface in + * hardif_activate_interface() where the originator mac is set and + * outdated packets (especially uninitialized mac addresses) in the + * packet queue + */ + if (hard_iface->if_status == BATADV_IF_TO_BE_ACTIVATED) + hard_iface->if_status = BATADV_IF_ACTIVE; + + primary_if = batadv_primary_if_get_selected(bat_priv); + + if (hard_iface == primary_if) { + /* tt changes have to be committed before the tvlv data is + * appended as it may alter the tt tvlv container + */ + batadv_tt_local_commit_changes(bat_priv); + tvlv_len = batadv_tvlv_container_ogm_append(bat_priv, ogm_buff, + ogm_buff_len, + BATADV_OGM_HLEN); + } + + batadv_ogm_packet = (struct batadv_ogm_packet *)(*ogm_buff); + batadv_ogm_packet->tvlv_len = htons(tvlv_len); + + /* change sequence number to network order */ + seqno = (u32)atomic_read(&hard_iface->bat_iv.ogm_seqno); + batadv_ogm_packet->seqno = htonl(seqno); + atomic_inc(&hard_iface->bat_iv.ogm_seqno); + + batadv_iv_ogm_slide_own_bcast_window(hard_iface); + + send_time = batadv_iv_ogm_emit_send_time(bat_priv); + + if (hard_iface != primary_if) { + /* OGMs from secondary interfaces are only scheduled on their + * respective interfaces. + */ + batadv_iv_ogm_queue_add(bat_priv, *ogm_buff, *ogm_buff_len, + hard_iface, hard_iface, 1, send_time); + goto out; + } + + /* OGMs from primary interfaces are scheduled on all + * interfaces. + */ + rcu_read_lock(); + list_for_each_entry_rcu(tmp_hard_iface, &batadv_hardif_list, list) { + if (tmp_hard_iface->soft_iface != hard_iface->soft_iface) + continue; + + if (!kref_get_unless_zero(&tmp_hard_iface->refcount)) + continue; + + batadv_iv_ogm_queue_add(bat_priv, *ogm_buff, + *ogm_buff_len, hard_iface, + tmp_hard_iface, 1, send_time); + + batadv_hardif_put(tmp_hard_iface); + } + rcu_read_unlock(); + +out: + batadv_hardif_put(primary_if); +} + +static void batadv_iv_ogm_schedule(struct batadv_hard_iface *hard_iface) +{ + if (hard_iface->if_status == BATADV_IF_NOT_IN_USE || + hard_iface->if_status == BATADV_IF_TO_BE_REMOVED) + return; + + mutex_lock(&hard_iface->bat_iv.ogm_buff_mutex); + batadv_iv_ogm_schedule_buff(hard_iface); + mutex_unlock(&hard_iface->bat_iv.ogm_buff_mutex); +} + +/** + * batadv_iv_orig_ifinfo_sum() - Get bcast_own sum for originator over interface + * @orig_node: originator which reproadcasted the OGMs directly + * @if_outgoing: interface which transmitted the original OGM and received the + * direct rebroadcast + * + * Return: Number of replied (rebroadcasted) OGMs which were transmitted by + * an originator and directly (without intermediate hop) received by a specific + * interface + */ +static u8 batadv_iv_orig_ifinfo_sum(struct batadv_orig_node *orig_node, + struct batadv_hard_iface *if_outgoing) +{ + struct batadv_orig_ifinfo *orig_ifinfo; + u8 sum; + + orig_ifinfo = batadv_orig_ifinfo_get(orig_node, if_outgoing); + if (!orig_ifinfo) + return 0; + + spin_lock_bh(&orig_node->bat_iv.ogm_cnt_lock); + sum = orig_ifinfo->bat_iv.bcast_own_sum; + spin_unlock_bh(&orig_node->bat_iv.ogm_cnt_lock); + + batadv_orig_ifinfo_put(orig_ifinfo); + + return sum; +} + +/** + * batadv_iv_ogm_orig_update() - use OGM to update corresponding data in an + * originator + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: the orig node who originally emitted the ogm packet + * @orig_ifinfo: ifinfo for the outgoing interface of the orig_node + * @ethhdr: Ethernet header of the OGM + * @batadv_ogm_packet: the ogm packet + * @if_incoming: interface where the packet was received + * @if_outgoing: interface for which the retransmission should be considered + * @dup_status: the duplicate status of this ogm packet. + */ +static void +batadv_iv_ogm_orig_update(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + struct batadv_orig_ifinfo *orig_ifinfo, + const struct ethhdr *ethhdr, + const struct batadv_ogm_packet *batadv_ogm_packet, + struct batadv_hard_iface *if_incoming, + struct batadv_hard_iface *if_outgoing, + enum batadv_dup_status dup_status) +{ + struct batadv_neigh_ifinfo *neigh_ifinfo = NULL; + struct batadv_neigh_ifinfo *router_ifinfo = NULL; + struct batadv_neigh_node *neigh_node = NULL; + struct batadv_neigh_node *tmp_neigh_node = NULL; + struct batadv_neigh_node *router = NULL; + u8 sum_orig, sum_neigh; + u8 *neigh_addr; + u8 tq_avg; + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "%s(): Searching and updating originator entry of received packet\n", + __func__); + + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp_neigh_node, + &orig_node->neigh_list, list) { + neigh_addr = tmp_neigh_node->addr; + if (batadv_compare_eth(neigh_addr, ethhdr->h_source) && + tmp_neigh_node->if_incoming == if_incoming && + kref_get_unless_zero(&tmp_neigh_node->refcount)) { + if (WARN(neigh_node, "too many matching neigh_nodes")) + batadv_neigh_node_put(neigh_node); + neigh_node = tmp_neigh_node; + continue; + } + + if (dup_status != BATADV_NO_DUP) + continue; + + /* only update the entry for this outgoing interface */ + neigh_ifinfo = batadv_neigh_ifinfo_get(tmp_neigh_node, + if_outgoing); + if (!neigh_ifinfo) + continue; + + spin_lock_bh(&tmp_neigh_node->ifinfo_lock); + batadv_ring_buffer_set(neigh_ifinfo->bat_iv.tq_recv, + &neigh_ifinfo->bat_iv.tq_index, 0); + tq_avg = batadv_ring_buffer_avg(neigh_ifinfo->bat_iv.tq_recv); + neigh_ifinfo->bat_iv.tq_avg = tq_avg; + spin_unlock_bh(&tmp_neigh_node->ifinfo_lock); + + batadv_neigh_ifinfo_put(neigh_ifinfo); + neigh_ifinfo = NULL; + } + + if (!neigh_node) { + struct batadv_orig_node *orig_tmp; + + orig_tmp = batadv_iv_ogm_orig_get(bat_priv, ethhdr->h_source); + if (!orig_tmp) + goto unlock; + + neigh_node = batadv_iv_ogm_neigh_new(if_incoming, + ethhdr->h_source, + orig_node, orig_tmp); + + batadv_orig_node_put(orig_tmp); + if (!neigh_node) + goto unlock; + } else { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Updating existing last-hop neighbor of originator\n"); + } + + rcu_read_unlock(); + neigh_ifinfo = batadv_neigh_ifinfo_new(neigh_node, if_outgoing); + if (!neigh_ifinfo) + goto out; + + neigh_node->last_seen = jiffies; + + spin_lock_bh(&neigh_node->ifinfo_lock); + batadv_ring_buffer_set(neigh_ifinfo->bat_iv.tq_recv, + &neigh_ifinfo->bat_iv.tq_index, + batadv_ogm_packet->tq); + tq_avg = batadv_ring_buffer_avg(neigh_ifinfo->bat_iv.tq_recv); + neigh_ifinfo->bat_iv.tq_avg = tq_avg; + spin_unlock_bh(&neigh_node->ifinfo_lock); + + if (dup_status == BATADV_NO_DUP) { + orig_ifinfo->last_ttl = batadv_ogm_packet->ttl; + neigh_ifinfo->last_ttl = batadv_ogm_packet->ttl; + } + + /* if this neighbor already is our next hop there is nothing + * to change + */ + router = batadv_orig_router_get(orig_node, if_outgoing); + if (router == neigh_node) + goto out; + + if (router) { + router_ifinfo = batadv_neigh_ifinfo_get(router, if_outgoing); + if (!router_ifinfo) + goto out; + + /* if this neighbor does not offer a better TQ we won't + * consider it + */ + if (router_ifinfo->bat_iv.tq_avg > neigh_ifinfo->bat_iv.tq_avg) + goto out; + } + + /* if the TQ is the same and the link not more symmetric we + * won't consider it either + */ + if (router_ifinfo && + neigh_ifinfo->bat_iv.tq_avg == router_ifinfo->bat_iv.tq_avg) { + sum_orig = batadv_iv_orig_ifinfo_sum(router->orig_node, + router->if_incoming); + sum_neigh = batadv_iv_orig_ifinfo_sum(neigh_node->orig_node, + neigh_node->if_incoming); + if (sum_orig >= sum_neigh) + goto out; + } + + batadv_update_route(bat_priv, orig_node, if_outgoing, neigh_node); + goto out; + +unlock: + rcu_read_unlock(); +out: + batadv_neigh_node_put(neigh_node); + batadv_neigh_node_put(router); + batadv_neigh_ifinfo_put(neigh_ifinfo); + batadv_neigh_ifinfo_put(router_ifinfo); +} + +/** + * batadv_iv_ogm_calc_tq() - calculate tq for current received ogm packet + * @orig_node: the orig node who originally emitted the ogm packet + * @orig_neigh_node: the orig node struct of the neighbor who sent the packet + * @batadv_ogm_packet: the ogm packet + * @if_incoming: interface where the packet was received + * @if_outgoing: interface for which the retransmission should be considered + * + * Return: true if the link can be considered bidirectional, false otherwise + */ +static bool batadv_iv_ogm_calc_tq(struct batadv_orig_node *orig_node, + struct batadv_orig_node *orig_neigh_node, + struct batadv_ogm_packet *batadv_ogm_packet, + struct batadv_hard_iface *if_incoming, + struct batadv_hard_iface *if_outgoing) +{ + struct batadv_priv *bat_priv = netdev_priv(if_incoming->soft_iface); + struct batadv_neigh_node *neigh_node = NULL, *tmp_neigh_node; + struct batadv_neigh_ifinfo *neigh_ifinfo; + u8 total_count; + u8 orig_eq_count, neigh_rq_count, neigh_rq_inv, tq_own; + unsigned int tq_iface_hop_penalty = BATADV_TQ_MAX_VALUE; + unsigned int neigh_rq_inv_cube, neigh_rq_max_cube; + unsigned int tq_asym_penalty, inv_asym_penalty; + unsigned int combined_tq; + bool ret = false; + + /* find corresponding one hop neighbor */ + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp_neigh_node, + &orig_neigh_node->neigh_list, list) { + if (!batadv_compare_eth(tmp_neigh_node->addr, + orig_neigh_node->orig)) + continue; + + if (tmp_neigh_node->if_incoming != if_incoming) + continue; + + if (!kref_get_unless_zero(&tmp_neigh_node->refcount)) + continue; + + neigh_node = tmp_neigh_node; + break; + } + rcu_read_unlock(); + + if (!neigh_node) + neigh_node = batadv_iv_ogm_neigh_new(if_incoming, + orig_neigh_node->orig, + orig_neigh_node, + orig_neigh_node); + + if (!neigh_node) + goto out; + + /* if orig_node is direct neighbor update neigh_node last_seen */ + if (orig_node == orig_neigh_node) + neigh_node->last_seen = jiffies; + + orig_node->last_seen = jiffies; + + /* find packet count of corresponding one hop neighbor */ + orig_eq_count = batadv_iv_orig_ifinfo_sum(orig_neigh_node, if_incoming); + neigh_ifinfo = batadv_neigh_ifinfo_new(neigh_node, if_outgoing); + if (neigh_ifinfo) { + neigh_rq_count = neigh_ifinfo->bat_iv.real_packet_count; + batadv_neigh_ifinfo_put(neigh_ifinfo); + } else { + neigh_rq_count = 0; + } + + /* pay attention to not get a value bigger than 100 % */ + if (orig_eq_count > neigh_rq_count) + total_count = neigh_rq_count; + else + total_count = orig_eq_count; + + /* if we have too few packets (too less data) we set tq_own to zero + * if we receive too few packets it is not considered bidirectional + */ + if (total_count < BATADV_TQ_LOCAL_BIDRECT_SEND_MINIMUM || + neigh_rq_count < BATADV_TQ_LOCAL_BIDRECT_RECV_MINIMUM) + tq_own = 0; + else + /* neigh_node->real_packet_count is never zero as we + * only purge old information when getting new + * information + */ + tq_own = (BATADV_TQ_MAX_VALUE * total_count) / neigh_rq_count; + + /* 1 - ((1-x) ** 3), normalized to TQ_MAX_VALUE this does + * affect the nearly-symmetric links only a little, but + * punishes asymmetric links more. This will give a value + * between 0 and TQ_MAX_VALUE + */ + neigh_rq_inv = BATADV_TQ_LOCAL_WINDOW_SIZE - neigh_rq_count; + neigh_rq_inv_cube = neigh_rq_inv * neigh_rq_inv * neigh_rq_inv; + neigh_rq_max_cube = BATADV_TQ_LOCAL_WINDOW_SIZE * + BATADV_TQ_LOCAL_WINDOW_SIZE * + BATADV_TQ_LOCAL_WINDOW_SIZE; + inv_asym_penalty = BATADV_TQ_MAX_VALUE * neigh_rq_inv_cube; + inv_asym_penalty /= neigh_rq_max_cube; + tq_asym_penalty = BATADV_TQ_MAX_VALUE - inv_asym_penalty; + tq_iface_hop_penalty -= atomic_read(&if_incoming->hop_penalty); + + /* penalize if the OGM is forwarded on the same interface. WiFi + * interfaces and other half duplex devices suffer from throughput + * drops as they can't send and receive at the same time. + */ + if (if_outgoing && if_incoming == if_outgoing && + batadv_is_wifi_hardif(if_outgoing)) + tq_iface_hop_penalty = batadv_hop_penalty(tq_iface_hop_penalty, + bat_priv); + + combined_tq = batadv_ogm_packet->tq * + tq_own * + tq_asym_penalty * + tq_iface_hop_penalty; + combined_tq /= BATADV_TQ_MAX_VALUE * + BATADV_TQ_MAX_VALUE * + BATADV_TQ_MAX_VALUE; + batadv_ogm_packet->tq = combined_tq; + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "bidirectional: orig = %pM neigh = %pM => own_bcast = %2i, real recv = %2i, local tq: %3i, asym_penalty: %3i, iface_hop_penalty: %3i, total tq: %3i, if_incoming = %s, if_outgoing = %s\n", + orig_node->orig, orig_neigh_node->orig, total_count, + neigh_rq_count, tq_own, tq_asym_penalty, + tq_iface_hop_penalty, batadv_ogm_packet->tq, + if_incoming->net_dev->name, + if_outgoing ? if_outgoing->net_dev->name : "DEFAULT"); + + /* if link has the minimum required transmission quality + * consider it bidirectional + */ + if (batadv_ogm_packet->tq >= BATADV_TQ_TOTAL_BIDRECT_LIMIT) + ret = true; + +out: + batadv_neigh_node_put(neigh_node); + return ret; +} + +/** + * batadv_iv_ogm_update_seqnos() - process a batman packet for all interfaces, + * adjust the sequence number and find out whether it is a duplicate + * @ethhdr: ethernet header of the packet + * @batadv_ogm_packet: OGM packet to be considered + * @if_incoming: interface on which the OGM packet was received + * @if_outgoing: interface for which the retransmission should be considered + * + * Return: duplicate status as enum batadv_dup_status + */ +static enum batadv_dup_status +batadv_iv_ogm_update_seqnos(const struct ethhdr *ethhdr, + const struct batadv_ogm_packet *batadv_ogm_packet, + const struct batadv_hard_iface *if_incoming, + struct batadv_hard_iface *if_outgoing) +{ + struct batadv_priv *bat_priv = netdev_priv(if_incoming->soft_iface); + struct batadv_orig_node *orig_node; + struct batadv_orig_ifinfo *orig_ifinfo = NULL; + struct batadv_neigh_node *neigh_node; + struct batadv_neigh_ifinfo *neigh_ifinfo; + bool is_dup; + s32 seq_diff; + bool need_update = false; + int set_mark; + enum batadv_dup_status ret = BATADV_NO_DUP; + u32 seqno = ntohl(batadv_ogm_packet->seqno); + u8 *neigh_addr; + u8 packet_count; + unsigned long *bitmap; + + orig_node = batadv_iv_ogm_orig_get(bat_priv, batadv_ogm_packet->orig); + if (!orig_node) + return BATADV_NO_DUP; + + orig_ifinfo = batadv_orig_ifinfo_new(orig_node, if_outgoing); + if (WARN_ON(!orig_ifinfo)) { + batadv_orig_node_put(orig_node); + return 0; + } + + spin_lock_bh(&orig_node->bat_iv.ogm_cnt_lock); + seq_diff = seqno - orig_ifinfo->last_real_seqno; + + /* signalize caller that the packet is to be dropped. */ + if (!hlist_empty(&orig_node->neigh_list) && + batadv_window_protected(bat_priv, seq_diff, + BATADV_TQ_LOCAL_WINDOW_SIZE, + &orig_ifinfo->batman_seqno_reset, NULL)) { + ret = BATADV_PROTECTED; + goto out; + } + + rcu_read_lock(); + hlist_for_each_entry_rcu(neigh_node, &orig_node->neigh_list, list) { + neigh_ifinfo = batadv_neigh_ifinfo_new(neigh_node, + if_outgoing); + if (!neigh_ifinfo) + continue; + + neigh_addr = neigh_node->addr; + is_dup = batadv_test_bit(neigh_ifinfo->bat_iv.real_bits, + orig_ifinfo->last_real_seqno, + seqno); + + if (batadv_compare_eth(neigh_addr, ethhdr->h_source) && + neigh_node->if_incoming == if_incoming) { + set_mark = 1; + if (is_dup) + ret = BATADV_NEIGH_DUP; + } else { + set_mark = 0; + if (is_dup && ret != BATADV_NEIGH_DUP) + ret = BATADV_ORIG_DUP; + } + + /* if the window moved, set the update flag. */ + bitmap = neigh_ifinfo->bat_iv.real_bits; + need_update |= batadv_bit_get_packet(bat_priv, bitmap, + seq_diff, set_mark); + + packet_count = bitmap_weight(bitmap, + BATADV_TQ_LOCAL_WINDOW_SIZE); + neigh_ifinfo->bat_iv.real_packet_count = packet_count; + batadv_neigh_ifinfo_put(neigh_ifinfo); + } + rcu_read_unlock(); + + if (need_update) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "%s updating last_seqno: old %u, new %u\n", + if_outgoing ? if_outgoing->net_dev->name : "DEFAULT", + orig_ifinfo->last_real_seqno, seqno); + orig_ifinfo->last_real_seqno = seqno; + } + +out: + spin_unlock_bh(&orig_node->bat_iv.ogm_cnt_lock); + batadv_orig_node_put(orig_node); + batadv_orig_ifinfo_put(orig_ifinfo); + return ret; +} + +/** + * batadv_iv_ogm_process_per_outif() - process a batman iv OGM for an outgoing + * interface + * @skb: the skb containing the OGM + * @ogm_offset: offset from skb->data to start of ogm header + * @orig_node: the (cached) orig node for the originator of this OGM + * @if_incoming: the interface where this packet was received + * @if_outgoing: the interface for which the packet should be considered + */ +static void +batadv_iv_ogm_process_per_outif(const struct sk_buff *skb, int ogm_offset, + struct batadv_orig_node *orig_node, + struct batadv_hard_iface *if_incoming, + struct batadv_hard_iface *if_outgoing) +{ + struct batadv_priv *bat_priv = netdev_priv(if_incoming->soft_iface); + struct batadv_hardif_neigh_node *hardif_neigh = NULL; + struct batadv_neigh_node *router = NULL; + struct batadv_neigh_node *router_router = NULL; + struct batadv_orig_node *orig_neigh_node; + struct batadv_orig_ifinfo *orig_ifinfo; + struct batadv_neigh_node *orig_neigh_router = NULL; + struct batadv_neigh_ifinfo *router_ifinfo = NULL; + struct batadv_ogm_packet *ogm_packet; + enum batadv_dup_status dup_status; + bool is_from_best_next_hop = false; + bool is_single_hop_neigh = false; + bool sameseq, similar_ttl; + struct sk_buff *skb_priv; + struct ethhdr *ethhdr; + u8 *prev_sender; + bool is_bidirect; + + /* create a private copy of the skb, as some functions change tq value + * and/or flags. + */ + skb_priv = skb_copy(skb, GFP_ATOMIC); + if (!skb_priv) + return; + + ethhdr = eth_hdr(skb_priv); + ogm_packet = (struct batadv_ogm_packet *)(skb_priv->data + ogm_offset); + + dup_status = batadv_iv_ogm_update_seqnos(ethhdr, ogm_packet, + if_incoming, if_outgoing); + if (batadv_compare_eth(ethhdr->h_source, ogm_packet->orig)) + is_single_hop_neigh = true; + + if (dup_status == BATADV_PROTECTED) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Drop packet: packet within seqno protection time (sender: %pM)\n", + ethhdr->h_source); + goto out; + } + + if (ogm_packet->tq == 0) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Drop packet: originator packet with tq equal 0\n"); + goto out; + } + + if (is_single_hop_neigh) { + hardif_neigh = batadv_hardif_neigh_get(if_incoming, + ethhdr->h_source); + if (hardif_neigh) + hardif_neigh->last_seen = jiffies; + } + + router = batadv_orig_router_get(orig_node, if_outgoing); + if (router) { + router_router = batadv_orig_router_get(router->orig_node, + if_outgoing); + router_ifinfo = batadv_neigh_ifinfo_get(router, if_outgoing); + } + + if ((router_ifinfo && router_ifinfo->bat_iv.tq_avg != 0) && + (batadv_compare_eth(router->addr, ethhdr->h_source))) + is_from_best_next_hop = true; + + prev_sender = ogm_packet->prev_sender; + /* avoid temporary routing loops */ + if (router && router_router && + (batadv_compare_eth(router->addr, prev_sender)) && + !(batadv_compare_eth(ogm_packet->orig, prev_sender)) && + (batadv_compare_eth(router->addr, router_router->addr))) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Drop packet: ignoring all rebroadcast packets that may make me loop (sender: %pM)\n", + ethhdr->h_source); + goto out; + } + + if (if_outgoing == BATADV_IF_DEFAULT) + batadv_tvlv_ogm_receive(bat_priv, ogm_packet, orig_node); + + /* if sender is a direct neighbor the sender mac equals + * originator mac + */ + if (is_single_hop_neigh) + orig_neigh_node = orig_node; + else + orig_neigh_node = batadv_iv_ogm_orig_get(bat_priv, + ethhdr->h_source); + + if (!orig_neigh_node) + goto out; + + /* Update nc_nodes of the originator */ + batadv_nc_update_nc_node(bat_priv, orig_node, orig_neigh_node, + ogm_packet, is_single_hop_neigh); + + orig_neigh_router = batadv_orig_router_get(orig_neigh_node, + if_outgoing); + + /* drop packet if sender is not a direct neighbor and if we + * don't route towards it + */ + if (!is_single_hop_neigh && !orig_neigh_router) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Drop packet: OGM via unknown neighbor!\n"); + goto out_neigh; + } + + is_bidirect = batadv_iv_ogm_calc_tq(orig_node, orig_neigh_node, + ogm_packet, if_incoming, + if_outgoing); + + /* update ranking if it is not a duplicate or has the same + * seqno and similar ttl as the non-duplicate + */ + orig_ifinfo = batadv_orig_ifinfo_new(orig_node, if_outgoing); + if (!orig_ifinfo) + goto out_neigh; + + sameseq = orig_ifinfo->last_real_seqno == ntohl(ogm_packet->seqno); + similar_ttl = (orig_ifinfo->last_ttl - 3) <= ogm_packet->ttl; + + if (is_bidirect && (dup_status == BATADV_NO_DUP || + (sameseq && similar_ttl))) { + batadv_iv_ogm_orig_update(bat_priv, orig_node, + orig_ifinfo, ethhdr, + ogm_packet, if_incoming, + if_outgoing, dup_status); + } + batadv_orig_ifinfo_put(orig_ifinfo); + + /* only forward for specific interface, not for the default one. */ + if (if_outgoing == BATADV_IF_DEFAULT) + goto out_neigh; + + /* is single hop (direct) neighbor */ + if (is_single_hop_neigh) { + /* OGMs from secondary interfaces should only scheduled once + * per interface where it has been received, not multiple times + */ + if (ogm_packet->ttl <= 2 && + if_incoming != if_outgoing) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Drop packet: OGM from secondary interface and wrong outgoing interface\n"); + goto out_neigh; + } + /* mark direct link on incoming interface */ + batadv_iv_ogm_forward(orig_node, ethhdr, ogm_packet, + is_single_hop_neigh, + is_from_best_next_hop, if_incoming, + if_outgoing); + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Forwarding packet: rebroadcast neighbor packet with direct link flag\n"); + goto out_neigh; + } + + /* multihop originator */ + if (!is_bidirect) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Drop packet: not received via bidirectional link\n"); + goto out_neigh; + } + + if (dup_status == BATADV_NEIGH_DUP) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Drop packet: duplicate packet received\n"); + goto out_neigh; + } + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Forwarding packet: rebroadcast originator packet\n"); + batadv_iv_ogm_forward(orig_node, ethhdr, ogm_packet, + is_single_hop_neigh, is_from_best_next_hop, + if_incoming, if_outgoing); + +out_neigh: + if (orig_neigh_node && !is_single_hop_neigh) + batadv_orig_node_put(orig_neigh_node); +out: + batadv_neigh_ifinfo_put(router_ifinfo); + batadv_neigh_node_put(router); + batadv_neigh_node_put(router_router); + batadv_neigh_node_put(orig_neigh_router); + batadv_hardif_neigh_put(hardif_neigh); + + consume_skb(skb_priv); +} + +/** + * batadv_iv_ogm_process_reply() - Check OGM for direct reply and process it + * @ogm_packet: rebroadcast OGM packet to process + * @if_incoming: the interface where this packet was received + * @orig_node: originator which reproadcasted the OGMs + * @if_incoming_seqno: OGM sequence number when rebroadcast was received + */ +static void batadv_iv_ogm_process_reply(struct batadv_ogm_packet *ogm_packet, + struct batadv_hard_iface *if_incoming, + struct batadv_orig_node *orig_node, + u32 if_incoming_seqno) +{ + struct batadv_orig_ifinfo *orig_ifinfo; + s32 bit_pos; + u8 *weight; + + /* neighbor has to indicate direct link and it has to + * come via the corresponding interface + */ + if (!(ogm_packet->flags & BATADV_DIRECTLINK)) + return; + + if (!batadv_compare_eth(if_incoming->net_dev->dev_addr, + ogm_packet->orig)) + return; + + orig_ifinfo = batadv_orig_ifinfo_get(orig_node, if_incoming); + if (!orig_ifinfo) + return; + + /* save packet seqno for bidirectional check */ + spin_lock_bh(&orig_node->bat_iv.ogm_cnt_lock); + bit_pos = if_incoming_seqno - 2; + bit_pos -= ntohl(ogm_packet->seqno); + batadv_set_bit(orig_ifinfo->bat_iv.bcast_own, bit_pos); + weight = &orig_ifinfo->bat_iv.bcast_own_sum; + *weight = bitmap_weight(orig_ifinfo->bat_iv.bcast_own, + BATADV_TQ_LOCAL_WINDOW_SIZE); + spin_unlock_bh(&orig_node->bat_iv.ogm_cnt_lock); + + batadv_orig_ifinfo_put(orig_ifinfo); +} + +/** + * batadv_iv_ogm_process() - process an incoming batman iv OGM + * @skb: the skb containing the OGM + * @ogm_offset: offset to the OGM which should be processed (for aggregates) + * @if_incoming: the interface where this packet was received + */ +static void batadv_iv_ogm_process(const struct sk_buff *skb, int ogm_offset, + struct batadv_hard_iface *if_incoming) +{ + struct batadv_priv *bat_priv = netdev_priv(if_incoming->soft_iface); + struct batadv_orig_node *orig_neigh_node, *orig_node; + struct batadv_hard_iface *hard_iface; + struct batadv_ogm_packet *ogm_packet; + u32 if_incoming_seqno; + bool has_directlink_flag; + struct ethhdr *ethhdr; + bool is_my_oldorig = false; + bool is_my_addr = false; + bool is_my_orig = false; + + ogm_packet = (struct batadv_ogm_packet *)(skb->data + ogm_offset); + ethhdr = eth_hdr(skb); + + /* Silently drop when the batman packet is actually not a + * correct packet. + * + * This might happen if a packet is padded (e.g. Ethernet has a + * minimum frame length of 64 byte) and the aggregation interprets + * it as an additional length. + * + * TODO: A more sane solution would be to have a bit in the + * batadv_ogm_packet to detect whether the packet is the last + * packet in an aggregation. Here we expect that the padding + * is always zero (or not 0x01) + */ + if (ogm_packet->packet_type != BATADV_IV_OGM) + return; + + /* could be changed by schedule_own_packet() */ + if_incoming_seqno = atomic_read(&if_incoming->bat_iv.ogm_seqno); + + if (ogm_packet->flags & BATADV_DIRECTLINK) + has_directlink_flag = true; + else + has_directlink_flag = false; + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Received BATMAN packet via NB: %pM, IF: %s [%pM] (from OG: %pM, via prev OG: %pM, seqno %u, tq %d, TTL %d, V %d, IDF %d)\n", + ethhdr->h_source, if_incoming->net_dev->name, + if_incoming->net_dev->dev_addr, ogm_packet->orig, + ogm_packet->prev_sender, ntohl(ogm_packet->seqno), + ogm_packet->tq, ogm_packet->ttl, + ogm_packet->version, has_directlink_flag); + + rcu_read_lock(); + list_for_each_entry_rcu(hard_iface, &batadv_hardif_list, list) { + if (hard_iface->if_status != BATADV_IF_ACTIVE) + continue; + + if (hard_iface->soft_iface != if_incoming->soft_iface) + continue; + + if (batadv_compare_eth(ethhdr->h_source, + hard_iface->net_dev->dev_addr)) + is_my_addr = true; + + if (batadv_compare_eth(ogm_packet->orig, + hard_iface->net_dev->dev_addr)) + is_my_orig = true; + + if (batadv_compare_eth(ogm_packet->prev_sender, + hard_iface->net_dev->dev_addr)) + is_my_oldorig = true; + } + rcu_read_unlock(); + + if (is_my_addr) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Drop packet: received my own broadcast (sender: %pM)\n", + ethhdr->h_source); + return; + } + + if (is_my_orig) { + orig_neigh_node = batadv_iv_ogm_orig_get(bat_priv, + ethhdr->h_source); + if (!orig_neigh_node) + return; + + batadv_iv_ogm_process_reply(ogm_packet, if_incoming, + orig_neigh_node, if_incoming_seqno); + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Drop packet: originator packet from myself (via neighbor)\n"); + batadv_orig_node_put(orig_neigh_node); + return; + } + + if (is_my_oldorig) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Drop packet: ignoring all rebroadcast echos (sender: %pM)\n", + ethhdr->h_source); + return; + } + + if (ogm_packet->flags & BATADV_NOT_BEST_NEXT_HOP) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Drop packet: ignoring all packets not forwarded from the best next hop (sender: %pM)\n", + ethhdr->h_source); + return; + } + + orig_node = batadv_iv_ogm_orig_get(bat_priv, ogm_packet->orig); + if (!orig_node) + return; + + batadv_iv_ogm_process_per_outif(skb, ogm_offset, orig_node, + if_incoming, BATADV_IF_DEFAULT); + + rcu_read_lock(); + list_for_each_entry_rcu(hard_iface, &batadv_hardif_list, list) { + if (hard_iface->if_status != BATADV_IF_ACTIVE) + continue; + + if (hard_iface->soft_iface != bat_priv->soft_iface) + continue; + + if (!kref_get_unless_zero(&hard_iface->refcount)) + continue; + + batadv_iv_ogm_process_per_outif(skb, ogm_offset, orig_node, + if_incoming, hard_iface); + + batadv_hardif_put(hard_iface); + } + rcu_read_unlock(); + + batadv_orig_node_put(orig_node); +} + +static void batadv_iv_send_outstanding_bat_ogm_packet(struct work_struct *work) +{ + struct delayed_work *delayed_work; + struct batadv_forw_packet *forw_packet; + struct batadv_priv *bat_priv; + bool dropped = false; + + delayed_work = to_delayed_work(work); + forw_packet = container_of(delayed_work, struct batadv_forw_packet, + delayed_work); + bat_priv = netdev_priv(forw_packet->if_incoming->soft_iface); + + if (atomic_read(&bat_priv->mesh_state) == BATADV_MESH_DEACTIVATING) { + dropped = true; + goto out; + } + + batadv_iv_ogm_emit(forw_packet); + + /* we have to have at least one packet in the queue to determine the + * queues wake up time unless we are shutting down. + * + * only re-schedule if this is the "original" copy, e.g. the OGM of the + * primary interface should only be rescheduled once per period, but + * this function will be called for the forw_packet instances of the + * other secondary interfaces as well. + */ + if (forw_packet->own && + forw_packet->if_incoming == forw_packet->if_outgoing) + batadv_iv_ogm_schedule(forw_packet->if_incoming); + +out: + /* do we get something for free()? */ + if (batadv_forw_packet_steal(forw_packet, + &bat_priv->forw_bat_list_lock)) + batadv_forw_packet_free(forw_packet, dropped); +} + +static int batadv_iv_ogm_receive(struct sk_buff *skb, + struct batadv_hard_iface *if_incoming) +{ + struct batadv_priv *bat_priv = netdev_priv(if_incoming->soft_iface); + struct batadv_ogm_packet *ogm_packet; + u8 *packet_pos; + int ogm_offset; + bool res; + int ret = NET_RX_DROP; + + res = batadv_check_management_packet(skb, if_incoming, BATADV_OGM_HLEN); + if (!res) + goto free_skb; + + /* did we receive a B.A.T.M.A.N. IV OGM packet on an interface + * that does not have B.A.T.M.A.N. IV enabled ? + */ + if (bat_priv->algo_ops->iface.enable != batadv_iv_ogm_iface_enable) + goto free_skb; + + batadv_inc_counter(bat_priv, BATADV_CNT_MGMT_RX); + batadv_add_counter(bat_priv, BATADV_CNT_MGMT_RX_BYTES, + skb->len + ETH_HLEN); + + ogm_offset = 0; + ogm_packet = (struct batadv_ogm_packet *)skb->data; + + /* unpack the aggregated packets and process them one by one */ + while (batadv_iv_ogm_aggr_packet(ogm_offset, skb_headlen(skb), + ogm_packet)) { + batadv_iv_ogm_process(skb, ogm_offset, if_incoming); + + ogm_offset += BATADV_OGM_HLEN; + ogm_offset += ntohs(ogm_packet->tvlv_len); + + packet_pos = skb->data + ogm_offset; + ogm_packet = (struct batadv_ogm_packet *)packet_pos; + } + + ret = NET_RX_SUCCESS; + +free_skb: + if (ret == NET_RX_SUCCESS) + consume_skb(skb); + else + kfree_skb(skb); + + return ret; +} + +/** + * batadv_iv_ogm_neigh_get_tq_avg() - Get the TQ average for a neighbour on a + * given outgoing interface. + * @neigh_node: Neighbour of interest + * @if_outgoing: Outgoing interface of interest + * @tq_avg: Pointer of where to store the TQ average + * + * Return: False if no average TQ available, otherwise true. + */ +static bool +batadv_iv_ogm_neigh_get_tq_avg(struct batadv_neigh_node *neigh_node, + struct batadv_hard_iface *if_outgoing, + u8 *tq_avg) +{ + struct batadv_neigh_ifinfo *n_ifinfo; + + n_ifinfo = batadv_neigh_ifinfo_get(neigh_node, if_outgoing); + if (!n_ifinfo) + return false; + + *tq_avg = n_ifinfo->bat_iv.tq_avg; + batadv_neigh_ifinfo_put(n_ifinfo); + + return true; +} + +/** + * batadv_iv_ogm_orig_dump_subentry() - Dump an originator subentry into a + * message + * @msg: Netlink message to dump into + * @portid: Port making netlink request + * @seq: Sequence number of netlink message + * @bat_priv: The bat priv with all the soft interface information + * @if_outgoing: Limit dump to entries with this outgoing interface + * @orig_node: Originator to dump + * @neigh_node: Single hops neighbour + * @best: Is the best originator + * + * Return: Error code, or 0 on success + */ +static int +batadv_iv_ogm_orig_dump_subentry(struct sk_buff *msg, u32 portid, u32 seq, + struct batadv_priv *bat_priv, + struct batadv_hard_iface *if_outgoing, + struct batadv_orig_node *orig_node, + struct batadv_neigh_node *neigh_node, + bool best) +{ + void *hdr; + u8 tq_avg; + unsigned int last_seen_msecs; + + last_seen_msecs = jiffies_to_msecs(jiffies - orig_node->last_seen); + + if (!batadv_iv_ogm_neigh_get_tq_avg(neigh_node, if_outgoing, &tq_avg)) + return 0; + + if (if_outgoing != BATADV_IF_DEFAULT && + if_outgoing != neigh_node->if_incoming) + return 0; + + hdr = genlmsg_put(msg, portid, seq, &batadv_netlink_family, + NLM_F_MULTI, BATADV_CMD_GET_ORIGINATORS); + if (!hdr) + return -ENOBUFS; + + if (nla_put(msg, BATADV_ATTR_ORIG_ADDRESS, ETH_ALEN, + orig_node->orig) || + nla_put(msg, BATADV_ATTR_NEIGH_ADDRESS, ETH_ALEN, + neigh_node->addr) || + nla_put_string(msg, BATADV_ATTR_HARD_IFNAME, + neigh_node->if_incoming->net_dev->name) || + nla_put_u32(msg, BATADV_ATTR_HARD_IFINDEX, + neigh_node->if_incoming->net_dev->ifindex) || + nla_put_u8(msg, BATADV_ATTR_TQ, tq_avg) || + nla_put_u32(msg, BATADV_ATTR_LAST_SEEN_MSECS, + last_seen_msecs)) + goto nla_put_failure; + + if (best && nla_put_flag(msg, BATADV_ATTR_FLAG_BEST)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + + nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +/** + * batadv_iv_ogm_orig_dump_entry() - Dump an originator entry into a message + * @msg: Netlink message to dump into + * @portid: Port making netlink request + * @seq: Sequence number of netlink message + * @bat_priv: The bat priv with all the soft interface information + * @if_outgoing: Limit dump to entries with this outgoing interface + * @orig_node: Originator to dump + * @sub_s: Number of sub entries to skip + * + * This function assumes the caller holds rcu_read_lock(). + * + * Return: Error code, or 0 on success + */ +static int +batadv_iv_ogm_orig_dump_entry(struct sk_buff *msg, u32 portid, u32 seq, + struct batadv_priv *bat_priv, + struct batadv_hard_iface *if_outgoing, + struct batadv_orig_node *orig_node, int *sub_s) +{ + struct batadv_neigh_node *neigh_node_best; + struct batadv_neigh_node *neigh_node; + int sub = 0; + bool best; + u8 tq_avg_best; + + neigh_node_best = batadv_orig_router_get(orig_node, if_outgoing); + if (!neigh_node_best) + goto out; + + if (!batadv_iv_ogm_neigh_get_tq_avg(neigh_node_best, if_outgoing, + &tq_avg_best)) + goto out; + + if (tq_avg_best == 0) + goto out; + + hlist_for_each_entry_rcu(neigh_node, &orig_node->neigh_list, list) { + if (sub++ < *sub_s) + continue; + + best = (neigh_node == neigh_node_best); + + if (batadv_iv_ogm_orig_dump_subentry(msg, portid, seq, + bat_priv, if_outgoing, + orig_node, neigh_node, + best)) { + batadv_neigh_node_put(neigh_node_best); + + *sub_s = sub - 1; + return -EMSGSIZE; + } + } + + out: + batadv_neigh_node_put(neigh_node_best); + + *sub_s = 0; + return 0; +} + +/** + * batadv_iv_ogm_orig_dump_bucket() - Dump an originator bucket into a + * message + * @msg: Netlink message to dump into + * @portid: Port making netlink request + * @seq: Sequence number of netlink message + * @bat_priv: The bat priv with all the soft interface information + * @if_outgoing: Limit dump to entries with this outgoing interface + * @head: Bucket to be dumped + * @idx_s: Number of entries to be skipped + * @sub: Number of sub entries to be skipped + * + * Return: Error code, or 0 on success + */ +static int +batadv_iv_ogm_orig_dump_bucket(struct sk_buff *msg, u32 portid, u32 seq, + struct batadv_priv *bat_priv, + struct batadv_hard_iface *if_outgoing, + struct hlist_head *head, int *idx_s, int *sub) +{ + struct batadv_orig_node *orig_node; + int idx = 0; + + rcu_read_lock(); + hlist_for_each_entry_rcu(orig_node, head, hash_entry) { + if (idx++ < *idx_s) + continue; + + if (batadv_iv_ogm_orig_dump_entry(msg, portid, seq, bat_priv, + if_outgoing, orig_node, + sub)) { + rcu_read_unlock(); + *idx_s = idx - 1; + return -EMSGSIZE; + } + } + rcu_read_unlock(); + + *idx_s = 0; + *sub = 0; + return 0; +} + +/** + * batadv_iv_ogm_orig_dump() - Dump the originators into a message + * @msg: Netlink message to dump into + * @cb: Control block containing additional options + * @bat_priv: The bat priv with all the soft interface information + * @if_outgoing: Limit dump to entries with this outgoing interface + */ +static void +batadv_iv_ogm_orig_dump(struct sk_buff *msg, struct netlink_callback *cb, + struct batadv_priv *bat_priv, + struct batadv_hard_iface *if_outgoing) +{ + struct batadv_hashtable *hash = bat_priv->orig_hash; + struct hlist_head *head; + int bucket = cb->args[0]; + int idx = cb->args[1]; + int sub = cb->args[2]; + int portid = NETLINK_CB(cb->skb).portid; + + while (bucket < hash->size) { + head = &hash->table[bucket]; + + if (batadv_iv_ogm_orig_dump_bucket(msg, portid, + cb->nlh->nlmsg_seq, + bat_priv, if_outgoing, head, + &idx, &sub)) + break; + + bucket++; + } + + cb->args[0] = bucket; + cb->args[1] = idx; + cb->args[2] = sub; +} + +/** + * batadv_iv_ogm_neigh_diff() - calculate tq difference of two neighbors + * @neigh1: the first neighbor object of the comparison + * @if_outgoing1: outgoing interface for the first neighbor + * @neigh2: the second neighbor object of the comparison + * @if_outgoing2: outgoing interface for the second neighbor + * @diff: pointer to integer receiving the calculated difference + * + * The content of *@diff is only valid when this function returns true. + * It is less, equal to or greater than 0 if the metric via neigh1 is lower, + * the same as or higher than the metric via neigh2 + * + * Return: true when the difference could be calculated, false otherwise + */ +static bool batadv_iv_ogm_neigh_diff(struct batadv_neigh_node *neigh1, + struct batadv_hard_iface *if_outgoing1, + struct batadv_neigh_node *neigh2, + struct batadv_hard_iface *if_outgoing2, + int *diff) +{ + struct batadv_neigh_ifinfo *neigh1_ifinfo, *neigh2_ifinfo; + u8 tq1, tq2; + bool ret = true; + + neigh1_ifinfo = batadv_neigh_ifinfo_get(neigh1, if_outgoing1); + neigh2_ifinfo = batadv_neigh_ifinfo_get(neigh2, if_outgoing2); + + if (!neigh1_ifinfo || !neigh2_ifinfo) { + ret = false; + goto out; + } + + tq1 = neigh1_ifinfo->bat_iv.tq_avg; + tq2 = neigh2_ifinfo->bat_iv.tq_avg; + *diff = (int)tq1 - (int)tq2; + +out: + batadv_neigh_ifinfo_put(neigh1_ifinfo); + batadv_neigh_ifinfo_put(neigh2_ifinfo); + + return ret; +} + +/** + * batadv_iv_ogm_neigh_dump_neigh() - Dump a neighbour into a netlink message + * @msg: Netlink message to dump into + * @portid: Port making netlink request + * @seq: Sequence number of netlink message + * @hardif_neigh: Neighbour to be dumped + * + * Return: Error code, or 0 on success + */ +static int +batadv_iv_ogm_neigh_dump_neigh(struct sk_buff *msg, u32 portid, u32 seq, + struct batadv_hardif_neigh_node *hardif_neigh) +{ + void *hdr; + unsigned int last_seen_msecs; + + last_seen_msecs = jiffies_to_msecs(jiffies - hardif_neigh->last_seen); + + hdr = genlmsg_put(msg, portid, seq, &batadv_netlink_family, + NLM_F_MULTI, BATADV_CMD_GET_NEIGHBORS); + if (!hdr) + return -ENOBUFS; + + if (nla_put(msg, BATADV_ATTR_NEIGH_ADDRESS, ETH_ALEN, + hardif_neigh->addr) || + nla_put_string(msg, BATADV_ATTR_HARD_IFNAME, + hardif_neigh->if_incoming->net_dev->name) || + nla_put_u32(msg, BATADV_ATTR_HARD_IFINDEX, + hardif_neigh->if_incoming->net_dev->ifindex) || + nla_put_u32(msg, BATADV_ATTR_LAST_SEEN_MSECS, + last_seen_msecs)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + + nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +/** + * batadv_iv_ogm_neigh_dump_hardif() - Dump the neighbours of a hard interface + * into a message + * @msg: Netlink message to dump into + * @portid: Port making netlink request + * @seq: Sequence number of netlink message + * @bat_priv: The bat priv with all the soft interface information + * @hard_iface: Hard interface to dump the neighbours for + * @idx_s: Number of entries to skip + * + * This function assumes the caller holds rcu_read_lock(). + * + * Return: Error code, or 0 on success + */ +static int +batadv_iv_ogm_neigh_dump_hardif(struct sk_buff *msg, u32 portid, u32 seq, + struct batadv_priv *bat_priv, + struct batadv_hard_iface *hard_iface, + int *idx_s) +{ + struct batadv_hardif_neigh_node *hardif_neigh; + int idx = 0; + + hlist_for_each_entry_rcu(hardif_neigh, + &hard_iface->neigh_list, list) { + if (idx++ < *idx_s) + continue; + + if (batadv_iv_ogm_neigh_dump_neigh(msg, portid, seq, + hardif_neigh)) { + *idx_s = idx - 1; + return -EMSGSIZE; + } + } + + *idx_s = 0; + return 0; +} + +/** + * batadv_iv_ogm_neigh_dump() - Dump the neighbours into a message + * @msg: Netlink message to dump into + * @cb: Control block containing additional options + * @bat_priv: The bat priv with all the soft interface information + * @single_hardif: Limit dump to this hard interface + */ +static void +batadv_iv_ogm_neigh_dump(struct sk_buff *msg, struct netlink_callback *cb, + struct batadv_priv *bat_priv, + struct batadv_hard_iface *single_hardif) +{ + struct batadv_hard_iface *hard_iface; + int i_hardif = 0; + int i_hardif_s = cb->args[0]; + int idx = cb->args[1]; + int portid = NETLINK_CB(cb->skb).portid; + + rcu_read_lock(); + if (single_hardif) { + if (i_hardif_s == 0) { + if (batadv_iv_ogm_neigh_dump_hardif(msg, portid, + cb->nlh->nlmsg_seq, + bat_priv, + single_hardif, + &idx) == 0) + i_hardif++; + } + } else { + list_for_each_entry_rcu(hard_iface, &batadv_hardif_list, + list) { + if (hard_iface->soft_iface != bat_priv->soft_iface) + continue; + + if (i_hardif++ < i_hardif_s) + continue; + + if (batadv_iv_ogm_neigh_dump_hardif(msg, portid, + cb->nlh->nlmsg_seq, + bat_priv, + hard_iface, &idx)) { + i_hardif--; + break; + } + } + } + rcu_read_unlock(); + + cb->args[0] = i_hardif; + cb->args[1] = idx; +} + +/** + * batadv_iv_ogm_neigh_cmp() - compare the metrics of two neighbors + * @neigh1: the first neighbor object of the comparison + * @if_outgoing1: outgoing interface for the first neighbor + * @neigh2: the second neighbor object of the comparison + * @if_outgoing2: outgoing interface for the second neighbor + * + * Return: a value less, equal to or greater than 0 if the metric via neigh1 is + * lower, the same as or higher than the metric via neigh2 + */ +static int batadv_iv_ogm_neigh_cmp(struct batadv_neigh_node *neigh1, + struct batadv_hard_iface *if_outgoing1, + struct batadv_neigh_node *neigh2, + struct batadv_hard_iface *if_outgoing2) +{ + bool ret; + int diff; + + ret = batadv_iv_ogm_neigh_diff(neigh1, if_outgoing1, neigh2, + if_outgoing2, &diff); + if (!ret) + return 0; + + return diff; +} + +/** + * batadv_iv_ogm_neigh_is_sob() - check if neigh1 is similarly good or better + * than neigh2 from the metric prospective + * @neigh1: the first neighbor object of the comparison + * @if_outgoing1: outgoing interface for the first neighbor + * @neigh2: the second neighbor object of the comparison + * @if_outgoing2: outgoing interface for the second neighbor + * + * Return: true if the metric via neigh1 is equally good or better than + * the metric via neigh2, false otherwise. + */ +static bool +batadv_iv_ogm_neigh_is_sob(struct batadv_neigh_node *neigh1, + struct batadv_hard_iface *if_outgoing1, + struct batadv_neigh_node *neigh2, + struct batadv_hard_iface *if_outgoing2) +{ + bool ret; + int diff; + + ret = batadv_iv_ogm_neigh_diff(neigh1, if_outgoing1, neigh2, + if_outgoing2, &diff); + if (!ret) + return false; + + ret = diff > -BATADV_TQ_SIMILARITY_THRESHOLD; + return ret; +} + +static void batadv_iv_iface_enabled(struct batadv_hard_iface *hard_iface) +{ + /* begin scheduling originator messages on that interface */ + batadv_iv_ogm_schedule(hard_iface); +} + +/** + * batadv_iv_init_sel_class() - initialize GW selection class + * @bat_priv: the bat priv with all the soft interface information + */ +static void batadv_iv_init_sel_class(struct batadv_priv *bat_priv) +{ + /* set default TQ difference threshold to 20 */ + atomic_set(&bat_priv->gw.sel_class, 20); +} + +static struct batadv_gw_node * +batadv_iv_gw_get_best_gw_node(struct batadv_priv *bat_priv) +{ + struct batadv_neigh_node *router; + struct batadv_neigh_ifinfo *router_ifinfo; + struct batadv_gw_node *gw_node, *curr_gw = NULL; + u64 max_gw_factor = 0; + u64 tmp_gw_factor = 0; + u8 max_tq = 0; + u8 tq_avg; + struct batadv_orig_node *orig_node; + + rcu_read_lock(); + hlist_for_each_entry_rcu(gw_node, &bat_priv->gw.gateway_list, list) { + orig_node = gw_node->orig_node; + router = batadv_orig_router_get(orig_node, BATADV_IF_DEFAULT); + if (!router) + continue; + + router_ifinfo = batadv_neigh_ifinfo_get(router, + BATADV_IF_DEFAULT); + if (!router_ifinfo) + goto next; + + if (!kref_get_unless_zero(&gw_node->refcount)) + goto next; + + tq_avg = router_ifinfo->bat_iv.tq_avg; + + switch (atomic_read(&bat_priv->gw.sel_class)) { + case 1: /* fast connection */ + tmp_gw_factor = tq_avg * tq_avg; + tmp_gw_factor *= gw_node->bandwidth_down; + tmp_gw_factor *= 100 * 100; + tmp_gw_factor >>= 18; + + if (tmp_gw_factor > max_gw_factor || + (tmp_gw_factor == max_gw_factor && + tq_avg > max_tq)) { + batadv_gw_node_put(curr_gw); + curr_gw = gw_node; + kref_get(&curr_gw->refcount); + } + break; + + default: /* 2: stable connection (use best statistic) + * 3: fast-switch (use best statistic but change as + * soon as a better gateway appears) + * XX: late-switch (use best statistic but change as + * soon as a better gateway appears which has + * $routing_class more tq points) + */ + if (tq_avg > max_tq) { + batadv_gw_node_put(curr_gw); + curr_gw = gw_node; + kref_get(&curr_gw->refcount); + } + break; + } + + if (tq_avg > max_tq) + max_tq = tq_avg; + + if (tmp_gw_factor > max_gw_factor) + max_gw_factor = tmp_gw_factor; + + batadv_gw_node_put(gw_node); + +next: + batadv_neigh_node_put(router); + batadv_neigh_ifinfo_put(router_ifinfo); + } + rcu_read_unlock(); + + return curr_gw; +} + +static bool batadv_iv_gw_is_eligible(struct batadv_priv *bat_priv, + struct batadv_orig_node *curr_gw_orig, + struct batadv_orig_node *orig_node) +{ + struct batadv_neigh_ifinfo *router_orig_ifinfo = NULL; + struct batadv_neigh_ifinfo *router_gw_ifinfo = NULL; + struct batadv_neigh_node *router_gw = NULL; + struct batadv_neigh_node *router_orig = NULL; + u8 gw_tq_avg, orig_tq_avg; + bool ret = false; + + /* dynamic re-election is performed only on fast or late switch */ + if (atomic_read(&bat_priv->gw.sel_class) <= 2) + return false; + + router_gw = batadv_orig_router_get(curr_gw_orig, BATADV_IF_DEFAULT); + if (!router_gw) { + ret = true; + goto out; + } + + router_gw_ifinfo = batadv_neigh_ifinfo_get(router_gw, + BATADV_IF_DEFAULT); + if (!router_gw_ifinfo) { + ret = true; + goto out; + } + + router_orig = batadv_orig_router_get(orig_node, BATADV_IF_DEFAULT); + if (!router_orig) + goto out; + + router_orig_ifinfo = batadv_neigh_ifinfo_get(router_orig, + BATADV_IF_DEFAULT); + if (!router_orig_ifinfo) + goto out; + + gw_tq_avg = router_gw_ifinfo->bat_iv.tq_avg; + orig_tq_avg = router_orig_ifinfo->bat_iv.tq_avg; + + /* the TQ value has to be better */ + if (orig_tq_avg < gw_tq_avg) + goto out; + + /* if the routing class is greater than 3 the value tells us how much + * greater the TQ value of the new gateway must be + */ + if ((atomic_read(&bat_priv->gw.sel_class) > 3) && + (orig_tq_avg - gw_tq_avg < atomic_read(&bat_priv->gw.sel_class))) + goto out; + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Restarting gateway selection: better gateway found (tq curr: %i, tq new: %i)\n", + gw_tq_avg, orig_tq_avg); + + ret = true; +out: + batadv_neigh_ifinfo_put(router_gw_ifinfo); + batadv_neigh_ifinfo_put(router_orig_ifinfo); + batadv_neigh_node_put(router_gw); + batadv_neigh_node_put(router_orig); + + return ret; +} + +/** + * batadv_iv_gw_dump_entry() - Dump a gateway into a message + * @msg: Netlink message to dump into + * @portid: Port making netlink request + * @cb: Control block containing additional options + * @bat_priv: The bat priv with all the soft interface information + * @gw_node: Gateway to be dumped + * + * Return: Error code, or 0 on success + */ +static int batadv_iv_gw_dump_entry(struct sk_buff *msg, u32 portid, + struct netlink_callback *cb, + struct batadv_priv *bat_priv, + struct batadv_gw_node *gw_node) +{ + struct batadv_neigh_ifinfo *router_ifinfo = NULL; + struct batadv_neigh_node *router; + struct batadv_gw_node *curr_gw = NULL; + int ret = 0; + void *hdr; + + router = batadv_orig_router_get(gw_node->orig_node, BATADV_IF_DEFAULT); + if (!router) + goto out; + + router_ifinfo = batadv_neigh_ifinfo_get(router, BATADV_IF_DEFAULT); + if (!router_ifinfo) + goto out; + + curr_gw = batadv_gw_get_selected_gw_node(bat_priv); + + hdr = genlmsg_put(msg, portid, cb->nlh->nlmsg_seq, + &batadv_netlink_family, NLM_F_MULTI, + BATADV_CMD_GET_GATEWAYS); + if (!hdr) { + ret = -ENOBUFS; + goto out; + } + + genl_dump_check_consistent(cb, hdr); + + ret = -EMSGSIZE; + + if (curr_gw == gw_node) + if (nla_put_flag(msg, BATADV_ATTR_FLAG_BEST)) { + genlmsg_cancel(msg, hdr); + goto out; + } + + if (nla_put(msg, BATADV_ATTR_ORIG_ADDRESS, ETH_ALEN, + gw_node->orig_node->orig) || + nla_put_u8(msg, BATADV_ATTR_TQ, router_ifinfo->bat_iv.tq_avg) || + nla_put(msg, BATADV_ATTR_ROUTER, ETH_ALEN, + router->addr) || + nla_put_string(msg, BATADV_ATTR_HARD_IFNAME, + router->if_incoming->net_dev->name) || + nla_put_u32(msg, BATADV_ATTR_HARD_IFINDEX, + router->if_incoming->net_dev->ifindex) || + nla_put_u32(msg, BATADV_ATTR_BANDWIDTH_DOWN, + gw_node->bandwidth_down) || + nla_put_u32(msg, BATADV_ATTR_BANDWIDTH_UP, + gw_node->bandwidth_up)) { + genlmsg_cancel(msg, hdr); + goto out; + } + + genlmsg_end(msg, hdr); + ret = 0; + +out: + batadv_gw_node_put(curr_gw); + batadv_neigh_ifinfo_put(router_ifinfo); + batadv_neigh_node_put(router); + return ret; +} + +/** + * batadv_iv_gw_dump() - Dump gateways into a message + * @msg: Netlink message to dump into + * @cb: Control block containing additional options + * @bat_priv: The bat priv with all the soft interface information + */ +static void batadv_iv_gw_dump(struct sk_buff *msg, struct netlink_callback *cb, + struct batadv_priv *bat_priv) +{ + int portid = NETLINK_CB(cb->skb).portid; + struct batadv_gw_node *gw_node; + int idx_skip = cb->args[0]; + int idx = 0; + + spin_lock_bh(&bat_priv->gw.list_lock); + cb->seq = bat_priv->gw.generation << 1 | 1; + + hlist_for_each_entry(gw_node, &bat_priv->gw.gateway_list, list) { + if (idx++ < idx_skip) + continue; + + if (batadv_iv_gw_dump_entry(msg, portid, cb, bat_priv, + gw_node)) { + idx_skip = idx - 1; + goto unlock; + } + } + + idx_skip = idx; +unlock: + spin_unlock_bh(&bat_priv->gw.list_lock); + + cb->args[0] = idx_skip; +} + +static struct batadv_algo_ops batadv_batman_iv __read_mostly = { + .name = "BATMAN_IV", + .iface = { + .enable = batadv_iv_ogm_iface_enable, + .enabled = batadv_iv_iface_enabled, + .disable = batadv_iv_ogm_iface_disable, + .update_mac = batadv_iv_ogm_iface_update_mac, + .primary_set = batadv_iv_ogm_primary_iface_set, + }, + .neigh = { + .cmp = batadv_iv_ogm_neigh_cmp, + .is_similar_or_better = batadv_iv_ogm_neigh_is_sob, + .dump = batadv_iv_ogm_neigh_dump, + }, + .orig = { + .dump = batadv_iv_ogm_orig_dump, + }, + .gw = { + .init_sel_class = batadv_iv_init_sel_class, + .get_best_gw_node = batadv_iv_gw_get_best_gw_node, + .is_eligible = batadv_iv_gw_is_eligible, + .dump = batadv_iv_gw_dump, + }, +}; + +/** + * batadv_iv_init() - B.A.T.M.A.N. IV initialization function + * + * Return: 0 on success or negative error number in case of failure + */ +int __init batadv_iv_init(void) +{ + int ret; + + /* batman originator packet */ + ret = batadv_recv_handler_register(BATADV_IV_OGM, + batadv_iv_ogm_receive); + if (ret < 0) + goto out; + + ret = batadv_algo_register(&batadv_batman_iv); + if (ret < 0) + goto handler_unregister; + + goto out; + +handler_unregister: + batadv_recv_handler_unregister(BATADV_IV_OGM); +out: + return ret; +} diff --git a/net/batman-adv/bat_iv_ogm.h b/net/batman-adv/bat_iv_ogm.h new file mode 100644 index 000000000..04b01bd68 --- /dev/null +++ b/net/batman-adv/bat_iv_ogm.h @@ -0,0 +1,14 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#ifndef _NET_BATMAN_ADV_BAT_IV_OGM_H_ +#define _NET_BATMAN_ADV_BAT_IV_OGM_H_ + +#include "main.h" + +int batadv_iv_init(void); + +#endif /* _NET_BATMAN_ADV_BAT_IV_OGM_H_ */ diff --git a/net/batman-adv/bat_v.c b/net/batman-adv/bat_v.c new file mode 100644 index 000000000..54e41fc70 --- /dev/null +++ b/net/batman-adv/bat_v.c @@ -0,0 +1,910 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Linus Lüssing, Marek Lindner + */ + +#include "bat_v.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "bat_algo.h" +#include "bat_v_elp.h" +#include "bat_v_ogm.h" +#include "gateway_client.h" +#include "gateway_common.h" +#include "hard-interface.h" +#include "hash.h" +#include "log.h" +#include "netlink.h" +#include "originator.h" + +static void batadv_v_iface_activate(struct batadv_hard_iface *hard_iface) +{ + struct batadv_priv *bat_priv = netdev_priv(hard_iface->soft_iface); + struct batadv_hard_iface *primary_if; + + primary_if = batadv_primary_if_get_selected(bat_priv); + + if (primary_if) { + batadv_v_elp_iface_activate(primary_if, hard_iface); + batadv_hardif_put(primary_if); + } + + /* B.A.T.M.A.N. V does not use any queuing mechanism, therefore it can + * set the interface as ACTIVE right away, without any risk of race + * condition + */ + if (hard_iface->if_status == BATADV_IF_TO_BE_ACTIVATED) + hard_iface->if_status = BATADV_IF_ACTIVE; +} + +static int batadv_v_iface_enable(struct batadv_hard_iface *hard_iface) +{ + int ret; + + ret = batadv_v_elp_iface_enable(hard_iface); + if (ret < 0) + return ret; + + ret = batadv_v_ogm_iface_enable(hard_iface); + if (ret < 0) + batadv_v_elp_iface_disable(hard_iface); + + return ret; +} + +static void batadv_v_iface_disable(struct batadv_hard_iface *hard_iface) +{ + batadv_v_ogm_iface_disable(hard_iface); + batadv_v_elp_iface_disable(hard_iface); +} + +static void batadv_v_primary_iface_set(struct batadv_hard_iface *hard_iface) +{ + batadv_v_elp_primary_iface_set(hard_iface); + batadv_v_ogm_primary_iface_set(hard_iface); +} + +/** + * batadv_v_iface_update_mac() - react to hard-interface MAC address change + * @hard_iface: the modified interface + * + * If the modified interface is the primary one, update the originator + * address in the ELP and OGM messages to reflect the new MAC address. + */ +static void batadv_v_iface_update_mac(struct batadv_hard_iface *hard_iface) +{ + struct batadv_priv *bat_priv = netdev_priv(hard_iface->soft_iface); + struct batadv_hard_iface *primary_if; + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (primary_if != hard_iface) + goto out; + + batadv_v_primary_iface_set(hard_iface); +out: + batadv_hardif_put(primary_if); +} + +static void +batadv_v_hardif_neigh_init(struct batadv_hardif_neigh_node *hardif_neigh) +{ + ewma_throughput_init(&hardif_neigh->bat_v.throughput); + INIT_WORK(&hardif_neigh->bat_v.metric_work, + batadv_v_elp_throughput_metric_update); +} + +/** + * batadv_v_neigh_dump_neigh() - Dump a neighbour into a message + * @msg: Netlink message to dump into + * @portid: Port making netlink request + * @seq: Sequence number of netlink message + * @hardif_neigh: Neighbour to dump + * + * Return: Error code, or 0 on success + */ +static int +batadv_v_neigh_dump_neigh(struct sk_buff *msg, u32 portid, u32 seq, + struct batadv_hardif_neigh_node *hardif_neigh) +{ + void *hdr; + unsigned int last_seen_msecs; + u32 throughput; + + last_seen_msecs = jiffies_to_msecs(jiffies - hardif_neigh->last_seen); + throughput = ewma_throughput_read(&hardif_neigh->bat_v.throughput); + throughput = throughput * 100; + + hdr = genlmsg_put(msg, portid, seq, &batadv_netlink_family, NLM_F_MULTI, + BATADV_CMD_GET_NEIGHBORS); + if (!hdr) + return -ENOBUFS; + + if (nla_put(msg, BATADV_ATTR_NEIGH_ADDRESS, ETH_ALEN, + hardif_neigh->addr) || + nla_put_string(msg, BATADV_ATTR_HARD_IFNAME, + hardif_neigh->if_incoming->net_dev->name) || + nla_put_u32(msg, BATADV_ATTR_HARD_IFINDEX, + hardif_neigh->if_incoming->net_dev->ifindex) || + nla_put_u32(msg, BATADV_ATTR_LAST_SEEN_MSECS, + last_seen_msecs) || + nla_put_u32(msg, BATADV_ATTR_THROUGHPUT, throughput)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + + nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +/** + * batadv_v_neigh_dump_hardif() - Dump the neighbours of a hard interface into + * a message + * @msg: Netlink message to dump into + * @portid: Port making netlink request + * @seq: Sequence number of netlink message + * @bat_priv: The bat priv with all the soft interface information + * @hard_iface: The hard interface to be dumped + * @idx_s: Entries to be skipped + * + * This function assumes the caller holds rcu_read_lock(). + * + * Return: Error code, or 0 on success + */ +static int +batadv_v_neigh_dump_hardif(struct sk_buff *msg, u32 portid, u32 seq, + struct batadv_priv *bat_priv, + struct batadv_hard_iface *hard_iface, + int *idx_s) +{ + struct batadv_hardif_neigh_node *hardif_neigh; + int idx = 0; + + hlist_for_each_entry_rcu(hardif_neigh, + &hard_iface->neigh_list, list) { + if (idx++ < *idx_s) + continue; + + if (batadv_v_neigh_dump_neigh(msg, portid, seq, hardif_neigh)) { + *idx_s = idx - 1; + return -EMSGSIZE; + } + } + + *idx_s = 0; + return 0; +} + +/** + * batadv_v_neigh_dump() - Dump the neighbours of a hard interface into a + * message + * @msg: Netlink message to dump into + * @cb: Control block containing additional options + * @bat_priv: The bat priv with all the soft interface information + * @single_hardif: Limit dumping to this hard interface + */ +static void +batadv_v_neigh_dump(struct sk_buff *msg, struct netlink_callback *cb, + struct batadv_priv *bat_priv, + struct batadv_hard_iface *single_hardif) +{ + struct batadv_hard_iface *hard_iface; + int i_hardif = 0; + int i_hardif_s = cb->args[0]; + int idx = cb->args[1]; + int portid = NETLINK_CB(cb->skb).portid; + + rcu_read_lock(); + if (single_hardif) { + if (i_hardif_s == 0) { + if (batadv_v_neigh_dump_hardif(msg, portid, + cb->nlh->nlmsg_seq, + bat_priv, single_hardif, + &idx) == 0) + i_hardif++; + } + } else { + list_for_each_entry_rcu(hard_iface, &batadv_hardif_list, list) { + if (hard_iface->soft_iface != bat_priv->soft_iface) + continue; + + if (i_hardif++ < i_hardif_s) + continue; + + if (batadv_v_neigh_dump_hardif(msg, portid, + cb->nlh->nlmsg_seq, + bat_priv, hard_iface, + &idx)) { + i_hardif--; + break; + } + } + } + rcu_read_unlock(); + + cb->args[0] = i_hardif; + cb->args[1] = idx; +} + +/** + * batadv_v_orig_dump_subentry() - Dump an originator subentry into a message + * @msg: Netlink message to dump into + * @portid: Port making netlink request + * @seq: Sequence number of netlink message + * @bat_priv: The bat priv with all the soft interface information + * @if_outgoing: Limit dump to entries with this outgoing interface + * @orig_node: Originator to dump + * @neigh_node: Single hops neighbour + * @best: Is the best originator + * + * Return: Error code, or 0 on success + */ +static int +batadv_v_orig_dump_subentry(struct sk_buff *msg, u32 portid, u32 seq, + struct batadv_priv *bat_priv, + struct batadv_hard_iface *if_outgoing, + struct batadv_orig_node *orig_node, + struct batadv_neigh_node *neigh_node, + bool best) +{ + struct batadv_neigh_ifinfo *n_ifinfo; + unsigned int last_seen_msecs; + u32 throughput; + void *hdr; + + n_ifinfo = batadv_neigh_ifinfo_get(neigh_node, if_outgoing); + if (!n_ifinfo) + return 0; + + throughput = n_ifinfo->bat_v.throughput * 100; + + batadv_neigh_ifinfo_put(n_ifinfo); + + last_seen_msecs = jiffies_to_msecs(jiffies - orig_node->last_seen); + + if (if_outgoing != BATADV_IF_DEFAULT && + if_outgoing != neigh_node->if_incoming) + return 0; + + hdr = genlmsg_put(msg, portid, seq, &batadv_netlink_family, NLM_F_MULTI, + BATADV_CMD_GET_ORIGINATORS); + if (!hdr) + return -ENOBUFS; + + if (nla_put(msg, BATADV_ATTR_ORIG_ADDRESS, ETH_ALEN, orig_node->orig) || + nla_put(msg, BATADV_ATTR_NEIGH_ADDRESS, ETH_ALEN, + neigh_node->addr) || + nla_put_string(msg, BATADV_ATTR_HARD_IFNAME, + neigh_node->if_incoming->net_dev->name) || + nla_put_u32(msg, BATADV_ATTR_HARD_IFINDEX, + neigh_node->if_incoming->net_dev->ifindex) || + nla_put_u32(msg, BATADV_ATTR_THROUGHPUT, throughput) || + nla_put_u32(msg, BATADV_ATTR_LAST_SEEN_MSECS, + last_seen_msecs)) + goto nla_put_failure; + + if (best && nla_put_flag(msg, BATADV_ATTR_FLAG_BEST)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + + nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +/** + * batadv_v_orig_dump_entry() - Dump an originator entry into a message + * @msg: Netlink message to dump into + * @portid: Port making netlink request + * @seq: Sequence number of netlink message + * @bat_priv: The bat priv with all the soft interface information + * @if_outgoing: Limit dump to entries with this outgoing interface + * @orig_node: Originator to dump + * @sub_s: Number of sub entries to skip + * + * This function assumes the caller holds rcu_read_lock(). + * + * Return: Error code, or 0 on success + */ +static int +batadv_v_orig_dump_entry(struct sk_buff *msg, u32 portid, u32 seq, + struct batadv_priv *bat_priv, + struct batadv_hard_iface *if_outgoing, + struct batadv_orig_node *orig_node, int *sub_s) +{ + struct batadv_neigh_node *neigh_node_best; + struct batadv_neigh_node *neigh_node; + int sub = 0; + bool best; + + neigh_node_best = batadv_orig_router_get(orig_node, if_outgoing); + if (!neigh_node_best) + goto out; + + hlist_for_each_entry_rcu(neigh_node, &orig_node->neigh_list, list) { + if (sub++ < *sub_s) + continue; + + best = (neigh_node == neigh_node_best); + + if (batadv_v_orig_dump_subentry(msg, portid, seq, bat_priv, + if_outgoing, orig_node, + neigh_node, best)) { + batadv_neigh_node_put(neigh_node_best); + + *sub_s = sub - 1; + return -EMSGSIZE; + } + } + + out: + batadv_neigh_node_put(neigh_node_best); + + *sub_s = 0; + return 0; +} + +/** + * batadv_v_orig_dump_bucket() - Dump an originator bucket into a message + * @msg: Netlink message to dump into + * @portid: Port making netlink request + * @seq: Sequence number of netlink message + * @bat_priv: The bat priv with all the soft interface information + * @if_outgoing: Limit dump to entries with this outgoing interface + * @head: Bucket to be dumped + * @idx_s: Number of entries to be skipped + * @sub: Number of sub entries to be skipped + * + * Return: Error code, or 0 on success + */ +static int +batadv_v_orig_dump_bucket(struct sk_buff *msg, u32 portid, u32 seq, + struct batadv_priv *bat_priv, + struct batadv_hard_iface *if_outgoing, + struct hlist_head *head, int *idx_s, int *sub) +{ + struct batadv_orig_node *orig_node; + int idx = 0; + + rcu_read_lock(); + hlist_for_each_entry_rcu(orig_node, head, hash_entry) { + if (idx++ < *idx_s) + continue; + + if (batadv_v_orig_dump_entry(msg, portid, seq, bat_priv, + if_outgoing, orig_node, sub)) { + rcu_read_unlock(); + *idx_s = idx - 1; + return -EMSGSIZE; + } + } + rcu_read_unlock(); + + *idx_s = 0; + *sub = 0; + return 0; +} + +/** + * batadv_v_orig_dump() - Dump the originators into a message + * @msg: Netlink message to dump into + * @cb: Control block containing additional options + * @bat_priv: The bat priv with all the soft interface information + * @if_outgoing: Limit dump to entries with this outgoing interface + */ +static void +batadv_v_orig_dump(struct sk_buff *msg, struct netlink_callback *cb, + struct batadv_priv *bat_priv, + struct batadv_hard_iface *if_outgoing) +{ + struct batadv_hashtable *hash = bat_priv->orig_hash; + struct hlist_head *head; + int bucket = cb->args[0]; + int idx = cb->args[1]; + int sub = cb->args[2]; + int portid = NETLINK_CB(cb->skb).portid; + + while (bucket < hash->size) { + head = &hash->table[bucket]; + + if (batadv_v_orig_dump_bucket(msg, portid, + cb->nlh->nlmsg_seq, + bat_priv, if_outgoing, head, &idx, + &sub)) + break; + + bucket++; + } + + cb->args[0] = bucket; + cb->args[1] = idx; + cb->args[2] = sub; +} + +static int batadv_v_neigh_cmp(struct batadv_neigh_node *neigh1, + struct batadv_hard_iface *if_outgoing1, + struct batadv_neigh_node *neigh2, + struct batadv_hard_iface *if_outgoing2) +{ + struct batadv_neigh_ifinfo *ifinfo1, *ifinfo2; + int ret = 0; + + ifinfo1 = batadv_neigh_ifinfo_get(neigh1, if_outgoing1); + if (!ifinfo1) + goto err_ifinfo1; + + ifinfo2 = batadv_neigh_ifinfo_get(neigh2, if_outgoing2); + if (!ifinfo2) + goto err_ifinfo2; + + ret = ifinfo1->bat_v.throughput - ifinfo2->bat_v.throughput; + + batadv_neigh_ifinfo_put(ifinfo2); +err_ifinfo2: + batadv_neigh_ifinfo_put(ifinfo1); +err_ifinfo1: + return ret; +} + +static bool batadv_v_neigh_is_sob(struct batadv_neigh_node *neigh1, + struct batadv_hard_iface *if_outgoing1, + struct batadv_neigh_node *neigh2, + struct batadv_hard_iface *if_outgoing2) +{ + struct batadv_neigh_ifinfo *ifinfo1, *ifinfo2; + u32 threshold; + bool ret = false; + + ifinfo1 = batadv_neigh_ifinfo_get(neigh1, if_outgoing1); + if (!ifinfo1) + goto err_ifinfo1; + + ifinfo2 = batadv_neigh_ifinfo_get(neigh2, if_outgoing2); + if (!ifinfo2) + goto err_ifinfo2; + + threshold = ifinfo1->bat_v.throughput / 4; + threshold = ifinfo1->bat_v.throughput - threshold; + + ret = ifinfo2->bat_v.throughput > threshold; + + batadv_neigh_ifinfo_put(ifinfo2); +err_ifinfo2: + batadv_neigh_ifinfo_put(ifinfo1); +err_ifinfo1: + return ret; +} + +/** + * batadv_v_init_sel_class() - initialize GW selection class + * @bat_priv: the bat priv with all the soft interface information + */ +static void batadv_v_init_sel_class(struct batadv_priv *bat_priv) +{ + /* set default throughput difference threshold to 5Mbps */ + atomic_set(&bat_priv->gw.sel_class, 50); +} + +static ssize_t batadv_v_store_sel_class(struct batadv_priv *bat_priv, + char *buff, size_t count) +{ + u32 old_class, class; + + if (!batadv_parse_throughput(bat_priv->soft_iface, buff, + "B.A.T.M.A.N. V GW selection class", + &class)) + return -EINVAL; + + old_class = atomic_read(&bat_priv->gw.sel_class); + atomic_set(&bat_priv->gw.sel_class, class); + + if (old_class != class) + batadv_gw_reselect(bat_priv); + + return count; +} + +/** + * batadv_v_gw_throughput_get() - retrieve the GW-bandwidth for a given GW + * @gw_node: the GW to retrieve the metric for + * @bw: the pointer where the metric will be stored. The metric is computed as + * the minimum between the GW advertised throughput and the path throughput to + * it in the mesh + * + * Return: 0 on success, -1 on failure + */ +static int batadv_v_gw_throughput_get(struct batadv_gw_node *gw_node, u32 *bw) +{ + struct batadv_neigh_ifinfo *router_ifinfo = NULL; + struct batadv_orig_node *orig_node; + struct batadv_neigh_node *router; + int ret = -1; + + orig_node = gw_node->orig_node; + router = batadv_orig_router_get(orig_node, BATADV_IF_DEFAULT); + if (!router) + goto out; + + router_ifinfo = batadv_neigh_ifinfo_get(router, BATADV_IF_DEFAULT); + if (!router_ifinfo) + goto out; + + /* the GW metric is computed as the minimum between the path throughput + * to reach the GW itself and the advertised bandwidth. + * This gives us an approximation of the effective throughput that the + * client can expect via this particular GW node + */ + *bw = router_ifinfo->bat_v.throughput; + *bw = min_t(u32, *bw, gw_node->bandwidth_down); + + ret = 0; +out: + batadv_neigh_node_put(router); + batadv_neigh_ifinfo_put(router_ifinfo); + + return ret; +} + +/** + * batadv_v_gw_get_best_gw_node() - retrieve the best GW node + * @bat_priv: the bat priv with all the soft interface information + * + * Return: the GW node having the best GW-metric, NULL if no GW is known + */ +static struct batadv_gw_node * +batadv_v_gw_get_best_gw_node(struct batadv_priv *bat_priv) +{ + struct batadv_gw_node *gw_node, *curr_gw = NULL; + u32 max_bw = 0, bw; + + rcu_read_lock(); + hlist_for_each_entry_rcu(gw_node, &bat_priv->gw.gateway_list, list) { + if (!kref_get_unless_zero(&gw_node->refcount)) + continue; + + if (batadv_v_gw_throughput_get(gw_node, &bw) < 0) + goto next; + + if (curr_gw && bw <= max_bw) + goto next; + + batadv_gw_node_put(curr_gw); + + curr_gw = gw_node; + kref_get(&curr_gw->refcount); + max_bw = bw; + +next: + batadv_gw_node_put(gw_node); + } + rcu_read_unlock(); + + return curr_gw; +} + +/** + * batadv_v_gw_is_eligible() - check if a originator would be selected as GW + * @bat_priv: the bat priv with all the soft interface information + * @curr_gw_orig: originator representing the currently selected GW + * @orig_node: the originator representing the new candidate + * + * Return: true if orig_node can be selected as current GW, false otherwise + */ +static bool batadv_v_gw_is_eligible(struct batadv_priv *bat_priv, + struct batadv_orig_node *curr_gw_orig, + struct batadv_orig_node *orig_node) +{ + struct batadv_gw_node *curr_gw, *orig_gw = NULL; + u32 gw_throughput, orig_throughput, threshold; + bool ret = false; + + threshold = atomic_read(&bat_priv->gw.sel_class); + + curr_gw = batadv_gw_node_get(bat_priv, curr_gw_orig); + if (!curr_gw) { + ret = true; + goto out; + } + + if (batadv_v_gw_throughput_get(curr_gw, &gw_throughput) < 0) { + ret = true; + goto out; + } + + orig_gw = batadv_gw_node_get(bat_priv, orig_node); + if (!orig_gw) + goto out; + + if (batadv_v_gw_throughput_get(orig_gw, &orig_throughput) < 0) + goto out; + + if (orig_throughput < gw_throughput) + goto out; + + if ((orig_throughput - gw_throughput) < threshold) + goto out; + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Restarting gateway selection: better gateway found (throughput curr: %u, throughput new: %u)\n", + gw_throughput, orig_throughput); + + ret = true; +out: + batadv_gw_node_put(curr_gw); + batadv_gw_node_put(orig_gw); + + return ret; +} + +/** + * batadv_v_gw_dump_entry() - Dump a gateway into a message + * @msg: Netlink message to dump into + * @portid: Port making netlink request + * @cb: Control block containing additional options + * @bat_priv: The bat priv with all the soft interface information + * @gw_node: Gateway to be dumped + * + * Return: Error code, or 0 on success + */ +static int batadv_v_gw_dump_entry(struct sk_buff *msg, u32 portid, + struct netlink_callback *cb, + struct batadv_priv *bat_priv, + struct batadv_gw_node *gw_node) +{ + struct batadv_neigh_ifinfo *router_ifinfo = NULL; + struct batadv_neigh_node *router; + struct batadv_gw_node *curr_gw = NULL; + int ret = 0; + void *hdr; + + router = batadv_orig_router_get(gw_node->orig_node, BATADV_IF_DEFAULT); + if (!router) + goto out; + + router_ifinfo = batadv_neigh_ifinfo_get(router, BATADV_IF_DEFAULT); + if (!router_ifinfo) + goto out; + + curr_gw = batadv_gw_get_selected_gw_node(bat_priv); + + hdr = genlmsg_put(msg, portid, cb->nlh->nlmsg_seq, + &batadv_netlink_family, NLM_F_MULTI, + BATADV_CMD_GET_GATEWAYS); + if (!hdr) { + ret = -ENOBUFS; + goto out; + } + + genl_dump_check_consistent(cb, hdr); + + ret = -EMSGSIZE; + + if (curr_gw == gw_node) { + if (nla_put_flag(msg, BATADV_ATTR_FLAG_BEST)) { + genlmsg_cancel(msg, hdr); + goto out; + } + } + + if (nla_put(msg, BATADV_ATTR_ORIG_ADDRESS, ETH_ALEN, + gw_node->orig_node->orig)) { + genlmsg_cancel(msg, hdr); + goto out; + } + + if (nla_put_u32(msg, BATADV_ATTR_THROUGHPUT, + router_ifinfo->bat_v.throughput)) { + genlmsg_cancel(msg, hdr); + goto out; + } + + if (nla_put(msg, BATADV_ATTR_ROUTER, ETH_ALEN, router->addr)) { + genlmsg_cancel(msg, hdr); + goto out; + } + + if (nla_put_string(msg, BATADV_ATTR_HARD_IFNAME, + router->if_incoming->net_dev->name)) { + genlmsg_cancel(msg, hdr); + goto out; + } + + if (nla_put_u32(msg, BATADV_ATTR_HARD_IFINDEX, + router->if_incoming->net_dev->ifindex)) { + genlmsg_cancel(msg, hdr); + goto out; + } + + if (nla_put_u32(msg, BATADV_ATTR_BANDWIDTH_DOWN, + gw_node->bandwidth_down)) { + genlmsg_cancel(msg, hdr); + goto out; + } + + if (nla_put_u32(msg, BATADV_ATTR_BANDWIDTH_UP, gw_node->bandwidth_up)) { + genlmsg_cancel(msg, hdr); + goto out; + } + + genlmsg_end(msg, hdr); + ret = 0; + +out: + batadv_gw_node_put(curr_gw); + batadv_neigh_ifinfo_put(router_ifinfo); + batadv_neigh_node_put(router); + return ret; +} + +/** + * batadv_v_gw_dump() - Dump gateways into a message + * @msg: Netlink message to dump into + * @cb: Control block containing additional options + * @bat_priv: The bat priv with all the soft interface information + */ +static void batadv_v_gw_dump(struct sk_buff *msg, struct netlink_callback *cb, + struct batadv_priv *bat_priv) +{ + int portid = NETLINK_CB(cb->skb).portid; + struct batadv_gw_node *gw_node; + int idx_skip = cb->args[0]; + int idx = 0; + + spin_lock_bh(&bat_priv->gw.list_lock); + cb->seq = bat_priv->gw.generation << 1 | 1; + + hlist_for_each_entry(gw_node, &bat_priv->gw.gateway_list, list) { + if (idx++ < idx_skip) + continue; + + if (batadv_v_gw_dump_entry(msg, portid, cb, bat_priv, + gw_node)) { + idx_skip = idx - 1; + goto unlock; + } + } + + idx_skip = idx; +unlock: + spin_unlock_bh(&bat_priv->gw.list_lock); + + cb->args[0] = idx_skip; +} + +static struct batadv_algo_ops batadv_batman_v __read_mostly = { + .name = "BATMAN_V", + .iface = { + .activate = batadv_v_iface_activate, + .enable = batadv_v_iface_enable, + .disable = batadv_v_iface_disable, + .update_mac = batadv_v_iface_update_mac, + .primary_set = batadv_v_primary_iface_set, + }, + .neigh = { + .hardif_init = batadv_v_hardif_neigh_init, + .cmp = batadv_v_neigh_cmp, + .is_similar_or_better = batadv_v_neigh_is_sob, + .dump = batadv_v_neigh_dump, + }, + .orig = { + .dump = batadv_v_orig_dump, + }, + .gw = { + .init_sel_class = batadv_v_init_sel_class, + .store_sel_class = batadv_v_store_sel_class, + .get_best_gw_node = batadv_v_gw_get_best_gw_node, + .is_eligible = batadv_v_gw_is_eligible, + .dump = batadv_v_gw_dump, + }, +}; + +/** + * batadv_v_hardif_init() - initialize the algorithm specific fields in the + * hard-interface object + * @hard_iface: the hard-interface to initialize + */ +void batadv_v_hardif_init(struct batadv_hard_iface *hard_iface) +{ + /* enable link throughput auto-detection by setting the throughput + * override to zero + */ + atomic_set(&hard_iface->bat_v.throughput_override, 0); + atomic_set(&hard_iface->bat_v.elp_interval, 500); + + hard_iface->bat_v.aggr_len = 0; + skb_queue_head_init(&hard_iface->bat_v.aggr_list); + INIT_DELAYED_WORK(&hard_iface->bat_v.aggr_wq, + batadv_v_ogm_aggr_work); +} + +/** + * batadv_v_mesh_init() - initialize the B.A.T.M.A.N. V private resources for a + * mesh + * @bat_priv: the object representing the mesh interface to initialise + * + * Return: 0 on success or a negative error code otherwise + */ +int batadv_v_mesh_init(struct batadv_priv *bat_priv) +{ + int ret = 0; + + ret = batadv_v_ogm_init(bat_priv); + if (ret < 0) + return ret; + + return 0; +} + +/** + * batadv_v_mesh_free() - free the B.A.T.M.A.N. V private resources for a mesh + * @bat_priv: the object representing the mesh interface to free + */ +void batadv_v_mesh_free(struct batadv_priv *bat_priv) +{ + batadv_v_ogm_free(bat_priv); +} + +/** + * batadv_v_init() - B.A.T.M.A.N. V initialization function + * + * Description: Takes care of initializing all the subcomponents. + * It is invoked upon module load only. + * + * Return: 0 on success or a negative error code otherwise + */ +int __init batadv_v_init(void) +{ + int ret; + + /* B.A.T.M.A.N. V echo location protocol packet */ + ret = batadv_recv_handler_register(BATADV_ELP, + batadv_v_elp_packet_recv); + if (ret < 0) + return ret; + + ret = batadv_recv_handler_register(BATADV_OGM2, + batadv_v_ogm_packet_recv); + if (ret < 0) + goto elp_unregister; + + ret = batadv_algo_register(&batadv_batman_v); + if (ret < 0) + goto ogm_unregister; + + return ret; + +ogm_unregister: + batadv_recv_handler_unregister(BATADV_OGM2); + +elp_unregister: + batadv_recv_handler_unregister(BATADV_ELP); + + return ret; +} diff --git a/net/batman-adv/bat_v.h b/net/batman-adv/bat_v.h new file mode 100644 index 000000000..964431f4d --- /dev/null +++ b/net/batman-adv/bat_v.h @@ -0,0 +1,41 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Linus Lüssing + */ + +#ifndef _NET_BATMAN_ADV_BAT_V_H_ +#define _NET_BATMAN_ADV_BAT_V_H_ + +#include "main.h" + +#ifdef CONFIG_BATMAN_ADV_BATMAN_V + +int batadv_v_init(void); +void batadv_v_hardif_init(struct batadv_hard_iface *hardif); +int batadv_v_mesh_init(struct batadv_priv *bat_priv); +void batadv_v_mesh_free(struct batadv_priv *bat_priv); + +#else + +static inline int batadv_v_init(void) +{ + return 0; +} + +static inline void batadv_v_hardif_init(struct batadv_hard_iface *hardif) +{ +} + +static inline int batadv_v_mesh_init(struct batadv_priv *bat_priv) +{ + return 0; +} + +static inline void batadv_v_mesh_free(struct batadv_priv *bat_priv) +{ +} + +#endif /* CONFIG_BATMAN_ADV_BATMAN_V */ + +#endif /* _NET_BATMAN_ADV_BAT_V_H_ */ diff --git a/net/batman-adv/bat_v_elp.c b/net/batman-adv/bat_v_elp.c new file mode 100644 index 000000000..98a624f32 --- /dev/null +++ b/net/batman-adv/bat_v_elp.c @@ -0,0 +1,551 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Linus Lüssing, Marek Lindner + */ + +#include "bat_v_elp.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "bat_algo.h" +#include "bat_v_ogm.h" +#include "hard-interface.h" +#include "log.h" +#include "originator.h" +#include "routing.h" +#include "send.h" + +/** + * batadv_v_elp_start_timer() - restart timer for ELP periodic work + * @hard_iface: the interface for which the timer has to be reset + */ +static void batadv_v_elp_start_timer(struct batadv_hard_iface *hard_iface) +{ + unsigned int msecs; + + msecs = atomic_read(&hard_iface->bat_v.elp_interval) - BATADV_JITTER; + msecs += prandom_u32_max(2 * BATADV_JITTER); + + queue_delayed_work(batadv_event_workqueue, &hard_iface->bat_v.elp_wq, + msecs_to_jiffies(msecs)); +} + +/** + * batadv_v_elp_get_throughput() - get the throughput towards a neighbour + * @neigh: the neighbour for which the throughput has to be obtained + * + * Return: The throughput towards the given neighbour in multiples of 100kpbs + * (a value of '1' equals 0.1Mbps, '10' equals 1Mbps, etc). + */ +static u32 batadv_v_elp_get_throughput(struct batadv_hardif_neigh_node *neigh) +{ + struct batadv_hard_iface *hard_iface = neigh->if_incoming; + struct ethtool_link_ksettings link_settings; + struct net_device *real_netdev; + struct station_info sinfo; + u32 throughput; + int ret; + + /* if the user specified a customised value for this interface, then + * return it directly + */ + throughput = atomic_read(&hard_iface->bat_v.throughput_override); + if (throughput != 0) + return throughput; + + /* if this is a wireless device, then ask its throughput through + * cfg80211 API + */ + if (batadv_is_wifi_hardif(hard_iface)) { + if (!batadv_is_cfg80211_hardif(hard_iface)) + /* unsupported WiFi driver version */ + goto default_throughput; + + real_netdev = batadv_get_real_netdev(hard_iface->net_dev); + if (!real_netdev) + goto default_throughput; + + ret = cfg80211_get_station(real_netdev, neigh->addr, &sinfo); + + if (!ret) { + /* free the TID stats immediately */ + cfg80211_sinfo_release_content(&sinfo); + } + + dev_put(real_netdev); + if (ret == -ENOENT) { + /* Node is not associated anymore! It would be + * possible to delete this neighbor. For now set + * the throughput metric to 0. + */ + return 0; + } + if (ret) + goto default_throughput; + + if (sinfo.filled & BIT(NL80211_STA_INFO_EXPECTED_THROUGHPUT)) + return sinfo.expected_throughput / 100; + + /* try to estimate the expected throughput based on reported tx + * rates + */ + if (sinfo.filled & BIT(NL80211_STA_INFO_TX_BITRATE)) + return cfg80211_calculate_bitrate(&sinfo.txrate) / 3; + + goto default_throughput; + } + + /* if not a wifi interface, check if this device provides data via + * ethtool (e.g. an Ethernet adapter) + */ + rtnl_lock(); + ret = __ethtool_get_link_ksettings(hard_iface->net_dev, &link_settings); + rtnl_unlock(); + if (ret == 0) { + /* link characteristics might change over time */ + if (link_settings.base.duplex == DUPLEX_FULL) + hard_iface->bat_v.flags |= BATADV_FULL_DUPLEX; + else + hard_iface->bat_v.flags &= ~BATADV_FULL_DUPLEX; + + throughput = link_settings.base.speed; + if (throughput && throughput != SPEED_UNKNOWN) + return throughput * 10; + } + +default_throughput: + if (!(hard_iface->bat_v.flags & BATADV_WARNING_DEFAULT)) { + batadv_info(hard_iface->soft_iface, + "WiFi driver or ethtool info does not provide information about link speeds on interface %s, therefore defaulting to hardcoded throughput values of %u.%1u Mbps. Consider overriding the throughput manually or checking your driver.\n", + hard_iface->net_dev->name, + BATADV_THROUGHPUT_DEFAULT_VALUE / 10, + BATADV_THROUGHPUT_DEFAULT_VALUE % 10); + hard_iface->bat_v.flags |= BATADV_WARNING_DEFAULT; + } + + /* if none of the above cases apply, return the base_throughput */ + return BATADV_THROUGHPUT_DEFAULT_VALUE; +} + +/** + * batadv_v_elp_throughput_metric_update() - worker updating the throughput + * metric of a single hop neighbour + * @work: the work queue item + */ +void batadv_v_elp_throughput_metric_update(struct work_struct *work) +{ + struct batadv_hardif_neigh_node_bat_v *neigh_bat_v; + struct batadv_hardif_neigh_node *neigh; + + neigh_bat_v = container_of(work, struct batadv_hardif_neigh_node_bat_v, + metric_work); + neigh = container_of(neigh_bat_v, struct batadv_hardif_neigh_node, + bat_v); + + ewma_throughput_add(&neigh->bat_v.throughput, + batadv_v_elp_get_throughput(neigh)); + + /* decrement refcounter to balance increment performed before scheduling + * this task + */ + batadv_hardif_neigh_put(neigh); +} + +/** + * batadv_v_elp_wifi_neigh_probe() - send link probing packets to a neighbour + * @neigh: the neighbour to probe + * + * Sends a predefined number of unicast wifi packets to a given neighbour in + * order to trigger the throughput estimation on this link by the RC algorithm. + * Packets are sent only if there is not enough payload unicast traffic towards + * this neighbour.. + * + * Return: True on success and false in case of error during skb preparation. + */ +static bool +batadv_v_elp_wifi_neigh_probe(struct batadv_hardif_neigh_node *neigh) +{ + struct batadv_hard_iface *hard_iface = neigh->if_incoming; + struct batadv_priv *bat_priv = netdev_priv(hard_iface->soft_iface); + unsigned long last_tx_diff; + struct sk_buff *skb; + int probe_len, i; + int elp_skb_len; + + /* this probing routine is for Wifi neighbours only */ + if (!batadv_is_wifi_hardif(hard_iface)) + return true; + + /* probe the neighbor only if no unicast packets have been sent + * to it in the last 100 milliseconds: this is the rate control + * algorithm sampling interval (minstrel). In this way, if not + * enough traffic has been sent to the neighbor, batman-adv can + * generate 2 probe packets and push the RC algorithm to perform + * the sampling + */ + last_tx_diff = jiffies_to_msecs(jiffies - neigh->bat_v.last_unicast_tx); + if (last_tx_diff <= BATADV_ELP_PROBE_MAX_TX_DIFF) + return true; + + probe_len = max_t(int, sizeof(struct batadv_elp_packet), + BATADV_ELP_MIN_PROBE_SIZE); + + for (i = 0; i < BATADV_ELP_PROBES_PER_NODE; i++) { + elp_skb_len = hard_iface->bat_v.elp_skb->len; + skb = skb_copy_expand(hard_iface->bat_v.elp_skb, 0, + probe_len - elp_skb_len, + GFP_ATOMIC); + if (!skb) + return false; + + /* Tell the skb to get as big as the allocated space (we want + * the packet to be exactly of that size to make the link + * throughput estimation effective. + */ + skb_put_zero(skb, probe_len - hard_iface->bat_v.elp_skb->len); + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Sending unicast (probe) ELP packet on interface %s to %pM\n", + hard_iface->net_dev->name, neigh->addr); + + batadv_send_skb_packet(skb, hard_iface, neigh->addr); + } + + return true; +} + +/** + * batadv_v_elp_periodic_work() - ELP periodic task per interface + * @work: work queue item + * + * Emits broadcast ELP messages in regular intervals. + */ +static void batadv_v_elp_periodic_work(struct work_struct *work) +{ + struct batadv_hardif_neigh_node *hardif_neigh; + struct batadv_hard_iface *hard_iface; + struct batadv_hard_iface_bat_v *bat_v; + struct batadv_elp_packet *elp_packet; + struct batadv_priv *bat_priv; + struct sk_buff *skb; + u32 elp_interval; + bool ret; + + bat_v = container_of(work, struct batadv_hard_iface_bat_v, elp_wq.work); + hard_iface = container_of(bat_v, struct batadv_hard_iface, bat_v); + bat_priv = netdev_priv(hard_iface->soft_iface); + + if (atomic_read(&bat_priv->mesh_state) == BATADV_MESH_DEACTIVATING) + goto out; + + /* we are in the process of shutting this interface down */ + if (hard_iface->if_status == BATADV_IF_NOT_IN_USE || + hard_iface->if_status == BATADV_IF_TO_BE_REMOVED) + goto out; + + /* the interface was enabled but may not be ready yet */ + if (hard_iface->if_status != BATADV_IF_ACTIVE) + goto restart_timer; + + skb = skb_copy(hard_iface->bat_v.elp_skb, GFP_ATOMIC); + if (!skb) + goto restart_timer; + + elp_packet = (struct batadv_elp_packet *)skb->data; + elp_packet->seqno = htonl(atomic_read(&hard_iface->bat_v.elp_seqno)); + elp_interval = atomic_read(&hard_iface->bat_v.elp_interval); + elp_packet->elp_interval = htonl(elp_interval); + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Sending broadcast ELP packet on interface %s, seqno %u\n", + hard_iface->net_dev->name, + atomic_read(&hard_iface->bat_v.elp_seqno)); + + batadv_send_broadcast_skb(skb, hard_iface); + + atomic_inc(&hard_iface->bat_v.elp_seqno); + + /* The throughput metric is updated on each sent packet. This way, if a + * node is dead and no longer sends packets, batman-adv is still able to + * react timely to its death. + * + * The throughput metric is updated by following these steps: + * 1) if the hard_iface is wifi => send a number of unicast ELPs for + * probing/sampling to each neighbor + * 2) update the throughput metric value of each neighbor (note that the + * value retrieved in this step might be 100ms old because the + * probing packets at point 1) could still be in the HW queue) + */ + rcu_read_lock(); + hlist_for_each_entry_rcu(hardif_neigh, &hard_iface->neigh_list, list) { + if (!batadv_v_elp_wifi_neigh_probe(hardif_neigh)) + /* if something goes wrong while probing, better to stop + * sending packets immediately and reschedule the task + */ + break; + + if (!kref_get_unless_zero(&hardif_neigh->refcount)) + continue; + + /* Reading the estimated throughput from cfg80211 is a task that + * may sleep and that is not allowed in an rcu protected + * context. Therefore schedule a task for that. + */ + ret = queue_work(batadv_event_workqueue, + &hardif_neigh->bat_v.metric_work); + + if (!ret) + batadv_hardif_neigh_put(hardif_neigh); + } + rcu_read_unlock(); + +restart_timer: + batadv_v_elp_start_timer(hard_iface); +out: + return; +} + +/** + * batadv_v_elp_iface_enable() - setup the ELP interface private resources + * @hard_iface: interface for which the data has to be prepared + * + * Return: 0 on success or a -ENOMEM in case of failure. + */ +int batadv_v_elp_iface_enable(struct batadv_hard_iface *hard_iface) +{ + static const size_t tvlv_padding = sizeof(__be32); + struct batadv_elp_packet *elp_packet; + unsigned char *elp_buff; + u32 random_seqno; + size_t size; + int res = -ENOMEM; + + size = ETH_HLEN + NET_IP_ALIGN + BATADV_ELP_HLEN + tvlv_padding; + hard_iface->bat_v.elp_skb = dev_alloc_skb(size); + if (!hard_iface->bat_v.elp_skb) + goto out; + + skb_reserve(hard_iface->bat_v.elp_skb, ETH_HLEN + NET_IP_ALIGN); + elp_buff = skb_put_zero(hard_iface->bat_v.elp_skb, + BATADV_ELP_HLEN + tvlv_padding); + elp_packet = (struct batadv_elp_packet *)elp_buff; + + elp_packet->packet_type = BATADV_ELP; + elp_packet->version = BATADV_COMPAT_VERSION; + + /* randomize initial seqno to avoid collision */ + get_random_bytes(&random_seqno, sizeof(random_seqno)); + atomic_set(&hard_iface->bat_v.elp_seqno, random_seqno); + + /* assume full-duplex by default */ + hard_iface->bat_v.flags |= BATADV_FULL_DUPLEX; + + /* warn the user (again) if there is no throughput data is available */ + hard_iface->bat_v.flags &= ~BATADV_WARNING_DEFAULT; + + if (batadv_is_wifi_hardif(hard_iface)) + hard_iface->bat_v.flags &= ~BATADV_FULL_DUPLEX; + + INIT_DELAYED_WORK(&hard_iface->bat_v.elp_wq, + batadv_v_elp_periodic_work); + batadv_v_elp_start_timer(hard_iface); + res = 0; + +out: + return res; +} + +/** + * batadv_v_elp_iface_disable() - release ELP interface private resources + * @hard_iface: interface for which the resources have to be released + */ +void batadv_v_elp_iface_disable(struct batadv_hard_iface *hard_iface) +{ + cancel_delayed_work_sync(&hard_iface->bat_v.elp_wq); + + dev_kfree_skb(hard_iface->bat_v.elp_skb); + hard_iface->bat_v.elp_skb = NULL; +} + +/** + * batadv_v_elp_iface_activate() - update the ELP buffer belonging to the given + * hard-interface + * @primary_iface: the new primary interface + * @hard_iface: interface holding the to-be-updated buffer + */ +void batadv_v_elp_iface_activate(struct batadv_hard_iface *primary_iface, + struct batadv_hard_iface *hard_iface) +{ + struct batadv_elp_packet *elp_packet; + struct sk_buff *skb; + + if (!hard_iface->bat_v.elp_skb) + return; + + skb = hard_iface->bat_v.elp_skb; + elp_packet = (struct batadv_elp_packet *)skb->data; + ether_addr_copy(elp_packet->orig, + primary_iface->net_dev->dev_addr); +} + +/** + * batadv_v_elp_primary_iface_set() - change internal data to reflect the new + * primary interface + * @primary_iface: the new primary interface + */ +void batadv_v_elp_primary_iface_set(struct batadv_hard_iface *primary_iface) +{ + struct batadv_hard_iface *hard_iface; + + /* update orig field of every elp iface belonging to this mesh */ + rcu_read_lock(); + list_for_each_entry_rcu(hard_iface, &batadv_hardif_list, list) { + if (primary_iface->soft_iface != hard_iface->soft_iface) + continue; + + batadv_v_elp_iface_activate(primary_iface, hard_iface); + } + rcu_read_unlock(); +} + +/** + * batadv_v_elp_neigh_update() - update an ELP neighbour node + * @bat_priv: the bat priv with all the soft interface information + * @neigh_addr: the neighbour interface address + * @if_incoming: the interface the packet was received through + * @elp_packet: the received ELP packet + * + * Updates the ELP neighbour node state with the data received within the new + * ELP packet. + */ +static void batadv_v_elp_neigh_update(struct batadv_priv *bat_priv, + u8 *neigh_addr, + struct batadv_hard_iface *if_incoming, + struct batadv_elp_packet *elp_packet) + +{ + struct batadv_neigh_node *neigh; + struct batadv_orig_node *orig_neigh; + struct batadv_hardif_neigh_node *hardif_neigh; + s32 seqno_diff; + s32 elp_latest_seqno; + + orig_neigh = batadv_v_ogm_orig_get(bat_priv, elp_packet->orig); + if (!orig_neigh) + return; + + neigh = batadv_neigh_node_get_or_create(orig_neigh, + if_incoming, neigh_addr); + if (!neigh) + goto orig_free; + + hardif_neigh = batadv_hardif_neigh_get(if_incoming, neigh_addr); + if (!hardif_neigh) + goto neigh_free; + + elp_latest_seqno = hardif_neigh->bat_v.elp_latest_seqno; + seqno_diff = ntohl(elp_packet->seqno) - elp_latest_seqno; + + /* known or older sequence numbers are ignored. However always adopt + * if the router seems to have been restarted. + */ + if (seqno_diff < 1 && seqno_diff > -BATADV_ELP_MAX_AGE) + goto hardif_free; + + neigh->last_seen = jiffies; + hardif_neigh->last_seen = jiffies; + hardif_neigh->bat_v.elp_latest_seqno = ntohl(elp_packet->seqno); + hardif_neigh->bat_v.elp_interval = ntohl(elp_packet->elp_interval); + +hardif_free: + batadv_hardif_neigh_put(hardif_neigh); +neigh_free: + batadv_neigh_node_put(neigh); +orig_free: + batadv_orig_node_put(orig_neigh); +} + +/** + * batadv_v_elp_packet_recv() - main ELP packet handler + * @skb: the received packet + * @if_incoming: the interface this packet was received through + * + * Return: NET_RX_SUCCESS and consumes the skb if the packet was properly + * processed or NET_RX_DROP in case of failure. + */ +int batadv_v_elp_packet_recv(struct sk_buff *skb, + struct batadv_hard_iface *if_incoming) +{ + struct batadv_priv *bat_priv = netdev_priv(if_incoming->soft_iface); + struct batadv_elp_packet *elp_packet; + struct batadv_hard_iface *primary_if; + struct ethhdr *ethhdr; + bool res; + int ret = NET_RX_DROP; + + res = batadv_check_management_packet(skb, if_incoming, BATADV_ELP_HLEN); + if (!res) + goto free_skb; + + ethhdr = eth_hdr(skb); + if (batadv_is_my_mac(bat_priv, ethhdr->h_source)) + goto free_skb; + + /* did we receive a B.A.T.M.A.N. V ELP packet on an interface + * that does not have B.A.T.M.A.N. V ELP enabled ? + */ + if (strcmp(bat_priv->algo_ops->name, "BATMAN_V") != 0) + goto free_skb; + + elp_packet = (struct batadv_elp_packet *)skb->data; + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Received ELP packet from %pM seqno %u ORIG: %pM\n", + ethhdr->h_source, ntohl(elp_packet->seqno), + elp_packet->orig); + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + goto free_skb; + + batadv_v_elp_neigh_update(bat_priv, ethhdr->h_source, if_incoming, + elp_packet); + + ret = NET_RX_SUCCESS; + batadv_hardif_put(primary_if); + +free_skb: + if (ret == NET_RX_SUCCESS) + consume_skb(skb); + else + kfree_skb(skb); + + return ret; +} diff --git a/net/batman-adv/bat_v_elp.h b/net/batman-adv/bat_v_elp.h new file mode 100644 index 000000000..9e2740195 --- /dev/null +++ b/net/batman-adv/bat_v_elp.h @@ -0,0 +1,24 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Linus Lüssing, Marek Lindner + */ + +#ifndef _NET_BATMAN_ADV_BAT_V_ELP_H_ +#define _NET_BATMAN_ADV_BAT_V_ELP_H_ + +#include "main.h" + +#include +#include + +int batadv_v_elp_iface_enable(struct batadv_hard_iface *hard_iface); +void batadv_v_elp_iface_disable(struct batadv_hard_iface *hard_iface); +void batadv_v_elp_iface_activate(struct batadv_hard_iface *primary_iface, + struct batadv_hard_iface *hard_iface); +void batadv_v_elp_primary_iface_set(struct batadv_hard_iface *primary_iface); +int batadv_v_elp_packet_recv(struct sk_buff *skb, + struct batadv_hard_iface *if_incoming); +void batadv_v_elp_throughput_metric_update(struct work_struct *work); + +#endif /* _NET_BATMAN_ADV_BAT_V_ELP_H_ */ diff --git a/net/batman-adv/bat_v_ogm.c b/net/batman-adv/bat_v_ogm.c new file mode 100644 index 000000000..9f4815f4c --- /dev/null +++ b/net/batman-adv/bat_v_ogm.c @@ -0,0 +1,1088 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Antonio Quartulli + */ + +#include "bat_v_ogm.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "bat_algo.h" +#include "hard-interface.h" +#include "hash.h" +#include "log.h" +#include "originator.h" +#include "routing.h" +#include "send.h" +#include "translation-table.h" +#include "tvlv.h" + +/** + * batadv_v_ogm_orig_get() - retrieve and possibly create an originator node + * @bat_priv: the bat priv with all the soft interface information + * @addr: the address of the originator + * + * Return: the orig_node corresponding to the specified address. If such an + * object does not exist, it is allocated here. In case of allocation failure + * returns NULL. + */ +struct batadv_orig_node *batadv_v_ogm_orig_get(struct batadv_priv *bat_priv, + const u8 *addr) +{ + struct batadv_orig_node *orig_node; + int hash_added; + + orig_node = batadv_orig_hash_find(bat_priv, addr); + if (orig_node) + return orig_node; + + orig_node = batadv_orig_node_new(bat_priv, addr); + if (!orig_node) + return NULL; + + kref_get(&orig_node->refcount); + hash_added = batadv_hash_add(bat_priv->orig_hash, batadv_compare_orig, + batadv_choose_orig, orig_node, + &orig_node->hash_entry); + if (hash_added != 0) { + /* remove refcnt for newly created orig_node and hash entry */ + batadv_orig_node_put(orig_node); + batadv_orig_node_put(orig_node); + orig_node = NULL; + } + + return orig_node; +} + +/** + * batadv_v_ogm_start_queue_timer() - restart the OGM aggregation timer + * @hard_iface: the interface to use to send the OGM + */ +static void batadv_v_ogm_start_queue_timer(struct batadv_hard_iface *hard_iface) +{ + unsigned int msecs = BATADV_MAX_AGGREGATION_MS * 1000; + + /* msecs * [0.9, 1.1] */ + msecs += prandom_u32_max(msecs / 5) - (msecs / 10); + queue_delayed_work(batadv_event_workqueue, &hard_iface->bat_v.aggr_wq, + msecs_to_jiffies(msecs / 1000)); +} + +/** + * batadv_v_ogm_start_timer() - restart the OGM sending timer + * @bat_priv: the bat priv with all the soft interface information + */ +static void batadv_v_ogm_start_timer(struct batadv_priv *bat_priv) +{ + unsigned long msecs; + /* this function may be invoked in different contexts (ogm rescheduling + * or hard_iface activation), but the work timer should not be reset + */ + if (delayed_work_pending(&bat_priv->bat_v.ogm_wq)) + return; + + msecs = atomic_read(&bat_priv->orig_interval) - BATADV_JITTER; + msecs += prandom_u32_max(2 * BATADV_JITTER); + queue_delayed_work(batadv_event_workqueue, &bat_priv->bat_v.ogm_wq, + msecs_to_jiffies(msecs)); +} + +/** + * batadv_v_ogm_send_to_if() - send a batman ogm using a given interface + * @skb: the OGM to send + * @hard_iface: the interface to use to send the OGM + */ +static void batadv_v_ogm_send_to_if(struct sk_buff *skb, + struct batadv_hard_iface *hard_iface) +{ + struct batadv_priv *bat_priv = netdev_priv(hard_iface->soft_iface); + + if (hard_iface->if_status != BATADV_IF_ACTIVE) { + kfree_skb(skb); + return; + } + + batadv_inc_counter(bat_priv, BATADV_CNT_MGMT_TX); + batadv_add_counter(bat_priv, BATADV_CNT_MGMT_TX_BYTES, + skb->len + ETH_HLEN); + + batadv_send_broadcast_skb(skb, hard_iface); +} + +/** + * batadv_v_ogm_len() - OGMv2 packet length + * @skb: the OGM to check + * + * Return: Length of the given OGMv2 packet, including tvlv length, excluding + * ethernet header length. + */ +static unsigned int batadv_v_ogm_len(struct sk_buff *skb) +{ + struct batadv_ogm2_packet *ogm_packet; + + ogm_packet = (struct batadv_ogm2_packet *)skb->data; + return BATADV_OGM2_HLEN + ntohs(ogm_packet->tvlv_len); +} + +/** + * batadv_v_ogm_queue_left() - check if given OGM still fits aggregation queue + * @skb: the OGM to check + * @hard_iface: the interface to use to send the OGM + * + * Caller needs to hold the hard_iface->bat_v.aggr_list.lock. + * + * Return: True, if the given OGMv2 packet still fits, false otherwise. + */ +static bool batadv_v_ogm_queue_left(struct sk_buff *skb, + struct batadv_hard_iface *hard_iface) +{ + unsigned int max = min_t(unsigned int, hard_iface->net_dev->mtu, + BATADV_MAX_AGGREGATION_BYTES); + unsigned int ogm_len = batadv_v_ogm_len(skb); + + lockdep_assert_held(&hard_iface->bat_v.aggr_list.lock); + + return hard_iface->bat_v.aggr_len + ogm_len <= max; +} + +/** + * batadv_v_ogm_aggr_list_free - free all elements in an aggregation queue + * @hard_iface: the interface holding the aggregation queue + * + * Empties the OGMv2 aggregation queue and frees all the skbs it contains. + * + * Caller needs to hold the hard_iface->bat_v.aggr_list.lock. + */ +static void batadv_v_ogm_aggr_list_free(struct batadv_hard_iface *hard_iface) +{ + lockdep_assert_held(&hard_iface->bat_v.aggr_list.lock); + + __skb_queue_purge(&hard_iface->bat_v.aggr_list); + hard_iface->bat_v.aggr_len = 0; +} + +/** + * batadv_v_ogm_aggr_send() - flush & send aggregation queue + * @hard_iface: the interface with the aggregation queue to flush + * + * Aggregates all OGMv2 packets currently in the aggregation queue into a + * single OGMv2 packet and transmits this aggregate. + * + * The aggregation queue is empty after this call. + * + * Caller needs to hold the hard_iface->bat_v.aggr_list.lock. + */ +static void batadv_v_ogm_aggr_send(struct batadv_hard_iface *hard_iface) +{ + unsigned int aggr_len = hard_iface->bat_v.aggr_len; + struct sk_buff *skb_aggr; + unsigned int ogm_len; + struct sk_buff *skb; + + lockdep_assert_held(&hard_iface->bat_v.aggr_list.lock); + + if (!aggr_len) + return; + + skb_aggr = dev_alloc_skb(aggr_len + ETH_HLEN + NET_IP_ALIGN); + if (!skb_aggr) { + batadv_v_ogm_aggr_list_free(hard_iface); + return; + } + + skb_reserve(skb_aggr, ETH_HLEN + NET_IP_ALIGN); + skb_reset_network_header(skb_aggr); + + while ((skb = __skb_dequeue(&hard_iface->bat_v.aggr_list))) { + hard_iface->bat_v.aggr_len -= batadv_v_ogm_len(skb); + + ogm_len = batadv_v_ogm_len(skb); + skb_put_data(skb_aggr, skb->data, ogm_len); + + consume_skb(skb); + } + + batadv_v_ogm_send_to_if(skb_aggr, hard_iface); +} + +/** + * batadv_v_ogm_queue_on_if() - queue a batman ogm on a given interface + * @skb: the OGM to queue + * @hard_iface: the interface to queue the OGM on + */ +static void batadv_v_ogm_queue_on_if(struct sk_buff *skb, + struct batadv_hard_iface *hard_iface) +{ + struct batadv_priv *bat_priv = netdev_priv(hard_iface->soft_iface); + + if (!atomic_read(&bat_priv->aggregated_ogms)) { + batadv_v_ogm_send_to_if(skb, hard_iface); + return; + } + + spin_lock_bh(&hard_iface->bat_v.aggr_list.lock); + if (!batadv_v_ogm_queue_left(skb, hard_iface)) + batadv_v_ogm_aggr_send(hard_iface); + + hard_iface->bat_v.aggr_len += batadv_v_ogm_len(skb); + __skb_queue_tail(&hard_iface->bat_v.aggr_list, skb); + spin_unlock_bh(&hard_iface->bat_v.aggr_list.lock); +} + +/** + * batadv_v_ogm_send_softif() - periodic worker broadcasting the own OGM + * @bat_priv: the bat priv with all the soft interface information + */ +static void batadv_v_ogm_send_softif(struct batadv_priv *bat_priv) +{ + struct batadv_hard_iface *hard_iface; + struct batadv_ogm2_packet *ogm_packet; + struct sk_buff *skb, *skb_tmp; + unsigned char *ogm_buff; + int ogm_buff_len; + u16 tvlv_len = 0; + int ret; + + lockdep_assert_held(&bat_priv->bat_v.ogm_buff_mutex); + + if (atomic_read(&bat_priv->mesh_state) == BATADV_MESH_DEACTIVATING) + goto out; + + ogm_buff = bat_priv->bat_v.ogm_buff; + ogm_buff_len = bat_priv->bat_v.ogm_buff_len; + /* tt changes have to be committed before the tvlv data is + * appended as it may alter the tt tvlv container + */ + batadv_tt_local_commit_changes(bat_priv); + tvlv_len = batadv_tvlv_container_ogm_append(bat_priv, &ogm_buff, + &ogm_buff_len, + BATADV_OGM2_HLEN); + + bat_priv->bat_v.ogm_buff = ogm_buff; + bat_priv->bat_v.ogm_buff_len = ogm_buff_len; + + skb = netdev_alloc_skb_ip_align(NULL, ETH_HLEN + ogm_buff_len); + if (!skb) + goto reschedule; + + skb_reserve(skb, ETH_HLEN); + skb_put_data(skb, ogm_buff, ogm_buff_len); + + ogm_packet = (struct batadv_ogm2_packet *)skb->data; + ogm_packet->seqno = htonl(atomic_read(&bat_priv->bat_v.ogm_seqno)); + atomic_inc(&bat_priv->bat_v.ogm_seqno); + ogm_packet->tvlv_len = htons(tvlv_len); + + /* broadcast on every interface */ + rcu_read_lock(); + list_for_each_entry_rcu(hard_iface, &batadv_hardif_list, list) { + if (hard_iface->soft_iface != bat_priv->soft_iface) + continue; + + if (!kref_get_unless_zero(&hard_iface->refcount)) + continue; + + ret = batadv_hardif_no_broadcast(hard_iface, NULL, NULL); + if (ret) { + char *type; + + switch (ret) { + case BATADV_HARDIF_BCAST_NORECIPIENT: + type = "no neighbor"; + break; + case BATADV_HARDIF_BCAST_DUPFWD: + type = "single neighbor is source"; + break; + case BATADV_HARDIF_BCAST_DUPORIG: + type = "single neighbor is originator"; + break; + default: + type = "unknown"; + } + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, "OGM2 from ourselves on %s suppressed: %s\n", + hard_iface->net_dev->name, type); + + batadv_hardif_put(hard_iface); + continue; + } + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Sending own OGM2 packet (originator %pM, seqno %u, throughput %u, TTL %d) on interface %s [%pM]\n", + ogm_packet->orig, ntohl(ogm_packet->seqno), + ntohl(ogm_packet->throughput), ogm_packet->ttl, + hard_iface->net_dev->name, + hard_iface->net_dev->dev_addr); + + /* this skb gets consumed by batadv_v_ogm_send_to_if() */ + skb_tmp = skb_clone(skb, GFP_ATOMIC); + if (!skb_tmp) { + batadv_hardif_put(hard_iface); + break; + } + + batadv_v_ogm_queue_on_if(skb_tmp, hard_iface); + batadv_hardif_put(hard_iface); + } + rcu_read_unlock(); + + consume_skb(skb); + +reschedule: + batadv_v_ogm_start_timer(bat_priv); +out: + return; +} + +/** + * batadv_v_ogm_send() - periodic worker broadcasting the own OGM + * @work: work queue item + */ +static void batadv_v_ogm_send(struct work_struct *work) +{ + struct batadv_priv_bat_v *bat_v; + struct batadv_priv *bat_priv; + + bat_v = container_of(work, struct batadv_priv_bat_v, ogm_wq.work); + bat_priv = container_of(bat_v, struct batadv_priv, bat_v); + + mutex_lock(&bat_priv->bat_v.ogm_buff_mutex); + batadv_v_ogm_send_softif(bat_priv); + mutex_unlock(&bat_priv->bat_v.ogm_buff_mutex); +} + +/** + * batadv_v_ogm_aggr_work() - OGM queue periodic task per interface + * @work: work queue item + * + * Emits aggregated OGM messages in regular intervals. + */ +void batadv_v_ogm_aggr_work(struct work_struct *work) +{ + struct batadv_hard_iface_bat_v *batv; + struct batadv_hard_iface *hard_iface; + + batv = container_of(work, struct batadv_hard_iface_bat_v, aggr_wq.work); + hard_iface = container_of(batv, struct batadv_hard_iface, bat_v); + + spin_lock_bh(&hard_iface->bat_v.aggr_list.lock); + batadv_v_ogm_aggr_send(hard_iface); + spin_unlock_bh(&hard_iface->bat_v.aggr_list.lock); + + batadv_v_ogm_start_queue_timer(hard_iface); +} + +/** + * batadv_v_ogm_iface_enable() - prepare an interface for B.A.T.M.A.N. V + * @hard_iface: the interface to prepare + * + * Takes care of scheduling its own OGM sending routine for this interface. + * + * Return: 0 on success or a negative error code otherwise + */ +int batadv_v_ogm_iface_enable(struct batadv_hard_iface *hard_iface) +{ + struct batadv_priv *bat_priv = netdev_priv(hard_iface->soft_iface); + + batadv_v_ogm_start_queue_timer(hard_iface); + batadv_v_ogm_start_timer(bat_priv); + + return 0; +} + +/** + * batadv_v_ogm_iface_disable() - release OGM interface private resources + * @hard_iface: interface for which the resources have to be released + */ +void batadv_v_ogm_iface_disable(struct batadv_hard_iface *hard_iface) +{ + cancel_delayed_work_sync(&hard_iface->bat_v.aggr_wq); + + spin_lock_bh(&hard_iface->bat_v.aggr_list.lock); + batadv_v_ogm_aggr_list_free(hard_iface); + spin_unlock_bh(&hard_iface->bat_v.aggr_list.lock); +} + +/** + * batadv_v_ogm_primary_iface_set() - set a new primary interface + * @primary_iface: the new primary interface + */ +void batadv_v_ogm_primary_iface_set(struct batadv_hard_iface *primary_iface) +{ + struct batadv_priv *bat_priv = netdev_priv(primary_iface->soft_iface); + struct batadv_ogm2_packet *ogm_packet; + + mutex_lock(&bat_priv->bat_v.ogm_buff_mutex); + if (!bat_priv->bat_v.ogm_buff) + goto unlock; + + ogm_packet = (struct batadv_ogm2_packet *)bat_priv->bat_v.ogm_buff; + ether_addr_copy(ogm_packet->orig, primary_iface->net_dev->dev_addr); + +unlock: + mutex_unlock(&bat_priv->bat_v.ogm_buff_mutex); +} + +/** + * batadv_v_forward_penalty() - apply a penalty to the throughput metric + * forwarded with B.A.T.M.A.N. V OGMs + * @bat_priv: the bat priv with all the soft interface information + * @if_incoming: the interface where the OGM has been received + * @if_outgoing: the interface where the OGM has to be forwarded to + * @throughput: the current throughput + * + * Apply a penalty on the current throughput metric value based on the + * characteristic of the interface where the OGM has been received. + * + * Initially the per hardif hop penalty is applied to the throughput. After + * that the return value is then computed as follows: + * - throughput * 50% if the incoming and outgoing interface are the + * same WiFi interface and the throughput is above + * 1MBit/s + * - throughput if the outgoing interface is the default + * interface (i.e. this OGM is processed for the + * internal table and not forwarded) + * - throughput * node hop penalty otherwise + * + * Return: the penalised throughput metric. + */ +static u32 batadv_v_forward_penalty(struct batadv_priv *bat_priv, + struct batadv_hard_iface *if_incoming, + struct batadv_hard_iface *if_outgoing, + u32 throughput) +{ + int if_hop_penalty = atomic_read(&if_incoming->hop_penalty); + int hop_penalty = atomic_read(&bat_priv->hop_penalty); + int hop_penalty_max = BATADV_TQ_MAX_VALUE; + + /* Apply per hardif hop penalty */ + throughput = throughput * (hop_penalty_max - if_hop_penalty) / + hop_penalty_max; + + /* Don't apply hop penalty in default originator table. */ + if (if_outgoing == BATADV_IF_DEFAULT) + return throughput; + + /* Forwarding on the same WiFi interface cuts the throughput in half + * due to the store & forward characteristics of WIFI. + * Very low throughput values are the exception. + */ + if (throughput > 10 && + if_incoming == if_outgoing && + !(if_incoming->bat_v.flags & BATADV_FULL_DUPLEX)) + return throughput / 2; + + /* hop penalty of 255 equals 100% */ + return throughput * (hop_penalty_max - hop_penalty) / hop_penalty_max; +} + +/** + * batadv_v_ogm_forward() - check conditions and forward an OGM to the given + * outgoing interface + * @bat_priv: the bat priv with all the soft interface information + * @ogm_received: previously received OGM to be forwarded + * @orig_node: the originator which has been updated + * @neigh_node: the neigh_node through with the OGM has been received + * @if_incoming: the interface on which this OGM was received on + * @if_outgoing: the interface to which the OGM has to be forwarded to + * + * Forward an OGM to an interface after having altered the throughput metric and + * the TTL value contained in it. The original OGM isn't modified. + */ +static void batadv_v_ogm_forward(struct batadv_priv *bat_priv, + const struct batadv_ogm2_packet *ogm_received, + struct batadv_orig_node *orig_node, + struct batadv_neigh_node *neigh_node, + struct batadv_hard_iface *if_incoming, + struct batadv_hard_iface *if_outgoing) +{ + struct batadv_neigh_ifinfo *neigh_ifinfo = NULL; + struct batadv_orig_ifinfo *orig_ifinfo = NULL; + struct batadv_neigh_node *router = NULL; + struct batadv_ogm2_packet *ogm_forward; + unsigned char *skb_buff; + struct sk_buff *skb; + size_t packet_len; + u16 tvlv_len; + + /* only forward for specific interfaces, not for the default one. */ + if (if_outgoing == BATADV_IF_DEFAULT) + goto out; + + orig_ifinfo = batadv_orig_ifinfo_new(orig_node, if_outgoing); + if (!orig_ifinfo) + goto out; + + /* acquire possibly updated router */ + router = batadv_orig_router_get(orig_node, if_outgoing); + + /* strict rule: forward packets coming from the best next hop only */ + if (neigh_node != router) + goto out; + + /* don't forward the same seqno twice on one interface */ + if (orig_ifinfo->last_seqno_forwarded == ntohl(ogm_received->seqno)) + goto out; + + orig_ifinfo->last_seqno_forwarded = ntohl(ogm_received->seqno); + + if (ogm_received->ttl <= 1) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, "ttl exceeded\n"); + goto out; + } + + neigh_ifinfo = batadv_neigh_ifinfo_get(neigh_node, if_outgoing); + if (!neigh_ifinfo) + goto out; + + tvlv_len = ntohs(ogm_received->tvlv_len); + + packet_len = BATADV_OGM2_HLEN + tvlv_len; + skb = netdev_alloc_skb_ip_align(if_outgoing->net_dev, + ETH_HLEN + packet_len); + if (!skb) + goto out; + + skb_reserve(skb, ETH_HLEN); + skb_buff = skb_put_data(skb, ogm_received, packet_len); + + /* apply forward penalty */ + ogm_forward = (struct batadv_ogm2_packet *)skb_buff; + ogm_forward->throughput = htonl(neigh_ifinfo->bat_v.throughput); + ogm_forward->ttl--; + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Forwarding OGM2 packet on %s: throughput %u, ttl %u, received via %s\n", + if_outgoing->net_dev->name, ntohl(ogm_forward->throughput), + ogm_forward->ttl, if_incoming->net_dev->name); + + batadv_v_ogm_queue_on_if(skb, if_outgoing); + +out: + batadv_orig_ifinfo_put(orig_ifinfo); + batadv_neigh_node_put(router); + batadv_neigh_ifinfo_put(neigh_ifinfo); +} + +/** + * batadv_v_ogm_metric_update() - update route metric based on OGM + * @bat_priv: the bat priv with all the soft interface information + * @ogm2: OGM2 structure + * @orig_node: Originator structure for which the OGM has been received + * @neigh_node: the neigh_node through with the OGM has been received + * @if_incoming: the interface where this packet was received + * @if_outgoing: the interface for which the packet should be considered + * + * Return: + * 1 if the OGM is new, + * 0 if it is not new but valid, + * <0 on error (e.g. old OGM) + */ +static int batadv_v_ogm_metric_update(struct batadv_priv *bat_priv, + const struct batadv_ogm2_packet *ogm2, + struct batadv_orig_node *orig_node, + struct batadv_neigh_node *neigh_node, + struct batadv_hard_iface *if_incoming, + struct batadv_hard_iface *if_outgoing) +{ + struct batadv_orig_ifinfo *orig_ifinfo; + struct batadv_neigh_ifinfo *neigh_ifinfo = NULL; + bool protection_started = false; + int ret = -EINVAL; + u32 path_throughput; + s32 seq_diff; + + orig_ifinfo = batadv_orig_ifinfo_new(orig_node, if_outgoing); + if (!orig_ifinfo) + goto out; + + seq_diff = ntohl(ogm2->seqno) - orig_ifinfo->last_real_seqno; + + if (!hlist_empty(&orig_node->neigh_list) && + batadv_window_protected(bat_priv, seq_diff, + BATADV_OGM_MAX_AGE, + &orig_ifinfo->batman_seqno_reset, + &protection_started)) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Drop packet: packet within window protection time from %pM\n", + ogm2->orig); + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Last reset: %ld, %ld\n", + orig_ifinfo->batman_seqno_reset, jiffies); + goto out; + } + + /* drop packets with old seqnos, however accept the first packet after + * a host has been rebooted. + */ + if (seq_diff < 0 && !protection_started) + goto out; + + neigh_node->last_seen = jiffies; + + orig_node->last_seen = jiffies; + + orig_ifinfo->last_real_seqno = ntohl(ogm2->seqno); + orig_ifinfo->last_ttl = ogm2->ttl; + + neigh_ifinfo = batadv_neigh_ifinfo_new(neigh_node, if_outgoing); + if (!neigh_ifinfo) + goto out; + + path_throughput = batadv_v_forward_penalty(bat_priv, if_incoming, + if_outgoing, + ntohl(ogm2->throughput)); + neigh_ifinfo->bat_v.throughput = path_throughput; + neigh_ifinfo->bat_v.last_seqno = ntohl(ogm2->seqno); + neigh_ifinfo->last_ttl = ogm2->ttl; + + if (seq_diff > 0 || protection_started) + ret = 1; + else + ret = 0; +out: + batadv_orig_ifinfo_put(orig_ifinfo); + batadv_neigh_ifinfo_put(neigh_ifinfo); + + return ret; +} + +/** + * batadv_v_ogm_route_update() - update routes based on OGM + * @bat_priv: the bat priv with all the soft interface information + * @ethhdr: the Ethernet header of the OGM2 + * @ogm2: OGM2 structure + * @orig_node: Originator structure for which the OGM has been received + * @neigh_node: the neigh_node through with the OGM has been received + * @if_incoming: the interface where this packet was received + * @if_outgoing: the interface for which the packet should be considered + * + * Return: true if the packet should be forwarded, false otherwise + */ +static bool batadv_v_ogm_route_update(struct batadv_priv *bat_priv, + const struct ethhdr *ethhdr, + const struct batadv_ogm2_packet *ogm2, + struct batadv_orig_node *orig_node, + struct batadv_neigh_node *neigh_node, + struct batadv_hard_iface *if_incoming, + struct batadv_hard_iface *if_outgoing) +{ + struct batadv_neigh_node *router = NULL; + struct batadv_orig_node *orig_neigh_node; + struct batadv_neigh_node *orig_neigh_router = NULL; + struct batadv_neigh_ifinfo *router_ifinfo = NULL, *neigh_ifinfo = NULL; + u32 router_throughput, neigh_throughput; + u32 router_last_seqno; + u32 neigh_last_seqno; + s32 neigh_seq_diff; + bool forward = false; + + orig_neigh_node = batadv_v_ogm_orig_get(bat_priv, ethhdr->h_source); + if (!orig_neigh_node) + goto out; + + orig_neigh_router = batadv_orig_router_get(orig_neigh_node, + if_outgoing); + + /* drop packet if sender is not a direct neighbor and if we + * don't route towards it + */ + router = batadv_orig_router_get(orig_node, if_outgoing); + if (router && router->orig_node != orig_node && !orig_neigh_router) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Drop packet: OGM via unknown neighbor!\n"); + goto out; + } + + /* Mark the OGM to be considered for forwarding, and update routes + * if needed. + */ + forward = true; + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Searching and updating originator entry of received packet\n"); + + /* if this neighbor already is our next hop there is nothing + * to change + */ + if (router == neigh_node) + goto out; + + /* don't consider neighbours with worse throughput. + * also switch route if this seqno is BATADV_V_MAX_ORIGDIFF newer than + * the last received seqno from our best next hop. + */ + if (router) { + router_ifinfo = batadv_neigh_ifinfo_get(router, if_outgoing); + neigh_ifinfo = batadv_neigh_ifinfo_get(neigh_node, if_outgoing); + + /* if these are not allocated, something is wrong. */ + if (!router_ifinfo || !neigh_ifinfo) + goto out; + + neigh_last_seqno = neigh_ifinfo->bat_v.last_seqno; + router_last_seqno = router_ifinfo->bat_v.last_seqno; + neigh_seq_diff = neigh_last_seqno - router_last_seqno; + router_throughput = router_ifinfo->bat_v.throughput; + neigh_throughput = neigh_ifinfo->bat_v.throughput; + + if (neigh_seq_diff < BATADV_OGM_MAX_ORIGDIFF && + router_throughput >= neigh_throughput) + goto out; + } + + batadv_update_route(bat_priv, orig_node, if_outgoing, neigh_node); +out: + batadv_neigh_node_put(router); + batadv_neigh_node_put(orig_neigh_router); + batadv_orig_node_put(orig_neigh_node); + batadv_neigh_ifinfo_put(router_ifinfo); + batadv_neigh_ifinfo_put(neigh_ifinfo); + + return forward; +} + +/** + * batadv_v_ogm_process_per_outif() - process a batman v OGM for an outgoing if + * @bat_priv: the bat priv with all the soft interface information + * @ethhdr: the Ethernet header of the OGM2 + * @ogm2: OGM2 structure + * @orig_node: Originator structure for which the OGM has been received + * @neigh_node: the neigh_node through with the OGM has been received + * @if_incoming: the interface where this packet was received + * @if_outgoing: the interface for which the packet should be considered + */ +static void +batadv_v_ogm_process_per_outif(struct batadv_priv *bat_priv, + const struct ethhdr *ethhdr, + const struct batadv_ogm2_packet *ogm2, + struct batadv_orig_node *orig_node, + struct batadv_neigh_node *neigh_node, + struct batadv_hard_iface *if_incoming, + struct batadv_hard_iface *if_outgoing) +{ + int seqno_age; + bool forward; + + /* first, update the metric with according sanity checks */ + seqno_age = batadv_v_ogm_metric_update(bat_priv, ogm2, orig_node, + neigh_node, if_incoming, + if_outgoing); + + /* outdated sequence numbers are to be discarded */ + if (seqno_age < 0) + return; + + /* only unknown & newer OGMs contain TVLVs we are interested in */ + if (seqno_age > 0 && if_outgoing == BATADV_IF_DEFAULT) + batadv_tvlv_containers_process(bat_priv, true, orig_node, + NULL, NULL, + (unsigned char *)(ogm2 + 1), + ntohs(ogm2->tvlv_len)); + + /* if the metric update went through, update routes if needed */ + forward = batadv_v_ogm_route_update(bat_priv, ethhdr, ogm2, orig_node, + neigh_node, if_incoming, + if_outgoing); + + /* if the routes have been processed correctly, check and forward */ + if (forward) + batadv_v_ogm_forward(bat_priv, ogm2, orig_node, neigh_node, + if_incoming, if_outgoing); +} + +/** + * batadv_v_ogm_aggr_packet() - checks if there is another OGM aggregated + * @buff_pos: current position in the skb + * @packet_len: total length of the skb + * @ogm2_packet: potential OGM2 in buffer + * + * Return: true if there is enough space for another OGM, false otherwise. + */ +static bool +batadv_v_ogm_aggr_packet(int buff_pos, int packet_len, + const struct batadv_ogm2_packet *ogm2_packet) +{ + int next_buff_pos = 0; + + /* check if there is enough space for the header */ + next_buff_pos += buff_pos + sizeof(*ogm2_packet); + if (next_buff_pos > packet_len) + return false; + + /* check if there is enough space for the optional TVLV */ + next_buff_pos += ntohs(ogm2_packet->tvlv_len); + + return (next_buff_pos <= packet_len) && + (next_buff_pos <= BATADV_MAX_AGGREGATION_BYTES); +} + +/** + * batadv_v_ogm_process() - process an incoming batman v OGM + * @skb: the skb containing the OGM + * @ogm_offset: offset to the OGM which should be processed (for aggregates) + * @if_incoming: the interface where this packet was received + */ +static void batadv_v_ogm_process(const struct sk_buff *skb, int ogm_offset, + struct batadv_hard_iface *if_incoming) +{ + struct batadv_priv *bat_priv = netdev_priv(if_incoming->soft_iface); + struct ethhdr *ethhdr; + struct batadv_orig_node *orig_node = NULL; + struct batadv_hardif_neigh_node *hardif_neigh = NULL; + struct batadv_neigh_node *neigh_node = NULL; + struct batadv_hard_iface *hard_iface; + struct batadv_ogm2_packet *ogm_packet; + u32 ogm_throughput, link_throughput, path_throughput; + int ret; + + ethhdr = eth_hdr(skb); + ogm_packet = (struct batadv_ogm2_packet *)(skb->data + ogm_offset); + + ogm_throughput = ntohl(ogm_packet->throughput); + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Received OGM2 packet via NB: %pM, IF: %s [%pM] (from OG: %pM, seqno %u, throughput %u, TTL %u, V %u, tvlv_len %u)\n", + ethhdr->h_source, if_incoming->net_dev->name, + if_incoming->net_dev->dev_addr, ogm_packet->orig, + ntohl(ogm_packet->seqno), ogm_throughput, ogm_packet->ttl, + ogm_packet->version, ntohs(ogm_packet->tvlv_len)); + + if (batadv_is_my_mac(bat_priv, ogm_packet->orig)) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Drop packet: originator packet from ourself\n"); + return; + } + + /* If the throughput metric is 0, immediately drop the packet. No need + * to create orig_node / neigh_node for an unusable route. + */ + if (ogm_throughput == 0) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Drop packet: originator packet with throughput metric of 0\n"); + return; + } + + /* require ELP packets be to received from this neighbor first */ + hardif_neigh = batadv_hardif_neigh_get(if_incoming, ethhdr->h_source); + if (!hardif_neigh) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Drop packet: OGM via unknown neighbor!\n"); + goto out; + } + + orig_node = batadv_v_ogm_orig_get(bat_priv, ogm_packet->orig); + if (!orig_node) + goto out; + + neigh_node = batadv_neigh_node_get_or_create(orig_node, if_incoming, + ethhdr->h_source); + if (!neigh_node) + goto out; + + /* Update the received throughput metric to match the link + * characteristic: + * - If this OGM traveled one hop so far (emitted by single hop + * neighbor) the path throughput metric equals the link throughput. + * - For OGMs traversing more than hop the path throughput metric is + * the smaller of the path throughput and the link throughput. + */ + link_throughput = ewma_throughput_read(&hardif_neigh->bat_v.throughput); + path_throughput = min_t(u32, link_throughput, ogm_throughput); + ogm_packet->throughput = htonl(path_throughput); + + batadv_v_ogm_process_per_outif(bat_priv, ethhdr, ogm_packet, orig_node, + neigh_node, if_incoming, + BATADV_IF_DEFAULT); + + rcu_read_lock(); + list_for_each_entry_rcu(hard_iface, &batadv_hardif_list, list) { + if (hard_iface->if_status != BATADV_IF_ACTIVE) + continue; + + if (hard_iface->soft_iface != bat_priv->soft_iface) + continue; + + if (!kref_get_unless_zero(&hard_iface->refcount)) + continue; + + ret = batadv_hardif_no_broadcast(hard_iface, + ogm_packet->orig, + hardif_neigh->orig); + + if (ret) { + char *type; + + switch (ret) { + case BATADV_HARDIF_BCAST_NORECIPIENT: + type = "no neighbor"; + break; + case BATADV_HARDIF_BCAST_DUPFWD: + type = "single neighbor is source"; + break; + case BATADV_HARDIF_BCAST_DUPORIG: + type = "single neighbor is originator"; + break; + default: + type = "unknown"; + } + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, "OGM2 packet from %pM on %s suppressed: %s\n", + ogm_packet->orig, hard_iface->net_dev->name, + type); + + batadv_hardif_put(hard_iface); + continue; + } + + batadv_v_ogm_process_per_outif(bat_priv, ethhdr, ogm_packet, + orig_node, neigh_node, + if_incoming, hard_iface); + + batadv_hardif_put(hard_iface); + } + rcu_read_unlock(); +out: + batadv_orig_node_put(orig_node); + batadv_neigh_node_put(neigh_node); + batadv_hardif_neigh_put(hardif_neigh); +} + +/** + * batadv_v_ogm_packet_recv() - OGM2 receiving handler + * @skb: the received OGM + * @if_incoming: the interface where this OGM has been received + * + * Return: NET_RX_SUCCESS and consume the skb on success or returns NET_RX_DROP + * (without freeing the skb) on failure + */ +int batadv_v_ogm_packet_recv(struct sk_buff *skb, + struct batadv_hard_iface *if_incoming) +{ + struct batadv_priv *bat_priv = netdev_priv(if_incoming->soft_iface); + struct batadv_ogm2_packet *ogm_packet; + struct ethhdr *ethhdr; + int ogm_offset; + u8 *packet_pos; + int ret = NET_RX_DROP; + + /* did we receive a OGM2 packet on an interface that does not have + * B.A.T.M.A.N. V enabled ? + */ + if (strcmp(bat_priv->algo_ops->name, "BATMAN_V") != 0) + goto free_skb; + + if (!batadv_check_management_packet(skb, if_incoming, BATADV_OGM2_HLEN)) + goto free_skb; + + ethhdr = eth_hdr(skb); + if (batadv_is_my_mac(bat_priv, ethhdr->h_source)) + goto free_skb; + + batadv_inc_counter(bat_priv, BATADV_CNT_MGMT_RX); + batadv_add_counter(bat_priv, BATADV_CNT_MGMT_RX_BYTES, + skb->len + ETH_HLEN); + + ogm_offset = 0; + ogm_packet = (struct batadv_ogm2_packet *)skb->data; + + while (batadv_v_ogm_aggr_packet(ogm_offset, skb_headlen(skb), + ogm_packet)) { + batadv_v_ogm_process(skb, ogm_offset, if_incoming); + + ogm_offset += BATADV_OGM2_HLEN; + ogm_offset += ntohs(ogm_packet->tvlv_len); + + packet_pos = skb->data + ogm_offset; + ogm_packet = (struct batadv_ogm2_packet *)packet_pos; + } + + ret = NET_RX_SUCCESS; + +free_skb: + if (ret == NET_RX_SUCCESS) + consume_skb(skb); + else + kfree_skb(skb); + + return ret; +} + +/** + * batadv_v_ogm_init() - initialise the OGM2 engine + * @bat_priv: the bat priv with all the soft interface information + * + * Return: 0 on success or a negative error code in case of failure + */ +int batadv_v_ogm_init(struct batadv_priv *bat_priv) +{ + struct batadv_ogm2_packet *ogm_packet; + unsigned char *ogm_buff; + u32 random_seqno; + + bat_priv->bat_v.ogm_buff_len = BATADV_OGM2_HLEN; + ogm_buff = kzalloc(bat_priv->bat_v.ogm_buff_len, GFP_ATOMIC); + if (!ogm_buff) + return -ENOMEM; + + bat_priv->bat_v.ogm_buff = ogm_buff; + ogm_packet = (struct batadv_ogm2_packet *)ogm_buff; + ogm_packet->packet_type = BATADV_OGM2; + ogm_packet->version = BATADV_COMPAT_VERSION; + ogm_packet->ttl = BATADV_TTL; + ogm_packet->flags = BATADV_NO_FLAGS; + ogm_packet->throughput = htonl(BATADV_THROUGHPUT_MAX_VALUE); + + /* randomize initial seqno to avoid collision */ + get_random_bytes(&random_seqno, sizeof(random_seqno)); + atomic_set(&bat_priv->bat_v.ogm_seqno, random_seqno); + INIT_DELAYED_WORK(&bat_priv->bat_v.ogm_wq, batadv_v_ogm_send); + + mutex_init(&bat_priv->bat_v.ogm_buff_mutex); + + return 0; +} + +/** + * batadv_v_ogm_free() - free OGM private resources + * @bat_priv: the bat priv with all the soft interface information + */ +void batadv_v_ogm_free(struct batadv_priv *bat_priv) +{ + cancel_delayed_work_sync(&bat_priv->bat_v.ogm_wq); + + mutex_lock(&bat_priv->bat_v.ogm_buff_mutex); + + kfree(bat_priv->bat_v.ogm_buff); + bat_priv->bat_v.ogm_buff = NULL; + bat_priv->bat_v.ogm_buff_len = 0; + + mutex_unlock(&bat_priv->bat_v.ogm_buff_mutex); +} diff --git a/net/batman-adv/bat_v_ogm.h b/net/batman-adv/bat_v_ogm.h new file mode 100644 index 000000000..edeffedec --- /dev/null +++ b/net/batman-adv/bat_v_ogm.h @@ -0,0 +1,27 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Antonio Quartulli + */ + +#ifndef _NET_BATMAN_ADV_BAT_V_OGM_H_ +#define _NET_BATMAN_ADV_BAT_V_OGM_H_ + +#include "main.h" + +#include +#include +#include + +int batadv_v_ogm_init(struct batadv_priv *bat_priv); +void batadv_v_ogm_free(struct batadv_priv *bat_priv); +void batadv_v_ogm_aggr_work(struct work_struct *work); +int batadv_v_ogm_iface_enable(struct batadv_hard_iface *hard_iface); +void batadv_v_ogm_iface_disable(struct batadv_hard_iface *hard_iface); +struct batadv_orig_node *batadv_v_ogm_orig_get(struct batadv_priv *bat_priv, + const u8 *addr); +void batadv_v_ogm_primary_iface_set(struct batadv_hard_iface *primary_iface); +int batadv_v_ogm_packet_recv(struct sk_buff *skb, + struct batadv_hard_iface *if_incoming); + +#endif /* _NET_BATMAN_ADV_BAT_V_OGM_H_ */ diff --git a/net/batman-adv/bitarray.c b/net/batman-adv/bitarray.c new file mode 100644 index 000000000..649c41f39 --- /dev/null +++ b/net/batman-adv/bitarray.c @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Simon Wunderlich, Marek Lindner + */ + +#include "bitarray.h" +#include "main.h" + +#include + +#include "log.h" + +/* shift the packet array by n places. */ +static void batadv_bitmap_shift_left(unsigned long *seq_bits, s32 n) +{ + if (n <= 0 || n >= BATADV_TQ_LOCAL_WINDOW_SIZE) + return; + + bitmap_shift_left(seq_bits, seq_bits, n, BATADV_TQ_LOCAL_WINDOW_SIZE); +} + +/** + * batadv_bit_get_packet() - receive and process one packet within the sequence + * number window + * @priv: the bat priv with all the soft interface information + * @seq_bits: pointer to the sequence number receive packet + * @seq_num_diff: difference between the current/received sequence number and + * the last sequence number + * @set_mark: whether this packet should be marked in seq_bits + * + * Return: true if the window was moved (either new or very old), + * false if the window was not moved/shifted. + */ +bool batadv_bit_get_packet(void *priv, unsigned long *seq_bits, + s32 seq_num_diff, int set_mark) +{ + struct batadv_priv *bat_priv = priv; + + /* sequence number is slightly older. We already got a sequence number + * higher than this one, so we just mark it. + */ + if (seq_num_diff <= 0 && seq_num_diff > -BATADV_TQ_LOCAL_WINDOW_SIZE) { + if (set_mark) + batadv_set_bit(seq_bits, -seq_num_diff); + return false; + } + + /* sequence number is slightly newer, so we shift the window and + * set the mark if required + */ + if (seq_num_diff > 0 && seq_num_diff < BATADV_TQ_LOCAL_WINDOW_SIZE) { + batadv_bitmap_shift_left(seq_bits, seq_num_diff); + + if (set_mark) + batadv_set_bit(seq_bits, 0); + return true; + } + + /* sequence number is much newer, probably missed a lot of packets */ + if (seq_num_diff >= BATADV_TQ_LOCAL_WINDOW_SIZE && + seq_num_diff < BATADV_EXPECTED_SEQNO_RANGE) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "We missed a lot of packets (%i) !\n", + seq_num_diff - 1); + bitmap_zero(seq_bits, BATADV_TQ_LOCAL_WINDOW_SIZE); + if (set_mark) + batadv_set_bit(seq_bits, 0); + return true; + } + + /* received a much older packet. The other host either restarted + * or the old packet got delayed somewhere in the network. The + * packet should be dropped without calling this function if the + * seqno window is protected. + * + * seq_num_diff <= -BATADV_TQ_LOCAL_WINDOW_SIZE + * or + * seq_num_diff >= BATADV_EXPECTED_SEQNO_RANGE + */ + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Other host probably restarted!\n"); + + bitmap_zero(seq_bits, BATADV_TQ_LOCAL_WINDOW_SIZE); + if (set_mark) + batadv_set_bit(seq_bits, 0); + + return true; +} diff --git a/net/batman-adv/bitarray.h b/net/batman-adv/bitarray.h new file mode 100644 index 000000000..37f7ae413 --- /dev/null +++ b/net/batman-adv/bitarray.h @@ -0,0 +1,56 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Simon Wunderlich, Marek Lindner + */ + +#ifndef _NET_BATMAN_ADV_BITARRAY_H_ +#define _NET_BATMAN_ADV_BITARRAY_H_ + +#include "main.h" + +#include +#include +#include +#include + +/** + * batadv_test_bit() - check if bit is set in the current window + * + * @seq_bits: pointer to the sequence number receive packet + * @last_seqno: latest sequence number in seq_bits + * @curr_seqno: sequence number to test for + * + * Return: true if the corresponding bit in the given seq_bits indicates true + * and curr_seqno is within range of last_seqno. Otherwise returns false. + */ +static inline bool batadv_test_bit(const unsigned long *seq_bits, + u32 last_seqno, u32 curr_seqno) +{ + s32 diff; + + diff = last_seqno - curr_seqno; + if (diff < 0 || diff >= BATADV_TQ_LOCAL_WINDOW_SIZE) + return false; + return test_bit(diff, seq_bits) != 0; +} + +/** + * batadv_set_bit() - Turn corresponding bit on, so we can remember that we got + * the packet + * @seq_bits: bitmap of the packet receive window + * @n: relative sequence number of newly received packet + */ +static inline void batadv_set_bit(unsigned long *seq_bits, s32 n) +{ + /* if too old, just drop it */ + if (n < 0 || n >= BATADV_TQ_LOCAL_WINDOW_SIZE) + return; + + set_bit(n, seq_bits); /* turn the position on */ +} + +bool batadv_bit_get_packet(void *priv, unsigned long *seq_bits, + s32 seq_num_diff, int set_mark); + +#endif /* _NET_BATMAN_ADV_BITARRAY_H_ */ diff --git a/net/batman-adv/bridge_loop_avoidance.c b/net/batman-adv/bridge_loop_avoidance.c new file mode 100644 index 000000000..37ce6cfb3 --- /dev/null +++ b/net/batman-adv/bridge_loop_avoidance.c @@ -0,0 +1,2502 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Simon Wunderlich + */ + +#include "bridge_loop_avoidance.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "hard-interface.h" +#include "hash.h" +#include "log.h" +#include "netlink.h" +#include "originator.h" +#include "soft-interface.h" +#include "translation-table.h" + +static const u8 batadv_announce_mac[4] = {0x43, 0x05, 0x43, 0x05}; + +static void batadv_bla_periodic_work(struct work_struct *work); +static void +batadv_bla_send_announce(struct batadv_priv *bat_priv, + struct batadv_bla_backbone_gw *backbone_gw); + +/** + * batadv_choose_claim() - choose the right bucket for a claim. + * @data: data to hash + * @size: size of the hash table + * + * Return: the hash index of the claim + */ +static inline u32 batadv_choose_claim(const void *data, u32 size) +{ + const struct batadv_bla_claim *claim = data; + u32 hash = 0; + + hash = jhash(&claim->addr, sizeof(claim->addr), hash); + hash = jhash(&claim->vid, sizeof(claim->vid), hash); + + return hash % size; +} + +/** + * batadv_choose_backbone_gw() - choose the right bucket for a backbone gateway. + * @data: data to hash + * @size: size of the hash table + * + * Return: the hash index of the backbone gateway + */ +static inline u32 batadv_choose_backbone_gw(const void *data, u32 size) +{ + const struct batadv_bla_backbone_gw *gw; + u32 hash = 0; + + gw = data; + hash = jhash(&gw->orig, sizeof(gw->orig), hash); + hash = jhash(&gw->vid, sizeof(gw->vid), hash); + + return hash % size; +} + +/** + * batadv_compare_backbone_gw() - compare address and vid of two backbone gws + * @node: list node of the first entry to compare + * @data2: pointer to the second backbone gateway + * + * Return: true if the backbones have the same data, false otherwise + */ +static bool batadv_compare_backbone_gw(const struct hlist_node *node, + const void *data2) +{ + const void *data1 = container_of(node, struct batadv_bla_backbone_gw, + hash_entry); + const struct batadv_bla_backbone_gw *gw1 = data1; + const struct batadv_bla_backbone_gw *gw2 = data2; + + if (!batadv_compare_eth(gw1->orig, gw2->orig)) + return false; + + if (gw1->vid != gw2->vid) + return false; + + return true; +} + +/** + * batadv_compare_claim() - compare address and vid of two claims + * @node: list node of the first entry to compare + * @data2: pointer to the second claims + * + * Return: true if the claim have the same data, 0 otherwise + */ +static bool batadv_compare_claim(const struct hlist_node *node, + const void *data2) +{ + const void *data1 = container_of(node, struct batadv_bla_claim, + hash_entry); + const struct batadv_bla_claim *cl1 = data1; + const struct batadv_bla_claim *cl2 = data2; + + if (!batadv_compare_eth(cl1->addr, cl2->addr)) + return false; + + if (cl1->vid != cl2->vid) + return false; + + return true; +} + +/** + * batadv_backbone_gw_release() - release backbone gw from lists and queue for + * free after rcu grace period + * @ref: kref pointer of the backbone gw + */ +static void batadv_backbone_gw_release(struct kref *ref) +{ + struct batadv_bla_backbone_gw *backbone_gw; + + backbone_gw = container_of(ref, struct batadv_bla_backbone_gw, + refcount); + + kfree_rcu(backbone_gw, rcu); +} + +/** + * batadv_backbone_gw_put() - decrement the backbone gw refcounter and possibly + * release it + * @backbone_gw: backbone gateway to be free'd + */ +static void batadv_backbone_gw_put(struct batadv_bla_backbone_gw *backbone_gw) +{ + if (!backbone_gw) + return; + + kref_put(&backbone_gw->refcount, batadv_backbone_gw_release); +} + +/** + * batadv_claim_release() - release claim from lists and queue for free after + * rcu grace period + * @ref: kref pointer of the claim + */ +static void batadv_claim_release(struct kref *ref) +{ + struct batadv_bla_claim *claim; + struct batadv_bla_backbone_gw *old_backbone_gw; + + claim = container_of(ref, struct batadv_bla_claim, refcount); + + spin_lock_bh(&claim->backbone_lock); + old_backbone_gw = claim->backbone_gw; + claim->backbone_gw = NULL; + spin_unlock_bh(&claim->backbone_lock); + + spin_lock_bh(&old_backbone_gw->crc_lock); + old_backbone_gw->crc ^= crc16(0, claim->addr, ETH_ALEN); + spin_unlock_bh(&old_backbone_gw->crc_lock); + + batadv_backbone_gw_put(old_backbone_gw); + + kfree_rcu(claim, rcu); +} + +/** + * batadv_claim_put() - decrement the claim refcounter and possibly release it + * @claim: claim to be free'd + */ +static void batadv_claim_put(struct batadv_bla_claim *claim) +{ + if (!claim) + return; + + kref_put(&claim->refcount, batadv_claim_release); +} + +/** + * batadv_claim_hash_find() - looks for a claim in the claim hash + * @bat_priv: the bat priv with all the soft interface information + * @data: search data (may be local/static data) + * + * Return: claim if found or NULL otherwise. + */ +static struct batadv_bla_claim * +batadv_claim_hash_find(struct batadv_priv *bat_priv, + struct batadv_bla_claim *data) +{ + struct batadv_hashtable *hash = bat_priv->bla.claim_hash; + struct hlist_head *head; + struct batadv_bla_claim *claim; + struct batadv_bla_claim *claim_tmp = NULL; + int index; + + if (!hash) + return NULL; + + index = batadv_choose_claim(data, hash->size); + head = &hash->table[index]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(claim, head, hash_entry) { + if (!batadv_compare_claim(&claim->hash_entry, data)) + continue; + + if (!kref_get_unless_zero(&claim->refcount)) + continue; + + claim_tmp = claim; + break; + } + rcu_read_unlock(); + + return claim_tmp; +} + +/** + * batadv_backbone_hash_find() - looks for a backbone gateway in the hash + * @bat_priv: the bat priv with all the soft interface information + * @addr: the address of the originator + * @vid: the VLAN ID + * + * Return: backbone gateway if found or NULL otherwise + */ +static struct batadv_bla_backbone_gw * +batadv_backbone_hash_find(struct batadv_priv *bat_priv, const u8 *addr, + unsigned short vid) +{ + struct batadv_hashtable *hash = bat_priv->bla.backbone_hash; + struct hlist_head *head; + struct batadv_bla_backbone_gw search_entry, *backbone_gw; + struct batadv_bla_backbone_gw *backbone_gw_tmp = NULL; + int index; + + if (!hash) + return NULL; + + ether_addr_copy(search_entry.orig, addr); + search_entry.vid = vid; + + index = batadv_choose_backbone_gw(&search_entry, hash->size); + head = &hash->table[index]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(backbone_gw, head, hash_entry) { + if (!batadv_compare_backbone_gw(&backbone_gw->hash_entry, + &search_entry)) + continue; + + if (!kref_get_unless_zero(&backbone_gw->refcount)) + continue; + + backbone_gw_tmp = backbone_gw; + break; + } + rcu_read_unlock(); + + return backbone_gw_tmp; +} + +/** + * batadv_bla_del_backbone_claims() - delete all claims for a backbone + * @backbone_gw: backbone gateway where the claims should be removed + */ +static void +batadv_bla_del_backbone_claims(struct batadv_bla_backbone_gw *backbone_gw) +{ + struct batadv_hashtable *hash; + struct hlist_node *node_tmp; + struct hlist_head *head; + struct batadv_bla_claim *claim; + int i; + spinlock_t *list_lock; /* protects write access to the hash lists */ + + hash = backbone_gw->bat_priv->bla.claim_hash; + if (!hash) + return; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + list_lock = &hash->list_locks[i]; + + spin_lock_bh(list_lock); + hlist_for_each_entry_safe(claim, node_tmp, + head, hash_entry) { + if (claim->backbone_gw != backbone_gw) + continue; + + batadv_claim_put(claim); + hlist_del_rcu(&claim->hash_entry); + } + spin_unlock_bh(list_lock); + } + + /* all claims gone, initialize CRC */ + spin_lock_bh(&backbone_gw->crc_lock); + backbone_gw->crc = BATADV_BLA_CRC_INIT; + spin_unlock_bh(&backbone_gw->crc_lock); +} + +/** + * batadv_bla_send_claim() - sends a claim frame according to the provided info + * @bat_priv: the bat priv with all the soft interface information + * @mac: the mac address to be announced within the claim + * @vid: the VLAN ID + * @claimtype: the type of the claim (CLAIM, UNCLAIM, ANNOUNCE, ...) + */ +static void batadv_bla_send_claim(struct batadv_priv *bat_priv, const u8 *mac, + unsigned short vid, int claimtype) +{ + struct sk_buff *skb; + struct ethhdr *ethhdr; + struct batadv_hard_iface *primary_if; + struct net_device *soft_iface; + u8 *hw_src; + struct batadv_bla_claim_dst local_claim_dest; + __be32 zeroip = 0; + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + return; + + memcpy(&local_claim_dest, &bat_priv->bla.claim_dest, + sizeof(local_claim_dest)); + local_claim_dest.type = claimtype; + + soft_iface = primary_if->soft_iface; + + skb = arp_create(ARPOP_REPLY, ETH_P_ARP, + /* IP DST: 0.0.0.0 */ + zeroip, + primary_if->soft_iface, + /* IP SRC: 0.0.0.0 */ + zeroip, + /* Ethernet DST: Broadcast */ + NULL, + /* Ethernet SRC/HW SRC: originator mac */ + primary_if->net_dev->dev_addr, + /* HW DST: FF:43:05:XX:YY:YY + * with XX = claim type + * and YY:YY = group id + */ + (u8 *)&local_claim_dest); + + if (!skb) + goto out; + + ethhdr = (struct ethhdr *)skb->data; + hw_src = (u8 *)ethhdr + ETH_HLEN + sizeof(struct arphdr); + + /* now we pretend that the client would have sent this ... */ + switch (claimtype) { + case BATADV_CLAIM_TYPE_CLAIM: + /* normal claim frame + * set Ethernet SRC to the clients mac + */ + ether_addr_copy(ethhdr->h_source, mac); + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): CLAIM %pM on vid %d\n", __func__, mac, + batadv_print_vid(vid)); + break; + case BATADV_CLAIM_TYPE_UNCLAIM: + /* unclaim frame + * set HW SRC to the clients mac + */ + ether_addr_copy(hw_src, mac); + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): UNCLAIM %pM on vid %d\n", __func__, mac, + batadv_print_vid(vid)); + break; + case BATADV_CLAIM_TYPE_ANNOUNCE: + /* announcement frame + * set HW SRC to the special mac containing the crc + */ + ether_addr_copy(hw_src, mac); + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): ANNOUNCE of %pM on vid %d\n", __func__, + ethhdr->h_source, batadv_print_vid(vid)); + break; + case BATADV_CLAIM_TYPE_REQUEST: + /* request frame + * set HW SRC and header destination to the receiving backbone + * gws mac + */ + ether_addr_copy(hw_src, mac); + ether_addr_copy(ethhdr->h_dest, mac); + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): REQUEST of %pM to %pM on vid %d\n", __func__, + ethhdr->h_source, ethhdr->h_dest, + batadv_print_vid(vid)); + break; + case BATADV_CLAIM_TYPE_LOOPDETECT: + ether_addr_copy(ethhdr->h_source, mac); + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): LOOPDETECT of %pM to %pM on vid %d\n", + __func__, ethhdr->h_source, ethhdr->h_dest, + batadv_print_vid(vid)); + + break; + } + + if (vid & BATADV_VLAN_HAS_TAG) { + skb = vlan_insert_tag(skb, htons(ETH_P_8021Q), + vid & VLAN_VID_MASK); + if (!skb) + goto out; + } + + skb_reset_mac_header(skb); + skb->protocol = eth_type_trans(skb, soft_iface); + batadv_inc_counter(bat_priv, BATADV_CNT_RX); + batadv_add_counter(bat_priv, BATADV_CNT_RX_BYTES, + skb->len + ETH_HLEN); + + netif_rx(skb); +out: + batadv_hardif_put(primary_if); +} + +/** + * batadv_bla_loopdetect_report() - worker for reporting the loop + * @work: work queue item + * + * Throws an uevent, as the loopdetect check function can't do that itself + * since the kernel may sleep while throwing uevents. + */ +static void batadv_bla_loopdetect_report(struct work_struct *work) +{ + struct batadv_bla_backbone_gw *backbone_gw; + struct batadv_priv *bat_priv; + char vid_str[6] = { '\0' }; + + backbone_gw = container_of(work, struct batadv_bla_backbone_gw, + report_work); + bat_priv = backbone_gw->bat_priv; + + batadv_info(bat_priv->soft_iface, + "Possible loop on VLAN %d detected which can't be handled by BLA - please check your network setup!\n", + batadv_print_vid(backbone_gw->vid)); + snprintf(vid_str, sizeof(vid_str), "%d", + batadv_print_vid(backbone_gw->vid)); + vid_str[sizeof(vid_str) - 1] = 0; + + batadv_throw_uevent(bat_priv, BATADV_UEV_BLA, BATADV_UEV_LOOPDETECT, + vid_str); + + batadv_backbone_gw_put(backbone_gw); +} + +/** + * batadv_bla_get_backbone_gw() - finds or creates a backbone gateway + * @bat_priv: the bat priv with all the soft interface information + * @orig: the mac address of the originator + * @vid: the VLAN ID + * @own_backbone: set if the requested backbone is local + * + * Return: the (possibly created) backbone gateway or NULL on error + */ +static struct batadv_bla_backbone_gw * +batadv_bla_get_backbone_gw(struct batadv_priv *bat_priv, const u8 *orig, + unsigned short vid, bool own_backbone) +{ + struct batadv_bla_backbone_gw *entry; + struct batadv_orig_node *orig_node; + int hash_added; + + entry = batadv_backbone_hash_find(bat_priv, orig, vid); + + if (entry) + return entry; + + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): not found (%pM, %d), creating new entry\n", __func__, + orig, batadv_print_vid(vid)); + + entry = kzalloc(sizeof(*entry), GFP_ATOMIC); + if (!entry) + return NULL; + + entry->vid = vid; + entry->lasttime = jiffies; + entry->crc = BATADV_BLA_CRC_INIT; + entry->bat_priv = bat_priv; + spin_lock_init(&entry->crc_lock); + atomic_set(&entry->request_sent, 0); + atomic_set(&entry->wait_periods, 0); + ether_addr_copy(entry->orig, orig); + INIT_WORK(&entry->report_work, batadv_bla_loopdetect_report); + kref_init(&entry->refcount); + + kref_get(&entry->refcount); + hash_added = batadv_hash_add(bat_priv->bla.backbone_hash, + batadv_compare_backbone_gw, + batadv_choose_backbone_gw, entry, + &entry->hash_entry); + + if (unlikely(hash_added != 0)) { + /* hash failed, free the structure */ + kfree(entry); + return NULL; + } + + /* this is a gateway now, remove any TT entry on this VLAN */ + orig_node = batadv_orig_hash_find(bat_priv, orig); + if (orig_node) { + batadv_tt_global_del_orig(bat_priv, orig_node, vid, + "became a backbone gateway"); + batadv_orig_node_put(orig_node); + } + + if (own_backbone) { + batadv_bla_send_announce(bat_priv, entry); + + /* this will be decreased in the worker thread */ + atomic_inc(&entry->request_sent); + atomic_set(&entry->wait_periods, BATADV_BLA_WAIT_PERIODS); + atomic_inc(&bat_priv->bla.num_requests); + } + + return entry; +} + +/** + * batadv_bla_update_own_backbone_gw() - updates the own backbone gw for a VLAN + * @bat_priv: the bat priv with all the soft interface information + * @primary_if: the selected primary interface + * @vid: VLAN identifier + * + * update or add the own backbone gw to make sure we announce + * where we receive other backbone gws + */ +static void +batadv_bla_update_own_backbone_gw(struct batadv_priv *bat_priv, + struct batadv_hard_iface *primary_if, + unsigned short vid) +{ + struct batadv_bla_backbone_gw *backbone_gw; + + backbone_gw = batadv_bla_get_backbone_gw(bat_priv, + primary_if->net_dev->dev_addr, + vid, true); + if (unlikely(!backbone_gw)) + return; + + backbone_gw->lasttime = jiffies; + batadv_backbone_gw_put(backbone_gw); +} + +/** + * batadv_bla_answer_request() - answer a bla request by sending own claims + * @bat_priv: the bat priv with all the soft interface information + * @primary_if: interface where the request came on + * @vid: the vid where the request came on + * + * Repeat all of our own claims, and finally send an ANNOUNCE frame + * to allow the requester another check if the CRC is correct now. + */ +static void batadv_bla_answer_request(struct batadv_priv *bat_priv, + struct batadv_hard_iface *primary_if, + unsigned short vid) +{ + struct hlist_head *head; + struct batadv_hashtable *hash; + struct batadv_bla_claim *claim; + struct batadv_bla_backbone_gw *backbone_gw; + int i; + + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): received a claim request, send all of our own claims again\n", + __func__); + + backbone_gw = batadv_backbone_hash_find(bat_priv, + primary_if->net_dev->dev_addr, + vid); + if (!backbone_gw) + return; + + hash = bat_priv->bla.claim_hash; + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(claim, head, hash_entry) { + /* only own claims are interesting */ + if (claim->backbone_gw != backbone_gw) + continue; + + batadv_bla_send_claim(bat_priv, claim->addr, claim->vid, + BATADV_CLAIM_TYPE_CLAIM); + } + rcu_read_unlock(); + } + + /* finally, send an announcement frame */ + batadv_bla_send_announce(bat_priv, backbone_gw); + batadv_backbone_gw_put(backbone_gw); +} + +/** + * batadv_bla_send_request() - send a request to repeat claims + * @backbone_gw: the backbone gateway from whom we are out of sync + * + * When the crc is wrong, ask the backbone gateway for a full table update. + * After the request, it will repeat all of his own claims and finally + * send an announcement claim with which we can check again. + */ +static void batadv_bla_send_request(struct batadv_bla_backbone_gw *backbone_gw) +{ + /* first, remove all old entries */ + batadv_bla_del_backbone_claims(backbone_gw); + + batadv_dbg(BATADV_DBG_BLA, backbone_gw->bat_priv, + "Sending REQUEST to %pM\n", backbone_gw->orig); + + /* send request */ + batadv_bla_send_claim(backbone_gw->bat_priv, backbone_gw->orig, + backbone_gw->vid, BATADV_CLAIM_TYPE_REQUEST); + + /* no local broadcasts should be sent or received, for now. */ + if (!atomic_read(&backbone_gw->request_sent)) { + atomic_inc(&backbone_gw->bat_priv->bla.num_requests); + atomic_set(&backbone_gw->request_sent, 1); + } +} + +/** + * batadv_bla_send_announce() - Send an announcement frame + * @bat_priv: the bat priv with all the soft interface information + * @backbone_gw: our backbone gateway which should be announced + */ +static void batadv_bla_send_announce(struct batadv_priv *bat_priv, + struct batadv_bla_backbone_gw *backbone_gw) +{ + u8 mac[ETH_ALEN]; + __be16 crc; + + memcpy(mac, batadv_announce_mac, 4); + spin_lock_bh(&backbone_gw->crc_lock); + crc = htons(backbone_gw->crc); + spin_unlock_bh(&backbone_gw->crc_lock); + memcpy(&mac[4], &crc, 2); + + batadv_bla_send_claim(bat_priv, mac, backbone_gw->vid, + BATADV_CLAIM_TYPE_ANNOUNCE); +} + +/** + * batadv_bla_add_claim() - Adds a claim in the claim hash + * @bat_priv: the bat priv with all the soft interface information + * @mac: the mac address of the claim + * @vid: the VLAN ID of the frame + * @backbone_gw: the backbone gateway which claims it + */ +static void batadv_bla_add_claim(struct batadv_priv *bat_priv, + const u8 *mac, const unsigned short vid, + struct batadv_bla_backbone_gw *backbone_gw) +{ + struct batadv_bla_backbone_gw *old_backbone_gw; + struct batadv_bla_claim *claim; + struct batadv_bla_claim search_claim; + bool remove_crc = false; + int hash_added; + + ether_addr_copy(search_claim.addr, mac); + search_claim.vid = vid; + claim = batadv_claim_hash_find(bat_priv, &search_claim); + + /* create a new claim entry if it does not exist yet. */ + if (!claim) { + claim = kzalloc(sizeof(*claim), GFP_ATOMIC); + if (!claim) + return; + + ether_addr_copy(claim->addr, mac); + spin_lock_init(&claim->backbone_lock); + claim->vid = vid; + claim->lasttime = jiffies; + kref_get(&backbone_gw->refcount); + claim->backbone_gw = backbone_gw; + kref_init(&claim->refcount); + + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): adding new entry %pM, vid %d to hash ...\n", + __func__, mac, batadv_print_vid(vid)); + + kref_get(&claim->refcount); + hash_added = batadv_hash_add(bat_priv->bla.claim_hash, + batadv_compare_claim, + batadv_choose_claim, claim, + &claim->hash_entry); + + if (unlikely(hash_added != 0)) { + /* only local changes happened. */ + kfree(claim); + return; + } + } else { + claim->lasttime = jiffies; + if (claim->backbone_gw == backbone_gw) + /* no need to register a new backbone */ + goto claim_free_ref; + + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): changing ownership for %pM, vid %d to gw %pM\n", + __func__, mac, batadv_print_vid(vid), + backbone_gw->orig); + + remove_crc = true; + } + + /* replace backbone_gw atomically and adjust reference counters */ + spin_lock_bh(&claim->backbone_lock); + old_backbone_gw = claim->backbone_gw; + kref_get(&backbone_gw->refcount); + claim->backbone_gw = backbone_gw; + spin_unlock_bh(&claim->backbone_lock); + + if (remove_crc) { + /* remove claim address from old backbone_gw */ + spin_lock_bh(&old_backbone_gw->crc_lock); + old_backbone_gw->crc ^= crc16(0, claim->addr, ETH_ALEN); + spin_unlock_bh(&old_backbone_gw->crc_lock); + } + + batadv_backbone_gw_put(old_backbone_gw); + + /* add claim address to new backbone_gw */ + spin_lock_bh(&backbone_gw->crc_lock); + backbone_gw->crc ^= crc16(0, claim->addr, ETH_ALEN); + spin_unlock_bh(&backbone_gw->crc_lock); + backbone_gw->lasttime = jiffies; + +claim_free_ref: + batadv_claim_put(claim); +} + +/** + * batadv_bla_claim_get_backbone_gw() - Get valid reference for backbone_gw of + * claim + * @claim: claim whose backbone_gw should be returned + * + * Return: valid reference to claim::backbone_gw + */ +static struct batadv_bla_backbone_gw * +batadv_bla_claim_get_backbone_gw(struct batadv_bla_claim *claim) +{ + struct batadv_bla_backbone_gw *backbone_gw; + + spin_lock_bh(&claim->backbone_lock); + backbone_gw = claim->backbone_gw; + kref_get(&backbone_gw->refcount); + spin_unlock_bh(&claim->backbone_lock); + + return backbone_gw; +} + +/** + * batadv_bla_del_claim() - delete a claim from the claim hash + * @bat_priv: the bat priv with all the soft interface information + * @mac: mac address of the claim to be removed + * @vid: VLAN id for the claim to be removed + */ +static void batadv_bla_del_claim(struct batadv_priv *bat_priv, + const u8 *mac, const unsigned short vid) +{ + struct batadv_bla_claim search_claim, *claim; + struct batadv_bla_claim *claim_removed_entry; + struct hlist_node *claim_removed_node; + + ether_addr_copy(search_claim.addr, mac); + search_claim.vid = vid; + claim = batadv_claim_hash_find(bat_priv, &search_claim); + if (!claim) + return; + + batadv_dbg(BATADV_DBG_BLA, bat_priv, "%s(): %pM, vid %d\n", __func__, + mac, batadv_print_vid(vid)); + + claim_removed_node = batadv_hash_remove(bat_priv->bla.claim_hash, + batadv_compare_claim, + batadv_choose_claim, claim); + if (!claim_removed_node) + goto free_claim; + + /* reference from the hash is gone */ + claim_removed_entry = hlist_entry(claim_removed_node, + struct batadv_bla_claim, hash_entry); + batadv_claim_put(claim_removed_entry); + +free_claim: + /* don't need the reference from hash_find() anymore */ + batadv_claim_put(claim); +} + +/** + * batadv_handle_announce() - check for ANNOUNCE frame + * @bat_priv: the bat priv with all the soft interface information + * @an_addr: announcement mac address (ARP Sender HW address) + * @backbone_addr: originator address of the sender (Ethernet source MAC) + * @vid: the VLAN ID of the frame + * + * Return: true if handled + */ +static bool batadv_handle_announce(struct batadv_priv *bat_priv, u8 *an_addr, + u8 *backbone_addr, unsigned short vid) +{ + struct batadv_bla_backbone_gw *backbone_gw; + u16 backbone_crc, crc; + + if (memcmp(an_addr, batadv_announce_mac, 4) != 0) + return false; + + backbone_gw = batadv_bla_get_backbone_gw(bat_priv, backbone_addr, vid, + false); + + if (unlikely(!backbone_gw)) + return true; + + /* handle as ANNOUNCE frame */ + backbone_gw->lasttime = jiffies; + crc = ntohs(*((__force __be16 *)(&an_addr[4]))); + + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): ANNOUNCE vid %d (sent by %pM)... CRC = %#.4x\n", + __func__, batadv_print_vid(vid), backbone_gw->orig, crc); + + spin_lock_bh(&backbone_gw->crc_lock); + backbone_crc = backbone_gw->crc; + spin_unlock_bh(&backbone_gw->crc_lock); + + if (backbone_crc != crc) { + batadv_dbg(BATADV_DBG_BLA, backbone_gw->bat_priv, + "%s(): CRC FAILED for %pM/%d (my = %#.4x, sent = %#.4x)\n", + __func__, backbone_gw->orig, + batadv_print_vid(backbone_gw->vid), + backbone_crc, crc); + + batadv_bla_send_request(backbone_gw); + } else { + /* if we have sent a request and the crc was OK, + * we can allow traffic again. + */ + if (atomic_read(&backbone_gw->request_sent)) { + atomic_dec(&backbone_gw->bat_priv->bla.num_requests); + atomic_set(&backbone_gw->request_sent, 0); + } + } + + batadv_backbone_gw_put(backbone_gw); + return true; +} + +/** + * batadv_handle_request() - check for REQUEST frame + * @bat_priv: the bat priv with all the soft interface information + * @primary_if: the primary hard interface of this batman soft interface + * @backbone_addr: backbone address to be requested (ARP sender HW MAC) + * @ethhdr: ethernet header of a packet + * @vid: the VLAN ID of the frame + * + * Return: true if handled + */ +static bool batadv_handle_request(struct batadv_priv *bat_priv, + struct batadv_hard_iface *primary_if, + u8 *backbone_addr, struct ethhdr *ethhdr, + unsigned short vid) +{ + /* check for REQUEST frame */ + if (!batadv_compare_eth(backbone_addr, ethhdr->h_dest)) + return false; + + /* sanity check, this should not happen on a normal switch, + * we ignore it in this case. + */ + if (!batadv_compare_eth(ethhdr->h_dest, primary_if->net_dev->dev_addr)) + return true; + + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): REQUEST vid %d (sent by %pM)...\n", + __func__, batadv_print_vid(vid), ethhdr->h_source); + + batadv_bla_answer_request(bat_priv, primary_if, vid); + return true; +} + +/** + * batadv_handle_unclaim() - check for UNCLAIM frame + * @bat_priv: the bat priv with all the soft interface information + * @primary_if: the primary hard interface of this batman soft interface + * @backbone_addr: originator address of the backbone (Ethernet source) + * @claim_addr: Client to be unclaimed (ARP sender HW MAC) + * @vid: the VLAN ID of the frame + * + * Return: true if handled + */ +static bool batadv_handle_unclaim(struct batadv_priv *bat_priv, + struct batadv_hard_iface *primary_if, + const u8 *backbone_addr, const u8 *claim_addr, + unsigned short vid) +{ + struct batadv_bla_backbone_gw *backbone_gw; + + /* unclaim in any case if it is our own */ + if (primary_if && batadv_compare_eth(backbone_addr, + primary_if->net_dev->dev_addr)) + batadv_bla_send_claim(bat_priv, claim_addr, vid, + BATADV_CLAIM_TYPE_UNCLAIM); + + backbone_gw = batadv_backbone_hash_find(bat_priv, backbone_addr, vid); + + if (!backbone_gw) + return true; + + /* this must be an UNCLAIM frame */ + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): UNCLAIM %pM on vid %d (sent by %pM)...\n", __func__, + claim_addr, batadv_print_vid(vid), backbone_gw->orig); + + batadv_bla_del_claim(bat_priv, claim_addr, vid); + batadv_backbone_gw_put(backbone_gw); + return true; +} + +/** + * batadv_handle_claim() - check for CLAIM frame + * @bat_priv: the bat priv with all the soft interface information + * @primary_if: the primary hard interface of this batman soft interface + * @backbone_addr: originator address of the backbone (Ethernet Source) + * @claim_addr: client mac address to be claimed (ARP sender HW MAC) + * @vid: the VLAN ID of the frame + * + * Return: true if handled + */ +static bool batadv_handle_claim(struct batadv_priv *bat_priv, + struct batadv_hard_iface *primary_if, + const u8 *backbone_addr, const u8 *claim_addr, + unsigned short vid) +{ + struct batadv_bla_backbone_gw *backbone_gw; + + /* register the gateway if not yet available, and add the claim. */ + + backbone_gw = batadv_bla_get_backbone_gw(bat_priv, backbone_addr, vid, + false); + + if (unlikely(!backbone_gw)) + return true; + + /* this must be a CLAIM frame */ + batadv_bla_add_claim(bat_priv, claim_addr, vid, backbone_gw); + if (batadv_compare_eth(backbone_addr, primary_if->net_dev->dev_addr)) + batadv_bla_send_claim(bat_priv, claim_addr, vid, + BATADV_CLAIM_TYPE_CLAIM); + + /* TODO: we could call something like tt_local_del() here. */ + + batadv_backbone_gw_put(backbone_gw); + return true; +} + +/** + * batadv_check_claim_group() - check for claim group membership + * @bat_priv: the bat priv with all the soft interface information + * @primary_if: the primary interface of this batman interface + * @hw_src: the Hardware source in the ARP Header + * @hw_dst: the Hardware destination in the ARP Header + * @ethhdr: pointer to the Ethernet header of the claim frame + * + * checks if it is a claim packet and if it's on the same group. + * This function also applies the group ID of the sender + * if it is in the same mesh. + * + * Return: + * 2 - if it is a claim packet and on the same group + * 1 - if is a claim packet from another group + * 0 - if it is not a claim packet + */ +static int batadv_check_claim_group(struct batadv_priv *bat_priv, + struct batadv_hard_iface *primary_if, + u8 *hw_src, u8 *hw_dst, + struct ethhdr *ethhdr) +{ + u8 *backbone_addr; + struct batadv_orig_node *orig_node; + struct batadv_bla_claim_dst *bla_dst, *bla_dst_own; + + bla_dst = (struct batadv_bla_claim_dst *)hw_dst; + bla_dst_own = &bat_priv->bla.claim_dest; + + /* if announcement packet, use the source, + * otherwise assume it is in the hw_src + */ + switch (bla_dst->type) { + case BATADV_CLAIM_TYPE_CLAIM: + backbone_addr = hw_src; + break; + case BATADV_CLAIM_TYPE_REQUEST: + case BATADV_CLAIM_TYPE_ANNOUNCE: + case BATADV_CLAIM_TYPE_UNCLAIM: + backbone_addr = ethhdr->h_source; + break; + default: + return 0; + } + + /* don't accept claim frames from ourselves */ + if (batadv_compare_eth(backbone_addr, primary_if->net_dev->dev_addr)) + return 0; + + /* if its already the same group, it is fine. */ + if (bla_dst->group == bla_dst_own->group) + return 2; + + /* lets see if this originator is in our mesh */ + orig_node = batadv_orig_hash_find(bat_priv, backbone_addr); + + /* don't accept claims from gateways which are not in + * the same mesh or group. + */ + if (!orig_node) + return 1; + + /* if our mesh friends mac is bigger, use it for ourselves. */ + if (ntohs(bla_dst->group) > ntohs(bla_dst_own->group)) { + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "taking other backbones claim group: %#.4x\n", + ntohs(bla_dst->group)); + bla_dst_own->group = bla_dst->group; + } + + batadv_orig_node_put(orig_node); + + return 2; +} + +/** + * batadv_bla_process_claim() - Check if this is a claim frame, and process it + * @bat_priv: the bat priv with all the soft interface information + * @primary_if: the primary hard interface of this batman soft interface + * @skb: the frame to be checked + * + * Return: true if it was a claim frame, otherwise return false to + * tell the callee that it can use the frame on its own. + */ +static bool batadv_bla_process_claim(struct batadv_priv *bat_priv, + struct batadv_hard_iface *primary_if, + struct sk_buff *skb) +{ + struct batadv_bla_claim_dst *bla_dst, *bla_dst_own; + u8 *hw_src, *hw_dst; + struct vlan_hdr *vhdr, vhdr_buf; + struct ethhdr *ethhdr; + struct arphdr *arphdr; + unsigned short vid; + int vlan_depth = 0; + __be16 proto; + int headlen; + int ret; + + vid = batadv_get_vid(skb, 0); + ethhdr = eth_hdr(skb); + + proto = ethhdr->h_proto; + headlen = ETH_HLEN; + if (vid & BATADV_VLAN_HAS_TAG) { + /* Traverse the VLAN/Ethertypes. + * + * At this point it is known that the first protocol is a VLAN + * header, so start checking at the encapsulated protocol. + * + * The depth of the VLAN headers is recorded to drop BLA claim + * frames encapsulated into multiple VLAN headers (QinQ). + */ + do { + vhdr = skb_header_pointer(skb, headlen, VLAN_HLEN, + &vhdr_buf); + if (!vhdr) + return false; + + proto = vhdr->h_vlan_encapsulated_proto; + headlen += VLAN_HLEN; + vlan_depth++; + } while (proto == htons(ETH_P_8021Q)); + } + + if (proto != htons(ETH_P_ARP)) + return false; /* not a claim frame */ + + /* this must be a ARP frame. check if it is a claim. */ + + if (unlikely(!pskb_may_pull(skb, headlen + arp_hdr_len(skb->dev)))) + return false; + + /* pskb_may_pull() may have modified the pointers, get ethhdr again */ + ethhdr = eth_hdr(skb); + arphdr = (struct arphdr *)((u8 *)ethhdr + headlen); + + /* Check whether the ARP frame carries a valid + * IP information + */ + if (arphdr->ar_hrd != htons(ARPHRD_ETHER)) + return false; + if (arphdr->ar_pro != htons(ETH_P_IP)) + return false; + if (arphdr->ar_hln != ETH_ALEN) + return false; + if (arphdr->ar_pln != 4) + return false; + + hw_src = (u8 *)arphdr + sizeof(struct arphdr); + hw_dst = hw_src + ETH_ALEN + 4; + bla_dst = (struct batadv_bla_claim_dst *)hw_dst; + bla_dst_own = &bat_priv->bla.claim_dest; + + /* check if it is a claim frame in general */ + if (memcmp(bla_dst->magic, bla_dst_own->magic, + sizeof(bla_dst->magic)) != 0) + return false; + + /* check if there is a claim frame encapsulated deeper in (QinQ) and + * drop that, as this is not supported by BLA but should also not be + * sent via the mesh. + */ + if (vlan_depth > 1) + return true; + + /* Let the loopdetect frames on the mesh in any case. */ + if (bla_dst->type == BATADV_CLAIM_TYPE_LOOPDETECT) + return false; + + /* check if it is a claim frame. */ + ret = batadv_check_claim_group(bat_priv, primary_if, hw_src, hw_dst, + ethhdr); + if (ret == 1) + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): received a claim frame from another group. From: %pM on vid %d ...(hw_src %pM, hw_dst %pM)\n", + __func__, ethhdr->h_source, batadv_print_vid(vid), + hw_src, hw_dst); + + if (ret < 2) + return !!ret; + + /* become a backbone gw ourselves on this vlan if not happened yet */ + batadv_bla_update_own_backbone_gw(bat_priv, primary_if, vid); + + /* check for the different types of claim frames ... */ + switch (bla_dst->type) { + case BATADV_CLAIM_TYPE_CLAIM: + if (batadv_handle_claim(bat_priv, primary_if, hw_src, + ethhdr->h_source, vid)) + return true; + break; + case BATADV_CLAIM_TYPE_UNCLAIM: + if (batadv_handle_unclaim(bat_priv, primary_if, + ethhdr->h_source, hw_src, vid)) + return true; + break; + + case BATADV_CLAIM_TYPE_ANNOUNCE: + if (batadv_handle_announce(bat_priv, hw_src, ethhdr->h_source, + vid)) + return true; + break; + case BATADV_CLAIM_TYPE_REQUEST: + if (batadv_handle_request(bat_priv, primary_if, hw_src, ethhdr, + vid)) + return true; + break; + } + + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): ERROR - this looks like a claim frame, but is useless. eth src %pM on vid %d ...(hw_src %pM, hw_dst %pM)\n", + __func__, ethhdr->h_source, batadv_print_vid(vid), hw_src, + hw_dst); + return true; +} + +/** + * batadv_bla_purge_backbone_gw() - Remove backbone gateways after a timeout or + * immediately + * @bat_priv: the bat priv with all the soft interface information + * @now: whether the whole hash shall be wiped now + * + * Check when we last heard from other nodes, and remove them in case of + * a time out, or clean all backbone gws if now is set. + */ +static void batadv_bla_purge_backbone_gw(struct batadv_priv *bat_priv, int now) +{ + struct batadv_bla_backbone_gw *backbone_gw; + struct hlist_node *node_tmp; + struct hlist_head *head; + struct batadv_hashtable *hash; + spinlock_t *list_lock; /* protects write access to the hash lists */ + int i; + + hash = bat_priv->bla.backbone_hash; + if (!hash) + return; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + list_lock = &hash->list_locks[i]; + + spin_lock_bh(list_lock); + hlist_for_each_entry_safe(backbone_gw, node_tmp, + head, hash_entry) { + if (now) + goto purge_now; + if (!batadv_has_timed_out(backbone_gw->lasttime, + BATADV_BLA_BACKBONE_TIMEOUT)) + continue; + + batadv_dbg(BATADV_DBG_BLA, backbone_gw->bat_priv, + "%s(): backbone gw %pM timed out\n", + __func__, backbone_gw->orig); + +purge_now: + /* don't wait for the pending request anymore */ + if (atomic_read(&backbone_gw->request_sent)) + atomic_dec(&bat_priv->bla.num_requests); + + batadv_bla_del_backbone_claims(backbone_gw); + + hlist_del_rcu(&backbone_gw->hash_entry); + batadv_backbone_gw_put(backbone_gw); + } + spin_unlock_bh(list_lock); + } +} + +/** + * batadv_bla_purge_claims() - Remove claims after a timeout or immediately + * @bat_priv: the bat priv with all the soft interface information + * @primary_if: the selected primary interface, may be NULL if now is set + * @now: whether the whole hash shall be wiped now + * + * Check when we heard last time from our own claims, and remove them in case of + * a time out, or clean all claims if now is set + */ +static void batadv_bla_purge_claims(struct batadv_priv *bat_priv, + struct batadv_hard_iface *primary_if, + int now) +{ + struct batadv_bla_backbone_gw *backbone_gw; + struct batadv_bla_claim *claim; + struct hlist_head *head; + struct batadv_hashtable *hash; + int i; + + hash = bat_priv->bla.claim_hash; + if (!hash) + return; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(claim, head, hash_entry) { + backbone_gw = batadv_bla_claim_get_backbone_gw(claim); + if (now) + goto purge_now; + + if (!batadv_compare_eth(backbone_gw->orig, + primary_if->net_dev->dev_addr)) + goto skip; + + if (!batadv_has_timed_out(claim->lasttime, + BATADV_BLA_CLAIM_TIMEOUT)) + goto skip; + + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): timed out.\n", __func__); + +purge_now: + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): %pM, vid %d\n", __func__, + claim->addr, claim->vid); + + batadv_handle_unclaim(bat_priv, primary_if, + backbone_gw->orig, + claim->addr, claim->vid); +skip: + batadv_backbone_gw_put(backbone_gw); + } + rcu_read_unlock(); + } +} + +/** + * batadv_bla_update_orig_address() - Update the backbone gateways when the own + * originator address changes + * @bat_priv: the bat priv with all the soft interface information + * @primary_if: the new selected primary_if + * @oldif: the old primary interface, may be NULL + */ +void batadv_bla_update_orig_address(struct batadv_priv *bat_priv, + struct batadv_hard_iface *primary_if, + struct batadv_hard_iface *oldif) +{ + struct batadv_bla_backbone_gw *backbone_gw; + struct hlist_head *head; + struct batadv_hashtable *hash; + __be16 group; + int i; + + /* reset bridge loop avoidance group id */ + group = htons(crc16(0, primary_if->net_dev->dev_addr, ETH_ALEN)); + bat_priv->bla.claim_dest.group = group; + + /* purge everything when bridge loop avoidance is turned off */ + if (!atomic_read(&bat_priv->bridge_loop_avoidance)) + oldif = NULL; + + if (!oldif) { + batadv_bla_purge_claims(bat_priv, NULL, 1); + batadv_bla_purge_backbone_gw(bat_priv, 1); + return; + } + + hash = bat_priv->bla.backbone_hash; + if (!hash) + return; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(backbone_gw, head, hash_entry) { + /* own orig still holds the old value. */ + if (!batadv_compare_eth(backbone_gw->orig, + oldif->net_dev->dev_addr)) + continue; + + ether_addr_copy(backbone_gw->orig, + primary_if->net_dev->dev_addr); + /* send an announce frame so others will ask for our + * claims and update their tables. + */ + batadv_bla_send_announce(bat_priv, backbone_gw); + } + rcu_read_unlock(); + } +} + +/** + * batadv_bla_send_loopdetect() - send a loopdetect frame + * @bat_priv: the bat priv with all the soft interface information + * @backbone_gw: the backbone gateway for which a loop should be detected + * + * To detect loops that the bridge loop avoidance can't handle, send a loop + * detection packet on the backbone. Unlike other BLA frames, this frame will + * be allowed on the mesh by other nodes. If it is received on the mesh, this + * indicates that there is a loop. + */ +static void +batadv_bla_send_loopdetect(struct batadv_priv *bat_priv, + struct batadv_bla_backbone_gw *backbone_gw) +{ + batadv_dbg(BATADV_DBG_BLA, bat_priv, "Send loopdetect frame for vid %d\n", + backbone_gw->vid); + batadv_bla_send_claim(bat_priv, bat_priv->bla.loopdetect_addr, + backbone_gw->vid, BATADV_CLAIM_TYPE_LOOPDETECT); +} + +/** + * batadv_bla_status_update() - purge bla interfaces if necessary + * @net_dev: the soft interface net device + */ +void batadv_bla_status_update(struct net_device *net_dev) +{ + struct batadv_priv *bat_priv = netdev_priv(net_dev); + struct batadv_hard_iface *primary_if; + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + return; + + /* this function already purges everything when bla is disabled, + * so just call that one. + */ + batadv_bla_update_orig_address(bat_priv, primary_if, primary_if); + batadv_hardif_put(primary_if); +} + +/** + * batadv_bla_periodic_work() - performs periodic bla work + * @work: kernel work struct + * + * periodic work to do: + * * purge structures when they are too old + * * send announcements + */ +static void batadv_bla_periodic_work(struct work_struct *work) +{ + struct delayed_work *delayed_work; + struct batadv_priv *bat_priv; + struct batadv_priv_bla *priv_bla; + struct hlist_head *head; + struct batadv_bla_backbone_gw *backbone_gw; + struct batadv_hashtable *hash; + struct batadv_hard_iface *primary_if; + bool send_loopdetect = false; + int i; + + delayed_work = to_delayed_work(work); + priv_bla = container_of(delayed_work, struct batadv_priv_bla, work); + bat_priv = container_of(priv_bla, struct batadv_priv, bla); + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + goto out; + + batadv_bla_purge_claims(bat_priv, primary_if, 0); + batadv_bla_purge_backbone_gw(bat_priv, 0); + + if (!atomic_read(&bat_priv->bridge_loop_avoidance)) + goto out; + + if (atomic_dec_and_test(&bat_priv->bla.loopdetect_next)) { + /* set a new random mac address for the next bridge loop + * detection frames. Set the locally administered bit to avoid + * collisions with users mac addresses. + */ + eth_random_addr(bat_priv->bla.loopdetect_addr); + bat_priv->bla.loopdetect_addr[0] = 0xba; + bat_priv->bla.loopdetect_addr[1] = 0xbe; + bat_priv->bla.loopdetect_lasttime = jiffies; + atomic_set(&bat_priv->bla.loopdetect_next, + BATADV_BLA_LOOPDETECT_PERIODS); + + /* mark for sending loop detect on all VLANs */ + send_loopdetect = true; + } + + hash = bat_priv->bla.backbone_hash; + if (!hash) + goto out; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(backbone_gw, head, hash_entry) { + if (!batadv_compare_eth(backbone_gw->orig, + primary_if->net_dev->dev_addr)) + continue; + + backbone_gw->lasttime = jiffies; + + batadv_bla_send_announce(bat_priv, backbone_gw); + if (send_loopdetect) + batadv_bla_send_loopdetect(bat_priv, + backbone_gw); + + /* request_sent is only set after creation to avoid + * problems when we are not yet known as backbone gw + * in the backbone. + * + * We can reset this now after we waited some periods + * to give bridge forward delays and bla group forming + * some grace time. + */ + + if (atomic_read(&backbone_gw->request_sent) == 0) + continue; + + if (!atomic_dec_and_test(&backbone_gw->wait_periods)) + continue; + + atomic_dec(&backbone_gw->bat_priv->bla.num_requests); + atomic_set(&backbone_gw->request_sent, 0); + } + rcu_read_unlock(); + } +out: + batadv_hardif_put(primary_if); + + queue_delayed_work(batadv_event_workqueue, &bat_priv->bla.work, + msecs_to_jiffies(BATADV_BLA_PERIOD_LENGTH)); +} + +/* The hash for claim and backbone hash receive the same key because they + * are getting initialized by hash_new with the same key. Reinitializing + * them with to different keys to allow nested locking without generating + * lockdep warnings + */ +static struct lock_class_key batadv_claim_hash_lock_class_key; +static struct lock_class_key batadv_backbone_hash_lock_class_key; + +/** + * batadv_bla_init() - initialize all bla structures + * @bat_priv: the bat priv with all the soft interface information + * + * Return: 0 on success, < 0 on error. + */ +int batadv_bla_init(struct batadv_priv *bat_priv) +{ + int i; + u8 claim_dest[ETH_ALEN] = {0xff, 0x43, 0x05, 0x00, 0x00, 0x00}; + struct batadv_hard_iface *primary_if; + u16 crc; + unsigned long entrytime; + + spin_lock_init(&bat_priv->bla.bcast_duplist_lock); + + batadv_dbg(BATADV_DBG_BLA, bat_priv, "bla hash registering\n"); + + /* setting claim destination address */ + memcpy(&bat_priv->bla.claim_dest.magic, claim_dest, 3); + bat_priv->bla.claim_dest.type = 0; + primary_if = batadv_primary_if_get_selected(bat_priv); + if (primary_if) { + crc = crc16(0, primary_if->net_dev->dev_addr, ETH_ALEN); + bat_priv->bla.claim_dest.group = htons(crc); + batadv_hardif_put(primary_if); + } else { + bat_priv->bla.claim_dest.group = 0; /* will be set later */ + } + + /* initialize the duplicate list */ + entrytime = jiffies - msecs_to_jiffies(BATADV_DUPLIST_TIMEOUT); + for (i = 0; i < BATADV_DUPLIST_SIZE; i++) + bat_priv->bla.bcast_duplist[i].entrytime = entrytime; + bat_priv->bla.bcast_duplist_curr = 0; + + atomic_set(&bat_priv->bla.loopdetect_next, + BATADV_BLA_LOOPDETECT_PERIODS); + + if (bat_priv->bla.claim_hash) + return 0; + + bat_priv->bla.claim_hash = batadv_hash_new(128); + if (!bat_priv->bla.claim_hash) + return -ENOMEM; + + bat_priv->bla.backbone_hash = batadv_hash_new(32); + if (!bat_priv->bla.backbone_hash) { + batadv_hash_destroy(bat_priv->bla.claim_hash); + return -ENOMEM; + } + + batadv_hash_set_lock_class(bat_priv->bla.claim_hash, + &batadv_claim_hash_lock_class_key); + batadv_hash_set_lock_class(bat_priv->bla.backbone_hash, + &batadv_backbone_hash_lock_class_key); + + batadv_dbg(BATADV_DBG_BLA, bat_priv, "bla hashes initialized\n"); + + INIT_DELAYED_WORK(&bat_priv->bla.work, batadv_bla_periodic_work); + + queue_delayed_work(batadv_event_workqueue, &bat_priv->bla.work, + msecs_to_jiffies(BATADV_BLA_PERIOD_LENGTH)); + return 0; +} + +/** + * batadv_bla_check_duplist() - Check if a frame is in the broadcast dup. + * @bat_priv: the bat priv with all the soft interface information + * @skb: contains the multicast packet to be checked + * @payload_ptr: pointer to position inside the head buffer of the skb + * marking the start of the data to be CRC'ed + * @orig: originator mac address, NULL if unknown + * + * Check if it is on our broadcast list. Another gateway might have sent the + * same packet because it is connected to the same backbone, so we have to + * remove this duplicate. + * + * This is performed by checking the CRC, which will tell us + * with a good chance that it is the same packet. If it is furthermore + * sent by another host, drop it. We allow equal packets from + * the same host however as this might be intended. + * + * Return: true if a packet is in the duplicate list, false otherwise. + */ +static bool batadv_bla_check_duplist(struct batadv_priv *bat_priv, + struct sk_buff *skb, u8 *payload_ptr, + const u8 *orig) +{ + struct batadv_bcast_duplist_entry *entry; + bool ret = false; + int i, curr; + __be32 crc; + + /* calculate the crc ... */ + crc = batadv_skb_crc32(skb, payload_ptr); + + spin_lock_bh(&bat_priv->bla.bcast_duplist_lock); + + for (i = 0; i < BATADV_DUPLIST_SIZE; i++) { + curr = (bat_priv->bla.bcast_duplist_curr + i); + curr %= BATADV_DUPLIST_SIZE; + entry = &bat_priv->bla.bcast_duplist[curr]; + + /* we can stop searching if the entry is too old ; + * later entries will be even older + */ + if (batadv_has_timed_out(entry->entrytime, + BATADV_DUPLIST_TIMEOUT)) + break; + + if (entry->crc != crc) + continue; + + /* are the originators both known and not anonymous? */ + if (orig && !is_zero_ether_addr(orig) && + !is_zero_ether_addr(entry->orig)) { + /* If known, check if the new frame came from + * the same originator: + * We are safe to take identical frames from the + * same orig, if known, as multiplications in + * the mesh are detected via the (orig, seqno) pair. + * So we can be a bit more liberal here and allow + * identical frames from the same orig which the source + * host might have sent multiple times on purpose. + */ + if (batadv_compare_eth(entry->orig, orig)) + continue; + } + + /* this entry seems to match: same crc, not too old, + * and from another gw. therefore return true to forbid it. + */ + ret = true; + goto out; + } + /* not found, add a new entry (overwrite the oldest entry) + * and allow it, its the first occurrence. + */ + curr = (bat_priv->bla.bcast_duplist_curr + BATADV_DUPLIST_SIZE - 1); + curr %= BATADV_DUPLIST_SIZE; + entry = &bat_priv->bla.bcast_duplist[curr]; + entry->crc = crc; + entry->entrytime = jiffies; + + /* known originator */ + if (orig) + ether_addr_copy(entry->orig, orig); + /* anonymous originator */ + else + eth_zero_addr(entry->orig); + + bat_priv->bla.bcast_duplist_curr = curr; + +out: + spin_unlock_bh(&bat_priv->bla.bcast_duplist_lock); + + return ret; +} + +/** + * batadv_bla_check_ucast_duplist() - Check if a frame is in the broadcast dup. + * @bat_priv: the bat priv with all the soft interface information + * @skb: contains the multicast packet to be checked, decapsulated from a + * unicast_packet + * + * Check if it is on our broadcast list. Another gateway might have sent the + * same packet because it is connected to the same backbone, so we have to + * remove this duplicate. + * + * Return: true if a packet is in the duplicate list, false otherwise. + */ +static bool batadv_bla_check_ucast_duplist(struct batadv_priv *bat_priv, + struct sk_buff *skb) +{ + return batadv_bla_check_duplist(bat_priv, skb, (u8 *)skb->data, NULL); +} + +/** + * batadv_bla_check_bcast_duplist() - Check if a frame is in the broadcast dup. + * @bat_priv: the bat priv with all the soft interface information + * @skb: contains the bcast_packet to be checked + * + * Check if it is on our broadcast list. Another gateway might have sent the + * same packet because it is connected to the same backbone, so we have to + * remove this duplicate. + * + * Return: true if a packet is in the duplicate list, false otherwise. + */ +bool batadv_bla_check_bcast_duplist(struct batadv_priv *bat_priv, + struct sk_buff *skb) +{ + struct batadv_bcast_packet *bcast_packet; + u8 *payload_ptr; + + bcast_packet = (struct batadv_bcast_packet *)skb->data; + payload_ptr = (u8 *)(bcast_packet + 1); + + return batadv_bla_check_duplist(bat_priv, skb, payload_ptr, + bcast_packet->orig); +} + +/** + * batadv_bla_is_backbone_gw_orig() - Check if the originator is a gateway for + * the VLAN identified by vid. + * @bat_priv: the bat priv with all the soft interface information + * @orig: originator mac address + * @vid: VLAN identifier + * + * Return: true if orig is a backbone for this vid, false otherwise. + */ +bool batadv_bla_is_backbone_gw_orig(struct batadv_priv *bat_priv, u8 *orig, + unsigned short vid) +{ + struct batadv_hashtable *hash = bat_priv->bla.backbone_hash; + struct hlist_head *head; + struct batadv_bla_backbone_gw *backbone_gw; + int i; + + if (!atomic_read(&bat_priv->bridge_loop_avoidance)) + return false; + + if (!hash) + return false; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(backbone_gw, head, hash_entry) { + if (batadv_compare_eth(backbone_gw->orig, orig) && + backbone_gw->vid == vid) { + rcu_read_unlock(); + return true; + } + } + rcu_read_unlock(); + } + + return false; +} + +/** + * batadv_bla_is_backbone_gw() - check if originator is a backbone gw for a VLAN + * @skb: the frame to be checked + * @orig_node: the orig_node of the frame + * @hdr_size: maximum length of the frame + * + * Return: true if the orig_node is also a gateway on the soft interface, + * otherwise it returns false. + */ +bool batadv_bla_is_backbone_gw(struct sk_buff *skb, + struct batadv_orig_node *orig_node, int hdr_size) +{ + struct batadv_bla_backbone_gw *backbone_gw; + unsigned short vid; + + if (!atomic_read(&orig_node->bat_priv->bridge_loop_avoidance)) + return false; + + /* first, find out the vid. */ + if (!pskb_may_pull(skb, hdr_size + ETH_HLEN)) + return false; + + vid = batadv_get_vid(skb, hdr_size); + + /* see if this originator is a backbone gw for this VLAN */ + backbone_gw = batadv_backbone_hash_find(orig_node->bat_priv, + orig_node->orig, vid); + if (!backbone_gw) + return false; + + batadv_backbone_gw_put(backbone_gw); + return true; +} + +/** + * batadv_bla_free() - free all bla structures + * @bat_priv: the bat priv with all the soft interface information + * + * for softinterface free or module unload + */ +void batadv_bla_free(struct batadv_priv *bat_priv) +{ + struct batadv_hard_iface *primary_if; + + cancel_delayed_work_sync(&bat_priv->bla.work); + primary_if = batadv_primary_if_get_selected(bat_priv); + + if (bat_priv->bla.claim_hash) { + batadv_bla_purge_claims(bat_priv, primary_if, 1); + batadv_hash_destroy(bat_priv->bla.claim_hash); + bat_priv->bla.claim_hash = NULL; + } + if (bat_priv->bla.backbone_hash) { + batadv_bla_purge_backbone_gw(bat_priv, 1); + batadv_hash_destroy(bat_priv->bla.backbone_hash); + bat_priv->bla.backbone_hash = NULL; + } + batadv_hardif_put(primary_if); +} + +/** + * batadv_bla_loopdetect_check() - check and handle a detected loop + * @bat_priv: the bat priv with all the soft interface information + * @skb: the packet to check + * @primary_if: interface where the request came on + * @vid: the VLAN ID of the frame + * + * Checks if this packet is a loop detect frame which has been sent by us, + * throws an uevent and logs the event if that is the case. + * + * Return: true if it is a loop detect frame which is to be dropped, false + * otherwise. + */ +static bool +batadv_bla_loopdetect_check(struct batadv_priv *bat_priv, struct sk_buff *skb, + struct batadv_hard_iface *primary_if, + unsigned short vid) +{ + struct batadv_bla_backbone_gw *backbone_gw; + struct ethhdr *ethhdr; + bool ret; + + ethhdr = eth_hdr(skb); + + /* Only check for the MAC address and skip more checks here for + * performance reasons - this function is on the hotpath, after all. + */ + if (!batadv_compare_eth(ethhdr->h_source, + bat_priv->bla.loopdetect_addr)) + return false; + + /* If the packet came too late, don't forward it on the mesh + * but don't consider that as loop. It might be a coincidence. + */ + if (batadv_has_timed_out(bat_priv->bla.loopdetect_lasttime, + BATADV_BLA_LOOPDETECT_TIMEOUT)) + return true; + + backbone_gw = batadv_bla_get_backbone_gw(bat_priv, + primary_if->net_dev->dev_addr, + vid, true); + if (unlikely(!backbone_gw)) + return true; + + ret = queue_work(batadv_event_workqueue, &backbone_gw->report_work); + + /* backbone_gw is unreferenced in the report work function + * if queue_work() call was successful + */ + if (!ret) + batadv_backbone_gw_put(backbone_gw); + + return true; +} + +/** + * batadv_bla_rx() - check packets coming from the mesh. + * @bat_priv: the bat priv with all the soft interface information + * @skb: the frame to be checked + * @vid: the VLAN ID of the frame + * @packet_type: the batman packet type this frame came in + * + * batadv_bla_rx avoidance checks if: + * * we have to race for a claim + * * if the frame is allowed on the LAN + * + * In these cases, the skb is further handled by this function + * + * Return: true if handled, otherwise it returns false and the caller shall + * further process the skb. + */ +bool batadv_bla_rx(struct batadv_priv *bat_priv, struct sk_buff *skb, + unsigned short vid, int packet_type) +{ + struct batadv_bla_backbone_gw *backbone_gw; + struct ethhdr *ethhdr; + struct batadv_bla_claim search_claim, *claim = NULL; + struct batadv_hard_iface *primary_if; + bool own_claim; + bool ret; + + ethhdr = eth_hdr(skb); + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + goto handled; + + if (!atomic_read(&bat_priv->bridge_loop_avoidance)) + goto allow; + + if (batadv_bla_loopdetect_check(bat_priv, skb, primary_if, vid)) + goto handled; + + if (unlikely(atomic_read(&bat_priv->bla.num_requests))) + /* don't allow multicast packets while requests are in flight */ + if (is_multicast_ether_addr(ethhdr->h_dest)) + /* Both broadcast flooding or multicast-via-unicasts + * delivery might send to multiple backbone gateways + * sharing the same LAN and therefore need to coordinate + * which backbone gateway forwards into the LAN, + * by claiming the payload source address. + * + * Broadcast flooding and multicast-via-unicasts + * delivery use the following two batman packet types. + * Note: explicitly exclude BATADV_UNICAST_4ADDR, + * as the DHCP gateway feature will send explicitly + * to only one BLA gateway, so the claiming process + * should be avoided there. + */ + if (packet_type == BATADV_BCAST || + packet_type == BATADV_UNICAST) + goto handled; + + /* potential duplicates from foreign BLA backbone gateways via + * multicast-in-unicast packets + */ + if (is_multicast_ether_addr(ethhdr->h_dest) && + packet_type == BATADV_UNICAST && + batadv_bla_check_ucast_duplist(bat_priv, skb)) + goto handled; + + ether_addr_copy(search_claim.addr, ethhdr->h_source); + search_claim.vid = vid; + claim = batadv_claim_hash_find(bat_priv, &search_claim); + + if (!claim) { + /* possible optimization: race for a claim */ + /* No claim exists yet, claim it for us! + */ + + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): Unclaimed MAC %pM found. Claim it. Local: %s\n", + __func__, ethhdr->h_source, + batadv_is_my_client(bat_priv, + ethhdr->h_source, vid) ? + "yes" : "no"); + batadv_handle_claim(bat_priv, primary_if, + primary_if->net_dev->dev_addr, + ethhdr->h_source, vid); + goto allow; + } + + /* if it is our own claim ... */ + backbone_gw = batadv_bla_claim_get_backbone_gw(claim); + own_claim = batadv_compare_eth(backbone_gw->orig, + primary_if->net_dev->dev_addr); + batadv_backbone_gw_put(backbone_gw); + + if (own_claim) { + /* ... allow it in any case */ + claim->lasttime = jiffies; + goto allow; + } + + /* if it is a multicast ... */ + if (is_multicast_ether_addr(ethhdr->h_dest) && + (packet_type == BATADV_BCAST || packet_type == BATADV_UNICAST)) { + /* ... drop it. the responsible gateway is in charge. + * + * We need to check packet type because with the gateway + * feature, broadcasts (like DHCP requests) may be sent + * using a unicast 4 address packet type. See comment above. + */ + goto handled; + } else { + /* seems the client considers us as its best gateway. + * send a claim and update the claim table + * immediately. + */ + batadv_handle_claim(bat_priv, primary_if, + primary_if->net_dev->dev_addr, + ethhdr->h_source, vid); + goto allow; + } +allow: + batadv_bla_update_own_backbone_gw(bat_priv, primary_if, vid); + ret = false; + goto out; + +handled: + kfree_skb(skb); + ret = true; + +out: + batadv_hardif_put(primary_if); + batadv_claim_put(claim); + return ret; +} + +/** + * batadv_bla_tx() - check packets going into the mesh + * @bat_priv: the bat priv with all the soft interface information + * @skb: the frame to be checked + * @vid: the VLAN ID of the frame + * + * batadv_bla_tx checks if: + * * a claim was received which has to be processed + * * the frame is allowed on the mesh + * + * in these cases, the skb is further handled by this function. + * + * This call might reallocate skb data. + * + * Return: true if handled, otherwise it returns false and the caller shall + * further process the skb. + */ +bool batadv_bla_tx(struct batadv_priv *bat_priv, struct sk_buff *skb, + unsigned short vid) +{ + struct ethhdr *ethhdr; + struct batadv_bla_claim search_claim, *claim = NULL; + struct batadv_bla_backbone_gw *backbone_gw; + struct batadv_hard_iface *primary_if; + bool client_roamed; + bool ret = false; + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + goto out; + + if (!atomic_read(&bat_priv->bridge_loop_avoidance)) + goto allow; + + if (batadv_bla_process_claim(bat_priv, primary_if, skb)) + goto handled; + + ethhdr = eth_hdr(skb); + + if (unlikely(atomic_read(&bat_priv->bla.num_requests))) + /* don't allow broadcasts while requests are in flight */ + if (is_multicast_ether_addr(ethhdr->h_dest)) + goto handled; + + ether_addr_copy(search_claim.addr, ethhdr->h_source); + search_claim.vid = vid; + + claim = batadv_claim_hash_find(bat_priv, &search_claim); + + /* if no claim exists, allow it. */ + if (!claim) + goto allow; + + /* check if we are responsible. */ + backbone_gw = batadv_bla_claim_get_backbone_gw(claim); + client_roamed = batadv_compare_eth(backbone_gw->orig, + primary_if->net_dev->dev_addr); + batadv_backbone_gw_put(backbone_gw); + + if (client_roamed) { + /* if yes, the client has roamed and we have + * to unclaim it. + */ + if (batadv_has_timed_out(claim->lasttime, 100)) { + /* only unclaim if the last claim entry is + * older than 100 ms to make sure we really + * have a roaming client here. + */ + batadv_dbg(BATADV_DBG_BLA, bat_priv, "%s(): Roaming client %pM detected. Unclaim it.\n", + __func__, ethhdr->h_source); + batadv_handle_unclaim(bat_priv, primary_if, + primary_if->net_dev->dev_addr, + ethhdr->h_source, vid); + goto allow; + } else { + batadv_dbg(BATADV_DBG_BLA, bat_priv, "%s(): Race for claim %pM detected. Drop packet.\n", + __func__, ethhdr->h_source); + goto handled; + } + } + + /* check if it is a multicast/broadcast frame */ + if (is_multicast_ether_addr(ethhdr->h_dest)) { + /* drop it. the responsible gateway has forwarded it into + * the backbone network. + */ + goto handled; + } else { + /* we must allow it. at least if we are + * responsible for the DESTINATION. + */ + goto allow; + } +allow: + batadv_bla_update_own_backbone_gw(bat_priv, primary_if, vid); + ret = false; + goto out; +handled: + ret = true; +out: + batadv_hardif_put(primary_if); + batadv_claim_put(claim); + return ret; +} + +/** + * batadv_bla_claim_dump_entry() - dump one entry of the claim table + * to a netlink socket + * @msg: buffer for the message + * @portid: netlink port + * @cb: Control block containing additional options + * @primary_if: primary interface + * @claim: entry to dump + * + * Return: 0 or error code. + */ +static int +batadv_bla_claim_dump_entry(struct sk_buff *msg, u32 portid, + struct netlink_callback *cb, + struct batadv_hard_iface *primary_if, + struct batadv_bla_claim *claim) +{ + const u8 *primary_addr = primary_if->net_dev->dev_addr; + u16 backbone_crc; + bool is_own; + void *hdr; + int ret = -EINVAL; + + hdr = genlmsg_put(msg, portid, cb->nlh->nlmsg_seq, + &batadv_netlink_family, NLM_F_MULTI, + BATADV_CMD_GET_BLA_CLAIM); + if (!hdr) { + ret = -ENOBUFS; + goto out; + } + + genl_dump_check_consistent(cb, hdr); + + is_own = batadv_compare_eth(claim->backbone_gw->orig, + primary_addr); + + spin_lock_bh(&claim->backbone_gw->crc_lock); + backbone_crc = claim->backbone_gw->crc; + spin_unlock_bh(&claim->backbone_gw->crc_lock); + + if (is_own) + if (nla_put_flag(msg, BATADV_ATTR_BLA_OWN)) { + genlmsg_cancel(msg, hdr); + goto out; + } + + if (nla_put(msg, BATADV_ATTR_BLA_ADDRESS, ETH_ALEN, claim->addr) || + nla_put_u16(msg, BATADV_ATTR_BLA_VID, claim->vid) || + nla_put(msg, BATADV_ATTR_BLA_BACKBONE, ETH_ALEN, + claim->backbone_gw->orig) || + nla_put_u16(msg, BATADV_ATTR_BLA_CRC, + backbone_crc)) { + genlmsg_cancel(msg, hdr); + goto out; + } + + genlmsg_end(msg, hdr); + ret = 0; + +out: + return ret; +} + +/** + * batadv_bla_claim_dump_bucket() - dump one bucket of the claim table + * to a netlink socket + * @msg: buffer for the message + * @portid: netlink port + * @cb: Control block containing additional options + * @primary_if: primary interface + * @hash: hash to dump + * @bucket: bucket index to dump + * @idx_skip: How many entries to skip + * + * Return: always 0. + */ +static int +batadv_bla_claim_dump_bucket(struct sk_buff *msg, u32 portid, + struct netlink_callback *cb, + struct batadv_hard_iface *primary_if, + struct batadv_hashtable *hash, unsigned int bucket, + int *idx_skip) +{ + struct batadv_bla_claim *claim; + int idx = 0; + int ret = 0; + + spin_lock_bh(&hash->list_locks[bucket]); + cb->seq = atomic_read(&hash->generation) << 1 | 1; + + hlist_for_each_entry(claim, &hash->table[bucket], hash_entry) { + if (idx++ < *idx_skip) + continue; + + ret = batadv_bla_claim_dump_entry(msg, portid, cb, + primary_if, claim); + if (ret) { + *idx_skip = idx - 1; + goto unlock; + } + } + + *idx_skip = 0; +unlock: + spin_unlock_bh(&hash->list_locks[bucket]); + return ret; +} + +/** + * batadv_bla_claim_dump() - dump claim table to a netlink socket + * @msg: buffer for the message + * @cb: callback structure containing arguments + * + * Return: message length. + */ +int batadv_bla_claim_dump(struct sk_buff *msg, struct netlink_callback *cb) +{ + struct batadv_hard_iface *primary_if = NULL; + int portid = NETLINK_CB(cb->skb).portid; + struct net *net = sock_net(cb->skb->sk); + struct net_device *soft_iface; + struct batadv_hashtable *hash; + struct batadv_priv *bat_priv; + int bucket = cb->args[0]; + int idx = cb->args[1]; + int ifindex; + int ret = 0; + + ifindex = batadv_netlink_get_ifindex(cb->nlh, + BATADV_ATTR_MESH_IFINDEX); + if (!ifindex) + return -EINVAL; + + soft_iface = dev_get_by_index(net, ifindex); + if (!soft_iface || !batadv_softif_is_valid(soft_iface)) { + ret = -ENODEV; + goto out; + } + + bat_priv = netdev_priv(soft_iface); + hash = bat_priv->bla.claim_hash; + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if || primary_if->if_status != BATADV_IF_ACTIVE) { + ret = -ENOENT; + goto out; + } + + while (bucket < hash->size) { + if (batadv_bla_claim_dump_bucket(msg, portid, cb, primary_if, + hash, bucket, &idx)) + break; + bucket++; + } + + cb->args[0] = bucket; + cb->args[1] = idx; + + ret = msg->len; + +out: + batadv_hardif_put(primary_if); + + dev_put(soft_iface); + + return ret; +} + +/** + * batadv_bla_backbone_dump_entry() - dump one entry of the backbone table to a + * netlink socket + * @msg: buffer for the message + * @portid: netlink port + * @cb: Control block containing additional options + * @primary_if: primary interface + * @backbone_gw: entry to dump + * + * Return: 0 or error code. + */ +static int +batadv_bla_backbone_dump_entry(struct sk_buff *msg, u32 portid, + struct netlink_callback *cb, + struct batadv_hard_iface *primary_if, + struct batadv_bla_backbone_gw *backbone_gw) +{ + const u8 *primary_addr = primary_if->net_dev->dev_addr; + u16 backbone_crc; + bool is_own; + int msecs; + void *hdr; + int ret = -EINVAL; + + hdr = genlmsg_put(msg, portid, cb->nlh->nlmsg_seq, + &batadv_netlink_family, NLM_F_MULTI, + BATADV_CMD_GET_BLA_BACKBONE); + if (!hdr) { + ret = -ENOBUFS; + goto out; + } + + genl_dump_check_consistent(cb, hdr); + + is_own = batadv_compare_eth(backbone_gw->orig, primary_addr); + + spin_lock_bh(&backbone_gw->crc_lock); + backbone_crc = backbone_gw->crc; + spin_unlock_bh(&backbone_gw->crc_lock); + + msecs = jiffies_to_msecs(jiffies - backbone_gw->lasttime); + + if (is_own) + if (nla_put_flag(msg, BATADV_ATTR_BLA_OWN)) { + genlmsg_cancel(msg, hdr); + goto out; + } + + if (nla_put(msg, BATADV_ATTR_BLA_BACKBONE, ETH_ALEN, + backbone_gw->orig) || + nla_put_u16(msg, BATADV_ATTR_BLA_VID, backbone_gw->vid) || + nla_put_u16(msg, BATADV_ATTR_BLA_CRC, + backbone_crc) || + nla_put_u32(msg, BATADV_ATTR_LAST_SEEN_MSECS, msecs)) { + genlmsg_cancel(msg, hdr); + goto out; + } + + genlmsg_end(msg, hdr); + ret = 0; + +out: + return ret; +} + +/** + * batadv_bla_backbone_dump_bucket() - dump one bucket of the backbone table to + * a netlink socket + * @msg: buffer for the message + * @portid: netlink port + * @cb: Control block containing additional options + * @primary_if: primary interface + * @hash: hash to dump + * @bucket: bucket index to dump + * @idx_skip: How many entries to skip + * + * Return: always 0. + */ +static int +batadv_bla_backbone_dump_bucket(struct sk_buff *msg, u32 portid, + struct netlink_callback *cb, + struct batadv_hard_iface *primary_if, + struct batadv_hashtable *hash, + unsigned int bucket, int *idx_skip) +{ + struct batadv_bla_backbone_gw *backbone_gw; + int idx = 0; + int ret = 0; + + spin_lock_bh(&hash->list_locks[bucket]); + cb->seq = atomic_read(&hash->generation) << 1 | 1; + + hlist_for_each_entry(backbone_gw, &hash->table[bucket], hash_entry) { + if (idx++ < *idx_skip) + continue; + + ret = batadv_bla_backbone_dump_entry(msg, portid, cb, + primary_if, backbone_gw); + if (ret) { + *idx_skip = idx - 1; + goto unlock; + } + } + + *idx_skip = 0; +unlock: + spin_unlock_bh(&hash->list_locks[bucket]); + return ret; +} + +/** + * batadv_bla_backbone_dump() - dump backbone table to a netlink socket + * @msg: buffer for the message + * @cb: callback structure containing arguments + * + * Return: message length. + */ +int batadv_bla_backbone_dump(struct sk_buff *msg, struct netlink_callback *cb) +{ + struct batadv_hard_iface *primary_if = NULL; + int portid = NETLINK_CB(cb->skb).portid; + struct net *net = sock_net(cb->skb->sk); + struct net_device *soft_iface; + struct batadv_hashtable *hash; + struct batadv_priv *bat_priv; + int bucket = cb->args[0]; + int idx = cb->args[1]; + int ifindex; + int ret = 0; + + ifindex = batadv_netlink_get_ifindex(cb->nlh, + BATADV_ATTR_MESH_IFINDEX); + if (!ifindex) + return -EINVAL; + + soft_iface = dev_get_by_index(net, ifindex); + if (!soft_iface || !batadv_softif_is_valid(soft_iface)) { + ret = -ENODEV; + goto out; + } + + bat_priv = netdev_priv(soft_iface); + hash = bat_priv->bla.backbone_hash; + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if || primary_if->if_status != BATADV_IF_ACTIVE) { + ret = -ENOENT; + goto out; + } + + while (bucket < hash->size) { + if (batadv_bla_backbone_dump_bucket(msg, portid, cb, primary_if, + hash, bucket, &idx)) + break; + bucket++; + } + + cb->args[0] = bucket; + cb->args[1] = idx; + + ret = msg->len; + +out: + batadv_hardif_put(primary_if); + + dev_put(soft_iface); + + return ret; +} + +#ifdef CONFIG_BATMAN_ADV_DAT +/** + * batadv_bla_check_claim() - check if address is claimed + * + * @bat_priv: the bat priv with all the soft interface information + * @addr: mac address of which the claim status is checked + * @vid: the VLAN ID + * + * addr is checked if this address is claimed by the local device itself. + * + * Return: true if bla is disabled or the mac is claimed by the device, + * false if the device addr is already claimed by another gateway + */ +bool batadv_bla_check_claim(struct batadv_priv *bat_priv, + u8 *addr, unsigned short vid) +{ + struct batadv_bla_claim search_claim; + struct batadv_bla_claim *claim = NULL; + struct batadv_hard_iface *primary_if = NULL; + bool ret = true; + + if (!atomic_read(&bat_priv->bridge_loop_avoidance)) + return ret; + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + return ret; + + /* First look if the mac address is claimed */ + ether_addr_copy(search_claim.addr, addr); + search_claim.vid = vid; + + claim = batadv_claim_hash_find(bat_priv, &search_claim); + + /* If there is a claim and we are not owner of the claim, + * return false. + */ + if (claim) { + if (!batadv_compare_eth(claim->backbone_gw->orig, + primary_if->net_dev->dev_addr)) + ret = false; + batadv_claim_put(claim); + } + + batadv_hardif_put(primary_if); + return ret; +} +#endif diff --git a/net/batman-adv/bridge_loop_avoidance.h b/net/batman-adv/bridge_loop_avoidance.h new file mode 100644 index 000000000..8673a2659 --- /dev/null +++ b/net/batman-adv/bridge_loop_avoidance.h @@ -0,0 +1,132 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Simon Wunderlich + */ + +#ifndef _NET_BATMAN_ADV_BLA_H_ +#define _NET_BATMAN_ADV_BLA_H_ + +#include "main.h" + +#include +#include +#include +#include +#include +#include + +/** + * batadv_bla_is_loopdetect_mac() - check if the mac address is from a loop + * detect frame sent by bridge loop avoidance + * @mac: mac address to check + * + * Return: true if the it looks like a loop detect frame + * (mac starts with BA:BE), false otherwise + */ +static inline bool batadv_bla_is_loopdetect_mac(const uint8_t *mac) +{ + if (mac[0] == 0xba && mac[1] == 0xbe) + return true; + + return false; +} + +#ifdef CONFIG_BATMAN_ADV_BLA +bool batadv_bla_rx(struct batadv_priv *bat_priv, struct sk_buff *skb, + unsigned short vid, int packet_type); +bool batadv_bla_tx(struct batadv_priv *bat_priv, struct sk_buff *skb, + unsigned short vid); +bool batadv_bla_is_backbone_gw(struct sk_buff *skb, + struct batadv_orig_node *orig_node, + int hdr_size); +int batadv_bla_claim_dump(struct sk_buff *msg, struct netlink_callback *cb); +int batadv_bla_backbone_dump(struct sk_buff *msg, struct netlink_callback *cb); +bool batadv_bla_is_backbone_gw_orig(struct batadv_priv *bat_priv, u8 *orig, + unsigned short vid); +bool batadv_bla_check_bcast_duplist(struct batadv_priv *bat_priv, + struct sk_buff *skb); +void batadv_bla_update_orig_address(struct batadv_priv *bat_priv, + struct batadv_hard_iface *primary_if, + struct batadv_hard_iface *oldif); +void batadv_bla_status_update(struct net_device *net_dev); +int batadv_bla_init(struct batadv_priv *bat_priv); +void batadv_bla_free(struct batadv_priv *bat_priv); +#ifdef CONFIG_BATMAN_ADV_DAT +bool batadv_bla_check_claim(struct batadv_priv *bat_priv, u8 *addr, + unsigned short vid); +#endif +#define BATADV_BLA_CRC_INIT 0 +#else /* ifdef CONFIG_BATMAN_ADV_BLA */ + +static inline bool batadv_bla_rx(struct batadv_priv *bat_priv, + struct sk_buff *skb, unsigned short vid, + int packet_type) +{ + return false; +} + +static inline bool batadv_bla_tx(struct batadv_priv *bat_priv, + struct sk_buff *skb, unsigned short vid) +{ + return false; +} + +static inline bool batadv_bla_is_backbone_gw(struct sk_buff *skb, + struct batadv_orig_node *orig_node, + int hdr_size) +{ + return false; +} + +static inline bool batadv_bla_is_backbone_gw_orig(struct batadv_priv *bat_priv, + u8 *orig, unsigned short vid) +{ + return false; +} + +static inline bool +batadv_bla_check_bcast_duplist(struct batadv_priv *bat_priv, + struct sk_buff *skb) +{ + return false; +} + +static inline void +batadv_bla_update_orig_address(struct batadv_priv *bat_priv, + struct batadv_hard_iface *primary_if, + struct batadv_hard_iface *oldif) +{ +} + +static inline int batadv_bla_init(struct batadv_priv *bat_priv) +{ + return 1; +} + +static inline void batadv_bla_free(struct batadv_priv *bat_priv) +{ +} + +static inline int batadv_bla_claim_dump(struct sk_buff *msg, + struct netlink_callback *cb) +{ + return -EOPNOTSUPP; +} + +static inline int batadv_bla_backbone_dump(struct sk_buff *msg, + struct netlink_callback *cb) +{ + return -EOPNOTSUPP; +} + +static inline +bool batadv_bla_check_claim(struct batadv_priv *bat_priv, u8 *addr, + unsigned short vid) +{ + return true; +} + +#endif /* ifdef CONFIG_BATMAN_ADV_BLA */ + +#endif /* ifndef _NET_BATMAN_ADV_BLA_H_ */ diff --git a/net/batman-adv/distributed-arp-table.c b/net/batman-adv/distributed-arp-table.c new file mode 100644 index 000000000..a02332520 --- /dev/null +++ b/net/batman-adv/distributed-arp-table.c @@ -0,0 +1,1831 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Antonio Quartulli + */ + +#include "distributed-arp-table.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "bridge_loop_avoidance.h" +#include "hard-interface.h" +#include "hash.h" +#include "log.h" +#include "netlink.h" +#include "originator.h" +#include "send.h" +#include "soft-interface.h" +#include "translation-table.h" +#include "tvlv.h" + +enum batadv_bootpop { + BATADV_BOOTREPLY = 2, +}; + +enum batadv_boothtype { + BATADV_HTYPE_ETHERNET = 1, +}; + +enum batadv_dhcpoptioncode { + BATADV_DHCP_OPT_PAD = 0, + BATADV_DHCP_OPT_MSG_TYPE = 53, + BATADV_DHCP_OPT_END = 255, +}; + +enum batadv_dhcptype { + BATADV_DHCPACK = 5, +}; + +/* { 99, 130, 83, 99 } */ +#define BATADV_DHCP_MAGIC 1669485411 + +struct batadv_dhcp_packet { + __u8 op; + __u8 htype; + __u8 hlen; + __u8 hops; + __be32 xid; + __be16 secs; + __be16 flags; + __be32 ciaddr; + __be32 yiaddr; + __be32 siaddr; + __be32 giaddr; + __u8 chaddr[16]; + __u8 sname[64]; + __u8 file[128]; + __be32 magic; + /* __u8 options[]; */ +}; + +#define BATADV_DHCP_YIADDR_LEN sizeof(((struct batadv_dhcp_packet *)0)->yiaddr) +#define BATADV_DHCP_CHADDR_LEN sizeof(((struct batadv_dhcp_packet *)0)->chaddr) + +static void batadv_dat_purge(struct work_struct *work); + +/** + * batadv_dat_start_timer() - initialise the DAT periodic worker + * @bat_priv: the bat priv with all the soft interface information + */ +static void batadv_dat_start_timer(struct batadv_priv *bat_priv) +{ + queue_delayed_work(batadv_event_workqueue, &bat_priv->dat.work, + msecs_to_jiffies(10000)); +} + +/** + * batadv_dat_entry_release() - release dat_entry from lists and queue for free + * after rcu grace period + * @ref: kref pointer of the dat_entry + */ +static void batadv_dat_entry_release(struct kref *ref) +{ + struct batadv_dat_entry *dat_entry; + + dat_entry = container_of(ref, struct batadv_dat_entry, refcount); + + kfree_rcu(dat_entry, rcu); +} + +/** + * batadv_dat_entry_put() - decrement the dat_entry refcounter and possibly + * release it + * @dat_entry: dat_entry to be free'd + */ +static void batadv_dat_entry_put(struct batadv_dat_entry *dat_entry) +{ + if (!dat_entry) + return; + + kref_put(&dat_entry->refcount, batadv_dat_entry_release); +} + +/** + * batadv_dat_to_purge() - check whether a dat_entry has to be purged or not + * @dat_entry: the entry to check + * + * Return: true if the entry has to be purged now, false otherwise. + */ +static bool batadv_dat_to_purge(struct batadv_dat_entry *dat_entry) +{ + return batadv_has_timed_out(dat_entry->last_update, + BATADV_DAT_ENTRY_TIMEOUT); +} + +/** + * __batadv_dat_purge() - delete entries from the DAT local storage + * @bat_priv: the bat priv with all the soft interface information + * @to_purge: function in charge to decide whether an entry has to be purged or + * not. This function takes the dat_entry as argument and has to + * returns a boolean value: true is the entry has to be deleted, + * false otherwise + * + * Loops over each entry in the DAT local storage and deletes it if and only if + * the to_purge function passed as argument returns true. + */ +static void __batadv_dat_purge(struct batadv_priv *bat_priv, + bool (*to_purge)(struct batadv_dat_entry *)) +{ + spinlock_t *list_lock; /* protects write access to the hash lists */ + struct batadv_dat_entry *dat_entry; + struct hlist_node *node_tmp; + struct hlist_head *head; + u32 i; + + if (!bat_priv->dat.hash) + return; + + for (i = 0; i < bat_priv->dat.hash->size; i++) { + head = &bat_priv->dat.hash->table[i]; + list_lock = &bat_priv->dat.hash->list_locks[i]; + + spin_lock_bh(list_lock); + hlist_for_each_entry_safe(dat_entry, node_tmp, head, + hash_entry) { + /* if a helper function has been passed as parameter, + * ask it if the entry has to be purged or not + */ + if (to_purge && !to_purge(dat_entry)) + continue; + + hlist_del_rcu(&dat_entry->hash_entry); + batadv_dat_entry_put(dat_entry); + } + spin_unlock_bh(list_lock); + } +} + +/** + * batadv_dat_purge() - periodic task that deletes old entries from the local + * DAT hash table + * @work: kernel work struct + */ +static void batadv_dat_purge(struct work_struct *work) +{ + struct delayed_work *delayed_work; + struct batadv_priv_dat *priv_dat; + struct batadv_priv *bat_priv; + + delayed_work = to_delayed_work(work); + priv_dat = container_of(delayed_work, struct batadv_priv_dat, work); + bat_priv = container_of(priv_dat, struct batadv_priv, dat); + + __batadv_dat_purge(bat_priv, batadv_dat_to_purge); + batadv_dat_start_timer(bat_priv); +} + +/** + * batadv_compare_dat() - comparing function used in the local DAT hash table + * @node: node in the local table + * @data2: second object to compare the node to + * + * Return: true if the two entries are the same, false otherwise. + */ +static bool batadv_compare_dat(const struct hlist_node *node, const void *data2) +{ + const void *data1 = container_of(node, struct batadv_dat_entry, + hash_entry); + + return memcmp(data1, data2, sizeof(__be32)) == 0; +} + +/** + * batadv_arp_hw_src() - extract the hw_src field from an ARP packet + * @skb: ARP packet + * @hdr_size: size of the possible header before the ARP packet + * + * Return: the value of the hw_src field in the ARP packet. + */ +static u8 *batadv_arp_hw_src(struct sk_buff *skb, int hdr_size) +{ + u8 *addr; + + addr = (u8 *)(skb->data + hdr_size); + addr += ETH_HLEN + sizeof(struct arphdr); + + return addr; +} + +/** + * batadv_arp_ip_src() - extract the ip_src field from an ARP packet + * @skb: ARP packet + * @hdr_size: size of the possible header before the ARP packet + * + * Return: the value of the ip_src field in the ARP packet. + */ +static __be32 batadv_arp_ip_src(struct sk_buff *skb, int hdr_size) +{ + return *(__force __be32 *)(batadv_arp_hw_src(skb, hdr_size) + ETH_ALEN); +} + +/** + * batadv_arp_hw_dst() - extract the hw_dst field from an ARP packet + * @skb: ARP packet + * @hdr_size: size of the possible header before the ARP packet + * + * Return: the value of the hw_dst field in the ARP packet. + */ +static u8 *batadv_arp_hw_dst(struct sk_buff *skb, int hdr_size) +{ + return batadv_arp_hw_src(skb, hdr_size) + ETH_ALEN + 4; +} + +/** + * batadv_arp_ip_dst() - extract the ip_dst field from an ARP packet + * @skb: ARP packet + * @hdr_size: size of the possible header before the ARP packet + * + * Return: the value of the ip_dst field in the ARP packet. + */ +static __be32 batadv_arp_ip_dst(struct sk_buff *skb, int hdr_size) +{ + u8 *dst = batadv_arp_hw_src(skb, hdr_size) + ETH_ALEN * 2 + 4; + + return *(__force __be32 *)dst; +} + +/** + * batadv_hash_dat() - compute the hash value for an IP address + * @data: data to hash + * @size: size of the hash table + * + * Return: the selected index in the hash table for the given data. + */ +static u32 batadv_hash_dat(const void *data, u32 size) +{ + u32 hash = 0; + const struct batadv_dat_entry *dat = data; + const unsigned char *key; + __be16 vid; + u32 i; + + key = (__force const unsigned char *)&dat->ip; + for (i = 0; i < sizeof(dat->ip); i++) { + hash += key[i]; + hash += (hash << 10); + hash ^= (hash >> 6); + } + + vid = htons(dat->vid); + key = (__force const unsigned char *)&vid; + for (i = 0; i < sizeof(dat->vid); i++) { + hash += key[i]; + hash += (hash << 10); + hash ^= (hash >> 6); + } + + hash += (hash << 3); + hash ^= (hash >> 11); + hash += (hash << 15); + + return hash % size; +} + +/** + * batadv_dat_entry_hash_find() - look for a given dat_entry in the local hash + * table + * @bat_priv: the bat priv with all the soft interface information + * @ip: search key + * @vid: VLAN identifier + * + * Return: the dat_entry if found, NULL otherwise. + */ +static struct batadv_dat_entry * +batadv_dat_entry_hash_find(struct batadv_priv *bat_priv, __be32 ip, + unsigned short vid) +{ + struct hlist_head *head; + struct batadv_dat_entry to_find, *dat_entry, *dat_entry_tmp = NULL; + struct batadv_hashtable *hash = bat_priv->dat.hash; + u32 index; + + if (!hash) + return NULL; + + to_find.ip = ip; + to_find.vid = vid; + + index = batadv_hash_dat(&to_find, hash->size); + head = &hash->table[index]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(dat_entry, head, hash_entry) { + if (dat_entry->ip != ip) + continue; + + if (!kref_get_unless_zero(&dat_entry->refcount)) + continue; + + dat_entry_tmp = dat_entry; + break; + } + rcu_read_unlock(); + + return dat_entry_tmp; +} + +/** + * batadv_dat_entry_add() - add a new dat entry or update it if already exists + * @bat_priv: the bat priv with all the soft interface information + * @ip: ipv4 to add/edit + * @mac_addr: mac address to assign to the given ipv4 + * @vid: VLAN identifier + */ +static void batadv_dat_entry_add(struct batadv_priv *bat_priv, __be32 ip, + u8 *mac_addr, unsigned short vid) +{ + struct batadv_dat_entry *dat_entry; + int hash_added; + + dat_entry = batadv_dat_entry_hash_find(bat_priv, ip, vid); + /* if this entry is already known, just update it */ + if (dat_entry) { + if (!batadv_compare_eth(dat_entry->mac_addr, mac_addr)) + ether_addr_copy(dat_entry->mac_addr, mac_addr); + dat_entry->last_update = jiffies; + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "Entry updated: %pI4 %pM (vid: %d)\n", + &dat_entry->ip, dat_entry->mac_addr, + batadv_print_vid(vid)); + goto out; + } + + dat_entry = kmalloc(sizeof(*dat_entry), GFP_ATOMIC); + if (!dat_entry) + goto out; + + dat_entry->ip = ip; + dat_entry->vid = vid; + ether_addr_copy(dat_entry->mac_addr, mac_addr); + dat_entry->last_update = jiffies; + kref_init(&dat_entry->refcount); + + kref_get(&dat_entry->refcount); + hash_added = batadv_hash_add(bat_priv->dat.hash, batadv_compare_dat, + batadv_hash_dat, dat_entry, + &dat_entry->hash_entry); + + if (unlikely(hash_added != 0)) { + /* remove the reference for the hash */ + batadv_dat_entry_put(dat_entry); + goto out; + } + + batadv_dbg(BATADV_DBG_DAT, bat_priv, "New entry added: %pI4 %pM (vid: %d)\n", + &dat_entry->ip, dat_entry->mac_addr, batadv_print_vid(vid)); + +out: + batadv_dat_entry_put(dat_entry); +} + +#ifdef CONFIG_BATMAN_ADV_DEBUG + +/** + * batadv_dbg_arp() - print a debug message containing all the ARP packet + * details + * @bat_priv: the bat priv with all the soft interface information + * @skb: ARP packet + * @hdr_size: size of the possible header before the ARP packet + * @msg: message to print together with the debugging information + */ +static void batadv_dbg_arp(struct batadv_priv *bat_priv, struct sk_buff *skb, + int hdr_size, char *msg) +{ + struct batadv_unicast_4addr_packet *unicast_4addr_packet; + struct batadv_bcast_packet *bcast_pkt; + u8 *orig_addr; + __be32 ip_src, ip_dst; + + if (msg) + batadv_dbg(BATADV_DBG_DAT, bat_priv, "%s\n", msg); + + ip_src = batadv_arp_ip_src(skb, hdr_size); + ip_dst = batadv_arp_ip_dst(skb, hdr_size); + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "ARP MSG = [src: %pM-%pI4 dst: %pM-%pI4]\n", + batadv_arp_hw_src(skb, hdr_size), &ip_src, + batadv_arp_hw_dst(skb, hdr_size), &ip_dst); + + if (hdr_size < sizeof(struct batadv_unicast_packet)) + return; + + unicast_4addr_packet = (struct batadv_unicast_4addr_packet *)skb->data; + + switch (unicast_4addr_packet->u.packet_type) { + case BATADV_UNICAST: + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "* encapsulated within a UNICAST packet\n"); + break; + case BATADV_UNICAST_4ADDR: + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "* encapsulated within a UNICAST_4ADDR packet (src: %pM)\n", + unicast_4addr_packet->src); + switch (unicast_4addr_packet->subtype) { + case BATADV_P_DAT_DHT_PUT: + batadv_dbg(BATADV_DBG_DAT, bat_priv, "* type: DAT_DHT_PUT\n"); + break; + case BATADV_P_DAT_DHT_GET: + batadv_dbg(BATADV_DBG_DAT, bat_priv, "* type: DAT_DHT_GET\n"); + break; + case BATADV_P_DAT_CACHE_REPLY: + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "* type: DAT_CACHE_REPLY\n"); + break; + case BATADV_P_DATA: + batadv_dbg(BATADV_DBG_DAT, bat_priv, "* type: DATA\n"); + break; + default: + batadv_dbg(BATADV_DBG_DAT, bat_priv, "* type: Unknown (%u)!\n", + unicast_4addr_packet->u.packet_type); + } + break; + case BATADV_BCAST: + bcast_pkt = (struct batadv_bcast_packet *)unicast_4addr_packet; + orig_addr = bcast_pkt->orig; + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "* encapsulated within a BCAST packet (src: %pM)\n", + orig_addr); + break; + default: + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "* encapsulated within an unknown packet type (0x%x)\n", + unicast_4addr_packet->u.packet_type); + } +} + +#else + +static void batadv_dbg_arp(struct batadv_priv *bat_priv, struct sk_buff *skb, + int hdr_size, char *msg) +{ +} + +#endif /* CONFIG_BATMAN_ADV_DEBUG */ + +/** + * batadv_is_orig_node_eligible() - check whether a node can be a DHT candidate + * @res: the array with the already selected candidates + * @select: number of already selected candidates + * @tmp_max: address of the currently evaluated node + * @max: current round max address + * @last_max: address of the last selected candidate + * @candidate: orig_node under evaluation + * @max_orig_node: last selected candidate + * + * Return: true if the node has been elected as next candidate or false + * otherwise. + */ +static bool batadv_is_orig_node_eligible(struct batadv_dat_candidate *res, + int select, batadv_dat_addr_t tmp_max, + batadv_dat_addr_t max, + batadv_dat_addr_t last_max, + struct batadv_orig_node *candidate, + struct batadv_orig_node *max_orig_node) +{ + bool ret = false; + int j; + + /* check if orig node candidate is running DAT */ + if (!test_bit(BATADV_ORIG_CAPA_HAS_DAT, &candidate->capabilities)) + goto out; + + /* Check if this node has already been selected... */ + for (j = 0; j < select; j++) + if (res[j].orig_node == candidate) + break; + /* ..and possibly skip it */ + if (j < select) + goto out; + /* sanity check: has it already been selected? This should not happen */ + if (tmp_max > last_max) + goto out; + /* check if during this iteration an originator with a closer dht + * address has already been found + */ + if (tmp_max < max) + goto out; + /* this is an hash collision with the temporary selected node. Choose + * the one with the lowest address + */ + if (tmp_max == max && max_orig_node && + batadv_compare_eth(candidate->orig, max_orig_node->orig)) + goto out; + + ret = true; +out: + return ret; +} + +/** + * batadv_choose_next_candidate() - select the next DHT candidate + * @bat_priv: the bat priv with all the soft interface information + * @cands: candidates array + * @select: number of candidates already present in the array + * @ip_key: key to look up in the DHT + * @last_max: pointer where the address of the selected candidate will be saved + */ +static void batadv_choose_next_candidate(struct batadv_priv *bat_priv, + struct batadv_dat_candidate *cands, + int select, batadv_dat_addr_t ip_key, + batadv_dat_addr_t *last_max) +{ + batadv_dat_addr_t max = 0; + batadv_dat_addr_t tmp_max = 0; + struct batadv_orig_node *orig_node, *max_orig_node = NULL; + struct batadv_hashtable *hash = bat_priv->orig_hash; + struct hlist_head *head; + int i; + + /* if no node is eligible as candidate, leave the candidate type as + * NOT_FOUND + */ + cands[select].type = BATADV_DAT_CANDIDATE_NOT_FOUND; + + /* iterate over the originator list and find the node with the closest + * dat_address which has not been selected yet + */ + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(orig_node, head, hash_entry) { + /* the dht space is a ring using unsigned addresses */ + tmp_max = BATADV_DAT_ADDR_MAX - orig_node->dat_addr + + ip_key; + + if (!batadv_is_orig_node_eligible(cands, select, + tmp_max, max, + *last_max, orig_node, + max_orig_node)) + continue; + + if (!kref_get_unless_zero(&orig_node->refcount)) + continue; + + max = tmp_max; + batadv_orig_node_put(max_orig_node); + max_orig_node = orig_node; + } + rcu_read_unlock(); + } + if (max_orig_node) { + cands[select].type = BATADV_DAT_CANDIDATE_ORIG; + cands[select].orig_node = max_orig_node; + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "dat_select_candidates() %d: selected %pM addr=%u dist=%u\n", + select, max_orig_node->orig, max_orig_node->dat_addr, + max); + } + *last_max = max; +} + +/** + * batadv_dat_select_candidates() - select the nodes which the DHT message has + * to be sent to + * @bat_priv: the bat priv with all the soft interface information + * @ip_dst: ipv4 to look up in the DHT + * @vid: VLAN identifier + * + * An originator O is selected if and only if its DHT_ID value is one of three + * closest values (from the LEFT, with wrap around if needed) then the hash + * value of the key. ip_dst is the key. + * + * Return: the candidate array of size BATADV_DAT_CANDIDATE_NUM. + */ +static struct batadv_dat_candidate * +batadv_dat_select_candidates(struct batadv_priv *bat_priv, __be32 ip_dst, + unsigned short vid) +{ + int select; + batadv_dat_addr_t last_max = BATADV_DAT_ADDR_MAX, ip_key; + struct batadv_dat_candidate *res; + struct batadv_dat_entry dat; + + if (!bat_priv->orig_hash) + return NULL; + + res = kmalloc_array(BATADV_DAT_CANDIDATES_NUM, sizeof(*res), + GFP_ATOMIC); + if (!res) + return NULL; + + dat.ip = ip_dst; + dat.vid = vid; + ip_key = (batadv_dat_addr_t)batadv_hash_dat(&dat, + BATADV_DAT_ADDR_MAX); + + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "%s(): IP=%pI4 hash(IP)=%u\n", __func__, &ip_dst, + ip_key); + + for (select = 0; select < BATADV_DAT_CANDIDATES_NUM; select++) + batadv_choose_next_candidate(bat_priv, res, select, ip_key, + &last_max); + + return res; +} + +/** + * batadv_dat_forward_data() - copy and send payload to the selected candidates + * @bat_priv: the bat priv with all the soft interface information + * @skb: payload to send + * @ip: the DHT key + * @vid: VLAN identifier + * @packet_subtype: unicast4addr packet subtype to use + * + * This function copies the skb with pskb_copy() and is sent as a unicast packet + * to each of the selected candidates. + * + * Return: true if the packet is sent to at least one candidate, false + * otherwise. + */ +static bool batadv_dat_forward_data(struct batadv_priv *bat_priv, + struct sk_buff *skb, __be32 ip, + unsigned short vid, int packet_subtype) +{ + int i; + bool ret = false; + int send_status; + struct batadv_neigh_node *neigh_node = NULL; + struct sk_buff *tmp_skb; + struct batadv_dat_candidate *cand; + + cand = batadv_dat_select_candidates(bat_priv, ip, vid); + if (!cand) + goto out; + + batadv_dbg(BATADV_DBG_DAT, bat_priv, "DHT_SEND for %pI4\n", &ip); + + for (i = 0; i < BATADV_DAT_CANDIDATES_NUM; i++) { + if (cand[i].type == BATADV_DAT_CANDIDATE_NOT_FOUND) + continue; + + neigh_node = batadv_orig_router_get(cand[i].orig_node, + BATADV_IF_DEFAULT); + if (!neigh_node) + goto free_orig; + + tmp_skb = pskb_copy_for_clone(skb, GFP_ATOMIC); + if (!batadv_send_skb_prepare_unicast_4addr(bat_priv, tmp_skb, + cand[i].orig_node, + packet_subtype)) { + kfree_skb(tmp_skb); + goto free_neigh; + } + + send_status = batadv_send_unicast_skb(tmp_skb, neigh_node); + if (send_status == NET_XMIT_SUCCESS) { + /* count the sent packet */ + switch (packet_subtype) { + case BATADV_P_DAT_DHT_GET: + batadv_inc_counter(bat_priv, + BATADV_CNT_DAT_GET_TX); + break; + case BATADV_P_DAT_DHT_PUT: + batadv_inc_counter(bat_priv, + BATADV_CNT_DAT_PUT_TX); + break; + } + + /* packet sent to a candidate: return true */ + ret = true; + } +free_neigh: + batadv_neigh_node_put(neigh_node); +free_orig: + batadv_orig_node_put(cand[i].orig_node); + } + +out: + kfree(cand); + return ret; +} + +/** + * batadv_dat_tvlv_container_update() - update the dat tvlv container after dat + * setting change + * @bat_priv: the bat priv with all the soft interface information + */ +static void batadv_dat_tvlv_container_update(struct batadv_priv *bat_priv) +{ + char dat_mode; + + dat_mode = atomic_read(&bat_priv->distributed_arp_table); + + switch (dat_mode) { + case 0: + batadv_tvlv_container_unregister(bat_priv, BATADV_TVLV_DAT, 1); + break; + case 1: + batadv_tvlv_container_register(bat_priv, BATADV_TVLV_DAT, 1, + NULL, 0); + break; + } +} + +/** + * batadv_dat_status_update() - update the dat tvlv container after dat + * setting change + * @net_dev: the soft interface net device + */ +void batadv_dat_status_update(struct net_device *net_dev) +{ + struct batadv_priv *bat_priv = netdev_priv(net_dev); + + batadv_dat_tvlv_container_update(bat_priv); +} + +/** + * batadv_dat_tvlv_ogm_handler_v1() - process incoming dat tvlv container + * @bat_priv: the bat priv with all the soft interface information + * @orig: the orig_node of the ogm + * @flags: flags indicating the tvlv state (see batadv_tvlv_handler_flags) + * @tvlv_value: tvlv buffer containing the gateway data + * @tvlv_value_len: tvlv buffer length + */ +static void batadv_dat_tvlv_ogm_handler_v1(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig, + u8 flags, + void *tvlv_value, u16 tvlv_value_len) +{ + if (flags & BATADV_TVLV_HANDLER_OGM_CIFNOTFND) + clear_bit(BATADV_ORIG_CAPA_HAS_DAT, &orig->capabilities); + else + set_bit(BATADV_ORIG_CAPA_HAS_DAT, &orig->capabilities); +} + +/** + * batadv_dat_hash_free() - free the local DAT hash table + * @bat_priv: the bat priv with all the soft interface information + */ +static void batadv_dat_hash_free(struct batadv_priv *bat_priv) +{ + if (!bat_priv->dat.hash) + return; + + __batadv_dat_purge(bat_priv, NULL); + + batadv_hash_destroy(bat_priv->dat.hash); + + bat_priv->dat.hash = NULL; +} + +/** + * batadv_dat_init() - initialise the DAT internals + * @bat_priv: the bat priv with all the soft interface information + * + * Return: 0 in case of success, a negative error code otherwise + */ +int batadv_dat_init(struct batadv_priv *bat_priv) +{ + if (bat_priv->dat.hash) + return 0; + + bat_priv->dat.hash = batadv_hash_new(1024); + + if (!bat_priv->dat.hash) + return -ENOMEM; + + INIT_DELAYED_WORK(&bat_priv->dat.work, batadv_dat_purge); + batadv_dat_start_timer(bat_priv); + + batadv_tvlv_handler_register(bat_priv, batadv_dat_tvlv_ogm_handler_v1, + NULL, BATADV_TVLV_DAT, 1, + BATADV_TVLV_HANDLER_OGM_CIFNOTFND); + batadv_dat_tvlv_container_update(bat_priv); + return 0; +} + +/** + * batadv_dat_free() - free the DAT internals + * @bat_priv: the bat priv with all the soft interface information + */ +void batadv_dat_free(struct batadv_priv *bat_priv) +{ + batadv_tvlv_container_unregister(bat_priv, BATADV_TVLV_DAT, 1); + batadv_tvlv_handler_unregister(bat_priv, BATADV_TVLV_DAT, 1); + + cancel_delayed_work_sync(&bat_priv->dat.work); + + batadv_dat_hash_free(bat_priv); +} + +/** + * batadv_dat_cache_dump_entry() - dump one entry of the DAT cache table to a + * netlink socket + * @msg: buffer for the message + * @portid: netlink port + * @cb: Control block containing additional options + * @dat_entry: entry to dump + * + * Return: 0 or error code. + */ +static int +batadv_dat_cache_dump_entry(struct sk_buff *msg, u32 portid, + struct netlink_callback *cb, + struct batadv_dat_entry *dat_entry) +{ + int msecs; + void *hdr; + + hdr = genlmsg_put(msg, portid, cb->nlh->nlmsg_seq, + &batadv_netlink_family, NLM_F_MULTI, + BATADV_CMD_GET_DAT_CACHE); + if (!hdr) + return -ENOBUFS; + + genl_dump_check_consistent(cb, hdr); + + msecs = jiffies_to_msecs(jiffies - dat_entry->last_update); + + if (nla_put_in_addr(msg, BATADV_ATTR_DAT_CACHE_IP4ADDRESS, + dat_entry->ip) || + nla_put(msg, BATADV_ATTR_DAT_CACHE_HWADDRESS, ETH_ALEN, + dat_entry->mac_addr) || + nla_put_u16(msg, BATADV_ATTR_DAT_CACHE_VID, dat_entry->vid) || + nla_put_u32(msg, BATADV_ATTR_LAST_SEEN_MSECS, msecs)) { + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; + } + + genlmsg_end(msg, hdr); + return 0; +} + +/** + * batadv_dat_cache_dump_bucket() - dump one bucket of the DAT cache table to + * a netlink socket + * @msg: buffer for the message + * @portid: netlink port + * @cb: Control block containing additional options + * @hash: hash to dump + * @bucket: bucket index to dump + * @idx_skip: How many entries to skip + * + * Return: 0 or error code. + */ +static int +batadv_dat_cache_dump_bucket(struct sk_buff *msg, u32 portid, + struct netlink_callback *cb, + struct batadv_hashtable *hash, unsigned int bucket, + int *idx_skip) +{ + struct batadv_dat_entry *dat_entry; + int idx = 0; + + spin_lock_bh(&hash->list_locks[bucket]); + cb->seq = atomic_read(&hash->generation) << 1 | 1; + + hlist_for_each_entry(dat_entry, &hash->table[bucket], hash_entry) { + if (idx < *idx_skip) + goto skip; + + if (batadv_dat_cache_dump_entry(msg, portid, cb, dat_entry)) { + spin_unlock_bh(&hash->list_locks[bucket]); + *idx_skip = idx; + + return -EMSGSIZE; + } + +skip: + idx++; + } + spin_unlock_bh(&hash->list_locks[bucket]); + + return 0; +} + +/** + * batadv_dat_cache_dump() - dump DAT cache table to a netlink socket + * @msg: buffer for the message + * @cb: callback structure containing arguments + * + * Return: message length. + */ +int batadv_dat_cache_dump(struct sk_buff *msg, struct netlink_callback *cb) +{ + struct batadv_hard_iface *primary_if = NULL; + int portid = NETLINK_CB(cb->skb).portid; + struct net *net = sock_net(cb->skb->sk); + struct net_device *soft_iface; + struct batadv_hashtable *hash; + struct batadv_priv *bat_priv; + int bucket = cb->args[0]; + int idx = cb->args[1]; + int ifindex; + int ret = 0; + + ifindex = batadv_netlink_get_ifindex(cb->nlh, + BATADV_ATTR_MESH_IFINDEX); + if (!ifindex) + return -EINVAL; + + soft_iface = dev_get_by_index(net, ifindex); + if (!soft_iface || !batadv_softif_is_valid(soft_iface)) { + ret = -ENODEV; + goto out; + } + + bat_priv = netdev_priv(soft_iface); + hash = bat_priv->dat.hash; + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if || primary_if->if_status != BATADV_IF_ACTIVE) { + ret = -ENOENT; + goto out; + } + + while (bucket < hash->size) { + if (batadv_dat_cache_dump_bucket(msg, portid, cb, hash, bucket, + &idx)) + break; + + bucket++; + idx = 0; + } + + cb->args[0] = bucket; + cb->args[1] = idx; + + ret = msg->len; + +out: + batadv_hardif_put(primary_if); + + dev_put(soft_iface); + + return ret; +} + +/** + * batadv_arp_get_type() - parse an ARP packet and gets the type + * @bat_priv: the bat priv with all the soft interface information + * @skb: packet to analyse + * @hdr_size: size of the possible header before the ARP packet in the skb + * + * Return: the ARP type if the skb contains a valid ARP packet, 0 otherwise. + */ +static u16 batadv_arp_get_type(struct batadv_priv *bat_priv, + struct sk_buff *skb, int hdr_size) +{ + struct arphdr *arphdr; + struct ethhdr *ethhdr; + __be32 ip_src, ip_dst; + u8 *hw_src, *hw_dst; + u16 type = 0; + + /* pull the ethernet header */ + if (unlikely(!pskb_may_pull(skb, hdr_size + ETH_HLEN))) + goto out; + + ethhdr = (struct ethhdr *)(skb->data + hdr_size); + + if (ethhdr->h_proto != htons(ETH_P_ARP)) + goto out; + + /* pull the ARP payload */ + if (unlikely(!pskb_may_pull(skb, hdr_size + ETH_HLEN + + arp_hdr_len(skb->dev)))) + goto out; + + arphdr = (struct arphdr *)(skb->data + hdr_size + ETH_HLEN); + + /* check whether the ARP packet carries a valid IP information */ + if (arphdr->ar_hrd != htons(ARPHRD_ETHER)) + goto out; + + if (arphdr->ar_pro != htons(ETH_P_IP)) + goto out; + + if (arphdr->ar_hln != ETH_ALEN) + goto out; + + if (arphdr->ar_pln != 4) + goto out; + + /* Check for bad reply/request. If the ARP message is not sane, DAT + * will simply ignore it + */ + ip_src = batadv_arp_ip_src(skb, hdr_size); + ip_dst = batadv_arp_ip_dst(skb, hdr_size); + if (ipv4_is_loopback(ip_src) || ipv4_is_multicast(ip_src) || + ipv4_is_loopback(ip_dst) || ipv4_is_multicast(ip_dst) || + ipv4_is_zeronet(ip_src) || ipv4_is_lbcast(ip_src) || + ipv4_is_zeronet(ip_dst) || ipv4_is_lbcast(ip_dst)) + goto out; + + hw_src = batadv_arp_hw_src(skb, hdr_size); + if (is_zero_ether_addr(hw_src) || is_multicast_ether_addr(hw_src)) + goto out; + + /* don't care about the destination MAC address in ARP requests */ + if (arphdr->ar_op != htons(ARPOP_REQUEST)) { + hw_dst = batadv_arp_hw_dst(skb, hdr_size); + if (is_zero_ether_addr(hw_dst) || + is_multicast_ether_addr(hw_dst)) + goto out; + } + + type = ntohs(arphdr->ar_op); +out: + return type; +} + +/** + * batadv_dat_get_vid() - extract the VLAN identifier from skb if any + * @skb: the buffer containing the packet to extract the VID from + * @hdr_size: the size of the batman-adv header encapsulating the packet + * + * Return: If the packet embedded in the skb is vlan tagged this function + * returns the VID with the BATADV_VLAN_HAS_TAG flag. Otherwise BATADV_NO_FLAGS + * is returned. + */ +static unsigned short batadv_dat_get_vid(struct sk_buff *skb, int *hdr_size) +{ + unsigned short vid; + + vid = batadv_get_vid(skb, *hdr_size); + + /* ARP parsing functions jump forward of hdr_size + ETH_HLEN. + * If the header contained in the packet is a VLAN one (which is longer) + * hdr_size is updated so that the functions will still skip the + * correct amount of bytes. + */ + if (vid & BATADV_VLAN_HAS_TAG) + *hdr_size += VLAN_HLEN; + + return vid; +} + +/** + * batadv_dat_arp_create_reply() - create an ARP Reply + * @bat_priv: the bat priv with all the soft interface information + * @ip_src: ARP sender IP + * @ip_dst: ARP target IP + * @hw_src: Ethernet source and ARP sender MAC + * @hw_dst: Ethernet destination and ARP target MAC + * @vid: VLAN identifier (optional, set to zero otherwise) + * + * Creates an ARP Reply from the given values, optionally encapsulated in a + * VLAN header. + * + * Return: An skb containing an ARP Reply. + */ +static struct sk_buff * +batadv_dat_arp_create_reply(struct batadv_priv *bat_priv, __be32 ip_src, + __be32 ip_dst, u8 *hw_src, u8 *hw_dst, + unsigned short vid) +{ + struct sk_buff *skb; + + skb = arp_create(ARPOP_REPLY, ETH_P_ARP, ip_dst, bat_priv->soft_iface, + ip_src, hw_dst, hw_src, hw_dst); + if (!skb) + return NULL; + + skb_reset_mac_header(skb); + + if (vid & BATADV_VLAN_HAS_TAG) + skb = vlan_insert_tag(skb, htons(ETH_P_8021Q), + vid & VLAN_VID_MASK); + + return skb; +} + +/** + * batadv_dat_snoop_outgoing_arp_request() - snoop the ARP request and try to + * answer using DAT + * @bat_priv: the bat priv with all the soft interface information + * @skb: packet to check + * + * Return: true if the message has been sent to the dht candidates, false + * otherwise. In case of a positive return value the message has to be enqueued + * to permit the fallback. + */ +bool batadv_dat_snoop_outgoing_arp_request(struct batadv_priv *bat_priv, + struct sk_buff *skb) +{ + u16 type = 0; + __be32 ip_dst, ip_src; + u8 *hw_src; + bool ret = false; + struct batadv_dat_entry *dat_entry = NULL; + struct sk_buff *skb_new; + struct net_device *soft_iface = bat_priv->soft_iface; + int hdr_size = 0; + unsigned short vid; + + if (!atomic_read(&bat_priv->distributed_arp_table)) + goto out; + + vid = batadv_dat_get_vid(skb, &hdr_size); + + type = batadv_arp_get_type(bat_priv, skb, hdr_size); + /* If the node gets an ARP_REQUEST it has to send a DHT_GET unicast + * message to the selected DHT candidates + */ + if (type != ARPOP_REQUEST) + goto out; + + batadv_dbg_arp(bat_priv, skb, hdr_size, "Parsing outgoing ARP REQUEST"); + + ip_src = batadv_arp_ip_src(skb, hdr_size); + hw_src = batadv_arp_hw_src(skb, hdr_size); + ip_dst = batadv_arp_ip_dst(skb, hdr_size); + + batadv_dat_entry_add(bat_priv, ip_src, hw_src, vid); + + dat_entry = batadv_dat_entry_hash_find(bat_priv, ip_dst, vid); + if (dat_entry) { + /* If the ARP request is destined for a local client the local + * client will answer itself. DAT would only generate a + * duplicate packet. + * + * Moreover, if the soft-interface is enslaved into a bridge, an + * additional DAT answer may trigger kernel warnings about + * a packet coming from the wrong port. + */ + if (batadv_is_my_client(bat_priv, dat_entry->mac_addr, vid)) { + ret = true; + goto out; + } + + /* If BLA is enabled, only send ARP replies if we have claimed + * the destination for the ARP request or if no one else of + * the backbone gws belonging to our backbone has claimed the + * destination. + */ + if (!batadv_bla_check_claim(bat_priv, + dat_entry->mac_addr, vid)) { + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "Device %pM claimed by another backbone gw. Don't send ARP reply!", + dat_entry->mac_addr); + ret = true; + goto out; + } + + skb_new = batadv_dat_arp_create_reply(bat_priv, ip_dst, ip_src, + dat_entry->mac_addr, + hw_src, vid); + if (!skb_new) + goto out; + + skb_new->protocol = eth_type_trans(skb_new, soft_iface); + + batadv_inc_counter(bat_priv, BATADV_CNT_RX); + batadv_add_counter(bat_priv, BATADV_CNT_RX_BYTES, + skb->len + ETH_HLEN + hdr_size); + + netif_rx(skb_new); + batadv_dbg(BATADV_DBG_DAT, bat_priv, "ARP request replied locally\n"); + ret = true; + } else { + /* Send the request to the DHT */ + ret = batadv_dat_forward_data(bat_priv, skb, ip_dst, vid, + BATADV_P_DAT_DHT_GET); + } +out: + batadv_dat_entry_put(dat_entry); + return ret; +} + +/** + * batadv_dat_snoop_incoming_arp_request() - snoop the ARP request and try to + * answer using the local DAT storage + * @bat_priv: the bat priv with all the soft interface information + * @skb: packet to check + * @hdr_size: size of the encapsulation header + * + * Return: true if the request has been answered, false otherwise. + */ +bool batadv_dat_snoop_incoming_arp_request(struct batadv_priv *bat_priv, + struct sk_buff *skb, int hdr_size) +{ + u16 type; + __be32 ip_src, ip_dst; + u8 *hw_src; + struct sk_buff *skb_new; + struct batadv_dat_entry *dat_entry = NULL; + bool ret = false; + unsigned short vid; + int err; + + if (!atomic_read(&bat_priv->distributed_arp_table)) + goto out; + + vid = batadv_dat_get_vid(skb, &hdr_size); + + type = batadv_arp_get_type(bat_priv, skb, hdr_size); + if (type != ARPOP_REQUEST) + goto out; + + hw_src = batadv_arp_hw_src(skb, hdr_size); + ip_src = batadv_arp_ip_src(skb, hdr_size); + ip_dst = batadv_arp_ip_dst(skb, hdr_size); + + batadv_dbg_arp(bat_priv, skb, hdr_size, "Parsing incoming ARP REQUEST"); + + batadv_dat_entry_add(bat_priv, ip_src, hw_src, vid); + + dat_entry = batadv_dat_entry_hash_find(bat_priv, ip_dst, vid); + if (!dat_entry) + goto out; + + skb_new = batadv_dat_arp_create_reply(bat_priv, ip_dst, ip_src, + dat_entry->mac_addr, hw_src, vid); + if (!skb_new) + goto out; + + /* To preserve backwards compatibility, the node has choose the outgoing + * format based on the incoming request packet type. The assumption is + * that a node not using the 4addr packet format doesn't support it. + */ + if (hdr_size == sizeof(struct batadv_unicast_4addr_packet)) + err = batadv_send_skb_via_tt_4addr(bat_priv, skb_new, + BATADV_P_DAT_CACHE_REPLY, + NULL, vid); + else + err = batadv_send_skb_via_tt(bat_priv, skb_new, NULL, vid); + + if (err != NET_XMIT_DROP) { + batadv_inc_counter(bat_priv, BATADV_CNT_DAT_CACHED_REPLY_TX); + ret = true; + } +out: + batadv_dat_entry_put(dat_entry); + if (ret) + kfree_skb(skb); + return ret; +} + +/** + * batadv_dat_snoop_outgoing_arp_reply() - snoop the ARP reply and fill the DHT + * @bat_priv: the bat priv with all the soft interface information + * @skb: packet to check + */ +void batadv_dat_snoop_outgoing_arp_reply(struct batadv_priv *bat_priv, + struct sk_buff *skb) +{ + u16 type; + __be32 ip_src, ip_dst; + u8 *hw_src, *hw_dst; + int hdr_size = 0; + unsigned short vid; + + if (!atomic_read(&bat_priv->distributed_arp_table)) + return; + + vid = batadv_dat_get_vid(skb, &hdr_size); + + type = batadv_arp_get_type(bat_priv, skb, hdr_size); + if (type != ARPOP_REPLY) + return; + + batadv_dbg_arp(bat_priv, skb, hdr_size, "Parsing outgoing ARP REPLY"); + + hw_src = batadv_arp_hw_src(skb, hdr_size); + ip_src = batadv_arp_ip_src(skb, hdr_size); + hw_dst = batadv_arp_hw_dst(skb, hdr_size); + ip_dst = batadv_arp_ip_dst(skb, hdr_size); + + batadv_dat_entry_add(bat_priv, ip_src, hw_src, vid); + batadv_dat_entry_add(bat_priv, ip_dst, hw_dst, vid); + + /* Send the ARP reply to the candidates for both the IP addresses that + * the node obtained from the ARP reply + */ + batadv_dat_forward_data(bat_priv, skb, ip_src, vid, + BATADV_P_DAT_DHT_PUT); + batadv_dat_forward_data(bat_priv, skb, ip_dst, vid, + BATADV_P_DAT_DHT_PUT); +} + +/** + * batadv_dat_snoop_incoming_arp_reply() - snoop the ARP reply and fill the + * local DAT storage only + * @bat_priv: the bat priv with all the soft interface information + * @skb: packet to check + * @hdr_size: size of the encapsulation header + * + * Return: true if the packet was snooped and consumed by DAT. False if the + * packet has to be delivered to the interface + */ +bool batadv_dat_snoop_incoming_arp_reply(struct batadv_priv *bat_priv, + struct sk_buff *skb, int hdr_size) +{ + struct batadv_dat_entry *dat_entry = NULL; + u16 type; + __be32 ip_src, ip_dst; + u8 *hw_src, *hw_dst; + bool dropped = false; + unsigned short vid; + + if (!atomic_read(&bat_priv->distributed_arp_table)) + goto out; + + vid = batadv_dat_get_vid(skb, &hdr_size); + + type = batadv_arp_get_type(bat_priv, skb, hdr_size); + if (type != ARPOP_REPLY) + goto out; + + batadv_dbg_arp(bat_priv, skb, hdr_size, "Parsing incoming ARP REPLY"); + + hw_src = batadv_arp_hw_src(skb, hdr_size); + ip_src = batadv_arp_ip_src(skb, hdr_size); + hw_dst = batadv_arp_hw_dst(skb, hdr_size); + ip_dst = batadv_arp_ip_dst(skb, hdr_size); + + /* If ip_dst is already in cache and has the right mac address, + * drop this frame if this ARP reply is destined for us because it's + * most probably an ARP reply generated by another node of the DHT. + * We have most probably received already a reply earlier. Delivering + * this frame would lead to doubled receive of an ARP reply. + */ + dat_entry = batadv_dat_entry_hash_find(bat_priv, ip_src, vid); + if (dat_entry && batadv_compare_eth(hw_src, dat_entry->mac_addr)) { + batadv_dbg(BATADV_DBG_DAT, bat_priv, "Doubled ARP reply removed: ARP MSG = [src: %pM-%pI4 dst: %pM-%pI4]; dat_entry: %pM-%pI4\n", + hw_src, &ip_src, hw_dst, &ip_dst, + dat_entry->mac_addr, &dat_entry->ip); + dropped = true; + } + + /* Update our internal cache with both the IP addresses the node got + * within the ARP reply + */ + batadv_dat_entry_add(bat_priv, ip_src, hw_src, vid); + batadv_dat_entry_add(bat_priv, ip_dst, hw_dst, vid); + + if (dropped) + goto out; + + /* If BLA is enabled, only forward ARP replies if we have claimed the + * source of the ARP reply or if no one else of the same backbone has + * already claimed that client. This prevents that different gateways + * to the same backbone all forward the ARP reply leading to multiple + * replies in the backbone. + */ + if (!batadv_bla_check_claim(bat_priv, hw_src, vid)) { + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "Device %pM claimed by another backbone gw. Drop ARP reply.\n", + hw_src); + dropped = true; + goto out; + } + + /* if this REPLY is directed to a client of mine, let's deliver the + * packet to the interface + */ + dropped = !batadv_is_my_client(bat_priv, hw_dst, vid); + + /* if this REPLY is sent on behalf of a client of mine, let's drop the + * packet because the client will reply by itself + */ + dropped |= batadv_is_my_client(bat_priv, hw_src, vid); +out: + if (dropped) + kfree_skb(skb); + batadv_dat_entry_put(dat_entry); + /* if dropped == false -> deliver to the interface */ + return dropped; +} + +/** + * batadv_dat_check_dhcp_ipudp() - check skb for IP+UDP headers valid for DHCP + * @skb: the packet to check + * @ip_src: a buffer to store the IPv4 source address in + * + * Checks whether the given skb has an IP and UDP header valid for a DHCP + * message from a DHCP server. And if so, stores the IPv4 source address in + * the provided buffer. + * + * Return: True if valid, false otherwise. + */ +static bool +batadv_dat_check_dhcp_ipudp(struct sk_buff *skb, __be32 *ip_src) +{ + unsigned int offset = skb_network_offset(skb); + struct udphdr *udphdr, _udphdr; + struct iphdr *iphdr, _iphdr; + + iphdr = skb_header_pointer(skb, offset, sizeof(_iphdr), &_iphdr); + if (!iphdr || iphdr->version != 4 || iphdr->ihl * 4 < sizeof(_iphdr)) + return false; + + if (iphdr->protocol != IPPROTO_UDP) + return false; + + offset += iphdr->ihl * 4; + skb_set_transport_header(skb, offset); + + udphdr = skb_header_pointer(skb, offset, sizeof(_udphdr), &_udphdr); + if (!udphdr || udphdr->source != htons(67)) + return false; + + *ip_src = get_unaligned(&iphdr->saddr); + + return true; +} + +/** + * batadv_dat_check_dhcp() - examine packet for valid DHCP message + * @skb: the packet to check + * @proto: ethernet protocol hint (behind a potential vlan) + * @ip_src: a buffer to store the IPv4 source address in + * + * Checks whether the given skb is a valid DHCP packet. And if so, stores the + * IPv4 source address in the provided buffer. + * + * Caller needs to ensure that the skb network header is set correctly. + * + * Return: If skb is a valid DHCP packet, then returns its op code + * (e.g. BOOTREPLY vs. BOOTREQUEST). Otherwise returns -EINVAL. + */ +static int +batadv_dat_check_dhcp(struct sk_buff *skb, __be16 proto, __be32 *ip_src) +{ + __be32 *magic, _magic; + unsigned int offset; + struct { + __u8 op; + __u8 htype; + __u8 hlen; + __u8 hops; + } *dhcp_h, _dhcp_h; + + if (proto != htons(ETH_P_IP)) + return -EINVAL; + + if (!batadv_dat_check_dhcp_ipudp(skb, ip_src)) + return -EINVAL; + + offset = skb_transport_offset(skb) + sizeof(struct udphdr); + if (skb->len < offset + sizeof(struct batadv_dhcp_packet)) + return -EINVAL; + + dhcp_h = skb_header_pointer(skb, offset, sizeof(_dhcp_h), &_dhcp_h); + if (!dhcp_h || dhcp_h->htype != BATADV_HTYPE_ETHERNET || + dhcp_h->hlen != ETH_ALEN) + return -EINVAL; + + offset += offsetof(struct batadv_dhcp_packet, magic); + + magic = skb_header_pointer(skb, offset, sizeof(_magic), &_magic); + if (!magic || get_unaligned(magic) != htonl(BATADV_DHCP_MAGIC)) + return -EINVAL; + + return dhcp_h->op; +} + +/** + * batadv_dat_get_dhcp_message_type() - get message type of a DHCP packet + * @skb: the DHCP packet to parse + * + * Iterates over the DHCP options of the given DHCP packet to find a + * DHCP Message Type option and parse it. + * + * Caller needs to ensure that the given skb is a valid DHCP packet and + * that the skb transport header is set correctly. + * + * Return: The found DHCP message type value, if found. -EINVAL otherwise. + */ +static int batadv_dat_get_dhcp_message_type(struct sk_buff *skb) +{ + unsigned int offset = skb_transport_offset(skb) + sizeof(struct udphdr); + u8 *type, _type; + struct { + u8 type; + u8 len; + } *tl, _tl; + + offset += sizeof(struct batadv_dhcp_packet); + + while ((tl = skb_header_pointer(skb, offset, sizeof(_tl), &_tl))) { + if (tl->type == BATADV_DHCP_OPT_MSG_TYPE) + break; + + if (tl->type == BATADV_DHCP_OPT_END) + break; + + if (tl->type == BATADV_DHCP_OPT_PAD) + offset++; + else + offset += tl->len + sizeof(_tl); + } + + /* Option Overload Code not supported */ + if (!tl || tl->type != BATADV_DHCP_OPT_MSG_TYPE || + tl->len != sizeof(_type)) + return -EINVAL; + + offset += sizeof(_tl); + + type = skb_header_pointer(skb, offset, sizeof(_type), &_type); + if (!type) + return -EINVAL; + + return *type; +} + +/** + * batadv_dat_dhcp_get_yiaddr() - get yiaddr from a DHCP packet + * @skb: the DHCP packet to parse + * @buf: a buffer to store the yiaddr in + * + * Caller needs to ensure that the given skb is a valid DHCP packet and + * that the skb transport header is set correctly. + * + * Return: True on success, false otherwise. + */ +static bool batadv_dat_dhcp_get_yiaddr(struct sk_buff *skb, __be32 *buf) +{ + unsigned int offset = skb_transport_offset(skb) + sizeof(struct udphdr); + __be32 *yiaddr; + + offset += offsetof(struct batadv_dhcp_packet, yiaddr); + yiaddr = skb_header_pointer(skb, offset, BATADV_DHCP_YIADDR_LEN, buf); + + if (!yiaddr) + return false; + + if (yiaddr != buf) + *buf = get_unaligned(yiaddr); + + return true; +} + +/** + * batadv_dat_get_dhcp_chaddr() - get chaddr from a DHCP packet + * @skb: the DHCP packet to parse + * @buf: a buffer to store the chaddr in + * + * Caller needs to ensure that the given skb is a valid DHCP packet and + * that the skb transport header is set correctly. + * + * Return: True on success, false otherwise + */ +static bool batadv_dat_get_dhcp_chaddr(struct sk_buff *skb, u8 *buf) +{ + unsigned int offset = skb_transport_offset(skb) + sizeof(struct udphdr); + u8 *chaddr; + + offset += offsetof(struct batadv_dhcp_packet, chaddr); + chaddr = skb_header_pointer(skb, offset, BATADV_DHCP_CHADDR_LEN, buf); + + if (!chaddr) + return false; + + if (chaddr != buf) + memcpy(buf, chaddr, BATADV_DHCP_CHADDR_LEN); + + return true; +} + +/** + * batadv_dat_put_dhcp() - puts addresses from a DHCP packet into the DHT and + * DAT cache + * @bat_priv: the bat priv with all the soft interface information + * @chaddr: the DHCP client MAC address + * @yiaddr: the DHCP client IP address + * @hw_dst: the DHCP server MAC address + * @ip_dst: the DHCP server IP address + * @vid: VLAN identifier + * + * Adds given MAC/IP pairs to the local DAT cache and propagates them further + * into the DHT. + * + * For the DHT propagation, client MAC + IP will appear as the ARP Reply + * transmitter (and hw_dst/ip_dst as the target). + */ +static void batadv_dat_put_dhcp(struct batadv_priv *bat_priv, u8 *chaddr, + __be32 yiaddr, u8 *hw_dst, __be32 ip_dst, + unsigned short vid) +{ + struct sk_buff *skb; + + skb = batadv_dat_arp_create_reply(bat_priv, yiaddr, ip_dst, chaddr, + hw_dst, vid); + if (!skb) + return; + + skb_set_network_header(skb, ETH_HLEN); + + batadv_dat_entry_add(bat_priv, yiaddr, chaddr, vid); + batadv_dat_entry_add(bat_priv, ip_dst, hw_dst, vid); + + batadv_dat_forward_data(bat_priv, skb, yiaddr, vid, + BATADV_P_DAT_DHT_PUT); + batadv_dat_forward_data(bat_priv, skb, ip_dst, vid, + BATADV_P_DAT_DHT_PUT); + + consume_skb(skb); + + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "Snooped from outgoing DHCPACK (server address): %pI4, %pM (vid: %i)\n", + &ip_dst, hw_dst, batadv_print_vid(vid)); + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "Snooped from outgoing DHCPACK (client address): %pI4, %pM (vid: %i)\n", + &yiaddr, chaddr, batadv_print_vid(vid)); +} + +/** + * batadv_dat_check_dhcp_ack() - examine packet for valid DHCP message + * @skb: the packet to check + * @proto: ethernet protocol hint (behind a potential vlan) + * @ip_src: a buffer to store the IPv4 source address in + * @chaddr: a buffer to store the DHCP Client Hardware Address in + * @yiaddr: a buffer to store the DHCP Your IP Address in + * + * Checks whether the given skb is a valid DHCPACK. And if so, stores the + * IPv4 server source address (ip_src), client MAC address (chaddr) and client + * IPv4 address (yiaddr) in the provided buffers. + * + * Caller needs to ensure that the skb network header is set correctly. + * + * Return: True if the skb is a valid DHCPACK. False otherwise. + */ +static bool +batadv_dat_check_dhcp_ack(struct sk_buff *skb, __be16 proto, __be32 *ip_src, + u8 *chaddr, __be32 *yiaddr) +{ + int type; + + type = batadv_dat_check_dhcp(skb, proto, ip_src); + if (type != BATADV_BOOTREPLY) + return false; + + type = batadv_dat_get_dhcp_message_type(skb); + if (type != BATADV_DHCPACK) + return false; + + if (!batadv_dat_dhcp_get_yiaddr(skb, yiaddr)) + return false; + + if (!batadv_dat_get_dhcp_chaddr(skb, chaddr)) + return false; + + return true; +} + +/** + * batadv_dat_snoop_outgoing_dhcp_ack() - snoop DHCPACK and fill DAT with it + * @bat_priv: the bat priv with all the soft interface information + * @skb: the packet to snoop + * @proto: ethernet protocol hint (behind a potential vlan) + * @vid: VLAN identifier + * + * This function first checks whether the given skb is a valid DHCPACK. If + * so then its source MAC and IP as well as its DHCP Client Hardware Address + * field and DHCP Your IP Address field are added to the local DAT cache and + * propagated into the DHT. + * + * Caller needs to ensure that the skb mac and network headers are set + * correctly. + */ +void batadv_dat_snoop_outgoing_dhcp_ack(struct batadv_priv *bat_priv, + struct sk_buff *skb, + __be16 proto, + unsigned short vid) +{ + u8 chaddr[BATADV_DHCP_CHADDR_LEN]; + __be32 ip_src, yiaddr; + + if (!atomic_read(&bat_priv->distributed_arp_table)) + return; + + if (!batadv_dat_check_dhcp_ack(skb, proto, &ip_src, chaddr, &yiaddr)) + return; + + batadv_dat_put_dhcp(bat_priv, chaddr, yiaddr, eth_hdr(skb)->h_source, + ip_src, vid); +} + +/** + * batadv_dat_snoop_incoming_dhcp_ack() - snoop DHCPACK and fill DAT cache + * @bat_priv: the bat priv with all the soft interface information + * @skb: the packet to snoop + * @hdr_size: header size, up to the tail of the batman-adv header + * + * This function first checks whether the given skb is a valid DHCPACK. If + * so then its source MAC and IP as well as its DHCP Client Hardware Address + * field and DHCP Your IP Address field are added to the local DAT cache. + */ +void batadv_dat_snoop_incoming_dhcp_ack(struct batadv_priv *bat_priv, + struct sk_buff *skb, int hdr_size) +{ + u8 chaddr[BATADV_DHCP_CHADDR_LEN]; + struct ethhdr *ethhdr; + __be32 ip_src, yiaddr; + unsigned short vid; + __be16 proto; + u8 *hw_src; + + if (!atomic_read(&bat_priv->distributed_arp_table)) + return; + + if (unlikely(!pskb_may_pull(skb, hdr_size + ETH_HLEN))) + return; + + ethhdr = (struct ethhdr *)(skb->data + hdr_size); + skb_set_network_header(skb, hdr_size + ETH_HLEN); + proto = ethhdr->h_proto; + + if (!batadv_dat_check_dhcp_ack(skb, proto, &ip_src, chaddr, &yiaddr)) + return; + + hw_src = ethhdr->h_source; + vid = batadv_dat_get_vid(skb, &hdr_size); + + batadv_dat_entry_add(bat_priv, yiaddr, chaddr, vid); + batadv_dat_entry_add(bat_priv, ip_src, hw_src, vid); + + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "Snooped from incoming DHCPACK (server address): %pI4, %pM (vid: %i)\n", + &ip_src, hw_src, batadv_print_vid(vid)); + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "Snooped from incoming DHCPACK (client address): %pI4, %pM (vid: %i)\n", + &yiaddr, chaddr, batadv_print_vid(vid)); +} + +/** + * batadv_dat_drop_broadcast_packet() - check if an ARP request has to be + * dropped (because the node has already obtained the reply via DAT) or not + * @bat_priv: the bat priv with all the soft interface information + * @forw_packet: the broadcast packet + * + * Return: true if the node can drop the packet, false otherwise. + */ +bool batadv_dat_drop_broadcast_packet(struct batadv_priv *bat_priv, + struct batadv_forw_packet *forw_packet) +{ + u16 type; + __be32 ip_dst; + struct batadv_dat_entry *dat_entry = NULL; + bool ret = false; + int hdr_size = sizeof(struct batadv_bcast_packet); + unsigned short vid; + + if (!atomic_read(&bat_priv->distributed_arp_table)) + goto out; + + /* If this packet is an ARP_REQUEST and the node already has the + * information that it is going to ask, then the packet can be dropped + */ + if (batadv_forw_packet_is_rebroadcast(forw_packet)) + goto out; + + vid = batadv_dat_get_vid(forw_packet->skb, &hdr_size); + + type = batadv_arp_get_type(bat_priv, forw_packet->skb, hdr_size); + if (type != ARPOP_REQUEST) + goto out; + + ip_dst = batadv_arp_ip_dst(forw_packet->skb, hdr_size); + dat_entry = batadv_dat_entry_hash_find(bat_priv, ip_dst, vid); + /* check if the node already got this entry */ + if (!dat_entry) { + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "ARP Request for %pI4: fallback\n", &ip_dst); + goto out; + } + + batadv_dbg(BATADV_DBG_DAT, bat_priv, + "ARP Request for %pI4: fallback prevented\n", &ip_dst); + ret = true; + +out: + batadv_dat_entry_put(dat_entry); + return ret; +} diff --git a/net/batman-adv/distributed-arp-table.h b/net/batman-adv/distributed-arp-table.h new file mode 100644 index 000000000..bed7f3d20 --- /dev/null +++ b/net/batman-adv/distributed-arp-table.h @@ -0,0 +1,186 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Antonio Quartulli + */ + +#ifndef _NET_BATMAN_ADV_DISTRIBUTED_ARP_TABLE_H_ +#define _NET_BATMAN_ADV_DISTRIBUTED_ARP_TABLE_H_ + +#include "main.h" + +#include +#include +#include +#include +#include +#include + +#include "originator.h" + +#ifdef CONFIG_BATMAN_ADV_DAT + +/* BATADV_DAT_ADDR_MAX - maximum address value in the DHT space */ +#define BATADV_DAT_ADDR_MAX ((batadv_dat_addr_t)~(batadv_dat_addr_t)0) + +void batadv_dat_status_update(struct net_device *net_dev); +bool batadv_dat_snoop_outgoing_arp_request(struct batadv_priv *bat_priv, + struct sk_buff *skb); +bool batadv_dat_snoop_incoming_arp_request(struct batadv_priv *bat_priv, + struct sk_buff *skb, int hdr_size); +void batadv_dat_snoop_outgoing_arp_reply(struct batadv_priv *bat_priv, + struct sk_buff *skb); +bool batadv_dat_snoop_incoming_arp_reply(struct batadv_priv *bat_priv, + struct sk_buff *skb, int hdr_size); +void batadv_dat_snoop_outgoing_dhcp_ack(struct batadv_priv *bat_priv, + struct sk_buff *skb, + __be16 proto, + unsigned short vid); +void batadv_dat_snoop_incoming_dhcp_ack(struct batadv_priv *bat_priv, + struct sk_buff *skb, int hdr_size); +bool batadv_dat_drop_broadcast_packet(struct batadv_priv *bat_priv, + struct batadv_forw_packet *forw_packet); + +/** + * batadv_dat_init_orig_node_addr() - assign a DAT address to the orig_node + * @orig_node: the node to assign the DAT address to + */ +static inline void +batadv_dat_init_orig_node_addr(struct batadv_orig_node *orig_node) +{ + u32 addr; + + addr = batadv_choose_orig(orig_node->orig, BATADV_DAT_ADDR_MAX); + orig_node->dat_addr = (batadv_dat_addr_t)addr; +} + +/** + * batadv_dat_init_own_addr() - assign a DAT address to the node itself + * @bat_priv: the bat priv with all the soft interface information + * @primary_if: a pointer to the primary interface + */ +static inline void +batadv_dat_init_own_addr(struct batadv_priv *bat_priv, + struct batadv_hard_iface *primary_if) +{ + u32 addr; + + addr = batadv_choose_orig(primary_if->net_dev->dev_addr, + BATADV_DAT_ADDR_MAX); + + bat_priv->dat.addr = (batadv_dat_addr_t)addr; +} + +int batadv_dat_init(struct batadv_priv *bat_priv); +void batadv_dat_free(struct batadv_priv *bat_priv); +int batadv_dat_cache_dump(struct sk_buff *msg, struct netlink_callback *cb); + +/** + * batadv_dat_inc_counter() - increment the correct DAT packet counter + * @bat_priv: the bat priv with all the soft interface information + * @subtype: the 4addr subtype of the packet to be counted + * + * Updates the ethtool statistics for the received packet if it is a DAT subtype + */ +static inline void batadv_dat_inc_counter(struct batadv_priv *bat_priv, + u8 subtype) +{ + switch (subtype) { + case BATADV_P_DAT_DHT_GET: + batadv_inc_counter(bat_priv, + BATADV_CNT_DAT_GET_RX); + break; + case BATADV_P_DAT_DHT_PUT: + batadv_inc_counter(bat_priv, + BATADV_CNT_DAT_PUT_RX); + break; + } +} + +#else + +static inline void batadv_dat_status_update(struct net_device *net_dev) +{ +} + +static inline bool +batadv_dat_snoop_outgoing_arp_request(struct batadv_priv *bat_priv, + struct sk_buff *skb) +{ + return false; +} + +static inline bool +batadv_dat_snoop_incoming_arp_request(struct batadv_priv *bat_priv, + struct sk_buff *skb, int hdr_size) +{ + return false; +} + +static inline bool +batadv_dat_snoop_outgoing_arp_reply(struct batadv_priv *bat_priv, + struct sk_buff *skb) +{ + return false; +} + +static inline bool +batadv_dat_snoop_incoming_arp_reply(struct batadv_priv *bat_priv, + struct sk_buff *skb, int hdr_size) +{ + return false; +} + +static inline void +batadv_dat_snoop_outgoing_dhcp_ack(struct batadv_priv *bat_priv, + struct sk_buff *skb, __be16 proto, + unsigned short vid) +{ +} + +static inline void +batadv_dat_snoop_incoming_dhcp_ack(struct batadv_priv *bat_priv, + struct sk_buff *skb, int hdr_size) +{ +} + +static inline bool +batadv_dat_drop_broadcast_packet(struct batadv_priv *bat_priv, + struct batadv_forw_packet *forw_packet) +{ + return false; +} + +static inline void +batadv_dat_init_orig_node_addr(struct batadv_orig_node *orig_node) +{ +} + +static inline void batadv_dat_init_own_addr(struct batadv_priv *bat_priv, + struct batadv_hard_iface *iface) +{ +} + +static inline int batadv_dat_init(struct batadv_priv *bat_priv) +{ + return 0; +} + +static inline void batadv_dat_free(struct batadv_priv *bat_priv) +{ +} + +static inline int +batadv_dat_cache_dump(struct sk_buff *msg, struct netlink_callback *cb) +{ + return -EOPNOTSUPP; +} + +static inline void batadv_dat_inc_counter(struct batadv_priv *bat_priv, + u8 subtype) +{ +} + +#endif /* CONFIG_BATMAN_ADV_DAT */ + +#endif /* _NET_BATMAN_ADV_DISTRIBUTED_ARP_TABLE_H_ */ diff --git a/net/batman-adv/fragmentation.c b/net/batman-adv/fragmentation.c new file mode 100644 index 000000000..c120c7c6d --- /dev/null +++ b/net/batman-adv/fragmentation.c @@ -0,0 +1,562 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Martin Hundebøll + */ + +#include "fragmentation.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "hard-interface.h" +#include "originator.h" +#include "routing.h" +#include "send.h" + +/** + * batadv_frag_clear_chain() - delete entries in the fragment buffer chain + * @head: head of chain with entries. + * @dropped: whether the chain is cleared because all fragments are dropped + * + * Free fragments in the passed hlist. Should be called with appropriate lock. + */ +static void batadv_frag_clear_chain(struct hlist_head *head, bool dropped) +{ + struct batadv_frag_list_entry *entry; + struct hlist_node *node; + + hlist_for_each_entry_safe(entry, node, head, list) { + hlist_del(&entry->list); + + if (dropped) + kfree_skb(entry->skb); + else + consume_skb(entry->skb); + + kfree(entry); + } +} + +/** + * batadv_frag_purge_orig() - free fragments associated to an orig + * @orig_node: originator to free fragments from + * @check_cb: optional function to tell if an entry should be purged + */ +void batadv_frag_purge_orig(struct batadv_orig_node *orig_node, + bool (*check_cb)(struct batadv_frag_table_entry *)) +{ + struct batadv_frag_table_entry *chain; + u8 i; + + for (i = 0; i < BATADV_FRAG_BUFFER_COUNT; i++) { + chain = &orig_node->fragments[i]; + spin_lock_bh(&chain->lock); + + if (!check_cb || check_cb(chain)) { + batadv_frag_clear_chain(&chain->fragment_list, true); + chain->size = 0; + } + + spin_unlock_bh(&chain->lock); + } +} + +/** + * batadv_frag_size_limit() - maximum possible size of packet to be fragmented + * + * Return: the maximum size of payload that can be fragmented. + */ +static int batadv_frag_size_limit(void) +{ + int limit = BATADV_FRAG_MAX_FRAG_SIZE; + + limit -= sizeof(struct batadv_frag_packet); + limit *= BATADV_FRAG_MAX_FRAGMENTS; + + return limit; +} + +/** + * batadv_frag_init_chain() - check and prepare fragment chain for new fragment + * @chain: chain in fragments table to init + * @seqno: sequence number of the received fragment + * + * Make chain ready for a fragment with sequence number "seqno". Delete existing + * entries if they have an "old" sequence number. + * + * Caller must hold chain->lock. + * + * Return: true if chain is empty and the caller can just insert the new + * fragment without searching for the right position. + */ +static bool batadv_frag_init_chain(struct batadv_frag_table_entry *chain, + u16 seqno) +{ + lockdep_assert_held(&chain->lock); + + if (chain->seqno == seqno) + return false; + + if (!hlist_empty(&chain->fragment_list)) + batadv_frag_clear_chain(&chain->fragment_list, true); + + chain->size = 0; + chain->seqno = seqno; + + return true; +} + +/** + * batadv_frag_insert_packet() - insert a fragment into a fragment chain + * @orig_node: originator that the fragment was received from + * @skb: skb to insert + * @chain_out: list head to attach complete chains of fragments to + * + * Insert a new fragment into the reverse ordered chain in the right table + * entry. The hash table entry is cleared if "old" fragments exist in it. + * + * Return: true if skb is buffered, false on error. If the chain has all the + * fragments needed to merge the packet, the chain is moved to the passed head + * to avoid locking the chain in the table. + */ +static bool batadv_frag_insert_packet(struct batadv_orig_node *orig_node, + struct sk_buff *skb, + struct hlist_head *chain_out) +{ + struct batadv_frag_table_entry *chain; + struct batadv_frag_list_entry *frag_entry_new = NULL, *frag_entry_curr; + struct batadv_frag_list_entry *frag_entry_last = NULL; + struct batadv_frag_packet *frag_packet; + u8 bucket; + u16 seqno, hdr_size = sizeof(struct batadv_frag_packet); + bool ret = false; + + /* Linearize packet to avoid linearizing 16 packets in a row when doing + * the later merge. Non-linear merge should be added to remove this + * linearization. + */ + if (skb_linearize(skb) < 0) + goto err; + + frag_packet = (struct batadv_frag_packet *)skb->data; + seqno = ntohs(frag_packet->seqno); + bucket = seqno % BATADV_FRAG_BUFFER_COUNT; + + frag_entry_new = kmalloc(sizeof(*frag_entry_new), GFP_ATOMIC); + if (!frag_entry_new) + goto err; + + frag_entry_new->skb = skb; + frag_entry_new->no = frag_packet->no; + + /* Select entry in the "chain table" and delete any prior fragments + * with another sequence number. batadv_frag_init_chain() returns true, + * if the list is empty at return. + */ + chain = &orig_node->fragments[bucket]; + spin_lock_bh(&chain->lock); + if (batadv_frag_init_chain(chain, seqno)) { + hlist_add_head(&frag_entry_new->list, &chain->fragment_list); + chain->size = skb->len - hdr_size; + chain->timestamp = jiffies; + chain->total_size = ntohs(frag_packet->total_size); + ret = true; + goto out; + } + + /* Find the position for the new fragment. */ + hlist_for_each_entry(frag_entry_curr, &chain->fragment_list, list) { + /* Drop packet if fragment already exists. */ + if (frag_entry_curr->no == frag_entry_new->no) + goto err_unlock; + + /* Order fragments from highest to lowest. */ + if (frag_entry_curr->no < frag_entry_new->no) { + hlist_add_before(&frag_entry_new->list, + &frag_entry_curr->list); + chain->size += skb->len - hdr_size; + chain->timestamp = jiffies; + ret = true; + goto out; + } + + /* store current entry because it could be the last in list */ + frag_entry_last = frag_entry_curr; + } + + /* Reached the end of the list, so insert after 'frag_entry_last'. */ + if (likely(frag_entry_last)) { + hlist_add_behind(&frag_entry_new->list, &frag_entry_last->list); + chain->size += skb->len - hdr_size; + chain->timestamp = jiffies; + ret = true; + } + +out: + if (chain->size > batadv_frag_size_limit() || + chain->total_size != ntohs(frag_packet->total_size) || + chain->total_size > batadv_frag_size_limit()) { + /* Clear chain if total size of either the list or the packet + * exceeds the maximum size of one merged packet. Don't allow + * packets to have different total_size. + */ + batadv_frag_clear_chain(&chain->fragment_list, true); + chain->size = 0; + } else if (ntohs(frag_packet->total_size) == chain->size) { + /* All fragments received. Hand over chain to caller. */ + hlist_move_list(&chain->fragment_list, chain_out); + chain->size = 0; + } + +err_unlock: + spin_unlock_bh(&chain->lock); + +err: + if (!ret) { + kfree(frag_entry_new); + kfree_skb(skb); + } + + return ret; +} + +/** + * batadv_frag_merge_packets() - merge a chain of fragments + * @chain: head of chain with fragments + * + * Expand the first skb in the chain and copy the content of the remaining + * skb's into the expanded one. After doing so, clear the chain. + * + * Return: the merged skb or NULL on error. + */ +static struct sk_buff * +batadv_frag_merge_packets(struct hlist_head *chain) +{ + struct batadv_frag_packet *packet; + struct batadv_frag_list_entry *entry; + struct sk_buff *skb_out; + int size, hdr_size = sizeof(struct batadv_frag_packet); + bool dropped = false; + + /* Remove first entry, as this is the destination for the rest of the + * fragments. + */ + entry = hlist_entry(chain->first, struct batadv_frag_list_entry, list); + hlist_del(&entry->list); + skb_out = entry->skb; + kfree(entry); + + packet = (struct batadv_frag_packet *)skb_out->data; + size = ntohs(packet->total_size) + hdr_size; + + /* Make room for the rest of the fragments. */ + if (pskb_expand_head(skb_out, 0, size - skb_out->len, GFP_ATOMIC) < 0) { + kfree_skb(skb_out); + skb_out = NULL; + dropped = true; + goto free; + } + + /* Move the existing MAC header to just before the payload. (Override + * the fragment header.) + */ + skb_pull(skb_out, hdr_size); + skb_out->ip_summed = CHECKSUM_NONE; + memmove(skb_out->data - ETH_HLEN, skb_mac_header(skb_out), ETH_HLEN); + skb_set_mac_header(skb_out, -ETH_HLEN); + skb_reset_network_header(skb_out); + skb_reset_transport_header(skb_out); + + /* Copy the payload of the each fragment into the last skb */ + hlist_for_each_entry(entry, chain, list) { + size = entry->skb->len - hdr_size; + skb_put_data(skb_out, entry->skb->data + hdr_size, size); + } + +free: + /* Locking is not needed, because 'chain' is not part of any orig. */ + batadv_frag_clear_chain(chain, dropped); + return skb_out; +} + +/** + * batadv_frag_skb_buffer() - buffer fragment for later merge + * @skb: skb to buffer + * @orig_node_src: originator that the skb is received from + * + * Add fragment to buffer and merge fragments if possible. + * + * There are three possible outcomes: 1) Packet is merged: Return true and + * set *skb to merged packet; 2) Packet is buffered: Return true and set *skb + * to NULL; 3) Error: Return false and free skb. + * + * Return: true when the packet is merged or buffered, false when skb is not + * used. + */ +bool batadv_frag_skb_buffer(struct sk_buff **skb, + struct batadv_orig_node *orig_node_src) +{ + struct sk_buff *skb_out = NULL; + struct hlist_head head = HLIST_HEAD_INIT; + bool ret = false; + + /* Add packet to buffer and table entry if merge is possible. */ + if (!batadv_frag_insert_packet(orig_node_src, *skb, &head)) + goto out_err; + + /* Leave if more fragments are needed to merge. */ + if (hlist_empty(&head)) + goto out; + + skb_out = batadv_frag_merge_packets(&head); + if (!skb_out) + goto out_err; + +out: + ret = true; +out_err: + *skb = skb_out; + return ret; +} + +/** + * batadv_frag_skb_fwd() - forward fragments that would exceed MTU when merged + * @skb: skb to forward + * @recv_if: interface that the skb is received on + * @orig_node_src: originator that the skb is received from + * + * Look up the next-hop of the fragments payload and check if the merged packet + * will exceed the MTU towards the next-hop. If so, the fragment is forwarded + * without merging it. + * + * Return: true if the fragment is consumed/forwarded, false otherwise. + */ +bool batadv_frag_skb_fwd(struct sk_buff *skb, + struct batadv_hard_iface *recv_if, + struct batadv_orig_node *orig_node_src) +{ + struct batadv_priv *bat_priv = netdev_priv(recv_if->soft_iface); + struct batadv_orig_node *orig_node_dst; + struct batadv_neigh_node *neigh_node = NULL; + struct batadv_frag_packet *packet; + u16 total_size; + bool ret = false; + + packet = (struct batadv_frag_packet *)skb->data; + orig_node_dst = batadv_orig_hash_find(bat_priv, packet->dest); + if (!orig_node_dst) + goto out; + + neigh_node = batadv_find_router(bat_priv, orig_node_dst, recv_if); + if (!neigh_node) + goto out; + + /* Forward the fragment, if the merged packet would be too big to + * be assembled. + */ + total_size = ntohs(packet->total_size); + if (total_size > neigh_node->if_incoming->net_dev->mtu) { + batadv_inc_counter(bat_priv, BATADV_CNT_FRAG_FWD); + batadv_add_counter(bat_priv, BATADV_CNT_FRAG_FWD_BYTES, + skb->len + ETH_HLEN); + + packet->ttl--; + batadv_send_unicast_skb(skb, neigh_node); + ret = true; + } + +out: + batadv_orig_node_put(orig_node_dst); + batadv_neigh_node_put(neigh_node); + return ret; +} + +/** + * batadv_frag_create() - create a fragment from skb + * @net_dev: outgoing device for fragment + * @skb: skb to create fragment from + * @frag_head: header to use in new fragment + * @fragment_size: size of new fragment + * + * Split the passed skb into two fragments: A new one with size matching the + * passed mtu and the old one with the rest. The new skb contains data from the + * tail of the old skb. + * + * Return: the new fragment, NULL on error. + */ +static struct sk_buff *batadv_frag_create(struct net_device *net_dev, + struct sk_buff *skb, + struct batadv_frag_packet *frag_head, + unsigned int fragment_size) +{ + unsigned int ll_reserved = LL_RESERVED_SPACE(net_dev); + unsigned int tailroom = net_dev->needed_tailroom; + struct sk_buff *skb_fragment; + unsigned int header_size = sizeof(*frag_head); + unsigned int mtu = fragment_size + header_size; + + skb_fragment = dev_alloc_skb(ll_reserved + mtu + tailroom); + if (!skb_fragment) + goto err; + + skb_fragment->priority = skb->priority; + + /* Eat the last mtu-bytes of the skb */ + skb_reserve(skb_fragment, ll_reserved + header_size); + skb_split(skb, skb_fragment, skb->len - fragment_size); + + /* Add the header */ + skb_push(skb_fragment, header_size); + memcpy(skb_fragment->data, frag_head, header_size); + +err: + return skb_fragment; +} + +/** + * batadv_frag_send_packet() - create up to 16 fragments from the passed skb + * @skb: skb to create fragments from + * @orig_node: final destination of the created fragments + * @neigh_node: next-hop of the created fragments + * + * Return: the netdev tx status or a negative errno code on a failure + */ +int batadv_frag_send_packet(struct sk_buff *skb, + struct batadv_orig_node *orig_node, + struct batadv_neigh_node *neigh_node) +{ + struct net_device *net_dev = neigh_node->if_incoming->net_dev; + struct batadv_priv *bat_priv; + struct batadv_hard_iface *primary_if = NULL; + struct batadv_frag_packet frag_header; + struct sk_buff *skb_fragment; + unsigned int mtu = net_dev->mtu; + unsigned int header_size = sizeof(frag_header); + unsigned int max_fragment_size, num_fragments; + int ret; + + /* To avoid merge and refragmentation at next-hops we never send + * fragments larger than BATADV_FRAG_MAX_FRAG_SIZE + */ + mtu = min_t(unsigned int, mtu, BATADV_FRAG_MAX_FRAG_SIZE); + max_fragment_size = mtu - header_size; + + if (skb->len == 0 || max_fragment_size == 0) + return -EINVAL; + + num_fragments = (skb->len - 1) / max_fragment_size + 1; + max_fragment_size = (skb->len - 1) / num_fragments + 1; + + /* Don't even try to fragment, if we need more than 16 fragments */ + if (num_fragments > BATADV_FRAG_MAX_FRAGMENTS) { + ret = -EAGAIN; + goto free_skb; + } + + bat_priv = orig_node->bat_priv; + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) { + ret = -EINVAL; + goto free_skb; + } + + /* GRO might have added fragments to the fragment list instead of + * frags[]. But this is not handled by skb_split and must be + * linearized to avoid incorrect length information after all + * batman-adv fragments were created and submitted to the + * hard-interface + */ + if (skb_has_frag_list(skb) && __skb_linearize(skb)) { + ret = -ENOMEM; + goto free_skb; + } + + /* Create one header to be copied to all fragments */ + frag_header.packet_type = BATADV_UNICAST_FRAG; + frag_header.version = BATADV_COMPAT_VERSION; + frag_header.ttl = BATADV_TTL; + frag_header.seqno = htons(atomic_inc_return(&bat_priv->frag_seqno)); + frag_header.reserved = 0; + frag_header.no = 0; + frag_header.total_size = htons(skb->len); + + /* skb->priority values from 256->263 are magic values to + * directly indicate a specific 802.1d priority. This is used + * to allow 802.1d priority to be passed directly in from VLAN + * tags, etc. + */ + if (skb->priority >= 256 && skb->priority <= 263) + frag_header.priority = skb->priority - 256; + else + frag_header.priority = 0; + + ether_addr_copy(frag_header.orig, primary_if->net_dev->dev_addr); + ether_addr_copy(frag_header.dest, orig_node->orig); + + /* Eat and send fragments from the tail of skb */ + while (skb->len > max_fragment_size) { + /* The initial check in this function should cover this case */ + if (unlikely(frag_header.no == BATADV_FRAG_MAX_FRAGMENTS - 1)) { + ret = -EINVAL; + goto put_primary_if; + } + + skb_fragment = batadv_frag_create(net_dev, skb, &frag_header, + max_fragment_size); + if (!skb_fragment) { + ret = -ENOMEM; + goto put_primary_if; + } + + batadv_inc_counter(bat_priv, BATADV_CNT_FRAG_TX); + batadv_add_counter(bat_priv, BATADV_CNT_FRAG_TX_BYTES, + skb_fragment->len + ETH_HLEN); + ret = batadv_send_unicast_skb(skb_fragment, neigh_node); + if (ret != NET_XMIT_SUCCESS) { + ret = NET_XMIT_DROP; + goto put_primary_if; + } + + frag_header.no++; + } + + /* make sure that there is at least enough head for the fragmentation + * and ethernet headers + */ + ret = skb_cow_head(skb, ETH_HLEN + header_size); + if (ret < 0) + goto put_primary_if; + + skb_push(skb, header_size); + memcpy(skb->data, &frag_header, header_size); + + /* Send the last fragment */ + batadv_inc_counter(bat_priv, BATADV_CNT_FRAG_TX); + batadv_add_counter(bat_priv, BATADV_CNT_FRAG_TX_BYTES, + skb->len + ETH_HLEN); + ret = batadv_send_unicast_skb(skb, neigh_node); + /* skb was consumed */ + skb = NULL; + +put_primary_if: + batadv_hardif_put(primary_if); +free_skb: + kfree_skb(skb); + + return ret; +} diff --git a/net/batman-adv/fragmentation.h b/net/batman-adv/fragmentation.h new file mode 100644 index 000000000..dbf0871f8 --- /dev/null +++ b/net/batman-adv/fragmentation.h @@ -0,0 +1,44 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Martin Hundebøll + */ + +#ifndef _NET_BATMAN_ADV_FRAGMENTATION_H_ +#define _NET_BATMAN_ADV_FRAGMENTATION_H_ + +#include "main.h" + +#include +#include +#include +#include +#include + +void batadv_frag_purge_orig(struct batadv_orig_node *orig, + bool (*check_cb)(struct batadv_frag_table_entry *)); +bool batadv_frag_skb_fwd(struct sk_buff *skb, + struct batadv_hard_iface *recv_if, + struct batadv_orig_node *orig_node_src); +bool batadv_frag_skb_buffer(struct sk_buff **skb, + struct batadv_orig_node *orig_node); +int batadv_frag_send_packet(struct sk_buff *skb, + struct batadv_orig_node *orig_node, + struct batadv_neigh_node *neigh_node); + +/** + * batadv_frag_check_entry() - check if a list of fragments has timed out + * @frags_entry: table entry to check + * + * Return: true if the frags entry has timed out, false otherwise. + */ +static inline bool +batadv_frag_check_entry(struct batadv_frag_table_entry *frags_entry) +{ + if (!hlist_empty(&frags_entry->fragment_list) && + batadv_has_timed_out(frags_entry->timestamp, BATADV_FRAG_TIMEOUT)) + return true; + return false; +} + +#endif /* _NET_BATMAN_ADV_FRAGMENTATION_H_ */ diff --git a/net/batman-adv/gateway_client.c b/net/batman-adv/gateway_client.c new file mode 100644 index 000000000..d26124bc2 --- /dev/null +++ b/net/batman-adv/gateway_client.c @@ -0,0 +1,769 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner + */ + +#include "gateway_client.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "hard-interface.h" +#include "log.h" +#include "netlink.h" +#include "originator.h" +#include "routing.h" +#include "soft-interface.h" +#include "translation-table.h" + +/* These are the offsets of the "hw type" and "hw address length" in the dhcp + * packet starting at the beginning of the dhcp header + */ +#define BATADV_DHCP_HTYPE_OFFSET 1 +#define BATADV_DHCP_HLEN_OFFSET 2 +/* Value of htype representing Ethernet */ +#define BATADV_DHCP_HTYPE_ETHERNET 0x01 +/* This is the offset of the "chaddr" field in the dhcp packet starting at the + * beginning of the dhcp header + */ +#define BATADV_DHCP_CHADDR_OFFSET 28 + +/** + * batadv_gw_node_release() - release gw_node from lists and queue for free + * after rcu grace period + * @ref: kref pointer of the gw_node + */ +void batadv_gw_node_release(struct kref *ref) +{ + struct batadv_gw_node *gw_node; + + gw_node = container_of(ref, struct batadv_gw_node, refcount); + + batadv_orig_node_put(gw_node->orig_node); + kfree_rcu(gw_node, rcu); +} + +/** + * batadv_gw_get_selected_gw_node() - Get currently selected gateway + * @bat_priv: the bat priv with all the soft interface information + * + * Return: selected gateway (with increased refcnt), NULL on errors + */ +struct batadv_gw_node * +batadv_gw_get_selected_gw_node(struct batadv_priv *bat_priv) +{ + struct batadv_gw_node *gw_node; + + rcu_read_lock(); + gw_node = rcu_dereference(bat_priv->gw.curr_gw); + if (!gw_node) + goto out; + + if (!kref_get_unless_zero(&gw_node->refcount)) + gw_node = NULL; + +out: + rcu_read_unlock(); + return gw_node; +} + +/** + * batadv_gw_get_selected_orig() - Get originator of currently selected gateway + * @bat_priv: the bat priv with all the soft interface information + * + * Return: orig_node of selected gateway (with increased refcnt), NULL on errors + */ +struct batadv_orig_node * +batadv_gw_get_selected_orig(struct batadv_priv *bat_priv) +{ + struct batadv_gw_node *gw_node; + struct batadv_orig_node *orig_node = NULL; + + gw_node = batadv_gw_get_selected_gw_node(bat_priv); + if (!gw_node) + goto out; + + rcu_read_lock(); + orig_node = gw_node->orig_node; + if (!orig_node) + goto unlock; + + if (!kref_get_unless_zero(&orig_node->refcount)) + orig_node = NULL; + +unlock: + rcu_read_unlock(); +out: + batadv_gw_node_put(gw_node); + return orig_node; +} + +static void batadv_gw_select(struct batadv_priv *bat_priv, + struct batadv_gw_node *new_gw_node) +{ + struct batadv_gw_node *curr_gw_node; + + spin_lock_bh(&bat_priv->gw.list_lock); + + if (new_gw_node) + kref_get(&new_gw_node->refcount); + + curr_gw_node = rcu_replace_pointer(bat_priv->gw.curr_gw, new_gw_node, + true); + + batadv_gw_node_put(curr_gw_node); + + spin_unlock_bh(&bat_priv->gw.list_lock); +} + +/** + * batadv_gw_reselect() - force a gateway reselection + * @bat_priv: the bat priv with all the soft interface information + * + * Set a flag to remind the GW component to perform a new gateway reselection. + * However this function does not ensure that the current gateway is going to be + * deselected. The reselection mechanism may elect the same gateway once again. + * + * This means that invoking batadv_gw_reselect() does not guarantee a gateway + * change and therefore a uevent is not necessarily expected. + */ +void batadv_gw_reselect(struct batadv_priv *bat_priv) +{ + atomic_set(&bat_priv->gw.reselect, 1); +} + +/** + * batadv_gw_check_client_stop() - check if client mode has been switched off + * @bat_priv: the bat priv with all the soft interface information + * + * This function assumes the caller has checked that the gw state *is actually + * changing*. This function is not supposed to be called when there is no state + * change. + */ +void batadv_gw_check_client_stop(struct batadv_priv *bat_priv) +{ + struct batadv_gw_node *curr_gw; + + if (atomic_read(&bat_priv->gw.mode) != BATADV_GW_MODE_CLIENT) + return; + + curr_gw = batadv_gw_get_selected_gw_node(bat_priv); + if (!curr_gw) + return; + + /* deselect the current gateway so that next time that client mode is + * enabled a proper GW_ADD event can be sent + */ + batadv_gw_select(bat_priv, NULL); + + /* if batman-adv is switching the gw client mode off and a gateway was + * already selected, send a DEL uevent + */ + batadv_throw_uevent(bat_priv, BATADV_UEV_GW, BATADV_UEV_DEL, NULL); + + batadv_gw_node_put(curr_gw); +} + +/** + * batadv_gw_election() - Elect the best gateway + * @bat_priv: the bat priv with all the soft interface information + */ +void batadv_gw_election(struct batadv_priv *bat_priv) +{ + struct batadv_gw_node *curr_gw = NULL; + struct batadv_gw_node *next_gw = NULL; + struct batadv_neigh_node *router = NULL; + struct batadv_neigh_ifinfo *router_ifinfo = NULL; + char gw_addr[18] = { '\0' }; + + if (atomic_read(&bat_priv->gw.mode) != BATADV_GW_MODE_CLIENT) + goto out; + + if (!bat_priv->algo_ops->gw.get_best_gw_node) + goto out; + + curr_gw = batadv_gw_get_selected_gw_node(bat_priv); + + if (!batadv_atomic_dec_not_zero(&bat_priv->gw.reselect) && curr_gw) + goto out; + + /* if gw.reselect is set to 1 it means that a previous call to + * gw.is_eligible() said that we have a new best GW, therefore it can + * now be picked from the list and selected + */ + next_gw = bat_priv->algo_ops->gw.get_best_gw_node(bat_priv); + + if (curr_gw == next_gw) + goto out; + + if (next_gw) { + sprintf(gw_addr, "%pM", next_gw->orig_node->orig); + + router = batadv_orig_router_get(next_gw->orig_node, + BATADV_IF_DEFAULT); + if (!router) { + batadv_gw_reselect(bat_priv); + goto out; + } + + router_ifinfo = batadv_neigh_ifinfo_get(router, + BATADV_IF_DEFAULT); + if (!router_ifinfo) { + batadv_gw_reselect(bat_priv); + goto out; + } + } + + if (curr_gw && !next_gw) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Removing selected gateway - no gateway in range\n"); + batadv_throw_uevent(bat_priv, BATADV_UEV_GW, BATADV_UEV_DEL, + NULL); + } else if (!curr_gw && next_gw) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Adding route to gateway %pM (bandwidth: %u.%u/%u.%u MBit, tq: %i)\n", + next_gw->orig_node->orig, + next_gw->bandwidth_down / 10, + next_gw->bandwidth_down % 10, + next_gw->bandwidth_up / 10, + next_gw->bandwidth_up % 10, + router_ifinfo->bat_iv.tq_avg); + batadv_throw_uevent(bat_priv, BATADV_UEV_GW, BATADV_UEV_ADD, + gw_addr); + } else { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Changing route to gateway %pM (bandwidth: %u.%u/%u.%u MBit, tq: %i)\n", + next_gw->orig_node->orig, + next_gw->bandwidth_down / 10, + next_gw->bandwidth_down % 10, + next_gw->bandwidth_up / 10, + next_gw->bandwidth_up % 10, + router_ifinfo->bat_iv.tq_avg); + batadv_throw_uevent(bat_priv, BATADV_UEV_GW, BATADV_UEV_CHANGE, + gw_addr); + } + + batadv_gw_select(bat_priv, next_gw); + +out: + batadv_gw_node_put(curr_gw); + batadv_gw_node_put(next_gw); + batadv_neigh_node_put(router); + batadv_neigh_ifinfo_put(router_ifinfo); +} + +/** + * batadv_gw_check_election() - Elect orig node as best gateway when eligible + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: orig node which is to be checked + */ +void batadv_gw_check_election(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node) +{ + struct batadv_orig_node *curr_gw_orig; + + /* abort immediately if the routing algorithm does not support gateway + * election + */ + if (!bat_priv->algo_ops->gw.is_eligible) + return; + + curr_gw_orig = batadv_gw_get_selected_orig(bat_priv); + if (!curr_gw_orig) + goto reselect; + + /* this node already is the gateway */ + if (curr_gw_orig == orig_node) + goto out; + + if (!bat_priv->algo_ops->gw.is_eligible(bat_priv, curr_gw_orig, + orig_node)) + goto out; + +reselect: + batadv_gw_reselect(bat_priv); +out: + batadv_orig_node_put(curr_gw_orig); +} + +/** + * batadv_gw_node_add() - add gateway node to list of available gateways + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: originator announcing gateway capabilities + * @gateway: announced bandwidth information + * + * Has to be called with the appropriate locks being acquired + * (gw.list_lock). + */ +static void batadv_gw_node_add(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + struct batadv_tvlv_gateway_data *gateway) +{ + struct batadv_gw_node *gw_node; + + lockdep_assert_held(&bat_priv->gw.list_lock); + + if (gateway->bandwidth_down == 0) + return; + + gw_node = kzalloc(sizeof(*gw_node), GFP_ATOMIC); + if (!gw_node) + return; + + kref_init(&gw_node->refcount); + INIT_HLIST_NODE(&gw_node->list); + kref_get(&orig_node->refcount); + gw_node->orig_node = orig_node; + gw_node->bandwidth_down = ntohl(gateway->bandwidth_down); + gw_node->bandwidth_up = ntohl(gateway->bandwidth_up); + + kref_get(&gw_node->refcount); + hlist_add_head_rcu(&gw_node->list, &bat_priv->gw.gateway_list); + bat_priv->gw.generation++; + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Found new gateway %pM -> gw bandwidth: %u.%u/%u.%u MBit\n", + orig_node->orig, + ntohl(gateway->bandwidth_down) / 10, + ntohl(gateway->bandwidth_down) % 10, + ntohl(gateway->bandwidth_up) / 10, + ntohl(gateway->bandwidth_up) % 10); + + /* don't return reference to new gw_node */ + batadv_gw_node_put(gw_node); +} + +/** + * batadv_gw_node_get() - retrieve gateway node from list of available gateways + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: originator announcing gateway capabilities + * + * Return: gateway node if found or NULL otherwise. + */ +struct batadv_gw_node *batadv_gw_node_get(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node) +{ + struct batadv_gw_node *gw_node_tmp, *gw_node = NULL; + + rcu_read_lock(); + hlist_for_each_entry_rcu(gw_node_tmp, &bat_priv->gw.gateway_list, + list) { + if (gw_node_tmp->orig_node != orig_node) + continue; + + if (!kref_get_unless_zero(&gw_node_tmp->refcount)) + continue; + + gw_node = gw_node_tmp; + break; + } + rcu_read_unlock(); + + return gw_node; +} + +/** + * batadv_gw_node_update() - update list of available gateways with changed + * bandwidth information + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: originator announcing gateway capabilities + * @gateway: announced bandwidth information + */ +void batadv_gw_node_update(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + struct batadv_tvlv_gateway_data *gateway) +{ + struct batadv_gw_node *gw_node, *curr_gw = NULL; + + spin_lock_bh(&bat_priv->gw.list_lock); + gw_node = batadv_gw_node_get(bat_priv, orig_node); + if (!gw_node) { + batadv_gw_node_add(bat_priv, orig_node, gateway); + spin_unlock_bh(&bat_priv->gw.list_lock); + goto out; + } + spin_unlock_bh(&bat_priv->gw.list_lock); + + if (gw_node->bandwidth_down == ntohl(gateway->bandwidth_down) && + gw_node->bandwidth_up == ntohl(gateway->bandwidth_up)) + goto out; + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Gateway bandwidth of originator %pM changed from %u.%u/%u.%u MBit to %u.%u/%u.%u MBit\n", + orig_node->orig, + gw_node->bandwidth_down / 10, + gw_node->bandwidth_down % 10, + gw_node->bandwidth_up / 10, + gw_node->bandwidth_up % 10, + ntohl(gateway->bandwidth_down) / 10, + ntohl(gateway->bandwidth_down) % 10, + ntohl(gateway->bandwidth_up) / 10, + ntohl(gateway->bandwidth_up) % 10); + + gw_node->bandwidth_down = ntohl(gateway->bandwidth_down); + gw_node->bandwidth_up = ntohl(gateway->bandwidth_up); + + if (ntohl(gateway->bandwidth_down) == 0) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Gateway %pM removed from gateway list\n", + orig_node->orig); + + /* Note: We don't need a NULL check here, since curr_gw never + * gets dereferenced. + */ + spin_lock_bh(&bat_priv->gw.list_lock); + if (!hlist_unhashed(&gw_node->list)) { + hlist_del_init_rcu(&gw_node->list); + batadv_gw_node_put(gw_node); + bat_priv->gw.generation++; + } + spin_unlock_bh(&bat_priv->gw.list_lock); + + curr_gw = batadv_gw_get_selected_gw_node(bat_priv); + if (gw_node == curr_gw) + batadv_gw_reselect(bat_priv); + + batadv_gw_node_put(curr_gw); + } + +out: + batadv_gw_node_put(gw_node); +} + +/** + * batadv_gw_node_delete() - Remove orig_node from gateway list + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: orig node which is currently in process of being removed + */ +void batadv_gw_node_delete(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node) +{ + struct batadv_tvlv_gateway_data gateway; + + gateway.bandwidth_down = 0; + gateway.bandwidth_up = 0; + + batadv_gw_node_update(bat_priv, orig_node, &gateway); +} + +/** + * batadv_gw_node_free() - Free gateway information from soft interface + * @bat_priv: the bat priv with all the soft interface information + */ +void batadv_gw_node_free(struct batadv_priv *bat_priv) +{ + struct batadv_gw_node *gw_node; + struct hlist_node *node_tmp; + + spin_lock_bh(&bat_priv->gw.list_lock); + hlist_for_each_entry_safe(gw_node, node_tmp, + &bat_priv->gw.gateway_list, list) { + hlist_del_init_rcu(&gw_node->list); + batadv_gw_node_put(gw_node); + bat_priv->gw.generation++; + } + spin_unlock_bh(&bat_priv->gw.list_lock); +} + +/** + * batadv_gw_dump() - Dump gateways into a message + * @msg: Netlink message to dump into + * @cb: Control block containing additional options + * + * Return: Error code, or length of message + */ +int batadv_gw_dump(struct sk_buff *msg, struct netlink_callback *cb) +{ + struct batadv_hard_iface *primary_if = NULL; + struct net *net = sock_net(cb->skb->sk); + struct net_device *soft_iface; + struct batadv_priv *bat_priv; + int ifindex; + int ret; + + ifindex = batadv_netlink_get_ifindex(cb->nlh, + BATADV_ATTR_MESH_IFINDEX); + if (!ifindex) + return -EINVAL; + + soft_iface = dev_get_by_index(net, ifindex); + if (!soft_iface || !batadv_softif_is_valid(soft_iface)) { + ret = -ENODEV; + goto out; + } + + bat_priv = netdev_priv(soft_iface); + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if || primary_if->if_status != BATADV_IF_ACTIVE) { + ret = -ENOENT; + goto out; + } + + if (!bat_priv->algo_ops->gw.dump) { + ret = -EOPNOTSUPP; + goto out; + } + + bat_priv->algo_ops->gw.dump(msg, cb, bat_priv); + + ret = msg->len; + +out: + batadv_hardif_put(primary_if); + dev_put(soft_iface); + + return ret; +} + +/** + * batadv_gw_dhcp_recipient_get() - check if a packet is a DHCP message + * @skb: the packet to check + * @header_len: a pointer to the batman-adv header size + * @chaddr: buffer where the client address will be stored. Valid + * only if the function returns BATADV_DHCP_TO_CLIENT + * + * This function may re-allocate the data buffer of the skb passed as argument. + * + * Return: + * - BATADV_DHCP_NO if the packet is not a dhcp message or if there was an error + * while parsing it + * - BATADV_DHCP_TO_SERVER if this is a message going to the DHCP server + * - BATADV_DHCP_TO_CLIENT if this is a message going to a DHCP client + */ +enum batadv_dhcp_recipient +batadv_gw_dhcp_recipient_get(struct sk_buff *skb, unsigned int *header_len, + u8 *chaddr) +{ + enum batadv_dhcp_recipient ret = BATADV_DHCP_NO; + struct ethhdr *ethhdr; + struct iphdr *iphdr; + struct ipv6hdr *ipv6hdr; + struct udphdr *udphdr; + struct vlan_ethhdr *vhdr; + int chaddr_offset; + __be16 proto; + u8 *p; + + /* check for ethernet header */ + if (!pskb_may_pull(skb, *header_len + ETH_HLEN)) + return BATADV_DHCP_NO; + + ethhdr = eth_hdr(skb); + proto = ethhdr->h_proto; + *header_len += ETH_HLEN; + + /* check for initial vlan header */ + if (proto == htons(ETH_P_8021Q)) { + if (!pskb_may_pull(skb, *header_len + VLAN_HLEN)) + return BATADV_DHCP_NO; + + vhdr = vlan_eth_hdr(skb); + proto = vhdr->h_vlan_encapsulated_proto; + *header_len += VLAN_HLEN; + } + + /* check for ip header */ + switch (proto) { + case htons(ETH_P_IP): + if (!pskb_may_pull(skb, *header_len + sizeof(*iphdr))) + return BATADV_DHCP_NO; + + iphdr = (struct iphdr *)(skb->data + *header_len); + *header_len += iphdr->ihl * 4; + + /* check for udp header */ + if (iphdr->protocol != IPPROTO_UDP) + return BATADV_DHCP_NO; + + break; + case htons(ETH_P_IPV6): + if (!pskb_may_pull(skb, *header_len + sizeof(*ipv6hdr))) + return BATADV_DHCP_NO; + + ipv6hdr = (struct ipv6hdr *)(skb->data + *header_len); + *header_len += sizeof(*ipv6hdr); + + /* check for udp header */ + if (ipv6hdr->nexthdr != IPPROTO_UDP) + return BATADV_DHCP_NO; + + break; + default: + return BATADV_DHCP_NO; + } + + if (!pskb_may_pull(skb, *header_len + sizeof(*udphdr))) + return BATADV_DHCP_NO; + + udphdr = (struct udphdr *)(skb->data + *header_len); + *header_len += sizeof(*udphdr); + + /* check for bootp port */ + switch (proto) { + case htons(ETH_P_IP): + if (udphdr->dest == htons(67)) + ret = BATADV_DHCP_TO_SERVER; + else if (udphdr->source == htons(67)) + ret = BATADV_DHCP_TO_CLIENT; + break; + case htons(ETH_P_IPV6): + if (udphdr->dest == htons(547)) + ret = BATADV_DHCP_TO_SERVER; + else if (udphdr->source == htons(547)) + ret = BATADV_DHCP_TO_CLIENT; + break; + } + + chaddr_offset = *header_len + BATADV_DHCP_CHADDR_OFFSET; + /* store the client address if the message is going to a client */ + if (ret == BATADV_DHCP_TO_CLIENT) { + if (!pskb_may_pull(skb, chaddr_offset + ETH_ALEN)) + return BATADV_DHCP_NO; + + /* check if the DHCP packet carries an Ethernet DHCP */ + p = skb->data + *header_len + BATADV_DHCP_HTYPE_OFFSET; + if (*p != BATADV_DHCP_HTYPE_ETHERNET) + return BATADV_DHCP_NO; + + /* check if the DHCP packet carries a valid Ethernet address */ + p = skb->data + *header_len + BATADV_DHCP_HLEN_OFFSET; + if (*p != ETH_ALEN) + return BATADV_DHCP_NO; + + ether_addr_copy(chaddr, skb->data + chaddr_offset); + } + + return ret; +} + +/** + * batadv_gw_out_of_range() - check if the dhcp request destination is the best + * gateway + * @bat_priv: the bat priv with all the soft interface information + * @skb: the outgoing packet + * + * Check if the skb is a DHCP request and if it is sent to the current best GW + * server. Due to topology changes it may be the case that the GW server + * previously selected is not the best one anymore. + * + * This call might reallocate skb data. + * Must be invoked only when the DHCP packet is going TO a DHCP SERVER. + * + * Return: true if the packet destination is unicast and it is not the best gw, + * false otherwise. + */ +bool batadv_gw_out_of_range(struct batadv_priv *bat_priv, + struct sk_buff *skb) +{ + struct batadv_neigh_node *neigh_curr = NULL; + struct batadv_neigh_node *neigh_old = NULL; + struct batadv_orig_node *orig_dst_node = NULL; + struct batadv_gw_node *gw_node = NULL; + struct batadv_gw_node *curr_gw = NULL; + struct batadv_neigh_ifinfo *curr_ifinfo, *old_ifinfo; + struct ethhdr *ethhdr = (struct ethhdr *)skb->data; + bool out_of_range = false; + u8 curr_tq_avg; + unsigned short vid; + + vid = batadv_get_vid(skb, 0); + + if (is_multicast_ether_addr(ethhdr->h_dest)) + goto out; + + orig_dst_node = batadv_transtable_search(bat_priv, ethhdr->h_source, + ethhdr->h_dest, vid); + if (!orig_dst_node) + goto out; + + gw_node = batadv_gw_node_get(bat_priv, orig_dst_node); + if (!gw_node) + goto out; + + switch (atomic_read(&bat_priv->gw.mode)) { + case BATADV_GW_MODE_SERVER: + /* If we are a GW then we are our best GW. We can artificially + * set the tq towards ourself as the maximum value + */ + curr_tq_avg = BATADV_TQ_MAX_VALUE; + break; + case BATADV_GW_MODE_CLIENT: + curr_gw = batadv_gw_get_selected_gw_node(bat_priv); + if (!curr_gw) + goto out; + + /* packet is going to our gateway */ + if (curr_gw->orig_node == orig_dst_node) + goto out; + + /* If the dhcp packet has been sent to a different gw, + * we have to evaluate whether the old gw is still + * reliable enough + */ + neigh_curr = batadv_find_router(bat_priv, curr_gw->orig_node, + NULL); + if (!neigh_curr) + goto out; + + curr_ifinfo = batadv_neigh_ifinfo_get(neigh_curr, + BATADV_IF_DEFAULT); + if (!curr_ifinfo) + goto out; + + curr_tq_avg = curr_ifinfo->bat_iv.tq_avg; + batadv_neigh_ifinfo_put(curr_ifinfo); + + break; + case BATADV_GW_MODE_OFF: + default: + goto out; + } + + neigh_old = batadv_find_router(bat_priv, orig_dst_node, NULL); + if (!neigh_old) + goto out; + + old_ifinfo = batadv_neigh_ifinfo_get(neigh_old, BATADV_IF_DEFAULT); + if (!old_ifinfo) + goto out; + + if ((curr_tq_avg - old_ifinfo->bat_iv.tq_avg) > BATADV_GW_THRESHOLD) + out_of_range = true; + batadv_neigh_ifinfo_put(old_ifinfo); + +out: + batadv_orig_node_put(orig_dst_node); + batadv_gw_node_put(curr_gw); + batadv_gw_node_put(gw_node); + batadv_neigh_node_put(neigh_old); + batadv_neigh_node_put(neigh_curr); + return out_of_range; +} diff --git a/net/batman-adv/gateway_client.h b/net/batman-adv/gateway_client.h new file mode 100644 index 000000000..95c2ccdaa --- /dev/null +++ b/net/batman-adv/gateway_client.h @@ -0,0 +1,55 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner + */ + +#ifndef _NET_BATMAN_ADV_GATEWAY_CLIENT_H_ +#define _NET_BATMAN_ADV_GATEWAY_CLIENT_H_ + +#include "main.h" + +#include +#include +#include +#include +#include + +void batadv_gw_check_client_stop(struct batadv_priv *bat_priv); +void batadv_gw_reselect(struct batadv_priv *bat_priv); +void batadv_gw_election(struct batadv_priv *bat_priv); +struct batadv_orig_node * +batadv_gw_get_selected_orig(struct batadv_priv *bat_priv); +void batadv_gw_check_election(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node); +void batadv_gw_node_update(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + struct batadv_tvlv_gateway_data *gateway); +void batadv_gw_node_delete(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node); +void batadv_gw_node_free(struct batadv_priv *bat_priv); +void batadv_gw_node_release(struct kref *ref); +struct batadv_gw_node * +batadv_gw_get_selected_gw_node(struct batadv_priv *bat_priv); +int batadv_gw_dump(struct sk_buff *msg, struct netlink_callback *cb); +bool batadv_gw_out_of_range(struct batadv_priv *bat_priv, struct sk_buff *skb); +enum batadv_dhcp_recipient +batadv_gw_dhcp_recipient_get(struct sk_buff *skb, unsigned int *header_len, + u8 *chaddr); +struct batadv_gw_node *batadv_gw_node_get(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node); + +/** + * batadv_gw_node_put() - decrement the gw_node refcounter and possibly release + * it + * @gw_node: gateway node to free + */ +static inline void batadv_gw_node_put(struct batadv_gw_node *gw_node) +{ + if (!gw_node) + return; + + kref_put(&gw_node->refcount, batadv_gw_node_release); +} + +#endif /* _NET_BATMAN_ADV_GATEWAY_CLIENT_H_ */ diff --git a/net/batman-adv/gateway_common.c b/net/batman-adv/gateway_common.c new file mode 100644 index 000000000..9349c76f3 --- /dev/null +++ b/net/batman-adv/gateway_common.c @@ -0,0 +1,274 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner + */ + +#include "gateway_common.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gateway_client.h" +#include "log.h" +#include "tvlv.h" + +/** + * batadv_parse_throughput() - parse supplied string buffer to extract + * throughput information + * @net_dev: the soft interface net device + * @buff: string buffer to parse + * @description: text shown when throughput string cannot be parsed + * @throughput: pointer holding the returned throughput information + * + * Return: false on parse error and true otherwise. + */ +bool batadv_parse_throughput(struct net_device *net_dev, char *buff, + const char *description, u32 *throughput) +{ + enum batadv_bandwidth_units bw_unit_type = BATADV_BW_UNIT_KBIT; + u64 lthroughput; + char *tmp_ptr; + int ret; + + if (strlen(buff) > 4) { + tmp_ptr = buff + strlen(buff) - 4; + + if (strncasecmp(tmp_ptr, "mbit", 4) == 0) + bw_unit_type = BATADV_BW_UNIT_MBIT; + + if (strncasecmp(tmp_ptr, "kbit", 4) == 0 || + bw_unit_type == BATADV_BW_UNIT_MBIT) + *tmp_ptr = '\0'; + } + + ret = kstrtou64(buff, 10, <hroughput); + if (ret) { + batadv_err(net_dev, + "Invalid throughput speed for %s: %s\n", + description, buff); + return false; + } + + switch (bw_unit_type) { + case BATADV_BW_UNIT_MBIT: + /* prevent overflow */ + if (U64_MAX / 10 < lthroughput) { + batadv_err(net_dev, + "Throughput speed for %s too large: %s\n", + description, buff); + return false; + } + + lthroughput *= 10; + break; + case BATADV_BW_UNIT_KBIT: + default: + lthroughput = div_u64(lthroughput, 100); + break; + } + + if (lthroughput > U32_MAX) { + batadv_err(net_dev, + "Throughput speed for %s too large: %s\n", + description, buff); + return false; + } + + *throughput = lthroughput; + + return true; +} + +/** + * batadv_parse_gw_bandwidth() - parse supplied string buffer to extract + * download and upload bandwidth information + * @net_dev: the soft interface net device + * @buff: string buffer to parse + * @down: pointer holding the returned download bandwidth information + * @up: pointer holding the returned upload bandwidth information + * + * Return: false on parse error and true otherwise. + */ +static bool batadv_parse_gw_bandwidth(struct net_device *net_dev, char *buff, + u32 *down, u32 *up) +{ + char *slash_ptr; + bool ret; + + slash_ptr = strchr(buff, '/'); + if (slash_ptr) + *slash_ptr = 0; + + ret = batadv_parse_throughput(net_dev, buff, "download gateway speed", + down); + if (!ret) + return false; + + /* we also got some upload info */ + if (slash_ptr) { + ret = batadv_parse_throughput(net_dev, slash_ptr + 1, + "upload gateway speed", up); + if (!ret) + return false; + } + + return true; +} + +/** + * batadv_gw_tvlv_container_update() - update the gw tvlv container after + * gateway setting change + * @bat_priv: the bat priv with all the soft interface information + */ +void batadv_gw_tvlv_container_update(struct batadv_priv *bat_priv) +{ + struct batadv_tvlv_gateway_data gw; + u32 down, up; + char gw_mode; + + gw_mode = atomic_read(&bat_priv->gw.mode); + + switch (gw_mode) { + case BATADV_GW_MODE_OFF: + case BATADV_GW_MODE_CLIENT: + batadv_tvlv_container_unregister(bat_priv, BATADV_TVLV_GW, 1); + break; + case BATADV_GW_MODE_SERVER: + down = atomic_read(&bat_priv->gw.bandwidth_down); + up = atomic_read(&bat_priv->gw.bandwidth_up); + gw.bandwidth_down = htonl(down); + gw.bandwidth_up = htonl(up); + batadv_tvlv_container_register(bat_priv, BATADV_TVLV_GW, 1, + &gw, sizeof(gw)); + break; + } +} + +/** + * batadv_gw_bandwidth_set() - Parse and set download/upload gateway bandwidth + * from supplied string buffer + * @net_dev: netdev struct of the soft interface + * @buff: the buffer containing the user data + * @count: number of bytes in the buffer + * + * Return: 'count' on success or a negative error code in case of failure + */ +ssize_t batadv_gw_bandwidth_set(struct net_device *net_dev, char *buff, + size_t count) +{ + struct batadv_priv *bat_priv = netdev_priv(net_dev); + u32 down_curr; + u32 up_curr; + u32 down_new = 0; + u32 up_new = 0; + bool ret; + + down_curr = (unsigned int)atomic_read(&bat_priv->gw.bandwidth_down); + up_curr = (unsigned int)atomic_read(&bat_priv->gw.bandwidth_up); + + ret = batadv_parse_gw_bandwidth(net_dev, buff, &down_new, &up_new); + if (!ret) + return -EINVAL; + + if (!down_new) + down_new = 1; + + if (!up_new) + up_new = down_new / 5; + + if (!up_new) + up_new = 1; + + if (down_curr == down_new && up_curr == up_new) + return count; + + batadv_gw_reselect(bat_priv); + batadv_info(net_dev, + "Changing gateway bandwidth from: '%u.%u/%u.%u MBit' to: '%u.%u/%u.%u MBit'\n", + down_curr / 10, down_curr % 10, up_curr / 10, up_curr % 10, + down_new / 10, down_new % 10, up_new / 10, up_new % 10); + + atomic_set(&bat_priv->gw.bandwidth_down, down_new); + atomic_set(&bat_priv->gw.bandwidth_up, up_new); + batadv_gw_tvlv_container_update(bat_priv); + + return count; +} + +/** + * batadv_gw_tvlv_ogm_handler_v1() - process incoming gateway tvlv container + * @bat_priv: the bat priv with all the soft interface information + * @orig: the orig_node of the ogm + * @flags: flags indicating the tvlv state (see batadv_tvlv_handler_flags) + * @tvlv_value: tvlv buffer containing the gateway data + * @tvlv_value_len: tvlv buffer length + */ +static void batadv_gw_tvlv_ogm_handler_v1(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig, + u8 flags, + void *tvlv_value, u16 tvlv_value_len) +{ + struct batadv_tvlv_gateway_data gateway, *gateway_ptr; + + /* only fetch the tvlv value if the handler wasn't called via the + * CIFNOTFND flag and if there is data to fetch + */ + if (flags & BATADV_TVLV_HANDLER_OGM_CIFNOTFND || + tvlv_value_len < sizeof(gateway)) { + gateway.bandwidth_down = 0; + gateway.bandwidth_up = 0; + } else { + gateway_ptr = tvlv_value; + gateway.bandwidth_down = gateway_ptr->bandwidth_down; + gateway.bandwidth_up = gateway_ptr->bandwidth_up; + if (gateway.bandwidth_down == 0 || + gateway.bandwidth_up == 0) { + gateway.bandwidth_down = 0; + gateway.bandwidth_up = 0; + } + } + + batadv_gw_node_update(bat_priv, orig, &gateway); + + /* restart gateway selection */ + if (gateway.bandwidth_down != 0 && + atomic_read(&bat_priv->gw.mode) == BATADV_GW_MODE_CLIENT) + batadv_gw_check_election(bat_priv, orig); +} + +/** + * batadv_gw_init() - initialise the gateway handling internals + * @bat_priv: the bat priv with all the soft interface information + */ +void batadv_gw_init(struct batadv_priv *bat_priv) +{ + if (bat_priv->algo_ops->gw.init_sel_class) + bat_priv->algo_ops->gw.init_sel_class(bat_priv); + else + atomic_set(&bat_priv->gw.sel_class, 1); + + batadv_tvlv_handler_register(bat_priv, batadv_gw_tvlv_ogm_handler_v1, + NULL, BATADV_TVLV_GW, 1, + BATADV_TVLV_HANDLER_OGM_CIFNOTFND); +} + +/** + * batadv_gw_free() - free the gateway handling internals + * @bat_priv: the bat priv with all the soft interface information + */ +void batadv_gw_free(struct batadv_priv *bat_priv) +{ + batadv_tvlv_container_unregister(bat_priv, BATADV_TVLV_GW, 1); + batadv_tvlv_handler_unregister(bat_priv, BATADV_TVLV_GW, 1); +} diff --git a/net/batman-adv/gateway_common.h b/net/batman-adv/gateway_common.h new file mode 100644 index 000000000..87c37f907 --- /dev/null +++ b/net/batman-adv/gateway_common.h @@ -0,0 +1,38 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner + */ + +#ifndef _NET_BATMAN_ADV_GATEWAY_COMMON_H_ +#define _NET_BATMAN_ADV_GATEWAY_COMMON_H_ + +#include "main.h" + +#include +#include + +/** + * enum batadv_bandwidth_units - bandwidth unit types + */ +enum batadv_bandwidth_units { + /** @BATADV_BW_UNIT_KBIT: unit type kbit */ + BATADV_BW_UNIT_KBIT, + + /** @BATADV_BW_UNIT_MBIT: unit type mbit */ + BATADV_BW_UNIT_MBIT, +}; + +#define BATADV_GW_MODE_OFF_NAME "off" +#define BATADV_GW_MODE_CLIENT_NAME "client" +#define BATADV_GW_MODE_SERVER_NAME "server" + +ssize_t batadv_gw_bandwidth_set(struct net_device *net_dev, char *buff, + size_t count); +void batadv_gw_tvlv_container_update(struct batadv_priv *bat_priv); +void batadv_gw_init(struct batadv_priv *bat_priv); +void batadv_gw_free(struct batadv_priv *bat_priv); +bool batadv_parse_throughput(struct net_device *net_dev, char *buff, + const char *description, u32 *throughput); + +#endif /* _NET_BATMAN_ADV_GATEWAY_COMMON_H_ */ diff --git a/net/batman-adv/hard-interface.c b/net/batman-adv/hard-interface.c new file mode 100644 index 000000000..24c9c0c3f --- /dev/null +++ b/net/batman-adv/hard-interface.c @@ -0,0 +1,1022 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#include "hard-interface.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "bat_v.h" +#include "bridge_loop_avoidance.h" +#include "distributed-arp-table.h" +#include "gateway_client.h" +#include "log.h" +#include "originator.h" +#include "send.h" +#include "soft-interface.h" +#include "translation-table.h" + +/** + * batadv_hardif_release() - release hard interface from lists and queue for + * free after rcu grace period + * @ref: kref pointer of the hard interface + */ +void batadv_hardif_release(struct kref *ref) +{ + struct batadv_hard_iface *hard_iface; + + hard_iface = container_of(ref, struct batadv_hard_iface, refcount); + dev_put(hard_iface->net_dev); + + kfree_rcu(hard_iface, rcu); +} + +/** + * batadv_hardif_get_by_netdev() - Get hard interface object of a net_device + * @net_dev: net_device to search for + * + * Return: batadv_hard_iface of net_dev (with increased refcnt), NULL on errors + */ +struct batadv_hard_iface * +batadv_hardif_get_by_netdev(const struct net_device *net_dev) +{ + struct batadv_hard_iface *hard_iface; + + rcu_read_lock(); + list_for_each_entry_rcu(hard_iface, &batadv_hardif_list, list) { + if (hard_iface->net_dev == net_dev && + kref_get_unless_zero(&hard_iface->refcount)) + goto out; + } + + hard_iface = NULL; + +out: + rcu_read_unlock(); + return hard_iface; +} + +/** + * batadv_getlink_net() - return link net namespace (of use fallback) + * @netdev: net_device to check + * @fallback_net: return in case get_link_net is not available for @netdev + * + * Return: result of rtnl_link_ops->get_link_net or @fallback_net + */ +static struct net *batadv_getlink_net(const struct net_device *netdev, + struct net *fallback_net) +{ + if (!netdev->rtnl_link_ops) + return fallback_net; + + if (!netdev->rtnl_link_ops->get_link_net) + return fallback_net; + + return netdev->rtnl_link_ops->get_link_net(netdev); +} + +/** + * batadv_mutual_parents() - check if two devices are each others parent + * @dev1: 1st net dev + * @net1: 1st devices netns + * @dev2: 2nd net dev + * @net2: 2nd devices netns + * + * veth devices come in pairs and each is the parent of the other! + * + * Return: true if the devices are each others parent, otherwise false + */ +static bool batadv_mutual_parents(const struct net_device *dev1, + struct net *net1, + const struct net_device *dev2, + struct net *net2) +{ + int dev1_parent_iflink = dev_get_iflink(dev1); + int dev2_parent_iflink = dev_get_iflink(dev2); + const struct net *dev1_parent_net; + const struct net *dev2_parent_net; + + dev1_parent_net = batadv_getlink_net(dev1, net1); + dev2_parent_net = batadv_getlink_net(dev2, net2); + + if (!dev1_parent_iflink || !dev2_parent_iflink) + return false; + + return (dev1_parent_iflink == dev2->ifindex) && + (dev2_parent_iflink == dev1->ifindex) && + net_eq(dev1_parent_net, net2) && + net_eq(dev2_parent_net, net1); +} + +/** + * batadv_is_on_batman_iface() - check if a device is a batman iface descendant + * @net_dev: the device to check + * + * If the user creates any virtual device on top of a batman-adv interface, it + * is important to prevent this new interface from being used to create a new + * mesh network (this behaviour would lead to a batman-over-batman + * configuration). This function recursively checks all the fathers of the + * device passed as argument looking for a batman-adv soft interface. + * + * Return: true if the device is descendant of a batman-adv mesh interface (or + * if it is a batman-adv interface itself), false otherwise + */ +static bool batadv_is_on_batman_iface(const struct net_device *net_dev) +{ + struct net *net = dev_net(net_dev); + struct net_device *parent_dev; + struct net *parent_net; + int iflink; + bool ret; + + /* check if this is a batman-adv mesh interface */ + if (batadv_softif_is_valid(net_dev)) + return true; + + iflink = dev_get_iflink(net_dev); + if (iflink == 0) + return false; + + parent_net = batadv_getlink_net(net_dev, net); + + /* iflink to itself, most likely physical device */ + if (net == parent_net && iflink == net_dev->ifindex) + return false; + + /* recurse over the parent device */ + parent_dev = __dev_get_by_index((struct net *)parent_net, iflink); + if (!parent_dev) { + pr_warn("Cannot find parent device. Skipping batadv-on-batadv check for %s\n", + net_dev->name); + return false; + } + + if (batadv_mutual_parents(net_dev, net, parent_dev, parent_net)) + return false; + + ret = batadv_is_on_batman_iface(parent_dev); + + return ret; +} + +static bool batadv_is_valid_iface(const struct net_device *net_dev) +{ + if (net_dev->flags & IFF_LOOPBACK) + return false; + + if (net_dev->type != ARPHRD_ETHER) + return false; + + if (net_dev->addr_len != ETH_ALEN) + return false; + + /* no batman over batman */ + if (batadv_is_on_batman_iface(net_dev)) + return false; + + return true; +} + +/** + * batadv_get_real_netdevice() - check if the given netdev struct is a virtual + * interface on top of another 'real' interface + * @netdev: the device to check + * + * Callers must hold the rtnl semaphore. You may want batadv_get_real_netdev() + * instead of this. + * + * Return: the 'real' net device or the original net device and NULL in case + * of an error. + */ +static struct net_device *batadv_get_real_netdevice(struct net_device *netdev) +{ + struct batadv_hard_iface *hard_iface = NULL; + struct net_device *real_netdev = NULL; + struct net *real_net; + struct net *net; + int iflink; + + ASSERT_RTNL(); + + if (!netdev) + return NULL; + + iflink = dev_get_iflink(netdev); + if (iflink == 0) { + dev_hold(netdev); + return netdev; + } + + hard_iface = batadv_hardif_get_by_netdev(netdev); + if (!hard_iface || !hard_iface->soft_iface) + goto out; + + net = dev_net(hard_iface->soft_iface); + real_net = batadv_getlink_net(netdev, net); + + /* iflink to itself, most likely physical device */ + if (net == real_net && netdev->ifindex == iflink) { + real_netdev = netdev; + dev_hold(real_netdev); + goto out; + } + + real_netdev = dev_get_by_index(real_net, iflink); + +out: + batadv_hardif_put(hard_iface); + return real_netdev; +} + +/** + * batadv_get_real_netdev() - check if the given net_device struct is a virtual + * interface on top of another 'real' interface + * @net_device: the device to check + * + * Return: the 'real' net device or the original net device and NULL in case + * of an error. + */ +struct net_device *batadv_get_real_netdev(struct net_device *net_device) +{ + struct net_device *real_netdev; + + rtnl_lock(); + real_netdev = batadv_get_real_netdevice(net_device); + rtnl_unlock(); + + return real_netdev; +} + +/** + * batadv_is_wext_netdev() - check if the given net_device struct is a + * wext wifi interface + * @net_device: the device to check + * + * Return: true if the net device is a wext wireless device, false + * otherwise. + */ +static bool batadv_is_wext_netdev(struct net_device *net_device) +{ + if (!net_device) + return false; + +#ifdef CONFIG_WIRELESS_EXT + /* pre-cfg80211 drivers have to implement WEXT, so it is possible to + * check for wireless_handlers != NULL + */ + if (net_device->wireless_handlers) + return true; +#endif + + return false; +} + +/** + * batadv_is_cfg80211_netdev() - check if the given net_device struct is a + * cfg80211 wifi interface + * @net_device: the device to check + * + * Return: true if the net device is a cfg80211 wireless device, false + * otherwise. + */ +static bool batadv_is_cfg80211_netdev(struct net_device *net_device) +{ + if (!net_device) + return false; + +#if IS_ENABLED(CONFIG_CFG80211) + /* cfg80211 drivers have to set ieee80211_ptr */ + if (net_device->ieee80211_ptr) + return true; +#endif + + return false; +} + +/** + * batadv_wifi_flags_evaluate() - calculate wifi flags for net_device + * @net_device: the device to check + * + * Return: batadv_hard_iface_wifi_flags flags of the device + */ +static u32 batadv_wifi_flags_evaluate(struct net_device *net_device) +{ + u32 wifi_flags = 0; + struct net_device *real_netdev; + + if (batadv_is_wext_netdev(net_device)) + wifi_flags |= BATADV_HARDIF_WIFI_WEXT_DIRECT; + + if (batadv_is_cfg80211_netdev(net_device)) + wifi_flags |= BATADV_HARDIF_WIFI_CFG80211_DIRECT; + + real_netdev = batadv_get_real_netdevice(net_device); + if (!real_netdev) + return wifi_flags; + + if (real_netdev == net_device) + goto out; + + if (batadv_is_wext_netdev(real_netdev)) + wifi_flags |= BATADV_HARDIF_WIFI_WEXT_INDIRECT; + + if (batadv_is_cfg80211_netdev(real_netdev)) + wifi_flags |= BATADV_HARDIF_WIFI_CFG80211_INDIRECT; + +out: + dev_put(real_netdev); + return wifi_flags; +} + +/** + * batadv_is_cfg80211_hardif() - check if the given hardif is a cfg80211 wifi + * interface + * @hard_iface: the device to check + * + * Return: true if the net device is a cfg80211 wireless device, false + * otherwise. + */ +bool batadv_is_cfg80211_hardif(struct batadv_hard_iface *hard_iface) +{ + u32 allowed_flags = 0; + + allowed_flags |= BATADV_HARDIF_WIFI_CFG80211_DIRECT; + allowed_flags |= BATADV_HARDIF_WIFI_CFG80211_INDIRECT; + + return !!(hard_iface->wifi_flags & allowed_flags); +} + +/** + * batadv_is_wifi_hardif() - check if the given hardif is a wifi interface + * @hard_iface: the device to check + * + * Return: true if the net device is a 802.11 wireless device, false otherwise. + */ +bool batadv_is_wifi_hardif(struct batadv_hard_iface *hard_iface) +{ + if (!hard_iface) + return false; + + return hard_iface->wifi_flags != 0; +} + +/** + * batadv_hardif_no_broadcast() - check whether (re)broadcast is necessary + * @if_outgoing: the outgoing interface checked and considered for (re)broadcast + * @orig_addr: the originator of this packet + * @orig_neigh: originator address of the forwarder we just got the packet from + * (NULL if we originated) + * + * Checks whether a packet needs to be (re)broadcasted on the given interface. + * + * Return: + * BATADV_HARDIF_BCAST_NORECIPIENT: No neighbor on interface + * BATADV_HARDIF_BCAST_DUPFWD: Just one neighbor, but it is the forwarder + * BATADV_HARDIF_BCAST_DUPORIG: Just one neighbor, but it is the originator + * BATADV_HARDIF_BCAST_OK: Several neighbors, must broadcast + */ +int batadv_hardif_no_broadcast(struct batadv_hard_iface *if_outgoing, + u8 *orig_addr, u8 *orig_neigh) +{ + struct batadv_hardif_neigh_node *hardif_neigh; + struct hlist_node *first; + int ret = BATADV_HARDIF_BCAST_OK; + + rcu_read_lock(); + + /* 0 neighbors -> no (re)broadcast */ + first = rcu_dereference(hlist_first_rcu(&if_outgoing->neigh_list)); + if (!first) { + ret = BATADV_HARDIF_BCAST_NORECIPIENT; + goto out; + } + + /* >1 neighbors -> (re)broadcast */ + if (rcu_dereference(hlist_next_rcu(first))) + goto out; + + hardif_neigh = hlist_entry(first, struct batadv_hardif_neigh_node, + list); + + /* 1 neighbor, is the originator -> no rebroadcast */ + if (orig_addr && batadv_compare_eth(hardif_neigh->orig, orig_addr)) { + ret = BATADV_HARDIF_BCAST_DUPORIG; + /* 1 neighbor, is the one we received from -> no rebroadcast */ + } else if (orig_neigh && + batadv_compare_eth(hardif_neigh->orig, orig_neigh)) { + ret = BATADV_HARDIF_BCAST_DUPFWD; + } + +out: + rcu_read_unlock(); + return ret; +} + +static struct batadv_hard_iface * +batadv_hardif_get_active(const struct net_device *soft_iface) +{ + struct batadv_hard_iface *hard_iface; + + rcu_read_lock(); + list_for_each_entry_rcu(hard_iface, &batadv_hardif_list, list) { + if (hard_iface->soft_iface != soft_iface) + continue; + + if (hard_iface->if_status == BATADV_IF_ACTIVE && + kref_get_unless_zero(&hard_iface->refcount)) + goto out; + } + + hard_iface = NULL; + +out: + rcu_read_unlock(); + return hard_iface; +} + +static void batadv_primary_if_update_addr(struct batadv_priv *bat_priv, + struct batadv_hard_iface *oldif) +{ + struct batadv_hard_iface *primary_if; + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + goto out; + + batadv_dat_init_own_addr(bat_priv, primary_if); + batadv_bla_update_orig_address(bat_priv, primary_if, oldif); +out: + batadv_hardif_put(primary_if); +} + +static void batadv_primary_if_select(struct batadv_priv *bat_priv, + struct batadv_hard_iface *new_hard_iface) +{ + struct batadv_hard_iface *curr_hard_iface; + + ASSERT_RTNL(); + + if (new_hard_iface) + kref_get(&new_hard_iface->refcount); + + curr_hard_iface = rcu_replace_pointer(bat_priv->primary_if, + new_hard_iface, 1); + + if (!new_hard_iface) + goto out; + + bat_priv->algo_ops->iface.primary_set(new_hard_iface); + batadv_primary_if_update_addr(bat_priv, curr_hard_iface); + +out: + batadv_hardif_put(curr_hard_iface); +} + +static bool +batadv_hardif_is_iface_up(const struct batadv_hard_iface *hard_iface) +{ + if (hard_iface->net_dev->flags & IFF_UP) + return true; + + return false; +} + +static void batadv_check_known_mac_addr(const struct net_device *net_dev) +{ + const struct batadv_hard_iface *hard_iface; + + rcu_read_lock(); + list_for_each_entry_rcu(hard_iface, &batadv_hardif_list, list) { + if (hard_iface->if_status != BATADV_IF_ACTIVE && + hard_iface->if_status != BATADV_IF_TO_BE_ACTIVATED) + continue; + + if (hard_iface->net_dev == net_dev) + continue; + + if (!batadv_compare_eth(hard_iface->net_dev->dev_addr, + net_dev->dev_addr)) + continue; + + pr_warn("The newly added mac address (%pM) already exists on: %s\n", + net_dev->dev_addr, hard_iface->net_dev->name); + pr_warn("It is strongly recommended to keep mac addresses unique to avoid problems!\n"); + } + rcu_read_unlock(); +} + +/** + * batadv_hardif_recalc_extra_skbroom() - Recalculate skbuff extra head/tailroom + * @soft_iface: netdev struct of the mesh interface + */ +static void batadv_hardif_recalc_extra_skbroom(struct net_device *soft_iface) +{ + const struct batadv_hard_iface *hard_iface; + unsigned short lower_header_len = ETH_HLEN; + unsigned short lower_headroom = 0; + unsigned short lower_tailroom = 0; + unsigned short needed_headroom; + + rcu_read_lock(); + list_for_each_entry_rcu(hard_iface, &batadv_hardif_list, list) { + if (hard_iface->if_status == BATADV_IF_NOT_IN_USE) + continue; + + if (hard_iface->soft_iface != soft_iface) + continue; + + lower_header_len = max_t(unsigned short, lower_header_len, + hard_iface->net_dev->hard_header_len); + + lower_headroom = max_t(unsigned short, lower_headroom, + hard_iface->net_dev->needed_headroom); + + lower_tailroom = max_t(unsigned short, lower_tailroom, + hard_iface->net_dev->needed_tailroom); + } + rcu_read_unlock(); + + needed_headroom = lower_headroom + (lower_header_len - ETH_HLEN); + needed_headroom += batadv_max_header_len(); + + /* fragmentation headers don't strip the unicast/... header */ + needed_headroom += sizeof(struct batadv_frag_packet); + + soft_iface->needed_headroom = needed_headroom; + soft_iface->needed_tailroom = lower_tailroom; +} + +/** + * batadv_hardif_min_mtu() - Calculate maximum MTU for soft interface + * @soft_iface: netdev struct of the soft interface + * + * Return: MTU for the soft-interface (limited by the minimal MTU of all active + * slave interfaces) + */ +int batadv_hardif_min_mtu(struct net_device *soft_iface) +{ + struct batadv_priv *bat_priv = netdev_priv(soft_iface); + const struct batadv_hard_iface *hard_iface; + int min_mtu = INT_MAX; + + rcu_read_lock(); + list_for_each_entry_rcu(hard_iface, &batadv_hardif_list, list) { + if (hard_iface->if_status != BATADV_IF_ACTIVE && + hard_iface->if_status != BATADV_IF_TO_BE_ACTIVATED) + continue; + + if (hard_iface->soft_iface != soft_iface) + continue; + + min_mtu = min_t(int, hard_iface->net_dev->mtu, min_mtu); + } + rcu_read_unlock(); + + if (atomic_read(&bat_priv->fragmentation) == 0) + goto out; + + /* with fragmentation enabled the maximum size of internally generated + * packets such as translation table exchanges or tvlv containers, etc + * has to be calculated + */ + min_mtu = min_t(int, min_mtu, BATADV_FRAG_MAX_FRAG_SIZE); + min_mtu -= sizeof(struct batadv_frag_packet); + min_mtu *= BATADV_FRAG_MAX_FRAGMENTS; + +out: + /* report to the other components the maximum amount of bytes that + * batman-adv can send over the wire (without considering the payload + * overhead). For example, this value is used by TT to compute the + * maximum local table size + */ + atomic_set(&bat_priv->packet_size_max, min_mtu); + + /* the real soft-interface MTU is computed by removing the payload + * overhead from the maximum amount of bytes that was just computed. + * + * However batman-adv does not support MTUs bigger than ETH_DATA_LEN + */ + return min_t(int, min_mtu - batadv_max_header_len(), ETH_DATA_LEN); +} + +/** + * batadv_update_min_mtu() - Adjusts the MTU if a new interface with a smaller + * MTU appeared + * @soft_iface: netdev struct of the soft interface + */ +void batadv_update_min_mtu(struct net_device *soft_iface) +{ + struct batadv_priv *bat_priv = netdev_priv(soft_iface); + int limit_mtu; + int mtu; + + mtu = batadv_hardif_min_mtu(soft_iface); + + if (bat_priv->mtu_set_by_user) + limit_mtu = bat_priv->mtu_set_by_user; + else + limit_mtu = ETH_DATA_LEN; + + mtu = min(mtu, limit_mtu); + dev_set_mtu(soft_iface, mtu); + + /* Check if the local translate table should be cleaned up to match a + * new (and smaller) MTU. + */ + batadv_tt_local_resize_to_mtu(soft_iface); +} + +static void +batadv_hardif_activate_interface(struct batadv_hard_iface *hard_iface) +{ + struct batadv_priv *bat_priv; + struct batadv_hard_iface *primary_if = NULL; + + if (hard_iface->if_status != BATADV_IF_INACTIVE) + goto out; + + bat_priv = netdev_priv(hard_iface->soft_iface); + + bat_priv->algo_ops->iface.update_mac(hard_iface); + hard_iface->if_status = BATADV_IF_TO_BE_ACTIVATED; + + /* the first active interface becomes our primary interface or + * the next active interface after the old primary interface was removed + */ + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + batadv_primary_if_select(bat_priv, hard_iface); + + batadv_info(hard_iface->soft_iface, "Interface activated: %s\n", + hard_iface->net_dev->name); + + batadv_update_min_mtu(hard_iface->soft_iface); + + if (bat_priv->algo_ops->iface.activate) + bat_priv->algo_ops->iface.activate(hard_iface); + +out: + batadv_hardif_put(primary_if); +} + +static void +batadv_hardif_deactivate_interface(struct batadv_hard_iface *hard_iface) +{ + if (hard_iface->if_status != BATADV_IF_ACTIVE && + hard_iface->if_status != BATADV_IF_TO_BE_ACTIVATED) + return; + + hard_iface->if_status = BATADV_IF_INACTIVE; + + batadv_info(hard_iface->soft_iface, "Interface deactivated: %s\n", + hard_iface->net_dev->name); + + batadv_update_min_mtu(hard_iface->soft_iface); +} + +/** + * batadv_hardif_enable_interface() - Enslave hard interface to soft interface + * @hard_iface: hard interface to add to soft interface + * @soft_iface: netdev struct of the mesh interface + * + * Return: 0 on success or negative error number in case of failure + */ +int batadv_hardif_enable_interface(struct batadv_hard_iface *hard_iface, + struct net_device *soft_iface) +{ + struct batadv_priv *bat_priv; + __be16 ethertype = htons(ETH_P_BATMAN); + int max_header_len = batadv_max_header_len(); + int ret; + + if (hard_iface->net_dev->mtu < ETH_MIN_MTU + max_header_len) + return -EINVAL; + + if (hard_iface->if_status != BATADV_IF_NOT_IN_USE) + goto out; + + kref_get(&hard_iface->refcount); + + dev_hold(soft_iface); + hard_iface->soft_iface = soft_iface; + bat_priv = netdev_priv(hard_iface->soft_iface); + + ret = netdev_master_upper_dev_link(hard_iface->net_dev, + soft_iface, NULL, NULL, NULL); + if (ret) + goto err_dev; + + ret = bat_priv->algo_ops->iface.enable(hard_iface); + if (ret < 0) + goto err_upper; + + hard_iface->if_status = BATADV_IF_INACTIVE; + + kref_get(&hard_iface->refcount); + hard_iface->batman_adv_ptype.type = ethertype; + hard_iface->batman_adv_ptype.func = batadv_batman_skb_recv; + hard_iface->batman_adv_ptype.dev = hard_iface->net_dev; + dev_add_pack(&hard_iface->batman_adv_ptype); + + batadv_info(hard_iface->soft_iface, "Adding interface: %s\n", + hard_iface->net_dev->name); + + if (atomic_read(&bat_priv->fragmentation) && + hard_iface->net_dev->mtu < ETH_DATA_LEN + max_header_len) + batadv_info(hard_iface->soft_iface, + "The MTU of interface %s is too small (%i) to handle the transport of batman-adv packets. Packets going over this interface will be fragmented on layer2 which could impact the performance. Setting the MTU to %i would solve the problem.\n", + hard_iface->net_dev->name, hard_iface->net_dev->mtu, + ETH_DATA_LEN + max_header_len); + + if (!atomic_read(&bat_priv->fragmentation) && + hard_iface->net_dev->mtu < ETH_DATA_LEN + max_header_len) + batadv_info(hard_iface->soft_iface, + "The MTU of interface %s is too small (%i) to handle the transport of batman-adv packets. If you experience problems getting traffic through try increasing the MTU to %i.\n", + hard_iface->net_dev->name, hard_iface->net_dev->mtu, + ETH_DATA_LEN + max_header_len); + + if (batadv_hardif_is_iface_up(hard_iface)) + batadv_hardif_activate_interface(hard_iface); + else + batadv_err(hard_iface->soft_iface, + "Not using interface %s (retrying later): interface not active\n", + hard_iface->net_dev->name); + + batadv_hardif_recalc_extra_skbroom(soft_iface); + + if (bat_priv->algo_ops->iface.enabled) + bat_priv->algo_ops->iface.enabled(hard_iface); + +out: + return 0; + +err_upper: + netdev_upper_dev_unlink(hard_iface->net_dev, soft_iface); +err_dev: + hard_iface->soft_iface = NULL; + dev_put(soft_iface); + batadv_hardif_put(hard_iface); + return ret; +} + +/** + * batadv_hardif_cnt() - get number of interfaces enslaved to soft interface + * @soft_iface: soft interface to check + * + * This function is only using RCU for locking - the result can therefore be + * off when another function is modifying the list at the same time. The + * caller can use the rtnl_lock to make sure that the count is accurate. + * + * Return: number of connected/enslaved hard interfaces + */ +static size_t batadv_hardif_cnt(const struct net_device *soft_iface) +{ + struct batadv_hard_iface *hard_iface; + size_t count = 0; + + rcu_read_lock(); + list_for_each_entry_rcu(hard_iface, &batadv_hardif_list, list) { + if (hard_iface->soft_iface != soft_iface) + continue; + + count++; + } + rcu_read_unlock(); + + return count; +} + +/** + * batadv_hardif_disable_interface() - Remove hard interface from soft interface + * @hard_iface: hard interface to be removed + */ +void batadv_hardif_disable_interface(struct batadv_hard_iface *hard_iface) +{ + struct batadv_priv *bat_priv = netdev_priv(hard_iface->soft_iface); + struct batadv_hard_iface *primary_if = NULL; + + batadv_hardif_deactivate_interface(hard_iface); + + if (hard_iface->if_status != BATADV_IF_INACTIVE) + goto out; + + batadv_info(hard_iface->soft_iface, "Removing interface: %s\n", + hard_iface->net_dev->name); + dev_remove_pack(&hard_iface->batman_adv_ptype); + batadv_hardif_put(hard_iface); + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (hard_iface == primary_if) { + struct batadv_hard_iface *new_if; + + new_if = batadv_hardif_get_active(hard_iface->soft_iface); + batadv_primary_if_select(bat_priv, new_if); + + batadv_hardif_put(new_if); + } + + bat_priv->algo_ops->iface.disable(hard_iface); + hard_iface->if_status = BATADV_IF_NOT_IN_USE; + + /* delete all references to this hard_iface */ + batadv_purge_orig_ref(bat_priv); + batadv_purge_outstanding_packets(bat_priv, hard_iface); + dev_put(hard_iface->soft_iface); + + netdev_upper_dev_unlink(hard_iface->net_dev, hard_iface->soft_iface); + batadv_hardif_recalc_extra_skbroom(hard_iface->soft_iface); + + /* nobody uses this interface anymore */ + if (batadv_hardif_cnt(hard_iface->soft_iface) <= 1) + batadv_gw_check_client_stop(bat_priv); + + hard_iface->soft_iface = NULL; + batadv_hardif_put(hard_iface); + +out: + batadv_hardif_put(primary_if); +} + +static struct batadv_hard_iface * +batadv_hardif_add_interface(struct net_device *net_dev) +{ + struct batadv_hard_iface *hard_iface; + + ASSERT_RTNL(); + + if (!batadv_is_valid_iface(net_dev)) + goto out; + + dev_hold(net_dev); + + hard_iface = kzalloc(sizeof(*hard_iface), GFP_ATOMIC); + if (!hard_iface) + goto release_dev; + + hard_iface->net_dev = net_dev; + hard_iface->soft_iface = NULL; + hard_iface->if_status = BATADV_IF_NOT_IN_USE; + + INIT_LIST_HEAD(&hard_iface->list); + INIT_HLIST_HEAD(&hard_iface->neigh_list); + + mutex_init(&hard_iface->bat_iv.ogm_buff_mutex); + spin_lock_init(&hard_iface->neigh_list_lock); + kref_init(&hard_iface->refcount); + + hard_iface->num_bcasts = BATADV_NUM_BCASTS_DEFAULT; + hard_iface->wifi_flags = batadv_wifi_flags_evaluate(net_dev); + if (batadv_is_wifi_hardif(hard_iface)) + hard_iface->num_bcasts = BATADV_NUM_BCASTS_WIRELESS; + + atomic_set(&hard_iface->hop_penalty, 0); + + batadv_v_hardif_init(hard_iface); + + batadv_check_known_mac_addr(hard_iface->net_dev); + kref_get(&hard_iface->refcount); + list_add_tail_rcu(&hard_iface->list, &batadv_hardif_list); + batadv_hardif_generation++; + + return hard_iface; + +release_dev: + dev_put(net_dev); +out: + return NULL; +} + +static void batadv_hardif_remove_interface(struct batadv_hard_iface *hard_iface) +{ + ASSERT_RTNL(); + + /* first deactivate interface */ + if (hard_iface->if_status != BATADV_IF_NOT_IN_USE) + batadv_hardif_disable_interface(hard_iface); + + if (hard_iface->if_status != BATADV_IF_NOT_IN_USE) + return; + + hard_iface->if_status = BATADV_IF_TO_BE_REMOVED; + batadv_hardif_put(hard_iface); +} + +/** + * batadv_hard_if_event_softif() - Handle events for soft interfaces + * @event: NETDEV_* event to handle + * @net_dev: net_device which generated an event + * + * Return: NOTIFY_* result + */ +static int batadv_hard_if_event_softif(unsigned long event, + struct net_device *net_dev) +{ + struct batadv_priv *bat_priv; + + switch (event) { + case NETDEV_REGISTER: + bat_priv = netdev_priv(net_dev); + batadv_softif_create_vlan(bat_priv, BATADV_NO_FLAGS); + break; + } + + return NOTIFY_DONE; +} + +static int batadv_hard_if_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct net_device *net_dev = netdev_notifier_info_to_dev(ptr); + struct batadv_hard_iface *hard_iface; + struct batadv_hard_iface *primary_if = NULL; + struct batadv_priv *bat_priv; + + if (batadv_softif_is_valid(net_dev)) + return batadv_hard_if_event_softif(event, net_dev); + + hard_iface = batadv_hardif_get_by_netdev(net_dev); + if (!hard_iface && (event == NETDEV_REGISTER || + event == NETDEV_POST_TYPE_CHANGE)) + hard_iface = batadv_hardif_add_interface(net_dev); + + if (!hard_iface) + goto out; + + switch (event) { + case NETDEV_UP: + batadv_hardif_activate_interface(hard_iface); + break; + case NETDEV_GOING_DOWN: + case NETDEV_DOWN: + batadv_hardif_deactivate_interface(hard_iface); + break; + case NETDEV_UNREGISTER: + case NETDEV_PRE_TYPE_CHANGE: + list_del_rcu(&hard_iface->list); + batadv_hardif_generation++; + + batadv_hardif_remove_interface(hard_iface); + break; + case NETDEV_CHANGEMTU: + if (hard_iface->soft_iface) + batadv_update_min_mtu(hard_iface->soft_iface); + break; + case NETDEV_CHANGEADDR: + if (hard_iface->if_status == BATADV_IF_NOT_IN_USE) + goto hardif_put; + + batadv_check_known_mac_addr(hard_iface->net_dev); + + bat_priv = netdev_priv(hard_iface->soft_iface); + bat_priv->algo_ops->iface.update_mac(hard_iface); + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + goto hardif_put; + + if (hard_iface == primary_if) + batadv_primary_if_update_addr(bat_priv, NULL); + break; + case NETDEV_CHANGEUPPER: + hard_iface->wifi_flags = batadv_wifi_flags_evaluate(net_dev); + if (batadv_is_wifi_hardif(hard_iface)) + hard_iface->num_bcasts = BATADV_NUM_BCASTS_WIRELESS; + break; + default: + break; + } + +hardif_put: + batadv_hardif_put(hard_iface); +out: + batadv_hardif_put(primary_if); + return NOTIFY_DONE; +} + +struct notifier_block batadv_hard_if_notifier = { + .notifier_call = batadv_hard_if_event, +}; diff --git a/net/batman-adv/hard-interface.h b/net/batman-adv/hard-interface.h new file mode 100644 index 000000000..64f660dbb --- /dev/null +++ b/net/batman-adv/hard-interface.h @@ -0,0 +1,122 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#ifndef _NET_BATMAN_ADV_HARD_INTERFACE_H_ +#define _NET_BATMAN_ADV_HARD_INTERFACE_H_ + +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include + +/** + * enum batadv_hard_if_state - State of a hard interface + */ +enum batadv_hard_if_state { + /** + * @BATADV_IF_NOT_IN_USE: interface is not used as slave interface of a + * batman-adv soft interface + */ + BATADV_IF_NOT_IN_USE, + + /** + * @BATADV_IF_TO_BE_REMOVED: interface will be removed from soft + * interface + */ + BATADV_IF_TO_BE_REMOVED, + + /** @BATADV_IF_INACTIVE: interface is deactivated */ + BATADV_IF_INACTIVE, + + /** @BATADV_IF_ACTIVE: interface is used */ + BATADV_IF_ACTIVE, + + /** @BATADV_IF_TO_BE_ACTIVATED: interface is getting activated */ + BATADV_IF_TO_BE_ACTIVATED, +}; + +/** + * enum batadv_hard_if_bcast - broadcast avoidance options + */ +enum batadv_hard_if_bcast { + /** @BATADV_HARDIF_BCAST_OK: Do broadcast on according hard interface */ + BATADV_HARDIF_BCAST_OK = 0, + + /** + * @BATADV_HARDIF_BCAST_NORECIPIENT: Broadcast not needed, there is no + * recipient + */ + BATADV_HARDIF_BCAST_NORECIPIENT, + + /** + * @BATADV_HARDIF_BCAST_DUPFWD: There is just the neighbor we got it + * from + */ + BATADV_HARDIF_BCAST_DUPFWD, + + /** @BATADV_HARDIF_BCAST_DUPORIG: There is just the originator */ + BATADV_HARDIF_BCAST_DUPORIG, +}; + +extern struct notifier_block batadv_hard_if_notifier; + +struct net_device *batadv_get_real_netdev(struct net_device *net_device); +bool batadv_is_cfg80211_hardif(struct batadv_hard_iface *hard_iface); +bool batadv_is_wifi_hardif(struct batadv_hard_iface *hard_iface); +struct batadv_hard_iface* +batadv_hardif_get_by_netdev(const struct net_device *net_dev); +int batadv_hardif_enable_interface(struct batadv_hard_iface *hard_iface, + struct net_device *soft_iface); +void batadv_hardif_disable_interface(struct batadv_hard_iface *hard_iface); +int batadv_hardif_min_mtu(struct net_device *soft_iface); +void batadv_update_min_mtu(struct net_device *soft_iface); +void batadv_hardif_release(struct kref *ref); +int batadv_hardif_no_broadcast(struct batadv_hard_iface *if_outgoing, + u8 *orig_addr, u8 *orig_neigh); + +/** + * batadv_hardif_put() - decrement the hard interface refcounter and possibly + * release it + * @hard_iface: the hard interface to free + */ +static inline void batadv_hardif_put(struct batadv_hard_iface *hard_iface) +{ + if (!hard_iface) + return; + + kref_put(&hard_iface->refcount, batadv_hardif_release); +} + +/** + * batadv_primary_if_get_selected() - Get reference to primary interface + * @bat_priv: the bat priv with all the soft interface information + * + * Return: primary interface (with increased refcnt), otherwise NULL + */ +static inline struct batadv_hard_iface * +batadv_primary_if_get_selected(struct batadv_priv *bat_priv) +{ + struct batadv_hard_iface *hard_iface; + + rcu_read_lock(); + hard_iface = rcu_dereference(bat_priv->primary_if); + if (!hard_iface) + goto out; + + if (!kref_get_unless_zero(&hard_iface->refcount)) + hard_iface = NULL; + +out: + rcu_read_unlock(); + return hard_iface; +} + +#endif /* _NET_BATMAN_ADV_HARD_INTERFACE_H_ */ diff --git a/net/batman-adv/hash.c b/net/batman-adv/hash.c new file mode 100644 index 000000000..8016e6197 --- /dev/null +++ b/net/batman-adv/hash.c @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Simon Wunderlich, Marek Lindner + */ + +#include "hash.h" +#include "main.h" + +#include +#include +#include + +/* clears the hash */ +static void batadv_hash_init(struct batadv_hashtable *hash) +{ + u32 i; + + for (i = 0; i < hash->size; i++) { + INIT_HLIST_HEAD(&hash->table[i]); + spin_lock_init(&hash->list_locks[i]); + } + + atomic_set(&hash->generation, 0); +} + +/** + * batadv_hash_destroy() - Free only the hashtable and the hash itself + * @hash: hash object to destroy + */ +void batadv_hash_destroy(struct batadv_hashtable *hash) +{ + kfree(hash->list_locks); + kfree(hash->table); + kfree(hash); +} + +/** + * batadv_hash_new() - Allocates and clears the hashtable + * @size: number of hash buckets to allocate + * + * Return: newly allocated hashtable, NULL on errors + */ +struct batadv_hashtable *batadv_hash_new(u32 size) +{ + struct batadv_hashtable *hash; + + hash = kmalloc(sizeof(*hash), GFP_ATOMIC); + if (!hash) + return NULL; + + hash->table = kmalloc_array(size, sizeof(*hash->table), GFP_ATOMIC); + if (!hash->table) + goto free_hash; + + hash->list_locks = kmalloc_array(size, sizeof(*hash->list_locks), + GFP_ATOMIC); + if (!hash->list_locks) + goto free_table; + + hash->size = size; + batadv_hash_init(hash); + return hash; + +free_table: + kfree(hash->table); +free_hash: + kfree(hash); + return NULL; +} + +/** + * batadv_hash_set_lock_class() - Set specific lockdep class for hash spinlocks + * @hash: hash object to modify + * @key: lockdep class key address + */ +void batadv_hash_set_lock_class(struct batadv_hashtable *hash, + struct lock_class_key *key) +{ + u32 i; + + for (i = 0; i < hash->size; i++) + lockdep_set_class(&hash->list_locks[i], key); +} diff --git a/net/batman-adv/hash.h b/net/batman-adv/hash.h new file mode 100644 index 000000000..fb251c385 --- /dev/null +++ b/net/batman-adv/hash.h @@ -0,0 +1,157 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Simon Wunderlich, Marek Lindner + */ + +#ifndef _NET_BATMAN_ADV_HASH_H_ +#define _NET_BATMAN_ADV_HASH_H_ + +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +/* callback to a compare function. should compare 2 element data for their + * keys + * + * Return: true if same and false if not same + */ +typedef bool (*batadv_hashdata_compare_cb)(const struct hlist_node *, + const void *); + +/* the hashfunction + * + * Return: an index based on the key in the data of the first argument and the + * size the second + */ +typedef u32 (*batadv_hashdata_choose_cb)(const void *, u32); +typedef void (*batadv_hashdata_free_cb)(struct hlist_node *, void *); + +/** + * struct batadv_hashtable - Wrapper of simple hlist based hashtable + */ +struct batadv_hashtable { + /** @table: the hashtable itself with the buckets */ + struct hlist_head *table; + + /** @list_locks: spinlock for each hash list entry */ + spinlock_t *list_locks; + + /** @size: size of hashtable */ + u32 size; + + /** @generation: current (generation) sequence number */ + atomic_t generation; +}; + +/* allocates and clears the hash */ +struct batadv_hashtable *batadv_hash_new(u32 size); + +/* set class key for all locks */ +void batadv_hash_set_lock_class(struct batadv_hashtable *hash, + struct lock_class_key *key); + +/* free only the hashtable and the hash itself. */ +void batadv_hash_destroy(struct batadv_hashtable *hash); + +/** + * batadv_hash_add() - adds data to the hashtable + * @hash: storage hash table + * @compare: callback to determine if 2 hash elements are identical + * @choose: callback calculating the hash index + * @data: data passed to the aforementioned callbacks as argument + * @data_node: to be added element + * + * Return: 0 on success, 1 if the element already is in the hash + * and -1 on error. + */ +static inline int batadv_hash_add(struct batadv_hashtable *hash, + batadv_hashdata_compare_cb compare, + batadv_hashdata_choose_cb choose, + const void *data, + struct hlist_node *data_node) +{ + u32 index; + int ret = -1; + struct hlist_head *head; + struct hlist_node *node; + spinlock_t *list_lock; /* spinlock to protect write access */ + + if (!hash) + goto out; + + index = choose(data, hash->size); + head = &hash->table[index]; + list_lock = &hash->list_locks[index]; + + spin_lock_bh(list_lock); + + hlist_for_each(node, head) { + if (!compare(node, data)) + continue; + + ret = 1; + goto unlock; + } + + /* no duplicate found in list, add new element */ + hlist_add_head_rcu(data_node, head); + atomic_inc(&hash->generation); + + ret = 0; + +unlock: + spin_unlock_bh(list_lock); +out: + return ret; +} + +/** + * batadv_hash_remove() - Removes data from hash, if found + * @hash: hash table + * @compare: callback to determine if 2 hash elements are identical + * @choose: callback calculating the hash index + * @data: data passed to the aforementioned callbacks as argument + * + * ata could be the structure you use with just the key filled, we just need + * the key for comparing. + * + * Return: returns pointer do data on success, so you can remove the used + * structure yourself, or NULL on error + */ +static inline void *batadv_hash_remove(struct batadv_hashtable *hash, + batadv_hashdata_compare_cb compare, + batadv_hashdata_choose_cb choose, + void *data) +{ + u32 index; + struct hlist_node *node; + struct hlist_head *head; + void *data_save = NULL; + + index = choose(data, hash->size); + head = &hash->table[index]; + + spin_lock_bh(&hash->list_locks[index]); + hlist_for_each(node, head) { + if (!compare(node, data)) + continue; + + data_save = node; + hlist_del_rcu(node); + atomic_inc(&hash->generation); + break; + } + spin_unlock_bh(&hash->list_locks[index]); + + return data_save; +} + +#endif /* _NET_BATMAN_ADV_HASH_H_ */ diff --git a/net/batman-adv/log.c b/net/batman-adv/log.c new file mode 100644 index 000000000..7a93a1e94 --- /dev/null +++ b/net/batman-adv/log.c @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner + */ + +#include "log.h" +#include "main.h" + +#include + +#include "trace.h" + +/** + * batadv_debug_log() - Add debug log entry + * @bat_priv: the bat priv with all the soft interface information + * @fmt: format string + * + * Return: 0 on success or negative error number in case of failure + */ +int batadv_debug_log(struct batadv_priv *bat_priv, const char *fmt, ...) +{ + struct va_format vaf; + va_list args; + + va_start(args, fmt); + + vaf.fmt = fmt; + vaf.va = &args; + + trace_batadv_dbg(bat_priv, &vaf); + + va_end(args); + + return 0; +} diff --git a/net/batman-adv/log.h b/net/batman-adv/log.h new file mode 100644 index 000000000..6717c965f --- /dev/null +++ b/net/batman-adv/log.h @@ -0,0 +1,143 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#ifndef _NET_BATMAN_ADV_LOG_H_ +#define _NET_BATMAN_ADV_LOG_H_ + +#include "main.h" + +#include +#include +#include +#include + +#ifdef CONFIG_BATMAN_ADV_DEBUG + +int batadv_debug_log_setup(struct batadv_priv *bat_priv); +void batadv_debug_log_cleanup(struct batadv_priv *bat_priv); + +#else + +static inline int batadv_debug_log_setup(struct batadv_priv *bat_priv) +{ + return 0; +} + +static inline void batadv_debug_log_cleanup(struct batadv_priv *bat_priv) +{ +} + +#endif + +/** + * enum batadv_dbg_level - available log levels + */ +enum batadv_dbg_level { + /** @BATADV_DBG_BATMAN: OGM and TQ computations related messages */ + BATADV_DBG_BATMAN = BIT(0), + + /** @BATADV_DBG_ROUTES: route added / changed / deleted */ + BATADV_DBG_ROUTES = BIT(1), + + /** @BATADV_DBG_TT: translation table messages */ + BATADV_DBG_TT = BIT(2), + + /** @BATADV_DBG_BLA: bridge loop avoidance messages */ + BATADV_DBG_BLA = BIT(3), + + /** @BATADV_DBG_DAT: ARP snooping and DAT related messages */ + BATADV_DBG_DAT = BIT(4), + + /** @BATADV_DBG_NC: network coding related messages */ + BATADV_DBG_NC = BIT(5), + + /** @BATADV_DBG_MCAST: multicast related messages */ + BATADV_DBG_MCAST = BIT(6), + + /** @BATADV_DBG_TP_METER: throughput meter messages */ + BATADV_DBG_TP_METER = BIT(7), + + /** @BATADV_DBG_ALL: the union of all the above log levels */ + BATADV_DBG_ALL = 255, +}; + +#ifdef CONFIG_BATMAN_ADV_DEBUG +int batadv_debug_log(struct batadv_priv *bat_priv, const char *fmt, ...) +__printf(2, 3); + +/** + * _batadv_dbg() - Store debug output with(out) rate limiting + * @type: type of debug message + * @bat_priv: the bat priv with all the soft interface information + * @ratelimited: whether output should be rate limited + * @fmt: format string + * @arg: variable arguments + */ +#define _batadv_dbg(type, bat_priv, ratelimited, fmt, arg...) \ + do { \ + struct batadv_priv *__batpriv = (bat_priv); \ + if (atomic_read(&__batpriv->log_level) & (type) && \ + (!(ratelimited) || net_ratelimit())) \ + batadv_debug_log(__batpriv, fmt, ## arg); \ + } \ + while (0) +#else /* !CONFIG_BATMAN_ADV_DEBUG */ +__printf(4, 5) +static inline void _batadv_dbg(int type __always_unused, + struct batadv_priv *bat_priv __always_unused, + int ratelimited __always_unused, + const char *fmt __always_unused, ...) +{ +} +#endif + +/** + * batadv_dbg() - Store debug output without rate limiting + * @type: type of debug message + * @bat_priv: the bat priv with all the soft interface information + * @arg: format string and variable arguments + */ +#define batadv_dbg(type, bat_priv, arg...) \ + _batadv_dbg(type, bat_priv, 0, ## arg) + +/** + * batadv_dbg_ratelimited() - Store debug output with rate limiting + * @type: type of debug message + * @bat_priv: the bat priv with all the soft interface information + * @arg: format string and variable arguments + */ +#define batadv_dbg_ratelimited(type, bat_priv, arg...) \ + _batadv_dbg(type, bat_priv, 1, ## arg) + +/** + * batadv_info() - Store message in debug buffer and print it to kmsg buffer + * @net_dev: the soft interface net device + * @fmt: format string + * @arg: variable arguments + */ +#define batadv_info(net_dev, fmt, arg...) \ + do { \ + struct net_device *_netdev = (net_dev); \ + struct batadv_priv *_batpriv = netdev_priv(_netdev); \ + batadv_dbg(BATADV_DBG_ALL, _batpriv, fmt, ## arg); \ + pr_info("%s: " fmt, _netdev->name, ## arg); \ + } while (0) + +/** + * batadv_err() - Store error in debug buffer and print it to kmsg buffer + * @net_dev: the soft interface net device + * @fmt: format string + * @arg: variable arguments + */ +#define batadv_err(net_dev, fmt, arg...) \ + do { \ + struct net_device *_netdev = (net_dev); \ + struct batadv_priv *_batpriv = netdev_priv(_netdev); \ + batadv_dbg(BATADV_DBG_ALL, _batpriv, fmt, ## arg); \ + pr_err("%s: " fmt, _netdev->name, ## arg); \ + } while (0) + +#endif /* _NET_BATMAN_ADV_LOG_H_ */ diff --git a/net/batman-adv/main.c b/net/batman-adv/main.c new file mode 100644 index 000000000..e8a449915 --- /dev/null +++ b/net/batman-adv/main.c @@ -0,0 +1,731 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "bat_algo.h" +#include "bat_iv_ogm.h" +#include "bat_v.h" +#include "bridge_loop_avoidance.h" +#include "distributed-arp-table.h" +#include "gateway_client.h" +#include "gateway_common.h" +#include "hard-interface.h" +#include "log.h" +#include "multicast.h" +#include "netlink.h" +#include "network-coding.h" +#include "originator.h" +#include "routing.h" +#include "send.h" +#include "soft-interface.h" +#include "tp_meter.h" +#include "translation-table.h" + +/* List manipulations on hardif_list have to be rtnl_lock()'ed, + * list traversals just rcu-locked + */ +struct list_head batadv_hardif_list; +unsigned int batadv_hardif_generation; +static int (*batadv_rx_handler[256])(struct sk_buff *skb, + struct batadv_hard_iface *recv_if); + +unsigned char batadv_broadcast_addr[] = {0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; + +struct workqueue_struct *batadv_event_workqueue; + +static void batadv_recv_handler_init(void); + +#define BATADV_UEV_TYPE_VAR "BATTYPE=" +#define BATADV_UEV_ACTION_VAR "BATACTION=" +#define BATADV_UEV_DATA_VAR "BATDATA=" + +static char *batadv_uev_action_str[] = { + "add", + "del", + "change", + "loopdetect", +}; + +static char *batadv_uev_type_str[] = { + "gw", + "bla", +}; + +static int __init batadv_init(void) +{ + int ret; + + ret = batadv_tt_cache_init(); + if (ret < 0) + return ret; + + INIT_LIST_HEAD(&batadv_hardif_list); + batadv_algo_init(); + + batadv_recv_handler_init(); + + batadv_v_init(); + batadv_iv_init(); + batadv_nc_init(); + batadv_tp_meter_init(); + + batadv_event_workqueue = create_singlethread_workqueue("bat_events"); + if (!batadv_event_workqueue) + goto err_create_wq; + + register_netdevice_notifier(&batadv_hard_if_notifier); + rtnl_link_register(&batadv_link_ops); + batadv_netlink_register(); + + pr_info("B.A.T.M.A.N. advanced %s (compatibility version %i) loaded\n", + BATADV_SOURCE_VERSION, BATADV_COMPAT_VERSION); + + return 0; + +err_create_wq: + batadv_tt_cache_destroy(); + + return -ENOMEM; +} + +static void __exit batadv_exit(void) +{ + batadv_netlink_unregister(); + rtnl_link_unregister(&batadv_link_ops); + unregister_netdevice_notifier(&batadv_hard_if_notifier); + + destroy_workqueue(batadv_event_workqueue); + batadv_event_workqueue = NULL; + + rcu_barrier(); + + batadv_tt_cache_destroy(); +} + +/** + * batadv_mesh_init() - Initialize soft interface + * @soft_iface: netdev struct of the soft interface + * + * Return: 0 on success or negative error number in case of failure + */ +int batadv_mesh_init(struct net_device *soft_iface) +{ + struct batadv_priv *bat_priv = netdev_priv(soft_iface); + int ret; + + spin_lock_init(&bat_priv->forw_bat_list_lock); + spin_lock_init(&bat_priv->forw_bcast_list_lock); + spin_lock_init(&bat_priv->tt.changes_list_lock); + spin_lock_init(&bat_priv->tt.req_list_lock); + spin_lock_init(&bat_priv->tt.roam_list_lock); + spin_lock_init(&bat_priv->tt.last_changeset_lock); + spin_lock_init(&bat_priv->tt.commit_lock); + spin_lock_init(&bat_priv->gw.list_lock); +#ifdef CONFIG_BATMAN_ADV_MCAST + spin_lock_init(&bat_priv->mcast.mla_lock); + spin_lock_init(&bat_priv->mcast.want_lists_lock); +#endif + spin_lock_init(&bat_priv->tvlv.container_list_lock); + spin_lock_init(&bat_priv->tvlv.handler_list_lock); + spin_lock_init(&bat_priv->softif_vlan_list_lock); + spin_lock_init(&bat_priv->tp_list_lock); + + INIT_HLIST_HEAD(&bat_priv->forw_bat_list); + INIT_HLIST_HEAD(&bat_priv->forw_bcast_list); + INIT_HLIST_HEAD(&bat_priv->gw.gateway_list); +#ifdef CONFIG_BATMAN_ADV_MCAST + INIT_HLIST_HEAD(&bat_priv->mcast.want_all_unsnoopables_list); + INIT_HLIST_HEAD(&bat_priv->mcast.want_all_ipv4_list); + INIT_HLIST_HEAD(&bat_priv->mcast.want_all_ipv6_list); +#endif + INIT_LIST_HEAD(&bat_priv->tt.changes_list); + INIT_HLIST_HEAD(&bat_priv->tt.req_list); + INIT_LIST_HEAD(&bat_priv->tt.roam_list); +#ifdef CONFIG_BATMAN_ADV_MCAST + INIT_HLIST_HEAD(&bat_priv->mcast.mla_list); +#endif + INIT_HLIST_HEAD(&bat_priv->tvlv.container_list); + INIT_HLIST_HEAD(&bat_priv->tvlv.handler_list); + INIT_HLIST_HEAD(&bat_priv->softif_vlan_list); + INIT_HLIST_HEAD(&bat_priv->tp_list); + + bat_priv->gw.generation = 0; + + ret = batadv_originator_init(bat_priv); + if (ret < 0) { + atomic_set(&bat_priv->mesh_state, BATADV_MESH_DEACTIVATING); + goto err_orig; + } + + ret = batadv_tt_init(bat_priv); + if (ret < 0) { + atomic_set(&bat_priv->mesh_state, BATADV_MESH_DEACTIVATING); + goto err_tt; + } + + ret = batadv_v_mesh_init(bat_priv); + if (ret < 0) { + atomic_set(&bat_priv->mesh_state, BATADV_MESH_DEACTIVATING); + goto err_v; + } + + ret = batadv_bla_init(bat_priv); + if (ret < 0) { + atomic_set(&bat_priv->mesh_state, BATADV_MESH_DEACTIVATING); + goto err_bla; + } + + ret = batadv_dat_init(bat_priv); + if (ret < 0) { + atomic_set(&bat_priv->mesh_state, BATADV_MESH_DEACTIVATING); + goto err_dat; + } + + ret = batadv_nc_mesh_init(bat_priv); + if (ret < 0) { + atomic_set(&bat_priv->mesh_state, BATADV_MESH_DEACTIVATING); + goto err_nc; + } + + batadv_gw_init(bat_priv); + batadv_mcast_init(bat_priv); + + atomic_set(&bat_priv->gw.reselect, 0); + atomic_set(&bat_priv->mesh_state, BATADV_MESH_ACTIVE); + + return 0; + +err_nc: + batadv_dat_free(bat_priv); +err_dat: + batadv_bla_free(bat_priv); +err_bla: + batadv_v_mesh_free(bat_priv); +err_v: + batadv_tt_free(bat_priv); +err_tt: + batadv_originator_free(bat_priv); +err_orig: + batadv_purge_outstanding_packets(bat_priv, NULL); + atomic_set(&bat_priv->mesh_state, BATADV_MESH_INACTIVE); + + return ret; +} + +/** + * batadv_mesh_free() - Deinitialize soft interface + * @soft_iface: netdev struct of the soft interface + */ +void batadv_mesh_free(struct net_device *soft_iface) +{ + struct batadv_priv *bat_priv = netdev_priv(soft_iface); + + atomic_set(&bat_priv->mesh_state, BATADV_MESH_DEACTIVATING); + + batadv_purge_outstanding_packets(bat_priv, NULL); + + batadv_gw_node_free(bat_priv); + + batadv_v_mesh_free(bat_priv); + batadv_nc_mesh_free(bat_priv); + batadv_dat_free(bat_priv); + batadv_bla_free(bat_priv); + + batadv_mcast_free(bat_priv); + + /* Free the TT and the originator tables only after having terminated + * all the other depending components which may use these structures for + * their purposes. + */ + batadv_tt_free(bat_priv); + + /* Since the originator table clean up routine is accessing the TT + * tables as well, it has to be invoked after the TT tables have been + * freed and marked as empty. This ensures that no cleanup RCU callbacks + * accessing the TT data are scheduled for later execution. + */ + batadv_originator_free(bat_priv); + + batadv_gw_free(bat_priv); + + free_percpu(bat_priv->bat_counters); + bat_priv->bat_counters = NULL; + + atomic_set(&bat_priv->mesh_state, BATADV_MESH_INACTIVE); +} + +/** + * batadv_is_my_mac() - check if the given mac address belongs to any of the + * real interfaces in the current mesh + * @bat_priv: the bat priv with all the soft interface information + * @addr: the address to check + * + * Return: 'true' if the mac address was found, false otherwise. + */ +bool batadv_is_my_mac(struct batadv_priv *bat_priv, const u8 *addr) +{ + const struct batadv_hard_iface *hard_iface; + bool is_my_mac = false; + + rcu_read_lock(); + list_for_each_entry_rcu(hard_iface, &batadv_hardif_list, list) { + if (hard_iface->if_status != BATADV_IF_ACTIVE) + continue; + + if (hard_iface->soft_iface != bat_priv->soft_iface) + continue; + + if (batadv_compare_eth(hard_iface->net_dev->dev_addr, addr)) { + is_my_mac = true; + break; + } + } + rcu_read_unlock(); + return is_my_mac; +} + +/** + * batadv_max_header_len() - calculate maximum encapsulation overhead for a + * payload packet + * + * Return: the maximum encapsulation overhead in bytes. + */ +int batadv_max_header_len(void) +{ + int header_len = 0; + + header_len = max_t(int, header_len, + sizeof(struct batadv_unicast_packet)); + header_len = max_t(int, header_len, + sizeof(struct batadv_unicast_4addr_packet)); + header_len = max_t(int, header_len, + sizeof(struct batadv_bcast_packet)); + +#ifdef CONFIG_BATMAN_ADV_NC + header_len = max_t(int, header_len, + sizeof(struct batadv_coded_packet)); +#endif + + return header_len + ETH_HLEN; +} + +/** + * batadv_skb_set_priority() - sets skb priority according to packet content + * @skb: the packet to be sent + * @offset: offset to the packet content + * + * This function sets a value between 256 and 263 (802.1d priority), which + * can be interpreted by the cfg80211 or other drivers. + */ +void batadv_skb_set_priority(struct sk_buff *skb, int offset) +{ + struct iphdr ip_hdr_tmp, *ip_hdr; + struct ipv6hdr ip6_hdr_tmp, *ip6_hdr; + struct ethhdr ethhdr_tmp, *ethhdr; + struct vlan_ethhdr *vhdr, vhdr_tmp; + u32 prio; + + /* already set, do nothing */ + if (skb->priority >= 256 && skb->priority <= 263) + return; + + ethhdr = skb_header_pointer(skb, offset, sizeof(*ethhdr), ðhdr_tmp); + if (!ethhdr) + return; + + switch (ethhdr->h_proto) { + case htons(ETH_P_8021Q): + vhdr = skb_header_pointer(skb, offset + sizeof(*vhdr), + sizeof(*vhdr), &vhdr_tmp); + if (!vhdr) + return; + prio = ntohs(vhdr->h_vlan_TCI) & VLAN_PRIO_MASK; + prio = prio >> VLAN_PRIO_SHIFT; + break; + case htons(ETH_P_IP): + ip_hdr = skb_header_pointer(skb, offset + sizeof(*ethhdr), + sizeof(*ip_hdr), &ip_hdr_tmp); + if (!ip_hdr) + return; + prio = (ipv4_get_dsfield(ip_hdr) & 0xfc) >> 5; + break; + case htons(ETH_P_IPV6): + ip6_hdr = skb_header_pointer(skb, offset + sizeof(*ethhdr), + sizeof(*ip6_hdr), &ip6_hdr_tmp); + if (!ip6_hdr) + return; + prio = (ipv6_get_dsfield(ip6_hdr) & 0xfc) >> 5; + break; + default: + return; + } + + skb->priority = prio + 256; +} + +static int batadv_recv_unhandled_packet(struct sk_buff *skb, + struct batadv_hard_iface *recv_if) +{ + kfree_skb(skb); + + return NET_RX_DROP; +} + +/* incoming packets with the batman ethertype received on any active hard + * interface + */ + +/** + * batadv_batman_skb_recv() - Handle incoming message from an hard interface + * @skb: the received packet + * @dev: the net device that the packet was received on + * @ptype: packet type of incoming packet (ETH_P_BATMAN) + * @orig_dev: the original receive net device (e.g. bonded device) + * + * Return: NET_RX_SUCCESS on success or NET_RX_DROP in case of failure + */ +int batadv_batman_skb_recv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *ptype, + struct net_device *orig_dev) +{ + struct batadv_priv *bat_priv; + struct batadv_ogm_packet *batadv_ogm_packet; + struct batadv_hard_iface *hard_iface; + u8 idx; + + hard_iface = container_of(ptype, struct batadv_hard_iface, + batman_adv_ptype); + + /* Prevent processing a packet received on an interface which is getting + * shut down otherwise the packet may trigger de-reference errors + * further down in the receive path. + */ + if (!kref_get_unless_zero(&hard_iface->refcount)) + goto err_out; + + skb = skb_share_check(skb, GFP_ATOMIC); + + /* skb was released by skb_share_check() */ + if (!skb) + goto err_put; + + /* packet should hold at least type and version */ + if (unlikely(!pskb_may_pull(skb, 2))) + goto err_free; + + /* expect a valid ethernet header here. */ + if (unlikely(skb->mac_len != ETH_HLEN || !skb_mac_header(skb))) + goto err_free; + + if (!hard_iface->soft_iface) + goto err_free; + + bat_priv = netdev_priv(hard_iface->soft_iface); + + if (atomic_read(&bat_priv->mesh_state) != BATADV_MESH_ACTIVE) + goto err_free; + + /* discard frames on not active interfaces */ + if (hard_iface->if_status != BATADV_IF_ACTIVE) + goto err_free; + + batadv_ogm_packet = (struct batadv_ogm_packet *)skb->data; + + if (batadv_ogm_packet->version != BATADV_COMPAT_VERSION) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Drop packet: incompatible batman version (%i)\n", + batadv_ogm_packet->version); + goto err_free; + } + + /* reset control block to avoid left overs from previous users */ + memset(skb->cb, 0, sizeof(struct batadv_skb_cb)); + + idx = batadv_ogm_packet->packet_type; + (*batadv_rx_handler[idx])(skb, hard_iface); + + batadv_hardif_put(hard_iface); + + /* return NET_RX_SUCCESS in any case as we + * most probably dropped the packet for + * routing-logical reasons. + */ + return NET_RX_SUCCESS; + +err_free: + kfree_skb(skb); +err_put: + batadv_hardif_put(hard_iface); +err_out: + return NET_RX_DROP; +} + +static void batadv_recv_handler_init(void) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(batadv_rx_handler); i++) + batadv_rx_handler[i] = batadv_recv_unhandled_packet; + + for (i = BATADV_UNICAST_MIN; i <= BATADV_UNICAST_MAX; i++) + batadv_rx_handler[i] = batadv_recv_unhandled_unicast_packet; + + /* compile time checks for sizes */ + BUILD_BUG_ON(sizeof(struct batadv_bla_claim_dst) != 6); + BUILD_BUG_ON(sizeof(struct batadv_ogm_packet) != 24); + BUILD_BUG_ON(sizeof(struct batadv_icmp_header) != 20); + BUILD_BUG_ON(sizeof(struct batadv_icmp_packet) != 20); + BUILD_BUG_ON(sizeof(struct batadv_icmp_packet_rr) != 116); + BUILD_BUG_ON(sizeof(struct batadv_unicast_packet) != 10); + BUILD_BUG_ON(sizeof(struct batadv_unicast_4addr_packet) != 18); + BUILD_BUG_ON(sizeof(struct batadv_frag_packet) != 20); + BUILD_BUG_ON(sizeof(struct batadv_bcast_packet) != 14); + BUILD_BUG_ON(sizeof(struct batadv_coded_packet) != 46); + BUILD_BUG_ON(sizeof(struct batadv_unicast_tvlv_packet) != 20); + BUILD_BUG_ON(sizeof(struct batadv_tvlv_hdr) != 4); + BUILD_BUG_ON(sizeof(struct batadv_tvlv_gateway_data) != 8); + BUILD_BUG_ON(sizeof(struct batadv_tvlv_tt_vlan_data) != 8); + BUILD_BUG_ON(sizeof(struct batadv_tvlv_tt_change) != 12); + BUILD_BUG_ON(sizeof(struct batadv_tvlv_roam_adv) != 8); + + i = sizeof_field(struct sk_buff, cb); + BUILD_BUG_ON(sizeof(struct batadv_skb_cb) > i); + + /* broadcast packet */ + batadv_rx_handler[BATADV_BCAST] = batadv_recv_bcast_packet; + + /* unicast packets ... */ + /* unicast with 4 addresses packet */ + batadv_rx_handler[BATADV_UNICAST_4ADDR] = batadv_recv_unicast_packet; + /* unicast packet */ + batadv_rx_handler[BATADV_UNICAST] = batadv_recv_unicast_packet; + /* unicast tvlv packet */ + batadv_rx_handler[BATADV_UNICAST_TVLV] = batadv_recv_unicast_tvlv; + /* batman icmp packet */ + batadv_rx_handler[BATADV_ICMP] = batadv_recv_icmp_packet; + /* Fragmented packets */ + batadv_rx_handler[BATADV_UNICAST_FRAG] = batadv_recv_frag_packet; +} + +/** + * batadv_recv_handler_register() - Register handler for batman-adv packet type + * @packet_type: batadv_packettype which should be handled + * @recv_handler: receive handler for the packet type + * + * Return: 0 on success or negative error number in case of failure + */ +int +batadv_recv_handler_register(u8 packet_type, + int (*recv_handler)(struct sk_buff *, + struct batadv_hard_iface *)) +{ + int (*curr)(struct sk_buff *skb, + struct batadv_hard_iface *recv_if); + curr = batadv_rx_handler[packet_type]; + + if (curr != batadv_recv_unhandled_packet && + curr != batadv_recv_unhandled_unicast_packet) + return -EBUSY; + + batadv_rx_handler[packet_type] = recv_handler; + return 0; +} + +/** + * batadv_recv_handler_unregister() - Unregister handler for packet type + * @packet_type: batadv_packettype which should no longer be handled + */ +void batadv_recv_handler_unregister(u8 packet_type) +{ + batadv_rx_handler[packet_type] = batadv_recv_unhandled_packet; +} + +/** + * batadv_skb_crc32() - calculate CRC32 of the whole packet and skip bytes in + * the header + * @skb: skb pointing to fragmented socket buffers + * @payload_ptr: Pointer to position inside the head buffer of the skb + * marking the start of the data to be CRC'ed + * + * payload_ptr must always point to an address in the skb head buffer and not to + * a fragment. + * + * Return: big endian crc32c of the checksummed data + */ +__be32 batadv_skb_crc32(struct sk_buff *skb, u8 *payload_ptr) +{ + u32 crc = 0; + unsigned int from; + unsigned int to = skb->len; + struct skb_seq_state st; + const u8 *data; + unsigned int len; + unsigned int consumed = 0; + + from = (unsigned int)(payload_ptr - skb->data); + + skb_prepare_seq_read(skb, from, to, &st); + while ((len = skb_seq_read(consumed, &data, &st)) != 0) { + crc = crc32c(crc, data, len); + consumed += len; + } + + return htonl(crc); +} + +/** + * batadv_get_vid() - extract the VLAN identifier from skb if any + * @skb: the buffer containing the packet + * @header_len: length of the batman header preceding the ethernet header + * + * Return: VID with the BATADV_VLAN_HAS_TAG flag when the packet embedded in the + * skb is vlan tagged. Otherwise BATADV_NO_FLAGS. + */ +unsigned short batadv_get_vid(struct sk_buff *skb, size_t header_len) +{ + struct ethhdr *ethhdr = (struct ethhdr *)(skb->data + header_len); + struct vlan_ethhdr *vhdr; + unsigned short vid; + + if (ethhdr->h_proto != htons(ETH_P_8021Q)) + return BATADV_NO_FLAGS; + + if (!pskb_may_pull(skb, header_len + VLAN_ETH_HLEN)) + return BATADV_NO_FLAGS; + + vhdr = (struct vlan_ethhdr *)(skb->data + header_len); + vid = ntohs(vhdr->h_vlan_TCI) & VLAN_VID_MASK; + vid |= BATADV_VLAN_HAS_TAG; + + return vid; +} + +/** + * batadv_vlan_ap_isola_get() - return AP isolation status for the given vlan + * @bat_priv: the bat priv with all the soft interface information + * @vid: the VLAN identifier for which the AP isolation attributed as to be + * looked up + * + * Return: true if AP isolation is on for the VLAN identified by vid, false + * otherwise + */ +bool batadv_vlan_ap_isola_get(struct batadv_priv *bat_priv, unsigned short vid) +{ + bool ap_isolation_enabled = false; + struct batadv_softif_vlan *vlan; + + /* if the AP isolation is requested on a VLAN, then check for its + * setting in the proper VLAN private data structure + */ + vlan = batadv_softif_vlan_get(bat_priv, vid); + if (vlan) { + ap_isolation_enabled = atomic_read(&vlan->ap_isolation); + batadv_softif_vlan_put(vlan); + } + + return ap_isolation_enabled; +} + +/** + * batadv_throw_uevent() - Send an uevent with batman-adv specific env data + * @bat_priv: the bat priv with all the soft interface information + * @type: subsystem type of event. Stored in uevent's BATTYPE + * @action: action type of event. Stored in uevent's BATACTION + * @data: string with additional information to the event (ignored for + * BATADV_UEV_DEL). Stored in uevent's BATDATA + * + * Return: 0 on success or negative error number in case of failure + */ +int batadv_throw_uevent(struct batadv_priv *bat_priv, enum batadv_uev_type type, + enum batadv_uev_action action, const char *data) +{ + int ret = -ENOMEM; + struct kobject *bat_kobj; + char *uevent_env[4] = { NULL, NULL, NULL, NULL }; + + bat_kobj = &bat_priv->soft_iface->dev.kobj; + + uevent_env[0] = kasprintf(GFP_ATOMIC, + "%s%s", BATADV_UEV_TYPE_VAR, + batadv_uev_type_str[type]); + if (!uevent_env[0]) + goto out; + + uevent_env[1] = kasprintf(GFP_ATOMIC, + "%s%s", BATADV_UEV_ACTION_VAR, + batadv_uev_action_str[action]); + if (!uevent_env[1]) + goto out; + + /* If the event is DEL, ignore the data field */ + if (action != BATADV_UEV_DEL) { + uevent_env[2] = kasprintf(GFP_ATOMIC, + "%s%s", BATADV_UEV_DATA_VAR, data); + if (!uevent_env[2]) + goto out; + } + + ret = kobject_uevent_env(bat_kobj, KOBJ_CHANGE, uevent_env); +out: + kfree(uevent_env[0]); + kfree(uevent_env[1]); + kfree(uevent_env[2]); + + if (ret) + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Impossible to send uevent for (%s,%s,%s) event (err: %d)\n", + batadv_uev_type_str[type], + batadv_uev_action_str[action], + (action == BATADV_UEV_DEL ? "NULL" : data), ret); + return ret; +} + +module_init(batadv_init); +module_exit(batadv_exit); + +MODULE_LICENSE("GPL"); + +MODULE_AUTHOR(BATADV_DRIVER_AUTHOR); +MODULE_DESCRIPTION(BATADV_DRIVER_DESC); +MODULE_VERSION(BATADV_SOURCE_VERSION); +MODULE_ALIAS_RTNL_LINK("batadv"); +MODULE_ALIAS_GENL_FAMILY(BATADV_NL_NAME); diff --git a/net/batman-adv/main.h b/net/batman-adv/main.h new file mode 100644 index 000000000..c48803b32 --- /dev/null +++ b/net/batman-adv/main.h @@ -0,0 +1,384 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#ifndef _NET_BATMAN_ADV_MAIN_H_ +#define _NET_BATMAN_ADV_MAIN_H_ + +#define BATADV_DRIVER_AUTHOR "Marek Lindner , " \ + "Simon Wunderlich " +#define BATADV_DRIVER_DESC "B.A.T.M.A.N. advanced" +#define BATADV_DRIVER_DEVICE "batman-adv" + +#ifndef BATADV_SOURCE_VERSION +#define BATADV_SOURCE_VERSION "2022.3" +#endif + +/* B.A.T.M.A.N. parameters */ + +#define BATADV_TQ_MAX_VALUE 255 +#define BATADV_THROUGHPUT_MAX_VALUE 0xFFFFFFFF +#define BATADV_JITTER 20 + +/* Time To Live of broadcast messages */ +#define BATADV_TTL 50 + +/* maximum sequence number age of broadcast messages */ +#define BATADV_BCAST_MAX_AGE 64 + +/* purge originators after time in seconds if no valid packet comes in + * -> TODO: check influence on BATADV_TQ_LOCAL_WINDOW_SIZE + */ +#define BATADV_PURGE_TIMEOUT 200000 /* 200 seconds */ +#define BATADV_TT_LOCAL_TIMEOUT 600000 /* in milliseconds */ +#define BATADV_TT_CLIENT_ROAM_TIMEOUT 600000 /* in milliseconds */ +#define BATADV_TT_CLIENT_TEMP_TIMEOUT 600000 /* in milliseconds */ +#define BATADV_TT_WORK_PERIOD 5000 /* 5 seconds */ +#define BATADV_ORIG_WORK_PERIOD 1000 /* 1 second */ +#define BATADV_MCAST_WORK_PERIOD 500 /* 0.5 seconds */ +#define BATADV_DAT_ENTRY_TIMEOUT (5 * 60000) /* 5 mins in milliseconds */ +/* sliding packet range of received originator messages in sequence numbers + * (should be a multiple of our word size) + */ +#define BATADV_TQ_LOCAL_WINDOW_SIZE 64 +/* milliseconds we have to keep pending tt_req */ +#define BATADV_TT_REQUEST_TIMEOUT 3000 + +#define BATADV_TQ_GLOBAL_WINDOW_SIZE 5 +#define BATADV_TQ_LOCAL_BIDRECT_SEND_MINIMUM 1 +#define BATADV_TQ_LOCAL_BIDRECT_RECV_MINIMUM 1 +#define BATADV_TQ_TOTAL_BIDRECT_LIMIT 1 + +/* B.A.T.M.A.N. V */ +#define BATADV_THROUGHPUT_DEFAULT_VALUE 10 /* 1 Mbps */ +#define BATADV_ELP_PROBES_PER_NODE 2 +#define BATADV_ELP_MIN_PROBE_SIZE 200 /* bytes */ +#define BATADV_ELP_PROBE_MAX_TX_DIFF 100 /* milliseconds */ +#define BATADV_ELP_MAX_AGE 64 +#define BATADV_OGM_MAX_ORIGDIFF 5 +#define BATADV_OGM_MAX_AGE 64 + +/* number of OGMs sent with the last tt diff */ +#define BATADV_TT_OGM_APPEND_MAX 3 + +/* Time in which a client can roam at most ROAMING_MAX_COUNT times in + * milliseconds + */ +#define BATADV_ROAMING_MAX_TIME 20000 +#define BATADV_ROAMING_MAX_COUNT 5 + +#define BATADV_NO_FLAGS 0 + +#define BATADV_NULL_IFINDEX 0 /* dummy ifindex used to avoid iface checks */ + +#define BATADV_NO_MARK 0 + +/* default interface for multi interface operation. The default interface is + * used for communication which originated locally (i.e. is not forwarded) + * or where special forwarding is not desired/necessary. + */ +#define BATADV_IF_DEFAULT ((struct batadv_hard_iface *)NULL) + +#define BATADV_NUM_WORDS BITS_TO_LONGS(BATADV_TQ_LOCAL_WINDOW_SIZE) + +#define BATADV_LOG_BUF_LEN 8192 /* has to be a power of 2 */ + +/* number of packets to send for broadcasts on different interface types */ +#define BATADV_NUM_BCASTS_DEFAULT 1 +#define BATADV_NUM_BCASTS_WIRELESS 3 + +/* length of the single packet used by the TP meter */ +#define BATADV_TP_PACKET_LEN ETH_DATA_LEN + +/* msecs after which an ARP_REQUEST is sent in broadcast as fallback */ +#define ARP_REQ_DELAY 250 +/* numbers of originator to contact for any PUT/GET DHT operation */ +#define BATADV_DAT_CANDIDATES_NUM 3 + +/* BATADV_TQ_SIMILARITY_THRESHOLD - TQ points that a secondary metric can differ + * at most from the primary one in order to be still considered acceptable + */ +#define BATADV_TQ_SIMILARITY_THRESHOLD 50 + +/* should not be bigger than 512 bytes or change the size of + * forw_packet->direct_link_flags + */ +#define BATADV_MAX_AGGREGATION_BYTES 512 +#define BATADV_MAX_AGGREGATION_MS 100 + +#define BATADV_BLA_PERIOD_LENGTH 10000 /* 10 seconds */ +#define BATADV_BLA_BACKBONE_TIMEOUT (BATADV_BLA_PERIOD_LENGTH * 6) +#define BATADV_BLA_CLAIM_TIMEOUT (BATADV_BLA_PERIOD_LENGTH * 10) +#define BATADV_BLA_WAIT_PERIODS 3 +#define BATADV_BLA_LOOPDETECT_PERIODS 6 +#define BATADV_BLA_LOOPDETECT_TIMEOUT 3000 /* 3 seconds */ + +#define BATADV_DUPLIST_SIZE 16 +#define BATADV_DUPLIST_TIMEOUT 500 /* 500 ms */ +/* don't reset again within 30 seconds */ +#define BATADV_RESET_PROTECTION_MS 30000 +#define BATADV_EXPECTED_SEQNO_RANGE 65536 + +#define BATADV_NC_NODE_TIMEOUT 10000 /* Milliseconds */ + +/** + * BATADV_TP_MAX_NUM - maximum number of simultaneously active tp sessions + */ +#define BATADV_TP_MAX_NUM 5 + +/** + * enum batadv_mesh_state - State of a soft interface + */ +enum batadv_mesh_state { + /** @BATADV_MESH_INACTIVE: soft interface is not yet running */ + BATADV_MESH_INACTIVE, + + /** @BATADV_MESH_ACTIVE: interface is up and running */ + BATADV_MESH_ACTIVE, + + /** @BATADV_MESH_DEACTIVATING: interface is getting shut down */ + BATADV_MESH_DEACTIVATING, +}; + +#define BATADV_BCAST_QUEUE_LEN 256 +#define BATADV_BATMAN_QUEUE_LEN 256 + +/** + * enum batadv_uev_action - action type of uevent + */ +enum batadv_uev_action { + /** @BATADV_UEV_ADD: gateway was selected (after none was selected) */ + BATADV_UEV_ADD = 0, + + /** + * @BATADV_UEV_DEL: selected gateway was removed and none is selected + * anymore + */ + BATADV_UEV_DEL, + + /** + * @BATADV_UEV_CHANGE: a different gateway was selected as based gateway + */ + BATADV_UEV_CHANGE, + + /** + * @BATADV_UEV_LOOPDETECT: loop was detected which cannot be handled by + * bridge loop avoidance + */ + BATADV_UEV_LOOPDETECT, +}; + +/** + * enum batadv_uev_type - Type of uevent + */ +enum batadv_uev_type { + /** @BATADV_UEV_GW: selected gateway was modified */ + BATADV_UEV_GW = 0, + + /** @BATADV_UEV_BLA: bridge loop avoidance event */ + BATADV_UEV_BLA, +}; + +#define BATADV_GW_THRESHOLD 50 + +/* Number of fragment chains for each orig_node */ +#define BATADV_FRAG_BUFFER_COUNT 8 +/* Maximum number of fragments for one packet */ +#define BATADV_FRAG_MAX_FRAGMENTS 16 +/* Maxumim size of each fragment */ +#define BATADV_FRAG_MAX_FRAG_SIZE 1280 +/* Time to keep fragments while waiting for rest of the fragments */ +#define BATADV_FRAG_TIMEOUT 10000 + +#define BATADV_DAT_CANDIDATE_NOT_FOUND 0 +#define BATADV_DAT_CANDIDATE_ORIG 1 + +/* Debug Messages */ +#ifdef pr_fmt +#undef pr_fmt +#endif +/* Append 'batman-adv: ' before kernel messages */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +/* Kernel headers */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "types.h" +#include "main.h" + +/** + * batadv_print_vid() - return printable version of vid information + * @vid: the VLAN identifier + * + * Return: -1 when no VLAN is used, VLAN id otherwise + */ +static inline int batadv_print_vid(unsigned short vid) +{ + if (vid & BATADV_VLAN_HAS_TAG) + return (int)(vid & VLAN_VID_MASK); + else + return -1; +} + +extern struct list_head batadv_hardif_list; +extern unsigned int batadv_hardif_generation; + +extern unsigned char batadv_broadcast_addr[]; +extern struct workqueue_struct *batadv_event_workqueue; + +int batadv_mesh_init(struct net_device *soft_iface); +void batadv_mesh_free(struct net_device *soft_iface); +bool batadv_is_my_mac(struct batadv_priv *bat_priv, const u8 *addr); +int batadv_max_header_len(void); +void batadv_skb_set_priority(struct sk_buff *skb, int offset); +int batadv_batman_skb_recv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *ptype, + struct net_device *orig_dev); +int +batadv_recv_handler_register(u8 packet_type, + int (*recv_handler)(struct sk_buff *, + struct batadv_hard_iface *)); +void batadv_recv_handler_unregister(u8 packet_type); +__be32 batadv_skb_crc32(struct sk_buff *skb, u8 *payload_ptr); + +/** + * batadv_compare_eth() - Compare two not u16 aligned Ethernet addresses + * @data1: Pointer to a six-byte array containing the Ethernet address + * @data2: Pointer other six-byte array containing the Ethernet address + * + * note: can't use ether_addr_equal() as it requires aligned memory + * + * Return: true if they are the same ethernet addr + */ +static inline bool batadv_compare_eth(const void *data1, const void *data2) +{ + return ether_addr_equal_unaligned(data1, data2); +} + +/** + * batadv_has_timed_out() - compares current time (jiffies) and timestamp + + * timeout + * @timestamp: base value to compare with (in jiffies) + * @timeout: added to base value before comparing (in milliseconds) + * + * Return: true if current time is after timestamp + timeout + */ +static inline bool batadv_has_timed_out(unsigned long timestamp, + unsigned int timeout) +{ + return time_is_before_jiffies(timestamp + msecs_to_jiffies(timeout)); +} + +/** + * batadv_atomic_dec_not_zero() - Decrease unless the number is 0 + * @v: pointer of type atomic_t + * + * Return: non-zero if v was not 0, and zero otherwise. + */ +#define batadv_atomic_dec_not_zero(v) atomic_add_unless((v), -1, 0) + +/** + * batadv_smallest_signed_int() - Returns the smallest signed integer in two's + * complement with the sizeof x + * @x: type of integer + * + * Return: smallest signed integer of type + */ +#define batadv_smallest_signed_int(x) (1u << (7u + 8u * (sizeof(x) - 1u))) + +/** + * batadv_seq_before() - Checks if a sequence number x is a predecessor of y + * @x: potential predecessor of @y + * @y: value to compare @x against + * + * It handles overflows/underflows and can correctly check for a predecessor + * unless the variable sequence number has grown by more than + * 2**(bitwidth(x)-1)-1. + * + * This means that for a u8 with the maximum value 255, it would think: + * + * * when adding nothing - it is neither a predecessor nor a successor + * * before adding more than 127 to the starting value - it is a predecessor, + * * when adding 128 - it is neither a predecessor nor a successor, + * * after adding more than 127 to the starting value - it is a successor + * + * Return: true when x is a predecessor of y, false otherwise + */ +#define batadv_seq_before(x, y) ({ \ + typeof(x)_d1 = (x); \ + typeof(y)_d2 = (y); \ + typeof(x)_dummy = (_d1 - _d2); \ + (void)(&_d1 == &_d2); \ + _dummy > batadv_smallest_signed_int(_dummy); \ +}) + +/** + * batadv_seq_after() - Checks if a sequence number x is a successor of y + * @x: potential successor of @y + * @y: value to compare @x against + * + * It handles overflows/underflows and can correctly check for a successor + * unless the variable sequence number has grown by more than + * 2**(bitwidth(x)-1)-1. + * + * This means that for a u8 with the maximum value 255, it would think: + * + * * when adding nothing - it is neither a predecessor nor a successor + * * before adding more than 127 to the starting value - it is a predecessor, + * * when adding 128 - it is neither a predecessor nor a successor, + * * after adding more than 127 to the starting value - it is a successor + * + * Return: true when x is a successor of y, false otherwise + */ +#define batadv_seq_after(x, y) batadv_seq_before(y, x) + +/** + * batadv_add_counter() - Add to per cpu statistics counter of soft interface + * @bat_priv: the bat priv with all the soft interface information + * @idx: counter index which should be modified + * @count: value to increase counter by + * + * Stop preemption on local cpu while incrementing the counter + */ +static inline void batadv_add_counter(struct batadv_priv *bat_priv, size_t idx, + size_t count) +{ + this_cpu_add(bat_priv->bat_counters[idx], count); +} + +/** + * batadv_inc_counter() - Increase per cpu statistics counter of soft interface + * @b: the bat priv with all the soft interface information + * @i: counter index which should be modified + */ +#define batadv_inc_counter(b, i) batadv_add_counter(b, i, 1) + +/** + * BATADV_SKB_CB() - Get batadv_skb_cb from skb control buffer + * @__skb: skb holding the control buffer + * + * The members of the control buffer are defined in struct batadv_skb_cb in + * types.h. The macro is inspired by the similar macro TCP_SKB_CB() in tcp.h. + * + * Return: pointer to the batadv_skb_cb of the skb + */ +#define BATADV_SKB_CB(__skb) ((struct batadv_skb_cb *)&((__skb)->cb[0])) + +unsigned short batadv_get_vid(struct sk_buff *skb, size_t header_len); +bool batadv_vlan_ap_isola_get(struct batadv_priv *bat_priv, unsigned short vid); +int batadv_throw_uevent(struct batadv_priv *bat_priv, enum batadv_uev_type type, + enum batadv_uev_action action, const char *data); + +#endif /* _NET_BATMAN_ADV_MAIN_H_ */ diff --git a/net/batman-adv/multicast.c b/net/batman-adv/multicast.c new file mode 100644 index 000000000..b23845591 --- /dev/null +++ b/net/batman-adv/multicast.c @@ -0,0 +1,2317 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Linus Lüssing + */ + +#include "multicast.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "bridge_loop_avoidance.h" +#include "hard-interface.h" +#include "hash.h" +#include "log.h" +#include "netlink.h" +#include "send.h" +#include "soft-interface.h" +#include "translation-table.h" +#include "tvlv.h" + +static void batadv_mcast_mla_update(struct work_struct *work); + +/** + * batadv_mcast_start_timer() - schedule the multicast periodic worker + * @bat_priv: the bat priv with all the soft interface information + */ +static void batadv_mcast_start_timer(struct batadv_priv *bat_priv) +{ + queue_delayed_work(batadv_event_workqueue, &bat_priv->mcast.work, + msecs_to_jiffies(BATADV_MCAST_WORK_PERIOD)); +} + +/** + * batadv_mcast_get_bridge() - get the bridge on top of the softif if it exists + * @soft_iface: netdev struct of the mesh interface + * + * If the given soft interface has a bridge on top then the refcount + * of the according net device is increased. + * + * Return: NULL if no such bridge exists. Otherwise the net device of the + * bridge. + */ +static struct net_device *batadv_mcast_get_bridge(struct net_device *soft_iface) +{ + struct net_device *upper = soft_iface; + + rcu_read_lock(); + do { + upper = netdev_master_upper_dev_get_rcu(upper); + } while (upper && !netif_is_bridge_master(upper)); + + dev_hold(upper); + rcu_read_unlock(); + + return upper; +} + +/** + * batadv_mcast_mla_rtr_flags_softif_get_ipv4() - get mcast router flags from + * node for IPv4 + * @dev: the interface to check + * + * Checks the presence of an IPv4 multicast router on this node. + * + * Caller needs to hold rcu read lock. + * + * Return: BATADV_NO_FLAGS if present, BATADV_MCAST_WANT_NO_RTR4 otherwise. + */ +static u8 batadv_mcast_mla_rtr_flags_softif_get_ipv4(struct net_device *dev) +{ + struct in_device *in_dev = __in_dev_get_rcu(dev); + + if (in_dev && IN_DEV_MFORWARD(in_dev)) + return BATADV_NO_FLAGS; + else + return BATADV_MCAST_WANT_NO_RTR4; +} + +/** + * batadv_mcast_mla_rtr_flags_softif_get_ipv6() - get mcast router flags from + * node for IPv6 + * @dev: the interface to check + * + * Checks the presence of an IPv6 multicast router on this node. + * + * Caller needs to hold rcu read lock. + * + * Return: BATADV_NO_FLAGS if present, BATADV_MCAST_WANT_NO_RTR6 otherwise. + */ +#if IS_ENABLED(CONFIG_IPV6_MROUTE) +static u8 batadv_mcast_mla_rtr_flags_softif_get_ipv6(struct net_device *dev) +{ + struct inet6_dev *in6_dev = __in6_dev_get(dev); + + if (in6_dev && atomic_read(&in6_dev->cnf.mc_forwarding)) + return BATADV_NO_FLAGS; + else + return BATADV_MCAST_WANT_NO_RTR6; +} +#else +static inline u8 +batadv_mcast_mla_rtr_flags_softif_get_ipv6(struct net_device *dev) +{ + return BATADV_MCAST_WANT_NO_RTR6; +} +#endif + +/** + * batadv_mcast_mla_rtr_flags_softif_get() - get mcast router flags from node + * @bat_priv: the bat priv with all the soft interface information + * @bridge: bridge interface on top of the soft_iface if present, + * otherwise pass NULL + * + * Checks the presence of IPv4 and IPv6 multicast routers on this + * node. + * + * Return: + * BATADV_NO_FLAGS: Both an IPv4 and IPv6 multicast router is present + * BATADV_MCAST_WANT_NO_RTR4: No IPv4 multicast router is present + * BATADV_MCAST_WANT_NO_RTR6: No IPv6 multicast router is present + * The former two OR'd: no multicast router is present + */ +static u8 batadv_mcast_mla_rtr_flags_softif_get(struct batadv_priv *bat_priv, + struct net_device *bridge) +{ + struct net_device *dev = bridge ? bridge : bat_priv->soft_iface; + u8 flags = BATADV_NO_FLAGS; + + rcu_read_lock(); + + flags |= batadv_mcast_mla_rtr_flags_softif_get_ipv4(dev); + flags |= batadv_mcast_mla_rtr_flags_softif_get_ipv6(dev); + + rcu_read_unlock(); + + return flags; +} + +/** + * batadv_mcast_mla_rtr_flags_bridge_get() - get mcast router flags from bridge + * @bat_priv: the bat priv with all the soft interface information + * @bridge: bridge interface on top of the soft_iface if present, + * otherwise pass NULL + * + * Checks the presence of IPv4 and IPv6 multicast routers behind a bridge. + * + * Return: + * BATADV_NO_FLAGS: Both an IPv4 and IPv6 multicast router is present + * BATADV_MCAST_WANT_NO_RTR4: No IPv4 multicast router is present + * BATADV_MCAST_WANT_NO_RTR6: No IPv6 multicast router is present + * The former two OR'd: no multicast router is present + */ +static u8 batadv_mcast_mla_rtr_flags_bridge_get(struct batadv_priv *bat_priv, + struct net_device *bridge) +{ + struct net_device *dev = bat_priv->soft_iface; + u8 flags = BATADV_NO_FLAGS; + + if (!bridge) + return BATADV_MCAST_WANT_NO_RTR4 | BATADV_MCAST_WANT_NO_RTR6; + + if (!br_multicast_has_router_adjacent(dev, ETH_P_IP)) + flags |= BATADV_MCAST_WANT_NO_RTR4; + if (!br_multicast_has_router_adjacent(dev, ETH_P_IPV6)) + flags |= BATADV_MCAST_WANT_NO_RTR6; + + return flags; +} + +/** + * batadv_mcast_mla_rtr_flags_get() - get multicast router flags + * @bat_priv: the bat priv with all the soft interface information + * @bridge: bridge interface on top of the soft_iface if present, + * otherwise pass NULL + * + * Checks the presence of IPv4 and IPv6 multicast routers on this + * node or behind its bridge. + * + * Return: + * BATADV_NO_FLAGS: Both an IPv4 and IPv6 multicast router is present + * BATADV_MCAST_WANT_NO_RTR4: No IPv4 multicast router is present + * BATADV_MCAST_WANT_NO_RTR6: No IPv6 multicast router is present + * The former two OR'd: no multicast router is present + */ +static u8 batadv_mcast_mla_rtr_flags_get(struct batadv_priv *bat_priv, + struct net_device *bridge) +{ + u8 flags = BATADV_MCAST_WANT_NO_RTR4 | BATADV_MCAST_WANT_NO_RTR6; + + flags &= batadv_mcast_mla_rtr_flags_softif_get(bat_priv, bridge); + flags &= batadv_mcast_mla_rtr_flags_bridge_get(bat_priv, bridge); + + return flags; +} + +/** + * batadv_mcast_mla_flags_get() - get the new multicast flags + * @bat_priv: the bat priv with all the soft interface information + * + * Return: A set of flags for the current/next TVLV, querier and + * bridge state. + */ +static struct batadv_mcast_mla_flags +batadv_mcast_mla_flags_get(struct batadv_priv *bat_priv) +{ + struct net_device *dev = bat_priv->soft_iface; + struct batadv_mcast_querier_state *qr4, *qr6; + struct batadv_mcast_mla_flags mla_flags; + struct net_device *bridge; + + bridge = batadv_mcast_get_bridge(dev); + + memset(&mla_flags, 0, sizeof(mla_flags)); + mla_flags.enabled = 1; + mla_flags.tvlv_flags |= batadv_mcast_mla_rtr_flags_get(bat_priv, + bridge); + + if (!bridge) + return mla_flags; + + dev_put(bridge); + + mla_flags.bridged = 1; + qr4 = &mla_flags.querier_ipv4; + qr6 = &mla_flags.querier_ipv6; + + if (!IS_ENABLED(CONFIG_BRIDGE_IGMP_SNOOPING)) + pr_warn_once("No bridge IGMP snooping compiled - multicast optimizations disabled\n"); + + qr4->exists = br_multicast_has_querier_anywhere(dev, ETH_P_IP); + qr4->shadowing = br_multicast_has_querier_adjacent(dev, ETH_P_IP); + + qr6->exists = br_multicast_has_querier_anywhere(dev, ETH_P_IPV6); + qr6->shadowing = br_multicast_has_querier_adjacent(dev, ETH_P_IPV6); + + mla_flags.tvlv_flags |= BATADV_MCAST_WANT_ALL_UNSNOOPABLES; + + /* 1) If no querier exists at all, then multicast listeners on + * our local TT clients behind the bridge will keep silent. + * 2) If the selected querier is on one of our local TT clients, + * behind the bridge, then this querier might shadow multicast + * listeners on our local TT clients, behind this bridge. + * + * In both cases, we will signalize other batman nodes that + * we need all multicast traffic of the according protocol. + */ + if (!qr4->exists || qr4->shadowing) { + mla_flags.tvlv_flags |= BATADV_MCAST_WANT_ALL_IPV4; + mla_flags.tvlv_flags &= ~BATADV_MCAST_WANT_NO_RTR4; + } + + if (!qr6->exists || qr6->shadowing) { + mla_flags.tvlv_flags |= BATADV_MCAST_WANT_ALL_IPV6; + mla_flags.tvlv_flags &= ~BATADV_MCAST_WANT_NO_RTR6; + } + + return mla_flags; +} + +/** + * batadv_mcast_mla_is_duplicate() - check whether an address is in a list + * @mcast_addr: the multicast address to check + * @mcast_list: the list with multicast addresses to search in + * + * Return: true if the given address is already in the given list. + * Otherwise returns false. + */ +static bool batadv_mcast_mla_is_duplicate(u8 *mcast_addr, + struct hlist_head *mcast_list) +{ + struct batadv_hw_addr *mcast_entry; + + hlist_for_each_entry(mcast_entry, mcast_list, list) + if (batadv_compare_eth(mcast_entry->addr, mcast_addr)) + return true; + + return false; +} + +/** + * batadv_mcast_mla_softif_get_ipv4() - get softif IPv4 multicast listeners + * @dev: the device to collect multicast addresses from + * @mcast_list: a list to put found addresses into + * @flags: flags indicating the new multicast state + * + * Collects multicast addresses of IPv4 multicast listeners residing + * on this kernel on the given soft interface, dev, in + * the given mcast_list. In general, multicast listeners provided by + * your multicast receiving applications run directly on this node. + * + * Return: -ENOMEM on memory allocation error or the number of + * items added to the mcast_list otherwise. + */ +static int +batadv_mcast_mla_softif_get_ipv4(struct net_device *dev, + struct hlist_head *mcast_list, + struct batadv_mcast_mla_flags *flags) +{ + struct batadv_hw_addr *new; + struct in_device *in_dev; + u8 mcast_addr[ETH_ALEN]; + struct ip_mc_list *pmc; + int ret = 0; + + if (flags->tvlv_flags & BATADV_MCAST_WANT_ALL_IPV4) + return 0; + + rcu_read_lock(); + + in_dev = __in_dev_get_rcu(dev); + if (!in_dev) { + rcu_read_unlock(); + return 0; + } + + for (pmc = rcu_dereference(in_dev->mc_list); pmc; + pmc = rcu_dereference(pmc->next_rcu)) { + if (flags->tvlv_flags & BATADV_MCAST_WANT_ALL_UNSNOOPABLES && + ipv4_is_local_multicast(pmc->multiaddr)) + continue; + + if (!(flags->tvlv_flags & BATADV_MCAST_WANT_NO_RTR4) && + !ipv4_is_local_multicast(pmc->multiaddr)) + continue; + + ip_eth_mc_map(pmc->multiaddr, mcast_addr); + + if (batadv_mcast_mla_is_duplicate(mcast_addr, mcast_list)) + continue; + + new = kmalloc(sizeof(*new), GFP_ATOMIC); + if (!new) { + ret = -ENOMEM; + break; + } + + ether_addr_copy(new->addr, mcast_addr); + hlist_add_head(&new->list, mcast_list); + ret++; + } + rcu_read_unlock(); + + return ret; +} + +/** + * batadv_mcast_mla_softif_get_ipv6() - get softif IPv6 multicast listeners + * @dev: the device to collect multicast addresses from + * @mcast_list: a list to put found addresses into + * @flags: flags indicating the new multicast state + * + * Collects multicast addresses of IPv6 multicast listeners residing + * on this kernel on the given soft interface, dev, in + * the given mcast_list. In general, multicast listeners provided by + * your multicast receiving applications run directly on this node. + * + * Return: -ENOMEM on memory allocation error or the number of + * items added to the mcast_list otherwise. + */ +#if IS_ENABLED(CONFIG_IPV6) +static int +batadv_mcast_mla_softif_get_ipv6(struct net_device *dev, + struct hlist_head *mcast_list, + struct batadv_mcast_mla_flags *flags) +{ + struct batadv_hw_addr *new; + struct inet6_dev *in6_dev; + u8 mcast_addr[ETH_ALEN]; + struct ifmcaddr6 *pmc6; + int ret = 0; + + if (flags->tvlv_flags & BATADV_MCAST_WANT_ALL_IPV6) + return 0; + + rcu_read_lock(); + + in6_dev = __in6_dev_get(dev); + if (!in6_dev) { + rcu_read_unlock(); + return 0; + } + + for (pmc6 = rcu_dereference(in6_dev->mc_list); + pmc6; + pmc6 = rcu_dereference(pmc6->next)) { + if (IPV6_ADDR_MC_SCOPE(&pmc6->mca_addr) < + IPV6_ADDR_SCOPE_LINKLOCAL) + continue; + + if (flags->tvlv_flags & BATADV_MCAST_WANT_ALL_UNSNOOPABLES && + ipv6_addr_is_ll_all_nodes(&pmc6->mca_addr)) + continue; + + if (!(flags->tvlv_flags & BATADV_MCAST_WANT_NO_RTR6) && + IPV6_ADDR_MC_SCOPE(&pmc6->mca_addr) > + IPV6_ADDR_SCOPE_LINKLOCAL) + continue; + + ipv6_eth_mc_map(&pmc6->mca_addr, mcast_addr); + + if (batadv_mcast_mla_is_duplicate(mcast_addr, mcast_list)) + continue; + + new = kmalloc(sizeof(*new), GFP_ATOMIC); + if (!new) { + ret = -ENOMEM; + break; + } + + ether_addr_copy(new->addr, mcast_addr); + hlist_add_head(&new->list, mcast_list); + ret++; + } + rcu_read_unlock(); + + return ret; +} +#else +static inline int +batadv_mcast_mla_softif_get_ipv6(struct net_device *dev, + struct hlist_head *mcast_list, + struct batadv_mcast_mla_flags *flags) +{ + return 0; +} +#endif + +/** + * batadv_mcast_mla_softif_get() - get softif multicast listeners + * @dev: the device to collect multicast addresses from + * @mcast_list: a list to put found addresses into + * @flags: flags indicating the new multicast state + * + * Collects multicast addresses of multicast listeners residing + * on this kernel on the given soft interface, dev, in + * the given mcast_list. In general, multicast listeners provided by + * your multicast receiving applications run directly on this node. + * + * If there is a bridge interface on top of dev, collect from that one + * instead. Just like with IP addresses and routes, multicast listeners + * will(/should) register to the bridge interface instead of an + * enslaved bat0. + * + * Return: -ENOMEM on memory allocation error or the number of + * items added to the mcast_list otherwise. + */ +static int +batadv_mcast_mla_softif_get(struct net_device *dev, + struct hlist_head *mcast_list, + struct batadv_mcast_mla_flags *flags) +{ + struct net_device *bridge = batadv_mcast_get_bridge(dev); + int ret4, ret6 = 0; + + if (bridge) + dev = bridge; + + ret4 = batadv_mcast_mla_softif_get_ipv4(dev, mcast_list, flags); + if (ret4 < 0) + goto out; + + ret6 = batadv_mcast_mla_softif_get_ipv6(dev, mcast_list, flags); + if (ret6 < 0) { + ret4 = 0; + goto out; + } + +out: + dev_put(bridge); + + return ret4 + ret6; +} + +/** + * batadv_mcast_mla_br_addr_cpy() - copy a bridge multicast address + * @dst: destination to write to - a multicast MAC address + * @src: source to read from - a multicast IP address + * + * Converts a given multicast IPv4/IPv6 address from a bridge + * to its matching multicast MAC address and copies it into the given + * destination buffer. + * + * Caller needs to make sure the destination buffer can hold + * at least ETH_ALEN bytes. + */ +static void batadv_mcast_mla_br_addr_cpy(char *dst, const struct br_ip *src) +{ + if (src->proto == htons(ETH_P_IP)) + ip_eth_mc_map(src->dst.ip4, dst); +#if IS_ENABLED(CONFIG_IPV6) + else if (src->proto == htons(ETH_P_IPV6)) + ipv6_eth_mc_map(&src->dst.ip6, dst); +#endif + else + eth_zero_addr(dst); +} + +/** + * batadv_mcast_mla_bridge_get() - get bridged-in multicast listeners + * @dev: a bridge slave whose bridge to collect multicast addresses from + * @mcast_list: a list to put found addresses into + * @flags: flags indicating the new multicast state + * + * Collects multicast addresses of multicast listeners residing + * on foreign, non-mesh devices which we gave access to our mesh via + * a bridge on top of the given soft interface, dev, in the given + * mcast_list. + * + * Return: -ENOMEM on memory allocation error or the number of + * items added to the mcast_list otherwise. + */ +static int batadv_mcast_mla_bridge_get(struct net_device *dev, + struct hlist_head *mcast_list, + struct batadv_mcast_mla_flags *flags) +{ + struct list_head bridge_mcast_list = LIST_HEAD_INIT(bridge_mcast_list); + struct br_ip_list *br_ip_entry, *tmp; + u8 tvlv_flags = flags->tvlv_flags; + struct batadv_hw_addr *new; + u8 mcast_addr[ETH_ALEN]; + int ret; + + /* we don't need to detect these devices/listeners, the IGMP/MLD + * snooping code of the Linux bridge already does that for us + */ + ret = br_multicast_list_adjacent(dev, &bridge_mcast_list); + if (ret < 0) + goto out; + + list_for_each_entry(br_ip_entry, &bridge_mcast_list, list) { + if (br_ip_entry->addr.proto == htons(ETH_P_IP)) { + if (tvlv_flags & BATADV_MCAST_WANT_ALL_IPV4) + continue; + + if (tvlv_flags & BATADV_MCAST_WANT_ALL_UNSNOOPABLES && + ipv4_is_local_multicast(br_ip_entry->addr.dst.ip4)) + continue; + + if (!(tvlv_flags & BATADV_MCAST_WANT_NO_RTR4) && + !ipv4_is_local_multicast(br_ip_entry->addr.dst.ip4)) + continue; + } + +#if IS_ENABLED(CONFIG_IPV6) + if (br_ip_entry->addr.proto == htons(ETH_P_IPV6)) { + if (tvlv_flags & BATADV_MCAST_WANT_ALL_IPV6) + continue; + + if (tvlv_flags & BATADV_MCAST_WANT_ALL_UNSNOOPABLES && + ipv6_addr_is_ll_all_nodes(&br_ip_entry->addr.dst.ip6)) + continue; + + if (!(tvlv_flags & BATADV_MCAST_WANT_NO_RTR6) && + IPV6_ADDR_MC_SCOPE(&br_ip_entry->addr.dst.ip6) > + IPV6_ADDR_SCOPE_LINKLOCAL) + continue; + } +#endif + + batadv_mcast_mla_br_addr_cpy(mcast_addr, &br_ip_entry->addr); + if (batadv_mcast_mla_is_duplicate(mcast_addr, mcast_list)) + continue; + + new = kmalloc(sizeof(*new), GFP_ATOMIC); + if (!new) { + ret = -ENOMEM; + break; + } + + ether_addr_copy(new->addr, mcast_addr); + hlist_add_head(&new->list, mcast_list); + } + +out: + list_for_each_entry_safe(br_ip_entry, tmp, &bridge_mcast_list, list) { + list_del(&br_ip_entry->list); + kfree(br_ip_entry); + } + + return ret; +} + +/** + * batadv_mcast_mla_list_free() - free a list of multicast addresses + * @mcast_list: the list to free + * + * Removes and frees all items in the given mcast_list. + */ +static void batadv_mcast_mla_list_free(struct hlist_head *mcast_list) +{ + struct batadv_hw_addr *mcast_entry; + struct hlist_node *tmp; + + hlist_for_each_entry_safe(mcast_entry, tmp, mcast_list, list) { + hlist_del(&mcast_entry->list); + kfree(mcast_entry); + } +} + +/** + * batadv_mcast_mla_tt_retract() - clean up multicast listener announcements + * @bat_priv: the bat priv with all the soft interface information + * @mcast_list: a list of addresses which should _not_ be removed + * + * Retracts the announcement of any multicast listener from the + * translation table except the ones listed in the given mcast_list. + * + * If mcast_list is NULL then all are retracted. + */ +static void batadv_mcast_mla_tt_retract(struct batadv_priv *bat_priv, + struct hlist_head *mcast_list) +{ + struct batadv_hw_addr *mcast_entry; + struct hlist_node *tmp; + + hlist_for_each_entry_safe(mcast_entry, tmp, &bat_priv->mcast.mla_list, + list) { + if (mcast_list && + batadv_mcast_mla_is_duplicate(mcast_entry->addr, + mcast_list)) + continue; + + batadv_tt_local_remove(bat_priv, mcast_entry->addr, + BATADV_NO_FLAGS, + "mcast TT outdated", false); + + hlist_del(&mcast_entry->list); + kfree(mcast_entry); + } +} + +/** + * batadv_mcast_mla_tt_add() - add multicast listener announcements + * @bat_priv: the bat priv with all the soft interface information + * @mcast_list: a list of addresses which are going to get added + * + * Adds multicast listener announcements from the given mcast_list to the + * translation table if they have not been added yet. + */ +static void batadv_mcast_mla_tt_add(struct batadv_priv *bat_priv, + struct hlist_head *mcast_list) +{ + struct batadv_hw_addr *mcast_entry; + struct hlist_node *tmp; + + if (!mcast_list) + return; + + hlist_for_each_entry_safe(mcast_entry, tmp, mcast_list, list) { + if (batadv_mcast_mla_is_duplicate(mcast_entry->addr, + &bat_priv->mcast.mla_list)) + continue; + + if (!batadv_tt_local_add(bat_priv->soft_iface, + mcast_entry->addr, BATADV_NO_FLAGS, + BATADV_NULL_IFINDEX, BATADV_NO_MARK)) + continue; + + hlist_del(&mcast_entry->list); + hlist_add_head(&mcast_entry->list, &bat_priv->mcast.mla_list); + } +} + +/** + * batadv_mcast_querier_log() - debug output regarding the querier status on + * link + * @bat_priv: the bat priv with all the soft interface information + * @str_proto: a string for the querier protocol (e.g. "IGMP" or "MLD") + * @old_state: the previous querier state on our link + * @new_state: the new querier state on our link + * + * Outputs debug messages to the logging facility with log level 'mcast' + * regarding changes to the querier status on the link which are relevant + * to our multicast optimizations. + * + * Usually this is about whether a querier appeared or vanished in + * our mesh or whether the querier is in the suboptimal position of being + * behind our local bridge segment: Snooping switches will directly + * forward listener reports to the querier, therefore batman-adv and + * the bridge will potentially not see these listeners - the querier is + * potentially shadowing listeners from us then. + * + * This is only interesting for nodes with a bridge on top of their + * soft interface. + */ +static void +batadv_mcast_querier_log(struct batadv_priv *bat_priv, char *str_proto, + struct batadv_mcast_querier_state *old_state, + struct batadv_mcast_querier_state *new_state) +{ + if (!old_state->exists && new_state->exists) + batadv_info(bat_priv->soft_iface, "%s Querier appeared\n", + str_proto); + else if (old_state->exists && !new_state->exists) + batadv_info(bat_priv->soft_iface, + "%s Querier disappeared - multicast optimizations disabled\n", + str_proto); + else if (!bat_priv->mcast.mla_flags.bridged && !new_state->exists) + batadv_info(bat_priv->soft_iface, + "No %s Querier present - multicast optimizations disabled\n", + str_proto); + + if (new_state->exists) { + if ((!old_state->shadowing && new_state->shadowing) || + (!old_state->exists && new_state->shadowing)) + batadv_dbg(BATADV_DBG_MCAST, bat_priv, + "%s Querier is behind our bridged segment: Might shadow listeners\n", + str_proto); + else if (old_state->shadowing && !new_state->shadowing) + batadv_dbg(BATADV_DBG_MCAST, bat_priv, + "%s Querier is not behind our bridged segment\n", + str_proto); + } +} + +/** + * batadv_mcast_bridge_log() - debug output for topology changes in bridged + * setups + * @bat_priv: the bat priv with all the soft interface information + * @new_flags: flags indicating the new multicast state + * + * If no bridges are ever used on this node, then this function does nothing. + * + * Otherwise this function outputs debug information to the 'mcast' log level + * which might be relevant to our multicast optimizations. + * + * More precisely, it outputs information when a bridge interface is added or + * removed from a soft interface. And when a bridge is present, it further + * outputs information about the querier state which is relevant for the + * multicast flags this node is going to set. + */ +static void +batadv_mcast_bridge_log(struct batadv_priv *bat_priv, + struct batadv_mcast_mla_flags *new_flags) +{ + struct batadv_mcast_mla_flags *old_flags = &bat_priv->mcast.mla_flags; + + if (!old_flags->bridged && new_flags->bridged) + batadv_dbg(BATADV_DBG_MCAST, bat_priv, + "Bridge added: Setting Unsnoopables(U)-flag\n"); + else if (old_flags->bridged && !new_flags->bridged) + batadv_dbg(BATADV_DBG_MCAST, bat_priv, + "Bridge removed: Unsetting Unsnoopables(U)-flag\n"); + + if (new_flags->bridged) { + batadv_mcast_querier_log(bat_priv, "IGMP", + &old_flags->querier_ipv4, + &new_flags->querier_ipv4); + batadv_mcast_querier_log(bat_priv, "MLD", + &old_flags->querier_ipv6, + &new_flags->querier_ipv6); + } +} + +/** + * batadv_mcast_flags_log() - output debug information about mcast flag changes + * @bat_priv: the bat priv with all the soft interface information + * @flags: TVLV flags indicating the new multicast state + * + * Whenever the multicast TVLV flags this node announces change, this function + * should be used to notify userspace about the change. + */ +static void batadv_mcast_flags_log(struct batadv_priv *bat_priv, u8 flags) +{ + bool old_enabled = bat_priv->mcast.mla_flags.enabled; + u8 old_flags = bat_priv->mcast.mla_flags.tvlv_flags; + char str_old_flags[] = "[.... . ]"; + + sprintf(str_old_flags, "[%c%c%c%s%s]", + (old_flags & BATADV_MCAST_WANT_ALL_UNSNOOPABLES) ? 'U' : '.', + (old_flags & BATADV_MCAST_WANT_ALL_IPV4) ? '4' : '.', + (old_flags & BATADV_MCAST_WANT_ALL_IPV6) ? '6' : '.', + !(old_flags & BATADV_MCAST_WANT_NO_RTR4) ? "R4" : ". ", + !(old_flags & BATADV_MCAST_WANT_NO_RTR6) ? "R6" : ". "); + + batadv_dbg(BATADV_DBG_MCAST, bat_priv, + "Changing multicast flags from '%s' to '[%c%c%c%s%s]'\n", + old_enabled ? str_old_flags : "", + (flags & BATADV_MCAST_WANT_ALL_UNSNOOPABLES) ? 'U' : '.', + (flags & BATADV_MCAST_WANT_ALL_IPV4) ? '4' : '.', + (flags & BATADV_MCAST_WANT_ALL_IPV6) ? '6' : '.', + !(flags & BATADV_MCAST_WANT_NO_RTR4) ? "R4" : ". ", + !(flags & BATADV_MCAST_WANT_NO_RTR6) ? "R6" : ". "); +} + +/** + * batadv_mcast_mla_flags_update() - update multicast flags + * @bat_priv: the bat priv with all the soft interface information + * @flags: flags indicating the new multicast state + * + * Updates the own multicast tvlv with our current multicast related settings, + * capabilities and inabilities. + */ +static void +batadv_mcast_mla_flags_update(struct batadv_priv *bat_priv, + struct batadv_mcast_mla_flags *flags) +{ + struct batadv_tvlv_mcast_data mcast_data; + + if (!memcmp(flags, &bat_priv->mcast.mla_flags, sizeof(*flags))) + return; + + batadv_mcast_bridge_log(bat_priv, flags); + batadv_mcast_flags_log(bat_priv, flags->tvlv_flags); + + mcast_data.flags = flags->tvlv_flags; + memset(mcast_data.reserved, 0, sizeof(mcast_data.reserved)); + + batadv_tvlv_container_register(bat_priv, BATADV_TVLV_MCAST, 2, + &mcast_data, sizeof(mcast_data)); + + bat_priv->mcast.mla_flags = *flags; +} + +/** + * __batadv_mcast_mla_update() - update the own MLAs + * @bat_priv: the bat priv with all the soft interface information + * + * Updates the own multicast listener announcements in the translation + * table as well as the own, announced multicast tvlv container. + * + * Note that non-conflicting reads and writes to bat_priv->mcast.mla_list + * in batadv_mcast_mla_tt_retract() and batadv_mcast_mla_tt_add() are + * ensured by the non-parallel execution of the worker this function + * belongs to. + */ +static void __batadv_mcast_mla_update(struct batadv_priv *bat_priv) +{ + struct net_device *soft_iface = bat_priv->soft_iface; + struct hlist_head mcast_list = HLIST_HEAD_INIT; + struct batadv_mcast_mla_flags flags; + int ret; + + flags = batadv_mcast_mla_flags_get(bat_priv); + + ret = batadv_mcast_mla_softif_get(soft_iface, &mcast_list, &flags); + if (ret < 0) + goto out; + + ret = batadv_mcast_mla_bridge_get(soft_iface, &mcast_list, &flags); + if (ret < 0) + goto out; + + spin_lock(&bat_priv->mcast.mla_lock); + batadv_mcast_mla_tt_retract(bat_priv, &mcast_list); + batadv_mcast_mla_tt_add(bat_priv, &mcast_list); + batadv_mcast_mla_flags_update(bat_priv, &flags); + spin_unlock(&bat_priv->mcast.mla_lock); + +out: + batadv_mcast_mla_list_free(&mcast_list); +} + +/** + * batadv_mcast_mla_update() - update the own MLAs + * @work: kernel work struct + * + * Updates the own multicast listener announcements in the translation + * table as well as the own, announced multicast tvlv container. + * + * In the end, reschedules the work timer. + */ +static void batadv_mcast_mla_update(struct work_struct *work) +{ + struct delayed_work *delayed_work; + struct batadv_priv_mcast *priv_mcast; + struct batadv_priv *bat_priv; + + delayed_work = to_delayed_work(work); + priv_mcast = container_of(delayed_work, struct batadv_priv_mcast, work); + bat_priv = container_of(priv_mcast, struct batadv_priv, mcast); + + __batadv_mcast_mla_update(bat_priv); + batadv_mcast_start_timer(bat_priv); +} + +/** + * batadv_mcast_is_report_ipv4() - check for IGMP reports + * @skb: the ethernet frame destined for the mesh + * + * This call might reallocate skb data. + * + * Checks whether the given frame is a valid IGMP report. + * + * Return: If so then true, otherwise false. + */ +static bool batadv_mcast_is_report_ipv4(struct sk_buff *skb) +{ + if (ip_mc_check_igmp(skb) < 0) + return false; + + switch (igmp_hdr(skb)->type) { + case IGMP_HOST_MEMBERSHIP_REPORT: + case IGMPV2_HOST_MEMBERSHIP_REPORT: + case IGMPV3_HOST_MEMBERSHIP_REPORT: + return true; + } + + return false; +} + +/** + * batadv_mcast_forw_mode_check_ipv4() - check for optimized forwarding + * potential + * @bat_priv: the bat priv with all the soft interface information + * @skb: the IPv4 packet to check + * @is_unsnoopable: stores whether the destination is snoopable + * @is_routable: stores whether the destination is routable + * + * Checks whether the given IPv4 packet has the potential to be forwarded with a + * mode more optimal than classic flooding. + * + * Return: If so then 0. Otherwise -EINVAL or -ENOMEM in case of memory + * allocation failure. + */ +static int batadv_mcast_forw_mode_check_ipv4(struct batadv_priv *bat_priv, + struct sk_buff *skb, + bool *is_unsnoopable, + int *is_routable) +{ + struct iphdr *iphdr; + + /* We might fail due to out-of-memory -> drop it */ + if (!pskb_may_pull(skb, sizeof(struct ethhdr) + sizeof(*iphdr))) + return -ENOMEM; + + if (batadv_mcast_is_report_ipv4(skb)) + return -EINVAL; + + iphdr = ip_hdr(skb); + + /* link-local multicast listeners behind a bridge are + * not snoopable (see RFC4541, section 2.1.2.2) + */ + if (ipv4_is_local_multicast(iphdr->daddr)) + *is_unsnoopable = true; + else + *is_routable = ETH_P_IP; + + return 0; +} + +/** + * batadv_mcast_is_report_ipv6() - check for MLD reports + * @skb: the ethernet frame destined for the mesh + * + * This call might reallocate skb data. + * + * Checks whether the given frame is a valid MLD report. + * + * Return: If so then true, otherwise false. + */ +static bool batadv_mcast_is_report_ipv6(struct sk_buff *skb) +{ + if (ipv6_mc_check_mld(skb) < 0) + return false; + + switch (icmp6_hdr(skb)->icmp6_type) { + case ICMPV6_MGM_REPORT: + case ICMPV6_MLD2_REPORT: + return true; + } + + return false; +} + +/** + * batadv_mcast_forw_mode_check_ipv6() - check for optimized forwarding + * potential + * @bat_priv: the bat priv with all the soft interface information + * @skb: the IPv6 packet to check + * @is_unsnoopable: stores whether the destination is snoopable + * @is_routable: stores whether the destination is routable + * + * Checks whether the given IPv6 packet has the potential to be forwarded with a + * mode more optimal than classic flooding. + * + * Return: If so then 0. Otherwise -EINVAL is or -ENOMEM if we are out of memory + */ +static int batadv_mcast_forw_mode_check_ipv6(struct batadv_priv *bat_priv, + struct sk_buff *skb, + bool *is_unsnoopable, + int *is_routable) +{ + struct ipv6hdr *ip6hdr; + + /* We might fail due to out-of-memory -> drop it */ + if (!pskb_may_pull(skb, sizeof(struct ethhdr) + sizeof(*ip6hdr))) + return -ENOMEM; + + if (batadv_mcast_is_report_ipv6(skb)) + return -EINVAL; + + ip6hdr = ipv6_hdr(skb); + + if (IPV6_ADDR_MC_SCOPE(&ip6hdr->daddr) < IPV6_ADDR_SCOPE_LINKLOCAL) + return -EINVAL; + + /* link-local-all-nodes multicast listeners behind a bridge are + * not snoopable (see RFC4541, section 3, paragraph 3) + */ + if (ipv6_addr_is_ll_all_nodes(&ip6hdr->daddr)) + *is_unsnoopable = true; + else if (IPV6_ADDR_MC_SCOPE(&ip6hdr->daddr) > IPV6_ADDR_SCOPE_LINKLOCAL) + *is_routable = ETH_P_IPV6; + + return 0; +} + +/** + * batadv_mcast_forw_mode_check() - check for optimized forwarding potential + * @bat_priv: the bat priv with all the soft interface information + * @skb: the multicast frame to check + * @is_unsnoopable: stores whether the destination is snoopable + * @is_routable: stores whether the destination is routable + * + * Checks whether the given multicast ethernet frame has the potential to be + * forwarded with a mode more optimal than classic flooding. + * + * Return: If so then 0. Otherwise -EINVAL is or -ENOMEM if we are out of memory + */ +static int batadv_mcast_forw_mode_check(struct batadv_priv *bat_priv, + struct sk_buff *skb, + bool *is_unsnoopable, + int *is_routable) +{ + struct ethhdr *ethhdr = eth_hdr(skb); + + if (!atomic_read(&bat_priv->multicast_mode)) + return -EINVAL; + + switch (ntohs(ethhdr->h_proto)) { + case ETH_P_IP: + return batadv_mcast_forw_mode_check_ipv4(bat_priv, skb, + is_unsnoopable, + is_routable); + case ETH_P_IPV6: + if (!IS_ENABLED(CONFIG_IPV6)) + return -EINVAL; + + return batadv_mcast_forw_mode_check_ipv6(bat_priv, skb, + is_unsnoopable, + is_routable); + default: + return -EINVAL; + } +} + +/** + * batadv_mcast_forw_want_all_ip_count() - count nodes with unspecific mcast + * interest + * @bat_priv: the bat priv with all the soft interface information + * @ethhdr: ethernet header of a packet + * + * Return: the number of nodes which want all IPv4 multicast traffic if the + * given ethhdr is from an IPv4 packet or the number of nodes which want all + * IPv6 traffic if it matches an IPv6 packet. + */ +static int batadv_mcast_forw_want_all_ip_count(struct batadv_priv *bat_priv, + struct ethhdr *ethhdr) +{ + switch (ntohs(ethhdr->h_proto)) { + case ETH_P_IP: + return atomic_read(&bat_priv->mcast.num_want_all_ipv4); + case ETH_P_IPV6: + return atomic_read(&bat_priv->mcast.num_want_all_ipv6); + default: + /* we shouldn't be here... */ + return 0; + } +} + +/** + * batadv_mcast_forw_rtr_count() - count nodes with a multicast router + * @bat_priv: the bat priv with all the soft interface information + * @protocol: the ethernet protocol type to count multicast routers for + * + * Return: the number of nodes which want all routable IPv4 multicast traffic + * if the protocol is ETH_P_IP or the number of nodes which want all routable + * IPv6 traffic if the protocol is ETH_P_IPV6. Otherwise returns 0. + */ + +static int batadv_mcast_forw_rtr_count(struct batadv_priv *bat_priv, + int protocol) +{ + switch (protocol) { + case ETH_P_IP: + return atomic_read(&bat_priv->mcast.num_want_all_rtr4); + case ETH_P_IPV6: + return atomic_read(&bat_priv->mcast.num_want_all_rtr6); + default: + return 0; + } +} + +/** + * batadv_mcast_forw_tt_node_get() - get a multicast tt node + * @bat_priv: the bat priv with all the soft interface information + * @ethhdr: the ether header containing the multicast destination + * + * Return: an orig_node matching the multicast address provided by ethhdr + * via a translation table lookup. This increases the returned nodes refcount. + */ +static struct batadv_orig_node * +batadv_mcast_forw_tt_node_get(struct batadv_priv *bat_priv, + struct ethhdr *ethhdr) +{ + return batadv_transtable_search(bat_priv, NULL, ethhdr->h_dest, + BATADV_NO_FLAGS); +} + +/** + * batadv_mcast_forw_ipv4_node_get() - get a node with an ipv4 flag + * @bat_priv: the bat priv with all the soft interface information + * + * Return: an orig_node which has the BATADV_MCAST_WANT_ALL_IPV4 flag set and + * increases its refcount. + */ +static struct batadv_orig_node * +batadv_mcast_forw_ipv4_node_get(struct batadv_priv *bat_priv) +{ + struct batadv_orig_node *tmp_orig_node, *orig_node = NULL; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp_orig_node, + &bat_priv->mcast.want_all_ipv4_list, + mcast_want_all_ipv4_node) { + if (!kref_get_unless_zero(&tmp_orig_node->refcount)) + continue; + + orig_node = tmp_orig_node; + break; + } + rcu_read_unlock(); + + return orig_node; +} + +/** + * batadv_mcast_forw_ipv6_node_get() - get a node with an ipv6 flag + * @bat_priv: the bat priv with all the soft interface information + * + * Return: an orig_node which has the BATADV_MCAST_WANT_ALL_IPV6 flag set + * and increases its refcount. + */ +static struct batadv_orig_node * +batadv_mcast_forw_ipv6_node_get(struct batadv_priv *bat_priv) +{ + struct batadv_orig_node *tmp_orig_node, *orig_node = NULL; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp_orig_node, + &bat_priv->mcast.want_all_ipv6_list, + mcast_want_all_ipv6_node) { + if (!kref_get_unless_zero(&tmp_orig_node->refcount)) + continue; + + orig_node = tmp_orig_node; + break; + } + rcu_read_unlock(); + + return orig_node; +} + +/** + * batadv_mcast_forw_ip_node_get() - get a node with an ipv4/ipv6 flag + * @bat_priv: the bat priv with all the soft interface information + * @ethhdr: an ethernet header to determine the protocol family from + * + * Return: an orig_node which has the BATADV_MCAST_WANT_ALL_IPV4 or + * BATADV_MCAST_WANT_ALL_IPV6 flag, depending on the provided ethhdr, sets and + * increases its refcount. + */ +static struct batadv_orig_node * +batadv_mcast_forw_ip_node_get(struct batadv_priv *bat_priv, + struct ethhdr *ethhdr) +{ + switch (ntohs(ethhdr->h_proto)) { + case ETH_P_IP: + return batadv_mcast_forw_ipv4_node_get(bat_priv); + case ETH_P_IPV6: + return batadv_mcast_forw_ipv6_node_get(bat_priv); + default: + /* we shouldn't be here... */ + return NULL; + } +} + +/** + * batadv_mcast_forw_unsnoop_node_get() - get a node with an unsnoopable flag + * @bat_priv: the bat priv with all the soft interface information + * + * Return: an orig_node which has the BATADV_MCAST_WANT_ALL_UNSNOOPABLES flag + * set and increases its refcount. + */ +static struct batadv_orig_node * +batadv_mcast_forw_unsnoop_node_get(struct batadv_priv *bat_priv) +{ + struct batadv_orig_node *tmp_orig_node, *orig_node = NULL; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp_orig_node, + &bat_priv->mcast.want_all_unsnoopables_list, + mcast_want_all_unsnoopables_node) { + if (!kref_get_unless_zero(&tmp_orig_node->refcount)) + continue; + + orig_node = tmp_orig_node; + break; + } + rcu_read_unlock(); + + return orig_node; +} + +/** + * batadv_mcast_forw_rtr4_node_get() - get a node with an ipv4 mcast router flag + * @bat_priv: the bat priv with all the soft interface information + * + * Return: an orig_node which has the BATADV_MCAST_WANT_NO_RTR4 flag unset and + * increases its refcount. + */ +static struct batadv_orig_node * +batadv_mcast_forw_rtr4_node_get(struct batadv_priv *bat_priv) +{ + struct batadv_orig_node *tmp_orig_node, *orig_node = NULL; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp_orig_node, + &bat_priv->mcast.want_all_rtr4_list, + mcast_want_all_rtr4_node) { + if (!kref_get_unless_zero(&tmp_orig_node->refcount)) + continue; + + orig_node = tmp_orig_node; + break; + } + rcu_read_unlock(); + + return orig_node; +} + +/** + * batadv_mcast_forw_rtr6_node_get() - get a node with an ipv6 mcast router flag + * @bat_priv: the bat priv with all the soft interface information + * + * Return: an orig_node which has the BATADV_MCAST_WANT_NO_RTR6 flag unset + * and increases its refcount. + */ +static struct batadv_orig_node * +batadv_mcast_forw_rtr6_node_get(struct batadv_priv *bat_priv) +{ + struct batadv_orig_node *tmp_orig_node, *orig_node = NULL; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp_orig_node, + &bat_priv->mcast.want_all_rtr6_list, + mcast_want_all_rtr6_node) { + if (!kref_get_unless_zero(&tmp_orig_node->refcount)) + continue; + + orig_node = tmp_orig_node; + break; + } + rcu_read_unlock(); + + return orig_node; +} + +/** + * batadv_mcast_forw_rtr_node_get() - get a node with an ipv4/ipv6 router flag + * @bat_priv: the bat priv with all the soft interface information + * @ethhdr: an ethernet header to determine the protocol family from + * + * Return: an orig_node which has no BATADV_MCAST_WANT_NO_RTR4 or + * BATADV_MCAST_WANT_NO_RTR6 flag, depending on the provided ethhdr, set and + * increases its refcount. + */ +static struct batadv_orig_node * +batadv_mcast_forw_rtr_node_get(struct batadv_priv *bat_priv, + struct ethhdr *ethhdr) +{ + switch (ntohs(ethhdr->h_proto)) { + case ETH_P_IP: + return batadv_mcast_forw_rtr4_node_get(bat_priv); + case ETH_P_IPV6: + return batadv_mcast_forw_rtr6_node_get(bat_priv); + default: + /* we shouldn't be here... */ + return NULL; + } +} + +/** + * batadv_mcast_forw_mode() - check on how to forward a multicast packet + * @bat_priv: the bat priv with all the soft interface information + * @skb: The multicast packet to check + * @orig: an originator to be set to forward the skb to + * @is_routable: stores whether the destination is routable + * + * Return: the forwarding mode as enum batadv_forw_mode and in case of + * BATADV_FORW_SINGLE set the orig to the single originator the skb + * should be forwarded to. + */ +enum batadv_forw_mode +batadv_mcast_forw_mode(struct batadv_priv *bat_priv, struct sk_buff *skb, + struct batadv_orig_node **orig, int *is_routable) +{ + int ret, tt_count, ip_count, unsnoop_count, total_count; + bool is_unsnoopable = false; + unsigned int mcast_fanout; + struct ethhdr *ethhdr; + int rtr_count = 0; + + ret = batadv_mcast_forw_mode_check(bat_priv, skb, &is_unsnoopable, + is_routable); + if (ret == -ENOMEM) + return BATADV_FORW_NONE; + else if (ret < 0) + return BATADV_FORW_ALL; + + ethhdr = eth_hdr(skb); + + tt_count = batadv_tt_global_hash_count(bat_priv, ethhdr->h_dest, + BATADV_NO_FLAGS); + ip_count = batadv_mcast_forw_want_all_ip_count(bat_priv, ethhdr); + unsnoop_count = !is_unsnoopable ? 0 : + atomic_read(&bat_priv->mcast.num_want_all_unsnoopables); + rtr_count = batadv_mcast_forw_rtr_count(bat_priv, *is_routable); + + total_count = tt_count + ip_count + unsnoop_count + rtr_count; + + switch (total_count) { + case 1: + if (tt_count) + *orig = batadv_mcast_forw_tt_node_get(bat_priv, ethhdr); + else if (ip_count) + *orig = batadv_mcast_forw_ip_node_get(bat_priv, ethhdr); + else if (unsnoop_count) + *orig = batadv_mcast_forw_unsnoop_node_get(bat_priv); + else if (rtr_count) + *orig = batadv_mcast_forw_rtr_node_get(bat_priv, + ethhdr); + + if (*orig) + return BATADV_FORW_SINGLE; + + fallthrough; + case 0: + return BATADV_FORW_NONE; + default: + mcast_fanout = atomic_read(&bat_priv->multicast_fanout); + + if (!unsnoop_count && total_count <= mcast_fanout) + return BATADV_FORW_SOME; + } + + return BATADV_FORW_ALL; +} + +/** + * batadv_mcast_forw_send_orig() - send a multicast packet to an originator + * @bat_priv: the bat priv with all the soft interface information + * @skb: the multicast packet to send + * @vid: the vlan identifier + * @orig_node: the originator to send the packet to + * + * Return: NET_XMIT_DROP in case of error or NET_XMIT_SUCCESS otherwise. + */ +int batadv_mcast_forw_send_orig(struct batadv_priv *bat_priv, + struct sk_buff *skb, + unsigned short vid, + struct batadv_orig_node *orig_node) +{ + /* Avoid sending multicast-in-unicast packets to other BLA + * gateways - they already got the frame from the LAN side + * we share with them. + * TODO: Refactor to take BLA into account earlier, to avoid + * reducing the mcast_fanout count. + */ + if (batadv_bla_is_backbone_gw_orig(bat_priv, orig_node->orig, vid)) { + dev_kfree_skb(skb); + return NET_XMIT_SUCCESS; + } + + return batadv_send_skb_unicast(bat_priv, skb, BATADV_UNICAST, 0, + orig_node, vid); +} + +/** + * batadv_mcast_forw_tt() - forwards a packet to multicast listeners + * @bat_priv: the bat priv with all the soft interface information + * @skb: the multicast packet to transmit + * @vid: the vlan identifier + * + * Sends copies of a frame with multicast destination to any multicast + * listener registered in the translation table. A transmission is performed + * via a batman-adv unicast packet for each such destination node. + * + * Return: NET_XMIT_DROP on memory allocation failure, NET_XMIT_SUCCESS + * otherwise. + */ +static int +batadv_mcast_forw_tt(struct batadv_priv *bat_priv, struct sk_buff *skb, + unsigned short vid) +{ + int ret = NET_XMIT_SUCCESS; + struct sk_buff *newskb; + + struct batadv_tt_orig_list_entry *orig_entry; + + struct batadv_tt_global_entry *tt_global; + const u8 *addr = eth_hdr(skb)->h_dest; + + tt_global = batadv_tt_global_hash_find(bat_priv, addr, vid); + if (!tt_global) + goto out; + + rcu_read_lock(); + hlist_for_each_entry_rcu(orig_entry, &tt_global->orig_list, list) { + newskb = skb_copy(skb, GFP_ATOMIC); + if (!newskb) { + ret = NET_XMIT_DROP; + break; + } + + batadv_mcast_forw_send_orig(bat_priv, newskb, vid, + orig_entry->orig_node); + } + rcu_read_unlock(); + + batadv_tt_global_entry_put(tt_global); + +out: + return ret; +} + +/** + * batadv_mcast_forw_want_all_ipv4() - forward to nodes with want-all-ipv4 + * @bat_priv: the bat priv with all the soft interface information + * @skb: the multicast packet to transmit + * @vid: the vlan identifier + * + * Sends copies of a frame with multicast destination to any node with a + * BATADV_MCAST_WANT_ALL_IPV4 flag set. A transmission is performed via a + * batman-adv unicast packet for each such destination node. + * + * Return: NET_XMIT_DROP on memory allocation failure, NET_XMIT_SUCCESS + * otherwise. + */ +static int +batadv_mcast_forw_want_all_ipv4(struct batadv_priv *bat_priv, + struct sk_buff *skb, unsigned short vid) +{ + struct batadv_orig_node *orig_node; + int ret = NET_XMIT_SUCCESS; + struct sk_buff *newskb; + + rcu_read_lock(); + hlist_for_each_entry_rcu(orig_node, + &bat_priv->mcast.want_all_ipv4_list, + mcast_want_all_ipv4_node) { + newskb = skb_copy(skb, GFP_ATOMIC); + if (!newskb) { + ret = NET_XMIT_DROP; + break; + } + + batadv_mcast_forw_send_orig(bat_priv, newskb, vid, orig_node); + } + rcu_read_unlock(); + return ret; +} + +/** + * batadv_mcast_forw_want_all_ipv6() - forward to nodes with want-all-ipv6 + * @bat_priv: the bat priv with all the soft interface information + * @skb: The multicast packet to transmit + * @vid: the vlan identifier + * + * Sends copies of a frame with multicast destination to any node with a + * BATADV_MCAST_WANT_ALL_IPV6 flag set. A transmission is performed via a + * batman-adv unicast packet for each such destination node. + * + * Return: NET_XMIT_DROP on memory allocation failure, NET_XMIT_SUCCESS + * otherwise. + */ +static int +batadv_mcast_forw_want_all_ipv6(struct batadv_priv *bat_priv, + struct sk_buff *skb, unsigned short vid) +{ + struct batadv_orig_node *orig_node; + int ret = NET_XMIT_SUCCESS; + struct sk_buff *newskb; + + rcu_read_lock(); + hlist_for_each_entry_rcu(orig_node, + &bat_priv->mcast.want_all_ipv6_list, + mcast_want_all_ipv6_node) { + newskb = skb_copy(skb, GFP_ATOMIC); + if (!newskb) { + ret = NET_XMIT_DROP; + break; + } + + batadv_mcast_forw_send_orig(bat_priv, newskb, vid, orig_node); + } + rcu_read_unlock(); + return ret; +} + +/** + * batadv_mcast_forw_want_all() - forward packet to nodes in a want-all list + * @bat_priv: the bat priv with all the soft interface information + * @skb: the multicast packet to transmit + * @vid: the vlan identifier + * + * Sends copies of a frame with multicast destination to any node with a + * BATADV_MCAST_WANT_ALL_IPV4 or BATADV_MCAST_WANT_ALL_IPV6 flag set. A + * transmission is performed via a batman-adv unicast packet for each such + * destination node. + * + * Return: NET_XMIT_DROP on memory allocation failure or if the protocol family + * is neither IPv4 nor IPv6. NET_XMIT_SUCCESS otherwise. + */ +static int +batadv_mcast_forw_want_all(struct batadv_priv *bat_priv, + struct sk_buff *skb, unsigned short vid) +{ + switch (ntohs(eth_hdr(skb)->h_proto)) { + case ETH_P_IP: + return batadv_mcast_forw_want_all_ipv4(bat_priv, skb, vid); + case ETH_P_IPV6: + return batadv_mcast_forw_want_all_ipv6(bat_priv, skb, vid); + default: + /* we shouldn't be here... */ + return NET_XMIT_DROP; + } +} + +/** + * batadv_mcast_forw_want_all_rtr4() - forward to nodes with want-all-rtr4 + * @bat_priv: the bat priv with all the soft interface information + * @skb: the multicast packet to transmit + * @vid: the vlan identifier + * + * Sends copies of a frame with multicast destination to any node with a + * BATADV_MCAST_WANT_NO_RTR4 flag unset. A transmission is performed via a + * batman-adv unicast packet for each such destination node. + * + * Return: NET_XMIT_DROP on memory allocation failure, NET_XMIT_SUCCESS + * otherwise. + */ +static int +batadv_mcast_forw_want_all_rtr4(struct batadv_priv *bat_priv, + struct sk_buff *skb, unsigned short vid) +{ + struct batadv_orig_node *orig_node; + int ret = NET_XMIT_SUCCESS; + struct sk_buff *newskb; + + rcu_read_lock(); + hlist_for_each_entry_rcu(orig_node, + &bat_priv->mcast.want_all_rtr4_list, + mcast_want_all_rtr4_node) { + newskb = skb_copy(skb, GFP_ATOMIC); + if (!newskb) { + ret = NET_XMIT_DROP; + break; + } + + batadv_mcast_forw_send_orig(bat_priv, newskb, vid, orig_node); + } + rcu_read_unlock(); + return ret; +} + +/** + * batadv_mcast_forw_want_all_rtr6() - forward to nodes with want-all-rtr6 + * @bat_priv: the bat priv with all the soft interface information + * @skb: The multicast packet to transmit + * @vid: the vlan identifier + * + * Sends copies of a frame with multicast destination to any node with a + * BATADV_MCAST_WANT_NO_RTR6 flag unset. A transmission is performed via a + * batman-adv unicast packet for each such destination node. + * + * Return: NET_XMIT_DROP on memory allocation failure, NET_XMIT_SUCCESS + * otherwise. + */ +static int +batadv_mcast_forw_want_all_rtr6(struct batadv_priv *bat_priv, + struct sk_buff *skb, unsigned short vid) +{ + struct batadv_orig_node *orig_node; + int ret = NET_XMIT_SUCCESS; + struct sk_buff *newskb; + + rcu_read_lock(); + hlist_for_each_entry_rcu(orig_node, + &bat_priv->mcast.want_all_rtr6_list, + mcast_want_all_rtr6_node) { + newskb = skb_copy(skb, GFP_ATOMIC); + if (!newskb) { + ret = NET_XMIT_DROP; + break; + } + + batadv_mcast_forw_send_orig(bat_priv, newskb, vid, orig_node); + } + rcu_read_unlock(); + return ret; +} + +/** + * batadv_mcast_forw_want_rtr() - forward packet to nodes in a want-all-rtr list + * @bat_priv: the bat priv with all the soft interface information + * @skb: the multicast packet to transmit + * @vid: the vlan identifier + * + * Sends copies of a frame with multicast destination to any node with a + * BATADV_MCAST_WANT_NO_RTR4 or BATADV_MCAST_WANT_NO_RTR6 flag unset. A + * transmission is performed via a batman-adv unicast packet for each such + * destination node. + * + * Return: NET_XMIT_DROP on memory allocation failure or if the protocol family + * is neither IPv4 nor IPv6. NET_XMIT_SUCCESS otherwise. + */ +static int +batadv_mcast_forw_want_rtr(struct batadv_priv *bat_priv, + struct sk_buff *skb, unsigned short vid) +{ + switch (ntohs(eth_hdr(skb)->h_proto)) { + case ETH_P_IP: + return batadv_mcast_forw_want_all_rtr4(bat_priv, skb, vid); + case ETH_P_IPV6: + return batadv_mcast_forw_want_all_rtr6(bat_priv, skb, vid); + default: + /* we shouldn't be here... */ + return NET_XMIT_DROP; + } +} + +/** + * batadv_mcast_forw_send() - send packet to any detected multicast recipient + * @bat_priv: the bat priv with all the soft interface information + * @skb: the multicast packet to transmit + * @vid: the vlan identifier + * @is_routable: stores whether the destination is routable + * + * Sends copies of a frame with multicast destination to any node that signaled + * interest in it, that is either via the translation table or the according + * want-all flags. A transmission is performed via a batman-adv unicast packet + * for each such destination node. + * + * The given skb is consumed/freed. + * + * Return: NET_XMIT_DROP on memory allocation failure or if the protocol family + * is neither IPv4 nor IPv6. NET_XMIT_SUCCESS otherwise. + */ +int batadv_mcast_forw_send(struct batadv_priv *bat_priv, struct sk_buff *skb, + unsigned short vid, int is_routable) +{ + int ret; + + ret = batadv_mcast_forw_tt(bat_priv, skb, vid); + if (ret != NET_XMIT_SUCCESS) { + kfree_skb(skb); + return ret; + } + + ret = batadv_mcast_forw_want_all(bat_priv, skb, vid); + if (ret != NET_XMIT_SUCCESS) { + kfree_skb(skb); + return ret; + } + + if (!is_routable) + goto skip_mc_router; + + ret = batadv_mcast_forw_want_rtr(bat_priv, skb, vid); + if (ret != NET_XMIT_SUCCESS) { + kfree_skb(skb); + return ret; + } + +skip_mc_router: + consume_skb(skb); + return ret; +} + +/** + * batadv_mcast_want_unsnoop_update() - update unsnoop counter and list + * @bat_priv: the bat priv with all the soft interface information + * @orig: the orig_node which multicast state might have changed of + * @mcast_flags: flags indicating the new multicast state + * + * If the BATADV_MCAST_WANT_ALL_UNSNOOPABLES flag of this originator, + * orig, has toggled then this method updates the counter and the list + * accordingly. + * + * Caller needs to hold orig->mcast_handler_lock. + */ +static void batadv_mcast_want_unsnoop_update(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig, + u8 mcast_flags) +{ + struct hlist_node *node = &orig->mcast_want_all_unsnoopables_node; + struct hlist_head *head = &bat_priv->mcast.want_all_unsnoopables_list; + + lockdep_assert_held(&orig->mcast_handler_lock); + + /* switched from flag unset to set */ + if (mcast_flags & BATADV_MCAST_WANT_ALL_UNSNOOPABLES && + !(orig->mcast_flags & BATADV_MCAST_WANT_ALL_UNSNOOPABLES)) { + atomic_inc(&bat_priv->mcast.num_want_all_unsnoopables); + + spin_lock_bh(&bat_priv->mcast.want_lists_lock); + /* flag checks above + mcast_handler_lock prevents this */ + WARN_ON(!hlist_unhashed(node)); + + hlist_add_head_rcu(node, head); + spin_unlock_bh(&bat_priv->mcast.want_lists_lock); + /* switched from flag set to unset */ + } else if (!(mcast_flags & BATADV_MCAST_WANT_ALL_UNSNOOPABLES) && + orig->mcast_flags & BATADV_MCAST_WANT_ALL_UNSNOOPABLES) { + atomic_dec(&bat_priv->mcast.num_want_all_unsnoopables); + + spin_lock_bh(&bat_priv->mcast.want_lists_lock); + /* flag checks above + mcast_handler_lock prevents this */ + WARN_ON(hlist_unhashed(node)); + + hlist_del_init_rcu(node); + spin_unlock_bh(&bat_priv->mcast.want_lists_lock); + } +} + +/** + * batadv_mcast_want_ipv4_update() - update want-all-ipv4 counter and list + * @bat_priv: the bat priv with all the soft interface information + * @orig: the orig_node which multicast state might have changed of + * @mcast_flags: flags indicating the new multicast state + * + * If the BATADV_MCAST_WANT_ALL_IPV4 flag of this originator, orig, has + * toggled then this method updates the counter and the list accordingly. + * + * Caller needs to hold orig->mcast_handler_lock. + */ +static void batadv_mcast_want_ipv4_update(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig, + u8 mcast_flags) +{ + struct hlist_node *node = &orig->mcast_want_all_ipv4_node; + struct hlist_head *head = &bat_priv->mcast.want_all_ipv4_list; + + lockdep_assert_held(&orig->mcast_handler_lock); + + /* switched from flag unset to set */ + if (mcast_flags & BATADV_MCAST_WANT_ALL_IPV4 && + !(orig->mcast_flags & BATADV_MCAST_WANT_ALL_IPV4)) { + atomic_inc(&bat_priv->mcast.num_want_all_ipv4); + + spin_lock_bh(&bat_priv->mcast.want_lists_lock); + /* flag checks above + mcast_handler_lock prevents this */ + WARN_ON(!hlist_unhashed(node)); + + hlist_add_head_rcu(node, head); + spin_unlock_bh(&bat_priv->mcast.want_lists_lock); + /* switched from flag set to unset */ + } else if (!(mcast_flags & BATADV_MCAST_WANT_ALL_IPV4) && + orig->mcast_flags & BATADV_MCAST_WANT_ALL_IPV4) { + atomic_dec(&bat_priv->mcast.num_want_all_ipv4); + + spin_lock_bh(&bat_priv->mcast.want_lists_lock); + /* flag checks above + mcast_handler_lock prevents this */ + WARN_ON(hlist_unhashed(node)); + + hlist_del_init_rcu(node); + spin_unlock_bh(&bat_priv->mcast.want_lists_lock); + } +} + +/** + * batadv_mcast_want_ipv6_update() - update want-all-ipv6 counter and list + * @bat_priv: the bat priv with all the soft interface information + * @orig: the orig_node which multicast state might have changed of + * @mcast_flags: flags indicating the new multicast state + * + * If the BATADV_MCAST_WANT_ALL_IPV6 flag of this originator, orig, has + * toggled then this method updates the counter and the list accordingly. + * + * Caller needs to hold orig->mcast_handler_lock. + */ +static void batadv_mcast_want_ipv6_update(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig, + u8 mcast_flags) +{ + struct hlist_node *node = &orig->mcast_want_all_ipv6_node; + struct hlist_head *head = &bat_priv->mcast.want_all_ipv6_list; + + lockdep_assert_held(&orig->mcast_handler_lock); + + /* switched from flag unset to set */ + if (mcast_flags & BATADV_MCAST_WANT_ALL_IPV6 && + !(orig->mcast_flags & BATADV_MCAST_WANT_ALL_IPV6)) { + atomic_inc(&bat_priv->mcast.num_want_all_ipv6); + + spin_lock_bh(&bat_priv->mcast.want_lists_lock); + /* flag checks above + mcast_handler_lock prevents this */ + WARN_ON(!hlist_unhashed(node)); + + hlist_add_head_rcu(node, head); + spin_unlock_bh(&bat_priv->mcast.want_lists_lock); + /* switched from flag set to unset */ + } else if (!(mcast_flags & BATADV_MCAST_WANT_ALL_IPV6) && + orig->mcast_flags & BATADV_MCAST_WANT_ALL_IPV6) { + atomic_dec(&bat_priv->mcast.num_want_all_ipv6); + + spin_lock_bh(&bat_priv->mcast.want_lists_lock); + /* flag checks above + mcast_handler_lock prevents this */ + WARN_ON(hlist_unhashed(node)); + + hlist_del_init_rcu(node); + spin_unlock_bh(&bat_priv->mcast.want_lists_lock); + } +} + +/** + * batadv_mcast_want_rtr4_update() - update want-all-rtr4 counter and list + * @bat_priv: the bat priv with all the soft interface information + * @orig: the orig_node which multicast state might have changed of + * @mcast_flags: flags indicating the new multicast state + * + * If the BATADV_MCAST_WANT_NO_RTR4 flag of this originator, orig, has + * toggled then this method updates the counter and the list accordingly. + * + * Caller needs to hold orig->mcast_handler_lock. + */ +static void batadv_mcast_want_rtr4_update(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig, + u8 mcast_flags) +{ + struct hlist_node *node = &orig->mcast_want_all_rtr4_node; + struct hlist_head *head = &bat_priv->mcast.want_all_rtr4_list; + + lockdep_assert_held(&orig->mcast_handler_lock); + + /* switched from flag set to unset */ + if (!(mcast_flags & BATADV_MCAST_WANT_NO_RTR4) && + orig->mcast_flags & BATADV_MCAST_WANT_NO_RTR4) { + atomic_inc(&bat_priv->mcast.num_want_all_rtr4); + + spin_lock_bh(&bat_priv->mcast.want_lists_lock); + /* flag checks above + mcast_handler_lock prevents this */ + WARN_ON(!hlist_unhashed(node)); + + hlist_add_head_rcu(node, head); + spin_unlock_bh(&bat_priv->mcast.want_lists_lock); + /* switched from flag unset to set */ + } else if (mcast_flags & BATADV_MCAST_WANT_NO_RTR4 && + !(orig->mcast_flags & BATADV_MCAST_WANT_NO_RTR4)) { + atomic_dec(&bat_priv->mcast.num_want_all_rtr4); + + spin_lock_bh(&bat_priv->mcast.want_lists_lock); + /* flag checks above + mcast_handler_lock prevents this */ + WARN_ON(hlist_unhashed(node)); + + hlist_del_init_rcu(node); + spin_unlock_bh(&bat_priv->mcast.want_lists_lock); + } +} + +/** + * batadv_mcast_want_rtr6_update() - update want-all-rtr6 counter and list + * @bat_priv: the bat priv with all the soft interface information + * @orig: the orig_node which multicast state might have changed of + * @mcast_flags: flags indicating the new multicast state + * + * If the BATADV_MCAST_WANT_NO_RTR6 flag of this originator, orig, has + * toggled then this method updates the counter and the list accordingly. + * + * Caller needs to hold orig->mcast_handler_lock. + */ +static void batadv_mcast_want_rtr6_update(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig, + u8 mcast_flags) +{ + struct hlist_node *node = &orig->mcast_want_all_rtr6_node; + struct hlist_head *head = &bat_priv->mcast.want_all_rtr6_list; + + lockdep_assert_held(&orig->mcast_handler_lock); + + /* switched from flag set to unset */ + if (!(mcast_flags & BATADV_MCAST_WANT_NO_RTR6) && + orig->mcast_flags & BATADV_MCAST_WANT_NO_RTR6) { + atomic_inc(&bat_priv->mcast.num_want_all_rtr6); + + spin_lock_bh(&bat_priv->mcast.want_lists_lock); + /* flag checks above + mcast_handler_lock prevents this */ + WARN_ON(!hlist_unhashed(node)); + + hlist_add_head_rcu(node, head); + spin_unlock_bh(&bat_priv->mcast.want_lists_lock); + /* switched from flag unset to set */ + } else if (mcast_flags & BATADV_MCAST_WANT_NO_RTR6 && + !(orig->mcast_flags & BATADV_MCAST_WANT_NO_RTR6)) { + atomic_dec(&bat_priv->mcast.num_want_all_rtr6); + + spin_lock_bh(&bat_priv->mcast.want_lists_lock); + /* flag checks above + mcast_handler_lock prevents this */ + WARN_ON(hlist_unhashed(node)); + + hlist_del_init_rcu(node); + spin_unlock_bh(&bat_priv->mcast.want_lists_lock); + } +} + +/** + * batadv_mcast_tvlv_flags_get() - get multicast flags from an OGM TVLV + * @enabled: whether the originator has multicast TVLV support enabled + * @tvlv_value: tvlv buffer containing the multicast flags + * @tvlv_value_len: tvlv buffer length + * + * Return: multicast flags for the given tvlv buffer + */ +static u8 +batadv_mcast_tvlv_flags_get(bool enabled, void *tvlv_value, u16 tvlv_value_len) +{ + u8 mcast_flags = BATADV_NO_FLAGS; + + if (enabled && tvlv_value && tvlv_value_len >= sizeof(mcast_flags)) + mcast_flags = *(u8 *)tvlv_value; + + if (!enabled) { + mcast_flags |= BATADV_MCAST_WANT_ALL_IPV4; + mcast_flags |= BATADV_MCAST_WANT_ALL_IPV6; + } + + /* remove redundant flags to avoid sending duplicate packets later */ + if (mcast_flags & BATADV_MCAST_WANT_ALL_IPV4) + mcast_flags |= BATADV_MCAST_WANT_NO_RTR4; + + if (mcast_flags & BATADV_MCAST_WANT_ALL_IPV6) + mcast_flags |= BATADV_MCAST_WANT_NO_RTR6; + + return mcast_flags; +} + +/** + * batadv_mcast_tvlv_ogm_handler() - process incoming multicast tvlv container + * @bat_priv: the bat priv with all the soft interface information + * @orig: the orig_node of the ogm + * @flags: flags indicating the tvlv state (see batadv_tvlv_handler_flags) + * @tvlv_value: tvlv buffer containing the multicast data + * @tvlv_value_len: tvlv buffer length + */ +static void batadv_mcast_tvlv_ogm_handler(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig, + u8 flags, + void *tvlv_value, + u16 tvlv_value_len) +{ + bool orig_mcast_enabled = !(flags & BATADV_TVLV_HANDLER_OGM_CIFNOTFND); + u8 mcast_flags; + + mcast_flags = batadv_mcast_tvlv_flags_get(orig_mcast_enabled, + tvlv_value, tvlv_value_len); + + spin_lock_bh(&orig->mcast_handler_lock); + + if (orig_mcast_enabled && + !test_bit(BATADV_ORIG_CAPA_HAS_MCAST, &orig->capabilities)) { + set_bit(BATADV_ORIG_CAPA_HAS_MCAST, &orig->capabilities); + } else if (!orig_mcast_enabled && + test_bit(BATADV_ORIG_CAPA_HAS_MCAST, &orig->capabilities)) { + clear_bit(BATADV_ORIG_CAPA_HAS_MCAST, &orig->capabilities); + } + + set_bit(BATADV_ORIG_CAPA_HAS_MCAST, &orig->capa_initialized); + + batadv_mcast_want_unsnoop_update(bat_priv, orig, mcast_flags); + batadv_mcast_want_ipv4_update(bat_priv, orig, mcast_flags); + batadv_mcast_want_ipv6_update(bat_priv, orig, mcast_flags); + batadv_mcast_want_rtr4_update(bat_priv, orig, mcast_flags); + batadv_mcast_want_rtr6_update(bat_priv, orig, mcast_flags); + + orig->mcast_flags = mcast_flags; + spin_unlock_bh(&orig->mcast_handler_lock); +} + +/** + * batadv_mcast_init() - initialize the multicast optimizations structures + * @bat_priv: the bat priv with all the soft interface information + */ +void batadv_mcast_init(struct batadv_priv *bat_priv) +{ + batadv_tvlv_handler_register(bat_priv, batadv_mcast_tvlv_ogm_handler, + NULL, BATADV_TVLV_MCAST, 2, + BATADV_TVLV_HANDLER_OGM_CIFNOTFND); + + INIT_DELAYED_WORK(&bat_priv->mcast.work, batadv_mcast_mla_update); + batadv_mcast_start_timer(bat_priv); +} + +/** + * batadv_mcast_mesh_info_put() - put multicast info into a netlink message + * @msg: buffer for the message + * @bat_priv: the bat priv with all the soft interface information + * + * Return: 0 or error code. + */ +int batadv_mcast_mesh_info_put(struct sk_buff *msg, + struct batadv_priv *bat_priv) +{ + u32 flags = bat_priv->mcast.mla_flags.tvlv_flags; + u32 flags_priv = BATADV_NO_FLAGS; + + if (bat_priv->mcast.mla_flags.bridged) { + flags_priv |= BATADV_MCAST_FLAGS_BRIDGED; + + if (bat_priv->mcast.mla_flags.querier_ipv4.exists) + flags_priv |= BATADV_MCAST_FLAGS_QUERIER_IPV4_EXISTS; + if (bat_priv->mcast.mla_flags.querier_ipv6.exists) + flags_priv |= BATADV_MCAST_FLAGS_QUERIER_IPV6_EXISTS; + if (bat_priv->mcast.mla_flags.querier_ipv4.shadowing) + flags_priv |= BATADV_MCAST_FLAGS_QUERIER_IPV4_SHADOWING; + if (bat_priv->mcast.mla_flags.querier_ipv6.shadowing) + flags_priv |= BATADV_MCAST_FLAGS_QUERIER_IPV6_SHADOWING; + } + + if (nla_put_u32(msg, BATADV_ATTR_MCAST_FLAGS, flags) || + nla_put_u32(msg, BATADV_ATTR_MCAST_FLAGS_PRIV, flags_priv)) + return -EMSGSIZE; + + return 0; +} + +/** + * batadv_mcast_flags_dump_entry() - dump one entry of the multicast flags table + * to a netlink socket + * @msg: buffer for the message + * @portid: netlink port + * @cb: Control block containing additional options + * @orig_node: originator to dump the multicast flags of + * + * Return: 0 or error code. + */ +static int +batadv_mcast_flags_dump_entry(struct sk_buff *msg, u32 portid, + struct netlink_callback *cb, + struct batadv_orig_node *orig_node) +{ + void *hdr; + + hdr = genlmsg_put(msg, portid, cb->nlh->nlmsg_seq, + &batadv_netlink_family, NLM_F_MULTI, + BATADV_CMD_GET_MCAST_FLAGS); + if (!hdr) + return -ENOBUFS; + + genl_dump_check_consistent(cb, hdr); + + if (nla_put(msg, BATADV_ATTR_ORIG_ADDRESS, ETH_ALEN, + orig_node->orig)) { + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; + } + + if (test_bit(BATADV_ORIG_CAPA_HAS_MCAST, + &orig_node->capabilities)) { + if (nla_put_u32(msg, BATADV_ATTR_MCAST_FLAGS, + orig_node->mcast_flags)) { + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; + } + } + + genlmsg_end(msg, hdr); + return 0; +} + +/** + * batadv_mcast_flags_dump_bucket() - dump one bucket of the multicast flags + * table to a netlink socket + * @msg: buffer for the message + * @portid: netlink port + * @cb: Control block containing additional options + * @hash: hash to dump + * @bucket: bucket index to dump + * @idx_skip: How many entries to skip + * + * Return: 0 or error code. + */ +static int +batadv_mcast_flags_dump_bucket(struct sk_buff *msg, u32 portid, + struct netlink_callback *cb, + struct batadv_hashtable *hash, + unsigned int bucket, long *idx_skip) +{ + struct batadv_orig_node *orig_node; + long idx = 0; + + spin_lock_bh(&hash->list_locks[bucket]); + cb->seq = atomic_read(&hash->generation) << 1 | 1; + + hlist_for_each_entry(orig_node, &hash->table[bucket], hash_entry) { + if (!test_bit(BATADV_ORIG_CAPA_HAS_MCAST, + &orig_node->capa_initialized)) + continue; + + if (idx < *idx_skip) + goto skip; + + if (batadv_mcast_flags_dump_entry(msg, portid, cb, orig_node)) { + spin_unlock_bh(&hash->list_locks[bucket]); + *idx_skip = idx; + + return -EMSGSIZE; + } + +skip: + idx++; + } + spin_unlock_bh(&hash->list_locks[bucket]); + + return 0; +} + +/** + * __batadv_mcast_flags_dump() - dump multicast flags table to a netlink socket + * @msg: buffer for the message + * @portid: netlink port + * @cb: Control block containing additional options + * @bat_priv: the bat priv with all the soft interface information + * @bucket: current bucket to dump + * @idx: index in current bucket to the next entry to dump + * + * Return: 0 or error code. + */ +static int +__batadv_mcast_flags_dump(struct sk_buff *msg, u32 portid, + struct netlink_callback *cb, + struct batadv_priv *bat_priv, long *bucket, long *idx) +{ + struct batadv_hashtable *hash = bat_priv->orig_hash; + long bucket_tmp = *bucket; + long idx_tmp = *idx; + + while (bucket_tmp < hash->size) { + if (batadv_mcast_flags_dump_bucket(msg, portid, cb, hash, + bucket_tmp, &idx_tmp)) + break; + + bucket_tmp++; + idx_tmp = 0; + } + + *bucket = bucket_tmp; + *idx = idx_tmp; + + return msg->len; +} + +/** + * batadv_mcast_netlink_get_primary() - get primary interface from netlink + * callback + * @cb: netlink callback structure + * @primary_if: the primary interface pointer to return the result in + * + * Return: 0 or error code. + */ +static int +batadv_mcast_netlink_get_primary(struct netlink_callback *cb, + struct batadv_hard_iface **primary_if) +{ + struct batadv_hard_iface *hard_iface = NULL; + struct net *net = sock_net(cb->skb->sk); + struct net_device *soft_iface; + struct batadv_priv *bat_priv; + int ifindex; + int ret = 0; + + ifindex = batadv_netlink_get_ifindex(cb->nlh, BATADV_ATTR_MESH_IFINDEX); + if (!ifindex) + return -EINVAL; + + soft_iface = dev_get_by_index(net, ifindex); + if (!soft_iface || !batadv_softif_is_valid(soft_iface)) { + ret = -ENODEV; + goto out; + } + + bat_priv = netdev_priv(soft_iface); + + hard_iface = batadv_primary_if_get_selected(bat_priv); + if (!hard_iface || hard_iface->if_status != BATADV_IF_ACTIVE) { + ret = -ENOENT; + goto out; + } + +out: + dev_put(soft_iface); + + if (!ret && primary_if) + *primary_if = hard_iface; + else + batadv_hardif_put(hard_iface); + + return ret; +} + +/** + * batadv_mcast_flags_dump() - dump multicast flags table to a netlink socket + * @msg: buffer for the message + * @cb: callback structure containing arguments + * + * Return: message length. + */ +int batadv_mcast_flags_dump(struct sk_buff *msg, struct netlink_callback *cb) +{ + struct batadv_hard_iface *primary_if = NULL; + int portid = NETLINK_CB(cb->skb).portid; + struct batadv_priv *bat_priv; + long *bucket = &cb->args[0]; + long *idx = &cb->args[1]; + int ret; + + ret = batadv_mcast_netlink_get_primary(cb, &primary_if); + if (ret) + return ret; + + bat_priv = netdev_priv(primary_if->soft_iface); + ret = __batadv_mcast_flags_dump(msg, portid, cb, bat_priv, bucket, idx); + + batadv_hardif_put(primary_if); + return ret; +} + +/** + * batadv_mcast_free() - free the multicast optimizations structures + * @bat_priv: the bat priv with all the soft interface information + */ +void batadv_mcast_free(struct batadv_priv *bat_priv) +{ + cancel_delayed_work_sync(&bat_priv->mcast.work); + + batadv_tvlv_container_unregister(bat_priv, BATADV_TVLV_MCAST, 2); + batadv_tvlv_handler_unregister(bat_priv, BATADV_TVLV_MCAST, 2); + + /* safely calling outside of worker, as worker was canceled above */ + batadv_mcast_mla_tt_retract(bat_priv, NULL); +} + +/** + * batadv_mcast_purge_orig() - reset originator global mcast state modifications + * @orig: the originator which is going to get purged + */ +void batadv_mcast_purge_orig(struct batadv_orig_node *orig) +{ + struct batadv_priv *bat_priv = orig->bat_priv; + + spin_lock_bh(&orig->mcast_handler_lock); + + batadv_mcast_want_unsnoop_update(bat_priv, orig, BATADV_NO_FLAGS); + batadv_mcast_want_ipv4_update(bat_priv, orig, BATADV_NO_FLAGS); + batadv_mcast_want_ipv6_update(bat_priv, orig, BATADV_NO_FLAGS); + batadv_mcast_want_rtr4_update(bat_priv, orig, + BATADV_MCAST_WANT_NO_RTR4); + batadv_mcast_want_rtr6_update(bat_priv, orig, + BATADV_MCAST_WANT_NO_RTR6); + + spin_unlock_bh(&orig->mcast_handler_lock); +} diff --git a/net/batman-adv/multicast.h b/net/batman-adv/multicast.h new file mode 100644 index 000000000..8aec818d0 --- /dev/null +++ b/net/batman-adv/multicast.h @@ -0,0 +1,123 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Linus Lüssing + */ + +#ifndef _NET_BATMAN_ADV_MULTICAST_H_ +#define _NET_BATMAN_ADV_MULTICAST_H_ + +#include "main.h" + +#include +#include + +/** + * enum batadv_forw_mode - the way a packet should be forwarded as + */ +enum batadv_forw_mode { + /** + * @BATADV_FORW_ALL: forward the packet to all nodes (currently via + * classic flooding) + */ + BATADV_FORW_ALL, + + /** + * @BATADV_FORW_SOME: forward the packet to some nodes (currently via + * a multicast-to-unicast conversion and the BATMAN unicast routing + * protocol) + */ + BATADV_FORW_SOME, + + /** + * @BATADV_FORW_SINGLE: forward the packet to a single node (currently + * via the BATMAN unicast routing protocol) + */ + BATADV_FORW_SINGLE, + + /** @BATADV_FORW_NONE: don't forward, drop it */ + BATADV_FORW_NONE, +}; + +#ifdef CONFIG_BATMAN_ADV_MCAST + +enum batadv_forw_mode +batadv_mcast_forw_mode(struct batadv_priv *bat_priv, struct sk_buff *skb, + struct batadv_orig_node **mcast_single_orig, + int *is_routable); + +int batadv_mcast_forw_send_orig(struct batadv_priv *bat_priv, + struct sk_buff *skb, + unsigned short vid, + struct batadv_orig_node *orig_node); + +int batadv_mcast_forw_send(struct batadv_priv *bat_priv, struct sk_buff *skb, + unsigned short vid, int is_routable); + +void batadv_mcast_init(struct batadv_priv *bat_priv); + +int batadv_mcast_mesh_info_put(struct sk_buff *msg, + struct batadv_priv *bat_priv); + +int batadv_mcast_flags_dump(struct sk_buff *msg, struct netlink_callback *cb); + +void batadv_mcast_free(struct batadv_priv *bat_priv); + +void batadv_mcast_purge_orig(struct batadv_orig_node *orig_node); + +#else + +static inline enum batadv_forw_mode +batadv_mcast_forw_mode(struct batadv_priv *bat_priv, struct sk_buff *skb, + struct batadv_orig_node **mcast_single_orig, + int *is_routable) +{ + return BATADV_FORW_ALL; +} + +static inline int +batadv_mcast_forw_send_orig(struct batadv_priv *bat_priv, + struct sk_buff *skb, + unsigned short vid, + struct batadv_orig_node *orig_node) +{ + kfree_skb(skb); + return NET_XMIT_DROP; +} + +static inline int +batadv_mcast_forw_send(struct batadv_priv *bat_priv, struct sk_buff *skb, + unsigned short vid, int is_routable) +{ + kfree_skb(skb); + return NET_XMIT_DROP; +} + +static inline int batadv_mcast_init(struct batadv_priv *bat_priv) +{ + return 0; +} + +static inline int +batadv_mcast_mesh_info_put(struct sk_buff *msg, struct batadv_priv *bat_priv) +{ + return 0; +} + +static inline int batadv_mcast_flags_dump(struct sk_buff *msg, + struct netlink_callback *cb) +{ + return -EOPNOTSUPP; +} + +static inline void batadv_mcast_free(struct batadv_priv *bat_priv) +{ +} + +static inline void batadv_mcast_purge_orig(struct batadv_orig_node *orig_node) +{ +} + +#endif /* CONFIG_BATMAN_ADV_MCAST */ + +#endif /* _NET_BATMAN_ADV_MULTICAST_H_ */ diff --git a/net/batman-adv/netlink.c b/net/batman-adv/netlink.c new file mode 100644 index 000000000..86e0664e0 --- /dev/null +++ b/net/batman-adv/netlink.c @@ -0,0 +1,1522 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Matthias Schiffer + */ + +#include "netlink.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "bat_algo.h" +#include "bridge_loop_avoidance.h" +#include "distributed-arp-table.h" +#include "gateway_client.h" +#include "gateway_common.h" +#include "hard-interface.h" +#include "log.h" +#include "multicast.h" +#include "network-coding.h" +#include "originator.h" +#include "soft-interface.h" +#include "tp_meter.h" +#include "translation-table.h" + +struct genl_family batadv_netlink_family; + +/* multicast groups */ +enum batadv_netlink_multicast_groups { + BATADV_NL_MCGRP_CONFIG, + BATADV_NL_MCGRP_TPMETER, +}; + +/** + * enum batadv_genl_ops_flags - flags for genl_ops's internal_flags + */ +enum batadv_genl_ops_flags { + /** + * @BATADV_FLAG_NEED_MESH: request requires valid soft interface in + * attribute BATADV_ATTR_MESH_IFINDEX and expects a pointer to it to be + * saved in info->user_ptr[0] + */ + BATADV_FLAG_NEED_MESH = BIT(0), + + /** + * @BATADV_FLAG_NEED_HARDIF: request requires valid hard interface in + * attribute BATADV_ATTR_HARD_IFINDEX and expects a pointer to it to be + * saved in info->user_ptr[1] + */ + BATADV_FLAG_NEED_HARDIF = BIT(1), + + /** + * @BATADV_FLAG_NEED_VLAN: request requires valid vlan in + * attribute BATADV_ATTR_VLANID and expects a pointer to it to be + * saved in info->user_ptr[1] + */ + BATADV_FLAG_NEED_VLAN = BIT(2), +}; + +static const struct genl_multicast_group batadv_netlink_mcgrps[] = { + [BATADV_NL_MCGRP_CONFIG] = { .name = BATADV_NL_MCAST_GROUP_CONFIG }, + [BATADV_NL_MCGRP_TPMETER] = { .name = BATADV_NL_MCAST_GROUP_TPMETER }, +}; + +static const struct nla_policy batadv_netlink_policy[NUM_BATADV_ATTR] = { + [BATADV_ATTR_VERSION] = { .type = NLA_STRING }, + [BATADV_ATTR_ALGO_NAME] = { .type = NLA_STRING }, + [BATADV_ATTR_MESH_IFINDEX] = { .type = NLA_U32 }, + [BATADV_ATTR_MESH_IFNAME] = { .type = NLA_STRING }, + [BATADV_ATTR_MESH_ADDRESS] = { .len = ETH_ALEN }, + [BATADV_ATTR_HARD_IFINDEX] = { .type = NLA_U32 }, + [BATADV_ATTR_HARD_IFNAME] = { .type = NLA_STRING }, + [BATADV_ATTR_HARD_ADDRESS] = { .len = ETH_ALEN }, + [BATADV_ATTR_ORIG_ADDRESS] = { .len = ETH_ALEN }, + [BATADV_ATTR_TPMETER_RESULT] = { .type = NLA_U8 }, + [BATADV_ATTR_TPMETER_TEST_TIME] = { .type = NLA_U32 }, + [BATADV_ATTR_TPMETER_BYTES] = { .type = NLA_U64 }, + [BATADV_ATTR_TPMETER_COOKIE] = { .type = NLA_U32 }, + [BATADV_ATTR_ACTIVE] = { .type = NLA_FLAG }, + [BATADV_ATTR_TT_ADDRESS] = { .len = ETH_ALEN }, + [BATADV_ATTR_TT_TTVN] = { .type = NLA_U8 }, + [BATADV_ATTR_TT_LAST_TTVN] = { .type = NLA_U8 }, + [BATADV_ATTR_TT_CRC32] = { .type = NLA_U32 }, + [BATADV_ATTR_TT_VID] = { .type = NLA_U16 }, + [BATADV_ATTR_TT_FLAGS] = { .type = NLA_U32 }, + [BATADV_ATTR_FLAG_BEST] = { .type = NLA_FLAG }, + [BATADV_ATTR_LAST_SEEN_MSECS] = { .type = NLA_U32 }, + [BATADV_ATTR_NEIGH_ADDRESS] = { .len = ETH_ALEN }, + [BATADV_ATTR_TQ] = { .type = NLA_U8 }, + [BATADV_ATTR_THROUGHPUT] = { .type = NLA_U32 }, + [BATADV_ATTR_BANDWIDTH_UP] = { .type = NLA_U32 }, + [BATADV_ATTR_BANDWIDTH_DOWN] = { .type = NLA_U32 }, + [BATADV_ATTR_ROUTER] = { .len = ETH_ALEN }, + [BATADV_ATTR_BLA_OWN] = { .type = NLA_FLAG }, + [BATADV_ATTR_BLA_ADDRESS] = { .len = ETH_ALEN }, + [BATADV_ATTR_BLA_VID] = { .type = NLA_U16 }, + [BATADV_ATTR_BLA_BACKBONE] = { .len = ETH_ALEN }, + [BATADV_ATTR_BLA_CRC] = { .type = NLA_U16 }, + [BATADV_ATTR_DAT_CACHE_IP4ADDRESS] = { .type = NLA_U32 }, + [BATADV_ATTR_DAT_CACHE_HWADDRESS] = { .len = ETH_ALEN }, + [BATADV_ATTR_DAT_CACHE_VID] = { .type = NLA_U16 }, + [BATADV_ATTR_MCAST_FLAGS] = { .type = NLA_U32 }, + [BATADV_ATTR_MCAST_FLAGS_PRIV] = { .type = NLA_U32 }, + [BATADV_ATTR_VLANID] = { .type = NLA_U16 }, + [BATADV_ATTR_AGGREGATED_OGMS_ENABLED] = { .type = NLA_U8 }, + [BATADV_ATTR_AP_ISOLATION_ENABLED] = { .type = NLA_U8 }, + [BATADV_ATTR_ISOLATION_MARK] = { .type = NLA_U32 }, + [BATADV_ATTR_ISOLATION_MASK] = { .type = NLA_U32 }, + [BATADV_ATTR_BONDING_ENABLED] = { .type = NLA_U8 }, + [BATADV_ATTR_BRIDGE_LOOP_AVOIDANCE_ENABLED] = { .type = NLA_U8 }, + [BATADV_ATTR_DISTRIBUTED_ARP_TABLE_ENABLED] = { .type = NLA_U8 }, + [BATADV_ATTR_FRAGMENTATION_ENABLED] = { .type = NLA_U8 }, + [BATADV_ATTR_GW_BANDWIDTH_DOWN] = { .type = NLA_U32 }, + [BATADV_ATTR_GW_BANDWIDTH_UP] = { .type = NLA_U32 }, + [BATADV_ATTR_GW_MODE] = { .type = NLA_U8 }, + [BATADV_ATTR_GW_SEL_CLASS] = { .type = NLA_U32 }, + [BATADV_ATTR_HOP_PENALTY] = { .type = NLA_U8 }, + [BATADV_ATTR_LOG_LEVEL] = { .type = NLA_U32 }, + [BATADV_ATTR_MULTICAST_FORCEFLOOD_ENABLED] = { .type = NLA_U8 }, + [BATADV_ATTR_MULTICAST_FANOUT] = { .type = NLA_U32 }, + [BATADV_ATTR_NETWORK_CODING_ENABLED] = { .type = NLA_U8 }, + [BATADV_ATTR_ORIG_INTERVAL] = { .type = NLA_U32 }, + [BATADV_ATTR_ELP_INTERVAL] = { .type = NLA_U32 }, + [BATADV_ATTR_THROUGHPUT_OVERRIDE] = { .type = NLA_U32 }, +}; + +/** + * batadv_netlink_get_ifindex() - Extract an interface index from a message + * @nlh: Message header + * @attrtype: Attribute which holds an interface index + * + * Return: interface index, or 0. + */ +int +batadv_netlink_get_ifindex(const struct nlmsghdr *nlh, int attrtype) +{ + struct nlattr *attr = nlmsg_find_attr(nlh, GENL_HDRLEN, attrtype); + + return (attr && nla_len(attr) == sizeof(u32)) ? nla_get_u32(attr) : 0; +} + +/** + * batadv_netlink_mesh_fill_ap_isolation() - Add ap_isolation softif attribute + * @msg: Netlink message to dump into + * @bat_priv: the bat priv with all the soft interface information + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_mesh_fill_ap_isolation(struct sk_buff *msg, + struct batadv_priv *bat_priv) +{ + struct batadv_softif_vlan *vlan; + u8 ap_isolation; + + vlan = batadv_softif_vlan_get(bat_priv, BATADV_NO_FLAGS); + if (!vlan) + return 0; + + ap_isolation = atomic_read(&vlan->ap_isolation); + batadv_softif_vlan_put(vlan); + + return nla_put_u8(msg, BATADV_ATTR_AP_ISOLATION_ENABLED, + !!ap_isolation); +} + +/** + * batadv_netlink_set_mesh_ap_isolation() - Set ap_isolation from genl msg + * @attr: parsed BATADV_ATTR_AP_ISOLATION_ENABLED attribute + * @bat_priv: the bat priv with all the soft interface information + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_set_mesh_ap_isolation(struct nlattr *attr, + struct batadv_priv *bat_priv) +{ + struct batadv_softif_vlan *vlan; + + vlan = batadv_softif_vlan_get(bat_priv, BATADV_NO_FLAGS); + if (!vlan) + return -ENOENT; + + atomic_set(&vlan->ap_isolation, !!nla_get_u8(attr)); + batadv_softif_vlan_put(vlan); + + return 0; +} + +/** + * batadv_netlink_mesh_fill() - Fill message with mesh attributes + * @msg: Netlink message to dump into + * @bat_priv: the bat priv with all the soft interface information + * @cmd: type of message to generate + * @portid: Port making netlink request + * @seq: sequence number for message + * @flags: Additional flags for message + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_mesh_fill(struct sk_buff *msg, + struct batadv_priv *bat_priv, + enum batadv_nl_commands cmd, + u32 portid, u32 seq, int flags) +{ + struct net_device *soft_iface = bat_priv->soft_iface; + struct batadv_hard_iface *primary_if = NULL; + struct net_device *hard_iface; + void *hdr; + + hdr = genlmsg_put(msg, portid, seq, &batadv_netlink_family, flags, cmd); + if (!hdr) + return -ENOBUFS; + + if (nla_put_string(msg, BATADV_ATTR_VERSION, BATADV_SOURCE_VERSION) || + nla_put_string(msg, BATADV_ATTR_ALGO_NAME, + bat_priv->algo_ops->name) || + nla_put_u32(msg, BATADV_ATTR_MESH_IFINDEX, soft_iface->ifindex) || + nla_put_string(msg, BATADV_ATTR_MESH_IFNAME, soft_iface->name) || + nla_put(msg, BATADV_ATTR_MESH_ADDRESS, ETH_ALEN, + soft_iface->dev_addr) || + nla_put_u8(msg, BATADV_ATTR_TT_TTVN, + (u8)atomic_read(&bat_priv->tt.vn))) + goto nla_put_failure; + +#ifdef CONFIG_BATMAN_ADV_BLA + if (nla_put_u16(msg, BATADV_ATTR_BLA_CRC, + ntohs(bat_priv->bla.claim_dest.group))) + goto nla_put_failure; +#endif + + if (batadv_mcast_mesh_info_put(msg, bat_priv)) + goto nla_put_failure; + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (primary_if && primary_if->if_status == BATADV_IF_ACTIVE) { + hard_iface = primary_if->net_dev; + + if (nla_put_u32(msg, BATADV_ATTR_HARD_IFINDEX, + hard_iface->ifindex) || + nla_put_string(msg, BATADV_ATTR_HARD_IFNAME, + hard_iface->name) || + nla_put(msg, BATADV_ATTR_HARD_ADDRESS, ETH_ALEN, + hard_iface->dev_addr)) + goto nla_put_failure; + } + + if (nla_put_u8(msg, BATADV_ATTR_AGGREGATED_OGMS_ENABLED, + !!atomic_read(&bat_priv->aggregated_ogms))) + goto nla_put_failure; + + if (batadv_netlink_mesh_fill_ap_isolation(msg, bat_priv)) + goto nla_put_failure; + + if (nla_put_u32(msg, BATADV_ATTR_ISOLATION_MARK, + bat_priv->isolation_mark)) + goto nla_put_failure; + + if (nla_put_u32(msg, BATADV_ATTR_ISOLATION_MASK, + bat_priv->isolation_mark_mask)) + goto nla_put_failure; + + if (nla_put_u8(msg, BATADV_ATTR_BONDING_ENABLED, + !!atomic_read(&bat_priv->bonding))) + goto nla_put_failure; + +#ifdef CONFIG_BATMAN_ADV_BLA + if (nla_put_u8(msg, BATADV_ATTR_BRIDGE_LOOP_AVOIDANCE_ENABLED, + !!atomic_read(&bat_priv->bridge_loop_avoidance))) + goto nla_put_failure; +#endif /* CONFIG_BATMAN_ADV_BLA */ + +#ifdef CONFIG_BATMAN_ADV_DAT + if (nla_put_u8(msg, BATADV_ATTR_DISTRIBUTED_ARP_TABLE_ENABLED, + !!atomic_read(&bat_priv->distributed_arp_table))) + goto nla_put_failure; +#endif /* CONFIG_BATMAN_ADV_DAT */ + + if (nla_put_u8(msg, BATADV_ATTR_FRAGMENTATION_ENABLED, + !!atomic_read(&bat_priv->fragmentation))) + goto nla_put_failure; + + if (nla_put_u32(msg, BATADV_ATTR_GW_BANDWIDTH_DOWN, + atomic_read(&bat_priv->gw.bandwidth_down))) + goto nla_put_failure; + + if (nla_put_u32(msg, BATADV_ATTR_GW_BANDWIDTH_UP, + atomic_read(&bat_priv->gw.bandwidth_up))) + goto nla_put_failure; + + if (nla_put_u8(msg, BATADV_ATTR_GW_MODE, + atomic_read(&bat_priv->gw.mode))) + goto nla_put_failure; + + if (bat_priv->algo_ops->gw.get_best_gw_node && + bat_priv->algo_ops->gw.is_eligible) { + /* GW selection class is not available if the routing algorithm + * in use does not implement the GW API + */ + if (nla_put_u32(msg, BATADV_ATTR_GW_SEL_CLASS, + atomic_read(&bat_priv->gw.sel_class))) + goto nla_put_failure; + } + + if (nla_put_u8(msg, BATADV_ATTR_HOP_PENALTY, + atomic_read(&bat_priv->hop_penalty))) + goto nla_put_failure; + +#ifdef CONFIG_BATMAN_ADV_DEBUG + if (nla_put_u32(msg, BATADV_ATTR_LOG_LEVEL, + atomic_read(&bat_priv->log_level))) + goto nla_put_failure; +#endif /* CONFIG_BATMAN_ADV_DEBUG */ + +#ifdef CONFIG_BATMAN_ADV_MCAST + if (nla_put_u8(msg, BATADV_ATTR_MULTICAST_FORCEFLOOD_ENABLED, + !atomic_read(&bat_priv->multicast_mode))) + goto nla_put_failure; + + if (nla_put_u32(msg, BATADV_ATTR_MULTICAST_FANOUT, + atomic_read(&bat_priv->multicast_fanout))) + goto nla_put_failure; +#endif /* CONFIG_BATMAN_ADV_MCAST */ + +#ifdef CONFIG_BATMAN_ADV_NC + if (nla_put_u8(msg, BATADV_ATTR_NETWORK_CODING_ENABLED, + !!atomic_read(&bat_priv->network_coding))) + goto nla_put_failure; +#endif /* CONFIG_BATMAN_ADV_NC */ + + if (nla_put_u32(msg, BATADV_ATTR_ORIG_INTERVAL, + atomic_read(&bat_priv->orig_interval))) + goto nla_put_failure; + + batadv_hardif_put(primary_if); + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + batadv_hardif_put(primary_if); + + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +/** + * batadv_netlink_notify_mesh() - send softif attributes to listener + * @bat_priv: the bat priv with all the soft interface information + * + * Return: 0 on success, < 0 on error + */ +int batadv_netlink_notify_mesh(struct batadv_priv *bat_priv) +{ + struct sk_buff *msg; + int ret; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + ret = batadv_netlink_mesh_fill(msg, bat_priv, BATADV_CMD_SET_MESH, + 0, 0, 0); + if (ret < 0) { + nlmsg_free(msg); + return ret; + } + + genlmsg_multicast_netns(&batadv_netlink_family, + dev_net(bat_priv->soft_iface), msg, 0, + BATADV_NL_MCGRP_CONFIG, GFP_KERNEL); + + return 0; +} + +/** + * batadv_netlink_get_mesh() - Get softif attributes + * @skb: Netlink message with request data + * @info: receiver information + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_get_mesh(struct sk_buff *skb, struct genl_info *info) +{ + struct batadv_priv *bat_priv = info->user_ptr[0]; + struct sk_buff *msg; + int ret; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + ret = batadv_netlink_mesh_fill(msg, bat_priv, BATADV_CMD_GET_MESH, + info->snd_portid, info->snd_seq, 0); + if (ret < 0) { + nlmsg_free(msg); + return ret; + } + + ret = genlmsg_reply(msg, info); + + return ret; +} + +/** + * batadv_netlink_set_mesh() - Set softif attributes + * @skb: Netlink message with request data + * @info: receiver information + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_set_mesh(struct sk_buff *skb, struct genl_info *info) +{ + struct batadv_priv *bat_priv = info->user_ptr[0]; + struct nlattr *attr; + + if (info->attrs[BATADV_ATTR_AGGREGATED_OGMS_ENABLED]) { + attr = info->attrs[BATADV_ATTR_AGGREGATED_OGMS_ENABLED]; + + atomic_set(&bat_priv->aggregated_ogms, !!nla_get_u8(attr)); + } + + if (info->attrs[BATADV_ATTR_AP_ISOLATION_ENABLED]) { + attr = info->attrs[BATADV_ATTR_AP_ISOLATION_ENABLED]; + + batadv_netlink_set_mesh_ap_isolation(attr, bat_priv); + } + + if (info->attrs[BATADV_ATTR_ISOLATION_MARK]) { + attr = info->attrs[BATADV_ATTR_ISOLATION_MARK]; + + bat_priv->isolation_mark = nla_get_u32(attr); + } + + if (info->attrs[BATADV_ATTR_ISOLATION_MASK]) { + attr = info->attrs[BATADV_ATTR_ISOLATION_MASK]; + + bat_priv->isolation_mark_mask = nla_get_u32(attr); + } + + if (info->attrs[BATADV_ATTR_BONDING_ENABLED]) { + attr = info->attrs[BATADV_ATTR_BONDING_ENABLED]; + + atomic_set(&bat_priv->bonding, !!nla_get_u8(attr)); + } + +#ifdef CONFIG_BATMAN_ADV_BLA + if (info->attrs[BATADV_ATTR_BRIDGE_LOOP_AVOIDANCE_ENABLED]) { + attr = info->attrs[BATADV_ATTR_BRIDGE_LOOP_AVOIDANCE_ENABLED]; + + atomic_set(&bat_priv->bridge_loop_avoidance, + !!nla_get_u8(attr)); + batadv_bla_status_update(bat_priv->soft_iface); + } +#endif /* CONFIG_BATMAN_ADV_BLA */ + +#ifdef CONFIG_BATMAN_ADV_DAT + if (info->attrs[BATADV_ATTR_DISTRIBUTED_ARP_TABLE_ENABLED]) { + attr = info->attrs[BATADV_ATTR_DISTRIBUTED_ARP_TABLE_ENABLED]; + + atomic_set(&bat_priv->distributed_arp_table, + !!nla_get_u8(attr)); + batadv_dat_status_update(bat_priv->soft_iface); + } +#endif /* CONFIG_BATMAN_ADV_DAT */ + + if (info->attrs[BATADV_ATTR_FRAGMENTATION_ENABLED]) { + attr = info->attrs[BATADV_ATTR_FRAGMENTATION_ENABLED]; + + atomic_set(&bat_priv->fragmentation, !!nla_get_u8(attr)); + + rtnl_lock(); + batadv_update_min_mtu(bat_priv->soft_iface); + rtnl_unlock(); + } + + if (info->attrs[BATADV_ATTR_GW_BANDWIDTH_DOWN]) { + attr = info->attrs[BATADV_ATTR_GW_BANDWIDTH_DOWN]; + + atomic_set(&bat_priv->gw.bandwidth_down, nla_get_u32(attr)); + batadv_gw_tvlv_container_update(bat_priv); + } + + if (info->attrs[BATADV_ATTR_GW_BANDWIDTH_UP]) { + attr = info->attrs[BATADV_ATTR_GW_BANDWIDTH_UP]; + + atomic_set(&bat_priv->gw.bandwidth_up, nla_get_u32(attr)); + batadv_gw_tvlv_container_update(bat_priv); + } + + if (info->attrs[BATADV_ATTR_GW_MODE]) { + u8 gw_mode; + + attr = info->attrs[BATADV_ATTR_GW_MODE]; + gw_mode = nla_get_u8(attr); + + if (gw_mode <= BATADV_GW_MODE_SERVER) { + /* Invoking batadv_gw_reselect() is not enough to really + * de-select the current GW. It will only instruct the + * gateway client code to perform a re-election the next + * time that this is needed. + * + * When gw client mode is being switched off the current + * GW must be de-selected explicitly otherwise no GW_ADD + * uevent is thrown on client mode re-activation. This + * is operation is performed in + * batadv_gw_check_client_stop(). + */ + batadv_gw_reselect(bat_priv); + + /* always call batadv_gw_check_client_stop() before + * changing the gateway state + */ + batadv_gw_check_client_stop(bat_priv); + atomic_set(&bat_priv->gw.mode, gw_mode); + batadv_gw_tvlv_container_update(bat_priv); + } + } + + if (info->attrs[BATADV_ATTR_GW_SEL_CLASS] && + bat_priv->algo_ops->gw.get_best_gw_node && + bat_priv->algo_ops->gw.is_eligible) { + /* setting the GW selection class is allowed only if the routing + * algorithm in use implements the GW API + */ + + u32 sel_class_max = 0xffffffffu; + u32 sel_class; + + attr = info->attrs[BATADV_ATTR_GW_SEL_CLASS]; + sel_class = nla_get_u32(attr); + + if (!bat_priv->algo_ops->gw.store_sel_class) + sel_class_max = BATADV_TQ_MAX_VALUE; + + if (sel_class >= 1 && sel_class <= sel_class_max) { + atomic_set(&bat_priv->gw.sel_class, sel_class); + batadv_gw_reselect(bat_priv); + } + } + + if (info->attrs[BATADV_ATTR_HOP_PENALTY]) { + attr = info->attrs[BATADV_ATTR_HOP_PENALTY]; + + atomic_set(&bat_priv->hop_penalty, nla_get_u8(attr)); + } + +#ifdef CONFIG_BATMAN_ADV_DEBUG + if (info->attrs[BATADV_ATTR_LOG_LEVEL]) { + attr = info->attrs[BATADV_ATTR_LOG_LEVEL]; + + atomic_set(&bat_priv->log_level, + nla_get_u32(attr) & BATADV_DBG_ALL); + } +#endif /* CONFIG_BATMAN_ADV_DEBUG */ + +#ifdef CONFIG_BATMAN_ADV_MCAST + if (info->attrs[BATADV_ATTR_MULTICAST_FORCEFLOOD_ENABLED]) { + attr = info->attrs[BATADV_ATTR_MULTICAST_FORCEFLOOD_ENABLED]; + + atomic_set(&bat_priv->multicast_mode, !nla_get_u8(attr)); + } + + if (info->attrs[BATADV_ATTR_MULTICAST_FANOUT]) { + attr = info->attrs[BATADV_ATTR_MULTICAST_FANOUT]; + + atomic_set(&bat_priv->multicast_fanout, nla_get_u32(attr)); + } +#endif /* CONFIG_BATMAN_ADV_MCAST */ + +#ifdef CONFIG_BATMAN_ADV_NC + if (info->attrs[BATADV_ATTR_NETWORK_CODING_ENABLED]) { + attr = info->attrs[BATADV_ATTR_NETWORK_CODING_ENABLED]; + + atomic_set(&bat_priv->network_coding, !!nla_get_u8(attr)); + batadv_nc_status_update(bat_priv->soft_iface); + } +#endif /* CONFIG_BATMAN_ADV_NC */ + + if (info->attrs[BATADV_ATTR_ORIG_INTERVAL]) { + u32 orig_interval; + + attr = info->attrs[BATADV_ATTR_ORIG_INTERVAL]; + orig_interval = nla_get_u32(attr); + + orig_interval = min_t(u32, orig_interval, INT_MAX); + orig_interval = max_t(u32, orig_interval, 2 * BATADV_JITTER); + + atomic_set(&bat_priv->orig_interval, orig_interval); + } + + batadv_netlink_notify_mesh(bat_priv); + + return 0; +} + +/** + * batadv_netlink_tp_meter_put() - Fill information of started tp_meter session + * @msg: netlink message to be sent back + * @cookie: tp meter session cookie + * + * Return: 0 on success, < 0 on error + */ +static int +batadv_netlink_tp_meter_put(struct sk_buff *msg, u32 cookie) +{ + if (nla_put_u32(msg, BATADV_ATTR_TPMETER_COOKIE, cookie)) + return -ENOBUFS; + + return 0; +} + +/** + * batadv_netlink_tpmeter_notify() - send tp_meter result via netlink to client + * @bat_priv: the bat priv with all the soft interface information + * @dst: destination of tp_meter session + * @result: reason for tp meter session stop + * @test_time: total time of the tp_meter session + * @total_bytes: bytes acked to the receiver + * @cookie: cookie of tp_meter session + * + * Return: 0 on success, < 0 on error + */ +int batadv_netlink_tpmeter_notify(struct batadv_priv *bat_priv, const u8 *dst, + u8 result, u32 test_time, u64 total_bytes, + u32 cookie) +{ + struct sk_buff *msg; + void *hdr; + int ret; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &batadv_netlink_family, 0, + BATADV_CMD_TP_METER); + if (!hdr) { + ret = -ENOBUFS; + goto err_genlmsg; + } + + if (nla_put_u32(msg, BATADV_ATTR_TPMETER_COOKIE, cookie)) + goto nla_put_failure; + + if (nla_put_u32(msg, BATADV_ATTR_TPMETER_TEST_TIME, test_time)) + goto nla_put_failure; + + if (nla_put_u64_64bit(msg, BATADV_ATTR_TPMETER_BYTES, total_bytes, + BATADV_ATTR_PAD)) + goto nla_put_failure; + + if (nla_put_u8(msg, BATADV_ATTR_TPMETER_RESULT, result)) + goto nla_put_failure; + + if (nla_put(msg, BATADV_ATTR_ORIG_ADDRESS, ETH_ALEN, dst)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&batadv_netlink_family, + dev_net(bat_priv->soft_iface), msg, 0, + BATADV_NL_MCGRP_TPMETER, GFP_KERNEL); + + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + ret = -EMSGSIZE; + +err_genlmsg: + nlmsg_free(msg); + return ret; +} + +/** + * batadv_netlink_tp_meter_start() - Start a new tp_meter session + * @skb: received netlink message + * @info: receiver information + * + * Return: 0 on success, < 0 on error + */ +static int +batadv_netlink_tp_meter_start(struct sk_buff *skb, struct genl_info *info) +{ + struct batadv_priv *bat_priv = info->user_ptr[0]; + struct sk_buff *msg = NULL; + u32 test_length; + void *msg_head; + u32 cookie; + u8 *dst; + int ret; + + if (!info->attrs[BATADV_ATTR_ORIG_ADDRESS]) + return -EINVAL; + + if (!info->attrs[BATADV_ATTR_TPMETER_TEST_TIME]) + return -EINVAL; + + dst = nla_data(info->attrs[BATADV_ATTR_ORIG_ADDRESS]); + + test_length = nla_get_u32(info->attrs[BATADV_ATTR_TPMETER_TEST_TIME]); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) { + ret = -ENOMEM; + goto out; + } + + msg_head = genlmsg_put(msg, info->snd_portid, info->snd_seq, + &batadv_netlink_family, 0, + BATADV_CMD_TP_METER); + if (!msg_head) { + ret = -ENOBUFS; + goto out; + } + + batadv_tp_start(bat_priv, dst, test_length, &cookie); + + ret = batadv_netlink_tp_meter_put(msg, cookie); + + out: + if (ret) { + if (msg) + nlmsg_free(msg); + return ret; + } + + genlmsg_end(msg, msg_head); + return genlmsg_reply(msg, info); +} + +/** + * batadv_netlink_tp_meter_cancel() - Cancel a running tp_meter session + * @skb: received netlink message + * @info: receiver information + * + * Return: 0 on success, < 0 on error + */ +static int +batadv_netlink_tp_meter_cancel(struct sk_buff *skb, struct genl_info *info) +{ + struct batadv_priv *bat_priv = info->user_ptr[0]; + u8 *dst; + int ret = 0; + + if (!info->attrs[BATADV_ATTR_ORIG_ADDRESS]) + return -EINVAL; + + dst = nla_data(info->attrs[BATADV_ATTR_ORIG_ADDRESS]); + + batadv_tp_stop(bat_priv, dst, BATADV_TP_REASON_CANCEL); + + return ret; +} + +/** + * batadv_netlink_hardif_fill() - Fill message with hardif attributes + * @msg: Netlink message to dump into + * @bat_priv: the bat priv with all the soft interface information + * @hard_iface: hard interface which was modified + * @cmd: type of message to generate + * @portid: Port making netlink request + * @seq: sequence number for message + * @flags: Additional flags for message + * @cb: Control block containing additional options + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_hardif_fill(struct sk_buff *msg, + struct batadv_priv *bat_priv, + struct batadv_hard_iface *hard_iface, + enum batadv_nl_commands cmd, + u32 portid, u32 seq, int flags, + struct netlink_callback *cb) +{ + struct net_device *net_dev = hard_iface->net_dev; + void *hdr; + + hdr = genlmsg_put(msg, portid, seq, &batadv_netlink_family, flags, cmd); + if (!hdr) + return -ENOBUFS; + + if (cb) + genl_dump_check_consistent(cb, hdr); + + if (nla_put_u32(msg, BATADV_ATTR_MESH_IFINDEX, + bat_priv->soft_iface->ifindex)) + goto nla_put_failure; + + if (nla_put_string(msg, BATADV_ATTR_MESH_IFNAME, + bat_priv->soft_iface->name)) + goto nla_put_failure; + + if (nla_put_u32(msg, BATADV_ATTR_HARD_IFINDEX, + net_dev->ifindex) || + nla_put_string(msg, BATADV_ATTR_HARD_IFNAME, + net_dev->name) || + nla_put(msg, BATADV_ATTR_HARD_ADDRESS, ETH_ALEN, + net_dev->dev_addr)) + goto nla_put_failure; + + if (hard_iface->if_status == BATADV_IF_ACTIVE) { + if (nla_put_flag(msg, BATADV_ATTR_ACTIVE)) + goto nla_put_failure; + } + + if (nla_put_u8(msg, BATADV_ATTR_HOP_PENALTY, + atomic_read(&hard_iface->hop_penalty))) + goto nla_put_failure; + +#ifdef CONFIG_BATMAN_ADV_BATMAN_V + if (nla_put_u32(msg, BATADV_ATTR_ELP_INTERVAL, + atomic_read(&hard_iface->bat_v.elp_interval))) + goto nla_put_failure; + + if (nla_put_u32(msg, BATADV_ATTR_THROUGHPUT_OVERRIDE, + atomic_read(&hard_iface->bat_v.throughput_override))) + goto nla_put_failure; +#endif /* CONFIG_BATMAN_ADV_BATMAN_V */ + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +/** + * batadv_netlink_notify_hardif() - send hardif attributes to listener + * @bat_priv: the bat priv with all the soft interface information + * @hard_iface: hard interface which was modified + * + * Return: 0 on success, < 0 on error + */ +int batadv_netlink_notify_hardif(struct batadv_priv *bat_priv, + struct batadv_hard_iface *hard_iface) +{ + struct sk_buff *msg; + int ret; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + ret = batadv_netlink_hardif_fill(msg, bat_priv, hard_iface, + BATADV_CMD_SET_HARDIF, 0, 0, 0, NULL); + if (ret < 0) { + nlmsg_free(msg); + return ret; + } + + genlmsg_multicast_netns(&batadv_netlink_family, + dev_net(bat_priv->soft_iface), msg, 0, + BATADV_NL_MCGRP_CONFIG, GFP_KERNEL); + + return 0; +} + +/** + * batadv_netlink_get_hardif() - Get hardif attributes + * @skb: Netlink message with request data + * @info: receiver information + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_get_hardif(struct sk_buff *skb, + struct genl_info *info) +{ + struct batadv_hard_iface *hard_iface = info->user_ptr[1]; + struct batadv_priv *bat_priv = info->user_ptr[0]; + struct sk_buff *msg; + int ret; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + ret = batadv_netlink_hardif_fill(msg, bat_priv, hard_iface, + BATADV_CMD_GET_HARDIF, + info->snd_portid, info->snd_seq, 0, + NULL); + if (ret < 0) { + nlmsg_free(msg); + return ret; + } + + ret = genlmsg_reply(msg, info); + + return ret; +} + +/** + * batadv_netlink_set_hardif() - Set hardif attributes + * @skb: Netlink message with request data + * @info: receiver information + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_set_hardif(struct sk_buff *skb, + struct genl_info *info) +{ + struct batadv_hard_iface *hard_iface = info->user_ptr[1]; + struct batadv_priv *bat_priv = info->user_ptr[0]; + struct nlattr *attr; + + if (info->attrs[BATADV_ATTR_HOP_PENALTY]) { + attr = info->attrs[BATADV_ATTR_HOP_PENALTY]; + + atomic_set(&hard_iface->hop_penalty, nla_get_u8(attr)); + } + +#ifdef CONFIG_BATMAN_ADV_BATMAN_V + + if (info->attrs[BATADV_ATTR_ELP_INTERVAL]) { + attr = info->attrs[BATADV_ATTR_ELP_INTERVAL]; + + atomic_set(&hard_iface->bat_v.elp_interval, nla_get_u32(attr)); + } + + if (info->attrs[BATADV_ATTR_THROUGHPUT_OVERRIDE]) { + attr = info->attrs[BATADV_ATTR_THROUGHPUT_OVERRIDE]; + + atomic_set(&hard_iface->bat_v.throughput_override, + nla_get_u32(attr)); + } +#endif /* CONFIG_BATMAN_ADV_BATMAN_V */ + + batadv_netlink_notify_hardif(bat_priv, hard_iface); + + return 0; +} + +/** + * batadv_netlink_dump_hardif() - Dump all hard interface into a messages + * @msg: Netlink message to dump into + * @cb: Parameters from query + * + * Return: error code, or length of reply message on success + */ +static int +batadv_netlink_dump_hardif(struct sk_buff *msg, struct netlink_callback *cb) +{ + struct net *net = sock_net(cb->skb->sk); + struct net_device *soft_iface; + struct batadv_hard_iface *hard_iface; + struct batadv_priv *bat_priv; + int ifindex; + int portid = NETLINK_CB(cb->skb).portid; + int skip = cb->args[0]; + int i = 0; + + ifindex = batadv_netlink_get_ifindex(cb->nlh, + BATADV_ATTR_MESH_IFINDEX); + if (!ifindex) + return -EINVAL; + + soft_iface = dev_get_by_index(net, ifindex); + if (!soft_iface) + return -ENODEV; + + if (!batadv_softif_is_valid(soft_iface)) { + dev_put(soft_iface); + return -ENODEV; + } + + bat_priv = netdev_priv(soft_iface); + + rtnl_lock(); + cb->seq = batadv_hardif_generation << 1 | 1; + + list_for_each_entry(hard_iface, &batadv_hardif_list, list) { + if (hard_iface->soft_iface != soft_iface) + continue; + + if (i++ < skip) + continue; + + if (batadv_netlink_hardif_fill(msg, bat_priv, hard_iface, + BATADV_CMD_GET_HARDIF, + portid, cb->nlh->nlmsg_seq, + NLM_F_MULTI, cb)) { + i--; + break; + } + } + + rtnl_unlock(); + + dev_put(soft_iface); + + cb->args[0] = i; + + return msg->len; +} + +/** + * batadv_netlink_vlan_fill() - Fill message with vlan attributes + * @msg: Netlink message to dump into + * @bat_priv: the bat priv with all the soft interface information + * @vlan: vlan which was modified + * @cmd: type of message to generate + * @portid: Port making netlink request + * @seq: sequence number for message + * @flags: Additional flags for message + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_vlan_fill(struct sk_buff *msg, + struct batadv_priv *bat_priv, + struct batadv_softif_vlan *vlan, + enum batadv_nl_commands cmd, + u32 portid, u32 seq, int flags) +{ + void *hdr; + + hdr = genlmsg_put(msg, portid, seq, &batadv_netlink_family, flags, cmd); + if (!hdr) + return -ENOBUFS; + + if (nla_put_u32(msg, BATADV_ATTR_MESH_IFINDEX, + bat_priv->soft_iface->ifindex)) + goto nla_put_failure; + + if (nla_put_string(msg, BATADV_ATTR_MESH_IFNAME, + bat_priv->soft_iface->name)) + goto nla_put_failure; + + if (nla_put_u32(msg, BATADV_ATTR_VLANID, vlan->vid & VLAN_VID_MASK)) + goto nla_put_failure; + + if (nla_put_u8(msg, BATADV_ATTR_AP_ISOLATION_ENABLED, + !!atomic_read(&vlan->ap_isolation))) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +/** + * batadv_netlink_notify_vlan() - send vlan attributes to listener + * @bat_priv: the bat priv with all the soft interface information + * @vlan: vlan which was modified + * + * Return: 0 on success, < 0 on error + */ +int batadv_netlink_notify_vlan(struct batadv_priv *bat_priv, + struct batadv_softif_vlan *vlan) +{ + struct sk_buff *msg; + int ret; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + ret = batadv_netlink_vlan_fill(msg, bat_priv, vlan, + BATADV_CMD_SET_VLAN, 0, 0, 0); + if (ret < 0) { + nlmsg_free(msg); + return ret; + } + + genlmsg_multicast_netns(&batadv_netlink_family, + dev_net(bat_priv->soft_iface), msg, 0, + BATADV_NL_MCGRP_CONFIG, GFP_KERNEL); + + return 0; +} + +/** + * batadv_netlink_get_vlan() - Get vlan attributes + * @skb: Netlink message with request data + * @info: receiver information + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_get_vlan(struct sk_buff *skb, struct genl_info *info) +{ + struct batadv_softif_vlan *vlan = info->user_ptr[1]; + struct batadv_priv *bat_priv = info->user_ptr[0]; + struct sk_buff *msg; + int ret; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + ret = batadv_netlink_vlan_fill(msg, bat_priv, vlan, BATADV_CMD_GET_VLAN, + info->snd_portid, info->snd_seq, 0); + if (ret < 0) { + nlmsg_free(msg); + return ret; + } + + ret = genlmsg_reply(msg, info); + + return ret; +} + +/** + * batadv_netlink_set_vlan() - Get vlan attributes + * @skb: Netlink message with request data + * @info: receiver information + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_netlink_set_vlan(struct sk_buff *skb, struct genl_info *info) +{ + struct batadv_softif_vlan *vlan = info->user_ptr[1]; + struct batadv_priv *bat_priv = info->user_ptr[0]; + struct nlattr *attr; + + if (info->attrs[BATADV_ATTR_AP_ISOLATION_ENABLED]) { + attr = info->attrs[BATADV_ATTR_AP_ISOLATION_ENABLED]; + + atomic_set(&vlan->ap_isolation, !!nla_get_u8(attr)); + } + + batadv_netlink_notify_vlan(bat_priv, vlan); + + return 0; +} + +/** + * batadv_get_softif_from_info() - Retrieve soft interface from genl attributes + * @net: the applicable net namespace + * @info: receiver information + * + * Return: Pointer to soft interface (with increased refcnt) on success, error + * pointer on error + */ +static struct net_device * +batadv_get_softif_from_info(struct net *net, struct genl_info *info) +{ + struct net_device *soft_iface; + int ifindex; + + if (!info->attrs[BATADV_ATTR_MESH_IFINDEX]) + return ERR_PTR(-EINVAL); + + ifindex = nla_get_u32(info->attrs[BATADV_ATTR_MESH_IFINDEX]); + + soft_iface = dev_get_by_index(net, ifindex); + if (!soft_iface) + return ERR_PTR(-ENODEV); + + if (!batadv_softif_is_valid(soft_iface)) + goto err_put_softif; + + return soft_iface; + +err_put_softif: + dev_put(soft_iface); + + return ERR_PTR(-EINVAL); +} + +/** + * batadv_get_hardif_from_info() - Retrieve hardif from genl attributes + * @bat_priv: the bat priv with all the soft interface information + * @net: the applicable net namespace + * @info: receiver information + * + * Return: Pointer to hard interface (with increased refcnt) on success, error + * pointer on error + */ +static struct batadv_hard_iface * +batadv_get_hardif_from_info(struct batadv_priv *bat_priv, struct net *net, + struct genl_info *info) +{ + struct batadv_hard_iface *hard_iface; + struct net_device *hard_dev; + unsigned int hardif_index; + + if (!info->attrs[BATADV_ATTR_HARD_IFINDEX]) + return ERR_PTR(-EINVAL); + + hardif_index = nla_get_u32(info->attrs[BATADV_ATTR_HARD_IFINDEX]); + + hard_dev = dev_get_by_index(net, hardif_index); + if (!hard_dev) + return ERR_PTR(-ENODEV); + + hard_iface = batadv_hardif_get_by_netdev(hard_dev); + if (!hard_iface) + goto err_put_harddev; + + if (hard_iface->soft_iface != bat_priv->soft_iface) + goto err_put_hardif; + + /* hard_dev is referenced by hard_iface and not needed here */ + dev_put(hard_dev); + + return hard_iface; + +err_put_hardif: + batadv_hardif_put(hard_iface); +err_put_harddev: + dev_put(hard_dev); + + return ERR_PTR(-EINVAL); +} + +/** + * batadv_get_vlan_from_info() - Retrieve vlan from genl attributes + * @bat_priv: the bat priv with all the soft interface information + * @net: the applicable net namespace + * @info: receiver information + * + * Return: Pointer to vlan on success (with increased refcnt), error pointer + * on error + */ +static struct batadv_softif_vlan * +batadv_get_vlan_from_info(struct batadv_priv *bat_priv, struct net *net, + struct genl_info *info) +{ + struct batadv_softif_vlan *vlan; + u16 vid; + + if (!info->attrs[BATADV_ATTR_VLANID]) + return ERR_PTR(-EINVAL); + + vid = nla_get_u16(info->attrs[BATADV_ATTR_VLANID]); + + vlan = batadv_softif_vlan_get(bat_priv, vid | BATADV_VLAN_HAS_TAG); + if (!vlan) + return ERR_PTR(-ENOENT); + + return vlan; +} + +/** + * batadv_pre_doit() - Prepare batman-adv genl doit request + * @ops: requested netlink operation + * @skb: Netlink message with request data + * @info: receiver information + * + * Return: 0 on success or negative error number in case of failure + */ +static int batadv_pre_doit(const struct genl_ops *ops, struct sk_buff *skb, + struct genl_info *info) +{ + struct net *net = genl_info_net(info); + struct batadv_hard_iface *hard_iface; + struct batadv_priv *bat_priv = NULL; + struct batadv_softif_vlan *vlan; + struct net_device *soft_iface; + u8 user_ptr1_flags; + u8 mesh_dep_flags; + int ret; + + user_ptr1_flags = BATADV_FLAG_NEED_HARDIF | BATADV_FLAG_NEED_VLAN; + if (WARN_ON(hweight8(ops->internal_flags & user_ptr1_flags) > 1)) + return -EINVAL; + + mesh_dep_flags = BATADV_FLAG_NEED_HARDIF | BATADV_FLAG_NEED_VLAN; + if (WARN_ON((ops->internal_flags & mesh_dep_flags) && + (~ops->internal_flags & BATADV_FLAG_NEED_MESH))) + return -EINVAL; + + if (ops->internal_flags & BATADV_FLAG_NEED_MESH) { + soft_iface = batadv_get_softif_from_info(net, info); + if (IS_ERR(soft_iface)) + return PTR_ERR(soft_iface); + + bat_priv = netdev_priv(soft_iface); + info->user_ptr[0] = bat_priv; + } + + if (ops->internal_flags & BATADV_FLAG_NEED_HARDIF) { + hard_iface = batadv_get_hardif_from_info(bat_priv, net, info); + if (IS_ERR(hard_iface)) { + ret = PTR_ERR(hard_iface); + goto err_put_softif; + } + + info->user_ptr[1] = hard_iface; + } + + if (ops->internal_flags & BATADV_FLAG_NEED_VLAN) { + vlan = batadv_get_vlan_from_info(bat_priv, net, info); + if (IS_ERR(vlan)) { + ret = PTR_ERR(vlan); + goto err_put_softif; + } + + info->user_ptr[1] = vlan; + } + + return 0; + +err_put_softif: + if (bat_priv) + dev_put(bat_priv->soft_iface); + + return ret; +} + +/** + * batadv_post_doit() - End batman-adv genl doit request + * @ops: requested netlink operation + * @skb: Netlink message with request data + * @info: receiver information + */ +static void batadv_post_doit(const struct genl_ops *ops, struct sk_buff *skb, + struct genl_info *info) +{ + struct batadv_hard_iface *hard_iface; + struct batadv_softif_vlan *vlan; + struct batadv_priv *bat_priv; + + if (ops->internal_flags & BATADV_FLAG_NEED_HARDIF && + info->user_ptr[1]) { + hard_iface = info->user_ptr[1]; + + batadv_hardif_put(hard_iface); + } + + if (ops->internal_flags & BATADV_FLAG_NEED_VLAN && info->user_ptr[1]) { + vlan = info->user_ptr[1]; + batadv_softif_vlan_put(vlan); + } + + if (ops->internal_flags & BATADV_FLAG_NEED_MESH && info->user_ptr[0]) { + bat_priv = info->user_ptr[0]; + dev_put(bat_priv->soft_iface); + } +} + +static const struct genl_small_ops batadv_netlink_ops[] = { + { + .cmd = BATADV_CMD_GET_MESH, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + /* can be retrieved by unprivileged users */ + .doit = batadv_netlink_get_mesh, + .internal_flags = BATADV_FLAG_NEED_MESH, + }, + { + .cmd = BATADV_CMD_TP_METER, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .doit = batadv_netlink_tp_meter_start, + .internal_flags = BATADV_FLAG_NEED_MESH, + }, + { + .cmd = BATADV_CMD_TP_METER_CANCEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .doit = batadv_netlink_tp_meter_cancel, + .internal_flags = BATADV_FLAG_NEED_MESH, + }, + { + .cmd = BATADV_CMD_GET_ROUTING_ALGOS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .dumpit = batadv_algo_dump, + }, + { + .cmd = BATADV_CMD_GET_HARDIF, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + /* can be retrieved by unprivileged users */ + .dumpit = batadv_netlink_dump_hardif, + .doit = batadv_netlink_get_hardif, + .internal_flags = BATADV_FLAG_NEED_MESH | + BATADV_FLAG_NEED_HARDIF, + }, + { + .cmd = BATADV_CMD_GET_TRANSTABLE_LOCAL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .dumpit = batadv_tt_local_dump, + }, + { + .cmd = BATADV_CMD_GET_TRANSTABLE_GLOBAL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .dumpit = batadv_tt_global_dump, + }, + { + .cmd = BATADV_CMD_GET_ORIGINATORS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .dumpit = batadv_orig_dump, + }, + { + .cmd = BATADV_CMD_GET_NEIGHBORS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .dumpit = batadv_hardif_neigh_dump, + }, + { + .cmd = BATADV_CMD_GET_GATEWAYS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .dumpit = batadv_gw_dump, + }, + { + .cmd = BATADV_CMD_GET_BLA_CLAIM, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .dumpit = batadv_bla_claim_dump, + }, + { + .cmd = BATADV_CMD_GET_BLA_BACKBONE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .dumpit = batadv_bla_backbone_dump, + }, + { + .cmd = BATADV_CMD_GET_DAT_CACHE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .dumpit = batadv_dat_cache_dump, + }, + { + .cmd = BATADV_CMD_GET_MCAST_FLAGS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .dumpit = batadv_mcast_flags_dump, + }, + { + .cmd = BATADV_CMD_SET_MESH, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .doit = batadv_netlink_set_mesh, + .internal_flags = BATADV_FLAG_NEED_MESH, + }, + { + .cmd = BATADV_CMD_SET_HARDIF, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .doit = batadv_netlink_set_hardif, + .internal_flags = BATADV_FLAG_NEED_MESH | + BATADV_FLAG_NEED_HARDIF, + }, + { + .cmd = BATADV_CMD_GET_VLAN, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + /* can be retrieved by unprivileged users */ + .doit = batadv_netlink_get_vlan, + .internal_flags = BATADV_FLAG_NEED_MESH | + BATADV_FLAG_NEED_VLAN, + }, + { + .cmd = BATADV_CMD_SET_VLAN, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .doit = batadv_netlink_set_vlan, + .internal_flags = BATADV_FLAG_NEED_MESH | + BATADV_FLAG_NEED_VLAN, + }, +}; + +struct genl_family batadv_netlink_family __ro_after_init = { + .hdrsize = 0, + .name = BATADV_NL_NAME, + .version = 1, + .maxattr = BATADV_ATTR_MAX, + .policy = batadv_netlink_policy, + .netnsok = true, + .pre_doit = batadv_pre_doit, + .post_doit = batadv_post_doit, + .module = THIS_MODULE, + .small_ops = batadv_netlink_ops, + .n_small_ops = ARRAY_SIZE(batadv_netlink_ops), + .resv_start_op = BATADV_CMD_SET_VLAN + 1, + .mcgrps = batadv_netlink_mcgrps, + .n_mcgrps = ARRAY_SIZE(batadv_netlink_mcgrps), +}; + +/** + * batadv_netlink_register() - register batadv genl netlink family + */ +void __init batadv_netlink_register(void) +{ + int ret; + + ret = genl_register_family(&batadv_netlink_family); + if (ret) + pr_warn("unable to register netlink family"); +} + +/** + * batadv_netlink_unregister() - unregister batadv genl netlink family + */ +void batadv_netlink_unregister(void) +{ + genl_unregister_family(&batadv_netlink_family); +} diff --git a/net/batman-adv/netlink.h b/net/batman-adv/netlink.h new file mode 100644 index 000000000..48102cc74 --- /dev/null +++ b/net/batman-adv/netlink.h @@ -0,0 +1,32 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Matthias Schiffer + */ + +#ifndef _NET_BATMAN_ADV_NETLINK_H_ +#define _NET_BATMAN_ADV_NETLINK_H_ + +#include "main.h" + +#include +#include +#include + +void batadv_netlink_register(void); +void batadv_netlink_unregister(void); +int batadv_netlink_get_ifindex(const struct nlmsghdr *nlh, int attrtype); + +int batadv_netlink_tpmeter_notify(struct batadv_priv *bat_priv, const u8 *dst, + u8 result, u32 test_time, u64 total_bytes, + u32 cookie); + +int batadv_netlink_notify_mesh(struct batadv_priv *bat_priv); +int batadv_netlink_notify_hardif(struct batadv_priv *bat_priv, + struct batadv_hard_iface *hard_iface); +int batadv_netlink_notify_vlan(struct batadv_priv *bat_priv, + struct batadv_softif_vlan *vlan); + +extern struct genl_family batadv_netlink_family; + +#endif /* _NET_BATMAN_ADV_NETLINK_H_ */ diff --git a/net/batman-adv/network-coding.c b/net/batman-adv/network-coding.c new file mode 100644 index 000000000..5f4aeeb60 --- /dev/null +++ b/net/batman-adv/network-coding.c @@ -0,0 +1,1873 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Martin Hundebøll, Jeppe Ledet-Pedersen + */ + +#include "network-coding.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "hash.h" +#include "log.h" +#include "originator.h" +#include "routing.h" +#include "send.h" +#include "tvlv.h" + +static struct lock_class_key batadv_nc_coding_hash_lock_class_key; +static struct lock_class_key batadv_nc_decoding_hash_lock_class_key; + +static void batadv_nc_worker(struct work_struct *work); +static int batadv_nc_recv_coded_packet(struct sk_buff *skb, + struct batadv_hard_iface *recv_if); + +/** + * batadv_nc_init() - one-time initialization for network coding + * + * Return: 0 on success or negative error number in case of failure + */ +int __init batadv_nc_init(void) +{ + /* Register our packet type */ + return batadv_recv_handler_register(BATADV_CODED, + batadv_nc_recv_coded_packet); +} + +/** + * batadv_nc_start_timer() - initialise the nc periodic worker + * @bat_priv: the bat priv with all the soft interface information + */ +static void batadv_nc_start_timer(struct batadv_priv *bat_priv) +{ + queue_delayed_work(batadv_event_workqueue, &bat_priv->nc.work, + msecs_to_jiffies(10)); +} + +/** + * batadv_nc_tvlv_container_update() - update the network coding tvlv container + * after network coding setting change + * @bat_priv: the bat priv with all the soft interface information + */ +static void batadv_nc_tvlv_container_update(struct batadv_priv *bat_priv) +{ + char nc_mode; + + nc_mode = atomic_read(&bat_priv->network_coding); + + switch (nc_mode) { + case 0: + batadv_tvlv_container_unregister(bat_priv, BATADV_TVLV_NC, 1); + break; + case 1: + batadv_tvlv_container_register(bat_priv, BATADV_TVLV_NC, 1, + NULL, 0); + break; + } +} + +/** + * batadv_nc_status_update() - update the network coding tvlv container after + * network coding setting change + * @net_dev: the soft interface net device + */ +void batadv_nc_status_update(struct net_device *net_dev) +{ + struct batadv_priv *bat_priv = netdev_priv(net_dev); + + batadv_nc_tvlv_container_update(bat_priv); +} + +/** + * batadv_nc_tvlv_ogm_handler_v1() - process incoming nc tvlv container + * @bat_priv: the bat priv with all the soft interface information + * @orig: the orig_node of the ogm + * @flags: flags indicating the tvlv state (see batadv_tvlv_handler_flags) + * @tvlv_value: tvlv buffer containing the gateway data + * @tvlv_value_len: tvlv buffer length + */ +static void batadv_nc_tvlv_ogm_handler_v1(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig, + u8 flags, + void *tvlv_value, u16 tvlv_value_len) +{ + if (flags & BATADV_TVLV_HANDLER_OGM_CIFNOTFND) + clear_bit(BATADV_ORIG_CAPA_HAS_NC, &orig->capabilities); + else + set_bit(BATADV_ORIG_CAPA_HAS_NC, &orig->capabilities); +} + +/** + * batadv_nc_mesh_init() - initialise coding hash table and start housekeeping + * @bat_priv: the bat priv with all the soft interface information + * + * Return: 0 on success or negative error number in case of failure + */ +int batadv_nc_mesh_init(struct batadv_priv *bat_priv) +{ + bat_priv->nc.timestamp_fwd_flush = jiffies; + bat_priv->nc.timestamp_sniffed_purge = jiffies; + + if (bat_priv->nc.coding_hash || bat_priv->nc.decoding_hash) + return 0; + + bat_priv->nc.coding_hash = batadv_hash_new(128); + if (!bat_priv->nc.coding_hash) + goto err; + + batadv_hash_set_lock_class(bat_priv->nc.coding_hash, + &batadv_nc_coding_hash_lock_class_key); + + bat_priv->nc.decoding_hash = batadv_hash_new(128); + if (!bat_priv->nc.decoding_hash) { + batadv_hash_destroy(bat_priv->nc.coding_hash); + goto err; + } + + batadv_hash_set_lock_class(bat_priv->nc.decoding_hash, + &batadv_nc_decoding_hash_lock_class_key); + + INIT_DELAYED_WORK(&bat_priv->nc.work, batadv_nc_worker); + batadv_nc_start_timer(bat_priv); + + batadv_tvlv_handler_register(bat_priv, batadv_nc_tvlv_ogm_handler_v1, + NULL, BATADV_TVLV_NC, 1, + BATADV_TVLV_HANDLER_OGM_CIFNOTFND); + batadv_nc_tvlv_container_update(bat_priv); + return 0; + +err: + return -ENOMEM; +} + +/** + * batadv_nc_init_bat_priv() - initialise the nc specific bat_priv variables + * @bat_priv: the bat priv with all the soft interface information + */ +void batadv_nc_init_bat_priv(struct batadv_priv *bat_priv) +{ + atomic_set(&bat_priv->network_coding, 0); + bat_priv->nc.min_tq = 200; + bat_priv->nc.max_fwd_delay = 10; + bat_priv->nc.max_buffer_time = 200; +} + +/** + * batadv_nc_init_orig() - initialise the nc fields of an orig_node + * @orig_node: the orig_node which is going to be initialised + */ +void batadv_nc_init_orig(struct batadv_orig_node *orig_node) +{ + INIT_LIST_HEAD(&orig_node->in_coding_list); + INIT_LIST_HEAD(&orig_node->out_coding_list); + spin_lock_init(&orig_node->in_coding_list_lock); + spin_lock_init(&orig_node->out_coding_list_lock); +} + +/** + * batadv_nc_node_release() - release nc_node from lists and queue for free + * after rcu grace period + * @ref: kref pointer of the nc_node + */ +static void batadv_nc_node_release(struct kref *ref) +{ + struct batadv_nc_node *nc_node; + + nc_node = container_of(ref, struct batadv_nc_node, refcount); + + batadv_orig_node_put(nc_node->orig_node); + kfree_rcu(nc_node, rcu); +} + +/** + * batadv_nc_node_put() - decrement the nc_node refcounter and possibly + * release it + * @nc_node: nc_node to be free'd + */ +static void batadv_nc_node_put(struct batadv_nc_node *nc_node) +{ + if (!nc_node) + return; + + kref_put(&nc_node->refcount, batadv_nc_node_release); +} + +/** + * batadv_nc_path_release() - release nc_path from lists and queue for free + * after rcu grace period + * @ref: kref pointer of the nc_path + */ +static void batadv_nc_path_release(struct kref *ref) +{ + struct batadv_nc_path *nc_path; + + nc_path = container_of(ref, struct batadv_nc_path, refcount); + + kfree_rcu(nc_path, rcu); +} + +/** + * batadv_nc_path_put() - decrement the nc_path refcounter and possibly + * release it + * @nc_path: nc_path to be free'd + */ +static void batadv_nc_path_put(struct batadv_nc_path *nc_path) +{ + if (!nc_path) + return; + + kref_put(&nc_path->refcount, batadv_nc_path_release); +} + +/** + * batadv_nc_packet_free() - frees nc packet + * @nc_packet: the nc packet to free + * @dropped: whether the packet is freed because is dropped + */ +static void batadv_nc_packet_free(struct batadv_nc_packet *nc_packet, + bool dropped) +{ + if (dropped) + kfree_skb(nc_packet->skb); + else + consume_skb(nc_packet->skb); + + batadv_nc_path_put(nc_packet->nc_path); + kfree(nc_packet); +} + +/** + * batadv_nc_to_purge_nc_node() - checks whether an nc node has to be purged + * @bat_priv: the bat priv with all the soft interface information + * @nc_node: the nc node to check + * + * Return: true if the entry has to be purged now, false otherwise + */ +static bool batadv_nc_to_purge_nc_node(struct batadv_priv *bat_priv, + struct batadv_nc_node *nc_node) +{ + if (atomic_read(&bat_priv->mesh_state) != BATADV_MESH_ACTIVE) + return true; + + return batadv_has_timed_out(nc_node->last_seen, BATADV_NC_NODE_TIMEOUT); +} + +/** + * batadv_nc_to_purge_nc_path_coding() - checks whether an nc path has timed out + * @bat_priv: the bat priv with all the soft interface information + * @nc_path: the nc path to check + * + * Return: true if the entry has to be purged now, false otherwise + */ +static bool batadv_nc_to_purge_nc_path_coding(struct batadv_priv *bat_priv, + struct batadv_nc_path *nc_path) +{ + if (atomic_read(&bat_priv->mesh_state) != BATADV_MESH_ACTIVE) + return true; + + /* purge the path when no packets has been added for 10 times the + * max_fwd_delay time + */ + return batadv_has_timed_out(nc_path->last_valid, + bat_priv->nc.max_fwd_delay * 10); +} + +/** + * batadv_nc_to_purge_nc_path_decoding() - checks whether an nc path has timed + * out + * @bat_priv: the bat priv with all the soft interface information + * @nc_path: the nc path to check + * + * Return: true if the entry has to be purged now, false otherwise + */ +static bool batadv_nc_to_purge_nc_path_decoding(struct batadv_priv *bat_priv, + struct batadv_nc_path *nc_path) +{ + if (atomic_read(&bat_priv->mesh_state) != BATADV_MESH_ACTIVE) + return true; + + /* purge the path when no packets has been added for 10 times the + * max_buffer time + */ + return batadv_has_timed_out(nc_path->last_valid, + bat_priv->nc.max_buffer_time * 10); +} + +/** + * batadv_nc_purge_orig_nc_nodes() - go through list of nc nodes and purge stale + * entries + * @bat_priv: the bat priv with all the soft interface information + * @list: list of nc nodes + * @lock: nc node list lock + * @to_purge: function in charge to decide whether an entry has to be purged or + * not. This function takes the nc node as argument and has to return + * a boolean value: true if the entry has to be deleted, false + * otherwise + */ +static void +batadv_nc_purge_orig_nc_nodes(struct batadv_priv *bat_priv, + struct list_head *list, + spinlock_t *lock, + bool (*to_purge)(struct batadv_priv *, + struct batadv_nc_node *)) +{ + struct batadv_nc_node *nc_node, *nc_node_tmp; + + /* For each nc_node in list */ + spin_lock_bh(lock); + list_for_each_entry_safe(nc_node, nc_node_tmp, list, list) { + /* if an helper function has been passed as parameter, + * ask it if the entry has to be purged or not + */ + if (to_purge && !to_purge(bat_priv, nc_node)) + continue; + + batadv_dbg(BATADV_DBG_NC, bat_priv, + "Removing nc_node %pM -> %pM\n", + nc_node->addr, nc_node->orig_node->orig); + list_del_rcu(&nc_node->list); + batadv_nc_node_put(nc_node); + } + spin_unlock_bh(lock); +} + +/** + * batadv_nc_purge_orig() - purges all nc node data attached of the given + * originator + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: orig_node with the nc node entries to be purged + * @to_purge: function in charge to decide whether an entry has to be purged or + * not. This function takes the nc node as argument and has to return + * a boolean value: true is the entry has to be deleted, false + * otherwise + */ +void batadv_nc_purge_orig(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + bool (*to_purge)(struct batadv_priv *, + struct batadv_nc_node *)) +{ + /* Check ingoing nc_node's of this orig_node */ + batadv_nc_purge_orig_nc_nodes(bat_priv, &orig_node->in_coding_list, + &orig_node->in_coding_list_lock, + to_purge); + + /* Check outgoing nc_node's of this orig_node */ + batadv_nc_purge_orig_nc_nodes(bat_priv, &orig_node->out_coding_list, + &orig_node->out_coding_list_lock, + to_purge); +} + +/** + * batadv_nc_purge_orig_hash() - traverse entire originator hash to check if + * they have timed out nc nodes + * @bat_priv: the bat priv with all the soft interface information + */ +static void batadv_nc_purge_orig_hash(struct batadv_priv *bat_priv) +{ + struct batadv_hashtable *hash = bat_priv->orig_hash; + struct hlist_head *head; + struct batadv_orig_node *orig_node; + u32 i; + + if (!hash) + return; + + /* For each orig_node */ + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(orig_node, head, hash_entry) + batadv_nc_purge_orig(bat_priv, orig_node, + batadv_nc_to_purge_nc_node); + rcu_read_unlock(); + } +} + +/** + * batadv_nc_purge_paths() - traverse all nc paths part of the hash and remove + * unused ones + * @bat_priv: the bat priv with all the soft interface information + * @hash: hash table containing the nc paths to check + * @to_purge: function in charge to decide whether an entry has to be purged or + * not. This function takes the nc node as argument and has to return + * a boolean value: true is the entry has to be deleted, false + * otherwise + */ +static void batadv_nc_purge_paths(struct batadv_priv *bat_priv, + struct batadv_hashtable *hash, + bool (*to_purge)(struct batadv_priv *, + struct batadv_nc_path *)) +{ + struct hlist_head *head; + struct hlist_node *node_tmp; + struct batadv_nc_path *nc_path; + spinlock_t *lock; /* Protects lists in hash */ + u32 i; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + lock = &hash->list_locks[i]; + + /* For each nc_path in this bin */ + spin_lock_bh(lock); + hlist_for_each_entry_safe(nc_path, node_tmp, head, hash_entry) { + /* if an helper function has been passed as parameter, + * ask it if the entry has to be purged or not + */ + if (to_purge && !to_purge(bat_priv, nc_path)) + continue; + + /* purging an non-empty nc_path should never happen, but + * is observed under high CPU load. Delay the purging + * until next iteration to allow the packet_list to be + * emptied first. + */ + if (!unlikely(list_empty(&nc_path->packet_list))) { + net_ratelimited_function(printk, + KERN_WARNING + "Skipping free of non-empty nc_path (%pM -> %pM)!\n", + nc_path->prev_hop, + nc_path->next_hop); + continue; + } + + /* nc_path is unused, so remove it */ + batadv_dbg(BATADV_DBG_NC, bat_priv, + "Remove nc_path %pM -> %pM\n", + nc_path->prev_hop, nc_path->next_hop); + hlist_del_rcu(&nc_path->hash_entry); + batadv_nc_path_put(nc_path); + } + spin_unlock_bh(lock); + } +} + +/** + * batadv_nc_hash_key_gen() - computes the nc_path hash key + * @key: buffer to hold the final hash key + * @src: source ethernet mac address going into the hash key + * @dst: destination ethernet mac address going into the hash key + */ +static void batadv_nc_hash_key_gen(struct batadv_nc_path *key, const char *src, + const char *dst) +{ + memcpy(key->prev_hop, src, sizeof(key->prev_hop)); + memcpy(key->next_hop, dst, sizeof(key->next_hop)); +} + +/** + * batadv_nc_hash_choose() - compute the hash value for an nc path + * @data: data to hash + * @size: size of the hash table + * + * Return: the selected index in the hash table for the given data. + */ +static u32 batadv_nc_hash_choose(const void *data, u32 size) +{ + const struct batadv_nc_path *nc_path = data; + u32 hash = 0; + + hash = jhash(&nc_path->prev_hop, sizeof(nc_path->prev_hop), hash); + hash = jhash(&nc_path->next_hop, sizeof(nc_path->next_hop), hash); + + return hash % size; +} + +/** + * batadv_nc_hash_compare() - comparing function used in the network coding hash + * tables + * @node: node in the local table + * @data2: second object to compare the node to + * + * Return: true if the two entry are the same, false otherwise + */ +static bool batadv_nc_hash_compare(const struct hlist_node *node, + const void *data2) +{ + const struct batadv_nc_path *nc_path1, *nc_path2; + + nc_path1 = container_of(node, struct batadv_nc_path, hash_entry); + nc_path2 = data2; + + /* Return 1 if the two keys are identical */ + if (!batadv_compare_eth(nc_path1->prev_hop, nc_path2->prev_hop)) + return false; + + if (!batadv_compare_eth(nc_path1->next_hop, nc_path2->next_hop)) + return false; + + return true; +} + +/** + * batadv_nc_hash_find() - search for an existing nc path and return it + * @hash: hash table containing the nc path + * @data: search key + * + * Return: the nc_path if found, NULL otherwise. + */ +static struct batadv_nc_path * +batadv_nc_hash_find(struct batadv_hashtable *hash, + void *data) +{ + struct hlist_head *head; + struct batadv_nc_path *nc_path, *nc_path_tmp = NULL; + int index; + + if (!hash) + return NULL; + + index = batadv_nc_hash_choose(data, hash->size); + head = &hash->table[index]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(nc_path, head, hash_entry) { + if (!batadv_nc_hash_compare(&nc_path->hash_entry, data)) + continue; + + if (!kref_get_unless_zero(&nc_path->refcount)) + continue; + + nc_path_tmp = nc_path; + break; + } + rcu_read_unlock(); + + return nc_path_tmp; +} + +/** + * batadv_nc_send_packet() - send non-coded packet and free nc_packet struct + * @nc_packet: the nc packet to send + */ +static void batadv_nc_send_packet(struct batadv_nc_packet *nc_packet) +{ + batadv_send_unicast_skb(nc_packet->skb, nc_packet->neigh_node); + nc_packet->skb = NULL; + batadv_nc_packet_free(nc_packet, false); +} + +/** + * batadv_nc_sniffed_purge() - Checks timestamp of given sniffed nc_packet. + * @bat_priv: the bat priv with all the soft interface information + * @nc_path: the nc path the packet belongs to + * @nc_packet: the nc packet to be checked + * + * Checks whether the given sniffed (overheard) nc_packet has hit its buffering + * timeout. If so, the packet is no longer kept and the entry deleted from the + * queue. Has to be called with the appropriate locks. + * + * Return: false as soon as the entry in the fifo queue has not been timed out + * yet and true otherwise. + */ +static bool batadv_nc_sniffed_purge(struct batadv_priv *bat_priv, + struct batadv_nc_path *nc_path, + struct batadv_nc_packet *nc_packet) +{ + unsigned long timeout = bat_priv->nc.max_buffer_time; + bool res = false; + + lockdep_assert_held(&nc_path->packet_list_lock); + + /* Packets are added to tail, so the remaining packets did not time + * out and we can stop processing the current queue + */ + if (atomic_read(&bat_priv->mesh_state) == BATADV_MESH_ACTIVE && + !batadv_has_timed_out(nc_packet->timestamp, timeout)) + goto out; + + /* purge nc packet */ + list_del(&nc_packet->list); + batadv_nc_packet_free(nc_packet, true); + + res = true; + +out: + return res; +} + +/** + * batadv_nc_fwd_flush() - Checks the timestamp of the given nc packet. + * @bat_priv: the bat priv with all the soft interface information + * @nc_path: the nc path the packet belongs to + * @nc_packet: the nc packet to be checked + * + * Checks whether the given nc packet has hit its forward timeout. If so, the + * packet is no longer delayed, immediately sent and the entry deleted from the + * queue. Has to be called with the appropriate locks. + * + * Return: false as soon as the entry in the fifo queue has not been timed out + * yet and true otherwise. + */ +static bool batadv_nc_fwd_flush(struct batadv_priv *bat_priv, + struct batadv_nc_path *nc_path, + struct batadv_nc_packet *nc_packet) +{ + unsigned long timeout = bat_priv->nc.max_fwd_delay; + + lockdep_assert_held(&nc_path->packet_list_lock); + + /* Packets are added to tail, so the remaining packets did not time + * out and we can stop processing the current queue + */ + if (atomic_read(&bat_priv->mesh_state) == BATADV_MESH_ACTIVE && + !batadv_has_timed_out(nc_packet->timestamp, timeout)) + return false; + + /* Send packet */ + batadv_inc_counter(bat_priv, BATADV_CNT_FORWARD); + batadv_add_counter(bat_priv, BATADV_CNT_FORWARD_BYTES, + nc_packet->skb->len + ETH_HLEN); + list_del(&nc_packet->list); + batadv_nc_send_packet(nc_packet); + + return true; +} + +/** + * batadv_nc_process_nc_paths() - traverse given nc packet pool and free timed + * out nc packets + * @bat_priv: the bat priv with all the soft interface information + * @hash: to be processed hash table + * @process_fn: Function called to process given nc packet. Should return true + * to encourage this function to proceed with the next packet. + * Otherwise the rest of the current queue is skipped. + */ +static void +batadv_nc_process_nc_paths(struct batadv_priv *bat_priv, + struct batadv_hashtable *hash, + bool (*process_fn)(struct batadv_priv *, + struct batadv_nc_path *, + struct batadv_nc_packet *)) +{ + struct hlist_head *head; + struct batadv_nc_packet *nc_packet, *nc_packet_tmp; + struct batadv_nc_path *nc_path; + bool ret; + int i; + + if (!hash) + return; + + /* Loop hash table bins */ + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + + /* Loop coding paths */ + rcu_read_lock(); + hlist_for_each_entry_rcu(nc_path, head, hash_entry) { + /* Loop packets */ + spin_lock_bh(&nc_path->packet_list_lock); + list_for_each_entry_safe(nc_packet, nc_packet_tmp, + &nc_path->packet_list, list) { + ret = process_fn(bat_priv, nc_path, nc_packet); + if (!ret) + break; + } + spin_unlock_bh(&nc_path->packet_list_lock); + } + rcu_read_unlock(); + } +} + +/** + * batadv_nc_worker() - periodic task for housekeeping related to network + * coding + * @work: kernel work struct + */ +static void batadv_nc_worker(struct work_struct *work) +{ + struct delayed_work *delayed_work; + struct batadv_priv_nc *priv_nc; + struct batadv_priv *bat_priv; + unsigned long timeout; + + delayed_work = to_delayed_work(work); + priv_nc = container_of(delayed_work, struct batadv_priv_nc, work); + bat_priv = container_of(priv_nc, struct batadv_priv, nc); + + batadv_nc_purge_orig_hash(bat_priv); + batadv_nc_purge_paths(bat_priv, bat_priv->nc.coding_hash, + batadv_nc_to_purge_nc_path_coding); + batadv_nc_purge_paths(bat_priv, bat_priv->nc.decoding_hash, + batadv_nc_to_purge_nc_path_decoding); + + timeout = bat_priv->nc.max_fwd_delay; + + if (batadv_has_timed_out(bat_priv->nc.timestamp_fwd_flush, timeout)) { + batadv_nc_process_nc_paths(bat_priv, bat_priv->nc.coding_hash, + batadv_nc_fwd_flush); + bat_priv->nc.timestamp_fwd_flush = jiffies; + } + + if (batadv_has_timed_out(bat_priv->nc.timestamp_sniffed_purge, + bat_priv->nc.max_buffer_time)) { + batadv_nc_process_nc_paths(bat_priv, bat_priv->nc.decoding_hash, + batadv_nc_sniffed_purge); + bat_priv->nc.timestamp_sniffed_purge = jiffies; + } + + /* Schedule a new check */ + batadv_nc_start_timer(bat_priv); +} + +/** + * batadv_can_nc_with_orig() - checks whether the given orig node is suitable + * for coding or not + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: neighboring orig node which may be used as nc candidate + * @ogm_packet: incoming ogm packet also used for the checks + * + * Return: true if: + * 1) The OGM must have the most recent sequence number. + * 2) The TTL must be decremented by one and only one. + * 3) The OGM must be received from the first hop from orig_node. + * 4) The TQ value of the OGM must be above bat_priv->nc.min_tq. + */ +static bool batadv_can_nc_with_orig(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + struct batadv_ogm_packet *ogm_packet) +{ + struct batadv_orig_ifinfo *orig_ifinfo; + u32 last_real_seqno; + u8 last_ttl; + + orig_ifinfo = batadv_orig_ifinfo_get(orig_node, BATADV_IF_DEFAULT); + if (!orig_ifinfo) + return false; + + last_ttl = orig_ifinfo->last_ttl; + last_real_seqno = orig_ifinfo->last_real_seqno; + batadv_orig_ifinfo_put(orig_ifinfo); + + if (last_real_seqno != ntohl(ogm_packet->seqno)) + return false; + if (last_ttl != ogm_packet->ttl + 1) + return false; + if (!batadv_compare_eth(ogm_packet->orig, ogm_packet->prev_sender)) + return false; + if (ogm_packet->tq < bat_priv->nc.min_tq) + return false; + + return true; +} + +/** + * batadv_nc_find_nc_node() - search for an existing nc node and return it + * @orig_node: orig node originating the ogm packet + * @orig_neigh_node: neighboring orig node from which we received the ogm packet + * (can be equal to orig_node) + * @in_coding: traverse incoming or outgoing network coding list + * + * Return: the nc_node if found, NULL otherwise. + */ +static struct batadv_nc_node * +batadv_nc_find_nc_node(struct batadv_orig_node *orig_node, + struct batadv_orig_node *orig_neigh_node, + bool in_coding) +{ + struct batadv_nc_node *nc_node, *nc_node_out = NULL; + struct list_head *list; + + if (in_coding) + list = &orig_neigh_node->in_coding_list; + else + list = &orig_neigh_node->out_coding_list; + + /* Traverse list of nc_nodes to orig_node */ + rcu_read_lock(); + list_for_each_entry_rcu(nc_node, list, list) { + if (!batadv_compare_eth(nc_node->addr, orig_node->orig)) + continue; + + if (!kref_get_unless_zero(&nc_node->refcount)) + continue; + + /* Found a match */ + nc_node_out = nc_node; + break; + } + rcu_read_unlock(); + + return nc_node_out; +} + +/** + * batadv_nc_get_nc_node() - retrieves an nc node or creates the entry if it was + * not found + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: orig node originating the ogm packet + * @orig_neigh_node: neighboring orig node from which we received the ogm packet + * (can be equal to orig_node) + * @in_coding: traverse incoming or outgoing network coding list + * + * Return: the nc_node if found or created, NULL in case of an error. + */ +static struct batadv_nc_node * +batadv_nc_get_nc_node(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + struct batadv_orig_node *orig_neigh_node, + bool in_coding) +{ + struct batadv_nc_node *nc_node; + spinlock_t *lock; /* Used to lock list selected by "int in_coding" */ + struct list_head *list; + + /* Select ingoing or outgoing coding node */ + if (in_coding) { + lock = &orig_neigh_node->in_coding_list_lock; + list = &orig_neigh_node->in_coding_list; + } else { + lock = &orig_neigh_node->out_coding_list_lock; + list = &orig_neigh_node->out_coding_list; + } + + spin_lock_bh(lock); + + /* Check if nc_node is already added */ + nc_node = batadv_nc_find_nc_node(orig_node, orig_neigh_node, in_coding); + + /* Node found */ + if (nc_node) + goto unlock; + + nc_node = kzalloc(sizeof(*nc_node), GFP_ATOMIC); + if (!nc_node) + goto unlock; + + /* Initialize nc_node */ + INIT_LIST_HEAD(&nc_node->list); + kref_init(&nc_node->refcount); + ether_addr_copy(nc_node->addr, orig_node->orig); + kref_get(&orig_neigh_node->refcount); + nc_node->orig_node = orig_neigh_node; + + batadv_dbg(BATADV_DBG_NC, bat_priv, "Adding nc_node %pM -> %pM\n", + nc_node->addr, nc_node->orig_node->orig); + + /* Add nc_node to orig_node */ + kref_get(&nc_node->refcount); + list_add_tail_rcu(&nc_node->list, list); + +unlock: + spin_unlock_bh(lock); + + return nc_node; +} + +/** + * batadv_nc_update_nc_node() - updates stored incoming and outgoing nc node + * structs (best called on incoming OGMs) + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: orig node originating the ogm packet + * @orig_neigh_node: neighboring orig node from which we received the ogm packet + * (can be equal to orig_node) + * @ogm_packet: incoming ogm packet + * @is_single_hop_neigh: orig_node is a single hop neighbor + */ +void batadv_nc_update_nc_node(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + struct batadv_orig_node *orig_neigh_node, + struct batadv_ogm_packet *ogm_packet, + int is_single_hop_neigh) +{ + struct batadv_nc_node *in_nc_node = NULL; + struct batadv_nc_node *out_nc_node = NULL; + + /* Check if network coding is enabled */ + if (!atomic_read(&bat_priv->network_coding)) + goto out; + + /* check if orig node is network coding enabled */ + if (!test_bit(BATADV_ORIG_CAPA_HAS_NC, &orig_node->capabilities)) + goto out; + + /* accept ogms from 'good' neighbors and single hop neighbors */ + if (!batadv_can_nc_with_orig(bat_priv, orig_node, ogm_packet) && + !is_single_hop_neigh) + goto out; + + /* Add orig_node as in_nc_node on hop */ + in_nc_node = batadv_nc_get_nc_node(bat_priv, orig_node, + orig_neigh_node, true); + if (!in_nc_node) + goto out; + + in_nc_node->last_seen = jiffies; + + /* Add hop as out_nc_node on orig_node */ + out_nc_node = batadv_nc_get_nc_node(bat_priv, orig_neigh_node, + orig_node, false); + if (!out_nc_node) + goto out; + + out_nc_node->last_seen = jiffies; + +out: + batadv_nc_node_put(in_nc_node); + batadv_nc_node_put(out_nc_node); +} + +/** + * batadv_nc_get_path() - get existing nc_path or allocate a new one + * @bat_priv: the bat priv with all the soft interface information + * @hash: hash table containing the nc path + * @src: ethernet source address - first half of the nc path search key + * @dst: ethernet destination address - second half of the nc path search key + * + * Return: pointer to nc_path if the path was found or created, returns NULL + * on error. + */ +static struct batadv_nc_path *batadv_nc_get_path(struct batadv_priv *bat_priv, + struct batadv_hashtable *hash, + u8 *src, + u8 *dst) +{ + int hash_added; + struct batadv_nc_path *nc_path, nc_path_key; + + batadv_nc_hash_key_gen(&nc_path_key, src, dst); + + /* Search for existing nc_path */ + nc_path = batadv_nc_hash_find(hash, (void *)&nc_path_key); + + if (nc_path) { + /* Set timestamp to delay removal of nc_path */ + nc_path->last_valid = jiffies; + return nc_path; + } + + /* No existing nc_path was found; create a new */ + nc_path = kzalloc(sizeof(*nc_path), GFP_ATOMIC); + + if (!nc_path) + return NULL; + + /* Initialize nc_path */ + INIT_LIST_HEAD(&nc_path->packet_list); + spin_lock_init(&nc_path->packet_list_lock); + kref_init(&nc_path->refcount); + nc_path->last_valid = jiffies; + ether_addr_copy(nc_path->next_hop, dst); + ether_addr_copy(nc_path->prev_hop, src); + + batadv_dbg(BATADV_DBG_NC, bat_priv, "Adding nc_path %pM -> %pM\n", + nc_path->prev_hop, + nc_path->next_hop); + + /* Add nc_path to hash table */ + kref_get(&nc_path->refcount); + hash_added = batadv_hash_add(hash, batadv_nc_hash_compare, + batadv_nc_hash_choose, &nc_path_key, + &nc_path->hash_entry); + + if (hash_added < 0) { + kfree(nc_path); + return NULL; + } + + return nc_path; +} + +/** + * batadv_nc_random_weight_tq() - scale the receivers TQ-value to avoid unfair + * selection of a receiver with slightly lower TQ than the other + * @tq: to be weighted tq value + * + * Return: scaled tq value + */ +static u8 batadv_nc_random_weight_tq(u8 tq) +{ + /* randomize the estimated packet loss (max TQ - estimated TQ) */ + u8 rand_tq = prandom_u32_max(BATADV_TQ_MAX_VALUE + 1 - tq); + + /* convert to (randomized) estimated tq again */ + return BATADV_TQ_MAX_VALUE - rand_tq; +} + +/** + * batadv_nc_memxor() - XOR destination with source + * @dst: byte array to XOR into + * @src: byte array to XOR from + * @len: length of destination array + */ +static void batadv_nc_memxor(char *dst, const char *src, unsigned int len) +{ + unsigned int i; + + for (i = 0; i < len; ++i) + dst[i] ^= src[i]; +} + +/** + * batadv_nc_code_packets() - code a received unicast_packet with an nc packet + * into a coded_packet and send it + * @bat_priv: the bat priv with all the soft interface information + * @skb: data skb to forward + * @ethhdr: pointer to the ethernet header inside the skb + * @nc_packet: structure containing the packet to the skb can be coded with + * @neigh_node: next hop to forward packet to + * + * Return: true if both packets are consumed, false otherwise. + */ +static bool batadv_nc_code_packets(struct batadv_priv *bat_priv, + struct sk_buff *skb, + struct ethhdr *ethhdr, + struct batadv_nc_packet *nc_packet, + struct batadv_neigh_node *neigh_node) +{ + u8 tq_weighted_neigh, tq_weighted_coding, tq_tmp; + struct sk_buff *skb_dest, *skb_src; + struct batadv_unicast_packet *packet1; + struct batadv_unicast_packet *packet2; + struct batadv_coded_packet *coded_packet; + struct batadv_neigh_node *neigh_tmp, *router_neigh, *first_dest; + struct batadv_neigh_node *router_coding = NULL, *second_dest; + struct batadv_neigh_ifinfo *router_neigh_ifinfo = NULL; + struct batadv_neigh_ifinfo *router_coding_ifinfo = NULL; + u8 *first_source, *second_source; + __be32 packet_id1, packet_id2; + size_t count; + bool res = false; + int coding_len; + int unicast_size = sizeof(*packet1); + int coded_size = sizeof(*coded_packet); + int header_add = coded_size - unicast_size; + + /* TODO: do we need to consider the outgoing interface for + * coded packets? + */ + router_neigh = batadv_orig_router_get(neigh_node->orig_node, + BATADV_IF_DEFAULT); + if (!router_neigh) + goto out; + + router_neigh_ifinfo = batadv_neigh_ifinfo_get(router_neigh, + BATADV_IF_DEFAULT); + if (!router_neigh_ifinfo) + goto out; + + neigh_tmp = nc_packet->neigh_node; + router_coding = batadv_orig_router_get(neigh_tmp->orig_node, + BATADV_IF_DEFAULT); + if (!router_coding) + goto out; + + router_coding_ifinfo = batadv_neigh_ifinfo_get(router_coding, + BATADV_IF_DEFAULT); + if (!router_coding_ifinfo) + goto out; + + tq_tmp = router_neigh_ifinfo->bat_iv.tq_avg; + tq_weighted_neigh = batadv_nc_random_weight_tq(tq_tmp); + tq_tmp = router_coding_ifinfo->bat_iv.tq_avg; + tq_weighted_coding = batadv_nc_random_weight_tq(tq_tmp); + + /* Select one destination for the MAC-header dst-field based on + * weighted TQ-values. + */ + if (tq_weighted_neigh >= tq_weighted_coding) { + /* Destination from nc_packet is selected for MAC-header */ + first_dest = nc_packet->neigh_node; + first_source = nc_packet->nc_path->prev_hop; + second_dest = neigh_node; + second_source = ethhdr->h_source; + packet1 = (struct batadv_unicast_packet *)nc_packet->skb->data; + packet2 = (struct batadv_unicast_packet *)skb->data; + packet_id1 = nc_packet->packet_id; + packet_id2 = batadv_skb_crc32(skb, + skb->data + sizeof(*packet2)); + } else { + /* Destination for skb is selected for MAC-header */ + first_dest = neigh_node; + first_source = ethhdr->h_source; + second_dest = nc_packet->neigh_node; + second_source = nc_packet->nc_path->prev_hop; + packet1 = (struct batadv_unicast_packet *)skb->data; + packet2 = (struct batadv_unicast_packet *)nc_packet->skb->data; + packet_id1 = batadv_skb_crc32(skb, + skb->data + sizeof(*packet1)); + packet_id2 = nc_packet->packet_id; + } + + /* Instead of zero padding the smallest data buffer, we + * code into the largest. + */ + if (skb->len <= nc_packet->skb->len) { + skb_dest = nc_packet->skb; + skb_src = skb; + } else { + skb_dest = skb; + skb_src = nc_packet->skb; + } + + /* coding_len is used when decoding the packet shorter packet */ + coding_len = skb_src->len - unicast_size; + + if (skb_linearize(skb_dest) < 0 || skb_linearize(skb_src) < 0) + goto out; + + skb_push(skb_dest, header_add); + + coded_packet = (struct batadv_coded_packet *)skb_dest->data; + skb_reset_mac_header(skb_dest); + + coded_packet->packet_type = BATADV_CODED; + coded_packet->version = BATADV_COMPAT_VERSION; + coded_packet->ttl = packet1->ttl; + + /* Info about first unicast packet */ + ether_addr_copy(coded_packet->first_source, first_source); + ether_addr_copy(coded_packet->first_orig_dest, packet1->dest); + coded_packet->first_crc = packet_id1; + coded_packet->first_ttvn = packet1->ttvn; + + /* Info about second unicast packet */ + ether_addr_copy(coded_packet->second_dest, second_dest->addr); + ether_addr_copy(coded_packet->second_source, second_source); + ether_addr_copy(coded_packet->second_orig_dest, packet2->dest); + coded_packet->second_crc = packet_id2; + coded_packet->second_ttl = packet2->ttl; + coded_packet->second_ttvn = packet2->ttvn; + coded_packet->coded_len = htons(coding_len); + + /* This is where the magic happens: Code skb_src into skb_dest */ + batadv_nc_memxor(skb_dest->data + coded_size, + skb_src->data + unicast_size, coding_len); + + /* Update counters accordingly */ + if (BATADV_SKB_CB(skb_src)->decoded && + BATADV_SKB_CB(skb_dest)->decoded) { + /* Both packets are recoded */ + count = skb_src->len + ETH_HLEN; + count += skb_dest->len + ETH_HLEN; + batadv_add_counter(bat_priv, BATADV_CNT_NC_RECODE, 2); + batadv_add_counter(bat_priv, BATADV_CNT_NC_RECODE_BYTES, count); + } else if (!BATADV_SKB_CB(skb_src)->decoded && + !BATADV_SKB_CB(skb_dest)->decoded) { + /* Both packets are newly coded */ + count = skb_src->len + ETH_HLEN; + count += skb_dest->len + ETH_HLEN; + batadv_add_counter(bat_priv, BATADV_CNT_NC_CODE, 2); + batadv_add_counter(bat_priv, BATADV_CNT_NC_CODE_BYTES, count); + } else if (BATADV_SKB_CB(skb_src)->decoded && + !BATADV_SKB_CB(skb_dest)->decoded) { + /* skb_src recoded and skb_dest is newly coded */ + batadv_inc_counter(bat_priv, BATADV_CNT_NC_RECODE); + batadv_add_counter(bat_priv, BATADV_CNT_NC_RECODE_BYTES, + skb_src->len + ETH_HLEN); + batadv_inc_counter(bat_priv, BATADV_CNT_NC_CODE); + batadv_add_counter(bat_priv, BATADV_CNT_NC_CODE_BYTES, + skb_dest->len + ETH_HLEN); + } else if (!BATADV_SKB_CB(skb_src)->decoded && + BATADV_SKB_CB(skb_dest)->decoded) { + /* skb_src is newly coded and skb_dest is recoded */ + batadv_inc_counter(bat_priv, BATADV_CNT_NC_CODE); + batadv_add_counter(bat_priv, BATADV_CNT_NC_CODE_BYTES, + skb_src->len + ETH_HLEN); + batadv_inc_counter(bat_priv, BATADV_CNT_NC_RECODE); + batadv_add_counter(bat_priv, BATADV_CNT_NC_RECODE_BYTES, + skb_dest->len + ETH_HLEN); + } + + /* skb_src is now coded into skb_dest, so free it */ + consume_skb(skb_src); + + /* avoid duplicate free of skb from nc_packet */ + nc_packet->skb = NULL; + batadv_nc_packet_free(nc_packet, false); + + /* Send the coded packet and return true */ + batadv_send_unicast_skb(skb_dest, first_dest); + res = true; +out: + batadv_neigh_node_put(router_neigh); + batadv_neigh_node_put(router_coding); + batadv_neigh_ifinfo_put(router_neigh_ifinfo); + batadv_neigh_ifinfo_put(router_coding_ifinfo); + return res; +} + +/** + * batadv_nc_skb_coding_possible() - true if a decoded skb is available at dst. + * @skb: data skb to forward + * @dst: destination mac address of the other skb to code with + * @src: source mac address of skb + * + * Whenever we network code a packet we have to check whether we received it in + * a network coded form. If so, we may not be able to use it for coding because + * some neighbors may also have received (overheard) the packet in the network + * coded form without being able to decode it. It is hard to know which of the + * neighboring nodes was able to decode the packet, therefore we can only + * re-code the packet if the source of the previous encoded packet is involved. + * Since the source encoded the packet we can be certain it has all necessary + * decode information. + * + * Return: true if coding of a decoded packet is allowed. + */ +static bool batadv_nc_skb_coding_possible(struct sk_buff *skb, u8 *dst, u8 *src) +{ + if (BATADV_SKB_CB(skb)->decoded && !batadv_compare_eth(dst, src)) + return false; + return true; +} + +/** + * batadv_nc_path_search() - Find the coding path matching in_nc_node and + * out_nc_node to retrieve a buffered packet that can be used for coding. + * @bat_priv: the bat priv with all the soft interface information + * @in_nc_node: pointer to skb next hop's neighbor nc node + * @out_nc_node: pointer to skb source's neighbor nc node + * @skb: data skb to forward + * @eth_dst: next hop mac address of skb + * + * Return: true if coding of a decoded skb is allowed. + */ +static struct batadv_nc_packet * +batadv_nc_path_search(struct batadv_priv *bat_priv, + struct batadv_nc_node *in_nc_node, + struct batadv_nc_node *out_nc_node, + struct sk_buff *skb, + u8 *eth_dst) +{ + struct batadv_nc_path *nc_path, nc_path_key; + struct batadv_nc_packet *nc_packet_out = NULL; + struct batadv_nc_packet *nc_packet, *nc_packet_tmp; + struct batadv_hashtable *hash = bat_priv->nc.coding_hash; + int idx; + + if (!hash) + return NULL; + + /* Create almost path key */ + batadv_nc_hash_key_gen(&nc_path_key, in_nc_node->addr, + out_nc_node->addr); + idx = batadv_nc_hash_choose(&nc_path_key, hash->size); + + /* Check for coding opportunities in this nc_path */ + rcu_read_lock(); + hlist_for_each_entry_rcu(nc_path, &hash->table[idx], hash_entry) { + if (!batadv_compare_eth(nc_path->prev_hop, in_nc_node->addr)) + continue; + + if (!batadv_compare_eth(nc_path->next_hop, out_nc_node->addr)) + continue; + + spin_lock_bh(&nc_path->packet_list_lock); + if (list_empty(&nc_path->packet_list)) { + spin_unlock_bh(&nc_path->packet_list_lock); + continue; + } + + list_for_each_entry_safe(nc_packet, nc_packet_tmp, + &nc_path->packet_list, list) { + if (!batadv_nc_skb_coding_possible(nc_packet->skb, + eth_dst, + in_nc_node->addr)) + continue; + + /* Coding opportunity is found! */ + list_del(&nc_packet->list); + nc_packet_out = nc_packet; + break; + } + + spin_unlock_bh(&nc_path->packet_list_lock); + break; + } + rcu_read_unlock(); + + return nc_packet_out; +} + +/** + * batadv_nc_skb_src_search() - Loops through the list of neighboring nodes of + * the skb's sender (may be equal to the originator). + * @bat_priv: the bat priv with all the soft interface information + * @skb: data skb to forward + * @eth_dst: next hop mac address of skb + * @eth_src: source mac address of skb + * @in_nc_node: pointer to skb next hop's neighbor nc node + * + * Return: an nc packet if a suitable coding packet was found, NULL otherwise. + */ +static struct batadv_nc_packet * +batadv_nc_skb_src_search(struct batadv_priv *bat_priv, + struct sk_buff *skb, + u8 *eth_dst, + u8 *eth_src, + struct batadv_nc_node *in_nc_node) +{ + struct batadv_orig_node *orig_node; + struct batadv_nc_node *out_nc_node; + struct batadv_nc_packet *nc_packet = NULL; + + orig_node = batadv_orig_hash_find(bat_priv, eth_src); + if (!orig_node) + return NULL; + + rcu_read_lock(); + list_for_each_entry_rcu(out_nc_node, + &orig_node->out_coding_list, list) { + /* Check if the skb is decoded and if recoding is possible */ + if (!batadv_nc_skb_coding_possible(skb, + out_nc_node->addr, eth_src)) + continue; + + /* Search for an opportunity in this nc_path */ + nc_packet = batadv_nc_path_search(bat_priv, in_nc_node, + out_nc_node, skb, eth_dst); + if (nc_packet) + break; + } + rcu_read_unlock(); + + batadv_orig_node_put(orig_node); + return nc_packet; +} + +/** + * batadv_nc_skb_store_before_coding() - set the ethernet src and dst of the + * unicast skb before it is stored for use in later decoding + * @bat_priv: the bat priv with all the soft interface information + * @skb: data skb to store + * @eth_dst_new: new destination mac address of skb + */ +static void batadv_nc_skb_store_before_coding(struct batadv_priv *bat_priv, + struct sk_buff *skb, + u8 *eth_dst_new) +{ + struct ethhdr *ethhdr; + + /* Copy skb header to change the mac header */ + skb = pskb_copy_for_clone(skb, GFP_ATOMIC); + if (!skb) + return; + + /* Set the mac header as if we actually sent the packet uncoded */ + ethhdr = eth_hdr(skb); + ether_addr_copy(ethhdr->h_source, ethhdr->h_dest); + ether_addr_copy(ethhdr->h_dest, eth_dst_new); + + /* Set data pointer to MAC header to mimic packets from our tx path */ + skb_push(skb, ETH_HLEN); + + /* Add the packet to the decoding packet pool */ + batadv_nc_skb_store_for_decoding(bat_priv, skb); + + /* batadv_nc_skb_store_for_decoding() clones the skb, so we must free + * our ref + */ + consume_skb(skb); +} + +/** + * batadv_nc_skb_dst_search() - Loops through list of neighboring nodes to dst. + * @skb: data skb to forward + * @neigh_node: next hop to forward packet to + * @ethhdr: pointer to the ethernet header inside the skb + * + * Loops through the list of neighboring nodes the next hop has a good + * connection to (receives OGMs with a sufficient quality). We need to find a + * neighbor of our next hop that potentially sent a packet which our next hop + * also received (overheard) and has stored for later decoding. + * + * Return: true if the skb was consumed (encoded packet sent) or false otherwise + */ +static bool batadv_nc_skb_dst_search(struct sk_buff *skb, + struct batadv_neigh_node *neigh_node, + struct ethhdr *ethhdr) +{ + struct net_device *netdev = neigh_node->if_incoming->soft_iface; + struct batadv_priv *bat_priv = netdev_priv(netdev); + struct batadv_orig_node *orig_node = neigh_node->orig_node; + struct batadv_nc_node *nc_node; + struct batadv_nc_packet *nc_packet = NULL; + + rcu_read_lock(); + list_for_each_entry_rcu(nc_node, &orig_node->in_coding_list, list) { + /* Search for coding opportunity with this in_nc_node */ + nc_packet = batadv_nc_skb_src_search(bat_priv, skb, + neigh_node->addr, + ethhdr->h_source, nc_node); + + /* Opportunity was found, so stop searching */ + if (nc_packet) + break; + } + rcu_read_unlock(); + + if (!nc_packet) + return false; + + /* Save packets for later decoding */ + batadv_nc_skb_store_before_coding(bat_priv, skb, + neigh_node->addr); + batadv_nc_skb_store_before_coding(bat_priv, nc_packet->skb, + nc_packet->neigh_node->addr); + + /* Code and send packets */ + if (batadv_nc_code_packets(bat_priv, skb, ethhdr, nc_packet, + neigh_node)) + return true; + + /* out of mem ? Coding failed - we have to free the buffered packet + * to avoid memleaks. The skb passed as argument will be dealt with + * by the calling function. + */ + batadv_nc_send_packet(nc_packet); + return false; +} + +/** + * batadv_nc_skb_add_to_path() - buffer skb for later encoding / decoding + * @skb: skb to add to path + * @nc_path: path to add skb to + * @neigh_node: next hop to forward packet to + * @packet_id: checksum to identify packet + * + * Return: true if the packet was buffered or false in case of an error. + */ +static bool batadv_nc_skb_add_to_path(struct sk_buff *skb, + struct batadv_nc_path *nc_path, + struct batadv_neigh_node *neigh_node, + __be32 packet_id) +{ + struct batadv_nc_packet *nc_packet; + + nc_packet = kzalloc(sizeof(*nc_packet), GFP_ATOMIC); + if (!nc_packet) + return false; + + /* Initialize nc_packet */ + nc_packet->timestamp = jiffies; + nc_packet->packet_id = packet_id; + nc_packet->skb = skb; + nc_packet->neigh_node = neigh_node; + nc_packet->nc_path = nc_path; + + /* Add coding packet to list */ + spin_lock_bh(&nc_path->packet_list_lock); + list_add_tail(&nc_packet->list, &nc_path->packet_list); + spin_unlock_bh(&nc_path->packet_list_lock); + + return true; +} + +/** + * batadv_nc_skb_forward() - try to code a packet or add it to the coding packet + * buffer + * @skb: data skb to forward + * @neigh_node: next hop to forward packet to + * + * Return: true if the skb was consumed (encoded packet sent) or false otherwise + */ +bool batadv_nc_skb_forward(struct sk_buff *skb, + struct batadv_neigh_node *neigh_node) +{ + const struct net_device *netdev = neigh_node->if_incoming->soft_iface; + struct batadv_priv *bat_priv = netdev_priv(netdev); + struct batadv_unicast_packet *packet; + struct batadv_nc_path *nc_path; + struct ethhdr *ethhdr = eth_hdr(skb); + __be32 packet_id; + u8 *payload; + + /* Check if network coding is enabled */ + if (!atomic_read(&bat_priv->network_coding)) + goto out; + + /* We only handle unicast packets */ + payload = skb_network_header(skb); + packet = (struct batadv_unicast_packet *)payload; + if (packet->packet_type != BATADV_UNICAST) + goto out; + + /* Try to find a coding opportunity and send the skb if one is found */ + if (batadv_nc_skb_dst_search(skb, neigh_node, ethhdr)) + return true; + + /* Find or create a nc_path for this src-dst pair */ + nc_path = batadv_nc_get_path(bat_priv, + bat_priv->nc.coding_hash, + ethhdr->h_source, + neigh_node->addr); + + if (!nc_path) + goto out; + + /* Add skb to nc_path */ + packet_id = batadv_skb_crc32(skb, payload + sizeof(*packet)); + if (!batadv_nc_skb_add_to_path(skb, nc_path, neigh_node, packet_id)) + goto free_nc_path; + + /* Packet is consumed */ + return true; + +free_nc_path: + batadv_nc_path_put(nc_path); +out: + /* Packet is not consumed */ + return false; +} + +/** + * batadv_nc_skb_store_for_decoding() - save a clone of the skb which can be + * used when decoding coded packets + * @bat_priv: the bat priv with all the soft interface information + * @skb: data skb to store + */ +void batadv_nc_skb_store_for_decoding(struct batadv_priv *bat_priv, + struct sk_buff *skb) +{ + struct batadv_unicast_packet *packet; + struct batadv_nc_path *nc_path; + struct ethhdr *ethhdr = eth_hdr(skb); + __be32 packet_id; + u8 *payload; + + /* Check if network coding is enabled */ + if (!atomic_read(&bat_priv->network_coding)) + goto out; + + /* Check for supported packet type */ + payload = skb_network_header(skb); + packet = (struct batadv_unicast_packet *)payload; + if (packet->packet_type != BATADV_UNICAST) + goto out; + + /* Find existing nc_path or create a new */ + nc_path = batadv_nc_get_path(bat_priv, + bat_priv->nc.decoding_hash, + ethhdr->h_source, + ethhdr->h_dest); + + if (!nc_path) + goto out; + + /* Clone skb and adjust skb->data to point at batman header */ + skb = skb_clone(skb, GFP_ATOMIC); + if (unlikely(!skb)) + goto free_nc_path; + + if (unlikely(!pskb_may_pull(skb, ETH_HLEN))) + goto free_skb; + + if (unlikely(!skb_pull_rcsum(skb, ETH_HLEN))) + goto free_skb; + + /* Add skb to nc_path */ + packet_id = batadv_skb_crc32(skb, payload + sizeof(*packet)); + if (!batadv_nc_skb_add_to_path(skb, nc_path, NULL, packet_id)) + goto free_skb; + + batadv_inc_counter(bat_priv, BATADV_CNT_NC_BUFFER); + return; + +free_skb: + kfree_skb(skb); +free_nc_path: + batadv_nc_path_put(nc_path); +out: + return; +} + +/** + * batadv_nc_skb_store_sniffed_unicast() - check if a received unicast packet + * should be saved in the decoding buffer and, if so, store it there + * @bat_priv: the bat priv with all the soft interface information + * @skb: unicast skb to store + */ +void batadv_nc_skb_store_sniffed_unicast(struct batadv_priv *bat_priv, + struct sk_buff *skb) +{ + struct ethhdr *ethhdr = eth_hdr(skb); + + if (batadv_is_my_mac(bat_priv, ethhdr->h_dest)) + return; + + /* Set data pointer to MAC header to mimic packets from our tx path */ + skb_push(skb, ETH_HLEN); + + batadv_nc_skb_store_for_decoding(bat_priv, skb); +} + +/** + * batadv_nc_skb_decode_packet() - decode given skb using the decode data stored + * in nc_packet + * @bat_priv: the bat priv with all the soft interface information + * @skb: unicast skb to decode + * @nc_packet: decode data needed to decode the skb + * + * Return: pointer to decoded unicast packet if the packet was decoded or NULL + * in case of an error. + */ +static struct batadv_unicast_packet * +batadv_nc_skb_decode_packet(struct batadv_priv *bat_priv, struct sk_buff *skb, + struct batadv_nc_packet *nc_packet) +{ + const int h_size = sizeof(struct batadv_unicast_packet); + const int h_diff = sizeof(struct batadv_coded_packet) - h_size; + struct batadv_unicast_packet *unicast_packet; + struct batadv_coded_packet coded_packet_tmp; + struct ethhdr *ethhdr, ethhdr_tmp; + u8 *orig_dest, ttl, ttvn; + unsigned int coding_len; + int err; + + /* Save headers temporarily */ + memcpy(&coded_packet_tmp, skb->data, sizeof(coded_packet_tmp)); + memcpy(ðhdr_tmp, skb_mac_header(skb), sizeof(ethhdr_tmp)); + + if (skb_cow(skb, 0) < 0) + return NULL; + + if (unlikely(!skb_pull_rcsum(skb, h_diff))) + return NULL; + + /* Data points to batman header, so set mac header 14 bytes before + * and network to data + */ + skb_set_mac_header(skb, -ETH_HLEN); + skb_reset_network_header(skb); + + /* Reconstruct original mac header */ + ethhdr = eth_hdr(skb); + *ethhdr = ethhdr_tmp; + + /* Select the correct unicast header information based on the location + * of our mac address in the coded_packet header + */ + if (batadv_is_my_mac(bat_priv, coded_packet_tmp.second_dest)) { + /* If we are the second destination the packet was overheard, + * so the Ethernet address must be copied to h_dest and + * pkt_type changed from PACKET_OTHERHOST to PACKET_HOST + */ + ether_addr_copy(ethhdr->h_dest, coded_packet_tmp.second_dest); + skb->pkt_type = PACKET_HOST; + + orig_dest = coded_packet_tmp.second_orig_dest; + ttl = coded_packet_tmp.second_ttl; + ttvn = coded_packet_tmp.second_ttvn; + } else { + orig_dest = coded_packet_tmp.first_orig_dest; + ttl = coded_packet_tmp.ttl; + ttvn = coded_packet_tmp.first_ttvn; + } + + coding_len = ntohs(coded_packet_tmp.coded_len); + + if (coding_len > skb->len) + return NULL; + + /* Here the magic is reversed: + * extract the missing packet from the received coded packet + */ + batadv_nc_memxor(skb->data + h_size, + nc_packet->skb->data + h_size, + coding_len); + + /* Resize decoded skb if decoded with larger packet */ + if (nc_packet->skb->len > coding_len + h_size) { + err = pskb_trim_rcsum(skb, coding_len + h_size); + if (err) + return NULL; + } + + /* Create decoded unicast packet */ + unicast_packet = (struct batadv_unicast_packet *)skb->data; + unicast_packet->packet_type = BATADV_UNICAST; + unicast_packet->version = BATADV_COMPAT_VERSION; + unicast_packet->ttl = ttl; + ether_addr_copy(unicast_packet->dest, orig_dest); + unicast_packet->ttvn = ttvn; + + batadv_nc_packet_free(nc_packet, false); + return unicast_packet; +} + +/** + * batadv_nc_find_decoding_packet() - search through buffered decoding data to + * find the data needed to decode the coded packet + * @bat_priv: the bat priv with all the soft interface information + * @ethhdr: pointer to the ethernet header inside the coded packet + * @coded: coded packet we try to find decode data for + * + * Return: pointer to nc packet if the needed data was found or NULL otherwise. + */ +static struct batadv_nc_packet * +batadv_nc_find_decoding_packet(struct batadv_priv *bat_priv, + struct ethhdr *ethhdr, + struct batadv_coded_packet *coded) +{ + struct batadv_hashtable *hash = bat_priv->nc.decoding_hash; + struct batadv_nc_packet *tmp_nc_packet, *nc_packet = NULL; + struct batadv_nc_path *nc_path, nc_path_key; + u8 *dest, *source; + __be32 packet_id; + int index; + + if (!hash) + return NULL; + + /* Select the correct packet id based on the location of our mac-addr */ + dest = ethhdr->h_source; + if (!batadv_is_my_mac(bat_priv, coded->second_dest)) { + source = coded->second_source; + packet_id = coded->second_crc; + } else { + source = coded->first_source; + packet_id = coded->first_crc; + } + + batadv_nc_hash_key_gen(&nc_path_key, source, dest); + index = batadv_nc_hash_choose(&nc_path_key, hash->size); + + /* Search for matching coding path */ + rcu_read_lock(); + hlist_for_each_entry_rcu(nc_path, &hash->table[index], hash_entry) { + /* Find matching nc_packet */ + spin_lock_bh(&nc_path->packet_list_lock); + list_for_each_entry(tmp_nc_packet, + &nc_path->packet_list, list) { + if (packet_id == tmp_nc_packet->packet_id) { + list_del(&tmp_nc_packet->list); + + nc_packet = tmp_nc_packet; + break; + } + } + spin_unlock_bh(&nc_path->packet_list_lock); + + if (nc_packet) + break; + } + rcu_read_unlock(); + + if (!nc_packet) + batadv_dbg(BATADV_DBG_NC, bat_priv, + "No decoding packet found for %u\n", packet_id); + + return nc_packet; +} + +/** + * batadv_nc_recv_coded_packet() - try to decode coded packet and enqueue the + * resulting unicast packet + * @skb: incoming coded packet + * @recv_if: pointer to interface this packet was received on + * + * Return: NET_RX_SUCCESS if the packet has been consumed or NET_RX_DROP + * otherwise. + */ +static int batadv_nc_recv_coded_packet(struct sk_buff *skb, + struct batadv_hard_iface *recv_if) +{ + struct batadv_priv *bat_priv = netdev_priv(recv_if->soft_iface); + struct batadv_unicast_packet *unicast_packet; + struct batadv_coded_packet *coded_packet; + struct batadv_nc_packet *nc_packet; + struct ethhdr *ethhdr; + int hdr_size = sizeof(*coded_packet); + + /* Check if network coding is enabled */ + if (!atomic_read(&bat_priv->network_coding)) + goto free_skb; + + /* Make sure we can access (and remove) header */ + if (unlikely(!pskb_may_pull(skb, hdr_size))) + goto free_skb; + + coded_packet = (struct batadv_coded_packet *)skb->data; + ethhdr = eth_hdr(skb); + + /* Verify frame is destined for us */ + if (!batadv_is_my_mac(bat_priv, ethhdr->h_dest) && + !batadv_is_my_mac(bat_priv, coded_packet->second_dest)) + goto free_skb; + + /* Update stat counter */ + if (batadv_is_my_mac(bat_priv, coded_packet->second_dest)) + batadv_inc_counter(bat_priv, BATADV_CNT_NC_SNIFFED); + + nc_packet = batadv_nc_find_decoding_packet(bat_priv, ethhdr, + coded_packet); + if (!nc_packet) { + batadv_inc_counter(bat_priv, BATADV_CNT_NC_DECODE_FAILED); + goto free_skb; + } + + /* Make skb's linear, because decoding accesses the entire buffer */ + if (skb_linearize(skb) < 0) + goto free_nc_packet; + + if (skb_linearize(nc_packet->skb) < 0) + goto free_nc_packet; + + /* Decode the packet */ + unicast_packet = batadv_nc_skb_decode_packet(bat_priv, skb, nc_packet); + if (!unicast_packet) { + batadv_inc_counter(bat_priv, BATADV_CNT_NC_DECODE_FAILED); + goto free_nc_packet; + } + + /* Mark packet as decoded to do correct recoding when forwarding */ + BATADV_SKB_CB(skb)->decoded = true; + batadv_inc_counter(bat_priv, BATADV_CNT_NC_DECODE); + batadv_add_counter(bat_priv, BATADV_CNT_NC_DECODE_BYTES, + skb->len + ETH_HLEN); + return batadv_recv_unicast_packet(skb, recv_if); + +free_nc_packet: + batadv_nc_packet_free(nc_packet, true); +free_skb: + kfree_skb(skb); + + return NET_RX_DROP; +} + +/** + * batadv_nc_mesh_free() - clean up network coding memory + * @bat_priv: the bat priv with all the soft interface information + */ +void batadv_nc_mesh_free(struct batadv_priv *bat_priv) +{ + batadv_tvlv_container_unregister(bat_priv, BATADV_TVLV_NC, 1); + batadv_tvlv_handler_unregister(bat_priv, BATADV_TVLV_NC, 1); + cancel_delayed_work_sync(&bat_priv->nc.work); + + batadv_nc_purge_paths(bat_priv, bat_priv->nc.coding_hash, NULL); + batadv_hash_destroy(bat_priv->nc.coding_hash); + batadv_nc_purge_paths(bat_priv, bat_priv->nc.decoding_hash, NULL); + batadv_hash_destroy(bat_priv->nc.decoding_hash); +} diff --git a/net/batman-adv/network-coding.h b/net/batman-adv/network-coding.h new file mode 100644 index 000000000..368cc3130 --- /dev/null +++ b/net/batman-adv/network-coding.h @@ -0,0 +1,106 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Martin Hundebøll, Jeppe Ledet-Pedersen + */ + +#ifndef _NET_BATMAN_ADV_NETWORK_CODING_H_ +#define _NET_BATMAN_ADV_NETWORK_CODING_H_ + +#include "main.h" + +#include +#include +#include +#include + +#ifdef CONFIG_BATMAN_ADV_NC + +void batadv_nc_status_update(struct net_device *net_dev); +int batadv_nc_init(void); +int batadv_nc_mesh_init(struct batadv_priv *bat_priv); +void batadv_nc_mesh_free(struct batadv_priv *bat_priv); +void batadv_nc_update_nc_node(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + struct batadv_orig_node *orig_neigh_node, + struct batadv_ogm_packet *ogm_packet, + int is_single_hop_neigh); +void batadv_nc_purge_orig(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + bool (*to_purge)(struct batadv_priv *, + struct batadv_nc_node *)); +void batadv_nc_init_bat_priv(struct batadv_priv *bat_priv); +void batadv_nc_init_orig(struct batadv_orig_node *orig_node); +bool batadv_nc_skb_forward(struct sk_buff *skb, + struct batadv_neigh_node *neigh_node); +void batadv_nc_skb_store_for_decoding(struct batadv_priv *bat_priv, + struct sk_buff *skb); +void batadv_nc_skb_store_sniffed_unicast(struct batadv_priv *bat_priv, + struct sk_buff *skb); + +#else /* ifdef CONFIG_BATMAN_ADV_NC */ + +static inline void batadv_nc_status_update(struct net_device *net_dev) +{ +} + +static inline int batadv_nc_init(void) +{ + return 0; +} + +static inline int batadv_nc_mesh_init(struct batadv_priv *bat_priv) +{ + return 0; +} + +static inline void batadv_nc_mesh_free(struct batadv_priv *bat_priv) +{ +} + +static inline void +batadv_nc_update_nc_node(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + struct batadv_orig_node *orig_neigh_node, + struct batadv_ogm_packet *ogm_packet, + int is_single_hop_neigh) +{ +} + +static inline void +batadv_nc_purge_orig(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + bool (*to_purge)(struct batadv_priv *, + struct batadv_nc_node *)) +{ +} + +static inline void batadv_nc_init_bat_priv(struct batadv_priv *bat_priv) +{ +} + +static inline void batadv_nc_init_orig(struct batadv_orig_node *orig_node) +{ +} + +static inline bool batadv_nc_skb_forward(struct sk_buff *skb, + struct batadv_neigh_node *neigh_node) +{ + return false; +} + +static inline void +batadv_nc_skb_store_for_decoding(struct batadv_priv *bat_priv, + struct sk_buff *skb) +{ +} + +static inline void +batadv_nc_skb_store_sniffed_unicast(struct batadv_priv *bat_priv, + struct sk_buff *skb) +{ +} + +#endif /* ifdef CONFIG_BATMAN_ADV_NC */ + +#endif /* _NET_BATMAN_ADV_NETWORK_CODING_H_ */ diff --git a/net/batman-adv/originator.c b/net/batman-adv/originator.c new file mode 100644 index 000000000..34903df4f --- /dev/null +++ b/net/batman-adv/originator.c @@ -0,0 +1,1349 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#include "originator.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "bat_algo.h" +#include "distributed-arp-table.h" +#include "fragmentation.h" +#include "gateway_client.h" +#include "hard-interface.h" +#include "hash.h" +#include "log.h" +#include "multicast.h" +#include "netlink.h" +#include "network-coding.h" +#include "routing.h" +#include "soft-interface.h" +#include "translation-table.h" + +/* hash class keys */ +static struct lock_class_key batadv_orig_hash_lock_class_key; + +/** + * batadv_orig_hash_find() - Find and return originator from orig_hash + * @bat_priv: the bat priv with all the soft interface information + * @data: mac address of the originator + * + * Return: orig_node (with increased refcnt), NULL on errors + */ +struct batadv_orig_node * +batadv_orig_hash_find(struct batadv_priv *bat_priv, const void *data) +{ + struct batadv_hashtable *hash = bat_priv->orig_hash; + struct hlist_head *head; + struct batadv_orig_node *orig_node, *orig_node_tmp = NULL; + int index; + + if (!hash) + return NULL; + + index = batadv_choose_orig(data, hash->size); + head = &hash->table[index]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(orig_node, head, hash_entry) { + if (!batadv_compare_eth(orig_node, data)) + continue; + + if (!kref_get_unless_zero(&orig_node->refcount)) + continue; + + orig_node_tmp = orig_node; + break; + } + rcu_read_unlock(); + + return orig_node_tmp; +} + +static void batadv_purge_orig(struct work_struct *work); + +/** + * batadv_compare_orig() - comparing function used in the originator hash table + * @node: node in the local table + * @data2: second object to compare the node to + * + * Return: true if they are the same originator + */ +bool batadv_compare_orig(const struct hlist_node *node, const void *data2) +{ + const void *data1 = container_of(node, struct batadv_orig_node, + hash_entry); + + return batadv_compare_eth(data1, data2); +} + +/** + * batadv_orig_node_vlan_get() - get an orig_node_vlan object + * @orig_node: the originator serving the VLAN + * @vid: the VLAN identifier + * + * Return: the vlan object identified by vid and belonging to orig_node or NULL + * if it does not exist. + */ +struct batadv_orig_node_vlan * +batadv_orig_node_vlan_get(struct batadv_orig_node *orig_node, + unsigned short vid) +{ + struct batadv_orig_node_vlan *vlan = NULL, *tmp; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp, &orig_node->vlan_list, list) { + if (tmp->vid != vid) + continue; + + if (!kref_get_unless_zero(&tmp->refcount)) + continue; + + vlan = tmp; + + break; + } + rcu_read_unlock(); + + return vlan; +} + +/** + * batadv_orig_node_vlan_new() - search and possibly create an orig_node_vlan + * object + * @orig_node: the originator serving the VLAN + * @vid: the VLAN identifier + * + * Return: NULL in case of failure or the vlan object identified by vid and + * belonging to orig_node otherwise. The object is created and added to the list + * if it does not exist. + * + * The object is returned with refcounter increased by 1. + */ +struct batadv_orig_node_vlan * +batadv_orig_node_vlan_new(struct batadv_orig_node *orig_node, + unsigned short vid) +{ + struct batadv_orig_node_vlan *vlan; + + spin_lock_bh(&orig_node->vlan_list_lock); + + /* first look if an object for this vid already exists */ + vlan = batadv_orig_node_vlan_get(orig_node, vid); + if (vlan) + goto out; + + vlan = kzalloc(sizeof(*vlan), GFP_ATOMIC); + if (!vlan) + goto out; + + kref_init(&vlan->refcount); + vlan->vid = vid; + + kref_get(&vlan->refcount); + hlist_add_head_rcu(&vlan->list, &orig_node->vlan_list); + +out: + spin_unlock_bh(&orig_node->vlan_list_lock); + + return vlan; +} + +/** + * batadv_orig_node_vlan_release() - release originator-vlan object from lists + * and queue for free after rcu grace period + * @ref: kref pointer of the originator-vlan object + */ +void batadv_orig_node_vlan_release(struct kref *ref) +{ + struct batadv_orig_node_vlan *orig_vlan; + + orig_vlan = container_of(ref, struct batadv_orig_node_vlan, refcount); + + kfree_rcu(orig_vlan, rcu); +} + +/** + * batadv_originator_init() - Initialize all originator structures + * @bat_priv: the bat priv with all the soft interface information + * + * Return: 0 on success or negative error number in case of failure + */ +int batadv_originator_init(struct batadv_priv *bat_priv) +{ + if (bat_priv->orig_hash) + return 0; + + bat_priv->orig_hash = batadv_hash_new(1024); + + if (!bat_priv->orig_hash) + goto err; + + batadv_hash_set_lock_class(bat_priv->orig_hash, + &batadv_orig_hash_lock_class_key); + + INIT_DELAYED_WORK(&bat_priv->orig_work, batadv_purge_orig); + queue_delayed_work(batadv_event_workqueue, + &bat_priv->orig_work, + msecs_to_jiffies(BATADV_ORIG_WORK_PERIOD)); + + return 0; + +err: + return -ENOMEM; +} + +/** + * batadv_neigh_ifinfo_release() - release neigh_ifinfo from lists and queue for + * free after rcu grace period + * @ref: kref pointer of the neigh_ifinfo + */ +void batadv_neigh_ifinfo_release(struct kref *ref) +{ + struct batadv_neigh_ifinfo *neigh_ifinfo; + + neigh_ifinfo = container_of(ref, struct batadv_neigh_ifinfo, refcount); + + if (neigh_ifinfo->if_outgoing != BATADV_IF_DEFAULT) + batadv_hardif_put(neigh_ifinfo->if_outgoing); + + kfree_rcu(neigh_ifinfo, rcu); +} + +/** + * batadv_hardif_neigh_release() - release hardif neigh node from lists and + * queue for free after rcu grace period + * @ref: kref pointer of the neigh_node + */ +void batadv_hardif_neigh_release(struct kref *ref) +{ + struct batadv_hardif_neigh_node *hardif_neigh; + + hardif_neigh = container_of(ref, struct batadv_hardif_neigh_node, + refcount); + + spin_lock_bh(&hardif_neigh->if_incoming->neigh_list_lock); + hlist_del_init_rcu(&hardif_neigh->list); + spin_unlock_bh(&hardif_neigh->if_incoming->neigh_list_lock); + + batadv_hardif_put(hardif_neigh->if_incoming); + kfree_rcu(hardif_neigh, rcu); +} + +/** + * batadv_neigh_node_release() - release neigh_node from lists and queue for + * free after rcu grace period + * @ref: kref pointer of the neigh_node + */ +void batadv_neigh_node_release(struct kref *ref) +{ + struct hlist_node *node_tmp; + struct batadv_neigh_node *neigh_node; + struct batadv_neigh_ifinfo *neigh_ifinfo; + + neigh_node = container_of(ref, struct batadv_neigh_node, refcount); + + hlist_for_each_entry_safe(neigh_ifinfo, node_tmp, + &neigh_node->ifinfo_list, list) { + batadv_neigh_ifinfo_put(neigh_ifinfo); + } + + batadv_hardif_neigh_put(neigh_node->hardif_neigh); + + batadv_hardif_put(neigh_node->if_incoming); + + kfree_rcu(neigh_node, rcu); +} + +/** + * batadv_orig_router_get() - router to the originator depending on iface + * @orig_node: the orig node for the router + * @if_outgoing: the interface where the payload packet has been received or + * the OGM should be sent to + * + * Return: the neighbor which should be the router for this orig_node/iface. + * + * The object is returned with refcounter increased by 1. + */ +struct batadv_neigh_node * +batadv_orig_router_get(struct batadv_orig_node *orig_node, + const struct batadv_hard_iface *if_outgoing) +{ + struct batadv_orig_ifinfo *orig_ifinfo; + struct batadv_neigh_node *router = NULL; + + rcu_read_lock(); + hlist_for_each_entry_rcu(orig_ifinfo, &orig_node->ifinfo_list, list) { + if (orig_ifinfo->if_outgoing != if_outgoing) + continue; + + router = rcu_dereference(orig_ifinfo->router); + break; + } + + if (router && !kref_get_unless_zero(&router->refcount)) + router = NULL; + + rcu_read_unlock(); + return router; +} + +/** + * batadv_orig_ifinfo_get() - find the ifinfo from an orig_node + * @orig_node: the orig node to be queried + * @if_outgoing: the interface for which the ifinfo should be acquired + * + * Return: the requested orig_ifinfo or NULL if not found. + * + * The object is returned with refcounter increased by 1. + */ +struct batadv_orig_ifinfo * +batadv_orig_ifinfo_get(struct batadv_orig_node *orig_node, + struct batadv_hard_iface *if_outgoing) +{ + struct batadv_orig_ifinfo *tmp, *orig_ifinfo = NULL; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp, &orig_node->ifinfo_list, + list) { + if (tmp->if_outgoing != if_outgoing) + continue; + + if (!kref_get_unless_zero(&tmp->refcount)) + continue; + + orig_ifinfo = tmp; + break; + } + rcu_read_unlock(); + + return orig_ifinfo; +} + +/** + * batadv_orig_ifinfo_new() - search and possibly create an orig_ifinfo object + * @orig_node: the orig node to be queried + * @if_outgoing: the interface for which the ifinfo should be acquired + * + * Return: NULL in case of failure or the orig_ifinfo object for the if_outgoing + * interface otherwise. The object is created and added to the list + * if it does not exist. + * + * The object is returned with refcounter increased by 1. + */ +struct batadv_orig_ifinfo * +batadv_orig_ifinfo_new(struct batadv_orig_node *orig_node, + struct batadv_hard_iface *if_outgoing) +{ + struct batadv_orig_ifinfo *orig_ifinfo; + unsigned long reset_time; + + spin_lock_bh(&orig_node->neigh_list_lock); + + orig_ifinfo = batadv_orig_ifinfo_get(orig_node, if_outgoing); + if (orig_ifinfo) + goto out; + + orig_ifinfo = kzalloc(sizeof(*orig_ifinfo), GFP_ATOMIC); + if (!orig_ifinfo) + goto out; + + if (if_outgoing != BATADV_IF_DEFAULT) + kref_get(&if_outgoing->refcount); + + reset_time = jiffies - 1; + reset_time -= msecs_to_jiffies(BATADV_RESET_PROTECTION_MS); + orig_ifinfo->batman_seqno_reset = reset_time; + orig_ifinfo->if_outgoing = if_outgoing; + INIT_HLIST_NODE(&orig_ifinfo->list); + kref_init(&orig_ifinfo->refcount); + + kref_get(&orig_ifinfo->refcount); + hlist_add_head_rcu(&orig_ifinfo->list, + &orig_node->ifinfo_list); +out: + spin_unlock_bh(&orig_node->neigh_list_lock); + return orig_ifinfo; +} + +/** + * batadv_neigh_ifinfo_get() - find the ifinfo from an neigh_node + * @neigh: the neigh node to be queried + * @if_outgoing: the interface for which the ifinfo should be acquired + * + * The object is returned with refcounter increased by 1. + * + * Return: the requested neigh_ifinfo or NULL if not found + */ +struct batadv_neigh_ifinfo * +batadv_neigh_ifinfo_get(struct batadv_neigh_node *neigh, + struct batadv_hard_iface *if_outgoing) +{ + struct batadv_neigh_ifinfo *neigh_ifinfo = NULL, + *tmp_neigh_ifinfo; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp_neigh_ifinfo, &neigh->ifinfo_list, + list) { + if (tmp_neigh_ifinfo->if_outgoing != if_outgoing) + continue; + + if (!kref_get_unless_zero(&tmp_neigh_ifinfo->refcount)) + continue; + + neigh_ifinfo = tmp_neigh_ifinfo; + break; + } + rcu_read_unlock(); + + return neigh_ifinfo; +} + +/** + * batadv_neigh_ifinfo_new() - search and possibly create an neigh_ifinfo object + * @neigh: the neigh node to be queried + * @if_outgoing: the interface for which the ifinfo should be acquired + * + * Return: NULL in case of failure or the neigh_ifinfo object for the + * if_outgoing interface otherwise. The object is created and added to the list + * if it does not exist. + * + * The object is returned with refcounter increased by 1. + */ +struct batadv_neigh_ifinfo * +batadv_neigh_ifinfo_new(struct batadv_neigh_node *neigh, + struct batadv_hard_iface *if_outgoing) +{ + struct batadv_neigh_ifinfo *neigh_ifinfo; + + spin_lock_bh(&neigh->ifinfo_lock); + + neigh_ifinfo = batadv_neigh_ifinfo_get(neigh, if_outgoing); + if (neigh_ifinfo) + goto out; + + neigh_ifinfo = kzalloc(sizeof(*neigh_ifinfo), GFP_ATOMIC); + if (!neigh_ifinfo) + goto out; + + if (if_outgoing) + kref_get(&if_outgoing->refcount); + + INIT_HLIST_NODE(&neigh_ifinfo->list); + kref_init(&neigh_ifinfo->refcount); + neigh_ifinfo->if_outgoing = if_outgoing; + + kref_get(&neigh_ifinfo->refcount); + hlist_add_head_rcu(&neigh_ifinfo->list, &neigh->ifinfo_list); + +out: + spin_unlock_bh(&neigh->ifinfo_lock); + + return neigh_ifinfo; +} + +/** + * batadv_neigh_node_get() - retrieve a neighbour from the list + * @orig_node: originator which the neighbour belongs to + * @hard_iface: the interface where this neighbour is connected to + * @addr: the address of the neighbour + * + * Looks for and possibly returns a neighbour belonging to this originator list + * which is connected through the provided hard interface. + * + * Return: neighbor when found. Otherwise NULL + */ +static struct batadv_neigh_node * +batadv_neigh_node_get(const struct batadv_orig_node *orig_node, + const struct batadv_hard_iface *hard_iface, + const u8 *addr) +{ + struct batadv_neigh_node *tmp_neigh_node, *res = NULL; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp_neigh_node, &orig_node->neigh_list, list) { + if (!batadv_compare_eth(tmp_neigh_node->addr, addr)) + continue; + + if (tmp_neigh_node->if_incoming != hard_iface) + continue; + + if (!kref_get_unless_zero(&tmp_neigh_node->refcount)) + continue; + + res = tmp_neigh_node; + break; + } + rcu_read_unlock(); + + return res; +} + +/** + * batadv_hardif_neigh_create() - create a hardif neighbour node + * @hard_iface: the interface this neighbour is connected to + * @neigh_addr: the interface address of the neighbour to retrieve + * @orig_node: originator object representing the neighbour + * + * Return: the hardif neighbour node if found or created or NULL otherwise. + */ +static struct batadv_hardif_neigh_node * +batadv_hardif_neigh_create(struct batadv_hard_iface *hard_iface, + const u8 *neigh_addr, + struct batadv_orig_node *orig_node) +{ + struct batadv_priv *bat_priv = netdev_priv(hard_iface->soft_iface); + struct batadv_hardif_neigh_node *hardif_neigh; + + spin_lock_bh(&hard_iface->neigh_list_lock); + + /* check if neighbor hasn't been added in the meantime */ + hardif_neigh = batadv_hardif_neigh_get(hard_iface, neigh_addr); + if (hardif_neigh) + goto out; + + hardif_neigh = kzalloc(sizeof(*hardif_neigh), GFP_ATOMIC); + if (!hardif_neigh) + goto out; + + kref_get(&hard_iface->refcount); + INIT_HLIST_NODE(&hardif_neigh->list); + ether_addr_copy(hardif_neigh->addr, neigh_addr); + ether_addr_copy(hardif_neigh->orig, orig_node->orig); + hardif_neigh->if_incoming = hard_iface; + hardif_neigh->last_seen = jiffies; + + kref_init(&hardif_neigh->refcount); + + if (bat_priv->algo_ops->neigh.hardif_init) + bat_priv->algo_ops->neigh.hardif_init(hardif_neigh); + + hlist_add_head_rcu(&hardif_neigh->list, &hard_iface->neigh_list); + +out: + spin_unlock_bh(&hard_iface->neigh_list_lock); + return hardif_neigh; +} + +/** + * batadv_hardif_neigh_get_or_create() - retrieve or create a hardif neighbour + * node + * @hard_iface: the interface this neighbour is connected to + * @neigh_addr: the interface address of the neighbour to retrieve + * @orig_node: originator object representing the neighbour + * + * Return: the hardif neighbour node if found or created or NULL otherwise. + */ +static struct batadv_hardif_neigh_node * +batadv_hardif_neigh_get_or_create(struct batadv_hard_iface *hard_iface, + const u8 *neigh_addr, + struct batadv_orig_node *orig_node) +{ + struct batadv_hardif_neigh_node *hardif_neigh; + + /* first check without locking to avoid the overhead */ + hardif_neigh = batadv_hardif_neigh_get(hard_iface, neigh_addr); + if (hardif_neigh) + return hardif_neigh; + + return batadv_hardif_neigh_create(hard_iface, neigh_addr, orig_node); +} + +/** + * batadv_hardif_neigh_get() - retrieve a hardif neighbour from the list + * @hard_iface: the interface where this neighbour is connected to + * @neigh_addr: the address of the neighbour + * + * Looks for and possibly returns a neighbour belonging to this hard interface. + * + * Return: neighbor when found. Otherwise NULL + */ +struct batadv_hardif_neigh_node * +batadv_hardif_neigh_get(const struct batadv_hard_iface *hard_iface, + const u8 *neigh_addr) +{ + struct batadv_hardif_neigh_node *tmp_hardif_neigh, *hardif_neigh = NULL; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp_hardif_neigh, + &hard_iface->neigh_list, list) { + if (!batadv_compare_eth(tmp_hardif_neigh->addr, neigh_addr)) + continue; + + if (!kref_get_unless_zero(&tmp_hardif_neigh->refcount)) + continue; + + hardif_neigh = tmp_hardif_neigh; + break; + } + rcu_read_unlock(); + + return hardif_neigh; +} + +/** + * batadv_neigh_node_create() - create a neigh node object + * @orig_node: originator object representing the neighbour + * @hard_iface: the interface where the neighbour is connected to + * @neigh_addr: the mac address of the neighbour interface + * + * Allocates a new neigh_node object and initialises all the generic fields. + * + * Return: the neighbour node if found or created or NULL otherwise. + */ +static struct batadv_neigh_node * +batadv_neigh_node_create(struct batadv_orig_node *orig_node, + struct batadv_hard_iface *hard_iface, + const u8 *neigh_addr) +{ + struct batadv_neigh_node *neigh_node; + struct batadv_hardif_neigh_node *hardif_neigh = NULL; + + spin_lock_bh(&orig_node->neigh_list_lock); + + neigh_node = batadv_neigh_node_get(orig_node, hard_iface, neigh_addr); + if (neigh_node) + goto out; + + hardif_neigh = batadv_hardif_neigh_get_or_create(hard_iface, + neigh_addr, orig_node); + if (!hardif_neigh) + goto out; + + neigh_node = kzalloc(sizeof(*neigh_node), GFP_ATOMIC); + if (!neigh_node) + goto out; + + INIT_HLIST_NODE(&neigh_node->list); + INIT_HLIST_HEAD(&neigh_node->ifinfo_list); + spin_lock_init(&neigh_node->ifinfo_lock); + + kref_get(&hard_iface->refcount); + ether_addr_copy(neigh_node->addr, neigh_addr); + neigh_node->if_incoming = hard_iface; + neigh_node->orig_node = orig_node; + neigh_node->last_seen = jiffies; + + /* increment unique neighbor refcount */ + kref_get(&hardif_neigh->refcount); + neigh_node->hardif_neigh = hardif_neigh; + + /* extra reference for return */ + kref_init(&neigh_node->refcount); + + kref_get(&neigh_node->refcount); + hlist_add_head_rcu(&neigh_node->list, &orig_node->neigh_list); + + batadv_dbg(BATADV_DBG_BATMAN, orig_node->bat_priv, + "Creating new neighbor %pM for orig_node %pM on interface %s\n", + neigh_addr, orig_node->orig, hard_iface->net_dev->name); + +out: + spin_unlock_bh(&orig_node->neigh_list_lock); + + batadv_hardif_neigh_put(hardif_neigh); + return neigh_node; +} + +/** + * batadv_neigh_node_get_or_create() - retrieve or create a neigh node object + * @orig_node: originator object representing the neighbour + * @hard_iface: the interface where the neighbour is connected to + * @neigh_addr: the mac address of the neighbour interface + * + * Return: the neighbour node if found or created or NULL otherwise. + */ +struct batadv_neigh_node * +batadv_neigh_node_get_or_create(struct batadv_orig_node *orig_node, + struct batadv_hard_iface *hard_iface, + const u8 *neigh_addr) +{ + struct batadv_neigh_node *neigh_node; + + /* first check without locking to avoid the overhead */ + neigh_node = batadv_neigh_node_get(orig_node, hard_iface, neigh_addr); + if (neigh_node) + return neigh_node; + + return batadv_neigh_node_create(orig_node, hard_iface, neigh_addr); +} + +/** + * batadv_hardif_neigh_dump() - Dump to netlink the neighbor infos for a + * specific outgoing interface + * @msg: message to dump into + * @cb: parameters for the dump + * + * Return: 0 or error value + */ +int batadv_hardif_neigh_dump(struct sk_buff *msg, struct netlink_callback *cb) +{ + struct net *net = sock_net(cb->skb->sk); + struct net_device *soft_iface; + struct net_device *hard_iface = NULL; + struct batadv_hard_iface *hardif = BATADV_IF_DEFAULT; + struct batadv_priv *bat_priv; + struct batadv_hard_iface *primary_if = NULL; + int ret; + int ifindex, hard_ifindex; + + ifindex = batadv_netlink_get_ifindex(cb->nlh, BATADV_ATTR_MESH_IFINDEX); + if (!ifindex) + return -EINVAL; + + soft_iface = dev_get_by_index(net, ifindex); + if (!soft_iface || !batadv_softif_is_valid(soft_iface)) { + ret = -ENODEV; + goto out; + } + + bat_priv = netdev_priv(soft_iface); + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if || primary_if->if_status != BATADV_IF_ACTIVE) { + ret = -ENOENT; + goto out; + } + + hard_ifindex = batadv_netlink_get_ifindex(cb->nlh, + BATADV_ATTR_HARD_IFINDEX); + if (hard_ifindex) { + hard_iface = dev_get_by_index(net, hard_ifindex); + if (hard_iface) + hardif = batadv_hardif_get_by_netdev(hard_iface); + + if (!hardif) { + ret = -ENODEV; + goto out; + } + + if (hardif->soft_iface != soft_iface) { + ret = -ENOENT; + goto out; + } + } + + if (!bat_priv->algo_ops->neigh.dump) { + ret = -EOPNOTSUPP; + goto out; + } + + bat_priv->algo_ops->neigh.dump(msg, cb, bat_priv, hardif); + + ret = msg->len; + + out: + batadv_hardif_put(hardif); + dev_put(hard_iface); + batadv_hardif_put(primary_if); + dev_put(soft_iface); + + return ret; +} + +/** + * batadv_orig_ifinfo_release() - release orig_ifinfo from lists and queue for + * free after rcu grace period + * @ref: kref pointer of the orig_ifinfo + */ +void batadv_orig_ifinfo_release(struct kref *ref) +{ + struct batadv_orig_ifinfo *orig_ifinfo; + struct batadv_neigh_node *router; + + orig_ifinfo = container_of(ref, struct batadv_orig_ifinfo, refcount); + + if (orig_ifinfo->if_outgoing != BATADV_IF_DEFAULT) + batadv_hardif_put(orig_ifinfo->if_outgoing); + + /* this is the last reference to this object */ + router = rcu_dereference_protected(orig_ifinfo->router, true); + batadv_neigh_node_put(router); + + kfree_rcu(orig_ifinfo, rcu); +} + +/** + * batadv_orig_node_free_rcu() - free the orig_node + * @rcu: rcu pointer of the orig_node + */ +static void batadv_orig_node_free_rcu(struct rcu_head *rcu) +{ + struct batadv_orig_node *orig_node; + + orig_node = container_of(rcu, struct batadv_orig_node, rcu); + + batadv_mcast_purge_orig(orig_node); + + batadv_frag_purge_orig(orig_node, NULL); + + kfree(orig_node->tt_buff); + kfree(orig_node); +} + +/** + * batadv_orig_node_release() - release orig_node from lists and queue for + * free after rcu grace period + * @ref: kref pointer of the orig_node + */ +void batadv_orig_node_release(struct kref *ref) +{ + struct hlist_node *node_tmp; + struct batadv_neigh_node *neigh_node; + struct batadv_orig_node *orig_node; + struct batadv_orig_ifinfo *orig_ifinfo; + struct batadv_orig_node_vlan *vlan; + struct batadv_orig_ifinfo *last_candidate; + + orig_node = container_of(ref, struct batadv_orig_node, refcount); + + spin_lock_bh(&orig_node->neigh_list_lock); + + /* for all neighbors towards this originator ... */ + hlist_for_each_entry_safe(neigh_node, node_tmp, + &orig_node->neigh_list, list) { + hlist_del_rcu(&neigh_node->list); + batadv_neigh_node_put(neigh_node); + } + + hlist_for_each_entry_safe(orig_ifinfo, node_tmp, + &orig_node->ifinfo_list, list) { + hlist_del_rcu(&orig_ifinfo->list); + batadv_orig_ifinfo_put(orig_ifinfo); + } + + last_candidate = orig_node->last_bonding_candidate; + orig_node->last_bonding_candidate = NULL; + spin_unlock_bh(&orig_node->neigh_list_lock); + + batadv_orig_ifinfo_put(last_candidate); + + spin_lock_bh(&orig_node->vlan_list_lock); + hlist_for_each_entry_safe(vlan, node_tmp, &orig_node->vlan_list, list) { + hlist_del_rcu(&vlan->list); + batadv_orig_node_vlan_put(vlan); + } + spin_unlock_bh(&orig_node->vlan_list_lock); + + /* Free nc_nodes */ + batadv_nc_purge_orig(orig_node->bat_priv, orig_node, NULL); + + call_rcu(&orig_node->rcu, batadv_orig_node_free_rcu); +} + +/** + * batadv_originator_free() - Free all originator structures + * @bat_priv: the bat priv with all the soft interface information + */ +void batadv_originator_free(struct batadv_priv *bat_priv) +{ + struct batadv_hashtable *hash = bat_priv->orig_hash; + struct hlist_node *node_tmp; + struct hlist_head *head; + spinlock_t *list_lock; /* spinlock to protect write access */ + struct batadv_orig_node *orig_node; + u32 i; + + if (!hash) + return; + + cancel_delayed_work_sync(&bat_priv->orig_work); + + bat_priv->orig_hash = NULL; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + list_lock = &hash->list_locks[i]; + + spin_lock_bh(list_lock); + hlist_for_each_entry_safe(orig_node, node_tmp, + head, hash_entry) { + hlist_del_rcu(&orig_node->hash_entry); + batadv_orig_node_put(orig_node); + } + spin_unlock_bh(list_lock); + } + + batadv_hash_destroy(hash); +} + +/** + * batadv_orig_node_new() - creates a new orig_node + * @bat_priv: the bat priv with all the soft interface information + * @addr: the mac address of the originator + * + * Creates a new originator object and initialises all the generic fields. + * The new object is not added to the originator list. + * + * Return: the newly created object or NULL on failure. + */ +struct batadv_orig_node *batadv_orig_node_new(struct batadv_priv *bat_priv, + const u8 *addr) +{ + struct batadv_orig_node *orig_node; + struct batadv_orig_node_vlan *vlan; + unsigned long reset_time; + int i; + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Creating new originator: %pM\n", addr); + + orig_node = kzalloc(sizeof(*orig_node), GFP_ATOMIC); + if (!orig_node) + return NULL; + + INIT_HLIST_HEAD(&orig_node->neigh_list); + INIT_HLIST_HEAD(&orig_node->vlan_list); + INIT_HLIST_HEAD(&orig_node->ifinfo_list); + spin_lock_init(&orig_node->bcast_seqno_lock); + spin_lock_init(&orig_node->neigh_list_lock); + spin_lock_init(&orig_node->tt_buff_lock); + spin_lock_init(&orig_node->tt_lock); + spin_lock_init(&orig_node->vlan_list_lock); + + batadv_nc_init_orig(orig_node); + + /* extra reference for return */ + kref_init(&orig_node->refcount); + + orig_node->bat_priv = bat_priv; + ether_addr_copy(orig_node->orig, addr); + batadv_dat_init_orig_node_addr(orig_node); + atomic_set(&orig_node->last_ttvn, 0); + orig_node->tt_buff = NULL; + orig_node->tt_buff_len = 0; + orig_node->last_seen = jiffies; + reset_time = jiffies - 1 - msecs_to_jiffies(BATADV_RESET_PROTECTION_MS); + orig_node->bcast_seqno_reset = reset_time; + +#ifdef CONFIG_BATMAN_ADV_MCAST + orig_node->mcast_flags = BATADV_MCAST_WANT_NO_RTR4; + orig_node->mcast_flags |= BATADV_MCAST_WANT_NO_RTR6; + INIT_HLIST_NODE(&orig_node->mcast_want_all_unsnoopables_node); + INIT_HLIST_NODE(&orig_node->mcast_want_all_ipv4_node); + INIT_HLIST_NODE(&orig_node->mcast_want_all_ipv6_node); + spin_lock_init(&orig_node->mcast_handler_lock); +#endif + + /* create a vlan object for the "untagged" LAN */ + vlan = batadv_orig_node_vlan_new(orig_node, BATADV_NO_FLAGS); + if (!vlan) + goto free_orig_node; + /* batadv_orig_node_vlan_new() increases the refcounter. + * Immediately release vlan since it is not needed anymore in this + * context + */ + batadv_orig_node_vlan_put(vlan); + + for (i = 0; i < BATADV_FRAG_BUFFER_COUNT; i++) { + INIT_HLIST_HEAD(&orig_node->fragments[i].fragment_list); + spin_lock_init(&orig_node->fragments[i].lock); + orig_node->fragments[i].size = 0; + } + + return orig_node; +free_orig_node: + kfree(orig_node); + return NULL; +} + +/** + * batadv_purge_neigh_ifinfo() - purge obsolete ifinfo entries from neighbor + * @bat_priv: the bat priv with all the soft interface information + * @neigh: orig node which is to be checked + */ +static void +batadv_purge_neigh_ifinfo(struct batadv_priv *bat_priv, + struct batadv_neigh_node *neigh) +{ + struct batadv_neigh_ifinfo *neigh_ifinfo; + struct batadv_hard_iface *if_outgoing; + struct hlist_node *node_tmp; + + spin_lock_bh(&neigh->ifinfo_lock); + + /* for all ifinfo objects for this neighinator */ + hlist_for_each_entry_safe(neigh_ifinfo, node_tmp, + &neigh->ifinfo_list, list) { + if_outgoing = neigh_ifinfo->if_outgoing; + + /* always keep the default interface */ + if (if_outgoing == BATADV_IF_DEFAULT) + continue; + + /* don't purge if the interface is not (going) down */ + if (if_outgoing->if_status != BATADV_IF_INACTIVE && + if_outgoing->if_status != BATADV_IF_NOT_IN_USE && + if_outgoing->if_status != BATADV_IF_TO_BE_REMOVED) + continue; + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "neighbor/ifinfo purge: neighbor %pM, iface: %s\n", + neigh->addr, if_outgoing->net_dev->name); + + hlist_del_rcu(&neigh_ifinfo->list); + batadv_neigh_ifinfo_put(neigh_ifinfo); + } + + spin_unlock_bh(&neigh->ifinfo_lock); +} + +/** + * batadv_purge_orig_ifinfo() - purge obsolete ifinfo entries from originator + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: orig node which is to be checked + * + * Return: true if any ifinfo entry was purged, false otherwise. + */ +static bool +batadv_purge_orig_ifinfo(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node) +{ + struct batadv_orig_ifinfo *orig_ifinfo; + struct batadv_hard_iface *if_outgoing; + struct hlist_node *node_tmp; + bool ifinfo_purged = false; + + spin_lock_bh(&orig_node->neigh_list_lock); + + /* for all ifinfo objects for this originator */ + hlist_for_each_entry_safe(orig_ifinfo, node_tmp, + &orig_node->ifinfo_list, list) { + if_outgoing = orig_ifinfo->if_outgoing; + + /* always keep the default interface */ + if (if_outgoing == BATADV_IF_DEFAULT) + continue; + + /* don't purge if the interface is not (going) down */ + if (if_outgoing->if_status != BATADV_IF_INACTIVE && + if_outgoing->if_status != BATADV_IF_NOT_IN_USE && + if_outgoing->if_status != BATADV_IF_TO_BE_REMOVED) + continue; + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "router/ifinfo purge: originator %pM, iface: %s\n", + orig_node->orig, if_outgoing->net_dev->name); + + ifinfo_purged = true; + + hlist_del_rcu(&orig_ifinfo->list); + batadv_orig_ifinfo_put(orig_ifinfo); + if (orig_node->last_bonding_candidate == orig_ifinfo) { + orig_node->last_bonding_candidate = NULL; + batadv_orig_ifinfo_put(orig_ifinfo); + } + } + + spin_unlock_bh(&orig_node->neigh_list_lock); + + return ifinfo_purged; +} + +/** + * batadv_purge_orig_neighbors() - purges neighbors from originator + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: orig node which is to be checked + * + * Return: true if any neighbor was purged, false otherwise + */ +static bool +batadv_purge_orig_neighbors(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node) +{ + struct hlist_node *node_tmp; + struct batadv_neigh_node *neigh_node; + bool neigh_purged = false; + unsigned long last_seen; + struct batadv_hard_iface *if_incoming; + + spin_lock_bh(&orig_node->neigh_list_lock); + + /* for all neighbors towards this originator ... */ + hlist_for_each_entry_safe(neigh_node, node_tmp, + &orig_node->neigh_list, list) { + last_seen = neigh_node->last_seen; + if_incoming = neigh_node->if_incoming; + + if (batadv_has_timed_out(last_seen, BATADV_PURGE_TIMEOUT) || + if_incoming->if_status == BATADV_IF_INACTIVE || + if_incoming->if_status == BATADV_IF_NOT_IN_USE || + if_incoming->if_status == BATADV_IF_TO_BE_REMOVED) { + if (if_incoming->if_status == BATADV_IF_INACTIVE || + if_incoming->if_status == BATADV_IF_NOT_IN_USE || + if_incoming->if_status == BATADV_IF_TO_BE_REMOVED) + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "neighbor purge: originator %pM, neighbor: %pM, iface: %s\n", + orig_node->orig, neigh_node->addr, + if_incoming->net_dev->name); + else + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "neighbor timeout: originator %pM, neighbor: %pM, last_seen: %u\n", + orig_node->orig, neigh_node->addr, + jiffies_to_msecs(last_seen)); + + neigh_purged = true; + + hlist_del_rcu(&neigh_node->list); + batadv_neigh_node_put(neigh_node); + } else { + /* only necessary if not the whole neighbor is to be + * deleted, but some interface has been removed. + */ + batadv_purge_neigh_ifinfo(bat_priv, neigh_node); + } + } + + spin_unlock_bh(&orig_node->neigh_list_lock); + return neigh_purged; +} + +/** + * batadv_find_best_neighbor() - finds the best neighbor after purging + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: orig node which is to be checked + * @if_outgoing: the interface for which the metric should be compared + * + * Return: the current best neighbor, with refcount increased. + */ +static struct batadv_neigh_node * +batadv_find_best_neighbor(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + struct batadv_hard_iface *if_outgoing) +{ + struct batadv_neigh_node *best = NULL, *neigh; + struct batadv_algo_ops *bao = bat_priv->algo_ops; + + rcu_read_lock(); + hlist_for_each_entry_rcu(neigh, &orig_node->neigh_list, list) { + if (best && (bao->neigh.cmp(neigh, if_outgoing, best, + if_outgoing) <= 0)) + continue; + + if (!kref_get_unless_zero(&neigh->refcount)) + continue; + + batadv_neigh_node_put(best); + + best = neigh; + } + rcu_read_unlock(); + + return best; +} + +/** + * batadv_purge_orig_node() - purges obsolete information from an orig_node + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: orig node which is to be checked + * + * This function checks if the orig_node or substructures of it have become + * obsolete, and purges this information if that's the case. + * + * Return: true if the orig_node is to be removed, false otherwise. + */ +static bool batadv_purge_orig_node(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node) +{ + struct batadv_neigh_node *best_neigh_node; + struct batadv_hard_iface *hard_iface; + bool changed_ifinfo, changed_neigh; + + if (batadv_has_timed_out(orig_node->last_seen, + 2 * BATADV_PURGE_TIMEOUT)) { + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "Originator timeout: originator %pM, last_seen %u\n", + orig_node->orig, + jiffies_to_msecs(orig_node->last_seen)); + return true; + } + changed_ifinfo = batadv_purge_orig_ifinfo(bat_priv, orig_node); + changed_neigh = batadv_purge_orig_neighbors(bat_priv, orig_node); + + if (!changed_ifinfo && !changed_neigh) + return false; + + /* first for NULL ... */ + best_neigh_node = batadv_find_best_neighbor(bat_priv, orig_node, + BATADV_IF_DEFAULT); + batadv_update_route(bat_priv, orig_node, BATADV_IF_DEFAULT, + best_neigh_node); + batadv_neigh_node_put(best_neigh_node); + + /* ... then for all other interfaces. */ + rcu_read_lock(); + list_for_each_entry_rcu(hard_iface, &batadv_hardif_list, list) { + if (hard_iface->if_status != BATADV_IF_ACTIVE) + continue; + + if (hard_iface->soft_iface != bat_priv->soft_iface) + continue; + + if (!kref_get_unless_zero(&hard_iface->refcount)) + continue; + + best_neigh_node = batadv_find_best_neighbor(bat_priv, + orig_node, + hard_iface); + batadv_update_route(bat_priv, orig_node, hard_iface, + best_neigh_node); + batadv_neigh_node_put(best_neigh_node); + + batadv_hardif_put(hard_iface); + } + rcu_read_unlock(); + + return false; +} + +/** + * batadv_purge_orig_ref() - Purge all outdated originators + * @bat_priv: the bat priv with all the soft interface information + */ +void batadv_purge_orig_ref(struct batadv_priv *bat_priv) +{ + struct batadv_hashtable *hash = bat_priv->orig_hash; + struct hlist_node *node_tmp; + struct hlist_head *head; + spinlock_t *list_lock; /* spinlock to protect write access */ + struct batadv_orig_node *orig_node; + u32 i; + + if (!hash) + return; + + /* for all origins... */ + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + list_lock = &hash->list_locks[i]; + + spin_lock_bh(list_lock); + hlist_for_each_entry_safe(orig_node, node_tmp, + head, hash_entry) { + if (batadv_purge_orig_node(bat_priv, orig_node)) { + batadv_gw_node_delete(bat_priv, orig_node); + hlist_del_rcu(&orig_node->hash_entry); + batadv_tt_global_del_orig(orig_node->bat_priv, + orig_node, -1, + "originator timed out"); + batadv_orig_node_put(orig_node); + continue; + } + + batadv_frag_purge_orig(orig_node, + batadv_frag_check_entry); + } + spin_unlock_bh(list_lock); + } + + batadv_gw_election(bat_priv); +} + +static void batadv_purge_orig(struct work_struct *work) +{ + struct delayed_work *delayed_work; + struct batadv_priv *bat_priv; + + delayed_work = to_delayed_work(work); + bat_priv = container_of(delayed_work, struct batadv_priv, orig_work); + batadv_purge_orig_ref(bat_priv); + queue_delayed_work(batadv_event_workqueue, + &bat_priv->orig_work, + msecs_to_jiffies(BATADV_ORIG_WORK_PERIOD)); +} + +/** + * batadv_orig_dump() - Dump to netlink the originator infos for a specific + * outgoing interface + * @msg: message to dump into + * @cb: parameters for the dump + * + * Return: 0 or error value + */ +int batadv_orig_dump(struct sk_buff *msg, struct netlink_callback *cb) +{ + struct net *net = sock_net(cb->skb->sk); + struct net_device *soft_iface; + struct net_device *hard_iface = NULL; + struct batadv_hard_iface *hardif = BATADV_IF_DEFAULT; + struct batadv_priv *bat_priv; + struct batadv_hard_iface *primary_if = NULL; + int ret; + int ifindex, hard_ifindex; + + ifindex = batadv_netlink_get_ifindex(cb->nlh, BATADV_ATTR_MESH_IFINDEX); + if (!ifindex) + return -EINVAL; + + soft_iface = dev_get_by_index(net, ifindex); + if (!soft_iface || !batadv_softif_is_valid(soft_iface)) { + ret = -ENODEV; + goto out; + } + + bat_priv = netdev_priv(soft_iface); + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if || primary_if->if_status != BATADV_IF_ACTIVE) { + ret = -ENOENT; + goto out; + } + + hard_ifindex = batadv_netlink_get_ifindex(cb->nlh, + BATADV_ATTR_HARD_IFINDEX); + if (hard_ifindex) { + hard_iface = dev_get_by_index(net, hard_ifindex); + if (hard_iface) + hardif = batadv_hardif_get_by_netdev(hard_iface); + + if (!hardif) { + ret = -ENODEV; + goto out; + } + + if (hardif->soft_iface != soft_iface) { + ret = -ENOENT; + goto out; + } + } + + if (!bat_priv->algo_ops->orig.dump) { + ret = -EOPNOTSUPP; + goto out; + } + + bat_priv->algo_ops->orig.dump(msg, cb, bat_priv, hardif); + + ret = msg->len; + + out: + batadv_hardif_put(hardif); + dev_put(hard_iface); + batadv_hardif_put(primary_if); + dev_put(soft_iface); + + return ret; +} diff --git a/net/batman-adv/originator.h b/net/batman-adv/originator.h new file mode 100644 index 000000000..ea3d69e4e --- /dev/null +++ b/net/batman-adv/originator.h @@ -0,0 +1,167 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#ifndef _NET_BATMAN_ADV_ORIGINATOR_H_ +#define _NET_BATMAN_ADV_ORIGINATOR_H_ + +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include + +bool batadv_compare_orig(const struct hlist_node *node, const void *data2); +int batadv_originator_init(struct batadv_priv *bat_priv); +void batadv_originator_free(struct batadv_priv *bat_priv); +void batadv_purge_orig_ref(struct batadv_priv *bat_priv); +void batadv_orig_node_release(struct kref *ref); +struct batadv_orig_node *batadv_orig_node_new(struct batadv_priv *bat_priv, + const u8 *addr); +struct batadv_hardif_neigh_node * +batadv_hardif_neigh_get(const struct batadv_hard_iface *hard_iface, + const u8 *neigh_addr); +void batadv_hardif_neigh_release(struct kref *ref); +struct batadv_neigh_node * +batadv_neigh_node_get_or_create(struct batadv_orig_node *orig_node, + struct batadv_hard_iface *hard_iface, + const u8 *neigh_addr); +void batadv_neigh_node_release(struct kref *ref); +struct batadv_neigh_node * +batadv_orig_router_get(struct batadv_orig_node *orig_node, + const struct batadv_hard_iface *if_outgoing); +struct batadv_neigh_ifinfo * +batadv_neigh_ifinfo_new(struct batadv_neigh_node *neigh, + struct batadv_hard_iface *if_outgoing); +struct batadv_neigh_ifinfo * +batadv_neigh_ifinfo_get(struct batadv_neigh_node *neigh, + struct batadv_hard_iface *if_outgoing); +void batadv_neigh_ifinfo_release(struct kref *ref); + +int batadv_hardif_neigh_dump(struct sk_buff *msg, struct netlink_callback *cb); + +struct batadv_orig_ifinfo * +batadv_orig_ifinfo_get(struct batadv_orig_node *orig_node, + struct batadv_hard_iface *if_outgoing); +struct batadv_orig_ifinfo * +batadv_orig_ifinfo_new(struct batadv_orig_node *orig_node, + struct batadv_hard_iface *if_outgoing); +void batadv_orig_ifinfo_release(struct kref *ref); + +int batadv_orig_dump(struct sk_buff *msg, struct netlink_callback *cb); +struct batadv_orig_node_vlan * +batadv_orig_node_vlan_new(struct batadv_orig_node *orig_node, + unsigned short vid); +struct batadv_orig_node_vlan * +batadv_orig_node_vlan_get(struct batadv_orig_node *orig_node, + unsigned short vid); +void batadv_orig_node_vlan_release(struct kref *ref); + +/** + * batadv_choose_orig() - Return the index of the orig entry in the hash table + * @data: mac address of the originator node + * @size: the size of the hash table + * + * Return: the hash index where the object represented by @data should be + * stored at. + */ +static inline u32 batadv_choose_orig(const void *data, u32 size) +{ + u32 hash = 0; + + hash = jhash(data, ETH_ALEN, hash); + return hash % size; +} + +struct batadv_orig_node * +batadv_orig_hash_find(struct batadv_priv *bat_priv, const void *data); + +/** + * batadv_orig_node_vlan_put() - decrement the refcounter and possibly release + * the originator-vlan object + * @orig_vlan: the originator-vlan object to release + */ +static inline void +batadv_orig_node_vlan_put(struct batadv_orig_node_vlan *orig_vlan) +{ + if (!orig_vlan) + return; + + kref_put(&orig_vlan->refcount, batadv_orig_node_vlan_release); +} + +/** + * batadv_neigh_ifinfo_put() - decrement the refcounter and possibly release + * the neigh_ifinfo + * @neigh_ifinfo: the neigh_ifinfo object to release + */ +static inline void +batadv_neigh_ifinfo_put(struct batadv_neigh_ifinfo *neigh_ifinfo) +{ + if (!neigh_ifinfo) + return; + + kref_put(&neigh_ifinfo->refcount, batadv_neigh_ifinfo_release); +} + +/** + * batadv_hardif_neigh_put() - decrement the hardif neighbors refcounter + * and possibly release it + * @hardif_neigh: hardif neigh neighbor to free + */ +static inline void +batadv_hardif_neigh_put(struct batadv_hardif_neigh_node *hardif_neigh) +{ + if (!hardif_neigh) + return; + + kref_put(&hardif_neigh->refcount, batadv_hardif_neigh_release); +} + +/** + * batadv_neigh_node_put() - decrement the neighbors refcounter and possibly + * release it + * @neigh_node: neigh neighbor to free + */ +static inline void batadv_neigh_node_put(struct batadv_neigh_node *neigh_node) +{ + if (!neigh_node) + return; + + kref_put(&neigh_node->refcount, batadv_neigh_node_release); +} + +/** + * batadv_orig_ifinfo_put() - decrement the refcounter and possibly release + * the orig_ifinfo + * @orig_ifinfo: the orig_ifinfo object to release + */ +static inline void +batadv_orig_ifinfo_put(struct batadv_orig_ifinfo *orig_ifinfo) +{ + if (!orig_ifinfo) + return; + + kref_put(&orig_ifinfo->refcount, batadv_orig_ifinfo_release); +} + +/** + * batadv_orig_node_put() - decrement the orig node refcounter and possibly + * release it + * @orig_node: the orig node to free + */ +static inline void batadv_orig_node_put(struct batadv_orig_node *orig_node) +{ + if (!orig_node) + return; + + kref_put(&orig_node->refcount, batadv_orig_node_release); +} + +#endif /* _NET_BATMAN_ADV_ORIGINATOR_H_ */ diff --git a/net/batman-adv/routing.c b/net/batman-adv/routing.c new file mode 100644 index 000000000..83f31494e --- /dev/null +++ b/net/batman-adv/routing.c @@ -0,0 +1,1273 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#include "routing.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "bitarray.h" +#include "bridge_loop_avoidance.h" +#include "distributed-arp-table.h" +#include "fragmentation.h" +#include "hard-interface.h" +#include "log.h" +#include "network-coding.h" +#include "originator.h" +#include "send.h" +#include "soft-interface.h" +#include "tp_meter.h" +#include "translation-table.h" +#include "tvlv.h" + +static int batadv_route_unicast_packet(struct sk_buff *skb, + struct batadv_hard_iface *recv_if); + +/** + * _batadv_update_route() - set the router for this originator + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: orig node which is to be configured + * @recv_if: the receive interface for which this route is set + * @neigh_node: neighbor which should be the next router + * + * This function does not perform any error checks + */ +static void _batadv_update_route(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + struct batadv_hard_iface *recv_if, + struct batadv_neigh_node *neigh_node) +{ + struct batadv_orig_ifinfo *orig_ifinfo; + struct batadv_neigh_node *curr_router; + + orig_ifinfo = batadv_orig_ifinfo_get(orig_node, recv_if); + if (!orig_ifinfo) + return; + + spin_lock_bh(&orig_node->neigh_list_lock); + /* curr_router used earlier may not be the current orig_ifinfo->router + * anymore because it was dereferenced outside of the neigh_list_lock + * protected region. After the new best neighbor has replace the current + * best neighbor the reference counter needs to decrease. Consequently, + * the code needs to ensure the curr_router variable contains a pointer + * to the replaced best neighbor. + */ + + /* increase refcount of new best neighbor */ + if (neigh_node) + kref_get(&neigh_node->refcount); + + curr_router = rcu_replace_pointer(orig_ifinfo->router, neigh_node, + true); + spin_unlock_bh(&orig_node->neigh_list_lock); + batadv_orig_ifinfo_put(orig_ifinfo); + + /* route deleted */ + if (curr_router && !neigh_node) { + batadv_dbg(BATADV_DBG_ROUTES, bat_priv, + "Deleting route towards: %pM\n", orig_node->orig); + batadv_tt_global_del_orig(bat_priv, orig_node, -1, + "Deleted route towards originator"); + + /* route added */ + } else if (!curr_router && neigh_node) { + batadv_dbg(BATADV_DBG_ROUTES, bat_priv, + "Adding route towards: %pM (via %pM)\n", + orig_node->orig, neigh_node->addr); + /* route changed */ + } else if (neigh_node && curr_router) { + batadv_dbg(BATADV_DBG_ROUTES, bat_priv, + "Changing route towards: %pM (now via %pM - was via %pM)\n", + orig_node->orig, neigh_node->addr, + curr_router->addr); + } + + /* decrease refcount of previous best neighbor */ + batadv_neigh_node_put(curr_router); +} + +/** + * batadv_update_route() - set the router for this originator + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: orig node which is to be configured + * @recv_if: the receive interface for which this route is set + * @neigh_node: neighbor which should be the next router + */ +void batadv_update_route(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + struct batadv_hard_iface *recv_if, + struct batadv_neigh_node *neigh_node) +{ + struct batadv_neigh_node *router = NULL; + + if (!orig_node) + goto out; + + router = batadv_orig_router_get(orig_node, recv_if); + + if (router != neigh_node) + _batadv_update_route(bat_priv, orig_node, recv_if, neigh_node); + +out: + batadv_neigh_node_put(router); +} + +/** + * batadv_window_protected() - checks whether the host restarted and is in the + * protection time. + * @bat_priv: the bat priv with all the soft interface information + * @seq_num_diff: difference between the current/received sequence number and + * the last sequence number + * @seq_old_max_diff: maximum age of sequence number not considered as restart + * @last_reset: jiffies timestamp of the last reset, will be updated when reset + * is detected + * @protection_started: is set to true if the protection window was started, + * doesn't change otherwise. + * + * Return: + * false if the packet is to be accepted. + * true if the packet is to be ignored. + */ +bool batadv_window_protected(struct batadv_priv *bat_priv, s32 seq_num_diff, + s32 seq_old_max_diff, unsigned long *last_reset, + bool *protection_started) +{ + if (seq_num_diff <= -seq_old_max_diff || + seq_num_diff >= BATADV_EXPECTED_SEQNO_RANGE) { + if (!batadv_has_timed_out(*last_reset, + BATADV_RESET_PROTECTION_MS)) + return true; + + *last_reset = jiffies; + if (protection_started) + *protection_started = true; + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "old packet received, start protection\n"); + } + + return false; +} + +/** + * batadv_check_management_packet() - Check preconditions for management packets + * @skb: incoming packet buffer + * @hard_iface: incoming hard interface + * @header_len: minimal header length of packet type + * + * Return: true when management preconditions are met, false otherwise + */ +bool batadv_check_management_packet(struct sk_buff *skb, + struct batadv_hard_iface *hard_iface, + int header_len) +{ + struct ethhdr *ethhdr; + + /* drop packet if it has not necessary minimum size */ + if (unlikely(!pskb_may_pull(skb, header_len))) + return false; + + ethhdr = eth_hdr(skb); + + /* packet with broadcast indication but unicast recipient */ + if (!is_broadcast_ether_addr(ethhdr->h_dest)) + return false; + + /* packet with invalid sender address */ + if (!is_valid_ether_addr(ethhdr->h_source)) + return false; + + /* create a copy of the skb, if needed, to modify it. */ + if (skb_cow(skb, 0) < 0) + return false; + + /* keep skb linear */ + if (skb_linearize(skb) < 0) + return false; + + return true; +} + +/** + * batadv_recv_my_icmp_packet() - receive an icmp packet locally + * @bat_priv: the bat priv with all the soft interface information + * @skb: icmp packet to process + * + * Return: NET_RX_SUCCESS if the packet has been consumed or NET_RX_DROP + * otherwise. + */ +static int batadv_recv_my_icmp_packet(struct batadv_priv *bat_priv, + struct sk_buff *skb) +{ + struct batadv_hard_iface *primary_if = NULL; + struct batadv_orig_node *orig_node = NULL; + struct batadv_icmp_header *icmph; + int res, ret = NET_RX_DROP; + + icmph = (struct batadv_icmp_header *)skb->data; + + switch (icmph->msg_type) { + case BATADV_ECHO_REQUEST: + /* answer echo request (ping) */ + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + goto out; + + /* get routing information */ + orig_node = batadv_orig_hash_find(bat_priv, icmph->orig); + if (!orig_node) + goto out; + + /* create a copy of the skb, if needed, to modify it. */ + if (skb_cow(skb, ETH_HLEN) < 0) + goto out; + + icmph = (struct batadv_icmp_header *)skb->data; + + ether_addr_copy(icmph->dst, icmph->orig); + ether_addr_copy(icmph->orig, primary_if->net_dev->dev_addr); + icmph->msg_type = BATADV_ECHO_REPLY; + icmph->ttl = BATADV_TTL; + + res = batadv_send_skb_to_orig(skb, orig_node, NULL); + if (res == NET_XMIT_SUCCESS) + ret = NET_RX_SUCCESS; + + /* skb was consumed */ + skb = NULL; + break; + case BATADV_TP: + if (!pskb_may_pull(skb, sizeof(struct batadv_icmp_tp_packet))) + goto out; + + batadv_tp_meter_recv(bat_priv, skb); + ret = NET_RX_SUCCESS; + /* skb was consumed */ + skb = NULL; + goto out; + default: + /* drop unknown type */ + goto out; + } +out: + batadv_hardif_put(primary_if); + batadv_orig_node_put(orig_node); + + kfree_skb(skb); + + return ret; +} + +static int batadv_recv_icmp_ttl_exceeded(struct batadv_priv *bat_priv, + struct sk_buff *skb) +{ + struct batadv_hard_iface *primary_if = NULL; + struct batadv_orig_node *orig_node = NULL; + struct batadv_icmp_packet *icmp_packet; + int res, ret = NET_RX_DROP; + + icmp_packet = (struct batadv_icmp_packet *)skb->data; + + /* send TTL exceeded if packet is an echo request (traceroute) */ + if (icmp_packet->msg_type != BATADV_ECHO_REQUEST) { + pr_debug("Warning - can't forward icmp packet from %pM to %pM: ttl exceeded\n", + icmp_packet->orig, icmp_packet->dst); + goto out; + } + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + goto out; + + /* get routing information */ + orig_node = batadv_orig_hash_find(bat_priv, icmp_packet->orig); + if (!orig_node) + goto out; + + /* create a copy of the skb, if needed, to modify it. */ + if (skb_cow(skb, ETH_HLEN) < 0) + goto out; + + icmp_packet = (struct batadv_icmp_packet *)skb->data; + + ether_addr_copy(icmp_packet->dst, icmp_packet->orig); + ether_addr_copy(icmp_packet->orig, primary_if->net_dev->dev_addr); + icmp_packet->msg_type = BATADV_TTL_EXCEEDED; + icmp_packet->ttl = BATADV_TTL; + + res = batadv_send_skb_to_orig(skb, orig_node, NULL); + if (res == NET_RX_SUCCESS) + ret = NET_XMIT_SUCCESS; + + /* skb was consumed */ + skb = NULL; + +out: + batadv_hardif_put(primary_if); + batadv_orig_node_put(orig_node); + + kfree_skb(skb); + + return ret; +} + +/** + * batadv_recv_icmp_packet() - Process incoming icmp packet + * @skb: incoming packet buffer + * @recv_if: incoming hard interface + * + * Return: NET_RX_SUCCESS on success or NET_RX_DROP in case of failure + */ +int batadv_recv_icmp_packet(struct sk_buff *skb, + struct batadv_hard_iface *recv_if) +{ + struct batadv_priv *bat_priv = netdev_priv(recv_if->soft_iface); + struct batadv_icmp_header *icmph; + struct batadv_icmp_packet_rr *icmp_packet_rr; + struct ethhdr *ethhdr; + struct batadv_orig_node *orig_node = NULL; + int hdr_size = sizeof(struct batadv_icmp_header); + int res, ret = NET_RX_DROP; + + /* drop packet if it has not necessary minimum size */ + if (unlikely(!pskb_may_pull(skb, hdr_size))) + goto free_skb; + + ethhdr = eth_hdr(skb); + + /* packet with unicast indication but non-unicast recipient */ + if (!is_valid_ether_addr(ethhdr->h_dest)) + goto free_skb; + + /* packet with broadcast/multicast sender address */ + if (is_multicast_ether_addr(ethhdr->h_source)) + goto free_skb; + + /* not for me */ + if (!batadv_is_my_mac(bat_priv, ethhdr->h_dest)) + goto free_skb; + + icmph = (struct batadv_icmp_header *)skb->data; + + /* add record route information if not full */ + if ((icmph->msg_type == BATADV_ECHO_REPLY || + icmph->msg_type == BATADV_ECHO_REQUEST) && + skb->len >= sizeof(struct batadv_icmp_packet_rr)) { + if (skb_linearize(skb) < 0) + goto free_skb; + + /* create a copy of the skb, if needed, to modify it. */ + if (skb_cow(skb, ETH_HLEN) < 0) + goto free_skb; + + ethhdr = eth_hdr(skb); + icmph = (struct batadv_icmp_header *)skb->data; + icmp_packet_rr = (struct batadv_icmp_packet_rr *)icmph; + if (icmp_packet_rr->rr_cur >= BATADV_RR_LEN) + goto free_skb; + + ether_addr_copy(icmp_packet_rr->rr[icmp_packet_rr->rr_cur], + ethhdr->h_dest); + icmp_packet_rr->rr_cur++; + } + + /* packet for me */ + if (batadv_is_my_mac(bat_priv, icmph->dst)) + return batadv_recv_my_icmp_packet(bat_priv, skb); + + /* TTL exceeded */ + if (icmph->ttl < 2) + return batadv_recv_icmp_ttl_exceeded(bat_priv, skb); + + /* get routing information */ + orig_node = batadv_orig_hash_find(bat_priv, icmph->dst); + if (!orig_node) + goto free_skb; + + /* create a copy of the skb, if needed, to modify it. */ + if (skb_cow(skb, ETH_HLEN) < 0) + goto put_orig_node; + + icmph = (struct batadv_icmp_header *)skb->data; + + /* decrement ttl */ + icmph->ttl--; + + /* route it */ + res = batadv_send_skb_to_orig(skb, orig_node, recv_if); + if (res == NET_XMIT_SUCCESS) + ret = NET_RX_SUCCESS; + + /* skb was consumed */ + skb = NULL; + +put_orig_node: + batadv_orig_node_put(orig_node); +free_skb: + kfree_skb(skb); + + return ret; +} + +/** + * batadv_check_unicast_packet() - Check for malformed unicast packets + * @bat_priv: the bat priv with all the soft interface information + * @skb: packet to check + * @hdr_size: size of header to pull + * + * Checks for short header and bad addresses in the given packet. + * + * Return: negative value when check fails and 0 otherwise. The negative value + * depends on the reason: -ENODATA for bad header, -EBADR for broadcast + * destination or source, and -EREMOTE for non-local (other host) destination. + */ +static int batadv_check_unicast_packet(struct batadv_priv *bat_priv, + struct sk_buff *skb, int hdr_size) +{ + struct ethhdr *ethhdr; + + /* drop packet if it has not necessary minimum size */ + if (unlikely(!pskb_may_pull(skb, hdr_size))) + return -ENODATA; + + ethhdr = eth_hdr(skb); + + /* packet with unicast indication but non-unicast recipient */ + if (!is_valid_ether_addr(ethhdr->h_dest)) + return -EBADR; + + /* packet with broadcast/multicast sender address */ + if (is_multicast_ether_addr(ethhdr->h_source)) + return -EBADR; + + /* not for me */ + if (!batadv_is_my_mac(bat_priv, ethhdr->h_dest)) + return -EREMOTE; + + return 0; +} + +/** + * batadv_last_bonding_get() - Get last_bonding_candidate of orig_node + * @orig_node: originator node whose last bonding candidate should be retrieved + * + * Return: last bonding candidate of router or NULL if not found + * + * The object is returned with refcounter increased by 1. + */ +static struct batadv_orig_ifinfo * +batadv_last_bonding_get(struct batadv_orig_node *orig_node) +{ + struct batadv_orig_ifinfo *last_bonding_candidate; + + spin_lock_bh(&orig_node->neigh_list_lock); + last_bonding_candidate = orig_node->last_bonding_candidate; + + if (last_bonding_candidate) + kref_get(&last_bonding_candidate->refcount); + spin_unlock_bh(&orig_node->neigh_list_lock); + + return last_bonding_candidate; +} + +/** + * batadv_last_bonding_replace() - Replace last_bonding_candidate of orig_node + * @orig_node: originator node whose bonding candidates should be replaced + * @new_candidate: new bonding candidate or NULL + */ +static void +batadv_last_bonding_replace(struct batadv_orig_node *orig_node, + struct batadv_orig_ifinfo *new_candidate) +{ + struct batadv_orig_ifinfo *old_candidate; + + spin_lock_bh(&orig_node->neigh_list_lock); + old_candidate = orig_node->last_bonding_candidate; + + if (new_candidate) + kref_get(&new_candidate->refcount); + orig_node->last_bonding_candidate = new_candidate; + spin_unlock_bh(&orig_node->neigh_list_lock); + + batadv_orig_ifinfo_put(old_candidate); +} + +/** + * batadv_find_router() - find a suitable router for this originator + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: the destination node + * @recv_if: pointer to interface this packet was received on + * + * Return: the router which should be used for this orig_node on + * this interface, or NULL if not available. + */ +struct batadv_neigh_node * +batadv_find_router(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + struct batadv_hard_iface *recv_if) +{ + struct batadv_algo_ops *bao = bat_priv->algo_ops; + struct batadv_neigh_node *first_candidate_router = NULL; + struct batadv_neigh_node *next_candidate_router = NULL; + struct batadv_neigh_node *router, *cand_router = NULL; + struct batadv_neigh_node *last_cand_router = NULL; + struct batadv_orig_ifinfo *cand, *first_candidate = NULL; + struct batadv_orig_ifinfo *next_candidate = NULL; + struct batadv_orig_ifinfo *last_candidate; + bool last_candidate_found = false; + + if (!orig_node) + return NULL; + + router = batadv_orig_router_get(orig_node, recv_if); + + if (!router) + return router; + + /* only consider bonding for recv_if == BATADV_IF_DEFAULT (first hop) + * and if activated. + */ + if (!(recv_if == BATADV_IF_DEFAULT && atomic_read(&bat_priv->bonding))) + return router; + + /* bonding: loop through the list of possible routers found + * for the various outgoing interfaces and find a candidate after + * the last chosen bonding candidate (next_candidate). If no such + * router is found, use the first candidate found (the previously + * chosen bonding candidate might have been the last one in the list). + * If this can't be found either, return the previously chosen + * router - obviously there are no other candidates. + */ + rcu_read_lock(); + last_candidate = batadv_last_bonding_get(orig_node); + if (last_candidate) + last_cand_router = rcu_dereference(last_candidate->router); + + hlist_for_each_entry_rcu(cand, &orig_node->ifinfo_list, list) { + /* acquire some structures and references ... */ + if (!kref_get_unless_zero(&cand->refcount)) + continue; + + cand_router = rcu_dereference(cand->router); + if (!cand_router) + goto next; + + if (!kref_get_unless_zero(&cand_router->refcount)) { + cand_router = NULL; + goto next; + } + + /* alternative candidate should be good enough to be + * considered + */ + if (!bao->neigh.is_similar_or_better(cand_router, + cand->if_outgoing, router, + recv_if)) + goto next; + + /* don't use the same router twice */ + if (last_cand_router == cand_router) + goto next; + + /* mark the first possible candidate */ + if (!first_candidate) { + kref_get(&cand_router->refcount); + kref_get(&cand->refcount); + first_candidate = cand; + first_candidate_router = cand_router; + } + + /* check if the loop has already passed the previously selected + * candidate ... this function should select the next candidate + * AFTER the previously used bonding candidate. + */ + if (!last_candidate || last_candidate_found) { + next_candidate = cand; + next_candidate_router = cand_router; + break; + } + + if (last_candidate == cand) + last_candidate_found = true; +next: + /* free references */ + if (cand_router) { + batadv_neigh_node_put(cand_router); + cand_router = NULL; + } + batadv_orig_ifinfo_put(cand); + } + rcu_read_unlock(); + + /* After finding candidates, handle the three cases: + * 1) there is a next candidate, use that + * 2) there is no next candidate, use the first of the list + * 3) there is no candidate at all, return the default router + */ + if (next_candidate) { + batadv_neigh_node_put(router); + + kref_get(&next_candidate_router->refcount); + router = next_candidate_router; + batadv_last_bonding_replace(orig_node, next_candidate); + } else if (first_candidate) { + batadv_neigh_node_put(router); + + kref_get(&first_candidate_router->refcount); + router = first_candidate_router; + batadv_last_bonding_replace(orig_node, first_candidate); + } else { + batadv_last_bonding_replace(orig_node, NULL); + } + + /* cleanup of candidates */ + if (first_candidate) { + batadv_neigh_node_put(first_candidate_router); + batadv_orig_ifinfo_put(first_candidate); + } + + if (next_candidate) { + batadv_neigh_node_put(next_candidate_router); + batadv_orig_ifinfo_put(next_candidate); + } + + batadv_orig_ifinfo_put(last_candidate); + + return router; +} + +static int batadv_route_unicast_packet(struct sk_buff *skb, + struct batadv_hard_iface *recv_if) +{ + struct batadv_priv *bat_priv = netdev_priv(recv_if->soft_iface); + struct batadv_orig_node *orig_node = NULL; + struct batadv_unicast_packet *unicast_packet; + struct ethhdr *ethhdr = eth_hdr(skb); + int res, hdr_len, ret = NET_RX_DROP; + unsigned int len; + + unicast_packet = (struct batadv_unicast_packet *)skb->data; + + /* TTL exceeded */ + if (unicast_packet->ttl < 2) { + pr_debug("Warning - can't forward unicast packet from %pM to %pM: ttl exceeded\n", + ethhdr->h_source, unicast_packet->dest); + goto free_skb; + } + + /* get routing information */ + orig_node = batadv_orig_hash_find(bat_priv, unicast_packet->dest); + + if (!orig_node) + goto free_skb; + + /* create a copy of the skb, if needed, to modify it. */ + if (skb_cow(skb, ETH_HLEN) < 0) + goto put_orig_node; + + /* decrement ttl */ + unicast_packet = (struct batadv_unicast_packet *)skb->data; + unicast_packet->ttl--; + + switch (unicast_packet->packet_type) { + case BATADV_UNICAST_4ADDR: + hdr_len = sizeof(struct batadv_unicast_4addr_packet); + break; + case BATADV_UNICAST: + hdr_len = sizeof(struct batadv_unicast_packet); + break; + default: + /* other packet types not supported - yet */ + hdr_len = -1; + break; + } + + if (hdr_len > 0) + batadv_skb_set_priority(skb, hdr_len); + + len = skb->len; + res = batadv_send_skb_to_orig(skb, orig_node, recv_if); + + /* translate transmit result into receive result */ + if (res == NET_XMIT_SUCCESS) { + ret = NET_RX_SUCCESS; + /* skb was transmitted and consumed */ + batadv_inc_counter(bat_priv, BATADV_CNT_FORWARD); + batadv_add_counter(bat_priv, BATADV_CNT_FORWARD_BYTES, + len + ETH_HLEN); + } + + /* skb was consumed */ + skb = NULL; + +put_orig_node: + batadv_orig_node_put(orig_node); +free_skb: + kfree_skb(skb); + + return ret; +} + +/** + * batadv_reroute_unicast_packet() - update the unicast header for re-routing + * @bat_priv: the bat priv with all the soft interface information + * @skb: unicast packet to process + * @unicast_packet: the unicast header to be updated + * @dst_addr: the payload destination + * @vid: VLAN identifier + * + * Search the translation table for dst_addr and update the unicast header with + * the new corresponding information (originator address where the destination + * client currently is and its known TTVN) + * + * Return: true if the packet header has been updated, false otherwise + */ +static bool +batadv_reroute_unicast_packet(struct batadv_priv *bat_priv, struct sk_buff *skb, + struct batadv_unicast_packet *unicast_packet, + u8 *dst_addr, unsigned short vid) +{ + struct batadv_orig_node *orig_node = NULL; + struct batadv_hard_iface *primary_if = NULL; + bool ret = false; + const u8 *orig_addr; + u8 orig_ttvn; + + if (batadv_is_my_client(bat_priv, dst_addr, vid)) { + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + goto out; + orig_addr = primary_if->net_dev->dev_addr; + orig_ttvn = (u8)atomic_read(&bat_priv->tt.vn); + } else { + orig_node = batadv_transtable_search(bat_priv, NULL, dst_addr, + vid); + if (!orig_node) + goto out; + + if (batadv_compare_eth(orig_node->orig, unicast_packet->dest)) + goto out; + + orig_addr = orig_node->orig; + orig_ttvn = (u8)atomic_read(&orig_node->last_ttvn); + } + + /* update the packet header */ + skb_postpull_rcsum(skb, unicast_packet, sizeof(*unicast_packet)); + ether_addr_copy(unicast_packet->dest, orig_addr); + unicast_packet->ttvn = orig_ttvn; + skb_postpush_rcsum(skb, unicast_packet, sizeof(*unicast_packet)); + + ret = true; +out: + batadv_hardif_put(primary_if); + batadv_orig_node_put(orig_node); + + return ret; +} + +static bool batadv_check_unicast_ttvn(struct batadv_priv *bat_priv, + struct sk_buff *skb, int hdr_len) +{ + struct batadv_unicast_packet *unicast_packet; + struct batadv_hard_iface *primary_if; + struct batadv_orig_node *orig_node; + u8 curr_ttvn, old_ttvn; + struct ethhdr *ethhdr; + unsigned short vid; + int is_old_ttvn; + + /* check if there is enough data before accessing it */ + if (!pskb_may_pull(skb, hdr_len + ETH_HLEN)) + return false; + + /* create a copy of the skb (in case of for re-routing) to modify it. */ + if (skb_cow(skb, sizeof(*unicast_packet)) < 0) + return false; + + unicast_packet = (struct batadv_unicast_packet *)skb->data; + vid = batadv_get_vid(skb, hdr_len); + ethhdr = (struct ethhdr *)(skb->data + hdr_len); + + /* do not reroute multicast frames in a unicast header */ + if (is_multicast_ether_addr(ethhdr->h_dest)) + return true; + + /* check if the destination client was served by this node and it is now + * roaming. In this case, it means that the node has got a ROAM_ADV + * message and that it knows the new destination in the mesh to re-route + * the packet to + */ + if (batadv_tt_local_client_is_roaming(bat_priv, ethhdr->h_dest, vid)) { + if (batadv_reroute_unicast_packet(bat_priv, skb, unicast_packet, + ethhdr->h_dest, vid)) + batadv_dbg_ratelimited(BATADV_DBG_TT, + bat_priv, + "Rerouting unicast packet to %pM (dst=%pM): Local Roaming\n", + unicast_packet->dest, + ethhdr->h_dest); + /* at this point the mesh destination should have been + * substituted with the originator address found in the global + * table. If not, let the packet go untouched anyway because + * there is nothing the node can do + */ + return true; + } + + /* retrieve the TTVN known by this node for the packet destination. This + * value is used later to check if the node which sent (or re-routed + * last time) the packet had an updated information or not + */ + curr_ttvn = (u8)atomic_read(&bat_priv->tt.vn); + if (!batadv_is_my_mac(bat_priv, unicast_packet->dest)) { + orig_node = batadv_orig_hash_find(bat_priv, + unicast_packet->dest); + /* if it is not possible to find the orig_node representing the + * destination, the packet can immediately be dropped as it will + * not be possible to deliver it + */ + if (!orig_node) + return false; + + curr_ttvn = (u8)atomic_read(&orig_node->last_ttvn); + batadv_orig_node_put(orig_node); + } + + /* check if the TTVN contained in the packet is fresher than what the + * node knows + */ + is_old_ttvn = batadv_seq_before(unicast_packet->ttvn, curr_ttvn); + if (!is_old_ttvn) + return true; + + old_ttvn = unicast_packet->ttvn; + /* the packet was forged based on outdated network information. Its + * destination can possibly be updated and forwarded towards the new + * target host + */ + if (batadv_reroute_unicast_packet(bat_priv, skb, unicast_packet, + ethhdr->h_dest, vid)) { + batadv_dbg_ratelimited(BATADV_DBG_TT, bat_priv, + "Rerouting unicast packet to %pM (dst=%pM): TTVN mismatch old_ttvn=%u new_ttvn=%u\n", + unicast_packet->dest, ethhdr->h_dest, + old_ttvn, curr_ttvn); + return true; + } + + /* the packet has not been re-routed: either the destination is + * currently served by this node or there is no destination at all and + * it is possible to drop the packet + */ + if (!batadv_is_my_client(bat_priv, ethhdr->h_dest, vid)) + return false; + + /* update the header in order to let the packet be delivered to this + * node's soft interface + */ + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + return false; + + /* update the packet header */ + skb_postpull_rcsum(skb, unicast_packet, sizeof(*unicast_packet)); + ether_addr_copy(unicast_packet->dest, primary_if->net_dev->dev_addr); + unicast_packet->ttvn = curr_ttvn; + skb_postpush_rcsum(skb, unicast_packet, sizeof(*unicast_packet)); + + batadv_hardif_put(primary_if); + + return true; +} + +/** + * batadv_recv_unhandled_unicast_packet() - receive and process packets which + * are in the unicast number space but not yet known to the implementation + * @skb: unicast tvlv packet to process + * @recv_if: pointer to interface this packet was received on + * + * Return: NET_RX_SUCCESS if the packet has been consumed or NET_RX_DROP + * otherwise. + */ +int batadv_recv_unhandled_unicast_packet(struct sk_buff *skb, + struct batadv_hard_iface *recv_if) +{ + struct batadv_unicast_packet *unicast_packet; + struct batadv_priv *bat_priv = netdev_priv(recv_if->soft_iface); + int check, hdr_size = sizeof(*unicast_packet); + + check = batadv_check_unicast_packet(bat_priv, skb, hdr_size); + if (check < 0) + goto free_skb; + + /* we don't know about this type, drop it. */ + unicast_packet = (struct batadv_unicast_packet *)skb->data; + if (batadv_is_my_mac(bat_priv, unicast_packet->dest)) + goto free_skb; + + return batadv_route_unicast_packet(skb, recv_if); + +free_skb: + kfree_skb(skb); + return NET_RX_DROP; +} + +/** + * batadv_recv_unicast_packet() - Process incoming unicast packet + * @skb: incoming packet buffer + * @recv_if: incoming hard interface + * + * Return: NET_RX_SUCCESS on success or NET_RX_DROP in case of failure + */ +int batadv_recv_unicast_packet(struct sk_buff *skb, + struct batadv_hard_iface *recv_if) +{ + struct batadv_priv *bat_priv = netdev_priv(recv_if->soft_iface); + struct batadv_unicast_packet *unicast_packet; + struct batadv_unicast_4addr_packet *unicast_4addr_packet; + u8 *orig_addr, *orig_addr_gw; + struct batadv_orig_node *orig_node = NULL, *orig_node_gw = NULL; + int check, hdr_size = sizeof(*unicast_packet); + enum batadv_subtype subtype; + int ret = NET_RX_DROP; + bool is4addr, is_gw; + + unicast_packet = (struct batadv_unicast_packet *)skb->data; + is4addr = unicast_packet->packet_type == BATADV_UNICAST_4ADDR; + /* the caller function should have already pulled 2 bytes */ + if (is4addr) + hdr_size = sizeof(*unicast_4addr_packet); + + /* function returns -EREMOTE for promiscuous packets */ + check = batadv_check_unicast_packet(bat_priv, skb, hdr_size); + + /* Even though the packet is not for us, we might save it to use for + * decoding a later received coded packet + */ + if (check == -EREMOTE) + batadv_nc_skb_store_sniffed_unicast(bat_priv, skb); + + if (check < 0) + goto free_skb; + if (!batadv_check_unicast_ttvn(bat_priv, skb, hdr_size)) + goto free_skb; + + unicast_packet = (struct batadv_unicast_packet *)skb->data; + + /* packet for me */ + if (batadv_is_my_mac(bat_priv, unicast_packet->dest)) { + /* If this is a unicast packet from another backgone gw, + * drop it. + */ + orig_addr_gw = eth_hdr(skb)->h_source; + orig_node_gw = batadv_orig_hash_find(bat_priv, orig_addr_gw); + if (orig_node_gw) { + is_gw = batadv_bla_is_backbone_gw(skb, orig_node_gw, + hdr_size); + batadv_orig_node_put(orig_node_gw); + if (is_gw) { + batadv_dbg(BATADV_DBG_BLA, bat_priv, + "%s(): Dropped unicast pkt received from another backbone gw %pM.\n", + __func__, orig_addr_gw); + goto free_skb; + } + } + + if (is4addr) { + unicast_4addr_packet = + (struct batadv_unicast_4addr_packet *)skb->data; + subtype = unicast_4addr_packet->subtype; + batadv_dat_inc_counter(bat_priv, subtype); + + /* Only payload data should be considered for speedy + * join. For example, DAT also uses unicast 4addr + * types, but those packets should not be considered + * for speedy join, since the clients do not actually + * reside at the sending originator. + */ + if (subtype == BATADV_P_DATA) { + orig_addr = unicast_4addr_packet->src; + orig_node = batadv_orig_hash_find(bat_priv, + orig_addr); + } + } + + if (batadv_dat_snoop_incoming_arp_request(bat_priv, skb, + hdr_size)) + goto rx_success; + if (batadv_dat_snoop_incoming_arp_reply(bat_priv, skb, + hdr_size)) + goto rx_success; + + batadv_dat_snoop_incoming_dhcp_ack(bat_priv, skb, hdr_size); + + batadv_interface_rx(recv_if->soft_iface, skb, hdr_size, + orig_node); + +rx_success: + batadv_orig_node_put(orig_node); + + return NET_RX_SUCCESS; + } + + ret = batadv_route_unicast_packet(skb, recv_if); + /* skb was consumed */ + skb = NULL; + +free_skb: + kfree_skb(skb); + + return ret; +} + +/** + * batadv_recv_unicast_tvlv() - receive and process unicast tvlv packets + * @skb: unicast tvlv packet to process + * @recv_if: pointer to interface this packet was received on + * + * Return: NET_RX_SUCCESS if the packet has been consumed or NET_RX_DROP + * otherwise. + */ +int batadv_recv_unicast_tvlv(struct sk_buff *skb, + struct batadv_hard_iface *recv_if) +{ + struct batadv_priv *bat_priv = netdev_priv(recv_if->soft_iface); + struct batadv_unicast_tvlv_packet *unicast_tvlv_packet; + unsigned char *tvlv_buff; + u16 tvlv_buff_len; + int hdr_size = sizeof(*unicast_tvlv_packet); + int ret = NET_RX_DROP; + + if (batadv_check_unicast_packet(bat_priv, skb, hdr_size) < 0) + goto free_skb; + + /* the header is likely to be modified while forwarding */ + if (skb_cow(skb, hdr_size) < 0) + goto free_skb; + + /* packet needs to be linearized to access the tvlv content */ + if (skb_linearize(skb) < 0) + goto free_skb; + + unicast_tvlv_packet = (struct batadv_unicast_tvlv_packet *)skb->data; + + tvlv_buff = (unsigned char *)(skb->data + hdr_size); + tvlv_buff_len = ntohs(unicast_tvlv_packet->tvlv_len); + + if (tvlv_buff_len > skb->len - hdr_size) + goto free_skb; + + ret = batadv_tvlv_containers_process(bat_priv, false, NULL, + unicast_tvlv_packet->src, + unicast_tvlv_packet->dst, + tvlv_buff, tvlv_buff_len); + + if (ret != NET_RX_SUCCESS) { + ret = batadv_route_unicast_packet(skb, recv_if); + /* skb was consumed */ + skb = NULL; + } + +free_skb: + kfree_skb(skb); + + return ret; +} + +/** + * batadv_recv_frag_packet() - process received fragment + * @skb: the received fragment + * @recv_if: interface that the skb is received on + * + * This function does one of the three following things: 1) Forward fragment, if + * the assembled packet will exceed our MTU; 2) Buffer fragment, if we still + * lack further fragments; 3) Merge fragments, if we have all needed parts. + * + * Return: NET_RX_DROP if the skb is not consumed, NET_RX_SUCCESS otherwise. + */ +int batadv_recv_frag_packet(struct sk_buff *skb, + struct batadv_hard_iface *recv_if) +{ + struct batadv_priv *bat_priv = netdev_priv(recv_if->soft_iface); + struct batadv_orig_node *orig_node_src = NULL; + struct batadv_frag_packet *frag_packet; + int ret = NET_RX_DROP; + + if (batadv_check_unicast_packet(bat_priv, skb, + sizeof(*frag_packet)) < 0) + goto free_skb; + + frag_packet = (struct batadv_frag_packet *)skb->data; + orig_node_src = batadv_orig_hash_find(bat_priv, frag_packet->orig); + if (!orig_node_src) + goto free_skb; + + skb->priority = frag_packet->priority + 256; + + /* Route the fragment if it is not for us and too big to be merged. */ + if (!batadv_is_my_mac(bat_priv, frag_packet->dest) && + batadv_frag_skb_fwd(skb, recv_if, orig_node_src)) { + /* skb was consumed */ + skb = NULL; + ret = NET_RX_SUCCESS; + goto put_orig_node; + } + + batadv_inc_counter(bat_priv, BATADV_CNT_FRAG_RX); + batadv_add_counter(bat_priv, BATADV_CNT_FRAG_RX_BYTES, skb->len); + + /* Add fragment to buffer and merge if possible. */ + if (!batadv_frag_skb_buffer(&skb, orig_node_src)) + goto put_orig_node; + + /* Deliver merged packet to the appropriate handler, if it was + * merged + */ + if (skb) { + batadv_batman_skb_recv(skb, recv_if->net_dev, + &recv_if->batman_adv_ptype, NULL); + /* skb was consumed */ + skb = NULL; + } + + ret = NET_RX_SUCCESS; + +put_orig_node: + batadv_orig_node_put(orig_node_src); +free_skb: + kfree_skb(skb); + + return ret; +} + +/** + * batadv_recv_bcast_packet() - Process incoming broadcast packet + * @skb: incoming packet buffer + * @recv_if: incoming hard interface + * + * Return: NET_RX_SUCCESS on success or NET_RX_DROP in case of failure + */ +int batadv_recv_bcast_packet(struct sk_buff *skb, + struct batadv_hard_iface *recv_if) +{ + struct batadv_priv *bat_priv = netdev_priv(recv_if->soft_iface); + struct batadv_orig_node *orig_node = NULL; + struct batadv_bcast_packet *bcast_packet; + struct ethhdr *ethhdr; + int hdr_size = sizeof(*bcast_packet); + s32 seq_diff; + u32 seqno; + int ret; + + /* drop packet if it has not necessary minimum size */ + if (unlikely(!pskb_may_pull(skb, hdr_size))) + goto free_skb; + + ethhdr = eth_hdr(skb); + + /* packet with broadcast indication but unicast recipient */ + if (!is_broadcast_ether_addr(ethhdr->h_dest)) + goto free_skb; + + /* packet with broadcast/multicast sender address */ + if (is_multicast_ether_addr(ethhdr->h_source)) + goto free_skb; + + /* ignore broadcasts sent by myself */ + if (batadv_is_my_mac(bat_priv, ethhdr->h_source)) + goto free_skb; + + bcast_packet = (struct batadv_bcast_packet *)skb->data; + + /* ignore broadcasts originated by myself */ + if (batadv_is_my_mac(bat_priv, bcast_packet->orig)) + goto free_skb; + + if (bcast_packet->ttl-- < 2) + goto free_skb; + + orig_node = batadv_orig_hash_find(bat_priv, bcast_packet->orig); + + if (!orig_node) + goto free_skb; + + spin_lock_bh(&orig_node->bcast_seqno_lock); + + seqno = ntohl(bcast_packet->seqno); + /* check whether the packet is a duplicate */ + if (batadv_test_bit(orig_node->bcast_bits, orig_node->last_bcast_seqno, + seqno)) + goto spin_unlock; + + seq_diff = seqno - orig_node->last_bcast_seqno; + + /* check whether the packet is old and the host just restarted. */ + if (batadv_window_protected(bat_priv, seq_diff, + BATADV_BCAST_MAX_AGE, + &orig_node->bcast_seqno_reset, NULL)) + goto spin_unlock; + + /* mark broadcast in flood history, update window position + * if required. + */ + if (batadv_bit_get_packet(bat_priv, orig_node->bcast_bits, seq_diff, 1)) + orig_node->last_bcast_seqno = seqno; + + spin_unlock_bh(&orig_node->bcast_seqno_lock); + + /* check whether this has been sent by another originator before */ + if (batadv_bla_check_bcast_duplist(bat_priv, skb)) + goto free_skb; + + batadv_skb_set_priority(skb, sizeof(struct batadv_bcast_packet)); + + /* rebroadcast packet */ + ret = batadv_forw_bcast_packet(bat_priv, skb, 0, false); + if (ret == NETDEV_TX_BUSY) + goto free_skb; + + /* don't hand the broadcast up if it is from an originator + * from the same backbone. + */ + if (batadv_bla_is_backbone_gw(skb, orig_node, hdr_size)) + goto free_skb; + + if (batadv_dat_snoop_incoming_arp_request(bat_priv, skb, hdr_size)) + goto rx_success; + if (batadv_dat_snoop_incoming_arp_reply(bat_priv, skb, hdr_size)) + goto rx_success; + + batadv_dat_snoop_incoming_dhcp_ack(bat_priv, skb, hdr_size); + + /* broadcast for me */ + batadv_interface_rx(recv_if->soft_iface, skb, hdr_size, orig_node); + +rx_success: + ret = NET_RX_SUCCESS; + goto out; + +spin_unlock: + spin_unlock_bh(&orig_node->bcast_seqno_lock); +free_skb: + kfree_skb(skb); + ret = NET_RX_DROP; +out: + batadv_orig_node_put(orig_node); + return ret; +} diff --git a/net/batman-adv/routing.h b/net/batman-adv/routing.h new file mode 100644 index 000000000..5f387786e --- /dev/null +++ b/net/batman-adv/routing.h @@ -0,0 +1,46 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#ifndef _NET_BATMAN_ADV_ROUTING_H_ +#define _NET_BATMAN_ADV_ROUTING_H_ + +#include "main.h" + +#include +#include + +bool batadv_check_management_packet(struct sk_buff *skb, + struct batadv_hard_iface *hard_iface, + int header_len); +void batadv_update_route(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + struct batadv_hard_iface *recv_if, + struct batadv_neigh_node *neigh_node); +int batadv_recv_icmp_packet(struct sk_buff *skb, + struct batadv_hard_iface *recv_if); +int batadv_recv_unicast_packet(struct sk_buff *skb, + struct batadv_hard_iface *recv_if); +int batadv_recv_frag_packet(struct sk_buff *skb, + struct batadv_hard_iface *iface); +int batadv_recv_bcast_packet(struct sk_buff *skb, + struct batadv_hard_iface *recv_if); +int batadv_recv_tt_query(struct sk_buff *skb, + struct batadv_hard_iface *recv_if); +int batadv_recv_roam_adv(struct sk_buff *skb, + struct batadv_hard_iface *recv_if); +int batadv_recv_unicast_tvlv(struct sk_buff *skb, + struct batadv_hard_iface *recv_if); +int batadv_recv_unhandled_unicast_packet(struct sk_buff *skb, + struct batadv_hard_iface *recv_if); +struct batadv_neigh_node * +batadv_find_router(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + struct batadv_hard_iface *recv_if); +bool batadv_window_protected(struct batadv_priv *bat_priv, s32 seq_num_diff, + s32 seq_old_max_diff, unsigned long *last_reset, + bool *protection_started); + +#endif /* _NET_BATMAN_ADV_ROUTING_H_ */ diff --git a/net/batman-adv/send.c b/net/batman-adv/send.c new file mode 100644 index 000000000..0379b1268 --- /dev/null +++ b/net/batman-adv/send.c @@ -0,0 +1,1135 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#include "send.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "distributed-arp-table.h" +#include "fragmentation.h" +#include "gateway_client.h" +#include "hard-interface.h" +#include "log.h" +#include "network-coding.h" +#include "originator.h" +#include "routing.h" +#include "soft-interface.h" +#include "translation-table.h" + +static void batadv_send_outstanding_bcast_packet(struct work_struct *work); + +/** + * batadv_send_skb_packet() - send an already prepared packet + * @skb: the packet to send + * @hard_iface: the interface to use to send the broadcast packet + * @dst_addr: the payload destination + * + * Send out an already prepared packet to the given neighbor or broadcast it + * using the specified interface. Either hard_iface or neigh_node must be not + * NULL. + * If neigh_node is NULL, then the packet is broadcasted using hard_iface, + * otherwise it is sent as unicast to the given neighbor. + * + * Regardless of the return value, the skb is consumed. + * + * Return: A negative errno code is returned on a failure. A success does not + * guarantee the frame will be transmitted as it may be dropped due + * to congestion or traffic shaping. + */ +int batadv_send_skb_packet(struct sk_buff *skb, + struct batadv_hard_iface *hard_iface, + const u8 *dst_addr) +{ + struct batadv_priv *bat_priv; + struct ethhdr *ethhdr; + int ret; + + bat_priv = netdev_priv(hard_iface->soft_iface); + + if (hard_iface->if_status != BATADV_IF_ACTIVE) + goto send_skb_err; + + if (unlikely(!hard_iface->net_dev)) + goto send_skb_err; + + if (!(hard_iface->net_dev->flags & IFF_UP)) { + pr_warn("Interface %s is not up - can't send packet via that interface!\n", + hard_iface->net_dev->name); + goto send_skb_err; + } + + /* push to the ethernet header. */ + if (batadv_skb_head_push(skb, ETH_HLEN) < 0) + goto send_skb_err; + + skb_reset_mac_header(skb); + + ethhdr = eth_hdr(skb); + ether_addr_copy(ethhdr->h_source, hard_iface->net_dev->dev_addr); + ether_addr_copy(ethhdr->h_dest, dst_addr); + ethhdr->h_proto = htons(ETH_P_BATMAN); + + skb_set_network_header(skb, ETH_HLEN); + skb->protocol = htons(ETH_P_BATMAN); + + skb->dev = hard_iface->net_dev; + + /* Save a clone of the skb to use when decoding coded packets */ + batadv_nc_skb_store_for_decoding(bat_priv, skb); + + /* dev_queue_xmit() returns a negative result on error. However on + * congestion and traffic shaping, it drops and returns NET_XMIT_DROP + * (which is > 0). This will not be treated as an error. + */ + ret = dev_queue_xmit(skb); + return net_xmit_eval(ret); +send_skb_err: + kfree_skb(skb); + return NET_XMIT_DROP; +} + +/** + * batadv_send_broadcast_skb() - Send broadcast packet via hard interface + * @skb: packet to be transmitted (with batadv header and no outer eth header) + * @hard_iface: outgoing interface + * + * Return: A negative errno code is returned on a failure. A success does not + * guarantee the frame will be transmitted as it may be dropped due + * to congestion or traffic shaping. + */ +int batadv_send_broadcast_skb(struct sk_buff *skb, + struct batadv_hard_iface *hard_iface) +{ + return batadv_send_skb_packet(skb, hard_iface, batadv_broadcast_addr); +} + +/** + * batadv_send_unicast_skb() - Send unicast packet to neighbor + * @skb: packet to be transmitted (with batadv header and no outer eth header) + * @neigh: neighbor which is used as next hop to destination + * + * Return: A negative errno code is returned on a failure. A success does not + * guarantee the frame will be transmitted as it may be dropped due + * to congestion or traffic shaping. + */ +int batadv_send_unicast_skb(struct sk_buff *skb, + struct batadv_neigh_node *neigh) +{ +#ifdef CONFIG_BATMAN_ADV_BATMAN_V + struct batadv_hardif_neigh_node *hardif_neigh; +#endif + int ret; + + ret = batadv_send_skb_packet(skb, neigh->if_incoming, neigh->addr); + +#ifdef CONFIG_BATMAN_ADV_BATMAN_V + hardif_neigh = batadv_hardif_neigh_get(neigh->if_incoming, neigh->addr); + + if (hardif_neigh && ret != NET_XMIT_DROP) + hardif_neigh->bat_v.last_unicast_tx = jiffies; + + batadv_hardif_neigh_put(hardif_neigh); +#endif + + return ret; +} + +/** + * batadv_send_skb_to_orig() - Lookup next-hop and transmit skb. + * @skb: Packet to be transmitted. + * @orig_node: Final destination of the packet. + * @recv_if: Interface used when receiving the packet (can be NULL). + * + * Looks up the best next-hop towards the passed originator and passes the + * skb on for preparation of MAC header. If the packet originated from this + * host, NULL can be passed as recv_if and no interface alternating is + * attempted. + * + * Return: negative errno code on a failure, -EINPROGRESS if the skb is + * buffered for later transmit or the NET_XMIT status returned by the + * lower routine if the packet has been passed down. + */ +int batadv_send_skb_to_orig(struct sk_buff *skb, + struct batadv_orig_node *orig_node, + struct batadv_hard_iface *recv_if) +{ + struct batadv_priv *bat_priv = orig_node->bat_priv; + struct batadv_neigh_node *neigh_node; + int ret; + + /* batadv_find_router() increases neigh_nodes refcount if found. */ + neigh_node = batadv_find_router(bat_priv, orig_node, recv_if); + if (!neigh_node) { + ret = -EINVAL; + goto free_skb; + } + + /* Check if the skb is too large to send in one piece and fragment + * it if needed. + */ + if (atomic_read(&bat_priv->fragmentation) && + skb->len > neigh_node->if_incoming->net_dev->mtu) { + /* Fragment and send packet. */ + ret = batadv_frag_send_packet(skb, orig_node, neigh_node); + /* skb was consumed */ + skb = NULL; + + goto put_neigh_node; + } + + /* try to network code the packet, if it is received on an interface + * (i.e. being forwarded). If the packet originates from this node or if + * network coding fails, then send the packet as usual. + */ + if (recv_if && batadv_nc_skb_forward(skb, neigh_node)) + ret = -EINPROGRESS; + else + ret = batadv_send_unicast_skb(skb, neigh_node); + + /* skb was consumed */ + skb = NULL; + +put_neigh_node: + batadv_neigh_node_put(neigh_node); +free_skb: + kfree_skb(skb); + + return ret; +} + +/** + * batadv_send_skb_push_fill_unicast() - extend the buffer and initialize the + * common fields for unicast packets + * @skb: the skb carrying the unicast header to initialize + * @hdr_size: amount of bytes to push at the beginning of the skb + * @orig_node: the destination node + * + * Return: false if the buffer extension was not possible or true otherwise. + */ +static bool +batadv_send_skb_push_fill_unicast(struct sk_buff *skb, int hdr_size, + struct batadv_orig_node *orig_node) +{ + struct batadv_unicast_packet *unicast_packet; + u8 ttvn = (u8)atomic_read(&orig_node->last_ttvn); + + if (batadv_skb_head_push(skb, hdr_size) < 0) + return false; + + unicast_packet = (struct batadv_unicast_packet *)skb->data; + unicast_packet->version = BATADV_COMPAT_VERSION; + /* batman packet type: unicast */ + unicast_packet->packet_type = BATADV_UNICAST; + /* set unicast ttl */ + unicast_packet->ttl = BATADV_TTL; + /* copy the destination for faster routing */ + ether_addr_copy(unicast_packet->dest, orig_node->orig); + /* set the destination tt version number */ + unicast_packet->ttvn = ttvn; + + return true; +} + +/** + * batadv_send_skb_prepare_unicast() - encapsulate an skb with a unicast header + * @skb: the skb containing the payload to encapsulate + * @orig_node: the destination node + * + * Return: false if the payload could not be encapsulated or true otherwise. + */ +static bool batadv_send_skb_prepare_unicast(struct sk_buff *skb, + struct batadv_orig_node *orig_node) +{ + size_t uni_size = sizeof(struct batadv_unicast_packet); + + return batadv_send_skb_push_fill_unicast(skb, uni_size, orig_node); +} + +/** + * batadv_send_skb_prepare_unicast_4addr() - encapsulate an skb with a + * unicast 4addr header + * @bat_priv: the bat priv with all the soft interface information + * @skb: the skb containing the payload to encapsulate + * @orig: the destination node + * @packet_subtype: the unicast 4addr packet subtype to use + * + * Return: false if the payload could not be encapsulated or true otherwise. + */ +bool batadv_send_skb_prepare_unicast_4addr(struct batadv_priv *bat_priv, + struct sk_buff *skb, + struct batadv_orig_node *orig, + int packet_subtype) +{ + struct batadv_hard_iface *primary_if; + struct batadv_unicast_4addr_packet *uc_4addr_packet; + bool ret = false; + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + goto out; + + /* Pull the header space and fill the unicast_packet substructure. + * We can do that because the first member of the uc_4addr_packet + * is of type struct unicast_packet + */ + if (!batadv_send_skb_push_fill_unicast(skb, sizeof(*uc_4addr_packet), + orig)) + goto out; + + uc_4addr_packet = (struct batadv_unicast_4addr_packet *)skb->data; + uc_4addr_packet->u.packet_type = BATADV_UNICAST_4ADDR; + ether_addr_copy(uc_4addr_packet->src, primary_if->net_dev->dev_addr); + uc_4addr_packet->subtype = packet_subtype; + uc_4addr_packet->reserved = 0; + + ret = true; +out: + batadv_hardif_put(primary_if); + return ret; +} + +/** + * batadv_send_skb_unicast() - encapsulate and send an skb via unicast + * @bat_priv: the bat priv with all the soft interface information + * @skb: payload to send + * @packet_type: the batman unicast packet type to use + * @packet_subtype: the unicast 4addr packet subtype (only relevant for unicast + * 4addr packets) + * @orig_node: the originator to send the packet to + * @vid: the vid to be used to search the translation table + * + * Wrap the given skb into a batman-adv unicast or unicast-4addr header + * depending on whether BATADV_UNICAST or BATADV_UNICAST_4ADDR was supplied + * as packet_type. Then send this frame to the given orig_node. + * + * Return: NET_XMIT_DROP in case of error or NET_XMIT_SUCCESS otherwise. + */ +int batadv_send_skb_unicast(struct batadv_priv *bat_priv, + struct sk_buff *skb, int packet_type, + int packet_subtype, + struct batadv_orig_node *orig_node, + unsigned short vid) +{ + struct batadv_unicast_packet *unicast_packet; + struct ethhdr *ethhdr; + int ret = NET_XMIT_DROP; + + if (!orig_node) + goto out; + + switch (packet_type) { + case BATADV_UNICAST: + if (!batadv_send_skb_prepare_unicast(skb, orig_node)) + goto out; + break; + case BATADV_UNICAST_4ADDR: + if (!batadv_send_skb_prepare_unicast_4addr(bat_priv, skb, + orig_node, + packet_subtype)) + goto out; + break; + default: + /* this function supports UNICAST and UNICAST_4ADDR only. It + * should never be invoked with any other packet type + */ + goto out; + } + + /* skb->data might have been reallocated by + * batadv_send_skb_prepare_unicast{,_4addr}() + */ + ethhdr = eth_hdr(skb); + unicast_packet = (struct batadv_unicast_packet *)skb->data; + + /* inform the destination node that we are still missing a correct route + * for this client. The destination will receive this packet and will + * try to reroute it because the ttvn contained in the header is less + * than the current one + */ + if (batadv_tt_global_client_is_roaming(bat_priv, ethhdr->h_dest, vid)) + unicast_packet->ttvn = unicast_packet->ttvn - 1; + + ret = batadv_send_skb_to_orig(skb, orig_node, NULL); + /* skb was consumed */ + skb = NULL; + +out: + kfree_skb(skb); + return ret; +} + +/** + * batadv_send_skb_via_tt_generic() - send an skb via TT lookup + * @bat_priv: the bat priv with all the soft interface information + * @skb: payload to send + * @packet_type: the batman unicast packet type to use + * @packet_subtype: the unicast 4addr packet subtype (only relevant for unicast + * 4addr packets) + * @dst_hint: can be used to override the destination contained in the skb + * @vid: the vid to be used to search the translation table + * + * Look up the recipient node for the destination address in the ethernet + * header via the translation table. Wrap the given skb into a batman-adv + * unicast or unicast-4addr header depending on whether BATADV_UNICAST or + * BATADV_UNICAST_4ADDR was supplied as packet_type. Then send this frame + * to the according destination node. + * + * Return: NET_XMIT_DROP in case of error or NET_XMIT_SUCCESS otherwise. + */ +int batadv_send_skb_via_tt_generic(struct batadv_priv *bat_priv, + struct sk_buff *skb, int packet_type, + int packet_subtype, u8 *dst_hint, + unsigned short vid) +{ + struct ethhdr *ethhdr = (struct ethhdr *)skb->data; + struct batadv_orig_node *orig_node; + u8 *src, *dst; + int ret; + + src = ethhdr->h_source; + dst = ethhdr->h_dest; + + /* if we got an hint! let's send the packet to this client (if any) */ + if (dst_hint) { + src = NULL; + dst = dst_hint; + } + orig_node = batadv_transtable_search(bat_priv, src, dst, vid); + + ret = batadv_send_skb_unicast(bat_priv, skb, packet_type, + packet_subtype, orig_node, vid); + + batadv_orig_node_put(orig_node); + + return ret; +} + +/** + * batadv_send_skb_via_gw() - send an skb via gateway lookup + * @bat_priv: the bat priv with all the soft interface information + * @skb: payload to send + * @vid: the vid to be used to search the translation table + * + * Look up the currently selected gateway. Wrap the given skb into a batman-adv + * unicast header and send this frame to this gateway node. + * + * Return: NET_XMIT_DROP in case of error or NET_XMIT_SUCCESS otherwise. + */ +int batadv_send_skb_via_gw(struct batadv_priv *bat_priv, struct sk_buff *skb, + unsigned short vid) +{ + struct batadv_orig_node *orig_node; + int ret; + + orig_node = batadv_gw_get_selected_orig(bat_priv); + ret = batadv_send_skb_unicast(bat_priv, skb, BATADV_UNICAST_4ADDR, + BATADV_P_DATA, orig_node, vid); + + batadv_orig_node_put(orig_node); + + return ret; +} + +/** + * batadv_forw_packet_free() - free a forwarding packet + * @forw_packet: The packet to free + * @dropped: whether the packet is freed because is dropped + * + * This frees a forwarding packet and releases any resources it might + * have claimed. + */ +void batadv_forw_packet_free(struct batadv_forw_packet *forw_packet, + bool dropped) +{ + if (dropped) + kfree_skb(forw_packet->skb); + else + consume_skb(forw_packet->skb); + + batadv_hardif_put(forw_packet->if_incoming); + batadv_hardif_put(forw_packet->if_outgoing); + if (forw_packet->queue_left) + atomic_inc(forw_packet->queue_left); + kfree(forw_packet); +} + +/** + * batadv_forw_packet_alloc() - allocate a forwarding packet + * @if_incoming: The (optional) if_incoming to be grabbed + * @if_outgoing: The (optional) if_outgoing to be grabbed + * @queue_left: The (optional) queue counter to decrease + * @bat_priv: The bat_priv for the mesh of this forw_packet + * @skb: The raw packet this forwarding packet shall contain + * + * Allocates a forwarding packet and tries to get a reference to the + * (optional) if_incoming, if_outgoing and queue_left. If queue_left + * is NULL then bat_priv is optional, too. + * + * Return: An allocated forwarding packet on success, NULL otherwise. + */ +struct batadv_forw_packet * +batadv_forw_packet_alloc(struct batadv_hard_iface *if_incoming, + struct batadv_hard_iface *if_outgoing, + atomic_t *queue_left, + struct batadv_priv *bat_priv, + struct sk_buff *skb) +{ + struct batadv_forw_packet *forw_packet; + const char *qname; + + if (queue_left && !batadv_atomic_dec_not_zero(queue_left)) { + qname = "unknown"; + + if (queue_left == &bat_priv->bcast_queue_left) + qname = "bcast"; + + if (queue_left == &bat_priv->batman_queue_left) + qname = "batman"; + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "%s queue is full\n", qname); + + return NULL; + } + + forw_packet = kmalloc(sizeof(*forw_packet), GFP_ATOMIC); + if (!forw_packet) + goto err; + + if (if_incoming) + kref_get(&if_incoming->refcount); + + if (if_outgoing) + kref_get(&if_outgoing->refcount); + + INIT_HLIST_NODE(&forw_packet->list); + INIT_HLIST_NODE(&forw_packet->cleanup_list); + forw_packet->skb = skb; + forw_packet->queue_left = queue_left; + forw_packet->if_incoming = if_incoming; + forw_packet->if_outgoing = if_outgoing; + forw_packet->num_packets = 0; + + return forw_packet; + +err: + if (queue_left) + atomic_inc(queue_left); + + return NULL; +} + +/** + * batadv_forw_packet_was_stolen() - check whether someone stole this packet + * @forw_packet: the forwarding packet to check + * + * This function checks whether the given forwarding packet was claimed by + * someone else for free(). + * + * Return: True if someone stole it, false otherwise. + */ +static bool +batadv_forw_packet_was_stolen(struct batadv_forw_packet *forw_packet) +{ + return !hlist_unhashed(&forw_packet->cleanup_list); +} + +/** + * batadv_forw_packet_steal() - claim a forw_packet for free() + * @forw_packet: the forwarding packet to steal + * @lock: a key to the store to steal from (e.g. forw_{bat,bcast}_list_lock) + * + * This function tries to steal a specific forw_packet from global + * visibility for the purpose of getting it for free(). That means + * the caller is *not* allowed to requeue it afterwards. + * + * Return: True if stealing was successful. False if someone else stole it + * before us. + */ +bool batadv_forw_packet_steal(struct batadv_forw_packet *forw_packet, + spinlock_t *lock) +{ + /* did purging routine steal it earlier? */ + spin_lock_bh(lock); + if (batadv_forw_packet_was_stolen(forw_packet)) { + spin_unlock_bh(lock); + return false; + } + + hlist_del_init(&forw_packet->list); + + /* Just to spot misuse of this function */ + hlist_add_fake(&forw_packet->cleanup_list); + + spin_unlock_bh(lock); + return true; +} + +/** + * batadv_forw_packet_list_steal() - claim a list of forward packets for free() + * @forw_list: the to be stolen forward packets + * @cleanup_list: a backup pointer, to be able to dispose the packet later + * @hard_iface: the interface to steal forward packets from + * + * This function claims responsibility to free any forw_packet queued on the + * given hard_iface. If hard_iface is NULL forwarding packets on all hard + * interfaces will be claimed. + * + * The packets are being moved from the forw_list to the cleanup_list. This + * makes it possible for already running threads to notice the claim. + */ +static void +batadv_forw_packet_list_steal(struct hlist_head *forw_list, + struct hlist_head *cleanup_list, + const struct batadv_hard_iface *hard_iface) +{ + struct batadv_forw_packet *forw_packet; + struct hlist_node *safe_tmp_node; + + hlist_for_each_entry_safe(forw_packet, safe_tmp_node, + forw_list, list) { + /* if purge_outstanding_packets() was called with an argument + * we delete only packets belonging to the given interface + */ + if (hard_iface && + forw_packet->if_incoming != hard_iface && + forw_packet->if_outgoing != hard_iface) + continue; + + hlist_del(&forw_packet->list); + hlist_add_head(&forw_packet->cleanup_list, cleanup_list); + } +} + +/** + * batadv_forw_packet_list_free() - free a list of forward packets + * @head: a list of to be freed forw_packets + * + * This function cancels the scheduling of any packet in the provided list, + * waits for any possibly running packet forwarding thread to finish and + * finally, safely frees this forward packet. + * + * This function might sleep. + */ +static void batadv_forw_packet_list_free(struct hlist_head *head) +{ + struct batadv_forw_packet *forw_packet; + struct hlist_node *safe_tmp_node; + + hlist_for_each_entry_safe(forw_packet, safe_tmp_node, head, + cleanup_list) { + cancel_delayed_work_sync(&forw_packet->delayed_work); + + hlist_del(&forw_packet->cleanup_list); + batadv_forw_packet_free(forw_packet, true); + } +} + +/** + * batadv_forw_packet_queue() - try to queue a forwarding packet + * @forw_packet: the forwarding packet to queue + * @lock: a key to the store (e.g. forw_{bat,bcast}_list_lock) + * @head: the shelve to queue it on (e.g. forw_{bat,bcast}_list) + * @send_time: timestamp (jiffies) when the packet is to be sent + * + * This function tries to (re)queue a forwarding packet. Requeuing + * is prevented if the according interface is shutting down + * (e.g. if batadv_forw_packet_list_steal() was called for this + * packet earlier). + * + * Calling batadv_forw_packet_queue() after a call to + * batadv_forw_packet_steal() is forbidden! + * + * Caller needs to ensure that forw_packet->delayed_work was initialized. + */ +static void batadv_forw_packet_queue(struct batadv_forw_packet *forw_packet, + spinlock_t *lock, struct hlist_head *head, + unsigned long send_time) +{ + spin_lock_bh(lock); + + /* did purging routine steal it from us? */ + if (batadv_forw_packet_was_stolen(forw_packet)) { + /* If you got it for free() without trouble, then + * don't get back into the queue after stealing... + */ + WARN_ONCE(hlist_fake(&forw_packet->cleanup_list), + "Requeuing after batadv_forw_packet_steal() not allowed!\n"); + + spin_unlock_bh(lock); + return; + } + + hlist_del_init(&forw_packet->list); + hlist_add_head(&forw_packet->list, head); + + queue_delayed_work(batadv_event_workqueue, + &forw_packet->delayed_work, + send_time - jiffies); + spin_unlock_bh(lock); +} + +/** + * batadv_forw_packet_bcast_queue() - try to queue a broadcast packet + * @bat_priv: the bat priv with all the soft interface information + * @forw_packet: the forwarding packet to queue + * @send_time: timestamp (jiffies) when the packet is to be sent + * + * This function tries to (re)queue a broadcast packet. + * + * Caller needs to ensure that forw_packet->delayed_work was initialized. + */ +static void +batadv_forw_packet_bcast_queue(struct batadv_priv *bat_priv, + struct batadv_forw_packet *forw_packet, + unsigned long send_time) +{ + batadv_forw_packet_queue(forw_packet, &bat_priv->forw_bcast_list_lock, + &bat_priv->forw_bcast_list, send_time); +} + +/** + * batadv_forw_packet_ogmv1_queue() - try to queue an OGMv1 packet + * @bat_priv: the bat priv with all the soft interface information + * @forw_packet: the forwarding packet to queue + * @send_time: timestamp (jiffies) when the packet is to be sent + * + * This function tries to (re)queue an OGMv1 packet. + * + * Caller needs to ensure that forw_packet->delayed_work was initialized. + */ +void batadv_forw_packet_ogmv1_queue(struct batadv_priv *bat_priv, + struct batadv_forw_packet *forw_packet, + unsigned long send_time) +{ + batadv_forw_packet_queue(forw_packet, &bat_priv->forw_bat_list_lock, + &bat_priv->forw_bat_list, send_time); +} + +/** + * batadv_forw_bcast_packet_to_list() - queue broadcast packet for transmissions + * @bat_priv: the bat priv with all the soft interface information + * @skb: broadcast packet to add + * @delay: number of jiffies to wait before sending + * @own_packet: true if it is a self-generated broadcast packet + * @if_in: the interface where the packet was received on + * @if_out: the outgoing interface to queue on + * + * Adds a broadcast packet to the queue and sets up timers. Broadcast packets + * are sent multiple times to increase probability for being received. + * + * This call clones the given skb, hence the caller needs to take into + * account that the data segment of the original skb might not be + * modifiable anymore. + * + * Return: NETDEV_TX_OK on success and NETDEV_TX_BUSY on errors. + */ +static int batadv_forw_bcast_packet_to_list(struct batadv_priv *bat_priv, + struct sk_buff *skb, + unsigned long delay, + bool own_packet, + struct batadv_hard_iface *if_in, + struct batadv_hard_iface *if_out) +{ + struct batadv_forw_packet *forw_packet; + unsigned long send_time = jiffies; + struct sk_buff *newskb; + + newskb = skb_clone(skb, GFP_ATOMIC); + if (!newskb) + goto err; + + forw_packet = batadv_forw_packet_alloc(if_in, if_out, + &bat_priv->bcast_queue_left, + bat_priv, newskb); + if (!forw_packet) + goto err_packet_free; + + forw_packet->own = own_packet; + + INIT_DELAYED_WORK(&forw_packet->delayed_work, + batadv_send_outstanding_bcast_packet); + + send_time += delay ? delay : msecs_to_jiffies(5); + + batadv_forw_packet_bcast_queue(bat_priv, forw_packet, send_time); + return NETDEV_TX_OK; + +err_packet_free: + kfree_skb(newskb); +err: + return NETDEV_TX_BUSY; +} + +/** + * batadv_forw_bcast_packet_if() - forward and queue a broadcast packet + * @bat_priv: the bat priv with all the soft interface information + * @skb: broadcast packet to add + * @delay: number of jiffies to wait before sending + * @own_packet: true if it is a self-generated broadcast packet + * @if_in: the interface where the packet was received on + * @if_out: the outgoing interface to forward to + * + * Transmits a broadcast packet on the specified interface either immediately + * or if a delay is given after that. Furthermore, queues additional + * retransmissions if this interface is a wireless one. + * + * This call clones the given skb, hence the caller needs to take into + * account that the data segment of the original skb might not be + * modifiable anymore. + * + * Return: NETDEV_TX_OK on success and NETDEV_TX_BUSY on errors. + */ +static int batadv_forw_bcast_packet_if(struct batadv_priv *bat_priv, + struct sk_buff *skb, + unsigned long delay, + bool own_packet, + struct batadv_hard_iface *if_in, + struct batadv_hard_iface *if_out) +{ + unsigned int num_bcasts = if_out->num_bcasts; + struct sk_buff *newskb; + int ret = NETDEV_TX_OK; + + if (!delay) { + newskb = skb_clone(skb, GFP_ATOMIC); + if (!newskb) + return NETDEV_TX_BUSY; + + batadv_send_broadcast_skb(newskb, if_out); + num_bcasts--; + } + + /* delayed broadcast or rebroadcasts? */ + if (num_bcasts >= 1) { + BATADV_SKB_CB(skb)->num_bcasts = num_bcasts; + + ret = batadv_forw_bcast_packet_to_list(bat_priv, skb, delay, + own_packet, if_in, + if_out); + } + + return ret; +} + +/** + * batadv_send_no_broadcast() - check whether (re)broadcast is necessary + * @bat_priv: the bat priv with all the soft interface information + * @skb: broadcast packet to check + * @own_packet: true if it is a self-generated broadcast packet + * @if_out: the outgoing interface checked and considered for (re)broadcast + * + * Return: False if a packet needs to be (re)broadcasted on the given interface, + * true otherwise. + */ +static bool batadv_send_no_broadcast(struct batadv_priv *bat_priv, + struct sk_buff *skb, bool own_packet, + struct batadv_hard_iface *if_out) +{ + struct batadv_hardif_neigh_node *neigh_node = NULL; + struct batadv_bcast_packet *bcast_packet; + u8 *orig_neigh; + u8 *neigh_addr; + char *type; + int ret; + + if (!own_packet) { + neigh_addr = eth_hdr(skb)->h_source; + neigh_node = batadv_hardif_neigh_get(if_out, + neigh_addr); + } + + bcast_packet = (struct batadv_bcast_packet *)skb->data; + orig_neigh = neigh_node ? neigh_node->orig : NULL; + + ret = batadv_hardif_no_broadcast(if_out, bcast_packet->orig, + orig_neigh); + + batadv_hardif_neigh_put(neigh_node); + + /* ok, may broadcast */ + if (!ret) + return false; + + /* no broadcast */ + switch (ret) { + case BATADV_HARDIF_BCAST_NORECIPIENT: + type = "no neighbor"; + break; + case BATADV_HARDIF_BCAST_DUPFWD: + type = "single neighbor is source"; + break; + case BATADV_HARDIF_BCAST_DUPORIG: + type = "single neighbor is originator"; + break; + default: + type = "unknown"; + } + + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "BCAST packet from orig %pM on %s suppressed: %s\n", + bcast_packet->orig, + if_out->net_dev->name, type); + + return true; +} + +/** + * __batadv_forw_bcast_packet() - forward and queue a broadcast packet + * @bat_priv: the bat priv with all the soft interface information + * @skb: broadcast packet to add + * @delay: number of jiffies to wait before sending + * @own_packet: true if it is a self-generated broadcast packet + * + * Transmits a broadcast packet either immediately or if a delay is given + * after that. Furthermore, queues additional retransmissions on wireless + * interfaces. + * + * This call clones the given skb, hence the caller needs to take into + * account that the data segment of the given skb might not be + * modifiable anymore. + * + * Return: NETDEV_TX_OK on success and NETDEV_TX_BUSY on errors. + */ +static int __batadv_forw_bcast_packet(struct batadv_priv *bat_priv, + struct sk_buff *skb, + unsigned long delay, + bool own_packet) +{ + struct batadv_hard_iface *hard_iface; + struct batadv_hard_iface *primary_if; + int ret = NETDEV_TX_OK; + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + return NETDEV_TX_BUSY; + + rcu_read_lock(); + list_for_each_entry_rcu(hard_iface, &batadv_hardif_list, list) { + if (hard_iface->soft_iface != bat_priv->soft_iface) + continue; + + if (!kref_get_unless_zero(&hard_iface->refcount)) + continue; + + if (batadv_send_no_broadcast(bat_priv, skb, own_packet, + hard_iface)) { + batadv_hardif_put(hard_iface); + continue; + } + + ret = batadv_forw_bcast_packet_if(bat_priv, skb, delay, + own_packet, primary_if, + hard_iface); + batadv_hardif_put(hard_iface); + + if (ret == NETDEV_TX_BUSY) + break; + } + rcu_read_unlock(); + + batadv_hardif_put(primary_if); + return ret; +} + +/** + * batadv_forw_bcast_packet() - forward and queue a broadcast packet + * @bat_priv: the bat priv with all the soft interface information + * @skb: broadcast packet to add + * @delay: number of jiffies to wait before sending + * @own_packet: true if it is a self-generated broadcast packet + * + * Transmits a broadcast packet either immediately or if a delay is given + * after that. Furthermore, queues additional retransmissions on wireless + * interfaces. + * + * Return: NETDEV_TX_OK on success and NETDEV_TX_BUSY on errors. + */ +int batadv_forw_bcast_packet(struct batadv_priv *bat_priv, + struct sk_buff *skb, + unsigned long delay, + bool own_packet) +{ + return __batadv_forw_bcast_packet(bat_priv, skb, delay, own_packet); +} + +/** + * batadv_send_bcast_packet() - send and queue a broadcast packet + * @bat_priv: the bat priv with all the soft interface information + * @skb: broadcast packet to add + * @delay: number of jiffies to wait before sending + * @own_packet: true if it is a self-generated broadcast packet + * + * Transmits a broadcast packet either immediately or if a delay is given + * after that. Furthermore, queues additional retransmissions on wireless + * interfaces. + * + * Consumes the provided skb. + */ +void batadv_send_bcast_packet(struct batadv_priv *bat_priv, + struct sk_buff *skb, + unsigned long delay, + bool own_packet) +{ + __batadv_forw_bcast_packet(bat_priv, skb, delay, own_packet); + consume_skb(skb); +} + +/** + * batadv_forw_packet_bcasts_left() - check if a retransmission is necessary + * @forw_packet: the forwarding packet to check + * + * Checks whether a given packet has any (re)transmissions left on the provided + * interface. + * + * hard_iface may be NULL: In that case the number of transmissions this skb had + * so far is compared with the maximum amount of retransmissions independent of + * any interface instead. + * + * Return: True if (re)transmissions are left, false otherwise. + */ +static bool +batadv_forw_packet_bcasts_left(struct batadv_forw_packet *forw_packet) +{ + return BATADV_SKB_CB(forw_packet->skb)->num_bcasts; +} + +/** + * batadv_forw_packet_bcasts_dec() - decrement retransmission counter of a + * packet + * @forw_packet: the packet to decrease the counter for + */ +static void +batadv_forw_packet_bcasts_dec(struct batadv_forw_packet *forw_packet) +{ + BATADV_SKB_CB(forw_packet->skb)->num_bcasts--; +} + +/** + * batadv_forw_packet_is_rebroadcast() - check packet for previous transmissions + * @forw_packet: the packet to check + * + * Return: True if this packet was transmitted before, false otherwise. + */ +bool batadv_forw_packet_is_rebroadcast(struct batadv_forw_packet *forw_packet) +{ + unsigned char num_bcasts = BATADV_SKB_CB(forw_packet->skb)->num_bcasts; + + return num_bcasts != forw_packet->if_outgoing->num_bcasts; +} + +/** + * batadv_send_outstanding_bcast_packet() - transmit a queued broadcast packet + * @work: work queue item + * + * Transmits a queued broadcast packet and if necessary reschedules it. + */ +static void batadv_send_outstanding_bcast_packet(struct work_struct *work) +{ + unsigned long send_time = jiffies + msecs_to_jiffies(5); + struct batadv_forw_packet *forw_packet; + struct delayed_work *delayed_work; + struct batadv_priv *bat_priv; + struct sk_buff *skb1; + bool dropped = false; + + delayed_work = to_delayed_work(work); + forw_packet = container_of(delayed_work, struct batadv_forw_packet, + delayed_work); + bat_priv = netdev_priv(forw_packet->if_incoming->soft_iface); + + if (atomic_read(&bat_priv->mesh_state) == BATADV_MESH_DEACTIVATING) { + dropped = true; + goto out; + } + + if (batadv_dat_drop_broadcast_packet(bat_priv, forw_packet)) { + dropped = true; + goto out; + } + + /* send a copy of the saved skb */ + skb1 = skb_clone(forw_packet->skb, GFP_ATOMIC); + if (!skb1) + goto out; + + batadv_send_broadcast_skb(skb1, forw_packet->if_outgoing); + batadv_forw_packet_bcasts_dec(forw_packet); + + if (batadv_forw_packet_bcasts_left(forw_packet)) { + batadv_forw_packet_bcast_queue(bat_priv, forw_packet, + send_time); + return; + } + +out: + /* do we get something for free()? */ + if (batadv_forw_packet_steal(forw_packet, + &bat_priv->forw_bcast_list_lock)) + batadv_forw_packet_free(forw_packet, dropped); +} + +/** + * batadv_purge_outstanding_packets() - stop/purge scheduled bcast/OGMv1 packets + * @bat_priv: the bat priv with all the soft interface information + * @hard_iface: the hard interface to cancel and purge bcast/ogm packets on + * + * This method cancels and purges any broadcast and OGMv1 packet on the given + * hard_iface. If hard_iface is NULL, broadcast and OGMv1 packets on all hard + * interfaces will be canceled and purged. + * + * This function might sleep. + */ +void +batadv_purge_outstanding_packets(struct batadv_priv *bat_priv, + const struct batadv_hard_iface *hard_iface) +{ + struct hlist_head head = HLIST_HEAD_INIT; + + if (hard_iface) + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "%s(): %s\n", + __func__, hard_iface->net_dev->name); + else + batadv_dbg(BATADV_DBG_BATMAN, bat_priv, + "%s()\n", __func__); + + /* claim bcast list for free() */ + spin_lock_bh(&bat_priv->forw_bcast_list_lock); + batadv_forw_packet_list_steal(&bat_priv->forw_bcast_list, &head, + hard_iface); + spin_unlock_bh(&bat_priv->forw_bcast_list_lock); + + /* claim batman packet list for free() */ + spin_lock_bh(&bat_priv->forw_bat_list_lock); + batadv_forw_packet_list_steal(&bat_priv->forw_bat_list, &head, + hard_iface); + spin_unlock_bh(&bat_priv->forw_bat_list_lock); + + /* then cancel or wait for packet workers to finish and free */ + batadv_forw_packet_list_free(&head); +} diff --git a/net/batman-adv/send.h b/net/batman-adv/send.h new file mode 100644 index 000000000..08af251b7 --- /dev/null +++ b/net/batman-adv/send.h @@ -0,0 +1,116 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#ifndef _NET_BATMAN_ADV_SEND_H_ +#define _NET_BATMAN_ADV_SEND_H_ + +#include "main.h" + +#include +#include +#include +#include +#include + +void batadv_forw_packet_free(struct batadv_forw_packet *forw_packet, + bool dropped); +struct batadv_forw_packet * +batadv_forw_packet_alloc(struct batadv_hard_iface *if_incoming, + struct batadv_hard_iface *if_outgoing, + atomic_t *queue_left, + struct batadv_priv *bat_priv, + struct sk_buff *skb); +bool batadv_forw_packet_steal(struct batadv_forw_packet *packet, spinlock_t *l); +void batadv_forw_packet_ogmv1_queue(struct batadv_priv *bat_priv, + struct batadv_forw_packet *forw_packet, + unsigned long send_time); +bool batadv_forw_packet_is_rebroadcast(struct batadv_forw_packet *forw_packet); + +int batadv_send_skb_to_orig(struct sk_buff *skb, + struct batadv_orig_node *orig_node, + struct batadv_hard_iface *recv_if); +int batadv_send_skb_packet(struct sk_buff *skb, + struct batadv_hard_iface *hard_iface, + const u8 *dst_addr); +int batadv_send_broadcast_skb(struct sk_buff *skb, + struct batadv_hard_iface *hard_iface); +int batadv_send_unicast_skb(struct sk_buff *skb, + struct batadv_neigh_node *neigh_node); +int batadv_forw_bcast_packet(struct batadv_priv *bat_priv, + struct sk_buff *skb, + unsigned long delay, + bool own_packet); +void batadv_send_bcast_packet(struct batadv_priv *bat_priv, + struct sk_buff *skb, + unsigned long delay, + bool own_packet); +void +batadv_purge_outstanding_packets(struct batadv_priv *bat_priv, + const struct batadv_hard_iface *hard_iface); +bool batadv_send_skb_prepare_unicast_4addr(struct batadv_priv *bat_priv, + struct sk_buff *skb, + struct batadv_orig_node *orig_node, + int packet_subtype); +int batadv_send_skb_unicast(struct batadv_priv *bat_priv, + struct sk_buff *skb, int packet_type, + int packet_subtype, + struct batadv_orig_node *orig_node, + unsigned short vid); +int batadv_send_skb_via_tt_generic(struct batadv_priv *bat_priv, + struct sk_buff *skb, int packet_type, + int packet_subtype, u8 *dst_hint, + unsigned short vid); +int batadv_send_skb_via_gw(struct batadv_priv *bat_priv, struct sk_buff *skb, + unsigned short vid); + +/** + * batadv_send_skb_via_tt() - send an skb via TT lookup + * @bat_priv: the bat priv with all the soft interface information + * @skb: the payload to send + * @dst_hint: can be used to override the destination contained in the skb + * @vid: the vid to be used to search the translation table + * + * Look up the recipient node for the destination address in the ethernet + * header via the translation table. Wrap the given skb into a batman-adv + * unicast header. Then send this frame to the according destination node. + * + * Return: NET_XMIT_DROP in case of error or NET_XMIT_SUCCESS otherwise. + */ +static inline int batadv_send_skb_via_tt(struct batadv_priv *bat_priv, + struct sk_buff *skb, u8 *dst_hint, + unsigned short vid) +{ + return batadv_send_skb_via_tt_generic(bat_priv, skb, BATADV_UNICAST, 0, + dst_hint, vid); +} + +/** + * batadv_send_skb_via_tt_4addr() - send an skb via TT lookup + * @bat_priv: the bat priv with all the soft interface information + * @skb: the payload to send + * @packet_subtype: the unicast 4addr packet subtype to use + * @dst_hint: can be used to override the destination contained in the skb + * @vid: the vid to be used to search the translation table + * + * Look up the recipient node for the destination address in the ethernet + * header via the translation table. Wrap the given skb into a batman-adv + * unicast-4addr header. Then send this frame to the according destination + * node. + * + * Return: NET_XMIT_DROP in case of error or NET_XMIT_SUCCESS otherwise. + */ +static inline int batadv_send_skb_via_tt_4addr(struct batadv_priv *bat_priv, + struct sk_buff *skb, + int packet_subtype, + u8 *dst_hint, + unsigned short vid) +{ + return batadv_send_skb_via_tt_generic(bat_priv, skb, + BATADV_UNICAST_4ADDR, + packet_subtype, dst_hint, vid); +} + +#endif /* _NET_BATMAN_ADV_SEND_H_ */ diff --git a/net/batman-adv/soft-interface.c b/net/batman-adv/soft-interface.c new file mode 100644 index 000000000..d7b525a49 --- /dev/null +++ b/net/batman-adv/soft-interface.c @@ -0,0 +1,1132 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#include "soft-interface.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "bat_algo.h" +#include "bridge_loop_avoidance.h" +#include "distributed-arp-table.h" +#include "gateway_client.h" +#include "hard-interface.h" +#include "multicast.h" +#include "network-coding.h" +#include "originator.h" +#include "send.h" +#include "translation-table.h" + +/** + * batadv_skb_head_push() - Increase header size and move (push) head pointer + * @skb: packet buffer which should be modified + * @len: number of bytes to add + * + * Return: 0 on success or negative error number in case of failure + */ +int batadv_skb_head_push(struct sk_buff *skb, unsigned int len) +{ + int result; + + /* TODO: We must check if we can release all references to non-payload + * data using __skb_header_release in our skbs to allow skb_cow_header + * to work optimally. This means that those skbs are not allowed to read + * or write any data which is before the current position of skb->data + * after that call and thus allow other skbs with the same data buffer + * to write freely in that area. + */ + result = skb_cow_head(skb, len); + if (result < 0) + return result; + + skb_push(skb, len); + return 0; +} + +static int batadv_interface_open(struct net_device *dev) +{ + netif_start_queue(dev); + return 0; +} + +static int batadv_interface_release(struct net_device *dev) +{ + netif_stop_queue(dev); + return 0; +} + +/** + * batadv_sum_counter() - Sum the cpu-local counters for index 'idx' + * @bat_priv: the bat priv with all the soft interface information + * @idx: index of counter to sum up + * + * Return: sum of all cpu-local counters + */ +static u64 batadv_sum_counter(struct batadv_priv *bat_priv, size_t idx) +{ + u64 *counters, sum = 0; + int cpu; + + for_each_possible_cpu(cpu) { + counters = per_cpu_ptr(bat_priv->bat_counters, cpu); + sum += counters[idx]; + } + + return sum; +} + +static struct net_device_stats *batadv_interface_stats(struct net_device *dev) +{ + struct batadv_priv *bat_priv = netdev_priv(dev); + struct net_device_stats *stats = &dev->stats; + + stats->tx_packets = batadv_sum_counter(bat_priv, BATADV_CNT_TX); + stats->tx_bytes = batadv_sum_counter(bat_priv, BATADV_CNT_TX_BYTES); + stats->tx_dropped = batadv_sum_counter(bat_priv, BATADV_CNT_TX_DROPPED); + stats->rx_packets = batadv_sum_counter(bat_priv, BATADV_CNT_RX); + stats->rx_bytes = batadv_sum_counter(bat_priv, BATADV_CNT_RX_BYTES); + return stats; +} + +static int batadv_interface_set_mac_addr(struct net_device *dev, void *p) +{ + struct batadv_priv *bat_priv = netdev_priv(dev); + struct batadv_softif_vlan *vlan; + struct sockaddr *addr = p; + u8 old_addr[ETH_ALEN]; + + if (!is_valid_ether_addr(addr->sa_data)) + return -EADDRNOTAVAIL; + + ether_addr_copy(old_addr, dev->dev_addr); + eth_hw_addr_set(dev, addr->sa_data); + + /* only modify transtable if it has been initialized before */ + if (atomic_read(&bat_priv->mesh_state) != BATADV_MESH_ACTIVE) + return 0; + + rcu_read_lock(); + hlist_for_each_entry_rcu(vlan, &bat_priv->softif_vlan_list, list) { + batadv_tt_local_remove(bat_priv, old_addr, vlan->vid, + "mac address changed", false); + batadv_tt_local_add(dev, addr->sa_data, vlan->vid, + BATADV_NULL_IFINDEX, BATADV_NO_MARK); + } + rcu_read_unlock(); + + return 0; +} + +static int batadv_interface_change_mtu(struct net_device *dev, int new_mtu) +{ + struct batadv_priv *bat_priv = netdev_priv(dev); + + /* check ranges */ + if (new_mtu < 68 || new_mtu > batadv_hardif_min_mtu(dev)) + return -EINVAL; + + dev->mtu = new_mtu; + bat_priv->mtu_set_by_user = new_mtu; + + return 0; +} + +/** + * batadv_interface_set_rx_mode() - set the rx mode of a device + * @dev: registered network device to modify + * + * We do not actually need to set any rx filters for the virtual batman + * soft interface. However a dummy handler enables a user to set static + * multicast listeners for instance. + */ +static void batadv_interface_set_rx_mode(struct net_device *dev) +{ +} + +static netdev_tx_t batadv_interface_tx(struct sk_buff *skb, + struct net_device *soft_iface) +{ + struct ethhdr *ethhdr; + struct batadv_priv *bat_priv = netdev_priv(soft_iface); + struct batadv_hard_iface *primary_if = NULL; + struct batadv_bcast_packet *bcast_packet; + static const u8 stp_addr[ETH_ALEN] = {0x01, 0x80, 0xC2, 0x00, + 0x00, 0x00}; + static const u8 ectp_addr[ETH_ALEN] = {0xCF, 0x00, 0x00, 0x00, + 0x00, 0x00}; + enum batadv_dhcp_recipient dhcp_rcp = BATADV_DHCP_NO; + u8 *dst_hint = NULL, chaddr[ETH_ALEN]; + struct vlan_ethhdr *vhdr; + unsigned int header_len = 0; + int data_len = skb->len, ret; + unsigned long brd_delay = 0; + bool do_bcast = false, client_added; + unsigned short vid; + u32 seqno; + int gw_mode; + enum batadv_forw_mode forw_mode = BATADV_FORW_SINGLE; + struct batadv_orig_node *mcast_single_orig = NULL; + int mcast_is_routable = 0; + int network_offset = ETH_HLEN; + __be16 proto; + + if (atomic_read(&bat_priv->mesh_state) != BATADV_MESH_ACTIVE) + goto dropped; + + /* reset control block to avoid left overs from previous users */ + memset(skb->cb, 0, sizeof(struct batadv_skb_cb)); + + netif_trans_update(soft_iface); + vid = batadv_get_vid(skb, 0); + + skb_reset_mac_header(skb); + ethhdr = eth_hdr(skb); + + proto = ethhdr->h_proto; + + switch (ntohs(proto)) { + case ETH_P_8021Q: + if (!pskb_may_pull(skb, sizeof(*vhdr))) + goto dropped; + vhdr = vlan_eth_hdr(skb); + proto = vhdr->h_vlan_encapsulated_proto; + + /* drop batman-in-batman packets to prevent loops */ + if (proto != htons(ETH_P_BATMAN)) { + network_offset += VLAN_HLEN; + break; + } + + fallthrough; + case ETH_P_BATMAN: + goto dropped; + } + + skb_set_network_header(skb, network_offset); + + if (batadv_bla_tx(bat_priv, skb, vid)) + goto dropped; + + /* skb->data might have been reallocated by batadv_bla_tx() */ + ethhdr = eth_hdr(skb); + + /* Register the client MAC in the transtable */ + if (!is_multicast_ether_addr(ethhdr->h_source) && + !batadv_bla_is_loopdetect_mac(ethhdr->h_source)) { + client_added = batadv_tt_local_add(soft_iface, ethhdr->h_source, + vid, skb->skb_iif, + skb->mark); + if (!client_added) + goto dropped; + } + + /* Snoop address candidates from DHCPACKs for early DAT filling */ + batadv_dat_snoop_outgoing_dhcp_ack(bat_priv, skb, proto, vid); + + /* don't accept stp packets. STP does not help in meshes. + * better use the bridge loop avoidance ... + * + * The same goes for ECTP sent at least by some Cisco Switches, + * it might confuse the mesh when used with bridge loop avoidance. + */ + if (batadv_compare_eth(ethhdr->h_dest, stp_addr)) + goto dropped; + + if (batadv_compare_eth(ethhdr->h_dest, ectp_addr)) + goto dropped; + + gw_mode = atomic_read(&bat_priv->gw.mode); + if (is_multicast_ether_addr(ethhdr->h_dest)) { + /* if gw mode is off, broadcast every packet */ + if (gw_mode == BATADV_GW_MODE_OFF) { + do_bcast = true; + goto send; + } + + dhcp_rcp = batadv_gw_dhcp_recipient_get(skb, &header_len, + chaddr); + /* skb->data may have been modified by + * batadv_gw_dhcp_recipient_get() + */ + ethhdr = eth_hdr(skb); + /* if gw_mode is on, broadcast any non-DHCP message. + * All the DHCP packets are going to be sent as unicast + */ + if (dhcp_rcp == BATADV_DHCP_NO) { + do_bcast = true; + goto send; + } + + if (dhcp_rcp == BATADV_DHCP_TO_CLIENT) + dst_hint = chaddr; + else if ((gw_mode == BATADV_GW_MODE_SERVER) && + (dhcp_rcp == BATADV_DHCP_TO_SERVER)) + /* gateways should not forward any DHCP message if + * directed to a DHCP server + */ + goto dropped; + +send: + if (do_bcast && !is_broadcast_ether_addr(ethhdr->h_dest)) { + forw_mode = batadv_mcast_forw_mode(bat_priv, skb, + &mcast_single_orig, + &mcast_is_routable); + if (forw_mode == BATADV_FORW_NONE) + goto dropped; + + if (forw_mode == BATADV_FORW_SINGLE || + forw_mode == BATADV_FORW_SOME) + do_bcast = false; + } + } + + batadv_skb_set_priority(skb, 0); + + /* ethernet packet should be broadcasted */ + if (do_bcast) { + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + goto dropped; + + /* in case of ARP request, we do not immediately broadcasti the + * packet, instead we first wait for DAT to try to retrieve the + * correct ARP entry + */ + if (batadv_dat_snoop_outgoing_arp_request(bat_priv, skb)) + brd_delay = msecs_to_jiffies(ARP_REQ_DELAY); + + if (batadv_skb_head_push(skb, sizeof(*bcast_packet)) < 0) + goto dropped; + + bcast_packet = (struct batadv_bcast_packet *)skb->data; + bcast_packet->version = BATADV_COMPAT_VERSION; + bcast_packet->ttl = BATADV_TTL - 1; + + /* batman packet type: broadcast */ + bcast_packet->packet_type = BATADV_BCAST; + bcast_packet->reserved = 0; + + /* hw address of first interface is the orig mac because only + * this mac is known throughout the mesh + */ + ether_addr_copy(bcast_packet->orig, + primary_if->net_dev->dev_addr); + + /* set broadcast sequence number */ + seqno = atomic_inc_return(&bat_priv->bcast_seqno); + bcast_packet->seqno = htonl(seqno); + + batadv_send_bcast_packet(bat_priv, skb, brd_delay, true); + /* unicast packet */ + } else { + /* DHCP packets going to a server will use the GW feature */ + if (dhcp_rcp == BATADV_DHCP_TO_SERVER) { + ret = batadv_gw_out_of_range(bat_priv, skb); + if (ret) + goto dropped; + ret = batadv_send_skb_via_gw(bat_priv, skb, vid); + } else if (mcast_single_orig) { + ret = batadv_mcast_forw_send_orig(bat_priv, skb, vid, + mcast_single_orig); + } else if (forw_mode == BATADV_FORW_SOME) { + ret = batadv_mcast_forw_send(bat_priv, skb, vid, + mcast_is_routable); + } else { + if (batadv_dat_snoop_outgoing_arp_request(bat_priv, + skb)) + goto dropped; + + batadv_dat_snoop_outgoing_arp_reply(bat_priv, skb); + + ret = batadv_send_skb_via_tt(bat_priv, skb, dst_hint, + vid); + } + if (ret != NET_XMIT_SUCCESS) + goto dropped_freed; + } + + batadv_inc_counter(bat_priv, BATADV_CNT_TX); + batadv_add_counter(bat_priv, BATADV_CNT_TX_BYTES, data_len); + goto end; + +dropped: + kfree_skb(skb); +dropped_freed: + batadv_inc_counter(bat_priv, BATADV_CNT_TX_DROPPED); +end: + batadv_orig_node_put(mcast_single_orig); + batadv_hardif_put(primary_if); + return NETDEV_TX_OK; +} + +/** + * batadv_interface_rx() - receive ethernet frame on local batman-adv interface + * @soft_iface: local interface which will receive the ethernet frame + * @skb: ethernet frame for @soft_iface + * @hdr_size: size of already parsed batman-adv header + * @orig_node: originator from which the batman-adv packet was sent + * + * Sends an ethernet frame to the receive path of the local @soft_iface. + * skb->data has still point to the batman-adv header with the size @hdr_size. + * The caller has to have parsed this header already and made sure that at least + * @hdr_size bytes are still available for pull in @skb. + * + * The packet may still get dropped. This can happen when the encapsulated + * ethernet frame is invalid or contains again an batman-adv packet. Also + * unicast packets will be dropped directly when it was sent between two + * isolated clients. + */ +void batadv_interface_rx(struct net_device *soft_iface, + struct sk_buff *skb, int hdr_size, + struct batadv_orig_node *orig_node) +{ + struct batadv_bcast_packet *batadv_bcast_packet; + struct batadv_priv *bat_priv = netdev_priv(soft_iface); + struct vlan_ethhdr *vhdr; + struct ethhdr *ethhdr; + unsigned short vid; + int packet_type; + + batadv_bcast_packet = (struct batadv_bcast_packet *)skb->data; + packet_type = batadv_bcast_packet->packet_type; + + skb_pull_rcsum(skb, hdr_size); + skb_reset_mac_header(skb); + + /* clean the netfilter state now that the batman-adv header has been + * removed + */ + nf_reset_ct(skb); + + if (unlikely(!pskb_may_pull(skb, ETH_HLEN))) + goto dropped; + + vid = batadv_get_vid(skb, 0); + ethhdr = eth_hdr(skb); + + switch (ntohs(ethhdr->h_proto)) { + case ETH_P_8021Q: + if (!pskb_may_pull(skb, VLAN_ETH_HLEN)) + goto dropped; + + vhdr = skb_vlan_eth_hdr(skb); + + /* drop batman-in-batman packets to prevent loops */ + if (vhdr->h_vlan_encapsulated_proto != htons(ETH_P_BATMAN)) + break; + + fallthrough; + case ETH_P_BATMAN: + goto dropped; + } + + /* skb->dev & skb->pkt_type are set here */ + skb->protocol = eth_type_trans(skb, soft_iface); + skb_postpull_rcsum(skb, eth_hdr(skb), ETH_HLEN); + + batadv_inc_counter(bat_priv, BATADV_CNT_RX); + batadv_add_counter(bat_priv, BATADV_CNT_RX_BYTES, + skb->len + ETH_HLEN); + + /* Let the bridge loop avoidance check the packet. If will + * not handle it, we can safely push it up. + */ + if (batadv_bla_rx(bat_priv, skb, vid, packet_type)) + goto out; + + if (orig_node) + batadv_tt_add_temporary_global_entry(bat_priv, orig_node, + ethhdr->h_source, vid); + + if (is_multicast_ether_addr(ethhdr->h_dest)) { + /* set the mark on broadcast packets if AP isolation is ON and + * the packet is coming from an "isolated" client + */ + if (batadv_vlan_ap_isola_get(bat_priv, vid) && + batadv_tt_global_is_isolated(bat_priv, ethhdr->h_source, + vid)) { + /* save bits in skb->mark not covered by the mask and + * apply the mark on the rest + */ + skb->mark &= ~bat_priv->isolation_mark_mask; + skb->mark |= bat_priv->isolation_mark; + } + } else if (batadv_is_ap_isolated(bat_priv, ethhdr->h_source, + ethhdr->h_dest, vid)) { + goto dropped; + } + + netif_rx(skb); + goto out; + +dropped: + kfree_skb(skb); +out: + return; +} + +/** + * batadv_softif_vlan_release() - release vlan from lists and queue for free + * after rcu grace period + * @ref: kref pointer of the vlan object + */ +void batadv_softif_vlan_release(struct kref *ref) +{ + struct batadv_softif_vlan *vlan; + + vlan = container_of(ref, struct batadv_softif_vlan, refcount); + + spin_lock_bh(&vlan->bat_priv->softif_vlan_list_lock); + hlist_del_rcu(&vlan->list); + spin_unlock_bh(&vlan->bat_priv->softif_vlan_list_lock); + + kfree_rcu(vlan, rcu); +} + +/** + * batadv_softif_vlan_get() - get the vlan object for a specific vid + * @bat_priv: the bat priv with all the soft interface information + * @vid: the identifier of the vlan object to retrieve + * + * Return: the private data of the vlan matching the vid passed as argument or + * NULL otherwise. The refcounter of the returned object is incremented by 1. + */ +struct batadv_softif_vlan *batadv_softif_vlan_get(struct batadv_priv *bat_priv, + unsigned short vid) +{ + struct batadv_softif_vlan *vlan_tmp, *vlan = NULL; + + rcu_read_lock(); + hlist_for_each_entry_rcu(vlan_tmp, &bat_priv->softif_vlan_list, list) { + if (vlan_tmp->vid != vid) + continue; + + if (!kref_get_unless_zero(&vlan_tmp->refcount)) + continue; + + vlan = vlan_tmp; + break; + } + rcu_read_unlock(); + + return vlan; +} + +/** + * batadv_softif_create_vlan() - allocate the needed resources for a new vlan + * @bat_priv: the bat priv with all the soft interface information + * @vid: the VLAN identifier + * + * Return: 0 on success, a negative error otherwise. + */ +int batadv_softif_create_vlan(struct batadv_priv *bat_priv, unsigned short vid) +{ + struct batadv_softif_vlan *vlan; + + spin_lock_bh(&bat_priv->softif_vlan_list_lock); + + vlan = batadv_softif_vlan_get(bat_priv, vid); + if (vlan) { + batadv_softif_vlan_put(vlan); + spin_unlock_bh(&bat_priv->softif_vlan_list_lock); + return -EEXIST; + } + + vlan = kzalloc(sizeof(*vlan), GFP_ATOMIC); + if (!vlan) { + spin_unlock_bh(&bat_priv->softif_vlan_list_lock); + return -ENOMEM; + } + + vlan->bat_priv = bat_priv; + vlan->vid = vid; + kref_init(&vlan->refcount); + + atomic_set(&vlan->ap_isolation, 0); + + kref_get(&vlan->refcount); + hlist_add_head_rcu(&vlan->list, &bat_priv->softif_vlan_list); + spin_unlock_bh(&bat_priv->softif_vlan_list_lock); + + /* add a new TT local entry. This one will be marked with the NOPURGE + * flag + */ + batadv_tt_local_add(bat_priv->soft_iface, + bat_priv->soft_iface->dev_addr, vid, + BATADV_NULL_IFINDEX, BATADV_NO_MARK); + + /* don't return reference to new softif_vlan */ + batadv_softif_vlan_put(vlan); + + return 0; +} + +/** + * batadv_softif_destroy_vlan() - remove and destroy a softif_vlan object + * @bat_priv: the bat priv with all the soft interface information + * @vlan: the object to remove + */ +static void batadv_softif_destroy_vlan(struct batadv_priv *bat_priv, + struct batadv_softif_vlan *vlan) +{ + /* explicitly remove the associated TT local entry because it is marked + * with the NOPURGE flag + */ + batadv_tt_local_remove(bat_priv, bat_priv->soft_iface->dev_addr, + vlan->vid, "vlan interface destroyed", false); + + batadv_softif_vlan_put(vlan); +} + +/** + * batadv_interface_add_vid() - ndo_add_vid API implementation + * @dev: the netdev of the mesh interface + * @proto: protocol of the vlan id + * @vid: identifier of the new vlan + * + * Set up all the internal structures for handling the new vlan on top of the + * mesh interface + * + * Return: 0 on success or a negative error code in case of failure. + */ +static int batadv_interface_add_vid(struct net_device *dev, __be16 proto, + unsigned short vid) +{ + struct batadv_priv *bat_priv = netdev_priv(dev); + struct batadv_softif_vlan *vlan; + + /* only 802.1Q vlans are supported. + * batman-adv does not know how to handle other types + */ + if (proto != htons(ETH_P_8021Q)) + return -EINVAL; + + vid |= BATADV_VLAN_HAS_TAG; + + /* if a new vlan is getting created and it already exists, it means that + * it was not deleted yet. batadv_softif_vlan_get() increases the + * refcount in order to revive the object. + * + * if it does not exist then create it. + */ + vlan = batadv_softif_vlan_get(bat_priv, vid); + if (!vlan) + return batadv_softif_create_vlan(bat_priv, vid); + + /* add a new TT local entry. This one will be marked with the NOPURGE + * flag. This must be added again, even if the vlan object already + * exists, because the entry was deleted by kill_vid() + */ + batadv_tt_local_add(bat_priv->soft_iface, + bat_priv->soft_iface->dev_addr, vid, + BATADV_NULL_IFINDEX, BATADV_NO_MARK); + + return 0; +} + +/** + * batadv_interface_kill_vid() - ndo_kill_vid API implementation + * @dev: the netdev of the mesh interface + * @proto: protocol of the vlan id + * @vid: identifier of the deleted vlan + * + * Destroy all the internal structures used to handle the vlan identified by vid + * on top of the mesh interface + * + * Return: 0 on success, -EINVAL if the specified prototype is not ETH_P_8021Q + * or -ENOENT if the specified vlan id wasn't registered. + */ +static int batadv_interface_kill_vid(struct net_device *dev, __be16 proto, + unsigned short vid) +{ + struct batadv_priv *bat_priv = netdev_priv(dev); + struct batadv_softif_vlan *vlan; + + /* only 802.1Q vlans are supported. batman-adv does not know how to + * handle other types + */ + if (proto != htons(ETH_P_8021Q)) + return -EINVAL; + + vlan = batadv_softif_vlan_get(bat_priv, vid | BATADV_VLAN_HAS_TAG); + if (!vlan) + return -ENOENT; + + batadv_softif_destroy_vlan(bat_priv, vlan); + + /* finally free the vlan object */ + batadv_softif_vlan_put(vlan); + + return 0; +} + +/* batman-adv network devices have devices nesting below it and are a special + * "super class" of normal network devices; split their locks off into a + * separate class since they always nest. + */ +static struct lock_class_key batadv_netdev_xmit_lock_key; +static struct lock_class_key batadv_netdev_addr_lock_key; + +/** + * batadv_set_lockdep_class_one() - Set lockdep class for a single tx queue + * @dev: device which owns the tx queue + * @txq: tx queue to modify + * @_unused: always NULL + */ +static void batadv_set_lockdep_class_one(struct net_device *dev, + struct netdev_queue *txq, + void *_unused) +{ + lockdep_set_class(&txq->_xmit_lock, &batadv_netdev_xmit_lock_key); +} + +/** + * batadv_set_lockdep_class() - Set txq and addr_list lockdep class + * @dev: network device to modify + */ +static void batadv_set_lockdep_class(struct net_device *dev) +{ + lockdep_set_class(&dev->addr_list_lock, &batadv_netdev_addr_lock_key); + netdev_for_each_tx_queue(dev, batadv_set_lockdep_class_one, NULL); +} + +/** + * batadv_softif_init_late() - late stage initialization of soft interface + * @dev: registered network device to modify + * + * Return: error code on failures + */ +static int batadv_softif_init_late(struct net_device *dev) +{ + struct batadv_priv *bat_priv; + u32 random_seqno; + int ret; + size_t cnt_len = sizeof(u64) * BATADV_CNT_NUM; + + batadv_set_lockdep_class(dev); + + bat_priv = netdev_priv(dev); + bat_priv->soft_iface = dev; + + /* batadv_interface_stats() needs to be available as soon as + * register_netdevice() has been called + */ + bat_priv->bat_counters = __alloc_percpu(cnt_len, __alignof__(u64)); + if (!bat_priv->bat_counters) + return -ENOMEM; + + atomic_set(&bat_priv->aggregated_ogms, 1); + atomic_set(&bat_priv->bonding, 0); +#ifdef CONFIG_BATMAN_ADV_BLA + atomic_set(&bat_priv->bridge_loop_avoidance, 1); +#endif +#ifdef CONFIG_BATMAN_ADV_DAT + atomic_set(&bat_priv->distributed_arp_table, 1); +#endif +#ifdef CONFIG_BATMAN_ADV_MCAST + atomic_set(&bat_priv->multicast_mode, 1); + atomic_set(&bat_priv->multicast_fanout, 16); + atomic_set(&bat_priv->mcast.num_want_all_unsnoopables, 0); + atomic_set(&bat_priv->mcast.num_want_all_ipv4, 0); + atomic_set(&bat_priv->mcast.num_want_all_ipv6, 0); +#endif + atomic_set(&bat_priv->gw.mode, BATADV_GW_MODE_OFF); + atomic_set(&bat_priv->gw.bandwidth_down, 100); + atomic_set(&bat_priv->gw.bandwidth_up, 20); + atomic_set(&bat_priv->orig_interval, 1000); + atomic_set(&bat_priv->hop_penalty, 30); +#ifdef CONFIG_BATMAN_ADV_DEBUG + atomic_set(&bat_priv->log_level, 0); +#endif + atomic_set(&bat_priv->fragmentation, 1); + atomic_set(&bat_priv->packet_size_max, ETH_DATA_LEN); + atomic_set(&bat_priv->bcast_queue_left, BATADV_BCAST_QUEUE_LEN); + atomic_set(&bat_priv->batman_queue_left, BATADV_BATMAN_QUEUE_LEN); + + atomic_set(&bat_priv->mesh_state, BATADV_MESH_INACTIVE); + atomic_set(&bat_priv->bcast_seqno, 1); + atomic_set(&bat_priv->tt.vn, 0); + atomic_set(&bat_priv->tt.local_changes, 0); + atomic_set(&bat_priv->tt.ogm_append_cnt, 0); +#ifdef CONFIG_BATMAN_ADV_BLA + atomic_set(&bat_priv->bla.num_requests, 0); +#endif + atomic_set(&bat_priv->tp_num, 0); + + bat_priv->tt.last_changeset = NULL; + bat_priv->tt.last_changeset_len = 0; + bat_priv->isolation_mark = 0; + bat_priv->isolation_mark_mask = 0; + + /* randomize initial seqno to avoid collision */ + get_random_bytes(&random_seqno, sizeof(random_seqno)); + atomic_set(&bat_priv->frag_seqno, random_seqno); + + bat_priv->primary_if = NULL; + + batadv_nc_init_bat_priv(bat_priv); + + if (!bat_priv->algo_ops) { + ret = batadv_algo_select(bat_priv, batadv_routing_algo); + if (ret < 0) + goto free_bat_counters; + } + + ret = batadv_mesh_init(dev); + if (ret < 0) + goto free_bat_counters; + + return 0; + +free_bat_counters: + free_percpu(bat_priv->bat_counters); + bat_priv->bat_counters = NULL; + + return ret; +} + +/** + * batadv_softif_slave_add() - Add a slave interface to a batadv_soft_interface + * @dev: batadv_soft_interface used as master interface + * @slave_dev: net_device which should become the slave interface + * @extack: extended ACK report struct + * + * Return: 0 if successful or error otherwise. + */ +static int batadv_softif_slave_add(struct net_device *dev, + struct net_device *slave_dev, + struct netlink_ext_ack *extack) +{ + struct batadv_hard_iface *hard_iface; + int ret = -EINVAL; + + hard_iface = batadv_hardif_get_by_netdev(slave_dev); + if (!hard_iface || hard_iface->soft_iface) + goto out; + + ret = batadv_hardif_enable_interface(hard_iface, dev); + +out: + batadv_hardif_put(hard_iface); + return ret; +} + +/** + * batadv_softif_slave_del() - Delete a slave iface from a batadv_soft_interface + * @dev: batadv_soft_interface used as master interface + * @slave_dev: net_device which should be removed from the master interface + * + * Return: 0 if successful or error otherwise. + */ +static int batadv_softif_slave_del(struct net_device *dev, + struct net_device *slave_dev) +{ + struct batadv_hard_iface *hard_iface; + int ret = -EINVAL; + + hard_iface = batadv_hardif_get_by_netdev(slave_dev); + + if (!hard_iface || hard_iface->soft_iface != dev) + goto out; + + batadv_hardif_disable_interface(hard_iface); + ret = 0; + +out: + batadv_hardif_put(hard_iface); + return ret; +} + +static const struct net_device_ops batadv_netdev_ops = { + .ndo_init = batadv_softif_init_late, + .ndo_open = batadv_interface_open, + .ndo_stop = batadv_interface_release, + .ndo_get_stats = batadv_interface_stats, + .ndo_vlan_rx_add_vid = batadv_interface_add_vid, + .ndo_vlan_rx_kill_vid = batadv_interface_kill_vid, + .ndo_set_mac_address = batadv_interface_set_mac_addr, + .ndo_change_mtu = batadv_interface_change_mtu, + .ndo_set_rx_mode = batadv_interface_set_rx_mode, + .ndo_start_xmit = batadv_interface_tx, + .ndo_validate_addr = eth_validate_addr, + .ndo_add_slave = batadv_softif_slave_add, + .ndo_del_slave = batadv_softif_slave_del, +}; + +static void batadv_get_drvinfo(struct net_device *dev, + struct ethtool_drvinfo *info) +{ + strscpy(info->driver, "B.A.T.M.A.N. advanced", sizeof(info->driver)); + strscpy(info->version, BATADV_SOURCE_VERSION, sizeof(info->version)); + strscpy(info->fw_version, "N/A", sizeof(info->fw_version)); + strscpy(info->bus_info, "batman", sizeof(info->bus_info)); +} + +/* Inspired by drivers/net/ethernet/dlink/sundance.c:1702 + * Declare each description string in struct.name[] to get fixed sized buffer + * and compile time checking for strings longer than ETH_GSTRING_LEN. + */ +static const struct { + const char name[ETH_GSTRING_LEN]; +} batadv_counters_strings[] = { + { "tx" }, + { "tx_bytes" }, + { "tx_dropped" }, + { "rx" }, + { "rx_bytes" }, + { "forward" }, + { "forward_bytes" }, + { "mgmt_tx" }, + { "mgmt_tx_bytes" }, + { "mgmt_rx" }, + { "mgmt_rx_bytes" }, + { "frag_tx" }, + { "frag_tx_bytes" }, + { "frag_rx" }, + { "frag_rx_bytes" }, + { "frag_fwd" }, + { "frag_fwd_bytes" }, + { "tt_request_tx" }, + { "tt_request_rx" }, + { "tt_response_tx" }, + { "tt_response_rx" }, + { "tt_roam_adv_tx" }, + { "tt_roam_adv_rx" }, +#ifdef CONFIG_BATMAN_ADV_DAT + { "dat_get_tx" }, + { "dat_get_rx" }, + { "dat_put_tx" }, + { "dat_put_rx" }, + { "dat_cached_reply_tx" }, +#endif +#ifdef CONFIG_BATMAN_ADV_NC + { "nc_code" }, + { "nc_code_bytes" }, + { "nc_recode" }, + { "nc_recode_bytes" }, + { "nc_buffer" }, + { "nc_decode" }, + { "nc_decode_bytes" }, + { "nc_decode_failed" }, + { "nc_sniffed" }, +#endif +}; + +static void batadv_get_strings(struct net_device *dev, u32 stringset, u8 *data) +{ + if (stringset == ETH_SS_STATS) + memcpy(data, batadv_counters_strings, + sizeof(batadv_counters_strings)); +} + +static void batadv_get_ethtool_stats(struct net_device *dev, + struct ethtool_stats *stats, u64 *data) +{ + struct batadv_priv *bat_priv = netdev_priv(dev); + int i; + + for (i = 0; i < BATADV_CNT_NUM; i++) + data[i] = batadv_sum_counter(bat_priv, i); +} + +static int batadv_get_sset_count(struct net_device *dev, int stringset) +{ + if (stringset == ETH_SS_STATS) + return BATADV_CNT_NUM; + + return -EOPNOTSUPP; +} + +static const struct ethtool_ops batadv_ethtool_ops = { + .get_drvinfo = batadv_get_drvinfo, + .get_link = ethtool_op_get_link, + .get_strings = batadv_get_strings, + .get_ethtool_stats = batadv_get_ethtool_stats, + .get_sset_count = batadv_get_sset_count, +}; + +/** + * batadv_softif_free() - Deconstructor of batadv_soft_interface + * @dev: Device to cleanup and remove + */ +static void batadv_softif_free(struct net_device *dev) +{ + batadv_mesh_free(dev); + + /* some scheduled RCU callbacks need the bat_priv struct to accomplish + * their tasks. Wait for them all to be finished before freeing the + * netdev and its private data (bat_priv) + */ + rcu_barrier(); +} + +/** + * batadv_softif_init_early() - early stage initialization of soft interface + * @dev: registered network device to modify + */ +static void batadv_softif_init_early(struct net_device *dev) +{ + ether_setup(dev); + + dev->netdev_ops = &batadv_netdev_ops; + dev->needs_free_netdev = true; + dev->priv_destructor = batadv_softif_free; + dev->features |= NETIF_F_HW_VLAN_CTAG_FILTER | NETIF_F_NETNS_LOCAL; + dev->features |= NETIF_F_LLTX; + dev->priv_flags |= IFF_NO_QUEUE; + + /* can't call min_mtu, because the needed variables + * have not been initialized yet + */ + dev->mtu = ETH_DATA_LEN; + + /* generate random address */ + eth_hw_addr_random(dev); + + dev->ethtool_ops = &batadv_ethtool_ops; +} + +/** + * batadv_softif_validate() - validate configuration of new batadv link + * @tb: IFLA_INFO_DATA netlink attributes + * @data: enum batadv_ifla_attrs attributes + * @extack: extended ACK report struct + * + * Return: 0 if successful or error otherwise. + */ +static int batadv_softif_validate(struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct batadv_algo_ops *algo_ops; + + if (!data) + return 0; + + if (data[IFLA_BATADV_ALGO_NAME]) { + algo_ops = batadv_algo_get(nla_data(data[IFLA_BATADV_ALGO_NAME])); + if (!algo_ops) + return -EINVAL; + } + + return 0; +} + +/** + * batadv_softif_newlink() - pre-initialize and register new batadv link + * @src_net: the applicable net namespace + * @dev: network device to register + * @tb: IFLA_INFO_DATA netlink attributes + * @data: enum batadv_ifla_attrs attributes + * @extack: extended ACK report struct + * + * Return: 0 if successful or error otherwise. + */ +static int batadv_softif_newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct batadv_priv *bat_priv = netdev_priv(dev); + const char *algo_name; + int err; + + if (data && data[IFLA_BATADV_ALGO_NAME]) { + algo_name = nla_data(data[IFLA_BATADV_ALGO_NAME]); + err = batadv_algo_select(bat_priv, algo_name); + if (err) + return -EINVAL; + } + + return register_netdevice(dev); +} + +/** + * batadv_softif_destroy_netlink() - deletion of batadv_soft_interface via + * netlink + * @soft_iface: the to-be-removed batman-adv interface + * @head: list pointer + */ +static void batadv_softif_destroy_netlink(struct net_device *soft_iface, + struct list_head *head) +{ + struct batadv_priv *bat_priv = netdev_priv(soft_iface); + struct batadv_hard_iface *hard_iface; + struct batadv_softif_vlan *vlan; + + list_for_each_entry(hard_iface, &batadv_hardif_list, list) { + if (hard_iface->soft_iface == soft_iface) + batadv_hardif_disable_interface(hard_iface); + } + + /* destroy the "untagged" VLAN */ + vlan = batadv_softif_vlan_get(bat_priv, BATADV_NO_FLAGS); + if (vlan) { + batadv_softif_destroy_vlan(bat_priv, vlan); + batadv_softif_vlan_put(vlan); + } + + unregister_netdevice_queue(soft_iface, head); +} + +/** + * batadv_softif_is_valid() - Check whether device is a batadv soft interface + * @net_dev: device which should be checked + * + * Return: true when net_dev is a batman-adv interface, false otherwise + */ +bool batadv_softif_is_valid(const struct net_device *net_dev) +{ + if (net_dev->netdev_ops->ndo_start_xmit == batadv_interface_tx) + return true; + + return false; +} + +static const struct nla_policy batadv_ifla_policy[IFLA_BATADV_MAX + 1] = { + [IFLA_BATADV_ALGO_NAME] = { .type = NLA_NUL_STRING }, +}; + +struct rtnl_link_ops batadv_link_ops __read_mostly = { + .kind = "batadv", + .priv_size = sizeof(struct batadv_priv), + .setup = batadv_softif_init_early, + .maxtype = IFLA_BATADV_MAX, + .policy = batadv_ifla_policy, + .validate = batadv_softif_validate, + .newlink = batadv_softif_newlink, + .dellink = batadv_softif_destroy_netlink, +}; diff --git a/net/batman-adv/soft-interface.h b/net/batman-adv/soft-interface.h new file mode 100644 index 000000000..9f2003f1a --- /dev/null +++ b/net/batman-adv/soft-interface.h @@ -0,0 +1,42 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner + */ + +#ifndef _NET_BATMAN_ADV_SOFT_INTERFACE_H_ +#define _NET_BATMAN_ADV_SOFT_INTERFACE_H_ + +#include "main.h" + +#include +#include +#include +#include +#include + +int batadv_skb_head_push(struct sk_buff *skb, unsigned int len); +void batadv_interface_rx(struct net_device *soft_iface, + struct sk_buff *skb, int hdr_size, + struct batadv_orig_node *orig_node); +bool batadv_softif_is_valid(const struct net_device *net_dev); +extern struct rtnl_link_ops batadv_link_ops; +int batadv_softif_create_vlan(struct batadv_priv *bat_priv, unsigned short vid); +void batadv_softif_vlan_release(struct kref *ref); +struct batadv_softif_vlan *batadv_softif_vlan_get(struct batadv_priv *bat_priv, + unsigned short vid); + +/** + * batadv_softif_vlan_put() - decrease the vlan object refcounter and + * possibly release it + * @vlan: the vlan object to release + */ +static inline void batadv_softif_vlan_put(struct batadv_softif_vlan *vlan) +{ + if (!vlan) + return; + + kref_put(&vlan->refcount, batadv_softif_vlan_release); +} + +#endif /* _NET_BATMAN_ADV_SOFT_INTERFACE_H_ */ diff --git a/net/batman-adv/tp_meter.c b/net/batman-adv/tp_meter.c new file mode 100644 index 000000000..7f3dd3c39 --- /dev/null +++ b/net/batman-adv/tp_meter.c @@ -0,0 +1,1490 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Edo Monticelli, Antonio Quartulli + */ + +#include "tp_meter.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "hard-interface.h" +#include "log.h" +#include "netlink.h" +#include "originator.h" +#include "send.h" + +/** + * BATADV_TP_DEF_TEST_LENGTH - Default test length if not specified by the user + * in milliseconds + */ +#define BATADV_TP_DEF_TEST_LENGTH 10000 + +/** + * BATADV_TP_AWND - Advertised window by the receiver (in bytes) + */ +#define BATADV_TP_AWND 0x20000000 + +/** + * BATADV_TP_RECV_TIMEOUT - Receiver activity timeout. If the receiver does not + * get anything for such amount of milliseconds, the connection is killed + */ +#define BATADV_TP_RECV_TIMEOUT 1000 + +/** + * BATADV_TP_MAX_RTO - Maximum sender timeout. If the sender RTO gets beyond + * such amount of milliseconds, the receiver is considered unreachable and the + * connection is killed + */ +#define BATADV_TP_MAX_RTO 30000 + +/** + * BATADV_TP_FIRST_SEQ - First seqno of each session. The number is rather high + * in order to immediately trigger a wrap around (test purposes) + */ +#define BATADV_TP_FIRST_SEQ ((u32)-1 - 2000) + +/** + * BATADV_TP_PLEN - length of the payload (data after the batadv_unicast header) + * to simulate + */ +#define BATADV_TP_PLEN (BATADV_TP_PACKET_LEN - ETH_HLEN - \ + sizeof(struct batadv_unicast_packet)) + +static u8 batadv_tp_prerandom[4096] __read_mostly; + +/** + * batadv_tp_session_cookie() - generate session cookie based on session ids + * @session: TP session identifier + * @icmp_uid: icmp pseudo uid of the tp session + * + * Return: 32 bit tp_meter session cookie + */ +static u32 batadv_tp_session_cookie(const u8 session[2], u8 icmp_uid) +{ + u32 cookie; + + cookie = icmp_uid << 16; + cookie |= session[0] << 8; + cookie |= session[1]; + + return cookie; +} + +/** + * batadv_tp_cwnd() - compute the new cwnd size + * @base: base cwnd size value + * @increment: the value to add to base to get the new size + * @min: minimum cwnd value (usually MSS) + * + * Return the new cwnd size and ensure it does not exceed the Advertised + * Receiver Window size. It is wrapped around safely. + * For details refer to Section 3.1 of RFC5681 + * + * Return: new congestion window size in bytes + */ +static u32 batadv_tp_cwnd(u32 base, u32 increment, u32 min) +{ + u32 new_size = base + increment; + + /* check for wrap-around */ + if (new_size < base) + new_size = (u32)ULONG_MAX; + + new_size = min_t(u32, new_size, BATADV_TP_AWND); + + return max_t(u32, new_size, min); +} + +/** + * batadv_tp_update_cwnd() - update the Congestion Windows + * @tp_vars: the private data of the current TP meter session + * @mss: maximum segment size of transmission + * + * 1) if the session is in Slow Start, the CWND has to be increased by 1 + * MSS every unique received ACK + * 2) if the session is in Congestion Avoidance, the CWND has to be + * increased by MSS * MSS / CWND for every unique received ACK + */ +static void batadv_tp_update_cwnd(struct batadv_tp_vars *tp_vars, u32 mss) +{ + spin_lock_bh(&tp_vars->cwnd_lock); + + /* slow start... */ + if (tp_vars->cwnd <= tp_vars->ss_threshold) { + tp_vars->dec_cwnd = 0; + tp_vars->cwnd = batadv_tp_cwnd(tp_vars->cwnd, mss, mss); + spin_unlock_bh(&tp_vars->cwnd_lock); + return; + } + + /* increment CWND at least of 1 (section 3.1 of RFC5681) */ + tp_vars->dec_cwnd += max_t(u32, 1U << 3, + ((mss * mss) << 6) / (tp_vars->cwnd << 3)); + if (tp_vars->dec_cwnd < (mss << 3)) { + spin_unlock_bh(&tp_vars->cwnd_lock); + return; + } + + tp_vars->cwnd = batadv_tp_cwnd(tp_vars->cwnd, mss, mss); + tp_vars->dec_cwnd = 0; + + spin_unlock_bh(&tp_vars->cwnd_lock); +} + +/** + * batadv_tp_update_rto() - calculate new retransmission timeout + * @tp_vars: the private data of the current TP meter session + * @new_rtt: new roundtrip time in msec + */ +static void batadv_tp_update_rto(struct batadv_tp_vars *tp_vars, + u32 new_rtt) +{ + long m = new_rtt; + + /* RTT update + * Details in Section 2.2 and 2.3 of RFC6298 + * + * It's tricky to understand. Don't lose hair please. + * Inspired by tcp_rtt_estimator() tcp_input.c + */ + if (tp_vars->srtt != 0) { + m -= (tp_vars->srtt >> 3); /* m is now error in rtt est */ + tp_vars->srtt += m; /* rtt = 7/8 srtt + 1/8 new */ + if (m < 0) + m = -m; + + m -= (tp_vars->rttvar >> 2); + tp_vars->rttvar += m; /* mdev ~= 3/4 rttvar + 1/4 new */ + } else { + /* first measure getting in */ + tp_vars->srtt = m << 3; /* take the measured time to be srtt */ + tp_vars->rttvar = m << 1; /* new_rtt / 2 */ + } + + /* rto = srtt + 4 * rttvar. + * rttvar is scaled by 4, therefore doesn't need to be multiplied + */ + tp_vars->rto = (tp_vars->srtt >> 3) + tp_vars->rttvar; +} + +/** + * batadv_tp_batctl_notify() - send client status result to client + * @reason: reason for tp meter session stop + * @dst: destination of tp_meter session + * @bat_priv: the bat priv with all the soft interface information + * @start_time: start of transmission in jiffies + * @total_sent: bytes acked to the receiver + * @cookie: cookie of tp_meter session + */ +static void batadv_tp_batctl_notify(enum batadv_tp_meter_reason reason, + const u8 *dst, struct batadv_priv *bat_priv, + unsigned long start_time, u64 total_sent, + u32 cookie) +{ + u32 test_time; + u8 result; + u32 total_bytes; + + if (!batadv_tp_is_error(reason)) { + result = BATADV_TP_REASON_COMPLETE; + test_time = jiffies_to_msecs(jiffies - start_time); + total_bytes = total_sent; + } else { + result = reason; + test_time = 0; + total_bytes = 0; + } + + batadv_netlink_tpmeter_notify(bat_priv, dst, result, test_time, + total_bytes, cookie); +} + +/** + * batadv_tp_batctl_error_notify() - send client error result to client + * @reason: reason for tp meter session stop + * @dst: destination of tp_meter session + * @bat_priv: the bat priv with all the soft interface information + * @cookie: cookie of tp_meter session + */ +static void batadv_tp_batctl_error_notify(enum batadv_tp_meter_reason reason, + const u8 *dst, + struct batadv_priv *bat_priv, + u32 cookie) +{ + batadv_tp_batctl_notify(reason, dst, bat_priv, 0, 0, cookie); +} + +/** + * batadv_tp_list_find() - find a tp_vars object in the global list + * @bat_priv: the bat priv with all the soft interface information + * @dst: the other endpoint MAC address to look for + * + * Look for a tp_vars object matching dst as end_point and return it after + * having increment the refcounter. Return NULL is not found + * + * Return: matching tp_vars or NULL when no tp_vars with @dst was found + */ +static struct batadv_tp_vars *batadv_tp_list_find(struct batadv_priv *bat_priv, + const u8 *dst) +{ + struct batadv_tp_vars *pos, *tp_vars = NULL; + + rcu_read_lock(); + hlist_for_each_entry_rcu(pos, &bat_priv->tp_list, list) { + if (!batadv_compare_eth(pos->other_end, dst)) + continue; + + /* most of the time this function is invoked during the normal + * process..it makes sens to pay more when the session is + * finished and to speed the process up during the measurement + */ + if (unlikely(!kref_get_unless_zero(&pos->refcount))) + continue; + + tp_vars = pos; + break; + } + rcu_read_unlock(); + + return tp_vars; +} + +/** + * batadv_tp_list_find_session() - find tp_vars session object in the global + * list + * @bat_priv: the bat priv with all the soft interface information + * @dst: the other endpoint MAC address to look for + * @session: session identifier + * + * Look for a tp_vars object matching dst as end_point, session as tp meter + * session and return it after having increment the refcounter. Return NULL + * is not found + * + * Return: matching tp_vars or NULL when no tp_vars was found + */ +static struct batadv_tp_vars * +batadv_tp_list_find_session(struct batadv_priv *bat_priv, const u8 *dst, + const u8 *session) +{ + struct batadv_tp_vars *pos, *tp_vars = NULL; + + rcu_read_lock(); + hlist_for_each_entry_rcu(pos, &bat_priv->tp_list, list) { + if (!batadv_compare_eth(pos->other_end, dst)) + continue; + + if (memcmp(pos->session, session, sizeof(pos->session)) != 0) + continue; + + /* most of the time this function is invoked during the normal + * process..it makes sense to pay more when the session is + * finished and to speed the process up during the measurement + */ + if (unlikely(!kref_get_unless_zero(&pos->refcount))) + continue; + + tp_vars = pos; + break; + } + rcu_read_unlock(); + + return tp_vars; +} + +/** + * batadv_tp_vars_release() - release batadv_tp_vars from lists and queue for + * free after rcu grace period + * @ref: kref pointer of the batadv_tp_vars + */ +static void batadv_tp_vars_release(struct kref *ref) +{ + struct batadv_tp_vars *tp_vars; + struct batadv_tp_unacked *un, *safe; + + tp_vars = container_of(ref, struct batadv_tp_vars, refcount); + + /* lock should not be needed because this object is now out of any + * context! + */ + spin_lock_bh(&tp_vars->unacked_lock); + list_for_each_entry_safe(un, safe, &tp_vars->unacked_list, list) { + list_del(&un->list); + kfree(un); + } + spin_unlock_bh(&tp_vars->unacked_lock); + + kfree_rcu(tp_vars, rcu); +} + +/** + * batadv_tp_vars_put() - decrement the batadv_tp_vars refcounter and possibly + * release it + * @tp_vars: the private data of the current TP meter session to be free'd + */ +static void batadv_tp_vars_put(struct batadv_tp_vars *tp_vars) +{ + if (!tp_vars) + return; + + kref_put(&tp_vars->refcount, batadv_tp_vars_release); +} + +/** + * batadv_tp_sender_cleanup() - cleanup sender data and drop and timer + * @bat_priv: the bat priv with all the soft interface information + * @tp_vars: the private data of the current TP meter session to cleanup + */ +static void batadv_tp_sender_cleanup(struct batadv_priv *bat_priv, + struct batadv_tp_vars *tp_vars) +{ + cancel_delayed_work(&tp_vars->finish_work); + + spin_lock_bh(&tp_vars->bat_priv->tp_list_lock); + hlist_del_rcu(&tp_vars->list); + spin_unlock_bh(&tp_vars->bat_priv->tp_list_lock); + + /* drop list reference */ + batadv_tp_vars_put(tp_vars); + + atomic_dec(&tp_vars->bat_priv->tp_num); + + /* kill the timer and remove its reference */ + del_timer_sync(&tp_vars->timer); + /* the worker might have rearmed itself therefore we kill it again. Note + * that if the worker should run again before invoking the following + * del_timer(), it would not re-arm itself once again because the status + * is OFF now + */ + del_timer(&tp_vars->timer); + batadv_tp_vars_put(tp_vars); +} + +/** + * batadv_tp_sender_end() - print info about ended session and inform client + * @bat_priv: the bat priv with all the soft interface information + * @tp_vars: the private data of the current TP meter session + */ +static void batadv_tp_sender_end(struct batadv_priv *bat_priv, + struct batadv_tp_vars *tp_vars) +{ + u32 session_cookie; + + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Test towards %pM finished..shutting down (reason=%d)\n", + tp_vars->other_end, tp_vars->reason); + + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Last timing stats: SRTT=%ums RTTVAR=%ums RTO=%ums\n", + tp_vars->srtt >> 3, tp_vars->rttvar >> 2, tp_vars->rto); + + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Final values: cwnd=%u ss_threshold=%u\n", + tp_vars->cwnd, tp_vars->ss_threshold); + + session_cookie = batadv_tp_session_cookie(tp_vars->session, + tp_vars->icmp_uid); + + batadv_tp_batctl_notify(tp_vars->reason, + tp_vars->other_end, + bat_priv, + tp_vars->start_time, + atomic64_read(&tp_vars->tot_sent), + session_cookie); +} + +/** + * batadv_tp_sender_shutdown() - let sender thread/timer stop gracefully + * @tp_vars: the private data of the current TP meter session + * @reason: reason for tp meter session stop + */ +static void batadv_tp_sender_shutdown(struct batadv_tp_vars *tp_vars, + enum batadv_tp_meter_reason reason) +{ + if (!atomic_dec_and_test(&tp_vars->sending)) + return; + + tp_vars->reason = reason; +} + +/** + * batadv_tp_sender_finish() - stop sender session after test_length was reached + * @work: delayed work reference of the related tp_vars + */ +static void batadv_tp_sender_finish(struct work_struct *work) +{ + struct delayed_work *delayed_work; + struct batadv_tp_vars *tp_vars; + + delayed_work = to_delayed_work(work); + tp_vars = container_of(delayed_work, struct batadv_tp_vars, + finish_work); + + batadv_tp_sender_shutdown(tp_vars, BATADV_TP_REASON_COMPLETE); +} + +/** + * batadv_tp_reset_sender_timer() - reschedule the sender timer + * @tp_vars: the private TP meter data for this session + * + * Reschedule the timer using tp_vars->rto as delay + */ +static void batadv_tp_reset_sender_timer(struct batadv_tp_vars *tp_vars) +{ + /* most of the time this function is invoked while normal packet + * reception... + */ + if (unlikely(atomic_read(&tp_vars->sending) == 0)) + /* timer ref will be dropped in batadv_tp_sender_cleanup */ + return; + + mod_timer(&tp_vars->timer, jiffies + msecs_to_jiffies(tp_vars->rto)); +} + +/** + * batadv_tp_sender_timeout() - timer that fires in case of packet loss + * @t: address to timer_list inside tp_vars + * + * If fired it means that there was packet loss. + * Switch to Slow Start, set the ss_threshold to half of the current cwnd and + * reset the cwnd to 3*MSS + */ +static void batadv_tp_sender_timeout(struct timer_list *t) +{ + struct batadv_tp_vars *tp_vars = from_timer(tp_vars, t, timer); + struct batadv_priv *bat_priv = tp_vars->bat_priv; + + if (atomic_read(&tp_vars->sending) == 0) + return; + + /* if the user waited long enough...shutdown the test */ + if (unlikely(tp_vars->rto >= BATADV_TP_MAX_RTO)) { + batadv_tp_sender_shutdown(tp_vars, + BATADV_TP_REASON_DST_UNREACHABLE); + return; + } + + /* RTO exponential backoff + * Details in Section 5.5 of RFC6298 + */ + tp_vars->rto <<= 1; + + spin_lock_bh(&tp_vars->cwnd_lock); + + tp_vars->ss_threshold = tp_vars->cwnd >> 1; + if (tp_vars->ss_threshold < BATADV_TP_PLEN * 2) + tp_vars->ss_threshold = BATADV_TP_PLEN * 2; + + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Meter: RTO fired during test towards %pM! cwnd=%u new ss_thr=%u, resetting last_sent to %u\n", + tp_vars->other_end, tp_vars->cwnd, tp_vars->ss_threshold, + atomic_read(&tp_vars->last_acked)); + + tp_vars->cwnd = BATADV_TP_PLEN * 3; + + spin_unlock_bh(&tp_vars->cwnd_lock); + + /* resend the non-ACKed packets.. */ + tp_vars->last_sent = atomic_read(&tp_vars->last_acked); + wake_up(&tp_vars->more_bytes); + + batadv_tp_reset_sender_timer(tp_vars); +} + +/** + * batadv_tp_fill_prerandom() - Fill buffer with prefetched random bytes + * @tp_vars: the private TP meter data for this session + * @buf: Buffer to fill with bytes + * @nbytes: amount of pseudorandom bytes + */ +static void batadv_tp_fill_prerandom(struct batadv_tp_vars *tp_vars, + u8 *buf, size_t nbytes) +{ + u32 local_offset; + size_t bytes_inbuf; + size_t to_copy; + size_t pos = 0; + + spin_lock_bh(&tp_vars->prerandom_lock); + local_offset = tp_vars->prerandom_offset; + tp_vars->prerandom_offset += nbytes; + tp_vars->prerandom_offset %= sizeof(batadv_tp_prerandom); + spin_unlock_bh(&tp_vars->prerandom_lock); + + while (nbytes) { + local_offset %= sizeof(batadv_tp_prerandom); + bytes_inbuf = sizeof(batadv_tp_prerandom) - local_offset; + to_copy = min(nbytes, bytes_inbuf); + + memcpy(&buf[pos], &batadv_tp_prerandom[local_offset], to_copy); + pos += to_copy; + nbytes -= to_copy; + local_offset = 0; + } +} + +/** + * batadv_tp_send_msg() - send a single message + * @tp_vars: the private TP meter data for this session + * @src: source mac address + * @orig_node: the originator of the destination + * @seqno: sequence number of this packet + * @len: length of the entire packet + * @session: session identifier + * @uid: local ICMP "socket" index + * @timestamp: timestamp in jiffies which is replied in ack + * + * Create and send a single TP Meter message. + * + * Return: 0 on success, BATADV_TP_REASON_DST_UNREACHABLE if the destination is + * not reachable, BATADV_TP_REASON_MEMORY_ERROR if the packet couldn't be + * allocated + */ +static int batadv_tp_send_msg(struct batadv_tp_vars *tp_vars, const u8 *src, + struct batadv_orig_node *orig_node, + u32 seqno, size_t len, const u8 *session, + int uid, u32 timestamp) +{ + struct batadv_icmp_tp_packet *icmp; + struct sk_buff *skb; + int r; + u8 *data; + size_t data_len; + + skb = netdev_alloc_skb_ip_align(NULL, len + ETH_HLEN); + if (unlikely(!skb)) + return BATADV_TP_REASON_MEMORY_ERROR; + + skb_reserve(skb, ETH_HLEN); + icmp = skb_put(skb, sizeof(*icmp)); + + /* fill the icmp header */ + ether_addr_copy(icmp->dst, orig_node->orig); + ether_addr_copy(icmp->orig, src); + icmp->version = BATADV_COMPAT_VERSION; + icmp->packet_type = BATADV_ICMP; + icmp->ttl = BATADV_TTL; + icmp->msg_type = BATADV_TP; + icmp->uid = uid; + + icmp->subtype = BATADV_TP_MSG; + memcpy(icmp->session, session, sizeof(icmp->session)); + icmp->seqno = htonl(seqno); + icmp->timestamp = htonl(timestamp); + + data_len = len - sizeof(*icmp); + data = skb_put(skb, data_len); + batadv_tp_fill_prerandom(tp_vars, data, data_len); + + r = batadv_send_skb_to_orig(skb, orig_node, NULL); + if (r == NET_XMIT_SUCCESS) + return 0; + + return BATADV_TP_REASON_CANT_SEND; +} + +/** + * batadv_tp_recv_ack() - ACK receiving function + * @bat_priv: the bat priv with all the soft interface information + * @skb: the buffer containing the received packet + * + * Process a received TP ACK packet + */ +static void batadv_tp_recv_ack(struct batadv_priv *bat_priv, + const struct sk_buff *skb) +{ + struct batadv_hard_iface *primary_if = NULL; + struct batadv_orig_node *orig_node = NULL; + const struct batadv_icmp_tp_packet *icmp; + struct batadv_tp_vars *tp_vars; + const unsigned char *dev_addr; + size_t packet_len, mss; + u32 rtt, recv_ack, cwnd; + + packet_len = BATADV_TP_PLEN; + mss = BATADV_TP_PLEN; + packet_len += sizeof(struct batadv_unicast_packet); + + icmp = (struct batadv_icmp_tp_packet *)skb->data; + + /* find the tp_vars */ + tp_vars = batadv_tp_list_find_session(bat_priv, icmp->orig, + icmp->session); + if (unlikely(!tp_vars)) + return; + + if (unlikely(atomic_read(&tp_vars->sending) == 0)) + goto out; + + /* old ACK? silently drop it.. */ + if (batadv_seq_before(ntohl(icmp->seqno), + (u32)atomic_read(&tp_vars->last_acked))) + goto out; + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (unlikely(!primary_if)) + goto out; + + orig_node = batadv_orig_hash_find(bat_priv, icmp->orig); + if (unlikely(!orig_node)) + goto out; + + /* update RTO with the new sampled RTT, if any */ + rtt = jiffies_to_msecs(jiffies) - ntohl(icmp->timestamp); + if (icmp->timestamp && rtt) + batadv_tp_update_rto(tp_vars, rtt); + + /* ACK for new data... reset the timer */ + batadv_tp_reset_sender_timer(tp_vars); + + recv_ack = ntohl(icmp->seqno); + + /* check if this ACK is a duplicate */ + if (atomic_read(&tp_vars->last_acked) == recv_ack) { + atomic_inc(&tp_vars->dup_acks); + if (atomic_read(&tp_vars->dup_acks) != 3) + goto out; + + if (recv_ack >= tp_vars->recover) + goto out; + + /* if this is the third duplicate ACK do Fast Retransmit */ + batadv_tp_send_msg(tp_vars, primary_if->net_dev->dev_addr, + orig_node, recv_ack, packet_len, + icmp->session, icmp->uid, + jiffies_to_msecs(jiffies)); + + spin_lock_bh(&tp_vars->cwnd_lock); + + /* Fast Recovery */ + tp_vars->fast_recovery = true; + /* Set recover to the last outstanding seqno when Fast Recovery + * is entered. RFC6582, Section 3.2, step 1 + */ + tp_vars->recover = tp_vars->last_sent; + tp_vars->ss_threshold = tp_vars->cwnd >> 1; + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Meter: Fast Recovery, (cur cwnd=%u) ss_thr=%u last_sent=%u recv_ack=%u\n", + tp_vars->cwnd, tp_vars->ss_threshold, + tp_vars->last_sent, recv_ack); + tp_vars->cwnd = batadv_tp_cwnd(tp_vars->ss_threshold, 3 * mss, + mss); + tp_vars->dec_cwnd = 0; + tp_vars->last_sent = recv_ack; + + spin_unlock_bh(&tp_vars->cwnd_lock); + } else { + /* count the acked data */ + atomic64_add(recv_ack - atomic_read(&tp_vars->last_acked), + &tp_vars->tot_sent); + /* reset the duplicate ACKs counter */ + atomic_set(&tp_vars->dup_acks, 0); + + if (tp_vars->fast_recovery) { + /* partial ACK */ + if (batadv_seq_before(recv_ack, tp_vars->recover)) { + /* this is another hole in the window. React + * immediately as specified by NewReno (see + * Section 3.2 of RFC6582 for details) + */ + dev_addr = primary_if->net_dev->dev_addr; + batadv_tp_send_msg(tp_vars, dev_addr, + orig_node, recv_ack, + packet_len, icmp->session, + icmp->uid, + jiffies_to_msecs(jiffies)); + tp_vars->cwnd = batadv_tp_cwnd(tp_vars->cwnd, + mss, mss); + } else { + tp_vars->fast_recovery = false; + /* set cwnd to the value of ss_threshold at the + * moment that Fast Recovery was entered. + * RFC6582, Section 3.2, step 3 + */ + cwnd = batadv_tp_cwnd(tp_vars->ss_threshold, 0, + mss); + tp_vars->cwnd = cwnd; + } + goto move_twnd; + } + + if (recv_ack - atomic_read(&tp_vars->last_acked) >= mss) + batadv_tp_update_cwnd(tp_vars, mss); +move_twnd: + /* move the Transmit Window */ + atomic_set(&tp_vars->last_acked, recv_ack); + } + + wake_up(&tp_vars->more_bytes); +out: + batadv_hardif_put(primary_if); + batadv_orig_node_put(orig_node); + batadv_tp_vars_put(tp_vars); +} + +/** + * batadv_tp_avail() - check if congestion window is not full + * @tp_vars: the private data of the current TP meter session + * @payload_len: size of the payload of a single message + * + * Return: true when congestion window is not full, false otherwise + */ +static bool batadv_tp_avail(struct batadv_tp_vars *tp_vars, + size_t payload_len) +{ + u32 win_left, win_limit; + + win_limit = atomic_read(&tp_vars->last_acked) + tp_vars->cwnd; + win_left = win_limit - tp_vars->last_sent; + + return win_left >= payload_len; +} + +/** + * batadv_tp_wait_available() - wait until congestion window becomes free or + * timeout is reached + * @tp_vars: the private data of the current TP meter session + * @plen: size of the payload of a single message + * + * Return: 0 if the condition evaluated to false after the timeout elapsed, + * 1 if the condition evaluated to true after the timeout elapsed, the + * remaining jiffies (at least 1) if the condition evaluated to true before + * the timeout elapsed, or -ERESTARTSYS if it was interrupted by a signal. + */ +static int batadv_tp_wait_available(struct batadv_tp_vars *tp_vars, size_t plen) +{ + int ret; + + ret = wait_event_interruptible_timeout(tp_vars->more_bytes, + batadv_tp_avail(tp_vars, plen), + HZ / 10); + + return ret; +} + +/** + * batadv_tp_send() - main sending thread of a tp meter session + * @arg: address of the related tp_vars + * + * Return: nothing, this function never returns + */ +static int batadv_tp_send(void *arg) +{ + struct batadv_tp_vars *tp_vars = arg; + struct batadv_priv *bat_priv = tp_vars->bat_priv; + struct batadv_hard_iface *primary_if = NULL; + struct batadv_orig_node *orig_node = NULL; + size_t payload_len, packet_len; + int err = 0; + + if (unlikely(tp_vars->role != BATADV_TP_SENDER)) { + err = BATADV_TP_REASON_DST_UNREACHABLE; + tp_vars->reason = err; + goto out; + } + + orig_node = batadv_orig_hash_find(bat_priv, tp_vars->other_end); + if (unlikely(!orig_node)) { + err = BATADV_TP_REASON_DST_UNREACHABLE; + tp_vars->reason = err; + goto out; + } + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (unlikely(!primary_if)) { + err = BATADV_TP_REASON_DST_UNREACHABLE; + tp_vars->reason = err; + goto out; + } + + /* assume that all the hard_interfaces have a correctly + * configured MTU, so use the soft_iface MTU as MSS. + * This might not be true and in that case the fragmentation + * should be used. + * Now, try to send the packet as it is + */ + payload_len = BATADV_TP_PLEN; + BUILD_BUG_ON(sizeof(struct batadv_icmp_tp_packet) > BATADV_TP_PLEN); + + batadv_tp_reset_sender_timer(tp_vars); + + /* queue the worker in charge of terminating the test */ + queue_delayed_work(batadv_event_workqueue, &tp_vars->finish_work, + msecs_to_jiffies(tp_vars->test_length)); + + while (atomic_read(&tp_vars->sending) != 0) { + if (unlikely(!batadv_tp_avail(tp_vars, payload_len))) { + batadv_tp_wait_available(tp_vars, payload_len); + continue; + } + + /* to emulate normal unicast traffic, add to the payload len + * the size of the unicast header + */ + packet_len = payload_len + sizeof(struct batadv_unicast_packet); + + err = batadv_tp_send_msg(tp_vars, primary_if->net_dev->dev_addr, + orig_node, tp_vars->last_sent, + packet_len, + tp_vars->session, tp_vars->icmp_uid, + jiffies_to_msecs(jiffies)); + + /* something went wrong during the preparation/transmission */ + if (unlikely(err && err != BATADV_TP_REASON_CANT_SEND)) { + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Meter: %s() cannot send packets (%d)\n", + __func__, err); + /* ensure nobody else tries to stop the thread now */ + if (atomic_dec_and_test(&tp_vars->sending)) + tp_vars->reason = err; + break; + } + + /* right-shift the TWND */ + if (!err) + tp_vars->last_sent += payload_len; + + cond_resched(); + } + +out: + batadv_hardif_put(primary_if); + batadv_orig_node_put(orig_node); + + batadv_tp_sender_end(bat_priv, tp_vars); + batadv_tp_sender_cleanup(bat_priv, tp_vars); + + batadv_tp_vars_put(tp_vars); + + return 0; +} + +/** + * batadv_tp_start_kthread() - start new thread which manages the tp meter + * sender + * @tp_vars: the private data of the current TP meter session + */ +static void batadv_tp_start_kthread(struct batadv_tp_vars *tp_vars) +{ + struct task_struct *kthread; + struct batadv_priv *bat_priv = tp_vars->bat_priv; + u32 session_cookie; + + kref_get(&tp_vars->refcount); + kthread = kthread_create(batadv_tp_send, tp_vars, "kbatadv_tp_meter"); + if (IS_ERR(kthread)) { + session_cookie = batadv_tp_session_cookie(tp_vars->session, + tp_vars->icmp_uid); + pr_err("batadv: cannot create tp meter kthread\n"); + batadv_tp_batctl_error_notify(BATADV_TP_REASON_MEMORY_ERROR, + tp_vars->other_end, + bat_priv, session_cookie); + + /* drop reserved reference for kthread */ + batadv_tp_vars_put(tp_vars); + + /* cleanup of failed tp meter variables */ + batadv_tp_sender_cleanup(bat_priv, tp_vars); + return; + } + + wake_up_process(kthread); +} + +/** + * batadv_tp_start() - start a new tp meter session + * @bat_priv: the bat priv with all the soft interface information + * @dst: the receiver MAC address + * @test_length: test length in milliseconds + * @cookie: session cookie + */ +void batadv_tp_start(struct batadv_priv *bat_priv, const u8 *dst, + u32 test_length, u32 *cookie) +{ + struct batadv_tp_vars *tp_vars; + u8 session_id[2]; + u8 icmp_uid; + u32 session_cookie; + + get_random_bytes(session_id, sizeof(session_id)); + get_random_bytes(&icmp_uid, 1); + session_cookie = batadv_tp_session_cookie(session_id, icmp_uid); + *cookie = session_cookie; + + /* look for an already existing test towards this node */ + spin_lock_bh(&bat_priv->tp_list_lock); + tp_vars = batadv_tp_list_find(bat_priv, dst); + if (tp_vars) { + spin_unlock_bh(&bat_priv->tp_list_lock); + batadv_tp_vars_put(tp_vars); + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Meter: test to or from the same node already ongoing, aborting\n"); + batadv_tp_batctl_error_notify(BATADV_TP_REASON_ALREADY_ONGOING, + dst, bat_priv, session_cookie); + return; + } + + if (!atomic_add_unless(&bat_priv->tp_num, 1, BATADV_TP_MAX_NUM)) { + spin_unlock_bh(&bat_priv->tp_list_lock); + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Meter: too many ongoing sessions, aborting (SEND)\n"); + batadv_tp_batctl_error_notify(BATADV_TP_REASON_TOO_MANY, dst, + bat_priv, session_cookie); + return; + } + + tp_vars = kmalloc(sizeof(*tp_vars), GFP_ATOMIC); + if (!tp_vars) { + spin_unlock_bh(&bat_priv->tp_list_lock); + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Meter: %s cannot allocate list elements\n", + __func__); + batadv_tp_batctl_error_notify(BATADV_TP_REASON_MEMORY_ERROR, + dst, bat_priv, session_cookie); + return; + } + + /* initialize tp_vars */ + ether_addr_copy(tp_vars->other_end, dst); + kref_init(&tp_vars->refcount); + tp_vars->role = BATADV_TP_SENDER; + atomic_set(&tp_vars->sending, 1); + memcpy(tp_vars->session, session_id, sizeof(session_id)); + tp_vars->icmp_uid = icmp_uid; + + tp_vars->last_sent = BATADV_TP_FIRST_SEQ; + atomic_set(&tp_vars->last_acked, BATADV_TP_FIRST_SEQ); + tp_vars->fast_recovery = false; + tp_vars->recover = BATADV_TP_FIRST_SEQ; + + /* initialise the CWND to 3*MSS (Section 3.1 in RFC5681). + * For batman-adv the MSS is the size of the payload received by the + * soft_interface, hence its MTU + */ + tp_vars->cwnd = BATADV_TP_PLEN * 3; + /* at the beginning initialise the SS threshold to the biggest possible + * window size, hence the AWND size + */ + tp_vars->ss_threshold = BATADV_TP_AWND; + + /* RTO initial value is 3 seconds. + * Details in Section 2.1 of RFC6298 + */ + tp_vars->rto = 1000; + tp_vars->srtt = 0; + tp_vars->rttvar = 0; + + atomic64_set(&tp_vars->tot_sent, 0); + + kref_get(&tp_vars->refcount); + timer_setup(&tp_vars->timer, batadv_tp_sender_timeout, 0); + + tp_vars->bat_priv = bat_priv; + tp_vars->start_time = jiffies; + + init_waitqueue_head(&tp_vars->more_bytes); + + spin_lock_init(&tp_vars->unacked_lock); + INIT_LIST_HEAD(&tp_vars->unacked_list); + + spin_lock_init(&tp_vars->cwnd_lock); + + tp_vars->prerandom_offset = 0; + spin_lock_init(&tp_vars->prerandom_lock); + + kref_get(&tp_vars->refcount); + hlist_add_head_rcu(&tp_vars->list, &bat_priv->tp_list); + spin_unlock_bh(&bat_priv->tp_list_lock); + + tp_vars->test_length = test_length; + if (!tp_vars->test_length) + tp_vars->test_length = BATADV_TP_DEF_TEST_LENGTH; + + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Meter: starting throughput meter towards %pM (length=%ums)\n", + dst, test_length); + + /* init work item for finished tp tests */ + INIT_DELAYED_WORK(&tp_vars->finish_work, batadv_tp_sender_finish); + + /* start tp kthread. This way the write() call issued from userspace can + * happily return and avoid to block + */ + batadv_tp_start_kthread(tp_vars); + + /* don't return reference to new tp_vars */ + batadv_tp_vars_put(tp_vars); +} + +/** + * batadv_tp_stop() - stop currently running tp meter session + * @bat_priv: the bat priv with all the soft interface information + * @dst: the receiver MAC address + * @return_value: reason for tp meter session stop + */ +void batadv_tp_stop(struct batadv_priv *bat_priv, const u8 *dst, + u8 return_value) +{ + struct batadv_orig_node *orig_node; + struct batadv_tp_vars *tp_vars; + + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Meter: stopping test towards %pM\n", dst); + + orig_node = batadv_orig_hash_find(bat_priv, dst); + if (!orig_node) + return; + + tp_vars = batadv_tp_list_find(bat_priv, orig_node->orig); + if (!tp_vars) { + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Meter: trying to interrupt an already over connection\n"); + goto out; + } + + batadv_tp_sender_shutdown(tp_vars, return_value); + batadv_tp_vars_put(tp_vars); +out: + batadv_orig_node_put(orig_node); +} + +/** + * batadv_tp_reset_receiver_timer() - reset the receiver shutdown timer + * @tp_vars: the private data of the current TP meter session + * + * start the receiver shutdown timer or reset it if already started + */ +static void batadv_tp_reset_receiver_timer(struct batadv_tp_vars *tp_vars) +{ + mod_timer(&tp_vars->timer, + jiffies + msecs_to_jiffies(BATADV_TP_RECV_TIMEOUT)); +} + +/** + * batadv_tp_receiver_shutdown() - stop a tp meter receiver when timeout is + * reached without received ack + * @t: address to timer_list inside tp_vars + */ +static void batadv_tp_receiver_shutdown(struct timer_list *t) +{ + struct batadv_tp_vars *tp_vars = from_timer(tp_vars, t, timer); + struct batadv_tp_unacked *un, *safe; + struct batadv_priv *bat_priv; + + bat_priv = tp_vars->bat_priv; + + /* if there is recent activity rearm the timer */ + if (!batadv_has_timed_out(tp_vars->last_recv_time, + BATADV_TP_RECV_TIMEOUT)) { + /* reset the receiver shutdown timer */ + batadv_tp_reset_receiver_timer(tp_vars); + return; + } + + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Shutting down for inactivity (more than %dms) from %pM\n", + BATADV_TP_RECV_TIMEOUT, tp_vars->other_end); + + spin_lock_bh(&tp_vars->bat_priv->tp_list_lock); + hlist_del_rcu(&tp_vars->list); + spin_unlock_bh(&tp_vars->bat_priv->tp_list_lock); + + /* drop list reference */ + batadv_tp_vars_put(tp_vars); + + atomic_dec(&bat_priv->tp_num); + + spin_lock_bh(&tp_vars->unacked_lock); + list_for_each_entry_safe(un, safe, &tp_vars->unacked_list, list) { + list_del(&un->list); + kfree(un); + } + spin_unlock_bh(&tp_vars->unacked_lock); + + /* drop reference of timer */ + batadv_tp_vars_put(tp_vars); +} + +/** + * batadv_tp_send_ack() - send an ACK packet + * @bat_priv: the bat priv with all the soft interface information + * @dst: the mac address of the destination originator + * @seq: the sequence number to ACK + * @timestamp: the timestamp to echo back in the ACK + * @session: session identifier + * @socket_index: local ICMP socket identifier + * + * Return: 0 on success, a positive integer representing the reason of the + * failure otherwise + */ +static int batadv_tp_send_ack(struct batadv_priv *bat_priv, const u8 *dst, + u32 seq, __be32 timestamp, const u8 *session, + int socket_index) +{ + struct batadv_hard_iface *primary_if = NULL; + struct batadv_orig_node *orig_node; + struct batadv_icmp_tp_packet *icmp; + struct sk_buff *skb; + int r, ret; + + orig_node = batadv_orig_hash_find(bat_priv, dst); + if (unlikely(!orig_node)) { + ret = BATADV_TP_REASON_DST_UNREACHABLE; + goto out; + } + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (unlikely(!primary_if)) { + ret = BATADV_TP_REASON_DST_UNREACHABLE; + goto out; + } + + skb = netdev_alloc_skb_ip_align(NULL, sizeof(*icmp) + ETH_HLEN); + if (unlikely(!skb)) { + ret = BATADV_TP_REASON_MEMORY_ERROR; + goto out; + } + + skb_reserve(skb, ETH_HLEN); + icmp = skb_put(skb, sizeof(*icmp)); + icmp->packet_type = BATADV_ICMP; + icmp->version = BATADV_COMPAT_VERSION; + icmp->ttl = BATADV_TTL; + icmp->msg_type = BATADV_TP; + ether_addr_copy(icmp->dst, orig_node->orig); + ether_addr_copy(icmp->orig, primary_if->net_dev->dev_addr); + icmp->uid = socket_index; + + icmp->subtype = BATADV_TP_ACK; + memcpy(icmp->session, session, sizeof(icmp->session)); + icmp->seqno = htonl(seq); + icmp->timestamp = timestamp; + + /* send the ack */ + r = batadv_send_skb_to_orig(skb, orig_node, NULL); + if (unlikely(r < 0) || r == NET_XMIT_DROP) { + ret = BATADV_TP_REASON_DST_UNREACHABLE; + goto out; + } + ret = 0; + +out: + batadv_orig_node_put(orig_node); + batadv_hardif_put(primary_if); + + return ret; +} + +/** + * batadv_tp_handle_out_of_order() - store an out of order packet + * @tp_vars: the private data of the current TP meter session + * @skb: the buffer containing the received packet + * + * Store the out of order packet in the unacked list for late processing. This + * packets are kept in this list so that they can be ACKed at once as soon as + * all the previous packets have been received + * + * Return: true if the packed has been successfully processed, false otherwise + */ +static bool batadv_tp_handle_out_of_order(struct batadv_tp_vars *tp_vars, + const struct sk_buff *skb) +{ + const struct batadv_icmp_tp_packet *icmp; + struct batadv_tp_unacked *un, *new; + u32 payload_len; + bool added = false; + + new = kmalloc(sizeof(*new), GFP_ATOMIC); + if (unlikely(!new)) + return false; + + icmp = (struct batadv_icmp_tp_packet *)skb->data; + + new->seqno = ntohl(icmp->seqno); + payload_len = skb->len - sizeof(struct batadv_unicast_packet); + new->len = payload_len; + + spin_lock_bh(&tp_vars->unacked_lock); + /* if the list is empty immediately attach this new object */ + if (list_empty(&tp_vars->unacked_list)) { + list_add(&new->list, &tp_vars->unacked_list); + goto out; + } + + /* otherwise loop over the list and either drop the packet because this + * is a duplicate or store it at the right position. + * + * The iteration is done in the reverse way because it is likely that + * the last received packet (the one being processed now) has a bigger + * seqno than all the others already stored. + */ + list_for_each_entry_reverse(un, &tp_vars->unacked_list, list) { + /* check for duplicates */ + if (new->seqno == un->seqno) { + if (new->len > un->len) + un->len = new->len; + kfree(new); + added = true; + break; + } + + /* look for the right position */ + if (batadv_seq_before(new->seqno, un->seqno)) + continue; + + /* as soon as an entry having a bigger seqno is found, the new + * one is attached _after_ it. In this way the list is kept in + * ascending order + */ + list_add_tail(&new->list, &un->list); + added = true; + break; + } + + /* received packet with smallest seqno out of order; add it to front */ + if (!added) + list_add(&new->list, &tp_vars->unacked_list); + +out: + spin_unlock_bh(&tp_vars->unacked_lock); + + return true; +} + +/** + * batadv_tp_ack_unordered() - update number received bytes in current stream + * without gaps + * @tp_vars: the private data of the current TP meter session + */ +static void batadv_tp_ack_unordered(struct batadv_tp_vars *tp_vars) +{ + struct batadv_tp_unacked *un, *safe; + u32 to_ack; + + /* go through the unacked packet list and possibly ACK them as + * well + */ + spin_lock_bh(&tp_vars->unacked_lock); + list_for_each_entry_safe(un, safe, &tp_vars->unacked_list, list) { + /* the list is ordered, therefore it is possible to stop as soon + * there is a gap between the last acked seqno and the seqno of + * the packet under inspection + */ + if (batadv_seq_before(tp_vars->last_recv, un->seqno)) + break; + + to_ack = un->seqno + un->len - tp_vars->last_recv; + + if (batadv_seq_before(tp_vars->last_recv, un->seqno + un->len)) + tp_vars->last_recv += to_ack; + + list_del(&un->list); + kfree(un); + } + spin_unlock_bh(&tp_vars->unacked_lock); +} + +/** + * batadv_tp_init_recv() - return matching or create new receiver tp_vars + * @bat_priv: the bat priv with all the soft interface information + * @icmp: received icmp tp msg + * + * Return: corresponding tp_vars or NULL on errors + */ +static struct batadv_tp_vars * +batadv_tp_init_recv(struct batadv_priv *bat_priv, + const struct batadv_icmp_tp_packet *icmp) +{ + struct batadv_tp_vars *tp_vars; + + spin_lock_bh(&bat_priv->tp_list_lock); + tp_vars = batadv_tp_list_find_session(bat_priv, icmp->orig, + icmp->session); + if (tp_vars) + goto out_unlock; + + if (!atomic_add_unless(&bat_priv->tp_num, 1, BATADV_TP_MAX_NUM)) { + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Meter: too many ongoing sessions, aborting (RECV)\n"); + goto out_unlock; + } + + tp_vars = kmalloc(sizeof(*tp_vars), GFP_ATOMIC); + if (!tp_vars) + goto out_unlock; + + ether_addr_copy(tp_vars->other_end, icmp->orig); + tp_vars->role = BATADV_TP_RECEIVER; + memcpy(tp_vars->session, icmp->session, sizeof(tp_vars->session)); + tp_vars->last_recv = BATADV_TP_FIRST_SEQ; + tp_vars->bat_priv = bat_priv; + kref_init(&tp_vars->refcount); + + spin_lock_init(&tp_vars->unacked_lock); + INIT_LIST_HEAD(&tp_vars->unacked_list); + + kref_get(&tp_vars->refcount); + hlist_add_head_rcu(&tp_vars->list, &bat_priv->tp_list); + + kref_get(&tp_vars->refcount); + timer_setup(&tp_vars->timer, batadv_tp_receiver_shutdown, 0); + + batadv_tp_reset_receiver_timer(tp_vars); + +out_unlock: + spin_unlock_bh(&bat_priv->tp_list_lock); + + return tp_vars; +} + +/** + * batadv_tp_recv_msg() - process a single data message + * @bat_priv: the bat priv with all the soft interface information + * @skb: the buffer containing the received packet + * + * Process a received TP MSG packet + */ +static void batadv_tp_recv_msg(struct batadv_priv *bat_priv, + const struct sk_buff *skb) +{ + const struct batadv_icmp_tp_packet *icmp; + struct batadv_tp_vars *tp_vars; + size_t packet_size; + u32 seqno; + + icmp = (struct batadv_icmp_tp_packet *)skb->data; + + seqno = ntohl(icmp->seqno); + /* check if this is the first seqno. This means that if the + * first packet is lost, the tp meter does not work anymore! + */ + if (seqno == BATADV_TP_FIRST_SEQ) { + tp_vars = batadv_tp_init_recv(bat_priv, icmp); + if (!tp_vars) { + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Meter: seqno != BATADV_TP_FIRST_SEQ cannot initiate connection\n"); + goto out; + } + } else { + tp_vars = batadv_tp_list_find_session(bat_priv, icmp->orig, + icmp->session); + if (!tp_vars) { + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Unexpected packet from %pM!\n", + icmp->orig); + goto out; + } + } + + if (unlikely(tp_vars->role != BATADV_TP_RECEIVER)) { + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Meter: dropping packet: not expected (role=%u)\n", + tp_vars->role); + goto out; + } + + tp_vars->last_recv_time = jiffies; + + /* if the packet is a duplicate, it may be the case that an ACK has been + * lost. Resend the ACK + */ + if (batadv_seq_before(seqno, tp_vars->last_recv)) + goto send_ack; + + /* if the packet is out of order enqueue it */ + if (ntohl(icmp->seqno) != tp_vars->last_recv) { + /* exit immediately (and do not send any ACK) if the packet has + * not been enqueued correctly + */ + if (!batadv_tp_handle_out_of_order(tp_vars, skb)) + goto out; + + /* send a duplicate ACK */ + goto send_ack; + } + + /* if everything was fine count the ACKed bytes */ + packet_size = skb->len - sizeof(struct batadv_unicast_packet); + tp_vars->last_recv += packet_size; + + /* check if this ordered message filled a gap.... */ + batadv_tp_ack_unordered(tp_vars); + +send_ack: + /* send the ACK. If the received packet was out of order, the ACK that + * is going to be sent is a duplicate (the sender will count them and + * possibly enter Fast Retransmit as soon as it has reached 3) + */ + batadv_tp_send_ack(bat_priv, icmp->orig, tp_vars->last_recv, + icmp->timestamp, icmp->session, icmp->uid); +out: + batadv_tp_vars_put(tp_vars); +} + +/** + * batadv_tp_meter_recv() - main TP Meter receiving function + * @bat_priv: the bat priv with all the soft interface information + * @skb: the buffer containing the received packet + */ +void batadv_tp_meter_recv(struct batadv_priv *bat_priv, struct sk_buff *skb) +{ + struct batadv_icmp_tp_packet *icmp; + + icmp = (struct batadv_icmp_tp_packet *)skb->data; + + switch (icmp->subtype) { + case BATADV_TP_MSG: + batadv_tp_recv_msg(bat_priv, skb); + break; + case BATADV_TP_ACK: + batadv_tp_recv_ack(bat_priv, skb); + break; + default: + batadv_dbg(BATADV_DBG_TP_METER, bat_priv, + "Received unknown TP Metric packet type %u\n", + icmp->subtype); + } + consume_skb(skb); +} + +/** + * batadv_tp_meter_init() - initialize global tp_meter structures + */ +void __init batadv_tp_meter_init(void) +{ + get_random_bytes(batadv_tp_prerandom, sizeof(batadv_tp_prerandom)); +} diff --git a/net/batman-adv/tp_meter.h b/net/batman-adv/tp_meter.h new file mode 100644 index 000000000..f0046d366 --- /dev/null +++ b/net/batman-adv/tp_meter.h @@ -0,0 +1,22 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Edo Monticelli, Antonio Quartulli + */ + +#ifndef _NET_BATMAN_ADV_TP_METER_H_ +#define _NET_BATMAN_ADV_TP_METER_H_ + +#include "main.h" + +#include +#include + +void batadv_tp_meter_init(void); +void batadv_tp_start(struct batadv_priv *bat_priv, const u8 *dst, + u32 test_length, u32 *cookie); +void batadv_tp_stop(struct batadv_priv *bat_priv, const u8 *dst, + u8 return_value); +void batadv_tp_meter_recv(struct batadv_priv *bat_priv, struct sk_buff *skb); + +#endif /* _NET_BATMAN_ADV_TP_METER_H_ */ diff --git a/net/batman-adv/trace.c b/net/batman-adv/trace.c new file mode 100644 index 000000000..ec8b95190 --- /dev/null +++ b/net/batman-adv/trace.c @@ -0,0 +1,8 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Sven Eckelmann + */ + +#define CREATE_TRACE_POINTS +#include "trace.h" diff --git a/net/batman-adv/trace.h b/net/batman-adv/trace.h new file mode 100644 index 000000000..5dd52bc5c --- /dev/null +++ b/net/batman-adv/trace.h @@ -0,0 +1,64 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Sven Eckelmann + */ + +#if !defined(_NET_BATMAN_ADV_TRACE_H_) || defined(TRACE_HEADER_MULTI_READ) +#define _NET_BATMAN_ADV_TRACE_H_ + +#include "main.h" + +#include +#include +#include +#include + +#undef TRACE_SYSTEM +#define TRACE_SYSTEM batadv + +/* provide dummy function when tracing is disabled */ +#if !defined(CONFIG_BATMAN_ADV_TRACING) + +#undef TRACE_EVENT +#define TRACE_EVENT(name, proto, ...) \ + static inline void trace_ ## name(proto) {} + +#endif /* CONFIG_BATMAN_ADV_TRACING */ + +TRACE_EVENT(batadv_dbg, + + TP_PROTO(struct batadv_priv *bat_priv, + struct va_format *vaf), + + TP_ARGS(bat_priv, vaf), + + TP_STRUCT__entry( + __string(device, bat_priv->soft_iface->name) + __string(driver, KBUILD_MODNAME) + __vstring(msg, vaf->fmt, vaf->va) + ), + + TP_fast_assign( + __assign_str(device, bat_priv->soft_iface->name); + __assign_str(driver, KBUILD_MODNAME); + __assign_vstr(msg, vaf->fmt, vaf->va); + ), + + TP_printk( + "%s %s %s", + __get_str(driver), + __get_str(device), + __get_str(msg) + ) +); + +#endif /* _NET_BATMAN_ADV_TRACE_H_ || TRACE_HEADER_MULTI_READ */ + +#undef TRACE_INCLUDE_PATH +#define TRACE_INCLUDE_PATH . +#undef TRACE_INCLUDE_FILE +#define TRACE_INCLUDE_FILE trace + +/* This part must be outside protection */ +#include diff --git a/net/batman-adv/translation-table.c b/net/batman-adv/translation-table.c new file mode 100644 index 000000000..5d8cee747 --- /dev/null +++ b/net/batman-adv/translation-table.c @@ -0,0 +1,4290 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich, Antonio Quartulli + */ + +#include "translation-table.h" +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "bridge_loop_avoidance.h" +#include "hard-interface.h" +#include "hash.h" +#include "log.h" +#include "netlink.h" +#include "originator.h" +#include "soft-interface.h" +#include "tvlv.h" + +static struct kmem_cache *batadv_tl_cache __read_mostly; +static struct kmem_cache *batadv_tg_cache __read_mostly; +static struct kmem_cache *batadv_tt_orig_cache __read_mostly; +static struct kmem_cache *batadv_tt_change_cache __read_mostly; +static struct kmem_cache *batadv_tt_req_cache __read_mostly; +static struct kmem_cache *batadv_tt_roam_cache __read_mostly; + +/* hash class keys */ +static struct lock_class_key batadv_tt_local_hash_lock_class_key; +static struct lock_class_key batadv_tt_global_hash_lock_class_key; + +static void batadv_send_roam_adv(struct batadv_priv *bat_priv, u8 *client, + unsigned short vid, + struct batadv_orig_node *orig_node); +static void batadv_tt_purge(struct work_struct *work); +static void +batadv_tt_global_del_orig_list(struct batadv_tt_global_entry *tt_global_entry); +static void batadv_tt_global_del(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + const unsigned char *addr, + unsigned short vid, const char *message, + bool roaming); + +/** + * batadv_compare_tt() - check if two TT entries are the same + * @node: the list element pointer of the first TT entry + * @data2: pointer to the tt_common_entry of the second TT entry + * + * Compare the MAC address and the VLAN ID of the two TT entries and check if + * they are the same TT client. + * Return: true if the two TT clients are the same, false otherwise + */ +static bool batadv_compare_tt(const struct hlist_node *node, const void *data2) +{ + const void *data1 = container_of(node, struct batadv_tt_common_entry, + hash_entry); + const struct batadv_tt_common_entry *tt1 = data1; + const struct batadv_tt_common_entry *tt2 = data2; + + return (tt1->vid == tt2->vid) && batadv_compare_eth(data1, data2); +} + +/** + * batadv_choose_tt() - return the index of the tt entry in the hash table + * @data: pointer to the tt_common_entry object to map + * @size: the size of the hash table + * + * Return: the hash index where the object represented by 'data' should be + * stored at. + */ +static inline u32 batadv_choose_tt(const void *data, u32 size) +{ + const struct batadv_tt_common_entry *tt; + u32 hash = 0; + + tt = data; + hash = jhash(&tt->addr, ETH_ALEN, hash); + hash = jhash(&tt->vid, sizeof(tt->vid), hash); + + return hash % size; +} + +/** + * batadv_tt_hash_find() - look for a client in the given hash table + * @hash: the hash table to search + * @addr: the mac address of the client to look for + * @vid: VLAN identifier + * + * Return: a pointer to the tt_common struct belonging to the searched client if + * found, NULL otherwise. + */ +static struct batadv_tt_common_entry * +batadv_tt_hash_find(struct batadv_hashtable *hash, const u8 *addr, + unsigned short vid) +{ + struct hlist_head *head; + struct batadv_tt_common_entry to_search, *tt, *tt_tmp = NULL; + u32 index; + + if (!hash) + return NULL; + + ether_addr_copy(to_search.addr, addr); + to_search.vid = vid; + + index = batadv_choose_tt(&to_search, hash->size); + head = &hash->table[index]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tt, head, hash_entry) { + if (!batadv_compare_eth(tt, addr)) + continue; + + if (tt->vid != vid) + continue; + + if (!kref_get_unless_zero(&tt->refcount)) + continue; + + tt_tmp = tt; + break; + } + rcu_read_unlock(); + + return tt_tmp; +} + +/** + * batadv_tt_local_hash_find() - search the local table for a given client + * @bat_priv: the bat priv with all the soft interface information + * @addr: the mac address of the client to look for + * @vid: VLAN identifier + * + * Return: a pointer to the corresponding tt_local_entry struct if the client is + * found, NULL otherwise. + */ +static struct batadv_tt_local_entry * +batadv_tt_local_hash_find(struct batadv_priv *bat_priv, const u8 *addr, + unsigned short vid) +{ + struct batadv_tt_common_entry *tt_common_entry; + struct batadv_tt_local_entry *tt_local_entry = NULL; + + tt_common_entry = batadv_tt_hash_find(bat_priv->tt.local_hash, addr, + vid); + if (tt_common_entry) + tt_local_entry = container_of(tt_common_entry, + struct batadv_tt_local_entry, + common); + return tt_local_entry; +} + +/** + * batadv_tt_global_hash_find() - search the global table for a given client + * @bat_priv: the bat priv with all the soft interface information + * @addr: the mac address of the client to look for + * @vid: VLAN identifier + * + * Return: a pointer to the corresponding tt_global_entry struct if the client + * is found, NULL otherwise. + */ +struct batadv_tt_global_entry * +batadv_tt_global_hash_find(struct batadv_priv *bat_priv, const u8 *addr, + unsigned short vid) +{ + struct batadv_tt_common_entry *tt_common_entry; + struct batadv_tt_global_entry *tt_global_entry = NULL; + + tt_common_entry = batadv_tt_hash_find(bat_priv->tt.global_hash, addr, + vid); + if (tt_common_entry) + tt_global_entry = container_of(tt_common_entry, + struct batadv_tt_global_entry, + common); + return tt_global_entry; +} + +/** + * batadv_tt_local_entry_free_rcu() - free the tt_local_entry + * @rcu: rcu pointer of the tt_local_entry + */ +static void batadv_tt_local_entry_free_rcu(struct rcu_head *rcu) +{ + struct batadv_tt_local_entry *tt_local_entry; + + tt_local_entry = container_of(rcu, struct batadv_tt_local_entry, + common.rcu); + + kmem_cache_free(batadv_tl_cache, tt_local_entry); +} + +/** + * batadv_tt_local_entry_release() - release tt_local_entry from lists and queue + * for free after rcu grace period + * @ref: kref pointer of the nc_node + */ +static void batadv_tt_local_entry_release(struct kref *ref) +{ + struct batadv_tt_local_entry *tt_local_entry; + + tt_local_entry = container_of(ref, struct batadv_tt_local_entry, + common.refcount); + + batadv_softif_vlan_put(tt_local_entry->vlan); + + call_rcu(&tt_local_entry->common.rcu, batadv_tt_local_entry_free_rcu); +} + +/** + * batadv_tt_local_entry_put() - decrement the tt_local_entry refcounter and + * possibly release it + * @tt_local_entry: tt_local_entry to be free'd + */ +static void +batadv_tt_local_entry_put(struct batadv_tt_local_entry *tt_local_entry) +{ + if (!tt_local_entry) + return; + + kref_put(&tt_local_entry->common.refcount, + batadv_tt_local_entry_release); +} + +/** + * batadv_tt_global_entry_free_rcu() - free the tt_global_entry + * @rcu: rcu pointer of the tt_global_entry + */ +static void batadv_tt_global_entry_free_rcu(struct rcu_head *rcu) +{ + struct batadv_tt_global_entry *tt_global_entry; + + tt_global_entry = container_of(rcu, struct batadv_tt_global_entry, + common.rcu); + + kmem_cache_free(batadv_tg_cache, tt_global_entry); +} + +/** + * batadv_tt_global_entry_release() - release tt_global_entry from lists and + * queue for free after rcu grace period + * @ref: kref pointer of the nc_node + */ +void batadv_tt_global_entry_release(struct kref *ref) +{ + struct batadv_tt_global_entry *tt_global_entry; + + tt_global_entry = container_of(ref, struct batadv_tt_global_entry, + common.refcount); + + batadv_tt_global_del_orig_list(tt_global_entry); + + call_rcu(&tt_global_entry->common.rcu, batadv_tt_global_entry_free_rcu); +} + +/** + * batadv_tt_global_hash_count() - count the number of orig entries + * @bat_priv: the bat priv with all the soft interface information + * @addr: the mac address of the client to count entries for + * @vid: VLAN identifier + * + * Return: the number of originators advertising the given address/data + * (excluding our self). + */ +int batadv_tt_global_hash_count(struct batadv_priv *bat_priv, + const u8 *addr, unsigned short vid) +{ + struct batadv_tt_global_entry *tt_global_entry; + int count; + + tt_global_entry = batadv_tt_global_hash_find(bat_priv, addr, vid); + if (!tt_global_entry) + return 0; + + count = atomic_read(&tt_global_entry->orig_list_count); + batadv_tt_global_entry_put(tt_global_entry); + + return count; +} + +/** + * batadv_tt_local_size_mod() - change the size by v of the local table + * identified by vid + * @bat_priv: the bat priv with all the soft interface information + * @vid: the VLAN identifier of the sub-table to change + * @v: the amount to sum to the local table size + */ +static void batadv_tt_local_size_mod(struct batadv_priv *bat_priv, + unsigned short vid, int v) +{ + struct batadv_softif_vlan *vlan; + + vlan = batadv_softif_vlan_get(bat_priv, vid); + if (!vlan) + return; + + atomic_add(v, &vlan->tt.num_entries); + + batadv_softif_vlan_put(vlan); +} + +/** + * batadv_tt_local_size_inc() - increase by one the local table size for the + * given vid + * @bat_priv: the bat priv with all the soft interface information + * @vid: the VLAN identifier + */ +static void batadv_tt_local_size_inc(struct batadv_priv *bat_priv, + unsigned short vid) +{ + batadv_tt_local_size_mod(bat_priv, vid, 1); +} + +/** + * batadv_tt_local_size_dec() - decrease by one the local table size for the + * given vid + * @bat_priv: the bat priv with all the soft interface information + * @vid: the VLAN identifier + */ +static void batadv_tt_local_size_dec(struct batadv_priv *bat_priv, + unsigned short vid) +{ + batadv_tt_local_size_mod(bat_priv, vid, -1); +} + +/** + * batadv_tt_global_size_mod() - change the size by v of the global table + * for orig_node identified by vid + * @orig_node: the originator for which the table has to be modified + * @vid: the VLAN identifier + * @v: the amount to sum to the global table size + */ +static void batadv_tt_global_size_mod(struct batadv_orig_node *orig_node, + unsigned short vid, int v) +{ + struct batadv_orig_node_vlan *vlan; + + vlan = batadv_orig_node_vlan_new(orig_node, vid); + if (!vlan) + return; + + if (atomic_add_return(v, &vlan->tt.num_entries) == 0) { + spin_lock_bh(&orig_node->vlan_list_lock); + if (!hlist_unhashed(&vlan->list)) { + hlist_del_init_rcu(&vlan->list); + batadv_orig_node_vlan_put(vlan); + } + spin_unlock_bh(&orig_node->vlan_list_lock); + } + + batadv_orig_node_vlan_put(vlan); +} + +/** + * batadv_tt_global_size_inc() - increase by one the global table size for the + * given vid + * @orig_node: the originator which global table size has to be decreased + * @vid: the vlan identifier + */ +static void batadv_tt_global_size_inc(struct batadv_orig_node *orig_node, + unsigned short vid) +{ + batadv_tt_global_size_mod(orig_node, vid, 1); +} + +/** + * batadv_tt_global_size_dec() - decrease by one the global table size for the + * given vid + * @orig_node: the originator which global table size has to be decreased + * @vid: the vlan identifier + */ +static void batadv_tt_global_size_dec(struct batadv_orig_node *orig_node, + unsigned short vid) +{ + batadv_tt_global_size_mod(orig_node, vid, -1); +} + +/** + * batadv_tt_orig_list_entry_free_rcu() - free the orig_entry + * @rcu: rcu pointer of the orig_entry + */ +static void batadv_tt_orig_list_entry_free_rcu(struct rcu_head *rcu) +{ + struct batadv_tt_orig_list_entry *orig_entry; + + orig_entry = container_of(rcu, struct batadv_tt_orig_list_entry, rcu); + + kmem_cache_free(batadv_tt_orig_cache, orig_entry); +} + +/** + * batadv_tt_orig_list_entry_release() - release tt orig entry from lists and + * queue for free after rcu grace period + * @ref: kref pointer of the tt orig entry + */ +static void batadv_tt_orig_list_entry_release(struct kref *ref) +{ + struct batadv_tt_orig_list_entry *orig_entry; + + orig_entry = container_of(ref, struct batadv_tt_orig_list_entry, + refcount); + + batadv_orig_node_put(orig_entry->orig_node); + call_rcu(&orig_entry->rcu, batadv_tt_orig_list_entry_free_rcu); +} + +/** + * batadv_tt_orig_list_entry_put() - decrement the tt orig entry refcounter and + * possibly release it + * @orig_entry: tt orig entry to be free'd + */ +static void +batadv_tt_orig_list_entry_put(struct batadv_tt_orig_list_entry *orig_entry) +{ + if (!orig_entry) + return; + + kref_put(&orig_entry->refcount, batadv_tt_orig_list_entry_release); +} + +/** + * batadv_tt_local_event() - store a local TT event (ADD/DEL) + * @bat_priv: the bat priv with all the soft interface information + * @tt_local_entry: the TT entry involved in the event + * @event_flags: flags to store in the event structure + */ +static void batadv_tt_local_event(struct batadv_priv *bat_priv, + struct batadv_tt_local_entry *tt_local_entry, + u8 event_flags) +{ + struct batadv_tt_change_node *tt_change_node, *entry, *safe; + struct batadv_tt_common_entry *common = &tt_local_entry->common; + u8 flags = common->flags | event_flags; + bool event_removed = false; + bool del_op_requested, del_op_entry; + + tt_change_node = kmem_cache_alloc(batadv_tt_change_cache, GFP_ATOMIC); + if (!tt_change_node) + return; + + tt_change_node->change.flags = flags; + memset(tt_change_node->change.reserved, 0, + sizeof(tt_change_node->change.reserved)); + ether_addr_copy(tt_change_node->change.addr, common->addr); + tt_change_node->change.vid = htons(common->vid); + + del_op_requested = flags & BATADV_TT_CLIENT_DEL; + + /* check for ADD+DEL or DEL+ADD events */ + spin_lock_bh(&bat_priv->tt.changes_list_lock); + list_for_each_entry_safe(entry, safe, &bat_priv->tt.changes_list, + list) { + if (!batadv_compare_eth(entry->change.addr, common->addr)) + continue; + + /* DEL+ADD in the same orig interval have no effect and can be + * removed to avoid silly behaviour on the receiver side. The + * other way around (ADD+DEL) can happen in case of roaming of + * a client still in the NEW state. Roaming of NEW clients is + * now possible due to automatically recognition of "temporary" + * clients + */ + del_op_entry = entry->change.flags & BATADV_TT_CLIENT_DEL; + if (!del_op_requested && del_op_entry) + goto del; + if (del_op_requested && !del_op_entry) + goto del; + + /* this is a second add in the same originator interval. It + * means that flags have been changed: update them! + */ + if (!del_op_requested && !del_op_entry) + entry->change.flags = flags; + + continue; +del: + list_del(&entry->list); + kmem_cache_free(batadv_tt_change_cache, entry); + kmem_cache_free(batadv_tt_change_cache, tt_change_node); + event_removed = true; + goto unlock; + } + + /* track the change in the OGMinterval list */ + list_add_tail(&tt_change_node->list, &bat_priv->tt.changes_list); + +unlock: + spin_unlock_bh(&bat_priv->tt.changes_list_lock); + + if (event_removed) + atomic_dec(&bat_priv->tt.local_changes); + else + atomic_inc(&bat_priv->tt.local_changes); +} + +/** + * batadv_tt_len() - compute length in bytes of given number of tt changes + * @changes_num: number of tt changes + * + * Return: computed length in bytes. + */ +static int batadv_tt_len(int changes_num) +{ + return changes_num * sizeof(struct batadv_tvlv_tt_change); +} + +/** + * batadv_tt_entries() - compute the number of entries fitting in tt_len bytes + * @tt_len: available space + * + * Return: the number of entries. + */ +static u16 batadv_tt_entries(u16 tt_len) +{ + return tt_len / batadv_tt_len(1); +} + +/** + * batadv_tt_local_table_transmit_size() - calculates the local translation + * table size when transmitted over the air + * @bat_priv: the bat priv with all the soft interface information + * + * Return: local translation table size in bytes. + */ +static int batadv_tt_local_table_transmit_size(struct batadv_priv *bat_priv) +{ + u16 num_vlan = 0; + u16 tt_local_entries = 0; + struct batadv_softif_vlan *vlan; + int hdr_size; + + rcu_read_lock(); + hlist_for_each_entry_rcu(vlan, &bat_priv->softif_vlan_list, list) { + num_vlan++; + tt_local_entries += atomic_read(&vlan->tt.num_entries); + } + rcu_read_unlock(); + + /* header size of tvlv encapsulated tt response payload */ + hdr_size = sizeof(struct batadv_unicast_tvlv_packet); + hdr_size += sizeof(struct batadv_tvlv_hdr); + hdr_size += sizeof(struct batadv_tvlv_tt_data); + hdr_size += num_vlan * sizeof(struct batadv_tvlv_tt_vlan_data); + + return hdr_size + batadv_tt_len(tt_local_entries); +} + +static int batadv_tt_local_init(struct batadv_priv *bat_priv) +{ + if (bat_priv->tt.local_hash) + return 0; + + bat_priv->tt.local_hash = batadv_hash_new(1024); + + if (!bat_priv->tt.local_hash) + return -ENOMEM; + + batadv_hash_set_lock_class(bat_priv->tt.local_hash, + &batadv_tt_local_hash_lock_class_key); + + return 0; +} + +static void batadv_tt_global_free(struct batadv_priv *bat_priv, + struct batadv_tt_global_entry *tt_global, + const char *message) +{ + struct batadv_tt_global_entry *tt_removed_entry; + struct hlist_node *tt_removed_node; + + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Deleting global tt entry %pM (vid: %d): %s\n", + tt_global->common.addr, + batadv_print_vid(tt_global->common.vid), message); + + tt_removed_node = batadv_hash_remove(bat_priv->tt.global_hash, + batadv_compare_tt, + batadv_choose_tt, + &tt_global->common); + if (!tt_removed_node) + return; + + /* drop reference of remove hash entry */ + tt_removed_entry = hlist_entry(tt_removed_node, + struct batadv_tt_global_entry, + common.hash_entry); + batadv_tt_global_entry_put(tt_removed_entry); +} + +/** + * batadv_tt_local_add() - add a new client to the local table or update an + * existing client + * @soft_iface: netdev struct of the mesh interface + * @addr: the mac address of the client to add + * @vid: VLAN identifier + * @ifindex: index of the interface where the client is connected to (useful to + * identify wireless clients) + * @mark: the value contained in the skb->mark field of the received packet (if + * any) + * + * Return: true if the client was successfully added, false otherwise. + */ +bool batadv_tt_local_add(struct net_device *soft_iface, const u8 *addr, + unsigned short vid, int ifindex, u32 mark) +{ + struct batadv_priv *bat_priv = netdev_priv(soft_iface); + struct batadv_tt_local_entry *tt_local; + struct batadv_tt_global_entry *tt_global = NULL; + struct net *net = dev_net(soft_iface); + struct batadv_softif_vlan *vlan; + struct net_device *in_dev = NULL; + struct batadv_hard_iface *in_hardif = NULL; + struct hlist_head *head; + struct batadv_tt_orig_list_entry *orig_entry; + int hash_added, table_size, packet_size_max; + bool ret = false; + bool roamed_back = false; + u8 remote_flags; + u32 match_mark; + + if (ifindex != BATADV_NULL_IFINDEX) + in_dev = dev_get_by_index(net, ifindex); + + if (in_dev) + in_hardif = batadv_hardif_get_by_netdev(in_dev); + + tt_local = batadv_tt_local_hash_find(bat_priv, addr, vid); + + if (!is_multicast_ether_addr(addr)) + tt_global = batadv_tt_global_hash_find(bat_priv, addr, vid); + + if (tt_local) { + tt_local->last_seen = jiffies; + if (tt_local->common.flags & BATADV_TT_CLIENT_PENDING) { + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Re-adding pending client %pM (vid: %d)\n", + addr, batadv_print_vid(vid)); + /* whatever the reason why the PENDING flag was set, + * this is a client which was enqueued to be removed in + * this orig_interval. Since it popped up again, the + * flag can be reset like it was never enqueued + */ + tt_local->common.flags &= ~BATADV_TT_CLIENT_PENDING; + goto add_event; + } + + if (tt_local->common.flags & BATADV_TT_CLIENT_ROAM) { + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Roaming client %pM (vid: %d) came back to its original location\n", + addr, batadv_print_vid(vid)); + /* the ROAM flag is set because this client roamed away + * and the node got a roaming_advertisement message. Now + * that the client popped up again at its original + * location such flag can be unset + */ + tt_local->common.flags &= ~BATADV_TT_CLIENT_ROAM; + roamed_back = true; + } + goto check_roaming; + } + + /* Ignore the client if we cannot send it in a full table response. */ + table_size = batadv_tt_local_table_transmit_size(bat_priv); + table_size += batadv_tt_len(1); + packet_size_max = atomic_read(&bat_priv->packet_size_max); + if (table_size > packet_size_max) { + net_ratelimited_function(batadv_info, soft_iface, + "Local translation table size (%i) exceeds maximum packet size (%i); Ignoring new local tt entry: %pM\n", + table_size, packet_size_max, addr); + goto out; + } + + tt_local = kmem_cache_alloc(batadv_tl_cache, GFP_ATOMIC); + if (!tt_local) + goto out; + + /* increase the refcounter of the related vlan */ + vlan = batadv_softif_vlan_get(bat_priv, vid); + if (!vlan) { + net_ratelimited_function(batadv_info, soft_iface, + "adding TT local entry %pM to non-existent VLAN %d\n", + addr, batadv_print_vid(vid)); + kmem_cache_free(batadv_tl_cache, tt_local); + tt_local = NULL; + goto out; + } + + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Creating new local tt entry: %pM (vid: %d, ttvn: %d)\n", + addr, batadv_print_vid(vid), + (u8)atomic_read(&bat_priv->tt.vn)); + + ether_addr_copy(tt_local->common.addr, addr); + /* The local entry has to be marked as NEW to avoid to send it in + * a full table response going out before the next ttvn increment + * (consistency check) + */ + tt_local->common.flags = BATADV_TT_CLIENT_NEW; + tt_local->common.vid = vid; + if (batadv_is_wifi_hardif(in_hardif)) + tt_local->common.flags |= BATADV_TT_CLIENT_WIFI; + kref_init(&tt_local->common.refcount); + tt_local->last_seen = jiffies; + tt_local->common.added_at = tt_local->last_seen; + tt_local->vlan = vlan; + + /* the batman interface mac and multicast addresses should never be + * purged + */ + if (batadv_compare_eth(addr, soft_iface->dev_addr) || + is_multicast_ether_addr(addr)) + tt_local->common.flags |= BATADV_TT_CLIENT_NOPURGE; + + kref_get(&tt_local->common.refcount); + hash_added = batadv_hash_add(bat_priv->tt.local_hash, batadv_compare_tt, + batadv_choose_tt, &tt_local->common, + &tt_local->common.hash_entry); + + if (unlikely(hash_added != 0)) { + /* remove the reference for the hash */ + batadv_tt_local_entry_put(tt_local); + goto out; + } + +add_event: + batadv_tt_local_event(bat_priv, tt_local, BATADV_NO_FLAGS); + +check_roaming: + /* Check whether it is a roaming, but don't do anything if the roaming + * process has already been handled + */ + if (tt_global && !(tt_global->common.flags & BATADV_TT_CLIENT_ROAM)) { + /* These node are probably going to update their tt table */ + head = &tt_global->orig_list; + rcu_read_lock(); + hlist_for_each_entry_rcu(orig_entry, head, list) { + batadv_send_roam_adv(bat_priv, tt_global->common.addr, + tt_global->common.vid, + orig_entry->orig_node); + } + rcu_read_unlock(); + if (roamed_back) { + batadv_tt_global_free(bat_priv, tt_global, + "Roaming canceled"); + } else { + /* The global entry has to be marked as ROAMING and + * has to be kept for consistency purpose + */ + tt_global->common.flags |= BATADV_TT_CLIENT_ROAM; + tt_global->roam_at = jiffies; + } + } + + /* store the current remote flags before altering them. This helps + * understanding is flags are changing or not + */ + remote_flags = tt_local->common.flags & BATADV_TT_REMOTE_MASK; + + if (batadv_is_wifi_hardif(in_hardif)) + tt_local->common.flags |= BATADV_TT_CLIENT_WIFI; + else + tt_local->common.flags &= ~BATADV_TT_CLIENT_WIFI; + + /* check the mark in the skb: if it's equal to the configured + * isolation_mark, it means the packet is coming from an isolated + * non-mesh client + */ + match_mark = (mark & bat_priv->isolation_mark_mask); + if (bat_priv->isolation_mark_mask && + match_mark == bat_priv->isolation_mark) + tt_local->common.flags |= BATADV_TT_CLIENT_ISOLA; + else + tt_local->common.flags &= ~BATADV_TT_CLIENT_ISOLA; + + /* if any "dynamic" flag has been modified, resend an ADD event for this + * entry so that all the nodes can get the new flags + */ + if (remote_flags ^ (tt_local->common.flags & BATADV_TT_REMOTE_MASK)) + batadv_tt_local_event(bat_priv, tt_local, BATADV_NO_FLAGS); + + ret = true; +out: + batadv_hardif_put(in_hardif); + dev_put(in_dev); + batadv_tt_local_entry_put(tt_local); + batadv_tt_global_entry_put(tt_global); + return ret; +} + +/** + * batadv_tt_prepare_tvlv_global_data() - prepare the TVLV TT header to send + * within a TT Response directed to another node + * @orig_node: originator for which the TT data has to be prepared + * @tt_data: uninitialised pointer to the address of the TVLV buffer + * @tt_change: uninitialised pointer to the address of the area where the TT + * changed can be stored + * @tt_len: pointer to the length to reserve to the tt_change. if -1 this + * function reserves the amount of space needed to send the entire global TT + * table. In case of success the value is updated with the real amount of + * reserved bytes + * Allocate the needed amount of memory for the entire TT TVLV and write its + * header made up of one tvlv_tt_data object and a series of tvlv_tt_vlan_data + * objects, one per active VLAN served by the originator node. + * + * Return: the size of the allocated buffer or 0 in case of failure. + */ +static u16 +batadv_tt_prepare_tvlv_global_data(struct batadv_orig_node *orig_node, + struct batadv_tvlv_tt_data **tt_data, + struct batadv_tvlv_tt_change **tt_change, + s32 *tt_len) +{ + u16 num_vlan = 0; + u16 num_entries = 0; + u16 change_offset; + u16 tvlv_len; + struct batadv_tvlv_tt_vlan_data *tt_vlan; + struct batadv_orig_node_vlan *vlan; + u8 *tt_change_ptr; + + spin_lock_bh(&orig_node->vlan_list_lock); + hlist_for_each_entry(vlan, &orig_node->vlan_list, list) { + num_vlan++; + num_entries += atomic_read(&vlan->tt.num_entries); + } + + change_offset = sizeof(**tt_data); + change_offset += num_vlan * sizeof(*tt_vlan); + + /* if tt_len is negative, allocate the space needed by the full table */ + if (*tt_len < 0) + *tt_len = batadv_tt_len(num_entries); + + tvlv_len = *tt_len; + tvlv_len += change_offset; + + *tt_data = kmalloc(tvlv_len, GFP_ATOMIC); + if (!*tt_data) { + *tt_len = 0; + goto out; + } + + (*tt_data)->flags = BATADV_NO_FLAGS; + (*tt_data)->ttvn = atomic_read(&orig_node->last_ttvn); + (*tt_data)->num_vlan = htons(num_vlan); + + tt_vlan = (struct batadv_tvlv_tt_vlan_data *)(*tt_data + 1); + hlist_for_each_entry(vlan, &orig_node->vlan_list, list) { + tt_vlan->vid = htons(vlan->vid); + tt_vlan->crc = htonl(vlan->tt.crc); + tt_vlan->reserved = 0; + + tt_vlan++; + } + + tt_change_ptr = (u8 *)*tt_data + change_offset; + *tt_change = (struct batadv_tvlv_tt_change *)tt_change_ptr; + +out: + spin_unlock_bh(&orig_node->vlan_list_lock); + return tvlv_len; +} + +/** + * batadv_tt_prepare_tvlv_local_data() - allocate and prepare the TT TVLV for + * this node + * @bat_priv: the bat priv with all the soft interface information + * @tt_data: uninitialised pointer to the address of the TVLV buffer + * @tt_change: uninitialised pointer to the address of the area where the TT + * changes can be stored + * @tt_len: pointer to the length to reserve to the tt_change. if -1 this + * function reserves the amount of space needed to send the entire local TT + * table. In case of success the value is updated with the real amount of + * reserved bytes + * + * Allocate the needed amount of memory for the entire TT TVLV and write its + * header made up by one tvlv_tt_data object and a series of tvlv_tt_vlan_data + * objects, one per active VLAN. + * + * Return: the size of the allocated buffer or 0 in case of failure. + */ +static u16 +batadv_tt_prepare_tvlv_local_data(struct batadv_priv *bat_priv, + struct batadv_tvlv_tt_data **tt_data, + struct batadv_tvlv_tt_change **tt_change, + s32 *tt_len) +{ + struct batadv_tvlv_tt_vlan_data *tt_vlan; + struct batadv_softif_vlan *vlan; + u16 num_vlan = 0; + u16 vlan_entries = 0; + u16 total_entries = 0; + u16 tvlv_len; + u8 *tt_change_ptr; + int change_offset; + + spin_lock_bh(&bat_priv->softif_vlan_list_lock); + hlist_for_each_entry(vlan, &bat_priv->softif_vlan_list, list) { + vlan_entries = atomic_read(&vlan->tt.num_entries); + if (vlan_entries < 1) + continue; + + num_vlan++; + total_entries += vlan_entries; + } + + change_offset = sizeof(**tt_data); + change_offset += num_vlan * sizeof(*tt_vlan); + + /* if tt_len is negative, allocate the space needed by the full table */ + if (*tt_len < 0) + *tt_len = batadv_tt_len(total_entries); + + tvlv_len = *tt_len; + tvlv_len += change_offset; + + *tt_data = kmalloc(tvlv_len, GFP_ATOMIC); + if (!*tt_data) { + tvlv_len = 0; + goto out; + } + + (*tt_data)->flags = BATADV_NO_FLAGS; + (*tt_data)->ttvn = atomic_read(&bat_priv->tt.vn); + (*tt_data)->num_vlan = htons(num_vlan); + + tt_vlan = (struct batadv_tvlv_tt_vlan_data *)(*tt_data + 1); + hlist_for_each_entry(vlan, &bat_priv->softif_vlan_list, list) { + vlan_entries = atomic_read(&vlan->tt.num_entries); + if (vlan_entries < 1) + continue; + + tt_vlan->vid = htons(vlan->vid); + tt_vlan->crc = htonl(vlan->tt.crc); + tt_vlan->reserved = 0; + + tt_vlan++; + } + + tt_change_ptr = (u8 *)*tt_data + change_offset; + *tt_change = (struct batadv_tvlv_tt_change *)tt_change_ptr; + +out: + spin_unlock_bh(&bat_priv->softif_vlan_list_lock); + return tvlv_len; +} + +/** + * batadv_tt_tvlv_container_update() - update the translation table tvlv + * container after local tt changes have been committed + * @bat_priv: the bat priv with all the soft interface information + */ +static void batadv_tt_tvlv_container_update(struct batadv_priv *bat_priv) +{ + struct batadv_tt_change_node *entry, *safe; + struct batadv_tvlv_tt_data *tt_data; + struct batadv_tvlv_tt_change *tt_change; + int tt_diff_len, tt_change_len = 0; + int tt_diff_entries_num = 0; + int tt_diff_entries_count = 0; + u16 tvlv_len; + + tt_diff_entries_num = atomic_read(&bat_priv->tt.local_changes); + tt_diff_len = batadv_tt_len(tt_diff_entries_num); + + /* if we have too many changes for one packet don't send any + * and wait for the tt table request which will be fragmented + */ + if (tt_diff_len > bat_priv->soft_iface->mtu) + tt_diff_len = 0; + + tvlv_len = batadv_tt_prepare_tvlv_local_data(bat_priv, &tt_data, + &tt_change, &tt_diff_len); + if (!tvlv_len) + return; + + tt_data->flags = BATADV_TT_OGM_DIFF; + + if (tt_diff_len == 0) + goto container_register; + + spin_lock_bh(&bat_priv->tt.changes_list_lock); + atomic_set(&bat_priv->tt.local_changes, 0); + + list_for_each_entry_safe(entry, safe, &bat_priv->tt.changes_list, + list) { + if (tt_diff_entries_count < tt_diff_entries_num) { + memcpy(tt_change + tt_diff_entries_count, + &entry->change, + sizeof(struct batadv_tvlv_tt_change)); + tt_diff_entries_count++; + } + list_del(&entry->list); + kmem_cache_free(batadv_tt_change_cache, entry); + } + spin_unlock_bh(&bat_priv->tt.changes_list_lock); + + /* Keep the buffer for possible tt_request */ + spin_lock_bh(&bat_priv->tt.last_changeset_lock); + kfree(bat_priv->tt.last_changeset); + bat_priv->tt.last_changeset_len = 0; + bat_priv->tt.last_changeset = NULL; + tt_change_len = batadv_tt_len(tt_diff_entries_count); + /* check whether this new OGM has no changes due to size problems */ + if (tt_diff_entries_count > 0) { + /* if kmalloc() fails we will reply with the full table + * instead of providing the diff + */ + bat_priv->tt.last_changeset = kzalloc(tt_diff_len, GFP_ATOMIC); + if (bat_priv->tt.last_changeset) { + memcpy(bat_priv->tt.last_changeset, + tt_change, tt_change_len); + bat_priv->tt.last_changeset_len = tt_diff_len; + } + } + spin_unlock_bh(&bat_priv->tt.last_changeset_lock); + +container_register: + batadv_tvlv_container_register(bat_priv, BATADV_TVLV_TT, 1, tt_data, + tvlv_len); + kfree(tt_data); +} + +/** + * batadv_tt_local_dump_entry() - Dump one TT local entry into a message + * @msg :Netlink message to dump into + * @portid: Port making netlink request + * @cb: Control block containing additional options + * @bat_priv: The bat priv with all the soft interface information + * @common: tt local & tt global common data + * + * Return: Error code, or 0 on success + */ +static int +batadv_tt_local_dump_entry(struct sk_buff *msg, u32 portid, + struct netlink_callback *cb, + struct batadv_priv *bat_priv, + struct batadv_tt_common_entry *common) +{ + void *hdr; + struct batadv_softif_vlan *vlan; + struct batadv_tt_local_entry *local; + unsigned int last_seen_msecs; + u32 crc; + + local = container_of(common, struct batadv_tt_local_entry, common); + last_seen_msecs = jiffies_to_msecs(jiffies - local->last_seen); + + vlan = batadv_softif_vlan_get(bat_priv, common->vid); + if (!vlan) + return 0; + + crc = vlan->tt.crc; + + batadv_softif_vlan_put(vlan); + + hdr = genlmsg_put(msg, portid, cb->nlh->nlmsg_seq, + &batadv_netlink_family, NLM_F_MULTI, + BATADV_CMD_GET_TRANSTABLE_LOCAL); + if (!hdr) + return -ENOBUFS; + + genl_dump_check_consistent(cb, hdr); + + if (nla_put(msg, BATADV_ATTR_TT_ADDRESS, ETH_ALEN, common->addr) || + nla_put_u32(msg, BATADV_ATTR_TT_CRC32, crc) || + nla_put_u16(msg, BATADV_ATTR_TT_VID, common->vid) || + nla_put_u32(msg, BATADV_ATTR_TT_FLAGS, common->flags)) + goto nla_put_failure; + + if (!(common->flags & BATADV_TT_CLIENT_NOPURGE) && + nla_put_u32(msg, BATADV_ATTR_LAST_SEEN_MSECS, last_seen_msecs)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + + nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +/** + * batadv_tt_local_dump_bucket() - Dump one TT local bucket into a message + * @msg: Netlink message to dump into + * @portid: Port making netlink request + * @cb: Control block containing additional options + * @bat_priv: The bat priv with all the soft interface information + * @hash: hash to dump + * @bucket: bucket index to dump + * @idx_s: Number of entries to skip + * + * Return: Error code, or 0 on success + */ +static int +batadv_tt_local_dump_bucket(struct sk_buff *msg, u32 portid, + struct netlink_callback *cb, + struct batadv_priv *bat_priv, + struct batadv_hashtable *hash, unsigned int bucket, + int *idx_s) +{ + struct batadv_tt_common_entry *common; + int idx = 0; + + spin_lock_bh(&hash->list_locks[bucket]); + cb->seq = atomic_read(&hash->generation) << 1 | 1; + + hlist_for_each_entry(common, &hash->table[bucket], hash_entry) { + if (idx++ < *idx_s) + continue; + + if (batadv_tt_local_dump_entry(msg, portid, cb, bat_priv, + common)) { + spin_unlock_bh(&hash->list_locks[bucket]); + *idx_s = idx - 1; + return -EMSGSIZE; + } + } + spin_unlock_bh(&hash->list_locks[bucket]); + + *idx_s = 0; + return 0; +} + +/** + * batadv_tt_local_dump() - Dump TT local entries into a message + * @msg: Netlink message to dump into + * @cb: Parameters from query + * + * Return: Error code, or 0 on success + */ +int batadv_tt_local_dump(struct sk_buff *msg, struct netlink_callback *cb) +{ + struct net *net = sock_net(cb->skb->sk); + struct net_device *soft_iface; + struct batadv_priv *bat_priv; + struct batadv_hard_iface *primary_if = NULL; + struct batadv_hashtable *hash; + int ret; + int ifindex; + int bucket = cb->args[0]; + int idx = cb->args[1]; + int portid = NETLINK_CB(cb->skb).portid; + + ifindex = batadv_netlink_get_ifindex(cb->nlh, BATADV_ATTR_MESH_IFINDEX); + if (!ifindex) + return -EINVAL; + + soft_iface = dev_get_by_index(net, ifindex); + if (!soft_iface || !batadv_softif_is_valid(soft_iface)) { + ret = -ENODEV; + goto out; + } + + bat_priv = netdev_priv(soft_iface); + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if || primary_if->if_status != BATADV_IF_ACTIVE) { + ret = -ENOENT; + goto out; + } + + hash = bat_priv->tt.local_hash; + + while (bucket < hash->size) { + if (batadv_tt_local_dump_bucket(msg, portid, cb, bat_priv, + hash, bucket, &idx)) + break; + + bucket++; + } + + ret = msg->len; + + out: + batadv_hardif_put(primary_if); + dev_put(soft_iface); + + cb->args[0] = bucket; + cb->args[1] = idx; + + return ret; +} + +static void +batadv_tt_local_set_pending(struct batadv_priv *bat_priv, + struct batadv_tt_local_entry *tt_local_entry, + u16 flags, const char *message) +{ + batadv_tt_local_event(bat_priv, tt_local_entry, flags); + + /* The local client has to be marked as "pending to be removed" but has + * to be kept in the table in order to send it in a full table + * response issued before the net ttvn increment (consistency check) + */ + tt_local_entry->common.flags |= BATADV_TT_CLIENT_PENDING; + + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Local tt entry (%pM, vid: %d) pending to be removed: %s\n", + tt_local_entry->common.addr, + batadv_print_vid(tt_local_entry->common.vid), message); +} + +/** + * batadv_tt_local_remove() - logically remove an entry from the local table + * @bat_priv: the bat priv with all the soft interface information + * @addr: the MAC address of the client to remove + * @vid: VLAN identifier + * @message: message to append to the log on deletion + * @roaming: true if the deletion is due to a roaming event + * + * Return: the flags assigned to the local entry before being deleted + */ +u16 batadv_tt_local_remove(struct batadv_priv *bat_priv, const u8 *addr, + unsigned short vid, const char *message, + bool roaming) +{ + struct batadv_tt_local_entry *tt_removed_entry; + struct batadv_tt_local_entry *tt_local_entry; + u16 flags, curr_flags = BATADV_NO_FLAGS; + struct hlist_node *tt_removed_node; + + tt_local_entry = batadv_tt_local_hash_find(bat_priv, addr, vid); + if (!tt_local_entry) + goto out; + + curr_flags = tt_local_entry->common.flags; + + flags = BATADV_TT_CLIENT_DEL; + /* if this global entry addition is due to a roaming, the node has to + * mark the local entry as "roamed" in order to correctly reroute + * packets later + */ + if (roaming) { + flags |= BATADV_TT_CLIENT_ROAM; + /* mark the local client as ROAMed */ + tt_local_entry->common.flags |= BATADV_TT_CLIENT_ROAM; + } + + if (!(tt_local_entry->common.flags & BATADV_TT_CLIENT_NEW)) { + batadv_tt_local_set_pending(bat_priv, tt_local_entry, flags, + message); + goto out; + } + /* if this client has been added right now, it is possible to + * immediately purge it + */ + batadv_tt_local_event(bat_priv, tt_local_entry, BATADV_TT_CLIENT_DEL); + + tt_removed_node = batadv_hash_remove(bat_priv->tt.local_hash, + batadv_compare_tt, + batadv_choose_tt, + &tt_local_entry->common); + if (!tt_removed_node) + goto out; + + /* drop reference of remove hash entry */ + tt_removed_entry = hlist_entry(tt_removed_node, + struct batadv_tt_local_entry, + common.hash_entry); + batadv_tt_local_entry_put(tt_removed_entry); + +out: + batadv_tt_local_entry_put(tt_local_entry); + + return curr_flags; +} + +/** + * batadv_tt_local_purge_list() - purge inactive tt local entries + * @bat_priv: the bat priv with all the soft interface information + * @head: pointer to the list containing the local tt entries + * @timeout: parameter deciding whether a given tt local entry is considered + * inactive or not + */ +static void batadv_tt_local_purge_list(struct batadv_priv *bat_priv, + struct hlist_head *head, + int timeout) +{ + struct batadv_tt_local_entry *tt_local_entry; + struct batadv_tt_common_entry *tt_common_entry; + struct hlist_node *node_tmp; + + hlist_for_each_entry_safe(tt_common_entry, node_tmp, head, + hash_entry) { + tt_local_entry = container_of(tt_common_entry, + struct batadv_tt_local_entry, + common); + if (tt_local_entry->common.flags & BATADV_TT_CLIENT_NOPURGE) + continue; + + /* entry already marked for deletion */ + if (tt_local_entry->common.flags & BATADV_TT_CLIENT_PENDING) + continue; + + if (!batadv_has_timed_out(tt_local_entry->last_seen, timeout)) + continue; + + batadv_tt_local_set_pending(bat_priv, tt_local_entry, + BATADV_TT_CLIENT_DEL, "timed out"); + } +} + +/** + * batadv_tt_local_purge() - purge inactive tt local entries + * @bat_priv: the bat priv with all the soft interface information + * @timeout: parameter deciding whether a given tt local entry is considered + * inactive or not + */ +static void batadv_tt_local_purge(struct batadv_priv *bat_priv, + int timeout) +{ + struct batadv_hashtable *hash = bat_priv->tt.local_hash; + struct hlist_head *head; + spinlock_t *list_lock; /* protects write access to the hash lists */ + u32 i; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + list_lock = &hash->list_locks[i]; + + spin_lock_bh(list_lock); + batadv_tt_local_purge_list(bat_priv, head, timeout); + spin_unlock_bh(list_lock); + } +} + +static void batadv_tt_local_table_free(struct batadv_priv *bat_priv) +{ + struct batadv_hashtable *hash; + spinlock_t *list_lock; /* protects write access to the hash lists */ + struct batadv_tt_common_entry *tt_common_entry; + struct batadv_tt_local_entry *tt_local; + struct hlist_node *node_tmp; + struct hlist_head *head; + u32 i; + + if (!bat_priv->tt.local_hash) + return; + + hash = bat_priv->tt.local_hash; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + list_lock = &hash->list_locks[i]; + + spin_lock_bh(list_lock); + hlist_for_each_entry_safe(tt_common_entry, node_tmp, + head, hash_entry) { + hlist_del_rcu(&tt_common_entry->hash_entry); + tt_local = container_of(tt_common_entry, + struct batadv_tt_local_entry, + common); + + batadv_tt_local_entry_put(tt_local); + } + spin_unlock_bh(list_lock); + } + + batadv_hash_destroy(hash); + + bat_priv->tt.local_hash = NULL; +} + +static int batadv_tt_global_init(struct batadv_priv *bat_priv) +{ + if (bat_priv->tt.global_hash) + return 0; + + bat_priv->tt.global_hash = batadv_hash_new(1024); + + if (!bat_priv->tt.global_hash) + return -ENOMEM; + + batadv_hash_set_lock_class(bat_priv->tt.global_hash, + &batadv_tt_global_hash_lock_class_key); + + return 0; +} + +static void batadv_tt_changes_list_free(struct batadv_priv *bat_priv) +{ + struct batadv_tt_change_node *entry, *safe; + + spin_lock_bh(&bat_priv->tt.changes_list_lock); + + list_for_each_entry_safe(entry, safe, &bat_priv->tt.changes_list, + list) { + list_del(&entry->list); + kmem_cache_free(batadv_tt_change_cache, entry); + } + + atomic_set(&bat_priv->tt.local_changes, 0); + spin_unlock_bh(&bat_priv->tt.changes_list_lock); +} + +/** + * batadv_tt_global_orig_entry_find() - find a TT orig_list_entry + * @entry: the TT global entry where the orig_list_entry has to be + * extracted from + * @orig_node: the originator for which the orig_list_entry has to be found + * + * retrieve the orig_tt_list_entry belonging to orig_node from the + * batadv_tt_global_entry list + * + * Return: it with an increased refcounter, NULL if not found + */ +static struct batadv_tt_orig_list_entry * +batadv_tt_global_orig_entry_find(const struct batadv_tt_global_entry *entry, + const struct batadv_orig_node *orig_node) +{ + struct batadv_tt_orig_list_entry *tmp_orig_entry, *orig_entry = NULL; + const struct hlist_head *head; + + rcu_read_lock(); + head = &entry->orig_list; + hlist_for_each_entry_rcu(tmp_orig_entry, head, list) { + if (tmp_orig_entry->orig_node != orig_node) + continue; + if (!kref_get_unless_zero(&tmp_orig_entry->refcount)) + continue; + + orig_entry = tmp_orig_entry; + break; + } + rcu_read_unlock(); + + return orig_entry; +} + +/** + * batadv_tt_global_entry_has_orig() - check if a TT global entry is also + * handled by a given originator + * @entry: the TT global entry to check + * @orig_node: the originator to search in the list + * @flags: a pointer to store TT flags for the given @entry received + * from @orig_node + * + * find out if an orig_node is already in the list of a tt_global_entry. + * + * Return: true if found, false otherwise + */ +static bool +batadv_tt_global_entry_has_orig(const struct batadv_tt_global_entry *entry, + const struct batadv_orig_node *orig_node, + u8 *flags) +{ + struct batadv_tt_orig_list_entry *orig_entry; + bool found = false; + + orig_entry = batadv_tt_global_orig_entry_find(entry, orig_node); + if (orig_entry) { + found = true; + + if (flags) + *flags = orig_entry->flags; + + batadv_tt_orig_list_entry_put(orig_entry); + } + + return found; +} + +/** + * batadv_tt_global_sync_flags() - update TT sync flags + * @tt_global: the TT global entry to update sync flags in + * + * Updates the sync flag bits in the tt_global flag attribute with a logical + * OR of all sync flags from any of its TT orig entries. + */ +static void +batadv_tt_global_sync_flags(struct batadv_tt_global_entry *tt_global) +{ + struct batadv_tt_orig_list_entry *orig_entry; + const struct hlist_head *head; + u16 flags = BATADV_NO_FLAGS; + + rcu_read_lock(); + head = &tt_global->orig_list; + hlist_for_each_entry_rcu(orig_entry, head, list) + flags |= orig_entry->flags; + rcu_read_unlock(); + + flags |= tt_global->common.flags & (~BATADV_TT_SYNC_MASK); + tt_global->common.flags = flags; +} + +/** + * batadv_tt_global_orig_entry_add() - add or update a TT orig entry + * @tt_global: the TT global entry to add an orig entry in + * @orig_node: the originator to add an orig entry for + * @ttvn: translation table version number of this changeset + * @flags: TT sync flags + */ +static void +batadv_tt_global_orig_entry_add(struct batadv_tt_global_entry *tt_global, + struct batadv_orig_node *orig_node, int ttvn, + u8 flags) +{ + struct batadv_tt_orig_list_entry *orig_entry; + + spin_lock_bh(&tt_global->list_lock); + + orig_entry = batadv_tt_global_orig_entry_find(tt_global, orig_node); + if (orig_entry) { + /* refresh the ttvn: the current value could be a bogus one that + * was added during a "temporary client detection" + */ + orig_entry->ttvn = ttvn; + orig_entry->flags = flags; + goto sync_flags; + } + + orig_entry = kmem_cache_zalloc(batadv_tt_orig_cache, GFP_ATOMIC); + if (!orig_entry) + goto out; + + INIT_HLIST_NODE(&orig_entry->list); + kref_get(&orig_node->refcount); + batadv_tt_global_size_inc(orig_node, tt_global->common.vid); + orig_entry->orig_node = orig_node; + orig_entry->ttvn = ttvn; + orig_entry->flags = flags; + kref_init(&orig_entry->refcount); + + kref_get(&orig_entry->refcount); + hlist_add_head_rcu(&orig_entry->list, + &tt_global->orig_list); + atomic_inc(&tt_global->orig_list_count); + +sync_flags: + batadv_tt_global_sync_flags(tt_global); +out: + batadv_tt_orig_list_entry_put(orig_entry); + + spin_unlock_bh(&tt_global->list_lock); +} + +/** + * batadv_tt_global_add() - add a new TT global entry or update an existing one + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: the originator announcing the client + * @tt_addr: the mac address of the non-mesh client + * @vid: VLAN identifier + * @flags: TT flags that have to be set for this non-mesh client + * @ttvn: the tt version number ever announcing this non-mesh client + * + * Add a new TT global entry for the given originator. If the entry already + * exists add a new reference to the given originator (a global entry can have + * references to multiple originators) and adjust the flags attribute to reflect + * the function argument. + * If a TT local entry exists for this non-mesh client remove it. + * + * The caller must hold the orig_node refcount. + * + * Return: true if the new entry has been added, false otherwise + */ +static bool batadv_tt_global_add(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + const unsigned char *tt_addr, + unsigned short vid, u16 flags, u8 ttvn) +{ + struct batadv_tt_global_entry *tt_global_entry; + struct batadv_tt_local_entry *tt_local_entry; + bool ret = false; + int hash_added; + struct batadv_tt_common_entry *common; + u16 local_flags; + + /* ignore global entries from backbone nodes */ + if (batadv_bla_is_backbone_gw_orig(bat_priv, orig_node->orig, vid)) + return true; + + tt_global_entry = batadv_tt_global_hash_find(bat_priv, tt_addr, vid); + tt_local_entry = batadv_tt_local_hash_find(bat_priv, tt_addr, vid); + + /* if the node already has a local client for this entry, it has to wait + * for a roaming advertisement instead of manually messing up the global + * table + */ + if ((flags & BATADV_TT_CLIENT_TEMP) && tt_local_entry && + !(tt_local_entry->common.flags & BATADV_TT_CLIENT_NEW)) + goto out; + + if (!tt_global_entry) { + tt_global_entry = kmem_cache_zalloc(batadv_tg_cache, + GFP_ATOMIC); + if (!tt_global_entry) + goto out; + + common = &tt_global_entry->common; + ether_addr_copy(common->addr, tt_addr); + common->vid = vid; + + if (!is_multicast_ether_addr(common->addr)) + common->flags = flags & (~BATADV_TT_SYNC_MASK); + + tt_global_entry->roam_at = 0; + /* node must store current time in case of roaming. This is + * needed to purge this entry out on timeout (if nobody claims + * it) + */ + if (flags & BATADV_TT_CLIENT_ROAM) + tt_global_entry->roam_at = jiffies; + kref_init(&common->refcount); + common->added_at = jiffies; + + INIT_HLIST_HEAD(&tt_global_entry->orig_list); + atomic_set(&tt_global_entry->orig_list_count, 0); + spin_lock_init(&tt_global_entry->list_lock); + + kref_get(&common->refcount); + hash_added = batadv_hash_add(bat_priv->tt.global_hash, + batadv_compare_tt, + batadv_choose_tt, common, + &common->hash_entry); + + if (unlikely(hash_added != 0)) { + /* remove the reference for the hash */ + batadv_tt_global_entry_put(tt_global_entry); + goto out_remove; + } + } else { + common = &tt_global_entry->common; + /* If there is already a global entry, we can use this one for + * our processing. + * But if we are trying to add a temporary client then here are + * two options at this point: + * 1) the global client is not a temporary client: the global + * client has to be left as it is, temporary information + * should never override any already known client state + * 2) the global client is a temporary client: purge the + * originator list and add the new one orig_entry + */ + if (flags & BATADV_TT_CLIENT_TEMP) { + if (!(common->flags & BATADV_TT_CLIENT_TEMP)) + goto out; + if (batadv_tt_global_entry_has_orig(tt_global_entry, + orig_node, NULL)) + goto out_remove; + batadv_tt_global_del_orig_list(tt_global_entry); + goto add_orig_entry; + } + + /* if the client was temporary added before receiving the first + * OGM announcing it, we have to clear the TEMP flag. Also, + * remove the previous temporary orig node and re-add it + * if required. If the orig entry changed, the new one which + * is a non-temporary entry is preferred. + */ + if (common->flags & BATADV_TT_CLIENT_TEMP) { + batadv_tt_global_del_orig_list(tt_global_entry); + common->flags &= ~BATADV_TT_CLIENT_TEMP; + } + + /* the change can carry possible "attribute" flags like the + * TT_CLIENT_TEMP, therefore they have to be copied in the + * client entry + */ + if (!is_multicast_ether_addr(common->addr)) + common->flags |= flags & (~BATADV_TT_SYNC_MASK); + + /* If there is the BATADV_TT_CLIENT_ROAM flag set, there is only + * one originator left in the list and we previously received a + * delete + roaming change for this originator. + * + * We should first delete the old originator before adding the + * new one. + */ + if (common->flags & BATADV_TT_CLIENT_ROAM) { + batadv_tt_global_del_orig_list(tt_global_entry); + common->flags &= ~BATADV_TT_CLIENT_ROAM; + tt_global_entry->roam_at = 0; + } + } +add_orig_entry: + /* add the new orig_entry (if needed) or update it */ + batadv_tt_global_orig_entry_add(tt_global_entry, orig_node, ttvn, + flags & BATADV_TT_SYNC_MASK); + + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Creating new global tt entry: %pM (vid: %d, via %pM)\n", + common->addr, batadv_print_vid(common->vid), + orig_node->orig); + ret = true; + +out_remove: + /* Do not remove multicast addresses from the local hash on + * global additions + */ + if (is_multicast_ether_addr(tt_addr)) + goto out; + + /* remove address from local hash if present */ + local_flags = batadv_tt_local_remove(bat_priv, tt_addr, vid, + "global tt received", + flags & BATADV_TT_CLIENT_ROAM); + tt_global_entry->common.flags |= local_flags & BATADV_TT_CLIENT_WIFI; + + if (!(flags & BATADV_TT_CLIENT_ROAM)) + /* this is a normal global add. Therefore the client is not in a + * roaming state anymore. + */ + tt_global_entry->common.flags &= ~BATADV_TT_CLIENT_ROAM; + +out: + batadv_tt_global_entry_put(tt_global_entry); + batadv_tt_local_entry_put(tt_local_entry); + return ret; +} + +/** + * batadv_transtable_best_orig() - Get best originator list entry from tt entry + * @bat_priv: the bat priv with all the soft interface information + * @tt_global_entry: global translation table entry to be analyzed + * + * This function assumes the caller holds rcu_read_lock(). + * Return: best originator list entry or NULL on errors. + */ +static struct batadv_tt_orig_list_entry * +batadv_transtable_best_orig(struct batadv_priv *bat_priv, + struct batadv_tt_global_entry *tt_global_entry) +{ + struct batadv_neigh_node *router, *best_router = NULL; + struct batadv_algo_ops *bao = bat_priv->algo_ops; + struct hlist_head *head; + struct batadv_tt_orig_list_entry *orig_entry, *best_entry = NULL; + + head = &tt_global_entry->orig_list; + hlist_for_each_entry_rcu(orig_entry, head, list) { + router = batadv_orig_router_get(orig_entry->orig_node, + BATADV_IF_DEFAULT); + if (!router) + continue; + + if (best_router && + bao->neigh.cmp(router, BATADV_IF_DEFAULT, best_router, + BATADV_IF_DEFAULT) <= 0) { + batadv_neigh_node_put(router); + continue; + } + + /* release the refcount for the "old" best */ + batadv_neigh_node_put(best_router); + + best_entry = orig_entry; + best_router = router; + } + + batadv_neigh_node_put(best_router); + + return best_entry; +} + +/** + * batadv_tt_global_dump_subentry() - Dump all TT local entries into a message + * @msg: Netlink message to dump into + * @portid: Port making netlink request + * @seq: Sequence number of netlink message + * @common: tt local & tt global common data + * @orig: Originator node announcing a non-mesh client + * @best: Is the best originator for the TT entry + * + * Return: Error code, or 0 on success + */ +static int +batadv_tt_global_dump_subentry(struct sk_buff *msg, u32 portid, u32 seq, + struct batadv_tt_common_entry *common, + struct batadv_tt_orig_list_entry *orig, + bool best) +{ + u16 flags = (common->flags & (~BATADV_TT_SYNC_MASK)) | orig->flags; + void *hdr; + struct batadv_orig_node_vlan *vlan; + u8 last_ttvn; + u32 crc; + + vlan = batadv_orig_node_vlan_get(orig->orig_node, + common->vid); + if (!vlan) + return 0; + + crc = vlan->tt.crc; + + batadv_orig_node_vlan_put(vlan); + + hdr = genlmsg_put(msg, portid, seq, &batadv_netlink_family, + NLM_F_MULTI, + BATADV_CMD_GET_TRANSTABLE_GLOBAL); + if (!hdr) + return -ENOBUFS; + + last_ttvn = atomic_read(&orig->orig_node->last_ttvn); + + if (nla_put(msg, BATADV_ATTR_TT_ADDRESS, ETH_ALEN, common->addr) || + nla_put(msg, BATADV_ATTR_ORIG_ADDRESS, ETH_ALEN, + orig->orig_node->orig) || + nla_put_u8(msg, BATADV_ATTR_TT_TTVN, orig->ttvn) || + nla_put_u8(msg, BATADV_ATTR_TT_LAST_TTVN, last_ttvn) || + nla_put_u32(msg, BATADV_ATTR_TT_CRC32, crc) || + nla_put_u16(msg, BATADV_ATTR_TT_VID, common->vid) || + nla_put_u32(msg, BATADV_ATTR_TT_FLAGS, flags)) + goto nla_put_failure; + + if (best && nla_put_flag(msg, BATADV_ATTR_FLAG_BEST)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + + nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +/** + * batadv_tt_global_dump_entry() - Dump one TT global entry into a message + * @msg: Netlink message to dump into + * @portid: Port making netlink request + * @seq: Sequence number of netlink message + * @bat_priv: The bat priv with all the soft interface information + * @common: tt local & tt global common data + * @sub_s: Number of entries to skip + * + * This function assumes the caller holds rcu_read_lock(). + * + * Return: Error code, or 0 on success + */ +static int +batadv_tt_global_dump_entry(struct sk_buff *msg, u32 portid, u32 seq, + struct batadv_priv *bat_priv, + struct batadv_tt_common_entry *common, int *sub_s) +{ + struct batadv_tt_orig_list_entry *orig_entry, *best_entry; + struct batadv_tt_global_entry *global; + struct hlist_head *head; + int sub = 0; + bool best; + + global = container_of(common, struct batadv_tt_global_entry, common); + best_entry = batadv_transtable_best_orig(bat_priv, global); + head = &global->orig_list; + + hlist_for_each_entry_rcu(orig_entry, head, list) { + if (sub++ < *sub_s) + continue; + + best = (orig_entry == best_entry); + + if (batadv_tt_global_dump_subentry(msg, portid, seq, common, + orig_entry, best)) { + *sub_s = sub - 1; + return -EMSGSIZE; + } + } + + *sub_s = 0; + return 0; +} + +/** + * batadv_tt_global_dump_bucket() - Dump one TT local bucket into a message + * @msg: Netlink message to dump into + * @portid: Port making netlink request + * @seq: Sequence number of netlink message + * @bat_priv: The bat priv with all the soft interface information + * @head: Pointer to the list containing the global tt entries + * @idx_s: Number of entries to skip + * @sub: Number of entries to skip + * + * Return: Error code, or 0 on success + */ +static int +batadv_tt_global_dump_bucket(struct sk_buff *msg, u32 portid, u32 seq, + struct batadv_priv *bat_priv, + struct hlist_head *head, int *idx_s, int *sub) +{ + struct batadv_tt_common_entry *common; + int idx = 0; + + rcu_read_lock(); + hlist_for_each_entry_rcu(common, head, hash_entry) { + if (idx++ < *idx_s) + continue; + + if (batadv_tt_global_dump_entry(msg, portid, seq, bat_priv, + common, sub)) { + rcu_read_unlock(); + *idx_s = idx - 1; + return -EMSGSIZE; + } + } + rcu_read_unlock(); + + *idx_s = 0; + *sub = 0; + return 0; +} + +/** + * batadv_tt_global_dump() - Dump TT global entries into a message + * @msg: Netlink message to dump into + * @cb: Parameters from query + * + * Return: Error code, or length of message on success + */ +int batadv_tt_global_dump(struct sk_buff *msg, struct netlink_callback *cb) +{ + struct net *net = sock_net(cb->skb->sk); + struct net_device *soft_iface; + struct batadv_priv *bat_priv; + struct batadv_hard_iface *primary_if = NULL; + struct batadv_hashtable *hash; + struct hlist_head *head; + int ret; + int ifindex; + int bucket = cb->args[0]; + int idx = cb->args[1]; + int sub = cb->args[2]; + int portid = NETLINK_CB(cb->skb).portid; + + ifindex = batadv_netlink_get_ifindex(cb->nlh, BATADV_ATTR_MESH_IFINDEX); + if (!ifindex) + return -EINVAL; + + soft_iface = dev_get_by_index(net, ifindex); + if (!soft_iface || !batadv_softif_is_valid(soft_iface)) { + ret = -ENODEV; + goto out; + } + + bat_priv = netdev_priv(soft_iface); + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if || primary_if->if_status != BATADV_IF_ACTIVE) { + ret = -ENOENT; + goto out; + } + + hash = bat_priv->tt.global_hash; + + while (bucket < hash->size) { + head = &hash->table[bucket]; + + if (batadv_tt_global_dump_bucket(msg, portid, + cb->nlh->nlmsg_seq, bat_priv, + head, &idx, &sub)) + break; + + bucket++; + } + + ret = msg->len; + + out: + batadv_hardif_put(primary_if); + dev_put(soft_iface); + + cb->args[0] = bucket; + cb->args[1] = idx; + cb->args[2] = sub; + + return ret; +} + +/** + * _batadv_tt_global_del_orig_entry() - remove and free an orig_entry + * @tt_global_entry: the global entry to remove the orig_entry from + * @orig_entry: the orig entry to remove and free + * + * Remove an orig_entry from its list in the given tt_global_entry and + * free this orig_entry afterwards. + * + * Caller must hold tt_global_entry->list_lock and ensure orig_entry->list is + * part of a list. + */ +static void +_batadv_tt_global_del_orig_entry(struct batadv_tt_global_entry *tt_global_entry, + struct batadv_tt_orig_list_entry *orig_entry) +{ + lockdep_assert_held(&tt_global_entry->list_lock); + + batadv_tt_global_size_dec(orig_entry->orig_node, + tt_global_entry->common.vid); + atomic_dec(&tt_global_entry->orig_list_count); + /* requires holding tt_global_entry->list_lock and orig_entry->list + * being part of a list + */ + hlist_del_rcu(&orig_entry->list); + batadv_tt_orig_list_entry_put(orig_entry); +} + +/* deletes the orig list of a tt_global_entry */ +static void +batadv_tt_global_del_orig_list(struct batadv_tt_global_entry *tt_global_entry) +{ + struct hlist_head *head; + struct hlist_node *safe; + struct batadv_tt_orig_list_entry *orig_entry; + + spin_lock_bh(&tt_global_entry->list_lock); + head = &tt_global_entry->orig_list; + hlist_for_each_entry_safe(orig_entry, safe, head, list) + _batadv_tt_global_del_orig_entry(tt_global_entry, orig_entry); + spin_unlock_bh(&tt_global_entry->list_lock); +} + +/** + * batadv_tt_global_del_orig_node() - remove orig_node from a global tt entry + * @bat_priv: the bat priv with all the soft interface information + * @tt_global_entry: the global entry to remove the orig_node from + * @orig_node: the originator announcing the client + * @message: message to append to the log on deletion + * + * Remove the given orig_node and its according orig_entry from the given + * global tt entry. + */ +static void +batadv_tt_global_del_orig_node(struct batadv_priv *bat_priv, + struct batadv_tt_global_entry *tt_global_entry, + struct batadv_orig_node *orig_node, + const char *message) +{ + struct hlist_head *head; + struct hlist_node *safe; + struct batadv_tt_orig_list_entry *orig_entry; + unsigned short vid; + + spin_lock_bh(&tt_global_entry->list_lock); + head = &tt_global_entry->orig_list; + hlist_for_each_entry_safe(orig_entry, safe, head, list) { + if (orig_entry->orig_node == orig_node) { + vid = tt_global_entry->common.vid; + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Deleting %pM from global tt entry %pM (vid: %d): %s\n", + orig_node->orig, + tt_global_entry->common.addr, + batadv_print_vid(vid), message); + _batadv_tt_global_del_orig_entry(tt_global_entry, + orig_entry); + } + } + spin_unlock_bh(&tt_global_entry->list_lock); +} + +/* If the client is to be deleted, we check if it is the last origantor entry + * within tt_global entry. If yes, we set the BATADV_TT_CLIENT_ROAM flag and the + * timer, otherwise we simply remove the originator scheduled for deletion. + */ +static void +batadv_tt_global_del_roaming(struct batadv_priv *bat_priv, + struct batadv_tt_global_entry *tt_global_entry, + struct batadv_orig_node *orig_node, + const char *message) +{ + bool last_entry = true; + struct hlist_head *head; + struct batadv_tt_orig_list_entry *orig_entry; + + /* no local entry exists, case 1: + * Check if this is the last one or if other entries exist. + */ + + rcu_read_lock(); + head = &tt_global_entry->orig_list; + hlist_for_each_entry_rcu(orig_entry, head, list) { + if (orig_entry->orig_node != orig_node) { + last_entry = false; + break; + } + } + rcu_read_unlock(); + + if (last_entry) { + /* its the last one, mark for roaming. */ + tt_global_entry->common.flags |= BATADV_TT_CLIENT_ROAM; + tt_global_entry->roam_at = jiffies; + } else { + /* there is another entry, we can simply delete this + * one and can still use the other one. + */ + batadv_tt_global_del_orig_node(bat_priv, tt_global_entry, + orig_node, message); + } +} + +/** + * batadv_tt_global_del() - remove a client from the global table + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: an originator serving this client + * @addr: the mac address of the client + * @vid: VLAN identifier + * @message: a message explaining the reason for deleting the client to print + * for debugging purpose + * @roaming: true if the deletion has been triggered by a roaming event + */ +static void batadv_tt_global_del(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + const unsigned char *addr, unsigned short vid, + const char *message, bool roaming) +{ + struct batadv_tt_global_entry *tt_global_entry; + struct batadv_tt_local_entry *local_entry = NULL; + + tt_global_entry = batadv_tt_global_hash_find(bat_priv, addr, vid); + if (!tt_global_entry) + goto out; + + if (!roaming) { + batadv_tt_global_del_orig_node(bat_priv, tt_global_entry, + orig_node, message); + + if (hlist_empty(&tt_global_entry->orig_list)) + batadv_tt_global_free(bat_priv, tt_global_entry, + message); + + goto out; + } + + /* if we are deleting a global entry due to a roam + * event, there are two possibilities: + * 1) the client roamed from node A to node B => if there + * is only one originator left for this client, we mark + * it with BATADV_TT_CLIENT_ROAM, we start a timer and we + * wait for node B to claim it. In case of timeout + * the entry is purged. + * + * If there are other originators left, we directly delete + * the originator. + * 2) the client roamed to us => we can directly delete + * the global entry, since it is useless now. + */ + local_entry = batadv_tt_local_hash_find(bat_priv, + tt_global_entry->common.addr, + vid); + if (local_entry) { + /* local entry exists, case 2: client roamed to us. */ + batadv_tt_global_del_orig_list(tt_global_entry); + batadv_tt_global_free(bat_priv, tt_global_entry, message); + } else { + /* no local entry exists, case 1: check for roaming */ + batadv_tt_global_del_roaming(bat_priv, tt_global_entry, + orig_node, message); + } + +out: + batadv_tt_global_entry_put(tt_global_entry); + batadv_tt_local_entry_put(local_entry); +} + +/** + * batadv_tt_global_del_orig() - remove all the TT global entries belonging to + * the given originator matching the provided vid + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: the originator owning the entries to remove + * @match_vid: the VLAN identifier to match. If negative all the entries will be + * removed + * @message: debug message to print as "reason" + */ +void batadv_tt_global_del_orig(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + s32 match_vid, + const char *message) +{ + struct batadv_tt_global_entry *tt_global; + struct batadv_tt_common_entry *tt_common_entry; + u32 i; + struct batadv_hashtable *hash = bat_priv->tt.global_hash; + struct hlist_node *safe; + struct hlist_head *head; + spinlock_t *list_lock; /* protects write access to the hash lists */ + unsigned short vid; + + if (!hash) + return; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + list_lock = &hash->list_locks[i]; + + spin_lock_bh(list_lock); + hlist_for_each_entry_safe(tt_common_entry, safe, + head, hash_entry) { + /* remove only matching entries */ + if (match_vid >= 0 && tt_common_entry->vid != match_vid) + continue; + + tt_global = container_of(tt_common_entry, + struct batadv_tt_global_entry, + common); + + batadv_tt_global_del_orig_node(bat_priv, tt_global, + orig_node, message); + + if (hlist_empty(&tt_global->orig_list)) { + vid = tt_global->common.vid; + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Deleting global tt entry %pM (vid: %d): %s\n", + tt_global->common.addr, + batadv_print_vid(vid), message); + hlist_del_rcu(&tt_common_entry->hash_entry); + batadv_tt_global_entry_put(tt_global); + } + } + spin_unlock_bh(list_lock); + } + clear_bit(BATADV_ORIG_CAPA_HAS_TT, &orig_node->capa_initialized); +} + +static bool batadv_tt_global_to_purge(struct batadv_tt_global_entry *tt_global, + char **msg) +{ + bool purge = false; + unsigned long roam_timeout = BATADV_TT_CLIENT_ROAM_TIMEOUT; + unsigned long temp_timeout = BATADV_TT_CLIENT_TEMP_TIMEOUT; + + if ((tt_global->common.flags & BATADV_TT_CLIENT_ROAM) && + batadv_has_timed_out(tt_global->roam_at, roam_timeout)) { + purge = true; + *msg = "Roaming timeout\n"; + } + + if ((tt_global->common.flags & BATADV_TT_CLIENT_TEMP) && + batadv_has_timed_out(tt_global->common.added_at, temp_timeout)) { + purge = true; + *msg = "Temporary client timeout\n"; + } + + return purge; +} + +static void batadv_tt_global_purge(struct batadv_priv *bat_priv) +{ + struct batadv_hashtable *hash = bat_priv->tt.global_hash; + struct hlist_head *head; + struct hlist_node *node_tmp; + spinlock_t *list_lock; /* protects write access to the hash lists */ + u32 i; + char *msg = NULL; + struct batadv_tt_common_entry *tt_common; + struct batadv_tt_global_entry *tt_global; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + list_lock = &hash->list_locks[i]; + + spin_lock_bh(list_lock); + hlist_for_each_entry_safe(tt_common, node_tmp, head, + hash_entry) { + tt_global = container_of(tt_common, + struct batadv_tt_global_entry, + common); + + if (!batadv_tt_global_to_purge(tt_global, &msg)) + continue; + + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Deleting global tt entry %pM (vid: %d): %s\n", + tt_global->common.addr, + batadv_print_vid(tt_global->common.vid), + msg); + + hlist_del_rcu(&tt_common->hash_entry); + + batadv_tt_global_entry_put(tt_global); + } + spin_unlock_bh(list_lock); + } +} + +static void batadv_tt_global_table_free(struct batadv_priv *bat_priv) +{ + struct batadv_hashtable *hash; + spinlock_t *list_lock; /* protects write access to the hash lists */ + struct batadv_tt_common_entry *tt_common_entry; + struct batadv_tt_global_entry *tt_global; + struct hlist_node *node_tmp; + struct hlist_head *head; + u32 i; + + if (!bat_priv->tt.global_hash) + return; + + hash = bat_priv->tt.global_hash; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + list_lock = &hash->list_locks[i]; + + spin_lock_bh(list_lock); + hlist_for_each_entry_safe(tt_common_entry, node_tmp, + head, hash_entry) { + hlist_del_rcu(&tt_common_entry->hash_entry); + tt_global = container_of(tt_common_entry, + struct batadv_tt_global_entry, + common); + batadv_tt_global_entry_put(tt_global); + } + spin_unlock_bh(list_lock); + } + + batadv_hash_destroy(hash); + + bat_priv->tt.global_hash = NULL; +} + +static bool +_batadv_is_ap_isolated(struct batadv_tt_local_entry *tt_local_entry, + struct batadv_tt_global_entry *tt_global_entry) +{ + if (tt_local_entry->common.flags & BATADV_TT_CLIENT_WIFI && + tt_global_entry->common.flags & BATADV_TT_CLIENT_WIFI) + return true; + + /* check if the two clients are marked as isolated */ + if (tt_local_entry->common.flags & BATADV_TT_CLIENT_ISOLA && + tt_global_entry->common.flags & BATADV_TT_CLIENT_ISOLA) + return true; + + return false; +} + +/** + * batadv_transtable_search() - get the mesh destination for a given client + * @bat_priv: the bat priv with all the soft interface information + * @src: mac address of the source client + * @addr: mac address of the destination client + * @vid: VLAN identifier + * + * Return: a pointer to the originator that was selected as destination in the + * mesh for contacting the client 'addr', NULL otherwise. + * In case of multiple originators serving the same client, the function returns + * the best one (best in terms of metric towards the destination node). + * + * If the two clients are AP isolated the function returns NULL. + */ +struct batadv_orig_node *batadv_transtable_search(struct batadv_priv *bat_priv, + const u8 *src, + const u8 *addr, + unsigned short vid) +{ + struct batadv_tt_local_entry *tt_local_entry = NULL; + struct batadv_tt_global_entry *tt_global_entry = NULL; + struct batadv_orig_node *orig_node = NULL; + struct batadv_tt_orig_list_entry *best_entry; + + if (src && batadv_vlan_ap_isola_get(bat_priv, vid)) { + tt_local_entry = batadv_tt_local_hash_find(bat_priv, src, vid); + if (!tt_local_entry || + (tt_local_entry->common.flags & BATADV_TT_CLIENT_PENDING)) + goto out; + } + + tt_global_entry = batadv_tt_global_hash_find(bat_priv, addr, vid); + if (!tt_global_entry) + goto out; + + /* check whether the clients should not communicate due to AP + * isolation + */ + if (tt_local_entry && + _batadv_is_ap_isolated(tt_local_entry, tt_global_entry)) + goto out; + + rcu_read_lock(); + best_entry = batadv_transtable_best_orig(bat_priv, tt_global_entry); + /* found anything? */ + if (best_entry) + orig_node = best_entry->orig_node; + if (orig_node && !kref_get_unless_zero(&orig_node->refcount)) + orig_node = NULL; + rcu_read_unlock(); + +out: + batadv_tt_global_entry_put(tt_global_entry); + batadv_tt_local_entry_put(tt_local_entry); + + return orig_node; +} + +/** + * batadv_tt_global_crc() - calculates the checksum of the local table belonging + * to the given orig_node + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: originator for which the CRC should be computed + * @vid: VLAN identifier for which the CRC32 has to be computed + * + * This function computes the checksum for the global table corresponding to a + * specific originator. In particular, the checksum is computed as follows: For + * each client connected to the originator the CRC32C of the MAC address and the + * VID is computed and then all the CRC32Cs of the various clients are xor'ed + * together. + * + * The idea behind is that CRC32C should be used as much as possible in order to + * produce a unique hash of the table, but since the order which is used to feed + * the CRC32C function affects the result and since every node in the network + * probably sorts the clients differently, the hash function cannot be directly + * computed over the entire table. Hence the CRC32C is used only on + * the single client entry, while all the results are then xor'ed together + * because the XOR operation can combine them all while trying to reduce the + * noise as much as possible. + * + * Return: the checksum of the global table of a given originator. + */ +static u32 batadv_tt_global_crc(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + unsigned short vid) +{ + struct batadv_hashtable *hash = bat_priv->tt.global_hash; + struct batadv_tt_orig_list_entry *tt_orig; + struct batadv_tt_common_entry *tt_common; + struct batadv_tt_global_entry *tt_global; + struct hlist_head *head; + u32 i, crc_tmp, crc = 0; + u8 flags; + __be16 tmp_vid; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tt_common, head, hash_entry) { + tt_global = container_of(tt_common, + struct batadv_tt_global_entry, + common); + /* compute the CRC only for entries belonging to the + * VLAN identified by the vid passed as parameter + */ + if (tt_common->vid != vid) + continue; + + /* Roaming clients are in the global table for + * consistency only. They don't have to be + * taken into account while computing the + * global crc + */ + if (tt_common->flags & BATADV_TT_CLIENT_ROAM) + continue; + /* Temporary clients have not been announced yet, so + * they have to be skipped while computing the global + * crc + */ + if (tt_common->flags & BATADV_TT_CLIENT_TEMP) + continue; + + /* find out if this global entry is announced by this + * originator + */ + tt_orig = batadv_tt_global_orig_entry_find(tt_global, + orig_node); + if (!tt_orig) + continue; + + /* use network order to read the VID: this ensures that + * every node reads the bytes in the same order. + */ + tmp_vid = htons(tt_common->vid); + crc_tmp = crc32c(0, &tmp_vid, sizeof(tmp_vid)); + + /* compute the CRC on flags that have to be kept in sync + * among nodes + */ + flags = tt_orig->flags; + crc_tmp = crc32c(crc_tmp, &flags, sizeof(flags)); + + crc ^= crc32c(crc_tmp, tt_common->addr, ETH_ALEN); + + batadv_tt_orig_list_entry_put(tt_orig); + } + rcu_read_unlock(); + } + + return crc; +} + +/** + * batadv_tt_local_crc() - calculates the checksum of the local table + * @bat_priv: the bat priv with all the soft interface information + * @vid: VLAN identifier for which the CRC32 has to be computed + * + * For details about the computation, please refer to the documentation for + * batadv_tt_global_crc(). + * + * Return: the checksum of the local table + */ +static u32 batadv_tt_local_crc(struct batadv_priv *bat_priv, + unsigned short vid) +{ + struct batadv_hashtable *hash = bat_priv->tt.local_hash; + struct batadv_tt_common_entry *tt_common; + struct hlist_head *head; + u32 i, crc_tmp, crc = 0; + u8 flags; + __be16 tmp_vid; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tt_common, head, hash_entry) { + /* compute the CRC only for entries belonging to the + * VLAN identified by vid + */ + if (tt_common->vid != vid) + continue; + + /* not yet committed clients have not to be taken into + * account while computing the CRC + */ + if (tt_common->flags & BATADV_TT_CLIENT_NEW) + continue; + + /* use network order to read the VID: this ensures that + * every node reads the bytes in the same order. + */ + tmp_vid = htons(tt_common->vid); + crc_tmp = crc32c(0, &tmp_vid, sizeof(tmp_vid)); + + /* compute the CRC on flags that have to be kept in sync + * among nodes + */ + flags = tt_common->flags & BATADV_TT_SYNC_MASK; + crc_tmp = crc32c(crc_tmp, &flags, sizeof(flags)); + + crc ^= crc32c(crc_tmp, tt_common->addr, ETH_ALEN); + } + rcu_read_unlock(); + } + + return crc; +} + +/** + * batadv_tt_req_node_release() - free tt_req node entry + * @ref: kref pointer of the tt req_node entry + */ +static void batadv_tt_req_node_release(struct kref *ref) +{ + struct batadv_tt_req_node *tt_req_node; + + tt_req_node = container_of(ref, struct batadv_tt_req_node, refcount); + + kmem_cache_free(batadv_tt_req_cache, tt_req_node); +} + +/** + * batadv_tt_req_node_put() - decrement the tt_req_node refcounter and + * possibly release it + * @tt_req_node: tt_req_node to be free'd + */ +static void batadv_tt_req_node_put(struct batadv_tt_req_node *tt_req_node) +{ + if (!tt_req_node) + return; + + kref_put(&tt_req_node->refcount, batadv_tt_req_node_release); +} + +static void batadv_tt_req_list_free(struct batadv_priv *bat_priv) +{ + struct batadv_tt_req_node *node; + struct hlist_node *safe; + + spin_lock_bh(&bat_priv->tt.req_list_lock); + + hlist_for_each_entry_safe(node, safe, &bat_priv->tt.req_list, list) { + hlist_del_init(&node->list); + batadv_tt_req_node_put(node); + } + + spin_unlock_bh(&bat_priv->tt.req_list_lock); +} + +static void batadv_tt_save_orig_buffer(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + const void *tt_buff, + u16 tt_buff_len) +{ + /* Replace the old buffer only if I received something in the + * last OGM (the OGM could carry no changes) + */ + spin_lock_bh(&orig_node->tt_buff_lock); + if (tt_buff_len > 0) { + kfree(orig_node->tt_buff); + orig_node->tt_buff_len = 0; + orig_node->tt_buff = kmalloc(tt_buff_len, GFP_ATOMIC); + if (orig_node->tt_buff) { + memcpy(orig_node->tt_buff, tt_buff, tt_buff_len); + orig_node->tt_buff_len = tt_buff_len; + } + } + spin_unlock_bh(&orig_node->tt_buff_lock); +} + +static void batadv_tt_req_purge(struct batadv_priv *bat_priv) +{ + struct batadv_tt_req_node *node; + struct hlist_node *safe; + + spin_lock_bh(&bat_priv->tt.req_list_lock); + hlist_for_each_entry_safe(node, safe, &bat_priv->tt.req_list, list) { + if (batadv_has_timed_out(node->issued_at, + BATADV_TT_REQUEST_TIMEOUT)) { + hlist_del_init(&node->list); + batadv_tt_req_node_put(node); + } + } + spin_unlock_bh(&bat_priv->tt.req_list_lock); +} + +/** + * batadv_tt_req_node_new() - search and possibly create a tt_req_node object + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: orig node this request is being issued for + * + * Return: the pointer to the new tt_req_node struct if no request + * has already been issued for this orig_node, NULL otherwise. + */ +static struct batadv_tt_req_node * +batadv_tt_req_node_new(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node) +{ + struct batadv_tt_req_node *tt_req_node_tmp, *tt_req_node = NULL; + + spin_lock_bh(&bat_priv->tt.req_list_lock); + hlist_for_each_entry(tt_req_node_tmp, &bat_priv->tt.req_list, list) { + if (batadv_compare_eth(tt_req_node_tmp, orig_node) && + !batadv_has_timed_out(tt_req_node_tmp->issued_at, + BATADV_TT_REQUEST_TIMEOUT)) + goto unlock; + } + + tt_req_node = kmem_cache_alloc(batadv_tt_req_cache, GFP_ATOMIC); + if (!tt_req_node) + goto unlock; + + kref_init(&tt_req_node->refcount); + ether_addr_copy(tt_req_node->addr, orig_node->orig); + tt_req_node->issued_at = jiffies; + + kref_get(&tt_req_node->refcount); + hlist_add_head(&tt_req_node->list, &bat_priv->tt.req_list); +unlock: + spin_unlock_bh(&bat_priv->tt.req_list_lock); + return tt_req_node; +} + +/** + * batadv_tt_local_valid() - verify local tt entry and get flags + * @entry_ptr: to be checked local tt entry + * @data_ptr: not used but definition required to satisfy the callback prototype + * @flags: a pointer to store TT flags for this client to + * + * Checks the validity of the given local TT entry. If it is, then the provided + * flags pointer is updated. + * + * Return: true if the entry is a valid, false otherwise. + */ +static bool batadv_tt_local_valid(const void *entry_ptr, + const void *data_ptr, + u8 *flags) +{ + const struct batadv_tt_common_entry *tt_common_entry = entry_ptr; + + if (tt_common_entry->flags & BATADV_TT_CLIENT_NEW) + return false; + + if (flags) + *flags = tt_common_entry->flags; + + return true; +} + +/** + * batadv_tt_global_valid() - verify global tt entry and get flags + * @entry_ptr: to be checked global tt entry + * @data_ptr: an orig_node object (may be NULL) + * @flags: a pointer to store TT flags for this client to + * + * Checks the validity of the given global TT entry. If it is, then the provided + * flags pointer is updated either with the common (summed) TT flags if data_ptr + * is NULL or the specific, per originator TT flags otherwise. + * + * Return: true if the entry is a valid, false otherwise. + */ +static bool batadv_tt_global_valid(const void *entry_ptr, + const void *data_ptr, + u8 *flags) +{ + const struct batadv_tt_common_entry *tt_common_entry = entry_ptr; + const struct batadv_tt_global_entry *tt_global_entry; + const struct batadv_orig_node *orig_node = data_ptr; + + if (tt_common_entry->flags & BATADV_TT_CLIENT_ROAM || + tt_common_entry->flags & BATADV_TT_CLIENT_TEMP) + return false; + + tt_global_entry = container_of(tt_common_entry, + struct batadv_tt_global_entry, + common); + + return batadv_tt_global_entry_has_orig(tt_global_entry, orig_node, + flags); +} + +/** + * batadv_tt_tvlv_generate() - fill the tvlv buff with the tt entries from the + * specified tt hash + * @bat_priv: the bat priv with all the soft interface information + * @hash: hash table containing the tt entries + * @tt_len: expected tvlv tt data buffer length in number of bytes + * @tvlv_buff: pointer to the buffer to fill with the TT data + * @valid_cb: function to filter tt change entries and to return TT flags + * @cb_data: data passed to the filter function as argument + * + * Fills the tvlv buff with the tt entries from the specified hash. If valid_cb + * is not provided then this becomes a no-op. + */ +static void batadv_tt_tvlv_generate(struct batadv_priv *bat_priv, + struct batadv_hashtable *hash, + void *tvlv_buff, u16 tt_len, + bool (*valid_cb)(const void *, + const void *, + u8 *flags), + void *cb_data) +{ + struct batadv_tt_common_entry *tt_common_entry; + struct batadv_tvlv_tt_change *tt_change; + struct hlist_head *head; + u16 tt_tot, tt_num_entries = 0; + u8 flags; + bool ret; + u32 i; + + tt_tot = batadv_tt_entries(tt_len); + tt_change = tvlv_buff; + + if (!valid_cb) + return; + + rcu_read_lock(); + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + + hlist_for_each_entry_rcu(tt_common_entry, + head, hash_entry) { + if (tt_tot == tt_num_entries) + break; + + ret = valid_cb(tt_common_entry, cb_data, &flags); + if (!ret) + continue; + + ether_addr_copy(tt_change->addr, tt_common_entry->addr); + tt_change->flags = flags; + tt_change->vid = htons(tt_common_entry->vid); + memset(tt_change->reserved, 0, + sizeof(tt_change->reserved)); + + tt_num_entries++; + tt_change++; + } + } + rcu_read_unlock(); +} + +/** + * batadv_tt_global_check_crc() - check if all the CRCs are correct + * @orig_node: originator for which the CRCs have to be checked + * @tt_vlan: pointer to the first tvlv VLAN entry + * @num_vlan: number of tvlv VLAN entries + * + * Return: true if all the received CRCs match the locally stored ones, false + * otherwise + */ +static bool batadv_tt_global_check_crc(struct batadv_orig_node *orig_node, + struct batadv_tvlv_tt_vlan_data *tt_vlan, + u16 num_vlan) +{ + struct batadv_tvlv_tt_vlan_data *tt_vlan_tmp; + struct batadv_orig_node_vlan *vlan; + int i, orig_num_vlan; + u32 crc; + + /* check if each received CRC matches the locally stored one */ + for (i = 0; i < num_vlan; i++) { + tt_vlan_tmp = tt_vlan + i; + + /* if orig_node is a backbone node for this VLAN, don't check + * the CRC as we ignore all the global entries over it + */ + if (batadv_bla_is_backbone_gw_orig(orig_node->bat_priv, + orig_node->orig, + ntohs(tt_vlan_tmp->vid))) + continue; + + vlan = batadv_orig_node_vlan_get(orig_node, + ntohs(tt_vlan_tmp->vid)); + if (!vlan) + return false; + + crc = vlan->tt.crc; + batadv_orig_node_vlan_put(vlan); + + if (crc != ntohl(tt_vlan_tmp->crc)) + return false; + } + + /* check if any excess VLANs exist locally for the originator + * which are not mentioned in the TVLV from the originator. + */ + rcu_read_lock(); + orig_num_vlan = 0; + hlist_for_each_entry_rcu(vlan, &orig_node->vlan_list, list) + orig_num_vlan++; + rcu_read_unlock(); + + if (orig_num_vlan > num_vlan) + return false; + + return true; +} + +/** + * batadv_tt_local_update_crc() - update all the local CRCs + * @bat_priv: the bat priv with all the soft interface information + */ +static void batadv_tt_local_update_crc(struct batadv_priv *bat_priv) +{ + struct batadv_softif_vlan *vlan; + + /* recompute the global CRC for each VLAN */ + rcu_read_lock(); + hlist_for_each_entry_rcu(vlan, &bat_priv->softif_vlan_list, list) { + vlan->tt.crc = batadv_tt_local_crc(bat_priv, vlan->vid); + } + rcu_read_unlock(); +} + +/** + * batadv_tt_global_update_crc() - update all the global CRCs for this orig_node + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: the orig_node for which the CRCs have to be updated + */ +static void batadv_tt_global_update_crc(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node) +{ + struct batadv_orig_node_vlan *vlan; + u32 crc; + + /* recompute the global CRC for each VLAN */ + rcu_read_lock(); + hlist_for_each_entry_rcu(vlan, &orig_node->vlan_list, list) { + /* if orig_node is a backbone node for this VLAN, don't compute + * the CRC as we ignore all the global entries over it + */ + if (batadv_bla_is_backbone_gw_orig(bat_priv, orig_node->orig, + vlan->vid)) + continue; + + crc = batadv_tt_global_crc(bat_priv, orig_node, vlan->vid); + vlan->tt.crc = crc; + } + rcu_read_unlock(); +} + +/** + * batadv_send_tt_request() - send a TT Request message to a given node + * @bat_priv: the bat priv with all the soft interface information + * @dst_orig_node: the destination of the message + * @ttvn: the version number that the source of the message is looking for + * @tt_vlan: pointer to the first tvlv VLAN object to request + * @num_vlan: number of tvlv VLAN entries + * @full_table: ask for the entire translation table if true, while only for the + * last TT diff otherwise + * + * Return: true if the TT Request was sent, false otherwise + */ +static bool batadv_send_tt_request(struct batadv_priv *bat_priv, + struct batadv_orig_node *dst_orig_node, + u8 ttvn, + struct batadv_tvlv_tt_vlan_data *tt_vlan, + u16 num_vlan, bool full_table) +{ + struct batadv_tvlv_tt_data *tvlv_tt_data = NULL; + struct batadv_tt_req_node *tt_req_node = NULL; + struct batadv_tvlv_tt_vlan_data *tt_vlan_req; + struct batadv_hard_iface *primary_if; + bool ret = false; + int i, size; + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + goto out; + + /* The new tt_req will be issued only if I'm not waiting for a + * reply from the same orig_node yet + */ + tt_req_node = batadv_tt_req_node_new(bat_priv, dst_orig_node); + if (!tt_req_node) + goto out; + + size = sizeof(*tvlv_tt_data) + sizeof(*tt_vlan_req) * num_vlan; + tvlv_tt_data = kzalloc(size, GFP_ATOMIC); + if (!tvlv_tt_data) + goto out; + + tvlv_tt_data->flags = BATADV_TT_REQUEST; + tvlv_tt_data->ttvn = ttvn; + tvlv_tt_data->num_vlan = htons(num_vlan); + + /* send all the CRCs within the request. This is needed by intermediate + * nodes to ensure they have the correct table before replying + */ + tt_vlan_req = (struct batadv_tvlv_tt_vlan_data *)(tvlv_tt_data + 1); + for (i = 0; i < num_vlan; i++) { + tt_vlan_req->vid = tt_vlan->vid; + tt_vlan_req->crc = tt_vlan->crc; + + tt_vlan_req++; + tt_vlan++; + } + + if (full_table) + tvlv_tt_data->flags |= BATADV_TT_FULL_TABLE; + + batadv_dbg(BATADV_DBG_TT, bat_priv, "Sending TT_REQUEST to %pM [%c]\n", + dst_orig_node->orig, full_table ? 'F' : '.'); + + batadv_inc_counter(bat_priv, BATADV_CNT_TT_REQUEST_TX); + batadv_tvlv_unicast_send(bat_priv, primary_if->net_dev->dev_addr, + dst_orig_node->orig, BATADV_TVLV_TT, 1, + tvlv_tt_data, size); + ret = true; + +out: + batadv_hardif_put(primary_if); + + if (ret && tt_req_node) { + spin_lock_bh(&bat_priv->tt.req_list_lock); + if (!hlist_unhashed(&tt_req_node->list)) { + hlist_del_init(&tt_req_node->list); + batadv_tt_req_node_put(tt_req_node); + } + spin_unlock_bh(&bat_priv->tt.req_list_lock); + } + + batadv_tt_req_node_put(tt_req_node); + + kfree(tvlv_tt_data); + return ret; +} + +/** + * batadv_send_other_tt_response() - send reply to tt request concerning another + * node's translation table + * @bat_priv: the bat priv with all the soft interface information + * @tt_data: tt data containing the tt request information + * @req_src: mac address of tt request sender + * @req_dst: mac address of tt request recipient + * + * Return: true if tt request reply was sent, false otherwise. + */ +static bool batadv_send_other_tt_response(struct batadv_priv *bat_priv, + struct batadv_tvlv_tt_data *tt_data, + u8 *req_src, u8 *req_dst) +{ + struct batadv_orig_node *req_dst_orig_node; + struct batadv_orig_node *res_dst_orig_node = NULL; + struct batadv_tvlv_tt_change *tt_change; + struct batadv_tvlv_tt_data *tvlv_tt_data = NULL; + struct batadv_tvlv_tt_vlan_data *tt_vlan; + bool ret = false, full_table; + u8 orig_ttvn, req_ttvn; + u16 tvlv_len; + s32 tt_len; + + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Received TT_REQUEST from %pM for ttvn: %u (%pM) [%c]\n", + req_src, tt_data->ttvn, req_dst, + ((tt_data->flags & BATADV_TT_FULL_TABLE) ? 'F' : '.')); + + /* Let's get the orig node of the REAL destination */ + req_dst_orig_node = batadv_orig_hash_find(bat_priv, req_dst); + if (!req_dst_orig_node) + goto out; + + res_dst_orig_node = batadv_orig_hash_find(bat_priv, req_src); + if (!res_dst_orig_node) + goto out; + + orig_ttvn = (u8)atomic_read(&req_dst_orig_node->last_ttvn); + req_ttvn = tt_data->ttvn; + + tt_vlan = (struct batadv_tvlv_tt_vlan_data *)(tt_data + 1); + /* this node doesn't have the requested data */ + if (orig_ttvn != req_ttvn || + !batadv_tt_global_check_crc(req_dst_orig_node, tt_vlan, + ntohs(tt_data->num_vlan))) + goto out; + + /* If the full table has been explicitly requested */ + if (tt_data->flags & BATADV_TT_FULL_TABLE || + !req_dst_orig_node->tt_buff) + full_table = true; + else + full_table = false; + + /* TT fragmentation hasn't been implemented yet, so send as many + * TT entries fit a single packet as possible only + */ + if (!full_table) { + spin_lock_bh(&req_dst_orig_node->tt_buff_lock); + tt_len = req_dst_orig_node->tt_buff_len; + + tvlv_len = batadv_tt_prepare_tvlv_global_data(req_dst_orig_node, + &tvlv_tt_data, + &tt_change, + &tt_len); + if (!tt_len) + goto unlock; + + /* Copy the last orig_node's OGM buffer */ + memcpy(tt_change, req_dst_orig_node->tt_buff, + req_dst_orig_node->tt_buff_len); + spin_unlock_bh(&req_dst_orig_node->tt_buff_lock); + } else { + /* allocate the tvlv, put the tt_data and all the tt_vlan_data + * in the initial part + */ + tt_len = -1; + tvlv_len = batadv_tt_prepare_tvlv_global_data(req_dst_orig_node, + &tvlv_tt_data, + &tt_change, + &tt_len); + if (!tt_len) + goto out; + + /* fill the rest of the tvlv with the real TT entries */ + batadv_tt_tvlv_generate(bat_priv, bat_priv->tt.global_hash, + tt_change, tt_len, + batadv_tt_global_valid, + req_dst_orig_node); + } + + /* Don't send the response, if larger than fragmented packet. */ + tt_len = sizeof(struct batadv_unicast_tvlv_packet) + tvlv_len; + if (tt_len > atomic_read(&bat_priv->packet_size_max)) { + net_ratelimited_function(batadv_info, bat_priv->soft_iface, + "Ignoring TT_REQUEST from %pM; Response size exceeds max packet size.\n", + res_dst_orig_node->orig); + goto out; + } + + tvlv_tt_data->flags = BATADV_TT_RESPONSE; + tvlv_tt_data->ttvn = req_ttvn; + + if (full_table) + tvlv_tt_data->flags |= BATADV_TT_FULL_TABLE; + + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Sending TT_RESPONSE %pM for %pM [%c] (ttvn: %u)\n", + res_dst_orig_node->orig, req_dst_orig_node->orig, + full_table ? 'F' : '.', req_ttvn); + + batadv_inc_counter(bat_priv, BATADV_CNT_TT_RESPONSE_TX); + + batadv_tvlv_unicast_send(bat_priv, req_dst_orig_node->orig, + req_src, BATADV_TVLV_TT, 1, tvlv_tt_data, + tvlv_len); + + ret = true; + goto out; + +unlock: + spin_unlock_bh(&req_dst_orig_node->tt_buff_lock); + +out: + batadv_orig_node_put(res_dst_orig_node); + batadv_orig_node_put(req_dst_orig_node); + kfree(tvlv_tt_data); + return ret; +} + +/** + * batadv_send_my_tt_response() - send reply to tt request concerning this + * node's translation table + * @bat_priv: the bat priv with all the soft interface information + * @tt_data: tt data containing the tt request information + * @req_src: mac address of tt request sender + * + * Return: true if tt request reply was sent, false otherwise. + */ +static bool batadv_send_my_tt_response(struct batadv_priv *bat_priv, + struct batadv_tvlv_tt_data *tt_data, + u8 *req_src) +{ + struct batadv_tvlv_tt_data *tvlv_tt_data = NULL; + struct batadv_hard_iface *primary_if = NULL; + struct batadv_tvlv_tt_change *tt_change; + struct batadv_orig_node *orig_node; + u8 my_ttvn, req_ttvn; + u16 tvlv_len; + bool full_table; + s32 tt_len; + + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Received TT_REQUEST from %pM for ttvn: %u (me) [%c]\n", + req_src, tt_data->ttvn, + ((tt_data->flags & BATADV_TT_FULL_TABLE) ? 'F' : '.')); + + spin_lock_bh(&bat_priv->tt.commit_lock); + + my_ttvn = (u8)atomic_read(&bat_priv->tt.vn); + req_ttvn = tt_data->ttvn; + + orig_node = batadv_orig_hash_find(bat_priv, req_src); + if (!orig_node) + goto out; + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + goto out; + + /* If the full table has been explicitly requested or the gap + * is too big send the whole local translation table + */ + if (tt_data->flags & BATADV_TT_FULL_TABLE || my_ttvn != req_ttvn || + !bat_priv->tt.last_changeset) + full_table = true; + else + full_table = false; + + /* TT fragmentation hasn't been implemented yet, so send as many + * TT entries fit a single packet as possible only + */ + if (!full_table) { + spin_lock_bh(&bat_priv->tt.last_changeset_lock); + + tt_len = bat_priv->tt.last_changeset_len; + tvlv_len = batadv_tt_prepare_tvlv_local_data(bat_priv, + &tvlv_tt_data, + &tt_change, + &tt_len); + if (!tt_len || !tvlv_len) + goto unlock; + + /* Copy the last orig_node's OGM buffer */ + memcpy(tt_change, bat_priv->tt.last_changeset, + bat_priv->tt.last_changeset_len); + spin_unlock_bh(&bat_priv->tt.last_changeset_lock); + } else { + req_ttvn = (u8)atomic_read(&bat_priv->tt.vn); + + /* allocate the tvlv, put the tt_data and all the tt_vlan_data + * in the initial part + */ + tt_len = -1; + tvlv_len = batadv_tt_prepare_tvlv_local_data(bat_priv, + &tvlv_tt_data, + &tt_change, + &tt_len); + if (!tt_len || !tvlv_len) + goto out; + + /* fill the rest of the tvlv with the real TT entries */ + batadv_tt_tvlv_generate(bat_priv, bat_priv->tt.local_hash, + tt_change, tt_len, + batadv_tt_local_valid, NULL); + } + + tvlv_tt_data->flags = BATADV_TT_RESPONSE; + tvlv_tt_data->ttvn = req_ttvn; + + if (full_table) + tvlv_tt_data->flags |= BATADV_TT_FULL_TABLE; + + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Sending TT_RESPONSE to %pM [%c] (ttvn: %u)\n", + orig_node->orig, full_table ? 'F' : '.', req_ttvn); + + batadv_inc_counter(bat_priv, BATADV_CNT_TT_RESPONSE_TX); + + batadv_tvlv_unicast_send(bat_priv, primary_if->net_dev->dev_addr, + req_src, BATADV_TVLV_TT, 1, tvlv_tt_data, + tvlv_len); + + goto out; + +unlock: + spin_unlock_bh(&bat_priv->tt.last_changeset_lock); +out: + spin_unlock_bh(&bat_priv->tt.commit_lock); + batadv_orig_node_put(orig_node); + batadv_hardif_put(primary_if); + kfree(tvlv_tt_data); + /* The packet was for this host, so it doesn't need to be re-routed */ + return true; +} + +/** + * batadv_send_tt_response() - send reply to tt request + * @bat_priv: the bat priv with all the soft interface information + * @tt_data: tt data containing the tt request information + * @req_src: mac address of tt request sender + * @req_dst: mac address of tt request recipient + * + * Return: true if tt request reply was sent, false otherwise. + */ +static bool batadv_send_tt_response(struct batadv_priv *bat_priv, + struct batadv_tvlv_tt_data *tt_data, + u8 *req_src, u8 *req_dst) +{ + if (batadv_is_my_mac(bat_priv, req_dst)) + return batadv_send_my_tt_response(bat_priv, tt_data, req_src); + return batadv_send_other_tt_response(bat_priv, tt_data, req_src, + req_dst); +} + +static void _batadv_tt_update_changes(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + struct batadv_tvlv_tt_change *tt_change, + u16 tt_num_changes, u8 ttvn) +{ + int i; + int roams; + + for (i = 0; i < tt_num_changes; i++) { + if ((tt_change + i)->flags & BATADV_TT_CLIENT_DEL) { + roams = (tt_change + i)->flags & BATADV_TT_CLIENT_ROAM; + batadv_tt_global_del(bat_priv, orig_node, + (tt_change + i)->addr, + ntohs((tt_change + i)->vid), + "tt removed by changes", + roams); + } else { + if (!batadv_tt_global_add(bat_priv, orig_node, + (tt_change + i)->addr, + ntohs((tt_change + i)->vid), + (tt_change + i)->flags, ttvn)) + /* In case of problem while storing a + * global_entry, we stop the updating + * procedure without committing the + * ttvn change. This will avoid to send + * corrupted data on tt_request + */ + return; + } + } + set_bit(BATADV_ORIG_CAPA_HAS_TT, &orig_node->capa_initialized); +} + +static void batadv_tt_fill_gtable(struct batadv_priv *bat_priv, + struct batadv_tvlv_tt_change *tt_change, + u8 ttvn, u8 *resp_src, + u16 num_entries) +{ + struct batadv_orig_node *orig_node; + + orig_node = batadv_orig_hash_find(bat_priv, resp_src); + if (!orig_node) + goto out; + + /* Purge the old table first.. */ + batadv_tt_global_del_orig(bat_priv, orig_node, -1, + "Received full table"); + + _batadv_tt_update_changes(bat_priv, orig_node, tt_change, num_entries, + ttvn); + + spin_lock_bh(&orig_node->tt_buff_lock); + kfree(orig_node->tt_buff); + orig_node->tt_buff_len = 0; + orig_node->tt_buff = NULL; + spin_unlock_bh(&orig_node->tt_buff_lock); + + atomic_set(&orig_node->last_ttvn, ttvn); + +out: + batadv_orig_node_put(orig_node); +} + +static void batadv_tt_update_changes(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + u16 tt_num_changes, u8 ttvn, + struct batadv_tvlv_tt_change *tt_change) +{ + _batadv_tt_update_changes(bat_priv, orig_node, tt_change, + tt_num_changes, ttvn); + + batadv_tt_save_orig_buffer(bat_priv, orig_node, tt_change, + batadv_tt_len(tt_num_changes)); + atomic_set(&orig_node->last_ttvn, ttvn); +} + +/** + * batadv_is_my_client() - check if a client is served by the local node + * @bat_priv: the bat priv with all the soft interface information + * @addr: the mac address of the client to check + * @vid: VLAN identifier + * + * Return: true if the client is served by this node, false otherwise. + */ +bool batadv_is_my_client(struct batadv_priv *bat_priv, const u8 *addr, + unsigned short vid) +{ + struct batadv_tt_local_entry *tt_local_entry; + bool ret = false; + + tt_local_entry = batadv_tt_local_hash_find(bat_priv, addr, vid); + if (!tt_local_entry) + goto out; + /* Check if the client has been logically deleted (but is kept for + * consistency purpose) + */ + if ((tt_local_entry->common.flags & BATADV_TT_CLIENT_PENDING) || + (tt_local_entry->common.flags & BATADV_TT_CLIENT_ROAM)) + goto out; + ret = true; +out: + batadv_tt_local_entry_put(tt_local_entry); + return ret; +} + +/** + * batadv_handle_tt_response() - process incoming tt reply + * @bat_priv: the bat priv with all the soft interface information + * @tt_data: tt data containing the tt request information + * @resp_src: mac address of tt reply sender + * @num_entries: number of tt change entries appended to the tt data + */ +static void batadv_handle_tt_response(struct batadv_priv *bat_priv, + struct batadv_tvlv_tt_data *tt_data, + u8 *resp_src, u16 num_entries) +{ + struct batadv_tt_req_node *node; + struct hlist_node *safe; + struct batadv_orig_node *orig_node = NULL; + struct batadv_tvlv_tt_change *tt_change; + u8 *tvlv_ptr = (u8 *)tt_data; + u16 change_offset; + + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Received TT_RESPONSE from %pM for ttvn %d t_size: %d [%c]\n", + resp_src, tt_data->ttvn, num_entries, + ((tt_data->flags & BATADV_TT_FULL_TABLE) ? 'F' : '.')); + + orig_node = batadv_orig_hash_find(bat_priv, resp_src); + if (!orig_node) + goto out; + + spin_lock_bh(&orig_node->tt_lock); + + change_offset = sizeof(struct batadv_tvlv_tt_vlan_data); + change_offset *= ntohs(tt_data->num_vlan); + change_offset += sizeof(*tt_data); + tvlv_ptr += change_offset; + + tt_change = (struct batadv_tvlv_tt_change *)tvlv_ptr; + if (tt_data->flags & BATADV_TT_FULL_TABLE) { + batadv_tt_fill_gtable(bat_priv, tt_change, tt_data->ttvn, + resp_src, num_entries); + } else { + batadv_tt_update_changes(bat_priv, orig_node, num_entries, + tt_data->ttvn, tt_change); + } + + /* Recalculate the CRC for this orig_node and store it */ + batadv_tt_global_update_crc(bat_priv, orig_node); + + spin_unlock_bh(&orig_node->tt_lock); + + /* Delete the tt_req_node from pending tt_requests list */ + spin_lock_bh(&bat_priv->tt.req_list_lock); + hlist_for_each_entry_safe(node, safe, &bat_priv->tt.req_list, list) { + if (!batadv_compare_eth(node->addr, resp_src)) + continue; + hlist_del_init(&node->list); + batadv_tt_req_node_put(node); + } + + spin_unlock_bh(&bat_priv->tt.req_list_lock); +out: + batadv_orig_node_put(orig_node); +} + +static void batadv_tt_roam_list_free(struct batadv_priv *bat_priv) +{ + struct batadv_tt_roam_node *node, *safe; + + spin_lock_bh(&bat_priv->tt.roam_list_lock); + + list_for_each_entry_safe(node, safe, &bat_priv->tt.roam_list, list) { + list_del(&node->list); + kmem_cache_free(batadv_tt_roam_cache, node); + } + + spin_unlock_bh(&bat_priv->tt.roam_list_lock); +} + +static void batadv_tt_roam_purge(struct batadv_priv *bat_priv) +{ + struct batadv_tt_roam_node *node, *safe; + + spin_lock_bh(&bat_priv->tt.roam_list_lock); + list_for_each_entry_safe(node, safe, &bat_priv->tt.roam_list, list) { + if (!batadv_has_timed_out(node->first_time, + BATADV_ROAMING_MAX_TIME)) + continue; + + list_del(&node->list); + kmem_cache_free(batadv_tt_roam_cache, node); + } + spin_unlock_bh(&bat_priv->tt.roam_list_lock); +} + +/** + * batadv_tt_check_roam_count() - check if a client has roamed too frequently + * @bat_priv: the bat priv with all the soft interface information + * @client: mac address of the roaming client + * + * This function checks whether the client already reached the + * maximum number of possible roaming phases. In this case the ROAMING_ADV + * will not be sent. + * + * Return: true if the ROAMING_ADV can be sent, false otherwise + */ +static bool batadv_tt_check_roam_count(struct batadv_priv *bat_priv, u8 *client) +{ + struct batadv_tt_roam_node *tt_roam_node; + bool ret = false; + + spin_lock_bh(&bat_priv->tt.roam_list_lock); + /* The new tt_req will be issued only if I'm not waiting for a + * reply from the same orig_node yet + */ + list_for_each_entry(tt_roam_node, &bat_priv->tt.roam_list, list) { + if (!batadv_compare_eth(tt_roam_node->addr, client)) + continue; + + if (batadv_has_timed_out(tt_roam_node->first_time, + BATADV_ROAMING_MAX_TIME)) + continue; + + if (!batadv_atomic_dec_not_zero(&tt_roam_node->counter)) + /* Sorry, you roamed too many times! */ + goto unlock; + ret = true; + break; + } + + if (!ret) { + tt_roam_node = kmem_cache_alloc(batadv_tt_roam_cache, + GFP_ATOMIC); + if (!tt_roam_node) + goto unlock; + + tt_roam_node->first_time = jiffies; + atomic_set(&tt_roam_node->counter, + BATADV_ROAMING_MAX_COUNT - 1); + ether_addr_copy(tt_roam_node->addr, client); + + list_add(&tt_roam_node->list, &bat_priv->tt.roam_list); + ret = true; + } + +unlock: + spin_unlock_bh(&bat_priv->tt.roam_list_lock); + return ret; +} + +/** + * batadv_send_roam_adv() - send a roaming advertisement message + * @bat_priv: the bat priv with all the soft interface information + * @client: mac address of the roaming client + * @vid: VLAN identifier + * @orig_node: message destination + * + * Send a ROAMING_ADV message to the node which was previously serving this + * client. This is done to inform the node that from now on all traffic destined + * for this particular roamed client has to be forwarded to the sender of the + * roaming message. + */ +static void batadv_send_roam_adv(struct batadv_priv *bat_priv, u8 *client, + unsigned short vid, + struct batadv_orig_node *orig_node) +{ + struct batadv_hard_iface *primary_if; + struct batadv_tvlv_roam_adv tvlv_roam; + + primary_if = batadv_primary_if_get_selected(bat_priv); + if (!primary_if) + goto out; + + /* before going on we have to check whether the client has + * already roamed to us too many times + */ + if (!batadv_tt_check_roam_count(bat_priv, client)) + goto out; + + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Sending ROAMING_ADV to %pM (client %pM, vid: %d)\n", + orig_node->orig, client, batadv_print_vid(vid)); + + batadv_inc_counter(bat_priv, BATADV_CNT_TT_ROAM_ADV_TX); + + memcpy(tvlv_roam.client, client, sizeof(tvlv_roam.client)); + tvlv_roam.vid = htons(vid); + + batadv_tvlv_unicast_send(bat_priv, primary_if->net_dev->dev_addr, + orig_node->orig, BATADV_TVLV_ROAM, 1, + &tvlv_roam, sizeof(tvlv_roam)); + +out: + batadv_hardif_put(primary_if); +} + +static void batadv_tt_purge(struct work_struct *work) +{ + struct delayed_work *delayed_work; + struct batadv_priv_tt *priv_tt; + struct batadv_priv *bat_priv; + + delayed_work = to_delayed_work(work); + priv_tt = container_of(delayed_work, struct batadv_priv_tt, work); + bat_priv = container_of(priv_tt, struct batadv_priv, tt); + + batadv_tt_local_purge(bat_priv, BATADV_TT_LOCAL_TIMEOUT); + batadv_tt_global_purge(bat_priv); + batadv_tt_req_purge(bat_priv); + batadv_tt_roam_purge(bat_priv); + + queue_delayed_work(batadv_event_workqueue, &bat_priv->tt.work, + msecs_to_jiffies(BATADV_TT_WORK_PERIOD)); +} + +/** + * batadv_tt_free() - Free translation table of soft interface + * @bat_priv: the bat priv with all the soft interface information + */ +void batadv_tt_free(struct batadv_priv *bat_priv) +{ + batadv_tvlv_handler_unregister(bat_priv, BATADV_TVLV_ROAM, 1); + + batadv_tvlv_container_unregister(bat_priv, BATADV_TVLV_TT, 1); + batadv_tvlv_handler_unregister(bat_priv, BATADV_TVLV_TT, 1); + + cancel_delayed_work_sync(&bat_priv->tt.work); + + batadv_tt_local_table_free(bat_priv); + batadv_tt_global_table_free(bat_priv); + batadv_tt_req_list_free(bat_priv); + batadv_tt_changes_list_free(bat_priv); + batadv_tt_roam_list_free(bat_priv); + + kfree(bat_priv->tt.last_changeset); +} + +/** + * batadv_tt_local_set_flags() - set or unset the specified flags on the local + * table and possibly count them in the TT size + * @bat_priv: the bat priv with all the soft interface information + * @flags: the flag to switch + * @enable: whether to set or unset the flag + * @count: whether to increase the TT size by the number of changed entries + */ +static void batadv_tt_local_set_flags(struct batadv_priv *bat_priv, u16 flags, + bool enable, bool count) +{ + struct batadv_hashtable *hash = bat_priv->tt.local_hash; + struct batadv_tt_common_entry *tt_common_entry; + struct hlist_head *head; + u32 i; + + if (!hash) + return; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tt_common_entry, + head, hash_entry) { + if (enable) { + if ((tt_common_entry->flags & flags) == flags) + continue; + tt_common_entry->flags |= flags; + } else { + if (!(tt_common_entry->flags & flags)) + continue; + tt_common_entry->flags &= ~flags; + } + + if (!count) + continue; + + batadv_tt_local_size_inc(bat_priv, + tt_common_entry->vid); + } + rcu_read_unlock(); + } +} + +/* Purge out all the tt local entries marked with BATADV_TT_CLIENT_PENDING */ +static void batadv_tt_local_purge_pending_clients(struct batadv_priv *bat_priv) +{ + struct batadv_hashtable *hash = bat_priv->tt.local_hash; + struct batadv_tt_common_entry *tt_common; + struct batadv_tt_local_entry *tt_local; + struct hlist_node *node_tmp; + struct hlist_head *head; + spinlock_t *list_lock; /* protects write access to the hash lists */ + u32 i; + + if (!hash) + return; + + for (i = 0; i < hash->size; i++) { + head = &hash->table[i]; + list_lock = &hash->list_locks[i]; + + spin_lock_bh(list_lock); + hlist_for_each_entry_safe(tt_common, node_tmp, head, + hash_entry) { + if (!(tt_common->flags & BATADV_TT_CLIENT_PENDING)) + continue; + + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Deleting local tt entry (%pM, vid: %d): pending\n", + tt_common->addr, + batadv_print_vid(tt_common->vid)); + + batadv_tt_local_size_dec(bat_priv, tt_common->vid); + hlist_del_rcu(&tt_common->hash_entry); + tt_local = container_of(tt_common, + struct batadv_tt_local_entry, + common); + + batadv_tt_local_entry_put(tt_local); + } + spin_unlock_bh(list_lock); + } +} + +/** + * batadv_tt_local_commit_changes_nolock() - commit all pending local tt changes + * which have been queued in the time since the last commit + * @bat_priv: the bat priv with all the soft interface information + * + * Caller must hold tt->commit_lock. + */ +static void batadv_tt_local_commit_changes_nolock(struct batadv_priv *bat_priv) +{ + lockdep_assert_held(&bat_priv->tt.commit_lock); + + if (atomic_read(&bat_priv->tt.local_changes) < 1) { + if (!batadv_atomic_dec_not_zero(&bat_priv->tt.ogm_append_cnt)) + batadv_tt_tvlv_container_update(bat_priv); + return; + } + + batadv_tt_local_set_flags(bat_priv, BATADV_TT_CLIENT_NEW, false, true); + + batadv_tt_local_purge_pending_clients(bat_priv); + batadv_tt_local_update_crc(bat_priv); + + /* Increment the TTVN only once per OGM interval */ + atomic_inc(&bat_priv->tt.vn); + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Local changes committed, updating to ttvn %u\n", + (u8)atomic_read(&bat_priv->tt.vn)); + + /* reset the sending counter */ + atomic_set(&bat_priv->tt.ogm_append_cnt, BATADV_TT_OGM_APPEND_MAX); + batadv_tt_tvlv_container_update(bat_priv); +} + +/** + * batadv_tt_local_commit_changes() - commit all pending local tt changes which + * have been queued in the time since the last commit + * @bat_priv: the bat priv with all the soft interface information + */ +void batadv_tt_local_commit_changes(struct batadv_priv *bat_priv) +{ + spin_lock_bh(&bat_priv->tt.commit_lock); + batadv_tt_local_commit_changes_nolock(bat_priv); + spin_unlock_bh(&bat_priv->tt.commit_lock); +} + +/** + * batadv_is_ap_isolated() - Check if packet from upper layer should be dropped + * @bat_priv: the bat priv with all the soft interface information + * @src: source mac address of packet + * @dst: destination mac address of packet + * @vid: vlan id of packet + * + * Return: true when src+dst(+vid) pair should be isolated, false otherwise + */ +bool batadv_is_ap_isolated(struct batadv_priv *bat_priv, u8 *src, u8 *dst, + unsigned short vid) +{ + struct batadv_tt_local_entry *tt_local_entry; + struct batadv_tt_global_entry *tt_global_entry; + struct batadv_softif_vlan *vlan; + bool ret = false; + + vlan = batadv_softif_vlan_get(bat_priv, vid); + if (!vlan) + return false; + + if (!atomic_read(&vlan->ap_isolation)) + goto vlan_put; + + tt_local_entry = batadv_tt_local_hash_find(bat_priv, dst, vid); + if (!tt_local_entry) + goto vlan_put; + + tt_global_entry = batadv_tt_global_hash_find(bat_priv, src, vid); + if (!tt_global_entry) + goto local_entry_put; + + if (_batadv_is_ap_isolated(tt_local_entry, tt_global_entry)) + ret = true; + + batadv_tt_global_entry_put(tt_global_entry); +local_entry_put: + batadv_tt_local_entry_put(tt_local_entry); +vlan_put: + batadv_softif_vlan_put(vlan); + return ret; +} + +/** + * batadv_tt_update_orig() - update global translation table with new tt + * information received via ogms + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: the orig_node of the ogm + * @tt_buff: pointer to the first tvlv VLAN entry + * @tt_num_vlan: number of tvlv VLAN entries + * @tt_change: pointer to the first entry in the TT buffer + * @tt_num_changes: number of tt changes inside the tt buffer + * @ttvn: translation table version number of this changeset + */ +static void batadv_tt_update_orig(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + const void *tt_buff, u16 tt_num_vlan, + struct batadv_tvlv_tt_change *tt_change, + u16 tt_num_changes, u8 ttvn) +{ + u8 orig_ttvn = (u8)atomic_read(&orig_node->last_ttvn); + struct batadv_tvlv_tt_vlan_data *tt_vlan; + bool full_table = true; + bool has_tt_init; + + tt_vlan = (struct batadv_tvlv_tt_vlan_data *)tt_buff; + has_tt_init = test_bit(BATADV_ORIG_CAPA_HAS_TT, + &orig_node->capa_initialized); + + /* orig table not initialised AND first diff is in the OGM OR the ttvn + * increased by one -> we can apply the attached changes + */ + if ((!has_tt_init && ttvn == 1) || ttvn - orig_ttvn == 1) { + /* the OGM could not contain the changes due to their size or + * because they have already been sent BATADV_TT_OGM_APPEND_MAX + * times. + * In this case send a tt request + */ + if (!tt_num_changes) { + full_table = false; + goto request_table; + } + + spin_lock_bh(&orig_node->tt_lock); + + batadv_tt_update_changes(bat_priv, orig_node, tt_num_changes, + ttvn, tt_change); + + /* Even if we received the precomputed crc with the OGM, we + * prefer to recompute it to spot any possible inconsistency + * in the global table + */ + batadv_tt_global_update_crc(bat_priv, orig_node); + + spin_unlock_bh(&orig_node->tt_lock); + + /* The ttvn alone is not enough to guarantee consistency + * because a single value could represent different states + * (due to the wrap around). Thus a node has to check whether + * the resulting table (after applying the changes) is still + * consistent or not. E.g. a node could disconnect while its + * ttvn is X and reconnect on ttvn = X + TTVN_MAX: in this case + * checking the CRC value is mandatory to detect the + * inconsistency + */ + if (!batadv_tt_global_check_crc(orig_node, tt_vlan, + tt_num_vlan)) + goto request_table; + } else { + /* if we missed more than one change or our tables are not + * in sync anymore -> request fresh tt data + */ + if (!has_tt_init || ttvn != orig_ttvn || + !batadv_tt_global_check_crc(orig_node, tt_vlan, + tt_num_vlan)) { +request_table: + batadv_dbg(BATADV_DBG_TT, bat_priv, + "TT inconsistency for %pM. Need to retrieve the correct information (ttvn: %u last_ttvn: %u num_changes: %u)\n", + orig_node->orig, ttvn, orig_ttvn, + tt_num_changes); + batadv_send_tt_request(bat_priv, orig_node, ttvn, + tt_vlan, tt_num_vlan, + full_table); + return; + } + } +} + +/** + * batadv_tt_global_client_is_roaming() - check if a client is marked as roaming + * @bat_priv: the bat priv with all the soft interface information + * @addr: the mac address of the client to check + * @vid: VLAN identifier + * + * Return: true if we know that the client has moved from its old originator + * to another one. This entry is still kept for consistency purposes and will be + * deleted later by a DEL or because of timeout + */ +bool batadv_tt_global_client_is_roaming(struct batadv_priv *bat_priv, + u8 *addr, unsigned short vid) +{ + struct batadv_tt_global_entry *tt_global_entry; + bool ret = false; + + tt_global_entry = batadv_tt_global_hash_find(bat_priv, addr, vid); + if (!tt_global_entry) + goto out; + + ret = tt_global_entry->common.flags & BATADV_TT_CLIENT_ROAM; + batadv_tt_global_entry_put(tt_global_entry); +out: + return ret; +} + +/** + * batadv_tt_local_client_is_roaming() - tells whether the client is roaming + * @bat_priv: the bat priv with all the soft interface information + * @addr: the mac address of the local client to query + * @vid: VLAN identifier + * + * Return: true if the local client is known to be roaming (it is not served by + * this node anymore) or not. If yes, the client is still present in the table + * to keep the latter consistent with the node TTVN + */ +bool batadv_tt_local_client_is_roaming(struct batadv_priv *bat_priv, + u8 *addr, unsigned short vid) +{ + struct batadv_tt_local_entry *tt_local_entry; + bool ret = false; + + tt_local_entry = batadv_tt_local_hash_find(bat_priv, addr, vid); + if (!tt_local_entry) + goto out; + + ret = tt_local_entry->common.flags & BATADV_TT_CLIENT_ROAM; + batadv_tt_local_entry_put(tt_local_entry); +out: + return ret; +} + +/** + * batadv_tt_add_temporary_global_entry() - Add temporary entry to global TT + * @bat_priv: the bat priv with all the soft interface information + * @orig_node: orig node which the temporary entry should be associated with + * @addr: mac address of the client + * @vid: VLAN id of the new temporary global translation table + * + * Return: true when temporary tt entry could be added, false otherwise + */ +bool batadv_tt_add_temporary_global_entry(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + const unsigned char *addr, + unsigned short vid) +{ + /* ignore loop detect macs, they are not supposed to be in the tt local + * data as well. + */ + if (batadv_bla_is_loopdetect_mac(addr)) + return false; + + if (!batadv_tt_global_add(bat_priv, orig_node, addr, vid, + BATADV_TT_CLIENT_TEMP, + atomic_read(&orig_node->last_ttvn))) + return false; + + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Added temporary global client (addr: %pM, vid: %d, orig: %pM)\n", + addr, batadv_print_vid(vid), orig_node->orig); + + return true; +} + +/** + * batadv_tt_local_resize_to_mtu() - resize the local translation table fit the + * maximum packet size that can be transported through the mesh + * @soft_iface: netdev struct of the mesh interface + * + * Remove entries older than 'timeout' and half timeout if more entries need + * to be removed. + */ +void batadv_tt_local_resize_to_mtu(struct net_device *soft_iface) +{ + struct batadv_priv *bat_priv = netdev_priv(soft_iface); + int packet_size_max = atomic_read(&bat_priv->packet_size_max); + int table_size, timeout = BATADV_TT_LOCAL_TIMEOUT / 2; + bool reduced = false; + + spin_lock_bh(&bat_priv->tt.commit_lock); + + while (true) { + table_size = batadv_tt_local_table_transmit_size(bat_priv); + if (packet_size_max >= table_size) + break; + + batadv_tt_local_purge(bat_priv, timeout); + batadv_tt_local_purge_pending_clients(bat_priv); + + timeout /= 2; + reduced = true; + net_ratelimited_function(batadv_info, soft_iface, + "Forced to purge local tt entries to fit new maximum fragment MTU (%i)\n", + packet_size_max); + } + + /* commit these changes immediately, to avoid synchronization problem + * with the TTVN + */ + if (reduced) + batadv_tt_local_commit_changes_nolock(bat_priv); + + spin_unlock_bh(&bat_priv->tt.commit_lock); +} + +/** + * batadv_tt_tvlv_ogm_handler_v1() - process incoming tt tvlv container + * @bat_priv: the bat priv with all the soft interface information + * @orig: the orig_node of the ogm + * @flags: flags indicating the tvlv state (see batadv_tvlv_handler_flags) + * @tvlv_value: tvlv buffer containing the gateway data + * @tvlv_value_len: tvlv buffer length + */ +static void batadv_tt_tvlv_ogm_handler_v1(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig, + u8 flags, void *tvlv_value, + u16 tvlv_value_len) +{ + struct batadv_tvlv_tt_vlan_data *tt_vlan; + struct batadv_tvlv_tt_change *tt_change; + struct batadv_tvlv_tt_data *tt_data; + u16 num_entries, num_vlan; + + if (tvlv_value_len < sizeof(*tt_data)) + return; + + tt_data = tvlv_value; + tvlv_value_len -= sizeof(*tt_data); + + num_vlan = ntohs(tt_data->num_vlan); + + if (tvlv_value_len < sizeof(*tt_vlan) * num_vlan) + return; + + tt_vlan = (struct batadv_tvlv_tt_vlan_data *)(tt_data + 1); + tt_change = (struct batadv_tvlv_tt_change *)(tt_vlan + num_vlan); + tvlv_value_len -= sizeof(*tt_vlan) * num_vlan; + + num_entries = batadv_tt_entries(tvlv_value_len); + + batadv_tt_update_orig(bat_priv, orig, tt_vlan, num_vlan, tt_change, + num_entries, tt_data->ttvn); +} + +/** + * batadv_tt_tvlv_unicast_handler_v1() - process incoming (unicast) tt tvlv + * container + * @bat_priv: the bat priv with all the soft interface information + * @src: mac address of tt tvlv sender + * @dst: mac address of tt tvlv recipient + * @tvlv_value: tvlv buffer containing the tt data + * @tvlv_value_len: tvlv buffer length + * + * Return: NET_RX_DROP if the tt tvlv is to be re-routed, NET_RX_SUCCESS + * otherwise. + */ +static int batadv_tt_tvlv_unicast_handler_v1(struct batadv_priv *bat_priv, + u8 *src, u8 *dst, + void *tvlv_value, + u16 tvlv_value_len) +{ + struct batadv_tvlv_tt_data *tt_data; + u16 tt_vlan_len, tt_num_entries; + char tt_flag; + bool ret; + + if (tvlv_value_len < sizeof(*tt_data)) + return NET_RX_SUCCESS; + + tt_data = tvlv_value; + tvlv_value_len -= sizeof(*tt_data); + + tt_vlan_len = sizeof(struct batadv_tvlv_tt_vlan_data); + tt_vlan_len *= ntohs(tt_data->num_vlan); + + if (tvlv_value_len < tt_vlan_len) + return NET_RX_SUCCESS; + + tvlv_value_len -= tt_vlan_len; + tt_num_entries = batadv_tt_entries(tvlv_value_len); + + switch (tt_data->flags & BATADV_TT_DATA_TYPE_MASK) { + case BATADV_TT_REQUEST: + batadv_inc_counter(bat_priv, BATADV_CNT_TT_REQUEST_RX); + + /* If this node cannot provide a TT response the tt_request is + * forwarded + */ + ret = batadv_send_tt_response(bat_priv, tt_data, src, dst); + if (!ret) { + if (tt_data->flags & BATADV_TT_FULL_TABLE) + tt_flag = 'F'; + else + tt_flag = '.'; + + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Routing TT_REQUEST to %pM [%c]\n", + dst, tt_flag); + /* tvlv API will re-route the packet */ + return NET_RX_DROP; + } + break; + case BATADV_TT_RESPONSE: + batadv_inc_counter(bat_priv, BATADV_CNT_TT_RESPONSE_RX); + + if (batadv_is_my_mac(bat_priv, dst)) { + batadv_handle_tt_response(bat_priv, tt_data, + src, tt_num_entries); + return NET_RX_SUCCESS; + } + + if (tt_data->flags & BATADV_TT_FULL_TABLE) + tt_flag = 'F'; + else + tt_flag = '.'; + + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Routing TT_RESPONSE to %pM [%c]\n", dst, tt_flag); + + /* tvlv API will re-route the packet */ + return NET_RX_DROP; + } + + return NET_RX_SUCCESS; +} + +/** + * batadv_roam_tvlv_unicast_handler_v1() - process incoming tt roam tvlv + * container + * @bat_priv: the bat priv with all the soft interface information + * @src: mac address of tt tvlv sender + * @dst: mac address of tt tvlv recipient + * @tvlv_value: tvlv buffer containing the tt data + * @tvlv_value_len: tvlv buffer length + * + * Return: NET_RX_DROP if the tt roam tvlv is to be re-routed, NET_RX_SUCCESS + * otherwise. + */ +static int batadv_roam_tvlv_unicast_handler_v1(struct batadv_priv *bat_priv, + u8 *src, u8 *dst, + void *tvlv_value, + u16 tvlv_value_len) +{ + struct batadv_tvlv_roam_adv *roaming_adv; + struct batadv_orig_node *orig_node = NULL; + + /* If this node is not the intended recipient of the + * roaming advertisement the packet is forwarded + * (the tvlv API will re-route the packet). + */ + if (!batadv_is_my_mac(bat_priv, dst)) + return NET_RX_DROP; + + if (tvlv_value_len < sizeof(*roaming_adv)) + goto out; + + orig_node = batadv_orig_hash_find(bat_priv, src); + if (!orig_node) + goto out; + + batadv_inc_counter(bat_priv, BATADV_CNT_TT_ROAM_ADV_RX); + roaming_adv = tvlv_value; + + batadv_dbg(BATADV_DBG_TT, bat_priv, + "Received ROAMING_ADV from %pM (client %pM)\n", + src, roaming_adv->client); + + batadv_tt_global_add(bat_priv, orig_node, roaming_adv->client, + ntohs(roaming_adv->vid), BATADV_TT_CLIENT_ROAM, + atomic_read(&orig_node->last_ttvn) + 1); + +out: + batadv_orig_node_put(orig_node); + return NET_RX_SUCCESS; +} + +/** + * batadv_tt_init() - initialise the translation table internals + * @bat_priv: the bat priv with all the soft interface information + * + * Return: 0 on success or negative error number in case of failure. + */ +int batadv_tt_init(struct batadv_priv *bat_priv) +{ + int ret; + + /* synchronized flags must be remote */ + BUILD_BUG_ON(!(BATADV_TT_SYNC_MASK & BATADV_TT_REMOTE_MASK)); + + ret = batadv_tt_local_init(bat_priv); + if (ret < 0) + return ret; + + ret = batadv_tt_global_init(bat_priv); + if (ret < 0) { + batadv_tt_local_table_free(bat_priv); + return ret; + } + + batadv_tvlv_handler_register(bat_priv, batadv_tt_tvlv_ogm_handler_v1, + batadv_tt_tvlv_unicast_handler_v1, + BATADV_TVLV_TT, 1, BATADV_NO_FLAGS); + + batadv_tvlv_handler_register(bat_priv, NULL, + batadv_roam_tvlv_unicast_handler_v1, + BATADV_TVLV_ROAM, 1, BATADV_NO_FLAGS); + + INIT_DELAYED_WORK(&bat_priv->tt.work, batadv_tt_purge); + queue_delayed_work(batadv_event_workqueue, &bat_priv->tt.work, + msecs_to_jiffies(BATADV_TT_WORK_PERIOD)); + + return 1; +} + +/** + * batadv_tt_global_is_isolated() - check if a client is marked as isolated + * @bat_priv: the bat priv with all the soft interface information + * @addr: the mac address of the client + * @vid: the identifier of the VLAN where this client is connected + * + * Return: true if the client is marked with the TT_CLIENT_ISOLA flag, false + * otherwise + */ +bool batadv_tt_global_is_isolated(struct batadv_priv *bat_priv, + const u8 *addr, unsigned short vid) +{ + struct batadv_tt_global_entry *tt; + bool ret; + + tt = batadv_tt_global_hash_find(bat_priv, addr, vid); + if (!tt) + return false; + + ret = tt->common.flags & BATADV_TT_CLIENT_ISOLA; + + batadv_tt_global_entry_put(tt); + + return ret; +} + +/** + * batadv_tt_cache_init() - Initialize tt memory object cache + * + * Return: 0 on success or negative error number in case of failure. + */ +int __init batadv_tt_cache_init(void) +{ + size_t tl_size = sizeof(struct batadv_tt_local_entry); + size_t tg_size = sizeof(struct batadv_tt_global_entry); + size_t tt_orig_size = sizeof(struct batadv_tt_orig_list_entry); + size_t tt_change_size = sizeof(struct batadv_tt_change_node); + size_t tt_req_size = sizeof(struct batadv_tt_req_node); + size_t tt_roam_size = sizeof(struct batadv_tt_roam_node); + + batadv_tl_cache = kmem_cache_create("batadv_tl_cache", tl_size, 0, + SLAB_HWCACHE_ALIGN, NULL); + if (!batadv_tl_cache) + return -ENOMEM; + + batadv_tg_cache = kmem_cache_create("batadv_tg_cache", tg_size, 0, + SLAB_HWCACHE_ALIGN, NULL); + if (!batadv_tg_cache) + goto err_tt_tl_destroy; + + batadv_tt_orig_cache = kmem_cache_create("batadv_tt_orig_cache", + tt_orig_size, 0, + SLAB_HWCACHE_ALIGN, NULL); + if (!batadv_tt_orig_cache) + goto err_tt_tg_destroy; + + batadv_tt_change_cache = kmem_cache_create("batadv_tt_change_cache", + tt_change_size, 0, + SLAB_HWCACHE_ALIGN, NULL); + if (!batadv_tt_change_cache) + goto err_tt_orig_destroy; + + batadv_tt_req_cache = kmem_cache_create("batadv_tt_req_cache", + tt_req_size, 0, + SLAB_HWCACHE_ALIGN, NULL); + if (!batadv_tt_req_cache) + goto err_tt_change_destroy; + + batadv_tt_roam_cache = kmem_cache_create("batadv_tt_roam_cache", + tt_roam_size, 0, + SLAB_HWCACHE_ALIGN, NULL); + if (!batadv_tt_roam_cache) + goto err_tt_req_destroy; + + return 0; + +err_tt_req_destroy: + kmem_cache_destroy(batadv_tt_req_cache); + batadv_tt_req_cache = NULL; +err_tt_change_destroy: + kmem_cache_destroy(batadv_tt_change_cache); + batadv_tt_change_cache = NULL; +err_tt_orig_destroy: + kmem_cache_destroy(batadv_tt_orig_cache); + batadv_tt_orig_cache = NULL; +err_tt_tg_destroy: + kmem_cache_destroy(batadv_tg_cache); + batadv_tg_cache = NULL; +err_tt_tl_destroy: + kmem_cache_destroy(batadv_tl_cache); + batadv_tl_cache = NULL; + + return -ENOMEM; +} + +/** + * batadv_tt_cache_destroy() - Destroy tt memory object cache + */ +void batadv_tt_cache_destroy(void) +{ + kmem_cache_destroy(batadv_tl_cache); + kmem_cache_destroy(batadv_tg_cache); + kmem_cache_destroy(batadv_tt_orig_cache); + kmem_cache_destroy(batadv_tt_change_cache); + kmem_cache_destroy(batadv_tt_req_cache); + kmem_cache_destroy(batadv_tt_roam_cache); +} diff --git a/net/batman-adv/translation-table.h b/net/batman-adv/translation-table.h new file mode 100644 index 000000000..d18740d9a --- /dev/null +++ b/net/batman-adv/translation-table.h @@ -0,0 +1,74 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich, Antonio Quartulli + */ + +#ifndef _NET_BATMAN_ADV_TRANSLATION_TABLE_H_ +#define _NET_BATMAN_ADV_TRANSLATION_TABLE_H_ + +#include "main.h" + +#include +#include +#include +#include +#include + +int batadv_tt_init(struct batadv_priv *bat_priv); +bool batadv_tt_local_add(struct net_device *soft_iface, const u8 *addr, + unsigned short vid, int ifindex, u32 mark); +u16 batadv_tt_local_remove(struct batadv_priv *bat_priv, + const u8 *addr, unsigned short vid, + const char *message, bool roaming); +int batadv_tt_local_dump(struct sk_buff *msg, struct netlink_callback *cb); +int batadv_tt_global_dump(struct sk_buff *msg, struct netlink_callback *cb); +void batadv_tt_global_del_orig(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + s32 match_vid, const char *message); +struct batadv_tt_global_entry * +batadv_tt_global_hash_find(struct batadv_priv *bat_priv, const u8 *addr, + unsigned short vid); +void batadv_tt_global_entry_release(struct kref *ref); +int batadv_tt_global_hash_count(struct batadv_priv *bat_priv, + const u8 *addr, unsigned short vid); +struct batadv_orig_node *batadv_transtable_search(struct batadv_priv *bat_priv, + const u8 *src, const u8 *addr, + unsigned short vid); +void batadv_tt_free(struct batadv_priv *bat_priv); +bool batadv_is_my_client(struct batadv_priv *bat_priv, const u8 *addr, + unsigned short vid); +bool batadv_is_ap_isolated(struct batadv_priv *bat_priv, u8 *src, u8 *dst, + unsigned short vid); +void batadv_tt_local_commit_changes(struct batadv_priv *bat_priv); +bool batadv_tt_global_client_is_roaming(struct batadv_priv *bat_priv, + u8 *addr, unsigned short vid); +bool batadv_tt_local_client_is_roaming(struct batadv_priv *bat_priv, + u8 *addr, unsigned short vid); +void batadv_tt_local_resize_to_mtu(struct net_device *soft_iface); +bool batadv_tt_add_temporary_global_entry(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig_node, + const unsigned char *addr, + unsigned short vid); +bool batadv_tt_global_is_isolated(struct batadv_priv *bat_priv, + const u8 *addr, unsigned short vid); + +int batadv_tt_cache_init(void); +void batadv_tt_cache_destroy(void); + +/** + * batadv_tt_global_entry_put() - decrement the tt_global_entry refcounter and + * possibly release it + * @tt_global_entry: tt_global_entry to be free'd + */ +static inline void +batadv_tt_global_entry_put(struct batadv_tt_global_entry *tt_global_entry) +{ + if (!tt_global_entry) + return; + + kref_put(&tt_global_entry->common.refcount, + batadv_tt_global_entry_release); +} + +#endif /* _NET_BATMAN_ADV_TRANSLATION_TABLE_H_ */ diff --git a/net/batman-adv/tvlv.c b/net/batman-adv/tvlv.c new file mode 100644 index 000000000..7ec2e2343 --- /dev/null +++ b/net/batman-adv/tvlv.c @@ -0,0 +1,636 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#include "main.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "originator.h" +#include "send.h" +#include "tvlv.h" + +/** + * batadv_tvlv_handler_release() - release tvlv handler from lists and queue for + * free after rcu grace period + * @ref: kref pointer of the tvlv + */ +static void batadv_tvlv_handler_release(struct kref *ref) +{ + struct batadv_tvlv_handler *tvlv_handler; + + tvlv_handler = container_of(ref, struct batadv_tvlv_handler, refcount); + kfree_rcu(tvlv_handler, rcu); +} + +/** + * batadv_tvlv_handler_put() - decrement the tvlv container refcounter and + * possibly release it + * @tvlv_handler: the tvlv handler to free + */ +static void batadv_tvlv_handler_put(struct batadv_tvlv_handler *tvlv_handler) +{ + if (!tvlv_handler) + return; + + kref_put(&tvlv_handler->refcount, batadv_tvlv_handler_release); +} + +/** + * batadv_tvlv_handler_get() - retrieve tvlv handler from the tvlv handler list + * based on the provided type and version (both need to match) + * @bat_priv: the bat priv with all the soft interface information + * @type: tvlv handler type to look for + * @version: tvlv handler version to look for + * + * Return: tvlv handler if found or NULL otherwise. + */ +static struct batadv_tvlv_handler * +batadv_tvlv_handler_get(struct batadv_priv *bat_priv, u8 type, u8 version) +{ + struct batadv_tvlv_handler *tvlv_handler_tmp, *tvlv_handler = NULL; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tvlv_handler_tmp, + &bat_priv->tvlv.handler_list, list) { + if (tvlv_handler_tmp->type != type) + continue; + + if (tvlv_handler_tmp->version != version) + continue; + + if (!kref_get_unless_zero(&tvlv_handler_tmp->refcount)) + continue; + + tvlv_handler = tvlv_handler_tmp; + break; + } + rcu_read_unlock(); + + return tvlv_handler; +} + +/** + * batadv_tvlv_container_release() - release tvlv from lists and free + * @ref: kref pointer of the tvlv + */ +static void batadv_tvlv_container_release(struct kref *ref) +{ + struct batadv_tvlv_container *tvlv; + + tvlv = container_of(ref, struct batadv_tvlv_container, refcount); + kfree(tvlv); +} + +/** + * batadv_tvlv_container_put() - decrement the tvlv container refcounter and + * possibly release it + * @tvlv: the tvlv container to free + */ +static void batadv_tvlv_container_put(struct batadv_tvlv_container *tvlv) +{ + if (!tvlv) + return; + + kref_put(&tvlv->refcount, batadv_tvlv_container_release); +} + +/** + * batadv_tvlv_container_get() - retrieve tvlv container from the tvlv container + * list based on the provided type and version (both need to match) + * @bat_priv: the bat priv with all the soft interface information + * @type: tvlv container type to look for + * @version: tvlv container version to look for + * + * Has to be called with the appropriate locks being acquired + * (tvlv.container_list_lock). + * + * Return: tvlv container if found or NULL otherwise. + */ +static struct batadv_tvlv_container * +batadv_tvlv_container_get(struct batadv_priv *bat_priv, u8 type, u8 version) +{ + struct batadv_tvlv_container *tvlv_tmp, *tvlv = NULL; + + lockdep_assert_held(&bat_priv->tvlv.container_list_lock); + + hlist_for_each_entry(tvlv_tmp, &bat_priv->tvlv.container_list, list) { + if (tvlv_tmp->tvlv_hdr.type != type) + continue; + + if (tvlv_tmp->tvlv_hdr.version != version) + continue; + + kref_get(&tvlv_tmp->refcount); + tvlv = tvlv_tmp; + break; + } + + return tvlv; +} + +/** + * batadv_tvlv_container_list_size() - calculate the size of the tvlv container + * list entries + * @bat_priv: the bat priv with all the soft interface information + * + * Has to be called with the appropriate locks being acquired + * (tvlv.container_list_lock). + * + * Return: size of all currently registered tvlv containers in bytes. + */ +static u16 batadv_tvlv_container_list_size(struct batadv_priv *bat_priv) +{ + struct batadv_tvlv_container *tvlv; + u16 tvlv_len = 0; + + lockdep_assert_held(&bat_priv->tvlv.container_list_lock); + + hlist_for_each_entry(tvlv, &bat_priv->tvlv.container_list, list) { + tvlv_len += sizeof(struct batadv_tvlv_hdr); + tvlv_len += ntohs(tvlv->tvlv_hdr.len); + } + + return tvlv_len; +} + +/** + * batadv_tvlv_container_remove() - remove tvlv container from the tvlv + * container list + * @bat_priv: the bat priv with all the soft interface information + * @tvlv: the to be removed tvlv container + * + * Has to be called with the appropriate locks being acquired + * (tvlv.container_list_lock). + */ +static void batadv_tvlv_container_remove(struct batadv_priv *bat_priv, + struct batadv_tvlv_container *tvlv) +{ + lockdep_assert_held(&bat_priv->tvlv.container_list_lock); + + if (!tvlv) + return; + + hlist_del(&tvlv->list); + + /* first call to decrement the counter, second call to free */ + batadv_tvlv_container_put(tvlv); + batadv_tvlv_container_put(tvlv); +} + +/** + * batadv_tvlv_container_unregister() - unregister tvlv container based on the + * provided type and version (both need to match) + * @bat_priv: the bat priv with all the soft interface information + * @type: tvlv container type to unregister + * @version: tvlv container type to unregister + */ +void batadv_tvlv_container_unregister(struct batadv_priv *bat_priv, + u8 type, u8 version) +{ + struct batadv_tvlv_container *tvlv; + + spin_lock_bh(&bat_priv->tvlv.container_list_lock); + tvlv = batadv_tvlv_container_get(bat_priv, type, version); + batadv_tvlv_container_remove(bat_priv, tvlv); + spin_unlock_bh(&bat_priv->tvlv.container_list_lock); +} + +/** + * batadv_tvlv_container_register() - register tvlv type, version and content + * to be propagated with each (primary interface) OGM + * @bat_priv: the bat priv with all the soft interface information + * @type: tvlv container type + * @version: tvlv container version + * @tvlv_value: tvlv container content + * @tvlv_value_len: tvlv container content length + * + * If a container of the same type and version was already registered the new + * content is going to replace the old one. + */ +void batadv_tvlv_container_register(struct batadv_priv *bat_priv, + u8 type, u8 version, + void *tvlv_value, u16 tvlv_value_len) +{ + struct batadv_tvlv_container *tvlv_old, *tvlv_new; + + if (!tvlv_value) + tvlv_value_len = 0; + + tvlv_new = kzalloc(sizeof(*tvlv_new) + tvlv_value_len, GFP_ATOMIC); + if (!tvlv_new) + return; + + tvlv_new->tvlv_hdr.version = version; + tvlv_new->tvlv_hdr.type = type; + tvlv_new->tvlv_hdr.len = htons(tvlv_value_len); + + memcpy(tvlv_new + 1, tvlv_value, ntohs(tvlv_new->tvlv_hdr.len)); + INIT_HLIST_NODE(&tvlv_new->list); + kref_init(&tvlv_new->refcount); + + spin_lock_bh(&bat_priv->tvlv.container_list_lock); + tvlv_old = batadv_tvlv_container_get(bat_priv, type, version); + batadv_tvlv_container_remove(bat_priv, tvlv_old); + + kref_get(&tvlv_new->refcount); + hlist_add_head(&tvlv_new->list, &bat_priv->tvlv.container_list); + spin_unlock_bh(&bat_priv->tvlv.container_list_lock); + + /* don't return reference to new tvlv_container */ + batadv_tvlv_container_put(tvlv_new); +} + +/** + * batadv_tvlv_realloc_packet_buff() - reallocate packet buffer to accommodate + * requested packet size + * @packet_buff: packet buffer + * @packet_buff_len: packet buffer size + * @min_packet_len: requested packet minimum size + * @additional_packet_len: requested additional packet size on top of minimum + * size + * + * Return: true of the packet buffer could be changed to the requested size, + * false otherwise. + */ +static bool batadv_tvlv_realloc_packet_buff(unsigned char **packet_buff, + int *packet_buff_len, + int min_packet_len, + int additional_packet_len) +{ + unsigned char *new_buff; + + new_buff = kmalloc(min_packet_len + additional_packet_len, GFP_ATOMIC); + + /* keep old buffer if kmalloc should fail */ + if (!new_buff) + return false; + + memcpy(new_buff, *packet_buff, min_packet_len); + kfree(*packet_buff); + *packet_buff = new_buff; + *packet_buff_len = min_packet_len + additional_packet_len; + + return true; +} + +/** + * batadv_tvlv_container_ogm_append() - append tvlv container content to given + * OGM packet buffer + * @bat_priv: the bat priv with all the soft interface information + * @packet_buff: ogm packet buffer + * @packet_buff_len: ogm packet buffer size including ogm header and tvlv + * content + * @packet_min_len: ogm header size to be preserved for the OGM itself + * + * The ogm packet might be enlarged or shrunk depending on the current size + * and the size of the to-be-appended tvlv containers. + * + * Return: size of all appended tvlv containers in bytes. + */ +u16 batadv_tvlv_container_ogm_append(struct batadv_priv *bat_priv, + unsigned char **packet_buff, + int *packet_buff_len, int packet_min_len) +{ + struct batadv_tvlv_container *tvlv; + struct batadv_tvlv_hdr *tvlv_hdr; + u16 tvlv_value_len; + void *tvlv_value; + bool ret; + + spin_lock_bh(&bat_priv->tvlv.container_list_lock); + tvlv_value_len = batadv_tvlv_container_list_size(bat_priv); + + ret = batadv_tvlv_realloc_packet_buff(packet_buff, packet_buff_len, + packet_min_len, tvlv_value_len); + + if (!ret) + goto end; + + if (!tvlv_value_len) + goto end; + + tvlv_value = (*packet_buff) + packet_min_len; + + hlist_for_each_entry(tvlv, &bat_priv->tvlv.container_list, list) { + tvlv_hdr = tvlv_value; + tvlv_hdr->type = tvlv->tvlv_hdr.type; + tvlv_hdr->version = tvlv->tvlv_hdr.version; + tvlv_hdr->len = tvlv->tvlv_hdr.len; + tvlv_value = tvlv_hdr + 1; + memcpy(tvlv_value, tvlv + 1, ntohs(tvlv->tvlv_hdr.len)); + tvlv_value = (u8 *)tvlv_value + ntohs(tvlv->tvlv_hdr.len); + } + +end: + spin_unlock_bh(&bat_priv->tvlv.container_list_lock); + return tvlv_value_len; +} + +/** + * batadv_tvlv_call_handler() - parse the given tvlv buffer to call the + * appropriate handlers + * @bat_priv: the bat priv with all the soft interface information + * @tvlv_handler: tvlv callback function handling the tvlv content + * @ogm_source: flag indicating whether the tvlv is an ogm or a unicast packet + * @orig_node: orig node emitting the ogm packet + * @src: source mac address of the unicast packet + * @dst: destination mac address of the unicast packet + * @tvlv_value: tvlv content + * @tvlv_value_len: tvlv content length + * + * Return: success if the handler was not found or the return value of the + * handler callback. + */ +static int batadv_tvlv_call_handler(struct batadv_priv *bat_priv, + struct batadv_tvlv_handler *tvlv_handler, + bool ogm_source, + struct batadv_orig_node *orig_node, + u8 *src, u8 *dst, + void *tvlv_value, u16 tvlv_value_len) +{ + if (!tvlv_handler) + return NET_RX_SUCCESS; + + if (ogm_source) { + if (!tvlv_handler->ogm_handler) + return NET_RX_SUCCESS; + + if (!orig_node) + return NET_RX_SUCCESS; + + tvlv_handler->ogm_handler(bat_priv, orig_node, + BATADV_NO_FLAGS, + tvlv_value, tvlv_value_len); + tvlv_handler->flags |= BATADV_TVLV_HANDLER_OGM_CALLED; + } else { + if (!src) + return NET_RX_SUCCESS; + + if (!dst) + return NET_RX_SUCCESS; + + if (!tvlv_handler->unicast_handler) + return NET_RX_SUCCESS; + + return tvlv_handler->unicast_handler(bat_priv, src, + dst, tvlv_value, + tvlv_value_len); + } + + return NET_RX_SUCCESS; +} + +/** + * batadv_tvlv_containers_process() - parse the given tvlv buffer to call the + * appropriate handlers + * @bat_priv: the bat priv with all the soft interface information + * @ogm_source: flag indicating whether the tvlv is an ogm or a unicast packet + * @orig_node: orig node emitting the ogm packet + * @src: source mac address of the unicast packet + * @dst: destination mac address of the unicast packet + * @tvlv_value: tvlv content + * @tvlv_value_len: tvlv content length + * + * Return: success when processing an OGM or the return value of all called + * handler callbacks. + */ +int batadv_tvlv_containers_process(struct batadv_priv *bat_priv, + bool ogm_source, + struct batadv_orig_node *orig_node, + u8 *src, u8 *dst, + void *tvlv_value, u16 tvlv_value_len) +{ + struct batadv_tvlv_handler *tvlv_handler; + struct batadv_tvlv_hdr *tvlv_hdr; + u16 tvlv_value_cont_len; + u8 cifnotfound = BATADV_TVLV_HANDLER_OGM_CIFNOTFND; + int ret = NET_RX_SUCCESS; + + while (tvlv_value_len >= sizeof(*tvlv_hdr)) { + tvlv_hdr = tvlv_value; + tvlv_value_cont_len = ntohs(tvlv_hdr->len); + tvlv_value = tvlv_hdr + 1; + tvlv_value_len -= sizeof(*tvlv_hdr); + + if (tvlv_value_cont_len > tvlv_value_len) + break; + + tvlv_handler = batadv_tvlv_handler_get(bat_priv, + tvlv_hdr->type, + tvlv_hdr->version); + + ret |= batadv_tvlv_call_handler(bat_priv, tvlv_handler, + ogm_source, orig_node, + src, dst, tvlv_value, + tvlv_value_cont_len); + batadv_tvlv_handler_put(tvlv_handler); + tvlv_value = (u8 *)tvlv_value + tvlv_value_cont_len; + tvlv_value_len -= tvlv_value_cont_len; + } + + if (!ogm_source) + return ret; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tvlv_handler, + &bat_priv->tvlv.handler_list, list) { + if ((tvlv_handler->flags & BATADV_TVLV_HANDLER_OGM_CIFNOTFND) && + !(tvlv_handler->flags & BATADV_TVLV_HANDLER_OGM_CALLED)) + tvlv_handler->ogm_handler(bat_priv, orig_node, + cifnotfound, NULL, 0); + + tvlv_handler->flags &= ~BATADV_TVLV_HANDLER_OGM_CALLED; + } + rcu_read_unlock(); + + return NET_RX_SUCCESS; +} + +/** + * batadv_tvlv_ogm_receive() - process an incoming ogm and call the appropriate + * handlers + * @bat_priv: the bat priv with all the soft interface information + * @batadv_ogm_packet: ogm packet containing the tvlv containers + * @orig_node: orig node emitting the ogm packet + */ +void batadv_tvlv_ogm_receive(struct batadv_priv *bat_priv, + struct batadv_ogm_packet *batadv_ogm_packet, + struct batadv_orig_node *orig_node) +{ + void *tvlv_value; + u16 tvlv_value_len; + + if (!batadv_ogm_packet) + return; + + tvlv_value_len = ntohs(batadv_ogm_packet->tvlv_len); + if (!tvlv_value_len) + return; + + tvlv_value = batadv_ogm_packet + 1; + + batadv_tvlv_containers_process(bat_priv, true, orig_node, NULL, NULL, + tvlv_value, tvlv_value_len); +} + +/** + * batadv_tvlv_handler_register() - register tvlv handler based on the provided + * type and version (both need to match) for ogm tvlv payload and/or unicast + * payload + * @bat_priv: the bat priv with all the soft interface information + * @optr: ogm tvlv handler callback function. This function receives the orig + * node, flags and the tvlv content as argument to process. + * @uptr: unicast tvlv handler callback function. This function receives the + * source & destination of the unicast packet as well as the tvlv content + * to process. + * @type: tvlv handler type to be registered + * @version: tvlv handler version to be registered + * @flags: flags to enable or disable TVLV API behavior + */ +void batadv_tvlv_handler_register(struct batadv_priv *bat_priv, + void (*optr)(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig, + u8 flags, + void *tvlv_value, + u16 tvlv_value_len), + int (*uptr)(struct batadv_priv *bat_priv, + u8 *src, u8 *dst, + void *tvlv_value, + u16 tvlv_value_len), + u8 type, u8 version, u8 flags) +{ + struct batadv_tvlv_handler *tvlv_handler; + + spin_lock_bh(&bat_priv->tvlv.handler_list_lock); + + tvlv_handler = batadv_tvlv_handler_get(bat_priv, type, version); + if (tvlv_handler) { + spin_unlock_bh(&bat_priv->tvlv.handler_list_lock); + batadv_tvlv_handler_put(tvlv_handler); + return; + } + + tvlv_handler = kzalloc(sizeof(*tvlv_handler), GFP_ATOMIC); + if (!tvlv_handler) { + spin_unlock_bh(&bat_priv->tvlv.handler_list_lock); + return; + } + + tvlv_handler->ogm_handler = optr; + tvlv_handler->unicast_handler = uptr; + tvlv_handler->type = type; + tvlv_handler->version = version; + tvlv_handler->flags = flags; + kref_init(&tvlv_handler->refcount); + INIT_HLIST_NODE(&tvlv_handler->list); + + kref_get(&tvlv_handler->refcount); + hlist_add_head_rcu(&tvlv_handler->list, &bat_priv->tvlv.handler_list); + spin_unlock_bh(&bat_priv->tvlv.handler_list_lock); + + /* don't return reference to new tvlv_handler */ + batadv_tvlv_handler_put(tvlv_handler); +} + +/** + * batadv_tvlv_handler_unregister() - unregister tvlv handler based on the + * provided type and version (both need to match) + * @bat_priv: the bat priv with all the soft interface information + * @type: tvlv handler type to be unregistered + * @version: tvlv handler version to be unregistered + */ +void batadv_tvlv_handler_unregister(struct batadv_priv *bat_priv, + u8 type, u8 version) +{ + struct batadv_tvlv_handler *tvlv_handler; + + tvlv_handler = batadv_tvlv_handler_get(bat_priv, type, version); + if (!tvlv_handler) + return; + + batadv_tvlv_handler_put(tvlv_handler); + spin_lock_bh(&bat_priv->tvlv.handler_list_lock); + hlist_del_rcu(&tvlv_handler->list); + spin_unlock_bh(&bat_priv->tvlv.handler_list_lock); + batadv_tvlv_handler_put(tvlv_handler); +} + +/** + * batadv_tvlv_unicast_send() - send a unicast packet with tvlv payload to the + * specified host + * @bat_priv: the bat priv with all the soft interface information + * @src: source mac address of the unicast packet + * @dst: destination mac address of the unicast packet + * @type: tvlv type + * @version: tvlv version + * @tvlv_value: tvlv content + * @tvlv_value_len: tvlv content length + */ +void batadv_tvlv_unicast_send(struct batadv_priv *bat_priv, const u8 *src, + const u8 *dst, u8 type, u8 version, + void *tvlv_value, u16 tvlv_value_len) +{ + struct batadv_unicast_tvlv_packet *unicast_tvlv_packet; + struct batadv_tvlv_hdr *tvlv_hdr; + struct batadv_orig_node *orig_node; + struct sk_buff *skb; + unsigned char *tvlv_buff; + unsigned int tvlv_len; + ssize_t hdr_len = sizeof(*unicast_tvlv_packet); + + orig_node = batadv_orig_hash_find(bat_priv, dst); + if (!orig_node) + return; + + tvlv_len = sizeof(*tvlv_hdr) + tvlv_value_len; + + skb = netdev_alloc_skb_ip_align(NULL, ETH_HLEN + hdr_len + tvlv_len); + if (!skb) + goto out; + + skb->priority = TC_PRIO_CONTROL; + skb_reserve(skb, ETH_HLEN); + tvlv_buff = skb_put(skb, sizeof(*unicast_tvlv_packet) + tvlv_len); + unicast_tvlv_packet = (struct batadv_unicast_tvlv_packet *)tvlv_buff; + unicast_tvlv_packet->packet_type = BATADV_UNICAST_TVLV; + unicast_tvlv_packet->version = BATADV_COMPAT_VERSION; + unicast_tvlv_packet->ttl = BATADV_TTL; + unicast_tvlv_packet->reserved = 0; + unicast_tvlv_packet->tvlv_len = htons(tvlv_len); + unicast_tvlv_packet->align = 0; + ether_addr_copy(unicast_tvlv_packet->src, src); + ether_addr_copy(unicast_tvlv_packet->dst, dst); + + tvlv_buff = (unsigned char *)(unicast_tvlv_packet + 1); + tvlv_hdr = (struct batadv_tvlv_hdr *)tvlv_buff; + tvlv_hdr->version = version; + tvlv_hdr->type = type; + tvlv_hdr->len = htons(tvlv_value_len); + tvlv_buff += sizeof(*tvlv_hdr); + memcpy(tvlv_buff, tvlv_value, tvlv_value_len); + + batadv_send_skb_to_orig(skb, orig_node, NULL); +out: + batadv_orig_node_put(orig_node); +} diff --git a/net/batman-adv/tvlv.h b/net/batman-adv/tvlv.h new file mode 100644 index 000000000..4cf8af00f --- /dev/null +++ b/net/batman-adv/tvlv.h @@ -0,0 +1,49 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#ifndef _NET_BATMAN_ADV_TVLV_H_ +#define _NET_BATMAN_ADV_TVLV_H_ + +#include "main.h" + +#include +#include + +void batadv_tvlv_container_register(struct batadv_priv *bat_priv, + u8 type, u8 version, + void *tvlv_value, u16 tvlv_value_len); +u16 batadv_tvlv_container_ogm_append(struct batadv_priv *bat_priv, + unsigned char **packet_buff, + int *packet_buff_len, int packet_min_len); +void batadv_tvlv_ogm_receive(struct batadv_priv *bat_priv, + struct batadv_ogm_packet *batadv_ogm_packet, + struct batadv_orig_node *orig_node); +void batadv_tvlv_container_unregister(struct batadv_priv *bat_priv, + u8 type, u8 version); + +void batadv_tvlv_handler_register(struct batadv_priv *bat_priv, + void (*optr)(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig, + u8 flags, + void *tvlv_value, + u16 tvlv_value_len), + int (*uptr)(struct batadv_priv *bat_priv, + u8 *src, u8 *dst, + void *tvlv_value, + u16 tvlv_value_len), + u8 type, u8 version, u8 flags); +void batadv_tvlv_handler_unregister(struct batadv_priv *bat_priv, + u8 type, u8 version); +int batadv_tvlv_containers_process(struct batadv_priv *bat_priv, + bool ogm_source, + struct batadv_orig_node *orig_node, + u8 *src, u8 *dst, + void *tvlv_buff, u16 tvlv_buff_len); +void batadv_tvlv_unicast_send(struct batadv_priv *bat_priv, const u8 *src, + const u8 *dst, u8 type, u8 version, + void *tvlv_value, u16 tvlv_value_len); + +#endif /* _NET_BATMAN_ADV_TVLV_H_ */ diff --git a/net/batman-adv/types.h b/net/batman-adv/types.h new file mode 100644 index 000000000..76791815b --- /dev/null +++ b/net/batman-adv/types.h @@ -0,0 +1,2378 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright (C) B.A.T.M.A.N. contributors: + * + * Marek Lindner, Simon Wunderlich + */ + +#ifndef _NET_BATMAN_ADV_TYPES_H_ +#define _NET_BATMAN_ADV_TYPES_H_ + +#ifndef _NET_BATMAN_ADV_MAIN_H_ +#error only "main.h" can be included directly +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include /* for linux/wait.h */ +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CONFIG_BATMAN_ADV_DAT + +/** + * typedef batadv_dat_addr_t - type used for all DHT addresses + * + * If it is changed, BATADV_DAT_ADDR_MAX is changed as well. + * + * *Please be careful: batadv_dat_addr_t must be UNSIGNED* + */ +typedef u16 batadv_dat_addr_t; + +#endif /* CONFIG_BATMAN_ADV_DAT */ + +/** + * enum batadv_dhcp_recipient - dhcp destination + */ +enum batadv_dhcp_recipient { + /** @BATADV_DHCP_NO: packet is not a dhcp message */ + BATADV_DHCP_NO = 0, + + /** @BATADV_DHCP_TO_SERVER: dhcp message is directed to a server */ + BATADV_DHCP_TO_SERVER, + + /** @BATADV_DHCP_TO_CLIENT: dhcp message is directed to a client */ + BATADV_DHCP_TO_CLIENT, +}; + +/** + * BATADV_TT_REMOTE_MASK - bitmask selecting the flags that are sent over the + * wire only + */ +#define BATADV_TT_REMOTE_MASK 0x00FF + +/** + * BATADV_TT_SYNC_MASK - bitmask of the flags that need to be kept in sync + * among the nodes. These flags are used to compute the global/local CRC + */ +#define BATADV_TT_SYNC_MASK 0x00F0 + +/** + * struct batadv_hard_iface_bat_iv - per hard-interface B.A.T.M.A.N. IV data + */ +struct batadv_hard_iface_bat_iv { + /** @ogm_buff: buffer holding the OGM packet */ + unsigned char *ogm_buff; + + /** @ogm_buff_len: length of the OGM packet buffer */ + int ogm_buff_len; + + /** @ogm_seqno: OGM sequence number - used to identify each OGM */ + atomic_t ogm_seqno; + + /** @ogm_buff_mutex: lock protecting ogm_buff and ogm_buff_len */ + struct mutex ogm_buff_mutex; +}; + +/** + * enum batadv_v_hard_iface_flags - interface flags useful to B.A.T.M.A.N. V + */ +enum batadv_v_hard_iface_flags { + /** + * @BATADV_FULL_DUPLEX: tells if the connection over this link is + * full-duplex + */ + BATADV_FULL_DUPLEX = BIT(0), + + /** + * @BATADV_WARNING_DEFAULT: tells whether we have warned the user that + * no throughput data is available for this interface and that default + * values are assumed. + */ + BATADV_WARNING_DEFAULT = BIT(1), +}; + +/** + * struct batadv_hard_iface_bat_v - per hard-interface B.A.T.M.A.N. V data + */ +struct batadv_hard_iface_bat_v { + /** @elp_interval: time interval between two ELP transmissions */ + atomic_t elp_interval; + + /** @elp_seqno: current ELP sequence number */ + atomic_t elp_seqno; + + /** @elp_skb: base skb containing the ELP message to send */ + struct sk_buff *elp_skb; + + /** @elp_wq: workqueue used to schedule ELP transmissions */ + struct delayed_work elp_wq; + + /** @aggr_wq: workqueue used to transmit queued OGM packets */ + struct delayed_work aggr_wq; + + /** @aggr_list: queue for to be aggregated OGM packets */ + struct sk_buff_head aggr_list; + + /** @aggr_len: size of the OGM aggregate (excluding ethernet header) */ + unsigned int aggr_len; + + /** + * @throughput_override: throughput override to disable link + * auto-detection + */ + atomic_t throughput_override; + + /** @flags: interface specific flags */ + u8 flags; +}; + +/** + * enum batadv_hard_iface_wifi_flags - Flags describing the wifi configuration + * of a batadv_hard_iface + */ +enum batadv_hard_iface_wifi_flags { + /** @BATADV_HARDIF_WIFI_WEXT_DIRECT: it is a wext wifi device */ + BATADV_HARDIF_WIFI_WEXT_DIRECT = BIT(0), + + /** @BATADV_HARDIF_WIFI_CFG80211_DIRECT: it is a cfg80211 wifi device */ + BATADV_HARDIF_WIFI_CFG80211_DIRECT = BIT(1), + + /** + * @BATADV_HARDIF_WIFI_WEXT_INDIRECT: link device is a wext wifi device + */ + BATADV_HARDIF_WIFI_WEXT_INDIRECT = BIT(2), + + /** + * @BATADV_HARDIF_WIFI_CFG80211_INDIRECT: link device is a cfg80211 wifi + * device + */ + BATADV_HARDIF_WIFI_CFG80211_INDIRECT = BIT(3), +}; + +/** + * struct batadv_hard_iface - network device known to batman-adv + */ +struct batadv_hard_iface { + /** @list: list node for batadv_hardif_list */ + struct list_head list; + + /** @if_status: status of the interface for batman-adv */ + char if_status; + + /** + * @num_bcasts: number of payload re-broadcasts on this interface (ARQ) + */ + u8 num_bcasts; + + /** + * @wifi_flags: flags whether this is (directly or indirectly) a wifi + * interface + */ + u32 wifi_flags; + + /** @net_dev: pointer to the net_device */ + struct net_device *net_dev; + + /** @refcount: number of contexts the object is used */ + struct kref refcount; + + /** + * @batman_adv_ptype: packet type describing packets that should be + * processed by batman-adv for this interface + */ + struct packet_type batman_adv_ptype; + + /** + * @soft_iface: the batman-adv interface which uses this network + * interface + */ + struct net_device *soft_iface; + + /** @rcu: struct used for freeing in an RCU-safe manner */ + struct rcu_head rcu; + + /** + * @hop_penalty: penalty which will be applied to the tq-field + * of an OGM received via this interface + */ + atomic_t hop_penalty; + + /** @bat_iv: per hard-interface B.A.T.M.A.N. IV data */ + struct batadv_hard_iface_bat_iv bat_iv; + +#ifdef CONFIG_BATMAN_ADV_BATMAN_V + /** @bat_v: per hard-interface B.A.T.M.A.N. V data */ + struct batadv_hard_iface_bat_v bat_v; +#endif + + /** + * @neigh_list: list of unique single hop neighbors via this interface + */ + struct hlist_head neigh_list; + + /** @neigh_list_lock: lock protecting neigh_list */ + spinlock_t neigh_list_lock; +}; + +/** + * struct batadv_orig_ifinfo_bat_iv - B.A.T.M.A.N. IV private orig_ifinfo + * members + */ +struct batadv_orig_ifinfo_bat_iv { + /** + * @bcast_own: bitfield which counts the number of our OGMs this + * orig_node rebroadcasted "back" to us (relative to last_real_seqno) + */ + DECLARE_BITMAP(bcast_own, BATADV_TQ_LOCAL_WINDOW_SIZE); + + /** @bcast_own_sum: sum of bcast_own */ + u8 bcast_own_sum; +}; + +/** + * struct batadv_orig_ifinfo - originator info per outgoing interface + */ +struct batadv_orig_ifinfo { + /** @list: list node for &batadv_orig_node.ifinfo_list */ + struct hlist_node list; + + /** @if_outgoing: pointer to outgoing hard-interface */ + struct batadv_hard_iface *if_outgoing; + + /** @router: router that should be used to reach this originator */ + struct batadv_neigh_node __rcu *router; + + /** @last_real_seqno: last and best known sequence number */ + u32 last_real_seqno; + + /** @last_ttl: ttl of last received packet */ + u8 last_ttl; + + /** @last_seqno_forwarded: seqno of the OGM which was forwarded last */ + u32 last_seqno_forwarded; + + /** @batman_seqno_reset: time when the batman seqno window was reset */ + unsigned long batman_seqno_reset; + + /** @bat_iv: B.A.T.M.A.N. IV private structure */ + struct batadv_orig_ifinfo_bat_iv bat_iv; + + /** @refcount: number of contexts the object is used */ + struct kref refcount; + + /** @rcu: struct used for freeing in an RCU-safe manner */ + struct rcu_head rcu; +}; + +/** + * struct batadv_frag_table_entry - head in the fragment buffer table + */ +struct batadv_frag_table_entry { + /** @fragment_list: head of list with fragments */ + struct hlist_head fragment_list; + + /** @lock: lock to protect the list of fragments */ + spinlock_t lock; + + /** @timestamp: time (jiffie) of last received fragment */ + unsigned long timestamp; + + /** @seqno: sequence number of the fragments in the list */ + u16 seqno; + + /** @size: accumulated size of packets in list */ + u16 size; + + /** @total_size: expected size of the assembled packet */ + u16 total_size; +}; + +/** + * struct batadv_frag_list_entry - entry in a list of fragments + */ +struct batadv_frag_list_entry { + /** @list: list node information */ + struct hlist_node list; + + /** @skb: fragment */ + struct sk_buff *skb; + + /** @no: fragment number in the set */ + u8 no; +}; + +/** + * struct batadv_vlan_tt - VLAN specific TT attributes + */ +struct batadv_vlan_tt { + /** @crc: CRC32 checksum of the entries belonging to this vlan */ + u32 crc; + + /** @num_entries: number of TT entries for this VLAN */ + atomic_t num_entries; +}; + +/** + * struct batadv_orig_node_vlan - VLAN specific data per orig_node + */ +struct batadv_orig_node_vlan { + /** @vid: the VLAN identifier */ + unsigned short vid; + + /** @tt: VLAN specific TT attributes */ + struct batadv_vlan_tt tt; + + /** @list: list node for &batadv_orig_node.vlan_list */ + struct hlist_node list; + + /** + * @refcount: number of context where this object is currently in use + */ + struct kref refcount; + + /** @rcu: struct used for freeing in a RCU-safe manner */ + struct rcu_head rcu; +}; + +/** + * struct batadv_orig_bat_iv - B.A.T.M.A.N. IV private orig_node members + */ +struct batadv_orig_bat_iv { + /** + * @ogm_cnt_lock: lock protecting &batadv_orig_ifinfo_bat_iv.bcast_own, + * &batadv_orig_ifinfo_bat_iv.bcast_own_sum, + * &batadv_neigh_ifinfo_bat_iv.bat_iv.real_bits and + * &batadv_neigh_ifinfo_bat_iv.real_packet_count + */ + spinlock_t ogm_cnt_lock; +}; + +/** + * struct batadv_orig_node - structure for orig_list maintaining nodes of mesh + */ +struct batadv_orig_node { + /** @orig: originator ethernet address */ + u8 orig[ETH_ALEN]; + + /** @ifinfo_list: list for routers per outgoing interface */ + struct hlist_head ifinfo_list; + + /** + * @last_bonding_candidate: pointer to last ifinfo of last used router + */ + struct batadv_orig_ifinfo *last_bonding_candidate; + +#ifdef CONFIG_BATMAN_ADV_DAT + /** @dat_addr: address of the orig node in the distributed hash */ + batadv_dat_addr_t dat_addr; +#endif + + /** @last_seen: time when last packet from this node was received */ + unsigned long last_seen; + + /** + * @bcast_seqno_reset: time when the broadcast seqno window was reset + */ + unsigned long bcast_seqno_reset; + +#ifdef CONFIG_BATMAN_ADV_MCAST + /** + * @mcast_handler_lock: synchronizes mcast-capability and -flag changes + */ + spinlock_t mcast_handler_lock; + + /** @mcast_flags: multicast flags announced by the orig node */ + u8 mcast_flags; + + /** + * @mcast_want_all_unsnoopables_node: a list node for the + * mcast.want_all_unsnoopables list + */ + struct hlist_node mcast_want_all_unsnoopables_node; + + /** + * @mcast_want_all_ipv4_node: a list node for the mcast.want_all_ipv4 + * list + */ + struct hlist_node mcast_want_all_ipv4_node; + /** + * @mcast_want_all_ipv6_node: a list node for the mcast.want_all_ipv6 + * list + */ + struct hlist_node mcast_want_all_ipv6_node; + + /** + * @mcast_want_all_rtr4_node: a list node for the mcast.want_all_rtr4 + * list + */ + struct hlist_node mcast_want_all_rtr4_node; + /** + * @mcast_want_all_rtr6_node: a list node for the mcast.want_all_rtr6 + * list + */ + struct hlist_node mcast_want_all_rtr6_node; +#endif + + /** @capabilities: announced capabilities of this originator */ + unsigned long capabilities; + + /** + * @capa_initialized: bitfield to remember whether a capability was + * initialized + */ + unsigned long capa_initialized; + + /** @last_ttvn: last seen translation table version number */ + atomic_t last_ttvn; + + /** @tt_buff: last tt changeset this node received from the orig node */ + unsigned char *tt_buff; + + /** + * @tt_buff_len: length of the last tt changeset this node received + * from the orig node + */ + s16 tt_buff_len; + + /** @tt_buff_lock: lock that protects tt_buff and tt_buff_len */ + spinlock_t tt_buff_lock; + + /** + * @tt_lock: avoids concurrent read from and write to the table. Table + * update is made up of two operations (data structure update and + * metadata -CRC/TTVN-recalculation) and they have to be executed + * atomically in order to avoid another thread to read the + * table/metadata between those. + */ + spinlock_t tt_lock; + + /** + * @bcast_bits: bitfield containing the info which payload broadcast + * originated from this orig node this host already has seen (relative + * to last_bcast_seqno) + */ + DECLARE_BITMAP(bcast_bits, BATADV_TQ_LOCAL_WINDOW_SIZE); + + /** + * @last_bcast_seqno: last broadcast sequence number received by this + * host + */ + u32 last_bcast_seqno; + + /** + * @neigh_list: list of potential next hop neighbor towards this orig + * node + */ + struct hlist_head neigh_list; + + /** + * @neigh_list_lock: lock protecting neigh_list, ifinfo_list, + * last_bonding_candidate and router + */ + spinlock_t neigh_list_lock; + + /** @hash_entry: hlist node for &batadv_priv.orig_hash */ + struct hlist_node hash_entry; + + /** @bat_priv: pointer to soft_iface this orig node belongs to */ + struct batadv_priv *bat_priv; + + /** @bcast_seqno_lock: lock protecting bcast_bits & last_bcast_seqno */ + spinlock_t bcast_seqno_lock; + + /** @refcount: number of contexts the object is used */ + struct kref refcount; + + /** @rcu: struct used for freeing in an RCU-safe manner */ + struct rcu_head rcu; + +#ifdef CONFIG_BATMAN_ADV_NC + /** @in_coding_list: list of nodes this orig can hear */ + struct list_head in_coding_list; + + /** @out_coding_list: list of nodes that can hear this orig */ + struct list_head out_coding_list; + + /** @in_coding_list_lock: protects in_coding_list */ + spinlock_t in_coding_list_lock; + + /** @out_coding_list_lock: protects out_coding_list */ + spinlock_t out_coding_list_lock; +#endif + + /** @fragments: array with heads for fragment chains */ + struct batadv_frag_table_entry fragments[BATADV_FRAG_BUFFER_COUNT]; + + /** + * @vlan_list: a list of orig_node_vlan structs, one per VLAN served by + * the originator represented by this object + */ + struct hlist_head vlan_list; + + /** @vlan_list_lock: lock protecting vlan_list */ + spinlock_t vlan_list_lock; + + /** @bat_iv: B.A.T.M.A.N. IV private structure */ + struct batadv_orig_bat_iv bat_iv; +}; + +/** + * enum batadv_orig_capabilities - orig node capabilities + */ +enum batadv_orig_capabilities { + /** + * @BATADV_ORIG_CAPA_HAS_DAT: orig node has distributed arp table + * enabled + */ + BATADV_ORIG_CAPA_HAS_DAT, + + /** @BATADV_ORIG_CAPA_HAS_NC: orig node has network coding enabled */ + BATADV_ORIG_CAPA_HAS_NC, + + /** @BATADV_ORIG_CAPA_HAS_TT: orig node has tt capability */ + BATADV_ORIG_CAPA_HAS_TT, + + /** + * @BATADV_ORIG_CAPA_HAS_MCAST: orig node has some multicast capability + * (= orig node announces a tvlv of type BATADV_TVLV_MCAST) + */ + BATADV_ORIG_CAPA_HAS_MCAST, +}; + +/** + * struct batadv_gw_node - structure for orig nodes announcing gw capabilities + */ +struct batadv_gw_node { + /** @list: list node for &batadv_priv_gw.list */ + struct hlist_node list; + + /** @orig_node: pointer to corresponding orig node */ + struct batadv_orig_node *orig_node; + + /** @bandwidth_down: advertised uplink download bandwidth */ + u32 bandwidth_down; + + /** @bandwidth_up: advertised uplink upload bandwidth */ + u32 bandwidth_up; + + /** @refcount: number of contexts the object is used */ + struct kref refcount; + + /** @rcu: struct used for freeing in an RCU-safe manner */ + struct rcu_head rcu; +}; + +DECLARE_EWMA(throughput, 10, 8) + +/** + * struct batadv_hardif_neigh_node_bat_v - B.A.T.M.A.N. V private neighbor + * information + */ +struct batadv_hardif_neigh_node_bat_v { + /** @throughput: ewma link throughput towards this neighbor */ + struct ewma_throughput throughput; + + /** @elp_interval: time interval between two ELP transmissions */ + u32 elp_interval; + + /** @elp_latest_seqno: latest and best known ELP sequence number */ + u32 elp_latest_seqno; + + /** + * @last_unicast_tx: when the last unicast packet has been sent to this + * neighbor + */ + unsigned long last_unicast_tx; + + /** @metric_work: work queue callback item for metric update */ + struct work_struct metric_work; +}; + +/** + * struct batadv_hardif_neigh_node - unique neighbor per hard-interface + */ +struct batadv_hardif_neigh_node { + /** @list: list node for &batadv_hard_iface.neigh_list */ + struct hlist_node list; + + /** @addr: the MAC address of the neighboring interface */ + u8 addr[ETH_ALEN]; + + /** + * @orig: the address of the originator this neighbor node belongs to + */ + u8 orig[ETH_ALEN]; + + /** @if_incoming: pointer to incoming hard-interface */ + struct batadv_hard_iface *if_incoming; + + /** @last_seen: when last packet via this neighbor was received */ + unsigned long last_seen; + +#ifdef CONFIG_BATMAN_ADV_BATMAN_V + /** @bat_v: B.A.T.M.A.N. V private data */ + struct batadv_hardif_neigh_node_bat_v bat_v; +#endif + + /** @refcount: number of contexts the object is used */ + struct kref refcount; + + /** @rcu: struct used for freeing in a RCU-safe manner */ + struct rcu_head rcu; +}; + +/** + * struct batadv_neigh_node - structure for single hops neighbors + */ +struct batadv_neigh_node { + /** @list: list node for &batadv_orig_node.neigh_list */ + struct hlist_node list; + + /** @orig_node: pointer to corresponding orig_node */ + struct batadv_orig_node *orig_node; + + /** @addr: the MAC address of the neighboring interface */ + u8 addr[ETH_ALEN]; + + /** @ifinfo_list: list for routing metrics per outgoing interface */ + struct hlist_head ifinfo_list; + + /** @ifinfo_lock: lock protecting ifinfo_list and its members */ + spinlock_t ifinfo_lock; + + /** @if_incoming: pointer to incoming hard-interface */ + struct batadv_hard_iface *if_incoming; + + /** @last_seen: when last packet via this neighbor was received */ + unsigned long last_seen; + + /** @hardif_neigh: hardif_neigh of this neighbor */ + struct batadv_hardif_neigh_node *hardif_neigh; + + /** @refcount: number of contexts the object is used */ + struct kref refcount; + + /** @rcu: struct used for freeing in an RCU-safe manner */ + struct rcu_head rcu; +}; + +/** + * struct batadv_neigh_ifinfo_bat_iv - neighbor information per outgoing + * interface for B.A.T.M.A.N. IV + */ +struct batadv_neigh_ifinfo_bat_iv { + /** @tq_recv: ring buffer of received TQ values from this neigh node */ + u8 tq_recv[BATADV_TQ_GLOBAL_WINDOW_SIZE]; + + /** @tq_index: ring buffer index */ + u8 tq_index; + + /** + * @tq_avg: averaged tq of all tq values in the ring buffer (tq_recv) + */ + u8 tq_avg; + + /** + * @real_bits: bitfield containing the number of OGMs received from this + * neigh node (relative to orig_node->last_real_seqno) + */ + DECLARE_BITMAP(real_bits, BATADV_TQ_LOCAL_WINDOW_SIZE); + + /** @real_packet_count: counted result of real_bits */ + u8 real_packet_count; +}; + +/** + * struct batadv_neigh_ifinfo_bat_v - neighbor information per outgoing + * interface for B.A.T.M.A.N. V + */ +struct batadv_neigh_ifinfo_bat_v { + /** + * @throughput: last throughput metric received from originator via this + * neigh + */ + u32 throughput; + + /** @last_seqno: last sequence number known for this neighbor */ + u32 last_seqno; +}; + +/** + * struct batadv_neigh_ifinfo - neighbor information per outgoing interface + */ +struct batadv_neigh_ifinfo { + /** @list: list node for &batadv_neigh_node.ifinfo_list */ + struct hlist_node list; + + /** @if_outgoing: pointer to outgoing hard-interface */ + struct batadv_hard_iface *if_outgoing; + + /** @bat_iv: B.A.T.M.A.N. IV private structure */ + struct batadv_neigh_ifinfo_bat_iv bat_iv; + +#ifdef CONFIG_BATMAN_ADV_BATMAN_V + /** @bat_v: B.A.T.M.A.N. V private data */ + struct batadv_neigh_ifinfo_bat_v bat_v; +#endif + + /** @last_ttl: last received ttl from this neigh node */ + u8 last_ttl; + + /** @refcount: number of contexts the object is used */ + struct kref refcount; + + /** @rcu: struct used for freeing in a RCU-safe manner */ + struct rcu_head rcu; +}; + +#ifdef CONFIG_BATMAN_ADV_BLA + +/** + * struct batadv_bcast_duplist_entry - structure for LAN broadcast suppression + */ +struct batadv_bcast_duplist_entry { + /** @orig: mac address of orig node originating the broadcast */ + u8 orig[ETH_ALEN]; + + /** @crc: crc32 checksum of broadcast payload */ + __be32 crc; + + /** @entrytime: time when the broadcast packet was received */ + unsigned long entrytime; +}; +#endif + +/** + * enum batadv_counters - indices for traffic counters + */ +enum batadv_counters { + /** @BATADV_CNT_TX: transmitted payload traffic packet counter */ + BATADV_CNT_TX, + + /** @BATADV_CNT_TX_BYTES: transmitted payload traffic bytes counter */ + BATADV_CNT_TX_BYTES, + + /** + * @BATADV_CNT_TX_DROPPED: dropped transmission payload traffic packet + * counter + */ + BATADV_CNT_TX_DROPPED, + + /** @BATADV_CNT_RX: received payload traffic packet counter */ + BATADV_CNT_RX, + + /** @BATADV_CNT_RX_BYTES: received payload traffic bytes counter */ + BATADV_CNT_RX_BYTES, + + /** @BATADV_CNT_FORWARD: forwarded payload traffic packet counter */ + BATADV_CNT_FORWARD, + + /** + * @BATADV_CNT_FORWARD_BYTES: forwarded payload traffic bytes counter + */ + BATADV_CNT_FORWARD_BYTES, + + /** + * @BATADV_CNT_MGMT_TX: transmitted routing protocol traffic packet + * counter + */ + BATADV_CNT_MGMT_TX, + + /** + * @BATADV_CNT_MGMT_TX_BYTES: transmitted routing protocol traffic bytes + * counter + */ + BATADV_CNT_MGMT_TX_BYTES, + + /** + * @BATADV_CNT_MGMT_RX: received routing protocol traffic packet counter + */ + BATADV_CNT_MGMT_RX, + + /** + * @BATADV_CNT_MGMT_RX_BYTES: received routing protocol traffic bytes + * counter + */ + BATADV_CNT_MGMT_RX_BYTES, + + /** @BATADV_CNT_FRAG_TX: transmitted fragment traffic packet counter */ + BATADV_CNT_FRAG_TX, + + /** + * @BATADV_CNT_FRAG_TX_BYTES: transmitted fragment traffic bytes counter + */ + BATADV_CNT_FRAG_TX_BYTES, + + /** @BATADV_CNT_FRAG_RX: received fragment traffic packet counter */ + BATADV_CNT_FRAG_RX, + + /** + * @BATADV_CNT_FRAG_RX_BYTES: received fragment traffic bytes counter + */ + BATADV_CNT_FRAG_RX_BYTES, + + /** @BATADV_CNT_FRAG_FWD: forwarded fragment traffic packet counter */ + BATADV_CNT_FRAG_FWD, + + /** + * @BATADV_CNT_FRAG_FWD_BYTES: forwarded fragment traffic bytes counter + */ + BATADV_CNT_FRAG_FWD_BYTES, + + /** + * @BATADV_CNT_TT_REQUEST_TX: transmitted tt req traffic packet counter + */ + BATADV_CNT_TT_REQUEST_TX, + + /** @BATADV_CNT_TT_REQUEST_RX: received tt req traffic packet counter */ + BATADV_CNT_TT_REQUEST_RX, + + /** + * @BATADV_CNT_TT_RESPONSE_TX: transmitted tt resp traffic packet + * counter + */ + BATADV_CNT_TT_RESPONSE_TX, + + /** + * @BATADV_CNT_TT_RESPONSE_RX: received tt resp traffic packet counter + */ + BATADV_CNT_TT_RESPONSE_RX, + + /** + * @BATADV_CNT_TT_ROAM_ADV_TX: transmitted tt roam traffic packet + * counter + */ + BATADV_CNT_TT_ROAM_ADV_TX, + + /** + * @BATADV_CNT_TT_ROAM_ADV_RX: received tt roam traffic packet counter + */ + BATADV_CNT_TT_ROAM_ADV_RX, + +#ifdef CONFIG_BATMAN_ADV_DAT + /** + * @BATADV_CNT_DAT_GET_TX: transmitted dht GET traffic packet counter + */ + BATADV_CNT_DAT_GET_TX, + + /** @BATADV_CNT_DAT_GET_RX: received dht GET traffic packet counter */ + BATADV_CNT_DAT_GET_RX, + + /** + * @BATADV_CNT_DAT_PUT_TX: transmitted dht PUT traffic packet counter + */ + BATADV_CNT_DAT_PUT_TX, + + /** @BATADV_CNT_DAT_PUT_RX: received dht PUT traffic packet counter */ + BATADV_CNT_DAT_PUT_RX, + + /** + * @BATADV_CNT_DAT_CACHED_REPLY_TX: transmitted dat cache reply traffic + * packet counter + */ + BATADV_CNT_DAT_CACHED_REPLY_TX, +#endif + +#ifdef CONFIG_BATMAN_ADV_NC + /** + * @BATADV_CNT_NC_CODE: transmitted nc-combined traffic packet counter + */ + BATADV_CNT_NC_CODE, + + /** + * @BATADV_CNT_NC_CODE_BYTES: transmitted nc-combined traffic bytes + * counter + */ + BATADV_CNT_NC_CODE_BYTES, + + /** + * @BATADV_CNT_NC_RECODE: transmitted nc-recombined traffic packet + * counter + */ + BATADV_CNT_NC_RECODE, + + /** + * @BATADV_CNT_NC_RECODE_BYTES: transmitted nc-recombined traffic bytes + * counter + */ + BATADV_CNT_NC_RECODE_BYTES, + + /** + * @BATADV_CNT_NC_BUFFER: counter for packets buffered for later nc + * decoding + */ + BATADV_CNT_NC_BUFFER, + + /** + * @BATADV_CNT_NC_DECODE: received and nc-decoded traffic packet counter + */ + BATADV_CNT_NC_DECODE, + + /** + * @BATADV_CNT_NC_DECODE_BYTES: received and nc-decoded traffic bytes + * counter + */ + BATADV_CNT_NC_DECODE_BYTES, + + /** + * @BATADV_CNT_NC_DECODE_FAILED: received and decode-failed traffic + * packet counter + */ + BATADV_CNT_NC_DECODE_FAILED, + + /** + * @BATADV_CNT_NC_SNIFFED: counter for nc-decoded packets received in + * promisc mode. + */ + BATADV_CNT_NC_SNIFFED, +#endif + + /** @BATADV_CNT_NUM: number of traffic counters */ + BATADV_CNT_NUM, +}; + +/** + * struct batadv_priv_tt - per mesh interface translation table data + */ +struct batadv_priv_tt { + /** @vn: translation table version number */ + atomic_t vn; + + /** + * @ogm_append_cnt: counter of number of OGMs containing the local tt + * diff + */ + atomic_t ogm_append_cnt; + + /** @local_changes: changes registered in an originator interval */ + atomic_t local_changes; + + /** + * @changes_list: tracks tt local changes within an originator interval + */ + struct list_head changes_list; + + /** @local_hash: local translation table hash table */ + struct batadv_hashtable *local_hash; + + /** @global_hash: global translation table hash table */ + struct batadv_hashtable *global_hash; + + /** @req_list: list of pending & unanswered tt_requests */ + struct hlist_head req_list; + + /** + * @roam_list: list of the last roaming events of each client limiting + * the number of roaming events to avoid route flapping + */ + struct list_head roam_list; + + /** @changes_list_lock: lock protecting changes_list */ + spinlock_t changes_list_lock; + + /** @req_list_lock: lock protecting req_list */ + spinlock_t req_list_lock; + + /** @roam_list_lock: lock protecting roam_list */ + spinlock_t roam_list_lock; + + /** @last_changeset: last tt changeset this host has generated */ + unsigned char *last_changeset; + + /** + * @last_changeset_len: length of last tt changeset this host has + * generated + */ + s16 last_changeset_len; + + /** + * @last_changeset_lock: lock protecting last_changeset & + * last_changeset_len + */ + spinlock_t last_changeset_lock; + + /** + * @commit_lock: prevents from executing a local TT commit while reading + * the local table. The local TT commit is made up of two operations + * (data structure update and metadata -CRC/TTVN- recalculation) and + * they have to be executed atomically in order to avoid another thread + * to read the table/metadata between those. + */ + spinlock_t commit_lock; + + /** @work: work queue callback item for translation table purging */ + struct delayed_work work; +}; + +#ifdef CONFIG_BATMAN_ADV_BLA + +/** + * struct batadv_priv_bla - per mesh interface bridge loop avoidance data + */ +struct batadv_priv_bla { + /** @num_requests: number of bla requests in flight */ + atomic_t num_requests; + + /** + * @claim_hash: hash table containing mesh nodes this host has claimed + */ + struct batadv_hashtable *claim_hash; + + /** + * @backbone_hash: hash table containing all detected backbone gateways + */ + struct batadv_hashtable *backbone_hash; + + /** @loopdetect_addr: MAC address used for own loopdetection frames */ + u8 loopdetect_addr[ETH_ALEN]; + + /** + * @loopdetect_lasttime: time when the loopdetection frames were sent + */ + unsigned long loopdetect_lasttime; + + /** + * @loopdetect_next: how many periods to wait for the next loopdetect + * process + */ + atomic_t loopdetect_next; + + /** + * @bcast_duplist: recently received broadcast packets array (for + * broadcast duplicate suppression) + */ + struct batadv_bcast_duplist_entry bcast_duplist[BATADV_DUPLIST_SIZE]; + + /** + * @bcast_duplist_curr: index of last broadcast packet added to + * bcast_duplist + */ + int bcast_duplist_curr; + + /** + * @bcast_duplist_lock: lock protecting bcast_duplist & + * bcast_duplist_curr + */ + spinlock_t bcast_duplist_lock; + + /** @claim_dest: local claim data (e.g. claim group) */ + struct batadv_bla_claim_dst claim_dest; + + /** @work: work queue callback item for cleanups & bla announcements */ + struct delayed_work work; +}; +#endif + +#ifdef CONFIG_BATMAN_ADV_DEBUG + +/** + * struct batadv_priv_debug_log - debug logging data + */ +struct batadv_priv_debug_log { + /** @log_buff: buffer holding the logs (ring buffer) */ + char log_buff[BATADV_LOG_BUF_LEN]; + + /** @log_start: index of next character to read */ + unsigned long log_start; + + /** @log_end: index of next character to write */ + unsigned long log_end; + + /** @lock: lock protecting log_buff, log_start & log_end */ + spinlock_t lock; + + /** @queue_wait: log reader's wait queue */ + wait_queue_head_t queue_wait; +}; +#endif + +/** + * struct batadv_priv_gw - per mesh interface gateway data + */ +struct batadv_priv_gw { + /** @gateway_list: list of available gateway nodes */ + struct hlist_head gateway_list; + + /** @list_lock: lock protecting gateway_list, curr_gw, generation */ + spinlock_t list_lock; + + /** @curr_gw: pointer to currently selected gateway node */ + struct batadv_gw_node __rcu *curr_gw; + + /** @generation: current (generation) sequence number */ + unsigned int generation; + + /** + * @mode: gateway operation: off, client or server (see batadv_gw_modes) + */ + atomic_t mode; + + /** @sel_class: gateway selection class (applies if gw_mode client) */ + atomic_t sel_class; + + /** + * @bandwidth_down: advertised uplink download bandwidth (if gw_mode + * server) + */ + atomic_t bandwidth_down; + + /** + * @bandwidth_up: advertised uplink upload bandwidth (if gw_mode server) + */ + atomic_t bandwidth_up; + + /** @reselect: bool indicating a gateway re-selection is in progress */ + atomic_t reselect; +}; + +/** + * struct batadv_priv_tvlv - per mesh interface tvlv data + */ +struct batadv_priv_tvlv { + /** + * @container_list: list of registered tvlv containers to be sent with + * each OGM + */ + struct hlist_head container_list; + + /** @handler_list: list of the various tvlv content handlers */ + struct hlist_head handler_list; + + /** @container_list_lock: protects tvlv container list access */ + spinlock_t container_list_lock; + + /** @handler_list_lock: protects handler list access */ + spinlock_t handler_list_lock; +}; + +#ifdef CONFIG_BATMAN_ADV_DAT + +/** + * struct batadv_priv_dat - per mesh interface DAT private data + */ +struct batadv_priv_dat { + /** @addr: node DAT address */ + batadv_dat_addr_t addr; + + /** @hash: hashtable representing the local ARP cache */ + struct batadv_hashtable *hash; + + /** @work: work queue callback item for cache purging */ + struct delayed_work work; +}; +#endif + +#ifdef CONFIG_BATMAN_ADV_MCAST +/** + * struct batadv_mcast_querier_state - IGMP/MLD querier state when bridged + */ +struct batadv_mcast_querier_state { + /** @exists: whether a querier exists in the mesh */ + unsigned char exists:1; + + /** + * @shadowing: if a querier exists, whether it is potentially shadowing + * multicast listeners (i.e. querier is behind our own bridge segment) + */ + unsigned char shadowing:1; +}; + +/** + * struct batadv_mcast_mla_flags - flags for the querier, bridge and tvlv state + */ +struct batadv_mcast_mla_flags { + /** @querier_ipv4: the current state of an IGMP querier in the mesh */ + struct batadv_mcast_querier_state querier_ipv4; + + /** @querier_ipv6: the current state of an MLD querier in the mesh */ + struct batadv_mcast_querier_state querier_ipv6; + + /** @enabled: whether the multicast tvlv is currently enabled */ + unsigned char enabled:1; + + /** @bridged: whether the soft interface has a bridge on top */ + unsigned char bridged:1; + + /** @tvlv_flags: the flags we have last sent in our mcast tvlv */ + u8 tvlv_flags; +}; + +/** + * struct batadv_priv_mcast - per mesh interface mcast data + */ +struct batadv_priv_mcast { + /** + * @mla_list: list of multicast addresses we are currently announcing + * via TT + */ + struct hlist_head mla_list; /* see __batadv_mcast_mla_update() */ + + /** + * @want_all_unsnoopables_list: a list of orig_nodes wanting all + * unsnoopable multicast traffic + */ + struct hlist_head want_all_unsnoopables_list; + + /** + * @want_all_ipv4_list: a list of orig_nodes wanting all IPv4 multicast + * traffic + */ + struct hlist_head want_all_ipv4_list; + + /** + * @want_all_ipv6_list: a list of orig_nodes wanting all IPv6 multicast + * traffic + */ + struct hlist_head want_all_ipv6_list; + + /** + * @want_all_rtr4_list: a list of orig_nodes wanting all routable IPv4 + * multicast traffic + */ + struct hlist_head want_all_rtr4_list; + + /** + * @want_all_rtr6_list: a list of orig_nodes wanting all routable IPv6 + * multicast traffic + */ + struct hlist_head want_all_rtr6_list; + + /** + * @mla_flags: flags for the querier, bridge and tvlv state + */ + struct batadv_mcast_mla_flags mla_flags; + + /** + * @mla_lock: a lock protecting mla_list and mla_flags + */ + spinlock_t mla_lock; + + /** + * @num_want_all_unsnoopables: number of nodes wanting unsnoopable IP + * traffic + */ + atomic_t num_want_all_unsnoopables; + + /** @num_want_all_ipv4: counter for items in want_all_ipv4_list */ + atomic_t num_want_all_ipv4; + + /** @num_want_all_ipv6: counter for items in want_all_ipv6_list */ + atomic_t num_want_all_ipv6; + + /** @num_want_all_rtr4: counter for items in want_all_rtr4_list */ + atomic_t num_want_all_rtr4; + + /** @num_want_all_rtr6: counter for items in want_all_rtr6_list */ + atomic_t num_want_all_rtr6; + + /** + * @want_lists_lock: lock for protecting modifications to mcasts + * want_all_{unsnoopables,ipv4,ipv6}_list (traversals are rcu-locked) + */ + spinlock_t want_lists_lock; + + /** @work: work queue callback item for multicast TT and TVLV updates */ + struct delayed_work work; +}; +#endif + +/** + * struct batadv_priv_nc - per mesh interface network coding private data + */ +struct batadv_priv_nc { + /** @work: work queue callback item for cleanup */ + struct delayed_work work; + + /** + * @min_tq: only consider neighbors for encoding if neigh_tq > min_tq + */ + u8 min_tq; + + /** + * @max_fwd_delay: maximum packet forward delay to allow coding of + * packets + */ + u32 max_fwd_delay; + + /** + * @max_buffer_time: buffer time for sniffed packets used to decoding + */ + u32 max_buffer_time; + + /** + * @timestamp_fwd_flush: timestamp of last forward packet queue flush + */ + unsigned long timestamp_fwd_flush; + + /** + * @timestamp_sniffed_purge: timestamp of last sniffed packet queue + * purge + */ + unsigned long timestamp_sniffed_purge; + + /** + * @coding_hash: Hash table used to buffer skbs while waiting for + * another incoming skb to code it with. Skbs are added to the buffer + * just before being forwarded in routing.c + */ + struct batadv_hashtable *coding_hash; + + /** + * @decoding_hash: Hash table used to buffer skbs that might be needed + * to decode a received coded skb. The buffer is used for 1) skbs + * arriving on the soft-interface; 2) skbs overheard on the + * hard-interface; and 3) skbs forwarded by batman-adv. + */ + struct batadv_hashtable *decoding_hash; +}; + +/** + * struct batadv_tp_unacked - unacked packet meta-information + * + * This struct is supposed to represent a buffer unacked packet. However, since + * the purpose of the TP meter is to count the traffic only, there is no need to + * store the entire sk_buff, the starting offset and the length are enough + */ +struct batadv_tp_unacked { + /** @seqno: seqno of the unacked packet */ + u32 seqno; + + /** @len: length of the packet */ + u16 len; + + /** @list: list node for &batadv_tp_vars.unacked_list */ + struct list_head list; +}; + +/** + * enum batadv_tp_meter_role - Modus in tp meter session + */ +enum batadv_tp_meter_role { + /** @BATADV_TP_RECEIVER: Initialized as receiver */ + BATADV_TP_RECEIVER, + + /** @BATADV_TP_SENDER: Initialized as sender */ + BATADV_TP_SENDER +}; + +/** + * struct batadv_tp_vars - tp meter private variables per session + */ +struct batadv_tp_vars { + /** @list: list node for &bat_priv.tp_list */ + struct hlist_node list; + + /** @timer: timer for ack (receiver) and retry (sender) */ + struct timer_list timer; + + /** @bat_priv: pointer to the mesh object */ + struct batadv_priv *bat_priv; + + /** @start_time: start time in jiffies */ + unsigned long start_time; + + /** @other_end: mac address of remote */ + u8 other_end[ETH_ALEN]; + + /** @role: receiver/sender modi */ + enum batadv_tp_meter_role role; + + /** @sending: sending binary semaphore: 1 if sending, 0 is not */ + atomic_t sending; + + /** @reason: reason for a stopped session */ + enum batadv_tp_meter_reason reason; + + /** @finish_work: work item for the finishing procedure */ + struct delayed_work finish_work; + + /** @test_length: test length in milliseconds */ + u32 test_length; + + /** @session: TP session identifier */ + u8 session[2]; + + /** @icmp_uid: local ICMP "socket" index */ + u8 icmp_uid; + + /* sender variables */ + + /** @dec_cwnd: decimal part of the cwnd used during linear growth */ + u16 dec_cwnd; + + /** @cwnd: current size of the congestion window */ + u32 cwnd; + + /** @cwnd_lock: lock do protect @cwnd & @dec_cwnd */ + spinlock_t cwnd_lock; + + /** + * @ss_threshold: Slow Start threshold. Once cwnd exceeds this value the + * connection switches to the Congestion Avoidance state + */ + u32 ss_threshold; + + /** @last_acked: last acked byte */ + atomic_t last_acked; + + /** @last_sent: last sent byte, not yet acked */ + u32 last_sent; + + /** @tot_sent: amount of data sent/ACKed so far */ + atomic64_t tot_sent; + + /** @dup_acks: duplicate ACKs counter */ + atomic_t dup_acks; + + /** @fast_recovery: true if in Fast Recovery mode */ + unsigned char fast_recovery:1; + + /** @recover: last sent seqno when entering Fast Recovery */ + u32 recover; + + /** @rto: sender timeout */ + u32 rto; + + /** @srtt: smoothed RTT scaled by 2^3 */ + u32 srtt; + + /** @rttvar: RTT variation scaled by 2^2 */ + u32 rttvar; + + /** + * @more_bytes: waiting queue anchor when waiting for more ack/retry + * timeout + */ + wait_queue_head_t more_bytes; + + /** @prerandom_offset: offset inside the prerandom buffer */ + u32 prerandom_offset; + + /** @prerandom_lock: spinlock protecting access to prerandom_offset */ + spinlock_t prerandom_lock; + + /* receiver variables */ + + /** @last_recv: last in-order received packet */ + u32 last_recv; + + /** @unacked_list: list of unacked packets (meta-info only) */ + struct list_head unacked_list; + + /** @unacked_lock: protect unacked_list */ + spinlock_t unacked_lock; + + /** @last_recv_time: time (jiffies) a msg was received */ + unsigned long last_recv_time; + + /** @refcount: number of context where the object is used */ + struct kref refcount; + + /** @rcu: struct used for freeing in an RCU-safe manner */ + struct rcu_head rcu; +}; + +/** + * struct batadv_softif_vlan - per VLAN attributes set + */ +struct batadv_softif_vlan { + /** @bat_priv: pointer to the mesh object */ + struct batadv_priv *bat_priv; + + /** @vid: VLAN identifier */ + unsigned short vid; + + /** @ap_isolation: AP isolation state */ + atomic_t ap_isolation; /* boolean */ + + /** @tt: TT private attributes (VLAN specific) */ + struct batadv_vlan_tt tt; + + /** @list: list node for &bat_priv.softif_vlan_list */ + struct hlist_node list; + + /** + * @refcount: number of context where this object is currently in use + */ + struct kref refcount; + + /** @rcu: struct used for freeing in a RCU-safe manner */ + struct rcu_head rcu; +}; + +/** + * struct batadv_priv_bat_v - B.A.T.M.A.N. V per soft-interface private data + */ +struct batadv_priv_bat_v { + /** @ogm_buff: buffer holding the OGM packet */ + unsigned char *ogm_buff; + + /** @ogm_buff_len: length of the OGM packet buffer */ + int ogm_buff_len; + + /** @ogm_seqno: OGM sequence number - used to identify each OGM */ + atomic_t ogm_seqno; + + /** @ogm_buff_mutex: lock protecting ogm_buff and ogm_buff_len */ + struct mutex ogm_buff_mutex; + + /** @ogm_wq: workqueue used to schedule OGM transmissions */ + struct delayed_work ogm_wq; +}; + +/** + * struct batadv_priv - per mesh interface data + */ +struct batadv_priv { + /** + * @mesh_state: current status of the mesh + * (inactive/active/deactivating) + */ + atomic_t mesh_state; + + /** @soft_iface: net device which holds this struct as private data */ + struct net_device *soft_iface; + + /** + * @mtu_set_by_user: MTU was set once by user + * protected by rtnl_lock + */ + int mtu_set_by_user; + + /** + * @bat_counters: mesh internal traffic statistic counters (see + * batadv_counters) + */ + u64 __percpu *bat_counters; /* Per cpu counters */ + + /** + * @aggregated_ogms: bool indicating whether OGM aggregation is enabled + */ + atomic_t aggregated_ogms; + + /** @bonding: bool indicating whether traffic bonding is enabled */ + atomic_t bonding; + + /** + * @fragmentation: bool indicating whether traffic fragmentation is + * enabled + */ + atomic_t fragmentation; + + /** + * @packet_size_max: max packet size that can be transmitted via + * multiple fragmented skbs or a single frame if fragmentation is + * disabled + */ + atomic_t packet_size_max; + + /** + * @frag_seqno: incremental counter to identify chains of egress + * fragments + */ + atomic_t frag_seqno; + +#ifdef CONFIG_BATMAN_ADV_BLA + /** + * @bridge_loop_avoidance: bool indicating whether bridge loop + * avoidance is enabled + */ + atomic_t bridge_loop_avoidance; +#endif + +#ifdef CONFIG_BATMAN_ADV_DAT + /** + * @distributed_arp_table: bool indicating whether distributed ARP table + * is enabled + */ + atomic_t distributed_arp_table; +#endif + +#ifdef CONFIG_BATMAN_ADV_MCAST + /** + * @multicast_mode: Enable or disable multicast optimizations on this + * node's sender/originating side + */ + atomic_t multicast_mode; + + /** + * @multicast_fanout: Maximum number of packet copies to generate for a + * multicast-to-unicast conversion + */ + atomic_t multicast_fanout; +#endif + + /** @orig_interval: OGM broadcast interval in milliseconds */ + atomic_t orig_interval; + + /** + * @hop_penalty: penalty which will be applied to an OGM's tq-field on + * every hop + */ + atomic_t hop_penalty; + +#ifdef CONFIG_BATMAN_ADV_DEBUG + /** @log_level: configured log level (see batadv_dbg_level) */ + atomic_t log_level; +#endif + + /** + * @isolation_mark: the skb->mark value used to match packets for AP + * isolation + */ + u32 isolation_mark; + + /** + * @isolation_mark_mask: bitmask identifying the bits in skb->mark to be + * used for the isolation mark + */ + u32 isolation_mark_mask; + + /** @bcast_seqno: last sent broadcast packet sequence number */ + atomic_t bcast_seqno; + + /** + * @bcast_queue_left: number of remaining buffered broadcast packet + * slots + */ + atomic_t bcast_queue_left; + + /** @batman_queue_left: number of remaining OGM packet slots */ + atomic_t batman_queue_left; + + /** @forw_bat_list: list of aggregated OGMs that will be forwarded */ + struct hlist_head forw_bat_list; + + /** + * @forw_bcast_list: list of broadcast packets that will be + * rebroadcasted + */ + struct hlist_head forw_bcast_list; + + /** @tp_list: list of tp sessions */ + struct hlist_head tp_list; + + /** @orig_hash: hash table containing mesh participants (orig nodes) */ + struct batadv_hashtable *orig_hash; + + /** @forw_bat_list_lock: lock protecting forw_bat_list */ + spinlock_t forw_bat_list_lock; + + /** @forw_bcast_list_lock: lock protecting forw_bcast_list */ + spinlock_t forw_bcast_list_lock; + + /** @tp_list_lock: spinlock protecting @tp_list */ + spinlock_t tp_list_lock; + + /** @tp_num: number of currently active tp sessions */ + atomic_t tp_num; + + /** @orig_work: work queue callback item for orig node purging */ + struct delayed_work orig_work; + + /** + * @primary_if: one of the hard-interfaces assigned to this mesh + * interface becomes the primary interface + */ + struct batadv_hard_iface __rcu *primary_if; /* rcu protected pointer */ + + /** @algo_ops: routing algorithm used by this mesh interface */ + struct batadv_algo_ops *algo_ops; + + /** + * @softif_vlan_list: a list of softif_vlan structs, one per VLAN + * created on top of the mesh interface represented by this object + */ + struct hlist_head softif_vlan_list; + + /** @softif_vlan_list_lock: lock protecting softif_vlan_list */ + spinlock_t softif_vlan_list_lock; + +#ifdef CONFIG_BATMAN_ADV_BLA + /** @bla: bridge loop avoidance data */ + struct batadv_priv_bla bla; +#endif + +#ifdef CONFIG_BATMAN_ADV_DEBUG + /** @debug_log: holding debug logging relevant data */ + struct batadv_priv_debug_log *debug_log; +#endif + + /** @gw: gateway data */ + struct batadv_priv_gw gw; + + /** @tt: translation table data */ + struct batadv_priv_tt tt; + + /** @tvlv: type-version-length-value data */ + struct batadv_priv_tvlv tvlv; + +#ifdef CONFIG_BATMAN_ADV_DAT + /** @dat: distributed arp table data */ + struct batadv_priv_dat dat; +#endif + +#ifdef CONFIG_BATMAN_ADV_MCAST + /** @mcast: multicast data */ + struct batadv_priv_mcast mcast; +#endif + +#ifdef CONFIG_BATMAN_ADV_NC + /** + * @network_coding: bool indicating whether network coding is enabled + */ + atomic_t network_coding; + + /** @nc: network coding data */ + struct batadv_priv_nc nc; +#endif /* CONFIG_BATMAN_ADV_NC */ + +#ifdef CONFIG_BATMAN_ADV_BATMAN_V + /** @bat_v: B.A.T.M.A.N. V per soft-interface private data */ + struct batadv_priv_bat_v bat_v; +#endif +}; + +#ifdef CONFIG_BATMAN_ADV_BLA + +/** + * struct batadv_bla_backbone_gw - batman-adv gateway bridged into the LAN + */ +struct batadv_bla_backbone_gw { + /** + * @orig: originator address of backbone node (mac address of primary + * iface) + */ + u8 orig[ETH_ALEN]; + + /** @vid: vlan id this gateway was detected on */ + unsigned short vid; + + /** @hash_entry: hlist node for &batadv_priv_bla.backbone_hash */ + struct hlist_node hash_entry; + + /** @bat_priv: pointer to soft_iface this backbone gateway belongs to */ + struct batadv_priv *bat_priv; + + /** @lasttime: last time we heard of this backbone gw */ + unsigned long lasttime; + + /** + * @wait_periods: grace time for bridge forward delays and bla group + * forming at bootup phase - no bcast traffic is formwared until it has + * elapsed + */ + atomic_t wait_periods; + + /** + * @request_sent: if this bool is set to true we are out of sync with + * this backbone gateway - no bcast traffic is formwared until the + * situation was resolved + */ + atomic_t request_sent; + + /** @crc: crc16 checksum over all claims */ + u16 crc; + + /** @crc_lock: lock protecting crc */ + spinlock_t crc_lock; + + /** @report_work: work struct for reporting detected loops */ + struct work_struct report_work; + + /** @refcount: number of contexts the object is used */ + struct kref refcount; + + /** @rcu: struct used for freeing in an RCU-safe manner */ + struct rcu_head rcu; +}; + +/** + * struct batadv_bla_claim - claimed non-mesh client structure + */ +struct batadv_bla_claim { + /** @addr: mac address of claimed non-mesh client */ + u8 addr[ETH_ALEN]; + + /** @vid: vlan id this client was detected on */ + unsigned short vid; + + /** @backbone_gw: pointer to backbone gw claiming this client */ + struct batadv_bla_backbone_gw *backbone_gw; + + /** @backbone_lock: lock protecting backbone_gw pointer */ + spinlock_t backbone_lock; + + /** @lasttime: last time we heard of claim (locals only) */ + unsigned long lasttime; + + /** @hash_entry: hlist node for &batadv_priv_bla.claim_hash */ + struct hlist_node hash_entry; + + /** @refcount: number of contexts the object is used */ + struct rcu_head rcu; + + /** @rcu: struct used for freeing in an RCU-safe manner */ + struct kref refcount; +}; +#endif + +/** + * struct batadv_tt_common_entry - tt local & tt global common data + */ +struct batadv_tt_common_entry { + /** @addr: mac address of non-mesh client */ + u8 addr[ETH_ALEN]; + + /** @vid: VLAN identifier */ + unsigned short vid; + + /** + * @hash_entry: hlist node for &batadv_priv_tt.local_hash or for + * &batadv_priv_tt.global_hash + */ + struct hlist_node hash_entry; + + /** @flags: various state handling flags (see batadv_tt_client_flags) */ + u16 flags; + + /** @added_at: timestamp used for purging stale tt common entries */ + unsigned long added_at; + + /** @refcount: number of contexts the object is used */ + struct kref refcount; + + /** @rcu: struct used for freeing in an RCU-safe manner */ + struct rcu_head rcu; +}; + +/** + * struct batadv_tt_local_entry - translation table local entry data + */ +struct batadv_tt_local_entry { + /** @common: general translation table data */ + struct batadv_tt_common_entry common; + + /** @last_seen: timestamp used for purging stale tt local entries */ + unsigned long last_seen; + + /** @vlan: soft-interface vlan of the entry */ + struct batadv_softif_vlan *vlan; +}; + +/** + * struct batadv_tt_global_entry - translation table global entry data + */ +struct batadv_tt_global_entry { + /** @common: general translation table data */ + struct batadv_tt_common_entry common; + + /** @orig_list: list of orig nodes announcing this non-mesh client */ + struct hlist_head orig_list; + + /** @orig_list_count: number of items in the orig_list */ + atomic_t orig_list_count; + + /** @list_lock: lock protecting orig_list */ + spinlock_t list_lock; + + /** @roam_at: time at which TT_GLOBAL_ROAM was set */ + unsigned long roam_at; +}; + +/** + * struct batadv_tt_orig_list_entry - orig node announcing a non-mesh client + */ +struct batadv_tt_orig_list_entry { + /** @orig_node: pointer to orig node announcing this non-mesh client */ + struct batadv_orig_node *orig_node; + + /** + * @ttvn: translation table version number which added the non-mesh + * client + */ + u8 ttvn; + + /** @flags: per orig entry TT sync flags */ + u8 flags; + + /** @list: list node for &batadv_tt_global_entry.orig_list */ + struct hlist_node list; + + /** @refcount: number of contexts the object is used */ + struct kref refcount; + + /** @rcu: struct used for freeing in an RCU-safe manner */ + struct rcu_head rcu; +}; + +/** + * struct batadv_tt_change_node - structure for tt changes occurred + */ +struct batadv_tt_change_node { + /** @list: list node for &batadv_priv_tt.changes_list */ + struct list_head list; + + /** @change: holds the actual translation table diff data */ + struct batadv_tvlv_tt_change change; +}; + +/** + * struct batadv_tt_req_node - data to keep track of the tt requests in flight + */ +struct batadv_tt_req_node { + /** + * @addr: mac address of the originator this request was sent to + */ + u8 addr[ETH_ALEN]; + + /** @issued_at: timestamp used for purging stale tt requests */ + unsigned long issued_at; + + /** @refcount: number of contexts the object is used by */ + struct kref refcount; + + /** @list: list node for &batadv_priv_tt.req_list */ + struct hlist_node list; +}; + +/** + * struct batadv_tt_roam_node - roaming client data + */ +struct batadv_tt_roam_node { + /** @addr: mac address of the client in the roaming phase */ + u8 addr[ETH_ALEN]; + + /** + * @counter: number of allowed roaming events per client within a single + * OGM interval (changes are committed with each OGM) + */ + atomic_t counter; + + /** + * @first_time: timestamp used for purging stale roaming node entries + */ + unsigned long first_time; + + /** @list: list node for &batadv_priv_tt.roam_list */ + struct list_head list; +}; + +/** + * struct batadv_nc_node - network coding node + */ +struct batadv_nc_node { + /** @list: next and prev pointer for the list handling */ + struct list_head list; + + /** @addr: the node's mac address */ + u8 addr[ETH_ALEN]; + + /** @refcount: number of contexts the object is used by */ + struct kref refcount; + + /** @rcu: struct used for freeing in an RCU-safe manner */ + struct rcu_head rcu; + + /** @orig_node: pointer to corresponding orig node struct */ + struct batadv_orig_node *orig_node; + + /** @last_seen: timestamp of last ogm received from this node */ + unsigned long last_seen; +}; + +/** + * struct batadv_nc_path - network coding path + */ +struct batadv_nc_path { + /** @hash_entry: next and prev pointer for the list handling */ + struct hlist_node hash_entry; + + /** @rcu: struct used for freeing in an RCU-safe manner */ + struct rcu_head rcu; + + /** @refcount: number of contexts the object is used by */ + struct kref refcount; + + /** @packet_list: list of buffered packets for this path */ + struct list_head packet_list; + + /** @packet_list_lock: access lock for packet list */ + spinlock_t packet_list_lock; + + /** @next_hop: next hop (destination) of path */ + u8 next_hop[ETH_ALEN]; + + /** @prev_hop: previous hop (source) of path */ + u8 prev_hop[ETH_ALEN]; + + /** @last_valid: timestamp for last validation of path */ + unsigned long last_valid; +}; + +/** + * struct batadv_nc_packet - network coding packet used when coding and + * decoding packets + */ +struct batadv_nc_packet { + /** @list: next and prev pointer for the list handling */ + struct list_head list; + + /** @packet_id: crc32 checksum of skb data */ + __be32 packet_id; + + /** + * @timestamp: field containing the info when the packet was added to + * path + */ + unsigned long timestamp; + + /** @neigh_node: pointer to original next hop neighbor of skb */ + struct batadv_neigh_node *neigh_node; + + /** @skb: skb which can be encoded or used for decoding */ + struct sk_buff *skb; + + /** @nc_path: pointer to path this nc packet is attached to */ + struct batadv_nc_path *nc_path; +}; + +/** + * struct batadv_skb_cb - control buffer structure used to store private data + * relevant to batman-adv in the skb->cb buffer in skbs. + */ +struct batadv_skb_cb { + /** + * @decoded: Marks a skb as decoded, which is checked when searching for + * coding opportunities in network-coding.c + */ + unsigned char decoded:1; + + /** @num_bcasts: Counter for broadcast packet retransmissions */ + unsigned char num_bcasts; +}; + +/** + * struct batadv_forw_packet - structure for bcast packets to be sent/forwarded + */ +struct batadv_forw_packet { + /** + * @list: list node for &batadv_priv.forw.bcast_list and + * &batadv_priv.forw.bat_list + */ + struct hlist_node list; + + /** @cleanup_list: list node for purging functions */ + struct hlist_node cleanup_list; + + /** @send_time: execution time for delayed_work (packet sending) */ + unsigned long send_time; + + /** + * @own: bool for locally generated packets (local OGMs are re-scheduled + * after sending) + */ + u8 own; + + /** @skb: bcast packet's skb buffer */ + struct sk_buff *skb; + + /** @packet_len: size of aggregated OGM packet inside the skb buffer */ + u16 packet_len; + + /** @direct_link_flags: direct link flags for aggregated OGM packets */ + u32 direct_link_flags; + + /** @num_packets: counter for aggregated OGMv1 packets */ + u8 num_packets; + + /** @delayed_work: work queue callback item for packet sending */ + struct delayed_work delayed_work; + + /** + * @if_incoming: pointer to incoming hard-iface or primary iface if + * locally generated packet + */ + struct batadv_hard_iface *if_incoming; + + /** + * @if_outgoing: packet where the packet should be sent to, or NULL if + * unspecified + */ + struct batadv_hard_iface *if_outgoing; + + /** @queue_left: The queue (counter) this packet was applied to */ + atomic_t *queue_left; +}; + +/** + * struct batadv_algo_iface_ops - mesh algorithm callbacks (interface specific) + */ +struct batadv_algo_iface_ops { + /** + * @activate: start routing mechanisms when hard-interface is brought up + * (optional) + */ + void (*activate)(struct batadv_hard_iface *hard_iface); + + /** @enable: init routing info when hard-interface is enabled */ + int (*enable)(struct batadv_hard_iface *hard_iface); + + /** @enabled: notification when hard-interface was enabled (optional) */ + void (*enabled)(struct batadv_hard_iface *hard_iface); + + /** @disable: de-init routing info when hard-interface is disabled */ + void (*disable)(struct batadv_hard_iface *hard_iface); + + /** + * @update_mac: (re-)init mac addresses of the protocol information + * belonging to this hard-interface + */ + void (*update_mac)(struct batadv_hard_iface *hard_iface); + + /** @primary_set: called when primary interface is selected / changed */ + void (*primary_set)(struct batadv_hard_iface *hard_iface); +}; + +/** + * struct batadv_algo_neigh_ops - mesh algorithm callbacks (neighbour specific) + */ +struct batadv_algo_neigh_ops { + /** @hardif_init: called on creation of single hop entry (optional) */ + void (*hardif_init)(struct batadv_hardif_neigh_node *neigh); + + /** + * @cmp: compare the metrics of two neighbors for their respective + * outgoing interfaces + */ + int (*cmp)(struct batadv_neigh_node *neigh1, + struct batadv_hard_iface *if_outgoing1, + struct batadv_neigh_node *neigh2, + struct batadv_hard_iface *if_outgoing2); + + /** + * @is_similar_or_better: check if neigh1 is equally similar or better + * than neigh2 for their respective outgoing interface from the metric + * prospective + */ + bool (*is_similar_or_better)(struct batadv_neigh_node *neigh1, + struct batadv_hard_iface *if_outgoing1, + struct batadv_neigh_node *neigh2, + struct batadv_hard_iface *if_outgoing2); + + /** @dump: dump neighbors to a netlink socket (optional) */ + void (*dump)(struct sk_buff *msg, struct netlink_callback *cb, + struct batadv_priv *priv, + struct batadv_hard_iface *hard_iface); +}; + +/** + * struct batadv_algo_orig_ops - mesh algorithm callbacks (originator specific) + */ +struct batadv_algo_orig_ops { + /** @dump: dump originators to a netlink socket (optional) */ + void (*dump)(struct sk_buff *msg, struct netlink_callback *cb, + struct batadv_priv *priv, + struct batadv_hard_iface *hard_iface); +}; + +/** + * struct batadv_algo_gw_ops - mesh algorithm callbacks (GW specific) + */ +struct batadv_algo_gw_ops { + /** @init_sel_class: initialize GW selection class (optional) */ + void (*init_sel_class)(struct batadv_priv *bat_priv); + + /** + * @store_sel_class: parse and stores a new GW selection class + * (optional) + */ + ssize_t (*store_sel_class)(struct batadv_priv *bat_priv, char *buff, + size_t count); + /** + * @get_best_gw_node: select the best GW from the list of available + * nodes (optional) + */ + struct batadv_gw_node *(*get_best_gw_node) + (struct batadv_priv *bat_priv); + + /** + * @is_eligible: check if a newly discovered GW is a potential candidate + * for the election as best GW (optional) + */ + bool (*is_eligible)(struct batadv_priv *bat_priv, + struct batadv_orig_node *curr_gw_orig, + struct batadv_orig_node *orig_node); + + /** @dump: dump gateways to a netlink socket (optional) */ + void (*dump)(struct sk_buff *msg, struct netlink_callback *cb, + struct batadv_priv *priv); +}; + +/** + * struct batadv_algo_ops - mesh algorithm callbacks + */ +struct batadv_algo_ops { + /** @list: list node for the batadv_algo_list */ + struct hlist_node list; + + /** @name: name of the algorithm */ + char *name; + + /** @iface: callbacks related to interface handling */ + struct batadv_algo_iface_ops iface; + + /** @neigh: callbacks related to neighbors handling */ + struct batadv_algo_neigh_ops neigh; + + /** @orig: callbacks related to originators handling */ + struct batadv_algo_orig_ops orig; + + /** @gw: callbacks related to GW mode */ + struct batadv_algo_gw_ops gw; +}; + +/** + * struct batadv_dat_entry - it is a single entry of batman-adv ARP backend. It + * is used to stored ARP entries needed for the global DAT cache + */ +struct batadv_dat_entry { + /** @ip: the IPv4 corresponding to this DAT/ARP entry */ + __be32 ip; + + /** @mac_addr: the MAC address associated to the stored IPv4 */ + u8 mac_addr[ETH_ALEN]; + + /** @vid: the vlan ID associated to this entry */ + unsigned short vid; + + /** + * @last_update: time in jiffies when this entry was refreshed last time + */ + unsigned long last_update; + + /** @hash_entry: hlist node for &batadv_priv_dat.hash */ + struct hlist_node hash_entry; + + /** @refcount: number of contexts the object is used */ + struct kref refcount; + + /** @rcu: struct used for freeing in an RCU-safe manner */ + struct rcu_head rcu; +}; + +/** + * struct batadv_hw_addr - a list entry for a MAC address + */ +struct batadv_hw_addr { + /** @list: list node for the linking of entries */ + struct hlist_node list; + + /** @addr: the MAC address of this list entry */ + unsigned char addr[ETH_ALEN]; +}; + +/** + * struct batadv_dat_candidate - candidate destination for DAT operations + */ +struct batadv_dat_candidate { + /** + * @type: the type of the selected candidate. It can one of the + * following: + * - BATADV_DAT_CANDIDATE_NOT_FOUND + * - BATADV_DAT_CANDIDATE_ORIG + */ + int type; + + /** + * @orig_node: if type is BATADV_DAT_CANDIDATE_ORIG this field points to + * the corresponding originator node structure + */ + struct batadv_orig_node *orig_node; +}; + +/** + * struct batadv_tvlv_container - container for tvlv appended to OGMs + */ +struct batadv_tvlv_container { + /** @list: hlist node for &batadv_priv_tvlv.container_list */ + struct hlist_node list; + + /** @tvlv_hdr: tvlv header information needed to construct the tvlv */ + struct batadv_tvlv_hdr tvlv_hdr; + + /** @refcount: number of contexts the object is used */ + struct kref refcount; +}; + +/** + * struct batadv_tvlv_handler - handler for specific tvlv type and version + */ +struct batadv_tvlv_handler { + /** @list: hlist node for &batadv_priv_tvlv.handler_list */ + struct hlist_node list; + + /** + * @ogm_handler: handler callback which is given the tvlv payload to + * process on incoming OGM packets + */ + void (*ogm_handler)(struct batadv_priv *bat_priv, + struct batadv_orig_node *orig, + u8 flags, void *tvlv_value, u16 tvlv_value_len); + + /** + * @unicast_handler: handler callback which is given the tvlv payload to + * process on incoming unicast tvlv packets + */ + int (*unicast_handler)(struct batadv_priv *bat_priv, + u8 *src, u8 *dst, + void *tvlv_value, u16 tvlv_value_len); + + /** @type: tvlv type this handler feels responsible for */ + u8 type; + + /** @version: tvlv version this handler feels responsible for */ + u8 version; + + /** @flags: tvlv handler flags */ + u8 flags; + + /** @refcount: number of contexts the object is used */ + struct kref refcount; + + /** @rcu: struct used for freeing in an RCU-safe manner */ + struct rcu_head rcu; +}; + +/** + * enum batadv_tvlv_handler_flags - tvlv handler flags definitions + */ +enum batadv_tvlv_handler_flags { + /** + * @BATADV_TVLV_HANDLER_OGM_CIFNOTFND: tvlv ogm processing function + * will call this handler even if its type was not found (with no data) + */ + BATADV_TVLV_HANDLER_OGM_CIFNOTFND = BIT(1), + + /** + * @BATADV_TVLV_HANDLER_OGM_CALLED: interval tvlv handling flag - the + * API marks a handler as being called, so it won't be called if the + * BATADV_TVLV_HANDLER_OGM_CIFNOTFND flag was set + */ + BATADV_TVLV_HANDLER_OGM_CALLED = BIT(2), +}; + +#endif /* _NET_BATMAN_ADV_TYPES_H_ */ diff --git a/net/bluetooth/6lowpan.c b/net/bluetooth/6lowpan.c new file mode 100644 index 000000000..4eb1b3ced --- /dev/null +++ b/net/bluetooth/6lowpan.c @@ -0,0 +1,1283 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + Copyright (c) 2013-2014 Intel Corp. + +*/ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include /* for the compression support */ + +#define VERSION "0.1" + +static struct dentry *lowpan_enable_debugfs; +static struct dentry *lowpan_control_debugfs; + +#define IFACE_NAME_TEMPLATE "bt%d" + +struct skb_cb { + struct in6_addr addr; + struct in6_addr gw; + struct l2cap_chan *chan; +}; +#define lowpan_cb(skb) ((struct skb_cb *)((skb)->cb)) + +/* The devices list contains those devices that we are acting + * as a proxy. The BT 6LoWPAN device is a virtual device that + * connects to the Bluetooth LE device. The real connection to + * BT device is done via l2cap layer. There exists one + * virtual device / one BT 6LoWPAN network (=hciX device). + * The list contains struct lowpan_dev elements. + */ +static LIST_HEAD(bt_6lowpan_devices); +static DEFINE_SPINLOCK(devices_lock); + +static bool enable_6lowpan; + +/* We are listening incoming connections via this channel + */ +static struct l2cap_chan *listen_chan; +static DEFINE_MUTEX(set_lock); + +struct lowpan_peer { + struct list_head list; + struct rcu_head rcu; + struct l2cap_chan *chan; + + /* peer addresses in various formats */ + unsigned char lladdr[ETH_ALEN]; + struct in6_addr peer_addr; +}; + +struct lowpan_btle_dev { + struct list_head list; + + struct hci_dev *hdev; + struct net_device *netdev; + struct list_head peers; + atomic_t peer_count; /* number of items in peers list */ + + struct work_struct delete_netdev; + struct delayed_work notify_peers; +}; + +static inline struct lowpan_btle_dev * +lowpan_btle_dev(const struct net_device *netdev) +{ + return (struct lowpan_btle_dev *)lowpan_dev(netdev)->priv; +} + +static inline void peer_add(struct lowpan_btle_dev *dev, + struct lowpan_peer *peer) +{ + list_add_rcu(&peer->list, &dev->peers); + atomic_inc(&dev->peer_count); +} + +static inline bool peer_del(struct lowpan_btle_dev *dev, + struct lowpan_peer *peer) +{ + list_del_rcu(&peer->list); + kfree_rcu(peer, rcu); + + module_put(THIS_MODULE); + + if (atomic_dec_and_test(&dev->peer_count)) { + BT_DBG("last peer"); + return true; + } + + return false; +} + +static inline struct lowpan_peer * +__peer_lookup_chan(struct lowpan_btle_dev *dev, struct l2cap_chan *chan) +{ + struct lowpan_peer *peer; + + list_for_each_entry_rcu(peer, &dev->peers, list) { + if (peer->chan == chan) + return peer; + } + + return NULL; +} + +static inline struct lowpan_peer * +__peer_lookup_conn(struct lowpan_btle_dev *dev, struct l2cap_conn *conn) +{ + struct lowpan_peer *peer; + + list_for_each_entry_rcu(peer, &dev->peers, list) { + if (peer->chan->conn == conn) + return peer; + } + + return NULL; +} + +static inline struct lowpan_peer *peer_lookup_dst(struct lowpan_btle_dev *dev, + struct in6_addr *daddr, + struct sk_buff *skb) +{ + struct rt6_info *rt = (struct rt6_info *)skb_dst(skb); + int count = atomic_read(&dev->peer_count); + const struct in6_addr *nexthop; + struct lowpan_peer *peer; + struct neighbour *neigh; + + BT_DBG("peers %d addr %pI6c rt %p", count, daddr, rt); + + if (!rt) { + if (ipv6_addr_any(&lowpan_cb(skb)->gw)) { + /* There is neither route nor gateway, + * probably the destination is a direct peer. + */ + nexthop = daddr; + } else { + /* There is a known gateway + */ + nexthop = &lowpan_cb(skb)->gw; + } + } else { + nexthop = rt6_nexthop(rt, daddr); + + /* We need to remember the address because it is needed + * by bt_xmit() when sending the packet. In bt_xmit(), the + * destination routing info is not set. + */ + memcpy(&lowpan_cb(skb)->gw, nexthop, sizeof(struct in6_addr)); + } + + BT_DBG("gw %pI6c", nexthop); + + rcu_read_lock(); + + list_for_each_entry_rcu(peer, &dev->peers, list) { + BT_DBG("dst addr %pMR dst type %u ip %pI6c", + &peer->chan->dst, peer->chan->dst_type, + &peer->peer_addr); + + if (!ipv6_addr_cmp(&peer->peer_addr, nexthop)) { + rcu_read_unlock(); + return peer; + } + } + + /* use the neighbour cache for matching addresses assigned by SLAAC */ + neigh = __ipv6_neigh_lookup(dev->netdev, nexthop); + if (neigh) { + list_for_each_entry_rcu(peer, &dev->peers, list) { + if (!memcmp(neigh->ha, peer->lladdr, ETH_ALEN)) { + neigh_release(neigh); + rcu_read_unlock(); + return peer; + } + } + neigh_release(neigh); + } + + rcu_read_unlock(); + + return NULL; +} + +static struct lowpan_peer *lookup_peer(struct l2cap_conn *conn) +{ + struct lowpan_btle_dev *entry; + struct lowpan_peer *peer = NULL; + + rcu_read_lock(); + + list_for_each_entry_rcu(entry, &bt_6lowpan_devices, list) { + peer = __peer_lookup_conn(entry, conn); + if (peer) + break; + } + + rcu_read_unlock(); + + return peer; +} + +static struct lowpan_btle_dev *lookup_dev(struct l2cap_conn *conn) +{ + struct lowpan_btle_dev *entry; + struct lowpan_btle_dev *dev = NULL; + + rcu_read_lock(); + + list_for_each_entry_rcu(entry, &bt_6lowpan_devices, list) { + if (conn->hcon->hdev == entry->hdev) { + dev = entry; + break; + } + } + + rcu_read_unlock(); + + return dev; +} + +static int give_skb_to_upper(struct sk_buff *skb, struct net_device *dev) +{ + struct sk_buff *skb_cp; + + skb_cp = skb_copy(skb, GFP_ATOMIC); + if (!skb_cp) + return NET_RX_DROP; + + return netif_rx(skb_cp); +} + +static int iphc_decompress(struct sk_buff *skb, struct net_device *netdev, + struct lowpan_peer *peer) +{ + const u8 *saddr; + + saddr = peer->lladdr; + + return lowpan_header_decompress(skb, netdev, netdev->dev_addr, saddr); +} + +static int recv_pkt(struct sk_buff *skb, struct net_device *dev, + struct lowpan_peer *peer) +{ + struct sk_buff *local_skb; + int ret; + + if (!netif_running(dev)) + goto drop; + + if (dev->type != ARPHRD_6LOWPAN || !skb->len) + goto drop; + + skb_reset_network_header(skb); + + skb = skb_share_check(skb, GFP_ATOMIC); + if (!skb) + goto drop; + + /* check that it's our buffer */ + if (lowpan_is_ipv6(*skb_network_header(skb))) { + /* Pull off the 1-byte of 6lowpan header. */ + skb_pull(skb, 1); + + /* Copy the packet so that the IPv6 header is + * properly aligned. + */ + local_skb = skb_copy_expand(skb, NET_SKB_PAD - 1, + skb_tailroom(skb), GFP_ATOMIC); + if (!local_skb) + goto drop; + + local_skb->protocol = htons(ETH_P_IPV6); + local_skb->pkt_type = PACKET_HOST; + local_skb->dev = dev; + + skb_set_transport_header(local_skb, sizeof(struct ipv6hdr)); + + if (give_skb_to_upper(local_skb, dev) != NET_RX_SUCCESS) { + kfree_skb(local_skb); + goto drop; + } + + dev->stats.rx_bytes += skb->len; + dev->stats.rx_packets++; + + consume_skb(local_skb); + consume_skb(skb); + } else if (lowpan_is_iphc(*skb_network_header(skb))) { + local_skb = skb_clone(skb, GFP_ATOMIC); + if (!local_skb) + goto drop; + + local_skb->dev = dev; + + ret = iphc_decompress(local_skb, dev, peer); + if (ret < 0) { + BT_DBG("iphc_decompress failed: %d", ret); + kfree_skb(local_skb); + goto drop; + } + + local_skb->protocol = htons(ETH_P_IPV6); + local_skb->pkt_type = PACKET_HOST; + + if (give_skb_to_upper(local_skb, dev) + != NET_RX_SUCCESS) { + kfree_skb(local_skb); + goto drop; + } + + dev->stats.rx_bytes += skb->len; + dev->stats.rx_packets++; + + consume_skb(local_skb); + consume_skb(skb); + } else { + BT_DBG("unknown packet type"); + goto drop; + } + + return NET_RX_SUCCESS; + +drop: + dev->stats.rx_dropped++; + return NET_RX_DROP; +} + +/* Packet from BT LE device */ +static int chan_recv_cb(struct l2cap_chan *chan, struct sk_buff *skb) +{ + struct lowpan_btle_dev *dev; + struct lowpan_peer *peer; + int err; + + peer = lookup_peer(chan->conn); + if (!peer) + return -ENOENT; + + dev = lookup_dev(chan->conn); + if (!dev || !dev->netdev) + return -ENOENT; + + err = recv_pkt(skb, dev->netdev, peer); + if (err) { + BT_DBG("recv pkt %d", err); + err = -EAGAIN; + } + + return err; +} + +static int setup_header(struct sk_buff *skb, struct net_device *netdev, + bdaddr_t *peer_addr, u8 *peer_addr_type) +{ + struct in6_addr ipv6_daddr; + struct ipv6hdr *hdr; + struct lowpan_btle_dev *dev; + struct lowpan_peer *peer; + u8 *daddr; + int err, status = 0; + + hdr = ipv6_hdr(skb); + + dev = lowpan_btle_dev(netdev); + + memcpy(&ipv6_daddr, &hdr->daddr, sizeof(ipv6_daddr)); + + if (ipv6_addr_is_multicast(&ipv6_daddr)) { + lowpan_cb(skb)->chan = NULL; + daddr = NULL; + } else { + BT_DBG("dest IP %pI6c", &ipv6_daddr); + + /* The packet might be sent to 6lowpan interface + * because of routing (either via default route + * or user set route) so get peer according to + * the destination address. + */ + peer = peer_lookup_dst(dev, &ipv6_daddr, skb); + if (!peer) { + BT_DBG("no such peer"); + return -ENOENT; + } + + daddr = peer->lladdr; + *peer_addr = peer->chan->dst; + *peer_addr_type = peer->chan->dst_type; + lowpan_cb(skb)->chan = peer->chan; + + status = 1; + } + + lowpan_header_compress(skb, netdev, daddr, dev->netdev->dev_addr); + + err = dev_hard_header(skb, netdev, ETH_P_IPV6, NULL, NULL, 0); + if (err < 0) + return err; + + return status; +} + +static int header_create(struct sk_buff *skb, struct net_device *netdev, + unsigned short type, const void *_daddr, + const void *_saddr, unsigned int len) +{ + if (type != ETH_P_IPV6) + return -EINVAL; + + return 0; +} + +/* Packet to BT LE device */ +static int send_pkt(struct l2cap_chan *chan, struct sk_buff *skb, + struct net_device *netdev) +{ + struct msghdr msg; + struct kvec iv; + int err; + + /* Remember the skb so that we can send EAGAIN to the caller if + * we run out of credits. + */ + chan->data = skb; + + iv.iov_base = skb->data; + iv.iov_len = skb->len; + + memset(&msg, 0, sizeof(msg)); + iov_iter_kvec(&msg.msg_iter, ITER_SOURCE, &iv, 1, skb->len); + + err = l2cap_chan_send(chan, &msg, skb->len); + if (err > 0) { + netdev->stats.tx_bytes += err; + netdev->stats.tx_packets++; + return 0; + } + + if (err < 0) + netdev->stats.tx_errors++; + + return err; +} + +static int send_mcast_pkt(struct sk_buff *skb, struct net_device *netdev) +{ + struct sk_buff *local_skb; + struct lowpan_btle_dev *entry; + int err = 0; + + rcu_read_lock(); + + list_for_each_entry_rcu(entry, &bt_6lowpan_devices, list) { + struct lowpan_peer *pentry; + struct lowpan_btle_dev *dev; + + if (entry->netdev != netdev) + continue; + + dev = lowpan_btle_dev(entry->netdev); + + list_for_each_entry_rcu(pentry, &dev->peers, list) { + int ret; + + local_skb = skb_clone(skb, GFP_ATOMIC); + + BT_DBG("xmit %s to %pMR type %u IP %pI6c chan %p", + netdev->name, + &pentry->chan->dst, pentry->chan->dst_type, + &pentry->peer_addr, pentry->chan); + ret = send_pkt(pentry->chan, local_skb, netdev); + if (ret < 0) + err = ret; + + kfree_skb(local_skb); + } + } + + rcu_read_unlock(); + + return err; +} + +static netdev_tx_t bt_xmit(struct sk_buff *skb, struct net_device *netdev) +{ + int err = 0; + bdaddr_t addr; + u8 addr_type; + + /* We must take a copy of the skb before we modify/replace the ipv6 + * header as the header could be used elsewhere + */ + skb = skb_unshare(skb, GFP_ATOMIC); + if (!skb) + return NET_XMIT_DROP; + + /* Return values from setup_header() + * <0 - error, packet is dropped + * 0 - this is a multicast packet + * 1 - this is unicast packet + */ + err = setup_header(skb, netdev, &addr, &addr_type); + if (err < 0) { + kfree_skb(skb); + return NET_XMIT_DROP; + } + + if (err) { + if (lowpan_cb(skb)->chan) { + BT_DBG("xmit %s to %pMR type %u IP %pI6c chan %p", + netdev->name, &addr, addr_type, + &lowpan_cb(skb)->addr, lowpan_cb(skb)->chan); + err = send_pkt(lowpan_cb(skb)->chan, skb, netdev); + } else { + err = -ENOENT; + } + } else { + /* We need to send the packet to every device behind this + * interface. + */ + err = send_mcast_pkt(skb, netdev); + } + + dev_kfree_skb(skb); + + if (err) + BT_DBG("ERROR: xmit failed (%d)", err); + + return err < 0 ? NET_XMIT_DROP : err; +} + +static int bt_dev_init(struct net_device *dev) +{ + netdev_lockdep_set_classes(dev); + + return 0; +} + +static const struct net_device_ops netdev_ops = { + .ndo_init = bt_dev_init, + .ndo_start_xmit = bt_xmit, +}; + +static const struct header_ops header_ops = { + .create = header_create, +}; + +static void netdev_setup(struct net_device *dev) +{ + dev->hard_header_len = 0; + dev->needed_tailroom = 0; + dev->flags = IFF_RUNNING | IFF_MULTICAST; + dev->watchdog_timeo = 0; + dev->tx_queue_len = DEFAULT_TX_QUEUE_LEN; + + dev->netdev_ops = &netdev_ops; + dev->header_ops = &header_ops; + dev->needs_free_netdev = true; +} + +static struct device_type bt_type = { + .name = "bluetooth", +}; + +static void ifup(struct net_device *netdev) +{ + int err; + + rtnl_lock(); + err = dev_open(netdev, NULL); + if (err < 0) + BT_INFO("iface %s cannot be opened (%d)", netdev->name, err); + rtnl_unlock(); +} + +static void ifdown(struct net_device *netdev) +{ + rtnl_lock(); + dev_close(netdev); + rtnl_unlock(); +} + +static void do_notify_peers(struct work_struct *work) +{ + struct lowpan_btle_dev *dev = container_of(work, struct lowpan_btle_dev, + notify_peers.work); + + netdev_notify_peers(dev->netdev); /* send neighbour adv at startup */ +} + +static bool is_bt_6lowpan(struct hci_conn *hcon) +{ + if (hcon->type != LE_LINK) + return false; + + if (!enable_6lowpan) + return false; + + return true; +} + +static struct l2cap_chan *chan_create(void) +{ + struct l2cap_chan *chan; + + chan = l2cap_chan_create(); + if (!chan) + return NULL; + + l2cap_chan_set_defaults(chan); + + chan->chan_type = L2CAP_CHAN_CONN_ORIENTED; + chan->mode = L2CAP_MODE_LE_FLOWCTL; + chan->imtu = 1280; + + return chan; +} + +static struct l2cap_chan *add_peer_chan(struct l2cap_chan *chan, + struct lowpan_btle_dev *dev, + bool new_netdev) +{ + struct lowpan_peer *peer; + + peer = kzalloc(sizeof(*peer), GFP_ATOMIC); + if (!peer) + return NULL; + + peer->chan = chan; + + baswap((void *)peer->lladdr, &chan->dst); + + lowpan_iphc_uncompress_eui48_lladdr(&peer->peer_addr, peer->lladdr); + + spin_lock(&devices_lock); + INIT_LIST_HEAD(&peer->list); + peer_add(dev, peer); + spin_unlock(&devices_lock); + + /* Notifying peers about us needs to be done without locks held */ + if (new_netdev) + INIT_DELAYED_WORK(&dev->notify_peers, do_notify_peers); + schedule_delayed_work(&dev->notify_peers, msecs_to_jiffies(100)); + + return peer->chan; +} + +static int setup_netdev(struct l2cap_chan *chan, struct lowpan_btle_dev **dev) +{ + struct net_device *netdev; + bdaddr_t addr; + int err; + + netdev = alloc_netdev(LOWPAN_PRIV_SIZE(sizeof(struct lowpan_btle_dev)), + IFACE_NAME_TEMPLATE, NET_NAME_UNKNOWN, + netdev_setup); + if (!netdev) + return -ENOMEM; + + netdev->addr_assign_type = NET_ADDR_PERM; + baswap(&addr, &chan->src); + __dev_addr_set(netdev, &addr, sizeof(addr)); + + netdev->netdev_ops = &netdev_ops; + SET_NETDEV_DEV(netdev, &chan->conn->hcon->hdev->dev); + SET_NETDEV_DEVTYPE(netdev, &bt_type); + + *dev = lowpan_btle_dev(netdev); + (*dev)->netdev = netdev; + (*dev)->hdev = chan->conn->hcon->hdev; + INIT_LIST_HEAD(&(*dev)->peers); + + spin_lock(&devices_lock); + INIT_LIST_HEAD(&(*dev)->list); + list_add_rcu(&(*dev)->list, &bt_6lowpan_devices); + spin_unlock(&devices_lock); + + err = lowpan_register_netdev(netdev, LOWPAN_LLTYPE_BTLE); + if (err < 0) { + BT_INFO("register_netdev failed %d", err); + spin_lock(&devices_lock); + list_del_rcu(&(*dev)->list); + spin_unlock(&devices_lock); + free_netdev(netdev); + goto out; + } + + BT_DBG("ifindex %d peer bdaddr %pMR type %d my addr %pMR type %d", + netdev->ifindex, &chan->dst, chan->dst_type, + &chan->src, chan->src_type); + set_bit(__LINK_STATE_PRESENT, &netdev->state); + + return 0; + +out: + return err; +} + +static inline void chan_ready_cb(struct l2cap_chan *chan) +{ + struct lowpan_btle_dev *dev; + bool new_netdev = false; + + dev = lookup_dev(chan->conn); + + BT_DBG("chan %p conn %p dev %p", chan, chan->conn, dev); + + if (!dev) { + if (setup_netdev(chan, &dev) < 0) { + l2cap_chan_del(chan, -ENOENT); + return; + } + new_netdev = true; + } + + if (!try_module_get(THIS_MODULE)) + return; + + add_peer_chan(chan, dev, new_netdev); + ifup(dev->netdev); +} + +static inline struct l2cap_chan *chan_new_conn_cb(struct l2cap_chan *pchan) +{ + struct l2cap_chan *chan; + + chan = chan_create(); + if (!chan) + return NULL; + + chan->ops = pchan->ops; + + BT_DBG("chan %p pchan %p", chan, pchan); + + return chan; +} + +static void delete_netdev(struct work_struct *work) +{ + struct lowpan_btle_dev *entry = container_of(work, + struct lowpan_btle_dev, + delete_netdev); + + lowpan_unregister_netdev(entry->netdev); + + /* The entry pointer is deleted by the netdev destructor. */ +} + +static void chan_close_cb(struct l2cap_chan *chan) +{ + struct lowpan_btle_dev *entry; + struct lowpan_btle_dev *dev = NULL; + struct lowpan_peer *peer; + int err = -ENOENT; + bool last = false, remove = true; + + BT_DBG("chan %p conn %p", chan, chan->conn); + + if (chan->conn && chan->conn->hcon) { + if (!is_bt_6lowpan(chan->conn->hcon)) + return; + + /* If conn is set, then the netdev is also there and we should + * not remove it. + */ + remove = false; + } + + spin_lock(&devices_lock); + + list_for_each_entry_rcu(entry, &bt_6lowpan_devices, list) { + dev = lowpan_btle_dev(entry->netdev); + peer = __peer_lookup_chan(dev, chan); + if (peer) { + last = peer_del(dev, peer); + err = 0; + + BT_DBG("dev %p removing %speer %p", dev, + last ? "last " : "1 ", peer); + BT_DBG("chan %p orig refcnt %u", chan, + kref_read(&chan->kref)); + + l2cap_chan_put(chan); + break; + } + } + + if (!err && last && dev && !atomic_read(&dev->peer_count)) { + spin_unlock(&devices_lock); + + cancel_delayed_work_sync(&dev->notify_peers); + + ifdown(dev->netdev); + + if (remove) { + INIT_WORK(&entry->delete_netdev, delete_netdev); + schedule_work(&entry->delete_netdev); + } + } else { + spin_unlock(&devices_lock); + } +} + +static void chan_state_change_cb(struct l2cap_chan *chan, int state, int err) +{ + BT_DBG("chan %p conn %p state %s err %d", chan, chan->conn, + state_to_string(state), err); +} + +static struct sk_buff *chan_alloc_skb_cb(struct l2cap_chan *chan, + unsigned long hdr_len, + unsigned long len, int nb) +{ + /* Note that we must allocate using GFP_ATOMIC here as + * this function is called originally from netdev hard xmit + * function in atomic context. + */ + return bt_skb_alloc(hdr_len + len, GFP_ATOMIC); +} + +static void chan_suspend_cb(struct l2cap_chan *chan) +{ + struct lowpan_btle_dev *dev; + + BT_DBG("chan %p suspend", chan); + + dev = lookup_dev(chan->conn); + if (!dev || !dev->netdev) + return; + + netif_stop_queue(dev->netdev); +} + +static void chan_resume_cb(struct l2cap_chan *chan) +{ + struct lowpan_btle_dev *dev; + + BT_DBG("chan %p resume", chan); + + dev = lookup_dev(chan->conn); + if (!dev || !dev->netdev) + return; + + netif_wake_queue(dev->netdev); +} + +static long chan_get_sndtimeo_cb(struct l2cap_chan *chan) +{ + return L2CAP_CONN_TIMEOUT; +} + +static const struct l2cap_ops bt_6lowpan_chan_ops = { + .name = "L2CAP 6LoWPAN channel", + .new_connection = chan_new_conn_cb, + .recv = chan_recv_cb, + .close = chan_close_cb, + .state_change = chan_state_change_cb, + .ready = chan_ready_cb, + .resume = chan_resume_cb, + .suspend = chan_suspend_cb, + .get_sndtimeo = chan_get_sndtimeo_cb, + .alloc_skb = chan_alloc_skb_cb, + + .teardown = l2cap_chan_no_teardown, + .defer = l2cap_chan_no_defer, + .set_shutdown = l2cap_chan_no_set_shutdown, +}; + +static int bt_6lowpan_connect(bdaddr_t *addr, u8 dst_type) +{ + struct l2cap_chan *chan; + int err; + + chan = chan_create(); + if (!chan) + return -EINVAL; + + chan->ops = &bt_6lowpan_chan_ops; + + err = l2cap_chan_connect(chan, cpu_to_le16(L2CAP_PSM_IPSP), 0, + addr, dst_type); + + BT_DBG("chan %p err %d", chan, err); + if (err < 0) + l2cap_chan_put(chan); + + return err; +} + +static int bt_6lowpan_disconnect(struct l2cap_conn *conn, u8 dst_type) +{ + struct lowpan_peer *peer; + + BT_DBG("conn %p dst type %u", conn, dst_type); + + peer = lookup_peer(conn); + if (!peer) + return -ENOENT; + + BT_DBG("peer %p chan %p", peer, peer->chan); + + l2cap_chan_close(peer->chan, ENOENT); + + return 0; +} + +static struct l2cap_chan *bt_6lowpan_listen(void) +{ + bdaddr_t *addr = BDADDR_ANY; + struct l2cap_chan *chan; + int err; + + if (!enable_6lowpan) + return NULL; + + chan = chan_create(); + if (!chan) + return NULL; + + chan->ops = &bt_6lowpan_chan_ops; + chan->state = BT_LISTEN; + chan->src_type = BDADDR_LE_PUBLIC; + + atomic_set(&chan->nesting, L2CAP_NESTING_PARENT); + + BT_DBG("chan %p src type %u", chan, chan->src_type); + + err = l2cap_add_psm(chan, addr, cpu_to_le16(L2CAP_PSM_IPSP)); + if (err) { + l2cap_chan_put(chan); + BT_ERR("psm cannot be added err %d", err); + return NULL; + } + + return chan; +} + +static int get_l2cap_conn(char *buf, bdaddr_t *addr, u8 *addr_type, + struct l2cap_conn **conn) +{ + struct hci_conn *hcon; + struct hci_dev *hdev; + int n; + + n = sscanf(buf, "%hhx:%hhx:%hhx:%hhx:%hhx:%hhx %hhu", + &addr->b[5], &addr->b[4], &addr->b[3], + &addr->b[2], &addr->b[1], &addr->b[0], + addr_type); + + if (n < 7) + return -EINVAL; + + /* The LE_PUBLIC address type is ignored because of BDADDR_ANY */ + hdev = hci_get_route(addr, BDADDR_ANY, BDADDR_LE_PUBLIC); + if (!hdev) + return -ENOENT; + + hci_dev_lock(hdev); + hcon = hci_conn_hash_lookup_le(hdev, addr, *addr_type); + hci_dev_unlock(hdev); + hci_dev_put(hdev); + + if (!hcon) + return -ENOENT; + + *conn = (struct l2cap_conn *)hcon->l2cap_data; + + BT_DBG("conn %p dst %pMR type %u", *conn, &hcon->dst, hcon->dst_type); + + return 0; +} + +static void disconnect_all_peers(void) +{ + struct lowpan_btle_dev *entry; + struct lowpan_peer *peer, *tmp_peer, *new_peer; + struct list_head peers; + + INIT_LIST_HEAD(&peers); + + /* We make a separate list of peers as the close_cb() will + * modify the device peers list so it is better not to mess + * with the same list at the same time. + */ + + rcu_read_lock(); + + list_for_each_entry_rcu(entry, &bt_6lowpan_devices, list) { + list_for_each_entry_rcu(peer, &entry->peers, list) { + new_peer = kmalloc(sizeof(*new_peer), GFP_ATOMIC); + if (!new_peer) + break; + + new_peer->chan = peer->chan; + INIT_LIST_HEAD(&new_peer->list); + + list_add(&new_peer->list, &peers); + } + } + + rcu_read_unlock(); + + spin_lock(&devices_lock); + list_for_each_entry_safe(peer, tmp_peer, &peers, list) { + l2cap_chan_close(peer->chan, ENOENT); + + list_del_rcu(&peer->list); + kfree_rcu(peer, rcu); + } + spin_unlock(&devices_lock); +} + +struct set_enable { + struct work_struct work; + bool flag; +}; + +static void do_enable_set(struct work_struct *work) +{ + struct set_enable *set_enable = container_of(work, + struct set_enable, work); + + if (!set_enable->flag || enable_6lowpan != set_enable->flag) + /* Disconnect existing connections if 6lowpan is + * disabled + */ + disconnect_all_peers(); + + enable_6lowpan = set_enable->flag; + + mutex_lock(&set_lock); + if (listen_chan) { + l2cap_chan_close(listen_chan, 0); + l2cap_chan_put(listen_chan); + } + + listen_chan = bt_6lowpan_listen(); + mutex_unlock(&set_lock); + + kfree(set_enable); +} + +static int lowpan_enable_set(void *data, u64 val) +{ + struct set_enable *set_enable; + + set_enable = kzalloc(sizeof(*set_enable), GFP_KERNEL); + if (!set_enable) + return -ENOMEM; + + set_enable->flag = !!val; + INIT_WORK(&set_enable->work, do_enable_set); + + schedule_work(&set_enable->work); + + return 0; +} + +static int lowpan_enable_get(void *data, u64 *val) +{ + *val = enable_6lowpan; + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(lowpan_enable_fops, lowpan_enable_get, + lowpan_enable_set, "%llu\n"); + +static ssize_t lowpan_control_write(struct file *fp, + const char __user *user_buffer, + size_t count, + loff_t *position) +{ + char buf[32]; + size_t buf_size = min(count, sizeof(buf) - 1); + int ret; + bdaddr_t addr; + u8 addr_type; + struct l2cap_conn *conn = NULL; + + if (copy_from_user(buf, user_buffer, buf_size)) + return -EFAULT; + + buf[buf_size] = '\0'; + + if (memcmp(buf, "connect ", 8) == 0) { + ret = get_l2cap_conn(&buf[8], &addr, &addr_type, &conn); + if (ret == -EINVAL) + return ret; + + mutex_lock(&set_lock); + if (listen_chan) { + l2cap_chan_close(listen_chan, 0); + l2cap_chan_put(listen_chan); + listen_chan = NULL; + } + mutex_unlock(&set_lock); + + if (conn) { + struct lowpan_peer *peer; + + if (!is_bt_6lowpan(conn->hcon)) + return -EINVAL; + + peer = lookup_peer(conn); + if (peer) { + BT_DBG("6LoWPAN connection already exists"); + return -EALREADY; + } + + BT_DBG("conn %p dst %pMR type %d user %u", conn, + &conn->hcon->dst, conn->hcon->dst_type, + addr_type); + } + + ret = bt_6lowpan_connect(&addr, addr_type); + if (ret < 0) + return ret; + + return count; + } + + if (memcmp(buf, "disconnect ", 11) == 0) { + ret = get_l2cap_conn(&buf[11], &addr, &addr_type, &conn); + if (ret < 0) + return ret; + + ret = bt_6lowpan_disconnect(conn, addr_type); + if (ret < 0) + return ret; + + return count; + } + + return count; +} + +static int lowpan_control_show(struct seq_file *f, void *ptr) +{ + struct lowpan_btle_dev *entry; + struct lowpan_peer *peer; + + spin_lock(&devices_lock); + + list_for_each_entry(entry, &bt_6lowpan_devices, list) { + list_for_each_entry(peer, &entry->peers, list) + seq_printf(f, "%pMR (type %u)\n", + &peer->chan->dst, peer->chan->dst_type); + } + + spin_unlock(&devices_lock); + + return 0; +} + +static int lowpan_control_open(struct inode *inode, struct file *file) +{ + return single_open(file, lowpan_control_show, inode->i_private); +} + +static const struct file_operations lowpan_control_fops = { + .open = lowpan_control_open, + .read = seq_read, + .write = lowpan_control_write, + .llseek = seq_lseek, + .release = single_release, +}; + +static void disconnect_devices(void) +{ + struct lowpan_btle_dev *entry, *tmp, *new_dev; + struct list_head devices; + + INIT_LIST_HEAD(&devices); + + /* We make a separate list of devices because the unregister_netdev() + * will call device_event() which will also want to modify the same + * devices list. + */ + + rcu_read_lock(); + + list_for_each_entry_rcu(entry, &bt_6lowpan_devices, list) { + new_dev = kmalloc(sizeof(*new_dev), GFP_ATOMIC); + if (!new_dev) + break; + + new_dev->netdev = entry->netdev; + INIT_LIST_HEAD(&new_dev->list); + + list_add_rcu(&new_dev->list, &devices); + } + + rcu_read_unlock(); + + list_for_each_entry_safe(entry, tmp, &devices, list) { + ifdown(entry->netdev); + BT_DBG("Unregistering netdev %s %p", + entry->netdev->name, entry->netdev); + lowpan_unregister_netdev(entry->netdev); + kfree(entry); + } +} + +static int device_event(struct notifier_block *unused, + unsigned long event, void *ptr) +{ + struct net_device *netdev = netdev_notifier_info_to_dev(ptr); + struct lowpan_btle_dev *entry; + + if (netdev->type != ARPHRD_6LOWPAN) + return NOTIFY_DONE; + + switch (event) { + case NETDEV_UNREGISTER: + spin_lock(&devices_lock); + list_for_each_entry(entry, &bt_6lowpan_devices, list) { + if (entry->netdev == netdev) { + BT_DBG("Unregistered netdev %s %p", + netdev->name, netdev); + list_del(&entry->list); + break; + } + } + spin_unlock(&devices_lock); + break; + } + + return NOTIFY_DONE; +} + +static struct notifier_block bt_6lowpan_dev_notifier = { + .notifier_call = device_event, +}; + +static int __init bt_6lowpan_init(void) +{ + lowpan_enable_debugfs = debugfs_create_file_unsafe("6lowpan_enable", + 0644, bt_debugfs, + NULL, + &lowpan_enable_fops); + lowpan_control_debugfs = debugfs_create_file("6lowpan_control", 0644, + bt_debugfs, NULL, + &lowpan_control_fops); + + return register_netdevice_notifier(&bt_6lowpan_dev_notifier); +} + +static void __exit bt_6lowpan_exit(void) +{ + debugfs_remove(lowpan_enable_debugfs); + debugfs_remove(lowpan_control_debugfs); + + if (listen_chan) { + l2cap_chan_close(listen_chan, 0); + l2cap_chan_put(listen_chan); + } + + disconnect_devices(); + + unregister_netdevice_notifier(&bt_6lowpan_dev_notifier); +} + +module_init(bt_6lowpan_init); +module_exit(bt_6lowpan_exit); + +MODULE_AUTHOR("Jukka Rissanen "); +MODULE_DESCRIPTION("Bluetooth 6LoWPAN"); +MODULE_VERSION(VERSION); +MODULE_LICENSE("GPL"); diff --git a/net/bluetooth/Kconfig b/net/bluetooth/Kconfig new file mode 100644 index 000000000..ae3bdc6df --- /dev/null +++ b/net/bluetooth/Kconfig @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Bluetooth subsystem configuration +# + +menuconfig BT + tristate "Bluetooth subsystem support" + depends on !S390 + depends on RFKILL || !RFKILL + select CRC16 + select CRYPTO + select CRYPTO_SKCIPHER + select CRYPTO_LIB_AES + imply CRYPTO_AES + select CRYPTO_CMAC + select CRYPTO_ECB + select CRYPTO_SHA256 + select CRYPTO_ECDH + help + Bluetooth is low-cost, low-power, short-range wireless technology. + It was designed as a replacement for cables and other short-range + technologies like IrDA. Bluetooth operates in personal area range + that typically extends up to 10 meters. More information about + Bluetooth can be found at . + + Linux Bluetooth subsystem consist of several layers: + Bluetooth Core + HCI device and connection manager, scheduler + SCO audio links + L2CAP (Logical Link Control and Adaptation Protocol) + SMP (Security Manager Protocol) on LE (Low Energy) links + ISO isochronous links + HCI Device drivers (Interface to the hardware) + RFCOMM Module (RFCOMM Protocol) + BNEP Module (Bluetooth Network Encapsulation Protocol) + CMTP Module (CAPI Message Transport Protocol) + HIDP Module (Human Interface Device Protocol) + + Say Y here to compile Bluetooth support into the kernel or say M to + compile it as module (bluetooth). + + To use Linux Bluetooth subsystem, you will need several user-space + utilities like hciconfig and bluetoothd. These utilities and updates + to Bluetooth kernel modules are provided in the BlueZ packages. For + more information, see . + +config BT_BREDR + bool "Bluetooth Classic (BR/EDR) features" + depends on BT + default y + help + Bluetooth Classic includes support for Basic Rate (BR) + available with Bluetooth version 1.0b or later and support + for Enhanced Data Rate (EDR) available with Bluetooth + version 2.0 or later. + +source "net/bluetooth/rfcomm/Kconfig" + +source "net/bluetooth/bnep/Kconfig" + +source "net/bluetooth/cmtp/Kconfig" + +source "net/bluetooth/hidp/Kconfig" + +config BT_HS + bool "Bluetooth High Speed (HS) features" + depends on BT_BREDR + help + Bluetooth High Speed includes support for off-loading + Bluetooth connections via 802.11 (wifi) physical layer + available with Bluetooth version 3.0 or later. + +config BT_LE + bool "Bluetooth Low Energy (LE) features" + depends on BT + default y + help + Bluetooth Low Energy includes support low-energy physical + layer available with Bluetooth version 4.0 or later. + +config BT_6LOWPAN + tristate "Bluetooth 6LoWPAN support" + depends on BT_LE && 6LOWPAN + help + IPv6 compression over Bluetooth Low Energy. + +config BT_LEDS + bool "Enable LED triggers" + depends on BT + depends on LEDS_CLASS + select LEDS_TRIGGERS + help + This option selects a few LED triggers for different + Bluetooth events. + +config BT_MSFTEXT + bool "Enable Microsoft extensions" + depends on BT + help + This options enables support for the Microsoft defined HCI + vendor extensions. + +config BT_AOSPEXT + bool "Enable Android Open Source Project extensions" + depends on BT + help + This options enables support for the Android Open Source + Project defined HCI vendor extensions. + +config BT_DEBUGFS + bool "Export Bluetooth internals in debugfs" + depends on BT && DEBUG_FS + default y + help + Provide extensive information about internal Bluetooth states + in debugfs. + +config BT_SELFTEST + bool "Bluetooth self testing support" + depends on BT && DEBUG_KERNEL + help + Run self tests when initializing the Bluetooth subsystem. This + is a developer option and can cause significant delay when booting + the system. + + When the Bluetooth subsystem is built as module, then the test + cases are run first thing at module load time. When the Bluetooth + subsystem is compiled into the kernel image, then the test cases + are run late in the initcall hierarchy. + +config BT_SELFTEST_ECDH + bool "ECDH test cases" + depends on BT_LE && BT_SELFTEST + help + Run test cases for ECDH cryptographic functionality used by the + Bluetooth Low Energy Secure Connections feature. + +config BT_SELFTEST_SMP + bool "SMP test cases" + depends on BT_LE && BT_SELFTEST + help + Run test cases for SMP cryptographic functionality, including both + legacy SMP as well as the Secure Connections features. + +config BT_FEATURE_DEBUG + bool "Enable runtime option for debugging statements" + depends on BT && !DYNAMIC_DEBUG + help + This provides an option to enable/disable debugging statements + at runtime via the experimental features interface. + +source "drivers/bluetooth/Kconfig" diff --git a/net/bluetooth/Makefile b/net/bluetooth/Makefile new file mode 100644 index 000000000..0e7b7db42 --- /dev/null +++ b/net/bluetooth/Makefile @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the Linux Bluetooth subsystem. +# + +obj-$(CONFIG_BT) += bluetooth.o +obj-$(CONFIG_BT_RFCOMM) += rfcomm/ +obj-$(CONFIG_BT_BNEP) += bnep/ +obj-$(CONFIG_BT_CMTP) += cmtp/ +obj-$(CONFIG_BT_HIDP) += hidp/ +obj-$(CONFIG_BT_6LOWPAN) += bluetooth_6lowpan.o + +bluetooth_6lowpan-y := 6lowpan.o + +bluetooth-y := af_bluetooth.o hci_core.o hci_conn.o hci_event.o mgmt.o \ + hci_sock.o hci_sysfs.o l2cap_core.o l2cap_sock.o smp.o lib.o \ + ecdh_helper.o hci_request.o mgmt_util.o mgmt_config.o hci_codec.o \ + eir.o hci_sync.o + +bluetooth-$(CONFIG_BT_BREDR) += sco.o +bluetooth-$(CONFIG_BT_LE) += iso.o +bluetooth-$(CONFIG_BT_HS) += a2mp.o amp.o +bluetooth-$(CONFIG_BT_LEDS) += leds.o +bluetooth-$(CONFIG_BT_MSFTEXT) += msft.o +bluetooth-$(CONFIG_BT_AOSPEXT) += aosp.o +bluetooth-$(CONFIG_BT_DEBUGFS) += hci_debugfs.o +bluetooth-$(CONFIG_BT_SELFTEST) += selftest.o diff --git a/net/bluetooth/a2mp.c b/net/bluetooth/a2mp.c new file mode 100644 index 000000000..e7adb8a98 --- /dev/null +++ b/net/bluetooth/a2mp.c @@ -0,0 +1,1054 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + Copyright (c) 2010,2011 Code Aurora Forum. All rights reserved. + Copyright (c) 2011,2012 Intel Corp. + +*/ + +#include +#include +#include + +#include "hci_request.h" +#include "a2mp.h" +#include "amp.h" + +#define A2MP_FEAT_EXT 0x8000 + +/* Global AMP Manager list */ +static LIST_HEAD(amp_mgr_list); +static DEFINE_MUTEX(amp_mgr_list_lock); + +/* A2MP build & send command helper functions */ +static struct a2mp_cmd *__a2mp_build(u8 code, u8 ident, u16 len, void *data) +{ + struct a2mp_cmd *cmd; + int plen; + + plen = sizeof(*cmd) + len; + cmd = kzalloc(plen, GFP_KERNEL); + if (!cmd) + return NULL; + + cmd->code = code; + cmd->ident = ident; + cmd->len = cpu_to_le16(len); + + memcpy(cmd->data, data, len); + + return cmd; +} + +static void a2mp_send(struct amp_mgr *mgr, u8 code, u8 ident, u16 len, void *data) +{ + struct l2cap_chan *chan = mgr->a2mp_chan; + struct a2mp_cmd *cmd; + u16 total_len = len + sizeof(*cmd); + struct kvec iv; + struct msghdr msg; + + cmd = __a2mp_build(code, ident, len, data); + if (!cmd) + return; + + iv.iov_base = cmd; + iv.iov_len = total_len; + + memset(&msg, 0, sizeof(msg)); + + iov_iter_kvec(&msg.msg_iter, ITER_SOURCE, &iv, 1, total_len); + + l2cap_chan_send(chan, &msg, total_len); + + kfree(cmd); +} + +static u8 __next_ident(struct amp_mgr *mgr) +{ + if (++mgr->ident == 0) + mgr->ident = 1; + + return mgr->ident; +} + +static struct amp_mgr *amp_mgr_lookup_by_state(u8 state) +{ + struct amp_mgr *mgr; + + mutex_lock(&_mgr_list_lock); + list_for_each_entry(mgr, &_mgr_list, list) { + if (test_and_clear_bit(state, &mgr->state)) { + amp_mgr_get(mgr); + mutex_unlock(&_mgr_list_lock); + return mgr; + } + } + mutex_unlock(&_mgr_list_lock); + + return NULL; +} + +/* hci_dev_list shall be locked */ +static void __a2mp_add_cl(struct amp_mgr *mgr, struct a2mp_cl *cl) +{ + struct hci_dev *hdev; + int i = 1; + + cl[0].id = AMP_ID_BREDR; + cl[0].type = AMP_TYPE_BREDR; + cl[0].status = AMP_STATUS_BLUETOOTH_ONLY; + + list_for_each_entry(hdev, &hci_dev_list, list) { + if (hdev->dev_type == HCI_AMP) { + cl[i].id = hdev->id; + cl[i].type = hdev->amp_type; + if (test_bit(HCI_UP, &hdev->flags)) + cl[i].status = hdev->amp_status; + else + cl[i].status = AMP_STATUS_POWERED_DOWN; + i++; + } + } +} + +/* Processing A2MP messages */ +static int a2mp_command_rej(struct amp_mgr *mgr, struct sk_buff *skb, + struct a2mp_cmd *hdr) +{ + struct a2mp_cmd_rej *rej = (void *) skb->data; + + if (le16_to_cpu(hdr->len) < sizeof(*rej)) + return -EINVAL; + + BT_DBG("ident %u reason %d", hdr->ident, le16_to_cpu(rej->reason)); + + skb_pull(skb, sizeof(*rej)); + + return 0; +} + +static int a2mp_discover_req(struct amp_mgr *mgr, struct sk_buff *skb, + struct a2mp_cmd *hdr) +{ + struct a2mp_discov_req *req = (void *) skb->data; + u16 len = le16_to_cpu(hdr->len); + struct a2mp_discov_rsp *rsp; + u16 ext_feat; + u8 num_ctrl; + struct hci_dev *hdev; + + if (len < sizeof(*req)) + return -EINVAL; + + skb_pull(skb, sizeof(*req)); + + ext_feat = le16_to_cpu(req->ext_feat); + + BT_DBG("mtu %d efm 0x%4.4x", le16_to_cpu(req->mtu), ext_feat); + + /* check that packet is not broken for now */ + while (ext_feat & A2MP_FEAT_EXT) { + if (len < sizeof(ext_feat)) + return -EINVAL; + + ext_feat = get_unaligned_le16(skb->data); + BT_DBG("efm 0x%4.4x", ext_feat); + len -= sizeof(ext_feat); + skb_pull(skb, sizeof(ext_feat)); + } + + read_lock(&hci_dev_list_lock); + + /* at minimum the BR/EDR needs to be listed */ + num_ctrl = 1; + + list_for_each_entry(hdev, &hci_dev_list, list) { + if (hdev->dev_type == HCI_AMP) + num_ctrl++; + } + + len = struct_size(rsp, cl, num_ctrl); + rsp = kmalloc(len, GFP_ATOMIC); + if (!rsp) { + read_unlock(&hci_dev_list_lock); + return -ENOMEM; + } + + rsp->mtu = cpu_to_le16(L2CAP_A2MP_DEFAULT_MTU); + rsp->ext_feat = 0; + + __a2mp_add_cl(mgr, rsp->cl); + + read_unlock(&hci_dev_list_lock); + + a2mp_send(mgr, A2MP_DISCOVER_RSP, hdr->ident, len, rsp); + + kfree(rsp); + return 0; +} + +static int a2mp_discover_rsp(struct amp_mgr *mgr, struct sk_buff *skb, + struct a2mp_cmd *hdr) +{ + struct a2mp_discov_rsp *rsp = (void *) skb->data; + u16 len = le16_to_cpu(hdr->len); + struct a2mp_cl *cl; + u16 ext_feat; + bool found = false; + + if (len < sizeof(*rsp)) + return -EINVAL; + + len -= sizeof(*rsp); + skb_pull(skb, sizeof(*rsp)); + + ext_feat = le16_to_cpu(rsp->ext_feat); + + BT_DBG("mtu %d efm 0x%4.4x", le16_to_cpu(rsp->mtu), ext_feat); + + /* check that packet is not broken for now */ + while (ext_feat & A2MP_FEAT_EXT) { + if (len < sizeof(ext_feat)) + return -EINVAL; + + ext_feat = get_unaligned_le16(skb->data); + BT_DBG("efm 0x%4.4x", ext_feat); + len -= sizeof(ext_feat); + skb_pull(skb, sizeof(ext_feat)); + } + + cl = (void *) skb->data; + while (len >= sizeof(*cl)) { + BT_DBG("Remote AMP id %u type %u status %u", cl->id, cl->type, + cl->status); + + if (cl->id != AMP_ID_BREDR && cl->type != AMP_TYPE_BREDR) { + struct a2mp_info_req req; + + found = true; + + memset(&req, 0, sizeof(req)); + + req.id = cl->id; + a2mp_send(mgr, A2MP_GETINFO_REQ, __next_ident(mgr), + sizeof(req), &req); + } + + len -= sizeof(*cl); + cl = skb_pull(skb, sizeof(*cl)); + } + + /* Fall back to L2CAP init sequence */ + if (!found) { + struct l2cap_conn *conn = mgr->l2cap_conn; + struct l2cap_chan *chan; + + mutex_lock(&conn->chan_lock); + + list_for_each_entry(chan, &conn->chan_l, list) { + + BT_DBG("chan %p state %s", chan, + state_to_string(chan->state)); + + if (chan->scid == L2CAP_CID_A2MP) + continue; + + l2cap_chan_lock(chan); + + if (chan->state == BT_CONNECT) + l2cap_send_conn_req(chan); + + l2cap_chan_unlock(chan); + } + + mutex_unlock(&conn->chan_lock); + } + + return 0; +} + +static int a2mp_change_notify(struct amp_mgr *mgr, struct sk_buff *skb, + struct a2mp_cmd *hdr) +{ + struct a2mp_cl *cl = (void *) skb->data; + + while (skb->len >= sizeof(*cl)) { + BT_DBG("Controller id %u type %u status %u", cl->id, cl->type, + cl->status); + cl = skb_pull(skb, sizeof(*cl)); + } + + /* TODO send A2MP_CHANGE_RSP */ + + return 0; +} + +static void read_local_amp_info_complete(struct hci_dev *hdev, u8 status, + u16 opcode) +{ + BT_DBG("%s status 0x%2.2x", hdev->name, status); + + a2mp_send_getinfo_rsp(hdev); +} + +static int a2mp_getinfo_req(struct amp_mgr *mgr, struct sk_buff *skb, + struct a2mp_cmd *hdr) +{ + struct a2mp_info_req *req = (void *) skb->data; + struct hci_dev *hdev; + struct hci_request hreq; + int err = 0; + + if (le16_to_cpu(hdr->len) < sizeof(*req)) + return -EINVAL; + + BT_DBG("id %u", req->id); + + hdev = hci_dev_get(req->id); + if (!hdev || hdev->dev_type != HCI_AMP) { + struct a2mp_info_rsp rsp; + + memset(&rsp, 0, sizeof(rsp)); + + rsp.id = req->id; + rsp.status = A2MP_STATUS_INVALID_CTRL_ID; + + a2mp_send(mgr, A2MP_GETINFO_RSP, hdr->ident, sizeof(rsp), + &rsp); + + goto done; + } + + set_bit(READ_LOC_AMP_INFO, &mgr->state); + hci_req_init(&hreq, hdev); + hci_req_add(&hreq, HCI_OP_READ_LOCAL_AMP_INFO, 0, NULL); + err = hci_req_run(&hreq, read_local_amp_info_complete); + if (err < 0) + a2mp_send_getinfo_rsp(hdev); + +done: + if (hdev) + hci_dev_put(hdev); + + skb_pull(skb, sizeof(*req)); + return 0; +} + +static int a2mp_getinfo_rsp(struct amp_mgr *mgr, struct sk_buff *skb, + struct a2mp_cmd *hdr) +{ + struct a2mp_info_rsp *rsp = (struct a2mp_info_rsp *) skb->data; + struct a2mp_amp_assoc_req req; + struct amp_ctrl *ctrl; + + if (le16_to_cpu(hdr->len) < sizeof(*rsp)) + return -EINVAL; + + BT_DBG("id %u status 0x%2.2x", rsp->id, rsp->status); + + if (rsp->status) + return -EINVAL; + + ctrl = amp_ctrl_add(mgr, rsp->id); + if (!ctrl) + return -ENOMEM; + + memset(&req, 0, sizeof(req)); + + req.id = rsp->id; + a2mp_send(mgr, A2MP_GETAMPASSOC_REQ, __next_ident(mgr), sizeof(req), + &req); + + skb_pull(skb, sizeof(*rsp)); + return 0; +} + +static int a2mp_getampassoc_req(struct amp_mgr *mgr, struct sk_buff *skb, + struct a2mp_cmd *hdr) +{ + struct a2mp_amp_assoc_req *req = (void *) skb->data; + struct hci_dev *hdev; + struct amp_mgr *tmp; + + if (le16_to_cpu(hdr->len) < sizeof(*req)) + return -EINVAL; + + BT_DBG("id %u", req->id); + + /* Make sure that other request is not processed */ + tmp = amp_mgr_lookup_by_state(READ_LOC_AMP_ASSOC); + + hdev = hci_dev_get(req->id); + if (!hdev || hdev->amp_type == AMP_TYPE_BREDR || tmp) { + struct a2mp_amp_assoc_rsp rsp; + + memset(&rsp, 0, sizeof(rsp)); + rsp.id = req->id; + + if (tmp) { + rsp.status = A2MP_STATUS_COLLISION_OCCURED; + amp_mgr_put(tmp); + } else { + rsp.status = A2MP_STATUS_INVALID_CTRL_ID; + } + + a2mp_send(mgr, A2MP_GETAMPASSOC_RSP, hdr->ident, sizeof(rsp), + &rsp); + + goto done; + } + + amp_read_loc_assoc(hdev, mgr); + +done: + if (hdev) + hci_dev_put(hdev); + + skb_pull(skb, sizeof(*req)); + return 0; +} + +static int a2mp_getampassoc_rsp(struct amp_mgr *mgr, struct sk_buff *skb, + struct a2mp_cmd *hdr) +{ + struct a2mp_amp_assoc_rsp *rsp = (void *) skb->data; + u16 len = le16_to_cpu(hdr->len); + struct hci_dev *hdev; + struct amp_ctrl *ctrl; + struct hci_conn *hcon; + size_t assoc_len; + + if (len < sizeof(*rsp)) + return -EINVAL; + + assoc_len = len - sizeof(*rsp); + + BT_DBG("id %u status 0x%2.2x assoc len %zu", rsp->id, rsp->status, + assoc_len); + + if (rsp->status) + return -EINVAL; + + /* Save remote ASSOC data */ + ctrl = amp_ctrl_lookup(mgr, rsp->id); + if (ctrl) { + u8 *assoc; + + assoc = kmemdup(rsp->amp_assoc, assoc_len, GFP_KERNEL); + if (!assoc) { + amp_ctrl_put(ctrl); + return -ENOMEM; + } + + ctrl->assoc = assoc; + ctrl->assoc_len = assoc_len; + ctrl->assoc_rem_len = assoc_len; + ctrl->assoc_len_so_far = 0; + + amp_ctrl_put(ctrl); + } + + /* Create Phys Link */ + hdev = hci_dev_get(rsp->id); + if (!hdev) + return -EINVAL; + + hcon = phylink_add(hdev, mgr, rsp->id, true); + if (!hcon) + goto done; + + BT_DBG("Created hcon %p: loc:%u -> rem:%u", hcon, hdev->id, rsp->id); + + mgr->bredr_chan->remote_amp_id = rsp->id; + + amp_create_phylink(hdev, mgr, hcon); + +done: + hci_dev_put(hdev); + skb_pull(skb, len); + return 0; +} + +static int a2mp_createphyslink_req(struct amp_mgr *mgr, struct sk_buff *skb, + struct a2mp_cmd *hdr) +{ + struct a2mp_physlink_req *req = (void *) skb->data; + struct a2mp_physlink_rsp rsp; + struct hci_dev *hdev; + struct hci_conn *hcon; + struct amp_ctrl *ctrl; + + if (le16_to_cpu(hdr->len) < sizeof(*req)) + return -EINVAL; + + BT_DBG("local_id %u, remote_id %u", req->local_id, req->remote_id); + + memset(&rsp, 0, sizeof(rsp)); + + rsp.local_id = req->remote_id; + rsp.remote_id = req->local_id; + + hdev = hci_dev_get(req->remote_id); + if (!hdev || hdev->amp_type == AMP_TYPE_BREDR) { + rsp.status = A2MP_STATUS_INVALID_CTRL_ID; + goto send_rsp; + } + + ctrl = amp_ctrl_lookup(mgr, rsp.remote_id); + if (!ctrl) { + ctrl = amp_ctrl_add(mgr, rsp.remote_id); + if (ctrl) { + amp_ctrl_get(ctrl); + } else { + rsp.status = A2MP_STATUS_UNABLE_START_LINK_CREATION; + goto send_rsp; + } + } + + if (ctrl) { + size_t assoc_len = le16_to_cpu(hdr->len) - sizeof(*req); + u8 *assoc; + + assoc = kmemdup(req->amp_assoc, assoc_len, GFP_KERNEL); + if (!assoc) { + amp_ctrl_put(ctrl); + hci_dev_put(hdev); + return -ENOMEM; + } + + ctrl->assoc = assoc; + ctrl->assoc_len = assoc_len; + ctrl->assoc_rem_len = assoc_len; + ctrl->assoc_len_so_far = 0; + + amp_ctrl_put(ctrl); + } + + hcon = phylink_add(hdev, mgr, req->local_id, false); + if (hcon) { + amp_accept_phylink(hdev, mgr, hcon); + rsp.status = A2MP_STATUS_SUCCESS; + } else { + rsp.status = A2MP_STATUS_UNABLE_START_LINK_CREATION; + } + +send_rsp: + if (hdev) + hci_dev_put(hdev); + + /* Reply error now and success after HCI Write Remote AMP Assoc + command complete with success status + */ + if (rsp.status != A2MP_STATUS_SUCCESS) { + a2mp_send(mgr, A2MP_CREATEPHYSLINK_RSP, hdr->ident, + sizeof(rsp), &rsp); + } else { + set_bit(WRITE_REMOTE_AMP_ASSOC, &mgr->state); + mgr->ident = hdr->ident; + } + + skb_pull(skb, le16_to_cpu(hdr->len)); + return 0; +} + +static int a2mp_discphyslink_req(struct amp_mgr *mgr, struct sk_buff *skb, + struct a2mp_cmd *hdr) +{ + struct a2mp_physlink_req *req = (void *) skb->data; + struct a2mp_physlink_rsp rsp; + struct hci_dev *hdev; + struct hci_conn *hcon; + + if (le16_to_cpu(hdr->len) < sizeof(*req)) + return -EINVAL; + + BT_DBG("local_id %u remote_id %u", req->local_id, req->remote_id); + + memset(&rsp, 0, sizeof(rsp)); + + rsp.local_id = req->remote_id; + rsp.remote_id = req->local_id; + rsp.status = A2MP_STATUS_SUCCESS; + + hdev = hci_dev_get(req->remote_id); + if (!hdev) { + rsp.status = A2MP_STATUS_INVALID_CTRL_ID; + goto send_rsp; + } + + hcon = hci_conn_hash_lookup_ba(hdev, AMP_LINK, + &mgr->l2cap_conn->hcon->dst); + if (!hcon) { + bt_dev_err(hdev, "no phys link exist"); + rsp.status = A2MP_STATUS_NO_PHYSICAL_LINK_EXISTS; + goto clean; + } + + /* TODO Disconnect Phys Link here */ + +clean: + hci_dev_put(hdev); + +send_rsp: + a2mp_send(mgr, A2MP_DISCONNPHYSLINK_RSP, hdr->ident, sizeof(rsp), &rsp); + + skb_pull(skb, sizeof(*req)); + return 0; +} + +static inline int a2mp_cmd_rsp(struct amp_mgr *mgr, struct sk_buff *skb, + struct a2mp_cmd *hdr) +{ + BT_DBG("ident %u code 0x%2.2x", hdr->ident, hdr->code); + + skb_pull(skb, le16_to_cpu(hdr->len)); + return 0; +} + +/* Handle A2MP signalling */ +static int a2mp_chan_recv_cb(struct l2cap_chan *chan, struct sk_buff *skb) +{ + struct a2mp_cmd *hdr; + struct amp_mgr *mgr = chan->data; + int err = 0; + + amp_mgr_get(mgr); + + while (skb->len >= sizeof(*hdr)) { + u16 len; + + hdr = (void *) skb->data; + len = le16_to_cpu(hdr->len); + + BT_DBG("code 0x%2.2x id %u len %u", hdr->code, hdr->ident, len); + + skb_pull(skb, sizeof(*hdr)); + + if (len > skb->len || !hdr->ident) { + err = -EINVAL; + break; + } + + mgr->ident = hdr->ident; + + switch (hdr->code) { + case A2MP_COMMAND_REJ: + a2mp_command_rej(mgr, skb, hdr); + break; + + case A2MP_DISCOVER_REQ: + err = a2mp_discover_req(mgr, skb, hdr); + break; + + case A2MP_CHANGE_NOTIFY: + err = a2mp_change_notify(mgr, skb, hdr); + break; + + case A2MP_GETINFO_REQ: + err = a2mp_getinfo_req(mgr, skb, hdr); + break; + + case A2MP_GETAMPASSOC_REQ: + err = a2mp_getampassoc_req(mgr, skb, hdr); + break; + + case A2MP_CREATEPHYSLINK_REQ: + err = a2mp_createphyslink_req(mgr, skb, hdr); + break; + + case A2MP_DISCONNPHYSLINK_REQ: + err = a2mp_discphyslink_req(mgr, skb, hdr); + break; + + case A2MP_DISCOVER_RSP: + err = a2mp_discover_rsp(mgr, skb, hdr); + break; + + case A2MP_GETINFO_RSP: + err = a2mp_getinfo_rsp(mgr, skb, hdr); + break; + + case A2MP_GETAMPASSOC_RSP: + err = a2mp_getampassoc_rsp(mgr, skb, hdr); + break; + + case A2MP_CHANGE_RSP: + case A2MP_CREATEPHYSLINK_RSP: + case A2MP_DISCONNPHYSLINK_RSP: + err = a2mp_cmd_rsp(mgr, skb, hdr); + break; + + default: + BT_ERR("Unknown A2MP sig cmd 0x%2.2x", hdr->code); + err = -EINVAL; + break; + } + } + + if (err) { + struct a2mp_cmd_rej rej; + + memset(&rej, 0, sizeof(rej)); + + rej.reason = cpu_to_le16(0); + hdr = (void *) skb->data; + + BT_DBG("Send A2MP Rej: cmd 0x%2.2x err %d", hdr->code, err); + + a2mp_send(mgr, A2MP_COMMAND_REJ, hdr->ident, sizeof(rej), + &rej); + } + + /* Always free skb and return success error code to prevent + from sending L2CAP Disconnect over A2MP channel */ + kfree_skb(skb); + + amp_mgr_put(mgr); + + return 0; +} + +static void a2mp_chan_close_cb(struct l2cap_chan *chan) +{ + l2cap_chan_put(chan); +} + +static void a2mp_chan_state_change_cb(struct l2cap_chan *chan, int state, + int err) +{ + struct amp_mgr *mgr = chan->data; + + if (!mgr) + return; + + BT_DBG("chan %p state %s", chan, state_to_string(state)); + + chan->state = state; + + switch (state) { + case BT_CLOSED: + if (mgr) + amp_mgr_put(mgr); + break; + } +} + +static struct sk_buff *a2mp_chan_alloc_skb_cb(struct l2cap_chan *chan, + unsigned long hdr_len, + unsigned long len, int nb) +{ + struct sk_buff *skb; + + skb = bt_skb_alloc(hdr_len + len, GFP_KERNEL); + if (!skb) + return ERR_PTR(-ENOMEM); + + return skb; +} + +static const struct l2cap_ops a2mp_chan_ops = { + .name = "L2CAP A2MP channel", + .recv = a2mp_chan_recv_cb, + .close = a2mp_chan_close_cb, + .state_change = a2mp_chan_state_change_cb, + .alloc_skb = a2mp_chan_alloc_skb_cb, + + /* Not implemented for A2MP */ + .new_connection = l2cap_chan_no_new_connection, + .teardown = l2cap_chan_no_teardown, + .ready = l2cap_chan_no_ready, + .defer = l2cap_chan_no_defer, + .resume = l2cap_chan_no_resume, + .set_shutdown = l2cap_chan_no_set_shutdown, + .get_sndtimeo = l2cap_chan_no_get_sndtimeo, +}; + +static struct l2cap_chan *a2mp_chan_open(struct l2cap_conn *conn, bool locked) +{ + struct l2cap_chan *chan; + int err; + + chan = l2cap_chan_create(); + if (!chan) + return NULL; + + BT_DBG("chan %p", chan); + + chan->chan_type = L2CAP_CHAN_FIXED; + chan->scid = L2CAP_CID_A2MP; + chan->dcid = L2CAP_CID_A2MP; + chan->omtu = L2CAP_A2MP_DEFAULT_MTU; + chan->imtu = L2CAP_A2MP_DEFAULT_MTU; + chan->flush_to = L2CAP_DEFAULT_FLUSH_TO; + + chan->ops = &a2mp_chan_ops; + + l2cap_chan_set_defaults(chan); + chan->remote_max_tx = chan->max_tx; + chan->remote_tx_win = chan->tx_win; + + chan->retrans_timeout = L2CAP_DEFAULT_RETRANS_TO; + chan->monitor_timeout = L2CAP_DEFAULT_MONITOR_TO; + + skb_queue_head_init(&chan->tx_q); + + chan->mode = L2CAP_MODE_ERTM; + + err = l2cap_ertm_init(chan); + if (err < 0) { + l2cap_chan_del(chan, 0); + return NULL; + } + + chan->conf_state = 0; + + if (locked) + __l2cap_chan_add(conn, chan); + else + l2cap_chan_add(conn, chan); + + chan->remote_mps = chan->omtu; + chan->mps = chan->omtu; + + chan->state = BT_CONNECTED; + + return chan; +} + +/* AMP Manager functions */ +struct amp_mgr *amp_mgr_get(struct amp_mgr *mgr) +{ + BT_DBG("mgr %p orig refcnt %d", mgr, kref_read(&mgr->kref)); + + kref_get(&mgr->kref); + + return mgr; +} + +static void amp_mgr_destroy(struct kref *kref) +{ + struct amp_mgr *mgr = container_of(kref, struct amp_mgr, kref); + + BT_DBG("mgr %p", mgr); + + mutex_lock(&_mgr_list_lock); + list_del(&mgr->list); + mutex_unlock(&_mgr_list_lock); + + amp_ctrl_list_flush(mgr); + kfree(mgr); +} + +int amp_mgr_put(struct amp_mgr *mgr) +{ + BT_DBG("mgr %p orig refcnt %d", mgr, kref_read(&mgr->kref)); + + return kref_put(&mgr->kref, &_mgr_destroy); +} + +static struct amp_mgr *amp_mgr_create(struct l2cap_conn *conn, bool locked) +{ + struct amp_mgr *mgr; + struct l2cap_chan *chan; + + mgr = kzalloc(sizeof(*mgr), GFP_KERNEL); + if (!mgr) + return NULL; + + BT_DBG("conn %p mgr %p", conn, mgr); + + mgr->l2cap_conn = conn; + + chan = a2mp_chan_open(conn, locked); + if (!chan) { + kfree(mgr); + return NULL; + } + + mgr->a2mp_chan = chan; + chan->data = mgr; + + conn->hcon->amp_mgr = mgr; + + kref_init(&mgr->kref); + + /* Remote AMP ctrl list initialization */ + INIT_LIST_HEAD(&mgr->amp_ctrls); + mutex_init(&mgr->amp_ctrls_lock); + + mutex_lock(&_mgr_list_lock); + list_add(&mgr->list, &_mgr_list); + mutex_unlock(&_mgr_list_lock); + + return mgr; +} + +struct l2cap_chan *a2mp_channel_create(struct l2cap_conn *conn, + struct sk_buff *skb) +{ + struct amp_mgr *mgr; + + if (conn->hcon->type != ACL_LINK) + return NULL; + + mgr = amp_mgr_create(conn, false); + if (!mgr) { + BT_ERR("Could not create AMP manager"); + return NULL; + } + + BT_DBG("mgr: %p chan %p", mgr, mgr->a2mp_chan); + + return mgr->a2mp_chan; +} + +void a2mp_send_getinfo_rsp(struct hci_dev *hdev) +{ + struct amp_mgr *mgr; + struct a2mp_info_rsp rsp; + + mgr = amp_mgr_lookup_by_state(READ_LOC_AMP_INFO); + if (!mgr) + return; + + BT_DBG("%s mgr %p", hdev->name, mgr); + + memset(&rsp, 0, sizeof(rsp)); + + rsp.id = hdev->id; + rsp.status = A2MP_STATUS_INVALID_CTRL_ID; + + if (hdev->amp_type != AMP_TYPE_BREDR) { + rsp.status = 0; + rsp.total_bw = cpu_to_le32(hdev->amp_total_bw); + rsp.max_bw = cpu_to_le32(hdev->amp_max_bw); + rsp.min_latency = cpu_to_le32(hdev->amp_min_latency); + rsp.pal_cap = cpu_to_le16(hdev->amp_pal_cap); + rsp.assoc_size = cpu_to_le16(hdev->amp_assoc_size); + } + + a2mp_send(mgr, A2MP_GETINFO_RSP, mgr->ident, sizeof(rsp), &rsp); + amp_mgr_put(mgr); +} + +void a2mp_send_getampassoc_rsp(struct hci_dev *hdev, u8 status) +{ + struct amp_mgr *mgr; + struct amp_assoc *loc_assoc = &hdev->loc_assoc; + struct a2mp_amp_assoc_rsp *rsp; + size_t len; + + mgr = amp_mgr_lookup_by_state(READ_LOC_AMP_ASSOC); + if (!mgr) + return; + + BT_DBG("%s mgr %p", hdev->name, mgr); + + len = sizeof(struct a2mp_amp_assoc_rsp) + loc_assoc->len; + rsp = kzalloc(len, GFP_KERNEL); + if (!rsp) { + amp_mgr_put(mgr); + return; + } + + rsp->id = hdev->id; + + if (status) { + rsp->status = A2MP_STATUS_INVALID_CTRL_ID; + } else { + rsp->status = A2MP_STATUS_SUCCESS; + memcpy(rsp->amp_assoc, loc_assoc->data, loc_assoc->len); + } + + a2mp_send(mgr, A2MP_GETAMPASSOC_RSP, mgr->ident, len, rsp); + amp_mgr_put(mgr); + kfree(rsp); +} + +void a2mp_send_create_phy_link_req(struct hci_dev *hdev, u8 status) +{ + struct amp_mgr *mgr; + struct amp_assoc *loc_assoc = &hdev->loc_assoc; + struct a2mp_physlink_req *req; + struct l2cap_chan *bredr_chan; + size_t len; + + mgr = amp_mgr_lookup_by_state(READ_LOC_AMP_ASSOC_FINAL); + if (!mgr) + return; + + len = sizeof(*req) + loc_assoc->len; + + BT_DBG("%s mgr %p assoc_len %zu", hdev->name, mgr, len); + + req = kzalloc(len, GFP_KERNEL); + if (!req) { + amp_mgr_put(mgr); + return; + } + + bredr_chan = mgr->bredr_chan; + if (!bredr_chan) + goto clean; + + req->local_id = hdev->id; + req->remote_id = bredr_chan->remote_amp_id; + memcpy(req->amp_assoc, loc_assoc->data, loc_assoc->len); + + a2mp_send(mgr, A2MP_CREATEPHYSLINK_REQ, __next_ident(mgr), len, req); + +clean: + amp_mgr_put(mgr); + kfree(req); +} + +void a2mp_send_create_phy_link_rsp(struct hci_dev *hdev, u8 status) +{ + struct amp_mgr *mgr; + struct a2mp_physlink_rsp rsp; + struct hci_conn *hs_hcon; + + mgr = amp_mgr_lookup_by_state(WRITE_REMOTE_AMP_ASSOC); + if (!mgr) + return; + + memset(&rsp, 0, sizeof(rsp)); + + hs_hcon = hci_conn_hash_lookup_state(hdev, AMP_LINK, BT_CONNECT); + if (!hs_hcon) { + rsp.status = A2MP_STATUS_UNABLE_START_LINK_CREATION; + } else { + rsp.remote_id = hs_hcon->remote_id; + rsp.status = A2MP_STATUS_SUCCESS; + } + + BT_DBG("%s mgr %p hs_hcon %p status %u", hdev->name, mgr, hs_hcon, + status); + + rsp.local_id = hdev->id; + a2mp_send(mgr, A2MP_CREATEPHYSLINK_RSP, mgr->ident, sizeof(rsp), &rsp); + amp_mgr_put(mgr); +} + +void a2mp_discover_amp(struct l2cap_chan *chan) +{ + struct l2cap_conn *conn = chan->conn; + struct amp_mgr *mgr = conn->hcon->amp_mgr; + struct a2mp_discov_req req; + + BT_DBG("chan %p conn %p mgr %p", chan, conn, mgr); + + if (!mgr) { + mgr = amp_mgr_create(conn, true); + if (!mgr) + return; + } + + mgr->bredr_chan = chan; + + memset(&req, 0, sizeof(req)); + + req.mtu = cpu_to_le16(L2CAP_A2MP_DEFAULT_MTU); + req.ext_feat = 0; + a2mp_send(mgr, A2MP_DISCOVER_REQ, 1, sizeof(req), &req); +} diff --git a/net/bluetooth/a2mp.h b/net/bluetooth/a2mp.h new file mode 100644 index 000000000..2fd253a61 --- /dev/null +++ b/net/bluetooth/a2mp.h @@ -0,0 +1,154 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + Copyright (c) 2010,2011 Code Aurora Forum. All rights reserved. + Copyright (c) 2011,2012 Intel Corp. + +*/ + +#ifndef __A2MP_H +#define __A2MP_H + +#include + +enum amp_mgr_state { + READ_LOC_AMP_INFO, + READ_LOC_AMP_ASSOC, + READ_LOC_AMP_ASSOC_FINAL, + WRITE_REMOTE_AMP_ASSOC, +}; + +struct amp_mgr { + struct list_head list; + struct l2cap_conn *l2cap_conn; + struct l2cap_chan *a2mp_chan; + struct l2cap_chan *bredr_chan; + struct kref kref; + __u8 ident; + __u8 handle; + unsigned long state; + unsigned long flags; + + struct list_head amp_ctrls; + struct mutex amp_ctrls_lock; +}; + +struct a2mp_cmd { + __u8 code; + __u8 ident; + __le16 len; + __u8 data[]; +} __packed; + +/* A2MP command codes */ +#define A2MP_COMMAND_REJ 0x01 +struct a2mp_cmd_rej { + __le16 reason; + __u8 data[]; +} __packed; + +#define A2MP_DISCOVER_REQ 0x02 +struct a2mp_discov_req { + __le16 mtu; + __le16 ext_feat; +} __packed; + +struct a2mp_cl { + __u8 id; + __u8 type; + __u8 status; +} __packed; + +#define A2MP_DISCOVER_RSP 0x03 +struct a2mp_discov_rsp { + __le16 mtu; + __le16 ext_feat; + struct a2mp_cl cl[]; +} __packed; + +#define A2MP_CHANGE_NOTIFY 0x04 +#define A2MP_CHANGE_RSP 0x05 + +#define A2MP_GETINFO_REQ 0x06 +struct a2mp_info_req { + __u8 id; +} __packed; + +#define A2MP_GETINFO_RSP 0x07 +struct a2mp_info_rsp { + __u8 id; + __u8 status; + __le32 total_bw; + __le32 max_bw; + __le32 min_latency; + __le16 pal_cap; + __le16 assoc_size; +} __packed; + +#define A2MP_GETAMPASSOC_REQ 0x08 +struct a2mp_amp_assoc_req { + __u8 id; +} __packed; + +#define A2MP_GETAMPASSOC_RSP 0x09 +struct a2mp_amp_assoc_rsp { + __u8 id; + __u8 status; + __u8 amp_assoc[]; +} __packed; + +#define A2MP_CREATEPHYSLINK_REQ 0x0A +#define A2MP_DISCONNPHYSLINK_REQ 0x0C +struct a2mp_physlink_req { + __u8 local_id; + __u8 remote_id; + __u8 amp_assoc[]; +} __packed; + +#define A2MP_CREATEPHYSLINK_RSP 0x0B +#define A2MP_DISCONNPHYSLINK_RSP 0x0D +struct a2mp_physlink_rsp { + __u8 local_id; + __u8 remote_id; + __u8 status; +} __packed; + +/* A2MP response status */ +#define A2MP_STATUS_SUCCESS 0x00 +#define A2MP_STATUS_INVALID_CTRL_ID 0x01 +#define A2MP_STATUS_UNABLE_START_LINK_CREATION 0x02 +#define A2MP_STATUS_NO_PHYSICAL_LINK_EXISTS 0x02 +#define A2MP_STATUS_COLLISION_OCCURED 0x03 +#define A2MP_STATUS_DISCONN_REQ_RECVD 0x04 +#define A2MP_STATUS_PHYS_LINK_EXISTS 0x05 +#define A2MP_STATUS_SECURITY_VIOLATION 0x06 + +struct amp_mgr *amp_mgr_get(struct amp_mgr *mgr); + +#if IS_ENABLED(CONFIG_BT_HS) +int amp_mgr_put(struct amp_mgr *mgr); +struct l2cap_chan *a2mp_channel_create(struct l2cap_conn *conn, + struct sk_buff *skb); +void a2mp_discover_amp(struct l2cap_chan *chan); +#else +static inline int amp_mgr_put(struct amp_mgr *mgr) +{ + return 0; +} + +static inline struct l2cap_chan *a2mp_channel_create(struct l2cap_conn *conn, + struct sk_buff *skb) +{ + return NULL; +} + +static inline void a2mp_discover_amp(struct l2cap_chan *chan) +{ +} +#endif + +void a2mp_send_getinfo_rsp(struct hci_dev *hdev); +void a2mp_send_getampassoc_rsp(struct hci_dev *hdev, u8 status); +void a2mp_send_create_phy_link_req(struct hci_dev *hdev, u8 status); +void a2mp_send_create_phy_link_rsp(struct hci_dev *hdev, u8 status); + +#endif /* __A2MP_H */ diff --git a/net/bluetooth/af_bluetooth.c b/net/bluetooth/af_bluetooth.c new file mode 100644 index 000000000..f1b751035 --- /dev/null +++ b/net/bluetooth/af_bluetooth.c @@ -0,0 +1,812 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + Copyright (C) 2000-2001 Qualcomm Incorporated + + Written 2000,2001 by Maxim Krasnyansky + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +/* Bluetooth address family and sockets. */ + +#include +#include +#include +#include + +#include + +#include +#include + +#include "leds.h" +#include "selftest.h" + +/* Bluetooth sockets */ +#define BT_MAX_PROTO (BTPROTO_LAST + 1) +static const struct net_proto_family *bt_proto[BT_MAX_PROTO]; +static DEFINE_RWLOCK(bt_proto_lock); + +static struct lock_class_key bt_lock_key[BT_MAX_PROTO]; +static const char *const bt_key_strings[BT_MAX_PROTO] = { + "sk_lock-AF_BLUETOOTH-BTPROTO_L2CAP", + "sk_lock-AF_BLUETOOTH-BTPROTO_HCI", + "sk_lock-AF_BLUETOOTH-BTPROTO_SCO", + "sk_lock-AF_BLUETOOTH-BTPROTO_RFCOMM", + "sk_lock-AF_BLUETOOTH-BTPROTO_BNEP", + "sk_lock-AF_BLUETOOTH-BTPROTO_CMTP", + "sk_lock-AF_BLUETOOTH-BTPROTO_HIDP", + "sk_lock-AF_BLUETOOTH-BTPROTO_AVDTP", + "sk_lock-AF_BLUETOOTH-BTPROTO_ISO", +}; + +static struct lock_class_key bt_slock_key[BT_MAX_PROTO]; +static const char *const bt_slock_key_strings[BT_MAX_PROTO] = { + "slock-AF_BLUETOOTH-BTPROTO_L2CAP", + "slock-AF_BLUETOOTH-BTPROTO_HCI", + "slock-AF_BLUETOOTH-BTPROTO_SCO", + "slock-AF_BLUETOOTH-BTPROTO_RFCOMM", + "slock-AF_BLUETOOTH-BTPROTO_BNEP", + "slock-AF_BLUETOOTH-BTPROTO_CMTP", + "slock-AF_BLUETOOTH-BTPROTO_HIDP", + "slock-AF_BLUETOOTH-BTPROTO_AVDTP", + "slock-AF_BLUETOOTH-BTPROTO_ISO", +}; + +void bt_sock_reclassify_lock(struct sock *sk, int proto) +{ + BUG_ON(!sk); + BUG_ON(!sock_allow_reclassification(sk)); + + sock_lock_init_class_and_name(sk, + bt_slock_key_strings[proto], &bt_slock_key[proto], + bt_key_strings[proto], &bt_lock_key[proto]); +} +EXPORT_SYMBOL(bt_sock_reclassify_lock); + +int bt_sock_register(int proto, const struct net_proto_family *ops) +{ + int err = 0; + + if (proto < 0 || proto >= BT_MAX_PROTO) + return -EINVAL; + + write_lock(&bt_proto_lock); + + if (bt_proto[proto]) + err = -EEXIST; + else + bt_proto[proto] = ops; + + write_unlock(&bt_proto_lock); + + return err; +} +EXPORT_SYMBOL(bt_sock_register); + +void bt_sock_unregister(int proto) +{ + if (proto < 0 || proto >= BT_MAX_PROTO) + return; + + write_lock(&bt_proto_lock); + bt_proto[proto] = NULL; + write_unlock(&bt_proto_lock); +} +EXPORT_SYMBOL(bt_sock_unregister); + +static int bt_sock_create(struct net *net, struct socket *sock, int proto, + int kern) +{ + int err; + + if (net != &init_net) + return -EAFNOSUPPORT; + + if (proto < 0 || proto >= BT_MAX_PROTO) + return -EINVAL; + + if (!bt_proto[proto]) + request_module("bt-proto-%d", proto); + + err = -EPROTONOSUPPORT; + + read_lock(&bt_proto_lock); + + if (bt_proto[proto] && try_module_get(bt_proto[proto]->owner)) { + err = bt_proto[proto]->create(net, sock, proto, kern); + if (!err) + bt_sock_reclassify_lock(sock->sk, proto); + module_put(bt_proto[proto]->owner); + } + + read_unlock(&bt_proto_lock); + + return err; +} + +void bt_sock_link(struct bt_sock_list *l, struct sock *sk) +{ + write_lock(&l->lock); + sk_add_node(sk, &l->head); + write_unlock(&l->lock); +} +EXPORT_SYMBOL(bt_sock_link); + +void bt_sock_unlink(struct bt_sock_list *l, struct sock *sk) +{ + write_lock(&l->lock); + sk_del_node_init(sk); + write_unlock(&l->lock); +} +EXPORT_SYMBOL(bt_sock_unlink); + +void bt_accept_enqueue(struct sock *parent, struct sock *sk, bool bh) +{ + BT_DBG("parent %p, sk %p", parent, sk); + + sock_hold(sk); + + if (bh) + bh_lock_sock_nested(sk); + else + lock_sock_nested(sk, SINGLE_DEPTH_NESTING); + + list_add_tail(&bt_sk(sk)->accept_q, &bt_sk(parent)->accept_q); + bt_sk(sk)->parent = parent; + + if (bh) + bh_unlock_sock(sk); + else + release_sock(sk); + + sk_acceptq_added(parent); +} +EXPORT_SYMBOL(bt_accept_enqueue); + +/* Calling function must hold the sk lock. + * bt_sk(sk)->parent must be non-NULL meaning sk is in the parent list. + */ +void bt_accept_unlink(struct sock *sk) +{ + BT_DBG("sk %p state %d", sk, sk->sk_state); + + list_del_init(&bt_sk(sk)->accept_q); + sk_acceptq_removed(bt_sk(sk)->parent); + bt_sk(sk)->parent = NULL; + sock_put(sk); +} +EXPORT_SYMBOL(bt_accept_unlink); + +struct sock *bt_accept_dequeue(struct sock *parent, struct socket *newsock) +{ + struct bt_sock *s, *n; + struct sock *sk; + + BT_DBG("parent %p", parent); + +restart: + list_for_each_entry_safe(s, n, &bt_sk(parent)->accept_q, accept_q) { + sk = (struct sock *)s; + + /* Prevent early freeing of sk due to unlink and sock_kill */ + sock_hold(sk); + lock_sock(sk); + + /* Check sk has not already been unlinked via + * bt_accept_unlink() due to serialisation caused by sk locking + */ + if (!bt_sk(sk)->parent) { + BT_DBG("sk %p, already unlinked", sk); + release_sock(sk); + sock_put(sk); + + /* Restart the loop as sk is no longer in the list + * and also avoid a potential infinite loop because + * list_for_each_entry_safe() is not thread safe. + */ + goto restart; + } + + /* sk is safely in the parent list so reduce reference count */ + sock_put(sk); + + /* FIXME: Is this check still needed */ + if (sk->sk_state == BT_CLOSED) { + bt_accept_unlink(sk); + release_sock(sk); + continue; + } + + if (sk->sk_state == BT_CONNECTED || !newsock || + test_bit(BT_SK_DEFER_SETUP, &bt_sk(parent)->flags)) { + bt_accept_unlink(sk); + if (newsock) + sock_graft(sk, newsock); + + release_sock(sk); + return sk; + } + + release_sock(sk); + } + + return NULL; +} +EXPORT_SYMBOL(bt_accept_dequeue); + +int bt_sock_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, + int flags) +{ + struct sock *sk = sock->sk; + struct sk_buff *skb; + size_t copied; + size_t skblen; + int err; + + BT_DBG("sock %p sk %p len %zu", sock, sk, len); + + if (flags & MSG_OOB) + return -EOPNOTSUPP; + + lock_sock(sk); + + skb = skb_recv_datagram(sk, flags, &err); + if (!skb) { + if (sk->sk_shutdown & RCV_SHUTDOWN) + err = 0; + + release_sock(sk); + return err; + } + + skblen = skb->len; + copied = skb->len; + if (len < copied) { + msg->msg_flags |= MSG_TRUNC; + copied = len; + } + + skb_reset_transport_header(skb); + err = skb_copy_datagram_msg(skb, 0, msg, copied); + if (err == 0) { + sock_recv_cmsgs(msg, sk, skb); + + if (msg->msg_name && bt_sk(sk)->skb_msg_name) + bt_sk(sk)->skb_msg_name(skb, msg->msg_name, + &msg->msg_namelen); + + if (bt_sk(sk)->skb_put_cmsg) + bt_sk(sk)->skb_put_cmsg(skb, msg, sk); + } + + skb_free_datagram(sk, skb); + + release_sock(sk); + + if (flags & MSG_TRUNC) + copied = skblen; + + return err ? : copied; +} +EXPORT_SYMBOL(bt_sock_recvmsg); + +static long bt_sock_data_wait(struct sock *sk, long timeo) +{ + DECLARE_WAITQUEUE(wait, current); + + add_wait_queue(sk_sleep(sk), &wait); + for (;;) { + set_current_state(TASK_INTERRUPTIBLE); + + if (!skb_queue_empty(&sk->sk_receive_queue)) + break; + + if (sk->sk_err || (sk->sk_shutdown & RCV_SHUTDOWN)) + break; + + if (signal_pending(current) || !timeo) + break; + + sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); + release_sock(sk); + timeo = schedule_timeout(timeo); + lock_sock(sk); + sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); + } + + __set_current_state(TASK_RUNNING); + remove_wait_queue(sk_sleep(sk), &wait); + return timeo; +} + +int bt_sock_stream_recvmsg(struct socket *sock, struct msghdr *msg, + size_t size, int flags) +{ + struct sock *sk = sock->sk; + int err = 0; + size_t target, copied = 0; + long timeo; + + if (flags & MSG_OOB) + return -EOPNOTSUPP; + + BT_DBG("sk %p size %zu", sk, size); + + lock_sock(sk); + + target = sock_rcvlowat(sk, flags & MSG_WAITALL, size); + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + + do { + struct sk_buff *skb; + int chunk; + + skb = skb_dequeue(&sk->sk_receive_queue); + if (!skb) { + if (copied >= target) + break; + + err = sock_error(sk); + if (err) + break; + if (sk->sk_shutdown & RCV_SHUTDOWN) + break; + + err = -EAGAIN; + if (!timeo) + break; + + timeo = bt_sock_data_wait(sk, timeo); + + if (signal_pending(current)) { + err = sock_intr_errno(timeo); + goto out; + } + continue; + } + + chunk = min_t(unsigned int, skb->len, size); + if (skb_copy_datagram_msg(skb, 0, msg, chunk)) { + skb_queue_head(&sk->sk_receive_queue, skb); + if (!copied) + copied = -EFAULT; + break; + } + copied += chunk; + size -= chunk; + + sock_recv_cmsgs(msg, sk, skb); + + if (!(flags & MSG_PEEK)) { + int skb_len = skb_headlen(skb); + + if (chunk <= skb_len) { + __skb_pull(skb, chunk); + } else { + struct sk_buff *frag; + + __skb_pull(skb, skb_len); + chunk -= skb_len; + + skb_walk_frags(skb, frag) { + if (chunk <= frag->len) { + /* Pulling partial data */ + skb->len -= chunk; + skb->data_len -= chunk; + __skb_pull(frag, chunk); + break; + } else if (frag->len) { + /* Pulling all frag data */ + chunk -= frag->len; + skb->len -= frag->len; + skb->data_len -= frag->len; + __skb_pull(frag, frag->len); + } + } + } + + if (skb->len) { + skb_queue_head(&sk->sk_receive_queue, skb); + break; + } + kfree_skb(skb); + + } else { + /* put message back and return */ + skb_queue_head(&sk->sk_receive_queue, skb); + break; + } + } while (size); + +out: + release_sock(sk); + return copied ? : err; +} +EXPORT_SYMBOL(bt_sock_stream_recvmsg); + +static inline __poll_t bt_accept_poll(struct sock *parent) +{ + struct bt_sock *s, *n; + struct sock *sk; + + list_for_each_entry_safe(s, n, &bt_sk(parent)->accept_q, accept_q) { + sk = (struct sock *)s; + if (sk->sk_state == BT_CONNECTED || + (test_bit(BT_SK_DEFER_SETUP, &bt_sk(parent)->flags) && + sk->sk_state == BT_CONNECT2)) + return EPOLLIN | EPOLLRDNORM; + } + + return 0; +} + +__poll_t bt_sock_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + struct sock *sk = sock->sk; + __poll_t mask = 0; + + poll_wait(file, sk_sleep(sk), wait); + + if (sk->sk_state == BT_LISTEN) + return bt_accept_poll(sk); + + if (sk->sk_err || !skb_queue_empty_lockless(&sk->sk_error_queue)) + mask |= EPOLLERR | + (sock_flag(sk, SOCK_SELECT_ERR_QUEUE) ? EPOLLPRI : 0); + + if (sk->sk_shutdown & RCV_SHUTDOWN) + mask |= EPOLLRDHUP | EPOLLIN | EPOLLRDNORM; + + if (sk->sk_shutdown == SHUTDOWN_MASK) + mask |= EPOLLHUP; + + if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) + mask |= EPOLLIN | EPOLLRDNORM; + + if (sk->sk_state == BT_CLOSED) + mask |= EPOLLHUP; + + if (sk->sk_state == BT_CONNECT || + sk->sk_state == BT_CONNECT2 || + sk->sk_state == BT_CONFIG) + return mask; + + if (!test_bit(BT_SK_SUSPEND, &bt_sk(sk)->flags) && sock_writeable(sk)) + mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND; + else + sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk); + + return mask; +} +EXPORT_SYMBOL(bt_sock_poll); + +int bt_sock_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + struct sock *sk = sock->sk; + struct sk_buff *skb; + long amount; + int err; + + BT_DBG("sk %p cmd %x arg %lx", sk, cmd, arg); + + switch (cmd) { + case TIOCOUTQ: + if (sk->sk_state == BT_LISTEN) + return -EINVAL; + + amount = sk->sk_sndbuf - sk_wmem_alloc_get(sk); + if (amount < 0) + amount = 0; + err = put_user(amount, (int __user *)arg); + break; + + case TIOCINQ: + if (sk->sk_state == BT_LISTEN) + return -EINVAL; + + lock_sock(sk); + skb = skb_peek(&sk->sk_receive_queue); + amount = skb ? skb->len : 0; + release_sock(sk); + err = put_user(amount, (int __user *)arg); + break; + + default: + err = -ENOIOCTLCMD; + break; + } + + return err; +} +EXPORT_SYMBOL(bt_sock_ioctl); + +/* This function expects the sk lock to be held when called */ +int bt_sock_wait_state(struct sock *sk, int state, unsigned long timeo) +{ + DECLARE_WAITQUEUE(wait, current); + int err = 0; + + BT_DBG("sk %p", sk); + + add_wait_queue(sk_sleep(sk), &wait); + set_current_state(TASK_INTERRUPTIBLE); + while (sk->sk_state != state) { + if (!timeo) { + err = -EINPROGRESS; + break; + } + + if (signal_pending(current)) { + err = sock_intr_errno(timeo); + break; + } + + release_sock(sk); + timeo = schedule_timeout(timeo); + lock_sock(sk); + set_current_state(TASK_INTERRUPTIBLE); + + err = sock_error(sk); + if (err) + break; + } + __set_current_state(TASK_RUNNING); + remove_wait_queue(sk_sleep(sk), &wait); + return err; +} +EXPORT_SYMBOL(bt_sock_wait_state); + +/* This function expects the sk lock to be held when called */ +int bt_sock_wait_ready(struct sock *sk, unsigned int msg_flags) +{ + DECLARE_WAITQUEUE(wait, current); + unsigned long timeo; + int err = 0; + + BT_DBG("sk %p", sk); + + timeo = sock_sndtimeo(sk, !!(msg_flags & MSG_DONTWAIT)); + + add_wait_queue(sk_sleep(sk), &wait); + set_current_state(TASK_INTERRUPTIBLE); + while (test_bit(BT_SK_SUSPEND, &bt_sk(sk)->flags)) { + if (!timeo) { + err = -EAGAIN; + break; + } + + if (signal_pending(current)) { + err = sock_intr_errno(timeo); + break; + } + + release_sock(sk); + timeo = schedule_timeout(timeo); + lock_sock(sk); + set_current_state(TASK_INTERRUPTIBLE); + + err = sock_error(sk); + if (err) + break; + } + __set_current_state(TASK_RUNNING); + remove_wait_queue(sk_sleep(sk), &wait); + + return err; +} +EXPORT_SYMBOL(bt_sock_wait_ready); + +#ifdef CONFIG_PROC_FS +static void *bt_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(seq->private->l->lock) +{ + struct bt_sock_list *l = pde_data(file_inode(seq->file)); + + read_lock(&l->lock); + return seq_hlist_start_head(&l->head, *pos); +} + +static void *bt_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct bt_sock_list *l = pde_data(file_inode(seq->file)); + + return seq_hlist_next(v, &l->head, pos); +} + +static void bt_seq_stop(struct seq_file *seq, void *v) + __releases(seq->private->l->lock) +{ + struct bt_sock_list *l = pde_data(file_inode(seq->file)); + + read_unlock(&l->lock); +} + +static int bt_seq_show(struct seq_file *seq, void *v) +{ + struct bt_sock_list *l = pde_data(file_inode(seq->file)); + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, "sk RefCnt Rmem Wmem User Inode Parent"); + + if (l->custom_seq_show) { + seq_putc(seq, ' '); + l->custom_seq_show(seq, v); + } + + seq_putc(seq, '\n'); + } else { + struct sock *sk = sk_entry(v); + struct bt_sock *bt = bt_sk(sk); + + seq_printf(seq, + "%pK %-6d %-6u %-6u %-6u %-6lu %-6lu", + sk, + refcount_read(&sk->sk_refcnt), + sk_rmem_alloc_get(sk), + sk_wmem_alloc_get(sk), + from_kuid(seq_user_ns(seq), sock_i_uid(sk)), + sock_i_ino(sk), + bt->parent ? sock_i_ino(bt->parent) : 0LU); + + if (l->custom_seq_show) { + seq_putc(seq, ' '); + l->custom_seq_show(seq, v); + } + + seq_putc(seq, '\n'); + } + return 0; +} + +static const struct seq_operations bt_seq_ops = { + .start = bt_seq_start, + .next = bt_seq_next, + .stop = bt_seq_stop, + .show = bt_seq_show, +}; + +int bt_procfs_init(struct net *net, const char *name, + struct bt_sock_list *sk_list, + int (*seq_show)(struct seq_file *, void *)) +{ + sk_list->custom_seq_show = seq_show; + + if (!proc_create_seq_data(name, 0, net->proc_net, &bt_seq_ops, sk_list)) + return -ENOMEM; + return 0; +} + +void bt_procfs_cleanup(struct net *net, const char *name) +{ + remove_proc_entry(name, net->proc_net); +} +#else +int bt_procfs_init(struct net *net, const char *name, + struct bt_sock_list *sk_list, + int (*seq_show)(struct seq_file *, void *)) +{ + return 0; +} + +void bt_procfs_cleanup(struct net *net, const char *name) +{ +} +#endif +EXPORT_SYMBOL(bt_procfs_init); +EXPORT_SYMBOL(bt_procfs_cleanup); + +static const struct net_proto_family bt_sock_family_ops = { + .owner = THIS_MODULE, + .family = PF_BLUETOOTH, + .create = bt_sock_create, +}; + +struct dentry *bt_debugfs; +EXPORT_SYMBOL_GPL(bt_debugfs); + +#define VERSION __stringify(BT_SUBSYS_VERSION) "." \ + __stringify(BT_SUBSYS_REVISION) + +static int __init bt_init(void) +{ + int err; + + sock_skb_cb_check_size(sizeof(struct bt_skb_cb)); + + BT_INFO("Core ver %s", VERSION); + + err = bt_selftest(); + if (err < 0) + return err; + + bt_debugfs = debugfs_create_dir("bluetooth", NULL); + + bt_leds_init(); + + err = bt_sysfs_init(); + if (err < 0) + goto cleanup_led; + + err = sock_register(&bt_sock_family_ops); + if (err) + goto cleanup_sysfs; + + BT_INFO("HCI device and connection manager initialized"); + + err = hci_sock_init(); + if (err) + goto unregister_socket; + + err = l2cap_init(); + if (err) + goto cleanup_socket; + + err = sco_init(); + if (err) + goto cleanup_cap; + + err = mgmt_init(); + if (err) + goto cleanup_sco; + + return 0; + +cleanup_sco: + sco_exit(); +cleanup_cap: + l2cap_exit(); +cleanup_socket: + hci_sock_cleanup(); +unregister_socket: + sock_unregister(PF_BLUETOOTH); +cleanup_sysfs: + bt_sysfs_cleanup(); +cleanup_led: + bt_leds_cleanup(); + return err; +} + +static void __exit bt_exit(void) +{ + mgmt_exit(); + + sco_exit(); + + l2cap_exit(); + + hci_sock_cleanup(); + + sock_unregister(PF_BLUETOOTH); + + bt_sysfs_cleanup(); + + bt_leds_cleanup(); + + debugfs_remove_recursive(bt_debugfs); +} + +subsys_initcall(bt_init); +module_exit(bt_exit); + +MODULE_AUTHOR("Marcel Holtmann "); +MODULE_DESCRIPTION("Bluetooth Core ver " VERSION); +MODULE_VERSION(VERSION); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_BLUETOOTH); diff --git a/net/bluetooth/amp.c b/net/bluetooth/amp.c new file mode 100644 index 000000000..2134f92bd --- /dev/null +++ b/net/bluetooth/amp.c @@ -0,0 +1,591 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + Copyright (c) 2011,2012 Intel Corp. + +*/ + +#include +#include +#include +#include + +#include "hci_request.h" +#include "a2mp.h" +#include "amp.h" + +/* Remote AMP Controllers interface */ +void amp_ctrl_get(struct amp_ctrl *ctrl) +{ + BT_DBG("ctrl %p orig refcnt %d", ctrl, + kref_read(&ctrl->kref)); + + kref_get(&ctrl->kref); +} + +static void amp_ctrl_destroy(struct kref *kref) +{ + struct amp_ctrl *ctrl = container_of(kref, struct amp_ctrl, kref); + + BT_DBG("ctrl %p", ctrl); + + kfree(ctrl->assoc); + kfree(ctrl); +} + +int amp_ctrl_put(struct amp_ctrl *ctrl) +{ + BT_DBG("ctrl %p orig refcnt %d", ctrl, + kref_read(&ctrl->kref)); + + return kref_put(&ctrl->kref, &_ctrl_destroy); +} + +struct amp_ctrl *amp_ctrl_add(struct amp_mgr *mgr, u8 id) +{ + struct amp_ctrl *ctrl; + + ctrl = kzalloc(sizeof(*ctrl), GFP_KERNEL); + if (!ctrl) + return NULL; + + kref_init(&ctrl->kref); + ctrl->id = id; + + mutex_lock(&mgr->amp_ctrls_lock); + list_add(&ctrl->list, &mgr->amp_ctrls); + mutex_unlock(&mgr->amp_ctrls_lock); + + BT_DBG("mgr %p ctrl %p", mgr, ctrl); + + return ctrl; +} + +void amp_ctrl_list_flush(struct amp_mgr *mgr) +{ + struct amp_ctrl *ctrl, *n; + + BT_DBG("mgr %p", mgr); + + mutex_lock(&mgr->amp_ctrls_lock); + list_for_each_entry_safe(ctrl, n, &mgr->amp_ctrls, list) { + list_del(&ctrl->list); + amp_ctrl_put(ctrl); + } + mutex_unlock(&mgr->amp_ctrls_lock); +} + +struct amp_ctrl *amp_ctrl_lookup(struct amp_mgr *mgr, u8 id) +{ + struct amp_ctrl *ctrl; + + BT_DBG("mgr %p id %u", mgr, id); + + mutex_lock(&mgr->amp_ctrls_lock); + list_for_each_entry(ctrl, &mgr->amp_ctrls, list) { + if (ctrl->id == id) { + amp_ctrl_get(ctrl); + mutex_unlock(&mgr->amp_ctrls_lock); + return ctrl; + } + } + mutex_unlock(&mgr->amp_ctrls_lock); + + return NULL; +} + +/* Physical Link interface */ +static u8 __next_handle(struct amp_mgr *mgr) +{ + if (++mgr->handle == 0) + mgr->handle = 1; + + return mgr->handle; +} + +struct hci_conn *phylink_add(struct hci_dev *hdev, struct amp_mgr *mgr, + u8 remote_id, bool out) +{ + bdaddr_t *dst = &mgr->l2cap_conn->hcon->dst; + struct hci_conn *hcon; + u8 role = out ? HCI_ROLE_MASTER : HCI_ROLE_SLAVE; + + hcon = hci_conn_add(hdev, AMP_LINK, dst, role); + if (!hcon) + return NULL; + + BT_DBG("hcon %p dst %pMR", hcon, dst); + + hcon->state = BT_CONNECT; + hcon->attempt++; + hcon->handle = __next_handle(mgr); + hcon->remote_id = remote_id; + hcon->amp_mgr = amp_mgr_get(mgr); + + return hcon; +} + +/* AMP crypto key generation interface */ +static int hmac_sha256(u8 *key, u8 ksize, char *plaintext, u8 psize, u8 *output) +{ + struct crypto_shash *tfm; + struct shash_desc *shash; + int ret; + + if (!ksize) + return -EINVAL; + + tfm = crypto_alloc_shash("hmac(sha256)", 0, 0); + if (IS_ERR(tfm)) { + BT_DBG("crypto_alloc_ahash failed: err %ld", PTR_ERR(tfm)); + return PTR_ERR(tfm); + } + + ret = crypto_shash_setkey(tfm, key, ksize); + if (ret) { + BT_DBG("crypto_ahash_setkey failed: err %d", ret); + goto failed; + } + + shash = kzalloc(sizeof(*shash) + crypto_shash_descsize(tfm), + GFP_KERNEL); + if (!shash) { + ret = -ENOMEM; + goto failed; + } + + shash->tfm = tfm; + + ret = crypto_shash_digest(shash, plaintext, psize, output); + + kfree(shash); + +failed: + crypto_free_shash(tfm); + return ret; +} + +int phylink_gen_key(struct hci_conn *conn, u8 *data, u8 *len, u8 *type) +{ + struct hci_dev *hdev = conn->hdev; + struct link_key *key; + u8 keybuf[HCI_AMP_LINK_KEY_SIZE]; + u8 gamp_key[HCI_AMP_LINK_KEY_SIZE]; + int err; + + if (!hci_conn_check_link_mode(conn)) + return -EACCES; + + BT_DBG("conn %p key_type %d", conn, conn->key_type); + + /* Legacy key */ + if (conn->key_type < 3) { + bt_dev_err(hdev, "legacy key type %u", conn->key_type); + return -EACCES; + } + + *type = conn->key_type; + *len = HCI_AMP_LINK_KEY_SIZE; + + key = hci_find_link_key(hdev, &conn->dst); + if (!key) { + BT_DBG("No Link key for conn %p dst %pMR", conn, &conn->dst); + return -EACCES; + } + + /* BR/EDR Link Key concatenated together with itself */ + memcpy(&keybuf[0], key->val, HCI_LINK_KEY_SIZE); + memcpy(&keybuf[HCI_LINK_KEY_SIZE], key->val, HCI_LINK_KEY_SIZE); + + /* Derive Generic AMP Link Key (gamp) */ + err = hmac_sha256(keybuf, HCI_AMP_LINK_KEY_SIZE, "gamp", 4, gamp_key); + if (err) { + bt_dev_err(hdev, "could not derive Generic AMP Key: err %d", err); + return err; + } + + if (conn->key_type == HCI_LK_DEBUG_COMBINATION) { + BT_DBG("Use Generic AMP Key (gamp)"); + memcpy(data, gamp_key, HCI_AMP_LINK_KEY_SIZE); + return err; + } + + /* Derive Dedicated AMP Link Key: "802b" is 802.11 PAL keyID */ + return hmac_sha256(gamp_key, HCI_AMP_LINK_KEY_SIZE, "802b", 4, data); +} + +static void read_local_amp_assoc_complete(struct hci_dev *hdev, u8 status, + u16 opcode, struct sk_buff *skb) +{ + struct hci_rp_read_local_amp_assoc *rp = (void *)skb->data; + struct amp_assoc *assoc = &hdev->loc_assoc; + size_t rem_len, frag_len; + + BT_DBG("%s status 0x%2.2x", hdev->name, rp->status); + + if (rp->status) + goto send_rsp; + + frag_len = skb->len - sizeof(*rp); + rem_len = __le16_to_cpu(rp->rem_len); + + if (rem_len > frag_len) { + BT_DBG("frag_len %zu rem_len %zu", frag_len, rem_len); + + memcpy(assoc->data + assoc->offset, rp->frag, frag_len); + assoc->offset += frag_len; + + /* Read other fragments */ + amp_read_loc_assoc_frag(hdev, rp->phy_handle); + + return; + } + + memcpy(assoc->data + assoc->offset, rp->frag, rem_len); + assoc->len = assoc->offset + rem_len; + assoc->offset = 0; + +send_rsp: + /* Send A2MP Rsp when all fragments are received */ + a2mp_send_getampassoc_rsp(hdev, rp->status); + a2mp_send_create_phy_link_req(hdev, rp->status); +} + +void amp_read_loc_assoc_frag(struct hci_dev *hdev, u8 phy_handle) +{ + struct hci_cp_read_local_amp_assoc cp; + struct amp_assoc *loc_assoc = &hdev->loc_assoc; + struct hci_request req; + int err; + + BT_DBG("%s handle %u", hdev->name, phy_handle); + + cp.phy_handle = phy_handle; + cp.max_len = cpu_to_le16(hdev->amp_assoc_size); + cp.len_so_far = cpu_to_le16(loc_assoc->offset); + + hci_req_init(&req, hdev); + hci_req_add(&req, HCI_OP_READ_LOCAL_AMP_ASSOC, sizeof(cp), &cp); + err = hci_req_run_skb(&req, read_local_amp_assoc_complete); + if (err < 0) + a2mp_send_getampassoc_rsp(hdev, A2MP_STATUS_INVALID_CTRL_ID); +} + +void amp_read_loc_assoc(struct hci_dev *hdev, struct amp_mgr *mgr) +{ + struct hci_cp_read_local_amp_assoc cp; + struct hci_request req; + int err; + + memset(&hdev->loc_assoc, 0, sizeof(struct amp_assoc)); + memset(&cp, 0, sizeof(cp)); + + cp.max_len = cpu_to_le16(hdev->amp_assoc_size); + + set_bit(READ_LOC_AMP_ASSOC, &mgr->state); + hci_req_init(&req, hdev); + hci_req_add(&req, HCI_OP_READ_LOCAL_AMP_ASSOC, sizeof(cp), &cp); + err = hci_req_run_skb(&req, read_local_amp_assoc_complete); + if (err < 0) + a2mp_send_getampassoc_rsp(hdev, A2MP_STATUS_INVALID_CTRL_ID); +} + +void amp_read_loc_assoc_final_data(struct hci_dev *hdev, + struct hci_conn *hcon) +{ + struct hci_cp_read_local_amp_assoc cp; + struct amp_mgr *mgr = hcon->amp_mgr; + struct hci_request req; + int err; + + if (!mgr) + return; + + cp.phy_handle = hcon->handle; + cp.len_so_far = cpu_to_le16(0); + cp.max_len = cpu_to_le16(hdev->amp_assoc_size); + + set_bit(READ_LOC_AMP_ASSOC_FINAL, &mgr->state); + + /* Read Local AMP Assoc final link information data */ + hci_req_init(&req, hdev); + hci_req_add(&req, HCI_OP_READ_LOCAL_AMP_ASSOC, sizeof(cp), &cp); + err = hci_req_run_skb(&req, read_local_amp_assoc_complete); + if (err < 0) + a2mp_send_getampassoc_rsp(hdev, A2MP_STATUS_INVALID_CTRL_ID); +} + +static void write_remote_amp_assoc_complete(struct hci_dev *hdev, u8 status, + u16 opcode, struct sk_buff *skb) +{ + struct hci_rp_write_remote_amp_assoc *rp = (void *)skb->data; + + BT_DBG("%s status 0x%2.2x phy_handle 0x%2.2x", + hdev->name, rp->status, rp->phy_handle); + + if (rp->status) + return; + + amp_write_rem_assoc_continue(hdev, rp->phy_handle); +} + +/* Write AMP Assoc data fragments, returns true with last fragment written*/ +static bool amp_write_rem_assoc_frag(struct hci_dev *hdev, + struct hci_conn *hcon) +{ + struct hci_cp_write_remote_amp_assoc *cp; + struct amp_mgr *mgr = hcon->amp_mgr; + struct amp_ctrl *ctrl; + struct hci_request req; + u16 frag_len, len; + + ctrl = amp_ctrl_lookup(mgr, hcon->remote_id); + if (!ctrl) + return false; + + if (!ctrl->assoc_rem_len) { + BT_DBG("all fragments are written"); + ctrl->assoc_rem_len = ctrl->assoc_len; + ctrl->assoc_len_so_far = 0; + + amp_ctrl_put(ctrl); + return true; + } + + frag_len = min_t(u16, 248, ctrl->assoc_rem_len); + len = frag_len + sizeof(*cp); + + cp = kzalloc(len, GFP_KERNEL); + if (!cp) { + amp_ctrl_put(ctrl); + return false; + } + + BT_DBG("hcon %p ctrl %p frag_len %u assoc_len %u rem_len %u", + hcon, ctrl, frag_len, ctrl->assoc_len, ctrl->assoc_rem_len); + + cp->phy_handle = hcon->handle; + cp->len_so_far = cpu_to_le16(ctrl->assoc_len_so_far); + cp->rem_len = cpu_to_le16(ctrl->assoc_rem_len); + memcpy(cp->frag, ctrl->assoc, frag_len); + + ctrl->assoc_len_so_far += frag_len; + ctrl->assoc_rem_len -= frag_len; + + amp_ctrl_put(ctrl); + + hci_req_init(&req, hdev); + hci_req_add(&req, HCI_OP_WRITE_REMOTE_AMP_ASSOC, len, cp); + hci_req_run_skb(&req, write_remote_amp_assoc_complete); + + kfree(cp); + + return false; +} + +void amp_write_rem_assoc_continue(struct hci_dev *hdev, u8 handle) +{ + struct hci_conn *hcon; + + BT_DBG("%s phy handle 0x%2.2x", hdev->name, handle); + + hcon = hci_conn_hash_lookup_handle(hdev, handle); + if (!hcon) + return; + + /* Send A2MP create phylink rsp when all fragments are written */ + if (amp_write_rem_assoc_frag(hdev, hcon)) + a2mp_send_create_phy_link_rsp(hdev, 0); +} + +void amp_write_remote_assoc(struct hci_dev *hdev, u8 handle) +{ + struct hci_conn *hcon; + + BT_DBG("%s phy handle 0x%2.2x", hdev->name, handle); + + hcon = hci_conn_hash_lookup_handle(hdev, handle); + if (!hcon) + return; + + BT_DBG("%s phy handle 0x%2.2x hcon %p", hdev->name, handle, hcon); + + amp_write_rem_assoc_frag(hdev, hcon); +} + +static void create_phylink_complete(struct hci_dev *hdev, u8 status, + u16 opcode) +{ + struct hci_cp_create_phy_link *cp; + + BT_DBG("%s status 0x%2.2x", hdev->name, status); + + cp = hci_sent_cmd_data(hdev, HCI_OP_CREATE_PHY_LINK); + if (!cp) + return; + + hci_dev_lock(hdev); + + if (status) { + struct hci_conn *hcon; + + hcon = hci_conn_hash_lookup_handle(hdev, cp->phy_handle); + if (hcon) + hci_conn_del(hcon); + } else { + amp_write_remote_assoc(hdev, cp->phy_handle); + } + + hci_dev_unlock(hdev); +} + +void amp_create_phylink(struct hci_dev *hdev, struct amp_mgr *mgr, + struct hci_conn *hcon) +{ + struct hci_cp_create_phy_link cp; + struct hci_request req; + + cp.phy_handle = hcon->handle; + + BT_DBG("%s hcon %p phy handle 0x%2.2x", hdev->name, hcon, + hcon->handle); + + if (phylink_gen_key(mgr->l2cap_conn->hcon, cp.key, &cp.key_len, + &cp.key_type)) { + BT_DBG("Cannot create link key"); + return; + } + + hci_req_init(&req, hdev); + hci_req_add(&req, HCI_OP_CREATE_PHY_LINK, sizeof(cp), &cp); + hci_req_run(&req, create_phylink_complete); +} + +static void accept_phylink_complete(struct hci_dev *hdev, u8 status, + u16 opcode) +{ + struct hci_cp_accept_phy_link *cp; + + BT_DBG("%s status 0x%2.2x", hdev->name, status); + + if (status) + return; + + cp = hci_sent_cmd_data(hdev, HCI_OP_ACCEPT_PHY_LINK); + if (!cp) + return; + + amp_write_remote_assoc(hdev, cp->phy_handle); +} + +void amp_accept_phylink(struct hci_dev *hdev, struct amp_mgr *mgr, + struct hci_conn *hcon) +{ + struct hci_cp_accept_phy_link cp; + struct hci_request req; + + cp.phy_handle = hcon->handle; + + BT_DBG("%s hcon %p phy handle 0x%2.2x", hdev->name, hcon, + hcon->handle); + + if (phylink_gen_key(mgr->l2cap_conn->hcon, cp.key, &cp.key_len, + &cp.key_type)) { + BT_DBG("Cannot create link key"); + return; + } + + hci_req_init(&req, hdev); + hci_req_add(&req, HCI_OP_ACCEPT_PHY_LINK, sizeof(cp), &cp); + hci_req_run(&req, accept_phylink_complete); +} + +void amp_physical_cfm(struct hci_conn *bredr_hcon, struct hci_conn *hs_hcon) +{ + struct hci_dev *bredr_hdev = hci_dev_hold(bredr_hcon->hdev); + struct amp_mgr *mgr = hs_hcon->amp_mgr; + struct l2cap_chan *bredr_chan; + + BT_DBG("bredr_hcon %p hs_hcon %p mgr %p", bredr_hcon, hs_hcon, mgr); + + if (!bredr_hdev || !mgr || !mgr->bredr_chan) + return; + + bredr_chan = mgr->bredr_chan; + + l2cap_chan_lock(bredr_chan); + + set_bit(FLAG_EFS_ENABLE, &bredr_chan->flags); + bredr_chan->remote_amp_id = hs_hcon->remote_id; + bredr_chan->local_amp_id = hs_hcon->hdev->id; + bredr_chan->hs_hcon = hs_hcon; + bredr_chan->conn->mtu = hs_hcon->hdev->block_mtu; + + __l2cap_physical_cfm(bredr_chan, 0); + + l2cap_chan_unlock(bredr_chan); + + hci_dev_put(bredr_hdev); +} + +void amp_create_logical_link(struct l2cap_chan *chan) +{ + struct hci_conn *hs_hcon = chan->hs_hcon; + struct hci_cp_create_accept_logical_link cp; + struct hci_dev *hdev; + + BT_DBG("chan %p hs_hcon %p dst %pMR", chan, hs_hcon, + &chan->conn->hcon->dst); + + if (!hs_hcon) + return; + + hdev = hci_dev_hold(chan->hs_hcon->hdev); + if (!hdev) + return; + + cp.phy_handle = hs_hcon->handle; + + cp.tx_flow_spec.id = chan->local_id; + cp.tx_flow_spec.stype = chan->local_stype; + cp.tx_flow_spec.msdu = cpu_to_le16(chan->local_msdu); + cp.tx_flow_spec.sdu_itime = cpu_to_le32(chan->local_sdu_itime); + cp.tx_flow_spec.acc_lat = cpu_to_le32(chan->local_acc_lat); + cp.tx_flow_spec.flush_to = cpu_to_le32(chan->local_flush_to); + + cp.rx_flow_spec.id = chan->remote_id; + cp.rx_flow_spec.stype = chan->remote_stype; + cp.rx_flow_spec.msdu = cpu_to_le16(chan->remote_msdu); + cp.rx_flow_spec.sdu_itime = cpu_to_le32(chan->remote_sdu_itime); + cp.rx_flow_spec.acc_lat = cpu_to_le32(chan->remote_acc_lat); + cp.rx_flow_spec.flush_to = cpu_to_le32(chan->remote_flush_to); + + if (hs_hcon->out) + hci_send_cmd(hdev, HCI_OP_CREATE_LOGICAL_LINK, sizeof(cp), + &cp); + else + hci_send_cmd(hdev, HCI_OP_ACCEPT_LOGICAL_LINK, sizeof(cp), + &cp); + + hci_dev_put(hdev); +} + +void amp_disconnect_logical_link(struct hci_chan *hchan) +{ + struct hci_conn *hcon = hchan->conn; + struct hci_cp_disconn_logical_link cp; + + if (hcon->state != BT_CONNECTED) { + BT_DBG("hchan %p not connected", hchan); + return; + } + + cp.log_handle = cpu_to_le16(hchan->handle); + hci_send_cmd(hcon->hdev, HCI_OP_DISCONN_LOGICAL_LINK, sizeof(cp), &cp); +} + +void amp_destroy_logical_link(struct hci_chan *hchan, u8 reason) +{ + BT_DBG("hchan %p", hchan); + + hci_chan_del(hchan); +} diff --git a/net/bluetooth/amp.h b/net/bluetooth/amp.h new file mode 100644 index 000000000..832764dfb --- /dev/null +++ b/net/bluetooth/amp.h @@ -0,0 +1,61 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + Copyright (c) 2011,2012 Intel Corp. + +*/ + +#ifndef __AMP_H +#define __AMP_H + +struct amp_ctrl { + struct list_head list; + struct kref kref; + __u8 id; + __u16 assoc_len_so_far; + __u16 assoc_rem_len; + __u16 assoc_len; + __u8 *assoc; +}; + +int amp_ctrl_put(struct amp_ctrl *ctrl); +void amp_ctrl_get(struct amp_ctrl *ctrl); +struct amp_ctrl *amp_ctrl_add(struct amp_mgr *mgr, u8 id); +struct amp_ctrl *amp_ctrl_lookup(struct amp_mgr *mgr, u8 id); +void amp_ctrl_list_flush(struct amp_mgr *mgr); + +struct hci_conn *phylink_add(struct hci_dev *hdev, struct amp_mgr *mgr, + u8 remote_id, bool out); + +int phylink_gen_key(struct hci_conn *hcon, u8 *data, u8 *len, u8 *type); + +void amp_read_loc_info(struct hci_dev *hdev, struct amp_mgr *mgr); +void amp_read_loc_assoc_frag(struct hci_dev *hdev, u8 phy_handle); +void amp_read_loc_assoc(struct hci_dev *hdev, struct amp_mgr *mgr); +void amp_read_loc_assoc_final_data(struct hci_dev *hdev, + struct hci_conn *hcon); +void amp_create_phylink(struct hci_dev *hdev, struct amp_mgr *mgr, + struct hci_conn *hcon); +void amp_accept_phylink(struct hci_dev *hdev, struct amp_mgr *mgr, + struct hci_conn *hcon); + +#if IS_ENABLED(CONFIG_BT_HS) +void amp_create_logical_link(struct l2cap_chan *chan); +void amp_disconnect_logical_link(struct hci_chan *hchan); +#else +static inline void amp_create_logical_link(struct l2cap_chan *chan) +{ +} + +static inline void amp_disconnect_logical_link(struct hci_chan *hchan) +{ +} +#endif + +void amp_write_remote_assoc(struct hci_dev *hdev, u8 handle); +void amp_write_rem_assoc_continue(struct hci_dev *hdev, u8 handle); +void amp_physical_cfm(struct hci_conn *bredr_hcon, struct hci_conn *hs_hcon); +void amp_create_logical_link(struct l2cap_chan *chan); +void amp_disconnect_logical_link(struct hci_chan *hchan); +void amp_destroy_logical_link(struct hci_chan *hchan, u8 reason); + +#endif /* __AMP_H */ diff --git a/net/bluetooth/aosp.c b/net/bluetooth/aosp.c new file mode 100644 index 000000000..1d67836e9 --- /dev/null +++ b/net/bluetooth/aosp.c @@ -0,0 +1,210 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2021 Intel Corporation + */ + +#include +#include + +#include "aosp.h" + +/* Command complete parameters of LE_Get_Vendor_Capabilities_Command + * The parameters grow over time. The base version that declares the + * version_supported field is v0.95. Refer to + * https://cs.android.com/android/platform/superproject/+/master:system/ + * bt/gd/hci/controller.cc;l=452?q=le_get_vendor_capabilities_handler + */ +struct aosp_rp_le_get_vendor_capa { + /* v0.95: 15 octets */ + __u8 status; + __u8 max_advt_instances; + __u8 offloaded_resolution_of_private_address; + __le16 total_scan_results_storage; + __u8 max_irk_list_sz; + __u8 filtering_support; + __u8 max_filter; + __u8 activity_energy_info_support; + __le16 version_supported; + __le16 total_num_of_advt_tracked; + __u8 extended_scan_support; + __u8 debug_logging_supported; + /* v0.96: 16 octets */ + __u8 le_address_generation_offloading_support; + /* v0.98: 21 octets */ + __le32 a2dp_source_offload_capability_mask; + __u8 bluetooth_quality_report_support; + /* v1.00: 25 octets */ + __le32 dynamic_audio_buffer_support; +} __packed; + +#define VENDOR_CAPA_BASE_SIZE 15 +#define VENDOR_CAPA_0_98_SIZE 21 + +void aosp_do_open(struct hci_dev *hdev) +{ + struct sk_buff *skb; + struct aosp_rp_le_get_vendor_capa *rp; + u16 version_supported; + + if (!hdev->aosp_capable) + return; + + bt_dev_dbg(hdev, "Initialize AOSP extension"); + + /* LE Get Vendor Capabilities Command */ + skb = __hci_cmd_sync(hdev, hci_opcode_pack(0x3f, 0x153), 0, NULL, + HCI_CMD_TIMEOUT); + if (IS_ERR_OR_NULL(skb)) { + if (!skb) + skb = ERR_PTR(-EIO); + + bt_dev_err(hdev, "AOSP get vendor capabilities (%ld)", + PTR_ERR(skb)); + return; + } + + /* A basic length check */ + if (skb->len < VENDOR_CAPA_BASE_SIZE) + goto length_error; + + rp = (struct aosp_rp_le_get_vendor_capa *)skb->data; + + version_supported = le16_to_cpu(rp->version_supported); + /* AOSP displays the verion number like v0.98, v1.00, etc. */ + bt_dev_info(hdev, "AOSP extensions version v%u.%02u", + version_supported >> 8, version_supported & 0xff); + + /* Do not support very old versions. */ + if (version_supported < 95) { + bt_dev_warn(hdev, "AOSP capabilities version %u too old", + version_supported); + goto done; + } + + if (version_supported < 98) { + bt_dev_warn(hdev, "AOSP quality report is not supported"); + goto done; + } + + if (skb->len < VENDOR_CAPA_0_98_SIZE) + goto length_error; + + /* The bluetooth_quality_report_support is defined at version + * v0.98. Refer to + * https://cs.android.com/android/platform/superproject/+/ + * master:system/bt/gd/hci/controller.cc;l=477 + */ + if (rp->bluetooth_quality_report_support) { + hdev->aosp_quality_report = true; + bt_dev_info(hdev, "AOSP quality report is supported"); + } + + goto done; + +length_error: + bt_dev_err(hdev, "AOSP capabilities length %d too short", skb->len); + +done: + kfree_skb(skb); +} + +void aosp_do_close(struct hci_dev *hdev) +{ + if (!hdev->aosp_capable) + return; + + bt_dev_dbg(hdev, "Cleanup of AOSP extension"); +} + +/* BQR command */ +#define BQR_OPCODE hci_opcode_pack(0x3f, 0x015e) + +/* BQR report action */ +#define REPORT_ACTION_ADD 0x00 +#define REPORT_ACTION_DELETE 0x01 +#define REPORT_ACTION_CLEAR 0x02 + +/* BQR event masks */ +#define QUALITY_MONITORING BIT(0) +#define APPRAOCHING_LSTO BIT(1) +#define A2DP_AUDIO_CHOPPY BIT(2) +#define SCO_VOICE_CHOPPY BIT(3) + +#define DEFAULT_BQR_EVENT_MASK (QUALITY_MONITORING | APPRAOCHING_LSTO | \ + A2DP_AUDIO_CHOPPY | SCO_VOICE_CHOPPY) + +/* Reporting at milliseconds so as not to stress the controller too much. + * Range: 0 ~ 65535 ms + */ +#define DEFALUT_REPORT_INTERVAL_MS 5000 + +struct aosp_bqr_cp { + __u8 report_action; + __u32 event_mask; + __u16 min_report_interval; +} __packed; + +static int enable_quality_report(struct hci_dev *hdev) +{ + struct sk_buff *skb; + struct aosp_bqr_cp cp; + + cp.report_action = REPORT_ACTION_ADD; + cp.event_mask = DEFAULT_BQR_EVENT_MASK; + cp.min_report_interval = DEFALUT_REPORT_INTERVAL_MS; + + skb = __hci_cmd_sync(hdev, BQR_OPCODE, sizeof(cp), &cp, + HCI_CMD_TIMEOUT); + if (IS_ERR_OR_NULL(skb)) { + if (!skb) + skb = ERR_PTR(-EIO); + + bt_dev_err(hdev, "Enabling Android BQR failed (%ld)", + PTR_ERR(skb)); + return PTR_ERR(skb); + } + + kfree_skb(skb); + return 0; +} + +static int disable_quality_report(struct hci_dev *hdev) +{ + struct sk_buff *skb; + struct aosp_bqr_cp cp = { 0 }; + + cp.report_action = REPORT_ACTION_CLEAR; + + skb = __hci_cmd_sync(hdev, BQR_OPCODE, sizeof(cp), &cp, + HCI_CMD_TIMEOUT); + if (IS_ERR_OR_NULL(skb)) { + if (!skb) + skb = ERR_PTR(-EIO); + + bt_dev_err(hdev, "Disabling Android BQR failed (%ld)", + PTR_ERR(skb)); + return PTR_ERR(skb); + } + + kfree_skb(skb); + return 0; +} + +bool aosp_has_quality_report(struct hci_dev *hdev) +{ + return hdev->aosp_quality_report; +} + +int aosp_set_quality_report(struct hci_dev *hdev, bool enable) +{ + if (!aosp_has_quality_report(hdev)) + return -EOPNOTSUPP; + + bt_dev_dbg(hdev, "quality report enable %d", enable); + + /* Enable or disable the quality report feature. */ + if (enable) + return enable_quality_report(hdev); + else + return disable_quality_report(hdev); +} diff --git a/net/bluetooth/aosp.h b/net/bluetooth/aosp.h new file mode 100644 index 000000000..2fd8886d5 --- /dev/null +++ b/net/bluetooth/aosp.h @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2021 Intel Corporation + */ + +#if IS_ENABLED(CONFIG_BT_AOSPEXT) + +void aosp_do_open(struct hci_dev *hdev); +void aosp_do_close(struct hci_dev *hdev); + +bool aosp_has_quality_report(struct hci_dev *hdev); +int aosp_set_quality_report(struct hci_dev *hdev, bool enable); + +#else + +static inline void aosp_do_open(struct hci_dev *hdev) {} +static inline void aosp_do_close(struct hci_dev *hdev) {} + +static inline bool aosp_has_quality_report(struct hci_dev *hdev) +{ + return false; +} + +static inline int aosp_set_quality_report(struct hci_dev *hdev, bool enable) +{ + return -EOPNOTSUPP; +} + +#endif diff --git a/net/bluetooth/bnep/Kconfig b/net/bluetooth/bnep/Kconfig new file mode 100644 index 000000000..aac02b5b0 --- /dev/null +++ b/net/bluetooth/bnep/Kconfig @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: GPL-2.0-only +config BT_BNEP + tristate "BNEP protocol support" + depends on BT_BREDR + select CRC32 + help + BNEP (Bluetooth Network Encapsulation Protocol) is Ethernet + emulation layer on top of Bluetooth. BNEP is required for + Bluetooth PAN (Personal Area Network). + + Say Y here to compile BNEP support into the kernel or say M to + compile it as module (bnep). + +config BT_BNEP_MC_FILTER + bool "Multicast filter support" + depends on BT_BNEP + help + This option enables the multicast filter support for BNEP. + +config BT_BNEP_PROTO_FILTER + bool "Protocol filter support" + depends on BT_BNEP + help + This option enables the protocol filter support for BNEP. + diff --git a/net/bluetooth/bnep/Makefile b/net/bluetooth/bnep/Makefile new file mode 100644 index 000000000..8af9d56bb --- /dev/null +++ b/net/bluetooth/bnep/Makefile @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the Linux Bluetooth BNEP layer. +# + +obj-$(CONFIG_BT_BNEP) += bnep.o + +bnep-objs := core.o sock.o netdev.o diff --git a/net/bluetooth/bnep/bnep.h b/net/bluetooth/bnep/bnep.h new file mode 100644 index 000000000..9680473ed --- /dev/null +++ b/net/bluetooth/bnep/bnep.h @@ -0,0 +1,173 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + BNEP protocol definition for Linux Bluetooth stack (BlueZ). + Copyright (C) 2002 Maxim Krasnyansky + +*/ + +#ifndef _BNEP_H +#define _BNEP_H + +#include +#include +#include + +/* Limits */ +#define BNEP_MAX_PROTO_FILTERS 5 +#define BNEP_MAX_MULTICAST_FILTERS 20 + +/* UUIDs */ +#define BNEP_BASE_UUID 0x0000000000001000800000805F9B34FB +#define BNEP_UUID16 0x02 +#define BNEP_UUID32 0x04 +#define BNEP_UUID128 0x16 + +#define BNEP_SVC_PANU 0x1115 +#define BNEP_SVC_NAP 0x1116 +#define BNEP_SVC_GN 0x1117 + +/* Packet types */ +#define BNEP_GENERAL 0x00 +#define BNEP_CONTROL 0x01 +#define BNEP_COMPRESSED 0x02 +#define BNEP_COMPRESSED_SRC_ONLY 0x03 +#define BNEP_COMPRESSED_DST_ONLY 0x04 + +/* Control types */ +#define BNEP_CMD_NOT_UNDERSTOOD 0x00 +#define BNEP_SETUP_CONN_REQ 0x01 +#define BNEP_SETUP_CONN_RSP 0x02 +#define BNEP_FILTER_NET_TYPE_SET 0x03 +#define BNEP_FILTER_NET_TYPE_RSP 0x04 +#define BNEP_FILTER_MULTI_ADDR_SET 0x05 +#define BNEP_FILTER_MULTI_ADDR_RSP 0x06 + +/* Extension types */ +#define BNEP_EXT_CONTROL 0x00 + +/* Response messages */ +#define BNEP_SUCCESS 0x00 + +#define BNEP_CONN_INVALID_DST 0x01 +#define BNEP_CONN_INVALID_SRC 0x02 +#define BNEP_CONN_INVALID_SVC 0x03 +#define BNEP_CONN_NOT_ALLOWED 0x04 + +#define BNEP_FILTER_UNSUPPORTED_REQ 0x01 +#define BNEP_FILTER_INVALID_RANGE 0x02 +#define BNEP_FILTER_INVALID_MCADDR 0x02 +#define BNEP_FILTER_LIMIT_REACHED 0x03 +#define BNEP_FILTER_DENIED_SECURITY 0x04 + +/* L2CAP settings */ +#define BNEP_MTU 1691 +#define BNEP_PSM 0x0f +#define BNEP_FLUSH_TO 0xffff +#define BNEP_CONNECT_TO 15 +#define BNEP_FILTER_TO 15 + +/* Headers */ +#define BNEP_TYPE_MASK 0x7f +#define BNEP_EXT_HEADER 0x80 + +struct bnep_setup_conn_req { + __u8 type; + __u8 ctrl; + __u8 uuid_size; + __u8 service[]; +} __packed; + +struct bnep_set_filter_req { + __u8 type; + __u8 ctrl; + __be16 len; + __u8 list[]; +} __packed; + +struct bnep_control_rsp { + __u8 type; + __u8 ctrl; + __be16 resp; +} __packed; + +struct bnep_ext_hdr { + __u8 type; + __u8 len; + __u8 data[]; +} __packed; + +/* BNEP ioctl defines */ +#define BNEPCONNADD _IOW('B', 200, int) +#define BNEPCONNDEL _IOW('B', 201, int) +#define BNEPGETCONNLIST _IOR('B', 210, int) +#define BNEPGETCONNINFO _IOR('B', 211, int) +#define BNEPGETSUPPFEAT _IOR('B', 212, int) + +#define BNEP_SETUP_RESPONSE 0 +#define BNEP_SETUP_RSP_SENT 10 + +struct bnep_connadd_req { + int sock; /* Connected socket */ + __u32 flags; + __u16 role; + char device[16]; /* Name of the Ethernet device */ +}; + +struct bnep_conndel_req { + __u32 flags; + __u8 dst[ETH_ALEN]; +}; + +struct bnep_conninfo { + __u32 flags; + __u16 role; + __u16 state; + __u8 dst[ETH_ALEN]; + char device[16]; +}; + +struct bnep_connlist_req { + __u32 cnum; + struct bnep_conninfo __user *ci; +}; + +struct bnep_proto_filter { + __u16 start; + __u16 end; +}; + +int bnep_add_connection(struct bnep_connadd_req *req, struct socket *sock); +int bnep_del_connection(struct bnep_conndel_req *req); +int bnep_get_connlist(struct bnep_connlist_req *req); +int bnep_get_conninfo(struct bnep_conninfo *ci); + +/* BNEP sessions */ +struct bnep_session { + struct list_head list; + + unsigned int role; + unsigned long state; + unsigned long flags; + atomic_t terminate; + struct task_struct *task; + + struct ethhdr eh; + struct msghdr msg; + + struct bnep_proto_filter proto_filter[BNEP_MAX_PROTO_FILTERS]; + unsigned long long mc_filter; + + struct socket *sock; + struct net_device *dev; +}; + +void bnep_net_setup(struct net_device *dev); +int bnep_sock_init(void); +void bnep_sock_cleanup(void); + +static inline int bnep_mc_hash(__u8 *addr) +{ + return crc32_be(~0, addr, ETH_ALEN) >> 26; +} + +#endif diff --git a/net/bluetooth/bnep/core.c b/net/bluetooth/bnep/core.c new file mode 100644 index 000000000..5a6a49885 --- /dev/null +++ b/net/bluetooth/bnep/core.c @@ -0,0 +1,769 @@ +/* + BNEP implementation for Linux Bluetooth stack (BlueZ). + Copyright (C) 2001-2002 Inventel Systemes + Written 2001-2002 by + Clément Moreau + David Libault + + Copyright (C) 2002 Maxim Krasnyansky + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "bnep.h" + +#define VERSION "1.3" + +static bool compress_src = true; +static bool compress_dst = true; + +static LIST_HEAD(bnep_session_list); +static DECLARE_RWSEM(bnep_session_sem); + +static struct bnep_session *__bnep_get_session(u8 *dst) +{ + struct bnep_session *s; + + BT_DBG(""); + + list_for_each_entry(s, &bnep_session_list, list) + if (ether_addr_equal(dst, s->eh.h_source)) + return s; + + return NULL; +} + +static void __bnep_link_session(struct bnep_session *s) +{ + list_add(&s->list, &bnep_session_list); +} + +static void __bnep_unlink_session(struct bnep_session *s) +{ + list_del(&s->list); +} + +static int bnep_send(struct bnep_session *s, void *data, size_t len) +{ + struct socket *sock = s->sock; + struct kvec iv = { data, len }; + + return kernel_sendmsg(sock, &s->msg, &iv, 1, len); +} + +static int bnep_send_rsp(struct bnep_session *s, u8 ctrl, u16 resp) +{ + struct bnep_control_rsp rsp; + rsp.type = BNEP_CONTROL; + rsp.ctrl = ctrl; + rsp.resp = htons(resp); + return bnep_send(s, &rsp, sizeof(rsp)); +} + +#ifdef CONFIG_BT_BNEP_PROTO_FILTER +static inline void bnep_set_default_proto_filter(struct bnep_session *s) +{ + /* (IPv4, ARP) */ + s->proto_filter[0].start = ETH_P_IP; + s->proto_filter[0].end = ETH_P_ARP; + /* (RARP, AppleTalk) */ + s->proto_filter[1].start = ETH_P_RARP; + s->proto_filter[1].end = ETH_P_AARP; + /* (IPX, IPv6) */ + s->proto_filter[2].start = ETH_P_IPX; + s->proto_filter[2].end = ETH_P_IPV6; +} +#endif + +static int bnep_ctrl_set_netfilter(struct bnep_session *s, __be16 *data, int len) +{ + int n; + + if (len < 2) + return -EILSEQ; + + n = get_unaligned_be16(data); + data++; + len -= 2; + + if (len < n) + return -EILSEQ; + + BT_DBG("filter len %d", n); + +#ifdef CONFIG_BT_BNEP_PROTO_FILTER + n /= 4; + if (n <= BNEP_MAX_PROTO_FILTERS) { + struct bnep_proto_filter *f = s->proto_filter; + int i; + + for (i = 0; i < n; i++) { + f[i].start = get_unaligned_be16(data++); + f[i].end = get_unaligned_be16(data++); + + BT_DBG("proto filter start %u end %u", + f[i].start, f[i].end); + } + + if (i < BNEP_MAX_PROTO_FILTERS) + memset(f + i, 0, sizeof(*f)); + + if (n == 0) + bnep_set_default_proto_filter(s); + + bnep_send_rsp(s, BNEP_FILTER_NET_TYPE_RSP, BNEP_SUCCESS); + } else { + bnep_send_rsp(s, BNEP_FILTER_NET_TYPE_RSP, BNEP_FILTER_LIMIT_REACHED); + } +#else + bnep_send_rsp(s, BNEP_FILTER_NET_TYPE_RSP, BNEP_FILTER_UNSUPPORTED_REQ); +#endif + return 0; +} + +static int bnep_ctrl_set_mcfilter(struct bnep_session *s, u8 *data, int len) +{ + int n; + + if (len < 2) + return -EILSEQ; + + n = get_unaligned_be16(data); + data += 2; + len -= 2; + + if (len < n) + return -EILSEQ; + + BT_DBG("filter len %d", n); + +#ifdef CONFIG_BT_BNEP_MC_FILTER + n /= (ETH_ALEN * 2); + + if (n > 0) { + int i; + + s->mc_filter = 0; + + /* Always send broadcast */ + set_bit(bnep_mc_hash(s->dev->broadcast), (ulong *) &s->mc_filter); + + /* Add address ranges to the multicast hash */ + for (; n > 0; n--) { + u8 a1[6], *a2; + + memcpy(a1, data, ETH_ALEN); + data += ETH_ALEN; + a2 = data; + data += ETH_ALEN; + + BT_DBG("mc filter %pMR -> %pMR", a1, a2); + + /* Iterate from a1 to a2 */ + set_bit(bnep_mc_hash(a1), (ulong *) &s->mc_filter); + while (memcmp(a1, a2, 6) < 0 && s->mc_filter != ~0LL) { + /* Increment a1 */ + i = 5; + while (i >= 0 && ++a1[i--] == 0) + ; + + set_bit(bnep_mc_hash(a1), (ulong *) &s->mc_filter); + } + } + } + + BT_DBG("mc filter hash 0x%llx", s->mc_filter); + + bnep_send_rsp(s, BNEP_FILTER_MULTI_ADDR_RSP, BNEP_SUCCESS); +#else + bnep_send_rsp(s, BNEP_FILTER_MULTI_ADDR_RSP, BNEP_FILTER_UNSUPPORTED_REQ); +#endif + return 0; +} + +static int bnep_rx_control(struct bnep_session *s, void *data, int len) +{ + u8 cmd = *(u8 *)data; + int err = 0; + + data++; + len--; + + switch (cmd) { + case BNEP_CMD_NOT_UNDERSTOOD: + case BNEP_SETUP_CONN_RSP: + case BNEP_FILTER_NET_TYPE_RSP: + case BNEP_FILTER_MULTI_ADDR_RSP: + /* Ignore these for now */ + break; + + case BNEP_FILTER_NET_TYPE_SET: + err = bnep_ctrl_set_netfilter(s, data, len); + break; + + case BNEP_FILTER_MULTI_ADDR_SET: + err = bnep_ctrl_set_mcfilter(s, data, len); + break; + + case BNEP_SETUP_CONN_REQ: + /* Successful response should be sent only once */ + if (test_bit(BNEP_SETUP_RESPONSE, &s->flags) && + !test_and_set_bit(BNEP_SETUP_RSP_SENT, &s->flags)) + err = bnep_send_rsp(s, BNEP_SETUP_CONN_RSP, + BNEP_SUCCESS); + else + err = bnep_send_rsp(s, BNEP_SETUP_CONN_RSP, + BNEP_CONN_NOT_ALLOWED); + break; + + default: { + u8 pkt[3]; + pkt[0] = BNEP_CONTROL; + pkt[1] = BNEP_CMD_NOT_UNDERSTOOD; + pkt[2] = cmd; + err = bnep_send(s, pkt, sizeof(pkt)); + } + break; + } + + return err; +} + +static int bnep_rx_extension(struct bnep_session *s, struct sk_buff *skb) +{ + struct bnep_ext_hdr *h; + int err = 0; + + do { + h = (void *) skb->data; + if (!skb_pull(skb, sizeof(*h))) { + err = -EILSEQ; + break; + } + + BT_DBG("type 0x%x len %u", h->type, h->len); + + switch (h->type & BNEP_TYPE_MASK) { + case BNEP_EXT_CONTROL: + bnep_rx_control(s, skb->data, skb->len); + break; + + default: + /* Unknown extension, skip it. */ + break; + } + + if (!skb_pull(skb, h->len)) { + err = -EILSEQ; + break; + } + } while (!err && (h->type & BNEP_EXT_HEADER)); + + return err; +} + +static u8 __bnep_rx_hlen[] = { + ETH_HLEN, /* BNEP_GENERAL */ + 0, /* BNEP_CONTROL */ + 2, /* BNEP_COMPRESSED */ + ETH_ALEN + 2, /* BNEP_COMPRESSED_SRC_ONLY */ + ETH_ALEN + 2 /* BNEP_COMPRESSED_DST_ONLY */ +}; + +static int bnep_rx_frame(struct bnep_session *s, struct sk_buff *skb) +{ + struct net_device *dev = s->dev; + struct sk_buff *nskb; + u8 type, ctrl_type; + + dev->stats.rx_bytes += skb->len; + + type = *(u8 *) skb->data; + skb_pull(skb, 1); + ctrl_type = *(u8 *)skb->data; + + if ((type & BNEP_TYPE_MASK) >= sizeof(__bnep_rx_hlen)) + goto badframe; + + if ((type & BNEP_TYPE_MASK) == BNEP_CONTROL) { + if (bnep_rx_control(s, skb->data, skb->len) < 0) { + dev->stats.tx_errors++; + kfree_skb(skb); + return 0; + } + + if (!(type & BNEP_EXT_HEADER)) { + kfree_skb(skb); + return 0; + } + + /* Verify and pull ctrl message since it's already processed */ + switch (ctrl_type) { + case BNEP_SETUP_CONN_REQ: + /* Pull: ctrl type (1 b), len (1 b), data (len bytes) */ + if (!skb_pull(skb, 2 + *(u8 *)(skb->data + 1) * 2)) + goto badframe; + break; + case BNEP_FILTER_MULTI_ADDR_SET: + case BNEP_FILTER_NET_TYPE_SET: + /* Pull: ctrl type (1 b), len (2 b), data (len bytes) */ + if (!skb_pull(skb, 3 + *(u16 *)(skb->data + 1) * 2)) + goto badframe; + break; + default: + kfree_skb(skb); + return 0; + } + } else { + skb_reset_mac_header(skb); + + /* Verify and pull out header */ + if (!skb_pull(skb, __bnep_rx_hlen[type & BNEP_TYPE_MASK])) + goto badframe; + + s->eh.h_proto = get_unaligned((__be16 *) (skb->data - 2)); + } + + if (type & BNEP_EXT_HEADER) { + if (bnep_rx_extension(s, skb) < 0) + goto badframe; + } + + /* Strip 802.1p header */ + if (ntohs(s->eh.h_proto) == ETH_P_8021Q) { + if (!skb_pull(skb, 4)) + goto badframe; + s->eh.h_proto = get_unaligned((__be16 *) (skb->data - 2)); + } + + /* We have to alloc new skb and copy data here :(. Because original skb + * may not be modified and because of the alignment requirements. */ + nskb = alloc_skb(2 + ETH_HLEN + skb->len, GFP_KERNEL); + if (!nskb) { + dev->stats.rx_dropped++; + kfree_skb(skb); + return -ENOMEM; + } + skb_reserve(nskb, 2); + + /* Decompress header and construct ether frame */ + switch (type & BNEP_TYPE_MASK) { + case BNEP_COMPRESSED: + __skb_put_data(nskb, &s->eh, ETH_HLEN); + break; + + case BNEP_COMPRESSED_SRC_ONLY: + __skb_put_data(nskb, s->eh.h_dest, ETH_ALEN); + __skb_put_data(nskb, skb_mac_header(skb), ETH_ALEN); + put_unaligned(s->eh.h_proto, (__be16 *) __skb_put(nskb, 2)); + break; + + case BNEP_COMPRESSED_DST_ONLY: + __skb_put_data(nskb, skb_mac_header(skb), ETH_ALEN); + __skb_put_data(nskb, s->eh.h_source, ETH_ALEN + 2); + break; + + case BNEP_GENERAL: + __skb_put_data(nskb, skb_mac_header(skb), ETH_ALEN * 2); + put_unaligned(s->eh.h_proto, (__be16 *) __skb_put(nskb, 2)); + break; + } + + skb_copy_from_linear_data(skb, __skb_put(nskb, skb->len), skb->len); + kfree_skb(skb); + + dev->stats.rx_packets++; + nskb->ip_summed = CHECKSUM_NONE; + nskb->protocol = eth_type_trans(nskb, dev); + netif_rx(nskb); + return 0; + +badframe: + dev->stats.rx_errors++; + kfree_skb(skb); + return 0; +} + +static u8 __bnep_tx_types[] = { + BNEP_GENERAL, + BNEP_COMPRESSED_SRC_ONLY, + BNEP_COMPRESSED_DST_ONLY, + BNEP_COMPRESSED +}; + +static int bnep_tx_frame(struct bnep_session *s, struct sk_buff *skb) +{ + struct ethhdr *eh = (void *) skb->data; + struct socket *sock = s->sock; + struct kvec iv[3]; + int len = 0, il = 0; + u8 type = 0; + + BT_DBG("skb %p dev %p type %u", skb, skb->dev, skb->pkt_type); + + if (!skb->dev) { + /* Control frame sent by us */ + goto send; + } + + iv[il++] = (struct kvec) { &type, 1 }; + len++; + + if (compress_src && ether_addr_equal(eh->h_dest, s->eh.h_source)) + type |= 0x01; + + if (compress_dst && ether_addr_equal(eh->h_source, s->eh.h_dest)) + type |= 0x02; + + if (type) + skb_pull(skb, ETH_ALEN * 2); + + type = __bnep_tx_types[type]; + switch (type) { + case BNEP_COMPRESSED_SRC_ONLY: + iv[il++] = (struct kvec) { eh->h_source, ETH_ALEN }; + len += ETH_ALEN; + break; + + case BNEP_COMPRESSED_DST_ONLY: + iv[il++] = (struct kvec) { eh->h_dest, ETH_ALEN }; + len += ETH_ALEN; + break; + } + +send: + iv[il++] = (struct kvec) { skb->data, skb->len }; + len += skb->len; + + /* FIXME: linearize skb */ + { + len = kernel_sendmsg(sock, &s->msg, iv, il, len); + } + kfree_skb(skb); + + if (len > 0) { + s->dev->stats.tx_bytes += len; + s->dev->stats.tx_packets++; + return 0; + } + + return len; +} + +static int bnep_session(void *arg) +{ + struct bnep_session *s = arg; + struct net_device *dev = s->dev; + struct sock *sk = s->sock->sk; + struct sk_buff *skb; + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + BT_DBG(""); + + set_user_nice(current, -15); + + add_wait_queue(sk_sleep(sk), &wait); + while (1) { + if (atomic_read(&s->terminate)) + break; + /* RX */ + while ((skb = skb_dequeue(&sk->sk_receive_queue))) { + skb_orphan(skb); + if (!skb_linearize(skb)) + bnep_rx_frame(s, skb); + else + kfree_skb(skb); + } + + if (sk->sk_state != BT_CONNECTED) + break; + + /* TX */ + while ((skb = skb_dequeue(&sk->sk_write_queue))) + if (bnep_tx_frame(s, skb)) + break; + netif_wake_queue(dev); + + /* + * wait_woken() performs the necessary memory barriers + * for us; see the header comment for this primitive. + */ + wait_woken(&wait, TASK_INTERRUPTIBLE, MAX_SCHEDULE_TIMEOUT); + } + remove_wait_queue(sk_sleep(sk), &wait); + + /* Cleanup session */ + down_write(&bnep_session_sem); + + /* Delete network device */ + unregister_netdev(dev); + + /* Wakeup user-space polling for socket errors */ + s->sock->sk->sk_err = EUNATCH; + + wake_up_interruptible(sk_sleep(s->sock->sk)); + + /* Release the socket */ + fput(s->sock->file); + + __bnep_unlink_session(s); + + up_write(&bnep_session_sem); + free_netdev(dev); + module_put_and_kthread_exit(0); + return 0; +} + +static struct device *bnep_get_device(struct bnep_session *session) +{ + struct l2cap_conn *conn = l2cap_pi(session->sock->sk)->chan->conn; + + if (!conn || !conn->hcon) + return NULL; + + return &conn->hcon->dev; +} + +static struct device_type bnep_type = { + .name = "bluetooth", +}; + +int bnep_add_connection(struct bnep_connadd_req *req, struct socket *sock) +{ + u32 valid_flags = BIT(BNEP_SETUP_RESPONSE); + struct net_device *dev; + struct bnep_session *s, *ss; + u8 dst[ETH_ALEN], src[ETH_ALEN]; + int err; + + BT_DBG(""); + + if (!l2cap_is_socket(sock)) + return -EBADFD; + + if (req->flags & ~valid_flags) + return -EINVAL; + + baswap((void *) dst, &l2cap_pi(sock->sk)->chan->dst); + baswap((void *) src, &l2cap_pi(sock->sk)->chan->src); + + /* session struct allocated as private part of net_device */ + dev = alloc_netdev(sizeof(struct bnep_session), + (*req->device) ? req->device : "bnep%d", + NET_NAME_UNKNOWN, + bnep_net_setup); + if (!dev) + return -ENOMEM; + + down_write(&bnep_session_sem); + + ss = __bnep_get_session(dst); + if (ss && ss->state == BT_CONNECTED) { + err = -EEXIST; + goto failed; + } + + s = netdev_priv(dev); + + /* This is rx header therefore addresses are swapped. + * ie. eh.h_dest is our local address. */ + memcpy(s->eh.h_dest, &src, ETH_ALEN); + memcpy(s->eh.h_source, &dst, ETH_ALEN); + eth_hw_addr_set(dev, s->eh.h_dest); + + s->dev = dev; + s->sock = sock; + s->role = req->role; + s->state = BT_CONNECTED; + s->flags = req->flags; + + s->msg.msg_flags = MSG_NOSIGNAL; + +#ifdef CONFIG_BT_BNEP_MC_FILTER + /* Set default mc filter to not filter out any mc addresses + * as defined in the BNEP specification (revision 0.95a) + * http://grouper.ieee.org/groups/802/15/Bluetooth/BNEP.pdf + */ + s->mc_filter = ~0LL; +#endif + +#ifdef CONFIG_BT_BNEP_PROTO_FILTER + /* Set default protocol filter */ + bnep_set_default_proto_filter(s); +#endif + + SET_NETDEV_DEV(dev, bnep_get_device(s)); + SET_NETDEV_DEVTYPE(dev, &bnep_type); + + err = register_netdev(dev); + if (err) + goto failed; + + __bnep_link_session(s); + + __module_get(THIS_MODULE); + s->task = kthread_run(bnep_session, s, "kbnepd %s", dev->name); + if (IS_ERR(s->task)) { + /* Session thread start failed, gotta cleanup. */ + module_put(THIS_MODULE); + unregister_netdev(dev); + __bnep_unlink_session(s); + err = PTR_ERR(s->task); + goto failed; + } + + up_write(&bnep_session_sem); + strcpy(req->device, dev->name); + return 0; + +failed: + up_write(&bnep_session_sem); + free_netdev(dev); + return err; +} + +int bnep_del_connection(struct bnep_conndel_req *req) +{ + u32 valid_flags = 0; + struct bnep_session *s; + int err = 0; + + BT_DBG(""); + + if (req->flags & ~valid_flags) + return -EINVAL; + + down_read(&bnep_session_sem); + + s = __bnep_get_session(req->dst); + if (s) { + atomic_inc(&s->terminate); + wake_up_interruptible(sk_sleep(s->sock->sk)); + } else + err = -ENOENT; + + up_read(&bnep_session_sem); + return err; +} + +static void __bnep_copy_ci(struct bnep_conninfo *ci, struct bnep_session *s) +{ + u32 valid_flags = BIT(BNEP_SETUP_RESPONSE); + + memset(ci, 0, sizeof(*ci)); + memcpy(ci->dst, s->eh.h_source, ETH_ALEN); + strcpy(ci->device, s->dev->name); + ci->flags = s->flags & valid_flags; + ci->state = s->state; + ci->role = s->role; +} + +int bnep_get_connlist(struct bnep_connlist_req *req) +{ + struct bnep_session *s; + int err = 0, n = 0; + + down_read(&bnep_session_sem); + + list_for_each_entry(s, &bnep_session_list, list) { + struct bnep_conninfo ci; + + __bnep_copy_ci(&ci, s); + + if (copy_to_user(req->ci, &ci, sizeof(ci))) { + err = -EFAULT; + break; + } + + if (++n >= req->cnum) + break; + + req->ci++; + } + req->cnum = n; + + up_read(&bnep_session_sem); + return err; +} + +int bnep_get_conninfo(struct bnep_conninfo *ci) +{ + struct bnep_session *s; + int err = 0; + + down_read(&bnep_session_sem); + + s = __bnep_get_session(ci->dst); + if (s) + __bnep_copy_ci(ci, s); + else + err = -ENOENT; + + up_read(&bnep_session_sem); + return err; +} + +static int __init bnep_init(void) +{ + char flt[50] = ""; + +#ifdef CONFIG_BT_BNEP_PROTO_FILTER + strcat(flt, "protocol "); +#endif + +#ifdef CONFIG_BT_BNEP_MC_FILTER + strcat(flt, "multicast"); +#endif + + BT_INFO("BNEP (Ethernet Emulation) ver %s", VERSION); + if (flt[0]) + BT_INFO("BNEP filters: %s", flt); + + bnep_sock_init(); + return 0; +} + +static void __exit bnep_exit(void) +{ + bnep_sock_cleanup(); +} + +module_init(bnep_init); +module_exit(bnep_exit); + +module_param(compress_src, bool, 0644); +MODULE_PARM_DESC(compress_src, "Compress sources headers"); + +module_param(compress_dst, bool, 0644); +MODULE_PARM_DESC(compress_dst, "Compress destination headers"); + +MODULE_AUTHOR("Marcel Holtmann "); +MODULE_DESCRIPTION("Bluetooth BNEP ver " VERSION); +MODULE_VERSION(VERSION); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("bt-proto-4"); diff --git a/net/bluetooth/bnep/netdev.c b/net/bluetooth/bnep/netdev.c new file mode 100644 index 000000000..cc1cff631 --- /dev/null +++ b/net/bluetooth/bnep/netdev.c @@ -0,0 +1,230 @@ +/* + BNEP implementation for Linux Bluetooth stack (BlueZ). + Copyright (C) 2001-2002 Inventel Systemes + Written 2001-2002 by + Clément Moreau + David Libault + + Copyright (C) 2002 Maxim Krasnyansky + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#include + +#include +#include +#include + +#include "bnep.h" + +#define BNEP_TX_QUEUE_LEN 20 + +static int bnep_net_open(struct net_device *dev) +{ + netif_start_queue(dev); + return 0; +} + +static int bnep_net_close(struct net_device *dev) +{ + netif_stop_queue(dev); + return 0; +} + +static void bnep_net_set_mc_list(struct net_device *dev) +{ +#ifdef CONFIG_BT_BNEP_MC_FILTER + struct bnep_session *s = netdev_priv(dev); + struct sock *sk = s->sock->sk; + struct bnep_set_filter_req *r; + struct sk_buff *skb; + int size; + + BT_DBG("%s mc_count %d", dev->name, netdev_mc_count(dev)); + + size = sizeof(*r) + (BNEP_MAX_MULTICAST_FILTERS + 1) * ETH_ALEN * 2; + skb = alloc_skb(size, GFP_ATOMIC); + if (!skb) { + BT_ERR("%s Multicast list allocation failed", dev->name); + return; + } + + r = (void *) skb->data; + __skb_put(skb, sizeof(*r)); + + r->type = BNEP_CONTROL; + r->ctrl = BNEP_FILTER_MULTI_ADDR_SET; + + if (dev->flags & (IFF_PROMISC | IFF_ALLMULTI)) { + u8 start[ETH_ALEN] = { 0x01 }; + + /* Request all addresses */ + __skb_put_data(skb, start, ETH_ALEN); + __skb_put_data(skb, dev->broadcast, ETH_ALEN); + r->len = htons(ETH_ALEN * 2); + } else { + struct netdev_hw_addr *ha; + int i, len = skb->len; + + if (dev->flags & IFF_BROADCAST) { + __skb_put_data(skb, dev->broadcast, ETH_ALEN); + __skb_put_data(skb, dev->broadcast, ETH_ALEN); + } + + /* FIXME: We should group addresses here. */ + + i = 0; + netdev_for_each_mc_addr(ha, dev) { + if (i == BNEP_MAX_MULTICAST_FILTERS) + break; + __skb_put_data(skb, ha->addr, ETH_ALEN); + __skb_put_data(skb, ha->addr, ETH_ALEN); + + i++; + } + r->len = htons(skb->len - len); + } + + skb_queue_tail(&sk->sk_write_queue, skb); + wake_up_interruptible(sk_sleep(sk)); +#endif +} + +static int bnep_net_set_mac_addr(struct net_device *dev, void *arg) +{ + BT_DBG("%s", dev->name); + return 0; +} + +static void bnep_net_timeout(struct net_device *dev, unsigned int txqueue) +{ + BT_DBG("net_timeout"); + netif_wake_queue(dev); +} + +#ifdef CONFIG_BT_BNEP_MC_FILTER +static int bnep_net_mc_filter(struct sk_buff *skb, struct bnep_session *s) +{ + struct ethhdr *eh = (void *) skb->data; + + if ((eh->h_dest[0] & 1) && !test_bit(bnep_mc_hash(eh->h_dest), (ulong *) &s->mc_filter)) + return 1; + return 0; +} +#endif + +#ifdef CONFIG_BT_BNEP_PROTO_FILTER +/* Determine ether protocol. Based on eth_type_trans. */ +static u16 bnep_net_eth_proto(struct sk_buff *skb) +{ + struct ethhdr *eh = (void *) skb->data; + u16 proto = ntohs(eh->h_proto); + + if (proto >= ETH_P_802_3_MIN) + return proto; + + if (get_unaligned((__be16 *) skb->data) == htons(0xFFFF)) + return ETH_P_802_3; + + return ETH_P_802_2; +} + +static int bnep_net_proto_filter(struct sk_buff *skb, struct bnep_session *s) +{ + u16 proto = bnep_net_eth_proto(skb); + struct bnep_proto_filter *f = s->proto_filter; + int i; + + for (i = 0; i < BNEP_MAX_PROTO_FILTERS && f[i].end; i++) { + if (proto >= f[i].start && proto <= f[i].end) + return 0; + } + + BT_DBG("BNEP: filtered skb %p, proto 0x%.4x", skb, proto); + return 1; +} +#endif + +static netdev_tx_t bnep_net_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct bnep_session *s = netdev_priv(dev); + struct sock *sk = s->sock->sk; + + BT_DBG("skb %p, dev %p", skb, dev); + +#ifdef CONFIG_BT_BNEP_MC_FILTER + if (bnep_net_mc_filter(skb, s)) { + kfree_skb(skb); + return NETDEV_TX_OK; + } +#endif + +#ifdef CONFIG_BT_BNEP_PROTO_FILTER + if (bnep_net_proto_filter(skb, s)) { + kfree_skb(skb); + return NETDEV_TX_OK; + } +#endif + + /* + * We cannot send L2CAP packets from here as we are potentially in a bh. + * So we have to queue them and wake up session thread which is sleeping + * on the sk_sleep(sk). + */ + netif_trans_update(dev); + skb_queue_tail(&sk->sk_write_queue, skb); + wake_up_interruptible(sk_sleep(sk)); + + if (skb_queue_len(&sk->sk_write_queue) >= BNEP_TX_QUEUE_LEN) { + BT_DBG("tx queue is full"); + + /* Stop queuing. + * Session thread will do netif_wake_queue() */ + netif_stop_queue(dev); + } + + return NETDEV_TX_OK; +} + +static const struct net_device_ops bnep_netdev_ops = { + .ndo_open = bnep_net_open, + .ndo_stop = bnep_net_close, + .ndo_start_xmit = bnep_net_xmit, + .ndo_validate_addr = eth_validate_addr, + .ndo_set_rx_mode = bnep_net_set_mc_list, + .ndo_set_mac_address = bnep_net_set_mac_addr, + .ndo_tx_timeout = bnep_net_timeout, + +}; + +void bnep_net_setup(struct net_device *dev) +{ + + eth_broadcast_addr(dev->broadcast); + dev->addr_len = ETH_ALEN; + + ether_setup(dev); + dev->min_mtu = 0; + dev->max_mtu = ETH_MAX_MTU; + dev->priv_flags &= ~IFF_TX_SKB_SHARING; + dev->netdev_ops = &bnep_netdev_ops; + + dev->watchdog_timeo = HZ * 2; +} diff --git a/net/bluetooth/bnep/sock.c b/net/bluetooth/bnep/sock.c new file mode 100644 index 000000000..57d509d77 --- /dev/null +++ b/net/bluetooth/bnep/sock.c @@ -0,0 +1,268 @@ +/* + BNEP implementation for Linux Bluetooth stack (BlueZ). + Copyright (C) 2001-2002 Inventel Systemes + Written 2001-2002 by + David Libault + + Copyright (C) 2002 Maxim Krasnyansky + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#include +#include +#include + +#include "bnep.h" + +static struct bt_sock_list bnep_sk_list = { + .lock = __RW_LOCK_UNLOCKED(bnep_sk_list.lock) +}; + +static int bnep_sock_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + + BT_DBG("sock %p sk %p", sock, sk); + + if (!sk) + return 0; + + bt_sock_unlink(&bnep_sk_list, sk); + + sock_orphan(sk); + sock_put(sk); + return 0; +} + +static int do_bnep_sock_ioctl(struct socket *sock, unsigned int cmd, void __user *argp) +{ + struct bnep_connlist_req cl; + struct bnep_connadd_req ca; + struct bnep_conndel_req cd; + struct bnep_conninfo ci; + struct socket *nsock; + __u32 supp_feat = BIT(BNEP_SETUP_RESPONSE); + int err; + + BT_DBG("cmd %x arg %p", cmd, argp); + + switch (cmd) { + case BNEPCONNADD: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + if (copy_from_user(&ca, argp, sizeof(ca))) + return -EFAULT; + + nsock = sockfd_lookup(ca.sock, &err); + if (!nsock) + return err; + + if (nsock->sk->sk_state != BT_CONNECTED) { + sockfd_put(nsock); + return -EBADFD; + } + ca.device[sizeof(ca.device)-1] = 0; + + err = bnep_add_connection(&ca, nsock); + if (!err) { + if (copy_to_user(argp, &ca, sizeof(ca))) + err = -EFAULT; + } else + sockfd_put(nsock); + + return err; + + case BNEPCONNDEL: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + if (copy_from_user(&cd, argp, sizeof(cd))) + return -EFAULT; + + return bnep_del_connection(&cd); + + case BNEPGETCONNLIST: + if (copy_from_user(&cl, argp, sizeof(cl))) + return -EFAULT; + + if (cl.cnum <= 0) + return -EINVAL; + + err = bnep_get_connlist(&cl); + if (!err && copy_to_user(argp, &cl, sizeof(cl))) + return -EFAULT; + + return err; + + case BNEPGETCONNINFO: + if (copy_from_user(&ci, argp, sizeof(ci))) + return -EFAULT; + + err = bnep_get_conninfo(&ci); + if (!err && copy_to_user(argp, &ci, sizeof(ci))) + return -EFAULT; + + return err; + + case BNEPGETSUPPFEAT: + if (copy_to_user(argp, &supp_feat, sizeof(supp_feat))) + return -EFAULT; + + return 0; + + default: + return -EINVAL; + } + + return 0; +} + +static int bnep_sock_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + return do_bnep_sock_ioctl(sock, cmd, (void __user *)arg); +} + +#ifdef CONFIG_COMPAT +static int bnep_sock_compat_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + void __user *argp = compat_ptr(arg); + if (cmd == BNEPGETCONNLIST) { + struct bnep_connlist_req cl; + unsigned __user *p = argp; + u32 uci; + int err; + + if (get_user(cl.cnum, p) || get_user(uci, p + 1)) + return -EFAULT; + + cl.ci = compat_ptr(uci); + + if (cl.cnum <= 0) + return -EINVAL; + + err = bnep_get_connlist(&cl); + + if (!err && put_user(cl.cnum, p)) + err = -EFAULT; + + return err; + } + + return do_bnep_sock_ioctl(sock, cmd, argp); +} +#endif + +static const struct proto_ops bnep_sock_ops = { + .family = PF_BLUETOOTH, + .owner = THIS_MODULE, + .release = bnep_sock_release, + .ioctl = bnep_sock_ioctl, +#ifdef CONFIG_COMPAT + .compat_ioctl = bnep_sock_compat_ioctl, +#endif + .bind = sock_no_bind, + .getname = sock_no_getname, + .sendmsg = sock_no_sendmsg, + .recvmsg = sock_no_recvmsg, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .connect = sock_no_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .mmap = sock_no_mmap +}; + +static struct proto bnep_proto = { + .name = "BNEP", + .owner = THIS_MODULE, + .obj_size = sizeof(struct bt_sock) +}; + +static int bnep_sock_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + + BT_DBG("sock %p", sock); + + if (sock->type != SOCK_RAW) + return -ESOCKTNOSUPPORT; + + sk = sk_alloc(net, PF_BLUETOOTH, GFP_ATOMIC, &bnep_proto, kern); + if (!sk) + return -ENOMEM; + + sock_init_data(sock, sk); + + sock->ops = &bnep_sock_ops; + + sock->state = SS_UNCONNECTED; + + sock_reset_flag(sk, SOCK_ZAPPED); + + sk->sk_protocol = protocol; + sk->sk_state = BT_OPEN; + + bt_sock_link(&bnep_sk_list, sk); + return 0; +} + +static const struct net_proto_family bnep_sock_family_ops = { + .family = PF_BLUETOOTH, + .owner = THIS_MODULE, + .create = bnep_sock_create +}; + +int __init bnep_sock_init(void) +{ + int err; + + err = proto_register(&bnep_proto, 0); + if (err < 0) + return err; + + err = bt_sock_register(BTPROTO_BNEP, &bnep_sock_family_ops); + if (err < 0) { + BT_ERR("Can't register BNEP socket"); + goto error; + } + + err = bt_procfs_init(&init_net, "bnep", &bnep_sk_list, NULL); + if (err < 0) { + BT_ERR("Failed to create BNEP proc file"); + bt_sock_unregister(BTPROTO_BNEP); + goto error; + } + + BT_INFO("BNEP socket layer initialized"); + + return 0; + +error: + proto_unregister(&bnep_proto); + return err; +} + +void __exit bnep_sock_cleanup(void) +{ + bt_procfs_cleanup(&init_net, "bnep"); + bt_sock_unregister(BTPROTO_BNEP); + proto_unregister(&bnep_proto); +} diff --git a/net/bluetooth/cmtp/Kconfig b/net/bluetooth/cmtp/Kconfig new file mode 100644 index 000000000..c8337786d --- /dev/null +++ b/net/bluetooth/cmtp/Kconfig @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: GPL-2.0-only +config BT_CMTP + tristate "CMTP protocol support" + depends on BT_BREDR && ISDN_CAPI + help + CMTP (CAPI Message Transport Protocol) is a transport layer + for CAPI messages. CMTP is required for the Bluetooth Common + ISDN Access Profile. + + Say Y here to compile CMTP support into the kernel or say M to + compile it as module (cmtp). + diff --git a/net/bluetooth/cmtp/Makefile b/net/bluetooth/cmtp/Makefile new file mode 100644 index 000000000..b2262ca97 --- /dev/null +++ b/net/bluetooth/cmtp/Makefile @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the Linux Bluetooth CMTP layer +# + +obj-$(CONFIG_BT_CMTP) += cmtp.o + +cmtp-objs := core.o sock.o capi.o diff --git a/net/bluetooth/cmtp/capi.c b/net/bluetooth/cmtp/capi.c new file mode 100644 index 000000000..f3bedc3b6 --- /dev/null +++ b/net/bluetooth/cmtp/capi.c @@ -0,0 +1,595 @@ +/* + CMTP implementation for Linux Bluetooth stack (BlueZ). + Copyright (C) 2002-2003 Marcel Holtmann + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "cmtp.h" + +#define CAPI_INTEROPERABILITY 0x20 + +#define CAPI_INTEROPERABILITY_REQ CAPICMD(CAPI_INTEROPERABILITY, CAPI_REQ) +#define CAPI_INTEROPERABILITY_CONF CAPICMD(CAPI_INTEROPERABILITY, CAPI_CONF) +#define CAPI_INTEROPERABILITY_IND CAPICMD(CAPI_INTEROPERABILITY, CAPI_IND) +#define CAPI_INTEROPERABILITY_RESP CAPICMD(CAPI_INTEROPERABILITY, CAPI_RESP) + +#define CAPI_INTEROPERABILITY_REQ_LEN (CAPI_MSG_BASELEN + 2) +#define CAPI_INTEROPERABILITY_CONF_LEN (CAPI_MSG_BASELEN + 4) +#define CAPI_INTEROPERABILITY_IND_LEN (CAPI_MSG_BASELEN + 2) +#define CAPI_INTEROPERABILITY_RESP_LEN (CAPI_MSG_BASELEN + 2) + +#define CAPI_FUNCTION_REGISTER 0 +#define CAPI_FUNCTION_RELEASE 1 +#define CAPI_FUNCTION_GET_PROFILE 2 +#define CAPI_FUNCTION_GET_MANUFACTURER 3 +#define CAPI_FUNCTION_GET_VERSION 4 +#define CAPI_FUNCTION_GET_SERIAL_NUMBER 5 +#define CAPI_FUNCTION_MANUFACTURER 6 +#define CAPI_FUNCTION_LOOPBACK 7 + + +#define CMTP_MSGNUM 1 +#define CMTP_APPLID 2 +#define CMTP_MAPPING 3 + +static struct cmtp_application *cmtp_application_add(struct cmtp_session *session, __u16 appl) +{ + struct cmtp_application *app = kzalloc(sizeof(*app), GFP_KERNEL); + + BT_DBG("session %p application %p appl %u", session, app, appl); + + if (!app) + return NULL; + + app->state = BT_OPEN; + app->appl = appl; + + list_add_tail(&app->list, &session->applications); + + return app; +} + +static void cmtp_application_del(struct cmtp_session *session, struct cmtp_application *app) +{ + BT_DBG("session %p application %p", session, app); + + if (app) { + list_del(&app->list); + kfree(app); + } +} + +static struct cmtp_application *cmtp_application_get(struct cmtp_session *session, int pattern, __u16 value) +{ + struct cmtp_application *app; + + list_for_each_entry(app, &session->applications, list) { + switch (pattern) { + case CMTP_MSGNUM: + if (app->msgnum == value) + return app; + break; + case CMTP_APPLID: + if (app->appl == value) + return app; + break; + case CMTP_MAPPING: + if (app->mapping == value) + return app; + break; + } + } + + return NULL; +} + +static int cmtp_msgnum_get(struct cmtp_session *session) +{ + session->msgnum++; + + if ((session->msgnum & 0xff) > 200) + session->msgnum = CMTP_INITIAL_MSGNUM + 1; + + return session->msgnum; +} + +static void cmtp_send_capimsg(struct cmtp_session *session, struct sk_buff *skb) +{ + struct cmtp_scb *scb = (void *) skb->cb; + + BT_DBG("session %p skb %p len %u", session, skb, skb->len); + + scb->id = -1; + scb->data = (CAPIMSG_COMMAND(skb->data) == CAPI_DATA_B3); + + skb_queue_tail(&session->transmit, skb); + + wake_up_interruptible(sk_sleep(session->sock->sk)); +} + +static void cmtp_send_interopmsg(struct cmtp_session *session, + __u8 subcmd, __u16 appl, __u16 msgnum, + __u16 function, unsigned char *buf, int len) +{ + struct sk_buff *skb; + unsigned char *s; + + BT_DBG("session %p subcmd 0x%02x appl %u msgnum %u", session, subcmd, appl, msgnum); + + skb = alloc_skb(CAPI_MSG_BASELEN + 6 + len, GFP_ATOMIC); + if (!skb) { + BT_ERR("Can't allocate memory for interoperability packet"); + return; + } + + s = skb_put(skb, CAPI_MSG_BASELEN + 6 + len); + + capimsg_setu16(s, 0, CAPI_MSG_BASELEN + 6 + len); + capimsg_setu16(s, 2, appl); + capimsg_setu8 (s, 4, CAPI_INTEROPERABILITY); + capimsg_setu8 (s, 5, subcmd); + capimsg_setu16(s, 6, msgnum); + + /* Interoperability selector (Bluetooth Device Management) */ + capimsg_setu16(s, 8, 0x0001); + + capimsg_setu8 (s, 10, 3 + len); + capimsg_setu16(s, 11, function); + capimsg_setu8 (s, 13, len); + + if (len > 0) + memcpy(s + 14, buf, len); + + cmtp_send_capimsg(session, skb); +} + +static void cmtp_recv_interopmsg(struct cmtp_session *session, struct sk_buff *skb) +{ + struct capi_ctr *ctrl = &session->ctrl; + struct cmtp_application *application; + __u16 appl, msgnum, func, info; + __u32 controller; + + BT_DBG("session %p skb %p len %u", session, skb, skb->len); + + switch (CAPIMSG_SUBCOMMAND(skb->data)) { + case CAPI_CONF: + if (skb->len < CAPI_MSG_BASELEN + 10) + break; + + func = CAPIMSG_U16(skb->data, CAPI_MSG_BASELEN + 5); + info = CAPIMSG_U16(skb->data, CAPI_MSG_BASELEN + 8); + + switch (func) { + case CAPI_FUNCTION_REGISTER: + msgnum = CAPIMSG_MSGID(skb->data); + + application = cmtp_application_get(session, CMTP_MSGNUM, msgnum); + if (application) { + application->state = BT_CONNECTED; + application->msgnum = 0; + application->mapping = CAPIMSG_APPID(skb->data); + wake_up_interruptible(&session->wait); + } + + break; + + case CAPI_FUNCTION_RELEASE: + appl = CAPIMSG_APPID(skb->data); + + application = cmtp_application_get(session, CMTP_MAPPING, appl); + if (application) { + application->state = BT_CLOSED; + application->msgnum = 0; + wake_up_interruptible(&session->wait); + } + + break; + + case CAPI_FUNCTION_GET_PROFILE: + if (skb->len < CAPI_MSG_BASELEN + 11 + sizeof(capi_profile)) + break; + + controller = CAPIMSG_U16(skb->data, CAPI_MSG_BASELEN + 11); + msgnum = CAPIMSG_MSGID(skb->data); + + if (!info && (msgnum == CMTP_INITIAL_MSGNUM)) { + session->ncontroller = controller; + wake_up_interruptible(&session->wait); + break; + } + + if (!info && ctrl) { + memcpy(&ctrl->profile, + skb->data + CAPI_MSG_BASELEN + 11, + sizeof(capi_profile)); + session->state = BT_CONNECTED; + capi_ctr_ready(ctrl); + } + + break; + + case CAPI_FUNCTION_GET_MANUFACTURER: + if (skb->len < CAPI_MSG_BASELEN + 15) + break; + + if (!info && ctrl) { + int len = min_t(uint, CAPI_MANUFACTURER_LEN, + skb->data[CAPI_MSG_BASELEN + 14]); + + memset(ctrl->manu, 0, CAPI_MANUFACTURER_LEN); + strncpy(ctrl->manu, + skb->data + CAPI_MSG_BASELEN + 15, len); + } + + break; + + case CAPI_FUNCTION_GET_VERSION: + if (skb->len < CAPI_MSG_BASELEN + 32) + break; + + if (!info && ctrl) { + ctrl->version.majorversion = CAPIMSG_U32(skb->data, CAPI_MSG_BASELEN + 16); + ctrl->version.minorversion = CAPIMSG_U32(skb->data, CAPI_MSG_BASELEN + 20); + ctrl->version.majormanuversion = CAPIMSG_U32(skb->data, CAPI_MSG_BASELEN + 24); + ctrl->version.minormanuversion = CAPIMSG_U32(skb->data, CAPI_MSG_BASELEN + 28); + } + + break; + + case CAPI_FUNCTION_GET_SERIAL_NUMBER: + if (skb->len < CAPI_MSG_BASELEN + 17) + break; + + if (!info && ctrl) { + int len = min_t(uint, CAPI_SERIAL_LEN, + skb->data[CAPI_MSG_BASELEN + 16]); + + memset(ctrl->serial, 0, CAPI_SERIAL_LEN); + strncpy(ctrl->serial, + skb->data + CAPI_MSG_BASELEN + 17, len); + } + + break; + } + + break; + + case CAPI_IND: + if (skb->len < CAPI_MSG_BASELEN + 6) + break; + + func = CAPIMSG_U16(skb->data, CAPI_MSG_BASELEN + 3); + + if (func == CAPI_FUNCTION_LOOPBACK) { + int len = min_t(uint, skb->len - CAPI_MSG_BASELEN - 6, + skb->data[CAPI_MSG_BASELEN + 5]); + appl = CAPIMSG_APPID(skb->data); + msgnum = CAPIMSG_MSGID(skb->data); + cmtp_send_interopmsg(session, CAPI_RESP, appl, msgnum, func, + skb->data + CAPI_MSG_BASELEN + 6, len); + } + + break; + } + + kfree_skb(skb); +} + +void cmtp_recv_capimsg(struct cmtp_session *session, struct sk_buff *skb) +{ + struct capi_ctr *ctrl = &session->ctrl; + struct cmtp_application *application; + __u16 appl; + __u32 contr; + + BT_DBG("session %p skb %p len %u", session, skb, skb->len); + + if (skb->len < CAPI_MSG_BASELEN) + return; + + if (CAPIMSG_COMMAND(skb->data) == CAPI_INTEROPERABILITY) { + cmtp_recv_interopmsg(session, skb); + return; + } + + if (session->flags & BIT(CMTP_LOOPBACK)) { + kfree_skb(skb); + return; + } + + appl = CAPIMSG_APPID(skb->data); + contr = CAPIMSG_CONTROL(skb->data); + + application = cmtp_application_get(session, CMTP_MAPPING, appl); + if (application) { + appl = application->appl; + CAPIMSG_SETAPPID(skb->data, appl); + } else { + BT_ERR("Can't find application with id %u", appl); + kfree_skb(skb); + return; + } + + if ((contr & 0x7f) == 0x01) { + contr = (contr & 0xffffff80) | session->num; + CAPIMSG_SETCONTROL(skb->data, contr); + } + + capi_ctr_handle_message(ctrl, appl, skb); +} + +static int cmtp_load_firmware(struct capi_ctr *ctrl, capiloaddata *data) +{ + BT_DBG("ctrl %p data %p", ctrl, data); + + return 0; +} + +static void cmtp_reset_ctr(struct capi_ctr *ctrl) +{ + struct cmtp_session *session = ctrl->driverdata; + + BT_DBG("ctrl %p", ctrl); + + capi_ctr_down(ctrl); + + atomic_inc(&session->terminate); + wake_up_process(session->task); +} + +static void cmtp_register_appl(struct capi_ctr *ctrl, __u16 appl, capi_register_params *rp) +{ + DECLARE_WAITQUEUE(wait, current); + struct cmtp_session *session = ctrl->driverdata; + struct cmtp_application *application; + unsigned long timeo = CMTP_INTEROP_TIMEOUT; + unsigned char buf[8]; + int err = 0, nconn, want = rp->level3cnt; + + BT_DBG("ctrl %p appl %u level3cnt %u datablkcnt %u datablklen %u", + ctrl, appl, rp->level3cnt, rp->datablkcnt, rp->datablklen); + + application = cmtp_application_add(session, appl); + if (!application) { + BT_ERR("Can't allocate memory for new application"); + return; + } + + if (want < 0) + nconn = ctrl->profile.nbchannel * -want; + else + nconn = want; + + if (nconn == 0) + nconn = ctrl->profile.nbchannel; + + capimsg_setu16(buf, 0, nconn); + capimsg_setu16(buf, 2, rp->datablkcnt); + capimsg_setu16(buf, 4, rp->datablklen); + + application->state = BT_CONFIG; + application->msgnum = cmtp_msgnum_get(session); + + cmtp_send_interopmsg(session, CAPI_REQ, 0x0000, application->msgnum, + CAPI_FUNCTION_REGISTER, buf, 6); + + add_wait_queue(&session->wait, &wait); + while (1) { + set_current_state(TASK_INTERRUPTIBLE); + + if (!timeo) { + err = -EAGAIN; + break; + } + + if (application->state == BT_CLOSED) { + err = -application->err; + break; + } + + if (application->state == BT_CONNECTED) + break; + + if (signal_pending(current)) { + err = -EINTR; + break; + } + + timeo = schedule_timeout(timeo); + } + set_current_state(TASK_RUNNING); + remove_wait_queue(&session->wait, &wait); + + if (err) { + cmtp_application_del(session, application); + return; + } +} + +static void cmtp_release_appl(struct capi_ctr *ctrl, __u16 appl) +{ + struct cmtp_session *session = ctrl->driverdata; + struct cmtp_application *application; + + BT_DBG("ctrl %p appl %u", ctrl, appl); + + application = cmtp_application_get(session, CMTP_APPLID, appl); + if (!application) { + BT_ERR("Can't find application"); + return; + } + + application->msgnum = cmtp_msgnum_get(session); + + cmtp_send_interopmsg(session, CAPI_REQ, application->mapping, application->msgnum, + CAPI_FUNCTION_RELEASE, NULL, 0); + + wait_event_interruptible_timeout(session->wait, + (application->state == BT_CLOSED), CMTP_INTEROP_TIMEOUT); + + cmtp_application_del(session, application); +} + +static u16 cmtp_send_message(struct capi_ctr *ctrl, struct sk_buff *skb) +{ + struct cmtp_session *session = ctrl->driverdata; + struct cmtp_application *application; + __u16 appl; + __u32 contr; + + BT_DBG("ctrl %p skb %p", ctrl, skb); + + appl = CAPIMSG_APPID(skb->data); + contr = CAPIMSG_CONTROL(skb->data); + + application = cmtp_application_get(session, CMTP_APPLID, appl); + if ((!application) || (application->state != BT_CONNECTED)) { + BT_ERR("Can't find application with id %u", appl); + return CAPI_ILLAPPNR; + } + + CAPIMSG_SETAPPID(skb->data, application->mapping); + + if ((contr & 0x7f) == session->num) { + contr = (contr & 0xffffff80) | 0x01; + CAPIMSG_SETCONTROL(skb->data, contr); + } + + cmtp_send_capimsg(session, skb); + + return CAPI_NOERROR; +} + +static char *cmtp_procinfo(struct capi_ctr *ctrl) +{ + return "CAPI Message Transport Protocol"; +} + +static int cmtp_proc_show(struct seq_file *m, void *v) +{ + struct capi_ctr *ctrl = m->private; + struct cmtp_session *session = ctrl->driverdata; + struct cmtp_application *app; + + seq_printf(m, "%s\n\n", cmtp_procinfo(ctrl)); + seq_printf(m, "addr %s\n", session->name); + seq_printf(m, "ctrl %d\n", session->num); + + list_for_each_entry(app, &session->applications, list) { + seq_printf(m, "appl %u -> %u\n", app->appl, app->mapping); + } + + return 0; +} + +int cmtp_attach_device(struct cmtp_session *session) +{ + unsigned char buf[4]; + long ret; + + BT_DBG("session %p", session); + + capimsg_setu32(buf, 0, 0); + + cmtp_send_interopmsg(session, CAPI_REQ, 0xffff, CMTP_INITIAL_MSGNUM, + CAPI_FUNCTION_GET_PROFILE, buf, 4); + + ret = wait_event_interruptible_timeout(session->wait, + session->ncontroller, CMTP_INTEROP_TIMEOUT); + + BT_INFO("Found %d CAPI controller(s) on device %s", session->ncontroller, session->name); + + if (!ret) + return -ETIMEDOUT; + + if (!session->ncontroller) + return -ENODEV; + + if (session->ncontroller > 1) + BT_INFO("Setting up only CAPI controller 1"); + + session->ctrl.owner = THIS_MODULE; + session->ctrl.driverdata = session; + strcpy(session->ctrl.name, session->name); + + session->ctrl.driver_name = "cmtp"; + session->ctrl.load_firmware = cmtp_load_firmware; + session->ctrl.reset_ctr = cmtp_reset_ctr; + session->ctrl.register_appl = cmtp_register_appl; + session->ctrl.release_appl = cmtp_release_appl; + session->ctrl.send_message = cmtp_send_message; + + session->ctrl.procinfo = cmtp_procinfo; + session->ctrl.proc_show = cmtp_proc_show; + + if (attach_capi_ctr(&session->ctrl) < 0) { + BT_ERR("Can't attach new controller"); + return -EBUSY; + } + + session->num = session->ctrl.cnr; + + BT_DBG("session %p num %d", session, session->num); + + capimsg_setu32(buf, 0, 1); + + cmtp_send_interopmsg(session, CAPI_REQ, 0xffff, cmtp_msgnum_get(session), + CAPI_FUNCTION_GET_MANUFACTURER, buf, 4); + + cmtp_send_interopmsg(session, CAPI_REQ, 0xffff, cmtp_msgnum_get(session), + CAPI_FUNCTION_GET_VERSION, buf, 4); + + cmtp_send_interopmsg(session, CAPI_REQ, 0xffff, cmtp_msgnum_get(session), + CAPI_FUNCTION_GET_SERIAL_NUMBER, buf, 4); + + cmtp_send_interopmsg(session, CAPI_REQ, 0xffff, cmtp_msgnum_get(session), + CAPI_FUNCTION_GET_PROFILE, buf, 4); + + return 0; +} + +void cmtp_detach_device(struct cmtp_session *session) +{ + BT_DBG("session %p", session); + + detach_capi_ctr(&session->ctrl); +} diff --git a/net/bluetooth/cmtp/cmtp.h b/net/bluetooth/cmtp/cmtp.h new file mode 100644 index 000000000..f6b9dc4e4 --- /dev/null +++ b/net/bluetooth/cmtp/cmtp.h @@ -0,0 +1,129 @@ +/* + CMTP implementation for Linux Bluetooth stack (BlueZ). + Copyright (C) 2002-2003 Marcel Holtmann + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#ifndef __CMTP_H +#define __CMTP_H + +#include +#include + +#define BTNAMSIZ 21 + +/* CMTP ioctl defines */ +#define CMTPCONNADD _IOW('C', 200, int) +#define CMTPCONNDEL _IOW('C', 201, int) +#define CMTPGETCONNLIST _IOR('C', 210, int) +#define CMTPGETCONNINFO _IOR('C', 211, int) + +#define CMTP_LOOPBACK 0 + +struct cmtp_connadd_req { + int sock; /* Connected socket */ + __u32 flags; +}; + +struct cmtp_conndel_req { + bdaddr_t bdaddr; + __u32 flags; +}; + +struct cmtp_conninfo { + bdaddr_t bdaddr; + __u32 flags; + __u16 state; + int num; +}; + +struct cmtp_connlist_req { + __u32 cnum; + struct cmtp_conninfo __user *ci; +}; + +int cmtp_add_connection(struct cmtp_connadd_req *req, struct socket *sock); +int cmtp_del_connection(struct cmtp_conndel_req *req); +int cmtp_get_connlist(struct cmtp_connlist_req *req); +int cmtp_get_conninfo(struct cmtp_conninfo *ci); + +/* CMTP session defines */ +#define CMTP_INTEROP_TIMEOUT (HZ * 5) +#define CMTP_INITIAL_MSGNUM 0xff00 + +struct cmtp_session { + struct list_head list; + + struct socket *sock; + + bdaddr_t bdaddr; + + unsigned long state; + unsigned long flags; + + uint mtu; + + char name[BTNAMSIZ]; + + atomic_t terminate; + struct task_struct *task; + + wait_queue_head_t wait; + + int ncontroller; + int num; + struct capi_ctr ctrl; + + struct list_head applications; + + unsigned long blockids; + int msgnum; + + struct sk_buff_head transmit; + + struct sk_buff *reassembly[16]; +}; + +struct cmtp_application { + struct list_head list; + + unsigned long state; + int err; + + __u16 appl; + __u16 mapping; + + __u16 msgnum; +}; + +struct cmtp_scb { + int id; + int data; +}; + +int cmtp_attach_device(struct cmtp_session *session); +void cmtp_detach_device(struct cmtp_session *session); + +void cmtp_recv_capimsg(struct cmtp_session *session, struct sk_buff *skb); + +/* CMTP init defines */ +int cmtp_init_sockets(void); +void cmtp_cleanup_sockets(void); + +#endif /* __CMTP_H */ diff --git a/net/bluetooth/cmtp/core.c b/net/bluetooth/cmtp/core.c new file mode 100644 index 000000000..90d130588 --- /dev/null +++ b/net/bluetooth/cmtp/core.c @@ -0,0 +1,519 @@ +/* + CMTP implementation for Linux Bluetooth stack (BlueZ). + Copyright (C) 2002-2003 Marcel Holtmann + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include "cmtp.h" + +#define VERSION "1.0" + +static DECLARE_RWSEM(cmtp_session_sem); +static LIST_HEAD(cmtp_session_list); + +static struct cmtp_session *__cmtp_get_session(bdaddr_t *bdaddr) +{ + struct cmtp_session *session; + + BT_DBG(""); + + list_for_each_entry(session, &cmtp_session_list, list) + if (!bacmp(bdaddr, &session->bdaddr)) + return session; + + return NULL; +} + +static void __cmtp_link_session(struct cmtp_session *session) +{ + list_add(&session->list, &cmtp_session_list); +} + +static void __cmtp_unlink_session(struct cmtp_session *session) +{ + list_del(&session->list); +} + +static void __cmtp_copy_session(struct cmtp_session *session, struct cmtp_conninfo *ci) +{ + u32 valid_flags = BIT(CMTP_LOOPBACK); + memset(ci, 0, sizeof(*ci)); + bacpy(&ci->bdaddr, &session->bdaddr); + + ci->flags = session->flags & valid_flags; + ci->state = session->state; + + ci->num = session->num; +} + + +static inline int cmtp_alloc_block_id(struct cmtp_session *session) +{ + int i, id = -1; + + for (i = 0; i < 16; i++) + if (!test_and_set_bit(i, &session->blockids)) { + id = i; + break; + } + + return id; +} + +static inline void cmtp_free_block_id(struct cmtp_session *session, int id) +{ + clear_bit(id, &session->blockids); +} + +static inline void cmtp_add_msgpart(struct cmtp_session *session, int id, const unsigned char *buf, int count) +{ + struct sk_buff *skb = session->reassembly[id], *nskb; + int size; + + BT_DBG("session %p buf %p count %d", session, buf, count); + + size = (skb) ? skb->len + count : count; + + nskb = alloc_skb(size, GFP_ATOMIC); + if (!nskb) { + BT_ERR("Can't allocate memory for CAPI message"); + return; + } + + if (skb && (skb->len > 0)) + skb_copy_from_linear_data(skb, skb_put(nskb, skb->len), skb->len); + + skb_put_data(nskb, buf, count); + + session->reassembly[id] = nskb; + + kfree_skb(skb); +} + +static inline int cmtp_recv_frame(struct cmtp_session *session, struct sk_buff *skb) +{ + __u8 hdr, hdrlen, id; + __u16 len; + + BT_DBG("session %p skb %p len %d", session, skb, skb->len); + + while (skb->len > 0) { + hdr = skb->data[0]; + + switch (hdr & 0xc0) { + case 0x40: + hdrlen = 2; + len = skb->data[1]; + break; + case 0x80: + hdrlen = 3; + len = skb->data[1] | (skb->data[2] << 8); + break; + default: + hdrlen = 1; + len = 0; + break; + } + + id = (hdr & 0x3c) >> 2; + + BT_DBG("hdr 0x%02x hdrlen %d len %d id %d", hdr, hdrlen, len, id); + + if (hdrlen + len > skb->len) { + BT_ERR("Wrong size or header information in CMTP frame"); + break; + } + + if (len == 0) { + skb_pull(skb, hdrlen); + continue; + } + + switch (hdr & 0x03) { + case 0x00: + cmtp_add_msgpart(session, id, skb->data + hdrlen, len); + cmtp_recv_capimsg(session, session->reassembly[id]); + session->reassembly[id] = NULL; + break; + case 0x01: + cmtp_add_msgpart(session, id, skb->data + hdrlen, len); + break; + default: + kfree_skb(session->reassembly[id]); + session->reassembly[id] = NULL; + break; + } + + skb_pull(skb, hdrlen + len); + } + + kfree_skb(skb); + return 0; +} + +static int cmtp_send_frame(struct cmtp_session *session, unsigned char *data, int len) +{ + struct socket *sock = session->sock; + struct kvec iv = { data, len }; + struct msghdr msg; + + BT_DBG("session %p data %p len %d", session, data, len); + + if (!len) + return 0; + + memset(&msg, 0, sizeof(msg)); + + return kernel_sendmsg(sock, &msg, &iv, 1, len); +} + +static void cmtp_process_transmit(struct cmtp_session *session) +{ + struct sk_buff *skb, *nskb; + unsigned char *hdr; + unsigned int size, tail; + + BT_DBG("session %p", session); + + nskb = alloc_skb(session->mtu, GFP_ATOMIC); + if (!nskb) { + BT_ERR("Can't allocate memory for new frame"); + return; + } + + while ((skb = skb_dequeue(&session->transmit))) { + struct cmtp_scb *scb = (void *) skb->cb; + + tail = session->mtu - nskb->len; + if (tail < 5) { + cmtp_send_frame(session, nskb->data, nskb->len); + skb_trim(nskb, 0); + tail = session->mtu; + } + + size = min_t(uint, ((tail < 258) ? (tail - 2) : (tail - 3)), skb->len); + + if (scb->id < 0) { + scb->id = cmtp_alloc_block_id(session); + if (scb->id < 0) { + skb_queue_head(&session->transmit, skb); + break; + } + } + + if (size < 256) { + hdr = skb_put(nskb, 2); + hdr[0] = 0x40 + | ((scb->id << 2) & 0x3c) + | ((skb->len == size) ? 0x00 : 0x01); + hdr[1] = size; + } else { + hdr = skb_put(nskb, 3); + hdr[0] = 0x80 + | ((scb->id << 2) & 0x3c) + | ((skb->len == size) ? 0x00 : 0x01); + hdr[1] = size & 0xff; + hdr[2] = size >> 8; + } + + skb_copy_from_linear_data(skb, skb_put(nskb, size), size); + skb_pull(skb, size); + + if (skb->len > 0) { + skb_queue_head(&session->transmit, skb); + } else { + cmtp_free_block_id(session, scb->id); + if (scb->data) { + cmtp_send_frame(session, nskb->data, nskb->len); + skb_trim(nskb, 0); + } + kfree_skb(skb); + } + } + + cmtp_send_frame(session, nskb->data, nskb->len); + + kfree_skb(nskb); +} + +static int cmtp_session(void *arg) +{ + struct cmtp_session *session = arg; + struct sock *sk = session->sock->sk; + struct sk_buff *skb; + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + BT_DBG("session %p", session); + + set_user_nice(current, -15); + + add_wait_queue(sk_sleep(sk), &wait); + while (1) { + if (atomic_read(&session->terminate)) + break; + if (sk->sk_state != BT_CONNECTED) + break; + + while ((skb = skb_dequeue(&sk->sk_receive_queue))) { + skb_orphan(skb); + if (!skb_linearize(skb)) + cmtp_recv_frame(session, skb); + else + kfree_skb(skb); + } + + cmtp_process_transmit(session); + + /* + * wait_woken() performs the necessary memory barriers + * for us; see the header comment for this primitive. + */ + wait_woken(&wait, TASK_INTERRUPTIBLE, MAX_SCHEDULE_TIMEOUT); + } + remove_wait_queue(sk_sleep(sk), &wait); + + down_write(&cmtp_session_sem); + + if (!(session->flags & BIT(CMTP_LOOPBACK))) + cmtp_detach_device(session); + + fput(session->sock->file); + + __cmtp_unlink_session(session); + + up_write(&cmtp_session_sem); + + kfree(session); + module_put_and_kthread_exit(0); + return 0; +} + +int cmtp_add_connection(struct cmtp_connadd_req *req, struct socket *sock) +{ + u32 valid_flags = BIT(CMTP_LOOPBACK); + struct cmtp_session *session, *s; + int i, err; + + BT_DBG(""); + + if (!l2cap_is_socket(sock)) + return -EBADFD; + + if (req->flags & ~valid_flags) + return -EINVAL; + + session = kzalloc(sizeof(struct cmtp_session), GFP_KERNEL); + if (!session) + return -ENOMEM; + + down_write(&cmtp_session_sem); + + s = __cmtp_get_session(&l2cap_pi(sock->sk)->chan->dst); + if (s && s->state == BT_CONNECTED) { + err = -EEXIST; + goto failed; + } + + bacpy(&session->bdaddr, &l2cap_pi(sock->sk)->chan->dst); + + session->mtu = min_t(uint, l2cap_pi(sock->sk)->chan->omtu, + l2cap_pi(sock->sk)->chan->imtu); + + BT_DBG("mtu %d", session->mtu); + + sprintf(session->name, "%pMR", &session->bdaddr); + + session->sock = sock; + session->state = BT_CONFIG; + + init_waitqueue_head(&session->wait); + + session->msgnum = CMTP_INITIAL_MSGNUM; + + INIT_LIST_HEAD(&session->applications); + + skb_queue_head_init(&session->transmit); + + for (i = 0; i < 16; i++) + session->reassembly[i] = NULL; + + session->flags = req->flags; + + __cmtp_link_session(session); + + __module_get(THIS_MODULE); + session->task = kthread_run(cmtp_session, session, "kcmtpd_ctr_%d", + session->num); + if (IS_ERR(session->task)) { + module_put(THIS_MODULE); + err = PTR_ERR(session->task); + goto unlink; + } + + if (!(session->flags & BIT(CMTP_LOOPBACK))) { + err = cmtp_attach_device(session); + if (err < 0) { + /* Caller will call fput in case of failure, and so + * will cmtp_session kthread. + */ + get_file(session->sock->file); + + atomic_inc(&session->terminate); + wake_up_interruptible(sk_sleep(session->sock->sk)); + up_write(&cmtp_session_sem); + return err; + } + } + + up_write(&cmtp_session_sem); + return 0; + +unlink: + __cmtp_unlink_session(session); + +failed: + up_write(&cmtp_session_sem); + kfree(session); + return err; +} + +int cmtp_del_connection(struct cmtp_conndel_req *req) +{ + u32 valid_flags = 0; + struct cmtp_session *session; + int err = 0; + + BT_DBG(""); + + if (req->flags & ~valid_flags) + return -EINVAL; + + down_read(&cmtp_session_sem); + + session = __cmtp_get_session(&req->bdaddr); + if (session) { + /* Flush the transmit queue */ + skb_queue_purge(&session->transmit); + + /* Stop session thread */ + atomic_inc(&session->terminate); + + /* + * See the comment preceding the call to wait_woken() + * in cmtp_session(). + */ + wake_up_interruptible(sk_sleep(session->sock->sk)); + } else + err = -ENOENT; + + up_read(&cmtp_session_sem); + return err; +} + +int cmtp_get_connlist(struct cmtp_connlist_req *req) +{ + struct cmtp_session *session; + int err = 0, n = 0; + + BT_DBG(""); + + down_read(&cmtp_session_sem); + + list_for_each_entry(session, &cmtp_session_list, list) { + struct cmtp_conninfo ci; + + __cmtp_copy_session(session, &ci); + + if (copy_to_user(req->ci, &ci, sizeof(ci))) { + err = -EFAULT; + break; + } + + if (++n >= req->cnum) + break; + + req->ci++; + } + req->cnum = n; + + up_read(&cmtp_session_sem); + return err; +} + +int cmtp_get_conninfo(struct cmtp_conninfo *ci) +{ + struct cmtp_session *session; + int err = 0; + + down_read(&cmtp_session_sem); + + session = __cmtp_get_session(&ci->bdaddr); + if (session) + __cmtp_copy_session(session, ci); + else + err = -ENOENT; + + up_read(&cmtp_session_sem); + return err; +} + + +static int __init cmtp_init(void) +{ + BT_INFO("CMTP (CAPI Emulation) ver %s", VERSION); + + return cmtp_init_sockets(); +} + +static void __exit cmtp_exit(void) +{ + cmtp_cleanup_sockets(); +} + +module_init(cmtp_init); +module_exit(cmtp_exit); + +MODULE_AUTHOR("Marcel Holtmann "); +MODULE_DESCRIPTION("Bluetooth CMTP ver " VERSION); +MODULE_VERSION(VERSION); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("bt-proto-5"); diff --git a/net/bluetooth/cmtp/sock.c b/net/bluetooth/cmtp/sock.c new file mode 100644 index 000000000..96d49d9fa --- /dev/null +++ b/net/bluetooth/cmtp/sock.c @@ -0,0 +1,271 @@ +/* + CMTP implementation for Linux Bluetooth stack (BlueZ). + Copyright (C) 2002-2003 Marcel Holtmann + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + + +#include "cmtp.h" + +static struct bt_sock_list cmtp_sk_list = { + .lock = __RW_LOCK_UNLOCKED(cmtp_sk_list.lock) +}; + +static int cmtp_sock_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + + BT_DBG("sock %p sk %p", sock, sk); + + if (!sk) + return 0; + + bt_sock_unlink(&cmtp_sk_list, sk); + + sock_orphan(sk); + sock_put(sk); + + return 0; +} + +static int do_cmtp_sock_ioctl(struct socket *sock, unsigned int cmd, void __user *argp) +{ + struct cmtp_connadd_req ca; + struct cmtp_conndel_req cd; + struct cmtp_connlist_req cl; + struct cmtp_conninfo ci; + struct socket *nsock; + int err; + + BT_DBG("cmd %x arg %p", cmd, argp); + + switch (cmd) { + case CMTPCONNADD: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + if (copy_from_user(&ca, argp, sizeof(ca))) + return -EFAULT; + + nsock = sockfd_lookup(ca.sock, &err); + if (!nsock) + return err; + + if (nsock->sk->sk_state != BT_CONNECTED) { + sockfd_put(nsock); + return -EBADFD; + } + + err = cmtp_add_connection(&ca, nsock); + if (!err) { + if (copy_to_user(argp, &ca, sizeof(ca))) + err = -EFAULT; + } else + sockfd_put(nsock); + + return err; + + case CMTPCONNDEL: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + if (copy_from_user(&cd, argp, sizeof(cd))) + return -EFAULT; + + return cmtp_del_connection(&cd); + + case CMTPGETCONNLIST: + if (copy_from_user(&cl, argp, sizeof(cl))) + return -EFAULT; + + if (cl.cnum <= 0) + return -EINVAL; + + err = cmtp_get_connlist(&cl); + if (!err && copy_to_user(argp, &cl, sizeof(cl))) + return -EFAULT; + + return err; + + case CMTPGETCONNINFO: + if (copy_from_user(&ci, argp, sizeof(ci))) + return -EFAULT; + + err = cmtp_get_conninfo(&ci); + if (!err && copy_to_user(argp, &ci, sizeof(ci))) + return -EFAULT; + + return err; + } + + return -EINVAL; +} + +static int cmtp_sock_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + return do_cmtp_sock_ioctl(sock, cmd, (void __user *)arg); +} + +#ifdef CONFIG_COMPAT +static int cmtp_sock_compat_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + void __user *argp = compat_ptr(arg); + if (cmd == CMTPGETCONNLIST) { + struct cmtp_connlist_req cl; + u32 __user *p = argp; + u32 uci; + int err; + + if (get_user(cl.cnum, p) || get_user(uci, p + 1)) + return -EFAULT; + + cl.ci = compat_ptr(uci); + + if (cl.cnum <= 0) + return -EINVAL; + + err = cmtp_get_connlist(&cl); + + if (!err && put_user(cl.cnum, p)) + err = -EFAULT; + + return err; + } + + return do_cmtp_sock_ioctl(sock, cmd, argp); +} +#endif + +static const struct proto_ops cmtp_sock_ops = { + .family = PF_BLUETOOTH, + .owner = THIS_MODULE, + .release = cmtp_sock_release, + .ioctl = cmtp_sock_ioctl, +#ifdef CONFIG_COMPAT + .compat_ioctl = cmtp_sock_compat_ioctl, +#endif + .bind = sock_no_bind, + .getname = sock_no_getname, + .sendmsg = sock_no_sendmsg, + .recvmsg = sock_no_recvmsg, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .connect = sock_no_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .mmap = sock_no_mmap +}; + +static struct proto cmtp_proto = { + .name = "CMTP", + .owner = THIS_MODULE, + .obj_size = sizeof(struct bt_sock) +}; + +static int cmtp_sock_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + + BT_DBG("sock %p", sock); + + if (sock->type != SOCK_RAW) + return -ESOCKTNOSUPPORT; + + sk = sk_alloc(net, PF_BLUETOOTH, GFP_ATOMIC, &cmtp_proto, kern); + if (!sk) + return -ENOMEM; + + sock_init_data(sock, sk); + + sock->ops = &cmtp_sock_ops; + + sock->state = SS_UNCONNECTED; + + sock_reset_flag(sk, SOCK_ZAPPED); + + sk->sk_protocol = protocol; + sk->sk_state = BT_OPEN; + + bt_sock_link(&cmtp_sk_list, sk); + + return 0; +} + +static const struct net_proto_family cmtp_sock_family_ops = { + .family = PF_BLUETOOTH, + .owner = THIS_MODULE, + .create = cmtp_sock_create +}; + +int cmtp_init_sockets(void) +{ + int err; + + err = proto_register(&cmtp_proto, 0); + if (err < 0) + return err; + + err = bt_sock_register(BTPROTO_CMTP, &cmtp_sock_family_ops); + if (err < 0) { + BT_ERR("Can't register CMTP socket"); + goto error; + } + + err = bt_procfs_init(&init_net, "cmtp", &cmtp_sk_list, NULL); + if (err < 0) { + BT_ERR("Failed to create CMTP proc file"); + bt_sock_unregister(BTPROTO_HIDP); + goto error; + } + + BT_INFO("CMTP socket layer initialized"); + + return 0; + +error: + proto_unregister(&cmtp_proto); + return err; +} + +void cmtp_cleanup_sockets(void) +{ + bt_procfs_cleanup(&init_net, "cmtp"); + bt_sock_unregister(BTPROTO_CMTP); + proto_unregister(&cmtp_proto); +} diff --git a/net/bluetooth/ecdh_helper.c b/net/bluetooth/ecdh_helper.c new file mode 100644 index 000000000..989401f11 --- /dev/null +++ b/net/bluetooth/ecdh_helper.c @@ -0,0 +1,228 @@ +/* + * ECDH helper functions - KPP wrappings + * + * Copyright (C) 2017 Intel Corporation + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 as + * published by the Free Software Foundation; + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + * IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + * CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * + * ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + * COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + * SOFTWARE IS DISCLAIMED. + */ +#include "ecdh_helper.h" + +#include +#include + +struct ecdh_completion { + struct completion completion; + int err; +}; + +static void ecdh_complete(struct crypto_async_request *req, int err) +{ + struct ecdh_completion *res = req->data; + + if (err == -EINPROGRESS) + return; + + res->err = err; + complete(&res->completion); +} + +static inline void swap_digits(u64 *in, u64 *out, unsigned int ndigits) +{ + int i; + + for (i = 0; i < ndigits; i++) + out[i] = __swab64(in[ndigits - 1 - i]); +} + +/* compute_ecdh_secret() - function assumes that the private key was + * already set. + * @tfm: KPP tfm handle allocated with crypto_alloc_kpp(). + * @public_key: pair's ecc public key. + * secret: memory where the ecdh computed shared secret will be saved. + * + * Return: zero on success; error code in case of error. + */ +int compute_ecdh_secret(struct crypto_kpp *tfm, const u8 public_key[64], + u8 secret[32]) +{ + struct kpp_request *req; + u8 *tmp; + struct ecdh_completion result; + struct scatterlist src, dst; + int err; + + tmp = kmalloc(64, GFP_KERNEL); + if (!tmp) + return -ENOMEM; + + req = kpp_request_alloc(tfm, GFP_KERNEL); + if (!req) { + err = -ENOMEM; + goto free_tmp; + } + + init_completion(&result.completion); + + swap_digits((u64 *)public_key, (u64 *)tmp, 4); /* x */ + swap_digits((u64 *)&public_key[32], (u64 *)&tmp[32], 4); /* y */ + + sg_init_one(&src, tmp, 64); + sg_init_one(&dst, secret, 32); + kpp_request_set_input(req, &src, 64); + kpp_request_set_output(req, &dst, 32); + kpp_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG, + ecdh_complete, &result); + err = crypto_kpp_compute_shared_secret(req); + if (err == -EINPROGRESS) { + wait_for_completion(&result.completion); + err = result.err; + } + if (err < 0) { + pr_err("alg: ecdh: compute shared secret failed. err %d\n", + err); + goto free_all; + } + + swap_digits((u64 *)secret, (u64 *)tmp, 4); + memcpy(secret, tmp, 32); + +free_all: + kpp_request_free(req); +free_tmp: + kfree_sensitive(tmp); + return err; +} + +/* set_ecdh_privkey() - set or generate ecc private key. + * + * Function generates an ecc private key in the crypto subsystem when receiving + * a NULL private key or sets the received key when not NULL. + * + * @tfm: KPP tfm handle allocated with crypto_alloc_kpp(). + * @private_key: user's ecc private key. When not NULL, the key is expected + * in little endian format. + * + * Return: zero on success; error code in case of error. + */ +int set_ecdh_privkey(struct crypto_kpp *tfm, const u8 private_key[32]) +{ + u8 *buf, *tmp = NULL; + unsigned int buf_len; + int err; + struct ecdh p = {0}; + + if (private_key) { + tmp = kmalloc(32, GFP_KERNEL); + if (!tmp) + return -ENOMEM; + swap_digits((u64 *)private_key, (u64 *)tmp, 4); + p.key = tmp; + p.key_size = 32; + } + + buf_len = crypto_ecdh_key_len(&p); + buf = kmalloc(buf_len, GFP_KERNEL); + if (!buf) { + err = -ENOMEM; + goto free_tmp; + } + + err = crypto_ecdh_encode_key(buf, buf_len, &p); + if (err) + goto free_all; + + err = crypto_kpp_set_secret(tfm, buf, buf_len); + /* fall through */ +free_all: + kfree_sensitive(buf); +free_tmp: + kfree_sensitive(tmp); + return err; +} + +/* generate_ecdh_public_key() - function assumes that the private key was + * already set. + * + * @tfm: KPP tfm handle allocated with crypto_alloc_kpp(). + * @public_key: memory where the computed ecc public key will be saved. + * + * Return: zero on success; error code in case of error. + */ +int generate_ecdh_public_key(struct crypto_kpp *tfm, u8 public_key[64]) +{ + struct kpp_request *req; + u8 *tmp; + struct ecdh_completion result; + struct scatterlist dst; + int err; + + tmp = kmalloc(64, GFP_KERNEL); + if (!tmp) + return -ENOMEM; + + req = kpp_request_alloc(tfm, GFP_KERNEL); + if (!req) { + err = -ENOMEM; + goto free_tmp; + } + + init_completion(&result.completion); + sg_init_one(&dst, tmp, 64); + kpp_request_set_input(req, NULL, 0); + kpp_request_set_output(req, &dst, 64); + kpp_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG, + ecdh_complete, &result); + + err = crypto_kpp_generate_public_key(req); + if (err == -EINPROGRESS) { + wait_for_completion(&result.completion); + err = result.err; + } + if (err < 0) + goto free_all; + + /* The public key is handed back in little endian as expected by + * the Security Manager Protocol. + */ + swap_digits((u64 *)tmp, (u64 *)public_key, 4); /* x */ + swap_digits((u64 *)&tmp[32], (u64 *)&public_key[32], 4); /* y */ + +free_all: + kpp_request_free(req); +free_tmp: + kfree(tmp); + return err; +} + +/* generate_ecdh_keys() - generate ecc key pair. + * + * @tfm: KPP tfm handle allocated with crypto_alloc_kpp(). + * @public_key: memory where the computed ecc public key will be saved. + * + * Return: zero on success; error code in case of error. + */ +int generate_ecdh_keys(struct crypto_kpp *tfm, u8 public_key[64]) +{ + int err; + + err = set_ecdh_privkey(tfm, NULL); + if (err) + return err; + + return generate_ecdh_public_key(tfm, public_key); +} diff --git a/net/bluetooth/ecdh_helper.h b/net/bluetooth/ecdh_helper.h new file mode 100644 index 000000000..830723971 --- /dev/null +++ b/net/bluetooth/ecdh_helper.h @@ -0,0 +1,30 @@ +/* + * ECDH helper functions - KPP wrappings + * + * Copyright (C) 2017 Intel Corporation + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 as + * published by the Free Software Foundation; + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + * IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + * CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * + * ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + * COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + * SOFTWARE IS DISCLAIMED. + */ +#include +#include + +int compute_ecdh_secret(struct crypto_kpp *tfm, const u8 pair_public_key[64], + u8 secret[32]); +int set_ecdh_privkey(struct crypto_kpp *tfm, const u8 private_key[32]); +int generate_ecdh_public_key(struct crypto_kpp *tfm, u8 public_key[64]); +int generate_ecdh_keys(struct crypto_kpp *tfm, u8 public_key[64]); diff --git a/net/bluetooth/eir.c b/net/bluetooth/eir.c new file mode 100644 index 000000000..8a85f6cdf --- /dev/null +++ b/net/bluetooth/eir.c @@ -0,0 +1,398 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * BlueZ - Bluetooth protocol stack for Linux + * + * Copyright (C) 2021 Intel Corporation + */ + +#include +#include +#include + +#include "eir.h" + +#define PNP_INFO_SVCLASS_ID 0x1200 + +static u8 eir_append_name(u8 *eir, u16 eir_len, u8 type, u8 *data, u8 data_len) +{ + u8 name[HCI_MAX_SHORT_NAME_LENGTH + 1]; + + /* If data is already NULL terminated just pass it directly */ + if (data[data_len - 1] == '\0') + return eir_append_data(eir, eir_len, type, data, data_len); + + memcpy(name, data, HCI_MAX_SHORT_NAME_LENGTH); + name[HCI_MAX_SHORT_NAME_LENGTH] = '\0'; + + return eir_append_data(eir, eir_len, type, name, sizeof(name)); +} + +u8 eir_append_local_name(struct hci_dev *hdev, u8 *ptr, u8 ad_len) +{ + size_t short_len; + size_t complete_len; + + /* no space left for name (+ NULL + type + len) */ + if ((HCI_MAX_AD_LENGTH - ad_len) < HCI_MAX_SHORT_NAME_LENGTH + 3) + return ad_len; + + /* use complete name if present and fits */ + complete_len = strnlen(hdev->dev_name, sizeof(hdev->dev_name)); + if (complete_len && complete_len <= HCI_MAX_SHORT_NAME_LENGTH) + return eir_append_name(ptr, ad_len, EIR_NAME_COMPLETE, + hdev->dev_name, complete_len + 1); + + /* use short name if present */ + short_len = strnlen(hdev->short_name, sizeof(hdev->short_name)); + if (short_len) + return eir_append_name(ptr, ad_len, EIR_NAME_SHORT, + hdev->short_name, + short_len == HCI_MAX_SHORT_NAME_LENGTH ? + short_len : short_len + 1); + + /* use shortened full name if present, we already know that name + * is longer then HCI_MAX_SHORT_NAME_LENGTH + */ + if (complete_len) + return eir_append_name(ptr, ad_len, EIR_NAME_SHORT, + hdev->dev_name, + HCI_MAX_SHORT_NAME_LENGTH); + + return ad_len; +} + +u8 eir_append_appearance(struct hci_dev *hdev, u8 *ptr, u8 ad_len) +{ + return eir_append_le16(ptr, ad_len, EIR_APPEARANCE, hdev->appearance); +} + +u8 eir_append_service_data(u8 *eir, u16 eir_len, u16 uuid, u8 *data, + u8 data_len) +{ + eir[eir_len++] = sizeof(u8) + sizeof(uuid) + data_len; + eir[eir_len++] = EIR_SERVICE_DATA; + put_unaligned_le16(uuid, &eir[eir_len]); + eir_len += sizeof(uuid); + memcpy(&eir[eir_len], data, data_len); + eir_len += data_len; + + return eir_len; +} + +static u8 *create_uuid16_list(struct hci_dev *hdev, u8 *data, ptrdiff_t len) +{ + u8 *ptr = data, *uuids_start = NULL; + struct bt_uuid *uuid; + + if (len < 4) + return ptr; + + list_for_each_entry(uuid, &hdev->uuids, list) { + u16 uuid16; + + if (uuid->size != 16) + continue; + + uuid16 = get_unaligned_le16(&uuid->uuid[12]); + if (uuid16 < 0x1100) + continue; + + if (uuid16 == PNP_INFO_SVCLASS_ID) + continue; + + if (!uuids_start) { + uuids_start = ptr; + uuids_start[0] = 1; + uuids_start[1] = EIR_UUID16_ALL; + ptr += 2; + } + + /* Stop if not enough space to put next UUID */ + if ((ptr - data) + sizeof(u16) > len) { + uuids_start[1] = EIR_UUID16_SOME; + break; + } + + *ptr++ = (uuid16 & 0x00ff); + *ptr++ = (uuid16 & 0xff00) >> 8; + uuids_start[0] += sizeof(uuid16); + } + + return ptr; +} + +static u8 *create_uuid32_list(struct hci_dev *hdev, u8 *data, ptrdiff_t len) +{ + u8 *ptr = data, *uuids_start = NULL; + struct bt_uuid *uuid; + + if (len < 6) + return ptr; + + list_for_each_entry(uuid, &hdev->uuids, list) { + if (uuid->size != 32) + continue; + + if (!uuids_start) { + uuids_start = ptr; + uuids_start[0] = 1; + uuids_start[1] = EIR_UUID32_ALL; + ptr += 2; + } + + /* Stop if not enough space to put next UUID */ + if ((ptr - data) + sizeof(u32) > len) { + uuids_start[1] = EIR_UUID32_SOME; + break; + } + + memcpy(ptr, &uuid->uuid[12], sizeof(u32)); + ptr += sizeof(u32); + uuids_start[0] += sizeof(u32); + } + + return ptr; +} + +static u8 *create_uuid128_list(struct hci_dev *hdev, u8 *data, ptrdiff_t len) +{ + u8 *ptr = data, *uuids_start = NULL; + struct bt_uuid *uuid; + + if (len < 18) + return ptr; + + list_for_each_entry(uuid, &hdev->uuids, list) { + if (uuid->size != 128) + continue; + + if (!uuids_start) { + uuids_start = ptr; + uuids_start[0] = 1; + uuids_start[1] = EIR_UUID128_ALL; + ptr += 2; + } + + /* Stop if not enough space to put next UUID */ + if ((ptr - data) + 16 > len) { + uuids_start[1] = EIR_UUID128_SOME; + break; + } + + memcpy(ptr, uuid->uuid, 16); + ptr += 16; + uuids_start[0] += 16; + } + + return ptr; +} + +void eir_create(struct hci_dev *hdev, u8 *data) +{ + u8 *ptr = data; + size_t name_len; + + name_len = strnlen(hdev->dev_name, sizeof(hdev->dev_name)); + + if (name_len > 0) { + /* EIR Data type */ + if (name_len > 48) { + name_len = 48; + ptr[1] = EIR_NAME_SHORT; + } else { + ptr[1] = EIR_NAME_COMPLETE; + } + + /* EIR Data length */ + ptr[0] = name_len + 1; + + memcpy(ptr + 2, hdev->dev_name, name_len); + + ptr += (name_len + 2); + } + + if (hdev->inq_tx_power != HCI_TX_POWER_INVALID) { + ptr[0] = 2; + ptr[1] = EIR_TX_POWER; + ptr[2] = (u8)hdev->inq_tx_power; + + ptr += 3; + } + + if (hdev->devid_source > 0) { + ptr[0] = 9; + ptr[1] = EIR_DEVICE_ID; + + put_unaligned_le16(hdev->devid_source, ptr + 2); + put_unaligned_le16(hdev->devid_vendor, ptr + 4); + put_unaligned_le16(hdev->devid_product, ptr + 6); + put_unaligned_le16(hdev->devid_version, ptr + 8); + + ptr += 10; + } + + ptr = create_uuid16_list(hdev, ptr, HCI_MAX_EIR_LENGTH - (ptr - data)); + ptr = create_uuid32_list(hdev, ptr, HCI_MAX_EIR_LENGTH - (ptr - data)); + ptr = create_uuid128_list(hdev, ptr, HCI_MAX_EIR_LENGTH - (ptr - data)); +} + +u8 eir_create_per_adv_data(struct hci_dev *hdev, u8 instance, u8 *ptr) +{ + struct adv_info *adv = NULL; + u8 ad_len = 0; + + /* Return 0 when the current instance identifier is invalid. */ + if (instance) { + adv = hci_find_adv_instance(hdev, instance); + if (!adv) + return 0; + } + + if (adv) { + memcpy(ptr, adv->per_adv_data, adv->per_adv_data_len); + ad_len += adv->per_adv_data_len; + ptr += adv->per_adv_data_len; + } + + return ad_len; +} + +u8 eir_create_adv_data(struct hci_dev *hdev, u8 instance, u8 *ptr) +{ + struct adv_info *adv = NULL; + u8 ad_len = 0, flags = 0; + u32 instance_flags; + + /* Return 0 when the current instance identifier is invalid. */ + if (instance) { + adv = hci_find_adv_instance(hdev, instance); + if (!adv) + return 0; + } + + instance_flags = hci_adv_instance_flags(hdev, instance); + + /* If instance already has the flags set skip adding it once + * again. + */ + if (adv && eir_get_data(adv->adv_data, adv->adv_data_len, EIR_FLAGS, + NULL)) + goto skip_flags; + + /* The Add Advertising command allows userspace to set both the general + * and limited discoverable flags. + */ + if (instance_flags & MGMT_ADV_FLAG_DISCOV) + flags |= LE_AD_GENERAL; + + if (instance_flags & MGMT_ADV_FLAG_LIMITED_DISCOV) + flags |= LE_AD_LIMITED; + + if (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) + flags |= LE_AD_NO_BREDR; + + if (flags || (instance_flags & MGMT_ADV_FLAG_MANAGED_FLAGS)) { + /* If a discovery flag wasn't provided, simply use the global + * settings. + */ + if (!flags) + flags |= mgmt_get_adv_discov_flags(hdev); + + /* If flags would still be empty, then there is no need to + * include the "Flags" AD field". + */ + if (flags) { + ptr[0] = 0x02; + ptr[1] = EIR_FLAGS; + ptr[2] = flags; + + ad_len += 3; + ptr += 3; + } + } + +skip_flags: + if (adv) { + memcpy(ptr, adv->adv_data, adv->adv_data_len); + ad_len += adv->adv_data_len; + ptr += adv->adv_data_len; + } + + if (instance_flags & MGMT_ADV_FLAG_TX_POWER) { + s8 adv_tx_power; + + if (ext_adv_capable(hdev)) { + if (adv) + adv_tx_power = adv->tx_power; + else + adv_tx_power = hdev->adv_tx_power; + } else { + adv_tx_power = hdev->adv_tx_power; + } + + /* Provide Tx Power only if we can provide a valid value for it */ + if (adv_tx_power != HCI_TX_POWER_INVALID) { + ptr[0] = 0x02; + ptr[1] = EIR_TX_POWER; + ptr[2] = (u8)adv_tx_power; + + ad_len += 3; + ptr += 3; + } + } + + return ad_len; +} + +static u8 create_default_scan_rsp(struct hci_dev *hdev, u8 *ptr) +{ + u8 scan_rsp_len = 0; + + if (hdev->appearance) + scan_rsp_len = eir_append_appearance(hdev, ptr, scan_rsp_len); + + return eir_append_local_name(hdev, ptr, scan_rsp_len); +} + +u8 eir_create_scan_rsp(struct hci_dev *hdev, u8 instance, u8 *ptr) +{ + struct adv_info *adv; + u8 scan_rsp_len = 0; + + if (!instance) + return create_default_scan_rsp(hdev, ptr); + + adv = hci_find_adv_instance(hdev, instance); + if (!adv) + return 0; + + if ((adv->flags & MGMT_ADV_FLAG_APPEARANCE) && hdev->appearance) + scan_rsp_len = eir_append_appearance(hdev, ptr, scan_rsp_len); + + memcpy(&ptr[scan_rsp_len], adv->scan_rsp_data, adv->scan_rsp_len); + + scan_rsp_len += adv->scan_rsp_len; + + if (adv->flags & MGMT_ADV_FLAG_LOCAL_NAME) + scan_rsp_len = eir_append_local_name(hdev, ptr, scan_rsp_len); + + return scan_rsp_len; +} + +void *eir_get_service_data(u8 *eir, size_t eir_len, u16 uuid, size_t *len) +{ + while ((eir = eir_get_data(eir, eir_len, EIR_SERVICE_DATA, len))) { + u16 value = get_unaligned_le16(eir); + + if (uuid == value) { + if (len) + *len -= 2; + return &eir[2]; + } + + eir += *len; + eir_len -= *len; + } + + return NULL; +} diff --git a/net/bluetooth/eir.h b/net/bluetooth/eir.h new file mode 100644 index 000000000..0df19f2f4 --- /dev/null +++ b/net/bluetooth/eir.h @@ -0,0 +1,99 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * BlueZ - Bluetooth protocol stack for Linux + * + * Copyright (C) 2021 Intel Corporation + */ + +#include + +void eir_create(struct hci_dev *hdev, u8 *data); + +u8 eir_create_adv_data(struct hci_dev *hdev, u8 instance, u8 *ptr); +u8 eir_create_scan_rsp(struct hci_dev *hdev, u8 instance, u8 *ptr); +u8 eir_create_per_adv_data(struct hci_dev *hdev, u8 instance, u8 *ptr); + +u8 eir_append_local_name(struct hci_dev *hdev, u8 *eir, u8 ad_len); +u8 eir_append_appearance(struct hci_dev *hdev, u8 *ptr, u8 ad_len); +u8 eir_append_service_data(u8 *eir, u16 eir_len, u16 uuid, u8 *data, + u8 data_len); + +static inline u16 eir_precalc_len(u8 data_len) +{ + return sizeof(u8) * 2 + data_len; +} + +static inline u16 eir_append_data(u8 *eir, u16 eir_len, u8 type, + u8 *data, u8 data_len) +{ + eir[eir_len++] = sizeof(type) + data_len; + eir[eir_len++] = type; + memcpy(&eir[eir_len], data, data_len); + eir_len += data_len; + + return eir_len; +} + +static inline u16 eir_append_le16(u8 *eir, u16 eir_len, u8 type, u16 data) +{ + eir[eir_len++] = sizeof(type) + sizeof(data); + eir[eir_len++] = type; + put_unaligned_le16(data, &eir[eir_len]); + eir_len += sizeof(data); + + return eir_len; +} + +static inline u16 eir_skb_put_data(struct sk_buff *skb, u8 type, u8 *data, u8 data_len) +{ + u8 *eir; + u16 eir_len; + + eir_len = eir_precalc_len(data_len); + eir = skb_put(skb, eir_len); + WARN_ON(sizeof(type) + data_len > U8_MAX); + eir[0] = sizeof(type) + data_len; + eir[1] = type; + memcpy(&eir[2], data, data_len); + + return eir_len; +} + +static inline void *eir_get_data(u8 *eir, size_t eir_len, u8 type, + size_t *data_len) +{ + size_t parsed = 0; + + if (eir_len < 2) + return NULL; + + while (parsed < eir_len - 1) { + u8 field_len = eir[0]; + + if (field_len == 0) + break; + + parsed += field_len + 1; + + if (parsed > eir_len) + break; + + if (eir[1] != type) { + eir += field_len + 1; + continue; + } + + /* Zero length data */ + if (field_len == 1) + return NULL; + + if (data_len) + *data_len = field_len - 1; + + return &eir[2]; + } + + return NULL; +} + +void *eir_get_service_data(u8 *eir, size_t eir_len, u16 uuid, size_t *len); diff --git a/net/bluetooth/hci_codec.c b/net/bluetooth/hci_codec.c new file mode 100644 index 000000000..3cc135bb1 --- /dev/null +++ b/net/bluetooth/hci_codec.c @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: GPL-2.0 + +/* Copyright (C) 2021 Intel Corporation */ + +#include +#include +#include "hci_codec.h" + +static int hci_codec_list_add(struct list_head *list, + struct hci_op_read_local_codec_caps *sent, + struct hci_rp_read_local_codec_caps *rp, + void *caps, + __u32 len) +{ + struct codec_list *entry; + + entry = kzalloc(sizeof(*entry) + len, GFP_KERNEL); + if (!entry) + return -ENOMEM; + + entry->id = sent->id; + if (sent->id == 0xFF) { + entry->cid = __le16_to_cpu(sent->cid); + entry->vid = __le16_to_cpu(sent->vid); + } + entry->transport = sent->transport; + entry->len = len; + entry->num_caps = 0; + if (rp) { + entry->num_caps = rp->num_caps; + memcpy(entry->caps, caps, len); + } + list_add(&entry->list, list); + + return 0; +} + +void hci_codec_list_clear(struct list_head *codec_list) +{ + struct codec_list *c, *n; + + list_for_each_entry_safe(c, n, codec_list, list) { + list_del(&c->list); + kfree(c); + } +} + +static void hci_read_codec_capabilities(struct hci_dev *hdev, __u8 transport, + struct hci_op_read_local_codec_caps + *cmd) +{ + __u8 i; + + for (i = 0; i < TRANSPORT_TYPE_MAX; i++) { + if (transport & BIT(i)) { + struct hci_rp_read_local_codec_caps *rp; + struct hci_codec_caps *caps; + struct sk_buff *skb; + __u8 j; + __u32 len; + + cmd->transport = i; + + /* If Read_Codec_Capabilities command is not supported + * then just add codec to the list without caps + */ + if (!(hdev->commands[45] & 0x08)) { + hci_dev_lock(hdev); + hci_codec_list_add(&hdev->local_codecs, cmd, + NULL, NULL, 0); + hci_dev_unlock(hdev); + continue; + } + + skb = __hci_cmd_sync_sk(hdev, HCI_OP_READ_LOCAL_CODEC_CAPS, + sizeof(*cmd), cmd, 0, HCI_CMD_TIMEOUT, NULL); + if (IS_ERR(skb)) { + bt_dev_err(hdev, "Failed to read codec capabilities (%ld)", + PTR_ERR(skb)); + continue; + } + + if (skb->len < sizeof(*rp)) + goto error; + + rp = (void *)skb->data; + + if (rp->status) + goto error; + + if (!rp->num_caps) { + len = 0; + /* this codec doesn't have capabilities */ + goto skip_caps_parse; + } + + skb_pull(skb, sizeof(*rp)); + + for (j = 0, len = 0; j < rp->num_caps; j++) { + caps = (void *)skb->data; + if (skb->len < sizeof(*caps)) + goto error; + if (skb->len < caps->len) + goto error; + len += sizeof(caps->len) + caps->len; + skb_pull(skb, sizeof(caps->len) + caps->len); + } + +skip_caps_parse: + hci_dev_lock(hdev); + hci_codec_list_add(&hdev->local_codecs, cmd, rp, + (__u8 *)rp + sizeof(*rp), len); + hci_dev_unlock(hdev); +error: + kfree_skb(skb); + } + } +} + +void hci_read_supported_codecs(struct hci_dev *hdev) +{ + struct sk_buff *skb; + struct hci_rp_read_local_supported_codecs *rp; + struct hci_std_codecs *std_codecs; + struct hci_vnd_codecs *vnd_codecs; + struct hci_op_read_local_codec_caps caps; + __u8 i; + + skb = __hci_cmd_sync_sk(hdev, HCI_OP_READ_LOCAL_CODECS, 0, NULL, + 0, HCI_CMD_TIMEOUT, NULL); + + if (IS_ERR(skb)) { + bt_dev_err(hdev, "Failed to read local supported codecs (%ld)", + PTR_ERR(skb)); + return; + } + + if (skb->len < sizeof(*rp)) + goto error; + + rp = (void *)skb->data; + + if (rp->status) + goto error; + + skb_pull(skb, sizeof(rp->status)); + + std_codecs = (void *)skb->data; + + /* validate codecs length before accessing */ + if (skb->len < flex_array_size(std_codecs, codec, std_codecs->num) + + sizeof(std_codecs->num)) + goto error; + + /* enumerate codec capabilities of standard codecs */ + memset(&caps, 0, sizeof(caps)); + for (i = 0; i < std_codecs->num; i++) { + caps.id = std_codecs->codec[i]; + caps.direction = 0x00; + hci_read_codec_capabilities(hdev, + LOCAL_CODEC_ACL_MASK | LOCAL_CODEC_SCO_MASK, &caps); + } + + skb_pull(skb, flex_array_size(std_codecs, codec, std_codecs->num) + + sizeof(std_codecs->num)); + + vnd_codecs = (void *)skb->data; + + /* validate vendor codecs length before accessing */ + if (skb->len < + flex_array_size(vnd_codecs, codec, vnd_codecs->num) + + sizeof(vnd_codecs->num)) + goto error; + + /* enumerate vendor codec capabilities */ + for (i = 0; i < vnd_codecs->num; i++) { + caps.id = 0xFF; + caps.cid = vnd_codecs->codec[i].cid; + caps.vid = vnd_codecs->codec[i].vid; + caps.direction = 0x00; + hci_read_codec_capabilities(hdev, + LOCAL_CODEC_ACL_MASK | LOCAL_CODEC_SCO_MASK, &caps); + } + +error: + kfree_skb(skb); +} + +void hci_read_supported_codecs_v2(struct hci_dev *hdev) +{ + struct sk_buff *skb; + struct hci_rp_read_local_supported_codecs_v2 *rp; + struct hci_std_codecs_v2 *std_codecs; + struct hci_vnd_codecs_v2 *vnd_codecs; + struct hci_op_read_local_codec_caps caps; + __u8 i; + + skb = __hci_cmd_sync_sk(hdev, HCI_OP_READ_LOCAL_CODECS_V2, 0, NULL, + 0, HCI_CMD_TIMEOUT, NULL); + + if (IS_ERR(skb)) { + bt_dev_err(hdev, "Failed to read local supported codecs (%ld)", + PTR_ERR(skb)); + return; + } + + if (skb->len < sizeof(*rp)) + goto error; + + rp = (void *)skb->data; + + if (rp->status) + goto error; + + skb_pull(skb, sizeof(rp->status)); + + std_codecs = (void *)skb->data; + + /* check for payload data length before accessing */ + if (skb->len < flex_array_size(std_codecs, codec, std_codecs->num) + + sizeof(std_codecs->num)) + goto error; + + memset(&caps, 0, sizeof(caps)); + + for (i = 0; i < std_codecs->num; i++) { + caps.id = std_codecs->codec[i].id; + hci_read_codec_capabilities(hdev, std_codecs->codec[i].transport, + &caps); + } + + skb_pull(skb, flex_array_size(std_codecs, codec, std_codecs->num) + + sizeof(std_codecs->num)); + + vnd_codecs = (void *)skb->data; + + /* check for payload data length before accessing */ + if (skb->len < + flex_array_size(vnd_codecs, codec, vnd_codecs->num) + + sizeof(vnd_codecs->num)) + goto error; + + for (i = 0; i < vnd_codecs->num; i++) { + caps.id = 0xFF; + caps.cid = vnd_codecs->codec[i].cid; + caps.vid = vnd_codecs->codec[i].vid; + hci_read_codec_capabilities(hdev, vnd_codecs->codec[i].transport, + &caps); + } + +error: + kfree_skb(skb); +} diff --git a/net/bluetooth/hci_codec.h b/net/bluetooth/hci_codec.h new file mode 100644 index 000000000..a2751930f --- /dev/null +++ b/net/bluetooth/hci_codec.h @@ -0,0 +1,7 @@ +/* SPDX-License-Identifier: GPL-2.0 */ + +/* Copyright (C) 2014 Intel Corporation */ + +void hci_read_supported_codecs(struct hci_dev *hdev); +void hci_read_supported_codecs_v2(struct hci_dev *hdev); +void hci_codec_list_clear(struct list_head *codec_list); diff --git a/net/bluetooth/hci_conn.c b/net/bluetooth/hci_conn.c new file mode 100644 index 000000000..12d368753 --- /dev/null +++ b/net/bluetooth/hci_conn.c @@ -0,0 +1,2917 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + Copyright (c) 2000-2001, 2010, Code Aurora Forum. All rights reserved. + + Written 2000,2001 by Maxim Krasnyansky + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +/* Bluetooth HCI connection handling. */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "hci_request.h" +#include "smp.h" +#include "a2mp.h" +#include "eir.h" + +struct sco_param { + u16 pkt_type; + u16 max_latency; + u8 retrans_effort; +}; + +struct conn_handle_t { + struct hci_conn *conn; + __u16 handle; +}; + +static const struct sco_param esco_param_cvsd[] = { + { EDR_ESCO_MASK & ~ESCO_2EV3, 0x000a, 0x01 }, /* S3 */ + { EDR_ESCO_MASK & ~ESCO_2EV3, 0x0007, 0x01 }, /* S2 */ + { EDR_ESCO_MASK | ESCO_EV3, 0x0007, 0x01 }, /* S1 */ + { EDR_ESCO_MASK | ESCO_HV3, 0xffff, 0x01 }, /* D1 */ + { EDR_ESCO_MASK | ESCO_HV1, 0xffff, 0x01 }, /* D0 */ +}; + +static const struct sco_param sco_param_cvsd[] = { + { EDR_ESCO_MASK | ESCO_HV3, 0xffff, 0xff }, /* D1 */ + { EDR_ESCO_MASK | ESCO_HV1, 0xffff, 0xff }, /* D0 */ +}; + +static const struct sco_param esco_param_msbc[] = { + { EDR_ESCO_MASK & ~ESCO_2EV3, 0x000d, 0x02 }, /* T2 */ + { EDR_ESCO_MASK | ESCO_EV3, 0x0008, 0x02 }, /* T1 */ +}; + +/* This function requires the caller holds hdev->lock */ +static void hci_connect_le_scan_cleanup(struct hci_conn *conn, u8 status) +{ + struct hci_conn_params *params; + struct hci_dev *hdev = conn->hdev; + struct smp_irk *irk; + bdaddr_t *bdaddr; + u8 bdaddr_type; + + bdaddr = &conn->dst; + bdaddr_type = conn->dst_type; + + /* Check if we need to convert to identity address */ + irk = hci_get_irk(hdev, bdaddr, bdaddr_type); + if (irk) { + bdaddr = &irk->bdaddr; + bdaddr_type = irk->addr_type; + } + + params = hci_pend_le_action_lookup(&hdev->pend_le_conns, bdaddr, + bdaddr_type); + if (!params) + return; + + if (params->conn) { + hci_conn_drop(params->conn); + hci_conn_put(params->conn); + params->conn = NULL; + } + + if (!params->explicit_connect) + return; + + /* If the status indicates successful cancellation of + * the attempt (i.e. Unknown Connection Id) there's no point of + * notifying failure since we'll go back to keep trying to + * connect. The only exception is explicit connect requests + * where a timeout + cancel does indicate an actual failure. + */ + if (status && status != HCI_ERROR_UNKNOWN_CONN_ID) + mgmt_connect_failed(hdev, &conn->dst, conn->type, + conn->dst_type, status); + + /* The connection attempt was doing scan for new RPA, and is + * in scan phase. If params are not associated with any other + * autoconnect action, remove them completely. If they are, just unmark + * them as waiting for connection, by clearing explicit_connect field. + */ + params->explicit_connect = false; + + hci_pend_le_list_del_init(params); + + switch (params->auto_connect) { + case HCI_AUTO_CONN_EXPLICIT: + hci_conn_params_del(hdev, bdaddr, bdaddr_type); + /* return instead of break to avoid duplicate scan update */ + return; + case HCI_AUTO_CONN_DIRECT: + case HCI_AUTO_CONN_ALWAYS: + hci_pend_le_list_add(params, &hdev->pend_le_conns); + break; + case HCI_AUTO_CONN_REPORT: + hci_pend_le_list_add(params, &hdev->pend_le_reports); + break; + default: + break; + } + + hci_update_passive_scan(hdev); +} + +static void hci_conn_cleanup(struct hci_conn *conn) +{ + struct hci_dev *hdev = conn->hdev; + + if (test_bit(HCI_CONN_PARAM_REMOVAL_PEND, &conn->flags)) + hci_conn_params_del(conn->hdev, &conn->dst, conn->dst_type); + + if (test_and_clear_bit(HCI_CONN_FLUSH_KEY, &conn->flags)) + hci_remove_link_key(hdev, &conn->dst); + + hci_chan_list_flush(conn); + + hci_conn_hash_del(hdev, conn); + + if (conn->cleanup) + conn->cleanup(conn); + + if (conn->type == SCO_LINK || conn->type == ESCO_LINK) { + switch (conn->setting & SCO_AIRMODE_MASK) { + case SCO_AIRMODE_CVSD: + case SCO_AIRMODE_TRANSP: + if (hdev->notify) + hdev->notify(hdev, HCI_NOTIFY_DISABLE_SCO); + break; + } + } else { + if (hdev->notify) + hdev->notify(hdev, HCI_NOTIFY_CONN_DEL); + } + + debugfs_remove_recursive(conn->debugfs); + + hci_conn_del_sysfs(conn); + + hci_dev_put(hdev); +} + +static void le_scan_cleanup(struct work_struct *work) +{ + struct hci_conn *conn = container_of(work, struct hci_conn, + le_scan_cleanup); + struct hci_dev *hdev = conn->hdev; + struct hci_conn *c = NULL; + + BT_DBG("%s hcon %p", hdev->name, conn); + + hci_dev_lock(hdev); + + /* Check that the hci_conn is still around */ + rcu_read_lock(); + list_for_each_entry_rcu(c, &hdev->conn_hash.list, list) { + if (c == conn) + break; + } + rcu_read_unlock(); + + if (c == conn) { + hci_connect_le_scan_cleanup(conn, 0x00); + hci_conn_cleanup(conn); + } + + hci_dev_unlock(hdev); + hci_dev_put(hdev); + hci_conn_put(conn); +} + +static void hci_connect_le_scan_remove(struct hci_conn *conn) +{ + BT_DBG("%s hcon %p", conn->hdev->name, conn); + + /* We can't call hci_conn_del/hci_conn_cleanup here since that + * could deadlock with another hci_conn_del() call that's holding + * hci_dev_lock and doing cancel_delayed_work_sync(&conn->disc_work). + * Instead, grab temporary extra references to the hci_dev and + * hci_conn and perform the necessary cleanup in a separate work + * callback. + */ + + hci_dev_hold(conn->hdev); + hci_conn_get(conn); + + /* Even though we hold a reference to the hdev, many other + * things might get cleaned up meanwhile, including the hdev's + * own workqueue, so we can't use that for scheduling. + */ + schedule_work(&conn->le_scan_cleanup); +} + +static void hci_acl_create_connection(struct hci_conn *conn) +{ + struct hci_dev *hdev = conn->hdev; + struct inquiry_entry *ie; + struct hci_cp_create_conn cp; + + BT_DBG("hcon %p", conn); + + /* Many controllers disallow HCI Create Connection while it is doing + * HCI Inquiry. So we cancel the Inquiry first before issuing HCI Create + * Connection. This may cause the MGMT discovering state to become false + * without user space's request but it is okay since the MGMT Discovery + * APIs do not promise that discovery should be done forever. Instead, + * the user space monitors the status of MGMT discovering and it may + * request for discovery again when this flag becomes false. + */ + if (test_bit(HCI_INQUIRY, &hdev->flags)) { + /* Put this connection to "pending" state so that it will be + * executed after the inquiry cancel command complete event. + */ + conn->state = BT_CONNECT2; + hci_send_cmd(hdev, HCI_OP_INQUIRY_CANCEL, 0, NULL); + return; + } + + conn->state = BT_CONNECT; + conn->out = true; + conn->role = HCI_ROLE_MASTER; + + conn->attempt++; + + conn->link_policy = hdev->link_policy; + + memset(&cp, 0, sizeof(cp)); + bacpy(&cp.bdaddr, &conn->dst); + cp.pscan_rep_mode = 0x02; + + ie = hci_inquiry_cache_lookup(hdev, &conn->dst); + if (ie) { + if (inquiry_entry_age(ie) <= INQUIRY_ENTRY_AGE_MAX) { + cp.pscan_rep_mode = ie->data.pscan_rep_mode; + cp.pscan_mode = ie->data.pscan_mode; + cp.clock_offset = ie->data.clock_offset | + cpu_to_le16(0x8000); + } + + memcpy(conn->dev_class, ie->data.dev_class, 3); + } + + cp.pkt_type = cpu_to_le16(conn->pkt_type); + if (lmp_rswitch_capable(hdev) && !(hdev->link_mode & HCI_LM_MASTER)) + cp.role_switch = 0x01; + else + cp.role_switch = 0x00; + + hci_send_cmd(hdev, HCI_OP_CREATE_CONN, sizeof(cp), &cp); +} + +int hci_disconnect(struct hci_conn *conn, __u8 reason) +{ + BT_DBG("hcon %p", conn); + + /* When we are central of an established connection and it enters + * the disconnect timeout, then go ahead and try to read the + * current clock offset. Processing of the result is done + * within the event handling and hci_clock_offset_evt function. + */ + if (conn->type == ACL_LINK && conn->role == HCI_ROLE_MASTER && + (conn->state == BT_CONNECTED || conn->state == BT_CONFIG)) { + struct hci_dev *hdev = conn->hdev; + struct hci_cp_read_clock_offset clkoff_cp; + + clkoff_cp.handle = cpu_to_le16(conn->handle); + hci_send_cmd(hdev, HCI_OP_READ_CLOCK_OFFSET, sizeof(clkoff_cp), + &clkoff_cp); + } + + return hci_abort_conn(conn, reason); +} + +static void hci_add_sco(struct hci_conn *conn, __u16 handle) +{ + struct hci_dev *hdev = conn->hdev; + struct hci_cp_add_sco cp; + + BT_DBG("hcon %p", conn); + + conn->state = BT_CONNECT; + conn->out = true; + + conn->attempt++; + + cp.handle = cpu_to_le16(handle); + cp.pkt_type = cpu_to_le16(conn->pkt_type); + + hci_send_cmd(hdev, HCI_OP_ADD_SCO, sizeof(cp), &cp); +} + +static bool find_next_esco_param(struct hci_conn *conn, + const struct sco_param *esco_param, int size) +{ + for (; conn->attempt <= size; conn->attempt++) { + if (lmp_esco_2m_capable(conn->link) || + (esco_param[conn->attempt - 1].pkt_type & ESCO_2EV3)) + break; + BT_DBG("hcon %p skipped attempt %d, eSCO 2M not supported", + conn, conn->attempt); + } + + return conn->attempt <= size; +} + +static int configure_datapath_sync(struct hci_dev *hdev, struct bt_codec *codec) +{ + int err; + __u8 vnd_len, *vnd_data = NULL; + struct hci_op_configure_data_path *cmd = NULL; + + err = hdev->get_codec_config_data(hdev, ESCO_LINK, codec, &vnd_len, + &vnd_data); + if (err < 0) + goto error; + + cmd = kzalloc(sizeof(*cmd) + vnd_len, GFP_KERNEL); + if (!cmd) { + err = -ENOMEM; + goto error; + } + + err = hdev->get_data_path_id(hdev, &cmd->data_path_id); + if (err < 0) + goto error; + + cmd->vnd_len = vnd_len; + memcpy(cmd->vnd_data, vnd_data, vnd_len); + + cmd->direction = 0x00; + __hci_cmd_sync_status(hdev, HCI_CONFIGURE_DATA_PATH, + sizeof(*cmd) + vnd_len, cmd, HCI_CMD_TIMEOUT); + + cmd->direction = 0x01; + err = __hci_cmd_sync_status(hdev, HCI_CONFIGURE_DATA_PATH, + sizeof(*cmd) + vnd_len, cmd, + HCI_CMD_TIMEOUT); +error: + + kfree(cmd); + kfree(vnd_data); + return err; +} + +static int hci_enhanced_setup_sync(struct hci_dev *hdev, void *data) +{ + struct conn_handle_t *conn_handle = data; + struct hci_conn *conn = conn_handle->conn; + __u16 handle = conn_handle->handle; + struct hci_cp_enhanced_setup_sync_conn cp; + const struct sco_param *param; + + kfree(conn_handle); + + bt_dev_dbg(hdev, "hcon %p", conn); + + /* for offload use case, codec needs to configured before opening SCO */ + if (conn->codec.data_path) + configure_datapath_sync(hdev, &conn->codec); + + conn->state = BT_CONNECT; + conn->out = true; + + conn->attempt++; + + memset(&cp, 0x00, sizeof(cp)); + + cp.handle = cpu_to_le16(handle); + + cp.tx_bandwidth = cpu_to_le32(0x00001f40); + cp.rx_bandwidth = cpu_to_le32(0x00001f40); + + switch (conn->codec.id) { + case BT_CODEC_MSBC: + if (!find_next_esco_param(conn, esco_param_msbc, + ARRAY_SIZE(esco_param_msbc))) + return -EINVAL; + + param = &esco_param_msbc[conn->attempt - 1]; + cp.tx_coding_format.id = 0x05; + cp.rx_coding_format.id = 0x05; + cp.tx_codec_frame_size = __cpu_to_le16(60); + cp.rx_codec_frame_size = __cpu_to_le16(60); + cp.in_bandwidth = __cpu_to_le32(32000); + cp.out_bandwidth = __cpu_to_le32(32000); + cp.in_coding_format.id = 0x04; + cp.out_coding_format.id = 0x04; + cp.in_coded_data_size = __cpu_to_le16(16); + cp.out_coded_data_size = __cpu_to_le16(16); + cp.in_pcm_data_format = 2; + cp.out_pcm_data_format = 2; + cp.in_pcm_sample_payload_msb_pos = 0; + cp.out_pcm_sample_payload_msb_pos = 0; + cp.in_data_path = conn->codec.data_path; + cp.out_data_path = conn->codec.data_path; + cp.in_transport_unit_size = 1; + cp.out_transport_unit_size = 1; + break; + + case BT_CODEC_TRANSPARENT: + if (!find_next_esco_param(conn, esco_param_msbc, + ARRAY_SIZE(esco_param_msbc))) + return false; + param = &esco_param_msbc[conn->attempt - 1]; + cp.tx_coding_format.id = 0x03; + cp.rx_coding_format.id = 0x03; + cp.tx_codec_frame_size = __cpu_to_le16(60); + cp.rx_codec_frame_size = __cpu_to_le16(60); + cp.in_bandwidth = __cpu_to_le32(0x1f40); + cp.out_bandwidth = __cpu_to_le32(0x1f40); + cp.in_coding_format.id = 0x03; + cp.out_coding_format.id = 0x03; + cp.in_coded_data_size = __cpu_to_le16(16); + cp.out_coded_data_size = __cpu_to_le16(16); + cp.in_pcm_data_format = 2; + cp.out_pcm_data_format = 2; + cp.in_pcm_sample_payload_msb_pos = 0; + cp.out_pcm_sample_payload_msb_pos = 0; + cp.in_data_path = conn->codec.data_path; + cp.out_data_path = conn->codec.data_path; + cp.in_transport_unit_size = 1; + cp.out_transport_unit_size = 1; + break; + + case BT_CODEC_CVSD: + if (lmp_esco_capable(conn->link)) { + if (!find_next_esco_param(conn, esco_param_cvsd, + ARRAY_SIZE(esco_param_cvsd))) + return -EINVAL; + param = &esco_param_cvsd[conn->attempt - 1]; + } else { + if (conn->attempt > ARRAY_SIZE(sco_param_cvsd)) + return -EINVAL; + param = &sco_param_cvsd[conn->attempt - 1]; + } + cp.tx_coding_format.id = 2; + cp.rx_coding_format.id = 2; + cp.tx_codec_frame_size = __cpu_to_le16(60); + cp.rx_codec_frame_size = __cpu_to_le16(60); + cp.in_bandwidth = __cpu_to_le32(16000); + cp.out_bandwidth = __cpu_to_le32(16000); + cp.in_coding_format.id = 4; + cp.out_coding_format.id = 4; + cp.in_coded_data_size = __cpu_to_le16(16); + cp.out_coded_data_size = __cpu_to_le16(16); + cp.in_pcm_data_format = 2; + cp.out_pcm_data_format = 2; + cp.in_pcm_sample_payload_msb_pos = 0; + cp.out_pcm_sample_payload_msb_pos = 0; + cp.in_data_path = conn->codec.data_path; + cp.out_data_path = conn->codec.data_path; + cp.in_transport_unit_size = 16; + cp.out_transport_unit_size = 16; + break; + default: + return -EINVAL; + } + + cp.retrans_effort = param->retrans_effort; + cp.pkt_type = __cpu_to_le16(param->pkt_type); + cp.max_latency = __cpu_to_le16(param->max_latency); + + if (hci_send_cmd(hdev, HCI_OP_ENHANCED_SETUP_SYNC_CONN, sizeof(cp), &cp) < 0) + return -EIO; + + return 0; +} + +static bool hci_setup_sync_conn(struct hci_conn *conn, __u16 handle) +{ + struct hci_dev *hdev = conn->hdev; + struct hci_cp_setup_sync_conn cp; + const struct sco_param *param; + + bt_dev_dbg(hdev, "hcon %p", conn); + + conn->state = BT_CONNECT; + conn->out = true; + + conn->attempt++; + + cp.handle = cpu_to_le16(handle); + + cp.tx_bandwidth = cpu_to_le32(0x00001f40); + cp.rx_bandwidth = cpu_to_le32(0x00001f40); + cp.voice_setting = cpu_to_le16(conn->setting); + + switch (conn->setting & SCO_AIRMODE_MASK) { + case SCO_AIRMODE_TRANSP: + if (!find_next_esco_param(conn, esco_param_msbc, + ARRAY_SIZE(esco_param_msbc))) + return false; + param = &esco_param_msbc[conn->attempt - 1]; + break; + case SCO_AIRMODE_CVSD: + if (lmp_esco_capable(conn->link)) { + if (!find_next_esco_param(conn, esco_param_cvsd, + ARRAY_SIZE(esco_param_cvsd))) + return false; + param = &esco_param_cvsd[conn->attempt - 1]; + } else { + if (conn->attempt > ARRAY_SIZE(sco_param_cvsd)) + return false; + param = &sco_param_cvsd[conn->attempt - 1]; + } + break; + default: + return false; + } + + cp.retrans_effort = param->retrans_effort; + cp.pkt_type = __cpu_to_le16(param->pkt_type); + cp.max_latency = __cpu_to_le16(param->max_latency); + + if (hci_send_cmd(hdev, HCI_OP_SETUP_SYNC_CONN, sizeof(cp), &cp) < 0) + return false; + + return true; +} + +bool hci_setup_sync(struct hci_conn *conn, __u16 handle) +{ + int result; + struct conn_handle_t *conn_handle; + + if (enhanced_sync_conn_capable(conn->hdev)) { + conn_handle = kzalloc(sizeof(*conn_handle), GFP_KERNEL); + + if (!conn_handle) + return false; + + conn_handle->conn = conn; + conn_handle->handle = handle; + result = hci_cmd_sync_queue(conn->hdev, hci_enhanced_setup_sync, + conn_handle, NULL); + if (result < 0) + kfree(conn_handle); + + return result == 0; + } + + return hci_setup_sync_conn(conn, handle); +} + +u8 hci_le_conn_update(struct hci_conn *conn, u16 min, u16 max, u16 latency, + u16 to_multiplier) +{ + struct hci_dev *hdev = conn->hdev; + struct hci_conn_params *params; + struct hci_cp_le_conn_update cp; + + hci_dev_lock(hdev); + + params = hci_conn_params_lookup(hdev, &conn->dst, conn->dst_type); + if (params) { + params->conn_min_interval = min; + params->conn_max_interval = max; + params->conn_latency = latency; + params->supervision_timeout = to_multiplier; + } + + hci_dev_unlock(hdev); + + memset(&cp, 0, sizeof(cp)); + cp.handle = cpu_to_le16(conn->handle); + cp.conn_interval_min = cpu_to_le16(min); + cp.conn_interval_max = cpu_to_le16(max); + cp.conn_latency = cpu_to_le16(latency); + cp.supervision_timeout = cpu_to_le16(to_multiplier); + cp.min_ce_len = cpu_to_le16(0x0000); + cp.max_ce_len = cpu_to_le16(0x0000); + + hci_send_cmd(hdev, HCI_OP_LE_CONN_UPDATE, sizeof(cp), &cp); + + if (params) + return 0x01; + + return 0x00; +} + +void hci_le_start_enc(struct hci_conn *conn, __le16 ediv, __le64 rand, + __u8 ltk[16], __u8 key_size) +{ + struct hci_dev *hdev = conn->hdev; + struct hci_cp_le_start_enc cp; + + BT_DBG("hcon %p", conn); + + memset(&cp, 0, sizeof(cp)); + + cp.handle = cpu_to_le16(conn->handle); + cp.rand = rand; + cp.ediv = ediv; + memcpy(cp.ltk, ltk, key_size); + + hci_send_cmd(hdev, HCI_OP_LE_START_ENC, sizeof(cp), &cp); +} + +/* Device _must_ be locked */ +void hci_sco_setup(struct hci_conn *conn, __u8 status) +{ + struct hci_conn *sco = conn->link; + + if (!sco) + return; + + BT_DBG("hcon %p", conn); + + if (!status) { + if (lmp_esco_capable(conn->hdev)) + hci_setup_sync(sco, conn->handle); + else + hci_add_sco(sco, conn->handle); + } else { + hci_connect_cfm(sco, status); + hci_conn_del(sco); + } +} + +static void hci_conn_timeout(struct work_struct *work) +{ + struct hci_conn *conn = container_of(work, struct hci_conn, + disc_work.work); + int refcnt = atomic_read(&conn->refcnt); + + BT_DBG("hcon %p state %s", conn, state_to_string(conn->state)); + + WARN_ON(refcnt < 0); + + /* FIXME: It was observed that in pairing failed scenario, refcnt + * drops below 0. Probably this is because l2cap_conn_del calls + * l2cap_chan_del for each channel, and inside l2cap_chan_del conn is + * dropped. After that loop hci_chan_del is called which also drops + * conn. For now make sure that ACL is alive if refcnt is higher then 0, + * otherwise drop it. + */ + if (refcnt > 0) + return; + + /* LE connections in scanning state need special handling */ + if (conn->state == BT_CONNECT && conn->type == LE_LINK && + test_bit(HCI_CONN_SCANNING, &conn->flags)) { + hci_connect_le_scan_remove(conn); + return; + } + + hci_abort_conn(conn, hci_proto_disconn_ind(conn)); +} + +/* Enter sniff mode */ +static void hci_conn_idle(struct work_struct *work) +{ + struct hci_conn *conn = container_of(work, struct hci_conn, + idle_work.work); + struct hci_dev *hdev = conn->hdev; + + BT_DBG("hcon %p mode %d", conn, conn->mode); + + if (!lmp_sniff_capable(hdev) || !lmp_sniff_capable(conn)) + return; + + if (conn->mode != HCI_CM_ACTIVE || !(conn->link_policy & HCI_LP_SNIFF)) + return; + + if (lmp_sniffsubr_capable(hdev) && lmp_sniffsubr_capable(conn)) { + struct hci_cp_sniff_subrate cp; + cp.handle = cpu_to_le16(conn->handle); + cp.max_latency = cpu_to_le16(0); + cp.min_remote_timeout = cpu_to_le16(0); + cp.min_local_timeout = cpu_to_le16(0); + hci_send_cmd(hdev, HCI_OP_SNIFF_SUBRATE, sizeof(cp), &cp); + } + + if (!test_and_set_bit(HCI_CONN_MODE_CHANGE_PEND, &conn->flags)) { + struct hci_cp_sniff_mode cp; + cp.handle = cpu_to_le16(conn->handle); + cp.max_interval = cpu_to_le16(hdev->sniff_max_interval); + cp.min_interval = cpu_to_le16(hdev->sniff_min_interval); + cp.attempt = cpu_to_le16(4); + cp.timeout = cpu_to_le16(1); + hci_send_cmd(hdev, HCI_OP_SNIFF_MODE, sizeof(cp), &cp); + } +} + +static void hci_conn_auto_accept(struct work_struct *work) +{ + struct hci_conn *conn = container_of(work, struct hci_conn, + auto_accept_work.work); + + hci_send_cmd(conn->hdev, HCI_OP_USER_CONFIRM_REPLY, sizeof(conn->dst), + &conn->dst); +} + +static void le_disable_advertising(struct hci_dev *hdev) +{ + if (ext_adv_capable(hdev)) { + struct hci_cp_le_set_ext_adv_enable cp; + + cp.enable = 0x00; + cp.num_of_sets = 0x00; + + hci_send_cmd(hdev, HCI_OP_LE_SET_EXT_ADV_ENABLE, sizeof(cp), + &cp); + } else { + u8 enable = 0x00; + hci_send_cmd(hdev, HCI_OP_LE_SET_ADV_ENABLE, sizeof(enable), + &enable); + } +} + +static void le_conn_timeout(struct work_struct *work) +{ + struct hci_conn *conn = container_of(work, struct hci_conn, + le_conn_timeout.work); + struct hci_dev *hdev = conn->hdev; + + BT_DBG(""); + + /* We could end up here due to having done directed advertising, + * so clean up the state if necessary. This should however only + * happen with broken hardware or if low duty cycle was used + * (which doesn't have a timeout of its own). + */ + if (conn->role == HCI_ROLE_SLAVE) { + /* Disable LE Advertising */ + le_disable_advertising(hdev); + hci_dev_lock(hdev); + hci_conn_failed(conn, HCI_ERROR_ADVERTISING_TIMEOUT); + hci_dev_unlock(hdev); + return; + } + + hci_abort_conn(conn, HCI_ERROR_REMOTE_USER_TERM); +} + +struct iso_cig_params { + struct hci_cp_le_set_cig_params cp; + struct hci_cis_params cis[0x1f]; +}; + +struct iso_list_data { + union { + u8 cig; + u8 big; + }; + union { + u8 cis; + u8 bis; + u16 sync_handle; + }; + int count; + struct iso_cig_params pdu; +}; + +static void bis_list(struct hci_conn *conn, void *data) +{ + struct iso_list_data *d = data; + + /* Skip if not broadcast/ANY address */ + if (bacmp(&conn->dst, BDADDR_ANY)) + return; + + if (d->big != conn->iso_qos.big || d->bis == BT_ISO_QOS_BIS_UNSET || + d->bis != conn->iso_qos.bis) + return; + + d->count++; +} + +static void find_bis(struct hci_conn *conn, void *data) +{ + struct iso_list_data *d = data; + + /* Ignore unicast */ + if (bacmp(&conn->dst, BDADDR_ANY)) + return; + + d->count++; +} + +static int terminate_big_sync(struct hci_dev *hdev, void *data) +{ + struct iso_list_data *d = data; + + bt_dev_dbg(hdev, "big 0x%2.2x bis 0x%2.2x", d->big, d->bis); + + hci_remove_ext_adv_instance_sync(hdev, d->bis, NULL); + + /* Check if ISO connection is a BIS and terminate BIG if there are + * no other connections using it. + */ + hci_conn_hash_list_state(hdev, find_bis, ISO_LINK, BT_CONNECTED, d); + if (d->count) + return 0; + + return hci_le_terminate_big_sync(hdev, d->big, + HCI_ERROR_LOCAL_HOST_TERM); +} + +static void terminate_big_destroy(struct hci_dev *hdev, void *data, int err) +{ + kfree(data); +} + +static int hci_le_terminate_big(struct hci_dev *hdev, u8 big, u8 bis) +{ + struct iso_list_data *d; + int ret; + + bt_dev_dbg(hdev, "big 0x%2.2x bis 0x%2.2x", big, bis); + + d = kmalloc(sizeof(*d), GFP_KERNEL); + if (!d) + return -ENOMEM; + + memset(d, 0, sizeof(*d)); + d->big = big; + d->bis = bis; + + ret = hci_cmd_sync_queue(hdev, terminate_big_sync, d, + terminate_big_destroy); + if (ret) + kfree(d); + + return ret; +} + +static int big_terminate_sync(struct hci_dev *hdev, void *data) +{ + struct iso_list_data *d = data; + + bt_dev_dbg(hdev, "big 0x%2.2x sync_handle 0x%4.4x", d->big, + d->sync_handle); + + /* Check if ISO connection is a BIS and terminate BIG if there are + * no other connections using it. + */ + hci_conn_hash_list_state(hdev, find_bis, ISO_LINK, BT_CONNECTED, d); + if (d->count) + return 0; + + hci_le_big_terminate_sync(hdev, d->big); + + return hci_le_pa_terminate_sync(hdev, d->sync_handle); +} + +static int hci_le_big_terminate(struct hci_dev *hdev, u8 big, u16 sync_handle) +{ + struct iso_list_data *d; + int ret; + + bt_dev_dbg(hdev, "big 0x%2.2x sync_handle 0x%4.4x", big, sync_handle); + + d = kmalloc(sizeof(*d), GFP_KERNEL); + if (!d) + return -ENOMEM; + + memset(d, 0, sizeof(*d)); + d->big = big; + d->sync_handle = sync_handle; + + ret = hci_cmd_sync_queue(hdev, big_terminate_sync, d, + terminate_big_destroy); + if (ret) + kfree(d); + + return ret; +} + +/* Cleanup BIS connection + * + * Detects if there any BIS left connected in a BIG + * broadcaster: Remove advertising instance and terminate BIG. + * broadcaster receiver: Teminate BIG sync and terminate PA sync. + */ +static void bis_cleanup(struct hci_conn *conn) +{ + struct hci_dev *hdev = conn->hdev; + + bt_dev_dbg(hdev, "conn %p", conn); + + if (conn->role == HCI_ROLE_MASTER) { + if (!test_and_clear_bit(HCI_CONN_PER_ADV, &conn->flags)) + return; + + hci_le_terminate_big(hdev, conn->iso_qos.big, + conn->iso_qos.bis); + } else { + hci_le_big_terminate(hdev, conn->iso_qos.big, + conn->sync_handle); + } +} + +static int remove_cig_sync(struct hci_dev *hdev, void *data) +{ + u8 handle = PTR_ERR(data); + + return hci_le_remove_cig_sync(hdev, handle); +} + +static int hci_le_remove_cig(struct hci_dev *hdev, u8 handle) +{ + bt_dev_dbg(hdev, "handle 0x%2.2x", handle); + + return hci_cmd_sync_queue(hdev, remove_cig_sync, ERR_PTR(handle), NULL); +} + +static void find_cis(struct hci_conn *conn, void *data) +{ + struct iso_list_data *d = data; + + /* Ignore broadcast */ + if (!bacmp(&conn->dst, BDADDR_ANY)) + return; + + d->count++; +} + +/* Cleanup CIS connection: + * + * Detects if there any CIS left connected in a CIG and remove it. + */ +static void cis_cleanup(struct hci_conn *conn) +{ + struct hci_dev *hdev = conn->hdev; + struct iso_list_data d; + + memset(&d, 0, sizeof(d)); + d.cig = conn->iso_qos.cig; + + /* Check if ISO connection is a CIS and remove CIG if there are + * no other connections using it. + */ + hci_conn_hash_list_state(hdev, find_cis, ISO_LINK, BT_BOUND, &d); + hci_conn_hash_list_state(hdev, find_cis, ISO_LINK, BT_CONNECT, &d); + hci_conn_hash_list_state(hdev, find_cis, ISO_LINK, BT_CONNECTED, &d); + if (d.count) + return; + + hci_le_remove_cig(hdev, conn->iso_qos.cig); +} + +struct hci_conn *hci_conn_add(struct hci_dev *hdev, int type, bdaddr_t *dst, + u8 role) +{ + struct hci_conn *conn; + + BT_DBG("%s dst %pMR", hdev->name, dst); + + conn = kzalloc(sizeof(*conn), GFP_KERNEL); + if (!conn) + return NULL; + + bacpy(&conn->dst, dst); + bacpy(&conn->src, &hdev->bdaddr); + conn->handle = HCI_CONN_HANDLE_UNSET; + conn->hdev = hdev; + conn->type = type; + conn->role = role; + conn->mode = HCI_CM_ACTIVE; + conn->state = BT_OPEN; + conn->auth_type = HCI_AT_GENERAL_BONDING; + conn->io_capability = hdev->io_capability; + conn->remote_auth = 0xff; + conn->key_type = 0xff; + conn->rssi = HCI_RSSI_INVALID; + conn->tx_power = HCI_TX_POWER_INVALID; + conn->max_tx_power = HCI_TX_POWER_INVALID; + + set_bit(HCI_CONN_POWER_SAVE, &conn->flags); + conn->disc_timeout = HCI_DISCONN_TIMEOUT; + + /* Set Default Authenticated payload timeout to 30s */ + conn->auth_payload_timeout = DEFAULT_AUTH_PAYLOAD_TIMEOUT; + + if (conn->role == HCI_ROLE_MASTER) + conn->out = true; + + switch (type) { + case ACL_LINK: + conn->pkt_type = hdev->pkt_type & ACL_PTYPE_MASK; + break; + case LE_LINK: + /* conn->src should reflect the local identity address */ + hci_copy_identity_address(hdev, &conn->src, &conn->src_type); + break; + case ISO_LINK: + /* conn->src should reflect the local identity address */ + hci_copy_identity_address(hdev, &conn->src, &conn->src_type); + + /* set proper cleanup function */ + if (!bacmp(dst, BDADDR_ANY)) + conn->cleanup = bis_cleanup; + else if (conn->role == HCI_ROLE_MASTER) + conn->cleanup = cis_cleanup; + + break; + case SCO_LINK: + if (lmp_esco_capable(hdev)) + conn->pkt_type = (hdev->esco_type & SCO_ESCO_MASK) | + (hdev->esco_type & EDR_ESCO_MASK); + else + conn->pkt_type = hdev->pkt_type & SCO_PTYPE_MASK; + break; + case ESCO_LINK: + conn->pkt_type = hdev->esco_type & ~EDR_ESCO_MASK; + break; + } + + skb_queue_head_init(&conn->data_q); + + INIT_LIST_HEAD(&conn->chan_list); + + INIT_DELAYED_WORK(&conn->disc_work, hci_conn_timeout); + INIT_DELAYED_WORK(&conn->auto_accept_work, hci_conn_auto_accept); + INIT_DELAYED_WORK(&conn->idle_work, hci_conn_idle); + INIT_DELAYED_WORK(&conn->le_conn_timeout, le_conn_timeout); + INIT_WORK(&conn->le_scan_cleanup, le_scan_cleanup); + + atomic_set(&conn->refcnt, 0); + + hci_dev_hold(hdev); + + hci_conn_hash_add(hdev, conn); + + /* The SCO and eSCO connections will only be notified when their + * setup has been completed. This is different to ACL links which + * can be notified right away. + */ + if (conn->type != SCO_LINK && conn->type != ESCO_LINK) { + if (hdev->notify) + hdev->notify(hdev, HCI_NOTIFY_CONN_ADD); + } + + hci_conn_init_sysfs(conn); + + return conn; +} + +static bool hci_conn_unlink(struct hci_conn *conn) +{ + if (!conn->link) + return false; + + conn->link->link = NULL; + conn->link = NULL; + + return true; +} + +int hci_conn_del(struct hci_conn *conn) +{ + struct hci_dev *hdev = conn->hdev; + + BT_DBG("%s hcon %p handle %d", hdev->name, conn, conn->handle); + + cancel_delayed_work_sync(&conn->disc_work); + cancel_delayed_work_sync(&conn->auto_accept_work); + cancel_delayed_work_sync(&conn->idle_work); + + if (conn->type == ACL_LINK) { + struct hci_conn *link = conn->link; + + if (link) { + hci_conn_unlink(conn); + /* Due to race, SCO connection might be not established + * yet at this point. Delete it now, otherwise it is + * possible for it to be stuck and can't be deleted. + */ + if (link->handle == HCI_CONN_HANDLE_UNSET) + hci_conn_del(link); + } + + /* Unacked frames */ + hdev->acl_cnt += conn->sent; + } else if (conn->type == LE_LINK) { + cancel_delayed_work(&conn->le_conn_timeout); + + if (hdev->le_pkts) + hdev->le_cnt += conn->sent; + else + hdev->acl_cnt += conn->sent; + } else { + struct hci_conn *acl = conn->link; + + if (acl) { + hci_conn_unlink(conn); + hci_conn_drop(acl); + } + + /* Unacked ISO frames */ + if (conn->type == ISO_LINK) { + if (hdev->iso_pkts) + hdev->iso_cnt += conn->sent; + else if (hdev->le_pkts) + hdev->le_cnt += conn->sent; + else + hdev->acl_cnt += conn->sent; + } + } + + if (conn->amp_mgr) + amp_mgr_put(conn->amp_mgr); + + skb_queue_purge(&conn->data_q); + + /* Remove the connection from the list and cleanup its remaining + * state. This is a separate function since for some cases like + * BT_CONNECT_SCAN we *only* want the cleanup part without the + * rest of hci_conn_del. + */ + hci_conn_cleanup(conn); + + return 0; +} + +struct hci_dev *hci_get_route(bdaddr_t *dst, bdaddr_t *src, uint8_t src_type) +{ + int use_src = bacmp(src, BDADDR_ANY); + struct hci_dev *hdev = NULL, *d; + + BT_DBG("%pMR -> %pMR", src, dst); + + read_lock(&hci_dev_list_lock); + + list_for_each_entry(d, &hci_dev_list, list) { + if (!test_bit(HCI_UP, &d->flags) || + hci_dev_test_flag(d, HCI_USER_CHANNEL) || + d->dev_type != HCI_PRIMARY) + continue; + + /* Simple routing: + * No source address - find interface with bdaddr != dst + * Source address - find interface with bdaddr == src + */ + + if (use_src) { + bdaddr_t id_addr; + u8 id_addr_type; + + if (src_type == BDADDR_BREDR) { + if (!lmp_bredr_capable(d)) + continue; + bacpy(&id_addr, &d->bdaddr); + id_addr_type = BDADDR_BREDR; + } else { + if (!lmp_le_capable(d)) + continue; + + hci_copy_identity_address(d, &id_addr, + &id_addr_type); + + /* Convert from HCI to three-value type */ + if (id_addr_type == ADDR_LE_DEV_PUBLIC) + id_addr_type = BDADDR_LE_PUBLIC; + else + id_addr_type = BDADDR_LE_RANDOM; + } + + if (!bacmp(&id_addr, src) && id_addr_type == src_type) { + hdev = d; break; + } + } else { + if (bacmp(&d->bdaddr, dst)) { + hdev = d; break; + } + } + } + + if (hdev) + hdev = hci_dev_hold(hdev); + + read_unlock(&hci_dev_list_lock); + return hdev; +} +EXPORT_SYMBOL(hci_get_route); + +/* This function requires the caller holds hdev->lock */ +static void hci_le_conn_failed(struct hci_conn *conn, u8 status) +{ + struct hci_dev *hdev = conn->hdev; + + hci_connect_le_scan_cleanup(conn, status); + + /* Enable advertising in case this was a failed connection + * attempt as a peripheral. + */ + hci_enable_advertising(hdev); +} + +/* This function requires the caller holds hdev->lock */ +void hci_conn_failed(struct hci_conn *conn, u8 status) +{ + struct hci_dev *hdev = conn->hdev; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + switch (conn->type) { + case LE_LINK: + hci_le_conn_failed(conn, status); + break; + case ACL_LINK: + mgmt_connect_failed(hdev, &conn->dst, conn->type, + conn->dst_type, status); + break; + } + + conn->state = BT_CLOSED; + hci_connect_cfm(conn, status); + hci_conn_del(conn); +} + +static void create_le_conn_complete(struct hci_dev *hdev, void *data, int err) +{ + struct hci_conn *conn = data; + + bt_dev_dbg(hdev, "err %d", err); + + hci_dev_lock(hdev); + + if (!err) { + hci_connect_le_scan_cleanup(conn, 0x00); + goto done; + } + + /* Check if connection is still pending */ + if (conn != hci_lookup_le_connect(hdev)) + goto done; + + hci_conn_failed(conn, bt_status(err)); + +done: + hci_dev_unlock(hdev); +} + +static int hci_connect_le_sync(struct hci_dev *hdev, void *data) +{ + struct hci_conn *conn = data; + + bt_dev_dbg(hdev, "conn %p", conn); + + return hci_le_create_conn_sync(hdev, conn); +} + +struct hci_conn *hci_connect_le(struct hci_dev *hdev, bdaddr_t *dst, + u8 dst_type, bool dst_resolved, u8 sec_level, + u16 conn_timeout, u8 role) +{ + struct hci_conn *conn; + struct smp_irk *irk; + int err; + + /* Let's make sure that le is enabled.*/ + if (!hci_dev_test_flag(hdev, HCI_LE_ENABLED)) { + if (lmp_le_capable(hdev)) + return ERR_PTR(-ECONNREFUSED); + + return ERR_PTR(-EOPNOTSUPP); + } + + /* Since the controller supports only one LE connection attempt at a + * time, we return -EBUSY if there is any connection attempt running. + */ + if (hci_lookup_le_connect(hdev)) + return ERR_PTR(-EBUSY); + + /* If there's already a connection object but it's not in + * scanning state it means it must already be established, in + * which case we can't do anything else except report a failure + * to connect. + */ + conn = hci_conn_hash_lookup_le(hdev, dst, dst_type); + if (conn && !test_bit(HCI_CONN_SCANNING, &conn->flags)) { + return ERR_PTR(-EBUSY); + } + + /* Check if the destination address has been resolved by the controller + * since if it did then the identity address shall be used. + */ + if (!dst_resolved) { + /* When given an identity address with existing identity + * resolving key, the connection needs to be established + * to a resolvable random address. + * + * Storing the resolvable random address is required here + * to handle connection failures. The address will later + * be resolved back into the original identity address + * from the connect request. + */ + irk = hci_find_irk_by_addr(hdev, dst, dst_type); + if (irk && bacmp(&irk->rpa, BDADDR_ANY)) { + dst = &irk->rpa; + dst_type = ADDR_LE_DEV_RANDOM; + } + } + + if (conn) { + bacpy(&conn->dst, dst); + } else { + conn = hci_conn_add(hdev, LE_LINK, dst, role); + if (!conn) + return ERR_PTR(-ENOMEM); + hci_conn_hold(conn); + conn->pending_sec_level = sec_level; + } + + conn->dst_type = dst_type; + conn->sec_level = BT_SECURITY_LOW; + conn->conn_timeout = conn_timeout; + + conn->state = BT_CONNECT; + clear_bit(HCI_CONN_SCANNING, &conn->flags); + + err = hci_cmd_sync_queue(hdev, hci_connect_le_sync, conn, + create_le_conn_complete); + if (err) { + hci_conn_del(conn); + return ERR_PTR(err); + } + + return conn; +} + +static bool is_connected(struct hci_dev *hdev, bdaddr_t *addr, u8 type) +{ + struct hci_conn *conn; + + conn = hci_conn_hash_lookup_le(hdev, addr, type); + if (!conn) + return false; + + if (conn->state != BT_CONNECTED) + return false; + + return true; +} + +/* This function requires the caller holds hdev->lock */ +static int hci_explicit_conn_params_set(struct hci_dev *hdev, + bdaddr_t *addr, u8 addr_type) +{ + struct hci_conn_params *params; + + if (is_connected(hdev, addr, addr_type)) + return -EISCONN; + + params = hci_conn_params_lookup(hdev, addr, addr_type); + if (!params) { + params = hci_conn_params_add(hdev, addr, addr_type); + if (!params) + return -ENOMEM; + + /* If we created new params, mark them to be deleted in + * hci_connect_le_scan_cleanup. It's different case than + * existing disabled params, those will stay after cleanup. + */ + params->auto_connect = HCI_AUTO_CONN_EXPLICIT; + } + + /* We're trying to connect, so make sure params are at pend_le_conns */ + if (params->auto_connect == HCI_AUTO_CONN_DISABLED || + params->auto_connect == HCI_AUTO_CONN_REPORT || + params->auto_connect == HCI_AUTO_CONN_EXPLICIT) { + hci_pend_le_list_del_init(params); + hci_pend_le_list_add(params, &hdev->pend_le_conns); + } + + params->explicit_connect = true; + + BT_DBG("addr %pMR (type %u) auto_connect %u", addr, addr_type, + params->auto_connect); + + return 0; +} + +static int qos_set_big(struct hci_dev *hdev, struct bt_iso_qos *qos) +{ + struct iso_list_data data; + + /* Allocate a BIG if not set */ + if (qos->big == BT_ISO_QOS_BIG_UNSET) { + for (data.big = 0x00; data.big < 0xef; data.big++) { + data.count = 0; + data.bis = 0xff; + + hci_conn_hash_list_state(hdev, bis_list, ISO_LINK, + BT_BOUND, &data); + if (!data.count) + break; + } + + if (data.big == 0xef) + return -EADDRNOTAVAIL; + + /* Update BIG */ + qos->big = data.big; + } + + return 0; +} + +static int qos_set_bis(struct hci_dev *hdev, struct bt_iso_qos *qos) +{ + struct iso_list_data data; + + /* Allocate BIS if not set */ + if (qos->bis == BT_ISO_QOS_BIS_UNSET) { + /* Find an unused adv set to advertise BIS, skip instance 0x00 + * since it is reserved as general purpose set. + */ + for (data.bis = 0x01; data.bis < hdev->le_num_of_adv_sets; + data.bis++) { + data.count = 0; + + hci_conn_hash_list_state(hdev, bis_list, ISO_LINK, + BT_BOUND, &data); + if (!data.count) + break; + } + + if (data.bis == hdev->le_num_of_adv_sets) + return -EADDRNOTAVAIL; + + /* Update BIS */ + qos->bis = data.bis; + } + + return 0; +} + +/* This function requires the caller holds hdev->lock */ +static struct hci_conn *hci_add_bis(struct hci_dev *hdev, bdaddr_t *dst, + struct bt_iso_qos *qos) +{ + struct hci_conn *conn; + struct iso_list_data data; + int err; + + /* Let's make sure that le is enabled.*/ + if (!hci_dev_test_flag(hdev, HCI_LE_ENABLED)) { + if (lmp_le_capable(hdev)) + return ERR_PTR(-ECONNREFUSED); + return ERR_PTR(-EOPNOTSUPP); + } + + err = qos_set_big(hdev, qos); + if (err) + return ERR_PTR(err); + + err = qos_set_bis(hdev, qos); + if (err) + return ERR_PTR(err); + + data.big = qos->big; + data.bis = qos->bis; + data.count = 0; + + /* Check if there is already a matching BIG/BIS */ + hci_conn_hash_list_state(hdev, bis_list, ISO_LINK, BT_BOUND, &data); + if (data.count) + return ERR_PTR(-EADDRINUSE); + + conn = hci_conn_hash_lookup_bis(hdev, dst, qos->big, qos->bis); + if (conn) + return ERR_PTR(-EADDRINUSE); + + conn = hci_conn_add(hdev, ISO_LINK, dst, HCI_ROLE_MASTER); + if (!conn) + return ERR_PTR(-ENOMEM); + + set_bit(HCI_CONN_PER_ADV, &conn->flags); + conn->state = BT_CONNECT; + + hci_conn_hold(conn); + return conn; +} + +/* This function requires the caller holds hdev->lock */ +struct hci_conn *hci_connect_le_scan(struct hci_dev *hdev, bdaddr_t *dst, + u8 dst_type, u8 sec_level, + u16 conn_timeout, + enum conn_reasons conn_reason) +{ + struct hci_conn *conn; + + /* Let's make sure that le is enabled.*/ + if (!hci_dev_test_flag(hdev, HCI_LE_ENABLED)) { + if (lmp_le_capable(hdev)) + return ERR_PTR(-ECONNREFUSED); + + return ERR_PTR(-EOPNOTSUPP); + } + + /* Some devices send ATT messages as soon as the physical link is + * established. To be able to handle these ATT messages, the user- + * space first establishes the connection and then starts the pairing + * process. + * + * So if a hci_conn object already exists for the following connection + * attempt, we simply update pending_sec_level and auth_type fields + * and return the object found. + */ + conn = hci_conn_hash_lookup_le(hdev, dst, dst_type); + if (conn) { + if (conn->pending_sec_level < sec_level) + conn->pending_sec_level = sec_level; + goto done; + } + + BT_DBG("requesting refresh of dst_addr"); + + conn = hci_conn_add(hdev, LE_LINK, dst, HCI_ROLE_MASTER); + if (!conn) + return ERR_PTR(-ENOMEM); + + if (hci_explicit_conn_params_set(hdev, dst, dst_type) < 0) { + hci_conn_del(conn); + return ERR_PTR(-EBUSY); + } + + conn->state = BT_CONNECT; + set_bit(HCI_CONN_SCANNING, &conn->flags); + conn->dst_type = dst_type; + conn->sec_level = BT_SECURITY_LOW; + conn->pending_sec_level = sec_level; + conn->conn_timeout = conn_timeout; + conn->conn_reason = conn_reason; + + hci_update_passive_scan(hdev); + +done: + hci_conn_hold(conn); + return conn; +} + +struct hci_conn *hci_connect_acl(struct hci_dev *hdev, bdaddr_t *dst, + u8 sec_level, u8 auth_type, + enum conn_reasons conn_reason) +{ + struct hci_conn *acl; + + if (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) { + if (lmp_bredr_capable(hdev)) + return ERR_PTR(-ECONNREFUSED); + + return ERR_PTR(-EOPNOTSUPP); + } + + /* Reject outgoing connection to device with same BD ADDR against + * CVE-2020-26555 + */ + if (!bacmp(&hdev->bdaddr, dst)) { + bt_dev_dbg(hdev, "Reject connection with same BD_ADDR %pMR\n", + dst); + return ERR_PTR(-ECONNREFUSED); + } + + acl = hci_conn_hash_lookup_ba(hdev, ACL_LINK, dst); + if (!acl) { + acl = hci_conn_add(hdev, ACL_LINK, dst, HCI_ROLE_MASTER); + if (!acl) + return ERR_PTR(-ENOMEM); + } + + hci_conn_hold(acl); + + acl->conn_reason = conn_reason; + if (acl->state == BT_OPEN || acl->state == BT_CLOSED) { + acl->sec_level = BT_SECURITY_LOW; + acl->pending_sec_level = sec_level; + acl->auth_type = auth_type; + hci_acl_create_connection(acl); + } + + return acl; +} + +struct hci_conn *hci_connect_sco(struct hci_dev *hdev, int type, bdaddr_t *dst, + __u16 setting, struct bt_codec *codec) +{ + struct hci_conn *acl; + struct hci_conn *sco; + + acl = hci_connect_acl(hdev, dst, BT_SECURITY_LOW, HCI_AT_NO_BONDING, + CONN_REASON_SCO_CONNECT); + if (IS_ERR(acl)) + return acl; + + sco = hci_conn_hash_lookup_ba(hdev, type, dst); + if (!sco) { + sco = hci_conn_add(hdev, type, dst, HCI_ROLE_MASTER); + if (!sco) { + hci_conn_drop(acl); + return ERR_PTR(-ENOMEM); + } + } + + acl->link = sco; + sco->link = acl; + + hci_conn_hold(sco); + + sco->setting = setting; + sco->codec = *codec; + + if (acl->state == BT_CONNECTED && + (sco->state == BT_OPEN || sco->state == BT_CLOSED)) { + set_bit(HCI_CONN_POWER_SAVE, &acl->flags); + hci_conn_enter_active_mode(acl, BT_POWER_FORCE_ACTIVE_ON); + + if (test_bit(HCI_CONN_MODE_CHANGE_PEND, &acl->flags)) { + /* defer SCO setup until mode change completed */ + set_bit(HCI_CONN_SCO_SETUP_PEND, &acl->flags); + return sco; + } + + hci_sco_setup(acl, 0x00); + } + + return sco; +} + +static void cis_add(struct iso_list_data *d, struct bt_iso_qos *qos) +{ + struct hci_cis_params *cis = &d->pdu.cis[d->pdu.cp.num_cis]; + + cis->cis_id = qos->cis; + cis->c_sdu = cpu_to_le16(qos->out.sdu); + cis->p_sdu = cpu_to_le16(qos->in.sdu); + cis->c_phy = qos->out.phy ? qos->out.phy : qos->in.phy; + cis->p_phy = qos->in.phy ? qos->in.phy : qos->out.phy; + cis->c_rtn = qos->out.rtn; + cis->p_rtn = qos->in.rtn; + + d->pdu.cp.num_cis++; +} + +static void cis_list(struct hci_conn *conn, void *data) +{ + struct iso_list_data *d = data; + + /* Skip if broadcast/ANY address */ + if (!bacmp(&conn->dst, BDADDR_ANY)) + return; + + if (d->cig != conn->iso_qos.cig || d->cis == BT_ISO_QOS_CIS_UNSET || + d->cis != conn->iso_qos.cis) + return; + + d->count++; + + if (d->pdu.cp.cig_id == BT_ISO_QOS_CIG_UNSET || + d->count >= ARRAY_SIZE(d->pdu.cis)) + return; + + cis_add(d, &conn->iso_qos); +} + +static int hci_le_create_big(struct hci_conn *conn, struct bt_iso_qos *qos) +{ + struct hci_dev *hdev = conn->hdev; + struct hci_cp_le_create_big cp; + + memset(&cp, 0, sizeof(cp)); + + cp.handle = qos->big; + cp.adv_handle = qos->bis; + cp.num_bis = 0x01; + hci_cpu_to_le24(qos->out.interval, cp.bis.sdu_interval); + cp.bis.sdu = cpu_to_le16(qos->out.sdu); + cp.bis.latency = cpu_to_le16(qos->out.latency); + cp.bis.rtn = qos->out.rtn; + cp.bis.phy = qos->out.phy; + cp.bis.packing = qos->packing; + cp.bis.framing = qos->framing; + cp.bis.encryption = 0x00; + memset(&cp.bis.bcode, 0, sizeof(cp.bis.bcode)); + + return hci_send_cmd(hdev, HCI_OP_LE_CREATE_BIG, sizeof(cp), &cp); +} + +static void set_cig_params_complete(struct hci_dev *hdev, void *data, int err) +{ + struct iso_cig_params *pdu = data; + + bt_dev_dbg(hdev, ""); + + if (err) + bt_dev_err(hdev, "Unable to set CIG parameters: %d", err); + + kfree(pdu); +} + +static int set_cig_params_sync(struct hci_dev *hdev, void *data) +{ + struct iso_cig_params *pdu = data; + u32 plen; + + plen = sizeof(pdu->cp) + pdu->cp.num_cis * sizeof(pdu->cis[0]); + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_CIG_PARAMS, plen, pdu, + HCI_CMD_TIMEOUT); +} + +static bool hci_le_set_cig_params(struct hci_conn *conn, struct bt_iso_qos *qos) +{ + struct hci_dev *hdev = conn->hdev; + struct iso_list_data data; + struct iso_cig_params *pdu; + + memset(&data, 0, sizeof(data)); + + /* Allocate a CIG if not set */ + if (qos->cig == BT_ISO_QOS_CIG_UNSET) { + for (data.cig = 0x00; data.cig < 0xff; data.cig++) { + data.count = 0; + data.cis = 0xff; + + hci_conn_hash_list_state(hdev, cis_list, ISO_LINK, + BT_BOUND, &data); + if (data.count) + continue; + + hci_conn_hash_list_state(hdev, cis_list, ISO_LINK, + BT_CONNECTED, &data); + if (!data.count) + break; + } + + if (data.cig == 0xff) + return false; + + /* Update CIG */ + qos->cig = data.cig; + } + + data.pdu.cp.cig_id = qos->cig; + hci_cpu_to_le24(qos->out.interval, data.pdu.cp.c_interval); + hci_cpu_to_le24(qos->in.interval, data.pdu.cp.p_interval); + data.pdu.cp.sca = qos->sca; + data.pdu.cp.packing = qos->packing; + data.pdu.cp.framing = qos->framing; + data.pdu.cp.c_latency = cpu_to_le16(qos->out.latency); + data.pdu.cp.p_latency = cpu_to_le16(qos->in.latency); + + if (qos->cis != BT_ISO_QOS_CIS_UNSET) { + data.count = 0; + data.cig = qos->cig; + data.cis = qos->cis; + + hci_conn_hash_list_state(hdev, cis_list, ISO_LINK, BT_BOUND, + &data); + if (data.count) + return false; + + cis_add(&data, qos); + } + + /* Reprogram all CIS(s) with the same CIG */ + for (data.cig = qos->cig, data.cis = 0x00; data.cis < 0x11; + data.cis++) { + data.count = 0; + + hci_conn_hash_list_state(hdev, cis_list, ISO_LINK, BT_BOUND, + &data); + if (data.count) + continue; + + /* Allocate a CIS if not set */ + if (qos->cis == BT_ISO_QOS_CIS_UNSET) { + /* Update CIS */ + qos->cis = data.cis; + cis_add(&data, qos); + } + } + + if (qos->cis == BT_ISO_QOS_CIS_UNSET || !data.pdu.cp.num_cis) + return false; + + pdu = kzalloc(sizeof(*pdu), GFP_KERNEL); + if (!pdu) + return false; + + memcpy(pdu, &data.pdu, sizeof(*pdu)); + + if (hci_cmd_sync_queue(hdev, set_cig_params_sync, pdu, + set_cig_params_complete) < 0) { + kfree(pdu); + return false; + } + + return true; +} + +struct hci_conn *hci_bind_cis(struct hci_dev *hdev, bdaddr_t *dst, + __u8 dst_type, struct bt_iso_qos *qos) +{ + struct hci_conn *cis; + + cis = hci_conn_hash_lookup_cis(hdev, dst, dst_type); + if (!cis) { + cis = hci_conn_add(hdev, ISO_LINK, dst, HCI_ROLE_MASTER); + if (!cis) + return ERR_PTR(-ENOMEM); + cis->cleanup = cis_cleanup; + cis->dst_type = dst_type; + } + + if (cis->state == BT_CONNECTED) + return cis; + + /* Check if CIS has been set and the settings matches */ + if (cis->state == BT_BOUND && + !memcmp(&cis->iso_qos, qos, sizeof(*qos))) + return cis; + + /* Update LINK PHYs according to QoS preference */ + cis->le_tx_phy = qos->out.phy; + cis->le_rx_phy = qos->in.phy; + + /* If output interval is not set use the input interval as it cannot be + * 0x000000. + */ + if (!qos->out.interval) + qos->out.interval = qos->in.interval; + + /* If input interval is not set use the output interval as it cannot be + * 0x000000. + */ + if (!qos->in.interval) + qos->in.interval = qos->out.interval; + + /* If output latency is not set use the input latency as it cannot be + * 0x0000. + */ + if (!qos->out.latency) + qos->out.latency = qos->in.latency; + + /* If input latency is not set use the output latency as it cannot be + * 0x0000. + */ + if (!qos->in.latency) + qos->in.latency = qos->out.latency; + + if (!hci_le_set_cig_params(cis, qos)) { + hci_conn_drop(cis); + return ERR_PTR(-EINVAL); + } + + cis->iso_qos = *qos; + cis->state = BT_BOUND; + + return cis; +} + +bool hci_iso_setup_path(struct hci_conn *conn) +{ + struct hci_dev *hdev = conn->hdev; + struct hci_cp_le_setup_iso_path cmd; + + memset(&cmd, 0, sizeof(cmd)); + + if (conn->iso_qos.out.sdu) { + cmd.handle = cpu_to_le16(conn->handle); + cmd.direction = 0x00; /* Input (Host to Controller) */ + cmd.path = 0x00; /* HCI path if enabled */ + cmd.codec = 0x03; /* Transparent Data */ + + if (hci_send_cmd(hdev, HCI_OP_LE_SETUP_ISO_PATH, sizeof(cmd), + &cmd) < 0) + return false; + } + + if (conn->iso_qos.in.sdu) { + cmd.handle = cpu_to_le16(conn->handle); + cmd.direction = 0x01; /* Output (Controller to Host) */ + cmd.path = 0x00; /* HCI path if enabled */ + cmd.codec = 0x03; /* Transparent Data */ + + if (hci_send_cmd(hdev, HCI_OP_LE_SETUP_ISO_PATH, sizeof(cmd), + &cmd) < 0) + return false; + } + + return true; +} + +static int hci_create_cis_sync(struct hci_dev *hdev, void *data) +{ + struct { + struct hci_cp_le_create_cis cp; + struct hci_cis cis[0x1f]; + } cmd; + struct hci_conn *conn = data; + u8 cig; + + memset(&cmd, 0, sizeof(cmd)); + cmd.cis[0].acl_handle = cpu_to_le16(conn->link->handle); + cmd.cis[0].cis_handle = cpu_to_le16(conn->handle); + cmd.cp.num_cis++; + cig = conn->iso_qos.cig; + + hci_dev_lock(hdev); + + rcu_read_lock(); + + list_for_each_entry_rcu(conn, &hdev->conn_hash.list, list) { + struct hci_cis *cis = &cmd.cis[cmd.cp.num_cis]; + + if (conn == data || conn->type != ISO_LINK || + conn->state == BT_CONNECTED || conn->iso_qos.cig != cig) + continue; + + /* Check if all CIS(s) belonging to a CIG are ready */ + if (!conn->link || conn->link->state != BT_CONNECTED || + conn->state != BT_CONNECT) { + cmd.cp.num_cis = 0; + break; + } + + /* Group all CIS with state BT_CONNECT since the spec don't + * allow to send them individually: + * + * BLUETOOTH CORE SPECIFICATION Version 5.3 | Vol 4, Part E + * page 2566: + * + * If the Host issues this command before all the + * HCI_LE_CIS_Established events from the previous use of the + * command have been generated, the Controller shall return the + * error code Command Disallowed (0x0C). + */ + cis->acl_handle = cpu_to_le16(conn->link->handle); + cis->cis_handle = cpu_to_le16(conn->handle); + cmd.cp.num_cis++; + } + + rcu_read_unlock(); + + hci_dev_unlock(hdev); + + if (!cmd.cp.num_cis) + return 0; + + return hci_send_cmd(hdev, HCI_OP_LE_CREATE_CIS, sizeof(cmd.cp) + + sizeof(cmd.cis[0]) * cmd.cp.num_cis, &cmd); +} + +int hci_le_create_cis(struct hci_conn *conn) +{ + struct hci_conn *cis; + struct hci_dev *hdev = conn->hdev; + int err; + + switch (conn->type) { + case LE_LINK: + if (!conn->link || conn->state != BT_CONNECTED) + return -EINVAL; + cis = conn->link; + break; + case ISO_LINK: + cis = conn; + break; + default: + return -EINVAL; + } + + if (cis->state == BT_CONNECT) + return 0; + + /* Queue Create CIS */ + err = hci_cmd_sync_queue(hdev, hci_create_cis_sync, cis, NULL); + if (err) + return err; + + cis->state = BT_CONNECT; + + return 0; +} + +static void hci_iso_qos_setup(struct hci_dev *hdev, struct hci_conn *conn, + struct bt_iso_io_qos *qos, __u8 phy) +{ + /* Only set MTU if PHY is enabled */ + if (!qos->sdu && qos->phy) { + if (hdev->iso_mtu > 0) + qos->sdu = hdev->iso_mtu; + else if (hdev->le_mtu > 0) + qos->sdu = hdev->le_mtu; + else + qos->sdu = hdev->acl_mtu; + } + + /* Use the same PHY as ACL if set to any */ + if (qos->phy == BT_ISO_PHY_ANY) + qos->phy = phy; + + /* Use LE ACL connection interval if not set */ + if (!qos->interval) + /* ACL interval unit in 1.25 ms to us */ + qos->interval = conn->le_conn_interval * 1250; + + /* Use LE ACL connection latency if not set */ + if (!qos->latency) + qos->latency = conn->le_conn_latency; +} + +static void hci_bind_bis(struct hci_conn *conn, + struct bt_iso_qos *qos) +{ + /* Update LINK PHYs according to QoS preference */ + conn->le_tx_phy = qos->out.phy; + conn->le_tx_phy = qos->out.phy; + conn->iso_qos = *qos; + conn->state = BT_BOUND; +} + +static int create_big_sync(struct hci_dev *hdev, void *data) +{ + struct hci_conn *conn = data; + struct bt_iso_qos *qos = &conn->iso_qos; + u16 interval, sync_interval = 0; + u32 flags = 0; + int err; + + if (qos->out.phy == 0x02) + flags |= MGMT_ADV_FLAG_SEC_2M; + + /* Align intervals */ + interval = qos->out.interval / 1250; + + if (qos->bis) + sync_interval = qos->sync_interval * 1600; + + err = hci_start_per_adv_sync(hdev, qos->bis, conn->le_per_adv_data_len, + conn->le_per_adv_data, flags, interval, + interval, sync_interval); + if (err) + return err; + + return hci_le_create_big(conn, &conn->iso_qos); +} + +static void create_pa_complete(struct hci_dev *hdev, void *data, int err) +{ + struct hci_cp_le_pa_create_sync *cp = data; + + bt_dev_dbg(hdev, ""); + + if (err) + bt_dev_err(hdev, "Unable to create PA: %d", err); + + kfree(cp); +} + +static int create_pa_sync(struct hci_dev *hdev, void *data) +{ + struct hci_cp_le_pa_create_sync *cp = data; + int err; + + err = __hci_cmd_sync_status(hdev, HCI_OP_LE_PA_CREATE_SYNC, + sizeof(*cp), cp, HCI_CMD_TIMEOUT); + if (err) { + hci_dev_clear_flag(hdev, HCI_PA_SYNC); + return err; + } + + return hci_update_passive_scan_sync(hdev); +} + +int hci_pa_create_sync(struct hci_dev *hdev, bdaddr_t *dst, __u8 dst_type, + __u8 sid) +{ + struct hci_cp_le_pa_create_sync *cp; + + if (hci_dev_test_and_set_flag(hdev, HCI_PA_SYNC)) + return -EBUSY; + + cp = kmalloc(sizeof(*cp), GFP_KERNEL); + if (!cp) { + hci_dev_clear_flag(hdev, HCI_PA_SYNC); + return -ENOMEM; + } + + /* Convert from ISO socket address type to HCI address type */ + if (dst_type == BDADDR_LE_PUBLIC) + dst_type = ADDR_LE_DEV_PUBLIC; + else + dst_type = ADDR_LE_DEV_RANDOM; + + memset(cp, 0, sizeof(*cp)); + cp->sid = sid; + cp->addr_type = dst_type; + bacpy(&cp->addr, dst); + + /* Queue start pa_create_sync and scan */ + return hci_cmd_sync_queue(hdev, create_pa_sync, cp, create_pa_complete); +} + +int hci_le_big_create_sync(struct hci_dev *hdev, struct bt_iso_qos *qos, + __u16 sync_handle, __u8 num_bis, __u8 bis[]) +{ + struct _packed { + struct hci_cp_le_big_create_sync cp; + __u8 bis[0x11]; + } pdu; + int err; + + if (num_bis > sizeof(pdu.bis)) + return -EINVAL; + + err = qos_set_big(hdev, qos); + if (err) + return err; + + memset(&pdu, 0, sizeof(pdu)); + pdu.cp.handle = qos->big; + pdu.cp.sync_handle = cpu_to_le16(sync_handle); + pdu.cp.num_bis = num_bis; + memcpy(pdu.bis, bis, num_bis); + + return hci_send_cmd(hdev, HCI_OP_LE_BIG_CREATE_SYNC, + sizeof(pdu.cp) + num_bis, &pdu); +} + +static void create_big_complete(struct hci_dev *hdev, void *data, int err) +{ + struct hci_conn *conn = data; + + bt_dev_dbg(hdev, "conn %p", conn); + + if (err) { + bt_dev_err(hdev, "Unable to create BIG: %d", err); + hci_connect_cfm(conn, err); + hci_conn_del(conn); + } +} + +struct hci_conn *hci_connect_bis(struct hci_dev *hdev, bdaddr_t *dst, + __u8 dst_type, struct bt_iso_qos *qos, + __u8 base_len, __u8 *base) +{ + struct hci_conn *conn; + int err; + + /* We need hci_conn object using the BDADDR_ANY as dst */ + conn = hci_add_bis(hdev, dst, qos); + if (IS_ERR(conn)) + return conn; + + hci_bind_bis(conn, qos); + + /* Add Basic Announcement into Peridic Adv Data if BASE is set */ + if (base_len && base) { + base_len = eir_append_service_data(conn->le_per_adv_data, 0, + 0x1851, base, base_len); + conn->le_per_adv_data_len = base_len; + } + + /* Queue start periodic advertising and create BIG */ + err = hci_cmd_sync_queue(hdev, create_big_sync, conn, + create_big_complete); + if (err < 0) { + hci_conn_drop(conn); + return ERR_PTR(err); + } + + hci_iso_qos_setup(hdev, conn, &qos->out, + conn->le_tx_phy ? conn->le_tx_phy : + hdev->le_tx_def_phys); + + return conn; +} + +struct hci_conn *hci_connect_cis(struct hci_dev *hdev, bdaddr_t *dst, + __u8 dst_type, struct bt_iso_qos *qos) +{ + struct hci_conn *le; + struct hci_conn *cis; + + if (hci_dev_test_flag(hdev, HCI_ADVERTISING)) + le = hci_connect_le(hdev, dst, dst_type, false, + BT_SECURITY_LOW, + HCI_LE_CONN_TIMEOUT, + HCI_ROLE_SLAVE); + else + le = hci_connect_le_scan(hdev, dst, dst_type, + BT_SECURITY_LOW, + HCI_LE_CONN_TIMEOUT, + CONN_REASON_ISO_CONNECT); + if (IS_ERR(le)) + return le; + + hci_iso_qos_setup(hdev, le, &qos->out, + le->le_tx_phy ? le->le_tx_phy : hdev->le_tx_def_phys); + hci_iso_qos_setup(hdev, le, &qos->in, + le->le_rx_phy ? le->le_rx_phy : hdev->le_rx_def_phys); + + cis = hci_bind_cis(hdev, dst, dst_type, qos); + if (IS_ERR(cis)) { + hci_conn_drop(le); + return cis; + } + + le->link = cis; + cis->link = le; + + hci_conn_hold(cis); + + /* If LE is already connected and CIS handle is already set proceed to + * Create CIS immediately. + */ + if (le->state == BT_CONNECTED && cis->handle != HCI_CONN_HANDLE_UNSET) + hci_le_create_cis(le); + + return cis; +} + +/* Check link security requirement */ +int hci_conn_check_link_mode(struct hci_conn *conn) +{ + BT_DBG("hcon %p", conn); + + /* In Secure Connections Only mode, it is required that Secure + * Connections is used and the link is encrypted with AES-CCM + * using a P-256 authenticated combination key. + */ + if (hci_dev_test_flag(conn->hdev, HCI_SC_ONLY)) { + if (!hci_conn_sc_enabled(conn) || + !test_bit(HCI_CONN_AES_CCM, &conn->flags) || + conn->key_type != HCI_LK_AUTH_COMBINATION_P256) + return 0; + } + + /* AES encryption is required for Level 4: + * + * BLUETOOTH CORE SPECIFICATION Version 5.2 | Vol 3, Part C + * page 1319: + * + * 128-bit equivalent strength for link and encryption keys + * required using FIPS approved algorithms (E0 not allowed, + * SAFER+ not allowed, and P-192 not allowed; encryption key + * not shortened) + */ + if (conn->sec_level == BT_SECURITY_FIPS && + !test_bit(HCI_CONN_AES_CCM, &conn->flags)) { + bt_dev_err(conn->hdev, + "Invalid security: Missing AES-CCM usage"); + return 0; + } + + if (hci_conn_ssp_enabled(conn) && + !test_bit(HCI_CONN_ENCRYPT, &conn->flags)) + return 0; + + return 1; +} + +/* Authenticate remote device */ +static int hci_conn_auth(struct hci_conn *conn, __u8 sec_level, __u8 auth_type) +{ + BT_DBG("hcon %p", conn); + + if (conn->pending_sec_level > sec_level) + sec_level = conn->pending_sec_level; + + if (sec_level > conn->sec_level) + conn->pending_sec_level = sec_level; + else if (test_bit(HCI_CONN_AUTH, &conn->flags)) + return 1; + + /* Make sure we preserve an existing MITM requirement*/ + auth_type |= (conn->auth_type & 0x01); + + conn->auth_type = auth_type; + + if (!test_and_set_bit(HCI_CONN_AUTH_PEND, &conn->flags)) { + struct hci_cp_auth_requested cp; + + cp.handle = cpu_to_le16(conn->handle); + hci_send_cmd(conn->hdev, HCI_OP_AUTH_REQUESTED, + sizeof(cp), &cp); + + /* Set the ENCRYPT_PEND to trigger encryption after + * authentication. + */ + if (!test_bit(HCI_CONN_ENCRYPT, &conn->flags)) + set_bit(HCI_CONN_ENCRYPT_PEND, &conn->flags); + } + + return 0; +} + +/* Encrypt the link */ +static void hci_conn_encrypt(struct hci_conn *conn) +{ + BT_DBG("hcon %p", conn); + + if (!test_and_set_bit(HCI_CONN_ENCRYPT_PEND, &conn->flags)) { + struct hci_cp_set_conn_encrypt cp; + cp.handle = cpu_to_le16(conn->handle); + cp.encrypt = 0x01; + hci_send_cmd(conn->hdev, HCI_OP_SET_CONN_ENCRYPT, sizeof(cp), + &cp); + } +} + +/* Enable security */ +int hci_conn_security(struct hci_conn *conn, __u8 sec_level, __u8 auth_type, + bool initiator) +{ + BT_DBG("hcon %p", conn); + + if (conn->type == LE_LINK) + return smp_conn_security(conn, sec_level); + + /* For sdp we don't need the link key. */ + if (sec_level == BT_SECURITY_SDP) + return 1; + + /* For non 2.1 devices and low security level we don't need the link + key. */ + if (sec_level == BT_SECURITY_LOW && !hci_conn_ssp_enabled(conn)) + return 1; + + /* For other security levels we need the link key. */ + if (!test_bit(HCI_CONN_AUTH, &conn->flags)) + goto auth; + + switch (conn->key_type) { + case HCI_LK_AUTH_COMBINATION_P256: + /* An authenticated FIPS approved combination key has + * sufficient security for security level 4 or lower. + */ + if (sec_level <= BT_SECURITY_FIPS) + goto encrypt; + break; + case HCI_LK_AUTH_COMBINATION_P192: + /* An authenticated combination key has sufficient security for + * security level 3 or lower. + */ + if (sec_level <= BT_SECURITY_HIGH) + goto encrypt; + break; + case HCI_LK_UNAUTH_COMBINATION_P192: + case HCI_LK_UNAUTH_COMBINATION_P256: + /* An unauthenticated combination key has sufficient security + * for security level 2 or lower. + */ + if (sec_level <= BT_SECURITY_MEDIUM) + goto encrypt; + break; + case HCI_LK_COMBINATION: + /* A combination key has always sufficient security for the + * security levels 2 or lower. High security level requires the + * combination key is generated using maximum PIN code length + * (16). For pre 2.1 units. + */ + if (sec_level <= BT_SECURITY_MEDIUM || conn->pin_length == 16) + goto encrypt; + break; + default: + break; + } + +auth: + if (test_bit(HCI_CONN_ENCRYPT_PEND, &conn->flags)) + return 0; + + if (initiator) + set_bit(HCI_CONN_AUTH_INITIATOR, &conn->flags); + + if (!hci_conn_auth(conn, sec_level, auth_type)) + return 0; + +encrypt: + if (test_bit(HCI_CONN_ENCRYPT, &conn->flags)) { + /* Ensure that the encryption key size has been read, + * otherwise stall the upper layer responses. + */ + if (!conn->enc_key_size) + return 0; + + /* Nothing else needed, all requirements are met */ + return 1; + } + + hci_conn_encrypt(conn); + return 0; +} +EXPORT_SYMBOL(hci_conn_security); + +/* Check secure link requirement */ +int hci_conn_check_secure(struct hci_conn *conn, __u8 sec_level) +{ + BT_DBG("hcon %p", conn); + + /* Accept if non-secure or higher security level is required */ + if (sec_level != BT_SECURITY_HIGH && sec_level != BT_SECURITY_FIPS) + return 1; + + /* Accept if secure or higher security level is already present */ + if (conn->sec_level == BT_SECURITY_HIGH || + conn->sec_level == BT_SECURITY_FIPS) + return 1; + + /* Reject not secure link */ + return 0; +} +EXPORT_SYMBOL(hci_conn_check_secure); + +/* Switch role */ +int hci_conn_switch_role(struct hci_conn *conn, __u8 role) +{ + BT_DBG("hcon %p", conn); + + if (role == conn->role) + return 1; + + if (!test_and_set_bit(HCI_CONN_RSWITCH_PEND, &conn->flags)) { + struct hci_cp_switch_role cp; + bacpy(&cp.bdaddr, &conn->dst); + cp.role = role; + hci_send_cmd(conn->hdev, HCI_OP_SWITCH_ROLE, sizeof(cp), &cp); + } + + return 0; +} +EXPORT_SYMBOL(hci_conn_switch_role); + +/* Enter active mode */ +void hci_conn_enter_active_mode(struct hci_conn *conn, __u8 force_active) +{ + struct hci_dev *hdev = conn->hdev; + + BT_DBG("hcon %p mode %d", conn, conn->mode); + + if (conn->mode != HCI_CM_SNIFF) + goto timer; + + if (!test_bit(HCI_CONN_POWER_SAVE, &conn->flags) && !force_active) + goto timer; + + if (!test_and_set_bit(HCI_CONN_MODE_CHANGE_PEND, &conn->flags)) { + struct hci_cp_exit_sniff_mode cp; + cp.handle = cpu_to_le16(conn->handle); + hci_send_cmd(hdev, HCI_OP_EXIT_SNIFF_MODE, sizeof(cp), &cp); + } + +timer: + if (hdev->idle_timeout > 0) + queue_delayed_work(hdev->workqueue, &conn->idle_work, + msecs_to_jiffies(hdev->idle_timeout)); +} + +/* Drop all connection on the device */ +void hci_conn_hash_flush(struct hci_dev *hdev) +{ + struct hci_conn_hash *h = &hdev->conn_hash; + struct hci_conn *c, *n; + + BT_DBG("hdev %s", hdev->name); + + list_for_each_entry_safe(c, n, &h->list, list) { + c->state = BT_CLOSED; + + hci_disconn_cfm(c, HCI_ERROR_LOCAL_HOST_TERM); + + /* Unlink before deleting otherwise it is possible that + * hci_conn_del removes the link which may cause the list to + * contain items already freed. + */ + hci_conn_unlink(c); + hci_conn_del(c); + } +} + +/* Check pending connect attempts */ +void hci_conn_check_pending(struct hci_dev *hdev) +{ + struct hci_conn *conn; + + BT_DBG("hdev %s", hdev->name); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_state(hdev, ACL_LINK, BT_CONNECT2); + if (conn) + hci_acl_create_connection(conn); + + hci_dev_unlock(hdev); +} + +static u32 get_link_mode(struct hci_conn *conn) +{ + u32 link_mode = 0; + + if (conn->role == HCI_ROLE_MASTER) + link_mode |= HCI_LM_MASTER; + + if (test_bit(HCI_CONN_ENCRYPT, &conn->flags)) + link_mode |= HCI_LM_ENCRYPT; + + if (test_bit(HCI_CONN_AUTH, &conn->flags)) + link_mode |= HCI_LM_AUTH; + + if (test_bit(HCI_CONN_SECURE, &conn->flags)) + link_mode |= HCI_LM_SECURE; + + if (test_bit(HCI_CONN_FIPS, &conn->flags)) + link_mode |= HCI_LM_FIPS; + + return link_mode; +} + +int hci_get_conn_list(void __user *arg) +{ + struct hci_conn *c; + struct hci_conn_list_req req, *cl; + struct hci_conn_info *ci; + struct hci_dev *hdev; + int n = 0, size, err; + + if (copy_from_user(&req, arg, sizeof(req))) + return -EFAULT; + + if (!req.conn_num || req.conn_num > (PAGE_SIZE * 2) / sizeof(*ci)) + return -EINVAL; + + size = sizeof(req) + req.conn_num * sizeof(*ci); + + cl = kmalloc(size, GFP_KERNEL); + if (!cl) + return -ENOMEM; + + hdev = hci_dev_get(req.dev_id); + if (!hdev) { + kfree(cl); + return -ENODEV; + } + + ci = cl->conn_info; + + hci_dev_lock(hdev); + list_for_each_entry(c, &hdev->conn_hash.list, list) { + bacpy(&(ci + n)->bdaddr, &c->dst); + (ci + n)->handle = c->handle; + (ci + n)->type = c->type; + (ci + n)->out = c->out; + (ci + n)->state = c->state; + (ci + n)->link_mode = get_link_mode(c); + if (++n >= req.conn_num) + break; + } + hci_dev_unlock(hdev); + + cl->dev_id = hdev->id; + cl->conn_num = n; + size = sizeof(req) + n * sizeof(*ci); + + hci_dev_put(hdev); + + err = copy_to_user(arg, cl, size); + kfree(cl); + + return err ? -EFAULT : 0; +} + +int hci_get_conn_info(struct hci_dev *hdev, void __user *arg) +{ + struct hci_conn_info_req req; + struct hci_conn_info ci; + struct hci_conn *conn; + char __user *ptr = arg + sizeof(req); + + if (copy_from_user(&req, arg, sizeof(req))) + return -EFAULT; + + hci_dev_lock(hdev); + conn = hci_conn_hash_lookup_ba(hdev, req.type, &req.bdaddr); + if (conn) { + bacpy(&ci.bdaddr, &conn->dst); + ci.handle = conn->handle; + ci.type = conn->type; + ci.out = conn->out; + ci.state = conn->state; + ci.link_mode = get_link_mode(conn); + } + hci_dev_unlock(hdev); + + if (!conn) + return -ENOENT; + + return copy_to_user(ptr, &ci, sizeof(ci)) ? -EFAULT : 0; +} + +int hci_get_auth_info(struct hci_dev *hdev, void __user *arg) +{ + struct hci_auth_info_req req; + struct hci_conn *conn; + + if (copy_from_user(&req, arg, sizeof(req))) + return -EFAULT; + + hci_dev_lock(hdev); + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &req.bdaddr); + if (conn) + req.type = conn->auth_type; + hci_dev_unlock(hdev); + + if (!conn) + return -ENOENT; + + return copy_to_user(arg, &req, sizeof(req)) ? -EFAULT : 0; +} + +struct hci_chan *hci_chan_create(struct hci_conn *conn) +{ + struct hci_dev *hdev = conn->hdev; + struct hci_chan *chan; + + BT_DBG("%s hcon %p", hdev->name, conn); + + if (test_bit(HCI_CONN_DROP, &conn->flags)) { + BT_DBG("Refusing to create new hci_chan"); + return NULL; + } + + chan = kzalloc(sizeof(*chan), GFP_KERNEL); + if (!chan) + return NULL; + + chan->conn = hci_conn_get(conn); + skb_queue_head_init(&chan->data_q); + chan->state = BT_CONNECTED; + + list_add_rcu(&chan->list, &conn->chan_list); + + return chan; +} + +void hci_chan_del(struct hci_chan *chan) +{ + struct hci_conn *conn = chan->conn; + struct hci_dev *hdev = conn->hdev; + + BT_DBG("%s hcon %p chan %p", hdev->name, conn, chan); + + list_del_rcu(&chan->list); + + synchronize_rcu(); + + /* Prevent new hci_chan's to be created for this hci_conn */ + set_bit(HCI_CONN_DROP, &conn->flags); + + hci_conn_put(conn); + + skb_queue_purge(&chan->data_q); + kfree(chan); +} + +void hci_chan_list_flush(struct hci_conn *conn) +{ + struct hci_chan *chan, *n; + + BT_DBG("hcon %p", conn); + + list_for_each_entry_safe(chan, n, &conn->chan_list, list) + hci_chan_del(chan); +} + +static struct hci_chan *__hci_chan_lookup_handle(struct hci_conn *hcon, + __u16 handle) +{ + struct hci_chan *hchan; + + list_for_each_entry(hchan, &hcon->chan_list, list) { + if (hchan->handle == handle) + return hchan; + } + + return NULL; +} + +struct hci_chan *hci_chan_lookup_handle(struct hci_dev *hdev, __u16 handle) +{ + struct hci_conn_hash *h = &hdev->conn_hash; + struct hci_conn *hcon; + struct hci_chan *hchan = NULL; + + rcu_read_lock(); + + list_for_each_entry_rcu(hcon, &h->list, list) { + hchan = __hci_chan_lookup_handle(hcon, handle); + if (hchan) + break; + } + + rcu_read_unlock(); + + return hchan; +} + +u32 hci_conn_get_phy(struct hci_conn *conn) +{ + u32 phys = 0; + + /* BLUETOOTH CORE SPECIFICATION Version 5.2 | Vol 2, Part B page 471: + * Table 6.2: Packets defined for synchronous, asynchronous, and + * CPB logical transport types. + */ + switch (conn->type) { + case SCO_LINK: + /* SCO logical transport (1 Mb/s): + * HV1, HV2, HV3 and DV. + */ + phys |= BT_PHY_BR_1M_1SLOT; + + break; + + case ACL_LINK: + /* ACL logical transport (1 Mb/s) ptt=0: + * DH1, DM3, DH3, DM5 and DH5. + */ + phys |= BT_PHY_BR_1M_1SLOT; + + if (conn->pkt_type & (HCI_DM3 | HCI_DH3)) + phys |= BT_PHY_BR_1M_3SLOT; + + if (conn->pkt_type & (HCI_DM5 | HCI_DH5)) + phys |= BT_PHY_BR_1M_5SLOT; + + /* ACL logical transport (2 Mb/s) ptt=1: + * 2-DH1, 2-DH3 and 2-DH5. + */ + if (!(conn->pkt_type & HCI_2DH1)) + phys |= BT_PHY_EDR_2M_1SLOT; + + if (!(conn->pkt_type & HCI_2DH3)) + phys |= BT_PHY_EDR_2M_3SLOT; + + if (!(conn->pkt_type & HCI_2DH5)) + phys |= BT_PHY_EDR_2M_5SLOT; + + /* ACL logical transport (3 Mb/s) ptt=1: + * 3-DH1, 3-DH3 and 3-DH5. + */ + if (!(conn->pkt_type & HCI_3DH1)) + phys |= BT_PHY_EDR_3M_1SLOT; + + if (!(conn->pkt_type & HCI_3DH3)) + phys |= BT_PHY_EDR_3M_3SLOT; + + if (!(conn->pkt_type & HCI_3DH5)) + phys |= BT_PHY_EDR_3M_5SLOT; + + break; + + case ESCO_LINK: + /* eSCO logical transport (1 Mb/s): EV3, EV4 and EV5 */ + phys |= BT_PHY_BR_1M_1SLOT; + + if (!(conn->pkt_type & (ESCO_EV4 | ESCO_EV5))) + phys |= BT_PHY_BR_1M_3SLOT; + + /* eSCO logical transport (2 Mb/s): 2-EV3, 2-EV5 */ + if (!(conn->pkt_type & ESCO_2EV3)) + phys |= BT_PHY_EDR_2M_1SLOT; + + if (!(conn->pkt_type & ESCO_2EV5)) + phys |= BT_PHY_EDR_2M_3SLOT; + + /* eSCO logical transport (3 Mb/s): 3-EV3, 3-EV5 */ + if (!(conn->pkt_type & ESCO_3EV3)) + phys |= BT_PHY_EDR_3M_1SLOT; + + if (!(conn->pkt_type & ESCO_3EV5)) + phys |= BT_PHY_EDR_3M_3SLOT; + + break; + + case LE_LINK: + if (conn->le_tx_phy & HCI_LE_SET_PHY_1M) + phys |= BT_PHY_LE_1M_TX; + + if (conn->le_rx_phy & HCI_LE_SET_PHY_1M) + phys |= BT_PHY_LE_1M_RX; + + if (conn->le_tx_phy & HCI_LE_SET_PHY_2M) + phys |= BT_PHY_LE_2M_TX; + + if (conn->le_rx_phy & HCI_LE_SET_PHY_2M) + phys |= BT_PHY_LE_2M_RX; + + if (conn->le_tx_phy & HCI_LE_SET_PHY_CODED) + phys |= BT_PHY_LE_CODED_TX; + + if (conn->le_rx_phy & HCI_LE_SET_PHY_CODED) + phys |= BT_PHY_LE_CODED_RX; + + break; + } + + return phys; +} + +int hci_abort_conn(struct hci_conn *conn, u8 reason) +{ + int r = 0; + + if (test_and_set_bit(HCI_CONN_CANCEL, &conn->flags)) + return 0; + + switch (conn->state) { + case BT_CONNECTED: + case BT_CONFIG: + if (conn->type == AMP_LINK) { + struct hci_cp_disconn_phy_link cp; + + cp.phy_handle = HCI_PHY_HANDLE(conn->handle); + cp.reason = reason; + r = hci_send_cmd(conn->hdev, HCI_OP_DISCONN_PHY_LINK, + sizeof(cp), &cp); + } else { + struct hci_cp_disconnect dc; + + dc.handle = cpu_to_le16(conn->handle); + dc.reason = reason; + r = hci_send_cmd(conn->hdev, HCI_OP_DISCONNECT, + sizeof(dc), &dc); + } + + conn->state = BT_DISCONN; + + break; + case BT_CONNECT: + if (conn->type == LE_LINK) { + if (test_bit(HCI_CONN_SCANNING, &conn->flags)) + break; + r = hci_send_cmd(conn->hdev, + HCI_OP_LE_CREATE_CONN_CANCEL, 0, NULL); + } else if (conn->type == ACL_LINK) { + if (conn->hdev->hci_ver < BLUETOOTH_VER_1_2) + break; + r = hci_send_cmd(conn->hdev, + HCI_OP_CREATE_CONN_CANCEL, + 6, &conn->dst); + } + break; + case BT_CONNECT2: + if (conn->type == ACL_LINK) { + struct hci_cp_reject_conn_req rej; + + bacpy(&rej.bdaddr, &conn->dst); + rej.reason = reason; + + r = hci_send_cmd(conn->hdev, + HCI_OP_REJECT_CONN_REQ, + sizeof(rej), &rej); + } else if (conn->type == SCO_LINK || conn->type == ESCO_LINK) { + struct hci_cp_reject_sync_conn_req rej; + + bacpy(&rej.bdaddr, &conn->dst); + + /* SCO rejection has its own limited set of + * allowed error values (0x0D-0x0F) which isn't + * compatible with most values passed to this + * function. To be safe hard-code one of the + * values that's suitable for SCO. + */ + rej.reason = HCI_ERROR_REJ_LIMITED_RESOURCES; + + r = hci_send_cmd(conn->hdev, + HCI_OP_REJECT_SYNC_CONN_REQ, + sizeof(rej), &rej); + } + break; + default: + conn->state = BT_CLOSED; + break; + } + + return r; +} diff --git a/net/bluetooth/hci_core.c b/net/bluetooth/hci_core.c new file mode 100644 index 000000000..6a1db678d --- /dev/null +++ b/net/bluetooth/hci_core.c @@ -0,0 +1,4158 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + Copyright (C) 2000-2001 Qualcomm Incorporated + Copyright (C) 2011 ProFUSION Embedded Systems + + Written 2000,2001 by Maxim Krasnyansky + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +/* Bluetooth HCI core. */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "hci_request.h" +#include "hci_debugfs.h" +#include "smp.h" +#include "leds.h" +#include "msft.h" +#include "aosp.h" +#include "hci_codec.h" + +static void hci_rx_work(struct work_struct *work); +static void hci_cmd_work(struct work_struct *work); +static void hci_tx_work(struct work_struct *work); + +/* HCI device list */ +LIST_HEAD(hci_dev_list); +DEFINE_RWLOCK(hci_dev_list_lock); + +/* HCI callback list */ +LIST_HEAD(hci_cb_list); +DEFINE_MUTEX(hci_cb_list_lock); + +/* HCI ID Numbering */ +static DEFINE_IDA(hci_index_ida); + +static int hci_scan_req(struct hci_request *req, unsigned long opt) +{ + __u8 scan = opt; + + BT_DBG("%s %x", req->hdev->name, scan); + + /* Inquiry and Page scans */ + hci_req_add(req, HCI_OP_WRITE_SCAN_ENABLE, 1, &scan); + return 0; +} + +static int hci_auth_req(struct hci_request *req, unsigned long opt) +{ + __u8 auth = opt; + + BT_DBG("%s %x", req->hdev->name, auth); + + /* Authentication */ + hci_req_add(req, HCI_OP_WRITE_AUTH_ENABLE, 1, &auth); + return 0; +} + +static int hci_encrypt_req(struct hci_request *req, unsigned long opt) +{ + __u8 encrypt = opt; + + BT_DBG("%s %x", req->hdev->name, encrypt); + + /* Encryption */ + hci_req_add(req, HCI_OP_WRITE_ENCRYPT_MODE, 1, &encrypt); + return 0; +} + +static int hci_linkpol_req(struct hci_request *req, unsigned long opt) +{ + __le16 policy = cpu_to_le16(opt); + + BT_DBG("%s %x", req->hdev->name, policy); + + /* Default link policy */ + hci_req_add(req, HCI_OP_WRITE_DEF_LINK_POLICY, 2, &policy); + return 0; +} + +/* Get HCI device by index. + * Device is held on return. */ +struct hci_dev *hci_dev_get(int index) +{ + struct hci_dev *hdev = NULL, *d; + + BT_DBG("%d", index); + + if (index < 0) + return NULL; + + read_lock(&hci_dev_list_lock); + list_for_each_entry(d, &hci_dev_list, list) { + if (d->id == index) { + hdev = hci_dev_hold(d); + break; + } + } + read_unlock(&hci_dev_list_lock); + return hdev; +} + +/* ---- Inquiry support ---- */ + +bool hci_discovery_active(struct hci_dev *hdev) +{ + struct discovery_state *discov = &hdev->discovery; + + switch (discov->state) { + case DISCOVERY_FINDING: + case DISCOVERY_RESOLVING: + return true; + + default: + return false; + } +} + +void hci_discovery_set_state(struct hci_dev *hdev, int state) +{ + int old_state = hdev->discovery.state; + + BT_DBG("%s state %u -> %u", hdev->name, hdev->discovery.state, state); + + if (old_state == state) + return; + + hdev->discovery.state = state; + + switch (state) { + case DISCOVERY_STOPPED: + hci_update_passive_scan(hdev); + + if (old_state != DISCOVERY_STARTING) + mgmt_discovering(hdev, 0); + break; + case DISCOVERY_STARTING: + break; + case DISCOVERY_FINDING: + mgmt_discovering(hdev, 1); + break; + case DISCOVERY_RESOLVING: + break; + case DISCOVERY_STOPPING: + break; + } +} + +void hci_inquiry_cache_flush(struct hci_dev *hdev) +{ + struct discovery_state *cache = &hdev->discovery; + struct inquiry_entry *p, *n; + + list_for_each_entry_safe(p, n, &cache->all, all) { + list_del(&p->all); + kfree(p); + } + + INIT_LIST_HEAD(&cache->unknown); + INIT_LIST_HEAD(&cache->resolve); +} + +struct inquiry_entry *hci_inquiry_cache_lookup(struct hci_dev *hdev, + bdaddr_t *bdaddr) +{ + struct discovery_state *cache = &hdev->discovery; + struct inquiry_entry *e; + + BT_DBG("cache %p, %pMR", cache, bdaddr); + + list_for_each_entry(e, &cache->all, all) { + if (!bacmp(&e->data.bdaddr, bdaddr)) + return e; + } + + return NULL; +} + +struct inquiry_entry *hci_inquiry_cache_lookup_unknown(struct hci_dev *hdev, + bdaddr_t *bdaddr) +{ + struct discovery_state *cache = &hdev->discovery; + struct inquiry_entry *e; + + BT_DBG("cache %p, %pMR", cache, bdaddr); + + list_for_each_entry(e, &cache->unknown, list) { + if (!bacmp(&e->data.bdaddr, bdaddr)) + return e; + } + + return NULL; +} + +struct inquiry_entry *hci_inquiry_cache_lookup_resolve(struct hci_dev *hdev, + bdaddr_t *bdaddr, + int state) +{ + struct discovery_state *cache = &hdev->discovery; + struct inquiry_entry *e; + + BT_DBG("cache %p bdaddr %pMR state %d", cache, bdaddr, state); + + list_for_each_entry(e, &cache->resolve, list) { + if (!bacmp(bdaddr, BDADDR_ANY) && e->name_state == state) + return e; + if (!bacmp(&e->data.bdaddr, bdaddr)) + return e; + } + + return NULL; +} + +void hci_inquiry_cache_update_resolve(struct hci_dev *hdev, + struct inquiry_entry *ie) +{ + struct discovery_state *cache = &hdev->discovery; + struct list_head *pos = &cache->resolve; + struct inquiry_entry *p; + + list_del(&ie->list); + + list_for_each_entry(p, &cache->resolve, list) { + if (p->name_state != NAME_PENDING && + abs(p->data.rssi) >= abs(ie->data.rssi)) + break; + pos = &p->list; + } + + list_add(&ie->list, pos); +} + +u32 hci_inquiry_cache_update(struct hci_dev *hdev, struct inquiry_data *data, + bool name_known) +{ + struct discovery_state *cache = &hdev->discovery; + struct inquiry_entry *ie; + u32 flags = 0; + + BT_DBG("cache %p, %pMR", cache, &data->bdaddr); + + hci_remove_remote_oob_data(hdev, &data->bdaddr, BDADDR_BREDR); + + if (!data->ssp_mode) + flags |= MGMT_DEV_FOUND_LEGACY_PAIRING; + + ie = hci_inquiry_cache_lookup(hdev, &data->bdaddr); + if (ie) { + if (!ie->data.ssp_mode) + flags |= MGMT_DEV_FOUND_LEGACY_PAIRING; + + if (ie->name_state == NAME_NEEDED && + data->rssi != ie->data.rssi) { + ie->data.rssi = data->rssi; + hci_inquiry_cache_update_resolve(hdev, ie); + } + + goto update; + } + + /* Entry not in the cache. Add new one. */ + ie = kzalloc(sizeof(*ie), GFP_KERNEL); + if (!ie) { + flags |= MGMT_DEV_FOUND_CONFIRM_NAME; + goto done; + } + + list_add(&ie->all, &cache->all); + + if (name_known) { + ie->name_state = NAME_KNOWN; + } else { + ie->name_state = NAME_NOT_KNOWN; + list_add(&ie->list, &cache->unknown); + } + +update: + if (name_known && ie->name_state != NAME_KNOWN && + ie->name_state != NAME_PENDING) { + ie->name_state = NAME_KNOWN; + list_del(&ie->list); + } + + memcpy(&ie->data, data, sizeof(*data)); + ie->timestamp = jiffies; + cache->timestamp = jiffies; + + if (ie->name_state == NAME_NOT_KNOWN) + flags |= MGMT_DEV_FOUND_CONFIRM_NAME; + +done: + return flags; +} + +static int inquiry_cache_dump(struct hci_dev *hdev, int num, __u8 *buf) +{ + struct discovery_state *cache = &hdev->discovery; + struct inquiry_info *info = (struct inquiry_info *) buf; + struct inquiry_entry *e; + int copied = 0; + + list_for_each_entry(e, &cache->all, all) { + struct inquiry_data *data = &e->data; + + if (copied >= num) + break; + + bacpy(&info->bdaddr, &data->bdaddr); + info->pscan_rep_mode = data->pscan_rep_mode; + info->pscan_period_mode = data->pscan_period_mode; + info->pscan_mode = data->pscan_mode; + memcpy(info->dev_class, data->dev_class, 3); + info->clock_offset = data->clock_offset; + + info++; + copied++; + } + + BT_DBG("cache %p, copied %d", cache, copied); + return copied; +} + +static int hci_inq_req(struct hci_request *req, unsigned long opt) +{ + struct hci_inquiry_req *ir = (struct hci_inquiry_req *) opt; + struct hci_dev *hdev = req->hdev; + struct hci_cp_inquiry cp; + + BT_DBG("%s", hdev->name); + + if (test_bit(HCI_INQUIRY, &hdev->flags)) + return 0; + + /* Start Inquiry */ + memcpy(&cp.lap, &ir->lap, 3); + cp.length = ir->length; + cp.num_rsp = ir->num_rsp; + hci_req_add(req, HCI_OP_INQUIRY, sizeof(cp), &cp); + + return 0; +} + +int hci_inquiry(void __user *arg) +{ + __u8 __user *ptr = arg; + struct hci_inquiry_req ir; + struct hci_dev *hdev; + int err = 0, do_inquiry = 0, max_rsp; + long timeo; + __u8 *buf; + + if (copy_from_user(&ir, ptr, sizeof(ir))) + return -EFAULT; + + hdev = hci_dev_get(ir.dev_id); + if (!hdev) + return -ENODEV; + + if (hci_dev_test_flag(hdev, HCI_USER_CHANNEL)) { + err = -EBUSY; + goto done; + } + + if (hci_dev_test_flag(hdev, HCI_UNCONFIGURED)) { + err = -EOPNOTSUPP; + goto done; + } + + if (hdev->dev_type != HCI_PRIMARY) { + err = -EOPNOTSUPP; + goto done; + } + + if (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) { + err = -EOPNOTSUPP; + goto done; + } + + /* Restrict maximum inquiry length to 60 seconds */ + if (ir.length > 60) { + err = -EINVAL; + goto done; + } + + hci_dev_lock(hdev); + if (inquiry_cache_age(hdev) > INQUIRY_CACHE_AGE_MAX || + inquiry_cache_empty(hdev) || ir.flags & IREQ_CACHE_FLUSH) { + hci_inquiry_cache_flush(hdev); + do_inquiry = 1; + } + hci_dev_unlock(hdev); + + timeo = ir.length * msecs_to_jiffies(2000); + + if (do_inquiry) { + err = hci_req_sync(hdev, hci_inq_req, (unsigned long) &ir, + timeo, NULL); + if (err < 0) + goto done; + + /* Wait until Inquiry procedure finishes (HCI_INQUIRY flag is + * cleared). If it is interrupted by a signal, return -EINTR. + */ + if (wait_on_bit(&hdev->flags, HCI_INQUIRY, + TASK_INTERRUPTIBLE)) { + err = -EINTR; + goto done; + } + } + + /* for unlimited number of responses we will use buffer with + * 255 entries + */ + max_rsp = (ir.num_rsp == 0) ? 255 : ir.num_rsp; + + /* cache_dump can't sleep. Therefore we allocate temp buffer and then + * copy it to the user space. + */ + buf = kmalloc_array(max_rsp, sizeof(struct inquiry_info), GFP_KERNEL); + if (!buf) { + err = -ENOMEM; + goto done; + } + + hci_dev_lock(hdev); + ir.num_rsp = inquiry_cache_dump(hdev, max_rsp, buf); + hci_dev_unlock(hdev); + + BT_DBG("num_rsp %d", ir.num_rsp); + + if (!copy_to_user(ptr, &ir, sizeof(ir))) { + ptr += sizeof(ir); + if (copy_to_user(ptr, buf, sizeof(struct inquiry_info) * + ir.num_rsp)) + err = -EFAULT; + } else + err = -EFAULT; + + kfree(buf); + +done: + hci_dev_put(hdev); + return err; +} + +static int hci_dev_do_open(struct hci_dev *hdev) +{ + int ret = 0; + + BT_DBG("%s %p", hdev->name, hdev); + + hci_req_sync_lock(hdev); + + ret = hci_dev_open_sync(hdev); + + hci_req_sync_unlock(hdev); + return ret; +} + +/* ---- HCI ioctl helpers ---- */ + +int hci_dev_open(__u16 dev) +{ + struct hci_dev *hdev; + int err; + + hdev = hci_dev_get(dev); + if (!hdev) + return -ENODEV; + + /* Devices that are marked as unconfigured can only be powered + * up as user channel. Trying to bring them up as normal devices + * will result into a failure. Only user channel operation is + * possible. + * + * When this function is called for a user channel, the flag + * HCI_USER_CHANNEL will be set first before attempting to + * open the device. + */ + if (hci_dev_test_flag(hdev, HCI_UNCONFIGURED) && + !hci_dev_test_flag(hdev, HCI_USER_CHANNEL)) { + err = -EOPNOTSUPP; + goto done; + } + + /* We need to ensure that no other power on/off work is pending + * before proceeding to call hci_dev_do_open. This is + * particularly important if the setup procedure has not yet + * completed. + */ + if (hci_dev_test_and_clear_flag(hdev, HCI_AUTO_OFF)) + cancel_delayed_work(&hdev->power_off); + + /* After this call it is guaranteed that the setup procedure + * has finished. This means that error conditions like RFKILL + * or no valid public or static random address apply. + */ + flush_workqueue(hdev->req_workqueue); + + /* For controllers not using the management interface and that + * are brought up using legacy ioctl, set the HCI_BONDABLE bit + * so that pairing works for them. Once the management interface + * is in use this bit will be cleared again and userspace has + * to explicitly enable it. + */ + if (!hci_dev_test_flag(hdev, HCI_USER_CHANNEL) && + !hci_dev_test_flag(hdev, HCI_MGMT)) + hci_dev_set_flag(hdev, HCI_BONDABLE); + + err = hci_dev_do_open(hdev); + +done: + hci_dev_put(hdev); + return err; +} + +int hci_dev_do_close(struct hci_dev *hdev) +{ + int err; + + BT_DBG("%s %p", hdev->name, hdev); + + hci_req_sync_lock(hdev); + + err = hci_dev_close_sync(hdev); + + hci_req_sync_unlock(hdev); + + return err; +} + +int hci_dev_close(__u16 dev) +{ + struct hci_dev *hdev; + int err; + + hdev = hci_dev_get(dev); + if (!hdev) + return -ENODEV; + + if (hci_dev_test_flag(hdev, HCI_USER_CHANNEL)) { + err = -EBUSY; + goto done; + } + + cancel_work_sync(&hdev->power_on); + if (hci_dev_test_and_clear_flag(hdev, HCI_AUTO_OFF)) + cancel_delayed_work(&hdev->power_off); + + err = hci_dev_do_close(hdev); + +done: + hci_dev_put(hdev); + return err; +} + +static int hci_dev_do_reset(struct hci_dev *hdev) +{ + int ret; + + BT_DBG("%s %p", hdev->name, hdev); + + hci_req_sync_lock(hdev); + + /* Drop queues */ + skb_queue_purge(&hdev->rx_q); + skb_queue_purge(&hdev->cmd_q); + + /* Cancel these to avoid queueing non-chained pending work */ + hci_dev_set_flag(hdev, HCI_CMD_DRAIN_WORKQUEUE); + /* Wait for + * + * if (!hci_dev_test_flag(hdev, HCI_CMD_DRAIN_WORKQUEUE)) + * queue_delayed_work(&hdev->{cmd,ncmd}_timer) + * + * inside RCU section to see the flag or complete scheduling. + */ + synchronize_rcu(); + /* Explicitly cancel works in case scheduled after setting the flag. */ + cancel_delayed_work(&hdev->cmd_timer); + cancel_delayed_work(&hdev->ncmd_timer); + + /* Avoid potential lockdep warnings from the *_flush() calls by + * ensuring the workqueue is empty up front. + */ + drain_workqueue(hdev->workqueue); + + hci_dev_lock(hdev); + hci_inquiry_cache_flush(hdev); + hci_conn_hash_flush(hdev); + hci_dev_unlock(hdev); + + if (hdev->flush) + hdev->flush(hdev); + + hci_dev_clear_flag(hdev, HCI_CMD_DRAIN_WORKQUEUE); + + atomic_set(&hdev->cmd_cnt, 1); + hdev->acl_cnt = 0; + hdev->sco_cnt = 0; + hdev->le_cnt = 0; + hdev->iso_cnt = 0; + + ret = hci_reset_sync(hdev); + + hci_req_sync_unlock(hdev); + return ret; +} + +int hci_dev_reset(__u16 dev) +{ + struct hci_dev *hdev; + int err; + + hdev = hci_dev_get(dev); + if (!hdev) + return -ENODEV; + + if (!test_bit(HCI_UP, &hdev->flags)) { + err = -ENETDOWN; + goto done; + } + + if (hci_dev_test_flag(hdev, HCI_USER_CHANNEL)) { + err = -EBUSY; + goto done; + } + + if (hci_dev_test_flag(hdev, HCI_UNCONFIGURED)) { + err = -EOPNOTSUPP; + goto done; + } + + err = hci_dev_do_reset(hdev); + +done: + hci_dev_put(hdev); + return err; +} + +int hci_dev_reset_stat(__u16 dev) +{ + struct hci_dev *hdev; + int ret = 0; + + hdev = hci_dev_get(dev); + if (!hdev) + return -ENODEV; + + if (hci_dev_test_flag(hdev, HCI_USER_CHANNEL)) { + ret = -EBUSY; + goto done; + } + + if (hci_dev_test_flag(hdev, HCI_UNCONFIGURED)) { + ret = -EOPNOTSUPP; + goto done; + } + + memset(&hdev->stat, 0, sizeof(struct hci_dev_stats)); + +done: + hci_dev_put(hdev); + return ret; +} + +static void hci_update_passive_scan_state(struct hci_dev *hdev, u8 scan) +{ + bool conn_changed, discov_changed; + + BT_DBG("%s scan 0x%02x", hdev->name, scan); + + if ((scan & SCAN_PAGE)) + conn_changed = !hci_dev_test_and_set_flag(hdev, + HCI_CONNECTABLE); + else + conn_changed = hci_dev_test_and_clear_flag(hdev, + HCI_CONNECTABLE); + + if ((scan & SCAN_INQUIRY)) { + discov_changed = !hci_dev_test_and_set_flag(hdev, + HCI_DISCOVERABLE); + } else { + hci_dev_clear_flag(hdev, HCI_LIMITED_DISCOVERABLE); + discov_changed = hci_dev_test_and_clear_flag(hdev, + HCI_DISCOVERABLE); + } + + if (!hci_dev_test_flag(hdev, HCI_MGMT)) + return; + + if (conn_changed || discov_changed) { + /* In case this was disabled through mgmt */ + hci_dev_set_flag(hdev, HCI_BREDR_ENABLED); + + if (hci_dev_test_flag(hdev, HCI_LE_ENABLED)) + hci_update_adv_data(hdev, hdev->cur_adv_instance); + + mgmt_new_settings(hdev); + } +} + +int hci_dev_cmd(unsigned int cmd, void __user *arg) +{ + struct hci_dev *hdev; + struct hci_dev_req dr; + int err = 0; + + if (copy_from_user(&dr, arg, sizeof(dr))) + return -EFAULT; + + hdev = hci_dev_get(dr.dev_id); + if (!hdev) + return -ENODEV; + + if (hci_dev_test_flag(hdev, HCI_USER_CHANNEL)) { + err = -EBUSY; + goto done; + } + + if (hci_dev_test_flag(hdev, HCI_UNCONFIGURED)) { + err = -EOPNOTSUPP; + goto done; + } + + if (hdev->dev_type != HCI_PRIMARY) { + err = -EOPNOTSUPP; + goto done; + } + + if (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) { + err = -EOPNOTSUPP; + goto done; + } + + switch (cmd) { + case HCISETAUTH: + err = hci_req_sync(hdev, hci_auth_req, dr.dev_opt, + HCI_INIT_TIMEOUT, NULL); + break; + + case HCISETENCRYPT: + if (!lmp_encrypt_capable(hdev)) { + err = -EOPNOTSUPP; + break; + } + + if (!test_bit(HCI_AUTH, &hdev->flags)) { + /* Auth must be enabled first */ + err = hci_req_sync(hdev, hci_auth_req, dr.dev_opt, + HCI_INIT_TIMEOUT, NULL); + if (err) + break; + } + + err = hci_req_sync(hdev, hci_encrypt_req, dr.dev_opt, + HCI_INIT_TIMEOUT, NULL); + break; + + case HCISETSCAN: + err = hci_req_sync(hdev, hci_scan_req, dr.dev_opt, + HCI_INIT_TIMEOUT, NULL); + + /* Ensure that the connectable and discoverable states + * get correctly modified as this was a non-mgmt change. + */ + if (!err) + hci_update_passive_scan_state(hdev, dr.dev_opt); + break; + + case HCISETLINKPOL: + err = hci_req_sync(hdev, hci_linkpol_req, dr.dev_opt, + HCI_INIT_TIMEOUT, NULL); + break; + + case HCISETLINKMODE: + hdev->link_mode = ((__u16) dr.dev_opt) & + (HCI_LM_MASTER | HCI_LM_ACCEPT); + break; + + case HCISETPTYPE: + if (hdev->pkt_type == (__u16) dr.dev_opt) + break; + + hdev->pkt_type = (__u16) dr.dev_opt; + mgmt_phy_configuration_changed(hdev, NULL); + break; + + case HCISETACLMTU: + hdev->acl_mtu = *((__u16 *) &dr.dev_opt + 1); + hdev->acl_pkts = *((__u16 *) &dr.dev_opt + 0); + break; + + case HCISETSCOMTU: + hdev->sco_mtu = *((__u16 *) &dr.dev_opt + 1); + hdev->sco_pkts = *((__u16 *) &dr.dev_opt + 0); + break; + + default: + err = -EINVAL; + break; + } + +done: + hci_dev_put(hdev); + return err; +} + +int hci_get_dev_list(void __user *arg) +{ + struct hci_dev *hdev; + struct hci_dev_list_req *dl; + struct hci_dev_req *dr; + int n = 0, size, err; + __u16 dev_num; + + if (get_user(dev_num, (__u16 __user *) arg)) + return -EFAULT; + + if (!dev_num || dev_num > (PAGE_SIZE * 2) / sizeof(*dr)) + return -EINVAL; + + size = sizeof(*dl) + dev_num * sizeof(*dr); + + dl = kzalloc(size, GFP_KERNEL); + if (!dl) + return -ENOMEM; + + dr = dl->dev_req; + + read_lock(&hci_dev_list_lock); + list_for_each_entry(hdev, &hci_dev_list, list) { + unsigned long flags = hdev->flags; + + /* When the auto-off is configured it means the transport + * is running, but in that case still indicate that the + * device is actually down. + */ + if (hci_dev_test_flag(hdev, HCI_AUTO_OFF)) + flags &= ~BIT(HCI_UP); + + (dr + n)->dev_id = hdev->id; + (dr + n)->dev_opt = flags; + + if (++n >= dev_num) + break; + } + read_unlock(&hci_dev_list_lock); + + dl->dev_num = n; + size = sizeof(*dl) + n * sizeof(*dr); + + err = copy_to_user(arg, dl, size); + kfree(dl); + + return err ? -EFAULT : 0; +} + +int hci_get_dev_info(void __user *arg) +{ + struct hci_dev *hdev; + struct hci_dev_info di; + unsigned long flags; + int err = 0; + + if (copy_from_user(&di, arg, sizeof(di))) + return -EFAULT; + + hdev = hci_dev_get(di.dev_id); + if (!hdev) + return -ENODEV; + + /* When the auto-off is configured it means the transport + * is running, but in that case still indicate that the + * device is actually down. + */ + if (hci_dev_test_flag(hdev, HCI_AUTO_OFF)) + flags = hdev->flags & ~BIT(HCI_UP); + else + flags = hdev->flags; + + strcpy(di.name, hdev->name); + di.bdaddr = hdev->bdaddr; + di.type = (hdev->bus & 0x0f) | ((hdev->dev_type & 0x03) << 4); + di.flags = flags; + di.pkt_type = hdev->pkt_type; + if (lmp_bredr_capable(hdev)) { + di.acl_mtu = hdev->acl_mtu; + di.acl_pkts = hdev->acl_pkts; + di.sco_mtu = hdev->sco_mtu; + di.sco_pkts = hdev->sco_pkts; + } else { + di.acl_mtu = hdev->le_mtu; + di.acl_pkts = hdev->le_pkts; + di.sco_mtu = 0; + di.sco_pkts = 0; + } + di.link_policy = hdev->link_policy; + di.link_mode = hdev->link_mode; + + memcpy(&di.stat, &hdev->stat, sizeof(di.stat)); + memcpy(&di.features, &hdev->features, sizeof(di.features)); + + if (copy_to_user(arg, &di, sizeof(di))) + err = -EFAULT; + + hci_dev_put(hdev); + + return err; +} + +/* ---- Interface to HCI drivers ---- */ + +static int hci_rfkill_set_block(void *data, bool blocked) +{ + struct hci_dev *hdev = data; + + BT_DBG("%p name %s blocked %d", hdev, hdev->name, blocked); + + if (hci_dev_test_flag(hdev, HCI_USER_CHANNEL)) + return -EBUSY; + + if (blocked) { + hci_dev_set_flag(hdev, HCI_RFKILLED); + if (!hci_dev_test_flag(hdev, HCI_SETUP) && + !hci_dev_test_flag(hdev, HCI_CONFIG)) + hci_dev_do_close(hdev); + } else { + hci_dev_clear_flag(hdev, HCI_RFKILLED); + } + + return 0; +} + +static const struct rfkill_ops hci_rfkill_ops = { + .set_block = hci_rfkill_set_block, +}; + +static void hci_power_on(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, power_on); + int err; + + BT_DBG("%s", hdev->name); + + if (test_bit(HCI_UP, &hdev->flags) && + hci_dev_test_flag(hdev, HCI_MGMT) && + hci_dev_test_and_clear_flag(hdev, HCI_AUTO_OFF)) { + cancel_delayed_work(&hdev->power_off); + err = hci_powered_update_sync(hdev); + mgmt_power_on(hdev, err); + return; + } + + err = hci_dev_do_open(hdev); + if (err < 0) { + hci_dev_lock(hdev); + mgmt_set_powered_failed(hdev, err); + hci_dev_unlock(hdev); + return; + } + + /* During the HCI setup phase, a few error conditions are + * ignored and they need to be checked now. If they are still + * valid, it is important to turn the device back off. + */ + if (hci_dev_test_flag(hdev, HCI_RFKILLED) || + hci_dev_test_flag(hdev, HCI_UNCONFIGURED) || + (hdev->dev_type == HCI_PRIMARY && + !bacmp(&hdev->bdaddr, BDADDR_ANY) && + !bacmp(&hdev->static_addr, BDADDR_ANY))) { + hci_dev_clear_flag(hdev, HCI_AUTO_OFF); + hci_dev_do_close(hdev); + } else if (hci_dev_test_flag(hdev, HCI_AUTO_OFF)) { + queue_delayed_work(hdev->req_workqueue, &hdev->power_off, + HCI_AUTO_OFF_TIMEOUT); + } + + if (hci_dev_test_and_clear_flag(hdev, HCI_SETUP)) { + /* For unconfigured devices, set the HCI_RAW flag + * so that userspace can easily identify them. + */ + if (hci_dev_test_flag(hdev, HCI_UNCONFIGURED)) + set_bit(HCI_RAW, &hdev->flags); + + /* For fully configured devices, this will send + * the Index Added event. For unconfigured devices, + * it will send Unconfigued Index Added event. + * + * Devices with HCI_QUIRK_RAW_DEVICE are ignored + * and no event will be send. + */ + mgmt_index_added(hdev); + } else if (hci_dev_test_and_clear_flag(hdev, HCI_CONFIG)) { + /* When the controller is now configured, then it + * is important to clear the HCI_RAW flag. + */ + if (!hci_dev_test_flag(hdev, HCI_UNCONFIGURED)) + clear_bit(HCI_RAW, &hdev->flags); + + /* Powering on the controller with HCI_CONFIG set only + * happens with the transition from unconfigured to + * configured. This will send the Index Added event. + */ + mgmt_index_added(hdev); + } +} + +static void hci_power_off(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, + power_off.work); + + BT_DBG("%s", hdev->name); + + hci_dev_do_close(hdev); +} + +static void hci_error_reset(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, error_reset); + + BT_DBG("%s", hdev->name); + + if (hdev->hw_error) + hdev->hw_error(hdev, hdev->hw_error_code); + else + bt_dev_err(hdev, "hardware error 0x%2.2x", hdev->hw_error_code); + + if (hci_dev_do_close(hdev)) + return; + + hci_dev_do_open(hdev); +} + +void hci_uuids_clear(struct hci_dev *hdev) +{ + struct bt_uuid *uuid, *tmp; + + list_for_each_entry_safe(uuid, tmp, &hdev->uuids, list) { + list_del(&uuid->list); + kfree(uuid); + } +} + +void hci_link_keys_clear(struct hci_dev *hdev) +{ + struct link_key *key, *tmp; + + list_for_each_entry_safe(key, tmp, &hdev->link_keys, list) { + list_del_rcu(&key->list); + kfree_rcu(key, rcu); + } +} + +void hci_smp_ltks_clear(struct hci_dev *hdev) +{ + struct smp_ltk *k, *tmp; + + list_for_each_entry_safe(k, tmp, &hdev->long_term_keys, list) { + list_del_rcu(&k->list); + kfree_rcu(k, rcu); + } +} + +void hci_smp_irks_clear(struct hci_dev *hdev) +{ + struct smp_irk *k, *tmp; + + list_for_each_entry_safe(k, tmp, &hdev->identity_resolving_keys, list) { + list_del_rcu(&k->list); + kfree_rcu(k, rcu); + } +} + +void hci_blocked_keys_clear(struct hci_dev *hdev) +{ + struct blocked_key *b, *tmp; + + list_for_each_entry_safe(b, tmp, &hdev->blocked_keys, list) { + list_del_rcu(&b->list); + kfree_rcu(b, rcu); + } +} + +bool hci_is_blocked_key(struct hci_dev *hdev, u8 type, u8 val[16]) +{ + bool blocked = false; + struct blocked_key *b; + + rcu_read_lock(); + list_for_each_entry_rcu(b, &hdev->blocked_keys, list) { + if (b->type == type && !memcmp(b->val, val, sizeof(b->val))) { + blocked = true; + break; + } + } + + rcu_read_unlock(); + return blocked; +} + +struct link_key *hci_find_link_key(struct hci_dev *hdev, bdaddr_t *bdaddr) +{ + struct link_key *k; + + rcu_read_lock(); + list_for_each_entry_rcu(k, &hdev->link_keys, list) { + if (bacmp(bdaddr, &k->bdaddr) == 0) { + rcu_read_unlock(); + + if (hci_is_blocked_key(hdev, + HCI_BLOCKED_KEY_TYPE_LINKKEY, + k->val)) { + bt_dev_warn_ratelimited(hdev, + "Link key blocked for %pMR", + &k->bdaddr); + return NULL; + } + + return k; + } + } + rcu_read_unlock(); + + return NULL; +} + +static bool hci_persistent_key(struct hci_dev *hdev, struct hci_conn *conn, + u8 key_type, u8 old_key_type) +{ + /* Legacy key */ + if (key_type < 0x03) + return true; + + /* Debug keys are insecure so don't store them persistently */ + if (key_type == HCI_LK_DEBUG_COMBINATION) + return false; + + /* Changed combination key and there's no previous one */ + if (key_type == HCI_LK_CHANGED_COMBINATION && old_key_type == 0xff) + return false; + + /* Security mode 3 case */ + if (!conn) + return true; + + /* BR/EDR key derived using SC from an LE link */ + if (conn->type == LE_LINK) + return true; + + /* Neither local nor remote side had no-bonding as requirement */ + if (conn->auth_type > 0x01 && conn->remote_auth > 0x01) + return true; + + /* Local side had dedicated bonding as requirement */ + if (conn->auth_type == 0x02 || conn->auth_type == 0x03) + return true; + + /* Remote side had dedicated bonding as requirement */ + if (conn->remote_auth == 0x02 || conn->remote_auth == 0x03) + return true; + + /* If none of the above criteria match, then don't store the key + * persistently */ + return false; +} + +static u8 ltk_role(u8 type) +{ + if (type == SMP_LTK) + return HCI_ROLE_MASTER; + + return HCI_ROLE_SLAVE; +} + +struct smp_ltk *hci_find_ltk(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 addr_type, u8 role) +{ + struct smp_ltk *k; + + rcu_read_lock(); + list_for_each_entry_rcu(k, &hdev->long_term_keys, list) { + if (addr_type != k->bdaddr_type || bacmp(bdaddr, &k->bdaddr)) + continue; + + if (smp_ltk_is_sc(k) || ltk_role(k->type) == role) { + rcu_read_unlock(); + + if (hci_is_blocked_key(hdev, HCI_BLOCKED_KEY_TYPE_LTK, + k->val)) { + bt_dev_warn_ratelimited(hdev, + "LTK blocked for %pMR", + &k->bdaddr); + return NULL; + } + + return k; + } + } + rcu_read_unlock(); + + return NULL; +} + +struct smp_irk *hci_find_irk_by_rpa(struct hci_dev *hdev, bdaddr_t *rpa) +{ + struct smp_irk *irk_to_return = NULL; + struct smp_irk *irk; + + rcu_read_lock(); + list_for_each_entry_rcu(irk, &hdev->identity_resolving_keys, list) { + if (!bacmp(&irk->rpa, rpa)) { + irk_to_return = irk; + goto done; + } + } + + list_for_each_entry_rcu(irk, &hdev->identity_resolving_keys, list) { + if (smp_irk_matches(hdev, irk->val, rpa)) { + bacpy(&irk->rpa, rpa); + irk_to_return = irk; + goto done; + } + } + +done: + if (irk_to_return && hci_is_blocked_key(hdev, HCI_BLOCKED_KEY_TYPE_IRK, + irk_to_return->val)) { + bt_dev_warn_ratelimited(hdev, "Identity key blocked for %pMR", + &irk_to_return->bdaddr); + irk_to_return = NULL; + } + + rcu_read_unlock(); + + return irk_to_return; +} + +struct smp_irk *hci_find_irk_by_addr(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 addr_type) +{ + struct smp_irk *irk_to_return = NULL; + struct smp_irk *irk; + + /* Identity Address must be public or static random */ + if (addr_type == ADDR_LE_DEV_RANDOM && (bdaddr->b[5] & 0xc0) != 0xc0) + return NULL; + + rcu_read_lock(); + list_for_each_entry_rcu(irk, &hdev->identity_resolving_keys, list) { + if (addr_type == irk->addr_type && + bacmp(bdaddr, &irk->bdaddr) == 0) { + irk_to_return = irk; + goto done; + } + } + +done: + + if (irk_to_return && hci_is_blocked_key(hdev, HCI_BLOCKED_KEY_TYPE_IRK, + irk_to_return->val)) { + bt_dev_warn_ratelimited(hdev, "Identity key blocked for %pMR", + &irk_to_return->bdaddr); + irk_to_return = NULL; + } + + rcu_read_unlock(); + + return irk_to_return; +} + +struct link_key *hci_add_link_key(struct hci_dev *hdev, struct hci_conn *conn, + bdaddr_t *bdaddr, u8 *val, u8 type, + u8 pin_len, bool *persistent) +{ + struct link_key *key, *old_key; + u8 old_key_type; + + old_key = hci_find_link_key(hdev, bdaddr); + if (old_key) { + old_key_type = old_key->type; + key = old_key; + } else { + old_key_type = conn ? conn->key_type : 0xff; + key = kzalloc(sizeof(*key), GFP_KERNEL); + if (!key) + return NULL; + list_add_rcu(&key->list, &hdev->link_keys); + } + + BT_DBG("%s key for %pMR type %u", hdev->name, bdaddr, type); + + /* Some buggy controller combinations generate a changed + * combination key for legacy pairing even when there's no + * previous key */ + if (type == HCI_LK_CHANGED_COMBINATION && + (!conn || conn->remote_auth == 0xff) && old_key_type == 0xff) { + type = HCI_LK_COMBINATION; + if (conn) + conn->key_type = type; + } + + bacpy(&key->bdaddr, bdaddr); + memcpy(key->val, val, HCI_LINK_KEY_SIZE); + key->pin_len = pin_len; + + if (type == HCI_LK_CHANGED_COMBINATION) + key->type = old_key_type; + else + key->type = type; + + if (persistent) + *persistent = hci_persistent_key(hdev, conn, type, + old_key_type); + + return key; +} + +struct smp_ltk *hci_add_ltk(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 addr_type, u8 type, u8 authenticated, + u8 tk[16], u8 enc_size, __le16 ediv, __le64 rand) +{ + struct smp_ltk *key, *old_key; + u8 role = ltk_role(type); + + old_key = hci_find_ltk(hdev, bdaddr, addr_type, role); + if (old_key) + key = old_key; + else { + key = kzalloc(sizeof(*key), GFP_KERNEL); + if (!key) + return NULL; + list_add_rcu(&key->list, &hdev->long_term_keys); + } + + bacpy(&key->bdaddr, bdaddr); + key->bdaddr_type = addr_type; + memcpy(key->val, tk, sizeof(key->val)); + key->authenticated = authenticated; + key->ediv = ediv; + key->rand = rand; + key->enc_size = enc_size; + key->type = type; + + return key; +} + +struct smp_irk *hci_add_irk(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 addr_type, u8 val[16], bdaddr_t *rpa) +{ + struct smp_irk *irk; + + irk = hci_find_irk_by_addr(hdev, bdaddr, addr_type); + if (!irk) { + irk = kzalloc(sizeof(*irk), GFP_KERNEL); + if (!irk) + return NULL; + + bacpy(&irk->bdaddr, bdaddr); + irk->addr_type = addr_type; + + list_add_rcu(&irk->list, &hdev->identity_resolving_keys); + } + + memcpy(irk->val, val, 16); + bacpy(&irk->rpa, rpa); + + return irk; +} + +int hci_remove_link_key(struct hci_dev *hdev, bdaddr_t *bdaddr) +{ + struct link_key *key; + + key = hci_find_link_key(hdev, bdaddr); + if (!key) + return -ENOENT; + + BT_DBG("%s removing %pMR", hdev->name, bdaddr); + + list_del_rcu(&key->list); + kfree_rcu(key, rcu); + + return 0; +} + +int hci_remove_ltk(struct hci_dev *hdev, bdaddr_t *bdaddr, u8 bdaddr_type) +{ + struct smp_ltk *k, *tmp; + int removed = 0; + + list_for_each_entry_safe(k, tmp, &hdev->long_term_keys, list) { + if (bacmp(bdaddr, &k->bdaddr) || k->bdaddr_type != bdaddr_type) + continue; + + BT_DBG("%s removing %pMR", hdev->name, bdaddr); + + list_del_rcu(&k->list); + kfree_rcu(k, rcu); + removed++; + } + + return removed ? 0 : -ENOENT; +} + +void hci_remove_irk(struct hci_dev *hdev, bdaddr_t *bdaddr, u8 addr_type) +{ + struct smp_irk *k, *tmp; + + list_for_each_entry_safe(k, tmp, &hdev->identity_resolving_keys, list) { + if (bacmp(bdaddr, &k->bdaddr) || k->addr_type != addr_type) + continue; + + BT_DBG("%s removing %pMR", hdev->name, bdaddr); + + list_del_rcu(&k->list); + kfree_rcu(k, rcu); + } +} + +bool hci_bdaddr_is_paired(struct hci_dev *hdev, bdaddr_t *bdaddr, u8 type) +{ + struct smp_ltk *k; + struct smp_irk *irk; + u8 addr_type; + + if (type == BDADDR_BREDR) { + if (hci_find_link_key(hdev, bdaddr)) + return true; + return false; + } + + /* Convert to HCI addr type which struct smp_ltk uses */ + if (type == BDADDR_LE_PUBLIC) + addr_type = ADDR_LE_DEV_PUBLIC; + else + addr_type = ADDR_LE_DEV_RANDOM; + + irk = hci_get_irk(hdev, bdaddr, addr_type); + if (irk) { + bdaddr = &irk->bdaddr; + addr_type = irk->addr_type; + } + + rcu_read_lock(); + list_for_each_entry_rcu(k, &hdev->long_term_keys, list) { + if (k->bdaddr_type == addr_type && !bacmp(bdaddr, &k->bdaddr)) { + rcu_read_unlock(); + return true; + } + } + rcu_read_unlock(); + + return false; +} + +/* HCI command timer function */ +static void hci_cmd_timeout(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, + cmd_timer.work); + + if (hdev->sent_cmd) { + struct hci_command_hdr *sent = (void *) hdev->sent_cmd->data; + u16 opcode = __le16_to_cpu(sent->opcode); + + bt_dev_err(hdev, "command 0x%4.4x tx timeout", opcode); + } else { + bt_dev_err(hdev, "command tx timeout"); + } + + if (hdev->cmd_timeout) + hdev->cmd_timeout(hdev); + + atomic_set(&hdev->cmd_cnt, 1); + queue_work(hdev->workqueue, &hdev->cmd_work); +} + +/* HCI ncmd timer function */ +static void hci_ncmd_timeout(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, + ncmd_timer.work); + + bt_dev_err(hdev, "Controller not accepting commands anymore: ncmd = 0"); + + /* During HCI_INIT phase no events can be injected if the ncmd timer + * triggers since the procedure has its own timeout handling. + */ + if (test_bit(HCI_INIT, &hdev->flags)) + return; + + /* This is an irrecoverable state, inject hardware error event */ + hci_reset_dev(hdev); +} + +struct oob_data *hci_find_remote_oob_data(struct hci_dev *hdev, + bdaddr_t *bdaddr, u8 bdaddr_type) +{ + struct oob_data *data; + + list_for_each_entry(data, &hdev->remote_oob_data, list) { + if (bacmp(bdaddr, &data->bdaddr) != 0) + continue; + if (data->bdaddr_type != bdaddr_type) + continue; + return data; + } + + return NULL; +} + +int hci_remove_remote_oob_data(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 bdaddr_type) +{ + struct oob_data *data; + + data = hci_find_remote_oob_data(hdev, bdaddr, bdaddr_type); + if (!data) + return -ENOENT; + + BT_DBG("%s removing %pMR (%u)", hdev->name, bdaddr, bdaddr_type); + + list_del(&data->list); + kfree(data); + + return 0; +} + +void hci_remote_oob_data_clear(struct hci_dev *hdev) +{ + struct oob_data *data, *n; + + list_for_each_entry_safe(data, n, &hdev->remote_oob_data, list) { + list_del(&data->list); + kfree(data); + } +} + +int hci_add_remote_oob_data(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 bdaddr_type, u8 *hash192, u8 *rand192, + u8 *hash256, u8 *rand256) +{ + struct oob_data *data; + + data = hci_find_remote_oob_data(hdev, bdaddr, bdaddr_type); + if (!data) { + data = kmalloc(sizeof(*data), GFP_KERNEL); + if (!data) + return -ENOMEM; + + bacpy(&data->bdaddr, bdaddr); + data->bdaddr_type = bdaddr_type; + list_add(&data->list, &hdev->remote_oob_data); + } + + if (hash192 && rand192) { + memcpy(data->hash192, hash192, sizeof(data->hash192)); + memcpy(data->rand192, rand192, sizeof(data->rand192)); + if (hash256 && rand256) + data->present = 0x03; + } else { + memset(data->hash192, 0, sizeof(data->hash192)); + memset(data->rand192, 0, sizeof(data->rand192)); + if (hash256 && rand256) + data->present = 0x02; + else + data->present = 0x00; + } + + if (hash256 && rand256) { + memcpy(data->hash256, hash256, sizeof(data->hash256)); + memcpy(data->rand256, rand256, sizeof(data->rand256)); + } else { + memset(data->hash256, 0, sizeof(data->hash256)); + memset(data->rand256, 0, sizeof(data->rand256)); + if (hash192 && rand192) + data->present = 0x01; + } + + BT_DBG("%s for %pMR", hdev->name, bdaddr); + + return 0; +} + +/* This function requires the caller holds hdev->lock */ +struct adv_info *hci_find_adv_instance(struct hci_dev *hdev, u8 instance) +{ + struct adv_info *adv_instance; + + list_for_each_entry(adv_instance, &hdev->adv_instances, list) { + if (adv_instance->instance == instance) + return adv_instance; + } + + return NULL; +} + +/* This function requires the caller holds hdev->lock */ +struct adv_info *hci_get_next_instance(struct hci_dev *hdev, u8 instance) +{ + struct adv_info *cur_instance; + + cur_instance = hci_find_adv_instance(hdev, instance); + if (!cur_instance) + return NULL; + + if (cur_instance == list_last_entry(&hdev->adv_instances, + struct adv_info, list)) + return list_first_entry(&hdev->adv_instances, + struct adv_info, list); + else + return list_next_entry(cur_instance, list); +} + +/* This function requires the caller holds hdev->lock */ +int hci_remove_adv_instance(struct hci_dev *hdev, u8 instance) +{ + struct adv_info *adv_instance; + + adv_instance = hci_find_adv_instance(hdev, instance); + if (!adv_instance) + return -ENOENT; + + BT_DBG("%s removing %dMR", hdev->name, instance); + + if (hdev->cur_adv_instance == instance) { + if (hdev->adv_instance_timeout) { + cancel_delayed_work(&hdev->adv_instance_expire); + hdev->adv_instance_timeout = 0; + } + hdev->cur_adv_instance = 0x00; + } + + cancel_delayed_work_sync(&adv_instance->rpa_expired_cb); + + list_del(&adv_instance->list); + kfree(adv_instance); + + hdev->adv_instance_cnt--; + + return 0; +} + +void hci_adv_instances_set_rpa_expired(struct hci_dev *hdev, bool rpa_expired) +{ + struct adv_info *adv_instance, *n; + + list_for_each_entry_safe(adv_instance, n, &hdev->adv_instances, list) + adv_instance->rpa_expired = rpa_expired; +} + +/* This function requires the caller holds hdev->lock */ +void hci_adv_instances_clear(struct hci_dev *hdev) +{ + struct adv_info *adv_instance, *n; + + if (hdev->adv_instance_timeout) { + cancel_delayed_work(&hdev->adv_instance_expire); + hdev->adv_instance_timeout = 0; + } + + list_for_each_entry_safe(adv_instance, n, &hdev->adv_instances, list) { + cancel_delayed_work_sync(&adv_instance->rpa_expired_cb); + list_del(&adv_instance->list); + kfree(adv_instance); + } + + hdev->adv_instance_cnt = 0; + hdev->cur_adv_instance = 0x00; +} + +static void adv_instance_rpa_expired(struct work_struct *work) +{ + struct adv_info *adv_instance = container_of(work, struct adv_info, + rpa_expired_cb.work); + + BT_DBG(""); + + adv_instance->rpa_expired = true; +} + +/* This function requires the caller holds hdev->lock */ +struct adv_info *hci_add_adv_instance(struct hci_dev *hdev, u8 instance, + u32 flags, u16 adv_data_len, u8 *adv_data, + u16 scan_rsp_len, u8 *scan_rsp_data, + u16 timeout, u16 duration, s8 tx_power, + u32 min_interval, u32 max_interval, + u8 mesh_handle) +{ + struct adv_info *adv; + + adv = hci_find_adv_instance(hdev, instance); + if (adv) { + memset(adv->adv_data, 0, sizeof(adv->adv_data)); + memset(adv->scan_rsp_data, 0, sizeof(adv->scan_rsp_data)); + memset(adv->per_adv_data, 0, sizeof(adv->per_adv_data)); + } else { + if (hdev->adv_instance_cnt >= hdev->le_num_of_adv_sets || + instance < 1 || instance > hdev->le_num_of_adv_sets + 1) + return ERR_PTR(-EOVERFLOW); + + adv = kzalloc(sizeof(*adv), GFP_KERNEL); + if (!adv) + return ERR_PTR(-ENOMEM); + + adv->pending = true; + adv->instance = instance; + list_add(&adv->list, &hdev->adv_instances); + hdev->adv_instance_cnt++; + } + + adv->flags = flags; + adv->min_interval = min_interval; + adv->max_interval = max_interval; + adv->tx_power = tx_power; + /* Defining a mesh_handle changes the timing units to ms, + * rather than seconds, and ties the instance to the requested + * mesh_tx queue. + */ + adv->mesh = mesh_handle; + + hci_set_adv_instance_data(hdev, instance, adv_data_len, adv_data, + scan_rsp_len, scan_rsp_data); + + adv->timeout = timeout; + adv->remaining_time = timeout; + + if (duration == 0) + adv->duration = hdev->def_multi_adv_rotation_duration; + else + adv->duration = duration; + + INIT_DELAYED_WORK(&adv->rpa_expired_cb, adv_instance_rpa_expired); + + BT_DBG("%s for %dMR", hdev->name, instance); + + return adv; +} + +/* This function requires the caller holds hdev->lock */ +struct adv_info *hci_add_per_instance(struct hci_dev *hdev, u8 instance, + u32 flags, u8 data_len, u8 *data, + u32 min_interval, u32 max_interval) +{ + struct adv_info *adv; + + adv = hci_add_adv_instance(hdev, instance, flags, 0, NULL, 0, NULL, + 0, 0, HCI_ADV_TX_POWER_NO_PREFERENCE, + min_interval, max_interval, 0); + if (IS_ERR(adv)) + return adv; + + adv->periodic = true; + adv->per_adv_data_len = data_len; + + if (data) + memcpy(adv->per_adv_data, data, data_len); + + return adv; +} + +/* This function requires the caller holds hdev->lock */ +int hci_set_adv_instance_data(struct hci_dev *hdev, u8 instance, + u16 adv_data_len, u8 *adv_data, + u16 scan_rsp_len, u8 *scan_rsp_data) +{ + struct adv_info *adv; + + adv = hci_find_adv_instance(hdev, instance); + + /* If advertisement doesn't exist, we can't modify its data */ + if (!adv) + return -ENOENT; + + if (adv_data_len && ADV_DATA_CMP(adv, adv_data, adv_data_len)) { + memset(adv->adv_data, 0, sizeof(adv->adv_data)); + memcpy(adv->adv_data, adv_data, adv_data_len); + adv->adv_data_len = adv_data_len; + adv->adv_data_changed = true; + } + + if (scan_rsp_len && SCAN_RSP_CMP(adv, scan_rsp_data, scan_rsp_len)) { + memset(adv->scan_rsp_data, 0, sizeof(adv->scan_rsp_data)); + memcpy(adv->scan_rsp_data, scan_rsp_data, scan_rsp_len); + adv->scan_rsp_len = scan_rsp_len; + adv->scan_rsp_changed = true; + } + + /* Mark as changed if there are flags which would affect it */ + if (((adv->flags & MGMT_ADV_FLAG_APPEARANCE) && hdev->appearance) || + adv->flags & MGMT_ADV_FLAG_LOCAL_NAME) + adv->scan_rsp_changed = true; + + return 0; +} + +/* This function requires the caller holds hdev->lock */ +u32 hci_adv_instance_flags(struct hci_dev *hdev, u8 instance) +{ + u32 flags; + struct adv_info *adv; + + if (instance == 0x00) { + /* Instance 0 always manages the "Tx Power" and "Flags" + * fields + */ + flags = MGMT_ADV_FLAG_TX_POWER | MGMT_ADV_FLAG_MANAGED_FLAGS; + + /* For instance 0, the HCI_ADVERTISING_CONNECTABLE setting + * corresponds to the "connectable" instance flag. + */ + if (hci_dev_test_flag(hdev, HCI_ADVERTISING_CONNECTABLE)) + flags |= MGMT_ADV_FLAG_CONNECTABLE; + + if (hci_dev_test_flag(hdev, HCI_LIMITED_DISCOVERABLE)) + flags |= MGMT_ADV_FLAG_LIMITED_DISCOV; + else if (hci_dev_test_flag(hdev, HCI_DISCOVERABLE)) + flags |= MGMT_ADV_FLAG_DISCOV; + + return flags; + } + + adv = hci_find_adv_instance(hdev, instance); + + /* Return 0 when we got an invalid instance identifier. */ + if (!adv) + return 0; + + return adv->flags; +} + +bool hci_adv_instance_is_scannable(struct hci_dev *hdev, u8 instance) +{ + struct adv_info *adv; + + /* Instance 0x00 always set local name */ + if (instance == 0x00) + return true; + + adv = hci_find_adv_instance(hdev, instance); + if (!adv) + return false; + + if (adv->flags & MGMT_ADV_FLAG_APPEARANCE || + adv->flags & MGMT_ADV_FLAG_LOCAL_NAME) + return true; + + return adv->scan_rsp_len ? true : false; +} + +/* This function requires the caller holds hdev->lock */ +void hci_adv_monitors_clear(struct hci_dev *hdev) +{ + struct adv_monitor *monitor; + int handle; + + idr_for_each_entry(&hdev->adv_monitors_idr, monitor, handle) + hci_free_adv_monitor(hdev, monitor); + + idr_destroy(&hdev->adv_monitors_idr); +} + +/* Frees the monitor structure and do some bookkeepings. + * This function requires the caller holds hdev->lock. + */ +void hci_free_adv_monitor(struct hci_dev *hdev, struct adv_monitor *monitor) +{ + struct adv_pattern *pattern; + struct adv_pattern *tmp; + + if (!monitor) + return; + + list_for_each_entry_safe(pattern, tmp, &monitor->patterns, list) { + list_del(&pattern->list); + kfree(pattern); + } + + if (monitor->handle) + idr_remove(&hdev->adv_monitors_idr, monitor->handle); + + if (monitor->state != ADV_MONITOR_STATE_NOT_REGISTERED) { + hdev->adv_monitors_cnt--; + mgmt_adv_monitor_removed(hdev, monitor->handle); + } + + kfree(monitor); +} + +/* Assigns handle to a monitor, and if offloading is supported and power is on, + * also attempts to forward the request to the controller. + * This function requires the caller holds hci_req_sync_lock. + */ +int hci_add_adv_monitor(struct hci_dev *hdev, struct adv_monitor *monitor) +{ + int min, max, handle; + int status = 0; + + if (!monitor) + return -EINVAL; + + hci_dev_lock(hdev); + + min = HCI_MIN_ADV_MONITOR_HANDLE; + max = HCI_MIN_ADV_MONITOR_HANDLE + HCI_MAX_ADV_MONITOR_NUM_HANDLES; + handle = idr_alloc(&hdev->adv_monitors_idr, monitor, min, max, + GFP_KERNEL); + + hci_dev_unlock(hdev); + + if (handle < 0) + return handle; + + monitor->handle = handle; + + if (!hdev_is_powered(hdev)) + return status; + + switch (hci_get_adv_monitor_offload_ext(hdev)) { + case HCI_ADV_MONITOR_EXT_NONE: + bt_dev_dbg(hdev, "add monitor %d status %d", + monitor->handle, status); + /* Message was not forwarded to controller - not an error */ + break; + + case HCI_ADV_MONITOR_EXT_MSFT: + status = msft_add_monitor_pattern(hdev, monitor); + bt_dev_dbg(hdev, "add monitor %d msft status %d", + handle, status); + break; + } + + return status; +} + +/* Attempts to tell the controller and free the monitor. If somehow the + * controller doesn't have a corresponding handle, remove anyway. + * This function requires the caller holds hci_req_sync_lock. + */ +static int hci_remove_adv_monitor(struct hci_dev *hdev, + struct adv_monitor *monitor) +{ + int status = 0; + int handle; + + switch (hci_get_adv_monitor_offload_ext(hdev)) { + case HCI_ADV_MONITOR_EXT_NONE: /* also goes here when powered off */ + bt_dev_dbg(hdev, "remove monitor %d status %d", + monitor->handle, status); + goto free_monitor; + + case HCI_ADV_MONITOR_EXT_MSFT: + handle = monitor->handle; + status = msft_remove_monitor(hdev, monitor); + bt_dev_dbg(hdev, "remove monitor %d msft status %d", + handle, status); + break; + } + + /* In case no matching handle registered, just free the monitor */ + if (status == -ENOENT) + goto free_monitor; + + return status; + +free_monitor: + if (status == -ENOENT) + bt_dev_warn(hdev, "Removing monitor with no matching handle %d", + monitor->handle); + hci_free_adv_monitor(hdev, monitor); + + return status; +} + +/* This function requires the caller holds hci_req_sync_lock */ +int hci_remove_single_adv_monitor(struct hci_dev *hdev, u16 handle) +{ + struct adv_monitor *monitor = idr_find(&hdev->adv_monitors_idr, handle); + + if (!monitor) + return -EINVAL; + + return hci_remove_adv_monitor(hdev, monitor); +} + +/* This function requires the caller holds hci_req_sync_lock */ +int hci_remove_all_adv_monitor(struct hci_dev *hdev) +{ + struct adv_monitor *monitor; + int idr_next_id = 0; + int status = 0; + + while (1) { + monitor = idr_get_next(&hdev->adv_monitors_idr, &idr_next_id); + if (!monitor) + break; + + status = hci_remove_adv_monitor(hdev, monitor); + if (status) + return status; + + idr_next_id++; + } + + return status; +} + +/* This function requires the caller holds hdev->lock */ +bool hci_is_adv_monitoring(struct hci_dev *hdev) +{ + return !idr_is_empty(&hdev->adv_monitors_idr); +} + +int hci_get_adv_monitor_offload_ext(struct hci_dev *hdev) +{ + if (msft_monitor_supported(hdev)) + return HCI_ADV_MONITOR_EXT_MSFT; + + return HCI_ADV_MONITOR_EXT_NONE; +} + +struct bdaddr_list *hci_bdaddr_list_lookup(struct list_head *bdaddr_list, + bdaddr_t *bdaddr, u8 type) +{ + struct bdaddr_list *b; + + list_for_each_entry(b, bdaddr_list, list) { + if (!bacmp(&b->bdaddr, bdaddr) && b->bdaddr_type == type) + return b; + } + + return NULL; +} + +struct bdaddr_list_with_irk *hci_bdaddr_list_lookup_with_irk( + struct list_head *bdaddr_list, bdaddr_t *bdaddr, + u8 type) +{ + struct bdaddr_list_with_irk *b; + + list_for_each_entry(b, bdaddr_list, list) { + if (!bacmp(&b->bdaddr, bdaddr) && b->bdaddr_type == type) + return b; + } + + return NULL; +} + +struct bdaddr_list_with_flags * +hci_bdaddr_list_lookup_with_flags(struct list_head *bdaddr_list, + bdaddr_t *bdaddr, u8 type) +{ + struct bdaddr_list_with_flags *b; + + list_for_each_entry(b, bdaddr_list, list) { + if (!bacmp(&b->bdaddr, bdaddr) && b->bdaddr_type == type) + return b; + } + + return NULL; +} + +void hci_bdaddr_list_clear(struct list_head *bdaddr_list) +{ + struct bdaddr_list *b, *n; + + list_for_each_entry_safe(b, n, bdaddr_list, list) { + list_del(&b->list); + kfree(b); + } +} + +int hci_bdaddr_list_add(struct list_head *list, bdaddr_t *bdaddr, u8 type) +{ + struct bdaddr_list *entry; + + if (!bacmp(bdaddr, BDADDR_ANY)) + return -EBADF; + + if (hci_bdaddr_list_lookup(list, bdaddr, type)) + return -EEXIST; + + entry = kzalloc(sizeof(*entry), GFP_KERNEL); + if (!entry) + return -ENOMEM; + + bacpy(&entry->bdaddr, bdaddr); + entry->bdaddr_type = type; + + list_add(&entry->list, list); + + return 0; +} + +int hci_bdaddr_list_add_with_irk(struct list_head *list, bdaddr_t *bdaddr, + u8 type, u8 *peer_irk, u8 *local_irk) +{ + struct bdaddr_list_with_irk *entry; + + if (!bacmp(bdaddr, BDADDR_ANY)) + return -EBADF; + + if (hci_bdaddr_list_lookup(list, bdaddr, type)) + return -EEXIST; + + entry = kzalloc(sizeof(*entry), GFP_KERNEL); + if (!entry) + return -ENOMEM; + + bacpy(&entry->bdaddr, bdaddr); + entry->bdaddr_type = type; + + if (peer_irk) + memcpy(entry->peer_irk, peer_irk, 16); + + if (local_irk) + memcpy(entry->local_irk, local_irk, 16); + + list_add(&entry->list, list); + + return 0; +} + +int hci_bdaddr_list_add_with_flags(struct list_head *list, bdaddr_t *bdaddr, + u8 type, u32 flags) +{ + struct bdaddr_list_with_flags *entry; + + if (!bacmp(bdaddr, BDADDR_ANY)) + return -EBADF; + + if (hci_bdaddr_list_lookup(list, bdaddr, type)) + return -EEXIST; + + entry = kzalloc(sizeof(*entry), GFP_KERNEL); + if (!entry) + return -ENOMEM; + + bacpy(&entry->bdaddr, bdaddr); + entry->bdaddr_type = type; + entry->flags = flags; + + list_add(&entry->list, list); + + return 0; +} + +int hci_bdaddr_list_del(struct list_head *list, bdaddr_t *bdaddr, u8 type) +{ + struct bdaddr_list *entry; + + if (!bacmp(bdaddr, BDADDR_ANY)) { + hci_bdaddr_list_clear(list); + return 0; + } + + entry = hci_bdaddr_list_lookup(list, bdaddr, type); + if (!entry) + return -ENOENT; + + list_del(&entry->list); + kfree(entry); + + return 0; +} + +int hci_bdaddr_list_del_with_irk(struct list_head *list, bdaddr_t *bdaddr, + u8 type) +{ + struct bdaddr_list_with_irk *entry; + + if (!bacmp(bdaddr, BDADDR_ANY)) { + hci_bdaddr_list_clear(list); + return 0; + } + + entry = hci_bdaddr_list_lookup_with_irk(list, bdaddr, type); + if (!entry) + return -ENOENT; + + list_del(&entry->list); + kfree(entry); + + return 0; +} + +int hci_bdaddr_list_del_with_flags(struct list_head *list, bdaddr_t *bdaddr, + u8 type) +{ + struct bdaddr_list_with_flags *entry; + + if (!bacmp(bdaddr, BDADDR_ANY)) { + hci_bdaddr_list_clear(list); + return 0; + } + + entry = hci_bdaddr_list_lookup_with_flags(list, bdaddr, type); + if (!entry) + return -ENOENT; + + list_del(&entry->list); + kfree(entry); + + return 0; +} + +/* This function requires the caller holds hdev->lock */ +struct hci_conn_params *hci_conn_params_lookup(struct hci_dev *hdev, + bdaddr_t *addr, u8 addr_type) +{ + struct hci_conn_params *params; + + list_for_each_entry(params, &hdev->le_conn_params, list) { + if (bacmp(¶ms->addr, addr) == 0 && + params->addr_type == addr_type) { + return params; + } + } + + return NULL; +} + +/* This function requires the caller holds hdev->lock or rcu_read_lock */ +struct hci_conn_params *hci_pend_le_action_lookup(struct list_head *list, + bdaddr_t *addr, u8 addr_type) +{ + struct hci_conn_params *param; + + rcu_read_lock(); + + list_for_each_entry_rcu(param, list, action) { + if (bacmp(¶m->addr, addr) == 0 && + param->addr_type == addr_type) { + rcu_read_unlock(); + return param; + } + } + + rcu_read_unlock(); + + return NULL; +} + +/* This function requires the caller holds hdev->lock */ +void hci_pend_le_list_del_init(struct hci_conn_params *param) +{ + if (list_empty(¶m->action)) + return; + + list_del_rcu(¶m->action); + synchronize_rcu(); + INIT_LIST_HEAD(¶m->action); +} + +/* This function requires the caller holds hdev->lock */ +void hci_pend_le_list_add(struct hci_conn_params *param, + struct list_head *list) +{ + list_add_rcu(¶m->action, list); +} + +/* This function requires the caller holds hdev->lock */ +struct hci_conn_params *hci_conn_params_add(struct hci_dev *hdev, + bdaddr_t *addr, u8 addr_type) +{ + struct hci_conn_params *params; + + params = hci_conn_params_lookup(hdev, addr, addr_type); + if (params) + return params; + + params = kzalloc(sizeof(*params), GFP_KERNEL); + if (!params) { + bt_dev_err(hdev, "out of memory"); + return NULL; + } + + bacpy(¶ms->addr, addr); + params->addr_type = addr_type; + + list_add(¶ms->list, &hdev->le_conn_params); + INIT_LIST_HEAD(¶ms->action); + + params->conn_min_interval = hdev->le_conn_min_interval; + params->conn_max_interval = hdev->le_conn_max_interval; + params->conn_latency = hdev->le_conn_latency; + params->supervision_timeout = hdev->le_supv_timeout; + params->auto_connect = HCI_AUTO_CONN_DISABLED; + + BT_DBG("addr %pMR (type %u)", addr, addr_type); + + return params; +} + +void hci_conn_params_free(struct hci_conn_params *params) +{ + hci_pend_le_list_del_init(params); + + if (params->conn) { + hci_conn_drop(params->conn); + hci_conn_put(params->conn); + } + + list_del(¶ms->list); + kfree(params); +} + +/* This function requires the caller holds hdev->lock */ +void hci_conn_params_del(struct hci_dev *hdev, bdaddr_t *addr, u8 addr_type) +{ + struct hci_conn_params *params; + + params = hci_conn_params_lookup(hdev, addr, addr_type); + if (!params) + return; + + hci_conn_params_free(params); + + hci_update_passive_scan(hdev); + + BT_DBG("addr %pMR (type %u)", addr, addr_type); +} + +/* This function requires the caller holds hdev->lock */ +void hci_conn_params_clear_disabled(struct hci_dev *hdev) +{ + struct hci_conn_params *params, *tmp; + + list_for_each_entry_safe(params, tmp, &hdev->le_conn_params, list) { + if (params->auto_connect != HCI_AUTO_CONN_DISABLED) + continue; + + /* If trying to establish one time connection to disabled + * device, leave the params, but mark them as just once. + */ + if (params->explicit_connect) { + params->auto_connect = HCI_AUTO_CONN_EXPLICIT; + continue; + } + + hci_conn_params_free(params); + } + + BT_DBG("All LE disabled connection parameters were removed"); +} + +/* This function requires the caller holds hdev->lock */ +static void hci_conn_params_clear_all(struct hci_dev *hdev) +{ + struct hci_conn_params *params, *tmp; + + list_for_each_entry_safe(params, tmp, &hdev->le_conn_params, list) + hci_conn_params_free(params); + + BT_DBG("All LE connection parameters were removed"); +} + +/* Copy the Identity Address of the controller. + * + * If the controller has a public BD_ADDR, then by default use that one. + * If this is a LE only controller without a public address, default to + * the static random address. + * + * For debugging purposes it is possible to force controllers with a + * public address to use the static random address instead. + * + * In case BR/EDR has been disabled on a dual-mode controller and + * userspace has configured a static address, then that address + * becomes the identity address instead of the public BR/EDR address. + */ +void hci_copy_identity_address(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 *bdaddr_type) +{ + if (hci_dev_test_flag(hdev, HCI_FORCE_STATIC_ADDR) || + !bacmp(&hdev->bdaddr, BDADDR_ANY) || + (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED) && + bacmp(&hdev->static_addr, BDADDR_ANY))) { + bacpy(bdaddr, &hdev->static_addr); + *bdaddr_type = ADDR_LE_DEV_RANDOM; + } else { + bacpy(bdaddr, &hdev->bdaddr); + *bdaddr_type = ADDR_LE_DEV_PUBLIC; + } +} + +static void hci_clear_wake_reason(struct hci_dev *hdev) +{ + hci_dev_lock(hdev); + + hdev->wake_reason = 0; + bacpy(&hdev->wake_addr, BDADDR_ANY); + hdev->wake_addr_type = 0; + + hci_dev_unlock(hdev); +} + +static int hci_suspend_notifier(struct notifier_block *nb, unsigned long action, + void *data) +{ + struct hci_dev *hdev = + container_of(nb, struct hci_dev, suspend_notifier); + int ret = 0; + + /* Userspace has full control of this device. Do nothing. */ + if (hci_dev_test_flag(hdev, HCI_USER_CHANNEL)) + return NOTIFY_DONE; + + /* To avoid a potential race with hci_unregister_dev. */ + hci_dev_hold(hdev); + + if (action == PM_SUSPEND_PREPARE) + ret = hci_suspend_dev(hdev); + else if (action == PM_POST_SUSPEND) + ret = hci_resume_dev(hdev); + + if (ret) + bt_dev_err(hdev, "Suspend notifier action (%lu) failed: %d", + action, ret); + + hci_dev_put(hdev); + return NOTIFY_DONE; +} + +/* Alloc HCI device */ +struct hci_dev *hci_alloc_dev_priv(int sizeof_priv) +{ + struct hci_dev *hdev; + unsigned int alloc_size; + + alloc_size = sizeof(*hdev); + if (sizeof_priv) { + /* Fixme: May need ALIGN-ment? */ + alloc_size += sizeof_priv; + } + + hdev = kzalloc(alloc_size, GFP_KERNEL); + if (!hdev) + return NULL; + + hdev->pkt_type = (HCI_DM1 | HCI_DH1 | HCI_HV1); + hdev->esco_type = (ESCO_HV1); + hdev->link_mode = (HCI_LM_ACCEPT); + hdev->num_iac = 0x01; /* One IAC support is mandatory */ + hdev->io_capability = 0x03; /* No Input No Output */ + hdev->manufacturer = 0xffff; /* Default to internal use */ + hdev->inq_tx_power = HCI_TX_POWER_INVALID; + hdev->adv_tx_power = HCI_TX_POWER_INVALID; + hdev->adv_instance_cnt = 0; + hdev->cur_adv_instance = 0x00; + hdev->adv_instance_timeout = 0; + + hdev->advmon_allowlist_duration = 300; + hdev->advmon_no_filter_duration = 500; + hdev->enable_advmon_interleave_scan = 0x00; /* Default to disable */ + + hdev->sniff_max_interval = 800; + hdev->sniff_min_interval = 80; + + hdev->le_adv_channel_map = 0x07; + hdev->le_adv_min_interval = 0x0800; + hdev->le_adv_max_interval = 0x0800; + hdev->le_scan_interval = 0x0060; + hdev->le_scan_window = 0x0030; + hdev->le_scan_int_suspend = 0x0400; + hdev->le_scan_window_suspend = 0x0012; + hdev->le_scan_int_discovery = DISCOV_LE_SCAN_INT; + hdev->le_scan_window_discovery = DISCOV_LE_SCAN_WIN; + hdev->le_scan_int_adv_monitor = 0x0060; + hdev->le_scan_window_adv_monitor = 0x0030; + hdev->le_scan_int_connect = 0x0060; + hdev->le_scan_window_connect = 0x0060; + hdev->le_conn_min_interval = 0x0018; + hdev->le_conn_max_interval = 0x0028; + hdev->le_conn_latency = 0x0000; + hdev->le_supv_timeout = 0x002a; + hdev->le_def_tx_len = 0x001b; + hdev->le_def_tx_time = 0x0148; + hdev->le_max_tx_len = 0x001b; + hdev->le_max_tx_time = 0x0148; + hdev->le_max_rx_len = 0x001b; + hdev->le_max_rx_time = 0x0148; + hdev->le_max_key_size = SMP_MAX_ENC_KEY_SIZE; + hdev->le_min_key_size = SMP_MIN_ENC_KEY_SIZE; + hdev->le_tx_def_phys = HCI_LE_SET_PHY_1M; + hdev->le_rx_def_phys = HCI_LE_SET_PHY_1M; + hdev->le_num_of_adv_sets = HCI_MAX_ADV_INSTANCES; + hdev->def_multi_adv_rotation_duration = HCI_DEFAULT_ADV_DURATION; + hdev->def_le_autoconnect_timeout = HCI_LE_AUTOCONN_TIMEOUT; + hdev->min_le_tx_power = HCI_TX_POWER_INVALID; + hdev->max_le_tx_power = HCI_TX_POWER_INVALID; + + hdev->rpa_timeout = HCI_DEFAULT_RPA_TIMEOUT; + hdev->discov_interleaved_timeout = DISCOV_INTERLEAVED_TIMEOUT; + hdev->conn_info_min_age = DEFAULT_CONN_INFO_MIN_AGE; + hdev->conn_info_max_age = DEFAULT_CONN_INFO_MAX_AGE; + hdev->auth_payload_timeout = DEFAULT_AUTH_PAYLOAD_TIMEOUT; + hdev->min_enc_key_size = HCI_MIN_ENC_KEY_SIZE; + + /* default 1.28 sec page scan */ + hdev->def_page_scan_type = PAGE_SCAN_TYPE_STANDARD; + hdev->def_page_scan_int = 0x0800; + hdev->def_page_scan_window = 0x0012; + + mutex_init(&hdev->lock); + mutex_init(&hdev->req_lock); + + INIT_LIST_HEAD(&hdev->mesh_pending); + INIT_LIST_HEAD(&hdev->mgmt_pending); + INIT_LIST_HEAD(&hdev->reject_list); + INIT_LIST_HEAD(&hdev->accept_list); + INIT_LIST_HEAD(&hdev->uuids); + INIT_LIST_HEAD(&hdev->link_keys); + INIT_LIST_HEAD(&hdev->long_term_keys); + INIT_LIST_HEAD(&hdev->identity_resolving_keys); + INIT_LIST_HEAD(&hdev->remote_oob_data); + INIT_LIST_HEAD(&hdev->le_accept_list); + INIT_LIST_HEAD(&hdev->le_resolv_list); + INIT_LIST_HEAD(&hdev->le_conn_params); + INIT_LIST_HEAD(&hdev->pend_le_conns); + INIT_LIST_HEAD(&hdev->pend_le_reports); + INIT_LIST_HEAD(&hdev->conn_hash.list); + INIT_LIST_HEAD(&hdev->adv_instances); + INIT_LIST_HEAD(&hdev->blocked_keys); + INIT_LIST_HEAD(&hdev->monitored_devices); + + INIT_LIST_HEAD(&hdev->local_codecs); + INIT_WORK(&hdev->rx_work, hci_rx_work); + INIT_WORK(&hdev->cmd_work, hci_cmd_work); + INIT_WORK(&hdev->tx_work, hci_tx_work); + INIT_WORK(&hdev->power_on, hci_power_on); + INIT_WORK(&hdev->error_reset, hci_error_reset); + + hci_cmd_sync_init(hdev); + + INIT_DELAYED_WORK(&hdev->power_off, hci_power_off); + + skb_queue_head_init(&hdev->rx_q); + skb_queue_head_init(&hdev->cmd_q); + skb_queue_head_init(&hdev->raw_q); + + init_waitqueue_head(&hdev->req_wait_q); + + INIT_DELAYED_WORK(&hdev->cmd_timer, hci_cmd_timeout); + INIT_DELAYED_WORK(&hdev->ncmd_timer, hci_ncmd_timeout); + + hci_request_setup(hdev); + + hci_init_sysfs(hdev); + discovery_init(hdev); + + return hdev; +} +EXPORT_SYMBOL(hci_alloc_dev_priv); + +/* Free HCI device */ +void hci_free_dev(struct hci_dev *hdev) +{ + /* will free via device release */ + put_device(&hdev->dev); +} +EXPORT_SYMBOL(hci_free_dev); + +/* Register HCI device */ +int hci_register_dev(struct hci_dev *hdev) +{ + int id, error; + + if (!hdev->open || !hdev->close || !hdev->send) + return -EINVAL; + + /* Do not allow HCI_AMP devices to register at index 0, + * so the index can be used as the AMP controller ID. + */ + switch (hdev->dev_type) { + case HCI_PRIMARY: + id = ida_simple_get(&hci_index_ida, 0, HCI_MAX_ID, GFP_KERNEL); + break; + case HCI_AMP: + id = ida_simple_get(&hci_index_ida, 1, HCI_MAX_ID, GFP_KERNEL); + break; + default: + return -EINVAL; + } + + if (id < 0) + return id; + + error = dev_set_name(&hdev->dev, "hci%u", id); + if (error) + return error; + + hdev->name = dev_name(&hdev->dev); + hdev->id = id; + + BT_DBG("%p name %s bus %d", hdev, hdev->name, hdev->bus); + + hdev->workqueue = alloc_ordered_workqueue("%s", WQ_HIGHPRI, hdev->name); + if (!hdev->workqueue) { + error = -ENOMEM; + goto err; + } + + hdev->req_workqueue = alloc_ordered_workqueue("%s", WQ_HIGHPRI, + hdev->name); + if (!hdev->req_workqueue) { + destroy_workqueue(hdev->workqueue); + error = -ENOMEM; + goto err; + } + + if (!IS_ERR_OR_NULL(bt_debugfs)) + hdev->debugfs = debugfs_create_dir(hdev->name, bt_debugfs); + + error = device_add(&hdev->dev); + if (error < 0) + goto err_wqueue; + + hci_leds_init(hdev); + + hdev->rfkill = rfkill_alloc(hdev->name, &hdev->dev, + RFKILL_TYPE_BLUETOOTH, &hci_rfkill_ops, + hdev); + if (hdev->rfkill) { + if (rfkill_register(hdev->rfkill) < 0) { + rfkill_destroy(hdev->rfkill); + hdev->rfkill = NULL; + } + } + + if (hdev->rfkill && rfkill_blocked(hdev->rfkill)) + hci_dev_set_flag(hdev, HCI_RFKILLED); + + hci_dev_set_flag(hdev, HCI_SETUP); + hci_dev_set_flag(hdev, HCI_AUTO_OFF); + + if (hdev->dev_type == HCI_PRIMARY) { + /* Assume BR/EDR support until proven otherwise (such as + * through reading supported features during init. + */ + hci_dev_set_flag(hdev, HCI_BREDR_ENABLED); + } + + write_lock(&hci_dev_list_lock); + list_add(&hdev->list, &hci_dev_list); + write_unlock(&hci_dev_list_lock); + + /* Devices that are marked for raw-only usage are unconfigured + * and should not be included in normal operation. + */ + if (test_bit(HCI_QUIRK_RAW_DEVICE, &hdev->quirks)) + hci_dev_set_flag(hdev, HCI_UNCONFIGURED); + + /* Mark Remote Wakeup connection flag as supported if driver has wakeup + * callback. + */ + if (hdev->wakeup) + hdev->conn_flags |= HCI_CONN_FLAG_REMOTE_WAKEUP; + + hci_sock_dev_event(hdev, HCI_DEV_REG); + hci_dev_hold(hdev); + + error = hci_register_suspend_notifier(hdev); + if (error) + BT_WARN("register suspend notifier failed error:%d\n", error); + + queue_work(hdev->req_workqueue, &hdev->power_on); + + idr_init(&hdev->adv_monitors_idr); + msft_register(hdev); + + return id; + +err_wqueue: + debugfs_remove_recursive(hdev->debugfs); + destroy_workqueue(hdev->workqueue); + destroy_workqueue(hdev->req_workqueue); +err: + ida_simple_remove(&hci_index_ida, hdev->id); + + return error; +} +EXPORT_SYMBOL(hci_register_dev); + +/* Unregister HCI device */ +void hci_unregister_dev(struct hci_dev *hdev) +{ + BT_DBG("%p name %s bus %d", hdev, hdev->name, hdev->bus); + + mutex_lock(&hdev->unregister_lock); + hci_dev_set_flag(hdev, HCI_UNREGISTER); + mutex_unlock(&hdev->unregister_lock); + + write_lock(&hci_dev_list_lock); + list_del(&hdev->list); + write_unlock(&hci_dev_list_lock); + + cancel_work_sync(&hdev->power_on); + + hci_cmd_sync_clear(hdev); + + hci_unregister_suspend_notifier(hdev); + + msft_unregister(hdev); + + hci_dev_do_close(hdev); + + if (!test_bit(HCI_INIT, &hdev->flags) && + !hci_dev_test_flag(hdev, HCI_SETUP) && + !hci_dev_test_flag(hdev, HCI_CONFIG)) { + hci_dev_lock(hdev); + mgmt_index_removed(hdev); + hci_dev_unlock(hdev); + } + + /* mgmt_index_removed should take care of emptying the + * pending list */ + BUG_ON(!list_empty(&hdev->mgmt_pending)); + + hci_sock_dev_event(hdev, HCI_DEV_UNREG); + + if (hdev->rfkill) { + rfkill_unregister(hdev->rfkill); + rfkill_destroy(hdev->rfkill); + } + + device_del(&hdev->dev); + /* Actual cleanup is deferred until hci_release_dev(). */ + hci_dev_put(hdev); +} +EXPORT_SYMBOL(hci_unregister_dev); + +/* Release HCI device */ +void hci_release_dev(struct hci_dev *hdev) +{ + debugfs_remove_recursive(hdev->debugfs); + kfree_const(hdev->hw_info); + kfree_const(hdev->fw_info); + + destroy_workqueue(hdev->workqueue); + destroy_workqueue(hdev->req_workqueue); + + hci_dev_lock(hdev); + hci_bdaddr_list_clear(&hdev->reject_list); + hci_bdaddr_list_clear(&hdev->accept_list); + hci_uuids_clear(hdev); + hci_link_keys_clear(hdev); + hci_smp_ltks_clear(hdev); + hci_smp_irks_clear(hdev); + hci_remote_oob_data_clear(hdev); + hci_adv_instances_clear(hdev); + hci_adv_monitors_clear(hdev); + hci_bdaddr_list_clear(&hdev->le_accept_list); + hci_bdaddr_list_clear(&hdev->le_resolv_list); + hci_conn_params_clear_all(hdev); + hci_discovery_filter_clear(hdev); + hci_blocked_keys_clear(hdev); + hci_codec_list_clear(&hdev->local_codecs); + hci_dev_unlock(hdev); + + ida_simple_remove(&hci_index_ida, hdev->id); + kfree_skb(hdev->sent_cmd); + kfree_skb(hdev->recv_event); + kfree(hdev); +} +EXPORT_SYMBOL(hci_release_dev); + +int hci_register_suspend_notifier(struct hci_dev *hdev) +{ + int ret = 0; + + if (!hdev->suspend_notifier.notifier_call && + !test_bit(HCI_QUIRK_NO_SUSPEND_NOTIFIER, &hdev->quirks)) { + hdev->suspend_notifier.notifier_call = hci_suspend_notifier; + ret = register_pm_notifier(&hdev->suspend_notifier); + } + + return ret; +} + +int hci_unregister_suspend_notifier(struct hci_dev *hdev) +{ + int ret = 0; + + if (hdev->suspend_notifier.notifier_call) { + ret = unregister_pm_notifier(&hdev->suspend_notifier); + if (!ret) + hdev->suspend_notifier.notifier_call = NULL; + } + + return ret; +} + +/* Suspend HCI device */ +int hci_suspend_dev(struct hci_dev *hdev) +{ + int ret; + + bt_dev_dbg(hdev, ""); + + /* Suspend should only act on when powered. */ + if (!hdev_is_powered(hdev) || + hci_dev_test_flag(hdev, HCI_UNREGISTER)) + return 0; + + /* If powering down don't attempt to suspend */ + if (mgmt_powering_down(hdev)) + return 0; + + hci_req_sync_lock(hdev); + ret = hci_suspend_sync(hdev); + hci_req_sync_unlock(hdev); + + hci_clear_wake_reason(hdev); + mgmt_suspending(hdev, hdev->suspend_state); + + hci_sock_dev_event(hdev, HCI_DEV_SUSPEND); + return ret; +} +EXPORT_SYMBOL(hci_suspend_dev); + +/* Resume HCI device */ +int hci_resume_dev(struct hci_dev *hdev) +{ + int ret; + + bt_dev_dbg(hdev, ""); + + /* Resume should only act on when powered. */ + if (!hdev_is_powered(hdev) || + hci_dev_test_flag(hdev, HCI_UNREGISTER)) + return 0; + + /* If powering down don't attempt to resume */ + if (mgmt_powering_down(hdev)) + return 0; + + hci_req_sync_lock(hdev); + ret = hci_resume_sync(hdev); + hci_req_sync_unlock(hdev); + + mgmt_resuming(hdev, hdev->wake_reason, &hdev->wake_addr, + hdev->wake_addr_type); + + hci_sock_dev_event(hdev, HCI_DEV_RESUME); + return ret; +} +EXPORT_SYMBOL(hci_resume_dev); + +/* Reset HCI device */ +int hci_reset_dev(struct hci_dev *hdev) +{ + static const u8 hw_err[] = { HCI_EV_HARDWARE_ERROR, 0x01, 0x00 }; + struct sk_buff *skb; + + skb = bt_skb_alloc(3, GFP_ATOMIC); + if (!skb) + return -ENOMEM; + + hci_skb_pkt_type(skb) = HCI_EVENT_PKT; + skb_put_data(skb, hw_err, 3); + + bt_dev_err(hdev, "Injecting HCI hardware error event"); + + /* Send Hardware Error to upper stack */ + return hci_recv_frame(hdev, skb); +} +EXPORT_SYMBOL(hci_reset_dev); + +/* Receive frame from HCI drivers */ +int hci_recv_frame(struct hci_dev *hdev, struct sk_buff *skb) +{ + if (!hdev || (!test_bit(HCI_UP, &hdev->flags) + && !test_bit(HCI_INIT, &hdev->flags))) { + kfree_skb(skb); + return -ENXIO; + } + + switch (hci_skb_pkt_type(skb)) { + case HCI_EVENT_PKT: + break; + case HCI_ACLDATA_PKT: + /* Detect if ISO packet has been sent as ACL */ + if (hci_conn_num(hdev, ISO_LINK)) { + __u16 handle = __le16_to_cpu(hci_acl_hdr(skb)->handle); + __u8 type; + + type = hci_conn_lookup_type(hdev, hci_handle(handle)); + if (type == ISO_LINK) + hci_skb_pkt_type(skb) = HCI_ISODATA_PKT; + } + break; + case HCI_SCODATA_PKT: + break; + case HCI_ISODATA_PKT: + break; + default: + kfree_skb(skb); + return -EINVAL; + } + + /* Incoming skb */ + bt_cb(skb)->incoming = 1; + + /* Time stamp */ + __net_timestamp(skb); + + skb_queue_tail(&hdev->rx_q, skb); + queue_work(hdev->workqueue, &hdev->rx_work); + + return 0; +} +EXPORT_SYMBOL(hci_recv_frame); + +/* Receive diagnostic message from HCI drivers */ +int hci_recv_diag(struct hci_dev *hdev, struct sk_buff *skb) +{ + /* Mark as diagnostic packet */ + hci_skb_pkt_type(skb) = HCI_DIAG_PKT; + + /* Time stamp */ + __net_timestamp(skb); + + skb_queue_tail(&hdev->rx_q, skb); + queue_work(hdev->workqueue, &hdev->rx_work); + + return 0; +} +EXPORT_SYMBOL(hci_recv_diag); + +void hci_set_hw_info(struct hci_dev *hdev, const char *fmt, ...) +{ + va_list vargs; + + va_start(vargs, fmt); + kfree_const(hdev->hw_info); + hdev->hw_info = kvasprintf_const(GFP_KERNEL, fmt, vargs); + va_end(vargs); +} +EXPORT_SYMBOL(hci_set_hw_info); + +void hci_set_fw_info(struct hci_dev *hdev, const char *fmt, ...) +{ + va_list vargs; + + va_start(vargs, fmt); + kfree_const(hdev->fw_info); + hdev->fw_info = kvasprintf_const(GFP_KERNEL, fmt, vargs); + va_end(vargs); +} +EXPORT_SYMBOL(hci_set_fw_info); + +/* ---- Interface to upper protocols ---- */ + +int hci_register_cb(struct hci_cb *cb) +{ + BT_DBG("%p name %s", cb, cb->name); + + mutex_lock(&hci_cb_list_lock); + list_add_tail(&cb->list, &hci_cb_list); + mutex_unlock(&hci_cb_list_lock); + + return 0; +} +EXPORT_SYMBOL(hci_register_cb); + +int hci_unregister_cb(struct hci_cb *cb) +{ + BT_DBG("%p name %s", cb, cb->name); + + mutex_lock(&hci_cb_list_lock); + list_del(&cb->list); + mutex_unlock(&hci_cb_list_lock); + + return 0; +} +EXPORT_SYMBOL(hci_unregister_cb); + +static int hci_send_frame(struct hci_dev *hdev, struct sk_buff *skb) +{ + int err; + + BT_DBG("%s type %d len %d", hdev->name, hci_skb_pkt_type(skb), + skb->len); + + /* Time stamp */ + __net_timestamp(skb); + + /* Send copy to monitor */ + hci_send_to_monitor(hdev, skb); + + if (atomic_read(&hdev->promisc)) { + /* Send copy to the sockets */ + hci_send_to_sock(hdev, skb); + } + + /* Get rid of skb owner, prior to sending to the driver. */ + skb_orphan(skb); + + if (!test_bit(HCI_RUNNING, &hdev->flags)) { + kfree_skb(skb); + return -EINVAL; + } + + err = hdev->send(hdev, skb); + if (err < 0) { + bt_dev_err(hdev, "sending frame failed (%d)", err); + kfree_skb(skb); + return err; + } + + return 0; +} + +/* Send HCI command */ +int hci_send_cmd(struct hci_dev *hdev, __u16 opcode, __u32 plen, + const void *param) +{ + struct sk_buff *skb; + + BT_DBG("%s opcode 0x%4.4x plen %d", hdev->name, opcode, plen); + + skb = hci_prepare_cmd(hdev, opcode, plen, param); + if (!skb) { + bt_dev_err(hdev, "no memory for command"); + return -ENOMEM; + } + + /* Stand-alone HCI commands must be flagged as + * single-command requests. + */ + bt_cb(skb)->hci.req_flags |= HCI_REQ_START; + + skb_queue_tail(&hdev->cmd_q, skb); + queue_work(hdev->workqueue, &hdev->cmd_work); + + return 0; +} + +int __hci_cmd_send(struct hci_dev *hdev, u16 opcode, u32 plen, + const void *param) +{ + struct sk_buff *skb; + + if (hci_opcode_ogf(opcode) != 0x3f) { + /* A controller receiving a command shall respond with either + * a Command Status Event or a Command Complete Event. + * Therefore, all standard HCI commands must be sent via the + * standard API, using hci_send_cmd or hci_cmd_sync helpers. + * Some vendors do not comply with this rule for vendor-specific + * commands and do not return any event. We want to support + * unresponded commands for such cases only. + */ + bt_dev_err(hdev, "unresponded command not supported"); + return -EINVAL; + } + + skb = hci_prepare_cmd(hdev, opcode, plen, param); + if (!skb) { + bt_dev_err(hdev, "no memory for command (opcode 0x%4.4x)", + opcode); + return -ENOMEM; + } + + hci_send_frame(hdev, skb); + + return 0; +} +EXPORT_SYMBOL(__hci_cmd_send); + +/* Get data from the previously sent command */ +void *hci_sent_cmd_data(struct hci_dev *hdev, __u16 opcode) +{ + struct hci_command_hdr *hdr; + + if (!hdev->sent_cmd) + return NULL; + + hdr = (void *) hdev->sent_cmd->data; + + if (hdr->opcode != cpu_to_le16(opcode)) + return NULL; + + BT_DBG("%s opcode 0x%4.4x", hdev->name, opcode); + + return hdev->sent_cmd->data + HCI_COMMAND_HDR_SIZE; +} + +/* Get data from last received event */ +void *hci_recv_event_data(struct hci_dev *hdev, __u8 event) +{ + struct hci_event_hdr *hdr; + int offset; + + if (!hdev->recv_event) + return NULL; + + hdr = (void *)hdev->recv_event->data; + offset = sizeof(*hdr); + + if (hdr->evt != event) { + /* In case of LE metaevent check the subevent match */ + if (hdr->evt == HCI_EV_LE_META) { + struct hci_ev_le_meta *ev; + + ev = (void *)hdev->recv_event->data + offset; + offset += sizeof(*ev); + if (ev->subevent == event) + goto found; + } + return NULL; + } + +found: + bt_dev_dbg(hdev, "event 0x%2.2x", event); + + return hdev->recv_event->data + offset; +} + +/* Send ACL data */ +static void hci_add_acl_hdr(struct sk_buff *skb, __u16 handle, __u16 flags) +{ + struct hci_acl_hdr *hdr; + int len = skb->len; + + skb_push(skb, HCI_ACL_HDR_SIZE); + skb_reset_transport_header(skb); + hdr = (struct hci_acl_hdr *)skb_transport_header(skb); + hdr->handle = cpu_to_le16(hci_handle_pack(handle, flags)); + hdr->dlen = cpu_to_le16(len); +} + +static void hci_queue_acl(struct hci_chan *chan, struct sk_buff_head *queue, + struct sk_buff *skb, __u16 flags) +{ + struct hci_conn *conn = chan->conn; + struct hci_dev *hdev = conn->hdev; + struct sk_buff *list; + + skb->len = skb_headlen(skb); + skb->data_len = 0; + + hci_skb_pkt_type(skb) = HCI_ACLDATA_PKT; + + switch (hdev->dev_type) { + case HCI_PRIMARY: + hci_add_acl_hdr(skb, conn->handle, flags); + break; + case HCI_AMP: + hci_add_acl_hdr(skb, chan->handle, flags); + break; + default: + bt_dev_err(hdev, "unknown dev_type %d", hdev->dev_type); + return; + } + + list = skb_shinfo(skb)->frag_list; + if (!list) { + /* Non fragmented */ + BT_DBG("%s nonfrag skb %p len %d", hdev->name, skb, skb->len); + + skb_queue_tail(queue, skb); + } else { + /* Fragmented */ + BT_DBG("%s frag %p len %d", hdev->name, skb, skb->len); + + skb_shinfo(skb)->frag_list = NULL; + + /* Queue all fragments atomically. We need to use spin_lock_bh + * here because of 6LoWPAN links, as there this function is + * called from softirq and using normal spin lock could cause + * deadlocks. + */ + spin_lock_bh(&queue->lock); + + __skb_queue_tail(queue, skb); + + flags &= ~ACL_START; + flags |= ACL_CONT; + do { + skb = list; list = list->next; + + hci_skb_pkt_type(skb) = HCI_ACLDATA_PKT; + hci_add_acl_hdr(skb, conn->handle, flags); + + BT_DBG("%s frag %p len %d", hdev->name, skb, skb->len); + + __skb_queue_tail(queue, skb); + } while (list); + + spin_unlock_bh(&queue->lock); + } +} + +void hci_send_acl(struct hci_chan *chan, struct sk_buff *skb, __u16 flags) +{ + struct hci_dev *hdev = chan->conn->hdev; + + BT_DBG("%s chan %p flags 0x%4.4x", hdev->name, chan, flags); + + hci_queue_acl(chan, &chan->data_q, skb, flags); + + queue_work(hdev->workqueue, &hdev->tx_work); +} + +/* Send SCO data */ +void hci_send_sco(struct hci_conn *conn, struct sk_buff *skb) +{ + struct hci_dev *hdev = conn->hdev; + struct hci_sco_hdr hdr; + + BT_DBG("%s len %d", hdev->name, skb->len); + + hdr.handle = cpu_to_le16(conn->handle); + hdr.dlen = skb->len; + + skb_push(skb, HCI_SCO_HDR_SIZE); + skb_reset_transport_header(skb); + memcpy(skb_transport_header(skb), &hdr, HCI_SCO_HDR_SIZE); + + hci_skb_pkt_type(skb) = HCI_SCODATA_PKT; + + skb_queue_tail(&conn->data_q, skb); + queue_work(hdev->workqueue, &hdev->tx_work); +} + +/* Send ISO data */ +static void hci_add_iso_hdr(struct sk_buff *skb, __u16 handle, __u8 flags) +{ + struct hci_iso_hdr *hdr; + int len = skb->len; + + skb_push(skb, HCI_ISO_HDR_SIZE); + skb_reset_transport_header(skb); + hdr = (struct hci_iso_hdr *)skb_transport_header(skb); + hdr->handle = cpu_to_le16(hci_handle_pack(handle, flags)); + hdr->dlen = cpu_to_le16(len); +} + +static void hci_queue_iso(struct hci_conn *conn, struct sk_buff_head *queue, + struct sk_buff *skb) +{ + struct hci_dev *hdev = conn->hdev; + struct sk_buff *list; + __u16 flags; + + skb->len = skb_headlen(skb); + skb->data_len = 0; + + hci_skb_pkt_type(skb) = HCI_ISODATA_PKT; + + list = skb_shinfo(skb)->frag_list; + + flags = hci_iso_flags_pack(list ? ISO_START : ISO_SINGLE, 0x00); + hci_add_iso_hdr(skb, conn->handle, flags); + + if (!list) { + /* Non fragmented */ + BT_DBG("%s nonfrag skb %p len %d", hdev->name, skb, skb->len); + + skb_queue_tail(queue, skb); + } else { + /* Fragmented */ + BT_DBG("%s frag %p len %d", hdev->name, skb, skb->len); + + skb_shinfo(skb)->frag_list = NULL; + + __skb_queue_tail(queue, skb); + + do { + skb = list; list = list->next; + + hci_skb_pkt_type(skb) = HCI_ISODATA_PKT; + flags = hci_iso_flags_pack(list ? ISO_CONT : ISO_END, + 0x00); + hci_add_iso_hdr(skb, conn->handle, flags); + + BT_DBG("%s frag %p len %d", hdev->name, skb, skb->len); + + __skb_queue_tail(queue, skb); + } while (list); + } +} + +void hci_send_iso(struct hci_conn *conn, struct sk_buff *skb) +{ + struct hci_dev *hdev = conn->hdev; + + BT_DBG("%s len %d", hdev->name, skb->len); + + hci_queue_iso(conn, &conn->data_q, skb); + + queue_work(hdev->workqueue, &hdev->tx_work); +} + +/* ---- HCI TX task (outgoing data) ---- */ + +/* HCI Connection scheduler */ +static inline void hci_quote_sent(struct hci_conn *conn, int num, int *quote) +{ + struct hci_dev *hdev; + int cnt, q; + + if (!conn) { + *quote = 0; + return; + } + + hdev = conn->hdev; + + switch (conn->type) { + case ACL_LINK: + cnt = hdev->acl_cnt; + break; + case AMP_LINK: + cnt = hdev->block_cnt; + break; + case SCO_LINK: + case ESCO_LINK: + cnt = hdev->sco_cnt; + break; + case LE_LINK: + cnt = hdev->le_mtu ? hdev->le_cnt : hdev->acl_cnt; + break; + case ISO_LINK: + cnt = hdev->iso_mtu ? hdev->iso_cnt : + hdev->le_mtu ? hdev->le_cnt : hdev->acl_cnt; + break; + default: + cnt = 0; + bt_dev_err(hdev, "unknown link type %d", conn->type); + } + + q = cnt / num; + *quote = q ? q : 1; +} + +static struct hci_conn *hci_low_sent(struct hci_dev *hdev, __u8 type, + int *quote) +{ + struct hci_conn_hash *h = &hdev->conn_hash; + struct hci_conn *conn = NULL, *c; + unsigned int num = 0, min = ~0; + + /* We don't have to lock device here. Connections are always + * added and removed with TX task disabled. */ + + rcu_read_lock(); + + list_for_each_entry_rcu(c, &h->list, list) { + if (c->type != type || skb_queue_empty(&c->data_q)) + continue; + + if (c->state != BT_CONNECTED && c->state != BT_CONFIG) + continue; + + num++; + + if (c->sent < min) { + min = c->sent; + conn = c; + } + + if (hci_conn_num(hdev, type) == num) + break; + } + + rcu_read_unlock(); + + hci_quote_sent(conn, num, quote); + + BT_DBG("conn %p quote %d", conn, *quote); + return conn; +} + +static void hci_link_tx_to(struct hci_dev *hdev, __u8 type) +{ + struct hci_conn_hash *h = &hdev->conn_hash; + struct hci_conn *c; + + bt_dev_err(hdev, "link tx timeout"); + + rcu_read_lock(); + + /* Kill stalled connections */ + list_for_each_entry_rcu(c, &h->list, list) { + if (c->type == type && c->sent) { + bt_dev_err(hdev, "killing stalled connection %pMR", + &c->dst); + hci_disconnect(c, HCI_ERROR_REMOTE_USER_TERM); + } + } + + rcu_read_unlock(); +} + +static struct hci_chan *hci_chan_sent(struct hci_dev *hdev, __u8 type, + int *quote) +{ + struct hci_conn_hash *h = &hdev->conn_hash; + struct hci_chan *chan = NULL; + unsigned int num = 0, min = ~0, cur_prio = 0; + struct hci_conn *conn; + int conn_num = 0; + + BT_DBG("%s", hdev->name); + + rcu_read_lock(); + + list_for_each_entry_rcu(conn, &h->list, list) { + struct hci_chan *tmp; + + if (conn->type != type) + continue; + + if (conn->state != BT_CONNECTED && conn->state != BT_CONFIG) + continue; + + conn_num++; + + list_for_each_entry_rcu(tmp, &conn->chan_list, list) { + struct sk_buff *skb; + + if (skb_queue_empty(&tmp->data_q)) + continue; + + skb = skb_peek(&tmp->data_q); + if (skb->priority < cur_prio) + continue; + + if (skb->priority > cur_prio) { + num = 0; + min = ~0; + cur_prio = skb->priority; + } + + num++; + + if (conn->sent < min) { + min = conn->sent; + chan = tmp; + } + } + + if (hci_conn_num(hdev, type) == conn_num) + break; + } + + rcu_read_unlock(); + + if (!chan) + return NULL; + + hci_quote_sent(chan->conn, num, quote); + + BT_DBG("chan %p quote %d", chan, *quote); + return chan; +} + +static void hci_prio_recalculate(struct hci_dev *hdev, __u8 type) +{ + struct hci_conn_hash *h = &hdev->conn_hash; + struct hci_conn *conn; + int num = 0; + + BT_DBG("%s", hdev->name); + + rcu_read_lock(); + + list_for_each_entry_rcu(conn, &h->list, list) { + struct hci_chan *chan; + + if (conn->type != type) + continue; + + if (conn->state != BT_CONNECTED && conn->state != BT_CONFIG) + continue; + + num++; + + list_for_each_entry_rcu(chan, &conn->chan_list, list) { + struct sk_buff *skb; + + if (chan->sent) { + chan->sent = 0; + continue; + } + + if (skb_queue_empty(&chan->data_q)) + continue; + + skb = skb_peek(&chan->data_q); + if (skb->priority >= HCI_PRIO_MAX - 1) + continue; + + skb->priority = HCI_PRIO_MAX - 1; + + BT_DBG("chan %p skb %p promoted to %d", chan, skb, + skb->priority); + } + + if (hci_conn_num(hdev, type) == num) + break; + } + + rcu_read_unlock(); + +} + +static inline int __get_blocks(struct hci_dev *hdev, struct sk_buff *skb) +{ + /* Calculate count of blocks used by this packet */ + return DIV_ROUND_UP(skb->len - HCI_ACL_HDR_SIZE, hdev->block_len); +} + +static void __check_timeout(struct hci_dev *hdev, unsigned int cnt, u8 type) +{ + unsigned long last_tx; + + if (hci_dev_test_flag(hdev, HCI_UNCONFIGURED)) + return; + + switch (type) { + case LE_LINK: + last_tx = hdev->le_last_tx; + break; + default: + last_tx = hdev->acl_last_tx; + break; + } + + /* tx timeout must be longer than maximum link supervision timeout + * (40.9 seconds) + */ + if (!cnt && time_after(jiffies, last_tx + HCI_ACL_TX_TIMEOUT)) + hci_link_tx_to(hdev, type); +} + +/* Schedule SCO */ +static void hci_sched_sco(struct hci_dev *hdev) +{ + struct hci_conn *conn; + struct sk_buff *skb; + int quote; + + BT_DBG("%s", hdev->name); + + if (!hci_conn_num(hdev, SCO_LINK)) + return; + + while (hdev->sco_cnt && (conn = hci_low_sent(hdev, SCO_LINK, "e))) { + while (quote-- && (skb = skb_dequeue(&conn->data_q))) { + BT_DBG("skb %p len %d", skb, skb->len); + hci_send_frame(hdev, skb); + + conn->sent++; + if (conn->sent == ~0) + conn->sent = 0; + } + } +} + +static void hci_sched_esco(struct hci_dev *hdev) +{ + struct hci_conn *conn; + struct sk_buff *skb; + int quote; + + BT_DBG("%s", hdev->name); + + if (!hci_conn_num(hdev, ESCO_LINK)) + return; + + while (hdev->sco_cnt && (conn = hci_low_sent(hdev, ESCO_LINK, + "e))) { + while (quote-- && (skb = skb_dequeue(&conn->data_q))) { + BT_DBG("skb %p len %d", skb, skb->len); + hci_send_frame(hdev, skb); + + conn->sent++; + if (conn->sent == ~0) + conn->sent = 0; + } + } +} + +static void hci_sched_acl_pkt(struct hci_dev *hdev) +{ + unsigned int cnt = hdev->acl_cnt; + struct hci_chan *chan; + struct sk_buff *skb; + int quote; + + __check_timeout(hdev, cnt, ACL_LINK); + + while (hdev->acl_cnt && + (chan = hci_chan_sent(hdev, ACL_LINK, "e))) { + u32 priority = (skb_peek(&chan->data_q))->priority; + while (quote-- && (skb = skb_peek(&chan->data_q))) { + BT_DBG("chan %p skb %p len %d priority %u", chan, skb, + skb->len, skb->priority); + + /* Stop if priority has changed */ + if (skb->priority < priority) + break; + + skb = skb_dequeue(&chan->data_q); + + hci_conn_enter_active_mode(chan->conn, + bt_cb(skb)->force_active); + + hci_send_frame(hdev, skb); + hdev->acl_last_tx = jiffies; + + hdev->acl_cnt--; + chan->sent++; + chan->conn->sent++; + + /* Send pending SCO packets right away */ + hci_sched_sco(hdev); + hci_sched_esco(hdev); + } + } + + if (cnt != hdev->acl_cnt) + hci_prio_recalculate(hdev, ACL_LINK); +} + +static void hci_sched_acl_blk(struct hci_dev *hdev) +{ + unsigned int cnt = hdev->block_cnt; + struct hci_chan *chan; + struct sk_buff *skb; + int quote; + u8 type; + + BT_DBG("%s", hdev->name); + + if (hdev->dev_type == HCI_AMP) + type = AMP_LINK; + else + type = ACL_LINK; + + __check_timeout(hdev, cnt, type); + + while (hdev->block_cnt > 0 && + (chan = hci_chan_sent(hdev, type, "e))) { + u32 priority = (skb_peek(&chan->data_q))->priority; + while (quote > 0 && (skb = skb_peek(&chan->data_q))) { + int blocks; + + BT_DBG("chan %p skb %p len %d priority %u", chan, skb, + skb->len, skb->priority); + + /* Stop if priority has changed */ + if (skb->priority < priority) + break; + + skb = skb_dequeue(&chan->data_q); + + blocks = __get_blocks(hdev, skb); + if (blocks > hdev->block_cnt) + return; + + hci_conn_enter_active_mode(chan->conn, + bt_cb(skb)->force_active); + + hci_send_frame(hdev, skb); + hdev->acl_last_tx = jiffies; + + hdev->block_cnt -= blocks; + quote -= blocks; + + chan->sent += blocks; + chan->conn->sent += blocks; + } + } + + if (cnt != hdev->block_cnt) + hci_prio_recalculate(hdev, type); +} + +static void hci_sched_acl(struct hci_dev *hdev) +{ + BT_DBG("%s", hdev->name); + + /* No ACL link over BR/EDR controller */ + if (!hci_conn_num(hdev, ACL_LINK) && hdev->dev_type == HCI_PRIMARY) + return; + + /* No AMP link over AMP controller */ + if (!hci_conn_num(hdev, AMP_LINK) && hdev->dev_type == HCI_AMP) + return; + + switch (hdev->flow_ctl_mode) { + case HCI_FLOW_CTL_MODE_PACKET_BASED: + hci_sched_acl_pkt(hdev); + break; + + case HCI_FLOW_CTL_MODE_BLOCK_BASED: + hci_sched_acl_blk(hdev); + break; + } +} + +static void hci_sched_le(struct hci_dev *hdev) +{ + struct hci_chan *chan; + struct sk_buff *skb; + int quote, cnt, tmp; + + BT_DBG("%s", hdev->name); + + if (!hci_conn_num(hdev, LE_LINK)) + return; + + cnt = hdev->le_pkts ? hdev->le_cnt : hdev->acl_cnt; + + __check_timeout(hdev, cnt, LE_LINK); + + tmp = cnt; + while (cnt && (chan = hci_chan_sent(hdev, LE_LINK, "e))) { + u32 priority = (skb_peek(&chan->data_q))->priority; + while (quote-- && (skb = skb_peek(&chan->data_q))) { + BT_DBG("chan %p skb %p len %d priority %u", chan, skb, + skb->len, skb->priority); + + /* Stop if priority has changed */ + if (skb->priority < priority) + break; + + skb = skb_dequeue(&chan->data_q); + + hci_send_frame(hdev, skb); + hdev->le_last_tx = jiffies; + + cnt--; + chan->sent++; + chan->conn->sent++; + + /* Send pending SCO packets right away */ + hci_sched_sco(hdev); + hci_sched_esco(hdev); + } + } + + if (hdev->le_pkts) + hdev->le_cnt = cnt; + else + hdev->acl_cnt = cnt; + + if (cnt != tmp) + hci_prio_recalculate(hdev, LE_LINK); +} + +/* Schedule CIS */ +static void hci_sched_iso(struct hci_dev *hdev) +{ + struct hci_conn *conn; + struct sk_buff *skb; + int quote, *cnt; + + BT_DBG("%s", hdev->name); + + if (!hci_conn_num(hdev, ISO_LINK)) + return; + + cnt = hdev->iso_pkts ? &hdev->iso_cnt : + hdev->le_pkts ? &hdev->le_cnt : &hdev->acl_cnt; + while (*cnt && (conn = hci_low_sent(hdev, ISO_LINK, "e))) { + while (quote-- && (skb = skb_dequeue(&conn->data_q))) { + BT_DBG("skb %p len %d", skb, skb->len); + hci_send_frame(hdev, skb); + + conn->sent++; + if (conn->sent == ~0) + conn->sent = 0; + (*cnt)--; + } + } +} + +static void hci_tx_work(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, tx_work); + struct sk_buff *skb; + + BT_DBG("%s acl %d sco %d le %d iso %d", hdev->name, hdev->acl_cnt, + hdev->sco_cnt, hdev->le_cnt, hdev->iso_cnt); + + if (!hci_dev_test_flag(hdev, HCI_USER_CHANNEL)) { + /* Schedule queues and send stuff to HCI driver */ + hci_sched_sco(hdev); + hci_sched_esco(hdev); + hci_sched_iso(hdev); + hci_sched_acl(hdev); + hci_sched_le(hdev); + } + + /* Send next queued raw (unknown type) packet */ + while ((skb = skb_dequeue(&hdev->raw_q))) + hci_send_frame(hdev, skb); +} + +/* ----- HCI RX task (incoming data processing) ----- */ + +/* ACL data packet */ +static void hci_acldata_packet(struct hci_dev *hdev, struct sk_buff *skb) +{ + struct hci_acl_hdr *hdr = (void *) skb->data; + struct hci_conn *conn; + __u16 handle, flags; + + skb_pull(skb, HCI_ACL_HDR_SIZE); + + handle = __le16_to_cpu(hdr->handle); + flags = hci_flags(handle); + handle = hci_handle(handle); + + BT_DBG("%s len %d handle 0x%4.4x flags 0x%4.4x", hdev->name, skb->len, + handle, flags); + + hdev->stat.acl_rx++; + + hci_dev_lock(hdev); + conn = hci_conn_hash_lookup_handle(hdev, handle); + hci_dev_unlock(hdev); + + if (conn) { + hci_conn_enter_active_mode(conn, BT_POWER_FORCE_ACTIVE_OFF); + + /* Send to upper protocol */ + l2cap_recv_acldata(conn, skb, flags); + return; + } else { + bt_dev_err(hdev, "ACL packet for unknown connection handle %d", + handle); + } + + kfree_skb(skb); +} + +/* SCO data packet */ +static void hci_scodata_packet(struct hci_dev *hdev, struct sk_buff *skb) +{ + struct hci_sco_hdr *hdr = (void *) skb->data; + struct hci_conn *conn; + __u16 handle, flags; + + skb_pull(skb, HCI_SCO_HDR_SIZE); + + handle = __le16_to_cpu(hdr->handle); + flags = hci_flags(handle); + handle = hci_handle(handle); + + BT_DBG("%s len %d handle 0x%4.4x flags 0x%4.4x", hdev->name, skb->len, + handle, flags); + + hdev->stat.sco_rx++; + + hci_dev_lock(hdev); + conn = hci_conn_hash_lookup_handle(hdev, handle); + hci_dev_unlock(hdev); + + if (conn) { + /* Send to upper protocol */ + bt_cb(skb)->sco.pkt_status = flags & 0x03; + sco_recv_scodata(conn, skb); + return; + } else { + bt_dev_err_ratelimited(hdev, "SCO packet for unknown connection handle %d", + handle); + } + + kfree_skb(skb); +} + +static void hci_isodata_packet(struct hci_dev *hdev, struct sk_buff *skb) +{ + struct hci_iso_hdr *hdr; + struct hci_conn *conn; + __u16 handle, flags; + + hdr = skb_pull_data(skb, sizeof(*hdr)); + if (!hdr) { + bt_dev_err(hdev, "ISO packet too small"); + goto drop; + } + + handle = __le16_to_cpu(hdr->handle); + flags = hci_flags(handle); + handle = hci_handle(handle); + + bt_dev_dbg(hdev, "len %d handle 0x%4.4x flags 0x%4.4x", skb->len, + handle, flags); + + hci_dev_lock(hdev); + conn = hci_conn_hash_lookup_handle(hdev, handle); + hci_dev_unlock(hdev); + + if (!conn) { + bt_dev_err(hdev, "ISO packet for unknown connection handle %d", + handle); + goto drop; + } + + /* Send to upper protocol */ + iso_recv(conn, skb, flags); + return; + +drop: + kfree_skb(skb); +} + +static bool hci_req_is_complete(struct hci_dev *hdev) +{ + struct sk_buff *skb; + + skb = skb_peek(&hdev->cmd_q); + if (!skb) + return true; + + return (bt_cb(skb)->hci.req_flags & HCI_REQ_START); +} + +static void hci_resend_last(struct hci_dev *hdev) +{ + struct hci_command_hdr *sent; + struct sk_buff *skb; + u16 opcode; + + if (!hdev->sent_cmd) + return; + + sent = (void *) hdev->sent_cmd->data; + opcode = __le16_to_cpu(sent->opcode); + if (opcode == HCI_OP_RESET) + return; + + skb = skb_clone(hdev->sent_cmd, GFP_KERNEL); + if (!skb) + return; + + skb_queue_head(&hdev->cmd_q, skb); + queue_work(hdev->workqueue, &hdev->cmd_work); +} + +void hci_req_cmd_complete(struct hci_dev *hdev, u16 opcode, u8 status, + hci_req_complete_t *req_complete, + hci_req_complete_skb_t *req_complete_skb) +{ + struct sk_buff *skb; + unsigned long flags; + + BT_DBG("opcode 0x%04x status 0x%02x", opcode, status); + + /* If the completed command doesn't match the last one that was + * sent we need to do special handling of it. + */ + if (!hci_sent_cmd_data(hdev, opcode)) { + /* Some CSR based controllers generate a spontaneous + * reset complete event during init and any pending + * command will never be completed. In such a case we + * need to resend whatever was the last sent + * command. + */ + if (test_bit(HCI_INIT, &hdev->flags) && opcode == HCI_OP_RESET) + hci_resend_last(hdev); + + return; + } + + /* If we reach this point this event matches the last command sent */ + hci_dev_clear_flag(hdev, HCI_CMD_PENDING); + + /* If the command succeeded and there's still more commands in + * this request the request is not yet complete. + */ + if (!status && !hci_req_is_complete(hdev)) + return; + + /* If this was the last command in a request the complete + * callback would be found in hdev->sent_cmd instead of the + * command queue (hdev->cmd_q). + */ + if (bt_cb(hdev->sent_cmd)->hci.req_flags & HCI_REQ_SKB) { + *req_complete_skb = bt_cb(hdev->sent_cmd)->hci.req_complete_skb; + return; + } + + if (bt_cb(hdev->sent_cmd)->hci.req_complete) { + *req_complete = bt_cb(hdev->sent_cmd)->hci.req_complete; + return; + } + + /* Remove all pending commands belonging to this request */ + spin_lock_irqsave(&hdev->cmd_q.lock, flags); + while ((skb = __skb_dequeue(&hdev->cmd_q))) { + if (bt_cb(skb)->hci.req_flags & HCI_REQ_START) { + __skb_queue_head(&hdev->cmd_q, skb); + break; + } + + if (bt_cb(skb)->hci.req_flags & HCI_REQ_SKB) + *req_complete_skb = bt_cb(skb)->hci.req_complete_skb; + else + *req_complete = bt_cb(skb)->hci.req_complete; + dev_kfree_skb_irq(skb); + } + spin_unlock_irqrestore(&hdev->cmd_q.lock, flags); +} + +static void hci_rx_work(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, rx_work); + struct sk_buff *skb; + + BT_DBG("%s", hdev->name); + + /* The kcov_remote functions used for collecting packet parsing + * coverage information from this background thread and associate + * the coverage with the syscall's thread which originally injected + * the packet. This helps fuzzing the kernel. + */ + for (; (skb = skb_dequeue(&hdev->rx_q)); kcov_remote_stop()) { + kcov_remote_start_common(skb_get_kcov_handle(skb)); + + /* Send copy to monitor */ + hci_send_to_monitor(hdev, skb); + + if (atomic_read(&hdev->promisc)) { + /* Send copy to the sockets */ + hci_send_to_sock(hdev, skb); + } + + /* If the device has been opened in HCI_USER_CHANNEL, + * the userspace has exclusive access to device. + * When device is HCI_INIT, we still need to process + * the data packets to the driver in order + * to complete its setup(). + */ + if (hci_dev_test_flag(hdev, HCI_USER_CHANNEL) && + !test_bit(HCI_INIT, &hdev->flags)) { + kfree_skb(skb); + continue; + } + + if (test_bit(HCI_INIT, &hdev->flags)) { + /* Don't process data packets in this states. */ + switch (hci_skb_pkt_type(skb)) { + case HCI_ACLDATA_PKT: + case HCI_SCODATA_PKT: + case HCI_ISODATA_PKT: + kfree_skb(skb); + continue; + } + } + + /* Process frame */ + switch (hci_skb_pkt_type(skb)) { + case HCI_EVENT_PKT: + BT_DBG("%s Event packet", hdev->name); + hci_event_packet(hdev, skb); + break; + + case HCI_ACLDATA_PKT: + BT_DBG("%s ACL data packet", hdev->name); + hci_acldata_packet(hdev, skb); + break; + + case HCI_SCODATA_PKT: + BT_DBG("%s SCO data packet", hdev->name); + hci_scodata_packet(hdev, skb); + break; + + case HCI_ISODATA_PKT: + BT_DBG("%s ISO data packet", hdev->name); + hci_isodata_packet(hdev, skb); + break; + + default: + kfree_skb(skb); + break; + } + } +} + +static void hci_cmd_work(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, cmd_work); + struct sk_buff *skb; + + BT_DBG("%s cmd_cnt %d cmd queued %d", hdev->name, + atomic_read(&hdev->cmd_cnt), skb_queue_len(&hdev->cmd_q)); + + /* Send queued commands */ + if (atomic_read(&hdev->cmd_cnt)) { + skb = skb_dequeue(&hdev->cmd_q); + if (!skb) + return; + + kfree_skb(hdev->sent_cmd); + + hdev->sent_cmd = skb_clone(skb, GFP_KERNEL); + if (hdev->sent_cmd) { + int res; + if (hci_req_status_pend(hdev)) + hci_dev_set_flag(hdev, HCI_CMD_PENDING); + atomic_dec(&hdev->cmd_cnt); + + res = hci_send_frame(hdev, skb); + if (res < 0) + __hci_cmd_sync_cancel(hdev, -res); + + rcu_read_lock(); + if (test_bit(HCI_RESET, &hdev->flags) || + hci_dev_test_flag(hdev, HCI_CMD_DRAIN_WORKQUEUE)) + cancel_delayed_work(&hdev->cmd_timer); + else + queue_delayed_work(hdev->workqueue, &hdev->cmd_timer, + HCI_CMD_TIMEOUT); + rcu_read_unlock(); + } else { + skb_queue_head(&hdev->cmd_q, skb); + queue_work(hdev->workqueue, &hdev->cmd_work); + } + } +} diff --git a/net/bluetooth/hci_debugfs.c b/net/bluetooth/hci_debugfs.c new file mode 100644 index 000000000..6124b3425 --- /dev/null +++ b/net/bluetooth/hci_debugfs.c @@ -0,0 +1,1379 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + + Copyright (C) 2014 Intel Corporation + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#include + +#include +#include + +#include "smp.h" +#include "hci_request.h" +#include "hci_debugfs.h" + +#define DEFINE_QUIRK_ATTRIBUTE(__name, __quirk) \ +static ssize_t __name ## _read(struct file *file, \ + char __user *user_buf, \ + size_t count, loff_t *ppos) \ +{ \ + struct hci_dev *hdev = file->private_data; \ + char buf[3]; \ + \ + buf[0] = test_bit(__quirk, &hdev->quirks) ? 'Y' : 'N'; \ + buf[1] = '\n'; \ + buf[2] = '\0'; \ + return simple_read_from_buffer(user_buf, count, ppos, buf, 2); \ +} \ + \ +static ssize_t __name ## _write(struct file *file, \ + const char __user *user_buf, \ + size_t count, loff_t *ppos) \ +{ \ + struct hci_dev *hdev = file->private_data; \ + bool enable; \ + int err; \ + \ + if (test_bit(HCI_UP, &hdev->flags)) \ + return -EBUSY; \ + \ + err = kstrtobool_from_user(user_buf, count, &enable); \ + if (err) \ + return err; \ + \ + if (enable == test_bit(__quirk, &hdev->quirks)) \ + return -EALREADY; \ + \ + change_bit(__quirk, &hdev->quirks); \ + \ + return count; \ +} \ + \ +static const struct file_operations __name ## _fops = { \ + .open = simple_open, \ + .read = __name ## _read, \ + .write = __name ## _write, \ + .llseek = default_llseek, \ +} \ + +#define DEFINE_INFO_ATTRIBUTE(__name, __field) \ +static int __name ## _show(struct seq_file *f, void *ptr) \ +{ \ + struct hci_dev *hdev = f->private; \ + \ + hci_dev_lock(hdev); \ + seq_printf(f, "%s\n", hdev->__field ? : ""); \ + hci_dev_unlock(hdev); \ + \ + return 0; \ +} \ + \ +DEFINE_SHOW_ATTRIBUTE(__name) + +static int features_show(struct seq_file *f, void *ptr) +{ + struct hci_dev *hdev = f->private; + u8 p; + + hci_dev_lock(hdev); + for (p = 0; p < HCI_MAX_PAGES && p <= hdev->max_page; p++) + seq_printf(f, "%2u: %8ph\n", p, hdev->features[p]); + if (lmp_le_capable(hdev)) + seq_printf(f, "LE: %8ph\n", hdev->le_features); + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(features); + +static int device_id_show(struct seq_file *f, void *ptr) +{ + struct hci_dev *hdev = f->private; + + hci_dev_lock(hdev); + seq_printf(f, "%4.4x:%4.4x:%4.4x:%4.4x\n", hdev->devid_source, + hdev->devid_vendor, hdev->devid_product, hdev->devid_version); + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(device_id); + +static int device_list_show(struct seq_file *f, void *ptr) +{ + struct hci_dev *hdev = f->private; + struct hci_conn_params *p; + struct bdaddr_list *b; + + hci_dev_lock(hdev); + list_for_each_entry(b, &hdev->accept_list, list) + seq_printf(f, "%pMR (type %u)\n", &b->bdaddr, b->bdaddr_type); + list_for_each_entry(p, &hdev->le_conn_params, list) { + seq_printf(f, "%pMR (type %u) %u\n", &p->addr, p->addr_type, + p->auto_connect); + } + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(device_list); + +static int blacklist_show(struct seq_file *f, void *p) +{ + struct hci_dev *hdev = f->private; + struct bdaddr_list *b; + + hci_dev_lock(hdev); + list_for_each_entry(b, &hdev->reject_list, list) + seq_printf(f, "%pMR (type %u)\n", &b->bdaddr, b->bdaddr_type); + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(blacklist); + +static int blocked_keys_show(struct seq_file *f, void *p) +{ + struct hci_dev *hdev = f->private; + struct blocked_key *key; + + rcu_read_lock(); + list_for_each_entry_rcu(key, &hdev->blocked_keys, list) + seq_printf(f, "%u %*phN\n", key->type, 16, key->val); + rcu_read_unlock(); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(blocked_keys); + +static int uuids_show(struct seq_file *f, void *p) +{ + struct hci_dev *hdev = f->private; + struct bt_uuid *uuid; + + hci_dev_lock(hdev); + list_for_each_entry(uuid, &hdev->uuids, list) { + u8 i, val[16]; + + /* The Bluetooth UUID values are stored in big endian, + * but with reversed byte order. So convert them into + * the right order for the %pUb modifier. + */ + for (i = 0; i < 16; i++) + val[i] = uuid->uuid[15 - i]; + + seq_printf(f, "%pUb\n", val); + } + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(uuids); + +static int remote_oob_show(struct seq_file *f, void *ptr) +{ + struct hci_dev *hdev = f->private; + struct oob_data *data; + + hci_dev_lock(hdev); + list_for_each_entry(data, &hdev->remote_oob_data, list) { + seq_printf(f, "%pMR (type %u) %u %*phN %*phN %*phN %*phN\n", + &data->bdaddr, data->bdaddr_type, data->present, + 16, data->hash192, 16, data->rand192, + 16, data->hash256, 16, data->rand256); + } + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(remote_oob); + +static int conn_info_min_age_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + if (val == 0 || val > hdev->conn_info_max_age) + return -EINVAL; + + hci_dev_lock(hdev); + hdev->conn_info_min_age = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int conn_info_min_age_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->conn_info_min_age; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(conn_info_min_age_fops, conn_info_min_age_get, + conn_info_min_age_set, "%llu\n"); + +static int conn_info_max_age_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + if (val == 0 || val < hdev->conn_info_min_age) + return -EINVAL; + + hci_dev_lock(hdev); + hdev->conn_info_max_age = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int conn_info_max_age_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->conn_info_max_age; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(conn_info_max_age_fops, conn_info_max_age_get, + conn_info_max_age_set, "%llu\n"); + +static ssize_t use_debug_keys_read(struct file *file, char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct hci_dev *hdev = file->private_data; + char buf[3]; + + buf[0] = hci_dev_test_flag(hdev, HCI_USE_DEBUG_KEYS) ? 'Y' : 'N'; + buf[1] = '\n'; + buf[2] = '\0'; + return simple_read_from_buffer(user_buf, count, ppos, buf, 2); +} + +static const struct file_operations use_debug_keys_fops = { + .open = simple_open, + .read = use_debug_keys_read, + .llseek = default_llseek, +}; + +static ssize_t sc_only_mode_read(struct file *file, char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct hci_dev *hdev = file->private_data; + char buf[3]; + + buf[0] = hci_dev_test_flag(hdev, HCI_SC_ONLY) ? 'Y' : 'N'; + buf[1] = '\n'; + buf[2] = '\0'; + return simple_read_from_buffer(user_buf, count, ppos, buf, 2); +} + +static const struct file_operations sc_only_mode_fops = { + .open = simple_open, + .read = sc_only_mode_read, + .llseek = default_llseek, +}; + +DEFINE_INFO_ATTRIBUTE(hardware_info, hw_info); +DEFINE_INFO_ATTRIBUTE(firmware_info, fw_info); + +void hci_debugfs_create_common(struct hci_dev *hdev) +{ + debugfs_create_file("features", 0444, hdev->debugfs, hdev, + &features_fops); + debugfs_create_u16("manufacturer", 0444, hdev->debugfs, + &hdev->manufacturer); + debugfs_create_u8("hci_version", 0444, hdev->debugfs, &hdev->hci_ver); + debugfs_create_u16("hci_revision", 0444, hdev->debugfs, &hdev->hci_rev); + debugfs_create_u8("hardware_error", 0444, hdev->debugfs, + &hdev->hw_error_code); + debugfs_create_file("device_id", 0444, hdev->debugfs, hdev, + &device_id_fops); + + debugfs_create_file("device_list", 0444, hdev->debugfs, hdev, + &device_list_fops); + debugfs_create_file("blacklist", 0444, hdev->debugfs, hdev, + &blacklist_fops); + debugfs_create_file("blocked_keys", 0444, hdev->debugfs, hdev, + &blocked_keys_fops); + debugfs_create_file("uuids", 0444, hdev->debugfs, hdev, &uuids_fops); + debugfs_create_file("remote_oob", 0400, hdev->debugfs, hdev, + &remote_oob_fops); + + debugfs_create_file("conn_info_min_age", 0644, hdev->debugfs, hdev, + &conn_info_min_age_fops); + debugfs_create_file("conn_info_max_age", 0644, hdev->debugfs, hdev, + &conn_info_max_age_fops); + + if (lmp_ssp_capable(hdev) || lmp_le_capable(hdev)) + debugfs_create_file("use_debug_keys", 0444, hdev->debugfs, + hdev, &use_debug_keys_fops); + + if (lmp_sc_capable(hdev) || lmp_le_capable(hdev)) + debugfs_create_file("sc_only_mode", 0444, hdev->debugfs, + hdev, &sc_only_mode_fops); + + if (hdev->hw_info) + debugfs_create_file("hardware_info", 0444, hdev->debugfs, + hdev, &hardware_info_fops); + + if (hdev->fw_info) + debugfs_create_file("firmware_info", 0444, hdev->debugfs, + hdev, &firmware_info_fops); +} + +static int inquiry_cache_show(struct seq_file *f, void *p) +{ + struct hci_dev *hdev = f->private; + struct discovery_state *cache = &hdev->discovery; + struct inquiry_entry *e; + + hci_dev_lock(hdev); + + list_for_each_entry(e, &cache->all, all) { + struct inquiry_data *data = &e->data; + seq_printf(f, "%pMR %d %d %d 0x%.2x%.2x%.2x 0x%.4x %d %d %u\n", + &data->bdaddr, + data->pscan_rep_mode, data->pscan_period_mode, + data->pscan_mode, data->dev_class[2], + data->dev_class[1], data->dev_class[0], + __le16_to_cpu(data->clock_offset), + data->rssi, data->ssp_mode, e->timestamp); + } + + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(inquiry_cache); + +static int link_keys_show(struct seq_file *f, void *ptr) +{ + struct hci_dev *hdev = f->private; + struct link_key *key; + + rcu_read_lock(); + list_for_each_entry_rcu(key, &hdev->link_keys, list) + seq_printf(f, "%pMR %u %*phN %u\n", &key->bdaddr, key->type, + HCI_LINK_KEY_SIZE, key->val, key->pin_len); + rcu_read_unlock(); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(link_keys); + +static int dev_class_show(struct seq_file *f, void *ptr) +{ + struct hci_dev *hdev = f->private; + + hci_dev_lock(hdev); + seq_printf(f, "0x%.2x%.2x%.2x\n", hdev->dev_class[2], + hdev->dev_class[1], hdev->dev_class[0]); + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(dev_class); + +static int voice_setting_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->voice_setting; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(voice_setting_fops, voice_setting_get, + NULL, "0x%4.4llx\n"); + +static ssize_t ssp_debug_mode_read(struct file *file, char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct hci_dev *hdev = file->private_data; + char buf[3]; + + buf[0] = hdev->ssp_debug_mode ? 'Y' : 'N'; + buf[1] = '\n'; + buf[2] = '\0'; + return simple_read_from_buffer(user_buf, count, ppos, buf, 2); +} + +static const struct file_operations ssp_debug_mode_fops = { + .open = simple_open, + .read = ssp_debug_mode_read, + .llseek = default_llseek, +}; + +static int auto_accept_delay_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + hdev->auto_accept_delay = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int min_encrypt_key_size_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + if (val < 1 || val > 16) + return -EINVAL; + + hci_dev_lock(hdev); + hdev->min_enc_key_size = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int min_encrypt_key_size_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->min_enc_key_size; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(min_encrypt_key_size_fops, + min_encrypt_key_size_get, + min_encrypt_key_size_set, "%llu\n"); + +static int auto_accept_delay_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->auto_accept_delay; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(auto_accept_delay_fops, auto_accept_delay_get, + auto_accept_delay_set, "%llu\n"); + +static ssize_t force_bredr_smp_read(struct file *file, + char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct hci_dev *hdev = file->private_data; + char buf[3]; + + buf[0] = hci_dev_test_flag(hdev, HCI_FORCE_BREDR_SMP) ? 'Y' : 'N'; + buf[1] = '\n'; + buf[2] = '\0'; + return simple_read_from_buffer(user_buf, count, ppos, buf, 2); +} + +static ssize_t force_bredr_smp_write(struct file *file, + const char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct hci_dev *hdev = file->private_data; + bool enable; + int err; + + err = kstrtobool_from_user(user_buf, count, &enable); + if (err) + return err; + + err = smp_force_bredr(hdev, enable); + if (err) + return err; + + return count; +} + +static const struct file_operations force_bredr_smp_fops = { + .open = simple_open, + .read = force_bredr_smp_read, + .write = force_bredr_smp_write, + .llseek = default_llseek, +}; + +static int idle_timeout_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + if (val != 0 && (val < 500 || val > 3600000)) + return -EINVAL; + + hci_dev_lock(hdev); + hdev->idle_timeout = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int idle_timeout_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->idle_timeout; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(idle_timeout_fops, idle_timeout_get, + idle_timeout_set, "%llu\n"); + +static int sniff_min_interval_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + if (val == 0 || val % 2 || val > hdev->sniff_max_interval) + return -EINVAL; + + hci_dev_lock(hdev); + hdev->sniff_min_interval = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int sniff_min_interval_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->sniff_min_interval; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(sniff_min_interval_fops, sniff_min_interval_get, + sniff_min_interval_set, "%llu\n"); + +static int sniff_max_interval_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + if (val == 0 || val % 2 || val < hdev->sniff_min_interval) + return -EINVAL; + + hci_dev_lock(hdev); + hdev->sniff_max_interval = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int sniff_max_interval_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->sniff_max_interval; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(sniff_max_interval_fops, sniff_max_interval_get, + sniff_max_interval_set, "%llu\n"); + +void hci_debugfs_create_bredr(struct hci_dev *hdev) +{ + debugfs_create_file("inquiry_cache", 0444, hdev->debugfs, hdev, + &inquiry_cache_fops); + debugfs_create_file("link_keys", 0400, hdev->debugfs, hdev, + &link_keys_fops); + debugfs_create_file("dev_class", 0444, hdev->debugfs, hdev, + &dev_class_fops); + debugfs_create_file("voice_setting", 0444, hdev->debugfs, hdev, + &voice_setting_fops); + + /* If the controller does not support BR/EDR Secure Connections + * feature, then the BR/EDR SMP channel shall not be present. + * + * To test this with Bluetooth 4.0 controllers, create a debugfs + * switch that allows forcing BR/EDR SMP support and accepting + * cross-transport pairing on non-AES encrypted connections. + */ + if (!lmp_sc_capable(hdev)) + debugfs_create_file("force_bredr_smp", 0644, hdev->debugfs, + hdev, &force_bredr_smp_fops); + + if (lmp_ssp_capable(hdev)) { + debugfs_create_file("ssp_debug_mode", 0444, hdev->debugfs, + hdev, &ssp_debug_mode_fops); + debugfs_create_file("min_encrypt_key_size", 0644, hdev->debugfs, + hdev, &min_encrypt_key_size_fops); + debugfs_create_file("auto_accept_delay", 0644, hdev->debugfs, + hdev, &auto_accept_delay_fops); + } + + if (lmp_sniff_capable(hdev)) { + debugfs_create_file("idle_timeout", 0644, hdev->debugfs, + hdev, &idle_timeout_fops); + debugfs_create_file("sniff_min_interval", 0644, hdev->debugfs, + hdev, &sniff_min_interval_fops); + debugfs_create_file("sniff_max_interval", 0644, hdev->debugfs, + hdev, &sniff_max_interval_fops); + } +} + +static int identity_show(struct seq_file *f, void *p) +{ + struct hci_dev *hdev = f->private; + bdaddr_t addr; + u8 addr_type; + + hci_dev_lock(hdev); + + hci_copy_identity_address(hdev, &addr, &addr_type); + + seq_printf(f, "%pMR (type %u) %*phN %pMR\n", &addr, addr_type, + 16, hdev->irk, &hdev->rpa); + + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(identity); + +static int rpa_timeout_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + /* Require the RPA timeout to be at least 30 seconds and at most + * 24 hours. + */ + if (val < 30 || val > (60 * 60 * 24)) + return -EINVAL; + + hci_dev_lock(hdev); + hdev->rpa_timeout = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int rpa_timeout_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->rpa_timeout; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(rpa_timeout_fops, rpa_timeout_get, + rpa_timeout_set, "%llu\n"); + +static int random_address_show(struct seq_file *f, void *p) +{ + struct hci_dev *hdev = f->private; + + hci_dev_lock(hdev); + seq_printf(f, "%pMR\n", &hdev->random_addr); + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(random_address); + +static int static_address_show(struct seq_file *f, void *p) +{ + struct hci_dev *hdev = f->private; + + hci_dev_lock(hdev); + seq_printf(f, "%pMR\n", &hdev->static_addr); + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(static_address); + +static ssize_t force_static_address_read(struct file *file, + char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct hci_dev *hdev = file->private_data; + char buf[3]; + + buf[0] = hci_dev_test_flag(hdev, HCI_FORCE_STATIC_ADDR) ? 'Y' : 'N'; + buf[1] = '\n'; + buf[2] = '\0'; + return simple_read_from_buffer(user_buf, count, ppos, buf, 2); +} + +static ssize_t force_static_address_write(struct file *file, + const char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct hci_dev *hdev = file->private_data; + bool enable; + int err; + + if (test_bit(HCI_UP, &hdev->flags)) + return -EBUSY; + + err = kstrtobool_from_user(user_buf, count, &enable); + if (err) + return err; + + if (enable == hci_dev_test_flag(hdev, HCI_FORCE_STATIC_ADDR)) + return -EALREADY; + + hci_dev_change_flag(hdev, HCI_FORCE_STATIC_ADDR); + + return count; +} + +static const struct file_operations force_static_address_fops = { + .open = simple_open, + .read = force_static_address_read, + .write = force_static_address_write, + .llseek = default_llseek, +}; + +static int white_list_show(struct seq_file *f, void *ptr) +{ + struct hci_dev *hdev = f->private; + struct bdaddr_list *b; + + hci_dev_lock(hdev); + list_for_each_entry(b, &hdev->le_accept_list, list) + seq_printf(f, "%pMR (type %u)\n", &b->bdaddr, b->bdaddr_type); + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(white_list); + +static int resolv_list_show(struct seq_file *f, void *ptr) +{ + struct hci_dev *hdev = f->private; + struct bdaddr_list *b; + + hci_dev_lock(hdev); + list_for_each_entry(b, &hdev->le_resolv_list, list) + seq_printf(f, "%pMR (type %u)\n", &b->bdaddr, b->bdaddr_type); + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(resolv_list); + +static int identity_resolving_keys_show(struct seq_file *f, void *ptr) +{ + struct hci_dev *hdev = f->private; + struct smp_irk *irk; + + rcu_read_lock(); + list_for_each_entry_rcu(irk, &hdev->identity_resolving_keys, list) { + seq_printf(f, "%pMR (type %u) %*phN %pMR\n", + &irk->bdaddr, irk->addr_type, + 16, irk->val, &irk->rpa); + } + rcu_read_unlock(); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(identity_resolving_keys); + +static int long_term_keys_show(struct seq_file *f, void *ptr) +{ + struct hci_dev *hdev = f->private; + struct smp_ltk *ltk; + + rcu_read_lock(); + list_for_each_entry_rcu(ltk, &hdev->long_term_keys, list) + seq_printf(f, "%pMR (type %u) %u 0x%02x %u %.4x %.16llx %*phN\n", + <k->bdaddr, ltk->bdaddr_type, ltk->authenticated, + ltk->type, ltk->enc_size, __le16_to_cpu(ltk->ediv), + __le64_to_cpu(ltk->rand), 16, ltk->val); + rcu_read_unlock(); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(long_term_keys); + +static int conn_min_interval_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + if (val < 0x0006 || val > 0x0c80 || val > hdev->le_conn_max_interval) + return -EINVAL; + + hci_dev_lock(hdev); + hdev->le_conn_min_interval = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int conn_min_interval_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->le_conn_min_interval; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(conn_min_interval_fops, conn_min_interval_get, + conn_min_interval_set, "%llu\n"); + +static int conn_max_interval_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + if (val < 0x0006 || val > 0x0c80 || val < hdev->le_conn_min_interval) + return -EINVAL; + + hci_dev_lock(hdev); + hdev->le_conn_max_interval = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int conn_max_interval_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->le_conn_max_interval; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(conn_max_interval_fops, conn_max_interval_get, + conn_max_interval_set, "%llu\n"); + +static int conn_latency_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + if (val > 0x01f3) + return -EINVAL; + + hci_dev_lock(hdev); + hdev->le_conn_latency = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int conn_latency_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->le_conn_latency; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(conn_latency_fops, conn_latency_get, + conn_latency_set, "%llu\n"); + +static int supervision_timeout_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + if (val < 0x000a || val > 0x0c80) + return -EINVAL; + + hci_dev_lock(hdev); + hdev->le_supv_timeout = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int supervision_timeout_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->le_supv_timeout; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(supervision_timeout_fops, supervision_timeout_get, + supervision_timeout_set, "%llu\n"); + +static int adv_channel_map_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + if (val < 0x01 || val > 0x07) + return -EINVAL; + + hci_dev_lock(hdev); + hdev->le_adv_channel_map = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int adv_channel_map_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->le_adv_channel_map; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(adv_channel_map_fops, adv_channel_map_get, + adv_channel_map_set, "%llu\n"); + +static int adv_min_interval_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + if (val < 0x0020 || val > 0x4000 || val > hdev->le_adv_max_interval) + return -EINVAL; + + hci_dev_lock(hdev); + hdev->le_adv_min_interval = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int adv_min_interval_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->le_adv_min_interval; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(adv_min_interval_fops, adv_min_interval_get, + adv_min_interval_set, "%llu\n"); + +static int adv_max_interval_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + if (val < 0x0020 || val > 0x4000 || val < hdev->le_adv_min_interval) + return -EINVAL; + + hci_dev_lock(hdev); + hdev->le_adv_max_interval = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int adv_max_interval_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->le_adv_max_interval; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(adv_max_interval_fops, adv_max_interval_get, + adv_max_interval_set, "%llu\n"); + +static int min_key_size_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + if (val > hdev->le_max_key_size || val < SMP_MIN_ENC_KEY_SIZE) { + hci_dev_unlock(hdev); + return -EINVAL; + } + + hdev->le_min_key_size = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int min_key_size_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->le_min_key_size; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(min_key_size_fops, min_key_size_get, + min_key_size_set, "%llu\n"); + +static int max_key_size_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + if (val > SMP_MAX_ENC_KEY_SIZE || val < hdev->le_min_key_size) { + hci_dev_unlock(hdev); + return -EINVAL; + } + + hdev->le_max_key_size = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int max_key_size_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->le_max_key_size; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(max_key_size_fops, max_key_size_get, + max_key_size_set, "%llu\n"); + +static int auth_payload_timeout_set(void *data, u64 val) +{ + struct hci_dev *hdev = data; + + if (val < 0x0001 || val > 0xffff) + return -EINVAL; + + hci_dev_lock(hdev); + hdev->auth_payload_timeout = val; + hci_dev_unlock(hdev); + + return 0; +} + +static int auth_payload_timeout_get(void *data, u64 *val) +{ + struct hci_dev *hdev = data; + + hci_dev_lock(hdev); + *val = hdev->auth_payload_timeout; + hci_dev_unlock(hdev); + + return 0; +} + +DEFINE_DEBUGFS_ATTRIBUTE(auth_payload_timeout_fops, + auth_payload_timeout_get, + auth_payload_timeout_set, "%llu\n"); + +static ssize_t force_no_mitm_read(struct file *file, + char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct hci_dev *hdev = file->private_data; + char buf[3]; + + buf[0] = hci_dev_test_flag(hdev, HCI_FORCE_NO_MITM) ? 'Y' : 'N'; + buf[1] = '\n'; + buf[2] = '\0'; + return simple_read_from_buffer(user_buf, count, ppos, buf, 2); +} + +static ssize_t force_no_mitm_write(struct file *file, + const char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct hci_dev *hdev = file->private_data; + char buf[32]; + size_t buf_size = min(count, (sizeof(buf) - 1)); + bool enable; + + if (copy_from_user(buf, user_buf, buf_size)) + return -EFAULT; + + buf[buf_size] = '\0'; + if (strtobool(buf, &enable)) + return -EINVAL; + + if (enable == hci_dev_test_flag(hdev, HCI_FORCE_NO_MITM)) + return -EALREADY; + + hci_dev_change_flag(hdev, HCI_FORCE_NO_MITM); + + return count; +} + +static const struct file_operations force_no_mitm_fops = { + .open = simple_open, + .read = force_no_mitm_read, + .write = force_no_mitm_write, + .llseek = default_llseek, +}; + +DEFINE_QUIRK_ATTRIBUTE(quirk_strict_duplicate_filter, + HCI_QUIRK_STRICT_DUPLICATE_FILTER); +DEFINE_QUIRK_ATTRIBUTE(quirk_simultaneous_discovery, + HCI_QUIRK_SIMULTANEOUS_DISCOVERY); + +void hci_debugfs_create_le(struct hci_dev *hdev) +{ + debugfs_create_file("identity", 0400, hdev->debugfs, hdev, + &identity_fops); + debugfs_create_file("rpa_timeout", 0644, hdev->debugfs, hdev, + &rpa_timeout_fops); + debugfs_create_file("random_address", 0444, hdev->debugfs, hdev, + &random_address_fops); + debugfs_create_file("static_address", 0444, hdev->debugfs, hdev, + &static_address_fops); + + /* For controllers with a public address, provide a debug + * option to force the usage of the configured static + * address. By default the public address is used. + */ + if (bacmp(&hdev->bdaddr, BDADDR_ANY)) + debugfs_create_file("force_static_address", 0644, + hdev->debugfs, hdev, + &force_static_address_fops); + + debugfs_create_u8("white_list_size", 0444, hdev->debugfs, + &hdev->le_accept_list_size); + debugfs_create_file("white_list", 0444, hdev->debugfs, hdev, + &white_list_fops); + debugfs_create_u8("resolv_list_size", 0444, hdev->debugfs, + &hdev->le_resolv_list_size); + debugfs_create_file("resolv_list", 0444, hdev->debugfs, hdev, + &resolv_list_fops); + debugfs_create_file("identity_resolving_keys", 0400, hdev->debugfs, + hdev, &identity_resolving_keys_fops); + debugfs_create_file("long_term_keys", 0400, hdev->debugfs, hdev, + &long_term_keys_fops); + debugfs_create_file("conn_min_interval", 0644, hdev->debugfs, hdev, + &conn_min_interval_fops); + debugfs_create_file("conn_max_interval", 0644, hdev->debugfs, hdev, + &conn_max_interval_fops); + debugfs_create_file("conn_latency", 0644, hdev->debugfs, hdev, + &conn_latency_fops); + debugfs_create_file("supervision_timeout", 0644, hdev->debugfs, hdev, + &supervision_timeout_fops); + debugfs_create_file("adv_channel_map", 0644, hdev->debugfs, hdev, + &adv_channel_map_fops); + debugfs_create_file("adv_min_interval", 0644, hdev->debugfs, hdev, + &adv_min_interval_fops); + debugfs_create_file("adv_max_interval", 0644, hdev->debugfs, hdev, + &adv_max_interval_fops); + debugfs_create_u16("discov_interleaved_timeout", 0644, hdev->debugfs, + &hdev->discov_interleaved_timeout); + debugfs_create_file("min_key_size", 0644, hdev->debugfs, hdev, + &min_key_size_fops); + debugfs_create_file("max_key_size", 0644, hdev->debugfs, hdev, + &max_key_size_fops); + debugfs_create_file("auth_payload_timeout", 0644, hdev->debugfs, hdev, + &auth_payload_timeout_fops); + debugfs_create_file("force_no_mitm", 0644, hdev->debugfs, hdev, + &force_no_mitm_fops); + + debugfs_create_file("quirk_strict_duplicate_filter", 0644, + hdev->debugfs, hdev, + &quirk_strict_duplicate_filter_fops); + debugfs_create_file("quirk_simultaneous_discovery", 0644, + hdev->debugfs, hdev, + &quirk_simultaneous_discovery_fops); +} + +void hci_debugfs_create_conn(struct hci_conn *conn) +{ + struct hci_dev *hdev = conn->hdev; + char name[6]; + + if (IS_ERR_OR_NULL(hdev->debugfs) || conn->debugfs) + return; + + snprintf(name, sizeof(name), "%u", conn->handle); + conn->debugfs = debugfs_create_dir(name, hdev->debugfs); +} + +static ssize_t dut_mode_read(struct file *file, char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct hci_dev *hdev = file->private_data; + char buf[3]; + + buf[0] = hci_dev_test_flag(hdev, HCI_DUT_MODE) ? 'Y' : 'N'; + buf[1] = '\n'; + buf[2] = '\0'; + return simple_read_from_buffer(user_buf, count, ppos, buf, 2); +} + +static ssize_t dut_mode_write(struct file *file, const char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct hci_dev *hdev = file->private_data; + struct sk_buff *skb; + bool enable; + int err; + + if (!test_bit(HCI_UP, &hdev->flags)) + return -ENETDOWN; + + err = kstrtobool_from_user(user_buf, count, &enable); + if (err) + return err; + + if (enable == hci_dev_test_flag(hdev, HCI_DUT_MODE)) + return -EALREADY; + + hci_req_sync_lock(hdev); + if (enable) + skb = __hci_cmd_sync(hdev, HCI_OP_ENABLE_DUT_MODE, 0, NULL, + HCI_CMD_TIMEOUT); + else + skb = __hci_cmd_sync(hdev, HCI_OP_RESET, 0, NULL, + HCI_CMD_TIMEOUT); + hci_req_sync_unlock(hdev); + + if (IS_ERR(skb)) + return PTR_ERR(skb); + + kfree_skb(skb); + + hci_dev_change_flag(hdev, HCI_DUT_MODE); + + return count; +} + +static const struct file_operations dut_mode_fops = { + .open = simple_open, + .read = dut_mode_read, + .write = dut_mode_write, + .llseek = default_llseek, +}; + +static ssize_t vendor_diag_read(struct file *file, char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct hci_dev *hdev = file->private_data; + char buf[3]; + + buf[0] = hci_dev_test_flag(hdev, HCI_VENDOR_DIAG) ? 'Y' : 'N'; + buf[1] = '\n'; + buf[2] = '\0'; + return simple_read_from_buffer(user_buf, count, ppos, buf, 2); +} + +static ssize_t vendor_diag_write(struct file *file, const char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct hci_dev *hdev = file->private_data; + bool enable; + int err; + + err = kstrtobool_from_user(user_buf, count, &enable); + if (err) + return err; + + /* When the diagnostic flags are not persistent and the transport + * is not active or in user channel operation, then there is no need + * for the vendor callback. Instead just store the desired value and + * the setting will be programmed when the controller gets powered on. + */ + if (test_bit(HCI_QUIRK_NON_PERSISTENT_DIAG, &hdev->quirks) && + (!test_bit(HCI_RUNNING, &hdev->flags) || + hci_dev_test_flag(hdev, HCI_USER_CHANNEL))) + goto done; + + hci_req_sync_lock(hdev); + err = hdev->set_diag(hdev, enable); + hci_req_sync_unlock(hdev); + + if (err < 0) + return err; + +done: + if (enable) + hci_dev_set_flag(hdev, HCI_VENDOR_DIAG); + else + hci_dev_clear_flag(hdev, HCI_VENDOR_DIAG); + + return count; +} + +static const struct file_operations vendor_diag_fops = { + .open = simple_open, + .read = vendor_diag_read, + .write = vendor_diag_write, + .llseek = default_llseek, +}; + +void hci_debugfs_create_basic(struct hci_dev *hdev) +{ + debugfs_create_file("dut_mode", 0644, hdev->debugfs, hdev, + &dut_mode_fops); + + if (hdev->set_diag) + debugfs_create_file("vendor_diag", 0644, hdev->debugfs, hdev, + &vendor_diag_fops); +} diff --git a/net/bluetooth/hci_debugfs.h b/net/bluetooth/hci_debugfs.h new file mode 100644 index 000000000..9a8a7c93b --- /dev/null +++ b/net/bluetooth/hci_debugfs.h @@ -0,0 +1,53 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + Copyright (C) 2014 Intel Corporation + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#if IS_ENABLED(CONFIG_BT_DEBUGFS) + +void hci_debugfs_create_common(struct hci_dev *hdev); +void hci_debugfs_create_bredr(struct hci_dev *hdev); +void hci_debugfs_create_le(struct hci_dev *hdev); +void hci_debugfs_create_conn(struct hci_conn *conn); +void hci_debugfs_create_basic(struct hci_dev *hdev); + +#else + +static inline void hci_debugfs_create_common(struct hci_dev *hdev) +{ +} + +static inline void hci_debugfs_create_bredr(struct hci_dev *hdev) +{ +} + +static inline void hci_debugfs_create_le(struct hci_dev *hdev) +{ +} + +static inline void hci_debugfs_create_conn(struct hci_conn *conn) +{ +} + +static inline void hci_debugfs_create_basic(struct hci_dev *hdev) +{ +} + +#endif diff --git a/net/bluetooth/hci_event.c b/net/bluetooth/hci_event.c new file mode 100644 index 000000000..56ecc5f97 --- /dev/null +++ b/net/bluetooth/hci_event.c @@ -0,0 +1,7576 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + Copyright (c) 2000-2001, 2010, Code Aurora Forum. All rights reserved. + + Written 2000,2001 by Maxim Krasnyansky + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +/* Bluetooth HCI event handling. */ + +#include +#include +#include + +#include +#include +#include + +#include "hci_request.h" +#include "hci_debugfs.h" +#include "hci_codec.h" +#include "a2mp.h" +#include "amp.h" +#include "smp.h" +#include "msft.h" +#include "eir.h" + +#define ZERO_KEY "\x00\x00\x00\x00\x00\x00\x00\x00" \ + "\x00\x00\x00\x00\x00\x00\x00\x00" + +#define secs_to_jiffies(_secs) msecs_to_jiffies((_secs) * 1000) + +/* Handle HCI Event packets */ + +static void *hci_ev_skb_pull(struct hci_dev *hdev, struct sk_buff *skb, + u8 ev, size_t len) +{ + void *data; + + data = skb_pull_data(skb, len); + if (!data) + bt_dev_err(hdev, "Malformed Event: 0x%2.2x", ev); + + return data; +} + +static void *hci_cc_skb_pull(struct hci_dev *hdev, struct sk_buff *skb, + u16 op, size_t len) +{ + void *data; + + data = skb_pull_data(skb, len); + if (!data) + bt_dev_err(hdev, "Malformed Command Complete: 0x%4.4x", op); + + return data; +} + +static void *hci_le_ev_skb_pull(struct hci_dev *hdev, struct sk_buff *skb, + u8 ev, size_t len) +{ + void *data; + + data = skb_pull_data(skb, len); + if (!data) + bt_dev_err(hdev, "Malformed LE Event: 0x%2.2x", ev); + + return data; +} + +static u8 hci_cc_inquiry_cancel(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + /* It is possible that we receive Inquiry Complete event right + * before we receive Inquiry Cancel Command Complete event, in + * which case the latter event should have status of Command + * Disallowed (0x0c). This should not be treated as error, since + * we actually achieve what Inquiry Cancel wants to achieve, + * which is to end the last Inquiry session. + */ + if (rp->status == 0x0c && !test_bit(HCI_INQUIRY, &hdev->flags)) { + bt_dev_warn(hdev, "Ignoring error of Inquiry Cancel command"); + rp->status = 0x00; + } + + if (rp->status) + return rp->status; + + clear_bit(HCI_INQUIRY, &hdev->flags); + smp_mb__after_atomic(); /* wake_up_bit advises about this barrier */ + wake_up_bit(&hdev->flags, HCI_INQUIRY); + + hci_dev_lock(hdev); + /* Set discovery state to stopped if we're not doing LE active + * scanning. + */ + if (!hci_dev_test_flag(hdev, HCI_LE_SCAN) || + hdev->le_scan_type != LE_SCAN_ACTIVE) + hci_discovery_set_state(hdev, DISCOVERY_STOPPED); + hci_dev_unlock(hdev); + + hci_conn_check_pending(hdev); + + return rp->status; +} + +static u8 hci_cc_periodic_inq(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hci_dev_set_flag(hdev, HCI_PERIODIC_INQ); + + return rp->status; +} + +static u8 hci_cc_exit_periodic_inq(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hci_dev_clear_flag(hdev, HCI_PERIODIC_INQ); + + hci_conn_check_pending(hdev); + + return rp->status; +} + +static u8 hci_cc_remote_name_req_cancel(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + return rp->status; +} + +static u8 hci_cc_role_discovery(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_role_discovery *rp = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(rp->handle)); + if (conn) + conn->role = rp->role; + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_read_link_policy(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_link_policy *rp = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(rp->handle)); + if (conn) + conn->link_policy = __le16_to_cpu(rp->policy); + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_write_link_policy(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_write_link_policy *rp = data; + struct hci_conn *conn; + void *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_WRITE_LINK_POLICY); + if (!sent) + return rp->status; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(rp->handle)); + if (conn) + conn->link_policy = get_unaligned_le16(sent + 2); + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_read_def_link_policy(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_def_link_policy *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hdev->link_policy = __le16_to_cpu(rp->policy); + + return rp->status; +} + +static u8 hci_cc_write_def_link_policy(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + void *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_WRITE_DEF_LINK_POLICY); + if (!sent) + return rp->status; + + hdev->link_policy = get_unaligned_le16(sent); + + return rp->status; +} + +static u8 hci_cc_reset(struct hci_dev *hdev, void *data, struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + clear_bit(HCI_RESET, &hdev->flags); + + if (rp->status) + return rp->status; + + /* Reset all non-persistent flags */ + hci_dev_clear_volatile_flags(hdev); + + hci_discovery_set_state(hdev, DISCOVERY_STOPPED); + + hdev->inq_tx_power = HCI_TX_POWER_INVALID; + hdev->adv_tx_power = HCI_TX_POWER_INVALID; + + memset(hdev->adv_data, 0, sizeof(hdev->adv_data)); + hdev->adv_data_len = 0; + + memset(hdev->scan_rsp_data, 0, sizeof(hdev->scan_rsp_data)); + hdev->scan_rsp_data_len = 0; + + hdev->le_scan_type = LE_SCAN_PASSIVE; + + hdev->ssp_debug_mode = 0; + + hci_bdaddr_list_clear(&hdev->le_accept_list); + hci_bdaddr_list_clear(&hdev->le_resolv_list); + + return rp->status; +} + +static u8 hci_cc_read_stored_link_key(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_stored_link_key *rp = data; + struct hci_cp_read_stored_link_key *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + sent = hci_sent_cmd_data(hdev, HCI_OP_READ_STORED_LINK_KEY); + if (!sent) + return rp->status; + + if (!rp->status && sent->read_all == 0x01) { + hdev->stored_max_keys = le16_to_cpu(rp->max_keys); + hdev->stored_num_keys = le16_to_cpu(rp->num_keys); + } + + return rp->status; +} + +static u8 hci_cc_delete_stored_link_key(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_delete_stored_link_key *rp = data; + u16 num_keys; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + num_keys = le16_to_cpu(rp->num_keys); + + if (num_keys <= hdev->stored_num_keys) + hdev->stored_num_keys -= num_keys; + else + hdev->stored_num_keys = 0; + + return rp->status; +} + +static u8 hci_cc_write_local_name(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + void *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + sent = hci_sent_cmd_data(hdev, HCI_OP_WRITE_LOCAL_NAME); + if (!sent) + return rp->status; + + hci_dev_lock(hdev); + + if (hci_dev_test_flag(hdev, HCI_MGMT)) + mgmt_set_local_name_complete(hdev, sent, rp->status); + else if (!rp->status) + memcpy(hdev->dev_name, sent, HCI_MAX_NAME_LENGTH); + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_read_local_name(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_local_name *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + if (hci_dev_test_flag(hdev, HCI_SETUP) || + hci_dev_test_flag(hdev, HCI_CONFIG)) + memcpy(hdev->dev_name, rp->name, HCI_MAX_NAME_LENGTH); + + return rp->status; +} + +static u8 hci_cc_write_auth_enable(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + void *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + sent = hci_sent_cmd_data(hdev, HCI_OP_WRITE_AUTH_ENABLE); + if (!sent) + return rp->status; + + hci_dev_lock(hdev); + + if (!rp->status) { + __u8 param = *((__u8 *) sent); + + if (param == AUTH_ENABLED) + set_bit(HCI_AUTH, &hdev->flags); + else + clear_bit(HCI_AUTH, &hdev->flags); + } + + if (hci_dev_test_flag(hdev, HCI_MGMT)) + mgmt_auth_enable_complete(hdev, rp->status); + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_write_encrypt_mode(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + __u8 param; + void *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_WRITE_ENCRYPT_MODE); + if (!sent) + return rp->status; + + param = *((__u8 *) sent); + + if (param) + set_bit(HCI_ENCRYPT, &hdev->flags); + else + clear_bit(HCI_ENCRYPT, &hdev->flags); + + return rp->status; +} + +static u8 hci_cc_write_scan_enable(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + __u8 param; + void *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + sent = hci_sent_cmd_data(hdev, HCI_OP_WRITE_SCAN_ENABLE); + if (!sent) + return rp->status; + + param = *((__u8 *) sent); + + hci_dev_lock(hdev); + + if (rp->status) { + hdev->discov_timeout = 0; + goto done; + } + + if (param & SCAN_INQUIRY) + set_bit(HCI_ISCAN, &hdev->flags); + else + clear_bit(HCI_ISCAN, &hdev->flags); + + if (param & SCAN_PAGE) + set_bit(HCI_PSCAN, &hdev->flags); + else + clear_bit(HCI_PSCAN, &hdev->flags); + +done: + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_set_event_filter(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + struct hci_cp_set_event_filter *cp; + void *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_SET_EVENT_FLT); + if (!sent) + return rp->status; + + cp = (struct hci_cp_set_event_filter *)sent; + + if (cp->flt_type == HCI_FLT_CLEAR_ALL) + hci_dev_clear_flag(hdev, HCI_EVENT_FILTER_CONFIGURED); + else + hci_dev_set_flag(hdev, HCI_EVENT_FILTER_CONFIGURED); + + return rp->status; +} + +static u8 hci_cc_read_class_of_dev(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_class_of_dev *rp = data; + + if (WARN_ON(!hdev)) + return HCI_ERROR_UNSPECIFIED; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + memcpy(hdev->dev_class, rp->dev_class, 3); + + bt_dev_dbg(hdev, "class 0x%.2x%.2x%.2x", hdev->dev_class[2], + hdev->dev_class[1], hdev->dev_class[0]); + + return rp->status; +} + +static u8 hci_cc_write_class_of_dev(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + void *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + sent = hci_sent_cmd_data(hdev, HCI_OP_WRITE_CLASS_OF_DEV); + if (!sent) + return rp->status; + + hci_dev_lock(hdev); + + if (!rp->status) + memcpy(hdev->dev_class, sent, 3); + + if (hci_dev_test_flag(hdev, HCI_MGMT)) + mgmt_set_class_of_dev_complete(hdev, sent, rp->status); + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_read_voice_setting(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_voice_setting *rp = data; + __u16 setting; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + setting = __le16_to_cpu(rp->voice_setting); + + if (hdev->voice_setting == setting) + return rp->status; + + hdev->voice_setting = setting; + + bt_dev_dbg(hdev, "voice setting 0x%4.4x", setting); + + if (hdev->notify) + hdev->notify(hdev, HCI_NOTIFY_VOICE_SETTING); + + return rp->status; +} + +static u8 hci_cc_write_voice_setting(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + __u16 setting; + void *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_WRITE_VOICE_SETTING); + if (!sent) + return rp->status; + + setting = get_unaligned_le16(sent); + + if (hdev->voice_setting == setting) + return rp->status; + + hdev->voice_setting = setting; + + bt_dev_dbg(hdev, "voice setting 0x%4.4x", setting); + + if (hdev->notify) + hdev->notify(hdev, HCI_NOTIFY_VOICE_SETTING); + + return rp->status; +} + +static u8 hci_cc_read_num_supported_iac(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_num_supported_iac *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hdev->num_iac = rp->num_iac; + + bt_dev_dbg(hdev, "num iac %d", hdev->num_iac); + + return rp->status; +} + +static u8 hci_cc_write_ssp_mode(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + struct hci_cp_write_ssp_mode *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + sent = hci_sent_cmd_data(hdev, HCI_OP_WRITE_SSP_MODE); + if (!sent) + return rp->status; + + hci_dev_lock(hdev); + + if (!rp->status) { + if (sent->mode) + hdev->features[1][0] |= LMP_HOST_SSP; + else + hdev->features[1][0] &= ~LMP_HOST_SSP; + } + + if (!rp->status) { + if (sent->mode) + hci_dev_set_flag(hdev, HCI_SSP_ENABLED); + else + hci_dev_clear_flag(hdev, HCI_SSP_ENABLED); + } + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_write_sc_support(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + struct hci_cp_write_sc_support *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + sent = hci_sent_cmd_data(hdev, HCI_OP_WRITE_SC_SUPPORT); + if (!sent) + return rp->status; + + hci_dev_lock(hdev); + + if (!rp->status) { + if (sent->support) + hdev->features[1][0] |= LMP_HOST_SC; + else + hdev->features[1][0] &= ~LMP_HOST_SC; + } + + if (!hci_dev_test_flag(hdev, HCI_MGMT) && !rp->status) { + if (sent->support) + hci_dev_set_flag(hdev, HCI_SC_ENABLED); + else + hci_dev_clear_flag(hdev, HCI_SC_ENABLED); + } + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_read_local_version(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_local_version *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + if (hci_dev_test_flag(hdev, HCI_SETUP) || + hci_dev_test_flag(hdev, HCI_CONFIG)) { + hdev->hci_ver = rp->hci_ver; + hdev->hci_rev = __le16_to_cpu(rp->hci_rev); + hdev->lmp_ver = rp->lmp_ver; + hdev->manufacturer = __le16_to_cpu(rp->manufacturer); + hdev->lmp_subver = __le16_to_cpu(rp->lmp_subver); + } + + return rp->status; +} + +static u8 hci_cc_read_enc_key_size(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_enc_key_size *rp = data; + struct hci_conn *conn; + u16 handle; + u8 status = rp->status; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + handle = le16_to_cpu(rp->handle); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, handle); + if (!conn) { + status = 0xFF; + goto done; + } + + /* While unexpected, the read_enc_key_size command may fail. The most + * secure approach is to then assume the key size is 0 to force a + * disconnection. + */ + if (status) { + bt_dev_err(hdev, "failed to read key size for handle %u", + handle); + conn->enc_key_size = 0; + } else { + conn->enc_key_size = rp->key_size; + status = 0; + + if (conn->enc_key_size < hdev->min_enc_key_size) { + /* As slave role, the conn->state has been set to + * BT_CONNECTED and l2cap conn req might not be received + * yet, at this moment the l2cap layer almost does + * nothing with the non-zero status. + * So we also clear encrypt related bits, and then the + * handler of l2cap conn req will get the right secure + * state at a later time. + */ + status = HCI_ERROR_AUTH_FAILURE; + clear_bit(HCI_CONN_ENCRYPT, &conn->flags); + clear_bit(HCI_CONN_AES_CCM, &conn->flags); + } + } + + hci_encrypt_cfm(conn, status); + +done: + hci_dev_unlock(hdev); + + return status; +} + +static u8 hci_cc_read_local_commands(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_local_commands *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + if (hci_dev_test_flag(hdev, HCI_SETUP) || + hci_dev_test_flag(hdev, HCI_CONFIG)) + memcpy(hdev->commands, rp->commands, sizeof(hdev->commands)); + + return rp->status; +} + +static u8 hci_cc_read_auth_payload_timeout(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_auth_payload_to *rp = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(rp->handle)); + if (conn) + conn->auth_payload_timeout = __le16_to_cpu(rp->timeout); + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_write_auth_payload_timeout(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_write_auth_payload_to *rp = data; + struct hci_conn *conn; + void *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_WRITE_AUTH_PAYLOAD_TO); + if (!sent) + return rp->status; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(rp->handle)); + if (conn) + conn->auth_payload_timeout = get_unaligned_le16(sent + 2); + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_read_local_features(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_local_features *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + memcpy(hdev->features, rp->features, 8); + + /* Adjust default settings according to features + * supported by device. */ + + if (hdev->features[0][0] & LMP_3SLOT) + hdev->pkt_type |= (HCI_DM3 | HCI_DH3); + + if (hdev->features[0][0] & LMP_5SLOT) + hdev->pkt_type |= (HCI_DM5 | HCI_DH5); + + if (hdev->features[0][1] & LMP_HV2) { + hdev->pkt_type |= (HCI_HV2); + hdev->esco_type |= (ESCO_HV2); + } + + if (hdev->features[0][1] & LMP_HV3) { + hdev->pkt_type |= (HCI_HV3); + hdev->esco_type |= (ESCO_HV3); + } + + if (lmp_esco_capable(hdev)) + hdev->esco_type |= (ESCO_EV3); + + if (hdev->features[0][4] & LMP_EV4) + hdev->esco_type |= (ESCO_EV4); + + if (hdev->features[0][4] & LMP_EV5) + hdev->esco_type |= (ESCO_EV5); + + if (hdev->features[0][5] & LMP_EDR_ESCO_2M) + hdev->esco_type |= (ESCO_2EV3); + + if (hdev->features[0][5] & LMP_EDR_ESCO_3M) + hdev->esco_type |= (ESCO_3EV3); + + if (hdev->features[0][5] & LMP_EDR_3S_ESCO) + hdev->esco_type |= (ESCO_2EV5 | ESCO_3EV5); + + return rp->status; +} + +static u8 hci_cc_read_local_ext_features(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_local_ext_features *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + if (hdev->max_page < rp->max_page) { + if (test_bit(HCI_QUIRK_BROKEN_LOCAL_EXT_FEATURES_PAGE_2, + &hdev->quirks)) + bt_dev_warn(hdev, "broken local ext features page 2"); + else + hdev->max_page = rp->max_page; + } + + if (rp->page < HCI_MAX_PAGES) + memcpy(hdev->features[rp->page], rp->features, 8); + + return rp->status; +} + +static u8 hci_cc_read_flow_control_mode(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_flow_control_mode *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hdev->flow_ctl_mode = rp->mode; + + return rp->status; +} + +static u8 hci_cc_read_buffer_size(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_buffer_size *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hdev->acl_mtu = __le16_to_cpu(rp->acl_mtu); + hdev->sco_mtu = rp->sco_mtu; + hdev->acl_pkts = __le16_to_cpu(rp->acl_max_pkt); + hdev->sco_pkts = __le16_to_cpu(rp->sco_max_pkt); + + if (test_bit(HCI_QUIRK_FIXUP_BUFFER_SIZE, &hdev->quirks)) { + hdev->sco_mtu = 64; + hdev->sco_pkts = 8; + } + + hdev->acl_cnt = hdev->acl_pkts; + hdev->sco_cnt = hdev->sco_pkts; + + BT_DBG("%s acl mtu %d:%d sco mtu %d:%d", hdev->name, hdev->acl_mtu, + hdev->acl_pkts, hdev->sco_mtu, hdev->sco_pkts); + + return rp->status; +} + +static u8 hci_cc_read_bd_addr(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_bd_addr *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + if (test_bit(HCI_INIT, &hdev->flags)) + bacpy(&hdev->bdaddr, &rp->bdaddr); + + if (hci_dev_test_flag(hdev, HCI_SETUP)) + bacpy(&hdev->setup_addr, &rp->bdaddr); + + return rp->status; +} + +static u8 hci_cc_read_local_pairing_opts(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_local_pairing_opts *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + if (hci_dev_test_flag(hdev, HCI_SETUP) || + hci_dev_test_flag(hdev, HCI_CONFIG)) { + hdev->pairing_opts = rp->pairing_opts; + hdev->max_enc_key_size = rp->max_key_size; + } + + return rp->status; +} + +static u8 hci_cc_read_page_scan_activity(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_page_scan_activity *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + if (test_bit(HCI_INIT, &hdev->flags)) { + hdev->page_scan_interval = __le16_to_cpu(rp->interval); + hdev->page_scan_window = __le16_to_cpu(rp->window); + } + + return rp->status; +} + +static u8 hci_cc_write_page_scan_activity(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + struct hci_cp_write_page_scan_activity *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_WRITE_PAGE_SCAN_ACTIVITY); + if (!sent) + return rp->status; + + hdev->page_scan_interval = __le16_to_cpu(sent->interval); + hdev->page_scan_window = __le16_to_cpu(sent->window); + + return rp->status; +} + +static u8 hci_cc_read_page_scan_type(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_page_scan_type *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + if (test_bit(HCI_INIT, &hdev->flags)) + hdev->page_scan_type = rp->type; + + return rp->status; +} + +static u8 hci_cc_write_page_scan_type(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + u8 *type; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + type = hci_sent_cmd_data(hdev, HCI_OP_WRITE_PAGE_SCAN_TYPE); + if (type) + hdev->page_scan_type = *type; + + return rp->status; +} + +static u8 hci_cc_read_data_block_size(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_data_block_size *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hdev->block_mtu = __le16_to_cpu(rp->max_acl_len); + hdev->block_len = __le16_to_cpu(rp->block_len); + hdev->num_blocks = __le16_to_cpu(rp->num_blocks); + + hdev->block_cnt = hdev->num_blocks; + + BT_DBG("%s blk mtu %d cnt %d len %d", hdev->name, hdev->block_mtu, + hdev->block_cnt, hdev->block_len); + + return rp->status; +} + +static u8 hci_cc_read_clock(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_clock *rp = data; + struct hci_cp_read_clock *cp; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hci_dev_lock(hdev); + + cp = hci_sent_cmd_data(hdev, HCI_OP_READ_CLOCK); + if (!cp) + goto unlock; + + if (cp->which == 0x00) { + hdev->clock = le32_to_cpu(rp->clock); + goto unlock; + } + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(rp->handle)); + if (conn) { + conn->clock = le32_to_cpu(rp->clock); + conn->clock_accuracy = le16_to_cpu(rp->accuracy); + } + +unlock: + hci_dev_unlock(hdev); + return rp->status; +} + +static u8 hci_cc_read_local_amp_info(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_local_amp_info *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hdev->amp_status = rp->amp_status; + hdev->amp_total_bw = __le32_to_cpu(rp->total_bw); + hdev->amp_max_bw = __le32_to_cpu(rp->max_bw); + hdev->amp_min_latency = __le32_to_cpu(rp->min_latency); + hdev->amp_max_pdu = __le32_to_cpu(rp->max_pdu); + hdev->amp_type = rp->amp_type; + hdev->amp_pal_cap = __le16_to_cpu(rp->pal_cap); + hdev->amp_assoc_size = __le16_to_cpu(rp->max_assoc_size); + hdev->amp_be_flush_to = __le32_to_cpu(rp->be_flush_to); + hdev->amp_max_flush_to = __le32_to_cpu(rp->max_flush_to); + + return rp->status; +} + +static u8 hci_cc_read_inq_rsp_tx_power(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_inq_rsp_tx_power *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hdev->inq_tx_power = rp->tx_power; + + return rp->status; +} + +static u8 hci_cc_read_def_err_data_reporting(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_def_err_data_reporting *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hdev->err_data_reporting = rp->err_data_reporting; + + return rp->status; +} + +static u8 hci_cc_write_def_err_data_reporting(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + struct hci_cp_write_def_err_data_reporting *cp; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + cp = hci_sent_cmd_data(hdev, HCI_OP_WRITE_DEF_ERR_DATA_REPORTING); + if (!cp) + return rp->status; + + hdev->err_data_reporting = cp->err_data_reporting; + + return rp->status; +} + +static u8 hci_cc_pin_code_reply(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_pin_code_reply *rp = data; + struct hci_cp_pin_code_reply *cp; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + hci_dev_lock(hdev); + + if (hci_dev_test_flag(hdev, HCI_MGMT)) + mgmt_pin_code_reply_complete(hdev, &rp->bdaddr, rp->status); + + if (rp->status) + goto unlock; + + cp = hci_sent_cmd_data(hdev, HCI_OP_PIN_CODE_REPLY); + if (!cp) + goto unlock; + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &cp->bdaddr); + if (conn) + conn->pin_length = cp->pin_len; + +unlock: + hci_dev_unlock(hdev); + return rp->status; +} + +static u8 hci_cc_pin_code_neg_reply(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_pin_code_neg_reply *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + hci_dev_lock(hdev); + + if (hci_dev_test_flag(hdev, HCI_MGMT)) + mgmt_pin_code_neg_reply_complete(hdev, &rp->bdaddr, + rp->status); + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_le_read_buffer_size(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_le_read_buffer_size *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hdev->le_mtu = __le16_to_cpu(rp->le_mtu); + hdev->le_pkts = rp->le_max_pkt; + + hdev->le_cnt = hdev->le_pkts; + + BT_DBG("%s le mtu %d:%d", hdev->name, hdev->le_mtu, hdev->le_pkts); + + return rp->status; +} + +static u8 hci_cc_le_read_local_features(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_le_read_local_features *rp = data; + + BT_DBG("%s status 0x%2.2x", hdev->name, rp->status); + + if (rp->status) + return rp->status; + + memcpy(hdev->le_features, rp->features, 8); + + return rp->status; +} + +static u8 hci_cc_le_read_adv_tx_power(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_le_read_adv_tx_power *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hdev->adv_tx_power = rp->tx_power; + + return rp->status; +} + +static u8 hci_cc_user_confirm_reply(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_user_confirm_reply *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + hci_dev_lock(hdev); + + if (hci_dev_test_flag(hdev, HCI_MGMT)) + mgmt_user_confirm_reply_complete(hdev, &rp->bdaddr, ACL_LINK, 0, + rp->status); + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_user_confirm_neg_reply(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_user_confirm_reply *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + hci_dev_lock(hdev); + + if (hci_dev_test_flag(hdev, HCI_MGMT)) + mgmt_user_confirm_neg_reply_complete(hdev, &rp->bdaddr, + ACL_LINK, 0, rp->status); + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_user_passkey_reply(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_user_confirm_reply *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + hci_dev_lock(hdev); + + if (hci_dev_test_flag(hdev, HCI_MGMT)) + mgmt_user_passkey_reply_complete(hdev, &rp->bdaddr, ACL_LINK, + 0, rp->status); + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_user_passkey_neg_reply(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_user_confirm_reply *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + hci_dev_lock(hdev); + + if (hci_dev_test_flag(hdev, HCI_MGMT)) + mgmt_user_passkey_neg_reply_complete(hdev, &rp->bdaddr, + ACL_LINK, 0, rp->status); + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_read_local_oob_data(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_local_oob_data *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + return rp->status; +} + +static u8 hci_cc_read_local_oob_ext_data(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_local_oob_ext_data *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + return rp->status; +} + +static u8 hci_cc_le_set_random_addr(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + bdaddr_t *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_LE_SET_RANDOM_ADDR); + if (!sent) + return rp->status; + + hci_dev_lock(hdev); + + bacpy(&hdev->random_addr, sent); + + if (!bacmp(&hdev->rpa, sent)) { + hci_dev_clear_flag(hdev, HCI_RPA_EXPIRED); + queue_delayed_work(hdev->workqueue, &hdev->rpa_expired, + secs_to_jiffies(hdev->rpa_timeout)); + } + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_le_set_default_phy(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + struct hci_cp_le_set_default_phy *cp; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + cp = hci_sent_cmd_data(hdev, HCI_OP_LE_SET_DEFAULT_PHY); + if (!cp) + return rp->status; + + hci_dev_lock(hdev); + + hdev->le_tx_def_phys = cp->tx_phys; + hdev->le_rx_def_phys = cp->rx_phys; + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_le_set_adv_set_random_addr(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + struct hci_cp_le_set_adv_set_rand_addr *cp; + struct adv_info *adv; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + cp = hci_sent_cmd_data(hdev, HCI_OP_LE_SET_ADV_SET_RAND_ADDR); + /* Update only in case the adv instance since handle 0x00 shall be using + * HCI_OP_LE_SET_RANDOM_ADDR since that allows both extended and + * non-extended adverting. + */ + if (!cp || !cp->handle) + return rp->status; + + hci_dev_lock(hdev); + + adv = hci_find_adv_instance(hdev, cp->handle); + if (adv) { + bacpy(&adv->random_addr, &cp->bdaddr); + if (!bacmp(&hdev->rpa, &cp->bdaddr)) { + adv->rpa_expired = false; + queue_delayed_work(hdev->workqueue, + &adv->rpa_expired_cb, + secs_to_jiffies(hdev->rpa_timeout)); + } + } + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_le_remove_adv_set(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + u8 *instance; + int err; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + instance = hci_sent_cmd_data(hdev, HCI_OP_LE_REMOVE_ADV_SET); + if (!instance) + return rp->status; + + hci_dev_lock(hdev); + + err = hci_remove_adv_instance(hdev, *instance); + if (!err) + mgmt_advertising_removed(hci_skb_sk(hdev->sent_cmd), hdev, + *instance); + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_le_clear_adv_sets(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + struct adv_info *adv, *n; + int err; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + if (!hci_sent_cmd_data(hdev, HCI_OP_LE_CLEAR_ADV_SETS)) + return rp->status; + + hci_dev_lock(hdev); + + list_for_each_entry_safe(adv, n, &hdev->adv_instances, list) { + u8 instance = adv->instance; + + err = hci_remove_adv_instance(hdev, instance); + if (!err) + mgmt_advertising_removed(hci_skb_sk(hdev->sent_cmd), + hdev, instance); + } + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_le_read_transmit_power(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_le_read_transmit_power *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hdev->min_le_tx_power = rp->min_le_tx_power; + hdev->max_le_tx_power = rp->max_le_tx_power; + + return rp->status; +} + +static u8 hci_cc_le_set_privacy_mode(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + struct hci_cp_le_set_privacy_mode *cp; + struct hci_conn_params *params; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + cp = hci_sent_cmd_data(hdev, HCI_OP_LE_SET_PRIVACY_MODE); + if (!cp) + return rp->status; + + hci_dev_lock(hdev); + + params = hci_conn_params_lookup(hdev, &cp->bdaddr, cp->bdaddr_type); + if (params) + WRITE_ONCE(params->privacy_mode, cp->mode); + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_le_set_adv_enable(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + __u8 *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_LE_SET_ADV_ENABLE); + if (!sent) + return rp->status; + + hci_dev_lock(hdev); + + /* If we're doing connection initiation as peripheral. Set a + * timeout in case something goes wrong. + */ + if (*sent) { + struct hci_conn *conn; + + hci_dev_set_flag(hdev, HCI_LE_ADV); + + conn = hci_lookup_le_connect(hdev); + if (conn) + queue_delayed_work(hdev->workqueue, + &conn->le_conn_timeout, + conn->conn_timeout); + } else { + hci_dev_clear_flag(hdev, HCI_LE_ADV); + } + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_le_set_ext_adv_enable(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_cp_le_set_ext_adv_enable *cp; + struct hci_cp_ext_adv_set *set; + struct adv_info *adv = NULL, *n; + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + cp = hci_sent_cmd_data(hdev, HCI_OP_LE_SET_EXT_ADV_ENABLE); + if (!cp) + return rp->status; + + set = (void *)cp->data; + + hci_dev_lock(hdev); + + if (cp->num_of_sets) + adv = hci_find_adv_instance(hdev, set->handle); + + if (cp->enable) { + struct hci_conn *conn; + + hci_dev_set_flag(hdev, HCI_LE_ADV); + + if (adv) + adv->enabled = true; + + conn = hci_lookup_le_connect(hdev); + if (conn) + queue_delayed_work(hdev->workqueue, + &conn->le_conn_timeout, + conn->conn_timeout); + } else { + if (cp->num_of_sets) { + if (adv) + adv->enabled = false; + + /* If just one instance was disabled check if there are + * any other instance enabled before clearing HCI_LE_ADV + */ + list_for_each_entry_safe(adv, n, &hdev->adv_instances, + list) { + if (adv->enabled) + goto unlock; + } + } else { + /* All instances shall be considered disabled */ + list_for_each_entry_safe(adv, n, &hdev->adv_instances, + list) + adv->enabled = false; + } + + hci_dev_clear_flag(hdev, HCI_LE_ADV); + } + +unlock: + hci_dev_unlock(hdev); + return rp->status; +} + +static u8 hci_cc_le_set_scan_param(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_cp_le_set_scan_param *cp; + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + cp = hci_sent_cmd_data(hdev, HCI_OP_LE_SET_SCAN_PARAM); + if (!cp) + return rp->status; + + hci_dev_lock(hdev); + + hdev->le_scan_type = cp->type; + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_le_set_ext_scan_param(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_cp_le_set_ext_scan_params *cp; + struct hci_ev_status *rp = data; + struct hci_cp_le_scan_phy_params *phy_param; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + cp = hci_sent_cmd_data(hdev, HCI_OP_LE_SET_EXT_SCAN_PARAMS); + if (!cp) + return rp->status; + + phy_param = (void *)cp->data; + + hci_dev_lock(hdev); + + hdev->le_scan_type = phy_param->type; + + hci_dev_unlock(hdev); + + return rp->status; +} + +static bool has_pending_adv_report(struct hci_dev *hdev) +{ + struct discovery_state *d = &hdev->discovery; + + return bacmp(&d->last_adv_addr, BDADDR_ANY); +} + +static void clear_pending_adv_report(struct hci_dev *hdev) +{ + struct discovery_state *d = &hdev->discovery; + + bacpy(&d->last_adv_addr, BDADDR_ANY); + d->last_adv_data_len = 0; +} + +static void store_pending_adv_report(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 bdaddr_type, s8 rssi, u32 flags, + u8 *data, u8 len) +{ + struct discovery_state *d = &hdev->discovery; + + if (len > HCI_MAX_AD_LENGTH) + return; + + bacpy(&d->last_adv_addr, bdaddr); + d->last_adv_addr_type = bdaddr_type; + d->last_adv_rssi = rssi; + d->last_adv_flags = flags; + memcpy(d->last_adv_data, data, len); + d->last_adv_data_len = len; +} + +static void le_set_scan_enable_complete(struct hci_dev *hdev, u8 enable) +{ + hci_dev_lock(hdev); + + switch (enable) { + case LE_SCAN_ENABLE: + hci_dev_set_flag(hdev, HCI_LE_SCAN); + if (hdev->le_scan_type == LE_SCAN_ACTIVE) + clear_pending_adv_report(hdev); + if (hci_dev_test_flag(hdev, HCI_MESH)) + hci_discovery_set_state(hdev, DISCOVERY_FINDING); + break; + + case LE_SCAN_DISABLE: + /* We do this here instead of when setting DISCOVERY_STOPPED + * since the latter would potentially require waiting for + * inquiry to stop too. + */ + if (has_pending_adv_report(hdev)) { + struct discovery_state *d = &hdev->discovery; + + mgmt_device_found(hdev, &d->last_adv_addr, LE_LINK, + d->last_adv_addr_type, NULL, + d->last_adv_rssi, d->last_adv_flags, + d->last_adv_data, + d->last_adv_data_len, NULL, 0, 0); + } + + /* Cancel this timer so that we don't try to disable scanning + * when it's already disabled. + */ + cancel_delayed_work(&hdev->le_scan_disable); + + hci_dev_clear_flag(hdev, HCI_LE_SCAN); + + /* The HCI_LE_SCAN_INTERRUPTED flag indicates that we + * interrupted scanning due to a connect request. Mark + * therefore discovery as stopped. + */ + if (hci_dev_test_and_clear_flag(hdev, HCI_LE_SCAN_INTERRUPTED)) + hci_discovery_set_state(hdev, DISCOVERY_STOPPED); + else if (!hci_dev_test_flag(hdev, HCI_LE_ADV) && + hdev->discovery.state == DISCOVERY_FINDING) + queue_work(hdev->workqueue, &hdev->reenable_adv_work); + + break; + + default: + bt_dev_err(hdev, "use of reserved LE_Scan_Enable param %d", + enable); + break; + } + + hci_dev_unlock(hdev); +} + +static u8 hci_cc_le_set_scan_enable(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_cp_le_set_scan_enable *cp; + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + cp = hci_sent_cmd_data(hdev, HCI_OP_LE_SET_SCAN_ENABLE); + if (!cp) + return rp->status; + + le_set_scan_enable_complete(hdev, cp->enable); + + return rp->status; +} + +static u8 hci_cc_le_set_ext_scan_enable(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_cp_le_set_ext_scan_enable *cp; + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + cp = hci_sent_cmd_data(hdev, HCI_OP_LE_SET_EXT_SCAN_ENABLE); + if (!cp) + return rp->status; + + le_set_scan_enable_complete(hdev, cp->enable); + + return rp->status; +} + +static u8 hci_cc_le_read_num_adv_sets(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_le_read_num_supported_adv_sets *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x No of Adv sets %u", rp->status, + rp->num_of_sets); + + if (rp->status) + return rp->status; + + hdev->le_num_of_adv_sets = rp->num_of_sets; + + return rp->status; +} + +static u8 hci_cc_le_read_accept_list_size(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_le_read_accept_list_size *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x size %u", rp->status, rp->size); + + if (rp->status) + return rp->status; + + hdev->le_accept_list_size = rp->size; + + return rp->status; +} + +static u8 hci_cc_le_clear_accept_list(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hci_dev_lock(hdev); + hci_bdaddr_list_clear(&hdev->le_accept_list); + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_le_add_to_accept_list(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_cp_le_add_to_accept_list *sent; + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_LE_ADD_TO_ACCEPT_LIST); + if (!sent) + return rp->status; + + hci_dev_lock(hdev); + hci_bdaddr_list_add(&hdev->le_accept_list, &sent->bdaddr, + sent->bdaddr_type); + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_le_del_from_accept_list(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_cp_le_del_from_accept_list *sent; + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_LE_DEL_FROM_ACCEPT_LIST); + if (!sent) + return rp->status; + + hci_dev_lock(hdev); + hci_bdaddr_list_del(&hdev->le_accept_list, &sent->bdaddr, + sent->bdaddr_type); + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_le_read_supported_states(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_le_read_supported_states *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + memcpy(hdev->le_states, rp->le_states, 8); + + return rp->status; +} + +static u8 hci_cc_le_read_def_data_len(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_le_read_def_data_len *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hdev->le_def_tx_len = le16_to_cpu(rp->tx_len); + hdev->le_def_tx_time = le16_to_cpu(rp->tx_time); + + return rp->status; +} + +static u8 hci_cc_le_write_def_data_len(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_cp_le_write_def_data_len *sent; + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_LE_WRITE_DEF_DATA_LEN); + if (!sent) + return rp->status; + + hdev->le_def_tx_len = le16_to_cpu(sent->tx_len); + hdev->le_def_tx_time = le16_to_cpu(sent->tx_time); + + return rp->status; +} + +static u8 hci_cc_le_add_to_resolv_list(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_cp_le_add_to_resolv_list *sent; + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_LE_ADD_TO_RESOLV_LIST); + if (!sent) + return rp->status; + + hci_dev_lock(hdev); + hci_bdaddr_list_add_with_irk(&hdev->le_resolv_list, &sent->bdaddr, + sent->bdaddr_type, sent->peer_irk, + sent->local_irk); + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_le_del_from_resolv_list(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_cp_le_del_from_resolv_list *sent; + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_LE_DEL_FROM_RESOLV_LIST); + if (!sent) + return rp->status; + + hci_dev_lock(hdev); + hci_bdaddr_list_del_with_irk(&hdev->le_resolv_list, &sent->bdaddr, + sent->bdaddr_type); + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_le_clear_resolv_list(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hci_dev_lock(hdev); + hci_bdaddr_list_clear(&hdev->le_resolv_list); + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_le_read_resolv_list_size(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_le_read_resolv_list_size *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x size %u", rp->status, rp->size); + + if (rp->status) + return rp->status; + + hdev->le_resolv_list_size = rp->size; + + return rp->status; +} + +static u8 hci_cc_le_set_addr_resolution_enable(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + __u8 *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_LE_SET_ADDR_RESOLV_ENABLE); + if (!sent) + return rp->status; + + hci_dev_lock(hdev); + + if (*sent) + hci_dev_set_flag(hdev, HCI_LL_RPA_RESOLUTION); + else + hci_dev_clear_flag(hdev, HCI_LL_RPA_RESOLUTION); + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_le_read_max_data_len(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_le_read_max_data_len *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hdev->le_max_tx_len = le16_to_cpu(rp->tx_len); + hdev->le_max_tx_time = le16_to_cpu(rp->tx_time); + hdev->le_max_rx_len = le16_to_cpu(rp->rx_len); + hdev->le_max_rx_time = le16_to_cpu(rp->rx_time); + + return rp->status; +} + +static u8 hci_cc_write_le_host_supported(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_cp_write_le_host_supported *sent; + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_WRITE_LE_HOST_SUPPORTED); + if (!sent) + return rp->status; + + hci_dev_lock(hdev); + + if (sent->le) { + hdev->features[1][0] |= LMP_HOST_LE; + hci_dev_set_flag(hdev, HCI_LE_ENABLED); + } else { + hdev->features[1][0] &= ~LMP_HOST_LE; + hci_dev_clear_flag(hdev, HCI_LE_ENABLED); + hci_dev_clear_flag(hdev, HCI_ADVERTISING); + } + + if (sent->simul) + hdev->features[1][0] |= LMP_HOST_LE_BREDR; + else + hdev->features[1][0] &= ~LMP_HOST_LE_BREDR; + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_set_adv_param(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_cp_le_set_adv_param *cp; + struct hci_ev_status *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + cp = hci_sent_cmd_data(hdev, HCI_OP_LE_SET_ADV_PARAM); + if (!cp) + return rp->status; + + hci_dev_lock(hdev); + hdev->adv_addr_type = cp->own_address_type; + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_set_ext_adv_param(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_le_set_ext_adv_params *rp = data; + struct hci_cp_le_set_ext_adv_params *cp; + struct adv_info *adv_instance; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + cp = hci_sent_cmd_data(hdev, HCI_OP_LE_SET_EXT_ADV_PARAMS); + if (!cp) + return rp->status; + + hci_dev_lock(hdev); + hdev->adv_addr_type = cp->own_addr_type; + if (!cp->handle) { + /* Store in hdev for instance 0 */ + hdev->adv_tx_power = rp->tx_power; + } else { + adv_instance = hci_find_adv_instance(hdev, cp->handle); + if (adv_instance) + adv_instance->tx_power = rp->tx_power; + } + /* Update adv data as tx power is known now */ + hci_update_adv_data(hdev, cp->handle); + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_read_rssi(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_read_rssi *rp = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(rp->handle)); + if (conn) + conn->rssi = rp->rssi; + + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_read_tx_power(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_cp_read_tx_power *sent; + struct hci_rp_read_tx_power *rp = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_READ_TX_POWER); + if (!sent) + return rp->status; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(rp->handle)); + if (!conn) + goto unlock; + + switch (sent->type) { + case 0x00: + conn->tx_power = rp->tx_power; + break; + case 0x01: + conn->max_tx_power = rp->tx_power; + break; + } + +unlock: + hci_dev_unlock(hdev); + return rp->status; +} + +static u8 hci_cc_write_ssp_debug_mode(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + u8 *mode; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + mode = hci_sent_cmd_data(hdev, HCI_OP_WRITE_SSP_DEBUG_MODE); + if (mode) + hdev->ssp_debug_mode = *mode; + + return rp->status; +} + +static void hci_cs_inquiry(struct hci_dev *hdev, __u8 status) +{ + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + if (status) { + hci_conn_check_pending(hdev); + return; + } + + if (hci_sent_cmd_data(hdev, HCI_OP_INQUIRY)) + set_bit(HCI_INQUIRY, &hdev->flags); +} + +static void hci_cs_create_conn(struct hci_dev *hdev, __u8 status) +{ + struct hci_cp_create_conn *cp; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + cp = hci_sent_cmd_data(hdev, HCI_OP_CREATE_CONN); + if (!cp) + return; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &cp->bdaddr); + + bt_dev_dbg(hdev, "bdaddr %pMR hcon %p", &cp->bdaddr, conn); + + if (status) { + if (conn && conn->state == BT_CONNECT) { + if (status != 0x0c || conn->attempt > 2) { + conn->state = BT_CLOSED; + hci_connect_cfm(conn, status); + hci_conn_del(conn); + } else + conn->state = BT_CONNECT2; + } + } else { + if (!conn) { + conn = hci_conn_add(hdev, ACL_LINK, &cp->bdaddr, + HCI_ROLE_MASTER); + if (!conn) + bt_dev_err(hdev, "no memory for new connection"); + } + } + + hci_dev_unlock(hdev); +} + +static void hci_cs_add_sco(struct hci_dev *hdev, __u8 status) +{ + struct hci_cp_add_sco *cp; + struct hci_conn *acl, *sco; + __u16 handle; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + if (!status) + return; + + cp = hci_sent_cmd_data(hdev, HCI_OP_ADD_SCO); + if (!cp) + return; + + handle = __le16_to_cpu(cp->handle); + + bt_dev_dbg(hdev, "handle 0x%4.4x", handle); + + hci_dev_lock(hdev); + + acl = hci_conn_hash_lookup_handle(hdev, handle); + if (acl) { + sco = acl->link; + if (sco) { + sco->state = BT_CLOSED; + + hci_connect_cfm(sco, status); + hci_conn_del(sco); + } + } + + hci_dev_unlock(hdev); +} + +static void hci_cs_auth_requested(struct hci_dev *hdev, __u8 status) +{ + struct hci_cp_auth_requested *cp; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + if (!status) + return; + + cp = hci_sent_cmd_data(hdev, HCI_OP_AUTH_REQUESTED); + if (!cp) + return; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(cp->handle)); + if (conn) { + if (conn->state == BT_CONFIG) { + hci_connect_cfm(conn, status); + hci_conn_drop(conn); + } + } + + hci_dev_unlock(hdev); +} + +static void hci_cs_set_conn_encrypt(struct hci_dev *hdev, __u8 status) +{ + struct hci_cp_set_conn_encrypt *cp; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + if (!status) + return; + + cp = hci_sent_cmd_data(hdev, HCI_OP_SET_CONN_ENCRYPT); + if (!cp) + return; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(cp->handle)); + if (conn) { + if (conn->state == BT_CONFIG) { + hci_connect_cfm(conn, status); + hci_conn_drop(conn); + } + } + + hci_dev_unlock(hdev); +} + +static int hci_outgoing_auth_needed(struct hci_dev *hdev, + struct hci_conn *conn) +{ + if (conn->state != BT_CONFIG || !conn->out) + return 0; + + if (conn->pending_sec_level == BT_SECURITY_SDP) + return 0; + + /* Only request authentication for SSP connections or non-SSP + * devices with sec_level MEDIUM or HIGH or if MITM protection + * is requested. + */ + if (!hci_conn_ssp_enabled(conn) && !(conn->auth_type & 0x01) && + conn->pending_sec_level != BT_SECURITY_FIPS && + conn->pending_sec_level != BT_SECURITY_HIGH && + conn->pending_sec_level != BT_SECURITY_MEDIUM) + return 0; + + return 1; +} + +static int hci_resolve_name(struct hci_dev *hdev, + struct inquiry_entry *e) +{ + struct hci_cp_remote_name_req cp; + + memset(&cp, 0, sizeof(cp)); + + bacpy(&cp.bdaddr, &e->data.bdaddr); + cp.pscan_rep_mode = e->data.pscan_rep_mode; + cp.pscan_mode = e->data.pscan_mode; + cp.clock_offset = e->data.clock_offset; + + return hci_send_cmd(hdev, HCI_OP_REMOTE_NAME_REQ, sizeof(cp), &cp); +} + +static bool hci_resolve_next_name(struct hci_dev *hdev) +{ + struct discovery_state *discov = &hdev->discovery; + struct inquiry_entry *e; + + if (list_empty(&discov->resolve)) + return false; + + /* We should stop if we already spent too much time resolving names. */ + if (time_after(jiffies, discov->name_resolve_timeout)) { + bt_dev_warn_ratelimited(hdev, "Name resolve takes too long."); + return false; + } + + e = hci_inquiry_cache_lookup_resolve(hdev, BDADDR_ANY, NAME_NEEDED); + if (!e) + return false; + + if (hci_resolve_name(hdev, e) == 0) { + e->name_state = NAME_PENDING; + return true; + } + + return false; +} + +static void hci_check_pending_name(struct hci_dev *hdev, struct hci_conn *conn, + bdaddr_t *bdaddr, u8 *name, u8 name_len) +{ + struct discovery_state *discov = &hdev->discovery; + struct inquiry_entry *e; + + /* Update the mgmt connected state if necessary. Be careful with + * conn objects that exist but are not (yet) connected however. + * Only those in BT_CONFIG or BT_CONNECTED states can be + * considered connected. + */ + if (conn && + (conn->state == BT_CONFIG || conn->state == BT_CONNECTED) && + !test_and_set_bit(HCI_CONN_MGMT_CONNECTED, &conn->flags)) + mgmt_device_connected(hdev, conn, name, name_len); + + if (discov->state == DISCOVERY_STOPPED) + return; + + if (discov->state == DISCOVERY_STOPPING) + goto discov_complete; + + if (discov->state != DISCOVERY_RESOLVING) + return; + + e = hci_inquiry_cache_lookup_resolve(hdev, bdaddr, NAME_PENDING); + /* If the device was not found in a list of found devices names of which + * are pending. there is no need to continue resolving a next name as it + * will be done upon receiving another Remote Name Request Complete + * Event */ + if (!e) + return; + + list_del(&e->list); + + e->name_state = name ? NAME_KNOWN : NAME_NOT_KNOWN; + mgmt_remote_name(hdev, bdaddr, ACL_LINK, 0x00, e->data.rssi, + name, name_len); + + if (hci_resolve_next_name(hdev)) + return; + +discov_complete: + hci_discovery_set_state(hdev, DISCOVERY_STOPPED); +} + +static void hci_cs_remote_name_req(struct hci_dev *hdev, __u8 status) +{ + struct hci_cp_remote_name_req *cp; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + /* If successful wait for the name req complete event before + * checking for the need to do authentication */ + if (!status) + return; + + cp = hci_sent_cmd_data(hdev, HCI_OP_REMOTE_NAME_REQ); + if (!cp) + return; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &cp->bdaddr); + + if (hci_dev_test_flag(hdev, HCI_MGMT)) + hci_check_pending_name(hdev, conn, &cp->bdaddr, NULL, 0); + + if (!conn) + goto unlock; + + if (!hci_outgoing_auth_needed(hdev, conn)) + goto unlock; + + if (!test_and_set_bit(HCI_CONN_AUTH_PEND, &conn->flags)) { + struct hci_cp_auth_requested auth_cp; + + set_bit(HCI_CONN_AUTH_INITIATOR, &conn->flags); + + auth_cp.handle = __cpu_to_le16(conn->handle); + hci_send_cmd(hdev, HCI_OP_AUTH_REQUESTED, + sizeof(auth_cp), &auth_cp); + } + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_cs_read_remote_features(struct hci_dev *hdev, __u8 status) +{ + struct hci_cp_read_remote_features *cp; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + if (!status) + return; + + cp = hci_sent_cmd_data(hdev, HCI_OP_READ_REMOTE_FEATURES); + if (!cp) + return; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(cp->handle)); + if (conn) { + if (conn->state == BT_CONFIG) { + hci_connect_cfm(conn, status); + hci_conn_drop(conn); + } + } + + hci_dev_unlock(hdev); +} + +static void hci_cs_read_remote_ext_features(struct hci_dev *hdev, __u8 status) +{ + struct hci_cp_read_remote_ext_features *cp; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + if (!status) + return; + + cp = hci_sent_cmd_data(hdev, HCI_OP_READ_REMOTE_EXT_FEATURES); + if (!cp) + return; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(cp->handle)); + if (conn) { + if (conn->state == BT_CONFIG) { + hci_connect_cfm(conn, status); + hci_conn_drop(conn); + } + } + + hci_dev_unlock(hdev); +} + +static void hci_cs_setup_sync_conn(struct hci_dev *hdev, __u8 status) +{ + struct hci_cp_setup_sync_conn *cp; + struct hci_conn *acl, *sco; + __u16 handle; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + if (!status) + return; + + cp = hci_sent_cmd_data(hdev, HCI_OP_SETUP_SYNC_CONN); + if (!cp) + return; + + handle = __le16_to_cpu(cp->handle); + + bt_dev_dbg(hdev, "handle 0x%4.4x", handle); + + hci_dev_lock(hdev); + + acl = hci_conn_hash_lookup_handle(hdev, handle); + if (acl) { + sco = acl->link; + if (sco) { + sco->state = BT_CLOSED; + + hci_connect_cfm(sco, status); + hci_conn_del(sco); + } + } + + hci_dev_unlock(hdev); +} + +static void hci_cs_enhanced_setup_sync_conn(struct hci_dev *hdev, __u8 status) +{ + struct hci_cp_enhanced_setup_sync_conn *cp; + struct hci_conn *acl, *sco; + __u16 handle; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + if (!status) + return; + + cp = hci_sent_cmd_data(hdev, HCI_OP_ENHANCED_SETUP_SYNC_CONN); + if (!cp) + return; + + handle = __le16_to_cpu(cp->handle); + + bt_dev_dbg(hdev, "handle 0x%4.4x", handle); + + hci_dev_lock(hdev); + + acl = hci_conn_hash_lookup_handle(hdev, handle); + if (acl) { + sco = acl->link; + if (sco) { + sco->state = BT_CLOSED; + + hci_connect_cfm(sco, status); + hci_conn_del(sco); + } + } + + hci_dev_unlock(hdev); +} + +static void hci_cs_sniff_mode(struct hci_dev *hdev, __u8 status) +{ + struct hci_cp_sniff_mode *cp; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + if (!status) + return; + + cp = hci_sent_cmd_data(hdev, HCI_OP_SNIFF_MODE); + if (!cp) + return; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(cp->handle)); + if (conn) { + clear_bit(HCI_CONN_MODE_CHANGE_PEND, &conn->flags); + + if (test_and_clear_bit(HCI_CONN_SCO_SETUP_PEND, &conn->flags)) + hci_sco_setup(conn, status); + } + + hci_dev_unlock(hdev); +} + +static void hci_cs_exit_sniff_mode(struct hci_dev *hdev, __u8 status) +{ + struct hci_cp_exit_sniff_mode *cp; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + if (!status) + return; + + cp = hci_sent_cmd_data(hdev, HCI_OP_EXIT_SNIFF_MODE); + if (!cp) + return; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(cp->handle)); + if (conn) { + clear_bit(HCI_CONN_MODE_CHANGE_PEND, &conn->flags); + + if (test_and_clear_bit(HCI_CONN_SCO_SETUP_PEND, &conn->flags)) + hci_sco_setup(conn, status); + } + + hci_dev_unlock(hdev); +} + +static void hci_cs_disconnect(struct hci_dev *hdev, u8 status) +{ + struct hci_cp_disconnect *cp; + struct hci_conn_params *params; + struct hci_conn *conn; + bool mgmt_conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + /* Wait for HCI_EV_DISCONN_COMPLETE if status 0x00 and not suspended + * otherwise cleanup the connection immediately. + */ + if (!status && !hdev->suspended) + return; + + cp = hci_sent_cmd_data(hdev, HCI_OP_DISCONNECT); + if (!cp) + return; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(cp->handle)); + if (!conn) + goto unlock; + + if (status) { + mgmt_disconnect_failed(hdev, &conn->dst, conn->type, + conn->dst_type, status); + + if (conn->type == LE_LINK && conn->role == HCI_ROLE_SLAVE) { + hdev->cur_adv_instance = conn->adv_instance; + hci_enable_advertising(hdev); + } + + /* Inform sockets conn is gone before we delete it */ + hci_disconn_cfm(conn, HCI_ERROR_UNSPECIFIED); + + goto done; + } + + mgmt_conn = test_and_clear_bit(HCI_CONN_MGMT_CONNECTED, &conn->flags); + + if (conn->type == ACL_LINK) { + if (test_and_clear_bit(HCI_CONN_FLUSH_KEY, &conn->flags)) + hci_remove_link_key(hdev, &conn->dst); + } + + params = hci_conn_params_lookup(hdev, &conn->dst, conn->dst_type); + if (params) { + switch (params->auto_connect) { + case HCI_AUTO_CONN_LINK_LOSS: + if (cp->reason != HCI_ERROR_CONNECTION_TIMEOUT) + break; + fallthrough; + + case HCI_AUTO_CONN_DIRECT: + case HCI_AUTO_CONN_ALWAYS: + hci_pend_le_list_del_init(params); + hci_pend_le_list_add(params, &hdev->pend_le_conns); + break; + + default: + break; + } + } + + mgmt_device_disconnected(hdev, &conn->dst, conn->type, conn->dst_type, + cp->reason, mgmt_conn); + + hci_disconn_cfm(conn, cp->reason); + +done: + /* If the disconnection failed for any reason, the upper layer + * does not retry to disconnect in current implementation. + * Hence, we need to do some basic cleanup here and re-enable + * advertising if necessary. + */ + hci_conn_del(conn); +unlock: + hci_dev_unlock(hdev); +} + +static u8 ev_bdaddr_type(struct hci_dev *hdev, u8 type, bool *resolved) +{ + /* When using controller based address resolution, then the new + * address types 0x02 and 0x03 are used. These types need to be + * converted back into either public address or random address type + */ + switch (type) { + case ADDR_LE_DEV_PUBLIC_RESOLVED: + if (resolved) + *resolved = true; + return ADDR_LE_DEV_PUBLIC; + case ADDR_LE_DEV_RANDOM_RESOLVED: + if (resolved) + *resolved = true; + return ADDR_LE_DEV_RANDOM; + } + + if (resolved) + *resolved = false; + return type; +} + +static void cs_le_create_conn(struct hci_dev *hdev, bdaddr_t *peer_addr, + u8 peer_addr_type, u8 own_address_type, + u8 filter_policy) +{ + struct hci_conn *conn; + + conn = hci_conn_hash_lookup_le(hdev, peer_addr, + peer_addr_type); + if (!conn) + return; + + own_address_type = ev_bdaddr_type(hdev, own_address_type, NULL); + + /* Store the initiator and responder address information which + * is needed for SMP. These values will not change during the + * lifetime of the connection. + */ + conn->init_addr_type = own_address_type; + if (own_address_type == ADDR_LE_DEV_RANDOM) + bacpy(&conn->init_addr, &hdev->random_addr); + else + bacpy(&conn->init_addr, &hdev->bdaddr); + + conn->resp_addr_type = peer_addr_type; + bacpy(&conn->resp_addr, peer_addr); +} + +static void hci_cs_le_create_conn(struct hci_dev *hdev, u8 status) +{ + struct hci_cp_le_create_conn *cp; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + /* All connection failure handling is taken care of by the + * hci_conn_failed function which is triggered by the HCI + * request completion callbacks used for connecting. + */ + if (status) + return; + + cp = hci_sent_cmd_data(hdev, HCI_OP_LE_CREATE_CONN); + if (!cp) + return; + + hci_dev_lock(hdev); + + cs_le_create_conn(hdev, &cp->peer_addr, cp->peer_addr_type, + cp->own_address_type, cp->filter_policy); + + hci_dev_unlock(hdev); +} + +static void hci_cs_le_ext_create_conn(struct hci_dev *hdev, u8 status) +{ + struct hci_cp_le_ext_create_conn *cp; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + /* All connection failure handling is taken care of by the + * hci_conn_failed function which is triggered by the HCI + * request completion callbacks used for connecting. + */ + if (status) + return; + + cp = hci_sent_cmd_data(hdev, HCI_OP_LE_EXT_CREATE_CONN); + if (!cp) + return; + + hci_dev_lock(hdev); + + cs_le_create_conn(hdev, &cp->peer_addr, cp->peer_addr_type, + cp->own_addr_type, cp->filter_policy); + + hci_dev_unlock(hdev); +} + +static void hci_cs_le_read_remote_features(struct hci_dev *hdev, u8 status) +{ + struct hci_cp_le_read_remote_features *cp; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + if (!status) + return; + + cp = hci_sent_cmd_data(hdev, HCI_OP_LE_READ_REMOTE_FEATURES); + if (!cp) + return; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(cp->handle)); + if (conn) { + if (conn->state == BT_CONFIG) { + hci_connect_cfm(conn, status); + hci_conn_drop(conn); + } + } + + hci_dev_unlock(hdev); +} + +static void hci_cs_le_start_enc(struct hci_dev *hdev, u8 status) +{ + struct hci_cp_le_start_enc *cp; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + if (!status) + return; + + hci_dev_lock(hdev); + + cp = hci_sent_cmd_data(hdev, HCI_OP_LE_START_ENC); + if (!cp) + goto unlock; + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(cp->handle)); + if (!conn) + goto unlock; + + if (conn->state != BT_CONNECTED) + goto unlock; + + hci_disconnect(conn, HCI_ERROR_AUTH_FAILURE); + hci_conn_drop(conn); + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_cs_switch_role(struct hci_dev *hdev, u8 status) +{ + struct hci_cp_switch_role *cp; + struct hci_conn *conn; + + BT_DBG("%s status 0x%2.2x", hdev->name, status); + + if (!status) + return; + + cp = hci_sent_cmd_data(hdev, HCI_OP_SWITCH_ROLE); + if (!cp) + return; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &cp->bdaddr); + if (conn) + clear_bit(HCI_CONN_RSWITCH_PEND, &conn->flags); + + hci_dev_unlock(hdev); +} + +static void hci_inquiry_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *ev = data; + struct discovery_state *discov = &hdev->discovery; + struct inquiry_entry *e; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + hci_conn_check_pending(hdev); + + if (!test_and_clear_bit(HCI_INQUIRY, &hdev->flags)) + return; + + smp_mb__after_atomic(); /* wake_up_bit advises about this barrier */ + wake_up_bit(&hdev->flags, HCI_INQUIRY); + + if (!hci_dev_test_flag(hdev, HCI_MGMT)) + return; + + hci_dev_lock(hdev); + + if (discov->state != DISCOVERY_FINDING) + goto unlock; + + if (list_empty(&discov->resolve)) { + /* When BR/EDR inquiry is active and no LE scanning is in + * progress, then change discovery state to indicate completion. + * + * When running LE scanning and BR/EDR inquiry simultaneously + * and the LE scan already finished, then change the discovery + * state to indicate completion. + */ + if (!hci_dev_test_flag(hdev, HCI_LE_SCAN) || + !test_bit(HCI_QUIRK_SIMULTANEOUS_DISCOVERY, &hdev->quirks)) + hci_discovery_set_state(hdev, DISCOVERY_STOPPED); + goto unlock; + } + + e = hci_inquiry_cache_lookup_resolve(hdev, BDADDR_ANY, NAME_NEEDED); + if (e && hci_resolve_name(hdev, e) == 0) { + e->name_state = NAME_PENDING; + hci_discovery_set_state(hdev, DISCOVERY_RESOLVING); + discov->name_resolve_timeout = jiffies + NAME_RESOLVE_DURATION; + } else { + /* When BR/EDR inquiry is active and no LE scanning is in + * progress, then change discovery state to indicate completion. + * + * When running LE scanning and BR/EDR inquiry simultaneously + * and the LE scan already finished, then change the discovery + * state to indicate completion. + */ + if (!hci_dev_test_flag(hdev, HCI_LE_SCAN) || + !test_bit(HCI_QUIRK_SIMULTANEOUS_DISCOVERY, &hdev->quirks)) + hci_discovery_set_state(hdev, DISCOVERY_STOPPED); + } + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_inquiry_result_evt(struct hci_dev *hdev, void *edata, + struct sk_buff *skb) +{ + struct hci_ev_inquiry_result *ev = edata; + struct inquiry_data data; + int i; + + if (!hci_ev_skb_pull(hdev, skb, HCI_EV_INQUIRY_RESULT, + flex_array_size(ev, info, ev->num))) + return; + + bt_dev_dbg(hdev, "num %d", ev->num); + + if (!ev->num) + return; + + if (hci_dev_test_flag(hdev, HCI_PERIODIC_INQ)) + return; + + hci_dev_lock(hdev); + + for (i = 0; i < ev->num; i++) { + struct inquiry_info *info = &ev->info[i]; + u32 flags; + + bacpy(&data.bdaddr, &info->bdaddr); + data.pscan_rep_mode = info->pscan_rep_mode; + data.pscan_period_mode = info->pscan_period_mode; + data.pscan_mode = info->pscan_mode; + memcpy(data.dev_class, info->dev_class, 3); + data.clock_offset = info->clock_offset; + data.rssi = HCI_RSSI_INVALID; + data.ssp_mode = 0x00; + + flags = hci_inquiry_cache_update(hdev, &data, false); + + mgmt_device_found(hdev, &info->bdaddr, ACL_LINK, 0x00, + info->dev_class, HCI_RSSI_INVALID, + flags, NULL, 0, NULL, 0, 0); + } + + hci_dev_unlock(hdev); +} + +static void hci_conn_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_conn_complete *ev = data; + struct hci_conn *conn; + u8 status = ev->status; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_ba(hdev, ev->link_type, &ev->bdaddr); + if (!conn) { + /* In case of error status and there is no connection pending + * just unlock as there is nothing to cleanup. + */ + if (ev->status) + goto unlock; + + /* Connection may not exist if auto-connected. Check the bredr + * allowlist to see if this device is allowed to auto connect. + * If link is an ACL type, create a connection class + * automatically. + * + * Auto-connect will only occur if the event filter is + * programmed with a given address. Right now, event filter is + * only used during suspend. + */ + if (ev->link_type == ACL_LINK && + hci_bdaddr_list_lookup_with_flags(&hdev->accept_list, + &ev->bdaddr, + BDADDR_BREDR)) { + conn = hci_conn_add(hdev, ev->link_type, &ev->bdaddr, + HCI_ROLE_SLAVE); + if (!conn) { + bt_dev_err(hdev, "no memory for new conn"); + goto unlock; + } + } else { + if (ev->link_type != SCO_LINK) + goto unlock; + + conn = hci_conn_hash_lookup_ba(hdev, ESCO_LINK, + &ev->bdaddr); + if (!conn) + goto unlock; + + conn->type = SCO_LINK; + } + } + + /* The HCI_Connection_Complete event is only sent once per connection. + * Processing it more than once per connection can corrupt kernel memory. + * + * As the connection handle is set here for the first time, it indicates + * whether the connection is already set up. + */ + if (conn->handle != HCI_CONN_HANDLE_UNSET) { + bt_dev_err(hdev, "Ignoring HCI_Connection_Complete for existing connection"); + goto unlock; + } + + if (!status) { + conn->handle = __le16_to_cpu(ev->handle); + if (conn->handle > HCI_CONN_HANDLE_MAX) { + bt_dev_err(hdev, "Invalid handle: 0x%4.4x > 0x%4.4x", + conn->handle, HCI_CONN_HANDLE_MAX); + status = HCI_ERROR_INVALID_PARAMETERS; + goto done; + } + + if (conn->type == ACL_LINK) { + conn->state = BT_CONFIG; + hci_conn_hold(conn); + + if (!conn->out && !hci_conn_ssp_enabled(conn) && + !hci_find_link_key(hdev, &ev->bdaddr)) + conn->disc_timeout = HCI_PAIRING_TIMEOUT; + else + conn->disc_timeout = HCI_DISCONN_TIMEOUT; + } else + conn->state = BT_CONNECTED; + + hci_debugfs_create_conn(conn); + hci_conn_add_sysfs(conn); + + if (test_bit(HCI_AUTH, &hdev->flags)) + set_bit(HCI_CONN_AUTH, &conn->flags); + + if (test_bit(HCI_ENCRYPT, &hdev->flags)) + set_bit(HCI_CONN_ENCRYPT, &conn->flags); + + /* Get remote features */ + if (conn->type == ACL_LINK) { + struct hci_cp_read_remote_features cp; + cp.handle = ev->handle; + hci_send_cmd(hdev, HCI_OP_READ_REMOTE_FEATURES, + sizeof(cp), &cp); + + hci_update_scan(hdev); + } + + /* Set packet type for incoming connection */ + if (!conn->out && hdev->hci_ver < BLUETOOTH_VER_2_0) { + struct hci_cp_change_conn_ptype cp; + cp.handle = ev->handle; + cp.pkt_type = cpu_to_le16(conn->pkt_type); + hci_send_cmd(hdev, HCI_OP_CHANGE_CONN_PTYPE, sizeof(cp), + &cp); + } + } + + if (conn->type == ACL_LINK) + hci_sco_setup(conn, ev->status); + +done: + if (status) { + hci_conn_failed(conn, status); + } else if (ev->link_type == SCO_LINK) { + switch (conn->setting & SCO_AIRMODE_MASK) { + case SCO_AIRMODE_CVSD: + if (hdev->notify) + hdev->notify(hdev, HCI_NOTIFY_ENABLE_SCO_CVSD); + break; + } + + hci_connect_cfm(conn, status); + } + +unlock: + hci_dev_unlock(hdev); + + hci_conn_check_pending(hdev); +} + +static void hci_reject_conn(struct hci_dev *hdev, bdaddr_t *bdaddr) +{ + struct hci_cp_reject_conn_req cp; + + bacpy(&cp.bdaddr, bdaddr); + cp.reason = HCI_ERROR_REJ_BAD_ADDR; + hci_send_cmd(hdev, HCI_OP_REJECT_CONN_REQ, sizeof(cp), &cp); +} + +static void hci_conn_request_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_conn_request *ev = data; + int mask = hdev->link_mode; + struct inquiry_entry *ie; + struct hci_conn *conn; + __u8 flags = 0; + + bt_dev_dbg(hdev, "bdaddr %pMR type 0x%x", &ev->bdaddr, ev->link_type); + + /* Reject incoming connection from device with same BD ADDR against + * CVE-2020-26555 + */ + if (hdev && !bacmp(&hdev->bdaddr, &ev->bdaddr)) { + bt_dev_dbg(hdev, "Reject connection with same BD_ADDR %pMR\n", + &ev->bdaddr); + hci_reject_conn(hdev, &ev->bdaddr); + return; + } + + mask |= hci_proto_connect_ind(hdev, &ev->bdaddr, ev->link_type, + &flags); + + if (!(mask & HCI_LM_ACCEPT)) { + hci_reject_conn(hdev, &ev->bdaddr); + return; + } + + hci_dev_lock(hdev); + + if (hci_bdaddr_list_lookup(&hdev->reject_list, &ev->bdaddr, + BDADDR_BREDR)) { + hci_reject_conn(hdev, &ev->bdaddr); + goto unlock; + } + + /* Require HCI_CONNECTABLE or an accept list entry to accept the + * connection. These features are only touched through mgmt so + * only do the checks if HCI_MGMT is set. + */ + if (hci_dev_test_flag(hdev, HCI_MGMT) && + !hci_dev_test_flag(hdev, HCI_CONNECTABLE) && + !hci_bdaddr_list_lookup_with_flags(&hdev->accept_list, &ev->bdaddr, + BDADDR_BREDR)) { + hci_reject_conn(hdev, &ev->bdaddr); + goto unlock; + } + + /* Connection accepted */ + + ie = hci_inquiry_cache_lookup(hdev, &ev->bdaddr); + if (ie) + memcpy(ie->data.dev_class, ev->dev_class, 3); + + conn = hci_conn_hash_lookup_ba(hdev, ev->link_type, + &ev->bdaddr); + if (!conn) { + conn = hci_conn_add(hdev, ev->link_type, &ev->bdaddr, + HCI_ROLE_SLAVE); + if (!conn) { + bt_dev_err(hdev, "no memory for new connection"); + goto unlock; + } + } + + memcpy(conn->dev_class, ev->dev_class, 3); + + hci_dev_unlock(hdev); + + if (ev->link_type == ACL_LINK || + (!(flags & HCI_PROTO_DEFER) && !lmp_esco_capable(hdev))) { + struct hci_cp_accept_conn_req cp; + conn->state = BT_CONNECT; + + bacpy(&cp.bdaddr, &ev->bdaddr); + + if (lmp_rswitch_capable(hdev) && (mask & HCI_LM_MASTER)) + cp.role = 0x00; /* Become central */ + else + cp.role = 0x01; /* Remain peripheral */ + + hci_send_cmd(hdev, HCI_OP_ACCEPT_CONN_REQ, sizeof(cp), &cp); + } else if (!(flags & HCI_PROTO_DEFER)) { + struct hci_cp_accept_sync_conn_req cp; + conn->state = BT_CONNECT; + + bacpy(&cp.bdaddr, &ev->bdaddr); + cp.pkt_type = cpu_to_le16(conn->pkt_type); + + cp.tx_bandwidth = cpu_to_le32(0x00001f40); + cp.rx_bandwidth = cpu_to_le32(0x00001f40); + cp.max_latency = cpu_to_le16(0xffff); + cp.content_format = cpu_to_le16(hdev->voice_setting); + cp.retrans_effort = 0xff; + + hci_send_cmd(hdev, HCI_OP_ACCEPT_SYNC_CONN_REQ, sizeof(cp), + &cp); + } else { + conn->state = BT_CONNECT2; + hci_connect_cfm(conn, 0); + } + + return; +unlock: + hci_dev_unlock(hdev); +} + +static u8 hci_to_mgmt_reason(u8 err) +{ + switch (err) { + case HCI_ERROR_CONNECTION_TIMEOUT: + return MGMT_DEV_DISCONN_TIMEOUT; + case HCI_ERROR_REMOTE_USER_TERM: + case HCI_ERROR_REMOTE_LOW_RESOURCES: + case HCI_ERROR_REMOTE_POWER_OFF: + return MGMT_DEV_DISCONN_REMOTE; + case HCI_ERROR_LOCAL_HOST_TERM: + return MGMT_DEV_DISCONN_LOCAL_HOST; + default: + return MGMT_DEV_DISCONN_UNKNOWN; + } +} + +static void hci_disconn_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_disconn_complete *ev = data; + u8 reason; + struct hci_conn_params *params; + struct hci_conn *conn; + bool mgmt_connected; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(ev->handle)); + if (!conn) + goto unlock; + + if (ev->status) { + mgmt_disconnect_failed(hdev, &conn->dst, conn->type, + conn->dst_type, ev->status); + goto unlock; + } + + conn->state = BT_CLOSED; + + mgmt_connected = test_and_clear_bit(HCI_CONN_MGMT_CONNECTED, &conn->flags); + + if (test_bit(HCI_CONN_AUTH_FAILURE, &conn->flags)) + reason = MGMT_DEV_DISCONN_AUTH_FAILURE; + else + reason = hci_to_mgmt_reason(ev->reason); + + mgmt_device_disconnected(hdev, &conn->dst, conn->type, conn->dst_type, + reason, mgmt_connected); + + if (conn->type == ACL_LINK) { + if (test_and_clear_bit(HCI_CONN_FLUSH_KEY, &conn->flags)) + hci_remove_link_key(hdev, &conn->dst); + + hci_update_scan(hdev); + } + + params = hci_conn_params_lookup(hdev, &conn->dst, conn->dst_type); + if (params) { + switch (params->auto_connect) { + case HCI_AUTO_CONN_LINK_LOSS: + if (ev->reason != HCI_ERROR_CONNECTION_TIMEOUT) + break; + fallthrough; + + case HCI_AUTO_CONN_DIRECT: + case HCI_AUTO_CONN_ALWAYS: + hci_pend_le_list_del_init(params); + hci_pend_le_list_add(params, &hdev->pend_le_conns); + hci_update_passive_scan(hdev); + break; + + default: + break; + } + } + + hci_disconn_cfm(conn, ev->reason); + + /* Re-enable advertising if necessary, since it might + * have been disabled by the connection. From the + * HCI_LE_Set_Advertise_Enable command description in + * the core specification (v4.0): + * "The Controller shall continue advertising until the Host + * issues an LE_Set_Advertise_Enable command with + * Advertising_Enable set to 0x00 (Advertising is disabled) + * or until a connection is created or until the Advertising + * is timed out due to Directed Advertising." + */ + if (conn->type == LE_LINK && conn->role == HCI_ROLE_SLAVE) { + hdev->cur_adv_instance = conn->adv_instance; + hci_enable_advertising(hdev); + } + + hci_conn_del(conn); + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_auth_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_auth_complete *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(ev->handle)); + if (!conn) + goto unlock; + + if (!ev->status) { + clear_bit(HCI_CONN_AUTH_FAILURE, &conn->flags); + set_bit(HCI_CONN_AUTH, &conn->flags); + conn->sec_level = conn->pending_sec_level; + } else { + if (ev->status == HCI_ERROR_PIN_OR_KEY_MISSING) + set_bit(HCI_CONN_AUTH_FAILURE, &conn->flags); + + mgmt_auth_failed(conn, ev->status); + } + + clear_bit(HCI_CONN_AUTH_PEND, &conn->flags); + + if (conn->state == BT_CONFIG) { + if (!ev->status && hci_conn_ssp_enabled(conn)) { + struct hci_cp_set_conn_encrypt cp; + cp.handle = ev->handle; + cp.encrypt = 0x01; + hci_send_cmd(hdev, HCI_OP_SET_CONN_ENCRYPT, sizeof(cp), + &cp); + } else { + conn->state = BT_CONNECTED; + hci_connect_cfm(conn, ev->status); + hci_conn_drop(conn); + } + } else { + hci_auth_cfm(conn, ev->status); + + hci_conn_hold(conn); + conn->disc_timeout = HCI_DISCONN_TIMEOUT; + hci_conn_drop(conn); + } + + if (test_bit(HCI_CONN_ENCRYPT_PEND, &conn->flags)) { + if (!ev->status) { + struct hci_cp_set_conn_encrypt cp; + cp.handle = ev->handle; + cp.encrypt = 0x01; + hci_send_cmd(hdev, HCI_OP_SET_CONN_ENCRYPT, sizeof(cp), + &cp); + } else { + clear_bit(HCI_CONN_ENCRYPT_PEND, &conn->flags); + hci_encrypt_cfm(conn, ev->status); + } + } + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_remote_name_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_remote_name *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + hci_conn_check_pending(hdev); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &ev->bdaddr); + + if (!hci_dev_test_flag(hdev, HCI_MGMT)) + goto check_auth; + + if (ev->status == 0) + hci_check_pending_name(hdev, conn, &ev->bdaddr, ev->name, + strnlen(ev->name, HCI_MAX_NAME_LENGTH)); + else + hci_check_pending_name(hdev, conn, &ev->bdaddr, NULL, 0); + +check_auth: + if (!conn) + goto unlock; + + if (!hci_outgoing_auth_needed(hdev, conn)) + goto unlock; + + if (!test_and_set_bit(HCI_CONN_AUTH_PEND, &conn->flags)) { + struct hci_cp_auth_requested cp; + + set_bit(HCI_CONN_AUTH_INITIATOR, &conn->flags); + + cp.handle = __cpu_to_le16(conn->handle); + hci_send_cmd(hdev, HCI_OP_AUTH_REQUESTED, sizeof(cp), &cp); + } + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_encrypt_change_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_encrypt_change *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(ev->handle)); + if (!conn) + goto unlock; + + if (!ev->status) { + if (ev->encrypt) { + /* Encryption implies authentication */ + set_bit(HCI_CONN_AUTH, &conn->flags); + set_bit(HCI_CONN_ENCRYPT, &conn->flags); + conn->sec_level = conn->pending_sec_level; + + /* P-256 authentication key implies FIPS */ + if (conn->key_type == HCI_LK_AUTH_COMBINATION_P256) + set_bit(HCI_CONN_FIPS, &conn->flags); + + if ((conn->type == ACL_LINK && ev->encrypt == 0x02) || + conn->type == LE_LINK) + set_bit(HCI_CONN_AES_CCM, &conn->flags); + } else { + clear_bit(HCI_CONN_ENCRYPT, &conn->flags); + clear_bit(HCI_CONN_AES_CCM, &conn->flags); + } + } + + /* We should disregard the current RPA and generate a new one + * whenever the encryption procedure fails. + */ + if (ev->status && conn->type == LE_LINK) { + hci_dev_set_flag(hdev, HCI_RPA_EXPIRED); + hci_adv_instances_set_rpa_expired(hdev, true); + } + + clear_bit(HCI_CONN_ENCRYPT_PEND, &conn->flags); + + /* Check link security requirements are met */ + if (!hci_conn_check_link_mode(conn)) + ev->status = HCI_ERROR_AUTH_FAILURE; + + if (ev->status && conn->state == BT_CONNECTED) { + if (ev->status == HCI_ERROR_PIN_OR_KEY_MISSING) + set_bit(HCI_CONN_AUTH_FAILURE, &conn->flags); + + /* Notify upper layers so they can cleanup before + * disconnecting. + */ + hci_encrypt_cfm(conn, ev->status); + hci_disconnect(conn, HCI_ERROR_AUTH_FAILURE); + hci_conn_drop(conn); + goto unlock; + } + + /* Try reading the encryption key size for encrypted ACL links */ + if (!ev->status && ev->encrypt && conn->type == ACL_LINK) { + struct hci_cp_read_enc_key_size cp; + + /* Only send HCI_Read_Encryption_Key_Size if the + * controller really supports it. If it doesn't, assume + * the default size (16). + */ + if (!(hdev->commands[20] & 0x10)) { + conn->enc_key_size = HCI_LINK_KEY_SIZE; + goto notify; + } + + cp.handle = cpu_to_le16(conn->handle); + if (hci_send_cmd(hdev, HCI_OP_READ_ENC_KEY_SIZE, + sizeof(cp), &cp)) { + bt_dev_err(hdev, "sending read key size failed"); + conn->enc_key_size = HCI_LINK_KEY_SIZE; + goto notify; + } + + goto unlock; + } + + /* Set the default Authenticated Payload Timeout after + * an LE Link is established. As per Core Spec v5.0, Vol 2, Part B + * Section 3.3, the HCI command WRITE_AUTH_PAYLOAD_TIMEOUT should be + * sent when the link is active and Encryption is enabled, the conn + * type can be either LE or ACL and controller must support LMP Ping. + * Ensure for AES-CCM encryption as well. + */ + if (test_bit(HCI_CONN_ENCRYPT, &conn->flags) && + test_bit(HCI_CONN_AES_CCM, &conn->flags) && + ((conn->type == ACL_LINK && lmp_ping_capable(hdev)) || + (conn->type == LE_LINK && (hdev->le_features[0] & HCI_LE_PING)))) { + struct hci_cp_write_auth_payload_to cp; + + cp.handle = cpu_to_le16(conn->handle); + cp.timeout = cpu_to_le16(hdev->auth_payload_timeout); + hci_send_cmd(conn->hdev, HCI_OP_WRITE_AUTH_PAYLOAD_TO, + sizeof(cp), &cp); + } + +notify: + hci_encrypt_cfm(conn, ev->status); + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_change_link_key_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_change_link_key_complete *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(ev->handle)); + if (conn) { + if (!ev->status) + set_bit(HCI_CONN_SECURE, &conn->flags); + + clear_bit(HCI_CONN_AUTH_PEND, &conn->flags); + + hci_key_change_cfm(conn, ev->status); + } + + hci_dev_unlock(hdev); +} + +static void hci_remote_features_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_remote_features *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(ev->handle)); + if (!conn) + goto unlock; + + if (!ev->status) + memcpy(conn->features[0], ev->features, 8); + + if (conn->state != BT_CONFIG) + goto unlock; + + if (!ev->status && lmp_ext_feat_capable(hdev) && + lmp_ext_feat_capable(conn)) { + struct hci_cp_read_remote_ext_features cp; + cp.handle = ev->handle; + cp.page = 0x01; + hci_send_cmd(hdev, HCI_OP_READ_REMOTE_EXT_FEATURES, + sizeof(cp), &cp); + goto unlock; + } + + if (!ev->status && !test_bit(HCI_CONN_MGMT_CONNECTED, &conn->flags)) { + struct hci_cp_remote_name_req cp; + memset(&cp, 0, sizeof(cp)); + bacpy(&cp.bdaddr, &conn->dst); + cp.pscan_rep_mode = 0x02; + hci_send_cmd(hdev, HCI_OP_REMOTE_NAME_REQ, sizeof(cp), &cp); + } else if (!test_and_set_bit(HCI_CONN_MGMT_CONNECTED, &conn->flags)) + mgmt_device_connected(hdev, conn, NULL, 0); + + if (!hci_outgoing_auth_needed(hdev, conn)) { + conn->state = BT_CONNECTED; + hci_connect_cfm(conn, ev->status); + hci_conn_drop(conn); + } + +unlock: + hci_dev_unlock(hdev); +} + +static inline void handle_cmd_cnt_and_timer(struct hci_dev *hdev, u8 ncmd) +{ + cancel_delayed_work(&hdev->cmd_timer); + + rcu_read_lock(); + if (!test_bit(HCI_RESET, &hdev->flags)) { + if (ncmd) { + cancel_delayed_work(&hdev->ncmd_timer); + atomic_set(&hdev->cmd_cnt, 1); + } else { + if (!hci_dev_test_flag(hdev, HCI_CMD_DRAIN_WORKQUEUE)) + queue_delayed_work(hdev->workqueue, &hdev->ncmd_timer, + HCI_NCMD_TIMEOUT); + } + } + rcu_read_unlock(); +} + +static u8 hci_cc_le_read_buffer_size_v2(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_le_read_buffer_size_v2 *rp = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + hdev->le_mtu = __le16_to_cpu(rp->acl_mtu); + hdev->le_pkts = rp->acl_max_pkt; + hdev->iso_mtu = __le16_to_cpu(rp->iso_mtu); + hdev->iso_pkts = rp->iso_max_pkt; + + hdev->le_cnt = hdev->le_pkts; + hdev->iso_cnt = hdev->iso_pkts; + + BT_DBG("%s acl mtu %d:%d iso mtu %d:%d", hdev->name, hdev->acl_mtu, + hdev->acl_pkts, hdev->iso_mtu, hdev->iso_pkts); + + return rp->status; +} + +static u8 hci_cc_le_set_cig_params(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_le_set_cig_params *rp = data; + struct hci_conn *conn; + int i = 0; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + hci_dev_lock(hdev); + + if (rp->status) { + while ((conn = hci_conn_hash_lookup_cig(hdev, rp->cig_id))) { + conn->state = BT_CLOSED; + hci_connect_cfm(conn, rp->status); + hci_conn_del(conn); + } + goto unlock; + } + + rcu_read_lock(); + + list_for_each_entry_rcu(conn, &hdev->conn_hash.list, list) { + if (conn->type != ISO_LINK || conn->iso_qos.cig != rp->cig_id || + conn->state == BT_CONNECTED) + continue; + + conn->handle = __le16_to_cpu(rp->handle[i++]); + + bt_dev_dbg(hdev, "%p handle 0x%4.4x link %p", conn, + conn->handle, conn->link); + + /* Create CIS if LE is already connected */ + if (conn->link && conn->link->state == BT_CONNECTED) { + rcu_read_unlock(); + hci_le_create_cis(conn->link); + rcu_read_lock(); + } + + if (i == rp->num_handles) + break; + } + + rcu_read_unlock(); + +unlock: + hci_dev_unlock(hdev); + + return rp->status; +} + +static u8 hci_cc_le_setup_iso_path(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_rp_le_setup_iso_path *rp = data; + struct hci_cp_le_setup_iso_path *cp; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + cp = hci_sent_cmd_data(hdev, HCI_OP_LE_SETUP_ISO_PATH); + if (!cp) + return rp->status; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(cp->handle)); + if (!conn) + goto unlock; + + if (rp->status) { + hci_connect_cfm(conn, rp->status); + hci_conn_del(conn); + goto unlock; + } + + switch (cp->direction) { + /* Input (Host to Controller) */ + case 0x00: + /* Only confirm connection if output only */ + if (conn->iso_qos.out.sdu && !conn->iso_qos.in.sdu) + hci_connect_cfm(conn, rp->status); + break; + /* Output (Controller to Host) */ + case 0x01: + /* Confirm connection since conn->iso_qos is always configured + * last. + */ + hci_connect_cfm(conn, rp->status); + break; + } + +unlock: + hci_dev_unlock(hdev); + return rp->status; +} + +static void hci_cs_le_create_big(struct hci_dev *hdev, u8 status) +{ + bt_dev_dbg(hdev, "status 0x%2.2x", status); +} + +static u8 hci_cc_set_per_adv_param(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + struct hci_cp_le_set_per_adv_params *cp; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + cp = hci_sent_cmd_data(hdev, HCI_OP_LE_SET_PER_ADV_PARAMS); + if (!cp) + return rp->status; + + /* TODO: set the conn state */ + return rp->status; +} + +static u8 hci_cc_le_set_per_adv_enable(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_status *rp = data; + __u8 *sent; + + bt_dev_dbg(hdev, "status 0x%2.2x", rp->status); + + if (rp->status) + return rp->status; + + sent = hci_sent_cmd_data(hdev, HCI_OP_LE_SET_PER_ADV_ENABLE); + if (!sent) + return rp->status; + + hci_dev_lock(hdev); + + if (*sent) + hci_dev_set_flag(hdev, HCI_LE_PER_ADV); + else + hci_dev_clear_flag(hdev, HCI_LE_PER_ADV); + + hci_dev_unlock(hdev); + + return rp->status; +} + +#define HCI_CC_VL(_op, _func, _min, _max) \ +{ \ + .op = _op, \ + .func = _func, \ + .min_len = _min, \ + .max_len = _max, \ +} + +#define HCI_CC(_op, _func, _len) \ + HCI_CC_VL(_op, _func, _len, _len) + +#define HCI_CC_STATUS(_op, _func) \ + HCI_CC(_op, _func, sizeof(struct hci_ev_status)) + +static const struct hci_cc { + u16 op; + u8 (*func)(struct hci_dev *hdev, void *data, struct sk_buff *skb); + u16 min_len; + u16 max_len; +} hci_cc_table[] = { + HCI_CC_STATUS(HCI_OP_INQUIRY_CANCEL, hci_cc_inquiry_cancel), + HCI_CC_STATUS(HCI_OP_PERIODIC_INQ, hci_cc_periodic_inq), + HCI_CC_STATUS(HCI_OP_EXIT_PERIODIC_INQ, hci_cc_exit_periodic_inq), + HCI_CC_STATUS(HCI_OP_REMOTE_NAME_REQ_CANCEL, + hci_cc_remote_name_req_cancel), + HCI_CC(HCI_OP_ROLE_DISCOVERY, hci_cc_role_discovery, + sizeof(struct hci_rp_role_discovery)), + HCI_CC(HCI_OP_READ_LINK_POLICY, hci_cc_read_link_policy, + sizeof(struct hci_rp_read_link_policy)), + HCI_CC(HCI_OP_WRITE_LINK_POLICY, hci_cc_write_link_policy, + sizeof(struct hci_rp_write_link_policy)), + HCI_CC(HCI_OP_READ_DEF_LINK_POLICY, hci_cc_read_def_link_policy, + sizeof(struct hci_rp_read_def_link_policy)), + HCI_CC_STATUS(HCI_OP_WRITE_DEF_LINK_POLICY, + hci_cc_write_def_link_policy), + HCI_CC_STATUS(HCI_OP_RESET, hci_cc_reset), + HCI_CC(HCI_OP_READ_STORED_LINK_KEY, hci_cc_read_stored_link_key, + sizeof(struct hci_rp_read_stored_link_key)), + HCI_CC(HCI_OP_DELETE_STORED_LINK_KEY, hci_cc_delete_stored_link_key, + sizeof(struct hci_rp_delete_stored_link_key)), + HCI_CC_STATUS(HCI_OP_WRITE_LOCAL_NAME, hci_cc_write_local_name), + HCI_CC(HCI_OP_READ_LOCAL_NAME, hci_cc_read_local_name, + sizeof(struct hci_rp_read_local_name)), + HCI_CC_STATUS(HCI_OP_WRITE_AUTH_ENABLE, hci_cc_write_auth_enable), + HCI_CC_STATUS(HCI_OP_WRITE_ENCRYPT_MODE, hci_cc_write_encrypt_mode), + HCI_CC_STATUS(HCI_OP_WRITE_SCAN_ENABLE, hci_cc_write_scan_enable), + HCI_CC_STATUS(HCI_OP_SET_EVENT_FLT, hci_cc_set_event_filter), + HCI_CC(HCI_OP_READ_CLASS_OF_DEV, hci_cc_read_class_of_dev, + sizeof(struct hci_rp_read_class_of_dev)), + HCI_CC_STATUS(HCI_OP_WRITE_CLASS_OF_DEV, hci_cc_write_class_of_dev), + HCI_CC(HCI_OP_READ_VOICE_SETTING, hci_cc_read_voice_setting, + sizeof(struct hci_rp_read_voice_setting)), + HCI_CC_STATUS(HCI_OP_WRITE_VOICE_SETTING, hci_cc_write_voice_setting), + HCI_CC(HCI_OP_READ_NUM_SUPPORTED_IAC, hci_cc_read_num_supported_iac, + sizeof(struct hci_rp_read_num_supported_iac)), + HCI_CC_STATUS(HCI_OP_WRITE_SSP_MODE, hci_cc_write_ssp_mode), + HCI_CC_STATUS(HCI_OP_WRITE_SC_SUPPORT, hci_cc_write_sc_support), + HCI_CC(HCI_OP_READ_AUTH_PAYLOAD_TO, hci_cc_read_auth_payload_timeout, + sizeof(struct hci_rp_read_auth_payload_to)), + HCI_CC(HCI_OP_WRITE_AUTH_PAYLOAD_TO, hci_cc_write_auth_payload_timeout, + sizeof(struct hci_rp_write_auth_payload_to)), + HCI_CC(HCI_OP_READ_LOCAL_VERSION, hci_cc_read_local_version, + sizeof(struct hci_rp_read_local_version)), + HCI_CC(HCI_OP_READ_LOCAL_COMMANDS, hci_cc_read_local_commands, + sizeof(struct hci_rp_read_local_commands)), + HCI_CC(HCI_OP_READ_LOCAL_FEATURES, hci_cc_read_local_features, + sizeof(struct hci_rp_read_local_features)), + HCI_CC(HCI_OP_READ_LOCAL_EXT_FEATURES, hci_cc_read_local_ext_features, + sizeof(struct hci_rp_read_local_ext_features)), + HCI_CC(HCI_OP_READ_BUFFER_SIZE, hci_cc_read_buffer_size, + sizeof(struct hci_rp_read_buffer_size)), + HCI_CC(HCI_OP_READ_BD_ADDR, hci_cc_read_bd_addr, + sizeof(struct hci_rp_read_bd_addr)), + HCI_CC(HCI_OP_READ_LOCAL_PAIRING_OPTS, hci_cc_read_local_pairing_opts, + sizeof(struct hci_rp_read_local_pairing_opts)), + HCI_CC(HCI_OP_READ_PAGE_SCAN_ACTIVITY, hci_cc_read_page_scan_activity, + sizeof(struct hci_rp_read_page_scan_activity)), + HCI_CC_STATUS(HCI_OP_WRITE_PAGE_SCAN_ACTIVITY, + hci_cc_write_page_scan_activity), + HCI_CC(HCI_OP_READ_PAGE_SCAN_TYPE, hci_cc_read_page_scan_type, + sizeof(struct hci_rp_read_page_scan_type)), + HCI_CC_STATUS(HCI_OP_WRITE_PAGE_SCAN_TYPE, hci_cc_write_page_scan_type), + HCI_CC(HCI_OP_READ_DATA_BLOCK_SIZE, hci_cc_read_data_block_size, + sizeof(struct hci_rp_read_data_block_size)), + HCI_CC(HCI_OP_READ_FLOW_CONTROL_MODE, hci_cc_read_flow_control_mode, + sizeof(struct hci_rp_read_flow_control_mode)), + HCI_CC(HCI_OP_READ_LOCAL_AMP_INFO, hci_cc_read_local_amp_info, + sizeof(struct hci_rp_read_local_amp_info)), + HCI_CC(HCI_OP_READ_CLOCK, hci_cc_read_clock, + sizeof(struct hci_rp_read_clock)), + HCI_CC(HCI_OP_READ_ENC_KEY_SIZE, hci_cc_read_enc_key_size, + sizeof(struct hci_rp_read_enc_key_size)), + HCI_CC(HCI_OP_READ_INQ_RSP_TX_POWER, hci_cc_read_inq_rsp_tx_power, + sizeof(struct hci_rp_read_inq_rsp_tx_power)), + HCI_CC(HCI_OP_READ_DEF_ERR_DATA_REPORTING, + hci_cc_read_def_err_data_reporting, + sizeof(struct hci_rp_read_def_err_data_reporting)), + HCI_CC_STATUS(HCI_OP_WRITE_DEF_ERR_DATA_REPORTING, + hci_cc_write_def_err_data_reporting), + HCI_CC(HCI_OP_PIN_CODE_REPLY, hci_cc_pin_code_reply, + sizeof(struct hci_rp_pin_code_reply)), + HCI_CC(HCI_OP_PIN_CODE_NEG_REPLY, hci_cc_pin_code_neg_reply, + sizeof(struct hci_rp_pin_code_neg_reply)), + HCI_CC(HCI_OP_READ_LOCAL_OOB_DATA, hci_cc_read_local_oob_data, + sizeof(struct hci_rp_read_local_oob_data)), + HCI_CC(HCI_OP_READ_LOCAL_OOB_EXT_DATA, hci_cc_read_local_oob_ext_data, + sizeof(struct hci_rp_read_local_oob_ext_data)), + HCI_CC(HCI_OP_LE_READ_BUFFER_SIZE, hci_cc_le_read_buffer_size, + sizeof(struct hci_rp_le_read_buffer_size)), + HCI_CC(HCI_OP_LE_READ_LOCAL_FEATURES, hci_cc_le_read_local_features, + sizeof(struct hci_rp_le_read_local_features)), + HCI_CC(HCI_OP_LE_READ_ADV_TX_POWER, hci_cc_le_read_adv_tx_power, + sizeof(struct hci_rp_le_read_adv_tx_power)), + HCI_CC(HCI_OP_USER_CONFIRM_REPLY, hci_cc_user_confirm_reply, + sizeof(struct hci_rp_user_confirm_reply)), + HCI_CC(HCI_OP_USER_CONFIRM_NEG_REPLY, hci_cc_user_confirm_neg_reply, + sizeof(struct hci_rp_user_confirm_reply)), + HCI_CC(HCI_OP_USER_PASSKEY_REPLY, hci_cc_user_passkey_reply, + sizeof(struct hci_rp_user_confirm_reply)), + HCI_CC(HCI_OP_USER_PASSKEY_NEG_REPLY, hci_cc_user_passkey_neg_reply, + sizeof(struct hci_rp_user_confirm_reply)), + HCI_CC_STATUS(HCI_OP_LE_SET_RANDOM_ADDR, hci_cc_le_set_random_addr), + HCI_CC_STATUS(HCI_OP_LE_SET_ADV_ENABLE, hci_cc_le_set_adv_enable), + HCI_CC_STATUS(HCI_OP_LE_SET_SCAN_PARAM, hci_cc_le_set_scan_param), + HCI_CC_STATUS(HCI_OP_LE_SET_SCAN_ENABLE, hci_cc_le_set_scan_enable), + HCI_CC(HCI_OP_LE_READ_ACCEPT_LIST_SIZE, + hci_cc_le_read_accept_list_size, + sizeof(struct hci_rp_le_read_accept_list_size)), + HCI_CC_STATUS(HCI_OP_LE_CLEAR_ACCEPT_LIST, hci_cc_le_clear_accept_list), + HCI_CC_STATUS(HCI_OP_LE_ADD_TO_ACCEPT_LIST, + hci_cc_le_add_to_accept_list), + HCI_CC_STATUS(HCI_OP_LE_DEL_FROM_ACCEPT_LIST, + hci_cc_le_del_from_accept_list), + HCI_CC(HCI_OP_LE_READ_SUPPORTED_STATES, hci_cc_le_read_supported_states, + sizeof(struct hci_rp_le_read_supported_states)), + HCI_CC(HCI_OP_LE_READ_DEF_DATA_LEN, hci_cc_le_read_def_data_len, + sizeof(struct hci_rp_le_read_def_data_len)), + HCI_CC_STATUS(HCI_OP_LE_WRITE_DEF_DATA_LEN, + hci_cc_le_write_def_data_len), + HCI_CC_STATUS(HCI_OP_LE_ADD_TO_RESOLV_LIST, + hci_cc_le_add_to_resolv_list), + HCI_CC_STATUS(HCI_OP_LE_DEL_FROM_RESOLV_LIST, + hci_cc_le_del_from_resolv_list), + HCI_CC_STATUS(HCI_OP_LE_CLEAR_RESOLV_LIST, + hci_cc_le_clear_resolv_list), + HCI_CC(HCI_OP_LE_READ_RESOLV_LIST_SIZE, hci_cc_le_read_resolv_list_size, + sizeof(struct hci_rp_le_read_resolv_list_size)), + HCI_CC_STATUS(HCI_OP_LE_SET_ADDR_RESOLV_ENABLE, + hci_cc_le_set_addr_resolution_enable), + HCI_CC(HCI_OP_LE_READ_MAX_DATA_LEN, hci_cc_le_read_max_data_len, + sizeof(struct hci_rp_le_read_max_data_len)), + HCI_CC_STATUS(HCI_OP_WRITE_LE_HOST_SUPPORTED, + hci_cc_write_le_host_supported), + HCI_CC_STATUS(HCI_OP_LE_SET_ADV_PARAM, hci_cc_set_adv_param), + HCI_CC(HCI_OP_READ_RSSI, hci_cc_read_rssi, + sizeof(struct hci_rp_read_rssi)), + HCI_CC(HCI_OP_READ_TX_POWER, hci_cc_read_tx_power, + sizeof(struct hci_rp_read_tx_power)), + HCI_CC_STATUS(HCI_OP_WRITE_SSP_DEBUG_MODE, hci_cc_write_ssp_debug_mode), + HCI_CC_STATUS(HCI_OP_LE_SET_EXT_SCAN_PARAMS, + hci_cc_le_set_ext_scan_param), + HCI_CC_STATUS(HCI_OP_LE_SET_EXT_SCAN_ENABLE, + hci_cc_le_set_ext_scan_enable), + HCI_CC_STATUS(HCI_OP_LE_SET_DEFAULT_PHY, hci_cc_le_set_default_phy), + HCI_CC(HCI_OP_LE_READ_NUM_SUPPORTED_ADV_SETS, + hci_cc_le_read_num_adv_sets, + sizeof(struct hci_rp_le_read_num_supported_adv_sets)), + HCI_CC(HCI_OP_LE_SET_EXT_ADV_PARAMS, hci_cc_set_ext_adv_param, + sizeof(struct hci_rp_le_set_ext_adv_params)), + HCI_CC_STATUS(HCI_OP_LE_SET_EXT_ADV_ENABLE, + hci_cc_le_set_ext_adv_enable), + HCI_CC_STATUS(HCI_OP_LE_SET_ADV_SET_RAND_ADDR, + hci_cc_le_set_adv_set_random_addr), + HCI_CC_STATUS(HCI_OP_LE_REMOVE_ADV_SET, hci_cc_le_remove_adv_set), + HCI_CC_STATUS(HCI_OP_LE_CLEAR_ADV_SETS, hci_cc_le_clear_adv_sets), + HCI_CC_STATUS(HCI_OP_LE_SET_PER_ADV_PARAMS, hci_cc_set_per_adv_param), + HCI_CC_STATUS(HCI_OP_LE_SET_PER_ADV_ENABLE, + hci_cc_le_set_per_adv_enable), + HCI_CC(HCI_OP_LE_READ_TRANSMIT_POWER, hci_cc_le_read_transmit_power, + sizeof(struct hci_rp_le_read_transmit_power)), + HCI_CC_STATUS(HCI_OP_LE_SET_PRIVACY_MODE, hci_cc_le_set_privacy_mode), + HCI_CC(HCI_OP_LE_READ_BUFFER_SIZE_V2, hci_cc_le_read_buffer_size_v2, + sizeof(struct hci_rp_le_read_buffer_size_v2)), + HCI_CC_VL(HCI_OP_LE_SET_CIG_PARAMS, hci_cc_le_set_cig_params, + sizeof(struct hci_rp_le_set_cig_params), HCI_MAX_EVENT_SIZE), + HCI_CC(HCI_OP_LE_SETUP_ISO_PATH, hci_cc_le_setup_iso_path, + sizeof(struct hci_rp_le_setup_iso_path)), +}; + +static u8 hci_cc_func(struct hci_dev *hdev, const struct hci_cc *cc, + struct sk_buff *skb) +{ + void *data; + + if (skb->len < cc->min_len) { + bt_dev_err(hdev, "unexpected cc 0x%4.4x length: %u < %u", + cc->op, skb->len, cc->min_len); + return HCI_ERROR_UNSPECIFIED; + } + + /* Just warn if the length is over max_len size it still be possible to + * partially parse the cc so leave to callback to decide if that is + * acceptable. + */ + if (skb->len > cc->max_len) + bt_dev_warn(hdev, "unexpected cc 0x%4.4x length: %u > %u", + cc->op, skb->len, cc->max_len); + + data = hci_cc_skb_pull(hdev, skb, cc->op, cc->min_len); + if (!data) + return HCI_ERROR_UNSPECIFIED; + + return cc->func(hdev, data, skb); +} + +static void hci_cmd_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb, u16 *opcode, u8 *status, + hci_req_complete_t *req_complete, + hci_req_complete_skb_t *req_complete_skb) +{ + struct hci_ev_cmd_complete *ev = data; + int i; + + *opcode = __le16_to_cpu(ev->opcode); + + bt_dev_dbg(hdev, "opcode 0x%4.4x", *opcode); + + for (i = 0; i < ARRAY_SIZE(hci_cc_table); i++) { + if (hci_cc_table[i].op == *opcode) { + *status = hci_cc_func(hdev, &hci_cc_table[i], skb); + break; + } + } + + if (i == ARRAY_SIZE(hci_cc_table)) { + /* Unknown opcode, assume byte 0 contains the status, so + * that e.g. __hci_cmd_sync() properly returns errors + * for vendor specific commands send by HCI drivers. + * If a vendor doesn't actually follow this convention we may + * need to introduce a vendor CC table in order to properly set + * the status. + */ + *status = skb->data[0]; + } + + handle_cmd_cnt_and_timer(hdev, ev->ncmd); + + hci_req_cmd_complete(hdev, *opcode, *status, req_complete, + req_complete_skb); + + if (hci_dev_test_flag(hdev, HCI_CMD_PENDING)) { + bt_dev_err(hdev, + "unexpected event for opcode 0x%4.4x", *opcode); + return; + } + + if (atomic_read(&hdev->cmd_cnt) && !skb_queue_empty(&hdev->cmd_q)) + queue_work(hdev->workqueue, &hdev->cmd_work); +} + +static void hci_cs_le_create_cis(struct hci_dev *hdev, u8 status) +{ + struct hci_cp_le_create_cis *cp; + int i; + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + if (!status) + return; + + cp = hci_sent_cmd_data(hdev, HCI_OP_LE_CREATE_CIS); + if (!cp) + return; + + hci_dev_lock(hdev); + + /* Remove connection if command failed */ + for (i = 0; cp->num_cis; cp->num_cis--, i++) { + struct hci_conn *conn; + u16 handle; + + handle = __le16_to_cpu(cp->cis[i].cis_handle); + + conn = hci_conn_hash_lookup_handle(hdev, handle); + if (conn) { + conn->state = BT_CLOSED; + hci_connect_cfm(conn, status); + hci_conn_del(conn); + } + } + + hci_dev_unlock(hdev); +} + +#define HCI_CS(_op, _func) \ +{ \ + .op = _op, \ + .func = _func, \ +} + +static const struct hci_cs { + u16 op; + void (*func)(struct hci_dev *hdev, __u8 status); +} hci_cs_table[] = { + HCI_CS(HCI_OP_INQUIRY, hci_cs_inquiry), + HCI_CS(HCI_OP_CREATE_CONN, hci_cs_create_conn), + HCI_CS(HCI_OP_DISCONNECT, hci_cs_disconnect), + HCI_CS(HCI_OP_ADD_SCO, hci_cs_add_sco), + HCI_CS(HCI_OP_AUTH_REQUESTED, hci_cs_auth_requested), + HCI_CS(HCI_OP_SET_CONN_ENCRYPT, hci_cs_set_conn_encrypt), + HCI_CS(HCI_OP_REMOTE_NAME_REQ, hci_cs_remote_name_req), + HCI_CS(HCI_OP_READ_REMOTE_FEATURES, hci_cs_read_remote_features), + HCI_CS(HCI_OP_READ_REMOTE_EXT_FEATURES, + hci_cs_read_remote_ext_features), + HCI_CS(HCI_OP_SETUP_SYNC_CONN, hci_cs_setup_sync_conn), + HCI_CS(HCI_OP_ENHANCED_SETUP_SYNC_CONN, + hci_cs_enhanced_setup_sync_conn), + HCI_CS(HCI_OP_SNIFF_MODE, hci_cs_sniff_mode), + HCI_CS(HCI_OP_EXIT_SNIFF_MODE, hci_cs_exit_sniff_mode), + HCI_CS(HCI_OP_SWITCH_ROLE, hci_cs_switch_role), + HCI_CS(HCI_OP_LE_CREATE_CONN, hci_cs_le_create_conn), + HCI_CS(HCI_OP_LE_READ_REMOTE_FEATURES, hci_cs_le_read_remote_features), + HCI_CS(HCI_OP_LE_START_ENC, hci_cs_le_start_enc), + HCI_CS(HCI_OP_LE_EXT_CREATE_CONN, hci_cs_le_ext_create_conn), + HCI_CS(HCI_OP_LE_CREATE_CIS, hci_cs_le_create_cis), + HCI_CS(HCI_OP_LE_CREATE_BIG, hci_cs_le_create_big), +}; + +static void hci_cmd_status_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb, u16 *opcode, u8 *status, + hci_req_complete_t *req_complete, + hci_req_complete_skb_t *req_complete_skb) +{ + struct hci_ev_cmd_status *ev = data; + int i; + + *opcode = __le16_to_cpu(ev->opcode); + *status = ev->status; + + bt_dev_dbg(hdev, "opcode 0x%4.4x", *opcode); + + for (i = 0; i < ARRAY_SIZE(hci_cs_table); i++) { + if (hci_cs_table[i].op == *opcode) { + hci_cs_table[i].func(hdev, ev->status); + break; + } + } + + handle_cmd_cnt_and_timer(hdev, ev->ncmd); + + /* Indicate request completion if the command failed. Also, if + * we're not waiting for a special event and we get a success + * command status we should try to flag the request as completed + * (since for this kind of commands there will not be a command + * complete event). + */ + if (ev->status || (hdev->sent_cmd && !hci_skb_event(hdev->sent_cmd))) { + hci_req_cmd_complete(hdev, *opcode, ev->status, req_complete, + req_complete_skb); + if (hci_dev_test_flag(hdev, HCI_CMD_PENDING)) { + bt_dev_err(hdev, "unexpected event for opcode 0x%4.4x", + *opcode); + return; + } + } + + if (atomic_read(&hdev->cmd_cnt) && !skb_queue_empty(&hdev->cmd_q)) + queue_work(hdev->workqueue, &hdev->cmd_work); +} + +static void hci_hardware_error_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_hardware_error *ev = data; + + bt_dev_dbg(hdev, "code 0x%2.2x", ev->code); + + hdev->hw_error_code = ev->code; + + queue_work(hdev->req_workqueue, &hdev->error_reset); +} + +static void hci_role_change_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_role_change *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &ev->bdaddr); + if (conn) { + if (!ev->status) + conn->role = ev->role; + + clear_bit(HCI_CONN_RSWITCH_PEND, &conn->flags); + + hci_role_switch_cfm(conn, ev->status, ev->role); + } + + hci_dev_unlock(hdev); +} + +static void hci_num_comp_pkts_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_num_comp_pkts *ev = data; + int i; + + if (!hci_ev_skb_pull(hdev, skb, HCI_EV_NUM_COMP_PKTS, + flex_array_size(ev, handles, ev->num))) + return; + + if (hdev->flow_ctl_mode != HCI_FLOW_CTL_MODE_PACKET_BASED) { + bt_dev_err(hdev, "wrong event for mode %d", hdev->flow_ctl_mode); + return; + } + + bt_dev_dbg(hdev, "num %d", ev->num); + + for (i = 0; i < ev->num; i++) { + struct hci_comp_pkts_info *info = &ev->handles[i]; + struct hci_conn *conn; + __u16 handle, count; + + handle = __le16_to_cpu(info->handle); + count = __le16_to_cpu(info->count); + + conn = hci_conn_hash_lookup_handle(hdev, handle); + if (!conn) + continue; + + conn->sent -= count; + + switch (conn->type) { + case ACL_LINK: + hdev->acl_cnt += count; + if (hdev->acl_cnt > hdev->acl_pkts) + hdev->acl_cnt = hdev->acl_pkts; + break; + + case LE_LINK: + if (hdev->le_pkts) { + hdev->le_cnt += count; + if (hdev->le_cnt > hdev->le_pkts) + hdev->le_cnt = hdev->le_pkts; + } else { + hdev->acl_cnt += count; + if (hdev->acl_cnt > hdev->acl_pkts) + hdev->acl_cnt = hdev->acl_pkts; + } + break; + + case SCO_LINK: + hdev->sco_cnt += count; + if (hdev->sco_cnt > hdev->sco_pkts) + hdev->sco_cnt = hdev->sco_pkts; + break; + + case ISO_LINK: + if (hdev->iso_pkts) { + hdev->iso_cnt += count; + if (hdev->iso_cnt > hdev->iso_pkts) + hdev->iso_cnt = hdev->iso_pkts; + } else if (hdev->le_pkts) { + hdev->le_cnt += count; + if (hdev->le_cnt > hdev->le_pkts) + hdev->le_cnt = hdev->le_pkts; + } else { + hdev->acl_cnt += count; + if (hdev->acl_cnt > hdev->acl_pkts) + hdev->acl_cnt = hdev->acl_pkts; + } + break; + + default: + bt_dev_err(hdev, "unknown type %d conn %p", + conn->type, conn); + break; + } + } + + queue_work(hdev->workqueue, &hdev->tx_work); +} + +static struct hci_conn *__hci_conn_lookup_handle(struct hci_dev *hdev, + __u16 handle) +{ + struct hci_chan *chan; + + switch (hdev->dev_type) { + case HCI_PRIMARY: + return hci_conn_hash_lookup_handle(hdev, handle); + case HCI_AMP: + chan = hci_chan_lookup_handle(hdev, handle); + if (chan) + return chan->conn; + break; + default: + bt_dev_err(hdev, "unknown dev_type %d", hdev->dev_type); + break; + } + + return NULL; +} + +static void hci_num_comp_blocks_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_num_comp_blocks *ev = data; + int i; + + if (!hci_ev_skb_pull(hdev, skb, HCI_EV_NUM_COMP_BLOCKS, + flex_array_size(ev, handles, ev->num_hndl))) + return; + + if (hdev->flow_ctl_mode != HCI_FLOW_CTL_MODE_BLOCK_BASED) { + bt_dev_err(hdev, "wrong event for mode %d", + hdev->flow_ctl_mode); + return; + } + + bt_dev_dbg(hdev, "num_blocks %d num_hndl %d", ev->num_blocks, + ev->num_hndl); + + for (i = 0; i < ev->num_hndl; i++) { + struct hci_comp_blocks_info *info = &ev->handles[i]; + struct hci_conn *conn = NULL; + __u16 handle, block_count; + + handle = __le16_to_cpu(info->handle); + block_count = __le16_to_cpu(info->blocks); + + conn = __hci_conn_lookup_handle(hdev, handle); + if (!conn) + continue; + + conn->sent -= block_count; + + switch (conn->type) { + case ACL_LINK: + case AMP_LINK: + hdev->block_cnt += block_count; + if (hdev->block_cnt > hdev->num_blocks) + hdev->block_cnt = hdev->num_blocks; + break; + + default: + bt_dev_err(hdev, "unknown type %d conn %p", + conn->type, conn); + break; + } + } + + queue_work(hdev->workqueue, &hdev->tx_work); +} + +static void hci_mode_change_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_mode_change *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(ev->handle)); + if (conn) { + conn->mode = ev->mode; + + if (!test_and_clear_bit(HCI_CONN_MODE_CHANGE_PEND, + &conn->flags)) { + if (conn->mode == HCI_CM_ACTIVE) + set_bit(HCI_CONN_POWER_SAVE, &conn->flags); + else + clear_bit(HCI_CONN_POWER_SAVE, &conn->flags); + } + + if (test_and_clear_bit(HCI_CONN_SCO_SETUP_PEND, &conn->flags)) + hci_sco_setup(conn, ev->status); + } + + hci_dev_unlock(hdev); +} + +static void hci_pin_code_request_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_pin_code_req *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, ""); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &ev->bdaddr); + if (!conn) + goto unlock; + + if (conn->state == BT_CONNECTED) { + hci_conn_hold(conn); + conn->disc_timeout = HCI_PAIRING_TIMEOUT; + hci_conn_drop(conn); + } + + if (!hci_dev_test_flag(hdev, HCI_BONDABLE) && + !test_bit(HCI_CONN_AUTH_INITIATOR, &conn->flags)) { + hci_send_cmd(hdev, HCI_OP_PIN_CODE_NEG_REPLY, + sizeof(ev->bdaddr), &ev->bdaddr); + } else if (hci_dev_test_flag(hdev, HCI_MGMT)) { + u8 secure; + + if (conn->pending_sec_level == BT_SECURITY_HIGH) + secure = 1; + else + secure = 0; + + mgmt_pin_code_request(hdev, &ev->bdaddr, secure); + } + +unlock: + hci_dev_unlock(hdev); +} + +static void conn_set_key(struct hci_conn *conn, u8 key_type, u8 pin_len) +{ + if (key_type == HCI_LK_CHANGED_COMBINATION) + return; + + conn->pin_length = pin_len; + conn->key_type = key_type; + + switch (key_type) { + case HCI_LK_LOCAL_UNIT: + case HCI_LK_REMOTE_UNIT: + case HCI_LK_DEBUG_COMBINATION: + return; + case HCI_LK_COMBINATION: + if (pin_len == 16) + conn->pending_sec_level = BT_SECURITY_HIGH; + else + conn->pending_sec_level = BT_SECURITY_MEDIUM; + break; + case HCI_LK_UNAUTH_COMBINATION_P192: + case HCI_LK_UNAUTH_COMBINATION_P256: + conn->pending_sec_level = BT_SECURITY_MEDIUM; + break; + case HCI_LK_AUTH_COMBINATION_P192: + conn->pending_sec_level = BT_SECURITY_HIGH; + break; + case HCI_LK_AUTH_COMBINATION_P256: + conn->pending_sec_level = BT_SECURITY_FIPS; + break; + } +} + +static void hci_link_key_request_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_link_key_req *ev = data; + struct hci_cp_link_key_reply cp; + struct hci_conn *conn; + struct link_key *key; + + bt_dev_dbg(hdev, ""); + + if (!hci_dev_test_flag(hdev, HCI_MGMT)) + return; + + hci_dev_lock(hdev); + + key = hci_find_link_key(hdev, &ev->bdaddr); + if (!key) { + bt_dev_dbg(hdev, "link key not found for %pMR", &ev->bdaddr); + goto not_found; + } + + bt_dev_dbg(hdev, "found key type %u for %pMR", key->type, &ev->bdaddr); + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &ev->bdaddr); + if (conn) { + clear_bit(HCI_CONN_NEW_LINK_KEY, &conn->flags); + + if ((key->type == HCI_LK_UNAUTH_COMBINATION_P192 || + key->type == HCI_LK_UNAUTH_COMBINATION_P256) && + conn->auth_type != 0xff && (conn->auth_type & 0x01)) { + bt_dev_dbg(hdev, "ignoring unauthenticated key"); + goto not_found; + } + + if (key->type == HCI_LK_COMBINATION && key->pin_len < 16 && + (conn->pending_sec_level == BT_SECURITY_HIGH || + conn->pending_sec_level == BT_SECURITY_FIPS)) { + bt_dev_dbg(hdev, "ignoring key unauthenticated for high security"); + goto not_found; + } + + conn_set_key(conn, key->type, key->pin_len); + } + + bacpy(&cp.bdaddr, &ev->bdaddr); + memcpy(cp.link_key, key->val, HCI_LINK_KEY_SIZE); + + hci_send_cmd(hdev, HCI_OP_LINK_KEY_REPLY, sizeof(cp), &cp); + + hci_dev_unlock(hdev); + + return; + +not_found: + hci_send_cmd(hdev, HCI_OP_LINK_KEY_NEG_REPLY, 6, &ev->bdaddr); + hci_dev_unlock(hdev); +} + +static void hci_link_key_notify_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_link_key_notify *ev = data; + struct hci_conn *conn; + struct link_key *key; + bool persistent; + u8 pin_len = 0; + + bt_dev_dbg(hdev, ""); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &ev->bdaddr); + if (!conn) + goto unlock; + + /* Ignore NULL link key against CVE-2020-26555 */ + if (!crypto_memneq(ev->link_key, ZERO_KEY, HCI_LINK_KEY_SIZE)) { + bt_dev_dbg(hdev, "Ignore NULL link key (ZERO KEY) for %pMR", + &ev->bdaddr); + hci_disconnect(conn, HCI_ERROR_AUTH_FAILURE); + hci_conn_drop(conn); + goto unlock; + } + + hci_conn_hold(conn); + conn->disc_timeout = HCI_DISCONN_TIMEOUT; + hci_conn_drop(conn); + + set_bit(HCI_CONN_NEW_LINK_KEY, &conn->flags); + conn_set_key(conn, ev->key_type, conn->pin_length); + + if (!hci_dev_test_flag(hdev, HCI_MGMT)) + goto unlock; + + key = hci_add_link_key(hdev, conn, &ev->bdaddr, ev->link_key, + ev->key_type, pin_len, &persistent); + if (!key) + goto unlock; + + /* Update connection information since adding the key will have + * fixed up the type in the case of changed combination keys. + */ + if (ev->key_type == HCI_LK_CHANGED_COMBINATION) + conn_set_key(conn, key->type, key->pin_len); + + mgmt_new_link_key(hdev, key, persistent); + + /* Keep debug keys around only if the HCI_KEEP_DEBUG_KEYS flag + * is set. If it's not set simply remove the key from the kernel + * list (we've still notified user space about it but with + * store_hint being 0). + */ + if (key->type == HCI_LK_DEBUG_COMBINATION && + !hci_dev_test_flag(hdev, HCI_KEEP_DEBUG_KEYS)) { + list_del_rcu(&key->list); + kfree_rcu(key, rcu); + goto unlock; + } + + if (persistent) + clear_bit(HCI_CONN_FLUSH_KEY, &conn->flags); + else + set_bit(HCI_CONN_FLUSH_KEY, &conn->flags); + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_clock_offset_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_clock_offset *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(ev->handle)); + if (conn && !ev->status) { + struct inquiry_entry *ie; + + ie = hci_inquiry_cache_lookup(hdev, &conn->dst); + if (ie) { + ie->data.clock_offset = ev->clock_offset; + ie->timestamp = jiffies; + } + } + + hci_dev_unlock(hdev); +} + +static void hci_pkt_type_change_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_pkt_type_change *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(ev->handle)); + if (conn && !ev->status) + conn->pkt_type = __le16_to_cpu(ev->pkt_type); + + hci_dev_unlock(hdev); +} + +static void hci_pscan_rep_mode_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_pscan_rep_mode *ev = data; + struct inquiry_entry *ie; + + bt_dev_dbg(hdev, ""); + + hci_dev_lock(hdev); + + ie = hci_inquiry_cache_lookup(hdev, &ev->bdaddr); + if (ie) { + ie->data.pscan_rep_mode = ev->pscan_rep_mode; + ie->timestamp = jiffies; + } + + hci_dev_unlock(hdev); +} + +static void hci_inquiry_result_with_rssi_evt(struct hci_dev *hdev, void *edata, + struct sk_buff *skb) +{ + struct hci_ev_inquiry_result_rssi *ev = edata; + struct inquiry_data data; + int i; + + bt_dev_dbg(hdev, "num_rsp %d", ev->num); + + if (!ev->num) + return; + + if (hci_dev_test_flag(hdev, HCI_PERIODIC_INQ)) + return; + + hci_dev_lock(hdev); + + if (skb->len == array_size(ev->num, + sizeof(struct inquiry_info_rssi_pscan))) { + struct inquiry_info_rssi_pscan *info; + + for (i = 0; i < ev->num; i++) { + u32 flags; + + info = hci_ev_skb_pull(hdev, skb, + HCI_EV_INQUIRY_RESULT_WITH_RSSI, + sizeof(*info)); + if (!info) { + bt_dev_err(hdev, "Malformed HCI Event: 0x%2.2x", + HCI_EV_INQUIRY_RESULT_WITH_RSSI); + goto unlock; + } + + bacpy(&data.bdaddr, &info->bdaddr); + data.pscan_rep_mode = info->pscan_rep_mode; + data.pscan_period_mode = info->pscan_period_mode; + data.pscan_mode = info->pscan_mode; + memcpy(data.dev_class, info->dev_class, 3); + data.clock_offset = info->clock_offset; + data.rssi = info->rssi; + data.ssp_mode = 0x00; + + flags = hci_inquiry_cache_update(hdev, &data, false); + + mgmt_device_found(hdev, &info->bdaddr, ACL_LINK, 0x00, + info->dev_class, info->rssi, + flags, NULL, 0, NULL, 0, 0); + } + } else if (skb->len == array_size(ev->num, + sizeof(struct inquiry_info_rssi))) { + struct inquiry_info_rssi *info; + + for (i = 0; i < ev->num; i++) { + u32 flags; + + info = hci_ev_skb_pull(hdev, skb, + HCI_EV_INQUIRY_RESULT_WITH_RSSI, + sizeof(*info)); + if (!info) { + bt_dev_err(hdev, "Malformed HCI Event: 0x%2.2x", + HCI_EV_INQUIRY_RESULT_WITH_RSSI); + goto unlock; + } + + bacpy(&data.bdaddr, &info->bdaddr); + data.pscan_rep_mode = info->pscan_rep_mode; + data.pscan_period_mode = info->pscan_period_mode; + data.pscan_mode = 0x00; + memcpy(data.dev_class, info->dev_class, 3); + data.clock_offset = info->clock_offset; + data.rssi = info->rssi; + data.ssp_mode = 0x00; + + flags = hci_inquiry_cache_update(hdev, &data, false); + + mgmt_device_found(hdev, &info->bdaddr, ACL_LINK, 0x00, + info->dev_class, info->rssi, + flags, NULL, 0, NULL, 0, 0); + } + } else { + bt_dev_err(hdev, "Malformed HCI Event: 0x%2.2x", + HCI_EV_INQUIRY_RESULT_WITH_RSSI); + } +unlock: + hci_dev_unlock(hdev); +} + +static void hci_remote_ext_features_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_remote_ext_features *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(ev->handle)); + if (!conn) + goto unlock; + + if (ev->page < HCI_MAX_PAGES) + memcpy(conn->features[ev->page], ev->features, 8); + + if (!ev->status && ev->page == 0x01) { + struct inquiry_entry *ie; + + ie = hci_inquiry_cache_lookup(hdev, &conn->dst); + if (ie) + ie->data.ssp_mode = (ev->features[0] & LMP_HOST_SSP); + + if (ev->features[0] & LMP_HOST_SSP) { + set_bit(HCI_CONN_SSP_ENABLED, &conn->flags); + } else { + /* It is mandatory by the Bluetooth specification that + * Extended Inquiry Results are only used when Secure + * Simple Pairing is enabled, but some devices violate + * this. + * + * To make these devices work, the internal SSP + * enabled flag needs to be cleared if the remote host + * features do not indicate SSP support */ + clear_bit(HCI_CONN_SSP_ENABLED, &conn->flags); + } + + if (ev->features[0] & LMP_HOST_SC) + set_bit(HCI_CONN_SC_ENABLED, &conn->flags); + } + + if (conn->state != BT_CONFIG) + goto unlock; + + if (!ev->status && !test_bit(HCI_CONN_MGMT_CONNECTED, &conn->flags)) { + struct hci_cp_remote_name_req cp; + memset(&cp, 0, sizeof(cp)); + bacpy(&cp.bdaddr, &conn->dst); + cp.pscan_rep_mode = 0x02; + hci_send_cmd(hdev, HCI_OP_REMOTE_NAME_REQ, sizeof(cp), &cp); + } else if (!test_and_set_bit(HCI_CONN_MGMT_CONNECTED, &conn->flags)) + mgmt_device_connected(hdev, conn, NULL, 0); + + if (!hci_outgoing_auth_needed(hdev, conn)) { + conn->state = BT_CONNECTED; + hci_connect_cfm(conn, ev->status); + hci_conn_drop(conn); + } + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_sync_conn_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_sync_conn_complete *ev = data; + struct hci_conn *conn; + u8 status = ev->status; + + switch (ev->link_type) { + case SCO_LINK: + case ESCO_LINK: + break; + default: + /* As per Core 5.3 Vol 4 Part E 7.7.35 (p.2219), Link_Type + * for HCI_Synchronous_Connection_Complete is limited to + * either SCO or eSCO + */ + bt_dev_err(hdev, "Ignoring connect complete event for invalid link type"); + return; + } + + bt_dev_dbg(hdev, "status 0x%2.2x", status); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_ba(hdev, ev->link_type, &ev->bdaddr); + if (!conn) { + if (ev->link_type == ESCO_LINK) + goto unlock; + + /* When the link type in the event indicates SCO connection + * and lookup of the connection object fails, then check + * if an eSCO connection object exists. + * + * The core limits the synchronous connections to either + * SCO or eSCO. The eSCO connection is preferred and tried + * to be setup first and until successfully established, + * the link type will be hinted as eSCO. + */ + conn = hci_conn_hash_lookup_ba(hdev, ESCO_LINK, &ev->bdaddr); + if (!conn) + goto unlock; + } + + /* The HCI_Synchronous_Connection_Complete event is only sent once per connection. + * Processing it more than once per connection can corrupt kernel memory. + * + * As the connection handle is set here for the first time, it indicates + * whether the connection is already set up. + */ + if (conn->handle != HCI_CONN_HANDLE_UNSET) { + bt_dev_err(hdev, "Ignoring HCI_Sync_Conn_Complete event for existing connection"); + goto unlock; + } + + switch (status) { + case 0x00: + conn->handle = __le16_to_cpu(ev->handle); + if (conn->handle > HCI_CONN_HANDLE_MAX) { + bt_dev_err(hdev, "Invalid handle: 0x%4.4x > 0x%4.4x", + conn->handle, HCI_CONN_HANDLE_MAX); + status = HCI_ERROR_INVALID_PARAMETERS; + conn->state = BT_CLOSED; + break; + } + + conn->state = BT_CONNECTED; + conn->type = ev->link_type; + + hci_debugfs_create_conn(conn); + hci_conn_add_sysfs(conn); + break; + + case 0x10: /* Connection Accept Timeout */ + case 0x0d: /* Connection Rejected due to Limited Resources */ + case 0x11: /* Unsupported Feature or Parameter Value */ + case 0x1c: /* SCO interval rejected */ + case 0x1a: /* Unsupported Remote Feature */ + case 0x1e: /* Invalid LMP Parameters */ + case 0x1f: /* Unspecified error */ + case 0x20: /* Unsupported LMP Parameter value */ + if (conn->out) { + conn->pkt_type = (hdev->esco_type & SCO_ESCO_MASK) | + (hdev->esco_type & EDR_ESCO_MASK); + if (hci_setup_sync(conn, conn->link->handle)) + goto unlock; + } + fallthrough; + + default: + conn->state = BT_CLOSED; + break; + } + + bt_dev_dbg(hdev, "SCO connected with air mode: %02x", ev->air_mode); + /* Notify only in case of SCO over HCI transport data path which + * is zero and non-zero value shall be non-HCI transport data path + */ + if (conn->codec.data_path == 0 && hdev->notify) { + switch (ev->air_mode) { + case 0x02: + hdev->notify(hdev, HCI_NOTIFY_ENABLE_SCO_CVSD); + break; + case 0x03: + hdev->notify(hdev, HCI_NOTIFY_ENABLE_SCO_TRANSP); + break; + } + } + + hci_connect_cfm(conn, status); + if (status) + hci_conn_del(conn); + +unlock: + hci_dev_unlock(hdev); +} + +static inline size_t eir_get_length(u8 *eir, size_t eir_len) +{ + size_t parsed = 0; + + while (parsed < eir_len) { + u8 field_len = eir[0]; + + if (field_len == 0) + return parsed; + + parsed += field_len + 1; + eir += field_len + 1; + } + + return eir_len; +} + +static void hci_extended_inquiry_result_evt(struct hci_dev *hdev, void *edata, + struct sk_buff *skb) +{ + struct hci_ev_ext_inquiry_result *ev = edata; + struct inquiry_data data; + size_t eir_len; + int i; + + if (!hci_ev_skb_pull(hdev, skb, HCI_EV_EXTENDED_INQUIRY_RESULT, + flex_array_size(ev, info, ev->num))) + return; + + bt_dev_dbg(hdev, "num %d", ev->num); + + if (!ev->num) + return; + + if (hci_dev_test_flag(hdev, HCI_PERIODIC_INQ)) + return; + + hci_dev_lock(hdev); + + for (i = 0; i < ev->num; i++) { + struct extended_inquiry_info *info = &ev->info[i]; + u32 flags; + bool name_known; + + bacpy(&data.bdaddr, &info->bdaddr); + data.pscan_rep_mode = info->pscan_rep_mode; + data.pscan_period_mode = info->pscan_period_mode; + data.pscan_mode = 0x00; + memcpy(data.dev_class, info->dev_class, 3); + data.clock_offset = info->clock_offset; + data.rssi = info->rssi; + data.ssp_mode = 0x01; + + if (hci_dev_test_flag(hdev, HCI_MGMT)) + name_known = eir_get_data(info->data, + sizeof(info->data), + EIR_NAME_COMPLETE, NULL); + else + name_known = true; + + flags = hci_inquiry_cache_update(hdev, &data, name_known); + + eir_len = eir_get_length(info->data, sizeof(info->data)); + + mgmt_device_found(hdev, &info->bdaddr, ACL_LINK, 0x00, + info->dev_class, info->rssi, + flags, info->data, eir_len, NULL, 0, 0); + } + + hci_dev_unlock(hdev); +} + +static void hci_key_refresh_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_key_refresh_complete *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x handle 0x%4.4x", ev->status, + __le16_to_cpu(ev->handle)); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(ev->handle)); + if (!conn) + goto unlock; + + /* For BR/EDR the necessary steps are taken through the + * auth_complete event. + */ + if (conn->type != LE_LINK) + goto unlock; + + if (!ev->status) + conn->sec_level = conn->pending_sec_level; + + clear_bit(HCI_CONN_ENCRYPT_PEND, &conn->flags); + + if (ev->status && conn->state == BT_CONNECTED) { + hci_disconnect(conn, HCI_ERROR_AUTH_FAILURE); + hci_conn_drop(conn); + goto unlock; + } + + if (conn->state == BT_CONFIG) { + if (!ev->status) + conn->state = BT_CONNECTED; + + hci_connect_cfm(conn, ev->status); + hci_conn_drop(conn); + } else { + hci_auth_cfm(conn, ev->status); + + hci_conn_hold(conn); + conn->disc_timeout = HCI_DISCONN_TIMEOUT; + hci_conn_drop(conn); + } + +unlock: + hci_dev_unlock(hdev); +} + +static u8 hci_get_auth_req(struct hci_conn *conn) +{ + /* If remote requests no-bonding follow that lead */ + if (conn->remote_auth == HCI_AT_NO_BONDING || + conn->remote_auth == HCI_AT_NO_BONDING_MITM) + return conn->remote_auth | (conn->auth_type & 0x01); + + /* If both remote and local have enough IO capabilities, require + * MITM protection + */ + if (conn->remote_cap != HCI_IO_NO_INPUT_OUTPUT && + conn->io_capability != HCI_IO_NO_INPUT_OUTPUT) + return conn->remote_auth | 0x01; + + /* No MITM protection possible so ignore remote requirement */ + return (conn->remote_auth & ~0x01) | (conn->auth_type & 0x01); +} + +static u8 bredr_oob_data_present(struct hci_conn *conn) +{ + struct hci_dev *hdev = conn->hdev; + struct oob_data *data; + + data = hci_find_remote_oob_data(hdev, &conn->dst, BDADDR_BREDR); + if (!data) + return 0x00; + + if (bredr_sc_enabled(hdev)) { + /* When Secure Connections is enabled, then just + * return the present value stored with the OOB + * data. The stored value contains the right present + * information. However it can only be trusted when + * not in Secure Connection Only mode. + */ + if (!hci_dev_test_flag(hdev, HCI_SC_ONLY)) + return data->present; + + /* When Secure Connections Only mode is enabled, then + * the P-256 values are required. If they are not + * available, then do not declare that OOB data is + * present. + */ + if (!crypto_memneq(data->rand256, ZERO_KEY, 16) || + !crypto_memneq(data->hash256, ZERO_KEY, 16)) + return 0x00; + + return 0x02; + } + + /* When Secure Connections is not enabled or actually + * not supported by the hardware, then check that if + * P-192 data values are present. + */ + if (!crypto_memneq(data->rand192, ZERO_KEY, 16) || + !crypto_memneq(data->hash192, ZERO_KEY, 16)) + return 0x00; + + return 0x01; +} + +static void hci_io_capa_request_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_io_capa_request *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, ""); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &ev->bdaddr); + if (!conn || !hci_conn_ssp_enabled(conn)) + goto unlock; + + hci_conn_hold(conn); + + if (!hci_dev_test_flag(hdev, HCI_MGMT)) + goto unlock; + + /* Allow pairing if we're pairable, the initiators of the + * pairing or if the remote is not requesting bonding. + */ + if (hci_dev_test_flag(hdev, HCI_BONDABLE) || + test_bit(HCI_CONN_AUTH_INITIATOR, &conn->flags) || + (conn->remote_auth & ~0x01) == HCI_AT_NO_BONDING) { + struct hci_cp_io_capability_reply cp; + + bacpy(&cp.bdaddr, &ev->bdaddr); + /* Change the IO capability from KeyboardDisplay + * to DisplayYesNo as it is not supported by BT spec. */ + cp.capability = (conn->io_capability == 0x04) ? + HCI_IO_DISPLAY_YESNO : conn->io_capability; + + /* If we are initiators, there is no remote information yet */ + if (conn->remote_auth == 0xff) { + /* Request MITM protection if our IO caps allow it + * except for the no-bonding case. + */ + if (conn->io_capability != HCI_IO_NO_INPUT_OUTPUT && + conn->auth_type != HCI_AT_NO_BONDING) + conn->auth_type |= 0x01; + } else { + conn->auth_type = hci_get_auth_req(conn); + } + + /* If we're not bondable, force one of the non-bondable + * authentication requirement values. + */ + if (!hci_dev_test_flag(hdev, HCI_BONDABLE)) + conn->auth_type &= HCI_AT_NO_BONDING_MITM; + + cp.authentication = conn->auth_type; + cp.oob_data = bredr_oob_data_present(conn); + + hci_send_cmd(hdev, HCI_OP_IO_CAPABILITY_REPLY, + sizeof(cp), &cp); + } else { + struct hci_cp_io_capability_neg_reply cp; + + bacpy(&cp.bdaddr, &ev->bdaddr); + cp.reason = HCI_ERROR_PAIRING_NOT_ALLOWED; + + hci_send_cmd(hdev, HCI_OP_IO_CAPABILITY_NEG_REPLY, + sizeof(cp), &cp); + } + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_io_capa_reply_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_io_capa_reply *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, ""); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &ev->bdaddr); + if (!conn) + goto unlock; + + conn->remote_cap = ev->capability; + conn->remote_auth = ev->authentication; + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_user_confirm_request_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_user_confirm_req *ev = data; + int loc_mitm, rem_mitm, confirm_hint = 0; + struct hci_conn *conn; + + bt_dev_dbg(hdev, ""); + + hci_dev_lock(hdev); + + if (!hci_dev_test_flag(hdev, HCI_MGMT)) + goto unlock; + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &ev->bdaddr); + if (!conn) + goto unlock; + + loc_mitm = (conn->auth_type & 0x01); + rem_mitm = (conn->remote_auth & 0x01); + + /* If we require MITM but the remote device can't provide that + * (it has NoInputNoOutput) then reject the confirmation + * request. We check the security level here since it doesn't + * necessarily match conn->auth_type. + */ + if (conn->pending_sec_level > BT_SECURITY_MEDIUM && + conn->remote_cap == HCI_IO_NO_INPUT_OUTPUT) { + bt_dev_dbg(hdev, "Rejecting request: remote device can't provide MITM"); + hci_send_cmd(hdev, HCI_OP_USER_CONFIRM_NEG_REPLY, + sizeof(ev->bdaddr), &ev->bdaddr); + goto unlock; + } + + /* If no side requires MITM protection; auto-accept */ + if ((!loc_mitm || conn->remote_cap == HCI_IO_NO_INPUT_OUTPUT) && + (!rem_mitm || conn->io_capability == HCI_IO_NO_INPUT_OUTPUT)) { + + /* If we're not the initiators request authorization to + * proceed from user space (mgmt_user_confirm with + * confirm_hint set to 1). The exception is if neither + * side had MITM or if the local IO capability is + * NoInputNoOutput, in which case we do auto-accept + */ + if (!test_bit(HCI_CONN_AUTH_PEND, &conn->flags) && + conn->io_capability != HCI_IO_NO_INPUT_OUTPUT && + (loc_mitm || rem_mitm)) { + bt_dev_dbg(hdev, "Confirming auto-accept as acceptor"); + confirm_hint = 1; + goto confirm; + } + + /* If there already exists link key in local host, leave the + * decision to user space since the remote device could be + * legitimate or malicious. + */ + if (hci_find_link_key(hdev, &ev->bdaddr)) { + bt_dev_dbg(hdev, "Local host already has link key"); + confirm_hint = 1; + goto confirm; + } + + BT_DBG("Auto-accept of user confirmation with %ums delay", + hdev->auto_accept_delay); + + if (hdev->auto_accept_delay > 0) { + int delay = msecs_to_jiffies(hdev->auto_accept_delay); + queue_delayed_work(conn->hdev->workqueue, + &conn->auto_accept_work, delay); + goto unlock; + } + + hci_send_cmd(hdev, HCI_OP_USER_CONFIRM_REPLY, + sizeof(ev->bdaddr), &ev->bdaddr); + goto unlock; + } + +confirm: + mgmt_user_confirm_request(hdev, &ev->bdaddr, ACL_LINK, 0, + le32_to_cpu(ev->passkey), confirm_hint); + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_user_passkey_request_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_user_passkey_req *ev = data; + + bt_dev_dbg(hdev, ""); + + if (hci_dev_test_flag(hdev, HCI_MGMT)) + mgmt_user_passkey_request(hdev, &ev->bdaddr, ACL_LINK, 0); +} + +static void hci_user_passkey_notify_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_user_passkey_notify *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, ""); + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &ev->bdaddr); + if (!conn) + return; + + conn->passkey_notify = __le32_to_cpu(ev->passkey); + conn->passkey_entered = 0; + + if (hci_dev_test_flag(hdev, HCI_MGMT)) + mgmt_user_passkey_notify(hdev, &conn->dst, conn->type, + conn->dst_type, conn->passkey_notify, + conn->passkey_entered); +} + +static void hci_keypress_notify_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_keypress_notify *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, ""); + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &ev->bdaddr); + if (!conn) + return; + + switch (ev->type) { + case HCI_KEYPRESS_STARTED: + conn->passkey_entered = 0; + return; + + case HCI_KEYPRESS_ENTERED: + conn->passkey_entered++; + break; + + case HCI_KEYPRESS_ERASED: + conn->passkey_entered--; + break; + + case HCI_KEYPRESS_CLEARED: + conn->passkey_entered = 0; + break; + + case HCI_KEYPRESS_COMPLETED: + return; + } + + if (hci_dev_test_flag(hdev, HCI_MGMT)) + mgmt_user_passkey_notify(hdev, &conn->dst, conn->type, + conn->dst_type, conn->passkey_notify, + conn->passkey_entered); +} + +static void hci_simple_pair_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_simple_pair_complete *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, ""); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &ev->bdaddr); + if (!conn || !hci_conn_ssp_enabled(conn)) + goto unlock; + + /* Reset the authentication requirement to unknown */ + conn->remote_auth = 0xff; + + /* To avoid duplicate auth_failed events to user space we check + * the HCI_CONN_AUTH_PEND flag which will be set if we + * initiated the authentication. A traditional auth_complete + * event gets always produced as initiator and is also mapped to + * the mgmt_auth_failed event */ + if (!test_bit(HCI_CONN_AUTH_PEND, &conn->flags) && ev->status) + mgmt_auth_failed(conn, ev->status); + + hci_conn_drop(conn); + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_remote_host_features_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_remote_host_features *ev = data; + struct inquiry_entry *ie; + struct hci_conn *conn; + + bt_dev_dbg(hdev, ""); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &ev->bdaddr); + if (conn) + memcpy(conn->features[1], ev->features, 8); + + ie = hci_inquiry_cache_lookup(hdev, &ev->bdaddr); + if (ie) + ie->data.ssp_mode = (ev->features[0] & LMP_HOST_SSP); + + hci_dev_unlock(hdev); +} + +static void hci_remote_oob_data_request_evt(struct hci_dev *hdev, void *edata, + struct sk_buff *skb) +{ + struct hci_ev_remote_oob_data_request *ev = edata; + struct oob_data *data; + + bt_dev_dbg(hdev, ""); + + hci_dev_lock(hdev); + + if (!hci_dev_test_flag(hdev, HCI_MGMT)) + goto unlock; + + data = hci_find_remote_oob_data(hdev, &ev->bdaddr, BDADDR_BREDR); + if (!data) { + struct hci_cp_remote_oob_data_neg_reply cp; + + bacpy(&cp.bdaddr, &ev->bdaddr); + hci_send_cmd(hdev, HCI_OP_REMOTE_OOB_DATA_NEG_REPLY, + sizeof(cp), &cp); + goto unlock; + } + + if (bredr_sc_enabled(hdev)) { + struct hci_cp_remote_oob_ext_data_reply cp; + + bacpy(&cp.bdaddr, &ev->bdaddr); + if (hci_dev_test_flag(hdev, HCI_SC_ONLY)) { + memset(cp.hash192, 0, sizeof(cp.hash192)); + memset(cp.rand192, 0, sizeof(cp.rand192)); + } else { + memcpy(cp.hash192, data->hash192, sizeof(cp.hash192)); + memcpy(cp.rand192, data->rand192, sizeof(cp.rand192)); + } + memcpy(cp.hash256, data->hash256, sizeof(cp.hash256)); + memcpy(cp.rand256, data->rand256, sizeof(cp.rand256)); + + hci_send_cmd(hdev, HCI_OP_REMOTE_OOB_EXT_DATA_REPLY, + sizeof(cp), &cp); + } else { + struct hci_cp_remote_oob_data_reply cp; + + bacpy(&cp.bdaddr, &ev->bdaddr); + memcpy(cp.hash, data->hash192, sizeof(cp.hash)); + memcpy(cp.rand, data->rand192, sizeof(cp.rand)); + + hci_send_cmd(hdev, HCI_OP_REMOTE_OOB_DATA_REPLY, + sizeof(cp), &cp); + } + +unlock: + hci_dev_unlock(hdev); +} + +#if IS_ENABLED(CONFIG_BT_HS) +static void hci_chan_selected_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_channel_selected *ev = data; + struct hci_conn *hcon; + + bt_dev_dbg(hdev, "handle 0x%2.2x", ev->phy_handle); + + hcon = hci_conn_hash_lookup_handle(hdev, ev->phy_handle); + if (!hcon) + return; + + amp_read_loc_assoc_final_data(hdev, hcon); +} + +static void hci_phy_link_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_phy_link_complete *ev = data; + struct hci_conn *hcon, *bredr_hcon; + + bt_dev_dbg(hdev, "handle 0x%2.2x status 0x%2.2x", ev->phy_handle, + ev->status); + + hci_dev_lock(hdev); + + hcon = hci_conn_hash_lookup_handle(hdev, ev->phy_handle); + if (!hcon) + goto unlock; + + if (!hcon->amp_mgr) + goto unlock; + + if (ev->status) { + hci_conn_del(hcon); + goto unlock; + } + + bredr_hcon = hcon->amp_mgr->l2cap_conn->hcon; + + hcon->state = BT_CONNECTED; + bacpy(&hcon->dst, &bredr_hcon->dst); + + hci_conn_hold(hcon); + hcon->disc_timeout = HCI_DISCONN_TIMEOUT; + hci_conn_drop(hcon); + + hci_debugfs_create_conn(hcon); + hci_conn_add_sysfs(hcon); + + amp_physical_cfm(bredr_hcon, hcon); + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_loglink_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_logical_link_complete *ev = data; + struct hci_conn *hcon; + struct hci_chan *hchan; + struct amp_mgr *mgr; + + bt_dev_dbg(hdev, "log_handle 0x%4.4x phy_handle 0x%2.2x status 0x%2.2x", + le16_to_cpu(ev->handle), ev->phy_handle, ev->status); + + hcon = hci_conn_hash_lookup_handle(hdev, ev->phy_handle); + if (!hcon) + return; + + /* Create AMP hchan */ + hchan = hci_chan_create(hcon); + if (!hchan) + return; + + hchan->handle = le16_to_cpu(ev->handle); + hchan->amp = true; + + BT_DBG("hcon %p mgr %p hchan %p", hcon, hcon->amp_mgr, hchan); + + mgr = hcon->amp_mgr; + if (mgr && mgr->bredr_chan) { + struct l2cap_chan *bredr_chan = mgr->bredr_chan; + + l2cap_chan_lock(bredr_chan); + + bredr_chan->conn->mtu = hdev->block_mtu; + l2cap_logical_cfm(bredr_chan, hchan, 0); + hci_conn_hold(hcon); + + l2cap_chan_unlock(bredr_chan); + } +} + +static void hci_disconn_loglink_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_disconn_logical_link_complete *ev = data; + struct hci_chan *hchan; + + bt_dev_dbg(hdev, "handle 0x%4.4x status 0x%2.2x", + le16_to_cpu(ev->handle), ev->status); + + if (ev->status) + return; + + hci_dev_lock(hdev); + + hchan = hci_chan_lookup_handle(hdev, le16_to_cpu(ev->handle)); + if (!hchan || !hchan->amp) + goto unlock; + + amp_destroy_logical_link(hchan, ev->reason); + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_disconn_phylink_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_disconn_phy_link_complete *ev = data; + struct hci_conn *hcon; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + if (ev->status) + return; + + hci_dev_lock(hdev); + + hcon = hci_conn_hash_lookup_handle(hdev, ev->phy_handle); + if (hcon && hcon->type == AMP_LINK) { + hcon->state = BT_CLOSED; + hci_disconn_cfm(hcon, ev->reason); + hci_conn_del(hcon); + } + + hci_dev_unlock(hdev); +} +#endif + +static void le_conn_update_addr(struct hci_conn *conn, bdaddr_t *bdaddr, + u8 bdaddr_type, bdaddr_t *local_rpa) +{ + if (conn->out) { + conn->dst_type = bdaddr_type; + conn->resp_addr_type = bdaddr_type; + bacpy(&conn->resp_addr, bdaddr); + + /* Check if the controller has set a Local RPA then it must be + * used instead or hdev->rpa. + */ + if (local_rpa && bacmp(local_rpa, BDADDR_ANY)) { + conn->init_addr_type = ADDR_LE_DEV_RANDOM; + bacpy(&conn->init_addr, local_rpa); + } else if (hci_dev_test_flag(conn->hdev, HCI_PRIVACY)) { + conn->init_addr_type = ADDR_LE_DEV_RANDOM; + bacpy(&conn->init_addr, &conn->hdev->rpa); + } else { + hci_copy_identity_address(conn->hdev, &conn->init_addr, + &conn->init_addr_type); + } + } else { + conn->resp_addr_type = conn->hdev->adv_addr_type; + /* Check if the controller has set a Local RPA then it must be + * used instead or hdev->rpa. + */ + if (local_rpa && bacmp(local_rpa, BDADDR_ANY)) { + conn->resp_addr_type = ADDR_LE_DEV_RANDOM; + bacpy(&conn->resp_addr, local_rpa); + } else if (conn->hdev->adv_addr_type == ADDR_LE_DEV_RANDOM) { + /* In case of ext adv, resp_addr will be updated in + * Adv Terminated event. + */ + if (!ext_adv_capable(conn->hdev)) + bacpy(&conn->resp_addr, + &conn->hdev->random_addr); + } else { + bacpy(&conn->resp_addr, &conn->hdev->bdaddr); + } + + conn->init_addr_type = bdaddr_type; + bacpy(&conn->init_addr, bdaddr); + + /* For incoming connections, set the default minimum + * and maximum connection interval. They will be used + * to check if the parameters are in range and if not + * trigger the connection update procedure. + */ + conn->le_conn_min_interval = conn->hdev->le_conn_min_interval; + conn->le_conn_max_interval = conn->hdev->le_conn_max_interval; + } +} + +static void le_conn_complete_evt(struct hci_dev *hdev, u8 status, + bdaddr_t *bdaddr, u8 bdaddr_type, + bdaddr_t *local_rpa, u8 role, u16 handle, + u16 interval, u16 latency, + u16 supervision_timeout) +{ + struct hci_conn_params *params; + struct hci_conn *conn; + struct smp_irk *irk; + u8 addr_type; + + hci_dev_lock(hdev); + + /* All controllers implicitly stop advertising in the event of a + * connection, so ensure that the state bit is cleared. + */ + hci_dev_clear_flag(hdev, HCI_LE_ADV); + + conn = hci_conn_hash_lookup_ba(hdev, LE_LINK, bdaddr); + if (!conn) { + /* In case of error status and there is no connection pending + * just unlock as there is nothing to cleanup. + */ + if (status) + goto unlock; + + conn = hci_conn_add(hdev, LE_LINK, bdaddr, role); + if (!conn) { + bt_dev_err(hdev, "no memory for new connection"); + goto unlock; + } + + conn->dst_type = bdaddr_type; + + /* If we didn't have a hci_conn object previously + * but we're in central role this must be something + * initiated using an accept list. Since accept list based + * connections are not "first class citizens" we don't + * have full tracking of them. Therefore, we go ahead + * with a "best effort" approach of determining the + * initiator address based on the HCI_PRIVACY flag. + */ + if (conn->out) { + conn->resp_addr_type = bdaddr_type; + bacpy(&conn->resp_addr, bdaddr); + if (hci_dev_test_flag(hdev, HCI_PRIVACY)) { + conn->init_addr_type = ADDR_LE_DEV_RANDOM; + bacpy(&conn->init_addr, &hdev->rpa); + } else { + hci_copy_identity_address(hdev, + &conn->init_addr, + &conn->init_addr_type); + } + } + } else { + cancel_delayed_work(&conn->le_conn_timeout); + } + + /* The HCI_LE_Connection_Complete event is only sent once per connection. + * Processing it more than once per connection can corrupt kernel memory. + * + * As the connection handle is set here for the first time, it indicates + * whether the connection is already set up. + */ + if (conn->handle != HCI_CONN_HANDLE_UNSET) { + bt_dev_err(hdev, "Ignoring HCI_Connection_Complete for existing connection"); + goto unlock; + } + + le_conn_update_addr(conn, bdaddr, bdaddr_type, local_rpa); + + /* Lookup the identity address from the stored connection + * address and address type. + * + * When establishing connections to an identity address, the + * connection procedure will store the resolvable random + * address first. Now if it can be converted back into the + * identity address, start using the identity address from + * now on. + */ + irk = hci_get_irk(hdev, &conn->dst, conn->dst_type); + if (irk) { + bacpy(&conn->dst, &irk->bdaddr); + conn->dst_type = irk->addr_type; + } + + conn->dst_type = ev_bdaddr_type(hdev, conn->dst_type, NULL); + + if (handle > HCI_CONN_HANDLE_MAX) { + bt_dev_err(hdev, "Invalid handle: 0x%4.4x > 0x%4.4x", handle, + HCI_CONN_HANDLE_MAX); + status = HCI_ERROR_INVALID_PARAMETERS; + } + + /* All connection failure handling is taken care of by the + * hci_conn_failed function which is triggered by the HCI + * request completion callbacks used for connecting. + */ + if (status) + goto unlock; + + /* Drop the connection if it has been aborted */ + if (test_bit(HCI_CONN_CANCEL, &conn->flags)) { + hci_conn_drop(conn); + goto unlock; + } + + if (conn->dst_type == ADDR_LE_DEV_PUBLIC) + addr_type = BDADDR_LE_PUBLIC; + else + addr_type = BDADDR_LE_RANDOM; + + /* Drop the connection if the device is blocked */ + if (hci_bdaddr_list_lookup(&hdev->reject_list, &conn->dst, addr_type)) { + hci_conn_drop(conn); + goto unlock; + } + + if (!test_and_set_bit(HCI_CONN_MGMT_CONNECTED, &conn->flags)) + mgmt_device_connected(hdev, conn, NULL, 0); + + conn->sec_level = BT_SECURITY_LOW; + conn->handle = handle; + conn->state = BT_CONFIG; + + /* Store current advertising instance as connection advertising instance + * when sotfware rotation is in use so it can be re-enabled when + * disconnected. + */ + if (!ext_adv_capable(hdev)) + conn->adv_instance = hdev->cur_adv_instance; + + conn->le_conn_interval = interval; + conn->le_conn_latency = latency; + conn->le_supv_timeout = supervision_timeout; + + hci_debugfs_create_conn(conn); + hci_conn_add_sysfs(conn); + + /* The remote features procedure is defined for central + * role only. So only in case of an initiated connection + * request the remote features. + * + * If the local controller supports peripheral-initiated features + * exchange, then requesting the remote features in peripheral + * role is possible. Otherwise just transition into the + * connected state without requesting the remote features. + */ + if (conn->out || + (hdev->le_features[0] & HCI_LE_PERIPHERAL_FEATURES)) { + struct hci_cp_le_read_remote_features cp; + + cp.handle = __cpu_to_le16(conn->handle); + + hci_send_cmd(hdev, HCI_OP_LE_READ_REMOTE_FEATURES, + sizeof(cp), &cp); + + hci_conn_hold(conn); + } else { + conn->state = BT_CONNECTED; + hci_connect_cfm(conn, status); + } + + params = hci_pend_le_action_lookup(&hdev->pend_le_conns, &conn->dst, + conn->dst_type); + if (params) { + hci_pend_le_list_del_init(params); + if (params->conn) { + hci_conn_drop(params->conn); + hci_conn_put(params->conn); + params->conn = NULL; + } + } + +unlock: + hci_update_passive_scan(hdev); + hci_dev_unlock(hdev); +} + +static void hci_le_conn_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_le_conn_complete *ev = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + le_conn_complete_evt(hdev, ev->status, &ev->bdaddr, ev->bdaddr_type, + NULL, ev->role, le16_to_cpu(ev->handle), + le16_to_cpu(ev->interval), + le16_to_cpu(ev->latency), + le16_to_cpu(ev->supervision_timeout)); +} + +static void hci_le_enh_conn_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_le_enh_conn_complete *ev = data; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + le_conn_complete_evt(hdev, ev->status, &ev->bdaddr, ev->bdaddr_type, + &ev->local_rpa, ev->role, le16_to_cpu(ev->handle), + le16_to_cpu(ev->interval), + le16_to_cpu(ev->latency), + le16_to_cpu(ev->supervision_timeout)); +} + +static void hci_le_ext_adv_term_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_evt_le_ext_adv_set_term *ev = data; + struct hci_conn *conn; + struct adv_info *adv, *n; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + /* The Bluetooth Core 5.3 specification clearly states that this event + * shall not be sent when the Host disables the advertising set. So in + * case of HCI_ERROR_CANCELLED_BY_HOST, just ignore the event. + * + * When the Host disables an advertising set, all cleanup is done via + * its command callback and not needed to be duplicated here. + */ + if (ev->status == HCI_ERROR_CANCELLED_BY_HOST) { + bt_dev_warn_ratelimited(hdev, "Unexpected advertising set terminated event"); + return; + } + + hci_dev_lock(hdev); + + adv = hci_find_adv_instance(hdev, ev->handle); + + if (ev->status) { + if (!adv) + goto unlock; + + /* Remove advertising as it has been terminated */ + hci_remove_adv_instance(hdev, ev->handle); + mgmt_advertising_removed(NULL, hdev, ev->handle); + + list_for_each_entry_safe(adv, n, &hdev->adv_instances, list) { + if (adv->enabled) + goto unlock; + } + + /* We are no longer advertising, clear HCI_LE_ADV */ + hci_dev_clear_flag(hdev, HCI_LE_ADV); + goto unlock; + } + + if (adv) + adv->enabled = false; + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(ev->conn_handle)); + if (conn) { + /* Store handle in the connection so the correct advertising + * instance can be re-enabled when disconnected. + */ + conn->adv_instance = ev->handle; + + if (hdev->adv_addr_type != ADDR_LE_DEV_RANDOM || + bacmp(&conn->resp_addr, BDADDR_ANY)) + goto unlock; + + if (!ev->handle) { + bacpy(&conn->resp_addr, &hdev->random_addr); + goto unlock; + } + + if (adv) + bacpy(&conn->resp_addr, &adv->random_addr); + } + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_le_conn_update_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_le_conn_update_complete *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + if (ev->status) + return; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(ev->handle)); + if (conn) { + conn->le_conn_interval = le16_to_cpu(ev->interval); + conn->le_conn_latency = le16_to_cpu(ev->latency); + conn->le_supv_timeout = le16_to_cpu(ev->supervision_timeout); + } + + hci_dev_unlock(hdev); +} + +/* This function requires the caller holds hdev->lock */ +static struct hci_conn *check_pending_le_conn(struct hci_dev *hdev, + bdaddr_t *addr, + u8 addr_type, bool addr_resolved, + u8 adv_type) +{ + struct hci_conn *conn; + struct hci_conn_params *params; + + /* If the event is not connectable don't proceed further */ + if (adv_type != LE_ADV_IND && adv_type != LE_ADV_DIRECT_IND) + return NULL; + + /* Ignore if the device is blocked or hdev is suspended */ + if (hci_bdaddr_list_lookup(&hdev->reject_list, addr, addr_type) || + hdev->suspended) + return NULL; + + /* Most controller will fail if we try to create new connections + * while we have an existing one in peripheral role. + */ + if (hdev->conn_hash.le_num_peripheral > 0 && + (!test_bit(HCI_QUIRK_VALID_LE_STATES, &hdev->quirks) || + !(hdev->le_states[3] & 0x10))) + return NULL; + + /* If we're not connectable only connect devices that we have in + * our pend_le_conns list. + */ + params = hci_pend_le_action_lookup(&hdev->pend_le_conns, addr, + addr_type); + if (!params) + return NULL; + + if (!params->explicit_connect) { + switch (params->auto_connect) { + case HCI_AUTO_CONN_DIRECT: + /* Only devices advertising with ADV_DIRECT_IND are + * triggering a connection attempt. This is allowing + * incoming connections from peripheral devices. + */ + if (adv_type != LE_ADV_DIRECT_IND) + return NULL; + break; + case HCI_AUTO_CONN_ALWAYS: + /* Devices advertising with ADV_IND or ADV_DIRECT_IND + * are triggering a connection attempt. This means + * that incoming connections from peripheral device are + * accepted and also outgoing connections to peripheral + * devices are established when found. + */ + break; + default: + return NULL; + } + } + + conn = hci_connect_le(hdev, addr, addr_type, addr_resolved, + BT_SECURITY_LOW, hdev->def_le_autoconnect_timeout, + HCI_ROLE_MASTER); + if (!IS_ERR(conn)) { + /* If HCI_AUTO_CONN_EXPLICIT is set, conn is already owned + * by higher layer that tried to connect, if no then + * store the pointer since we don't really have any + * other owner of the object besides the params that + * triggered it. This way we can abort the connection if + * the parameters get removed and keep the reference + * count consistent once the connection is established. + */ + + if (!params->explicit_connect) + params->conn = hci_conn_get(conn); + + return conn; + } + + switch (PTR_ERR(conn)) { + case -EBUSY: + /* If hci_connect() returns -EBUSY it means there is already + * an LE connection attempt going on. Since controllers don't + * support more than one connection attempt at the time, we + * don't consider this an error case. + */ + break; + default: + BT_DBG("Failed to connect: err %ld", PTR_ERR(conn)); + return NULL; + } + + return NULL; +} + +static void process_adv_report(struct hci_dev *hdev, u8 type, bdaddr_t *bdaddr, + u8 bdaddr_type, bdaddr_t *direct_addr, + u8 direct_addr_type, s8 rssi, u8 *data, u8 len, + bool ext_adv, bool ctl_time, u64 instant) +{ + struct discovery_state *d = &hdev->discovery; + struct smp_irk *irk; + struct hci_conn *conn; + bool match, bdaddr_resolved; + u32 flags; + u8 *ptr; + + switch (type) { + case LE_ADV_IND: + case LE_ADV_DIRECT_IND: + case LE_ADV_SCAN_IND: + case LE_ADV_NONCONN_IND: + case LE_ADV_SCAN_RSP: + break; + default: + bt_dev_err_ratelimited(hdev, "unknown advertising packet " + "type: 0x%02x", type); + return; + } + + if (!ext_adv && len > HCI_MAX_AD_LENGTH) { + bt_dev_err_ratelimited(hdev, "legacy adv larger than 31 bytes"); + return; + } + + /* Find the end of the data in case the report contains padded zero + * bytes at the end causing an invalid length value. + * + * When data is NULL, len is 0 so there is no need for extra ptr + * check as 'ptr < data + 0' is already false in such case. + */ + for (ptr = data; ptr < data + len && *ptr; ptr += *ptr + 1) { + if (ptr + 1 + *ptr > data + len) + break; + } + + /* Adjust for actual length. This handles the case when remote + * device is advertising with incorrect data length. + */ + len = ptr - data; + + /* If the direct address is present, then this report is from + * a LE Direct Advertising Report event. In that case it is + * important to see if the address is matching the local + * controller address. + */ + if (!hci_dev_test_flag(hdev, HCI_MESH) && direct_addr) { + direct_addr_type = ev_bdaddr_type(hdev, direct_addr_type, + &bdaddr_resolved); + + /* Only resolvable random addresses are valid for these + * kind of reports and others can be ignored. + */ + if (!hci_bdaddr_is_rpa(direct_addr, direct_addr_type)) + return; + + /* If the controller is not using resolvable random + * addresses, then this report can be ignored. + */ + if (!hci_dev_test_flag(hdev, HCI_PRIVACY)) + return; + + /* If the local IRK of the controller does not match + * with the resolvable random address provided, then + * this report can be ignored. + */ + if (!smp_irk_matches(hdev, hdev->irk, direct_addr)) + return; + } + + /* Check if we need to convert to identity address */ + irk = hci_get_irk(hdev, bdaddr, bdaddr_type); + if (irk) { + bdaddr = &irk->bdaddr; + bdaddr_type = irk->addr_type; + } + + bdaddr_type = ev_bdaddr_type(hdev, bdaddr_type, &bdaddr_resolved); + + /* Check if we have been requested to connect to this device. + * + * direct_addr is set only for directed advertising reports (it is NULL + * for advertising reports) and is already verified to be RPA above. + */ + conn = check_pending_le_conn(hdev, bdaddr, bdaddr_type, bdaddr_resolved, + type); + if (!ext_adv && conn && type == LE_ADV_IND && len <= HCI_MAX_AD_LENGTH) { + /* Store report for later inclusion by + * mgmt_device_connected + */ + memcpy(conn->le_adv_data, data, len); + conn->le_adv_data_len = len; + } + + if (type == LE_ADV_NONCONN_IND || type == LE_ADV_SCAN_IND) + flags = MGMT_DEV_FOUND_NOT_CONNECTABLE; + else + flags = 0; + + /* All scan results should be sent up for Mesh systems */ + if (hci_dev_test_flag(hdev, HCI_MESH)) { + mgmt_device_found(hdev, bdaddr, LE_LINK, bdaddr_type, NULL, + rssi, flags, data, len, NULL, 0, instant); + return; + } + + /* Passive scanning shouldn't trigger any device found events, + * except for devices marked as CONN_REPORT for which we do send + * device found events, or advertisement monitoring requested. + */ + if (hdev->le_scan_type == LE_SCAN_PASSIVE) { + if (type == LE_ADV_DIRECT_IND) + return; + + if (!hci_pend_le_action_lookup(&hdev->pend_le_reports, + bdaddr, bdaddr_type) && + idr_is_empty(&hdev->adv_monitors_idr)) + return; + + mgmt_device_found(hdev, bdaddr, LE_LINK, bdaddr_type, NULL, + rssi, flags, data, len, NULL, 0, 0); + return; + } + + /* When receiving a scan response, then there is no way to + * know if the remote device is connectable or not. However + * since scan responses are merged with a previously seen + * advertising report, the flags field from that report + * will be used. + * + * In the unlikely case that a controller just sends a scan + * response event that doesn't match the pending report, then + * it is marked as a standalone SCAN_RSP. + */ + if (type == LE_ADV_SCAN_RSP) + flags = MGMT_DEV_FOUND_SCAN_RSP; + + /* If there's nothing pending either store the data from this + * event or send an immediate device found event if the data + * should not be stored for later. + */ + if (!ext_adv && !has_pending_adv_report(hdev)) { + /* If the report will trigger a SCAN_REQ store it for + * later merging. + */ + if (type == LE_ADV_IND || type == LE_ADV_SCAN_IND) { + store_pending_adv_report(hdev, bdaddr, bdaddr_type, + rssi, flags, data, len); + return; + } + + mgmt_device_found(hdev, bdaddr, LE_LINK, bdaddr_type, NULL, + rssi, flags, data, len, NULL, 0, 0); + return; + } + + /* Check if the pending report is for the same device as the new one */ + match = (!bacmp(bdaddr, &d->last_adv_addr) && + bdaddr_type == d->last_adv_addr_type); + + /* If the pending data doesn't match this report or this isn't a + * scan response (e.g. we got a duplicate ADV_IND) then force + * sending of the pending data. + */ + if (type != LE_ADV_SCAN_RSP || !match) { + /* Send out whatever is in the cache, but skip duplicates */ + if (!match) + mgmt_device_found(hdev, &d->last_adv_addr, LE_LINK, + d->last_adv_addr_type, NULL, + d->last_adv_rssi, d->last_adv_flags, + d->last_adv_data, + d->last_adv_data_len, NULL, 0, 0); + + /* If the new report will trigger a SCAN_REQ store it for + * later merging. + */ + if (!ext_adv && (type == LE_ADV_IND || + type == LE_ADV_SCAN_IND)) { + store_pending_adv_report(hdev, bdaddr, bdaddr_type, + rssi, flags, data, len); + return; + } + + /* The advertising reports cannot be merged, so clear + * the pending report and send out a device found event. + */ + clear_pending_adv_report(hdev); + mgmt_device_found(hdev, bdaddr, LE_LINK, bdaddr_type, NULL, + rssi, flags, data, len, NULL, 0, 0); + return; + } + + /* If we get here we've got a pending ADV_IND or ADV_SCAN_IND and + * the new event is a SCAN_RSP. We can therefore proceed with + * sending a merged device found event. + */ + mgmt_device_found(hdev, &d->last_adv_addr, LE_LINK, + d->last_adv_addr_type, NULL, rssi, d->last_adv_flags, + d->last_adv_data, d->last_adv_data_len, data, len, 0); + clear_pending_adv_report(hdev); +} + +static void hci_le_adv_report_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_le_advertising_report *ev = data; + u64 instant = jiffies; + + if (!ev->num) + return; + + hci_dev_lock(hdev); + + while (ev->num--) { + struct hci_ev_le_advertising_info *info; + s8 rssi; + + info = hci_le_ev_skb_pull(hdev, skb, + HCI_EV_LE_ADVERTISING_REPORT, + sizeof(*info)); + if (!info) + break; + + if (!hci_le_ev_skb_pull(hdev, skb, HCI_EV_LE_ADVERTISING_REPORT, + info->length + 1)) + break; + + if (info->length <= HCI_MAX_AD_LENGTH) { + rssi = info->data[info->length]; + process_adv_report(hdev, info->type, &info->bdaddr, + info->bdaddr_type, NULL, 0, rssi, + info->data, info->length, false, + false, instant); + } else { + bt_dev_err(hdev, "Dropping invalid advertising data"); + } + } + + hci_dev_unlock(hdev); +} + +static u8 ext_evt_type_to_legacy(struct hci_dev *hdev, u16 evt_type) +{ + if (evt_type & LE_EXT_ADV_LEGACY_PDU) { + switch (evt_type) { + case LE_LEGACY_ADV_IND: + return LE_ADV_IND; + case LE_LEGACY_ADV_DIRECT_IND: + return LE_ADV_DIRECT_IND; + case LE_LEGACY_ADV_SCAN_IND: + return LE_ADV_SCAN_IND; + case LE_LEGACY_NONCONN_IND: + return LE_ADV_NONCONN_IND; + case LE_LEGACY_SCAN_RSP_ADV: + case LE_LEGACY_SCAN_RSP_ADV_SCAN: + return LE_ADV_SCAN_RSP; + } + + goto invalid; + } + + if (evt_type & LE_EXT_ADV_CONN_IND) { + if (evt_type & LE_EXT_ADV_DIRECT_IND) + return LE_ADV_DIRECT_IND; + + return LE_ADV_IND; + } + + if (evt_type & LE_EXT_ADV_SCAN_RSP) + return LE_ADV_SCAN_RSP; + + if (evt_type & LE_EXT_ADV_SCAN_IND) + return LE_ADV_SCAN_IND; + + if (evt_type == LE_EXT_ADV_NON_CONN_IND || + evt_type & LE_EXT_ADV_DIRECT_IND) + return LE_ADV_NONCONN_IND; + +invalid: + bt_dev_err_ratelimited(hdev, "Unknown advertising packet type: 0x%02x", + evt_type); + + return LE_ADV_INVALID; +} + +static void hci_le_ext_adv_report_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_le_ext_adv_report *ev = data; + u64 instant = jiffies; + + if (!ev->num) + return; + + hci_dev_lock(hdev); + + while (ev->num--) { + struct hci_ev_le_ext_adv_info *info; + u8 legacy_evt_type; + u16 evt_type; + + info = hci_le_ev_skb_pull(hdev, skb, HCI_EV_LE_EXT_ADV_REPORT, + sizeof(*info)); + if (!info) + break; + + if (!hci_le_ev_skb_pull(hdev, skb, HCI_EV_LE_EXT_ADV_REPORT, + info->length)) + break; + + evt_type = __le16_to_cpu(info->type); + legacy_evt_type = ext_evt_type_to_legacy(hdev, evt_type); + if (legacy_evt_type != LE_ADV_INVALID) { + process_adv_report(hdev, legacy_evt_type, &info->bdaddr, + info->bdaddr_type, NULL, 0, + info->rssi, info->data, info->length, + !(evt_type & LE_EXT_ADV_LEGACY_PDU), + false, instant); + } + } + + hci_dev_unlock(hdev); +} + +static int hci_le_pa_term_sync(struct hci_dev *hdev, __le16 handle) +{ + struct hci_cp_le_pa_term_sync cp; + + memset(&cp, 0, sizeof(cp)); + cp.handle = handle; + + return hci_send_cmd(hdev, HCI_OP_LE_PA_TERM_SYNC, sizeof(cp), &cp); +} + +static void hci_le_pa_sync_estabilished_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_le_pa_sync_established *ev = data; + int mask = hdev->link_mode; + __u8 flags = 0; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + if (ev->status) + return; + + hci_dev_lock(hdev); + + hci_dev_clear_flag(hdev, HCI_PA_SYNC); + + mask |= hci_proto_connect_ind(hdev, &ev->bdaddr, ISO_LINK, &flags); + if (!(mask & HCI_LM_ACCEPT)) + hci_le_pa_term_sync(hdev, ev->handle); + + hci_dev_unlock(hdev); +} + +static void hci_le_remote_feat_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_le_remote_feat_complete *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(ev->handle)); + if (conn) { + if (!ev->status) + memcpy(conn->features[0], ev->features, 8); + + if (conn->state == BT_CONFIG) { + __u8 status; + + /* If the local controller supports peripheral-initiated + * features exchange, but the remote controller does + * not, then it is possible that the error code 0x1a + * for unsupported remote feature gets returned. + * + * In this specific case, allow the connection to + * transition into connected state and mark it as + * successful. + */ + if (!conn->out && ev->status == 0x1a && + (hdev->le_features[0] & HCI_LE_PERIPHERAL_FEATURES)) + status = 0x00; + else + status = ev->status; + + conn->state = BT_CONNECTED; + hci_connect_cfm(conn, status); + hci_conn_drop(conn); + } + } + + hci_dev_unlock(hdev); +} + +static void hci_le_ltk_request_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_le_ltk_req *ev = data; + struct hci_cp_le_ltk_reply cp; + struct hci_cp_le_ltk_neg_reply neg; + struct hci_conn *conn; + struct smp_ltk *ltk; + + bt_dev_dbg(hdev, "handle 0x%4.4x", __le16_to_cpu(ev->handle)); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(ev->handle)); + if (conn == NULL) + goto not_found; + + ltk = hci_find_ltk(hdev, &conn->dst, conn->dst_type, conn->role); + if (!ltk) + goto not_found; + + if (smp_ltk_is_sc(ltk)) { + /* With SC both EDiv and Rand are set to zero */ + if (ev->ediv || ev->rand) + goto not_found; + } else { + /* For non-SC keys check that EDiv and Rand match */ + if (ev->ediv != ltk->ediv || ev->rand != ltk->rand) + goto not_found; + } + + memcpy(cp.ltk, ltk->val, ltk->enc_size); + memset(cp.ltk + ltk->enc_size, 0, sizeof(cp.ltk) - ltk->enc_size); + cp.handle = cpu_to_le16(conn->handle); + + conn->pending_sec_level = smp_ltk_sec_level(ltk); + + conn->enc_key_size = ltk->enc_size; + + hci_send_cmd(hdev, HCI_OP_LE_LTK_REPLY, sizeof(cp), &cp); + + /* Ref. Bluetooth Core SPEC pages 1975 and 2004. STK is a + * temporary key used to encrypt a connection following + * pairing. It is used during the Encrypted Session Setup to + * distribute the keys. Later, security can be re-established + * using a distributed LTK. + */ + if (ltk->type == SMP_STK) { + set_bit(HCI_CONN_STK_ENCRYPT, &conn->flags); + list_del_rcu(<k->list); + kfree_rcu(ltk, rcu); + } else { + clear_bit(HCI_CONN_STK_ENCRYPT, &conn->flags); + } + + hci_dev_unlock(hdev); + + return; + +not_found: + neg.handle = ev->handle; + hci_send_cmd(hdev, HCI_OP_LE_LTK_NEG_REPLY, sizeof(neg), &neg); + hci_dev_unlock(hdev); +} + +static void send_conn_param_neg_reply(struct hci_dev *hdev, u16 handle, + u8 reason) +{ + struct hci_cp_le_conn_param_req_neg_reply cp; + + cp.handle = cpu_to_le16(handle); + cp.reason = reason; + + hci_send_cmd(hdev, HCI_OP_LE_CONN_PARAM_REQ_NEG_REPLY, sizeof(cp), + &cp); +} + +static void hci_le_remote_conn_param_req_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_le_remote_conn_param_req *ev = data; + struct hci_cp_le_conn_param_req_reply cp; + struct hci_conn *hcon; + u16 handle, min, max, latency, timeout; + + bt_dev_dbg(hdev, "handle 0x%4.4x", __le16_to_cpu(ev->handle)); + + handle = le16_to_cpu(ev->handle); + min = le16_to_cpu(ev->interval_min); + max = le16_to_cpu(ev->interval_max); + latency = le16_to_cpu(ev->latency); + timeout = le16_to_cpu(ev->timeout); + + hcon = hci_conn_hash_lookup_handle(hdev, handle); + if (!hcon || hcon->state != BT_CONNECTED) + return send_conn_param_neg_reply(hdev, handle, + HCI_ERROR_UNKNOWN_CONN_ID); + + if (hci_check_conn_params(min, max, latency, timeout)) + return send_conn_param_neg_reply(hdev, handle, + HCI_ERROR_INVALID_LL_PARAMS); + + if (hcon->role == HCI_ROLE_MASTER) { + struct hci_conn_params *params; + u8 store_hint; + + hci_dev_lock(hdev); + + params = hci_conn_params_lookup(hdev, &hcon->dst, + hcon->dst_type); + if (params) { + params->conn_min_interval = min; + params->conn_max_interval = max; + params->conn_latency = latency; + params->supervision_timeout = timeout; + store_hint = 0x01; + } else { + store_hint = 0x00; + } + + hci_dev_unlock(hdev); + + mgmt_new_conn_param(hdev, &hcon->dst, hcon->dst_type, + store_hint, min, max, latency, timeout); + } + + cp.handle = ev->handle; + cp.interval_min = ev->interval_min; + cp.interval_max = ev->interval_max; + cp.latency = ev->latency; + cp.timeout = ev->timeout; + cp.min_ce_len = 0; + cp.max_ce_len = 0; + + hci_send_cmd(hdev, HCI_OP_LE_CONN_PARAM_REQ_REPLY, sizeof(cp), &cp); +} + +static void hci_le_direct_adv_report_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_le_direct_adv_report *ev = data; + u64 instant = jiffies; + int i; + + if (!hci_le_ev_skb_pull(hdev, skb, HCI_EV_LE_DIRECT_ADV_REPORT, + flex_array_size(ev, info, ev->num))) + return; + + if (!ev->num) + return; + + hci_dev_lock(hdev); + + for (i = 0; i < ev->num; i++) { + struct hci_ev_le_direct_adv_info *info = &ev->info[i]; + + process_adv_report(hdev, info->type, &info->bdaddr, + info->bdaddr_type, &info->direct_addr, + info->direct_addr_type, info->rssi, NULL, 0, + false, false, instant); + } + + hci_dev_unlock(hdev); +} + +static void hci_le_phy_update_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_ev_le_phy_update_complete *ev = data; + struct hci_conn *conn; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + if (ev->status) + return; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, __le16_to_cpu(ev->handle)); + if (!conn) + goto unlock; + + conn->le_tx_phy = ev->tx_phy; + conn->le_rx_phy = ev->rx_phy; + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_le_cis_estabilished_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_evt_le_cis_established *ev = data; + struct hci_conn *conn; + u16 handle = __le16_to_cpu(ev->handle); + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_handle(hdev, handle); + if (!conn) { + bt_dev_err(hdev, + "Unable to find connection with handle 0x%4.4x", + handle); + goto unlock; + } + + if (conn->type != ISO_LINK) { + bt_dev_err(hdev, + "Invalid connection link type handle 0x%4.4x", + handle); + goto unlock; + } + + if (conn->role == HCI_ROLE_SLAVE) { + __le32 interval; + + memset(&interval, 0, sizeof(interval)); + + memcpy(&interval, ev->c_latency, sizeof(ev->c_latency)); + conn->iso_qos.in.interval = le32_to_cpu(interval); + memcpy(&interval, ev->p_latency, sizeof(ev->p_latency)); + conn->iso_qos.out.interval = le32_to_cpu(interval); + conn->iso_qos.in.latency = le16_to_cpu(ev->interval); + conn->iso_qos.out.latency = le16_to_cpu(ev->interval); + conn->iso_qos.in.sdu = le16_to_cpu(ev->c_mtu); + conn->iso_qos.out.sdu = le16_to_cpu(ev->p_mtu); + conn->iso_qos.in.phy = ev->c_phy; + conn->iso_qos.out.phy = ev->p_phy; + } + + if (!ev->status) { + conn->state = BT_CONNECTED; + hci_debugfs_create_conn(conn); + hci_conn_add_sysfs(conn); + hci_iso_setup_path(conn); + goto unlock; + } + + hci_connect_cfm(conn, ev->status); + hci_conn_del(conn); + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_le_reject_cis(struct hci_dev *hdev, __le16 handle) +{ + struct hci_cp_le_reject_cis cp; + + memset(&cp, 0, sizeof(cp)); + cp.handle = handle; + cp.reason = HCI_ERROR_REJ_BAD_ADDR; + hci_send_cmd(hdev, HCI_OP_LE_REJECT_CIS, sizeof(cp), &cp); +} + +static void hci_le_accept_cis(struct hci_dev *hdev, __le16 handle) +{ + struct hci_cp_le_accept_cis cp; + + memset(&cp, 0, sizeof(cp)); + cp.handle = handle; + hci_send_cmd(hdev, HCI_OP_LE_ACCEPT_CIS, sizeof(cp), &cp); +} + +static void hci_le_cis_req_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_evt_le_cis_req *ev = data; + u16 acl_handle, cis_handle; + struct hci_conn *acl, *cis; + int mask; + __u8 flags = 0; + + acl_handle = __le16_to_cpu(ev->acl_handle); + cis_handle = __le16_to_cpu(ev->cis_handle); + + bt_dev_dbg(hdev, "acl 0x%4.4x handle 0x%4.4x cig 0x%2.2x cis 0x%2.2x", + acl_handle, cis_handle, ev->cig_id, ev->cis_id); + + hci_dev_lock(hdev); + + acl = hci_conn_hash_lookup_handle(hdev, acl_handle); + if (!acl) + goto unlock; + + mask = hci_proto_connect_ind(hdev, &acl->dst, ISO_LINK, &flags); + if (!(mask & HCI_LM_ACCEPT)) { + hci_le_reject_cis(hdev, ev->cis_handle); + goto unlock; + } + + cis = hci_conn_hash_lookup_handle(hdev, cis_handle); + if (!cis) { + cis = hci_conn_add(hdev, ISO_LINK, &acl->dst, HCI_ROLE_SLAVE); + if (!cis) { + hci_le_reject_cis(hdev, ev->cis_handle); + goto unlock; + } + cis->handle = cis_handle; + } + + cis->iso_qos.cig = ev->cig_id; + cis->iso_qos.cis = ev->cis_id; + + if (!(flags & HCI_PROTO_DEFER)) { + hci_le_accept_cis(hdev, ev->cis_handle); + } else { + cis->state = BT_CONNECT2; + hci_connect_cfm(cis, 0); + } + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_le_create_big_complete_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_evt_le_create_big_complete *ev = data; + struct hci_conn *conn; + + BT_DBG("%s status 0x%2.2x", hdev->name, ev->status); + + if (!hci_le_ev_skb_pull(hdev, skb, HCI_EVT_LE_CREATE_BIG_COMPLETE, + flex_array_size(ev, bis_handle, ev->num_bis))) + return; + + hci_dev_lock(hdev); + + conn = hci_conn_hash_lookup_big(hdev, ev->handle); + if (!conn) + goto unlock; + + if (conn->type != ISO_LINK) { + bt_dev_err(hdev, + "Invalid connection link type handle 0x%2.2x", + ev->handle); + goto unlock; + } + + if (ev->num_bis) + conn->handle = __le16_to_cpu(ev->bis_handle[0]); + + if (!ev->status) { + conn->state = BT_CONNECTED; + hci_debugfs_create_conn(conn); + hci_conn_add_sysfs(conn); + hci_iso_setup_path(conn); + goto unlock; + } + + hci_connect_cfm(conn, ev->status); + hci_conn_del(conn); + +unlock: + hci_dev_unlock(hdev); +} + +static void hci_le_big_sync_established_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_evt_le_big_sync_estabilished *ev = data; + struct hci_conn *bis; + int i; + + bt_dev_dbg(hdev, "status 0x%2.2x", ev->status); + + if (!hci_le_ev_skb_pull(hdev, skb, HCI_EVT_LE_BIG_SYNC_ESTABILISHED, + flex_array_size(ev, bis, ev->num_bis))) + return; + + if (ev->status) + return; + + hci_dev_lock(hdev); + + for (i = 0; i < ev->num_bis; i++) { + u16 handle = le16_to_cpu(ev->bis[i]); + __le32 interval; + + bis = hci_conn_hash_lookup_handle(hdev, handle); + if (!bis) { + bis = hci_conn_add(hdev, ISO_LINK, BDADDR_ANY, + HCI_ROLE_SLAVE); + if (!bis) + continue; + bis->handle = handle; + } + + bis->iso_qos.big = ev->handle; + memset(&interval, 0, sizeof(interval)); + memcpy(&interval, ev->latency, sizeof(ev->latency)); + bis->iso_qos.in.interval = le32_to_cpu(interval); + /* Convert ISO Interval (1.25 ms slots) to latency (ms) */ + bis->iso_qos.in.latency = le16_to_cpu(ev->interval) * 125 / 100; + bis->iso_qos.in.sdu = le16_to_cpu(ev->max_pdu); + + hci_iso_setup_path(bis); + } + + hci_dev_unlock(hdev); +} + +static void hci_le_big_info_adv_report_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) +{ + struct hci_evt_le_big_info_adv_report *ev = data; + int mask = hdev->link_mode; + __u8 flags = 0; + + bt_dev_dbg(hdev, "sync_handle 0x%4.4x", le16_to_cpu(ev->sync_handle)); + + hci_dev_lock(hdev); + + mask |= hci_proto_connect_ind(hdev, BDADDR_ANY, ISO_LINK, &flags); + if (!(mask & HCI_LM_ACCEPT)) + hci_le_pa_term_sync(hdev, ev->sync_handle); + + hci_dev_unlock(hdev); +} + +#define HCI_LE_EV_VL(_op, _func, _min_len, _max_len) \ +[_op] = { \ + .func = _func, \ + .min_len = _min_len, \ + .max_len = _max_len, \ +} + +#define HCI_LE_EV(_op, _func, _len) \ + HCI_LE_EV_VL(_op, _func, _len, _len) + +#define HCI_LE_EV_STATUS(_op, _func) \ + HCI_LE_EV(_op, _func, sizeof(struct hci_ev_status)) + +/* Entries in this table shall have their position according to the subevent + * opcode they handle so the use of the macros above is recommend since it does + * attempt to initialize at its proper index using Designated Initializers that + * way events without a callback function can be ommited. + */ +static const struct hci_le_ev { + void (*func)(struct hci_dev *hdev, void *data, struct sk_buff *skb); + u16 min_len; + u16 max_len; +} hci_le_ev_table[U8_MAX + 1] = { + /* [0x01 = HCI_EV_LE_CONN_COMPLETE] */ + HCI_LE_EV(HCI_EV_LE_CONN_COMPLETE, hci_le_conn_complete_evt, + sizeof(struct hci_ev_le_conn_complete)), + /* [0x02 = HCI_EV_LE_ADVERTISING_REPORT] */ + HCI_LE_EV_VL(HCI_EV_LE_ADVERTISING_REPORT, hci_le_adv_report_evt, + sizeof(struct hci_ev_le_advertising_report), + HCI_MAX_EVENT_SIZE), + /* [0x03 = HCI_EV_LE_CONN_UPDATE_COMPLETE] */ + HCI_LE_EV(HCI_EV_LE_CONN_UPDATE_COMPLETE, + hci_le_conn_update_complete_evt, + sizeof(struct hci_ev_le_conn_update_complete)), + /* [0x04 = HCI_EV_LE_REMOTE_FEAT_COMPLETE] */ + HCI_LE_EV(HCI_EV_LE_REMOTE_FEAT_COMPLETE, + hci_le_remote_feat_complete_evt, + sizeof(struct hci_ev_le_remote_feat_complete)), + /* [0x05 = HCI_EV_LE_LTK_REQ] */ + HCI_LE_EV(HCI_EV_LE_LTK_REQ, hci_le_ltk_request_evt, + sizeof(struct hci_ev_le_ltk_req)), + /* [0x06 = HCI_EV_LE_REMOTE_CONN_PARAM_REQ] */ + HCI_LE_EV(HCI_EV_LE_REMOTE_CONN_PARAM_REQ, + hci_le_remote_conn_param_req_evt, + sizeof(struct hci_ev_le_remote_conn_param_req)), + /* [0x0a = HCI_EV_LE_ENHANCED_CONN_COMPLETE] */ + HCI_LE_EV(HCI_EV_LE_ENHANCED_CONN_COMPLETE, + hci_le_enh_conn_complete_evt, + sizeof(struct hci_ev_le_enh_conn_complete)), + /* [0x0b = HCI_EV_LE_DIRECT_ADV_REPORT] */ + HCI_LE_EV_VL(HCI_EV_LE_DIRECT_ADV_REPORT, hci_le_direct_adv_report_evt, + sizeof(struct hci_ev_le_direct_adv_report), + HCI_MAX_EVENT_SIZE), + /* [0x0c = HCI_EV_LE_PHY_UPDATE_COMPLETE] */ + HCI_LE_EV(HCI_EV_LE_PHY_UPDATE_COMPLETE, hci_le_phy_update_evt, + sizeof(struct hci_ev_le_phy_update_complete)), + /* [0x0d = HCI_EV_LE_EXT_ADV_REPORT] */ + HCI_LE_EV_VL(HCI_EV_LE_EXT_ADV_REPORT, hci_le_ext_adv_report_evt, + sizeof(struct hci_ev_le_ext_adv_report), + HCI_MAX_EVENT_SIZE), + /* [0x0e = HCI_EV_LE_PA_SYNC_ESTABLISHED] */ + HCI_LE_EV(HCI_EV_LE_PA_SYNC_ESTABLISHED, + hci_le_pa_sync_estabilished_evt, + sizeof(struct hci_ev_le_pa_sync_established)), + /* [0x12 = HCI_EV_LE_EXT_ADV_SET_TERM] */ + HCI_LE_EV(HCI_EV_LE_EXT_ADV_SET_TERM, hci_le_ext_adv_term_evt, + sizeof(struct hci_evt_le_ext_adv_set_term)), + /* [0x19 = HCI_EVT_LE_CIS_ESTABLISHED] */ + HCI_LE_EV(HCI_EVT_LE_CIS_ESTABLISHED, hci_le_cis_estabilished_evt, + sizeof(struct hci_evt_le_cis_established)), + /* [0x1a = HCI_EVT_LE_CIS_REQ] */ + HCI_LE_EV(HCI_EVT_LE_CIS_REQ, hci_le_cis_req_evt, + sizeof(struct hci_evt_le_cis_req)), + /* [0x1b = HCI_EVT_LE_CREATE_BIG_COMPLETE] */ + HCI_LE_EV_VL(HCI_EVT_LE_CREATE_BIG_COMPLETE, + hci_le_create_big_complete_evt, + sizeof(struct hci_evt_le_create_big_complete), + HCI_MAX_EVENT_SIZE), + /* [0x1d = HCI_EV_LE_BIG_SYNC_ESTABILISHED] */ + HCI_LE_EV_VL(HCI_EVT_LE_BIG_SYNC_ESTABILISHED, + hci_le_big_sync_established_evt, + sizeof(struct hci_evt_le_big_sync_estabilished), + HCI_MAX_EVENT_SIZE), + /* [0x22 = HCI_EVT_LE_BIG_INFO_ADV_REPORT] */ + HCI_LE_EV_VL(HCI_EVT_LE_BIG_INFO_ADV_REPORT, + hci_le_big_info_adv_report_evt, + sizeof(struct hci_evt_le_big_info_adv_report), + HCI_MAX_EVENT_SIZE), +}; + +static void hci_le_meta_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb, u16 *opcode, u8 *status, + hci_req_complete_t *req_complete, + hci_req_complete_skb_t *req_complete_skb) +{ + struct hci_ev_le_meta *ev = data; + const struct hci_le_ev *subev; + + bt_dev_dbg(hdev, "subevent 0x%2.2x", ev->subevent); + + /* Only match event if command OGF is for LE */ + if (hdev->sent_cmd && + hci_opcode_ogf(hci_skb_opcode(hdev->sent_cmd)) == 0x08 && + hci_skb_event(hdev->sent_cmd) == ev->subevent) { + *opcode = hci_skb_opcode(hdev->sent_cmd); + hci_req_cmd_complete(hdev, *opcode, 0x00, req_complete, + req_complete_skb); + } + + subev = &hci_le_ev_table[ev->subevent]; + if (!subev->func) + return; + + if (skb->len < subev->min_len) { + bt_dev_err(hdev, "unexpected subevent 0x%2.2x length: %u < %u", + ev->subevent, skb->len, subev->min_len); + return; + } + + /* Just warn if the length is over max_len size it still be + * possible to partially parse the event so leave to callback to + * decide if that is acceptable. + */ + if (skb->len > subev->max_len) + bt_dev_warn(hdev, "unexpected subevent 0x%2.2x length: %u > %u", + ev->subevent, skb->len, subev->max_len); + data = hci_le_ev_skb_pull(hdev, skb, ev->subevent, subev->min_len); + if (!data) + return; + + subev->func(hdev, data, skb); +} + +static bool hci_get_cmd_complete(struct hci_dev *hdev, u16 opcode, + u8 event, struct sk_buff *skb) +{ + struct hci_ev_cmd_complete *ev; + struct hci_event_hdr *hdr; + + if (!skb) + return false; + + hdr = hci_ev_skb_pull(hdev, skb, event, sizeof(*hdr)); + if (!hdr) + return false; + + if (event) { + if (hdr->evt != event) + return false; + return true; + } + + /* Check if request ended in Command Status - no way to retrieve + * any extra parameters in this case. + */ + if (hdr->evt == HCI_EV_CMD_STATUS) + return false; + + if (hdr->evt != HCI_EV_CMD_COMPLETE) { + bt_dev_err(hdev, "last event is not cmd complete (0x%2.2x)", + hdr->evt); + return false; + } + + ev = hci_cc_skb_pull(hdev, skb, opcode, sizeof(*ev)); + if (!ev) + return false; + + if (opcode != __le16_to_cpu(ev->opcode)) { + BT_DBG("opcode doesn't match (0x%2.2x != 0x%2.2x)", opcode, + __le16_to_cpu(ev->opcode)); + return false; + } + + return true; +} + +static void hci_store_wake_reason(struct hci_dev *hdev, u8 event, + struct sk_buff *skb) +{ + struct hci_ev_le_advertising_info *adv; + struct hci_ev_le_direct_adv_info *direct_adv; + struct hci_ev_le_ext_adv_info *ext_adv; + const struct hci_ev_conn_complete *conn_complete = (void *)skb->data; + const struct hci_ev_conn_request *conn_request = (void *)skb->data; + + hci_dev_lock(hdev); + + /* If we are currently suspended and this is the first BT event seen, + * save the wake reason associated with the event. + */ + if (!hdev->suspended || hdev->wake_reason) + goto unlock; + + /* Default to remote wake. Values for wake_reason are documented in the + * Bluez mgmt api docs. + */ + hdev->wake_reason = MGMT_WAKE_REASON_REMOTE_WAKE; + + /* Once configured for remote wakeup, we should only wake up for + * reconnections. It's useful to see which device is waking us up so + * keep track of the bdaddr of the connection event that woke us up. + */ + if (event == HCI_EV_CONN_REQUEST) { + bacpy(&hdev->wake_addr, &conn_complete->bdaddr); + hdev->wake_addr_type = BDADDR_BREDR; + } else if (event == HCI_EV_CONN_COMPLETE) { + bacpy(&hdev->wake_addr, &conn_request->bdaddr); + hdev->wake_addr_type = BDADDR_BREDR; + } else if (event == HCI_EV_LE_META) { + struct hci_ev_le_meta *le_ev = (void *)skb->data; + u8 subevent = le_ev->subevent; + u8 *ptr = &skb->data[sizeof(*le_ev)]; + u8 num_reports = *ptr; + + if ((subevent == HCI_EV_LE_ADVERTISING_REPORT || + subevent == HCI_EV_LE_DIRECT_ADV_REPORT || + subevent == HCI_EV_LE_EXT_ADV_REPORT) && + num_reports) { + adv = (void *)(ptr + 1); + direct_adv = (void *)(ptr + 1); + ext_adv = (void *)(ptr + 1); + + switch (subevent) { + case HCI_EV_LE_ADVERTISING_REPORT: + bacpy(&hdev->wake_addr, &adv->bdaddr); + hdev->wake_addr_type = adv->bdaddr_type; + break; + case HCI_EV_LE_DIRECT_ADV_REPORT: + bacpy(&hdev->wake_addr, &direct_adv->bdaddr); + hdev->wake_addr_type = direct_adv->bdaddr_type; + break; + case HCI_EV_LE_EXT_ADV_REPORT: + bacpy(&hdev->wake_addr, &ext_adv->bdaddr); + hdev->wake_addr_type = ext_adv->bdaddr_type; + break; + } + } + } else { + hdev->wake_reason = MGMT_WAKE_REASON_UNEXPECTED; + } + +unlock: + hci_dev_unlock(hdev); +} + +#define HCI_EV_VL(_op, _func, _min_len, _max_len) \ +[_op] = { \ + .req = false, \ + .func = _func, \ + .min_len = _min_len, \ + .max_len = _max_len, \ +} + +#define HCI_EV(_op, _func, _len) \ + HCI_EV_VL(_op, _func, _len, _len) + +#define HCI_EV_STATUS(_op, _func) \ + HCI_EV(_op, _func, sizeof(struct hci_ev_status)) + +#define HCI_EV_REQ_VL(_op, _func, _min_len, _max_len) \ +[_op] = { \ + .req = true, \ + .func_req = _func, \ + .min_len = _min_len, \ + .max_len = _max_len, \ +} + +#define HCI_EV_REQ(_op, _func, _len) \ + HCI_EV_REQ_VL(_op, _func, _len, _len) + +/* Entries in this table shall have their position according to the event opcode + * they handle so the use of the macros above is recommend since it does attempt + * to initialize at its proper index using Designated Initializers that way + * events without a callback function don't have entered. + */ +static const struct hci_ev { + bool req; + union { + void (*func)(struct hci_dev *hdev, void *data, + struct sk_buff *skb); + void (*func_req)(struct hci_dev *hdev, void *data, + struct sk_buff *skb, u16 *opcode, u8 *status, + hci_req_complete_t *req_complete, + hci_req_complete_skb_t *req_complete_skb); + }; + u16 min_len; + u16 max_len; +} hci_ev_table[U8_MAX + 1] = { + /* [0x01 = HCI_EV_INQUIRY_COMPLETE] */ + HCI_EV_STATUS(HCI_EV_INQUIRY_COMPLETE, hci_inquiry_complete_evt), + /* [0x02 = HCI_EV_INQUIRY_RESULT] */ + HCI_EV_VL(HCI_EV_INQUIRY_RESULT, hci_inquiry_result_evt, + sizeof(struct hci_ev_inquiry_result), HCI_MAX_EVENT_SIZE), + /* [0x03 = HCI_EV_CONN_COMPLETE] */ + HCI_EV(HCI_EV_CONN_COMPLETE, hci_conn_complete_evt, + sizeof(struct hci_ev_conn_complete)), + /* [0x04 = HCI_EV_CONN_REQUEST] */ + HCI_EV(HCI_EV_CONN_REQUEST, hci_conn_request_evt, + sizeof(struct hci_ev_conn_request)), + /* [0x05 = HCI_EV_DISCONN_COMPLETE] */ + HCI_EV(HCI_EV_DISCONN_COMPLETE, hci_disconn_complete_evt, + sizeof(struct hci_ev_disconn_complete)), + /* [0x06 = HCI_EV_AUTH_COMPLETE] */ + HCI_EV(HCI_EV_AUTH_COMPLETE, hci_auth_complete_evt, + sizeof(struct hci_ev_auth_complete)), + /* [0x07 = HCI_EV_REMOTE_NAME] */ + HCI_EV(HCI_EV_REMOTE_NAME, hci_remote_name_evt, + sizeof(struct hci_ev_remote_name)), + /* [0x08 = HCI_EV_ENCRYPT_CHANGE] */ + HCI_EV(HCI_EV_ENCRYPT_CHANGE, hci_encrypt_change_evt, + sizeof(struct hci_ev_encrypt_change)), + /* [0x09 = HCI_EV_CHANGE_LINK_KEY_COMPLETE] */ + HCI_EV(HCI_EV_CHANGE_LINK_KEY_COMPLETE, + hci_change_link_key_complete_evt, + sizeof(struct hci_ev_change_link_key_complete)), + /* [0x0b = HCI_EV_REMOTE_FEATURES] */ + HCI_EV(HCI_EV_REMOTE_FEATURES, hci_remote_features_evt, + sizeof(struct hci_ev_remote_features)), + /* [0x0e = HCI_EV_CMD_COMPLETE] */ + HCI_EV_REQ_VL(HCI_EV_CMD_COMPLETE, hci_cmd_complete_evt, + sizeof(struct hci_ev_cmd_complete), HCI_MAX_EVENT_SIZE), + /* [0x0f = HCI_EV_CMD_STATUS] */ + HCI_EV_REQ(HCI_EV_CMD_STATUS, hci_cmd_status_evt, + sizeof(struct hci_ev_cmd_status)), + /* [0x10 = HCI_EV_CMD_STATUS] */ + HCI_EV(HCI_EV_HARDWARE_ERROR, hci_hardware_error_evt, + sizeof(struct hci_ev_hardware_error)), + /* [0x12 = HCI_EV_ROLE_CHANGE] */ + HCI_EV(HCI_EV_ROLE_CHANGE, hci_role_change_evt, + sizeof(struct hci_ev_role_change)), + /* [0x13 = HCI_EV_NUM_COMP_PKTS] */ + HCI_EV_VL(HCI_EV_NUM_COMP_PKTS, hci_num_comp_pkts_evt, + sizeof(struct hci_ev_num_comp_pkts), HCI_MAX_EVENT_SIZE), + /* [0x14 = HCI_EV_MODE_CHANGE] */ + HCI_EV(HCI_EV_MODE_CHANGE, hci_mode_change_evt, + sizeof(struct hci_ev_mode_change)), + /* [0x16 = HCI_EV_PIN_CODE_REQ] */ + HCI_EV(HCI_EV_PIN_CODE_REQ, hci_pin_code_request_evt, + sizeof(struct hci_ev_pin_code_req)), + /* [0x17 = HCI_EV_LINK_KEY_REQ] */ + HCI_EV(HCI_EV_LINK_KEY_REQ, hci_link_key_request_evt, + sizeof(struct hci_ev_link_key_req)), + /* [0x18 = HCI_EV_LINK_KEY_NOTIFY] */ + HCI_EV(HCI_EV_LINK_KEY_NOTIFY, hci_link_key_notify_evt, + sizeof(struct hci_ev_link_key_notify)), + /* [0x1c = HCI_EV_CLOCK_OFFSET] */ + HCI_EV(HCI_EV_CLOCK_OFFSET, hci_clock_offset_evt, + sizeof(struct hci_ev_clock_offset)), + /* [0x1d = HCI_EV_PKT_TYPE_CHANGE] */ + HCI_EV(HCI_EV_PKT_TYPE_CHANGE, hci_pkt_type_change_evt, + sizeof(struct hci_ev_pkt_type_change)), + /* [0x20 = HCI_EV_PSCAN_REP_MODE] */ + HCI_EV(HCI_EV_PSCAN_REP_MODE, hci_pscan_rep_mode_evt, + sizeof(struct hci_ev_pscan_rep_mode)), + /* [0x22 = HCI_EV_INQUIRY_RESULT_WITH_RSSI] */ + HCI_EV_VL(HCI_EV_INQUIRY_RESULT_WITH_RSSI, + hci_inquiry_result_with_rssi_evt, + sizeof(struct hci_ev_inquiry_result_rssi), + HCI_MAX_EVENT_SIZE), + /* [0x23 = HCI_EV_REMOTE_EXT_FEATURES] */ + HCI_EV(HCI_EV_REMOTE_EXT_FEATURES, hci_remote_ext_features_evt, + sizeof(struct hci_ev_remote_ext_features)), + /* [0x2c = HCI_EV_SYNC_CONN_COMPLETE] */ + HCI_EV(HCI_EV_SYNC_CONN_COMPLETE, hci_sync_conn_complete_evt, + sizeof(struct hci_ev_sync_conn_complete)), + /* [0x2d = HCI_EV_EXTENDED_INQUIRY_RESULT] */ + HCI_EV_VL(HCI_EV_EXTENDED_INQUIRY_RESULT, + hci_extended_inquiry_result_evt, + sizeof(struct hci_ev_ext_inquiry_result), HCI_MAX_EVENT_SIZE), + /* [0x30 = HCI_EV_KEY_REFRESH_COMPLETE] */ + HCI_EV(HCI_EV_KEY_REFRESH_COMPLETE, hci_key_refresh_complete_evt, + sizeof(struct hci_ev_key_refresh_complete)), + /* [0x31 = HCI_EV_IO_CAPA_REQUEST] */ + HCI_EV(HCI_EV_IO_CAPA_REQUEST, hci_io_capa_request_evt, + sizeof(struct hci_ev_io_capa_request)), + /* [0x32 = HCI_EV_IO_CAPA_REPLY] */ + HCI_EV(HCI_EV_IO_CAPA_REPLY, hci_io_capa_reply_evt, + sizeof(struct hci_ev_io_capa_reply)), + /* [0x33 = HCI_EV_USER_CONFIRM_REQUEST] */ + HCI_EV(HCI_EV_USER_CONFIRM_REQUEST, hci_user_confirm_request_evt, + sizeof(struct hci_ev_user_confirm_req)), + /* [0x34 = HCI_EV_USER_PASSKEY_REQUEST] */ + HCI_EV(HCI_EV_USER_PASSKEY_REQUEST, hci_user_passkey_request_evt, + sizeof(struct hci_ev_user_passkey_req)), + /* [0x35 = HCI_EV_REMOTE_OOB_DATA_REQUEST] */ + HCI_EV(HCI_EV_REMOTE_OOB_DATA_REQUEST, hci_remote_oob_data_request_evt, + sizeof(struct hci_ev_remote_oob_data_request)), + /* [0x36 = HCI_EV_SIMPLE_PAIR_COMPLETE] */ + HCI_EV(HCI_EV_SIMPLE_PAIR_COMPLETE, hci_simple_pair_complete_evt, + sizeof(struct hci_ev_simple_pair_complete)), + /* [0x3b = HCI_EV_USER_PASSKEY_NOTIFY] */ + HCI_EV(HCI_EV_USER_PASSKEY_NOTIFY, hci_user_passkey_notify_evt, + sizeof(struct hci_ev_user_passkey_notify)), + /* [0x3c = HCI_EV_KEYPRESS_NOTIFY] */ + HCI_EV(HCI_EV_KEYPRESS_NOTIFY, hci_keypress_notify_evt, + sizeof(struct hci_ev_keypress_notify)), + /* [0x3d = HCI_EV_REMOTE_HOST_FEATURES] */ + HCI_EV(HCI_EV_REMOTE_HOST_FEATURES, hci_remote_host_features_evt, + sizeof(struct hci_ev_remote_host_features)), + /* [0x3e = HCI_EV_LE_META] */ + HCI_EV_REQ_VL(HCI_EV_LE_META, hci_le_meta_evt, + sizeof(struct hci_ev_le_meta), HCI_MAX_EVENT_SIZE), +#if IS_ENABLED(CONFIG_BT_HS) + /* [0x40 = HCI_EV_PHY_LINK_COMPLETE] */ + HCI_EV(HCI_EV_PHY_LINK_COMPLETE, hci_phy_link_complete_evt, + sizeof(struct hci_ev_phy_link_complete)), + /* [0x41 = HCI_EV_CHANNEL_SELECTED] */ + HCI_EV(HCI_EV_CHANNEL_SELECTED, hci_chan_selected_evt, + sizeof(struct hci_ev_channel_selected)), + /* [0x42 = HCI_EV_DISCONN_PHY_LINK_COMPLETE] */ + HCI_EV(HCI_EV_DISCONN_LOGICAL_LINK_COMPLETE, + hci_disconn_loglink_complete_evt, + sizeof(struct hci_ev_disconn_logical_link_complete)), + /* [0x45 = HCI_EV_LOGICAL_LINK_COMPLETE] */ + HCI_EV(HCI_EV_LOGICAL_LINK_COMPLETE, hci_loglink_complete_evt, + sizeof(struct hci_ev_logical_link_complete)), + /* [0x46 = HCI_EV_DISCONN_LOGICAL_LINK_COMPLETE] */ + HCI_EV(HCI_EV_DISCONN_PHY_LINK_COMPLETE, + hci_disconn_phylink_complete_evt, + sizeof(struct hci_ev_disconn_phy_link_complete)), +#endif + /* [0x48 = HCI_EV_NUM_COMP_BLOCKS] */ + HCI_EV(HCI_EV_NUM_COMP_BLOCKS, hci_num_comp_blocks_evt, + sizeof(struct hci_ev_num_comp_blocks)), + /* [0xff = HCI_EV_VENDOR] */ + HCI_EV_VL(HCI_EV_VENDOR, msft_vendor_evt, 0, HCI_MAX_EVENT_SIZE), +}; + +static void hci_event_func(struct hci_dev *hdev, u8 event, struct sk_buff *skb, + u16 *opcode, u8 *status, + hci_req_complete_t *req_complete, + hci_req_complete_skb_t *req_complete_skb) +{ + const struct hci_ev *ev = &hci_ev_table[event]; + void *data; + + if (!ev->func) + return; + + if (skb->len < ev->min_len) { + bt_dev_err(hdev, "unexpected event 0x%2.2x length: %u < %u", + event, skb->len, ev->min_len); + return; + } + + /* Just warn if the length is over max_len size it still be + * possible to partially parse the event so leave to callback to + * decide if that is acceptable. + */ + if (skb->len > ev->max_len) + bt_dev_warn_ratelimited(hdev, + "unexpected event 0x%2.2x length: %u > %u", + event, skb->len, ev->max_len); + + data = hci_ev_skb_pull(hdev, skb, event, ev->min_len); + if (!data) + return; + + if (ev->req) + ev->func_req(hdev, data, skb, opcode, status, req_complete, + req_complete_skb); + else + ev->func(hdev, data, skb); +} + +void hci_event_packet(struct hci_dev *hdev, struct sk_buff *skb) +{ + struct hci_event_hdr *hdr = (void *) skb->data; + hci_req_complete_t req_complete = NULL; + hci_req_complete_skb_t req_complete_skb = NULL; + struct sk_buff *orig_skb = NULL; + u8 status = 0, event, req_evt = 0; + u16 opcode = HCI_OP_NOP; + + if (skb->len < sizeof(*hdr)) { + bt_dev_err(hdev, "Malformed HCI Event"); + goto done; + } + + kfree_skb(hdev->recv_event); + hdev->recv_event = skb_clone(skb, GFP_KERNEL); + + event = hdr->evt; + if (!event) { + bt_dev_warn(hdev, "Received unexpected HCI Event 0x%2.2x", + event); + goto done; + } + + /* Only match event if command OGF is not for LE */ + if (hdev->sent_cmd && + hci_opcode_ogf(hci_skb_opcode(hdev->sent_cmd)) != 0x08 && + hci_skb_event(hdev->sent_cmd) == event) { + hci_req_cmd_complete(hdev, hci_skb_opcode(hdev->sent_cmd), + status, &req_complete, &req_complete_skb); + req_evt = event; + } + + /* If it looks like we might end up having to call + * req_complete_skb, store a pristine copy of the skb since the + * various handlers may modify the original one through + * skb_pull() calls, etc. + */ + if (req_complete_skb || event == HCI_EV_CMD_STATUS || + event == HCI_EV_CMD_COMPLETE) + orig_skb = skb_clone(skb, GFP_KERNEL); + + skb_pull(skb, HCI_EVENT_HDR_SIZE); + + /* Store wake reason if we're suspended */ + hci_store_wake_reason(hdev, event, skb); + + bt_dev_dbg(hdev, "event 0x%2.2x", event); + + hci_event_func(hdev, event, skb, &opcode, &status, &req_complete, + &req_complete_skb); + + if (req_complete) { + req_complete(hdev, status, opcode); + } else if (req_complete_skb) { + if (!hci_get_cmd_complete(hdev, opcode, req_evt, orig_skb)) { + kfree_skb(orig_skb); + orig_skb = NULL; + } + req_complete_skb(hdev, status, opcode, orig_skb); + } + +done: + kfree_skb(orig_skb); + kfree_skb(skb); + hdev->stat.evt_rx++; +} diff --git a/net/bluetooth/hci_request.c b/net/bluetooth/hci_request.c new file mode 100644 index 000000000..f7e006a36 --- /dev/null +++ b/net/bluetooth/hci_request.c @@ -0,0 +1,922 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + + Copyright (C) 2014 Intel Corporation + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#include + +#include +#include +#include + +#include "smp.h" +#include "hci_request.h" +#include "msft.h" +#include "eir.h" + +void hci_req_init(struct hci_request *req, struct hci_dev *hdev) +{ + skb_queue_head_init(&req->cmd_q); + req->hdev = hdev; + req->err = 0; +} + +void hci_req_purge(struct hci_request *req) +{ + skb_queue_purge(&req->cmd_q); +} + +bool hci_req_status_pend(struct hci_dev *hdev) +{ + return hdev->req_status == HCI_REQ_PEND; +} + +static int req_run(struct hci_request *req, hci_req_complete_t complete, + hci_req_complete_skb_t complete_skb) +{ + struct hci_dev *hdev = req->hdev; + struct sk_buff *skb; + unsigned long flags; + + bt_dev_dbg(hdev, "length %u", skb_queue_len(&req->cmd_q)); + + /* If an error occurred during request building, remove all HCI + * commands queued on the HCI request queue. + */ + if (req->err) { + skb_queue_purge(&req->cmd_q); + return req->err; + } + + /* Do not allow empty requests */ + if (skb_queue_empty(&req->cmd_q)) + return -ENODATA; + + skb = skb_peek_tail(&req->cmd_q); + if (complete) { + bt_cb(skb)->hci.req_complete = complete; + } else if (complete_skb) { + bt_cb(skb)->hci.req_complete_skb = complete_skb; + bt_cb(skb)->hci.req_flags |= HCI_REQ_SKB; + } + + spin_lock_irqsave(&hdev->cmd_q.lock, flags); + skb_queue_splice_tail(&req->cmd_q, &hdev->cmd_q); + spin_unlock_irqrestore(&hdev->cmd_q.lock, flags); + + queue_work(hdev->workqueue, &hdev->cmd_work); + + return 0; +} + +int hci_req_run(struct hci_request *req, hci_req_complete_t complete) +{ + return req_run(req, complete, NULL); +} + +int hci_req_run_skb(struct hci_request *req, hci_req_complete_skb_t complete) +{ + return req_run(req, NULL, complete); +} + +void hci_req_sync_complete(struct hci_dev *hdev, u8 result, u16 opcode, + struct sk_buff *skb) +{ + bt_dev_dbg(hdev, "result 0x%2.2x", result); + + if (hdev->req_status == HCI_REQ_PEND) { + hdev->req_result = result; + hdev->req_status = HCI_REQ_DONE; + if (skb) + hdev->req_skb = skb_get(skb); + wake_up_interruptible(&hdev->req_wait_q); + } +} + +/* Execute request and wait for completion. */ +int __hci_req_sync(struct hci_dev *hdev, int (*func)(struct hci_request *req, + unsigned long opt), + unsigned long opt, u32 timeout, u8 *hci_status) +{ + struct hci_request req; + int err = 0; + + bt_dev_dbg(hdev, "start"); + + hci_req_init(&req, hdev); + + hdev->req_status = HCI_REQ_PEND; + + err = func(&req, opt); + if (err) { + if (hci_status) + *hci_status = HCI_ERROR_UNSPECIFIED; + return err; + } + + err = hci_req_run_skb(&req, hci_req_sync_complete); + if (err < 0) { + hdev->req_status = 0; + + /* ENODATA means the HCI request command queue is empty. + * This can happen when a request with conditionals doesn't + * trigger any commands to be sent. This is normal behavior + * and should not trigger an error return. + */ + if (err == -ENODATA) { + if (hci_status) + *hci_status = 0; + return 0; + } + + if (hci_status) + *hci_status = HCI_ERROR_UNSPECIFIED; + + return err; + } + + err = wait_event_interruptible_timeout(hdev->req_wait_q, + hdev->req_status != HCI_REQ_PEND, timeout); + + if (err == -ERESTARTSYS) + return -EINTR; + + switch (hdev->req_status) { + case HCI_REQ_DONE: + err = -bt_to_errno(hdev->req_result); + if (hci_status) + *hci_status = hdev->req_result; + break; + + case HCI_REQ_CANCELED: + err = -hdev->req_result; + if (hci_status) + *hci_status = HCI_ERROR_UNSPECIFIED; + break; + + default: + err = -ETIMEDOUT; + if (hci_status) + *hci_status = HCI_ERROR_UNSPECIFIED; + break; + } + + kfree_skb(hdev->req_skb); + hdev->req_skb = NULL; + hdev->req_status = hdev->req_result = 0; + + bt_dev_dbg(hdev, "end: err %d", err); + + return err; +} + +int hci_req_sync(struct hci_dev *hdev, int (*req)(struct hci_request *req, + unsigned long opt), + unsigned long opt, u32 timeout, u8 *hci_status) +{ + int ret; + + /* Serialize all requests */ + hci_req_sync_lock(hdev); + /* check the state after obtaing the lock to protect the HCI_UP + * against any races from hci_dev_do_close when the controller + * gets removed. + */ + if (test_bit(HCI_UP, &hdev->flags)) + ret = __hci_req_sync(hdev, req, opt, timeout, hci_status); + else + ret = -ENETDOWN; + hci_req_sync_unlock(hdev); + + return ret; +} + +struct sk_buff *hci_prepare_cmd(struct hci_dev *hdev, u16 opcode, u32 plen, + const void *param) +{ + int len = HCI_COMMAND_HDR_SIZE + plen; + struct hci_command_hdr *hdr; + struct sk_buff *skb; + + skb = bt_skb_alloc(len, GFP_ATOMIC); + if (!skb) + return NULL; + + hdr = skb_put(skb, HCI_COMMAND_HDR_SIZE); + hdr->opcode = cpu_to_le16(opcode); + hdr->plen = plen; + + if (plen) + skb_put_data(skb, param, plen); + + bt_dev_dbg(hdev, "skb len %d", skb->len); + + hci_skb_pkt_type(skb) = HCI_COMMAND_PKT; + hci_skb_opcode(skb) = opcode; + + return skb; +} + +/* Queue a command to an asynchronous HCI request */ +void hci_req_add_ev(struct hci_request *req, u16 opcode, u32 plen, + const void *param, u8 event) +{ + struct hci_dev *hdev = req->hdev; + struct sk_buff *skb; + + bt_dev_dbg(hdev, "opcode 0x%4.4x plen %d", opcode, plen); + + /* If an error occurred during request building, there is no point in + * queueing the HCI command. We can simply return. + */ + if (req->err) + return; + + skb = hci_prepare_cmd(hdev, opcode, plen, param); + if (!skb) { + bt_dev_err(hdev, "no memory for command (opcode 0x%4.4x)", + opcode); + req->err = -ENOMEM; + return; + } + + if (skb_queue_empty(&req->cmd_q)) + bt_cb(skb)->hci.req_flags |= HCI_REQ_START; + + hci_skb_event(skb) = event; + + skb_queue_tail(&req->cmd_q, skb); +} + +void hci_req_add(struct hci_request *req, u16 opcode, u32 plen, + const void *param) +{ + bt_dev_dbg(req->hdev, "HCI_REQ-0x%4.4x", opcode); + hci_req_add_ev(req, opcode, plen, param, 0); +} + +static void start_interleave_scan(struct hci_dev *hdev) +{ + hdev->interleave_scan_state = INTERLEAVE_SCAN_NO_FILTER; + queue_delayed_work(hdev->req_workqueue, + &hdev->interleave_scan, 0); +} + +static bool is_interleave_scanning(struct hci_dev *hdev) +{ + return hdev->interleave_scan_state != INTERLEAVE_SCAN_NONE; +} + +static void cancel_interleave_scan(struct hci_dev *hdev) +{ + bt_dev_dbg(hdev, "cancelling interleave scan"); + + cancel_delayed_work_sync(&hdev->interleave_scan); + + hdev->interleave_scan_state = INTERLEAVE_SCAN_NONE; +} + +/* Return true if interleave_scan wasn't started until exiting this function, + * otherwise, return false + */ +static bool __hci_update_interleaved_scan(struct hci_dev *hdev) +{ + /* Do interleaved scan only if all of the following are true: + * - There is at least one ADV monitor + * - At least one pending LE connection or one device to be scanned for + * - Monitor offloading is not supported + * If so, we should alternate between allowlist scan and one without + * any filters to save power. + */ + bool use_interleaving = hci_is_adv_monitoring(hdev) && + !(list_empty(&hdev->pend_le_conns) && + list_empty(&hdev->pend_le_reports)) && + hci_get_adv_monitor_offload_ext(hdev) == + HCI_ADV_MONITOR_EXT_NONE; + bool is_interleaving = is_interleave_scanning(hdev); + + if (use_interleaving && !is_interleaving) { + start_interleave_scan(hdev); + bt_dev_dbg(hdev, "starting interleave scan"); + return true; + } + + if (!use_interleaving && is_interleaving) + cancel_interleave_scan(hdev); + + return false; +} + +void hci_req_add_le_scan_disable(struct hci_request *req, bool rpa_le_conn) +{ + struct hci_dev *hdev = req->hdev; + + if (hdev->scanning_paused) { + bt_dev_dbg(hdev, "Scanning is paused for suspend"); + return; + } + + if (use_ext_scan(hdev)) { + struct hci_cp_le_set_ext_scan_enable cp; + + memset(&cp, 0, sizeof(cp)); + cp.enable = LE_SCAN_DISABLE; + hci_req_add(req, HCI_OP_LE_SET_EXT_SCAN_ENABLE, sizeof(cp), + &cp); + } else { + struct hci_cp_le_set_scan_enable cp; + + memset(&cp, 0, sizeof(cp)); + cp.enable = LE_SCAN_DISABLE; + hci_req_add(req, HCI_OP_LE_SET_SCAN_ENABLE, sizeof(cp), &cp); + } + + /* Disable address resolution */ + if (hci_dev_test_flag(hdev, HCI_LL_RPA_RESOLUTION) && !rpa_le_conn) { + __u8 enable = 0x00; + + hci_req_add(req, HCI_OP_LE_SET_ADDR_RESOLV_ENABLE, 1, &enable); + } +} + +static void del_from_accept_list(struct hci_request *req, bdaddr_t *bdaddr, + u8 bdaddr_type) +{ + struct hci_cp_le_del_from_accept_list cp; + + cp.bdaddr_type = bdaddr_type; + bacpy(&cp.bdaddr, bdaddr); + + bt_dev_dbg(req->hdev, "Remove %pMR (0x%x) from accept list", &cp.bdaddr, + cp.bdaddr_type); + hci_req_add(req, HCI_OP_LE_DEL_FROM_ACCEPT_LIST, sizeof(cp), &cp); + + if (use_ll_privacy(req->hdev)) { + struct smp_irk *irk; + + irk = hci_find_irk_by_addr(req->hdev, bdaddr, bdaddr_type); + if (irk) { + struct hci_cp_le_del_from_resolv_list cp; + + cp.bdaddr_type = bdaddr_type; + bacpy(&cp.bdaddr, bdaddr); + + hci_req_add(req, HCI_OP_LE_DEL_FROM_RESOLV_LIST, + sizeof(cp), &cp); + } + } +} + +/* Adds connection to accept list if needed. On error, returns -1. */ +static int add_to_accept_list(struct hci_request *req, + struct hci_conn_params *params, u8 *num_entries, + bool allow_rpa) +{ + struct hci_cp_le_add_to_accept_list cp; + struct hci_dev *hdev = req->hdev; + + /* Already in accept list */ + if (hci_bdaddr_list_lookup(&hdev->le_accept_list, ¶ms->addr, + params->addr_type)) + return 0; + + /* Select filter policy to accept all advertising */ + if (*num_entries >= hdev->le_accept_list_size) + return -1; + + /* Accept list can not be used with RPAs */ + if (!allow_rpa && + !hci_dev_test_flag(hdev, HCI_ENABLE_LL_PRIVACY) && + hci_find_irk_by_addr(hdev, ¶ms->addr, params->addr_type)) { + return -1; + } + + /* During suspend, only wakeable devices can be in accept list */ + if (hdev->suspended && + !(params->flags & HCI_CONN_FLAG_REMOTE_WAKEUP)) + return 0; + + *num_entries += 1; + cp.bdaddr_type = params->addr_type; + bacpy(&cp.bdaddr, ¶ms->addr); + + bt_dev_dbg(hdev, "Add %pMR (0x%x) to accept list", &cp.bdaddr, + cp.bdaddr_type); + hci_req_add(req, HCI_OP_LE_ADD_TO_ACCEPT_LIST, sizeof(cp), &cp); + + if (use_ll_privacy(hdev)) { + struct smp_irk *irk; + + irk = hci_find_irk_by_addr(hdev, ¶ms->addr, + params->addr_type); + if (irk) { + struct hci_cp_le_add_to_resolv_list cp; + + cp.bdaddr_type = params->addr_type; + bacpy(&cp.bdaddr, ¶ms->addr); + memcpy(cp.peer_irk, irk->val, 16); + + if (hci_dev_test_flag(hdev, HCI_PRIVACY)) + memcpy(cp.local_irk, hdev->irk, 16); + else + memset(cp.local_irk, 0, 16); + + hci_req_add(req, HCI_OP_LE_ADD_TO_RESOLV_LIST, + sizeof(cp), &cp); + } + } + + return 0; +} + +static u8 update_accept_list(struct hci_request *req) +{ + struct hci_dev *hdev = req->hdev; + struct hci_conn_params *params; + struct bdaddr_list *b; + u8 num_entries = 0; + bool pend_conn, pend_report; + /* We allow usage of accept list even with RPAs in suspend. In the worst + * case, we won't be able to wake from devices that use the privacy1.2 + * features. Additionally, once we support privacy1.2 and IRK + * offloading, we can update this to also check for those conditions. + */ + bool allow_rpa = hdev->suspended; + + if (use_ll_privacy(hdev)) + allow_rpa = true; + + /* Go through the current accept list programmed into the + * controller one by one and check if that address is still + * in the list of pending connections or list of devices to + * report. If not present in either list, then queue the + * command to remove it from the controller. + */ + list_for_each_entry(b, &hdev->le_accept_list, list) { + pend_conn = hci_pend_le_action_lookup(&hdev->pend_le_conns, + &b->bdaddr, + b->bdaddr_type); + pend_report = hci_pend_le_action_lookup(&hdev->pend_le_reports, + &b->bdaddr, + b->bdaddr_type); + + /* If the device is not likely to connect or report, + * remove it from the accept list. + */ + if (!pend_conn && !pend_report) { + del_from_accept_list(req, &b->bdaddr, b->bdaddr_type); + continue; + } + + /* Accept list can not be used with RPAs */ + if (!allow_rpa && + !hci_dev_test_flag(hdev, HCI_ENABLE_LL_PRIVACY) && + hci_find_irk_by_addr(hdev, &b->bdaddr, b->bdaddr_type)) { + return 0x00; + } + + num_entries++; + } + + /* Since all no longer valid accept list entries have been + * removed, walk through the list of pending connections + * and ensure that any new device gets programmed into + * the controller. + * + * If the list of the devices is larger than the list of + * available accept list entries in the controller, then + * just abort and return filer policy value to not use the + * accept list. + */ + list_for_each_entry(params, &hdev->pend_le_conns, action) { + if (add_to_accept_list(req, params, &num_entries, allow_rpa)) + return 0x00; + } + + /* After adding all new pending connections, walk through + * the list of pending reports and also add these to the + * accept list if there is still space. Abort if space runs out. + */ + list_for_each_entry(params, &hdev->pend_le_reports, action) { + if (add_to_accept_list(req, params, &num_entries, allow_rpa)) + return 0x00; + } + + /* Use the allowlist unless the following conditions are all true: + * - We are not currently suspending + * - There are 1 or more ADV monitors registered and it's not offloaded + * - Interleaved scanning is not currently using the allowlist + */ + if (!idr_is_empty(&hdev->adv_monitors_idr) && !hdev->suspended && + hci_get_adv_monitor_offload_ext(hdev) == HCI_ADV_MONITOR_EXT_NONE && + hdev->interleave_scan_state != INTERLEAVE_SCAN_ALLOWLIST) + return 0x00; + + /* Select filter policy to use accept list */ + return 0x01; +} + +static bool scan_use_rpa(struct hci_dev *hdev) +{ + return hci_dev_test_flag(hdev, HCI_PRIVACY); +} + +static void hci_req_start_scan(struct hci_request *req, u8 type, u16 interval, + u16 window, u8 own_addr_type, u8 filter_policy, + bool filter_dup, bool addr_resolv) +{ + struct hci_dev *hdev = req->hdev; + + if (hdev->scanning_paused) { + bt_dev_dbg(hdev, "Scanning is paused for suspend"); + return; + } + + if (use_ll_privacy(hdev) && addr_resolv) { + u8 enable = 0x01; + + hci_req_add(req, HCI_OP_LE_SET_ADDR_RESOLV_ENABLE, 1, &enable); + } + + /* Use ext scanning if set ext scan param and ext scan enable is + * supported + */ + if (use_ext_scan(hdev)) { + struct hci_cp_le_set_ext_scan_params *ext_param_cp; + struct hci_cp_le_set_ext_scan_enable ext_enable_cp; + struct hci_cp_le_scan_phy_params *phy_params; + u8 data[sizeof(*ext_param_cp) + sizeof(*phy_params) * 2]; + u32 plen; + + ext_param_cp = (void *)data; + phy_params = (void *)ext_param_cp->data; + + memset(ext_param_cp, 0, sizeof(*ext_param_cp)); + ext_param_cp->own_addr_type = own_addr_type; + ext_param_cp->filter_policy = filter_policy; + + plen = sizeof(*ext_param_cp); + + if (scan_1m(hdev) || scan_2m(hdev)) { + ext_param_cp->scanning_phys |= LE_SCAN_PHY_1M; + + memset(phy_params, 0, sizeof(*phy_params)); + phy_params->type = type; + phy_params->interval = cpu_to_le16(interval); + phy_params->window = cpu_to_le16(window); + + plen += sizeof(*phy_params); + phy_params++; + } + + if (scan_coded(hdev)) { + ext_param_cp->scanning_phys |= LE_SCAN_PHY_CODED; + + memset(phy_params, 0, sizeof(*phy_params)); + phy_params->type = type; + phy_params->interval = cpu_to_le16(interval); + phy_params->window = cpu_to_le16(window); + + plen += sizeof(*phy_params); + phy_params++; + } + + hci_req_add(req, HCI_OP_LE_SET_EXT_SCAN_PARAMS, + plen, ext_param_cp); + + memset(&ext_enable_cp, 0, sizeof(ext_enable_cp)); + ext_enable_cp.enable = LE_SCAN_ENABLE; + ext_enable_cp.filter_dup = filter_dup; + + hci_req_add(req, HCI_OP_LE_SET_EXT_SCAN_ENABLE, + sizeof(ext_enable_cp), &ext_enable_cp); + } else { + struct hci_cp_le_set_scan_param param_cp; + struct hci_cp_le_set_scan_enable enable_cp; + + memset(¶m_cp, 0, sizeof(param_cp)); + param_cp.type = type; + param_cp.interval = cpu_to_le16(interval); + param_cp.window = cpu_to_le16(window); + param_cp.own_address_type = own_addr_type; + param_cp.filter_policy = filter_policy; + hci_req_add(req, HCI_OP_LE_SET_SCAN_PARAM, sizeof(param_cp), + ¶m_cp); + + memset(&enable_cp, 0, sizeof(enable_cp)); + enable_cp.enable = LE_SCAN_ENABLE; + enable_cp.filter_dup = filter_dup; + hci_req_add(req, HCI_OP_LE_SET_SCAN_ENABLE, sizeof(enable_cp), + &enable_cp); + } +} + +/* Returns true if an le connection is in the scanning state */ +static inline bool hci_is_le_conn_scanning(struct hci_dev *hdev) +{ + struct hci_conn_hash *h = &hdev->conn_hash; + struct hci_conn *c; + + rcu_read_lock(); + + list_for_each_entry_rcu(c, &h->list, list) { + if (c->type == LE_LINK && c->state == BT_CONNECT && + test_bit(HCI_CONN_SCANNING, &c->flags)) { + rcu_read_unlock(); + return true; + } + } + + rcu_read_unlock(); + + return false; +} + +static void set_random_addr(struct hci_request *req, bdaddr_t *rpa); +static int hci_update_random_address(struct hci_request *req, + bool require_privacy, bool use_rpa, + u8 *own_addr_type) +{ + struct hci_dev *hdev = req->hdev; + int err; + + /* If privacy is enabled use a resolvable private address. If + * current RPA has expired or there is something else than + * the current RPA in use, then generate a new one. + */ + if (use_rpa) { + /* If Controller supports LL Privacy use own address type is + * 0x03 + */ + if (use_ll_privacy(hdev)) + *own_addr_type = ADDR_LE_DEV_RANDOM_RESOLVED; + else + *own_addr_type = ADDR_LE_DEV_RANDOM; + + if (rpa_valid(hdev)) + return 0; + + err = smp_generate_rpa(hdev, hdev->irk, &hdev->rpa); + if (err < 0) { + bt_dev_err(hdev, "failed to generate new RPA"); + return err; + } + + set_random_addr(req, &hdev->rpa); + + return 0; + } + + /* In case of required privacy without resolvable private address, + * use an non-resolvable private address. This is useful for active + * scanning and non-connectable advertising. + */ + if (require_privacy) { + bdaddr_t nrpa; + + while (true) { + /* The non-resolvable private address is generated + * from random six bytes with the two most significant + * bits cleared. + */ + get_random_bytes(&nrpa, 6); + nrpa.b[5] &= 0x3f; + + /* The non-resolvable private address shall not be + * equal to the public address. + */ + if (bacmp(&hdev->bdaddr, &nrpa)) + break; + } + + *own_addr_type = ADDR_LE_DEV_RANDOM; + set_random_addr(req, &nrpa); + return 0; + } + + /* If forcing static address is in use or there is no public + * address use the static address as random address (but skip + * the HCI command if the current random address is already the + * static one. + * + * In case BR/EDR has been disabled on a dual-mode controller + * and a static address has been configured, then use that + * address instead of the public BR/EDR address. + */ + if (hci_dev_test_flag(hdev, HCI_FORCE_STATIC_ADDR) || + !bacmp(&hdev->bdaddr, BDADDR_ANY) || + (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED) && + bacmp(&hdev->static_addr, BDADDR_ANY))) { + *own_addr_type = ADDR_LE_DEV_RANDOM; + if (bacmp(&hdev->static_addr, &hdev->random_addr)) + hci_req_add(req, HCI_OP_LE_SET_RANDOM_ADDR, 6, + &hdev->static_addr); + return 0; + } + + /* Neither privacy nor static address is being used so use a + * public address. + */ + *own_addr_type = ADDR_LE_DEV_PUBLIC; + + return 0; +} + +/* Ensure to call hci_req_add_le_scan_disable() first to disable the + * controller based address resolution to be able to reconfigure + * resolving list. + */ +void hci_req_add_le_passive_scan(struct hci_request *req) +{ + struct hci_dev *hdev = req->hdev; + u8 own_addr_type; + u8 filter_policy; + u16 window, interval; + /* Default is to enable duplicates filter */ + u8 filter_dup = LE_SCAN_FILTER_DUP_ENABLE; + /* Background scanning should run with address resolution */ + bool addr_resolv = true; + + if (hdev->scanning_paused) { + bt_dev_dbg(hdev, "Scanning is paused for suspend"); + return; + } + + /* Set require_privacy to false since no SCAN_REQ are send + * during passive scanning. Not using an non-resolvable address + * here is important so that peer devices using direct + * advertising with our address will be correctly reported + * by the controller. + */ + if (hci_update_random_address(req, false, scan_use_rpa(hdev), + &own_addr_type)) + return; + + if (hdev->enable_advmon_interleave_scan && + __hci_update_interleaved_scan(hdev)) + return; + + bt_dev_dbg(hdev, "interleave state %d", hdev->interleave_scan_state); + /* Adding or removing entries from the accept list must + * happen before enabling scanning. The controller does + * not allow accept list modification while scanning. + */ + filter_policy = update_accept_list(req); + + /* When the controller is using random resolvable addresses and + * with that having LE privacy enabled, then controllers with + * Extended Scanner Filter Policies support can now enable support + * for handling directed advertising. + * + * So instead of using filter polices 0x00 (no accept list) + * and 0x01 (accept list enabled) use the new filter policies + * 0x02 (no accept list) and 0x03 (accept list enabled). + */ + if (hci_dev_test_flag(hdev, HCI_PRIVACY) && + (hdev->le_features[0] & HCI_LE_EXT_SCAN_POLICY)) + filter_policy |= 0x02; + + if (hdev->suspended) { + window = hdev->le_scan_window_suspend; + interval = hdev->le_scan_int_suspend; + } else if (hci_is_le_conn_scanning(hdev)) { + window = hdev->le_scan_window_connect; + interval = hdev->le_scan_int_connect; + } else if (hci_is_adv_monitoring(hdev)) { + window = hdev->le_scan_window_adv_monitor; + interval = hdev->le_scan_int_adv_monitor; + + /* Disable duplicates filter when scanning for advertisement + * monitor for the following reasons. + * + * For HW pattern filtering (ex. MSFT), Realtek and Qualcomm + * controllers ignore RSSI_Sampling_Period when the duplicates + * filter is enabled. + * + * For SW pattern filtering, when we're not doing interleaved + * scanning, it is necessary to disable duplicates filter, + * otherwise hosts can only receive one advertisement and it's + * impossible to know if a peer is still in range. + */ + filter_dup = LE_SCAN_FILTER_DUP_DISABLE; + } else { + window = hdev->le_scan_window; + interval = hdev->le_scan_interval; + } + + bt_dev_dbg(hdev, "LE passive scan with accept list = %d", + filter_policy); + hci_req_start_scan(req, LE_SCAN_PASSIVE, interval, window, + own_addr_type, filter_policy, filter_dup, + addr_resolv); +} + +static int hci_req_add_le_interleaved_scan(struct hci_request *req, + unsigned long opt) +{ + struct hci_dev *hdev = req->hdev; + int ret = 0; + + hci_dev_lock(hdev); + + if (hci_dev_test_flag(hdev, HCI_LE_SCAN)) + hci_req_add_le_scan_disable(req, false); + hci_req_add_le_passive_scan(req); + + switch (hdev->interleave_scan_state) { + case INTERLEAVE_SCAN_ALLOWLIST: + bt_dev_dbg(hdev, "next state: allowlist"); + hdev->interleave_scan_state = INTERLEAVE_SCAN_NO_FILTER; + break; + case INTERLEAVE_SCAN_NO_FILTER: + bt_dev_dbg(hdev, "next state: no filter"); + hdev->interleave_scan_state = INTERLEAVE_SCAN_ALLOWLIST; + break; + case INTERLEAVE_SCAN_NONE: + BT_ERR("unexpected error"); + ret = -1; + } + + hci_dev_unlock(hdev); + + return ret; +} + +static void interleave_scan_work(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, + interleave_scan.work); + u8 status; + unsigned long timeout; + + if (hdev->interleave_scan_state == INTERLEAVE_SCAN_ALLOWLIST) { + timeout = msecs_to_jiffies(hdev->advmon_allowlist_duration); + } else if (hdev->interleave_scan_state == INTERLEAVE_SCAN_NO_FILTER) { + timeout = msecs_to_jiffies(hdev->advmon_no_filter_duration); + } else { + bt_dev_err(hdev, "unexpected error"); + return; + } + + hci_req_sync(hdev, hci_req_add_le_interleaved_scan, 0, + HCI_CMD_TIMEOUT, &status); + + /* Don't continue interleaving if it was canceled */ + if (is_interleave_scanning(hdev)) + queue_delayed_work(hdev->req_workqueue, + &hdev->interleave_scan, timeout); +} + +static void set_random_addr(struct hci_request *req, bdaddr_t *rpa) +{ + struct hci_dev *hdev = req->hdev; + + /* If we're advertising or initiating an LE connection we can't + * go ahead and change the random address at this time. This is + * because the eventual initiator address used for the + * subsequently created connection will be undefined (some + * controllers use the new address and others the one we had + * when the operation started). + * + * In this kind of scenario skip the update and let the random + * address be updated at the next cycle. + */ + if (hci_dev_test_flag(hdev, HCI_LE_ADV) || + hci_lookup_le_connect(hdev)) { + bt_dev_dbg(hdev, "Deferring random address update"); + hci_dev_set_flag(hdev, HCI_RPA_EXPIRED); + return; + } + + hci_req_add(req, HCI_OP_LE_SET_RANDOM_ADDR, 6, rpa); +} + +void hci_request_setup(struct hci_dev *hdev) +{ + INIT_DELAYED_WORK(&hdev->interleave_scan, interleave_scan_work); +} + +void hci_request_cancel_all(struct hci_dev *hdev) +{ + __hci_cmd_sync_cancel(hdev, ENODEV); + + cancel_interleave_scan(hdev); +} diff --git a/net/bluetooth/hci_request.h b/net/bluetooth/hci_request.h new file mode 100644 index 000000000..0be75cf0e --- /dev/null +++ b/net/bluetooth/hci_request.h @@ -0,0 +1,75 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + Copyright (C) 2014 Intel Corporation + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#include + +#define HCI_REQ_DONE 0 +#define HCI_REQ_PEND 1 +#define HCI_REQ_CANCELED 2 + +#define hci_req_sync_lock(hdev) mutex_lock(&hdev->req_lock) +#define hci_req_sync_unlock(hdev) mutex_unlock(&hdev->req_lock) + +#define HCI_REQ_DONE 0 +#define HCI_REQ_PEND 1 +#define HCI_REQ_CANCELED 2 + +struct hci_request { + struct hci_dev *hdev; + struct sk_buff_head cmd_q; + + /* If something goes wrong when building the HCI request, the error + * value is stored in this field. + */ + int err; +}; + +void hci_req_init(struct hci_request *req, struct hci_dev *hdev); +void hci_req_purge(struct hci_request *req); +bool hci_req_status_pend(struct hci_dev *hdev); +int hci_req_run(struct hci_request *req, hci_req_complete_t complete); +int hci_req_run_skb(struct hci_request *req, hci_req_complete_skb_t complete); +void hci_req_sync_complete(struct hci_dev *hdev, u8 result, u16 opcode, + struct sk_buff *skb); +void hci_req_add(struct hci_request *req, u16 opcode, u32 plen, + const void *param); +void hci_req_add_ev(struct hci_request *req, u16 opcode, u32 plen, + const void *param, u8 event); +void hci_req_cmd_complete(struct hci_dev *hdev, u16 opcode, u8 status, + hci_req_complete_t *req_complete, + hci_req_complete_skb_t *req_complete_skb); + +int hci_req_sync(struct hci_dev *hdev, int (*req)(struct hci_request *req, + unsigned long opt), + unsigned long opt, u32 timeout, u8 *hci_status); +int __hci_req_sync(struct hci_dev *hdev, int (*func)(struct hci_request *req, + unsigned long opt), + unsigned long opt, u32 timeout, u8 *hci_status); + +struct sk_buff *hci_prepare_cmd(struct hci_dev *hdev, u16 opcode, u32 plen, + const void *param); + +void hci_req_add_le_scan_disable(struct hci_request *req, bool rpa_le_conn); +void hci_req_add_le_passive_scan(struct hci_request *req); + +void hci_request_setup(struct hci_dev *hdev); +void hci_request_cancel_all(struct hci_dev *hdev); diff --git a/net/bluetooth/hci_sock.c b/net/bluetooth/hci_sock.c new file mode 100644 index 000000000..484fc2a8e --- /dev/null +++ b/net/bluetooth/hci_sock.c @@ -0,0 +1,2208 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + Copyright (C) 2000-2001 Qualcomm Incorporated + + Written 2000,2001 by Maxim Krasnyansky + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +/* Bluetooth HCI sockets. */ +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "mgmt_util.h" + +static LIST_HEAD(mgmt_chan_list); +static DEFINE_MUTEX(mgmt_chan_list_lock); + +static DEFINE_IDA(sock_cookie_ida); + +static atomic_t monitor_promisc = ATOMIC_INIT(0); + +/* ----- HCI socket interface ----- */ + +/* Socket info */ +#define hci_pi(sk) ((struct hci_pinfo *) sk) + +struct hci_pinfo { + struct bt_sock bt; + struct hci_dev *hdev; + struct hci_filter filter; + __u8 cmsg_mask; + unsigned short channel; + unsigned long flags; + __u32 cookie; + char comm[TASK_COMM_LEN]; + __u16 mtu; +}; + +static struct hci_dev *hci_hdev_from_sock(struct sock *sk) +{ + struct hci_dev *hdev = hci_pi(sk)->hdev; + + if (!hdev) + return ERR_PTR(-EBADFD); + if (hci_dev_test_flag(hdev, HCI_UNREGISTER)) + return ERR_PTR(-EPIPE); + return hdev; +} + +void hci_sock_set_flag(struct sock *sk, int nr) +{ + set_bit(nr, &hci_pi(sk)->flags); +} + +void hci_sock_clear_flag(struct sock *sk, int nr) +{ + clear_bit(nr, &hci_pi(sk)->flags); +} + +int hci_sock_test_flag(struct sock *sk, int nr) +{ + return test_bit(nr, &hci_pi(sk)->flags); +} + +unsigned short hci_sock_get_channel(struct sock *sk) +{ + return hci_pi(sk)->channel; +} + +u32 hci_sock_get_cookie(struct sock *sk) +{ + return hci_pi(sk)->cookie; +} + +static bool hci_sock_gen_cookie(struct sock *sk) +{ + int id = hci_pi(sk)->cookie; + + if (!id) { + id = ida_simple_get(&sock_cookie_ida, 1, 0, GFP_KERNEL); + if (id < 0) + id = 0xffffffff; + + hci_pi(sk)->cookie = id; + get_task_comm(hci_pi(sk)->comm, current); + return true; + } + + return false; +} + +static void hci_sock_free_cookie(struct sock *sk) +{ + int id = hci_pi(sk)->cookie; + + if (id) { + hci_pi(sk)->cookie = 0xffffffff; + ida_simple_remove(&sock_cookie_ida, id); + } +} + +static inline int hci_test_bit(int nr, const void *addr) +{ + return *((const __u32 *) addr + (nr >> 5)) & ((__u32) 1 << (nr & 31)); +} + +/* Security filter */ +#define HCI_SFLT_MAX_OGF 5 + +struct hci_sec_filter { + __u32 type_mask; + __u32 event_mask[2]; + __u32 ocf_mask[HCI_SFLT_MAX_OGF + 1][4]; +}; + +static const struct hci_sec_filter hci_sec_filter = { + /* Packet types */ + 0x10, + /* Events */ + { 0x1000d9fe, 0x0000b00c }, + /* Commands */ + { + { 0x0 }, + /* OGF_LINK_CTL */ + { 0xbe000006, 0x00000001, 0x00000000, 0x00 }, + /* OGF_LINK_POLICY */ + { 0x00005200, 0x00000000, 0x00000000, 0x00 }, + /* OGF_HOST_CTL */ + { 0xaab00200, 0x2b402aaa, 0x05220154, 0x00 }, + /* OGF_INFO_PARAM */ + { 0x000002be, 0x00000000, 0x00000000, 0x00 }, + /* OGF_STATUS_PARAM */ + { 0x000000ea, 0x00000000, 0x00000000, 0x00 } + } +}; + +static struct bt_sock_list hci_sk_list = { + .lock = __RW_LOCK_UNLOCKED(hci_sk_list.lock) +}; + +static bool is_filtered_packet(struct sock *sk, struct sk_buff *skb) +{ + struct hci_filter *flt; + int flt_type, flt_event; + + /* Apply filter */ + flt = &hci_pi(sk)->filter; + + flt_type = hci_skb_pkt_type(skb) & HCI_FLT_TYPE_BITS; + + if (!test_bit(flt_type, &flt->type_mask)) + return true; + + /* Extra filter for event packets only */ + if (hci_skb_pkt_type(skb) != HCI_EVENT_PKT) + return false; + + flt_event = (*(__u8 *)skb->data & HCI_FLT_EVENT_BITS); + + if (!hci_test_bit(flt_event, &flt->event_mask)) + return true; + + /* Check filter only when opcode is set */ + if (!flt->opcode) + return false; + + if (flt_event == HCI_EV_CMD_COMPLETE && + flt->opcode != get_unaligned((__le16 *)(skb->data + 3))) + return true; + + if (flt_event == HCI_EV_CMD_STATUS && + flt->opcode != get_unaligned((__le16 *)(skb->data + 4))) + return true; + + return false; +} + +/* Send frame to RAW socket */ +void hci_send_to_sock(struct hci_dev *hdev, struct sk_buff *skb) +{ + struct sock *sk; + struct sk_buff *skb_copy = NULL; + + BT_DBG("hdev %p len %d", hdev, skb->len); + + read_lock(&hci_sk_list.lock); + + sk_for_each(sk, &hci_sk_list.head) { + struct sk_buff *nskb; + + if (sk->sk_state != BT_BOUND || hci_pi(sk)->hdev != hdev) + continue; + + /* Don't send frame to the socket it came from */ + if (skb->sk == sk) + continue; + + if (hci_pi(sk)->channel == HCI_CHANNEL_RAW) { + if (hci_skb_pkt_type(skb) != HCI_COMMAND_PKT && + hci_skb_pkt_type(skb) != HCI_EVENT_PKT && + hci_skb_pkt_type(skb) != HCI_ACLDATA_PKT && + hci_skb_pkt_type(skb) != HCI_SCODATA_PKT && + hci_skb_pkt_type(skb) != HCI_ISODATA_PKT) + continue; + if (is_filtered_packet(sk, skb)) + continue; + } else if (hci_pi(sk)->channel == HCI_CHANNEL_USER) { + if (!bt_cb(skb)->incoming) + continue; + if (hci_skb_pkt_type(skb) != HCI_EVENT_PKT && + hci_skb_pkt_type(skb) != HCI_ACLDATA_PKT && + hci_skb_pkt_type(skb) != HCI_SCODATA_PKT && + hci_skb_pkt_type(skb) != HCI_ISODATA_PKT) + continue; + } else { + /* Don't send frame to other channel types */ + continue; + } + + if (!skb_copy) { + /* Create a private copy with headroom */ + skb_copy = __pskb_copy_fclone(skb, 1, GFP_ATOMIC, true); + if (!skb_copy) + continue; + + /* Put type byte before the data */ + memcpy(skb_push(skb_copy, 1), &hci_skb_pkt_type(skb), 1); + } + + nskb = skb_clone(skb_copy, GFP_ATOMIC); + if (!nskb) + continue; + + if (sock_queue_rcv_skb(sk, nskb)) + kfree_skb(nskb); + } + + read_unlock(&hci_sk_list.lock); + + kfree_skb(skb_copy); +} + +/* Send frame to sockets with specific channel */ +static void __hci_send_to_channel(unsigned short channel, struct sk_buff *skb, + int flag, struct sock *skip_sk) +{ + struct sock *sk; + + BT_DBG("channel %u len %d", channel, skb->len); + + sk_for_each(sk, &hci_sk_list.head) { + struct sk_buff *nskb; + + /* Ignore socket without the flag set */ + if (!hci_sock_test_flag(sk, flag)) + continue; + + /* Skip the original socket */ + if (sk == skip_sk) + continue; + + if (sk->sk_state != BT_BOUND) + continue; + + if (hci_pi(sk)->channel != channel) + continue; + + nskb = skb_clone(skb, GFP_ATOMIC); + if (!nskb) + continue; + + if (sock_queue_rcv_skb(sk, nskb)) + kfree_skb(nskb); + } + +} + +void hci_send_to_channel(unsigned short channel, struct sk_buff *skb, + int flag, struct sock *skip_sk) +{ + read_lock(&hci_sk_list.lock); + __hci_send_to_channel(channel, skb, flag, skip_sk); + read_unlock(&hci_sk_list.lock); +} + +/* Send frame to monitor socket */ +void hci_send_to_monitor(struct hci_dev *hdev, struct sk_buff *skb) +{ + struct sk_buff *skb_copy = NULL; + struct hci_mon_hdr *hdr; + __le16 opcode; + + if (!atomic_read(&monitor_promisc)) + return; + + BT_DBG("hdev %p len %d", hdev, skb->len); + + switch (hci_skb_pkt_type(skb)) { + case HCI_COMMAND_PKT: + opcode = cpu_to_le16(HCI_MON_COMMAND_PKT); + break; + case HCI_EVENT_PKT: + opcode = cpu_to_le16(HCI_MON_EVENT_PKT); + break; + case HCI_ACLDATA_PKT: + if (bt_cb(skb)->incoming) + opcode = cpu_to_le16(HCI_MON_ACL_RX_PKT); + else + opcode = cpu_to_le16(HCI_MON_ACL_TX_PKT); + break; + case HCI_SCODATA_PKT: + if (bt_cb(skb)->incoming) + opcode = cpu_to_le16(HCI_MON_SCO_RX_PKT); + else + opcode = cpu_to_le16(HCI_MON_SCO_TX_PKT); + break; + case HCI_ISODATA_PKT: + if (bt_cb(skb)->incoming) + opcode = cpu_to_le16(HCI_MON_ISO_RX_PKT); + else + opcode = cpu_to_le16(HCI_MON_ISO_TX_PKT); + break; + case HCI_DIAG_PKT: + opcode = cpu_to_le16(HCI_MON_VENDOR_DIAG); + break; + default: + return; + } + + /* Create a private copy with headroom */ + skb_copy = __pskb_copy_fclone(skb, HCI_MON_HDR_SIZE, GFP_ATOMIC, true); + if (!skb_copy) + return; + + /* Put header before the data */ + hdr = skb_push(skb_copy, HCI_MON_HDR_SIZE); + hdr->opcode = opcode; + hdr->index = cpu_to_le16(hdev->id); + hdr->len = cpu_to_le16(skb->len); + + hci_send_to_channel(HCI_CHANNEL_MONITOR, skb_copy, + HCI_SOCK_TRUSTED, NULL); + kfree_skb(skb_copy); +} + +void hci_send_monitor_ctrl_event(struct hci_dev *hdev, u16 event, + void *data, u16 data_len, ktime_t tstamp, + int flag, struct sock *skip_sk) +{ + struct sock *sk; + __le16 index; + + if (hdev) + index = cpu_to_le16(hdev->id); + else + index = cpu_to_le16(MGMT_INDEX_NONE); + + read_lock(&hci_sk_list.lock); + + sk_for_each(sk, &hci_sk_list.head) { + struct hci_mon_hdr *hdr; + struct sk_buff *skb; + + if (hci_pi(sk)->channel != HCI_CHANNEL_CONTROL) + continue; + + /* Ignore socket without the flag set */ + if (!hci_sock_test_flag(sk, flag)) + continue; + + /* Skip the original socket */ + if (sk == skip_sk) + continue; + + skb = bt_skb_alloc(6 + data_len, GFP_ATOMIC); + if (!skb) + continue; + + put_unaligned_le32(hci_pi(sk)->cookie, skb_put(skb, 4)); + put_unaligned_le16(event, skb_put(skb, 2)); + + if (data) + skb_put_data(skb, data, data_len); + + skb->tstamp = tstamp; + + hdr = skb_push(skb, HCI_MON_HDR_SIZE); + hdr->opcode = cpu_to_le16(HCI_MON_CTRL_EVENT); + hdr->index = index; + hdr->len = cpu_to_le16(skb->len - HCI_MON_HDR_SIZE); + + __hci_send_to_channel(HCI_CHANNEL_MONITOR, skb, + HCI_SOCK_TRUSTED, NULL); + kfree_skb(skb); + } + + read_unlock(&hci_sk_list.lock); +} + +static struct sk_buff *create_monitor_event(struct hci_dev *hdev, int event) +{ + struct hci_mon_hdr *hdr; + struct hci_mon_new_index *ni; + struct hci_mon_index_info *ii; + struct sk_buff *skb; + __le16 opcode; + + switch (event) { + case HCI_DEV_REG: + skb = bt_skb_alloc(HCI_MON_NEW_INDEX_SIZE, GFP_ATOMIC); + if (!skb) + return NULL; + + ni = skb_put(skb, HCI_MON_NEW_INDEX_SIZE); + ni->type = hdev->dev_type; + ni->bus = hdev->bus; + bacpy(&ni->bdaddr, &hdev->bdaddr); + memcpy_and_pad(ni->name, sizeof(ni->name), hdev->name, + strnlen(hdev->name, sizeof(ni->name)), '\0'); + + opcode = cpu_to_le16(HCI_MON_NEW_INDEX); + break; + + case HCI_DEV_UNREG: + skb = bt_skb_alloc(0, GFP_ATOMIC); + if (!skb) + return NULL; + + opcode = cpu_to_le16(HCI_MON_DEL_INDEX); + break; + + case HCI_DEV_SETUP: + if (hdev->manufacturer == 0xffff) + return NULL; + fallthrough; + + case HCI_DEV_UP: + skb = bt_skb_alloc(HCI_MON_INDEX_INFO_SIZE, GFP_ATOMIC); + if (!skb) + return NULL; + + ii = skb_put(skb, HCI_MON_INDEX_INFO_SIZE); + bacpy(&ii->bdaddr, &hdev->bdaddr); + ii->manufacturer = cpu_to_le16(hdev->manufacturer); + + opcode = cpu_to_le16(HCI_MON_INDEX_INFO); + break; + + case HCI_DEV_OPEN: + skb = bt_skb_alloc(0, GFP_ATOMIC); + if (!skb) + return NULL; + + opcode = cpu_to_le16(HCI_MON_OPEN_INDEX); + break; + + case HCI_DEV_CLOSE: + skb = bt_skb_alloc(0, GFP_ATOMIC); + if (!skb) + return NULL; + + opcode = cpu_to_le16(HCI_MON_CLOSE_INDEX); + break; + + default: + return NULL; + } + + __net_timestamp(skb); + + hdr = skb_push(skb, HCI_MON_HDR_SIZE); + hdr->opcode = opcode; + hdr->index = cpu_to_le16(hdev->id); + hdr->len = cpu_to_le16(skb->len - HCI_MON_HDR_SIZE); + + return skb; +} + +static struct sk_buff *create_monitor_ctrl_open(struct sock *sk) +{ + struct hci_mon_hdr *hdr; + struct sk_buff *skb; + u16 format; + u8 ver[3]; + u32 flags; + + /* No message needed when cookie is not present */ + if (!hci_pi(sk)->cookie) + return NULL; + + switch (hci_pi(sk)->channel) { + case HCI_CHANNEL_RAW: + format = 0x0000; + ver[0] = BT_SUBSYS_VERSION; + put_unaligned_le16(BT_SUBSYS_REVISION, ver + 1); + break; + case HCI_CHANNEL_USER: + format = 0x0001; + ver[0] = BT_SUBSYS_VERSION; + put_unaligned_le16(BT_SUBSYS_REVISION, ver + 1); + break; + case HCI_CHANNEL_CONTROL: + format = 0x0002; + mgmt_fill_version_info(ver); + break; + default: + /* No message for unsupported format */ + return NULL; + } + + skb = bt_skb_alloc(14 + TASK_COMM_LEN , GFP_ATOMIC); + if (!skb) + return NULL; + + flags = hci_sock_test_flag(sk, HCI_SOCK_TRUSTED) ? 0x1 : 0x0; + + put_unaligned_le32(hci_pi(sk)->cookie, skb_put(skb, 4)); + put_unaligned_le16(format, skb_put(skb, 2)); + skb_put_data(skb, ver, sizeof(ver)); + put_unaligned_le32(flags, skb_put(skb, 4)); + skb_put_u8(skb, TASK_COMM_LEN); + skb_put_data(skb, hci_pi(sk)->comm, TASK_COMM_LEN); + + __net_timestamp(skb); + + hdr = skb_push(skb, HCI_MON_HDR_SIZE); + hdr->opcode = cpu_to_le16(HCI_MON_CTRL_OPEN); + if (hci_pi(sk)->hdev) + hdr->index = cpu_to_le16(hci_pi(sk)->hdev->id); + else + hdr->index = cpu_to_le16(HCI_DEV_NONE); + hdr->len = cpu_to_le16(skb->len - HCI_MON_HDR_SIZE); + + return skb; +} + +static struct sk_buff *create_monitor_ctrl_close(struct sock *sk) +{ + struct hci_mon_hdr *hdr; + struct sk_buff *skb; + + /* No message needed when cookie is not present */ + if (!hci_pi(sk)->cookie) + return NULL; + + switch (hci_pi(sk)->channel) { + case HCI_CHANNEL_RAW: + case HCI_CHANNEL_USER: + case HCI_CHANNEL_CONTROL: + break; + default: + /* No message for unsupported format */ + return NULL; + } + + skb = bt_skb_alloc(4, GFP_ATOMIC); + if (!skb) + return NULL; + + put_unaligned_le32(hci_pi(sk)->cookie, skb_put(skb, 4)); + + __net_timestamp(skb); + + hdr = skb_push(skb, HCI_MON_HDR_SIZE); + hdr->opcode = cpu_to_le16(HCI_MON_CTRL_CLOSE); + if (hci_pi(sk)->hdev) + hdr->index = cpu_to_le16(hci_pi(sk)->hdev->id); + else + hdr->index = cpu_to_le16(HCI_DEV_NONE); + hdr->len = cpu_to_le16(skb->len - HCI_MON_HDR_SIZE); + + return skb; +} + +static struct sk_buff *create_monitor_ctrl_command(struct sock *sk, u16 index, + u16 opcode, u16 len, + const void *buf) +{ + struct hci_mon_hdr *hdr; + struct sk_buff *skb; + + skb = bt_skb_alloc(6 + len, GFP_ATOMIC); + if (!skb) + return NULL; + + put_unaligned_le32(hci_pi(sk)->cookie, skb_put(skb, 4)); + put_unaligned_le16(opcode, skb_put(skb, 2)); + + if (buf) + skb_put_data(skb, buf, len); + + __net_timestamp(skb); + + hdr = skb_push(skb, HCI_MON_HDR_SIZE); + hdr->opcode = cpu_to_le16(HCI_MON_CTRL_COMMAND); + hdr->index = cpu_to_le16(index); + hdr->len = cpu_to_le16(skb->len - HCI_MON_HDR_SIZE); + + return skb; +} + +static void __printf(2, 3) +send_monitor_note(struct sock *sk, const char *fmt, ...) +{ + size_t len; + struct hci_mon_hdr *hdr; + struct sk_buff *skb; + va_list args; + + va_start(args, fmt); + len = vsnprintf(NULL, 0, fmt, args); + va_end(args); + + skb = bt_skb_alloc(len + 1, GFP_ATOMIC); + if (!skb) + return; + + va_start(args, fmt); + vsprintf(skb_put(skb, len), fmt, args); + *(u8 *)skb_put(skb, 1) = 0; + va_end(args); + + __net_timestamp(skb); + + hdr = (void *)skb_push(skb, HCI_MON_HDR_SIZE); + hdr->opcode = cpu_to_le16(HCI_MON_SYSTEM_NOTE); + hdr->index = cpu_to_le16(HCI_DEV_NONE); + hdr->len = cpu_to_le16(skb->len - HCI_MON_HDR_SIZE); + + if (sock_queue_rcv_skb(sk, skb)) + kfree_skb(skb); +} + +static void send_monitor_replay(struct sock *sk) +{ + struct hci_dev *hdev; + + read_lock(&hci_dev_list_lock); + + list_for_each_entry(hdev, &hci_dev_list, list) { + struct sk_buff *skb; + + skb = create_monitor_event(hdev, HCI_DEV_REG); + if (!skb) + continue; + + if (sock_queue_rcv_skb(sk, skb)) + kfree_skb(skb); + + if (!test_bit(HCI_RUNNING, &hdev->flags)) + continue; + + skb = create_monitor_event(hdev, HCI_DEV_OPEN); + if (!skb) + continue; + + if (sock_queue_rcv_skb(sk, skb)) + kfree_skb(skb); + + if (test_bit(HCI_UP, &hdev->flags)) + skb = create_monitor_event(hdev, HCI_DEV_UP); + else if (hci_dev_test_flag(hdev, HCI_SETUP)) + skb = create_monitor_event(hdev, HCI_DEV_SETUP); + else + skb = NULL; + + if (skb) { + if (sock_queue_rcv_skb(sk, skb)) + kfree_skb(skb); + } + } + + read_unlock(&hci_dev_list_lock); +} + +static void send_monitor_control_replay(struct sock *mon_sk) +{ + struct sock *sk; + + read_lock(&hci_sk_list.lock); + + sk_for_each(sk, &hci_sk_list.head) { + struct sk_buff *skb; + + skb = create_monitor_ctrl_open(sk); + if (!skb) + continue; + + if (sock_queue_rcv_skb(mon_sk, skb)) + kfree_skb(skb); + } + + read_unlock(&hci_sk_list.lock); +} + +/* Generate internal stack event */ +static void hci_si_event(struct hci_dev *hdev, int type, int dlen, void *data) +{ + struct hci_event_hdr *hdr; + struct hci_ev_stack_internal *ev; + struct sk_buff *skb; + + skb = bt_skb_alloc(HCI_EVENT_HDR_SIZE + sizeof(*ev) + dlen, GFP_ATOMIC); + if (!skb) + return; + + hdr = skb_put(skb, HCI_EVENT_HDR_SIZE); + hdr->evt = HCI_EV_STACK_INTERNAL; + hdr->plen = sizeof(*ev) + dlen; + + ev = skb_put(skb, sizeof(*ev) + dlen); + ev->type = type; + memcpy(ev->data, data, dlen); + + bt_cb(skb)->incoming = 1; + __net_timestamp(skb); + + hci_skb_pkt_type(skb) = HCI_EVENT_PKT; + hci_send_to_sock(hdev, skb); + kfree_skb(skb); +} + +void hci_sock_dev_event(struct hci_dev *hdev, int event) +{ + BT_DBG("hdev %s event %d", hdev->name, event); + + if (atomic_read(&monitor_promisc)) { + struct sk_buff *skb; + + /* Send event to monitor */ + skb = create_monitor_event(hdev, event); + if (skb) { + hci_send_to_channel(HCI_CHANNEL_MONITOR, skb, + HCI_SOCK_TRUSTED, NULL); + kfree_skb(skb); + } + } + + if (event <= HCI_DEV_DOWN) { + struct hci_ev_si_device ev; + + /* Send event to sockets */ + ev.event = event; + ev.dev_id = hdev->id; + hci_si_event(NULL, HCI_EV_SI_DEVICE, sizeof(ev), &ev); + } + + if (event == HCI_DEV_UNREG) { + struct sock *sk; + + /* Wake up sockets using this dead device */ + read_lock(&hci_sk_list.lock); + sk_for_each(sk, &hci_sk_list.head) { + if (hci_pi(sk)->hdev == hdev) { + sk->sk_err = EPIPE; + sk->sk_state_change(sk); + } + } + read_unlock(&hci_sk_list.lock); + } +} + +static struct hci_mgmt_chan *__hci_mgmt_chan_find(unsigned short channel) +{ + struct hci_mgmt_chan *c; + + list_for_each_entry(c, &mgmt_chan_list, list) { + if (c->channel == channel) + return c; + } + + return NULL; +} + +static struct hci_mgmt_chan *hci_mgmt_chan_find(unsigned short channel) +{ + struct hci_mgmt_chan *c; + + mutex_lock(&mgmt_chan_list_lock); + c = __hci_mgmt_chan_find(channel); + mutex_unlock(&mgmt_chan_list_lock); + + return c; +} + +int hci_mgmt_chan_register(struct hci_mgmt_chan *c) +{ + if (c->channel < HCI_CHANNEL_CONTROL) + return -EINVAL; + + mutex_lock(&mgmt_chan_list_lock); + if (__hci_mgmt_chan_find(c->channel)) { + mutex_unlock(&mgmt_chan_list_lock); + return -EALREADY; + } + + list_add_tail(&c->list, &mgmt_chan_list); + + mutex_unlock(&mgmt_chan_list_lock); + + return 0; +} +EXPORT_SYMBOL(hci_mgmt_chan_register); + +void hci_mgmt_chan_unregister(struct hci_mgmt_chan *c) +{ + mutex_lock(&mgmt_chan_list_lock); + list_del(&c->list); + mutex_unlock(&mgmt_chan_list_lock); +} +EXPORT_SYMBOL(hci_mgmt_chan_unregister); + +static int hci_sock_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct hci_dev *hdev; + struct sk_buff *skb; + + BT_DBG("sock %p sk %p", sock, sk); + + if (!sk) + return 0; + + lock_sock(sk); + + switch (hci_pi(sk)->channel) { + case HCI_CHANNEL_MONITOR: + atomic_dec(&monitor_promisc); + break; + case HCI_CHANNEL_RAW: + case HCI_CHANNEL_USER: + case HCI_CHANNEL_CONTROL: + /* Send event to monitor */ + skb = create_monitor_ctrl_close(sk); + if (skb) { + hci_send_to_channel(HCI_CHANNEL_MONITOR, skb, + HCI_SOCK_TRUSTED, NULL); + kfree_skb(skb); + } + + hci_sock_free_cookie(sk); + break; + } + + bt_sock_unlink(&hci_sk_list, sk); + + hdev = hci_pi(sk)->hdev; + if (hdev) { + if (hci_pi(sk)->channel == HCI_CHANNEL_USER && + !hci_dev_test_flag(hdev, HCI_UNREGISTER)) { + /* When releasing a user channel exclusive access, + * call hci_dev_do_close directly instead of calling + * hci_dev_close to ensure the exclusive access will + * be released and the controller brought back down. + * + * The checking of HCI_AUTO_OFF is not needed in this + * case since it will have been cleared already when + * opening the user channel. + * + * Make sure to also check that we haven't already + * unregistered since all the cleanup will have already + * been complete and hdev will get released when we put + * below. + */ + hci_dev_do_close(hdev); + hci_dev_clear_flag(hdev, HCI_USER_CHANNEL); + mgmt_index_added(hdev); + } + + atomic_dec(&hdev->promisc); + hci_dev_put(hdev); + } + + sock_orphan(sk); + release_sock(sk); + sock_put(sk); + return 0; +} + +static int hci_sock_reject_list_add(struct hci_dev *hdev, void __user *arg) +{ + bdaddr_t bdaddr; + int err; + + if (copy_from_user(&bdaddr, arg, sizeof(bdaddr))) + return -EFAULT; + + hci_dev_lock(hdev); + + err = hci_bdaddr_list_add(&hdev->reject_list, &bdaddr, BDADDR_BREDR); + + hci_dev_unlock(hdev); + + return err; +} + +static int hci_sock_reject_list_del(struct hci_dev *hdev, void __user *arg) +{ + bdaddr_t bdaddr; + int err; + + if (copy_from_user(&bdaddr, arg, sizeof(bdaddr))) + return -EFAULT; + + hci_dev_lock(hdev); + + err = hci_bdaddr_list_del(&hdev->reject_list, &bdaddr, BDADDR_BREDR); + + hci_dev_unlock(hdev); + + return err; +} + +/* Ioctls that require bound socket */ +static int hci_sock_bound_ioctl(struct sock *sk, unsigned int cmd, + unsigned long arg) +{ + struct hci_dev *hdev = hci_hdev_from_sock(sk); + + if (IS_ERR(hdev)) + return PTR_ERR(hdev); + + if (hci_dev_test_flag(hdev, HCI_USER_CHANNEL)) + return -EBUSY; + + if (hci_dev_test_flag(hdev, HCI_UNCONFIGURED)) + return -EOPNOTSUPP; + + if (hdev->dev_type != HCI_PRIMARY) + return -EOPNOTSUPP; + + switch (cmd) { + case HCISETRAW: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + return -EOPNOTSUPP; + + case HCIGETCONNINFO: + return hci_get_conn_info(hdev, (void __user *)arg); + + case HCIGETAUTHINFO: + return hci_get_auth_info(hdev, (void __user *)arg); + + case HCIBLOCKADDR: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + return hci_sock_reject_list_add(hdev, (void __user *)arg); + + case HCIUNBLOCKADDR: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + return hci_sock_reject_list_del(hdev, (void __user *)arg); + } + + return -ENOIOCTLCMD; +} + +static int hci_sock_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + void __user *argp = (void __user *)arg; + struct sock *sk = sock->sk; + int err; + + BT_DBG("cmd %x arg %lx", cmd, arg); + + /* Make sure the cmd is valid before doing anything */ + switch (cmd) { + case HCIGETDEVLIST: + case HCIGETDEVINFO: + case HCIGETCONNLIST: + case HCIDEVUP: + case HCIDEVDOWN: + case HCIDEVRESET: + case HCIDEVRESTAT: + case HCISETSCAN: + case HCISETAUTH: + case HCISETENCRYPT: + case HCISETPTYPE: + case HCISETLINKPOL: + case HCISETLINKMODE: + case HCISETACLMTU: + case HCISETSCOMTU: + case HCIINQUIRY: + case HCISETRAW: + case HCIGETCONNINFO: + case HCIGETAUTHINFO: + case HCIBLOCKADDR: + case HCIUNBLOCKADDR: + break; + default: + return -ENOIOCTLCMD; + } + + lock_sock(sk); + + if (hci_pi(sk)->channel != HCI_CHANNEL_RAW) { + err = -EBADFD; + goto done; + } + + /* When calling an ioctl on an unbound raw socket, then ensure + * that the monitor gets informed. Ensure that the resulting event + * is only send once by checking if the cookie exists or not. The + * socket cookie will be only ever generated once for the lifetime + * of a given socket. + */ + if (hci_sock_gen_cookie(sk)) { + struct sk_buff *skb; + + /* Perform careful checks before setting the HCI_SOCK_TRUSTED + * flag. Make sure that not only the current task but also + * the socket opener has the required capability, since + * privileged programs can be tricked into making ioctl calls + * on HCI sockets, and the socket should not be marked as + * trusted simply because the ioctl caller is privileged. + */ + if (sk_capable(sk, CAP_NET_ADMIN)) + hci_sock_set_flag(sk, HCI_SOCK_TRUSTED); + + /* Send event to monitor */ + skb = create_monitor_ctrl_open(sk); + if (skb) { + hci_send_to_channel(HCI_CHANNEL_MONITOR, skb, + HCI_SOCK_TRUSTED, NULL); + kfree_skb(skb); + } + } + + release_sock(sk); + + switch (cmd) { + case HCIGETDEVLIST: + return hci_get_dev_list(argp); + + case HCIGETDEVINFO: + return hci_get_dev_info(argp); + + case HCIGETCONNLIST: + return hci_get_conn_list(argp); + + case HCIDEVUP: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + return hci_dev_open(arg); + + case HCIDEVDOWN: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + return hci_dev_close(arg); + + case HCIDEVRESET: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + return hci_dev_reset(arg); + + case HCIDEVRESTAT: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + return hci_dev_reset_stat(arg); + + case HCISETSCAN: + case HCISETAUTH: + case HCISETENCRYPT: + case HCISETPTYPE: + case HCISETLINKPOL: + case HCISETLINKMODE: + case HCISETACLMTU: + case HCISETSCOMTU: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + return hci_dev_cmd(cmd, argp); + + case HCIINQUIRY: + return hci_inquiry(argp); + } + + lock_sock(sk); + + err = hci_sock_bound_ioctl(sk, cmd, arg); + +done: + release_sock(sk); + return err; +} + +#ifdef CONFIG_COMPAT +static int hci_sock_compat_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + switch (cmd) { + case HCIDEVUP: + case HCIDEVDOWN: + case HCIDEVRESET: + case HCIDEVRESTAT: + return hci_sock_ioctl(sock, cmd, arg); + } + + return hci_sock_ioctl(sock, cmd, (unsigned long)compat_ptr(arg)); +} +#endif + +static int hci_sock_bind(struct socket *sock, struct sockaddr *addr, + int addr_len) +{ + struct sockaddr_hci haddr; + struct sock *sk = sock->sk; + struct hci_dev *hdev = NULL; + struct sk_buff *skb; + int len, err = 0; + + BT_DBG("sock %p sk %p", sock, sk); + + if (!addr) + return -EINVAL; + + memset(&haddr, 0, sizeof(haddr)); + len = min_t(unsigned int, sizeof(haddr), addr_len); + memcpy(&haddr, addr, len); + + if (haddr.hci_family != AF_BLUETOOTH) + return -EINVAL; + + lock_sock(sk); + + /* Allow detaching from dead device and attaching to alive device, if + * the caller wants to re-bind (instead of close) this socket in + * response to hci_sock_dev_event(HCI_DEV_UNREG) notification. + */ + hdev = hci_pi(sk)->hdev; + if (hdev && hci_dev_test_flag(hdev, HCI_UNREGISTER)) { + hci_pi(sk)->hdev = NULL; + sk->sk_state = BT_OPEN; + hci_dev_put(hdev); + } + hdev = NULL; + + if (sk->sk_state == BT_BOUND) { + err = -EALREADY; + goto done; + } + + switch (haddr.hci_channel) { + case HCI_CHANNEL_RAW: + if (hci_pi(sk)->hdev) { + err = -EALREADY; + goto done; + } + + if (haddr.hci_dev != HCI_DEV_NONE) { + hdev = hci_dev_get(haddr.hci_dev); + if (!hdev) { + err = -ENODEV; + goto done; + } + + atomic_inc(&hdev->promisc); + } + + hci_pi(sk)->channel = haddr.hci_channel; + + if (!hci_sock_gen_cookie(sk)) { + /* In the case when a cookie has already been assigned, + * then there has been already an ioctl issued against + * an unbound socket and with that triggered an open + * notification. Send a close notification first to + * allow the state transition to bounded. + */ + skb = create_monitor_ctrl_close(sk); + if (skb) { + hci_send_to_channel(HCI_CHANNEL_MONITOR, skb, + HCI_SOCK_TRUSTED, NULL); + kfree_skb(skb); + } + } + + if (capable(CAP_NET_ADMIN)) + hci_sock_set_flag(sk, HCI_SOCK_TRUSTED); + + hci_pi(sk)->hdev = hdev; + + /* Send event to monitor */ + skb = create_monitor_ctrl_open(sk); + if (skb) { + hci_send_to_channel(HCI_CHANNEL_MONITOR, skb, + HCI_SOCK_TRUSTED, NULL); + kfree_skb(skb); + } + break; + + case HCI_CHANNEL_USER: + if (hci_pi(sk)->hdev) { + err = -EALREADY; + goto done; + } + + if (haddr.hci_dev == HCI_DEV_NONE) { + err = -EINVAL; + goto done; + } + + if (!capable(CAP_NET_ADMIN)) { + err = -EPERM; + goto done; + } + + hdev = hci_dev_get(haddr.hci_dev); + if (!hdev) { + err = -ENODEV; + goto done; + } + + if (test_bit(HCI_INIT, &hdev->flags) || + hci_dev_test_flag(hdev, HCI_SETUP) || + hci_dev_test_flag(hdev, HCI_CONFIG) || + (!hci_dev_test_flag(hdev, HCI_AUTO_OFF) && + test_bit(HCI_UP, &hdev->flags))) { + err = -EBUSY; + hci_dev_put(hdev); + goto done; + } + + if (hci_dev_test_and_set_flag(hdev, HCI_USER_CHANNEL)) { + err = -EUSERS; + hci_dev_put(hdev); + goto done; + } + + mgmt_index_removed(hdev); + + err = hci_dev_open(hdev->id); + if (err) { + if (err == -EALREADY) { + /* In case the transport is already up and + * running, clear the error here. + * + * This can happen when opening a user + * channel and HCI_AUTO_OFF grace period + * is still active. + */ + err = 0; + } else { + hci_dev_clear_flag(hdev, HCI_USER_CHANNEL); + mgmt_index_added(hdev); + hci_dev_put(hdev); + goto done; + } + } + + hci_pi(sk)->channel = haddr.hci_channel; + + if (!hci_sock_gen_cookie(sk)) { + /* In the case when a cookie has already been assigned, + * this socket will transition from a raw socket into + * a user channel socket. For a clean transition, send + * the close notification first. + */ + skb = create_monitor_ctrl_close(sk); + if (skb) { + hci_send_to_channel(HCI_CHANNEL_MONITOR, skb, + HCI_SOCK_TRUSTED, NULL); + kfree_skb(skb); + } + } + + /* The user channel is restricted to CAP_NET_ADMIN + * capabilities and with that implicitly trusted. + */ + hci_sock_set_flag(sk, HCI_SOCK_TRUSTED); + + hci_pi(sk)->hdev = hdev; + + /* Send event to monitor */ + skb = create_monitor_ctrl_open(sk); + if (skb) { + hci_send_to_channel(HCI_CHANNEL_MONITOR, skb, + HCI_SOCK_TRUSTED, NULL); + kfree_skb(skb); + } + + atomic_inc(&hdev->promisc); + break; + + case HCI_CHANNEL_MONITOR: + if (haddr.hci_dev != HCI_DEV_NONE) { + err = -EINVAL; + goto done; + } + + if (!capable(CAP_NET_RAW)) { + err = -EPERM; + goto done; + } + + hci_pi(sk)->channel = haddr.hci_channel; + + /* The monitor interface is restricted to CAP_NET_RAW + * capabilities and with that implicitly trusted. + */ + hci_sock_set_flag(sk, HCI_SOCK_TRUSTED); + + send_monitor_note(sk, "Linux version %s (%s)", + init_utsname()->release, + init_utsname()->machine); + send_monitor_note(sk, "Bluetooth subsystem version %u.%u", + BT_SUBSYS_VERSION, BT_SUBSYS_REVISION); + send_monitor_replay(sk); + send_monitor_control_replay(sk); + + atomic_inc(&monitor_promisc); + break; + + case HCI_CHANNEL_LOGGING: + if (haddr.hci_dev != HCI_DEV_NONE) { + err = -EINVAL; + goto done; + } + + if (!capable(CAP_NET_ADMIN)) { + err = -EPERM; + goto done; + } + + hci_pi(sk)->channel = haddr.hci_channel; + break; + + default: + if (!hci_mgmt_chan_find(haddr.hci_channel)) { + err = -EINVAL; + goto done; + } + + if (haddr.hci_dev != HCI_DEV_NONE) { + err = -EINVAL; + goto done; + } + + /* Users with CAP_NET_ADMIN capabilities are allowed + * access to all management commands and events. For + * untrusted users the interface is restricted and + * also only untrusted events are sent. + */ + if (capable(CAP_NET_ADMIN)) + hci_sock_set_flag(sk, HCI_SOCK_TRUSTED); + + hci_pi(sk)->channel = haddr.hci_channel; + + /* At the moment the index and unconfigured index events + * are enabled unconditionally. Setting them on each + * socket when binding keeps this functionality. They + * however might be cleared later and then sending of these + * events will be disabled, but that is then intentional. + * + * This also enables generic events that are safe to be + * received by untrusted users. Example for such events + * are changes to settings, class of device, name etc. + */ + if (hci_pi(sk)->channel == HCI_CHANNEL_CONTROL) { + if (!hci_sock_gen_cookie(sk)) { + /* In the case when a cookie has already been + * assigned, this socket will transition from + * a raw socket into a control socket. To + * allow for a clean transition, send the + * close notification first. + */ + skb = create_monitor_ctrl_close(sk); + if (skb) { + hci_send_to_channel(HCI_CHANNEL_MONITOR, skb, + HCI_SOCK_TRUSTED, NULL); + kfree_skb(skb); + } + } + + /* Send event to monitor */ + skb = create_monitor_ctrl_open(sk); + if (skb) { + hci_send_to_channel(HCI_CHANNEL_MONITOR, skb, + HCI_SOCK_TRUSTED, NULL); + kfree_skb(skb); + } + + hci_sock_set_flag(sk, HCI_MGMT_INDEX_EVENTS); + hci_sock_set_flag(sk, HCI_MGMT_UNCONF_INDEX_EVENTS); + hci_sock_set_flag(sk, HCI_MGMT_OPTION_EVENTS); + hci_sock_set_flag(sk, HCI_MGMT_SETTING_EVENTS); + hci_sock_set_flag(sk, HCI_MGMT_DEV_CLASS_EVENTS); + hci_sock_set_flag(sk, HCI_MGMT_LOCAL_NAME_EVENTS); + } + break; + } + + /* Default MTU to HCI_MAX_FRAME_SIZE if not set */ + if (!hci_pi(sk)->mtu) + hci_pi(sk)->mtu = HCI_MAX_FRAME_SIZE; + + sk->sk_state = BT_BOUND; + +done: + release_sock(sk); + return err; +} + +static int hci_sock_getname(struct socket *sock, struct sockaddr *addr, + int peer) +{ + struct sockaddr_hci *haddr = (struct sockaddr_hci *)addr; + struct sock *sk = sock->sk; + struct hci_dev *hdev; + int err = 0; + + BT_DBG("sock %p sk %p", sock, sk); + + if (peer) + return -EOPNOTSUPP; + + lock_sock(sk); + + hdev = hci_hdev_from_sock(sk); + if (IS_ERR(hdev)) { + err = PTR_ERR(hdev); + goto done; + } + + haddr->hci_family = AF_BLUETOOTH; + haddr->hci_dev = hdev->id; + haddr->hci_channel= hci_pi(sk)->channel; + err = sizeof(*haddr); + +done: + release_sock(sk); + return err; +} + +static void hci_sock_cmsg(struct sock *sk, struct msghdr *msg, + struct sk_buff *skb) +{ + __u8 mask = hci_pi(sk)->cmsg_mask; + + if (mask & HCI_CMSG_DIR) { + int incoming = bt_cb(skb)->incoming; + put_cmsg(msg, SOL_HCI, HCI_CMSG_DIR, sizeof(incoming), + &incoming); + } + + if (mask & HCI_CMSG_TSTAMP) { +#ifdef CONFIG_COMPAT + struct old_timeval32 ctv; +#endif + struct __kernel_old_timeval tv; + void *data; + int len; + + skb_get_timestamp(skb, &tv); + + data = &tv; + len = sizeof(tv); +#ifdef CONFIG_COMPAT + if (!COMPAT_USE_64BIT_TIME && + (msg->msg_flags & MSG_CMSG_COMPAT)) { + ctv.tv_sec = tv.tv_sec; + ctv.tv_usec = tv.tv_usec; + data = &ctv; + len = sizeof(ctv); + } +#endif + + put_cmsg(msg, SOL_HCI, HCI_CMSG_TSTAMP, len, data); + } +} + +static int hci_sock_recvmsg(struct socket *sock, struct msghdr *msg, + size_t len, int flags) +{ + struct sock *sk = sock->sk; + struct sk_buff *skb; + int copied, err; + unsigned int skblen; + + BT_DBG("sock %p, sk %p", sock, sk); + + if (flags & MSG_OOB) + return -EOPNOTSUPP; + + if (hci_pi(sk)->channel == HCI_CHANNEL_LOGGING) + return -EOPNOTSUPP; + + if (sk->sk_state == BT_CLOSED) + return 0; + + skb = skb_recv_datagram(sk, flags, &err); + if (!skb) + return err; + + skblen = skb->len; + copied = skb->len; + if (len < copied) { + msg->msg_flags |= MSG_TRUNC; + copied = len; + } + + skb_reset_transport_header(skb); + err = skb_copy_datagram_msg(skb, 0, msg, copied); + + switch (hci_pi(sk)->channel) { + case HCI_CHANNEL_RAW: + hci_sock_cmsg(sk, msg, skb); + break; + case HCI_CHANNEL_USER: + case HCI_CHANNEL_MONITOR: + sock_recv_timestamp(msg, sk, skb); + break; + default: + if (hci_mgmt_chan_find(hci_pi(sk)->channel)) + sock_recv_timestamp(msg, sk, skb); + break; + } + + skb_free_datagram(sk, skb); + + if (flags & MSG_TRUNC) + copied = skblen; + + return err ? : copied; +} + +static int hci_mgmt_cmd(struct hci_mgmt_chan *chan, struct sock *sk, + struct sk_buff *skb) +{ + u8 *cp; + struct mgmt_hdr *hdr; + u16 opcode, index, len; + struct hci_dev *hdev = NULL; + const struct hci_mgmt_handler *handler; + bool var_len, no_hdev; + int err; + + BT_DBG("got %d bytes", skb->len); + + if (skb->len < sizeof(*hdr)) + return -EINVAL; + + hdr = (void *)skb->data; + opcode = __le16_to_cpu(hdr->opcode); + index = __le16_to_cpu(hdr->index); + len = __le16_to_cpu(hdr->len); + + if (len != skb->len - sizeof(*hdr)) { + err = -EINVAL; + goto done; + } + + if (chan->channel == HCI_CHANNEL_CONTROL) { + struct sk_buff *cmd; + + /* Send event to monitor */ + cmd = create_monitor_ctrl_command(sk, index, opcode, len, + skb->data + sizeof(*hdr)); + if (cmd) { + hci_send_to_channel(HCI_CHANNEL_MONITOR, cmd, + HCI_SOCK_TRUSTED, NULL); + kfree_skb(cmd); + } + } + + if (opcode >= chan->handler_count || + chan->handlers[opcode].func == NULL) { + BT_DBG("Unknown op %u", opcode); + err = mgmt_cmd_status(sk, index, opcode, + MGMT_STATUS_UNKNOWN_COMMAND); + goto done; + } + + handler = &chan->handlers[opcode]; + + if (!hci_sock_test_flag(sk, HCI_SOCK_TRUSTED) && + !(handler->flags & HCI_MGMT_UNTRUSTED)) { + err = mgmt_cmd_status(sk, index, opcode, + MGMT_STATUS_PERMISSION_DENIED); + goto done; + } + + if (index != MGMT_INDEX_NONE) { + hdev = hci_dev_get(index); + if (!hdev) { + err = mgmt_cmd_status(sk, index, opcode, + MGMT_STATUS_INVALID_INDEX); + goto done; + } + + if (hci_dev_test_flag(hdev, HCI_SETUP) || + hci_dev_test_flag(hdev, HCI_CONFIG) || + hci_dev_test_flag(hdev, HCI_USER_CHANNEL)) { + err = mgmt_cmd_status(sk, index, opcode, + MGMT_STATUS_INVALID_INDEX); + goto done; + } + + if (hci_dev_test_flag(hdev, HCI_UNCONFIGURED) && + !(handler->flags & HCI_MGMT_UNCONFIGURED)) { + err = mgmt_cmd_status(sk, index, opcode, + MGMT_STATUS_INVALID_INDEX); + goto done; + } + } + + if (!(handler->flags & HCI_MGMT_HDEV_OPTIONAL)) { + no_hdev = (handler->flags & HCI_MGMT_NO_HDEV); + if (no_hdev != !hdev) { + err = mgmt_cmd_status(sk, index, opcode, + MGMT_STATUS_INVALID_INDEX); + goto done; + } + } + + var_len = (handler->flags & HCI_MGMT_VAR_LEN); + if ((var_len && len < handler->data_len) || + (!var_len && len != handler->data_len)) { + err = mgmt_cmd_status(sk, index, opcode, + MGMT_STATUS_INVALID_PARAMS); + goto done; + } + + if (hdev && chan->hdev_init) + chan->hdev_init(sk, hdev); + + cp = skb->data + sizeof(*hdr); + + err = handler->func(sk, hdev, cp, len); + if (err < 0) + goto done; + + err = skb->len; + +done: + if (hdev) + hci_dev_put(hdev); + + return err; +} + +static int hci_logging_frame(struct sock *sk, struct sk_buff *skb, + unsigned int flags) +{ + struct hci_mon_hdr *hdr; + struct hci_dev *hdev; + u16 index; + int err; + + /* The logging frame consists at minimum of the standard header, + * the priority byte, the ident length byte and at least one string + * terminator NUL byte. Anything shorter are invalid packets. + */ + if (skb->len < sizeof(*hdr) + 3) + return -EINVAL; + + hdr = (void *)skb->data; + + if (__le16_to_cpu(hdr->len) != skb->len - sizeof(*hdr)) + return -EINVAL; + + if (__le16_to_cpu(hdr->opcode) == 0x0000) { + __u8 priority = skb->data[sizeof(*hdr)]; + __u8 ident_len = skb->data[sizeof(*hdr) + 1]; + + /* Only the priorities 0-7 are valid and with that any other + * value results in an invalid packet. + * + * The priority byte is followed by an ident length byte and + * the NUL terminated ident string. Check that the ident + * length is not overflowing the packet and also that the + * ident string itself is NUL terminated. In case the ident + * length is zero, the length value actually doubles as NUL + * terminator identifier. + * + * The message follows the ident string (if present) and + * must be NUL terminated. Otherwise it is not a valid packet. + */ + if (priority > 7 || skb->data[skb->len - 1] != 0x00 || + ident_len > skb->len - sizeof(*hdr) - 3 || + skb->data[sizeof(*hdr) + ident_len + 1] != 0x00) + return -EINVAL; + } else { + return -EINVAL; + } + + index = __le16_to_cpu(hdr->index); + + if (index != MGMT_INDEX_NONE) { + hdev = hci_dev_get(index); + if (!hdev) + return -ENODEV; + } else { + hdev = NULL; + } + + hdr->opcode = cpu_to_le16(HCI_MON_USER_LOGGING); + + hci_send_to_channel(HCI_CHANNEL_MONITOR, skb, HCI_SOCK_TRUSTED, NULL); + err = skb->len; + + if (hdev) + hci_dev_put(hdev); + + return err; +} + +static int hci_sock_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + struct sock *sk = sock->sk; + struct hci_mgmt_chan *chan; + struct hci_dev *hdev; + struct sk_buff *skb; + int err; + const unsigned int flags = msg->msg_flags; + + BT_DBG("sock %p sk %p", sock, sk); + + if (flags & MSG_OOB) + return -EOPNOTSUPP; + + if (flags & ~(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_ERRQUEUE | MSG_CMSG_COMPAT)) + return -EINVAL; + + if (len < 4 || len > hci_pi(sk)->mtu) + return -EINVAL; + + skb = bt_skb_sendmsg(sk, msg, len, len, 0, 0); + if (IS_ERR(skb)) + return PTR_ERR(skb); + + lock_sock(sk); + + switch (hci_pi(sk)->channel) { + case HCI_CHANNEL_RAW: + case HCI_CHANNEL_USER: + break; + case HCI_CHANNEL_MONITOR: + err = -EOPNOTSUPP; + goto drop; + case HCI_CHANNEL_LOGGING: + err = hci_logging_frame(sk, skb, flags); + goto drop; + default: + mutex_lock(&mgmt_chan_list_lock); + chan = __hci_mgmt_chan_find(hci_pi(sk)->channel); + if (chan) + err = hci_mgmt_cmd(chan, sk, skb); + else + err = -EINVAL; + + mutex_unlock(&mgmt_chan_list_lock); + goto drop; + } + + hdev = hci_hdev_from_sock(sk); + if (IS_ERR(hdev)) { + err = PTR_ERR(hdev); + goto drop; + } + + if (!test_bit(HCI_UP, &hdev->flags)) { + err = -ENETDOWN; + goto drop; + } + + hci_skb_pkt_type(skb) = skb->data[0]; + skb_pull(skb, 1); + + if (hci_pi(sk)->channel == HCI_CHANNEL_USER) { + /* No permission check is needed for user channel + * since that gets enforced when binding the socket. + * + * However check that the packet type is valid. + */ + if (hci_skb_pkt_type(skb) != HCI_COMMAND_PKT && + hci_skb_pkt_type(skb) != HCI_ACLDATA_PKT && + hci_skb_pkt_type(skb) != HCI_SCODATA_PKT && + hci_skb_pkt_type(skb) != HCI_ISODATA_PKT) { + err = -EINVAL; + goto drop; + } + + skb_queue_tail(&hdev->raw_q, skb); + queue_work(hdev->workqueue, &hdev->tx_work); + } else if (hci_skb_pkt_type(skb) == HCI_COMMAND_PKT) { + u16 opcode = get_unaligned_le16(skb->data); + u16 ogf = hci_opcode_ogf(opcode); + u16 ocf = hci_opcode_ocf(opcode); + + if (((ogf > HCI_SFLT_MAX_OGF) || + !hci_test_bit(ocf & HCI_FLT_OCF_BITS, + &hci_sec_filter.ocf_mask[ogf])) && + !capable(CAP_NET_RAW)) { + err = -EPERM; + goto drop; + } + + /* Since the opcode has already been extracted here, store + * a copy of the value for later use by the drivers. + */ + hci_skb_opcode(skb) = opcode; + + if (ogf == 0x3f) { + skb_queue_tail(&hdev->raw_q, skb); + queue_work(hdev->workqueue, &hdev->tx_work); + } else { + /* Stand-alone HCI commands must be flagged as + * single-command requests. + */ + bt_cb(skb)->hci.req_flags |= HCI_REQ_START; + + skb_queue_tail(&hdev->cmd_q, skb); + queue_work(hdev->workqueue, &hdev->cmd_work); + } + } else { + if (!capable(CAP_NET_RAW)) { + err = -EPERM; + goto drop; + } + + if (hci_skb_pkt_type(skb) != HCI_ACLDATA_PKT && + hci_skb_pkt_type(skb) != HCI_SCODATA_PKT && + hci_skb_pkt_type(skb) != HCI_ISODATA_PKT) { + err = -EINVAL; + goto drop; + } + + skb_queue_tail(&hdev->raw_q, skb); + queue_work(hdev->workqueue, &hdev->tx_work); + } + + err = len; + +done: + release_sock(sk); + return err; + +drop: + kfree_skb(skb); + goto done; +} + +static int hci_sock_setsockopt_old(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int len) +{ + struct hci_ufilter uf = { .opcode = 0 }; + struct sock *sk = sock->sk; + int err = 0, opt = 0; + + BT_DBG("sk %p, opt %d", sk, optname); + + lock_sock(sk); + + if (hci_pi(sk)->channel != HCI_CHANNEL_RAW) { + err = -EBADFD; + goto done; + } + + switch (optname) { + case HCI_DATA_DIR: + if (copy_from_sockptr(&opt, optval, sizeof(opt))) { + err = -EFAULT; + break; + } + + if (opt) + hci_pi(sk)->cmsg_mask |= HCI_CMSG_DIR; + else + hci_pi(sk)->cmsg_mask &= ~HCI_CMSG_DIR; + break; + + case HCI_TIME_STAMP: + if (copy_from_sockptr(&opt, optval, sizeof(opt))) { + err = -EFAULT; + break; + } + + if (opt) + hci_pi(sk)->cmsg_mask |= HCI_CMSG_TSTAMP; + else + hci_pi(sk)->cmsg_mask &= ~HCI_CMSG_TSTAMP; + break; + + case HCI_FILTER: + { + struct hci_filter *f = &hci_pi(sk)->filter; + + uf.type_mask = f->type_mask; + uf.opcode = f->opcode; + uf.event_mask[0] = *((u32 *) f->event_mask + 0); + uf.event_mask[1] = *((u32 *) f->event_mask + 1); + } + + len = min_t(unsigned int, len, sizeof(uf)); + if (copy_from_sockptr(&uf, optval, len)) { + err = -EFAULT; + break; + } + + if (!capable(CAP_NET_RAW)) { + uf.type_mask &= hci_sec_filter.type_mask; + uf.event_mask[0] &= *((u32 *) hci_sec_filter.event_mask + 0); + uf.event_mask[1] &= *((u32 *) hci_sec_filter.event_mask + 1); + } + + { + struct hci_filter *f = &hci_pi(sk)->filter; + + f->type_mask = uf.type_mask; + f->opcode = uf.opcode; + *((u32 *) f->event_mask + 0) = uf.event_mask[0]; + *((u32 *) f->event_mask + 1) = uf.event_mask[1]; + } + break; + + default: + err = -ENOPROTOOPT; + break; + } + +done: + release_sock(sk); + return err; +} + +static int hci_sock_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int len) +{ + struct sock *sk = sock->sk; + int err = 0; + u16 opt; + + BT_DBG("sk %p, opt %d", sk, optname); + + if (level == SOL_HCI) + return hci_sock_setsockopt_old(sock, level, optname, optval, + len); + + if (level != SOL_BLUETOOTH) + return -ENOPROTOOPT; + + lock_sock(sk); + + switch (optname) { + case BT_SNDMTU: + case BT_RCVMTU: + switch (hci_pi(sk)->channel) { + /* Don't allow changing MTU for channels that are meant for HCI + * traffic only. + */ + case HCI_CHANNEL_RAW: + case HCI_CHANNEL_USER: + err = -ENOPROTOOPT; + goto done; + } + + if (copy_from_sockptr(&opt, optval, sizeof(opt))) { + err = -EFAULT; + break; + } + + hci_pi(sk)->mtu = opt; + break; + + default: + err = -ENOPROTOOPT; + break; + } + +done: + release_sock(sk); + return err; +} + +static int hci_sock_getsockopt_old(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct hci_ufilter uf; + struct sock *sk = sock->sk; + int len, opt, err = 0; + + BT_DBG("sk %p, opt %d", sk, optname); + + if (get_user(len, optlen)) + return -EFAULT; + + lock_sock(sk); + + if (hci_pi(sk)->channel != HCI_CHANNEL_RAW) { + err = -EBADFD; + goto done; + } + + switch (optname) { + case HCI_DATA_DIR: + if (hci_pi(sk)->cmsg_mask & HCI_CMSG_DIR) + opt = 1; + else + opt = 0; + + if (put_user(opt, optval)) + err = -EFAULT; + break; + + case HCI_TIME_STAMP: + if (hci_pi(sk)->cmsg_mask & HCI_CMSG_TSTAMP) + opt = 1; + else + opt = 0; + + if (put_user(opt, optval)) + err = -EFAULT; + break; + + case HCI_FILTER: + { + struct hci_filter *f = &hci_pi(sk)->filter; + + memset(&uf, 0, sizeof(uf)); + uf.type_mask = f->type_mask; + uf.opcode = f->opcode; + uf.event_mask[0] = *((u32 *) f->event_mask + 0); + uf.event_mask[1] = *((u32 *) f->event_mask + 1); + } + + len = min_t(unsigned int, len, sizeof(uf)); + if (copy_to_user(optval, &uf, len)) + err = -EFAULT; + break; + + default: + err = -ENOPROTOOPT; + break; + } + +done: + release_sock(sk); + return err; +} + +static int hci_sock_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + int err = 0; + + BT_DBG("sk %p, opt %d", sk, optname); + + if (level == SOL_HCI) + return hci_sock_getsockopt_old(sock, level, optname, optval, + optlen); + + if (level != SOL_BLUETOOTH) + return -ENOPROTOOPT; + + lock_sock(sk); + + switch (optname) { + case BT_SNDMTU: + case BT_RCVMTU: + if (put_user(hci_pi(sk)->mtu, (u16 __user *)optval)) + err = -EFAULT; + break; + + default: + err = -ENOPROTOOPT; + break; + } + + release_sock(sk); + return err; +} + +static void hci_sock_destruct(struct sock *sk) +{ + mgmt_cleanup(sk); + skb_queue_purge(&sk->sk_receive_queue); + skb_queue_purge(&sk->sk_write_queue); +} + +static const struct proto_ops hci_sock_ops = { + .family = PF_BLUETOOTH, + .owner = THIS_MODULE, + .release = hci_sock_release, + .bind = hci_sock_bind, + .getname = hci_sock_getname, + .sendmsg = hci_sock_sendmsg, + .recvmsg = hci_sock_recvmsg, + .ioctl = hci_sock_ioctl, +#ifdef CONFIG_COMPAT + .compat_ioctl = hci_sock_compat_ioctl, +#endif + .poll = datagram_poll, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = hci_sock_setsockopt, + .getsockopt = hci_sock_getsockopt, + .connect = sock_no_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .mmap = sock_no_mmap +}; + +static struct proto hci_sk_proto = { + .name = "HCI", + .owner = THIS_MODULE, + .obj_size = sizeof(struct hci_pinfo) +}; + +static int hci_sock_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + + BT_DBG("sock %p", sock); + + if (sock->type != SOCK_RAW) + return -ESOCKTNOSUPPORT; + + sock->ops = &hci_sock_ops; + + sk = sk_alloc(net, PF_BLUETOOTH, GFP_ATOMIC, &hci_sk_proto, kern); + if (!sk) + return -ENOMEM; + + sock_init_data(sock, sk); + + sock_reset_flag(sk, SOCK_ZAPPED); + + sk->sk_protocol = protocol; + + sock->state = SS_UNCONNECTED; + sk->sk_state = BT_OPEN; + sk->sk_destruct = hci_sock_destruct; + + bt_sock_link(&hci_sk_list, sk); + return 0; +} + +static const struct net_proto_family hci_sock_family_ops = { + .family = PF_BLUETOOTH, + .owner = THIS_MODULE, + .create = hci_sock_create, +}; + +int __init hci_sock_init(void) +{ + int err; + + BUILD_BUG_ON(sizeof(struct sockaddr_hci) > sizeof(struct sockaddr)); + + err = proto_register(&hci_sk_proto, 0); + if (err < 0) + return err; + + err = bt_sock_register(BTPROTO_HCI, &hci_sock_family_ops); + if (err < 0) { + BT_ERR("HCI socket registration failed"); + goto error; + } + + err = bt_procfs_init(&init_net, "hci", &hci_sk_list, NULL); + if (err < 0) { + BT_ERR("Failed to create HCI proc file"); + bt_sock_unregister(BTPROTO_HCI); + goto error; + } + + BT_INFO("HCI socket layer initialized"); + + return 0; + +error: + proto_unregister(&hci_sk_proto); + return err; +} + +void hci_sock_cleanup(void) +{ + bt_procfs_cleanup(&init_net, "hci"); + bt_sock_unregister(BTPROTO_HCI); + proto_unregister(&hci_sk_proto); +} diff --git a/net/bluetooth/hci_sync.c b/net/bluetooth/hci_sync.c new file mode 100644 index 000000000..d74fe13f3 --- /dev/null +++ b/net/bluetooth/hci_sync.c @@ -0,0 +1,6316 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * BlueZ - Bluetooth protocol stack for Linux + * + * Copyright (C) 2021 Intel Corporation + */ + +#include + +#include +#include +#include + +#include "hci_request.h" +#include "hci_codec.h" +#include "hci_debugfs.h" +#include "smp.h" +#include "eir.h" +#include "msft.h" +#include "aosp.h" +#include "leds.h" + +static void hci_cmd_sync_complete(struct hci_dev *hdev, u8 result, u16 opcode, + struct sk_buff *skb) +{ + bt_dev_dbg(hdev, "result 0x%2.2x", result); + + if (hdev->req_status != HCI_REQ_PEND) + return; + + hdev->req_result = result; + hdev->req_status = HCI_REQ_DONE; + + if (skb) { + struct sock *sk = hci_skb_sk(skb); + + /* Drop sk reference if set */ + if (sk) + sock_put(sk); + + hdev->req_skb = skb_get(skb); + } + + wake_up_interruptible(&hdev->req_wait_q); +} + +static struct sk_buff *hci_cmd_sync_alloc(struct hci_dev *hdev, u16 opcode, + u32 plen, const void *param, + struct sock *sk) +{ + int len = HCI_COMMAND_HDR_SIZE + plen; + struct hci_command_hdr *hdr; + struct sk_buff *skb; + + skb = bt_skb_alloc(len, GFP_ATOMIC); + if (!skb) + return NULL; + + hdr = skb_put(skb, HCI_COMMAND_HDR_SIZE); + hdr->opcode = cpu_to_le16(opcode); + hdr->plen = plen; + + if (plen) + skb_put_data(skb, param, plen); + + bt_dev_dbg(hdev, "skb len %d", skb->len); + + hci_skb_pkt_type(skb) = HCI_COMMAND_PKT; + hci_skb_opcode(skb) = opcode; + + /* Grab a reference if command needs to be associated with a sock (e.g. + * likely mgmt socket that initiated the command). + */ + if (sk) { + hci_skb_sk(skb) = sk; + sock_hold(sk); + } + + return skb; +} + +static void hci_cmd_sync_add(struct hci_request *req, u16 opcode, u32 plen, + const void *param, u8 event, struct sock *sk) +{ + struct hci_dev *hdev = req->hdev; + struct sk_buff *skb; + + bt_dev_dbg(hdev, "opcode 0x%4.4x plen %d", opcode, plen); + + /* If an error occurred during request building, there is no point in + * queueing the HCI command. We can simply return. + */ + if (req->err) + return; + + skb = hci_cmd_sync_alloc(hdev, opcode, plen, param, sk); + if (!skb) { + bt_dev_err(hdev, "no memory for command (opcode 0x%4.4x)", + opcode); + req->err = -ENOMEM; + return; + } + + if (skb_queue_empty(&req->cmd_q)) + bt_cb(skb)->hci.req_flags |= HCI_REQ_START; + + hci_skb_event(skb) = event; + + skb_queue_tail(&req->cmd_q, skb); +} + +static int hci_cmd_sync_run(struct hci_request *req) +{ + struct hci_dev *hdev = req->hdev; + struct sk_buff *skb; + unsigned long flags; + + bt_dev_dbg(hdev, "length %u", skb_queue_len(&req->cmd_q)); + + /* If an error occurred during request building, remove all HCI + * commands queued on the HCI request queue. + */ + if (req->err) { + skb_queue_purge(&req->cmd_q); + return req->err; + } + + /* Do not allow empty requests */ + if (skb_queue_empty(&req->cmd_q)) + return -ENODATA; + + skb = skb_peek_tail(&req->cmd_q); + bt_cb(skb)->hci.req_complete_skb = hci_cmd_sync_complete; + bt_cb(skb)->hci.req_flags |= HCI_REQ_SKB; + + spin_lock_irqsave(&hdev->cmd_q.lock, flags); + skb_queue_splice_tail(&req->cmd_q, &hdev->cmd_q); + spin_unlock_irqrestore(&hdev->cmd_q.lock, flags); + + queue_work(hdev->workqueue, &hdev->cmd_work); + + return 0; +} + +/* This function requires the caller holds hdev->req_lock. */ +struct sk_buff *__hci_cmd_sync_sk(struct hci_dev *hdev, u16 opcode, u32 plen, + const void *param, u8 event, u32 timeout, + struct sock *sk) +{ + struct hci_request req; + struct sk_buff *skb; + int err = 0; + + bt_dev_dbg(hdev, "Opcode 0x%4.4x", opcode); + + hci_req_init(&req, hdev); + + hci_cmd_sync_add(&req, opcode, plen, param, event, sk); + + hdev->req_status = HCI_REQ_PEND; + + err = hci_cmd_sync_run(&req); + if (err < 0) + return ERR_PTR(err); + + err = wait_event_interruptible_timeout(hdev->req_wait_q, + hdev->req_status != HCI_REQ_PEND, + timeout); + + if (err == -ERESTARTSYS) + return ERR_PTR(-EINTR); + + switch (hdev->req_status) { + case HCI_REQ_DONE: + err = -bt_to_errno(hdev->req_result); + break; + + case HCI_REQ_CANCELED: + err = -hdev->req_result; + break; + + default: + err = -ETIMEDOUT; + break; + } + + hdev->req_status = 0; + hdev->req_result = 0; + skb = hdev->req_skb; + hdev->req_skb = NULL; + + bt_dev_dbg(hdev, "end: err %d", err); + + if (err < 0) { + kfree_skb(skb); + return ERR_PTR(err); + } + + return skb; +} +EXPORT_SYMBOL(__hci_cmd_sync_sk); + +/* This function requires the caller holds hdev->req_lock. */ +struct sk_buff *__hci_cmd_sync(struct hci_dev *hdev, u16 opcode, u32 plen, + const void *param, u32 timeout) +{ + return __hci_cmd_sync_sk(hdev, opcode, plen, param, 0, timeout, NULL); +} +EXPORT_SYMBOL(__hci_cmd_sync); + +/* Send HCI command and wait for command complete event */ +struct sk_buff *hci_cmd_sync(struct hci_dev *hdev, u16 opcode, u32 plen, + const void *param, u32 timeout) +{ + struct sk_buff *skb; + + if (!test_bit(HCI_UP, &hdev->flags)) + return ERR_PTR(-ENETDOWN); + + bt_dev_dbg(hdev, "opcode 0x%4.4x plen %d", opcode, plen); + + hci_req_sync_lock(hdev); + skb = __hci_cmd_sync(hdev, opcode, plen, param, timeout); + hci_req_sync_unlock(hdev); + + return skb; +} +EXPORT_SYMBOL(hci_cmd_sync); + +/* This function requires the caller holds hdev->req_lock. */ +struct sk_buff *__hci_cmd_sync_ev(struct hci_dev *hdev, u16 opcode, u32 plen, + const void *param, u8 event, u32 timeout) +{ + return __hci_cmd_sync_sk(hdev, opcode, plen, param, event, timeout, + NULL); +} +EXPORT_SYMBOL(__hci_cmd_sync_ev); + +/* This function requires the caller holds hdev->req_lock. */ +int __hci_cmd_sync_status_sk(struct hci_dev *hdev, u16 opcode, u32 plen, + const void *param, u8 event, u32 timeout, + struct sock *sk) +{ + struct sk_buff *skb; + u8 status; + + skb = __hci_cmd_sync_sk(hdev, opcode, plen, param, event, timeout, sk); + if (IS_ERR(skb)) { + if (!event) + bt_dev_err(hdev, "Opcode 0x%4.4x failed: %ld", opcode, + PTR_ERR(skb)); + return PTR_ERR(skb); + } + + /* If command return a status event skb will be set to NULL as there are + * no parameters, in case of failure IS_ERR(skb) would have be set to + * the actual error would be found with PTR_ERR(skb). + */ + if (!skb) + return 0; + + status = skb->data[0]; + + kfree_skb(skb); + + return status; +} +EXPORT_SYMBOL(__hci_cmd_sync_status_sk); + +int __hci_cmd_sync_status(struct hci_dev *hdev, u16 opcode, u32 plen, + const void *param, u32 timeout) +{ + return __hci_cmd_sync_status_sk(hdev, opcode, plen, param, 0, timeout, + NULL); +} +EXPORT_SYMBOL(__hci_cmd_sync_status); + +static void hci_cmd_sync_work(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, cmd_sync_work); + + bt_dev_dbg(hdev, ""); + + /* Dequeue all entries and run them */ + while (1) { + struct hci_cmd_sync_work_entry *entry; + + mutex_lock(&hdev->cmd_sync_work_lock); + entry = list_first_entry_or_null(&hdev->cmd_sync_work_list, + struct hci_cmd_sync_work_entry, + list); + if (entry) + list_del(&entry->list); + mutex_unlock(&hdev->cmd_sync_work_lock); + + if (!entry) + break; + + bt_dev_dbg(hdev, "entry %p", entry); + + if (entry->func) { + int err; + + hci_req_sync_lock(hdev); + err = entry->func(hdev, entry->data); + if (entry->destroy) + entry->destroy(hdev, entry->data, err); + hci_req_sync_unlock(hdev); + } + + kfree(entry); + } +} + +static void hci_cmd_sync_cancel_work(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, cmd_sync_cancel_work); + + cancel_delayed_work_sync(&hdev->cmd_timer); + cancel_delayed_work_sync(&hdev->ncmd_timer); + atomic_set(&hdev->cmd_cnt, 1); + + wake_up_interruptible(&hdev->req_wait_q); +} + +static int hci_scan_disable_sync(struct hci_dev *hdev); +static int scan_disable_sync(struct hci_dev *hdev, void *data) +{ + return hci_scan_disable_sync(hdev); +} + +static int hci_inquiry_sync(struct hci_dev *hdev, u8 length); +static int interleaved_inquiry_sync(struct hci_dev *hdev, void *data) +{ + return hci_inquiry_sync(hdev, DISCOV_INTERLEAVED_INQUIRY_LEN); +} + +static void le_scan_disable(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, + le_scan_disable.work); + int status; + + bt_dev_dbg(hdev, ""); + hci_dev_lock(hdev); + + if (!hci_dev_test_flag(hdev, HCI_LE_SCAN)) + goto _return; + + cancel_delayed_work(&hdev->le_scan_restart); + + status = hci_cmd_sync_queue(hdev, scan_disable_sync, NULL, NULL); + if (status) { + bt_dev_err(hdev, "failed to disable LE scan: %d", status); + goto _return; + } + + hdev->discovery.scan_start = 0; + + /* If we were running LE only scan, change discovery state. If + * we were running both LE and BR/EDR inquiry simultaneously, + * and BR/EDR inquiry is already finished, stop discovery, + * otherwise BR/EDR inquiry will stop discovery when finished. + * If we will resolve remote device name, do not change + * discovery state. + */ + + if (hdev->discovery.type == DISCOV_TYPE_LE) + goto discov_stopped; + + if (hdev->discovery.type != DISCOV_TYPE_INTERLEAVED) + goto _return; + + if (test_bit(HCI_QUIRK_SIMULTANEOUS_DISCOVERY, &hdev->quirks)) { + if (!test_bit(HCI_INQUIRY, &hdev->flags) && + hdev->discovery.state != DISCOVERY_RESOLVING) + goto discov_stopped; + + goto _return; + } + + status = hci_cmd_sync_queue(hdev, interleaved_inquiry_sync, NULL, NULL); + if (status) { + bt_dev_err(hdev, "inquiry failed: status %d", status); + goto discov_stopped; + } + + goto _return; + +discov_stopped: + hci_discovery_set_state(hdev, DISCOVERY_STOPPED); + +_return: + hci_dev_unlock(hdev); +} + +static int hci_le_set_scan_enable_sync(struct hci_dev *hdev, u8 val, + u8 filter_dup); +static int hci_le_scan_restart_sync(struct hci_dev *hdev) +{ + /* If controller is not scanning we are done. */ + if (!hci_dev_test_flag(hdev, HCI_LE_SCAN)) + return 0; + + if (hdev->scanning_paused) { + bt_dev_dbg(hdev, "Scanning is paused for suspend"); + return 0; + } + + hci_le_set_scan_enable_sync(hdev, LE_SCAN_DISABLE, 0x00); + return hci_le_set_scan_enable_sync(hdev, LE_SCAN_ENABLE, + LE_SCAN_FILTER_DUP_ENABLE); +} + +static void le_scan_restart(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, + le_scan_restart.work); + unsigned long timeout, duration, scan_start, now; + int status; + + bt_dev_dbg(hdev, ""); + + status = hci_le_scan_restart_sync(hdev); + if (status) { + bt_dev_err(hdev, "failed to restart LE scan: status %d", + status); + return; + } + + hci_dev_lock(hdev); + + if (!test_bit(HCI_QUIRK_STRICT_DUPLICATE_FILTER, &hdev->quirks) || + !hdev->discovery.scan_start) + goto unlock; + + /* When the scan was started, hdev->le_scan_disable has been queued + * after duration from scan_start. During scan restart this job + * has been canceled, and we need to queue it again after proper + * timeout, to make sure that scan does not run indefinitely. + */ + duration = hdev->discovery.scan_duration; + scan_start = hdev->discovery.scan_start; + now = jiffies; + if (now - scan_start <= duration) { + int elapsed; + + if (now >= scan_start) + elapsed = now - scan_start; + else + elapsed = ULONG_MAX - scan_start + now; + + timeout = duration - elapsed; + } else { + timeout = 0; + } + + queue_delayed_work(hdev->req_workqueue, + &hdev->le_scan_disable, timeout); + +unlock: + hci_dev_unlock(hdev); +} + +static int reenable_adv_sync(struct hci_dev *hdev, void *data) +{ + bt_dev_dbg(hdev, ""); + + if (!hci_dev_test_flag(hdev, HCI_ADVERTISING) && + list_empty(&hdev->adv_instances)) + return 0; + + if (hdev->cur_adv_instance) { + return hci_schedule_adv_instance_sync(hdev, + hdev->cur_adv_instance, + true); + } else { + if (ext_adv_capable(hdev)) { + hci_start_ext_adv_sync(hdev, 0x00); + } else { + hci_update_adv_data_sync(hdev, 0x00); + hci_update_scan_rsp_data_sync(hdev, 0x00); + hci_enable_advertising_sync(hdev); + } + } + + return 0; +} + +static void reenable_adv(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, + reenable_adv_work); + int status; + + bt_dev_dbg(hdev, ""); + + hci_dev_lock(hdev); + + status = hci_cmd_sync_queue(hdev, reenable_adv_sync, NULL, NULL); + if (status) + bt_dev_err(hdev, "failed to reenable ADV: %d", status); + + hci_dev_unlock(hdev); +} + +static void cancel_adv_timeout(struct hci_dev *hdev) +{ + if (hdev->adv_instance_timeout) { + hdev->adv_instance_timeout = 0; + cancel_delayed_work(&hdev->adv_instance_expire); + } +} + +/* For a single instance: + * - force == true: The instance will be removed even when its remaining + * lifetime is not zero. + * - force == false: the instance will be deactivated but kept stored unless + * the remaining lifetime is zero. + * + * For instance == 0x00: + * - force == true: All instances will be removed regardless of their timeout + * setting. + * - force == false: Only instances that have a timeout will be removed. + */ +int hci_clear_adv_instance_sync(struct hci_dev *hdev, struct sock *sk, + u8 instance, bool force) +{ + struct adv_info *adv_instance, *n, *next_instance = NULL; + int err; + u8 rem_inst; + + /* Cancel any timeout concerning the removed instance(s). */ + if (!instance || hdev->cur_adv_instance == instance) + cancel_adv_timeout(hdev); + + /* Get the next instance to advertise BEFORE we remove + * the current one. This can be the same instance again + * if there is only one instance. + */ + if (instance && hdev->cur_adv_instance == instance) + next_instance = hci_get_next_instance(hdev, instance); + + if (instance == 0x00) { + list_for_each_entry_safe(adv_instance, n, &hdev->adv_instances, + list) { + if (!(force || adv_instance->timeout)) + continue; + + rem_inst = adv_instance->instance; + err = hci_remove_adv_instance(hdev, rem_inst); + if (!err) + mgmt_advertising_removed(sk, hdev, rem_inst); + } + } else { + adv_instance = hci_find_adv_instance(hdev, instance); + + if (force || (adv_instance && adv_instance->timeout && + !adv_instance->remaining_time)) { + /* Don't advertise a removed instance. */ + if (next_instance && + next_instance->instance == instance) + next_instance = NULL; + + err = hci_remove_adv_instance(hdev, instance); + if (!err) + mgmt_advertising_removed(sk, hdev, instance); + } + } + + if (!hdev_is_powered(hdev) || hci_dev_test_flag(hdev, HCI_ADVERTISING)) + return 0; + + if (next_instance && !ext_adv_capable(hdev)) + return hci_schedule_adv_instance_sync(hdev, + next_instance->instance, + false); + + return 0; +} + +static int adv_timeout_expire_sync(struct hci_dev *hdev, void *data) +{ + u8 instance = *(u8 *)data; + + kfree(data); + + hci_clear_adv_instance_sync(hdev, NULL, instance, false); + + if (list_empty(&hdev->adv_instances)) + return hci_disable_advertising_sync(hdev); + + return 0; +} + +static void adv_timeout_expire(struct work_struct *work) +{ + u8 *inst_ptr; + struct hci_dev *hdev = container_of(work, struct hci_dev, + adv_instance_expire.work); + + bt_dev_dbg(hdev, ""); + + hci_dev_lock(hdev); + + hdev->adv_instance_timeout = 0; + + if (hdev->cur_adv_instance == 0x00) + goto unlock; + + inst_ptr = kmalloc(1, GFP_KERNEL); + if (!inst_ptr) + goto unlock; + + *inst_ptr = hdev->cur_adv_instance; + hci_cmd_sync_queue(hdev, adv_timeout_expire_sync, inst_ptr, NULL); + +unlock: + hci_dev_unlock(hdev); +} + +void hci_cmd_sync_init(struct hci_dev *hdev) +{ + INIT_WORK(&hdev->cmd_sync_work, hci_cmd_sync_work); + INIT_LIST_HEAD(&hdev->cmd_sync_work_list); + mutex_init(&hdev->cmd_sync_work_lock); + mutex_init(&hdev->unregister_lock); + + INIT_WORK(&hdev->cmd_sync_cancel_work, hci_cmd_sync_cancel_work); + INIT_WORK(&hdev->reenable_adv_work, reenable_adv); + INIT_DELAYED_WORK(&hdev->le_scan_disable, le_scan_disable); + INIT_DELAYED_WORK(&hdev->le_scan_restart, le_scan_restart); + INIT_DELAYED_WORK(&hdev->adv_instance_expire, adv_timeout_expire); +} + +void hci_cmd_sync_clear(struct hci_dev *hdev) +{ + struct hci_cmd_sync_work_entry *entry, *tmp; + + cancel_work_sync(&hdev->cmd_sync_work); + cancel_work_sync(&hdev->reenable_adv_work); + + mutex_lock(&hdev->cmd_sync_work_lock); + list_for_each_entry_safe(entry, tmp, &hdev->cmd_sync_work_list, list) { + if (entry->destroy) + entry->destroy(hdev, entry->data, -ECANCELED); + + list_del(&entry->list); + kfree(entry); + } + mutex_unlock(&hdev->cmd_sync_work_lock); +} + +void __hci_cmd_sync_cancel(struct hci_dev *hdev, int err) +{ + bt_dev_dbg(hdev, "err 0x%2.2x", err); + + if (hdev->req_status == HCI_REQ_PEND) { + hdev->req_result = err; + hdev->req_status = HCI_REQ_CANCELED; + + cancel_delayed_work_sync(&hdev->cmd_timer); + cancel_delayed_work_sync(&hdev->ncmd_timer); + atomic_set(&hdev->cmd_cnt, 1); + + wake_up_interruptible(&hdev->req_wait_q); + } +} + +void hci_cmd_sync_cancel(struct hci_dev *hdev, int err) +{ + bt_dev_dbg(hdev, "err 0x%2.2x", err); + + if (hdev->req_status == HCI_REQ_PEND) { + hdev->req_result = err; + hdev->req_status = HCI_REQ_CANCELED; + + queue_work(hdev->workqueue, &hdev->cmd_sync_cancel_work); + } +} +EXPORT_SYMBOL(hci_cmd_sync_cancel); + +int hci_cmd_sync_queue(struct hci_dev *hdev, hci_cmd_sync_work_func_t func, + void *data, hci_cmd_sync_work_destroy_t destroy) +{ + struct hci_cmd_sync_work_entry *entry; + int err = 0; + + mutex_lock(&hdev->unregister_lock); + if (hci_dev_test_flag(hdev, HCI_UNREGISTER)) { + err = -ENODEV; + goto unlock; + } + + entry = kmalloc(sizeof(*entry), GFP_KERNEL); + if (!entry) { + err = -ENOMEM; + goto unlock; + } + entry->func = func; + entry->data = data; + entry->destroy = destroy; + + mutex_lock(&hdev->cmd_sync_work_lock); + list_add_tail(&entry->list, &hdev->cmd_sync_work_list); + mutex_unlock(&hdev->cmd_sync_work_lock); + + queue_work(hdev->req_workqueue, &hdev->cmd_sync_work); + +unlock: + mutex_unlock(&hdev->unregister_lock); + return err; +} +EXPORT_SYMBOL(hci_cmd_sync_queue); + +int hci_update_eir_sync(struct hci_dev *hdev) +{ + struct hci_cp_write_eir cp; + + bt_dev_dbg(hdev, ""); + + if (!hdev_is_powered(hdev)) + return 0; + + if (!lmp_ext_inq_capable(hdev)) + return 0; + + if (!hci_dev_test_flag(hdev, HCI_SSP_ENABLED)) + return 0; + + if (hci_dev_test_flag(hdev, HCI_SERVICE_CACHE)) + return 0; + + memset(&cp, 0, sizeof(cp)); + + eir_create(hdev, cp.data); + + if (memcmp(cp.data, hdev->eir, sizeof(cp.data)) == 0) + return 0; + + memcpy(hdev->eir, cp.data, sizeof(cp.data)); + + return __hci_cmd_sync_status(hdev, HCI_OP_WRITE_EIR, sizeof(cp), &cp, + HCI_CMD_TIMEOUT); +} + +static u8 get_service_classes(struct hci_dev *hdev) +{ + struct bt_uuid *uuid; + u8 val = 0; + + list_for_each_entry(uuid, &hdev->uuids, list) + val |= uuid->svc_hint; + + return val; +} + +int hci_update_class_sync(struct hci_dev *hdev) +{ + u8 cod[3]; + + bt_dev_dbg(hdev, ""); + + if (!hdev_is_powered(hdev)) + return 0; + + if (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) + return 0; + + if (hci_dev_test_flag(hdev, HCI_SERVICE_CACHE)) + return 0; + + cod[0] = hdev->minor_class; + cod[1] = hdev->major_class; + cod[2] = get_service_classes(hdev); + + if (hci_dev_test_flag(hdev, HCI_LIMITED_DISCOVERABLE)) + cod[1] |= 0x20; + + if (memcmp(cod, hdev->dev_class, 3) == 0) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_WRITE_CLASS_OF_DEV, + sizeof(cod), cod, HCI_CMD_TIMEOUT); +} + +static bool is_advertising_allowed(struct hci_dev *hdev, bool connectable) +{ + /* If there is no connection we are OK to advertise. */ + if (hci_conn_num(hdev, LE_LINK) == 0) + return true; + + /* Check le_states if there is any connection in peripheral role. */ + if (hdev->conn_hash.le_num_peripheral > 0) { + /* Peripheral connection state and non connectable mode + * bit 20. + */ + if (!connectable && !(hdev->le_states[2] & 0x10)) + return false; + + /* Peripheral connection state and connectable mode bit 38 + * and scannable bit 21. + */ + if (connectable && (!(hdev->le_states[4] & 0x40) || + !(hdev->le_states[2] & 0x20))) + return false; + } + + /* Check le_states if there is any connection in central role. */ + if (hci_conn_num(hdev, LE_LINK) != hdev->conn_hash.le_num_peripheral) { + /* Central connection state and non connectable mode bit 18. */ + if (!connectable && !(hdev->le_states[2] & 0x02)) + return false; + + /* Central connection state and connectable mode bit 35 and + * scannable 19. + */ + if (connectable && (!(hdev->le_states[4] & 0x08) || + !(hdev->le_states[2] & 0x08))) + return false; + } + + return true; +} + +static bool adv_use_rpa(struct hci_dev *hdev, uint32_t flags) +{ + /* If privacy is not enabled don't use RPA */ + if (!hci_dev_test_flag(hdev, HCI_PRIVACY)) + return false; + + /* If basic privacy mode is enabled use RPA */ + if (!hci_dev_test_flag(hdev, HCI_LIMITED_PRIVACY)) + return true; + + /* If limited privacy mode is enabled don't use RPA if we're + * both discoverable and bondable. + */ + if ((flags & MGMT_ADV_FLAG_DISCOV) && + hci_dev_test_flag(hdev, HCI_BONDABLE)) + return false; + + /* We're neither bondable nor discoverable in the limited + * privacy mode, therefore use RPA. + */ + return true; +} + +static int hci_set_random_addr_sync(struct hci_dev *hdev, bdaddr_t *rpa) +{ + /* If we're advertising or initiating an LE connection we can't + * go ahead and change the random address at this time. This is + * because the eventual initiator address used for the + * subsequently created connection will be undefined (some + * controllers use the new address and others the one we had + * when the operation started). + * + * In this kind of scenario skip the update and let the random + * address be updated at the next cycle. + */ + if (hci_dev_test_flag(hdev, HCI_LE_ADV) || + hci_lookup_le_connect(hdev)) { + bt_dev_dbg(hdev, "Deferring random address update"); + hci_dev_set_flag(hdev, HCI_RPA_EXPIRED); + return 0; + } + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_RANDOM_ADDR, + 6, rpa, HCI_CMD_TIMEOUT); +} + +int hci_update_random_address_sync(struct hci_dev *hdev, bool require_privacy, + bool rpa, u8 *own_addr_type) +{ + int err; + + /* If privacy is enabled use a resolvable private address. If + * current RPA has expired or there is something else than + * the current RPA in use, then generate a new one. + */ + if (rpa) { + /* If Controller supports LL Privacy use own address type is + * 0x03 + */ + if (use_ll_privacy(hdev)) + *own_addr_type = ADDR_LE_DEV_RANDOM_RESOLVED; + else + *own_addr_type = ADDR_LE_DEV_RANDOM; + + /* Check if RPA is valid */ + if (rpa_valid(hdev)) + return 0; + + err = smp_generate_rpa(hdev, hdev->irk, &hdev->rpa); + if (err < 0) { + bt_dev_err(hdev, "failed to generate new RPA"); + return err; + } + + err = hci_set_random_addr_sync(hdev, &hdev->rpa); + if (err) + return err; + + return 0; + } + + /* In case of required privacy without resolvable private address, + * use an non-resolvable private address. This is useful for active + * scanning and non-connectable advertising. + */ + if (require_privacy) { + bdaddr_t nrpa; + + while (true) { + /* The non-resolvable private address is generated + * from random six bytes with the two most significant + * bits cleared. + */ + get_random_bytes(&nrpa, 6); + nrpa.b[5] &= 0x3f; + + /* The non-resolvable private address shall not be + * equal to the public address. + */ + if (bacmp(&hdev->bdaddr, &nrpa)) + break; + } + + *own_addr_type = ADDR_LE_DEV_RANDOM; + + return hci_set_random_addr_sync(hdev, &nrpa); + } + + /* If forcing static address is in use or there is no public + * address use the static address as random address (but skip + * the HCI command if the current random address is already the + * static one. + * + * In case BR/EDR has been disabled on a dual-mode controller + * and a static address has been configured, then use that + * address instead of the public BR/EDR address. + */ + if (hci_dev_test_flag(hdev, HCI_FORCE_STATIC_ADDR) || + !bacmp(&hdev->bdaddr, BDADDR_ANY) || + (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED) && + bacmp(&hdev->static_addr, BDADDR_ANY))) { + *own_addr_type = ADDR_LE_DEV_RANDOM; + if (bacmp(&hdev->static_addr, &hdev->random_addr)) + return hci_set_random_addr_sync(hdev, + &hdev->static_addr); + return 0; + } + + /* Neither privacy nor static address is being used so use a + * public address. + */ + *own_addr_type = ADDR_LE_DEV_PUBLIC; + + return 0; +} + +static int hci_disable_ext_adv_instance_sync(struct hci_dev *hdev, u8 instance) +{ + struct hci_cp_le_set_ext_adv_enable *cp; + struct hci_cp_ext_adv_set *set; + u8 data[sizeof(*cp) + sizeof(*set) * 1]; + u8 size; + + /* If request specifies an instance that doesn't exist, fail */ + if (instance > 0) { + struct adv_info *adv; + + adv = hci_find_adv_instance(hdev, instance); + if (!adv) + return -EINVAL; + + /* If not enabled there is nothing to do */ + if (!adv->enabled) + return 0; + } + + memset(data, 0, sizeof(data)); + + cp = (void *)data; + set = (void *)cp->data; + + /* Instance 0x00 indicates all advertising instances will be disabled */ + cp->num_of_sets = !!instance; + cp->enable = 0x00; + + set->handle = instance; + + size = sizeof(*cp) + sizeof(*set) * cp->num_of_sets; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_EXT_ADV_ENABLE, + size, data, HCI_CMD_TIMEOUT); +} + +static int hci_set_adv_set_random_addr_sync(struct hci_dev *hdev, u8 instance, + bdaddr_t *random_addr) +{ + struct hci_cp_le_set_adv_set_rand_addr cp; + int err; + + if (!instance) { + /* Instance 0x00 doesn't have an adv_info, instead it uses + * hdev->random_addr to track its address so whenever it needs + * to be updated this also set the random address since + * hdev->random_addr is shared with scan state machine. + */ + err = hci_set_random_addr_sync(hdev, random_addr); + if (err) + return err; + } + + memset(&cp, 0, sizeof(cp)); + + cp.handle = instance; + bacpy(&cp.bdaddr, random_addr); + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_ADV_SET_RAND_ADDR, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +int hci_setup_ext_adv_instance_sync(struct hci_dev *hdev, u8 instance) +{ + struct hci_cp_le_set_ext_adv_params cp; + bool connectable; + u32 flags; + bdaddr_t random_addr; + u8 own_addr_type; + int err; + struct adv_info *adv; + bool secondary_adv; + + if (instance > 0) { + adv = hci_find_adv_instance(hdev, instance); + if (!adv) + return -EINVAL; + } else { + adv = NULL; + } + + /* Updating parameters of an active instance will return a + * Command Disallowed error, so we must first disable the + * instance if it is active. + */ + if (adv && !adv->pending) { + err = hci_disable_ext_adv_instance_sync(hdev, instance); + if (err) + return err; + } + + flags = hci_adv_instance_flags(hdev, instance); + + /* If the "connectable" instance flag was not set, then choose between + * ADV_IND and ADV_NONCONN_IND based on the global connectable setting. + */ + connectable = (flags & MGMT_ADV_FLAG_CONNECTABLE) || + mgmt_get_connectable(hdev); + + if (!is_advertising_allowed(hdev, connectable)) + return -EPERM; + + /* Set require_privacy to true only when non-connectable + * advertising is used. In that case it is fine to use a + * non-resolvable private address. + */ + err = hci_get_random_address(hdev, !connectable, + adv_use_rpa(hdev, flags), adv, + &own_addr_type, &random_addr); + if (err < 0) + return err; + + memset(&cp, 0, sizeof(cp)); + + if (adv) { + hci_cpu_to_le24(adv->min_interval, cp.min_interval); + hci_cpu_to_le24(adv->max_interval, cp.max_interval); + cp.tx_power = adv->tx_power; + } else { + hci_cpu_to_le24(hdev->le_adv_min_interval, cp.min_interval); + hci_cpu_to_le24(hdev->le_adv_max_interval, cp.max_interval); + cp.tx_power = HCI_ADV_TX_POWER_NO_PREFERENCE; + } + + secondary_adv = (flags & MGMT_ADV_FLAG_SEC_MASK); + + if (connectable) { + if (secondary_adv) + cp.evt_properties = cpu_to_le16(LE_EXT_ADV_CONN_IND); + else + cp.evt_properties = cpu_to_le16(LE_LEGACY_ADV_IND); + } else if (hci_adv_instance_is_scannable(hdev, instance) || + (flags & MGMT_ADV_PARAM_SCAN_RSP)) { + if (secondary_adv) + cp.evt_properties = cpu_to_le16(LE_EXT_ADV_SCAN_IND); + else + cp.evt_properties = cpu_to_le16(LE_LEGACY_ADV_SCAN_IND); + } else { + if (secondary_adv) + cp.evt_properties = cpu_to_le16(LE_EXT_ADV_NON_CONN_IND); + else + cp.evt_properties = cpu_to_le16(LE_LEGACY_NONCONN_IND); + } + + /* If Own_Address_Type equals 0x02 or 0x03, the Peer_Address parameter + * contains the peer’s Identity Address and the Peer_Address_Type + * parameter contains the peer’s Identity Type (i.e., 0x00 or 0x01). + * These parameters are used to locate the corresponding local IRK in + * the resolving list; this IRK is used to generate their own address + * used in the advertisement. + */ + if (own_addr_type == ADDR_LE_DEV_RANDOM_RESOLVED) + hci_copy_identity_address(hdev, &cp.peer_addr, + &cp.peer_addr_type); + + cp.own_addr_type = own_addr_type; + cp.channel_map = hdev->le_adv_channel_map; + cp.handle = instance; + + if (flags & MGMT_ADV_FLAG_SEC_2M) { + cp.primary_phy = HCI_ADV_PHY_1M; + cp.secondary_phy = HCI_ADV_PHY_2M; + } else if (flags & MGMT_ADV_FLAG_SEC_CODED) { + cp.primary_phy = HCI_ADV_PHY_CODED; + cp.secondary_phy = HCI_ADV_PHY_CODED; + } else { + /* In all other cases use 1M */ + cp.primary_phy = HCI_ADV_PHY_1M; + cp.secondary_phy = HCI_ADV_PHY_1M; + } + + err = __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_EXT_ADV_PARAMS, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); + if (err) + return err; + + if ((own_addr_type == ADDR_LE_DEV_RANDOM || + own_addr_type == ADDR_LE_DEV_RANDOM_RESOLVED) && + bacmp(&random_addr, BDADDR_ANY)) { + /* Check if random address need to be updated */ + if (adv) { + if (!bacmp(&random_addr, &adv->random_addr)) + return 0; + } else { + if (!bacmp(&random_addr, &hdev->random_addr)) + return 0; + } + + return hci_set_adv_set_random_addr_sync(hdev, instance, + &random_addr); + } + + return 0; +} + +static int hci_set_ext_scan_rsp_data_sync(struct hci_dev *hdev, u8 instance) +{ + struct { + struct hci_cp_le_set_ext_scan_rsp_data cp; + u8 data[HCI_MAX_EXT_AD_LENGTH]; + } pdu; + u8 len; + struct adv_info *adv = NULL; + int err; + + memset(&pdu, 0, sizeof(pdu)); + + if (instance) { + adv = hci_find_adv_instance(hdev, instance); + if (!adv || !adv->scan_rsp_changed) + return 0; + } + + len = eir_create_scan_rsp(hdev, instance, pdu.data); + + pdu.cp.handle = instance; + pdu.cp.length = len; + pdu.cp.operation = LE_SET_ADV_DATA_OP_COMPLETE; + pdu.cp.frag_pref = LE_SET_ADV_DATA_NO_FRAG; + + err = __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_EXT_SCAN_RSP_DATA, + sizeof(pdu.cp) + len, &pdu.cp, + HCI_CMD_TIMEOUT); + if (err) + return err; + + if (adv) { + adv->scan_rsp_changed = false; + } else { + memcpy(hdev->scan_rsp_data, pdu.data, len); + hdev->scan_rsp_data_len = len; + } + + return 0; +} + +static int __hci_set_scan_rsp_data_sync(struct hci_dev *hdev, u8 instance) +{ + struct hci_cp_le_set_scan_rsp_data cp; + u8 len; + + memset(&cp, 0, sizeof(cp)); + + len = eir_create_scan_rsp(hdev, instance, cp.data); + + if (hdev->scan_rsp_data_len == len && + !memcmp(cp.data, hdev->scan_rsp_data, len)) + return 0; + + memcpy(hdev->scan_rsp_data, cp.data, sizeof(cp.data)); + hdev->scan_rsp_data_len = len; + + cp.length = len; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_SCAN_RSP_DATA, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +int hci_update_scan_rsp_data_sync(struct hci_dev *hdev, u8 instance) +{ + if (!hci_dev_test_flag(hdev, HCI_LE_ENABLED)) + return 0; + + if (ext_adv_capable(hdev)) + return hci_set_ext_scan_rsp_data_sync(hdev, instance); + + return __hci_set_scan_rsp_data_sync(hdev, instance); +} + +int hci_enable_ext_advertising_sync(struct hci_dev *hdev, u8 instance) +{ + struct hci_cp_le_set_ext_adv_enable *cp; + struct hci_cp_ext_adv_set *set; + u8 data[sizeof(*cp) + sizeof(*set) * 1]; + struct adv_info *adv; + + if (instance > 0) { + adv = hci_find_adv_instance(hdev, instance); + if (!adv) + return -EINVAL; + /* If already enabled there is nothing to do */ + if (adv->enabled) + return 0; + } else { + adv = NULL; + } + + cp = (void *)data; + set = (void *)cp->data; + + memset(cp, 0, sizeof(*cp)); + + cp->enable = 0x01; + cp->num_of_sets = 0x01; + + memset(set, 0, sizeof(*set)); + + set->handle = instance; + + /* Set duration per instance since controller is responsible for + * scheduling it. + */ + if (adv && adv->timeout) { + u16 duration = adv->timeout * MSEC_PER_SEC; + + /* Time = N * 10 ms */ + set->duration = cpu_to_le16(duration / 10); + } + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_EXT_ADV_ENABLE, + sizeof(*cp) + + sizeof(*set) * cp->num_of_sets, + data, HCI_CMD_TIMEOUT); +} + +int hci_start_ext_adv_sync(struct hci_dev *hdev, u8 instance) +{ + int err; + + err = hci_setup_ext_adv_instance_sync(hdev, instance); + if (err) + return err; + + err = hci_set_ext_scan_rsp_data_sync(hdev, instance); + if (err) + return err; + + return hci_enable_ext_advertising_sync(hdev, instance); +} + +static int hci_disable_per_advertising_sync(struct hci_dev *hdev, u8 instance) +{ + struct hci_cp_le_set_per_adv_enable cp; + + /* If periodic advertising already disabled there is nothing to do. */ + if (!hci_dev_test_flag(hdev, HCI_LE_PER_ADV)) + return 0; + + memset(&cp, 0, sizeof(cp)); + + cp.enable = 0x00; + cp.handle = instance; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_PER_ADV_ENABLE, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +static int hci_set_per_adv_params_sync(struct hci_dev *hdev, u8 instance, + u16 min_interval, u16 max_interval) +{ + struct hci_cp_le_set_per_adv_params cp; + + memset(&cp, 0, sizeof(cp)); + + if (!min_interval) + min_interval = DISCOV_LE_PER_ADV_INT_MIN; + + if (!max_interval) + max_interval = DISCOV_LE_PER_ADV_INT_MAX; + + cp.handle = instance; + cp.min_interval = cpu_to_le16(min_interval); + cp.max_interval = cpu_to_le16(max_interval); + cp.periodic_properties = 0x0000; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_PER_ADV_PARAMS, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +static int hci_set_per_adv_data_sync(struct hci_dev *hdev, u8 instance) +{ + struct { + struct hci_cp_le_set_per_adv_data cp; + u8 data[HCI_MAX_PER_AD_LENGTH]; + } pdu; + u8 len; + + memset(&pdu, 0, sizeof(pdu)); + + if (instance) { + struct adv_info *adv = hci_find_adv_instance(hdev, instance); + + if (!adv || !adv->periodic) + return 0; + } + + len = eir_create_per_adv_data(hdev, instance, pdu.data); + + pdu.cp.length = len; + pdu.cp.handle = instance; + pdu.cp.operation = LE_SET_ADV_DATA_OP_COMPLETE; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_PER_ADV_DATA, + sizeof(pdu.cp) + len, &pdu, + HCI_CMD_TIMEOUT); +} + +static int hci_enable_per_advertising_sync(struct hci_dev *hdev, u8 instance) +{ + struct hci_cp_le_set_per_adv_enable cp; + + /* If periodic advertising already enabled there is nothing to do. */ + if (hci_dev_test_flag(hdev, HCI_LE_PER_ADV)) + return 0; + + memset(&cp, 0, sizeof(cp)); + + cp.enable = 0x01; + cp.handle = instance; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_PER_ADV_ENABLE, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +/* Checks if periodic advertising data contains a Basic Announcement and if it + * does generates a Broadcast ID and add Broadcast Announcement. + */ +static int hci_adv_bcast_annoucement(struct hci_dev *hdev, struct adv_info *adv) +{ + u8 bid[3]; + u8 ad[4 + 3]; + + /* Skip if NULL adv as instance 0x00 is used for general purpose + * advertising so it cannot used for the likes of Broadcast Announcement + * as it can be overwritten at any point. + */ + if (!adv) + return 0; + + /* Check if PA data doesn't contains a Basic Audio Announcement then + * there is nothing to do. + */ + if (!eir_get_service_data(adv->per_adv_data, adv->per_adv_data_len, + 0x1851, NULL)) + return 0; + + /* Check if advertising data already has a Broadcast Announcement since + * the process may want to control the Broadcast ID directly and in that + * case the kernel shall no interfere. + */ + if (eir_get_service_data(adv->adv_data, adv->adv_data_len, 0x1852, + NULL)) + return 0; + + /* Generate Broadcast ID */ + get_random_bytes(bid, sizeof(bid)); + eir_append_service_data(ad, 0, 0x1852, bid, sizeof(bid)); + hci_set_adv_instance_data(hdev, adv->instance, sizeof(ad), ad, 0, NULL); + + return hci_update_adv_data_sync(hdev, adv->instance); +} + +int hci_start_per_adv_sync(struct hci_dev *hdev, u8 instance, u8 data_len, + u8 *data, u32 flags, u16 min_interval, + u16 max_interval, u16 sync_interval) +{ + struct adv_info *adv = NULL; + int err; + bool added = false; + + hci_disable_per_advertising_sync(hdev, instance); + + if (instance) { + adv = hci_find_adv_instance(hdev, instance); + /* Create an instance if that could not be found */ + if (!adv) { + adv = hci_add_per_instance(hdev, instance, flags, + data_len, data, + sync_interval, + sync_interval); + if (IS_ERR(adv)) + return PTR_ERR(adv); + added = true; + } + } + + /* Only start advertising if instance 0 or if a dedicated instance has + * been added. + */ + if (!adv || added) { + err = hci_start_ext_adv_sync(hdev, instance); + if (err < 0) + goto fail; + + err = hci_adv_bcast_annoucement(hdev, adv); + if (err < 0) + goto fail; + } + + err = hci_set_per_adv_params_sync(hdev, instance, min_interval, + max_interval); + if (err < 0) + goto fail; + + err = hci_set_per_adv_data_sync(hdev, instance); + if (err < 0) + goto fail; + + err = hci_enable_per_advertising_sync(hdev, instance); + if (err < 0) + goto fail; + + return 0; + +fail: + if (added) + hci_remove_adv_instance(hdev, instance); + + return err; +} + +static int hci_start_adv_sync(struct hci_dev *hdev, u8 instance) +{ + int err; + + if (ext_adv_capable(hdev)) + return hci_start_ext_adv_sync(hdev, instance); + + err = hci_update_adv_data_sync(hdev, instance); + if (err) + return err; + + err = hci_update_scan_rsp_data_sync(hdev, instance); + if (err) + return err; + + return hci_enable_advertising_sync(hdev); +} + +int hci_enable_advertising_sync(struct hci_dev *hdev) +{ + struct adv_info *adv_instance; + struct hci_cp_le_set_adv_param cp; + u8 own_addr_type, enable = 0x01; + bool connectable; + u16 adv_min_interval, adv_max_interval; + u32 flags; + u8 status; + + if (ext_adv_capable(hdev)) + return hci_enable_ext_advertising_sync(hdev, + hdev->cur_adv_instance); + + flags = hci_adv_instance_flags(hdev, hdev->cur_adv_instance); + adv_instance = hci_find_adv_instance(hdev, hdev->cur_adv_instance); + + /* If the "connectable" instance flag was not set, then choose between + * ADV_IND and ADV_NONCONN_IND based on the global connectable setting. + */ + connectable = (flags & MGMT_ADV_FLAG_CONNECTABLE) || + mgmt_get_connectable(hdev); + + if (!is_advertising_allowed(hdev, connectable)) + return -EINVAL; + + status = hci_disable_advertising_sync(hdev); + if (status) + return status; + + /* Clear the HCI_LE_ADV bit temporarily so that the + * hci_update_random_address knows that it's safe to go ahead + * and write a new random address. The flag will be set back on + * as soon as the SET_ADV_ENABLE HCI command completes. + */ + hci_dev_clear_flag(hdev, HCI_LE_ADV); + + /* Set require_privacy to true only when non-connectable + * advertising is used. In that case it is fine to use a + * non-resolvable private address. + */ + status = hci_update_random_address_sync(hdev, !connectable, + adv_use_rpa(hdev, flags), + &own_addr_type); + if (status) + return status; + + memset(&cp, 0, sizeof(cp)); + + if (adv_instance) { + adv_min_interval = adv_instance->min_interval; + adv_max_interval = adv_instance->max_interval; + } else { + adv_min_interval = hdev->le_adv_min_interval; + adv_max_interval = hdev->le_adv_max_interval; + } + + if (connectable) { + cp.type = LE_ADV_IND; + } else { + if (hci_adv_instance_is_scannable(hdev, hdev->cur_adv_instance)) + cp.type = LE_ADV_SCAN_IND; + else + cp.type = LE_ADV_NONCONN_IND; + + if (!hci_dev_test_flag(hdev, HCI_DISCOVERABLE) || + hci_dev_test_flag(hdev, HCI_LIMITED_DISCOVERABLE)) { + adv_min_interval = DISCOV_LE_FAST_ADV_INT_MIN; + adv_max_interval = DISCOV_LE_FAST_ADV_INT_MAX; + } + } + + cp.min_interval = cpu_to_le16(adv_min_interval); + cp.max_interval = cpu_to_le16(adv_max_interval); + cp.own_address_type = own_addr_type; + cp.channel_map = hdev->le_adv_channel_map; + + status = __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_ADV_PARAM, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); + if (status) + return status; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_ADV_ENABLE, + sizeof(enable), &enable, HCI_CMD_TIMEOUT); +} + +static int enable_advertising_sync(struct hci_dev *hdev, void *data) +{ + return hci_enable_advertising_sync(hdev); +} + +int hci_enable_advertising(struct hci_dev *hdev) +{ + if (!hci_dev_test_flag(hdev, HCI_ADVERTISING) && + list_empty(&hdev->adv_instances)) + return 0; + + return hci_cmd_sync_queue(hdev, enable_advertising_sync, NULL, NULL); +} + +int hci_remove_ext_adv_instance_sync(struct hci_dev *hdev, u8 instance, + struct sock *sk) +{ + int err; + + if (!ext_adv_capable(hdev)) + return 0; + + err = hci_disable_ext_adv_instance_sync(hdev, instance); + if (err) + return err; + + /* If request specifies an instance that doesn't exist, fail */ + if (instance > 0 && !hci_find_adv_instance(hdev, instance)) + return -EINVAL; + + return __hci_cmd_sync_status_sk(hdev, HCI_OP_LE_REMOVE_ADV_SET, + sizeof(instance), &instance, 0, + HCI_CMD_TIMEOUT, sk); +} + +static int remove_ext_adv_sync(struct hci_dev *hdev, void *data) +{ + struct adv_info *adv = data; + u8 instance = 0; + + if (adv) + instance = adv->instance; + + return hci_remove_ext_adv_instance_sync(hdev, instance, NULL); +} + +int hci_remove_ext_adv_instance(struct hci_dev *hdev, u8 instance) +{ + struct adv_info *adv = NULL; + + if (instance) { + adv = hci_find_adv_instance(hdev, instance); + if (!adv) + return -EINVAL; + } + + return hci_cmd_sync_queue(hdev, remove_ext_adv_sync, adv, NULL); +} + +int hci_le_terminate_big_sync(struct hci_dev *hdev, u8 handle, u8 reason) +{ + struct hci_cp_le_term_big cp; + + memset(&cp, 0, sizeof(cp)); + cp.handle = handle; + cp.reason = reason; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_TERM_BIG, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +static int hci_set_ext_adv_data_sync(struct hci_dev *hdev, u8 instance) +{ + struct { + struct hci_cp_le_set_ext_adv_data cp; + u8 data[HCI_MAX_EXT_AD_LENGTH]; + } pdu; + u8 len; + struct adv_info *adv = NULL; + int err; + + memset(&pdu, 0, sizeof(pdu)); + + if (instance) { + adv = hci_find_adv_instance(hdev, instance); + if (!adv || !adv->adv_data_changed) + return 0; + } + + len = eir_create_adv_data(hdev, instance, pdu.data); + + pdu.cp.length = len; + pdu.cp.handle = instance; + pdu.cp.operation = LE_SET_ADV_DATA_OP_COMPLETE; + pdu.cp.frag_pref = LE_SET_ADV_DATA_NO_FRAG; + + err = __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_EXT_ADV_DATA, + sizeof(pdu.cp) + len, &pdu.cp, + HCI_CMD_TIMEOUT); + if (err) + return err; + + /* Update data if the command succeed */ + if (adv) { + adv->adv_data_changed = false; + } else { + memcpy(hdev->adv_data, pdu.data, len); + hdev->adv_data_len = len; + } + + return 0; +} + +static int hci_set_adv_data_sync(struct hci_dev *hdev, u8 instance) +{ + struct hci_cp_le_set_adv_data cp; + u8 len; + + memset(&cp, 0, sizeof(cp)); + + len = eir_create_adv_data(hdev, instance, cp.data); + + /* There's nothing to do if the data hasn't changed */ + if (hdev->adv_data_len == len && + memcmp(cp.data, hdev->adv_data, len) == 0) + return 0; + + memcpy(hdev->adv_data, cp.data, sizeof(cp.data)); + hdev->adv_data_len = len; + + cp.length = len; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_ADV_DATA, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +int hci_update_adv_data_sync(struct hci_dev *hdev, u8 instance) +{ + if (!hci_dev_test_flag(hdev, HCI_LE_ENABLED)) + return 0; + + if (ext_adv_capable(hdev)) + return hci_set_ext_adv_data_sync(hdev, instance); + + return hci_set_adv_data_sync(hdev, instance); +} + +int hci_schedule_adv_instance_sync(struct hci_dev *hdev, u8 instance, + bool force) +{ + struct adv_info *adv = NULL; + u16 timeout; + + if (hci_dev_test_flag(hdev, HCI_ADVERTISING) && !ext_adv_capable(hdev)) + return -EPERM; + + if (hdev->adv_instance_timeout) + return -EBUSY; + + adv = hci_find_adv_instance(hdev, instance); + if (!adv) + return -ENOENT; + + /* A zero timeout means unlimited advertising. As long as there is + * only one instance, duration should be ignored. We still set a timeout + * in case further instances are being added later on. + * + * If the remaining lifetime of the instance is more than the duration + * then the timeout corresponds to the duration, otherwise it will be + * reduced to the remaining instance lifetime. + */ + if (adv->timeout == 0 || adv->duration <= adv->remaining_time) + timeout = adv->duration; + else + timeout = adv->remaining_time; + + /* The remaining time is being reduced unless the instance is being + * advertised without time limit. + */ + if (adv->timeout) + adv->remaining_time = adv->remaining_time - timeout; + + /* Only use work for scheduling instances with legacy advertising */ + if (!ext_adv_capable(hdev)) { + hdev->adv_instance_timeout = timeout; + queue_delayed_work(hdev->req_workqueue, + &hdev->adv_instance_expire, + msecs_to_jiffies(timeout * 1000)); + } + + /* If we're just re-scheduling the same instance again then do not + * execute any HCI commands. This happens when a single instance is + * being advertised. + */ + if (!force && hdev->cur_adv_instance == instance && + hci_dev_test_flag(hdev, HCI_LE_ADV)) + return 0; + + hdev->cur_adv_instance = instance; + + return hci_start_adv_sync(hdev, instance); +} + +static int hci_clear_adv_sets_sync(struct hci_dev *hdev, struct sock *sk) +{ + int err; + + if (!ext_adv_capable(hdev)) + return 0; + + /* Disable instance 0x00 to disable all instances */ + err = hci_disable_ext_adv_instance_sync(hdev, 0x00); + if (err) + return err; + + return __hci_cmd_sync_status_sk(hdev, HCI_OP_LE_CLEAR_ADV_SETS, + 0, NULL, 0, HCI_CMD_TIMEOUT, sk); +} + +static int hci_clear_adv_sync(struct hci_dev *hdev, struct sock *sk, bool force) +{ + struct adv_info *adv, *n; + int err = 0; + + if (ext_adv_capable(hdev)) + /* Remove all existing sets */ + err = hci_clear_adv_sets_sync(hdev, sk); + if (ext_adv_capable(hdev)) + return err; + + /* This is safe as long as there is no command send while the lock is + * held. + */ + hci_dev_lock(hdev); + + /* Cleanup non-ext instances */ + list_for_each_entry_safe(adv, n, &hdev->adv_instances, list) { + u8 instance = adv->instance; + int err; + + if (!(force || adv->timeout)) + continue; + + err = hci_remove_adv_instance(hdev, instance); + if (!err) + mgmt_advertising_removed(sk, hdev, instance); + } + + hci_dev_unlock(hdev); + + return 0; +} + +static int hci_remove_adv_sync(struct hci_dev *hdev, u8 instance, + struct sock *sk) +{ + int err = 0; + + /* If we use extended advertising, instance has to be removed first. */ + if (ext_adv_capable(hdev)) + err = hci_remove_ext_adv_instance_sync(hdev, instance, sk); + if (ext_adv_capable(hdev)) + return err; + + /* This is safe as long as there is no command send while the lock is + * held. + */ + hci_dev_lock(hdev); + + err = hci_remove_adv_instance(hdev, instance); + if (!err) + mgmt_advertising_removed(sk, hdev, instance); + + hci_dev_unlock(hdev); + + return err; +} + +/* For a single instance: + * - force == true: The instance will be removed even when its remaining + * lifetime is not zero. + * - force == false: the instance will be deactivated but kept stored unless + * the remaining lifetime is zero. + * + * For instance == 0x00: + * - force == true: All instances will be removed regardless of their timeout + * setting. + * - force == false: Only instances that have a timeout will be removed. + */ +int hci_remove_advertising_sync(struct hci_dev *hdev, struct sock *sk, + u8 instance, bool force) +{ + struct adv_info *next = NULL; + int err; + + /* Cancel any timeout concerning the removed instance(s). */ + if (!instance || hdev->cur_adv_instance == instance) + cancel_adv_timeout(hdev); + + /* Get the next instance to advertise BEFORE we remove + * the current one. This can be the same instance again + * if there is only one instance. + */ + if (hdev->cur_adv_instance == instance) + next = hci_get_next_instance(hdev, instance); + + if (!instance) { + err = hci_clear_adv_sync(hdev, sk, force); + if (err) + return err; + } else { + struct adv_info *adv = hci_find_adv_instance(hdev, instance); + + if (force || (adv && adv->timeout && !adv->remaining_time)) { + /* Don't advertise a removed instance. */ + if (next && next->instance == instance) + next = NULL; + + err = hci_remove_adv_sync(hdev, instance, sk); + if (err) + return err; + } + } + + if (!hdev_is_powered(hdev) || hci_dev_test_flag(hdev, HCI_ADVERTISING)) + return 0; + + if (next && !ext_adv_capable(hdev)) + hci_schedule_adv_instance_sync(hdev, next->instance, false); + + return 0; +} + +int hci_read_rssi_sync(struct hci_dev *hdev, __le16 handle) +{ + struct hci_cp_read_rssi cp; + + cp.handle = handle; + return __hci_cmd_sync_status(hdev, HCI_OP_READ_RSSI, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +int hci_read_clock_sync(struct hci_dev *hdev, struct hci_cp_read_clock *cp) +{ + return __hci_cmd_sync_status(hdev, HCI_OP_READ_CLOCK, + sizeof(*cp), cp, HCI_CMD_TIMEOUT); +} + +int hci_read_tx_power_sync(struct hci_dev *hdev, __le16 handle, u8 type) +{ + struct hci_cp_read_tx_power cp; + + cp.handle = handle; + cp.type = type; + return __hci_cmd_sync_status(hdev, HCI_OP_READ_TX_POWER, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +int hci_disable_advertising_sync(struct hci_dev *hdev) +{ + u8 enable = 0x00; + int err = 0; + + /* If controller is not advertising we are done. */ + if (!hci_dev_test_flag(hdev, HCI_LE_ADV)) + return 0; + + if (ext_adv_capable(hdev)) + err = hci_disable_ext_adv_instance_sync(hdev, 0x00); + if (ext_adv_capable(hdev)) + return err; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_ADV_ENABLE, + sizeof(enable), &enable, HCI_CMD_TIMEOUT); +} + +static int hci_le_set_ext_scan_enable_sync(struct hci_dev *hdev, u8 val, + u8 filter_dup) +{ + struct hci_cp_le_set_ext_scan_enable cp; + + memset(&cp, 0, sizeof(cp)); + cp.enable = val; + + if (hci_dev_test_flag(hdev, HCI_MESH)) + cp.filter_dup = LE_SCAN_FILTER_DUP_DISABLE; + else + cp.filter_dup = filter_dup; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_EXT_SCAN_ENABLE, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +static int hci_le_set_scan_enable_sync(struct hci_dev *hdev, u8 val, + u8 filter_dup) +{ + struct hci_cp_le_set_scan_enable cp; + + if (use_ext_scan(hdev)) + return hci_le_set_ext_scan_enable_sync(hdev, val, filter_dup); + + memset(&cp, 0, sizeof(cp)); + cp.enable = val; + + if (val && hci_dev_test_flag(hdev, HCI_MESH)) + cp.filter_dup = LE_SCAN_FILTER_DUP_DISABLE; + else + cp.filter_dup = filter_dup; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_SCAN_ENABLE, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +static int hci_le_set_addr_resolution_enable_sync(struct hci_dev *hdev, u8 val) +{ + if (!use_ll_privacy(hdev)) + return 0; + + /* If controller is not/already resolving we are done. */ + if (val == hci_dev_test_flag(hdev, HCI_LL_RPA_RESOLUTION)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_ADDR_RESOLV_ENABLE, + sizeof(val), &val, HCI_CMD_TIMEOUT); +} + +static int hci_scan_disable_sync(struct hci_dev *hdev) +{ + int err; + + /* If controller is not scanning we are done. */ + if (!hci_dev_test_flag(hdev, HCI_LE_SCAN)) + return 0; + + if (hdev->scanning_paused) { + bt_dev_dbg(hdev, "Scanning is paused for suspend"); + return 0; + } + + err = hci_le_set_scan_enable_sync(hdev, LE_SCAN_DISABLE, 0x00); + if (err) { + bt_dev_err(hdev, "Unable to disable scanning: %d", err); + return err; + } + + return err; +} + +static bool scan_use_rpa(struct hci_dev *hdev) +{ + return hci_dev_test_flag(hdev, HCI_PRIVACY); +} + +static void hci_start_interleave_scan(struct hci_dev *hdev) +{ + hdev->interleave_scan_state = INTERLEAVE_SCAN_NO_FILTER; + queue_delayed_work(hdev->req_workqueue, + &hdev->interleave_scan, 0); +} + +static bool is_interleave_scanning(struct hci_dev *hdev) +{ + return hdev->interleave_scan_state != INTERLEAVE_SCAN_NONE; +} + +static void cancel_interleave_scan(struct hci_dev *hdev) +{ + bt_dev_dbg(hdev, "cancelling interleave scan"); + + cancel_delayed_work_sync(&hdev->interleave_scan); + + hdev->interleave_scan_state = INTERLEAVE_SCAN_NONE; +} + +/* Return true if interleave_scan wasn't started until exiting this function, + * otherwise, return false + */ +static bool hci_update_interleaved_scan_sync(struct hci_dev *hdev) +{ + /* Do interleaved scan only if all of the following are true: + * - There is at least one ADV monitor + * - At least one pending LE connection or one device to be scanned for + * - Monitor offloading is not supported + * If so, we should alternate between allowlist scan and one without + * any filters to save power. + */ + bool use_interleaving = hci_is_adv_monitoring(hdev) && + !(list_empty(&hdev->pend_le_conns) && + list_empty(&hdev->pend_le_reports)) && + hci_get_adv_monitor_offload_ext(hdev) == + HCI_ADV_MONITOR_EXT_NONE; + bool is_interleaving = is_interleave_scanning(hdev); + + if (use_interleaving && !is_interleaving) { + hci_start_interleave_scan(hdev); + bt_dev_dbg(hdev, "starting interleave scan"); + return true; + } + + if (!use_interleaving && is_interleaving) + cancel_interleave_scan(hdev); + + return false; +} + +/* Removes connection to resolve list if needed.*/ +static int hci_le_del_resolve_list_sync(struct hci_dev *hdev, + bdaddr_t *bdaddr, u8 bdaddr_type) +{ + struct hci_cp_le_del_from_resolv_list cp; + struct bdaddr_list_with_irk *entry; + + if (!use_ll_privacy(hdev)) + return 0; + + /* Check if the IRK has been programmed */ + entry = hci_bdaddr_list_lookup_with_irk(&hdev->le_resolv_list, bdaddr, + bdaddr_type); + if (!entry) + return 0; + + cp.bdaddr_type = bdaddr_type; + bacpy(&cp.bdaddr, bdaddr); + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_DEL_FROM_RESOLV_LIST, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +static int hci_le_del_accept_list_sync(struct hci_dev *hdev, + bdaddr_t *bdaddr, u8 bdaddr_type) +{ + struct hci_cp_le_del_from_accept_list cp; + int err; + + /* Check if device is on accept list before removing it */ + if (!hci_bdaddr_list_lookup(&hdev->le_accept_list, bdaddr, bdaddr_type)) + return 0; + + cp.bdaddr_type = bdaddr_type; + bacpy(&cp.bdaddr, bdaddr); + + /* Ignore errors when removing from resolving list as that is likely + * that the device was never added. + */ + hci_le_del_resolve_list_sync(hdev, &cp.bdaddr, cp.bdaddr_type); + + err = __hci_cmd_sync_status(hdev, HCI_OP_LE_DEL_FROM_ACCEPT_LIST, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); + if (err) { + bt_dev_err(hdev, "Unable to remove from allow list: %d", err); + return err; + } + + bt_dev_dbg(hdev, "Remove %pMR (0x%x) from allow list", &cp.bdaddr, + cp.bdaddr_type); + + return 0; +} + +struct conn_params { + bdaddr_t addr; + u8 addr_type; + hci_conn_flags_t flags; + u8 privacy_mode; +}; + +/* Adds connection to resolve list if needed. + * Setting params to NULL programs local hdev->irk + */ +static int hci_le_add_resolve_list_sync(struct hci_dev *hdev, + struct conn_params *params) +{ + struct hci_cp_le_add_to_resolv_list cp; + struct smp_irk *irk; + struct bdaddr_list_with_irk *entry; + struct hci_conn_params *p; + + if (!use_ll_privacy(hdev)) + return 0; + + /* Attempt to program local identity address, type and irk if params is + * NULL. + */ + if (!params) { + if (!hci_dev_test_flag(hdev, HCI_PRIVACY)) + return 0; + + hci_copy_identity_address(hdev, &cp.bdaddr, &cp.bdaddr_type); + memcpy(cp.peer_irk, hdev->irk, 16); + goto done; + } + + irk = hci_find_irk_by_addr(hdev, ¶ms->addr, params->addr_type); + if (!irk) + return 0; + + /* Check if the IK has _not_ been programmed yet. */ + entry = hci_bdaddr_list_lookup_with_irk(&hdev->le_resolv_list, + ¶ms->addr, + params->addr_type); + if (entry) + return 0; + + cp.bdaddr_type = params->addr_type; + bacpy(&cp.bdaddr, ¶ms->addr); + memcpy(cp.peer_irk, irk->val, 16); + + /* Default privacy mode is always Network */ + params->privacy_mode = HCI_NETWORK_PRIVACY; + + rcu_read_lock(); + p = hci_pend_le_action_lookup(&hdev->pend_le_conns, + ¶ms->addr, params->addr_type); + if (!p) + p = hci_pend_le_action_lookup(&hdev->pend_le_reports, + ¶ms->addr, params->addr_type); + if (p) + WRITE_ONCE(p->privacy_mode, HCI_NETWORK_PRIVACY); + rcu_read_unlock(); + +done: + if (hci_dev_test_flag(hdev, HCI_PRIVACY)) + memcpy(cp.local_irk, hdev->irk, 16); + else + memset(cp.local_irk, 0, 16); + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_ADD_TO_RESOLV_LIST, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +/* Set Device Privacy Mode. */ +static int hci_le_set_privacy_mode_sync(struct hci_dev *hdev, + struct conn_params *params) +{ + struct hci_cp_le_set_privacy_mode cp; + struct smp_irk *irk; + + /* If device privacy mode has already been set there is nothing to do */ + if (params->privacy_mode == HCI_DEVICE_PRIVACY) + return 0; + + /* Check if HCI_CONN_FLAG_DEVICE_PRIVACY has been set as it also + * indicates that LL Privacy has been enabled and + * HCI_OP_LE_SET_PRIVACY_MODE is supported. + */ + if (!(params->flags & HCI_CONN_FLAG_DEVICE_PRIVACY)) + return 0; + + irk = hci_find_irk_by_addr(hdev, ¶ms->addr, params->addr_type); + if (!irk) + return 0; + + memset(&cp, 0, sizeof(cp)); + cp.bdaddr_type = irk->addr_type; + bacpy(&cp.bdaddr, &irk->bdaddr); + cp.mode = HCI_DEVICE_PRIVACY; + + /* Note: params->privacy_mode is not updated since it is a copy */ + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_PRIVACY_MODE, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +/* Adds connection to allow list if needed, if the device uses RPA (has IRK) + * this attempts to program the device in the resolving list as well and + * properly set the privacy mode. + */ +static int hci_le_add_accept_list_sync(struct hci_dev *hdev, + struct conn_params *params, + u8 *num_entries) +{ + struct hci_cp_le_add_to_accept_list cp; + int err; + + /* During suspend, only wakeable devices can be in acceptlist */ + if (hdev->suspended && + !(params->flags & HCI_CONN_FLAG_REMOTE_WAKEUP)) + return 0; + + /* Select filter policy to accept all advertising */ + if (*num_entries >= hdev->le_accept_list_size) + return -ENOSPC; + + /* Accept list can not be used with RPAs */ + if (!use_ll_privacy(hdev) && + hci_find_irk_by_addr(hdev, ¶ms->addr, params->addr_type)) + return -EINVAL; + + /* Attempt to program the device in the resolving list first to avoid + * having to rollback in case it fails since the resolving list is + * dynamic it can probably be smaller than the accept list. + */ + err = hci_le_add_resolve_list_sync(hdev, params); + if (err) { + bt_dev_err(hdev, "Unable to add to resolve list: %d", err); + return err; + } + + /* Set Privacy Mode */ + err = hci_le_set_privacy_mode_sync(hdev, params); + if (err) { + bt_dev_err(hdev, "Unable to set privacy mode: %d", err); + return err; + } + + /* Check if already in accept list */ + if (hci_bdaddr_list_lookup(&hdev->le_accept_list, ¶ms->addr, + params->addr_type)) + return 0; + + *num_entries += 1; + cp.bdaddr_type = params->addr_type; + bacpy(&cp.bdaddr, ¶ms->addr); + + err = __hci_cmd_sync_status(hdev, HCI_OP_LE_ADD_TO_ACCEPT_LIST, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); + if (err) { + bt_dev_err(hdev, "Unable to add to allow list: %d", err); + /* Rollback the device from the resolving list */ + hci_le_del_resolve_list_sync(hdev, &cp.bdaddr, cp.bdaddr_type); + return err; + } + + bt_dev_dbg(hdev, "Add %pMR (0x%x) to allow list", &cp.bdaddr, + cp.bdaddr_type); + + return 0; +} + +/* This function disables/pause all advertising instances */ +static int hci_pause_advertising_sync(struct hci_dev *hdev) +{ + int err; + int old_state; + + /* If already been paused there is nothing to do. */ + if (hdev->advertising_paused) + return 0; + + bt_dev_dbg(hdev, "Pausing directed advertising"); + + /* Stop directed advertising */ + old_state = hci_dev_test_flag(hdev, HCI_ADVERTISING); + if (old_state) { + /* When discoverable timeout triggers, then just make sure + * the limited discoverable flag is cleared. Even in the case + * of a timeout triggered from general discoverable, it is + * safe to unconditionally clear the flag. + */ + hci_dev_clear_flag(hdev, HCI_LIMITED_DISCOVERABLE); + hci_dev_clear_flag(hdev, HCI_DISCOVERABLE); + hdev->discov_timeout = 0; + } + + bt_dev_dbg(hdev, "Pausing advertising instances"); + + /* Call to disable any advertisements active on the controller. + * This will succeed even if no advertisements are configured. + */ + err = hci_disable_advertising_sync(hdev); + if (err) + return err; + + /* If we are using software rotation, pause the loop */ + if (!ext_adv_capable(hdev)) + cancel_adv_timeout(hdev); + + hdev->advertising_paused = true; + hdev->advertising_old_state = old_state; + + return 0; +} + +/* This function enables all user advertising instances */ +static int hci_resume_advertising_sync(struct hci_dev *hdev) +{ + struct adv_info *adv, *tmp; + int err; + + /* If advertising has not been paused there is nothing to do. */ + if (!hdev->advertising_paused) + return 0; + + /* Resume directed advertising */ + hdev->advertising_paused = false; + if (hdev->advertising_old_state) { + hci_dev_set_flag(hdev, HCI_ADVERTISING); + hdev->advertising_old_state = 0; + } + + bt_dev_dbg(hdev, "Resuming advertising instances"); + + if (ext_adv_capable(hdev)) { + /* Call for each tracked instance to be re-enabled */ + list_for_each_entry_safe(adv, tmp, &hdev->adv_instances, list) { + err = hci_enable_ext_advertising_sync(hdev, + adv->instance); + if (!err) + continue; + + /* If the instance cannot be resumed remove it */ + hci_remove_ext_adv_instance_sync(hdev, adv->instance, + NULL); + } + } else { + /* Schedule for most recent instance to be restarted and begin + * the software rotation loop + */ + err = hci_schedule_adv_instance_sync(hdev, + hdev->cur_adv_instance, + true); + } + + hdev->advertising_paused = false; + + return err; +} + +static int hci_pause_addr_resolution(struct hci_dev *hdev) +{ + int err; + + if (!use_ll_privacy(hdev)) + return 0; + + if (!hci_dev_test_flag(hdev, HCI_LL_RPA_RESOLUTION)) + return 0; + + /* Cannot disable addr resolution if scanning is enabled or + * when initiating an LE connection. + */ + if (hci_dev_test_flag(hdev, HCI_LE_SCAN) || + hci_lookup_le_connect(hdev)) { + bt_dev_err(hdev, "Command not allowed when scan/LE connect"); + return -EPERM; + } + + /* Cannot disable addr resolution if advertising is enabled. */ + err = hci_pause_advertising_sync(hdev); + if (err) { + bt_dev_err(hdev, "Pause advertising failed: %d", err); + return err; + } + + err = hci_le_set_addr_resolution_enable_sync(hdev, 0x00); + if (err) + bt_dev_err(hdev, "Unable to disable Address Resolution: %d", + err); + + /* Return if address resolution is disabled and RPA is not used. */ + if (!err && scan_use_rpa(hdev)) + return err; + + hci_resume_advertising_sync(hdev); + return err; +} + +struct sk_buff *hci_read_local_oob_data_sync(struct hci_dev *hdev, + bool extended, struct sock *sk) +{ + u16 opcode = extended ? HCI_OP_READ_LOCAL_OOB_EXT_DATA : + HCI_OP_READ_LOCAL_OOB_DATA; + + return __hci_cmd_sync_sk(hdev, opcode, 0, NULL, 0, HCI_CMD_TIMEOUT, sk); +} + +static struct conn_params *conn_params_copy(struct list_head *list, size_t *n) +{ + struct hci_conn_params *params; + struct conn_params *p; + size_t i; + + rcu_read_lock(); + + i = 0; + list_for_each_entry_rcu(params, list, action) + ++i; + *n = i; + + rcu_read_unlock(); + + p = kvcalloc(*n, sizeof(struct conn_params), GFP_KERNEL); + if (!p) + return NULL; + + rcu_read_lock(); + + i = 0; + list_for_each_entry_rcu(params, list, action) { + /* Racing adds are handled in next scan update */ + if (i >= *n) + break; + + /* No hdev->lock, but: addr, addr_type are immutable. + * privacy_mode is only written by us or in + * hci_cc_le_set_privacy_mode that we wait for. + * We should be idempotent so MGMT updating flags + * while we are processing is OK. + */ + bacpy(&p[i].addr, ¶ms->addr); + p[i].addr_type = params->addr_type; + p[i].flags = READ_ONCE(params->flags); + p[i].privacy_mode = READ_ONCE(params->privacy_mode); + ++i; + } + + rcu_read_unlock(); + + *n = i; + return p; +} + +/* Device must not be scanning when updating the accept list. + * + * Update is done using the following sequence: + * + * use_ll_privacy((Disable Advertising) -> Disable Resolving List) -> + * Remove Devices From Accept List -> + * (has IRK && use_ll_privacy(Remove Devices From Resolving List))-> + * Add Devices to Accept List -> + * (has IRK && use_ll_privacy(Remove Devices From Resolving List)) -> + * use_ll_privacy(Enable Resolving List -> (Enable Advertising)) -> + * Enable Scanning + * + * In case of failure advertising shall be restored to its original state and + * return would disable accept list since either accept or resolving list could + * not be programmed. + * + */ +static u8 hci_update_accept_list_sync(struct hci_dev *hdev) +{ + struct conn_params *params; + struct bdaddr_list *b, *t; + u8 num_entries = 0; + bool pend_conn, pend_report; + u8 filter_policy; + size_t i, n; + int err; + + /* Pause advertising if resolving list can be used as controllers + * cannot accept resolving list modifications while advertising. + */ + if (use_ll_privacy(hdev)) { + err = hci_pause_advertising_sync(hdev); + if (err) { + bt_dev_err(hdev, "pause advertising failed: %d", err); + return 0x00; + } + } + + /* Disable address resolution while reprogramming accept list since + * devices that do have an IRK will be programmed in the resolving list + * when LL Privacy is enabled. + */ + err = hci_le_set_addr_resolution_enable_sync(hdev, 0x00); + if (err) { + bt_dev_err(hdev, "Unable to disable LL privacy: %d", err); + goto done; + } + + /* Go through the current accept list programmed into the + * controller one by one and check if that address is connected or is + * still in the list of pending connections or list of devices to + * report. If not present in either list, then remove it from + * the controller. + */ + list_for_each_entry_safe(b, t, &hdev->le_accept_list, list) { + if (hci_conn_hash_lookup_le(hdev, &b->bdaddr, b->bdaddr_type)) + continue; + + /* Pointers not dereferenced, no locks needed */ + pend_conn = hci_pend_le_action_lookup(&hdev->pend_le_conns, + &b->bdaddr, + b->bdaddr_type); + pend_report = hci_pend_le_action_lookup(&hdev->pend_le_reports, + &b->bdaddr, + b->bdaddr_type); + + /* If the device is not likely to connect or report, + * remove it from the acceptlist. + */ + if (!pend_conn && !pend_report) { + hci_le_del_accept_list_sync(hdev, &b->bdaddr, + b->bdaddr_type); + continue; + } + + num_entries++; + } + + /* Since all no longer valid accept list entries have been + * removed, walk through the list of pending connections + * and ensure that any new device gets programmed into + * the controller. + * + * If the list of the devices is larger than the list of + * available accept list entries in the controller, then + * just abort and return filer policy value to not use the + * accept list. + * + * The list and params may be mutated while we wait for events, + * so make a copy and iterate it. + */ + + params = conn_params_copy(&hdev->pend_le_conns, &n); + if (!params) { + err = -ENOMEM; + goto done; + } + + for (i = 0; i < n; ++i) { + err = hci_le_add_accept_list_sync(hdev, ¶ms[i], + &num_entries); + if (err) { + kvfree(params); + goto done; + } + } + + kvfree(params); + + /* After adding all new pending connections, walk through + * the list of pending reports and also add these to the + * accept list if there is still space. Abort if space runs out. + */ + + params = conn_params_copy(&hdev->pend_le_reports, &n); + if (!params) { + err = -ENOMEM; + goto done; + } + + for (i = 0; i < n; ++i) { + err = hci_le_add_accept_list_sync(hdev, ¶ms[i], + &num_entries); + if (err) { + kvfree(params); + goto done; + } + } + + kvfree(params); + + /* Use the allowlist unless the following conditions are all true: + * - We are not currently suspending + * - There are 1 or more ADV monitors registered and it's not offloaded + * - Interleaved scanning is not currently using the allowlist + */ + if (!idr_is_empty(&hdev->adv_monitors_idr) && !hdev->suspended && + hci_get_adv_monitor_offload_ext(hdev) == HCI_ADV_MONITOR_EXT_NONE && + hdev->interleave_scan_state != INTERLEAVE_SCAN_ALLOWLIST) + err = -EINVAL; + +done: + filter_policy = err ? 0x00 : 0x01; + + /* Enable address resolution when LL Privacy is enabled. */ + err = hci_le_set_addr_resolution_enable_sync(hdev, 0x01); + if (err) + bt_dev_err(hdev, "Unable to enable LL privacy: %d", err); + + /* Resume advertising if it was paused */ + if (use_ll_privacy(hdev)) + hci_resume_advertising_sync(hdev); + + /* Select filter policy to use accept list */ + return filter_policy; +} + +/* Returns true if an le connection is in the scanning state */ +static inline bool hci_is_le_conn_scanning(struct hci_dev *hdev) +{ + struct hci_conn_hash *h = &hdev->conn_hash; + struct hci_conn *c; + + rcu_read_lock(); + + list_for_each_entry_rcu(c, &h->list, list) { + if (c->type == LE_LINK && c->state == BT_CONNECT && + test_bit(HCI_CONN_SCANNING, &c->flags)) { + rcu_read_unlock(); + return true; + } + } + + rcu_read_unlock(); + + return false; +} + +static int hci_le_set_ext_scan_param_sync(struct hci_dev *hdev, u8 type, + u16 interval, u16 window, + u8 own_addr_type, u8 filter_policy) +{ + struct hci_cp_le_set_ext_scan_params *cp; + struct hci_cp_le_scan_phy_params *phy; + u8 data[sizeof(*cp) + sizeof(*phy) * 2]; + u8 num_phy = 0; + + cp = (void *)data; + phy = (void *)cp->data; + + memset(data, 0, sizeof(data)); + + cp->own_addr_type = own_addr_type; + cp->filter_policy = filter_policy; + + if (scan_1m(hdev) || scan_2m(hdev)) { + cp->scanning_phys |= LE_SCAN_PHY_1M; + + phy->type = type; + phy->interval = cpu_to_le16(interval); + phy->window = cpu_to_le16(window); + + num_phy++; + phy++; + } + + if (scan_coded(hdev)) { + cp->scanning_phys |= LE_SCAN_PHY_CODED; + + phy->type = type; + phy->interval = cpu_to_le16(interval); + phy->window = cpu_to_le16(window); + + num_phy++; + phy++; + } + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_EXT_SCAN_PARAMS, + sizeof(*cp) + sizeof(*phy) * num_phy, + data, HCI_CMD_TIMEOUT); +} + +static int hci_le_set_scan_param_sync(struct hci_dev *hdev, u8 type, + u16 interval, u16 window, + u8 own_addr_type, u8 filter_policy) +{ + struct hci_cp_le_set_scan_param cp; + + if (use_ext_scan(hdev)) + return hci_le_set_ext_scan_param_sync(hdev, type, interval, + window, own_addr_type, + filter_policy); + + memset(&cp, 0, sizeof(cp)); + cp.type = type; + cp.interval = cpu_to_le16(interval); + cp.window = cpu_to_le16(window); + cp.own_address_type = own_addr_type; + cp.filter_policy = filter_policy; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_SCAN_PARAM, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +static int hci_start_scan_sync(struct hci_dev *hdev, u8 type, u16 interval, + u16 window, u8 own_addr_type, u8 filter_policy, + u8 filter_dup) +{ + int err; + + if (hdev->scanning_paused) { + bt_dev_dbg(hdev, "Scanning is paused for suspend"); + return 0; + } + + err = hci_le_set_scan_param_sync(hdev, type, interval, window, + own_addr_type, filter_policy); + if (err) + return err; + + return hci_le_set_scan_enable_sync(hdev, LE_SCAN_ENABLE, filter_dup); +} + +static int hci_passive_scan_sync(struct hci_dev *hdev) +{ + u8 own_addr_type; + u8 filter_policy; + u16 window, interval; + u8 filter_dups = LE_SCAN_FILTER_DUP_ENABLE; + int err; + + if (hdev->scanning_paused) { + bt_dev_dbg(hdev, "Scanning is paused for suspend"); + return 0; + } + + err = hci_scan_disable_sync(hdev); + if (err) { + bt_dev_err(hdev, "disable scanning failed: %d", err); + return err; + } + + /* Set require_privacy to false since no SCAN_REQ are send + * during passive scanning. Not using an non-resolvable address + * here is important so that peer devices using direct + * advertising with our address will be correctly reported + * by the controller. + */ + if (hci_update_random_address_sync(hdev, false, scan_use_rpa(hdev), + &own_addr_type)) + return 0; + + if (hdev->enable_advmon_interleave_scan && + hci_update_interleaved_scan_sync(hdev)) + return 0; + + bt_dev_dbg(hdev, "interleave state %d", hdev->interleave_scan_state); + + /* Adding or removing entries from the accept list must + * happen before enabling scanning. The controller does + * not allow accept list modification while scanning. + */ + filter_policy = hci_update_accept_list_sync(hdev); + + /* When the controller is using random resolvable addresses and + * with that having LE privacy enabled, then controllers with + * Extended Scanner Filter Policies support can now enable support + * for handling directed advertising. + * + * So instead of using filter polices 0x00 (no acceptlist) + * and 0x01 (acceptlist enabled) use the new filter policies + * 0x02 (no acceptlist) and 0x03 (acceptlist enabled). + */ + if (hci_dev_test_flag(hdev, HCI_PRIVACY) && + (hdev->le_features[0] & HCI_LE_EXT_SCAN_POLICY)) + filter_policy |= 0x02; + + if (hdev->suspended) { + window = hdev->le_scan_window_suspend; + interval = hdev->le_scan_int_suspend; + } else if (hci_is_le_conn_scanning(hdev)) { + window = hdev->le_scan_window_connect; + interval = hdev->le_scan_int_connect; + } else if (hci_is_adv_monitoring(hdev)) { + window = hdev->le_scan_window_adv_monitor; + interval = hdev->le_scan_int_adv_monitor; + } else { + window = hdev->le_scan_window; + interval = hdev->le_scan_interval; + } + + /* Disable all filtering for Mesh */ + if (hci_dev_test_flag(hdev, HCI_MESH)) { + filter_policy = 0; + filter_dups = LE_SCAN_FILTER_DUP_DISABLE; + } + + bt_dev_dbg(hdev, "LE passive scan with acceptlist = %d", filter_policy); + + return hci_start_scan_sync(hdev, LE_SCAN_PASSIVE, interval, window, + own_addr_type, filter_policy, filter_dups); +} + +/* This function controls the passive scanning based on hdev->pend_le_conns + * list. If there are pending LE connection we start the background scanning, + * otherwise we stop it in the following sequence: + * + * If there are devices to scan: + * + * Disable Scanning -> Update Accept List -> + * use_ll_privacy((Disable Advertising) -> Disable Resolving List -> + * Update Resolving List -> Enable Resolving List -> (Enable Advertising)) -> + * Enable Scanning + * + * Otherwise: + * + * Disable Scanning + */ +int hci_update_passive_scan_sync(struct hci_dev *hdev) +{ + int err; + + if (!test_bit(HCI_UP, &hdev->flags) || + test_bit(HCI_INIT, &hdev->flags) || + hci_dev_test_flag(hdev, HCI_SETUP) || + hci_dev_test_flag(hdev, HCI_CONFIG) || + hci_dev_test_flag(hdev, HCI_AUTO_OFF) || + hci_dev_test_flag(hdev, HCI_UNREGISTER)) + return 0; + + /* No point in doing scanning if LE support hasn't been enabled */ + if (!hci_dev_test_flag(hdev, HCI_LE_ENABLED)) + return 0; + + /* If discovery is active don't interfere with it */ + if (hdev->discovery.state != DISCOVERY_STOPPED) + return 0; + + /* Reset RSSI and UUID filters when starting background scanning + * since these filters are meant for service discovery only. + * + * The Start Discovery and Start Service Discovery operations + * ensure to set proper values for RSSI threshold and UUID + * filter list. So it is safe to just reset them here. + */ + hci_discovery_filter_clear(hdev); + + bt_dev_dbg(hdev, "ADV monitoring is %s", + hci_is_adv_monitoring(hdev) ? "on" : "off"); + + if (!hci_dev_test_flag(hdev, HCI_MESH) && + list_empty(&hdev->pend_le_conns) && + list_empty(&hdev->pend_le_reports) && + !hci_is_adv_monitoring(hdev) && + !hci_dev_test_flag(hdev, HCI_PA_SYNC)) { + /* If there is no pending LE connections or devices + * to be scanned for or no ADV monitors, we should stop the + * background scanning. + */ + + bt_dev_dbg(hdev, "stopping background scanning"); + + err = hci_scan_disable_sync(hdev); + if (err) + bt_dev_err(hdev, "stop background scanning failed: %d", + err); + } else { + /* If there is at least one pending LE connection, we should + * keep the background scan running. + */ + + /* If controller is connecting, we should not start scanning + * since some controllers are not able to scan and connect at + * the same time. + */ + if (hci_lookup_le_connect(hdev)) + return 0; + + bt_dev_dbg(hdev, "start background scanning"); + + err = hci_passive_scan_sync(hdev); + if (err) + bt_dev_err(hdev, "start background scanning failed: %d", + err); + } + + return err; +} + +static int update_scan_sync(struct hci_dev *hdev, void *data) +{ + return hci_update_scan_sync(hdev); +} + +int hci_update_scan(struct hci_dev *hdev) +{ + return hci_cmd_sync_queue(hdev, update_scan_sync, NULL, NULL); +} + +static int update_passive_scan_sync(struct hci_dev *hdev, void *data) +{ + return hci_update_passive_scan_sync(hdev); +} + +int hci_update_passive_scan(struct hci_dev *hdev) +{ + /* Only queue if it would have any effect */ + if (!test_bit(HCI_UP, &hdev->flags) || + test_bit(HCI_INIT, &hdev->flags) || + hci_dev_test_flag(hdev, HCI_SETUP) || + hci_dev_test_flag(hdev, HCI_CONFIG) || + hci_dev_test_flag(hdev, HCI_AUTO_OFF) || + hci_dev_test_flag(hdev, HCI_UNREGISTER)) + return 0; + + return hci_cmd_sync_queue(hdev, update_passive_scan_sync, NULL, NULL); +} + +int hci_write_sc_support_sync(struct hci_dev *hdev, u8 val) +{ + int err; + + if (!bredr_sc_enabled(hdev) || lmp_host_sc_capable(hdev)) + return 0; + + err = __hci_cmd_sync_status(hdev, HCI_OP_WRITE_SC_SUPPORT, + sizeof(val), &val, HCI_CMD_TIMEOUT); + + if (!err) { + if (val) { + hdev->features[1][0] |= LMP_HOST_SC; + hci_dev_set_flag(hdev, HCI_SC_ENABLED); + } else { + hdev->features[1][0] &= ~LMP_HOST_SC; + hci_dev_clear_flag(hdev, HCI_SC_ENABLED); + } + } + + return err; +} + +int hci_write_ssp_mode_sync(struct hci_dev *hdev, u8 mode) +{ + int err; + + if (!hci_dev_test_flag(hdev, HCI_SSP_ENABLED) || + lmp_host_ssp_capable(hdev)) + return 0; + + if (!mode && hci_dev_test_flag(hdev, HCI_USE_DEBUG_KEYS)) { + __hci_cmd_sync_status(hdev, HCI_OP_WRITE_SSP_DEBUG_MODE, + sizeof(mode), &mode, HCI_CMD_TIMEOUT); + } + + err = __hci_cmd_sync_status(hdev, HCI_OP_WRITE_SSP_MODE, + sizeof(mode), &mode, HCI_CMD_TIMEOUT); + if (err) + return err; + + return hci_write_sc_support_sync(hdev, 0x01); +} + +int hci_write_le_host_supported_sync(struct hci_dev *hdev, u8 le, u8 simul) +{ + struct hci_cp_write_le_host_supported cp; + + if (!hci_dev_test_flag(hdev, HCI_LE_ENABLED) || + !lmp_bredr_capable(hdev)) + return 0; + + /* Check first if we already have the right host state + * (host features set) + */ + if (le == lmp_host_le_capable(hdev) && + simul == lmp_host_le_br_capable(hdev)) + return 0; + + memset(&cp, 0, sizeof(cp)); + + cp.le = le; + cp.simul = simul; + + return __hci_cmd_sync_status(hdev, HCI_OP_WRITE_LE_HOST_SUPPORTED, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +static int hci_powered_update_adv_sync(struct hci_dev *hdev) +{ + struct adv_info *adv, *tmp; + int err; + + if (!hci_dev_test_flag(hdev, HCI_LE_ENABLED)) + return 0; + + /* If RPA Resolution has not been enable yet it means the + * resolving list is empty and we should attempt to program the + * local IRK in order to support using own_addr_type + * ADDR_LE_DEV_RANDOM_RESOLVED (0x03). + */ + if (!hci_dev_test_flag(hdev, HCI_LL_RPA_RESOLUTION)) { + hci_le_add_resolve_list_sync(hdev, NULL); + hci_le_set_addr_resolution_enable_sync(hdev, 0x01); + } + + /* Make sure the controller has a good default for + * advertising data. This also applies to the case + * where BR/EDR was toggled during the AUTO_OFF phase. + */ + if (hci_dev_test_flag(hdev, HCI_ADVERTISING) || + list_empty(&hdev->adv_instances)) { + if (ext_adv_capable(hdev)) { + err = hci_setup_ext_adv_instance_sync(hdev, 0x00); + if (!err) + hci_update_scan_rsp_data_sync(hdev, 0x00); + } else { + err = hci_update_adv_data_sync(hdev, 0x00); + if (!err) + hci_update_scan_rsp_data_sync(hdev, 0x00); + } + + if (hci_dev_test_flag(hdev, HCI_ADVERTISING)) + hci_enable_advertising_sync(hdev); + } + + /* Call for each tracked instance to be scheduled */ + list_for_each_entry_safe(adv, tmp, &hdev->adv_instances, list) + hci_schedule_adv_instance_sync(hdev, adv->instance, true); + + return 0; +} + +static int hci_write_auth_enable_sync(struct hci_dev *hdev) +{ + u8 link_sec; + + link_sec = hci_dev_test_flag(hdev, HCI_LINK_SECURITY); + if (link_sec == test_bit(HCI_AUTH, &hdev->flags)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_WRITE_AUTH_ENABLE, + sizeof(link_sec), &link_sec, + HCI_CMD_TIMEOUT); +} + +int hci_write_fast_connectable_sync(struct hci_dev *hdev, bool enable) +{ + struct hci_cp_write_page_scan_activity cp; + u8 type; + int err = 0; + + if (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) + return 0; + + if (hdev->hci_ver < BLUETOOTH_VER_1_2) + return 0; + + memset(&cp, 0, sizeof(cp)); + + if (enable) { + type = PAGE_SCAN_TYPE_INTERLACED; + + /* 160 msec page scan interval */ + cp.interval = cpu_to_le16(0x0100); + } else { + type = hdev->def_page_scan_type; + cp.interval = cpu_to_le16(hdev->def_page_scan_int); + } + + cp.window = cpu_to_le16(hdev->def_page_scan_window); + + if (__cpu_to_le16(hdev->page_scan_interval) != cp.interval || + __cpu_to_le16(hdev->page_scan_window) != cp.window) { + err = __hci_cmd_sync_status(hdev, + HCI_OP_WRITE_PAGE_SCAN_ACTIVITY, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); + if (err) + return err; + } + + if (hdev->page_scan_type != type) + err = __hci_cmd_sync_status(hdev, + HCI_OP_WRITE_PAGE_SCAN_TYPE, + sizeof(type), &type, + HCI_CMD_TIMEOUT); + + return err; +} + +static bool disconnected_accept_list_entries(struct hci_dev *hdev) +{ + struct bdaddr_list *b; + + list_for_each_entry(b, &hdev->accept_list, list) { + struct hci_conn *conn; + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &b->bdaddr); + if (!conn) + return true; + + if (conn->state != BT_CONNECTED && conn->state != BT_CONFIG) + return true; + } + + return false; +} + +static int hci_write_scan_enable_sync(struct hci_dev *hdev, u8 val) +{ + return __hci_cmd_sync_status(hdev, HCI_OP_WRITE_SCAN_ENABLE, + sizeof(val), &val, + HCI_CMD_TIMEOUT); +} + +int hci_update_scan_sync(struct hci_dev *hdev) +{ + u8 scan; + + if (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) + return 0; + + if (!hdev_is_powered(hdev)) + return 0; + + if (mgmt_powering_down(hdev)) + return 0; + + if (hdev->scanning_paused) + return 0; + + if (hci_dev_test_flag(hdev, HCI_CONNECTABLE) || + disconnected_accept_list_entries(hdev)) + scan = SCAN_PAGE; + else + scan = SCAN_DISABLED; + + if (hci_dev_test_flag(hdev, HCI_DISCOVERABLE)) + scan |= SCAN_INQUIRY; + + if (test_bit(HCI_PSCAN, &hdev->flags) == !!(scan & SCAN_PAGE) && + test_bit(HCI_ISCAN, &hdev->flags) == !!(scan & SCAN_INQUIRY)) + return 0; + + return hci_write_scan_enable_sync(hdev, scan); +} + +int hci_update_name_sync(struct hci_dev *hdev) +{ + struct hci_cp_write_local_name cp; + + memset(&cp, 0, sizeof(cp)); + + memcpy(cp.name, hdev->dev_name, sizeof(cp.name)); + + return __hci_cmd_sync_status(hdev, HCI_OP_WRITE_LOCAL_NAME, + sizeof(cp), &cp, + HCI_CMD_TIMEOUT); +} + +/* This function perform powered update HCI command sequence after the HCI init + * sequence which end up resetting all states, the sequence is as follows: + * + * HCI_SSP_ENABLED(Enable SSP) + * HCI_LE_ENABLED(Enable LE) + * HCI_LE_ENABLED(use_ll_privacy(Add local IRK to Resolving List) -> + * Update adv data) + * Enable Authentication + * lmp_bredr_capable(Set Fast Connectable -> Set Scan Type -> Set Class -> + * Set Name -> Set EIR) + */ +int hci_powered_update_sync(struct hci_dev *hdev) +{ + int err; + + /* Register the available SMP channels (BR/EDR and LE) only when + * successfully powering on the controller. This late + * registration is required so that LE SMP can clearly decide if + * the public address or static address is used. + */ + smp_register(hdev); + + err = hci_write_ssp_mode_sync(hdev, 0x01); + if (err) + return err; + + err = hci_write_le_host_supported_sync(hdev, 0x01, 0x00); + if (err) + return err; + + err = hci_powered_update_adv_sync(hdev); + if (err) + return err; + + err = hci_write_auth_enable_sync(hdev); + if (err) + return err; + + if (lmp_bredr_capable(hdev)) { + if (hci_dev_test_flag(hdev, HCI_FAST_CONNECTABLE)) + hci_write_fast_connectable_sync(hdev, true); + else + hci_write_fast_connectable_sync(hdev, false); + hci_update_scan_sync(hdev); + hci_update_class_sync(hdev); + hci_update_name_sync(hdev); + hci_update_eir_sync(hdev); + } + + return 0; +} + +/** + * hci_dev_get_bd_addr_from_property - Get the Bluetooth Device Address + * (BD_ADDR) for a HCI device from + * a firmware node property. + * @hdev: The HCI device + * + * Search the firmware node for 'local-bd-address'. + * + * All-zero BD addresses are rejected, because those could be properties + * that exist in the firmware tables, but were not updated by the firmware. For + * example, the DTS could define 'local-bd-address', with zero BD addresses. + */ +static void hci_dev_get_bd_addr_from_property(struct hci_dev *hdev) +{ + struct fwnode_handle *fwnode = dev_fwnode(hdev->dev.parent); + bdaddr_t ba; + int ret; + + ret = fwnode_property_read_u8_array(fwnode, "local-bd-address", + (u8 *)&ba, sizeof(ba)); + if (ret < 0 || !bacmp(&ba, BDADDR_ANY)) + return; + + bacpy(&hdev->public_addr, &ba); +} + +struct hci_init_stage { + int (*func)(struct hci_dev *hdev); +}; + +/* Run init stage NULL terminated function table */ +static int hci_init_stage_sync(struct hci_dev *hdev, + const struct hci_init_stage *stage) +{ + size_t i; + + for (i = 0; stage[i].func; i++) { + int err; + + err = stage[i].func(hdev); + if (err) + return err; + } + + return 0; +} + +/* Read Local Version */ +static int hci_read_local_version_sync(struct hci_dev *hdev) +{ + return __hci_cmd_sync_status(hdev, HCI_OP_READ_LOCAL_VERSION, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Read BD Address */ +static int hci_read_bd_addr_sync(struct hci_dev *hdev) +{ + return __hci_cmd_sync_status(hdev, HCI_OP_READ_BD_ADDR, + 0, NULL, HCI_CMD_TIMEOUT); +} + +#define HCI_INIT(_func) \ +{ \ + .func = _func, \ +} + +static const struct hci_init_stage hci_init0[] = { + /* HCI_OP_READ_LOCAL_VERSION */ + HCI_INIT(hci_read_local_version_sync), + /* HCI_OP_READ_BD_ADDR */ + HCI_INIT(hci_read_bd_addr_sync), + {} +}; + +int hci_reset_sync(struct hci_dev *hdev) +{ + int err; + + set_bit(HCI_RESET, &hdev->flags); + + err = __hci_cmd_sync_status(hdev, HCI_OP_RESET, 0, NULL, + HCI_CMD_TIMEOUT); + if (err) + return err; + + return 0; +} + +static int hci_init0_sync(struct hci_dev *hdev) +{ + int err; + + bt_dev_dbg(hdev, ""); + + /* Reset */ + if (!test_bit(HCI_QUIRK_RESET_ON_CLOSE, &hdev->quirks)) { + err = hci_reset_sync(hdev); + if (err) + return err; + } + + return hci_init_stage_sync(hdev, hci_init0); +} + +static int hci_unconf_init_sync(struct hci_dev *hdev) +{ + int err; + + if (test_bit(HCI_QUIRK_RAW_DEVICE, &hdev->quirks)) + return 0; + + err = hci_init0_sync(hdev); + if (err < 0) + return err; + + if (hci_dev_test_flag(hdev, HCI_SETUP)) + hci_debugfs_create_basic(hdev); + + return 0; +} + +/* Read Local Supported Features. */ +static int hci_read_local_features_sync(struct hci_dev *hdev) +{ + /* Not all AMP controllers support this command */ + if (hdev->dev_type == HCI_AMP && !(hdev->commands[14] & 0x20)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_READ_LOCAL_FEATURES, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* BR Controller init stage 1 command sequence */ +static const struct hci_init_stage br_init1[] = { + /* HCI_OP_READ_LOCAL_FEATURES */ + HCI_INIT(hci_read_local_features_sync), + /* HCI_OP_READ_LOCAL_VERSION */ + HCI_INIT(hci_read_local_version_sync), + /* HCI_OP_READ_BD_ADDR */ + HCI_INIT(hci_read_bd_addr_sync), + {} +}; + +/* Read Local Commands */ +static int hci_read_local_cmds_sync(struct hci_dev *hdev) +{ + /* All Bluetooth 1.2 and later controllers should support the + * HCI command for reading the local supported commands. + * + * Unfortunately some controllers indicate Bluetooth 1.2 support, + * but do not have support for this command. If that is the case, + * the driver can quirk the behavior and skip reading the local + * supported commands. + */ + if (hdev->hci_ver > BLUETOOTH_VER_1_1 && + !test_bit(HCI_QUIRK_BROKEN_LOCAL_COMMANDS, &hdev->quirks)) + return __hci_cmd_sync_status(hdev, HCI_OP_READ_LOCAL_COMMANDS, + 0, NULL, HCI_CMD_TIMEOUT); + + return 0; +} + +/* Read Local AMP Info */ +static int hci_read_local_amp_info_sync(struct hci_dev *hdev) +{ + return __hci_cmd_sync_status(hdev, HCI_OP_READ_LOCAL_AMP_INFO, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Read Data Blk size */ +static int hci_read_data_block_size_sync(struct hci_dev *hdev) +{ + return __hci_cmd_sync_status(hdev, HCI_OP_READ_DATA_BLOCK_SIZE, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Read Flow Control Mode */ +static int hci_read_flow_control_mode_sync(struct hci_dev *hdev) +{ + return __hci_cmd_sync_status(hdev, HCI_OP_READ_FLOW_CONTROL_MODE, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Read Location Data */ +static int hci_read_location_data_sync(struct hci_dev *hdev) +{ + return __hci_cmd_sync_status(hdev, HCI_OP_READ_LOCATION_DATA, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* AMP Controller init stage 1 command sequence */ +static const struct hci_init_stage amp_init1[] = { + /* HCI_OP_READ_LOCAL_VERSION */ + HCI_INIT(hci_read_local_version_sync), + /* HCI_OP_READ_LOCAL_COMMANDS */ + HCI_INIT(hci_read_local_cmds_sync), + /* HCI_OP_READ_LOCAL_AMP_INFO */ + HCI_INIT(hci_read_local_amp_info_sync), + /* HCI_OP_READ_DATA_BLOCK_SIZE */ + HCI_INIT(hci_read_data_block_size_sync), + /* HCI_OP_READ_FLOW_CONTROL_MODE */ + HCI_INIT(hci_read_flow_control_mode_sync), + /* HCI_OP_READ_LOCATION_DATA */ + HCI_INIT(hci_read_location_data_sync), + {} +}; + +static int hci_init1_sync(struct hci_dev *hdev) +{ + int err; + + bt_dev_dbg(hdev, ""); + + /* Reset */ + if (!test_bit(HCI_QUIRK_RESET_ON_CLOSE, &hdev->quirks)) { + err = hci_reset_sync(hdev); + if (err) + return err; + } + + switch (hdev->dev_type) { + case HCI_PRIMARY: + hdev->flow_ctl_mode = HCI_FLOW_CTL_MODE_PACKET_BASED; + return hci_init_stage_sync(hdev, br_init1); + case HCI_AMP: + hdev->flow_ctl_mode = HCI_FLOW_CTL_MODE_BLOCK_BASED; + return hci_init_stage_sync(hdev, amp_init1); + default: + bt_dev_err(hdev, "Unknown device type %d", hdev->dev_type); + break; + } + + return 0; +} + +/* AMP Controller init stage 2 command sequence */ +static const struct hci_init_stage amp_init2[] = { + /* HCI_OP_READ_LOCAL_FEATURES */ + HCI_INIT(hci_read_local_features_sync), + {} +}; + +/* Read Buffer Size (ACL mtu, max pkt, etc.) */ +static int hci_read_buffer_size_sync(struct hci_dev *hdev) +{ + return __hci_cmd_sync_status(hdev, HCI_OP_READ_BUFFER_SIZE, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Read Class of Device */ +static int hci_read_dev_class_sync(struct hci_dev *hdev) +{ + return __hci_cmd_sync_status(hdev, HCI_OP_READ_CLASS_OF_DEV, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Read Local Name */ +static int hci_read_local_name_sync(struct hci_dev *hdev) +{ + return __hci_cmd_sync_status(hdev, HCI_OP_READ_LOCAL_NAME, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Read Voice Setting */ +static int hci_read_voice_setting_sync(struct hci_dev *hdev) +{ + return __hci_cmd_sync_status(hdev, HCI_OP_READ_VOICE_SETTING, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Read Number of Supported IAC */ +static int hci_read_num_supported_iac_sync(struct hci_dev *hdev) +{ + return __hci_cmd_sync_status(hdev, HCI_OP_READ_NUM_SUPPORTED_IAC, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Read Current IAC LAP */ +static int hci_read_current_iac_lap_sync(struct hci_dev *hdev) +{ + return __hci_cmd_sync_status(hdev, HCI_OP_READ_CURRENT_IAC_LAP, + 0, NULL, HCI_CMD_TIMEOUT); +} + +static int hci_set_event_filter_sync(struct hci_dev *hdev, u8 flt_type, + u8 cond_type, bdaddr_t *bdaddr, + u8 auto_accept) +{ + struct hci_cp_set_event_filter cp; + + if (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) + return 0; + + if (test_bit(HCI_QUIRK_BROKEN_FILTER_CLEAR_ALL, &hdev->quirks)) + return 0; + + memset(&cp, 0, sizeof(cp)); + cp.flt_type = flt_type; + + if (flt_type != HCI_FLT_CLEAR_ALL) { + cp.cond_type = cond_type; + bacpy(&cp.addr_conn_flt.bdaddr, bdaddr); + cp.addr_conn_flt.auto_accept = auto_accept; + } + + return __hci_cmd_sync_status(hdev, HCI_OP_SET_EVENT_FLT, + flt_type == HCI_FLT_CLEAR_ALL ? + sizeof(cp.flt_type) : sizeof(cp), &cp, + HCI_CMD_TIMEOUT); +} + +static int hci_clear_event_filter_sync(struct hci_dev *hdev) +{ + if (!hci_dev_test_flag(hdev, HCI_EVENT_FILTER_CONFIGURED)) + return 0; + + /* In theory the state machine should not reach here unless + * a hci_set_event_filter_sync() call succeeds, but we do + * the check both for parity and as a future reminder. + */ + if (test_bit(HCI_QUIRK_BROKEN_FILTER_CLEAR_ALL, &hdev->quirks)) + return 0; + + return hci_set_event_filter_sync(hdev, HCI_FLT_CLEAR_ALL, 0x00, + BDADDR_ANY, 0x00); +} + +/* Connection accept timeout ~20 secs */ +static int hci_write_ca_timeout_sync(struct hci_dev *hdev) +{ + __le16 param = cpu_to_le16(0x7d00); + + return __hci_cmd_sync_status(hdev, HCI_OP_WRITE_CA_TIMEOUT, + sizeof(param), ¶m, HCI_CMD_TIMEOUT); +} + +/* BR Controller init stage 2 command sequence */ +static const struct hci_init_stage br_init2[] = { + /* HCI_OP_READ_BUFFER_SIZE */ + HCI_INIT(hci_read_buffer_size_sync), + /* HCI_OP_READ_CLASS_OF_DEV */ + HCI_INIT(hci_read_dev_class_sync), + /* HCI_OP_READ_LOCAL_NAME */ + HCI_INIT(hci_read_local_name_sync), + /* HCI_OP_READ_VOICE_SETTING */ + HCI_INIT(hci_read_voice_setting_sync), + /* HCI_OP_READ_NUM_SUPPORTED_IAC */ + HCI_INIT(hci_read_num_supported_iac_sync), + /* HCI_OP_READ_CURRENT_IAC_LAP */ + HCI_INIT(hci_read_current_iac_lap_sync), + /* HCI_OP_SET_EVENT_FLT */ + HCI_INIT(hci_clear_event_filter_sync), + /* HCI_OP_WRITE_CA_TIMEOUT */ + HCI_INIT(hci_write_ca_timeout_sync), + {} +}; + +static int hci_write_ssp_mode_1_sync(struct hci_dev *hdev) +{ + u8 mode = 0x01; + + if (!lmp_ssp_capable(hdev) || !hci_dev_test_flag(hdev, HCI_SSP_ENABLED)) + return 0; + + /* When SSP is available, then the host features page + * should also be available as well. However some + * controllers list the max_page as 0 as long as SSP + * has not been enabled. To achieve proper debugging + * output, force the minimum max_page to 1 at least. + */ + hdev->max_page = 0x01; + + return __hci_cmd_sync_status(hdev, HCI_OP_WRITE_SSP_MODE, + sizeof(mode), &mode, HCI_CMD_TIMEOUT); +} + +static int hci_write_eir_sync(struct hci_dev *hdev) +{ + struct hci_cp_write_eir cp; + + if (!lmp_ssp_capable(hdev) || hci_dev_test_flag(hdev, HCI_SSP_ENABLED)) + return 0; + + memset(hdev->eir, 0, sizeof(hdev->eir)); + memset(&cp, 0, sizeof(cp)); + + return __hci_cmd_sync_status(hdev, HCI_OP_WRITE_EIR, sizeof(cp), &cp, + HCI_CMD_TIMEOUT); +} + +static int hci_write_inquiry_mode_sync(struct hci_dev *hdev) +{ + u8 mode; + + if (!lmp_inq_rssi_capable(hdev) && + !test_bit(HCI_QUIRK_FIXUP_INQUIRY_MODE, &hdev->quirks)) + return 0; + + /* If Extended Inquiry Result events are supported, then + * they are clearly preferred over Inquiry Result with RSSI + * events. + */ + mode = lmp_ext_inq_capable(hdev) ? 0x02 : 0x01; + + return __hci_cmd_sync_status(hdev, HCI_OP_WRITE_INQUIRY_MODE, + sizeof(mode), &mode, HCI_CMD_TIMEOUT); +} + +static int hci_read_inq_rsp_tx_power_sync(struct hci_dev *hdev) +{ + if (!lmp_inq_tx_pwr_capable(hdev)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_READ_INQ_RSP_TX_POWER, + 0, NULL, HCI_CMD_TIMEOUT); +} + +static int hci_read_local_ext_features_sync(struct hci_dev *hdev, u8 page) +{ + struct hci_cp_read_local_ext_features cp; + + if (!lmp_ext_feat_capable(hdev)) + return 0; + + memset(&cp, 0, sizeof(cp)); + cp.page = page; + + return __hci_cmd_sync_status(hdev, HCI_OP_READ_LOCAL_EXT_FEATURES, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +static int hci_read_local_ext_features_1_sync(struct hci_dev *hdev) +{ + return hci_read_local_ext_features_sync(hdev, 0x01); +} + +/* HCI Controller init stage 2 command sequence */ +static const struct hci_init_stage hci_init2[] = { + /* HCI_OP_READ_LOCAL_COMMANDS */ + HCI_INIT(hci_read_local_cmds_sync), + /* HCI_OP_WRITE_SSP_MODE */ + HCI_INIT(hci_write_ssp_mode_1_sync), + /* HCI_OP_WRITE_EIR */ + HCI_INIT(hci_write_eir_sync), + /* HCI_OP_WRITE_INQUIRY_MODE */ + HCI_INIT(hci_write_inquiry_mode_sync), + /* HCI_OP_READ_INQ_RSP_TX_POWER */ + HCI_INIT(hci_read_inq_rsp_tx_power_sync), + /* HCI_OP_READ_LOCAL_EXT_FEATURES */ + HCI_INIT(hci_read_local_ext_features_1_sync), + /* HCI_OP_WRITE_AUTH_ENABLE */ + HCI_INIT(hci_write_auth_enable_sync), + {} +}; + +/* Read LE Buffer Size */ +static int hci_le_read_buffer_size_sync(struct hci_dev *hdev) +{ + /* Use Read LE Buffer Size V2 if supported */ + if (iso_capable(hdev) && hdev->commands[41] & 0x20) + return __hci_cmd_sync_status(hdev, + HCI_OP_LE_READ_BUFFER_SIZE_V2, + 0, NULL, HCI_CMD_TIMEOUT); + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_READ_BUFFER_SIZE, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Read LE Local Supported Features */ +static int hci_le_read_local_features_sync(struct hci_dev *hdev) +{ + return __hci_cmd_sync_status(hdev, HCI_OP_LE_READ_LOCAL_FEATURES, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Read LE Supported States */ +static int hci_le_read_supported_states_sync(struct hci_dev *hdev) +{ + return __hci_cmd_sync_status(hdev, HCI_OP_LE_READ_SUPPORTED_STATES, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* LE Controller init stage 2 command sequence */ +static const struct hci_init_stage le_init2[] = { + /* HCI_OP_LE_READ_LOCAL_FEATURES */ + HCI_INIT(hci_le_read_local_features_sync), + /* HCI_OP_LE_READ_BUFFER_SIZE */ + HCI_INIT(hci_le_read_buffer_size_sync), + /* HCI_OP_LE_READ_SUPPORTED_STATES */ + HCI_INIT(hci_le_read_supported_states_sync), + {} +}; + +static int hci_init2_sync(struct hci_dev *hdev) +{ + int err; + + bt_dev_dbg(hdev, ""); + + if (hdev->dev_type == HCI_AMP) + return hci_init_stage_sync(hdev, amp_init2); + + err = hci_init_stage_sync(hdev, hci_init2); + if (err) + return err; + + if (lmp_bredr_capable(hdev)) { + err = hci_init_stage_sync(hdev, br_init2); + if (err) + return err; + } else { + hci_dev_clear_flag(hdev, HCI_BREDR_ENABLED); + } + + if (lmp_le_capable(hdev)) { + err = hci_init_stage_sync(hdev, le_init2); + if (err) + return err; + /* LE-only controllers have LE implicitly enabled */ + if (!lmp_bredr_capable(hdev)) + hci_dev_set_flag(hdev, HCI_LE_ENABLED); + } + + return 0; +} + +static int hci_set_event_mask_sync(struct hci_dev *hdev) +{ + /* The second byte is 0xff instead of 0x9f (two reserved bits + * disabled) since a Broadcom 1.2 dongle doesn't respond to the + * command otherwise. + */ + u8 events[8] = { 0xff, 0xff, 0xfb, 0xff, 0x00, 0x00, 0x00, 0x00 }; + + /* CSR 1.1 dongles does not accept any bitfield so don't try to set + * any event mask for pre 1.2 devices. + */ + if (hdev->hci_ver < BLUETOOTH_VER_1_2) + return 0; + + if (lmp_bredr_capable(hdev)) { + events[4] |= 0x01; /* Flow Specification Complete */ + + /* Don't set Disconnect Complete when suspended as that + * would wakeup the host when disconnecting due to + * suspend. + */ + if (hdev->suspended) + events[0] &= 0xef; + } else { + /* Use a different default for LE-only devices */ + memset(events, 0, sizeof(events)); + events[1] |= 0x20; /* Command Complete */ + events[1] |= 0x40; /* Command Status */ + events[1] |= 0x80; /* Hardware Error */ + + /* If the controller supports the Disconnect command, enable + * the corresponding event. In addition enable packet flow + * control related events. + */ + if (hdev->commands[0] & 0x20) { + /* Don't set Disconnect Complete when suspended as that + * would wakeup the host when disconnecting due to + * suspend. + */ + if (!hdev->suspended) + events[0] |= 0x10; /* Disconnection Complete */ + events[2] |= 0x04; /* Number of Completed Packets */ + events[3] |= 0x02; /* Data Buffer Overflow */ + } + + /* If the controller supports the Read Remote Version + * Information command, enable the corresponding event. + */ + if (hdev->commands[2] & 0x80) + events[1] |= 0x08; /* Read Remote Version Information + * Complete + */ + + if (hdev->le_features[0] & HCI_LE_ENCRYPTION) { + events[0] |= 0x80; /* Encryption Change */ + events[5] |= 0x80; /* Encryption Key Refresh Complete */ + } + } + + if (lmp_inq_rssi_capable(hdev) || + test_bit(HCI_QUIRK_FIXUP_INQUIRY_MODE, &hdev->quirks)) + events[4] |= 0x02; /* Inquiry Result with RSSI */ + + if (lmp_ext_feat_capable(hdev)) + events[4] |= 0x04; /* Read Remote Extended Features Complete */ + + if (lmp_esco_capable(hdev)) { + events[5] |= 0x08; /* Synchronous Connection Complete */ + events[5] |= 0x10; /* Synchronous Connection Changed */ + } + + if (lmp_sniffsubr_capable(hdev)) + events[5] |= 0x20; /* Sniff Subrating */ + + if (lmp_pause_enc_capable(hdev)) + events[5] |= 0x80; /* Encryption Key Refresh Complete */ + + if (lmp_ext_inq_capable(hdev)) + events[5] |= 0x40; /* Extended Inquiry Result */ + + if (lmp_no_flush_capable(hdev)) + events[7] |= 0x01; /* Enhanced Flush Complete */ + + if (lmp_lsto_capable(hdev)) + events[6] |= 0x80; /* Link Supervision Timeout Changed */ + + if (lmp_ssp_capable(hdev)) { + events[6] |= 0x01; /* IO Capability Request */ + events[6] |= 0x02; /* IO Capability Response */ + events[6] |= 0x04; /* User Confirmation Request */ + events[6] |= 0x08; /* User Passkey Request */ + events[6] |= 0x10; /* Remote OOB Data Request */ + events[6] |= 0x20; /* Simple Pairing Complete */ + events[7] |= 0x04; /* User Passkey Notification */ + events[7] |= 0x08; /* Keypress Notification */ + events[7] |= 0x10; /* Remote Host Supported + * Features Notification + */ + } + + if (lmp_le_capable(hdev)) + events[7] |= 0x20; /* LE Meta-Event */ + + return __hci_cmd_sync_status(hdev, HCI_OP_SET_EVENT_MASK, + sizeof(events), events, HCI_CMD_TIMEOUT); +} + +static int hci_read_stored_link_key_sync(struct hci_dev *hdev) +{ + struct hci_cp_read_stored_link_key cp; + + if (!(hdev->commands[6] & 0x20) || + test_bit(HCI_QUIRK_BROKEN_STORED_LINK_KEY, &hdev->quirks)) + return 0; + + memset(&cp, 0, sizeof(cp)); + bacpy(&cp.bdaddr, BDADDR_ANY); + cp.read_all = 0x01; + + return __hci_cmd_sync_status(hdev, HCI_OP_READ_STORED_LINK_KEY, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +static int hci_setup_link_policy_sync(struct hci_dev *hdev) +{ + struct hci_cp_write_def_link_policy cp; + u16 link_policy = 0; + + if (!(hdev->commands[5] & 0x10)) + return 0; + + memset(&cp, 0, sizeof(cp)); + + if (lmp_rswitch_capable(hdev)) + link_policy |= HCI_LP_RSWITCH; + if (lmp_hold_capable(hdev)) + link_policy |= HCI_LP_HOLD; + if (lmp_sniff_capable(hdev)) + link_policy |= HCI_LP_SNIFF; + if (lmp_park_capable(hdev)) + link_policy |= HCI_LP_PARK; + + cp.policy = cpu_to_le16(link_policy); + + return __hci_cmd_sync_status(hdev, HCI_OP_WRITE_DEF_LINK_POLICY, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +static int hci_read_page_scan_activity_sync(struct hci_dev *hdev) +{ + if (!(hdev->commands[8] & 0x01)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_READ_PAGE_SCAN_ACTIVITY, + 0, NULL, HCI_CMD_TIMEOUT); +} + +static int hci_read_def_err_data_reporting_sync(struct hci_dev *hdev) +{ + if (!(hdev->commands[18] & 0x04) || + !(hdev->features[0][6] & LMP_ERR_DATA_REPORTING) || + test_bit(HCI_QUIRK_BROKEN_ERR_DATA_REPORTING, &hdev->quirks)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_READ_DEF_ERR_DATA_REPORTING, + 0, NULL, HCI_CMD_TIMEOUT); +} + +static int hci_read_page_scan_type_sync(struct hci_dev *hdev) +{ + /* Some older Broadcom based Bluetooth 1.2 controllers do not + * support the Read Page Scan Type command. Check support for + * this command in the bit mask of supported commands. + */ + if (!(hdev->commands[13] & 0x01)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_READ_PAGE_SCAN_TYPE, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Read features beyond page 1 if available */ +static int hci_read_local_ext_features_all_sync(struct hci_dev *hdev) +{ + u8 page; + int err; + + if (!lmp_ext_feat_capable(hdev)) + return 0; + + for (page = 2; page < HCI_MAX_PAGES && page <= hdev->max_page; + page++) { + err = hci_read_local_ext_features_sync(hdev, page); + if (err) + return err; + } + + return 0; +} + +/* HCI Controller init stage 3 command sequence */ +static const struct hci_init_stage hci_init3[] = { + /* HCI_OP_SET_EVENT_MASK */ + HCI_INIT(hci_set_event_mask_sync), + /* HCI_OP_READ_STORED_LINK_KEY */ + HCI_INIT(hci_read_stored_link_key_sync), + /* HCI_OP_WRITE_DEF_LINK_POLICY */ + HCI_INIT(hci_setup_link_policy_sync), + /* HCI_OP_READ_PAGE_SCAN_ACTIVITY */ + HCI_INIT(hci_read_page_scan_activity_sync), + /* HCI_OP_READ_DEF_ERR_DATA_REPORTING */ + HCI_INIT(hci_read_def_err_data_reporting_sync), + /* HCI_OP_READ_PAGE_SCAN_TYPE */ + HCI_INIT(hci_read_page_scan_type_sync), + /* HCI_OP_READ_LOCAL_EXT_FEATURES */ + HCI_INIT(hci_read_local_ext_features_all_sync), + {} +}; + +static int hci_le_set_event_mask_sync(struct hci_dev *hdev) +{ + u8 events[8]; + + if (!lmp_le_capable(hdev)) + return 0; + + memset(events, 0, sizeof(events)); + + if (hdev->le_features[0] & HCI_LE_ENCRYPTION) + events[0] |= 0x10; /* LE Long Term Key Request */ + + /* If controller supports the Connection Parameters Request + * Link Layer Procedure, enable the corresponding event. + */ + if (hdev->le_features[0] & HCI_LE_CONN_PARAM_REQ_PROC) + /* LE Remote Connection Parameter Request */ + events[0] |= 0x20; + + /* If the controller supports the Data Length Extension + * feature, enable the corresponding event. + */ + if (hdev->le_features[0] & HCI_LE_DATA_LEN_EXT) + events[0] |= 0x40; /* LE Data Length Change */ + + /* If the controller supports LL Privacy feature or LE Extended Adv, + * enable the corresponding event. + */ + if (use_enhanced_conn_complete(hdev)) + events[1] |= 0x02; /* LE Enhanced Connection Complete */ + + /* If the controller supports Extended Scanner Filter + * Policies, enable the corresponding event. + */ + if (hdev->le_features[0] & HCI_LE_EXT_SCAN_POLICY) + events[1] |= 0x04; /* LE Direct Advertising Report */ + + /* If the controller supports Channel Selection Algorithm #2 + * feature, enable the corresponding event. + */ + if (hdev->le_features[1] & HCI_LE_CHAN_SEL_ALG2) + events[2] |= 0x08; /* LE Channel Selection Algorithm */ + + /* If the controller supports the LE Set Scan Enable command, + * enable the corresponding advertising report event. + */ + if (hdev->commands[26] & 0x08) + events[0] |= 0x02; /* LE Advertising Report */ + + /* If the controller supports the LE Create Connection + * command, enable the corresponding event. + */ + if (hdev->commands[26] & 0x10) + events[0] |= 0x01; /* LE Connection Complete */ + + /* If the controller supports the LE Connection Update + * command, enable the corresponding event. + */ + if (hdev->commands[27] & 0x04) + events[0] |= 0x04; /* LE Connection Update Complete */ + + /* If the controller supports the LE Read Remote Used Features + * command, enable the corresponding event. + */ + if (hdev->commands[27] & 0x20) + /* LE Read Remote Used Features Complete */ + events[0] |= 0x08; + + /* If the controller supports the LE Read Local P-256 + * Public Key command, enable the corresponding event. + */ + if (hdev->commands[34] & 0x02) + /* LE Read Local P-256 Public Key Complete */ + events[0] |= 0x80; + + /* If the controller supports the LE Generate DHKey + * command, enable the corresponding event. + */ + if (hdev->commands[34] & 0x04) + events[1] |= 0x01; /* LE Generate DHKey Complete */ + + /* If the controller supports the LE Set Default PHY or + * LE Set PHY commands, enable the corresponding event. + */ + if (hdev->commands[35] & (0x20 | 0x40)) + events[1] |= 0x08; /* LE PHY Update Complete */ + + /* If the controller supports LE Set Extended Scan Parameters + * and LE Set Extended Scan Enable commands, enable the + * corresponding event. + */ + if (use_ext_scan(hdev)) + events[1] |= 0x10; /* LE Extended Advertising Report */ + + /* If the controller supports the LE Extended Advertising + * command, enable the corresponding event. + */ + if (ext_adv_capable(hdev)) + events[2] |= 0x02; /* LE Advertising Set Terminated */ + + if (cis_capable(hdev)) { + events[3] |= 0x01; /* LE CIS Established */ + if (cis_peripheral_capable(hdev)) + events[3] |= 0x02; /* LE CIS Request */ + } + + if (bis_capable(hdev)) { + events[3] |= 0x04; /* LE Create BIG Complete */ + events[3] |= 0x08; /* LE Terminate BIG Complete */ + events[3] |= 0x10; /* LE BIG Sync Established */ + events[3] |= 0x20; /* LE BIG Sync Loss */ + } + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_EVENT_MASK, + sizeof(events), events, HCI_CMD_TIMEOUT); +} + +/* Read LE Advertising Channel TX Power */ +static int hci_le_read_adv_tx_power_sync(struct hci_dev *hdev) +{ + if ((hdev->commands[25] & 0x40) && !ext_adv_capable(hdev)) { + /* HCI TS spec forbids mixing of legacy and extended + * advertising commands wherein READ_ADV_TX_POWER is + * also included. So do not call it if extended adv + * is supported otherwise controller will return + * COMMAND_DISALLOWED for extended commands. + */ + return __hci_cmd_sync_status(hdev, + HCI_OP_LE_READ_ADV_TX_POWER, + 0, NULL, HCI_CMD_TIMEOUT); + } + + return 0; +} + +/* Read LE Min/Max Tx Power*/ +static int hci_le_read_tx_power_sync(struct hci_dev *hdev) +{ + if (!(hdev->commands[38] & 0x80) || + test_bit(HCI_QUIRK_BROKEN_READ_TRANSMIT_POWER, &hdev->quirks)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_READ_TRANSMIT_POWER, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Read LE Accept List Size */ +static int hci_le_read_accept_list_size_sync(struct hci_dev *hdev) +{ + if (!(hdev->commands[26] & 0x40)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_READ_ACCEPT_LIST_SIZE, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Clear LE Accept List */ +static int hci_le_clear_accept_list_sync(struct hci_dev *hdev) +{ + if (!(hdev->commands[26] & 0x80)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_CLEAR_ACCEPT_LIST, 0, NULL, + HCI_CMD_TIMEOUT); +} + +/* Read LE Resolving List Size */ +static int hci_le_read_resolv_list_size_sync(struct hci_dev *hdev) +{ + if (!(hdev->commands[34] & 0x40)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_READ_RESOLV_LIST_SIZE, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Clear LE Resolving List */ +static int hci_le_clear_resolv_list_sync(struct hci_dev *hdev) +{ + if (!(hdev->commands[34] & 0x20)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_CLEAR_RESOLV_LIST, 0, NULL, + HCI_CMD_TIMEOUT); +} + +/* Set RPA timeout */ +static int hci_le_set_rpa_timeout_sync(struct hci_dev *hdev) +{ + __le16 timeout = cpu_to_le16(hdev->rpa_timeout); + + if (!(hdev->commands[35] & 0x04) || + test_bit(HCI_QUIRK_BROKEN_SET_RPA_TIMEOUT, &hdev->quirks)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_RPA_TIMEOUT, + sizeof(timeout), &timeout, + HCI_CMD_TIMEOUT); +} + +/* Read LE Maximum Data Length */ +static int hci_le_read_max_data_len_sync(struct hci_dev *hdev) +{ + if (!(hdev->le_features[0] & HCI_LE_DATA_LEN_EXT)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_READ_MAX_DATA_LEN, 0, NULL, + HCI_CMD_TIMEOUT); +} + +/* Read LE Suggested Default Data Length */ +static int hci_le_read_def_data_len_sync(struct hci_dev *hdev) +{ + if (!(hdev->le_features[0] & HCI_LE_DATA_LEN_EXT)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_READ_DEF_DATA_LEN, 0, NULL, + HCI_CMD_TIMEOUT); +} + +/* Read LE Number of Supported Advertising Sets */ +static int hci_le_read_num_support_adv_sets_sync(struct hci_dev *hdev) +{ + if (!ext_adv_capable(hdev)) + return 0; + + return __hci_cmd_sync_status(hdev, + HCI_OP_LE_READ_NUM_SUPPORTED_ADV_SETS, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Write LE Host Supported */ +static int hci_set_le_support_sync(struct hci_dev *hdev) +{ + struct hci_cp_write_le_host_supported cp; + + /* LE-only devices do not support explicit enablement */ + if (!lmp_bredr_capable(hdev)) + return 0; + + memset(&cp, 0, sizeof(cp)); + + if (hci_dev_test_flag(hdev, HCI_LE_ENABLED)) { + cp.le = 0x01; + cp.simul = 0x00; + } + + if (cp.le == lmp_host_le_capable(hdev)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_WRITE_LE_HOST_SUPPORTED, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +/* LE Set Host Feature */ +static int hci_le_set_host_feature_sync(struct hci_dev *hdev) +{ + struct hci_cp_le_set_host_feature cp; + + if (!iso_capable(hdev)) + return 0; + + memset(&cp, 0, sizeof(cp)); + + /* Isochronous Channels (Host Support) */ + cp.bit_number = 32; + cp.bit_value = 1; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_HOST_FEATURE, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +/* LE Controller init stage 3 command sequence */ +static const struct hci_init_stage le_init3[] = { + /* HCI_OP_LE_SET_EVENT_MASK */ + HCI_INIT(hci_le_set_event_mask_sync), + /* HCI_OP_LE_READ_ADV_TX_POWER */ + HCI_INIT(hci_le_read_adv_tx_power_sync), + /* HCI_OP_LE_READ_TRANSMIT_POWER */ + HCI_INIT(hci_le_read_tx_power_sync), + /* HCI_OP_LE_READ_ACCEPT_LIST_SIZE */ + HCI_INIT(hci_le_read_accept_list_size_sync), + /* HCI_OP_LE_CLEAR_ACCEPT_LIST */ + HCI_INIT(hci_le_clear_accept_list_sync), + /* HCI_OP_LE_READ_RESOLV_LIST_SIZE */ + HCI_INIT(hci_le_read_resolv_list_size_sync), + /* HCI_OP_LE_CLEAR_RESOLV_LIST */ + HCI_INIT(hci_le_clear_resolv_list_sync), + /* HCI_OP_LE_SET_RPA_TIMEOUT */ + HCI_INIT(hci_le_set_rpa_timeout_sync), + /* HCI_OP_LE_READ_MAX_DATA_LEN */ + HCI_INIT(hci_le_read_max_data_len_sync), + /* HCI_OP_LE_READ_DEF_DATA_LEN */ + HCI_INIT(hci_le_read_def_data_len_sync), + /* HCI_OP_LE_READ_NUM_SUPPORTED_ADV_SETS */ + HCI_INIT(hci_le_read_num_support_adv_sets_sync), + /* HCI_OP_WRITE_LE_HOST_SUPPORTED */ + HCI_INIT(hci_set_le_support_sync), + /* HCI_OP_LE_SET_HOST_FEATURE */ + HCI_INIT(hci_le_set_host_feature_sync), + {} +}; + +static int hci_init3_sync(struct hci_dev *hdev) +{ + int err; + + bt_dev_dbg(hdev, ""); + + err = hci_init_stage_sync(hdev, hci_init3); + if (err) + return err; + + if (lmp_le_capable(hdev)) + return hci_init_stage_sync(hdev, le_init3); + + return 0; +} + +static int hci_delete_stored_link_key_sync(struct hci_dev *hdev) +{ + struct hci_cp_delete_stored_link_key cp; + + /* Some Broadcom based Bluetooth controllers do not support the + * Delete Stored Link Key command. They are clearly indicating its + * absence in the bit mask of supported commands. + * + * Check the supported commands and only if the command is marked + * as supported send it. If not supported assume that the controller + * does not have actual support for stored link keys which makes this + * command redundant anyway. + * + * Some controllers indicate that they support handling deleting + * stored link keys, but they don't. The quirk lets a driver + * just disable this command. + */ + if (!(hdev->commands[6] & 0x80) || + test_bit(HCI_QUIRK_BROKEN_STORED_LINK_KEY, &hdev->quirks)) + return 0; + + memset(&cp, 0, sizeof(cp)); + bacpy(&cp.bdaddr, BDADDR_ANY); + cp.delete_all = 0x01; + + return __hci_cmd_sync_status(hdev, HCI_OP_DELETE_STORED_LINK_KEY, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +static int hci_set_event_mask_page_2_sync(struct hci_dev *hdev) +{ + u8 events[8] = { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }; + bool changed = false; + + /* Set event mask page 2 if the HCI command for it is supported */ + if (!(hdev->commands[22] & 0x04)) + return 0; + + /* If Connectionless Peripheral Broadcast central role is supported + * enable all necessary events for it. + */ + if (lmp_cpb_central_capable(hdev)) { + events[1] |= 0x40; /* Triggered Clock Capture */ + events[1] |= 0x80; /* Synchronization Train Complete */ + events[2] |= 0x08; /* Truncated Page Complete */ + events[2] |= 0x20; /* CPB Channel Map Change */ + changed = true; + } + + /* If Connectionless Peripheral Broadcast peripheral role is supported + * enable all necessary events for it. + */ + if (lmp_cpb_peripheral_capable(hdev)) { + events[2] |= 0x01; /* Synchronization Train Received */ + events[2] |= 0x02; /* CPB Receive */ + events[2] |= 0x04; /* CPB Timeout */ + events[2] |= 0x10; /* Peripheral Page Response Timeout */ + changed = true; + } + + /* Enable Authenticated Payload Timeout Expired event if supported */ + if (lmp_ping_capable(hdev) || hdev->le_features[0] & HCI_LE_PING) { + events[2] |= 0x80; + changed = true; + } + + /* Some Broadcom based controllers indicate support for Set Event + * Mask Page 2 command, but then actually do not support it. Since + * the default value is all bits set to zero, the command is only + * required if the event mask has to be changed. In case no change + * to the event mask is needed, skip this command. + */ + if (!changed) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_SET_EVENT_MASK_PAGE_2, + sizeof(events), events, HCI_CMD_TIMEOUT); +} + +/* Read local codec list if the HCI command is supported */ +static int hci_read_local_codecs_sync(struct hci_dev *hdev) +{ + if (hdev->commands[45] & 0x04) + hci_read_supported_codecs_v2(hdev); + else if (hdev->commands[29] & 0x20) + hci_read_supported_codecs(hdev); + + return 0; +} + +/* Read local pairing options if the HCI command is supported */ +static int hci_read_local_pairing_opts_sync(struct hci_dev *hdev) +{ + if (!(hdev->commands[41] & 0x08)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_READ_LOCAL_PAIRING_OPTS, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Get MWS transport configuration if the HCI command is supported */ +static int hci_get_mws_transport_config_sync(struct hci_dev *hdev) +{ + if (!mws_transport_config_capable(hdev)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_GET_MWS_TRANSPORT_CONFIG, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Check for Synchronization Train support */ +static int hci_read_sync_train_params_sync(struct hci_dev *hdev) +{ + if (!lmp_sync_train_capable(hdev)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_READ_SYNC_TRAIN_PARAMS, + 0, NULL, HCI_CMD_TIMEOUT); +} + +/* Enable Secure Connections if supported and configured */ +static int hci_write_sc_support_1_sync(struct hci_dev *hdev) +{ + u8 support = 0x01; + + if (!hci_dev_test_flag(hdev, HCI_SSP_ENABLED) || + !bredr_sc_enabled(hdev)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_WRITE_SC_SUPPORT, + sizeof(support), &support, + HCI_CMD_TIMEOUT); +} + +/* Set erroneous data reporting if supported to the wideband speech + * setting value + */ +static int hci_set_err_data_report_sync(struct hci_dev *hdev) +{ + struct hci_cp_write_def_err_data_reporting cp; + bool enabled = hci_dev_test_flag(hdev, HCI_WIDEBAND_SPEECH_ENABLED); + + if (!(hdev->commands[18] & 0x08) || + !(hdev->features[0][6] & LMP_ERR_DATA_REPORTING) || + test_bit(HCI_QUIRK_BROKEN_ERR_DATA_REPORTING, &hdev->quirks)) + return 0; + + if (enabled == hdev->err_data_reporting) + return 0; + + memset(&cp, 0, sizeof(cp)); + cp.err_data_reporting = enabled ? ERR_DATA_REPORTING_ENABLED : + ERR_DATA_REPORTING_DISABLED; + + return __hci_cmd_sync_status(hdev, HCI_OP_WRITE_DEF_ERR_DATA_REPORTING, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +static const struct hci_init_stage hci_init4[] = { + /* HCI_OP_DELETE_STORED_LINK_KEY */ + HCI_INIT(hci_delete_stored_link_key_sync), + /* HCI_OP_SET_EVENT_MASK_PAGE_2 */ + HCI_INIT(hci_set_event_mask_page_2_sync), + /* HCI_OP_READ_LOCAL_CODECS */ + HCI_INIT(hci_read_local_codecs_sync), + /* HCI_OP_READ_LOCAL_PAIRING_OPTS */ + HCI_INIT(hci_read_local_pairing_opts_sync), + /* HCI_OP_GET_MWS_TRANSPORT_CONFIG */ + HCI_INIT(hci_get_mws_transport_config_sync), + /* HCI_OP_READ_SYNC_TRAIN_PARAMS */ + HCI_INIT(hci_read_sync_train_params_sync), + /* HCI_OP_WRITE_SC_SUPPORT */ + HCI_INIT(hci_write_sc_support_1_sync), + /* HCI_OP_WRITE_DEF_ERR_DATA_REPORTING */ + HCI_INIT(hci_set_err_data_report_sync), + {} +}; + +/* Set Suggested Default Data Length to maximum if supported */ +static int hci_le_set_write_def_data_len_sync(struct hci_dev *hdev) +{ + struct hci_cp_le_write_def_data_len cp; + + if (!(hdev->le_features[0] & HCI_LE_DATA_LEN_EXT)) + return 0; + + memset(&cp, 0, sizeof(cp)); + cp.tx_len = cpu_to_le16(hdev->le_max_tx_len); + cp.tx_time = cpu_to_le16(hdev->le_max_tx_time); + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_WRITE_DEF_DATA_LEN, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +/* Set Default PHY parameters if command is supported */ +static int hci_le_set_default_phy_sync(struct hci_dev *hdev) +{ + struct hci_cp_le_set_default_phy cp; + + if (!(hdev->commands[35] & 0x20)) + return 0; + + memset(&cp, 0, sizeof(cp)); + cp.all_phys = 0x00; + cp.tx_phys = hdev->le_tx_def_phys; + cp.rx_phys = hdev->le_rx_def_phys; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_DEFAULT_PHY, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +static const struct hci_init_stage le_init4[] = { + /* HCI_OP_LE_WRITE_DEF_DATA_LEN */ + HCI_INIT(hci_le_set_write_def_data_len_sync), + /* HCI_OP_LE_SET_DEFAULT_PHY */ + HCI_INIT(hci_le_set_default_phy_sync), + {} +}; + +static int hci_init4_sync(struct hci_dev *hdev) +{ + int err; + + bt_dev_dbg(hdev, ""); + + err = hci_init_stage_sync(hdev, hci_init4); + if (err) + return err; + + if (lmp_le_capable(hdev)) + return hci_init_stage_sync(hdev, le_init4); + + return 0; +} + +static int hci_init_sync(struct hci_dev *hdev) +{ + int err; + + err = hci_init1_sync(hdev); + if (err < 0) + return err; + + if (hci_dev_test_flag(hdev, HCI_SETUP)) + hci_debugfs_create_basic(hdev); + + err = hci_init2_sync(hdev); + if (err < 0) + return err; + + /* HCI_PRIMARY covers both single-mode LE, BR/EDR and dual-mode + * BR/EDR/LE type controllers. AMP controllers only need the + * first two stages of init. + */ + if (hdev->dev_type != HCI_PRIMARY) + return 0; + + err = hci_init3_sync(hdev); + if (err < 0) + return err; + + err = hci_init4_sync(hdev); + if (err < 0) + return err; + + /* This function is only called when the controller is actually in + * configured state. When the controller is marked as unconfigured, + * this initialization procedure is not run. + * + * It means that it is possible that a controller runs through its + * setup phase and then discovers missing settings. If that is the + * case, then this function will not be called. It then will only + * be called during the config phase. + * + * So only when in setup phase or config phase, create the debugfs + * entries and register the SMP channels. + */ + if (!hci_dev_test_flag(hdev, HCI_SETUP) && + !hci_dev_test_flag(hdev, HCI_CONFIG)) + return 0; + + if (hci_dev_test_and_set_flag(hdev, HCI_DEBUGFS_CREATED)) + return 0; + + hci_debugfs_create_common(hdev); + + if (lmp_bredr_capable(hdev)) + hci_debugfs_create_bredr(hdev); + + if (lmp_le_capable(hdev)) + hci_debugfs_create_le(hdev); + + return 0; +} + +#define HCI_QUIRK_BROKEN(_quirk, _desc) { HCI_QUIRK_BROKEN_##_quirk, _desc } + +static const struct { + unsigned long quirk; + const char *desc; +} hci_broken_table[] = { + HCI_QUIRK_BROKEN(LOCAL_COMMANDS, + "HCI Read Local Supported Commands not supported"), + HCI_QUIRK_BROKEN(STORED_LINK_KEY, + "HCI Delete Stored Link Key command is advertised, " + "but not supported."), + HCI_QUIRK_BROKEN(ERR_DATA_REPORTING, + "HCI Read Default Erroneous Data Reporting command is " + "advertised, but not supported."), + HCI_QUIRK_BROKEN(READ_TRANSMIT_POWER, + "HCI Read Transmit Power Level command is advertised, " + "but not supported."), + HCI_QUIRK_BROKEN(FILTER_CLEAR_ALL, + "HCI Set Event Filter command not supported."), + HCI_QUIRK_BROKEN(ENHANCED_SETUP_SYNC_CONN, + "HCI Enhanced Setup Synchronous Connection command is " + "advertised, but not supported."), + HCI_QUIRK_BROKEN(SET_RPA_TIMEOUT, + "HCI LE Set Random Private Address Timeout command is " + "advertised, but not supported.") +}; + +/* This function handles hdev setup stage: + * + * Calls hdev->setup + * Setup address if HCI_QUIRK_USE_BDADDR_PROPERTY is set. + */ +static int hci_dev_setup_sync(struct hci_dev *hdev) +{ + int ret = 0; + bool invalid_bdaddr; + size_t i; + + if (!hci_dev_test_flag(hdev, HCI_SETUP) && + !test_bit(HCI_QUIRK_NON_PERSISTENT_SETUP, &hdev->quirks)) + return 0; + + bt_dev_dbg(hdev, ""); + + hci_sock_dev_event(hdev, HCI_DEV_SETUP); + + if (hdev->setup) + ret = hdev->setup(hdev); + + for (i = 0; i < ARRAY_SIZE(hci_broken_table); i++) { + if (test_bit(hci_broken_table[i].quirk, &hdev->quirks)) + bt_dev_warn(hdev, "%s", hci_broken_table[i].desc); + } + + /* The transport driver can set the quirk to mark the + * BD_ADDR invalid before creating the HCI device or in + * its setup callback. + */ + invalid_bdaddr = test_bit(HCI_QUIRK_INVALID_BDADDR, &hdev->quirks); + + if (!ret) { + if (test_bit(HCI_QUIRK_USE_BDADDR_PROPERTY, &hdev->quirks) && + !bacmp(&hdev->public_addr, BDADDR_ANY)) + hci_dev_get_bd_addr_from_property(hdev); + + if ((invalid_bdaddr || + test_bit(HCI_QUIRK_USE_BDADDR_PROPERTY, &hdev->quirks)) && + bacmp(&hdev->public_addr, BDADDR_ANY) && + hdev->set_bdaddr) { + ret = hdev->set_bdaddr(hdev, &hdev->public_addr); + if (!ret) + invalid_bdaddr = false; + } + } + + /* The transport driver can set these quirks before + * creating the HCI device or in its setup callback. + * + * For the invalid BD_ADDR quirk it is possible that + * it becomes a valid address if the bootloader does + * provide it (see above). + * + * In case any of them is set, the controller has to + * start up as unconfigured. + */ + if (test_bit(HCI_QUIRK_EXTERNAL_CONFIG, &hdev->quirks) || + invalid_bdaddr) + hci_dev_set_flag(hdev, HCI_UNCONFIGURED); + + /* For an unconfigured controller it is required to + * read at least the version information provided by + * the Read Local Version Information command. + * + * If the set_bdaddr driver callback is provided, then + * also the original Bluetooth public device address + * will be read using the Read BD Address command. + */ + if (hci_dev_test_flag(hdev, HCI_UNCONFIGURED)) + return hci_unconf_init_sync(hdev); + + return ret; +} + +/* This function handles hdev init stage: + * + * Calls hci_dev_setup_sync to perform setup stage + * Calls hci_init_sync to perform HCI command init sequence + */ +static int hci_dev_init_sync(struct hci_dev *hdev) +{ + int ret; + + bt_dev_dbg(hdev, ""); + + atomic_set(&hdev->cmd_cnt, 1); + set_bit(HCI_INIT, &hdev->flags); + + ret = hci_dev_setup_sync(hdev); + + if (hci_dev_test_flag(hdev, HCI_CONFIG)) { + /* If public address change is configured, ensure that + * the address gets programmed. If the driver does not + * support changing the public address, fail the power + * on procedure. + */ + if (bacmp(&hdev->public_addr, BDADDR_ANY) && + hdev->set_bdaddr) + ret = hdev->set_bdaddr(hdev, &hdev->public_addr); + else + ret = -EADDRNOTAVAIL; + } + + if (!ret) { + if (!hci_dev_test_flag(hdev, HCI_UNCONFIGURED) && + !hci_dev_test_flag(hdev, HCI_USER_CHANNEL)) { + ret = hci_init_sync(hdev); + if (!ret && hdev->post_init) + ret = hdev->post_init(hdev); + } + } + + /* If the HCI Reset command is clearing all diagnostic settings, + * then they need to be reprogrammed after the init procedure + * completed. + */ + if (test_bit(HCI_QUIRK_NON_PERSISTENT_DIAG, &hdev->quirks) && + !hci_dev_test_flag(hdev, HCI_USER_CHANNEL) && + hci_dev_test_flag(hdev, HCI_VENDOR_DIAG) && hdev->set_diag) + ret = hdev->set_diag(hdev, true); + + if (!hci_dev_test_flag(hdev, HCI_USER_CHANNEL)) { + msft_do_open(hdev); + aosp_do_open(hdev); + } + + clear_bit(HCI_INIT, &hdev->flags); + + return ret; +} + +int hci_dev_open_sync(struct hci_dev *hdev) +{ + int ret; + + bt_dev_dbg(hdev, ""); + + if (hci_dev_test_flag(hdev, HCI_UNREGISTER)) { + ret = -ENODEV; + goto done; + } + + if (!hci_dev_test_flag(hdev, HCI_SETUP) && + !hci_dev_test_flag(hdev, HCI_CONFIG)) { + /* Check for rfkill but allow the HCI setup stage to + * proceed (which in itself doesn't cause any RF activity). + */ + if (hci_dev_test_flag(hdev, HCI_RFKILLED)) { + ret = -ERFKILL; + goto done; + } + + /* Check for valid public address or a configured static + * random address, but let the HCI setup proceed to + * be able to determine if there is a public address + * or not. + * + * In case of user channel usage, it is not important + * if a public address or static random address is + * available. + * + * This check is only valid for BR/EDR controllers + * since AMP controllers do not have an address. + */ + if (!hci_dev_test_flag(hdev, HCI_USER_CHANNEL) && + hdev->dev_type == HCI_PRIMARY && + !bacmp(&hdev->bdaddr, BDADDR_ANY) && + !bacmp(&hdev->static_addr, BDADDR_ANY)) { + ret = -EADDRNOTAVAIL; + goto done; + } + } + + if (test_bit(HCI_UP, &hdev->flags)) { + ret = -EALREADY; + goto done; + } + + if (hdev->open(hdev)) { + ret = -EIO; + goto done; + } + + set_bit(HCI_RUNNING, &hdev->flags); + hci_sock_dev_event(hdev, HCI_DEV_OPEN); + + ret = hci_dev_init_sync(hdev); + if (!ret) { + hci_dev_hold(hdev); + hci_dev_set_flag(hdev, HCI_RPA_EXPIRED); + hci_adv_instances_set_rpa_expired(hdev, true); + set_bit(HCI_UP, &hdev->flags); + hci_sock_dev_event(hdev, HCI_DEV_UP); + hci_leds_update_powered(hdev, true); + if (!hci_dev_test_flag(hdev, HCI_SETUP) && + !hci_dev_test_flag(hdev, HCI_CONFIG) && + !hci_dev_test_flag(hdev, HCI_UNCONFIGURED) && + !hci_dev_test_flag(hdev, HCI_USER_CHANNEL) && + hci_dev_test_flag(hdev, HCI_MGMT) && + hdev->dev_type == HCI_PRIMARY) { + ret = hci_powered_update_sync(hdev); + mgmt_power_on(hdev, ret); + } + } else { + /* Init failed, cleanup */ + flush_work(&hdev->tx_work); + + /* Since hci_rx_work() is possible to awake new cmd_work + * it should be flushed first to avoid unexpected call of + * hci_cmd_work() + */ + flush_work(&hdev->rx_work); + flush_work(&hdev->cmd_work); + + skb_queue_purge(&hdev->cmd_q); + skb_queue_purge(&hdev->rx_q); + + if (hdev->flush) + hdev->flush(hdev); + + if (hdev->sent_cmd) { + cancel_delayed_work_sync(&hdev->cmd_timer); + kfree_skb(hdev->sent_cmd); + hdev->sent_cmd = NULL; + } + + clear_bit(HCI_RUNNING, &hdev->flags); + hci_sock_dev_event(hdev, HCI_DEV_CLOSE); + + hdev->close(hdev); + hdev->flags &= BIT(HCI_RAW); + } + +done: + return ret; +} + +/* This function requires the caller holds hdev->lock */ +static void hci_pend_le_actions_clear(struct hci_dev *hdev) +{ + struct hci_conn_params *p; + + list_for_each_entry(p, &hdev->le_conn_params, list) { + hci_pend_le_list_del_init(p); + if (p->conn) { + hci_conn_drop(p->conn); + hci_conn_put(p->conn); + p->conn = NULL; + } + } + + BT_DBG("All LE pending actions cleared"); +} + +static int hci_dev_shutdown(struct hci_dev *hdev) +{ + int err = 0; + /* Similar to how we first do setup and then set the exclusive access + * bit for userspace, we must first unset userchannel and then clean up. + * Otherwise, the kernel can't properly use the hci channel to clean up + * the controller (some shutdown routines require sending additional + * commands to the controller for example). + */ + bool was_userchannel = + hci_dev_test_and_clear_flag(hdev, HCI_USER_CHANNEL); + + if (!hci_dev_test_flag(hdev, HCI_UNREGISTER) && + test_bit(HCI_UP, &hdev->flags)) { + /* Execute vendor specific shutdown routine */ + if (hdev->shutdown) + err = hdev->shutdown(hdev); + } + + if (was_userchannel) + hci_dev_set_flag(hdev, HCI_USER_CHANNEL); + + return err; +} + +int hci_dev_close_sync(struct hci_dev *hdev) +{ + bool auto_off; + int err = 0; + + bt_dev_dbg(hdev, ""); + + cancel_delayed_work(&hdev->power_off); + cancel_delayed_work(&hdev->ncmd_timer); + cancel_delayed_work(&hdev->le_scan_disable); + cancel_delayed_work(&hdev->le_scan_restart); + + hci_request_cancel_all(hdev); + + if (hdev->adv_instance_timeout) { + cancel_delayed_work_sync(&hdev->adv_instance_expire); + hdev->adv_instance_timeout = 0; + } + + err = hci_dev_shutdown(hdev); + + if (!test_and_clear_bit(HCI_UP, &hdev->flags)) { + cancel_delayed_work_sync(&hdev->cmd_timer); + return err; + } + + hci_leds_update_powered(hdev, false); + + /* Flush RX and TX works */ + flush_work(&hdev->tx_work); + flush_work(&hdev->rx_work); + + if (hdev->discov_timeout > 0) { + hdev->discov_timeout = 0; + hci_dev_clear_flag(hdev, HCI_DISCOVERABLE); + hci_dev_clear_flag(hdev, HCI_LIMITED_DISCOVERABLE); + } + + if (hci_dev_test_and_clear_flag(hdev, HCI_SERVICE_CACHE)) + cancel_delayed_work(&hdev->service_cache); + + if (hci_dev_test_flag(hdev, HCI_MGMT)) { + struct adv_info *adv_instance; + + cancel_delayed_work_sync(&hdev->rpa_expired); + + list_for_each_entry(adv_instance, &hdev->adv_instances, list) + cancel_delayed_work_sync(&adv_instance->rpa_expired_cb); + } + + /* Avoid potential lockdep warnings from the *_flush() calls by + * ensuring the workqueue is empty up front. + */ + drain_workqueue(hdev->workqueue); + + hci_dev_lock(hdev); + + hci_discovery_set_state(hdev, DISCOVERY_STOPPED); + + auto_off = hci_dev_test_and_clear_flag(hdev, HCI_AUTO_OFF); + + if (!auto_off && hdev->dev_type == HCI_PRIMARY && + !hci_dev_test_flag(hdev, HCI_USER_CHANNEL) && + hci_dev_test_flag(hdev, HCI_MGMT)) + __mgmt_power_off(hdev); + + hci_inquiry_cache_flush(hdev); + hci_pend_le_actions_clear(hdev); + hci_conn_hash_flush(hdev); + /* Prevent data races on hdev->smp_data or hdev->smp_bredr_data */ + smp_unregister(hdev); + hci_dev_unlock(hdev); + + hci_sock_dev_event(hdev, HCI_DEV_DOWN); + + if (!hci_dev_test_flag(hdev, HCI_USER_CHANNEL)) { + aosp_do_close(hdev); + msft_do_close(hdev); + } + + if (hdev->flush) + hdev->flush(hdev); + + /* Reset device */ + skb_queue_purge(&hdev->cmd_q); + atomic_set(&hdev->cmd_cnt, 1); + if (test_bit(HCI_QUIRK_RESET_ON_CLOSE, &hdev->quirks) && + !auto_off && !hci_dev_test_flag(hdev, HCI_UNCONFIGURED)) { + set_bit(HCI_INIT, &hdev->flags); + hci_reset_sync(hdev); + clear_bit(HCI_INIT, &hdev->flags); + } + + /* flush cmd work */ + flush_work(&hdev->cmd_work); + + /* Drop queues */ + skb_queue_purge(&hdev->rx_q); + skb_queue_purge(&hdev->cmd_q); + skb_queue_purge(&hdev->raw_q); + + /* Drop last sent command */ + if (hdev->sent_cmd) { + cancel_delayed_work_sync(&hdev->cmd_timer); + kfree_skb(hdev->sent_cmd); + hdev->sent_cmd = NULL; + } + + clear_bit(HCI_RUNNING, &hdev->flags); + hci_sock_dev_event(hdev, HCI_DEV_CLOSE); + + /* After this point our queues are empty and no tasks are scheduled. */ + hdev->close(hdev); + + /* Clear flags */ + hdev->flags &= BIT(HCI_RAW); + hci_dev_clear_volatile_flags(hdev); + + /* Controller radio is available but is currently powered down */ + hdev->amp_status = AMP_STATUS_POWERED_DOWN; + + memset(hdev->eir, 0, sizeof(hdev->eir)); + memset(hdev->dev_class, 0, sizeof(hdev->dev_class)); + bacpy(&hdev->random_addr, BDADDR_ANY); + hci_codec_list_clear(&hdev->local_codecs); + + hci_dev_put(hdev); + return err; +} + +/* This function perform power on HCI command sequence as follows: + * + * If controller is already up (HCI_UP) performs hci_powered_update_sync + * sequence otherwise run hci_dev_open_sync which will follow with + * hci_powered_update_sync after the init sequence is completed. + */ +static int hci_power_on_sync(struct hci_dev *hdev) +{ + int err; + + if (test_bit(HCI_UP, &hdev->flags) && + hci_dev_test_flag(hdev, HCI_MGMT) && + hci_dev_test_and_clear_flag(hdev, HCI_AUTO_OFF)) { + cancel_delayed_work(&hdev->power_off); + return hci_powered_update_sync(hdev); + } + + err = hci_dev_open_sync(hdev); + if (err < 0) + return err; + + /* During the HCI setup phase, a few error conditions are + * ignored and they need to be checked now. If they are still + * valid, it is important to return the device back off. + */ + if (hci_dev_test_flag(hdev, HCI_RFKILLED) || + hci_dev_test_flag(hdev, HCI_UNCONFIGURED) || + (hdev->dev_type == HCI_PRIMARY && + !bacmp(&hdev->bdaddr, BDADDR_ANY) && + !bacmp(&hdev->static_addr, BDADDR_ANY))) { + hci_dev_clear_flag(hdev, HCI_AUTO_OFF); + hci_dev_close_sync(hdev); + } else if (hci_dev_test_flag(hdev, HCI_AUTO_OFF)) { + queue_delayed_work(hdev->req_workqueue, &hdev->power_off, + HCI_AUTO_OFF_TIMEOUT); + } + + if (hci_dev_test_and_clear_flag(hdev, HCI_SETUP)) { + /* For unconfigured devices, set the HCI_RAW flag + * so that userspace can easily identify them. + */ + if (hci_dev_test_flag(hdev, HCI_UNCONFIGURED)) + set_bit(HCI_RAW, &hdev->flags); + + /* For fully configured devices, this will send + * the Index Added event. For unconfigured devices, + * it will send Unconfigued Index Added event. + * + * Devices with HCI_QUIRK_RAW_DEVICE are ignored + * and no event will be send. + */ + mgmt_index_added(hdev); + } else if (hci_dev_test_and_clear_flag(hdev, HCI_CONFIG)) { + /* When the controller is now configured, then it + * is important to clear the HCI_RAW flag. + */ + if (!hci_dev_test_flag(hdev, HCI_UNCONFIGURED)) + clear_bit(HCI_RAW, &hdev->flags); + + /* Powering on the controller with HCI_CONFIG set only + * happens with the transition from unconfigured to + * configured. This will send the Index Added event. + */ + mgmt_index_added(hdev); + } + + return 0; +} + +static int hci_remote_name_cancel_sync(struct hci_dev *hdev, bdaddr_t *addr) +{ + struct hci_cp_remote_name_req_cancel cp; + + memset(&cp, 0, sizeof(cp)); + bacpy(&cp.bdaddr, addr); + + return __hci_cmd_sync_status(hdev, HCI_OP_REMOTE_NAME_REQ_CANCEL, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +int hci_stop_discovery_sync(struct hci_dev *hdev) +{ + struct discovery_state *d = &hdev->discovery; + struct inquiry_entry *e; + int err; + + bt_dev_dbg(hdev, "state %u", hdev->discovery.state); + + if (d->state == DISCOVERY_FINDING || d->state == DISCOVERY_STOPPING) { + if (test_bit(HCI_INQUIRY, &hdev->flags)) { + err = __hci_cmd_sync_status(hdev, HCI_OP_INQUIRY_CANCEL, + 0, NULL, HCI_CMD_TIMEOUT); + if (err) + return err; + } + + if (hci_dev_test_flag(hdev, HCI_LE_SCAN)) { + cancel_delayed_work(&hdev->le_scan_disable); + cancel_delayed_work(&hdev->le_scan_restart); + + err = hci_scan_disable_sync(hdev); + if (err) + return err; + } + + } else { + err = hci_scan_disable_sync(hdev); + if (err) + return err; + } + + /* Resume advertising if it was paused */ + if (use_ll_privacy(hdev)) + hci_resume_advertising_sync(hdev); + + /* No further actions needed for LE-only discovery */ + if (d->type == DISCOV_TYPE_LE) + return 0; + + if (d->state == DISCOVERY_RESOLVING || d->state == DISCOVERY_STOPPING) { + e = hci_inquiry_cache_lookup_resolve(hdev, BDADDR_ANY, + NAME_PENDING); + if (!e) + return 0; + + return hci_remote_name_cancel_sync(hdev, &e->data.bdaddr); + } + + return 0; +} + +static int hci_disconnect_phy_link_sync(struct hci_dev *hdev, u16 handle, + u8 reason) +{ + struct hci_cp_disconn_phy_link cp; + + memset(&cp, 0, sizeof(cp)); + cp.phy_handle = HCI_PHY_HANDLE(handle); + cp.reason = reason; + + return __hci_cmd_sync_status(hdev, HCI_OP_DISCONN_PHY_LINK, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +static int hci_disconnect_sync(struct hci_dev *hdev, struct hci_conn *conn, + u8 reason) +{ + struct hci_cp_disconnect cp; + + if (conn->type == AMP_LINK) + return hci_disconnect_phy_link_sync(hdev, conn->handle, reason); + + memset(&cp, 0, sizeof(cp)); + cp.handle = cpu_to_le16(conn->handle); + cp.reason = reason; + + /* Wait for HCI_EV_DISCONN_COMPLETE not HCI_EV_CMD_STATUS when not + * suspending. + */ + if (!hdev->suspended) + return __hci_cmd_sync_status_sk(hdev, HCI_OP_DISCONNECT, + sizeof(cp), &cp, + HCI_EV_DISCONN_COMPLETE, + HCI_CMD_TIMEOUT, NULL); + + return __hci_cmd_sync_status(hdev, HCI_OP_DISCONNECT, sizeof(cp), &cp, + HCI_CMD_TIMEOUT); +} + +static int hci_le_connect_cancel_sync(struct hci_dev *hdev, + struct hci_conn *conn) +{ + if (test_bit(HCI_CONN_SCANNING, &conn->flags)) + return 0; + + if (test_and_set_bit(HCI_CONN_CANCEL, &conn->flags)) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_CREATE_CONN_CANCEL, + 0, NULL, HCI_CMD_TIMEOUT); +} + +static int hci_connect_cancel_sync(struct hci_dev *hdev, struct hci_conn *conn) +{ + if (conn->type == LE_LINK) + return hci_le_connect_cancel_sync(hdev, conn); + + if (hdev->hci_ver < BLUETOOTH_VER_1_2) + return 0; + + return __hci_cmd_sync_status(hdev, HCI_OP_CREATE_CONN_CANCEL, + 6, &conn->dst, HCI_CMD_TIMEOUT); +} + +static int hci_reject_sco_sync(struct hci_dev *hdev, struct hci_conn *conn, + u8 reason) +{ + struct hci_cp_reject_sync_conn_req cp; + + memset(&cp, 0, sizeof(cp)); + bacpy(&cp.bdaddr, &conn->dst); + cp.reason = reason; + + /* SCO rejection has its own limited set of + * allowed error values (0x0D-0x0F). + */ + if (reason < 0x0d || reason > 0x0f) + cp.reason = HCI_ERROR_REJ_LIMITED_RESOURCES; + + return __hci_cmd_sync_status(hdev, HCI_OP_REJECT_SYNC_CONN_REQ, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +static int hci_reject_conn_sync(struct hci_dev *hdev, struct hci_conn *conn, + u8 reason) +{ + struct hci_cp_reject_conn_req cp; + + if (conn->type == SCO_LINK || conn->type == ESCO_LINK) + return hci_reject_sco_sync(hdev, conn, reason); + + memset(&cp, 0, sizeof(cp)); + bacpy(&cp.bdaddr, &conn->dst); + cp.reason = reason; + + return __hci_cmd_sync_status(hdev, HCI_OP_REJECT_CONN_REQ, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +int hci_abort_conn_sync(struct hci_dev *hdev, struct hci_conn *conn, u8 reason) +{ + int err; + + switch (conn->state) { + case BT_CONNECTED: + case BT_CONFIG: + return hci_disconnect_sync(hdev, conn, reason); + case BT_CONNECT: + err = hci_connect_cancel_sync(hdev, conn); + /* Cleanup hci_conn object if it cannot be cancelled as it + * likelly means the controller and host stack are out of sync. + */ + if (err) { + hci_dev_lock(hdev); + hci_conn_failed(conn, err); + hci_dev_unlock(hdev); + } + return err; + case BT_CONNECT2: + return hci_reject_conn_sync(hdev, conn, reason); + default: + conn->state = BT_CLOSED; + break; + } + + return 0; +} + +static int hci_disconnect_all_sync(struct hci_dev *hdev, u8 reason) +{ + struct hci_conn *conn, *tmp; + int err; + + list_for_each_entry_safe(conn, tmp, &hdev->conn_hash.list, list) { + err = hci_abort_conn_sync(hdev, conn, reason); + if (err) + return err; + } + + return 0; +} + +/* This function perform power off HCI command sequence as follows: + * + * Clear Advertising + * Stop Discovery + * Disconnect all connections + * hci_dev_close_sync + */ +static int hci_power_off_sync(struct hci_dev *hdev) +{ + int err; + + /* If controller is already down there is nothing to do */ + if (!test_bit(HCI_UP, &hdev->flags)) + return 0; + + if (test_bit(HCI_ISCAN, &hdev->flags) || + test_bit(HCI_PSCAN, &hdev->flags)) { + err = hci_write_scan_enable_sync(hdev, 0x00); + if (err) + return err; + } + + err = hci_clear_adv_sync(hdev, NULL, false); + if (err) + return err; + + err = hci_stop_discovery_sync(hdev); + if (err) + return err; + + /* Terminated due to Power Off */ + err = hci_disconnect_all_sync(hdev, HCI_ERROR_REMOTE_POWER_OFF); + if (err) + return err; + + return hci_dev_close_sync(hdev); +} + +int hci_set_powered_sync(struct hci_dev *hdev, u8 val) +{ + if (val) + return hci_power_on_sync(hdev); + + return hci_power_off_sync(hdev); +} + +static int hci_write_iac_sync(struct hci_dev *hdev) +{ + struct hci_cp_write_current_iac_lap cp; + + if (!hci_dev_test_flag(hdev, HCI_DISCOVERABLE)) + return 0; + + memset(&cp, 0, sizeof(cp)); + + if (hci_dev_test_flag(hdev, HCI_LIMITED_DISCOVERABLE)) { + /* Limited discoverable mode */ + cp.num_iac = min_t(u8, hdev->num_iac, 2); + cp.iac_lap[0] = 0x00; /* LIAC */ + cp.iac_lap[1] = 0x8b; + cp.iac_lap[2] = 0x9e; + cp.iac_lap[3] = 0x33; /* GIAC */ + cp.iac_lap[4] = 0x8b; + cp.iac_lap[5] = 0x9e; + } else { + /* General discoverable mode */ + cp.num_iac = 1; + cp.iac_lap[0] = 0x33; /* GIAC */ + cp.iac_lap[1] = 0x8b; + cp.iac_lap[2] = 0x9e; + } + + return __hci_cmd_sync_status(hdev, HCI_OP_WRITE_CURRENT_IAC_LAP, + (cp.num_iac * 3) + 1, &cp, + HCI_CMD_TIMEOUT); +} + +int hci_update_discoverable_sync(struct hci_dev *hdev) +{ + int err = 0; + + if (hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) { + err = hci_write_iac_sync(hdev); + if (err) + return err; + + err = hci_update_scan_sync(hdev); + if (err) + return err; + + err = hci_update_class_sync(hdev); + if (err) + return err; + } + + /* Advertising instances don't use the global discoverable setting, so + * only update AD if advertising was enabled using Set Advertising. + */ + if (hci_dev_test_flag(hdev, HCI_ADVERTISING)) { + err = hci_update_adv_data_sync(hdev, 0x00); + if (err) + return err; + + /* Discoverable mode affects the local advertising + * address in limited privacy mode. + */ + if (hci_dev_test_flag(hdev, HCI_LIMITED_PRIVACY)) { + if (ext_adv_capable(hdev)) + err = hci_start_ext_adv_sync(hdev, 0x00); + else + err = hci_enable_advertising_sync(hdev); + } + } + + return err; +} + +static int update_discoverable_sync(struct hci_dev *hdev, void *data) +{ + return hci_update_discoverable_sync(hdev); +} + +int hci_update_discoverable(struct hci_dev *hdev) +{ + /* Only queue if it would have any effect */ + if (hdev_is_powered(hdev) && + hci_dev_test_flag(hdev, HCI_ADVERTISING) && + hci_dev_test_flag(hdev, HCI_DISCOVERABLE) && + hci_dev_test_flag(hdev, HCI_LIMITED_PRIVACY)) + return hci_cmd_sync_queue(hdev, update_discoverable_sync, NULL, + NULL); + + return 0; +} + +int hci_update_connectable_sync(struct hci_dev *hdev) +{ + int err; + + err = hci_update_scan_sync(hdev); + if (err) + return err; + + /* If BR/EDR is not enabled and we disable advertising as a + * by-product of disabling connectable, we need to update the + * advertising flags. + */ + if (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) + err = hci_update_adv_data_sync(hdev, hdev->cur_adv_instance); + + /* Update the advertising parameters if necessary */ + if (hci_dev_test_flag(hdev, HCI_ADVERTISING) || + !list_empty(&hdev->adv_instances)) { + if (ext_adv_capable(hdev)) + err = hci_start_ext_adv_sync(hdev, + hdev->cur_adv_instance); + else + err = hci_enable_advertising_sync(hdev); + + if (err) + return err; + } + + return hci_update_passive_scan_sync(hdev); +} + +static int hci_inquiry_sync(struct hci_dev *hdev, u8 length) +{ + const u8 giac[3] = { 0x33, 0x8b, 0x9e }; + const u8 liac[3] = { 0x00, 0x8b, 0x9e }; + struct hci_cp_inquiry cp; + + bt_dev_dbg(hdev, ""); + + if (hci_dev_test_flag(hdev, HCI_INQUIRY)) + return 0; + + hci_dev_lock(hdev); + hci_inquiry_cache_flush(hdev); + hci_dev_unlock(hdev); + + memset(&cp, 0, sizeof(cp)); + + if (hdev->discovery.limited) + memcpy(&cp.lap, liac, sizeof(cp.lap)); + else + memcpy(&cp.lap, giac, sizeof(cp.lap)); + + cp.length = length; + + return __hci_cmd_sync_status(hdev, HCI_OP_INQUIRY, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +static int hci_active_scan_sync(struct hci_dev *hdev, uint16_t interval) +{ + u8 own_addr_type; + /* Accept list is not used for discovery */ + u8 filter_policy = 0x00; + /* Default is to enable duplicates filter */ + u8 filter_dup = LE_SCAN_FILTER_DUP_ENABLE; + int err; + + bt_dev_dbg(hdev, ""); + + /* If controller is scanning, it means the passive scanning is + * running. Thus, we should temporarily stop it in order to set the + * discovery scanning parameters. + */ + err = hci_scan_disable_sync(hdev); + if (err) { + bt_dev_err(hdev, "Unable to disable scanning: %d", err); + return err; + } + + cancel_interleave_scan(hdev); + + /* Pause address resolution for active scan and stop advertising if + * privacy is enabled. + */ + err = hci_pause_addr_resolution(hdev); + if (err) + goto failed; + + /* All active scans will be done with either a resolvable private + * address (when privacy feature has been enabled) or non-resolvable + * private address. + */ + err = hci_update_random_address_sync(hdev, true, scan_use_rpa(hdev), + &own_addr_type); + if (err < 0) + own_addr_type = ADDR_LE_DEV_PUBLIC; + + if (hci_is_adv_monitoring(hdev)) { + /* Duplicate filter should be disabled when some advertisement + * monitor is activated, otherwise AdvMon can only receive one + * advertisement for one peer(*) during active scanning, and + * might report loss to these peers. + * + * Note that different controllers have different meanings of + * |duplicate|. Some of them consider packets with the same + * address as duplicate, and others consider packets with the + * same address and the same RSSI as duplicate. Although in the + * latter case we don't need to disable duplicate filter, but + * it is common to have active scanning for a short period of + * time, the power impact should be neglectable. + */ + filter_dup = LE_SCAN_FILTER_DUP_DISABLE; + } + + err = hci_start_scan_sync(hdev, LE_SCAN_ACTIVE, interval, + hdev->le_scan_window_discovery, + own_addr_type, filter_policy, filter_dup); + if (!err) + return err; + +failed: + /* Resume advertising if it was paused */ + if (use_ll_privacy(hdev)) + hci_resume_advertising_sync(hdev); + + /* Resume passive scanning */ + hci_update_passive_scan_sync(hdev); + return err; +} + +static int hci_start_interleaved_discovery_sync(struct hci_dev *hdev) +{ + int err; + + bt_dev_dbg(hdev, ""); + + err = hci_active_scan_sync(hdev, hdev->le_scan_int_discovery * 2); + if (err) + return err; + + return hci_inquiry_sync(hdev, DISCOV_BREDR_INQUIRY_LEN); +} + +int hci_start_discovery_sync(struct hci_dev *hdev) +{ + unsigned long timeout; + int err; + + bt_dev_dbg(hdev, "type %u", hdev->discovery.type); + + switch (hdev->discovery.type) { + case DISCOV_TYPE_BREDR: + return hci_inquiry_sync(hdev, DISCOV_BREDR_INQUIRY_LEN); + case DISCOV_TYPE_INTERLEAVED: + /* When running simultaneous discovery, the LE scanning time + * should occupy the whole discovery time sine BR/EDR inquiry + * and LE scanning are scheduled by the controller. + * + * For interleaving discovery in comparison, BR/EDR inquiry + * and LE scanning are done sequentially with separate + * timeouts. + */ + if (test_bit(HCI_QUIRK_SIMULTANEOUS_DISCOVERY, + &hdev->quirks)) { + timeout = msecs_to_jiffies(DISCOV_LE_TIMEOUT); + /* During simultaneous discovery, we double LE scan + * interval. We must leave some time for the controller + * to do BR/EDR inquiry. + */ + err = hci_start_interleaved_discovery_sync(hdev); + break; + } + + timeout = msecs_to_jiffies(hdev->discov_interleaved_timeout); + err = hci_active_scan_sync(hdev, hdev->le_scan_int_discovery); + break; + case DISCOV_TYPE_LE: + timeout = msecs_to_jiffies(DISCOV_LE_TIMEOUT); + err = hci_active_scan_sync(hdev, hdev->le_scan_int_discovery); + break; + default: + return -EINVAL; + } + + if (err) + return err; + + bt_dev_dbg(hdev, "timeout %u ms", jiffies_to_msecs(timeout)); + + /* When service discovery is used and the controller has a + * strict duplicate filter, it is important to remember the + * start and duration of the scan. This is required for + * restarting scanning during the discovery phase. + */ + if (test_bit(HCI_QUIRK_STRICT_DUPLICATE_FILTER, &hdev->quirks) && + hdev->discovery.result_filtering) { + hdev->discovery.scan_start = jiffies; + hdev->discovery.scan_duration = timeout; + } + + queue_delayed_work(hdev->req_workqueue, &hdev->le_scan_disable, + timeout); + return 0; +} + +static void hci_suspend_monitor_sync(struct hci_dev *hdev) +{ + switch (hci_get_adv_monitor_offload_ext(hdev)) { + case HCI_ADV_MONITOR_EXT_MSFT: + msft_suspend_sync(hdev); + break; + default: + return; + } +} + +/* This function disables discovery and mark it as paused */ +static int hci_pause_discovery_sync(struct hci_dev *hdev) +{ + int old_state = hdev->discovery.state; + int err; + + /* If discovery already stopped/stopping/paused there nothing to do */ + if (old_state == DISCOVERY_STOPPED || old_state == DISCOVERY_STOPPING || + hdev->discovery_paused) + return 0; + + hci_discovery_set_state(hdev, DISCOVERY_STOPPING); + err = hci_stop_discovery_sync(hdev); + if (err) + return err; + + hdev->discovery_paused = true; + hdev->discovery_old_state = old_state; + hci_discovery_set_state(hdev, DISCOVERY_STOPPED); + + return 0; +} + +static int hci_update_event_filter_sync(struct hci_dev *hdev) +{ + struct bdaddr_list_with_flags *b; + u8 scan = SCAN_DISABLED; + bool scanning = test_bit(HCI_PSCAN, &hdev->flags); + int err; + + if (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) + return 0; + + /* Some fake CSR controllers lock up after setting this type of + * filter, so avoid sending the request altogether. + */ + if (test_bit(HCI_QUIRK_BROKEN_FILTER_CLEAR_ALL, &hdev->quirks)) + return 0; + + /* Always clear event filter when starting */ + hci_clear_event_filter_sync(hdev); + + list_for_each_entry(b, &hdev->accept_list, list) { + if (!(b->flags & HCI_CONN_FLAG_REMOTE_WAKEUP)) + continue; + + bt_dev_dbg(hdev, "Adding event filters for %pMR", &b->bdaddr); + + err = hci_set_event_filter_sync(hdev, HCI_FLT_CONN_SETUP, + HCI_CONN_SETUP_ALLOW_BDADDR, + &b->bdaddr, + HCI_CONN_SETUP_AUTO_ON); + if (err) + bt_dev_dbg(hdev, "Failed to set event filter for %pMR", + &b->bdaddr); + else + scan = SCAN_PAGE; + } + + if (scan && !scanning) + hci_write_scan_enable_sync(hdev, scan); + else if (!scan && scanning) + hci_write_scan_enable_sync(hdev, scan); + + return 0; +} + +/* This function disables scan (BR and LE) and mark it as paused */ +static int hci_pause_scan_sync(struct hci_dev *hdev) +{ + if (hdev->scanning_paused) + return 0; + + /* Disable page scan if enabled */ + if (test_bit(HCI_PSCAN, &hdev->flags)) + hci_write_scan_enable_sync(hdev, SCAN_DISABLED); + + hci_scan_disable_sync(hdev); + + hdev->scanning_paused = true; + + return 0; +} + +/* This function performs the HCI suspend procedures in the follow order: + * + * Pause discovery (active scanning/inquiry) + * Pause Directed Advertising/Advertising + * Pause Scanning (passive scanning in case discovery was not active) + * Disconnect all connections + * Set suspend_status to BT_SUSPEND_DISCONNECT if hdev cannot wakeup + * otherwise: + * Update event mask (only set events that are allowed to wake up the host) + * Update event filter (with devices marked with HCI_CONN_FLAG_REMOTE_WAKEUP) + * Update passive scanning (lower duty cycle) + * Set suspend_status to BT_SUSPEND_CONFIGURE_WAKE + */ +int hci_suspend_sync(struct hci_dev *hdev) +{ + int err; + + /* If marked as suspended there nothing to do */ + if (hdev->suspended) + return 0; + + /* Mark device as suspended */ + hdev->suspended = true; + + /* Pause discovery if not already stopped */ + hci_pause_discovery_sync(hdev); + + /* Pause other advertisements */ + hci_pause_advertising_sync(hdev); + + /* Suspend monitor filters */ + hci_suspend_monitor_sync(hdev); + + /* Prevent disconnects from causing scanning to be re-enabled */ + hci_pause_scan_sync(hdev); + + if (hci_conn_count(hdev)) { + /* Soft disconnect everything (power off) */ + err = hci_disconnect_all_sync(hdev, HCI_ERROR_REMOTE_POWER_OFF); + if (err) { + /* Set state to BT_RUNNING so resume doesn't notify */ + hdev->suspend_state = BT_RUNNING; + hci_resume_sync(hdev); + return err; + } + + /* Update event mask so only the allowed event can wakeup the + * host. + */ + hci_set_event_mask_sync(hdev); + } + + /* Only configure accept list if disconnect succeeded and wake + * isn't being prevented. + */ + if (!hdev->wakeup || !hdev->wakeup(hdev)) { + hdev->suspend_state = BT_SUSPEND_DISCONNECT; + return 0; + } + + /* Unpause to take care of updating scanning params */ + hdev->scanning_paused = false; + + /* Enable event filter for paired devices */ + hci_update_event_filter_sync(hdev); + + /* Update LE passive scan if enabled */ + hci_update_passive_scan_sync(hdev); + + /* Pause scan changes again. */ + hdev->scanning_paused = true; + + hdev->suspend_state = BT_SUSPEND_CONFIGURE_WAKE; + + return 0; +} + +/* This function resumes discovery */ +static int hci_resume_discovery_sync(struct hci_dev *hdev) +{ + int err; + + /* If discovery not paused there nothing to do */ + if (!hdev->discovery_paused) + return 0; + + hdev->discovery_paused = false; + + hci_discovery_set_state(hdev, DISCOVERY_STARTING); + + err = hci_start_discovery_sync(hdev); + + hci_discovery_set_state(hdev, err ? DISCOVERY_STOPPED : + DISCOVERY_FINDING); + + return err; +} + +static void hci_resume_monitor_sync(struct hci_dev *hdev) +{ + switch (hci_get_adv_monitor_offload_ext(hdev)) { + case HCI_ADV_MONITOR_EXT_MSFT: + msft_resume_sync(hdev); + break; + default: + return; + } +} + +/* This function resume scan and reset paused flag */ +static int hci_resume_scan_sync(struct hci_dev *hdev) +{ + if (!hdev->scanning_paused) + return 0; + + hdev->scanning_paused = false; + + hci_update_scan_sync(hdev); + + /* Reset passive scanning to normal */ + hci_update_passive_scan_sync(hdev); + + return 0; +} + +/* This function performs the HCI suspend procedures in the follow order: + * + * Restore event mask + * Clear event filter + * Update passive scanning (normal duty cycle) + * Resume Directed Advertising/Advertising + * Resume discovery (active scanning/inquiry) + */ +int hci_resume_sync(struct hci_dev *hdev) +{ + /* If not marked as suspended there nothing to do */ + if (!hdev->suspended) + return 0; + + hdev->suspended = false; + + /* Restore event mask */ + hci_set_event_mask_sync(hdev); + + /* Clear any event filters and restore scan state */ + hci_clear_event_filter_sync(hdev); + + /* Resume scanning */ + hci_resume_scan_sync(hdev); + + /* Resume monitor filters */ + hci_resume_monitor_sync(hdev); + + /* Resume other advertisements */ + hci_resume_advertising_sync(hdev); + + /* Resume discovery */ + hci_resume_discovery_sync(hdev); + + return 0; +} + +static bool conn_use_rpa(struct hci_conn *conn) +{ + struct hci_dev *hdev = conn->hdev; + + return hci_dev_test_flag(hdev, HCI_PRIVACY); +} + +static int hci_le_ext_directed_advertising_sync(struct hci_dev *hdev, + struct hci_conn *conn) +{ + struct hci_cp_le_set_ext_adv_params cp; + int err; + bdaddr_t random_addr; + u8 own_addr_type; + + err = hci_update_random_address_sync(hdev, false, conn_use_rpa(conn), + &own_addr_type); + if (err) + return err; + + /* Set require_privacy to false so that the remote device has a + * chance of identifying us. + */ + err = hci_get_random_address(hdev, false, conn_use_rpa(conn), NULL, + &own_addr_type, &random_addr); + if (err) + return err; + + memset(&cp, 0, sizeof(cp)); + + cp.evt_properties = cpu_to_le16(LE_LEGACY_ADV_DIRECT_IND); + cp.own_addr_type = own_addr_type; + cp.channel_map = hdev->le_adv_channel_map; + cp.tx_power = HCI_TX_POWER_INVALID; + cp.primary_phy = HCI_ADV_PHY_1M; + cp.secondary_phy = HCI_ADV_PHY_1M; + cp.handle = 0x00; /* Use instance 0 for directed adv */ + cp.own_addr_type = own_addr_type; + cp.peer_addr_type = conn->dst_type; + bacpy(&cp.peer_addr, &conn->dst); + + /* As per Core Spec 5.2 Vol 2, PART E, Sec 7.8.53, for + * advertising_event_property LE_LEGACY_ADV_DIRECT_IND + * does not supports advertising data when the advertising set already + * contains some, the controller shall return erroc code 'Invalid + * HCI Command Parameters(0x12). + * So it is required to remove adv set for handle 0x00. since we use + * instance 0 for directed adv. + */ + err = hci_remove_ext_adv_instance_sync(hdev, cp.handle, NULL); + if (err) + return err; + + err = __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_EXT_ADV_PARAMS, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); + if (err) + return err; + + /* Check if random address need to be updated */ + if (own_addr_type == ADDR_LE_DEV_RANDOM && + bacmp(&random_addr, BDADDR_ANY) && + bacmp(&random_addr, &hdev->random_addr)) { + err = hci_set_adv_set_random_addr_sync(hdev, 0x00, + &random_addr); + if (err) + return err; + } + + return hci_enable_ext_advertising_sync(hdev, 0x00); +} + +static int hci_le_directed_advertising_sync(struct hci_dev *hdev, + struct hci_conn *conn) +{ + struct hci_cp_le_set_adv_param cp; + u8 status; + u8 own_addr_type; + u8 enable; + + if (ext_adv_capable(hdev)) + return hci_le_ext_directed_advertising_sync(hdev, conn); + + /* Clear the HCI_LE_ADV bit temporarily so that the + * hci_update_random_address knows that it's safe to go ahead + * and write a new random address. The flag will be set back on + * as soon as the SET_ADV_ENABLE HCI command completes. + */ + hci_dev_clear_flag(hdev, HCI_LE_ADV); + + /* Set require_privacy to false so that the remote device has a + * chance of identifying us. + */ + status = hci_update_random_address_sync(hdev, false, conn_use_rpa(conn), + &own_addr_type); + if (status) + return status; + + memset(&cp, 0, sizeof(cp)); + + /* Some controllers might reject command if intervals are not + * within range for undirected advertising. + * BCM20702A0 is known to be affected by this. + */ + cp.min_interval = cpu_to_le16(0x0020); + cp.max_interval = cpu_to_le16(0x0020); + + cp.type = LE_ADV_DIRECT_IND; + cp.own_address_type = own_addr_type; + cp.direct_addr_type = conn->dst_type; + bacpy(&cp.direct_addr, &conn->dst); + cp.channel_map = hdev->le_adv_channel_map; + + status = __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_ADV_PARAM, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); + if (status) + return status; + + enable = 0x01; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_SET_ADV_ENABLE, + sizeof(enable), &enable, HCI_CMD_TIMEOUT); +} + +static void set_ext_conn_params(struct hci_conn *conn, + struct hci_cp_le_ext_conn_param *p) +{ + struct hci_dev *hdev = conn->hdev; + + memset(p, 0, sizeof(*p)); + + p->scan_interval = cpu_to_le16(hdev->le_scan_int_connect); + p->scan_window = cpu_to_le16(hdev->le_scan_window_connect); + p->conn_interval_min = cpu_to_le16(conn->le_conn_min_interval); + p->conn_interval_max = cpu_to_le16(conn->le_conn_max_interval); + p->conn_latency = cpu_to_le16(conn->le_conn_latency); + p->supervision_timeout = cpu_to_le16(conn->le_supv_timeout); + p->min_ce_len = cpu_to_le16(0x0000); + p->max_ce_len = cpu_to_le16(0x0000); +} + +static int hci_le_ext_create_conn_sync(struct hci_dev *hdev, + struct hci_conn *conn, u8 own_addr_type) +{ + struct hci_cp_le_ext_create_conn *cp; + struct hci_cp_le_ext_conn_param *p; + u8 data[sizeof(*cp) + sizeof(*p) * 3]; + u32 plen; + + cp = (void *)data; + p = (void *)cp->data; + + memset(cp, 0, sizeof(*cp)); + + bacpy(&cp->peer_addr, &conn->dst); + cp->peer_addr_type = conn->dst_type; + cp->own_addr_type = own_addr_type; + + plen = sizeof(*cp); + + if (scan_1m(hdev)) { + cp->phys |= LE_SCAN_PHY_1M; + set_ext_conn_params(conn, p); + + p++; + plen += sizeof(*p); + } + + if (scan_2m(hdev)) { + cp->phys |= LE_SCAN_PHY_2M; + set_ext_conn_params(conn, p); + + p++; + plen += sizeof(*p); + } + + if (scan_coded(hdev)) { + cp->phys |= LE_SCAN_PHY_CODED; + set_ext_conn_params(conn, p); + + plen += sizeof(*p); + } + + return __hci_cmd_sync_status_sk(hdev, HCI_OP_LE_EXT_CREATE_CONN, + plen, data, + HCI_EV_LE_ENHANCED_CONN_COMPLETE, + conn->conn_timeout, NULL); +} + +int hci_le_create_conn_sync(struct hci_dev *hdev, struct hci_conn *conn) +{ + struct hci_cp_le_create_conn cp; + struct hci_conn_params *params; + u8 own_addr_type; + int err; + + /* If requested to connect as peripheral use directed advertising */ + if (conn->role == HCI_ROLE_SLAVE) { + /* If we're active scanning and simultaneous roles is not + * enabled simply reject the attempt. + */ + if (hci_dev_test_flag(hdev, HCI_LE_SCAN) && + hdev->le_scan_type == LE_SCAN_ACTIVE && + !hci_dev_test_flag(hdev, HCI_LE_SIMULTANEOUS_ROLES)) { + hci_conn_del(conn); + return -EBUSY; + } + + /* Pause advertising while doing directed advertising. */ + hci_pause_advertising_sync(hdev); + + err = hci_le_directed_advertising_sync(hdev, conn); + goto done; + } + + /* Disable advertising if simultaneous roles is not in use. */ + if (!hci_dev_test_flag(hdev, HCI_LE_SIMULTANEOUS_ROLES)) + hci_pause_advertising_sync(hdev); + + params = hci_conn_params_lookup(hdev, &conn->dst, conn->dst_type); + if (params) { + conn->le_conn_min_interval = params->conn_min_interval; + conn->le_conn_max_interval = params->conn_max_interval; + conn->le_conn_latency = params->conn_latency; + conn->le_supv_timeout = params->supervision_timeout; + } else { + conn->le_conn_min_interval = hdev->le_conn_min_interval; + conn->le_conn_max_interval = hdev->le_conn_max_interval; + conn->le_conn_latency = hdev->le_conn_latency; + conn->le_supv_timeout = hdev->le_supv_timeout; + } + + /* If controller is scanning, we stop it since some controllers are + * not able to scan and connect at the same time. Also set the + * HCI_LE_SCAN_INTERRUPTED flag so that the command complete + * handler for scan disabling knows to set the correct discovery + * state. + */ + if (hci_dev_test_flag(hdev, HCI_LE_SCAN)) { + hci_scan_disable_sync(hdev); + hci_dev_set_flag(hdev, HCI_LE_SCAN_INTERRUPTED); + } + + /* Update random address, but set require_privacy to false so + * that we never connect with an non-resolvable address. + */ + err = hci_update_random_address_sync(hdev, false, conn_use_rpa(conn), + &own_addr_type); + if (err) + goto done; + + if (use_ext_conn(hdev)) { + err = hci_le_ext_create_conn_sync(hdev, conn, own_addr_type); + goto done; + } + + memset(&cp, 0, sizeof(cp)); + + cp.scan_interval = cpu_to_le16(hdev->le_scan_int_connect); + cp.scan_window = cpu_to_le16(hdev->le_scan_window_connect); + + bacpy(&cp.peer_addr, &conn->dst); + cp.peer_addr_type = conn->dst_type; + cp.own_address_type = own_addr_type; + cp.conn_interval_min = cpu_to_le16(conn->le_conn_min_interval); + cp.conn_interval_max = cpu_to_le16(conn->le_conn_max_interval); + cp.conn_latency = cpu_to_le16(conn->le_conn_latency); + cp.supervision_timeout = cpu_to_le16(conn->le_supv_timeout); + cp.min_ce_len = cpu_to_le16(0x0000); + cp.max_ce_len = cpu_to_le16(0x0000); + + /* BLUETOOTH CORE SPECIFICATION Version 5.3 | Vol 4, Part E page 2261: + * + * If this event is unmasked and the HCI_LE_Connection_Complete event + * is unmasked, only the HCI_LE_Enhanced_Connection_Complete event is + * sent when a new connection has been created. + */ + err = __hci_cmd_sync_status_sk(hdev, HCI_OP_LE_CREATE_CONN, + sizeof(cp), &cp, + use_enhanced_conn_complete(hdev) ? + HCI_EV_LE_ENHANCED_CONN_COMPLETE : + HCI_EV_LE_CONN_COMPLETE, + conn->conn_timeout, NULL); + +done: + if (err == -ETIMEDOUT) + hci_le_connect_cancel_sync(hdev, conn); + + /* Re-enable advertising after the connection attempt is finished. */ + hci_resume_advertising_sync(hdev); + return err; +} + +int hci_le_remove_cig_sync(struct hci_dev *hdev, u8 handle) +{ + struct hci_cp_le_remove_cig cp; + + memset(&cp, 0, sizeof(cp)); + cp.cig_id = handle; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_REMOVE_CIG, sizeof(cp), + &cp, HCI_CMD_TIMEOUT); +} + +int hci_le_big_terminate_sync(struct hci_dev *hdev, u8 handle) +{ + struct hci_cp_le_big_term_sync cp; + + memset(&cp, 0, sizeof(cp)); + cp.handle = handle; + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_BIG_TERM_SYNC, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +int hci_le_pa_terminate_sync(struct hci_dev *hdev, u16 handle) +{ + struct hci_cp_le_pa_term_sync cp; + + memset(&cp, 0, sizeof(cp)); + cp.handle = cpu_to_le16(handle); + + return __hci_cmd_sync_status(hdev, HCI_OP_LE_PA_TERM_SYNC, + sizeof(cp), &cp, HCI_CMD_TIMEOUT); +} + +int hci_get_random_address(struct hci_dev *hdev, bool require_privacy, + bool use_rpa, struct adv_info *adv_instance, + u8 *own_addr_type, bdaddr_t *rand_addr) +{ + int err; + + bacpy(rand_addr, BDADDR_ANY); + + /* If privacy is enabled use a resolvable private address. If + * current RPA has expired then generate a new one. + */ + if (use_rpa) { + /* If Controller supports LL Privacy use own address type is + * 0x03 + */ + if (use_ll_privacy(hdev)) + *own_addr_type = ADDR_LE_DEV_RANDOM_RESOLVED; + else + *own_addr_type = ADDR_LE_DEV_RANDOM; + + if (adv_instance) { + if (adv_rpa_valid(adv_instance)) + return 0; + } else { + if (rpa_valid(hdev)) + return 0; + } + + err = smp_generate_rpa(hdev, hdev->irk, &hdev->rpa); + if (err < 0) { + bt_dev_err(hdev, "failed to generate new RPA"); + return err; + } + + bacpy(rand_addr, &hdev->rpa); + + return 0; + } + + /* In case of required privacy without resolvable private address, + * use an non-resolvable private address. This is useful for + * non-connectable advertising. + */ + if (require_privacy) { + bdaddr_t nrpa; + + while (true) { + /* The non-resolvable private address is generated + * from random six bytes with the two most significant + * bits cleared. + */ + get_random_bytes(&nrpa, 6); + nrpa.b[5] &= 0x3f; + + /* The non-resolvable private address shall not be + * equal to the public address. + */ + if (bacmp(&hdev->bdaddr, &nrpa)) + break; + } + + *own_addr_type = ADDR_LE_DEV_RANDOM; + bacpy(rand_addr, &nrpa); + + return 0; + } + + /* No privacy so use a public address. */ + *own_addr_type = ADDR_LE_DEV_PUBLIC; + + return 0; +} + +static int _update_adv_data_sync(struct hci_dev *hdev, void *data) +{ + u8 instance = PTR_ERR(data); + + return hci_update_adv_data_sync(hdev, instance); +} + +int hci_update_adv_data(struct hci_dev *hdev, u8 instance) +{ + return hci_cmd_sync_queue(hdev, _update_adv_data_sync, + ERR_PTR(instance), NULL); +} diff --git a/net/bluetooth/hci_sysfs.c b/net/bluetooth/hci_sysfs.c new file mode 100644 index 000000000..633b82d54 --- /dev/null +++ b/net/bluetooth/hci_sysfs.c @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Bluetooth HCI driver model support. */ + +#include + +#include +#include + +static struct class *bt_class; + +static void bt_link_release(struct device *dev) +{ + struct hci_conn *conn = to_hci_conn(dev); + kfree(conn); +} + +static const struct device_type bt_link = { + .name = "link", + .release = bt_link_release, +}; + +/* + * The rfcomm tty device will possibly retain even when conn + * is down, and sysfs doesn't support move zombie device, + * so we should move the device before conn device is destroyed. + */ +static int __match_tty(struct device *dev, void *data) +{ + return !strncmp(dev_name(dev), "rfcomm", 6); +} + +void hci_conn_init_sysfs(struct hci_conn *conn) +{ + struct hci_dev *hdev = conn->hdev; + + bt_dev_dbg(hdev, "conn %p", conn); + + conn->dev.type = &bt_link; + conn->dev.class = bt_class; + conn->dev.parent = &hdev->dev; + + device_initialize(&conn->dev); +} + +void hci_conn_add_sysfs(struct hci_conn *conn) +{ + struct hci_dev *hdev = conn->hdev; + + bt_dev_dbg(hdev, "conn %p", conn); + + if (device_is_registered(&conn->dev)) + return; + + dev_set_name(&conn->dev, "%s:%d", hdev->name, conn->handle); + + if (device_add(&conn->dev) < 0) + bt_dev_err(hdev, "failed to register connection device"); +} + +void hci_conn_del_sysfs(struct hci_conn *conn) +{ + struct hci_dev *hdev = conn->hdev; + + bt_dev_dbg(hdev, "conn %p", conn); + + if (!device_is_registered(&conn->dev)) { + /* If device_add() has *not* succeeded, use *only* put_device() + * to drop the reference count. + */ + put_device(&conn->dev); + return; + } + + while (1) { + struct device *dev; + + dev = device_find_child(&conn->dev, NULL, __match_tty); + if (!dev) + break; + device_move(dev, NULL, DPM_ORDER_DEV_LAST); + put_device(dev); + } + + device_unregister(&conn->dev); +} + +static void bt_host_release(struct device *dev) +{ + struct hci_dev *hdev = to_hci_dev(dev); + + if (hci_dev_test_flag(hdev, HCI_UNREGISTER)) + hci_release_dev(hdev); + else + kfree(hdev); + module_put(THIS_MODULE); +} + +static const struct device_type bt_host = { + .name = "host", + .release = bt_host_release, +}; + +void hci_init_sysfs(struct hci_dev *hdev) +{ + struct device *dev = &hdev->dev; + + dev->type = &bt_host; + dev->class = bt_class; + + __module_get(THIS_MODULE); + device_initialize(dev); +} + +int __init bt_sysfs_init(void) +{ + bt_class = class_create(THIS_MODULE, "bluetooth"); + + return PTR_ERR_OR_ZERO(bt_class); +} + +void bt_sysfs_cleanup(void) +{ + class_destroy(bt_class); +} diff --git a/net/bluetooth/hidp/Kconfig b/net/bluetooth/hidp/Kconfig new file mode 100644 index 000000000..14100f341 --- /dev/null +++ b/net/bluetooth/hidp/Kconfig @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: GPL-2.0-only +config BT_HIDP + tristate "HIDP protocol support" + depends on BT_BREDR && INPUT + select HID + help + HIDP (Human Interface Device Protocol) is a transport layer + for HID reports. HIDP is required for the Bluetooth Human + Interface Device Profile. + + Say Y here to compile HIDP support into the kernel or say M to + compile it as module (hidp). + diff --git a/net/bluetooth/hidp/Makefile b/net/bluetooth/hidp/Makefile new file mode 100644 index 000000000..f41b0aa02 --- /dev/null +++ b/net/bluetooth/hidp/Makefile @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the Linux Bluetooth HIDP layer +# + +obj-$(CONFIG_BT_HIDP) += hidp.o + +hidp-objs := core.o sock.o diff --git a/net/bluetooth/hidp/core.c b/net/bluetooth/hidp/core.c new file mode 100644 index 000000000..82cc15ad9 --- /dev/null +++ b/net/bluetooth/hidp/core.c @@ -0,0 +1,1480 @@ +/* + HIDP implementation for Linux Bluetooth stack (BlueZ). + Copyright (C) 2003-2004 Marcel Holtmann + Copyright (C) 2013 David Herrmann + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "hidp.h" + +#define VERSION "1.2" + +static DECLARE_RWSEM(hidp_session_sem); +static DECLARE_WAIT_QUEUE_HEAD(hidp_session_wq); +static LIST_HEAD(hidp_session_list); + +static unsigned char hidp_keycode[256] = { + 0, 0, 0, 0, 30, 48, 46, 32, 18, 33, 34, 35, 23, 36, + 37, 38, 50, 49, 24, 25, 16, 19, 31, 20, 22, 47, 17, 45, + 21, 44, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 28, 1, + 14, 15, 57, 12, 13, 26, 27, 43, 43, 39, 40, 41, 51, 52, + 53, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 87, 88, + 99, 70, 119, 110, 102, 104, 111, 107, 109, 106, 105, 108, 103, 69, + 98, 55, 74, 78, 96, 79, 80, 81, 75, 76, 77, 71, 72, 73, + 82, 83, 86, 127, 116, 117, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 134, 138, 130, 132, 128, 129, 131, 137, 133, 135, + 136, 113, 115, 114, 0, 0, 0, 121, 0, 89, 93, 124, 92, 94, + 95, 0, 0, 0, 122, 123, 90, 91, 85, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 29, 42, 56, 125, 97, 54, 100, 126, 164, 166, 165, 163, 161, 115, + 114, 113, 150, 158, 159, 128, 136, 177, 178, 176, 142, 152, 173, 140 +}; + +static unsigned char hidp_mkeyspat[] = { 0x01, 0x01, 0x01, 0x01, 0x01, 0x01 }; + +static int hidp_session_probe(struct l2cap_conn *conn, + struct l2cap_user *user); +static void hidp_session_remove(struct l2cap_conn *conn, + struct l2cap_user *user); +static int hidp_session_thread(void *arg); +static void hidp_session_terminate(struct hidp_session *s); + +static void hidp_copy_session(struct hidp_session *session, struct hidp_conninfo *ci) +{ + u32 valid_flags = 0; + memset(ci, 0, sizeof(*ci)); + bacpy(&ci->bdaddr, &session->bdaddr); + + ci->flags = session->flags & valid_flags; + ci->state = BT_CONNECTED; + + if (session->input) { + ci->vendor = session->input->id.vendor; + ci->product = session->input->id.product; + ci->version = session->input->id.version; + if (session->input->name) + strscpy(ci->name, session->input->name, 128); + else + strscpy(ci->name, "HID Boot Device", 128); + } else if (session->hid) { + ci->vendor = session->hid->vendor; + ci->product = session->hid->product; + ci->version = session->hid->version; + strscpy(ci->name, session->hid->name, 128); + } +} + +/* assemble skb, queue message on @transmit and wake up the session thread */ +static int hidp_send_message(struct hidp_session *session, struct socket *sock, + struct sk_buff_head *transmit, unsigned char hdr, + const unsigned char *data, int size) +{ + struct sk_buff *skb; + struct sock *sk = sock->sk; + int ret; + + BT_DBG("session %p data %p size %d", session, data, size); + + if (atomic_read(&session->terminate)) + return -EIO; + + skb = alloc_skb(size + 1, GFP_ATOMIC); + if (!skb) { + BT_ERR("Can't allocate memory for new frame"); + return -ENOMEM; + } + + skb_put_u8(skb, hdr); + if (data && size > 0) { + skb_put_data(skb, data, size); + ret = size; + } else { + ret = 0; + } + + skb_queue_tail(transmit, skb); + wake_up_interruptible(sk_sleep(sk)); + + return ret; +} + +static int hidp_send_ctrl_message(struct hidp_session *session, + unsigned char hdr, const unsigned char *data, + int size) +{ + return hidp_send_message(session, session->ctrl_sock, + &session->ctrl_transmit, hdr, data, size); +} + +static int hidp_send_intr_message(struct hidp_session *session, + unsigned char hdr, const unsigned char *data, + int size) +{ + return hidp_send_message(session, session->intr_sock, + &session->intr_transmit, hdr, data, size); +} + +static int hidp_input_event(struct input_dev *dev, unsigned int type, + unsigned int code, int value) +{ + struct hidp_session *session = input_get_drvdata(dev); + unsigned char newleds; + unsigned char hdr, data[2]; + + BT_DBG("session %p type %d code %d value %d", + session, type, code, value); + + if (type != EV_LED) + return -1; + + newleds = (!!test_bit(LED_KANA, dev->led) << 3) | + (!!test_bit(LED_COMPOSE, dev->led) << 3) | + (!!test_bit(LED_SCROLLL, dev->led) << 2) | + (!!test_bit(LED_CAPSL, dev->led) << 1) | + (!!test_bit(LED_NUML, dev->led) << 0); + + if (session->leds == newleds) + return 0; + + session->leds = newleds; + + hdr = HIDP_TRANS_DATA | HIDP_DATA_RTYPE_OUPUT; + data[0] = 0x01; + data[1] = newleds; + + return hidp_send_intr_message(session, hdr, data, 2); +} + +static void hidp_input_report(struct hidp_session *session, struct sk_buff *skb) +{ + struct input_dev *dev = session->input; + unsigned char *keys = session->keys; + unsigned char *udata = skb->data + 1; + signed char *sdata = skb->data + 1; + int i, size = skb->len - 1; + + switch (skb->data[0]) { + case 0x01: /* Keyboard report */ + for (i = 0; i < 8; i++) + input_report_key(dev, hidp_keycode[i + 224], (udata[0] >> i) & 1); + + /* If all the key codes have been set to 0x01, it means + * too many keys were pressed at the same time. */ + if (!memcmp(udata + 2, hidp_mkeyspat, 6)) + break; + + for (i = 2; i < 8; i++) { + if (keys[i] > 3 && memscan(udata + 2, keys[i], 6) == udata + 8) { + if (hidp_keycode[keys[i]]) + input_report_key(dev, hidp_keycode[keys[i]], 0); + else + BT_ERR("Unknown key (scancode %#x) released.", keys[i]); + } + + if (udata[i] > 3 && memscan(keys + 2, udata[i], 6) == keys + 8) { + if (hidp_keycode[udata[i]]) + input_report_key(dev, hidp_keycode[udata[i]], 1); + else + BT_ERR("Unknown key (scancode %#x) pressed.", udata[i]); + } + } + + memcpy(keys, udata, 8); + break; + + case 0x02: /* Mouse report */ + input_report_key(dev, BTN_LEFT, sdata[0] & 0x01); + input_report_key(dev, BTN_RIGHT, sdata[0] & 0x02); + input_report_key(dev, BTN_MIDDLE, sdata[0] & 0x04); + input_report_key(dev, BTN_SIDE, sdata[0] & 0x08); + input_report_key(dev, BTN_EXTRA, sdata[0] & 0x10); + + input_report_rel(dev, REL_X, sdata[1]); + input_report_rel(dev, REL_Y, sdata[2]); + + if (size > 3) + input_report_rel(dev, REL_WHEEL, sdata[3]); + break; + } + + input_sync(dev); +} + +static int hidp_get_raw_report(struct hid_device *hid, + unsigned char report_number, + unsigned char *data, size_t count, + unsigned char report_type) +{ + struct hidp_session *session = hid->driver_data; + struct sk_buff *skb; + size_t len; + int numbered_reports = hid->report_enum[report_type].numbered; + int ret; + + if (atomic_read(&session->terminate)) + return -EIO; + + switch (report_type) { + case HID_FEATURE_REPORT: + report_type = HIDP_TRANS_GET_REPORT | HIDP_DATA_RTYPE_FEATURE; + break; + case HID_INPUT_REPORT: + report_type = HIDP_TRANS_GET_REPORT | HIDP_DATA_RTYPE_INPUT; + break; + case HID_OUTPUT_REPORT: + report_type = HIDP_TRANS_GET_REPORT | HIDP_DATA_RTYPE_OUPUT; + break; + default: + return -EINVAL; + } + + if (mutex_lock_interruptible(&session->report_mutex)) + return -ERESTARTSYS; + + /* Set up our wait, and send the report request to the device. */ + session->waiting_report_type = report_type & HIDP_DATA_RTYPE_MASK; + session->waiting_report_number = numbered_reports ? report_number : -1; + set_bit(HIDP_WAITING_FOR_RETURN, &session->flags); + data[0] = report_number; + ret = hidp_send_ctrl_message(session, report_type, data, 1); + if (ret < 0) + goto err; + + /* Wait for the return of the report. The returned report + gets put in session->report_return. */ + while (test_bit(HIDP_WAITING_FOR_RETURN, &session->flags) && + !atomic_read(&session->terminate)) { + int res; + + res = wait_event_interruptible_timeout(session->report_queue, + !test_bit(HIDP_WAITING_FOR_RETURN, &session->flags) + || atomic_read(&session->terminate), + 5*HZ); + if (res == 0) { + /* timeout */ + ret = -EIO; + goto err; + } + if (res < 0) { + /* signal */ + ret = -ERESTARTSYS; + goto err; + } + } + + skb = session->report_return; + if (skb) { + len = skb->len < count ? skb->len : count; + memcpy(data, skb->data, len); + + kfree_skb(skb); + session->report_return = NULL; + } else { + /* Device returned a HANDSHAKE, indicating protocol error. */ + len = -EIO; + } + + clear_bit(HIDP_WAITING_FOR_RETURN, &session->flags); + mutex_unlock(&session->report_mutex); + + return len; + +err: + clear_bit(HIDP_WAITING_FOR_RETURN, &session->flags); + mutex_unlock(&session->report_mutex); + return ret; +} + +static int hidp_set_raw_report(struct hid_device *hid, unsigned char reportnum, + unsigned char *data, size_t count, + unsigned char report_type) +{ + struct hidp_session *session = hid->driver_data; + int ret; + + switch (report_type) { + case HID_FEATURE_REPORT: + report_type = HIDP_TRANS_SET_REPORT | HIDP_DATA_RTYPE_FEATURE; + break; + case HID_INPUT_REPORT: + report_type = HIDP_TRANS_SET_REPORT | HIDP_DATA_RTYPE_INPUT; + break; + case HID_OUTPUT_REPORT: + report_type = HIDP_TRANS_SET_REPORT | HIDP_DATA_RTYPE_OUPUT; + break; + default: + return -EINVAL; + } + + if (mutex_lock_interruptible(&session->report_mutex)) + return -ERESTARTSYS; + + /* Set up our wait, and send the report request to the device. */ + data[0] = reportnum; + set_bit(HIDP_WAITING_FOR_SEND_ACK, &session->flags); + ret = hidp_send_ctrl_message(session, report_type, data, count); + if (ret < 0) + goto err; + + /* Wait for the ACK from the device. */ + while (test_bit(HIDP_WAITING_FOR_SEND_ACK, &session->flags) && + !atomic_read(&session->terminate)) { + int res; + + res = wait_event_interruptible_timeout(session->report_queue, + !test_bit(HIDP_WAITING_FOR_SEND_ACK, &session->flags) + || atomic_read(&session->terminate), + 10*HZ); + if (res == 0) { + /* timeout */ + ret = -EIO; + goto err; + } + if (res < 0) { + /* signal */ + ret = -ERESTARTSYS; + goto err; + } + } + + if (!session->output_report_success) { + ret = -EIO; + goto err; + } + + ret = count; + +err: + clear_bit(HIDP_WAITING_FOR_SEND_ACK, &session->flags); + mutex_unlock(&session->report_mutex); + return ret; +} + +static int hidp_output_report(struct hid_device *hid, __u8 *data, size_t count) +{ + struct hidp_session *session = hid->driver_data; + + return hidp_send_intr_message(session, + HIDP_TRANS_DATA | HIDP_DATA_RTYPE_OUPUT, + data, count); +} + +static int hidp_raw_request(struct hid_device *hid, unsigned char reportnum, + __u8 *buf, size_t len, unsigned char rtype, + int reqtype) +{ + switch (reqtype) { + case HID_REQ_GET_REPORT: + return hidp_get_raw_report(hid, reportnum, buf, len, rtype); + case HID_REQ_SET_REPORT: + return hidp_set_raw_report(hid, reportnum, buf, len, rtype); + default: + return -EIO; + } +} + +static void hidp_idle_timeout(struct timer_list *t) +{ + struct hidp_session *session = from_timer(session, t, timer); + + /* The HIDP user-space API only contains calls to add and remove + * devices. There is no way to forward events of any kind. Therefore, + * we have to forcefully disconnect a device on idle-timeouts. This is + * unfortunate and weird API design, but it is spec-compliant and + * required for backwards-compatibility. Hence, on idle-timeout, we + * signal driver-detach events, so poll() will be woken up with an + * error-condition on both sockets. + */ + + session->intr_sock->sk->sk_err = EUNATCH; + session->ctrl_sock->sk->sk_err = EUNATCH; + wake_up_interruptible(sk_sleep(session->intr_sock->sk)); + wake_up_interruptible(sk_sleep(session->ctrl_sock->sk)); + + hidp_session_terminate(session); +} + +static void hidp_set_timer(struct hidp_session *session) +{ + if (session->idle_to > 0) + mod_timer(&session->timer, jiffies + HZ * session->idle_to); +} + +static void hidp_del_timer(struct hidp_session *session) +{ + if (session->idle_to > 0) + del_timer_sync(&session->timer); +} + +static void hidp_process_report(struct hidp_session *session, int type, + const u8 *data, unsigned int len, int intr) +{ + if (len > HID_MAX_BUFFER_SIZE) + len = HID_MAX_BUFFER_SIZE; + + memcpy(session->input_buf, data, len); + hid_input_report(session->hid, type, session->input_buf, len, intr); +} + +static void hidp_process_handshake(struct hidp_session *session, + unsigned char param) +{ + BT_DBG("session %p param 0x%02x", session, param); + session->output_report_success = 0; /* default condition */ + + switch (param) { + case HIDP_HSHK_SUCCESSFUL: + /* FIXME: Call into SET_ GET_ handlers here */ + session->output_report_success = 1; + break; + + case HIDP_HSHK_NOT_READY: + case HIDP_HSHK_ERR_INVALID_REPORT_ID: + case HIDP_HSHK_ERR_UNSUPPORTED_REQUEST: + case HIDP_HSHK_ERR_INVALID_PARAMETER: + if (test_and_clear_bit(HIDP_WAITING_FOR_RETURN, &session->flags)) + wake_up_interruptible(&session->report_queue); + + /* FIXME: Call into SET_ GET_ handlers here */ + break; + + case HIDP_HSHK_ERR_UNKNOWN: + break; + + case HIDP_HSHK_ERR_FATAL: + /* Device requests a reboot, as this is the only way this error + * can be recovered. */ + hidp_send_ctrl_message(session, + HIDP_TRANS_HID_CONTROL | HIDP_CTRL_SOFT_RESET, NULL, 0); + break; + + default: + hidp_send_ctrl_message(session, + HIDP_TRANS_HANDSHAKE | HIDP_HSHK_ERR_INVALID_PARAMETER, NULL, 0); + break; + } + + /* Wake up the waiting thread. */ + if (test_and_clear_bit(HIDP_WAITING_FOR_SEND_ACK, &session->flags)) + wake_up_interruptible(&session->report_queue); +} + +static void hidp_process_hid_control(struct hidp_session *session, + unsigned char param) +{ + BT_DBG("session %p param 0x%02x", session, param); + + if (param == HIDP_CTRL_VIRTUAL_CABLE_UNPLUG) { + /* Flush the transmit queues */ + skb_queue_purge(&session->ctrl_transmit); + skb_queue_purge(&session->intr_transmit); + + hidp_session_terminate(session); + } +} + +/* Returns true if the passed-in skb should be freed by the caller. */ +static int hidp_process_data(struct hidp_session *session, struct sk_buff *skb, + unsigned char param) +{ + int done_with_skb = 1; + BT_DBG("session %p skb %p len %u param 0x%02x", session, skb, skb->len, param); + + switch (param) { + case HIDP_DATA_RTYPE_INPUT: + hidp_set_timer(session); + + if (session->input) + hidp_input_report(session, skb); + + if (session->hid) + hidp_process_report(session, HID_INPUT_REPORT, + skb->data, skb->len, 0); + break; + + case HIDP_DATA_RTYPE_OTHER: + case HIDP_DATA_RTYPE_OUPUT: + case HIDP_DATA_RTYPE_FEATURE: + break; + + default: + hidp_send_ctrl_message(session, + HIDP_TRANS_HANDSHAKE | HIDP_HSHK_ERR_INVALID_PARAMETER, NULL, 0); + } + + if (test_bit(HIDP_WAITING_FOR_RETURN, &session->flags) && + param == session->waiting_report_type) { + if (session->waiting_report_number < 0 || + session->waiting_report_number == skb->data[0]) { + /* hidp_get_raw_report() is waiting on this report. */ + session->report_return = skb; + done_with_skb = 0; + clear_bit(HIDP_WAITING_FOR_RETURN, &session->flags); + wake_up_interruptible(&session->report_queue); + } + } + + return done_with_skb; +} + +static void hidp_recv_ctrl_frame(struct hidp_session *session, + struct sk_buff *skb) +{ + unsigned char hdr, type, param; + int free_skb = 1; + + BT_DBG("session %p skb %p len %u", session, skb, skb->len); + + hdr = skb->data[0]; + skb_pull(skb, 1); + + type = hdr & HIDP_HEADER_TRANS_MASK; + param = hdr & HIDP_HEADER_PARAM_MASK; + + switch (type) { + case HIDP_TRANS_HANDSHAKE: + hidp_process_handshake(session, param); + break; + + case HIDP_TRANS_HID_CONTROL: + hidp_process_hid_control(session, param); + break; + + case HIDP_TRANS_DATA: + free_skb = hidp_process_data(session, skb, param); + break; + + default: + hidp_send_ctrl_message(session, + HIDP_TRANS_HANDSHAKE | HIDP_HSHK_ERR_UNSUPPORTED_REQUEST, NULL, 0); + break; + } + + if (free_skb) + kfree_skb(skb); +} + +static void hidp_recv_intr_frame(struct hidp_session *session, + struct sk_buff *skb) +{ + unsigned char hdr; + + BT_DBG("session %p skb %p len %u", session, skb, skb->len); + + hdr = skb->data[0]; + skb_pull(skb, 1); + + if (hdr == (HIDP_TRANS_DATA | HIDP_DATA_RTYPE_INPUT)) { + hidp_set_timer(session); + + if (session->input) + hidp_input_report(session, skb); + + if (session->hid) { + hidp_process_report(session, HID_INPUT_REPORT, + skb->data, skb->len, 1); + BT_DBG("report len %d", skb->len); + } + } else { + BT_DBG("Unsupported protocol header 0x%02x", hdr); + } + + kfree_skb(skb); +} + +static int hidp_send_frame(struct socket *sock, unsigned char *data, int len) +{ + struct kvec iv = { data, len }; + struct msghdr msg; + + BT_DBG("sock %p data %p len %d", sock, data, len); + + if (!len) + return 0; + + memset(&msg, 0, sizeof(msg)); + + return kernel_sendmsg(sock, &msg, &iv, 1, len); +} + +/* dequeue message from @transmit and send via @sock */ +static void hidp_process_transmit(struct hidp_session *session, + struct sk_buff_head *transmit, + struct socket *sock) +{ + struct sk_buff *skb; + int ret; + + BT_DBG("session %p", session); + + while ((skb = skb_dequeue(transmit))) { + ret = hidp_send_frame(sock, skb->data, skb->len); + if (ret == -EAGAIN) { + skb_queue_head(transmit, skb); + break; + } else if (ret < 0) { + hidp_session_terminate(session); + kfree_skb(skb); + break; + } + + hidp_set_timer(session); + kfree_skb(skb); + } +} + +static int hidp_setup_input(struct hidp_session *session, + const struct hidp_connadd_req *req) +{ + struct input_dev *input; + int i; + + input = input_allocate_device(); + if (!input) + return -ENOMEM; + + session->input = input; + + input_set_drvdata(input, session); + + input->name = "Bluetooth HID Boot Protocol Device"; + + input->id.bustype = BUS_BLUETOOTH; + input->id.vendor = req->vendor; + input->id.product = req->product; + input->id.version = req->version; + + if (req->subclass & 0x40) { + set_bit(EV_KEY, input->evbit); + set_bit(EV_LED, input->evbit); + set_bit(EV_REP, input->evbit); + + set_bit(LED_NUML, input->ledbit); + set_bit(LED_CAPSL, input->ledbit); + set_bit(LED_SCROLLL, input->ledbit); + set_bit(LED_COMPOSE, input->ledbit); + set_bit(LED_KANA, input->ledbit); + + for (i = 0; i < sizeof(hidp_keycode); i++) + set_bit(hidp_keycode[i], input->keybit); + clear_bit(0, input->keybit); + } + + if (req->subclass & 0x80) { + input->evbit[0] = BIT_MASK(EV_KEY) | BIT_MASK(EV_REL); + input->keybit[BIT_WORD(BTN_MOUSE)] = BIT_MASK(BTN_LEFT) | + BIT_MASK(BTN_RIGHT) | BIT_MASK(BTN_MIDDLE); + input->relbit[0] = BIT_MASK(REL_X) | BIT_MASK(REL_Y); + input->keybit[BIT_WORD(BTN_MOUSE)] |= BIT_MASK(BTN_SIDE) | + BIT_MASK(BTN_EXTRA); + input->relbit[0] |= BIT_MASK(REL_WHEEL); + } + + input->dev.parent = &session->conn->hcon->dev; + + input->event = hidp_input_event; + + return 0; +} + +static int hidp_open(struct hid_device *hid) +{ + return 0; +} + +static void hidp_close(struct hid_device *hid) +{ +} + +static int hidp_parse(struct hid_device *hid) +{ + struct hidp_session *session = hid->driver_data; + + return hid_parse_report(session->hid, session->rd_data, + session->rd_size); +} + +static int hidp_start(struct hid_device *hid) +{ + return 0; +} + +static void hidp_stop(struct hid_device *hid) +{ + struct hidp_session *session = hid->driver_data; + + skb_queue_purge(&session->ctrl_transmit); + skb_queue_purge(&session->intr_transmit); + + hid->claimed = 0; +} + +struct hid_ll_driver hidp_hid_driver = { + .parse = hidp_parse, + .start = hidp_start, + .stop = hidp_stop, + .open = hidp_open, + .close = hidp_close, + .raw_request = hidp_raw_request, + .output_report = hidp_output_report, +}; +EXPORT_SYMBOL_GPL(hidp_hid_driver); + +/* This function sets up the hid device. It does not add it + to the HID system. That is done in hidp_add_connection(). */ +static int hidp_setup_hid(struct hidp_session *session, + const struct hidp_connadd_req *req) +{ + struct hid_device *hid; + int err; + + session->rd_data = memdup_user(req->rd_data, req->rd_size); + if (IS_ERR(session->rd_data)) + return PTR_ERR(session->rd_data); + + session->rd_size = req->rd_size; + + hid = hid_allocate_device(); + if (IS_ERR(hid)) { + err = PTR_ERR(hid); + goto fault; + } + + session->hid = hid; + + hid->driver_data = session; + + hid->bus = BUS_BLUETOOTH; + hid->vendor = req->vendor; + hid->product = req->product; + hid->version = req->version; + hid->country = req->country; + + strscpy(hid->name, req->name, sizeof(hid->name)); + + snprintf(hid->phys, sizeof(hid->phys), "%pMR", + &l2cap_pi(session->ctrl_sock->sk)->chan->src); + + /* NOTE: Some device modules depend on the dst address being stored in + * uniq. Please be aware of this before making changes to this behavior. + */ + snprintf(hid->uniq, sizeof(hid->uniq), "%pMR", + &l2cap_pi(session->ctrl_sock->sk)->chan->dst); + + hid->dev.parent = &session->conn->hcon->dev; + hid->ll_driver = &hidp_hid_driver; + + /* True if device is blocked in drivers/hid/hid-quirks.c */ + if (hid_ignore(hid)) { + hid_destroy_device(session->hid); + session->hid = NULL; + return -ENODEV; + } + + return 0; + +fault: + kfree(session->rd_data); + session->rd_data = NULL; + + return err; +} + +/* initialize session devices */ +static int hidp_session_dev_init(struct hidp_session *session, + const struct hidp_connadd_req *req) +{ + int ret; + + if (req->rd_size > 0) { + ret = hidp_setup_hid(session, req); + if (ret && ret != -ENODEV) + return ret; + } + + if (!session->hid) { + ret = hidp_setup_input(session, req); + if (ret < 0) + return ret; + } + + return 0; +} + +/* destroy session devices */ +static void hidp_session_dev_destroy(struct hidp_session *session) +{ + if (session->hid) + put_device(&session->hid->dev); + else if (session->input) + input_put_device(session->input); + + kfree(session->rd_data); + session->rd_data = NULL; +} + +/* add HID/input devices to their underlying bus systems */ +static int hidp_session_dev_add(struct hidp_session *session) +{ + int ret; + + /* Both HID and input systems drop a ref-count when unregistering the + * device but they don't take a ref-count when registering them. Work + * around this by explicitly taking a refcount during registration + * which is dropped automatically by unregistering the devices. */ + + if (session->hid) { + ret = hid_add_device(session->hid); + if (ret) + return ret; + get_device(&session->hid->dev); + } else if (session->input) { + ret = input_register_device(session->input); + if (ret) + return ret; + input_get_device(session->input); + } + + return 0; +} + +/* remove HID/input devices from their bus systems */ +static void hidp_session_dev_del(struct hidp_session *session) +{ + if (session->hid) + hid_destroy_device(session->hid); + else if (session->input) + input_unregister_device(session->input); +} + +/* + * Asynchronous device registration + * HID device drivers might want to perform I/O during initialization to + * detect device types. Therefore, call device registration in a separate + * worker so the HIDP thread can schedule I/O operations. + * Note that this must be called after the worker thread was initialized + * successfully. This will then add the devices and increase session state + * on success, otherwise it will terminate the session thread. + */ +static void hidp_session_dev_work(struct work_struct *work) +{ + struct hidp_session *session = container_of(work, + struct hidp_session, + dev_init); + int ret; + + ret = hidp_session_dev_add(session); + if (!ret) + atomic_inc(&session->state); + else + hidp_session_terminate(session); +} + +/* + * Create new session object + * Allocate session object, initialize static fields, copy input data into the + * object and take a reference to all sub-objects. + * This returns 0 on success and puts a pointer to the new session object in + * \out. Otherwise, an error code is returned. + * The new session object has an initial ref-count of 1. + */ +static int hidp_session_new(struct hidp_session **out, const bdaddr_t *bdaddr, + struct socket *ctrl_sock, + struct socket *intr_sock, + const struct hidp_connadd_req *req, + struct l2cap_conn *conn) +{ + struct hidp_session *session; + int ret; + struct bt_sock *ctrl, *intr; + + ctrl = bt_sk(ctrl_sock->sk); + intr = bt_sk(intr_sock->sk); + + session = kzalloc(sizeof(*session), GFP_KERNEL); + if (!session) + return -ENOMEM; + + /* object and runtime management */ + kref_init(&session->ref); + atomic_set(&session->state, HIDP_SESSION_IDLING); + init_waitqueue_head(&session->state_queue); + session->flags = req->flags & BIT(HIDP_BLUETOOTH_VENDOR_ID); + + /* connection management */ + bacpy(&session->bdaddr, bdaddr); + session->conn = l2cap_conn_get(conn); + session->user.probe = hidp_session_probe; + session->user.remove = hidp_session_remove; + INIT_LIST_HEAD(&session->user.list); + session->ctrl_sock = ctrl_sock; + session->intr_sock = intr_sock; + skb_queue_head_init(&session->ctrl_transmit); + skb_queue_head_init(&session->intr_transmit); + session->ctrl_mtu = min_t(uint, l2cap_pi(ctrl)->chan->omtu, + l2cap_pi(ctrl)->chan->imtu); + session->intr_mtu = min_t(uint, l2cap_pi(intr)->chan->omtu, + l2cap_pi(intr)->chan->imtu); + session->idle_to = req->idle_to; + + /* device management */ + INIT_WORK(&session->dev_init, hidp_session_dev_work); + timer_setup(&session->timer, hidp_idle_timeout, 0); + + /* session data */ + mutex_init(&session->report_mutex); + init_waitqueue_head(&session->report_queue); + + ret = hidp_session_dev_init(session, req); + if (ret) + goto err_free; + + get_file(session->intr_sock->file); + get_file(session->ctrl_sock->file); + *out = session; + return 0; + +err_free: + l2cap_conn_put(session->conn); + kfree(session); + return ret; +} + +/* increase ref-count of the given session by one */ +static void hidp_session_get(struct hidp_session *session) +{ + kref_get(&session->ref); +} + +/* release callback */ +static void session_free(struct kref *ref) +{ + struct hidp_session *session = container_of(ref, struct hidp_session, + ref); + + hidp_session_dev_destroy(session); + skb_queue_purge(&session->ctrl_transmit); + skb_queue_purge(&session->intr_transmit); + fput(session->intr_sock->file); + fput(session->ctrl_sock->file); + l2cap_conn_put(session->conn); + kfree(session); +} + +/* decrease ref-count of the given session by one */ +static void hidp_session_put(struct hidp_session *session) +{ + kref_put(&session->ref, session_free); +} + +/* + * Search the list of active sessions for a session with target address + * \bdaddr. You must hold at least a read-lock on \hidp_session_sem. As long as + * you do not release this lock, the session objects cannot vanish and you can + * safely take a reference to the session yourself. + */ +static struct hidp_session *__hidp_session_find(const bdaddr_t *bdaddr) +{ + struct hidp_session *session; + + list_for_each_entry(session, &hidp_session_list, list) { + if (!bacmp(bdaddr, &session->bdaddr)) + return session; + } + + return NULL; +} + +/* + * Same as __hidp_session_find() but no locks must be held. This also takes a + * reference of the returned session (if non-NULL) so you must drop this + * reference if you no longer use the object. + */ +static struct hidp_session *hidp_session_find(const bdaddr_t *bdaddr) +{ + struct hidp_session *session; + + down_read(&hidp_session_sem); + + session = __hidp_session_find(bdaddr); + if (session) + hidp_session_get(session); + + up_read(&hidp_session_sem); + + return session; +} + +/* + * Start session synchronously + * This starts a session thread and waits until initialization + * is done or returns an error if it couldn't be started. + * If this returns 0 the session thread is up and running. You must call + * hipd_session_stop_sync() before deleting any runtime resources. + */ +static int hidp_session_start_sync(struct hidp_session *session) +{ + unsigned int vendor, product; + + if (session->hid) { + vendor = session->hid->vendor; + product = session->hid->product; + } else if (session->input) { + vendor = session->input->id.vendor; + product = session->input->id.product; + } else { + vendor = 0x0000; + product = 0x0000; + } + + session->task = kthread_run(hidp_session_thread, session, + "khidpd_%04x%04x", vendor, product); + if (IS_ERR(session->task)) + return PTR_ERR(session->task); + + while (atomic_read(&session->state) <= HIDP_SESSION_IDLING) + wait_event(session->state_queue, + atomic_read(&session->state) > HIDP_SESSION_IDLING); + + return 0; +} + +/* + * Terminate session thread + * Wake up session thread and notify it to stop. This is asynchronous and + * returns immediately. Call this whenever a runtime error occurs and you want + * the session to stop. + * Note: wake_up_interruptible() performs any necessary memory-barriers for us. + */ +static void hidp_session_terminate(struct hidp_session *session) +{ + atomic_inc(&session->terminate); + /* + * See the comment preceding the call to wait_woken() + * in hidp_session_run(). + */ + wake_up_interruptible(&hidp_session_wq); +} + +/* + * Probe HIDP session + * This is called from the l2cap_conn core when our l2cap_user object is bound + * to the hci-connection. We get the session via the \user object and can now + * start the session thread, link it into the global session list and + * schedule HID/input device registration. + * The global session-list owns its own reference to the session object so you + * can drop your own reference after registering the l2cap_user object. + */ +static int hidp_session_probe(struct l2cap_conn *conn, + struct l2cap_user *user) +{ + struct hidp_session *session = container_of(user, + struct hidp_session, + user); + struct hidp_session *s; + int ret; + + down_write(&hidp_session_sem); + + /* check that no other session for this device exists */ + s = __hidp_session_find(&session->bdaddr); + if (s) { + ret = -EEXIST; + goto out_unlock; + } + + if (session->input) { + ret = hidp_session_dev_add(session); + if (ret) + goto out_unlock; + } + + ret = hidp_session_start_sync(session); + if (ret) + goto out_del; + + /* HID device registration is async to allow I/O during probe */ + if (session->input) + atomic_inc(&session->state); + else + schedule_work(&session->dev_init); + + hidp_session_get(session); + list_add(&session->list, &hidp_session_list); + ret = 0; + goto out_unlock; + +out_del: + if (session->input) + hidp_session_dev_del(session); +out_unlock: + up_write(&hidp_session_sem); + return ret; +} + +/* + * Remove HIDP session + * Called from the l2cap_conn core when either we explicitly unregistered + * the l2cap_user object or if the underlying connection is shut down. + * We signal the hidp-session thread to shut down, unregister the HID/input + * devices and unlink the session from the global list. + * This drops the reference to the session that is owned by the global + * session-list. + * Note: We _must_ not synchronosly wait for the session-thread to shut down. + * This is, because the session-thread might be waiting for an HCI lock that is + * held while we are called. Therefore, we only unregister the devices and + * notify the session-thread to terminate. The thread itself owns a reference + * to the session object so it can safely shut down. + */ +static void hidp_session_remove(struct l2cap_conn *conn, + struct l2cap_user *user) +{ + struct hidp_session *session = container_of(user, + struct hidp_session, + user); + + down_write(&hidp_session_sem); + + hidp_session_terminate(session); + + cancel_work_sync(&session->dev_init); + if (session->input || + atomic_read(&session->state) > HIDP_SESSION_PREPARING) + hidp_session_dev_del(session); + + list_del(&session->list); + + up_write(&hidp_session_sem); + + hidp_session_put(session); +} + +/* + * Session Worker + * This performs the actual main-loop of the HIDP worker. We first check + * whether the underlying connection is still alive, then parse all pending + * messages and finally send all outstanding messages. + */ +static void hidp_session_run(struct hidp_session *session) +{ + struct sock *ctrl_sk = session->ctrl_sock->sk; + struct sock *intr_sk = session->intr_sock->sk; + struct sk_buff *skb; + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + add_wait_queue(&hidp_session_wq, &wait); + for (;;) { + /* + * This thread can be woken up two ways: + * - You call hidp_session_terminate() which sets the + * session->terminate flag and wakes this thread up. + * - Via modifying the socket state of ctrl/intr_sock. This + * thread is woken up by ->sk_state_changed(). + */ + + if (atomic_read(&session->terminate)) + break; + + if (ctrl_sk->sk_state != BT_CONNECTED || + intr_sk->sk_state != BT_CONNECTED) + break; + + /* parse incoming intr-skbs */ + while ((skb = skb_dequeue(&intr_sk->sk_receive_queue))) { + skb_orphan(skb); + if (!skb_linearize(skb)) + hidp_recv_intr_frame(session, skb); + else + kfree_skb(skb); + } + + /* send pending intr-skbs */ + hidp_process_transmit(session, &session->intr_transmit, + session->intr_sock); + + /* parse incoming ctrl-skbs */ + while ((skb = skb_dequeue(&ctrl_sk->sk_receive_queue))) { + skb_orphan(skb); + if (!skb_linearize(skb)) + hidp_recv_ctrl_frame(session, skb); + else + kfree_skb(skb); + } + + /* send pending ctrl-skbs */ + hidp_process_transmit(session, &session->ctrl_transmit, + session->ctrl_sock); + + /* + * wait_woken() performs the necessary memory barriers + * for us; see the header comment for this primitive. + */ + wait_woken(&wait, TASK_INTERRUPTIBLE, MAX_SCHEDULE_TIMEOUT); + } + remove_wait_queue(&hidp_session_wq, &wait); + + atomic_inc(&session->terminate); +} + +static int hidp_session_wake_function(wait_queue_entry_t *wait, + unsigned int mode, + int sync, void *key) +{ + wake_up_interruptible(&hidp_session_wq); + return false; +} + +/* + * HIDP session thread + * This thread runs the I/O for a single HIDP session. Startup is synchronous + * which allows us to take references to ourself here instead of doing that in + * the caller. + * When we are ready to run we notify the caller and call hidp_session_run(). + */ +static int hidp_session_thread(void *arg) +{ + struct hidp_session *session = arg; + DEFINE_WAIT_FUNC(ctrl_wait, hidp_session_wake_function); + DEFINE_WAIT_FUNC(intr_wait, hidp_session_wake_function); + + BT_DBG("session %p", session); + + /* initialize runtime environment */ + hidp_session_get(session); + __module_get(THIS_MODULE); + set_user_nice(current, -15); + hidp_set_timer(session); + + add_wait_queue(sk_sleep(session->ctrl_sock->sk), &ctrl_wait); + add_wait_queue(sk_sleep(session->intr_sock->sk), &intr_wait); + /* This memory barrier is paired with wq_has_sleeper(). See + * sock_poll_wait() for more information why this is needed. */ + smp_mb__before_atomic(); + + /* notify synchronous startup that we're ready */ + atomic_inc(&session->state); + wake_up(&session->state_queue); + + /* run session */ + hidp_session_run(session); + + /* cleanup runtime environment */ + remove_wait_queue(sk_sleep(session->intr_sock->sk), &intr_wait); + remove_wait_queue(sk_sleep(session->ctrl_sock->sk), &ctrl_wait); + wake_up_interruptible(&session->report_queue); + hidp_del_timer(session); + + /* + * If we stopped ourself due to any internal signal, we should try to + * unregister our own session here to avoid having it linger until the + * parent l2cap_conn dies or user-space cleans it up. + * This does not deadlock as we don't do any synchronous shutdown. + * Instead, this call has the same semantics as if user-space tried to + * delete the session. + */ + l2cap_unregister_user(session->conn, &session->user); + hidp_session_put(session); + + module_put_and_kthread_exit(0); + return 0; +} + +static int hidp_verify_sockets(struct socket *ctrl_sock, + struct socket *intr_sock) +{ + struct l2cap_chan *ctrl_chan, *intr_chan; + struct bt_sock *ctrl, *intr; + struct hidp_session *session; + + if (!l2cap_is_socket(ctrl_sock) || !l2cap_is_socket(intr_sock)) + return -EINVAL; + + ctrl_chan = l2cap_pi(ctrl_sock->sk)->chan; + intr_chan = l2cap_pi(intr_sock->sk)->chan; + + if (bacmp(&ctrl_chan->src, &intr_chan->src) || + bacmp(&ctrl_chan->dst, &intr_chan->dst)) + return -ENOTUNIQ; + + ctrl = bt_sk(ctrl_sock->sk); + intr = bt_sk(intr_sock->sk); + + if (ctrl->sk.sk_state != BT_CONNECTED || + intr->sk.sk_state != BT_CONNECTED) + return -EBADFD; + + /* early session check, we check again during session registration */ + session = hidp_session_find(&ctrl_chan->dst); + if (session) { + hidp_session_put(session); + return -EEXIST; + } + + return 0; +} + +int hidp_connection_add(const struct hidp_connadd_req *req, + struct socket *ctrl_sock, + struct socket *intr_sock) +{ + u32 valid_flags = BIT(HIDP_VIRTUAL_CABLE_UNPLUG) | + BIT(HIDP_BOOT_PROTOCOL_MODE); + struct hidp_session *session; + struct l2cap_conn *conn; + struct l2cap_chan *chan; + int ret; + + ret = hidp_verify_sockets(ctrl_sock, intr_sock); + if (ret) + return ret; + + if (req->flags & ~valid_flags) + return -EINVAL; + + chan = l2cap_pi(ctrl_sock->sk)->chan; + conn = NULL; + l2cap_chan_lock(chan); + if (chan->conn) + conn = l2cap_conn_get(chan->conn); + l2cap_chan_unlock(chan); + + if (!conn) + return -EBADFD; + + ret = hidp_session_new(&session, &chan->dst, ctrl_sock, + intr_sock, req, conn); + if (ret) + goto out_conn; + + ret = l2cap_register_user(conn, &session->user); + if (ret) + goto out_session; + + ret = 0; + +out_session: + hidp_session_put(session); +out_conn: + l2cap_conn_put(conn); + return ret; +} + +int hidp_connection_del(struct hidp_conndel_req *req) +{ + u32 valid_flags = BIT(HIDP_VIRTUAL_CABLE_UNPLUG); + struct hidp_session *session; + + if (req->flags & ~valid_flags) + return -EINVAL; + + session = hidp_session_find(&req->bdaddr); + if (!session) + return -ENOENT; + + if (req->flags & BIT(HIDP_VIRTUAL_CABLE_UNPLUG)) + hidp_send_ctrl_message(session, + HIDP_TRANS_HID_CONTROL | + HIDP_CTRL_VIRTUAL_CABLE_UNPLUG, + NULL, 0); + else + l2cap_unregister_user(session->conn, &session->user); + + hidp_session_put(session); + + return 0; +} + +int hidp_get_connlist(struct hidp_connlist_req *req) +{ + struct hidp_session *session; + int err = 0, n = 0; + + BT_DBG(""); + + down_read(&hidp_session_sem); + + list_for_each_entry(session, &hidp_session_list, list) { + struct hidp_conninfo ci; + + hidp_copy_session(session, &ci); + + if (copy_to_user(req->ci, &ci, sizeof(ci))) { + err = -EFAULT; + break; + } + + if (++n >= req->cnum) + break; + + req->ci++; + } + req->cnum = n; + + up_read(&hidp_session_sem); + return err; +} + +int hidp_get_conninfo(struct hidp_conninfo *ci) +{ + struct hidp_session *session; + + session = hidp_session_find(&ci->bdaddr); + if (session) { + hidp_copy_session(session, ci); + hidp_session_put(session); + } + + return session ? 0 : -ENOENT; +} + +static int __init hidp_init(void) +{ + BT_INFO("HIDP (Human Interface Emulation) ver %s", VERSION); + + return hidp_init_sockets(); +} + +static void __exit hidp_exit(void) +{ + hidp_cleanup_sockets(); +} + +module_init(hidp_init); +module_exit(hidp_exit); + +MODULE_AUTHOR("Marcel Holtmann "); +MODULE_AUTHOR("David Herrmann "); +MODULE_DESCRIPTION("Bluetooth HIDP ver " VERSION); +MODULE_VERSION(VERSION); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("bt-proto-6"); diff --git a/net/bluetooth/hidp/hidp.h b/net/bluetooth/hidp/hidp.h new file mode 100644 index 000000000..6ef88d0a1 --- /dev/null +++ b/net/bluetooth/hidp/hidp.h @@ -0,0 +1,192 @@ +/* + HIDP implementation for Linux Bluetooth stack (BlueZ). + Copyright (C) 2003-2004 Marcel Holtmann + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#ifndef __HIDP_H +#define __HIDP_H + +#include +#include +#include +#include +#include + +/* HIDP header masks */ +#define HIDP_HEADER_TRANS_MASK 0xf0 +#define HIDP_HEADER_PARAM_MASK 0x0f + +/* HIDP transaction types */ +#define HIDP_TRANS_HANDSHAKE 0x00 +#define HIDP_TRANS_HID_CONTROL 0x10 +#define HIDP_TRANS_GET_REPORT 0x40 +#define HIDP_TRANS_SET_REPORT 0x50 +#define HIDP_TRANS_GET_PROTOCOL 0x60 +#define HIDP_TRANS_SET_PROTOCOL 0x70 +#define HIDP_TRANS_GET_IDLE 0x80 +#define HIDP_TRANS_SET_IDLE 0x90 +#define HIDP_TRANS_DATA 0xa0 +#define HIDP_TRANS_DATC 0xb0 + +/* HIDP handshake results */ +#define HIDP_HSHK_SUCCESSFUL 0x00 +#define HIDP_HSHK_NOT_READY 0x01 +#define HIDP_HSHK_ERR_INVALID_REPORT_ID 0x02 +#define HIDP_HSHK_ERR_UNSUPPORTED_REQUEST 0x03 +#define HIDP_HSHK_ERR_INVALID_PARAMETER 0x04 +#define HIDP_HSHK_ERR_UNKNOWN 0x0e +#define HIDP_HSHK_ERR_FATAL 0x0f + +/* HIDP control operation parameters */ +#define HIDP_CTRL_NOP 0x00 +#define HIDP_CTRL_HARD_RESET 0x01 +#define HIDP_CTRL_SOFT_RESET 0x02 +#define HIDP_CTRL_SUSPEND 0x03 +#define HIDP_CTRL_EXIT_SUSPEND 0x04 +#define HIDP_CTRL_VIRTUAL_CABLE_UNPLUG 0x05 + +/* HIDP data transaction headers */ +#define HIDP_DATA_RTYPE_MASK 0x03 +#define HIDP_DATA_RSRVD_MASK 0x0c +#define HIDP_DATA_RTYPE_OTHER 0x00 +#define HIDP_DATA_RTYPE_INPUT 0x01 +#define HIDP_DATA_RTYPE_OUPUT 0x02 +#define HIDP_DATA_RTYPE_FEATURE 0x03 + +/* HIDP protocol header parameters */ +#define HIDP_PROTO_BOOT 0x00 +#define HIDP_PROTO_REPORT 0x01 + +/* HIDP ioctl defines */ +#define HIDPCONNADD _IOW('H', 200, int) +#define HIDPCONNDEL _IOW('H', 201, int) +#define HIDPGETCONNLIST _IOR('H', 210, int) +#define HIDPGETCONNINFO _IOR('H', 211, int) + +#define HIDP_VIRTUAL_CABLE_UNPLUG 0 +#define HIDP_BOOT_PROTOCOL_MODE 1 +#define HIDP_BLUETOOTH_VENDOR_ID 9 +#define HIDP_WAITING_FOR_RETURN 10 +#define HIDP_WAITING_FOR_SEND_ACK 11 + +struct hidp_connadd_req { + int ctrl_sock; /* Connected control socket */ + int intr_sock; /* Connected interrupt socket */ + __u16 parser; + __u16 rd_size; + __u8 __user *rd_data; + __u8 country; + __u8 subclass; + __u16 vendor; + __u16 product; + __u16 version; + __u32 flags; + __u32 idle_to; + char name[128]; +}; + +struct hidp_conndel_req { + bdaddr_t bdaddr; + __u32 flags; +}; + +struct hidp_conninfo { + bdaddr_t bdaddr; + __u32 flags; + __u16 state; + __u16 vendor; + __u16 product; + __u16 version; + char name[128]; +}; + +struct hidp_connlist_req { + __u32 cnum; + struct hidp_conninfo __user *ci; +}; + +int hidp_connection_add(const struct hidp_connadd_req *req, struct socket *ctrl_sock, struct socket *intr_sock); +int hidp_connection_del(struct hidp_conndel_req *req); +int hidp_get_connlist(struct hidp_connlist_req *req); +int hidp_get_conninfo(struct hidp_conninfo *ci); + +enum hidp_session_state { + HIDP_SESSION_IDLING, + HIDP_SESSION_PREPARING, + HIDP_SESSION_RUNNING, +}; + +/* HIDP session defines */ +struct hidp_session { + struct list_head list; + struct kref ref; + + /* runtime management */ + atomic_t state; + wait_queue_head_t state_queue; + atomic_t terminate; + struct task_struct *task; + unsigned long flags; + + /* connection management */ + bdaddr_t bdaddr; + struct l2cap_conn *conn; + struct l2cap_user user; + struct socket *ctrl_sock; + struct socket *intr_sock; + struct sk_buff_head ctrl_transmit; + struct sk_buff_head intr_transmit; + uint ctrl_mtu; + uint intr_mtu; + unsigned long idle_to; + + /* device management */ + struct work_struct dev_init; + struct input_dev *input; + struct hid_device *hid; + struct timer_list timer; + + /* Report descriptor */ + __u8 *rd_data; + uint rd_size; + + /* session data */ + unsigned char keys[8]; + unsigned char leds; + + /* Used in hidp_get_raw_report() */ + int waiting_report_type; /* HIDP_DATA_RTYPE_* */ + int waiting_report_number; /* -1 for not numbered */ + struct mutex report_mutex; + struct sk_buff *report_return; + wait_queue_head_t report_queue; + + /* Used in hidp_output_raw_report() */ + int output_report_success; /* boolean */ + + /* temporary input buffer */ + u8 input_buf[HID_MAX_BUFFER_SIZE]; +}; + +/* HIDP init defines */ +int __init hidp_init_sockets(void); +void __exit hidp_cleanup_sockets(void); + +#endif /* __HIDP_H */ diff --git a/net/bluetooth/hidp/sock.c b/net/bluetooth/hidp/sock.c new file mode 100644 index 000000000..369ed92da --- /dev/null +++ b/net/bluetooth/hidp/sock.c @@ -0,0 +1,320 @@ +/* + HIDP implementation for Linux Bluetooth stack (BlueZ). + Copyright (C) 2003-2004 Marcel Holtmann + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#include +#include +#include + +#include "hidp.h" + +static struct bt_sock_list hidp_sk_list = { + .lock = __RW_LOCK_UNLOCKED(hidp_sk_list.lock) +}; + +static int hidp_sock_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + + BT_DBG("sock %p sk %p", sock, sk); + + if (!sk) + return 0; + + bt_sock_unlink(&hidp_sk_list, sk); + + sock_orphan(sk); + sock_put(sk); + + return 0; +} + +static int do_hidp_sock_ioctl(struct socket *sock, unsigned int cmd, void __user *argp) +{ + struct hidp_connadd_req ca; + struct hidp_conndel_req cd; + struct hidp_connlist_req cl; + struct hidp_conninfo ci; + struct socket *csock; + struct socket *isock; + int err; + + BT_DBG("cmd %x arg %p", cmd, argp); + + switch (cmd) { + case HIDPCONNADD: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + if (copy_from_user(&ca, argp, sizeof(ca))) + return -EFAULT; + + csock = sockfd_lookup(ca.ctrl_sock, &err); + if (!csock) + return err; + + isock = sockfd_lookup(ca.intr_sock, &err); + if (!isock) { + sockfd_put(csock); + return err; + } + ca.name[sizeof(ca.name)-1] = 0; + + err = hidp_connection_add(&ca, csock, isock); + if (!err && copy_to_user(argp, &ca, sizeof(ca))) + err = -EFAULT; + + sockfd_put(csock); + sockfd_put(isock); + + return err; + + case HIDPCONNDEL: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + if (copy_from_user(&cd, argp, sizeof(cd))) + return -EFAULT; + + return hidp_connection_del(&cd); + + case HIDPGETCONNLIST: + if (copy_from_user(&cl, argp, sizeof(cl))) + return -EFAULT; + + if (cl.cnum <= 0) + return -EINVAL; + + err = hidp_get_connlist(&cl); + if (!err && copy_to_user(argp, &cl, sizeof(cl))) + return -EFAULT; + + return err; + + case HIDPGETCONNINFO: + if (copy_from_user(&ci, argp, sizeof(ci))) + return -EFAULT; + + err = hidp_get_conninfo(&ci); + if (!err && copy_to_user(argp, &ci, sizeof(ci))) + return -EFAULT; + + return err; + } + + return -EINVAL; +} + +static int hidp_sock_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + return do_hidp_sock_ioctl(sock, cmd, (void __user *)arg); +} + +#ifdef CONFIG_COMPAT +struct compat_hidp_connadd_req { + int ctrl_sock; /* Connected control socket */ + int intr_sock; /* Connected interrupt socket */ + __u16 parser; + __u16 rd_size; + compat_uptr_t rd_data; + __u8 country; + __u8 subclass; + __u16 vendor; + __u16 product; + __u16 version; + __u32 flags; + __u32 idle_to; + char name[128]; +}; + +static int hidp_sock_compat_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + void __user *argp = compat_ptr(arg); + int err; + + if (cmd == HIDPGETCONNLIST) { + struct hidp_connlist_req cl; + u32 __user *p = argp; + u32 uci; + + if (get_user(cl.cnum, p) || get_user(uci, p + 1)) + return -EFAULT; + + cl.ci = compat_ptr(uci); + + if (cl.cnum <= 0) + return -EINVAL; + + err = hidp_get_connlist(&cl); + + if (!err && put_user(cl.cnum, p)) + err = -EFAULT; + + return err; + } else if (cmd == HIDPCONNADD) { + struct compat_hidp_connadd_req ca32; + struct hidp_connadd_req ca; + struct socket *csock; + struct socket *isock; + + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + if (copy_from_user(&ca32, (void __user *) arg, sizeof(ca32))) + return -EFAULT; + + ca.ctrl_sock = ca32.ctrl_sock; + ca.intr_sock = ca32.intr_sock; + ca.parser = ca32.parser; + ca.rd_size = ca32.rd_size; + ca.rd_data = compat_ptr(ca32.rd_data); + ca.country = ca32.country; + ca.subclass = ca32.subclass; + ca.vendor = ca32.vendor; + ca.product = ca32.product; + ca.version = ca32.version; + ca.flags = ca32.flags; + ca.idle_to = ca32.idle_to; + ca32.name[sizeof(ca32.name) - 1] = '\0'; + memcpy(ca.name, ca32.name, 128); + + csock = sockfd_lookup(ca.ctrl_sock, &err); + if (!csock) + return err; + + isock = sockfd_lookup(ca.intr_sock, &err); + if (!isock) { + sockfd_put(csock); + return err; + } + + err = hidp_connection_add(&ca, csock, isock); + if (!err && copy_to_user(argp, &ca32, sizeof(ca32))) + err = -EFAULT; + + sockfd_put(csock); + sockfd_put(isock); + + return err; + } + + return hidp_sock_ioctl(sock, cmd, arg); +} +#endif + +static const struct proto_ops hidp_sock_ops = { + .family = PF_BLUETOOTH, + .owner = THIS_MODULE, + .release = hidp_sock_release, + .ioctl = hidp_sock_ioctl, +#ifdef CONFIG_COMPAT + .compat_ioctl = hidp_sock_compat_ioctl, +#endif + .bind = sock_no_bind, + .getname = sock_no_getname, + .sendmsg = sock_no_sendmsg, + .recvmsg = sock_no_recvmsg, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .connect = sock_no_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .mmap = sock_no_mmap +}; + +static struct proto hidp_proto = { + .name = "HIDP", + .owner = THIS_MODULE, + .obj_size = sizeof(struct bt_sock) +}; + +static int hidp_sock_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + + BT_DBG("sock %p", sock); + + if (sock->type != SOCK_RAW) + return -ESOCKTNOSUPPORT; + + sk = sk_alloc(net, PF_BLUETOOTH, GFP_ATOMIC, &hidp_proto, kern); + if (!sk) + return -ENOMEM; + + sock_init_data(sock, sk); + + sock->ops = &hidp_sock_ops; + + sock->state = SS_UNCONNECTED; + + sock_reset_flag(sk, SOCK_ZAPPED); + + sk->sk_protocol = protocol; + sk->sk_state = BT_OPEN; + + bt_sock_link(&hidp_sk_list, sk); + + return 0; +} + +static const struct net_proto_family hidp_sock_family_ops = { + .family = PF_BLUETOOTH, + .owner = THIS_MODULE, + .create = hidp_sock_create +}; + +int __init hidp_init_sockets(void) +{ + int err; + + err = proto_register(&hidp_proto, 0); + if (err < 0) + return err; + + err = bt_sock_register(BTPROTO_HIDP, &hidp_sock_family_ops); + if (err < 0) { + BT_ERR("Can't register HIDP socket"); + goto error; + } + + err = bt_procfs_init(&init_net, "hidp", &hidp_sk_list, NULL); + if (err < 0) { + BT_ERR("Failed to create HIDP proc file"); + bt_sock_unregister(BTPROTO_HIDP); + goto error; + } + + BT_INFO("HIDP socket layer initialized"); + + return 0; + +error: + proto_unregister(&hidp_proto); + return err; +} + +void __exit hidp_cleanup_sockets(void) +{ + bt_procfs_cleanup(&init_net, "hidp"); + bt_sock_unregister(BTPROTO_HIDP); + proto_unregister(&hidp_proto); +} diff --git a/net/bluetooth/iso.c b/net/bluetooth/iso.c new file mode 100644 index 000000000..91e990acc --- /dev/null +++ b/net/bluetooth/iso.c @@ -0,0 +1,1884 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * BlueZ - Bluetooth protocol stack for Linux + * + * Copyright (C) 2022 Intel Corporation + */ + +#include +#include +#include +#include + +#include +#include +#include + +static const struct proto_ops iso_sock_ops; + +static struct bt_sock_list iso_sk_list = { + .lock = __RW_LOCK_UNLOCKED(iso_sk_list.lock) +}; + +/* ---- ISO connections ---- */ +struct iso_conn { + struct hci_conn *hcon; + + /* @lock: spinlock protecting changes to iso_conn fields */ + spinlock_t lock; + struct sock *sk; + + struct delayed_work timeout_work; + + struct sk_buff *rx_skb; + __u32 rx_len; + __u16 tx_sn; +}; + +#define iso_conn_lock(c) spin_lock(&(c)->lock) +#define iso_conn_unlock(c) spin_unlock(&(c)->lock) + +static void iso_sock_close(struct sock *sk); +static void iso_sock_kill(struct sock *sk); + +/* ----- ISO socket info ----- */ +#define iso_pi(sk) ((struct iso_pinfo *)sk) + +#define EIR_SERVICE_DATA_LENGTH 4 +#define BASE_MAX_LENGTH (HCI_MAX_PER_AD_LENGTH - EIR_SERVICE_DATA_LENGTH) + +struct iso_pinfo { + struct bt_sock bt; + bdaddr_t src; + __u8 src_type; + bdaddr_t dst; + __u8 dst_type; + __u8 bc_sid; + __u8 bc_num_bis; + __u8 bc_bis[ISO_MAX_NUM_BIS]; + __u16 sync_handle; + __u32 flags; + struct bt_iso_qos qos; + __u8 base_len; + __u8 base[BASE_MAX_LENGTH]; + struct iso_conn *conn; +}; + +/* ---- ISO timers ---- */ +#define ISO_CONN_TIMEOUT (HZ * 40) +#define ISO_DISCONN_TIMEOUT (HZ * 2) + +static void iso_sock_timeout(struct work_struct *work) +{ + struct iso_conn *conn = container_of(work, struct iso_conn, + timeout_work.work); + struct sock *sk; + + iso_conn_lock(conn); + sk = conn->sk; + if (sk) + sock_hold(sk); + iso_conn_unlock(conn); + + if (!sk) + return; + + BT_DBG("sock %p state %d", sk, sk->sk_state); + + lock_sock(sk); + sk->sk_err = ETIMEDOUT; + sk->sk_state_change(sk); + release_sock(sk); + sock_put(sk); +} + +static void iso_sock_set_timer(struct sock *sk, long timeout) +{ + if (!iso_pi(sk)->conn) + return; + + BT_DBG("sock %p state %d timeout %ld", sk, sk->sk_state, timeout); + cancel_delayed_work(&iso_pi(sk)->conn->timeout_work); + schedule_delayed_work(&iso_pi(sk)->conn->timeout_work, timeout); +} + +static void iso_sock_clear_timer(struct sock *sk) +{ + if (!iso_pi(sk)->conn) + return; + + BT_DBG("sock %p state %d", sk, sk->sk_state); + cancel_delayed_work(&iso_pi(sk)->conn->timeout_work); +} + +/* ---- ISO connections ---- */ +static struct iso_conn *iso_conn_add(struct hci_conn *hcon) +{ + struct iso_conn *conn = hcon->iso_data; + + if (conn) { + if (!conn->hcon) + conn->hcon = hcon; + return conn; + } + + conn = kzalloc(sizeof(*conn), GFP_KERNEL); + if (!conn) + return NULL; + + spin_lock_init(&conn->lock); + INIT_DELAYED_WORK(&conn->timeout_work, iso_sock_timeout); + + hcon->iso_data = conn; + conn->hcon = hcon; + conn->tx_sn = 0; + + BT_DBG("hcon %p conn %p", hcon, conn); + + return conn; +} + +/* Delete channel. Must be called on the locked socket. */ +static void iso_chan_del(struct sock *sk, int err) +{ + struct iso_conn *conn; + struct sock *parent; + + conn = iso_pi(sk)->conn; + + BT_DBG("sk %p, conn %p, err %d", sk, conn, err); + + if (conn) { + iso_conn_lock(conn); + conn->sk = NULL; + iso_pi(sk)->conn = NULL; + iso_conn_unlock(conn); + + if (conn->hcon) + hci_conn_drop(conn->hcon); + } + + sk->sk_state = BT_CLOSED; + sk->sk_err = err; + + parent = bt_sk(sk)->parent; + if (parent) { + bt_accept_unlink(sk); + parent->sk_data_ready(parent); + } else { + sk->sk_state_change(sk); + } + + sock_set_flag(sk, SOCK_ZAPPED); +} + +static void iso_conn_del(struct hci_conn *hcon, int err) +{ + struct iso_conn *conn = hcon->iso_data; + struct sock *sk; + + if (!conn) + return; + + BT_DBG("hcon %p conn %p, err %d", hcon, conn, err); + + /* Kill socket */ + iso_conn_lock(conn); + sk = conn->sk; + if (sk) + sock_hold(sk); + iso_conn_unlock(conn); + + if (sk) { + lock_sock(sk); + iso_sock_clear_timer(sk); + iso_chan_del(sk, err); + release_sock(sk); + sock_put(sk); + } + + /* Ensure no more work items will run before freeing conn. */ + cancel_delayed_work_sync(&conn->timeout_work); + + hcon->iso_data = NULL; + kfree(conn); +} + +static int __iso_chan_add(struct iso_conn *conn, struct sock *sk, + struct sock *parent) +{ + BT_DBG("conn %p", conn); + + if (iso_pi(sk)->conn == conn && conn->sk == sk) + return 0; + + if (conn->sk) { + BT_ERR("conn->sk already set"); + return -EBUSY; + } + + iso_pi(sk)->conn = conn; + conn->sk = sk; + + if (parent) + bt_accept_enqueue(parent, sk, true); + + return 0; +} + +static int iso_chan_add(struct iso_conn *conn, struct sock *sk, + struct sock *parent) +{ + int err; + + iso_conn_lock(conn); + err = __iso_chan_add(conn, sk, parent); + iso_conn_unlock(conn); + + return err; +} + +static inline u8 le_addr_type(u8 bdaddr_type) +{ + if (bdaddr_type == BDADDR_LE_PUBLIC) + return ADDR_LE_DEV_PUBLIC; + else + return ADDR_LE_DEV_RANDOM; +} + +static int iso_connect_bis(struct sock *sk) +{ + struct iso_conn *conn; + struct hci_conn *hcon; + struct hci_dev *hdev; + int err; + + BT_DBG("%pMR", &iso_pi(sk)->src); + + hdev = hci_get_route(&iso_pi(sk)->dst, &iso_pi(sk)->src, + iso_pi(sk)->src_type); + if (!hdev) + return -EHOSTUNREACH; + + hci_dev_lock(hdev); + + if (!bis_capable(hdev)) { + err = -EOPNOTSUPP; + goto unlock; + } + + /* Fail if out PHYs are marked as disabled */ + if (!iso_pi(sk)->qos.out.phy) { + err = -EINVAL; + goto unlock; + } + + hcon = hci_connect_bis(hdev, &iso_pi(sk)->dst, iso_pi(sk)->dst_type, + &iso_pi(sk)->qos, iso_pi(sk)->base_len, + iso_pi(sk)->base); + if (IS_ERR(hcon)) { + err = PTR_ERR(hcon); + goto unlock; + } + + conn = iso_conn_add(hcon); + if (!conn) { + hci_conn_drop(hcon); + err = -ENOMEM; + goto unlock; + } + + lock_sock(sk); + + err = iso_chan_add(conn, sk, NULL); + if (err) { + release_sock(sk); + goto unlock; + } + + /* Update source addr of the socket */ + bacpy(&iso_pi(sk)->src, &hcon->src); + + if (hcon->state == BT_CONNECTED) { + iso_sock_clear_timer(sk); + sk->sk_state = BT_CONNECTED; + } else { + sk->sk_state = BT_CONNECT; + iso_sock_set_timer(sk, sk->sk_sndtimeo); + } + + release_sock(sk); + +unlock: + hci_dev_unlock(hdev); + hci_dev_put(hdev); + return err; +} + +static int iso_connect_cis(struct sock *sk) +{ + struct iso_conn *conn; + struct hci_conn *hcon; + struct hci_dev *hdev; + int err; + + BT_DBG("%pMR -> %pMR", &iso_pi(sk)->src, &iso_pi(sk)->dst); + + hdev = hci_get_route(&iso_pi(sk)->dst, &iso_pi(sk)->src, + iso_pi(sk)->src_type); + if (!hdev) + return -EHOSTUNREACH; + + hci_dev_lock(hdev); + + if (!cis_central_capable(hdev)) { + err = -EOPNOTSUPP; + goto unlock; + } + + /* Fail if either PHYs are marked as disabled */ + if (!iso_pi(sk)->qos.in.phy && !iso_pi(sk)->qos.out.phy) { + err = -EINVAL; + goto unlock; + } + + /* Just bind if DEFER_SETUP has been set */ + if (test_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags)) { + hcon = hci_bind_cis(hdev, &iso_pi(sk)->dst, + le_addr_type(iso_pi(sk)->dst_type), + &iso_pi(sk)->qos); + if (IS_ERR(hcon)) { + err = PTR_ERR(hcon); + goto unlock; + } + } else { + hcon = hci_connect_cis(hdev, &iso_pi(sk)->dst, + le_addr_type(iso_pi(sk)->dst_type), + &iso_pi(sk)->qos); + if (IS_ERR(hcon)) { + err = PTR_ERR(hcon); + goto unlock; + } + } + + conn = iso_conn_add(hcon); + if (!conn) { + hci_conn_drop(hcon); + err = -ENOMEM; + goto unlock; + } + + lock_sock(sk); + + err = iso_chan_add(conn, sk, NULL); + if (err) { + release_sock(sk); + goto unlock; + } + + /* Update source addr of the socket */ + bacpy(&iso_pi(sk)->src, &hcon->src); + + if (hcon->state == BT_CONNECTED) { + iso_sock_clear_timer(sk); + sk->sk_state = BT_CONNECTED; + } else if (test_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags)) { + iso_sock_clear_timer(sk); + sk->sk_state = BT_CONNECT; + } else { + sk->sk_state = BT_CONNECT; + iso_sock_set_timer(sk, sk->sk_sndtimeo); + } + + release_sock(sk); + +unlock: + hci_dev_unlock(hdev); + hci_dev_put(hdev); + return err; +} + +static struct bt_iso_qos *iso_sock_get_qos(struct sock *sk) +{ + if (sk->sk_state == BT_CONNECTED || sk->sk_state == BT_CONNECT2) + return &iso_pi(sk)->conn->hcon->iso_qos; + + return &iso_pi(sk)->qos; +} + +static int iso_send_frame(struct sock *sk, struct sk_buff *skb) +{ + struct iso_conn *conn = iso_pi(sk)->conn; + struct bt_iso_qos *qos = iso_sock_get_qos(sk); + struct hci_iso_data_hdr *hdr; + int len = 0; + + BT_DBG("sk %p len %d", sk, skb->len); + + if (skb->len > qos->out.sdu) + return -EMSGSIZE; + + len = skb->len; + + /* Push ISO data header */ + hdr = skb_push(skb, HCI_ISO_DATA_HDR_SIZE); + hdr->sn = cpu_to_le16(conn->tx_sn++); + hdr->slen = cpu_to_le16(hci_iso_data_len_pack(len, + HCI_ISO_STATUS_VALID)); + + if (sk->sk_state == BT_CONNECTED) + hci_send_iso(conn->hcon, skb); + else + len = -ENOTCONN; + + return len; +} + +static void iso_recv_frame(struct iso_conn *conn, struct sk_buff *skb) +{ + struct sock *sk; + + iso_conn_lock(conn); + sk = conn->sk; + iso_conn_unlock(conn); + + if (!sk) + goto drop; + + BT_DBG("sk %p len %d", sk, skb->len); + + if (sk->sk_state != BT_CONNECTED) + goto drop; + + if (!sock_queue_rcv_skb(sk, skb)) + return; + +drop: + kfree_skb(skb); +} + +/* -------- Socket interface ---------- */ +static struct sock *__iso_get_sock_listen_by_addr(bdaddr_t *src, bdaddr_t *dst) +{ + struct sock *sk; + + sk_for_each(sk, &iso_sk_list.head) { + if (sk->sk_state != BT_LISTEN) + continue; + + if (bacmp(&iso_pi(sk)->dst, dst)) + continue; + + if (!bacmp(&iso_pi(sk)->src, src)) + return sk; + } + + return NULL; +} + +static struct sock *__iso_get_sock_listen_by_sid(bdaddr_t *ba, bdaddr_t *bc, + __u8 sid) +{ + struct sock *sk; + + sk_for_each(sk, &iso_sk_list.head) { + if (sk->sk_state != BT_LISTEN) + continue; + + if (bacmp(&iso_pi(sk)->src, ba)) + continue; + + if (bacmp(&iso_pi(sk)->dst, bc)) + continue; + + if (iso_pi(sk)->bc_sid == sid) + return sk; + } + + return NULL; +} + +typedef bool (*iso_sock_match_t)(struct sock *sk, void *data); + +/* Find socket listening: + * source bdaddr (Unicast) + * destination bdaddr (Broadcast only) + * match func - pass NULL to ignore + * match func data - pass -1 to ignore + * Returns closest match. + */ +static struct sock *iso_get_sock_listen(bdaddr_t *src, bdaddr_t *dst, + iso_sock_match_t match, void *data) +{ + struct sock *sk = NULL, *sk1 = NULL; + + read_lock(&iso_sk_list.lock); + + sk_for_each(sk, &iso_sk_list.head) { + if (sk->sk_state != BT_LISTEN) + continue; + + /* Match Broadcast destination */ + if (bacmp(dst, BDADDR_ANY) && bacmp(&iso_pi(sk)->dst, dst)) + continue; + + /* Use Match function if provided */ + if (match && !match(sk, data)) + continue; + + /* Exact match. */ + if (!bacmp(&iso_pi(sk)->src, src)) + break; + + /* Closest match */ + if (!bacmp(&iso_pi(sk)->src, BDADDR_ANY)) + sk1 = sk; + } + + read_unlock(&iso_sk_list.lock); + + return sk ? sk : sk1; +} + +static void iso_sock_destruct(struct sock *sk) +{ + BT_DBG("sk %p", sk); + + skb_queue_purge(&sk->sk_receive_queue); + skb_queue_purge(&sk->sk_write_queue); +} + +static void iso_sock_cleanup_listen(struct sock *parent) +{ + struct sock *sk; + + BT_DBG("parent %p", parent); + + /* Close not yet accepted channels */ + while ((sk = bt_accept_dequeue(parent, NULL))) { + iso_sock_close(sk); + iso_sock_kill(sk); + } + + parent->sk_state = BT_CLOSED; + sock_set_flag(parent, SOCK_ZAPPED); +} + +/* Kill socket (only if zapped and orphan) + * Must be called on unlocked socket. + */ +static void iso_sock_kill(struct sock *sk) +{ + if (!sock_flag(sk, SOCK_ZAPPED) || sk->sk_socket || + sock_flag(sk, SOCK_DEAD)) + return; + + BT_DBG("sk %p state %d", sk, sk->sk_state); + + /* Kill poor orphan */ + bt_sock_unlink(&iso_sk_list, sk); + sock_set_flag(sk, SOCK_DEAD); + sock_put(sk); +} + +static void iso_conn_defer_reject(struct hci_conn *conn) +{ + struct hci_cp_le_reject_cis cp; + + BT_DBG("conn %p", conn); + + memset(&cp, 0, sizeof(cp)); + cp.handle = cpu_to_le16(conn->handle); + cp.reason = HCI_ERROR_REJ_BAD_ADDR; + hci_send_cmd(conn->hdev, HCI_OP_LE_REJECT_CIS, sizeof(cp), &cp); +} + +static void __iso_sock_close(struct sock *sk) +{ + BT_DBG("sk %p state %d socket %p", sk, sk->sk_state, sk->sk_socket); + + switch (sk->sk_state) { + case BT_LISTEN: + iso_sock_cleanup_listen(sk); + break; + + case BT_CONNECTED: + case BT_CONFIG: + if (iso_pi(sk)->conn->hcon) { + sk->sk_state = BT_DISCONN; + iso_sock_set_timer(sk, ISO_DISCONN_TIMEOUT); + iso_conn_lock(iso_pi(sk)->conn); + hci_conn_drop(iso_pi(sk)->conn->hcon); + iso_pi(sk)->conn->hcon = NULL; + iso_conn_unlock(iso_pi(sk)->conn); + } else { + iso_chan_del(sk, ECONNRESET); + } + break; + + case BT_CONNECT2: + if (iso_pi(sk)->conn->hcon) + iso_conn_defer_reject(iso_pi(sk)->conn->hcon); + iso_chan_del(sk, ECONNRESET); + break; + case BT_CONNECT: + /* In case of DEFER_SETUP the hcon would be bound to CIG which + * needs to be removed so just call hci_conn_del so the cleanup + * callback do what is needed. + */ + if (test_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags) && + iso_pi(sk)->conn->hcon) { + hci_conn_del(iso_pi(sk)->conn->hcon); + iso_pi(sk)->conn->hcon = NULL; + } + + iso_chan_del(sk, ECONNRESET); + break; + case BT_DISCONN: + iso_chan_del(sk, ECONNRESET); + break; + + default: + sock_set_flag(sk, SOCK_ZAPPED); + break; + } +} + +/* Must be called on unlocked socket. */ +static void iso_sock_close(struct sock *sk) +{ + iso_sock_clear_timer(sk); + lock_sock(sk); + __iso_sock_close(sk); + release_sock(sk); + iso_sock_kill(sk); +} + +static void iso_sock_init(struct sock *sk, struct sock *parent) +{ + BT_DBG("sk %p", sk); + + if (parent) { + sk->sk_type = parent->sk_type; + bt_sk(sk)->flags = bt_sk(parent)->flags; + security_sk_clone(parent, sk); + } +} + +static struct proto iso_proto = { + .name = "ISO", + .owner = THIS_MODULE, + .obj_size = sizeof(struct iso_pinfo) +}; + +#define DEFAULT_IO_QOS \ +{ \ + .interval = 10000u, \ + .latency = 10u, \ + .sdu = 40u, \ + .phy = BT_ISO_PHY_2M, \ + .rtn = 2u, \ +} + +static struct bt_iso_qos default_qos = { + .cig = BT_ISO_QOS_CIG_UNSET, + .cis = BT_ISO_QOS_CIS_UNSET, + .sca = 0x00, + .packing = 0x00, + .framing = 0x00, + .in = DEFAULT_IO_QOS, + .out = DEFAULT_IO_QOS, +}; + +static struct sock *iso_sock_alloc(struct net *net, struct socket *sock, + int proto, gfp_t prio, int kern) +{ + struct sock *sk; + + sk = sk_alloc(net, PF_BLUETOOTH, prio, &iso_proto, kern); + if (!sk) + return NULL; + + sock_init_data(sock, sk); + INIT_LIST_HEAD(&bt_sk(sk)->accept_q); + + sk->sk_destruct = iso_sock_destruct; + sk->sk_sndtimeo = ISO_CONN_TIMEOUT; + + sock_reset_flag(sk, SOCK_ZAPPED); + + sk->sk_protocol = proto; + sk->sk_state = BT_OPEN; + + /* Set address type as public as default src address is BDADDR_ANY */ + iso_pi(sk)->src_type = BDADDR_LE_PUBLIC; + + iso_pi(sk)->qos = default_qos; + + bt_sock_link(&iso_sk_list, sk); + return sk; +} + +static int iso_sock_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + + BT_DBG("sock %p", sock); + + sock->state = SS_UNCONNECTED; + + if (sock->type != SOCK_SEQPACKET) + return -ESOCKTNOSUPPORT; + + sock->ops = &iso_sock_ops; + + sk = iso_sock_alloc(net, sock, protocol, GFP_ATOMIC, kern); + if (!sk) + return -ENOMEM; + + iso_sock_init(sk, NULL); + return 0; +} + +static int iso_sock_bind_bc(struct socket *sock, struct sockaddr *addr, + int addr_len) +{ + struct sockaddr_iso *sa = (struct sockaddr_iso *)addr; + struct sock *sk = sock->sk; + int i; + + BT_DBG("sk %p bc_sid %u bc_num_bis %u", sk, sa->iso_bc->bc_sid, + sa->iso_bc->bc_num_bis); + + if (addr_len > sizeof(*sa) + sizeof(*sa->iso_bc) || + sa->iso_bc->bc_num_bis < 0x01 || sa->iso_bc->bc_num_bis > 0x1f) + return -EINVAL; + + bacpy(&iso_pi(sk)->dst, &sa->iso_bc->bc_bdaddr); + iso_pi(sk)->dst_type = sa->iso_bc->bc_bdaddr_type; + iso_pi(sk)->sync_handle = -1; + iso_pi(sk)->bc_sid = sa->iso_bc->bc_sid; + iso_pi(sk)->bc_num_bis = sa->iso_bc->bc_num_bis; + + for (i = 0; i < iso_pi(sk)->bc_num_bis; i++) { + if (sa->iso_bc->bc_bis[i] < 0x01 || + sa->iso_bc->bc_bis[i] > 0x1f) + return -EINVAL; + + memcpy(iso_pi(sk)->bc_bis, sa->iso_bc->bc_bis, + iso_pi(sk)->bc_num_bis); + } + + return 0; +} + +static int iso_sock_bind(struct socket *sock, struct sockaddr *addr, + int addr_len) +{ + struct sockaddr_iso *sa = (struct sockaddr_iso *)addr; + struct sock *sk = sock->sk; + int err = 0; + + BT_DBG("sk %p %pMR type %u", sk, &sa->iso_bdaddr, sa->iso_bdaddr_type); + + if (!addr || addr_len < sizeof(struct sockaddr_iso) || + addr->sa_family != AF_BLUETOOTH) + return -EINVAL; + + lock_sock(sk); + + if (sk->sk_state != BT_OPEN) { + err = -EBADFD; + goto done; + } + + if (sk->sk_type != SOCK_SEQPACKET) { + err = -EINVAL; + goto done; + } + + /* Check if the address type is of LE type */ + if (!bdaddr_type_is_le(sa->iso_bdaddr_type)) { + err = -EINVAL; + goto done; + } + + bacpy(&iso_pi(sk)->src, &sa->iso_bdaddr); + iso_pi(sk)->src_type = sa->iso_bdaddr_type; + + /* Check for Broadcast address */ + if (addr_len > sizeof(*sa)) { + err = iso_sock_bind_bc(sock, addr, addr_len); + if (err) + goto done; + } + + sk->sk_state = BT_BOUND; + +done: + release_sock(sk); + return err; +} + +static int iso_sock_connect(struct socket *sock, struct sockaddr *addr, + int alen, int flags) +{ + struct sockaddr_iso *sa = (struct sockaddr_iso *)addr; + struct sock *sk = sock->sk; + int err; + + BT_DBG("sk %p", sk); + + if (alen < sizeof(struct sockaddr_iso) || + addr->sa_family != AF_BLUETOOTH) + return -EINVAL; + + if (sk->sk_state != BT_OPEN && sk->sk_state != BT_BOUND) + return -EBADFD; + + if (sk->sk_type != SOCK_SEQPACKET) + return -EINVAL; + + /* Check if the address type is of LE type */ + if (!bdaddr_type_is_le(sa->iso_bdaddr_type)) + return -EINVAL; + + lock_sock(sk); + + bacpy(&iso_pi(sk)->dst, &sa->iso_bdaddr); + iso_pi(sk)->dst_type = sa->iso_bdaddr_type; + + release_sock(sk); + + if (bacmp(&iso_pi(sk)->dst, BDADDR_ANY)) + err = iso_connect_cis(sk); + else + err = iso_connect_bis(sk); + + if (err) + return err; + + lock_sock(sk); + + if (!test_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags)) { + err = bt_sock_wait_state(sk, BT_CONNECTED, + sock_sndtimeo(sk, flags & O_NONBLOCK)); + } + + release_sock(sk); + return err; +} + +static int iso_listen_bis(struct sock *sk) +{ + struct hci_dev *hdev; + int err = 0; + + BT_DBG("%pMR -> %pMR (SID 0x%2.2x)", &iso_pi(sk)->src, + &iso_pi(sk)->dst, iso_pi(sk)->bc_sid); + + write_lock(&iso_sk_list.lock); + + if (__iso_get_sock_listen_by_sid(&iso_pi(sk)->src, &iso_pi(sk)->dst, + iso_pi(sk)->bc_sid)) + err = -EADDRINUSE; + + write_unlock(&iso_sk_list.lock); + + if (err) + return err; + + hdev = hci_get_route(&iso_pi(sk)->dst, &iso_pi(sk)->src, + iso_pi(sk)->src_type); + if (!hdev) + return -EHOSTUNREACH; + + hci_dev_lock(hdev); + + err = hci_pa_create_sync(hdev, &iso_pi(sk)->dst, iso_pi(sk)->dst_type, + iso_pi(sk)->bc_sid); + + hci_dev_unlock(hdev); + hci_dev_put(hdev); + + return err; +} + +static int iso_listen_cis(struct sock *sk) +{ + int err = 0; + + BT_DBG("%pMR", &iso_pi(sk)->src); + + write_lock(&iso_sk_list.lock); + + if (__iso_get_sock_listen_by_addr(&iso_pi(sk)->src, &iso_pi(sk)->dst)) + err = -EADDRINUSE; + + write_unlock(&iso_sk_list.lock); + + return err; +} + +static int iso_sock_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + int err = 0; + + BT_DBG("sk %p backlog %d", sk, backlog); + + lock_sock(sk); + + if (sk->sk_state != BT_BOUND) { + err = -EBADFD; + goto done; + } + + if (sk->sk_type != SOCK_SEQPACKET) { + err = -EINVAL; + goto done; + } + + if (!bacmp(&iso_pi(sk)->dst, BDADDR_ANY)) + err = iso_listen_cis(sk); + else + err = iso_listen_bis(sk); + + if (err) + goto done; + + sk->sk_max_ack_backlog = backlog; + sk->sk_ack_backlog = 0; + + sk->sk_state = BT_LISTEN; + +done: + release_sock(sk); + return err; +} + +static int iso_sock_accept(struct socket *sock, struct socket *newsock, + int flags, bool kern) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + struct sock *sk = sock->sk, *ch; + long timeo; + int err = 0; + + lock_sock(sk); + + timeo = sock_rcvtimeo(sk, flags & O_NONBLOCK); + + BT_DBG("sk %p timeo %ld", sk, timeo); + + /* Wait for an incoming connection. (wake-one). */ + add_wait_queue_exclusive(sk_sleep(sk), &wait); + while (1) { + if (sk->sk_state != BT_LISTEN) { + err = -EBADFD; + break; + } + + ch = bt_accept_dequeue(sk, newsock); + if (ch) + break; + + if (!timeo) { + err = -EAGAIN; + break; + } + + if (signal_pending(current)) { + err = sock_intr_errno(timeo); + break; + } + + release_sock(sk); + + timeo = wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); + lock_sock(sk); + } + remove_wait_queue(sk_sleep(sk), &wait); + + if (err) + goto done; + + newsock->state = SS_CONNECTED; + + BT_DBG("new socket %p", ch); + +done: + release_sock(sk); + return err; +} + +static int iso_sock_getname(struct socket *sock, struct sockaddr *addr, + int peer) +{ + struct sockaddr_iso *sa = (struct sockaddr_iso *)addr; + struct sock *sk = sock->sk; + + BT_DBG("sock %p, sk %p", sock, sk); + + addr->sa_family = AF_BLUETOOTH; + + if (peer) { + bacpy(&sa->iso_bdaddr, &iso_pi(sk)->dst); + sa->iso_bdaddr_type = iso_pi(sk)->dst_type; + } else { + bacpy(&sa->iso_bdaddr, &iso_pi(sk)->src); + sa->iso_bdaddr_type = iso_pi(sk)->src_type; + } + + return sizeof(struct sockaddr_iso); +} + +static int iso_sock_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + struct sock *sk = sock->sk; + struct sk_buff *skb, **frag; + size_t mtu; + int err; + + BT_DBG("sock %p, sk %p", sock, sk); + + err = sock_error(sk); + if (err) + return err; + + if (msg->msg_flags & MSG_OOB) + return -EOPNOTSUPP; + + lock_sock(sk); + + if (sk->sk_state != BT_CONNECTED) { + release_sock(sk); + return -ENOTCONN; + } + + mtu = iso_pi(sk)->conn->hcon->hdev->iso_mtu; + + release_sock(sk); + + skb = bt_skb_sendmsg(sk, msg, len, mtu, HCI_ISO_DATA_HDR_SIZE, 0); + if (IS_ERR(skb)) + return PTR_ERR(skb); + + len -= skb->len; + + BT_DBG("skb %p len %d", sk, skb->len); + + /* Continuation fragments */ + frag = &skb_shinfo(skb)->frag_list; + while (len) { + struct sk_buff *tmp; + + tmp = bt_skb_sendmsg(sk, msg, len, mtu, 0, 0); + if (IS_ERR(tmp)) { + kfree_skb(skb); + return PTR_ERR(tmp); + } + + *frag = tmp; + + len -= tmp->len; + + skb->len += tmp->len; + skb->data_len += tmp->len; + + BT_DBG("frag %p len %d", *frag, tmp->len); + + frag = &(*frag)->next; + } + + lock_sock(sk); + + if (sk->sk_state == BT_CONNECTED) + err = iso_send_frame(sk, skb); + else + err = -ENOTCONN; + + release_sock(sk); + + if (err < 0) + kfree_skb(skb); + return err; +} + +static void iso_conn_defer_accept(struct hci_conn *conn) +{ + struct hci_cp_le_accept_cis cp; + struct hci_dev *hdev = conn->hdev; + + BT_DBG("conn %p", conn); + + conn->state = BT_CONFIG; + + cp.handle = cpu_to_le16(conn->handle); + + hci_send_cmd(hdev, HCI_OP_LE_ACCEPT_CIS, sizeof(cp), &cp); +} + +static int iso_sock_recvmsg(struct socket *sock, struct msghdr *msg, + size_t len, int flags) +{ + struct sock *sk = sock->sk; + struct iso_pinfo *pi = iso_pi(sk); + + BT_DBG("sk %p", sk); + + if (test_and_clear_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags)) { + lock_sock(sk); + switch (sk->sk_state) { + case BT_CONNECT2: + iso_conn_defer_accept(pi->conn->hcon); + sk->sk_state = BT_CONFIG; + release_sock(sk); + return 0; + case BT_CONNECT: + release_sock(sk); + return iso_connect_cis(sk); + default: + release_sock(sk); + break; + } + } + + return bt_sock_recvmsg(sock, msg, len, flags); +} + +static bool check_io_qos(struct bt_iso_io_qos *qos) +{ + /* If no PHY is enable SDU must be 0 */ + if (!qos->phy && qos->sdu) + return false; + + if (qos->interval && (qos->interval < 0xff || qos->interval > 0xfffff)) + return false; + + if (qos->latency && (qos->latency < 0x05 || qos->latency > 0xfa0)) + return false; + + if (qos->phy > BT_ISO_PHY_ANY) + return false; + + return true; +} + +static bool check_qos(struct bt_iso_qos *qos) +{ + if (qos->sca > 0x07) + return false; + + if (qos->packing > 0x01) + return false; + + if (qos->framing > 0x01) + return false; + + if (!check_io_qos(&qos->in)) + return false; + + if (!check_io_qos(&qos->out)) + return false; + + return true; +} + +static int iso_sock_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + int len, err = 0; + struct bt_iso_qos qos; + u32 opt; + + BT_DBG("sk %p", sk); + + lock_sock(sk); + + switch (optname) { + case BT_DEFER_SETUP: + if (sk->sk_state != BT_BOUND && sk->sk_state != BT_LISTEN) { + err = -EINVAL; + break; + } + + if (copy_from_sockptr(&opt, optval, sizeof(u32))) { + err = -EFAULT; + break; + } + + if (opt) + set_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags); + else + clear_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags); + break; + + case BT_ISO_QOS: + if (sk->sk_state != BT_OPEN && sk->sk_state != BT_BOUND && + sk->sk_state != BT_CONNECT2) { + err = -EINVAL; + break; + } + + len = min_t(unsigned int, sizeof(qos), optlen); + if (len != sizeof(qos)) { + err = -EINVAL; + break; + } + + memset(&qos, 0, sizeof(qos)); + + if (copy_from_sockptr(&qos, optval, len)) { + err = -EFAULT; + break; + } + + if (!check_qos(&qos)) { + err = -EINVAL; + break; + } + + iso_pi(sk)->qos = qos; + + break; + + case BT_ISO_BASE: + if (sk->sk_state != BT_OPEN && sk->sk_state != BT_BOUND && + sk->sk_state != BT_CONNECT2) { + err = -EINVAL; + break; + } + + if (optlen > sizeof(iso_pi(sk)->base)) { + err = -EOVERFLOW; + break; + } + + len = min_t(unsigned int, sizeof(iso_pi(sk)->base), optlen); + + if (copy_from_sockptr(iso_pi(sk)->base, optval, len)) { + err = -EFAULT; + break; + } + + iso_pi(sk)->base_len = len; + + break; + + default: + err = -ENOPROTOOPT; + break; + } + + release_sock(sk); + return err; +} + +static int iso_sock_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + int len, err = 0; + struct bt_iso_qos *qos; + u8 base_len; + u8 *base; + + BT_DBG("sk %p", sk); + + if (get_user(len, optlen)) + return -EFAULT; + + lock_sock(sk); + + switch (optname) { + case BT_DEFER_SETUP: + if (sk->sk_state == BT_CONNECTED) { + err = -EINVAL; + break; + } + + if (put_user(test_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags), + (u32 __user *)optval)) + err = -EFAULT; + + break; + + case BT_ISO_QOS: + qos = iso_sock_get_qos(sk); + + len = min_t(unsigned int, len, sizeof(*qos)); + if (copy_to_user(optval, qos, len)) + err = -EFAULT; + + break; + + case BT_ISO_BASE: + if (sk->sk_state == BT_CONNECTED) { + base_len = iso_pi(sk)->conn->hcon->le_per_adv_data_len; + base = iso_pi(sk)->conn->hcon->le_per_adv_data; + } else { + base_len = iso_pi(sk)->base_len; + base = iso_pi(sk)->base; + } + + len = min_t(unsigned int, len, base_len); + if (copy_to_user(optval, base, len)) + err = -EFAULT; + + break; + + default: + err = -ENOPROTOOPT; + break; + } + + release_sock(sk); + return err; +} + +static int iso_sock_shutdown(struct socket *sock, int how) +{ + struct sock *sk = sock->sk; + int err = 0; + + BT_DBG("sock %p, sk %p, how %d", sock, sk, how); + + if (!sk) + return 0; + + sock_hold(sk); + lock_sock(sk); + + switch (how) { + case SHUT_RD: + if (sk->sk_shutdown & RCV_SHUTDOWN) + goto unlock; + sk->sk_shutdown |= RCV_SHUTDOWN; + break; + case SHUT_WR: + if (sk->sk_shutdown & SEND_SHUTDOWN) + goto unlock; + sk->sk_shutdown |= SEND_SHUTDOWN; + break; + case SHUT_RDWR: + if (sk->sk_shutdown & SHUTDOWN_MASK) + goto unlock; + sk->sk_shutdown |= SHUTDOWN_MASK; + break; + } + + iso_sock_clear_timer(sk); + __iso_sock_close(sk); + + if (sock_flag(sk, SOCK_LINGER) && sk->sk_lingertime && + !(current->flags & PF_EXITING)) + err = bt_sock_wait_state(sk, BT_CLOSED, sk->sk_lingertime); + +unlock: + release_sock(sk); + sock_put(sk); + + return err; +} + +static int iso_sock_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + int err = 0; + + BT_DBG("sock %p, sk %p", sock, sk); + + if (!sk) + return 0; + + iso_sock_close(sk); + + if (sock_flag(sk, SOCK_LINGER) && READ_ONCE(sk->sk_lingertime) && + !(current->flags & PF_EXITING)) { + lock_sock(sk); + err = bt_sock_wait_state(sk, BT_CLOSED, sk->sk_lingertime); + release_sock(sk); + } + + sock_orphan(sk); + iso_sock_kill(sk); + return err; +} + +static void iso_sock_ready(struct sock *sk) +{ + BT_DBG("sk %p", sk); + + if (!sk) + return; + + lock_sock(sk); + iso_sock_clear_timer(sk); + sk->sk_state = BT_CONNECTED; + sk->sk_state_change(sk); + release_sock(sk); +} + +struct iso_list_data { + struct hci_conn *hcon; + int count; +}; + +static bool iso_match_big(struct sock *sk, void *data) +{ + struct hci_evt_le_big_sync_estabilished *ev = data; + + return ev->handle == iso_pi(sk)->qos.big; +} + +static void iso_conn_ready(struct iso_conn *conn) +{ + struct sock *parent; + struct sock *sk = conn->sk; + struct hci_ev_le_big_sync_estabilished *ev; + struct hci_conn *hcon; + + BT_DBG("conn %p", conn); + + if (sk) { + iso_sock_ready(conn->sk); + } else { + hcon = conn->hcon; + if (!hcon) + return; + + ev = hci_recv_event_data(hcon->hdev, + HCI_EVT_LE_BIG_SYNC_ESTABILISHED); + if (ev) + parent = iso_get_sock_listen(&hcon->src, + &hcon->dst, + iso_match_big, ev); + else + parent = iso_get_sock_listen(&hcon->src, + BDADDR_ANY, NULL, NULL); + + if (!parent) + return; + + lock_sock(parent); + + sk = iso_sock_alloc(sock_net(parent), NULL, + BTPROTO_ISO, GFP_ATOMIC, 0); + if (!sk) { + release_sock(parent); + return; + } + + iso_sock_init(sk, parent); + + bacpy(&iso_pi(sk)->src, &hcon->src); + iso_pi(sk)->src_type = hcon->src_type; + + /* If hcon has no destination address (BDADDR_ANY) it means it + * was created by HCI_EV_LE_BIG_SYNC_ESTABILISHED so we need to + * initialize using the parent socket destination address. + */ + if (!bacmp(&hcon->dst, BDADDR_ANY)) { + bacpy(&hcon->dst, &iso_pi(parent)->dst); + hcon->dst_type = iso_pi(parent)->dst_type; + hcon->sync_handle = iso_pi(parent)->sync_handle; + } + + bacpy(&iso_pi(sk)->dst, &hcon->dst); + iso_pi(sk)->dst_type = hcon->dst_type; + + hci_conn_hold(hcon); + iso_chan_add(conn, sk, parent); + + if (test_bit(BT_SK_DEFER_SETUP, &bt_sk(parent)->flags)) + sk->sk_state = BT_CONNECT2; + else + sk->sk_state = BT_CONNECTED; + + /* Wake up parent */ + parent->sk_data_ready(parent); + + release_sock(parent); + } +} + +static bool iso_match_sid(struct sock *sk, void *data) +{ + struct hci_ev_le_pa_sync_established *ev = data; + + return ev->sid == iso_pi(sk)->bc_sid; +} + +static bool iso_match_sync_handle(struct sock *sk, void *data) +{ + struct hci_evt_le_big_info_adv_report *ev = data; + + return le16_to_cpu(ev->sync_handle) == iso_pi(sk)->sync_handle; +} + +/* ----- ISO interface with lower layer (HCI) ----- */ + +int iso_connect_ind(struct hci_dev *hdev, bdaddr_t *bdaddr, __u8 *flags) +{ + struct hci_ev_le_pa_sync_established *ev1; + struct hci_evt_le_big_info_adv_report *ev2; + struct sock *sk; + int lm = 0; + + bt_dev_dbg(hdev, "bdaddr %pMR", bdaddr); + + /* Broadcast receiver requires handling of some events before it can + * proceed to establishing a BIG sync: + * + * 1. HCI_EV_LE_PA_SYNC_ESTABLISHED: The socket may specify a specific + * SID to listen to and once sync is estabilished its handle needs to + * be stored in iso_pi(sk)->sync_handle so it can be matched once + * receiving the BIG Info. + * 2. HCI_EVT_LE_BIG_INFO_ADV_REPORT: When connect_ind is triggered by a + * a BIG Info it attempts to check if there any listening socket with + * the same sync_handle and if it does then attempt to create a sync. + */ + ev1 = hci_recv_event_data(hdev, HCI_EV_LE_PA_SYNC_ESTABLISHED); + if (ev1) { + sk = iso_get_sock_listen(&hdev->bdaddr, bdaddr, iso_match_sid, + ev1); + if (sk) + iso_pi(sk)->sync_handle = le16_to_cpu(ev1->handle); + + goto done; + } + + ev2 = hci_recv_event_data(hdev, HCI_EVT_LE_BIG_INFO_ADV_REPORT); + if (ev2) { + sk = iso_get_sock_listen(&hdev->bdaddr, bdaddr, + iso_match_sync_handle, ev2); + if (sk) { + int err; + + if (ev2->num_bis < iso_pi(sk)->bc_num_bis) + iso_pi(sk)->bc_num_bis = ev2->num_bis; + + err = hci_le_big_create_sync(hdev, + &iso_pi(sk)->qos, + iso_pi(sk)->sync_handle, + iso_pi(sk)->bc_num_bis, + iso_pi(sk)->bc_bis); + if (err) { + bt_dev_err(hdev, "hci_le_big_create_sync: %d", + err); + sk = NULL; + } + } + } else { + sk = iso_get_sock_listen(&hdev->bdaddr, BDADDR_ANY, NULL, NULL); + } + +done: + if (!sk) + return lm; + + lm |= HCI_LM_ACCEPT; + + if (test_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags)) + *flags |= HCI_PROTO_DEFER; + + return lm; +} + +static void iso_connect_cfm(struct hci_conn *hcon, __u8 status) +{ + if (hcon->type != ISO_LINK) { + if (hcon->type != LE_LINK) + return; + + /* Check if LE link has failed */ + if (status) { + if (hcon->link) + iso_conn_del(hcon->link, bt_to_errno(status)); + return; + } + + /* Create CIS if pending */ + hci_le_create_cis(hcon); + return; + } + + BT_DBG("hcon %p bdaddr %pMR status %d", hcon, &hcon->dst, status); + + if (!status) { + struct iso_conn *conn; + + conn = iso_conn_add(hcon); + if (conn) + iso_conn_ready(conn); + } else { + iso_conn_del(hcon, bt_to_errno(status)); + } +} + +static void iso_disconn_cfm(struct hci_conn *hcon, __u8 reason) +{ + if (hcon->type != ISO_LINK) + return; + + BT_DBG("hcon %p reason %d", hcon, reason); + + iso_conn_del(hcon, bt_to_errno(reason)); +} + +void iso_recv(struct hci_conn *hcon, struct sk_buff *skb, u16 flags) +{ + struct iso_conn *conn = hcon->iso_data; + __u16 pb, ts, len; + + if (!conn) + goto drop; + + pb = hci_iso_flags_pb(flags); + ts = hci_iso_flags_ts(flags); + + BT_DBG("conn %p len %d pb 0x%x ts 0x%x", conn, skb->len, pb, ts); + + switch (pb) { + case ISO_START: + case ISO_SINGLE: + if (conn->rx_len) { + BT_ERR("Unexpected start frame (len %d)", skb->len); + kfree_skb(conn->rx_skb); + conn->rx_skb = NULL; + conn->rx_len = 0; + } + + if (ts) { + struct hci_iso_ts_data_hdr *hdr; + + /* TODO: add timestamp to the packet? */ + hdr = skb_pull_data(skb, HCI_ISO_TS_DATA_HDR_SIZE); + if (!hdr) { + BT_ERR("Frame is too short (len %d)", skb->len); + goto drop; + } + + len = __le16_to_cpu(hdr->slen); + } else { + struct hci_iso_data_hdr *hdr; + + hdr = skb_pull_data(skb, HCI_ISO_DATA_HDR_SIZE); + if (!hdr) { + BT_ERR("Frame is too short (len %d)", skb->len); + goto drop; + } + + len = __le16_to_cpu(hdr->slen); + } + + flags = hci_iso_data_flags(len); + len = hci_iso_data_len(len); + + BT_DBG("Start: total len %d, frag len %d flags 0x%4.4x", len, + skb->len, flags); + + if (len == skb->len) { + /* Complete frame received */ + iso_recv_frame(conn, skb); + return; + } + + if (pb == ISO_SINGLE) { + BT_ERR("Frame malformed (len %d, expected len %d)", + skb->len, len); + goto drop; + } + + if (skb->len > len) { + BT_ERR("Frame is too long (len %d, expected len %d)", + skb->len, len); + goto drop; + } + + /* Allocate skb for the complete frame (with header) */ + conn->rx_skb = bt_skb_alloc(len, GFP_KERNEL); + if (!conn->rx_skb) + goto drop; + + skb_copy_from_linear_data(skb, skb_put(conn->rx_skb, skb->len), + skb->len); + conn->rx_len = len - skb->len; + break; + + case ISO_CONT: + BT_DBG("Cont: frag len %d (expecting %d)", skb->len, + conn->rx_len); + + if (!conn->rx_len) { + BT_ERR("Unexpected continuation frame (len %d)", + skb->len); + goto drop; + } + + if (skb->len > conn->rx_len) { + BT_ERR("Fragment is too long (len %d, expected %d)", + skb->len, conn->rx_len); + kfree_skb(conn->rx_skb); + conn->rx_skb = NULL; + conn->rx_len = 0; + goto drop; + } + + skb_copy_from_linear_data(skb, skb_put(conn->rx_skb, skb->len), + skb->len); + conn->rx_len -= skb->len; + return; + + case ISO_END: + skb_copy_from_linear_data(skb, skb_put(conn->rx_skb, skb->len), + skb->len); + conn->rx_len -= skb->len; + + if (!conn->rx_len) { + struct sk_buff *rx_skb = conn->rx_skb; + + /* Complete frame received. iso_recv_frame + * takes ownership of the skb so set the global + * rx_skb pointer to NULL first. + */ + conn->rx_skb = NULL; + iso_recv_frame(conn, rx_skb); + } + break; + } + +drop: + kfree_skb(skb); +} + +static struct hci_cb iso_cb = { + .name = "ISO", + .connect_cfm = iso_connect_cfm, + .disconn_cfm = iso_disconn_cfm, +}; + +static int iso_debugfs_show(struct seq_file *f, void *p) +{ + struct sock *sk; + + read_lock(&iso_sk_list.lock); + + sk_for_each(sk, &iso_sk_list.head) { + seq_printf(f, "%pMR %pMR %d\n", &iso_pi(sk)->src, + &iso_pi(sk)->dst, sk->sk_state); + } + + read_unlock(&iso_sk_list.lock); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(iso_debugfs); + +static struct dentry *iso_debugfs; + +static const struct proto_ops iso_sock_ops = { + .family = PF_BLUETOOTH, + .owner = THIS_MODULE, + .release = iso_sock_release, + .bind = iso_sock_bind, + .connect = iso_sock_connect, + .listen = iso_sock_listen, + .accept = iso_sock_accept, + .getname = iso_sock_getname, + .sendmsg = iso_sock_sendmsg, + .recvmsg = iso_sock_recvmsg, + .poll = bt_sock_poll, + .ioctl = bt_sock_ioctl, + .mmap = sock_no_mmap, + .socketpair = sock_no_socketpair, + .shutdown = iso_sock_shutdown, + .setsockopt = iso_sock_setsockopt, + .getsockopt = iso_sock_getsockopt +}; + +static const struct net_proto_family iso_sock_family_ops = { + .family = PF_BLUETOOTH, + .owner = THIS_MODULE, + .create = iso_sock_create, +}; + +static bool iso_inited; + +bool iso_enabled(void) +{ + return iso_inited; +} + +int iso_init(void) +{ + int err; + + BUILD_BUG_ON(sizeof(struct sockaddr_iso) > sizeof(struct sockaddr)); + + if (iso_inited) + return -EALREADY; + + err = proto_register(&iso_proto, 0); + if (err < 0) + return err; + + err = bt_sock_register(BTPROTO_ISO, &iso_sock_family_ops); + if (err < 0) { + BT_ERR("ISO socket registration failed"); + goto error; + } + + err = bt_procfs_init(&init_net, "iso", &iso_sk_list, NULL); + if (err < 0) { + BT_ERR("Failed to create ISO proc file"); + bt_sock_unregister(BTPROTO_ISO); + goto error; + } + + BT_INFO("ISO socket layer initialized"); + + hci_register_cb(&iso_cb); + + if (IS_ERR_OR_NULL(bt_debugfs)) + return 0; + + if (!iso_debugfs) { + iso_debugfs = debugfs_create_file("iso", 0444, bt_debugfs, + NULL, &iso_debugfs_fops); + } + + iso_inited = true; + + return 0; + +error: + proto_unregister(&iso_proto); + return err; +} + +int iso_exit(void) +{ + if (!iso_inited) + return -EALREADY; + + bt_procfs_cleanup(&init_net, "iso"); + + debugfs_remove(iso_debugfs); + iso_debugfs = NULL; + + hci_unregister_cb(&iso_cb); + + bt_sock_unregister(BTPROTO_ISO); + + proto_unregister(&iso_proto); + + iso_inited = false; + + return 0; +} diff --git a/net/bluetooth/l2cap_core.c b/net/bluetooth/l2cap_core.c new file mode 100644 index 000000000..4c5793053 --- /dev/null +++ b/net/bluetooth/l2cap_core.c @@ -0,0 +1,8657 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + Copyright (C) 2000-2001 Qualcomm Incorporated + Copyright (C) 2009-2010 Gustavo F. Padovan + Copyright (C) 2010 Google Inc. + Copyright (C) 2011 ProFUSION Embedded Systems + Copyright (c) 2012 Code Aurora Forum. All rights reserved. + + Written 2000,2001 by Maxim Krasnyansky + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +/* Bluetooth L2CAP core. */ + +#include + +#include +#include +#include + +#include +#include +#include + +#include "smp.h" +#include "a2mp.h" +#include "amp.h" + +#define LE_FLOWCTL_MAX_CREDITS 65535 + +bool disable_ertm; +bool enable_ecred; + +static u32 l2cap_feat_mask = L2CAP_FEAT_FIXED_CHAN | L2CAP_FEAT_UCD; + +static LIST_HEAD(chan_list); +static DEFINE_RWLOCK(chan_list_lock); + +static struct sk_buff *l2cap_build_cmd(struct l2cap_conn *conn, + u8 code, u8 ident, u16 dlen, void *data); +static void l2cap_send_cmd(struct l2cap_conn *conn, u8 ident, u8 code, u16 len, + void *data); +static int l2cap_build_conf_req(struct l2cap_chan *chan, void *data, size_t data_size); +static void l2cap_send_disconn_req(struct l2cap_chan *chan, int err); + +static void l2cap_tx(struct l2cap_chan *chan, struct l2cap_ctrl *control, + struct sk_buff_head *skbs, u8 event); +static void l2cap_retrans_timeout(struct work_struct *work); +static void l2cap_monitor_timeout(struct work_struct *work); +static void l2cap_ack_timeout(struct work_struct *work); + +static inline u8 bdaddr_type(u8 link_type, u8 bdaddr_type) +{ + if (link_type == LE_LINK) { + if (bdaddr_type == ADDR_LE_DEV_PUBLIC) + return BDADDR_LE_PUBLIC; + else + return BDADDR_LE_RANDOM; + } + + return BDADDR_BREDR; +} + +static inline u8 bdaddr_src_type(struct hci_conn *hcon) +{ + return bdaddr_type(hcon->type, hcon->src_type); +} + +static inline u8 bdaddr_dst_type(struct hci_conn *hcon) +{ + return bdaddr_type(hcon->type, hcon->dst_type); +} + +/* ---- L2CAP channels ---- */ + +static struct l2cap_chan *__l2cap_get_chan_by_dcid(struct l2cap_conn *conn, + u16 cid) +{ + struct l2cap_chan *c; + + list_for_each_entry(c, &conn->chan_l, list) { + if (c->dcid == cid) + return c; + } + return NULL; +} + +static struct l2cap_chan *__l2cap_get_chan_by_scid(struct l2cap_conn *conn, + u16 cid) +{ + struct l2cap_chan *c; + + list_for_each_entry(c, &conn->chan_l, list) { + if (c->scid == cid) + return c; + } + return NULL; +} + +/* Find channel with given SCID. + * Returns a reference locked channel. + */ +static struct l2cap_chan *l2cap_get_chan_by_scid(struct l2cap_conn *conn, + u16 cid) +{ + struct l2cap_chan *c; + + mutex_lock(&conn->chan_lock); + c = __l2cap_get_chan_by_scid(conn, cid); + if (c) { + /* Only lock if chan reference is not 0 */ + c = l2cap_chan_hold_unless_zero(c); + if (c) + l2cap_chan_lock(c); + } + mutex_unlock(&conn->chan_lock); + + return c; +} + +/* Find channel with given DCID. + * Returns a reference locked channel. + */ +static struct l2cap_chan *l2cap_get_chan_by_dcid(struct l2cap_conn *conn, + u16 cid) +{ + struct l2cap_chan *c; + + mutex_lock(&conn->chan_lock); + c = __l2cap_get_chan_by_dcid(conn, cid); + if (c) { + /* Only lock if chan reference is not 0 */ + c = l2cap_chan_hold_unless_zero(c); + if (c) + l2cap_chan_lock(c); + } + mutex_unlock(&conn->chan_lock); + + return c; +} + +static struct l2cap_chan *__l2cap_get_chan_by_ident(struct l2cap_conn *conn, + u8 ident) +{ + struct l2cap_chan *c; + + list_for_each_entry(c, &conn->chan_l, list) { + if (c->ident == ident) + return c; + } + return NULL; +} + +static struct l2cap_chan *l2cap_get_chan_by_ident(struct l2cap_conn *conn, + u8 ident) +{ + struct l2cap_chan *c; + + mutex_lock(&conn->chan_lock); + c = __l2cap_get_chan_by_ident(conn, ident); + if (c) { + /* Only lock if chan reference is not 0 */ + c = l2cap_chan_hold_unless_zero(c); + if (c) + l2cap_chan_lock(c); + } + mutex_unlock(&conn->chan_lock); + + return c; +} + +static struct l2cap_chan *__l2cap_global_chan_by_addr(__le16 psm, bdaddr_t *src, + u8 src_type) +{ + struct l2cap_chan *c; + + list_for_each_entry(c, &chan_list, global_l) { + if (src_type == BDADDR_BREDR && c->src_type != BDADDR_BREDR) + continue; + + if (src_type != BDADDR_BREDR && c->src_type == BDADDR_BREDR) + continue; + + if (c->sport == psm && !bacmp(&c->src, src)) + return c; + } + return NULL; +} + +int l2cap_add_psm(struct l2cap_chan *chan, bdaddr_t *src, __le16 psm) +{ + int err; + + write_lock(&chan_list_lock); + + if (psm && __l2cap_global_chan_by_addr(psm, src, chan->src_type)) { + err = -EADDRINUSE; + goto done; + } + + if (psm) { + chan->psm = psm; + chan->sport = psm; + err = 0; + } else { + u16 p, start, end, incr; + + if (chan->src_type == BDADDR_BREDR) { + start = L2CAP_PSM_DYN_START; + end = L2CAP_PSM_AUTO_END; + incr = 2; + } else { + start = L2CAP_PSM_LE_DYN_START; + end = L2CAP_PSM_LE_DYN_END; + incr = 1; + } + + err = -EINVAL; + for (p = start; p <= end; p += incr) + if (!__l2cap_global_chan_by_addr(cpu_to_le16(p), src, + chan->src_type)) { + chan->psm = cpu_to_le16(p); + chan->sport = cpu_to_le16(p); + err = 0; + break; + } + } + +done: + write_unlock(&chan_list_lock); + return err; +} +EXPORT_SYMBOL_GPL(l2cap_add_psm); + +int l2cap_add_scid(struct l2cap_chan *chan, __u16 scid) +{ + write_lock(&chan_list_lock); + + /* Override the defaults (which are for conn-oriented) */ + chan->omtu = L2CAP_DEFAULT_MTU; + chan->chan_type = L2CAP_CHAN_FIXED; + + chan->scid = scid; + + write_unlock(&chan_list_lock); + + return 0; +} + +static u16 l2cap_alloc_cid(struct l2cap_conn *conn) +{ + u16 cid, dyn_end; + + if (conn->hcon->type == LE_LINK) + dyn_end = L2CAP_CID_LE_DYN_END; + else + dyn_end = L2CAP_CID_DYN_END; + + for (cid = L2CAP_CID_DYN_START; cid <= dyn_end; cid++) { + if (!__l2cap_get_chan_by_scid(conn, cid)) + return cid; + } + + return 0; +} + +static void l2cap_state_change(struct l2cap_chan *chan, int state) +{ + BT_DBG("chan %p %s -> %s", chan, state_to_string(chan->state), + state_to_string(state)); + + chan->state = state; + chan->ops->state_change(chan, state, 0); +} + +static inline void l2cap_state_change_and_error(struct l2cap_chan *chan, + int state, int err) +{ + chan->state = state; + chan->ops->state_change(chan, chan->state, err); +} + +static inline void l2cap_chan_set_err(struct l2cap_chan *chan, int err) +{ + chan->ops->state_change(chan, chan->state, err); +} + +static void __set_retrans_timer(struct l2cap_chan *chan) +{ + if (!delayed_work_pending(&chan->monitor_timer) && + chan->retrans_timeout) { + l2cap_set_timer(chan, &chan->retrans_timer, + msecs_to_jiffies(chan->retrans_timeout)); + } +} + +static void __set_monitor_timer(struct l2cap_chan *chan) +{ + __clear_retrans_timer(chan); + if (chan->monitor_timeout) { + l2cap_set_timer(chan, &chan->monitor_timer, + msecs_to_jiffies(chan->monitor_timeout)); + } +} + +static struct sk_buff *l2cap_ertm_seq_in_queue(struct sk_buff_head *head, + u16 seq) +{ + struct sk_buff *skb; + + skb_queue_walk(head, skb) { + if (bt_cb(skb)->l2cap.txseq == seq) + return skb; + } + + return NULL; +} + +/* ---- L2CAP sequence number lists ---- */ + +/* For ERTM, ordered lists of sequence numbers must be tracked for + * SREJ requests that are received and for frames that are to be + * retransmitted. These seq_list functions implement a singly-linked + * list in an array, where membership in the list can also be checked + * in constant time. Items can also be added to the tail of the list + * and removed from the head in constant time, without further memory + * allocs or frees. + */ + +static int l2cap_seq_list_init(struct l2cap_seq_list *seq_list, u16 size) +{ + size_t alloc_size, i; + + /* Allocated size is a power of 2 to map sequence numbers + * (which may be up to 14 bits) in to a smaller array that is + * sized for the negotiated ERTM transmit windows. + */ + alloc_size = roundup_pow_of_two(size); + + seq_list->list = kmalloc_array(alloc_size, sizeof(u16), GFP_KERNEL); + if (!seq_list->list) + return -ENOMEM; + + seq_list->mask = alloc_size - 1; + seq_list->head = L2CAP_SEQ_LIST_CLEAR; + seq_list->tail = L2CAP_SEQ_LIST_CLEAR; + for (i = 0; i < alloc_size; i++) + seq_list->list[i] = L2CAP_SEQ_LIST_CLEAR; + + return 0; +} + +static inline void l2cap_seq_list_free(struct l2cap_seq_list *seq_list) +{ + kfree(seq_list->list); +} + +static inline bool l2cap_seq_list_contains(struct l2cap_seq_list *seq_list, + u16 seq) +{ + /* Constant-time check for list membership */ + return seq_list->list[seq & seq_list->mask] != L2CAP_SEQ_LIST_CLEAR; +} + +static inline u16 l2cap_seq_list_pop(struct l2cap_seq_list *seq_list) +{ + u16 seq = seq_list->head; + u16 mask = seq_list->mask; + + seq_list->head = seq_list->list[seq & mask]; + seq_list->list[seq & mask] = L2CAP_SEQ_LIST_CLEAR; + + if (seq_list->head == L2CAP_SEQ_LIST_TAIL) { + seq_list->head = L2CAP_SEQ_LIST_CLEAR; + seq_list->tail = L2CAP_SEQ_LIST_CLEAR; + } + + return seq; +} + +static void l2cap_seq_list_clear(struct l2cap_seq_list *seq_list) +{ + u16 i; + + if (seq_list->head == L2CAP_SEQ_LIST_CLEAR) + return; + + for (i = 0; i <= seq_list->mask; i++) + seq_list->list[i] = L2CAP_SEQ_LIST_CLEAR; + + seq_list->head = L2CAP_SEQ_LIST_CLEAR; + seq_list->tail = L2CAP_SEQ_LIST_CLEAR; +} + +static void l2cap_seq_list_append(struct l2cap_seq_list *seq_list, u16 seq) +{ + u16 mask = seq_list->mask; + + /* All appends happen in constant time */ + + if (seq_list->list[seq & mask] != L2CAP_SEQ_LIST_CLEAR) + return; + + if (seq_list->tail == L2CAP_SEQ_LIST_CLEAR) + seq_list->head = seq; + else + seq_list->list[seq_list->tail & mask] = seq; + + seq_list->tail = seq; + seq_list->list[seq & mask] = L2CAP_SEQ_LIST_TAIL; +} + +static void l2cap_chan_timeout(struct work_struct *work) +{ + struct l2cap_chan *chan = container_of(work, struct l2cap_chan, + chan_timer.work); + struct l2cap_conn *conn = chan->conn; + int reason; + + BT_DBG("chan %p state %s", chan, state_to_string(chan->state)); + + mutex_lock(&conn->chan_lock); + /* __set_chan_timer() calls l2cap_chan_hold(chan) while scheduling + * this work. No need to call l2cap_chan_hold(chan) here again. + */ + l2cap_chan_lock(chan); + + if (chan->state == BT_CONNECTED || chan->state == BT_CONFIG) + reason = ECONNREFUSED; + else if (chan->state == BT_CONNECT && + chan->sec_level != BT_SECURITY_SDP) + reason = ECONNREFUSED; + else + reason = ETIMEDOUT; + + l2cap_chan_close(chan, reason); + + chan->ops->close(chan); + + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); + + mutex_unlock(&conn->chan_lock); +} + +struct l2cap_chan *l2cap_chan_create(void) +{ + struct l2cap_chan *chan; + + chan = kzalloc(sizeof(*chan), GFP_ATOMIC); + if (!chan) + return NULL; + + skb_queue_head_init(&chan->tx_q); + skb_queue_head_init(&chan->srej_q); + mutex_init(&chan->lock); + + /* Set default lock nesting level */ + atomic_set(&chan->nesting, L2CAP_NESTING_NORMAL); + + write_lock(&chan_list_lock); + list_add(&chan->global_l, &chan_list); + write_unlock(&chan_list_lock); + + INIT_DELAYED_WORK(&chan->chan_timer, l2cap_chan_timeout); + INIT_DELAYED_WORK(&chan->retrans_timer, l2cap_retrans_timeout); + INIT_DELAYED_WORK(&chan->monitor_timer, l2cap_monitor_timeout); + INIT_DELAYED_WORK(&chan->ack_timer, l2cap_ack_timeout); + + chan->state = BT_OPEN; + + kref_init(&chan->kref); + + /* This flag is cleared in l2cap_chan_ready() */ + set_bit(CONF_NOT_COMPLETE, &chan->conf_state); + + BT_DBG("chan %p", chan); + + return chan; +} +EXPORT_SYMBOL_GPL(l2cap_chan_create); + +static void l2cap_chan_destroy(struct kref *kref) +{ + struct l2cap_chan *chan = container_of(kref, struct l2cap_chan, kref); + + BT_DBG("chan %p", chan); + + write_lock(&chan_list_lock); + list_del(&chan->global_l); + write_unlock(&chan_list_lock); + + kfree(chan); +} + +void l2cap_chan_hold(struct l2cap_chan *c) +{ + BT_DBG("chan %p orig refcnt %u", c, kref_read(&c->kref)); + + kref_get(&c->kref); +} + +struct l2cap_chan *l2cap_chan_hold_unless_zero(struct l2cap_chan *c) +{ + BT_DBG("chan %p orig refcnt %u", c, kref_read(&c->kref)); + + if (!kref_get_unless_zero(&c->kref)) + return NULL; + + return c; +} + +void l2cap_chan_put(struct l2cap_chan *c) +{ + BT_DBG("chan %p orig refcnt %u", c, kref_read(&c->kref)); + + kref_put(&c->kref, l2cap_chan_destroy); +} +EXPORT_SYMBOL_GPL(l2cap_chan_put); + +void l2cap_chan_set_defaults(struct l2cap_chan *chan) +{ + chan->fcs = L2CAP_FCS_CRC16; + chan->max_tx = L2CAP_DEFAULT_MAX_TX; + chan->tx_win = L2CAP_DEFAULT_TX_WINDOW; + chan->tx_win_max = L2CAP_DEFAULT_TX_WINDOW; + chan->remote_max_tx = chan->max_tx; + chan->remote_tx_win = chan->tx_win; + chan->ack_win = L2CAP_DEFAULT_TX_WINDOW; + chan->sec_level = BT_SECURITY_LOW; + chan->flush_to = L2CAP_DEFAULT_FLUSH_TO; + chan->retrans_timeout = L2CAP_DEFAULT_RETRANS_TO; + chan->monitor_timeout = L2CAP_DEFAULT_MONITOR_TO; + + chan->conf_state = 0; + set_bit(CONF_NOT_COMPLETE, &chan->conf_state); + + set_bit(FLAG_FORCE_ACTIVE, &chan->flags); +} +EXPORT_SYMBOL_GPL(l2cap_chan_set_defaults); + +static void l2cap_le_flowctl_init(struct l2cap_chan *chan, u16 tx_credits) +{ + chan->sdu = NULL; + chan->sdu_last_frag = NULL; + chan->sdu_len = 0; + chan->tx_credits = tx_credits; + /* Derive MPS from connection MTU to stop HCI fragmentation */ + chan->mps = min_t(u16, chan->imtu, chan->conn->mtu - L2CAP_HDR_SIZE); + /* Give enough credits for a full packet */ + chan->rx_credits = (chan->imtu / chan->mps) + 1; + + skb_queue_head_init(&chan->tx_q); +} + +static void l2cap_ecred_init(struct l2cap_chan *chan, u16 tx_credits) +{ + l2cap_le_flowctl_init(chan, tx_credits); + + /* L2CAP implementations shall support a minimum MPS of 64 octets */ + if (chan->mps < L2CAP_ECRED_MIN_MPS) { + chan->mps = L2CAP_ECRED_MIN_MPS; + chan->rx_credits = (chan->imtu / chan->mps) + 1; + } +} + +void __l2cap_chan_add(struct l2cap_conn *conn, struct l2cap_chan *chan) +{ + BT_DBG("conn %p, psm 0x%2.2x, dcid 0x%4.4x", conn, + __le16_to_cpu(chan->psm), chan->dcid); + + conn->disc_reason = HCI_ERROR_REMOTE_USER_TERM; + + chan->conn = conn; + + switch (chan->chan_type) { + case L2CAP_CHAN_CONN_ORIENTED: + /* Alloc CID for connection-oriented socket */ + chan->scid = l2cap_alloc_cid(conn); + if (conn->hcon->type == ACL_LINK) + chan->omtu = L2CAP_DEFAULT_MTU; + break; + + case L2CAP_CHAN_CONN_LESS: + /* Connectionless socket */ + chan->scid = L2CAP_CID_CONN_LESS; + chan->dcid = L2CAP_CID_CONN_LESS; + chan->omtu = L2CAP_DEFAULT_MTU; + break; + + case L2CAP_CHAN_FIXED: + /* Caller will set CID and CID specific MTU values */ + break; + + default: + /* Raw socket can send/recv signalling messages only */ + chan->scid = L2CAP_CID_SIGNALING; + chan->dcid = L2CAP_CID_SIGNALING; + chan->omtu = L2CAP_DEFAULT_MTU; + } + + chan->local_id = L2CAP_BESTEFFORT_ID; + chan->local_stype = L2CAP_SERV_BESTEFFORT; + chan->local_msdu = L2CAP_DEFAULT_MAX_SDU_SIZE; + chan->local_sdu_itime = L2CAP_DEFAULT_SDU_ITIME; + chan->local_acc_lat = L2CAP_DEFAULT_ACC_LAT; + chan->local_flush_to = L2CAP_EFS_DEFAULT_FLUSH_TO; + + l2cap_chan_hold(chan); + + /* Only keep a reference for fixed channels if they requested it */ + if (chan->chan_type != L2CAP_CHAN_FIXED || + test_bit(FLAG_HOLD_HCI_CONN, &chan->flags)) + hci_conn_hold(conn->hcon); + + list_add(&chan->list, &conn->chan_l); +} + +void l2cap_chan_add(struct l2cap_conn *conn, struct l2cap_chan *chan) +{ + mutex_lock(&conn->chan_lock); + __l2cap_chan_add(conn, chan); + mutex_unlock(&conn->chan_lock); +} + +void l2cap_chan_del(struct l2cap_chan *chan, int err) +{ + struct l2cap_conn *conn = chan->conn; + + __clear_chan_timer(chan); + + BT_DBG("chan %p, conn %p, err %d, state %s", chan, conn, err, + state_to_string(chan->state)); + + chan->ops->teardown(chan, err); + + if (conn) { + struct amp_mgr *mgr = conn->hcon->amp_mgr; + /* Delete from channel list */ + list_del(&chan->list); + + l2cap_chan_put(chan); + + chan->conn = NULL; + + /* Reference was only held for non-fixed channels or + * fixed channels that explicitly requested it using the + * FLAG_HOLD_HCI_CONN flag. + */ + if (chan->chan_type != L2CAP_CHAN_FIXED || + test_bit(FLAG_HOLD_HCI_CONN, &chan->flags)) + hci_conn_drop(conn->hcon); + + if (mgr && mgr->bredr_chan == chan) + mgr->bredr_chan = NULL; + } + + if (chan->hs_hchan) { + struct hci_chan *hs_hchan = chan->hs_hchan; + + BT_DBG("chan %p disconnect hs_hchan %p", chan, hs_hchan); + amp_disconnect_logical_link(hs_hchan); + } + + if (test_bit(CONF_NOT_COMPLETE, &chan->conf_state)) + return; + + switch (chan->mode) { + case L2CAP_MODE_BASIC: + break; + + case L2CAP_MODE_LE_FLOWCTL: + case L2CAP_MODE_EXT_FLOWCTL: + skb_queue_purge(&chan->tx_q); + break; + + case L2CAP_MODE_ERTM: + __clear_retrans_timer(chan); + __clear_monitor_timer(chan); + __clear_ack_timer(chan); + + skb_queue_purge(&chan->srej_q); + + l2cap_seq_list_free(&chan->srej_list); + l2cap_seq_list_free(&chan->retrans_list); + fallthrough; + + case L2CAP_MODE_STREAMING: + skb_queue_purge(&chan->tx_q); + break; + } +} +EXPORT_SYMBOL_GPL(l2cap_chan_del); + +static void __l2cap_chan_list_id(struct l2cap_conn *conn, u16 id, + l2cap_chan_func_t func, void *data) +{ + struct l2cap_chan *chan, *l; + + list_for_each_entry_safe(chan, l, &conn->chan_l, list) { + if (chan->ident == id) + func(chan, data); + } +} + +static void __l2cap_chan_list(struct l2cap_conn *conn, l2cap_chan_func_t func, + void *data) +{ + struct l2cap_chan *chan; + + list_for_each_entry(chan, &conn->chan_l, list) { + func(chan, data); + } +} + +void l2cap_chan_list(struct l2cap_conn *conn, l2cap_chan_func_t func, + void *data) +{ + if (!conn) + return; + + mutex_lock(&conn->chan_lock); + __l2cap_chan_list(conn, func, data); + mutex_unlock(&conn->chan_lock); +} + +EXPORT_SYMBOL_GPL(l2cap_chan_list); + +static void l2cap_conn_update_id_addr(struct work_struct *work) +{ + struct l2cap_conn *conn = container_of(work, struct l2cap_conn, + id_addr_update_work); + struct hci_conn *hcon = conn->hcon; + struct l2cap_chan *chan; + + mutex_lock(&conn->chan_lock); + + list_for_each_entry(chan, &conn->chan_l, list) { + l2cap_chan_lock(chan); + bacpy(&chan->dst, &hcon->dst); + chan->dst_type = bdaddr_dst_type(hcon); + l2cap_chan_unlock(chan); + } + + mutex_unlock(&conn->chan_lock); +} + +static void l2cap_chan_le_connect_reject(struct l2cap_chan *chan) +{ + struct l2cap_conn *conn = chan->conn; + struct l2cap_le_conn_rsp rsp; + u16 result; + + if (test_bit(FLAG_DEFER_SETUP, &chan->flags)) + result = L2CAP_CR_LE_AUTHORIZATION; + else + result = L2CAP_CR_LE_BAD_PSM; + + l2cap_state_change(chan, BT_DISCONN); + + rsp.dcid = cpu_to_le16(chan->scid); + rsp.mtu = cpu_to_le16(chan->imtu); + rsp.mps = cpu_to_le16(chan->mps); + rsp.credits = cpu_to_le16(chan->rx_credits); + rsp.result = cpu_to_le16(result); + + l2cap_send_cmd(conn, chan->ident, L2CAP_LE_CONN_RSP, sizeof(rsp), + &rsp); +} + +static void l2cap_chan_ecred_connect_reject(struct l2cap_chan *chan) +{ + l2cap_state_change(chan, BT_DISCONN); + + __l2cap_ecred_conn_rsp_defer(chan); +} + +static void l2cap_chan_connect_reject(struct l2cap_chan *chan) +{ + struct l2cap_conn *conn = chan->conn; + struct l2cap_conn_rsp rsp; + u16 result; + + if (test_bit(FLAG_DEFER_SETUP, &chan->flags)) + result = L2CAP_CR_SEC_BLOCK; + else + result = L2CAP_CR_BAD_PSM; + + l2cap_state_change(chan, BT_DISCONN); + + rsp.scid = cpu_to_le16(chan->dcid); + rsp.dcid = cpu_to_le16(chan->scid); + rsp.result = cpu_to_le16(result); + rsp.status = cpu_to_le16(L2CAP_CS_NO_INFO); + + l2cap_send_cmd(conn, chan->ident, L2CAP_CONN_RSP, sizeof(rsp), &rsp); +} + +void l2cap_chan_close(struct l2cap_chan *chan, int reason) +{ + struct l2cap_conn *conn = chan->conn; + + BT_DBG("chan %p state %s", chan, state_to_string(chan->state)); + + switch (chan->state) { + case BT_LISTEN: + chan->ops->teardown(chan, 0); + break; + + case BT_CONNECTED: + case BT_CONFIG: + if (chan->chan_type == L2CAP_CHAN_CONN_ORIENTED) { + __set_chan_timer(chan, chan->ops->get_sndtimeo(chan)); + l2cap_send_disconn_req(chan, reason); + } else + l2cap_chan_del(chan, reason); + break; + + case BT_CONNECT2: + if (chan->chan_type == L2CAP_CHAN_CONN_ORIENTED) { + if (conn->hcon->type == ACL_LINK) + l2cap_chan_connect_reject(chan); + else if (conn->hcon->type == LE_LINK) { + switch (chan->mode) { + case L2CAP_MODE_LE_FLOWCTL: + l2cap_chan_le_connect_reject(chan); + break; + case L2CAP_MODE_EXT_FLOWCTL: + l2cap_chan_ecred_connect_reject(chan); + return; + } + } + } + + l2cap_chan_del(chan, reason); + break; + + case BT_CONNECT: + case BT_DISCONN: + l2cap_chan_del(chan, reason); + break; + + default: + chan->ops->teardown(chan, 0); + break; + } +} +EXPORT_SYMBOL(l2cap_chan_close); + +static inline u8 l2cap_get_auth_type(struct l2cap_chan *chan) +{ + switch (chan->chan_type) { + case L2CAP_CHAN_RAW: + switch (chan->sec_level) { + case BT_SECURITY_HIGH: + case BT_SECURITY_FIPS: + return HCI_AT_DEDICATED_BONDING_MITM; + case BT_SECURITY_MEDIUM: + return HCI_AT_DEDICATED_BONDING; + default: + return HCI_AT_NO_BONDING; + } + break; + case L2CAP_CHAN_CONN_LESS: + if (chan->psm == cpu_to_le16(L2CAP_PSM_3DSP)) { + if (chan->sec_level == BT_SECURITY_LOW) + chan->sec_level = BT_SECURITY_SDP; + } + if (chan->sec_level == BT_SECURITY_HIGH || + chan->sec_level == BT_SECURITY_FIPS) + return HCI_AT_NO_BONDING_MITM; + else + return HCI_AT_NO_BONDING; + break; + case L2CAP_CHAN_CONN_ORIENTED: + if (chan->psm == cpu_to_le16(L2CAP_PSM_SDP)) { + if (chan->sec_level == BT_SECURITY_LOW) + chan->sec_level = BT_SECURITY_SDP; + + if (chan->sec_level == BT_SECURITY_HIGH || + chan->sec_level == BT_SECURITY_FIPS) + return HCI_AT_NO_BONDING_MITM; + else + return HCI_AT_NO_BONDING; + } + fallthrough; + + default: + switch (chan->sec_level) { + case BT_SECURITY_HIGH: + case BT_SECURITY_FIPS: + return HCI_AT_GENERAL_BONDING_MITM; + case BT_SECURITY_MEDIUM: + return HCI_AT_GENERAL_BONDING; + default: + return HCI_AT_NO_BONDING; + } + break; + } +} + +/* Service level security */ +int l2cap_chan_check_security(struct l2cap_chan *chan, bool initiator) +{ + struct l2cap_conn *conn = chan->conn; + __u8 auth_type; + + if (conn->hcon->type == LE_LINK) + return smp_conn_security(conn->hcon, chan->sec_level); + + auth_type = l2cap_get_auth_type(chan); + + return hci_conn_security(conn->hcon, chan->sec_level, auth_type, + initiator); +} + +static u8 l2cap_get_ident(struct l2cap_conn *conn) +{ + u8 id; + + /* Get next available identificator. + * 1 - 128 are used by kernel. + * 129 - 199 are reserved. + * 200 - 254 are used by utilities like l2ping, etc. + */ + + mutex_lock(&conn->ident_lock); + + if (++conn->tx_ident > 128) + conn->tx_ident = 1; + + id = conn->tx_ident; + + mutex_unlock(&conn->ident_lock); + + return id; +} + +static void l2cap_send_cmd(struct l2cap_conn *conn, u8 ident, u8 code, u16 len, + void *data) +{ + struct sk_buff *skb = l2cap_build_cmd(conn, code, ident, len, data); + u8 flags; + + BT_DBG("code 0x%2.2x", code); + + if (!skb) + return; + + /* Use NO_FLUSH if supported or we have an LE link (which does + * not support auto-flushing packets) */ + if (lmp_no_flush_capable(conn->hcon->hdev) || + conn->hcon->type == LE_LINK) + flags = ACL_START_NO_FLUSH; + else + flags = ACL_START; + + bt_cb(skb)->force_active = BT_POWER_FORCE_ACTIVE_ON; + skb->priority = HCI_PRIO_MAX; + + hci_send_acl(conn->hchan, skb, flags); +} + +static bool __chan_is_moving(struct l2cap_chan *chan) +{ + return chan->move_state != L2CAP_MOVE_STABLE && + chan->move_state != L2CAP_MOVE_WAIT_PREPARE; +} + +static void l2cap_do_send(struct l2cap_chan *chan, struct sk_buff *skb) +{ + struct hci_conn *hcon = chan->conn->hcon; + u16 flags; + + BT_DBG("chan %p, skb %p len %d priority %u", chan, skb, skb->len, + skb->priority); + + if (chan->hs_hcon && !__chan_is_moving(chan)) { + if (chan->hs_hchan) + hci_send_acl(chan->hs_hchan, skb, ACL_COMPLETE); + else + kfree_skb(skb); + + return; + } + + /* Use NO_FLUSH for LE links (where this is the only option) or + * if the BR/EDR link supports it and flushing has not been + * explicitly requested (through FLAG_FLUSHABLE). + */ + if (hcon->type == LE_LINK || + (!test_bit(FLAG_FLUSHABLE, &chan->flags) && + lmp_no_flush_capable(hcon->hdev))) + flags = ACL_START_NO_FLUSH; + else + flags = ACL_START; + + bt_cb(skb)->force_active = test_bit(FLAG_FORCE_ACTIVE, &chan->flags); + hci_send_acl(chan->conn->hchan, skb, flags); +} + +static void __unpack_enhanced_control(u16 enh, struct l2cap_ctrl *control) +{ + control->reqseq = (enh & L2CAP_CTRL_REQSEQ) >> L2CAP_CTRL_REQSEQ_SHIFT; + control->final = (enh & L2CAP_CTRL_FINAL) >> L2CAP_CTRL_FINAL_SHIFT; + + if (enh & L2CAP_CTRL_FRAME_TYPE) { + /* S-Frame */ + control->sframe = 1; + control->poll = (enh & L2CAP_CTRL_POLL) >> L2CAP_CTRL_POLL_SHIFT; + control->super = (enh & L2CAP_CTRL_SUPERVISE) >> L2CAP_CTRL_SUPER_SHIFT; + + control->sar = 0; + control->txseq = 0; + } else { + /* I-Frame */ + control->sframe = 0; + control->sar = (enh & L2CAP_CTRL_SAR) >> L2CAP_CTRL_SAR_SHIFT; + control->txseq = (enh & L2CAP_CTRL_TXSEQ) >> L2CAP_CTRL_TXSEQ_SHIFT; + + control->poll = 0; + control->super = 0; + } +} + +static void __unpack_extended_control(u32 ext, struct l2cap_ctrl *control) +{ + control->reqseq = (ext & L2CAP_EXT_CTRL_REQSEQ) >> L2CAP_EXT_CTRL_REQSEQ_SHIFT; + control->final = (ext & L2CAP_EXT_CTRL_FINAL) >> L2CAP_EXT_CTRL_FINAL_SHIFT; + + if (ext & L2CAP_EXT_CTRL_FRAME_TYPE) { + /* S-Frame */ + control->sframe = 1; + control->poll = (ext & L2CAP_EXT_CTRL_POLL) >> L2CAP_EXT_CTRL_POLL_SHIFT; + control->super = (ext & L2CAP_EXT_CTRL_SUPERVISE) >> L2CAP_EXT_CTRL_SUPER_SHIFT; + + control->sar = 0; + control->txseq = 0; + } else { + /* I-Frame */ + control->sframe = 0; + control->sar = (ext & L2CAP_EXT_CTRL_SAR) >> L2CAP_EXT_CTRL_SAR_SHIFT; + control->txseq = (ext & L2CAP_EXT_CTRL_TXSEQ) >> L2CAP_EXT_CTRL_TXSEQ_SHIFT; + + control->poll = 0; + control->super = 0; + } +} + +static inline void __unpack_control(struct l2cap_chan *chan, + struct sk_buff *skb) +{ + if (test_bit(FLAG_EXT_CTRL, &chan->flags)) { + __unpack_extended_control(get_unaligned_le32(skb->data), + &bt_cb(skb)->l2cap); + skb_pull(skb, L2CAP_EXT_CTRL_SIZE); + } else { + __unpack_enhanced_control(get_unaligned_le16(skb->data), + &bt_cb(skb)->l2cap); + skb_pull(skb, L2CAP_ENH_CTRL_SIZE); + } +} + +static u32 __pack_extended_control(struct l2cap_ctrl *control) +{ + u32 packed; + + packed = control->reqseq << L2CAP_EXT_CTRL_REQSEQ_SHIFT; + packed |= control->final << L2CAP_EXT_CTRL_FINAL_SHIFT; + + if (control->sframe) { + packed |= control->poll << L2CAP_EXT_CTRL_POLL_SHIFT; + packed |= control->super << L2CAP_EXT_CTRL_SUPER_SHIFT; + packed |= L2CAP_EXT_CTRL_FRAME_TYPE; + } else { + packed |= control->sar << L2CAP_EXT_CTRL_SAR_SHIFT; + packed |= control->txseq << L2CAP_EXT_CTRL_TXSEQ_SHIFT; + } + + return packed; +} + +static u16 __pack_enhanced_control(struct l2cap_ctrl *control) +{ + u16 packed; + + packed = control->reqseq << L2CAP_CTRL_REQSEQ_SHIFT; + packed |= control->final << L2CAP_CTRL_FINAL_SHIFT; + + if (control->sframe) { + packed |= control->poll << L2CAP_CTRL_POLL_SHIFT; + packed |= control->super << L2CAP_CTRL_SUPER_SHIFT; + packed |= L2CAP_CTRL_FRAME_TYPE; + } else { + packed |= control->sar << L2CAP_CTRL_SAR_SHIFT; + packed |= control->txseq << L2CAP_CTRL_TXSEQ_SHIFT; + } + + return packed; +} + +static inline void __pack_control(struct l2cap_chan *chan, + struct l2cap_ctrl *control, + struct sk_buff *skb) +{ + if (test_bit(FLAG_EXT_CTRL, &chan->flags)) { + put_unaligned_le32(__pack_extended_control(control), + skb->data + L2CAP_HDR_SIZE); + } else { + put_unaligned_le16(__pack_enhanced_control(control), + skb->data + L2CAP_HDR_SIZE); + } +} + +static inline unsigned int __ertm_hdr_size(struct l2cap_chan *chan) +{ + if (test_bit(FLAG_EXT_CTRL, &chan->flags)) + return L2CAP_EXT_HDR_SIZE; + else + return L2CAP_ENH_HDR_SIZE; +} + +static struct sk_buff *l2cap_create_sframe_pdu(struct l2cap_chan *chan, + u32 control) +{ + struct sk_buff *skb; + struct l2cap_hdr *lh; + int hlen = __ertm_hdr_size(chan); + + if (chan->fcs == L2CAP_FCS_CRC16) + hlen += L2CAP_FCS_SIZE; + + skb = bt_skb_alloc(hlen, GFP_KERNEL); + + if (!skb) + return ERR_PTR(-ENOMEM); + + lh = skb_put(skb, L2CAP_HDR_SIZE); + lh->len = cpu_to_le16(hlen - L2CAP_HDR_SIZE); + lh->cid = cpu_to_le16(chan->dcid); + + if (test_bit(FLAG_EXT_CTRL, &chan->flags)) + put_unaligned_le32(control, skb_put(skb, L2CAP_EXT_CTRL_SIZE)); + else + put_unaligned_le16(control, skb_put(skb, L2CAP_ENH_CTRL_SIZE)); + + if (chan->fcs == L2CAP_FCS_CRC16) { + u16 fcs = crc16(0, (u8 *)skb->data, skb->len); + put_unaligned_le16(fcs, skb_put(skb, L2CAP_FCS_SIZE)); + } + + skb->priority = HCI_PRIO_MAX; + return skb; +} + +static void l2cap_send_sframe(struct l2cap_chan *chan, + struct l2cap_ctrl *control) +{ + struct sk_buff *skb; + u32 control_field; + + BT_DBG("chan %p, control %p", chan, control); + + if (!control->sframe) + return; + + if (__chan_is_moving(chan)) + return; + + if (test_and_clear_bit(CONN_SEND_FBIT, &chan->conn_state) && + !control->poll) + control->final = 1; + + if (control->super == L2CAP_SUPER_RR) + clear_bit(CONN_RNR_SENT, &chan->conn_state); + else if (control->super == L2CAP_SUPER_RNR) + set_bit(CONN_RNR_SENT, &chan->conn_state); + + if (control->super != L2CAP_SUPER_SREJ) { + chan->last_acked_seq = control->reqseq; + __clear_ack_timer(chan); + } + + BT_DBG("reqseq %d, final %d, poll %d, super %d", control->reqseq, + control->final, control->poll, control->super); + + if (test_bit(FLAG_EXT_CTRL, &chan->flags)) + control_field = __pack_extended_control(control); + else + control_field = __pack_enhanced_control(control); + + skb = l2cap_create_sframe_pdu(chan, control_field); + if (!IS_ERR(skb)) + l2cap_do_send(chan, skb); +} + +static void l2cap_send_rr_or_rnr(struct l2cap_chan *chan, bool poll) +{ + struct l2cap_ctrl control; + + BT_DBG("chan %p, poll %d", chan, poll); + + memset(&control, 0, sizeof(control)); + control.sframe = 1; + control.poll = poll; + + if (test_bit(CONN_LOCAL_BUSY, &chan->conn_state)) + control.super = L2CAP_SUPER_RNR; + else + control.super = L2CAP_SUPER_RR; + + control.reqseq = chan->buffer_seq; + l2cap_send_sframe(chan, &control); +} + +static inline int __l2cap_no_conn_pending(struct l2cap_chan *chan) +{ + if (chan->chan_type != L2CAP_CHAN_CONN_ORIENTED) + return true; + + return !test_bit(CONF_CONNECT_PEND, &chan->conf_state); +} + +static bool __amp_capable(struct l2cap_chan *chan) +{ + struct l2cap_conn *conn = chan->conn; + struct hci_dev *hdev; + bool amp_available = false; + + if (!(conn->local_fixed_chan & L2CAP_FC_A2MP)) + return false; + + if (!(conn->remote_fixed_chan & L2CAP_FC_A2MP)) + return false; + + read_lock(&hci_dev_list_lock); + list_for_each_entry(hdev, &hci_dev_list, list) { + if (hdev->amp_type != AMP_TYPE_BREDR && + test_bit(HCI_UP, &hdev->flags)) { + amp_available = true; + break; + } + } + read_unlock(&hci_dev_list_lock); + + if (chan->chan_policy == BT_CHANNEL_POLICY_AMP_PREFERRED) + return amp_available; + + return false; +} + +static bool l2cap_check_efs(struct l2cap_chan *chan) +{ + /* Check EFS parameters */ + return true; +} + +void l2cap_send_conn_req(struct l2cap_chan *chan) +{ + struct l2cap_conn *conn = chan->conn; + struct l2cap_conn_req req; + + req.scid = cpu_to_le16(chan->scid); + req.psm = chan->psm; + + chan->ident = l2cap_get_ident(conn); + + set_bit(CONF_CONNECT_PEND, &chan->conf_state); + + l2cap_send_cmd(conn, chan->ident, L2CAP_CONN_REQ, sizeof(req), &req); +} + +static void l2cap_send_create_chan_req(struct l2cap_chan *chan, u8 amp_id) +{ + struct l2cap_create_chan_req req; + req.scid = cpu_to_le16(chan->scid); + req.psm = chan->psm; + req.amp_id = amp_id; + + chan->ident = l2cap_get_ident(chan->conn); + + l2cap_send_cmd(chan->conn, chan->ident, L2CAP_CREATE_CHAN_REQ, + sizeof(req), &req); +} + +static void l2cap_move_setup(struct l2cap_chan *chan) +{ + struct sk_buff *skb; + + BT_DBG("chan %p", chan); + + if (chan->mode != L2CAP_MODE_ERTM) + return; + + __clear_retrans_timer(chan); + __clear_monitor_timer(chan); + __clear_ack_timer(chan); + + chan->retry_count = 0; + skb_queue_walk(&chan->tx_q, skb) { + if (bt_cb(skb)->l2cap.retries) + bt_cb(skb)->l2cap.retries = 1; + else + break; + } + + chan->expected_tx_seq = chan->buffer_seq; + + clear_bit(CONN_REJ_ACT, &chan->conn_state); + clear_bit(CONN_SREJ_ACT, &chan->conn_state); + l2cap_seq_list_clear(&chan->retrans_list); + l2cap_seq_list_clear(&chan->srej_list); + skb_queue_purge(&chan->srej_q); + + chan->tx_state = L2CAP_TX_STATE_XMIT; + chan->rx_state = L2CAP_RX_STATE_MOVE; + + set_bit(CONN_REMOTE_BUSY, &chan->conn_state); +} + +static void l2cap_move_done(struct l2cap_chan *chan) +{ + u8 move_role = chan->move_role; + BT_DBG("chan %p", chan); + + chan->move_state = L2CAP_MOVE_STABLE; + chan->move_role = L2CAP_MOVE_ROLE_NONE; + + if (chan->mode != L2CAP_MODE_ERTM) + return; + + switch (move_role) { + case L2CAP_MOVE_ROLE_INITIATOR: + l2cap_tx(chan, NULL, NULL, L2CAP_EV_EXPLICIT_POLL); + chan->rx_state = L2CAP_RX_STATE_WAIT_F; + break; + case L2CAP_MOVE_ROLE_RESPONDER: + chan->rx_state = L2CAP_RX_STATE_WAIT_P; + break; + } +} + +static void l2cap_chan_ready(struct l2cap_chan *chan) +{ + /* The channel may have already been flagged as connected in + * case of receiving data before the L2CAP info req/rsp + * procedure is complete. + */ + if (chan->state == BT_CONNECTED) + return; + + /* This clears all conf flags, including CONF_NOT_COMPLETE */ + chan->conf_state = 0; + __clear_chan_timer(chan); + + switch (chan->mode) { + case L2CAP_MODE_LE_FLOWCTL: + case L2CAP_MODE_EXT_FLOWCTL: + if (!chan->tx_credits) + chan->ops->suspend(chan); + break; + } + + chan->state = BT_CONNECTED; + + chan->ops->ready(chan); +} + +static void l2cap_le_connect(struct l2cap_chan *chan) +{ + struct l2cap_conn *conn = chan->conn; + struct l2cap_le_conn_req req; + + if (test_and_set_bit(FLAG_LE_CONN_REQ_SENT, &chan->flags)) + return; + + if (!chan->imtu) + chan->imtu = chan->conn->mtu; + + l2cap_le_flowctl_init(chan, 0); + + memset(&req, 0, sizeof(req)); + req.psm = chan->psm; + req.scid = cpu_to_le16(chan->scid); + req.mtu = cpu_to_le16(chan->imtu); + req.mps = cpu_to_le16(chan->mps); + req.credits = cpu_to_le16(chan->rx_credits); + + chan->ident = l2cap_get_ident(conn); + + l2cap_send_cmd(conn, chan->ident, L2CAP_LE_CONN_REQ, + sizeof(req), &req); +} + +struct l2cap_ecred_conn_data { + struct { + struct l2cap_ecred_conn_req req; + __le16 scid[5]; + } __packed pdu; + struct l2cap_chan *chan; + struct pid *pid; + int count; +}; + +static void l2cap_ecred_defer_connect(struct l2cap_chan *chan, void *data) +{ + struct l2cap_ecred_conn_data *conn = data; + struct pid *pid; + + if (chan == conn->chan) + return; + + if (!test_and_clear_bit(FLAG_DEFER_SETUP, &chan->flags)) + return; + + pid = chan->ops->get_peer_pid(chan); + + /* Only add deferred channels with the same PID/PSM */ + if (conn->pid != pid || chan->psm != conn->chan->psm || chan->ident || + chan->mode != L2CAP_MODE_EXT_FLOWCTL || chan->state != BT_CONNECT) + return; + + if (test_and_set_bit(FLAG_ECRED_CONN_REQ_SENT, &chan->flags)) + return; + + l2cap_ecred_init(chan, 0); + + /* Set the same ident so we can match on the rsp */ + chan->ident = conn->chan->ident; + + /* Include all channels deferred */ + conn->pdu.scid[conn->count] = cpu_to_le16(chan->scid); + + conn->count++; +} + +static void l2cap_ecred_connect(struct l2cap_chan *chan) +{ + struct l2cap_conn *conn = chan->conn; + struct l2cap_ecred_conn_data data; + + if (test_bit(FLAG_DEFER_SETUP, &chan->flags)) + return; + + if (test_and_set_bit(FLAG_ECRED_CONN_REQ_SENT, &chan->flags)) + return; + + l2cap_ecred_init(chan, 0); + + memset(&data, 0, sizeof(data)); + data.pdu.req.psm = chan->psm; + data.pdu.req.mtu = cpu_to_le16(chan->imtu); + data.pdu.req.mps = cpu_to_le16(chan->mps); + data.pdu.req.credits = cpu_to_le16(chan->rx_credits); + data.pdu.scid[0] = cpu_to_le16(chan->scid); + + chan->ident = l2cap_get_ident(conn); + + data.count = 1; + data.chan = chan; + data.pid = chan->ops->get_peer_pid(chan); + + __l2cap_chan_list(conn, l2cap_ecred_defer_connect, &data); + + l2cap_send_cmd(conn, chan->ident, L2CAP_ECRED_CONN_REQ, + sizeof(data.pdu.req) + data.count * sizeof(__le16), + &data.pdu); +} + +static void l2cap_le_start(struct l2cap_chan *chan) +{ + struct l2cap_conn *conn = chan->conn; + + if (!smp_conn_security(conn->hcon, chan->sec_level)) + return; + + if (!chan->psm) { + l2cap_chan_ready(chan); + return; + } + + if (chan->state == BT_CONNECT) { + if (chan->mode == L2CAP_MODE_EXT_FLOWCTL) + l2cap_ecred_connect(chan); + else + l2cap_le_connect(chan); + } +} + +static void l2cap_start_connection(struct l2cap_chan *chan) +{ + if (__amp_capable(chan)) { + BT_DBG("chan %p AMP capable: discover AMPs", chan); + a2mp_discover_amp(chan); + } else if (chan->conn->hcon->type == LE_LINK) { + l2cap_le_start(chan); + } else { + l2cap_send_conn_req(chan); + } +} + +static void l2cap_request_info(struct l2cap_conn *conn) +{ + struct l2cap_info_req req; + + if (conn->info_state & L2CAP_INFO_FEAT_MASK_REQ_SENT) + return; + + req.type = cpu_to_le16(L2CAP_IT_FEAT_MASK); + + conn->info_state |= L2CAP_INFO_FEAT_MASK_REQ_SENT; + conn->info_ident = l2cap_get_ident(conn); + + schedule_delayed_work(&conn->info_timer, L2CAP_INFO_TIMEOUT); + + l2cap_send_cmd(conn, conn->info_ident, L2CAP_INFO_REQ, + sizeof(req), &req); +} + +static bool l2cap_check_enc_key_size(struct hci_conn *hcon) +{ + /* The minimum encryption key size needs to be enforced by the + * host stack before establishing any L2CAP connections. The + * specification in theory allows a minimum of 1, but to align + * BR/EDR and LE transports, a minimum of 7 is chosen. + * + * This check might also be called for unencrypted connections + * that have no key size requirements. Ensure that the link is + * actually encrypted before enforcing a key size. + */ + int min_key_size = hcon->hdev->min_enc_key_size; + + /* On FIPS security level, key size must be 16 bytes */ + if (hcon->sec_level == BT_SECURITY_FIPS) + min_key_size = 16; + + return (!test_bit(HCI_CONN_ENCRYPT, &hcon->flags) || + hcon->enc_key_size >= min_key_size); +} + +static void l2cap_do_start(struct l2cap_chan *chan) +{ + struct l2cap_conn *conn = chan->conn; + + if (conn->hcon->type == LE_LINK) { + l2cap_le_start(chan); + return; + } + + if (!(conn->info_state & L2CAP_INFO_FEAT_MASK_REQ_SENT)) { + l2cap_request_info(conn); + return; + } + + if (!(conn->info_state & L2CAP_INFO_FEAT_MASK_REQ_DONE)) + return; + + if (!l2cap_chan_check_security(chan, true) || + !__l2cap_no_conn_pending(chan)) + return; + + if (l2cap_check_enc_key_size(conn->hcon)) + l2cap_start_connection(chan); + else + __set_chan_timer(chan, L2CAP_DISC_TIMEOUT); +} + +static inline int l2cap_mode_supported(__u8 mode, __u32 feat_mask) +{ + u32 local_feat_mask = l2cap_feat_mask; + if (!disable_ertm) + local_feat_mask |= L2CAP_FEAT_ERTM | L2CAP_FEAT_STREAMING; + + switch (mode) { + case L2CAP_MODE_ERTM: + return L2CAP_FEAT_ERTM & feat_mask & local_feat_mask; + case L2CAP_MODE_STREAMING: + return L2CAP_FEAT_STREAMING & feat_mask & local_feat_mask; + default: + return 0x00; + } +} + +static void l2cap_send_disconn_req(struct l2cap_chan *chan, int err) +{ + struct l2cap_conn *conn = chan->conn; + struct l2cap_disconn_req req; + + if (!conn) + return; + + if (chan->mode == L2CAP_MODE_ERTM && chan->state == BT_CONNECTED) { + __clear_retrans_timer(chan); + __clear_monitor_timer(chan); + __clear_ack_timer(chan); + } + + if (chan->scid == L2CAP_CID_A2MP) { + l2cap_state_change(chan, BT_DISCONN); + return; + } + + req.dcid = cpu_to_le16(chan->dcid); + req.scid = cpu_to_le16(chan->scid); + l2cap_send_cmd(conn, l2cap_get_ident(conn), L2CAP_DISCONN_REQ, + sizeof(req), &req); + + l2cap_state_change_and_error(chan, BT_DISCONN, err); +} + +/* ---- L2CAP connections ---- */ +static void l2cap_conn_start(struct l2cap_conn *conn) +{ + struct l2cap_chan *chan, *tmp; + + BT_DBG("conn %p", conn); + + mutex_lock(&conn->chan_lock); + + list_for_each_entry_safe(chan, tmp, &conn->chan_l, list) { + l2cap_chan_lock(chan); + + if (chan->chan_type != L2CAP_CHAN_CONN_ORIENTED) { + l2cap_chan_ready(chan); + l2cap_chan_unlock(chan); + continue; + } + + if (chan->state == BT_CONNECT) { + if (!l2cap_chan_check_security(chan, true) || + !__l2cap_no_conn_pending(chan)) { + l2cap_chan_unlock(chan); + continue; + } + + if (!l2cap_mode_supported(chan->mode, conn->feat_mask) + && test_bit(CONF_STATE2_DEVICE, + &chan->conf_state)) { + l2cap_chan_close(chan, ECONNRESET); + l2cap_chan_unlock(chan); + continue; + } + + if (l2cap_check_enc_key_size(conn->hcon)) + l2cap_start_connection(chan); + else + l2cap_chan_close(chan, ECONNREFUSED); + + } else if (chan->state == BT_CONNECT2) { + struct l2cap_conn_rsp rsp; + char buf[128]; + rsp.scid = cpu_to_le16(chan->dcid); + rsp.dcid = cpu_to_le16(chan->scid); + + if (l2cap_chan_check_security(chan, false)) { + if (test_bit(FLAG_DEFER_SETUP, &chan->flags)) { + rsp.result = cpu_to_le16(L2CAP_CR_PEND); + rsp.status = cpu_to_le16(L2CAP_CS_AUTHOR_PEND); + chan->ops->defer(chan); + + } else { + l2cap_state_change(chan, BT_CONFIG); + rsp.result = cpu_to_le16(L2CAP_CR_SUCCESS); + rsp.status = cpu_to_le16(L2CAP_CS_NO_INFO); + } + } else { + rsp.result = cpu_to_le16(L2CAP_CR_PEND); + rsp.status = cpu_to_le16(L2CAP_CS_AUTHEN_PEND); + } + + l2cap_send_cmd(conn, chan->ident, L2CAP_CONN_RSP, + sizeof(rsp), &rsp); + + if (test_bit(CONF_REQ_SENT, &chan->conf_state) || + rsp.result != L2CAP_CR_SUCCESS) { + l2cap_chan_unlock(chan); + continue; + } + + set_bit(CONF_REQ_SENT, &chan->conf_state); + l2cap_send_cmd(conn, l2cap_get_ident(conn), L2CAP_CONF_REQ, + l2cap_build_conf_req(chan, buf, sizeof(buf)), buf); + chan->num_conf_req++; + } + + l2cap_chan_unlock(chan); + } + + mutex_unlock(&conn->chan_lock); +} + +static void l2cap_le_conn_ready(struct l2cap_conn *conn) +{ + struct hci_conn *hcon = conn->hcon; + struct hci_dev *hdev = hcon->hdev; + + BT_DBG("%s conn %p", hdev->name, conn); + + /* For outgoing pairing which doesn't necessarily have an + * associated socket (e.g. mgmt_pair_device). + */ + if (hcon->out) + smp_conn_security(hcon, hcon->pending_sec_level); + + /* For LE peripheral connections, make sure the connection interval + * is in the range of the minimum and maximum interval that has + * been configured for this connection. If not, then trigger + * the connection update procedure. + */ + if (hcon->role == HCI_ROLE_SLAVE && + (hcon->le_conn_interval < hcon->le_conn_min_interval || + hcon->le_conn_interval > hcon->le_conn_max_interval)) { + struct l2cap_conn_param_update_req req; + + req.min = cpu_to_le16(hcon->le_conn_min_interval); + req.max = cpu_to_le16(hcon->le_conn_max_interval); + req.latency = cpu_to_le16(hcon->le_conn_latency); + req.to_multiplier = cpu_to_le16(hcon->le_supv_timeout); + + l2cap_send_cmd(conn, l2cap_get_ident(conn), + L2CAP_CONN_PARAM_UPDATE_REQ, sizeof(req), &req); + } +} + +static void l2cap_conn_ready(struct l2cap_conn *conn) +{ + struct l2cap_chan *chan; + struct hci_conn *hcon = conn->hcon; + + BT_DBG("conn %p", conn); + + if (hcon->type == ACL_LINK) + l2cap_request_info(conn); + + mutex_lock(&conn->chan_lock); + + list_for_each_entry(chan, &conn->chan_l, list) { + + l2cap_chan_lock(chan); + + if (chan->scid == L2CAP_CID_A2MP) { + l2cap_chan_unlock(chan); + continue; + } + + if (hcon->type == LE_LINK) { + l2cap_le_start(chan); + } else if (chan->chan_type != L2CAP_CHAN_CONN_ORIENTED) { + if (conn->info_state & L2CAP_INFO_FEAT_MASK_REQ_DONE) + l2cap_chan_ready(chan); + } else if (chan->state == BT_CONNECT) { + l2cap_do_start(chan); + } + + l2cap_chan_unlock(chan); + } + + mutex_unlock(&conn->chan_lock); + + if (hcon->type == LE_LINK) + l2cap_le_conn_ready(conn); + + queue_work(hcon->hdev->workqueue, &conn->pending_rx_work); +} + +/* Notify sockets that we cannot guaranty reliability anymore */ +static void l2cap_conn_unreliable(struct l2cap_conn *conn, int err) +{ + struct l2cap_chan *chan; + + BT_DBG("conn %p", conn); + + mutex_lock(&conn->chan_lock); + + list_for_each_entry(chan, &conn->chan_l, list) { + if (test_bit(FLAG_FORCE_RELIABLE, &chan->flags)) + l2cap_chan_set_err(chan, err); + } + + mutex_unlock(&conn->chan_lock); +} + +static void l2cap_info_timeout(struct work_struct *work) +{ + struct l2cap_conn *conn = container_of(work, struct l2cap_conn, + info_timer.work); + + conn->info_state |= L2CAP_INFO_FEAT_MASK_REQ_DONE; + conn->info_ident = 0; + + l2cap_conn_start(conn); +} + +/* + * l2cap_user + * External modules can register l2cap_user objects on l2cap_conn. The ->probe + * callback is called during registration. The ->remove callback is called + * during unregistration. + * An l2cap_user object can either be explicitly unregistered or when the + * underlying l2cap_conn object is deleted. This guarantees that l2cap->hcon, + * l2cap->hchan, .. are valid as long as the remove callback hasn't been called. + * External modules must own a reference to the l2cap_conn object if they intend + * to call l2cap_unregister_user(). The l2cap_conn object might get destroyed at + * any time if they don't. + */ + +int l2cap_register_user(struct l2cap_conn *conn, struct l2cap_user *user) +{ + struct hci_dev *hdev = conn->hcon->hdev; + int ret; + + /* We need to check whether l2cap_conn is registered. If it is not, we + * must not register the l2cap_user. l2cap_conn_del() is unregisters + * l2cap_conn objects, but doesn't provide its own locking. Instead, it + * relies on the parent hci_conn object to be locked. This itself relies + * on the hci_dev object to be locked. So we must lock the hci device + * here, too. */ + + hci_dev_lock(hdev); + + if (!list_empty(&user->list)) { + ret = -EINVAL; + goto out_unlock; + } + + /* conn->hchan is NULL after l2cap_conn_del() was called */ + if (!conn->hchan) { + ret = -ENODEV; + goto out_unlock; + } + + ret = user->probe(conn, user); + if (ret) + goto out_unlock; + + list_add(&user->list, &conn->users); + ret = 0; + +out_unlock: + hci_dev_unlock(hdev); + return ret; +} +EXPORT_SYMBOL(l2cap_register_user); + +void l2cap_unregister_user(struct l2cap_conn *conn, struct l2cap_user *user) +{ + struct hci_dev *hdev = conn->hcon->hdev; + + hci_dev_lock(hdev); + + if (list_empty(&user->list)) + goto out_unlock; + + list_del_init(&user->list); + user->remove(conn, user); + +out_unlock: + hci_dev_unlock(hdev); +} +EXPORT_SYMBOL(l2cap_unregister_user); + +static void l2cap_unregister_all_users(struct l2cap_conn *conn) +{ + struct l2cap_user *user; + + while (!list_empty(&conn->users)) { + user = list_first_entry(&conn->users, struct l2cap_user, list); + list_del_init(&user->list); + user->remove(conn, user); + } +} + +static void l2cap_conn_del(struct hci_conn *hcon, int err) +{ + struct l2cap_conn *conn = hcon->l2cap_data; + struct l2cap_chan *chan, *l; + + if (!conn) + return; + + BT_DBG("hcon %p conn %p, err %d", hcon, conn, err); + + kfree_skb(conn->rx_skb); + + skb_queue_purge(&conn->pending_rx); + + /* We can not call flush_work(&conn->pending_rx_work) here since we + * might block if we are running on a worker from the same workqueue + * pending_rx_work is waiting on. + */ + if (work_pending(&conn->pending_rx_work)) + cancel_work_sync(&conn->pending_rx_work); + + if (work_pending(&conn->id_addr_update_work)) + cancel_work_sync(&conn->id_addr_update_work); + + l2cap_unregister_all_users(conn); + + /* Force the connection to be immediately dropped */ + hcon->disc_timeout = 0; + + mutex_lock(&conn->chan_lock); + + /* Kill channels */ + list_for_each_entry_safe(chan, l, &conn->chan_l, list) { + l2cap_chan_hold(chan); + l2cap_chan_lock(chan); + + l2cap_chan_del(chan, err); + + chan->ops->close(chan); + + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); + } + + mutex_unlock(&conn->chan_lock); + + hci_chan_del(conn->hchan); + + if (conn->info_state & L2CAP_INFO_FEAT_MASK_REQ_SENT) + cancel_delayed_work_sync(&conn->info_timer); + + hcon->l2cap_data = NULL; + conn->hchan = NULL; + l2cap_conn_put(conn); +} + +static void l2cap_conn_free(struct kref *ref) +{ + struct l2cap_conn *conn = container_of(ref, struct l2cap_conn, ref); + + hci_conn_put(conn->hcon); + kfree(conn); +} + +struct l2cap_conn *l2cap_conn_get(struct l2cap_conn *conn) +{ + kref_get(&conn->ref); + return conn; +} +EXPORT_SYMBOL(l2cap_conn_get); + +void l2cap_conn_put(struct l2cap_conn *conn) +{ + kref_put(&conn->ref, l2cap_conn_free); +} +EXPORT_SYMBOL(l2cap_conn_put); + +/* ---- Socket interface ---- */ + +/* Find socket with psm and source / destination bdaddr. + * Returns closest match. + */ +static struct l2cap_chan *l2cap_global_chan_by_psm(int state, __le16 psm, + bdaddr_t *src, + bdaddr_t *dst, + u8 link_type) +{ + struct l2cap_chan *c, *tmp, *c1 = NULL; + + read_lock(&chan_list_lock); + + list_for_each_entry_safe(c, tmp, &chan_list, global_l) { + if (state && c->state != state) + continue; + + if (link_type == ACL_LINK && c->src_type != BDADDR_BREDR) + continue; + + if (link_type == LE_LINK && c->src_type == BDADDR_BREDR) + continue; + + if (c->chan_type != L2CAP_CHAN_FIXED && c->psm == psm) { + int src_match, dst_match; + int src_any, dst_any; + + /* Exact match. */ + src_match = !bacmp(&c->src, src); + dst_match = !bacmp(&c->dst, dst); + if (src_match && dst_match) { + if (!l2cap_chan_hold_unless_zero(c)) + continue; + + read_unlock(&chan_list_lock); + return c; + } + + /* Closest match */ + src_any = !bacmp(&c->src, BDADDR_ANY); + dst_any = !bacmp(&c->dst, BDADDR_ANY); + if ((src_match && dst_any) || (src_any && dst_match) || + (src_any && dst_any)) + c1 = c; + } + } + + if (c1) + c1 = l2cap_chan_hold_unless_zero(c1); + + read_unlock(&chan_list_lock); + + return c1; +} + +static void l2cap_monitor_timeout(struct work_struct *work) +{ + struct l2cap_chan *chan = container_of(work, struct l2cap_chan, + monitor_timer.work); + + BT_DBG("chan %p", chan); + + l2cap_chan_lock(chan); + + if (!chan->conn) { + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); + return; + } + + l2cap_tx(chan, NULL, NULL, L2CAP_EV_MONITOR_TO); + + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); +} + +static void l2cap_retrans_timeout(struct work_struct *work) +{ + struct l2cap_chan *chan = container_of(work, struct l2cap_chan, + retrans_timer.work); + + BT_DBG("chan %p", chan); + + l2cap_chan_lock(chan); + + if (!chan->conn) { + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); + return; + } + + l2cap_tx(chan, NULL, NULL, L2CAP_EV_RETRANS_TO); + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); +} + +static void l2cap_streaming_send(struct l2cap_chan *chan, + struct sk_buff_head *skbs) +{ + struct sk_buff *skb; + struct l2cap_ctrl *control; + + BT_DBG("chan %p, skbs %p", chan, skbs); + + if (__chan_is_moving(chan)) + return; + + skb_queue_splice_tail_init(skbs, &chan->tx_q); + + while (!skb_queue_empty(&chan->tx_q)) { + + skb = skb_dequeue(&chan->tx_q); + + bt_cb(skb)->l2cap.retries = 1; + control = &bt_cb(skb)->l2cap; + + control->reqseq = 0; + control->txseq = chan->next_tx_seq; + + __pack_control(chan, control, skb); + + if (chan->fcs == L2CAP_FCS_CRC16) { + u16 fcs = crc16(0, (u8 *) skb->data, skb->len); + put_unaligned_le16(fcs, skb_put(skb, L2CAP_FCS_SIZE)); + } + + l2cap_do_send(chan, skb); + + BT_DBG("Sent txseq %u", control->txseq); + + chan->next_tx_seq = __next_seq(chan, chan->next_tx_seq); + chan->frames_sent++; + } +} + +static int l2cap_ertm_send(struct l2cap_chan *chan) +{ + struct sk_buff *skb, *tx_skb; + struct l2cap_ctrl *control; + int sent = 0; + + BT_DBG("chan %p", chan); + + if (chan->state != BT_CONNECTED) + return -ENOTCONN; + + if (test_bit(CONN_REMOTE_BUSY, &chan->conn_state)) + return 0; + + if (__chan_is_moving(chan)) + return 0; + + while (chan->tx_send_head && + chan->unacked_frames < chan->remote_tx_win && + chan->tx_state == L2CAP_TX_STATE_XMIT) { + + skb = chan->tx_send_head; + + bt_cb(skb)->l2cap.retries = 1; + control = &bt_cb(skb)->l2cap; + + if (test_and_clear_bit(CONN_SEND_FBIT, &chan->conn_state)) + control->final = 1; + + control->reqseq = chan->buffer_seq; + chan->last_acked_seq = chan->buffer_seq; + control->txseq = chan->next_tx_seq; + + __pack_control(chan, control, skb); + + if (chan->fcs == L2CAP_FCS_CRC16) { + u16 fcs = crc16(0, (u8 *) skb->data, skb->len); + put_unaligned_le16(fcs, skb_put(skb, L2CAP_FCS_SIZE)); + } + + /* Clone after data has been modified. Data is assumed to be + read-only (for locking purposes) on cloned sk_buffs. + */ + tx_skb = skb_clone(skb, GFP_KERNEL); + + if (!tx_skb) + break; + + __set_retrans_timer(chan); + + chan->next_tx_seq = __next_seq(chan, chan->next_tx_seq); + chan->unacked_frames++; + chan->frames_sent++; + sent++; + + if (skb_queue_is_last(&chan->tx_q, skb)) + chan->tx_send_head = NULL; + else + chan->tx_send_head = skb_queue_next(&chan->tx_q, skb); + + l2cap_do_send(chan, tx_skb); + BT_DBG("Sent txseq %u", control->txseq); + } + + BT_DBG("Sent %d, %u unacked, %u in ERTM queue", sent, + chan->unacked_frames, skb_queue_len(&chan->tx_q)); + + return sent; +} + +static void l2cap_ertm_resend(struct l2cap_chan *chan) +{ + struct l2cap_ctrl control; + struct sk_buff *skb; + struct sk_buff *tx_skb; + u16 seq; + + BT_DBG("chan %p", chan); + + if (test_bit(CONN_REMOTE_BUSY, &chan->conn_state)) + return; + + if (__chan_is_moving(chan)) + return; + + while (chan->retrans_list.head != L2CAP_SEQ_LIST_CLEAR) { + seq = l2cap_seq_list_pop(&chan->retrans_list); + + skb = l2cap_ertm_seq_in_queue(&chan->tx_q, seq); + if (!skb) { + BT_DBG("Error: Can't retransmit seq %d, frame missing", + seq); + continue; + } + + bt_cb(skb)->l2cap.retries++; + control = bt_cb(skb)->l2cap; + + if (chan->max_tx != 0 && + bt_cb(skb)->l2cap.retries > chan->max_tx) { + BT_DBG("Retry limit exceeded (%d)", chan->max_tx); + l2cap_send_disconn_req(chan, ECONNRESET); + l2cap_seq_list_clear(&chan->retrans_list); + break; + } + + control.reqseq = chan->buffer_seq; + if (test_and_clear_bit(CONN_SEND_FBIT, &chan->conn_state)) + control.final = 1; + else + control.final = 0; + + if (skb_cloned(skb)) { + /* Cloned sk_buffs are read-only, so we need a + * writeable copy + */ + tx_skb = skb_copy(skb, GFP_KERNEL); + } else { + tx_skb = skb_clone(skb, GFP_KERNEL); + } + + if (!tx_skb) { + l2cap_seq_list_clear(&chan->retrans_list); + break; + } + + /* Update skb contents */ + if (test_bit(FLAG_EXT_CTRL, &chan->flags)) { + put_unaligned_le32(__pack_extended_control(&control), + tx_skb->data + L2CAP_HDR_SIZE); + } else { + put_unaligned_le16(__pack_enhanced_control(&control), + tx_skb->data + L2CAP_HDR_SIZE); + } + + /* Update FCS */ + if (chan->fcs == L2CAP_FCS_CRC16) { + u16 fcs = crc16(0, (u8 *) tx_skb->data, + tx_skb->len - L2CAP_FCS_SIZE); + put_unaligned_le16(fcs, skb_tail_pointer(tx_skb) - + L2CAP_FCS_SIZE); + } + + l2cap_do_send(chan, tx_skb); + + BT_DBG("Resent txseq %d", control.txseq); + + chan->last_acked_seq = chan->buffer_seq; + } +} + +static void l2cap_retransmit(struct l2cap_chan *chan, + struct l2cap_ctrl *control) +{ + BT_DBG("chan %p, control %p", chan, control); + + l2cap_seq_list_append(&chan->retrans_list, control->reqseq); + l2cap_ertm_resend(chan); +} + +static void l2cap_retransmit_all(struct l2cap_chan *chan, + struct l2cap_ctrl *control) +{ + struct sk_buff *skb; + + BT_DBG("chan %p, control %p", chan, control); + + if (control->poll) + set_bit(CONN_SEND_FBIT, &chan->conn_state); + + l2cap_seq_list_clear(&chan->retrans_list); + + if (test_bit(CONN_REMOTE_BUSY, &chan->conn_state)) + return; + + if (chan->unacked_frames) { + skb_queue_walk(&chan->tx_q, skb) { + if (bt_cb(skb)->l2cap.txseq == control->reqseq || + skb == chan->tx_send_head) + break; + } + + skb_queue_walk_from(&chan->tx_q, skb) { + if (skb == chan->tx_send_head) + break; + + l2cap_seq_list_append(&chan->retrans_list, + bt_cb(skb)->l2cap.txseq); + } + + l2cap_ertm_resend(chan); + } +} + +static void l2cap_send_ack(struct l2cap_chan *chan) +{ + struct l2cap_ctrl control; + u16 frames_to_ack = __seq_offset(chan, chan->buffer_seq, + chan->last_acked_seq); + int threshold; + + BT_DBG("chan %p last_acked_seq %d buffer_seq %d", + chan, chan->last_acked_seq, chan->buffer_seq); + + memset(&control, 0, sizeof(control)); + control.sframe = 1; + + if (test_bit(CONN_LOCAL_BUSY, &chan->conn_state) && + chan->rx_state == L2CAP_RX_STATE_RECV) { + __clear_ack_timer(chan); + control.super = L2CAP_SUPER_RNR; + control.reqseq = chan->buffer_seq; + l2cap_send_sframe(chan, &control); + } else { + if (!test_bit(CONN_REMOTE_BUSY, &chan->conn_state)) { + l2cap_ertm_send(chan); + /* If any i-frames were sent, they included an ack */ + if (chan->buffer_seq == chan->last_acked_seq) + frames_to_ack = 0; + } + + /* Ack now if the window is 3/4ths full. + * Calculate without mul or div + */ + threshold = chan->ack_win; + threshold += threshold << 1; + threshold >>= 2; + + BT_DBG("frames_to_ack %u, threshold %d", frames_to_ack, + threshold); + + if (frames_to_ack >= threshold) { + __clear_ack_timer(chan); + control.super = L2CAP_SUPER_RR; + control.reqseq = chan->buffer_seq; + l2cap_send_sframe(chan, &control); + frames_to_ack = 0; + } + + if (frames_to_ack) + __set_ack_timer(chan); + } +} + +static inline int l2cap_skbuff_fromiovec(struct l2cap_chan *chan, + struct msghdr *msg, int len, + int count, struct sk_buff *skb) +{ + struct l2cap_conn *conn = chan->conn; + struct sk_buff **frag; + int sent = 0; + + if (!copy_from_iter_full(skb_put(skb, count), count, &msg->msg_iter)) + return -EFAULT; + + sent += count; + len -= count; + + /* Continuation fragments (no L2CAP header) */ + frag = &skb_shinfo(skb)->frag_list; + while (len) { + struct sk_buff *tmp; + + count = min_t(unsigned int, conn->mtu, len); + + tmp = chan->ops->alloc_skb(chan, 0, count, + msg->msg_flags & MSG_DONTWAIT); + if (IS_ERR(tmp)) + return PTR_ERR(tmp); + + *frag = tmp; + + if (!copy_from_iter_full(skb_put(*frag, count), count, + &msg->msg_iter)) + return -EFAULT; + + sent += count; + len -= count; + + skb->len += (*frag)->len; + skb->data_len += (*frag)->len; + + frag = &(*frag)->next; + } + + return sent; +} + +static struct sk_buff *l2cap_create_connless_pdu(struct l2cap_chan *chan, + struct msghdr *msg, size_t len) +{ + struct l2cap_conn *conn = chan->conn; + struct sk_buff *skb; + int err, count, hlen = L2CAP_HDR_SIZE + L2CAP_PSMLEN_SIZE; + struct l2cap_hdr *lh; + + BT_DBG("chan %p psm 0x%2.2x len %zu", chan, + __le16_to_cpu(chan->psm), len); + + count = min_t(unsigned int, (conn->mtu - hlen), len); + + skb = chan->ops->alloc_skb(chan, hlen, count, + msg->msg_flags & MSG_DONTWAIT); + if (IS_ERR(skb)) + return skb; + + /* Create L2CAP header */ + lh = skb_put(skb, L2CAP_HDR_SIZE); + lh->cid = cpu_to_le16(chan->dcid); + lh->len = cpu_to_le16(len + L2CAP_PSMLEN_SIZE); + put_unaligned(chan->psm, (__le16 *) skb_put(skb, L2CAP_PSMLEN_SIZE)); + + err = l2cap_skbuff_fromiovec(chan, msg, len, count, skb); + if (unlikely(err < 0)) { + kfree_skb(skb); + return ERR_PTR(err); + } + return skb; +} + +static struct sk_buff *l2cap_create_basic_pdu(struct l2cap_chan *chan, + struct msghdr *msg, size_t len) +{ + struct l2cap_conn *conn = chan->conn; + struct sk_buff *skb; + int err, count; + struct l2cap_hdr *lh; + + BT_DBG("chan %p len %zu", chan, len); + + count = min_t(unsigned int, (conn->mtu - L2CAP_HDR_SIZE), len); + + skb = chan->ops->alloc_skb(chan, L2CAP_HDR_SIZE, count, + msg->msg_flags & MSG_DONTWAIT); + if (IS_ERR(skb)) + return skb; + + /* Create L2CAP header */ + lh = skb_put(skb, L2CAP_HDR_SIZE); + lh->cid = cpu_to_le16(chan->dcid); + lh->len = cpu_to_le16(len); + + err = l2cap_skbuff_fromiovec(chan, msg, len, count, skb); + if (unlikely(err < 0)) { + kfree_skb(skb); + return ERR_PTR(err); + } + return skb; +} + +static struct sk_buff *l2cap_create_iframe_pdu(struct l2cap_chan *chan, + struct msghdr *msg, size_t len, + u16 sdulen) +{ + struct l2cap_conn *conn = chan->conn; + struct sk_buff *skb; + int err, count, hlen; + struct l2cap_hdr *lh; + + BT_DBG("chan %p len %zu", chan, len); + + if (!conn) + return ERR_PTR(-ENOTCONN); + + hlen = __ertm_hdr_size(chan); + + if (sdulen) + hlen += L2CAP_SDULEN_SIZE; + + if (chan->fcs == L2CAP_FCS_CRC16) + hlen += L2CAP_FCS_SIZE; + + count = min_t(unsigned int, (conn->mtu - hlen), len); + + skb = chan->ops->alloc_skb(chan, hlen, count, + msg->msg_flags & MSG_DONTWAIT); + if (IS_ERR(skb)) + return skb; + + /* Create L2CAP header */ + lh = skb_put(skb, L2CAP_HDR_SIZE); + lh->cid = cpu_to_le16(chan->dcid); + lh->len = cpu_to_le16(len + (hlen - L2CAP_HDR_SIZE)); + + /* Control header is populated later */ + if (test_bit(FLAG_EXT_CTRL, &chan->flags)) + put_unaligned_le32(0, skb_put(skb, L2CAP_EXT_CTRL_SIZE)); + else + put_unaligned_le16(0, skb_put(skb, L2CAP_ENH_CTRL_SIZE)); + + if (sdulen) + put_unaligned_le16(sdulen, skb_put(skb, L2CAP_SDULEN_SIZE)); + + err = l2cap_skbuff_fromiovec(chan, msg, len, count, skb); + if (unlikely(err < 0)) { + kfree_skb(skb); + return ERR_PTR(err); + } + + bt_cb(skb)->l2cap.fcs = chan->fcs; + bt_cb(skb)->l2cap.retries = 0; + return skb; +} + +static int l2cap_segment_sdu(struct l2cap_chan *chan, + struct sk_buff_head *seg_queue, + struct msghdr *msg, size_t len) +{ + struct sk_buff *skb; + u16 sdu_len; + size_t pdu_len; + u8 sar; + + BT_DBG("chan %p, msg %p, len %zu", chan, msg, len); + + /* It is critical that ERTM PDUs fit in a single HCI fragment, + * so fragmented skbs are not used. The HCI layer's handling + * of fragmented skbs is not compatible with ERTM's queueing. + */ + + /* PDU size is derived from the HCI MTU */ + pdu_len = chan->conn->mtu; + + /* Constrain PDU size for BR/EDR connections */ + if (!chan->hs_hcon) + pdu_len = min_t(size_t, pdu_len, L2CAP_BREDR_MAX_PAYLOAD); + + /* Adjust for largest possible L2CAP overhead. */ + if (chan->fcs) + pdu_len -= L2CAP_FCS_SIZE; + + pdu_len -= __ertm_hdr_size(chan); + + /* Remote device may have requested smaller PDUs */ + pdu_len = min_t(size_t, pdu_len, chan->remote_mps); + + if (len <= pdu_len) { + sar = L2CAP_SAR_UNSEGMENTED; + sdu_len = 0; + pdu_len = len; + } else { + sar = L2CAP_SAR_START; + sdu_len = len; + } + + while (len > 0) { + skb = l2cap_create_iframe_pdu(chan, msg, pdu_len, sdu_len); + + if (IS_ERR(skb)) { + __skb_queue_purge(seg_queue); + return PTR_ERR(skb); + } + + bt_cb(skb)->l2cap.sar = sar; + __skb_queue_tail(seg_queue, skb); + + len -= pdu_len; + if (sdu_len) + sdu_len = 0; + + if (len <= pdu_len) { + sar = L2CAP_SAR_END; + pdu_len = len; + } else { + sar = L2CAP_SAR_CONTINUE; + } + } + + return 0; +} + +static struct sk_buff *l2cap_create_le_flowctl_pdu(struct l2cap_chan *chan, + struct msghdr *msg, + size_t len, u16 sdulen) +{ + struct l2cap_conn *conn = chan->conn; + struct sk_buff *skb; + int err, count, hlen; + struct l2cap_hdr *lh; + + BT_DBG("chan %p len %zu", chan, len); + + if (!conn) + return ERR_PTR(-ENOTCONN); + + hlen = L2CAP_HDR_SIZE; + + if (sdulen) + hlen += L2CAP_SDULEN_SIZE; + + count = min_t(unsigned int, (conn->mtu - hlen), len); + + skb = chan->ops->alloc_skb(chan, hlen, count, + msg->msg_flags & MSG_DONTWAIT); + if (IS_ERR(skb)) + return skb; + + /* Create L2CAP header */ + lh = skb_put(skb, L2CAP_HDR_SIZE); + lh->cid = cpu_to_le16(chan->dcid); + lh->len = cpu_to_le16(len + (hlen - L2CAP_HDR_SIZE)); + + if (sdulen) + put_unaligned_le16(sdulen, skb_put(skb, L2CAP_SDULEN_SIZE)); + + err = l2cap_skbuff_fromiovec(chan, msg, len, count, skb); + if (unlikely(err < 0)) { + kfree_skb(skb); + return ERR_PTR(err); + } + + return skb; +} + +static int l2cap_segment_le_sdu(struct l2cap_chan *chan, + struct sk_buff_head *seg_queue, + struct msghdr *msg, size_t len) +{ + struct sk_buff *skb; + size_t pdu_len; + u16 sdu_len; + + BT_DBG("chan %p, msg %p, len %zu", chan, msg, len); + + sdu_len = len; + pdu_len = chan->remote_mps - L2CAP_SDULEN_SIZE; + + while (len > 0) { + if (len <= pdu_len) + pdu_len = len; + + skb = l2cap_create_le_flowctl_pdu(chan, msg, pdu_len, sdu_len); + if (IS_ERR(skb)) { + __skb_queue_purge(seg_queue); + return PTR_ERR(skb); + } + + __skb_queue_tail(seg_queue, skb); + + len -= pdu_len; + + if (sdu_len) { + sdu_len = 0; + pdu_len += L2CAP_SDULEN_SIZE; + } + } + + return 0; +} + +static void l2cap_le_flowctl_send(struct l2cap_chan *chan) +{ + int sent = 0; + + BT_DBG("chan %p", chan); + + while (chan->tx_credits && !skb_queue_empty(&chan->tx_q)) { + l2cap_do_send(chan, skb_dequeue(&chan->tx_q)); + chan->tx_credits--; + sent++; + } + + BT_DBG("Sent %d credits %u queued %u", sent, chan->tx_credits, + skb_queue_len(&chan->tx_q)); +} + +int l2cap_chan_send(struct l2cap_chan *chan, struct msghdr *msg, size_t len) +{ + struct sk_buff *skb; + int err; + struct sk_buff_head seg_queue; + + if (!chan->conn) + return -ENOTCONN; + + /* Connectionless channel */ + if (chan->chan_type == L2CAP_CHAN_CONN_LESS) { + skb = l2cap_create_connless_pdu(chan, msg, len); + if (IS_ERR(skb)) + return PTR_ERR(skb); + + l2cap_do_send(chan, skb); + return len; + } + + switch (chan->mode) { + case L2CAP_MODE_LE_FLOWCTL: + case L2CAP_MODE_EXT_FLOWCTL: + /* Check outgoing MTU */ + if (len > chan->omtu) + return -EMSGSIZE; + + __skb_queue_head_init(&seg_queue); + + err = l2cap_segment_le_sdu(chan, &seg_queue, msg, len); + + if (chan->state != BT_CONNECTED) { + __skb_queue_purge(&seg_queue); + err = -ENOTCONN; + } + + if (err) + return err; + + skb_queue_splice_tail_init(&seg_queue, &chan->tx_q); + + l2cap_le_flowctl_send(chan); + + if (!chan->tx_credits) + chan->ops->suspend(chan); + + err = len; + + break; + + case L2CAP_MODE_BASIC: + /* Check outgoing MTU */ + if (len > chan->omtu) + return -EMSGSIZE; + + /* Create a basic PDU */ + skb = l2cap_create_basic_pdu(chan, msg, len); + if (IS_ERR(skb)) + return PTR_ERR(skb); + + l2cap_do_send(chan, skb); + err = len; + break; + + case L2CAP_MODE_ERTM: + case L2CAP_MODE_STREAMING: + /* Check outgoing MTU */ + if (len > chan->omtu) { + err = -EMSGSIZE; + break; + } + + __skb_queue_head_init(&seg_queue); + + /* Do segmentation before calling in to the state machine, + * since it's possible to block while waiting for memory + * allocation. + */ + err = l2cap_segment_sdu(chan, &seg_queue, msg, len); + + if (err) + break; + + if (chan->mode == L2CAP_MODE_ERTM) + l2cap_tx(chan, NULL, &seg_queue, L2CAP_EV_DATA_REQUEST); + else + l2cap_streaming_send(chan, &seg_queue); + + err = len; + + /* If the skbs were not queued for sending, they'll still be in + * seg_queue and need to be purged. + */ + __skb_queue_purge(&seg_queue); + break; + + default: + BT_DBG("bad state %1.1x", chan->mode); + err = -EBADFD; + } + + return err; +} +EXPORT_SYMBOL_GPL(l2cap_chan_send); + +static void l2cap_send_srej(struct l2cap_chan *chan, u16 txseq) +{ + struct l2cap_ctrl control; + u16 seq; + + BT_DBG("chan %p, txseq %u", chan, txseq); + + memset(&control, 0, sizeof(control)); + control.sframe = 1; + control.super = L2CAP_SUPER_SREJ; + + for (seq = chan->expected_tx_seq; seq != txseq; + seq = __next_seq(chan, seq)) { + if (!l2cap_ertm_seq_in_queue(&chan->srej_q, seq)) { + control.reqseq = seq; + l2cap_send_sframe(chan, &control); + l2cap_seq_list_append(&chan->srej_list, seq); + } + } + + chan->expected_tx_seq = __next_seq(chan, txseq); +} + +static void l2cap_send_srej_tail(struct l2cap_chan *chan) +{ + struct l2cap_ctrl control; + + BT_DBG("chan %p", chan); + + if (chan->srej_list.tail == L2CAP_SEQ_LIST_CLEAR) + return; + + memset(&control, 0, sizeof(control)); + control.sframe = 1; + control.super = L2CAP_SUPER_SREJ; + control.reqseq = chan->srej_list.tail; + l2cap_send_sframe(chan, &control); +} + +static void l2cap_send_srej_list(struct l2cap_chan *chan, u16 txseq) +{ + struct l2cap_ctrl control; + u16 initial_head; + u16 seq; + + BT_DBG("chan %p, txseq %u", chan, txseq); + + memset(&control, 0, sizeof(control)); + control.sframe = 1; + control.super = L2CAP_SUPER_SREJ; + + /* Capture initial list head to allow only one pass through the list. */ + initial_head = chan->srej_list.head; + + do { + seq = l2cap_seq_list_pop(&chan->srej_list); + if (seq == txseq || seq == L2CAP_SEQ_LIST_CLEAR) + break; + + control.reqseq = seq; + l2cap_send_sframe(chan, &control); + l2cap_seq_list_append(&chan->srej_list, seq); + } while (chan->srej_list.head != initial_head); +} + +static void l2cap_process_reqseq(struct l2cap_chan *chan, u16 reqseq) +{ + struct sk_buff *acked_skb; + u16 ackseq; + + BT_DBG("chan %p, reqseq %u", chan, reqseq); + + if (chan->unacked_frames == 0 || reqseq == chan->expected_ack_seq) + return; + + BT_DBG("expected_ack_seq %u, unacked_frames %u", + chan->expected_ack_seq, chan->unacked_frames); + + for (ackseq = chan->expected_ack_seq; ackseq != reqseq; + ackseq = __next_seq(chan, ackseq)) { + + acked_skb = l2cap_ertm_seq_in_queue(&chan->tx_q, ackseq); + if (acked_skb) { + skb_unlink(acked_skb, &chan->tx_q); + kfree_skb(acked_skb); + chan->unacked_frames--; + } + } + + chan->expected_ack_seq = reqseq; + + if (chan->unacked_frames == 0) + __clear_retrans_timer(chan); + + BT_DBG("unacked_frames %u", chan->unacked_frames); +} + +static void l2cap_abort_rx_srej_sent(struct l2cap_chan *chan) +{ + BT_DBG("chan %p", chan); + + chan->expected_tx_seq = chan->buffer_seq; + l2cap_seq_list_clear(&chan->srej_list); + skb_queue_purge(&chan->srej_q); + chan->rx_state = L2CAP_RX_STATE_RECV; +} + +static void l2cap_tx_state_xmit(struct l2cap_chan *chan, + struct l2cap_ctrl *control, + struct sk_buff_head *skbs, u8 event) +{ + BT_DBG("chan %p, control %p, skbs %p, event %d", chan, control, skbs, + event); + + switch (event) { + case L2CAP_EV_DATA_REQUEST: + if (chan->tx_send_head == NULL) + chan->tx_send_head = skb_peek(skbs); + + skb_queue_splice_tail_init(skbs, &chan->tx_q); + l2cap_ertm_send(chan); + break; + case L2CAP_EV_LOCAL_BUSY_DETECTED: + BT_DBG("Enter LOCAL_BUSY"); + set_bit(CONN_LOCAL_BUSY, &chan->conn_state); + + if (chan->rx_state == L2CAP_RX_STATE_SREJ_SENT) { + /* The SREJ_SENT state must be aborted if we are to + * enter the LOCAL_BUSY state. + */ + l2cap_abort_rx_srej_sent(chan); + } + + l2cap_send_ack(chan); + + break; + case L2CAP_EV_LOCAL_BUSY_CLEAR: + BT_DBG("Exit LOCAL_BUSY"); + clear_bit(CONN_LOCAL_BUSY, &chan->conn_state); + + if (test_bit(CONN_RNR_SENT, &chan->conn_state)) { + struct l2cap_ctrl local_control; + + memset(&local_control, 0, sizeof(local_control)); + local_control.sframe = 1; + local_control.super = L2CAP_SUPER_RR; + local_control.poll = 1; + local_control.reqseq = chan->buffer_seq; + l2cap_send_sframe(chan, &local_control); + + chan->retry_count = 1; + __set_monitor_timer(chan); + chan->tx_state = L2CAP_TX_STATE_WAIT_F; + } + break; + case L2CAP_EV_RECV_REQSEQ_AND_FBIT: + l2cap_process_reqseq(chan, control->reqseq); + break; + case L2CAP_EV_EXPLICIT_POLL: + l2cap_send_rr_or_rnr(chan, 1); + chan->retry_count = 1; + __set_monitor_timer(chan); + __clear_ack_timer(chan); + chan->tx_state = L2CAP_TX_STATE_WAIT_F; + break; + case L2CAP_EV_RETRANS_TO: + l2cap_send_rr_or_rnr(chan, 1); + chan->retry_count = 1; + __set_monitor_timer(chan); + chan->tx_state = L2CAP_TX_STATE_WAIT_F; + break; + case L2CAP_EV_RECV_FBIT: + /* Nothing to process */ + break; + default: + break; + } +} + +static void l2cap_tx_state_wait_f(struct l2cap_chan *chan, + struct l2cap_ctrl *control, + struct sk_buff_head *skbs, u8 event) +{ + BT_DBG("chan %p, control %p, skbs %p, event %d", chan, control, skbs, + event); + + switch (event) { + case L2CAP_EV_DATA_REQUEST: + if (chan->tx_send_head == NULL) + chan->tx_send_head = skb_peek(skbs); + /* Queue data, but don't send. */ + skb_queue_splice_tail_init(skbs, &chan->tx_q); + break; + case L2CAP_EV_LOCAL_BUSY_DETECTED: + BT_DBG("Enter LOCAL_BUSY"); + set_bit(CONN_LOCAL_BUSY, &chan->conn_state); + + if (chan->rx_state == L2CAP_RX_STATE_SREJ_SENT) { + /* The SREJ_SENT state must be aborted if we are to + * enter the LOCAL_BUSY state. + */ + l2cap_abort_rx_srej_sent(chan); + } + + l2cap_send_ack(chan); + + break; + case L2CAP_EV_LOCAL_BUSY_CLEAR: + BT_DBG("Exit LOCAL_BUSY"); + clear_bit(CONN_LOCAL_BUSY, &chan->conn_state); + + if (test_bit(CONN_RNR_SENT, &chan->conn_state)) { + struct l2cap_ctrl local_control; + memset(&local_control, 0, sizeof(local_control)); + local_control.sframe = 1; + local_control.super = L2CAP_SUPER_RR; + local_control.poll = 1; + local_control.reqseq = chan->buffer_seq; + l2cap_send_sframe(chan, &local_control); + + chan->retry_count = 1; + __set_monitor_timer(chan); + chan->tx_state = L2CAP_TX_STATE_WAIT_F; + } + break; + case L2CAP_EV_RECV_REQSEQ_AND_FBIT: + l2cap_process_reqseq(chan, control->reqseq); + fallthrough; + + case L2CAP_EV_RECV_FBIT: + if (control && control->final) { + __clear_monitor_timer(chan); + if (chan->unacked_frames > 0) + __set_retrans_timer(chan); + chan->retry_count = 0; + chan->tx_state = L2CAP_TX_STATE_XMIT; + BT_DBG("recv fbit tx_state 0x2.2%x", chan->tx_state); + } + break; + case L2CAP_EV_EXPLICIT_POLL: + /* Ignore */ + break; + case L2CAP_EV_MONITOR_TO: + if (chan->max_tx == 0 || chan->retry_count < chan->max_tx) { + l2cap_send_rr_or_rnr(chan, 1); + __set_monitor_timer(chan); + chan->retry_count++; + } else { + l2cap_send_disconn_req(chan, ECONNABORTED); + } + break; + default: + break; + } +} + +static void l2cap_tx(struct l2cap_chan *chan, struct l2cap_ctrl *control, + struct sk_buff_head *skbs, u8 event) +{ + BT_DBG("chan %p, control %p, skbs %p, event %d, state %d", + chan, control, skbs, event, chan->tx_state); + + switch (chan->tx_state) { + case L2CAP_TX_STATE_XMIT: + l2cap_tx_state_xmit(chan, control, skbs, event); + break; + case L2CAP_TX_STATE_WAIT_F: + l2cap_tx_state_wait_f(chan, control, skbs, event); + break; + default: + /* Ignore event */ + break; + } +} + +static void l2cap_pass_to_tx(struct l2cap_chan *chan, + struct l2cap_ctrl *control) +{ + BT_DBG("chan %p, control %p", chan, control); + l2cap_tx(chan, control, NULL, L2CAP_EV_RECV_REQSEQ_AND_FBIT); +} + +static void l2cap_pass_to_tx_fbit(struct l2cap_chan *chan, + struct l2cap_ctrl *control) +{ + BT_DBG("chan %p, control %p", chan, control); + l2cap_tx(chan, control, NULL, L2CAP_EV_RECV_FBIT); +} + +/* Copy frame to all raw sockets on that connection */ +static void l2cap_raw_recv(struct l2cap_conn *conn, struct sk_buff *skb) +{ + struct sk_buff *nskb; + struct l2cap_chan *chan; + + BT_DBG("conn %p", conn); + + mutex_lock(&conn->chan_lock); + + list_for_each_entry(chan, &conn->chan_l, list) { + if (chan->chan_type != L2CAP_CHAN_RAW) + continue; + + /* Don't send frame to the channel it came from */ + if (bt_cb(skb)->l2cap.chan == chan) + continue; + + nskb = skb_clone(skb, GFP_KERNEL); + if (!nskb) + continue; + if (chan->ops->recv(chan, nskb)) + kfree_skb(nskb); + } + + mutex_unlock(&conn->chan_lock); +} + +/* ---- L2CAP signalling commands ---- */ +static struct sk_buff *l2cap_build_cmd(struct l2cap_conn *conn, u8 code, + u8 ident, u16 dlen, void *data) +{ + struct sk_buff *skb, **frag; + struct l2cap_cmd_hdr *cmd; + struct l2cap_hdr *lh; + int len, count; + + BT_DBG("conn %p, code 0x%2.2x, ident 0x%2.2x, len %u", + conn, code, ident, dlen); + + if (conn->mtu < L2CAP_HDR_SIZE + L2CAP_CMD_HDR_SIZE) + return NULL; + + len = L2CAP_HDR_SIZE + L2CAP_CMD_HDR_SIZE + dlen; + count = min_t(unsigned int, conn->mtu, len); + + skb = bt_skb_alloc(count, GFP_KERNEL); + if (!skb) + return NULL; + + lh = skb_put(skb, L2CAP_HDR_SIZE); + lh->len = cpu_to_le16(L2CAP_CMD_HDR_SIZE + dlen); + + if (conn->hcon->type == LE_LINK) + lh->cid = cpu_to_le16(L2CAP_CID_LE_SIGNALING); + else + lh->cid = cpu_to_le16(L2CAP_CID_SIGNALING); + + cmd = skb_put(skb, L2CAP_CMD_HDR_SIZE); + cmd->code = code; + cmd->ident = ident; + cmd->len = cpu_to_le16(dlen); + + if (dlen) { + count -= L2CAP_HDR_SIZE + L2CAP_CMD_HDR_SIZE; + skb_put_data(skb, data, count); + data += count; + } + + len -= skb->len; + + /* Continuation fragments (no L2CAP header) */ + frag = &skb_shinfo(skb)->frag_list; + while (len) { + count = min_t(unsigned int, conn->mtu, len); + + *frag = bt_skb_alloc(count, GFP_KERNEL); + if (!*frag) + goto fail; + + skb_put_data(*frag, data, count); + + len -= count; + data += count; + + frag = &(*frag)->next; + } + + return skb; + +fail: + kfree_skb(skb); + return NULL; +} + +static inline int l2cap_get_conf_opt(void **ptr, int *type, int *olen, + unsigned long *val) +{ + struct l2cap_conf_opt *opt = *ptr; + int len; + + len = L2CAP_CONF_OPT_SIZE + opt->len; + *ptr += len; + + *type = opt->type; + *olen = opt->len; + + switch (opt->len) { + case 1: + *val = *((u8 *) opt->val); + break; + + case 2: + *val = get_unaligned_le16(opt->val); + break; + + case 4: + *val = get_unaligned_le32(opt->val); + break; + + default: + *val = (unsigned long) opt->val; + break; + } + + BT_DBG("type 0x%2.2x len %u val 0x%lx", *type, opt->len, *val); + return len; +} + +static void l2cap_add_conf_opt(void **ptr, u8 type, u8 len, unsigned long val, size_t size) +{ + struct l2cap_conf_opt *opt = *ptr; + + BT_DBG("type 0x%2.2x len %u val 0x%lx", type, len, val); + + if (size < L2CAP_CONF_OPT_SIZE + len) + return; + + opt->type = type; + opt->len = len; + + switch (len) { + case 1: + *((u8 *) opt->val) = val; + break; + + case 2: + put_unaligned_le16(val, opt->val); + break; + + case 4: + put_unaligned_le32(val, opt->val); + break; + + default: + memcpy(opt->val, (void *) val, len); + break; + } + + *ptr += L2CAP_CONF_OPT_SIZE + len; +} + +static void l2cap_add_opt_efs(void **ptr, struct l2cap_chan *chan, size_t size) +{ + struct l2cap_conf_efs efs; + + switch (chan->mode) { + case L2CAP_MODE_ERTM: + efs.id = chan->local_id; + efs.stype = chan->local_stype; + efs.msdu = cpu_to_le16(chan->local_msdu); + efs.sdu_itime = cpu_to_le32(chan->local_sdu_itime); + efs.acc_lat = cpu_to_le32(L2CAP_DEFAULT_ACC_LAT); + efs.flush_to = cpu_to_le32(L2CAP_EFS_DEFAULT_FLUSH_TO); + break; + + case L2CAP_MODE_STREAMING: + efs.id = 1; + efs.stype = L2CAP_SERV_BESTEFFORT; + efs.msdu = cpu_to_le16(chan->local_msdu); + efs.sdu_itime = cpu_to_le32(chan->local_sdu_itime); + efs.acc_lat = 0; + efs.flush_to = 0; + break; + + default: + return; + } + + l2cap_add_conf_opt(ptr, L2CAP_CONF_EFS, sizeof(efs), + (unsigned long) &efs, size); +} + +static void l2cap_ack_timeout(struct work_struct *work) +{ + struct l2cap_chan *chan = container_of(work, struct l2cap_chan, + ack_timer.work); + u16 frames_to_ack; + + BT_DBG("chan %p", chan); + + l2cap_chan_lock(chan); + + frames_to_ack = __seq_offset(chan, chan->buffer_seq, + chan->last_acked_seq); + + if (frames_to_ack) + l2cap_send_rr_or_rnr(chan, 0); + + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); +} + +int l2cap_ertm_init(struct l2cap_chan *chan) +{ + int err; + + chan->next_tx_seq = 0; + chan->expected_tx_seq = 0; + chan->expected_ack_seq = 0; + chan->unacked_frames = 0; + chan->buffer_seq = 0; + chan->frames_sent = 0; + chan->last_acked_seq = 0; + chan->sdu = NULL; + chan->sdu_last_frag = NULL; + chan->sdu_len = 0; + + skb_queue_head_init(&chan->tx_q); + + chan->local_amp_id = AMP_ID_BREDR; + chan->move_id = AMP_ID_BREDR; + chan->move_state = L2CAP_MOVE_STABLE; + chan->move_role = L2CAP_MOVE_ROLE_NONE; + + if (chan->mode != L2CAP_MODE_ERTM) + return 0; + + chan->rx_state = L2CAP_RX_STATE_RECV; + chan->tx_state = L2CAP_TX_STATE_XMIT; + + skb_queue_head_init(&chan->srej_q); + + err = l2cap_seq_list_init(&chan->srej_list, chan->tx_win); + if (err < 0) + return err; + + err = l2cap_seq_list_init(&chan->retrans_list, chan->remote_tx_win); + if (err < 0) + l2cap_seq_list_free(&chan->srej_list); + + return err; +} + +static inline __u8 l2cap_select_mode(__u8 mode, __u16 remote_feat_mask) +{ + switch (mode) { + case L2CAP_MODE_STREAMING: + case L2CAP_MODE_ERTM: + if (l2cap_mode_supported(mode, remote_feat_mask)) + return mode; + fallthrough; + default: + return L2CAP_MODE_BASIC; + } +} + +static inline bool __l2cap_ews_supported(struct l2cap_conn *conn) +{ + return ((conn->local_fixed_chan & L2CAP_FC_A2MP) && + (conn->feat_mask & L2CAP_FEAT_EXT_WINDOW)); +} + +static inline bool __l2cap_efs_supported(struct l2cap_conn *conn) +{ + return ((conn->local_fixed_chan & L2CAP_FC_A2MP) && + (conn->feat_mask & L2CAP_FEAT_EXT_FLOW)); +} + +static void __l2cap_set_ertm_timeouts(struct l2cap_chan *chan, + struct l2cap_conf_rfc *rfc) +{ + if (chan->local_amp_id != AMP_ID_BREDR && chan->hs_hcon) { + u64 ertm_to = chan->hs_hcon->hdev->amp_be_flush_to; + + /* Class 1 devices have must have ERTM timeouts + * exceeding the Link Supervision Timeout. The + * default Link Supervision Timeout for AMP + * controllers is 10 seconds. + * + * Class 1 devices use 0xffffffff for their + * best-effort flush timeout, so the clamping logic + * will result in a timeout that meets the above + * requirement. ERTM timeouts are 16-bit values, so + * the maximum timeout is 65.535 seconds. + */ + + /* Convert timeout to milliseconds and round */ + ertm_to = DIV_ROUND_UP_ULL(ertm_to, 1000); + + /* This is the recommended formula for class 2 devices + * that start ERTM timers when packets are sent to the + * controller. + */ + ertm_to = 3 * ertm_to + 500; + + if (ertm_to > 0xffff) + ertm_to = 0xffff; + + rfc->retrans_timeout = cpu_to_le16((u16) ertm_to); + rfc->monitor_timeout = rfc->retrans_timeout; + } else { + rfc->retrans_timeout = cpu_to_le16(L2CAP_DEFAULT_RETRANS_TO); + rfc->monitor_timeout = cpu_to_le16(L2CAP_DEFAULT_MONITOR_TO); + } +} + +static inline void l2cap_txwin_setup(struct l2cap_chan *chan) +{ + if (chan->tx_win > L2CAP_DEFAULT_TX_WINDOW && + __l2cap_ews_supported(chan->conn)) { + /* use extended control field */ + set_bit(FLAG_EXT_CTRL, &chan->flags); + chan->tx_win_max = L2CAP_DEFAULT_EXT_WINDOW; + } else { + chan->tx_win = min_t(u16, chan->tx_win, + L2CAP_DEFAULT_TX_WINDOW); + chan->tx_win_max = L2CAP_DEFAULT_TX_WINDOW; + } + chan->ack_win = chan->tx_win; +} + +static void l2cap_mtu_auto(struct l2cap_chan *chan) +{ + struct hci_conn *conn = chan->conn->hcon; + + chan->imtu = L2CAP_DEFAULT_MIN_MTU; + + /* The 2-DH1 packet has between 2 and 56 information bytes + * (including the 2-byte payload header) + */ + if (!(conn->pkt_type & HCI_2DH1)) + chan->imtu = 54; + + /* The 3-DH1 packet has between 2 and 85 information bytes + * (including the 2-byte payload header) + */ + if (!(conn->pkt_type & HCI_3DH1)) + chan->imtu = 83; + + /* The 2-DH3 packet has between 2 and 369 information bytes + * (including the 2-byte payload header) + */ + if (!(conn->pkt_type & HCI_2DH3)) + chan->imtu = 367; + + /* The 3-DH3 packet has between 2 and 554 information bytes + * (including the 2-byte payload header) + */ + if (!(conn->pkt_type & HCI_3DH3)) + chan->imtu = 552; + + /* The 2-DH5 packet has between 2 and 681 information bytes + * (including the 2-byte payload header) + */ + if (!(conn->pkt_type & HCI_2DH5)) + chan->imtu = 679; + + /* The 3-DH5 packet has between 2 and 1023 information bytes + * (including the 2-byte payload header) + */ + if (!(conn->pkt_type & HCI_3DH5)) + chan->imtu = 1021; +} + +static int l2cap_build_conf_req(struct l2cap_chan *chan, void *data, size_t data_size) +{ + struct l2cap_conf_req *req = data; + struct l2cap_conf_rfc rfc = { .mode = chan->mode }; + void *ptr = req->data; + void *endptr = data + data_size; + u16 size; + + BT_DBG("chan %p", chan); + + if (chan->num_conf_req || chan->num_conf_rsp) + goto done; + + switch (chan->mode) { + case L2CAP_MODE_STREAMING: + case L2CAP_MODE_ERTM: + if (test_bit(CONF_STATE2_DEVICE, &chan->conf_state)) + break; + + if (__l2cap_efs_supported(chan->conn)) + set_bit(FLAG_EFS_ENABLE, &chan->flags); + + fallthrough; + default: + chan->mode = l2cap_select_mode(rfc.mode, chan->conn->feat_mask); + break; + } + +done: + if (chan->imtu != L2CAP_DEFAULT_MTU) { + if (!chan->imtu) + l2cap_mtu_auto(chan); + l2cap_add_conf_opt(&ptr, L2CAP_CONF_MTU, 2, chan->imtu, + endptr - ptr); + } + + switch (chan->mode) { + case L2CAP_MODE_BASIC: + if (disable_ertm) + break; + + if (!(chan->conn->feat_mask & L2CAP_FEAT_ERTM) && + !(chan->conn->feat_mask & L2CAP_FEAT_STREAMING)) + break; + + rfc.mode = L2CAP_MODE_BASIC; + rfc.txwin_size = 0; + rfc.max_transmit = 0; + rfc.retrans_timeout = 0; + rfc.monitor_timeout = 0; + rfc.max_pdu_size = 0; + + l2cap_add_conf_opt(&ptr, L2CAP_CONF_RFC, sizeof(rfc), + (unsigned long) &rfc, endptr - ptr); + break; + + case L2CAP_MODE_ERTM: + rfc.mode = L2CAP_MODE_ERTM; + rfc.max_transmit = chan->max_tx; + + __l2cap_set_ertm_timeouts(chan, &rfc); + + size = min_t(u16, L2CAP_DEFAULT_MAX_PDU_SIZE, chan->conn->mtu - + L2CAP_EXT_HDR_SIZE - L2CAP_SDULEN_SIZE - + L2CAP_FCS_SIZE); + rfc.max_pdu_size = cpu_to_le16(size); + + l2cap_txwin_setup(chan); + + rfc.txwin_size = min_t(u16, chan->tx_win, + L2CAP_DEFAULT_TX_WINDOW); + + l2cap_add_conf_opt(&ptr, L2CAP_CONF_RFC, sizeof(rfc), + (unsigned long) &rfc, endptr - ptr); + + if (test_bit(FLAG_EFS_ENABLE, &chan->flags)) + l2cap_add_opt_efs(&ptr, chan, endptr - ptr); + + if (test_bit(FLAG_EXT_CTRL, &chan->flags)) + l2cap_add_conf_opt(&ptr, L2CAP_CONF_EWS, 2, + chan->tx_win, endptr - ptr); + + if (chan->conn->feat_mask & L2CAP_FEAT_FCS) + if (chan->fcs == L2CAP_FCS_NONE || + test_bit(CONF_RECV_NO_FCS, &chan->conf_state)) { + chan->fcs = L2CAP_FCS_NONE; + l2cap_add_conf_opt(&ptr, L2CAP_CONF_FCS, 1, + chan->fcs, endptr - ptr); + } + break; + + case L2CAP_MODE_STREAMING: + l2cap_txwin_setup(chan); + rfc.mode = L2CAP_MODE_STREAMING; + rfc.txwin_size = 0; + rfc.max_transmit = 0; + rfc.retrans_timeout = 0; + rfc.monitor_timeout = 0; + + size = min_t(u16, L2CAP_DEFAULT_MAX_PDU_SIZE, chan->conn->mtu - + L2CAP_EXT_HDR_SIZE - L2CAP_SDULEN_SIZE - + L2CAP_FCS_SIZE); + rfc.max_pdu_size = cpu_to_le16(size); + + l2cap_add_conf_opt(&ptr, L2CAP_CONF_RFC, sizeof(rfc), + (unsigned long) &rfc, endptr - ptr); + + if (test_bit(FLAG_EFS_ENABLE, &chan->flags)) + l2cap_add_opt_efs(&ptr, chan, endptr - ptr); + + if (chan->conn->feat_mask & L2CAP_FEAT_FCS) + if (chan->fcs == L2CAP_FCS_NONE || + test_bit(CONF_RECV_NO_FCS, &chan->conf_state)) { + chan->fcs = L2CAP_FCS_NONE; + l2cap_add_conf_opt(&ptr, L2CAP_CONF_FCS, 1, + chan->fcs, endptr - ptr); + } + break; + } + + req->dcid = cpu_to_le16(chan->dcid); + req->flags = cpu_to_le16(0); + + return ptr - data; +} + +static int l2cap_parse_conf_req(struct l2cap_chan *chan, void *data, size_t data_size) +{ + struct l2cap_conf_rsp *rsp = data; + void *ptr = rsp->data; + void *endptr = data + data_size; + void *req = chan->conf_req; + int len = chan->conf_len; + int type, hint, olen; + unsigned long val; + struct l2cap_conf_rfc rfc = { .mode = L2CAP_MODE_BASIC }; + struct l2cap_conf_efs efs; + u8 remote_efs = 0; + u16 mtu = L2CAP_DEFAULT_MTU; + u16 result = L2CAP_CONF_SUCCESS; + u16 size; + + BT_DBG("chan %p", chan); + + while (len >= L2CAP_CONF_OPT_SIZE) { + len -= l2cap_get_conf_opt(&req, &type, &olen, &val); + if (len < 0) + break; + + hint = type & L2CAP_CONF_HINT; + type &= L2CAP_CONF_MASK; + + switch (type) { + case L2CAP_CONF_MTU: + if (olen != 2) + break; + mtu = val; + break; + + case L2CAP_CONF_FLUSH_TO: + if (olen != 2) + break; + chan->flush_to = val; + break; + + case L2CAP_CONF_QOS: + break; + + case L2CAP_CONF_RFC: + if (olen != sizeof(rfc)) + break; + memcpy(&rfc, (void *) val, olen); + break; + + case L2CAP_CONF_FCS: + if (olen != 1) + break; + if (val == L2CAP_FCS_NONE) + set_bit(CONF_RECV_NO_FCS, &chan->conf_state); + break; + + case L2CAP_CONF_EFS: + if (olen != sizeof(efs)) + break; + remote_efs = 1; + memcpy(&efs, (void *) val, olen); + break; + + case L2CAP_CONF_EWS: + if (olen != 2) + break; + if (!(chan->conn->local_fixed_chan & L2CAP_FC_A2MP)) + return -ECONNREFUSED; + set_bit(FLAG_EXT_CTRL, &chan->flags); + set_bit(CONF_EWS_RECV, &chan->conf_state); + chan->tx_win_max = L2CAP_DEFAULT_EXT_WINDOW; + chan->remote_tx_win = val; + break; + + default: + if (hint) + break; + result = L2CAP_CONF_UNKNOWN; + l2cap_add_conf_opt(&ptr, (u8)type, sizeof(u8), type, endptr - ptr); + break; + } + } + + if (chan->num_conf_rsp || chan->num_conf_req > 1) + goto done; + + switch (chan->mode) { + case L2CAP_MODE_STREAMING: + case L2CAP_MODE_ERTM: + if (!test_bit(CONF_STATE2_DEVICE, &chan->conf_state)) { + chan->mode = l2cap_select_mode(rfc.mode, + chan->conn->feat_mask); + break; + } + + if (remote_efs) { + if (__l2cap_efs_supported(chan->conn)) + set_bit(FLAG_EFS_ENABLE, &chan->flags); + else + return -ECONNREFUSED; + } + + if (chan->mode != rfc.mode) + return -ECONNREFUSED; + + break; + } + +done: + if (chan->mode != rfc.mode) { + result = L2CAP_CONF_UNACCEPT; + rfc.mode = chan->mode; + + if (chan->num_conf_rsp == 1) + return -ECONNREFUSED; + + l2cap_add_conf_opt(&ptr, L2CAP_CONF_RFC, sizeof(rfc), + (unsigned long) &rfc, endptr - ptr); + } + + if (result == L2CAP_CONF_SUCCESS) { + /* Configure output options and let the other side know + * which ones we don't like. */ + + if (mtu < L2CAP_DEFAULT_MIN_MTU) + result = L2CAP_CONF_UNACCEPT; + else { + chan->omtu = mtu; + set_bit(CONF_MTU_DONE, &chan->conf_state); + } + l2cap_add_conf_opt(&ptr, L2CAP_CONF_MTU, 2, chan->omtu, endptr - ptr); + + if (remote_efs) { + if (chan->local_stype != L2CAP_SERV_NOTRAFIC && + efs.stype != L2CAP_SERV_NOTRAFIC && + efs.stype != chan->local_stype) { + + result = L2CAP_CONF_UNACCEPT; + + if (chan->num_conf_req >= 1) + return -ECONNREFUSED; + + l2cap_add_conf_opt(&ptr, L2CAP_CONF_EFS, + sizeof(efs), + (unsigned long) &efs, endptr - ptr); + } else { + /* Send PENDING Conf Rsp */ + result = L2CAP_CONF_PENDING; + set_bit(CONF_LOC_CONF_PEND, &chan->conf_state); + } + } + + switch (rfc.mode) { + case L2CAP_MODE_BASIC: + chan->fcs = L2CAP_FCS_NONE; + set_bit(CONF_MODE_DONE, &chan->conf_state); + break; + + case L2CAP_MODE_ERTM: + if (!test_bit(CONF_EWS_RECV, &chan->conf_state)) + chan->remote_tx_win = rfc.txwin_size; + else + rfc.txwin_size = L2CAP_DEFAULT_TX_WINDOW; + + chan->remote_max_tx = rfc.max_transmit; + + size = min_t(u16, le16_to_cpu(rfc.max_pdu_size), + chan->conn->mtu - L2CAP_EXT_HDR_SIZE - + L2CAP_SDULEN_SIZE - L2CAP_FCS_SIZE); + rfc.max_pdu_size = cpu_to_le16(size); + chan->remote_mps = size; + + __l2cap_set_ertm_timeouts(chan, &rfc); + + set_bit(CONF_MODE_DONE, &chan->conf_state); + + l2cap_add_conf_opt(&ptr, L2CAP_CONF_RFC, + sizeof(rfc), (unsigned long) &rfc, endptr - ptr); + + if (remote_efs && + test_bit(FLAG_EFS_ENABLE, &chan->flags)) { + chan->remote_id = efs.id; + chan->remote_stype = efs.stype; + chan->remote_msdu = le16_to_cpu(efs.msdu); + chan->remote_flush_to = + le32_to_cpu(efs.flush_to); + chan->remote_acc_lat = + le32_to_cpu(efs.acc_lat); + chan->remote_sdu_itime = + le32_to_cpu(efs.sdu_itime); + l2cap_add_conf_opt(&ptr, L2CAP_CONF_EFS, + sizeof(efs), + (unsigned long) &efs, endptr - ptr); + } + break; + + case L2CAP_MODE_STREAMING: + size = min_t(u16, le16_to_cpu(rfc.max_pdu_size), + chan->conn->mtu - L2CAP_EXT_HDR_SIZE - + L2CAP_SDULEN_SIZE - L2CAP_FCS_SIZE); + rfc.max_pdu_size = cpu_to_le16(size); + chan->remote_mps = size; + + set_bit(CONF_MODE_DONE, &chan->conf_state); + + l2cap_add_conf_opt(&ptr, L2CAP_CONF_RFC, sizeof(rfc), + (unsigned long) &rfc, endptr - ptr); + + break; + + default: + result = L2CAP_CONF_UNACCEPT; + + memset(&rfc, 0, sizeof(rfc)); + rfc.mode = chan->mode; + } + + if (result == L2CAP_CONF_SUCCESS) + set_bit(CONF_OUTPUT_DONE, &chan->conf_state); + } + rsp->scid = cpu_to_le16(chan->dcid); + rsp->result = cpu_to_le16(result); + rsp->flags = cpu_to_le16(0); + + return ptr - data; +} + +static int l2cap_parse_conf_rsp(struct l2cap_chan *chan, void *rsp, int len, + void *data, size_t size, u16 *result) +{ + struct l2cap_conf_req *req = data; + void *ptr = req->data; + void *endptr = data + size; + int type, olen; + unsigned long val; + struct l2cap_conf_rfc rfc = { .mode = L2CAP_MODE_BASIC }; + struct l2cap_conf_efs efs; + + BT_DBG("chan %p, rsp %p, len %d, req %p", chan, rsp, len, data); + + while (len >= L2CAP_CONF_OPT_SIZE) { + len -= l2cap_get_conf_opt(&rsp, &type, &olen, &val); + if (len < 0) + break; + + switch (type) { + case L2CAP_CONF_MTU: + if (olen != 2) + break; + if (val < L2CAP_DEFAULT_MIN_MTU) { + *result = L2CAP_CONF_UNACCEPT; + chan->imtu = L2CAP_DEFAULT_MIN_MTU; + } else + chan->imtu = val; + l2cap_add_conf_opt(&ptr, L2CAP_CONF_MTU, 2, chan->imtu, + endptr - ptr); + break; + + case L2CAP_CONF_FLUSH_TO: + if (olen != 2) + break; + chan->flush_to = val; + l2cap_add_conf_opt(&ptr, L2CAP_CONF_FLUSH_TO, 2, + chan->flush_to, endptr - ptr); + break; + + case L2CAP_CONF_RFC: + if (olen != sizeof(rfc)) + break; + memcpy(&rfc, (void *)val, olen); + if (test_bit(CONF_STATE2_DEVICE, &chan->conf_state) && + rfc.mode != chan->mode) + return -ECONNREFUSED; + chan->fcs = 0; + l2cap_add_conf_opt(&ptr, L2CAP_CONF_RFC, sizeof(rfc), + (unsigned long) &rfc, endptr - ptr); + break; + + case L2CAP_CONF_EWS: + if (olen != 2) + break; + chan->ack_win = min_t(u16, val, chan->ack_win); + l2cap_add_conf_opt(&ptr, L2CAP_CONF_EWS, 2, + chan->tx_win, endptr - ptr); + break; + + case L2CAP_CONF_EFS: + if (olen != sizeof(efs)) + break; + memcpy(&efs, (void *)val, olen); + if (chan->local_stype != L2CAP_SERV_NOTRAFIC && + efs.stype != L2CAP_SERV_NOTRAFIC && + efs.stype != chan->local_stype) + return -ECONNREFUSED; + l2cap_add_conf_opt(&ptr, L2CAP_CONF_EFS, sizeof(efs), + (unsigned long) &efs, endptr - ptr); + break; + + case L2CAP_CONF_FCS: + if (olen != 1) + break; + if (*result == L2CAP_CONF_PENDING) + if (val == L2CAP_FCS_NONE) + set_bit(CONF_RECV_NO_FCS, + &chan->conf_state); + break; + } + } + + if (chan->mode == L2CAP_MODE_BASIC && chan->mode != rfc.mode) + return -ECONNREFUSED; + + chan->mode = rfc.mode; + + if (*result == L2CAP_CONF_SUCCESS || *result == L2CAP_CONF_PENDING) { + switch (rfc.mode) { + case L2CAP_MODE_ERTM: + chan->retrans_timeout = le16_to_cpu(rfc.retrans_timeout); + chan->monitor_timeout = le16_to_cpu(rfc.monitor_timeout); + chan->mps = le16_to_cpu(rfc.max_pdu_size); + if (!test_bit(FLAG_EXT_CTRL, &chan->flags)) + chan->ack_win = min_t(u16, chan->ack_win, + rfc.txwin_size); + + if (test_bit(FLAG_EFS_ENABLE, &chan->flags)) { + chan->local_msdu = le16_to_cpu(efs.msdu); + chan->local_sdu_itime = + le32_to_cpu(efs.sdu_itime); + chan->local_acc_lat = le32_to_cpu(efs.acc_lat); + chan->local_flush_to = + le32_to_cpu(efs.flush_to); + } + break; + + case L2CAP_MODE_STREAMING: + chan->mps = le16_to_cpu(rfc.max_pdu_size); + } + } + + req->dcid = cpu_to_le16(chan->dcid); + req->flags = cpu_to_le16(0); + + return ptr - data; +} + +static int l2cap_build_conf_rsp(struct l2cap_chan *chan, void *data, + u16 result, u16 flags) +{ + struct l2cap_conf_rsp *rsp = data; + void *ptr = rsp->data; + + BT_DBG("chan %p", chan); + + rsp->scid = cpu_to_le16(chan->dcid); + rsp->result = cpu_to_le16(result); + rsp->flags = cpu_to_le16(flags); + + return ptr - data; +} + +void __l2cap_le_connect_rsp_defer(struct l2cap_chan *chan) +{ + struct l2cap_le_conn_rsp rsp; + struct l2cap_conn *conn = chan->conn; + + BT_DBG("chan %p", chan); + + rsp.dcid = cpu_to_le16(chan->scid); + rsp.mtu = cpu_to_le16(chan->imtu); + rsp.mps = cpu_to_le16(chan->mps); + rsp.credits = cpu_to_le16(chan->rx_credits); + rsp.result = cpu_to_le16(L2CAP_CR_LE_SUCCESS); + + l2cap_send_cmd(conn, chan->ident, L2CAP_LE_CONN_RSP, sizeof(rsp), + &rsp); +} + +static void l2cap_ecred_list_defer(struct l2cap_chan *chan, void *data) +{ + int *result = data; + + if (*result || test_bit(FLAG_ECRED_CONN_REQ_SENT, &chan->flags)) + return; + + switch (chan->state) { + case BT_CONNECT2: + /* If channel still pending accept add to result */ + (*result)++; + return; + case BT_CONNECTED: + return; + default: + /* If not connected or pending accept it has been refused */ + *result = -ECONNREFUSED; + return; + } +} + +struct l2cap_ecred_rsp_data { + struct { + struct l2cap_ecred_conn_rsp rsp; + __le16 scid[L2CAP_ECRED_MAX_CID]; + } __packed pdu; + int count; +}; + +static void l2cap_ecred_rsp_defer(struct l2cap_chan *chan, void *data) +{ + struct l2cap_ecred_rsp_data *rsp = data; + + if (test_bit(FLAG_ECRED_CONN_REQ_SENT, &chan->flags)) + return; + + /* Reset ident so only one response is sent */ + chan->ident = 0; + + /* Include all channels pending with the same ident */ + if (!rsp->pdu.rsp.result) + rsp->pdu.rsp.dcid[rsp->count++] = cpu_to_le16(chan->scid); + else + l2cap_chan_del(chan, ECONNRESET); +} + +void __l2cap_ecred_conn_rsp_defer(struct l2cap_chan *chan) +{ + struct l2cap_conn *conn = chan->conn; + struct l2cap_ecred_rsp_data data; + u16 id = chan->ident; + int result = 0; + + if (!id) + return; + + BT_DBG("chan %p id %d", chan, id); + + memset(&data, 0, sizeof(data)); + + data.pdu.rsp.mtu = cpu_to_le16(chan->imtu); + data.pdu.rsp.mps = cpu_to_le16(chan->mps); + data.pdu.rsp.credits = cpu_to_le16(chan->rx_credits); + data.pdu.rsp.result = cpu_to_le16(L2CAP_CR_LE_SUCCESS); + + /* Verify that all channels are ready */ + __l2cap_chan_list_id(conn, id, l2cap_ecred_list_defer, &result); + + if (result > 0) + return; + + if (result < 0) + data.pdu.rsp.result = cpu_to_le16(L2CAP_CR_LE_AUTHORIZATION); + + /* Build response */ + __l2cap_chan_list_id(conn, id, l2cap_ecred_rsp_defer, &data); + + l2cap_send_cmd(conn, id, L2CAP_ECRED_CONN_RSP, + sizeof(data.pdu.rsp) + (data.count * sizeof(__le16)), + &data.pdu); +} + +void __l2cap_connect_rsp_defer(struct l2cap_chan *chan) +{ + struct l2cap_conn_rsp rsp; + struct l2cap_conn *conn = chan->conn; + u8 buf[128]; + u8 rsp_code; + + rsp.scid = cpu_to_le16(chan->dcid); + rsp.dcid = cpu_to_le16(chan->scid); + rsp.result = cpu_to_le16(L2CAP_CR_SUCCESS); + rsp.status = cpu_to_le16(L2CAP_CS_NO_INFO); + + if (chan->hs_hcon) + rsp_code = L2CAP_CREATE_CHAN_RSP; + else + rsp_code = L2CAP_CONN_RSP; + + BT_DBG("chan %p rsp_code %u", chan, rsp_code); + + l2cap_send_cmd(conn, chan->ident, rsp_code, sizeof(rsp), &rsp); + + if (test_and_set_bit(CONF_REQ_SENT, &chan->conf_state)) + return; + + l2cap_send_cmd(conn, l2cap_get_ident(conn), L2CAP_CONF_REQ, + l2cap_build_conf_req(chan, buf, sizeof(buf)), buf); + chan->num_conf_req++; +} + +static void l2cap_conf_rfc_get(struct l2cap_chan *chan, void *rsp, int len) +{ + int type, olen; + unsigned long val; + /* Use sane default values in case a misbehaving remote device + * did not send an RFC or extended window size option. + */ + u16 txwin_ext = chan->ack_win; + struct l2cap_conf_rfc rfc = { + .mode = chan->mode, + .retrans_timeout = cpu_to_le16(L2CAP_DEFAULT_RETRANS_TO), + .monitor_timeout = cpu_to_le16(L2CAP_DEFAULT_MONITOR_TO), + .max_pdu_size = cpu_to_le16(chan->imtu), + .txwin_size = min_t(u16, chan->ack_win, L2CAP_DEFAULT_TX_WINDOW), + }; + + BT_DBG("chan %p, rsp %p, len %d", chan, rsp, len); + + if ((chan->mode != L2CAP_MODE_ERTM) && (chan->mode != L2CAP_MODE_STREAMING)) + return; + + while (len >= L2CAP_CONF_OPT_SIZE) { + len -= l2cap_get_conf_opt(&rsp, &type, &olen, &val); + if (len < 0) + break; + + switch (type) { + case L2CAP_CONF_RFC: + if (olen != sizeof(rfc)) + break; + memcpy(&rfc, (void *)val, olen); + break; + case L2CAP_CONF_EWS: + if (olen != 2) + break; + txwin_ext = val; + break; + } + } + + switch (rfc.mode) { + case L2CAP_MODE_ERTM: + chan->retrans_timeout = le16_to_cpu(rfc.retrans_timeout); + chan->monitor_timeout = le16_to_cpu(rfc.monitor_timeout); + chan->mps = le16_to_cpu(rfc.max_pdu_size); + if (test_bit(FLAG_EXT_CTRL, &chan->flags)) + chan->ack_win = min_t(u16, chan->ack_win, txwin_ext); + else + chan->ack_win = min_t(u16, chan->ack_win, + rfc.txwin_size); + break; + case L2CAP_MODE_STREAMING: + chan->mps = le16_to_cpu(rfc.max_pdu_size); + } +} + +static inline int l2cap_command_rej(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + struct l2cap_cmd_rej_unk *rej = (struct l2cap_cmd_rej_unk *) data; + + if (cmd_len < sizeof(*rej)) + return -EPROTO; + + if (rej->reason != L2CAP_REJ_NOT_UNDERSTOOD) + return 0; + + if ((conn->info_state & L2CAP_INFO_FEAT_MASK_REQ_SENT) && + cmd->ident == conn->info_ident) { + cancel_delayed_work(&conn->info_timer); + + conn->info_state |= L2CAP_INFO_FEAT_MASK_REQ_DONE; + conn->info_ident = 0; + + l2cap_conn_start(conn); + } + + return 0; +} + +static struct l2cap_chan *l2cap_connect(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, + u8 *data, u8 rsp_code, u8 amp_id) +{ + struct l2cap_conn_req *req = (struct l2cap_conn_req *) data; + struct l2cap_conn_rsp rsp; + struct l2cap_chan *chan = NULL, *pchan; + int result, status = L2CAP_CS_NO_INFO; + + u16 dcid = 0, scid = __le16_to_cpu(req->scid); + __le16 psm = req->psm; + + BT_DBG("psm 0x%2.2x scid 0x%4.4x", __le16_to_cpu(psm), scid); + + /* Check if we have socket listening on psm */ + pchan = l2cap_global_chan_by_psm(BT_LISTEN, psm, &conn->hcon->src, + &conn->hcon->dst, ACL_LINK); + if (!pchan) { + result = L2CAP_CR_BAD_PSM; + goto sendresp; + } + + mutex_lock(&conn->chan_lock); + l2cap_chan_lock(pchan); + + /* Check if the ACL is secure enough (if not SDP) */ + if (psm != cpu_to_le16(L2CAP_PSM_SDP) && + !hci_conn_check_link_mode(conn->hcon)) { + conn->disc_reason = HCI_ERROR_AUTH_FAILURE; + result = L2CAP_CR_SEC_BLOCK; + goto response; + } + + result = L2CAP_CR_NO_MEM; + + /* Check for valid dynamic CID range (as per Erratum 3253) */ + if (scid < L2CAP_CID_DYN_START || scid > L2CAP_CID_DYN_END) { + result = L2CAP_CR_INVALID_SCID; + goto response; + } + + /* Check if we already have channel with that dcid */ + if (__l2cap_get_chan_by_dcid(conn, scid)) { + result = L2CAP_CR_SCID_IN_USE; + goto response; + } + + chan = pchan->ops->new_connection(pchan); + if (!chan) + goto response; + + /* For certain devices (ex: HID mouse), support for authentication, + * pairing and bonding is optional. For such devices, inorder to avoid + * the ACL alive for too long after L2CAP disconnection, reset the ACL + * disc_timeout back to HCI_DISCONN_TIMEOUT during L2CAP connect. + */ + conn->hcon->disc_timeout = HCI_DISCONN_TIMEOUT; + + bacpy(&chan->src, &conn->hcon->src); + bacpy(&chan->dst, &conn->hcon->dst); + chan->src_type = bdaddr_src_type(conn->hcon); + chan->dst_type = bdaddr_dst_type(conn->hcon); + chan->psm = psm; + chan->dcid = scid; + chan->local_amp_id = amp_id; + + __l2cap_chan_add(conn, chan); + + dcid = chan->scid; + + __set_chan_timer(chan, chan->ops->get_sndtimeo(chan)); + + chan->ident = cmd->ident; + + if (conn->info_state & L2CAP_INFO_FEAT_MASK_REQ_DONE) { + if (l2cap_chan_check_security(chan, false)) { + if (test_bit(FLAG_DEFER_SETUP, &chan->flags)) { + l2cap_state_change(chan, BT_CONNECT2); + result = L2CAP_CR_PEND; + status = L2CAP_CS_AUTHOR_PEND; + chan->ops->defer(chan); + } else { + /* Force pending result for AMP controllers. + * The connection will succeed after the + * physical link is up. + */ + if (amp_id == AMP_ID_BREDR) { + l2cap_state_change(chan, BT_CONFIG); + result = L2CAP_CR_SUCCESS; + } else { + l2cap_state_change(chan, BT_CONNECT2); + result = L2CAP_CR_PEND; + } + status = L2CAP_CS_NO_INFO; + } + } else { + l2cap_state_change(chan, BT_CONNECT2); + result = L2CAP_CR_PEND; + status = L2CAP_CS_AUTHEN_PEND; + } + } else { + l2cap_state_change(chan, BT_CONNECT2); + result = L2CAP_CR_PEND; + status = L2CAP_CS_NO_INFO; + } + +response: + l2cap_chan_unlock(pchan); + mutex_unlock(&conn->chan_lock); + l2cap_chan_put(pchan); + +sendresp: + rsp.scid = cpu_to_le16(scid); + rsp.dcid = cpu_to_le16(dcid); + rsp.result = cpu_to_le16(result); + rsp.status = cpu_to_le16(status); + l2cap_send_cmd(conn, cmd->ident, rsp_code, sizeof(rsp), &rsp); + + if (result == L2CAP_CR_PEND && status == L2CAP_CS_NO_INFO) { + struct l2cap_info_req info; + info.type = cpu_to_le16(L2CAP_IT_FEAT_MASK); + + conn->info_state |= L2CAP_INFO_FEAT_MASK_REQ_SENT; + conn->info_ident = l2cap_get_ident(conn); + + schedule_delayed_work(&conn->info_timer, L2CAP_INFO_TIMEOUT); + + l2cap_send_cmd(conn, conn->info_ident, L2CAP_INFO_REQ, + sizeof(info), &info); + } + + if (chan && !test_bit(CONF_REQ_SENT, &chan->conf_state) && + result == L2CAP_CR_SUCCESS) { + u8 buf[128]; + set_bit(CONF_REQ_SENT, &chan->conf_state); + l2cap_send_cmd(conn, l2cap_get_ident(conn), L2CAP_CONF_REQ, + l2cap_build_conf_req(chan, buf, sizeof(buf)), buf); + chan->num_conf_req++; + } + + return chan; +} + +static int l2cap_connect_req(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, u8 *data) +{ + struct hci_dev *hdev = conn->hcon->hdev; + struct hci_conn *hcon = conn->hcon; + + if (cmd_len < sizeof(struct l2cap_conn_req)) + return -EPROTO; + + hci_dev_lock(hdev); + if (hci_dev_test_flag(hdev, HCI_MGMT) && + !test_and_set_bit(HCI_CONN_MGMT_CONNECTED, &hcon->flags)) + mgmt_device_connected(hdev, hcon, NULL, 0); + hci_dev_unlock(hdev); + + l2cap_connect(conn, cmd, data, L2CAP_CONN_RSP, 0); + return 0; +} + +static int l2cap_connect_create_rsp(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + struct l2cap_conn_rsp *rsp = (struct l2cap_conn_rsp *) data; + u16 scid, dcid, result, status; + struct l2cap_chan *chan; + u8 req[128]; + int err; + + if (cmd_len < sizeof(*rsp)) + return -EPROTO; + + scid = __le16_to_cpu(rsp->scid); + dcid = __le16_to_cpu(rsp->dcid); + result = __le16_to_cpu(rsp->result); + status = __le16_to_cpu(rsp->status); + + if (result == L2CAP_CR_SUCCESS && (dcid < L2CAP_CID_DYN_START || + dcid > L2CAP_CID_DYN_END)) + return -EPROTO; + + BT_DBG("dcid 0x%4.4x scid 0x%4.4x result 0x%2.2x status 0x%2.2x", + dcid, scid, result, status); + + mutex_lock(&conn->chan_lock); + + if (scid) { + chan = __l2cap_get_chan_by_scid(conn, scid); + if (!chan) { + err = -EBADSLT; + goto unlock; + } + } else { + chan = __l2cap_get_chan_by_ident(conn, cmd->ident); + if (!chan) { + err = -EBADSLT; + goto unlock; + } + } + + chan = l2cap_chan_hold_unless_zero(chan); + if (!chan) { + err = -EBADSLT; + goto unlock; + } + + err = 0; + + l2cap_chan_lock(chan); + + switch (result) { + case L2CAP_CR_SUCCESS: + if (__l2cap_get_chan_by_dcid(conn, dcid)) { + err = -EBADSLT; + break; + } + + l2cap_state_change(chan, BT_CONFIG); + chan->ident = 0; + chan->dcid = dcid; + clear_bit(CONF_CONNECT_PEND, &chan->conf_state); + + if (test_and_set_bit(CONF_REQ_SENT, &chan->conf_state)) + break; + + l2cap_send_cmd(conn, l2cap_get_ident(conn), L2CAP_CONF_REQ, + l2cap_build_conf_req(chan, req, sizeof(req)), req); + chan->num_conf_req++; + break; + + case L2CAP_CR_PEND: + set_bit(CONF_CONNECT_PEND, &chan->conf_state); + break; + + default: + l2cap_chan_del(chan, ECONNREFUSED); + break; + } + + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); + +unlock: + mutex_unlock(&conn->chan_lock); + + return err; +} + +static inline void set_default_fcs(struct l2cap_chan *chan) +{ + /* FCS is enabled only in ERTM or streaming mode, if one or both + * sides request it. + */ + if (chan->mode != L2CAP_MODE_ERTM && chan->mode != L2CAP_MODE_STREAMING) + chan->fcs = L2CAP_FCS_NONE; + else if (!test_bit(CONF_RECV_NO_FCS, &chan->conf_state)) + chan->fcs = L2CAP_FCS_CRC16; +} + +static void l2cap_send_efs_conf_rsp(struct l2cap_chan *chan, void *data, + u8 ident, u16 flags) +{ + struct l2cap_conn *conn = chan->conn; + + BT_DBG("conn %p chan %p ident %d flags 0x%4.4x", conn, chan, ident, + flags); + + clear_bit(CONF_LOC_CONF_PEND, &chan->conf_state); + set_bit(CONF_OUTPUT_DONE, &chan->conf_state); + + l2cap_send_cmd(conn, ident, L2CAP_CONF_RSP, + l2cap_build_conf_rsp(chan, data, + L2CAP_CONF_SUCCESS, flags), data); +} + +static void cmd_reject_invalid_cid(struct l2cap_conn *conn, u8 ident, + u16 scid, u16 dcid) +{ + struct l2cap_cmd_rej_cid rej; + + rej.reason = cpu_to_le16(L2CAP_REJ_INVALID_CID); + rej.scid = __cpu_to_le16(scid); + rej.dcid = __cpu_to_le16(dcid); + + l2cap_send_cmd(conn, ident, L2CAP_COMMAND_REJ, sizeof(rej), &rej); +} + +static inline int l2cap_config_req(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + struct l2cap_conf_req *req = (struct l2cap_conf_req *) data; + u16 dcid, flags; + u8 rsp[64]; + struct l2cap_chan *chan; + int len, err = 0; + + if (cmd_len < sizeof(*req)) + return -EPROTO; + + dcid = __le16_to_cpu(req->dcid); + flags = __le16_to_cpu(req->flags); + + BT_DBG("dcid 0x%4.4x flags 0x%2.2x", dcid, flags); + + chan = l2cap_get_chan_by_scid(conn, dcid); + if (!chan) { + cmd_reject_invalid_cid(conn, cmd->ident, dcid, 0); + return 0; + } + + if (chan->state != BT_CONFIG && chan->state != BT_CONNECT2 && + chan->state != BT_CONNECTED) { + cmd_reject_invalid_cid(conn, cmd->ident, chan->scid, + chan->dcid); + goto unlock; + } + + /* Reject if config buffer is too small. */ + len = cmd_len - sizeof(*req); + if (chan->conf_len + len > sizeof(chan->conf_req)) { + l2cap_send_cmd(conn, cmd->ident, L2CAP_CONF_RSP, + l2cap_build_conf_rsp(chan, rsp, + L2CAP_CONF_REJECT, flags), rsp); + goto unlock; + } + + /* Store config. */ + memcpy(chan->conf_req + chan->conf_len, req->data, len); + chan->conf_len += len; + + if (flags & L2CAP_CONF_FLAG_CONTINUATION) { + /* Incomplete config. Send empty response. */ + l2cap_send_cmd(conn, cmd->ident, L2CAP_CONF_RSP, + l2cap_build_conf_rsp(chan, rsp, + L2CAP_CONF_SUCCESS, flags), rsp); + goto unlock; + } + + /* Complete config. */ + len = l2cap_parse_conf_req(chan, rsp, sizeof(rsp)); + if (len < 0) { + l2cap_send_disconn_req(chan, ECONNRESET); + goto unlock; + } + + chan->ident = cmd->ident; + l2cap_send_cmd(conn, cmd->ident, L2CAP_CONF_RSP, len, rsp); + if (chan->num_conf_rsp < L2CAP_CONF_MAX_CONF_RSP) + chan->num_conf_rsp++; + + /* Reset config buffer. */ + chan->conf_len = 0; + + if (!test_bit(CONF_OUTPUT_DONE, &chan->conf_state)) + goto unlock; + + if (test_bit(CONF_INPUT_DONE, &chan->conf_state)) { + set_default_fcs(chan); + + if (chan->mode == L2CAP_MODE_ERTM || + chan->mode == L2CAP_MODE_STREAMING) + err = l2cap_ertm_init(chan); + + if (err < 0) + l2cap_send_disconn_req(chan, -err); + else + l2cap_chan_ready(chan); + + goto unlock; + } + + if (!test_and_set_bit(CONF_REQ_SENT, &chan->conf_state)) { + u8 buf[64]; + l2cap_send_cmd(conn, l2cap_get_ident(conn), L2CAP_CONF_REQ, + l2cap_build_conf_req(chan, buf, sizeof(buf)), buf); + chan->num_conf_req++; + } + + /* Got Conf Rsp PENDING from remote side and assume we sent + Conf Rsp PENDING in the code above */ + if (test_bit(CONF_REM_CONF_PEND, &chan->conf_state) && + test_bit(CONF_LOC_CONF_PEND, &chan->conf_state)) { + + /* check compatibility */ + + /* Send rsp for BR/EDR channel */ + if (!chan->hs_hcon) + l2cap_send_efs_conf_rsp(chan, rsp, cmd->ident, flags); + else + chan->ident = cmd->ident; + } + +unlock: + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); + return err; +} + +static inline int l2cap_config_rsp(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + struct l2cap_conf_rsp *rsp = (struct l2cap_conf_rsp *)data; + u16 scid, flags, result; + struct l2cap_chan *chan; + int len = cmd_len - sizeof(*rsp); + int err = 0; + + if (cmd_len < sizeof(*rsp)) + return -EPROTO; + + scid = __le16_to_cpu(rsp->scid); + flags = __le16_to_cpu(rsp->flags); + result = __le16_to_cpu(rsp->result); + + BT_DBG("scid 0x%4.4x flags 0x%2.2x result 0x%2.2x len %d", scid, flags, + result, len); + + chan = l2cap_get_chan_by_scid(conn, scid); + if (!chan) + return 0; + + switch (result) { + case L2CAP_CONF_SUCCESS: + l2cap_conf_rfc_get(chan, rsp->data, len); + clear_bit(CONF_REM_CONF_PEND, &chan->conf_state); + break; + + case L2CAP_CONF_PENDING: + set_bit(CONF_REM_CONF_PEND, &chan->conf_state); + + if (test_bit(CONF_LOC_CONF_PEND, &chan->conf_state)) { + char buf[64]; + + len = l2cap_parse_conf_rsp(chan, rsp->data, len, + buf, sizeof(buf), &result); + if (len < 0) { + l2cap_send_disconn_req(chan, ECONNRESET); + goto done; + } + + if (!chan->hs_hcon) { + l2cap_send_efs_conf_rsp(chan, buf, cmd->ident, + 0); + } else { + if (l2cap_check_efs(chan)) { + amp_create_logical_link(chan); + chan->ident = cmd->ident; + } + } + } + goto done; + + case L2CAP_CONF_UNKNOWN: + case L2CAP_CONF_UNACCEPT: + if (chan->num_conf_rsp <= L2CAP_CONF_MAX_CONF_RSP) { + char req[64]; + + if (len > sizeof(req) - sizeof(struct l2cap_conf_req)) { + l2cap_send_disconn_req(chan, ECONNRESET); + goto done; + } + + /* throw out any old stored conf requests */ + result = L2CAP_CONF_SUCCESS; + len = l2cap_parse_conf_rsp(chan, rsp->data, len, + req, sizeof(req), &result); + if (len < 0) { + l2cap_send_disconn_req(chan, ECONNRESET); + goto done; + } + + l2cap_send_cmd(conn, l2cap_get_ident(conn), + L2CAP_CONF_REQ, len, req); + chan->num_conf_req++; + if (result != L2CAP_CONF_SUCCESS) + goto done; + break; + } + fallthrough; + + default: + l2cap_chan_set_err(chan, ECONNRESET); + + __set_chan_timer(chan, L2CAP_DISC_REJ_TIMEOUT); + l2cap_send_disconn_req(chan, ECONNRESET); + goto done; + } + + if (flags & L2CAP_CONF_FLAG_CONTINUATION) + goto done; + + set_bit(CONF_INPUT_DONE, &chan->conf_state); + + if (test_bit(CONF_OUTPUT_DONE, &chan->conf_state)) { + set_default_fcs(chan); + + if (chan->mode == L2CAP_MODE_ERTM || + chan->mode == L2CAP_MODE_STREAMING) + err = l2cap_ertm_init(chan); + + if (err < 0) + l2cap_send_disconn_req(chan, -err); + else + l2cap_chan_ready(chan); + } + +done: + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); + return err; +} + +static inline int l2cap_disconnect_req(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + struct l2cap_disconn_req *req = (struct l2cap_disconn_req *) data; + struct l2cap_disconn_rsp rsp; + u16 dcid, scid; + struct l2cap_chan *chan; + + if (cmd_len != sizeof(*req)) + return -EPROTO; + + scid = __le16_to_cpu(req->scid); + dcid = __le16_to_cpu(req->dcid); + + BT_DBG("scid 0x%4.4x dcid 0x%4.4x", scid, dcid); + + chan = l2cap_get_chan_by_scid(conn, dcid); + if (!chan) { + cmd_reject_invalid_cid(conn, cmd->ident, dcid, scid); + return 0; + } + + rsp.dcid = cpu_to_le16(chan->scid); + rsp.scid = cpu_to_le16(chan->dcid); + l2cap_send_cmd(conn, cmd->ident, L2CAP_DISCONN_RSP, sizeof(rsp), &rsp); + + chan->ops->set_shutdown(chan); + + l2cap_chan_unlock(chan); + mutex_lock(&conn->chan_lock); + l2cap_chan_lock(chan); + l2cap_chan_del(chan, ECONNRESET); + mutex_unlock(&conn->chan_lock); + + chan->ops->close(chan); + + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); + + return 0; +} + +static inline int l2cap_disconnect_rsp(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + struct l2cap_disconn_rsp *rsp = (struct l2cap_disconn_rsp *) data; + u16 dcid, scid; + struct l2cap_chan *chan; + + if (cmd_len != sizeof(*rsp)) + return -EPROTO; + + scid = __le16_to_cpu(rsp->scid); + dcid = __le16_to_cpu(rsp->dcid); + + BT_DBG("dcid 0x%4.4x scid 0x%4.4x", dcid, scid); + + chan = l2cap_get_chan_by_scid(conn, scid); + if (!chan) { + return 0; + } + + if (chan->state != BT_DISCONN) { + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); + return 0; + } + + l2cap_chan_unlock(chan); + mutex_lock(&conn->chan_lock); + l2cap_chan_lock(chan); + l2cap_chan_del(chan, 0); + mutex_unlock(&conn->chan_lock); + + chan->ops->close(chan); + + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); + + return 0; +} + +static inline int l2cap_information_req(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + struct l2cap_info_req *req = (struct l2cap_info_req *) data; + u16 type; + + if (cmd_len != sizeof(*req)) + return -EPROTO; + + type = __le16_to_cpu(req->type); + + BT_DBG("type 0x%4.4x", type); + + if (type == L2CAP_IT_FEAT_MASK) { + u8 buf[8]; + u32 feat_mask = l2cap_feat_mask; + struct l2cap_info_rsp *rsp = (struct l2cap_info_rsp *) buf; + rsp->type = cpu_to_le16(L2CAP_IT_FEAT_MASK); + rsp->result = cpu_to_le16(L2CAP_IR_SUCCESS); + if (!disable_ertm) + feat_mask |= L2CAP_FEAT_ERTM | L2CAP_FEAT_STREAMING + | L2CAP_FEAT_FCS; + if (conn->local_fixed_chan & L2CAP_FC_A2MP) + feat_mask |= L2CAP_FEAT_EXT_FLOW + | L2CAP_FEAT_EXT_WINDOW; + + put_unaligned_le32(feat_mask, rsp->data); + l2cap_send_cmd(conn, cmd->ident, L2CAP_INFO_RSP, sizeof(buf), + buf); + } else if (type == L2CAP_IT_FIXED_CHAN) { + u8 buf[12]; + struct l2cap_info_rsp *rsp = (struct l2cap_info_rsp *) buf; + + rsp->type = cpu_to_le16(L2CAP_IT_FIXED_CHAN); + rsp->result = cpu_to_le16(L2CAP_IR_SUCCESS); + rsp->data[0] = conn->local_fixed_chan; + memset(rsp->data + 1, 0, 7); + l2cap_send_cmd(conn, cmd->ident, L2CAP_INFO_RSP, sizeof(buf), + buf); + } else { + struct l2cap_info_rsp rsp; + rsp.type = cpu_to_le16(type); + rsp.result = cpu_to_le16(L2CAP_IR_NOTSUPP); + l2cap_send_cmd(conn, cmd->ident, L2CAP_INFO_RSP, sizeof(rsp), + &rsp); + } + + return 0; +} + +static inline int l2cap_information_rsp(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + struct l2cap_info_rsp *rsp = (struct l2cap_info_rsp *) data; + u16 type, result; + + if (cmd_len < sizeof(*rsp)) + return -EPROTO; + + type = __le16_to_cpu(rsp->type); + result = __le16_to_cpu(rsp->result); + + BT_DBG("type 0x%4.4x result 0x%2.2x", type, result); + + /* L2CAP Info req/rsp are unbound to channels, add extra checks */ + if (cmd->ident != conn->info_ident || + conn->info_state & L2CAP_INFO_FEAT_MASK_REQ_DONE) + return 0; + + cancel_delayed_work(&conn->info_timer); + + if (result != L2CAP_IR_SUCCESS) { + conn->info_state |= L2CAP_INFO_FEAT_MASK_REQ_DONE; + conn->info_ident = 0; + + l2cap_conn_start(conn); + + return 0; + } + + switch (type) { + case L2CAP_IT_FEAT_MASK: + conn->feat_mask = get_unaligned_le32(rsp->data); + + if (conn->feat_mask & L2CAP_FEAT_FIXED_CHAN) { + struct l2cap_info_req req; + req.type = cpu_to_le16(L2CAP_IT_FIXED_CHAN); + + conn->info_ident = l2cap_get_ident(conn); + + l2cap_send_cmd(conn, conn->info_ident, + L2CAP_INFO_REQ, sizeof(req), &req); + } else { + conn->info_state |= L2CAP_INFO_FEAT_MASK_REQ_DONE; + conn->info_ident = 0; + + l2cap_conn_start(conn); + } + break; + + case L2CAP_IT_FIXED_CHAN: + conn->remote_fixed_chan = rsp->data[0]; + conn->info_state |= L2CAP_INFO_FEAT_MASK_REQ_DONE; + conn->info_ident = 0; + + l2cap_conn_start(conn); + break; + } + + return 0; +} + +static int l2cap_create_channel_req(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, + u16 cmd_len, void *data) +{ + struct l2cap_create_chan_req *req = data; + struct l2cap_create_chan_rsp rsp; + struct l2cap_chan *chan; + struct hci_dev *hdev; + u16 psm, scid; + + if (cmd_len != sizeof(*req)) + return -EPROTO; + + if (!(conn->local_fixed_chan & L2CAP_FC_A2MP)) + return -EINVAL; + + psm = le16_to_cpu(req->psm); + scid = le16_to_cpu(req->scid); + + BT_DBG("psm 0x%2.2x, scid 0x%4.4x, amp_id %d", psm, scid, req->amp_id); + + /* For controller id 0 make BR/EDR connection */ + if (req->amp_id == AMP_ID_BREDR) { + l2cap_connect(conn, cmd, data, L2CAP_CREATE_CHAN_RSP, + req->amp_id); + return 0; + } + + /* Validate AMP controller id */ + hdev = hci_dev_get(req->amp_id); + if (!hdev) + goto error; + + if (hdev->dev_type != HCI_AMP || !test_bit(HCI_UP, &hdev->flags)) { + hci_dev_put(hdev); + goto error; + } + + chan = l2cap_connect(conn, cmd, data, L2CAP_CREATE_CHAN_RSP, + req->amp_id); + if (chan) { + struct amp_mgr *mgr = conn->hcon->amp_mgr; + struct hci_conn *hs_hcon; + + hs_hcon = hci_conn_hash_lookup_ba(hdev, AMP_LINK, + &conn->hcon->dst); + if (!hs_hcon) { + hci_dev_put(hdev); + cmd_reject_invalid_cid(conn, cmd->ident, chan->scid, + chan->dcid); + return 0; + } + + BT_DBG("mgr %p bredr_chan %p hs_hcon %p", mgr, chan, hs_hcon); + + mgr->bredr_chan = chan; + chan->hs_hcon = hs_hcon; + chan->fcs = L2CAP_FCS_NONE; + conn->mtu = hdev->block_mtu; + } + + hci_dev_put(hdev); + + return 0; + +error: + rsp.dcid = 0; + rsp.scid = cpu_to_le16(scid); + rsp.result = cpu_to_le16(L2CAP_CR_BAD_AMP); + rsp.status = cpu_to_le16(L2CAP_CS_NO_INFO); + + l2cap_send_cmd(conn, cmd->ident, L2CAP_CREATE_CHAN_RSP, + sizeof(rsp), &rsp); + + return 0; +} + +static void l2cap_send_move_chan_req(struct l2cap_chan *chan, u8 dest_amp_id) +{ + struct l2cap_move_chan_req req; + u8 ident; + + BT_DBG("chan %p, dest_amp_id %d", chan, dest_amp_id); + + ident = l2cap_get_ident(chan->conn); + chan->ident = ident; + + req.icid = cpu_to_le16(chan->scid); + req.dest_amp_id = dest_amp_id; + + l2cap_send_cmd(chan->conn, ident, L2CAP_MOVE_CHAN_REQ, sizeof(req), + &req); + + __set_chan_timer(chan, L2CAP_MOVE_TIMEOUT); +} + +static void l2cap_send_move_chan_rsp(struct l2cap_chan *chan, u16 result) +{ + struct l2cap_move_chan_rsp rsp; + + BT_DBG("chan %p, result 0x%4.4x", chan, result); + + rsp.icid = cpu_to_le16(chan->dcid); + rsp.result = cpu_to_le16(result); + + l2cap_send_cmd(chan->conn, chan->ident, L2CAP_MOVE_CHAN_RSP, + sizeof(rsp), &rsp); +} + +static void l2cap_send_move_chan_cfm(struct l2cap_chan *chan, u16 result) +{ + struct l2cap_move_chan_cfm cfm; + + BT_DBG("chan %p, result 0x%4.4x", chan, result); + + chan->ident = l2cap_get_ident(chan->conn); + + cfm.icid = cpu_to_le16(chan->scid); + cfm.result = cpu_to_le16(result); + + l2cap_send_cmd(chan->conn, chan->ident, L2CAP_MOVE_CHAN_CFM, + sizeof(cfm), &cfm); + + __set_chan_timer(chan, L2CAP_MOVE_TIMEOUT); +} + +static void l2cap_send_move_chan_cfm_icid(struct l2cap_conn *conn, u16 icid) +{ + struct l2cap_move_chan_cfm cfm; + + BT_DBG("conn %p, icid 0x%4.4x", conn, icid); + + cfm.icid = cpu_to_le16(icid); + cfm.result = cpu_to_le16(L2CAP_MC_UNCONFIRMED); + + l2cap_send_cmd(conn, l2cap_get_ident(conn), L2CAP_MOVE_CHAN_CFM, + sizeof(cfm), &cfm); +} + +static void l2cap_send_move_chan_cfm_rsp(struct l2cap_conn *conn, u8 ident, + u16 icid) +{ + struct l2cap_move_chan_cfm_rsp rsp; + + BT_DBG("icid 0x%4.4x", icid); + + rsp.icid = cpu_to_le16(icid); + l2cap_send_cmd(conn, ident, L2CAP_MOVE_CHAN_CFM_RSP, sizeof(rsp), &rsp); +} + +static void __release_logical_link(struct l2cap_chan *chan) +{ + chan->hs_hchan = NULL; + chan->hs_hcon = NULL; + + /* Placeholder - release the logical link */ +} + +static void l2cap_logical_fail(struct l2cap_chan *chan) +{ + /* Logical link setup failed */ + if (chan->state != BT_CONNECTED) { + /* Create channel failure, disconnect */ + l2cap_send_disconn_req(chan, ECONNRESET); + return; + } + + switch (chan->move_role) { + case L2CAP_MOVE_ROLE_RESPONDER: + l2cap_move_done(chan); + l2cap_send_move_chan_rsp(chan, L2CAP_MR_NOT_SUPP); + break; + case L2CAP_MOVE_ROLE_INITIATOR: + if (chan->move_state == L2CAP_MOVE_WAIT_LOGICAL_COMP || + chan->move_state == L2CAP_MOVE_WAIT_LOGICAL_CFM) { + /* Remote has only sent pending or + * success responses, clean up + */ + l2cap_move_done(chan); + } + + /* Other amp move states imply that the move + * has already aborted + */ + l2cap_send_move_chan_cfm(chan, L2CAP_MC_UNCONFIRMED); + break; + } +} + +static void l2cap_logical_finish_create(struct l2cap_chan *chan, + struct hci_chan *hchan) +{ + struct l2cap_conf_rsp rsp; + + chan->hs_hchan = hchan; + chan->hs_hcon->l2cap_data = chan->conn; + + l2cap_send_efs_conf_rsp(chan, &rsp, chan->ident, 0); + + if (test_bit(CONF_INPUT_DONE, &chan->conf_state)) { + int err; + + set_default_fcs(chan); + + err = l2cap_ertm_init(chan); + if (err < 0) + l2cap_send_disconn_req(chan, -err); + else + l2cap_chan_ready(chan); + } +} + +static void l2cap_logical_finish_move(struct l2cap_chan *chan, + struct hci_chan *hchan) +{ + chan->hs_hcon = hchan->conn; + chan->hs_hcon->l2cap_data = chan->conn; + + BT_DBG("move_state %d", chan->move_state); + + switch (chan->move_state) { + case L2CAP_MOVE_WAIT_LOGICAL_COMP: + /* Move confirm will be sent after a success + * response is received + */ + chan->move_state = L2CAP_MOVE_WAIT_RSP_SUCCESS; + break; + case L2CAP_MOVE_WAIT_LOGICAL_CFM: + if (test_bit(CONN_LOCAL_BUSY, &chan->conn_state)) { + chan->move_state = L2CAP_MOVE_WAIT_LOCAL_BUSY; + } else if (chan->move_role == L2CAP_MOVE_ROLE_INITIATOR) { + chan->move_state = L2CAP_MOVE_WAIT_CONFIRM_RSP; + l2cap_send_move_chan_cfm(chan, L2CAP_MC_CONFIRMED); + } else if (chan->move_role == L2CAP_MOVE_ROLE_RESPONDER) { + chan->move_state = L2CAP_MOVE_WAIT_CONFIRM; + l2cap_send_move_chan_rsp(chan, L2CAP_MR_SUCCESS); + } + break; + default: + /* Move was not in expected state, free the channel */ + __release_logical_link(chan); + + chan->move_state = L2CAP_MOVE_STABLE; + } +} + +/* Call with chan locked */ +void l2cap_logical_cfm(struct l2cap_chan *chan, struct hci_chan *hchan, + u8 status) +{ + BT_DBG("chan %p, hchan %p, status %d", chan, hchan, status); + + if (status) { + l2cap_logical_fail(chan); + __release_logical_link(chan); + return; + } + + if (chan->state != BT_CONNECTED) { + /* Ignore logical link if channel is on BR/EDR */ + if (chan->local_amp_id != AMP_ID_BREDR) + l2cap_logical_finish_create(chan, hchan); + } else { + l2cap_logical_finish_move(chan, hchan); + } +} + +void l2cap_move_start(struct l2cap_chan *chan) +{ + BT_DBG("chan %p", chan); + + if (chan->local_amp_id == AMP_ID_BREDR) { + if (chan->chan_policy != BT_CHANNEL_POLICY_AMP_PREFERRED) + return; + chan->move_role = L2CAP_MOVE_ROLE_INITIATOR; + chan->move_state = L2CAP_MOVE_WAIT_PREPARE; + /* Placeholder - start physical link setup */ + } else { + chan->move_role = L2CAP_MOVE_ROLE_INITIATOR; + chan->move_state = L2CAP_MOVE_WAIT_RSP_SUCCESS; + chan->move_id = 0; + l2cap_move_setup(chan); + l2cap_send_move_chan_req(chan, 0); + } +} + +static void l2cap_do_create(struct l2cap_chan *chan, int result, + u8 local_amp_id, u8 remote_amp_id) +{ + BT_DBG("chan %p state %s %u -> %u", chan, state_to_string(chan->state), + local_amp_id, remote_amp_id); + + chan->fcs = L2CAP_FCS_NONE; + + /* Outgoing channel on AMP */ + if (chan->state == BT_CONNECT) { + if (result == L2CAP_CR_SUCCESS) { + chan->local_amp_id = local_amp_id; + l2cap_send_create_chan_req(chan, remote_amp_id); + } else { + /* Revert to BR/EDR connect */ + l2cap_send_conn_req(chan); + } + + return; + } + + /* Incoming channel on AMP */ + if (__l2cap_no_conn_pending(chan)) { + struct l2cap_conn_rsp rsp; + char buf[128]; + rsp.scid = cpu_to_le16(chan->dcid); + rsp.dcid = cpu_to_le16(chan->scid); + + if (result == L2CAP_CR_SUCCESS) { + /* Send successful response */ + rsp.result = cpu_to_le16(L2CAP_CR_SUCCESS); + rsp.status = cpu_to_le16(L2CAP_CS_NO_INFO); + } else { + /* Send negative response */ + rsp.result = cpu_to_le16(L2CAP_CR_NO_MEM); + rsp.status = cpu_to_le16(L2CAP_CS_NO_INFO); + } + + l2cap_send_cmd(chan->conn, chan->ident, L2CAP_CREATE_CHAN_RSP, + sizeof(rsp), &rsp); + + if (result == L2CAP_CR_SUCCESS) { + l2cap_state_change(chan, BT_CONFIG); + set_bit(CONF_REQ_SENT, &chan->conf_state); + l2cap_send_cmd(chan->conn, l2cap_get_ident(chan->conn), + L2CAP_CONF_REQ, + l2cap_build_conf_req(chan, buf, sizeof(buf)), buf); + chan->num_conf_req++; + } + } +} + +static void l2cap_do_move_initiate(struct l2cap_chan *chan, u8 local_amp_id, + u8 remote_amp_id) +{ + l2cap_move_setup(chan); + chan->move_id = local_amp_id; + chan->move_state = L2CAP_MOVE_WAIT_RSP; + + l2cap_send_move_chan_req(chan, remote_amp_id); +} + +static void l2cap_do_move_respond(struct l2cap_chan *chan, int result) +{ + struct hci_chan *hchan = NULL; + + /* Placeholder - get hci_chan for logical link */ + + if (hchan) { + if (hchan->state == BT_CONNECTED) { + /* Logical link is ready to go */ + chan->hs_hcon = hchan->conn; + chan->hs_hcon->l2cap_data = chan->conn; + chan->move_state = L2CAP_MOVE_WAIT_CONFIRM; + l2cap_send_move_chan_rsp(chan, L2CAP_MR_SUCCESS); + + l2cap_logical_cfm(chan, hchan, L2CAP_MR_SUCCESS); + } else { + /* Wait for logical link to be ready */ + chan->move_state = L2CAP_MOVE_WAIT_LOGICAL_CFM; + } + } else { + /* Logical link not available */ + l2cap_send_move_chan_rsp(chan, L2CAP_MR_NOT_ALLOWED); + } +} + +static void l2cap_do_move_cancel(struct l2cap_chan *chan, int result) +{ + if (chan->move_role == L2CAP_MOVE_ROLE_RESPONDER) { + u8 rsp_result; + if (result == -EINVAL) + rsp_result = L2CAP_MR_BAD_ID; + else + rsp_result = L2CAP_MR_NOT_ALLOWED; + + l2cap_send_move_chan_rsp(chan, rsp_result); + } + + chan->move_role = L2CAP_MOVE_ROLE_NONE; + chan->move_state = L2CAP_MOVE_STABLE; + + /* Restart data transmission */ + l2cap_ertm_send(chan); +} + +/* Invoke with locked chan */ +void __l2cap_physical_cfm(struct l2cap_chan *chan, int result) +{ + u8 local_amp_id = chan->local_amp_id; + u8 remote_amp_id = chan->remote_amp_id; + + BT_DBG("chan %p, result %d, local_amp_id %d, remote_amp_id %d", + chan, result, local_amp_id, remote_amp_id); + + if (chan->state == BT_DISCONN || chan->state == BT_CLOSED) + return; + + if (chan->state != BT_CONNECTED) { + l2cap_do_create(chan, result, local_amp_id, remote_amp_id); + } else if (result != L2CAP_MR_SUCCESS) { + l2cap_do_move_cancel(chan, result); + } else { + switch (chan->move_role) { + case L2CAP_MOVE_ROLE_INITIATOR: + l2cap_do_move_initiate(chan, local_amp_id, + remote_amp_id); + break; + case L2CAP_MOVE_ROLE_RESPONDER: + l2cap_do_move_respond(chan, result); + break; + default: + l2cap_do_move_cancel(chan, result); + break; + } + } +} + +static inline int l2cap_move_channel_req(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, + u16 cmd_len, void *data) +{ + struct l2cap_move_chan_req *req = data; + struct l2cap_move_chan_rsp rsp; + struct l2cap_chan *chan; + u16 icid = 0; + u16 result = L2CAP_MR_NOT_ALLOWED; + + if (cmd_len != sizeof(*req)) + return -EPROTO; + + icid = le16_to_cpu(req->icid); + + BT_DBG("icid 0x%4.4x, dest_amp_id %d", icid, req->dest_amp_id); + + if (!(conn->local_fixed_chan & L2CAP_FC_A2MP)) + return -EINVAL; + + chan = l2cap_get_chan_by_dcid(conn, icid); + if (!chan) { + rsp.icid = cpu_to_le16(icid); + rsp.result = cpu_to_le16(L2CAP_MR_NOT_ALLOWED); + l2cap_send_cmd(conn, cmd->ident, L2CAP_MOVE_CHAN_RSP, + sizeof(rsp), &rsp); + return 0; + } + + chan->ident = cmd->ident; + + if (chan->scid < L2CAP_CID_DYN_START || + chan->chan_policy == BT_CHANNEL_POLICY_BREDR_ONLY || + (chan->mode != L2CAP_MODE_ERTM && + chan->mode != L2CAP_MODE_STREAMING)) { + result = L2CAP_MR_NOT_ALLOWED; + goto send_move_response; + } + + if (chan->local_amp_id == req->dest_amp_id) { + result = L2CAP_MR_SAME_ID; + goto send_move_response; + } + + if (req->dest_amp_id != AMP_ID_BREDR) { + struct hci_dev *hdev; + hdev = hci_dev_get(req->dest_amp_id); + if (!hdev || hdev->dev_type != HCI_AMP || + !test_bit(HCI_UP, &hdev->flags)) { + if (hdev) + hci_dev_put(hdev); + + result = L2CAP_MR_BAD_ID; + goto send_move_response; + } + hci_dev_put(hdev); + } + + /* Detect a move collision. Only send a collision response + * if this side has "lost", otherwise proceed with the move. + * The winner has the larger bd_addr. + */ + if ((__chan_is_moving(chan) || + chan->move_role != L2CAP_MOVE_ROLE_NONE) && + bacmp(&conn->hcon->src, &conn->hcon->dst) > 0) { + result = L2CAP_MR_COLLISION; + goto send_move_response; + } + + chan->move_role = L2CAP_MOVE_ROLE_RESPONDER; + l2cap_move_setup(chan); + chan->move_id = req->dest_amp_id; + + if (req->dest_amp_id == AMP_ID_BREDR) { + /* Moving to BR/EDR */ + if (test_bit(CONN_LOCAL_BUSY, &chan->conn_state)) { + chan->move_state = L2CAP_MOVE_WAIT_LOCAL_BUSY; + result = L2CAP_MR_PEND; + } else { + chan->move_state = L2CAP_MOVE_WAIT_CONFIRM; + result = L2CAP_MR_SUCCESS; + } + } else { + chan->move_state = L2CAP_MOVE_WAIT_PREPARE; + /* Placeholder - uncomment when amp functions are available */ + /*amp_accept_physical(chan, req->dest_amp_id);*/ + result = L2CAP_MR_PEND; + } + +send_move_response: + l2cap_send_move_chan_rsp(chan, result); + + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); + + return 0; +} + +static void l2cap_move_continue(struct l2cap_conn *conn, u16 icid, u16 result) +{ + struct l2cap_chan *chan; + struct hci_chan *hchan = NULL; + + chan = l2cap_get_chan_by_scid(conn, icid); + if (!chan) { + l2cap_send_move_chan_cfm_icid(conn, icid); + return; + } + + __clear_chan_timer(chan); + if (result == L2CAP_MR_PEND) + __set_chan_timer(chan, L2CAP_MOVE_ERTX_TIMEOUT); + + switch (chan->move_state) { + case L2CAP_MOVE_WAIT_LOGICAL_COMP: + /* Move confirm will be sent when logical link + * is complete. + */ + chan->move_state = L2CAP_MOVE_WAIT_LOGICAL_CFM; + break; + case L2CAP_MOVE_WAIT_RSP_SUCCESS: + if (result == L2CAP_MR_PEND) { + break; + } else if (test_bit(CONN_LOCAL_BUSY, + &chan->conn_state)) { + chan->move_state = L2CAP_MOVE_WAIT_LOCAL_BUSY; + } else { + /* Logical link is up or moving to BR/EDR, + * proceed with move + */ + chan->move_state = L2CAP_MOVE_WAIT_CONFIRM_RSP; + l2cap_send_move_chan_cfm(chan, L2CAP_MC_CONFIRMED); + } + break; + case L2CAP_MOVE_WAIT_RSP: + /* Moving to AMP */ + if (result == L2CAP_MR_SUCCESS) { + /* Remote is ready, send confirm immediately + * after logical link is ready + */ + chan->move_state = L2CAP_MOVE_WAIT_LOGICAL_CFM; + } else { + /* Both logical link and move success + * are required to confirm + */ + chan->move_state = L2CAP_MOVE_WAIT_LOGICAL_COMP; + } + + /* Placeholder - get hci_chan for logical link */ + if (!hchan) { + /* Logical link not available */ + l2cap_send_move_chan_cfm(chan, L2CAP_MC_UNCONFIRMED); + break; + } + + /* If the logical link is not yet connected, do not + * send confirmation. + */ + if (hchan->state != BT_CONNECTED) + break; + + /* Logical link is already ready to go */ + + chan->hs_hcon = hchan->conn; + chan->hs_hcon->l2cap_data = chan->conn; + + if (result == L2CAP_MR_SUCCESS) { + /* Can confirm now */ + l2cap_send_move_chan_cfm(chan, L2CAP_MC_CONFIRMED); + } else { + /* Now only need move success + * to confirm + */ + chan->move_state = L2CAP_MOVE_WAIT_RSP_SUCCESS; + } + + l2cap_logical_cfm(chan, hchan, L2CAP_MR_SUCCESS); + break; + default: + /* Any other amp move state means the move failed. */ + chan->move_id = chan->local_amp_id; + l2cap_move_done(chan); + l2cap_send_move_chan_cfm(chan, L2CAP_MC_UNCONFIRMED); + } + + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); +} + +static void l2cap_move_fail(struct l2cap_conn *conn, u8 ident, u16 icid, + u16 result) +{ + struct l2cap_chan *chan; + + chan = l2cap_get_chan_by_ident(conn, ident); + if (!chan) { + /* Could not locate channel, icid is best guess */ + l2cap_send_move_chan_cfm_icid(conn, icid); + return; + } + + __clear_chan_timer(chan); + + if (chan->move_role == L2CAP_MOVE_ROLE_INITIATOR) { + if (result == L2CAP_MR_COLLISION) { + chan->move_role = L2CAP_MOVE_ROLE_RESPONDER; + } else { + /* Cleanup - cancel move */ + chan->move_id = chan->local_amp_id; + l2cap_move_done(chan); + } + } + + l2cap_send_move_chan_cfm(chan, L2CAP_MC_UNCONFIRMED); + + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); +} + +static int l2cap_move_channel_rsp(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, + u16 cmd_len, void *data) +{ + struct l2cap_move_chan_rsp *rsp = data; + u16 icid, result; + + if (cmd_len != sizeof(*rsp)) + return -EPROTO; + + icid = le16_to_cpu(rsp->icid); + result = le16_to_cpu(rsp->result); + + BT_DBG("icid 0x%4.4x, result 0x%4.4x", icid, result); + + if (result == L2CAP_MR_SUCCESS || result == L2CAP_MR_PEND) + l2cap_move_continue(conn, icid, result); + else + l2cap_move_fail(conn, cmd->ident, icid, result); + + return 0; +} + +static int l2cap_move_channel_confirm(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, + u16 cmd_len, void *data) +{ + struct l2cap_move_chan_cfm *cfm = data; + struct l2cap_chan *chan; + u16 icid, result; + + if (cmd_len != sizeof(*cfm)) + return -EPROTO; + + icid = le16_to_cpu(cfm->icid); + result = le16_to_cpu(cfm->result); + + BT_DBG("icid 0x%4.4x, result 0x%4.4x", icid, result); + + chan = l2cap_get_chan_by_dcid(conn, icid); + if (!chan) { + /* Spec requires a response even if the icid was not found */ + l2cap_send_move_chan_cfm_rsp(conn, cmd->ident, icid); + return 0; + } + + if (chan->move_state == L2CAP_MOVE_WAIT_CONFIRM) { + if (result == L2CAP_MC_CONFIRMED) { + chan->local_amp_id = chan->move_id; + if (chan->local_amp_id == AMP_ID_BREDR) + __release_logical_link(chan); + } else { + chan->move_id = chan->local_amp_id; + } + + l2cap_move_done(chan); + } + + l2cap_send_move_chan_cfm_rsp(conn, cmd->ident, icid); + + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); + + return 0; +} + +static inline int l2cap_move_channel_confirm_rsp(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, + u16 cmd_len, void *data) +{ + struct l2cap_move_chan_cfm_rsp *rsp = data; + struct l2cap_chan *chan; + u16 icid; + + if (cmd_len != sizeof(*rsp)) + return -EPROTO; + + icid = le16_to_cpu(rsp->icid); + + BT_DBG("icid 0x%4.4x", icid); + + chan = l2cap_get_chan_by_scid(conn, icid); + if (!chan) + return 0; + + __clear_chan_timer(chan); + + if (chan->move_state == L2CAP_MOVE_WAIT_CONFIRM_RSP) { + chan->local_amp_id = chan->move_id; + + if (chan->local_amp_id == AMP_ID_BREDR && chan->hs_hchan) + __release_logical_link(chan); + + l2cap_move_done(chan); + } + + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); + + return 0; +} + +static inline int l2cap_conn_param_update_req(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, + u16 cmd_len, u8 *data) +{ + struct hci_conn *hcon = conn->hcon; + struct l2cap_conn_param_update_req *req; + struct l2cap_conn_param_update_rsp rsp; + u16 min, max, latency, to_multiplier; + int err; + + if (hcon->role != HCI_ROLE_MASTER) + return -EINVAL; + + if (cmd_len != sizeof(struct l2cap_conn_param_update_req)) + return -EPROTO; + + req = (struct l2cap_conn_param_update_req *) data; + min = __le16_to_cpu(req->min); + max = __le16_to_cpu(req->max); + latency = __le16_to_cpu(req->latency); + to_multiplier = __le16_to_cpu(req->to_multiplier); + + BT_DBG("min 0x%4.4x max 0x%4.4x latency: 0x%4.4x Timeout: 0x%4.4x", + min, max, latency, to_multiplier); + + memset(&rsp, 0, sizeof(rsp)); + + err = hci_check_conn_params(min, max, latency, to_multiplier); + if (err) + rsp.result = cpu_to_le16(L2CAP_CONN_PARAM_REJECTED); + else + rsp.result = cpu_to_le16(L2CAP_CONN_PARAM_ACCEPTED); + + l2cap_send_cmd(conn, cmd->ident, L2CAP_CONN_PARAM_UPDATE_RSP, + sizeof(rsp), &rsp); + + if (!err) { + u8 store_hint; + + store_hint = hci_le_conn_update(hcon, min, max, latency, + to_multiplier); + mgmt_new_conn_param(hcon->hdev, &hcon->dst, hcon->dst_type, + store_hint, min, max, latency, + to_multiplier); + + } + + return 0; +} + +static int l2cap_le_connect_rsp(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + struct l2cap_le_conn_rsp *rsp = (struct l2cap_le_conn_rsp *) data; + struct hci_conn *hcon = conn->hcon; + u16 dcid, mtu, mps, credits, result; + struct l2cap_chan *chan; + int err, sec_level; + + if (cmd_len < sizeof(*rsp)) + return -EPROTO; + + dcid = __le16_to_cpu(rsp->dcid); + mtu = __le16_to_cpu(rsp->mtu); + mps = __le16_to_cpu(rsp->mps); + credits = __le16_to_cpu(rsp->credits); + result = __le16_to_cpu(rsp->result); + + if (result == L2CAP_CR_LE_SUCCESS && (mtu < 23 || mps < 23 || + dcid < L2CAP_CID_DYN_START || + dcid > L2CAP_CID_LE_DYN_END)) + return -EPROTO; + + BT_DBG("dcid 0x%4.4x mtu %u mps %u credits %u result 0x%2.2x", + dcid, mtu, mps, credits, result); + + mutex_lock(&conn->chan_lock); + + chan = __l2cap_get_chan_by_ident(conn, cmd->ident); + if (!chan) { + err = -EBADSLT; + goto unlock; + } + + err = 0; + + l2cap_chan_lock(chan); + + switch (result) { + case L2CAP_CR_LE_SUCCESS: + if (__l2cap_get_chan_by_dcid(conn, dcid)) { + err = -EBADSLT; + break; + } + + chan->ident = 0; + chan->dcid = dcid; + chan->omtu = mtu; + chan->remote_mps = mps; + chan->tx_credits = credits; + l2cap_chan_ready(chan); + break; + + case L2CAP_CR_LE_AUTHENTICATION: + case L2CAP_CR_LE_ENCRYPTION: + /* If we already have MITM protection we can't do + * anything. + */ + if (hcon->sec_level > BT_SECURITY_MEDIUM) { + l2cap_chan_del(chan, ECONNREFUSED); + break; + } + + sec_level = hcon->sec_level + 1; + if (chan->sec_level < sec_level) + chan->sec_level = sec_level; + + /* We'll need to send a new Connect Request */ + clear_bit(FLAG_LE_CONN_REQ_SENT, &chan->flags); + + smp_conn_security(hcon, chan->sec_level); + break; + + default: + l2cap_chan_del(chan, ECONNREFUSED); + break; + } + + l2cap_chan_unlock(chan); + +unlock: + mutex_unlock(&conn->chan_lock); + + return err; +} + +static inline int l2cap_bredr_sig_cmd(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + int err = 0; + + switch (cmd->code) { + case L2CAP_COMMAND_REJ: + l2cap_command_rej(conn, cmd, cmd_len, data); + break; + + case L2CAP_CONN_REQ: + err = l2cap_connect_req(conn, cmd, cmd_len, data); + break; + + case L2CAP_CONN_RSP: + case L2CAP_CREATE_CHAN_RSP: + l2cap_connect_create_rsp(conn, cmd, cmd_len, data); + break; + + case L2CAP_CONF_REQ: + err = l2cap_config_req(conn, cmd, cmd_len, data); + break; + + case L2CAP_CONF_RSP: + l2cap_config_rsp(conn, cmd, cmd_len, data); + break; + + case L2CAP_DISCONN_REQ: + err = l2cap_disconnect_req(conn, cmd, cmd_len, data); + break; + + case L2CAP_DISCONN_RSP: + l2cap_disconnect_rsp(conn, cmd, cmd_len, data); + break; + + case L2CAP_ECHO_REQ: + l2cap_send_cmd(conn, cmd->ident, L2CAP_ECHO_RSP, cmd_len, data); + break; + + case L2CAP_ECHO_RSP: + break; + + case L2CAP_INFO_REQ: + err = l2cap_information_req(conn, cmd, cmd_len, data); + break; + + case L2CAP_INFO_RSP: + l2cap_information_rsp(conn, cmd, cmd_len, data); + break; + + case L2CAP_CREATE_CHAN_REQ: + err = l2cap_create_channel_req(conn, cmd, cmd_len, data); + break; + + case L2CAP_MOVE_CHAN_REQ: + err = l2cap_move_channel_req(conn, cmd, cmd_len, data); + break; + + case L2CAP_MOVE_CHAN_RSP: + l2cap_move_channel_rsp(conn, cmd, cmd_len, data); + break; + + case L2CAP_MOVE_CHAN_CFM: + err = l2cap_move_channel_confirm(conn, cmd, cmd_len, data); + break; + + case L2CAP_MOVE_CHAN_CFM_RSP: + l2cap_move_channel_confirm_rsp(conn, cmd, cmd_len, data); + break; + + default: + BT_ERR("Unknown BR/EDR signaling command 0x%2.2x", cmd->code); + err = -EINVAL; + break; + } + + return err; +} + +static int l2cap_le_connect_req(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + struct l2cap_le_conn_req *req = (struct l2cap_le_conn_req *) data; + struct l2cap_le_conn_rsp rsp; + struct l2cap_chan *chan, *pchan; + u16 dcid, scid, credits, mtu, mps; + __le16 psm; + u8 result; + + if (cmd_len != sizeof(*req)) + return -EPROTO; + + scid = __le16_to_cpu(req->scid); + mtu = __le16_to_cpu(req->mtu); + mps = __le16_to_cpu(req->mps); + psm = req->psm; + dcid = 0; + credits = 0; + + if (mtu < 23 || mps < 23) + return -EPROTO; + + BT_DBG("psm 0x%2.2x scid 0x%4.4x mtu %u mps %u", __le16_to_cpu(psm), + scid, mtu, mps); + + /* BLUETOOTH CORE SPECIFICATION Version 5.3 | Vol 3, Part A + * page 1059: + * + * Valid range: 0x0001-0x00ff + * + * Table 4.15: L2CAP_LE_CREDIT_BASED_CONNECTION_REQ SPSM ranges + */ + if (!psm || __le16_to_cpu(psm) > L2CAP_PSM_LE_DYN_END) { + result = L2CAP_CR_LE_BAD_PSM; + chan = NULL; + goto response; + } + + /* Check if we have socket listening on psm */ + pchan = l2cap_global_chan_by_psm(BT_LISTEN, psm, &conn->hcon->src, + &conn->hcon->dst, LE_LINK); + if (!pchan) { + result = L2CAP_CR_LE_BAD_PSM; + chan = NULL; + goto response; + } + + mutex_lock(&conn->chan_lock); + l2cap_chan_lock(pchan); + + if (!smp_sufficient_security(conn->hcon, pchan->sec_level, + SMP_ALLOW_STK)) { + result = L2CAP_CR_LE_AUTHENTICATION; + chan = NULL; + goto response_unlock; + } + + /* Check for valid dynamic CID range */ + if (scid < L2CAP_CID_DYN_START || scid > L2CAP_CID_LE_DYN_END) { + result = L2CAP_CR_LE_INVALID_SCID; + chan = NULL; + goto response_unlock; + } + + /* Check if we already have channel with that dcid */ + if (__l2cap_get_chan_by_dcid(conn, scid)) { + result = L2CAP_CR_LE_SCID_IN_USE; + chan = NULL; + goto response_unlock; + } + + chan = pchan->ops->new_connection(pchan); + if (!chan) { + result = L2CAP_CR_LE_NO_MEM; + goto response_unlock; + } + + bacpy(&chan->src, &conn->hcon->src); + bacpy(&chan->dst, &conn->hcon->dst); + chan->src_type = bdaddr_src_type(conn->hcon); + chan->dst_type = bdaddr_dst_type(conn->hcon); + chan->psm = psm; + chan->dcid = scid; + chan->omtu = mtu; + chan->remote_mps = mps; + + __l2cap_chan_add(conn, chan); + + l2cap_le_flowctl_init(chan, __le16_to_cpu(req->credits)); + + dcid = chan->scid; + credits = chan->rx_credits; + + __set_chan_timer(chan, chan->ops->get_sndtimeo(chan)); + + chan->ident = cmd->ident; + + if (test_bit(FLAG_DEFER_SETUP, &chan->flags)) { + l2cap_state_change(chan, BT_CONNECT2); + /* The following result value is actually not defined + * for LE CoC but we use it to let the function know + * that it should bail out after doing its cleanup + * instead of sending a response. + */ + result = L2CAP_CR_PEND; + chan->ops->defer(chan); + } else { + l2cap_chan_ready(chan); + result = L2CAP_CR_LE_SUCCESS; + } + +response_unlock: + l2cap_chan_unlock(pchan); + mutex_unlock(&conn->chan_lock); + l2cap_chan_put(pchan); + + if (result == L2CAP_CR_PEND) + return 0; + +response: + if (chan) { + rsp.mtu = cpu_to_le16(chan->imtu); + rsp.mps = cpu_to_le16(chan->mps); + } else { + rsp.mtu = 0; + rsp.mps = 0; + } + + rsp.dcid = cpu_to_le16(dcid); + rsp.credits = cpu_to_le16(credits); + rsp.result = cpu_to_le16(result); + + l2cap_send_cmd(conn, cmd->ident, L2CAP_LE_CONN_RSP, sizeof(rsp), &rsp); + + return 0; +} + +static inline int l2cap_le_credits(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + struct l2cap_le_credits *pkt; + struct l2cap_chan *chan; + u16 cid, credits, max_credits; + + if (cmd_len != sizeof(*pkt)) + return -EPROTO; + + pkt = (struct l2cap_le_credits *) data; + cid = __le16_to_cpu(pkt->cid); + credits = __le16_to_cpu(pkt->credits); + + BT_DBG("cid 0x%4.4x credits 0x%4.4x", cid, credits); + + chan = l2cap_get_chan_by_dcid(conn, cid); + if (!chan) + return -EBADSLT; + + max_credits = LE_FLOWCTL_MAX_CREDITS - chan->tx_credits; + if (credits > max_credits) { + BT_ERR("LE credits overflow"); + l2cap_send_disconn_req(chan, ECONNRESET); + + /* Return 0 so that we don't trigger an unnecessary + * command reject packet. + */ + goto unlock; + } + + chan->tx_credits += credits; + + /* Resume sending */ + l2cap_le_flowctl_send(chan); + + if (chan->tx_credits) + chan->ops->resume(chan); + +unlock: + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); + + return 0; +} + +static inline int l2cap_ecred_conn_req(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + struct l2cap_ecred_conn_req *req = (void *) data; + struct { + struct l2cap_ecred_conn_rsp rsp; + __le16 dcid[L2CAP_ECRED_MAX_CID]; + } __packed pdu; + struct l2cap_chan *chan, *pchan; + u16 mtu, mps; + __le16 psm; + u8 result, len = 0; + int i, num_scid; + bool defer = false; + + if (!enable_ecred) + return -EINVAL; + + if (cmd_len < sizeof(*req) || (cmd_len - sizeof(*req)) % sizeof(u16)) { + result = L2CAP_CR_LE_INVALID_PARAMS; + goto response; + } + + cmd_len -= sizeof(*req); + num_scid = cmd_len / sizeof(u16); + + if (num_scid > ARRAY_SIZE(pdu.dcid)) { + result = L2CAP_CR_LE_INVALID_PARAMS; + goto response; + } + + mtu = __le16_to_cpu(req->mtu); + mps = __le16_to_cpu(req->mps); + + if (mtu < L2CAP_ECRED_MIN_MTU || mps < L2CAP_ECRED_MIN_MPS) { + result = L2CAP_CR_LE_UNACCEPT_PARAMS; + goto response; + } + + psm = req->psm; + + /* BLUETOOTH CORE SPECIFICATION Version 5.3 | Vol 3, Part A + * page 1059: + * + * Valid range: 0x0001-0x00ff + * + * Table 4.15: L2CAP_LE_CREDIT_BASED_CONNECTION_REQ SPSM ranges + */ + if (!psm || __le16_to_cpu(psm) > L2CAP_PSM_LE_DYN_END) { + result = L2CAP_CR_LE_BAD_PSM; + goto response; + } + + BT_DBG("psm 0x%2.2x mtu %u mps %u", __le16_to_cpu(psm), mtu, mps); + + memset(&pdu, 0, sizeof(pdu)); + + /* Check if we have socket listening on psm */ + pchan = l2cap_global_chan_by_psm(BT_LISTEN, psm, &conn->hcon->src, + &conn->hcon->dst, LE_LINK); + if (!pchan) { + result = L2CAP_CR_LE_BAD_PSM; + goto response; + } + + mutex_lock(&conn->chan_lock); + l2cap_chan_lock(pchan); + + if (!smp_sufficient_security(conn->hcon, pchan->sec_level, + SMP_ALLOW_STK)) { + result = L2CAP_CR_LE_AUTHENTICATION; + goto unlock; + } + + result = L2CAP_CR_LE_SUCCESS; + + for (i = 0; i < num_scid; i++) { + u16 scid = __le16_to_cpu(req->scid[i]); + + BT_DBG("scid[%d] 0x%4.4x", i, scid); + + pdu.dcid[i] = 0x0000; + len += sizeof(*pdu.dcid); + + /* Check for valid dynamic CID range */ + if (scid < L2CAP_CID_DYN_START || scid > L2CAP_CID_LE_DYN_END) { + result = L2CAP_CR_LE_INVALID_SCID; + continue; + } + + /* Check if we already have channel with that dcid */ + if (__l2cap_get_chan_by_dcid(conn, scid)) { + result = L2CAP_CR_LE_SCID_IN_USE; + continue; + } + + chan = pchan->ops->new_connection(pchan); + if (!chan) { + result = L2CAP_CR_LE_NO_MEM; + continue; + } + + bacpy(&chan->src, &conn->hcon->src); + bacpy(&chan->dst, &conn->hcon->dst); + chan->src_type = bdaddr_src_type(conn->hcon); + chan->dst_type = bdaddr_dst_type(conn->hcon); + chan->psm = psm; + chan->dcid = scid; + chan->omtu = mtu; + chan->remote_mps = mps; + + __l2cap_chan_add(conn, chan); + + l2cap_ecred_init(chan, __le16_to_cpu(req->credits)); + + /* Init response */ + if (!pdu.rsp.credits) { + pdu.rsp.mtu = cpu_to_le16(chan->imtu); + pdu.rsp.mps = cpu_to_le16(chan->mps); + pdu.rsp.credits = cpu_to_le16(chan->rx_credits); + } + + pdu.dcid[i] = cpu_to_le16(chan->scid); + + __set_chan_timer(chan, chan->ops->get_sndtimeo(chan)); + + chan->ident = cmd->ident; + chan->mode = L2CAP_MODE_EXT_FLOWCTL; + + if (test_bit(FLAG_DEFER_SETUP, &chan->flags)) { + l2cap_state_change(chan, BT_CONNECT2); + defer = true; + chan->ops->defer(chan); + } else { + l2cap_chan_ready(chan); + } + } + +unlock: + l2cap_chan_unlock(pchan); + mutex_unlock(&conn->chan_lock); + l2cap_chan_put(pchan); + +response: + pdu.rsp.result = cpu_to_le16(result); + + if (defer) + return 0; + + l2cap_send_cmd(conn, cmd->ident, L2CAP_ECRED_CONN_RSP, + sizeof(pdu.rsp) + len, &pdu); + + return 0; +} + +static inline int l2cap_ecred_conn_rsp(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + struct l2cap_ecred_conn_rsp *rsp = (void *) data; + struct hci_conn *hcon = conn->hcon; + u16 mtu, mps, credits, result; + struct l2cap_chan *chan, *tmp; + int err = 0, sec_level; + int i = 0; + + if (cmd_len < sizeof(*rsp)) + return -EPROTO; + + mtu = __le16_to_cpu(rsp->mtu); + mps = __le16_to_cpu(rsp->mps); + credits = __le16_to_cpu(rsp->credits); + result = __le16_to_cpu(rsp->result); + + BT_DBG("mtu %u mps %u credits %u result 0x%4.4x", mtu, mps, credits, + result); + + mutex_lock(&conn->chan_lock); + + cmd_len -= sizeof(*rsp); + + list_for_each_entry_safe(chan, tmp, &conn->chan_l, list) { + u16 dcid; + + if (chan->ident != cmd->ident || + chan->mode != L2CAP_MODE_EXT_FLOWCTL || + chan->state == BT_CONNECTED) + continue; + + l2cap_chan_lock(chan); + + /* Check that there is a dcid for each pending channel */ + if (cmd_len < sizeof(dcid)) { + l2cap_chan_del(chan, ECONNREFUSED); + l2cap_chan_unlock(chan); + continue; + } + + dcid = __le16_to_cpu(rsp->dcid[i++]); + cmd_len -= sizeof(u16); + + BT_DBG("dcid[%d] 0x%4.4x", i, dcid); + + /* Check if dcid is already in use */ + if (dcid && __l2cap_get_chan_by_dcid(conn, dcid)) { + /* If a device receives a + * L2CAP_CREDIT_BASED_CONNECTION_RSP packet with an + * already-assigned Destination CID, then both the + * original channel and the new channel shall be + * immediately discarded and not used. + */ + l2cap_chan_del(chan, ECONNREFUSED); + l2cap_chan_unlock(chan); + chan = __l2cap_get_chan_by_dcid(conn, dcid); + l2cap_chan_lock(chan); + l2cap_chan_del(chan, ECONNRESET); + l2cap_chan_unlock(chan); + continue; + } + + switch (result) { + case L2CAP_CR_LE_AUTHENTICATION: + case L2CAP_CR_LE_ENCRYPTION: + /* If we already have MITM protection we can't do + * anything. + */ + if (hcon->sec_level > BT_SECURITY_MEDIUM) { + l2cap_chan_del(chan, ECONNREFUSED); + break; + } + + sec_level = hcon->sec_level + 1; + if (chan->sec_level < sec_level) + chan->sec_level = sec_level; + + /* We'll need to send a new Connect Request */ + clear_bit(FLAG_ECRED_CONN_REQ_SENT, &chan->flags); + + smp_conn_security(hcon, chan->sec_level); + break; + + case L2CAP_CR_LE_BAD_PSM: + l2cap_chan_del(chan, ECONNREFUSED); + break; + + default: + /* If dcid was not set it means channels was refused */ + if (!dcid) { + l2cap_chan_del(chan, ECONNREFUSED); + break; + } + + chan->ident = 0; + chan->dcid = dcid; + chan->omtu = mtu; + chan->remote_mps = mps; + chan->tx_credits = credits; + l2cap_chan_ready(chan); + break; + } + + l2cap_chan_unlock(chan); + } + + mutex_unlock(&conn->chan_lock); + + return err; +} + +static inline int l2cap_ecred_reconf_req(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + struct l2cap_ecred_reconf_req *req = (void *) data; + struct l2cap_ecred_reconf_rsp rsp; + u16 mtu, mps, result; + struct l2cap_chan *chan; + int i, num_scid; + + if (!enable_ecred) + return -EINVAL; + + if (cmd_len < sizeof(*req) || cmd_len - sizeof(*req) % sizeof(u16)) { + result = L2CAP_CR_LE_INVALID_PARAMS; + goto respond; + } + + mtu = __le16_to_cpu(req->mtu); + mps = __le16_to_cpu(req->mps); + + BT_DBG("mtu %u mps %u", mtu, mps); + + if (mtu < L2CAP_ECRED_MIN_MTU) { + result = L2CAP_RECONF_INVALID_MTU; + goto respond; + } + + if (mps < L2CAP_ECRED_MIN_MPS) { + result = L2CAP_RECONF_INVALID_MPS; + goto respond; + } + + cmd_len -= sizeof(*req); + num_scid = cmd_len / sizeof(u16); + result = L2CAP_RECONF_SUCCESS; + + for (i = 0; i < num_scid; i++) { + u16 scid; + + scid = __le16_to_cpu(req->scid[i]); + if (!scid) + return -EPROTO; + + chan = __l2cap_get_chan_by_dcid(conn, scid); + if (!chan) + continue; + + /* If the MTU value is decreased for any of the included + * channels, then the receiver shall disconnect all + * included channels. + */ + if (chan->omtu > mtu) { + BT_ERR("chan %p decreased MTU %u -> %u", chan, + chan->omtu, mtu); + result = L2CAP_RECONF_INVALID_MTU; + } + + chan->omtu = mtu; + chan->remote_mps = mps; + } + +respond: + rsp.result = cpu_to_le16(result); + + l2cap_send_cmd(conn, cmd->ident, L2CAP_ECRED_RECONF_RSP, sizeof(rsp), + &rsp); + + return 0; +} + +static inline int l2cap_ecred_reconf_rsp(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + struct l2cap_chan *chan, *tmp; + struct l2cap_ecred_conn_rsp *rsp = (void *) data; + u16 result; + + if (cmd_len < sizeof(*rsp)) + return -EPROTO; + + result = __le16_to_cpu(rsp->result); + + BT_DBG("result 0x%4.4x", rsp->result); + + if (!result) + return 0; + + list_for_each_entry_safe(chan, tmp, &conn->chan_l, list) { + if (chan->ident != cmd->ident) + continue; + + l2cap_chan_del(chan, ECONNRESET); + } + + return 0; +} + +static inline int l2cap_le_command_rej(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + struct l2cap_cmd_rej_unk *rej = (struct l2cap_cmd_rej_unk *) data; + struct l2cap_chan *chan; + + if (cmd_len < sizeof(*rej)) + return -EPROTO; + + mutex_lock(&conn->chan_lock); + + chan = __l2cap_get_chan_by_ident(conn, cmd->ident); + if (!chan) + goto done; + + chan = l2cap_chan_hold_unless_zero(chan); + if (!chan) + goto done; + + l2cap_chan_lock(chan); + l2cap_chan_del(chan, ECONNREFUSED); + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); + +done: + mutex_unlock(&conn->chan_lock); + return 0; +} + +static inline int l2cap_le_sig_cmd(struct l2cap_conn *conn, + struct l2cap_cmd_hdr *cmd, u16 cmd_len, + u8 *data) +{ + int err = 0; + + switch (cmd->code) { + case L2CAP_COMMAND_REJ: + l2cap_le_command_rej(conn, cmd, cmd_len, data); + break; + + case L2CAP_CONN_PARAM_UPDATE_REQ: + err = l2cap_conn_param_update_req(conn, cmd, cmd_len, data); + break; + + case L2CAP_CONN_PARAM_UPDATE_RSP: + break; + + case L2CAP_LE_CONN_RSP: + l2cap_le_connect_rsp(conn, cmd, cmd_len, data); + break; + + case L2CAP_LE_CONN_REQ: + err = l2cap_le_connect_req(conn, cmd, cmd_len, data); + break; + + case L2CAP_LE_CREDITS: + err = l2cap_le_credits(conn, cmd, cmd_len, data); + break; + + case L2CAP_ECRED_CONN_REQ: + err = l2cap_ecred_conn_req(conn, cmd, cmd_len, data); + break; + + case L2CAP_ECRED_CONN_RSP: + err = l2cap_ecred_conn_rsp(conn, cmd, cmd_len, data); + break; + + case L2CAP_ECRED_RECONF_REQ: + err = l2cap_ecred_reconf_req(conn, cmd, cmd_len, data); + break; + + case L2CAP_ECRED_RECONF_RSP: + err = l2cap_ecred_reconf_rsp(conn, cmd, cmd_len, data); + break; + + case L2CAP_DISCONN_REQ: + err = l2cap_disconnect_req(conn, cmd, cmd_len, data); + break; + + case L2CAP_DISCONN_RSP: + l2cap_disconnect_rsp(conn, cmd, cmd_len, data); + break; + + default: + BT_ERR("Unknown LE signaling command 0x%2.2x", cmd->code); + err = -EINVAL; + break; + } + + return err; +} + +static inline void l2cap_le_sig_channel(struct l2cap_conn *conn, + struct sk_buff *skb) +{ + struct hci_conn *hcon = conn->hcon; + struct l2cap_cmd_hdr *cmd; + u16 len; + int err; + + if (hcon->type != LE_LINK) + goto drop; + + if (skb->len < L2CAP_CMD_HDR_SIZE) + goto drop; + + cmd = (void *) skb->data; + skb_pull(skb, L2CAP_CMD_HDR_SIZE); + + len = le16_to_cpu(cmd->len); + + BT_DBG("code 0x%2.2x len %d id 0x%2.2x", cmd->code, len, cmd->ident); + + if (len != skb->len || !cmd->ident) { + BT_DBG("corrupted command"); + goto drop; + } + + err = l2cap_le_sig_cmd(conn, cmd, len, skb->data); + if (err) { + struct l2cap_cmd_rej_unk rej; + + BT_ERR("Wrong link type (%d)", err); + + rej.reason = cpu_to_le16(L2CAP_REJ_NOT_UNDERSTOOD); + l2cap_send_cmd(conn, cmd->ident, L2CAP_COMMAND_REJ, + sizeof(rej), &rej); + } + +drop: + kfree_skb(skb); +} + +static inline void l2cap_sig_send_rej(struct l2cap_conn *conn, u16 ident) +{ + struct l2cap_cmd_rej_unk rej; + + rej.reason = cpu_to_le16(L2CAP_REJ_NOT_UNDERSTOOD); + l2cap_send_cmd(conn, ident, L2CAP_COMMAND_REJ, sizeof(rej), &rej); +} + +static inline void l2cap_sig_channel(struct l2cap_conn *conn, + struct sk_buff *skb) +{ + struct hci_conn *hcon = conn->hcon; + struct l2cap_cmd_hdr *cmd; + int err; + + l2cap_raw_recv(conn, skb); + + if (hcon->type != ACL_LINK) + goto drop; + + while (skb->len >= L2CAP_CMD_HDR_SIZE) { + u16 len; + + cmd = (void *) skb->data; + skb_pull(skb, L2CAP_CMD_HDR_SIZE); + + len = le16_to_cpu(cmd->len); + + BT_DBG("code 0x%2.2x len %d id 0x%2.2x", cmd->code, len, + cmd->ident); + + if (len > skb->len || !cmd->ident) { + BT_DBG("corrupted command"); + l2cap_sig_send_rej(conn, cmd->ident); + break; + } + + err = l2cap_bredr_sig_cmd(conn, cmd, len, skb->data); + if (err) { + BT_ERR("Wrong link type (%d)", err); + l2cap_sig_send_rej(conn, cmd->ident); + } + + skb_pull(skb, len); + } + + if (skb->len > 0) { + BT_DBG("corrupted command"); + l2cap_sig_send_rej(conn, 0); + } + +drop: + kfree_skb(skb); +} + +static int l2cap_check_fcs(struct l2cap_chan *chan, struct sk_buff *skb) +{ + u16 our_fcs, rcv_fcs; + int hdr_size; + + if (test_bit(FLAG_EXT_CTRL, &chan->flags)) + hdr_size = L2CAP_EXT_HDR_SIZE; + else + hdr_size = L2CAP_ENH_HDR_SIZE; + + if (chan->fcs == L2CAP_FCS_CRC16) { + skb_trim(skb, skb->len - L2CAP_FCS_SIZE); + rcv_fcs = get_unaligned_le16(skb->data + skb->len); + our_fcs = crc16(0, skb->data - hdr_size, skb->len + hdr_size); + + if (our_fcs != rcv_fcs) + return -EBADMSG; + } + return 0; +} + +static void l2cap_send_i_or_rr_or_rnr(struct l2cap_chan *chan) +{ + struct l2cap_ctrl control; + + BT_DBG("chan %p", chan); + + memset(&control, 0, sizeof(control)); + control.sframe = 1; + control.final = 1; + control.reqseq = chan->buffer_seq; + set_bit(CONN_SEND_FBIT, &chan->conn_state); + + if (test_bit(CONN_LOCAL_BUSY, &chan->conn_state)) { + control.super = L2CAP_SUPER_RNR; + l2cap_send_sframe(chan, &control); + } + + if (test_and_clear_bit(CONN_REMOTE_BUSY, &chan->conn_state) && + chan->unacked_frames > 0) + __set_retrans_timer(chan); + + /* Send pending iframes */ + l2cap_ertm_send(chan); + + if (!test_bit(CONN_LOCAL_BUSY, &chan->conn_state) && + test_bit(CONN_SEND_FBIT, &chan->conn_state)) { + /* F-bit wasn't sent in an s-frame or i-frame yet, so + * send it now. + */ + control.super = L2CAP_SUPER_RR; + l2cap_send_sframe(chan, &control); + } +} + +static void append_skb_frag(struct sk_buff *skb, struct sk_buff *new_frag, + struct sk_buff **last_frag) +{ + /* skb->len reflects data in skb as well as all fragments + * skb->data_len reflects only data in fragments + */ + if (!skb_has_frag_list(skb)) + skb_shinfo(skb)->frag_list = new_frag; + + new_frag->next = NULL; + + (*last_frag)->next = new_frag; + *last_frag = new_frag; + + skb->len += new_frag->len; + skb->data_len += new_frag->len; + skb->truesize += new_frag->truesize; +} + +static int l2cap_reassemble_sdu(struct l2cap_chan *chan, struct sk_buff *skb, + struct l2cap_ctrl *control) +{ + int err = -EINVAL; + + switch (control->sar) { + case L2CAP_SAR_UNSEGMENTED: + if (chan->sdu) + break; + + err = chan->ops->recv(chan, skb); + break; + + case L2CAP_SAR_START: + if (chan->sdu) + break; + + if (!pskb_may_pull(skb, L2CAP_SDULEN_SIZE)) + break; + + chan->sdu_len = get_unaligned_le16(skb->data); + skb_pull(skb, L2CAP_SDULEN_SIZE); + + if (chan->sdu_len > chan->imtu) { + err = -EMSGSIZE; + break; + } + + if (skb->len >= chan->sdu_len) + break; + + chan->sdu = skb; + chan->sdu_last_frag = skb; + + skb = NULL; + err = 0; + break; + + case L2CAP_SAR_CONTINUE: + if (!chan->sdu) + break; + + append_skb_frag(chan->sdu, skb, + &chan->sdu_last_frag); + skb = NULL; + + if (chan->sdu->len >= chan->sdu_len) + break; + + err = 0; + break; + + case L2CAP_SAR_END: + if (!chan->sdu) + break; + + append_skb_frag(chan->sdu, skb, + &chan->sdu_last_frag); + skb = NULL; + + if (chan->sdu->len != chan->sdu_len) + break; + + err = chan->ops->recv(chan, chan->sdu); + + if (!err) { + /* Reassembly complete */ + chan->sdu = NULL; + chan->sdu_last_frag = NULL; + chan->sdu_len = 0; + } + break; + } + + if (err) { + kfree_skb(skb); + kfree_skb(chan->sdu); + chan->sdu = NULL; + chan->sdu_last_frag = NULL; + chan->sdu_len = 0; + } + + return err; +} + +static int l2cap_resegment(struct l2cap_chan *chan) +{ + /* Placeholder */ + return 0; +} + +void l2cap_chan_busy(struct l2cap_chan *chan, int busy) +{ + u8 event; + + if (chan->mode != L2CAP_MODE_ERTM) + return; + + event = busy ? L2CAP_EV_LOCAL_BUSY_DETECTED : L2CAP_EV_LOCAL_BUSY_CLEAR; + l2cap_tx(chan, NULL, NULL, event); +} + +static int l2cap_rx_queued_iframes(struct l2cap_chan *chan) +{ + int err = 0; + /* Pass sequential frames to l2cap_reassemble_sdu() + * until a gap is encountered. + */ + + BT_DBG("chan %p", chan); + + while (!test_bit(CONN_LOCAL_BUSY, &chan->conn_state)) { + struct sk_buff *skb; + BT_DBG("Searching for skb with txseq %d (queue len %d)", + chan->buffer_seq, skb_queue_len(&chan->srej_q)); + + skb = l2cap_ertm_seq_in_queue(&chan->srej_q, chan->buffer_seq); + + if (!skb) + break; + + skb_unlink(skb, &chan->srej_q); + chan->buffer_seq = __next_seq(chan, chan->buffer_seq); + err = l2cap_reassemble_sdu(chan, skb, &bt_cb(skb)->l2cap); + if (err) + break; + } + + if (skb_queue_empty(&chan->srej_q)) { + chan->rx_state = L2CAP_RX_STATE_RECV; + l2cap_send_ack(chan); + } + + return err; +} + +static void l2cap_handle_srej(struct l2cap_chan *chan, + struct l2cap_ctrl *control) +{ + struct sk_buff *skb; + + BT_DBG("chan %p, control %p", chan, control); + + if (control->reqseq == chan->next_tx_seq) { + BT_DBG("Invalid reqseq %d, disconnecting", control->reqseq); + l2cap_send_disconn_req(chan, ECONNRESET); + return; + } + + skb = l2cap_ertm_seq_in_queue(&chan->tx_q, control->reqseq); + + if (skb == NULL) { + BT_DBG("Seq %d not available for retransmission", + control->reqseq); + return; + } + + if (chan->max_tx != 0 && bt_cb(skb)->l2cap.retries >= chan->max_tx) { + BT_DBG("Retry limit exceeded (%d)", chan->max_tx); + l2cap_send_disconn_req(chan, ECONNRESET); + return; + } + + clear_bit(CONN_REMOTE_BUSY, &chan->conn_state); + + if (control->poll) { + l2cap_pass_to_tx(chan, control); + + set_bit(CONN_SEND_FBIT, &chan->conn_state); + l2cap_retransmit(chan, control); + l2cap_ertm_send(chan); + + if (chan->tx_state == L2CAP_TX_STATE_WAIT_F) { + set_bit(CONN_SREJ_ACT, &chan->conn_state); + chan->srej_save_reqseq = control->reqseq; + } + } else { + l2cap_pass_to_tx_fbit(chan, control); + + if (control->final) { + if (chan->srej_save_reqseq != control->reqseq || + !test_and_clear_bit(CONN_SREJ_ACT, + &chan->conn_state)) + l2cap_retransmit(chan, control); + } else { + l2cap_retransmit(chan, control); + if (chan->tx_state == L2CAP_TX_STATE_WAIT_F) { + set_bit(CONN_SREJ_ACT, &chan->conn_state); + chan->srej_save_reqseq = control->reqseq; + } + } + } +} + +static void l2cap_handle_rej(struct l2cap_chan *chan, + struct l2cap_ctrl *control) +{ + struct sk_buff *skb; + + BT_DBG("chan %p, control %p", chan, control); + + if (control->reqseq == chan->next_tx_seq) { + BT_DBG("Invalid reqseq %d, disconnecting", control->reqseq); + l2cap_send_disconn_req(chan, ECONNRESET); + return; + } + + skb = l2cap_ertm_seq_in_queue(&chan->tx_q, control->reqseq); + + if (chan->max_tx && skb && + bt_cb(skb)->l2cap.retries >= chan->max_tx) { + BT_DBG("Retry limit exceeded (%d)", chan->max_tx); + l2cap_send_disconn_req(chan, ECONNRESET); + return; + } + + clear_bit(CONN_REMOTE_BUSY, &chan->conn_state); + + l2cap_pass_to_tx(chan, control); + + if (control->final) { + if (!test_and_clear_bit(CONN_REJ_ACT, &chan->conn_state)) + l2cap_retransmit_all(chan, control); + } else { + l2cap_retransmit_all(chan, control); + l2cap_ertm_send(chan); + if (chan->tx_state == L2CAP_TX_STATE_WAIT_F) + set_bit(CONN_REJ_ACT, &chan->conn_state); + } +} + +static u8 l2cap_classify_txseq(struct l2cap_chan *chan, u16 txseq) +{ + BT_DBG("chan %p, txseq %d", chan, txseq); + + BT_DBG("last_acked_seq %d, expected_tx_seq %d", chan->last_acked_seq, + chan->expected_tx_seq); + + if (chan->rx_state == L2CAP_RX_STATE_SREJ_SENT) { + if (__seq_offset(chan, txseq, chan->last_acked_seq) >= + chan->tx_win) { + /* See notes below regarding "double poll" and + * invalid packets. + */ + if (chan->tx_win <= ((chan->tx_win_max + 1) >> 1)) { + BT_DBG("Invalid/Ignore - after SREJ"); + return L2CAP_TXSEQ_INVALID_IGNORE; + } else { + BT_DBG("Invalid - in window after SREJ sent"); + return L2CAP_TXSEQ_INVALID; + } + } + + if (chan->srej_list.head == txseq) { + BT_DBG("Expected SREJ"); + return L2CAP_TXSEQ_EXPECTED_SREJ; + } + + if (l2cap_ertm_seq_in_queue(&chan->srej_q, txseq)) { + BT_DBG("Duplicate SREJ - txseq already stored"); + return L2CAP_TXSEQ_DUPLICATE_SREJ; + } + + if (l2cap_seq_list_contains(&chan->srej_list, txseq)) { + BT_DBG("Unexpected SREJ - not requested"); + return L2CAP_TXSEQ_UNEXPECTED_SREJ; + } + } + + if (chan->expected_tx_seq == txseq) { + if (__seq_offset(chan, txseq, chan->last_acked_seq) >= + chan->tx_win) { + BT_DBG("Invalid - txseq outside tx window"); + return L2CAP_TXSEQ_INVALID; + } else { + BT_DBG("Expected"); + return L2CAP_TXSEQ_EXPECTED; + } + } + + if (__seq_offset(chan, txseq, chan->last_acked_seq) < + __seq_offset(chan, chan->expected_tx_seq, chan->last_acked_seq)) { + BT_DBG("Duplicate - expected_tx_seq later than txseq"); + return L2CAP_TXSEQ_DUPLICATE; + } + + if (__seq_offset(chan, txseq, chan->last_acked_seq) >= chan->tx_win) { + /* A source of invalid packets is a "double poll" condition, + * where delays cause us to send multiple poll packets. If + * the remote stack receives and processes both polls, + * sequence numbers can wrap around in such a way that a + * resent frame has a sequence number that looks like new data + * with a sequence gap. This would trigger an erroneous SREJ + * request. + * + * Fortunately, this is impossible with a tx window that's + * less than half of the maximum sequence number, which allows + * invalid frames to be safely ignored. + * + * With tx window sizes greater than half of the tx window + * maximum, the frame is invalid and cannot be ignored. This + * causes a disconnect. + */ + + if (chan->tx_win <= ((chan->tx_win_max + 1) >> 1)) { + BT_DBG("Invalid/Ignore - txseq outside tx window"); + return L2CAP_TXSEQ_INVALID_IGNORE; + } else { + BT_DBG("Invalid - txseq outside tx window"); + return L2CAP_TXSEQ_INVALID; + } + } else { + BT_DBG("Unexpected - txseq indicates missing frames"); + return L2CAP_TXSEQ_UNEXPECTED; + } +} + +static int l2cap_rx_state_recv(struct l2cap_chan *chan, + struct l2cap_ctrl *control, + struct sk_buff *skb, u8 event) +{ + struct l2cap_ctrl local_control; + int err = 0; + bool skb_in_use = false; + + BT_DBG("chan %p, control %p, skb %p, event %d", chan, control, skb, + event); + + switch (event) { + case L2CAP_EV_RECV_IFRAME: + switch (l2cap_classify_txseq(chan, control->txseq)) { + case L2CAP_TXSEQ_EXPECTED: + l2cap_pass_to_tx(chan, control); + + if (test_bit(CONN_LOCAL_BUSY, &chan->conn_state)) { + BT_DBG("Busy, discarding expected seq %d", + control->txseq); + break; + } + + chan->expected_tx_seq = __next_seq(chan, + control->txseq); + + chan->buffer_seq = chan->expected_tx_seq; + skb_in_use = true; + + /* l2cap_reassemble_sdu may free skb, hence invalidate + * control, so make a copy in advance to use it after + * l2cap_reassemble_sdu returns and to avoid the race + * condition, for example: + * + * The current thread calls: + * l2cap_reassemble_sdu + * chan->ops->recv == l2cap_sock_recv_cb + * __sock_queue_rcv_skb + * Another thread calls: + * bt_sock_recvmsg + * skb_recv_datagram + * skb_free_datagram + * Then the current thread tries to access control, but + * it was freed by skb_free_datagram. + */ + local_control = *control; + err = l2cap_reassemble_sdu(chan, skb, control); + if (err) + break; + + if (local_control.final) { + if (!test_and_clear_bit(CONN_REJ_ACT, + &chan->conn_state)) { + local_control.final = 0; + l2cap_retransmit_all(chan, &local_control); + l2cap_ertm_send(chan); + } + } + + if (!test_bit(CONN_LOCAL_BUSY, &chan->conn_state)) + l2cap_send_ack(chan); + break; + case L2CAP_TXSEQ_UNEXPECTED: + l2cap_pass_to_tx(chan, control); + + /* Can't issue SREJ frames in the local busy state. + * Drop this frame, it will be seen as missing + * when local busy is exited. + */ + if (test_bit(CONN_LOCAL_BUSY, &chan->conn_state)) { + BT_DBG("Busy, discarding unexpected seq %d", + control->txseq); + break; + } + + /* There was a gap in the sequence, so an SREJ + * must be sent for each missing frame. The + * current frame is stored for later use. + */ + skb_queue_tail(&chan->srej_q, skb); + skb_in_use = true; + BT_DBG("Queued %p (queue len %d)", skb, + skb_queue_len(&chan->srej_q)); + + clear_bit(CONN_SREJ_ACT, &chan->conn_state); + l2cap_seq_list_clear(&chan->srej_list); + l2cap_send_srej(chan, control->txseq); + + chan->rx_state = L2CAP_RX_STATE_SREJ_SENT; + break; + case L2CAP_TXSEQ_DUPLICATE: + l2cap_pass_to_tx(chan, control); + break; + case L2CAP_TXSEQ_INVALID_IGNORE: + break; + case L2CAP_TXSEQ_INVALID: + default: + l2cap_send_disconn_req(chan, ECONNRESET); + break; + } + break; + case L2CAP_EV_RECV_RR: + l2cap_pass_to_tx(chan, control); + if (control->final) { + clear_bit(CONN_REMOTE_BUSY, &chan->conn_state); + + if (!test_and_clear_bit(CONN_REJ_ACT, &chan->conn_state) && + !__chan_is_moving(chan)) { + control->final = 0; + l2cap_retransmit_all(chan, control); + } + + l2cap_ertm_send(chan); + } else if (control->poll) { + l2cap_send_i_or_rr_or_rnr(chan); + } else { + if (test_and_clear_bit(CONN_REMOTE_BUSY, + &chan->conn_state) && + chan->unacked_frames) + __set_retrans_timer(chan); + + l2cap_ertm_send(chan); + } + break; + case L2CAP_EV_RECV_RNR: + set_bit(CONN_REMOTE_BUSY, &chan->conn_state); + l2cap_pass_to_tx(chan, control); + if (control && control->poll) { + set_bit(CONN_SEND_FBIT, &chan->conn_state); + l2cap_send_rr_or_rnr(chan, 0); + } + __clear_retrans_timer(chan); + l2cap_seq_list_clear(&chan->retrans_list); + break; + case L2CAP_EV_RECV_REJ: + l2cap_handle_rej(chan, control); + break; + case L2CAP_EV_RECV_SREJ: + l2cap_handle_srej(chan, control); + break; + default: + break; + } + + if (skb && !skb_in_use) { + BT_DBG("Freeing %p", skb); + kfree_skb(skb); + } + + return err; +} + +static int l2cap_rx_state_srej_sent(struct l2cap_chan *chan, + struct l2cap_ctrl *control, + struct sk_buff *skb, u8 event) +{ + int err = 0; + u16 txseq = control->txseq; + bool skb_in_use = false; + + BT_DBG("chan %p, control %p, skb %p, event %d", chan, control, skb, + event); + + switch (event) { + case L2CAP_EV_RECV_IFRAME: + switch (l2cap_classify_txseq(chan, txseq)) { + case L2CAP_TXSEQ_EXPECTED: + /* Keep frame for reassembly later */ + l2cap_pass_to_tx(chan, control); + skb_queue_tail(&chan->srej_q, skb); + skb_in_use = true; + BT_DBG("Queued %p (queue len %d)", skb, + skb_queue_len(&chan->srej_q)); + + chan->expected_tx_seq = __next_seq(chan, txseq); + break; + case L2CAP_TXSEQ_EXPECTED_SREJ: + l2cap_seq_list_pop(&chan->srej_list); + + l2cap_pass_to_tx(chan, control); + skb_queue_tail(&chan->srej_q, skb); + skb_in_use = true; + BT_DBG("Queued %p (queue len %d)", skb, + skb_queue_len(&chan->srej_q)); + + err = l2cap_rx_queued_iframes(chan); + if (err) + break; + + break; + case L2CAP_TXSEQ_UNEXPECTED: + /* Got a frame that can't be reassembled yet. + * Save it for later, and send SREJs to cover + * the missing frames. + */ + skb_queue_tail(&chan->srej_q, skb); + skb_in_use = true; + BT_DBG("Queued %p (queue len %d)", skb, + skb_queue_len(&chan->srej_q)); + + l2cap_pass_to_tx(chan, control); + l2cap_send_srej(chan, control->txseq); + break; + case L2CAP_TXSEQ_UNEXPECTED_SREJ: + /* This frame was requested with an SREJ, but + * some expected retransmitted frames are + * missing. Request retransmission of missing + * SREJ'd frames. + */ + skb_queue_tail(&chan->srej_q, skb); + skb_in_use = true; + BT_DBG("Queued %p (queue len %d)", skb, + skb_queue_len(&chan->srej_q)); + + l2cap_pass_to_tx(chan, control); + l2cap_send_srej_list(chan, control->txseq); + break; + case L2CAP_TXSEQ_DUPLICATE_SREJ: + /* We've already queued this frame. Drop this copy. */ + l2cap_pass_to_tx(chan, control); + break; + case L2CAP_TXSEQ_DUPLICATE: + /* Expecting a later sequence number, so this frame + * was already received. Ignore it completely. + */ + break; + case L2CAP_TXSEQ_INVALID_IGNORE: + break; + case L2CAP_TXSEQ_INVALID: + default: + l2cap_send_disconn_req(chan, ECONNRESET); + break; + } + break; + case L2CAP_EV_RECV_RR: + l2cap_pass_to_tx(chan, control); + if (control->final) { + clear_bit(CONN_REMOTE_BUSY, &chan->conn_state); + + if (!test_and_clear_bit(CONN_REJ_ACT, + &chan->conn_state)) { + control->final = 0; + l2cap_retransmit_all(chan, control); + } + + l2cap_ertm_send(chan); + } else if (control->poll) { + if (test_and_clear_bit(CONN_REMOTE_BUSY, + &chan->conn_state) && + chan->unacked_frames) { + __set_retrans_timer(chan); + } + + set_bit(CONN_SEND_FBIT, &chan->conn_state); + l2cap_send_srej_tail(chan); + } else { + if (test_and_clear_bit(CONN_REMOTE_BUSY, + &chan->conn_state) && + chan->unacked_frames) + __set_retrans_timer(chan); + + l2cap_send_ack(chan); + } + break; + case L2CAP_EV_RECV_RNR: + set_bit(CONN_REMOTE_BUSY, &chan->conn_state); + l2cap_pass_to_tx(chan, control); + if (control->poll) { + l2cap_send_srej_tail(chan); + } else { + struct l2cap_ctrl rr_control; + memset(&rr_control, 0, sizeof(rr_control)); + rr_control.sframe = 1; + rr_control.super = L2CAP_SUPER_RR; + rr_control.reqseq = chan->buffer_seq; + l2cap_send_sframe(chan, &rr_control); + } + + break; + case L2CAP_EV_RECV_REJ: + l2cap_handle_rej(chan, control); + break; + case L2CAP_EV_RECV_SREJ: + l2cap_handle_srej(chan, control); + break; + } + + if (skb && !skb_in_use) { + BT_DBG("Freeing %p", skb); + kfree_skb(skb); + } + + return err; +} + +static int l2cap_finish_move(struct l2cap_chan *chan) +{ + BT_DBG("chan %p", chan); + + chan->rx_state = L2CAP_RX_STATE_RECV; + + if (chan->hs_hcon) + chan->conn->mtu = chan->hs_hcon->hdev->block_mtu; + else + chan->conn->mtu = chan->conn->hcon->hdev->acl_mtu; + + return l2cap_resegment(chan); +} + +static int l2cap_rx_state_wait_p(struct l2cap_chan *chan, + struct l2cap_ctrl *control, + struct sk_buff *skb, u8 event) +{ + int err; + + BT_DBG("chan %p, control %p, skb %p, event %d", chan, control, skb, + event); + + if (!control->poll) + return -EPROTO; + + l2cap_process_reqseq(chan, control->reqseq); + + if (!skb_queue_empty(&chan->tx_q)) + chan->tx_send_head = skb_peek(&chan->tx_q); + else + chan->tx_send_head = NULL; + + /* Rewind next_tx_seq to the point expected + * by the receiver. + */ + chan->next_tx_seq = control->reqseq; + chan->unacked_frames = 0; + + err = l2cap_finish_move(chan); + if (err) + return err; + + set_bit(CONN_SEND_FBIT, &chan->conn_state); + l2cap_send_i_or_rr_or_rnr(chan); + + if (event == L2CAP_EV_RECV_IFRAME) + return -EPROTO; + + return l2cap_rx_state_recv(chan, control, NULL, event); +} + +static int l2cap_rx_state_wait_f(struct l2cap_chan *chan, + struct l2cap_ctrl *control, + struct sk_buff *skb, u8 event) +{ + int err; + + if (!control->final) + return -EPROTO; + + clear_bit(CONN_REMOTE_BUSY, &chan->conn_state); + + chan->rx_state = L2CAP_RX_STATE_RECV; + l2cap_process_reqseq(chan, control->reqseq); + + if (!skb_queue_empty(&chan->tx_q)) + chan->tx_send_head = skb_peek(&chan->tx_q); + else + chan->tx_send_head = NULL; + + /* Rewind next_tx_seq to the point expected + * by the receiver. + */ + chan->next_tx_seq = control->reqseq; + chan->unacked_frames = 0; + + if (chan->hs_hcon) + chan->conn->mtu = chan->hs_hcon->hdev->block_mtu; + else + chan->conn->mtu = chan->conn->hcon->hdev->acl_mtu; + + err = l2cap_resegment(chan); + + if (!err) + err = l2cap_rx_state_recv(chan, control, skb, event); + + return err; +} + +static bool __valid_reqseq(struct l2cap_chan *chan, u16 reqseq) +{ + /* Make sure reqseq is for a packet that has been sent but not acked */ + u16 unacked; + + unacked = __seq_offset(chan, chan->next_tx_seq, chan->expected_ack_seq); + return __seq_offset(chan, chan->next_tx_seq, reqseq) <= unacked; +} + +static int l2cap_rx(struct l2cap_chan *chan, struct l2cap_ctrl *control, + struct sk_buff *skb, u8 event) +{ + int err = 0; + + BT_DBG("chan %p, control %p, skb %p, event %d, state %d", chan, + control, skb, event, chan->rx_state); + + if (__valid_reqseq(chan, control->reqseq)) { + switch (chan->rx_state) { + case L2CAP_RX_STATE_RECV: + err = l2cap_rx_state_recv(chan, control, skb, event); + break; + case L2CAP_RX_STATE_SREJ_SENT: + err = l2cap_rx_state_srej_sent(chan, control, skb, + event); + break; + case L2CAP_RX_STATE_WAIT_P: + err = l2cap_rx_state_wait_p(chan, control, skb, event); + break; + case L2CAP_RX_STATE_WAIT_F: + err = l2cap_rx_state_wait_f(chan, control, skb, event); + break; + default: + /* shut it down */ + break; + } + } else { + BT_DBG("Invalid reqseq %d (next_tx_seq %d, expected_ack_seq %d", + control->reqseq, chan->next_tx_seq, + chan->expected_ack_seq); + l2cap_send_disconn_req(chan, ECONNRESET); + } + + return err; +} + +static int l2cap_stream_rx(struct l2cap_chan *chan, struct l2cap_ctrl *control, + struct sk_buff *skb) +{ + /* l2cap_reassemble_sdu may free skb, hence invalidate control, so store + * the txseq field in advance to use it after l2cap_reassemble_sdu + * returns and to avoid the race condition, for example: + * + * The current thread calls: + * l2cap_reassemble_sdu + * chan->ops->recv == l2cap_sock_recv_cb + * __sock_queue_rcv_skb + * Another thread calls: + * bt_sock_recvmsg + * skb_recv_datagram + * skb_free_datagram + * Then the current thread tries to access control, but it was freed by + * skb_free_datagram. + */ + u16 txseq = control->txseq; + + BT_DBG("chan %p, control %p, skb %p, state %d", chan, control, skb, + chan->rx_state); + + if (l2cap_classify_txseq(chan, txseq) == L2CAP_TXSEQ_EXPECTED) { + l2cap_pass_to_tx(chan, control); + + BT_DBG("buffer_seq %u->%u", chan->buffer_seq, + __next_seq(chan, chan->buffer_seq)); + + chan->buffer_seq = __next_seq(chan, chan->buffer_seq); + + l2cap_reassemble_sdu(chan, skb, control); + } else { + if (chan->sdu) { + kfree_skb(chan->sdu); + chan->sdu = NULL; + } + chan->sdu_last_frag = NULL; + chan->sdu_len = 0; + + if (skb) { + BT_DBG("Freeing %p", skb); + kfree_skb(skb); + } + } + + chan->last_acked_seq = txseq; + chan->expected_tx_seq = __next_seq(chan, txseq); + + return 0; +} + +static int l2cap_data_rcv(struct l2cap_chan *chan, struct sk_buff *skb) +{ + struct l2cap_ctrl *control = &bt_cb(skb)->l2cap; + u16 len; + u8 event; + + __unpack_control(chan, skb); + + len = skb->len; + + /* + * We can just drop the corrupted I-frame here. + * Receiver will miss it and start proper recovery + * procedures and ask for retransmission. + */ + if (l2cap_check_fcs(chan, skb)) + goto drop; + + if (!control->sframe && control->sar == L2CAP_SAR_START) + len -= L2CAP_SDULEN_SIZE; + + if (chan->fcs == L2CAP_FCS_CRC16) + len -= L2CAP_FCS_SIZE; + + if (len > chan->mps) { + l2cap_send_disconn_req(chan, ECONNRESET); + goto drop; + } + + if (chan->ops->filter) { + if (chan->ops->filter(chan, skb)) + goto drop; + } + + if (!control->sframe) { + int err; + + BT_DBG("iframe sar %d, reqseq %d, final %d, txseq %d", + control->sar, control->reqseq, control->final, + control->txseq); + + /* Validate F-bit - F=0 always valid, F=1 only + * valid in TX WAIT_F + */ + if (control->final && chan->tx_state != L2CAP_TX_STATE_WAIT_F) + goto drop; + + if (chan->mode != L2CAP_MODE_STREAMING) { + event = L2CAP_EV_RECV_IFRAME; + err = l2cap_rx(chan, control, skb, event); + } else { + err = l2cap_stream_rx(chan, control, skb); + } + + if (err) + l2cap_send_disconn_req(chan, ECONNRESET); + } else { + const u8 rx_func_to_event[4] = { + L2CAP_EV_RECV_RR, L2CAP_EV_RECV_REJ, + L2CAP_EV_RECV_RNR, L2CAP_EV_RECV_SREJ + }; + + /* Only I-frames are expected in streaming mode */ + if (chan->mode == L2CAP_MODE_STREAMING) + goto drop; + + BT_DBG("sframe reqseq %d, final %d, poll %d, super %d", + control->reqseq, control->final, control->poll, + control->super); + + if (len != 0) { + BT_ERR("Trailing bytes: %d in sframe", len); + l2cap_send_disconn_req(chan, ECONNRESET); + goto drop; + } + + /* Validate F and P bits */ + if (control->final && (control->poll || + chan->tx_state != L2CAP_TX_STATE_WAIT_F)) + goto drop; + + event = rx_func_to_event[control->super]; + if (l2cap_rx(chan, control, skb, event)) + l2cap_send_disconn_req(chan, ECONNRESET); + } + + return 0; + +drop: + kfree_skb(skb); + return 0; +} + +static void l2cap_chan_le_send_credits(struct l2cap_chan *chan) +{ + struct l2cap_conn *conn = chan->conn; + struct l2cap_le_credits pkt; + u16 return_credits; + + return_credits = (chan->imtu / chan->mps) + 1; + + if (chan->rx_credits >= return_credits) + return; + + return_credits -= chan->rx_credits; + + BT_DBG("chan %p returning %u credits to sender", chan, return_credits); + + chan->rx_credits += return_credits; + + pkt.cid = cpu_to_le16(chan->scid); + pkt.credits = cpu_to_le16(return_credits); + + chan->ident = l2cap_get_ident(conn); + + l2cap_send_cmd(conn, chan->ident, L2CAP_LE_CREDITS, sizeof(pkt), &pkt); +} + +static int l2cap_ecred_recv(struct l2cap_chan *chan, struct sk_buff *skb) +{ + int err; + + BT_DBG("SDU reassemble complete: chan %p skb->len %u", chan, skb->len); + + /* Wait recv to confirm reception before updating the credits */ + err = chan->ops->recv(chan, skb); + + /* Update credits whenever an SDU is received */ + l2cap_chan_le_send_credits(chan); + + return err; +} + +static int l2cap_ecred_data_rcv(struct l2cap_chan *chan, struct sk_buff *skb) +{ + int err; + + if (!chan->rx_credits) { + BT_ERR("No credits to receive LE L2CAP data"); + l2cap_send_disconn_req(chan, ECONNRESET); + return -ENOBUFS; + } + + if (chan->imtu < skb->len) { + BT_ERR("Too big LE L2CAP PDU"); + return -ENOBUFS; + } + + chan->rx_credits--; + BT_DBG("rx_credits %u -> %u", chan->rx_credits + 1, chan->rx_credits); + + /* Update if remote had run out of credits, this should only happens + * if the remote is not using the entire MPS. + */ + if (!chan->rx_credits) + l2cap_chan_le_send_credits(chan); + + err = 0; + + if (!chan->sdu) { + u16 sdu_len; + + sdu_len = get_unaligned_le16(skb->data); + skb_pull(skb, L2CAP_SDULEN_SIZE); + + BT_DBG("Start of new SDU. sdu_len %u skb->len %u imtu %u", + sdu_len, skb->len, chan->imtu); + + if (sdu_len > chan->imtu) { + BT_ERR("Too big LE L2CAP SDU length received"); + err = -EMSGSIZE; + goto failed; + } + + if (skb->len > sdu_len) { + BT_ERR("Too much LE L2CAP data received"); + err = -EINVAL; + goto failed; + } + + if (skb->len == sdu_len) + return l2cap_ecred_recv(chan, skb); + + chan->sdu = skb; + chan->sdu_len = sdu_len; + chan->sdu_last_frag = skb; + + /* Detect if remote is not able to use the selected MPS */ + if (skb->len + L2CAP_SDULEN_SIZE < chan->mps) { + u16 mps_len = skb->len + L2CAP_SDULEN_SIZE; + + /* Adjust the number of credits */ + BT_DBG("chan->mps %u -> %u", chan->mps, mps_len); + chan->mps = mps_len; + l2cap_chan_le_send_credits(chan); + } + + return 0; + } + + BT_DBG("SDU fragment. chan->sdu->len %u skb->len %u chan->sdu_len %u", + chan->sdu->len, skb->len, chan->sdu_len); + + if (chan->sdu->len + skb->len > chan->sdu_len) { + BT_ERR("Too much LE L2CAP data received"); + err = -EINVAL; + goto failed; + } + + append_skb_frag(chan->sdu, skb, &chan->sdu_last_frag); + skb = NULL; + + if (chan->sdu->len == chan->sdu_len) { + err = l2cap_ecred_recv(chan, chan->sdu); + if (!err) { + chan->sdu = NULL; + chan->sdu_last_frag = NULL; + chan->sdu_len = 0; + } + } + +failed: + if (err) { + kfree_skb(skb); + kfree_skb(chan->sdu); + chan->sdu = NULL; + chan->sdu_last_frag = NULL; + chan->sdu_len = 0; + } + + /* We can't return an error here since we took care of the skb + * freeing internally. An error return would cause the caller to + * do a double-free of the skb. + */ + return 0; +} + +static void l2cap_data_channel(struct l2cap_conn *conn, u16 cid, + struct sk_buff *skb) +{ + struct l2cap_chan *chan; + + chan = l2cap_get_chan_by_scid(conn, cid); + if (!chan) { + if (cid == L2CAP_CID_A2MP) { + chan = a2mp_channel_create(conn, skb); + if (!chan) { + kfree_skb(skb); + return; + } + + l2cap_chan_hold(chan); + l2cap_chan_lock(chan); + } else { + BT_DBG("unknown cid 0x%4.4x", cid); + /* Drop packet and return */ + kfree_skb(skb); + return; + } + } + + BT_DBG("chan %p, len %d", chan, skb->len); + + /* If we receive data on a fixed channel before the info req/rsp + * procedure is done simply assume that the channel is supported + * and mark it as ready. + */ + if (chan->chan_type == L2CAP_CHAN_FIXED) + l2cap_chan_ready(chan); + + if (chan->state != BT_CONNECTED) + goto drop; + + switch (chan->mode) { + case L2CAP_MODE_LE_FLOWCTL: + case L2CAP_MODE_EXT_FLOWCTL: + if (l2cap_ecred_data_rcv(chan, skb) < 0) + goto drop; + + goto done; + + case L2CAP_MODE_BASIC: + /* If socket recv buffers overflows we drop data here + * which is *bad* because L2CAP has to be reliable. + * But we don't have any other choice. L2CAP doesn't + * provide flow control mechanism. */ + + if (chan->imtu < skb->len) { + BT_ERR("Dropping L2CAP data: receive buffer overflow"); + goto drop; + } + + if (!chan->ops->recv(chan, skb)) + goto done; + break; + + case L2CAP_MODE_ERTM: + case L2CAP_MODE_STREAMING: + l2cap_data_rcv(chan, skb); + goto done; + + default: + BT_DBG("chan %p: bad mode 0x%2.2x", chan, chan->mode); + break; + } + +drop: + kfree_skb(skb); + +done: + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); +} + +static void l2cap_conless_channel(struct l2cap_conn *conn, __le16 psm, + struct sk_buff *skb) +{ + struct hci_conn *hcon = conn->hcon; + struct l2cap_chan *chan; + + if (hcon->type != ACL_LINK) + goto free_skb; + + chan = l2cap_global_chan_by_psm(0, psm, &hcon->src, &hcon->dst, + ACL_LINK); + if (!chan) + goto free_skb; + + BT_DBG("chan %p, len %d", chan, skb->len); + + if (chan->state != BT_BOUND && chan->state != BT_CONNECTED) + goto drop; + + if (chan->imtu < skb->len) + goto drop; + + /* Store remote BD_ADDR and PSM for msg_name */ + bacpy(&bt_cb(skb)->l2cap.bdaddr, &hcon->dst); + bt_cb(skb)->l2cap.psm = psm; + + if (!chan->ops->recv(chan, skb)) { + l2cap_chan_put(chan); + return; + } + +drop: + l2cap_chan_put(chan); +free_skb: + kfree_skb(skb); +} + +static void l2cap_recv_frame(struct l2cap_conn *conn, struct sk_buff *skb) +{ + struct l2cap_hdr *lh = (void *) skb->data; + struct hci_conn *hcon = conn->hcon; + u16 cid, len; + __le16 psm; + + if (hcon->state != BT_CONNECTED) { + BT_DBG("queueing pending rx skb"); + skb_queue_tail(&conn->pending_rx, skb); + return; + } + + skb_pull(skb, L2CAP_HDR_SIZE); + cid = __le16_to_cpu(lh->cid); + len = __le16_to_cpu(lh->len); + + if (len != skb->len) { + kfree_skb(skb); + return; + } + + /* Since we can't actively block incoming LE connections we must + * at least ensure that we ignore incoming data from them. + */ + if (hcon->type == LE_LINK && + hci_bdaddr_list_lookup(&hcon->hdev->reject_list, &hcon->dst, + bdaddr_dst_type(hcon))) { + kfree_skb(skb); + return; + } + + BT_DBG("len %d, cid 0x%4.4x", len, cid); + + switch (cid) { + case L2CAP_CID_SIGNALING: + l2cap_sig_channel(conn, skb); + break; + + case L2CAP_CID_CONN_LESS: + psm = get_unaligned((__le16 *) skb->data); + skb_pull(skb, L2CAP_PSMLEN_SIZE); + l2cap_conless_channel(conn, psm, skb); + break; + + case L2CAP_CID_LE_SIGNALING: + l2cap_le_sig_channel(conn, skb); + break; + + default: + l2cap_data_channel(conn, cid, skb); + break; + } +} + +static void process_pending_rx(struct work_struct *work) +{ + struct l2cap_conn *conn = container_of(work, struct l2cap_conn, + pending_rx_work); + struct sk_buff *skb; + + BT_DBG(""); + + while ((skb = skb_dequeue(&conn->pending_rx))) + l2cap_recv_frame(conn, skb); +} + +static struct l2cap_conn *l2cap_conn_add(struct hci_conn *hcon) +{ + struct l2cap_conn *conn = hcon->l2cap_data; + struct hci_chan *hchan; + + if (conn) + return conn; + + hchan = hci_chan_create(hcon); + if (!hchan) + return NULL; + + conn = kzalloc(sizeof(*conn), GFP_KERNEL); + if (!conn) { + hci_chan_del(hchan); + return NULL; + } + + kref_init(&conn->ref); + hcon->l2cap_data = conn; + conn->hcon = hci_conn_get(hcon); + conn->hchan = hchan; + + BT_DBG("hcon %p conn %p hchan %p", hcon, conn, hchan); + + switch (hcon->type) { + case LE_LINK: + if (hcon->hdev->le_mtu) { + conn->mtu = hcon->hdev->le_mtu; + break; + } + fallthrough; + default: + conn->mtu = hcon->hdev->acl_mtu; + break; + } + + conn->feat_mask = 0; + + conn->local_fixed_chan = L2CAP_FC_SIG_BREDR | L2CAP_FC_CONNLESS; + + if (hcon->type == ACL_LINK && + hci_dev_test_flag(hcon->hdev, HCI_HS_ENABLED)) + conn->local_fixed_chan |= L2CAP_FC_A2MP; + + if (hci_dev_test_flag(hcon->hdev, HCI_LE_ENABLED) && + (bredr_sc_enabled(hcon->hdev) || + hci_dev_test_flag(hcon->hdev, HCI_FORCE_BREDR_SMP))) + conn->local_fixed_chan |= L2CAP_FC_SMP_BREDR; + + mutex_init(&conn->ident_lock); + mutex_init(&conn->chan_lock); + + INIT_LIST_HEAD(&conn->chan_l); + INIT_LIST_HEAD(&conn->users); + + INIT_DELAYED_WORK(&conn->info_timer, l2cap_info_timeout); + + skb_queue_head_init(&conn->pending_rx); + INIT_WORK(&conn->pending_rx_work, process_pending_rx); + INIT_WORK(&conn->id_addr_update_work, l2cap_conn_update_id_addr); + + conn->disc_reason = HCI_ERROR_REMOTE_USER_TERM; + + return conn; +} + +static bool is_valid_psm(u16 psm, u8 dst_type) +{ + if (!psm) + return false; + + if (bdaddr_type_is_le(dst_type)) + return (psm <= 0x00ff); + + /* PSM must be odd and lsb of upper byte must be 0 */ + return ((psm & 0x0101) == 0x0001); +} + +struct l2cap_chan_data { + struct l2cap_chan *chan; + struct pid *pid; + int count; +}; + +static void l2cap_chan_by_pid(struct l2cap_chan *chan, void *data) +{ + struct l2cap_chan_data *d = data; + struct pid *pid; + + if (chan == d->chan) + return; + + if (!test_bit(FLAG_DEFER_SETUP, &chan->flags)) + return; + + pid = chan->ops->get_peer_pid(chan); + + /* Only count deferred channels with the same PID/PSM */ + if (d->pid != pid || chan->psm != d->chan->psm || chan->ident || + chan->mode != L2CAP_MODE_EXT_FLOWCTL || chan->state != BT_CONNECT) + return; + + d->count++; +} + +int l2cap_chan_connect(struct l2cap_chan *chan, __le16 psm, u16 cid, + bdaddr_t *dst, u8 dst_type) +{ + struct l2cap_conn *conn; + struct hci_conn *hcon; + struct hci_dev *hdev; + int err; + + BT_DBG("%pMR -> %pMR (type %u) psm 0x%4.4x mode 0x%2.2x", &chan->src, + dst, dst_type, __le16_to_cpu(psm), chan->mode); + + hdev = hci_get_route(dst, &chan->src, chan->src_type); + if (!hdev) + return -EHOSTUNREACH; + + hci_dev_lock(hdev); + + if (!is_valid_psm(__le16_to_cpu(psm), dst_type) && !cid && + chan->chan_type != L2CAP_CHAN_RAW) { + err = -EINVAL; + goto done; + } + + if (chan->chan_type == L2CAP_CHAN_CONN_ORIENTED && !psm) { + err = -EINVAL; + goto done; + } + + if (chan->chan_type == L2CAP_CHAN_FIXED && !cid) { + err = -EINVAL; + goto done; + } + + switch (chan->mode) { + case L2CAP_MODE_BASIC: + break; + case L2CAP_MODE_LE_FLOWCTL: + break; + case L2CAP_MODE_EXT_FLOWCTL: + if (!enable_ecred) { + err = -EOPNOTSUPP; + goto done; + } + break; + case L2CAP_MODE_ERTM: + case L2CAP_MODE_STREAMING: + if (!disable_ertm) + break; + fallthrough; + default: + err = -EOPNOTSUPP; + goto done; + } + + switch (chan->state) { + case BT_CONNECT: + case BT_CONNECT2: + case BT_CONFIG: + /* Already connecting */ + err = 0; + goto done; + + case BT_CONNECTED: + /* Already connected */ + err = -EISCONN; + goto done; + + case BT_OPEN: + case BT_BOUND: + /* Can connect */ + break; + + default: + err = -EBADFD; + goto done; + } + + /* Set destination address and psm */ + bacpy(&chan->dst, dst); + chan->dst_type = dst_type; + + chan->psm = psm; + chan->dcid = cid; + + if (bdaddr_type_is_le(dst_type)) { + /* Convert from L2CAP channel address type to HCI address type + */ + if (dst_type == BDADDR_LE_PUBLIC) + dst_type = ADDR_LE_DEV_PUBLIC; + else + dst_type = ADDR_LE_DEV_RANDOM; + + if (hci_dev_test_flag(hdev, HCI_ADVERTISING)) + hcon = hci_connect_le(hdev, dst, dst_type, false, + chan->sec_level, + HCI_LE_CONN_TIMEOUT, + HCI_ROLE_SLAVE); + else + hcon = hci_connect_le_scan(hdev, dst, dst_type, + chan->sec_level, + HCI_LE_CONN_TIMEOUT, + CONN_REASON_L2CAP_CHAN); + + } else { + u8 auth_type = l2cap_get_auth_type(chan); + hcon = hci_connect_acl(hdev, dst, chan->sec_level, auth_type, + CONN_REASON_L2CAP_CHAN); + } + + if (IS_ERR(hcon)) { + err = PTR_ERR(hcon); + goto done; + } + + conn = l2cap_conn_add(hcon); + if (!conn) { + hci_conn_drop(hcon); + err = -ENOMEM; + goto done; + } + + if (chan->mode == L2CAP_MODE_EXT_FLOWCTL) { + struct l2cap_chan_data data; + + data.chan = chan; + data.pid = chan->ops->get_peer_pid(chan); + data.count = 1; + + l2cap_chan_list(conn, l2cap_chan_by_pid, &data); + + /* Check if there isn't too many channels being connected */ + if (data.count > L2CAP_ECRED_CONN_SCID_MAX) { + hci_conn_drop(hcon); + err = -EPROTO; + goto done; + } + } + + mutex_lock(&conn->chan_lock); + l2cap_chan_lock(chan); + + if (cid && __l2cap_get_chan_by_dcid(conn, cid)) { + hci_conn_drop(hcon); + err = -EBUSY; + goto chan_unlock; + } + + /* Update source addr of the socket */ + bacpy(&chan->src, &hcon->src); + chan->src_type = bdaddr_src_type(hcon); + + __l2cap_chan_add(conn, chan); + + /* l2cap_chan_add takes its own ref so we can drop this one */ + hci_conn_drop(hcon); + + l2cap_state_change(chan, BT_CONNECT); + __set_chan_timer(chan, chan->ops->get_sndtimeo(chan)); + + /* Release chan->sport so that it can be reused by other + * sockets (as it's only used for listening sockets). + */ + write_lock(&chan_list_lock); + chan->sport = 0; + write_unlock(&chan_list_lock); + + if (hcon->state == BT_CONNECTED) { + if (chan->chan_type != L2CAP_CHAN_CONN_ORIENTED) { + __clear_chan_timer(chan); + if (l2cap_chan_check_security(chan, true)) + l2cap_state_change(chan, BT_CONNECTED); + } else + l2cap_do_start(chan); + } + + err = 0; + +chan_unlock: + l2cap_chan_unlock(chan); + mutex_unlock(&conn->chan_lock); +done: + hci_dev_unlock(hdev); + hci_dev_put(hdev); + return err; +} +EXPORT_SYMBOL_GPL(l2cap_chan_connect); + +static void l2cap_ecred_reconfigure(struct l2cap_chan *chan) +{ + struct l2cap_conn *conn = chan->conn; + struct { + struct l2cap_ecred_reconf_req req; + __le16 scid; + } pdu; + + pdu.req.mtu = cpu_to_le16(chan->imtu); + pdu.req.mps = cpu_to_le16(chan->mps); + pdu.scid = cpu_to_le16(chan->scid); + + chan->ident = l2cap_get_ident(conn); + + l2cap_send_cmd(conn, chan->ident, L2CAP_ECRED_RECONF_REQ, + sizeof(pdu), &pdu); +} + +int l2cap_chan_reconfigure(struct l2cap_chan *chan, __u16 mtu) +{ + if (chan->imtu > mtu) + return -EINVAL; + + BT_DBG("chan %p mtu 0x%4.4x", chan, mtu); + + chan->imtu = mtu; + + l2cap_ecred_reconfigure(chan); + + return 0; +} + +/* ---- L2CAP interface with lower layer (HCI) ---- */ + +int l2cap_connect_ind(struct hci_dev *hdev, bdaddr_t *bdaddr) +{ + int exact = 0, lm1 = 0, lm2 = 0; + struct l2cap_chan *c; + + BT_DBG("hdev %s, bdaddr %pMR", hdev->name, bdaddr); + + /* Find listening sockets and check their link_mode */ + read_lock(&chan_list_lock); + list_for_each_entry(c, &chan_list, global_l) { + if (c->state != BT_LISTEN) + continue; + + if (!bacmp(&c->src, &hdev->bdaddr)) { + lm1 |= HCI_LM_ACCEPT; + if (test_bit(FLAG_ROLE_SWITCH, &c->flags)) + lm1 |= HCI_LM_MASTER; + exact++; + } else if (!bacmp(&c->src, BDADDR_ANY)) { + lm2 |= HCI_LM_ACCEPT; + if (test_bit(FLAG_ROLE_SWITCH, &c->flags)) + lm2 |= HCI_LM_MASTER; + } + } + read_unlock(&chan_list_lock); + + return exact ? lm1 : lm2; +} + +/* Find the next fixed channel in BT_LISTEN state, continue iteration + * from an existing channel in the list or from the beginning of the + * global list (by passing NULL as first parameter). + */ +static struct l2cap_chan *l2cap_global_fixed_chan(struct l2cap_chan *c, + struct hci_conn *hcon) +{ + u8 src_type = bdaddr_src_type(hcon); + + read_lock(&chan_list_lock); + + if (c) + c = list_next_entry(c, global_l); + else + c = list_entry(chan_list.next, typeof(*c), global_l); + + list_for_each_entry_from(c, &chan_list, global_l) { + if (c->chan_type != L2CAP_CHAN_FIXED) + continue; + if (c->state != BT_LISTEN) + continue; + if (bacmp(&c->src, &hcon->src) && bacmp(&c->src, BDADDR_ANY)) + continue; + if (src_type != c->src_type) + continue; + + c = l2cap_chan_hold_unless_zero(c); + read_unlock(&chan_list_lock); + return c; + } + + read_unlock(&chan_list_lock); + + return NULL; +} + +static void l2cap_connect_cfm(struct hci_conn *hcon, u8 status) +{ + struct hci_dev *hdev = hcon->hdev; + struct l2cap_conn *conn; + struct l2cap_chan *pchan; + u8 dst_type; + + if (hcon->type != ACL_LINK && hcon->type != LE_LINK) + return; + + BT_DBG("hcon %p bdaddr %pMR status %d", hcon, &hcon->dst, status); + + if (status) { + l2cap_conn_del(hcon, bt_to_errno(status)); + return; + } + + conn = l2cap_conn_add(hcon); + if (!conn) + return; + + dst_type = bdaddr_dst_type(hcon); + + /* If device is blocked, do not create channels for it */ + if (hci_bdaddr_list_lookup(&hdev->reject_list, &hcon->dst, dst_type)) + return; + + /* Find fixed channels and notify them of the new connection. We + * use multiple individual lookups, continuing each time where + * we left off, because the list lock would prevent calling the + * potentially sleeping l2cap_chan_lock() function. + */ + pchan = l2cap_global_fixed_chan(NULL, hcon); + while (pchan) { + struct l2cap_chan *chan, *next; + + /* Client fixed channels should override server ones */ + if (__l2cap_get_chan_by_dcid(conn, pchan->scid)) + goto next; + + l2cap_chan_lock(pchan); + chan = pchan->ops->new_connection(pchan); + if (chan) { + bacpy(&chan->src, &hcon->src); + bacpy(&chan->dst, &hcon->dst); + chan->src_type = bdaddr_src_type(hcon); + chan->dst_type = dst_type; + + __l2cap_chan_add(conn, chan); + } + + l2cap_chan_unlock(pchan); +next: + next = l2cap_global_fixed_chan(pchan, hcon); + l2cap_chan_put(pchan); + pchan = next; + } + + l2cap_conn_ready(conn); +} + +int l2cap_disconn_ind(struct hci_conn *hcon) +{ + struct l2cap_conn *conn = hcon->l2cap_data; + + BT_DBG("hcon %p", hcon); + + if (!conn) + return HCI_ERROR_REMOTE_USER_TERM; + return conn->disc_reason; +} + +static void l2cap_disconn_cfm(struct hci_conn *hcon, u8 reason) +{ + if (hcon->type != ACL_LINK && hcon->type != LE_LINK) + return; + + BT_DBG("hcon %p reason %d", hcon, reason); + + l2cap_conn_del(hcon, bt_to_errno(reason)); +} + +static inline void l2cap_check_encryption(struct l2cap_chan *chan, u8 encrypt) +{ + if (chan->chan_type != L2CAP_CHAN_CONN_ORIENTED) + return; + + if (encrypt == 0x00) { + if (chan->sec_level == BT_SECURITY_MEDIUM) { + __set_chan_timer(chan, L2CAP_ENC_TIMEOUT); + } else if (chan->sec_level == BT_SECURITY_HIGH || + chan->sec_level == BT_SECURITY_FIPS) + l2cap_chan_close(chan, ECONNREFUSED); + } else { + if (chan->sec_level == BT_SECURITY_MEDIUM) + __clear_chan_timer(chan); + } +} + +static void l2cap_security_cfm(struct hci_conn *hcon, u8 status, u8 encrypt) +{ + struct l2cap_conn *conn = hcon->l2cap_data; + struct l2cap_chan *chan; + + if (!conn) + return; + + BT_DBG("conn %p status 0x%2.2x encrypt %u", conn, status, encrypt); + + mutex_lock(&conn->chan_lock); + + list_for_each_entry(chan, &conn->chan_l, list) { + l2cap_chan_lock(chan); + + BT_DBG("chan %p scid 0x%4.4x state %s", chan, chan->scid, + state_to_string(chan->state)); + + if (chan->scid == L2CAP_CID_A2MP) { + l2cap_chan_unlock(chan); + continue; + } + + if (!status && encrypt) + chan->sec_level = hcon->sec_level; + + if (!__l2cap_no_conn_pending(chan)) { + l2cap_chan_unlock(chan); + continue; + } + + if (!status && (chan->state == BT_CONNECTED || + chan->state == BT_CONFIG)) { + chan->ops->resume(chan); + l2cap_check_encryption(chan, encrypt); + l2cap_chan_unlock(chan); + continue; + } + + if (chan->state == BT_CONNECT) { + if (!status && l2cap_check_enc_key_size(hcon)) + l2cap_start_connection(chan); + else + __set_chan_timer(chan, L2CAP_DISC_TIMEOUT); + } else if (chan->state == BT_CONNECT2 && + !(chan->mode == L2CAP_MODE_EXT_FLOWCTL || + chan->mode == L2CAP_MODE_LE_FLOWCTL)) { + struct l2cap_conn_rsp rsp; + __u16 res, stat; + + if (!status && l2cap_check_enc_key_size(hcon)) { + if (test_bit(FLAG_DEFER_SETUP, &chan->flags)) { + res = L2CAP_CR_PEND; + stat = L2CAP_CS_AUTHOR_PEND; + chan->ops->defer(chan); + } else { + l2cap_state_change(chan, BT_CONFIG); + res = L2CAP_CR_SUCCESS; + stat = L2CAP_CS_NO_INFO; + } + } else { + l2cap_state_change(chan, BT_DISCONN); + __set_chan_timer(chan, L2CAP_DISC_TIMEOUT); + res = L2CAP_CR_SEC_BLOCK; + stat = L2CAP_CS_NO_INFO; + } + + rsp.scid = cpu_to_le16(chan->dcid); + rsp.dcid = cpu_to_le16(chan->scid); + rsp.result = cpu_to_le16(res); + rsp.status = cpu_to_le16(stat); + l2cap_send_cmd(conn, chan->ident, L2CAP_CONN_RSP, + sizeof(rsp), &rsp); + + if (!test_bit(CONF_REQ_SENT, &chan->conf_state) && + res == L2CAP_CR_SUCCESS) { + char buf[128]; + set_bit(CONF_REQ_SENT, &chan->conf_state); + l2cap_send_cmd(conn, l2cap_get_ident(conn), + L2CAP_CONF_REQ, + l2cap_build_conf_req(chan, buf, sizeof(buf)), + buf); + chan->num_conf_req++; + } + } + + l2cap_chan_unlock(chan); + } + + mutex_unlock(&conn->chan_lock); +} + +/* Append fragment into frame respecting the maximum len of rx_skb */ +static int l2cap_recv_frag(struct l2cap_conn *conn, struct sk_buff *skb, + u16 len) +{ + if (!conn->rx_skb) { + /* Allocate skb for the complete frame (with header) */ + conn->rx_skb = bt_skb_alloc(len, GFP_KERNEL); + if (!conn->rx_skb) + return -ENOMEM; + /* Init rx_len */ + conn->rx_len = len; + } + + /* Copy as much as the rx_skb can hold */ + len = min_t(u16, len, skb->len); + skb_copy_from_linear_data(skb, skb_put(conn->rx_skb, len), len); + skb_pull(skb, len); + conn->rx_len -= len; + + return len; +} + +static int l2cap_recv_len(struct l2cap_conn *conn, struct sk_buff *skb) +{ + struct sk_buff *rx_skb; + int len; + + /* Append just enough to complete the header */ + len = l2cap_recv_frag(conn, skb, L2CAP_LEN_SIZE - conn->rx_skb->len); + + /* If header could not be read just continue */ + if (len < 0 || conn->rx_skb->len < L2CAP_LEN_SIZE) + return len; + + rx_skb = conn->rx_skb; + len = get_unaligned_le16(rx_skb->data); + + /* Check if rx_skb has enough space to received all fragments */ + if (len + (L2CAP_HDR_SIZE - L2CAP_LEN_SIZE) <= skb_tailroom(rx_skb)) { + /* Update expected len */ + conn->rx_len = len + (L2CAP_HDR_SIZE - L2CAP_LEN_SIZE); + return L2CAP_LEN_SIZE; + } + + /* Reset conn->rx_skb since it will need to be reallocated in order to + * fit all fragments. + */ + conn->rx_skb = NULL; + + /* Reallocates rx_skb using the exact expected length */ + len = l2cap_recv_frag(conn, rx_skb, + len + (L2CAP_HDR_SIZE - L2CAP_LEN_SIZE)); + kfree_skb(rx_skb); + + return len; +} + +static void l2cap_recv_reset(struct l2cap_conn *conn) +{ + kfree_skb(conn->rx_skb); + conn->rx_skb = NULL; + conn->rx_len = 0; +} + +void l2cap_recv_acldata(struct hci_conn *hcon, struct sk_buff *skb, u16 flags) +{ + struct l2cap_conn *conn = hcon->l2cap_data; + int len; + + /* For AMP controller do not create l2cap conn */ + if (!conn && hcon->hdev->dev_type != HCI_PRIMARY) + goto drop; + + if (!conn) + conn = l2cap_conn_add(hcon); + + if (!conn) + goto drop; + + BT_DBG("conn %p len %u flags 0x%x", conn, skb->len, flags); + + switch (flags) { + case ACL_START: + case ACL_START_NO_FLUSH: + case ACL_COMPLETE: + if (conn->rx_skb) { + BT_ERR("Unexpected start frame (len %d)", skb->len); + l2cap_recv_reset(conn); + l2cap_conn_unreliable(conn, ECOMM); + } + + /* Start fragment may not contain the L2CAP length so just + * copy the initial byte when that happens and use conn->mtu as + * expected length. + */ + if (skb->len < L2CAP_LEN_SIZE) { + l2cap_recv_frag(conn, skb, conn->mtu); + break; + } + + len = get_unaligned_le16(skb->data) + L2CAP_HDR_SIZE; + + if (len == skb->len) { + /* Complete frame received */ + l2cap_recv_frame(conn, skb); + return; + } + + BT_DBG("Start: total len %d, frag len %u", len, skb->len); + + if (skb->len > len) { + BT_ERR("Frame is too long (len %u, expected len %d)", + skb->len, len); + l2cap_conn_unreliable(conn, ECOMM); + goto drop; + } + + /* Append fragment into frame (with header) */ + if (l2cap_recv_frag(conn, skb, len) < 0) + goto drop; + + break; + + case ACL_CONT: + BT_DBG("Cont: frag len %u (expecting %u)", skb->len, conn->rx_len); + + if (!conn->rx_skb) { + BT_ERR("Unexpected continuation frame (len %d)", skb->len); + l2cap_conn_unreliable(conn, ECOMM); + goto drop; + } + + /* Complete the L2CAP length if it has not been read */ + if (conn->rx_skb->len < L2CAP_LEN_SIZE) { + if (l2cap_recv_len(conn, skb) < 0) { + l2cap_conn_unreliable(conn, ECOMM); + goto drop; + } + + /* Header still could not be read just continue */ + if (conn->rx_skb->len < L2CAP_LEN_SIZE) + break; + } + + if (skb->len > conn->rx_len) { + BT_ERR("Fragment is too long (len %u, expected %u)", + skb->len, conn->rx_len); + l2cap_recv_reset(conn); + l2cap_conn_unreliable(conn, ECOMM); + goto drop; + } + + /* Append fragment into frame (with header) */ + l2cap_recv_frag(conn, skb, skb->len); + + if (!conn->rx_len) { + /* Complete frame received. l2cap_recv_frame + * takes ownership of the skb so set the global + * rx_skb pointer to NULL first. + */ + struct sk_buff *rx_skb = conn->rx_skb; + conn->rx_skb = NULL; + l2cap_recv_frame(conn, rx_skb); + } + break; + } + +drop: + kfree_skb(skb); +} + +static struct hci_cb l2cap_cb = { + .name = "L2CAP", + .connect_cfm = l2cap_connect_cfm, + .disconn_cfm = l2cap_disconn_cfm, + .security_cfm = l2cap_security_cfm, +}; + +static int l2cap_debugfs_show(struct seq_file *f, void *p) +{ + struct l2cap_chan *c; + + read_lock(&chan_list_lock); + + list_for_each_entry(c, &chan_list, global_l) { + seq_printf(f, "%pMR (%u) %pMR (%u) %d %d 0x%4.4x 0x%4.4x %d %d %d %d\n", + &c->src, c->src_type, &c->dst, c->dst_type, + c->state, __le16_to_cpu(c->psm), + c->scid, c->dcid, c->imtu, c->omtu, + c->sec_level, c->mode); + } + + read_unlock(&chan_list_lock); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(l2cap_debugfs); + +static struct dentry *l2cap_debugfs; + +int __init l2cap_init(void) +{ + int err; + + err = l2cap_init_sockets(); + if (err < 0) + return err; + + hci_register_cb(&l2cap_cb); + + if (IS_ERR_OR_NULL(bt_debugfs)) + return 0; + + l2cap_debugfs = debugfs_create_file("l2cap", 0444, bt_debugfs, + NULL, &l2cap_debugfs_fops); + + return 0; +} + +void l2cap_exit(void) +{ + debugfs_remove(l2cap_debugfs); + hci_unregister_cb(&l2cap_cb); + l2cap_cleanup_sockets(); +} + +module_param(disable_ertm, bool, 0644); +MODULE_PARM_DESC(disable_ertm, "Disable enhanced retransmission mode"); + +module_param(enable_ecred, bool, 0644); +MODULE_PARM_DESC(enable_ecred, "Enable enhanced credit flow control mode"); diff --git a/net/bluetooth/l2cap_sock.c b/net/bluetooth/l2cap_sock.c new file mode 100644 index 000000000..947ca580b --- /dev/null +++ b/net/bluetooth/l2cap_sock.c @@ -0,0 +1,1981 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + Copyright (C) 2000-2001 Qualcomm Incorporated + Copyright (C) 2009-2010 Gustavo F. Padovan + Copyright (C) 2010 Google Inc. + Copyright (C) 2011 ProFUSION Embedded Systems + + Written 2000,2001 by Maxim Krasnyansky + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +/* Bluetooth L2CAP sockets. */ + +#include +#include +#include +#include + +#include +#include +#include + +#include "smp.h" + +static struct bt_sock_list l2cap_sk_list = { + .lock = __RW_LOCK_UNLOCKED(l2cap_sk_list.lock) +}; + +static const struct proto_ops l2cap_sock_ops; +static void l2cap_sock_init(struct sock *sk, struct sock *parent); +static struct sock *l2cap_sock_alloc(struct net *net, struct socket *sock, + int proto, gfp_t prio, int kern); +static void l2cap_sock_cleanup_listen(struct sock *parent); + +bool l2cap_is_socket(struct socket *sock) +{ + return sock && sock->ops == &l2cap_sock_ops; +} +EXPORT_SYMBOL(l2cap_is_socket); + +static int l2cap_validate_bredr_psm(u16 psm) +{ + /* PSM must be odd and lsb of upper byte must be 0 */ + if ((psm & 0x0101) != 0x0001) + return -EINVAL; + + /* Restrict usage of well-known PSMs */ + if (psm < L2CAP_PSM_DYN_START && !capable(CAP_NET_BIND_SERVICE)) + return -EACCES; + + return 0; +} + +static int l2cap_validate_le_psm(u16 psm) +{ + /* Valid LE_PSM ranges are defined only until 0x00ff */ + if (psm > L2CAP_PSM_LE_DYN_END) + return -EINVAL; + + /* Restrict fixed, SIG assigned PSM values to CAP_NET_BIND_SERVICE */ + if (psm < L2CAP_PSM_LE_DYN_START && !capable(CAP_NET_BIND_SERVICE)) + return -EACCES; + + return 0; +} + +static int l2cap_sock_bind(struct socket *sock, struct sockaddr *addr, int alen) +{ + struct sock *sk = sock->sk; + struct l2cap_chan *chan = l2cap_pi(sk)->chan; + struct sockaddr_l2 la; + int len, err = 0; + + BT_DBG("sk %p", sk); + + if (!addr || alen < offsetofend(struct sockaddr, sa_family) || + addr->sa_family != AF_BLUETOOTH) + return -EINVAL; + + memset(&la, 0, sizeof(la)); + len = min_t(unsigned int, sizeof(la), alen); + memcpy(&la, addr, len); + + if (la.l2_cid && la.l2_psm) + return -EINVAL; + + if (!bdaddr_type_is_valid(la.l2_bdaddr_type)) + return -EINVAL; + + if (bdaddr_type_is_le(la.l2_bdaddr_type)) { + /* We only allow ATT user space socket */ + if (la.l2_cid && + la.l2_cid != cpu_to_le16(L2CAP_CID_ATT)) + return -EINVAL; + } + + lock_sock(sk); + + if (sk->sk_state != BT_OPEN) { + err = -EBADFD; + goto done; + } + + if (la.l2_psm) { + __u16 psm = __le16_to_cpu(la.l2_psm); + + if (la.l2_bdaddr_type == BDADDR_BREDR) + err = l2cap_validate_bredr_psm(psm); + else + err = l2cap_validate_le_psm(psm); + + if (err) + goto done; + } + + bacpy(&chan->src, &la.l2_bdaddr); + chan->src_type = la.l2_bdaddr_type; + + if (la.l2_cid) + err = l2cap_add_scid(chan, __le16_to_cpu(la.l2_cid)); + else + err = l2cap_add_psm(chan, &la.l2_bdaddr, la.l2_psm); + + if (err < 0) + goto done; + + switch (chan->chan_type) { + case L2CAP_CHAN_CONN_LESS: + if (__le16_to_cpu(la.l2_psm) == L2CAP_PSM_3DSP) + chan->sec_level = BT_SECURITY_SDP; + break; + case L2CAP_CHAN_CONN_ORIENTED: + if (__le16_to_cpu(la.l2_psm) == L2CAP_PSM_SDP || + __le16_to_cpu(la.l2_psm) == L2CAP_PSM_RFCOMM) + chan->sec_level = BT_SECURITY_SDP; + break; + case L2CAP_CHAN_RAW: + chan->sec_level = BT_SECURITY_SDP; + break; + case L2CAP_CHAN_FIXED: + /* Fixed channels default to the L2CAP core not holding a + * hci_conn reference for them. For fixed channels mapping to + * L2CAP sockets we do want to hold a reference so set the + * appropriate flag to request it. + */ + set_bit(FLAG_HOLD_HCI_CONN, &chan->flags); + break; + } + + /* Use L2CAP_MODE_LE_FLOWCTL (CoC) in case of LE address and + * L2CAP_MODE_EXT_FLOWCTL (ECRED) has not been set. + */ + if (chan->psm && bdaddr_type_is_le(chan->src_type) && + chan->mode != L2CAP_MODE_EXT_FLOWCTL) + chan->mode = L2CAP_MODE_LE_FLOWCTL; + + chan->state = BT_BOUND; + sk->sk_state = BT_BOUND; + +done: + release_sock(sk); + return err; +} + +static void l2cap_sock_init_pid(struct sock *sk) +{ + struct l2cap_chan *chan = l2cap_pi(sk)->chan; + + /* Only L2CAP_MODE_EXT_FLOWCTL ever need to access the PID in order to + * group the channels being requested. + */ + if (chan->mode != L2CAP_MODE_EXT_FLOWCTL) + return; + + spin_lock(&sk->sk_peer_lock); + sk->sk_peer_pid = get_pid(task_tgid(current)); + spin_unlock(&sk->sk_peer_lock); +} + +static int l2cap_sock_connect(struct socket *sock, struct sockaddr *addr, + int alen, int flags) +{ + struct sock *sk = sock->sk; + struct l2cap_chan *chan = l2cap_pi(sk)->chan; + struct sockaddr_l2 la; + int len, err = 0; + bool zapped; + + BT_DBG("sk %p", sk); + + lock_sock(sk); + zapped = sock_flag(sk, SOCK_ZAPPED); + release_sock(sk); + + if (zapped) + return -EINVAL; + + if (!addr || alen < offsetofend(struct sockaddr, sa_family) || + addr->sa_family != AF_BLUETOOTH) + return -EINVAL; + + memset(&la, 0, sizeof(la)); + len = min_t(unsigned int, sizeof(la), alen); + memcpy(&la, addr, len); + + if (la.l2_cid && la.l2_psm) + return -EINVAL; + + if (!bdaddr_type_is_valid(la.l2_bdaddr_type)) + return -EINVAL; + + /* Check that the socket wasn't bound to something that + * conflicts with the address given to connect(). If chan->src + * is BDADDR_ANY it means bind() was never used, in which case + * chan->src_type and la.l2_bdaddr_type do not need to match. + */ + if (chan->src_type == BDADDR_BREDR && bacmp(&chan->src, BDADDR_ANY) && + bdaddr_type_is_le(la.l2_bdaddr_type)) { + /* Old user space versions will try to incorrectly bind + * the ATT socket using BDADDR_BREDR. We need to accept + * this and fix up the source address type only when + * both the source CID and destination CID indicate + * ATT. Anything else is an invalid combination. + */ + if (chan->scid != L2CAP_CID_ATT || + la.l2_cid != cpu_to_le16(L2CAP_CID_ATT)) + return -EINVAL; + + /* We don't have the hdev available here to make a + * better decision on random vs public, but since all + * user space versions that exhibit this issue anyway do + * not support random local addresses assuming public + * here is good enough. + */ + chan->src_type = BDADDR_LE_PUBLIC; + } + + if (chan->src_type != BDADDR_BREDR && la.l2_bdaddr_type == BDADDR_BREDR) + return -EINVAL; + + if (bdaddr_type_is_le(la.l2_bdaddr_type)) { + /* We only allow ATT user space socket */ + if (la.l2_cid && + la.l2_cid != cpu_to_le16(L2CAP_CID_ATT)) + return -EINVAL; + } + + /* Use L2CAP_MODE_LE_FLOWCTL (CoC) in case of LE address and + * L2CAP_MODE_EXT_FLOWCTL (ECRED) has not been set. + */ + if (chan->psm && bdaddr_type_is_le(chan->src_type) && + chan->mode != L2CAP_MODE_EXT_FLOWCTL) + chan->mode = L2CAP_MODE_LE_FLOWCTL; + + l2cap_sock_init_pid(sk); + + err = l2cap_chan_connect(chan, la.l2_psm, __le16_to_cpu(la.l2_cid), + &la.l2_bdaddr, la.l2_bdaddr_type); + if (err) + return err; + + lock_sock(sk); + + err = bt_sock_wait_state(sk, BT_CONNECTED, + sock_sndtimeo(sk, flags & O_NONBLOCK)); + + release_sock(sk); + + return err; +} + +static int l2cap_sock_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + struct l2cap_chan *chan = l2cap_pi(sk)->chan; + int err = 0; + + BT_DBG("sk %p backlog %d", sk, backlog); + + lock_sock(sk); + + if (sk->sk_state != BT_BOUND) { + err = -EBADFD; + goto done; + } + + if (sk->sk_type != SOCK_SEQPACKET && sk->sk_type != SOCK_STREAM) { + err = -EINVAL; + goto done; + } + + switch (chan->mode) { + case L2CAP_MODE_BASIC: + case L2CAP_MODE_LE_FLOWCTL: + break; + case L2CAP_MODE_EXT_FLOWCTL: + if (!enable_ecred) { + err = -EOPNOTSUPP; + goto done; + } + break; + case L2CAP_MODE_ERTM: + case L2CAP_MODE_STREAMING: + if (!disable_ertm) + break; + fallthrough; + default: + err = -EOPNOTSUPP; + goto done; + } + + l2cap_sock_init_pid(sk); + + sk->sk_max_ack_backlog = backlog; + sk->sk_ack_backlog = 0; + + /* Listening channels need to use nested locking in order not to + * cause lockdep warnings when the created child channels end up + * being locked in the same thread as the parent channel. + */ + atomic_set(&chan->nesting, L2CAP_NESTING_PARENT); + + chan->state = BT_LISTEN; + sk->sk_state = BT_LISTEN; + +done: + release_sock(sk); + return err; +} + +static int l2cap_sock_accept(struct socket *sock, struct socket *newsock, + int flags, bool kern) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + struct sock *sk = sock->sk, *nsk; + long timeo; + int err = 0; + + lock_sock_nested(sk, L2CAP_NESTING_PARENT); + + timeo = sock_rcvtimeo(sk, flags & O_NONBLOCK); + + BT_DBG("sk %p timeo %ld", sk, timeo); + + /* Wait for an incoming connection. (wake-one). */ + add_wait_queue_exclusive(sk_sleep(sk), &wait); + while (1) { + if (sk->sk_state != BT_LISTEN) { + err = -EBADFD; + break; + } + + nsk = bt_accept_dequeue(sk, newsock); + if (nsk) + break; + + if (!timeo) { + err = -EAGAIN; + break; + } + + if (signal_pending(current)) { + err = sock_intr_errno(timeo); + break; + } + + release_sock(sk); + + timeo = wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); + + lock_sock_nested(sk, L2CAP_NESTING_PARENT); + } + remove_wait_queue(sk_sleep(sk), &wait); + + if (err) + goto done; + + newsock->state = SS_CONNECTED; + + BT_DBG("new socket %p", nsk); + +done: + release_sock(sk); + return err; +} + +static int l2cap_sock_getname(struct socket *sock, struct sockaddr *addr, + int peer) +{ + struct sockaddr_l2 *la = (struct sockaddr_l2 *) addr; + struct sock *sk = sock->sk; + struct l2cap_chan *chan = l2cap_pi(sk)->chan; + + BT_DBG("sock %p, sk %p", sock, sk); + + if (peer && sk->sk_state != BT_CONNECTED && + sk->sk_state != BT_CONNECT && sk->sk_state != BT_CONNECT2 && + sk->sk_state != BT_CONFIG) + return -ENOTCONN; + + memset(la, 0, sizeof(struct sockaddr_l2)); + addr->sa_family = AF_BLUETOOTH; + + la->l2_psm = chan->psm; + + if (peer) { + bacpy(&la->l2_bdaddr, &chan->dst); + la->l2_cid = cpu_to_le16(chan->dcid); + la->l2_bdaddr_type = chan->dst_type; + } else { + bacpy(&la->l2_bdaddr, &chan->src); + la->l2_cid = cpu_to_le16(chan->scid); + la->l2_bdaddr_type = chan->src_type; + } + + return sizeof(struct sockaddr_l2); +} + +static int l2cap_get_mode(struct l2cap_chan *chan) +{ + switch (chan->mode) { + case L2CAP_MODE_BASIC: + return BT_MODE_BASIC; + case L2CAP_MODE_ERTM: + return BT_MODE_ERTM; + case L2CAP_MODE_STREAMING: + return BT_MODE_STREAMING; + case L2CAP_MODE_LE_FLOWCTL: + return BT_MODE_LE_FLOWCTL; + case L2CAP_MODE_EXT_FLOWCTL: + return BT_MODE_EXT_FLOWCTL; + } + + return -EINVAL; +} + +static int l2cap_sock_getsockopt_old(struct socket *sock, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + struct l2cap_chan *chan = l2cap_pi(sk)->chan; + struct l2cap_options opts; + struct l2cap_conninfo cinfo; + int len, err = 0; + u32 opt; + + BT_DBG("sk %p", sk); + + if (get_user(len, optlen)) + return -EFAULT; + + lock_sock(sk); + + switch (optname) { + case L2CAP_OPTIONS: + /* LE sockets should use BT_SNDMTU/BT_RCVMTU, but since + * legacy ATT code depends on getsockopt for + * L2CAP_OPTIONS we need to let this pass. + */ + if (bdaddr_type_is_le(chan->src_type) && + chan->scid != L2CAP_CID_ATT) { + err = -EINVAL; + break; + } + + /* Only BR/EDR modes are supported here */ + switch (chan->mode) { + case L2CAP_MODE_BASIC: + case L2CAP_MODE_ERTM: + case L2CAP_MODE_STREAMING: + break; + default: + err = -EINVAL; + break; + } + + if (err < 0) + break; + + memset(&opts, 0, sizeof(opts)); + opts.imtu = chan->imtu; + opts.omtu = chan->omtu; + opts.flush_to = chan->flush_to; + opts.mode = chan->mode; + opts.fcs = chan->fcs; + opts.max_tx = chan->max_tx; + opts.txwin_size = chan->tx_win; + + BT_DBG("mode 0x%2.2x", chan->mode); + + len = min_t(unsigned int, len, sizeof(opts)); + if (copy_to_user(optval, (char *) &opts, len)) + err = -EFAULT; + + break; + + case L2CAP_LM: + switch (chan->sec_level) { + case BT_SECURITY_LOW: + opt = L2CAP_LM_AUTH; + break; + case BT_SECURITY_MEDIUM: + opt = L2CAP_LM_AUTH | L2CAP_LM_ENCRYPT; + break; + case BT_SECURITY_HIGH: + opt = L2CAP_LM_AUTH | L2CAP_LM_ENCRYPT | + L2CAP_LM_SECURE; + break; + case BT_SECURITY_FIPS: + opt = L2CAP_LM_AUTH | L2CAP_LM_ENCRYPT | + L2CAP_LM_SECURE | L2CAP_LM_FIPS; + break; + default: + opt = 0; + break; + } + + if (test_bit(FLAG_ROLE_SWITCH, &chan->flags)) + opt |= L2CAP_LM_MASTER; + + if (test_bit(FLAG_FORCE_RELIABLE, &chan->flags)) + opt |= L2CAP_LM_RELIABLE; + + if (put_user(opt, (u32 __user *) optval)) + err = -EFAULT; + + break; + + case L2CAP_CONNINFO: + if (sk->sk_state != BT_CONNECTED && + !(sk->sk_state == BT_CONNECT2 && + test_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags))) { + err = -ENOTCONN; + break; + } + + memset(&cinfo, 0, sizeof(cinfo)); + cinfo.hci_handle = chan->conn->hcon->handle; + memcpy(cinfo.dev_class, chan->conn->hcon->dev_class, 3); + + len = min_t(unsigned int, len, sizeof(cinfo)); + if (copy_to_user(optval, (char *) &cinfo, len)) + err = -EFAULT; + + break; + + default: + err = -ENOPROTOOPT; + break; + } + + release_sock(sk); + return err; +} + +static int l2cap_sock_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + struct l2cap_chan *chan = l2cap_pi(sk)->chan; + struct bt_security sec; + struct bt_power pwr; + u32 phys; + int len, mode, err = 0; + + BT_DBG("sk %p", sk); + + if (level == SOL_L2CAP) + return l2cap_sock_getsockopt_old(sock, optname, optval, optlen); + + if (level != SOL_BLUETOOTH) + return -ENOPROTOOPT; + + if (get_user(len, optlen)) + return -EFAULT; + + lock_sock(sk); + + switch (optname) { + case BT_SECURITY: + if (chan->chan_type != L2CAP_CHAN_CONN_ORIENTED && + chan->chan_type != L2CAP_CHAN_FIXED && + chan->chan_type != L2CAP_CHAN_RAW) { + err = -EINVAL; + break; + } + + memset(&sec, 0, sizeof(sec)); + if (chan->conn) { + sec.level = chan->conn->hcon->sec_level; + + if (sk->sk_state == BT_CONNECTED) + sec.key_size = chan->conn->hcon->enc_key_size; + } else { + sec.level = chan->sec_level; + } + + len = min_t(unsigned int, len, sizeof(sec)); + if (copy_to_user(optval, (char *) &sec, len)) + err = -EFAULT; + + break; + + case BT_DEFER_SETUP: + if (sk->sk_state != BT_BOUND && sk->sk_state != BT_LISTEN) { + err = -EINVAL; + break; + } + + if (put_user(test_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags), + (u32 __user *) optval)) + err = -EFAULT; + + break; + + case BT_FLUSHABLE: + if (put_user(test_bit(FLAG_FLUSHABLE, &chan->flags), + (u32 __user *) optval)) + err = -EFAULT; + + break; + + case BT_POWER: + if (sk->sk_type != SOCK_SEQPACKET && sk->sk_type != SOCK_STREAM + && sk->sk_type != SOCK_RAW) { + err = -EINVAL; + break; + } + + pwr.force_active = test_bit(FLAG_FORCE_ACTIVE, &chan->flags); + + len = min_t(unsigned int, len, sizeof(pwr)); + if (copy_to_user(optval, (char *) &pwr, len)) + err = -EFAULT; + + break; + + case BT_CHANNEL_POLICY: + if (put_user(chan->chan_policy, (u32 __user *) optval)) + err = -EFAULT; + break; + + case BT_SNDMTU: + if (!bdaddr_type_is_le(chan->src_type)) { + err = -EINVAL; + break; + } + + if (sk->sk_state != BT_CONNECTED) { + err = -ENOTCONN; + break; + } + + if (put_user(chan->omtu, (u16 __user *) optval)) + err = -EFAULT; + break; + + case BT_RCVMTU: + if (!bdaddr_type_is_le(chan->src_type)) { + err = -EINVAL; + break; + } + + if (put_user(chan->imtu, (u16 __user *) optval)) + err = -EFAULT; + break; + + case BT_PHY: + if (sk->sk_state != BT_CONNECTED) { + err = -ENOTCONN; + break; + } + + phys = hci_conn_get_phy(chan->conn->hcon); + + if (put_user(phys, (u32 __user *) optval)) + err = -EFAULT; + break; + + case BT_MODE: + if (!enable_ecred) { + err = -ENOPROTOOPT; + break; + } + + if (chan->chan_type != L2CAP_CHAN_CONN_ORIENTED) { + err = -EINVAL; + break; + } + + mode = l2cap_get_mode(chan); + if (mode < 0) { + err = mode; + break; + } + + if (put_user(mode, (u8 __user *) optval)) + err = -EFAULT; + break; + + default: + err = -ENOPROTOOPT; + break; + } + + release_sock(sk); + return err; +} + +static bool l2cap_valid_mtu(struct l2cap_chan *chan, u16 mtu) +{ + switch (chan->scid) { + case L2CAP_CID_ATT: + if (mtu < L2CAP_LE_MIN_MTU) + return false; + break; + + default: + if (mtu < L2CAP_DEFAULT_MIN_MTU) + return false; + } + + return true; +} + +static int l2cap_sock_setsockopt_old(struct socket *sock, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct l2cap_chan *chan = l2cap_pi(sk)->chan; + struct l2cap_options opts; + int len, err = 0; + u32 opt; + + BT_DBG("sk %p", sk); + + lock_sock(sk); + + switch (optname) { + case L2CAP_OPTIONS: + if (bdaddr_type_is_le(chan->src_type)) { + err = -EINVAL; + break; + } + + if (sk->sk_state == BT_CONNECTED) { + err = -EINVAL; + break; + } + + opts.imtu = chan->imtu; + opts.omtu = chan->omtu; + opts.flush_to = chan->flush_to; + opts.mode = chan->mode; + opts.fcs = chan->fcs; + opts.max_tx = chan->max_tx; + opts.txwin_size = chan->tx_win; + + len = min_t(unsigned int, sizeof(opts), optlen); + if (copy_from_sockptr(&opts, optval, len)) { + err = -EFAULT; + break; + } + + if (opts.txwin_size > L2CAP_DEFAULT_EXT_WINDOW) { + err = -EINVAL; + break; + } + + if (!l2cap_valid_mtu(chan, opts.imtu)) { + err = -EINVAL; + break; + } + + /* Only BR/EDR modes are supported here */ + switch (opts.mode) { + case L2CAP_MODE_BASIC: + clear_bit(CONF_STATE2_DEVICE, &chan->conf_state); + break; + case L2CAP_MODE_ERTM: + case L2CAP_MODE_STREAMING: + if (!disable_ertm) + break; + fallthrough; + default: + err = -EINVAL; + break; + } + + if (err < 0) + break; + + chan->mode = opts.mode; + + BT_DBG("mode 0x%2.2x", chan->mode); + + chan->imtu = opts.imtu; + chan->omtu = opts.omtu; + chan->fcs = opts.fcs; + chan->max_tx = opts.max_tx; + chan->tx_win = opts.txwin_size; + chan->flush_to = opts.flush_to; + break; + + case L2CAP_LM: + if (copy_from_sockptr(&opt, optval, sizeof(u32))) { + err = -EFAULT; + break; + } + + if (opt & L2CAP_LM_FIPS) { + err = -EINVAL; + break; + } + + if (opt & L2CAP_LM_AUTH) + chan->sec_level = BT_SECURITY_LOW; + if (opt & L2CAP_LM_ENCRYPT) + chan->sec_level = BT_SECURITY_MEDIUM; + if (opt & L2CAP_LM_SECURE) + chan->sec_level = BT_SECURITY_HIGH; + + if (opt & L2CAP_LM_MASTER) + set_bit(FLAG_ROLE_SWITCH, &chan->flags); + else + clear_bit(FLAG_ROLE_SWITCH, &chan->flags); + + if (opt & L2CAP_LM_RELIABLE) + set_bit(FLAG_FORCE_RELIABLE, &chan->flags); + else + clear_bit(FLAG_FORCE_RELIABLE, &chan->flags); + break; + + default: + err = -ENOPROTOOPT; + break; + } + + release_sock(sk); + return err; +} + +static int l2cap_set_mode(struct l2cap_chan *chan, u8 mode) +{ + switch (mode) { + case BT_MODE_BASIC: + if (bdaddr_type_is_le(chan->src_type)) + return -EINVAL; + mode = L2CAP_MODE_BASIC; + clear_bit(CONF_STATE2_DEVICE, &chan->conf_state); + break; + case BT_MODE_ERTM: + if (!disable_ertm || bdaddr_type_is_le(chan->src_type)) + return -EINVAL; + mode = L2CAP_MODE_ERTM; + break; + case BT_MODE_STREAMING: + if (!disable_ertm || bdaddr_type_is_le(chan->src_type)) + return -EINVAL; + mode = L2CAP_MODE_STREAMING; + break; + case BT_MODE_LE_FLOWCTL: + if (!bdaddr_type_is_le(chan->src_type)) + return -EINVAL; + mode = L2CAP_MODE_LE_FLOWCTL; + break; + case BT_MODE_EXT_FLOWCTL: + /* TODO: Add support for ECRED PDUs to BR/EDR */ + if (!bdaddr_type_is_le(chan->src_type)) + return -EINVAL; + mode = L2CAP_MODE_EXT_FLOWCTL; + break; + default: + return -EINVAL; + } + + chan->mode = mode; + + return 0; +} + +static int l2cap_sock_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct l2cap_chan *chan = l2cap_pi(sk)->chan; + struct bt_security sec; + struct bt_power pwr; + struct l2cap_conn *conn; + int len, err = 0; + u32 opt; + u16 mtu; + u8 mode; + + BT_DBG("sk %p", sk); + + if (level == SOL_L2CAP) + return l2cap_sock_setsockopt_old(sock, optname, optval, optlen); + + if (level != SOL_BLUETOOTH) + return -ENOPROTOOPT; + + lock_sock(sk); + + switch (optname) { + case BT_SECURITY: + if (chan->chan_type != L2CAP_CHAN_CONN_ORIENTED && + chan->chan_type != L2CAP_CHAN_FIXED && + chan->chan_type != L2CAP_CHAN_RAW) { + err = -EINVAL; + break; + } + + sec.level = BT_SECURITY_LOW; + + len = min_t(unsigned int, sizeof(sec), optlen); + if (copy_from_sockptr(&sec, optval, len)) { + err = -EFAULT; + break; + } + + if (sec.level < BT_SECURITY_LOW || + sec.level > BT_SECURITY_FIPS) { + err = -EINVAL; + break; + } + + chan->sec_level = sec.level; + + if (!chan->conn) + break; + + conn = chan->conn; + + /* change security for LE channels */ + if (chan->scid == L2CAP_CID_ATT) { + if (smp_conn_security(conn->hcon, sec.level)) { + err = -EINVAL; + break; + } + + set_bit(FLAG_PENDING_SECURITY, &chan->flags); + sk->sk_state = BT_CONFIG; + chan->state = BT_CONFIG; + + /* or for ACL link */ + } else if ((sk->sk_state == BT_CONNECT2 && + test_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags)) || + sk->sk_state == BT_CONNECTED) { + if (!l2cap_chan_check_security(chan, true)) + set_bit(BT_SK_SUSPEND, &bt_sk(sk)->flags); + else + sk->sk_state_change(sk); + } else { + err = -EINVAL; + } + break; + + case BT_DEFER_SETUP: + if (sk->sk_state != BT_BOUND && sk->sk_state != BT_LISTEN) { + err = -EINVAL; + break; + } + + if (copy_from_sockptr(&opt, optval, sizeof(u32))) { + err = -EFAULT; + break; + } + + if (opt) { + set_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags); + set_bit(FLAG_DEFER_SETUP, &chan->flags); + } else { + clear_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags); + clear_bit(FLAG_DEFER_SETUP, &chan->flags); + } + break; + + case BT_FLUSHABLE: + if (copy_from_sockptr(&opt, optval, sizeof(u32))) { + err = -EFAULT; + break; + } + + if (opt > BT_FLUSHABLE_ON) { + err = -EINVAL; + break; + } + + if (opt == BT_FLUSHABLE_OFF) { + conn = chan->conn; + /* proceed further only when we have l2cap_conn and + No Flush support in the LM */ + if (!conn || !lmp_no_flush_capable(conn->hcon->hdev)) { + err = -EINVAL; + break; + } + } + + if (opt) + set_bit(FLAG_FLUSHABLE, &chan->flags); + else + clear_bit(FLAG_FLUSHABLE, &chan->flags); + break; + + case BT_POWER: + if (chan->chan_type != L2CAP_CHAN_CONN_ORIENTED && + chan->chan_type != L2CAP_CHAN_RAW) { + err = -EINVAL; + break; + } + + pwr.force_active = BT_POWER_FORCE_ACTIVE_ON; + + len = min_t(unsigned int, sizeof(pwr), optlen); + if (copy_from_sockptr(&pwr, optval, len)) { + err = -EFAULT; + break; + } + + if (pwr.force_active) + set_bit(FLAG_FORCE_ACTIVE, &chan->flags); + else + clear_bit(FLAG_FORCE_ACTIVE, &chan->flags); + break; + + case BT_CHANNEL_POLICY: + if (copy_from_sockptr(&opt, optval, sizeof(u32))) { + err = -EFAULT; + break; + } + + if (opt > BT_CHANNEL_POLICY_AMP_PREFERRED) { + err = -EINVAL; + break; + } + + if (chan->mode != L2CAP_MODE_ERTM && + chan->mode != L2CAP_MODE_STREAMING) { + err = -EOPNOTSUPP; + break; + } + + chan->chan_policy = (u8) opt; + + if (sk->sk_state == BT_CONNECTED && + chan->move_role == L2CAP_MOVE_ROLE_NONE) + l2cap_move_start(chan); + + break; + + case BT_SNDMTU: + if (!bdaddr_type_is_le(chan->src_type)) { + err = -EINVAL; + break; + } + + /* Setting is not supported as it's the remote side that + * decides this. + */ + err = -EPERM; + break; + + case BT_RCVMTU: + if (!bdaddr_type_is_le(chan->src_type)) { + err = -EINVAL; + break; + } + + if (chan->mode == L2CAP_MODE_LE_FLOWCTL && + sk->sk_state == BT_CONNECTED) { + err = -EISCONN; + break; + } + + if (copy_from_sockptr(&mtu, optval, sizeof(u16))) { + err = -EFAULT; + break; + } + + if (chan->mode == L2CAP_MODE_EXT_FLOWCTL && + sk->sk_state == BT_CONNECTED) + err = l2cap_chan_reconfigure(chan, mtu); + else + chan->imtu = mtu; + + break; + + case BT_MODE: + if (!enable_ecred) { + err = -ENOPROTOOPT; + break; + } + + BT_DBG("sk->sk_state %u", sk->sk_state); + + if (sk->sk_state != BT_BOUND) { + err = -EINVAL; + break; + } + + if (chan->chan_type != L2CAP_CHAN_CONN_ORIENTED) { + err = -EINVAL; + break; + } + + if (copy_from_sockptr(&mode, optval, sizeof(u8))) { + err = -EFAULT; + break; + } + + BT_DBG("mode %u", mode); + + err = l2cap_set_mode(chan, mode); + if (err) + break; + + BT_DBG("mode 0x%2.2x", chan->mode); + + break; + + default: + err = -ENOPROTOOPT; + break; + } + + release_sock(sk); + return err; +} + +static int l2cap_sock_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + struct sock *sk = sock->sk; + struct l2cap_chan *chan = l2cap_pi(sk)->chan; + int err; + + BT_DBG("sock %p, sk %p", sock, sk); + + err = sock_error(sk); + if (err) + return err; + + if (msg->msg_flags & MSG_OOB) + return -EOPNOTSUPP; + + if (sk->sk_state != BT_CONNECTED) + return -ENOTCONN; + + lock_sock(sk); + err = bt_sock_wait_ready(sk, msg->msg_flags); + release_sock(sk); + if (err) + return err; + + l2cap_chan_lock(chan); + err = l2cap_chan_send(chan, msg, len); + l2cap_chan_unlock(chan); + + return err; +} + +static int l2cap_sock_recvmsg(struct socket *sock, struct msghdr *msg, + size_t len, int flags) +{ + struct sock *sk = sock->sk; + struct l2cap_pinfo *pi = l2cap_pi(sk); + int err; + + lock_sock(sk); + + if (sk->sk_state == BT_CONNECT2 && test_bit(BT_SK_DEFER_SETUP, + &bt_sk(sk)->flags)) { + if (pi->chan->mode == L2CAP_MODE_EXT_FLOWCTL) { + sk->sk_state = BT_CONNECTED; + pi->chan->state = BT_CONNECTED; + __l2cap_ecred_conn_rsp_defer(pi->chan); + } else if (bdaddr_type_is_le(pi->chan->src_type)) { + sk->sk_state = BT_CONNECTED; + pi->chan->state = BT_CONNECTED; + __l2cap_le_connect_rsp_defer(pi->chan); + } else { + sk->sk_state = BT_CONFIG; + pi->chan->state = BT_CONFIG; + __l2cap_connect_rsp_defer(pi->chan); + } + + err = 0; + goto done; + } + + release_sock(sk); + + if (sock->type == SOCK_STREAM) + err = bt_sock_stream_recvmsg(sock, msg, len, flags); + else + err = bt_sock_recvmsg(sock, msg, len, flags); + + if (pi->chan->mode != L2CAP_MODE_ERTM) + return err; + + /* Attempt to put pending rx data in the socket buffer */ + + lock_sock(sk); + + if (!test_bit(CONN_LOCAL_BUSY, &pi->chan->conn_state)) + goto done; + + if (pi->rx_busy_skb) { + if (!__sock_queue_rcv_skb(sk, pi->rx_busy_skb)) + pi->rx_busy_skb = NULL; + else + goto done; + } + + /* Restore data flow when half of the receive buffer is + * available. This avoids resending large numbers of + * frames. + */ + if (atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf >> 1) + l2cap_chan_busy(pi->chan, 0); + +done: + release_sock(sk); + return err; +} + +/* Kill socket (only if zapped and orphan) + * Must be called on unlocked socket, with l2cap channel lock. + */ +static void l2cap_sock_kill(struct sock *sk) +{ + if (!sock_flag(sk, SOCK_ZAPPED) || sk->sk_socket) + return; + + BT_DBG("sk %p state %s", sk, state_to_string(sk->sk_state)); + + /* Kill poor orphan */ + + l2cap_chan_put(l2cap_pi(sk)->chan); + sock_set_flag(sk, SOCK_DEAD); + sock_put(sk); +} + +static int __l2cap_wait_ack(struct sock *sk, struct l2cap_chan *chan) +{ + DECLARE_WAITQUEUE(wait, current); + int err = 0; + int timeo = L2CAP_WAIT_ACK_POLL_PERIOD; + /* Timeout to prevent infinite loop */ + unsigned long timeout = jiffies + L2CAP_WAIT_ACK_TIMEOUT; + + add_wait_queue(sk_sleep(sk), &wait); + set_current_state(TASK_INTERRUPTIBLE); + do { + BT_DBG("Waiting for %d ACKs, timeout %04d ms", + chan->unacked_frames, time_after(jiffies, timeout) ? 0 : + jiffies_to_msecs(timeout - jiffies)); + + if (!timeo) + timeo = L2CAP_WAIT_ACK_POLL_PERIOD; + + if (signal_pending(current)) { + err = sock_intr_errno(timeo); + break; + } + + release_sock(sk); + timeo = schedule_timeout(timeo); + lock_sock(sk); + set_current_state(TASK_INTERRUPTIBLE); + + err = sock_error(sk); + if (err) + break; + + if (time_after(jiffies, timeout)) { + err = -ENOLINK; + break; + } + + } while (chan->unacked_frames > 0 && + chan->state == BT_CONNECTED); + + set_current_state(TASK_RUNNING); + remove_wait_queue(sk_sleep(sk), &wait); + return err; +} + +static int l2cap_sock_shutdown(struct socket *sock, int how) +{ + struct sock *sk = sock->sk; + struct l2cap_chan *chan; + struct l2cap_conn *conn; + int err = 0; + + BT_DBG("sock %p, sk %p, how %d", sock, sk, how); + + /* 'how' parameter is mapped to sk_shutdown as follows: + * SHUT_RD (0) --> RCV_SHUTDOWN (1) + * SHUT_WR (1) --> SEND_SHUTDOWN (2) + * SHUT_RDWR (2) --> SHUTDOWN_MASK (3) + */ + how++; + + if (!sk) + return 0; + + lock_sock(sk); + + if ((sk->sk_shutdown & how) == how) + goto shutdown_already; + + BT_DBG("Handling sock shutdown"); + + /* prevent sk structure from being freed whilst unlocked */ + sock_hold(sk); + + chan = l2cap_pi(sk)->chan; + /* prevent chan structure from being freed whilst unlocked */ + l2cap_chan_hold(chan); + + BT_DBG("chan %p state %s", chan, state_to_string(chan->state)); + + if (chan->mode == L2CAP_MODE_ERTM && + chan->unacked_frames > 0 && + chan->state == BT_CONNECTED) { + err = __l2cap_wait_ack(sk, chan); + + /* After waiting for ACKs, check whether shutdown + * has already been actioned to close the L2CAP + * link such as by l2cap_disconnection_req(). + */ + if ((sk->sk_shutdown & how) == how) + goto shutdown_matched; + } + + /* Try setting the RCV_SHUTDOWN bit, return early if SEND_SHUTDOWN + * is already set + */ + if ((how & RCV_SHUTDOWN) && !(sk->sk_shutdown & RCV_SHUTDOWN)) { + sk->sk_shutdown |= RCV_SHUTDOWN; + if ((sk->sk_shutdown & how) == how) + goto shutdown_matched; + } + + sk->sk_shutdown |= SEND_SHUTDOWN; + release_sock(sk); + + l2cap_chan_lock(chan); + conn = chan->conn; + if (conn) + /* prevent conn structure from being freed */ + l2cap_conn_get(conn); + l2cap_chan_unlock(chan); + + if (conn) + /* mutex lock must be taken before l2cap_chan_lock() */ + mutex_lock(&conn->chan_lock); + + l2cap_chan_lock(chan); + l2cap_chan_close(chan, 0); + l2cap_chan_unlock(chan); + + if (conn) { + mutex_unlock(&conn->chan_lock); + l2cap_conn_put(conn); + } + + lock_sock(sk); + + if (sock_flag(sk, SOCK_LINGER) && sk->sk_lingertime && + !(current->flags & PF_EXITING)) + err = bt_sock_wait_state(sk, BT_CLOSED, + sk->sk_lingertime); + +shutdown_matched: + l2cap_chan_put(chan); + sock_put(sk); + +shutdown_already: + if (!err && sk->sk_err) + err = -sk->sk_err; + + release_sock(sk); + + BT_DBG("Sock shutdown complete err: %d", err); + + return err; +} + +static int l2cap_sock_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + int err; + struct l2cap_chan *chan; + + BT_DBG("sock %p, sk %p", sock, sk); + + if (!sk) + return 0; + + l2cap_sock_cleanup_listen(sk); + bt_sock_unlink(&l2cap_sk_list, sk); + + err = l2cap_sock_shutdown(sock, SHUT_RDWR); + chan = l2cap_pi(sk)->chan; + + l2cap_chan_hold(chan); + l2cap_chan_lock(chan); + + sock_orphan(sk); + l2cap_sock_kill(sk); + + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); + + return err; +} + +static void l2cap_sock_cleanup_listen(struct sock *parent) +{ + struct sock *sk; + + BT_DBG("parent %p state %s", parent, + state_to_string(parent->sk_state)); + + /* Close not yet accepted channels */ + while ((sk = bt_accept_dequeue(parent, NULL))) { + struct l2cap_chan *chan = l2cap_pi(sk)->chan; + + BT_DBG("child chan %p state %s", chan, + state_to_string(chan->state)); + + l2cap_chan_hold(chan); + l2cap_chan_lock(chan); + + __clear_chan_timer(chan); + l2cap_chan_close(chan, ECONNRESET); + l2cap_sock_kill(sk); + + l2cap_chan_unlock(chan); + l2cap_chan_put(chan); + } +} + +static struct l2cap_chan *l2cap_sock_new_connection_cb(struct l2cap_chan *chan) +{ + struct sock *sk, *parent = chan->data; + + lock_sock(parent); + + /* Check for backlog size */ + if (sk_acceptq_is_full(parent)) { + BT_DBG("backlog full %d", parent->sk_ack_backlog); + release_sock(parent); + return NULL; + } + + sk = l2cap_sock_alloc(sock_net(parent), NULL, BTPROTO_L2CAP, + GFP_ATOMIC, 0); + if (!sk) { + release_sock(parent); + return NULL; + } + + bt_sock_reclassify_lock(sk, BTPROTO_L2CAP); + + l2cap_sock_init(sk, parent); + + bt_accept_enqueue(parent, sk, false); + + release_sock(parent); + + return l2cap_pi(sk)->chan; +} + +static int l2cap_sock_recv_cb(struct l2cap_chan *chan, struct sk_buff *skb) +{ + struct sock *sk = chan->data; + int err; + + lock_sock(sk); + + if (l2cap_pi(sk)->rx_busy_skb) { + err = -ENOMEM; + goto done; + } + + if (chan->mode != L2CAP_MODE_ERTM && + chan->mode != L2CAP_MODE_STREAMING) { + /* Even if no filter is attached, we could potentially + * get errors from security modules, etc. + */ + err = sk_filter(sk, skb); + if (err) + goto done; + } + + err = __sock_queue_rcv_skb(sk, skb); + + /* For ERTM, handle one skb that doesn't fit into the recv + * buffer. This is important to do because the data frames + * have already been acked, so the skb cannot be discarded. + * + * Notify the l2cap core that the buffer is full, so the + * LOCAL_BUSY state is entered and no more frames are + * acked and reassembled until there is buffer space + * available. + */ + if (err < 0 && chan->mode == L2CAP_MODE_ERTM) { + l2cap_pi(sk)->rx_busy_skb = skb; + l2cap_chan_busy(chan, 1); + err = 0; + } + +done: + release_sock(sk); + + return err; +} + +static void l2cap_sock_close_cb(struct l2cap_chan *chan) +{ + struct sock *sk = chan->data; + + if (!sk) + return; + + l2cap_sock_kill(sk); +} + +static void l2cap_sock_teardown_cb(struct l2cap_chan *chan, int err) +{ + struct sock *sk = chan->data; + struct sock *parent; + + if (!sk) + return; + + BT_DBG("chan %p state %s", chan, state_to_string(chan->state)); + + /* This callback can be called both for server (BT_LISTEN) + * sockets as well as "normal" ones. To avoid lockdep warnings + * with child socket locking (through l2cap_sock_cleanup_listen) + * we need separation into separate nesting levels. The simplest + * way to accomplish this is to inherit the nesting level used + * for the channel. + */ + lock_sock_nested(sk, atomic_read(&chan->nesting)); + + parent = bt_sk(sk)->parent; + + switch (chan->state) { + case BT_OPEN: + case BT_BOUND: + case BT_CLOSED: + break; + case BT_LISTEN: + l2cap_sock_cleanup_listen(sk); + sk->sk_state = BT_CLOSED; + chan->state = BT_CLOSED; + + break; + default: + sk->sk_state = BT_CLOSED; + chan->state = BT_CLOSED; + + sk->sk_err = err; + + if (parent) { + bt_accept_unlink(sk); + parent->sk_data_ready(parent); + } else { + sk->sk_state_change(sk); + } + + break; + } + release_sock(sk); + + /* Only zap after cleanup to avoid use after free race */ + sock_set_flag(sk, SOCK_ZAPPED); + +} + +static void l2cap_sock_state_change_cb(struct l2cap_chan *chan, int state, + int err) +{ + struct sock *sk = chan->data; + + sk->sk_state = state; + + if (err) + sk->sk_err = err; +} + +static struct sk_buff *l2cap_sock_alloc_skb_cb(struct l2cap_chan *chan, + unsigned long hdr_len, + unsigned long len, int nb) +{ + struct sock *sk = chan->data; + struct sk_buff *skb; + int err; + + l2cap_chan_unlock(chan); + skb = bt_skb_send_alloc(sk, hdr_len + len, nb, &err); + l2cap_chan_lock(chan); + + if (!skb) + return ERR_PTR(err); + + /* Channel lock is released before requesting new skb and then + * reacquired thus we need to recheck channel state. + */ + if (chan->state != BT_CONNECTED) { + kfree_skb(skb); + return ERR_PTR(-ENOTCONN); + } + + skb->priority = sk->sk_priority; + + bt_cb(skb)->l2cap.chan = chan; + + return skb; +} + +static void l2cap_sock_ready_cb(struct l2cap_chan *chan) +{ + struct sock *sk = chan->data; + struct sock *parent; + + lock_sock(sk); + + parent = bt_sk(sk)->parent; + + BT_DBG("sk %p, parent %p", sk, parent); + + sk->sk_state = BT_CONNECTED; + sk->sk_state_change(sk); + + if (parent) + parent->sk_data_ready(parent); + + release_sock(sk); +} + +static void l2cap_sock_defer_cb(struct l2cap_chan *chan) +{ + struct sock *parent, *sk = chan->data; + + lock_sock(sk); + + parent = bt_sk(sk)->parent; + if (parent) + parent->sk_data_ready(parent); + + release_sock(sk); +} + +static void l2cap_sock_resume_cb(struct l2cap_chan *chan) +{ + struct sock *sk = chan->data; + + if (test_and_clear_bit(FLAG_PENDING_SECURITY, &chan->flags)) { + sk->sk_state = BT_CONNECTED; + chan->state = BT_CONNECTED; + } + + clear_bit(BT_SK_SUSPEND, &bt_sk(sk)->flags); + sk->sk_state_change(sk); +} + +static void l2cap_sock_set_shutdown_cb(struct l2cap_chan *chan) +{ + struct sock *sk = chan->data; + + lock_sock(sk); + sk->sk_shutdown = SHUTDOWN_MASK; + release_sock(sk); +} + +static long l2cap_sock_get_sndtimeo_cb(struct l2cap_chan *chan) +{ + struct sock *sk = chan->data; + + return sk->sk_sndtimeo; +} + +static struct pid *l2cap_sock_get_peer_pid_cb(struct l2cap_chan *chan) +{ + struct sock *sk = chan->data; + + return sk->sk_peer_pid; +} + +static void l2cap_sock_suspend_cb(struct l2cap_chan *chan) +{ + struct sock *sk = chan->data; + + set_bit(BT_SK_SUSPEND, &bt_sk(sk)->flags); + sk->sk_state_change(sk); +} + +static int l2cap_sock_filter(struct l2cap_chan *chan, struct sk_buff *skb) +{ + struct sock *sk = chan->data; + + switch (chan->mode) { + case L2CAP_MODE_ERTM: + case L2CAP_MODE_STREAMING: + return sk_filter(sk, skb); + } + + return 0; +} + +static const struct l2cap_ops l2cap_chan_ops = { + .name = "L2CAP Socket Interface", + .new_connection = l2cap_sock_new_connection_cb, + .recv = l2cap_sock_recv_cb, + .close = l2cap_sock_close_cb, + .teardown = l2cap_sock_teardown_cb, + .state_change = l2cap_sock_state_change_cb, + .ready = l2cap_sock_ready_cb, + .defer = l2cap_sock_defer_cb, + .resume = l2cap_sock_resume_cb, + .suspend = l2cap_sock_suspend_cb, + .set_shutdown = l2cap_sock_set_shutdown_cb, + .get_sndtimeo = l2cap_sock_get_sndtimeo_cb, + .get_peer_pid = l2cap_sock_get_peer_pid_cb, + .alloc_skb = l2cap_sock_alloc_skb_cb, + .filter = l2cap_sock_filter, +}; + +static void l2cap_sock_destruct(struct sock *sk) +{ + BT_DBG("sk %p", sk); + + if (l2cap_pi(sk)->chan) { + l2cap_pi(sk)->chan->data = NULL; + l2cap_chan_put(l2cap_pi(sk)->chan); + } + + if (l2cap_pi(sk)->rx_busy_skb) { + kfree_skb(l2cap_pi(sk)->rx_busy_skb); + l2cap_pi(sk)->rx_busy_skb = NULL; + } + + skb_queue_purge(&sk->sk_receive_queue); + skb_queue_purge(&sk->sk_write_queue); +} + +static void l2cap_skb_msg_name(struct sk_buff *skb, void *msg_name, + int *msg_namelen) +{ + DECLARE_SOCKADDR(struct sockaddr_l2 *, la, msg_name); + + memset(la, 0, sizeof(struct sockaddr_l2)); + la->l2_family = AF_BLUETOOTH; + la->l2_psm = bt_cb(skb)->l2cap.psm; + bacpy(&la->l2_bdaddr, &bt_cb(skb)->l2cap.bdaddr); + + *msg_namelen = sizeof(struct sockaddr_l2); +} + +static void l2cap_sock_init(struct sock *sk, struct sock *parent) +{ + struct l2cap_chan *chan = l2cap_pi(sk)->chan; + + BT_DBG("sk %p", sk); + + if (parent) { + struct l2cap_chan *pchan = l2cap_pi(parent)->chan; + + sk->sk_type = parent->sk_type; + bt_sk(sk)->flags = bt_sk(parent)->flags; + + chan->chan_type = pchan->chan_type; + chan->imtu = pchan->imtu; + chan->omtu = pchan->omtu; + chan->conf_state = pchan->conf_state; + chan->mode = pchan->mode; + chan->fcs = pchan->fcs; + chan->max_tx = pchan->max_tx; + chan->tx_win = pchan->tx_win; + chan->tx_win_max = pchan->tx_win_max; + chan->sec_level = pchan->sec_level; + chan->flags = pchan->flags; + chan->tx_credits = pchan->tx_credits; + chan->rx_credits = pchan->rx_credits; + + if (chan->chan_type == L2CAP_CHAN_FIXED) { + chan->scid = pchan->scid; + chan->dcid = pchan->scid; + } + + security_sk_clone(parent, sk); + } else { + switch (sk->sk_type) { + case SOCK_RAW: + chan->chan_type = L2CAP_CHAN_RAW; + break; + case SOCK_DGRAM: + chan->chan_type = L2CAP_CHAN_CONN_LESS; + bt_sk(sk)->skb_msg_name = l2cap_skb_msg_name; + break; + case SOCK_SEQPACKET: + case SOCK_STREAM: + chan->chan_type = L2CAP_CHAN_CONN_ORIENTED; + break; + } + + chan->imtu = L2CAP_DEFAULT_MTU; + chan->omtu = 0; + if (!disable_ertm && sk->sk_type == SOCK_STREAM) { + chan->mode = L2CAP_MODE_ERTM; + set_bit(CONF_STATE2_DEVICE, &chan->conf_state); + } else { + chan->mode = L2CAP_MODE_BASIC; + } + + l2cap_chan_set_defaults(chan); + } + + /* Default config options */ + chan->flush_to = L2CAP_DEFAULT_FLUSH_TO; + + chan->data = sk; + chan->ops = &l2cap_chan_ops; +} + +static struct proto l2cap_proto = { + .name = "L2CAP", + .owner = THIS_MODULE, + .obj_size = sizeof(struct l2cap_pinfo) +}; + +static struct sock *l2cap_sock_alloc(struct net *net, struct socket *sock, + int proto, gfp_t prio, int kern) +{ + struct sock *sk; + struct l2cap_chan *chan; + + sk = sk_alloc(net, PF_BLUETOOTH, prio, &l2cap_proto, kern); + if (!sk) + return NULL; + + sock_init_data(sock, sk); + INIT_LIST_HEAD(&bt_sk(sk)->accept_q); + + sk->sk_destruct = l2cap_sock_destruct; + sk->sk_sndtimeo = L2CAP_CONN_TIMEOUT; + + sock_reset_flag(sk, SOCK_ZAPPED); + + sk->sk_protocol = proto; + sk->sk_state = BT_OPEN; + + chan = l2cap_chan_create(); + if (!chan) { + sk_free(sk); + return NULL; + } + + l2cap_chan_hold(chan); + + l2cap_pi(sk)->chan = chan; + + return sk; +} + +static int l2cap_sock_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + + BT_DBG("sock %p", sock); + + sock->state = SS_UNCONNECTED; + + if (sock->type != SOCK_SEQPACKET && sock->type != SOCK_STREAM && + sock->type != SOCK_DGRAM && sock->type != SOCK_RAW) + return -ESOCKTNOSUPPORT; + + if (sock->type == SOCK_RAW && !kern && !capable(CAP_NET_RAW)) + return -EPERM; + + sock->ops = &l2cap_sock_ops; + + sk = l2cap_sock_alloc(net, sock, protocol, GFP_ATOMIC, kern); + if (!sk) + return -ENOMEM; + + l2cap_sock_init(sk, NULL); + bt_sock_link(&l2cap_sk_list, sk); + return 0; +} + +static const struct proto_ops l2cap_sock_ops = { + .family = PF_BLUETOOTH, + .owner = THIS_MODULE, + .release = l2cap_sock_release, + .bind = l2cap_sock_bind, + .connect = l2cap_sock_connect, + .listen = l2cap_sock_listen, + .accept = l2cap_sock_accept, + .getname = l2cap_sock_getname, + .sendmsg = l2cap_sock_sendmsg, + .recvmsg = l2cap_sock_recvmsg, + .poll = bt_sock_poll, + .ioctl = bt_sock_ioctl, + .gettstamp = sock_gettstamp, + .mmap = sock_no_mmap, + .socketpair = sock_no_socketpair, + .shutdown = l2cap_sock_shutdown, + .setsockopt = l2cap_sock_setsockopt, + .getsockopt = l2cap_sock_getsockopt +}; + +static const struct net_proto_family l2cap_sock_family_ops = { + .family = PF_BLUETOOTH, + .owner = THIS_MODULE, + .create = l2cap_sock_create, +}; + +int __init l2cap_init_sockets(void) +{ + int err; + + BUILD_BUG_ON(sizeof(struct sockaddr_l2) > sizeof(struct sockaddr)); + + err = proto_register(&l2cap_proto, 0); + if (err < 0) + return err; + + err = bt_sock_register(BTPROTO_L2CAP, &l2cap_sock_family_ops); + if (err < 0) { + BT_ERR("L2CAP socket registration failed"); + goto error; + } + + err = bt_procfs_init(&init_net, "l2cap", &l2cap_sk_list, + NULL); + if (err < 0) { + BT_ERR("Failed to create L2CAP proc file"); + bt_sock_unregister(BTPROTO_L2CAP); + goto error; + } + + BT_INFO("L2CAP socket layer initialized"); + + return 0; + +error: + proto_unregister(&l2cap_proto); + return err; +} + +void l2cap_cleanup_sockets(void) +{ + bt_procfs_cleanup(&init_net, "l2cap"); + bt_sock_unregister(BTPROTO_L2CAP); + proto_unregister(&l2cap_proto); +} diff --git a/net/bluetooth/leds.c b/net/bluetooth/leds.c new file mode 100644 index 000000000..f46847632 --- /dev/null +++ b/net/bluetooth/leds.c @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2015, Heiner Kallweit + */ + +#include +#include + +#include "leds.h" + +DEFINE_LED_TRIGGER(bt_power_led_trigger); + +struct hci_basic_led_trigger { + struct led_trigger led_trigger; + struct hci_dev *hdev; +}; + +#define to_hci_basic_led_trigger(arg) container_of(arg, \ + struct hci_basic_led_trigger, led_trigger) + +void hci_leds_update_powered(struct hci_dev *hdev, bool enabled) +{ + if (hdev->power_led) + led_trigger_event(hdev->power_led, + enabled ? LED_FULL : LED_OFF); + + if (!enabled) { + struct hci_dev *d; + + read_lock(&hci_dev_list_lock); + + list_for_each_entry(d, &hci_dev_list, list) { + if (test_bit(HCI_UP, &d->flags)) + enabled = true; + } + + read_unlock(&hci_dev_list_lock); + } + + led_trigger_event(bt_power_led_trigger, enabled ? LED_FULL : LED_OFF); +} + +static int power_activate(struct led_classdev *led_cdev) +{ + struct hci_basic_led_trigger *htrig; + bool powered; + + htrig = to_hci_basic_led_trigger(led_cdev->trigger); + powered = test_bit(HCI_UP, &htrig->hdev->flags); + + led_trigger_event(led_cdev->trigger, powered ? LED_FULL : LED_OFF); + + return 0; +} + +static struct led_trigger *led_allocate_basic(struct hci_dev *hdev, + int (*activate)(struct led_classdev *led_cdev), + const char *name) +{ + struct hci_basic_led_trigger *htrig; + + htrig = devm_kzalloc(&hdev->dev, sizeof(*htrig), GFP_KERNEL); + if (!htrig) + return NULL; + + htrig->hdev = hdev; + htrig->led_trigger.activate = activate; + htrig->led_trigger.name = devm_kasprintf(&hdev->dev, GFP_KERNEL, + "%s-%s", hdev->name, + name); + if (!htrig->led_trigger.name) + goto err_alloc; + + if (devm_led_trigger_register(&hdev->dev, &htrig->led_trigger)) + goto err_register; + + return &htrig->led_trigger; + +err_register: + devm_kfree(&hdev->dev, (void *)htrig->led_trigger.name); +err_alloc: + devm_kfree(&hdev->dev, htrig); + return NULL; +} + +void hci_leds_init(struct hci_dev *hdev) +{ + /* initialize power_led */ + hdev->power_led = led_allocate_basic(hdev, power_activate, "power"); +} + +void bt_leds_init(void) +{ + led_trigger_register_simple("bluetooth-power", &bt_power_led_trigger); +} + +void bt_leds_cleanup(void) +{ + led_trigger_unregister_simple(bt_power_led_trigger); +} diff --git a/net/bluetooth/leds.h b/net/bluetooth/leds.h new file mode 100644 index 000000000..bb5e09204 --- /dev/null +++ b/net/bluetooth/leds.h @@ -0,0 +1,23 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright 2015, Heiner Kallweit + */ + +#if IS_ENABLED(CONFIG_BT_LEDS) + +void hci_leds_update_powered(struct hci_dev *hdev, bool enabled); +void hci_leds_init(struct hci_dev *hdev); + +void bt_leds_init(void); +void bt_leds_cleanup(void); + +#else + +static inline void hci_leds_update_powered(struct hci_dev *hdev, + bool enabled) {} +static inline void hci_leds_init(struct hci_dev *hdev) {} + +static inline void bt_leds_init(void) {} +static inline void bt_leds_cleanup(void) {} + +#endif diff --git a/net/bluetooth/lib.c b/net/bluetooth/lib.c new file mode 100644 index 000000000..53a796ac0 --- /dev/null +++ b/net/bluetooth/lib.c @@ -0,0 +1,320 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + Copyright (C) 2000-2001 Qualcomm Incorporated + + Written 2000,2001 by Maxim Krasnyansky + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +/* Bluetooth kernel library. */ + +#define pr_fmt(fmt) "Bluetooth: " fmt + +#include + +#include + +void baswap(bdaddr_t *dst, const bdaddr_t *src) +{ + const unsigned char *s = (const unsigned char *)src; + unsigned char *d = (unsigned char *)dst; + unsigned int i; + + for (i = 0; i < 6; i++) + d[i] = s[5 - i]; +} +EXPORT_SYMBOL(baswap); + +/* Bluetooth error codes to Unix errno mapping */ +int bt_to_errno(__u16 code) +{ + switch (code) { + case 0: + return 0; + + case 0x01: + return EBADRQC; + + case 0x02: + return ENOTCONN; + + case 0x03: + return EIO; + + case 0x04: + case 0x3c: + return EHOSTDOWN; + + case 0x05: + return EACCES; + + case 0x06: + return EBADE; + + case 0x07: + return ENOMEM; + + case 0x08: + return ETIMEDOUT; + + case 0x09: + return EMLINK; + + case 0x0a: + return EMLINK; + + case 0x0b: + return EALREADY; + + case 0x0c: + return EBUSY; + + case 0x0d: + case 0x0e: + case 0x0f: + return ECONNREFUSED; + + case 0x10: + return ETIMEDOUT; + + case 0x11: + case 0x27: + case 0x29: + case 0x20: + return EOPNOTSUPP; + + case 0x12: + return EINVAL; + + case 0x13: + case 0x14: + case 0x15: + return ECONNRESET; + + case 0x16: + return ECONNABORTED; + + case 0x17: + return ELOOP; + + case 0x18: + return EACCES; + + case 0x1a: + return EPROTONOSUPPORT; + + case 0x1b: + return ECONNREFUSED; + + case 0x19: + case 0x1e: + case 0x23: + case 0x24: + case 0x25: + return EPROTO; + + default: + return ENOSYS; + } +} +EXPORT_SYMBOL(bt_to_errno); + +/* Unix errno to Bluetooth error codes mapping */ +__u8 bt_status(int err) +{ + /* Don't convert if already positive value */ + if (err >= 0) + return err; + + switch (err) { + case -EBADRQC: + return 0x01; + + case -ENOTCONN: + return 0x02; + + case -EIO: + return 0x03; + + case -EHOSTDOWN: + return 0x04; + + case -EACCES: + return 0x05; + + case -EBADE: + return 0x06; + + case -ENOMEM: + return 0x07; + + case -ETIMEDOUT: + return 0x08; + + case -EMLINK: + return 0x09; + + case -EALREADY: + return 0x0b; + + case -EBUSY: + return 0x0c; + + case -ECONNREFUSED: + return 0x0d; + + case -EOPNOTSUPP: + return 0x11; + + case -EINVAL: + return 0x12; + + case -ECONNRESET: + return 0x13; + + case -ECONNABORTED: + return 0x16; + + case -ELOOP: + return 0x17; + + case -EPROTONOSUPPORT: + return 0x1a; + + case -EPROTO: + return 0x19; + + default: + return 0x1f; + } +} +EXPORT_SYMBOL(bt_status); + +void bt_info(const char *format, ...) +{ + struct va_format vaf; + va_list args; + + va_start(args, format); + + vaf.fmt = format; + vaf.va = &args; + + pr_info("%pV", &vaf); + + va_end(args); +} +EXPORT_SYMBOL(bt_info); + +void bt_warn(const char *format, ...) +{ + struct va_format vaf; + va_list args; + + va_start(args, format); + + vaf.fmt = format; + vaf.va = &args; + + pr_warn("%pV", &vaf); + + va_end(args); +} +EXPORT_SYMBOL(bt_warn); + +void bt_err(const char *format, ...) +{ + struct va_format vaf; + va_list args; + + va_start(args, format); + + vaf.fmt = format; + vaf.va = &args; + + pr_err("%pV", &vaf); + + va_end(args); +} +EXPORT_SYMBOL(bt_err); + +#ifdef CONFIG_BT_FEATURE_DEBUG +static bool debug_enable; + +void bt_dbg_set(bool enable) +{ + debug_enable = enable; +} + +bool bt_dbg_get(void) +{ + return debug_enable; +} + +void bt_dbg(const char *format, ...) +{ + struct va_format vaf; + va_list args; + + if (likely(!debug_enable)) + return; + + va_start(args, format); + + vaf.fmt = format; + vaf.va = &args; + + printk(KERN_DEBUG pr_fmt("%pV"), &vaf); + + va_end(args); +} +EXPORT_SYMBOL(bt_dbg); +#endif + +void bt_warn_ratelimited(const char *format, ...) +{ + struct va_format vaf; + va_list args; + + va_start(args, format); + + vaf.fmt = format; + vaf.va = &args; + + pr_warn_ratelimited("%pV", &vaf); + + va_end(args); +} +EXPORT_SYMBOL(bt_warn_ratelimited); + +void bt_err_ratelimited(const char *format, ...) +{ + struct va_format vaf; + va_list args; + + va_start(args, format); + + vaf.fmt = format; + vaf.va = &args; + + pr_err_ratelimited("%pV", &vaf); + + va_end(args); +} +EXPORT_SYMBOL(bt_err_ratelimited); diff --git a/net/bluetooth/mgmt.c b/net/bluetooth/mgmt.c new file mode 100644 index 000000000..6d631a2e6 --- /dev/null +++ b/net/bluetooth/mgmt.c @@ -0,0 +1,10571 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + + Copyright (C) 2010 Nokia Corporation + Copyright (C) 2011-2012 Intel Corporation + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +/* Bluetooth HCI Management interface */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "hci_request.h" +#include "smp.h" +#include "mgmt_util.h" +#include "mgmt_config.h" +#include "msft.h" +#include "eir.h" +#include "aosp.h" + +#define MGMT_VERSION 1 +#define MGMT_REVISION 22 + +static const u16 mgmt_commands[] = { + MGMT_OP_READ_INDEX_LIST, + MGMT_OP_READ_INFO, + MGMT_OP_SET_POWERED, + MGMT_OP_SET_DISCOVERABLE, + MGMT_OP_SET_CONNECTABLE, + MGMT_OP_SET_FAST_CONNECTABLE, + MGMT_OP_SET_BONDABLE, + MGMT_OP_SET_LINK_SECURITY, + MGMT_OP_SET_SSP, + MGMT_OP_SET_HS, + MGMT_OP_SET_LE, + MGMT_OP_SET_DEV_CLASS, + MGMT_OP_SET_LOCAL_NAME, + MGMT_OP_ADD_UUID, + MGMT_OP_REMOVE_UUID, + MGMT_OP_LOAD_LINK_KEYS, + MGMT_OP_LOAD_LONG_TERM_KEYS, + MGMT_OP_DISCONNECT, + MGMT_OP_GET_CONNECTIONS, + MGMT_OP_PIN_CODE_REPLY, + MGMT_OP_PIN_CODE_NEG_REPLY, + MGMT_OP_SET_IO_CAPABILITY, + MGMT_OP_PAIR_DEVICE, + MGMT_OP_CANCEL_PAIR_DEVICE, + MGMT_OP_UNPAIR_DEVICE, + MGMT_OP_USER_CONFIRM_REPLY, + MGMT_OP_USER_CONFIRM_NEG_REPLY, + MGMT_OP_USER_PASSKEY_REPLY, + MGMT_OP_USER_PASSKEY_NEG_REPLY, + MGMT_OP_READ_LOCAL_OOB_DATA, + MGMT_OP_ADD_REMOTE_OOB_DATA, + MGMT_OP_REMOVE_REMOTE_OOB_DATA, + MGMT_OP_START_DISCOVERY, + MGMT_OP_STOP_DISCOVERY, + MGMT_OP_CONFIRM_NAME, + MGMT_OP_BLOCK_DEVICE, + MGMT_OP_UNBLOCK_DEVICE, + MGMT_OP_SET_DEVICE_ID, + MGMT_OP_SET_ADVERTISING, + MGMT_OP_SET_BREDR, + MGMT_OP_SET_STATIC_ADDRESS, + MGMT_OP_SET_SCAN_PARAMS, + MGMT_OP_SET_SECURE_CONN, + MGMT_OP_SET_DEBUG_KEYS, + MGMT_OP_SET_PRIVACY, + MGMT_OP_LOAD_IRKS, + MGMT_OP_GET_CONN_INFO, + MGMT_OP_GET_CLOCK_INFO, + MGMT_OP_ADD_DEVICE, + MGMT_OP_REMOVE_DEVICE, + MGMT_OP_LOAD_CONN_PARAM, + MGMT_OP_READ_UNCONF_INDEX_LIST, + MGMT_OP_READ_CONFIG_INFO, + MGMT_OP_SET_EXTERNAL_CONFIG, + MGMT_OP_SET_PUBLIC_ADDRESS, + MGMT_OP_START_SERVICE_DISCOVERY, + MGMT_OP_READ_LOCAL_OOB_EXT_DATA, + MGMT_OP_READ_EXT_INDEX_LIST, + MGMT_OP_READ_ADV_FEATURES, + MGMT_OP_ADD_ADVERTISING, + MGMT_OP_REMOVE_ADVERTISING, + MGMT_OP_GET_ADV_SIZE_INFO, + MGMT_OP_START_LIMITED_DISCOVERY, + MGMT_OP_READ_EXT_INFO, + MGMT_OP_SET_APPEARANCE, + MGMT_OP_GET_PHY_CONFIGURATION, + MGMT_OP_SET_PHY_CONFIGURATION, + MGMT_OP_SET_BLOCKED_KEYS, + MGMT_OP_SET_WIDEBAND_SPEECH, + MGMT_OP_READ_CONTROLLER_CAP, + MGMT_OP_READ_EXP_FEATURES_INFO, + MGMT_OP_SET_EXP_FEATURE, + MGMT_OP_READ_DEF_SYSTEM_CONFIG, + MGMT_OP_SET_DEF_SYSTEM_CONFIG, + MGMT_OP_READ_DEF_RUNTIME_CONFIG, + MGMT_OP_SET_DEF_RUNTIME_CONFIG, + MGMT_OP_GET_DEVICE_FLAGS, + MGMT_OP_SET_DEVICE_FLAGS, + MGMT_OP_READ_ADV_MONITOR_FEATURES, + MGMT_OP_ADD_ADV_PATTERNS_MONITOR, + MGMT_OP_REMOVE_ADV_MONITOR, + MGMT_OP_ADD_EXT_ADV_PARAMS, + MGMT_OP_ADD_EXT_ADV_DATA, + MGMT_OP_ADD_ADV_PATTERNS_MONITOR_RSSI, + MGMT_OP_SET_MESH_RECEIVER, + MGMT_OP_MESH_READ_FEATURES, + MGMT_OP_MESH_SEND, + MGMT_OP_MESH_SEND_CANCEL, +}; + +static const u16 mgmt_events[] = { + MGMT_EV_CONTROLLER_ERROR, + MGMT_EV_INDEX_ADDED, + MGMT_EV_INDEX_REMOVED, + MGMT_EV_NEW_SETTINGS, + MGMT_EV_CLASS_OF_DEV_CHANGED, + MGMT_EV_LOCAL_NAME_CHANGED, + MGMT_EV_NEW_LINK_KEY, + MGMT_EV_NEW_LONG_TERM_KEY, + MGMT_EV_DEVICE_CONNECTED, + MGMT_EV_DEVICE_DISCONNECTED, + MGMT_EV_CONNECT_FAILED, + MGMT_EV_PIN_CODE_REQUEST, + MGMT_EV_USER_CONFIRM_REQUEST, + MGMT_EV_USER_PASSKEY_REQUEST, + MGMT_EV_AUTH_FAILED, + MGMT_EV_DEVICE_FOUND, + MGMT_EV_DISCOVERING, + MGMT_EV_DEVICE_BLOCKED, + MGMT_EV_DEVICE_UNBLOCKED, + MGMT_EV_DEVICE_UNPAIRED, + MGMT_EV_PASSKEY_NOTIFY, + MGMT_EV_NEW_IRK, + MGMT_EV_NEW_CSRK, + MGMT_EV_DEVICE_ADDED, + MGMT_EV_DEVICE_REMOVED, + MGMT_EV_NEW_CONN_PARAM, + MGMT_EV_UNCONF_INDEX_ADDED, + MGMT_EV_UNCONF_INDEX_REMOVED, + MGMT_EV_NEW_CONFIG_OPTIONS, + MGMT_EV_EXT_INDEX_ADDED, + MGMT_EV_EXT_INDEX_REMOVED, + MGMT_EV_LOCAL_OOB_DATA_UPDATED, + MGMT_EV_ADVERTISING_ADDED, + MGMT_EV_ADVERTISING_REMOVED, + MGMT_EV_EXT_INFO_CHANGED, + MGMT_EV_PHY_CONFIGURATION_CHANGED, + MGMT_EV_EXP_FEATURE_CHANGED, + MGMT_EV_DEVICE_FLAGS_CHANGED, + MGMT_EV_ADV_MONITOR_ADDED, + MGMT_EV_ADV_MONITOR_REMOVED, + MGMT_EV_CONTROLLER_SUSPEND, + MGMT_EV_CONTROLLER_RESUME, + MGMT_EV_ADV_MONITOR_DEVICE_FOUND, + MGMT_EV_ADV_MONITOR_DEVICE_LOST, +}; + +static const u16 mgmt_untrusted_commands[] = { + MGMT_OP_READ_INDEX_LIST, + MGMT_OP_READ_INFO, + MGMT_OP_READ_UNCONF_INDEX_LIST, + MGMT_OP_READ_CONFIG_INFO, + MGMT_OP_READ_EXT_INDEX_LIST, + MGMT_OP_READ_EXT_INFO, + MGMT_OP_READ_CONTROLLER_CAP, + MGMT_OP_READ_EXP_FEATURES_INFO, + MGMT_OP_READ_DEF_SYSTEM_CONFIG, + MGMT_OP_READ_DEF_RUNTIME_CONFIG, +}; + +static const u16 mgmt_untrusted_events[] = { + MGMT_EV_INDEX_ADDED, + MGMT_EV_INDEX_REMOVED, + MGMT_EV_NEW_SETTINGS, + MGMT_EV_CLASS_OF_DEV_CHANGED, + MGMT_EV_LOCAL_NAME_CHANGED, + MGMT_EV_UNCONF_INDEX_ADDED, + MGMT_EV_UNCONF_INDEX_REMOVED, + MGMT_EV_NEW_CONFIG_OPTIONS, + MGMT_EV_EXT_INDEX_ADDED, + MGMT_EV_EXT_INDEX_REMOVED, + MGMT_EV_EXT_INFO_CHANGED, + MGMT_EV_EXP_FEATURE_CHANGED, +}; + +#define CACHE_TIMEOUT msecs_to_jiffies(2 * 1000) + +#define ZERO_KEY "\x00\x00\x00\x00\x00\x00\x00\x00" \ + "\x00\x00\x00\x00\x00\x00\x00\x00" + +/* HCI to MGMT error code conversion table */ +static const u8 mgmt_status_table[] = { + MGMT_STATUS_SUCCESS, + MGMT_STATUS_UNKNOWN_COMMAND, /* Unknown Command */ + MGMT_STATUS_NOT_CONNECTED, /* No Connection */ + MGMT_STATUS_FAILED, /* Hardware Failure */ + MGMT_STATUS_CONNECT_FAILED, /* Page Timeout */ + MGMT_STATUS_AUTH_FAILED, /* Authentication Failed */ + MGMT_STATUS_AUTH_FAILED, /* PIN or Key Missing */ + MGMT_STATUS_NO_RESOURCES, /* Memory Full */ + MGMT_STATUS_TIMEOUT, /* Connection Timeout */ + MGMT_STATUS_NO_RESOURCES, /* Max Number of Connections */ + MGMT_STATUS_NO_RESOURCES, /* Max Number of SCO Connections */ + MGMT_STATUS_ALREADY_CONNECTED, /* ACL Connection Exists */ + MGMT_STATUS_BUSY, /* Command Disallowed */ + MGMT_STATUS_NO_RESOURCES, /* Rejected Limited Resources */ + MGMT_STATUS_REJECTED, /* Rejected Security */ + MGMT_STATUS_REJECTED, /* Rejected Personal */ + MGMT_STATUS_TIMEOUT, /* Host Timeout */ + MGMT_STATUS_NOT_SUPPORTED, /* Unsupported Feature */ + MGMT_STATUS_INVALID_PARAMS, /* Invalid Parameters */ + MGMT_STATUS_DISCONNECTED, /* OE User Ended Connection */ + MGMT_STATUS_NO_RESOURCES, /* OE Low Resources */ + MGMT_STATUS_DISCONNECTED, /* OE Power Off */ + MGMT_STATUS_DISCONNECTED, /* Connection Terminated */ + MGMT_STATUS_BUSY, /* Repeated Attempts */ + MGMT_STATUS_REJECTED, /* Pairing Not Allowed */ + MGMT_STATUS_FAILED, /* Unknown LMP PDU */ + MGMT_STATUS_NOT_SUPPORTED, /* Unsupported Remote Feature */ + MGMT_STATUS_REJECTED, /* SCO Offset Rejected */ + MGMT_STATUS_REJECTED, /* SCO Interval Rejected */ + MGMT_STATUS_REJECTED, /* Air Mode Rejected */ + MGMT_STATUS_INVALID_PARAMS, /* Invalid LMP Parameters */ + MGMT_STATUS_FAILED, /* Unspecified Error */ + MGMT_STATUS_NOT_SUPPORTED, /* Unsupported LMP Parameter Value */ + MGMT_STATUS_FAILED, /* Role Change Not Allowed */ + MGMT_STATUS_TIMEOUT, /* LMP Response Timeout */ + MGMT_STATUS_FAILED, /* LMP Error Transaction Collision */ + MGMT_STATUS_FAILED, /* LMP PDU Not Allowed */ + MGMT_STATUS_REJECTED, /* Encryption Mode Not Accepted */ + MGMT_STATUS_FAILED, /* Unit Link Key Used */ + MGMT_STATUS_NOT_SUPPORTED, /* QoS Not Supported */ + MGMT_STATUS_TIMEOUT, /* Instant Passed */ + MGMT_STATUS_NOT_SUPPORTED, /* Pairing Not Supported */ + MGMT_STATUS_FAILED, /* Transaction Collision */ + MGMT_STATUS_FAILED, /* Reserved for future use */ + MGMT_STATUS_INVALID_PARAMS, /* Unacceptable Parameter */ + MGMT_STATUS_REJECTED, /* QoS Rejected */ + MGMT_STATUS_NOT_SUPPORTED, /* Classification Not Supported */ + MGMT_STATUS_REJECTED, /* Insufficient Security */ + MGMT_STATUS_INVALID_PARAMS, /* Parameter Out Of Range */ + MGMT_STATUS_FAILED, /* Reserved for future use */ + MGMT_STATUS_BUSY, /* Role Switch Pending */ + MGMT_STATUS_FAILED, /* Reserved for future use */ + MGMT_STATUS_FAILED, /* Slot Violation */ + MGMT_STATUS_FAILED, /* Role Switch Failed */ + MGMT_STATUS_INVALID_PARAMS, /* EIR Too Large */ + MGMT_STATUS_NOT_SUPPORTED, /* Simple Pairing Not Supported */ + MGMT_STATUS_BUSY, /* Host Busy Pairing */ + MGMT_STATUS_REJECTED, /* Rejected, No Suitable Channel */ + MGMT_STATUS_BUSY, /* Controller Busy */ + MGMT_STATUS_INVALID_PARAMS, /* Unsuitable Connection Interval */ + MGMT_STATUS_TIMEOUT, /* Directed Advertising Timeout */ + MGMT_STATUS_AUTH_FAILED, /* Terminated Due to MIC Failure */ + MGMT_STATUS_CONNECT_FAILED, /* Connection Establishment Failed */ + MGMT_STATUS_CONNECT_FAILED, /* MAC Connection Failed */ +}; + +static u8 mgmt_errno_status(int err) +{ + switch (err) { + case 0: + return MGMT_STATUS_SUCCESS; + case -EPERM: + return MGMT_STATUS_REJECTED; + case -EINVAL: + return MGMT_STATUS_INVALID_PARAMS; + case -EOPNOTSUPP: + return MGMT_STATUS_NOT_SUPPORTED; + case -EBUSY: + return MGMT_STATUS_BUSY; + case -ETIMEDOUT: + return MGMT_STATUS_AUTH_FAILED; + case -ENOMEM: + return MGMT_STATUS_NO_RESOURCES; + case -EISCONN: + return MGMT_STATUS_ALREADY_CONNECTED; + case -ENOTCONN: + return MGMT_STATUS_DISCONNECTED; + } + + return MGMT_STATUS_FAILED; +} + +static u8 mgmt_status(int err) +{ + if (err < 0) + return mgmt_errno_status(err); + + if (err < ARRAY_SIZE(mgmt_status_table)) + return mgmt_status_table[err]; + + return MGMT_STATUS_FAILED; +} + +static int mgmt_index_event(u16 event, struct hci_dev *hdev, void *data, + u16 len, int flag) +{ + return mgmt_send_event(event, hdev, HCI_CHANNEL_CONTROL, data, len, + flag, NULL); +} + +static int mgmt_limited_event(u16 event, struct hci_dev *hdev, void *data, + u16 len, int flag, struct sock *skip_sk) +{ + return mgmt_send_event(event, hdev, HCI_CHANNEL_CONTROL, data, len, + flag, skip_sk); +} + +static int mgmt_event(u16 event, struct hci_dev *hdev, void *data, u16 len, + struct sock *skip_sk) +{ + return mgmt_send_event(event, hdev, HCI_CHANNEL_CONTROL, data, len, + HCI_SOCK_TRUSTED, skip_sk); +} + +static int mgmt_event_skb(struct sk_buff *skb, struct sock *skip_sk) +{ + return mgmt_send_event_skb(HCI_CHANNEL_CONTROL, skb, HCI_SOCK_TRUSTED, + skip_sk); +} + +static u8 le_addr_type(u8 mgmt_addr_type) +{ + if (mgmt_addr_type == BDADDR_LE_PUBLIC) + return ADDR_LE_DEV_PUBLIC; + else + return ADDR_LE_DEV_RANDOM; +} + +void mgmt_fill_version_info(void *ver) +{ + struct mgmt_rp_read_version *rp = ver; + + rp->version = MGMT_VERSION; + rp->revision = cpu_to_le16(MGMT_REVISION); +} + +static int read_version(struct sock *sk, struct hci_dev *hdev, void *data, + u16 data_len) +{ + struct mgmt_rp_read_version rp; + + bt_dev_dbg(hdev, "sock %p", sk); + + mgmt_fill_version_info(&rp); + + return mgmt_cmd_complete(sk, MGMT_INDEX_NONE, MGMT_OP_READ_VERSION, 0, + &rp, sizeof(rp)); +} + +static int read_commands(struct sock *sk, struct hci_dev *hdev, void *data, + u16 data_len) +{ + struct mgmt_rp_read_commands *rp; + u16 num_commands, num_events; + size_t rp_size; + int i, err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (hci_sock_test_flag(sk, HCI_SOCK_TRUSTED)) { + num_commands = ARRAY_SIZE(mgmt_commands); + num_events = ARRAY_SIZE(mgmt_events); + } else { + num_commands = ARRAY_SIZE(mgmt_untrusted_commands); + num_events = ARRAY_SIZE(mgmt_untrusted_events); + } + + rp_size = sizeof(*rp) + ((num_commands + num_events) * sizeof(u16)); + + rp = kmalloc(rp_size, GFP_KERNEL); + if (!rp) + return -ENOMEM; + + rp->num_commands = cpu_to_le16(num_commands); + rp->num_events = cpu_to_le16(num_events); + + if (hci_sock_test_flag(sk, HCI_SOCK_TRUSTED)) { + __le16 *opcode = rp->opcodes; + + for (i = 0; i < num_commands; i++, opcode++) + put_unaligned_le16(mgmt_commands[i], opcode); + + for (i = 0; i < num_events; i++, opcode++) + put_unaligned_le16(mgmt_events[i], opcode); + } else { + __le16 *opcode = rp->opcodes; + + for (i = 0; i < num_commands; i++, opcode++) + put_unaligned_le16(mgmt_untrusted_commands[i], opcode); + + for (i = 0; i < num_events; i++, opcode++) + put_unaligned_le16(mgmt_untrusted_events[i], opcode); + } + + err = mgmt_cmd_complete(sk, MGMT_INDEX_NONE, MGMT_OP_READ_COMMANDS, 0, + rp, rp_size); + kfree(rp); + + return err; +} + +static int read_index_list(struct sock *sk, struct hci_dev *hdev, void *data, + u16 data_len) +{ + struct mgmt_rp_read_index_list *rp; + struct hci_dev *d; + size_t rp_len; + u16 count; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + read_lock(&hci_dev_list_lock); + + count = 0; + list_for_each_entry(d, &hci_dev_list, list) { + if (d->dev_type == HCI_PRIMARY && + !hci_dev_test_flag(d, HCI_UNCONFIGURED)) + count++; + } + + rp_len = sizeof(*rp) + (2 * count); + rp = kmalloc(rp_len, GFP_ATOMIC); + if (!rp) { + read_unlock(&hci_dev_list_lock); + return -ENOMEM; + } + + count = 0; + list_for_each_entry(d, &hci_dev_list, list) { + if (hci_dev_test_flag(d, HCI_SETUP) || + hci_dev_test_flag(d, HCI_CONFIG) || + hci_dev_test_flag(d, HCI_USER_CHANNEL)) + continue; + + /* Devices marked as raw-only are neither configured + * nor unconfigured controllers. + */ + if (test_bit(HCI_QUIRK_RAW_DEVICE, &d->quirks)) + continue; + + if (d->dev_type == HCI_PRIMARY && + !hci_dev_test_flag(d, HCI_UNCONFIGURED)) { + rp->index[count++] = cpu_to_le16(d->id); + bt_dev_dbg(hdev, "Added hci%u", d->id); + } + } + + rp->num_controllers = cpu_to_le16(count); + rp_len = sizeof(*rp) + (2 * count); + + read_unlock(&hci_dev_list_lock); + + err = mgmt_cmd_complete(sk, MGMT_INDEX_NONE, MGMT_OP_READ_INDEX_LIST, + 0, rp, rp_len); + + kfree(rp); + + return err; +} + +static int read_unconf_index_list(struct sock *sk, struct hci_dev *hdev, + void *data, u16 data_len) +{ + struct mgmt_rp_read_unconf_index_list *rp; + struct hci_dev *d; + size_t rp_len; + u16 count; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + read_lock(&hci_dev_list_lock); + + count = 0; + list_for_each_entry(d, &hci_dev_list, list) { + if (d->dev_type == HCI_PRIMARY && + hci_dev_test_flag(d, HCI_UNCONFIGURED)) + count++; + } + + rp_len = sizeof(*rp) + (2 * count); + rp = kmalloc(rp_len, GFP_ATOMIC); + if (!rp) { + read_unlock(&hci_dev_list_lock); + return -ENOMEM; + } + + count = 0; + list_for_each_entry(d, &hci_dev_list, list) { + if (hci_dev_test_flag(d, HCI_SETUP) || + hci_dev_test_flag(d, HCI_CONFIG) || + hci_dev_test_flag(d, HCI_USER_CHANNEL)) + continue; + + /* Devices marked as raw-only are neither configured + * nor unconfigured controllers. + */ + if (test_bit(HCI_QUIRK_RAW_DEVICE, &d->quirks)) + continue; + + if (d->dev_type == HCI_PRIMARY && + hci_dev_test_flag(d, HCI_UNCONFIGURED)) { + rp->index[count++] = cpu_to_le16(d->id); + bt_dev_dbg(hdev, "Added hci%u", d->id); + } + } + + rp->num_controllers = cpu_to_le16(count); + rp_len = sizeof(*rp) + (2 * count); + + read_unlock(&hci_dev_list_lock); + + err = mgmt_cmd_complete(sk, MGMT_INDEX_NONE, + MGMT_OP_READ_UNCONF_INDEX_LIST, 0, rp, rp_len); + + kfree(rp); + + return err; +} + +static int read_ext_index_list(struct sock *sk, struct hci_dev *hdev, + void *data, u16 data_len) +{ + struct mgmt_rp_read_ext_index_list *rp; + struct hci_dev *d; + u16 count; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + read_lock(&hci_dev_list_lock); + + count = 0; + list_for_each_entry(d, &hci_dev_list, list) { + if (d->dev_type == HCI_PRIMARY || d->dev_type == HCI_AMP) + count++; + } + + rp = kmalloc(struct_size(rp, entry, count), GFP_ATOMIC); + if (!rp) { + read_unlock(&hci_dev_list_lock); + return -ENOMEM; + } + + count = 0; + list_for_each_entry(d, &hci_dev_list, list) { + if (hci_dev_test_flag(d, HCI_SETUP) || + hci_dev_test_flag(d, HCI_CONFIG) || + hci_dev_test_flag(d, HCI_USER_CHANNEL)) + continue; + + /* Devices marked as raw-only are neither configured + * nor unconfigured controllers. + */ + if (test_bit(HCI_QUIRK_RAW_DEVICE, &d->quirks)) + continue; + + if (d->dev_type == HCI_PRIMARY) { + if (hci_dev_test_flag(d, HCI_UNCONFIGURED)) + rp->entry[count].type = 0x01; + else + rp->entry[count].type = 0x00; + } else if (d->dev_type == HCI_AMP) { + rp->entry[count].type = 0x02; + } else { + continue; + } + + rp->entry[count].bus = d->bus; + rp->entry[count++].index = cpu_to_le16(d->id); + bt_dev_dbg(hdev, "Added hci%u", d->id); + } + + rp->num_controllers = cpu_to_le16(count); + + read_unlock(&hci_dev_list_lock); + + /* If this command is called at least once, then all the + * default index and unconfigured index events are disabled + * and from now on only extended index events are used. + */ + hci_sock_set_flag(sk, HCI_MGMT_EXT_INDEX_EVENTS); + hci_sock_clear_flag(sk, HCI_MGMT_INDEX_EVENTS); + hci_sock_clear_flag(sk, HCI_MGMT_UNCONF_INDEX_EVENTS); + + err = mgmt_cmd_complete(sk, MGMT_INDEX_NONE, + MGMT_OP_READ_EXT_INDEX_LIST, 0, rp, + struct_size(rp, entry, count)); + + kfree(rp); + + return err; +} + +static bool is_configured(struct hci_dev *hdev) +{ + if (test_bit(HCI_QUIRK_EXTERNAL_CONFIG, &hdev->quirks) && + !hci_dev_test_flag(hdev, HCI_EXT_CONFIGURED)) + return false; + + if ((test_bit(HCI_QUIRK_INVALID_BDADDR, &hdev->quirks) || + test_bit(HCI_QUIRK_USE_BDADDR_PROPERTY, &hdev->quirks)) && + !bacmp(&hdev->public_addr, BDADDR_ANY)) + return false; + + return true; +} + +static __le32 get_missing_options(struct hci_dev *hdev) +{ + u32 options = 0; + + if (test_bit(HCI_QUIRK_EXTERNAL_CONFIG, &hdev->quirks) && + !hci_dev_test_flag(hdev, HCI_EXT_CONFIGURED)) + options |= MGMT_OPTION_EXTERNAL_CONFIG; + + if ((test_bit(HCI_QUIRK_INVALID_BDADDR, &hdev->quirks) || + test_bit(HCI_QUIRK_USE_BDADDR_PROPERTY, &hdev->quirks)) && + !bacmp(&hdev->public_addr, BDADDR_ANY)) + options |= MGMT_OPTION_PUBLIC_ADDRESS; + + return cpu_to_le32(options); +} + +static int new_options(struct hci_dev *hdev, struct sock *skip) +{ + __le32 options = get_missing_options(hdev); + + return mgmt_limited_event(MGMT_EV_NEW_CONFIG_OPTIONS, hdev, &options, + sizeof(options), HCI_MGMT_OPTION_EVENTS, skip); +} + +static int send_options_rsp(struct sock *sk, u16 opcode, struct hci_dev *hdev) +{ + __le32 options = get_missing_options(hdev); + + return mgmt_cmd_complete(sk, hdev->id, opcode, 0, &options, + sizeof(options)); +} + +static int read_config_info(struct sock *sk, struct hci_dev *hdev, + void *data, u16 data_len) +{ + struct mgmt_rp_read_config_info rp; + u32 options = 0; + + bt_dev_dbg(hdev, "sock %p", sk); + + hci_dev_lock(hdev); + + memset(&rp, 0, sizeof(rp)); + rp.manufacturer = cpu_to_le16(hdev->manufacturer); + + if (test_bit(HCI_QUIRK_EXTERNAL_CONFIG, &hdev->quirks)) + options |= MGMT_OPTION_EXTERNAL_CONFIG; + + if (hdev->set_bdaddr) + options |= MGMT_OPTION_PUBLIC_ADDRESS; + + rp.supported_options = cpu_to_le32(options); + rp.missing_options = get_missing_options(hdev); + + hci_dev_unlock(hdev); + + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_READ_CONFIG_INFO, 0, + &rp, sizeof(rp)); +} + +static u32 get_supported_phys(struct hci_dev *hdev) +{ + u32 supported_phys = 0; + + if (lmp_bredr_capable(hdev)) { + supported_phys |= MGMT_PHY_BR_1M_1SLOT; + + if (hdev->features[0][0] & LMP_3SLOT) + supported_phys |= MGMT_PHY_BR_1M_3SLOT; + + if (hdev->features[0][0] & LMP_5SLOT) + supported_phys |= MGMT_PHY_BR_1M_5SLOT; + + if (lmp_edr_2m_capable(hdev)) { + supported_phys |= MGMT_PHY_EDR_2M_1SLOT; + + if (lmp_edr_3slot_capable(hdev)) + supported_phys |= MGMT_PHY_EDR_2M_3SLOT; + + if (lmp_edr_5slot_capable(hdev)) + supported_phys |= MGMT_PHY_EDR_2M_5SLOT; + + if (lmp_edr_3m_capable(hdev)) { + supported_phys |= MGMT_PHY_EDR_3M_1SLOT; + + if (lmp_edr_3slot_capable(hdev)) + supported_phys |= MGMT_PHY_EDR_3M_3SLOT; + + if (lmp_edr_5slot_capable(hdev)) + supported_phys |= MGMT_PHY_EDR_3M_5SLOT; + } + } + } + + if (lmp_le_capable(hdev)) { + supported_phys |= MGMT_PHY_LE_1M_TX; + supported_phys |= MGMT_PHY_LE_1M_RX; + + if (hdev->le_features[1] & HCI_LE_PHY_2M) { + supported_phys |= MGMT_PHY_LE_2M_TX; + supported_phys |= MGMT_PHY_LE_2M_RX; + } + + if (hdev->le_features[1] & HCI_LE_PHY_CODED) { + supported_phys |= MGMT_PHY_LE_CODED_TX; + supported_phys |= MGMT_PHY_LE_CODED_RX; + } + } + + return supported_phys; +} + +static u32 get_selected_phys(struct hci_dev *hdev) +{ + u32 selected_phys = 0; + + if (lmp_bredr_capable(hdev)) { + selected_phys |= MGMT_PHY_BR_1M_1SLOT; + + if (hdev->pkt_type & (HCI_DM3 | HCI_DH3)) + selected_phys |= MGMT_PHY_BR_1M_3SLOT; + + if (hdev->pkt_type & (HCI_DM5 | HCI_DH5)) + selected_phys |= MGMT_PHY_BR_1M_5SLOT; + + if (lmp_edr_2m_capable(hdev)) { + if (!(hdev->pkt_type & HCI_2DH1)) + selected_phys |= MGMT_PHY_EDR_2M_1SLOT; + + if (lmp_edr_3slot_capable(hdev) && + !(hdev->pkt_type & HCI_2DH3)) + selected_phys |= MGMT_PHY_EDR_2M_3SLOT; + + if (lmp_edr_5slot_capable(hdev) && + !(hdev->pkt_type & HCI_2DH5)) + selected_phys |= MGMT_PHY_EDR_2M_5SLOT; + + if (lmp_edr_3m_capable(hdev)) { + if (!(hdev->pkt_type & HCI_3DH1)) + selected_phys |= MGMT_PHY_EDR_3M_1SLOT; + + if (lmp_edr_3slot_capable(hdev) && + !(hdev->pkt_type & HCI_3DH3)) + selected_phys |= MGMT_PHY_EDR_3M_3SLOT; + + if (lmp_edr_5slot_capable(hdev) && + !(hdev->pkt_type & HCI_3DH5)) + selected_phys |= MGMT_PHY_EDR_3M_5SLOT; + } + } + } + + if (lmp_le_capable(hdev)) { + if (hdev->le_tx_def_phys & HCI_LE_SET_PHY_1M) + selected_phys |= MGMT_PHY_LE_1M_TX; + + if (hdev->le_rx_def_phys & HCI_LE_SET_PHY_1M) + selected_phys |= MGMT_PHY_LE_1M_RX; + + if (hdev->le_tx_def_phys & HCI_LE_SET_PHY_2M) + selected_phys |= MGMT_PHY_LE_2M_TX; + + if (hdev->le_rx_def_phys & HCI_LE_SET_PHY_2M) + selected_phys |= MGMT_PHY_LE_2M_RX; + + if (hdev->le_tx_def_phys & HCI_LE_SET_PHY_CODED) + selected_phys |= MGMT_PHY_LE_CODED_TX; + + if (hdev->le_rx_def_phys & HCI_LE_SET_PHY_CODED) + selected_phys |= MGMT_PHY_LE_CODED_RX; + } + + return selected_phys; +} + +static u32 get_configurable_phys(struct hci_dev *hdev) +{ + return (get_supported_phys(hdev) & ~MGMT_PHY_BR_1M_1SLOT & + ~MGMT_PHY_LE_1M_TX & ~MGMT_PHY_LE_1M_RX); +} + +static u32 get_supported_settings(struct hci_dev *hdev) +{ + u32 settings = 0; + + settings |= MGMT_SETTING_POWERED; + settings |= MGMT_SETTING_BONDABLE; + settings |= MGMT_SETTING_DEBUG_KEYS; + settings |= MGMT_SETTING_CONNECTABLE; + settings |= MGMT_SETTING_DISCOVERABLE; + + if (lmp_bredr_capable(hdev)) { + if (hdev->hci_ver >= BLUETOOTH_VER_1_2) + settings |= MGMT_SETTING_FAST_CONNECTABLE; + settings |= MGMT_SETTING_BREDR; + settings |= MGMT_SETTING_LINK_SECURITY; + + if (lmp_ssp_capable(hdev)) { + settings |= MGMT_SETTING_SSP; + if (IS_ENABLED(CONFIG_BT_HS)) + settings |= MGMT_SETTING_HS; + } + + if (lmp_sc_capable(hdev)) + settings |= MGMT_SETTING_SECURE_CONN; + + if (test_bit(HCI_QUIRK_WIDEBAND_SPEECH_SUPPORTED, + &hdev->quirks)) + settings |= MGMT_SETTING_WIDEBAND_SPEECH; + } + + if (lmp_le_capable(hdev)) { + settings |= MGMT_SETTING_LE; + settings |= MGMT_SETTING_SECURE_CONN; + settings |= MGMT_SETTING_PRIVACY; + settings |= MGMT_SETTING_STATIC_ADDRESS; + settings |= MGMT_SETTING_ADVERTISING; + } + + if (test_bit(HCI_QUIRK_EXTERNAL_CONFIG, &hdev->quirks) || + hdev->set_bdaddr) + settings |= MGMT_SETTING_CONFIGURATION; + + if (cis_central_capable(hdev)) + settings |= MGMT_SETTING_CIS_CENTRAL; + + if (cis_peripheral_capable(hdev)) + settings |= MGMT_SETTING_CIS_PERIPHERAL; + + settings |= MGMT_SETTING_PHY_CONFIGURATION; + + return settings; +} + +static u32 get_current_settings(struct hci_dev *hdev) +{ + u32 settings = 0; + + if (hdev_is_powered(hdev)) + settings |= MGMT_SETTING_POWERED; + + if (hci_dev_test_flag(hdev, HCI_CONNECTABLE)) + settings |= MGMT_SETTING_CONNECTABLE; + + if (hci_dev_test_flag(hdev, HCI_FAST_CONNECTABLE)) + settings |= MGMT_SETTING_FAST_CONNECTABLE; + + if (hci_dev_test_flag(hdev, HCI_DISCOVERABLE)) + settings |= MGMT_SETTING_DISCOVERABLE; + + if (hci_dev_test_flag(hdev, HCI_BONDABLE)) + settings |= MGMT_SETTING_BONDABLE; + + if (hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) + settings |= MGMT_SETTING_BREDR; + + if (hci_dev_test_flag(hdev, HCI_LE_ENABLED)) + settings |= MGMT_SETTING_LE; + + if (hci_dev_test_flag(hdev, HCI_LINK_SECURITY)) + settings |= MGMT_SETTING_LINK_SECURITY; + + if (hci_dev_test_flag(hdev, HCI_SSP_ENABLED)) + settings |= MGMT_SETTING_SSP; + + if (hci_dev_test_flag(hdev, HCI_HS_ENABLED)) + settings |= MGMT_SETTING_HS; + + if (hci_dev_test_flag(hdev, HCI_ADVERTISING)) + settings |= MGMT_SETTING_ADVERTISING; + + if (hci_dev_test_flag(hdev, HCI_SC_ENABLED)) + settings |= MGMT_SETTING_SECURE_CONN; + + if (hci_dev_test_flag(hdev, HCI_KEEP_DEBUG_KEYS)) + settings |= MGMT_SETTING_DEBUG_KEYS; + + if (hci_dev_test_flag(hdev, HCI_PRIVACY)) + settings |= MGMT_SETTING_PRIVACY; + + /* The current setting for static address has two purposes. The + * first is to indicate if the static address will be used and + * the second is to indicate if it is actually set. + * + * This means if the static address is not configured, this flag + * will never be set. If the address is configured, then if the + * address is actually used decides if the flag is set or not. + * + * For single mode LE only controllers and dual-mode controllers + * with BR/EDR disabled, the existence of the static address will + * be evaluated. + */ + if (hci_dev_test_flag(hdev, HCI_FORCE_STATIC_ADDR) || + !hci_dev_test_flag(hdev, HCI_BREDR_ENABLED) || + !bacmp(&hdev->bdaddr, BDADDR_ANY)) { + if (bacmp(&hdev->static_addr, BDADDR_ANY)) + settings |= MGMT_SETTING_STATIC_ADDRESS; + } + + if (hci_dev_test_flag(hdev, HCI_WIDEBAND_SPEECH_ENABLED)) + settings |= MGMT_SETTING_WIDEBAND_SPEECH; + + if (cis_central_capable(hdev)) + settings |= MGMT_SETTING_CIS_CENTRAL; + + if (cis_peripheral_capable(hdev)) + settings |= MGMT_SETTING_CIS_PERIPHERAL; + + return settings; +} + +static struct mgmt_pending_cmd *pending_find(u16 opcode, struct hci_dev *hdev) +{ + return mgmt_pending_find(HCI_CHANNEL_CONTROL, opcode, hdev); +} + +u8 mgmt_get_adv_discov_flags(struct hci_dev *hdev) +{ + struct mgmt_pending_cmd *cmd; + + /* If there's a pending mgmt command the flags will not yet have + * their final values, so check for this first. + */ + cmd = pending_find(MGMT_OP_SET_DISCOVERABLE, hdev); + if (cmd) { + struct mgmt_mode *cp = cmd->param; + if (cp->val == 0x01) + return LE_AD_GENERAL; + else if (cp->val == 0x02) + return LE_AD_LIMITED; + } else { + if (hci_dev_test_flag(hdev, HCI_LIMITED_DISCOVERABLE)) + return LE_AD_LIMITED; + else if (hci_dev_test_flag(hdev, HCI_DISCOVERABLE)) + return LE_AD_GENERAL; + } + + return 0; +} + +bool mgmt_get_connectable(struct hci_dev *hdev) +{ + struct mgmt_pending_cmd *cmd; + + /* If there's a pending mgmt command the flag will not yet have + * it's final value, so check for this first. + */ + cmd = pending_find(MGMT_OP_SET_CONNECTABLE, hdev); + if (cmd) { + struct mgmt_mode *cp = cmd->param; + + return cp->val; + } + + return hci_dev_test_flag(hdev, HCI_CONNECTABLE); +} + +static int service_cache_sync(struct hci_dev *hdev, void *data) +{ + hci_update_eir_sync(hdev); + hci_update_class_sync(hdev); + + return 0; +} + +static void service_cache_off(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, + service_cache.work); + + if (!hci_dev_test_and_clear_flag(hdev, HCI_SERVICE_CACHE)) + return; + + hci_cmd_sync_queue(hdev, service_cache_sync, NULL, NULL); +} + +static int rpa_expired_sync(struct hci_dev *hdev, void *data) +{ + /* The generation of a new RPA and programming it into the + * controller happens in the hci_req_enable_advertising() + * function. + */ + if (ext_adv_capable(hdev)) + return hci_start_ext_adv_sync(hdev, hdev->cur_adv_instance); + else + return hci_enable_advertising_sync(hdev); +} + +static void rpa_expired(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, + rpa_expired.work); + + bt_dev_dbg(hdev, ""); + + hci_dev_set_flag(hdev, HCI_RPA_EXPIRED); + + if (!hci_dev_test_flag(hdev, HCI_ADVERTISING)) + return; + + hci_cmd_sync_queue(hdev, rpa_expired_sync, NULL, NULL); +} + +static void discov_off(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, + discov_off.work); + + bt_dev_dbg(hdev, ""); + + hci_dev_lock(hdev); + + /* When discoverable timeout triggers, then just make sure + * the limited discoverable flag is cleared. Even in the case + * of a timeout triggered from general discoverable, it is + * safe to unconditionally clear the flag. + */ + hci_dev_clear_flag(hdev, HCI_LIMITED_DISCOVERABLE); + hci_dev_clear_flag(hdev, HCI_DISCOVERABLE); + hdev->discov_timeout = 0; + + hci_update_discoverable(hdev); + + mgmt_new_settings(hdev); + + hci_dev_unlock(hdev); +} + +static int send_settings_rsp(struct sock *sk, u16 opcode, struct hci_dev *hdev); + +static void mesh_send_complete(struct hci_dev *hdev, + struct mgmt_mesh_tx *mesh_tx, bool silent) +{ + u8 handle = mesh_tx->handle; + + if (!silent) + mgmt_event(MGMT_EV_MESH_PACKET_CMPLT, hdev, &handle, + sizeof(handle), NULL); + + mgmt_mesh_remove(mesh_tx); +} + +static int mesh_send_done_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_mesh_tx *mesh_tx; + + hci_dev_clear_flag(hdev, HCI_MESH_SENDING); + hci_disable_advertising_sync(hdev); + mesh_tx = mgmt_mesh_next(hdev, NULL); + + if (mesh_tx) + mesh_send_complete(hdev, mesh_tx, false); + + return 0; +} + +static int mesh_send_sync(struct hci_dev *hdev, void *data); +static void mesh_send_start_complete(struct hci_dev *hdev, void *data, int err); +static void mesh_next(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_mesh_tx *mesh_tx = mgmt_mesh_next(hdev, NULL); + + if (!mesh_tx) + return; + + err = hci_cmd_sync_queue(hdev, mesh_send_sync, mesh_tx, + mesh_send_start_complete); + + if (err < 0) + mesh_send_complete(hdev, mesh_tx, false); + else + hci_dev_set_flag(hdev, HCI_MESH_SENDING); +} + +static void mesh_send_done(struct work_struct *work) +{ + struct hci_dev *hdev = container_of(work, struct hci_dev, + mesh_send_done.work); + + if (!hci_dev_test_flag(hdev, HCI_MESH_SENDING)) + return; + + hci_cmd_sync_queue(hdev, mesh_send_done_sync, NULL, mesh_next); +} + +static void mgmt_init_hdev(struct sock *sk, struct hci_dev *hdev) +{ + if (hci_dev_test_flag(hdev, HCI_MGMT)) + return; + + BT_INFO("MGMT ver %d.%d", MGMT_VERSION, MGMT_REVISION); + + INIT_DELAYED_WORK(&hdev->discov_off, discov_off); + INIT_DELAYED_WORK(&hdev->service_cache, service_cache_off); + INIT_DELAYED_WORK(&hdev->rpa_expired, rpa_expired); + INIT_DELAYED_WORK(&hdev->mesh_send_done, mesh_send_done); + + /* Non-mgmt controlled devices get this bit set + * implicitly so that pairing works for them, however + * for mgmt we require user-space to explicitly enable + * it + */ + hci_dev_clear_flag(hdev, HCI_BONDABLE); + + hci_dev_set_flag(hdev, HCI_MGMT); +} + +static int read_controller_info(struct sock *sk, struct hci_dev *hdev, + void *data, u16 data_len) +{ + struct mgmt_rp_read_info rp; + + bt_dev_dbg(hdev, "sock %p", sk); + + hci_dev_lock(hdev); + + memset(&rp, 0, sizeof(rp)); + + bacpy(&rp.bdaddr, &hdev->bdaddr); + + rp.version = hdev->hci_ver; + rp.manufacturer = cpu_to_le16(hdev->manufacturer); + + rp.supported_settings = cpu_to_le32(get_supported_settings(hdev)); + rp.current_settings = cpu_to_le32(get_current_settings(hdev)); + + memcpy(rp.dev_class, hdev->dev_class, 3); + + memcpy(rp.name, hdev->dev_name, sizeof(hdev->dev_name)); + memcpy(rp.short_name, hdev->short_name, sizeof(hdev->short_name)); + + hci_dev_unlock(hdev); + + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_READ_INFO, 0, &rp, + sizeof(rp)); +} + +static u16 append_eir_data_to_buf(struct hci_dev *hdev, u8 *eir) +{ + u16 eir_len = 0; + size_t name_len; + + if (hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) + eir_len = eir_append_data(eir, eir_len, EIR_CLASS_OF_DEV, + hdev->dev_class, 3); + + if (hci_dev_test_flag(hdev, HCI_LE_ENABLED)) + eir_len = eir_append_le16(eir, eir_len, EIR_APPEARANCE, + hdev->appearance); + + name_len = strnlen(hdev->dev_name, sizeof(hdev->dev_name)); + eir_len = eir_append_data(eir, eir_len, EIR_NAME_COMPLETE, + hdev->dev_name, name_len); + + name_len = strnlen(hdev->short_name, sizeof(hdev->short_name)); + eir_len = eir_append_data(eir, eir_len, EIR_NAME_SHORT, + hdev->short_name, name_len); + + return eir_len; +} + +static int read_ext_controller_info(struct sock *sk, struct hci_dev *hdev, + void *data, u16 data_len) +{ + char buf[512]; + struct mgmt_rp_read_ext_info *rp = (void *)buf; + u16 eir_len; + + bt_dev_dbg(hdev, "sock %p", sk); + + memset(&buf, 0, sizeof(buf)); + + hci_dev_lock(hdev); + + bacpy(&rp->bdaddr, &hdev->bdaddr); + + rp->version = hdev->hci_ver; + rp->manufacturer = cpu_to_le16(hdev->manufacturer); + + rp->supported_settings = cpu_to_le32(get_supported_settings(hdev)); + rp->current_settings = cpu_to_le32(get_current_settings(hdev)); + + + eir_len = append_eir_data_to_buf(hdev, rp->eir); + rp->eir_len = cpu_to_le16(eir_len); + + hci_dev_unlock(hdev); + + /* If this command is called at least once, then the events + * for class of device and local name changes are disabled + * and only the new extended controller information event + * is used. + */ + hci_sock_set_flag(sk, HCI_MGMT_EXT_INFO_EVENTS); + hci_sock_clear_flag(sk, HCI_MGMT_DEV_CLASS_EVENTS); + hci_sock_clear_flag(sk, HCI_MGMT_LOCAL_NAME_EVENTS); + + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_READ_EXT_INFO, 0, rp, + sizeof(*rp) + eir_len); +} + +static int ext_info_changed(struct hci_dev *hdev, struct sock *skip) +{ + char buf[512]; + struct mgmt_ev_ext_info_changed *ev = (void *)buf; + u16 eir_len; + + memset(buf, 0, sizeof(buf)); + + eir_len = append_eir_data_to_buf(hdev, ev->eir); + ev->eir_len = cpu_to_le16(eir_len); + + return mgmt_limited_event(MGMT_EV_EXT_INFO_CHANGED, hdev, ev, + sizeof(*ev) + eir_len, + HCI_MGMT_EXT_INFO_EVENTS, skip); +} + +static int send_settings_rsp(struct sock *sk, u16 opcode, struct hci_dev *hdev) +{ + __le32 settings = cpu_to_le32(get_current_settings(hdev)); + + return mgmt_cmd_complete(sk, hdev->id, opcode, 0, &settings, + sizeof(settings)); +} + +void mgmt_advertising_added(struct sock *sk, struct hci_dev *hdev, u8 instance) +{ + struct mgmt_ev_advertising_added ev; + + ev.instance = instance; + + mgmt_event(MGMT_EV_ADVERTISING_ADDED, hdev, &ev, sizeof(ev), sk); +} + +void mgmt_advertising_removed(struct sock *sk, struct hci_dev *hdev, + u8 instance) +{ + struct mgmt_ev_advertising_removed ev; + + ev.instance = instance; + + mgmt_event(MGMT_EV_ADVERTISING_REMOVED, hdev, &ev, sizeof(ev), sk); +} + +static void cancel_adv_timeout(struct hci_dev *hdev) +{ + if (hdev->adv_instance_timeout) { + hdev->adv_instance_timeout = 0; + cancel_delayed_work(&hdev->adv_instance_expire); + } +} + +/* This function requires the caller holds hdev->lock */ +static void restart_le_actions(struct hci_dev *hdev) +{ + struct hci_conn_params *p; + + list_for_each_entry(p, &hdev->le_conn_params, list) { + /* Needed for AUTO_OFF case where might not "really" + * have been powered off. + */ + hci_pend_le_list_del_init(p); + + switch (p->auto_connect) { + case HCI_AUTO_CONN_DIRECT: + case HCI_AUTO_CONN_ALWAYS: + hci_pend_le_list_add(p, &hdev->pend_le_conns); + break; + case HCI_AUTO_CONN_REPORT: + hci_pend_le_list_add(p, &hdev->pend_le_reports); + break; + default: + break; + } + } +} + +static int new_settings(struct hci_dev *hdev, struct sock *skip) +{ + __le32 ev = cpu_to_le32(get_current_settings(hdev)); + + return mgmt_limited_event(MGMT_EV_NEW_SETTINGS, hdev, &ev, + sizeof(ev), HCI_MGMT_SETTING_EVENTS, skip); +} + +static void mgmt_set_powered_complete(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_mode *cp; + + /* Make sure cmd still outstanding. */ + if (cmd != pending_find(MGMT_OP_SET_POWERED, hdev)) + return; + + cp = cmd->param; + + bt_dev_dbg(hdev, "err %d", err); + + if (!err) { + if (cp->val) { + hci_dev_lock(hdev); + restart_le_actions(hdev); + hci_update_passive_scan(hdev); + hci_dev_unlock(hdev); + } + + send_settings_rsp(cmd->sk, cmd->opcode, hdev); + + /* Only call new_setting for power on as power off is deferred + * to hdev->power_off work which does call hci_dev_do_close. + */ + if (cp->val) + new_settings(hdev, cmd->sk); + } else { + mgmt_cmd_status(cmd->sk, hdev->id, MGMT_OP_SET_POWERED, + mgmt_status(err)); + } + + mgmt_pending_remove(cmd); +} + +static int set_powered_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_mode *cp = cmd->param; + + BT_DBG("%s", hdev->name); + + return hci_set_powered_sync(hdev, cp->val); +} + +static int set_powered(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_mode *cp = data; + struct mgmt_pending_cmd *cmd; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (cp->val != 0x00 && cp->val != 0x01) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_POWERED, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + if (pending_find(MGMT_OP_SET_POWERED, hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_POWERED, + MGMT_STATUS_BUSY); + goto failed; + } + + if (!!cp->val == hdev_is_powered(hdev)) { + err = send_settings_rsp(sk, MGMT_OP_SET_POWERED, hdev); + goto failed; + } + + cmd = mgmt_pending_add(sk, MGMT_OP_SET_POWERED, hdev, data, len); + if (!cmd) { + err = -ENOMEM; + goto failed; + } + + err = hci_cmd_sync_queue(hdev, set_powered_sync, cmd, + mgmt_set_powered_complete); + + if (err < 0) + mgmt_pending_remove(cmd); + +failed: + hci_dev_unlock(hdev); + return err; +} + +int mgmt_new_settings(struct hci_dev *hdev) +{ + return new_settings(hdev, NULL); +} + +struct cmd_lookup { + struct sock *sk; + struct hci_dev *hdev; + u8 mgmt_status; +}; + +static void settings_rsp(struct mgmt_pending_cmd *cmd, void *data) +{ + struct cmd_lookup *match = data; + + send_settings_rsp(cmd->sk, cmd->opcode, match->hdev); + + list_del(&cmd->list); + + if (match->sk == NULL) { + match->sk = cmd->sk; + sock_hold(match->sk); + } + + mgmt_pending_free(cmd); +} + +static void cmd_status_rsp(struct mgmt_pending_cmd *cmd, void *data) +{ + u8 *status = data; + + mgmt_cmd_status(cmd->sk, cmd->index, cmd->opcode, *status); + mgmt_pending_remove(cmd); +} + +static void cmd_complete_rsp(struct mgmt_pending_cmd *cmd, void *data) +{ + if (cmd->cmd_complete) { + u8 *status = data; + + cmd->cmd_complete(cmd, *status); + mgmt_pending_remove(cmd); + + return; + } + + cmd_status_rsp(cmd, data); +} + +static int generic_cmd_complete(struct mgmt_pending_cmd *cmd, u8 status) +{ + return mgmt_cmd_complete(cmd->sk, cmd->index, cmd->opcode, status, + cmd->param, cmd->param_len); +} + +static int addr_cmd_complete(struct mgmt_pending_cmd *cmd, u8 status) +{ + return mgmt_cmd_complete(cmd->sk, cmd->index, cmd->opcode, status, + cmd->param, sizeof(struct mgmt_addr_info)); +} + +static u8 mgmt_bredr_support(struct hci_dev *hdev) +{ + if (!lmp_bredr_capable(hdev)) + return MGMT_STATUS_NOT_SUPPORTED; + else if (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) + return MGMT_STATUS_REJECTED; + else + return MGMT_STATUS_SUCCESS; +} + +static u8 mgmt_le_support(struct hci_dev *hdev) +{ + if (!lmp_le_capable(hdev)) + return MGMT_STATUS_NOT_SUPPORTED; + else if (!hci_dev_test_flag(hdev, HCI_LE_ENABLED)) + return MGMT_STATUS_REJECTED; + else + return MGMT_STATUS_SUCCESS; +} + +static void mgmt_set_discoverable_complete(struct hci_dev *hdev, void *data, + int err) +{ + struct mgmt_pending_cmd *cmd = data; + + bt_dev_dbg(hdev, "err %d", err); + + /* Make sure cmd still outstanding. */ + if (cmd != pending_find(MGMT_OP_SET_DISCOVERABLE, hdev)) + return; + + hci_dev_lock(hdev); + + if (err) { + u8 mgmt_err = mgmt_status(err); + mgmt_cmd_status(cmd->sk, cmd->index, cmd->opcode, mgmt_err); + hci_dev_clear_flag(hdev, HCI_LIMITED_DISCOVERABLE); + goto done; + } + + if (hci_dev_test_flag(hdev, HCI_DISCOVERABLE) && + hdev->discov_timeout > 0) { + int to = msecs_to_jiffies(hdev->discov_timeout * 1000); + queue_delayed_work(hdev->req_workqueue, &hdev->discov_off, to); + } + + send_settings_rsp(cmd->sk, MGMT_OP_SET_DISCOVERABLE, hdev); + new_settings(hdev, cmd->sk); + +done: + mgmt_pending_remove(cmd); + hci_dev_unlock(hdev); +} + +static int set_discoverable_sync(struct hci_dev *hdev, void *data) +{ + BT_DBG("%s", hdev->name); + + return hci_update_discoverable_sync(hdev); +} + +static int set_discoverable(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_set_discoverable *cp = data; + struct mgmt_pending_cmd *cmd; + u16 timeout; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!hci_dev_test_flag(hdev, HCI_LE_ENABLED) && + !hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_DISCOVERABLE, + MGMT_STATUS_REJECTED); + + if (cp->val != 0x00 && cp->val != 0x01 && cp->val != 0x02) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_DISCOVERABLE, + MGMT_STATUS_INVALID_PARAMS); + + timeout = __le16_to_cpu(cp->timeout); + + /* Disabling discoverable requires that no timeout is set, + * and enabling limited discoverable requires a timeout. + */ + if ((cp->val == 0x00 && timeout > 0) || + (cp->val == 0x02 && timeout == 0)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_DISCOVERABLE, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + if (!hdev_is_powered(hdev) && timeout > 0) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_DISCOVERABLE, + MGMT_STATUS_NOT_POWERED); + goto failed; + } + + if (pending_find(MGMT_OP_SET_DISCOVERABLE, hdev) || + pending_find(MGMT_OP_SET_CONNECTABLE, hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_DISCOVERABLE, + MGMT_STATUS_BUSY); + goto failed; + } + + if (!hci_dev_test_flag(hdev, HCI_CONNECTABLE)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_DISCOVERABLE, + MGMT_STATUS_REJECTED); + goto failed; + } + + if (hdev->advertising_paused) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_DISCOVERABLE, + MGMT_STATUS_BUSY); + goto failed; + } + + if (!hdev_is_powered(hdev)) { + bool changed = false; + + /* Setting limited discoverable when powered off is + * not a valid operation since it requires a timeout + * and so no need to check HCI_LIMITED_DISCOVERABLE. + */ + if (!!cp->val != hci_dev_test_flag(hdev, HCI_DISCOVERABLE)) { + hci_dev_change_flag(hdev, HCI_DISCOVERABLE); + changed = true; + } + + err = send_settings_rsp(sk, MGMT_OP_SET_DISCOVERABLE, hdev); + if (err < 0) + goto failed; + + if (changed) + err = new_settings(hdev, sk); + + goto failed; + } + + /* If the current mode is the same, then just update the timeout + * value with the new value. And if only the timeout gets updated, + * then no need for any HCI transactions. + */ + if (!!cp->val == hci_dev_test_flag(hdev, HCI_DISCOVERABLE) && + (cp->val == 0x02) == hci_dev_test_flag(hdev, + HCI_LIMITED_DISCOVERABLE)) { + cancel_delayed_work(&hdev->discov_off); + hdev->discov_timeout = timeout; + + if (cp->val && hdev->discov_timeout > 0) { + int to = msecs_to_jiffies(hdev->discov_timeout * 1000); + queue_delayed_work(hdev->req_workqueue, + &hdev->discov_off, to); + } + + err = send_settings_rsp(sk, MGMT_OP_SET_DISCOVERABLE, hdev); + goto failed; + } + + cmd = mgmt_pending_add(sk, MGMT_OP_SET_DISCOVERABLE, hdev, data, len); + if (!cmd) { + err = -ENOMEM; + goto failed; + } + + /* Cancel any potential discoverable timeout that might be + * still active and store new timeout value. The arming of + * the timeout happens in the complete handler. + */ + cancel_delayed_work(&hdev->discov_off); + hdev->discov_timeout = timeout; + + if (cp->val) + hci_dev_set_flag(hdev, HCI_DISCOVERABLE); + else + hci_dev_clear_flag(hdev, HCI_DISCOVERABLE); + + /* Limited discoverable mode */ + if (cp->val == 0x02) + hci_dev_set_flag(hdev, HCI_LIMITED_DISCOVERABLE); + else + hci_dev_clear_flag(hdev, HCI_LIMITED_DISCOVERABLE); + + err = hci_cmd_sync_queue(hdev, set_discoverable_sync, cmd, + mgmt_set_discoverable_complete); + + if (err < 0) + mgmt_pending_remove(cmd); + +failed: + hci_dev_unlock(hdev); + return err; +} + +static void mgmt_set_connectable_complete(struct hci_dev *hdev, void *data, + int err) +{ + struct mgmt_pending_cmd *cmd = data; + + bt_dev_dbg(hdev, "err %d", err); + + /* Make sure cmd still outstanding. */ + if (cmd != pending_find(MGMT_OP_SET_CONNECTABLE, hdev)) + return; + + hci_dev_lock(hdev); + + if (err) { + u8 mgmt_err = mgmt_status(err); + mgmt_cmd_status(cmd->sk, cmd->index, cmd->opcode, mgmt_err); + goto done; + } + + send_settings_rsp(cmd->sk, MGMT_OP_SET_CONNECTABLE, hdev); + new_settings(hdev, cmd->sk); + +done: + if (cmd) + mgmt_pending_remove(cmd); + + hci_dev_unlock(hdev); +} + +static int set_connectable_update_settings(struct hci_dev *hdev, + struct sock *sk, u8 val) +{ + bool changed = false; + int err; + + if (!!val != hci_dev_test_flag(hdev, HCI_CONNECTABLE)) + changed = true; + + if (val) { + hci_dev_set_flag(hdev, HCI_CONNECTABLE); + } else { + hci_dev_clear_flag(hdev, HCI_CONNECTABLE); + hci_dev_clear_flag(hdev, HCI_DISCOVERABLE); + } + + err = send_settings_rsp(sk, MGMT_OP_SET_CONNECTABLE, hdev); + if (err < 0) + return err; + + if (changed) { + hci_update_scan(hdev); + hci_update_passive_scan(hdev); + return new_settings(hdev, sk); + } + + return 0; +} + +static int set_connectable_sync(struct hci_dev *hdev, void *data) +{ + BT_DBG("%s", hdev->name); + + return hci_update_connectable_sync(hdev); +} + +static int set_connectable(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_mode *cp = data; + struct mgmt_pending_cmd *cmd; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!hci_dev_test_flag(hdev, HCI_LE_ENABLED) && + !hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_CONNECTABLE, + MGMT_STATUS_REJECTED); + + if (cp->val != 0x00 && cp->val != 0x01) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_CONNECTABLE, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + if (!hdev_is_powered(hdev)) { + err = set_connectable_update_settings(hdev, sk, cp->val); + goto failed; + } + + if (pending_find(MGMT_OP_SET_DISCOVERABLE, hdev) || + pending_find(MGMT_OP_SET_CONNECTABLE, hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_CONNECTABLE, + MGMT_STATUS_BUSY); + goto failed; + } + + cmd = mgmt_pending_add(sk, MGMT_OP_SET_CONNECTABLE, hdev, data, len); + if (!cmd) { + err = -ENOMEM; + goto failed; + } + + if (cp->val) { + hci_dev_set_flag(hdev, HCI_CONNECTABLE); + } else { + if (hdev->discov_timeout > 0) + cancel_delayed_work(&hdev->discov_off); + + hci_dev_clear_flag(hdev, HCI_LIMITED_DISCOVERABLE); + hci_dev_clear_flag(hdev, HCI_DISCOVERABLE); + hci_dev_clear_flag(hdev, HCI_CONNECTABLE); + } + + err = hci_cmd_sync_queue(hdev, set_connectable_sync, cmd, + mgmt_set_connectable_complete); + + if (err < 0) + mgmt_pending_remove(cmd); + +failed: + hci_dev_unlock(hdev); + return err; +} + +static int set_bondable(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_mode *cp = data; + bool changed; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (cp->val != 0x00 && cp->val != 0x01) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_BONDABLE, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + if (cp->val) + changed = !hci_dev_test_and_set_flag(hdev, HCI_BONDABLE); + else + changed = hci_dev_test_and_clear_flag(hdev, HCI_BONDABLE); + + err = send_settings_rsp(sk, MGMT_OP_SET_BONDABLE, hdev); + if (err < 0) + goto unlock; + + if (changed) { + /* In limited privacy mode the change of bondable mode + * may affect the local advertising address. + */ + hci_update_discoverable(hdev); + + err = new_settings(hdev, sk); + } + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static int set_link_security(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_mode *cp = data; + struct mgmt_pending_cmd *cmd; + u8 val, status; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + status = mgmt_bredr_support(hdev); + if (status) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_LINK_SECURITY, + status); + + if (cp->val != 0x00 && cp->val != 0x01) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_LINK_SECURITY, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + if (!hdev_is_powered(hdev)) { + bool changed = false; + + if (!!cp->val != hci_dev_test_flag(hdev, HCI_LINK_SECURITY)) { + hci_dev_change_flag(hdev, HCI_LINK_SECURITY); + changed = true; + } + + err = send_settings_rsp(sk, MGMT_OP_SET_LINK_SECURITY, hdev); + if (err < 0) + goto failed; + + if (changed) + err = new_settings(hdev, sk); + + goto failed; + } + + if (pending_find(MGMT_OP_SET_LINK_SECURITY, hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_LINK_SECURITY, + MGMT_STATUS_BUSY); + goto failed; + } + + val = !!cp->val; + + if (test_bit(HCI_AUTH, &hdev->flags) == val) { + err = send_settings_rsp(sk, MGMT_OP_SET_LINK_SECURITY, hdev); + goto failed; + } + + cmd = mgmt_pending_add(sk, MGMT_OP_SET_LINK_SECURITY, hdev, data, len); + if (!cmd) { + err = -ENOMEM; + goto failed; + } + + err = hci_send_cmd(hdev, HCI_OP_WRITE_AUTH_ENABLE, sizeof(val), &val); + if (err < 0) { + mgmt_pending_remove(cmd); + goto failed; + } + +failed: + hci_dev_unlock(hdev); + return err; +} + +static void set_ssp_complete(struct hci_dev *hdev, void *data, int err) +{ + struct cmd_lookup match = { NULL, hdev }; + struct mgmt_pending_cmd *cmd = data; + struct mgmt_mode *cp = cmd->param; + u8 enable = cp->val; + bool changed; + + /* Make sure cmd still outstanding. */ + if (cmd != pending_find(MGMT_OP_SET_SSP, hdev)) + return; + + if (err) { + u8 mgmt_err = mgmt_status(err); + + if (enable && hci_dev_test_and_clear_flag(hdev, + HCI_SSP_ENABLED)) { + hci_dev_clear_flag(hdev, HCI_HS_ENABLED); + new_settings(hdev, NULL); + } + + mgmt_pending_foreach(MGMT_OP_SET_SSP, hdev, cmd_status_rsp, + &mgmt_err); + return; + } + + if (enable) { + changed = !hci_dev_test_and_set_flag(hdev, HCI_SSP_ENABLED); + } else { + changed = hci_dev_test_and_clear_flag(hdev, HCI_SSP_ENABLED); + + if (!changed) + changed = hci_dev_test_and_clear_flag(hdev, + HCI_HS_ENABLED); + else + hci_dev_clear_flag(hdev, HCI_HS_ENABLED); + } + + mgmt_pending_foreach(MGMT_OP_SET_SSP, hdev, settings_rsp, &match); + + if (changed) + new_settings(hdev, match.sk); + + if (match.sk) + sock_put(match.sk); + + hci_update_eir_sync(hdev); +} + +static int set_ssp_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_mode *cp = cmd->param; + bool changed = false; + int err; + + if (cp->val) + changed = !hci_dev_test_and_set_flag(hdev, HCI_SSP_ENABLED); + + err = hci_write_ssp_mode_sync(hdev, cp->val); + + if (!err && changed) + hci_dev_clear_flag(hdev, HCI_SSP_ENABLED); + + return err; +} + +static int set_ssp(struct sock *sk, struct hci_dev *hdev, void *data, u16 len) +{ + struct mgmt_mode *cp = data; + struct mgmt_pending_cmd *cmd; + u8 status; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + status = mgmt_bredr_support(hdev); + if (status) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_SSP, status); + + if (!lmp_ssp_capable(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_SSP, + MGMT_STATUS_NOT_SUPPORTED); + + if (cp->val != 0x00 && cp->val != 0x01) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_SSP, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + if (!hdev_is_powered(hdev)) { + bool changed; + + if (cp->val) { + changed = !hci_dev_test_and_set_flag(hdev, + HCI_SSP_ENABLED); + } else { + changed = hci_dev_test_and_clear_flag(hdev, + HCI_SSP_ENABLED); + if (!changed) + changed = hci_dev_test_and_clear_flag(hdev, + HCI_HS_ENABLED); + else + hci_dev_clear_flag(hdev, HCI_HS_ENABLED); + } + + err = send_settings_rsp(sk, MGMT_OP_SET_SSP, hdev); + if (err < 0) + goto failed; + + if (changed) + err = new_settings(hdev, sk); + + goto failed; + } + + if (pending_find(MGMT_OP_SET_SSP, hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_SSP, + MGMT_STATUS_BUSY); + goto failed; + } + + if (!!cp->val == hci_dev_test_flag(hdev, HCI_SSP_ENABLED)) { + err = send_settings_rsp(sk, MGMT_OP_SET_SSP, hdev); + goto failed; + } + + cmd = mgmt_pending_add(sk, MGMT_OP_SET_SSP, hdev, data, len); + if (!cmd) + err = -ENOMEM; + else + err = hci_cmd_sync_queue(hdev, set_ssp_sync, cmd, + set_ssp_complete); + + if (err < 0) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_SSP, + MGMT_STATUS_FAILED); + + if (cmd) + mgmt_pending_remove(cmd); + } + +failed: + hci_dev_unlock(hdev); + return err; +} + +static int set_hs(struct sock *sk, struct hci_dev *hdev, void *data, u16 len) +{ + struct mgmt_mode *cp = data; + bool changed; + u8 status; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!IS_ENABLED(CONFIG_BT_HS)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_HS, + MGMT_STATUS_NOT_SUPPORTED); + + status = mgmt_bredr_support(hdev); + if (status) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_HS, status); + + if (!lmp_ssp_capable(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_HS, + MGMT_STATUS_NOT_SUPPORTED); + + if (!hci_dev_test_flag(hdev, HCI_SSP_ENABLED)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_HS, + MGMT_STATUS_REJECTED); + + if (cp->val != 0x00 && cp->val != 0x01) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_HS, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + if (pending_find(MGMT_OP_SET_SSP, hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_HS, + MGMT_STATUS_BUSY); + goto unlock; + } + + if (cp->val) { + changed = !hci_dev_test_and_set_flag(hdev, HCI_HS_ENABLED); + } else { + if (hdev_is_powered(hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_HS, + MGMT_STATUS_REJECTED); + goto unlock; + } + + changed = hci_dev_test_and_clear_flag(hdev, HCI_HS_ENABLED); + } + + err = send_settings_rsp(sk, MGMT_OP_SET_HS, hdev); + if (err < 0) + goto unlock; + + if (changed) + err = new_settings(hdev, sk); + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static void set_le_complete(struct hci_dev *hdev, void *data, int err) +{ + struct cmd_lookup match = { NULL, hdev }; + u8 status = mgmt_status(err); + + bt_dev_dbg(hdev, "err %d", err); + + if (status) { + mgmt_pending_foreach(MGMT_OP_SET_LE, hdev, cmd_status_rsp, + &status); + return; + } + + mgmt_pending_foreach(MGMT_OP_SET_LE, hdev, settings_rsp, &match); + + new_settings(hdev, match.sk); + + if (match.sk) + sock_put(match.sk); +} + +static int set_le_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_mode *cp = cmd->param; + u8 val = !!cp->val; + int err; + + if (!val) { + hci_clear_adv_instance_sync(hdev, NULL, 0x00, true); + + if (hci_dev_test_flag(hdev, HCI_LE_ADV)) + hci_disable_advertising_sync(hdev); + + if (ext_adv_capable(hdev)) + hci_remove_ext_adv_instance_sync(hdev, 0, cmd->sk); + } else { + hci_dev_set_flag(hdev, HCI_LE_ENABLED); + } + + err = hci_write_le_host_supported_sync(hdev, val, 0); + + /* Make sure the controller has a good default for + * advertising data. Restrict the update to when LE + * has actually been enabled. During power on, the + * update in powered_update_hci will take care of it. + */ + if (!err && hci_dev_test_flag(hdev, HCI_LE_ENABLED)) { + if (ext_adv_capable(hdev)) { + int status; + + status = hci_setup_ext_adv_instance_sync(hdev, 0x00); + if (!status) + hci_update_scan_rsp_data_sync(hdev, 0x00); + } else { + hci_update_adv_data_sync(hdev, 0x00); + hci_update_scan_rsp_data_sync(hdev, 0x00); + } + + hci_update_passive_scan(hdev); + } + + return err; +} + +static void set_mesh_complete(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_pending_cmd *cmd = data; + u8 status = mgmt_status(err); + struct sock *sk = cmd->sk; + + if (status) { + mgmt_pending_foreach(MGMT_OP_SET_MESH_RECEIVER, hdev, + cmd_status_rsp, &status); + return; + } + + mgmt_pending_remove(cmd); + mgmt_cmd_complete(sk, hdev->id, MGMT_OP_SET_MESH_RECEIVER, 0, NULL, 0); +} + +static int set_mesh_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_set_mesh *cp = cmd->param; + size_t len = cmd->param_len; + + memset(hdev->mesh_ad_types, 0, sizeof(hdev->mesh_ad_types)); + + if (cp->enable) + hci_dev_set_flag(hdev, HCI_MESH); + else + hci_dev_clear_flag(hdev, HCI_MESH); + + len -= sizeof(*cp); + + /* If filters don't fit, forward all adv pkts */ + if (len <= sizeof(hdev->mesh_ad_types)) + memcpy(hdev->mesh_ad_types, cp->ad_types, len); + + hci_update_passive_scan_sync(hdev); + return 0; +} + +static int set_mesh(struct sock *sk, struct hci_dev *hdev, void *data, u16 len) +{ + struct mgmt_cp_set_mesh *cp = data; + struct mgmt_pending_cmd *cmd; + int err = 0; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!lmp_le_capable(hdev) || + !hci_dev_test_flag(hdev, HCI_MESH_EXPERIMENTAL)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_MESH_RECEIVER, + MGMT_STATUS_NOT_SUPPORTED); + + if (cp->enable != 0x00 && cp->enable != 0x01) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_MESH_RECEIVER, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + cmd = mgmt_pending_add(sk, MGMT_OP_SET_MESH_RECEIVER, hdev, data, len); + if (!cmd) + err = -ENOMEM; + else + err = hci_cmd_sync_queue(hdev, set_mesh_sync, cmd, + set_mesh_complete); + + if (err < 0) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_MESH_RECEIVER, + MGMT_STATUS_FAILED); + + if (cmd) + mgmt_pending_remove(cmd); + } + + hci_dev_unlock(hdev); + return err; +} + +static void mesh_send_start_complete(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_mesh_tx *mesh_tx = data; + struct mgmt_cp_mesh_send *send = (void *)mesh_tx->param; + unsigned long mesh_send_interval; + u8 mgmt_err = mgmt_status(err); + + /* Report any errors here, but don't report completion */ + + if (mgmt_err) { + hci_dev_clear_flag(hdev, HCI_MESH_SENDING); + /* Send Complete Error Code for handle */ + mesh_send_complete(hdev, mesh_tx, false); + return; + } + + mesh_send_interval = msecs_to_jiffies((send->cnt) * 25); + queue_delayed_work(hdev->req_workqueue, &hdev->mesh_send_done, + mesh_send_interval); +} + +static int mesh_send_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_mesh_tx *mesh_tx = data; + struct mgmt_cp_mesh_send *send = (void *)mesh_tx->param; + struct adv_info *adv, *next_instance; + u8 instance = hdev->le_num_of_adv_sets + 1; + u16 timeout, duration; + int err = 0; + + if (hdev->le_num_of_adv_sets <= hdev->adv_instance_cnt) + return MGMT_STATUS_BUSY; + + timeout = 1000; + duration = send->cnt * INTERVAL_TO_MS(hdev->le_adv_max_interval); + adv = hci_add_adv_instance(hdev, instance, 0, + send->adv_data_len, send->adv_data, + 0, NULL, + timeout, duration, + HCI_ADV_TX_POWER_NO_PREFERENCE, + hdev->le_adv_min_interval, + hdev->le_adv_max_interval, + mesh_tx->handle); + + if (!IS_ERR(adv)) + mesh_tx->instance = instance; + else + err = PTR_ERR(adv); + + if (hdev->cur_adv_instance == instance) { + /* If the currently advertised instance is being changed then + * cancel the current advertising and schedule the next + * instance. If there is only one instance then the overridden + * advertising data will be visible right away. + */ + cancel_adv_timeout(hdev); + + next_instance = hci_get_next_instance(hdev, instance); + if (next_instance) + instance = next_instance->instance; + else + instance = 0; + } else if (hdev->adv_instance_timeout) { + /* Immediately advertise the new instance if no other, or + * let it go naturally from queue if ADV is already happening + */ + instance = 0; + } + + if (instance) + return hci_schedule_adv_instance_sync(hdev, instance, true); + + return err; +} + +static void send_count(struct mgmt_mesh_tx *mesh_tx, void *data) +{ + struct mgmt_rp_mesh_read_features *rp = data; + + if (rp->used_handles >= rp->max_handles) + return; + + rp->handles[rp->used_handles++] = mesh_tx->handle; +} + +static int mesh_features(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_rp_mesh_read_features rp; + + if (!lmp_le_capable(hdev) || + !hci_dev_test_flag(hdev, HCI_MESH_EXPERIMENTAL)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_MESH_READ_FEATURES, + MGMT_STATUS_NOT_SUPPORTED); + + memset(&rp, 0, sizeof(rp)); + rp.index = cpu_to_le16(hdev->id); + if (hci_dev_test_flag(hdev, HCI_LE_ENABLED)) + rp.max_handles = MESH_HANDLES_MAX; + + hci_dev_lock(hdev); + + if (rp.max_handles) + mgmt_mesh_foreach(hdev, send_count, &rp, sk); + + mgmt_cmd_complete(sk, hdev->id, MGMT_OP_MESH_READ_FEATURES, 0, &rp, + rp.used_handles + sizeof(rp) - MESH_HANDLES_MAX); + + hci_dev_unlock(hdev); + return 0; +} + +static int send_cancel(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_mesh_send_cancel *cancel = (void *)cmd->param; + struct mgmt_mesh_tx *mesh_tx; + + if (!cancel->handle) { + do { + mesh_tx = mgmt_mesh_next(hdev, cmd->sk); + + if (mesh_tx) + mesh_send_complete(hdev, mesh_tx, false); + } while (mesh_tx); + } else { + mesh_tx = mgmt_mesh_find(hdev, cancel->handle); + + if (mesh_tx && mesh_tx->sk == cmd->sk) + mesh_send_complete(hdev, mesh_tx, false); + } + + mgmt_cmd_complete(cmd->sk, hdev->id, MGMT_OP_MESH_SEND_CANCEL, + 0, NULL, 0); + mgmt_pending_free(cmd); + + return 0; +} + +static int mesh_send_cancel(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_pending_cmd *cmd; + int err; + + if (!lmp_le_capable(hdev) || + !hci_dev_test_flag(hdev, HCI_MESH_EXPERIMENTAL)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_MESH_SEND_CANCEL, + MGMT_STATUS_NOT_SUPPORTED); + + if (!hci_dev_test_flag(hdev, HCI_LE_ENABLED)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_MESH_SEND_CANCEL, + MGMT_STATUS_REJECTED); + + hci_dev_lock(hdev); + cmd = mgmt_pending_new(sk, MGMT_OP_MESH_SEND_CANCEL, hdev, data, len); + if (!cmd) + err = -ENOMEM; + else + err = hci_cmd_sync_queue(hdev, send_cancel, cmd, NULL); + + if (err < 0) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_MESH_SEND_CANCEL, + MGMT_STATUS_FAILED); + + if (cmd) + mgmt_pending_free(cmd); + } + + hci_dev_unlock(hdev); + return err; +} + +static int mesh_send(struct sock *sk, struct hci_dev *hdev, void *data, u16 len) +{ + struct mgmt_mesh_tx *mesh_tx; + struct mgmt_cp_mesh_send *send = data; + struct mgmt_rp_mesh_read_features rp; + bool sending; + int err = 0; + + if (!lmp_le_capable(hdev) || + !hci_dev_test_flag(hdev, HCI_MESH_EXPERIMENTAL)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_MESH_SEND, + MGMT_STATUS_NOT_SUPPORTED); + if (!hci_dev_test_flag(hdev, HCI_LE_ENABLED) || + len <= MGMT_MESH_SEND_SIZE || + len > (MGMT_MESH_SEND_SIZE + 31)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_MESH_SEND, + MGMT_STATUS_REJECTED); + + hci_dev_lock(hdev); + + memset(&rp, 0, sizeof(rp)); + rp.max_handles = MESH_HANDLES_MAX; + + mgmt_mesh_foreach(hdev, send_count, &rp, sk); + + if (rp.max_handles <= rp.used_handles) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_MESH_SEND, + MGMT_STATUS_BUSY); + goto done; + } + + sending = hci_dev_test_flag(hdev, HCI_MESH_SENDING); + mesh_tx = mgmt_mesh_add(sk, hdev, send, len); + + if (!mesh_tx) + err = -ENOMEM; + else if (!sending) + err = hci_cmd_sync_queue(hdev, mesh_send_sync, mesh_tx, + mesh_send_start_complete); + + if (err < 0) { + bt_dev_err(hdev, "Send Mesh Failed %d", err); + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_MESH_SEND, + MGMT_STATUS_FAILED); + + if (mesh_tx) { + if (sending) + mgmt_mesh_remove(mesh_tx); + } + } else { + hci_dev_set_flag(hdev, HCI_MESH_SENDING); + + mgmt_cmd_complete(sk, hdev->id, MGMT_OP_MESH_SEND, 0, + &mesh_tx->handle, 1); + } + +done: + hci_dev_unlock(hdev); + return err; +} + +static int set_le(struct sock *sk, struct hci_dev *hdev, void *data, u16 len) +{ + struct mgmt_mode *cp = data; + struct mgmt_pending_cmd *cmd; + int err; + u8 val, enabled; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!lmp_le_capable(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_LE, + MGMT_STATUS_NOT_SUPPORTED); + + if (cp->val != 0x00 && cp->val != 0x01) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_LE, + MGMT_STATUS_INVALID_PARAMS); + + /* Bluetooth single mode LE only controllers or dual-mode + * controllers configured as LE only devices, do not allow + * switching LE off. These have either LE enabled explicitly + * or BR/EDR has been previously switched off. + * + * When trying to enable an already enabled LE, then gracefully + * send a positive response. Trying to disable it however will + * result into rejection. + */ + if (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) { + if (cp->val == 0x01) + return send_settings_rsp(sk, MGMT_OP_SET_LE, hdev); + + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_LE, + MGMT_STATUS_REJECTED); + } + + hci_dev_lock(hdev); + + val = !!cp->val; + enabled = lmp_host_le_capable(hdev); + + if (!hdev_is_powered(hdev) || val == enabled) { + bool changed = false; + + if (val != hci_dev_test_flag(hdev, HCI_LE_ENABLED)) { + hci_dev_change_flag(hdev, HCI_LE_ENABLED); + changed = true; + } + + if (!val && hci_dev_test_flag(hdev, HCI_ADVERTISING)) { + hci_dev_clear_flag(hdev, HCI_ADVERTISING); + changed = true; + } + + err = send_settings_rsp(sk, MGMT_OP_SET_LE, hdev); + if (err < 0) + goto unlock; + + if (changed) + err = new_settings(hdev, sk); + + goto unlock; + } + + if (pending_find(MGMT_OP_SET_LE, hdev) || + pending_find(MGMT_OP_SET_ADVERTISING, hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_LE, + MGMT_STATUS_BUSY); + goto unlock; + } + + cmd = mgmt_pending_add(sk, MGMT_OP_SET_LE, hdev, data, len); + if (!cmd) + err = -ENOMEM; + else + err = hci_cmd_sync_queue(hdev, set_le_sync, cmd, + set_le_complete); + + if (err < 0) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_LE, + MGMT_STATUS_FAILED); + + if (cmd) + mgmt_pending_remove(cmd); + } + +unlock: + hci_dev_unlock(hdev); + return err; +} + +/* This is a helper function to test for pending mgmt commands that can + * cause CoD or EIR HCI commands. We can only allow one such pending + * mgmt command at a time since otherwise we cannot easily track what + * the current values are, will be, and based on that calculate if a new + * HCI command needs to be sent and if yes with what value. + */ +static bool pending_eir_or_class(struct hci_dev *hdev) +{ + struct mgmt_pending_cmd *cmd; + + list_for_each_entry(cmd, &hdev->mgmt_pending, list) { + switch (cmd->opcode) { + case MGMT_OP_ADD_UUID: + case MGMT_OP_REMOVE_UUID: + case MGMT_OP_SET_DEV_CLASS: + case MGMT_OP_SET_POWERED: + return true; + } + } + + return false; +} + +static const u8 bluetooth_base_uuid[] = { + 0xfb, 0x34, 0x9b, 0x5f, 0x80, 0x00, 0x00, 0x80, + 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +}; + +static u8 get_uuid_size(const u8 *uuid) +{ + u32 val; + + if (memcmp(uuid, bluetooth_base_uuid, 12)) + return 128; + + val = get_unaligned_le32(&uuid[12]); + if (val > 0xffff) + return 32; + + return 16; +} + +static void mgmt_class_complete(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_pending_cmd *cmd = data; + + bt_dev_dbg(hdev, "err %d", err); + + mgmt_cmd_complete(cmd->sk, cmd->index, cmd->opcode, + mgmt_status(err), hdev->dev_class, 3); + + mgmt_pending_free(cmd); +} + +static int add_uuid_sync(struct hci_dev *hdev, void *data) +{ + int err; + + err = hci_update_class_sync(hdev); + if (err) + return err; + + return hci_update_eir_sync(hdev); +} + +static int add_uuid(struct sock *sk, struct hci_dev *hdev, void *data, u16 len) +{ + struct mgmt_cp_add_uuid *cp = data; + struct mgmt_pending_cmd *cmd; + struct bt_uuid *uuid; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + hci_dev_lock(hdev); + + if (pending_eir_or_class(hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_UUID, + MGMT_STATUS_BUSY); + goto failed; + } + + uuid = kmalloc(sizeof(*uuid), GFP_KERNEL); + if (!uuid) { + err = -ENOMEM; + goto failed; + } + + memcpy(uuid->uuid, cp->uuid, 16); + uuid->svc_hint = cp->svc_hint; + uuid->size = get_uuid_size(cp->uuid); + + list_add_tail(&uuid->list, &hdev->uuids); + + cmd = mgmt_pending_new(sk, MGMT_OP_ADD_UUID, hdev, data, len); + if (!cmd) { + err = -ENOMEM; + goto failed; + } + + err = hci_cmd_sync_queue(hdev, add_uuid_sync, cmd, mgmt_class_complete); + if (err < 0) { + mgmt_pending_free(cmd); + goto failed; + } + +failed: + hci_dev_unlock(hdev); + return err; +} + +static bool enable_service_cache(struct hci_dev *hdev) +{ + if (!hdev_is_powered(hdev)) + return false; + + if (!hci_dev_test_and_set_flag(hdev, HCI_SERVICE_CACHE)) { + queue_delayed_work(hdev->workqueue, &hdev->service_cache, + CACHE_TIMEOUT); + return true; + } + + return false; +} + +static int remove_uuid_sync(struct hci_dev *hdev, void *data) +{ + int err; + + err = hci_update_class_sync(hdev); + if (err) + return err; + + return hci_update_eir_sync(hdev); +} + +static int remove_uuid(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_remove_uuid *cp = data; + struct mgmt_pending_cmd *cmd; + struct bt_uuid *match, *tmp; + static const u8 bt_uuid_any[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + }; + int err, found; + + bt_dev_dbg(hdev, "sock %p", sk); + + hci_dev_lock(hdev); + + if (pending_eir_or_class(hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_REMOVE_UUID, + MGMT_STATUS_BUSY); + goto unlock; + } + + if (memcmp(cp->uuid, bt_uuid_any, 16) == 0) { + hci_uuids_clear(hdev); + + if (enable_service_cache(hdev)) { + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_REMOVE_UUID, + 0, hdev->dev_class, 3); + goto unlock; + } + + goto update_class; + } + + found = 0; + + list_for_each_entry_safe(match, tmp, &hdev->uuids, list) { + if (memcmp(match->uuid, cp->uuid, 16) != 0) + continue; + + list_del(&match->list); + kfree(match); + found++; + } + + if (found == 0) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_REMOVE_UUID, + MGMT_STATUS_INVALID_PARAMS); + goto unlock; + } + +update_class: + cmd = mgmt_pending_new(sk, MGMT_OP_REMOVE_UUID, hdev, data, len); + if (!cmd) { + err = -ENOMEM; + goto unlock; + } + + err = hci_cmd_sync_queue(hdev, remove_uuid_sync, cmd, + mgmt_class_complete); + if (err < 0) + mgmt_pending_free(cmd); + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static int set_class_sync(struct hci_dev *hdev, void *data) +{ + int err = 0; + + if (hci_dev_test_and_clear_flag(hdev, HCI_SERVICE_CACHE)) { + cancel_delayed_work_sync(&hdev->service_cache); + err = hci_update_eir_sync(hdev); + } + + if (err) + return err; + + return hci_update_class_sync(hdev); +} + +static int set_dev_class(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_set_dev_class *cp = data; + struct mgmt_pending_cmd *cmd; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!lmp_bredr_capable(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_DEV_CLASS, + MGMT_STATUS_NOT_SUPPORTED); + + hci_dev_lock(hdev); + + if (pending_eir_or_class(hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_DEV_CLASS, + MGMT_STATUS_BUSY); + goto unlock; + } + + if ((cp->minor & 0x03) != 0 || (cp->major & 0xe0) != 0) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_DEV_CLASS, + MGMT_STATUS_INVALID_PARAMS); + goto unlock; + } + + hdev->major_class = cp->major; + hdev->minor_class = cp->minor; + + if (!hdev_is_powered(hdev)) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_SET_DEV_CLASS, 0, + hdev->dev_class, 3); + goto unlock; + } + + cmd = mgmt_pending_new(sk, MGMT_OP_SET_DEV_CLASS, hdev, data, len); + if (!cmd) { + err = -ENOMEM; + goto unlock; + } + + err = hci_cmd_sync_queue(hdev, set_class_sync, cmd, + mgmt_class_complete); + if (err < 0) + mgmt_pending_free(cmd); + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static int load_link_keys(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_load_link_keys *cp = data; + const u16 max_key_count = ((U16_MAX - sizeof(*cp)) / + sizeof(struct mgmt_link_key_info)); + u16 key_count, expected_len; + bool changed; + int i; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!lmp_bredr_capable(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_LOAD_LINK_KEYS, + MGMT_STATUS_NOT_SUPPORTED); + + key_count = __le16_to_cpu(cp->key_count); + if (key_count > max_key_count) { + bt_dev_err(hdev, "load_link_keys: too big key_count value %u", + key_count); + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_LOAD_LINK_KEYS, + MGMT_STATUS_INVALID_PARAMS); + } + + expected_len = struct_size(cp, keys, key_count); + if (expected_len != len) { + bt_dev_err(hdev, "load_link_keys: expected %u bytes, got %u bytes", + expected_len, len); + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_LOAD_LINK_KEYS, + MGMT_STATUS_INVALID_PARAMS); + } + + if (cp->debug_keys != 0x00 && cp->debug_keys != 0x01) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_LOAD_LINK_KEYS, + MGMT_STATUS_INVALID_PARAMS); + + bt_dev_dbg(hdev, "debug_keys %u key_count %u", cp->debug_keys, + key_count); + + for (i = 0; i < key_count; i++) { + struct mgmt_link_key_info *key = &cp->keys[i]; + + /* Considering SMP over BREDR/LE, there is no need to check addr_type */ + if (key->type > 0x08) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_LOAD_LINK_KEYS, + MGMT_STATUS_INVALID_PARAMS); + } + + hci_dev_lock(hdev); + + hci_link_keys_clear(hdev); + + if (cp->debug_keys) + changed = !hci_dev_test_and_set_flag(hdev, HCI_KEEP_DEBUG_KEYS); + else + changed = hci_dev_test_and_clear_flag(hdev, + HCI_KEEP_DEBUG_KEYS); + + if (changed) + new_settings(hdev, NULL); + + for (i = 0; i < key_count; i++) { + struct mgmt_link_key_info *key = &cp->keys[i]; + + if (hci_is_blocked_key(hdev, + HCI_BLOCKED_KEY_TYPE_LINKKEY, + key->val)) { + bt_dev_warn(hdev, "Skipping blocked link key for %pMR", + &key->addr.bdaddr); + continue; + } + + /* Always ignore debug keys and require a new pairing if + * the user wants to use them. + */ + if (key->type == HCI_LK_DEBUG_COMBINATION) + continue; + + hci_add_link_key(hdev, NULL, &key->addr.bdaddr, key->val, + key->type, key->pin_len, NULL); + } + + mgmt_cmd_complete(sk, hdev->id, MGMT_OP_LOAD_LINK_KEYS, 0, NULL, 0); + + hci_dev_unlock(hdev); + + return 0; +} + +static int device_unpaired(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 addr_type, struct sock *skip_sk) +{ + struct mgmt_ev_device_unpaired ev; + + bacpy(&ev.addr.bdaddr, bdaddr); + ev.addr.type = addr_type; + + return mgmt_event(MGMT_EV_DEVICE_UNPAIRED, hdev, &ev, sizeof(ev), + skip_sk); +} + +static void unpair_device_complete(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_unpair_device *cp = cmd->param; + + if (!err) + device_unpaired(hdev, &cp->addr.bdaddr, cp->addr.type, cmd->sk); + + cmd->cmd_complete(cmd, err); + mgmt_pending_free(cmd); +} + +static int unpair_device_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_unpair_device *cp = cmd->param; + struct hci_conn *conn; + + if (cp->addr.type == BDADDR_BREDR) + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, + &cp->addr.bdaddr); + else + conn = hci_conn_hash_lookup_le(hdev, &cp->addr.bdaddr, + le_addr_type(cp->addr.type)); + + if (!conn) + return 0; + + return hci_abort_conn_sync(hdev, conn, HCI_ERROR_REMOTE_USER_TERM); +} + +static int unpair_device(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_unpair_device *cp = data; + struct mgmt_rp_unpair_device rp; + struct hci_conn_params *params; + struct mgmt_pending_cmd *cmd; + struct hci_conn *conn; + u8 addr_type; + int err; + + memset(&rp, 0, sizeof(rp)); + bacpy(&rp.addr.bdaddr, &cp->addr.bdaddr); + rp.addr.type = cp->addr.type; + + if (!bdaddr_type_is_valid(cp->addr.type)) + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_UNPAIR_DEVICE, + MGMT_STATUS_INVALID_PARAMS, + &rp, sizeof(rp)); + + if (cp->disconnect != 0x00 && cp->disconnect != 0x01) + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_UNPAIR_DEVICE, + MGMT_STATUS_INVALID_PARAMS, + &rp, sizeof(rp)); + + hci_dev_lock(hdev); + + if (!hdev_is_powered(hdev)) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_UNPAIR_DEVICE, + MGMT_STATUS_NOT_POWERED, &rp, + sizeof(rp)); + goto unlock; + } + + if (cp->addr.type == BDADDR_BREDR) { + /* If disconnection is requested, then look up the + * connection. If the remote device is connected, it + * will be later used to terminate the link. + * + * Setting it to NULL explicitly will cause no + * termination of the link. + */ + if (cp->disconnect) + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, + &cp->addr.bdaddr); + else + conn = NULL; + + err = hci_remove_link_key(hdev, &cp->addr.bdaddr); + if (err < 0) { + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_UNPAIR_DEVICE, + MGMT_STATUS_NOT_PAIRED, &rp, + sizeof(rp)); + goto unlock; + } + + goto done; + } + + /* LE address type */ + addr_type = le_addr_type(cp->addr.type); + + /* Abort any ongoing SMP pairing. Removes ltk and irk if they exist. */ + err = smp_cancel_and_remove_pairing(hdev, &cp->addr.bdaddr, addr_type); + if (err < 0) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_UNPAIR_DEVICE, + MGMT_STATUS_NOT_PAIRED, &rp, + sizeof(rp)); + goto unlock; + } + + conn = hci_conn_hash_lookup_le(hdev, &cp->addr.bdaddr, addr_type); + if (!conn) { + hci_conn_params_del(hdev, &cp->addr.bdaddr, addr_type); + goto done; + } + + + /* Defer clearing up the connection parameters until closing to + * give a chance of keeping them if a repairing happens. + */ + set_bit(HCI_CONN_PARAM_REMOVAL_PEND, &conn->flags); + + /* Disable auto-connection parameters if present */ + params = hci_conn_params_lookup(hdev, &cp->addr.bdaddr, addr_type); + if (params) { + if (params->explicit_connect) + params->auto_connect = HCI_AUTO_CONN_EXPLICIT; + else + params->auto_connect = HCI_AUTO_CONN_DISABLED; + } + + /* If disconnection is not requested, then clear the connection + * variable so that the link is not terminated. + */ + if (!cp->disconnect) + conn = NULL; + +done: + /* If the connection variable is set, then termination of the + * link is requested. + */ + if (!conn) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_UNPAIR_DEVICE, 0, + &rp, sizeof(rp)); + device_unpaired(hdev, &cp->addr.bdaddr, cp->addr.type, sk); + goto unlock; + } + + cmd = mgmt_pending_new(sk, MGMT_OP_UNPAIR_DEVICE, hdev, cp, + sizeof(*cp)); + if (!cmd) { + err = -ENOMEM; + goto unlock; + } + + cmd->cmd_complete = addr_cmd_complete; + + err = hci_cmd_sync_queue(hdev, unpair_device_sync, cmd, + unpair_device_complete); + if (err < 0) + mgmt_pending_free(cmd); + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static int disconnect(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_disconnect *cp = data; + struct mgmt_rp_disconnect rp; + struct mgmt_pending_cmd *cmd; + struct hci_conn *conn; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + memset(&rp, 0, sizeof(rp)); + bacpy(&rp.addr.bdaddr, &cp->addr.bdaddr); + rp.addr.type = cp->addr.type; + + if (!bdaddr_type_is_valid(cp->addr.type)) + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_DISCONNECT, + MGMT_STATUS_INVALID_PARAMS, + &rp, sizeof(rp)); + + hci_dev_lock(hdev); + + if (!test_bit(HCI_UP, &hdev->flags)) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_DISCONNECT, + MGMT_STATUS_NOT_POWERED, &rp, + sizeof(rp)); + goto failed; + } + + if (pending_find(MGMT_OP_DISCONNECT, hdev)) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_DISCONNECT, + MGMT_STATUS_BUSY, &rp, sizeof(rp)); + goto failed; + } + + if (cp->addr.type == BDADDR_BREDR) + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, + &cp->addr.bdaddr); + else + conn = hci_conn_hash_lookup_le(hdev, &cp->addr.bdaddr, + le_addr_type(cp->addr.type)); + + if (!conn || conn->state == BT_OPEN || conn->state == BT_CLOSED) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_DISCONNECT, + MGMT_STATUS_NOT_CONNECTED, &rp, + sizeof(rp)); + goto failed; + } + + cmd = mgmt_pending_add(sk, MGMT_OP_DISCONNECT, hdev, data, len); + if (!cmd) { + err = -ENOMEM; + goto failed; + } + + cmd->cmd_complete = generic_cmd_complete; + + err = hci_disconnect(conn, HCI_ERROR_REMOTE_USER_TERM); + if (err < 0) + mgmt_pending_remove(cmd); + +failed: + hci_dev_unlock(hdev); + return err; +} + +static u8 link_to_bdaddr(u8 link_type, u8 addr_type) +{ + switch (link_type) { + case LE_LINK: + switch (addr_type) { + case ADDR_LE_DEV_PUBLIC: + return BDADDR_LE_PUBLIC; + + default: + /* Fallback to LE Random address type */ + return BDADDR_LE_RANDOM; + } + + default: + /* Fallback to BR/EDR type */ + return BDADDR_BREDR; + } +} + +static int get_connections(struct sock *sk, struct hci_dev *hdev, void *data, + u16 data_len) +{ + struct mgmt_rp_get_connections *rp; + struct hci_conn *c; + int err; + u16 i; + + bt_dev_dbg(hdev, "sock %p", sk); + + hci_dev_lock(hdev); + + if (!hdev_is_powered(hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_GET_CONNECTIONS, + MGMT_STATUS_NOT_POWERED); + goto unlock; + } + + i = 0; + list_for_each_entry(c, &hdev->conn_hash.list, list) { + if (test_bit(HCI_CONN_MGMT_CONNECTED, &c->flags)) + i++; + } + + rp = kmalloc(struct_size(rp, addr, i), GFP_KERNEL); + if (!rp) { + err = -ENOMEM; + goto unlock; + } + + i = 0; + list_for_each_entry(c, &hdev->conn_hash.list, list) { + if (!test_bit(HCI_CONN_MGMT_CONNECTED, &c->flags)) + continue; + bacpy(&rp->addr[i].bdaddr, &c->dst); + rp->addr[i].type = link_to_bdaddr(c->type, c->dst_type); + if (c->type == SCO_LINK || c->type == ESCO_LINK) + continue; + i++; + } + + rp->conn_count = cpu_to_le16(i); + + /* Recalculate length in case of filtered SCO connections, etc */ + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_GET_CONNECTIONS, 0, rp, + struct_size(rp, addr, i)); + + kfree(rp); + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static int send_pin_code_neg_reply(struct sock *sk, struct hci_dev *hdev, + struct mgmt_cp_pin_code_neg_reply *cp) +{ + struct mgmt_pending_cmd *cmd; + int err; + + cmd = mgmt_pending_add(sk, MGMT_OP_PIN_CODE_NEG_REPLY, hdev, cp, + sizeof(*cp)); + if (!cmd) + return -ENOMEM; + + cmd->cmd_complete = addr_cmd_complete; + + err = hci_send_cmd(hdev, HCI_OP_PIN_CODE_NEG_REPLY, + sizeof(cp->addr.bdaddr), &cp->addr.bdaddr); + if (err < 0) + mgmt_pending_remove(cmd); + + return err; +} + +static int pin_code_reply(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct hci_conn *conn; + struct mgmt_cp_pin_code_reply *cp = data; + struct hci_cp_pin_code_reply reply; + struct mgmt_pending_cmd *cmd; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + hci_dev_lock(hdev); + + if (!hdev_is_powered(hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_PIN_CODE_REPLY, + MGMT_STATUS_NOT_POWERED); + goto failed; + } + + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &cp->addr.bdaddr); + if (!conn) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_PIN_CODE_REPLY, + MGMT_STATUS_NOT_CONNECTED); + goto failed; + } + + if (conn->pending_sec_level == BT_SECURITY_HIGH && cp->pin_len != 16) { + struct mgmt_cp_pin_code_neg_reply ncp; + + memcpy(&ncp.addr, &cp->addr, sizeof(ncp.addr)); + + bt_dev_err(hdev, "PIN code is not 16 bytes long"); + + err = send_pin_code_neg_reply(sk, hdev, &ncp); + if (err >= 0) + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_PIN_CODE_REPLY, + MGMT_STATUS_INVALID_PARAMS); + + goto failed; + } + + cmd = mgmt_pending_add(sk, MGMT_OP_PIN_CODE_REPLY, hdev, data, len); + if (!cmd) { + err = -ENOMEM; + goto failed; + } + + cmd->cmd_complete = addr_cmd_complete; + + bacpy(&reply.bdaddr, &cp->addr.bdaddr); + reply.pin_len = cp->pin_len; + memcpy(reply.pin_code, cp->pin_code, sizeof(reply.pin_code)); + + err = hci_send_cmd(hdev, HCI_OP_PIN_CODE_REPLY, sizeof(reply), &reply); + if (err < 0) + mgmt_pending_remove(cmd); + +failed: + hci_dev_unlock(hdev); + return err; +} + +static int set_io_capability(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_set_io_capability *cp = data; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (cp->io_capability > SMP_IO_KEYBOARD_DISPLAY) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_IO_CAPABILITY, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + hdev->io_capability = cp->io_capability; + + bt_dev_dbg(hdev, "IO capability set to 0x%02x", hdev->io_capability); + + hci_dev_unlock(hdev); + + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_SET_IO_CAPABILITY, 0, + NULL, 0); +} + +static struct mgmt_pending_cmd *find_pairing(struct hci_conn *conn) +{ + struct hci_dev *hdev = conn->hdev; + struct mgmt_pending_cmd *cmd; + + list_for_each_entry(cmd, &hdev->mgmt_pending, list) { + if (cmd->opcode != MGMT_OP_PAIR_DEVICE) + continue; + + if (cmd->user_data != conn) + continue; + + return cmd; + } + + return NULL; +} + +static int pairing_complete(struct mgmt_pending_cmd *cmd, u8 status) +{ + struct mgmt_rp_pair_device rp; + struct hci_conn *conn = cmd->user_data; + int err; + + bacpy(&rp.addr.bdaddr, &conn->dst); + rp.addr.type = link_to_bdaddr(conn->type, conn->dst_type); + + err = mgmt_cmd_complete(cmd->sk, cmd->index, MGMT_OP_PAIR_DEVICE, + status, &rp, sizeof(rp)); + + /* So we don't get further callbacks for this connection */ + conn->connect_cfm_cb = NULL; + conn->security_cfm_cb = NULL; + conn->disconn_cfm_cb = NULL; + + hci_conn_drop(conn); + + /* The device is paired so there is no need to remove + * its connection parameters anymore. + */ + clear_bit(HCI_CONN_PARAM_REMOVAL_PEND, &conn->flags); + + hci_conn_put(conn); + + return err; +} + +void mgmt_smp_complete(struct hci_conn *conn, bool complete) +{ + u8 status = complete ? MGMT_STATUS_SUCCESS : MGMT_STATUS_FAILED; + struct mgmt_pending_cmd *cmd; + + cmd = find_pairing(conn); + if (cmd) { + cmd->cmd_complete(cmd, status); + mgmt_pending_remove(cmd); + } +} + +static void pairing_complete_cb(struct hci_conn *conn, u8 status) +{ + struct mgmt_pending_cmd *cmd; + + BT_DBG("status %u", status); + + cmd = find_pairing(conn); + if (!cmd) { + BT_DBG("Unable to find a pending command"); + return; + } + + cmd->cmd_complete(cmd, mgmt_status(status)); + mgmt_pending_remove(cmd); +} + +static void le_pairing_complete_cb(struct hci_conn *conn, u8 status) +{ + struct mgmt_pending_cmd *cmd; + + BT_DBG("status %u", status); + + if (!status) + return; + + cmd = find_pairing(conn); + if (!cmd) { + BT_DBG("Unable to find a pending command"); + return; + } + + cmd->cmd_complete(cmd, mgmt_status(status)); + mgmt_pending_remove(cmd); +} + +static int pair_device(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_pair_device *cp = data; + struct mgmt_rp_pair_device rp; + struct mgmt_pending_cmd *cmd; + u8 sec_level, auth_type; + struct hci_conn *conn; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + memset(&rp, 0, sizeof(rp)); + bacpy(&rp.addr.bdaddr, &cp->addr.bdaddr); + rp.addr.type = cp->addr.type; + + if (!bdaddr_type_is_valid(cp->addr.type)) + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_PAIR_DEVICE, + MGMT_STATUS_INVALID_PARAMS, + &rp, sizeof(rp)); + + if (cp->io_cap > SMP_IO_KEYBOARD_DISPLAY) + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_PAIR_DEVICE, + MGMT_STATUS_INVALID_PARAMS, + &rp, sizeof(rp)); + + hci_dev_lock(hdev); + + if (!hdev_is_powered(hdev)) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_PAIR_DEVICE, + MGMT_STATUS_NOT_POWERED, &rp, + sizeof(rp)); + goto unlock; + } + + if (hci_bdaddr_is_paired(hdev, &cp->addr.bdaddr, cp->addr.type)) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_PAIR_DEVICE, + MGMT_STATUS_ALREADY_PAIRED, &rp, + sizeof(rp)); + goto unlock; + } + + sec_level = BT_SECURITY_MEDIUM; + auth_type = HCI_AT_DEDICATED_BONDING; + + if (cp->addr.type == BDADDR_BREDR) { + conn = hci_connect_acl(hdev, &cp->addr.bdaddr, sec_level, + auth_type, CONN_REASON_PAIR_DEVICE); + } else { + u8 addr_type = le_addr_type(cp->addr.type); + struct hci_conn_params *p; + + /* When pairing a new device, it is expected to remember + * this device for future connections. Adding the connection + * parameter information ahead of time allows tracking + * of the peripheral preferred values and will speed up any + * further connection establishment. + * + * If connection parameters already exist, then they + * will be kept and this function does nothing. + */ + p = hci_conn_params_add(hdev, &cp->addr.bdaddr, addr_type); + + if (p->auto_connect == HCI_AUTO_CONN_EXPLICIT) + p->auto_connect = HCI_AUTO_CONN_DISABLED; + + conn = hci_connect_le_scan(hdev, &cp->addr.bdaddr, addr_type, + sec_level, HCI_LE_CONN_TIMEOUT, + CONN_REASON_PAIR_DEVICE); + } + + if (IS_ERR(conn)) { + int status; + + if (PTR_ERR(conn) == -EBUSY) + status = MGMT_STATUS_BUSY; + else if (PTR_ERR(conn) == -EOPNOTSUPP) + status = MGMT_STATUS_NOT_SUPPORTED; + else if (PTR_ERR(conn) == -ECONNREFUSED) + status = MGMT_STATUS_REJECTED; + else + status = MGMT_STATUS_CONNECT_FAILED; + + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_PAIR_DEVICE, + status, &rp, sizeof(rp)); + goto unlock; + } + + if (conn->connect_cfm_cb) { + hci_conn_drop(conn); + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_PAIR_DEVICE, + MGMT_STATUS_BUSY, &rp, sizeof(rp)); + goto unlock; + } + + cmd = mgmt_pending_add(sk, MGMT_OP_PAIR_DEVICE, hdev, data, len); + if (!cmd) { + err = -ENOMEM; + hci_conn_drop(conn); + goto unlock; + } + + cmd->cmd_complete = pairing_complete; + + /* For LE, just connecting isn't a proof that the pairing finished */ + if (cp->addr.type == BDADDR_BREDR) { + conn->connect_cfm_cb = pairing_complete_cb; + conn->security_cfm_cb = pairing_complete_cb; + conn->disconn_cfm_cb = pairing_complete_cb; + } else { + conn->connect_cfm_cb = le_pairing_complete_cb; + conn->security_cfm_cb = le_pairing_complete_cb; + conn->disconn_cfm_cb = le_pairing_complete_cb; + } + + conn->io_capability = cp->io_cap; + cmd->user_data = hci_conn_get(conn); + + if ((conn->state == BT_CONNECTED || conn->state == BT_CONFIG) && + hci_conn_security(conn, sec_level, auth_type, true)) { + cmd->cmd_complete(cmd, 0); + mgmt_pending_remove(cmd); + } + + err = 0; + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static int abort_conn_sync(struct hci_dev *hdev, void *data) +{ + struct hci_conn *conn; + u16 handle = PTR_ERR(data); + + conn = hci_conn_hash_lookup_handle(hdev, handle); + if (!conn) + return 0; + + return hci_abort_conn_sync(hdev, conn, HCI_ERROR_REMOTE_USER_TERM); +} + +static int cancel_pair_device(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_addr_info *addr = data; + struct mgmt_pending_cmd *cmd; + struct hci_conn *conn; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + hci_dev_lock(hdev); + + if (!hdev_is_powered(hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_CANCEL_PAIR_DEVICE, + MGMT_STATUS_NOT_POWERED); + goto unlock; + } + + cmd = pending_find(MGMT_OP_PAIR_DEVICE, hdev); + if (!cmd) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_CANCEL_PAIR_DEVICE, + MGMT_STATUS_INVALID_PARAMS); + goto unlock; + } + + conn = cmd->user_data; + + if (bacmp(&addr->bdaddr, &conn->dst) != 0) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_CANCEL_PAIR_DEVICE, + MGMT_STATUS_INVALID_PARAMS); + goto unlock; + } + + cmd->cmd_complete(cmd, MGMT_STATUS_CANCELLED); + mgmt_pending_remove(cmd); + + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_CANCEL_PAIR_DEVICE, 0, + addr, sizeof(*addr)); + + /* Since user doesn't want to proceed with the connection, abort any + * ongoing pairing and then terminate the link if it was created + * because of the pair device action. + */ + if (addr->type == BDADDR_BREDR) + hci_remove_link_key(hdev, &addr->bdaddr); + else + smp_cancel_and_remove_pairing(hdev, &addr->bdaddr, + le_addr_type(addr->type)); + + if (conn->conn_reason == CONN_REASON_PAIR_DEVICE) + hci_cmd_sync_queue(hdev, abort_conn_sync, ERR_PTR(conn->handle), + NULL); + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static int user_pairing_resp(struct sock *sk, struct hci_dev *hdev, + struct mgmt_addr_info *addr, u16 mgmt_op, + u16 hci_op, __le32 passkey) +{ + struct mgmt_pending_cmd *cmd; + struct hci_conn *conn; + int err; + + hci_dev_lock(hdev); + + if (!hdev_is_powered(hdev)) { + err = mgmt_cmd_complete(sk, hdev->id, mgmt_op, + MGMT_STATUS_NOT_POWERED, addr, + sizeof(*addr)); + goto done; + } + + if (addr->type == BDADDR_BREDR) + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &addr->bdaddr); + else + conn = hci_conn_hash_lookup_le(hdev, &addr->bdaddr, + le_addr_type(addr->type)); + + if (!conn) { + err = mgmt_cmd_complete(sk, hdev->id, mgmt_op, + MGMT_STATUS_NOT_CONNECTED, addr, + sizeof(*addr)); + goto done; + } + + if (addr->type == BDADDR_LE_PUBLIC || addr->type == BDADDR_LE_RANDOM) { + err = smp_user_confirm_reply(conn, mgmt_op, passkey); + if (!err) + err = mgmt_cmd_complete(sk, hdev->id, mgmt_op, + MGMT_STATUS_SUCCESS, addr, + sizeof(*addr)); + else + err = mgmt_cmd_complete(sk, hdev->id, mgmt_op, + MGMT_STATUS_FAILED, addr, + sizeof(*addr)); + + goto done; + } + + cmd = mgmt_pending_add(sk, mgmt_op, hdev, addr, sizeof(*addr)); + if (!cmd) { + err = -ENOMEM; + goto done; + } + + cmd->cmd_complete = addr_cmd_complete; + + /* Continue with pairing via HCI */ + if (hci_op == HCI_OP_USER_PASSKEY_REPLY) { + struct hci_cp_user_passkey_reply cp; + + bacpy(&cp.bdaddr, &addr->bdaddr); + cp.passkey = passkey; + err = hci_send_cmd(hdev, hci_op, sizeof(cp), &cp); + } else + err = hci_send_cmd(hdev, hci_op, sizeof(addr->bdaddr), + &addr->bdaddr); + + if (err < 0) + mgmt_pending_remove(cmd); + +done: + hci_dev_unlock(hdev); + return err; +} + +static int pin_code_neg_reply(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_cp_pin_code_neg_reply *cp = data; + + bt_dev_dbg(hdev, "sock %p", sk); + + return user_pairing_resp(sk, hdev, &cp->addr, + MGMT_OP_PIN_CODE_NEG_REPLY, + HCI_OP_PIN_CODE_NEG_REPLY, 0); +} + +static int user_confirm_reply(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_user_confirm_reply *cp = data; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (len != sizeof(*cp)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_USER_CONFIRM_REPLY, + MGMT_STATUS_INVALID_PARAMS); + + return user_pairing_resp(sk, hdev, &cp->addr, + MGMT_OP_USER_CONFIRM_REPLY, + HCI_OP_USER_CONFIRM_REPLY, 0); +} + +static int user_confirm_neg_reply(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_cp_user_confirm_neg_reply *cp = data; + + bt_dev_dbg(hdev, "sock %p", sk); + + return user_pairing_resp(sk, hdev, &cp->addr, + MGMT_OP_USER_CONFIRM_NEG_REPLY, + HCI_OP_USER_CONFIRM_NEG_REPLY, 0); +} + +static int user_passkey_reply(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_user_passkey_reply *cp = data; + + bt_dev_dbg(hdev, "sock %p", sk); + + return user_pairing_resp(sk, hdev, &cp->addr, + MGMT_OP_USER_PASSKEY_REPLY, + HCI_OP_USER_PASSKEY_REPLY, cp->passkey); +} + +static int user_passkey_neg_reply(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_cp_user_passkey_neg_reply *cp = data; + + bt_dev_dbg(hdev, "sock %p", sk); + + return user_pairing_resp(sk, hdev, &cp->addr, + MGMT_OP_USER_PASSKEY_NEG_REPLY, + HCI_OP_USER_PASSKEY_NEG_REPLY, 0); +} + +static int adv_expire_sync(struct hci_dev *hdev, u32 flags) +{ + struct adv_info *adv_instance; + + adv_instance = hci_find_adv_instance(hdev, hdev->cur_adv_instance); + if (!adv_instance) + return 0; + + /* stop if current instance doesn't need to be changed */ + if (!(adv_instance->flags & flags)) + return 0; + + cancel_adv_timeout(hdev); + + adv_instance = hci_get_next_instance(hdev, adv_instance->instance); + if (!adv_instance) + return 0; + + hci_schedule_adv_instance_sync(hdev, adv_instance->instance, true); + + return 0; +} + +static int name_changed_sync(struct hci_dev *hdev, void *data) +{ + return adv_expire_sync(hdev, MGMT_ADV_FLAG_LOCAL_NAME); +} + +static void set_name_complete(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_set_local_name *cp = cmd->param; + u8 status = mgmt_status(err); + + bt_dev_dbg(hdev, "err %d", err); + + if (cmd != pending_find(MGMT_OP_SET_LOCAL_NAME, hdev)) + return; + + if (status) { + mgmt_cmd_status(cmd->sk, hdev->id, MGMT_OP_SET_LOCAL_NAME, + status); + } else { + mgmt_cmd_complete(cmd->sk, hdev->id, MGMT_OP_SET_LOCAL_NAME, 0, + cp, sizeof(*cp)); + + if (hci_dev_test_flag(hdev, HCI_LE_ADV)) + hci_cmd_sync_queue(hdev, name_changed_sync, NULL, NULL); + } + + mgmt_pending_remove(cmd); +} + +static int set_name_sync(struct hci_dev *hdev, void *data) +{ + if (lmp_bredr_capable(hdev)) { + hci_update_name_sync(hdev); + hci_update_eir_sync(hdev); + } + + /* The name is stored in the scan response data and so + * no need to update the advertising data here. + */ + if (lmp_le_capable(hdev) && hci_dev_test_flag(hdev, HCI_ADVERTISING)) + hci_update_scan_rsp_data_sync(hdev, hdev->cur_adv_instance); + + return 0; +} + +static int set_local_name(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_set_local_name *cp = data; + struct mgmt_pending_cmd *cmd; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + hci_dev_lock(hdev); + + /* If the old values are the same as the new ones just return a + * direct command complete event. + */ + if (!memcmp(hdev->dev_name, cp->name, sizeof(hdev->dev_name)) && + !memcmp(hdev->short_name, cp->short_name, + sizeof(hdev->short_name))) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_SET_LOCAL_NAME, 0, + data, len); + goto failed; + } + + memcpy(hdev->short_name, cp->short_name, sizeof(hdev->short_name)); + + if (!hdev_is_powered(hdev)) { + memcpy(hdev->dev_name, cp->name, sizeof(hdev->dev_name)); + + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_SET_LOCAL_NAME, 0, + data, len); + if (err < 0) + goto failed; + + err = mgmt_limited_event(MGMT_EV_LOCAL_NAME_CHANGED, hdev, data, + len, HCI_MGMT_LOCAL_NAME_EVENTS, sk); + ext_info_changed(hdev, sk); + + goto failed; + } + + cmd = mgmt_pending_add(sk, MGMT_OP_SET_LOCAL_NAME, hdev, data, len); + if (!cmd) + err = -ENOMEM; + else + err = hci_cmd_sync_queue(hdev, set_name_sync, cmd, + set_name_complete); + + if (err < 0) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_LOCAL_NAME, + MGMT_STATUS_FAILED); + + if (cmd) + mgmt_pending_remove(cmd); + + goto failed; + } + + memcpy(hdev->dev_name, cp->name, sizeof(hdev->dev_name)); + +failed: + hci_dev_unlock(hdev); + return err; +} + +static int appearance_changed_sync(struct hci_dev *hdev, void *data) +{ + return adv_expire_sync(hdev, MGMT_ADV_FLAG_APPEARANCE); +} + +static int set_appearance(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_set_appearance *cp = data; + u16 appearance; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!lmp_le_capable(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_APPEARANCE, + MGMT_STATUS_NOT_SUPPORTED); + + appearance = le16_to_cpu(cp->appearance); + + hci_dev_lock(hdev); + + if (hdev->appearance != appearance) { + hdev->appearance = appearance; + + if (hci_dev_test_flag(hdev, HCI_LE_ADV)) + hci_cmd_sync_queue(hdev, appearance_changed_sync, NULL, + NULL); + + ext_info_changed(hdev, sk); + } + + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_SET_APPEARANCE, 0, NULL, + 0); + + hci_dev_unlock(hdev); + + return err; +} + +static int get_phy_configuration(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_rp_get_phy_configuration rp; + + bt_dev_dbg(hdev, "sock %p", sk); + + hci_dev_lock(hdev); + + memset(&rp, 0, sizeof(rp)); + + rp.supported_phys = cpu_to_le32(get_supported_phys(hdev)); + rp.selected_phys = cpu_to_le32(get_selected_phys(hdev)); + rp.configurable_phys = cpu_to_le32(get_configurable_phys(hdev)); + + hci_dev_unlock(hdev); + + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_GET_PHY_CONFIGURATION, 0, + &rp, sizeof(rp)); +} + +int mgmt_phy_configuration_changed(struct hci_dev *hdev, struct sock *skip) +{ + struct mgmt_ev_phy_configuration_changed ev; + + memset(&ev, 0, sizeof(ev)); + + ev.selected_phys = cpu_to_le32(get_selected_phys(hdev)); + + return mgmt_event(MGMT_EV_PHY_CONFIGURATION_CHANGED, hdev, &ev, + sizeof(ev), skip); +} + +static void set_default_phy_complete(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_pending_cmd *cmd = data; + struct sk_buff *skb = cmd->skb; + u8 status = mgmt_status(err); + + if (cmd != pending_find(MGMT_OP_SET_PHY_CONFIGURATION, hdev)) + return; + + if (!status) { + if (!skb) + status = MGMT_STATUS_FAILED; + else if (IS_ERR(skb)) + status = mgmt_status(PTR_ERR(skb)); + else + status = mgmt_status(skb->data[0]); + } + + bt_dev_dbg(hdev, "status %d", status); + + if (status) { + mgmt_cmd_status(cmd->sk, hdev->id, + MGMT_OP_SET_PHY_CONFIGURATION, status); + } else { + mgmt_cmd_complete(cmd->sk, hdev->id, + MGMT_OP_SET_PHY_CONFIGURATION, 0, + NULL, 0); + + mgmt_phy_configuration_changed(hdev, cmd->sk); + } + + if (skb && !IS_ERR(skb)) + kfree_skb(skb); + + mgmt_pending_remove(cmd); +} + +static int set_default_phy_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_set_phy_configuration *cp = cmd->param; + struct hci_cp_le_set_default_phy cp_phy; + u32 selected_phys = __le32_to_cpu(cp->selected_phys); + + memset(&cp_phy, 0, sizeof(cp_phy)); + + if (!(selected_phys & MGMT_PHY_LE_TX_MASK)) + cp_phy.all_phys |= 0x01; + + if (!(selected_phys & MGMT_PHY_LE_RX_MASK)) + cp_phy.all_phys |= 0x02; + + if (selected_phys & MGMT_PHY_LE_1M_TX) + cp_phy.tx_phys |= HCI_LE_SET_PHY_1M; + + if (selected_phys & MGMT_PHY_LE_2M_TX) + cp_phy.tx_phys |= HCI_LE_SET_PHY_2M; + + if (selected_phys & MGMT_PHY_LE_CODED_TX) + cp_phy.tx_phys |= HCI_LE_SET_PHY_CODED; + + if (selected_phys & MGMT_PHY_LE_1M_RX) + cp_phy.rx_phys |= HCI_LE_SET_PHY_1M; + + if (selected_phys & MGMT_PHY_LE_2M_RX) + cp_phy.rx_phys |= HCI_LE_SET_PHY_2M; + + if (selected_phys & MGMT_PHY_LE_CODED_RX) + cp_phy.rx_phys |= HCI_LE_SET_PHY_CODED; + + cmd->skb = __hci_cmd_sync(hdev, HCI_OP_LE_SET_DEFAULT_PHY, + sizeof(cp_phy), &cp_phy, HCI_CMD_TIMEOUT); + + return 0; +} + +static int set_phy_configuration(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_cp_set_phy_configuration *cp = data; + struct mgmt_pending_cmd *cmd; + u32 selected_phys, configurable_phys, supported_phys, unconfigure_phys; + u16 pkt_type = (HCI_DH1 | HCI_DM1); + bool changed = false; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + configurable_phys = get_configurable_phys(hdev); + supported_phys = get_supported_phys(hdev); + selected_phys = __le32_to_cpu(cp->selected_phys); + + if (selected_phys & ~supported_phys) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_PHY_CONFIGURATION, + MGMT_STATUS_INVALID_PARAMS); + + unconfigure_phys = supported_phys & ~configurable_phys; + + if ((selected_phys & unconfigure_phys) != unconfigure_phys) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_PHY_CONFIGURATION, + MGMT_STATUS_INVALID_PARAMS); + + if (selected_phys == get_selected_phys(hdev)) + return mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_SET_PHY_CONFIGURATION, + 0, NULL, 0); + + hci_dev_lock(hdev); + + if (!hdev_is_powered(hdev)) { + err = mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_PHY_CONFIGURATION, + MGMT_STATUS_REJECTED); + goto unlock; + } + + if (pending_find(MGMT_OP_SET_PHY_CONFIGURATION, hdev)) { + err = mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_PHY_CONFIGURATION, + MGMT_STATUS_BUSY); + goto unlock; + } + + if (selected_phys & MGMT_PHY_BR_1M_3SLOT) + pkt_type |= (HCI_DH3 | HCI_DM3); + else + pkt_type &= ~(HCI_DH3 | HCI_DM3); + + if (selected_phys & MGMT_PHY_BR_1M_5SLOT) + pkt_type |= (HCI_DH5 | HCI_DM5); + else + pkt_type &= ~(HCI_DH5 | HCI_DM5); + + if (selected_phys & MGMT_PHY_EDR_2M_1SLOT) + pkt_type &= ~HCI_2DH1; + else + pkt_type |= HCI_2DH1; + + if (selected_phys & MGMT_PHY_EDR_2M_3SLOT) + pkt_type &= ~HCI_2DH3; + else + pkt_type |= HCI_2DH3; + + if (selected_phys & MGMT_PHY_EDR_2M_5SLOT) + pkt_type &= ~HCI_2DH5; + else + pkt_type |= HCI_2DH5; + + if (selected_phys & MGMT_PHY_EDR_3M_1SLOT) + pkt_type &= ~HCI_3DH1; + else + pkt_type |= HCI_3DH1; + + if (selected_phys & MGMT_PHY_EDR_3M_3SLOT) + pkt_type &= ~HCI_3DH3; + else + pkt_type |= HCI_3DH3; + + if (selected_phys & MGMT_PHY_EDR_3M_5SLOT) + pkt_type &= ~HCI_3DH5; + else + pkt_type |= HCI_3DH5; + + if (pkt_type != hdev->pkt_type) { + hdev->pkt_type = pkt_type; + changed = true; + } + + if ((selected_phys & MGMT_PHY_LE_MASK) == + (get_selected_phys(hdev) & MGMT_PHY_LE_MASK)) { + if (changed) + mgmt_phy_configuration_changed(hdev, sk); + + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_SET_PHY_CONFIGURATION, + 0, NULL, 0); + + goto unlock; + } + + cmd = mgmt_pending_add(sk, MGMT_OP_SET_PHY_CONFIGURATION, hdev, data, + len); + if (!cmd) + err = -ENOMEM; + else + err = hci_cmd_sync_queue(hdev, set_default_phy_sync, cmd, + set_default_phy_complete); + + if (err < 0) { + err = mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_PHY_CONFIGURATION, + MGMT_STATUS_FAILED); + + if (cmd) + mgmt_pending_remove(cmd); + } + +unlock: + hci_dev_unlock(hdev); + + return err; +} + +static int set_blocked_keys(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + int err = MGMT_STATUS_SUCCESS; + struct mgmt_cp_set_blocked_keys *keys = data; + const u16 max_key_count = ((U16_MAX - sizeof(*keys)) / + sizeof(struct mgmt_blocked_key_info)); + u16 key_count, expected_len; + int i; + + bt_dev_dbg(hdev, "sock %p", sk); + + key_count = __le16_to_cpu(keys->key_count); + if (key_count > max_key_count) { + bt_dev_err(hdev, "too big key_count value %u", key_count); + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_BLOCKED_KEYS, + MGMT_STATUS_INVALID_PARAMS); + } + + expected_len = struct_size(keys, keys, key_count); + if (expected_len != len) { + bt_dev_err(hdev, "expected %u bytes, got %u bytes", + expected_len, len); + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_BLOCKED_KEYS, + MGMT_STATUS_INVALID_PARAMS); + } + + hci_dev_lock(hdev); + + hci_blocked_keys_clear(hdev); + + for (i = 0; i < key_count; ++i) { + struct blocked_key *b = kzalloc(sizeof(*b), GFP_KERNEL); + + if (!b) { + err = MGMT_STATUS_NO_RESOURCES; + break; + } + + b->type = keys->keys[i].type; + memcpy(b->val, keys->keys[i].val, sizeof(b->val)); + list_add_rcu(&b->list, &hdev->blocked_keys); + } + hci_dev_unlock(hdev); + + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_SET_BLOCKED_KEYS, + err, NULL, 0); +} + +static int set_wideband_speech(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_mode *cp = data; + int err; + bool changed = false; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!test_bit(HCI_QUIRK_WIDEBAND_SPEECH_SUPPORTED, &hdev->quirks)) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_WIDEBAND_SPEECH, + MGMT_STATUS_NOT_SUPPORTED); + + if (cp->val != 0x00 && cp->val != 0x01) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_WIDEBAND_SPEECH, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + if (hdev_is_powered(hdev) && + !!cp->val != hci_dev_test_flag(hdev, + HCI_WIDEBAND_SPEECH_ENABLED)) { + err = mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_WIDEBAND_SPEECH, + MGMT_STATUS_REJECTED); + goto unlock; + } + + if (cp->val) + changed = !hci_dev_test_and_set_flag(hdev, + HCI_WIDEBAND_SPEECH_ENABLED); + else + changed = hci_dev_test_and_clear_flag(hdev, + HCI_WIDEBAND_SPEECH_ENABLED); + + err = send_settings_rsp(sk, MGMT_OP_SET_WIDEBAND_SPEECH, hdev); + if (err < 0) + goto unlock; + + if (changed) + err = new_settings(hdev, sk); + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static int read_controller_cap(struct sock *sk, struct hci_dev *hdev, + void *data, u16 data_len) +{ + char buf[20]; + struct mgmt_rp_read_controller_cap *rp = (void *)buf; + u16 cap_len = 0; + u8 flags = 0; + u8 tx_power_range[2]; + + bt_dev_dbg(hdev, "sock %p", sk); + + memset(&buf, 0, sizeof(buf)); + + hci_dev_lock(hdev); + + /* When the Read Simple Pairing Options command is supported, then + * the remote public key validation is supported. + * + * Alternatively, when Microsoft extensions are available, they can + * indicate support for public key validation as well. + */ + if ((hdev->commands[41] & 0x08) || msft_curve_validity(hdev)) + flags |= 0x01; /* Remote public key validation (BR/EDR) */ + + flags |= 0x02; /* Remote public key validation (LE) */ + + /* When the Read Encryption Key Size command is supported, then the + * encryption key size is enforced. + */ + if (hdev->commands[20] & 0x10) + flags |= 0x04; /* Encryption key size enforcement (BR/EDR) */ + + flags |= 0x08; /* Encryption key size enforcement (LE) */ + + cap_len = eir_append_data(rp->cap, cap_len, MGMT_CAP_SEC_FLAGS, + &flags, 1); + + /* When the Read Simple Pairing Options command is supported, then + * also max encryption key size information is provided. + */ + if (hdev->commands[41] & 0x08) + cap_len = eir_append_le16(rp->cap, cap_len, + MGMT_CAP_MAX_ENC_KEY_SIZE, + hdev->max_enc_key_size); + + cap_len = eir_append_le16(rp->cap, cap_len, + MGMT_CAP_SMP_MAX_ENC_KEY_SIZE, + SMP_MAX_ENC_KEY_SIZE); + + /* Append the min/max LE tx power parameters if we were able to fetch + * it from the controller + */ + if (hdev->commands[38] & 0x80) { + memcpy(&tx_power_range[0], &hdev->min_le_tx_power, 1); + memcpy(&tx_power_range[1], &hdev->max_le_tx_power, 1); + cap_len = eir_append_data(rp->cap, cap_len, MGMT_CAP_LE_TX_PWR, + tx_power_range, 2); + } + + rp->cap_len = cpu_to_le16(cap_len); + + hci_dev_unlock(hdev); + + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_READ_CONTROLLER_CAP, 0, + rp, sizeof(*rp) + cap_len); +} + +#ifdef CONFIG_BT_FEATURE_DEBUG +/* d4992530-b9ec-469f-ab01-6c481c47da1c */ +static const u8 debug_uuid[16] = { + 0x1c, 0xda, 0x47, 0x1c, 0x48, 0x6c, 0x01, 0xab, + 0x9f, 0x46, 0xec, 0xb9, 0x30, 0x25, 0x99, 0xd4, +}; +#endif + +/* 330859bc-7506-492d-9370-9a6f0614037f */ +static const u8 quality_report_uuid[16] = { + 0x7f, 0x03, 0x14, 0x06, 0x6f, 0x9a, 0x70, 0x93, + 0x2d, 0x49, 0x06, 0x75, 0xbc, 0x59, 0x08, 0x33, +}; + +/* a6695ace-ee7f-4fb9-881a-5fac66c629af */ +static const u8 offload_codecs_uuid[16] = { + 0xaf, 0x29, 0xc6, 0x66, 0xac, 0x5f, 0x1a, 0x88, + 0xb9, 0x4f, 0x7f, 0xee, 0xce, 0x5a, 0x69, 0xa6, +}; + +/* 671b10b5-42c0-4696-9227-eb28d1b049d6 */ +static const u8 le_simultaneous_roles_uuid[16] = { + 0xd6, 0x49, 0xb0, 0xd1, 0x28, 0xeb, 0x27, 0x92, + 0x96, 0x46, 0xc0, 0x42, 0xb5, 0x10, 0x1b, 0x67, +}; + +/* 15c0a148-c273-11ea-b3de-0242ac130004 */ +static const u8 rpa_resolution_uuid[16] = { + 0x04, 0x00, 0x13, 0xac, 0x42, 0x02, 0xde, 0xb3, + 0xea, 0x11, 0x73, 0xc2, 0x48, 0xa1, 0xc0, 0x15, +}; + +/* 6fbaf188-05e0-496a-9885-d6ddfdb4e03e */ +static const u8 iso_socket_uuid[16] = { + 0x3e, 0xe0, 0xb4, 0xfd, 0xdd, 0xd6, 0x85, 0x98, + 0x6a, 0x49, 0xe0, 0x05, 0x88, 0xf1, 0xba, 0x6f, +}; + +/* 2ce463d7-7a03-4d8d-bf05-5f24e8f36e76 */ +static const u8 mgmt_mesh_uuid[16] = { + 0x76, 0x6e, 0xf3, 0xe8, 0x24, 0x5f, 0x05, 0xbf, + 0x8d, 0x4d, 0x03, 0x7a, 0xd7, 0x63, 0xe4, 0x2c, +}; + +static int read_exp_features_info(struct sock *sk, struct hci_dev *hdev, + void *data, u16 data_len) +{ + struct mgmt_rp_read_exp_features_info *rp; + size_t len; + u16 idx = 0; + u32 flags; + int status; + + bt_dev_dbg(hdev, "sock %p", sk); + + /* Enough space for 7 features */ + len = sizeof(*rp) + (sizeof(rp->features[0]) * 7); + rp = kzalloc(len, GFP_KERNEL); + if (!rp) + return -ENOMEM; + +#ifdef CONFIG_BT_FEATURE_DEBUG + if (!hdev) { + flags = bt_dbg_get() ? BIT(0) : 0; + + memcpy(rp->features[idx].uuid, debug_uuid, 16); + rp->features[idx].flags = cpu_to_le32(flags); + idx++; + } +#endif + + if (hdev && hci_dev_le_state_simultaneous(hdev)) { + if (hci_dev_test_flag(hdev, HCI_LE_SIMULTANEOUS_ROLES)) + flags = BIT(0); + else + flags = 0; + + memcpy(rp->features[idx].uuid, le_simultaneous_roles_uuid, 16); + rp->features[idx].flags = cpu_to_le32(flags); + idx++; + } + + if (hdev && ll_privacy_capable(hdev)) { + if (hci_dev_test_flag(hdev, HCI_ENABLE_LL_PRIVACY)) + flags = BIT(0) | BIT(1); + else + flags = BIT(1); + + memcpy(rp->features[idx].uuid, rpa_resolution_uuid, 16); + rp->features[idx].flags = cpu_to_le32(flags); + idx++; + } + + if (hdev && (aosp_has_quality_report(hdev) || + hdev->set_quality_report)) { + if (hci_dev_test_flag(hdev, HCI_QUALITY_REPORT)) + flags = BIT(0); + else + flags = 0; + + memcpy(rp->features[idx].uuid, quality_report_uuid, 16); + rp->features[idx].flags = cpu_to_le32(flags); + idx++; + } + + if (hdev && hdev->get_data_path_id) { + if (hci_dev_test_flag(hdev, HCI_OFFLOAD_CODECS_ENABLED)) + flags = BIT(0); + else + flags = 0; + + memcpy(rp->features[idx].uuid, offload_codecs_uuid, 16); + rp->features[idx].flags = cpu_to_le32(flags); + idx++; + } + + if (IS_ENABLED(CONFIG_BT_LE)) { + flags = iso_enabled() ? BIT(0) : 0; + memcpy(rp->features[idx].uuid, iso_socket_uuid, 16); + rp->features[idx].flags = cpu_to_le32(flags); + idx++; + } + + if (hdev && lmp_le_capable(hdev)) { + if (hci_dev_test_flag(hdev, HCI_MESH_EXPERIMENTAL)) + flags = BIT(0); + else + flags = 0; + + memcpy(rp->features[idx].uuid, mgmt_mesh_uuid, 16); + rp->features[idx].flags = cpu_to_le32(flags); + idx++; + } + + rp->feature_count = cpu_to_le16(idx); + + /* After reading the experimental features information, enable + * the events to update client on any future change. + */ + hci_sock_set_flag(sk, HCI_MGMT_EXP_FEATURE_EVENTS); + + status = mgmt_cmd_complete(sk, hdev ? hdev->id : MGMT_INDEX_NONE, + MGMT_OP_READ_EXP_FEATURES_INFO, + 0, rp, sizeof(*rp) + (20 * idx)); + + kfree(rp); + return status; +} + +static int exp_ll_privacy_feature_changed(bool enabled, struct hci_dev *hdev, + struct sock *skip) +{ + struct mgmt_ev_exp_feature_changed ev; + + memset(&ev, 0, sizeof(ev)); + memcpy(ev.uuid, rpa_resolution_uuid, 16); + ev.flags = cpu_to_le32((enabled ? BIT(0) : 0) | BIT(1)); + + // Do we need to be atomic with the conn_flags? + if (enabled && privacy_mode_capable(hdev)) + hdev->conn_flags |= HCI_CONN_FLAG_DEVICE_PRIVACY; + else + hdev->conn_flags &= ~HCI_CONN_FLAG_DEVICE_PRIVACY; + + return mgmt_limited_event(MGMT_EV_EXP_FEATURE_CHANGED, hdev, + &ev, sizeof(ev), + HCI_MGMT_EXP_FEATURE_EVENTS, skip); + +} + +static int exp_feature_changed(struct hci_dev *hdev, const u8 *uuid, + bool enabled, struct sock *skip) +{ + struct mgmt_ev_exp_feature_changed ev; + + memset(&ev, 0, sizeof(ev)); + memcpy(ev.uuid, uuid, 16); + ev.flags = cpu_to_le32(enabled ? BIT(0) : 0); + + return mgmt_limited_event(MGMT_EV_EXP_FEATURE_CHANGED, hdev, + &ev, sizeof(ev), + HCI_MGMT_EXP_FEATURE_EVENTS, skip); +} + +#define EXP_FEAT(_uuid, _set_func) \ +{ \ + .uuid = _uuid, \ + .set_func = _set_func, \ +} + +/* The zero key uuid is special. Multiple exp features are set through it. */ +static int set_zero_key_func(struct sock *sk, struct hci_dev *hdev, + struct mgmt_cp_set_exp_feature *cp, u16 data_len) +{ + struct mgmt_rp_set_exp_feature rp; + + memset(rp.uuid, 0, 16); + rp.flags = cpu_to_le32(0); + +#ifdef CONFIG_BT_FEATURE_DEBUG + if (!hdev) { + bool changed = bt_dbg_get(); + + bt_dbg_set(false); + + if (changed) + exp_feature_changed(NULL, ZERO_KEY, false, sk); + } +#endif + + if (hdev && use_ll_privacy(hdev) && !hdev_is_powered(hdev)) { + bool changed; + + changed = hci_dev_test_and_clear_flag(hdev, + HCI_ENABLE_LL_PRIVACY); + if (changed) + exp_feature_changed(hdev, rpa_resolution_uuid, false, + sk); + } + + hci_sock_set_flag(sk, HCI_MGMT_EXP_FEATURE_EVENTS); + + return mgmt_cmd_complete(sk, hdev ? hdev->id : MGMT_INDEX_NONE, + MGMT_OP_SET_EXP_FEATURE, 0, + &rp, sizeof(rp)); +} + +#ifdef CONFIG_BT_FEATURE_DEBUG +static int set_debug_func(struct sock *sk, struct hci_dev *hdev, + struct mgmt_cp_set_exp_feature *cp, u16 data_len) +{ + struct mgmt_rp_set_exp_feature rp; + + bool val, changed; + int err; + + /* Command requires to use the non-controller index */ + if (hdev) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_INDEX); + + /* Parameters are limited to a single octet */ + if (data_len != MGMT_SET_EXP_FEATURE_SIZE + 1) + return mgmt_cmd_status(sk, MGMT_INDEX_NONE, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_PARAMS); + + /* Only boolean on/off is supported */ + if (cp->param[0] != 0x00 && cp->param[0] != 0x01) + return mgmt_cmd_status(sk, MGMT_INDEX_NONE, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_PARAMS); + + val = !!cp->param[0]; + changed = val ? !bt_dbg_get() : bt_dbg_get(); + bt_dbg_set(val); + + memcpy(rp.uuid, debug_uuid, 16); + rp.flags = cpu_to_le32(val ? BIT(0) : 0); + + hci_sock_set_flag(sk, HCI_MGMT_EXP_FEATURE_EVENTS); + + err = mgmt_cmd_complete(sk, MGMT_INDEX_NONE, + MGMT_OP_SET_EXP_FEATURE, 0, + &rp, sizeof(rp)); + + if (changed) + exp_feature_changed(hdev, debug_uuid, val, sk); + + return err; +} +#endif + +static int set_mgmt_mesh_func(struct sock *sk, struct hci_dev *hdev, + struct mgmt_cp_set_exp_feature *cp, u16 data_len) +{ + struct mgmt_rp_set_exp_feature rp; + bool val, changed; + int err; + + /* Command requires to use the controller index */ + if (!hdev) + return mgmt_cmd_status(sk, MGMT_INDEX_NONE, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_INDEX); + + /* Parameters are limited to a single octet */ + if (data_len != MGMT_SET_EXP_FEATURE_SIZE + 1) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_PARAMS); + + /* Only boolean on/off is supported */ + if (cp->param[0] != 0x00 && cp->param[0] != 0x01) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_PARAMS); + + val = !!cp->param[0]; + + if (val) { + changed = !hci_dev_test_and_set_flag(hdev, + HCI_MESH_EXPERIMENTAL); + } else { + hci_dev_clear_flag(hdev, HCI_MESH); + changed = hci_dev_test_and_clear_flag(hdev, + HCI_MESH_EXPERIMENTAL); + } + + memcpy(rp.uuid, mgmt_mesh_uuid, 16); + rp.flags = cpu_to_le32(val ? BIT(0) : 0); + + hci_sock_set_flag(sk, HCI_MGMT_EXP_FEATURE_EVENTS); + + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, 0, + &rp, sizeof(rp)); + + if (changed) + exp_feature_changed(hdev, mgmt_mesh_uuid, val, sk); + + return err; +} + +static int set_rpa_resolution_func(struct sock *sk, struct hci_dev *hdev, + struct mgmt_cp_set_exp_feature *cp, + u16 data_len) +{ + struct mgmt_rp_set_exp_feature rp; + bool val, changed; + int err; + u32 flags; + + /* Command requires to use the controller index */ + if (!hdev) + return mgmt_cmd_status(sk, MGMT_INDEX_NONE, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_INDEX); + + /* Changes can only be made when controller is powered down */ + if (hdev_is_powered(hdev)) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_REJECTED); + + /* Parameters are limited to a single octet */ + if (data_len != MGMT_SET_EXP_FEATURE_SIZE + 1) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_PARAMS); + + /* Only boolean on/off is supported */ + if (cp->param[0] != 0x00 && cp->param[0] != 0x01) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_PARAMS); + + val = !!cp->param[0]; + + if (val) { + changed = !hci_dev_test_and_set_flag(hdev, + HCI_ENABLE_LL_PRIVACY); + hci_dev_clear_flag(hdev, HCI_ADVERTISING); + + /* Enable LL privacy + supported settings changed */ + flags = BIT(0) | BIT(1); + } else { + changed = hci_dev_test_and_clear_flag(hdev, + HCI_ENABLE_LL_PRIVACY); + + /* Disable LL privacy + supported settings changed */ + flags = BIT(1); + } + + memcpy(rp.uuid, rpa_resolution_uuid, 16); + rp.flags = cpu_to_le32(flags); + + hci_sock_set_flag(sk, HCI_MGMT_EXP_FEATURE_EVENTS); + + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, 0, + &rp, sizeof(rp)); + + if (changed) + exp_ll_privacy_feature_changed(val, hdev, sk); + + return err; +} + +static int set_quality_report_func(struct sock *sk, struct hci_dev *hdev, + struct mgmt_cp_set_exp_feature *cp, + u16 data_len) +{ + struct mgmt_rp_set_exp_feature rp; + bool val, changed; + int err; + + /* Command requires to use a valid controller index */ + if (!hdev) + return mgmt_cmd_status(sk, MGMT_INDEX_NONE, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_INDEX); + + /* Parameters are limited to a single octet */ + if (data_len != MGMT_SET_EXP_FEATURE_SIZE + 1) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_PARAMS); + + /* Only boolean on/off is supported */ + if (cp->param[0] != 0x00 && cp->param[0] != 0x01) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_PARAMS); + + hci_req_sync_lock(hdev); + + val = !!cp->param[0]; + changed = (val != hci_dev_test_flag(hdev, HCI_QUALITY_REPORT)); + + if (!aosp_has_quality_report(hdev) && !hdev->set_quality_report) { + err = mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_NOT_SUPPORTED); + goto unlock_quality_report; + } + + if (changed) { + if (hdev->set_quality_report) + err = hdev->set_quality_report(hdev, val); + else + err = aosp_set_quality_report(hdev, val); + + if (err) { + err = mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_FAILED); + goto unlock_quality_report; + } + + if (val) + hci_dev_set_flag(hdev, HCI_QUALITY_REPORT); + else + hci_dev_clear_flag(hdev, HCI_QUALITY_REPORT); + } + + bt_dev_dbg(hdev, "quality report enable %d changed %d", val, changed); + + memcpy(rp.uuid, quality_report_uuid, 16); + rp.flags = cpu_to_le32(val ? BIT(0) : 0); + hci_sock_set_flag(sk, HCI_MGMT_EXP_FEATURE_EVENTS); + + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_SET_EXP_FEATURE, 0, + &rp, sizeof(rp)); + + if (changed) + exp_feature_changed(hdev, quality_report_uuid, val, sk); + +unlock_quality_report: + hci_req_sync_unlock(hdev); + return err; +} + +static int set_offload_codec_func(struct sock *sk, struct hci_dev *hdev, + struct mgmt_cp_set_exp_feature *cp, + u16 data_len) +{ + bool val, changed; + int err; + struct mgmt_rp_set_exp_feature rp; + + /* Command requires to use a valid controller index */ + if (!hdev) + return mgmt_cmd_status(sk, MGMT_INDEX_NONE, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_INDEX); + + /* Parameters are limited to a single octet */ + if (data_len != MGMT_SET_EXP_FEATURE_SIZE + 1) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_PARAMS); + + /* Only boolean on/off is supported */ + if (cp->param[0] != 0x00 && cp->param[0] != 0x01) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_PARAMS); + + val = !!cp->param[0]; + changed = (val != hci_dev_test_flag(hdev, HCI_OFFLOAD_CODECS_ENABLED)); + + if (!hdev->get_data_path_id) { + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_NOT_SUPPORTED); + } + + if (changed) { + if (val) + hci_dev_set_flag(hdev, HCI_OFFLOAD_CODECS_ENABLED); + else + hci_dev_clear_flag(hdev, HCI_OFFLOAD_CODECS_ENABLED); + } + + bt_dev_info(hdev, "offload codecs enable %d changed %d", + val, changed); + + memcpy(rp.uuid, offload_codecs_uuid, 16); + rp.flags = cpu_to_le32(val ? BIT(0) : 0); + hci_sock_set_flag(sk, HCI_MGMT_EXP_FEATURE_EVENTS); + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, 0, + &rp, sizeof(rp)); + + if (changed) + exp_feature_changed(hdev, offload_codecs_uuid, val, sk); + + return err; +} + +static int set_le_simultaneous_roles_func(struct sock *sk, struct hci_dev *hdev, + struct mgmt_cp_set_exp_feature *cp, + u16 data_len) +{ + bool val, changed; + int err; + struct mgmt_rp_set_exp_feature rp; + + /* Command requires to use a valid controller index */ + if (!hdev) + return mgmt_cmd_status(sk, MGMT_INDEX_NONE, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_INDEX); + + /* Parameters are limited to a single octet */ + if (data_len != MGMT_SET_EXP_FEATURE_SIZE + 1) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_PARAMS); + + /* Only boolean on/off is supported */ + if (cp->param[0] != 0x00 && cp->param[0] != 0x01) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_PARAMS); + + val = !!cp->param[0]; + changed = (val != hci_dev_test_flag(hdev, HCI_LE_SIMULTANEOUS_ROLES)); + + if (!hci_dev_le_state_simultaneous(hdev)) { + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_NOT_SUPPORTED); + } + + if (changed) { + if (val) + hci_dev_set_flag(hdev, HCI_LE_SIMULTANEOUS_ROLES); + else + hci_dev_clear_flag(hdev, HCI_LE_SIMULTANEOUS_ROLES); + } + + bt_dev_info(hdev, "LE simultaneous roles enable %d changed %d", + val, changed); + + memcpy(rp.uuid, le_simultaneous_roles_uuid, 16); + rp.flags = cpu_to_le32(val ? BIT(0) : 0); + hci_sock_set_flag(sk, HCI_MGMT_EXP_FEATURE_EVENTS); + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, 0, + &rp, sizeof(rp)); + + if (changed) + exp_feature_changed(hdev, le_simultaneous_roles_uuid, val, sk); + + return err; +} + +#ifdef CONFIG_BT_LE +static int set_iso_socket_func(struct sock *sk, struct hci_dev *hdev, + struct mgmt_cp_set_exp_feature *cp, u16 data_len) +{ + struct mgmt_rp_set_exp_feature rp; + bool val, changed = false; + int err; + + /* Command requires to use the non-controller index */ + if (hdev) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_INDEX); + + /* Parameters are limited to a single octet */ + if (data_len != MGMT_SET_EXP_FEATURE_SIZE + 1) + return mgmt_cmd_status(sk, MGMT_INDEX_NONE, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_PARAMS); + + /* Only boolean on/off is supported */ + if (cp->param[0] != 0x00 && cp->param[0] != 0x01) + return mgmt_cmd_status(sk, MGMT_INDEX_NONE, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_INVALID_PARAMS); + + val = cp->param[0] ? true : false; + if (val) + err = iso_init(); + else + err = iso_exit(); + + if (!err) + changed = true; + + memcpy(rp.uuid, iso_socket_uuid, 16); + rp.flags = cpu_to_le32(val ? BIT(0) : 0); + + hci_sock_set_flag(sk, HCI_MGMT_EXP_FEATURE_EVENTS); + + err = mgmt_cmd_complete(sk, MGMT_INDEX_NONE, + MGMT_OP_SET_EXP_FEATURE, 0, + &rp, sizeof(rp)); + + if (changed) + exp_feature_changed(hdev, iso_socket_uuid, val, sk); + + return err; +} +#endif + +static const struct mgmt_exp_feature { + const u8 *uuid; + int (*set_func)(struct sock *sk, struct hci_dev *hdev, + struct mgmt_cp_set_exp_feature *cp, u16 data_len); +} exp_features[] = { + EXP_FEAT(ZERO_KEY, set_zero_key_func), +#ifdef CONFIG_BT_FEATURE_DEBUG + EXP_FEAT(debug_uuid, set_debug_func), +#endif + EXP_FEAT(mgmt_mesh_uuid, set_mgmt_mesh_func), + EXP_FEAT(rpa_resolution_uuid, set_rpa_resolution_func), + EXP_FEAT(quality_report_uuid, set_quality_report_func), + EXP_FEAT(offload_codecs_uuid, set_offload_codec_func), + EXP_FEAT(le_simultaneous_roles_uuid, set_le_simultaneous_roles_func), +#ifdef CONFIG_BT_LE + EXP_FEAT(iso_socket_uuid, set_iso_socket_func), +#endif + + /* end with a null feature */ + EXP_FEAT(NULL, NULL) +}; + +static int set_exp_feature(struct sock *sk, struct hci_dev *hdev, + void *data, u16 data_len) +{ + struct mgmt_cp_set_exp_feature *cp = data; + size_t i = 0; + + bt_dev_dbg(hdev, "sock %p", sk); + + for (i = 0; exp_features[i].uuid; i++) { + if (!memcmp(cp->uuid, exp_features[i].uuid, 16)) + return exp_features[i].set_func(sk, hdev, cp, data_len); + } + + return mgmt_cmd_status(sk, hdev ? hdev->id : MGMT_INDEX_NONE, + MGMT_OP_SET_EXP_FEATURE, + MGMT_STATUS_NOT_SUPPORTED); +} + +static u32 get_params_flags(struct hci_dev *hdev, + struct hci_conn_params *params) +{ + u32 flags = hdev->conn_flags; + + /* Devices using RPAs can only be programmed in the acceptlist if + * LL Privacy has been enable otherwise they cannot mark + * HCI_CONN_FLAG_REMOTE_WAKEUP. + */ + if ((flags & HCI_CONN_FLAG_REMOTE_WAKEUP) && !use_ll_privacy(hdev) && + hci_find_irk_by_addr(hdev, ¶ms->addr, params->addr_type)) + flags &= ~HCI_CONN_FLAG_REMOTE_WAKEUP; + + return flags; +} + +static int get_device_flags(struct sock *sk, struct hci_dev *hdev, void *data, + u16 data_len) +{ + struct mgmt_cp_get_device_flags *cp = data; + struct mgmt_rp_get_device_flags rp; + struct bdaddr_list_with_flags *br_params; + struct hci_conn_params *params; + u32 supported_flags; + u32 current_flags = 0; + u8 status = MGMT_STATUS_INVALID_PARAMS; + + bt_dev_dbg(hdev, "Get device flags %pMR (type 0x%x)\n", + &cp->addr.bdaddr, cp->addr.type); + + hci_dev_lock(hdev); + + supported_flags = hdev->conn_flags; + + memset(&rp, 0, sizeof(rp)); + + if (cp->addr.type == BDADDR_BREDR) { + br_params = hci_bdaddr_list_lookup_with_flags(&hdev->accept_list, + &cp->addr.bdaddr, + cp->addr.type); + if (!br_params) + goto done; + + current_flags = br_params->flags; + } else { + params = hci_conn_params_lookup(hdev, &cp->addr.bdaddr, + le_addr_type(cp->addr.type)); + if (!params) + goto done; + + supported_flags = get_params_flags(hdev, params); + current_flags = params->flags; + } + + bacpy(&rp.addr.bdaddr, &cp->addr.bdaddr); + rp.addr.type = cp->addr.type; + rp.supported_flags = cpu_to_le32(supported_flags); + rp.current_flags = cpu_to_le32(current_flags); + + status = MGMT_STATUS_SUCCESS; + +done: + hci_dev_unlock(hdev); + + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_GET_DEVICE_FLAGS, status, + &rp, sizeof(rp)); +} + +static void device_flags_changed(struct sock *sk, struct hci_dev *hdev, + bdaddr_t *bdaddr, u8 bdaddr_type, + u32 supported_flags, u32 current_flags) +{ + struct mgmt_ev_device_flags_changed ev; + + bacpy(&ev.addr.bdaddr, bdaddr); + ev.addr.type = bdaddr_type; + ev.supported_flags = cpu_to_le32(supported_flags); + ev.current_flags = cpu_to_le32(current_flags); + + mgmt_event(MGMT_EV_DEVICE_FLAGS_CHANGED, hdev, &ev, sizeof(ev), sk); +} + +static int set_device_flags(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_set_device_flags *cp = data; + struct bdaddr_list_with_flags *br_params; + struct hci_conn_params *params; + u8 status = MGMT_STATUS_INVALID_PARAMS; + u32 supported_flags; + u32 current_flags = __le32_to_cpu(cp->current_flags); + + bt_dev_dbg(hdev, "Set device flags %pMR (type 0x%x) = 0x%x", + &cp->addr.bdaddr, cp->addr.type, current_flags); + + // We should take hci_dev_lock() early, I think.. conn_flags can change + supported_flags = hdev->conn_flags; + + if ((supported_flags | current_flags) != supported_flags) { + bt_dev_warn(hdev, "Bad flag given (0x%x) vs supported (0x%0x)", + current_flags, supported_flags); + goto done; + } + + hci_dev_lock(hdev); + + if (cp->addr.type == BDADDR_BREDR) { + br_params = hci_bdaddr_list_lookup_with_flags(&hdev->accept_list, + &cp->addr.bdaddr, + cp->addr.type); + + if (br_params) { + br_params->flags = current_flags; + status = MGMT_STATUS_SUCCESS; + } else { + bt_dev_warn(hdev, "No such BR/EDR device %pMR (0x%x)", + &cp->addr.bdaddr, cp->addr.type); + } + + goto unlock; + } + + params = hci_conn_params_lookup(hdev, &cp->addr.bdaddr, + le_addr_type(cp->addr.type)); + if (!params) { + bt_dev_warn(hdev, "No such LE device %pMR (0x%x)", + &cp->addr.bdaddr, le_addr_type(cp->addr.type)); + goto unlock; + } + + supported_flags = get_params_flags(hdev, params); + + if ((supported_flags | current_flags) != supported_flags) { + bt_dev_warn(hdev, "Bad flag given (0x%x) vs supported (0x%0x)", + current_flags, supported_flags); + goto unlock; + } + + WRITE_ONCE(params->flags, current_flags); + status = MGMT_STATUS_SUCCESS; + + /* Update passive scan if HCI_CONN_FLAG_DEVICE_PRIVACY + * has been set. + */ + if (params->flags & HCI_CONN_FLAG_DEVICE_PRIVACY) + hci_update_passive_scan(hdev); + +unlock: + hci_dev_unlock(hdev); + +done: + if (status == MGMT_STATUS_SUCCESS) + device_flags_changed(sk, hdev, &cp->addr.bdaddr, cp->addr.type, + supported_flags, current_flags); + + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_SET_DEVICE_FLAGS, status, + &cp->addr, sizeof(cp->addr)); +} + +static void mgmt_adv_monitor_added(struct sock *sk, struct hci_dev *hdev, + u16 handle) +{ + struct mgmt_ev_adv_monitor_added ev; + + ev.monitor_handle = cpu_to_le16(handle); + + mgmt_event(MGMT_EV_ADV_MONITOR_ADDED, hdev, &ev, sizeof(ev), sk); +} + +void mgmt_adv_monitor_removed(struct hci_dev *hdev, u16 handle) +{ + struct mgmt_ev_adv_monitor_removed ev; + struct mgmt_pending_cmd *cmd; + struct sock *sk_skip = NULL; + struct mgmt_cp_remove_adv_monitor *cp; + + cmd = pending_find(MGMT_OP_REMOVE_ADV_MONITOR, hdev); + if (cmd) { + cp = cmd->param; + + if (cp->monitor_handle) + sk_skip = cmd->sk; + } + + ev.monitor_handle = cpu_to_le16(handle); + + mgmt_event(MGMT_EV_ADV_MONITOR_REMOVED, hdev, &ev, sizeof(ev), sk_skip); +} + +static int read_adv_mon_features(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct adv_monitor *monitor = NULL; + struct mgmt_rp_read_adv_monitor_features *rp = NULL; + int handle, err; + size_t rp_size = 0; + __u32 supported = 0; + __u32 enabled = 0; + __u16 num_handles = 0; + __u16 handles[HCI_MAX_ADV_MONITOR_NUM_HANDLES]; + + BT_DBG("request for %s", hdev->name); + + hci_dev_lock(hdev); + + if (msft_monitor_supported(hdev)) + supported |= MGMT_ADV_MONITOR_FEATURE_MASK_OR_PATTERNS; + + idr_for_each_entry(&hdev->adv_monitors_idr, monitor, handle) + handles[num_handles++] = monitor->handle; + + hci_dev_unlock(hdev); + + rp_size = sizeof(*rp) + (num_handles * sizeof(u16)); + rp = kmalloc(rp_size, GFP_KERNEL); + if (!rp) + return -ENOMEM; + + /* All supported features are currently enabled */ + enabled = supported; + + rp->supported_features = cpu_to_le32(supported); + rp->enabled_features = cpu_to_le32(enabled); + rp->max_num_handles = cpu_to_le16(HCI_MAX_ADV_MONITOR_NUM_HANDLES); + rp->max_num_patterns = HCI_MAX_ADV_MONITOR_NUM_PATTERNS; + rp->num_handles = cpu_to_le16(num_handles); + if (num_handles) + memcpy(&rp->handles, &handles, (num_handles * sizeof(u16))); + + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_READ_ADV_MONITOR_FEATURES, + MGMT_STATUS_SUCCESS, rp, rp_size); + + kfree(rp); + + return err; +} + +static void mgmt_add_adv_patterns_monitor_complete(struct hci_dev *hdev, + void *data, int status) +{ + struct mgmt_rp_add_adv_patterns_monitor rp; + struct mgmt_pending_cmd *cmd = data; + struct adv_monitor *monitor = cmd->user_data; + + hci_dev_lock(hdev); + + rp.monitor_handle = cpu_to_le16(monitor->handle); + + if (!status) { + mgmt_adv_monitor_added(cmd->sk, hdev, monitor->handle); + hdev->adv_monitors_cnt++; + if (monitor->state == ADV_MONITOR_STATE_NOT_REGISTERED) + monitor->state = ADV_MONITOR_STATE_REGISTERED; + hci_update_passive_scan(hdev); + } + + mgmt_cmd_complete(cmd->sk, cmd->index, cmd->opcode, + mgmt_status(status), &rp, sizeof(rp)); + mgmt_pending_remove(cmd); + + hci_dev_unlock(hdev); + bt_dev_dbg(hdev, "add monitor %d complete, status %d", + rp.monitor_handle, status); +} + +static int mgmt_add_adv_patterns_monitor_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct adv_monitor *monitor = cmd->user_data; + + return hci_add_adv_monitor(hdev, monitor); +} + +static int __add_adv_patterns_monitor(struct sock *sk, struct hci_dev *hdev, + struct adv_monitor *m, u8 status, + void *data, u16 len, u16 op) +{ + struct mgmt_pending_cmd *cmd; + int err; + + hci_dev_lock(hdev); + + if (status) + goto unlock; + + if (pending_find(MGMT_OP_SET_LE, hdev) || + pending_find(MGMT_OP_ADD_ADV_PATTERNS_MONITOR, hdev) || + pending_find(MGMT_OP_ADD_ADV_PATTERNS_MONITOR_RSSI, hdev) || + pending_find(MGMT_OP_REMOVE_ADV_MONITOR, hdev)) { + status = MGMT_STATUS_BUSY; + goto unlock; + } + + cmd = mgmt_pending_add(sk, op, hdev, data, len); + if (!cmd) { + status = MGMT_STATUS_NO_RESOURCES; + goto unlock; + } + + cmd->user_data = m; + err = hci_cmd_sync_queue(hdev, mgmt_add_adv_patterns_monitor_sync, cmd, + mgmt_add_adv_patterns_monitor_complete); + if (err) { + if (err == -ENOMEM) + status = MGMT_STATUS_NO_RESOURCES; + else + status = MGMT_STATUS_FAILED; + + goto unlock; + } + + hci_dev_unlock(hdev); + + return 0; + +unlock: + hci_free_adv_monitor(hdev, m); + hci_dev_unlock(hdev); + return mgmt_cmd_status(sk, hdev->id, op, status); +} + +static void parse_adv_monitor_rssi(struct adv_monitor *m, + struct mgmt_adv_rssi_thresholds *rssi) +{ + if (rssi) { + m->rssi.low_threshold = rssi->low_threshold; + m->rssi.low_threshold_timeout = + __le16_to_cpu(rssi->low_threshold_timeout); + m->rssi.high_threshold = rssi->high_threshold; + m->rssi.high_threshold_timeout = + __le16_to_cpu(rssi->high_threshold_timeout); + m->rssi.sampling_period = rssi->sampling_period; + } else { + /* Default values. These numbers are the least constricting + * parameters for MSFT API to work, so it behaves as if there + * are no rssi parameter to consider. May need to be changed + * if other API are to be supported. + */ + m->rssi.low_threshold = -127; + m->rssi.low_threshold_timeout = 60; + m->rssi.high_threshold = -127; + m->rssi.high_threshold_timeout = 0; + m->rssi.sampling_period = 0; + } +} + +static u8 parse_adv_monitor_pattern(struct adv_monitor *m, u8 pattern_count, + struct mgmt_adv_pattern *patterns) +{ + u8 offset = 0, length = 0; + struct adv_pattern *p = NULL; + int i; + + for (i = 0; i < pattern_count; i++) { + offset = patterns[i].offset; + length = patterns[i].length; + if (offset >= HCI_MAX_AD_LENGTH || + length > HCI_MAX_AD_LENGTH || + (offset + length) > HCI_MAX_AD_LENGTH) + return MGMT_STATUS_INVALID_PARAMS; + + p = kmalloc(sizeof(*p), GFP_KERNEL); + if (!p) + return MGMT_STATUS_NO_RESOURCES; + + p->ad_type = patterns[i].ad_type; + p->offset = patterns[i].offset; + p->length = patterns[i].length; + memcpy(p->value, patterns[i].value, p->length); + + INIT_LIST_HEAD(&p->list); + list_add(&p->list, &m->patterns); + } + + return MGMT_STATUS_SUCCESS; +} + +static int add_adv_patterns_monitor(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_cp_add_adv_patterns_monitor *cp = data; + struct adv_monitor *m = NULL; + u8 status = MGMT_STATUS_SUCCESS; + size_t expected_size = sizeof(*cp); + + BT_DBG("request for %s", hdev->name); + + if (len <= sizeof(*cp)) { + status = MGMT_STATUS_INVALID_PARAMS; + goto done; + } + + expected_size += cp->pattern_count * sizeof(struct mgmt_adv_pattern); + if (len != expected_size) { + status = MGMT_STATUS_INVALID_PARAMS; + goto done; + } + + m = kzalloc(sizeof(*m), GFP_KERNEL); + if (!m) { + status = MGMT_STATUS_NO_RESOURCES; + goto done; + } + + INIT_LIST_HEAD(&m->patterns); + + parse_adv_monitor_rssi(m, NULL); + status = parse_adv_monitor_pattern(m, cp->pattern_count, cp->patterns); + +done: + return __add_adv_patterns_monitor(sk, hdev, m, status, data, len, + MGMT_OP_ADD_ADV_PATTERNS_MONITOR); +} + +static int add_adv_patterns_monitor_rssi(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_cp_add_adv_patterns_monitor_rssi *cp = data; + struct adv_monitor *m = NULL; + u8 status = MGMT_STATUS_SUCCESS; + size_t expected_size = sizeof(*cp); + + BT_DBG("request for %s", hdev->name); + + if (len <= sizeof(*cp)) { + status = MGMT_STATUS_INVALID_PARAMS; + goto done; + } + + expected_size += cp->pattern_count * sizeof(struct mgmt_adv_pattern); + if (len != expected_size) { + status = MGMT_STATUS_INVALID_PARAMS; + goto done; + } + + m = kzalloc(sizeof(*m), GFP_KERNEL); + if (!m) { + status = MGMT_STATUS_NO_RESOURCES; + goto done; + } + + INIT_LIST_HEAD(&m->patterns); + + parse_adv_monitor_rssi(m, &cp->rssi); + status = parse_adv_monitor_pattern(m, cp->pattern_count, cp->patterns); + +done: + return __add_adv_patterns_monitor(sk, hdev, m, status, data, len, + MGMT_OP_ADD_ADV_PATTERNS_MONITOR_RSSI); +} + +static void mgmt_remove_adv_monitor_complete(struct hci_dev *hdev, + void *data, int status) +{ + struct mgmt_rp_remove_adv_monitor rp; + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_remove_adv_monitor *cp = cmd->param; + + hci_dev_lock(hdev); + + rp.monitor_handle = cp->monitor_handle; + + if (!status) + hci_update_passive_scan(hdev); + + mgmt_cmd_complete(cmd->sk, cmd->index, cmd->opcode, + mgmt_status(status), &rp, sizeof(rp)); + mgmt_pending_remove(cmd); + + hci_dev_unlock(hdev); + bt_dev_dbg(hdev, "remove monitor %d complete, status %d", + rp.monitor_handle, status); +} + +static int mgmt_remove_adv_monitor_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_remove_adv_monitor *cp = cmd->param; + u16 handle = __le16_to_cpu(cp->monitor_handle); + + if (!handle) + return hci_remove_all_adv_monitor(hdev); + + return hci_remove_single_adv_monitor(hdev, handle); +} + +static int remove_adv_monitor(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_pending_cmd *cmd; + int err, status; + + hci_dev_lock(hdev); + + if (pending_find(MGMT_OP_SET_LE, hdev) || + pending_find(MGMT_OP_REMOVE_ADV_MONITOR, hdev) || + pending_find(MGMT_OP_ADD_ADV_PATTERNS_MONITOR, hdev) || + pending_find(MGMT_OP_ADD_ADV_PATTERNS_MONITOR_RSSI, hdev)) { + status = MGMT_STATUS_BUSY; + goto unlock; + } + + cmd = mgmt_pending_add(sk, MGMT_OP_REMOVE_ADV_MONITOR, hdev, data, len); + if (!cmd) { + status = MGMT_STATUS_NO_RESOURCES; + goto unlock; + } + + err = hci_cmd_sync_queue(hdev, mgmt_remove_adv_monitor_sync, cmd, + mgmt_remove_adv_monitor_complete); + + if (err) { + mgmt_pending_remove(cmd); + + if (err == -ENOMEM) + status = MGMT_STATUS_NO_RESOURCES; + else + status = MGMT_STATUS_FAILED; + + goto unlock; + } + + hci_dev_unlock(hdev); + + return 0; + +unlock: + hci_dev_unlock(hdev); + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_REMOVE_ADV_MONITOR, + status); +} + +static void read_local_oob_data_complete(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_rp_read_local_oob_data mgmt_rp; + size_t rp_size = sizeof(mgmt_rp); + struct mgmt_pending_cmd *cmd = data; + struct sk_buff *skb = cmd->skb; + u8 status = mgmt_status(err); + + if (!status) { + if (!skb) + status = MGMT_STATUS_FAILED; + else if (IS_ERR(skb)) + status = mgmt_status(PTR_ERR(skb)); + else + status = mgmt_status(skb->data[0]); + } + + bt_dev_dbg(hdev, "status %d", status); + + if (status) { + mgmt_cmd_status(cmd->sk, hdev->id, MGMT_OP_READ_LOCAL_OOB_DATA, status); + goto remove; + } + + memset(&mgmt_rp, 0, sizeof(mgmt_rp)); + + if (!bredr_sc_enabled(hdev)) { + struct hci_rp_read_local_oob_data *rp = (void *) skb->data; + + if (skb->len < sizeof(*rp)) { + mgmt_cmd_status(cmd->sk, hdev->id, + MGMT_OP_READ_LOCAL_OOB_DATA, + MGMT_STATUS_FAILED); + goto remove; + } + + memcpy(mgmt_rp.hash192, rp->hash, sizeof(rp->hash)); + memcpy(mgmt_rp.rand192, rp->rand, sizeof(rp->rand)); + + rp_size -= sizeof(mgmt_rp.hash256) + sizeof(mgmt_rp.rand256); + } else { + struct hci_rp_read_local_oob_ext_data *rp = (void *) skb->data; + + if (skb->len < sizeof(*rp)) { + mgmt_cmd_status(cmd->sk, hdev->id, + MGMT_OP_READ_LOCAL_OOB_DATA, + MGMT_STATUS_FAILED); + goto remove; + } + + memcpy(mgmt_rp.hash192, rp->hash192, sizeof(rp->hash192)); + memcpy(mgmt_rp.rand192, rp->rand192, sizeof(rp->rand192)); + + memcpy(mgmt_rp.hash256, rp->hash256, sizeof(rp->hash256)); + memcpy(mgmt_rp.rand256, rp->rand256, sizeof(rp->rand256)); + } + + mgmt_cmd_complete(cmd->sk, hdev->id, MGMT_OP_READ_LOCAL_OOB_DATA, + MGMT_STATUS_SUCCESS, &mgmt_rp, rp_size); + +remove: + if (skb && !IS_ERR(skb)) + kfree_skb(skb); + + mgmt_pending_free(cmd); +} + +static int read_local_oob_data_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + + if (bredr_sc_enabled(hdev)) + cmd->skb = hci_read_local_oob_data_sync(hdev, true, cmd->sk); + else + cmd->skb = hci_read_local_oob_data_sync(hdev, false, cmd->sk); + + if (IS_ERR(cmd->skb)) + return PTR_ERR(cmd->skb); + else + return 0; +} + +static int read_local_oob_data(struct sock *sk, struct hci_dev *hdev, + void *data, u16 data_len) +{ + struct mgmt_pending_cmd *cmd; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + hci_dev_lock(hdev); + + if (!hdev_is_powered(hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_READ_LOCAL_OOB_DATA, + MGMT_STATUS_NOT_POWERED); + goto unlock; + } + + if (!lmp_ssp_capable(hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_READ_LOCAL_OOB_DATA, + MGMT_STATUS_NOT_SUPPORTED); + goto unlock; + } + + cmd = mgmt_pending_new(sk, MGMT_OP_READ_LOCAL_OOB_DATA, hdev, NULL, 0); + if (!cmd) + err = -ENOMEM; + else + err = hci_cmd_sync_queue(hdev, read_local_oob_data_sync, cmd, + read_local_oob_data_complete); + + if (err < 0) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_READ_LOCAL_OOB_DATA, + MGMT_STATUS_FAILED); + + if (cmd) + mgmt_pending_free(cmd); + } + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static int add_remote_oob_data(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_addr_info *addr = data; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!bdaddr_type_is_valid(addr->type)) + return mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_ADD_REMOTE_OOB_DATA, + MGMT_STATUS_INVALID_PARAMS, + addr, sizeof(*addr)); + + hci_dev_lock(hdev); + + if (len == MGMT_ADD_REMOTE_OOB_DATA_SIZE) { + struct mgmt_cp_add_remote_oob_data *cp = data; + u8 status; + + if (cp->addr.type != BDADDR_BREDR) { + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_ADD_REMOTE_OOB_DATA, + MGMT_STATUS_INVALID_PARAMS, + &cp->addr, sizeof(cp->addr)); + goto unlock; + } + + err = hci_add_remote_oob_data(hdev, &cp->addr.bdaddr, + cp->addr.type, cp->hash, + cp->rand, NULL, NULL); + if (err < 0) + status = MGMT_STATUS_FAILED; + else + status = MGMT_STATUS_SUCCESS; + + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_ADD_REMOTE_OOB_DATA, status, + &cp->addr, sizeof(cp->addr)); + } else if (len == MGMT_ADD_REMOTE_OOB_EXT_DATA_SIZE) { + struct mgmt_cp_add_remote_oob_ext_data *cp = data; + u8 *rand192, *hash192, *rand256, *hash256; + u8 status; + + if (bdaddr_type_is_le(cp->addr.type)) { + /* Enforce zero-valued 192-bit parameters as + * long as legacy SMP OOB isn't implemented. + */ + if (memcmp(cp->rand192, ZERO_KEY, 16) || + memcmp(cp->hash192, ZERO_KEY, 16)) { + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_ADD_REMOTE_OOB_DATA, + MGMT_STATUS_INVALID_PARAMS, + addr, sizeof(*addr)); + goto unlock; + } + + rand192 = NULL; + hash192 = NULL; + } else { + /* In case one of the P-192 values is set to zero, + * then just disable OOB data for P-192. + */ + if (!memcmp(cp->rand192, ZERO_KEY, 16) || + !memcmp(cp->hash192, ZERO_KEY, 16)) { + rand192 = NULL; + hash192 = NULL; + } else { + rand192 = cp->rand192; + hash192 = cp->hash192; + } + } + + /* In case one of the P-256 values is set to zero, then just + * disable OOB data for P-256. + */ + if (!memcmp(cp->rand256, ZERO_KEY, 16) || + !memcmp(cp->hash256, ZERO_KEY, 16)) { + rand256 = NULL; + hash256 = NULL; + } else { + rand256 = cp->rand256; + hash256 = cp->hash256; + } + + err = hci_add_remote_oob_data(hdev, &cp->addr.bdaddr, + cp->addr.type, hash192, rand192, + hash256, rand256); + if (err < 0) + status = MGMT_STATUS_FAILED; + else + status = MGMT_STATUS_SUCCESS; + + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_ADD_REMOTE_OOB_DATA, + status, &cp->addr, sizeof(cp->addr)); + } else { + bt_dev_err(hdev, "add_remote_oob_data: invalid len of %u bytes", + len); + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_REMOTE_OOB_DATA, + MGMT_STATUS_INVALID_PARAMS); + } + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static int remove_remote_oob_data(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_cp_remove_remote_oob_data *cp = data; + u8 status; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (cp->addr.type != BDADDR_BREDR) + return mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_REMOVE_REMOTE_OOB_DATA, + MGMT_STATUS_INVALID_PARAMS, + &cp->addr, sizeof(cp->addr)); + + hci_dev_lock(hdev); + + if (!bacmp(&cp->addr.bdaddr, BDADDR_ANY)) { + hci_remote_oob_data_clear(hdev); + status = MGMT_STATUS_SUCCESS; + goto done; + } + + err = hci_remove_remote_oob_data(hdev, &cp->addr.bdaddr, cp->addr.type); + if (err < 0) + status = MGMT_STATUS_INVALID_PARAMS; + else + status = MGMT_STATUS_SUCCESS; + +done: + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_REMOVE_REMOTE_OOB_DATA, + status, &cp->addr, sizeof(cp->addr)); + + hci_dev_unlock(hdev); + return err; +} + +void mgmt_start_discovery_complete(struct hci_dev *hdev, u8 status) +{ + struct mgmt_pending_cmd *cmd; + + bt_dev_dbg(hdev, "status %u", status); + + hci_dev_lock(hdev); + + cmd = pending_find(MGMT_OP_START_DISCOVERY, hdev); + if (!cmd) + cmd = pending_find(MGMT_OP_START_SERVICE_DISCOVERY, hdev); + + if (!cmd) + cmd = pending_find(MGMT_OP_START_LIMITED_DISCOVERY, hdev); + + if (cmd) { + cmd->cmd_complete(cmd, mgmt_status(status)); + mgmt_pending_remove(cmd); + } + + hci_dev_unlock(hdev); +} + +static bool discovery_type_is_valid(struct hci_dev *hdev, uint8_t type, + uint8_t *mgmt_status) +{ + switch (type) { + case DISCOV_TYPE_LE: + *mgmt_status = mgmt_le_support(hdev); + if (*mgmt_status) + return false; + break; + case DISCOV_TYPE_INTERLEAVED: + *mgmt_status = mgmt_le_support(hdev); + if (*mgmt_status) + return false; + fallthrough; + case DISCOV_TYPE_BREDR: + *mgmt_status = mgmt_bredr_support(hdev); + if (*mgmt_status) + return false; + break; + default: + *mgmt_status = MGMT_STATUS_INVALID_PARAMS; + return false; + } + + return true; +} + +static void start_discovery_complete(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_pending_cmd *cmd = data; + + if (cmd != pending_find(MGMT_OP_START_DISCOVERY, hdev) && + cmd != pending_find(MGMT_OP_START_LIMITED_DISCOVERY, hdev) && + cmd != pending_find(MGMT_OP_START_SERVICE_DISCOVERY, hdev)) + return; + + bt_dev_dbg(hdev, "err %d", err); + + mgmt_cmd_complete(cmd->sk, cmd->index, cmd->opcode, mgmt_status(err), + cmd->param, 1); + mgmt_pending_remove(cmd); + + hci_discovery_set_state(hdev, err ? DISCOVERY_STOPPED: + DISCOVERY_FINDING); +} + +static int start_discovery_sync(struct hci_dev *hdev, void *data) +{ + return hci_start_discovery_sync(hdev); +} + +static int start_discovery_internal(struct sock *sk, struct hci_dev *hdev, + u16 op, void *data, u16 len) +{ + struct mgmt_cp_start_discovery *cp = data; + struct mgmt_pending_cmd *cmd; + u8 status; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + hci_dev_lock(hdev); + + if (!hdev_is_powered(hdev)) { + err = mgmt_cmd_complete(sk, hdev->id, op, + MGMT_STATUS_NOT_POWERED, + &cp->type, sizeof(cp->type)); + goto failed; + } + + if (hdev->discovery.state != DISCOVERY_STOPPED || + hci_dev_test_flag(hdev, HCI_PERIODIC_INQ)) { + err = mgmt_cmd_complete(sk, hdev->id, op, MGMT_STATUS_BUSY, + &cp->type, sizeof(cp->type)); + goto failed; + } + + if (!discovery_type_is_valid(hdev, cp->type, &status)) { + err = mgmt_cmd_complete(sk, hdev->id, op, status, + &cp->type, sizeof(cp->type)); + goto failed; + } + + /* Can't start discovery when it is paused */ + if (hdev->discovery_paused) { + err = mgmt_cmd_complete(sk, hdev->id, op, MGMT_STATUS_BUSY, + &cp->type, sizeof(cp->type)); + goto failed; + } + + /* Clear the discovery filter first to free any previously + * allocated memory for the UUID list. + */ + hci_discovery_filter_clear(hdev); + + hdev->discovery.type = cp->type; + hdev->discovery.report_invalid_rssi = false; + if (op == MGMT_OP_START_LIMITED_DISCOVERY) + hdev->discovery.limited = true; + else + hdev->discovery.limited = false; + + cmd = mgmt_pending_add(sk, op, hdev, data, len); + if (!cmd) { + err = -ENOMEM; + goto failed; + } + + err = hci_cmd_sync_queue(hdev, start_discovery_sync, cmd, + start_discovery_complete); + if (err < 0) { + mgmt_pending_remove(cmd); + goto failed; + } + + hci_discovery_set_state(hdev, DISCOVERY_STARTING); + +failed: + hci_dev_unlock(hdev); + return err; +} + +static int start_discovery(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + return start_discovery_internal(sk, hdev, MGMT_OP_START_DISCOVERY, + data, len); +} + +static int start_limited_discovery(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + return start_discovery_internal(sk, hdev, + MGMT_OP_START_LIMITED_DISCOVERY, + data, len); +} + +static int start_service_discovery(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_cp_start_service_discovery *cp = data; + struct mgmt_pending_cmd *cmd; + const u16 max_uuid_count = ((U16_MAX - sizeof(*cp)) / 16); + u16 uuid_count, expected_len; + u8 status; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + hci_dev_lock(hdev); + + if (!hdev_is_powered(hdev)) { + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_START_SERVICE_DISCOVERY, + MGMT_STATUS_NOT_POWERED, + &cp->type, sizeof(cp->type)); + goto failed; + } + + if (hdev->discovery.state != DISCOVERY_STOPPED || + hci_dev_test_flag(hdev, HCI_PERIODIC_INQ)) { + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_START_SERVICE_DISCOVERY, + MGMT_STATUS_BUSY, &cp->type, + sizeof(cp->type)); + goto failed; + } + + if (hdev->discovery_paused) { + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_START_SERVICE_DISCOVERY, + MGMT_STATUS_BUSY, &cp->type, + sizeof(cp->type)); + goto failed; + } + + uuid_count = __le16_to_cpu(cp->uuid_count); + if (uuid_count > max_uuid_count) { + bt_dev_err(hdev, "service_discovery: too big uuid_count value %u", + uuid_count); + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_START_SERVICE_DISCOVERY, + MGMT_STATUS_INVALID_PARAMS, &cp->type, + sizeof(cp->type)); + goto failed; + } + + expected_len = sizeof(*cp) + uuid_count * 16; + if (expected_len != len) { + bt_dev_err(hdev, "service_discovery: expected %u bytes, got %u bytes", + expected_len, len); + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_START_SERVICE_DISCOVERY, + MGMT_STATUS_INVALID_PARAMS, &cp->type, + sizeof(cp->type)); + goto failed; + } + + if (!discovery_type_is_valid(hdev, cp->type, &status)) { + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_START_SERVICE_DISCOVERY, + status, &cp->type, sizeof(cp->type)); + goto failed; + } + + cmd = mgmt_pending_add(sk, MGMT_OP_START_SERVICE_DISCOVERY, + hdev, data, len); + if (!cmd) { + err = -ENOMEM; + goto failed; + } + + /* Clear the discovery filter first to free any previously + * allocated memory for the UUID list. + */ + hci_discovery_filter_clear(hdev); + + hdev->discovery.result_filtering = true; + hdev->discovery.type = cp->type; + hdev->discovery.rssi = cp->rssi; + hdev->discovery.uuid_count = uuid_count; + + if (uuid_count > 0) { + hdev->discovery.uuids = kmemdup(cp->uuids, uuid_count * 16, + GFP_KERNEL); + if (!hdev->discovery.uuids) { + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_START_SERVICE_DISCOVERY, + MGMT_STATUS_FAILED, + &cp->type, sizeof(cp->type)); + mgmt_pending_remove(cmd); + goto failed; + } + } + + err = hci_cmd_sync_queue(hdev, start_discovery_sync, cmd, + start_discovery_complete); + if (err < 0) { + mgmt_pending_remove(cmd); + goto failed; + } + + hci_discovery_set_state(hdev, DISCOVERY_STARTING); + +failed: + hci_dev_unlock(hdev); + return err; +} + +void mgmt_stop_discovery_complete(struct hci_dev *hdev, u8 status) +{ + struct mgmt_pending_cmd *cmd; + + bt_dev_dbg(hdev, "status %u", status); + + hci_dev_lock(hdev); + + cmd = pending_find(MGMT_OP_STOP_DISCOVERY, hdev); + if (cmd) { + cmd->cmd_complete(cmd, mgmt_status(status)); + mgmt_pending_remove(cmd); + } + + hci_dev_unlock(hdev); +} + +static void stop_discovery_complete(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_pending_cmd *cmd = data; + + if (cmd != pending_find(MGMT_OP_STOP_DISCOVERY, hdev)) + return; + + bt_dev_dbg(hdev, "err %d", err); + + mgmt_cmd_complete(cmd->sk, cmd->index, cmd->opcode, mgmt_status(err), + cmd->param, 1); + mgmt_pending_remove(cmd); + + if (!err) + hci_discovery_set_state(hdev, DISCOVERY_STOPPED); +} + +static int stop_discovery_sync(struct hci_dev *hdev, void *data) +{ + return hci_stop_discovery_sync(hdev); +} + +static int stop_discovery(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_stop_discovery *mgmt_cp = data; + struct mgmt_pending_cmd *cmd; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + hci_dev_lock(hdev); + + if (!hci_discovery_active(hdev)) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_STOP_DISCOVERY, + MGMT_STATUS_REJECTED, &mgmt_cp->type, + sizeof(mgmt_cp->type)); + goto unlock; + } + + if (hdev->discovery.type != mgmt_cp->type) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_STOP_DISCOVERY, + MGMT_STATUS_INVALID_PARAMS, + &mgmt_cp->type, sizeof(mgmt_cp->type)); + goto unlock; + } + + cmd = mgmt_pending_add(sk, MGMT_OP_STOP_DISCOVERY, hdev, data, len); + if (!cmd) { + err = -ENOMEM; + goto unlock; + } + + err = hci_cmd_sync_queue(hdev, stop_discovery_sync, cmd, + stop_discovery_complete); + if (err < 0) { + mgmt_pending_remove(cmd); + goto unlock; + } + + hci_discovery_set_state(hdev, DISCOVERY_STOPPING); + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static int confirm_name(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_confirm_name *cp = data; + struct inquiry_entry *e; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + hci_dev_lock(hdev); + + if (!hci_discovery_active(hdev)) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_CONFIRM_NAME, + MGMT_STATUS_FAILED, &cp->addr, + sizeof(cp->addr)); + goto failed; + } + + e = hci_inquiry_cache_lookup_unknown(hdev, &cp->addr.bdaddr); + if (!e) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_CONFIRM_NAME, + MGMT_STATUS_INVALID_PARAMS, &cp->addr, + sizeof(cp->addr)); + goto failed; + } + + if (cp->name_known) { + e->name_state = NAME_KNOWN; + list_del(&e->list); + } else { + e->name_state = NAME_NEEDED; + hci_inquiry_cache_update_resolve(hdev, e); + } + + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_CONFIRM_NAME, 0, + &cp->addr, sizeof(cp->addr)); + +failed: + hci_dev_unlock(hdev); + return err; +} + +static int block_device(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_block_device *cp = data; + u8 status; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!bdaddr_type_is_valid(cp->addr.type)) + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_BLOCK_DEVICE, + MGMT_STATUS_INVALID_PARAMS, + &cp->addr, sizeof(cp->addr)); + + hci_dev_lock(hdev); + + err = hci_bdaddr_list_add(&hdev->reject_list, &cp->addr.bdaddr, + cp->addr.type); + if (err < 0) { + status = MGMT_STATUS_FAILED; + goto done; + } + + mgmt_event(MGMT_EV_DEVICE_BLOCKED, hdev, &cp->addr, sizeof(cp->addr), + sk); + status = MGMT_STATUS_SUCCESS; + +done: + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_BLOCK_DEVICE, status, + &cp->addr, sizeof(cp->addr)); + + hci_dev_unlock(hdev); + + return err; +} + +static int unblock_device(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_unblock_device *cp = data; + u8 status; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!bdaddr_type_is_valid(cp->addr.type)) + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_UNBLOCK_DEVICE, + MGMT_STATUS_INVALID_PARAMS, + &cp->addr, sizeof(cp->addr)); + + hci_dev_lock(hdev); + + err = hci_bdaddr_list_del(&hdev->reject_list, &cp->addr.bdaddr, + cp->addr.type); + if (err < 0) { + status = MGMT_STATUS_INVALID_PARAMS; + goto done; + } + + mgmt_event(MGMT_EV_DEVICE_UNBLOCKED, hdev, &cp->addr, sizeof(cp->addr), + sk); + status = MGMT_STATUS_SUCCESS; + +done: + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_UNBLOCK_DEVICE, status, + &cp->addr, sizeof(cp->addr)); + + hci_dev_unlock(hdev); + + return err; +} + +static int set_device_id_sync(struct hci_dev *hdev, void *data) +{ + return hci_update_eir_sync(hdev); +} + +static int set_device_id(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_set_device_id *cp = data; + int err; + __u16 source; + + bt_dev_dbg(hdev, "sock %p", sk); + + source = __le16_to_cpu(cp->source); + + if (source > 0x0002) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_DEVICE_ID, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + hdev->devid_source = source; + hdev->devid_vendor = __le16_to_cpu(cp->vendor); + hdev->devid_product = __le16_to_cpu(cp->product); + hdev->devid_version = __le16_to_cpu(cp->version); + + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_SET_DEVICE_ID, 0, + NULL, 0); + + hci_cmd_sync_queue(hdev, set_device_id_sync, NULL, NULL); + + hci_dev_unlock(hdev); + + return err; +} + +static void enable_advertising_instance(struct hci_dev *hdev, int err) +{ + if (err) + bt_dev_err(hdev, "failed to re-configure advertising %d", err); + else + bt_dev_dbg(hdev, "status %d", err); +} + +static void set_advertising_complete(struct hci_dev *hdev, void *data, int err) +{ + struct cmd_lookup match = { NULL, hdev }; + u8 instance; + struct adv_info *adv_instance; + u8 status = mgmt_status(err); + + if (status) { + mgmt_pending_foreach(MGMT_OP_SET_ADVERTISING, hdev, + cmd_status_rsp, &status); + return; + } + + if (hci_dev_test_flag(hdev, HCI_LE_ADV)) + hci_dev_set_flag(hdev, HCI_ADVERTISING); + else + hci_dev_clear_flag(hdev, HCI_ADVERTISING); + + mgmt_pending_foreach(MGMT_OP_SET_ADVERTISING, hdev, settings_rsp, + &match); + + new_settings(hdev, match.sk); + + if (match.sk) + sock_put(match.sk); + + /* If "Set Advertising" was just disabled and instance advertising was + * set up earlier, then re-enable multi-instance advertising. + */ + if (hci_dev_test_flag(hdev, HCI_ADVERTISING) || + list_empty(&hdev->adv_instances)) + return; + + instance = hdev->cur_adv_instance; + if (!instance) { + adv_instance = list_first_entry_or_null(&hdev->adv_instances, + struct adv_info, list); + if (!adv_instance) + return; + + instance = adv_instance->instance; + } + + err = hci_schedule_adv_instance_sync(hdev, instance, true); + + enable_advertising_instance(hdev, err); +} + +static int set_adv_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_mode *cp = cmd->param; + u8 val = !!cp->val; + + if (cp->val == 0x02) + hci_dev_set_flag(hdev, HCI_ADVERTISING_CONNECTABLE); + else + hci_dev_clear_flag(hdev, HCI_ADVERTISING_CONNECTABLE); + + cancel_adv_timeout(hdev); + + if (val) { + /* Switch to instance "0" for the Set Advertising setting. + * We cannot use update_[adv|scan_rsp]_data() here as the + * HCI_ADVERTISING flag is not yet set. + */ + hdev->cur_adv_instance = 0x00; + + if (ext_adv_capable(hdev)) { + hci_start_ext_adv_sync(hdev, 0x00); + } else { + hci_update_adv_data_sync(hdev, 0x00); + hci_update_scan_rsp_data_sync(hdev, 0x00); + hci_enable_advertising_sync(hdev); + } + } else { + hci_disable_advertising_sync(hdev); + } + + return 0; +} + +static int set_advertising(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_mode *cp = data; + struct mgmt_pending_cmd *cmd; + u8 val, status; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + status = mgmt_le_support(hdev); + if (status) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_ADVERTISING, + status); + + if (cp->val != 0x00 && cp->val != 0x01 && cp->val != 0x02) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_ADVERTISING, + MGMT_STATUS_INVALID_PARAMS); + + if (hdev->advertising_paused) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_ADVERTISING, + MGMT_STATUS_BUSY); + + hci_dev_lock(hdev); + + val = !!cp->val; + + /* The following conditions are ones which mean that we should + * not do any HCI communication but directly send a mgmt + * response to user space (after toggling the flag if + * necessary). + */ + if (!hdev_is_powered(hdev) || + (val == hci_dev_test_flag(hdev, HCI_ADVERTISING) && + (cp->val == 0x02) == hci_dev_test_flag(hdev, HCI_ADVERTISING_CONNECTABLE)) || + hci_dev_test_flag(hdev, HCI_MESH) || + hci_conn_num(hdev, LE_LINK) > 0 || + (hci_dev_test_flag(hdev, HCI_LE_SCAN) && + hdev->le_scan_type == LE_SCAN_ACTIVE)) { + bool changed; + + if (cp->val) { + hdev->cur_adv_instance = 0x00; + changed = !hci_dev_test_and_set_flag(hdev, HCI_ADVERTISING); + if (cp->val == 0x02) + hci_dev_set_flag(hdev, HCI_ADVERTISING_CONNECTABLE); + else + hci_dev_clear_flag(hdev, HCI_ADVERTISING_CONNECTABLE); + } else { + changed = hci_dev_test_and_clear_flag(hdev, HCI_ADVERTISING); + hci_dev_clear_flag(hdev, HCI_ADVERTISING_CONNECTABLE); + } + + err = send_settings_rsp(sk, MGMT_OP_SET_ADVERTISING, hdev); + if (err < 0) + goto unlock; + + if (changed) + err = new_settings(hdev, sk); + + goto unlock; + } + + if (pending_find(MGMT_OP_SET_ADVERTISING, hdev) || + pending_find(MGMT_OP_SET_LE, hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_ADVERTISING, + MGMT_STATUS_BUSY); + goto unlock; + } + + cmd = mgmt_pending_add(sk, MGMT_OP_SET_ADVERTISING, hdev, data, len); + if (!cmd) + err = -ENOMEM; + else + err = hci_cmd_sync_queue(hdev, set_adv_sync, cmd, + set_advertising_complete); + + if (err < 0 && cmd) + mgmt_pending_remove(cmd); + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static int set_static_address(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_cp_set_static_address *cp = data; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!lmp_le_capable(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_STATIC_ADDRESS, + MGMT_STATUS_NOT_SUPPORTED); + + if (hdev_is_powered(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_STATIC_ADDRESS, + MGMT_STATUS_REJECTED); + + if (bacmp(&cp->bdaddr, BDADDR_ANY)) { + if (!bacmp(&cp->bdaddr, BDADDR_NONE)) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_STATIC_ADDRESS, + MGMT_STATUS_INVALID_PARAMS); + + /* Two most significant bits shall be set */ + if ((cp->bdaddr.b[5] & 0xc0) != 0xc0) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_STATIC_ADDRESS, + MGMT_STATUS_INVALID_PARAMS); + } + + hci_dev_lock(hdev); + + bacpy(&hdev->static_addr, &cp->bdaddr); + + err = send_settings_rsp(sk, MGMT_OP_SET_STATIC_ADDRESS, hdev); + if (err < 0) + goto unlock; + + err = new_settings(hdev, sk); + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static int set_scan_params(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_cp_set_scan_params *cp = data; + __u16 interval, window; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!lmp_le_capable(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_SCAN_PARAMS, + MGMT_STATUS_NOT_SUPPORTED); + + interval = __le16_to_cpu(cp->interval); + + if (interval < 0x0004 || interval > 0x4000) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_SCAN_PARAMS, + MGMT_STATUS_INVALID_PARAMS); + + window = __le16_to_cpu(cp->window); + + if (window < 0x0004 || window > 0x4000) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_SCAN_PARAMS, + MGMT_STATUS_INVALID_PARAMS); + + if (window > interval) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_SCAN_PARAMS, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + hdev->le_scan_interval = interval; + hdev->le_scan_window = window; + + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_SET_SCAN_PARAMS, 0, + NULL, 0); + + /* If background scan is running, restart it so new parameters are + * loaded. + */ + if (hci_dev_test_flag(hdev, HCI_LE_SCAN) && + hdev->discovery.state == DISCOVERY_STOPPED) + hci_update_passive_scan(hdev); + + hci_dev_unlock(hdev); + + return err; +} + +static void fast_connectable_complete(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_pending_cmd *cmd = data; + + bt_dev_dbg(hdev, "err %d", err); + + if (err) { + mgmt_cmd_status(cmd->sk, hdev->id, MGMT_OP_SET_FAST_CONNECTABLE, + mgmt_status(err)); + } else { + struct mgmt_mode *cp = cmd->param; + + if (cp->val) + hci_dev_set_flag(hdev, HCI_FAST_CONNECTABLE); + else + hci_dev_clear_flag(hdev, HCI_FAST_CONNECTABLE); + + send_settings_rsp(cmd->sk, MGMT_OP_SET_FAST_CONNECTABLE, hdev); + new_settings(hdev, cmd->sk); + } + + mgmt_pending_free(cmd); +} + +static int write_fast_connectable_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_mode *cp = cmd->param; + + return hci_write_fast_connectable_sync(hdev, cp->val); +} + +static int set_fast_connectable(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_mode *cp = data; + struct mgmt_pending_cmd *cmd; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED) || + hdev->hci_ver < BLUETOOTH_VER_1_2) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_FAST_CONNECTABLE, + MGMT_STATUS_NOT_SUPPORTED); + + if (cp->val != 0x00 && cp->val != 0x01) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_FAST_CONNECTABLE, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + if (!!cp->val == hci_dev_test_flag(hdev, HCI_FAST_CONNECTABLE)) { + err = send_settings_rsp(sk, MGMT_OP_SET_FAST_CONNECTABLE, hdev); + goto unlock; + } + + if (!hdev_is_powered(hdev)) { + hci_dev_change_flag(hdev, HCI_FAST_CONNECTABLE); + err = send_settings_rsp(sk, MGMT_OP_SET_FAST_CONNECTABLE, hdev); + new_settings(hdev, sk); + goto unlock; + } + + cmd = mgmt_pending_new(sk, MGMT_OP_SET_FAST_CONNECTABLE, hdev, data, + len); + if (!cmd) + err = -ENOMEM; + else + err = hci_cmd_sync_queue(hdev, write_fast_connectable_sync, cmd, + fast_connectable_complete); + + if (err < 0) { + mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_FAST_CONNECTABLE, + MGMT_STATUS_FAILED); + + if (cmd) + mgmt_pending_free(cmd); + } + +unlock: + hci_dev_unlock(hdev); + + return err; +} + +static void set_bredr_complete(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_pending_cmd *cmd = data; + + bt_dev_dbg(hdev, "err %d", err); + + if (err) { + u8 mgmt_err = mgmt_status(err); + + /* We need to restore the flag if related HCI commands + * failed. + */ + hci_dev_clear_flag(hdev, HCI_BREDR_ENABLED); + + mgmt_cmd_status(cmd->sk, cmd->index, cmd->opcode, mgmt_err); + } else { + send_settings_rsp(cmd->sk, MGMT_OP_SET_BREDR, hdev); + new_settings(hdev, cmd->sk); + } + + mgmt_pending_free(cmd); +} + +static int set_bredr_sync(struct hci_dev *hdev, void *data) +{ + int status; + + status = hci_write_fast_connectable_sync(hdev, false); + + if (!status) + status = hci_update_scan_sync(hdev); + + /* Since only the advertising data flags will change, there + * is no need to update the scan response data. + */ + if (!status) + status = hci_update_adv_data_sync(hdev, hdev->cur_adv_instance); + + return status; +} + +static int set_bredr(struct sock *sk, struct hci_dev *hdev, void *data, u16 len) +{ + struct mgmt_mode *cp = data; + struct mgmt_pending_cmd *cmd; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!lmp_bredr_capable(hdev) || !lmp_le_capable(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_BREDR, + MGMT_STATUS_NOT_SUPPORTED); + + if (!hci_dev_test_flag(hdev, HCI_LE_ENABLED)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_BREDR, + MGMT_STATUS_REJECTED); + + if (cp->val != 0x00 && cp->val != 0x01) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_BREDR, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + if (cp->val == hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) { + err = send_settings_rsp(sk, MGMT_OP_SET_BREDR, hdev); + goto unlock; + } + + if (!hdev_is_powered(hdev)) { + if (!cp->val) { + hci_dev_clear_flag(hdev, HCI_DISCOVERABLE); + hci_dev_clear_flag(hdev, HCI_SSP_ENABLED); + hci_dev_clear_flag(hdev, HCI_LINK_SECURITY); + hci_dev_clear_flag(hdev, HCI_FAST_CONNECTABLE); + hci_dev_clear_flag(hdev, HCI_HS_ENABLED); + } + + hci_dev_change_flag(hdev, HCI_BREDR_ENABLED); + + err = send_settings_rsp(sk, MGMT_OP_SET_BREDR, hdev); + if (err < 0) + goto unlock; + + err = new_settings(hdev, sk); + goto unlock; + } + + /* Reject disabling when powered on */ + if (!cp->val) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_BREDR, + MGMT_STATUS_REJECTED); + goto unlock; + } else { + /* When configuring a dual-mode controller to operate + * with LE only and using a static address, then switching + * BR/EDR back on is not allowed. + * + * Dual-mode controllers shall operate with the public + * address as its identity address for BR/EDR and LE. So + * reject the attempt to create an invalid configuration. + * + * The same restrictions applies when secure connections + * has been enabled. For BR/EDR this is a controller feature + * while for LE it is a host stack feature. This means that + * switching BR/EDR back on when secure connections has been + * enabled is not a supported transaction. + */ + if (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED) && + (bacmp(&hdev->static_addr, BDADDR_ANY) || + hci_dev_test_flag(hdev, HCI_SC_ENABLED))) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_BREDR, + MGMT_STATUS_REJECTED); + goto unlock; + } + } + + cmd = mgmt_pending_new(sk, MGMT_OP_SET_BREDR, hdev, data, len); + if (!cmd) + err = -ENOMEM; + else + err = hci_cmd_sync_queue(hdev, set_bredr_sync, cmd, + set_bredr_complete); + + if (err < 0) { + mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_BREDR, + MGMT_STATUS_FAILED); + if (cmd) + mgmt_pending_free(cmd); + + goto unlock; + } + + /* We need to flip the bit already here so that + * hci_req_update_adv_data generates the correct flags. + */ + hci_dev_set_flag(hdev, HCI_BREDR_ENABLED); + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static void set_secure_conn_complete(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_mode *cp; + + bt_dev_dbg(hdev, "err %d", err); + + if (err) { + u8 mgmt_err = mgmt_status(err); + + mgmt_cmd_status(cmd->sk, cmd->index, cmd->opcode, mgmt_err); + goto done; + } + + cp = cmd->param; + + switch (cp->val) { + case 0x00: + hci_dev_clear_flag(hdev, HCI_SC_ENABLED); + hci_dev_clear_flag(hdev, HCI_SC_ONLY); + break; + case 0x01: + hci_dev_set_flag(hdev, HCI_SC_ENABLED); + hci_dev_clear_flag(hdev, HCI_SC_ONLY); + break; + case 0x02: + hci_dev_set_flag(hdev, HCI_SC_ENABLED); + hci_dev_set_flag(hdev, HCI_SC_ONLY); + break; + } + + send_settings_rsp(cmd->sk, cmd->opcode, hdev); + new_settings(hdev, cmd->sk); + +done: + mgmt_pending_free(cmd); +} + +static int set_secure_conn_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_mode *cp = cmd->param; + u8 val = !!cp->val; + + /* Force write of val */ + hci_dev_set_flag(hdev, HCI_SC_ENABLED); + + return hci_write_sc_support_sync(hdev, val); +} + +static int set_secure_conn(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_mode *cp = data; + struct mgmt_pending_cmd *cmd; + u8 val; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!lmp_sc_capable(hdev) && + !hci_dev_test_flag(hdev, HCI_LE_ENABLED)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_SECURE_CONN, + MGMT_STATUS_NOT_SUPPORTED); + + if (hci_dev_test_flag(hdev, HCI_BREDR_ENABLED) && + lmp_sc_capable(hdev) && + !hci_dev_test_flag(hdev, HCI_SSP_ENABLED)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_SECURE_CONN, + MGMT_STATUS_REJECTED); + + if (cp->val != 0x00 && cp->val != 0x01 && cp->val != 0x02) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_SECURE_CONN, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + if (!hdev_is_powered(hdev) || !lmp_sc_capable(hdev) || + !hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) { + bool changed; + + if (cp->val) { + changed = !hci_dev_test_and_set_flag(hdev, + HCI_SC_ENABLED); + if (cp->val == 0x02) + hci_dev_set_flag(hdev, HCI_SC_ONLY); + else + hci_dev_clear_flag(hdev, HCI_SC_ONLY); + } else { + changed = hci_dev_test_and_clear_flag(hdev, + HCI_SC_ENABLED); + hci_dev_clear_flag(hdev, HCI_SC_ONLY); + } + + err = send_settings_rsp(sk, MGMT_OP_SET_SECURE_CONN, hdev); + if (err < 0) + goto failed; + + if (changed) + err = new_settings(hdev, sk); + + goto failed; + } + + val = !!cp->val; + + if (val == hci_dev_test_flag(hdev, HCI_SC_ENABLED) && + (cp->val == 0x02) == hci_dev_test_flag(hdev, HCI_SC_ONLY)) { + err = send_settings_rsp(sk, MGMT_OP_SET_SECURE_CONN, hdev); + goto failed; + } + + cmd = mgmt_pending_new(sk, MGMT_OP_SET_SECURE_CONN, hdev, data, len); + if (!cmd) + err = -ENOMEM; + else + err = hci_cmd_sync_queue(hdev, set_secure_conn_sync, cmd, + set_secure_conn_complete); + + if (err < 0) { + mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_SECURE_CONN, + MGMT_STATUS_FAILED); + if (cmd) + mgmt_pending_free(cmd); + } + +failed: + hci_dev_unlock(hdev); + return err; +} + +static int set_debug_keys(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_mode *cp = data; + bool changed, use_changed; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (cp->val != 0x00 && cp->val != 0x01 && cp->val != 0x02) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_DEBUG_KEYS, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + if (cp->val) + changed = !hci_dev_test_and_set_flag(hdev, HCI_KEEP_DEBUG_KEYS); + else + changed = hci_dev_test_and_clear_flag(hdev, + HCI_KEEP_DEBUG_KEYS); + + if (cp->val == 0x02) + use_changed = !hci_dev_test_and_set_flag(hdev, + HCI_USE_DEBUG_KEYS); + else + use_changed = hci_dev_test_and_clear_flag(hdev, + HCI_USE_DEBUG_KEYS); + + if (hdev_is_powered(hdev) && use_changed && + hci_dev_test_flag(hdev, HCI_SSP_ENABLED)) { + u8 mode = (cp->val == 0x02) ? 0x01 : 0x00; + hci_send_cmd(hdev, HCI_OP_WRITE_SSP_DEBUG_MODE, + sizeof(mode), &mode); + } + + err = send_settings_rsp(sk, MGMT_OP_SET_DEBUG_KEYS, hdev); + if (err < 0) + goto unlock; + + if (changed) + err = new_settings(hdev, sk); + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static int set_privacy(struct sock *sk, struct hci_dev *hdev, void *cp_data, + u16 len) +{ + struct mgmt_cp_set_privacy *cp = cp_data; + bool changed; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!lmp_le_capable(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_PRIVACY, + MGMT_STATUS_NOT_SUPPORTED); + + if (cp->privacy != 0x00 && cp->privacy != 0x01 && cp->privacy != 0x02) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_PRIVACY, + MGMT_STATUS_INVALID_PARAMS); + + if (hdev_is_powered(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_PRIVACY, + MGMT_STATUS_REJECTED); + + hci_dev_lock(hdev); + + /* If user space supports this command it is also expected to + * handle IRKs. Therefore, set the HCI_RPA_RESOLVING flag. + */ + hci_dev_set_flag(hdev, HCI_RPA_RESOLVING); + + if (cp->privacy) { + changed = !hci_dev_test_and_set_flag(hdev, HCI_PRIVACY); + memcpy(hdev->irk, cp->irk, sizeof(hdev->irk)); + hci_dev_set_flag(hdev, HCI_RPA_EXPIRED); + hci_adv_instances_set_rpa_expired(hdev, true); + if (cp->privacy == 0x02) + hci_dev_set_flag(hdev, HCI_LIMITED_PRIVACY); + else + hci_dev_clear_flag(hdev, HCI_LIMITED_PRIVACY); + } else { + changed = hci_dev_test_and_clear_flag(hdev, HCI_PRIVACY); + memset(hdev->irk, 0, sizeof(hdev->irk)); + hci_dev_clear_flag(hdev, HCI_RPA_EXPIRED); + hci_adv_instances_set_rpa_expired(hdev, false); + hci_dev_clear_flag(hdev, HCI_LIMITED_PRIVACY); + } + + err = send_settings_rsp(sk, MGMT_OP_SET_PRIVACY, hdev); + if (err < 0) + goto unlock; + + if (changed) + err = new_settings(hdev, sk); + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static bool irk_is_valid(struct mgmt_irk_info *irk) +{ + switch (irk->addr.type) { + case BDADDR_LE_PUBLIC: + return true; + + case BDADDR_LE_RANDOM: + /* Two most significant bits shall be set */ + if ((irk->addr.bdaddr.b[5] & 0xc0) != 0xc0) + return false; + return true; + } + + return false; +} + +static int load_irks(struct sock *sk, struct hci_dev *hdev, void *cp_data, + u16 len) +{ + struct mgmt_cp_load_irks *cp = cp_data; + const u16 max_irk_count = ((U16_MAX - sizeof(*cp)) / + sizeof(struct mgmt_irk_info)); + u16 irk_count, expected_len; + int i, err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!lmp_le_capable(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_LOAD_IRKS, + MGMT_STATUS_NOT_SUPPORTED); + + irk_count = __le16_to_cpu(cp->irk_count); + if (irk_count > max_irk_count) { + bt_dev_err(hdev, "load_irks: too big irk_count value %u", + irk_count); + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_LOAD_IRKS, + MGMT_STATUS_INVALID_PARAMS); + } + + expected_len = struct_size(cp, irks, irk_count); + if (expected_len != len) { + bt_dev_err(hdev, "load_irks: expected %u bytes, got %u bytes", + expected_len, len); + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_LOAD_IRKS, + MGMT_STATUS_INVALID_PARAMS); + } + + bt_dev_dbg(hdev, "irk_count %u", irk_count); + + for (i = 0; i < irk_count; i++) { + struct mgmt_irk_info *key = &cp->irks[i]; + + if (!irk_is_valid(key)) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_LOAD_IRKS, + MGMT_STATUS_INVALID_PARAMS); + } + + hci_dev_lock(hdev); + + hci_smp_irks_clear(hdev); + + for (i = 0; i < irk_count; i++) { + struct mgmt_irk_info *irk = &cp->irks[i]; + u8 addr_type = le_addr_type(irk->addr.type); + + if (hci_is_blocked_key(hdev, + HCI_BLOCKED_KEY_TYPE_IRK, + irk->val)) { + bt_dev_warn(hdev, "Skipping blocked IRK for %pMR", + &irk->addr.bdaddr); + continue; + } + + /* When using SMP over BR/EDR, the addr type should be set to BREDR */ + if (irk->addr.type == BDADDR_BREDR) + addr_type = BDADDR_BREDR; + + hci_add_irk(hdev, &irk->addr.bdaddr, + addr_type, irk->val, + BDADDR_ANY); + } + + hci_dev_set_flag(hdev, HCI_RPA_RESOLVING); + + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_LOAD_IRKS, 0, NULL, 0); + + hci_dev_unlock(hdev); + + return err; +} + +static bool ltk_is_valid(struct mgmt_ltk_info *key) +{ + if (key->initiator != 0x00 && key->initiator != 0x01) + return false; + + switch (key->addr.type) { + case BDADDR_LE_PUBLIC: + return true; + + case BDADDR_LE_RANDOM: + /* Two most significant bits shall be set */ + if ((key->addr.bdaddr.b[5] & 0xc0) != 0xc0) + return false; + return true; + } + + return false; +} + +static int load_long_term_keys(struct sock *sk, struct hci_dev *hdev, + void *cp_data, u16 len) +{ + struct mgmt_cp_load_long_term_keys *cp = cp_data; + const u16 max_key_count = ((U16_MAX - sizeof(*cp)) / + sizeof(struct mgmt_ltk_info)); + u16 key_count, expected_len; + int i, err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!lmp_le_capable(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_LOAD_LONG_TERM_KEYS, + MGMT_STATUS_NOT_SUPPORTED); + + key_count = __le16_to_cpu(cp->key_count); + if (key_count > max_key_count) { + bt_dev_err(hdev, "load_ltks: too big key_count value %u", + key_count); + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_LOAD_LONG_TERM_KEYS, + MGMT_STATUS_INVALID_PARAMS); + } + + expected_len = struct_size(cp, keys, key_count); + if (expected_len != len) { + bt_dev_err(hdev, "load_keys: expected %u bytes, got %u bytes", + expected_len, len); + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_LOAD_LONG_TERM_KEYS, + MGMT_STATUS_INVALID_PARAMS); + } + + bt_dev_dbg(hdev, "key_count %u", key_count); + + for (i = 0; i < key_count; i++) { + struct mgmt_ltk_info *key = &cp->keys[i]; + + if (!ltk_is_valid(key)) + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_LOAD_LONG_TERM_KEYS, + MGMT_STATUS_INVALID_PARAMS); + } + + hci_dev_lock(hdev); + + hci_smp_ltks_clear(hdev); + + for (i = 0; i < key_count; i++) { + struct mgmt_ltk_info *key = &cp->keys[i]; + u8 type, authenticated; + u8 addr_type = le_addr_type(key->addr.type); + + if (hci_is_blocked_key(hdev, + HCI_BLOCKED_KEY_TYPE_LTK, + key->val)) { + bt_dev_warn(hdev, "Skipping blocked LTK for %pMR", + &key->addr.bdaddr); + continue; + } + + switch (key->type) { + case MGMT_LTK_UNAUTHENTICATED: + authenticated = 0x00; + type = key->initiator ? SMP_LTK : SMP_LTK_RESPONDER; + break; + case MGMT_LTK_AUTHENTICATED: + authenticated = 0x01; + type = key->initiator ? SMP_LTK : SMP_LTK_RESPONDER; + break; + case MGMT_LTK_P256_UNAUTH: + authenticated = 0x00; + type = SMP_LTK_P256; + break; + case MGMT_LTK_P256_AUTH: + authenticated = 0x01; + type = SMP_LTK_P256; + break; + case MGMT_LTK_P256_DEBUG: + authenticated = 0x00; + type = SMP_LTK_P256_DEBUG; + fallthrough; + default: + continue; + } + + /* When using SMP over BR/EDR, the addr type should be set to BREDR */ + if (key->addr.type == BDADDR_BREDR) + addr_type = BDADDR_BREDR; + + hci_add_ltk(hdev, &key->addr.bdaddr, + addr_type, type, authenticated, + key->val, key->enc_size, key->ediv, key->rand); + } + + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_LOAD_LONG_TERM_KEYS, 0, + NULL, 0); + + hci_dev_unlock(hdev); + + return err; +} + +static void get_conn_info_complete(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_pending_cmd *cmd = data; + struct hci_conn *conn = cmd->user_data; + struct mgmt_cp_get_conn_info *cp = cmd->param; + struct mgmt_rp_get_conn_info rp; + u8 status; + + bt_dev_dbg(hdev, "err %d", err); + + memcpy(&rp.addr, &cp->addr, sizeof(rp.addr)); + + status = mgmt_status(err); + if (status == MGMT_STATUS_SUCCESS) { + rp.rssi = conn->rssi; + rp.tx_power = conn->tx_power; + rp.max_tx_power = conn->max_tx_power; + } else { + rp.rssi = HCI_RSSI_INVALID; + rp.tx_power = HCI_TX_POWER_INVALID; + rp.max_tx_power = HCI_TX_POWER_INVALID; + } + + mgmt_cmd_complete(cmd->sk, cmd->index, MGMT_OP_GET_CONN_INFO, status, + &rp, sizeof(rp)); + + mgmt_pending_free(cmd); +} + +static int get_conn_info_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_get_conn_info *cp = cmd->param; + struct hci_conn *conn; + int err; + __le16 handle; + + /* Make sure we are still connected */ + if (cp->addr.type == BDADDR_BREDR) + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, + &cp->addr.bdaddr); + else + conn = hci_conn_hash_lookup_ba(hdev, LE_LINK, &cp->addr.bdaddr); + + if (!conn || conn->state != BT_CONNECTED) + return MGMT_STATUS_NOT_CONNECTED; + + cmd->user_data = conn; + handle = cpu_to_le16(conn->handle); + + /* Refresh RSSI each time */ + err = hci_read_rssi_sync(hdev, handle); + + /* For LE links TX power does not change thus we don't need to + * query for it once value is known. + */ + if (!err && (!bdaddr_type_is_le(cp->addr.type) || + conn->tx_power == HCI_TX_POWER_INVALID)) + err = hci_read_tx_power_sync(hdev, handle, 0x00); + + /* Max TX power needs to be read only once per connection */ + if (!err && conn->max_tx_power == HCI_TX_POWER_INVALID) + err = hci_read_tx_power_sync(hdev, handle, 0x01); + + return err; +} + +static int get_conn_info(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_get_conn_info *cp = data; + struct mgmt_rp_get_conn_info rp; + struct hci_conn *conn; + unsigned long conn_info_age; + int err = 0; + + bt_dev_dbg(hdev, "sock %p", sk); + + memset(&rp, 0, sizeof(rp)); + bacpy(&rp.addr.bdaddr, &cp->addr.bdaddr); + rp.addr.type = cp->addr.type; + + if (!bdaddr_type_is_valid(cp->addr.type)) + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_GET_CONN_INFO, + MGMT_STATUS_INVALID_PARAMS, + &rp, sizeof(rp)); + + hci_dev_lock(hdev); + + if (!hdev_is_powered(hdev)) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_GET_CONN_INFO, + MGMT_STATUS_NOT_POWERED, &rp, + sizeof(rp)); + goto unlock; + } + + if (cp->addr.type == BDADDR_BREDR) + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, + &cp->addr.bdaddr); + else + conn = hci_conn_hash_lookup_ba(hdev, LE_LINK, &cp->addr.bdaddr); + + if (!conn || conn->state != BT_CONNECTED) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_GET_CONN_INFO, + MGMT_STATUS_NOT_CONNECTED, &rp, + sizeof(rp)); + goto unlock; + } + + /* To avoid client trying to guess when to poll again for information we + * calculate conn info age as random value between min/max set in hdev. + */ + conn_info_age = hdev->conn_info_min_age + + prandom_u32_max(hdev->conn_info_max_age - + hdev->conn_info_min_age); + + /* Query controller to refresh cached values if they are too old or were + * never read. + */ + if (time_after(jiffies, conn->conn_info_timestamp + + msecs_to_jiffies(conn_info_age)) || + !conn->conn_info_timestamp) { + struct mgmt_pending_cmd *cmd; + + cmd = mgmt_pending_new(sk, MGMT_OP_GET_CONN_INFO, hdev, data, + len); + if (!cmd) { + err = -ENOMEM; + } else { + err = hci_cmd_sync_queue(hdev, get_conn_info_sync, + cmd, get_conn_info_complete); + } + + if (err < 0) { + mgmt_cmd_complete(sk, hdev->id, MGMT_OP_GET_CONN_INFO, + MGMT_STATUS_FAILED, &rp, sizeof(rp)); + + if (cmd) + mgmt_pending_free(cmd); + + goto unlock; + } + + conn->conn_info_timestamp = jiffies; + } else { + /* Cache is valid, just reply with values cached in hci_conn */ + rp.rssi = conn->rssi; + rp.tx_power = conn->tx_power; + rp.max_tx_power = conn->max_tx_power; + + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_GET_CONN_INFO, + MGMT_STATUS_SUCCESS, &rp, sizeof(rp)); + } + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static void get_clock_info_complete(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_get_clock_info *cp = cmd->param; + struct mgmt_rp_get_clock_info rp; + struct hci_conn *conn = cmd->user_data; + u8 status = mgmt_status(err); + + bt_dev_dbg(hdev, "err %d", err); + + memset(&rp, 0, sizeof(rp)); + bacpy(&rp.addr.bdaddr, &cp->addr.bdaddr); + rp.addr.type = cp->addr.type; + + if (err) + goto complete; + + rp.local_clock = cpu_to_le32(hdev->clock); + + if (conn) { + rp.piconet_clock = cpu_to_le32(conn->clock); + rp.accuracy = cpu_to_le16(conn->clock_accuracy); + } + +complete: + mgmt_cmd_complete(cmd->sk, cmd->index, cmd->opcode, status, &rp, + sizeof(rp)); + + mgmt_pending_free(cmd); +} + +static int get_clock_info_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_get_clock_info *cp = cmd->param; + struct hci_cp_read_clock hci_cp; + struct hci_conn *conn; + + memset(&hci_cp, 0, sizeof(hci_cp)); + hci_read_clock_sync(hdev, &hci_cp); + + /* Make sure connection still exists */ + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &cp->addr.bdaddr); + if (!conn || conn->state != BT_CONNECTED) + return MGMT_STATUS_NOT_CONNECTED; + + cmd->user_data = conn; + hci_cp.handle = cpu_to_le16(conn->handle); + hci_cp.which = 0x01; /* Piconet clock */ + + return hci_read_clock_sync(hdev, &hci_cp); +} + +static int get_clock_info(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_get_clock_info *cp = data; + struct mgmt_rp_get_clock_info rp; + struct mgmt_pending_cmd *cmd; + struct hci_conn *conn; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + memset(&rp, 0, sizeof(rp)); + bacpy(&rp.addr.bdaddr, &cp->addr.bdaddr); + rp.addr.type = cp->addr.type; + + if (cp->addr.type != BDADDR_BREDR) + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_GET_CLOCK_INFO, + MGMT_STATUS_INVALID_PARAMS, + &rp, sizeof(rp)); + + hci_dev_lock(hdev); + + if (!hdev_is_powered(hdev)) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_GET_CLOCK_INFO, + MGMT_STATUS_NOT_POWERED, &rp, + sizeof(rp)); + goto unlock; + } + + if (bacmp(&cp->addr.bdaddr, BDADDR_ANY)) { + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, + &cp->addr.bdaddr); + if (!conn || conn->state != BT_CONNECTED) { + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_GET_CLOCK_INFO, + MGMT_STATUS_NOT_CONNECTED, + &rp, sizeof(rp)); + goto unlock; + } + } else { + conn = NULL; + } + + cmd = mgmt_pending_new(sk, MGMT_OP_GET_CLOCK_INFO, hdev, data, len); + if (!cmd) + err = -ENOMEM; + else + err = hci_cmd_sync_queue(hdev, get_clock_info_sync, cmd, + get_clock_info_complete); + + if (err < 0) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_GET_CLOCK_INFO, + MGMT_STATUS_FAILED, &rp, sizeof(rp)); + + if (cmd) + mgmt_pending_free(cmd); + } + + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static bool is_connected(struct hci_dev *hdev, bdaddr_t *addr, u8 type) +{ + struct hci_conn *conn; + + conn = hci_conn_hash_lookup_ba(hdev, LE_LINK, addr); + if (!conn) + return false; + + if (conn->dst_type != type) + return false; + + if (conn->state != BT_CONNECTED) + return false; + + return true; +} + +/* This function requires the caller holds hdev->lock */ +static int hci_conn_params_set(struct hci_dev *hdev, bdaddr_t *addr, + u8 addr_type, u8 auto_connect) +{ + struct hci_conn_params *params; + + params = hci_conn_params_add(hdev, addr, addr_type); + if (!params) + return -EIO; + + if (params->auto_connect == auto_connect) + return 0; + + hci_pend_le_list_del_init(params); + + switch (auto_connect) { + case HCI_AUTO_CONN_DISABLED: + case HCI_AUTO_CONN_LINK_LOSS: + /* If auto connect is being disabled when we're trying to + * connect to device, keep connecting. + */ + if (params->explicit_connect) + hci_pend_le_list_add(params, &hdev->pend_le_conns); + break; + case HCI_AUTO_CONN_REPORT: + if (params->explicit_connect) + hci_pend_le_list_add(params, &hdev->pend_le_conns); + else + hci_pend_le_list_add(params, &hdev->pend_le_reports); + break; + case HCI_AUTO_CONN_DIRECT: + case HCI_AUTO_CONN_ALWAYS: + if (!is_connected(hdev, addr, addr_type)) + hci_pend_le_list_add(params, &hdev->pend_le_conns); + break; + } + + params->auto_connect = auto_connect; + + bt_dev_dbg(hdev, "addr %pMR (type %u) auto_connect %u", + addr, addr_type, auto_connect); + + return 0; +} + +static void device_added(struct sock *sk, struct hci_dev *hdev, + bdaddr_t *bdaddr, u8 type, u8 action) +{ + struct mgmt_ev_device_added ev; + + bacpy(&ev.addr.bdaddr, bdaddr); + ev.addr.type = type; + ev.action = action; + + mgmt_event(MGMT_EV_DEVICE_ADDED, hdev, &ev, sizeof(ev), sk); +} + +static int add_device_sync(struct hci_dev *hdev, void *data) +{ + return hci_update_passive_scan_sync(hdev); +} + +static int add_device(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_cp_add_device *cp = data; + u8 auto_conn, addr_type; + struct hci_conn_params *params; + int err; + u32 current_flags = 0; + u32 supported_flags; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!bdaddr_type_is_valid(cp->addr.type) || + !bacmp(&cp->addr.bdaddr, BDADDR_ANY)) + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_ADD_DEVICE, + MGMT_STATUS_INVALID_PARAMS, + &cp->addr, sizeof(cp->addr)); + + if (cp->action != 0x00 && cp->action != 0x01 && cp->action != 0x02) + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_ADD_DEVICE, + MGMT_STATUS_INVALID_PARAMS, + &cp->addr, sizeof(cp->addr)); + + hci_dev_lock(hdev); + + if (cp->addr.type == BDADDR_BREDR) { + /* Only incoming connections action is supported for now */ + if (cp->action != 0x01) { + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_ADD_DEVICE, + MGMT_STATUS_INVALID_PARAMS, + &cp->addr, sizeof(cp->addr)); + goto unlock; + } + + err = hci_bdaddr_list_add_with_flags(&hdev->accept_list, + &cp->addr.bdaddr, + cp->addr.type, 0); + if (err) + goto unlock; + + hci_update_scan(hdev); + + goto added; + } + + addr_type = le_addr_type(cp->addr.type); + + if (cp->action == 0x02) + auto_conn = HCI_AUTO_CONN_ALWAYS; + else if (cp->action == 0x01) + auto_conn = HCI_AUTO_CONN_DIRECT; + else + auto_conn = HCI_AUTO_CONN_REPORT; + + /* Kernel internally uses conn_params with resolvable private + * address, but Add Device allows only identity addresses. + * Make sure it is enforced before calling + * hci_conn_params_lookup. + */ + if (!hci_is_identity_address(&cp->addr.bdaddr, addr_type)) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_ADD_DEVICE, + MGMT_STATUS_INVALID_PARAMS, + &cp->addr, sizeof(cp->addr)); + goto unlock; + } + + /* If the connection parameters don't exist for this device, + * they will be created and configured with defaults. + */ + if (hci_conn_params_set(hdev, &cp->addr.bdaddr, addr_type, + auto_conn) < 0) { + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_ADD_DEVICE, + MGMT_STATUS_FAILED, &cp->addr, + sizeof(cp->addr)); + goto unlock; + } else { + params = hci_conn_params_lookup(hdev, &cp->addr.bdaddr, + addr_type); + if (params) + current_flags = params->flags; + } + + err = hci_cmd_sync_queue(hdev, add_device_sync, NULL, NULL); + if (err < 0) + goto unlock; + +added: + device_added(sk, hdev, &cp->addr.bdaddr, cp->addr.type, cp->action); + supported_flags = hdev->conn_flags; + device_flags_changed(NULL, hdev, &cp->addr.bdaddr, cp->addr.type, + supported_flags, current_flags); + + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_ADD_DEVICE, + MGMT_STATUS_SUCCESS, &cp->addr, + sizeof(cp->addr)); + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static void device_removed(struct sock *sk, struct hci_dev *hdev, + bdaddr_t *bdaddr, u8 type) +{ + struct mgmt_ev_device_removed ev; + + bacpy(&ev.addr.bdaddr, bdaddr); + ev.addr.type = type; + + mgmt_event(MGMT_EV_DEVICE_REMOVED, hdev, &ev, sizeof(ev), sk); +} + +static int remove_device_sync(struct hci_dev *hdev, void *data) +{ + return hci_update_passive_scan_sync(hdev); +} + +static int remove_device(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_cp_remove_device *cp = data; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + hci_dev_lock(hdev); + + if (bacmp(&cp->addr.bdaddr, BDADDR_ANY)) { + struct hci_conn_params *params; + u8 addr_type; + + if (!bdaddr_type_is_valid(cp->addr.type)) { + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_REMOVE_DEVICE, + MGMT_STATUS_INVALID_PARAMS, + &cp->addr, sizeof(cp->addr)); + goto unlock; + } + + if (cp->addr.type == BDADDR_BREDR) { + err = hci_bdaddr_list_del(&hdev->accept_list, + &cp->addr.bdaddr, + cp->addr.type); + if (err) { + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_REMOVE_DEVICE, + MGMT_STATUS_INVALID_PARAMS, + &cp->addr, + sizeof(cp->addr)); + goto unlock; + } + + hci_update_scan(hdev); + + device_removed(sk, hdev, &cp->addr.bdaddr, + cp->addr.type); + goto complete; + } + + addr_type = le_addr_type(cp->addr.type); + + /* Kernel internally uses conn_params with resolvable private + * address, but Remove Device allows only identity addresses. + * Make sure it is enforced before calling + * hci_conn_params_lookup. + */ + if (!hci_is_identity_address(&cp->addr.bdaddr, addr_type)) { + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_REMOVE_DEVICE, + MGMT_STATUS_INVALID_PARAMS, + &cp->addr, sizeof(cp->addr)); + goto unlock; + } + + params = hci_conn_params_lookup(hdev, &cp->addr.bdaddr, + addr_type); + if (!params) { + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_REMOVE_DEVICE, + MGMT_STATUS_INVALID_PARAMS, + &cp->addr, sizeof(cp->addr)); + goto unlock; + } + + if (params->auto_connect == HCI_AUTO_CONN_DISABLED || + params->auto_connect == HCI_AUTO_CONN_EXPLICIT) { + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_REMOVE_DEVICE, + MGMT_STATUS_INVALID_PARAMS, + &cp->addr, sizeof(cp->addr)); + goto unlock; + } + + hci_conn_params_free(params); + + device_removed(sk, hdev, &cp->addr.bdaddr, cp->addr.type); + } else { + struct hci_conn_params *p, *tmp; + struct bdaddr_list *b, *btmp; + + if (cp->addr.type) { + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_REMOVE_DEVICE, + MGMT_STATUS_INVALID_PARAMS, + &cp->addr, sizeof(cp->addr)); + goto unlock; + } + + list_for_each_entry_safe(b, btmp, &hdev->accept_list, list) { + device_removed(sk, hdev, &b->bdaddr, b->bdaddr_type); + list_del(&b->list); + kfree(b); + } + + hci_update_scan(hdev); + + list_for_each_entry_safe(p, tmp, &hdev->le_conn_params, list) { + if (p->auto_connect == HCI_AUTO_CONN_DISABLED) + continue; + device_removed(sk, hdev, &p->addr, p->addr_type); + if (p->explicit_connect) { + p->auto_connect = HCI_AUTO_CONN_EXPLICIT; + continue; + } + hci_conn_params_free(p); + } + + bt_dev_dbg(hdev, "All LE connection parameters were removed"); + } + + hci_cmd_sync_queue(hdev, remove_device_sync, NULL, NULL); + +complete: + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_REMOVE_DEVICE, + MGMT_STATUS_SUCCESS, &cp->addr, + sizeof(cp->addr)); +unlock: + hci_dev_unlock(hdev); + return err; +} + +static int load_conn_param(struct sock *sk, struct hci_dev *hdev, void *data, + u16 len) +{ + struct mgmt_cp_load_conn_param *cp = data; + const u16 max_param_count = ((U16_MAX - sizeof(*cp)) / + sizeof(struct mgmt_conn_param)); + u16 param_count, expected_len; + int i; + + if (!lmp_le_capable(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_LOAD_CONN_PARAM, + MGMT_STATUS_NOT_SUPPORTED); + + param_count = __le16_to_cpu(cp->param_count); + if (param_count > max_param_count) { + bt_dev_err(hdev, "load_conn_param: too big param_count value %u", + param_count); + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_LOAD_CONN_PARAM, + MGMT_STATUS_INVALID_PARAMS); + } + + expected_len = struct_size(cp, params, param_count); + if (expected_len != len) { + bt_dev_err(hdev, "load_conn_param: expected %u bytes, got %u bytes", + expected_len, len); + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_LOAD_CONN_PARAM, + MGMT_STATUS_INVALID_PARAMS); + } + + bt_dev_dbg(hdev, "param_count %u", param_count); + + hci_dev_lock(hdev); + + hci_conn_params_clear_disabled(hdev); + + for (i = 0; i < param_count; i++) { + struct mgmt_conn_param *param = &cp->params[i]; + struct hci_conn_params *hci_param; + u16 min, max, latency, timeout; + u8 addr_type; + + bt_dev_dbg(hdev, "Adding %pMR (type %u)", ¶m->addr.bdaddr, + param->addr.type); + + if (param->addr.type == BDADDR_LE_PUBLIC) { + addr_type = ADDR_LE_DEV_PUBLIC; + } else if (param->addr.type == BDADDR_LE_RANDOM) { + addr_type = ADDR_LE_DEV_RANDOM; + } else { + bt_dev_err(hdev, "ignoring invalid connection parameters"); + continue; + } + + min = le16_to_cpu(param->min_interval); + max = le16_to_cpu(param->max_interval); + latency = le16_to_cpu(param->latency); + timeout = le16_to_cpu(param->timeout); + + bt_dev_dbg(hdev, "min 0x%04x max 0x%04x latency 0x%04x timeout 0x%04x", + min, max, latency, timeout); + + if (hci_check_conn_params(min, max, latency, timeout) < 0) { + bt_dev_err(hdev, "ignoring invalid connection parameters"); + continue; + } + + hci_param = hci_conn_params_add(hdev, ¶m->addr.bdaddr, + addr_type); + if (!hci_param) { + bt_dev_err(hdev, "failed to add connection parameters"); + continue; + } + + hci_param->conn_min_interval = min; + hci_param->conn_max_interval = max; + hci_param->conn_latency = latency; + hci_param->supervision_timeout = timeout; + } + + hci_dev_unlock(hdev); + + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_LOAD_CONN_PARAM, 0, + NULL, 0); +} + +static int set_external_config(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_cp_set_external_config *cp = data; + bool changed; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (hdev_is_powered(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_EXTERNAL_CONFIG, + MGMT_STATUS_REJECTED); + + if (cp->config != 0x00 && cp->config != 0x01) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_EXTERNAL_CONFIG, + MGMT_STATUS_INVALID_PARAMS); + + if (!test_bit(HCI_QUIRK_EXTERNAL_CONFIG, &hdev->quirks)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_EXTERNAL_CONFIG, + MGMT_STATUS_NOT_SUPPORTED); + + hci_dev_lock(hdev); + + if (cp->config) + changed = !hci_dev_test_and_set_flag(hdev, HCI_EXT_CONFIGURED); + else + changed = hci_dev_test_and_clear_flag(hdev, HCI_EXT_CONFIGURED); + + err = send_options_rsp(sk, MGMT_OP_SET_EXTERNAL_CONFIG, hdev); + if (err < 0) + goto unlock; + + if (!changed) + goto unlock; + + err = new_options(hdev, sk); + + if (hci_dev_test_flag(hdev, HCI_UNCONFIGURED) == is_configured(hdev)) { + mgmt_index_removed(hdev); + + if (hci_dev_test_and_change_flag(hdev, HCI_UNCONFIGURED)) { + hci_dev_set_flag(hdev, HCI_CONFIG); + hci_dev_set_flag(hdev, HCI_AUTO_OFF); + + queue_work(hdev->req_workqueue, &hdev->power_on); + } else { + set_bit(HCI_RAW, &hdev->flags); + mgmt_index_added(hdev); + } + } + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static int set_public_address(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_cp_set_public_address *cp = data; + bool changed; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (hdev_is_powered(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_PUBLIC_ADDRESS, + MGMT_STATUS_REJECTED); + + if (!bacmp(&cp->bdaddr, BDADDR_ANY)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_PUBLIC_ADDRESS, + MGMT_STATUS_INVALID_PARAMS); + + if (!hdev->set_bdaddr) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_PUBLIC_ADDRESS, + MGMT_STATUS_NOT_SUPPORTED); + + hci_dev_lock(hdev); + + changed = !!bacmp(&hdev->public_addr, &cp->bdaddr); + bacpy(&hdev->public_addr, &cp->bdaddr); + + err = send_options_rsp(sk, MGMT_OP_SET_PUBLIC_ADDRESS, hdev); + if (err < 0) + goto unlock; + + if (!changed) + goto unlock; + + if (hci_dev_test_flag(hdev, HCI_UNCONFIGURED)) + err = new_options(hdev, sk); + + if (is_configured(hdev)) { + mgmt_index_removed(hdev); + + hci_dev_clear_flag(hdev, HCI_UNCONFIGURED); + + hci_dev_set_flag(hdev, HCI_CONFIG); + hci_dev_set_flag(hdev, HCI_AUTO_OFF); + + queue_work(hdev->req_workqueue, &hdev->power_on); + } + +unlock: + hci_dev_unlock(hdev); + return err; +} + +static void read_local_oob_ext_data_complete(struct hci_dev *hdev, void *data, + int err) +{ + const struct mgmt_cp_read_local_oob_ext_data *mgmt_cp; + struct mgmt_rp_read_local_oob_ext_data *mgmt_rp; + u8 *h192, *r192, *h256, *r256; + struct mgmt_pending_cmd *cmd = data; + struct sk_buff *skb = cmd->skb; + u8 status = mgmt_status(err); + u16 eir_len; + + if (cmd != pending_find(MGMT_OP_READ_LOCAL_OOB_EXT_DATA, hdev)) + return; + + if (!status) { + if (!skb) + status = MGMT_STATUS_FAILED; + else if (IS_ERR(skb)) + status = mgmt_status(PTR_ERR(skb)); + else + status = mgmt_status(skb->data[0]); + } + + bt_dev_dbg(hdev, "status %u", status); + + mgmt_cp = cmd->param; + + if (status) { + status = mgmt_status(status); + eir_len = 0; + + h192 = NULL; + r192 = NULL; + h256 = NULL; + r256 = NULL; + } else if (!bredr_sc_enabled(hdev)) { + struct hci_rp_read_local_oob_data *rp; + + if (skb->len != sizeof(*rp)) { + status = MGMT_STATUS_FAILED; + eir_len = 0; + } else { + status = MGMT_STATUS_SUCCESS; + rp = (void *)skb->data; + + eir_len = 5 + 18 + 18; + h192 = rp->hash; + r192 = rp->rand; + h256 = NULL; + r256 = NULL; + } + } else { + struct hci_rp_read_local_oob_ext_data *rp; + + if (skb->len != sizeof(*rp)) { + status = MGMT_STATUS_FAILED; + eir_len = 0; + } else { + status = MGMT_STATUS_SUCCESS; + rp = (void *)skb->data; + + if (hci_dev_test_flag(hdev, HCI_SC_ONLY)) { + eir_len = 5 + 18 + 18; + h192 = NULL; + r192 = NULL; + } else { + eir_len = 5 + 18 + 18 + 18 + 18; + h192 = rp->hash192; + r192 = rp->rand192; + } + + h256 = rp->hash256; + r256 = rp->rand256; + } + } + + mgmt_rp = kmalloc(sizeof(*mgmt_rp) + eir_len, GFP_KERNEL); + if (!mgmt_rp) + goto done; + + if (eir_len == 0) + goto send_rsp; + + eir_len = eir_append_data(mgmt_rp->eir, 0, EIR_CLASS_OF_DEV, + hdev->dev_class, 3); + + if (h192 && r192) { + eir_len = eir_append_data(mgmt_rp->eir, eir_len, + EIR_SSP_HASH_C192, h192, 16); + eir_len = eir_append_data(mgmt_rp->eir, eir_len, + EIR_SSP_RAND_R192, r192, 16); + } + + if (h256 && r256) { + eir_len = eir_append_data(mgmt_rp->eir, eir_len, + EIR_SSP_HASH_C256, h256, 16); + eir_len = eir_append_data(mgmt_rp->eir, eir_len, + EIR_SSP_RAND_R256, r256, 16); + } + +send_rsp: + mgmt_rp->type = mgmt_cp->type; + mgmt_rp->eir_len = cpu_to_le16(eir_len); + + err = mgmt_cmd_complete(cmd->sk, hdev->id, + MGMT_OP_READ_LOCAL_OOB_EXT_DATA, status, + mgmt_rp, sizeof(*mgmt_rp) + eir_len); + if (err < 0 || status) + goto done; + + hci_sock_set_flag(cmd->sk, HCI_MGMT_OOB_DATA_EVENTS); + + err = mgmt_limited_event(MGMT_EV_LOCAL_OOB_DATA_UPDATED, hdev, + mgmt_rp, sizeof(*mgmt_rp) + eir_len, + HCI_MGMT_OOB_DATA_EVENTS, cmd->sk); +done: + if (skb && !IS_ERR(skb)) + kfree_skb(skb); + + kfree(mgmt_rp); + mgmt_pending_remove(cmd); +} + +static int read_local_ssp_oob_req(struct hci_dev *hdev, struct sock *sk, + struct mgmt_cp_read_local_oob_ext_data *cp) +{ + struct mgmt_pending_cmd *cmd; + int err; + + cmd = mgmt_pending_add(sk, MGMT_OP_READ_LOCAL_OOB_EXT_DATA, hdev, + cp, sizeof(*cp)); + if (!cmd) + return -ENOMEM; + + err = hci_cmd_sync_queue(hdev, read_local_oob_data_sync, cmd, + read_local_oob_ext_data_complete); + + if (err < 0) { + mgmt_pending_remove(cmd); + return err; + } + + return 0; +} + +static int read_local_oob_ext_data(struct sock *sk, struct hci_dev *hdev, + void *data, u16 data_len) +{ + struct mgmt_cp_read_local_oob_ext_data *cp = data; + struct mgmt_rp_read_local_oob_ext_data *rp; + size_t rp_len; + u16 eir_len; + u8 status, flags, role, addr[7], hash[16], rand[16]; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (hdev_is_powered(hdev)) { + switch (cp->type) { + case BIT(BDADDR_BREDR): + status = mgmt_bredr_support(hdev); + if (status) + eir_len = 0; + else + eir_len = 5; + break; + case (BIT(BDADDR_LE_PUBLIC) | BIT(BDADDR_LE_RANDOM)): + status = mgmt_le_support(hdev); + if (status) + eir_len = 0; + else + eir_len = 9 + 3 + 18 + 18 + 3; + break; + default: + status = MGMT_STATUS_INVALID_PARAMS; + eir_len = 0; + break; + } + } else { + status = MGMT_STATUS_NOT_POWERED; + eir_len = 0; + } + + rp_len = sizeof(*rp) + eir_len; + rp = kmalloc(rp_len, GFP_ATOMIC); + if (!rp) + return -ENOMEM; + + if (!status && !lmp_ssp_capable(hdev)) { + status = MGMT_STATUS_NOT_SUPPORTED; + eir_len = 0; + } + + if (status) + goto complete; + + hci_dev_lock(hdev); + + eir_len = 0; + switch (cp->type) { + case BIT(BDADDR_BREDR): + if (hci_dev_test_flag(hdev, HCI_SSP_ENABLED)) { + err = read_local_ssp_oob_req(hdev, sk, cp); + hci_dev_unlock(hdev); + if (!err) + goto done; + + status = MGMT_STATUS_FAILED; + goto complete; + } else { + eir_len = eir_append_data(rp->eir, eir_len, + EIR_CLASS_OF_DEV, + hdev->dev_class, 3); + } + break; + case (BIT(BDADDR_LE_PUBLIC) | BIT(BDADDR_LE_RANDOM)): + if (hci_dev_test_flag(hdev, HCI_SC_ENABLED) && + smp_generate_oob(hdev, hash, rand) < 0) { + hci_dev_unlock(hdev); + status = MGMT_STATUS_FAILED; + goto complete; + } + + /* This should return the active RPA, but since the RPA + * is only programmed on demand, it is really hard to fill + * this in at the moment. For now disallow retrieving + * local out-of-band data when privacy is in use. + * + * Returning the identity address will not help here since + * pairing happens before the identity resolving key is + * known and thus the connection establishment happens + * based on the RPA and not the identity address. + */ + if (hci_dev_test_flag(hdev, HCI_PRIVACY)) { + hci_dev_unlock(hdev); + status = MGMT_STATUS_REJECTED; + goto complete; + } + + if (hci_dev_test_flag(hdev, HCI_FORCE_STATIC_ADDR) || + !bacmp(&hdev->bdaddr, BDADDR_ANY) || + (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED) && + bacmp(&hdev->static_addr, BDADDR_ANY))) { + memcpy(addr, &hdev->static_addr, 6); + addr[6] = 0x01; + } else { + memcpy(addr, &hdev->bdaddr, 6); + addr[6] = 0x00; + } + + eir_len = eir_append_data(rp->eir, eir_len, EIR_LE_BDADDR, + addr, sizeof(addr)); + + if (hci_dev_test_flag(hdev, HCI_ADVERTISING)) + role = 0x02; + else + role = 0x01; + + eir_len = eir_append_data(rp->eir, eir_len, EIR_LE_ROLE, + &role, sizeof(role)); + + if (hci_dev_test_flag(hdev, HCI_SC_ENABLED)) { + eir_len = eir_append_data(rp->eir, eir_len, + EIR_LE_SC_CONFIRM, + hash, sizeof(hash)); + + eir_len = eir_append_data(rp->eir, eir_len, + EIR_LE_SC_RANDOM, + rand, sizeof(rand)); + } + + flags = mgmt_get_adv_discov_flags(hdev); + + if (!hci_dev_test_flag(hdev, HCI_BREDR_ENABLED)) + flags |= LE_AD_NO_BREDR; + + eir_len = eir_append_data(rp->eir, eir_len, EIR_FLAGS, + &flags, sizeof(flags)); + break; + } + + hci_dev_unlock(hdev); + + hci_sock_set_flag(sk, HCI_MGMT_OOB_DATA_EVENTS); + + status = MGMT_STATUS_SUCCESS; + +complete: + rp->type = cp->type; + rp->eir_len = cpu_to_le16(eir_len); + + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_READ_LOCAL_OOB_EXT_DATA, + status, rp, sizeof(*rp) + eir_len); + if (err < 0 || status) + goto done; + + err = mgmt_limited_event(MGMT_EV_LOCAL_OOB_DATA_UPDATED, hdev, + rp, sizeof(*rp) + eir_len, + HCI_MGMT_OOB_DATA_EVENTS, sk); + +done: + kfree(rp); + + return err; +} + +static u32 get_supported_adv_flags(struct hci_dev *hdev) +{ + u32 flags = 0; + + flags |= MGMT_ADV_FLAG_CONNECTABLE; + flags |= MGMT_ADV_FLAG_DISCOV; + flags |= MGMT_ADV_FLAG_LIMITED_DISCOV; + flags |= MGMT_ADV_FLAG_MANAGED_FLAGS; + flags |= MGMT_ADV_FLAG_APPEARANCE; + flags |= MGMT_ADV_FLAG_LOCAL_NAME; + flags |= MGMT_ADV_PARAM_DURATION; + flags |= MGMT_ADV_PARAM_TIMEOUT; + flags |= MGMT_ADV_PARAM_INTERVALS; + flags |= MGMT_ADV_PARAM_TX_POWER; + flags |= MGMT_ADV_PARAM_SCAN_RSP; + + /* In extended adv TX_POWER returned from Set Adv Param + * will be always valid. + */ + if (hdev->adv_tx_power != HCI_TX_POWER_INVALID || ext_adv_capable(hdev)) + flags |= MGMT_ADV_FLAG_TX_POWER; + + if (ext_adv_capable(hdev)) { + flags |= MGMT_ADV_FLAG_SEC_1M; + flags |= MGMT_ADV_FLAG_HW_OFFLOAD; + flags |= MGMT_ADV_FLAG_CAN_SET_TX_POWER; + + if (hdev->le_features[1] & HCI_LE_PHY_2M) + flags |= MGMT_ADV_FLAG_SEC_2M; + + if (hdev->le_features[1] & HCI_LE_PHY_CODED) + flags |= MGMT_ADV_FLAG_SEC_CODED; + } + + return flags; +} + +static int read_adv_features(struct sock *sk, struct hci_dev *hdev, + void *data, u16 data_len) +{ + struct mgmt_rp_read_adv_features *rp; + size_t rp_len; + int err; + struct adv_info *adv_instance; + u32 supported_flags; + u8 *instance; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!lmp_le_capable(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_READ_ADV_FEATURES, + MGMT_STATUS_REJECTED); + + hci_dev_lock(hdev); + + rp_len = sizeof(*rp) + hdev->adv_instance_cnt; + rp = kmalloc(rp_len, GFP_ATOMIC); + if (!rp) { + hci_dev_unlock(hdev); + return -ENOMEM; + } + + supported_flags = get_supported_adv_flags(hdev); + + rp->supported_flags = cpu_to_le32(supported_flags); + rp->max_adv_data_len = HCI_MAX_AD_LENGTH; + rp->max_scan_rsp_len = HCI_MAX_AD_LENGTH; + rp->max_instances = hdev->le_num_of_adv_sets; + rp->num_instances = hdev->adv_instance_cnt; + + instance = rp->instance; + list_for_each_entry(adv_instance, &hdev->adv_instances, list) { + /* Only instances 1-le_num_of_adv_sets are externally visible */ + if (adv_instance->instance <= hdev->adv_instance_cnt) { + *instance = adv_instance->instance; + instance++; + } else { + rp->num_instances--; + rp_len--; + } + } + + hci_dev_unlock(hdev); + + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_READ_ADV_FEATURES, + MGMT_STATUS_SUCCESS, rp, rp_len); + + kfree(rp); + + return err; +} + +static u8 calculate_name_len(struct hci_dev *hdev) +{ + u8 buf[HCI_MAX_SHORT_NAME_LENGTH + 3]; + + return eir_append_local_name(hdev, buf, 0); +} + +static u8 tlv_data_max_len(struct hci_dev *hdev, u32 adv_flags, + bool is_adv_data) +{ + u8 max_len = HCI_MAX_AD_LENGTH; + + if (is_adv_data) { + if (adv_flags & (MGMT_ADV_FLAG_DISCOV | + MGMT_ADV_FLAG_LIMITED_DISCOV | + MGMT_ADV_FLAG_MANAGED_FLAGS)) + max_len -= 3; + + if (adv_flags & MGMT_ADV_FLAG_TX_POWER) + max_len -= 3; + } else { + if (adv_flags & MGMT_ADV_FLAG_LOCAL_NAME) + max_len -= calculate_name_len(hdev); + + if (adv_flags & (MGMT_ADV_FLAG_APPEARANCE)) + max_len -= 4; + } + + return max_len; +} + +static bool flags_managed(u32 adv_flags) +{ + return adv_flags & (MGMT_ADV_FLAG_DISCOV | + MGMT_ADV_FLAG_LIMITED_DISCOV | + MGMT_ADV_FLAG_MANAGED_FLAGS); +} + +static bool tx_power_managed(u32 adv_flags) +{ + return adv_flags & MGMT_ADV_FLAG_TX_POWER; +} + +static bool name_managed(u32 adv_flags) +{ + return adv_flags & MGMT_ADV_FLAG_LOCAL_NAME; +} + +static bool appearance_managed(u32 adv_flags) +{ + return adv_flags & MGMT_ADV_FLAG_APPEARANCE; +} + +static bool tlv_data_is_valid(struct hci_dev *hdev, u32 adv_flags, u8 *data, + u8 len, bool is_adv_data) +{ + int i, cur_len; + u8 max_len; + + max_len = tlv_data_max_len(hdev, adv_flags, is_adv_data); + + if (len > max_len) + return false; + + /* Make sure that the data is correctly formatted. */ + for (i = 0; i < len; i += (cur_len + 1)) { + cur_len = data[i]; + + if (!cur_len) + continue; + + if (data[i + 1] == EIR_FLAGS && + (!is_adv_data || flags_managed(adv_flags))) + return false; + + if (data[i + 1] == EIR_TX_POWER && tx_power_managed(adv_flags)) + return false; + + if (data[i + 1] == EIR_NAME_COMPLETE && name_managed(adv_flags)) + return false; + + if (data[i + 1] == EIR_NAME_SHORT && name_managed(adv_flags)) + return false; + + if (data[i + 1] == EIR_APPEARANCE && + appearance_managed(adv_flags)) + return false; + + /* If the current field length would exceed the total data + * length, then it's invalid. + */ + if (i + cur_len >= len) + return false; + } + + return true; +} + +static bool requested_adv_flags_are_valid(struct hci_dev *hdev, u32 adv_flags) +{ + u32 supported_flags, phy_flags; + + /* The current implementation only supports a subset of the specified + * flags. Also need to check mutual exclusiveness of sec flags. + */ + supported_flags = get_supported_adv_flags(hdev); + phy_flags = adv_flags & MGMT_ADV_FLAG_SEC_MASK; + if (adv_flags & ~supported_flags || + ((phy_flags && (phy_flags ^ (phy_flags & -phy_flags))))) + return false; + + return true; +} + +static bool adv_busy(struct hci_dev *hdev) +{ + return pending_find(MGMT_OP_SET_LE, hdev); +} + +static void add_adv_complete(struct hci_dev *hdev, struct sock *sk, u8 instance, + int err) +{ + struct adv_info *adv, *n; + + bt_dev_dbg(hdev, "err %d", err); + + hci_dev_lock(hdev); + + list_for_each_entry_safe(adv, n, &hdev->adv_instances, list) { + u8 instance; + + if (!adv->pending) + continue; + + if (!err) { + adv->pending = false; + continue; + } + + instance = adv->instance; + + if (hdev->cur_adv_instance == instance) + cancel_adv_timeout(hdev); + + hci_remove_adv_instance(hdev, instance); + mgmt_advertising_removed(sk, hdev, instance); + } + + hci_dev_unlock(hdev); +} + +static void add_advertising_complete(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_add_advertising *cp = cmd->param; + struct mgmt_rp_add_advertising rp; + + memset(&rp, 0, sizeof(rp)); + + rp.instance = cp->instance; + + if (err) + mgmt_cmd_status(cmd->sk, cmd->index, cmd->opcode, + mgmt_status(err)); + else + mgmt_cmd_complete(cmd->sk, cmd->index, cmd->opcode, + mgmt_status(err), &rp, sizeof(rp)); + + add_adv_complete(hdev, cmd->sk, cp->instance, err); + + mgmt_pending_free(cmd); +} + +static int add_advertising_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_add_advertising *cp = cmd->param; + + return hci_schedule_adv_instance_sync(hdev, cp->instance, true); +} + +static int add_advertising(struct sock *sk, struct hci_dev *hdev, + void *data, u16 data_len) +{ + struct mgmt_cp_add_advertising *cp = data; + struct mgmt_rp_add_advertising rp; + u32 flags; + u8 status; + u16 timeout, duration; + unsigned int prev_instance_cnt; + u8 schedule_instance = 0; + struct adv_info *adv, *next_instance; + int err; + struct mgmt_pending_cmd *cmd; + + bt_dev_dbg(hdev, "sock %p", sk); + + status = mgmt_le_support(hdev); + if (status) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_ADVERTISING, + status); + + if (cp->instance < 1 || cp->instance > hdev->le_num_of_adv_sets) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_ADVERTISING, + MGMT_STATUS_INVALID_PARAMS); + + if (data_len != sizeof(*cp) + cp->adv_data_len + cp->scan_rsp_len) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_ADVERTISING, + MGMT_STATUS_INVALID_PARAMS); + + flags = __le32_to_cpu(cp->flags); + timeout = __le16_to_cpu(cp->timeout); + duration = __le16_to_cpu(cp->duration); + + if (!requested_adv_flags_are_valid(hdev, flags)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_ADVERTISING, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + if (timeout && !hdev_is_powered(hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_ADVERTISING, + MGMT_STATUS_REJECTED); + goto unlock; + } + + if (adv_busy(hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_ADVERTISING, + MGMT_STATUS_BUSY); + goto unlock; + } + + if (!tlv_data_is_valid(hdev, flags, cp->data, cp->adv_data_len, true) || + !tlv_data_is_valid(hdev, flags, cp->data + cp->adv_data_len, + cp->scan_rsp_len, false)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_ADVERTISING, + MGMT_STATUS_INVALID_PARAMS); + goto unlock; + } + + prev_instance_cnt = hdev->adv_instance_cnt; + + adv = hci_add_adv_instance(hdev, cp->instance, flags, + cp->adv_data_len, cp->data, + cp->scan_rsp_len, + cp->data + cp->adv_data_len, + timeout, duration, + HCI_ADV_TX_POWER_NO_PREFERENCE, + hdev->le_adv_min_interval, + hdev->le_adv_max_interval, 0); + if (IS_ERR(adv)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_ADVERTISING, + MGMT_STATUS_FAILED); + goto unlock; + } + + /* Only trigger an advertising added event if a new instance was + * actually added. + */ + if (hdev->adv_instance_cnt > prev_instance_cnt) + mgmt_advertising_added(sk, hdev, cp->instance); + + if (hdev->cur_adv_instance == cp->instance) { + /* If the currently advertised instance is being changed then + * cancel the current advertising and schedule the next + * instance. If there is only one instance then the overridden + * advertising data will be visible right away. + */ + cancel_adv_timeout(hdev); + + next_instance = hci_get_next_instance(hdev, cp->instance); + if (next_instance) + schedule_instance = next_instance->instance; + } else if (!hdev->adv_instance_timeout) { + /* Immediately advertise the new instance if no other + * instance is currently being advertised. + */ + schedule_instance = cp->instance; + } + + /* If the HCI_ADVERTISING flag is set or the device isn't powered or + * there is no instance to be advertised then we have no HCI + * communication to make. Simply return. + */ + if (!hdev_is_powered(hdev) || + hci_dev_test_flag(hdev, HCI_ADVERTISING) || + !schedule_instance) { + rp.instance = cp->instance; + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_ADD_ADVERTISING, + MGMT_STATUS_SUCCESS, &rp, sizeof(rp)); + goto unlock; + } + + /* We're good to go, update advertising data, parameters, and start + * advertising. + */ + cmd = mgmt_pending_new(sk, MGMT_OP_ADD_ADVERTISING, hdev, data, + data_len); + if (!cmd) { + err = -ENOMEM; + goto unlock; + } + + cp->instance = schedule_instance; + + err = hci_cmd_sync_queue(hdev, add_advertising_sync, cmd, + add_advertising_complete); + if (err < 0) + mgmt_pending_free(cmd); + +unlock: + hci_dev_unlock(hdev); + + return err; +} + +static void add_ext_adv_params_complete(struct hci_dev *hdev, void *data, + int err) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_add_ext_adv_params *cp = cmd->param; + struct mgmt_rp_add_ext_adv_params rp; + struct adv_info *adv; + u32 flags; + + BT_DBG("%s", hdev->name); + + hci_dev_lock(hdev); + + adv = hci_find_adv_instance(hdev, cp->instance); + if (!adv) + goto unlock; + + rp.instance = cp->instance; + rp.tx_power = adv->tx_power; + + /* While we're at it, inform userspace of the available space for this + * advertisement, given the flags that will be used. + */ + flags = __le32_to_cpu(cp->flags); + rp.max_adv_data_len = tlv_data_max_len(hdev, flags, true); + rp.max_scan_rsp_len = tlv_data_max_len(hdev, flags, false); + + if (err) { + /* If this advertisement was previously advertising and we + * failed to update it, we signal that it has been removed and + * delete its structure + */ + if (!adv->pending) + mgmt_advertising_removed(cmd->sk, hdev, cp->instance); + + hci_remove_adv_instance(hdev, cp->instance); + + mgmt_cmd_status(cmd->sk, cmd->index, cmd->opcode, + mgmt_status(err)); + } else { + mgmt_cmd_complete(cmd->sk, cmd->index, cmd->opcode, + mgmt_status(err), &rp, sizeof(rp)); + } + +unlock: + if (cmd) + mgmt_pending_free(cmd); + + hci_dev_unlock(hdev); +} + +static int add_ext_adv_params_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_add_ext_adv_params *cp = cmd->param; + + return hci_setup_ext_adv_instance_sync(hdev, cp->instance); +} + +static int add_ext_adv_params(struct sock *sk, struct hci_dev *hdev, + void *data, u16 data_len) +{ + struct mgmt_cp_add_ext_adv_params *cp = data; + struct mgmt_rp_add_ext_adv_params rp; + struct mgmt_pending_cmd *cmd = NULL; + struct adv_info *adv; + u32 flags, min_interval, max_interval; + u16 timeout, duration; + u8 status; + s8 tx_power; + int err; + + BT_DBG("%s", hdev->name); + + status = mgmt_le_support(hdev); + if (status) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_EXT_ADV_PARAMS, + status); + + if (cp->instance < 1 || cp->instance > hdev->le_num_of_adv_sets) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_EXT_ADV_PARAMS, + MGMT_STATUS_INVALID_PARAMS); + + /* The purpose of breaking add_advertising into two separate MGMT calls + * for params and data is to allow more parameters to be added to this + * structure in the future. For this reason, we verify that we have the + * bare minimum structure we know of when the interface was defined. Any + * extra parameters we don't know about will be ignored in this request. + */ + if (data_len < MGMT_ADD_EXT_ADV_PARAMS_MIN_SIZE) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_EXT_ADV_PARAMS, + MGMT_STATUS_INVALID_PARAMS); + + flags = __le32_to_cpu(cp->flags); + + if (!requested_adv_flags_are_valid(hdev, flags)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_EXT_ADV_PARAMS, + MGMT_STATUS_INVALID_PARAMS); + + hci_dev_lock(hdev); + + /* In new interface, we require that we are powered to register */ + if (!hdev_is_powered(hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_EXT_ADV_PARAMS, + MGMT_STATUS_REJECTED); + goto unlock; + } + + if (adv_busy(hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_EXT_ADV_PARAMS, + MGMT_STATUS_BUSY); + goto unlock; + } + + /* Parse defined parameters from request, use defaults otherwise */ + timeout = (flags & MGMT_ADV_PARAM_TIMEOUT) ? + __le16_to_cpu(cp->timeout) : 0; + + duration = (flags & MGMT_ADV_PARAM_DURATION) ? + __le16_to_cpu(cp->duration) : + hdev->def_multi_adv_rotation_duration; + + min_interval = (flags & MGMT_ADV_PARAM_INTERVALS) ? + __le32_to_cpu(cp->min_interval) : + hdev->le_adv_min_interval; + + max_interval = (flags & MGMT_ADV_PARAM_INTERVALS) ? + __le32_to_cpu(cp->max_interval) : + hdev->le_adv_max_interval; + + tx_power = (flags & MGMT_ADV_PARAM_TX_POWER) ? + cp->tx_power : + HCI_ADV_TX_POWER_NO_PREFERENCE; + + /* Create advertising instance with no advertising or response data */ + adv = hci_add_adv_instance(hdev, cp->instance, flags, 0, NULL, 0, NULL, + timeout, duration, tx_power, min_interval, + max_interval, 0); + + if (IS_ERR(adv)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_EXT_ADV_PARAMS, + MGMT_STATUS_FAILED); + goto unlock; + } + + /* Submit request for advertising params if ext adv available */ + if (ext_adv_capable(hdev)) { + cmd = mgmt_pending_new(sk, MGMT_OP_ADD_EXT_ADV_PARAMS, hdev, + data, data_len); + if (!cmd) { + err = -ENOMEM; + hci_remove_adv_instance(hdev, cp->instance); + goto unlock; + } + + err = hci_cmd_sync_queue(hdev, add_ext_adv_params_sync, cmd, + add_ext_adv_params_complete); + if (err < 0) + mgmt_pending_free(cmd); + } else { + rp.instance = cp->instance; + rp.tx_power = HCI_ADV_TX_POWER_NO_PREFERENCE; + rp.max_adv_data_len = tlv_data_max_len(hdev, flags, true); + rp.max_scan_rsp_len = tlv_data_max_len(hdev, flags, false); + err = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_ADD_EXT_ADV_PARAMS, + MGMT_STATUS_SUCCESS, &rp, sizeof(rp)); + } + +unlock: + hci_dev_unlock(hdev); + + return err; +} + +static void add_ext_adv_data_complete(struct hci_dev *hdev, void *data, int err) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_add_ext_adv_data *cp = cmd->param; + struct mgmt_rp_add_advertising rp; + + add_adv_complete(hdev, cmd->sk, cp->instance, err); + + memset(&rp, 0, sizeof(rp)); + + rp.instance = cp->instance; + + if (err) + mgmt_cmd_status(cmd->sk, cmd->index, cmd->opcode, + mgmt_status(err)); + else + mgmt_cmd_complete(cmd->sk, cmd->index, cmd->opcode, + mgmt_status(err), &rp, sizeof(rp)); + + mgmt_pending_free(cmd); +} + +static int add_ext_adv_data_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_add_ext_adv_data *cp = cmd->param; + int err; + + if (ext_adv_capable(hdev)) { + err = hci_update_adv_data_sync(hdev, cp->instance); + if (err) + return err; + + err = hci_update_scan_rsp_data_sync(hdev, cp->instance); + if (err) + return err; + + return hci_enable_ext_advertising_sync(hdev, cp->instance); + } + + return hci_schedule_adv_instance_sync(hdev, cp->instance, true); +} + +static int add_ext_adv_data(struct sock *sk, struct hci_dev *hdev, void *data, + u16 data_len) +{ + struct mgmt_cp_add_ext_adv_data *cp = data; + struct mgmt_rp_add_ext_adv_data rp; + u8 schedule_instance = 0; + struct adv_info *next_instance; + struct adv_info *adv_instance; + int err = 0; + struct mgmt_pending_cmd *cmd; + + BT_DBG("%s", hdev->name); + + hci_dev_lock(hdev); + + adv_instance = hci_find_adv_instance(hdev, cp->instance); + + if (!adv_instance) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_EXT_ADV_DATA, + MGMT_STATUS_INVALID_PARAMS); + goto unlock; + } + + /* In new interface, we require that we are powered to register */ + if (!hdev_is_powered(hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_EXT_ADV_DATA, + MGMT_STATUS_REJECTED); + goto clear_new_instance; + } + + if (adv_busy(hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_EXT_ADV_DATA, + MGMT_STATUS_BUSY); + goto clear_new_instance; + } + + /* Validate new data */ + if (!tlv_data_is_valid(hdev, adv_instance->flags, cp->data, + cp->adv_data_len, true) || + !tlv_data_is_valid(hdev, adv_instance->flags, cp->data + + cp->adv_data_len, cp->scan_rsp_len, false)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_ADD_EXT_ADV_DATA, + MGMT_STATUS_INVALID_PARAMS); + goto clear_new_instance; + } + + /* Set the data in the advertising instance */ + hci_set_adv_instance_data(hdev, cp->instance, cp->adv_data_len, + cp->data, cp->scan_rsp_len, + cp->data + cp->adv_data_len); + + /* If using software rotation, determine next instance to use */ + if (hdev->cur_adv_instance == cp->instance) { + /* If the currently advertised instance is being changed + * then cancel the current advertising and schedule the + * next instance. If there is only one instance then the + * overridden advertising data will be visible right + * away + */ + cancel_adv_timeout(hdev); + + next_instance = hci_get_next_instance(hdev, cp->instance); + if (next_instance) + schedule_instance = next_instance->instance; + } else if (!hdev->adv_instance_timeout) { + /* Immediately advertise the new instance if no other + * instance is currently being advertised. + */ + schedule_instance = cp->instance; + } + + /* If the HCI_ADVERTISING flag is set or there is no instance to + * be advertised then we have no HCI communication to make. + * Simply return. + */ + if (hci_dev_test_flag(hdev, HCI_ADVERTISING) || !schedule_instance) { + if (adv_instance->pending) { + mgmt_advertising_added(sk, hdev, cp->instance); + adv_instance->pending = false; + } + rp.instance = cp->instance; + err = mgmt_cmd_complete(sk, hdev->id, MGMT_OP_ADD_EXT_ADV_DATA, + MGMT_STATUS_SUCCESS, &rp, sizeof(rp)); + goto unlock; + } + + cmd = mgmt_pending_new(sk, MGMT_OP_ADD_EXT_ADV_DATA, hdev, data, + data_len); + if (!cmd) { + err = -ENOMEM; + goto clear_new_instance; + } + + err = hci_cmd_sync_queue(hdev, add_ext_adv_data_sync, cmd, + add_ext_adv_data_complete); + if (err < 0) { + mgmt_pending_free(cmd); + goto clear_new_instance; + } + + /* We were successful in updating data, so trigger advertising_added + * event if this is an instance that wasn't previously advertising. If + * a failure occurs in the requests we initiated, we will remove the + * instance again in add_advertising_complete + */ + if (adv_instance->pending) + mgmt_advertising_added(sk, hdev, cp->instance); + + goto unlock; + +clear_new_instance: + hci_remove_adv_instance(hdev, cp->instance); + +unlock: + hci_dev_unlock(hdev); + + return err; +} + +static void remove_advertising_complete(struct hci_dev *hdev, void *data, + int err) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_remove_advertising *cp = cmd->param; + struct mgmt_rp_remove_advertising rp; + + bt_dev_dbg(hdev, "err %d", err); + + memset(&rp, 0, sizeof(rp)); + rp.instance = cp->instance; + + if (err) + mgmt_cmd_status(cmd->sk, cmd->index, cmd->opcode, + mgmt_status(err)); + else + mgmt_cmd_complete(cmd->sk, cmd->index, cmd->opcode, + MGMT_STATUS_SUCCESS, &rp, sizeof(rp)); + + mgmt_pending_free(cmd); +} + +static int remove_advertising_sync(struct hci_dev *hdev, void *data) +{ + struct mgmt_pending_cmd *cmd = data; + struct mgmt_cp_remove_advertising *cp = cmd->param; + int err; + + err = hci_remove_advertising_sync(hdev, cmd->sk, cp->instance, true); + if (err) + return err; + + if (list_empty(&hdev->adv_instances)) + err = hci_disable_advertising_sync(hdev); + + return err; +} + +static int remove_advertising(struct sock *sk, struct hci_dev *hdev, + void *data, u16 data_len) +{ + struct mgmt_cp_remove_advertising *cp = data; + struct mgmt_pending_cmd *cmd; + int err; + + bt_dev_dbg(hdev, "sock %p", sk); + + hci_dev_lock(hdev); + + if (cp->instance && !hci_find_adv_instance(hdev, cp->instance)) { + err = mgmt_cmd_status(sk, hdev->id, + MGMT_OP_REMOVE_ADVERTISING, + MGMT_STATUS_INVALID_PARAMS); + goto unlock; + } + + if (pending_find(MGMT_OP_SET_LE, hdev)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_REMOVE_ADVERTISING, + MGMT_STATUS_BUSY); + goto unlock; + } + + if (list_empty(&hdev->adv_instances)) { + err = mgmt_cmd_status(sk, hdev->id, MGMT_OP_REMOVE_ADVERTISING, + MGMT_STATUS_INVALID_PARAMS); + goto unlock; + } + + cmd = mgmt_pending_new(sk, MGMT_OP_REMOVE_ADVERTISING, hdev, data, + data_len); + if (!cmd) { + err = -ENOMEM; + goto unlock; + } + + err = hci_cmd_sync_queue(hdev, remove_advertising_sync, cmd, + remove_advertising_complete); + if (err < 0) + mgmt_pending_free(cmd); + +unlock: + hci_dev_unlock(hdev); + + return err; +} + +static int get_adv_size_info(struct sock *sk, struct hci_dev *hdev, + void *data, u16 data_len) +{ + struct mgmt_cp_get_adv_size_info *cp = data; + struct mgmt_rp_get_adv_size_info rp; + u32 flags, supported_flags; + + bt_dev_dbg(hdev, "sock %p", sk); + + if (!lmp_le_capable(hdev)) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_GET_ADV_SIZE_INFO, + MGMT_STATUS_REJECTED); + + if (cp->instance < 1 || cp->instance > hdev->le_num_of_adv_sets) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_GET_ADV_SIZE_INFO, + MGMT_STATUS_INVALID_PARAMS); + + flags = __le32_to_cpu(cp->flags); + + /* The current implementation only supports a subset of the specified + * flags. + */ + supported_flags = get_supported_adv_flags(hdev); + if (flags & ~supported_flags) + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_GET_ADV_SIZE_INFO, + MGMT_STATUS_INVALID_PARAMS); + + rp.instance = cp->instance; + rp.flags = cp->flags; + rp.max_adv_data_len = tlv_data_max_len(hdev, flags, true); + rp.max_scan_rsp_len = tlv_data_max_len(hdev, flags, false); + + return mgmt_cmd_complete(sk, hdev->id, MGMT_OP_GET_ADV_SIZE_INFO, + MGMT_STATUS_SUCCESS, &rp, sizeof(rp)); +} + +static const struct hci_mgmt_handler mgmt_handlers[] = { + { NULL }, /* 0x0000 (no command) */ + { read_version, MGMT_READ_VERSION_SIZE, + HCI_MGMT_NO_HDEV | + HCI_MGMT_UNTRUSTED }, + { read_commands, MGMT_READ_COMMANDS_SIZE, + HCI_MGMT_NO_HDEV | + HCI_MGMT_UNTRUSTED }, + { read_index_list, MGMT_READ_INDEX_LIST_SIZE, + HCI_MGMT_NO_HDEV | + HCI_MGMT_UNTRUSTED }, + { read_controller_info, MGMT_READ_INFO_SIZE, + HCI_MGMT_UNTRUSTED }, + { set_powered, MGMT_SETTING_SIZE }, + { set_discoverable, MGMT_SET_DISCOVERABLE_SIZE }, + { set_connectable, MGMT_SETTING_SIZE }, + { set_fast_connectable, MGMT_SETTING_SIZE }, + { set_bondable, MGMT_SETTING_SIZE }, + { set_link_security, MGMT_SETTING_SIZE }, + { set_ssp, MGMT_SETTING_SIZE }, + { set_hs, MGMT_SETTING_SIZE }, + { set_le, MGMT_SETTING_SIZE }, + { set_dev_class, MGMT_SET_DEV_CLASS_SIZE }, + { set_local_name, MGMT_SET_LOCAL_NAME_SIZE }, + { add_uuid, MGMT_ADD_UUID_SIZE }, + { remove_uuid, MGMT_REMOVE_UUID_SIZE }, + { load_link_keys, MGMT_LOAD_LINK_KEYS_SIZE, + HCI_MGMT_VAR_LEN }, + { load_long_term_keys, MGMT_LOAD_LONG_TERM_KEYS_SIZE, + HCI_MGMT_VAR_LEN }, + { disconnect, MGMT_DISCONNECT_SIZE }, + { get_connections, MGMT_GET_CONNECTIONS_SIZE }, + { pin_code_reply, MGMT_PIN_CODE_REPLY_SIZE }, + { pin_code_neg_reply, MGMT_PIN_CODE_NEG_REPLY_SIZE }, + { set_io_capability, MGMT_SET_IO_CAPABILITY_SIZE }, + { pair_device, MGMT_PAIR_DEVICE_SIZE }, + { cancel_pair_device, MGMT_CANCEL_PAIR_DEVICE_SIZE }, + { unpair_device, MGMT_UNPAIR_DEVICE_SIZE }, + { user_confirm_reply, MGMT_USER_CONFIRM_REPLY_SIZE }, + { user_confirm_neg_reply, MGMT_USER_CONFIRM_NEG_REPLY_SIZE }, + { user_passkey_reply, MGMT_USER_PASSKEY_REPLY_SIZE }, + { user_passkey_neg_reply, MGMT_USER_PASSKEY_NEG_REPLY_SIZE }, + { read_local_oob_data, MGMT_READ_LOCAL_OOB_DATA_SIZE }, + { add_remote_oob_data, MGMT_ADD_REMOTE_OOB_DATA_SIZE, + HCI_MGMT_VAR_LEN }, + { remove_remote_oob_data, MGMT_REMOVE_REMOTE_OOB_DATA_SIZE }, + { start_discovery, MGMT_START_DISCOVERY_SIZE }, + { stop_discovery, MGMT_STOP_DISCOVERY_SIZE }, + { confirm_name, MGMT_CONFIRM_NAME_SIZE }, + { block_device, MGMT_BLOCK_DEVICE_SIZE }, + { unblock_device, MGMT_UNBLOCK_DEVICE_SIZE }, + { set_device_id, MGMT_SET_DEVICE_ID_SIZE }, + { set_advertising, MGMT_SETTING_SIZE }, + { set_bredr, MGMT_SETTING_SIZE }, + { set_static_address, MGMT_SET_STATIC_ADDRESS_SIZE }, + { set_scan_params, MGMT_SET_SCAN_PARAMS_SIZE }, + { set_secure_conn, MGMT_SETTING_SIZE }, + { set_debug_keys, MGMT_SETTING_SIZE }, + { set_privacy, MGMT_SET_PRIVACY_SIZE }, + { load_irks, MGMT_LOAD_IRKS_SIZE, + HCI_MGMT_VAR_LEN }, + { get_conn_info, MGMT_GET_CONN_INFO_SIZE }, + { get_clock_info, MGMT_GET_CLOCK_INFO_SIZE }, + { add_device, MGMT_ADD_DEVICE_SIZE }, + { remove_device, MGMT_REMOVE_DEVICE_SIZE }, + { load_conn_param, MGMT_LOAD_CONN_PARAM_SIZE, + HCI_MGMT_VAR_LEN }, + { read_unconf_index_list, MGMT_READ_UNCONF_INDEX_LIST_SIZE, + HCI_MGMT_NO_HDEV | + HCI_MGMT_UNTRUSTED }, + { read_config_info, MGMT_READ_CONFIG_INFO_SIZE, + HCI_MGMT_UNCONFIGURED | + HCI_MGMT_UNTRUSTED }, + { set_external_config, MGMT_SET_EXTERNAL_CONFIG_SIZE, + HCI_MGMT_UNCONFIGURED }, + { set_public_address, MGMT_SET_PUBLIC_ADDRESS_SIZE, + HCI_MGMT_UNCONFIGURED }, + { start_service_discovery, MGMT_START_SERVICE_DISCOVERY_SIZE, + HCI_MGMT_VAR_LEN }, + { read_local_oob_ext_data, MGMT_READ_LOCAL_OOB_EXT_DATA_SIZE }, + { read_ext_index_list, MGMT_READ_EXT_INDEX_LIST_SIZE, + HCI_MGMT_NO_HDEV | + HCI_MGMT_UNTRUSTED }, + { read_adv_features, MGMT_READ_ADV_FEATURES_SIZE }, + { add_advertising, MGMT_ADD_ADVERTISING_SIZE, + HCI_MGMT_VAR_LEN }, + { remove_advertising, MGMT_REMOVE_ADVERTISING_SIZE }, + { get_adv_size_info, MGMT_GET_ADV_SIZE_INFO_SIZE }, + { start_limited_discovery, MGMT_START_DISCOVERY_SIZE }, + { read_ext_controller_info,MGMT_READ_EXT_INFO_SIZE, + HCI_MGMT_UNTRUSTED }, + { set_appearance, MGMT_SET_APPEARANCE_SIZE }, + { get_phy_configuration, MGMT_GET_PHY_CONFIGURATION_SIZE }, + { set_phy_configuration, MGMT_SET_PHY_CONFIGURATION_SIZE }, + { set_blocked_keys, MGMT_OP_SET_BLOCKED_KEYS_SIZE, + HCI_MGMT_VAR_LEN }, + { set_wideband_speech, MGMT_SETTING_SIZE }, + { read_controller_cap, MGMT_READ_CONTROLLER_CAP_SIZE, + HCI_MGMT_UNTRUSTED }, + { read_exp_features_info, MGMT_READ_EXP_FEATURES_INFO_SIZE, + HCI_MGMT_UNTRUSTED | + HCI_MGMT_HDEV_OPTIONAL }, + { set_exp_feature, MGMT_SET_EXP_FEATURE_SIZE, + HCI_MGMT_VAR_LEN | + HCI_MGMT_HDEV_OPTIONAL }, + { read_def_system_config, MGMT_READ_DEF_SYSTEM_CONFIG_SIZE, + HCI_MGMT_UNTRUSTED }, + { set_def_system_config, MGMT_SET_DEF_SYSTEM_CONFIG_SIZE, + HCI_MGMT_VAR_LEN }, + { read_def_runtime_config, MGMT_READ_DEF_RUNTIME_CONFIG_SIZE, + HCI_MGMT_UNTRUSTED }, + { set_def_runtime_config, MGMT_SET_DEF_RUNTIME_CONFIG_SIZE, + HCI_MGMT_VAR_LEN }, + { get_device_flags, MGMT_GET_DEVICE_FLAGS_SIZE }, + { set_device_flags, MGMT_SET_DEVICE_FLAGS_SIZE }, + { read_adv_mon_features, MGMT_READ_ADV_MONITOR_FEATURES_SIZE }, + { add_adv_patterns_monitor,MGMT_ADD_ADV_PATTERNS_MONITOR_SIZE, + HCI_MGMT_VAR_LEN }, + { remove_adv_monitor, MGMT_REMOVE_ADV_MONITOR_SIZE }, + { add_ext_adv_params, MGMT_ADD_EXT_ADV_PARAMS_MIN_SIZE, + HCI_MGMT_VAR_LEN }, + { add_ext_adv_data, MGMT_ADD_EXT_ADV_DATA_SIZE, + HCI_MGMT_VAR_LEN }, + { add_adv_patterns_monitor_rssi, + MGMT_ADD_ADV_PATTERNS_MONITOR_RSSI_SIZE, + HCI_MGMT_VAR_LEN }, + { set_mesh, MGMT_SET_MESH_RECEIVER_SIZE, + HCI_MGMT_VAR_LEN }, + { mesh_features, MGMT_MESH_READ_FEATURES_SIZE }, + { mesh_send, MGMT_MESH_SEND_SIZE, + HCI_MGMT_VAR_LEN }, + { mesh_send_cancel, MGMT_MESH_SEND_CANCEL_SIZE }, +}; + +void mgmt_index_added(struct hci_dev *hdev) +{ + struct mgmt_ev_ext_index ev; + + if (test_bit(HCI_QUIRK_RAW_DEVICE, &hdev->quirks)) + return; + + switch (hdev->dev_type) { + case HCI_PRIMARY: + if (hci_dev_test_flag(hdev, HCI_UNCONFIGURED)) { + mgmt_index_event(MGMT_EV_UNCONF_INDEX_ADDED, hdev, + NULL, 0, HCI_MGMT_UNCONF_INDEX_EVENTS); + ev.type = 0x01; + } else { + mgmt_index_event(MGMT_EV_INDEX_ADDED, hdev, NULL, 0, + HCI_MGMT_INDEX_EVENTS); + ev.type = 0x00; + } + break; + case HCI_AMP: + ev.type = 0x02; + break; + default: + return; + } + + ev.bus = hdev->bus; + + mgmt_index_event(MGMT_EV_EXT_INDEX_ADDED, hdev, &ev, sizeof(ev), + HCI_MGMT_EXT_INDEX_EVENTS); +} + +void mgmt_index_removed(struct hci_dev *hdev) +{ + struct mgmt_ev_ext_index ev; + u8 status = MGMT_STATUS_INVALID_INDEX; + + if (test_bit(HCI_QUIRK_RAW_DEVICE, &hdev->quirks)) + return; + + switch (hdev->dev_type) { + case HCI_PRIMARY: + mgmt_pending_foreach(0, hdev, cmd_complete_rsp, &status); + + if (hci_dev_test_flag(hdev, HCI_UNCONFIGURED)) { + mgmt_index_event(MGMT_EV_UNCONF_INDEX_REMOVED, hdev, + NULL, 0, HCI_MGMT_UNCONF_INDEX_EVENTS); + ev.type = 0x01; + } else { + mgmt_index_event(MGMT_EV_INDEX_REMOVED, hdev, NULL, 0, + HCI_MGMT_INDEX_EVENTS); + ev.type = 0x00; + } + break; + case HCI_AMP: + ev.type = 0x02; + break; + default: + return; + } + + ev.bus = hdev->bus; + + mgmt_index_event(MGMT_EV_EXT_INDEX_REMOVED, hdev, &ev, sizeof(ev), + HCI_MGMT_EXT_INDEX_EVENTS); + + /* Cancel any remaining timed work */ + if (!hci_dev_test_flag(hdev, HCI_MGMT)) + return; + cancel_delayed_work_sync(&hdev->discov_off); + cancel_delayed_work_sync(&hdev->service_cache); + cancel_delayed_work_sync(&hdev->rpa_expired); +} + +void mgmt_power_on(struct hci_dev *hdev, int err) +{ + struct cmd_lookup match = { NULL, hdev }; + + bt_dev_dbg(hdev, "err %d", err); + + hci_dev_lock(hdev); + + if (!err) { + restart_le_actions(hdev); + hci_update_passive_scan(hdev); + } + + mgmt_pending_foreach(MGMT_OP_SET_POWERED, hdev, settings_rsp, &match); + + new_settings(hdev, match.sk); + + if (match.sk) + sock_put(match.sk); + + hci_dev_unlock(hdev); +} + +void __mgmt_power_off(struct hci_dev *hdev) +{ + struct cmd_lookup match = { NULL, hdev }; + u8 status, zero_cod[] = { 0, 0, 0 }; + + mgmt_pending_foreach(MGMT_OP_SET_POWERED, hdev, settings_rsp, &match); + + /* If the power off is because of hdev unregistration let + * use the appropriate INVALID_INDEX status. Otherwise use + * NOT_POWERED. We cover both scenarios here since later in + * mgmt_index_removed() any hci_conn callbacks will have already + * been triggered, potentially causing misleading DISCONNECTED + * status responses. + */ + if (hci_dev_test_flag(hdev, HCI_UNREGISTER)) + status = MGMT_STATUS_INVALID_INDEX; + else + status = MGMT_STATUS_NOT_POWERED; + + mgmt_pending_foreach(0, hdev, cmd_complete_rsp, &status); + + if (memcmp(hdev->dev_class, zero_cod, sizeof(zero_cod)) != 0) { + mgmt_limited_event(MGMT_EV_CLASS_OF_DEV_CHANGED, hdev, + zero_cod, sizeof(zero_cod), + HCI_MGMT_DEV_CLASS_EVENTS, NULL); + ext_info_changed(hdev, NULL); + } + + new_settings(hdev, match.sk); + + if (match.sk) + sock_put(match.sk); +} + +void mgmt_set_powered_failed(struct hci_dev *hdev, int err) +{ + struct mgmt_pending_cmd *cmd; + u8 status; + + cmd = pending_find(MGMT_OP_SET_POWERED, hdev); + if (!cmd) + return; + + if (err == -ERFKILL) + status = MGMT_STATUS_RFKILLED; + else + status = MGMT_STATUS_FAILED; + + mgmt_cmd_status(cmd->sk, hdev->id, MGMT_OP_SET_POWERED, status); + + mgmt_pending_remove(cmd); +} + +void mgmt_new_link_key(struct hci_dev *hdev, struct link_key *key, + bool persistent) +{ + struct mgmt_ev_new_link_key ev; + + memset(&ev, 0, sizeof(ev)); + + ev.store_hint = persistent; + bacpy(&ev.key.addr.bdaddr, &key->bdaddr); + ev.key.addr.type = link_to_bdaddr(key->link_type, key->bdaddr_type); + ev.key.type = key->type; + memcpy(ev.key.val, key->val, HCI_LINK_KEY_SIZE); + ev.key.pin_len = key->pin_len; + + mgmt_event(MGMT_EV_NEW_LINK_KEY, hdev, &ev, sizeof(ev), NULL); +} + +static u8 mgmt_ltk_type(struct smp_ltk *ltk) +{ + switch (ltk->type) { + case SMP_LTK: + case SMP_LTK_RESPONDER: + if (ltk->authenticated) + return MGMT_LTK_AUTHENTICATED; + return MGMT_LTK_UNAUTHENTICATED; + case SMP_LTK_P256: + if (ltk->authenticated) + return MGMT_LTK_P256_AUTH; + return MGMT_LTK_P256_UNAUTH; + case SMP_LTK_P256_DEBUG: + return MGMT_LTK_P256_DEBUG; + } + + return MGMT_LTK_UNAUTHENTICATED; +} + +void mgmt_new_ltk(struct hci_dev *hdev, struct smp_ltk *key, bool persistent) +{ + struct mgmt_ev_new_long_term_key ev; + + memset(&ev, 0, sizeof(ev)); + + /* Devices using resolvable or non-resolvable random addresses + * without providing an identity resolving key don't require + * to store long term keys. Their addresses will change the + * next time around. + * + * Only when a remote device provides an identity address + * make sure the long term key is stored. If the remote + * identity is known, the long term keys are internally + * mapped to the identity address. So allow static random + * and public addresses here. + */ + if (key->bdaddr_type == ADDR_LE_DEV_RANDOM && + (key->bdaddr.b[5] & 0xc0) != 0xc0) + ev.store_hint = 0x00; + else + ev.store_hint = persistent; + + bacpy(&ev.key.addr.bdaddr, &key->bdaddr); + ev.key.addr.type = link_to_bdaddr(key->link_type, key->bdaddr_type); + ev.key.type = mgmt_ltk_type(key); + ev.key.enc_size = key->enc_size; + ev.key.ediv = key->ediv; + ev.key.rand = key->rand; + + if (key->type == SMP_LTK) + ev.key.initiator = 1; + + /* Make sure we copy only the significant bytes based on the + * encryption key size, and set the rest of the value to zeroes. + */ + memcpy(ev.key.val, key->val, key->enc_size); + memset(ev.key.val + key->enc_size, 0, + sizeof(ev.key.val) - key->enc_size); + + mgmt_event(MGMT_EV_NEW_LONG_TERM_KEY, hdev, &ev, sizeof(ev), NULL); +} + +void mgmt_new_irk(struct hci_dev *hdev, struct smp_irk *irk, bool persistent) +{ + struct mgmt_ev_new_irk ev; + + memset(&ev, 0, sizeof(ev)); + + ev.store_hint = persistent; + + bacpy(&ev.rpa, &irk->rpa); + bacpy(&ev.irk.addr.bdaddr, &irk->bdaddr); + ev.irk.addr.type = link_to_bdaddr(irk->link_type, irk->addr_type); + memcpy(ev.irk.val, irk->val, sizeof(irk->val)); + + mgmt_event(MGMT_EV_NEW_IRK, hdev, &ev, sizeof(ev), NULL); +} + +void mgmt_new_csrk(struct hci_dev *hdev, struct smp_csrk *csrk, + bool persistent) +{ + struct mgmt_ev_new_csrk ev; + + memset(&ev, 0, sizeof(ev)); + + /* Devices using resolvable or non-resolvable random addresses + * without providing an identity resolving key don't require + * to store signature resolving keys. Their addresses will change + * the next time around. + * + * Only when a remote device provides an identity address + * make sure the signature resolving key is stored. So allow + * static random and public addresses here. + */ + if (csrk->bdaddr_type == ADDR_LE_DEV_RANDOM && + (csrk->bdaddr.b[5] & 0xc0) != 0xc0) + ev.store_hint = 0x00; + else + ev.store_hint = persistent; + + bacpy(&ev.key.addr.bdaddr, &csrk->bdaddr); + ev.key.addr.type = link_to_bdaddr(csrk->link_type, csrk->bdaddr_type); + ev.key.type = csrk->type; + memcpy(ev.key.val, csrk->val, sizeof(csrk->val)); + + mgmt_event(MGMT_EV_NEW_CSRK, hdev, &ev, sizeof(ev), NULL); +} + +void mgmt_new_conn_param(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 bdaddr_type, u8 store_hint, u16 min_interval, + u16 max_interval, u16 latency, u16 timeout) +{ + struct mgmt_ev_new_conn_param ev; + + if (!hci_is_identity_address(bdaddr, bdaddr_type)) + return; + + memset(&ev, 0, sizeof(ev)); + bacpy(&ev.addr.bdaddr, bdaddr); + ev.addr.type = link_to_bdaddr(LE_LINK, bdaddr_type); + ev.store_hint = store_hint; + ev.min_interval = cpu_to_le16(min_interval); + ev.max_interval = cpu_to_le16(max_interval); + ev.latency = cpu_to_le16(latency); + ev.timeout = cpu_to_le16(timeout); + + mgmt_event(MGMT_EV_NEW_CONN_PARAM, hdev, &ev, sizeof(ev), NULL); +} + +void mgmt_device_connected(struct hci_dev *hdev, struct hci_conn *conn, + u8 *name, u8 name_len) +{ + struct sk_buff *skb; + struct mgmt_ev_device_connected *ev; + u16 eir_len = 0; + u32 flags = 0; + + /* allocate buff for LE or BR/EDR adv */ + if (conn->le_adv_data_len > 0) + skb = mgmt_alloc_skb(hdev, MGMT_EV_DEVICE_CONNECTED, + sizeof(*ev) + conn->le_adv_data_len); + else + skb = mgmt_alloc_skb(hdev, MGMT_EV_DEVICE_CONNECTED, + sizeof(*ev) + (name ? eir_precalc_len(name_len) : 0) + + eir_precalc_len(sizeof(conn->dev_class))); + + ev = skb_put(skb, sizeof(*ev)); + bacpy(&ev->addr.bdaddr, &conn->dst); + ev->addr.type = link_to_bdaddr(conn->type, conn->dst_type); + + if (conn->out) + flags |= MGMT_DEV_FOUND_INITIATED_CONN; + + ev->flags = __cpu_to_le32(flags); + + /* We must ensure that the EIR Data fields are ordered and + * unique. Keep it simple for now and avoid the problem by not + * adding any BR/EDR data to the LE adv. + */ + if (conn->le_adv_data_len > 0) { + skb_put_data(skb, conn->le_adv_data, conn->le_adv_data_len); + eir_len = conn->le_adv_data_len; + } else { + if (name) + eir_len += eir_skb_put_data(skb, EIR_NAME_COMPLETE, name, name_len); + + if (memcmp(conn->dev_class, "\0\0\0", sizeof(conn->dev_class))) + eir_len += eir_skb_put_data(skb, EIR_CLASS_OF_DEV, + conn->dev_class, sizeof(conn->dev_class)); + } + + ev->eir_len = cpu_to_le16(eir_len); + + mgmt_event_skb(skb, NULL); +} + +static void disconnect_rsp(struct mgmt_pending_cmd *cmd, void *data) +{ + struct sock **sk = data; + + cmd->cmd_complete(cmd, 0); + + *sk = cmd->sk; + sock_hold(*sk); + + mgmt_pending_remove(cmd); +} + +static void unpair_device_rsp(struct mgmt_pending_cmd *cmd, void *data) +{ + struct hci_dev *hdev = data; + struct mgmt_cp_unpair_device *cp = cmd->param; + + device_unpaired(hdev, &cp->addr.bdaddr, cp->addr.type, cmd->sk); + + cmd->cmd_complete(cmd, 0); + mgmt_pending_remove(cmd); +} + +bool mgmt_powering_down(struct hci_dev *hdev) +{ + struct mgmt_pending_cmd *cmd; + struct mgmt_mode *cp; + + cmd = pending_find(MGMT_OP_SET_POWERED, hdev); + if (!cmd) + return false; + + cp = cmd->param; + if (!cp->val) + return true; + + return false; +} + +void mgmt_device_disconnected(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 link_type, u8 addr_type, u8 reason, + bool mgmt_connected) +{ + struct mgmt_ev_device_disconnected ev; + struct sock *sk = NULL; + + /* The connection is still in hci_conn_hash so test for 1 + * instead of 0 to know if this is the last one. + */ + if (mgmt_powering_down(hdev) && hci_conn_count(hdev) == 1) { + cancel_delayed_work(&hdev->power_off); + queue_work(hdev->req_workqueue, &hdev->power_off.work); + } + + if (!mgmt_connected) + return; + + if (link_type != ACL_LINK && link_type != LE_LINK) + return; + + mgmt_pending_foreach(MGMT_OP_DISCONNECT, hdev, disconnect_rsp, &sk); + + bacpy(&ev.addr.bdaddr, bdaddr); + ev.addr.type = link_to_bdaddr(link_type, addr_type); + ev.reason = reason; + + /* Report disconnects due to suspend */ + if (hdev->suspended) + ev.reason = MGMT_DEV_DISCONN_LOCAL_HOST_SUSPEND; + + mgmt_event(MGMT_EV_DEVICE_DISCONNECTED, hdev, &ev, sizeof(ev), sk); + + if (sk) + sock_put(sk); + + mgmt_pending_foreach(MGMT_OP_UNPAIR_DEVICE, hdev, unpair_device_rsp, + hdev); +} + +void mgmt_disconnect_failed(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 link_type, u8 addr_type, u8 status) +{ + u8 bdaddr_type = link_to_bdaddr(link_type, addr_type); + struct mgmt_cp_disconnect *cp; + struct mgmt_pending_cmd *cmd; + + mgmt_pending_foreach(MGMT_OP_UNPAIR_DEVICE, hdev, unpair_device_rsp, + hdev); + + cmd = pending_find(MGMT_OP_DISCONNECT, hdev); + if (!cmd) + return; + + cp = cmd->param; + + if (bacmp(bdaddr, &cp->addr.bdaddr)) + return; + + if (cp->addr.type != bdaddr_type) + return; + + cmd->cmd_complete(cmd, mgmt_status(status)); + mgmt_pending_remove(cmd); +} + +void mgmt_connect_failed(struct hci_dev *hdev, bdaddr_t *bdaddr, u8 link_type, + u8 addr_type, u8 status) +{ + struct mgmt_ev_connect_failed ev; + + /* The connection is still in hci_conn_hash so test for 1 + * instead of 0 to know if this is the last one. + */ + if (mgmt_powering_down(hdev) && hci_conn_count(hdev) == 1) { + cancel_delayed_work(&hdev->power_off); + queue_work(hdev->req_workqueue, &hdev->power_off.work); + } + + bacpy(&ev.addr.bdaddr, bdaddr); + ev.addr.type = link_to_bdaddr(link_type, addr_type); + ev.status = mgmt_status(status); + + mgmt_event(MGMT_EV_CONNECT_FAILED, hdev, &ev, sizeof(ev), NULL); +} + +void mgmt_pin_code_request(struct hci_dev *hdev, bdaddr_t *bdaddr, u8 secure) +{ + struct mgmt_ev_pin_code_request ev; + + bacpy(&ev.addr.bdaddr, bdaddr); + ev.addr.type = BDADDR_BREDR; + ev.secure = secure; + + mgmt_event(MGMT_EV_PIN_CODE_REQUEST, hdev, &ev, sizeof(ev), NULL); +} + +void mgmt_pin_code_reply_complete(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 status) +{ + struct mgmt_pending_cmd *cmd; + + cmd = pending_find(MGMT_OP_PIN_CODE_REPLY, hdev); + if (!cmd) + return; + + cmd->cmd_complete(cmd, mgmt_status(status)); + mgmt_pending_remove(cmd); +} + +void mgmt_pin_code_neg_reply_complete(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 status) +{ + struct mgmt_pending_cmd *cmd; + + cmd = pending_find(MGMT_OP_PIN_CODE_NEG_REPLY, hdev); + if (!cmd) + return; + + cmd->cmd_complete(cmd, mgmt_status(status)); + mgmt_pending_remove(cmd); +} + +int mgmt_user_confirm_request(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 link_type, u8 addr_type, u32 value, + u8 confirm_hint) +{ + struct mgmt_ev_user_confirm_request ev; + + bt_dev_dbg(hdev, "bdaddr %pMR", bdaddr); + + bacpy(&ev.addr.bdaddr, bdaddr); + ev.addr.type = link_to_bdaddr(link_type, addr_type); + ev.confirm_hint = confirm_hint; + ev.value = cpu_to_le32(value); + + return mgmt_event(MGMT_EV_USER_CONFIRM_REQUEST, hdev, &ev, sizeof(ev), + NULL); +} + +int mgmt_user_passkey_request(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 link_type, u8 addr_type) +{ + struct mgmt_ev_user_passkey_request ev; + + bt_dev_dbg(hdev, "bdaddr %pMR", bdaddr); + + bacpy(&ev.addr.bdaddr, bdaddr); + ev.addr.type = link_to_bdaddr(link_type, addr_type); + + return mgmt_event(MGMT_EV_USER_PASSKEY_REQUEST, hdev, &ev, sizeof(ev), + NULL); +} + +static int user_pairing_resp_complete(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 link_type, u8 addr_type, u8 status, + u8 opcode) +{ + struct mgmt_pending_cmd *cmd; + + cmd = pending_find(opcode, hdev); + if (!cmd) + return -ENOENT; + + cmd->cmd_complete(cmd, mgmt_status(status)); + mgmt_pending_remove(cmd); + + return 0; +} + +int mgmt_user_confirm_reply_complete(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 link_type, u8 addr_type, u8 status) +{ + return user_pairing_resp_complete(hdev, bdaddr, link_type, addr_type, + status, MGMT_OP_USER_CONFIRM_REPLY); +} + +int mgmt_user_confirm_neg_reply_complete(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 link_type, u8 addr_type, u8 status) +{ + return user_pairing_resp_complete(hdev, bdaddr, link_type, addr_type, + status, + MGMT_OP_USER_CONFIRM_NEG_REPLY); +} + +int mgmt_user_passkey_reply_complete(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 link_type, u8 addr_type, u8 status) +{ + return user_pairing_resp_complete(hdev, bdaddr, link_type, addr_type, + status, MGMT_OP_USER_PASSKEY_REPLY); +} + +int mgmt_user_passkey_neg_reply_complete(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 link_type, u8 addr_type, u8 status) +{ + return user_pairing_resp_complete(hdev, bdaddr, link_type, addr_type, + status, + MGMT_OP_USER_PASSKEY_NEG_REPLY); +} + +int mgmt_user_passkey_notify(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 link_type, u8 addr_type, u32 passkey, + u8 entered) +{ + struct mgmt_ev_passkey_notify ev; + + bt_dev_dbg(hdev, "bdaddr %pMR", bdaddr); + + bacpy(&ev.addr.bdaddr, bdaddr); + ev.addr.type = link_to_bdaddr(link_type, addr_type); + ev.passkey = __cpu_to_le32(passkey); + ev.entered = entered; + + return mgmt_event(MGMT_EV_PASSKEY_NOTIFY, hdev, &ev, sizeof(ev), NULL); +} + +void mgmt_auth_failed(struct hci_conn *conn, u8 hci_status) +{ + struct mgmt_ev_auth_failed ev; + struct mgmt_pending_cmd *cmd; + u8 status = mgmt_status(hci_status); + + bacpy(&ev.addr.bdaddr, &conn->dst); + ev.addr.type = link_to_bdaddr(conn->type, conn->dst_type); + ev.status = status; + + cmd = find_pairing(conn); + + mgmt_event(MGMT_EV_AUTH_FAILED, conn->hdev, &ev, sizeof(ev), + cmd ? cmd->sk : NULL); + + if (cmd) { + cmd->cmd_complete(cmd, status); + mgmt_pending_remove(cmd); + } +} + +void mgmt_auth_enable_complete(struct hci_dev *hdev, u8 status) +{ + struct cmd_lookup match = { NULL, hdev }; + bool changed; + + if (status) { + u8 mgmt_err = mgmt_status(status); + mgmt_pending_foreach(MGMT_OP_SET_LINK_SECURITY, hdev, + cmd_status_rsp, &mgmt_err); + return; + } + + if (test_bit(HCI_AUTH, &hdev->flags)) + changed = !hci_dev_test_and_set_flag(hdev, HCI_LINK_SECURITY); + else + changed = hci_dev_test_and_clear_flag(hdev, HCI_LINK_SECURITY); + + mgmt_pending_foreach(MGMT_OP_SET_LINK_SECURITY, hdev, settings_rsp, + &match); + + if (changed) + new_settings(hdev, match.sk); + + if (match.sk) + sock_put(match.sk); +} + +static void sk_lookup(struct mgmt_pending_cmd *cmd, void *data) +{ + struct cmd_lookup *match = data; + + if (match->sk == NULL) { + match->sk = cmd->sk; + sock_hold(match->sk); + } +} + +void mgmt_set_class_of_dev_complete(struct hci_dev *hdev, u8 *dev_class, + u8 status) +{ + struct cmd_lookup match = { NULL, hdev, mgmt_status(status) }; + + mgmt_pending_foreach(MGMT_OP_SET_DEV_CLASS, hdev, sk_lookup, &match); + mgmt_pending_foreach(MGMT_OP_ADD_UUID, hdev, sk_lookup, &match); + mgmt_pending_foreach(MGMT_OP_REMOVE_UUID, hdev, sk_lookup, &match); + + if (!status) { + mgmt_limited_event(MGMT_EV_CLASS_OF_DEV_CHANGED, hdev, dev_class, + 3, HCI_MGMT_DEV_CLASS_EVENTS, NULL); + ext_info_changed(hdev, NULL); + } + + if (match.sk) + sock_put(match.sk); +} + +void mgmt_set_local_name_complete(struct hci_dev *hdev, u8 *name, u8 status) +{ + struct mgmt_cp_set_local_name ev; + struct mgmt_pending_cmd *cmd; + + if (status) + return; + + memset(&ev, 0, sizeof(ev)); + memcpy(ev.name, name, HCI_MAX_NAME_LENGTH); + memcpy(ev.short_name, hdev->short_name, HCI_MAX_SHORT_NAME_LENGTH); + + cmd = pending_find(MGMT_OP_SET_LOCAL_NAME, hdev); + if (!cmd) { + memcpy(hdev->dev_name, name, sizeof(hdev->dev_name)); + + /* If this is a HCI command related to powering on the + * HCI dev don't send any mgmt signals. + */ + if (pending_find(MGMT_OP_SET_POWERED, hdev)) + return; + } + + mgmt_limited_event(MGMT_EV_LOCAL_NAME_CHANGED, hdev, &ev, sizeof(ev), + HCI_MGMT_LOCAL_NAME_EVENTS, cmd ? cmd->sk : NULL); + ext_info_changed(hdev, cmd ? cmd->sk : NULL); +} + +static inline bool has_uuid(u8 *uuid, u16 uuid_count, u8 (*uuids)[16]) +{ + int i; + + for (i = 0; i < uuid_count; i++) { + if (!memcmp(uuid, uuids[i], 16)) + return true; + } + + return false; +} + +static bool eir_has_uuids(u8 *eir, u16 eir_len, u16 uuid_count, u8 (*uuids)[16]) +{ + u16 parsed = 0; + + while (parsed < eir_len) { + u8 field_len = eir[0]; + u8 uuid[16]; + int i; + + if (field_len == 0) + break; + + if (eir_len - parsed < field_len + 1) + break; + + switch (eir[1]) { + case EIR_UUID16_ALL: + case EIR_UUID16_SOME: + for (i = 0; i + 3 <= field_len; i += 2) { + memcpy(uuid, bluetooth_base_uuid, 16); + uuid[13] = eir[i + 3]; + uuid[12] = eir[i + 2]; + if (has_uuid(uuid, uuid_count, uuids)) + return true; + } + break; + case EIR_UUID32_ALL: + case EIR_UUID32_SOME: + for (i = 0; i + 5 <= field_len; i += 4) { + memcpy(uuid, bluetooth_base_uuid, 16); + uuid[15] = eir[i + 5]; + uuid[14] = eir[i + 4]; + uuid[13] = eir[i + 3]; + uuid[12] = eir[i + 2]; + if (has_uuid(uuid, uuid_count, uuids)) + return true; + } + break; + case EIR_UUID128_ALL: + case EIR_UUID128_SOME: + for (i = 0; i + 17 <= field_len; i += 16) { + memcpy(uuid, eir + i + 2, 16); + if (has_uuid(uuid, uuid_count, uuids)) + return true; + } + break; + } + + parsed += field_len + 1; + eir += field_len + 1; + } + + return false; +} + +static void restart_le_scan(struct hci_dev *hdev) +{ + /* If controller is not scanning we are done. */ + if (!hci_dev_test_flag(hdev, HCI_LE_SCAN)) + return; + + if (time_after(jiffies + DISCOV_LE_RESTART_DELAY, + hdev->discovery.scan_start + + hdev->discovery.scan_duration)) + return; + + queue_delayed_work(hdev->req_workqueue, &hdev->le_scan_restart, + DISCOV_LE_RESTART_DELAY); +} + +static bool is_filter_match(struct hci_dev *hdev, s8 rssi, u8 *eir, + u16 eir_len, u8 *scan_rsp, u8 scan_rsp_len) +{ + /* If a RSSI threshold has been specified, and + * HCI_QUIRK_STRICT_DUPLICATE_FILTER is not set, then all results with + * a RSSI smaller than the RSSI threshold will be dropped. If the quirk + * is set, let it through for further processing, as we might need to + * restart the scan. + * + * For BR/EDR devices (pre 1.2) providing no RSSI during inquiry, + * the results are also dropped. + */ + if (hdev->discovery.rssi != HCI_RSSI_INVALID && + (rssi == HCI_RSSI_INVALID || + (rssi < hdev->discovery.rssi && + !test_bit(HCI_QUIRK_STRICT_DUPLICATE_FILTER, &hdev->quirks)))) + return false; + + if (hdev->discovery.uuid_count != 0) { + /* If a list of UUIDs is provided in filter, results with no + * matching UUID should be dropped. + */ + if (!eir_has_uuids(eir, eir_len, hdev->discovery.uuid_count, + hdev->discovery.uuids) && + !eir_has_uuids(scan_rsp, scan_rsp_len, + hdev->discovery.uuid_count, + hdev->discovery.uuids)) + return false; + } + + /* If duplicate filtering does not report RSSI changes, then restart + * scanning to ensure updated result with updated RSSI values. + */ + if (test_bit(HCI_QUIRK_STRICT_DUPLICATE_FILTER, &hdev->quirks)) { + restart_le_scan(hdev); + + /* Validate RSSI value against the RSSI threshold once more. */ + if (hdev->discovery.rssi != HCI_RSSI_INVALID && + rssi < hdev->discovery.rssi) + return false; + } + + return true; +} + +void mgmt_adv_monitor_device_lost(struct hci_dev *hdev, u16 handle, + bdaddr_t *bdaddr, u8 addr_type) +{ + struct mgmt_ev_adv_monitor_device_lost ev; + + ev.monitor_handle = cpu_to_le16(handle); + bacpy(&ev.addr.bdaddr, bdaddr); + ev.addr.type = addr_type; + + mgmt_event(MGMT_EV_ADV_MONITOR_DEVICE_LOST, hdev, &ev, sizeof(ev), + NULL); +} + +static void mgmt_send_adv_monitor_device_found(struct hci_dev *hdev, + struct sk_buff *skb, + struct sock *skip_sk, + u16 handle) +{ + struct sk_buff *advmon_skb; + size_t advmon_skb_len; + __le16 *monitor_handle; + + if (!skb) + return; + + advmon_skb_len = (sizeof(struct mgmt_ev_adv_monitor_device_found) - + sizeof(struct mgmt_ev_device_found)) + skb->len; + advmon_skb = mgmt_alloc_skb(hdev, MGMT_EV_ADV_MONITOR_DEVICE_FOUND, + advmon_skb_len); + if (!advmon_skb) + return; + + /* ADV_MONITOR_DEVICE_FOUND is similar to DEVICE_FOUND event except + * that it also has 'monitor_handle'. Make a copy of DEVICE_FOUND and + * store monitor_handle of the matched monitor. + */ + monitor_handle = skb_put(advmon_skb, sizeof(*monitor_handle)); + *monitor_handle = cpu_to_le16(handle); + skb_put_data(advmon_skb, skb->data, skb->len); + + mgmt_event_skb(advmon_skb, skip_sk); +} + +static void mgmt_adv_monitor_device_found(struct hci_dev *hdev, + bdaddr_t *bdaddr, bool report_device, + struct sk_buff *skb, + struct sock *skip_sk) +{ + struct monitored_device *dev, *tmp; + bool matched = false; + bool notified = false; + + /* We have received the Advertisement Report because: + * 1. the kernel has initiated active discovery + * 2. if not, we have pend_le_reports > 0 in which case we are doing + * passive scanning + * 3. if none of the above is true, we have one or more active + * Advertisement Monitor + * + * For case 1 and 2, report all advertisements via MGMT_EV_DEVICE_FOUND + * and report ONLY one advertisement per device for the matched Monitor + * via MGMT_EV_ADV_MONITOR_DEVICE_FOUND event. + * + * For case 3, since we are not active scanning and all advertisements + * received are due to a matched Advertisement Monitor, report all + * advertisements ONLY via MGMT_EV_ADV_MONITOR_DEVICE_FOUND event. + */ + if (report_device && !hdev->advmon_pend_notify) { + mgmt_event_skb(skb, skip_sk); + return; + } + + hdev->advmon_pend_notify = false; + + list_for_each_entry_safe(dev, tmp, &hdev->monitored_devices, list) { + if (!bacmp(&dev->bdaddr, bdaddr)) { + matched = true; + + if (!dev->notified) { + mgmt_send_adv_monitor_device_found(hdev, skb, + skip_sk, + dev->handle); + notified = true; + dev->notified = true; + } + } + + if (!dev->notified) + hdev->advmon_pend_notify = true; + } + + if (!report_device && + ((matched && !notified) || !msft_monitor_supported(hdev))) { + /* Handle 0 indicates that we are not active scanning and this + * is a subsequent advertisement report for an already matched + * Advertisement Monitor or the controller offloading support + * is not available. + */ + mgmt_send_adv_monitor_device_found(hdev, skb, skip_sk, 0); + } + + if (report_device) + mgmt_event_skb(skb, skip_sk); + else + kfree_skb(skb); +} + +static void mesh_device_found(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 addr_type, s8 rssi, u32 flags, u8 *eir, + u16 eir_len, u8 *scan_rsp, u8 scan_rsp_len, + u64 instant) +{ + struct sk_buff *skb; + struct mgmt_ev_mesh_device_found *ev; + int i, j; + + if (!hdev->mesh_ad_types[0]) + goto accepted; + + /* Scan for requested AD types */ + if (eir_len > 0) { + for (i = 0; i + 1 < eir_len; i += eir[i] + 1) { + for (j = 0; j < sizeof(hdev->mesh_ad_types); j++) { + if (!hdev->mesh_ad_types[j]) + break; + + if (hdev->mesh_ad_types[j] == eir[i + 1]) + goto accepted; + } + } + } + + if (scan_rsp_len > 0) { + for (i = 0; i + 1 < scan_rsp_len; i += scan_rsp[i] + 1) { + for (j = 0; j < sizeof(hdev->mesh_ad_types); j++) { + if (!hdev->mesh_ad_types[j]) + break; + + if (hdev->mesh_ad_types[j] == scan_rsp[i + 1]) + goto accepted; + } + } + } + + return; + +accepted: + skb = mgmt_alloc_skb(hdev, MGMT_EV_MESH_DEVICE_FOUND, + sizeof(*ev) + eir_len + scan_rsp_len); + if (!skb) + return; + + ev = skb_put(skb, sizeof(*ev)); + + bacpy(&ev->addr.bdaddr, bdaddr); + ev->addr.type = link_to_bdaddr(LE_LINK, addr_type); + ev->rssi = rssi; + ev->flags = cpu_to_le32(flags); + ev->instant = cpu_to_le64(instant); + + if (eir_len > 0) + /* Copy EIR or advertising data into event */ + skb_put_data(skb, eir, eir_len); + + if (scan_rsp_len > 0) + /* Append scan response data to event */ + skb_put_data(skb, scan_rsp, scan_rsp_len); + + ev->eir_len = cpu_to_le16(eir_len + scan_rsp_len); + + mgmt_event_skb(skb, NULL); +} + +void mgmt_device_found(struct hci_dev *hdev, bdaddr_t *bdaddr, u8 link_type, + u8 addr_type, u8 *dev_class, s8 rssi, u32 flags, + u8 *eir, u16 eir_len, u8 *scan_rsp, u8 scan_rsp_len, + u64 instant) +{ + struct sk_buff *skb; + struct mgmt_ev_device_found *ev; + bool report_device = hci_discovery_active(hdev); + + if (hci_dev_test_flag(hdev, HCI_MESH) && link_type == LE_LINK) + mesh_device_found(hdev, bdaddr, addr_type, rssi, flags, + eir, eir_len, scan_rsp, scan_rsp_len, + instant); + + /* Don't send events for a non-kernel initiated discovery. With + * LE one exception is if we have pend_le_reports > 0 in which + * case we're doing passive scanning and want these events. + */ + if (!hci_discovery_active(hdev)) { + if (link_type == ACL_LINK) + return; + if (link_type == LE_LINK && !list_empty(&hdev->pend_le_reports)) + report_device = true; + else if (!hci_is_adv_monitoring(hdev)) + return; + } + + if (hdev->discovery.result_filtering) { + /* We are using service discovery */ + if (!is_filter_match(hdev, rssi, eir, eir_len, scan_rsp, + scan_rsp_len)) + return; + } + + if (hdev->discovery.limited) { + /* Check for limited discoverable bit */ + if (dev_class) { + if (!(dev_class[1] & 0x20)) + return; + } else { + u8 *flags = eir_get_data(eir, eir_len, EIR_FLAGS, NULL); + if (!flags || !(flags[0] & LE_AD_LIMITED)) + return; + } + } + + /* Allocate skb. The 5 extra bytes are for the potential CoD field */ + skb = mgmt_alloc_skb(hdev, MGMT_EV_DEVICE_FOUND, + sizeof(*ev) + eir_len + scan_rsp_len + 5); + if (!skb) + return; + + ev = skb_put(skb, sizeof(*ev)); + + /* In case of device discovery with BR/EDR devices (pre 1.2), the + * RSSI value was reported as 0 when not available. This behavior + * is kept when using device discovery. This is required for full + * backwards compatibility with the API. + * + * However when using service discovery, the value 127 will be + * returned when the RSSI is not available. + */ + if (rssi == HCI_RSSI_INVALID && !hdev->discovery.report_invalid_rssi && + link_type == ACL_LINK) + rssi = 0; + + bacpy(&ev->addr.bdaddr, bdaddr); + ev->addr.type = link_to_bdaddr(link_type, addr_type); + ev->rssi = rssi; + ev->flags = cpu_to_le32(flags); + + if (eir_len > 0) + /* Copy EIR or advertising data into event */ + skb_put_data(skb, eir, eir_len); + + if (dev_class && !eir_get_data(eir, eir_len, EIR_CLASS_OF_DEV, NULL)) { + u8 eir_cod[5]; + + eir_len += eir_append_data(eir_cod, 0, EIR_CLASS_OF_DEV, + dev_class, 3); + skb_put_data(skb, eir_cod, sizeof(eir_cod)); + } + + if (scan_rsp_len > 0) + /* Append scan response data to event */ + skb_put_data(skb, scan_rsp, scan_rsp_len); + + ev->eir_len = cpu_to_le16(eir_len + scan_rsp_len); + + mgmt_adv_monitor_device_found(hdev, bdaddr, report_device, skb, NULL); +} + +void mgmt_remote_name(struct hci_dev *hdev, bdaddr_t *bdaddr, u8 link_type, + u8 addr_type, s8 rssi, u8 *name, u8 name_len) +{ + struct sk_buff *skb; + struct mgmt_ev_device_found *ev; + u16 eir_len = 0; + u32 flags = 0; + + skb = mgmt_alloc_skb(hdev, MGMT_EV_DEVICE_FOUND, + sizeof(*ev) + (name ? eir_precalc_len(name_len) : 0)); + + ev = skb_put(skb, sizeof(*ev)); + bacpy(&ev->addr.bdaddr, bdaddr); + ev->addr.type = link_to_bdaddr(link_type, addr_type); + ev->rssi = rssi; + + if (name) + eir_len += eir_skb_put_data(skb, EIR_NAME_COMPLETE, name, name_len); + else + flags = MGMT_DEV_FOUND_NAME_REQUEST_FAILED; + + ev->eir_len = cpu_to_le16(eir_len); + ev->flags = cpu_to_le32(flags); + + mgmt_event_skb(skb, NULL); +} + +void mgmt_discovering(struct hci_dev *hdev, u8 discovering) +{ + struct mgmt_ev_discovering ev; + + bt_dev_dbg(hdev, "discovering %u", discovering); + + memset(&ev, 0, sizeof(ev)); + ev.type = hdev->discovery.type; + ev.discovering = discovering; + + mgmt_event(MGMT_EV_DISCOVERING, hdev, &ev, sizeof(ev), NULL); +} + +void mgmt_suspending(struct hci_dev *hdev, u8 state) +{ + struct mgmt_ev_controller_suspend ev; + + ev.suspend_state = state; + mgmt_event(MGMT_EV_CONTROLLER_SUSPEND, hdev, &ev, sizeof(ev), NULL); +} + +void mgmt_resuming(struct hci_dev *hdev, u8 reason, bdaddr_t *bdaddr, + u8 addr_type) +{ + struct mgmt_ev_controller_resume ev; + + ev.wake_reason = reason; + if (bdaddr) { + bacpy(&ev.addr.bdaddr, bdaddr); + ev.addr.type = addr_type; + } else { + memset(&ev.addr, 0, sizeof(ev.addr)); + } + + mgmt_event(MGMT_EV_CONTROLLER_RESUME, hdev, &ev, sizeof(ev), NULL); +} + +static struct hci_mgmt_chan chan = { + .channel = HCI_CHANNEL_CONTROL, + .handler_count = ARRAY_SIZE(mgmt_handlers), + .handlers = mgmt_handlers, + .hdev_init = mgmt_init_hdev, +}; + +int mgmt_init(void) +{ + return hci_mgmt_chan_register(&chan); +} + +void mgmt_exit(void) +{ + hci_mgmt_chan_unregister(&chan); +} + +void mgmt_cleanup(struct sock *sk) +{ + struct mgmt_mesh_tx *mesh_tx; + struct hci_dev *hdev; + + read_lock(&hci_dev_list_lock); + + list_for_each_entry(hdev, &hci_dev_list, list) { + do { + mesh_tx = mgmt_mesh_next(hdev, sk); + + if (mesh_tx) + mesh_send_complete(hdev, mesh_tx, true); + } while (mesh_tx); + } + + read_unlock(&hci_dev_list_lock); +} diff --git a/net/bluetooth/mgmt_config.c b/net/bluetooth/mgmt_config.c new file mode 100644 index 000000000..6ef701c27 --- /dev/null +++ b/net/bluetooth/mgmt_config.c @@ -0,0 +1,346 @@ +// SPDX-License-Identifier: GPL-2.0-only + +/* + * Copyright (C) 2020 Google Corporation + */ + +#include +#include +#include + +#include "mgmt_util.h" +#include "mgmt_config.h" + +#define HDEV_PARAM_U16(_param_name_) \ + struct {\ + struct mgmt_tlv entry; \ + __le16 value; \ + } __packed _param_name_ + +#define HDEV_PARAM_U8(_param_name_) \ + struct {\ + struct mgmt_tlv entry; \ + __u8 value; \ + } __packed _param_name_ + +#define TLV_SET_U16(_param_code_, _param_name_) \ + { \ + { cpu_to_le16(_param_code_), sizeof(__u16) }, \ + cpu_to_le16(hdev->_param_name_) \ + } + +#define TLV_SET_U8(_param_code_, _param_name_) \ + { \ + { cpu_to_le16(_param_code_), sizeof(__u8) }, \ + hdev->_param_name_ \ + } + +#define TLV_SET_U16_JIFFIES_TO_MSECS(_param_code_, _param_name_) \ + { \ + { cpu_to_le16(_param_code_), sizeof(__u16) }, \ + cpu_to_le16(jiffies_to_msecs(hdev->_param_name_)) \ + } + +int read_def_system_config(struct sock *sk, struct hci_dev *hdev, void *data, + u16 data_len) +{ + int ret; + struct mgmt_rp_read_def_system_config { + /* Please see mgmt-api.txt for documentation of these values */ + HDEV_PARAM_U16(def_page_scan_type); + HDEV_PARAM_U16(def_page_scan_int); + HDEV_PARAM_U16(def_page_scan_window); + HDEV_PARAM_U16(def_inq_scan_type); + HDEV_PARAM_U16(def_inq_scan_int); + HDEV_PARAM_U16(def_inq_scan_window); + HDEV_PARAM_U16(def_br_lsto); + HDEV_PARAM_U16(def_page_timeout); + HDEV_PARAM_U16(sniff_min_interval); + HDEV_PARAM_U16(sniff_max_interval); + HDEV_PARAM_U16(le_adv_min_interval); + HDEV_PARAM_U16(le_adv_max_interval); + HDEV_PARAM_U16(def_multi_adv_rotation_duration); + HDEV_PARAM_U16(le_scan_interval); + HDEV_PARAM_U16(le_scan_window); + HDEV_PARAM_U16(le_scan_int_suspend); + HDEV_PARAM_U16(le_scan_window_suspend); + HDEV_PARAM_U16(le_scan_int_discovery); + HDEV_PARAM_U16(le_scan_window_discovery); + HDEV_PARAM_U16(le_scan_int_adv_monitor); + HDEV_PARAM_U16(le_scan_window_adv_monitor); + HDEV_PARAM_U16(le_scan_int_connect); + HDEV_PARAM_U16(le_scan_window_connect); + HDEV_PARAM_U16(le_conn_min_interval); + HDEV_PARAM_U16(le_conn_max_interval); + HDEV_PARAM_U16(le_conn_latency); + HDEV_PARAM_U16(le_supv_timeout); + HDEV_PARAM_U16(def_le_autoconnect_timeout); + HDEV_PARAM_U16(advmon_allowlist_duration); + HDEV_PARAM_U16(advmon_no_filter_duration); + HDEV_PARAM_U8(enable_advmon_interleave_scan); + } __packed rp = { + TLV_SET_U16(0x0000, def_page_scan_type), + TLV_SET_U16(0x0001, def_page_scan_int), + TLV_SET_U16(0x0002, def_page_scan_window), + TLV_SET_U16(0x0003, def_inq_scan_type), + TLV_SET_U16(0x0004, def_inq_scan_int), + TLV_SET_U16(0x0005, def_inq_scan_window), + TLV_SET_U16(0x0006, def_br_lsto), + TLV_SET_U16(0x0007, def_page_timeout), + TLV_SET_U16(0x0008, sniff_min_interval), + TLV_SET_U16(0x0009, sniff_max_interval), + TLV_SET_U16(0x000a, le_adv_min_interval), + TLV_SET_U16(0x000b, le_adv_max_interval), + TLV_SET_U16(0x000c, def_multi_adv_rotation_duration), + TLV_SET_U16(0x000d, le_scan_interval), + TLV_SET_U16(0x000e, le_scan_window), + TLV_SET_U16(0x000f, le_scan_int_suspend), + TLV_SET_U16(0x0010, le_scan_window_suspend), + TLV_SET_U16(0x0011, le_scan_int_discovery), + TLV_SET_U16(0x0012, le_scan_window_discovery), + TLV_SET_U16(0x0013, le_scan_int_adv_monitor), + TLV_SET_U16(0x0014, le_scan_window_adv_monitor), + TLV_SET_U16(0x0015, le_scan_int_connect), + TLV_SET_U16(0x0016, le_scan_window_connect), + TLV_SET_U16(0x0017, le_conn_min_interval), + TLV_SET_U16(0x0018, le_conn_max_interval), + TLV_SET_U16(0x0019, le_conn_latency), + TLV_SET_U16(0x001a, le_supv_timeout), + TLV_SET_U16_JIFFIES_TO_MSECS(0x001b, + def_le_autoconnect_timeout), + TLV_SET_U16(0x001d, advmon_allowlist_duration), + TLV_SET_U16(0x001e, advmon_no_filter_duration), + TLV_SET_U8(0x001f, enable_advmon_interleave_scan), + }; + + bt_dev_dbg(hdev, "sock %p", sk); + + ret = mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_READ_DEF_SYSTEM_CONFIG, + 0, &rp, sizeof(rp)); + return ret; +} + +#define TO_TLV(x) ((struct mgmt_tlv *)(x)) +#define TLV_GET_LE16(tlv) le16_to_cpu(*((__le16 *)(TO_TLV(tlv)->value))) +#define TLV_GET_U8(tlv) (*((__u8 *)(TO_TLV(tlv)->value))) + +int set_def_system_config(struct sock *sk, struct hci_dev *hdev, void *data, + u16 data_len) +{ + u16 buffer_left = data_len; + u8 *buffer = data; + + if (buffer_left < sizeof(struct mgmt_tlv)) { + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_DEF_SYSTEM_CONFIG, + MGMT_STATUS_INVALID_PARAMS); + } + + /* First pass to validate the tlv */ + while (buffer_left >= sizeof(struct mgmt_tlv)) { + const u8 len = TO_TLV(buffer)->length; + size_t exp_type_len; + const u16 exp_len = sizeof(struct mgmt_tlv) + + len; + const u16 type = le16_to_cpu(TO_TLV(buffer)->type); + + if (buffer_left < exp_len) { + bt_dev_warn(hdev, "invalid len left %u, exp >= %u", + buffer_left, exp_len); + + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_DEF_SYSTEM_CONFIG, + MGMT_STATUS_INVALID_PARAMS); + } + + /* Please see mgmt-api.txt for documentation of these values */ + switch (type) { + case 0x0000: + case 0x0001: + case 0x0002: + case 0x0003: + case 0x0004: + case 0x0005: + case 0x0006: + case 0x0007: + case 0x0008: + case 0x0009: + case 0x000a: + case 0x000b: + case 0x000c: + case 0x000d: + case 0x000e: + case 0x000f: + case 0x0010: + case 0x0011: + case 0x0012: + case 0x0013: + case 0x0014: + case 0x0015: + case 0x0016: + case 0x0017: + case 0x0018: + case 0x0019: + case 0x001a: + case 0x001b: + case 0x001d: + case 0x001e: + exp_type_len = sizeof(u16); + break; + case 0x001f: + exp_type_len = sizeof(u8); + break; + default: + exp_type_len = 0; + bt_dev_warn(hdev, "unsupported parameter %u", type); + break; + } + + if (exp_type_len && len != exp_type_len) { + bt_dev_warn(hdev, "invalid length %d, exp %zu for type %u", + len, exp_type_len, type); + + return mgmt_cmd_status(sk, hdev->id, + MGMT_OP_SET_DEF_SYSTEM_CONFIG, + MGMT_STATUS_INVALID_PARAMS); + } + + buffer_left -= exp_len; + buffer += exp_len; + } + + buffer_left = data_len; + buffer = data; + while (buffer_left >= sizeof(struct mgmt_tlv)) { + const u8 len = TO_TLV(buffer)->length; + const u16 exp_len = sizeof(struct mgmt_tlv) + + len; + const u16 type = le16_to_cpu(TO_TLV(buffer)->type); + + switch (type) { + case 0x0000: + hdev->def_page_scan_type = TLV_GET_LE16(buffer); + break; + case 0x0001: + hdev->def_page_scan_int = TLV_GET_LE16(buffer); + break; + case 0x0002: + hdev->def_page_scan_window = TLV_GET_LE16(buffer); + break; + case 0x0003: + hdev->def_inq_scan_type = TLV_GET_LE16(buffer); + break; + case 0x0004: + hdev->def_inq_scan_int = TLV_GET_LE16(buffer); + break; + case 0x0005: + hdev->def_inq_scan_window = TLV_GET_LE16(buffer); + break; + case 0x0006: + hdev->def_br_lsto = TLV_GET_LE16(buffer); + break; + case 0x0007: + hdev->def_page_timeout = TLV_GET_LE16(buffer); + break; + case 0x0008: + hdev->sniff_min_interval = TLV_GET_LE16(buffer); + break; + case 0x0009: + hdev->sniff_max_interval = TLV_GET_LE16(buffer); + break; + case 0x000a: + hdev->le_adv_min_interval = TLV_GET_LE16(buffer); + break; + case 0x000b: + hdev->le_adv_max_interval = TLV_GET_LE16(buffer); + break; + case 0x000c: + hdev->def_multi_adv_rotation_duration = + TLV_GET_LE16(buffer); + break; + case 0x000d: + hdev->le_scan_interval = TLV_GET_LE16(buffer); + break; + case 0x000e: + hdev->le_scan_window = TLV_GET_LE16(buffer); + break; + case 0x000f: + hdev->le_scan_int_suspend = TLV_GET_LE16(buffer); + break; + case 0x0010: + hdev->le_scan_window_suspend = TLV_GET_LE16(buffer); + break; + case 0x0011: + hdev->le_scan_int_discovery = TLV_GET_LE16(buffer); + break; + case 0x00012: + hdev->le_scan_window_discovery = TLV_GET_LE16(buffer); + break; + case 0x00013: + hdev->le_scan_int_adv_monitor = TLV_GET_LE16(buffer); + break; + case 0x00014: + hdev->le_scan_window_adv_monitor = TLV_GET_LE16(buffer); + break; + case 0x00015: + hdev->le_scan_int_connect = TLV_GET_LE16(buffer); + break; + case 0x00016: + hdev->le_scan_window_connect = TLV_GET_LE16(buffer); + break; + case 0x00017: + hdev->le_conn_min_interval = TLV_GET_LE16(buffer); + break; + case 0x00018: + hdev->le_conn_max_interval = TLV_GET_LE16(buffer); + break; + case 0x00019: + hdev->le_conn_latency = TLV_GET_LE16(buffer); + break; + case 0x0001a: + hdev->le_supv_timeout = TLV_GET_LE16(buffer); + break; + case 0x0001b: + hdev->def_le_autoconnect_timeout = + msecs_to_jiffies(TLV_GET_LE16(buffer)); + break; + case 0x0001d: + hdev->advmon_allowlist_duration = TLV_GET_LE16(buffer); + break; + case 0x0001e: + hdev->advmon_no_filter_duration = TLV_GET_LE16(buffer); + break; + case 0x0001f: + hdev->enable_advmon_interleave_scan = TLV_GET_U8(buffer); + break; + default: + bt_dev_warn(hdev, "unsupported parameter %u", type); + break; + } + + buffer_left -= exp_len; + buffer += exp_len; + } + + return mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_SET_DEF_SYSTEM_CONFIG, 0, NULL, 0); +} + +int read_def_runtime_config(struct sock *sk, struct hci_dev *hdev, void *data, + u16 data_len) +{ + bt_dev_dbg(hdev, "sock %p", sk); + + return mgmt_cmd_complete(sk, hdev->id, + MGMT_OP_READ_DEF_RUNTIME_CONFIG, 0, NULL, 0); +} + +int set_def_runtime_config(struct sock *sk, struct hci_dev *hdev, void *data, + u16 data_len) +{ + bt_dev_dbg(hdev, "sock %p", sk); + + return mgmt_cmd_status(sk, hdev->id, MGMT_OP_SET_DEF_SYSTEM_CONFIG, + MGMT_STATUS_INVALID_PARAMS); +} diff --git a/net/bluetooth/mgmt_config.h b/net/bluetooth/mgmt_config.h new file mode 100644 index 000000000..a4965f107 --- /dev/null +++ b/net/bluetooth/mgmt_config.h @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: GPL-2.0-only + +/* + * Copyright (C) 2020 Google Corporation + */ + +int read_def_system_config(struct sock *sk, struct hci_dev *hdev, void *data, + u16 data_len); + +int set_def_system_config(struct sock *sk, struct hci_dev *hdev, void *data, + u16 data_len); + +int read_def_runtime_config(struct sock *sk, struct hci_dev *hdev, void *data, + u16 data_len); + +int set_def_runtime_config(struct sock *sk, struct hci_dev *hdev, void *data, + u16 data_len); diff --git a/net/bluetooth/mgmt_util.c b/net/bluetooth/mgmt_util.c new file mode 100644 index 000000000..0115f783b --- /dev/null +++ b/net/bluetooth/mgmt_util.c @@ -0,0 +1,390 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + + Copyright (C) 2015 Intel Corporation + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#include + +#include +#include +#include +#include + +#include "mgmt_util.h" + +static struct sk_buff *create_monitor_ctrl_event(__le16 index, u32 cookie, + u16 opcode, u16 len, void *buf) +{ + struct hci_mon_hdr *hdr; + struct sk_buff *skb; + + skb = bt_skb_alloc(6 + len, GFP_ATOMIC); + if (!skb) + return NULL; + + put_unaligned_le32(cookie, skb_put(skb, 4)); + put_unaligned_le16(opcode, skb_put(skb, 2)); + + if (buf) + skb_put_data(skb, buf, len); + + __net_timestamp(skb); + + hdr = skb_push(skb, HCI_MON_HDR_SIZE); + hdr->opcode = cpu_to_le16(HCI_MON_CTRL_EVENT); + hdr->index = index; + hdr->len = cpu_to_le16(skb->len - HCI_MON_HDR_SIZE); + + return skb; +} + +struct sk_buff *mgmt_alloc_skb(struct hci_dev *hdev, u16 opcode, + unsigned int size) +{ + struct sk_buff *skb; + + skb = alloc_skb(sizeof(struct mgmt_hdr) + size, GFP_KERNEL); + if (!skb) + return skb; + + skb_reserve(skb, sizeof(struct mgmt_hdr)); + bt_cb(skb)->mgmt.hdev = hdev; + bt_cb(skb)->mgmt.opcode = opcode; + + return skb; +} + +int mgmt_send_event_skb(unsigned short channel, struct sk_buff *skb, int flag, + struct sock *skip_sk) +{ + struct hci_dev *hdev; + struct mgmt_hdr *hdr; + int len; + + if (!skb) + return -EINVAL; + + len = skb->len; + hdev = bt_cb(skb)->mgmt.hdev; + + /* Time stamp */ + __net_timestamp(skb); + + /* Send just the data, without headers, to the monitor */ + if (channel == HCI_CHANNEL_CONTROL) + hci_send_monitor_ctrl_event(hdev, bt_cb(skb)->mgmt.opcode, + skb->data, skb->len, + skb_get_ktime(skb), flag, skip_sk); + + hdr = skb_push(skb, sizeof(*hdr)); + hdr->opcode = cpu_to_le16(bt_cb(skb)->mgmt.opcode); + if (hdev) + hdr->index = cpu_to_le16(hdev->id); + else + hdr->index = cpu_to_le16(MGMT_INDEX_NONE); + hdr->len = cpu_to_le16(len); + + hci_send_to_channel(channel, skb, flag, skip_sk); + + kfree_skb(skb); + return 0; +} + +int mgmt_send_event(u16 event, struct hci_dev *hdev, unsigned short channel, + void *data, u16 data_len, int flag, struct sock *skip_sk) +{ + struct sk_buff *skb; + + skb = mgmt_alloc_skb(hdev, event, data_len); + if (!skb) + return -ENOMEM; + + if (data) + skb_put_data(skb, data, data_len); + + return mgmt_send_event_skb(channel, skb, flag, skip_sk); +} + +int mgmt_cmd_status(struct sock *sk, u16 index, u16 cmd, u8 status) +{ + struct sk_buff *skb, *mskb; + struct mgmt_hdr *hdr; + struct mgmt_ev_cmd_status *ev; + int err; + + BT_DBG("sock %p, index %u, cmd %u, status %u", sk, index, cmd, status); + + skb = alloc_skb(sizeof(*hdr) + sizeof(*ev), GFP_KERNEL); + if (!skb) + return -ENOMEM; + + hdr = skb_put(skb, sizeof(*hdr)); + + hdr->opcode = cpu_to_le16(MGMT_EV_CMD_STATUS); + hdr->index = cpu_to_le16(index); + hdr->len = cpu_to_le16(sizeof(*ev)); + + ev = skb_put(skb, sizeof(*ev)); + ev->status = status; + ev->opcode = cpu_to_le16(cmd); + + mskb = create_monitor_ctrl_event(hdr->index, hci_sock_get_cookie(sk), + MGMT_EV_CMD_STATUS, sizeof(*ev), ev); + if (mskb) + skb->tstamp = mskb->tstamp; + else + __net_timestamp(skb); + + err = sock_queue_rcv_skb(sk, skb); + if (err < 0) + kfree_skb(skb); + + if (mskb) { + hci_send_to_channel(HCI_CHANNEL_MONITOR, mskb, + HCI_SOCK_TRUSTED, NULL); + kfree_skb(mskb); + } + + return err; +} + +int mgmt_cmd_complete(struct sock *sk, u16 index, u16 cmd, u8 status, + void *rp, size_t rp_len) +{ + struct sk_buff *skb, *mskb; + struct mgmt_hdr *hdr; + struct mgmt_ev_cmd_complete *ev; + int err; + + BT_DBG("sock %p", sk); + + skb = alloc_skb(sizeof(*hdr) + sizeof(*ev) + rp_len, GFP_KERNEL); + if (!skb) + return -ENOMEM; + + hdr = skb_put(skb, sizeof(*hdr)); + + hdr->opcode = cpu_to_le16(MGMT_EV_CMD_COMPLETE); + hdr->index = cpu_to_le16(index); + hdr->len = cpu_to_le16(sizeof(*ev) + rp_len); + + ev = skb_put(skb, sizeof(*ev) + rp_len); + ev->opcode = cpu_to_le16(cmd); + ev->status = status; + + if (rp) + memcpy(ev->data, rp, rp_len); + + mskb = create_monitor_ctrl_event(hdr->index, hci_sock_get_cookie(sk), + MGMT_EV_CMD_COMPLETE, + sizeof(*ev) + rp_len, ev); + if (mskb) + skb->tstamp = mskb->tstamp; + else + __net_timestamp(skb); + + err = sock_queue_rcv_skb(sk, skb); + if (err < 0) + kfree_skb(skb); + + if (mskb) { + hci_send_to_channel(HCI_CHANNEL_MONITOR, mskb, + HCI_SOCK_TRUSTED, NULL); + kfree_skb(mskb); + } + + return err; +} + +struct mgmt_pending_cmd *mgmt_pending_find(unsigned short channel, u16 opcode, + struct hci_dev *hdev) +{ + struct mgmt_pending_cmd *cmd; + + list_for_each_entry(cmd, &hdev->mgmt_pending, list) { + if (hci_sock_get_channel(cmd->sk) != channel) + continue; + if (cmd->opcode == opcode) + return cmd; + } + + return NULL; +} + +struct mgmt_pending_cmd *mgmt_pending_find_data(unsigned short channel, + u16 opcode, + struct hci_dev *hdev, + const void *data) +{ + struct mgmt_pending_cmd *cmd; + + list_for_each_entry(cmd, &hdev->mgmt_pending, list) { + if (cmd->user_data != data) + continue; + if (cmd->opcode == opcode) + return cmd; + } + + return NULL; +} + +void mgmt_pending_foreach(u16 opcode, struct hci_dev *hdev, + void (*cb)(struct mgmt_pending_cmd *cmd, void *data), + void *data) +{ + struct mgmt_pending_cmd *cmd, *tmp; + + list_for_each_entry_safe(cmd, tmp, &hdev->mgmt_pending, list) { + if (opcode > 0 && cmd->opcode != opcode) + continue; + + cb(cmd, data); + } +} + +struct mgmt_pending_cmd *mgmt_pending_new(struct sock *sk, u16 opcode, + struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_pending_cmd *cmd; + + cmd = kzalloc(sizeof(*cmd), GFP_KERNEL); + if (!cmd) + return NULL; + + cmd->opcode = opcode; + cmd->index = hdev->id; + + cmd->param = kmemdup(data, len, GFP_KERNEL); + if (!cmd->param) { + kfree(cmd); + return NULL; + } + + cmd->param_len = len; + + cmd->sk = sk; + sock_hold(sk); + + return cmd; +} + +struct mgmt_pending_cmd *mgmt_pending_add(struct sock *sk, u16 opcode, + struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_pending_cmd *cmd; + + cmd = mgmt_pending_new(sk, opcode, hdev, data, len); + if (!cmd) + return NULL; + + list_add_tail(&cmd->list, &hdev->mgmt_pending); + + return cmd; +} + +void mgmt_pending_free(struct mgmt_pending_cmd *cmd) +{ + sock_put(cmd->sk); + kfree(cmd->param); + kfree(cmd); +} + +void mgmt_pending_remove(struct mgmt_pending_cmd *cmd) +{ + list_del(&cmd->list); + mgmt_pending_free(cmd); +} + +void mgmt_mesh_foreach(struct hci_dev *hdev, + void (*cb)(struct mgmt_mesh_tx *mesh_tx, void *data), + void *data, struct sock *sk) +{ + struct mgmt_mesh_tx *mesh_tx, *tmp; + + list_for_each_entry_safe(mesh_tx, tmp, &hdev->mgmt_pending, list) { + if (!sk || mesh_tx->sk == sk) + cb(mesh_tx, data); + } +} + +struct mgmt_mesh_tx *mgmt_mesh_next(struct hci_dev *hdev, struct sock *sk) +{ + struct mgmt_mesh_tx *mesh_tx; + + if (list_empty(&hdev->mesh_pending)) + return NULL; + + list_for_each_entry(mesh_tx, &hdev->mesh_pending, list) { + if (!sk || mesh_tx->sk == sk) + return mesh_tx; + } + + return NULL; +} + +struct mgmt_mesh_tx *mgmt_mesh_find(struct hci_dev *hdev, u8 handle) +{ + struct mgmt_mesh_tx *mesh_tx; + + if (list_empty(&hdev->mesh_pending)) + return NULL; + + list_for_each_entry(mesh_tx, &hdev->mesh_pending, list) { + if (mesh_tx->handle == handle) + return mesh_tx; + } + + return NULL; +} + +struct mgmt_mesh_tx *mgmt_mesh_add(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len) +{ + struct mgmt_mesh_tx *mesh_tx; + + mesh_tx = kzalloc(sizeof(*mesh_tx), GFP_KERNEL); + if (!mesh_tx) + return NULL; + + hdev->mesh_send_ref++; + if (!hdev->mesh_send_ref) + hdev->mesh_send_ref++; + + mesh_tx->handle = hdev->mesh_send_ref; + mesh_tx->index = hdev->id; + memcpy(mesh_tx->param, data, len); + mesh_tx->param_len = len; + mesh_tx->sk = sk; + sock_hold(sk); + + list_add_tail(&mesh_tx->list, &hdev->mesh_pending); + + return mesh_tx; +} + +void mgmt_mesh_remove(struct mgmt_mesh_tx *mesh_tx) +{ + list_del(&mesh_tx->list); + sock_put(mesh_tx->sk); + kfree(mesh_tx); +} diff --git a/net/bluetooth/mgmt_util.h b/net/bluetooth/mgmt_util.h new file mode 100644 index 000000000..bdf978605 --- /dev/null +++ b/net/bluetooth/mgmt_util.h @@ -0,0 +1,79 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + Copyright (C) 2015 Intel Coropration + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +struct mgmt_mesh_tx { + struct list_head list; + int index; + size_t param_len; + struct sock *sk; + u8 handle; + u8 instance; + u8 param[sizeof(struct mgmt_cp_mesh_send) + 31]; +}; + +struct mgmt_pending_cmd { + struct list_head list; + u16 opcode; + int index; + void *param; + size_t param_len; + struct sock *sk; + struct sk_buff *skb; + void *user_data; + int (*cmd_complete)(struct mgmt_pending_cmd *cmd, u8 status); +}; + +struct sk_buff *mgmt_alloc_skb(struct hci_dev *hdev, u16 opcode, + unsigned int size); +int mgmt_send_event_skb(unsigned short channel, struct sk_buff *skb, int flag, + struct sock *skip_sk); +int mgmt_send_event(u16 event, struct hci_dev *hdev, unsigned short channel, + void *data, u16 data_len, int flag, struct sock *skip_sk); +int mgmt_cmd_status(struct sock *sk, u16 index, u16 cmd, u8 status); +int mgmt_cmd_complete(struct sock *sk, u16 index, u16 cmd, u8 status, + void *rp, size_t rp_len); + +struct mgmt_pending_cmd *mgmt_pending_find(unsigned short channel, u16 opcode, + struct hci_dev *hdev); +struct mgmt_pending_cmd *mgmt_pending_find_data(unsigned short channel, + u16 opcode, + struct hci_dev *hdev, + const void *data); +void mgmt_pending_foreach(u16 opcode, struct hci_dev *hdev, + void (*cb)(struct mgmt_pending_cmd *cmd, void *data), + void *data); +struct mgmt_pending_cmd *mgmt_pending_add(struct sock *sk, u16 opcode, + struct hci_dev *hdev, + void *data, u16 len); +struct mgmt_pending_cmd *mgmt_pending_new(struct sock *sk, u16 opcode, + struct hci_dev *hdev, + void *data, u16 len); +void mgmt_pending_free(struct mgmt_pending_cmd *cmd); +void mgmt_pending_remove(struct mgmt_pending_cmd *cmd); +void mgmt_mesh_foreach(struct hci_dev *hdev, + void (*cb)(struct mgmt_mesh_tx *mesh_tx, void *data), + void *data, struct sock *sk); +struct mgmt_mesh_tx *mgmt_mesh_find(struct hci_dev *hdev, u8 handle); +struct mgmt_mesh_tx *mgmt_mesh_next(struct hci_dev *hdev, struct sock *sk); +struct mgmt_mesh_tx *mgmt_mesh_add(struct sock *sk, struct hci_dev *hdev, + void *data, u16 len); +void mgmt_mesh_remove(struct mgmt_mesh_tx *mesh_tx); diff --git a/net/bluetooth/msft.c b/net/bluetooth/msft.c new file mode 100644 index 000000000..bee6a4c65 --- /dev/null +++ b/net/bluetooth/msft.c @@ -0,0 +1,837 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2020 Google Corporation + */ + +#include +#include +#include + +#include "hci_request.h" +#include "mgmt_util.h" +#include "msft.h" + +#define MSFT_RSSI_THRESHOLD_VALUE_MIN -127 +#define MSFT_RSSI_THRESHOLD_VALUE_MAX 20 +#define MSFT_RSSI_LOW_TIMEOUT_MAX 0x3C + +#define MSFT_OP_READ_SUPPORTED_FEATURES 0x00 +struct msft_cp_read_supported_features { + __u8 sub_opcode; +} __packed; + +struct msft_rp_read_supported_features { + __u8 status; + __u8 sub_opcode; + __le64 features; + __u8 evt_prefix_len; + __u8 evt_prefix[]; +} __packed; + +#define MSFT_OP_LE_MONITOR_ADVERTISEMENT 0x03 +#define MSFT_MONITOR_ADVERTISEMENT_TYPE_PATTERN 0x01 +struct msft_le_monitor_advertisement_pattern { + __u8 length; + __u8 data_type; + __u8 start_byte; + __u8 pattern[]; +}; + +struct msft_le_monitor_advertisement_pattern_data { + __u8 count; + __u8 data[]; +}; + +struct msft_cp_le_monitor_advertisement { + __u8 sub_opcode; + __s8 rssi_high; + __s8 rssi_low; + __u8 rssi_low_interval; + __u8 rssi_sampling_period; + __u8 cond_type; + __u8 data[]; +} __packed; + +struct msft_rp_le_monitor_advertisement { + __u8 status; + __u8 sub_opcode; + __u8 handle; +} __packed; + +#define MSFT_OP_LE_CANCEL_MONITOR_ADVERTISEMENT 0x04 +struct msft_cp_le_cancel_monitor_advertisement { + __u8 sub_opcode; + __u8 handle; +} __packed; + +struct msft_rp_le_cancel_monitor_advertisement { + __u8 status; + __u8 sub_opcode; +} __packed; + +#define MSFT_OP_LE_SET_ADVERTISEMENT_FILTER_ENABLE 0x05 +struct msft_cp_le_set_advertisement_filter_enable { + __u8 sub_opcode; + __u8 enable; +} __packed; + +struct msft_rp_le_set_advertisement_filter_enable { + __u8 status; + __u8 sub_opcode; +} __packed; + +#define MSFT_EV_LE_MONITOR_DEVICE 0x02 +struct msft_ev_le_monitor_device { + __u8 addr_type; + bdaddr_t bdaddr; + __u8 monitor_handle; + __u8 monitor_state; +} __packed; + +struct msft_monitor_advertisement_handle_data { + __u8 msft_handle; + __u16 mgmt_handle; + struct list_head list; +}; + +struct msft_data { + __u64 features; + __u8 evt_prefix_len; + __u8 *evt_prefix; + struct list_head handle_map; + __u8 resuming; + __u8 suspending; + __u8 filter_enabled; +}; + +bool msft_monitor_supported(struct hci_dev *hdev) +{ + return !!(msft_get_features(hdev) & MSFT_FEATURE_MASK_LE_ADV_MONITOR); +} + +static bool read_supported_features(struct hci_dev *hdev, + struct msft_data *msft) +{ + struct msft_cp_read_supported_features cp; + struct msft_rp_read_supported_features *rp; + struct sk_buff *skb; + + cp.sub_opcode = MSFT_OP_READ_SUPPORTED_FEATURES; + + skb = __hci_cmd_sync(hdev, hdev->msft_opcode, sizeof(cp), &cp, + HCI_CMD_TIMEOUT); + if (IS_ERR_OR_NULL(skb)) { + if (!skb) + skb = ERR_PTR(-EIO); + + bt_dev_err(hdev, "Failed to read MSFT supported features (%ld)", + PTR_ERR(skb)); + return false; + } + + if (skb->len < sizeof(*rp)) { + bt_dev_err(hdev, "MSFT supported features length mismatch"); + goto failed; + } + + rp = (struct msft_rp_read_supported_features *)skb->data; + + if (rp->sub_opcode != MSFT_OP_READ_SUPPORTED_FEATURES) + goto failed; + + if (rp->evt_prefix_len > 0) { + msft->evt_prefix = kmemdup(rp->evt_prefix, rp->evt_prefix_len, + GFP_KERNEL); + if (!msft->evt_prefix) + goto failed; + } + + msft->evt_prefix_len = rp->evt_prefix_len; + msft->features = __le64_to_cpu(rp->features); + + if (msft->features & MSFT_FEATURE_MASK_CURVE_VALIDITY) + hdev->msft_curve_validity = true; + + kfree_skb(skb); + return true; + +failed: + kfree_skb(skb); + return false; +} + +/* is_mgmt = true matches the handle exposed to userspace via mgmt. + * is_mgmt = false matches the handle used by the msft controller. + * This function requires the caller holds hdev->lock + */ +static struct msft_monitor_advertisement_handle_data *msft_find_handle_data + (struct hci_dev *hdev, u16 handle, bool is_mgmt) +{ + struct msft_monitor_advertisement_handle_data *entry; + struct msft_data *msft = hdev->msft_data; + + list_for_each_entry(entry, &msft->handle_map, list) { + if (is_mgmt && entry->mgmt_handle == handle) + return entry; + if (!is_mgmt && entry->msft_handle == handle) + return entry; + } + + return NULL; +} + +/* This function requires the caller holds hdev->lock */ +static int msft_monitor_device_del(struct hci_dev *hdev, __u16 mgmt_handle, + bdaddr_t *bdaddr, __u8 addr_type, + bool notify) +{ + struct monitored_device *dev, *tmp; + int count = 0; + + list_for_each_entry_safe(dev, tmp, &hdev->monitored_devices, list) { + /* mgmt_handle == 0 indicates remove all devices, whereas, + * bdaddr == NULL indicates remove all devices matching the + * mgmt_handle. + */ + if ((!mgmt_handle || dev->handle == mgmt_handle) && + (!bdaddr || (!bacmp(bdaddr, &dev->bdaddr) && + addr_type == dev->addr_type))) { + if (notify && dev->notified) { + mgmt_adv_monitor_device_lost(hdev, dev->handle, + &dev->bdaddr, + dev->addr_type); + } + + list_del(&dev->list); + kfree(dev); + count++; + } + } + + return count; +} + +static int msft_le_monitor_advertisement_cb(struct hci_dev *hdev, u16 opcode, + struct adv_monitor *monitor, + struct sk_buff *skb) +{ + struct msft_rp_le_monitor_advertisement *rp; + struct msft_monitor_advertisement_handle_data *handle_data; + struct msft_data *msft = hdev->msft_data; + int status = 0; + + hci_dev_lock(hdev); + + rp = (struct msft_rp_le_monitor_advertisement *)skb->data; + if (skb->len < sizeof(*rp)) { + status = HCI_ERROR_UNSPECIFIED; + goto unlock; + } + + status = rp->status; + if (status) + goto unlock; + + handle_data = kmalloc(sizeof(*handle_data), GFP_KERNEL); + if (!handle_data) { + status = HCI_ERROR_UNSPECIFIED; + goto unlock; + } + + handle_data->mgmt_handle = monitor->handle; + handle_data->msft_handle = rp->handle; + INIT_LIST_HEAD(&handle_data->list); + list_add(&handle_data->list, &msft->handle_map); + + monitor->state = ADV_MONITOR_STATE_OFFLOADED; + +unlock: + if (status) + hci_free_adv_monitor(hdev, monitor); + + hci_dev_unlock(hdev); + + return status; +} + +static int msft_le_cancel_monitor_advertisement_cb(struct hci_dev *hdev, + u16 opcode, + struct adv_monitor *monitor, + struct sk_buff *skb) +{ + struct msft_rp_le_cancel_monitor_advertisement *rp; + struct msft_monitor_advertisement_handle_data *handle_data; + struct msft_data *msft = hdev->msft_data; + int status = 0; + + rp = (struct msft_rp_le_cancel_monitor_advertisement *)skb->data; + if (skb->len < sizeof(*rp)) { + status = HCI_ERROR_UNSPECIFIED; + goto done; + } + + status = rp->status; + if (status) + goto done; + + hci_dev_lock(hdev); + + handle_data = msft_find_handle_data(hdev, monitor->handle, true); + + if (handle_data) { + if (monitor->state == ADV_MONITOR_STATE_OFFLOADED) + monitor->state = ADV_MONITOR_STATE_REGISTERED; + + /* Do not free the monitor if it is being removed due to + * suspend. It will be re-monitored on resume. + */ + if (!msft->suspending) { + hci_free_adv_monitor(hdev, monitor); + + /* Clear any monitored devices by this Adv Monitor */ + msft_monitor_device_del(hdev, handle_data->mgmt_handle, + NULL, 0, false); + } + + list_del(&handle_data->list); + kfree(handle_data); + } + + hci_dev_unlock(hdev); + +done: + return status; +} + +/* This function requires the caller holds hci_req_sync_lock */ +static int msft_remove_monitor_sync(struct hci_dev *hdev, + struct adv_monitor *monitor) +{ + struct msft_cp_le_cancel_monitor_advertisement cp; + struct msft_monitor_advertisement_handle_data *handle_data; + struct sk_buff *skb; + + handle_data = msft_find_handle_data(hdev, monitor->handle, true); + + /* If no matched handle, just remove without telling controller */ + if (!handle_data) + return -ENOENT; + + cp.sub_opcode = MSFT_OP_LE_CANCEL_MONITOR_ADVERTISEMENT; + cp.handle = handle_data->msft_handle; + + skb = __hci_cmd_sync(hdev, hdev->msft_opcode, sizeof(cp), &cp, + HCI_CMD_TIMEOUT); + if (IS_ERR_OR_NULL(skb)) { + if (!skb) + return -EIO; + return PTR_ERR(skb); + } + + return msft_le_cancel_monitor_advertisement_cb(hdev, hdev->msft_opcode, + monitor, skb); +} + +/* This function requires the caller holds hci_req_sync_lock */ +int msft_suspend_sync(struct hci_dev *hdev) +{ + struct msft_data *msft = hdev->msft_data; + struct adv_monitor *monitor; + int handle = 0; + + if (!msft || !msft_monitor_supported(hdev)) + return 0; + + msft->suspending = true; + + while (1) { + monitor = idr_get_next(&hdev->adv_monitors_idr, &handle); + if (!monitor) + break; + + msft_remove_monitor_sync(hdev, monitor); + + handle++; + } + + /* All monitors have been removed */ + msft->suspending = false; + + return 0; +} + +static bool msft_monitor_rssi_valid(struct adv_monitor *monitor) +{ + struct adv_rssi_thresholds *r = &monitor->rssi; + + if (r->high_threshold < MSFT_RSSI_THRESHOLD_VALUE_MIN || + r->high_threshold > MSFT_RSSI_THRESHOLD_VALUE_MAX || + r->low_threshold < MSFT_RSSI_THRESHOLD_VALUE_MIN || + r->low_threshold > MSFT_RSSI_THRESHOLD_VALUE_MAX) + return false; + + /* High_threshold_timeout is not supported, + * once high_threshold is reached, events are immediately reported. + */ + if (r->high_threshold_timeout != 0) + return false; + + if (r->low_threshold_timeout > MSFT_RSSI_LOW_TIMEOUT_MAX) + return false; + + /* Sampling period from 0x00 to 0xFF are all allowed */ + return true; +} + +static bool msft_monitor_pattern_valid(struct adv_monitor *monitor) +{ + return msft_monitor_rssi_valid(monitor); + /* No additional check needed for pattern-based monitor */ +} + +static int msft_add_monitor_sync(struct hci_dev *hdev, + struct adv_monitor *monitor) +{ + struct msft_cp_le_monitor_advertisement *cp; + struct msft_le_monitor_advertisement_pattern_data *pattern_data; + struct msft_le_monitor_advertisement_pattern *pattern; + struct adv_pattern *entry; + size_t total_size = sizeof(*cp) + sizeof(*pattern_data); + ptrdiff_t offset = 0; + u8 pattern_count = 0; + struct sk_buff *skb; + + if (!msft_monitor_pattern_valid(monitor)) + return -EINVAL; + + list_for_each_entry(entry, &monitor->patterns, list) { + pattern_count++; + total_size += sizeof(*pattern) + entry->length; + } + + cp = kmalloc(total_size, GFP_KERNEL); + if (!cp) + return -ENOMEM; + + cp->sub_opcode = MSFT_OP_LE_MONITOR_ADVERTISEMENT; + cp->rssi_high = monitor->rssi.high_threshold; + cp->rssi_low = monitor->rssi.low_threshold; + cp->rssi_low_interval = (u8)monitor->rssi.low_threshold_timeout; + cp->rssi_sampling_period = monitor->rssi.sampling_period; + + cp->cond_type = MSFT_MONITOR_ADVERTISEMENT_TYPE_PATTERN; + + pattern_data = (void *)cp->data; + pattern_data->count = pattern_count; + + list_for_each_entry(entry, &monitor->patterns, list) { + pattern = (void *)(pattern_data->data + offset); + /* the length also includes data_type and offset */ + pattern->length = entry->length + 2; + pattern->data_type = entry->ad_type; + pattern->start_byte = entry->offset; + memcpy(pattern->pattern, entry->value, entry->length); + offset += sizeof(*pattern) + entry->length; + } + + skb = __hci_cmd_sync(hdev, hdev->msft_opcode, total_size, cp, + HCI_CMD_TIMEOUT); + kfree(cp); + + if (IS_ERR_OR_NULL(skb)) { + if (!skb) + return -EIO; + return PTR_ERR(skb); + } + + return msft_le_monitor_advertisement_cb(hdev, hdev->msft_opcode, + monitor, skb); +} + +/* This function requires the caller holds hci_req_sync_lock */ +static void reregister_monitor(struct hci_dev *hdev) +{ + struct adv_monitor *monitor; + struct msft_data *msft = hdev->msft_data; + int handle = 0; + + if (!msft) + return; + + msft->resuming = true; + + while (1) { + monitor = idr_get_next(&hdev->adv_monitors_idr, &handle); + if (!monitor) + break; + + msft_add_monitor_sync(hdev, monitor); + + handle++; + } + + /* All monitors have been reregistered */ + msft->resuming = false; +} + +/* This function requires the caller holds hci_req_sync_lock */ +int msft_resume_sync(struct hci_dev *hdev) +{ + struct msft_data *msft = hdev->msft_data; + + if (!msft || !msft_monitor_supported(hdev)) + return 0; + + hci_dev_lock(hdev); + + /* Clear already tracked devices on resume. Once the monitors are + * reregistered, devices in range will be found again after resume. + */ + hdev->advmon_pend_notify = false; + msft_monitor_device_del(hdev, 0, NULL, 0, true); + + hci_dev_unlock(hdev); + + reregister_monitor(hdev); + + return 0; +} + +/* This function requires the caller holds hci_req_sync_lock */ +void msft_do_open(struct hci_dev *hdev) +{ + struct msft_data *msft = hdev->msft_data; + + if (hdev->msft_opcode == HCI_OP_NOP) + return; + + if (!msft) { + bt_dev_err(hdev, "MSFT extension not registered"); + return; + } + + bt_dev_dbg(hdev, "Initialize MSFT extension"); + + /* Reset existing MSFT data before re-reading */ + kfree(msft->evt_prefix); + msft->evt_prefix = NULL; + msft->evt_prefix_len = 0; + msft->features = 0; + + if (!read_supported_features(hdev, msft)) { + hdev->msft_data = NULL; + kfree(msft); + return; + } + + if (msft_monitor_supported(hdev)) { + msft->resuming = true; + msft_set_filter_enable(hdev, true); + /* Monitors get removed on power off, so we need to explicitly + * tell the controller to re-monitor. + */ + reregister_monitor(hdev); + } +} + +void msft_do_close(struct hci_dev *hdev) +{ + struct msft_data *msft = hdev->msft_data; + struct msft_monitor_advertisement_handle_data *handle_data, *tmp; + struct adv_monitor *monitor; + + if (!msft) + return; + + bt_dev_dbg(hdev, "Cleanup of MSFT extension"); + + /* The controller will silently remove all monitors on power off. + * Therefore, remove handle_data mapping and reset monitor state. + */ + list_for_each_entry_safe(handle_data, tmp, &msft->handle_map, list) { + monitor = idr_find(&hdev->adv_monitors_idr, + handle_data->mgmt_handle); + + if (monitor && monitor->state == ADV_MONITOR_STATE_OFFLOADED) + monitor->state = ADV_MONITOR_STATE_REGISTERED; + + list_del(&handle_data->list); + kfree(handle_data); + } + + hci_dev_lock(hdev); + + /* Clear any devices that are being monitored and notify device lost */ + hdev->advmon_pend_notify = false; + msft_monitor_device_del(hdev, 0, NULL, 0, true); + + hci_dev_unlock(hdev); +} + +void msft_register(struct hci_dev *hdev) +{ + struct msft_data *msft = NULL; + + bt_dev_dbg(hdev, "Register MSFT extension"); + + msft = kzalloc(sizeof(*msft), GFP_KERNEL); + if (!msft) { + bt_dev_err(hdev, "Failed to register MSFT extension"); + return; + } + + INIT_LIST_HEAD(&msft->handle_map); + hdev->msft_data = msft; +} + +void msft_unregister(struct hci_dev *hdev) +{ + struct msft_data *msft = hdev->msft_data; + + if (!msft) + return; + + bt_dev_dbg(hdev, "Unregister MSFT extension"); + + hdev->msft_data = NULL; + + kfree(msft->evt_prefix); + kfree(msft); +} + +/* This function requires the caller holds hdev->lock */ +static void msft_device_found(struct hci_dev *hdev, bdaddr_t *bdaddr, + __u8 addr_type, __u16 mgmt_handle) +{ + struct monitored_device *dev; + + dev = kmalloc(sizeof(*dev), GFP_KERNEL); + if (!dev) { + bt_dev_err(hdev, "MSFT vendor event %u: no memory", + MSFT_EV_LE_MONITOR_DEVICE); + return; + } + + bacpy(&dev->bdaddr, bdaddr); + dev->addr_type = addr_type; + dev->handle = mgmt_handle; + dev->notified = false; + + INIT_LIST_HEAD(&dev->list); + list_add(&dev->list, &hdev->monitored_devices); + hdev->advmon_pend_notify = true; +} + +/* This function requires the caller holds hdev->lock */ +static void msft_device_lost(struct hci_dev *hdev, bdaddr_t *bdaddr, + __u8 addr_type, __u16 mgmt_handle) +{ + if (!msft_monitor_device_del(hdev, mgmt_handle, bdaddr, addr_type, + true)) { + bt_dev_err(hdev, "MSFT vendor event %u: dev %pMR not in list", + MSFT_EV_LE_MONITOR_DEVICE, bdaddr); + } +} + +static void *msft_skb_pull(struct hci_dev *hdev, struct sk_buff *skb, + u8 ev, size_t len) +{ + void *data; + + data = skb_pull_data(skb, len); + if (!data) + bt_dev_err(hdev, "Malformed MSFT vendor event: 0x%02x", ev); + + return data; +} + +/* This function requires the caller holds hdev->lock */ +static void msft_monitor_device_evt(struct hci_dev *hdev, struct sk_buff *skb) +{ + struct msft_ev_le_monitor_device *ev; + struct msft_monitor_advertisement_handle_data *handle_data; + u8 addr_type; + + ev = msft_skb_pull(hdev, skb, MSFT_EV_LE_MONITOR_DEVICE, sizeof(*ev)); + if (!ev) + return; + + bt_dev_dbg(hdev, + "MSFT vendor event 0x%02x: handle 0x%04x state %d addr %pMR", + MSFT_EV_LE_MONITOR_DEVICE, ev->monitor_handle, + ev->monitor_state, &ev->bdaddr); + + handle_data = msft_find_handle_data(hdev, ev->monitor_handle, false); + if (!handle_data) + return; + + switch (ev->addr_type) { + case ADDR_LE_DEV_PUBLIC: + addr_type = BDADDR_LE_PUBLIC; + break; + + case ADDR_LE_DEV_RANDOM: + addr_type = BDADDR_LE_RANDOM; + break; + + default: + bt_dev_err(hdev, + "MSFT vendor event 0x%02x: unknown addr type 0x%02x", + MSFT_EV_LE_MONITOR_DEVICE, ev->addr_type); + return; + } + + if (ev->monitor_state) + msft_device_found(hdev, &ev->bdaddr, addr_type, + handle_data->mgmt_handle); + else + msft_device_lost(hdev, &ev->bdaddr, addr_type, + handle_data->mgmt_handle); +} + +void msft_vendor_evt(struct hci_dev *hdev, void *data, struct sk_buff *skb) +{ + struct msft_data *msft = hdev->msft_data; + u8 *evt_prefix; + u8 *evt; + + if (!msft) + return; + + /* When the extension has defined an event prefix, check that it + * matches, and otherwise just return. + */ + if (msft->evt_prefix_len > 0) { + evt_prefix = msft_skb_pull(hdev, skb, 0, msft->evt_prefix_len); + if (!evt_prefix) + return; + + if (memcmp(evt_prefix, msft->evt_prefix, msft->evt_prefix_len)) + return; + } + + /* Every event starts at least with an event code and the rest of + * the data is variable and depends on the event code. + */ + if (skb->len < 1) + return; + + evt = msft_skb_pull(hdev, skb, 0, sizeof(*evt)); + if (!evt) + return; + + hci_dev_lock(hdev); + + switch (*evt) { + case MSFT_EV_LE_MONITOR_DEVICE: + msft_monitor_device_evt(hdev, skb); + break; + + default: + bt_dev_dbg(hdev, "MSFT vendor event 0x%02x", *evt); + break; + } + + hci_dev_unlock(hdev); +} + +__u64 msft_get_features(struct hci_dev *hdev) +{ + struct msft_data *msft = hdev->msft_data; + + return msft ? msft->features : 0; +} + +static void msft_le_set_advertisement_filter_enable_cb(struct hci_dev *hdev, + u8 status, u16 opcode, + struct sk_buff *skb) +{ + struct msft_cp_le_set_advertisement_filter_enable *cp; + struct msft_rp_le_set_advertisement_filter_enable *rp; + struct msft_data *msft = hdev->msft_data; + + rp = (struct msft_rp_le_set_advertisement_filter_enable *)skb->data; + if (skb->len < sizeof(*rp)) + return; + + /* Error 0x0C would be returned if the filter enabled status is + * already set to whatever we were trying to set. + * Although the default state should be disabled, some controller set + * the initial value to enabled. Because there is no way to know the + * actual initial value before sending this command, here we also treat + * error 0x0C as success. + */ + if (status != 0x00 && status != 0x0C) + return; + + hci_dev_lock(hdev); + + cp = hci_sent_cmd_data(hdev, hdev->msft_opcode); + msft->filter_enabled = cp->enable; + + if (status == 0x0C) + bt_dev_warn(hdev, "MSFT filter_enable is already %s", + cp->enable ? "on" : "off"); + + hci_dev_unlock(hdev); +} + +/* This function requires the caller holds hci_req_sync_lock */ +int msft_add_monitor_pattern(struct hci_dev *hdev, struct adv_monitor *monitor) +{ + struct msft_data *msft = hdev->msft_data; + + if (!msft) + return -EOPNOTSUPP; + + if (msft->resuming || msft->suspending) + return -EBUSY; + + return msft_add_monitor_sync(hdev, monitor); +} + +/* This function requires the caller holds hci_req_sync_lock */ +int msft_remove_monitor(struct hci_dev *hdev, struct adv_monitor *monitor) +{ + struct msft_data *msft = hdev->msft_data; + + if (!msft) + return -EOPNOTSUPP; + + if (msft->resuming || msft->suspending) + return -EBUSY; + + return msft_remove_monitor_sync(hdev, monitor); +} + +void msft_req_add_set_filter_enable(struct hci_request *req, bool enable) +{ + struct hci_dev *hdev = req->hdev; + struct msft_cp_le_set_advertisement_filter_enable cp; + + cp.sub_opcode = MSFT_OP_LE_SET_ADVERTISEMENT_FILTER_ENABLE; + cp.enable = enable; + + hci_req_add(req, hdev->msft_opcode, sizeof(cp), &cp); +} + +int msft_set_filter_enable(struct hci_dev *hdev, bool enable) +{ + struct hci_request req; + struct msft_data *msft = hdev->msft_data; + int err; + + if (!msft) + return -EOPNOTSUPP; + + hci_req_init(&req, hdev); + msft_req_add_set_filter_enable(&req, enable); + err = hci_req_run_skb(&req, msft_le_set_advertisement_filter_enable_cb); + + return err; +} + +bool msft_curve_validity(struct hci_dev *hdev) +{ + return hdev->msft_curve_validity; +} diff --git a/net/bluetooth/msft.h b/net/bluetooth/msft.h new file mode 100644 index 000000000..2a63205b3 --- /dev/null +++ b/net/bluetooth/msft.h @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2020 Google Corporation + */ + +#define MSFT_FEATURE_MASK_BREDR_RSSI_MONITOR BIT(0) +#define MSFT_FEATURE_MASK_LE_CONN_RSSI_MONITOR BIT(1) +#define MSFT_FEATURE_MASK_LE_ADV_RSSI_MONITOR BIT(2) +#define MSFT_FEATURE_MASK_LE_ADV_MONITOR BIT(3) +#define MSFT_FEATURE_MASK_CURVE_VALIDITY BIT(4) +#define MSFT_FEATURE_MASK_CONCURRENT_ADV_MONITOR BIT(5) + +#if IS_ENABLED(CONFIG_BT_MSFTEXT) + +bool msft_monitor_supported(struct hci_dev *hdev); +void msft_register(struct hci_dev *hdev); +void msft_unregister(struct hci_dev *hdev); +void msft_do_open(struct hci_dev *hdev); +void msft_do_close(struct hci_dev *hdev); +void msft_vendor_evt(struct hci_dev *hdev, void *data, struct sk_buff *skb); +__u64 msft_get_features(struct hci_dev *hdev); +int msft_add_monitor_pattern(struct hci_dev *hdev, struct adv_monitor *monitor); +int msft_remove_monitor(struct hci_dev *hdev, struct adv_monitor *monitor); +void msft_req_add_set_filter_enable(struct hci_request *req, bool enable); +int msft_set_filter_enable(struct hci_dev *hdev, bool enable); +int msft_suspend_sync(struct hci_dev *hdev); +int msft_resume_sync(struct hci_dev *hdev); +bool msft_curve_validity(struct hci_dev *hdev); + +#else + +static inline bool msft_monitor_supported(struct hci_dev *hdev) +{ + return false; +} + +static inline void msft_register(struct hci_dev *hdev) {} +static inline void msft_unregister(struct hci_dev *hdev) {} +static inline void msft_do_open(struct hci_dev *hdev) {} +static inline void msft_do_close(struct hci_dev *hdev) {} +static inline void msft_vendor_evt(struct hci_dev *hdev, void *data, + struct sk_buff *skb) {} +static inline __u64 msft_get_features(struct hci_dev *hdev) { return 0; } +static inline int msft_add_monitor_pattern(struct hci_dev *hdev, + struct adv_monitor *monitor) +{ + return -EOPNOTSUPP; +} + +static inline int msft_remove_monitor(struct hci_dev *hdev, + struct adv_monitor *monitor) +{ + return -EOPNOTSUPP; +} + +static inline void msft_req_add_set_filter_enable(struct hci_request *req, + bool enable) {} +static inline int msft_set_filter_enable(struct hci_dev *hdev, bool enable) +{ + return -EOPNOTSUPP; +} + +static inline int msft_suspend_sync(struct hci_dev *hdev) +{ + return -EOPNOTSUPP; +} + +static inline int msft_resume_sync(struct hci_dev *hdev) +{ + return -EOPNOTSUPP; +} + +static inline bool msft_curve_validity(struct hci_dev *hdev) +{ + return false; +} + +#endif diff --git a/net/bluetooth/rfcomm/Kconfig b/net/bluetooth/rfcomm/Kconfig new file mode 100644 index 000000000..9b9953ebf --- /dev/null +++ b/net/bluetooth/rfcomm/Kconfig @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: GPL-2.0-only +config BT_RFCOMM + tristate "RFCOMM protocol support" + depends on BT_BREDR + help + RFCOMM provides connection oriented stream transport. RFCOMM + support is required for Dialup Networking, OBEX and other Bluetooth + applications. + + Say Y here to compile RFCOMM support into the kernel or say M to + compile it as module (rfcomm). + +config BT_RFCOMM_TTY + bool "RFCOMM TTY support" + depends on BT_RFCOMM + depends on TTY + help + This option enables TTY emulation support for RFCOMM channels. + diff --git a/net/bluetooth/rfcomm/Makefile b/net/bluetooth/rfcomm/Makefile new file mode 100644 index 000000000..593e5c48c --- /dev/null +++ b/net/bluetooth/rfcomm/Makefile @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the Linux Bluetooth RFCOMM layer. +# + +obj-$(CONFIG_BT_RFCOMM) += rfcomm.o + +rfcomm-y := core.o sock.o +rfcomm-$(CONFIG_BT_RFCOMM_TTY) += tty.o diff --git a/net/bluetooth/rfcomm/core.c b/net/bluetooth/rfcomm/core.c new file mode 100644 index 000000000..8d6fce900 --- /dev/null +++ b/net/bluetooth/rfcomm/core.c @@ -0,0 +1,2281 @@ +/* + RFCOMM implementation for Linux Bluetooth stack (BlueZ). + Copyright (C) 2002 Maxim Krasnyansky + Copyright (C) 2002 Marcel Holtmann + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +/* + * Bluetooth RFCOMM core. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#define VERSION "1.11" + +static bool disable_cfc; +static bool l2cap_ertm; +static int channel_mtu = -1; + +static struct task_struct *rfcomm_thread; + +static DEFINE_MUTEX(rfcomm_mutex); +#define rfcomm_lock() mutex_lock(&rfcomm_mutex) +#define rfcomm_unlock() mutex_unlock(&rfcomm_mutex) + + +static LIST_HEAD(session_list); + +static int rfcomm_send_frame(struct rfcomm_session *s, u8 *data, int len); +static int rfcomm_send_sabm(struct rfcomm_session *s, u8 dlci); +static int rfcomm_send_disc(struct rfcomm_session *s, u8 dlci); +static int rfcomm_queue_disc(struct rfcomm_dlc *d); +static int rfcomm_send_nsc(struct rfcomm_session *s, int cr, u8 type); +static int rfcomm_send_pn(struct rfcomm_session *s, int cr, struct rfcomm_dlc *d); +static int rfcomm_send_msc(struct rfcomm_session *s, int cr, u8 dlci, u8 v24_sig); +static int rfcomm_send_test(struct rfcomm_session *s, int cr, u8 *pattern, int len); +static int rfcomm_send_credits(struct rfcomm_session *s, u8 addr, u8 credits); +static void rfcomm_make_uih(struct sk_buff *skb, u8 addr); + +static void rfcomm_process_connect(struct rfcomm_session *s); + +static struct rfcomm_session *rfcomm_session_create(bdaddr_t *src, + bdaddr_t *dst, + u8 sec_level, + int *err); +static struct rfcomm_session *rfcomm_session_get(bdaddr_t *src, bdaddr_t *dst); +static struct rfcomm_session *rfcomm_session_del(struct rfcomm_session *s); + +/* ---- RFCOMM frame parsing macros ---- */ +#define __get_dlci(b) ((b & 0xfc) >> 2) +#define __get_type(b) ((b & 0xef)) + +#define __test_ea(b) ((b & 0x01)) +#define __test_cr(b) (!!(b & 0x02)) +#define __test_pf(b) (!!(b & 0x10)) + +#define __session_dir(s) ((s)->initiator ? 0x00 : 0x01) + +#define __addr(cr, dlci) (((dlci & 0x3f) << 2) | (cr << 1) | 0x01) +#define __ctrl(type, pf) (((type & 0xef) | (pf << 4))) +#define __dlci(dir, chn) (((chn & 0x1f) << 1) | dir) +#define __srv_channel(dlci) (dlci >> 1) + +#define __len8(len) (((len) << 1) | 1) +#define __len16(len) ((len) << 1) + +/* MCC macros */ +#define __mcc_type(cr, type) (((type << 2) | (cr << 1) | 0x01)) +#define __get_mcc_type(b) ((b & 0xfc) >> 2) +#define __get_mcc_len(b) ((b & 0xfe) >> 1) + +/* RPN macros */ +#define __rpn_line_settings(data, stop, parity) ((data & 0x3) | ((stop & 0x1) << 2) | ((parity & 0x7) << 3)) +#define __get_rpn_data_bits(line) ((line) & 0x3) +#define __get_rpn_stop_bits(line) (((line) >> 2) & 0x1) +#define __get_rpn_parity(line) (((line) >> 3) & 0x7) + +static DECLARE_WAIT_QUEUE_HEAD(rfcomm_wq); + +static void rfcomm_schedule(void) +{ + wake_up_all(&rfcomm_wq); +} + +/* ---- RFCOMM FCS computation ---- */ + +/* reversed, 8-bit, poly=0x07 */ +static unsigned char rfcomm_crc_table[256] = { + 0x00, 0x91, 0xe3, 0x72, 0x07, 0x96, 0xe4, 0x75, + 0x0e, 0x9f, 0xed, 0x7c, 0x09, 0x98, 0xea, 0x7b, + 0x1c, 0x8d, 0xff, 0x6e, 0x1b, 0x8a, 0xf8, 0x69, + 0x12, 0x83, 0xf1, 0x60, 0x15, 0x84, 0xf6, 0x67, + + 0x38, 0xa9, 0xdb, 0x4a, 0x3f, 0xae, 0xdc, 0x4d, + 0x36, 0xa7, 0xd5, 0x44, 0x31, 0xa0, 0xd2, 0x43, + 0x24, 0xb5, 0xc7, 0x56, 0x23, 0xb2, 0xc0, 0x51, + 0x2a, 0xbb, 0xc9, 0x58, 0x2d, 0xbc, 0xce, 0x5f, + + 0x70, 0xe1, 0x93, 0x02, 0x77, 0xe6, 0x94, 0x05, + 0x7e, 0xef, 0x9d, 0x0c, 0x79, 0xe8, 0x9a, 0x0b, + 0x6c, 0xfd, 0x8f, 0x1e, 0x6b, 0xfa, 0x88, 0x19, + 0x62, 0xf3, 0x81, 0x10, 0x65, 0xf4, 0x86, 0x17, + + 0x48, 0xd9, 0xab, 0x3a, 0x4f, 0xde, 0xac, 0x3d, + 0x46, 0xd7, 0xa5, 0x34, 0x41, 0xd0, 0xa2, 0x33, + 0x54, 0xc5, 0xb7, 0x26, 0x53, 0xc2, 0xb0, 0x21, + 0x5a, 0xcb, 0xb9, 0x28, 0x5d, 0xcc, 0xbe, 0x2f, + + 0xe0, 0x71, 0x03, 0x92, 0xe7, 0x76, 0x04, 0x95, + 0xee, 0x7f, 0x0d, 0x9c, 0xe9, 0x78, 0x0a, 0x9b, + 0xfc, 0x6d, 0x1f, 0x8e, 0xfb, 0x6a, 0x18, 0x89, + 0xf2, 0x63, 0x11, 0x80, 0xf5, 0x64, 0x16, 0x87, + + 0xd8, 0x49, 0x3b, 0xaa, 0xdf, 0x4e, 0x3c, 0xad, + 0xd6, 0x47, 0x35, 0xa4, 0xd1, 0x40, 0x32, 0xa3, + 0xc4, 0x55, 0x27, 0xb6, 0xc3, 0x52, 0x20, 0xb1, + 0xca, 0x5b, 0x29, 0xb8, 0xcd, 0x5c, 0x2e, 0xbf, + + 0x90, 0x01, 0x73, 0xe2, 0x97, 0x06, 0x74, 0xe5, + 0x9e, 0x0f, 0x7d, 0xec, 0x99, 0x08, 0x7a, 0xeb, + 0x8c, 0x1d, 0x6f, 0xfe, 0x8b, 0x1a, 0x68, 0xf9, + 0x82, 0x13, 0x61, 0xf0, 0x85, 0x14, 0x66, 0xf7, + + 0xa8, 0x39, 0x4b, 0xda, 0xaf, 0x3e, 0x4c, 0xdd, + 0xa6, 0x37, 0x45, 0xd4, 0xa1, 0x30, 0x42, 0xd3, + 0xb4, 0x25, 0x57, 0xc6, 0xb3, 0x22, 0x50, 0xc1, + 0xba, 0x2b, 0x59, 0xc8, 0xbd, 0x2c, 0x5e, 0xcf +}; + +/* CRC on 2 bytes */ +#define __crc(data) (rfcomm_crc_table[rfcomm_crc_table[0xff ^ data[0]] ^ data[1]]) + +/* FCS on 2 bytes */ +static inline u8 __fcs(u8 *data) +{ + return 0xff - __crc(data); +} + +/* FCS on 3 bytes */ +static inline u8 __fcs2(u8 *data) +{ + return 0xff - rfcomm_crc_table[__crc(data) ^ data[2]]; +} + +/* Check FCS */ +static inline int __check_fcs(u8 *data, int type, u8 fcs) +{ + u8 f = __crc(data); + + if (type != RFCOMM_UIH) + f = rfcomm_crc_table[f ^ data[2]]; + + return rfcomm_crc_table[f ^ fcs] != 0xcf; +} + +/* ---- L2CAP callbacks ---- */ +static void rfcomm_l2state_change(struct sock *sk) +{ + BT_DBG("%p state %d", sk, sk->sk_state); + rfcomm_schedule(); +} + +static void rfcomm_l2data_ready(struct sock *sk) +{ + BT_DBG("%p", sk); + rfcomm_schedule(); +} + +static int rfcomm_l2sock_create(struct socket **sock) +{ + int err; + + BT_DBG(""); + + err = sock_create_kern(&init_net, PF_BLUETOOTH, SOCK_SEQPACKET, BTPROTO_L2CAP, sock); + if (!err) { + struct sock *sk = (*sock)->sk; + sk->sk_data_ready = rfcomm_l2data_ready; + sk->sk_state_change = rfcomm_l2state_change; + } + return err; +} + +static int rfcomm_check_security(struct rfcomm_dlc *d) +{ + struct sock *sk = d->session->sock->sk; + struct l2cap_conn *conn = l2cap_pi(sk)->chan->conn; + + __u8 auth_type; + + switch (d->sec_level) { + case BT_SECURITY_HIGH: + case BT_SECURITY_FIPS: + auth_type = HCI_AT_GENERAL_BONDING_MITM; + break; + case BT_SECURITY_MEDIUM: + auth_type = HCI_AT_GENERAL_BONDING; + break; + default: + auth_type = HCI_AT_NO_BONDING; + break; + } + + return hci_conn_security(conn->hcon, d->sec_level, auth_type, + d->out); +} + +static void rfcomm_session_timeout(struct timer_list *t) +{ + struct rfcomm_session *s = from_timer(s, t, timer); + + BT_DBG("session %p state %ld", s, s->state); + + set_bit(RFCOMM_TIMED_OUT, &s->flags); + rfcomm_schedule(); +} + +static void rfcomm_session_set_timer(struct rfcomm_session *s, long timeout) +{ + BT_DBG("session %p state %ld timeout %ld", s, s->state, timeout); + + mod_timer(&s->timer, jiffies + timeout); +} + +static void rfcomm_session_clear_timer(struct rfcomm_session *s) +{ + BT_DBG("session %p state %ld", s, s->state); + + del_timer_sync(&s->timer); +} + +/* ---- RFCOMM DLCs ---- */ +static void rfcomm_dlc_timeout(struct timer_list *t) +{ + struct rfcomm_dlc *d = from_timer(d, t, timer); + + BT_DBG("dlc %p state %ld", d, d->state); + + set_bit(RFCOMM_TIMED_OUT, &d->flags); + rfcomm_dlc_put(d); + rfcomm_schedule(); +} + +static void rfcomm_dlc_set_timer(struct rfcomm_dlc *d, long timeout) +{ + BT_DBG("dlc %p state %ld timeout %ld", d, d->state, timeout); + + if (!mod_timer(&d->timer, jiffies + timeout)) + rfcomm_dlc_hold(d); +} + +static void rfcomm_dlc_clear_timer(struct rfcomm_dlc *d) +{ + BT_DBG("dlc %p state %ld", d, d->state); + + if (del_timer(&d->timer)) + rfcomm_dlc_put(d); +} + +static void rfcomm_dlc_clear_state(struct rfcomm_dlc *d) +{ + BT_DBG("%p", d); + + d->state = BT_OPEN; + d->flags = 0; + d->mscex = 0; + d->sec_level = BT_SECURITY_LOW; + d->mtu = RFCOMM_DEFAULT_MTU; + d->v24_sig = RFCOMM_V24_RTC | RFCOMM_V24_RTR | RFCOMM_V24_DV; + + d->cfc = RFCOMM_CFC_DISABLED; + d->rx_credits = RFCOMM_DEFAULT_CREDITS; +} + +struct rfcomm_dlc *rfcomm_dlc_alloc(gfp_t prio) +{ + struct rfcomm_dlc *d = kzalloc(sizeof(*d), prio); + + if (!d) + return NULL; + + timer_setup(&d->timer, rfcomm_dlc_timeout, 0); + + skb_queue_head_init(&d->tx_queue); + mutex_init(&d->lock); + refcount_set(&d->refcnt, 1); + + rfcomm_dlc_clear_state(d); + + BT_DBG("%p", d); + + return d; +} + +void rfcomm_dlc_free(struct rfcomm_dlc *d) +{ + BT_DBG("%p", d); + + skb_queue_purge(&d->tx_queue); + kfree(d); +} + +static void rfcomm_dlc_link(struct rfcomm_session *s, struct rfcomm_dlc *d) +{ + BT_DBG("dlc %p session %p", d, s); + + rfcomm_session_clear_timer(s); + rfcomm_dlc_hold(d); + list_add(&d->list, &s->dlcs); + d->session = s; +} + +static void rfcomm_dlc_unlink(struct rfcomm_dlc *d) +{ + struct rfcomm_session *s = d->session; + + BT_DBG("dlc %p refcnt %d session %p", d, refcount_read(&d->refcnt), s); + + list_del(&d->list); + d->session = NULL; + rfcomm_dlc_put(d); + + if (list_empty(&s->dlcs)) + rfcomm_session_set_timer(s, RFCOMM_IDLE_TIMEOUT); +} + +static struct rfcomm_dlc *rfcomm_dlc_get(struct rfcomm_session *s, u8 dlci) +{ + struct rfcomm_dlc *d; + + list_for_each_entry(d, &s->dlcs, list) + if (d->dlci == dlci) + return d; + + return NULL; +} + +static int rfcomm_check_channel(u8 channel) +{ + return channel < 1 || channel > 30; +} + +static int __rfcomm_dlc_open(struct rfcomm_dlc *d, bdaddr_t *src, bdaddr_t *dst, u8 channel) +{ + struct rfcomm_session *s; + int err = 0; + u8 dlci; + + BT_DBG("dlc %p state %ld %pMR -> %pMR channel %d", + d, d->state, src, dst, channel); + + if (rfcomm_check_channel(channel)) + return -EINVAL; + + if (d->state != BT_OPEN && d->state != BT_CLOSED) + return 0; + + s = rfcomm_session_get(src, dst); + if (!s) { + s = rfcomm_session_create(src, dst, d->sec_level, &err); + if (!s) + return err; + } + + dlci = __dlci(__session_dir(s), channel); + + /* Check if DLCI already exists */ + if (rfcomm_dlc_get(s, dlci)) + return -EBUSY; + + rfcomm_dlc_clear_state(d); + + d->dlci = dlci; + d->addr = __addr(s->initiator, dlci); + d->priority = 7; + + d->state = BT_CONFIG; + rfcomm_dlc_link(s, d); + + d->out = 1; + + d->mtu = s->mtu; + d->cfc = (s->cfc == RFCOMM_CFC_UNKNOWN) ? 0 : s->cfc; + + if (s->state == BT_CONNECTED) { + if (rfcomm_check_security(d)) + rfcomm_send_pn(s, 1, d); + else + set_bit(RFCOMM_AUTH_PENDING, &d->flags); + } + + rfcomm_dlc_set_timer(d, RFCOMM_CONN_TIMEOUT); + + return 0; +} + +int rfcomm_dlc_open(struct rfcomm_dlc *d, bdaddr_t *src, bdaddr_t *dst, u8 channel) +{ + int r; + + rfcomm_lock(); + + r = __rfcomm_dlc_open(d, src, dst, channel); + + rfcomm_unlock(); + return r; +} + +static void __rfcomm_dlc_disconn(struct rfcomm_dlc *d) +{ + struct rfcomm_session *s = d->session; + + d->state = BT_DISCONN; + if (skb_queue_empty(&d->tx_queue)) { + rfcomm_send_disc(s, d->dlci); + rfcomm_dlc_set_timer(d, RFCOMM_DISC_TIMEOUT); + } else { + rfcomm_queue_disc(d); + rfcomm_dlc_set_timer(d, RFCOMM_DISC_TIMEOUT * 2); + } +} + +static int __rfcomm_dlc_close(struct rfcomm_dlc *d, int err) +{ + struct rfcomm_session *s = d->session; + if (!s) + return 0; + + BT_DBG("dlc %p state %ld dlci %d err %d session %p", + d, d->state, d->dlci, err, s); + + switch (d->state) { + case BT_CONNECT: + case BT_CONFIG: + case BT_OPEN: + case BT_CONNECT2: + if (test_and_clear_bit(RFCOMM_DEFER_SETUP, &d->flags)) { + set_bit(RFCOMM_AUTH_REJECT, &d->flags); + rfcomm_schedule(); + return 0; + } + } + + switch (d->state) { + case BT_CONNECT: + case BT_CONNECTED: + __rfcomm_dlc_disconn(d); + break; + + case BT_CONFIG: + if (s->state != BT_BOUND) { + __rfcomm_dlc_disconn(d); + break; + } + /* if closing a dlc in a session that hasn't been started, + * just close and unlink the dlc + */ + fallthrough; + + default: + rfcomm_dlc_clear_timer(d); + + rfcomm_dlc_lock(d); + d->state = BT_CLOSED; + d->state_change(d, err); + rfcomm_dlc_unlock(d); + + skb_queue_purge(&d->tx_queue); + rfcomm_dlc_unlink(d); + } + + return 0; +} + +int rfcomm_dlc_close(struct rfcomm_dlc *d, int err) +{ + int r = 0; + struct rfcomm_dlc *d_list; + struct rfcomm_session *s, *s_list; + + BT_DBG("dlc %p state %ld dlci %d err %d", d, d->state, d->dlci, err); + + rfcomm_lock(); + + s = d->session; + if (!s) + goto no_session; + + /* after waiting on the mutex check the session still exists + * then check the dlc still exists + */ + list_for_each_entry(s_list, &session_list, list) { + if (s_list == s) { + list_for_each_entry(d_list, &s->dlcs, list) { + if (d_list == d) { + r = __rfcomm_dlc_close(d, err); + break; + } + } + break; + } + } + +no_session: + rfcomm_unlock(); + return r; +} + +struct rfcomm_dlc *rfcomm_dlc_exists(bdaddr_t *src, bdaddr_t *dst, u8 channel) +{ + struct rfcomm_session *s; + struct rfcomm_dlc *dlc = NULL; + u8 dlci; + + if (rfcomm_check_channel(channel)) + return ERR_PTR(-EINVAL); + + rfcomm_lock(); + s = rfcomm_session_get(src, dst); + if (s) { + dlci = __dlci(__session_dir(s), channel); + dlc = rfcomm_dlc_get(s, dlci); + } + rfcomm_unlock(); + return dlc; +} + +static int rfcomm_dlc_send_frag(struct rfcomm_dlc *d, struct sk_buff *frag) +{ + int len = frag->len; + + BT_DBG("dlc %p mtu %d len %d", d, d->mtu, len); + + if (len > d->mtu) + return -EINVAL; + + rfcomm_make_uih(frag, d->addr); + __skb_queue_tail(&d->tx_queue, frag); + + return len; +} + +int rfcomm_dlc_send(struct rfcomm_dlc *d, struct sk_buff *skb) +{ + unsigned long flags; + struct sk_buff *frag, *next; + int len; + + if (d->state != BT_CONNECTED) + return -ENOTCONN; + + frag = skb_shinfo(skb)->frag_list; + skb_shinfo(skb)->frag_list = NULL; + + /* Queue all fragments atomically. */ + spin_lock_irqsave(&d->tx_queue.lock, flags); + + len = rfcomm_dlc_send_frag(d, skb); + if (len < 0 || !frag) + goto unlock; + + for (; frag; frag = next) { + int ret; + + next = frag->next; + + ret = rfcomm_dlc_send_frag(d, frag); + if (ret < 0) { + dev_kfree_skb_irq(frag); + goto unlock; + } + + len += ret; + } + +unlock: + spin_unlock_irqrestore(&d->tx_queue.lock, flags); + + if (len > 0 && !test_bit(RFCOMM_TX_THROTTLED, &d->flags)) + rfcomm_schedule(); + return len; +} + +void rfcomm_dlc_send_noerror(struct rfcomm_dlc *d, struct sk_buff *skb) +{ + int len = skb->len; + + BT_DBG("dlc %p mtu %d len %d", d, d->mtu, len); + + rfcomm_make_uih(skb, d->addr); + skb_queue_tail(&d->tx_queue, skb); + + if (d->state == BT_CONNECTED && + !test_bit(RFCOMM_TX_THROTTLED, &d->flags)) + rfcomm_schedule(); +} + +void __rfcomm_dlc_throttle(struct rfcomm_dlc *d) +{ + BT_DBG("dlc %p state %ld", d, d->state); + + if (!d->cfc) { + d->v24_sig |= RFCOMM_V24_FC; + set_bit(RFCOMM_MSC_PENDING, &d->flags); + } + rfcomm_schedule(); +} + +void __rfcomm_dlc_unthrottle(struct rfcomm_dlc *d) +{ + BT_DBG("dlc %p state %ld", d, d->state); + + if (!d->cfc) { + d->v24_sig &= ~RFCOMM_V24_FC; + set_bit(RFCOMM_MSC_PENDING, &d->flags); + } + rfcomm_schedule(); +} + +/* + Set/get modem status functions use _local_ status i.e. what we report + to the other side. + Remote status is provided by dlc->modem_status() callback. + */ +int rfcomm_dlc_set_modem_status(struct rfcomm_dlc *d, u8 v24_sig) +{ + BT_DBG("dlc %p state %ld v24_sig 0x%x", + d, d->state, v24_sig); + + if (test_bit(RFCOMM_RX_THROTTLED, &d->flags)) + v24_sig |= RFCOMM_V24_FC; + else + v24_sig &= ~RFCOMM_V24_FC; + + d->v24_sig = v24_sig; + + if (!test_and_set_bit(RFCOMM_MSC_PENDING, &d->flags)) + rfcomm_schedule(); + + return 0; +} + +int rfcomm_dlc_get_modem_status(struct rfcomm_dlc *d, u8 *v24_sig) +{ + BT_DBG("dlc %p state %ld v24_sig 0x%x", + d, d->state, d->v24_sig); + + *v24_sig = d->v24_sig; + return 0; +} + +/* ---- RFCOMM sessions ---- */ +static struct rfcomm_session *rfcomm_session_add(struct socket *sock, int state) +{ + struct rfcomm_session *s = kzalloc(sizeof(*s), GFP_KERNEL); + + if (!s) + return NULL; + + BT_DBG("session %p sock %p", s, sock); + + timer_setup(&s->timer, rfcomm_session_timeout, 0); + + INIT_LIST_HEAD(&s->dlcs); + s->state = state; + s->sock = sock; + + s->mtu = RFCOMM_DEFAULT_MTU; + s->cfc = disable_cfc ? RFCOMM_CFC_DISABLED : RFCOMM_CFC_UNKNOWN; + + /* Do not increment module usage count for listening sessions. + * Otherwise we won't be able to unload the module. */ + if (state != BT_LISTEN) + if (!try_module_get(THIS_MODULE)) { + kfree(s); + return NULL; + } + + list_add(&s->list, &session_list); + + return s; +} + +static struct rfcomm_session *rfcomm_session_del(struct rfcomm_session *s) +{ + int state = s->state; + + BT_DBG("session %p state %ld", s, s->state); + + list_del(&s->list); + + rfcomm_session_clear_timer(s); + sock_release(s->sock); + kfree(s); + + if (state != BT_LISTEN) + module_put(THIS_MODULE); + + return NULL; +} + +static struct rfcomm_session *rfcomm_session_get(bdaddr_t *src, bdaddr_t *dst) +{ + struct rfcomm_session *s, *n; + struct l2cap_chan *chan; + list_for_each_entry_safe(s, n, &session_list, list) { + chan = l2cap_pi(s->sock->sk)->chan; + + if ((!bacmp(src, BDADDR_ANY) || !bacmp(&chan->src, src)) && + !bacmp(&chan->dst, dst)) + return s; + } + return NULL; +} + +static struct rfcomm_session *rfcomm_session_close(struct rfcomm_session *s, + int err) +{ + struct rfcomm_dlc *d, *n; + + s->state = BT_CLOSED; + + BT_DBG("session %p state %ld err %d", s, s->state, err); + + /* Close all dlcs */ + list_for_each_entry_safe(d, n, &s->dlcs, list) { + d->state = BT_CLOSED; + __rfcomm_dlc_close(d, err); + } + + rfcomm_session_clear_timer(s); + return rfcomm_session_del(s); +} + +static struct rfcomm_session *rfcomm_session_create(bdaddr_t *src, + bdaddr_t *dst, + u8 sec_level, + int *err) +{ + struct rfcomm_session *s = NULL; + struct sockaddr_l2 addr; + struct socket *sock; + struct sock *sk; + + BT_DBG("%pMR -> %pMR", src, dst); + + *err = rfcomm_l2sock_create(&sock); + if (*err < 0) + return NULL; + + bacpy(&addr.l2_bdaddr, src); + addr.l2_family = AF_BLUETOOTH; + addr.l2_psm = 0; + addr.l2_cid = 0; + addr.l2_bdaddr_type = BDADDR_BREDR; + *err = kernel_bind(sock, (struct sockaddr *) &addr, sizeof(addr)); + if (*err < 0) + goto failed; + + /* Set L2CAP options */ + sk = sock->sk; + lock_sock(sk); + /* Set MTU to 0 so L2CAP can auto select the MTU */ + l2cap_pi(sk)->chan->imtu = 0; + l2cap_pi(sk)->chan->sec_level = sec_level; + if (l2cap_ertm) + l2cap_pi(sk)->chan->mode = L2CAP_MODE_ERTM; + release_sock(sk); + + s = rfcomm_session_add(sock, BT_BOUND); + if (!s) { + *err = -ENOMEM; + goto failed; + } + + s->initiator = 1; + + bacpy(&addr.l2_bdaddr, dst); + addr.l2_family = AF_BLUETOOTH; + addr.l2_psm = cpu_to_le16(L2CAP_PSM_RFCOMM); + addr.l2_cid = 0; + addr.l2_bdaddr_type = BDADDR_BREDR; + *err = kernel_connect(sock, (struct sockaddr *) &addr, sizeof(addr), O_NONBLOCK); + if (*err == 0 || *err == -EINPROGRESS) + return s; + + return rfcomm_session_del(s); + +failed: + sock_release(sock); + return NULL; +} + +void rfcomm_session_getaddr(struct rfcomm_session *s, bdaddr_t *src, bdaddr_t *dst) +{ + struct l2cap_chan *chan = l2cap_pi(s->sock->sk)->chan; + if (src) + bacpy(src, &chan->src); + if (dst) + bacpy(dst, &chan->dst); +} + +/* ---- RFCOMM frame sending ---- */ +static int rfcomm_send_frame(struct rfcomm_session *s, u8 *data, int len) +{ + struct kvec iv = { data, len }; + struct msghdr msg; + + BT_DBG("session %p len %d", s, len); + + memset(&msg, 0, sizeof(msg)); + + return kernel_sendmsg(s->sock, &msg, &iv, 1, len); +} + +static int rfcomm_send_cmd(struct rfcomm_session *s, struct rfcomm_cmd *cmd) +{ + BT_DBG("%p cmd %u", s, cmd->ctrl); + + return rfcomm_send_frame(s, (void *) cmd, sizeof(*cmd)); +} + +static int rfcomm_send_sabm(struct rfcomm_session *s, u8 dlci) +{ + struct rfcomm_cmd cmd; + + BT_DBG("%p dlci %d", s, dlci); + + cmd.addr = __addr(s->initiator, dlci); + cmd.ctrl = __ctrl(RFCOMM_SABM, 1); + cmd.len = __len8(0); + cmd.fcs = __fcs2((u8 *) &cmd); + + return rfcomm_send_cmd(s, &cmd); +} + +static int rfcomm_send_ua(struct rfcomm_session *s, u8 dlci) +{ + struct rfcomm_cmd cmd; + + BT_DBG("%p dlci %d", s, dlci); + + cmd.addr = __addr(!s->initiator, dlci); + cmd.ctrl = __ctrl(RFCOMM_UA, 1); + cmd.len = __len8(0); + cmd.fcs = __fcs2((u8 *) &cmd); + + return rfcomm_send_cmd(s, &cmd); +} + +static int rfcomm_send_disc(struct rfcomm_session *s, u8 dlci) +{ + struct rfcomm_cmd cmd; + + BT_DBG("%p dlci %d", s, dlci); + + cmd.addr = __addr(s->initiator, dlci); + cmd.ctrl = __ctrl(RFCOMM_DISC, 1); + cmd.len = __len8(0); + cmd.fcs = __fcs2((u8 *) &cmd); + + return rfcomm_send_cmd(s, &cmd); +} + +static int rfcomm_queue_disc(struct rfcomm_dlc *d) +{ + struct rfcomm_cmd *cmd; + struct sk_buff *skb; + + BT_DBG("dlc %p dlci %d", d, d->dlci); + + skb = alloc_skb(sizeof(*cmd), GFP_KERNEL); + if (!skb) + return -ENOMEM; + + cmd = __skb_put(skb, sizeof(*cmd)); + cmd->addr = d->addr; + cmd->ctrl = __ctrl(RFCOMM_DISC, 1); + cmd->len = __len8(0); + cmd->fcs = __fcs2((u8 *) cmd); + + skb_queue_tail(&d->tx_queue, skb); + rfcomm_schedule(); + return 0; +} + +static int rfcomm_send_dm(struct rfcomm_session *s, u8 dlci) +{ + struct rfcomm_cmd cmd; + + BT_DBG("%p dlci %d", s, dlci); + + cmd.addr = __addr(!s->initiator, dlci); + cmd.ctrl = __ctrl(RFCOMM_DM, 1); + cmd.len = __len8(0); + cmd.fcs = __fcs2((u8 *) &cmd); + + return rfcomm_send_cmd(s, &cmd); +} + +static int rfcomm_send_nsc(struct rfcomm_session *s, int cr, u8 type) +{ + struct rfcomm_hdr *hdr; + struct rfcomm_mcc *mcc; + u8 buf[16], *ptr = buf; + + BT_DBG("%p cr %d type %d", s, cr, type); + + hdr = (void *) ptr; ptr += sizeof(*hdr); + hdr->addr = __addr(s->initiator, 0); + hdr->ctrl = __ctrl(RFCOMM_UIH, 0); + hdr->len = __len8(sizeof(*mcc) + 1); + + mcc = (void *) ptr; ptr += sizeof(*mcc); + mcc->type = __mcc_type(0, RFCOMM_NSC); + mcc->len = __len8(1); + + /* Type that we didn't like */ + *ptr = __mcc_type(cr, type); ptr++; + + *ptr = __fcs(buf); ptr++; + + return rfcomm_send_frame(s, buf, ptr - buf); +} + +static int rfcomm_send_pn(struct rfcomm_session *s, int cr, struct rfcomm_dlc *d) +{ + struct rfcomm_hdr *hdr; + struct rfcomm_mcc *mcc; + struct rfcomm_pn *pn; + u8 buf[16], *ptr = buf; + + BT_DBG("%p cr %d dlci %d mtu %d", s, cr, d->dlci, d->mtu); + + hdr = (void *) ptr; ptr += sizeof(*hdr); + hdr->addr = __addr(s->initiator, 0); + hdr->ctrl = __ctrl(RFCOMM_UIH, 0); + hdr->len = __len8(sizeof(*mcc) + sizeof(*pn)); + + mcc = (void *) ptr; ptr += sizeof(*mcc); + mcc->type = __mcc_type(cr, RFCOMM_PN); + mcc->len = __len8(sizeof(*pn)); + + pn = (void *) ptr; ptr += sizeof(*pn); + pn->dlci = d->dlci; + pn->priority = d->priority; + pn->ack_timer = 0; + pn->max_retrans = 0; + + if (s->cfc) { + pn->flow_ctrl = cr ? 0xf0 : 0xe0; + pn->credits = RFCOMM_DEFAULT_CREDITS; + } else { + pn->flow_ctrl = 0; + pn->credits = 0; + } + + if (cr && channel_mtu >= 0) + pn->mtu = cpu_to_le16(channel_mtu); + else + pn->mtu = cpu_to_le16(d->mtu); + + *ptr = __fcs(buf); ptr++; + + return rfcomm_send_frame(s, buf, ptr - buf); +} + +int rfcomm_send_rpn(struct rfcomm_session *s, int cr, u8 dlci, + u8 bit_rate, u8 data_bits, u8 stop_bits, + u8 parity, u8 flow_ctrl_settings, + u8 xon_char, u8 xoff_char, u16 param_mask) +{ + struct rfcomm_hdr *hdr; + struct rfcomm_mcc *mcc; + struct rfcomm_rpn *rpn; + u8 buf[16], *ptr = buf; + + BT_DBG("%p cr %d dlci %d bit_r 0x%x data_b 0x%x stop_b 0x%x parity 0x%x" + " flwc_s 0x%x xon_c 0x%x xoff_c 0x%x p_mask 0x%x", + s, cr, dlci, bit_rate, data_bits, stop_bits, parity, + flow_ctrl_settings, xon_char, xoff_char, param_mask); + + hdr = (void *) ptr; ptr += sizeof(*hdr); + hdr->addr = __addr(s->initiator, 0); + hdr->ctrl = __ctrl(RFCOMM_UIH, 0); + hdr->len = __len8(sizeof(*mcc) + sizeof(*rpn)); + + mcc = (void *) ptr; ptr += sizeof(*mcc); + mcc->type = __mcc_type(cr, RFCOMM_RPN); + mcc->len = __len8(sizeof(*rpn)); + + rpn = (void *) ptr; ptr += sizeof(*rpn); + rpn->dlci = __addr(1, dlci); + rpn->bit_rate = bit_rate; + rpn->line_settings = __rpn_line_settings(data_bits, stop_bits, parity); + rpn->flow_ctrl = flow_ctrl_settings; + rpn->xon_char = xon_char; + rpn->xoff_char = xoff_char; + rpn->param_mask = cpu_to_le16(param_mask); + + *ptr = __fcs(buf); ptr++; + + return rfcomm_send_frame(s, buf, ptr - buf); +} + +static int rfcomm_send_rls(struct rfcomm_session *s, int cr, u8 dlci, u8 status) +{ + struct rfcomm_hdr *hdr; + struct rfcomm_mcc *mcc; + struct rfcomm_rls *rls; + u8 buf[16], *ptr = buf; + + BT_DBG("%p cr %d status 0x%x", s, cr, status); + + hdr = (void *) ptr; ptr += sizeof(*hdr); + hdr->addr = __addr(s->initiator, 0); + hdr->ctrl = __ctrl(RFCOMM_UIH, 0); + hdr->len = __len8(sizeof(*mcc) + sizeof(*rls)); + + mcc = (void *) ptr; ptr += sizeof(*mcc); + mcc->type = __mcc_type(cr, RFCOMM_RLS); + mcc->len = __len8(sizeof(*rls)); + + rls = (void *) ptr; ptr += sizeof(*rls); + rls->dlci = __addr(1, dlci); + rls->status = status; + + *ptr = __fcs(buf); ptr++; + + return rfcomm_send_frame(s, buf, ptr - buf); +} + +static int rfcomm_send_msc(struct rfcomm_session *s, int cr, u8 dlci, u8 v24_sig) +{ + struct rfcomm_hdr *hdr; + struct rfcomm_mcc *mcc; + struct rfcomm_msc *msc; + u8 buf[16], *ptr = buf; + + BT_DBG("%p cr %d v24 0x%x", s, cr, v24_sig); + + hdr = (void *) ptr; ptr += sizeof(*hdr); + hdr->addr = __addr(s->initiator, 0); + hdr->ctrl = __ctrl(RFCOMM_UIH, 0); + hdr->len = __len8(sizeof(*mcc) + sizeof(*msc)); + + mcc = (void *) ptr; ptr += sizeof(*mcc); + mcc->type = __mcc_type(cr, RFCOMM_MSC); + mcc->len = __len8(sizeof(*msc)); + + msc = (void *) ptr; ptr += sizeof(*msc); + msc->dlci = __addr(1, dlci); + msc->v24_sig = v24_sig | 0x01; + + *ptr = __fcs(buf); ptr++; + + return rfcomm_send_frame(s, buf, ptr - buf); +} + +static int rfcomm_send_fcoff(struct rfcomm_session *s, int cr) +{ + struct rfcomm_hdr *hdr; + struct rfcomm_mcc *mcc; + u8 buf[16], *ptr = buf; + + BT_DBG("%p cr %d", s, cr); + + hdr = (void *) ptr; ptr += sizeof(*hdr); + hdr->addr = __addr(s->initiator, 0); + hdr->ctrl = __ctrl(RFCOMM_UIH, 0); + hdr->len = __len8(sizeof(*mcc)); + + mcc = (void *) ptr; ptr += sizeof(*mcc); + mcc->type = __mcc_type(cr, RFCOMM_FCOFF); + mcc->len = __len8(0); + + *ptr = __fcs(buf); ptr++; + + return rfcomm_send_frame(s, buf, ptr - buf); +} + +static int rfcomm_send_fcon(struct rfcomm_session *s, int cr) +{ + struct rfcomm_hdr *hdr; + struct rfcomm_mcc *mcc; + u8 buf[16], *ptr = buf; + + BT_DBG("%p cr %d", s, cr); + + hdr = (void *) ptr; ptr += sizeof(*hdr); + hdr->addr = __addr(s->initiator, 0); + hdr->ctrl = __ctrl(RFCOMM_UIH, 0); + hdr->len = __len8(sizeof(*mcc)); + + mcc = (void *) ptr; ptr += sizeof(*mcc); + mcc->type = __mcc_type(cr, RFCOMM_FCON); + mcc->len = __len8(0); + + *ptr = __fcs(buf); ptr++; + + return rfcomm_send_frame(s, buf, ptr - buf); +} + +static int rfcomm_send_test(struct rfcomm_session *s, int cr, u8 *pattern, int len) +{ + struct socket *sock = s->sock; + struct kvec iv[3]; + struct msghdr msg; + unsigned char hdr[5], crc[1]; + + if (len > 125) + return -EINVAL; + + BT_DBG("%p cr %d", s, cr); + + hdr[0] = __addr(s->initiator, 0); + hdr[1] = __ctrl(RFCOMM_UIH, 0); + hdr[2] = 0x01 | ((len + 2) << 1); + hdr[3] = 0x01 | ((cr & 0x01) << 1) | (RFCOMM_TEST << 2); + hdr[4] = 0x01 | (len << 1); + + crc[0] = __fcs(hdr); + + iv[0].iov_base = hdr; + iv[0].iov_len = 5; + iv[1].iov_base = pattern; + iv[1].iov_len = len; + iv[2].iov_base = crc; + iv[2].iov_len = 1; + + memset(&msg, 0, sizeof(msg)); + + return kernel_sendmsg(sock, &msg, iv, 3, 6 + len); +} + +static int rfcomm_send_credits(struct rfcomm_session *s, u8 addr, u8 credits) +{ + struct rfcomm_hdr *hdr; + u8 buf[16], *ptr = buf; + + BT_DBG("%p addr %d credits %d", s, addr, credits); + + hdr = (void *) ptr; ptr += sizeof(*hdr); + hdr->addr = addr; + hdr->ctrl = __ctrl(RFCOMM_UIH, 1); + hdr->len = __len8(0); + + *ptr = credits; ptr++; + + *ptr = __fcs(buf); ptr++; + + return rfcomm_send_frame(s, buf, ptr - buf); +} + +static void rfcomm_make_uih(struct sk_buff *skb, u8 addr) +{ + struct rfcomm_hdr *hdr; + int len = skb->len; + u8 *crc; + + if (len > 127) { + hdr = skb_push(skb, 4); + put_unaligned(cpu_to_le16(__len16(len)), (__le16 *) &hdr->len); + } else { + hdr = skb_push(skb, 3); + hdr->len = __len8(len); + } + hdr->addr = addr; + hdr->ctrl = __ctrl(RFCOMM_UIH, 0); + + crc = skb_put(skb, 1); + *crc = __fcs((void *) hdr); +} + +/* ---- RFCOMM frame reception ---- */ +static struct rfcomm_session *rfcomm_recv_ua(struct rfcomm_session *s, u8 dlci) +{ + BT_DBG("session %p state %ld dlci %d", s, s->state, dlci); + + if (dlci) { + /* Data channel */ + struct rfcomm_dlc *d = rfcomm_dlc_get(s, dlci); + if (!d) { + rfcomm_send_dm(s, dlci); + return s; + } + + switch (d->state) { + case BT_CONNECT: + rfcomm_dlc_clear_timer(d); + + rfcomm_dlc_lock(d); + d->state = BT_CONNECTED; + d->state_change(d, 0); + rfcomm_dlc_unlock(d); + + rfcomm_send_msc(s, 1, dlci, d->v24_sig); + break; + + case BT_DISCONN: + d->state = BT_CLOSED; + __rfcomm_dlc_close(d, 0); + + if (list_empty(&s->dlcs)) { + s->state = BT_DISCONN; + rfcomm_send_disc(s, 0); + rfcomm_session_clear_timer(s); + } + + break; + } + } else { + /* Control channel */ + switch (s->state) { + case BT_CONNECT: + s->state = BT_CONNECTED; + rfcomm_process_connect(s); + break; + + case BT_DISCONN: + s = rfcomm_session_close(s, ECONNRESET); + break; + } + } + return s; +} + +static struct rfcomm_session *rfcomm_recv_dm(struct rfcomm_session *s, u8 dlci) +{ + int err = 0; + + BT_DBG("session %p state %ld dlci %d", s, s->state, dlci); + + if (dlci) { + /* Data DLC */ + struct rfcomm_dlc *d = rfcomm_dlc_get(s, dlci); + if (d) { + if (d->state == BT_CONNECT || d->state == BT_CONFIG) + err = ECONNREFUSED; + else + err = ECONNRESET; + + d->state = BT_CLOSED; + __rfcomm_dlc_close(d, err); + } + } else { + if (s->state == BT_CONNECT) + err = ECONNREFUSED; + else + err = ECONNRESET; + + s = rfcomm_session_close(s, err); + } + return s; +} + +static struct rfcomm_session *rfcomm_recv_disc(struct rfcomm_session *s, + u8 dlci) +{ + int err = 0; + + BT_DBG("session %p state %ld dlci %d", s, s->state, dlci); + + if (dlci) { + struct rfcomm_dlc *d = rfcomm_dlc_get(s, dlci); + if (d) { + rfcomm_send_ua(s, dlci); + + if (d->state == BT_CONNECT || d->state == BT_CONFIG) + err = ECONNREFUSED; + else + err = ECONNRESET; + + d->state = BT_CLOSED; + __rfcomm_dlc_close(d, err); + } else + rfcomm_send_dm(s, dlci); + + } else { + rfcomm_send_ua(s, 0); + + if (s->state == BT_CONNECT) + err = ECONNREFUSED; + else + err = ECONNRESET; + + s = rfcomm_session_close(s, err); + } + return s; +} + +void rfcomm_dlc_accept(struct rfcomm_dlc *d) +{ + struct sock *sk = d->session->sock->sk; + struct l2cap_conn *conn = l2cap_pi(sk)->chan->conn; + + BT_DBG("dlc %p", d); + + rfcomm_send_ua(d->session, d->dlci); + + rfcomm_dlc_clear_timer(d); + + rfcomm_dlc_lock(d); + d->state = BT_CONNECTED; + d->state_change(d, 0); + rfcomm_dlc_unlock(d); + + if (d->role_switch) + hci_conn_switch_role(conn->hcon, 0x00); + + rfcomm_send_msc(d->session, 1, d->dlci, d->v24_sig); +} + +static void rfcomm_check_accept(struct rfcomm_dlc *d) +{ + if (rfcomm_check_security(d)) { + if (d->defer_setup) { + set_bit(RFCOMM_DEFER_SETUP, &d->flags); + rfcomm_dlc_set_timer(d, RFCOMM_AUTH_TIMEOUT); + + rfcomm_dlc_lock(d); + d->state = BT_CONNECT2; + d->state_change(d, 0); + rfcomm_dlc_unlock(d); + } else + rfcomm_dlc_accept(d); + } else { + set_bit(RFCOMM_AUTH_PENDING, &d->flags); + rfcomm_dlc_set_timer(d, RFCOMM_AUTH_TIMEOUT); + } +} + +static int rfcomm_recv_sabm(struct rfcomm_session *s, u8 dlci) +{ + struct rfcomm_dlc *d; + u8 channel; + + BT_DBG("session %p state %ld dlci %d", s, s->state, dlci); + + if (!dlci) { + rfcomm_send_ua(s, 0); + + if (s->state == BT_OPEN) { + s->state = BT_CONNECTED; + rfcomm_process_connect(s); + } + return 0; + } + + /* Check if DLC exists */ + d = rfcomm_dlc_get(s, dlci); + if (d) { + if (d->state == BT_OPEN) { + /* DLC was previously opened by PN request */ + rfcomm_check_accept(d); + } + return 0; + } + + /* Notify socket layer about incoming connection */ + channel = __srv_channel(dlci); + if (rfcomm_connect_ind(s, channel, &d)) { + d->dlci = dlci; + d->addr = __addr(s->initiator, dlci); + rfcomm_dlc_link(s, d); + + rfcomm_check_accept(d); + } else { + rfcomm_send_dm(s, dlci); + } + + return 0; +} + +static int rfcomm_apply_pn(struct rfcomm_dlc *d, int cr, struct rfcomm_pn *pn) +{ + struct rfcomm_session *s = d->session; + + BT_DBG("dlc %p state %ld dlci %d mtu %d fc 0x%x credits %d", + d, d->state, d->dlci, pn->mtu, pn->flow_ctrl, pn->credits); + + if ((pn->flow_ctrl == 0xf0 && s->cfc != RFCOMM_CFC_DISABLED) || + pn->flow_ctrl == 0xe0) { + d->cfc = RFCOMM_CFC_ENABLED; + d->tx_credits = pn->credits; + } else { + d->cfc = RFCOMM_CFC_DISABLED; + set_bit(RFCOMM_TX_THROTTLED, &d->flags); + } + + if (s->cfc == RFCOMM_CFC_UNKNOWN) + s->cfc = d->cfc; + + d->priority = pn->priority; + + d->mtu = __le16_to_cpu(pn->mtu); + + if (cr && d->mtu > s->mtu) + d->mtu = s->mtu; + + return 0; +} + +static int rfcomm_recv_pn(struct rfcomm_session *s, int cr, struct sk_buff *skb) +{ + struct rfcomm_pn *pn = (void *) skb->data; + struct rfcomm_dlc *d; + u8 dlci = pn->dlci; + + BT_DBG("session %p state %ld dlci %d", s, s->state, dlci); + + if (!dlci) + return 0; + + d = rfcomm_dlc_get(s, dlci); + if (d) { + if (cr) { + /* PN request */ + rfcomm_apply_pn(d, cr, pn); + rfcomm_send_pn(s, 0, d); + } else { + /* PN response */ + switch (d->state) { + case BT_CONFIG: + rfcomm_apply_pn(d, cr, pn); + + d->state = BT_CONNECT; + rfcomm_send_sabm(s, d->dlci); + break; + } + } + } else { + u8 channel = __srv_channel(dlci); + + if (!cr) + return 0; + + /* PN request for non existing DLC. + * Assume incoming connection. */ + if (rfcomm_connect_ind(s, channel, &d)) { + d->dlci = dlci; + d->addr = __addr(s->initiator, dlci); + rfcomm_dlc_link(s, d); + + rfcomm_apply_pn(d, cr, pn); + + d->state = BT_OPEN; + rfcomm_send_pn(s, 0, d); + } else { + rfcomm_send_dm(s, dlci); + } + } + return 0; +} + +static int rfcomm_recv_rpn(struct rfcomm_session *s, int cr, int len, struct sk_buff *skb) +{ + struct rfcomm_rpn *rpn = (void *) skb->data; + u8 dlci = __get_dlci(rpn->dlci); + + u8 bit_rate = 0; + u8 data_bits = 0; + u8 stop_bits = 0; + u8 parity = 0; + u8 flow_ctrl = 0; + u8 xon_char = 0; + u8 xoff_char = 0; + u16 rpn_mask = RFCOMM_RPN_PM_ALL; + + BT_DBG("dlci %d cr %d len 0x%x bitr 0x%x line 0x%x flow 0x%x xonc 0x%x xoffc 0x%x pm 0x%x", + dlci, cr, len, rpn->bit_rate, rpn->line_settings, rpn->flow_ctrl, + rpn->xon_char, rpn->xoff_char, rpn->param_mask); + + if (!cr) + return 0; + + if (len == 1) { + /* This is a request, return default (according to ETSI TS 07.10) settings */ + bit_rate = RFCOMM_RPN_BR_9600; + data_bits = RFCOMM_RPN_DATA_8; + stop_bits = RFCOMM_RPN_STOP_1; + parity = RFCOMM_RPN_PARITY_NONE; + flow_ctrl = RFCOMM_RPN_FLOW_NONE; + xon_char = RFCOMM_RPN_XON_CHAR; + xoff_char = RFCOMM_RPN_XOFF_CHAR; + goto rpn_out; + } + + /* Check for sane values, ignore/accept bit_rate, 8 bits, 1 stop bit, + * no parity, no flow control lines, normal XON/XOFF chars */ + + if (rpn->param_mask & cpu_to_le16(RFCOMM_RPN_PM_BITRATE)) { + bit_rate = rpn->bit_rate; + if (bit_rate > RFCOMM_RPN_BR_230400) { + BT_DBG("RPN bit rate mismatch 0x%x", bit_rate); + bit_rate = RFCOMM_RPN_BR_9600; + rpn_mask ^= RFCOMM_RPN_PM_BITRATE; + } + } + + if (rpn->param_mask & cpu_to_le16(RFCOMM_RPN_PM_DATA)) { + data_bits = __get_rpn_data_bits(rpn->line_settings); + if (data_bits != RFCOMM_RPN_DATA_8) { + BT_DBG("RPN data bits mismatch 0x%x", data_bits); + data_bits = RFCOMM_RPN_DATA_8; + rpn_mask ^= RFCOMM_RPN_PM_DATA; + } + } + + if (rpn->param_mask & cpu_to_le16(RFCOMM_RPN_PM_STOP)) { + stop_bits = __get_rpn_stop_bits(rpn->line_settings); + if (stop_bits != RFCOMM_RPN_STOP_1) { + BT_DBG("RPN stop bits mismatch 0x%x", stop_bits); + stop_bits = RFCOMM_RPN_STOP_1; + rpn_mask ^= RFCOMM_RPN_PM_STOP; + } + } + + if (rpn->param_mask & cpu_to_le16(RFCOMM_RPN_PM_PARITY)) { + parity = __get_rpn_parity(rpn->line_settings); + if (parity != RFCOMM_RPN_PARITY_NONE) { + BT_DBG("RPN parity mismatch 0x%x", parity); + parity = RFCOMM_RPN_PARITY_NONE; + rpn_mask ^= RFCOMM_RPN_PM_PARITY; + } + } + + if (rpn->param_mask & cpu_to_le16(RFCOMM_RPN_PM_FLOW)) { + flow_ctrl = rpn->flow_ctrl; + if (flow_ctrl != RFCOMM_RPN_FLOW_NONE) { + BT_DBG("RPN flow ctrl mismatch 0x%x", flow_ctrl); + flow_ctrl = RFCOMM_RPN_FLOW_NONE; + rpn_mask ^= RFCOMM_RPN_PM_FLOW; + } + } + + if (rpn->param_mask & cpu_to_le16(RFCOMM_RPN_PM_XON)) { + xon_char = rpn->xon_char; + if (xon_char != RFCOMM_RPN_XON_CHAR) { + BT_DBG("RPN XON char mismatch 0x%x", xon_char); + xon_char = RFCOMM_RPN_XON_CHAR; + rpn_mask ^= RFCOMM_RPN_PM_XON; + } + } + + if (rpn->param_mask & cpu_to_le16(RFCOMM_RPN_PM_XOFF)) { + xoff_char = rpn->xoff_char; + if (xoff_char != RFCOMM_RPN_XOFF_CHAR) { + BT_DBG("RPN XOFF char mismatch 0x%x", xoff_char); + xoff_char = RFCOMM_RPN_XOFF_CHAR; + rpn_mask ^= RFCOMM_RPN_PM_XOFF; + } + } + +rpn_out: + rfcomm_send_rpn(s, 0, dlci, bit_rate, data_bits, stop_bits, + parity, flow_ctrl, xon_char, xoff_char, rpn_mask); + + return 0; +} + +static int rfcomm_recv_rls(struct rfcomm_session *s, int cr, struct sk_buff *skb) +{ + struct rfcomm_rls *rls = (void *) skb->data; + u8 dlci = __get_dlci(rls->dlci); + + BT_DBG("dlci %d cr %d status 0x%x", dlci, cr, rls->status); + + if (!cr) + return 0; + + /* We should probably do something with this information here. But + * for now it's sufficient just to reply -- Bluetooth 1.1 says it's + * mandatory to recognise and respond to RLS */ + + rfcomm_send_rls(s, 0, dlci, rls->status); + + return 0; +} + +static int rfcomm_recv_msc(struct rfcomm_session *s, int cr, struct sk_buff *skb) +{ + struct rfcomm_msc *msc = (void *) skb->data; + struct rfcomm_dlc *d; + u8 dlci = __get_dlci(msc->dlci); + + BT_DBG("dlci %d cr %d v24 0x%x", dlci, cr, msc->v24_sig); + + d = rfcomm_dlc_get(s, dlci); + if (!d) + return 0; + + if (cr) { + if (msc->v24_sig & RFCOMM_V24_FC && !d->cfc) + set_bit(RFCOMM_TX_THROTTLED, &d->flags); + else + clear_bit(RFCOMM_TX_THROTTLED, &d->flags); + + rfcomm_dlc_lock(d); + + d->remote_v24_sig = msc->v24_sig; + + if (d->modem_status) + d->modem_status(d, msc->v24_sig); + + rfcomm_dlc_unlock(d); + + rfcomm_send_msc(s, 0, dlci, msc->v24_sig); + + d->mscex |= RFCOMM_MSCEX_RX; + } else + d->mscex |= RFCOMM_MSCEX_TX; + + return 0; +} + +static int rfcomm_recv_mcc(struct rfcomm_session *s, struct sk_buff *skb) +{ + struct rfcomm_mcc *mcc = (void *) skb->data; + u8 type, cr, len; + + cr = __test_cr(mcc->type); + type = __get_mcc_type(mcc->type); + len = __get_mcc_len(mcc->len); + + BT_DBG("%p type 0x%x cr %d", s, type, cr); + + skb_pull(skb, 2); + + switch (type) { + case RFCOMM_PN: + rfcomm_recv_pn(s, cr, skb); + break; + + case RFCOMM_RPN: + rfcomm_recv_rpn(s, cr, len, skb); + break; + + case RFCOMM_RLS: + rfcomm_recv_rls(s, cr, skb); + break; + + case RFCOMM_MSC: + rfcomm_recv_msc(s, cr, skb); + break; + + case RFCOMM_FCOFF: + if (cr) { + set_bit(RFCOMM_TX_THROTTLED, &s->flags); + rfcomm_send_fcoff(s, 0); + } + break; + + case RFCOMM_FCON: + if (cr) { + clear_bit(RFCOMM_TX_THROTTLED, &s->flags); + rfcomm_send_fcon(s, 0); + } + break; + + case RFCOMM_TEST: + if (cr) + rfcomm_send_test(s, 0, skb->data, skb->len); + break; + + case RFCOMM_NSC: + break; + + default: + BT_ERR("Unknown control type 0x%02x", type); + rfcomm_send_nsc(s, cr, type); + break; + } + return 0; +} + +static int rfcomm_recv_data(struct rfcomm_session *s, u8 dlci, int pf, struct sk_buff *skb) +{ + struct rfcomm_dlc *d; + + BT_DBG("session %p state %ld dlci %d pf %d", s, s->state, dlci, pf); + + d = rfcomm_dlc_get(s, dlci); + if (!d) { + rfcomm_send_dm(s, dlci); + goto drop; + } + + if (pf && d->cfc) { + u8 credits = *(u8 *) skb->data; skb_pull(skb, 1); + + d->tx_credits += credits; + if (d->tx_credits) + clear_bit(RFCOMM_TX_THROTTLED, &d->flags); + } + + if (skb->len && d->state == BT_CONNECTED) { + rfcomm_dlc_lock(d); + d->rx_credits--; + d->data_ready(d, skb); + rfcomm_dlc_unlock(d); + return 0; + } + +drop: + kfree_skb(skb); + return 0; +} + +static struct rfcomm_session *rfcomm_recv_frame(struct rfcomm_session *s, + struct sk_buff *skb) +{ + struct rfcomm_hdr *hdr = (void *) skb->data; + u8 type, dlci, fcs; + + if (!s) { + /* no session, so free socket data */ + kfree_skb(skb); + return s; + } + + dlci = __get_dlci(hdr->addr); + type = __get_type(hdr->ctrl); + + /* Trim FCS */ + skb->len--; skb->tail--; + fcs = *(u8 *)skb_tail_pointer(skb); + + if (__check_fcs(skb->data, type, fcs)) { + BT_ERR("bad checksum in packet"); + kfree_skb(skb); + return s; + } + + if (__test_ea(hdr->len)) + skb_pull(skb, 3); + else + skb_pull(skb, 4); + + switch (type) { + case RFCOMM_SABM: + if (__test_pf(hdr->ctrl)) + rfcomm_recv_sabm(s, dlci); + break; + + case RFCOMM_DISC: + if (__test_pf(hdr->ctrl)) + s = rfcomm_recv_disc(s, dlci); + break; + + case RFCOMM_UA: + if (__test_pf(hdr->ctrl)) + s = rfcomm_recv_ua(s, dlci); + break; + + case RFCOMM_DM: + s = rfcomm_recv_dm(s, dlci); + break; + + case RFCOMM_UIH: + if (dlci) { + rfcomm_recv_data(s, dlci, __test_pf(hdr->ctrl), skb); + return s; + } + rfcomm_recv_mcc(s, skb); + break; + + default: + BT_ERR("Unknown packet type 0x%02x", type); + break; + } + kfree_skb(skb); + return s; +} + +/* ---- Connection and data processing ---- */ + +static void rfcomm_process_connect(struct rfcomm_session *s) +{ + struct rfcomm_dlc *d, *n; + + BT_DBG("session %p state %ld", s, s->state); + + list_for_each_entry_safe(d, n, &s->dlcs, list) { + if (d->state == BT_CONFIG) { + d->mtu = s->mtu; + if (rfcomm_check_security(d)) { + rfcomm_send_pn(s, 1, d); + } else { + set_bit(RFCOMM_AUTH_PENDING, &d->flags); + rfcomm_dlc_set_timer(d, RFCOMM_AUTH_TIMEOUT); + } + } + } +} + +/* Send data queued for the DLC. + * Return number of frames left in the queue. + */ +static int rfcomm_process_tx(struct rfcomm_dlc *d) +{ + struct sk_buff *skb; + int err; + + BT_DBG("dlc %p state %ld cfc %d rx_credits %d tx_credits %d", + d, d->state, d->cfc, d->rx_credits, d->tx_credits); + + /* Send pending MSC */ + if (test_and_clear_bit(RFCOMM_MSC_PENDING, &d->flags)) + rfcomm_send_msc(d->session, 1, d->dlci, d->v24_sig); + + if (d->cfc) { + /* CFC enabled. + * Give them some credits */ + if (!test_bit(RFCOMM_RX_THROTTLED, &d->flags) && + d->rx_credits <= (d->cfc >> 2)) { + rfcomm_send_credits(d->session, d->addr, d->cfc - d->rx_credits); + d->rx_credits = d->cfc; + } + } else { + /* CFC disabled. + * Give ourselves some credits */ + d->tx_credits = 5; + } + + if (test_bit(RFCOMM_TX_THROTTLED, &d->flags)) + return skb_queue_len(&d->tx_queue); + + while (d->tx_credits && (skb = skb_dequeue(&d->tx_queue))) { + err = rfcomm_send_frame(d->session, skb->data, skb->len); + if (err < 0) { + skb_queue_head(&d->tx_queue, skb); + break; + } + kfree_skb(skb); + d->tx_credits--; + } + + if (d->cfc && !d->tx_credits) { + /* We're out of TX credits. + * Set TX_THROTTLED flag to avoid unnesary wakeups by dlc_send. */ + set_bit(RFCOMM_TX_THROTTLED, &d->flags); + } + + return skb_queue_len(&d->tx_queue); +} + +static void rfcomm_process_dlcs(struct rfcomm_session *s) +{ + struct rfcomm_dlc *d, *n; + + BT_DBG("session %p state %ld", s, s->state); + + list_for_each_entry_safe(d, n, &s->dlcs, list) { + if (test_bit(RFCOMM_TIMED_OUT, &d->flags)) { + __rfcomm_dlc_close(d, ETIMEDOUT); + continue; + } + + if (test_bit(RFCOMM_ENC_DROP, &d->flags)) { + __rfcomm_dlc_close(d, ECONNREFUSED); + continue; + } + + if (test_and_clear_bit(RFCOMM_AUTH_ACCEPT, &d->flags)) { + rfcomm_dlc_clear_timer(d); + if (d->out) { + rfcomm_send_pn(s, 1, d); + rfcomm_dlc_set_timer(d, RFCOMM_CONN_TIMEOUT); + } else { + if (d->defer_setup) { + set_bit(RFCOMM_DEFER_SETUP, &d->flags); + rfcomm_dlc_set_timer(d, RFCOMM_AUTH_TIMEOUT); + + rfcomm_dlc_lock(d); + d->state = BT_CONNECT2; + d->state_change(d, 0); + rfcomm_dlc_unlock(d); + } else + rfcomm_dlc_accept(d); + } + continue; + } else if (test_and_clear_bit(RFCOMM_AUTH_REJECT, &d->flags)) { + rfcomm_dlc_clear_timer(d); + if (!d->out) + rfcomm_send_dm(s, d->dlci); + else + d->state = BT_CLOSED; + __rfcomm_dlc_close(d, ECONNREFUSED); + continue; + } + + if (test_bit(RFCOMM_SEC_PENDING, &d->flags)) + continue; + + if (test_bit(RFCOMM_TX_THROTTLED, &s->flags)) + continue; + + if ((d->state == BT_CONNECTED || d->state == BT_DISCONN) && + d->mscex == RFCOMM_MSCEX_OK) + rfcomm_process_tx(d); + } +} + +static struct rfcomm_session *rfcomm_process_rx(struct rfcomm_session *s) +{ + struct socket *sock = s->sock; + struct sock *sk = sock->sk; + struct sk_buff *skb; + + BT_DBG("session %p state %ld qlen %d", s, s->state, skb_queue_len(&sk->sk_receive_queue)); + + /* Get data directly from socket receive queue without copying it. */ + while ((skb = skb_dequeue(&sk->sk_receive_queue))) { + skb_orphan(skb); + if (!skb_linearize(skb)) { + s = rfcomm_recv_frame(s, skb); + if (!s) + break; + } else { + kfree_skb(skb); + } + } + + if (s && (sk->sk_state == BT_CLOSED)) + s = rfcomm_session_close(s, sk->sk_err); + + return s; +} + +static void rfcomm_accept_connection(struct rfcomm_session *s) +{ + struct socket *sock = s->sock, *nsock; + int err; + + /* Fast check for a new connection. + * Avoids unnesesary socket allocations. */ + if (list_empty(&bt_sk(sock->sk)->accept_q)) + return; + + BT_DBG("session %p", s); + + err = kernel_accept(sock, &nsock, O_NONBLOCK); + if (err < 0) + return; + + /* Set our callbacks */ + nsock->sk->sk_data_ready = rfcomm_l2data_ready; + nsock->sk->sk_state_change = rfcomm_l2state_change; + + s = rfcomm_session_add(nsock, BT_OPEN); + if (s) { + /* We should adjust MTU on incoming sessions. + * L2CAP MTU minus UIH header and FCS. */ + s->mtu = min(l2cap_pi(nsock->sk)->chan->omtu, + l2cap_pi(nsock->sk)->chan->imtu) - 5; + + rfcomm_schedule(); + } else + sock_release(nsock); +} + +static struct rfcomm_session *rfcomm_check_connection(struct rfcomm_session *s) +{ + struct sock *sk = s->sock->sk; + + BT_DBG("%p state %ld", s, s->state); + + switch (sk->sk_state) { + case BT_CONNECTED: + s->state = BT_CONNECT; + + /* We can adjust MTU on outgoing sessions. + * L2CAP MTU minus UIH header and FCS. */ + s->mtu = min(l2cap_pi(sk)->chan->omtu, l2cap_pi(sk)->chan->imtu) - 5; + + rfcomm_send_sabm(s, 0); + break; + + case BT_CLOSED: + s = rfcomm_session_close(s, sk->sk_err); + break; + } + return s; +} + +static void rfcomm_process_sessions(void) +{ + struct rfcomm_session *s, *n; + + rfcomm_lock(); + + list_for_each_entry_safe(s, n, &session_list, list) { + if (test_and_clear_bit(RFCOMM_TIMED_OUT, &s->flags)) { + s->state = BT_DISCONN; + rfcomm_send_disc(s, 0); + continue; + } + + switch (s->state) { + case BT_LISTEN: + rfcomm_accept_connection(s); + continue; + + case BT_BOUND: + s = rfcomm_check_connection(s); + break; + + default: + s = rfcomm_process_rx(s); + break; + } + + if (s) + rfcomm_process_dlcs(s); + } + + rfcomm_unlock(); +} + +static int rfcomm_add_listener(bdaddr_t *ba) +{ + struct sockaddr_l2 addr; + struct socket *sock; + struct sock *sk; + struct rfcomm_session *s; + int err = 0; + + /* Create socket */ + err = rfcomm_l2sock_create(&sock); + if (err < 0) { + BT_ERR("Create socket failed %d", err); + return err; + } + + /* Bind socket */ + bacpy(&addr.l2_bdaddr, ba); + addr.l2_family = AF_BLUETOOTH; + addr.l2_psm = cpu_to_le16(L2CAP_PSM_RFCOMM); + addr.l2_cid = 0; + addr.l2_bdaddr_type = BDADDR_BREDR; + err = kernel_bind(sock, (struct sockaddr *) &addr, sizeof(addr)); + if (err < 0) { + BT_ERR("Bind failed %d", err); + goto failed; + } + + /* Set L2CAP options */ + sk = sock->sk; + lock_sock(sk); + /* Set MTU to 0 so L2CAP can auto select the MTU */ + l2cap_pi(sk)->chan->imtu = 0; + release_sock(sk); + + /* Start listening on the socket */ + err = kernel_listen(sock, 10); + if (err) { + BT_ERR("Listen failed %d", err); + goto failed; + } + + /* Add listening session */ + s = rfcomm_session_add(sock, BT_LISTEN); + if (!s) { + err = -ENOMEM; + goto failed; + } + + return 0; +failed: + sock_release(sock); + return err; +} + +static void rfcomm_kill_listener(void) +{ + struct rfcomm_session *s, *n; + + BT_DBG(""); + + list_for_each_entry_safe(s, n, &session_list, list) + rfcomm_session_del(s); +} + +static int rfcomm_run(void *unused) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + BT_DBG(""); + + set_user_nice(current, -10); + + rfcomm_add_listener(BDADDR_ANY); + + add_wait_queue(&rfcomm_wq, &wait); + while (!kthread_should_stop()) { + + /* Process stuff */ + rfcomm_process_sessions(); + + wait_woken(&wait, TASK_INTERRUPTIBLE, MAX_SCHEDULE_TIMEOUT); + } + remove_wait_queue(&rfcomm_wq, &wait); + + rfcomm_kill_listener(); + + return 0; +} + +static void rfcomm_security_cfm(struct hci_conn *conn, u8 status, u8 encrypt) +{ + struct rfcomm_session *s; + struct rfcomm_dlc *d, *n; + + BT_DBG("conn %p status 0x%02x encrypt 0x%02x", conn, status, encrypt); + + s = rfcomm_session_get(&conn->hdev->bdaddr, &conn->dst); + if (!s) + return; + + list_for_each_entry_safe(d, n, &s->dlcs, list) { + if (test_and_clear_bit(RFCOMM_SEC_PENDING, &d->flags)) { + rfcomm_dlc_clear_timer(d); + if (status || encrypt == 0x00) { + set_bit(RFCOMM_ENC_DROP, &d->flags); + continue; + } + } + + if (d->state == BT_CONNECTED && !status && encrypt == 0x00) { + if (d->sec_level == BT_SECURITY_MEDIUM) { + set_bit(RFCOMM_SEC_PENDING, &d->flags); + rfcomm_dlc_set_timer(d, RFCOMM_AUTH_TIMEOUT); + continue; + } else if (d->sec_level == BT_SECURITY_HIGH || + d->sec_level == BT_SECURITY_FIPS) { + set_bit(RFCOMM_ENC_DROP, &d->flags); + continue; + } + } + + if (!test_and_clear_bit(RFCOMM_AUTH_PENDING, &d->flags)) + continue; + + if (!status && hci_conn_check_secure(conn, d->sec_level)) + set_bit(RFCOMM_AUTH_ACCEPT, &d->flags); + else + set_bit(RFCOMM_AUTH_REJECT, &d->flags); + } + + rfcomm_schedule(); +} + +static struct hci_cb rfcomm_cb = { + .name = "RFCOMM", + .security_cfm = rfcomm_security_cfm +}; + +static int rfcomm_dlc_debugfs_show(struct seq_file *f, void *x) +{ + struct rfcomm_session *s; + + rfcomm_lock(); + + list_for_each_entry(s, &session_list, list) { + struct l2cap_chan *chan = l2cap_pi(s->sock->sk)->chan; + struct rfcomm_dlc *d; + list_for_each_entry(d, &s->dlcs, list) { + seq_printf(f, "%pMR %pMR %ld %d %d %d %d\n", + &chan->src, &chan->dst, + d->state, d->dlci, d->mtu, + d->rx_credits, d->tx_credits); + } + } + + rfcomm_unlock(); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(rfcomm_dlc_debugfs); + +static struct dentry *rfcomm_dlc_debugfs; + +/* ---- Initialization ---- */ +static int __init rfcomm_init(void) +{ + int err; + + hci_register_cb(&rfcomm_cb); + + rfcomm_thread = kthread_run(rfcomm_run, NULL, "krfcommd"); + if (IS_ERR(rfcomm_thread)) { + err = PTR_ERR(rfcomm_thread); + goto unregister; + } + + err = rfcomm_init_ttys(); + if (err < 0) + goto stop; + + err = rfcomm_init_sockets(); + if (err < 0) + goto cleanup; + + BT_INFO("RFCOMM ver %s", VERSION); + + if (IS_ERR_OR_NULL(bt_debugfs)) + return 0; + + rfcomm_dlc_debugfs = debugfs_create_file("rfcomm_dlc", 0444, + bt_debugfs, NULL, + &rfcomm_dlc_debugfs_fops); + + return 0; + +cleanup: + rfcomm_cleanup_ttys(); + +stop: + kthread_stop(rfcomm_thread); + +unregister: + hci_unregister_cb(&rfcomm_cb); + + return err; +} + +static void __exit rfcomm_exit(void) +{ + debugfs_remove(rfcomm_dlc_debugfs); + + hci_unregister_cb(&rfcomm_cb); + + kthread_stop(rfcomm_thread); + + rfcomm_cleanup_ttys(); + + rfcomm_cleanup_sockets(); +} + +module_init(rfcomm_init); +module_exit(rfcomm_exit); + +module_param(disable_cfc, bool, 0644); +MODULE_PARM_DESC(disable_cfc, "Disable credit based flow control"); + +module_param(channel_mtu, int, 0644); +MODULE_PARM_DESC(channel_mtu, "Default MTU for the RFCOMM channel"); + +module_param(l2cap_ertm, bool, 0644); +MODULE_PARM_DESC(l2cap_ertm, "Use L2CAP ERTM mode for connection"); + +MODULE_AUTHOR("Marcel Holtmann "); +MODULE_DESCRIPTION("Bluetooth RFCOMM ver " VERSION); +MODULE_VERSION(VERSION); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("bt-proto-3"); diff --git a/net/bluetooth/rfcomm/sock.c b/net/bluetooth/rfcomm/sock.c new file mode 100644 index 000000000..4397e14ff --- /dev/null +++ b/net/bluetooth/rfcomm/sock.c @@ -0,0 +1,1093 @@ +/* + RFCOMM implementation for Linux Bluetooth stack (BlueZ). + Copyright (C) 2002 Maxim Krasnyansky + Copyright (C) 2002 Marcel Holtmann + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +/* + * RFCOMM sockets. + */ +#include +#include +#include +#include + +#include +#include +#include +#include + +static const struct proto_ops rfcomm_sock_ops; + +static struct bt_sock_list rfcomm_sk_list = { + .lock = __RW_LOCK_UNLOCKED(rfcomm_sk_list.lock) +}; + +static void rfcomm_sock_close(struct sock *sk); +static void rfcomm_sock_kill(struct sock *sk); + +/* ---- DLC callbacks ---- + * + * called under rfcomm_dlc_lock() + */ +static void rfcomm_sk_data_ready(struct rfcomm_dlc *d, struct sk_buff *skb) +{ + struct sock *sk = d->owner; + if (!sk) + return; + + atomic_add(skb->len, &sk->sk_rmem_alloc); + skb_queue_tail(&sk->sk_receive_queue, skb); + sk->sk_data_ready(sk); + + if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf) + rfcomm_dlc_throttle(d); +} + +static void rfcomm_sk_state_change(struct rfcomm_dlc *d, int err) +{ + struct sock *sk = d->owner, *parent; + + if (!sk) + return; + + BT_DBG("dlc %p state %ld err %d", d, d->state, err); + + lock_sock(sk); + + if (err) + sk->sk_err = err; + + sk->sk_state = d->state; + + parent = bt_sk(sk)->parent; + if (parent) { + if (d->state == BT_CLOSED) { + sock_set_flag(sk, SOCK_ZAPPED); + bt_accept_unlink(sk); + } + parent->sk_data_ready(parent); + } else { + if (d->state == BT_CONNECTED) + rfcomm_session_getaddr(d->session, + &rfcomm_pi(sk)->src, NULL); + sk->sk_state_change(sk); + } + + release_sock(sk); + + if (parent && sock_flag(sk, SOCK_ZAPPED)) { + /* We have to drop DLC lock here, otherwise + * rfcomm_sock_destruct() will dead lock. */ + rfcomm_dlc_unlock(d); + rfcomm_sock_kill(sk); + rfcomm_dlc_lock(d); + } +} + +/* ---- Socket functions ---- */ +static struct sock *__rfcomm_get_listen_sock_by_addr(u8 channel, bdaddr_t *src) +{ + struct sock *sk = NULL; + + sk_for_each(sk, &rfcomm_sk_list.head) { + if (rfcomm_pi(sk)->channel != channel) + continue; + + if (bacmp(&rfcomm_pi(sk)->src, src)) + continue; + + if (sk->sk_state == BT_BOUND || sk->sk_state == BT_LISTEN) + break; + } + + return sk ? sk : NULL; +} + +/* Find socket with channel and source bdaddr. + * Returns closest match. + */ +static struct sock *rfcomm_get_sock_by_channel(int state, u8 channel, bdaddr_t *src) +{ + struct sock *sk = NULL, *sk1 = NULL; + + read_lock(&rfcomm_sk_list.lock); + + sk_for_each(sk, &rfcomm_sk_list.head) { + if (state && sk->sk_state != state) + continue; + + if (rfcomm_pi(sk)->channel == channel) { + /* Exact match. */ + if (!bacmp(&rfcomm_pi(sk)->src, src)) + break; + + /* Closest match */ + if (!bacmp(&rfcomm_pi(sk)->src, BDADDR_ANY)) + sk1 = sk; + } + } + + read_unlock(&rfcomm_sk_list.lock); + + return sk ? sk : sk1; +} + +static void rfcomm_sock_destruct(struct sock *sk) +{ + struct rfcomm_dlc *d = rfcomm_pi(sk)->dlc; + + BT_DBG("sk %p dlc %p", sk, d); + + skb_queue_purge(&sk->sk_receive_queue); + skb_queue_purge(&sk->sk_write_queue); + + rfcomm_dlc_lock(d); + rfcomm_pi(sk)->dlc = NULL; + + /* Detach DLC if it's owned by this socket */ + if (d->owner == sk) + d->owner = NULL; + rfcomm_dlc_unlock(d); + + rfcomm_dlc_put(d); +} + +static void rfcomm_sock_cleanup_listen(struct sock *parent) +{ + struct sock *sk; + + BT_DBG("parent %p", parent); + + /* Close not yet accepted dlcs */ + while ((sk = bt_accept_dequeue(parent, NULL))) { + rfcomm_sock_close(sk); + rfcomm_sock_kill(sk); + } + + parent->sk_state = BT_CLOSED; + sock_set_flag(parent, SOCK_ZAPPED); +} + +/* Kill socket (only if zapped and orphan) + * Must be called on unlocked socket. + */ +static void rfcomm_sock_kill(struct sock *sk) +{ + if (!sock_flag(sk, SOCK_ZAPPED) || sk->sk_socket) + return; + + BT_DBG("sk %p state %d refcnt %d", sk, sk->sk_state, refcount_read(&sk->sk_refcnt)); + + /* Kill poor orphan */ + bt_sock_unlink(&rfcomm_sk_list, sk); + sock_set_flag(sk, SOCK_DEAD); + sock_put(sk); +} + +static void __rfcomm_sock_close(struct sock *sk) +{ + struct rfcomm_dlc *d = rfcomm_pi(sk)->dlc; + + BT_DBG("sk %p state %d socket %p", sk, sk->sk_state, sk->sk_socket); + + switch (sk->sk_state) { + case BT_LISTEN: + rfcomm_sock_cleanup_listen(sk); + break; + + case BT_CONNECT: + case BT_CONNECT2: + case BT_CONFIG: + case BT_CONNECTED: + rfcomm_dlc_close(d, 0); + fallthrough; + + default: + sock_set_flag(sk, SOCK_ZAPPED); + break; + } +} + +/* Close socket. + * Must be called on unlocked socket. + */ +static void rfcomm_sock_close(struct sock *sk) +{ + lock_sock(sk); + __rfcomm_sock_close(sk); + release_sock(sk); +} + +static void rfcomm_sock_init(struct sock *sk, struct sock *parent) +{ + struct rfcomm_pinfo *pi = rfcomm_pi(sk); + + BT_DBG("sk %p", sk); + + if (parent) { + sk->sk_type = parent->sk_type; + pi->dlc->defer_setup = test_bit(BT_SK_DEFER_SETUP, + &bt_sk(parent)->flags); + + pi->sec_level = rfcomm_pi(parent)->sec_level; + pi->role_switch = rfcomm_pi(parent)->role_switch; + + security_sk_clone(parent, sk); + } else { + pi->dlc->defer_setup = 0; + + pi->sec_level = BT_SECURITY_LOW; + pi->role_switch = 0; + } + + pi->dlc->sec_level = pi->sec_level; + pi->dlc->role_switch = pi->role_switch; +} + +static struct proto rfcomm_proto = { + .name = "RFCOMM", + .owner = THIS_MODULE, + .obj_size = sizeof(struct rfcomm_pinfo) +}; + +static struct sock *rfcomm_sock_alloc(struct net *net, struct socket *sock, int proto, gfp_t prio, int kern) +{ + struct rfcomm_dlc *d; + struct sock *sk; + + sk = sk_alloc(net, PF_BLUETOOTH, prio, &rfcomm_proto, kern); + if (!sk) + return NULL; + + sock_init_data(sock, sk); + INIT_LIST_HEAD(&bt_sk(sk)->accept_q); + + d = rfcomm_dlc_alloc(prio); + if (!d) { + sk_free(sk); + return NULL; + } + + d->data_ready = rfcomm_sk_data_ready; + d->state_change = rfcomm_sk_state_change; + + rfcomm_pi(sk)->dlc = d; + d->owner = sk; + + sk->sk_destruct = rfcomm_sock_destruct; + sk->sk_sndtimeo = RFCOMM_CONN_TIMEOUT; + + sk->sk_sndbuf = RFCOMM_MAX_CREDITS * RFCOMM_DEFAULT_MTU * 10; + sk->sk_rcvbuf = RFCOMM_MAX_CREDITS * RFCOMM_DEFAULT_MTU * 10; + + sock_reset_flag(sk, SOCK_ZAPPED); + + sk->sk_protocol = proto; + sk->sk_state = BT_OPEN; + + bt_sock_link(&rfcomm_sk_list, sk); + + BT_DBG("sk %p", sk); + return sk; +} + +static int rfcomm_sock_create(struct net *net, struct socket *sock, + int protocol, int kern) +{ + struct sock *sk; + + BT_DBG("sock %p", sock); + + sock->state = SS_UNCONNECTED; + + if (sock->type != SOCK_STREAM && sock->type != SOCK_RAW) + return -ESOCKTNOSUPPORT; + + sock->ops = &rfcomm_sock_ops; + + sk = rfcomm_sock_alloc(net, sock, protocol, GFP_ATOMIC, kern); + if (!sk) + return -ENOMEM; + + rfcomm_sock_init(sk, NULL); + return 0; +} + +static int rfcomm_sock_bind(struct socket *sock, struct sockaddr *addr, int addr_len) +{ + struct sockaddr_rc sa; + struct sock *sk = sock->sk; + int len, err = 0; + + if (!addr || addr_len < offsetofend(struct sockaddr, sa_family) || + addr->sa_family != AF_BLUETOOTH) + return -EINVAL; + + memset(&sa, 0, sizeof(sa)); + len = min_t(unsigned int, sizeof(sa), addr_len); + memcpy(&sa, addr, len); + + BT_DBG("sk %p %pMR", sk, &sa.rc_bdaddr); + + lock_sock(sk); + + if (sk->sk_state != BT_OPEN) { + err = -EBADFD; + goto done; + } + + if (sk->sk_type != SOCK_STREAM) { + err = -EINVAL; + goto done; + } + + write_lock(&rfcomm_sk_list.lock); + + if (sa.rc_channel && + __rfcomm_get_listen_sock_by_addr(sa.rc_channel, &sa.rc_bdaddr)) { + err = -EADDRINUSE; + } else { + /* Save source address */ + bacpy(&rfcomm_pi(sk)->src, &sa.rc_bdaddr); + rfcomm_pi(sk)->channel = sa.rc_channel; + sk->sk_state = BT_BOUND; + } + + write_unlock(&rfcomm_sk_list.lock); + +done: + release_sock(sk); + return err; +} + +static int rfcomm_sock_connect(struct socket *sock, struct sockaddr *addr, int alen, int flags) +{ + struct sockaddr_rc *sa = (struct sockaddr_rc *) addr; + struct sock *sk = sock->sk; + struct rfcomm_dlc *d = rfcomm_pi(sk)->dlc; + int err = 0; + + BT_DBG("sk %p", sk); + + if (alen < sizeof(struct sockaddr_rc) || + addr->sa_family != AF_BLUETOOTH) + return -EINVAL; + + sock_hold(sk); + lock_sock(sk); + + if (sk->sk_state != BT_OPEN && sk->sk_state != BT_BOUND) { + err = -EBADFD; + goto done; + } + + if (sk->sk_type != SOCK_STREAM) { + err = -EINVAL; + goto done; + } + + sk->sk_state = BT_CONNECT; + bacpy(&rfcomm_pi(sk)->dst, &sa->rc_bdaddr); + rfcomm_pi(sk)->channel = sa->rc_channel; + + d->sec_level = rfcomm_pi(sk)->sec_level; + d->role_switch = rfcomm_pi(sk)->role_switch; + + /* Drop sock lock to avoid potential deadlock with the RFCOMM lock */ + release_sock(sk); + err = rfcomm_dlc_open(d, &rfcomm_pi(sk)->src, &sa->rc_bdaddr, + sa->rc_channel); + lock_sock(sk); + if (!err && !sock_flag(sk, SOCK_ZAPPED)) + err = bt_sock_wait_state(sk, BT_CONNECTED, + sock_sndtimeo(sk, flags & O_NONBLOCK)); + +done: + release_sock(sk); + sock_put(sk); + return err; +} + +static int rfcomm_sock_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + int err = 0; + + BT_DBG("sk %p backlog %d", sk, backlog); + + lock_sock(sk); + + if (sk->sk_state != BT_BOUND) { + err = -EBADFD; + goto done; + } + + if (sk->sk_type != SOCK_STREAM) { + err = -EINVAL; + goto done; + } + + if (!rfcomm_pi(sk)->channel) { + bdaddr_t *src = &rfcomm_pi(sk)->src; + u8 channel; + + err = -EINVAL; + + write_lock(&rfcomm_sk_list.lock); + + for (channel = 1; channel < 31; channel++) + if (!__rfcomm_get_listen_sock_by_addr(channel, src)) { + rfcomm_pi(sk)->channel = channel; + err = 0; + break; + } + + write_unlock(&rfcomm_sk_list.lock); + + if (err < 0) + goto done; + } + + sk->sk_max_ack_backlog = backlog; + sk->sk_ack_backlog = 0; + sk->sk_state = BT_LISTEN; + +done: + release_sock(sk); + return err; +} + +static int rfcomm_sock_accept(struct socket *sock, struct socket *newsock, int flags, + bool kern) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + struct sock *sk = sock->sk, *nsk; + long timeo; + int err = 0; + + lock_sock_nested(sk, SINGLE_DEPTH_NESTING); + + if (sk->sk_type != SOCK_STREAM) { + err = -EINVAL; + goto done; + } + + timeo = sock_rcvtimeo(sk, flags & O_NONBLOCK); + + BT_DBG("sk %p timeo %ld", sk, timeo); + + /* Wait for an incoming connection. (wake-one). */ + add_wait_queue_exclusive(sk_sleep(sk), &wait); + while (1) { + if (sk->sk_state != BT_LISTEN) { + err = -EBADFD; + break; + } + + nsk = bt_accept_dequeue(sk, newsock); + if (nsk) + break; + + if (!timeo) { + err = -EAGAIN; + break; + } + + if (signal_pending(current)) { + err = sock_intr_errno(timeo); + break; + } + + release_sock(sk); + + timeo = wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); + + lock_sock_nested(sk, SINGLE_DEPTH_NESTING); + } + remove_wait_queue(sk_sleep(sk), &wait); + + if (err) + goto done; + + newsock->state = SS_CONNECTED; + + BT_DBG("new socket %p", nsk); + +done: + release_sock(sk); + return err; +} + +static int rfcomm_sock_getname(struct socket *sock, struct sockaddr *addr, int peer) +{ + struct sockaddr_rc *sa = (struct sockaddr_rc *) addr; + struct sock *sk = sock->sk; + + BT_DBG("sock %p, sk %p", sock, sk); + + if (peer && sk->sk_state != BT_CONNECTED && + sk->sk_state != BT_CONNECT && sk->sk_state != BT_CONNECT2) + return -ENOTCONN; + + memset(sa, 0, sizeof(*sa)); + sa->rc_family = AF_BLUETOOTH; + sa->rc_channel = rfcomm_pi(sk)->channel; + if (peer) + bacpy(&sa->rc_bdaddr, &rfcomm_pi(sk)->dst); + else + bacpy(&sa->rc_bdaddr, &rfcomm_pi(sk)->src); + + return sizeof(struct sockaddr_rc); +} + +static int rfcomm_sock_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + struct sock *sk = sock->sk; + struct rfcomm_dlc *d = rfcomm_pi(sk)->dlc; + struct sk_buff *skb; + int sent; + + if (test_bit(RFCOMM_DEFER_SETUP, &d->flags)) + return -ENOTCONN; + + if (msg->msg_flags & MSG_OOB) + return -EOPNOTSUPP; + + if (sk->sk_shutdown & SEND_SHUTDOWN) + return -EPIPE; + + BT_DBG("sock %p, sk %p", sock, sk); + + lock_sock(sk); + + sent = bt_sock_wait_ready(sk, msg->msg_flags); + + release_sock(sk); + + if (sent) + return sent; + + skb = bt_skb_sendmmsg(sk, msg, len, d->mtu, RFCOMM_SKB_HEAD_RESERVE, + RFCOMM_SKB_TAIL_RESERVE); + if (IS_ERR(skb)) + return PTR_ERR(skb); + + sent = rfcomm_dlc_send(d, skb); + if (sent < 0) + kfree_skb(skb); + + return sent; +} + +static int rfcomm_sock_recvmsg(struct socket *sock, struct msghdr *msg, + size_t size, int flags) +{ + struct sock *sk = sock->sk; + struct rfcomm_dlc *d = rfcomm_pi(sk)->dlc; + int len; + + if (test_and_clear_bit(RFCOMM_DEFER_SETUP, &d->flags)) { + rfcomm_dlc_accept(d); + return 0; + } + + len = bt_sock_stream_recvmsg(sock, msg, size, flags); + + lock_sock(sk); + if (!(flags & MSG_PEEK) && len > 0) + atomic_sub(len, &sk->sk_rmem_alloc); + + if (atomic_read(&sk->sk_rmem_alloc) <= (sk->sk_rcvbuf >> 2)) + rfcomm_dlc_unthrottle(rfcomm_pi(sk)->dlc); + release_sock(sk); + + return len; +} + +static int rfcomm_sock_setsockopt_old(struct socket *sock, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + int err = 0; + u32 opt; + + BT_DBG("sk %p", sk); + + lock_sock(sk); + + switch (optname) { + case RFCOMM_LM: + if (copy_from_sockptr(&opt, optval, sizeof(u32))) { + err = -EFAULT; + break; + } + + if (opt & RFCOMM_LM_FIPS) { + err = -EINVAL; + break; + } + + if (opt & RFCOMM_LM_AUTH) + rfcomm_pi(sk)->sec_level = BT_SECURITY_LOW; + if (opt & RFCOMM_LM_ENCRYPT) + rfcomm_pi(sk)->sec_level = BT_SECURITY_MEDIUM; + if (opt & RFCOMM_LM_SECURE) + rfcomm_pi(sk)->sec_level = BT_SECURITY_HIGH; + + rfcomm_pi(sk)->role_switch = (opt & RFCOMM_LM_MASTER); + break; + + default: + err = -ENOPROTOOPT; + break; + } + + release_sock(sk); + return err; +} + +static int rfcomm_sock_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct bt_security sec; + int err = 0; + size_t len; + u32 opt; + + BT_DBG("sk %p", sk); + + if (level == SOL_RFCOMM) + return rfcomm_sock_setsockopt_old(sock, optname, optval, optlen); + + if (level != SOL_BLUETOOTH) + return -ENOPROTOOPT; + + lock_sock(sk); + + switch (optname) { + case BT_SECURITY: + if (sk->sk_type != SOCK_STREAM) { + err = -EINVAL; + break; + } + + sec.level = BT_SECURITY_LOW; + + len = min_t(unsigned int, sizeof(sec), optlen); + if (copy_from_sockptr(&sec, optval, len)) { + err = -EFAULT; + break; + } + + if (sec.level > BT_SECURITY_HIGH) { + err = -EINVAL; + break; + } + + rfcomm_pi(sk)->sec_level = sec.level; + break; + + case BT_DEFER_SETUP: + if (sk->sk_state != BT_BOUND && sk->sk_state != BT_LISTEN) { + err = -EINVAL; + break; + } + + if (copy_from_sockptr(&opt, optval, sizeof(u32))) { + err = -EFAULT; + break; + } + + if (opt) + set_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags); + else + clear_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags); + + break; + + default: + err = -ENOPROTOOPT; + break; + } + + release_sock(sk); + return err; +} + +static int rfcomm_sock_getsockopt_old(struct socket *sock, int optname, char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + struct sock *l2cap_sk; + struct l2cap_conn *conn; + struct rfcomm_conninfo cinfo; + int len, err = 0; + u32 opt; + + BT_DBG("sk %p", sk); + + if (get_user(len, optlen)) + return -EFAULT; + + lock_sock(sk); + + switch (optname) { + case RFCOMM_LM: + switch (rfcomm_pi(sk)->sec_level) { + case BT_SECURITY_LOW: + opt = RFCOMM_LM_AUTH; + break; + case BT_SECURITY_MEDIUM: + opt = RFCOMM_LM_AUTH | RFCOMM_LM_ENCRYPT; + break; + case BT_SECURITY_HIGH: + opt = RFCOMM_LM_AUTH | RFCOMM_LM_ENCRYPT | + RFCOMM_LM_SECURE; + break; + case BT_SECURITY_FIPS: + opt = RFCOMM_LM_AUTH | RFCOMM_LM_ENCRYPT | + RFCOMM_LM_SECURE | RFCOMM_LM_FIPS; + break; + default: + opt = 0; + break; + } + + if (rfcomm_pi(sk)->role_switch) + opt |= RFCOMM_LM_MASTER; + + if (put_user(opt, (u32 __user *) optval)) + err = -EFAULT; + + break; + + case RFCOMM_CONNINFO: + if (sk->sk_state != BT_CONNECTED && + !rfcomm_pi(sk)->dlc->defer_setup) { + err = -ENOTCONN; + break; + } + + l2cap_sk = rfcomm_pi(sk)->dlc->session->sock->sk; + conn = l2cap_pi(l2cap_sk)->chan->conn; + + memset(&cinfo, 0, sizeof(cinfo)); + cinfo.hci_handle = conn->hcon->handle; + memcpy(cinfo.dev_class, conn->hcon->dev_class, 3); + + len = min_t(unsigned int, len, sizeof(cinfo)); + if (copy_to_user(optval, (char *) &cinfo, len)) + err = -EFAULT; + + break; + + default: + err = -ENOPROTOOPT; + break; + } + + release_sock(sk); + return err; +} + +static int rfcomm_sock_getsockopt(struct socket *sock, int level, int optname, char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + struct bt_security sec; + int len, err = 0; + + BT_DBG("sk %p", sk); + + if (level == SOL_RFCOMM) + return rfcomm_sock_getsockopt_old(sock, optname, optval, optlen); + + if (level != SOL_BLUETOOTH) + return -ENOPROTOOPT; + + if (get_user(len, optlen)) + return -EFAULT; + + lock_sock(sk); + + switch (optname) { + case BT_SECURITY: + if (sk->sk_type != SOCK_STREAM) { + err = -EINVAL; + break; + } + + sec.level = rfcomm_pi(sk)->sec_level; + sec.key_size = 0; + + len = min_t(unsigned int, len, sizeof(sec)); + if (copy_to_user(optval, (char *) &sec, len)) + err = -EFAULT; + + break; + + case BT_DEFER_SETUP: + if (sk->sk_state != BT_BOUND && sk->sk_state != BT_LISTEN) { + err = -EINVAL; + break; + } + + if (put_user(test_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags), + (u32 __user *) optval)) + err = -EFAULT; + + break; + + default: + err = -ENOPROTOOPT; + break; + } + + release_sock(sk); + return err; +} + +static int rfcomm_sock_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + struct sock *sk __maybe_unused = sock->sk; + int err; + + BT_DBG("sk %p cmd %x arg %lx", sk, cmd, arg); + + err = bt_sock_ioctl(sock, cmd, arg); + + if (err == -ENOIOCTLCMD) { +#ifdef CONFIG_BT_RFCOMM_TTY + lock_sock(sk); + err = rfcomm_dev_ioctl(sk, cmd, (void __user *) arg); + release_sock(sk); +#else + err = -EOPNOTSUPP; +#endif + } + + return err; +} + +#ifdef CONFIG_COMPAT +static int rfcomm_sock_compat_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + return rfcomm_sock_ioctl(sock, cmd, (unsigned long)compat_ptr(arg)); +} +#endif + +static int rfcomm_sock_shutdown(struct socket *sock, int how) +{ + struct sock *sk = sock->sk; + int err = 0; + + BT_DBG("sock %p, sk %p", sock, sk); + + if (!sk) + return 0; + + lock_sock(sk); + if (!sk->sk_shutdown) { + sk->sk_shutdown = SHUTDOWN_MASK; + + release_sock(sk); + __rfcomm_sock_close(sk); + lock_sock(sk); + + if (sock_flag(sk, SOCK_LINGER) && sk->sk_lingertime && + !(current->flags & PF_EXITING)) + err = bt_sock_wait_state(sk, BT_CLOSED, sk->sk_lingertime); + } + release_sock(sk); + return err; +} + +static int rfcomm_sock_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + int err; + + BT_DBG("sock %p, sk %p", sock, sk); + + if (!sk) + return 0; + + err = rfcomm_sock_shutdown(sock, 2); + + sock_orphan(sk); + rfcomm_sock_kill(sk); + return err; +} + +/* ---- RFCOMM core layer callbacks ---- + * + * called under rfcomm_lock() + */ +int rfcomm_connect_ind(struct rfcomm_session *s, u8 channel, struct rfcomm_dlc **d) +{ + struct sock *sk, *parent; + bdaddr_t src, dst; + int result = 0; + + BT_DBG("session %p channel %d", s, channel); + + rfcomm_session_getaddr(s, &src, &dst); + + /* Check if we have socket listening on channel */ + parent = rfcomm_get_sock_by_channel(BT_LISTEN, channel, &src); + if (!parent) + return 0; + + lock_sock(parent); + + /* Check for backlog size */ + if (sk_acceptq_is_full(parent)) { + BT_DBG("backlog full %d", parent->sk_ack_backlog); + goto done; + } + + sk = rfcomm_sock_alloc(sock_net(parent), NULL, BTPROTO_RFCOMM, GFP_ATOMIC, 0); + if (!sk) + goto done; + + bt_sock_reclassify_lock(sk, BTPROTO_RFCOMM); + + rfcomm_sock_init(sk, parent); + bacpy(&rfcomm_pi(sk)->src, &src); + bacpy(&rfcomm_pi(sk)->dst, &dst); + rfcomm_pi(sk)->channel = channel; + + sk->sk_state = BT_CONFIG; + bt_accept_enqueue(parent, sk, true); + + /* Accept connection and return socket DLC */ + *d = rfcomm_pi(sk)->dlc; + result = 1; + +done: + release_sock(parent); + + if (test_bit(BT_SK_DEFER_SETUP, &bt_sk(parent)->flags)) + parent->sk_state_change(parent); + + return result; +} + +static int rfcomm_sock_debugfs_show(struct seq_file *f, void *p) +{ + struct sock *sk; + + read_lock(&rfcomm_sk_list.lock); + + sk_for_each(sk, &rfcomm_sk_list.head) { + seq_printf(f, "%pMR %pMR %d %d\n", + &rfcomm_pi(sk)->src, &rfcomm_pi(sk)->dst, + sk->sk_state, rfcomm_pi(sk)->channel); + } + + read_unlock(&rfcomm_sk_list.lock); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(rfcomm_sock_debugfs); + +static struct dentry *rfcomm_sock_debugfs; + +static const struct proto_ops rfcomm_sock_ops = { + .family = PF_BLUETOOTH, + .owner = THIS_MODULE, + .release = rfcomm_sock_release, + .bind = rfcomm_sock_bind, + .connect = rfcomm_sock_connect, + .listen = rfcomm_sock_listen, + .accept = rfcomm_sock_accept, + .getname = rfcomm_sock_getname, + .sendmsg = rfcomm_sock_sendmsg, + .recvmsg = rfcomm_sock_recvmsg, + .shutdown = rfcomm_sock_shutdown, + .setsockopt = rfcomm_sock_setsockopt, + .getsockopt = rfcomm_sock_getsockopt, + .ioctl = rfcomm_sock_ioctl, + .gettstamp = sock_gettstamp, + .poll = bt_sock_poll, + .socketpair = sock_no_socketpair, + .mmap = sock_no_mmap, +#ifdef CONFIG_COMPAT + .compat_ioctl = rfcomm_sock_compat_ioctl, +#endif +}; + +static const struct net_proto_family rfcomm_sock_family_ops = { + .family = PF_BLUETOOTH, + .owner = THIS_MODULE, + .create = rfcomm_sock_create +}; + +int __init rfcomm_init_sockets(void) +{ + int err; + + BUILD_BUG_ON(sizeof(struct sockaddr_rc) > sizeof(struct sockaddr)); + + err = proto_register(&rfcomm_proto, 0); + if (err < 0) + return err; + + err = bt_sock_register(BTPROTO_RFCOMM, &rfcomm_sock_family_ops); + if (err < 0) { + BT_ERR("RFCOMM socket layer registration failed"); + goto error; + } + + err = bt_procfs_init(&init_net, "rfcomm", &rfcomm_sk_list, NULL); + if (err < 0) { + BT_ERR("Failed to create RFCOMM proc file"); + bt_sock_unregister(BTPROTO_RFCOMM); + goto error; + } + + BT_INFO("RFCOMM socket layer initialized"); + + if (IS_ERR_OR_NULL(bt_debugfs)) + return 0; + + rfcomm_sock_debugfs = debugfs_create_file("rfcomm", 0444, + bt_debugfs, NULL, + &rfcomm_sock_debugfs_fops); + + return 0; + +error: + proto_unregister(&rfcomm_proto); + return err; +} + +void __exit rfcomm_cleanup_sockets(void) +{ + bt_procfs_cleanup(&init_net, "rfcomm"); + + debugfs_remove(rfcomm_sock_debugfs); + + bt_sock_unregister(BTPROTO_RFCOMM); + + proto_unregister(&rfcomm_proto); +} diff --git a/net/bluetooth/rfcomm/tty.c b/net/bluetooth/rfcomm/tty.c new file mode 100644 index 000000000..8009e0e93 --- /dev/null +++ b/net/bluetooth/rfcomm/tty.c @@ -0,0 +1,1162 @@ +/* + RFCOMM implementation for Linux Bluetooth stack (BlueZ). + Copyright (C) 2002 Maxim Krasnyansky + Copyright (C) 2002 Marcel Holtmann + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +/* + * RFCOMM TTY. + */ + +#include + +#include +#include +#include + +#include +#include +#include + +#define RFCOMM_TTY_PORTS RFCOMM_MAX_DEV /* whole lotta rfcomm devices */ +#define RFCOMM_TTY_MAJOR 216 /* device node major id of the usb/bluetooth.c driver */ +#define RFCOMM_TTY_MINOR 0 + +static DEFINE_MUTEX(rfcomm_ioctl_mutex); +static struct tty_driver *rfcomm_tty_driver; + +struct rfcomm_dev { + struct tty_port port; + struct list_head list; + + char name[12]; + int id; + unsigned long flags; + int err; + + unsigned long status; /* don't export to userspace */ + + bdaddr_t src; + bdaddr_t dst; + u8 channel; + + uint modem_status; + + struct rfcomm_dlc *dlc; + + struct device *tty_dev; + + atomic_t wmem_alloc; + + struct sk_buff_head pending; +}; + +static LIST_HEAD(rfcomm_dev_list); +static DEFINE_MUTEX(rfcomm_dev_lock); + +static void rfcomm_dev_data_ready(struct rfcomm_dlc *dlc, struct sk_buff *skb); +static void rfcomm_dev_state_change(struct rfcomm_dlc *dlc, int err); +static void rfcomm_dev_modem_status(struct rfcomm_dlc *dlc, u8 v24_sig); + +/* ---- Device functions ---- */ + +static void rfcomm_dev_destruct(struct tty_port *port) +{ + struct rfcomm_dev *dev = container_of(port, struct rfcomm_dev, port); + struct rfcomm_dlc *dlc = dev->dlc; + + BT_DBG("dev %p dlc %p", dev, dlc); + + rfcomm_dlc_lock(dlc); + /* Detach DLC if it's owned by this dev */ + if (dlc->owner == dev) + dlc->owner = NULL; + rfcomm_dlc_unlock(dlc); + + rfcomm_dlc_put(dlc); + + if (dev->tty_dev) + tty_unregister_device(rfcomm_tty_driver, dev->id); + + mutex_lock(&rfcomm_dev_lock); + list_del(&dev->list); + mutex_unlock(&rfcomm_dev_lock); + + kfree(dev); + + /* It's safe to call module_put() here because socket still + holds reference to this module. */ + module_put(THIS_MODULE); +} + +/* device-specific initialization: open the dlc */ +static int rfcomm_dev_activate(struct tty_port *port, struct tty_struct *tty) +{ + struct rfcomm_dev *dev = container_of(port, struct rfcomm_dev, port); + int err; + + err = rfcomm_dlc_open(dev->dlc, &dev->src, &dev->dst, dev->channel); + if (err) + set_bit(TTY_IO_ERROR, &tty->flags); + return err; +} + +/* we block the open until the dlc->state becomes BT_CONNECTED */ +static int rfcomm_dev_carrier_raised(struct tty_port *port) +{ + struct rfcomm_dev *dev = container_of(port, struct rfcomm_dev, port); + + return (dev->dlc->state == BT_CONNECTED); +} + +/* device-specific cleanup: close the dlc */ +static void rfcomm_dev_shutdown(struct tty_port *port) +{ + struct rfcomm_dev *dev = container_of(port, struct rfcomm_dev, port); + + if (dev->tty_dev->parent) + device_move(dev->tty_dev, NULL, DPM_ORDER_DEV_LAST); + + /* close the dlc */ + rfcomm_dlc_close(dev->dlc, 0); +} + +static const struct tty_port_operations rfcomm_port_ops = { + .destruct = rfcomm_dev_destruct, + .activate = rfcomm_dev_activate, + .shutdown = rfcomm_dev_shutdown, + .carrier_raised = rfcomm_dev_carrier_raised, +}; + +static struct rfcomm_dev *__rfcomm_dev_lookup(int id) +{ + struct rfcomm_dev *dev; + + list_for_each_entry(dev, &rfcomm_dev_list, list) + if (dev->id == id) + return dev; + + return NULL; +} + +static struct rfcomm_dev *rfcomm_dev_get(int id) +{ + struct rfcomm_dev *dev; + + mutex_lock(&rfcomm_dev_lock); + + dev = __rfcomm_dev_lookup(id); + + if (dev && !tty_port_get(&dev->port)) + dev = NULL; + + mutex_unlock(&rfcomm_dev_lock); + + return dev; +} + +static void rfcomm_reparent_device(struct rfcomm_dev *dev) +{ + struct hci_dev *hdev; + struct hci_conn *conn; + + hdev = hci_get_route(&dev->dst, &dev->src, BDADDR_BREDR); + if (!hdev) + return; + + /* The lookup results are unsafe to access without the + * hci device lock (FIXME: why is this not documented?) + */ + hci_dev_lock(hdev); + conn = hci_conn_hash_lookup_ba(hdev, ACL_LINK, &dev->dst); + + /* Just because the acl link is in the hash table is no + * guarantee the sysfs device has been added ... + */ + if (conn && device_is_registered(&conn->dev)) + device_move(dev->tty_dev, &conn->dev, DPM_ORDER_DEV_AFTER_PARENT); + + hci_dev_unlock(hdev); + hci_dev_put(hdev); +} + +static ssize_t address_show(struct device *tty_dev, + struct device_attribute *attr, char *buf) +{ + struct rfcomm_dev *dev = dev_get_drvdata(tty_dev); + return sprintf(buf, "%pMR\n", &dev->dst); +} + +static ssize_t channel_show(struct device *tty_dev, + struct device_attribute *attr, char *buf) +{ + struct rfcomm_dev *dev = dev_get_drvdata(tty_dev); + return sprintf(buf, "%d\n", dev->channel); +} + +static DEVICE_ATTR_RO(address); +static DEVICE_ATTR_RO(channel); + +static struct rfcomm_dev *__rfcomm_dev_add(struct rfcomm_dev_req *req, + struct rfcomm_dlc *dlc) +{ + struct rfcomm_dev *dev, *entry; + struct list_head *head = &rfcomm_dev_list; + int err = 0; + + dev = kzalloc(sizeof(struct rfcomm_dev), GFP_KERNEL); + if (!dev) + return ERR_PTR(-ENOMEM); + + mutex_lock(&rfcomm_dev_lock); + + if (req->dev_id < 0) { + dev->id = 0; + + list_for_each_entry(entry, &rfcomm_dev_list, list) { + if (entry->id != dev->id) + break; + + dev->id++; + head = &entry->list; + } + } else { + dev->id = req->dev_id; + + list_for_each_entry(entry, &rfcomm_dev_list, list) { + if (entry->id == dev->id) { + err = -EADDRINUSE; + goto out; + } + + if (entry->id > dev->id - 1) + break; + + head = &entry->list; + } + } + + if ((dev->id < 0) || (dev->id > RFCOMM_MAX_DEV - 1)) { + err = -ENFILE; + goto out; + } + + sprintf(dev->name, "rfcomm%d", dev->id); + + list_add(&dev->list, head); + + bacpy(&dev->src, &req->src); + bacpy(&dev->dst, &req->dst); + dev->channel = req->channel; + + dev->flags = req->flags & + ((1 << RFCOMM_RELEASE_ONHUP) | (1 << RFCOMM_REUSE_DLC)); + + tty_port_init(&dev->port); + dev->port.ops = &rfcomm_port_ops; + + skb_queue_head_init(&dev->pending); + + rfcomm_dlc_lock(dlc); + + if (req->flags & (1 << RFCOMM_REUSE_DLC)) { + struct sock *sk = dlc->owner; + struct sk_buff *skb; + + BUG_ON(!sk); + + rfcomm_dlc_throttle(dlc); + + while ((skb = skb_dequeue(&sk->sk_receive_queue))) { + skb_orphan(skb); + skb_queue_tail(&dev->pending, skb); + atomic_sub(skb->len, &sk->sk_rmem_alloc); + } + } + + dlc->data_ready = rfcomm_dev_data_ready; + dlc->state_change = rfcomm_dev_state_change; + dlc->modem_status = rfcomm_dev_modem_status; + + dlc->owner = dev; + dev->dlc = dlc; + + rfcomm_dev_modem_status(dlc, dlc->remote_v24_sig); + + rfcomm_dlc_unlock(dlc); + + /* It's safe to call __module_get() here because socket already + holds reference to this module. */ + __module_get(THIS_MODULE); + + mutex_unlock(&rfcomm_dev_lock); + return dev; + +out: + mutex_unlock(&rfcomm_dev_lock); + kfree(dev); + return ERR_PTR(err); +} + +static int rfcomm_dev_add(struct rfcomm_dev_req *req, struct rfcomm_dlc *dlc) +{ + struct rfcomm_dev *dev; + struct device *tty; + + BT_DBG("id %d channel %d", req->dev_id, req->channel); + + dev = __rfcomm_dev_add(req, dlc); + if (IS_ERR(dev)) { + rfcomm_dlc_put(dlc); + return PTR_ERR(dev); + } + + tty = tty_port_register_device(&dev->port, rfcomm_tty_driver, + dev->id, NULL); + if (IS_ERR(tty)) { + tty_port_put(&dev->port); + return PTR_ERR(tty); + } + + dev->tty_dev = tty; + rfcomm_reparent_device(dev); + dev_set_drvdata(dev->tty_dev, dev); + + if (device_create_file(dev->tty_dev, &dev_attr_address) < 0) + BT_ERR("Failed to create address attribute"); + + if (device_create_file(dev->tty_dev, &dev_attr_channel) < 0) + BT_ERR("Failed to create channel attribute"); + + return dev->id; +} + +/* ---- Send buffer ---- */ +static inline unsigned int rfcomm_room(struct rfcomm_dev *dev) +{ + struct rfcomm_dlc *dlc = dev->dlc; + + /* Limit the outstanding number of packets not yet sent to 40 */ + int pending = 40 - atomic_read(&dev->wmem_alloc); + + return max(0, pending) * dlc->mtu; +} + +static void rfcomm_wfree(struct sk_buff *skb) +{ + struct rfcomm_dev *dev = (void *) skb->sk; + atomic_dec(&dev->wmem_alloc); + if (test_bit(RFCOMM_TTY_ATTACHED, &dev->flags)) + tty_port_tty_wakeup(&dev->port); + tty_port_put(&dev->port); +} + +static void rfcomm_set_owner_w(struct sk_buff *skb, struct rfcomm_dev *dev) +{ + tty_port_get(&dev->port); + atomic_inc(&dev->wmem_alloc); + skb->sk = (void *) dev; + skb->destructor = rfcomm_wfree; +} + +static struct sk_buff *rfcomm_wmalloc(struct rfcomm_dev *dev, unsigned long size, gfp_t priority) +{ + struct sk_buff *skb = alloc_skb(size, priority); + if (skb) + rfcomm_set_owner_w(skb, dev); + return skb; +} + +/* ---- Device IOCTLs ---- */ + +#define NOCAP_FLAGS ((1 << RFCOMM_REUSE_DLC) | (1 << RFCOMM_RELEASE_ONHUP)) + +static int __rfcomm_create_dev(struct sock *sk, void __user *arg) +{ + struct rfcomm_dev_req req; + struct rfcomm_dlc *dlc; + int id; + + if (copy_from_user(&req, arg, sizeof(req))) + return -EFAULT; + + BT_DBG("sk %p dev_id %d flags 0x%x", sk, req.dev_id, req.flags); + + if (req.flags != NOCAP_FLAGS && !capable(CAP_NET_ADMIN)) + return -EPERM; + + if (req.flags & (1 << RFCOMM_REUSE_DLC)) { + /* Socket must be connected */ + if (sk->sk_state != BT_CONNECTED) + return -EBADFD; + + dlc = rfcomm_pi(sk)->dlc; + rfcomm_dlc_hold(dlc); + } else { + /* Validate the channel is unused */ + dlc = rfcomm_dlc_exists(&req.src, &req.dst, req.channel); + if (IS_ERR(dlc)) + return PTR_ERR(dlc); + if (dlc) + return -EBUSY; + dlc = rfcomm_dlc_alloc(GFP_KERNEL); + if (!dlc) + return -ENOMEM; + } + + id = rfcomm_dev_add(&req, dlc); + if (id < 0) + return id; + + if (req.flags & (1 << RFCOMM_REUSE_DLC)) { + /* DLC is now used by device. + * Socket must be disconnected */ + sk->sk_state = BT_CLOSED; + } + + return id; +} + +static int __rfcomm_release_dev(void __user *arg) +{ + struct rfcomm_dev_req req; + struct rfcomm_dev *dev; + struct tty_struct *tty; + + if (copy_from_user(&req, arg, sizeof(req))) + return -EFAULT; + + BT_DBG("dev_id %d flags 0x%x", req.dev_id, req.flags); + + dev = rfcomm_dev_get(req.dev_id); + if (!dev) + return -ENODEV; + + if (dev->flags != NOCAP_FLAGS && !capable(CAP_NET_ADMIN)) { + tty_port_put(&dev->port); + return -EPERM; + } + + /* only release once */ + if (test_and_set_bit(RFCOMM_DEV_RELEASED, &dev->status)) { + tty_port_put(&dev->port); + return -EALREADY; + } + + if (req.flags & (1 << RFCOMM_HANGUP_NOW)) + rfcomm_dlc_close(dev->dlc, 0); + + /* Shut down TTY synchronously before freeing rfcomm_dev */ + tty = tty_port_tty_get(&dev->port); + if (tty) { + tty_vhangup(tty); + tty_kref_put(tty); + } + + if (!test_bit(RFCOMM_TTY_OWNED, &dev->status)) + tty_port_put(&dev->port); + + tty_port_put(&dev->port); + return 0; +} + +static int rfcomm_create_dev(struct sock *sk, void __user *arg) +{ + int ret; + + mutex_lock(&rfcomm_ioctl_mutex); + ret = __rfcomm_create_dev(sk, arg); + mutex_unlock(&rfcomm_ioctl_mutex); + + return ret; +} + +static int rfcomm_release_dev(void __user *arg) +{ + int ret; + + mutex_lock(&rfcomm_ioctl_mutex); + ret = __rfcomm_release_dev(arg); + mutex_unlock(&rfcomm_ioctl_mutex); + + return ret; +} + +static int rfcomm_get_dev_list(void __user *arg) +{ + struct rfcomm_dev *dev; + struct rfcomm_dev_list_req *dl; + struct rfcomm_dev_info *di; + int n = 0, size, err; + u16 dev_num; + + BT_DBG(""); + + if (get_user(dev_num, (u16 __user *) arg)) + return -EFAULT; + + if (!dev_num || dev_num > (PAGE_SIZE * 4) / sizeof(*di)) + return -EINVAL; + + size = sizeof(*dl) + dev_num * sizeof(*di); + + dl = kzalloc(size, GFP_KERNEL); + if (!dl) + return -ENOMEM; + + di = dl->dev_info; + + mutex_lock(&rfcomm_dev_lock); + + list_for_each_entry(dev, &rfcomm_dev_list, list) { + if (!tty_port_get(&dev->port)) + continue; + (di + n)->id = dev->id; + (di + n)->flags = dev->flags; + (di + n)->state = dev->dlc->state; + (di + n)->channel = dev->channel; + bacpy(&(di + n)->src, &dev->src); + bacpy(&(di + n)->dst, &dev->dst); + tty_port_put(&dev->port); + if (++n >= dev_num) + break; + } + + mutex_unlock(&rfcomm_dev_lock); + + dl->dev_num = n; + size = sizeof(*dl) + n * sizeof(*di); + + err = copy_to_user(arg, dl, size); + kfree(dl); + + return err ? -EFAULT : 0; +} + +static int rfcomm_get_dev_info(void __user *arg) +{ + struct rfcomm_dev *dev; + struct rfcomm_dev_info di; + int err = 0; + + BT_DBG(""); + + if (copy_from_user(&di, arg, sizeof(di))) + return -EFAULT; + + dev = rfcomm_dev_get(di.id); + if (!dev) + return -ENODEV; + + di.flags = dev->flags; + di.channel = dev->channel; + di.state = dev->dlc->state; + bacpy(&di.src, &dev->src); + bacpy(&di.dst, &dev->dst); + + if (copy_to_user(arg, &di, sizeof(di))) + err = -EFAULT; + + tty_port_put(&dev->port); + return err; +} + +int rfcomm_dev_ioctl(struct sock *sk, unsigned int cmd, void __user *arg) +{ + BT_DBG("cmd %d arg %p", cmd, arg); + + switch (cmd) { + case RFCOMMCREATEDEV: + return rfcomm_create_dev(sk, arg); + + case RFCOMMRELEASEDEV: + return rfcomm_release_dev(arg); + + case RFCOMMGETDEVLIST: + return rfcomm_get_dev_list(arg); + + case RFCOMMGETDEVINFO: + return rfcomm_get_dev_info(arg); + } + + return -EINVAL; +} + +/* ---- DLC callbacks ---- */ +static void rfcomm_dev_data_ready(struct rfcomm_dlc *dlc, struct sk_buff *skb) +{ + struct rfcomm_dev *dev = dlc->owner; + + if (!dev) { + kfree_skb(skb); + return; + } + + if (!skb_queue_empty(&dev->pending)) { + skb_queue_tail(&dev->pending, skb); + return; + } + + BT_DBG("dlc %p len %d", dlc, skb->len); + + tty_insert_flip_string(&dev->port, skb->data, skb->len); + tty_flip_buffer_push(&dev->port); + + kfree_skb(skb); +} + +static void rfcomm_dev_state_change(struct rfcomm_dlc *dlc, int err) +{ + struct rfcomm_dev *dev = dlc->owner; + if (!dev) + return; + + BT_DBG("dlc %p dev %p err %d", dlc, dev, err); + + dev->err = err; + if (dlc->state == BT_CONNECTED) { + rfcomm_reparent_device(dev); + + wake_up_interruptible(&dev->port.open_wait); + } else if (dlc->state == BT_CLOSED) + tty_port_tty_hangup(&dev->port, false); +} + +static void rfcomm_dev_modem_status(struct rfcomm_dlc *dlc, u8 v24_sig) +{ + struct rfcomm_dev *dev = dlc->owner; + if (!dev) + return; + + BT_DBG("dlc %p dev %p v24_sig 0x%02x", dlc, dev, v24_sig); + + if ((dev->modem_status & TIOCM_CD) && !(v24_sig & RFCOMM_V24_DV)) + tty_port_tty_hangup(&dev->port, true); + + dev->modem_status = + ((v24_sig & RFCOMM_V24_RTC) ? (TIOCM_DSR | TIOCM_DTR) : 0) | + ((v24_sig & RFCOMM_V24_RTR) ? (TIOCM_RTS | TIOCM_CTS) : 0) | + ((v24_sig & RFCOMM_V24_IC) ? TIOCM_RI : 0) | + ((v24_sig & RFCOMM_V24_DV) ? TIOCM_CD : 0); +} + +/* ---- TTY functions ---- */ +static void rfcomm_tty_copy_pending(struct rfcomm_dev *dev) +{ + struct sk_buff *skb; + int inserted = 0; + + BT_DBG("dev %p", dev); + + rfcomm_dlc_lock(dev->dlc); + + while ((skb = skb_dequeue(&dev->pending))) { + inserted += tty_insert_flip_string(&dev->port, skb->data, + skb->len); + kfree_skb(skb); + } + + rfcomm_dlc_unlock(dev->dlc); + + if (inserted > 0) + tty_flip_buffer_push(&dev->port); +} + +/* do the reverse of install, clearing the tty fields and releasing the + * reference to tty_port + */ +static void rfcomm_tty_cleanup(struct tty_struct *tty) +{ + struct rfcomm_dev *dev = tty->driver_data; + + clear_bit(RFCOMM_TTY_ATTACHED, &dev->flags); + + rfcomm_dlc_lock(dev->dlc); + tty->driver_data = NULL; + rfcomm_dlc_unlock(dev->dlc); + + /* + * purge the dlc->tx_queue to avoid circular dependencies + * between dev and dlc + */ + skb_queue_purge(&dev->dlc->tx_queue); + + tty_port_put(&dev->port); +} + +/* we acquire the tty_port reference since it's here the tty is first used + * by setting the termios. We also populate the driver_data field and install + * the tty port + */ +static int rfcomm_tty_install(struct tty_driver *driver, struct tty_struct *tty) +{ + struct rfcomm_dev *dev; + struct rfcomm_dlc *dlc; + int err; + + dev = rfcomm_dev_get(tty->index); + if (!dev) + return -ENODEV; + + dlc = dev->dlc; + + /* Attach TTY and open DLC */ + rfcomm_dlc_lock(dlc); + tty->driver_data = dev; + rfcomm_dlc_unlock(dlc); + set_bit(RFCOMM_TTY_ATTACHED, &dev->flags); + + /* install the tty_port */ + err = tty_port_install(&dev->port, driver, tty); + if (err) { + rfcomm_tty_cleanup(tty); + return err; + } + + /* take over the tty_port reference if the port was created with the + * flag RFCOMM_RELEASE_ONHUP. This will force the release of the port + * when the last process closes the tty. The behaviour is expected by + * userspace. + */ + if (test_bit(RFCOMM_RELEASE_ONHUP, &dev->flags)) { + set_bit(RFCOMM_TTY_OWNED, &dev->status); + tty_port_put(&dev->port); + } + + return 0; +} + +static int rfcomm_tty_open(struct tty_struct *tty, struct file *filp) +{ + struct rfcomm_dev *dev = tty->driver_data; + int err; + + BT_DBG("tty %p id %d", tty, tty->index); + + BT_DBG("dev %p dst %pMR channel %d opened %d", dev, &dev->dst, + dev->channel, dev->port.count); + + err = tty_port_open(&dev->port, tty, filp); + if (err) + return err; + + /* + * FIXME: rfcomm should use proper flow control for + * received data. This hack will be unnecessary and can + * be removed when that's implemented + */ + rfcomm_tty_copy_pending(dev); + + rfcomm_dlc_unthrottle(dev->dlc); + + return 0; +} + +static void rfcomm_tty_close(struct tty_struct *tty, struct file *filp) +{ + struct rfcomm_dev *dev = (struct rfcomm_dev *) tty->driver_data; + + BT_DBG("tty %p dev %p dlc %p opened %d", tty, dev, dev->dlc, + dev->port.count); + + tty_port_close(&dev->port, tty, filp); +} + +static int rfcomm_tty_write(struct tty_struct *tty, const unsigned char *buf, int count) +{ + struct rfcomm_dev *dev = (struct rfcomm_dev *) tty->driver_data; + struct rfcomm_dlc *dlc = dev->dlc; + struct sk_buff *skb; + int sent = 0, size; + + BT_DBG("tty %p count %d", tty, count); + + while (count) { + size = min_t(uint, count, dlc->mtu); + + skb = rfcomm_wmalloc(dev, size + RFCOMM_SKB_RESERVE, GFP_ATOMIC); + if (!skb) + break; + + skb_reserve(skb, RFCOMM_SKB_HEAD_RESERVE); + + skb_put_data(skb, buf + sent, size); + + rfcomm_dlc_send_noerror(dlc, skb); + + sent += size; + count -= size; + } + + return sent; +} + +static unsigned int rfcomm_tty_write_room(struct tty_struct *tty) +{ + struct rfcomm_dev *dev = (struct rfcomm_dev *) tty->driver_data; + int room = 0; + + if (dev && dev->dlc) + room = rfcomm_room(dev); + + BT_DBG("tty %p room %d", tty, room); + + return room; +} + +static int rfcomm_tty_ioctl(struct tty_struct *tty, unsigned int cmd, unsigned long arg) +{ + BT_DBG("tty %p cmd 0x%02x", tty, cmd); + + switch (cmd) { + case TCGETS: + BT_DBG("TCGETS is not supported"); + return -ENOIOCTLCMD; + + case TCSETS: + BT_DBG("TCSETS is not supported"); + return -ENOIOCTLCMD; + + case TIOCMIWAIT: + BT_DBG("TIOCMIWAIT"); + break; + + case TIOCSERGETLSR: + BT_ERR("TIOCSERGETLSR is not supported"); + return -ENOIOCTLCMD; + + case TIOCSERCONFIG: + BT_ERR("TIOCSERCONFIG is not supported"); + return -ENOIOCTLCMD; + + default: + return -ENOIOCTLCMD; /* ioctls which we must ignore */ + + } + + return -ENOIOCTLCMD; +} + +static void rfcomm_tty_set_termios(struct tty_struct *tty, + const struct ktermios *old) +{ + struct ktermios *new = &tty->termios; + int old_baud_rate = tty_termios_baud_rate(old); + int new_baud_rate = tty_termios_baud_rate(new); + + u8 baud, data_bits, stop_bits, parity, x_on, x_off; + u16 changes = 0; + + struct rfcomm_dev *dev = (struct rfcomm_dev *) tty->driver_data; + + BT_DBG("tty %p termios %p", tty, old); + + if (!dev || !dev->dlc || !dev->dlc->session) + return; + + /* Handle turning off CRTSCTS */ + if ((old->c_cflag & CRTSCTS) && !(new->c_cflag & CRTSCTS)) + BT_DBG("Turning off CRTSCTS unsupported"); + + /* Parity on/off and when on, odd/even */ + if (((old->c_cflag & PARENB) != (new->c_cflag & PARENB)) || + ((old->c_cflag & PARODD) != (new->c_cflag & PARODD))) { + changes |= RFCOMM_RPN_PM_PARITY; + BT_DBG("Parity change detected."); + } + + /* Mark and space parity are not supported! */ + if (new->c_cflag & PARENB) { + if (new->c_cflag & PARODD) { + BT_DBG("Parity is ODD"); + parity = RFCOMM_RPN_PARITY_ODD; + } else { + BT_DBG("Parity is EVEN"); + parity = RFCOMM_RPN_PARITY_EVEN; + } + } else { + BT_DBG("Parity is OFF"); + parity = RFCOMM_RPN_PARITY_NONE; + } + + /* Setting the x_on / x_off characters */ + if (old->c_cc[VSTOP] != new->c_cc[VSTOP]) { + BT_DBG("XOFF custom"); + x_on = new->c_cc[VSTOP]; + changes |= RFCOMM_RPN_PM_XON; + } else { + BT_DBG("XOFF default"); + x_on = RFCOMM_RPN_XON_CHAR; + } + + if (old->c_cc[VSTART] != new->c_cc[VSTART]) { + BT_DBG("XON custom"); + x_off = new->c_cc[VSTART]; + changes |= RFCOMM_RPN_PM_XOFF; + } else { + BT_DBG("XON default"); + x_off = RFCOMM_RPN_XOFF_CHAR; + } + + /* Handle setting of stop bits */ + if ((old->c_cflag & CSTOPB) != (new->c_cflag & CSTOPB)) + changes |= RFCOMM_RPN_PM_STOP; + + /* POSIX does not support 1.5 stop bits and RFCOMM does not + * support 2 stop bits. So a request for 2 stop bits gets + * translated to 1.5 stop bits */ + if (new->c_cflag & CSTOPB) + stop_bits = RFCOMM_RPN_STOP_15; + else + stop_bits = RFCOMM_RPN_STOP_1; + + /* Handle number of data bits [5-8] */ + if ((old->c_cflag & CSIZE) != (new->c_cflag & CSIZE)) + changes |= RFCOMM_RPN_PM_DATA; + + switch (new->c_cflag & CSIZE) { + case CS5: + data_bits = RFCOMM_RPN_DATA_5; + break; + case CS6: + data_bits = RFCOMM_RPN_DATA_6; + break; + case CS7: + data_bits = RFCOMM_RPN_DATA_7; + break; + case CS8: + data_bits = RFCOMM_RPN_DATA_8; + break; + default: + data_bits = RFCOMM_RPN_DATA_8; + break; + } + + /* Handle baudrate settings */ + if (old_baud_rate != new_baud_rate) + changes |= RFCOMM_RPN_PM_BITRATE; + + switch (new_baud_rate) { + case 2400: + baud = RFCOMM_RPN_BR_2400; + break; + case 4800: + baud = RFCOMM_RPN_BR_4800; + break; + case 7200: + baud = RFCOMM_RPN_BR_7200; + break; + case 9600: + baud = RFCOMM_RPN_BR_9600; + break; + case 19200: + baud = RFCOMM_RPN_BR_19200; + break; + case 38400: + baud = RFCOMM_RPN_BR_38400; + break; + case 57600: + baud = RFCOMM_RPN_BR_57600; + break; + case 115200: + baud = RFCOMM_RPN_BR_115200; + break; + case 230400: + baud = RFCOMM_RPN_BR_230400; + break; + default: + /* 9600 is standard accordinag to the RFCOMM specification */ + baud = RFCOMM_RPN_BR_9600; + break; + + } + + if (changes) + rfcomm_send_rpn(dev->dlc->session, 1, dev->dlc->dlci, baud, + data_bits, stop_bits, parity, + RFCOMM_RPN_FLOW_NONE, x_on, x_off, changes); +} + +static void rfcomm_tty_throttle(struct tty_struct *tty) +{ + struct rfcomm_dev *dev = (struct rfcomm_dev *) tty->driver_data; + + BT_DBG("tty %p dev %p", tty, dev); + + rfcomm_dlc_throttle(dev->dlc); +} + +static void rfcomm_tty_unthrottle(struct tty_struct *tty) +{ + struct rfcomm_dev *dev = (struct rfcomm_dev *) tty->driver_data; + + BT_DBG("tty %p dev %p", tty, dev); + + rfcomm_dlc_unthrottle(dev->dlc); +} + +static unsigned int rfcomm_tty_chars_in_buffer(struct tty_struct *tty) +{ + struct rfcomm_dev *dev = (struct rfcomm_dev *) tty->driver_data; + + BT_DBG("tty %p dev %p", tty, dev); + + if (!dev || !dev->dlc) + return 0; + + if (!skb_queue_empty(&dev->dlc->tx_queue)) + return dev->dlc->mtu; + + return 0; +} + +static void rfcomm_tty_flush_buffer(struct tty_struct *tty) +{ + struct rfcomm_dev *dev = (struct rfcomm_dev *) tty->driver_data; + + BT_DBG("tty %p dev %p", tty, dev); + + if (!dev || !dev->dlc) + return; + + skb_queue_purge(&dev->dlc->tx_queue); + tty_wakeup(tty); +} + +static void rfcomm_tty_send_xchar(struct tty_struct *tty, char ch) +{ + BT_DBG("tty %p ch %c", tty, ch); +} + +static void rfcomm_tty_wait_until_sent(struct tty_struct *tty, int timeout) +{ + BT_DBG("tty %p timeout %d", tty, timeout); +} + +static void rfcomm_tty_hangup(struct tty_struct *tty) +{ + struct rfcomm_dev *dev = (struct rfcomm_dev *) tty->driver_data; + + BT_DBG("tty %p dev %p", tty, dev); + + tty_port_hangup(&dev->port); +} + +static int rfcomm_tty_tiocmget(struct tty_struct *tty) +{ + struct rfcomm_dev *dev = (struct rfcomm_dev *) tty->driver_data; + + BT_DBG("tty %p dev %p", tty, dev); + + return dev->modem_status; +} + +static int rfcomm_tty_tiocmset(struct tty_struct *tty, unsigned int set, unsigned int clear) +{ + struct rfcomm_dev *dev = (struct rfcomm_dev *) tty->driver_data; + struct rfcomm_dlc *dlc = dev->dlc; + u8 v24_sig; + + BT_DBG("tty %p dev %p set 0x%02x clear 0x%02x", tty, dev, set, clear); + + rfcomm_dlc_get_modem_status(dlc, &v24_sig); + + if (set & TIOCM_DSR || set & TIOCM_DTR) + v24_sig |= RFCOMM_V24_RTC; + if (set & TIOCM_RTS || set & TIOCM_CTS) + v24_sig |= RFCOMM_V24_RTR; + if (set & TIOCM_RI) + v24_sig |= RFCOMM_V24_IC; + if (set & TIOCM_CD) + v24_sig |= RFCOMM_V24_DV; + + if (clear & TIOCM_DSR || clear & TIOCM_DTR) + v24_sig &= ~RFCOMM_V24_RTC; + if (clear & TIOCM_RTS || clear & TIOCM_CTS) + v24_sig &= ~RFCOMM_V24_RTR; + if (clear & TIOCM_RI) + v24_sig &= ~RFCOMM_V24_IC; + if (clear & TIOCM_CD) + v24_sig &= ~RFCOMM_V24_DV; + + rfcomm_dlc_set_modem_status(dlc, v24_sig); + + return 0; +} + +/* ---- TTY structure ---- */ + +static const struct tty_operations rfcomm_ops = { + .open = rfcomm_tty_open, + .close = rfcomm_tty_close, + .write = rfcomm_tty_write, + .write_room = rfcomm_tty_write_room, + .chars_in_buffer = rfcomm_tty_chars_in_buffer, + .flush_buffer = rfcomm_tty_flush_buffer, + .ioctl = rfcomm_tty_ioctl, + .throttle = rfcomm_tty_throttle, + .unthrottle = rfcomm_tty_unthrottle, + .set_termios = rfcomm_tty_set_termios, + .send_xchar = rfcomm_tty_send_xchar, + .hangup = rfcomm_tty_hangup, + .wait_until_sent = rfcomm_tty_wait_until_sent, + .tiocmget = rfcomm_tty_tiocmget, + .tiocmset = rfcomm_tty_tiocmset, + .install = rfcomm_tty_install, + .cleanup = rfcomm_tty_cleanup, +}; + +int __init rfcomm_init_ttys(void) +{ + int error; + + rfcomm_tty_driver = tty_alloc_driver(RFCOMM_TTY_PORTS, + TTY_DRIVER_REAL_RAW | TTY_DRIVER_DYNAMIC_DEV); + if (IS_ERR(rfcomm_tty_driver)) + return PTR_ERR(rfcomm_tty_driver); + + rfcomm_tty_driver->driver_name = "rfcomm"; + rfcomm_tty_driver->name = "rfcomm"; + rfcomm_tty_driver->major = RFCOMM_TTY_MAJOR; + rfcomm_tty_driver->minor_start = RFCOMM_TTY_MINOR; + rfcomm_tty_driver->type = TTY_DRIVER_TYPE_SERIAL; + rfcomm_tty_driver->subtype = SERIAL_TYPE_NORMAL; + rfcomm_tty_driver->init_termios = tty_std_termios; + rfcomm_tty_driver->init_termios.c_cflag = B9600 | CS8 | CREAD | HUPCL; + rfcomm_tty_driver->init_termios.c_lflag &= ~ICANON; + tty_set_operations(rfcomm_tty_driver, &rfcomm_ops); + + error = tty_register_driver(rfcomm_tty_driver); + if (error) { + BT_ERR("Can't register RFCOMM TTY driver"); + tty_driver_kref_put(rfcomm_tty_driver); + return error; + } + + BT_INFO("RFCOMM TTY layer initialized"); + + return 0; +} + +void rfcomm_cleanup_ttys(void) +{ + tty_unregister_driver(rfcomm_tty_driver); + tty_driver_kref_put(rfcomm_tty_driver); +} diff --git a/net/bluetooth/sco.c b/net/bluetooth/sco.c new file mode 100644 index 000000000..6d4168cfe --- /dev/null +++ b/net/bluetooth/sco.c @@ -0,0 +1,1509 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + Copyright (C) 2000-2001 Qualcomm Incorporated + + Written 2000,2001 by Maxim Krasnyansky + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +/* Bluetooth SCO sockets. */ + +#include +#include +#include +#include + +#include +#include +#include + +static bool disable_esco; + +static const struct proto_ops sco_sock_ops; + +static struct bt_sock_list sco_sk_list = { + .lock = __RW_LOCK_UNLOCKED(sco_sk_list.lock) +}; + +/* ---- SCO connections ---- */ +struct sco_conn { + struct hci_conn *hcon; + + spinlock_t lock; + struct sock *sk; + + struct delayed_work timeout_work; + + unsigned int mtu; +}; + +#define sco_conn_lock(c) spin_lock(&c->lock) +#define sco_conn_unlock(c) spin_unlock(&c->lock) + +static void sco_sock_close(struct sock *sk); +static void sco_sock_kill(struct sock *sk); + +/* ----- SCO socket info ----- */ +#define sco_pi(sk) ((struct sco_pinfo *) sk) + +struct sco_pinfo { + struct bt_sock bt; + bdaddr_t src; + bdaddr_t dst; + __u32 flags; + __u16 setting; + __u8 cmsg_mask; + struct bt_codec codec; + struct sco_conn *conn; +}; + +/* ---- SCO timers ---- */ +#define SCO_CONN_TIMEOUT (HZ * 40) +#define SCO_DISCONN_TIMEOUT (HZ * 2) + +static void sco_sock_timeout(struct work_struct *work) +{ + struct sco_conn *conn = container_of(work, struct sco_conn, + timeout_work.work); + struct sock *sk; + + sco_conn_lock(conn); + sk = conn->sk; + if (sk) + sock_hold(sk); + sco_conn_unlock(conn); + + if (!sk) + return; + + BT_DBG("sock %p state %d", sk, sk->sk_state); + + lock_sock(sk); + sk->sk_err = ETIMEDOUT; + sk->sk_state_change(sk); + release_sock(sk); + sock_put(sk); +} + +static void sco_sock_set_timer(struct sock *sk, long timeout) +{ + if (!sco_pi(sk)->conn) + return; + + BT_DBG("sock %p state %d timeout %ld", sk, sk->sk_state, timeout); + cancel_delayed_work(&sco_pi(sk)->conn->timeout_work); + schedule_delayed_work(&sco_pi(sk)->conn->timeout_work, timeout); +} + +static void sco_sock_clear_timer(struct sock *sk) +{ + if (!sco_pi(sk)->conn) + return; + + BT_DBG("sock %p state %d", sk, sk->sk_state); + cancel_delayed_work(&sco_pi(sk)->conn->timeout_work); +} + +/* ---- SCO connections ---- */ +static struct sco_conn *sco_conn_add(struct hci_conn *hcon) +{ + struct hci_dev *hdev = hcon->hdev; + struct sco_conn *conn = hcon->sco_data; + + if (conn) + return conn; + + conn = kzalloc(sizeof(struct sco_conn), GFP_KERNEL); + if (!conn) + return NULL; + + spin_lock_init(&conn->lock); + INIT_DELAYED_WORK(&conn->timeout_work, sco_sock_timeout); + + hcon->sco_data = conn; + conn->hcon = hcon; + + if (hdev->sco_mtu > 0) + conn->mtu = hdev->sco_mtu; + else + conn->mtu = 60; + + BT_DBG("hcon %p conn %p", hcon, conn); + + return conn; +} + +/* Delete channel. + * Must be called on the locked socket. */ +static void sco_chan_del(struct sock *sk, int err) +{ + struct sco_conn *conn; + + conn = sco_pi(sk)->conn; + + BT_DBG("sk %p, conn %p, err %d", sk, conn, err); + + if (conn) { + sco_conn_lock(conn); + conn->sk = NULL; + sco_pi(sk)->conn = NULL; + sco_conn_unlock(conn); + + if (conn->hcon) + hci_conn_drop(conn->hcon); + } + + sk->sk_state = BT_CLOSED; + sk->sk_err = err; + sk->sk_state_change(sk); + + sock_set_flag(sk, SOCK_ZAPPED); +} + +static void sco_conn_del(struct hci_conn *hcon, int err) +{ + struct sco_conn *conn = hcon->sco_data; + struct sock *sk; + + if (!conn) + return; + + BT_DBG("hcon %p conn %p, err %d", hcon, conn, err); + + /* Kill socket */ + sco_conn_lock(conn); + sk = conn->sk; + if (sk) + sock_hold(sk); + sco_conn_unlock(conn); + + if (sk) { + lock_sock(sk); + sco_sock_clear_timer(sk); + sco_chan_del(sk, err); + release_sock(sk); + sock_put(sk); + } + + /* Ensure no more work items will run before freeing conn. */ + cancel_delayed_work_sync(&conn->timeout_work); + + hcon->sco_data = NULL; + kfree(conn); +} + +static void __sco_chan_add(struct sco_conn *conn, struct sock *sk, + struct sock *parent) +{ + BT_DBG("conn %p", conn); + + sco_pi(sk)->conn = conn; + conn->sk = sk; + + if (parent) + bt_accept_enqueue(parent, sk, true); +} + +static int sco_chan_add(struct sco_conn *conn, struct sock *sk, + struct sock *parent) +{ + int err = 0; + + sco_conn_lock(conn); + if (conn->sk) + err = -EBUSY; + else + __sco_chan_add(conn, sk, parent); + + sco_conn_unlock(conn); + return err; +} + +static int sco_connect(struct hci_dev *hdev, struct sock *sk) +{ + struct sco_conn *conn; + struct hci_conn *hcon; + int err, type; + + BT_DBG("%pMR -> %pMR", &sco_pi(sk)->src, &sco_pi(sk)->dst); + + if (lmp_esco_capable(hdev) && !disable_esco) + type = ESCO_LINK; + else + type = SCO_LINK; + + if (sco_pi(sk)->setting == BT_VOICE_TRANSPARENT && + (!lmp_transp_capable(hdev) || !lmp_esco_capable(hdev))) + return -EOPNOTSUPP; + + hcon = hci_connect_sco(hdev, type, &sco_pi(sk)->dst, + sco_pi(sk)->setting, &sco_pi(sk)->codec); + if (IS_ERR(hcon)) + return PTR_ERR(hcon); + + conn = sco_conn_add(hcon); + if (!conn) { + hci_conn_drop(hcon); + return -ENOMEM; + } + + /* Update source addr of the socket */ + bacpy(&sco_pi(sk)->src, &hcon->src); + + err = sco_chan_add(conn, sk, NULL); + if (err) + return err; + + if (hcon->state == BT_CONNECTED) { + sco_sock_clear_timer(sk); + sk->sk_state = BT_CONNECTED; + } else { + sk->sk_state = BT_CONNECT; + sco_sock_set_timer(sk, sk->sk_sndtimeo); + } + + return err; +} + +static int sco_send_frame(struct sock *sk, struct sk_buff *skb) +{ + struct sco_conn *conn = sco_pi(sk)->conn; + int len = skb->len; + + /* Check outgoing MTU */ + if (len > conn->mtu) + return -EINVAL; + + BT_DBG("sk %p len %d", sk, len); + + hci_send_sco(conn->hcon, skb); + + return len; +} + +static void sco_recv_frame(struct sco_conn *conn, struct sk_buff *skb) +{ + struct sock *sk; + + sco_conn_lock(conn); + sk = conn->sk; + sco_conn_unlock(conn); + + if (!sk) + goto drop; + + BT_DBG("sk %p len %u", sk, skb->len); + + if (sk->sk_state != BT_CONNECTED) + goto drop; + + if (!sock_queue_rcv_skb(sk, skb)) + return; + +drop: + kfree_skb(skb); +} + +/* -------- Socket interface ---------- */ +static struct sock *__sco_get_sock_listen_by_addr(bdaddr_t *ba) +{ + struct sock *sk; + + sk_for_each(sk, &sco_sk_list.head) { + if (sk->sk_state != BT_LISTEN) + continue; + + if (!bacmp(&sco_pi(sk)->src, ba)) + return sk; + } + + return NULL; +} + +/* Find socket listening on source bdaddr. + * Returns closest match. + */ +static struct sock *sco_get_sock_listen(bdaddr_t *src) +{ + struct sock *sk = NULL, *sk1 = NULL; + + read_lock(&sco_sk_list.lock); + + sk_for_each(sk, &sco_sk_list.head) { + if (sk->sk_state != BT_LISTEN) + continue; + + /* Exact match. */ + if (!bacmp(&sco_pi(sk)->src, src)) + break; + + /* Closest match */ + if (!bacmp(&sco_pi(sk)->src, BDADDR_ANY)) + sk1 = sk; + } + + read_unlock(&sco_sk_list.lock); + + return sk ? sk : sk1; +} + +static void sco_sock_destruct(struct sock *sk) +{ + BT_DBG("sk %p", sk); + + skb_queue_purge(&sk->sk_receive_queue); + skb_queue_purge(&sk->sk_write_queue); +} + +static void sco_sock_cleanup_listen(struct sock *parent) +{ + struct sock *sk; + + BT_DBG("parent %p", parent); + + /* Close not yet accepted channels */ + while ((sk = bt_accept_dequeue(parent, NULL))) { + sco_sock_close(sk); + sco_sock_kill(sk); + } + + parent->sk_state = BT_CLOSED; + sock_set_flag(parent, SOCK_ZAPPED); +} + +/* Kill socket (only if zapped and orphan) + * Must be called on unlocked socket. + */ +static void sco_sock_kill(struct sock *sk) +{ + if (!sock_flag(sk, SOCK_ZAPPED) || sk->sk_socket) + return; + + BT_DBG("sk %p state %d", sk, sk->sk_state); + + /* Kill poor orphan */ + bt_sock_unlink(&sco_sk_list, sk); + sock_set_flag(sk, SOCK_DEAD); + sock_put(sk); +} + +static void __sco_sock_close(struct sock *sk) +{ + BT_DBG("sk %p state %d socket %p", sk, sk->sk_state, sk->sk_socket); + + switch (sk->sk_state) { + case BT_LISTEN: + sco_sock_cleanup_listen(sk); + break; + + case BT_CONNECTED: + case BT_CONFIG: + if (sco_pi(sk)->conn->hcon) { + sk->sk_state = BT_DISCONN; + sco_sock_set_timer(sk, SCO_DISCONN_TIMEOUT); + sco_conn_lock(sco_pi(sk)->conn); + hci_conn_drop(sco_pi(sk)->conn->hcon); + sco_pi(sk)->conn->hcon = NULL; + sco_conn_unlock(sco_pi(sk)->conn); + } else + sco_chan_del(sk, ECONNRESET); + break; + + case BT_CONNECT2: + case BT_CONNECT: + case BT_DISCONN: + sco_chan_del(sk, ECONNRESET); + break; + + default: + sock_set_flag(sk, SOCK_ZAPPED); + break; + } + +} + +/* Must be called on unlocked socket. */ +static void sco_sock_close(struct sock *sk) +{ + lock_sock(sk); + sco_sock_clear_timer(sk); + __sco_sock_close(sk); + release_sock(sk); +} + +static void sco_skb_put_cmsg(struct sk_buff *skb, struct msghdr *msg, + struct sock *sk) +{ + if (sco_pi(sk)->cmsg_mask & SCO_CMSG_PKT_STATUS) + put_cmsg(msg, SOL_BLUETOOTH, BT_SCM_PKT_STATUS, + sizeof(bt_cb(skb)->sco.pkt_status), + &bt_cb(skb)->sco.pkt_status); +} + +static void sco_sock_init(struct sock *sk, struct sock *parent) +{ + BT_DBG("sk %p", sk); + + if (parent) { + sk->sk_type = parent->sk_type; + bt_sk(sk)->flags = bt_sk(parent)->flags; + security_sk_clone(parent, sk); + } else { + bt_sk(sk)->skb_put_cmsg = sco_skb_put_cmsg; + } +} + +static struct proto sco_proto = { + .name = "SCO", + .owner = THIS_MODULE, + .obj_size = sizeof(struct sco_pinfo) +}; + +static struct sock *sco_sock_alloc(struct net *net, struct socket *sock, + int proto, gfp_t prio, int kern) +{ + struct sock *sk; + + sk = sk_alloc(net, PF_BLUETOOTH, prio, &sco_proto, kern); + if (!sk) + return NULL; + + sock_init_data(sock, sk); + INIT_LIST_HEAD(&bt_sk(sk)->accept_q); + + sk->sk_destruct = sco_sock_destruct; + sk->sk_sndtimeo = SCO_CONN_TIMEOUT; + + sock_reset_flag(sk, SOCK_ZAPPED); + + sk->sk_protocol = proto; + sk->sk_state = BT_OPEN; + + sco_pi(sk)->setting = BT_VOICE_CVSD_16BIT; + sco_pi(sk)->codec.id = BT_CODEC_CVSD; + sco_pi(sk)->codec.cid = 0xffff; + sco_pi(sk)->codec.vid = 0xffff; + sco_pi(sk)->codec.data_path = 0x00; + + bt_sock_link(&sco_sk_list, sk); + return sk; +} + +static int sco_sock_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + + BT_DBG("sock %p", sock); + + sock->state = SS_UNCONNECTED; + + if (sock->type != SOCK_SEQPACKET) + return -ESOCKTNOSUPPORT; + + sock->ops = &sco_sock_ops; + + sk = sco_sock_alloc(net, sock, protocol, GFP_ATOMIC, kern); + if (!sk) + return -ENOMEM; + + sco_sock_init(sk, NULL); + return 0; +} + +static int sco_sock_bind(struct socket *sock, struct sockaddr *addr, + int addr_len) +{ + struct sockaddr_sco *sa = (struct sockaddr_sco *) addr; + struct sock *sk = sock->sk; + int err = 0; + + if (!addr || addr_len < sizeof(struct sockaddr_sco) || + addr->sa_family != AF_BLUETOOTH) + return -EINVAL; + + BT_DBG("sk %p %pMR", sk, &sa->sco_bdaddr); + + lock_sock(sk); + + if (sk->sk_state != BT_OPEN) { + err = -EBADFD; + goto done; + } + + if (sk->sk_type != SOCK_SEQPACKET) { + err = -EINVAL; + goto done; + } + + bacpy(&sco_pi(sk)->src, &sa->sco_bdaddr); + + sk->sk_state = BT_BOUND; + +done: + release_sock(sk); + return err; +} + +static int sco_sock_connect(struct socket *sock, struct sockaddr *addr, int alen, int flags) +{ + struct sockaddr_sco *sa = (struct sockaddr_sco *) addr; + struct sock *sk = sock->sk; + struct hci_dev *hdev; + int err; + + BT_DBG("sk %p", sk); + + if (alen < sizeof(struct sockaddr_sco) || + addr->sa_family != AF_BLUETOOTH) + return -EINVAL; + + lock_sock(sk); + if (sk->sk_state != BT_OPEN && sk->sk_state != BT_BOUND) { + err = -EBADFD; + goto done; + } + + if (sk->sk_type != SOCK_SEQPACKET) { + err = -EINVAL; + goto done; + } + + hdev = hci_get_route(&sa->sco_bdaddr, &sco_pi(sk)->src, BDADDR_BREDR); + if (!hdev) { + err = -EHOSTUNREACH; + goto done; + } + hci_dev_lock(hdev); + + /* Set destination address and psm */ + bacpy(&sco_pi(sk)->dst, &sa->sco_bdaddr); + + err = sco_connect(hdev, sk); + hci_dev_unlock(hdev); + hci_dev_put(hdev); + if (err) + goto done; + + err = bt_sock_wait_state(sk, BT_CONNECTED, + sock_sndtimeo(sk, flags & O_NONBLOCK)); + +done: + release_sock(sk); + return err; +} + +static int sco_sock_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + bdaddr_t *src = &sco_pi(sk)->src; + int err = 0; + + BT_DBG("sk %p backlog %d", sk, backlog); + + lock_sock(sk); + + if (sk->sk_state != BT_BOUND) { + err = -EBADFD; + goto done; + } + + if (sk->sk_type != SOCK_SEQPACKET) { + err = -EINVAL; + goto done; + } + + write_lock(&sco_sk_list.lock); + + if (__sco_get_sock_listen_by_addr(src)) { + err = -EADDRINUSE; + goto unlock; + } + + sk->sk_max_ack_backlog = backlog; + sk->sk_ack_backlog = 0; + + sk->sk_state = BT_LISTEN; + +unlock: + write_unlock(&sco_sk_list.lock); + +done: + release_sock(sk); + return err; +} + +static int sco_sock_accept(struct socket *sock, struct socket *newsock, + int flags, bool kern) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + struct sock *sk = sock->sk, *ch; + long timeo; + int err = 0; + + lock_sock(sk); + + timeo = sock_rcvtimeo(sk, flags & O_NONBLOCK); + + BT_DBG("sk %p timeo %ld", sk, timeo); + + /* Wait for an incoming connection. (wake-one). */ + add_wait_queue_exclusive(sk_sleep(sk), &wait); + while (1) { + if (sk->sk_state != BT_LISTEN) { + err = -EBADFD; + break; + } + + ch = bt_accept_dequeue(sk, newsock); + if (ch) + break; + + if (!timeo) { + err = -EAGAIN; + break; + } + + if (signal_pending(current)) { + err = sock_intr_errno(timeo); + break; + } + + release_sock(sk); + + timeo = wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); + lock_sock(sk); + } + remove_wait_queue(sk_sleep(sk), &wait); + + if (err) + goto done; + + newsock->state = SS_CONNECTED; + + BT_DBG("new socket %p", ch); + +done: + release_sock(sk); + return err; +} + +static int sco_sock_getname(struct socket *sock, struct sockaddr *addr, + int peer) +{ + struct sockaddr_sco *sa = (struct sockaddr_sco *) addr; + struct sock *sk = sock->sk; + + BT_DBG("sock %p, sk %p", sock, sk); + + addr->sa_family = AF_BLUETOOTH; + + if (peer) + bacpy(&sa->sco_bdaddr, &sco_pi(sk)->dst); + else + bacpy(&sa->sco_bdaddr, &sco_pi(sk)->src); + + return sizeof(struct sockaddr_sco); +} + +static int sco_sock_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + struct sock *sk = sock->sk; + struct sk_buff *skb; + int err; + + BT_DBG("sock %p, sk %p", sock, sk); + + err = sock_error(sk); + if (err) + return err; + + if (msg->msg_flags & MSG_OOB) + return -EOPNOTSUPP; + + skb = bt_skb_sendmsg(sk, msg, len, len, 0, 0); + if (IS_ERR(skb)) + return PTR_ERR(skb); + + lock_sock(sk); + + if (sk->sk_state == BT_CONNECTED) + err = sco_send_frame(sk, skb); + else + err = -ENOTCONN; + + release_sock(sk); + + if (err < 0) + kfree_skb(skb); + return err; +} + +static void sco_conn_defer_accept(struct hci_conn *conn, u16 setting) +{ + struct hci_dev *hdev = conn->hdev; + + BT_DBG("conn %p", conn); + + conn->state = BT_CONFIG; + + if (!lmp_esco_capable(hdev)) { + struct hci_cp_accept_conn_req cp; + + bacpy(&cp.bdaddr, &conn->dst); + cp.role = 0x00; /* Ignored */ + + hci_send_cmd(hdev, HCI_OP_ACCEPT_CONN_REQ, sizeof(cp), &cp); + } else { + struct hci_cp_accept_sync_conn_req cp; + + bacpy(&cp.bdaddr, &conn->dst); + cp.pkt_type = cpu_to_le16(conn->pkt_type); + + cp.tx_bandwidth = cpu_to_le32(0x00001f40); + cp.rx_bandwidth = cpu_to_le32(0x00001f40); + cp.content_format = cpu_to_le16(setting); + + switch (setting & SCO_AIRMODE_MASK) { + case SCO_AIRMODE_TRANSP: + if (conn->pkt_type & ESCO_2EV3) + cp.max_latency = cpu_to_le16(0x0008); + else + cp.max_latency = cpu_to_le16(0x000D); + cp.retrans_effort = 0x02; + break; + case SCO_AIRMODE_CVSD: + cp.max_latency = cpu_to_le16(0xffff); + cp.retrans_effort = 0xff; + break; + default: + /* use CVSD settings as fallback */ + cp.max_latency = cpu_to_le16(0xffff); + cp.retrans_effort = 0xff; + break; + } + + hci_send_cmd(hdev, HCI_OP_ACCEPT_SYNC_CONN_REQ, + sizeof(cp), &cp); + } +} + +static int sco_sock_recvmsg(struct socket *sock, struct msghdr *msg, + size_t len, int flags) +{ + struct sock *sk = sock->sk; + struct sco_pinfo *pi = sco_pi(sk); + + lock_sock(sk); + + if (sk->sk_state == BT_CONNECT2 && + test_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags)) { + sco_conn_defer_accept(pi->conn->hcon, pi->setting); + sk->sk_state = BT_CONFIG; + + release_sock(sk); + return 0; + } + + release_sock(sk); + + return bt_sock_recvmsg(sock, msg, len, flags); +} + +static int sco_sock_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + int len, err = 0; + struct bt_voice voice; + u32 opt; + struct bt_codecs *codecs; + struct hci_dev *hdev; + __u8 buffer[255]; + + BT_DBG("sk %p", sk); + + lock_sock(sk); + + switch (optname) { + + case BT_DEFER_SETUP: + if (sk->sk_state != BT_BOUND && sk->sk_state != BT_LISTEN) { + err = -EINVAL; + break; + } + + if (copy_from_sockptr(&opt, optval, sizeof(u32))) { + err = -EFAULT; + break; + } + + if (opt) + set_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags); + else + clear_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags); + break; + + case BT_VOICE: + if (sk->sk_state != BT_OPEN && sk->sk_state != BT_BOUND && + sk->sk_state != BT_CONNECT2) { + err = -EINVAL; + break; + } + + voice.setting = sco_pi(sk)->setting; + + len = min_t(unsigned int, sizeof(voice), optlen); + if (copy_from_sockptr(&voice, optval, len)) { + err = -EFAULT; + break; + } + + /* Explicitly check for these values */ + if (voice.setting != BT_VOICE_TRANSPARENT && + voice.setting != BT_VOICE_CVSD_16BIT) { + err = -EINVAL; + break; + } + + sco_pi(sk)->setting = voice.setting; + hdev = hci_get_route(&sco_pi(sk)->dst, &sco_pi(sk)->src, + BDADDR_BREDR); + if (!hdev) { + err = -EBADFD; + break; + } + if (enhanced_sync_conn_capable(hdev) && + voice.setting == BT_VOICE_TRANSPARENT) + sco_pi(sk)->codec.id = BT_CODEC_TRANSPARENT; + hci_dev_put(hdev); + break; + + case BT_PKT_STATUS: + if (copy_from_sockptr(&opt, optval, sizeof(u32))) { + err = -EFAULT; + break; + } + + if (opt) + sco_pi(sk)->cmsg_mask |= SCO_CMSG_PKT_STATUS; + else + sco_pi(sk)->cmsg_mask &= SCO_CMSG_PKT_STATUS; + break; + + case BT_CODEC: + if (sk->sk_state != BT_OPEN && sk->sk_state != BT_BOUND && + sk->sk_state != BT_CONNECT2) { + err = -EINVAL; + break; + } + + hdev = hci_get_route(&sco_pi(sk)->dst, &sco_pi(sk)->src, + BDADDR_BREDR); + if (!hdev) { + err = -EBADFD; + break; + } + + if (!hci_dev_test_flag(hdev, HCI_OFFLOAD_CODECS_ENABLED)) { + hci_dev_put(hdev); + err = -EOPNOTSUPP; + break; + } + + if (!hdev->get_data_path_id) { + hci_dev_put(hdev); + err = -EOPNOTSUPP; + break; + } + + if (optlen < sizeof(struct bt_codecs) || + optlen > sizeof(buffer)) { + hci_dev_put(hdev); + err = -EINVAL; + break; + } + + if (copy_from_sockptr(buffer, optval, optlen)) { + hci_dev_put(hdev); + err = -EFAULT; + break; + } + + codecs = (void *)buffer; + + if (codecs->num_codecs > 1) { + hci_dev_put(hdev); + err = -EINVAL; + break; + } + + sco_pi(sk)->codec = codecs->codecs[0]; + hci_dev_put(hdev); + break; + + default: + err = -ENOPROTOOPT; + break; + } + + release_sock(sk); + return err; +} + +static int sco_sock_getsockopt_old(struct socket *sock, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + struct sco_options opts; + struct sco_conninfo cinfo; + int len, err = 0; + + BT_DBG("sk %p", sk); + + if (get_user(len, optlen)) + return -EFAULT; + + lock_sock(sk); + + switch (optname) { + case SCO_OPTIONS: + if (sk->sk_state != BT_CONNECTED && + !(sk->sk_state == BT_CONNECT2 && + test_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags))) { + err = -ENOTCONN; + break; + } + + opts.mtu = sco_pi(sk)->conn->mtu; + + BT_DBG("mtu %u", opts.mtu); + + len = min_t(unsigned int, len, sizeof(opts)); + if (copy_to_user(optval, (char *)&opts, len)) + err = -EFAULT; + + break; + + case SCO_CONNINFO: + if (sk->sk_state != BT_CONNECTED && + !(sk->sk_state == BT_CONNECT2 && + test_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags))) { + err = -ENOTCONN; + break; + } + + memset(&cinfo, 0, sizeof(cinfo)); + cinfo.hci_handle = sco_pi(sk)->conn->hcon->handle; + memcpy(cinfo.dev_class, sco_pi(sk)->conn->hcon->dev_class, 3); + + len = min_t(unsigned int, len, sizeof(cinfo)); + if (copy_to_user(optval, (char *)&cinfo, len)) + err = -EFAULT; + + break; + + default: + err = -ENOPROTOOPT; + break; + } + + release_sock(sk); + return err; +} + +static int sco_sock_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + int len, err = 0; + struct bt_voice voice; + u32 phys; + int pkt_status; + int buf_len; + struct codec_list *c; + u8 num_codecs, i, __user *ptr; + struct hci_dev *hdev; + struct hci_codec_caps *caps; + struct bt_codec codec; + + BT_DBG("sk %p", sk); + + if (level == SOL_SCO) + return sco_sock_getsockopt_old(sock, optname, optval, optlen); + + if (get_user(len, optlen)) + return -EFAULT; + + lock_sock(sk); + + switch (optname) { + + case BT_DEFER_SETUP: + if (sk->sk_state != BT_BOUND && sk->sk_state != BT_LISTEN) { + err = -EINVAL; + break; + } + + if (put_user(test_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags), + (u32 __user *)optval)) + err = -EFAULT; + + break; + + case BT_VOICE: + voice.setting = sco_pi(sk)->setting; + + len = min_t(unsigned int, len, sizeof(voice)); + if (copy_to_user(optval, (char *)&voice, len)) + err = -EFAULT; + + break; + + case BT_PHY: + if (sk->sk_state != BT_CONNECTED) { + err = -ENOTCONN; + break; + } + + phys = hci_conn_get_phy(sco_pi(sk)->conn->hcon); + + if (put_user(phys, (u32 __user *) optval)) + err = -EFAULT; + break; + + case BT_PKT_STATUS: + pkt_status = (sco_pi(sk)->cmsg_mask & SCO_CMSG_PKT_STATUS); + + if (put_user(pkt_status, (int __user *)optval)) + err = -EFAULT; + break; + + case BT_SNDMTU: + case BT_RCVMTU: + if (sk->sk_state != BT_CONNECTED) { + err = -ENOTCONN; + break; + } + + if (put_user(sco_pi(sk)->conn->mtu, (u32 __user *)optval)) + err = -EFAULT; + break; + + case BT_CODEC: + num_codecs = 0; + buf_len = 0; + + hdev = hci_get_route(&sco_pi(sk)->dst, &sco_pi(sk)->src, BDADDR_BREDR); + if (!hdev) { + err = -EBADFD; + break; + } + + if (!hci_dev_test_flag(hdev, HCI_OFFLOAD_CODECS_ENABLED)) { + hci_dev_put(hdev); + err = -EOPNOTSUPP; + break; + } + + if (!hdev->get_data_path_id) { + hci_dev_put(hdev); + err = -EOPNOTSUPP; + break; + } + + release_sock(sk); + + /* find total buffer size required to copy codec + caps */ + hci_dev_lock(hdev); + list_for_each_entry(c, &hdev->local_codecs, list) { + if (c->transport != HCI_TRANSPORT_SCO_ESCO) + continue; + num_codecs++; + for (i = 0, caps = c->caps; i < c->num_caps; i++) { + buf_len += 1 + caps->len; + caps = (void *)&caps->data[caps->len]; + } + buf_len += sizeof(struct bt_codec); + } + hci_dev_unlock(hdev); + + buf_len += sizeof(struct bt_codecs); + if (buf_len > len) { + hci_dev_put(hdev); + return -ENOBUFS; + } + ptr = optval; + + if (put_user(num_codecs, ptr)) { + hci_dev_put(hdev); + return -EFAULT; + } + ptr += sizeof(num_codecs); + + /* Iterate all the codecs supported over SCO and populate + * codec data + */ + hci_dev_lock(hdev); + list_for_each_entry(c, &hdev->local_codecs, list) { + if (c->transport != HCI_TRANSPORT_SCO_ESCO) + continue; + + codec.id = c->id; + codec.cid = c->cid; + codec.vid = c->vid; + err = hdev->get_data_path_id(hdev, &codec.data_path); + if (err < 0) + break; + codec.num_caps = c->num_caps; + if (copy_to_user(ptr, &codec, sizeof(codec))) { + err = -EFAULT; + break; + } + ptr += sizeof(codec); + + /* find codec capabilities data length */ + len = 0; + for (i = 0, caps = c->caps; i < c->num_caps; i++) { + len += 1 + caps->len; + caps = (void *)&caps->data[caps->len]; + } + + /* copy codec capabilities data */ + if (len && copy_to_user(ptr, c->caps, len)) { + err = -EFAULT; + break; + } + ptr += len; + } + + hci_dev_unlock(hdev); + hci_dev_put(hdev); + + lock_sock(sk); + + if (!err && put_user(buf_len, optlen)) + err = -EFAULT; + + break; + + default: + err = -ENOPROTOOPT; + break; + } + + release_sock(sk); + return err; +} + +static int sco_sock_shutdown(struct socket *sock, int how) +{ + struct sock *sk = sock->sk; + int err = 0; + + BT_DBG("sock %p, sk %p", sock, sk); + + if (!sk) + return 0; + + sock_hold(sk); + lock_sock(sk); + + if (!sk->sk_shutdown) { + sk->sk_shutdown = SHUTDOWN_MASK; + sco_sock_clear_timer(sk); + __sco_sock_close(sk); + + if (sock_flag(sk, SOCK_LINGER) && sk->sk_lingertime && + !(current->flags & PF_EXITING)) + err = bt_sock_wait_state(sk, BT_CLOSED, + sk->sk_lingertime); + } + + release_sock(sk); + sock_put(sk); + + return err; +} + +static int sco_sock_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + int err = 0; + + BT_DBG("sock %p, sk %p", sock, sk); + + if (!sk) + return 0; + + sco_sock_close(sk); + + if (sock_flag(sk, SOCK_LINGER) && READ_ONCE(sk->sk_lingertime) && + !(current->flags & PF_EXITING)) { + lock_sock(sk); + err = bt_sock_wait_state(sk, BT_CLOSED, sk->sk_lingertime); + release_sock(sk); + } + + sock_orphan(sk); + sco_sock_kill(sk); + return err; +} + +static void sco_conn_ready(struct sco_conn *conn) +{ + struct sock *parent; + struct sock *sk = conn->sk; + + BT_DBG("conn %p", conn); + + if (sk) { + lock_sock(sk); + sco_sock_clear_timer(sk); + sk->sk_state = BT_CONNECTED; + sk->sk_state_change(sk); + release_sock(sk); + } else { + sco_conn_lock(conn); + + if (!conn->hcon) { + sco_conn_unlock(conn); + return; + } + + parent = sco_get_sock_listen(&conn->hcon->src); + if (!parent) { + sco_conn_unlock(conn); + return; + } + + lock_sock(parent); + + sk = sco_sock_alloc(sock_net(parent), NULL, + BTPROTO_SCO, GFP_ATOMIC, 0); + if (!sk) { + release_sock(parent); + sco_conn_unlock(conn); + return; + } + + sco_sock_init(sk, parent); + + bacpy(&sco_pi(sk)->src, &conn->hcon->src); + bacpy(&sco_pi(sk)->dst, &conn->hcon->dst); + + hci_conn_hold(conn->hcon); + __sco_chan_add(conn, sk, parent); + + if (test_bit(BT_SK_DEFER_SETUP, &bt_sk(parent)->flags)) + sk->sk_state = BT_CONNECT2; + else + sk->sk_state = BT_CONNECTED; + + /* Wake up parent */ + parent->sk_data_ready(parent); + + release_sock(parent); + + sco_conn_unlock(conn); + } +} + +/* ----- SCO interface with lower layer (HCI) ----- */ +int sco_connect_ind(struct hci_dev *hdev, bdaddr_t *bdaddr, __u8 *flags) +{ + struct sock *sk; + int lm = 0; + + BT_DBG("hdev %s, bdaddr %pMR", hdev->name, bdaddr); + + /* Find listening sockets */ + read_lock(&sco_sk_list.lock); + sk_for_each(sk, &sco_sk_list.head) { + if (sk->sk_state != BT_LISTEN) + continue; + + if (!bacmp(&sco_pi(sk)->src, &hdev->bdaddr) || + !bacmp(&sco_pi(sk)->src, BDADDR_ANY)) { + lm |= HCI_LM_ACCEPT; + + if (test_bit(BT_SK_DEFER_SETUP, &bt_sk(sk)->flags)) + *flags |= HCI_PROTO_DEFER; + break; + } + } + read_unlock(&sco_sk_list.lock); + + return lm; +} + +static void sco_connect_cfm(struct hci_conn *hcon, __u8 status) +{ + if (hcon->type != SCO_LINK && hcon->type != ESCO_LINK) + return; + + BT_DBG("hcon %p bdaddr %pMR status %u", hcon, &hcon->dst, status); + + if (!status) { + struct sco_conn *conn; + + conn = sco_conn_add(hcon); + if (conn) + sco_conn_ready(conn); + } else + sco_conn_del(hcon, bt_to_errno(status)); +} + +static void sco_disconn_cfm(struct hci_conn *hcon, __u8 reason) +{ + if (hcon->type != SCO_LINK && hcon->type != ESCO_LINK) + return; + + BT_DBG("hcon %p reason %d", hcon, reason); + + sco_conn_del(hcon, bt_to_errno(reason)); +} + +void sco_recv_scodata(struct hci_conn *hcon, struct sk_buff *skb) +{ + struct sco_conn *conn = hcon->sco_data; + + if (!conn) + goto drop; + + BT_DBG("conn %p len %u", conn, skb->len); + + if (skb->len) { + sco_recv_frame(conn, skb); + return; + } + +drop: + kfree_skb(skb); +} + +static struct hci_cb sco_cb = { + .name = "SCO", + .connect_cfm = sco_connect_cfm, + .disconn_cfm = sco_disconn_cfm, +}; + +static int sco_debugfs_show(struct seq_file *f, void *p) +{ + struct sock *sk; + + read_lock(&sco_sk_list.lock); + + sk_for_each(sk, &sco_sk_list.head) { + seq_printf(f, "%pMR %pMR %d\n", &sco_pi(sk)->src, + &sco_pi(sk)->dst, sk->sk_state); + } + + read_unlock(&sco_sk_list.lock); + + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(sco_debugfs); + +static struct dentry *sco_debugfs; + +static const struct proto_ops sco_sock_ops = { + .family = PF_BLUETOOTH, + .owner = THIS_MODULE, + .release = sco_sock_release, + .bind = sco_sock_bind, + .connect = sco_sock_connect, + .listen = sco_sock_listen, + .accept = sco_sock_accept, + .getname = sco_sock_getname, + .sendmsg = sco_sock_sendmsg, + .recvmsg = sco_sock_recvmsg, + .poll = bt_sock_poll, + .ioctl = bt_sock_ioctl, + .gettstamp = sock_gettstamp, + .mmap = sock_no_mmap, + .socketpair = sock_no_socketpair, + .shutdown = sco_sock_shutdown, + .setsockopt = sco_sock_setsockopt, + .getsockopt = sco_sock_getsockopt +}; + +static const struct net_proto_family sco_sock_family_ops = { + .family = PF_BLUETOOTH, + .owner = THIS_MODULE, + .create = sco_sock_create, +}; + +int __init sco_init(void) +{ + int err; + + BUILD_BUG_ON(sizeof(struct sockaddr_sco) > sizeof(struct sockaddr)); + + err = proto_register(&sco_proto, 0); + if (err < 0) + return err; + + err = bt_sock_register(BTPROTO_SCO, &sco_sock_family_ops); + if (err < 0) { + BT_ERR("SCO socket registration failed"); + goto error; + } + + err = bt_procfs_init(&init_net, "sco", &sco_sk_list, NULL); + if (err < 0) { + BT_ERR("Failed to create SCO proc file"); + bt_sock_unregister(BTPROTO_SCO); + goto error; + } + + BT_INFO("SCO socket layer initialized"); + + hci_register_cb(&sco_cb); + + if (IS_ERR_OR_NULL(bt_debugfs)) + return 0; + + sco_debugfs = debugfs_create_file("sco", 0444, bt_debugfs, + NULL, &sco_debugfs_fops); + + return 0; + +error: + proto_unregister(&sco_proto); + return err; +} + +void sco_exit(void) +{ + bt_procfs_cleanup(&init_net, "sco"); + + debugfs_remove(sco_debugfs); + + hci_unregister_cb(&sco_cb); + + bt_sock_unregister(BTPROTO_SCO); + + proto_unregister(&sco_proto); +} + +module_param(disable_esco, bool, 0644); +MODULE_PARM_DESC(disable_esco, "Disable eSCO connection creation"); diff --git a/net/bluetooth/selftest.c b/net/bluetooth/selftest.c new file mode 100644 index 000000000..f49604d44 --- /dev/null +++ b/net/bluetooth/selftest.c @@ -0,0 +1,309 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + + Copyright (C) 2014 Intel Corporation + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#include + +#include +#include + +#include "ecdh_helper.h" +#include "smp.h" +#include "selftest.h" + +#if IS_ENABLED(CONFIG_BT_SELFTEST_ECDH) + +static const u8 priv_a_1[32] __initconst = { + 0xbd, 0x1a, 0x3c, 0xcd, 0xa6, 0xb8, 0x99, 0x58, + 0x99, 0xb7, 0x40, 0xeb, 0x7b, 0x60, 0xff, 0x4a, + 0x50, 0x3f, 0x10, 0xd2, 0xe3, 0xb3, 0xc9, 0x74, + 0x38, 0x5f, 0xc5, 0xa3, 0xd4, 0xf6, 0x49, 0x3f, +}; +static const u8 priv_b_1[32] __initconst = { + 0xfd, 0xc5, 0x7f, 0xf4, 0x49, 0xdd, 0x4f, 0x6b, + 0xfb, 0x7c, 0x9d, 0xf1, 0xc2, 0x9a, 0xcb, 0x59, + 0x2a, 0xe7, 0xd4, 0xee, 0xfb, 0xfc, 0x0a, 0x90, + 0x9a, 0xbb, 0xf6, 0x32, 0x3d, 0x8b, 0x18, 0x55, +}; +static const u8 pub_a_1[64] __initconst = { + 0xe6, 0x9d, 0x35, 0x0e, 0x48, 0x01, 0x03, 0xcc, + 0xdb, 0xfd, 0xf4, 0xac, 0x11, 0x91, 0xf4, 0xef, + 0xb9, 0xa5, 0xf9, 0xe9, 0xa7, 0x83, 0x2c, 0x5e, + 0x2c, 0xbe, 0x97, 0xf2, 0xd2, 0x03, 0xb0, 0x20, + + 0x8b, 0xd2, 0x89, 0x15, 0xd0, 0x8e, 0x1c, 0x74, + 0x24, 0x30, 0xed, 0x8f, 0xc2, 0x45, 0x63, 0x76, + 0x5c, 0x15, 0x52, 0x5a, 0xbf, 0x9a, 0x32, 0x63, + 0x6d, 0xeb, 0x2a, 0x65, 0x49, 0x9c, 0x80, 0xdc, +}; +static const u8 pub_b_1[64] __initconst = { + 0x90, 0xa1, 0xaa, 0x2f, 0xb2, 0x77, 0x90, 0x55, + 0x9f, 0xa6, 0x15, 0x86, 0xfd, 0x8a, 0xb5, 0x47, + 0x00, 0x4c, 0x9e, 0xf1, 0x84, 0x22, 0x59, 0x09, + 0x96, 0x1d, 0xaf, 0x1f, 0xf0, 0xf0, 0xa1, 0x1e, + + 0x4a, 0x21, 0xb1, 0x15, 0xf9, 0xaf, 0x89, 0x5f, + 0x76, 0x36, 0x8e, 0xe2, 0x30, 0x11, 0x2d, 0x47, + 0x60, 0x51, 0xb8, 0x9a, 0x3a, 0x70, 0x56, 0x73, + 0x37, 0xad, 0x9d, 0x42, 0x3e, 0xf3, 0x55, 0x4c, +}; +static const u8 dhkey_1[32] __initconst = { + 0x98, 0xa6, 0xbf, 0x73, 0xf3, 0x34, 0x8d, 0x86, + 0xf1, 0x66, 0xf8, 0xb4, 0x13, 0x6b, 0x79, 0x99, + 0x9b, 0x7d, 0x39, 0x0a, 0xa6, 0x10, 0x10, 0x34, + 0x05, 0xad, 0xc8, 0x57, 0xa3, 0x34, 0x02, 0xec, +}; + +static const u8 priv_a_2[32] __initconst = { + 0x63, 0x76, 0x45, 0xd0, 0xf7, 0x73, 0xac, 0xb7, + 0xff, 0xdd, 0x03, 0x72, 0xb9, 0x72, 0x85, 0xb4, + 0x41, 0xb6, 0x5d, 0x0c, 0x5d, 0x54, 0x84, 0x60, + 0x1a, 0xa3, 0x9a, 0x3c, 0x69, 0x16, 0xa5, 0x06, +}; +static const u8 priv_b_2[32] __initconst = { + 0xba, 0x30, 0x55, 0x50, 0x19, 0xa2, 0xca, 0xa3, + 0xa5, 0x29, 0x08, 0xc6, 0xb5, 0x03, 0x88, 0x7e, + 0x03, 0x2b, 0x50, 0x73, 0xd4, 0x2e, 0x50, 0x97, + 0x64, 0xcd, 0x72, 0x0d, 0x67, 0xa0, 0x9a, 0x52, +}; +static const u8 pub_a_2[64] __initconst = { + 0xdd, 0x78, 0x5c, 0x74, 0x03, 0x9b, 0x7e, 0x98, + 0xcb, 0x94, 0x87, 0x4a, 0xad, 0xfa, 0xf8, 0xd5, + 0x43, 0x3e, 0x5c, 0xaf, 0xea, 0xb5, 0x4c, 0xf4, + 0x9e, 0x80, 0x79, 0x57, 0x7b, 0xa4, 0x31, 0x2c, + + 0x4f, 0x5d, 0x71, 0x43, 0x77, 0x43, 0xf8, 0xea, + 0xd4, 0x3e, 0xbd, 0x17, 0x91, 0x10, 0x21, 0xd0, + 0x1f, 0x87, 0x43, 0x8e, 0x40, 0xe2, 0x52, 0xcd, + 0xbe, 0xdf, 0x98, 0x38, 0x18, 0x12, 0x95, 0x91, +}; +static const u8 pub_b_2[64] __initconst = { + 0xcc, 0x00, 0x65, 0xe1, 0xf5, 0x6c, 0x0d, 0xcf, + 0xec, 0x96, 0x47, 0x20, 0x66, 0xc9, 0xdb, 0x84, + 0x81, 0x75, 0xa8, 0x4d, 0xc0, 0xdf, 0xc7, 0x9d, + 0x1b, 0x3f, 0x3d, 0xf2, 0x3f, 0xe4, 0x65, 0xf4, + + 0x79, 0xb2, 0xec, 0xd8, 0xca, 0x55, 0xa1, 0xa8, + 0x43, 0x4d, 0x6b, 0xca, 0x10, 0xb0, 0xc2, 0x01, + 0xc2, 0x33, 0x4e, 0x16, 0x24, 0xc4, 0xef, 0xee, + 0x99, 0xd8, 0xbb, 0xbc, 0x48, 0xd0, 0x01, 0x02, +}; +static const u8 dhkey_2[32] __initconst = { + 0x69, 0xeb, 0x21, 0x32, 0xf2, 0xc6, 0x05, 0x41, + 0x60, 0x19, 0xcd, 0x5e, 0x94, 0xe1, 0xe6, 0x5f, + 0x33, 0x07, 0xe3, 0x38, 0x4b, 0x68, 0xe5, 0x62, + 0x3f, 0x88, 0x6d, 0x2f, 0x3a, 0x84, 0x85, 0xab, +}; + +static const u8 priv_a_3[32] __initconst = { + 0xbd, 0x1a, 0x3c, 0xcd, 0xa6, 0xb8, 0x99, 0x58, + 0x99, 0xb7, 0x40, 0xeb, 0x7b, 0x60, 0xff, 0x4a, + 0x50, 0x3f, 0x10, 0xd2, 0xe3, 0xb3, 0xc9, 0x74, + 0x38, 0x5f, 0xc5, 0xa3, 0xd4, 0xf6, 0x49, 0x3f, +}; +static const u8 pub_a_3[64] __initconst = { + 0xe6, 0x9d, 0x35, 0x0e, 0x48, 0x01, 0x03, 0xcc, + 0xdb, 0xfd, 0xf4, 0xac, 0x11, 0x91, 0xf4, 0xef, + 0xb9, 0xa5, 0xf9, 0xe9, 0xa7, 0x83, 0x2c, 0x5e, + 0x2c, 0xbe, 0x97, 0xf2, 0xd2, 0x03, 0xb0, 0x20, + + 0x8b, 0xd2, 0x89, 0x15, 0xd0, 0x8e, 0x1c, 0x74, + 0x24, 0x30, 0xed, 0x8f, 0xc2, 0x45, 0x63, 0x76, + 0x5c, 0x15, 0x52, 0x5a, 0xbf, 0x9a, 0x32, 0x63, + 0x6d, 0xeb, 0x2a, 0x65, 0x49, 0x9c, 0x80, 0xdc, +}; +static const u8 dhkey_3[32] __initconst = { + 0x2d, 0xab, 0x00, 0x48, 0xcb, 0xb3, 0x7b, 0xda, + 0x55, 0x7b, 0x8b, 0x72, 0xa8, 0x57, 0x87, 0xc3, + 0x87, 0x27, 0x99, 0x32, 0xfc, 0x79, 0x5f, 0xae, + 0x7c, 0x1c, 0xf9, 0x49, 0xe6, 0xd7, 0xaa, 0x70, +}; + +static int __init test_ecdh_sample(struct crypto_kpp *tfm, const u8 priv_a[32], + const u8 priv_b[32], const u8 pub_a[64], + const u8 pub_b[64], const u8 dhkey[32]) +{ + u8 *tmp, *dhkey_a, *dhkey_b; + int ret; + + tmp = kmalloc(64, GFP_KERNEL); + if (!tmp) + return -EINVAL; + + dhkey_a = &tmp[0]; + dhkey_b = &tmp[32]; + + ret = set_ecdh_privkey(tfm, priv_a); + if (ret) + goto out; + + ret = compute_ecdh_secret(tfm, pub_b, dhkey_a); + if (ret) + goto out; + + if (memcmp(dhkey_a, dhkey, 32)) { + ret = -EINVAL; + goto out; + } + + ret = set_ecdh_privkey(tfm, priv_b); + if (ret) + goto out; + + ret = compute_ecdh_secret(tfm, pub_a, dhkey_b); + if (ret) + goto out; + + if (memcmp(dhkey_b, dhkey, 32)) + ret = -EINVAL; + /* fall through*/ +out: + kfree(tmp); + return ret; +} + +static char test_ecdh_buffer[32]; + +static ssize_t test_ecdh_read(struct file *file, char __user *user_buf, + size_t count, loff_t *ppos) +{ + return simple_read_from_buffer(user_buf, count, ppos, test_ecdh_buffer, + strlen(test_ecdh_buffer)); +} + +static const struct file_operations test_ecdh_fops = { + .open = simple_open, + .read = test_ecdh_read, + .llseek = default_llseek, +}; + +static int __init test_ecdh(void) +{ + struct crypto_kpp *tfm; + ktime_t calltime, delta, rettime; + unsigned long long duration = 0; + int err; + + calltime = ktime_get(); + + tfm = crypto_alloc_kpp("ecdh-nist-p256", 0, 0); + if (IS_ERR(tfm)) { + BT_ERR("Unable to create ECDH crypto context"); + err = PTR_ERR(tfm); + goto done; + } + + err = test_ecdh_sample(tfm, priv_a_1, priv_b_1, pub_a_1, pub_b_1, + dhkey_1); + if (err) { + BT_ERR("ECDH sample 1 failed"); + goto done; + } + + err = test_ecdh_sample(tfm, priv_a_2, priv_b_2, pub_a_2, pub_b_2, + dhkey_2); + if (err) { + BT_ERR("ECDH sample 2 failed"); + goto done; + } + + err = test_ecdh_sample(tfm, priv_a_3, priv_a_3, pub_a_3, pub_a_3, + dhkey_3); + if (err) { + BT_ERR("ECDH sample 3 failed"); + goto done; + } + + crypto_free_kpp(tfm); + + rettime = ktime_get(); + delta = ktime_sub(rettime, calltime); + duration = (unsigned long long) ktime_to_ns(delta) >> 10; + + BT_INFO("ECDH test passed in %llu usecs", duration); + +done: + if (!err) + snprintf(test_ecdh_buffer, sizeof(test_ecdh_buffer), + "PASS (%llu usecs)\n", duration); + else + snprintf(test_ecdh_buffer, sizeof(test_ecdh_buffer), "FAIL\n"); + + debugfs_create_file("selftest_ecdh", 0444, bt_debugfs, NULL, + &test_ecdh_fops); + + return err; +} + +#else + +static inline int test_ecdh(void) +{ + return 0; +} + +#endif + +static int __init run_selftest(void) +{ + int err; + + BT_INFO("Starting self testing"); + + err = test_ecdh(); + if (err) + goto done; + + err = bt_selftest_smp(); + +done: + BT_INFO("Finished self testing"); + + return err; +} + +#if IS_MODULE(CONFIG_BT) + +/* This is run when CONFIG_BT_SELFTEST=y and CONFIG_BT=m and is just a + * wrapper to allow running this at module init. + * + * If CONFIG_BT_SELFTEST=n, then this code is not compiled at all. + */ +int __init bt_selftest(void) +{ + return run_selftest(); +} + +#else + +/* This is run when CONFIG_BT_SELFTEST=y and CONFIG_BT=y and is run + * via late_initcall() as last item in the initialization sequence. + * + * If CONFIG_BT_SELFTEST=n, then this code is not compiled at all. + */ +static int __init bt_selftest_init(void) +{ + return run_selftest(); +} +late_initcall(bt_selftest_init); + +#endif diff --git a/net/bluetooth/selftest.h b/net/bluetooth/selftest.h new file mode 100644 index 000000000..2aa0a346a --- /dev/null +++ b/net/bluetooth/selftest.h @@ -0,0 +1,45 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + Copyright (C) 2014 Intel Corporation + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#if IS_ENABLED(CONFIG_BT_SELFTEST) && IS_MODULE(CONFIG_BT) + +/* When CONFIG_BT_SELFTEST=y and the CONFIG_BT=m, then the self testing + * is run at module loading time. + */ +int bt_selftest(void); + +#else + +/* When CONFIG_BT_SELFTEST=y and CONFIG_BT=y, then the self testing + * is run via late_initcall() to make sure that subsys_initcall() of + * the Bluetooth subsystem and device_initcall() of the Crypto subsystem + * do not clash. + * + * When CONFIG_BT_SELFTEST=n, then this turns into an empty call that + * has no impact. + */ +static inline int bt_selftest(void) +{ + return 0; +} + +#endif diff --git a/net/bluetooth/smp.c b/net/bluetooth/smp.c new file mode 100644 index 000000000..ecb005bce --- /dev/null +++ b/net/bluetooth/smp.c @@ -0,0 +1,3848 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + Copyright (C) 2011 Nokia Corporation and/or its subsidiary(-ies). + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "ecdh_helper.h" +#include "smp.h" + +#define SMP_DEV(hdev) \ + ((struct smp_dev *)((struct l2cap_chan *)((hdev)->smp_data))->data) + +/* Low-level debug macros to be used for stuff that we don't want + * accidentally in dmesg, i.e. the values of the various crypto keys + * and the inputs & outputs of crypto functions. + */ +#ifdef DEBUG +#define SMP_DBG(fmt, ...) printk(KERN_DEBUG "%s: " fmt, __func__, \ + ##__VA_ARGS__) +#else +#define SMP_DBG(fmt, ...) no_printk(KERN_DEBUG "%s: " fmt, __func__, \ + ##__VA_ARGS__) +#endif + +#define SMP_ALLOW_CMD(smp, code) set_bit(code, &smp->allow_cmd) + +/* Keys which are not distributed with Secure Connections */ +#define SMP_SC_NO_DIST (SMP_DIST_ENC_KEY | SMP_DIST_LINK_KEY) + +#define SMP_TIMEOUT msecs_to_jiffies(30000) + +#define AUTH_REQ_MASK(dev) (hci_dev_test_flag(dev, HCI_SC_ENABLED) ? \ + 0x3f : 0x07) +#define KEY_DIST_MASK 0x07 + +/* Maximum message length that can be passed to aes_cmac */ +#define CMAC_MSG_MAX 80 + +enum { + SMP_FLAG_TK_VALID, + SMP_FLAG_CFM_PENDING, + SMP_FLAG_MITM_AUTH, + SMP_FLAG_COMPLETE, + SMP_FLAG_INITIATOR, + SMP_FLAG_SC, + SMP_FLAG_REMOTE_PK, + SMP_FLAG_DEBUG_KEY, + SMP_FLAG_WAIT_USER, + SMP_FLAG_DHKEY_PENDING, + SMP_FLAG_REMOTE_OOB, + SMP_FLAG_LOCAL_OOB, + SMP_FLAG_CT2, +}; + +struct smp_dev { + /* Secure Connections OOB data */ + bool local_oob; + u8 local_pk[64]; + u8 local_rand[16]; + bool debug_key; + + struct crypto_shash *tfm_cmac; + struct crypto_kpp *tfm_ecdh; +}; + +struct smp_chan { + struct l2cap_conn *conn; + struct delayed_work security_timer; + unsigned long allow_cmd; /* Bitmask of allowed commands */ + + u8 preq[7]; /* SMP Pairing Request */ + u8 prsp[7]; /* SMP Pairing Response */ + u8 prnd[16]; /* SMP Pairing Random (local) */ + u8 rrnd[16]; /* SMP Pairing Random (remote) */ + u8 pcnf[16]; /* SMP Pairing Confirm */ + u8 tk[16]; /* SMP Temporary Key */ + u8 rr[16]; /* Remote OOB ra/rb value */ + u8 lr[16]; /* Local OOB ra/rb value */ + u8 enc_key_size; + u8 remote_key_dist; + bdaddr_t id_addr; + u8 id_addr_type; + u8 irk[16]; + struct smp_csrk *csrk; + struct smp_csrk *responder_csrk; + struct smp_ltk *ltk; + struct smp_ltk *responder_ltk; + struct smp_irk *remote_irk; + u8 *link_key; + unsigned long flags; + u8 method; + u8 passkey_round; + + /* Secure Connections variables */ + u8 local_pk[64]; + u8 remote_pk[64]; + u8 dhkey[32]; + u8 mackey[16]; + + struct crypto_shash *tfm_cmac; + struct crypto_kpp *tfm_ecdh; +}; + +/* These debug key values are defined in the SMP section of the core + * specification. debug_pk is the public debug key and debug_sk the + * private debug key. + */ +static const u8 debug_pk[64] = { + 0xe6, 0x9d, 0x35, 0x0e, 0x48, 0x01, 0x03, 0xcc, + 0xdb, 0xfd, 0xf4, 0xac, 0x11, 0x91, 0xf4, 0xef, + 0xb9, 0xa5, 0xf9, 0xe9, 0xa7, 0x83, 0x2c, 0x5e, + 0x2c, 0xbe, 0x97, 0xf2, 0xd2, 0x03, 0xb0, 0x20, + + 0x8b, 0xd2, 0x89, 0x15, 0xd0, 0x8e, 0x1c, 0x74, + 0x24, 0x30, 0xed, 0x8f, 0xc2, 0x45, 0x63, 0x76, + 0x5c, 0x15, 0x52, 0x5a, 0xbf, 0x9a, 0x32, 0x63, + 0x6d, 0xeb, 0x2a, 0x65, 0x49, 0x9c, 0x80, 0xdc, +}; + +static const u8 debug_sk[32] = { + 0xbd, 0x1a, 0x3c, 0xcd, 0xa6, 0xb8, 0x99, 0x58, + 0x99, 0xb7, 0x40, 0xeb, 0x7b, 0x60, 0xff, 0x4a, + 0x50, 0x3f, 0x10, 0xd2, 0xe3, 0xb3, 0xc9, 0x74, + 0x38, 0x5f, 0xc5, 0xa3, 0xd4, 0xf6, 0x49, 0x3f, +}; + +static inline void swap_buf(const u8 *src, u8 *dst, size_t len) +{ + size_t i; + + for (i = 0; i < len; i++) + dst[len - 1 - i] = src[i]; +} + +/* The following functions map to the LE SC SMP crypto functions + * AES-CMAC, f4, f5, f6, g2 and h6. + */ + +static int aes_cmac(struct crypto_shash *tfm, const u8 k[16], const u8 *m, + size_t len, u8 mac[16]) +{ + uint8_t tmp[16], mac_msb[16], msg_msb[CMAC_MSG_MAX]; + int err; + + if (len > CMAC_MSG_MAX) + return -EFBIG; + + if (!tfm) { + BT_ERR("tfm %p", tfm); + return -EINVAL; + } + + /* Swap key and message from LSB to MSB */ + swap_buf(k, tmp, 16); + swap_buf(m, msg_msb, len); + + SMP_DBG("msg (len %zu) %*phN", len, (int) len, m); + SMP_DBG("key %16phN", k); + + err = crypto_shash_setkey(tfm, tmp, 16); + if (err) { + BT_ERR("cipher setkey failed: %d", err); + return err; + } + + err = crypto_shash_tfm_digest(tfm, msg_msb, len, mac_msb); + if (err) { + BT_ERR("Hash computation error %d", err); + return err; + } + + swap_buf(mac_msb, mac, 16); + + SMP_DBG("mac %16phN", mac); + + return 0; +} + +static int smp_f4(struct crypto_shash *tfm_cmac, const u8 u[32], + const u8 v[32], const u8 x[16], u8 z, u8 res[16]) +{ + u8 m[65]; + int err; + + SMP_DBG("u %32phN", u); + SMP_DBG("v %32phN", v); + SMP_DBG("x %16phN z %02x", x, z); + + m[0] = z; + memcpy(m + 1, v, 32); + memcpy(m + 33, u, 32); + + err = aes_cmac(tfm_cmac, x, m, sizeof(m), res); + if (err) + return err; + + SMP_DBG("res %16phN", res); + + return err; +} + +static int smp_f5(struct crypto_shash *tfm_cmac, const u8 w[32], + const u8 n1[16], const u8 n2[16], const u8 a1[7], + const u8 a2[7], u8 mackey[16], u8 ltk[16]) +{ + /* The btle, salt and length "magic" values are as defined in + * the SMP section of the Bluetooth core specification. In ASCII + * the btle value ends up being 'btle'. The salt is just a + * random number whereas length is the value 256 in little + * endian format. + */ + const u8 btle[4] = { 0x65, 0x6c, 0x74, 0x62 }; + const u8 salt[16] = { 0xbe, 0x83, 0x60, 0x5a, 0xdb, 0x0b, 0x37, 0x60, + 0x38, 0xa5, 0xf5, 0xaa, 0x91, 0x83, 0x88, 0x6c }; + const u8 length[2] = { 0x00, 0x01 }; + u8 m[53], t[16]; + int err; + + SMP_DBG("w %32phN", w); + SMP_DBG("n1 %16phN n2 %16phN", n1, n2); + SMP_DBG("a1 %7phN a2 %7phN", a1, a2); + + err = aes_cmac(tfm_cmac, salt, w, 32, t); + if (err) + return err; + + SMP_DBG("t %16phN", t); + + memcpy(m, length, 2); + memcpy(m + 2, a2, 7); + memcpy(m + 9, a1, 7); + memcpy(m + 16, n2, 16); + memcpy(m + 32, n1, 16); + memcpy(m + 48, btle, 4); + + m[52] = 0; /* Counter */ + + err = aes_cmac(tfm_cmac, t, m, sizeof(m), mackey); + if (err) + return err; + + SMP_DBG("mackey %16phN", mackey); + + m[52] = 1; /* Counter */ + + err = aes_cmac(tfm_cmac, t, m, sizeof(m), ltk); + if (err) + return err; + + SMP_DBG("ltk %16phN", ltk); + + return 0; +} + +static int smp_f6(struct crypto_shash *tfm_cmac, const u8 w[16], + const u8 n1[16], const u8 n2[16], const u8 r[16], + const u8 io_cap[3], const u8 a1[7], const u8 a2[7], + u8 res[16]) +{ + u8 m[65]; + int err; + + SMP_DBG("w %16phN", w); + SMP_DBG("n1 %16phN n2 %16phN", n1, n2); + SMP_DBG("r %16phN io_cap %3phN a1 %7phN a2 %7phN", r, io_cap, a1, a2); + + memcpy(m, a2, 7); + memcpy(m + 7, a1, 7); + memcpy(m + 14, io_cap, 3); + memcpy(m + 17, r, 16); + memcpy(m + 33, n2, 16); + memcpy(m + 49, n1, 16); + + err = aes_cmac(tfm_cmac, w, m, sizeof(m), res); + if (err) + return err; + + SMP_DBG("res %16phN", res); + + return err; +} + +static int smp_g2(struct crypto_shash *tfm_cmac, const u8 u[32], const u8 v[32], + const u8 x[16], const u8 y[16], u32 *val) +{ + u8 m[80], tmp[16]; + int err; + + SMP_DBG("u %32phN", u); + SMP_DBG("v %32phN", v); + SMP_DBG("x %16phN y %16phN", x, y); + + memcpy(m, y, 16); + memcpy(m + 16, v, 32); + memcpy(m + 48, u, 32); + + err = aes_cmac(tfm_cmac, x, m, sizeof(m), tmp); + if (err) + return err; + + *val = get_unaligned_le32(tmp); + *val %= 1000000; + + SMP_DBG("val %06u", *val); + + return 0; +} + +static int smp_h6(struct crypto_shash *tfm_cmac, const u8 w[16], + const u8 key_id[4], u8 res[16]) +{ + int err; + + SMP_DBG("w %16phN key_id %4phN", w, key_id); + + err = aes_cmac(tfm_cmac, w, key_id, 4, res); + if (err) + return err; + + SMP_DBG("res %16phN", res); + + return err; +} + +static int smp_h7(struct crypto_shash *tfm_cmac, const u8 w[16], + const u8 salt[16], u8 res[16]) +{ + int err; + + SMP_DBG("w %16phN salt %16phN", w, salt); + + err = aes_cmac(tfm_cmac, salt, w, 16, res); + if (err) + return err; + + SMP_DBG("res %16phN", res); + + return err; +} + +/* The following functions map to the legacy SMP crypto functions e, c1, + * s1 and ah. + */ + +static int smp_e(const u8 *k, u8 *r) +{ + struct crypto_aes_ctx ctx; + uint8_t tmp[16], data[16]; + int err; + + SMP_DBG("k %16phN r %16phN", k, r); + + /* The most significant octet of key corresponds to k[0] */ + swap_buf(k, tmp, 16); + + err = aes_expandkey(&ctx, tmp, 16); + if (err) { + BT_ERR("cipher setkey failed: %d", err); + return err; + } + + /* Most significant octet of plaintextData corresponds to data[0] */ + swap_buf(r, data, 16); + + aes_encrypt(&ctx, data, data); + + /* Most significant octet of encryptedData corresponds to data[0] */ + swap_buf(data, r, 16); + + SMP_DBG("r %16phN", r); + + memzero_explicit(&ctx, sizeof(ctx)); + return err; +} + +static int smp_c1(const u8 k[16], + const u8 r[16], const u8 preq[7], const u8 pres[7], u8 _iat, + const bdaddr_t *ia, u8 _rat, const bdaddr_t *ra, u8 res[16]) +{ + u8 p1[16], p2[16]; + int err; + + SMP_DBG("k %16phN r %16phN", k, r); + SMP_DBG("iat %u ia %6phN rat %u ra %6phN", _iat, ia, _rat, ra); + SMP_DBG("preq %7phN pres %7phN", preq, pres); + + memset(p1, 0, 16); + + /* p1 = pres || preq || _rat || _iat */ + p1[0] = _iat; + p1[1] = _rat; + memcpy(p1 + 2, preq, 7); + memcpy(p1 + 9, pres, 7); + + SMP_DBG("p1 %16phN", p1); + + /* res = r XOR p1 */ + crypto_xor_cpy(res, r, p1, sizeof(p1)); + + /* res = e(k, res) */ + err = smp_e(k, res); + if (err) { + BT_ERR("Encrypt data error"); + return err; + } + + /* p2 = padding || ia || ra */ + memcpy(p2, ra, 6); + memcpy(p2 + 6, ia, 6); + memset(p2 + 12, 0, 4); + + SMP_DBG("p2 %16phN", p2); + + /* res = res XOR p2 */ + crypto_xor(res, p2, sizeof(p2)); + + /* res = e(k, res) */ + err = smp_e(k, res); + if (err) + BT_ERR("Encrypt data error"); + + return err; +} + +static int smp_s1(const u8 k[16], + const u8 r1[16], const u8 r2[16], u8 _r[16]) +{ + int err; + + /* Just least significant octets from r1 and r2 are considered */ + memcpy(_r, r2, 8); + memcpy(_r + 8, r1, 8); + + err = smp_e(k, _r); + if (err) + BT_ERR("Encrypt data error"); + + return err; +} + +static int smp_ah(const u8 irk[16], const u8 r[3], u8 res[3]) +{ + u8 _res[16]; + int err; + + /* r' = padding || r */ + memcpy(_res, r, 3); + memset(_res + 3, 0, 13); + + err = smp_e(irk, _res); + if (err) { + BT_ERR("Encrypt error"); + return err; + } + + /* The output of the random address function ah is: + * ah(k, r) = e(k, r') mod 2^24 + * The output of the security function e is then truncated to 24 bits + * by taking the least significant 24 bits of the output of e as the + * result of ah. + */ + memcpy(res, _res, 3); + + return 0; +} + +bool smp_irk_matches(struct hci_dev *hdev, const u8 irk[16], + const bdaddr_t *bdaddr) +{ + struct l2cap_chan *chan = hdev->smp_data; + u8 hash[3]; + int err; + + if (!chan || !chan->data) + return false; + + bt_dev_dbg(hdev, "RPA %pMR IRK %*phN", bdaddr, 16, irk); + + err = smp_ah(irk, &bdaddr->b[3], hash); + if (err) + return false; + + return !crypto_memneq(bdaddr->b, hash, 3); +} + +int smp_generate_rpa(struct hci_dev *hdev, const u8 irk[16], bdaddr_t *rpa) +{ + struct l2cap_chan *chan = hdev->smp_data; + int err; + + if (!chan || !chan->data) + return -EOPNOTSUPP; + + get_random_bytes(&rpa->b[3], 3); + + rpa->b[5] &= 0x3f; /* Clear two most significant bits */ + rpa->b[5] |= 0x40; /* Set second most significant bit */ + + err = smp_ah(irk, &rpa->b[3], rpa->b); + if (err < 0) + return err; + + bt_dev_dbg(hdev, "RPA %pMR", rpa); + + return 0; +} + +int smp_generate_oob(struct hci_dev *hdev, u8 hash[16], u8 rand[16]) +{ + struct l2cap_chan *chan = hdev->smp_data; + struct smp_dev *smp; + int err; + + if (!chan || !chan->data) + return -EOPNOTSUPP; + + smp = chan->data; + + if (hci_dev_test_flag(hdev, HCI_USE_DEBUG_KEYS)) { + bt_dev_dbg(hdev, "Using debug keys"); + err = set_ecdh_privkey(smp->tfm_ecdh, debug_sk); + if (err) + return err; + memcpy(smp->local_pk, debug_pk, 64); + smp->debug_key = true; + } else { + while (true) { + /* Generate key pair for Secure Connections */ + err = generate_ecdh_keys(smp->tfm_ecdh, smp->local_pk); + if (err) + return err; + + /* This is unlikely, but we need to check that + * we didn't accidentally generate a debug key. + */ + if (crypto_memneq(smp->local_pk, debug_pk, 64)) + break; + } + smp->debug_key = false; + } + + SMP_DBG("OOB Public Key X: %32phN", smp->local_pk); + SMP_DBG("OOB Public Key Y: %32phN", smp->local_pk + 32); + + get_random_bytes(smp->local_rand, 16); + + err = smp_f4(smp->tfm_cmac, smp->local_pk, smp->local_pk, + smp->local_rand, 0, hash); + if (err < 0) + return err; + + memcpy(rand, smp->local_rand, 16); + + smp->local_oob = true; + + return 0; +} + +static void smp_send_cmd(struct l2cap_conn *conn, u8 code, u16 len, void *data) +{ + struct l2cap_chan *chan = conn->smp; + struct smp_chan *smp; + struct kvec iv[2]; + struct msghdr msg; + + if (!chan) + return; + + bt_dev_dbg(conn->hcon->hdev, "code 0x%2.2x", code); + + iv[0].iov_base = &code; + iv[0].iov_len = 1; + + iv[1].iov_base = data; + iv[1].iov_len = len; + + memset(&msg, 0, sizeof(msg)); + + iov_iter_kvec(&msg.msg_iter, ITER_SOURCE, iv, 2, 1 + len); + + l2cap_chan_send(chan, &msg, 1 + len); + + if (!chan->data) + return; + + smp = chan->data; + + cancel_delayed_work_sync(&smp->security_timer); + schedule_delayed_work(&smp->security_timer, SMP_TIMEOUT); +} + +static u8 authreq_to_seclevel(u8 authreq) +{ + if (authreq & SMP_AUTH_MITM) { + if (authreq & SMP_AUTH_SC) + return BT_SECURITY_FIPS; + else + return BT_SECURITY_HIGH; + } else { + return BT_SECURITY_MEDIUM; + } +} + +static __u8 seclevel_to_authreq(__u8 sec_level) +{ + switch (sec_level) { + case BT_SECURITY_FIPS: + case BT_SECURITY_HIGH: + return SMP_AUTH_MITM | SMP_AUTH_BONDING; + case BT_SECURITY_MEDIUM: + return SMP_AUTH_BONDING; + default: + return SMP_AUTH_NONE; + } +} + +static void build_pairing_cmd(struct l2cap_conn *conn, + struct smp_cmd_pairing *req, + struct smp_cmd_pairing *rsp, __u8 authreq) +{ + struct l2cap_chan *chan = conn->smp; + struct smp_chan *smp = chan->data; + struct hci_conn *hcon = conn->hcon; + struct hci_dev *hdev = hcon->hdev; + u8 local_dist = 0, remote_dist = 0, oob_flag = SMP_OOB_NOT_PRESENT; + + if (hci_dev_test_flag(hdev, HCI_BONDABLE)) { + local_dist = SMP_DIST_ENC_KEY | SMP_DIST_SIGN; + remote_dist = SMP_DIST_ENC_KEY | SMP_DIST_SIGN; + authreq |= SMP_AUTH_BONDING; + } else { + authreq &= ~SMP_AUTH_BONDING; + } + + if (hci_dev_test_flag(hdev, HCI_RPA_RESOLVING)) + remote_dist |= SMP_DIST_ID_KEY; + + if (hci_dev_test_flag(hdev, HCI_PRIVACY)) + local_dist |= SMP_DIST_ID_KEY; + + if (hci_dev_test_flag(hdev, HCI_SC_ENABLED) && + (authreq & SMP_AUTH_SC)) { + struct oob_data *oob_data; + u8 bdaddr_type; + + if (hci_dev_test_flag(hdev, HCI_SSP_ENABLED)) { + local_dist |= SMP_DIST_LINK_KEY; + remote_dist |= SMP_DIST_LINK_KEY; + } + + if (hcon->dst_type == ADDR_LE_DEV_PUBLIC) + bdaddr_type = BDADDR_LE_PUBLIC; + else + bdaddr_type = BDADDR_LE_RANDOM; + + oob_data = hci_find_remote_oob_data(hdev, &hcon->dst, + bdaddr_type); + if (oob_data && oob_data->present) { + set_bit(SMP_FLAG_REMOTE_OOB, &smp->flags); + oob_flag = SMP_OOB_PRESENT; + memcpy(smp->rr, oob_data->rand256, 16); + memcpy(smp->pcnf, oob_data->hash256, 16); + SMP_DBG("OOB Remote Confirmation: %16phN", smp->pcnf); + SMP_DBG("OOB Remote Random: %16phN", smp->rr); + } + + } else { + authreq &= ~SMP_AUTH_SC; + } + + if (rsp == NULL) { + req->io_capability = conn->hcon->io_capability; + req->oob_flag = oob_flag; + req->max_key_size = hdev->le_max_key_size; + req->init_key_dist = local_dist; + req->resp_key_dist = remote_dist; + req->auth_req = (authreq & AUTH_REQ_MASK(hdev)); + + smp->remote_key_dist = remote_dist; + return; + } + + rsp->io_capability = conn->hcon->io_capability; + rsp->oob_flag = oob_flag; + rsp->max_key_size = hdev->le_max_key_size; + rsp->init_key_dist = req->init_key_dist & remote_dist; + rsp->resp_key_dist = req->resp_key_dist & local_dist; + rsp->auth_req = (authreq & AUTH_REQ_MASK(hdev)); + + smp->remote_key_dist = rsp->init_key_dist; +} + +static u8 check_enc_key_size(struct l2cap_conn *conn, __u8 max_key_size) +{ + struct l2cap_chan *chan = conn->smp; + struct hci_dev *hdev = conn->hcon->hdev; + struct smp_chan *smp = chan->data; + + if (conn->hcon->pending_sec_level == BT_SECURITY_FIPS && + max_key_size != SMP_MAX_ENC_KEY_SIZE) + return SMP_ENC_KEY_SIZE; + + if (max_key_size > hdev->le_max_key_size || + max_key_size < SMP_MIN_ENC_KEY_SIZE) + return SMP_ENC_KEY_SIZE; + + smp->enc_key_size = max_key_size; + + return 0; +} + +static void smp_chan_destroy(struct l2cap_conn *conn) +{ + struct l2cap_chan *chan = conn->smp; + struct smp_chan *smp = chan->data; + struct hci_conn *hcon = conn->hcon; + bool complete; + + BUG_ON(!smp); + + cancel_delayed_work_sync(&smp->security_timer); + + complete = test_bit(SMP_FLAG_COMPLETE, &smp->flags); + mgmt_smp_complete(hcon, complete); + + kfree_sensitive(smp->csrk); + kfree_sensitive(smp->responder_csrk); + kfree_sensitive(smp->link_key); + + crypto_free_shash(smp->tfm_cmac); + crypto_free_kpp(smp->tfm_ecdh); + + /* Ensure that we don't leave any debug key around if debug key + * support hasn't been explicitly enabled. + */ + if (smp->ltk && smp->ltk->type == SMP_LTK_P256_DEBUG && + !hci_dev_test_flag(hcon->hdev, HCI_KEEP_DEBUG_KEYS)) { + list_del_rcu(&smp->ltk->list); + kfree_rcu(smp->ltk, rcu); + smp->ltk = NULL; + } + + /* If pairing failed clean up any keys we might have */ + if (!complete) { + if (smp->ltk) { + list_del_rcu(&smp->ltk->list); + kfree_rcu(smp->ltk, rcu); + } + + if (smp->responder_ltk) { + list_del_rcu(&smp->responder_ltk->list); + kfree_rcu(smp->responder_ltk, rcu); + } + + if (smp->remote_irk) { + list_del_rcu(&smp->remote_irk->list); + kfree_rcu(smp->remote_irk, rcu); + } + } + + chan->data = NULL; + kfree_sensitive(smp); + hci_conn_drop(hcon); +} + +static void smp_failure(struct l2cap_conn *conn, u8 reason) +{ + struct hci_conn *hcon = conn->hcon; + struct l2cap_chan *chan = conn->smp; + + if (reason) + smp_send_cmd(conn, SMP_CMD_PAIRING_FAIL, sizeof(reason), + &reason); + + mgmt_auth_failed(hcon, HCI_ERROR_AUTH_FAILURE); + + if (chan->data) + smp_chan_destroy(conn); +} + +#define JUST_WORKS 0x00 +#define JUST_CFM 0x01 +#define REQ_PASSKEY 0x02 +#define CFM_PASSKEY 0x03 +#define REQ_OOB 0x04 +#define DSP_PASSKEY 0x05 +#define OVERLAP 0xFF + +static const u8 gen_method[5][5] = { + { JUST_WORKS, JUST_CFM, REQ_PASSKEY, JUST_WORKS, REQ_PASSKEY }, + { JUST_WORKS, JUST_CFM, REQ_PASSKEY, JUST_WORKS, REQ_PASSKEY }, + { CFM_PASSKEY, CFM_PASSKEY, REQ_PASSKEY, JUST_WORKS, CFM_PASSKEY }, + { JUST_WORKS, JUST_CFM, JUST_WORKS, JUST_WORKS, JUST_CFM }, + { CFM_PASSKEY, CFM_PASSKEY, REQ_PASSKEY, JUST_WORKS, OVERLAP }, +}; + +static const u8 sc_method[5][5] = { + { JUST_WORKS, JUST_CFM, REQ_PASSKEY, JUST_WORKS, REQ_PASSKEY }, + { JUST_WORKS, CFM_PASSKEY, REQ_PASSKEY, JUST_WORKS, CFM_PASSKEY }, + { DSP_PASSKEY, DSP_PASSKEY, REQ_PASSKEY, JUST_WORKS, DSP_PASSKEY }, + { JUST_WORKS, JUST_CFM, JUST_WORKS, JUST_WORKS, JUST_CFM }, + { DSP_PASSKEY, CFM_PASSKEY, REQ_PASSKEY, JUST_WORKS, CFM_PASSKEY }, +}; + +static u8 get_auth_method(struct smp_chan *smp, u8 local_io, u8 remote_io) +{ + /* If either side has unknown io_caps, use JUST_CFM (which gets + * converted later to JUST_WORKS if we're initiators. + */ + if (local_io > SMP_IO_KEYBOARD_DISPLAY || + remote_io > SMP_IO_KEYBOARD_DISPLAY) + return JUST_CFM; + + if (test_bit(SMP_FLAG_SC, &smp->flags)) + return sc_method[remote_io][local_io]; + + return gen_method[remote_io][local_io]; +} + +static int tk_request(struct l2cap_conn *conn, u8 remote_oob, u8 auth, + u8 local_io, u8 remote_io) +{ + struct hci_conn *hcon = conn->hcon; + struct l2cap_chan *chan = conn->smp; + struct smp_chan *smp = chan->data; + u32 passkey = 0; + int ret; + + /* Initialize key for JUST WORKS */ + memset(smp->tk, 0, sizeof(smp->tk)); + clear_bit(SMP_FLAG_TK_VALID, &smp->flags); + + bt_dev_dbg(hcon->hdev, "auth:%u lcl:%u rem:%u", auth, local_io, + remote_io); + + /* If neither side wants MITM, either "just" confirm an incoming + * request or use just-works for outgoing ones. The JUST_CFM + * will be converted to JUST_WORKS if necessary later in this + * function. If either side has MITM look up the method from the + * table. + */ + if (!(auth & SMP_AUTH_MITM)) + smp->method = JUST_CFM; + else + smp->method = get_auth_method(smp, local_io, remote_io); + + /* Don't confirm locally initiated pairing attempts */ + if (smp->method == JUST_CFM && test_bit(SMP_FLAG_INITIATOR, + &smp->flags)) + smp->method = JUST_WORKS; + + /* Don't bother user space with no IO capabilities */ + if (smp->method == JUST_CFM && + hcon->io_capability == HCI_IO_NO_INPUT_OUTPUT) + smp->method = JUST_WORKS; + + /* If Just Works, Continue with Zero TK and ask user-space for + * confirmation */ + if (smp->method == JUST_WORKS) { + ret = mgmt_user_confirm_request(hcon->hdev, &hcon->dst, + hcon->type, + hcon->dst_type, + passkey, 1); + if (ret) + return ret; + set_bit(SMP_FLAG_WAIT_USER, &smp->flags); + return 0; + } + + /* If this function is used for SC -> legacy fallback we + * can only recover the just-works case. + */ + if (test_bit(SMP_FLAG_SC, &smp->flags)) + return -EINVAL; + + /* Not Just Works/Confirm results in MITM Authentication */ + if (smp->method != JUST_CFM) { + set_bit(SMP_FLAG_MITM_AUTH, &smp->flags); + if (hcon->pending_sec_level < BT_SECURITY_HIGH) + hcon->pending_sec_level = BT_SECURITY_HIGH; + } + + /* If both devices have Keyboard-Display I/O, the initiator + * Confirms and the responder Enters the passkey. + */ + if (smp->method == OVERLAP) { + if (hcon->role == HCI_ROLE_MASTER) + smp->method = CFM_PASSKEY; + else + smp->method = REQ_PASSKEY; + } + + /* Generate random passkey. */ + if (smp->method == CFM_PASSKEY) { + memset(smp->tk, 0, sizeof(smp->tk)); + get_random_bytes(&passkey, sizeof(passkey)); + passkey %= 1000000; + put_unaligned_le32(passkey, smp->tk); + bt_dev_dbg(hcon->hdev, "PassKey: %u", passkey); + set_bit(SMP_FLAG_TK_VALID, &smp->flags); + } + + if (smp->method == REQ_PASSKEY) + ret = mgmt_user_passkey_request(hcon->hdev, &hcon->dst, + hcon->type, hcon->dst_type); + else if (smp->method == JUST_CFM) + ret = mgmt_user_confirm_request(hcon->hdev, &hcon->dst, + hcon->type, hcon->dst_type, + passkey, 1); + else + ret = mgmt_user_passkey_notify(hcon->hdev, &hcon->dst, + hcon->type, hcon->dst_type, + passkey, 0); + + return ret; +} + +static u8 smp_confirm(struct smp_chan *smp) +{ + struct l2cap_conn *conn = smp->conn; + struct smp_cmd_pairing_confirm cp; + int ret; + + bt_dev_dbg(conn->hcon->hdev, "conn %p", conn); + + ret = smp_c1(smp->tk, smp->prnd, smp->preq, smp->prsp, + conn->hcon->init_addr_type, &conn->hcon->init_addr, + conn->hcon->resp_addr_type, &conn->hcon->resp_addr, + cp.confirm_val); + if (ret) + return SMP_UNSPECIFIED; + + clear_bit(SMP_FLAG_CFM_PENDING, &smp->flags); + + smp_send_cmd(smp->conn, SMP_CMD_PAIRING_CONFIRM, sizeof(cp), &cp); + + if (conn->hcon->out) + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_CONFIRM); + else + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_RANDOM); + + return 0; +} + +static u8 smp_random(struct smp_chan *smp) +{ + struct l2cap_conn *conn = smp->conn; + struct hci_conn *hcon = conn->hcon; + u8 confirm[16]; + int ret; + + bt_dev_dbg(conn->hcon->hdev, "conn %p %s", conn, + conn->hcon->out ? "initiator" : "responder"); + + ret = smp_c1(smp->tk, smp->rrnd, smp->preq, smp->prsp, + hcon->init_addr_type, &hcon->init_addr, + hcon->resp_addr_type, &hcon->resp_addr, confirm); + if (ret) + return SMP_UNSPECIFIED; + + if (crypto_memneq(smp->pcnf, confirm, sizeof(smp->pcnf))) { + bt_dev_err(hcon->hdev, "pairing failed " + "(confirmation values mismatch)"); + return SMP_CONFIRM_FAILED; + } + + if (hcon->out) { + u8 stk[16]; + __le64 rand = 0; + __le16 ediv = 0; + + smp_s1(smp->tk, smp->rrnd, smp->prnd, stk); + + if (test_and_set_bit(HCI_CONN_ENCRYPT_PEND, &hcon->flags)) + return SMP_UNSPECIFIED; + + hci_le_start_enc(hcon, ediv, rand, stk, smp->enc_key_size); + hcon->enc_key_size = smp->enc_key_size; + set_bit(HCI_CONN_STK_ENCRYPT, &hcon->flags); + } else { + u8 stk[16], auth; + __le64 rand = 0; + __le16 ediv = 0; + + smp_send_cmd(conn, SMP_CMD_PAIRING_RANDOM, sizeof(smp->prnd), + smp->prnd); + + smp_s1(smp->tk, smp->prnd, smp->rrnd, stk); + + if (hcon->pending_sec_level == BT_SECURITY_HIGH) + auth = 1; + else + auth = 0; + + /* Even though there's no _RESPONDER suffix this is the + * responder STK we're adding for later lookup (the initiator + * STK never needs to be stored). + */ + hci_add_ltk(hcon->hdev, &hcon->dst, hcon->dst_type, + SMP_STK, auth, stk, smp->enc_key_size, ediv, rand); + } + + return 0; +} + +static void smp_notify_keys(struct l2cap_conn *conn) +{ + struct l2cap_chan *chan = conn->smp; + struct smp_chan *smp = chan->data; + struct hci_conn *hcon = conn->hcon; + struct hci_dev *hdev = hcon->hdev; + struct smp_cmd_pairing *req = (void *) &smp->preq[1]; + struct smp_cmd_pairing *rsp = (void *) &smp->prsp[1]; + bool persistent; + + if (hcon->type == ACL_LINK) { + if (hcon->key_type == HCI_LK_DEBUG_COMBINATION) + persistent = false; + else + persistent = !test_bit(HCI_CONN_FLUSH_KEY, + &hcon->flags); + } else { + /* The LTKs, IRKs and CSRKs should be persistent only if + * both sides had the bonding bit set in their + * authentication requests. + */ + persistent = !!((req->auth_req & rsp->auth_req) & + SMP_AUTH_BONDING); + } + + if (smp->remote_irk) { + smp->remote_irk->link_type = hcon->type; + mgmt_new_irk(hdev, smp->remote_irk, persistent); + + /* Now that user space can be considered to know the + * identity address track the connection based on it + * from now on (assuming this is an LE link). + */ + if (hcon->type == LE_LINK) { + bacpy(&hcon->dst, &smp->remote_irk->bdaddr); + hcon->dst_type = smp->remote_irk->addr_type; + queue_work(hdev->workqueue, &conn->id_addr_update_work); + } + } + + if (smp->csrk) { + smp->csrk->link_type = hcon->type; + smp->csrk->bdaddr_type = hcon->dst_type; + bacpy(&smp->csrk->bdaddr, &hcon->dst); + mgmt_new_csrk(hdev, smp->csrk, persistent); + } + + if (smp->responder_csrk) { + smp->responder_csrk->link_type = hcon->type; + smp->responder_csrk->bdaddr_type = hcon->dst_type; + bacpy(&smp->responder_csrk->bdaddr, &hcon->dst); + mgmt_new_csrk(hdev, smp->responder_csrk, persistent); + } + + if (smp->ltk) { + smp->ltk->link_type = hcon->type; + smp->ltk->bdaddr_type = hcon->dst_type; + bacpy(&smp->ltk->bdaddr, &hcon->dst); + mgmt_new_ltk(hdev, smp->ltk, persistent); + } + + if (smp->responder_ltk) { + smp->responder_ltk->link_type = hcon->type; + smp->responder_ltk->bdaddr_type = hcon->dst_type; + bacpy(&smp->responder_ltk->bdaddr, &hcon->dst); + mgmt_new_ltk(hdev, smp->responder_ltk, persistent); + } + + if (smp->link_key) { + struct link_key *key; + u8 type; + + if (test_bit(SMP_FLAG_DEBUG_KEY, &smp->flags)) + type = HCI_LK_DEBUG_COMBINATION; + else if (hcon->sec_level == BT_SECURITY_FIPS) + type = HCI_LK_AUTH_COMBINATION_P256; + else + type = HCI_LK_UNAUTH_COMBINATION_P256; + + key = hci_add_link_key(hdev, smp->conn->hcon, &hcon->dst, + smp->link_key, type, 0, &persistent); + if (key) { + key->link_type = hcon->type; + key->bdaddr_type = hcon->dst_type; + mgmt_new_link_key(hdev, key, persistent); + + /* Don't keep debug keys around if the relevant + * flag is not set. + */ + if (!hci_dev_test_flag(hdev, HCI_KEEP_DEBUG_KEYS) && + key->type == HCI_LK_DEBUG_COMBINATION) { + list_del_rcu(&key->list); + kfree_rcu(key, rcu); + } + } + } +} + +static void sc_add_ltk(struct smp_chan *smp) +{ + struct hci_conn *hcon = smp->conn->hcon; + u8 key_type, auth; + + if (test_bit(SMP_FLAG_DEBUG_KEY, &smp->flags)) + key_type = SMP_LTK_P256_DEBUG; + else + key_type = SMP_LTK_P256; + + if (hcon->pending_sec_level == BT_SECURITY_FIPS) + auth = 1; + else + auth = 0; + + smp->ltk = hci_add_ltk(hcon->hdev, &hcon->dst, hcon->dst_type, + key_type, auth, smp->tk, smp->enc_key_size, + 0, 0); +} + +static void sc_generate_link_key(struct smp_chan *smp) +{ + /* From core spec. Spells out in ASCII as 'lebr'. */ + const u8 lebr[4] = { 0x72, 0x62, 0x65, 0x6c }; + + smp->link_key = kzalloc(16, GFP_KERNEL); + if (!smp->link_key) + return; + + if (test_bit(SMP_FLAG_CT2, &smp->flags)) { + /* SALT = 0x000000000000000000000000746D7031 */ + const u8 salt[16] = { 0x31, 0x70, 0x6d, 0x74 }; + + if (smp_h7(smp->tfm_cmac, smp->tk, salt, smp->link_key)) { + kfree_sensitive(smp->link_key); + smp->link_key = NULL; + return; + } + } else { + /* From core spec. Spells out in ASCII as 'tmp1'. */ + const u8 tmp1[4] = { 0x31, 0x70, 0x6d, 0x74 }; + + if (smp_h6(smp->tfm_cmac, smp->tk, tmp1, smp->link_key)) { + kfree_sensitive(smp->link_key); + smp->link_key = NULL; + return; + } + } + + if (smp_h6(smp->tfm_cmac, smp->link_key, lebr, smp->link_key)) { + kfree_sensitive(smp->link_key); + smp->link_key = NULL; + return; + } +} + +static void smp_allow_key_dist(struct smp_chan *smp) +{ + /* Allow the first expected phase 3 PDU. The rest of the PDUs + * will be allowed in each PDU handler to ensure we receive + * them in the correct order. + */ + if (smp->remote_key_dist & SMP_DIST_ENC_KEY) + SMP_ALLOW_CMD(smp, SMP_CMD_ENCRYPT_INFO); + else if (smp->remote_key_dist & SMP_DIST_ID_KEY) + SMP_ALLOW_CMD(smp, SMP_CMD_IDENT_INFO); + else if (smp->remote_key_dist & SMP_DIST_SIGN) + SMP_ALLOW_CMD(smp, SMP_CMD_SIGN_INFO); +} + +static void sc_generate_ltk(struct smp_chan *smp) +{ + /* From core spec. Spells out in ASCII as 'brle'. */ + const u8 brle[4] = { 0x65, 0x6c, 0x72, 0x62 }; + struct hci_conn *hcon = smp->conn->hcon; + struct hci_dev *hdev = hcon->hdev; + struct link_key *key; + + key = hci_find_link_key(hdev, &hcon->dst); + if (!key) { + bt_dev_err(hdev, "no Link Key found to generate LTK"); + return; + } + + if (key->type == HCI_LK_DEBUG_COMBINATION) + set_bit(SMP_FLAG_DEBUG_KEY, &smp->flags); + + if (test_bit(SMP_FLAG_CT2, &smp->flags)) { + /* SALT = 0x000000000000000000000000746D7032 */ + const u8 salt[16] = { 0x32, 0x70, 0x6d, 0x74 }; + + if (smp_h7(smp->tfm_cmac, key->val, salt, smp->tk)) + return; + } else { + /* From core spec. Spells out in ASCII as 'tmp2'. */ + const u8 tmp2[4] = { 0x32, 0x70, 0x6d, 0x74 }; + + if (smp_h6(smp->tfm_cmac, key->val, tmp2, smp->tk)) + return; + } + + if (smp_h6(smp->tfm_cmac, smp->tk, brle, smp->tk)) + return; + + sc_add_ltk(smp); +} + +static void smp_distribute_keys(struct smp_chan *smp) +{ + struct smp_cmd_pairing *req, *rsp; + struct l2cap_conn *conn = smp->conn; + struct hci_conn *hcon = conn->hcon; + struct hci_dev *hdev = hcon->hdev; + __u8 *keydist; + + bt_dev_dbg(hdev, "conn %p", conn); + + rsp = (void *) &smp->prsp[1]; + + /* The responder sends its keys first */ + if (hcon->out && (smp->remote_key_dist & KEY_DIST_MASK)) { + smp_allow_key_dist(smp); + return; + } + + req = (void *) &smp->preq[1]; + + if (hcon->out) { + keydist = &rsp->init_key_dist; + *keydist &= req->init_key_dist; + } else { + keydist = &rsp->resp_key_dist; + *keydist &= req->resp_key_dist; + } + + if (test_bit(SMP_FLAG_SC, &smp->flags)) { + if (hcon->type == LE_LINK && (*keydist & SMP_DIST_LINK_KEY)) + sc_generate_link_key(smp); + if (hcon->type == ACL_LINK && (*keydist & SMP_DIST_ENC_KEY)) + sc_generate_ltk(smp); + + /* Clear the keys which are generated but not distributed */ + *keydist &= ~SMP_SC_NO_DIST; + } + + bt_dev_dbg(hdev, "keydist 0x%x", *keydist); + + if (*keydist & SMP_DIST_ENC_KEY) { + struct smp_cmd_encrypt_info enc; + struct smp_cmd_initiator_ident ident; + struct smp_ltk *ltk; + u8 authenticated; + __le16 ediv; + __le64 rand; + + /* Make sure we generate only the significant amount of + * bytes based on the encryption key size, and set the rest + * of the value to zeroes. + */ + get_random_bytes(enc.ltk, smp->enc_key_size); + memset(enc.ltk + smp->enc_key_size, 0, + sizeof(enc.ltk) - smp->enc_key_size); + + get_random_bytes(&ediv, sizeof(ediv)); + get_random_bytes(&rand, sizeof(rand)); + + smp_send_cmd(conn, SMP_CMD_ENCRYPT_INFO, sizeof(enc), &enc); + + authenticated = hcon->sec_level == BT_SECURITY_HIGH; + ltk = hci_add_ltk(hdev, &hcon->dst, hcon->dst_type, + SMP_LTK_RESPONDER, authenticated, enc.ltk, + smp->enc_key_size, ediv, rand); + smp->responder_ltk = ltk; + + ident.ediv = ediv; + ident.rand = rand; + + smp_send_cmd(conn, SMP_CMD_INITIATOR_IDENT, sizeof(ident), + &ident); + + *keydist &= ~SMP_DIST_ENC_KEY; + } + + if (*keydist & SMP_DIST_ID_KEY) { + struct smp_cmd_ident_addr_info addrinfo; + struct smp_cmd_ident_info idinfo; + + memcpy(idinfo.irk, hdev->irk, sizeof(idinfo.irk)); + + smp_send_cmd(conn, SMP_CMD_IDENT_INFO, sizeof(idinfo), &idinfo); + + /* The hci_conn contains the local identity address + * after the connection has been established. + * + * This is true even when the connection has been + * established using a resolvable random address. + */ + bacpy(&addrinfo.bdaddr, &hcon->src); + addrinfo.addr_type = hcon->src_type; + + smp_send_cmd(conn, SMP_CMD_IDENT_ADDR_INFO, sizeof(addrinfo), + &addrinfo); + + *keydist &= ~SMP_DIST_ID_KEY; + } + + if (*keydist & SMP_DIST_SIGN) { + struct smp_cmd_sign_info sign; + struct smp_csrk *csrk; + + /* Generate a new random key */ + get_random_bytes(sign.csrk, sizeof(sign.csrk)); + + csrk = kzalloc(sizeof(*csrk), GFP_KERNEL); + if (csrk) { + if (hcon->sec_level > BT_SECURITY_MEDIUM) + csrk->type = MGMT_CSRK_LOCAL_AUTHENTICATED; + else + csrk->type = MGMT_CSRK_LOCAL_UNAUTHENTICATED; + memcpy(csrk->val, sign.csrk, sizeof(csrk->val)); + } + smp->responder_csrk = csrk; + + smp_send_cmd(conn, SMP_CMD_SIGN_INFO, sizeof(sign), &sign); + + *keydist &= ~SMP_DIST_SIGN; + } + + /* If there are still keys to be received wait for them */ + if (smp->remote_key_dist & KEY_DIST_MASK) { + smp_allow_key_dist(smp); + return; + } + + set_bit(SMP_FLAG_COMPLETE, &smp->flags); + smp_notify_keys(conn); + + smp_chan_destroy(conn); +} + +static void smp_timeout(struct work_struct *work) +{ + struct smp_chan *smp = container_of(work, struct smp_chan, + security_timer.work); + struct l2cap_conn *conn = smp->conn; + + bt_dev_dbg(conn->hcon->hdev, "conn %p", conn); + + hci_disconnect(conn->hcon, HCI_ERROR_REMOTE_USER_TERM); +} + +static struct smp_chan *smp_chan_create(struct l2cap_conn *conn) +{ + struct hci_conn *hcon = conn->hcon; + struct l2cap_chan *chan = conn->smp; + struct smp_chan *smp; + + smp = kzalloc(sizeof(*smp), GFP_ATOMIC); + if (!smp) + return NULL; + + smp->tfm_cmac = crypto_alloc_shash("cmac(aes)", 0, 0); + if (IS_ERR(smp->tfm_cmac)) { + bt_dev_err(hcon->hdev, "Unable to create CMAC crypto context"); + goto zfree_smp; + } + + smp->tfm_ecdh = crypto_alloc_kpp("ecdh-nist-p256", 0, 0); + if (IS_ERR(smp->tfm_ecdh)) { + bt_dev_err(hcon->hdev, "Unable to create ECDH crypto context"); + goto free_shash; + } + + smp->conn = conn; + chan->data = smp; + + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_FAIL); + + INIT_DELAYED_WORK(&smp->security_timer, smp_timeout); + + hci_conn_hold(hcon); + + return smp; + +free_shash: + crypto_free_shash(smp->tfm_cmac); +zfree_smp: + kfree_sensitive(smp); + return NULL; +} + +static int sc_mackey_and_ltk(struct smp_chan *smp, u8 mackey[16], u8 ltk[16]) +{ + struct hci_conn *hcon = smp->conn->hcon; + u8 *na, *nb, a[7], b[7]; + + if (hcon->out) { + na = smp->prnd; + nb = smp->rrnd; + } else { + na = smp->rrnd; + nb = smp->prnd; + } + + memcpy(a, &hcon->init_addr, 6); + memcpy(b, &hcon->resp_addr, 6); + a[6] = hcon->init_addr_type; + b[6] = hcon->resp_addr_type; + + return smp_f5(smp->tfm_cmac, smp->dhkey, na, nb, a, b, mackey, ltk); +} + +static void sc_dhkey_check(struct smp_chan *smp) +{ + struct hci_conn *hcon = smp->conn->hcon; + struct smp_cmd_dhkey_check check; + u8 a[7], b[7], *local_addr, *remote_addr; + u8 io_cap[3], r[16]; + + memcpy(a, &hcon->init_addr, 6); + memcpy(b, &hcon->resp_addr, 6); + a[6] = hcon->init_addr_type; + b[6] = hcon->resp_addr_type; + + if (hcon->out) { + local_addr = a; + remote_addr = b; + memcpy(io_cap, &smp->preq[1], 3); + } else { + local_addr = b; + remote_addr = a; + memcpy(io_cap, &smp->prsp[1], 3); + } + + memset(r, 0, sizeof(r)); + + if (smp->method == REQ_PASSKEY || smp->method == DSP_PASSKEY) + put_unaligned_le32(hcon->passkey_notify, r); + + if (smp->method == REQ_OOB) + memcpy(r, smp->rr, 16); + + smp_f6(smp->tfm_cmac, smp->mackey, smp->prnd, smp->rrnd, r, io_cap, + local_addr, remote_addr, check.e); + + smp_send_cmd(smp->conn, SMP_CMD_DHKEY_CHECK, sizeof(check), &check); +} + +static u8 sc_passkey_send_confirm(struct smp_chan *smp) +{ + struct l2cap_conn *conn = smp->conn; + struct hci_conn *hcon = conn->hcon; + struct smp_cmd_pairing_confirm cfm; + u8 r; + + r = ((hcon->passkey_notify >> smp->passkey_round) & 0x01); + r |= 0x80; + + get_random_bytes(smp->prnd, sizeof(smp->prnd)); + + if (smp_f4(smp->tfm_cmac, smp->local_pk, smp->remote_pk, smp->prnd, r, + cfm.confirm_val)) + return SMP_UNSPECIFIED; + + smp_send_cmd(conn, SMP_CMD_PAIRING_CONFIRM, sizeof(cfm), &cfm); + + return 0; +} + +static u8 sc_passkey_round(struct smp_chan *smp, u8 smp_op) +{ + struct l2cap_conn *conn = smp->conn; + struct hci_conn *hcon = conn->hcon; + struct hci_dev *hdev = hcon->hdev; + u8 cfm[16], r; + + /* Ignore the PDU if we've already done 20 rounds (0 - 19) */ + if (smp->passkey_round >= 20) + return 0; + + switch (smp_op) { + case SMP_CMD_PAIRING_RANDOM: + r = ((hcon->passkey_notify >> smp->passkey_round) & 0x01); + r |= 0x80; + + if (smp_f4(smp->tfm_cmac, smp->remote_pk, smp->local_pk, + smp->rrnd, r, cfm)) + return SMP_UNSPECIFIED; + + if (crypto_memneq(smp->pcnf, cfm, 16)) + return SMP_CONFIRM_FAILED; + + smp->passkey_round++; + + if (smp->passkey_round == 20) { + /* Generate MacKey and LTK */ + if (sc_mackey_and_ltk(smp, smp->mackey, smp->tk)) + return SMP_UNSPECIFIED; + } + + /* The round is only complete when the initiator + * receives pairing random. + */ + if (!hcon->out) { + smp_send_cmd(conn, SMP_CMD_PAIRING_RANDOM, + sizeof(smp->prnd), smp->prnd); + if (smp->passkey_round == 20) + SMP_ALLOW_CMD(smp, SMP_CMD_DHKEY_CHECK); + else + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_CONFIRM); + return 0; + } + + /* Start the next round */ + if (smp->passkey_round != 20) + return sc_passkey_round(smp, 0); + + /* Passkey rounds are complete - start DHKey Check */ + sc_dhkey_check(smp); + SMP_ALLOW_CMD(smp, SMP_CMD_DHKEY_CHECK); + + break; + + case SMP_CMD_PAIRING_CONFIRM: + if (test_bit(SMP_FLAG_WAIT_USER, &smp->flags)) { + set_bit(SMP_FLAG_CFM_PENDING, &smp->flags); + return 0; + } + + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_RANDOM); + + if (hcon->out) { + smp_send_cmd(conn, SMP_CMD_PAIRING_RANDOM, + sizeof(smp->prnd), smp->prnd); + return 0; + } + + return sc_passkey_send_confirm(smp); + + case SMP_CMD_PUBLIC_KEY: + default: + /* Initiating device starts the round */ + if (!hcon->out) + return 0; + + bt_dev_dbg(hdev, "Starting passkey round %u", + smp->passkey_round + 1); + + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_CONFIRM); + + return sc_passkey_send_confirm(smp); + } + + return 0; +} + +static int sc_user_reply(struct smp_chan *smp, u16 mgmt_op, __le32 passkey) +{ + struct l2cap_conn *conn = smp->conn; + struct hci_conn *hcon = conn->hcon; + u8 smp_op; + + clear_bit(SMP_FLAG_WAIT_USER, &smp->flags); + + switch (mgmt_op) { + case MGMT_OP_USER_PASSKEY_NEG_REPLY: + smp_failure(smp->conn, SMP_PASSKEY_ENTRY_FAILED); + return 0; + case MGMT_OP_USER_CONFIRM_NEG_REPLY: + smp_failure(smp->conn, SMP_NUMERIC_COMP_FAILED); + return 0; + case MGMT_OP_USER_PASSKEY_REPLY: + hcon->passkey_notify = le32_to_cpu(passkey); + smp->passkey_round = 0; + + if (test_and_clear_bit(SMP_FLAG_CFM_PENDING, &smp->flags)) + smp_op = SMP_CMD_PAIRING_CONFIRM; + else + smp_op = 0; + + if (sc_passkey_round(smp, smp_op)) + return -EIO; + + return 0; + } + + /* Initiator sends DHKey check first */ + if (hcon->out) { + sc_dhkey_check(smp); + SMP_ALLOW_CMD(smp, SMP_CMD_DHKEY_CHECK); + } else if (test_and_clear_bit(SMP_FLAG_DHKEY_PENDING, &smp->flags)) { + sc_dhkey_check(smp); + sc_add_ltk(smp); + } + + return 0; +} + +int smp_user_confirm_reply(struct hci_conn *hcon, u16 mgmt_op, __le32 passkey) +{ + struct l2cap_conn *conn = hcon->l2cap_data; + struct l2cap_chan *chan; + struct smp_chan *smp; + u32 value; + int err; + + if (!conn) + return -ENOTCONN; + + bt_dev_dbg(conn->hcon->hdev, ""); + + chan = conn->smp; + if (!chan) + return -ENOTCONN; + + l2cap_chan_lock(chan); + if (!chan->data) { + err = -ENOTCONN; + goto unlock; + } + + smp = chan->data; + + if (test_bit(SMP_FLAG_SC, &smp->flags)) { + err = sc_user_reply(smp, mgmt_op, passkey); + goto unlock; + } + + switch (mgmt_op) { + case MGMT_OP_USER_PASSKEY_REPLY: + value = le32_to_cpu(passkey); + memset(smp->tk, 0, sizeof(smp->tk)); + bt_dev_dbg(conn->hcon->hdev, "PassKey: %u", value); + put_unaligned_le32(value, smp->tk); + fallthrough; + case MGMT_OP_USER_CONFIRM_REPLY: + set_bit(SMP_FLAG_TK_VALID, &smp->flags); + break; + case MGMT_OP_USER_PASSKEY_NEG_REPLY: + case MGMT_OP_USER_CONFIRM_NEG_REPLY: + smp_failure(conn, SMP_PASSKEY_ENTRY_FAILED); + err = 0; + goto unlock; + default: + smp_failure(conn, SMP_PASSKEY_ENTRY_FAILED); + err = -EOPNOTSUPP; + goto unlock; + } + + err = 0; + + /* If it is our turn to send Pairing Confirm, do so now */ + if (test_bit(SMP_FLAG_CFM_PENDING, &smp->flags)) { + u8 rsp = smp_confirm(smp); + if (rsp) + smp_failure(conn, rsp); + } + +unlock: + l2cap_chan_unlock(chan); + return err; +} + +static void build_bredr_pairing_cmd(struct smp_chan *smp, + struct smp_cmd_pairing *req, + struct smp_cmd_pairing *rsp) +{ + struct l2cap_conn *conn = smp->conn; + struct hci_dev *hdev = conn->hcon->hdev; + u8 local_dist = 0, remote_dist = 0; + + if (hci_dev_test_flag(hdev, HCI_BONDABLE)) { + local_dist = SMP_DIST_ENC_KEY | SMP_DIST_SIGN; + remote_dist = SMP_DIST_ENC_KEY | SMP_DIST_SIGN; + } + + if (hci_dev_test_flag(hdev, HCI_RPA_RESOLVING)) + remote_dist |= SMP_DIST_ID_KEY; + + if (hci_dev_test_flag(hdev, HCI_PRIVACY)) + local_dist |= SMP_DIST_ID_KEY; + + if (!rsp) { + memset(req, 0, sizeof(*req)); + + req->auth_req = SMP_AUTH_CT2; + req->init_key_dist = local_dist; + req->resp_key_dist = remote_dist; + req->max_key_size = conn->hcon->enc_key_size; + + smp->remote_key_dist = remote_dist; + + return; + } + + memset(rsp, 0, sizeof(*rsp)); + + rsp->auth_req = SMP_AUTH_CT2; + rsp->max_key_size = conn->hcon->enc_key_size; + rsp->init_key_dist = req->init_key_dist & remote_dist; + rsp->resp_key_dist = req->resp_key_dist & local_dist; + + smp->remote_key_dist = rsp->init_key_dist; +} + +static u8 smp_cmd_pairing_req(struct l2cap_conn *conn, struct sk_buff *skb) +{ + struct smp_cmd_pairing rsp, *req = (void *) skb->data; + struct l2cap_chan *chan = conn->smp; + struct hci_dev *hdev = conn->hcon->hdev; + struct smp_chan *smp; + u8 key_size, auth, sec_level; + int ret; + + bt_dev_dbg(hdev, "conn %p", conn); + + if (skb->len < sizeof(*req)) + return SMP_INVALID_PARAMS; + + if (conn->hcon->role != HCI_ROLE_SLAVE) + return SMP_CMD_NOTSUPP; + + if (!chan->data) + smp = smp_chan_create(conn); + else + smp = chan->data; + + if (!smp) + return SMP_UNSPECIFIED; + + /* We didn't start the pairing, so match remote */ + auth = req->auth_req & AUTH_REQ_MASK(hdev); + + if (!hci_dev_test_flag(hdev, HCI_BONDABLE) && + (auth & SMP_AUTH_BONDING)) + return SMP_PAIRING_NOTSUPP; + + if (hci_dev_test_flag(hdev, HCI_SC_ONLY) && !(auth & SMP_AUTH_SC)) + return SMP_AUTH_REQUIREMENTS; + + smp->preq[0] = SMP_CMD_PAIRING_REQ; + memcpy(&smp->preq[1], req, sizeof(*req)); + skb_pull(skb, sizeof(*req)); + + /* If the remote side's OOB flag is set it means it has + * successfully received our local OOB data - therefore set the + * flag to indicate that local OOB is in use. + */ + if (req->oob_flag == SMP_OOB_PRESENT && SMP_DEV(hdev)->local_oob) + set_bit(SMP_FLAG_LOCAL_OOB, &smp->flags); + + /* SMP over BR/EDR requires special treatment */ + if (conn->hcon->type == ACL_LINK) { + /* We must have a BR/EDR SC link */ + if (!test_bit(HCI_CONN_AES_CCM, &conn->hcon->flags) && + !hci_dev_test_flag(hdev, HCI_FORCE_BREDR_SMP)) + return SMP_CROSS_TRANSP_NOT_ALLOWED; + + set_bit(SMP_FLAG_SC, &smp->flags); + + build_bredr_pairing_cmd(smp, req, &rsp); + + if (req->auth_req & SMP_AUTH_CT2) + set_bit(SMP_FLAG_CT2, &smp->flags); + + key_size = min(req->max_key_size, rsp.max_key_size); + if (check_enc_key_size(conn, key_size)) + return SMP_ENC_KEY_SIZE; + + /* Clear bits which are generated but not distributed */ + smp->remote_key_dist &= ~SMP_SC_NO_DIST; + + smp->prsp[0] = SMP_CMD_PAIRING_RSP; + memcpy(&smp->prsp[1], &rsp, sizeof(rsp)); + smp_send_cmd(conn, SMP_CMD_PAIRING_RSP, sizeof(rsp), &rsp); + + smp_distribute_keys(smp); + return 0; + } + + build_pairing_cmd(conn, req, &rsp, auth); + + if (rsp.auth_req & SMP_AUTH_SC) { + set_bit(SMP_FLAG_SC, &smp->flags); + + if (rsp.auth_req & SMP_AUTH_CT2) + set_bit(SMP_FLAG_CT2, &smp->flags); + } + + if (conn->hcon->io_capability == HCI_IO_NO_INPUT_OUTPUT) + sec_level = BT_SECURITY_MEDIUM; + else + sec_level = authreq_to_seclevel(auth); + + if (sec_level > conn->hcon->pending_sec_level) + conn->hcon->pending_sec_level = sec_level; + + /* If we need MITM check that it can be achieved */ + if (conn->hcon->pending_sec_level >= BT_SECURITY_HIGH) { + u8 method; + + method = get_auth_method(smp, conn->hcon->io_capability, + req->io_capability); + if (method == JUST_WORKS || method == JUST_CFM) + return SMP_AUTH_REQUIREMENTS; + } + + key_size = min(req->max_key_size, rsp.max_key_size); + if (check_enc_key_size(conn, key_size)) + return SMP_ENC_KEY_SIZE; + + get_random_bytes(smp->prnd, sizeof(smp->prnd)); + + smp->prsp[0] = SMP_CMD_PAIRING_RSP; + memcpy(&smp->prsp[1], &rsp, sizeof(rsp)); + + smp_send_cmd(conn, SMP_CMD_PAIRING_RSP, sizeof(rsp), &rsp); + + clear_bit(SMP_FLAG_INITIATOR, &smp->flags); + + /* Strictly speaking we shouldn't allow Pairing Confirm for the + * SC case, however some implementations incorrectly copy RFU auth + * req bits from our security request, which may create a false + * positive SC enablement. + */ + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_CONFIRM); + + if (test_bit(SMP_FLAG_SC, &smp->flags)) { + SMP_ALLOW_CMD(smp, SMP_CMD_PUBLIC_KEY); + /* Clear bits which are generated but not distributed */ + smp->remote_key_dist &= ~SMP_SC_NO_DIST; + /* Wait for Public Key from Initiating Device */ + return 0; + } + + /* Request setup of TK */ + ret = tk_request(conn, 0, auth, rsp.io_capability, req->io_capability); + if (ret) + return SMP_UNSPECIFIED; + + return 0; +} + +static u8 sc_send_public_key(struct smp_chan *smp) +{ + struct hci_dev *hdev = smp->conn->hcon->hdev; + + bt_dev_dbg(hdev, ""); + + if (test_bit(SMP_FLAG_LOCAL_OOB, &smp->flags)) { + struct l2cap_chan *chan = hdev->smp_data; + struct smp_dev *smp_dev; + + if (!chan || !chan->data) + return SMP_UNSPECIFIED; + + smp_dev = chan->data; + + memcpy(smp->local_pk, smp_dev->local_pk, 64); + memcpy(smp->lr, smp_dev->local_rand, 16); + + if (smp_dev->debug_key) + set_bit(SMP_FLAG_DEBUG_KEY, &smp->flags); + + goto done; + } + + if (hci_dev_test_flag(hdev, HCI_USE_DEBUG_KEYS)) { + bt_dev_dbg(hdev, "Using debug keys"); + if (set_ecdh_privkey(smp->tfm_ecdh, debug_sk)) + return SMP_UNSPECIFIED; + memcpy(smp->local_pk, debug_pk, 64); + set_bit(SMP_FLAG_DEBUG_KEY, &smp->flags); + } else { + while (true) { + /* Generate key pair for Secure Connections */ + if (generate_ecdh_keys(smp->tfm_ecdh, smp->local_pk)) + return SMP_UNSPECIFIED; + + /* This is unlikely, but we need to check that + * we didn't accidentally generate a debug key. + */ + if (crypto_memneq(smp->local_pk, debug_pk, 64)) + break; + } + } + +done: + SMP_DBG("Local Public Key X: %32phN", smp->local_pk); + SMP_DBG("Local Public Key Y: %32phN", smp->local_pk + 32); + + smp_send_cmd(smp->conn, SMP_CMD_PUBLIC_KEY, 64, smp->local_pk); + + return 0; +} + +static u8 smp_cmd_pairing_rsp(struct l2cap_conn *conn, struct sk_buff *skb) +{ + struct smp_cmd_pairing *req, *rsp = (void *) skb->data; + struct l2cap_chan *chan = conn->smp; + struct smp_chan *smp = chan->data; + struct hci_dev *hdev = conn->hcon->hdev; + u8 key_size, auth; + int ret; + + bt_dev_dbg(hdev, "conn %p", conn); + + if (skb->len < sizeof(*rsp)) + return SMP_INVALID_PARAMS; + + if (conn->hcon->role != HCI_ROLE_MASTER) + return SMP_CMD_NOTSUPP; + + skb_pull(skb, sizeof(*rsp)); + + req = (void *) &smp->preq[1]; + + key_size = min(req->max_key_size, rsp->max_key_size); + if (check_enc_key_size(conn, key_size)) + return SMP_ENC_KEY_SIZE; + + auth = rsp->auth_req & AUTH_REQ_MASK(hdev); + + if (hci_dev_test_flag(hdev, HCI_SC_ONLY) && !(auth & SMP_AUTH_SC)) + return SMP_AUTH_REQUIREMENTS; + + /* If the remote side's OOB flag is set it means it has + * successfully received our local OOB data - therefore set the + * flag to indicate that local OOB is in use. + */ + if (rsp->oob_flag == SMP_OOB_PRESENT && SMP_DEV(hdev)->local_oob) + set_bit(SMP_FLAG_LOCAL_OOB, &smp->flags); + + smp->prsp[0] = SMP_CMD_PAIRING_RSP; + memcpy(&smp->prsp[1], rsp, sizeof(*rsp)); + + /* Update remote key distribution in case the remote cleared + * some bits that we had enabled in our request. + */ + smp->remote_key_dist &= rsp->resp_key_dist; + + if ((req->auth_req & SMP_AUTH_CT2) && (auth & SMP_AUTH_CT2)) + set_bit(SMP_FLAG_CT2, &smp->flags); + + /* For BR/EDR this means we're done and can start phase 3 */ + if (conn->hcon->type == ACL_LINK) { + /* Clear bits which are generated but not distributed */ + smp->remote_key_dist &= ~SMP_SC_NO_DIST; + smp_distribute_keys(smp); + return 0; + } + + if ((req->auth_req & SMP_AUTH_SC) && (auth & SMP_AUTH_SC)) + set_bit(SMP_FLAG_SC, &smp->flags); + else if (conn->hcon->pending_sec_level > BT_SECURITY_HIGH) + conn->hcon->pending_sec_level = BT_SECURITY_HIGH; + + /* If we need MITM check that it can be achieved */ + if (conn->hcon->pending_sec_level >= BT_SECURITY_HIGH) { + u8 method; + + method = get_auth_method(smp, req->io_capability, + rsp->io_capability); + if (method == JUST_WORKS || method == JUST_CFM) + return SMP_AUTH_REQUIREMENTS; + } + + get_random_bytes(smp->prnd, sizeof(smp->prnd)); + + /* Update remote key distribution in case the remote cleared + * some bits that we had enabled in our request. + */ + smp->remote_key_dist &= rsp->resp_key_dist; + + if (test_bit(SMP_FLAG_SC, &smp->flags)) { + /* Clear bits which are generated but not distributed */ + smp->remote_key_dist &= ~SMP_SC_NO_DIST; + SMP_ALLOW_CMD(smp, SMP_CMD_PUBLIC_KEY); + return sc_send_public_key(smp); + } + + auth |= req->auth_req; + + ret = tk_request(conn, 0, auth, req->io_capability, rsp->io_capability); + if (ret) + return SMP_UNSPECIFIED; + + set_bit(SMP_FLAG_CFM_PENDING, &smp->flags); + + /* Can't compose response until we have been confirmed */ + if (test_bit(SMP_FLAG_TK_VALID, &smp->flags)) + return smp_confirm(smp); + + return 0; +} + +static u8 sc_check_confirm(struct smp_chan *smp) +{ + struct l2cap_conn *conn = smp->conn; + + bt_dev_dbg(conn->hcon->hdev, ""); + + if (smp->method == REQ_PASSKEY || smp->method == DSP_PASSKEY) + return sc_passkey_round(smp, SMP_CMD_PAIRING_CONFIRM); + + if (conn->hcon->out) { + smp_send_cmd(conn, SMP_CMD_PAIRING_RANDOM, sizeof(smp->prnd), + smp->prnd); + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_RANDOM); + } + + return 0; +} + +/* Work-around for some implementations that incorrectly copy RFU bits + * from our security request and thereby create the impression that + * we're doing SC when in fact the remote doesn't support it. + */ +static int fixup_sc_false_positive(struct smp_chan *smp) +{ + struct l2cap_conn *conn = smp->conn; + struct hci_conn *hcon = conn->hcon; + struct hci_dev *hdev = hcon->hdev; + struct smp_cmd_pairing *req, *rsp; + u8 auth; + + /* The issue is only observed when we're in responder role */ + if (hcon->out) + return SMP_UNSPECIFIED; + + if (hci_dev_test_flag(hdev, HCI_SC_ONLY)) { + bt_dev_err(hdev, "refusing legacy fallback in SC-only mode"); + return SMP_UNSPECIFIED; + } + + bt_dev_err(hdev, "trying to fall back to legacy SMP"); + + req = (void *) &smp->preq[1]; + rsp = (void *) &smp->prsp[1]; + + /* Rebuild key dist flags which may have been cleared for SC */ + smp->remote_key_dist = (req->init_key_dist & rsp->resp_key_dist); + + auth = req->auth_req & AUTH_REQ_MASK(hdev); + + if (tk_request(conn, 0, auth, rsp->io_capability, req->io_capability)) { + bt_dev_err(hdev, "failed to fall back to legacy SMP"); + return SMP_UNSPECIFIED; + } + + clear_bit(SMP_FLAG_SC, &smp->flags); + + return 0; +} + +static u8 smp_cmd_pairing_confirm(struct l2cap_conn *conn, struct sk_buff *skb) +{ + struct l2cap_chan *chan = conn->smp; + struct smp_chan *smp = chan->data; + struct hci_conn *hcon = conn->hcon; + struct hci_dev *hdev = hcon->hdev; + + bt_dev_dbg(hdev, "conn %p %s", conn, + hcon->out ? "initiator" : "responder"); + + if (skb->len < sizeof(smp->pcnf)) + return SMP_INVALID_PARAMS; + + memcpy(smp->pcnf, skb->data, sizeof(smp->pcnf)); + skb_pull(skb, sizeof(smp->pcnf)); + + if (test_bit(SMP_FLAG_SC, &smp->flags)) { + int ret; + + /* Public Key exchange must happen before any other steps */ + if (test_bit(SMP_FLAG_REMOTE_PK, &smp->flags)) + return sc_check_confirm(smp); + + bt_dev_err(hdev, "Unexpected SMP Pairing Confirm"); + + ret = fixup_sc_false_positive(smp); + if (ret) + return ret; + } + + if (conn->hcon->out) { + smp_send_cmd(conn, SMP_CMD_PAIRING_RANDOM, sizeof(smp->prnd), + smp->prnd); + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_RANDOM); + return 0; + } + + if (test_bit(SMP_FLAG_TK_VALID, &smp->flags)) + return smp_confirm(smp); + + set_bit(SMP_FLAG_CFM_PENDING, &smp->flags); + + return 0; +} + +static u8 smp_cmd_pairing_random(struct l2cap_conn *conn, struct sk_buff *skb) +{ + struct l2cap_chan *chan = conn->smp; + struct smp_chan *smp = chan->data; + struct hci_conn *hcon = conn->hcon; + u8 *pkax, *pkbx, *na, *nb, confirm_hint; + u32 passkey; + int err; + + bt_dev_dbg(hcon->hdev, "conn %p", conn); + + if (skb->len < sizeof(smp->rrnd)) + return SMP_INVALID_PARAMS; + + memcpy(smp->rrnd, skb->data, sizeof(smp->rrnd)); + skb_pull(skb, sizeof(smp->rrnd)); + + if (!test_bit(SMP_FLAG_SC, &smp->flags)) + return smp_random(smp); + + if (hcon->out) { + pkax = smp->local_pk; + pkbx = smp->remote_pk; + na = smp->prnd; + nb = smp->rrnd; + } else { + pkax = smp->remote_pk; + pkbx = smp->local_pk; + na = smp->rrnd; + nb = smp->prnd; + } + + if (smp->method == REQ_OOB) { + if (!hcon->out) + smp_send_cmd(conn, SMP_CMD_PAIRING_RANDOM, + sizeof(smp->prnd), smp->prnd); + SMP_ALLOW_CMD(smp, SMP_CMD_DHKEY_CHECK); + goto mackey_and_ltk; + } + + /* Passkey entry has special treatment */ + if (smp->method == REQ_PASSKEY || smp->method == DSP_PASSKEY) + return sc_passkey_round(smp, SMP_CMD_PAIRING_RANDOM); + + if (hcon->out) { + u8 cfm[16]; + + err = smp_f4(smp->tfm_cmac, smp->remote_pk, smp->local_pk, + smp->rrnd, 0, cfm); + if (err) + return SMP_UNSPECIFIED; + + if (crypto_memneq(smp->pcnf, cfm, 16)) + return SMP_CONFIRM_FAILED; + } else { + smp_send_cmd(conn, SMP_CMD_PAIRING_RANDOM, sizeof(smp->prnd), + smp->prnd); + SMP_ALLOW_CMD(smp, SMP_CMD_DHKEY_CHECK); + + /* Only Just-Works pairing requires extra checks */ + if (smp->method != JUST_WORKS) + goto mackey_and_ltk; + + /* If there already exists long term key in local host, leave + * the decision to user space since the remote device could + * be legitimate or malicious. + */ + if (hci_find_ltk(hcon->hdev, &hcon->dst, hcon->dst_type, + hcon->role)) { + /* Set passkey to 0. The value can be any number since + * it'll be ignored anyway. + */ + passkey = 0; + confirm_hint = 1; + goto confirm; + } + } + +mackey_and_ltk: + /* Generate MacKey and LTK */ + err = sc_mackey_and_ltk(smp, smp->mackey, smp->tk); + if (err) + return SMP_UNSPECIFIED; + + if (smp->method == REQ_OOB) { + if (hcon->out) { + sc_dhkey_check(smp); + SMP_ALLOW_CMD(smp, SMP_CMD_DHKEY_CHECK); + } + return 0; + } + + err = smp_g2(smp->tfm_cmac, pkax, pkbx, na, nb, &passkey); + if (err) + return SMP_UNSPECIFIED; + + confirm_hint = 0; + +confirm: + if (smp->method == JUST_WORKS) + confirm_hint = 1; + + err = mgmt_user_confirm_request(hcon->hdev, &hcon->dst, hcon->type, + hcon->dst_type, passkey, confirm_hint); + if (err) + return SMP_UNSPECIFIED; + + set_bit(SMP_FLAG_WAIT_USER, &smp->flags); + + return 0; +} + +static bool smp_ltk_encrypt(struct l2cap_conn *conn, u8 sec_level) +{ + struct smp_ltk *key; + struct hci_conn *hcon = conn->hcon; + + key = hci_find_ltk(hcon->hdev, &hcon->dst, hcon->dst_type, hcon->role); + if (!key) + return false; + + if (smp_ltk_sec_level(key) < sec_level) + return false; + + if (test_and_set_bit(HCI_CONN_ENCRYPT_PEND, &hcon->flags)) + return true; + + hci_le_start_enc(hcon, key->ediv, key->rand, key->val, key->enc_size); + hcon->enc_key_size = key->enc_size; + + /* We never store STKs for initiator role, so clear this flag */ + clear_bit(HCI_CONN_STK_ENCRYPT, &hcon->flags); + + return true; +} + +bool smp_sufficient_security(struct hci_conn *hcon, u8 sec_level, + enum smp_key_pref key_pref) +{ + if (sec_level == BT_SECURITY_LOW) + return true; + + /* If we're encrypted with an STK but the caller prefers using + * LTK claim insufficient security. This way we allow the + * connection to be re-encrypted with an LTK, even if the LTK + * provides the same level of security. Only exception is if we + * don't have an LTK (e.g. because of key distribution bits). + */ + if (key_pref == SMP_USE_LTK && + test_bit(HCI_CONN_STK_ENCRYPT, &hcon->flags) && + hci_find_ltk(hcon->hdev, &hcon->dst, hcon->dst_type, hcon->role)) + return false; + + if (hcon->sec_level >= sec_level) + return true; + + return false; +} + +static u8 smp_cmd_security_req(struct l2cap_conn *conn, struct sk_buff *skb) +{ + struct smp_cmd_security_req *rp = (void *) skb->data; + struct smp_cmd_pairing cp; + struct hci_conn *hcon = conn->hcon; + struct hci_dev *hdev = hcon->hdev; + struct smp_chan *smp; + u8 sec_level, auth; + + bt_dev_dbg(hdev, "conn %p", conn); + + if (skb->len < sizeof(*rp)) + return SMP_INVALID_PARAMS; + + if (hcon->role != HCI_ROLE_MASTER) + return SMP_CMD_NOTSUPP; + + auth = rp->auth_req & AUTH_REQ_MASK(hdev); + + if (hci_dev_test_flag(hdev, HCI_SC_ONLY) && !(auth & SMP_AUTH_SC)) + return SMP_AUTH_REQUIREMENTS; + + if (hcon->io_capability == HCI_IO_NO_INPUT_OUTPUT) + sec_level = BT_SECURITY_MEDIUM; + else + sec_level = authreq_to_seclevel(auth); + + if (smp_sufficient_security(hcon, sec_level, SMP_USE_LTK)) { + /* If link is already encrypted with sufficient security we + * still need refresh encryption as per Core Spec 5.0 Vol 3, + * Part H 2.4.6 + */ + smp_ltk_encrypt(conn, hcon->sec_level); + return 0; + } + + if (sec_level > hcon->pending_sec_level) + hcon->pending_sec_level = sec_level; + + if (smp_ltk_encrypt(conn, hcon->pending_sec_level)) + return 0; + + smp = smp_chan_create(conn); + if (!smp) + return SMP_UNSPECIFIED; + + if (!hci_dev_test_flag(hdev, HCI_BONDABLE) && + (auth & SMP_AUTH_BONDING)) + return SMP_PAIRING_NOTSUPP; + + skb_pull(skb, sizeof(*rp)); + + memset(&cp, 0, sizeof(cp)); + build_pairing_cmd(conn, &cp, NULL, auth); + + smp->preq[0] = SMP_CMD_PAIRING_REQ; + memcpy(&smp->preq[1], &cp, sizeof(cp)); + + smp_send_cmd(conn, SMP_CMD_PAIRING_REQ, sizeof(cp), &cp); + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_RSP); + + return 0; +} + +int smp_conn_security(struct hci_conn *hcon, __u8 sec_level) +{ + struct l2cap_conn *conn = hcon->l2cap_data; + struct l2cap_chan *chan; + struct smp_chan *smp; + __u8 authreq; + int ret; + + bt_dev_dbg(hcon->hdev, "conn %p hcon %p level 0x%2.2x", conn, hcon, + sec_level); + + /* This may be NULL if there's an unexpected disconnection */ + if (!conn) + return 1; + + if (!hci_dev_test_flag(hcon->hdev, HCI_LE_ENABLED)) + return 1; + + if (smp_sufficient_security(hcon, sec_level, SMP_USE_LTK)) + return 1; + + if (sec_level > hcon->pending_sec_level) + hcon->pending_sec_level = sec_level; + + if (hcon->role == HCI_ROLE_MASTER) + if (smp_ltk_encrypt(conn, hcon->pending_sec_level)) + return 0; + + chan = conn->smp; + if (!chan) { + bt_dev_err(hcon->hdev, "security requested but not available"); + return 1; + } + + l2cap_chan_lock(chan); + + /* If SMP is already in progress ignore this request */ + if (chan->data) { + ret = 0; + goto unlock; + } + + smp = smp_chan_create(conn); + if (!smp) { + ret = 1; + goto unlock; + } + + authreq = seclevel_to_authreq(sec_level); + + if (hci_dev_test_flag(hcon->hdev, HCI_SC_ENABLED)) { + authreq |= SMP_AUTH_SC; + if (hci_dev_test_flag(hcon->hdev, HCI_SSP_ENABLED)) + authreq |= SMP_AUTH_CT2; + } + + /* Don't attempt to set MITM if setting is overridden by debugfs + * Needed to pass certification test SM/MAS/PKE/BV-01-C + */ + if (!hci_dev_test_flag(hcon->hdev, HCI_FORCE_NO_MITM)) { + /* Require MITM if IO Capability allows or the security level + * requires it. + */ + if (hcon->io_capability != HCI_IO_NO_INPUT_OUTPUT || + hcon->pending_sec_level > BT_SECURITY_MEDIUM) + authreq |= SMP_AUTH_MITM; + } + + if (hcon->role == HCI_ROLE_MASTER) { + struct smp_cmd_pairing cp; + + build_pairing_cmd(conn, &cp, NULL, authreq); + smp->preq[0] = SMP_CMD_PAIRING_REQ; + memcpy(&smp->preq[1], &cp, sizeof(cp)); + + smp_send_cmd(conn, SMP_CMD_PAIRING_REQ, sizeof(cp), &cp); + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_RSP); + } else { + struct smp_cmd_security_req cp; + cp.auth_req = authreq; + smp_send_cmd(conn, SMP_CMD_SECURITY_REQ, sizeof(cp), &cp); + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_REQ); + } + + set_bit(SMP_FLAG_INITIATOR, &smp->flags); + ret = 0; + +unlock: + l2cap_chan_unlock(chan); + return ret; +} + +int smp_cancel_and_remove_pairing(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 addr_type) +{ + struct hci_conn *hcon; + struct l2cap_conn *conn; + struct l2cap_chan *chan; + struct smp_chan *smp; + int err; + + err = hci_remove_ltk(hdev, bdaddr, addr_type); + hci_remove_irk(hdev, bdaddr, addr_type); + + hcon = hci_conn_hash_lookup_le(hdev, bdaddr, addr_type); + if (!hcon) + goto done; + + conn = hcon->l2cap_data; + if (!conn) + goto done; + + chan = conn->smp; + if (!chan) + goto done; + + l2cap_chan_lock(chan); + + smp = chan->data; + if (smp) { + /* Set keys to NULL to make sure smp_failure() does not try to + * remove and free already invalidated rcu list entries. */ + smp->ltk = NULL; + smp->responder_ltk = NULL; + smp->remote_irk = NULL; + + if (test_bit(SMP_FLAG_COMPLETE, &smp->flags)) + smp_failure(conn, 0); + else + smp_failure(conn, SMP_UNSPECIFIED); + err = 0; + } + + l2cap_chan_unlock(chan); + +done: + return err; +} + +static int smp_cmd_encrypt_info(struct l2cap_conn *conn, struct sk_buff *skb) +{ + struct smp_cmd_encrypt_info *rp = (void *) skb->data; + struct l2cap_chan *chan = conn->smp; + struct smp_chan *smp = chan->data; + + bt_dev_dbg(conn->hcon->hdev, "conn %p", conn); + + if (skb->len < sizeof(*rp)) + return SMP_INVALID_PARAMS; + + /* Pairing is aborted if any blocked keys are distributed */ + if (hci_is_blocked_key(conn->hcon->hdev, HCI_BLOCKED_KEY_TYPE_LTK, + rp->ltk)) { + bt_dev_warn_ratelimited(conn->hcon->hdev, + "LTK blocked for %pMR", + &conn->hcon->dst); + return SMP_INVALID_PARAMS; + } + + SMP_ALLOW_CMD(smp, SMP_CMD_INITIATOR_IDENT); + + skb_pull(skb, sizeof(*rp)); + + memcpy(smp->tk, rp->ltk, sizeof(smp->tk)); + + return 0; +} + +static int smp_cmd_initiator_ident(struct l2cap_conn *conn, struct sk_buff *skb) +{ + struct smp_cmd_initiator_ident *rp = (void *)skb->data; + struct l2cap_chan *chan = conn->smp; + struct smp_chan *smp = chan->data; + struct hci_dev *hdev = conn->hcon->hdev; + struct hci_conn *hcon = conn->hcon; + struct smp_ltk *ltk; + u8 authenticated; + + bt_dev_dbg(hdev, "conn %p", conn); + + if (skb->len < sizeof(*rp)) + return SMP_INVALID_PARAMS; + + /* Mark the information as received */ + smp->remote_key_dist &= ~SMP_DIST_ENC_KEY; + + if (smp->remote_key_dist & SMP_DIST_ID_KEY) + SMP_ALLOW_CMD(smp, SMP_CMD_IDENT_INFO); + else if (smp->remote_key_dist & SMP_DIST_SIGN) + SMP_ALLOW_CMD(smp, SMP_CMD_SIGN_INFO); + + skb_pull(skb, sizeof(*rp)); + + authenticated = (hcon->sec_level == BT_SECURITY_HIGH); + ltk = hci_add_ltk(hdev, &hcon->dst, hcon->dst_type, SMP_LTK, + authenticated, smp->tk, smp->enc_key_size, + rp->ediv, rp->rand); + smp->ltk = ltk; + if (!(smp->remote_key_dist & KEY_DIST_MASK)) + smp_distribute_keys(smp); + + return 0; +} + +static int smp_cmd_ident_info(struct l2cap_conn *conn, struct sk_buff *skb) +{ + struct smp_cmd_ident_info *info = (void *) skb->data; + struct l2cap_chan *chan = conn->smp; + struct smp_chan *smp = chan->data; + + bt_dev_dbg(conn->hcon->hdev, ""); + + if (skb->len < sizeof(*info)) + return SMP_INVALID_PARAMS; + + /* Pairing is aborted if any blocked keys are distributed */ + if (hci_is_blocked_key(conn->hcon->hdev, HCI_BLOCKED_KEY_TYPE_IRK, + info->irk)) { + bt_dev_warn_ratelimited(conn->hcon->hdev, + "Identity key blocked for %pMR", + &conn->hcon->dst); + return SMP_INVALID_PARAMS; + } + + SMP_ALLOW_CMD(smp, SMP_CMD_IDENT_ADDR_INFO); + + skb_pull(skb, sizeof(*info)); + + memcpy(smp->irk, info->irk, 16); + + return 0; +} + +static int smp_cmd_ident_addr_info(struct l2cap_conn *conn, + struct sk_buff *skb) +{ + struct smp_cmd_ident_addr_info *info = (void *) skb->data; + struct l2cap_chan *chan = conn->smp; + struct smp_chan *smp = chan->data; + struct hci_conn *hcon = conn->hcon; + bdaddr_t rpa; + + bt_dev_dbg(hcon->hdev, ""); + + if (skb->len < sizeof(*info)) + return SMP_INVALID_PARAMS; + + /* Mark the information as received */ + smp->remote_key_dist &= ~SMP_DIST_ID_KEY; + + if (smp->remote_key_dist & SMP_DIST_SIGN) + SMP_ALLOW_CMD(smp, SMP_CMD_SIGN_INFO); + + skb_pull(skb, sizeof(*info)); + + /* Strictly speaking the Core Specification (4.1) allows sending + * an empty address which would force us to rely on just the IRK + * as "identity information". However, since such + * implementations are not known of and in order to not over + * complicate our implementation, simply pretend that we never + * received an IRK for such a device. + * + * The Identity Address must also be a Static Random or Public + * Address, which hci_is_identity_address() checks for. + */ + if (!bacmp(&info->bdaddr, BDADDR_ANY) || + !hci_is_identity_address(&info->bdaddr, info->addr_type)) { + bt_dev_err(hcon->hdev, "ignoring IRK with no identity address"); + goto distribute; + } + + /* Drop IRK if peer is using identity address during pairing but is + * providing different address as identity information. + * + * Microsoft Surface Precision Mouse is known to have this bug. + */ + if (hci_is_identity_address(&hcon->dst, hcon->dst_type) && + (bacmp(&info->bdaddr, &hcon->dst) || + info->addr_type != hcon->dst_type)) { + bt_dev_err(hcon->hdev, + "ignoring IRK with invalid identity address"); + goto distribute; + } + + bacpy(&smp->id_addr, &info->bdaddr); + smp->id_addr_type = info->addr_type; + + if (hci_bdaddr_is_rpa(&hcon->dst, hcon->dst_type)) + bacpy(&rpa, &hcon->dst); + else + bacpy(&rpa, BDADDR_ANY); + + smp->remote_irk = hci_add_irk(conn->hcon->hdev, &smp->id_addr, + smp->id_addr_type, smp->irk, &rpa); + +distribute: + if (!(smp->remote_key_dist & KEY_DIST_MASK)) + smp_distribute_keys(smp); + + return 0; +} + +static int smp_cmd_sign_info(struct l2cap_conn *conn, struct sk_buff *skb) +{ + struct smp_cmd_sign_info *rp = (void *) skb->data; + struct l2cap_chan *chan = conn->smp; + struct smp_chan *smp = chan->data; + struct smp_csrk *csrk; + + bt_dev_dbg(conn->hcon->hdev, "conn %p", conn); + + if (skb->len < sizeof(*rp)) + return SMP_INVALID_PARAMS; + + /* Mark the information as received */ + smp->remote_key_dist &= ~SMP_DIST_SIGN; + + skb_pull(skb, sizeof(*rp)); + + csrk = kzalloc(sizeof(*csrk), GFP_KERNEL); + if (csrk) { + if (conn->hcon->sec_level > BT_SECURITY_MEDIUM) + csrk->type = MGMT_CSRK_REMOTE_AUTHENTICATED; + else + csrk->type = MGMT_CSRK_REMOTE_UNAUTHENTICATED; + memcpy(csrk->val, rp->csrk, sizeof(csrk->val)); + } + smp->csrk = csrk; + smp_distribute_keys(smp); + + return 0; +} + +static u8 sc_select_method(struct smp_chan *smp) +{ + struct l2cap_conn *conn = smp->conn; + struct hci_conn *hcon = conn->hcon; + struct smp_cmd_pairing *local, *remote; + u8 local_mitm, remote_mitm, local_io, remote_io, method; + + if (test_bit(SMP_FLAG_REMOTE_OOB, &smp->flags) || + test_bit(SMP_FLAG_LOCAL_OOB, &smp->flags)) + return REQ_OOB; + + /* The preq/prsp contain the raw Pairing Request/Response PDUs + * which are needed as inputs to some crypto functions. To get + * the "struct smp_cmd_pairing" from them we need to skip the + * first byte which contains the opcode. + */ + if (hcon->out) { + local = (void *) &smp->preq[1]; + remote = (void *) &smp->prsp[1]; + } else { + local = (void *) &smp->prsp[1]; + remote = (void *) &smp->preq[1]; + } + + local_io = local->io_capability; + remote_io = remote->io_capability; + + local_mitm = (local->auth_req & SMP_AUTH_MITM); + remote_mitm = (remote->auth_req & SMP_AUTH_MITM); + + /* If either side wants MITM, look up the method from the table, + * otherwise use JUST WORKS. + */ + if (local_mitm || remote_mitm) + method = get_auth_method(smp, local_io, remote_io); + else + method = JUST_WORKS; + + /* Don't confirm locally initiated pairing attempts */ + if (method == JUST_CFM && test_bit(SMP_FLAG_INITIATOR, &smp->flags)) + method = JUST_WORKS; + + return method; +} + +static int smp_cmd_public_key(struct l2cap_conn *conn, struct sk_buff *skb) +{ + struct smp_cmd_public_key *key = (void *) skb->data; + struct hci_conn *hcon = conn->hcon; + struct l2cap_chan *chan = conn->smp; + struct smp_chan *smp = chan->data; + struct hci_dev *hdev = hcon->hdev; + struct crypto_kpp *tfm_ecdh; + struct smp_cmd_pairing_confirm cfm; + int err; + + bt_dev_dbg(hdev, "conn %p", conn); + + if (skb->len < sizeof(*key)) + return SMP_INVALID_PARAMS; + + /* Check if remote and local public keys are the same and debug key is + * not in use. + */ + if (!test_bit(SMP_FLAG_DEBUG_KEY, &smp->flags) && + !crypto_memneq(key, smp->local_pk, 64)) { + bt_dev_err(hdev, "Remote and local public keys are identical"); + return SMP_UNSPECIFIED; + } + + memcpy(smp->remote_pk, key, 64); + + if (test_bit(SMP_FLAG_REMOTE_OOB, &smp->flags)) { + err = smp_f4(smp->tfm_cmac, smp->remote_pk, smp->remote_pk, + smp->rr, 0, cfm.confirm_val); + if (err) + return SMP_UNSPECIFIED; + + if (crypto_memneq(cfm.confirm_val, smp->pcnf, 16)) + return SMP_CONFIRM_FAILED; + } + + /* Non-initiating device sends its public key after receiving + * the key from the initiating device. + */ + if (!hcon->out) { + err = sc_send_public_key(smp); + if (err) + return err; + } + + SMP_DBG("Remote Public Key X: %32phN", smp->remote_pk); + SMP_DBG("Remote Public Key Y: %32phN", smp->remote_pk + 32); + + /* Compute the shared secret on the same crypto tfm on which the private + * key was set/generated. + */ + if (test_bit(SMP_FLAG_LOCAL_OOB, &smp->flags)) { + struct l2cap_chan *hchan = hdev->smp_data; + struct smp_dev *smp_dev; + + if (!hchan || !hchan->data) + return SMP_UNSPECIFIED; + + smp_dev = hchan->data; + + tfm_ecdh = smp_dev->tfm_ecdh; + } else { + tfm_ecdh = smp->tfm_ecdh; + } + + if (compute_ecdh_secret(tfm_ecdh, smp->remote_pk, smp->dhkey)) + return SMP_UNSPECIFIED; + + SMP_DBG("DHKey %32phN", smp->dhkey); + + set_bit(SMP_FLAG_REMOTE_PK, &smp->flags); + + smp->method = sc_select_method(smp); + + bt_dev_dbg(hdev, "selected method 0x%02x", smp->method); + + /* JUST_WORKS and JUST_CFM result in an unauthenticated key */ + if (smp->method == JUST_WORKS || smp->method == JUST_CFM) + hcon->pending_sec_level = BT_SECURITY_MEDIUM; + else + hcon->pending_sec_level = BT_SECURITY_FIPS; + + if (!crypto_memneq(debug_pk, smp->remote_pk, 64)) + set_bit(SMP_FLAG_DEBUG_KEY, &smp->flags); + + if (smp->method == DSP_PASSKEY) { + get_random_bytes(&hcon->passkey_notify, + sizeof(hcon->passkey_notify)); + hcon->passkey_notify %= 1000000; + hcon->passkey_entered = 0; + smp->passkey_round = 0; + if (mgmt_user_passkey_notify(hdev, &hcon->dst, hcon->type, + hcon->dst_type, + hcon->passkey_notify, + hcon->passkey_entered)) + return SMP_UNSPECIFIED; + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_CONFIRM); + return sc_passkey_round(smp, SMP_CMD_PUBLIC_KEY); + } + + if (smp->method == REQ_OOB) { + if (hcon->out) + smp_send_cmd(conn, SMP_CMD_PAIRING_RANDOM, + sizeof(smp->prnd), smp->prnd); + + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_RANDOM); + + return 0; + } + + if (hcon->out) + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_CONFIRM); + + if (smp->method == REQ_PASSKEY) { + if (mgmt_user_passkey_request(hdev, &hcon->dst, hcon->type, + hcon->dst_type)) + return SMP_UNSPECIFIED; + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_CONFIRM); + set_bit(SMP_FLAG_WAIT_USER, &smp->flags); + return 0; + } + + /* The Initiating device waits for the non-initiating device to + * send the confirm value. + */ + if (conn->hcon->out) + return 0; + + err = smp_f4(smp->tfm_cmac, smp->local_pk, smp->remote_pk, smp->prnd, + 0, cfm.confirm_val); + if (err) + return SMP_UNSPECIFIED; + + smp_send_cmd(conn, SMP_CMD_PAIRING_CONFIRM, sizeof(cfm), &cfm); + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_RANDOM); + + return 0; +} + +static int smp_cmd_dhkey_check(struct l2cap_conn *conn, struct sk_buff *skb) +{ + struct smp_cmd_dhkey_check *check = (void *) skb->data; + struct l2cap_chan *chan = conn->smp; + struct hci_conn *hcon = conn->hcon; + struct smp_chan *smp = chan->data; + u8 a[7], b[7], *local_addr, *remote_addr; + u8 io_cap[3], r[16], e[16]; + int err; + + bt_dev_dbg(hcon->hdev, "conn %p", conn); + + if (skb->len < sizeof(*check)) + return SMP_INVALID_PARAMS; + + memcpy(a, &hcon->init_addr, 6); + memcpy(b, &hcon->resp_addr, 6); + a[6] = hcon->init_addr_type; + b[6] = hcon->resp_addr_type; + + if (hcon->out) { + local_addr = a; + remote_addr = b; + memcpy(io_cap, &smp->prsp[1], 3); + } else { + local_addr = b; + remote_addr = a; + memcpy(io_cap, &smp->preq[1], 3); + } + + memset(r, 0, sizeof(r)); + + if (smp->method == REQ_PASSKEY || smp->method == DSP_PASSKEY) + put_unaligned_le32(hcon->passkey_notify, r); + else if (smp->method == REQ_OOB) + memcpy(r, smp->lr, 16); + + err = smp_f6(smp->tfm_cmac, smp->mackey, smp->rrnd, smp->prnd, r, + io_cap, remote_addr, local_addr, e); + if (err) + return SMP_UNSPECIFIED; + + if (crypto_memneq(check->e, e, 16)) + return SMP_DHKEY_CHECK_FAILED; + + if (!hcon->out) { + if (test_bit(SMP_FLAG_WAIT_USER, &smp->flags)) { + set_bit(SMP_FLAG_DHKEY_PENDING, &smp->flags); + return 0; + } + + /* Responder sends DHKey check as response to initiator */ + sc_dhkey_check(smp); + } + + sc_add_ltk(smp); + + if (hcon->out) { + hci_le_start_enc(hcon, 0, 0, smp->tk, smp->enc_key_size); + hcon->enc_key_size = smp->enc_key_size; + } + + return 0; +} + +static int smp_cmd_keypress_notify(struct l2cap_conn *conn, + struct sk_buff *skb) +{ + struct smp_cmd_keypress_notify *kp = (void *) skb->data; + + bt_dev_dbg(conn->hcon->hdev, "value 0x%02x", kp->value); + + return 0; +} + +static int smp_sig_channel(struct l2cap_chan *chan, struct sk_buff *skb) +{ + struct l2cap_conn *conn = chan->conn; + struct hci_conn *hcon = conn->hcon; + struct smp_chan *smp; + __u8 code, reason; + int err = 0; + + if (skb->len < 1) + return -EILSEQ; + + if (!hci_dev_test_flag(hcon->hdev, HCI_LE_ENABLED)) { + reason = SMP_PAIRING_NOTSUPP; + goto done; + } + + code = skb->data[0]; + skb_pull(skb, sizeof(code)); + + smp = chan->data; + + if (code > SMP_CMD_MAX) + goto drop; + + if (smp && !test_and_clear_bit(code, &smp->allow_cmd)) + goto drop; + + /* If we don't have a context the only allowed commands are + * pairing request and security request. + */ + if (!smp && code != SMP_CMD_PAIRING_REQ && code != SMP_CMD_SECURITY_REQ) + goto drop; + + switch (code) { + case SMP_CMD_PAIRING_REQ: + reason = smp_cmd_pairing_req(conn, skb); + break; + + case SMP_CMD_PAIRING_FAIL: + smp_failure(conn, 0); + err = -EPERM; + break; + + case SMP_CMD_PAIRING_RSP: + reason = smp_cmd_pairing_rsp(conn, skb); + break; + + case SMP_CMD_SECURITY_REQ: + reason = smp_cmd_security_req(conn, skb); + break; + + case SMP_CMD_PAIRING_CONFIRM: + reason = smp_cmd_pairing_confirm(conn, skb); + break; + + case SMP_CMD_PAIRING_RANDOM: + reason = smp_cmd_pairing_random(conn, skb); + break; + + case SMP_CMD_ENCRYPT_INFO: + reason = smp_cmd_encrypt_info(conn, skb); + break; + + case SMP_CMD_INITIATOR_IDENT: + reason = smp_cmd_initiator_ident(conn, skb); + break; + + case SMP_CMD_IDENT_INFO: + reason = smp_cmd_ident_info(conn, skb); + break; + + case SMP_CMD_IDENT_ADDR_INFO: + reason = smp_cmd_ident_addr_info(conn, skb); + break; + + case SMP_CMD_SIGN_INFO: + reason = smp_cmd_sign_info(conn, skb); + break; + + case SMP_CMD_PUBLIC_KEY: + reason = smp_cmd_public_key(conn, skb); + break; + + case SMP_CMD_DHKEY_CHECK: + reason = smp_cmd_dhkey_check(conn, skb); + break; + + case SMP_CMD_KEYPRESS_NOTIFY: + reason = smp_cmd_keypress_notify(conn, skb); + break; + + default: + bt_dev_dbg(hcon->hdev, "Unknown command code 0x%2.2x", code); + reason = SMP_CMD_NOTSUPP; + goto done; + } + +done: + if (!err) { + if (reason) + smp_failure(conn, reason); + kfree_skb(skb); + } + + return err; + +drop: + bt_dev_err(hcon->hdev, "unexpected SMP command 0x%02x from %pMR", + code, &hcon->dst); + kfree_skb(skb); + return 0; +} + +static void smp_teardown_cb(struct l2cap_chan *chan, int err) +{ + struct l2cap_conn *conn = chan->conn; + + bt_dev_dbg(conn->hcon->hdev, "chan %p", chan); + + if (chan->data) + smp_chan_destroy(conn); + + conn->smp = NULL; + l2cap_chan_put(chan); +} + +static void bredr_pairing(struct l2cap_chan *chan) +{ + struct l2cap_conn *conn = chan->conn; + struct hci_conn *hcon = conn->hcon; + struct hci_dev *hdev = hcon->hdev; + struct smp_cmd_pairing req; + struct smp_chan *smp; + + bt_dev_dbg(hdev, "chan %p", chan); + + /* Only new pairings are interesting */ + if (!test_bit(HCI_CONN_NEW_LINK_KEY, &hcon->flags)) + return; + + /* Don't bother if we're not encrypted */ + if (!test_bit(HCI_CONN_ENCRYPT, &hcon->flags)) + return; + + /* Only initiator may initiate SMP over BR/EDR */ + if (hcon->role != HCI_ROLE_MASTER) + return; + + /* Secure Connections support must be enabled */ + if (!hci_dev_test_flag(hdev, HCI_SC_ENABLED)) + return; + + /* BR/EDR must use Secure Connections for SMP */ + if (!test_bit(HCI_CONN_AES_CCM, &hcon->flags) && + !hci_dev_test_flag(hdev, HCI_FORCE_BREDR_SMP)) + return; + + /* If our LE support is not enabled don't do anything */ + if (!hci_dev_test_flag(hdev, HCI_LE_ENABLED)) + return; + + /* Don't bother if remote LE support is not enabled */ + if (!lmp_host_le_capable(hcon)) + return; + + /* Remote must support SMP fixed chan for BR/EDR */ + if (!(conn->remote_fixed_chan & L2CAP_FC_SMP_BREDR)) + return; + + /* Don't bother if SMP is already ongoing */ + if (chan->data) + return; + + smp = smp_chan_create(conn); + if (!smp) { + bt_dev_err(hdev, "unable to create SMP context for BR/EDR"); + return; + } + + set_bit(SMP_FLAG_SC, &smp->flags); + + bt_dev_dbg(hdev, "starting SMP over BR/EDR"); + + /* Prepare and send the BR/EDR SMP Pairing Request */ + build_bredr_pairing_cmd(smp, &req, NULL); + + smp->preq[0] = SMP_CMD_PAIRING_REQ; + memcpy(&smp->preq[1], &req, sizeof(req)); + + smp_send_cmd(conn, SMP_CMD_PAIRING_REQ, sizeof(req), &req); + SMP_ALLOW_CMD(smp, SMP_CMD_PAIRING_RSP); +} + +static void smp_resume_cb(struct l2cap_chan *chan) +{ + struct smp_chan *smp = chan->data; + struct l2cap_conn *conn = chan->conn; + struct hci_conn *hcon = conn->hcon; + + bt_dev_dbg(hcon->hdev, "chan %p", chan); + + if (hcon->type == ACL_LINK) { + bredr_pairing(chan); + return; + } + + if (!smp) + return; + + if (!test_bit(HCI_CONN_ENCRYPT, &hcon->flags)) + return; + + cancel_delayed_work(&smp->security_timer); + + smp_distribute_keys(smp); +} + +static void smp_ready_cb(struct l2cap_chan *chan) +{ + struct l2cap_conn *conn = chan->conn; + struct hci_conn *hcon = conn->hcon; + + bt_dev_dbg(hcon->hdev, "chan %p", chan); + + /* No need to call l2cap_chan_hold() here since we already own + * the reference taken in smp_new_conn_cb(). This is just the + * first time that we tie it to a specific pointer. The code in + * l2cap_core.c ensures that there's no risk this function wont + * get called if smp_new_conn_cb was previously called. + */ + conn->smp = chan; + + if (hcon->type == ACL_LINK && test_bit(HCI_CONN_ENCRYPT, &hcon->flags)) + bredr_pairing(chan); +} + +static int smp_recv_cb(struct l2cap_chan *chan, struct sk_buff *skb) +{ + int err; + + bt_dev_dbg(chan->conn->hcon->hdev, "chan %p", chan); + + err = smp_sig_channel(chan, skb); + if (err) { + struct smp_chan *smp = chan->data; + + if (smp) + cancel_delayed_work_sync(&smp->security_timer); + + hci_disconnect(chan->conn->hcon, HCI_ERROR_AUTH_FAILURE); + } + + return err; +} + +static struct sk_buff *smp_alloc_skb_cb(struct l2cap_chan *chan, + unsigned long hdr_len, + unsigned long len, int nb) +{ + struct sk_buff *skb; + + skb = bt_skb_alloc(hdr_len + len, GFP_KERNEL); + if (!skb) + return ERR_PTR(-ENOMEM); + + skb->priority = HCI_PRIO_MAX; + bt_cb(skb)->l2cap.chan = chan; + + return skb; +} + +static const struct l2cap_ops smp_chan_ops = { + .name = "Security Manager", + .ready = smp_ready_cb, + .recv = smp_recv_cb, + .alloc_skb = smp_alloc_skb_cb, + .teardown = smp_teardown_cb, + .resume = smp_resume_cb, + + .new_connection = l2cap_chan_no_new_connection, + .state_change = l2cap_chan_no_state_change, + .close = l2cap_chan_no_close, + .defer = l2cap_chan_no_defer, + .suspend = l2cap_chan_no_suspend, + .set_shutdown = l2cap_chan_no_set_shutdown, + .get_sndtimeo = l2cap_chan_no_get_sndtimeo, +}; + +static inline struct l2cap_chan *smp_new_conn_cb(struct l2cap_chan *pchan) +{ + struct l2cap_chan *chan; + + BT_DBG("pchan %p", pchan); + + chan = l2cap_chan_create(); + if (!chan) + return NULL; + + chan->chan_type = pchan->chan_type; + chan->ops = &smp_chan_ops; + chan->scid = pchan->scid; + chan->dcid = chan->scid; + chan->imtu = pchan->imtu; + chan->omtu = pchan->omtu; + chan->mode = pchan->mode; + + /* Other L2CAP channels may request SMP routines in order to + * change the security level. This means that the SMP channel + * lock must be considered in its own category to avoid lockdep + * warnings. + */ + atomic_set(&chan->nesting, L2CAP_NESTING_SMP); + + BT_DBG("created chan %p", chan); + + return chan; +} + +static const struct l2cap_ops smp_root_chan_ops = { + .name = "Security Manager Root", + .new_connection = smp_new_conn_cb, + + /* None of these are implemented for the root channel */ + .close = l2cap_chan_no_close, + .alloc_skb = l2cap_chan_no_alloc_skb, + .recv = l2cap_chan_no_recv, + .state_change = l2cap_chan_no_state_change, + .teardown = l2cap_chan_no_teardown, + .ready = l2cap_chan_no_ready, + .defer = l2cap_chan_no_defer, + .suspend = l2cap_chan_no_suspend, + .resume = l2cap_chan_no_resume, + .set_shutdown = l2cap_chan_no_set_shutdown, + .get_sndtimeo = l2cap_chan_no_get_sndtimeo, +}; + +static struct l2cap_chan *smp_add_cid(struct hci_dev *hdev, u16 cid) +{ + struct l2cap_chan *chan; + struct smp_dev *smp; + struct crypto_shash *tfm_cmac; + struct crypto_kpp *tfm_ecdh; + + if (cid == L2CAP_CID_SMP_BREDR) { + smp = NULL; + goto create_chan; + } + + smp = kzalloc(sizeof(*smp), GFP_KERNEL); + if (!smp) + return ERR_PTR(-ENOMEM); + + tfm_cmac = crypto_alloc_shash("cmac(aes)", 0, 0); + if (IS_ERR(tfm_cmac)) { + bt_dev_err(hdev, "Unable to create CMAC crypto context"); + kfree_sensitive(smp); + return ERR_CAST(tfm_cmac); + } + + tfm_ecdh = crypto_alloc_kpp("ecdh-nist-p256", 0, 0); + if (IS_ERR(tfm_ecdh)) { + bt_dev_err(hdev, "Unable to create ECDH crypto context"); + crypto_free_shash(tfm_cmac); + kfree_sensitive(smp); + return ERR_CAST(tfm_ecdh); + } + + smp->local_oob = false; + smp->tfm_cmac = tfm_cmac; + smp->tfm_ecdh = tfm_ecdh; + +create_chan: + chan = l2cap_chan_create(); + if (!chan) { + if (smp) { + crypto_free_shash(smp->tfm_cmac); + crypto_free_kpp(smp->tfm_ecdh); + kfree_sensitive(smp); + } + return ERR_PTR(-ENOMEM); + } + + chan->data = smp; + + l2cap_add_scid(chan, cid); + + l2cap_chan_set_defaults(chan); + + if (cid == L2CAP_CID_SMP) { + u8 bdaddr_type; + + hci_copy_identity_address(hdev, &chan->src, &bdaddr_type); + + if (bdaddr_type == ADDR_LE_DEV_PUBLIC) + chan->src_type = BDADDR_LE_PUBLIC; + else + chan->src_type = BDADDR_LE_RANDOM; + } else { + bacpy(&chan->src, &hdev->bdaddr); + chan->src_type = BDADDR_BREDR; + } + + chan->state = BT_LISTEN; + chan->mode = L2CAP_MODE_BASIC; + chan->imtu = L2CAP_DEFAULT_MTU; + chan->ops = &smp_root_chan_ops; + + /* Set correct nesting level for a parent/listening channel */ + atomic_set(&chan->nesting, L2CAP_NESTING_PARENT); + + return chan; +} + +static void smp_del_chan(struct l2cap_chan *chan) +{ + struct smp_dev *smp; + + BT_DBG("chan %p", chan); + + smp = chan->data; + if (smp) { + chan->data = NULL; + crypto_free_shash(smp->tfm_cmac); + crypto_free_kpp(smp->tfm_ecdh); + kfree_sensitive(smp); + } + + l2cap_chan_put(chan); +} + +int smp_force_bredr(struct hci_dev *hdev, bool enable) +{ + if (enable == hci_dev_test_flag(hdev, HCI_FORCE_BREDR_SMP)) + return -EALREADY; + + if (enable) { + struct l2cap_chan *chan; + + chan = smp_add_cid(hdev, L2CAP_CID_SMP_BREDR); + if (IS_ERR(chan)) + return PTR_ERR(chan); + + hdev->smp_bredr_data = chan; + } else { + struct l2cap_chan *chan; + + chan = hdev->smp_bredr_data; + hdev->smp_bredr_data = NULL; + smp_del_chan(chan); + } + + hci_dev_change_flag(hdev, HCI_FORCE_BREDR_SMP); + + return 0; +} + +int smp_register(struct hci_dev *hdev) +{ + struct l2cap_chan *chan; + + bt_dev_dbg(hdev, ""); + + /* If the controller does not support Low Energy operation, then + * there is also no need to register any SMP channel. + */ + if (!lmp_le_capable(hdev)) + return 0; + + if (WARN_ON(hdev->smp_data)) { + chan = hdev->smp_data; + hdev->smp_data = NULL; + smp_del_chan(chan); + } + + chan = smp_add_cid(hdev, L2CAP_CID_SMP); + if (IS_ERR(chan)) + return PTR_ERR(chan); + + hdev->smp_data = chan; + + if (!lmp_sc_capable(hdev)) { + /* Flag can be already set here (due to power toggle) */ + if (!hci_dev_test_flag(hdev, HCI_FORCE_BREDR_SMP)) + return 0; + } + + if (WARN_ON(hdev->smp_bredr_data)) { + chan = hdev->smp_bredr_data; + hdev->smp_bredr_data = NULL; + smp_del_chan(chan); + } + + chan = smp_add_cid(hdev, L2CAP_CID_SMP_BREDR); + if (IS_ERR(chan)) { + int err = PTR_ERR(chan); + chan = hdev->smp_data; + hdev->smp_data = NULL; + smp_del_chan(chan); + return err; + } + + hdev->smp_bredr_data = chan; + + return 0; +} + +void smp_unregister(struct hci_dev *hdev) +{ + struct l2cap_chan *chan; + + if (hdev->smp_bredr_data) { + chan = hdev->smp_bredr_data; + hdev->smp_bredr_data = NULL; + smp_del_chan(chan); + } + + if (hdev->smp_data) { + chan = hdev->smp_data; + hdev->smp_data = NULL; + smp_del_chan(chan); + } +} + +#if IS_ENABLED(CONFIG_BT_SELFTEST_SMP) + +static int __init test_debug_key(struct crypto_kpp *tfm_ecdh) +{ + u8 pk[64]; + int err; + + err = set_ecdh_privkey(tfm_ecdh, debug_sk); + if (err) + return err; + + err = generate_ecdh_public_key(tfm_ecdh, pk); + if (err) + return err; + + if (crypto_memneq(pk, debug_pk, 64)) + return -EINVAL; + + return 0; +} + +static int __init test_ah(void) +{ + const u8 irk[16] = { + 0x9b, 0x7d, 0x39, 0x0a, 0xa6, 0x10, 0x10, 0x34, + 0x05, 0xad, 0xc8, 0x57, 0xa3, 0x34, 0x02, 0xec }; + const u8 r[3] = { 0x94, 0x81, 0x70 }; + const u8 exp[3] = { 0xaa, 0xfb, 0x0d }; + u8 res[3]; + int err; + + err = smp_ah(irk, r, res); + if (err) + return err; + + if (crypto_memneq(res, exp, 3)) + return -EINVAL; + + return 0; +} + +static int __init test_c1(void) +{ + const u8 k[16] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }; + const u8 r[16] = { + 0xe0, 0x2e, 0x70, 0xc6, 0x4e, 0x27, 0x88, 0x63, + 0x0e, 0x6f, 0xad, 0x56, 0x21, 0xd5, 0x83, 0x57 }; + const u8 preq[7] = { 0x01, 0x01, 0x00, 0x00, 0x10, 0x07, 0x07 }; + const u8 pres[7] = { 0x02, 0x03, 0x00, 0x00, 0x08, 0x00, 0x05 }; + const u8 _iat = 0x01; + const u8 _rat = 0x00; + const bdaddr_t ra = { { 0xb6, 0xb5, 0xb4, 0xb3, 0xb2, 0xb1 } }; + const bdaddr_t ia = { { 0xa6, 0xa5, 0xa4, 0xa3, 0xa2, 0xa1 } }; + const u8 exp[16] = { + 0x86, 0x3b, 0xf1, 0xbe, 0xc5, 0x4d, 0xa7, 0xd2, + 0xea, 0x88, 0x89, 0x87, 0xef, 0x3f, 0x1e, 0x1e }; + u8 res[16]; + int err; + + err = smp_c1(k, r, preq, pres, _iat, &ia, _rat, &ra, res); + if (err) + return err; + + if (crypto_memneq(res, exp, 16)) + return -EINVAL; + + return 0; +} + +static int __init test_s1(void) +{ + const u8 k[16] = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }; + const u8 r1[16] = { + 0x88, 0x77, 0x66, 0x55, 0x44, 0x33, 0x22, 0x11 }; + const u8 r2[16] = { + 0x00, 0xff, 0xee, 0xdd, 0xcc, 0xbb, 0xaa, 0x99 }; + const u8 exp[16] = { + 0x62, 0xa0, 0x6d, 0x79, 0xae, 0x16, 0x42, 0x5b, + 0x9b, 0xf4, 0xb0, 0xe8, 0xf0, 0xe1, 0x1f, 0x9a }; + u8 res[16]; + int err; + + err = smp_s1(k, r1, r2, res); + if (err) + return err; + + if (crypto_memneq(res, exp, 16)) + return -EINVAL; + + return 0; +} + +static int __init test_f4(struct crypto_shash *tfm_cmac) +{ + const u8 u[32] = { + 0xe6, 0x9d, 0x35, 0x0e, 0x48, 0x01, 0x03, 0xcc, + 0xdb, 0xfd, 0xf4, 0xac, 0x11, 0x91, 0xf4, 0xef, + 0xb9, 0xa5, 0xf9, 0xe9, 0xa7, 0x83, 0x2c, 0x5e, + 0x2c, 0xbe, 0x97, 0xf2, 0xd2, 0x03, 0xb0, 0x20 }; + const u8 v[32] = { + 0xfd, 0xc5, 0x7f, 0xf4, 0x49, 0xdd, 0x4f, 0x6b, + 0xfb, 0x7c, 0x9d, 0xf1, 0xc2, 0x9a, 0xcb, 0x59, + 0x2a, 0xe7, 0xd4, 0xee, 0xfb, 0xfc, 0x0a, 0x90, + 0x9a, 0xbb, 0xf6, 0x32, 0x3d, 0x8b, 0x18, 0x55 }; + const u8 x[16] = { + 0xab, 0xae, 0x2b, 0x71, 0xec, 0xb2, 0xff, 0xff, + 0x3e, 0x73, 0x77, 0xd1, 0x54, 0x84, 0xcb, 0xd5 }; + const u8 z = 0x00; + const u8 exp[16] = { + 0x2d, 0x87, 0x74, 0xa9, 0xbe, 0xa1, 0xed, 0xf1, + 0x1c, 0xbd, 0xa9, 0x07, 0xf1, 0x16, 0xc9, 0xf2 }; + u8 res[16]; + int err; + + err = smp_f4(tfm_cmac, u, v, x, z, res); + if (err) + return err; + + if (crypto_memneq(res, exp, 16)) + return -EINVAL; + + return 0; +} + +static int __init test_f5(struct crypto_shash *tfm_cmac) +{ + const u8 w[32] = { + 0x98, 0xa6, 0xbf, 0x73, 0xf3, 0x34, 0x8d, 0x86, + 0xf1, 0x66, 0xf8, 0xb4, 0x13, 0x6b, 0x79, 0x99, + 0x9b, 0x7d, 0x39, 0x0a, 0xa6, 0x10, 0x10, 0x34, + 0x05, 0xad, 0xc8, 0x57, 0xa3, 0x34, 0x02, 0xec }; + const u8 n1[16] = { + 0xab, 0xae, 0x2b, 0x71, 0xec, 0xb2, 0xff, 0xff, + 0x3e, 0x73, 0x77, 0xd1, 0x54, 0x84, 0xcb, 0xd5 }; + const u8 n2[16] = { + 0xcf, 0xc4, 0x3d, 0xff, 0xf7, 0x83, 0x65, 0x21, + 0x6e, 0x5f, 0xa7, 0x25, 0xcc, 0xe7, 0xe8, 0xa6 }; + const u8 a1[7] = { 0xce, 0xbf, 0x37, 0x37, 0x12, 0x56, 0x00 }; + const u8 a2[7] = { 0xc1, 0xcf, 0x2d, 0x70, 0x13, 0xa7, 0x00 }; + const u8 exp_ltk[16] = { + 0x38, 0x0a, 0x75, 0x94, 0xb5, 0x22, 0x05, 0x98, + 0x23, 0xcd, 0xd7, 0x69, 0x11, 0x79, 0x86, 0x69 }; + const u8 exp_mackey[16] = { + 0x20, 0x6e, 0x63, 0xce, 0x20, 0x6a, 0x3f, 0xfd, + 0x02, 0x4a, 0x08, 0xa1, 0x76, 0xf1, 0x65, 0x29 }; + u8 mackey[16], ltk[16]; + int err; + + err = smp_f5(tfm_cmac, w, n1, n2, a1, a2, mackey, ltk); + if (err) + return err; + + if (crypto_memneq(mackey, exp_mackey, 16)) + return -EINVAL; + + if (crypto_memneq(ltk, exp_ltk, 16)) + return -EINVAL; + + return 0; +} + +static int __init test_f6(struct crypto_shash *tfm_cmac) +{ + const u8 w[16] = { + 0x20, 0x6e, 0x63, 0xce, 0x20, 0x6a, 0x3f, 0xfd, + 0x02, 0x4a, 0x08, 0xa1, 0x76, 0xf1, 0x65, 0x29 }; + const u8 n1[16] = { + 0xab, 0xae, 0x2b, 0x71, 0xec, 0xb2, 0xff, 0xff, + 0x3e, 0x73, 0x77, 0xd1, 0x54, 0x84, 0xcb, 0xd5 }; + const u8 n2[16] = { + 0xcf, 0xc4, 0x3d, 0xff, 0xf7, 0x83, 0x65, 0x21, + 0x6e, 0x5f, 0xa7, 0x25, 0xcc, 0xe7, 0xe8, 0xa6 }; + const u8 r[16] = { + 0xc8, 0x0f, 0x2d, 0x0c, 0xd2, 0x42, 0xda, 0x08, + 0x54, 0xbb, 0x53, 0xb4, 0x3b, 0x34, 0xa3, 0x12 }; + const u8 io_cap[3] = { 0x02, 0x01, 0x01 }; + const u8 a1[7] = { 0xce, 0xbf, 0x37, 0x37, 0x12, 0x56, 0x00 }; + const u8 a2[7] = { 0xc1, 0xcf, 0x2d, 0x70, 0x13, 0xa7, 0x00 }; + const u8 exp[16] = { + 0x61, 0x8f, 0x95, 0xda, 0x09, 0x0b, 0x6c, 0xd2, + 0xc5, 0xe8, 0xd0, 0x9c, 0x98, 0x73, 0xc4, 0xe3 }; + u8 res[16]; + int err; + + err = smp_f6(tfm_cmac, w, n1, n2, r, io_cap, a1, a2, res); + if (err) + return err; + + if (crypto_memneq(res, exp, 16)) + return -EINVAL; + + return 0; +} + +static int __init test_g2(struct crypto_shash *tfm_cmac) +{ + const u8 u[32] = { + 0xe6, 0x9d, 0x35, 0x0e, 0x48, 0x01, 0x03, 0xcc, + 0xdb, 0xfd, 0xf4, 0xac, 0x11, 0x91, 0xf4, 0xef, + 0xb9, 0xa5, 0xf9, 0xe9, 0xa7, 0x83, 0x2c, 0x5e, + 0x2c, 0xbe, 0x97, 0xf2, 0xd2, 0x03, 0xb0, 0x20 }; + const u8 v[32] = { + 0xfd, 0xc5, 0x7f, 0xf4, 0x49, 0xdd, 0x4f, 0x6b, + 0xfb, 0x7c, 0x9d, 0xf1, 0xc2, 0x9a, 0xcb, 0x59, + 0x2a, 0xe7, 0xd4, 0xee, 0xfb, 0xfc, 0x0a, 0x90, + 0x9a, 0xbb, 0xf6, 0x32, 0x3d, 0x8b, 0x18, 0x55 }; + const u8 x[16] = { + 0xab, 0xae, 0x2b, 0x71, 0xec, 0xb2, 0xff, 0xff, + 0x3e, 0x73, 0x77, 0xd1, 0x54, 0x84, 0xcb, 0xd5 }; + const u8 y[16] = { + 0xcf, 0xc4, 0x3d, 0xff, 0xf7, 0x83, 0x65, 0x21, + 0x6e, 0x5f, 0xa7, 0x25, 0xcc, 0xe7, 0xe8, 0xa6 }; + const u32 exp_val = 0x2f9ed5ba % 1000000; + u32 val; + int err; + + err = smp_g2(tfm_cmac, u, v, x, y, &val); + if (err) + return err; + + if (val != exp_val) + return -EINVAL; + + return 0; +} + +static int __init test_h6(struct crypto_shash *tfm_cmac) +{ + const u8 w[16] = { + 0x9b, 0x7d, 0x39, 0x0a, 0xa6, 0x10, 0x10, 0x34, + 0x05, 0xad, 0xc8, 0x57, 0xa3, 0x34, 0x02, 0xec }; + const u8 key_id[4] = { 0x72, 0x62, 0x65, 0x6c }; + const u8 exp[16] = { + 0x99, 0x63, 0xb1, 0x80, 0xe2, 0xa9, 0xd3, 0xe8, + 0x1c, 0xc9, 0x6d, 0xe7, 0x02, 0xe1, 0x9a, 0x2d }; + u8 res[16]; + int err; + + err = smp_h6(tfm_cmac, w, key_id, res); + if (err) + return err; + + if (crypto_memneq(res, exp, 16)) + return -EINVAL; + + return 0; +} + +static char test_smp_buffer[32]; + +static ssize_t test_smp_read(struct file *file, char __user *user_buf, + size_t count, loff_t *ppos) +{ + return simple_read_from_buffer(user_buf, count, ppos, test_smp_buffer, + strlen(test_smp_buffer)); +} + +static const struct file_operations test_smp_fops = { + .open = simple_open, + .read = test_smp_read, + .llseek = default_llseek, +}; + +static int __init run_selftests(struct crypto_shash *tfm_cmac, + struct crypto_kpp *tfm_ecdh) +{ + ktime_t calltime, delta, rettime; + unsigned long long duration; + int err; + + calltime = ktime_get(); + + err = test_debug_key(tfm_ecdh); + if (err) { + BT_ERR("debug_key test failed"); + goto done; + } + + err = test_ah(); + if (err) { + BT_ERR("smp_ah test failed"); + goto done; + } + + err = test_c1(); + if (err) { + BT_ERR("smp_c1 test failed"); + goto done; + } + + err = test_s1(); + if (err) { + BT_ERR("smp_s1 test failed"); + goto done; + } + + err = test_f4(tfm_cmac); + if (err) { + BT_ERR("smp_f4 test failed"); + goto done; + } + + err = test_f5(tfm_cmac); + if (err) { + BT_ERR("smp_f5 test failed"); + goto done; + } + + err = test_f6(tfm_cmac); + if (err) { + BT_ERR("smp_f6 test failed"); + goto done; + } + + err = test_g2(tfm_cmac); + if (err) { + BT_ERR("smp_g2 test failed"); + goto done; + } + + err = test_h6(tfm_cmac); + if (err) { + BT_ERR("smp_h6 test failed"); + goto done; + } + + rettime = ktime_get(); + delta = ktime_sub(rettime, calltime); + duration = (unsigned long long) ktime_to_ns(delta) >> 10; + + BT_INFO("SMP test passed in %llu usecs", duration); + +done: + if (!err) + snprintf(test_smp_buffer, sizeof(test_smp_buffer), + "PASS (%llu usecs)\n", duration); + else + snprintf(test_smp_buffer, sizeof(test_smp_buffer), "FAIL\n"); + + debugfs_create_file("selftest_smp", 0444, bt_debugfs, NULL, + &test_smp_fops); + + return err; +} + +int __init bt_selftest_smp(void) +{ + struct crypto_shash *tfm_cmac; + struct crypto_kpp *tfm_ecdh; + int err; + + tfm_cmac = crypto_alloc_shash("cmac(aes)", 0, 0); + if (IS_ERR(tfm_cmac)) { + BT_ERR("Unable to create CMAC crypto context"); + return PTR_ERR(tfm_cmac); + } + + tfm_ecdh = crypto_alloc_kpp("ecdh-nist-p256", 0, 0); + if (IS_ERR(tfm_ecdh)) { + BT_ERR("Unable to create ECDH crypto context"); + crypto_free_shash(tfm_cmac); + return PTR_ERR(tfm_ecdh); + } + + err = run_selftests(tfm_cmac, tfm_ecdh); + + crypto_free_shash(tfm_cmac); + crypto_free_kpp(tfm_ecdh); + + return err; +} + +#endif diff --git a/net/bluetooth/smp.h b/net/bluetooth/smp.h new file mode 100644 index 000000000..87a59ec2c --- /dev/null +++ b/net/bluetooth/smp.h @@ -0,0 +1,214 @@ +/* + BlueZ - Bluetooth protocol stack for Linux + Copyright (C) 2011 Nokia Corporation and/or its subsidiary(-ies). + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License version 2 as + published by the Free Software Foundation; + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS. + IN NO EVENT SHALL THE COPYRIGHT HOLDER(S) AND AUTHOR(S) BE LIABLE FOR ANY + CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + + ALL LIABILITY, INCLUDING LIABILITY FOR INFRINGEMENT OF ANY PATENTS, + COPYRIGHTS, TRADEMARKS OR OTHER RIGHTS, RELATING TO USE OF THIS + SOFTWARE IS DISCLAIMED. +*/ + +#ifndef __SMP_H +#define __SMP_H + +struct smp_command_hdr { + __u8 code; +} __packed; + +#define SMP_CMD_PAIRING_REQ 0x01 +#define SMP_CMD_PAIRING_RSP 0x02 +struct smp_cmd_pairing { + __u8 io_capability; + __u8 oob_flag; + __u8 auth_req; + __u8 max_key_size; + __u8 init_key_dist; + __u8 resp_key_dist; +} __packed; + +#define SMP_IO_DISPLAY_ONLY 0x00 +#define SMP_IO_DISPLAY_YESNO 0x01 +#define SMP_IO_KEYBOARD_ONLY 0x02 +#define SMP_IO_NO_INPUT_OUTPUT 0x03 +#define SMP_IO_KEYBOARD_DISPLAY 0x04 + +#define SMP_OOB_NOT_PRESENT 0x00 +#define SMP_OOB_PRESENT 0x01 + +#define SMP_DIST_ENC_KEY 0x01 +#define SMP_DIST_ID_KEY 0x02 +#define SMP_DIST_SIGN 0x04 +#define SMP_DIST_LINK_KEY 0x08 + +#define SMP_AUTH_NONE 0x00 +#define SMP_AUTH_BONDING 0x01 +#define SMP_AUTH_MITM 0x04 +#define SMP_AUTH_SC 0x08 +#define SMP_AUTH_KEYPRESS 0x10 +#define SMP_AUTH_CT2 0x20 + +#define SMP_CMD_PAIRING_CONFIRM 0x03 +struct smp_cmd_pairing_confirm { + __u8 confirm_val[16]; +} __packed; + +#define SMP_CMD_PAIRING_RANDOM 0x04 +struct smp_cmd_pairing_random { + __u8 rand_val[16]; +} __packed; + +#define SMP_CMD_PAIRING_FAIL 0x05 +struct smp_cmd_pairing_fail { + __u8 reason; +} __packed; + +#define SMP_CMD_ENCRYPT_INFO 0x06 +struct smp_cmd_encrypt_info { + __u8 ltk[16]; +} __packed; + +#define SMP_CMD_INITIATOR_IDENT 0x07 +struct smp_cmd_initiator_ident { + __le16 ediv; + __le64 rand; +} __packed; + +#define SMP_CMD_IDENT_INFO 0x08 +struct smp_cmd_ident_info { + __u8 irk[16]; +} __packed; + +#define SMP_CMD_IDENT_ADDR_INFO 0x09 +struct smp_cmd_ident_addr_info { + __u8 addr_type; + bdaddr_t bdaddr; +} __packed; + +#define SMP_CMD_SIGN_INFO 0x0a +struct smp_cmd_sign_info { + __u8 csrk[16]; +} __packed; + +#define SMP_CMD_SECURITY_REQ 0x0b +struct smp_cmd_security_req { + __u8 auth_req; +} __packed; + +#define SMP_CMD_PUBLIC_KEY 0x0c +struct smp_cmd_public_key { + __u8 x[32]; + __u8 y[32]; +} __packed; + +#define SMP_CMD_DHKEY_CHECK 0x0d +struct smp_cmd_dhkey_check { + __u8 e[16]; +} __packed; + +#define SMP_CMD_KEYPRESS_NOTIFY 0x0e +struct smp_cmd_keypress_notify { + __u8 value; +} __packed; + +#define SMP_CMD_MAX 0x0e + +#define SMP_PASSKEY_ENTRY_FAILED 0x01 +#define SMP_OOB_NOT_AVAIL 0x02 +#define SMP_AUTH_REQUIREMENTS 0x03 +#define SMP_CONFIRM_FAILED 0x04 +#define SMP_PAIRING_NOTSUPP 0x05 +#define SMP_ENC_KEY_SIZE 0x06 +#define SMP_CMD_NOTSUPP 0x07 +#define SMP_UNSPECIFIED 0x08 +#define SMP_REPEATED_ATTEMPTS 0x09 +#define SMP_INVALID_PARAMS 0x0a +#define SMP_DHKEY_CHECK_FAILED 0x0b +#define SMP_NUMERIC_COMP_FAILED 0x0c +#define SMP_BREDR_PAIRING_IN_PROGRESS 0x0d +#define SMP_CROSS_TRANSP_NOT_ALLOWED 0x0e + +#define SMP_MIN_ENC_KEY_SIZE 7 +#define SMP_MAX_ENC_KEY_SIZE 16 + +/* LTK types used in internal storage (struct smp_ltk) */ +enum { + SMP_STK, + SMP_LTK, + SMP_LTK_RESPONDER, + SMP_LTK_P256, + SMP_LTK_P256_DEBUG, +}; + +static inline bool smp_ltk_is_sc(struct smp_ltk *key) +{ + switch (key->type) { + case SMP_LTK_P256: + case SMP_LTK_P256_DEBUG: + return true; + } + + return false; +} + +static inline u8 smp_ltk_sec_level(struct smp_ltk *key) +{ + if (key->authenticated) { + if (smp_ltk_is_sc(key)) + return BT_SECURITY_FIPS; + else + return BT_SECURITY_HIGH; + } + + return BT_SECURITY_MEDIUM; +} + +/* Key preferences for smp_sufficient security */ +enum smp_key_pref { + SMP_ALLOW_STK, + SMP_USE_LTK, +}; + +/* SMP Commands */ +int smp_cancel_and_remove_pairing(struct hci_dev *hdev, bdaddr_t *bdaddr, + u8 addr_type); +bool smp_sufficient_security(struct hci_conn *hcon, u8 sec_level, + enum smp_key_pref key_pref); +int smp_conn_security(struct hci_conn *hcon, __u8 sec_level); +int smp_user_confirm_reply(struct hci_conn *conn, u16 mgmt_op, __le32 passkey); + +bool smp_irk_matches(struct hci_dev *hdev, const u8 irk[16], + const bdaddr_t *bdaddr); +int smp_generate_rpa(struct hci_dev *hdev, const u8 irk[16], bdaddr_t *rpa); +int smp_generate_oob(struct hci_dev *hdev, u8 hash[16], u8 rand[16]); + +int smp_force_bredr(struct hci_dev *hdev, bool enable); + +int smp_register(struct hci_dev *hdev); +void smp_unregister(struct hci_dev *hdev); + +#if IS_ENABLED(CONFIG_BT_SELFTEST_SMP) + +int bt_selftest_smp(void); + +#else + +static inline int bt_selftest_smp(void) +{ + return 0; +} + +#endif + +#endif /* __SMP_H */ diff --git a/net/bpf/Makefile b/net/bpf/Makefile new file mode 100644 index 000000000..1ebe270bd --- /dev/null +++ b/net/bpf/Makefile @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: GPL-2.0-only +obj-$(CONFIG_BPF_SYSCALL) := test_run.o +ifeq ($(CONFIG_BPF_JIT),y) +obj-$(CONFIG_BPF_SYSCALL) += bpf_dummy_struct_ops.o +endif diff --git a/net/bpf/bpf_dummy_struct_ops.c b/net/bpf/bpf_dummy_struct_ops.c new file mode 100644 index 000000000..e78dadfc5 --- /dev/null +++ b/net/bpf/bpf_dummy_struct_ops.c @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (C) 2021. Huawei Technologies Co., Ltd + */ +#include +#include +#include +#include + +extern struct bpf_struct_ops bpf_bpf_dummy_ops; + +/* A common type for test_N with return value in bpf_dummy_ops */ +typedef int (*dummy_ops_test_ret_fn)(struct bpf_dummy_ops_state *state, ...); + +struct bpf_dummy_ops_test_args { + u64 args[MAX_BPF_FUNC_ARGS]; + struct bpf_dummy_ops_state state; +}; + +static struct bpf_dummy_ops_test_args * +dummy_ops_init_args(const union bpf_attr *kattr, unsigned int nr) +{ + __u32 size_in; + struct bpf_dummy_ops_test_args *args; + void __user *ctx_in; + void __user *u_state; + + size_in = kattr->test.ctx_size_in; + if (size_in != sizeof(u64) * nr) + return ERR_PTR(-EINVAL); + + args = kzalloc(sizeof(*args), GFP_KERNEL); + if (!args) + return ERR_PTR(-ENOMEM); + + ctx_in = u64_to_user_ptr(kattr->test.ctx_in); + if (copy_from_user(args->args, ctx_in, size_in)) + goto out; + + /* args[0] is 0 means state argument of test_N will be NULL */ + u_state = u64_to_user_ptr(args->args[0]); + if (u_state && copy_from_user(&args->state, u_state, + sizeof(args->state))) + goto out; + + return args; +out: + kfree(args); + return ERR_PTR(-EFAULT); +} + +static int dummy_ops_copy_args(struct bpf_dummy_ops_test_args *args) +{ + void __user *u_state; + + u_state = u64_to_user_ptr(args->args[0]); + if (u_state && copy_to_user(u_state, &args->state, sizeof(args->state))) + return -EFAULT; + + return 0; +} + +static int dummy_ops_call_op(void *image, struct bpf_dummy_ops_test_args *args) +{ + dummy_ops_test_ret_fn test = (void *)image; + struct bpf_dummy_ops_state *state = NULL; + + /* state needs to be NULL if args[0] is 0 */ + if (args->args[0]) + state = &args->state; + return test(state, args->args[1], args->args[2], + args->args[3], args->args[4]); +} + +extern const struct bpf_link_ops bpf_struct_ops_link_lops; + +int bpf_struct_ops_test_run(struct bpf_prog *prog, const union bpf_attr *kattr, + union bpf_attr __user *uattr) +{ + const struct bpf_struct_ops *st_ops = &bpf_bpf_dummy_ops; + const struct btf_type *func_proto; + struct bpf_dummy_ops_test_args *args; + struct bpf_tramp_links *tlinks; + struct bpf_tramp_link *link = NULL; + void *image = NULL; + unsigned int op_idx; + int prog_ret; + int err; + + if (prog->aux->attach_btf_id != st_ops->type_id) + return -EOPNOTSUPP; + + func_proto = prog->aux->attach_func_proto; + args = dummy_ops_init_args(kattr, btf_type_vlen(func_proto)); + if (IS_ERR(args)) + return PTR_ERR(args); + + tlinks = kcalloc(BPF_TRAMP_MAX, sizeof(*tlinks), GFP_KERNEL); + if (!tlinks) { + err = -ENOMEM; + goto out; + } + + image = bpf_jit_alloc_exec(PAGE_SIZE); + if (!image) { + err = -ENOMEM; + goto out; + } + set_vm_flush_reset_perms(image); + + link = kzalloc(sizeof(*link), GFP_USER); + if (!link) { + err = -ENOMEM; + goto out; + } + /* prog doesn't take the ownership of the reference from caller */ + bpf_prog_inc(prog); + bpf_link_init(&link->link, BPF_LINK_TYPE_STRUCT_OPS, &bpf_struct_ops_link_lops, prog); + + op_idx = prog->expected_attach_type; + err = bpf_struct_ops_prepare_trampoline(tlinks, link, + &st_ops->func_models[op_idx], + image, image + PAGE_SIZE); + if (err < 0) + goto out; + + set_memory_ro((long)image, 1); + set_memory_x((long)image, 1); + prog_ret = dummy_ops_call_op(image, args); + + err = dummy_ops_copy_args(args); + if (err) + goto out; + if (put_user(prog_ret, &uattr->test.retval)) + err = -EFAULT; +out: + kfree(args); + bpf_jit_free_exec(image); + if (link) + bpf_link_put(&link->link); + kfree(tlinks); + return err; +} + +static int bpf_dummy_init(struct btf *btf) +{ + return 0; +} + +static bool bpf_dummy_ops_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + return bpf_tracing_btf_ctx_access(off, size, type, prog, info); +} + +static int bpf_dummy_ops_btf_struct_access(struct bpf_verifier_log *log, + const struct btf *btf, + const struct btf_type *t, int off, + int size, enum bpf_access_type atype, + u32 *next_btf_id, + enum bpf_type_flag *flag) +{ + const struct btf_type *state; + s32 type_id; + int err; + + type_id = btf_find_by_name_kind(btf, "bpf_dummy_ops_state", + BTF_KIND_STRUCT); + if (type_id < 0) + return -EINVAL; + + state = btf_type_by_id(btf, type_id); + if (t != state) { + bpf_log(log, "only access to bpf_dummy_ops_state is supported\n"); + return -EACCES; + } + + err = btf_struct_access(log, btf, t, off, size, atype, next_btf_id, + flag); + if (err < 0) + return err; + + return atype == BPF_READ ? err : NOT_INIT; +} + +static const struct bpf_verifier_ops bpf_dummy_verifier_ops = { + .is_valid_access = bpf_dummy_ops_is_valid_access, + .btf_struct_access = bpf_dummy_ops_btf_struct_access, +}; + +static int bpf_dummy_init_member(const struct btf_type *t, + const struct btf_member *member, + void *kdata, const void *udata) +{ + return -EOPNOTSUPP; +} + +static int bpf_dummy_reg(void *kdata) +{ + return -EOPNOTSUPP; +} + +static void bpf_dummy_unreg(void *kdata) +{ +} + +struct bpf_struct_ops bpf_bpf_dummy_ops = { + .verifier_ops = &bpf_dummy_verifier_ops, + .init = bpf_dummy_init, + .init_member = bpf_dummy_init_member, + .reg = bpf_dummy_reg, + .unreg = bpf_dummy_unreg, + .name = "bpf_dummy_ops", +}; diff --git a/net/bpf/test_run.c b/net/bpf/test_run.c new file mode 100644 index 000000000..6094ef7cf --- /dev/null +++ b/net/bpf/test_run.c @@ -0,0 +1,1676 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (c) 2017 Facebook + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define CREATE_TRACE_POINTS +#include + +struct bpf_test_timer { + enum { NO_PREEMPT, NO_MIGRATE } mode; + u32 i; + u64 time_start, time_spent; +}; + +static void bpf_test_timer_enter(struct bpf_test_timer *t) + __acquires(rcu) +{ + rcu_read_lock(); + if (t->mode == NO_PREEMPT) + preempt_disable(); + else + migrate_disable(); + + t->time_start = ktime_get_ns(); +} + +static void bpf_test_timer_leave(struct bpf_test_timer *t) + __releases(rcu) +{ + t->time_start = 0; + + if (t->mode == NO_PREEMPT) + preempt_enable(); + else + migrate_enable(); + rcu_read_unlock(); +} + +static bool bpf_test_timer_continue(struct bpf_test_timer *t, int iterations, + u32 repeat, int *err, u32 *duration) + __must_hold(rcu) +{ + t->i += iterations; + if (t->i >= repeat) { + /* We're done. */ + t->time_spent += ktime_get_ns() - t->time_start; + do_div(t->time_spent, t->i); + *duration = t->time_spent > U32_MAX ? U32_MAX : (u32)t->time_spent; + *err = 0; + goto reset; + } + + if (signal_pending(current)) { + /* During iteration: we've been cancelled, abort. */ + *err = -EINTR; + goto reset; + } + + if (need_resched()) { + /* During iteration: we need to reschedule between runs. */ + t->time_spent += ktime_get_ns() - t->time_start; + bpf_test_timer_leave(t); + cond_resched(); + bpf_test_timer_enter(t); + } + + /* Do another round. */ + return true; + +reset: + t->i = 0; + return false; +} + +/* We put this struct at the head of each page with a context and frame + * initialised when the page is allocated, so we don't have to do this on each + * repetition of the test run. + */ +struct xdp_page_head { + struct xdp_buff orig_ctx; + struct xdp_buff ctx; + struct xdp_frame frm; + u8 data[]; +}; + +struct xdp_test_data { + struct xdp_buff *orig_ctx; + struct xdp_rxq_info rxq; + struct net_device *dev; + struct page_pool *pp; + struct xdp_frame **frames; + struct sk_buff **skbs; + struct xdp_mem_info mem; + u32 batch_size; + u32 frame_cnt; +}; + +#define TEST_XDP_FRAME_SIZE (PAGE_SIZE - sizeof(struct xdp_page_head)) +#define TEST_XDP_MAX_BATCH 256 + +static void xdp_test_run_init_page(struct page *page, void *arg) +{ + struct xdp_page_head *head = phys_to_virt(page_to_phys(page)); + struct xdp_buff *new_ctx, *orig_ctx; + u32 headroom = XDP_PACKET_HEADROOM; + struct xdp_test_data *xdp = arg; + size_t frm_len, meta_len; + struct xdp_frame *frm; + void *data; + + orig_ctx = xdp->orig_ctx; + frm_len = orig_ctx->data_end - orig_ctx->data_meta; + meta_len = orig_ctx->data - orig_ctx->data_meta; + headroom -= meta_len; + + new_ctx = &head->ctx; + frm = &head->frm; + data = &head->data; + memcpy(data + headroom, orig_ctx->data_meta, frm_len); + + xdp_init_buff(new_ctx, TEST_XDP_FRAME_SIZE, &xdp->rxq); + xdp_prepare_buff(new_ctx, data, headroom, frm_len, true); + new_ctx->data = new_ctx->data_meta + meta_len; + + xdp_update_frame_from_buff(new_ctx, frm); + frm->mem = new_ctx->rxq->mem; + + memcpy(&head->orig_ctx, new_ctx, sizeof(head->orig_ctx)); +} + +static int xdp_test_run_setup(struct xdp_test_data *xdp, struct xdp_buff *orig_ctx) +{ + struct page_pool *pp; + int err = -ENOMEM; + struct page_pool_params pp_params = { + .order = 0, + .flags = 0, + .pool_size = xdp->batch_size, + .nid = NUMA_NO_NODE, + .init_callback = xdp_test_run_init_page, + .init_arg = xdp, + }; + + xdp->frames = kvmalloc_array(xdp->batch_size, sizeof(void *), GFP_KERNEL); + if (!xdp->frames) + return -ENOMEM; + + xdp->skbs = kvmalloc_array(xdp->batch_size, sizeof(void *), GFP_KERNEL); + if (!xdp->skbs) + goto err_skbs; + + pp = page_pool_create(&pp_params); + if (IS_ERR(pp)) { + err = PTR_ERR(pp); + goto err_pp; + } + + /* will copy 'mem.id' into pp->xdp_mem_id */ + err = xdp_reg_mem_model(&xdp->mem, MEM_TYPE_PAGE_POOL, pp); + if (err) + goto err_mmodel; + + xdp->pp = pp; + + /* We create a 'fake' RXQ referencing the original dev, but with an + * xdp_mem_info pointing to our page_pool + */ + xdp_rxq_info_reg(&xdp->rxq, orig_ctx->rxq->dev, 0, 0); + xdp->rxq.mem.type = MEM_TYPE_PAGE_POOL; + xdp->rxq.mem.id = pp->xdp_mem_id; + xdp->dev = orig_ctx->rxq->dev; + xdp->orig_ctx = orig_ctx; + + return 0; + +err_mmodel: + page_pool_destroy(pp); +err_pp: + kvfree(xdp->skbs); +err_skbs: + kvfree(xdp->frames); + return err; +} + +static void xdp_test_run_teardown(struct xdp_test_data *xdp) +{ + xdp_unreg_mem_model(&xdp->mem); + page_pool_destroy(xdp->pp); + kfree(xdp->frames); + kfree(xdp->skbs); +} + +static bool ctx_was_changed(struct xdp_page_head *head) +{ + return head->orig_ctx.data != head->ctx.data || + head->orig_ctx.data_meta != head->ctx.data_meta || + head->orig_ctx.data_end != head->ctx.data_end; +} + +static void reset_ctx(struct xdp_page_head *head) +{ + if (likely(!ctx_was_changed(head))) + return; + + head->ctx.data = head->orig_ctx.data; + head->ctx.data_meta = head->orig_ctx.data_meta; + head->ctx.data_end = head->orig_ctx.data_end; + xdp_update_frame_from_buff(&head->ctx, &head->frm); +} + +static int xdp_recv_frames(struct xdp_frame **frames, int nframes, + struct sk_buff **skbs, + struct net_device *dev) +{ + gfp_t gfp = __GFP_ZERO | GFP_ATOMIC; + int i, n; + LIST_HEAD(list); + + n = kmem_cache_alloc_bulk(skbuff_head_cache, gfp, nframes, (void **)skbs); + if (unlikely(n == 0)) { + for (i = 0; i < nframes; i++) + xdp_return_frame(frames[i]); + return -ENOMEM; + } + + for (i = 0; i < nframes; i++) { + struct xdp_frame *xdpf = frames[i]; + struct sk_buff *skb = skbs[i]; + + skb = __xdp_build_skb_from_frame(xdpf, skb, dev); + if (!skb) { + xdp_return_frame(xdpf); + continue; + } + + list_add_tail(&skb->list, &list); + } + netif_receive_skb_list(&list); + + return 0; +} + +static int xdp_test_run_batch(struct xdp_test_data *xdp, struct bpf_prog *prog, + u32 repeat) +{ + struct bpf_redirect_info *ri = this_cpu_ptr(&bpf_redirect_info); + int err = 0, act, ret, i, nframes = 0, batch_sz; + struct xdp_frame **frames = xdp->frames; + struct xdp_page_head *head; + struct xdp_frame *frm; + bool redirect = false; + struct xdp_buff *ctx; + struct page *page; + + batch_sz = min_t(u32, repeat, xdp->batch_size); + + local_bh_disable(); + xdp_set_return_frame_no_direct(); + + for (i = 0; i < batch_sz; i++) { + page = page_pool_dev_alloc_pages(xdp->pp); + if (!page) { + err = -ENOMEM; + goto out; + } + + head = phys_to_virt(page_to_phys(page)); + reset_ctx(head); + ctx = &head->ctx; + frm = &head->frm; + xdp->frame_cnt++; + + act = bpf_prog_run_xdp(prog, ctx); + + /* if program changed pkt bounds we need to update the xdp_frame */ + if (unlikely(ctx_was_changed(head))) { + ret = xdp_update_frame_from_buff(ctx, frm); + if (ret) { + xdp_return_buff(ctx); + continue; + } + } + + switch (act) { + case XDP_TX: + /* we can't do a real XDP_TX since we're not in the + * driver, so turn it into a REDIRECT back to the same + * index + */ + ri->tgt_index = xdp->dev->ifindex; + ri->map_id = INT_MAX; + ri->map_type = BPF_MAP_TYPE_UNSPEC; + fallthrough; + case XDP_REDIRECT: + redirect = true; + ret = xdp_do_redirect_frame(xdp->dev, ctx, frm, prog); + if (ret) + xdp_return_buff(ctx); + break; + case XDP_PASS: + frames[nframes++] = frm; + break; + default: + bpf_warn_invalid_xdp_action(NULL, prog, act); + fallthrough; + case XDP_DROP: + xdp_return_buff(ctx); + break; + } + } + +out: + if (redirect) + xdp_do_flush(); + if (nframes) { + ret = xdp_recv_frames(frames, nframes, xdp->skbs, xdp->dev); + if (ret) + err = ret; + } + + xdp_clear_return_frame_no_direct(); + local_bh_enable(); + return err; +} + +static int bpf_test_run_xdp_live(struct bpf_prog *prog, struct xdp_buff *ctx, + u32 repeat, u32 batch_size, u32 *time) + +{ + struct xdp_test_data xdp = { .batch_size = batch_size }; + struct bpf_test_timer t = { .mode = NO_MIGRATE }; + int ret; + + if (!repeat) + repeat = 1; + + ret = xdp_test_run_setup(&xdp, ctx); + if (ret) + return ret; + + bpf_test_timer_enter(&t); + do { + xdp.frame_cnt = 0; + ret = xdp_test_run_batch(&xdp, prog, repeat - t.i); + if (unlikely(ret < 0)) + break; + } while (bpf_test_timer_continue(&t, xdp.frame_cnt, repeat, &ret, time)); + bpf_test_timer_leave(&t); + + xdp_test_run_teardown(&xdp); + return ret; +} + +static int bpf_test_run(struct bpf_prog *prog, void *ctx, u32 repeat, + u32 *retval, u32 *time, bool xdp) +{ + struct bpf_prog_array_item item = {.prog = prog}; + struct bpf_run_ctx *old_ctx; + struct bpf_cg_run_ctx run_ctx; + struct bpf_test_timer t = { NO_MIGRATE }; + enum bpf_cgroup_storage_type stype; + int ret; + + for_each_cgroup_storage_type(stype) { + item.cgroup_storage[stype] = bpf_cgroup_storage_alloc(prog, stype); + if (IS_ERR(item.cgroup_storage[stype])) { + item.cgroup_storage[stype] = NULL; + for_each_cgroup_storage_type(stype) + bpf_cgroup_storage_free(item.cgroup_storage[stype]); + return -ENOMEM; + } + } + + if (!repeat) + repeat = 1; + + bpf_test_timer_enter(&t); + old_ctx = bpf_set_run_ctx(&run_ctx.run_ctx); + do { + run_ctx.prog_item = &item; + if (xdp) + *retval = bpf_prog_run_xdp(prog, ctx); + else + *retval = bpf_prog_run(prog, ctx); + } while (bpf_test_timer_continue(&t, 1, repeat, &ret, time)); + bpf_reset_run_ctx(old_ctx); + bpf_test_timer_leave(&t); + + for_each_cgroup_storage_type(stype) + bpf_cgroup_storage_free(item.cgroup_storage[stype]); + + return ret; +} + +static int bpf_test_finish(const union bpf_attr *kattr, + union bpf_attr __user *uattr, const void *data, + struct skb_shared_info *sinfo, u32 size, + u32 retval, u32 duration) +{ + void __user *data_out = u64_to_user_ptr(kattr->test.data_out); + int err = -EFAULT; + u32 copy_size = size; + + /* Clamp copy if the user has provided a size hint, but copy the full + * buffer if not to retain old behaviour. + */ + if (kattr->test.data_size_out && + copy_size > kattr->test.data_size_out) { + copy_size = kattr->test.data_size_out; + err = -ENOSPC; + } + + if (data_out) { + int len = sinfo ? copy_size - sinfo->xdp_frags_size : copy_size; + + if (len < 0) { + err = -ENOSPC; + goto out; + } + + if (copy_to_user(data_out, data, len)) + goto out; + + if (sinfo) { + int i, offset = len; + u32 data_len; + + for (i = 0; i < sinfo->nr_frags; i++) { + skb_frag_t *frag = &sinfo->frags[i]; + + if (offset >= copy_size) { + err = -ENOSPC; + break; + } + + data_len = min_t(u32, copy_size - offset, + skb_frag_size(frag)); + + if (copy_to_user(data_out + offset, + skb_frag_address(frag), + data_len)) + goto out; + + offset += data_len; + } + } + } + + if (copy_to_user(&uattr->test.data_size_out, &size, sizeof(size))) + goto out; + if (copy_to_user(&uattr->test.retval, &retval, sizeof(retval))) + goto out; + if (copy_to_user(&uattr->test.duration, &duration, sizeof(duration))) + goto out; + if (err != -ENOSPC) + err = 0; +out: + trace_bpf_test_finish(&err); + return err; +} + +/* Integer types of various sizes and pointer combinations cover variety of + * architecture dependent calling conventions. 7+ can be supported in the + * future. + */ +__diag_push(); +__diag_ignore_all("-Wmissing-prototypes", + "Global functions as their definitions will be in vmlinux BTF"); +int noinline bpf_fentry_test1(int a) +{ + return a + 1; +} +EXPORT_SYMBOL_GPL(bpf_fentry_test1); +ALLOW_ERROR_INJECTION(bpf_fentry_test1, ERRNO); + +int noinline bpf_fentry_test2(int a, u64 b) +{ + return a + b; +} + +int noinline bpf_fentry_test3(char a, int b, u64 c) +{ + return a + b + c; +} + +int noinline bpf_fentry_test4(void *a, char b, int c, u64 d) +{ + return (long)a + b + c + d; +} + +int noinline bpf_fentry_test5(u64 a, void *b, short c, int d, u64 e) +{ + return a + (long)b + c + d + e; +} + +int noinline bpf_fentry_test6(u64 a, void *b, short c, int d, void *e, u64 f) +{ + return a + (long)b + c + d + (long)e + f; +} + +struct bpf_fentry_test_t { + struct bpf_fentry_test_t *a; +}; + +int noinline bpf_fentry_test7(struct bpf_fentry_test_t *arg) +{ + return (long)arg; +} + +int noinline bpf_fentry_test8(struct bpf_fentry_test_t *arg) +{ + return (long)arg->a; +} + +int noinline bpf_modify_return_test(int a, int *b) +{ + *b += 1; + return a + *b; +} + +u64 noinline bpf_kfunc_call_test1(struct sock *sk, u32 a, u64 b, u32 c, u64 d) +{ + return a + b + c + d; +} + +int noinline bpf_kfunc_call_test2(struct sock *sk, u32 a, u32 b) +{ + return a + b; +} + +struct sock * noinline bpf_kfunc_call_test3(struct sock *sk) +{ + return sk; +} + +struct prog_test_member1 { + int a; +}; + +struct prog_test_member { + struct prog_test_member1 m; + int c; +}; + +struct prog_test_ref_kfunc { + int a; + int b; + struct prog_test_member memb; + struct prog_test_ref_kfunc *next; + refcount_t cnt; +}; + +static struct prog_test_ref_kfunc prog_test_struct = { + .a = 42, + .b = 108, + .next = &prog_test_struct, + .cnt = REFCOUNT_INIT(1), +}; + +noinline struct prog_test_ref_kfunc * +bpf_kfunc_call_test_acquire(unsigned long *scalar_ptr) +{ + refcount_inc(&prog_test_struct.cnt); + return &prog_test_struct; +} + +noinline struct prog_test_member * +bpf_kfunc_call_memb_acquire(void) +{ + WARN_ON_ONCE(1); + return NULL; +} + +noinline void bpf_kfunc_call_test_release(struct prog_test_ref_kfunc *p) +{ + if (!p) + return; + + refcount_dec(&p->cnt); +} + +noinline void bpf_kfunc_call_memb_release(struct prog_test_member *p) +{ +} + +noinline void bpf_kfunc_call_memb1_release(struct prog_test_member1 *p) +{ + WARN_ON_ONCE(1); +} + +static int *__bpf_kfunc_call_test_get_mem(struct prog_test_ref_kfunc *p, const int size) +{ + if (size > 2 * sizeof(int)) + return NULL; + + return (int *)p; +} + +noinline int *bpf_kfunc_call_test_get_rdwr_mem(struct prog_test_ref_kfunc *p, const int rdwr_buf_size) +{ + return __bpf_kfunc_call_test_get_mem(p, rdwr_buf_size); +} + +noinline int *bpf_kfunc_call_test_get_rdonly_mem(struct prog_test_ref_kfunc *p, const int rdonly_buf_size) +{ + return __bpf_kfunc_call_test_get_mem(p, rdonly_buf_size); +} + +/* the next 2 ones can't be really used for testing expect to ensure + * that the verifier rejects the call. + * Acquire functions must return struct pointers, so these ones are + * failing. + */ +noinline int *bpf_kfunc_call_test_acq_rdonly_mem(struct prog_test_ref_kfunc *p, const int rdonly_buf_size) +{ + return __bpf_kfunc_call_test_get_mem(p, rdonly_buf_size); +} + +noinline void bpf_kfunc_call_int_mem_release(int *p) +{ +} + +noinline struct prog_test_ref_kfunc * +bpf_kfunc_call_test_kptr_get(struct prog_test_ref_kfunc **pp, int a, int b) +{ + struct prog_test_ref_kfunc *p = READ_ONCE(*pp); + + if (!p) + return NULL; + refcount_inc(&p->cnt); + return p; +} + +struct prog_test_pass1 { + int x0; + struct { + int x1; + struct { + int x2; + struct { + int x3; + }; + }; + }; +}; + +struct prog_test_pass2 { + int len; + short arr1[4]; + struct { + char arr2[4]; + unsigned long arr3[8]; + } x; +}; + +struct prog_test_fail1 { + void *p; + int x; +}; + +struct prog_test_fail2 { + int x8; + struct prog_test_pass1 x; +}; + +struct prog_test_fail3 { + int len; + char arr1[2]; + char arr2[]; +}; + +noinline void bpf_kfunc_call_test_pass_ctx(struct __sk_buff *skb) +{ +} + +noinline void bpf_kfunc_call_test_pass1(struct prog_test_pass1 *p) +{ +} + +noinline void bpf_kfunc_call_test_pass2(struct prog_test_pass2 *p) +{ +} + +noinline void bpf_kfunc_call_test_fail1(struct prog_test_fail1 *p) +{ +} + +noinline void bpf_kfunc_call_test_fail2(struct prog_test_fail2 *p) +{ +} + +noinline void bpf_kfunc_call_test_fail3(struct prog_test_fail3 *p) +{ +} + +noinline void bpf_kfunc_call_test_mem_len_pass1(void *mem, int mem__sz) +{ +} + +noinline void bpf_kfunc_call_test_mem_len_fail1(void *mem, int len) +{ +} + +noinline void bpf_kfunc_call_test_mem_len_fail2(u64 *mem, int len) +{ +} + +noinline void bpf_kfunc_call_test_ref(struct prog_test_ref_kfunc *p) +{ +} + +noinline void bpf_kfunc_call_test_destructive(void) +{ +} + +__diag_pop(); + +ALLOW_ERROR_INJECTION(bpf_modify_return_test, ERRNO); + +BTF_SET8_START(test_sk_check_kfunc_ids) +BTF_ID_FLAGS(func, bpf_kfunc_call_test1) +BTF_ID_FLAGS(func, bpf_kfunc_call_test2) +BTF_ID_FLAGS(func, bpf_kfunc_call_test3) +BTF_ID_FLAGS(func, bpf_kfunc_call_test_acquire, KF_ACQUIRE | KF_RET_NULL) +BTF_ID_FLAGS(func, bpf_kfunc_call_memb_acquire, KF_ACQUIRE | KF_RET_NULL) +BTF_ID_FLAGS(func, bpf_kfunc_call_test_release, KF_RELEASE) +BTF_ID_FLAGS(func, bpf_kfunc_call_memb_release, KF_RELEASE) +BTF_ID_FLAGS(func, bpf_kfunc_call_memb1_release, KF_RELEASE) +BTF_ID_FLAGS(func, bpf_kfunc_call_test_get_rdwr_mem, KF_RET_NULL) +BTF_ID_FLAGS(func, bpf_kfunc_call_test_get_rdonly_mem, KF_RET_NULL) +BTF_ID_FLAGS(func, bpf_kfunc_call_test_acq_rdonly_mem, KF_ACQUIRE | KF_RET_NULL) +BTF_ID_FLAGS(func, bpf_kfunc_call_int_mem_release, KF_RELEASE) +BTF_ID_FLAGS(func, bpf_kfunc_call_test_kptr_get, KF_ACQUIRE | KF_RET_NULL | KF_KPTR_GET) +BTF_ID_FLAGS(func, bpf_kfunc_call_test_pass_ctx) +BTF_ID_FLAGS(func, bpf_kfunc_call_test_pass1) +BTF_ID_FLAGS(func, bpf_kfunc_call_test_pass2) +BTF_ID_FLAGS(func, bpf_kfunc_call_test_fail1) +BTF_ID_FLAGS(func, bpf_kfunc_call_test_fail2) +BTF_ID_FLAGS(func, bpf_kfunc_call_test_fail3) +BTF_ID_FLAGS(func, bpf_kfunc_call_test_mem_len_pass1) +BTF_ID_FLAGS(func, bpf_kfunc_call_test_mem_len_fail1) +BTF_ID_FLAGS(func, bpf_kfunc_call_test_mem_len_fail2) +BTF_ID_FLAGS(func, bpf_kfunc_call_test_ref, KF_TRUSTED_ARGS) +BTF_ID_FLAGS(func, bpf_kfunc_call_test_destructive, KF_DESTRUCTIVE) +BTF_SET8_END(test_sk_check_kfunc_ids) + +static void *bpf_test_init(const union bpf_attr *kattr, u32 user_size, + u32 size, u32 headroom, u32 tailroom) +{ + void __user *data_in = u64_to_user_ptr(kattr->test.data_in); + void *data; + + if (size < ETH_HLEN || size > PAGE_SIZE - headroom - tailroom) + return ERR_PTR(-EINVAL); + + if (user_size > size) + return ERR_PTR(-EMSGSIZE); + + size = SKB_DATA_ALIGN(size); + data = kzalloc(size + headroom + tailroom, GFP_USER); + if (!data) + return ERR_PTR(-ENOMEM); + + if (copy_from_user(data + headroom, data_in, user_size)) { + kfree(data); + return ERR_PTR(-EFAULT); + } + + return data; +} + +int bpf_prog_test_run_tracing(struct bpf_prog *prog, + const union bpf_attr *kattr, + union bpf_attr __user *uattr) +{ + struct bpf_fentry_test_t arg = {}; + u16 side_effect = 0, ret = 0; + int b = 2, err = -EFAULT; + u32 retval = 0; + + if (kattr->test.flags || kattr->test.cpu || kattr->test.batch_size) + return -EINVAL; + + switch (prog->expected_attach_type) { + case BPF_TRACE_FENTRY: + case BPF_TRACE_FEXIT: + if (bpf_fentry_test1(1) != 2 || + bpf_fentry_test2(2, 3) != 5 || + bpf_fentry_test3(4, 5, 6) != 15 || + bpf_fentry_test4((void *)7, 8, 9, 10) != 34 || + bpf_fentry_test5(11, (void *)12, 13, 14, 15) != 65 || + bpf_fentry_test6(16, (void *)17, 18, 19, (void *)20, 21) != 111 || + bpf_fentry_test7((struct bpf_fentry_test_t *)0) != 0 || + bpf_fentry_test8(&arg) != 0) + goto out; + break; + case BPF_MODIFY_RETURN: + ret = bpf_modify_return_test(1, &b); + if (b != 2) + side_effect = 1; + break; + default: + goto out; + } + + retval = ((u32)side_effect << 16) | ret; + if (copy_to_user(&uattr->test.retval, &retval, sizeof(retval))) + goto out; + + err = 0; +out: + trace_bpf_test_finish(&err); + return err; +} + +struct bpf_raw_tp_test_run_info { + struct bpf_prog *prog; + void *ctx; + u32 retval; +}; + +static void +__bpf_prog_test_run_raw_tp(void *data) +{ + struct bpf_raw_tp_test_run_info *info = data; + + rcu_read_lock(); + info->retval = bpf_prog_run(info->prog, info->ctx); + rcu_read_unlock(); +} + +int bpf_prog_test_run_raw_tp(struct bpf_prog *prog, + const union bpf_attr *kattr, + union bpf_attr __user *uattr) +{ + void __user *ctx_in = u64_to_user_ptr(kattr->test.ctx_in); + __u32 ctx_size_in = kattr->test.ctx_size_in; + struct bpf_raw_tp_test_run_info info; + int cpu = kattr->test.cpu, err = 0; + int current_cpu; + + /* doesn't support data_in/out, ctx_out, duration, or repeat */ + if (kattr->test.data_in || kattr->test.data_out || + kattr->test.ctx_out || kattr->test.duration || + kattr->test.repeat || kattr->test.batch_size) + return -EINVAL; + + if (ctx_size_in < prog->aux->max_ctx_offset || + ctx_size_in > MAX_BPF_FUNC_ARGS * sizeof(u64)) + return -EINVAL; + + if ((kattr->test.flags & BPF_F_TEST_RUN_ON_CPU) == 0 && cpu != 0) + return -EINVAL; + + if (ctx_size_in) { + info.ctx = memdup_user(ctx_in, ctx_size_in); + if (IS_ERR(info.ctx)) + return PTR_ERR(info.ctx); + } else { + info.ctx = NULL; + } + + info.prog = prog; + + current_cpu = get_cpu(); + if ((kattr->test.flags & BPF_F_TEST_RUN_ON_CPU) == 0 || + cpu == current_cpu) { + __bpf_prog_test_run_raw_tp(&info); + } else if (cpu >= nr_cpu_ids || !cpu_online(cpu)) { + /* smp_call_function_single() also checks cpu_online() + * after csd_lock(). However, since cpu is from user + * space, let's do an extra quick check to filter out + * invalid value before smp_call_function_single(). + */ + err = -ENXIO; + } else { + err = smp_call_function_single(cpu, __bpf_prog_test_run_raw_tp, + &info, 1); + } + put_cpu(); + + if (!err && + copy_to_user(&uattr->test.retval, &info.retval, sizeof(u32))) + err = -EFAULT; + + kfree(info.ctx); + return err; +} + +static void *bpf_ctx_init(const union bpf_attr *kattr, u32 max_size) +{ + void __user *data_in = u64_to_user_ptr(kattr->test.ctx_in); + void __user *data_out = u64_to_user_ptr(kattr->test.ctx_out); + u32 size = kattr->test.ctx_size_in; + void *data; + int err; + + if (!data_in && !data_out) + return NULL; + + data = kzalloc(max_size, GFP_USER); + if (!data) + return ERR_PTR(-ENOMEM); + + if (data_in) { + err = bpf_check_uarg_tail_zero(USER_BPFPTR(data_in), max_size, size); + if (err) { + kfree(data); + return ERR_PTR(err); + } + + size = min_t(u32, max_size, size); + if (copy_from_user(data, data_in, size)) { + kfree(data); + return ERR_PTR(-EFAULT); + } + } + return data; +} + +static int bpf_ctx_finish(const union bpf_attr *kattr, + union bpf_attr __user *uattr, const void *data, + u32 size) +{ + void __user *data_out = u64_to_user_ptr(kattr->test.ctx_out); + int err = -EFAULT; + u32 copy_size = size; + + if (!data || !data_out) + return 0; + + if (copy_size > kattr->test.ctx_size_out) { + copy_size = kattr->test.ctx_size_out; + err = -ENOSPC; + } + + if (copy_to_user(data_out, data, copy_size)) + goto out; + if (copy_to_user(&uattr->test.ctx_size_out, &size, sizeof(size))) + goto out; + if (err != -ENOSPC) + err = 0; +out: + return err; +} + +/** + * range_is_zero - test whether buffer is initialized + * @buf: buffer to check + * @from: check from this position + * @to: check up until (excluding) this position + * + * This function returns true if the there is a non-zero byte + * in the buf in the range [from,to). + */ +static inline bool range_is_zero(void *buf, size_t from, size_t to) +{ + return !memchr_inv((u8 *)buf + from, 0, to - from); +} + +static int convert___skb_to_skb(struct sk_buff *skb, struct __sk_buff *__skb) +{ + struct qdisc_skb_cb *cb = (struct qdisc_skb_cb *)skb->cb; + + if (!__skb) + return 0; + + /* make sure the fields we don't use are zeroed */ + if (!range_is_zero(__skb, 0, offsetof(struct __sk_buff, mark))) + return -EINVAL; + + /* mark is allowed */ + + if (!range_is_zero(__skb, offsetofend(struct __sk_buff, mark), + offsetof(struct __sk_buff, priority))) + return -EINVAL; + + /* priority is allowed */ + /* ingress_ifindex is allowed */ + /* ifindex is allowed */ + + if (!range_is_zero(__skb, offsetofend(struct __sk_buff, ifindex), + offsetof(struct __sk_buff, cb))) + return -EINVAL; + + /* cb is allowed */ + + if (!range_is_zero(__skb, offsetofend(struct __sk_buff, cb), + offsetof(struct __sk_buff, tstamp))) + return -EINVAL; + + /* tstamp is allowed */ + /* wire_len is allowed */ + /* gso_segs is allowed */ + + if (!range_is_zero(__skb, offsetofend(struct __sk_buff, gso_segs), + offsetof(struct __sk_buff, gso_size))) + return -EINVAL; + + /* gso_size is allowed */ + + if (!range_is_zero(__skb, offsetofend(struct __sk_buff, gso_size), + offsetof(struct __sk_buff, hwtstamp))) + return -EINVAL; + + /* hwtstamp is allowed */ + + if (!range_is_zero(__skb, offsetofend(struct __sk_buff, hwtstamp), + sizeof(struct __sk_buff))) + return -EINVAL; + + skb->mark = __skb->mark; + skb->priority = __skb->priority; + skb->skb_iif = __skb->ingress_ifindex; + skb->tstamp = __skb->tstamp; + memcpy(&cb->data, __skb->cb, QDISC_CB_PRIV_LEN); + + if (__skb->wire_len == 0) { + cb->pkt_len = skb->len; + } else { + if (__skb->wire_len < skb->len || + __skb->wire_len > GSO_LEGACY_MAX_SIZE) + return -EINVAL; + cb->pkt_len = __skb->wire_len; + } + + if (__skb->gso_segs > GSO_MAX_SEGS) + return -EINVAL; + skb_shinfo(skb)->gso_segs = __skb->gso_segs; + skb_shinfo(skb)->gso_size = __skb->gso_size; + skb_shinfo(skb)->hwtstamps.hwtstamp = __skb->hwtstamp; + + return 0; +} + +static void convert_skb_to___skb(struct sk_buff *skb, struct __sk_buff *__skb) +{ + struct qdisc_skb_cb *cb = (struct qdisc_skb_cb *)skb->cb; + + if (!__skb) + return; + + __skb->mark = skb->mark; + __skb->priority = skb->priority; + __skb->ingress_ifindex = skb->skb_iif; + __skb->ifindex = skb->dev->ifindex; + __skb->tstamp = skb->tstamp; + memcpy(__skb->cb, &cb->data, QDISC_CB_PRIV_LEN); + __skb->wire_len = cb->pkt_len; + __skb->gso_segs = skb_shinfo(skb)->gso_segs; + __skb->hwtstamp = skb_shinfo(skb)->hwtstamps.hwtstamp; +} + +static struct proto bpf_dummy_proto = { + .name = "bpf_dummy", + .owner = THIS_MODULE, + .obj_size = sizeof(struct sock), +}; + +int bpf_prog_test_run_skb(struct bpf_prog *prog, const union bpf_attr *kattr, + union bpf_attr __user *uattr) +{ + bool is_l2 = false, is_direct_pkt_access = false; + struct net *net = current->nsproxy->net_ns; + struct net_device *dev = net->loopback_dev; + u32 size = kattr->test.data_size_in; + u32 repeat = kattr->test.repeat; + struct __sk_buff *ctx = NULL; + u32 retval, duration; + int hh_len = ETH_HLEN; + struct sk_buff *skb; + struct sock *sk; + void *data; + int ret; + + if (kattr->test.flags || kattr->test.cpu || kattr->test.batch_size) + return -EINVAL; + + data = bpf_test_init(kattr, kattr->test.data_size_in, + size, NET_SKB_PAD + NET_IP_ALIGN, + SKB_DATA_ALIGN(sizeof(struct skb_shared_info))); + if (IS_ERR(data)) + return PTR_ERR(data); + + ctx = bpf_ctx_init(kattr, sizeof(struct __sk_buff)); + if (IS_ERR(ctx)) { + kfree(data); + return PTR_ERR(ctx); + } + + switch (prog->type) { + case BPF_PROG_TYPE_SCHED_CLS: + case BPF_PROG_TYPE_SCHED_ACT: + is_l2 = true; + fallthrough; + case BPF_PROG_TYPE_LWT_IN: + case BPF_PROG_TYPE_LWT_OUT: + case BPF_PROG_TYPE_LWT_XMIT: + is_direct_pkt_access = true; + break; + default: + break; + } + + sk = sk_alloc(net, AF_UNSPEC, GFP_USER, &bpf_dummy_proto, 1); + if (!sk) { + kfree(data); + kfree(ctx); + return -ENOMEM; + } + sock_init_data(NULL, sk); + + skb = build_skb(data, 0); + if (!skb) { + kfree(data); + kfree(ctx); + sk_free(sk); + return -ENOMEM; + } + skb->sk = sk; + + skb_reserve(skb, NET_SKB_PAD + NET_IP_ALIGN); + __skb_put(skb, size); + if (ctx && ctx->ifindex > 1) { + dev = dev_get_by_index(net, ctx->ifindex); + if (!dev) { + ret = -ENODEV; + goto out; + } + } + skb->protocol = eth_type_trans(skb, dev); + skb_reset_network_header(skb); + + switch (skb->protocol) { + case htons(ETH_P_IP): + sk->sk_family = AF_INET; + if (sizeof(struct iphdr) <= skb_headlen(skb)) { + sk->sk_rcv_saddr = ip_hdr(skb)->saddr; + sk->sk_daddr = ip_hdr(skb)->daddr; + } + break; +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): + sk->sk_family = AF_INET6; + if (sizeof(struct ipv6hdr) <= skb_headlen(skb)) { + sk->sk_v6_rcv_saddr = ipv6_hdr(skb)->saddr; + sk->sk_v6_daddr = ipv6_hdr(skb)->daddr; + } + break; +#endif + default: + break; + } + + if (is_l2) + __skb_push(skb, hh_len); + if (is_direct_pkt_access) + bpf_compute_data_pointers(skb); + ret = convert___skb_to_skb(skb, ctx); + if (ret) + goto out; + ret = bpf_test_run(prog, skb, repeat, &retval, &duration, false); + if (ret) + goto out; + if (!is_l2) { + if (skb_headroom(skb) < hh_len) { + int nhead = HH_DATA_ALIGN(hh_len - skb_headroom(skb)); + + if (pskb_expand_head(skb, nhead, 0, GFP_USER)) { + ret = -ENOMEM; + goto out; + } + } + memset(__skb_push(skb, hh_len), 0, hh_len); + } + convert_skb_to___skb(skb, ctx); + + size = skb->len; + /* bpf program can never convert linear skb to non-linear */ + if (WARN_ON_ONCE(skb_is_nonlinear(skb))) + size = skb_headlen(skb); + ret = bpf_test_finish(kattr, uattr, skb->data, NULL, size, retval, + duration); + if (!ret) + ret = bpf_ctx_finish(kattr, uattr, ctx, + sizeof(struct __sk_buff)); +out: + if (dev && dev != net->loopback_dev) + dev_put(dev); + kfree_skb(skb); + sk_free(sk); + kfree(ctx); + return ret; +} + +static int xdp_convert_md_to_buff(struct xdp_md *xdp_md, struct xdp_buff *xdp) +{ + unsigned int ingress_ifindex, rx_queue_index; + struct netdev_rx_queue *rxqueue; + struct net_device *device; + + if (!xdp_md) + return 0; + + if (xdp_md->egress_ifindex != 0) + return -EINVAL; + + ingress_ifindex = xdp_md->ingress_ifindex; + rx_queue_index = xdp_md->rx_queue_index; + + if (!ingress_ifindex && rx_queue_index) + return -EINVAL; + + if (ingress_ifindex) { + device = dev_get_by_index(current->nsproxy->net_ns, + ingress_ifindex); + if (!device) + return -ENODEV; + + if (rx_queue_index >= device->real_num_rx_queues) + goto free_dev; + + rxqueue = __netif_get_rx_queue(device, rx_queue_index); + + if (!xdp_rxq_info_is_reg(&rxqueue->xdp_rxq)) + goto free_dev; + + xdp->rxq = &rxqueue->xdp_rxq; + /* The device is now tracked in the xdp->rxq for later + * dev_put() + */ + } + + xdp->data = xdp->data_meta + xdp_md->data; + return 0; + +free_dev: + dev_put(device); + return -EINVAL; +} + +static void xdp_convert_buff_to_md(struct xdp_buff *xdp, struct xdp_md *xdp_md) +{ + if (!xdp_md) + return; + + xdp_md->data = xdp->data - xdp->data_meta; + xdp_md->data_end = xdp->data_end - xdp->data_meta; + + if (xdp_md->ingress_ifindex) + dev_put(xdp->rxq->dev); +} + +int bpf_prog_test_run_xdp(struct bpf_prog *prog, const union bpf_attr *kattr, + union bpf_attr __user *uattr) +{ + bool do_live = (kattr->test.flags & BPF_F_TEST_XDP_LIVE_FRAMES); + u32 tailroom = SKB_DATA_ALIGN(sizeof(struct skb_shared_info)); + u32 batch_size = kattr->test.batch_size; + u32 retval = 0, duration, max_data_sz; + u32 size = kattr->test.data_size_in; + u32 headroom = XDP_PACKET_HEADROOM; + u32 repeat = kattr->test.repeat; + struct netdev_rx_queue *rxqueue; + struct skb_shared_info *sinfo; + struct xdp_buff xdp = {}; + int i, ret = -EINVAL; + struct xdp_md *ctx; + void *data; + + if (prog->expected_attach_type == BPF_XDP_DEVMAP || + prog->expected_attach_type == BPF_XDP_CPUMAP) + return -EINVAL; + + if (kattr->test.flags & ~BPF_F_TEST_XDP_LIVE_FRAMES) + return -EINVAL; + + if (do_live) { + if (!batch_size) + batch_size = NAPI_POLL_WEIGHT; + else if (batch_size > TEST_XDP_MAX_BATCH) + return -E2BIG; + + headroom += sizeof(struct xdp_page_head); + } else if (batch_size) { + return -EINVAL; + } + + ctx = bpf_ctx_init(kattr, sizeof(struct xdp_md)); + if (IS_ERR(ctx)) + return PTR_ERR(ctx); + + if (ctx) { + /* There can't be user provided data before the meta data */ + if (ctx->data_meta || ctx->data_end != size || + ctx->data > ctx->data_end || + unlikely(xdp_metalen_invalid(ctx->data)) || + (do_live && (kattr->test.data_out || kattr->test.ctx_out))) + goto free_ctx; + /* Meta data is allocated from the headroom */ + headroom -= ctx->data; + } + + max_data_sz = 4096 - headroom - tailroom; + if (size > max_data_sz) { + /* disallow live data mode for jumbo frames */ + if (do_live) + goto free_ctx; + size = max_data_sz; + } + + data = bpf_test_init(kattr, size, max_data_sz, headroom, tailroom); + if (IS_ERR(data)) { + ret = PTR_ERR(data); + goto free_ctx; + } + + rxqueue = __netif_get_rx_queue(current->nsproxy->net_ns->loopback_dev, 0); + rxqueue->xdp_rxq.frag_size = headroom + max_data_sz + tailroom; + xdp_init_buff(&xdp, rxqueue->xdp_rxq.frag_size, &rxqueue->xdp_rxq); + xdp_prepare_buff(&xdp, data, headroom, size, true); + sinfo = xdp_get_shared_info_from_buff(&xdp); + + ret = xdp_convert_md_to_buff(ctx, &xdp); + if (ret) + goto free_data; + + if (unlikely(kattr->test.data_size_in > size)) { + void __user *data_in = u64_to_user_ptr(kattr->test.data_in); + + while (size < kattr->test.data_size_in) { + struct page *page; + skb_frag_t *frag; + u32 data_len; + + if (sinfo->nr_frags == MAX_SKB_FRAGS) { + ret = -ENOMEM; + goto out; + } + + page = alloc_page(GFP_KERNEL); + if (!page) { + ret = -ENOMEM; + goto out; + } + + frag = &sinfo->frags[sinfo->nr_frags++]; + __skb_frag_set_page(frag, page); + + data_len = min_t(u32, kattr->test.data_size_in - size, + PAGE_SIZE); + skb_frag_size_set(frag, data_len); + + if (copy_from_user(page_address(page), data_in + size, + data_len)) { + ret = -EFAULT; + goto out; + } + sinfo->xdp_frags_size += data_len; + size += data_len; + } + xdp_buff_set_frags_flag(&xdp); + } + + if (repeat > 1) + bpf_prog_change_xdp(NULL, prog); + + if (do_live) + ret = bpf_test_run_xdp_live(prog, &xdp, repeat, batch_size, &duration); + else + ret = bpf_test_run(prog, &xdp, repeat, &retval, &duration, true); + /* We convert the xdp_buff back to an xdp_md before checking the return + * code so the reference count of any held netdevice will be decremented + * even if the test run failed. + */ + xdp_convert_buff_to_md(&xdp, ctx); + if (ret) + goto out; + + size = xdp.data_end - xdp.data_meta + sinfo->xdp_frags_size; + ret = bpf_test_finish(kattr, uattr, xdp.data_meta, sinfo, size, + retval, duration); + if (!ret) + ret = bpf_ctx_finish(kattr, uattr, ctx, + sizeof(struct xdp_md)); + +out: + if (repeat > 1) + bpf_prog_change_xdp(prog, NULL); +free_data: + for (i = 0; i < sinfo->nr_frags; i++) + __free_page(skb_frag_page(&sinfo->frags[i])); + kfree(data); +free_ctx: + kfree(ctx); + return ret; +} + +static int verify_user_bpf_flow_keys(struct bpf_flow_keys *ctx) +{ + /* make sure the fields we don't use are zeroed */ + if (!range_is_zero(ctx, 0, offsetof(struct bpf_flow_keys, flags))) + return -EINVAL; + + /* flags is allowed */ + + if (!range_is_zero(ctx, offsetofend(struct bpf_flow_keys, flags), + sizeof(struct bpf_flow_keys))) + return -EINVAL; + + return 0; +} + +int bpf_prog_test_run_flow_dissector(struct bpf_prog *prog, + const union bpf_attr *kattr, + union bpf_attr __user *uattr) +{ + struct bpf_test_timer t = { NO_PREEMPT }; + u32 size = kattr->test.data_size_in; + struct bpf_flow_dissector ctx = {}; + u32 repeat = kattr->test.repeat; + struct bpf_flow_keys *user_ctx; + struct bpf_flow_keys flow_keys; + const struct ethhdr *eth; + unsigned int flags = 0; + u32 retval, duration; + void *data; + int ret; + + if (kattr->test.flags || kattr->test.cpu || kattr->test.batch_size) + return -EINVAL; + + if (size < ETH_HLEN) + return -EINVAL; + + data = bpf_test_init(kattr, kattr->test.data_size_in, size, 0, 0); + if (IS_ERR(data)) + return PTR_ERR(data); + + eth = (struct ethhdr *)data; + + if (!repeat) + repeat = 1; + + user_ctx = bpf_ctx_init(kattr, sizeof(struct bpf_flow_keys)); + if (IS_ERR(user_ctx)) { + kfree(data); + return PTR_ERR(user_ctx); + } + if (user_ctx) { + ret = verify_user_bpf_flow_keys(user_ctx); + if (ret) + goto out; + flags = user_ctx->flags; + } + + ctx.flow_keys = &flow_keys; + ctx.data = data; + ctx.data_end = (__u8 *)data + size; + + bpf_test_timer_enter(&t); + do { + retval = bpf_flow_dissect(prog, &ctx, eth->h_proto, ETH_HLEN, + size, flags); + } while (bpf_test_timer_continue(&t, 1, repeat, &ret, &duration)); + bpf_test_timer_leave(&t); + + if (ret < 0) + goto out; + + ret = bpf_test_finish(kattr, uattr, &flow_keys, NULL, + sizeof(flow_keys), retval, duration); + if (!ret) + ret = bpf_ctx_finish(kattr, uattr, user_ctx, + sizeof(struct bpf_flow_keys)); + +out: + kfree(user_ctx); + kfree(data); + return ret; +} + +int bpf_prog_test_run_sk_lookup(struct bpf_prog *prog, const union bpf_attr *kattr, + union bpf_attr __user *uattr) +{ + struct bpf_test_timer t = { NO_PREEMPT }; + struct bpf_prog_array *progs = NULL; + struct bpf_sk_lookup_kern ctx = {}; + u32 repeat = kattr->test.repeat; + struct bpf_sk_lookup *user_ctx; + u32 retval, duration; + int ret = -EINVAL; + + if (kattr->test.flags || kattr->test.cpu || kattr->test.batch_size) + return -EINVAL; + + if (kattr->test.data_in || kattr->test.data_size_in || kattr->test.data_out || + kattr->test.data_size_out) + return -EINVAL; + + if (!repeat) + repeat = 1; + + user_ctx = bpf_ctx_init(kattr, sizeof(*user_ctx)); + if (IS_ERR(user_ctx)) + return PTR_ERR(user_ctx); + + if (!user_ctx) + return -EINVAL; + + if (user_ctx->sk) + goto out; + + if (!range_is_zero(user_ctx, offsetofend(typeof(*user_ctx), local_port), sizeof(*user_ctx))) + goto out; + + if (user_ctx->local_port > U16_MAX) { + ret = -ERANGE; + goto out; + } + + ctx.family = (u16)user_ctx->family; + ctx.protocol = (u16)user_ctx->protocol; + ctx.dport = (u16)user_ctx->local_port; + ctx.sport = user_ctx->remote_port; + + switch (ctx.family) { + case AF_INET: + ctx.v4.daddr = (__force __be32)user_ctx->local_ip4; + ctx.v4.saddr = (__force __be32)user_ctx->remote_ip4; + break; + +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + ctx.v6.daddr = (struct in6_addr *)user_ctx->local_ip6; + ctx.v6.saddr = (struct in6_addr *)user_ctx->remote_ip6; + break; +#endif + + default: + ret = -EAFNOSUPPORT; + goto out; + } + + progs = bpf_prog_array_alloc(1, GFP_KERNEL); + if (!progs) { + ret = -ENOMEM; + goto out; + } + + progs->items[0].prog = prog; + + bpf_test_timer_enter(&t); + do { + ctx.selected_sk = NULL; + retval = BPF_PROG_SK_LOOKUP_RUN_ARRAY(progs, ctx, bpf_prog_run); + } while (bpf_test_timer_continue(&t, 1, repeat, &ret, &duration)); + bpf_test_timer_leave(&t); + + if (ret < 0) + goto out; + + user_ctx->cookie = 0; + if (ctx.selected_sk) { + if (ctx.selected_sk->sk_reuseport && !ctx.no_reuseport) { + ret = -EOPNOTSUPP; + goto out; + } + + user_ctx->cookie = sock_gen_cookie(ctx.selected_sk); + } + + ret = bpf_test_finish(kattr, uattr, NULL, NULL, 0, retval, duration); + if (!ret) + ret = bpf_ctx_finish(kattr, uattr, user_ctx, sizeof(*user_ctx)); + +out: + bpf_prog_array_free(progs); + kfree(user_ctx); + return ret; +} + +int bpf_prog_test_run_syscall(struct bpf_prog *prog, + const union bpf_attr *kattr, + union bpf_attr __user *uattr) +{ + void __user *ctx_in = u64_to_user_ptr(kattr->test.ctx_in); + __u32 ctx_size_in = kattr->test.ctx_size_in; + void *ctx = NULL; + u32 retval; + int err = 0; + + /* doesn't support data_in/out, ctx_out, duration, or repeat or flags */ + if (kattr->test.data_in || kattr->test.data_out || + kattr->test.ctx_out || kattr->test.duration || + kattr->test.repeat || kattr->test.flags || + kattr->test.batch_size) + return -EINVAL; + + if (ctx_size_in < prog->aux->max_ctx_offset || + ctx_size_in > U16_MAX) + return -EINVAL; + + if (ctx_size_in) { + ctx = memdup_user(ctx_in, ctx_size_in); + if (IS_ERR(ctx)) + return PTR_ERR(ctx); + } + + rcu_read_lock_trace(); + retval = bpf_prog_run_pin_on_cpu(prog, ctx); + rcu_read_unlock_trace(); + + if (copy_to_user(&uattr->test.retval, &retval, sizeof(u32))) { + err = -EFAULT; + goto out; + } + if (ctx_size_in) + if (copy_to_user(ctx_in, ctx, ctx_size_in)) + err = -EFAULT; +out: + kfree(ctx); + return err; +} + +static const struct btf_kfunc_id_set bpf_prog_test_kfunc_set = { + .owner = THIS_MODULE, + .set = &test_sk_check_kfunc_ids, +}; + +BTF_ID_LIST(bpf_prog_test_dtor_kfunc_ids) +BTF_ID(struct, prog_test_ref_kfunc) +BTF_ID(func, bpf_kfunc_call_test_release) +BTF_ID(struct, prog_test_member) +BTF_ID(func, bpf_kfunc_call_memb_release) + +static int __init bpf_prog_test_run_init(void) +{ + const struct btf_id_dtor_kfunc bpf_prog_test_dtor_kfunc[] = { + { + .btf_id = bpf_prog_test_dtor_kfunc_ids[0], + .kfunc_btf_id = bpf_prog_test_dtor_kfunc_ids[1] + }, + { + .btf_id = bpf_prog_test_dtor_kfunc_ids[2], + .kfunc_btf_id = bpf_prog_test_dtor_kfunc_ids[3], + }, + }; + int ret; + + ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_SCHED_CLS, &bpf_prog_test_kfunc_set); + ret = ret ?: register_btf_kfunc_id_set(BPF_PROG_TYPE_TRACING, &bpf_prog_test_kfunc_set); + ret = ret ?: register_btf_kfunc_id_set(BPF_PROG_TYPE_SYSCALL, &bpf_prog_test_kfunc_set); + return ret ?: register_btf_id_dtor_kfuncs(bpf_prog_test_dtor_kfunc, + ARRAY_SIZE(bpf_prog_test_dtor_kfunc), + THIS_MODULE); +} +late_initcall(bpf_prog_test_run_init); diff --git a/net/bpfilter/.gitignore b/net/bpfilter/.gitignore new file mode 100644 index 000000000..f34e85ee8 --- /dev/null +++ b/net/bpfilter/.gitignore @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: GPL-2.0-only +bpfilter_umh diff --git a/net/bpfilter/Kconfig b/net/bpfilter/Kconfig new file mode 100644 index 000000000..3d4a21462 --- /dev/null +++ b/net/bpfilter/Kconfig @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: GPL-2.0-only +menuconfig BPFILTER + bool "BPF based packet filtering framework (BPFILTER)" + depends on BPF && INET + select USERMODE_DRIVER + help + This builds experimental bpfilter framework that is aiming to + provide netfilter compatible functionality via BPF + +if BPFILTER +config BPFILTER_UMH + tristate "bpfilter kernel module with user mode helper" + depends on CC_CAN_LINK + depends on m || CC_CAN_LINK_STATIC + default m + help + This builds bpfilter kernel module with embedded user mode helper + + Note: To compile this as built-in, your toolchain must support + building static binaries, since rootfs isn't mounted at the time + when __init functions are called and do_execv won't be able to find + the elf interpreter. +endif diff --git a/net/bpfilter/Makefile b/net/bpfilter/Makefile new file mode 100644 index 000000000..cdac82b8c --- /dev/null +++ b/net/bpfilter/Makefile @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the Linux BPFILTER layer. +# + +userprogs := bpfilter_umh +bpfilter_umh-objs := main.o +userccflags += -I $(srctree)/tools/include/ -I $(srctree)/tools/include/uapi + +ifeq ($(CONFIG_BPFILTER_UMH), y) +# builtin bpfilter_umh should be linked with -static +# since rootfs isn't mounted at the time of __init +# function is called and do_execv won't find elf interpreter +userldflags += -static +endif + +$(obj)/bpfilter_umh_blob.o: $(obj)/bpfilter_umh + +obj-$(CONFIG_BPFILTER_UMH) += bpfilter.o +bpfilter-objs += bpfilter_kern.o bpfilter_umh_blob.o diff --git a/net/bpfilter/bpfilter_kern.c b/net/bpfilter/bpfilter_kern.c new file mode 100644 index 000000000..422ec6e7c --- /dev/null +++ b/net/bpfilter/bpfilter_kern.c @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: GPL-2.0 +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#include "msgfmt.h" + +extern char bpfilter_umh_start; +extern char bpfilter_umh_end; + +static void shutdown_umh(void) +{ + struct umd_info *info = &bpfilter_ops.info; + struct pid *tgid = info->tgid; + + if (tgid) { + kill_pid(tgid, SIGKILL, 1); + wait_event(tgid->wait_pidfd, thread_group_exited(tgid)); + bpfilter_umh_cleanup(info); + } +} + +static void __stop_umh(void) +{ + if (IS_ENABLED(CONFIG_INET)) + shutdown_umh(); +} + +static int bpfilter_send_req(struct mbox_request *req) +{ + struct mbox_reply reply; + loff_t pos = 0; + ssize_t n; + + if (!bpfilter_ops.info.tgid) + return -EFAULT; + pos = 0; + n = kernel_write(bpfilter_ops.info.pipe_to_umh, req, sizeof(*req), + &pos); + if (n != sizeof(*req)) { + pr_err("write fail %zd\n", n); + goto stop; + } + pos = 0; + n = kernel_read(bpfilter_ops.info.pipe_from_umh, &reply, sizeof(reply), + &pos); + if (n != sizeof(reply)) { + pr_err("read fail %zd\n", n); + goto stop; + } + return reply.status; +stop: + __stop_umh(); + return -EFAULT; +} + +static int bpfilter_process_sockopt(struct sock *sk, int optname, + sockptr_t optval, unsigned int optlen, + bool is_set) +{ + struct mbox_request req = { + .is_set = is_set, + .pid = current->pid, + .cmd = optname, + .addr = (uintptr_t)optval.user, + .len = optlen, + }; + if (sockptr_is_kernel(optval)) { + pr_err("kernel access not supported\n"); + return -EFAULT; + } + return bpfilter_send_req(&req); +} + +static int start_umh(void) +{ + struct mbox_request req = { .pid = current->pid }; + int err; + + /* fork usermode process */ + err = fork_usermode_driver(&bpfilter_ops.info); + if (err) + return err; + pr_info("Loaded bpfilter_umh pid %d\n", pid_nr(bpfilter_ops.info.tgid)); + + /* health check that usermode process started correctly */ + if (bpfilter_send_req(&req) != 0) { + shutdown_umh(); + return -EFAULT; + } + + return 0; +} + +static int __init load_umh(void) +{ + int err; + + err = umd_load_blob(&bpfilter_ops.info, + &bpfilter_umh_start, + &bpfilter_umh_end - &bpfilter_umh_start); + if (err) + return err; + + mutex_lock(&bpfilter_ops.lock); + err = start_umh(); + if (!err && IS_ENABLED(CONFIG_INET)) { + bpfilter_ops.sockopt = &bpfilter_process_sockopt; + bpfilter_ops.start = &start_umh; + } + mutex_unlock(&bpfilter_ops.lock); + if (err) + umd_unload_blob(&bpfilter_ops.info); + return err; +} + +static void __exit fini_umh(void) +{ + mutex_lock(&bpfilter_ops.lock); + if (IS_ENABLED(CONFIG_INET)) { + shutdown_umh(); + bpfilter_ops.start = NULL; + bpfilter_ops.sockopt = NULL; + } + mutex_unlock(&bpfilter_ops.lock); + + umd_unload_blob(&bpfilter_ops.info); +} +module_init(load_umh); +module_exit(fini_umh); +MODULE_LICENSE("GPL"); diff --git a/net/bpfilter/bpfilter_umh_blob.S b/net/bpfilter/bpfilter_umh_blob.S new file mode 100644 index 000000000..40311d10d --- /dev/null +++ b/net/bpfilter/bpfilter_umh_blob.S @@ -0,0 +1,7 @@ +/* SPDX-License-Identifier: GPL-2.0 */ + .section .init.rodata, "a" + .global bpfilter_umh_start +bpfilter_umh_start: + .incbin "net/bpfilter/bpfilter_umh" + .global bpfilter_umh_end +bpfilter_umh_end: diff --git a/net/bpfilter/main.c b/net/bpfilter/main.c new file mode 100644 index 000000000..291a92546 --- /dev/null +++ b/net/bpfilter/main.c @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: GPL-2.0 +#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include +#include "../../include/uapi/linux/bpf.h" +#include +#include "msgfmt.h" + +FILE *debug_f; + +static int handle_get_cmd(struct mbox_request *cmd) +{ + switch (cmd->cmd) { + case 0: + return 0; + default: + break; + } + return -ENOPROTOOPT; +} + +static int handle_set_cmd(struct mbox_request *cmd) +{ + return -ENOPROTOOPT; +} + +static void loop(void) +{ + while (1) { + struct mbox_request req; + struct mbox_reply reply; + int n; + + n = read(0, &req, sizeof(req)); + if (n != sizeof(req)) { + fprintf(debug_f, "invalid request %d\n", n); + return; + } + + reply.status = req.is_set ? + handle_set_cmd(&req) : + handle_get_cmd(&req); + + n = write(1, &reply, sizeof(reply)); + if (n != sizeof(reply)) { + fprintf(debug_f, "reply failed %d\n", n); + return; + } + } +} + +int main(void) +{ + debug_f = fopen("/dev/kmsg", "w"); + setvbuf(debug_f, 0, _IOLBF, 0); + fprintf(debug_f, "<5>Started bpfilter\n"); + loop(); + fclose(debug_f); + return 0; +} diff --git a/net/bpfilter/msgfmt.h b/net/bpfilter/msgfmt.h new file mode 100644 index 000000000..98d121c62 --- /dev/null +++ b/net/bpfilter/msgfmt.h @@ -0,0 +1,17 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _NET_BPFILTER_MSGFMT_H +#define _NET_BPFILTER_MSGFMT_H + +struct mbox_request { + __u64 addr; + __u32 len; + __u32 is_set; + __u32 cmd; + __u32 pid; +}; + +struct mbox_reply { + __u32 status; +}; + +#endif diff --git a/net/bridge/Kconfig b/net/bridge/Kconfig new file mode 100644 index 000000000..3c8ded7d3 --- /dev/null +++ b/net/bridge/Kconfig @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# 802.1d Ethernet Bridging +# + +config BRIDGE + tristate "802.1d Ethernet Bridging" + select LLC + select STP + depends on IPV6 || IPV6=n + help + If you say Y here, then your Linux box will be able to act as an + Ethernet bridge, which means that the different Ethernet segments it + is connected to will appear as one Ethernet to the participants. + Several such bridges can work together to create even larger + networks of Ethernets using the IEEE 802.1 spanning tree algorithm. + As this is a standard, Linux bridges will cooperate properly with + other third party bridge products. + + In order to use the Ethernet bridge, you'll need the bridge + configuration tools; see + for location. Please read the Bridge mini-HOWTO for more + information. + + If you enable iptables support along with the bridge support then you + turn your bridge into a bridging IP firewall. + iptables will then see the IP packets being bridged, so you need to + take this into account when setting up your firewall rules. + Enabling arptables support when bridging will let arptables see + bridged ARP traffic in the arptables FORWARD chain. + + To compile this code as a module, choose M here: the module + will be called bridge. + + If unsure, say N. + +config BRIDGE_IGMP_SNOOPING + bool "IGMP/MLD snooping" + depends on BRIDGE + depends on INET + default y + help + If you say Y here, then the Ethernet bridge will be able selectively + forward multicast traffic based on IGMP/MLD traffic received from + each port. + + Say N to exclude this support and reduce the binary size. + + If unsure, say Y. + +config BRIDGE_VLAN_FILTERING + bool "VLAN filtering" + depends on BRIDGE + depends on VLAN_8021Q + default n + help + If you say Y here, then the Ethernet bridge will be able selectively + receive and forward traffic based on VLAN information in the packet + any VLAN information configured on the bridge port or bridge device. + + Say N to exclude this support and reduce the binary size. + + If unsure, say Y. + +config BRIDGE_MRP + bool "MRP protocol" + depends on BRIDGE + default n + help + If you say Y here, then the Ethernet bridge will be able to run MRP + protocol to detect loops + + Say N to exclude this support and reduce the binary size. + + If unsure, say N. + +config BRIDGE_CFM + bool "CFM protocol" + depends on BRIDGE + help + If you say Y here, then the Ethernet bridge will be able to run CFM + protocol according to 802.1Q section 12.14 + + Say N to exclude this support and reduce the binary size. + + If unsure, say N. diff --git a/net/bridge/Makefile b/net/bridge/Makefile new file mode 100644 index 000000000..24bd1c0a9 --- /dev/null +++ b/net/bridge/Makefile @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the IEEE 802.1d ethernet bridging layer. +# + +obj-$(CONFIG_BRIDGE) += bridge.o + +bridge-y := br.o br_device.o br_fdb.o br_forward.o br_if.o br_input.o \ + br_ioctl.o br_stp.o br_stp_bpdu.o \ + br_stp_if.o br_stp_timer.o br_netlink.o \ + br_netlink_tunnel.o br_arp_nd_proxy.o + +bridge-$(CONFIG_SYSFS) += br_sysfs_if.o br_sysfs_br.o + +bridge-$(subst m,y,$(CONFIG_BRIDGE_NETFILTER)) += br_nf_core.o + +br_netfilter-y := br_netfilter_hooks.o +br_netfilter-$(subst m,y,$(CONFIG_IPV6)) += br_netfilter_ipv6.o +obj-$(CONFIG_BRIDGE_NETFILTER) += br_netfilter.o + +bridge-$(CONFIG_BRIDGE_IGMP_SNOOPING) += br_multicast.o br_mdb.o br_multicast_eht.o + +bridge-$(CONFIG_BRIDGE_VLAN_FILTERING) += br_vlan.o br_vlan_tunnel.o br_vlan_options.o br_mst.o + +bridge-$(CONFIG_NET_SWITCHDEV) += br_switchdev.o + +obj-$(CONFIG_NETFILTER) += netfilter/ + +bridge-$(CONFIG_BRIDGE_MRP) += br_mrp_switchdev.o br_mrp.o br_mrp_netlink.o + +bridge-$(CONFIG_BRIDGE_CFM) += br_cfm.o br_cfm_netlink.o diff --git a/net/bridge/br.c b/net/bridge/br.c new file mode 100644 index 000000000..96e91d69a --- /dev/null +++ b/net/bridge/br.c @@ -0,0 +1,470 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Generic parts + * Linux ethernet bridge + * + * Authors: + * Lennert Buytenhek + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "br_private.h" + +/* + * Handle changes in state of network devices enslaved to a bridge. + * + * Note: don't care about up/down if bridge itself is down, because + * port state is checked when bridge is brought up. + */ +static int br_device_event(struct notifier_block *unused, unsigned long event, void *ptr) +{ + struct netlink_ext_ack *extack = netdev_notifier_info_to_extack(ptr); + struct netdev_notifier_pre_changeaddr_info *prechaddr_info; + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct net_bridge_port *p; + struct net_bridge *br; + bool notified = false; + bool changed_addr; + int err; + + if (netif_is_bridge_master(dev)) { + err = br_vlan_bridge_event(dev, event, ptr); + if (err) + return notifier_from_errno(err); + + if (event == NETDEV_REGISTER) { + /* register of bridge completed, add sysfs entries */ + err = br_sysfs_addbr(dev); + if (err) + return notifier_from_errno(err); + + return NOTIFY_DONE; + } + } + + /* not a port of a bridge */ + p = br_port_get_rtnl(dev); + if (!p) + return NOTIFY_DONE; + + br = p->br; + + switch (event) { + case NETDEV_CHANGEMTU: + br_mtu_auto_adjust(br); + break; + + case NETDEV_PRE_CHANGEADDR: + if (br->dev->addr_assign_type == NET_ADDR_SET) + break; + prechaddr_info = ptr; + err = dev_pre_changeaddr_notify(br->dev, + prechaddr_info->dev_addr, + extack); + if (err) + return notifier_from_errno(err); + break; + + case NETDEV_CHANGEADDR: + spin_lock_bh(&br->lock); + br_fdb_changeaddr(p, dev->dev_addr); + changed_addr = br_stp_recalculate_bridge_id(br); + spin_unlock_bh(&br->lock); + + if (changed_addr) + call_netdevice_notifiers(NETDEV_CHANGEADDR, br->dev); + + break; + + case NETDEV_CHANGE: + br_port_carrier_check(p, ¬ified); + break; + + case NETDEV_FEAT_CHANGE: + netdev_update_features(br->dev); + break; + + case NETDEV_DOWN: + spin_lock_bh(&br->lock); + if (br->dev->flags & IFF_UP) { + br_stp_disable_port(p); + notified = true; + } + spin_unlock_bh(&br->lock); + break; + + case NETDEV_UP: + if (netif_running(br->dev) && netif_oper_up(dev)) { + spin_lock_bh(&br->lock); + br_stp_enable_port(p); + notified = true; + spin_unlock_bh(&br->lock); + } + break; + + case NETDEV_UNREGISTER: + br_del_if(br, dev); + break; + + case NETDEV_CHANGENAME: + err = br_sysfs_renameif(p); + if (err) + return notifier_from_errno(err); + break; + + case NETDEV_PRE_TYPE_CHANGE: + /* Forbid underlying device to change its type. */ + return NOTIFY_BAD; + + case NETDEV_RESEND_IGMP: + /* Propagate to master device */ + call_netdevice_notifiers(event, br->dev); + break; + } + + if (event != NETDEV_UNREGISTER) + br_vlan_port_event(p, event); + + /* Events that may cause spanning tree to refresh */ + if (!notified && (event == NETDEV_CHANGEADDR || event == NETDEV_UP || + event == NETDEV_CHANGE || event == NETDEV_DOWN)) + br_ifinfo_notify(RTM_NEWLINK, NULL, p); + + return NOTIFY_DONE; +} + +static struct notifier_block br_device_notifier = { + .notifier_call = br_device_event +}; + +/* called with RTNL or RCU */ +static int br_switchdev_event(struct notifier_block *unused, + unsigned long event, void *ptr) +{ + struct net_device *dev = switchdev_notifier_info_to_dev(ptr); + struct net_bridge_port *p; + struct net_bridge *br; + struct switchdev_notifier_fdb_info *fdb_info; + int err = NOTIFY_DONE; + + p = br_port_get_rtnl_rcu(dev); + if (!p) + goto out; + + br = p->br; + + switch (event) { + case SWITCHDEV_FDB_ADD_TO_BRIDGE: + fdb_info = ptr; + err = br_fdb_external_learn_add(br, p, fdb_info->addr, + fdb_info->vid, false); + if (err) { + err = notifier_from_errno(err); + break; + } + br_fdb_offloaded_set(br, p, fdb_info->addr, + fdb_info->vid, true); + break; + case SWITCHDEV_FDB_DEL_TO_BRIDGE: + fdb_info = ptr; + err = br_fdb_external_learn_del(br, p, fdb_info->addr, + fdb_info->vid, false); + if (err) + err = notifier_from_errno(err); + break; + case SWITCHDEV_FDB_OFFLOADED: + fdb_info = ptr; + br_fdb_offloaded_set(br, p, fdb_info->addr, + fdb_info->vid, fdb_info->offloaded); + break; + case SWITCHDEV_FDB_FLUSH_TO_BRIDGE: + fdb_info = ptr; + /* Don't delete static entries */ + br_fdb_delete_by_port(br, p, fdb_info->vid, 0); + break; + } + +out: + return err; +} + +static struct notifier_block br_switchdev_notifier = { + .notifier_call = br_switchdev_event, +}; + +/* called under rtnl_mutex */ +static int br_switchdev_blocking_event(struct notifier_block *nb, + unsigned long event, void *ptr) +{ + struct netlink_ext_ack *extack = netdev_notifier_info_to_extack(ptr); + struct net_device *dev = switchdev_notifier_info_to_dev(ptr); + struct switchdev_notifier_brport_info *brport_info; + const struct switchdev_brport *b; + struct net_bridge_port *p; + int err = NOTIFY_DONE; + + p = br_port_get_rtnl(dev); + if (!p) + goto out; + + switch (event) { + case SWITCHDEV_BRPORT_OFFLOADED: + brport_info = ptr; + b = &brport_info->brport; + + err = br_switchdev_port_offload(p, b->dev, b->ctx, + b->atomic_nb, b->blocking_nb, + b->tx_fwd_offload, extack); + err = notifier_from_errno(err); + break; + case SWITCHDEV_BRPORT_UNOFFLOADED: + brport_info = ptr; + b = &brport_info->brport; + + br_switchdev_port_unoffload(p, b->ctx, b->atomic_nb, + b->blocking_nb); + break; + } + +out: + return err; +} + +static struct notifier_block br_switchdev_blocking_notifier = { + .notifier_call = br_switchdev_blocking_event, +}; + +/* br_boolopt_toggle - change user-controlled boolean option + * + * @br: bridge device + * @opt: id of the option to change + * @on: new option value + * @extack: extack for error messages + * + * Changes the value of the respective boolean option to @on taking care of + * any internal option value mapping and configuration. + */ +int br_boolopt_toggle(struct net_bridge *br, enum br_boolopt_id opt, bool on, + struct netlink_ext_ack *extack) +{ + int err = 0; + + switch (opt) { + case BR_BOOLOPT_NO_LL_LEARN: + br_opt_toggle(br, BROPT_NO_LL_LEARN, on); + break; + case BR_BOOLOPT_MCAST_VLAN_SNOOPING: + err = br_multicast_toggle_vlan_snooping(br, on, extack); + break; + case BR_BOOLOPT_MST_ENABLE: + err = br_mst_set_enabled(br, on, extack); + break; + default: + /* shouldn't be called with unsupported options */ + WARN_ON(1); + break; + } + + return err; +} + +int br_boolopt_get(const struct net_bridge *br, enum br_boolopt_id opt) +{ + switch (opt) { + case BR_BOOLOPT_NO_LL_LEARN: + return br_opt_get(br, BROPT_NO_LL_LEARN); + case BR_BOOLOPT_MCAST_VLAN_SNOOPING: + return br_opt_get(br, BROPT_MCAST_VLAN_SNOOPING_ENABLED); + case BR_BOOLOPT_MST_ENABLE: + return br_opt_get(br, BROPT_MST_ENABLED); + default: + /* shouldn't be called with unsupported options */ + WARN_ON(1); + break; + } + + return 0; +} + +int br_boolopt_multi_toggle(struct net_bridge *br, + struct br_boolopt_multi *bm, + struct netlink_ext_ack *extack) +{ + unsigned long bitmap = bm->optmask; + int err = 0; + int opt_id; + + for_each_set_bit(opt_id, &bitmap, BR_BOOLOPT_MAX) { + bool on = !!(bm->optval & BIT(opt_id)); + + err = br_boolopt_toggle(br, opt_id, on, extack); + if (err) { + br_debug(br, "boolopt multi-toggle error: option: %d current: %d new: %d error: %d\n", + opt_id, br_boolopt_get(br, opt_id), on, err); + break; + } + } + + return err; +} + +void br_boolopt_multi_get(const struct net_bridge *br, + struct br_boolopt_multi *bm) +{ + u32 optval = 0; + int opt_id; + + for (opt_id = 0; opt_id < BR_BOOLOPT_MAX; opt_id++) + optval |= (br_boolopt_get(br, opt_id) << opt_id); + + bm->optval = optval; + bm->optmask = GENMASK((BR_BOOLOPT_MAX - 1), 0); +} + +/* private bridge options, controlled by the kernel */ +void br_opt_toggle(struct net_bridge *br, enum net_bridge_opts opt, bool on) +{ + bool cur = !!br_opt_get(br, opt); + + br_debug(br, "toggle option: %d state: %d -> %d\n", + opt, cur, on); + + if (cur == on) + return; + + if (on) + set_bit(opt, &br->options); + else + clear_bit(opt, &br->options); +} + +static void __net_exit br_net_exit_batch(struct list_head *net_list) +{ + struct net_device *dev; + struct net *net; + LIST_HEAD(list); + + rtnl_lock(); + + list_for_each_entry(net, net_list, exit_list) + for_each_netdev(net, dev) + if (netif_is_bridge_master(dev)) + br_dev_delete(dev, &list); + + unregister_netdevice_many(&list); + + rtnl_unlock(); +} + +static struct pernet_operations br_net_ops = { + .exit_batch = br_net_exit_batch, +}; + +static const struct stp_proto br_stp_proto = { + .rcv = br_stp_rcv, +}; + +static int __init br_init(void) +{ + int err; + + BUILD_BUG_ON(sizeof(struct br_input_skb_cb) > sizeof_field(struct sk_buff, cb)); + + err = stp_proto_register(&br_stp_proto); + if (err < 0) { + pr_err("bridge: can't register sap for STP\n"); + return err; + } + + err = br_fdb_init(); + if (err) + goto err_out; + + err = register_pernet_subsys(&br_net_ops); + if (err) + goto err_out1; + + err = br_nf_core_init(); + if (err) + goto err_out2; + + err = register_netdevice_notifier(&br_device_notifier); + if (err) + goto err_out3; + + err = register_switchdev_notifier(&br_switchdev_notifier); + if (err) + goto err_out4; + + err = register_switchdev_blocking_notifier(&br_switchdev_blocking_notifier); + if (err) + goto err_out5; + + err = br_netlink_init(); + if (err) + goto err_out6; + + brioctl_set(br_ioctl_stub); + +#if IS_ENABLED(CONFIG_ATM_LANE) + br_fdb_test_addr_hook = br_fdb_test_addr; +#endif + +#if IS_MODULE(CONFIG_BRIDGE_NETFILTER) + pr_info("bridge: filtering via arp/ip/ip6tables is no longer available " + "by default. Update your scripts to load br_netfilter if you " + "need this.\n"); +#endif + + return 0; + +err_out6: + unregister_switchdev_blocking_notifier(&br_switchdev_blocking_notifier); +err_out5: + unregister_switchdev_notifier(&br_switchdev_notifier); +err_out4: + unregister_netdevice_notifier(&br_device_notifier); +err_out3: + br_nf_core_fini(); +err_out2: + unregister_pernet_subsys(&br_net_ops); +err_out1: + br_fdb_fini(); +err_out: + stp_proto_unregister(&br_stp_proto); + return err; +} + +static void __exit br_deinit(void) +{ + stp_proto_unregister(&br_stp_proto); + br_netlink_fini(); + unregister_switchdev_blocking_notifier(&br_switchdev_blocking_notifier); + unregister_switchdev_notifier(&br_switchdev_notifier); + unregister_netdevice_notifier(&br_device_notifier); + brioctl_set(NULL); + unregister_pernet_subsys(&br_net_ops); + + rcu_barrier(); /* Wait for completion of call_rcu()'s */ + + br_nf_core_fini(); +#if IS_ENABLED(CONFIG_ATM_LANE) + br_fdb_test_addr_hook = NULL; +#endif + br_fdb_fini(); +} + +module_init(br_init) +module_exit(br_deinit) +MODULE_LICENSE("GPL"); +MODULE_VERSION(BR_VERSION); +MODULE_ALIAS_RTNL_LINK("bridge"); diff --git a/net/bridge/br_arp_nd_proxy.c b/net/bridge/br_arp_nd_proxy.c new file mode 100644 index 000000000..b45c00c01 --- /dev/null +++ b/net/bridge/br_arp_nd_proxy.c @@ -0,0 +1,485 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Handle bridge arp/nd proxy/suppress + * + * Copyright (C) 2017 Cumulus Networks + * Copyright (c) 2017 Roopa Prabhu + * + * Authors: + * Roopa Prabhu + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_IPV6) +#include +#endif + +#include "br_private.h" + +void br_recalculate_neigh_suppress_enabled(struct net_bridge *br) +{ + struct net_bridge_port *p; + bool neigh_suppress = false; + + list_for_each_entry(p, &br->port_list, list) { + if (p->flags & BR_NEIGH_SUPPRESS) { + neigh_suppress = true; + break; + } + } + + br_opt_toggle(br, BROPT_NEIGH_SUPPRESS_ENABLED, neigh_suppress); +} + +#if IS_ENABLED(CONFIG_INET) +static void br_arp_send(struct net_bridge *br, struct net_bridge_port *p, + struct net_device *dev, __be32 dest_ip, __be32 src_ip, + const unsigned char *dest_hw, + const unsigned char *src_hw, + const unsigned char *target_hw, + __be16 vlan_proto, u16 vlan_tci) +{ + struct net_bridge_vlan_group *vg; + struct sk_buff *skb; + u16 pvid; + + netdev_dbg(dev, "arp send dev %s dst %pI4 dst_hw %pM src %pI4 src_hw %pM\n", + dev->name, &dest_ip, dest_hw, &src_ip, src_hw); + + if (!vlan_tci) { + arp_send(ARPOP_REPLY, ETH_P_ARP, dest_ip, dev, src_ip, + dest_hw, src_hw, target_hw); + return; + } + + skb = arp_create(ARPOP_REPLY, ETH_P_ARP, dest_ip, dev, src_ip, + dest_hw, src_hw, target_hw); + if (!skb) + return; + + if (p) + vg = nbp_vlan_group_rcu(p); + else + vg = br_vlan_group_rcu(br); + pvid = br_get_pvid(vg); + if (pvid == (vlan_tci & VLAN_VID_MASK)) + vlan_tci = 0; + + if (vlan_tci) + __vlan_hwaccel_put_tag(skb, vlan_proto, vlan_tci); + + if (p) { + arp_xmit(skb); + } else { + skb_reset_mac_header(skb); + __skb_pull(skb, skb_network_offset(skb)); + skb->ip_summed = CHECKSUM_UNNECESSARY; + skb->pkt_type = PACKET_HOST; + + netif_rx(skb); + } +} + +static int br_chk_addr_ip(struct net_device *dev, + struct netdev_nested_priv *priv) +{ + __be32 ip = *(__be32 *)priv->data; + struct in_device *in_dev; + __be32 addr = 0; + + in_dev = __in_dev_get_rcu(dev); + if (in_dev) + addr = inet_confirm_addr(dev_net(dev), in_dev, 0, ip, + RT_SCOPE_HOST); + + if (addr == ip) + return 1; + + return 0; +} + +static bool br_is_local_ip(struct net_device *dev, __be32 ip) +{ + struct netdev_nested_priv priv = { + .data = (void *)&ip, + }; + + if (br_chk_addr_ip(dev, &priv)) + return true; + + /* check if ip is configured on upper dev */ + if (netdev_walk_all_upper_dev_rcu(dev, br_chk_addr_ip, &priv)) + return true; + + return false; +} + +void br_do_proxy_suppress_arp(struct sk_buff *skb, struct net_bridge *br, + u16 vid, struct net_bridge_port *p) +{ + struct net_device *dev = br->dev; + struct net_device *vlandev = dev; + struct neighbour *n; + struct arphdr *parp; + u8 *arpptr, *sha; + __be32 sip, tip; + + BR_INPUT_SKB_CB(skb)->proxyarp_replied = 0; + + if ((dev->flags & IFF_NOARP) || + !pskb_may_pull(skb, arp_hdr_len(dev))) + return; + + parp = arp_hdr(skb); + + if (parp->ar_pro != htons(ETH_P_IP) || + parp->ar_hln != dev->addr_len || + parp->ar_pln != 4) + return; + + arpptr = (u8 *)parp + sizeof(struct arphdr); + sha = arpptr; + arpptr += dev->addr_len; /* sha */ + memcpy(&sip, arpptr, sizeof(sip)); + arpptr += sizeof(sip); + arpptr += dev->addr_len; /* tha */ + memcpy(&tip, arpptr, sizeof(tip)); + + if (ipv4_is_loopback(tip) || + ipv4_is_multicast(tip)) + return; + + if (br_opt_get(br, BROPT_NEIGH_SUPPRESS_ENABLED)) { + if (p && (p->flags & BR_NEIGH_SUPPRESS)) + return; + if (parp->ar_op != htons(ARPOP_RREQUEST) && + parp->ar_op != htons(ARPOP_RREPLY) && + (ipv4_is_zeronet(sip) || sip == tip)) { + /* prevent flooding to neigh suppress ports */ + BR_INPUT_SKB_CB(skb)->proxyarp_replied = 1; + return; + } + } + + if (parp->ar_op != htons(ARPOP_REQUEST)) + return; + + if (vid != 0) { + vlandev = __vlan_find_dev_deep_rcu(br->dev, skb->vlan_proto, + vid); + if (!vlandev) + return; + } + + if (br_opt_get(br, BROPT_NEIGH_SUPPRESS_ENABLED) && + br_is_local_ip(vlandev, tip)) { + /* its our local ip, so don't proxy reply + * and don't forward to neigh suppress ports + */ + BR_INPUT_SKB_CB(skb)->proxyarp_replied = 1; + return; + } + + n = neigh_lookup(&arp_tbl, &tip, vlandev); + if (n) { + struct net_bridge_fdb_entry *f; + + if (!(READ_ONCE(n->nud_state) & NUD_VALID)) { + neigh_release(n); + return; + } + + f = br_fdb_find_rcu(br, n->ha, vid); + if (f) { + bool replied = false; + + if ((p && (p->flags & BR_PROXYARP)) || + (f->dst && (f->dst->flags & (BR_PROXYARP_WIFI | + BR_NEIGH_SUPPRESS)))) { + if (!vid) + br_arp_send(br, p, skb->dev, sip, tip, + sha, n->ha, sha, 0, 0); + else + br_arp_send(br, p, skb->dev, sip, tip, + sha, n->ha, sha, + skb->vlan_proto, + skb_vlan_tag_get(skb)); + replied = true; + } + + /* If we have replied or as long as we know the + * mac, indicate to arp replied + */ + if (replied || + br_opt_get(br, BROPT_NEIGH_SUPPRESS_ENABLED)) + BR_INPUT_SKB_CB(skb)->proxyarp_replied = 1; + } + + neigh_release(n); + } +} +#endif + +#if IS_ENABLED(CONFIG_IPV6) +struct nd_msg *br_is_nd_neigh_msg(struct sk_buff *skb, struct nd_msg *msg) +{ + struct nd_msg *m; + + m = skb_header_pointer(skb, skb_network_offset(skb) + + sizeof(struct ipv6hdr), sizeof(*msg), msg); + if (!m) + return NULL; + + if (m->icmph.icmp6_code != 0 || + (m->icmph.icmp6_type != NDISC_NEIGHBOUR_SOLICITATION && + m->icmph.icmp6_type != NDISC_NEIGHBOUR_ADVERTISEMENT)) + return NULL; + + return m; +} + +static void br_nd_send(struct net_bridge *br, struct net_bridge_port *p, + struct sk_buff *request, struct neighbour *n, + __be16 vlan_proto, u16 vlan_tci, struct nd_msg *ns) +{ + struct net_device *dev = request->dev; + struct net_bridge_vlan_group *vg; + struct sk_buff *reply; + struct nd_msg *na; + struct ipv6hdr *pip6; + int na_olen = 8; /* opt hdr + ETH_ALEN for target */ + int ns_olen; + int i, len; + u8 *daddr; + u16 pvid; + + if (!dev) + return; + + len = LL_RESERVED_SPACE(dev) + sizeof(struct ipv6hdr) + + sizeof(*na) + na_olen + dev->needed_tailroom; + + reply = alloc_skb(len, GFP_ATOMIC); + if (!reply) + return; + + reply->protocol = htons(ETH_P_IPV6); + reply->dev = dev; + skb_reserve(reply, LL_RESERVED_SPACE(dev)); + skb_push(reply, sizeof(struct ethhdr)); + skb_set_mac_header(reply, 0); + + daddr = eth_hdr(request)->h_source; + + /* Do we need option processing ? */ + ns_olen = request->len - (skb_network_offset(request) + + sizeof(struct ipv6hdr)) - sizeof(*ns); + for (i = 0; i < ns_olen - 1; i += (ns->opt[i + 1] << 3)) { + if (!ns->opt[i + 1]) { + kfree_skb(reply); + return; + } + if (ns->opt[i] == ND_OPT_SOURCE_LL_ADDR) { + daddr = ns->opt + i + sizeof(struct nd_opt_hdr); + break; + } + } + + /* Ethernet header */ + ether_addr_copy(eth_hdr(reply)->h_dest, daddr); + ether_addr_copy(eth_hdr(reply)->h_source, n->ha); + eth_hdr(reply)->h_proto = htons(ETH_P_IPV6); + reply->protocol = htons(ETH_P_IPV6); + + skb_pull(reply, sizeof(struct ethhdr)); + skb_set_network_header(reply, 0); + skb_put(reply, sizeof(struct ipv6hdr)); + + /* IPv6 header */ + pip6 = ipv6_hdr(reply); + memset(pip6, 0, sizeof(struct ipv6hdr)); + pip6->version = 6; + pip6->priority = ipv6_hdr(request)->priority; + pip6->nexthdr = IPPROTO_ICMPV6; + pip6->hop_limit = 255; + pip6->daddr = ipv6_hdr(request)->saddr; + pip6->saddr = *(struct in6_addr *)n->primary_key; + + skb_pull(reply, sizeof(struct ipv6hdr)); + skb_set_transport_header(reply, 0); + + na = (struct nd_msg *)skb_put(reply, sizeof(*na) + na_olen); + + /* Neighbor Advertisement */ + memset(na, 0, sizeof(*na) + na_olen); + na->icmph.icmp6_type = NDISC_NEIGHBOUR_ADVERTISEMENT; + na->icmph.icmp6_router = (n->flags & NTF_ROUTER) ? 1 : 0; + na->icmph.icmp6_override = 1; + na->icmph.icmp6_solicited = 1; + na->target = ns->target; + ether_addr_copy(&na->opt[2], n->ha); + na->opt[0] = ND_OPT_TARGET_LL_ADDR; + na->opt[1] = na_olen >> 3; + + na->icmph.icmp6_cksum = csum_ipv6_magic(&pip6->saddr, + &pip6->daddr, + sizeof(*na) + na_olen, + IPPROTO_ICMPV6, + csum_partial(na, sizeof(*na) + na_olen, 0)); + + pip6->payload_len = htons(sizeof(*na) + na_olen); + + skb_push(reply, sizeof(struct ipv6hdr)); + skb_push(reply, sizeof(struct ethhdr)); + + reply->ip_summed = CHECKSUM_UNNECESSARY; + + if (p) + vg = nbp_vlan_group_rcu(p); + else + vg = br_vlan_group_rcu(br); + pvid = br_get_pvid(vg); + if (pvid == (vlan_tci & VLAN_VID_MASK)) + vlan_tci = 0; + + if (vlan_tci) + __vlan_hwaccel_put_tag(reply, vlan_proto, vlan_tci); + + netdev_dbg(dev, "nd send dev %s dst %pI6 dst_hw %pM src %pI6 src_hw %pM\n", + dev->name, &pip6->daddr, daddr, &pip6->saddr, n->ha); + + if (p) { + dev_queue_xmit(reply); + } else { + skb_reset_mac_header(reply); + __skb_pull(reply, skb_network_offset(reply)); + reply->ip_summed = CHECKSUM_UNNECESSARY; + reply->pkt_type = PACKET_HOST; + + netif_rx(reply); + } +} + +static int br_chk_addr_ip6(struct net_device *dev, + struct netdev_nested_priv *priv) +{ + struct in6_addr *addr = (struct in6_addr *)priv->data; + + if (ipv6_chk_addr(dev_net(dev), addr, dev, 0)) + return 1; + + return 0; +} + +static bool br_is_local_ip6(struct net_device *dev, struct in6_addr *addr) + +{ + struct netdev_nested_priv priv = { + .data = (void *)addr, + }; + + if (br_chk_addr_ip6(dev, &priv)) + return true; + + /* check if ip is configured on upper dev */ + if (netdev_walk_all_upper_dev_rcu(dev, br_chk_addr_ip6, &priv)) + return true; + + return false; +} + +void br_do_suppress_nd(struct sk_buff *skb, struct net_bridge *br, + u16 vid, struct net_bridge_port *p, struct nd_msg *msg) +{ + struct net_device *dev = br->dev; + struct net_device *vlandev = NULL; + struct in6_addr *saddr, *daddr; + struct ipv6hdr *iphdr; + struct neighbour *n; + + BR_INPUT_SKB_CB(skb)->proxyarp_replied = 0; + + if (p && (p->flags & BR_NEIGH_SUPPRESS)) + return; + + if (msg->icmph.icmp6_type == NDISC_NEIGHBOUR_ADVERTISEMENT && + !msg->icmph.icmp6_solicited) { + /* prevent flooding to neigh suppress ports */ + BR_INPUT_SKB_CB(skb)->proxyarp_replied = 1; + return; + } + + if (msg->icmph.icmp6_type != NDISC_NEIGHBOUR_SOLICITATION) + return; + + iphdr = ipv6_hdr(skb); + saddr = &iphdr->saddr; + daddr = &iphdr->daddr; + + if (ipv6_addr_any(saddr) || !ipv6_addr_cmp(saddr, daddr)) { + /* prevent flooding to neigh suppress ports */ + BR_INPUT_SKB_CB(skb)->proxyarp_replied = 1; + return; + } + + if (vid != 0) { + /* build neigh table lookup on the vlan device */ + vlandev = __vlan_find_dev_deep_rcu(br->dev, skb->vlan_proto, + vid); + if (!vlandev) + return; + } else { + vlandev = dev; + } + + if (br_is_local_ip6(vlandev, &msg->target)) { + /* its our own ip, so don't proxy reply + * and don't forward to arp suppress ports + */ + BR_INPUT_SKB_CB(skb)->proxyarp_replied = 1; + return; + } + + n = neigh_lookup(ipv6_stub->nd_tbl, &msg->target, vlandev); + if (n) { + struct net_bridge_fdb_entry *f; + + if (!(READ_ONCE(n->nud_state) & NUD_VALID)) { + neigh_release(n); + return; + } + + f = br_fdb_find_rcu(br, n->ha, vid); + if (f) { + bool replied = false; + + if (f->dst && (f->dst->flags & BR_NEIGH_SUPPRESS)) { + if (vid != 0) + br_nd_send(br, p, skb, n, + skb->vlan_proto, + skb_vlan_tag_get(skb), msg); + else + br_nd_send(br, p, skb, n, 0, 0, msg); + replied = true; + } + + /* If we have replied or as long as we know the + * mac, indicate to NEIGH_SUPPRESS ports that we + * have replied + */ + if (replied || + br_opt_get(br, BROPT_NEIGH_SUPPRESS_ENABLED)) + BR_INPUT_SKB_CB(skb)->proxyarp_replied = 1; + } + neigh_release(n); + } +} +#endif diff --git a/net/bridge/br_cfm.c b/net/bridge/br_cfm.c new file mode 100644 index 000000000..a3c755d0a --- /dev/null +++ b/net/bridge/br_cfm.c @@ -0,0 +1,867 @@ +// SPDX-License-Identifier: GPL-2.0-or-later + +#include +#include +#include "br_private_cfm.h" + +static struct br_cfm_mep *br_mep_find(struct net_bridge *br, u32 instance) +{ + struct br_cfm_mep *mep; + + hlist_for_each_entry(mep, &br->mep_list, head) + if (mep->instance == instance) + return mep; + + return NULL; +} + +static struct br_cfm_mep *br_mep_find_ifindex(struct net_bridge *br, + u32 ifindex) +{ + struct br_cfm_mep *mep; + + hlist_for_each_entry_rcu(mep, &br->mep_list, head, + lockdep_rtnl_is_held()) + if (mep->create.ifindex == ifindex) + return mep; + + return NULL; +} + +static struct br_cfm_peer_mep *br_peer_mep_find(struct br_cfm_mep *mep, + u32 mepid) +{ + struct br_cfm_peer_mep *peer_mep; + + hlist_for_each_entry_rcu(peer_mep, &mep->peer_mep_list, head, + lockdep_rtnl_is_held()) + if (peer_mep->mepid == mepid) + return peer_mep; + + return NULL; +} + +static struct net_bridge_port *br_mep_get_port(struct net_bridge *br, + u32 ifindex) +{ + struct net_bridge_port *port; + + list_for_each_entry(port, &br->port_list, list) + if (port->dev->ifindex == ifindex) + return port; + + return NULL; +} + +/* Calculate the CCM interval in us. */ +static u32 interval_to_us(enum br_cfm_ccm_interval interval) +{ + switch (interval) { + case BR_CFM_CCM_INTERVAL_NONE: + return 0; + case BR_CFM_CCM_INTERVAL_3_3_MS: + return 3300; + case BR_CFM_CCM_INTERVAL_10_MS: + return 10 * 1000; + case BR_CFM_CCM_INTERVAL_100_MS: + return 100 * 1000; + case BR_CFM_CCM_INTERVAL_1_SEC: + return 1000 * 1000; + case BR_CFM_CCM_INTERVAL_10_SEC: + return 10 * 1000 * 1000; + case BR_CFM_CCM_INTERVAL_1_MIN: + return 60 * 1000 * 1000; + case BR_CFM_CCM_INTERVAL_10_MIN: + return 10 * 60 * 1000 * 1000; + } + return 0; +} + +/* Convert the interface interval to CCM PDU value. */ +static u32 interval_to_pdu(enum br_cfm_ccm_interval interval) +{ + switch (interval) { + case BR_CFM_CCM_INTERVAL_NONE: + return 0; + case BR_CFM_CCM_INTERVAL_3_3_MS: + return 1; + case BR_CFM_CCM_INTERVAL_10_MS: + return 2; + case BR_CFM_CCM_INTERVAL_100_MS: + return 3; + case BR_CFM_CCM_INTERVAL_1_SEC: + return 4; + case BR_CFM_CCM_INTERVAL_10_SEC: + return 5; + case BR_CFM_CCM_INTERVAL_1_MIN: + return 6; + case BR_CFM_CCM_INTERVAL_10_MIN: + return 7; + } + return 0; +} + +/* Convert the CCM PDU value to interval on interface. */ +static u32 pdu_to_interval(u32 value) +{ + switch (value) { + case 0: + return BR_CFM_CCM_INTERVAL_NONE; + case 1: + return BR_CFM_CCM_INTERVAL_3_3_MS; + case 2: + return BR_CFM_CCM_INTERVAL_10_MS; + case 3: + return BR_CFM_CCM_INTERVAL_100_MS; + case 4: + return BR_CFM_CCM_INTERVAL_1_SEC; + case 5: + return BR_CFM_CCM_INTERVAL_10_SEC; + case 6: + return BR_CFM_CCM_INTERVAL_1_MIN; + case 7: + return BR_CFM_CCM_INTERVAL_10_MIN; + } + return BR_CFM_CCM_INTERVAL_NONE; +} + +static void ccm_rx_timer_start(struct br_cfm_peer_mep *peer_mep) +{ + u32 interval_us; + + interval_us = interval_to_us(peer_mep->mep->cc_config.exp_interval); + /* Function ccm_rx_dwork must be called with 1/4 + * of the configured CC 'expected_interval' + * in order to detect CCM defect after 3.25 interval. + */ + queue_delayed_work(system_wq, &peer_mep->ccm_rx_dwork, + usecs_to_jiffies(interval_us / 4)); +} + +static void br_cfm_notify(int event, const struct net_bridge_port *port) +{ + u32 filter = RTEXT_FILTER_CFM_STATUS; + + br_info_notify(event, port->br, NULL, filter); +} + +static void cc_peer_enable(struct br_cfm_peer_mep *peer_mep) +{ + memset(&peer_mep->cc_status, 0, sizeof(peer_mep->cc_status)); + peer_mep->ccm_rx_count_miss = 0; + + ccm_rx_timer_start(peer_mep); +} + +static void cc_peer_disable(struct br_cfm_peer_mep *peer_mep) +{ + cancel_delayed_work_sync(&peer_mep->ccm_rx_dwork); +} + +static struct sk_buff *ccm_frame_build(struct br_cfm_mep *mep, + const struct br_cfm_cc_ccm_tx_info *const tx_info) + +{ + struct br_cfm_common_hdr *common_hdr; + struct net_bridge_port *b_port; + struct br_cfm_maid *maid; + u8 *itu_reserved, *e_tlv; + struct ethhdr *eth_hdr; + struct sk_buff *skb; + __be32 *status_tlv; + __be32 *snumber; + __be16 *mepid; + + skb = dev_alloc_skb(CFM_CCM_MAX_FRAME_LENGTH); + if (!skb) + return NULL; + + rcu_read_lock(); + b_port = rcu_dereference(mep->b_port); + if (!b_port) { + kfree_skb(skb); + rcu_read_unlock(); + return NULL; + } + skb->dev = b_port->dev; + rcu_read_unlock(); + /* The device cannot be deleted until the work_queue functions has + * completed. This function is called from ccm_tx_work_expired() + * that is a work_queue functions. + */ + + skb->protocol = htons(ETH_P_CFM); + skb->priority = CFM_FRAME_PRIO; + + /* Ethernet header */ + eth_hdr = skb_put(skb, sizeof(*eth_hdr)); + ether_addr_copy(eth_hdr->h_dest, tx_info->dmac.addr); + ether_addr_copy(eth_hdr->h_source, mep->config.unicast_mac.addr); + eth_hdr->h_proto = htons(ETH_P_CFM); + + /* Common CFM Header */ + common_hdr = skb_put(skb, sizeof(*common_hdr)); + common_hdr->mdlevel_version = mep->config.mdlevel << 5; + common_hdr->opcode = BR_CFM_OPCODE_CCM; + common_hdr->flags = (mep->rdi << 7) | + interval_to_pdu(mep->cc_config.exp_interval); + common_hdr->tlv_offset = CFM_CCM_TLV_OFFSET; + + /* Sequence number */ + snumber = skb_put(skb, sizeof(*snumber)); + if (tx_info->seq_no_update) { + *snumber = cpu_to_be32(mep->ccm_tx_snumber); + mep->ccm_tx_snumber += 1; + } else { + *snumber = 0; + } + + mepid = skb_put(skb, sizeof(*mepid)); + *mepid = cpu_to_be16((u16)mep->config.mepid); + + maid = skb_put(skb, sizeof(*maid)); + memcpy(maid->data, mep->cc_config.exp_maid.data, sizeof(maid->data)); + + /* ITU reserved (CFM_CCM_ITU_RESERVED_SIZE octets) */ + itu_reserved = skb_put(skb, CFM_CCM_ITU_RESERVED_SIZE); + memset(itu_reserved, 0, CFM_CCM_ITU_RESERVED_SIZE); + + /* Generel CFM TLV format: + * TLV type: one byte + * TLV value length: two bytes + * TLV value: 'TLV value length' bytes + */ + + /* Port status TLV. The value length is 1. Total of 4 bytes. */ + if (tx_info->port_tlv) { + status_tlv = skb_put(skb, sizeof(*status_tlv)); + *status_tlv = cpu_to_be32((CFM_PORT_STATUS_TLV_TYPE << 24) | + (1 << 8) | /* Value length */ + (tx_info->port_tlv_value & 0xFF)); + } + + /* Interface status TLV. The value length is 1. Total of 4 bytes. */ + if (tx_info->if_tlv) { + status_tlv = skb_put(skb, sizeof(*status_tlv)); + *status_tlv = cpu_to_be32((CFM_IF_STATUS_TLV_TYPE << 24) | + (1 << 8) | /* Value length */ + (tx_info->if_tlv_value & 0xFF)); + } + + /* End TLV */ + e_tlv = skb_put(skb, sizeof(*e_tlv)); + *e_tlv = CFM_ENDE_TLV_TYPE; + + return skb; +} + +static void ccm_frame_tx(struct sk_buff *skb) +{ + skb_reset_network_header(skb); + dev_queue_xmit(skb); +} + +/* This function is called with the configured CC 'expected_interval' + * in order to drive CCM transmission when enabled. + */ +static void ccm_tx_work_expired(struct work_struct *work) +{ + struct delayed_work *del_work; + struct br_cfm_mep *mep; + struct sk_buff *skb; + u32 interval_us; + + del_work = to_delayed_work(work); + mep = container_of(del_work, struct br_cfm_mep, ccm_tx_dwork); + + if (time_before_eq(mep->ccm_tx_end, jiffies)) { + /* Transmission period has ended */ + mep->cc_ccm_tx_info.period = 0; + return; + } + + skb = ccm_frame_build(mep, &mep->cc_ccm_tx_info); + if (skb) + ccm_frame_tx(skb); + + interval_us = interval_to_us(mep->cc_config.exp_interval); + queue_delayed_work(system_wq, &mep->ccm_tx_dwork, + usecs_to_jiffies(interval_us)); +} + +/* This function is called with 1/4 of the configured CC 'expected_interval' + * in order to detect CCM defect after 3.25 interval. + */ +static void ccm_rx_work_expired(struct work_struct *work) +{ + struct br_cfm_peer_mep *peer_mep; + struct net_bridge_port *b_port; + struct delayed_work *del_work; + + del_work = to_delayed_work(work); + peer_mep = container_of(del_work, struct br_cfm_peer_mep, ccm_rx_dwork); + + /* After 13 counts (4 * 3,25) then 3.25 intervals are expired */ + if (peer_mep->ccm_rx_count_miss < 13) { + /* 3.25 intervals are NOT expired without CCM reception */ + peer_mep->ccm_rx_count_miss++; + + /* Start timer again */ + ccm_rx_timer_start(peer_mep); + } else { + /* 3.25 intervals are expired without CCM reception. + * CCM defect detected + */ + peer_mep->cc_status.ccm_defect = true; + + /* Change in CCM defect status - notify */ + rcu_read_lock(); + b_port = rcu_dereference(peer_mep->mep->b_port); + if (b_port) + br_cfm_notify(RTM_NEWLINK, b_port); + rcu_read_unlock(); + } +} + +static u32 ccm_tlv_extract(struct sk_buff *skb, u32 index, + struct br_cfm_peer_mep *peer_mep) +{ + __be32 *s_tlv; + __be32 _s_tlv; + u32 h_s_tlv; + u8 *e_tlv; + u8 _e_tlv; + + e_tlv = skb_header_pointer(skb, index, sizeof(_e_tlv), &_e_tlv); + if (!e_tlv) + return 0; + + /* TLV is present - get the status TLV */ + s_tlv = skb_header_pointer(skb, + index, + sizeof(_s_tlv), &_s_tlv); + if (!s_tlv) + return 0; + + h_s_tlv = ntohl(*s_tlv); + if ((h_s_tlv >> 24) == CFM_IF_STATUS_TLV_TYPE) { + /* Interface status TLV */ + peer_mep->cc_status.tlv_seen = true; + peer_mep->cc_status.if_tlv_value = (h_s_tlv & 0xFF); + } + + if ((h_s_tlv >> 24) == CFM_PORT_STATUS_TLV_TYPE) { + /* Port status TLV */ + peer_mep->cc_status.tlv_seen = true; + peer_mep->cc_status.port_tlv_value = (h_s_tlv & 0xFF); + } + + /* The Sender ID TLV is not handled */ + /* The Organization-Specific TLV is not handled */ + + /* Return the length of this tlv. + * This is the length of the value field plus 3 bytes for size of type + * field and length field + */ + return ((h_s_tlv >> 8) & 0xFFFF) + 3; +} + +/* note: already called with rcu_read_lock */ +static int br_cfm_frame_rx(struct net_bridge_port *port, struct sk_buff *skb) +{ + u32 mdlevel, interval, size, index, max; + const struct br_cfm_common_hdr *hdr; + struct br_cfm_peer_mep *peer_mep; + const struct br_cfm_maid *maid; + struct br_cfm_common_hdr _hdr; + struct br_cfm_maid _maid; + struct br_cfm_mep *mep; + struct net_bridge *br; + __be32 *snumber; + __be32 _snumber; + __be16 *mepid; + __be16 _mepid; + + if (port->state == BR_STATE_DISABLED) + return 0; + + hdr = skb_header_pointer(skb, 0, sizeof(_hdr), &_hdr); + if (!hdr) + return 1; + + br = port->br; + mep = br_mep_find_ifindex(br, port->dev->ifindex); + if (unlikely(!mep)) + /* No MEP on this port - must be forwarded */ + return 0; + + mdlevel = hdr->mdlevel_version >> 5; + if (mdlevel > mep->config.mdlevel) + /* The level is above this MEP level - must be forwarded */ + return 0; + + if ((hdr->mdlevel_version & 0x1F) != 0) { + /* Invalid version */ + mep->status.version_unexp_seen = true; + return 1; + } + + if (mdlevel < mep->config.mdlevel) { + /* The level is below this MEP level */ + mep->status.rx_level_low_seen = true; + return 1; + } + + if (hdr->opcode == BR_CFM_OPCODE_CCM) { + /* CCM PDU received. */ + /* MA ID is after common header + sequence number + MEP ID */ + maid = skb_header_pointer(skb, + CFM_CCM_PDU_MAID_OFFSET, + sizeof(_maid), &_maid); + if (!maid) + return 1; + if (memcmp(maid->data, mep->cc_config.exp_maid.data, + sizeof(maid->data))) + /* MA ID not as expected */ + return 1; + + /* MEP ID is after common header + sequence number */ + mepid = skb_header_pointer(skb, + CFM_CCM_PDU_MEPID_OFFSET, + sizeof(_mepid), &_mepid); + if (!mepid) + return 1; + peer_mep = br_peer_mep_find(mep, (u32)ntohs(*mepid)); + if (!peer_mep) + return 1; + + /* Interval is in common header flags */ + interval = hdr->flags & 0x07; + if (mep->cc_config.exp_interval != pdu_to_interval(interval)) + /* Interval not as expected */ + return 1; + + /* A valid CCM frame is received */ + if (peer_mep->cc_status.ccm_defect) { + peer_mep->cc_status.ccm_defect = false; + + /* Change in CCM defect status - notify */ + br_cfm_notify(RTM_NEWLINK, port); + + /* Start CCM RX timer */ + ccm_rx_timer_start(peer_mep); + } + + peer_mep->cc_status.seen = true; + peer_mep->ccm_rx_count_miss = 0; + + /* RDI is in common header flags */ + peer_mep->cc_status.rdi = (hdr->flags & 0x80) ? true : false; + + /* Sequence number is after common header */ + snumber = skb_header_pointer(skb, + CFM_CCM_PDU_SEQNR_OFFSET, + sizeof(_snumber), &_snumber); + if (!snumber) + return 1; + if (ntohl(*snumber) != (mep->ccm_rx_snumber + 1)) + /* Unexpected sequence number */ + peer_mep->cc_status.seq_unexp_seen = true; + + mep->ccm_rx_snumber = ntohl(*snumber); + + /* TLV end is after common header + sequence number + MEP ID + + * MA ID + ITU reserved + */ + index = CFM_CCM_PDU_TLV_OFFSET; + max = 0; + do { /* Handle all TLVs */ + size = ccm_tlv_extract(skb, index, peer_mep); + index += size; + max += 1; + } while (size != 0 && max < 4); /* Max four TLVs possible */ + + return 1; + } + + mep->status.opcode_unexp_seen = true; + + return 1; +} + +static struct br_frame_type cfm_frame_type __read_mostly = { + .type = cpu_to_be16(ETH_P_CFM), + .frame_handler = br_cfm_frame_rx, +}; + +int br_cfm_mep_create(struct net_bridge *br, + const u32 instance, + struct br_cfm_mep_create *const create, + struct netlink_ext_ack *extack) +{ + struct net_bridge_port *p; + struct br_cfm_mep *mep; + + ASSERT_RTNL(); + + if (create->domain == BR_CFM_VLAN) { + NL_SET_ERR_MSG_MOD(extack, + "VLAN domain not supported"); + return -EINVAL; + } + if (create->domain != BR_CFM_PORT) { + NL_SET_ERR_MSG_MOD(extack, + "Invalid domain value"); + return -EINVAL; + } + if (create->direction == BR_CFM_MEP_DIRECTION_UP) { + NL_SET_ERR_MSG_MOD(extack, + "Up-MEP not supported"); + return -EINVAL; + } + if (create->direction != BR_CFM_MEP_DIRECTION_DOWN) { + NL_SET_ERR_MSG_MOD(extack, + "Invalid direction value"); + return -EINVAL; + } + p = br_mep_get_port(br, create->ifindex); + if (!p) { + NL_SET_ERR_MSG_MOD(extack, + "Port is not related to bridge"); + return -EINVAL; + } + mep = br_mep_find(br, instance); + if (mep) { + NL_SET_ERR_MSG_MOD(extack, + "MEP instance already exists"); + return -EEXIST; + } + + /* In PORT domain only one instance can be created per port */ + if (create->domain == BR_CFM_PORT) { + mep = br_mep_find_ifindex(br, create->ifindex); + if (mep) { + NL_SET_ERR_MSG_MOD(extack, + "Only one Port MEP on a port allowed"); + return -EINVAL; + } + } + + mep = kzalloc(sizeof(*mep), GFP_KERNEL); + if (!mep) + return -ENOMEM; + + mep->create = *create; + mep->instance = instance; + rcu_assign_pointer(mep->b_port, p); + + INIT_HLIST_HEAD(&mep->peer_mep_list); + INIT_DELAYED_WORK(&mep->ccm_tx_dwork, ccm_tx_work_expired); + + if (hlist_empty(&br->mep_list)) + br_add_frame(br, &cfm_frame_type); + + hlist_add_tail_rcu(&mep->head, &br->mep_list); + + return 0; +} + +static void mep_delete_implementation(struct net_bridge *br, + struct br_cfm_mep *mep) +{ + struct br_cfm_peer_mep *peer_mep; + struct hlist_node *n_store; + + ASSERT_RTNL(); + + /* Empty and free peer MEP list */ + hlist_for_each_entry_safe(peer_mep, n_store, &mep->peer_mep_list, head) { + cancel_delayed_work_sync(&peer_mep->ccm_rx_dwork); + hlist_del_rcu(&peer_mep->head); + kfree_rcu(peer_mep, rcu); + } + + cancel_delayed_work_sync(&mep->ccm_tx_dwork); + + RCU_INIT_POINTER(mep->b_port, NULL); + hlist_del_rcu(&mep->head); + kfree_rcu(mep, rcu); + + if (hlist_empty(&br->mep_list)) + br_del_frame(br, &cfm_frame_type); +} + +int br_cfm_mep_delete(struct net_bridge *br, + const u32 instance, + struct netlink_ext_ack *extack) +{ + struct br_cfm_mep *mep; + + ASSERT_RTNL(); + + mep = br_mep_find(br, instance); + if (!mep) { + NL_SET_ERR_MSG_MOD(extack, + "MEP instance does not exists"); + return -ENOENT; + } + + mep_delete_implementation(br, mep); + + return 0; +} + +int br_cfm_mep_config_set(struct net_bridge *br, + const u32 instance, + const struct br_cfm_mep_config *const config, + struct netlink_ext_ack *extack) +{ + struct br_cfm_mep *mep; + + ASSERT_RTNL(); + + mep = br_mep_find(br, instance); + if (!mep) { + NL_SET_ERR_MSG_MOD(extack, + "MEP instance does not exists"); + return -ENOENT; + } + + mep->config = *config; + + return 0; +} + +int br_cfm_cc_config_set(struct net_bridge *br, + const u32 instance, + const struct br_cfm_cc_config *const config, + struct netlink_ext_ack *extack) +{ + struct br_cfm_peer_mep *peer_mep; + struct br_cfm_mep *mep; + + ASSERT_RTNL(); + + mep = br_mep_find(br, instance); + if (!mep) { + NL_SET_ERR_MSG_MOD(extack, + "MEP instance does not exists"); + return -ENOENT; + } + + /* Check for no change in configuration */ + if (memcmp(config, &mep->cc_config, sizeof(*config)) == 0) + return 0; + + if (config->enable && !mep->cc_config.enable) + /* CC is enabled */ + hlist_for_each_entry(peer_mep, &mep->peer_mep_list, head) + cc_peer_enable(peer_mep); + + if (!config->enable && mep->cc_config.enable) + /* CC is disabled */ + hlist_for_each_entry(peer_mep, &mep->peer_mep_list, head) + cc_peer_disable(peer_mep); + + mep->cc_config = *config; + mep->ccm_rx_snumber = 0; + mep->ccm_tx_snumber = 1; + + return 0; +} + +int br_cfm_cc_peer_mep_add(struct net_bridge *br, const u32 instance, + u32 mepid, + struct netlink_ext_ack *extack) +{ + struct br_cfm_peer_mep *peer_mep; + struct br_cfm_mep *mep; + + ASSERT_RTNL(); + + mep = br_mep_find(br, instance); + if (!mep) { + NL_SET_ERR_MSG_MOD(extack, + "MEP instance does not exists"); + return -ENOENT; + } + + peer_mep = br_peer_mep_find(mep, mepid); + if (peer_mep) { + NL_SET_ERR_MSG_MOD(extack, + "Peer MEP-ID already exists"); + return -EEXIST; + } + + peer_mep = kzalloc(sizeof(*peer_mep), GFP_KERNEL); + if (!peer_mep) + return -ENOMEM; + + peer_mep->mepid = mepid; + peer_mep->mep = mep; + INIT_DELAYED_WORK(&peer_mep->ccm_rx_dwork, ccm_rx_work_expired); + + if (mep->cc_config.enable) + cc_peer_enable(peer_mep); + + hlist_add_tail_rcu(&peer_mep->head, &mep->peer_mep_list); + + return 0; +} + +int br_cfm_cc_peer_mep_remove(struct net_bridge *br, const u32 instance, + u32 mepid, + struct netlink_ext_ack *extack) +{ + struct br_cfm_peer_mep *peer_mep; + struct br_cfm_mep *mep; + + ASSERT_RTNL(); + + mep = br_mep_find(br, instance); + if (!mep) { + NL_SET_ERR_MSG_MOD(extack, + "MEP instance does not exists"); + return -ENOENT; + } + + peer_mep = br_peer_mep_find(mep, mepid); + if (!peer_mep) { + NL_SET_ERR_MSG_MOD(extack, + "Peer MEP-ID does not exists"); + return -ENOENT; + } + + cc_peer_disable(peer_mep); + + hlist_del_rcu(&peer_mep->head); + kfree_rcu(peer_mep, rcu); + + return 0; +} + +int br_cfm_cc_rdi_set(struct net_bridge *br, const u32 instance, + const bool rdi, struct netlink_ext_ack *extack) +{ + struct br_cfm_mep *mep; + + ASSERT_RTNL(); + + mep = br_mep_find(br, instance); + if (!mep) { + NL_SET_ERR_MSG_MOD(extack, + "MEP instance does not exists"); + return -ENOENT; + } + + mep->rdi = rdi; + + return 0; +} + +int br_cfm_cc_ccm_tx(struct net_bridge *br, const u32 instance, + const struct br_cfm_cc_ccm_tx_info *const tx_info, + struct netlink_ext_ack *extack) +{ + struct br_cfm_mep *mep; + + ASSERT_RTNL(); + + mep = br_mep_find(br, instance); + if (!mep) { + NL_SET_ERR_MSG_MOD(extack, + "MEP instance does not exists"); + return -ENOENT; + } + + if (memcmp(tx_info, &mep->cc_ccm_tx_info, sizeof(*tx_info)) == 0) { + /* No change in tx_info. */ + if (mep->cc_ccm_tx_info.period == 0) + /* Transmission is not enabled - just return */ + return 0; + + /* Transmission is ongoing, the end time is recalculated */ + mep->ccm_tx_end = jiffies + + usecs_to_jiffies(tx_info->period * 1000000); + return 0; + } + + if (tx_info->period == 0 && mep->cc_ccm_tx_info.period == 0) + /* Some change in info and transmission is not ongoing */ + goto save; + + if (tx_info->period != 0 && mep->cc_ccm_tx_info.period != 0) { + /* Some change in info and transmission is ongoing + * The end time is recalculated + */ + mep->ccm_tx_end = jiffies + + usecs_to_jiffies(tx_info->period * 1000000); + + goto save; + } + + if (tx_info->period == 0 && mep->cc_ccm_tx_info.period != 0) { + cancel_delayed_work_sync(&mep->ccm_tx_dwork); + goto save; + } + + /* Start delayed work to transmit CCM frames. It is done with zero delay + * to send first frame immediately + */ + mep->ccm_tx_end = jiffies + usecs_to_jiffies(tx_info->period * 1000000); + queue_delayed_work(system_wq, &mep->ccm_tx_dwork, 0); + +save: + mep->cc_ccm_tx_info = *tx_info; + + return 0; +} + +int br_cfm_mep_count(struct net_bridge *br, u32 *count) +{ + struct br_cfm_mep *mep; + + *count = 0; + + rcu_read_lock(); + hlist_for_each_entry_rcu(mep, &br->mep_list, head) + *count += 1; + rcu_read_unlock(); + + return 0; +} + +int br_cfm_peer_mep_count(struct net_bridge *br, u32 *count) +{ + struct br_cfm_peer_mep *peer_mep; + struct br_cfm_mep *mep; + + *count = 0; + + rcu_read_lock(); + hlist_for_each_entry_rcu(mep, &br->mep_list, head) + hlist_for_each_entry_rcu(peer_mep, &mep->peer_mep_list, head) + *count += 1; + rcu_read_unlock(); + + return 0; +} + +bool br_cfm_created(struct net_bridge *br) +{ + return !hlist_empty(&br->mep_list); +} + +/* Deletes the CFM instances on a specific bridge port + */ +void br_cfm_port_del(struct net_bridge *br, struct net_bridge_port *port) +{ + struct hlist_node *n_store; + struct br_cfm_mep *mep; + + ASSERT_RTNL(); + + hlist_for_each_entry_safe(mep, n_store, &br->mep_list, head) + if (mep->create.ifindex == port->dev->ifindex) + mep_delete_implementation(br, mep); +} diff --git a/net/bridge/br_cfm_netlink.c b/net/bridge/br_cfm_netlink.c new file mode 100644 index 000000000..5c4c369f8 --- /dev/null +++ b/net/bridge/br_cfm_netlink.c @@ -0,0 +1,726 @@ +// SPDX-License-Identifier: GPL-2.0-or-later + +#include + +#include "br_private.h" +#include "br_private_cfm.h" + +static const struct nla_policy +br_cfm_mep_create_policy[IFLA_BRIDGE_CFM_MEP_CREATE_MAX + 1] = { + [IFLA_BRIDGE_CFM_MEP_CREATE_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_CFM_MEP_CREATE_INSTANCE] = { .type = NLA_U32 }, + [IFLA_BRIDGE_CFM_MEP_CREATE_DOMAIN] = { .type = NLA_U32 }, + [IFLA_BRIDGE_CFM_MEP_CREATE_DIRECTION] = { .type = NLA_U32 }, + [IFLA_BRIDGE_CFM_MEP_CREATE_IFINDEX] = { .type = NLA_U32 }, +}; + +static const struct nla_policy +br_cfm_mep_delete_policy[IFLA_BRIDGE_CFM_MEP_DELETE_MAX + 1] = { + [IFLA_BRIDGE_CFM_MEP_DELETE_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_CFM_MEP_DELETE_INSTANCE] = { .type = NLA_U32 }, +}; + +static const struct nla_policy +br_cfm_mep_config_policy[IFLA_BRIDGE_CFM_MEP_CONFIG_MAX + 1] = { + [IFLA_BRIDGE_CFM_MEP_CONFIG_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_CFM_MEP_CONFIG_INSTANCE] = { .type = NLA_U32 }, + [IFLA_BRIDGE_CFM_MEP_CONFIG_UNICAST_MAC] = NLA_POLICY_ETH_ADDR, + [IFLA_BRIDGE_CFM_MEP_CONFIG_MDLEVEL] = NLA_POLICY_MAX(NLA_U32, 7), + [IFLA_BRIDGE_CFM_MEP_CONFIG_MEPID] = NLA_POLICY_MAX(NLA_U32, 0x1FFF), +}; + +static const struct nla_policy +br_cfm_cc_config_policy[IFLA_BRIDGE_CFM_CC_CONFIG_MAX + 1] = { + [IFLA_BRIDGE_CFM_CC_CONFIG_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_CFM_CC_CONFIG_INSTANCE] = { .type = NLA_U32 }, + [IFLA_BRIDGE_CFM_CC_CONFIG_ENABLE] = { .type = NLA_U32 }, + [IFLA_BRIDGE_CFM_CC_CONFIG_EXP_INTERVAL] = { .type = NLA_U32 }, + [IFLA_BRIDGE_CFM_CC_CONFIG_EXP_MAID] = { + .type = NLA_BINARY, .len = CFM_MAID_LENGTH }, +}; + +static const struct nla_policy +br_cfm_cc_peer_mep_policy[IFLA_BRIDGE_CFM_CC_PEER_MEP_MAX + 1] = { + [IFLA_BRIDGE_CFM_CC_PEER_MEP_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_CFM_CC_PEER_MEP_INSTANCE] = { .type = NLA_U32 }, + [IFLA_BRIDGE_CFM_CC_PEER_MEPID] = NLA_POLICY_MAX(NLA_U32, 0x1FFF), +}; + +static const struct nla_policy +br_cfm_cc_rdi_policy[IFLA_BRIDGE_CFM_CC_RDI_MAX + 1] = { + [IFLA_BRIDGE_CFM_CC_RDI_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_CFM_CC_RDI_INSTANCE] = { .type = NLA_U32 }, + [IFLA_BRIDGE_CFM_CC_RDI_RDI] = { .type = NLA_U32 }, +}; + +static const struct nla_policy +br_cfm_cc_ccm_tx_policy[IFLA_BRIDGE_CFM_CC_CCM_TX_MAX + 1] = { + [IFLA_BRIDGE_CFM_CC_CCM_TX_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_CFM_CC_CCM_TX_INSTANCE] = { .type = NLA_U32 }, + [IFLA_BRIDGE_CFM_CC_CCM_TX_DMAC] = NLA_POLICY_ETH_ADDR, + [IFLA_BRIDGE_CFM_CC_CCM_TX_SEQ_NO_UPDATE] = { .type = NLA_U32 }, + [IFLA_BRIDGE_CFM_CC_CCM_TX_PERIOD] = { .type = NLA_U32 }, + [IFLA_BRIDGE_CFM_CC_CCM_TX_IF_TLV] = { .type = NLA_U32 }, + [IFLA_BRIDGE_CFM_CC_CCM_TX_IF_TLV_VALUE] = { .type = NLA_U8 }, + [IFLA_BRIDGE_CFM_CC_CCM_TX_PORT_TLV] = { .type = NLA_U32 }, + [IFLA_BRIDGE_CFM_CC_CCM_TX_PORT_TLV_VALUE] = { .type = NLA_U8 }, +}; + +static const struct nla_policy +br_cfm_policy[IFLA_BRIDGE_CFM_MAX + 1] = { + [IFLA_BRIDGE_CFM_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_CFM_MEP_CREATE] = + NLA_POLICY_NESTED(br_cfm_mep_create_policy), + [IFLA_BRIDGE_CFM_MEP_DELETE] = + NLA_POLICY_NESTED(br_cfm_mep_delete_policy), + [IFLA_BRIDGE_CFM_MEP_CONFIG] = + NLA_POLICY_NESTED(br_cfm_mep_config_policy), + [IFLA_BRIDGE_CFM_CC_CONFIG] = + NLA_POLICY_NESTED(br_cfm_cc_config_policy), + [IFLA_BRIDGE_CFM_CC_PEER_MEP_ADD] = + NLA_POLICY_NESTED(br_cfm_cc_peer_mep_policy), + [IFLA_BRIDGE_CFM_CC_PEER_MEP_REMOVE] = + NLA_POLICY_NESTED(br_cfm_cc_peer_mep_policy), + [IFLA_BRIDGE_CFM_CC_RDI] = + NLA_POLICY_NESTED(br_cfm_cc_rdi_policy), + [IFLA_BRIDGE_CFM_CC_CCM_TX] = + NLA_POLICY_NESTED(br_cfm_cc_ccm_tx_policy), +}; + +static int br_mep_create_parse(struct net_bridge *br, struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_CFM_MEP_CREATE_MAX + 1]; + struct br_cfm_mep_create create; + u32 instance; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_CFM_MEP_CREATE_MAX, attr, + br_cfm_mep_create_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_CFM_MEP_CREATE_INSTANCE]) { + NL_SET_ERR_MSG_MOD(extack, "Missing INSTANCE attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_MEP_CREATE_DOMAIN]) { + NL_SET_ERR_MSG_MOD(extack, "Missing DOMAIN attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_MEP_CREATE_DIRECTION]) { + NL_SET_ERR_MSG_MOD(extack, "Missing DIRECTION attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_MEP_CREATE_IFINDEX]) { + NL_SET_ERR_MSG_MOD(extack, "Missing IFINDEX attribute"); + return -EINVAL; + } + + memset(&create, 0, sizeof(create)); + + instance = nla_get_u32(tb[IFLA_BRIDGE_CFM_MEP_CREATE_INSTANCE]); + create.domain = nla_get_u32(tb[IFLA_BRIDGE_CFM_MEP_CREATE_DOMAIN]); + create.direction = nla_get_u32(tb[IFLA_BRIDGE_CFM_MEP_CREATE_DIRECTION]); + create.ifindex = nla_get_u32(tb[IFLA_BRIDGE_CFM_MEP_CREATE_IFINDEX]); + + return br_cfm_mep_create(br, instance, &create, extack); +} + +static int br_mep_delete_parse(struct net_bridge *br, struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_CFM_MEP_DELETE_MAX + 1]; + u32 instance; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_CFM_MEP_DELETE_MAX, attr, + br_cfm_mep_delete_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_CFM_MEP_DELETE_INSTANCE]) { + NL_SET_ERR_MSG_MOD(extack, + "Missing INSTANCE attribute"); + return -EINVAL; + } + + instance = nla_get_u32(tb[IFLA_BRIDGE_CFM_MEP_DELETE_INSTANCE]); + + return br_cfm_mep_delete(br, instance, extack); +} + +static int br_mep_config_parse(struct net_bridge *br, struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_CFM_MEP_CONFIG_MAX + 1]; + struct br_cfm_mep_config config; + u32 instance; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_CFM_MEP_CONFIG_MAX, attr, + br_cfm_mep_config_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_CFM_MEP_CONFIG_INSTANCE]) { + NL_SET_ERR_MSG_MOD(extack, "Missing INSTANCE attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_MEP_CONFIG_UNICAST_MAC]) { + NL_SET_ERR_MSG_MOD(extack, "Missing UNICAST_MAC attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_MEP_CONFIG_MDLEVEL]) { + NL_SET_ERR_MSG_MOD(extack, "Missing MDLEVEL attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_MEP_CONFIG_MEPID]) { + NL_SET_ERR_MSG_MOD(extack, "Missing MEPID attribute"); + return -EINVAL; + } + + memset(&config, 0, sizeof(config)); + + instance = nla_get_u32(tb[IFLA_BRIDGE_CFM_MEP_CONFIG_INSTANCE]); + nla_memcpy(&config.unicast_mac.addr, + tb[IFLA_BRIDGE_CFM_MEP_CONFIG_UNICAST_MAC], + sizeof(config.unicast_mac.addr)); + config.mdlevel = nla_get_u32(tb[IFLA_BRIDGE_CFM_MEP_CONFIG_MDLEVEL]); + config.mepid = nla_get_u32(tb[IFLA_BRIDGE_CFM_MEP_CONFIG_MEPID]); + + return br_cfm_mep_config_set(br, instance, &config, extack); +} + +static int br_cc_config_parse(struct net_bridge *br, struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_CFM_CC_CONFIG_MAX + 1]; + struct br_cfm_cc_config config; + u32 instance; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_CFM_CC_CONFIG_MAX, attr, + br_cfm_cc_config_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_CFM_CC_CONFIG_INSTANCE]) { + NL_SET_ERR_MSG_MOD(extack, "Missing INSTANCE attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_CC_CONFIG_ENABLE]) { + NL_SET_ERR_MSG_MOD(extack, "Missing ENABLE attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_CC_CONFIG_EXP_INTERVAL]) { + NL_SET_ERR_MSG_MOD(extack, "Missing INTERVAL attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_CC_CONFIG_EXP_MAID]) { + NL_SET_ERR_MSG_MOD(extack, "Missing MAID attribute"); + return -EINVAL; + } + + memset(&config, 0, sizeof(config)); + + instance = nla_get_u32(tb[IFLA_BRIDGE_CFM_CC_CONFIG_INSTANCE]); + config.enable = nla_get_u32(tb[IFLA_BRIDGE_CFM_CC_CONFIG_ENABLE]); + config.exp_interval = nla_get_u32(tb[IFLA_BRIDGE_CFM_CC_CONFIG_EXP_INTERVAL]); + nla_memcpy(&config.exp_maid.data, tb[IFLA_BRIDGE_CFM_CC_CONFIG_EXP_MAID], + sizeof(config.exp_maid.data)); + + return br_cfm_cc_config_set(br, instance, &config, extack); +} + +static int br_cc_peer_mep_add_parse(struct net_bridge *br, struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_CFM_CC_PEER_MEP_MAX + 1]; + u32 instance, peer_mep_id; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_CFM_CC_PEER_MEP_MAX, attr, + br_cfm_cc_peer_mep_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_CFM_CC_PEER_MEP_INSTANCE]) { + NL_SET_ERR_MSG_MOD(extack, "Missing INSTANCE attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_CC_PEER_MEPID]) { + NL_SET_ERR_MSG_MOD(extack, "Missing PEER_MEP_ID attribute"); + return -EINVAL; + } + + instance = nla_get_u32(tb[IFLA_BRIDGE_CFM_CC_PEER_MEP_INSTANCE]); + peer_mep_id = nla_get_u32(tb[IFLA_BRIDGE_CFM_CC_PEER_MEPID]); + + return br_cfm_cc_peer_mep_add(br, instance, peer_mep_id, extack); +} + +static int br_cc_peer_mep_remove_parse(struct net_bridge *br, struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_CFM_CC_PEER_MEP_MAX + 1]; + u32 instance, peer_mep_id; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_CFM_CC_PEER_MEP_MAX, attr, + br_cfm_cc_peer_mep_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_CFM_CC_PEER_MEP_INSTANCE]) { + NL_SET_ERR_MSG_MOD(extack, "Missing INSTANCE attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_CC_PEER_MEPID]) { + NL_SET_ERR_MSG_MOD(extack, "Missing PEER_MEP_ID attribute"); + return -EINVAL; + } + + instance = nla_get_u32(tb[IFLA_BRIDGE_CFM_CC_PEER_MEP_INSTANCE]); + peer_mep_id = nla_get_u32(tb[IFLA_BRIDGE_CFM_CC_PEER_MEPID]); + + return br_cfm_cc_peer_mep_remove(br, instance, peer_mep_id, extack); +} + +static int br_cc_rdi_parse(struct net_bridge *br, struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_CFM_CC_RDI_MAX + 1]; + u32 instance, rdi; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_CFM_CC_RDI_MAX, attr, + br_cfm_cc_rdi_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_CFM_CC_RDI_INSTANCE]) { + NL_SET_ERR_MSG_MOD(extack, "Missing INSTANCE attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_CC_RDI_RDI]) { + NL_SET_ERR_MSG_MOD(extack, "Missing RDI attribute"); + return -EINVAL; + } + + instance = nla_get_u32(tb[IFLA_BRIDGE_CFM_CC_RDI_INSTANCE]); + rdi = nla_get_u32(tb[IFLA_BRIDGE_CFM_CC_RDI_RDI]); + + return br_cfm_cc_rdi_set(br, instance, rdi, extack); +} + +static int br_cc_ccm_tx_parse(struct net_bridge *br, struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_CFM_CC_CCM_TX_MAX + 1]; + struct br_cfm_cc_ccm_tx_info tx_info; + u32 instance; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_CFM_CC_CCM_TX_MAX, attr, + br_cfm_cc_ccm_tx_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_CFM_CC_CCM_TX_INSTANCE]) { + NL_SET_ERR_MSG_MOD(extack, "Missing INSTANCE attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_CC_CCM_TX_DMAC]) { + NL_SET_ERR_MSG_MOD(extack, "Missing DMAC attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_CC_CCM_TX_SEQ_NO_UPDATE]) { + NL_SET_ERR_MSG_MOD(extack, "Missing SEQ_NO_UPDATE attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_CC_CCM_TX_PERIOD]) { + NL_SET_ERR_MSG_MOD(extack, "Missing PERIOD attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_CC_CCM_TX_IF_TLV]) { + NL_SET_ERR_MSG_MOD(extack, "Missing IF_TLV attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_CC_CCM_TX_IF_TLV_VALUE]) { + NL_SET_ERR_MSG_MOD(extack, "Missing IF_TLV_VALUE attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_CC_CCM_TX_PORT_TLV]) { + NL_SET_ERR_MSG_MOD(extack, "Missing PORT_TLV attribute"); + return -EINVAL; + } + if (!tb[IFLA_BRIDGE_CFM_CC_CCM_TX_PORT_TLV_VALUE]) { + NL_SET_ERR_MSG_MOD(extack, "Missing PORT_TLV_VALUE attribute"); + return -EINVAL; + } + + memset(&tx_info, 0, sizeof(tx_info)); + + instance = nla_get_u32(tb[IFLA_BRIDGE_CFM_CC_RDI_INSTANCE]); + nla_memcpy(&tx_info.dmac.addr, + tb[IFLA_BRIDGE_CFM_CC_CCM_TX_DMAC], + sizeof(tx_info.dmac.addr)); + tx_info.seq_no_update = nla_get_u32(tb[IFLA_BRIDGE_CFM_CC_CCM_TX_SEQ_NO_UPDATE]); + tx_info.period = nla_get_u32(tb[IFLA_BRIDGE_CFM_CC_CCM_TX_PERIOD]); + tx_info.if_tlv = nla_get_u32(tb[IFLA_BRIDGE_CFM_CC_CCM_TX_IF_TLV]); + tx_info.if_tlv_value = nla_get_u8(tb[IFLA_BRIDGE_CFM_CC_CCM_TX_IF_TLV_VALUE]); + tx_info.port_tlv = nla_get_u32(tb[IFLA_BRIDGE_CFM_CC_CCM_TX_PORT_TLV]); + tx_info.port_tlv_value = nla_get_u8(tb[IFLA_BRIDGE_CFM_CC_CCM_TX_PORT_TLV_VALUE]); + + return br_cfm_cc_ccm_tx(br, instance, &tx_info, extack); +} + +int br_cfm_parse(struct net_bridge *br, struct net_bridge_port *p, + struct nlattr *attr, int cmd, struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_CFM_MAX + 1]; + int err; + + /* When this function is called for a port then the br pointer is + * invalid, therefor set the br to point correctly + */ + if (p) + br = p->br; + + err = nla_parse_nested(tb, IFLA_BRIDGE_CFM_MAX, attr, + br_cfm_policy, extack); + if (err) + return err; + + if (tb[IFLA_BRIDGE_CFM_MEP_CREATE]) { + err = br_mep_create_parse(br, tb[IFLA_BRIDGE_CFM_MEP_CREATE], + extack); + if (err) + return err; + } + + if (tb[IFLA_BRIDGE_CFM_MEP_DELETE]) { + err = br_mep_delete_parse(br, tb[IFLA_BRIDGE_CFM_MEP_DELETE], + extack); + if (err) + return err; + } + + if (tb[IFLA_BRIDGE_CFM_MEP_CONFIG]) { + err = br_mep_config_parse(br, tb[IFLA_BRIDGE_CFM_MEP_CONFIG], + extack); + if (err) + return err; + } + + if (tb[IFLA_BRIDGE_CFM_CC_CONFIG]) { + err = br_cc_config_parse(br, tb[IFLA_BRIDGE_CFM_CC_CONFIG], + extack); + if (err) + return err; + } + + if (tb[IFLA_BRIDGE_CFM_CC_PEER_MEP_ADD]) { + err = br_cc_peer_mep_add_parse(br, tb[IFLA_BRIDGE_CFM_CC_PEER_MEP_ADD], + extack); + if (err) + return err; + } + + if (tb[IFLA_BRIDGE_CFM_CC_PEER_MEP_REMOVE]) { + err = br_cc_peer_mep_remove_parse(br, tb[IFLA_BRIDGE_CFM_CC_PEER_MEP_REMOVE], + extack); + if (err) + return err; + } + + if (tb[IFLA_BRIDGE_CFM_CC_RDI]) { + err = br_cc_rdi_parse(br, tb[IFLA_BRIDGE_CFM_CC_RDI], + extack); + if (err) + return err; + } + + if (tb[IFLA_BRIDGE_CFM_CC_CCM_TX]) { + err = br_cc_ccm_tx_parse(br, tb[IFLA_BRIDGE_CFM_CC_CCM_TX], + extack); + if (err) + return err; + } + + return 0; +} + +int br_cfm_config_fill_info(struct sk_buff *skb, struct net_bridge *br) +{ + struct br_cfm_peer_mep *peer_mep; + struct br_cfm_mep *mep; + struct nlattr *tb; + + hlist_for_each_entry_rcu(mep, &br->mep_list, head) { + tb = nla_nest_start(skb, IFLA_BRIDGE_CFM_MEP_CREATE_INFO); + if (!tb) + goto nla_info_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_MEP_CREATE_INSTANCE, + mep->instance)) + goto nla_put_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_MEP_CREATE_DOMAIN, + mep->create.domain)) + goto nla_put_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_MEP_CREATE_DIRECTION, + mep->create.direction)) + goto nla_put_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_MEP_CREATE_IFINDEX, + mep->create.ifindex)) + goto nla_put_failure; + + nla_nest_end(skb, tb); + + tb = nla_nest_start(skb, IFLA_BRIDGE_CFM_MEP_CONFIG_INFO); + + if (!tb) + goto nla_info_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_MEP_CONFIG_INSTANCE, + mep->instance)) + goto nla_put_failure; + + if (nla_put(skb, IFLA_BRIDGE_CFM_MEP_CONFIG_UNICAST_MAC, + sizeof(mep->config.unicast_mac.addr), + mep->config.unicast_mac.addr)) + goto nla_put_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_MEP_CONFIG_MDLEVEL, + mep->config.mdlevel)) + goto nla_put_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_MEP_CONFIG_MEPID, + mep->config.mepid)) + goto nla_put_failure; + + nla_nest_end(skb, tb); + + tb = nla_nest_start(skb, IFLA_BRIDGE_CFM_CC_CONFIG_INFO); + + if (!tb) + goto nla_info_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_CC_CONFIG_INSTANCE, + mep->instance)) + goto nla_put_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_CC_CONFIG_ENABLE, + mep->cc_config.enable)) + goto nla_put_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_CC_CONFIG_EXP_INTERVAL, + mep->cc_config.exp_interval)) + goto nla_put_failure; + + if (nla_put(skb, IFLA_BRIDGE_CFM_CC_CONFIG_EXP_MAID, + sizeof(mep->cc_config.exp_maid.data), + mep->cc_config.exp_maid.data)) + goto nla_put_failure; + + nla_nest_end(skb, tb); + + tb = nla_nest_start(skb, IFLA_BRIDGE_CFM_CC_RDI_INFO); + + if (!tb) + goto nla_info_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_CC_RDI_INSTANCE, + mep->instance)) + goto nla_put_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_CC_RDI_RDI, + mep->rdi)) + goto nla_put_failure; + + nla_nest_end(skb, tb); + + tb = nla_nest_start(skb, IFLA_BRIDGE_CFM_CC_CCM_TX_INFO); + + if (!tb) + goto nla_info_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_CC_CCM_TX_INSTANCE, + mep->instance)) + goto nla_put_failure; + + if (nla_put(skb, IFLA_BRIDGE_CFM_CC_CCM_TX_DMAC, + sizeof(mep->cc_ccm_tx_info.dmac), + mep->cc_ccm_tx_info.dmac.addr)) + goto nla_put_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_CC_CCM_TX_SEQ_NO_UPDATE, + mep->cc_ccm_tx_info.seq_no_update)) + goto nla_put_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_CC_CCM_TX_PERIOD, + mep->cc_ccm_tx_info.period)) + goto nla_put_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_CC_CCM_TX_IF_TLV, + mep->cc_ccm_tx_info.if_tlv)) + goto nla_put_failure; + + if (nla_put_u8(skb, IFLA_BRIDGE_CFM_CC_CCM_TX_IF_TLV_VALUE, + mep->cc_ccm_tx_info.if_tlv_value)) + goto nla_put_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_CC_CCM_TX_PORT_TLV, + mep->cc_ccm_tx_info.port_tlv)) + goto nla_put_failure; + + if (nla_put_u8(skb, IFLA_BRIDGE_CFM_CC_CCM_TX_PORT_TLV_VALUE, + mep->cc_ccm_tx_info.port_tlv_value)) + goto nla_put_failure; + + nla_nest_end(skb, tb); + + hlist_for_each_entry_rcu(peer_mep, &mep->peer_mep_list, head) { + tb = nla_nest_start(skb, + IFLA_BRIDGE_CFM_CC_PEER_MEP_INFO); + + if (!tb) + goto nla_info_failure; + + if (nla_put_u32(skb, + IFLA_BRIDGE_CFM_CC_PEER_MEP_INSTANCE, + mep->instance)) + goto nla_put_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_CC_PEER_MEPID, + peer_mep->mepid)) + goto nla_put_failure; + + nla_nest_end(skb, tb); + } + } + + return 0; + +nla_put_failure: + nla_nest_cancel(skb, tb); + +nla_info_failure: + return -EMSGSIZE; +} + +int br_cfm_status_fill_info(struct sk_buff *skb, + struct net_bridge *br, + bool getlink) +{ + struct br_cfm_peer_mep *peer_mep; + struct br_cfm_mep *mep; + struct nlattr *tb; + + hlist_for_each_entry_rcu(mep, &br->mep_list, head) { + tb = nla_nest_start(skb, IFLA_BRIDGE_CFM_MEP_STATUS_INFO); + if (!tb) + goto nla_info_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_MEP_STATUS_INSTANCE, + mep->instance)) + goto nla_put_failure; + + if (nla_put_u32(skb, + IFLA_BRIDGE_CFM_MEP_STATUS_OPCODE_UNEXP_SEEN, + mep->status.opcode_unexp_seen)) + goto nla_put_failure; + + if (nla_put_u32(skb, + IFLA_BRIDGE_CFM_MEP_STATUS_VERSION_UNEXP_SEEN, + mep->status.version_unexp_seen)) + goto nla_put_failure; + + if (nla_put_u32(skb, + IFLA_BRIDGE_CFM_MEP_STATUS_RX_LEVEL_LOW_SEEN, + mep->status.rx_level_low_seen)) + goto nla_put_failure; + + /* Only clear if this is a GETLINK */ + if (getlink) { + /* Clear all 'seen' indications */ + mep->status.opcode_unexp_seen = false; + mep->status.version_unexp_seen = false; + mep->status.rx_level_low_seen = false; + } + + nla_nest_end(skb, tb); + + hlist_for_each_entry_rcu(peer_mep, &mep->peer_mep_list, head) { + tb = nla_nest_start(skb, + IFLA_BRIDGE_CFM_CC_PEER_STATUS_INFO); + if (!tb) + goto nla_info_failure; + + if (nla_put_u32(skb, + IFLA_BRIDGE_CFM_CC_PEER_STATUS_INSTANCE, + mep->instance)) + goto nla_put_failure; + + if (nla_put_u32(skb, + IFLA_BRIDGE_CFM_CC_PEER_STATUS_PEER_MEPID, + peer_mep->mepid)) + goto nla_put_failure; + + if (nla_put_u32(skb, + IFLA_BRIDGE_CFM_CC_PEER_STATUS_CCM_DEFECT, + peer_mep->cc_status.ccm_defect)) + goto nla_put_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_CFM_CC_PEER_STATUS_RDI, + peer_mep->cc_status.rdi)) + goto nla_put_failure; + + if (nla_put_u8(skb, + IFLA_BRIDGE_CFM_CC_PEER_STATUS_PORT_TLV_VALUE, + peer_mep->cc_status.port_tlv_value)) + goto nla_put_failure; + + if (nla_put_u8(skb, + IFLA_BRIDGE_CFM_CC_PEER_STATUS_IF_TLV_VALUE, + peer_mep->cc_status.if_tlv_value)) + goto nla_put_failure; + + if (nla_put_u32(skb, + IFLA_BRIDGE_CFM_CC_PEER_STATUS_SEEN, + peer_mep->cc_status.seen)) + goto nla_put_failure; + + if (nla_put_u32(skb, + IFLA_BRIDGE_CFM_CC_PEER_STATUS_TLV_SEEN, + peer_mep->cc_status.tlv_seen)) + goto nla_put_failure; + + if (nla_put_u32(skb, + IFLA_BRIDGE_CFM_CC_PEER_STATUS_SEQ_UNEXP_SEEN, + peer_mep->cc_status.seq_unexp_seen)) + goto nla_put_failure; + + if (getlink) { /* Only clear if this is a GETLINK */ + /* Clear all 'seen' indications */ + peer_mep->cc_status.seen = false; + peer_mep->cc_status.tlv_seen = false; + peer_mep->cc_status.seq_unexp_seen = false; + } + + nla_nest_end(skb, tb); + } + } + + return 0; + +nla_put_failure: + nla_nest_cancel(skb, tb); + +nla_info_failure: + return -EMSGSIZE; +} diff --git a/net/bridge/br_device.c b/net/bridge/br_device.c new file mode 100644 index 000000000..b82906fc9 --- /dev/null +++ b/net/bridge/br_device.c @@ -0,0 +1,534 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Device handling code + * Linux ethernet bridge + * + * Authors: + * Lennert Buytenhek + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include "br_private.h" + +#define COMMON_FEATURES (NETIF_F_SG | NETIF_F_FRAGLIST | NETIF_F_HIGHDMA | \ + NETIF_F_GSO_MASK | NETIF_F_HW_CSUM) + +const struct nf_br_ops __rcu *nf_br_ops __read_mostly; +EXPORT_SYMBOL_GPL(nf_br_ops); + +/* net device transmit always called with BH disabled */ +netdev_tx_t br_dev_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct net_bridge_mcast_port *pmctx_null = NULL; + struct net_bridge *br = netdev_priv(dev); + struct net_bridge_mcast *brmctx = &br->multicast_ctx; + struct net_bridge_fdb_entry *dst; + struct net_bridge_mdb_entry *mdst; + const struct nf_br_ops *nf_ops; + u8 state = BR_STATE_FORWARDING; + struct net_bridge_vlan *vlan; + const unsigned char *dest; + u16 vid = 0; + + memset(skb->cb, 0, sizeof(struct br_input_skb_cb)); + + rcu_read_lock(); + nf_ops = rcu_dereference(nf_br_ops); + if (nf_ops && nf_ops->br_dev_xmit_hook(skb)) { + rcu_read_unlock(); + return NETDEV_TX_OK; + } + + dev_sw_netstats_tx_add(dev, 1, skb->len); + + br_switchdev_frame_unmark(skb); + BR_INPUT_SKB_CB(skb)->brdev = dev; + BR_INPUT_SKB_CB(skb)->frag_max_size = 0; + + skb_reset_mac_header(skb); + skb_pull(skb, ETH_HLEN); + + if (!br_allowed_ingress(br, br_vlan_group_rcu(br), skb, &vid, + &state, &vlan)) + goto out; + + if (IS_ENABLED(CONFIG_INET) && + (eth_hdr(skb)->h_proto == htons(ETH_P_ARP) || + eth_hdr(skb)->h_proto == htons(ETH_P_RARP)) && + br_opt_get(br, BROPT_NEIGH_SUPPRESS_ENABLED)) { + br_do_proxy_suppress_arp(skb, br, vid, NULL); + } else if (IS_ENABLED(CONFIG_IPV6) && + skb->protocol == htons(ETH_P_IPV6) && + br_opt_get(br, BROPT_NEIGH_SUPPRESS_ENABLED) && + pskb_may_pull(skb, sizeof(struct ipv6hdr) + + sizeof(struct nd_msg)) && + ipv6_hdr(skb)->nexthdr == IPPROTO_ICMPV6) { + struct nd_msg *msg, _msg; + + msg = br_is_nd_neigh_msg(skb, &_msg); + if (msg) + br_do_suppress_nd(skb, br, vid, NULL, msg); + } + + dest = eth_hdr(skb)->h_dest; + if (is_broadcast_ether_addr(dest)) { + br_flood(br, skb, BR_PKT_BROADCAST, false, true); + } else if (is_multicast_ether_addr(dest)) { + if (unlikely(netpoll_tx_running(dev))) { + br_flood(br, skb, BR_PKT_MULTICAST, false, true); + goto out; + } + if (br_multicast_rcv(&brmctx, &pmctx_null, vlan, skb, vid)) { + kfree_skb(skb); + goto out; + } + + mdst = br_mdb_get(brmctx, skb, vid); + if ((mdst || BR_INPUT_SKB_CB_MROUTERS_ONLY(skb)) && + br_multicast_querier_exists(brmctx, eth_hdr(skb), mdst)) + br_multicast_flood(mdst, skb, brmctx, false, true); + else + br_flood(br, skb, BR_PKT_MULTICAST, false, true); + } else if ((dst = br_fdb_find_rcu(br, dest, vid)) != NULL) { + br_forward(dst->dst, skb, false, true); + } else { + br_flood(br, skb, BR_PKT_UNICAST, false, true); + } +out: + rcu_read_unlock(); + return NETDEV_TX_OK; +} + +static struct lock_class_key bridge_netdev_addr_lock_key; + +static void br_set_lockdep_class(struct net_device *dev) +{ + lockdep_set_class(&dev->addr_list_lock, &bridge_netdev_addr_lock_key); +} + +static int br_dev_init(struct net_device *dev) +{ + struct net_bridge *br = netdev_priv(dev); + int err; + + dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); + if (!dev->tstats) + return -ENOMEM; + + err = br_fdb_hash_init(br); + if (err) { + free_percpu(dev->tstats); + return err; + } + + err = br_mdb_hash_init(br); + if (err) { + free_percpu(dev->tstats); + br_fdb_hash_fini(br); + return err; + } + + err = br_vlan_init(br); + if (err) { + free_percpu(dev->tstats); + br_mdb_hash_fini(br); + br_fdb_hash_fini(br); + return err; + } + + err = br_multicast_init_stats(br); + if (err) { + free_percpu(dev->tstats); + br_vlan_flush(br); + br_mdb_hash_fini(br); + br_fdb_hash_fini(br); + } + + br_set_lockdep_class(dev); + return err; +} + +static void br_dev_uninit(struct net_device *dev) +{ + struct net_bridge *br = netdev_priv(dev); + + br_multicast_dev_del(br); + br_multicast_uninit_stats(br); + br_vlan_flush(br); + br_mdb_hash_fini(br); + br_fdb_hash_fini(br); + free_percpu(dev->tstats); +} + +static int br_dev_open(struct net_device *dev) +{ + struct net_bridge *br = netdev_priv(dev); + + netdev_update_features(dev); + netif_start_queue(dev); + br_stp_enable_bridge(br); + br_multicast_open(br); + + if (br_opt_get(br, BROPT_MULTICAST_ENABLED)) + br_multicast_join_snoopers(br); + + return 0; +} + +static void br_dev_set_multicast_list(struct net_device *dev) +{ +} + +static void br_dev_change_rx_flags(struct net_device *dev, int change) +{ + if (change & IFF_PROMISC) + br_manage_promisc(netdev_priv(dev)); +} + +static int br_dev_stop(struct net_device *dev) +{ + struct net_bridge *br = netdev_priv(dev); + + br_stp_disable_bridge(br); + br_multicast_stop(br); + + if (br_opt_get(br, BROPT_MULTICAST_ENABLED)) + br_multicast_leave_snoopers(br); + + netif_stop_queue(dev); + + return 0; +} + +static int br_change_mtu(struct net_device *dev, int new_mtu) +{ + struct net_bridge *br = netdev_priv(dev); + + dev->mtu = new_mtu; + + /* this flag will be cleared if the MTU was automatically adjusted */ + br_opt_toggle(br, BROPT_MTU_SET_BY_USER, true); +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + /* remember the MTU in the rtable for PMTU */ + dst_metric_set(&br->fake_rtable.dst, RTAX_MTU, new_mtu); +#endif + + return 0; +} + +/* Allow setting mac address to any valid ethernet address. */ +static int br_set_mac_address(struct net_device *dev, void *p) +{ + struct net_bridge *br = netdev_priv(dev); + struct sockaddr *addr = p; + + if (!is_valid_ether_addr(addr->sa_data)) + return -EADDRNOTAVAIL; + + /* dev_set_mac_addr() can be called by a master device on bridge's + * NETDEV_UNREGISTER, but since it's being destroyed do nothing + */ + if (dev->reg_state != NETREG_REGISTERED) + return -EBUSY; + + spin_lock_bh(&br->lock); + if (!ether_addr_equal(dev->dev_addr, addr->sa_data)) { + /* Mac address will be changed in br_stp_change_bridge_id(). */ + br_stp_change_bridge_id(br, addr->sa_data); + } + spin_unlock_bh(&br->lock); + + return 0; +} + +static void br_getinfo(struct net_device *dev, struct ethtool_drvinfo *info) +{ + strscpy(info->driver, "bridge", sizeof(info->driver)); + strscpy(info->version, BR_VERSION, sizeof(info->version)); + strscpy(info->fw_version, "N/A", sizeof(info->fw_version)); + strscpy(info->bus_info, "N/A", sizeof(info->bus_info)); +} + +static int br_get_link_ksettings(struct net_device *dev, + struct ethtool_link_ksettings *cmd) +{ + struct net_bridge *br = netdev_priv(dev); + struct net_bridge_port *p; + + cmd->base.duplex = DUPLEX_UNKNOWN; + cmd->base.port = PORT_OTHER; + cmd->base.speed = SPEED_UNKNOWN; + + list_for_each_entry(p, &br->port_list, list) { + struct ethtool_link_ksettings ecmd; + struct net_device *pdev = p->dev; + + if (!netif_running(pdev) || !netif_oper_up(pdev)) + continue; + + if (__ethtool_get_link_ksettings(pdev, &ecmd)) + continue; + + if (ecmd.base.speed == (__u32)SPEED_UNKNOWN) + continue; + + if (cmd->base.speed == (__u32)SPEED_UNKNOWN || + cmd->base.speed < ecmd.base.speed) + cmd->base.speed = ecmd.base.speed; + } + + return 0; +} + +static netdev_features_t br_fix_features(struct net_device *dev, + netdev_features_t features) +{ + struct net_bridge *br = netdev_priv(dev); + + return br_features_recompute(br, features); +} + +#ifdef CONFIG_NET_POLL_CONTROLLER +static void br_poll_controller(struct net_device *br_dev) +{ +} + +static void br_netpoll_cleanup(struct net_device *dev) +{ + struct net_bridge *br = netdev_priv(dev); + struct net_bridge_port *p; + + list_for_each_entry(p, &br->port_list, list) + br_netpoll_disable(p); +} + +static int __br_netpoll_enable(struct net_bridge_port *p) +{ + struct netpoll *np; + int err; + + np = kzalloc(sizeof(*p->np), GFP_KERNEL); + if (!np) + return -ENOMEM; + + err = __netpoll_setup(np, p->dev); + if (err) { + kfree(np); + return err; + } + + p->np = np; + return err; +} + +int br_netpoll_enable(struct net_bridge_port *p) +{ + if (!p->br->dev->npinfo) + return 0; + + return __br_netpoll_enable(p); +} + +static int br_netpoll_setup(struct net_device *dev, struct netpoll_info *ni) +{ + struct net_bridge *br = netdev_priv(dev); + struct net_bridge_port *p; + int err = 0; + + list_for_each_entry(p, &br->port_list, list) { + if (!p->dev) + continue; + err = __br_netpoll_enable(p); + if (err) + goto fail; + } + +out: + return err; + +fail: + br_netpoll_cleanup(dev); + goto out; +} + +void br_netpoll_disable(struct net_bridge_port *p) +{ + struct netpoll *np = p->np; + + if (!np) + return; + + p->np = NULL; + + __netpoll_free(np); +} + +#endif + +static int br_add_slave(struct net_device *dev, struct net_device *slave_dev, + struct netlink_ext_ack *extack) + +{ + struct net_bridge *br = netdev_priv(dev); + + return br_add_if(br, slave_dev, extack); +} + +static int br_del_slave(struct net_device *dev, struct net_device *slave_dev) +{ + struct net_bridge *br = netdev_priv(dev); + + return br_del_if(br, slave_dev); +} + +static int br_fill_forward_path(struct net_device_path_ctx *ctx, + struct net_device_path *path) +{ + struct net_bridge_fdb_entry *f; + struct net_bridge_port *dst; + struct net_bridge *br; + + if (netif_is_bridge_port(ctx->dev)) + return -1; + + br = netdev_priv(ctx->dev); + + br_vlan_fill_forward_path_pvid(br, ctx, path); + + f = br_fdb_find_rcu(br, ctx->daddr, path->bridge.vlan_id); + if (!f || !f->dst) + return -1; + + dst = READ_ONCE(f->dst); + if (!dst) + return -1; + + if (br_vlan_fill_forward_path_mode(br, dst, path)) + return -1; + + path->type = DEV_PATH_BRIDGE; + path->dev = dst->br->dev; + ctx->dev = dst->dev; + + switch (path->bridge.vlan_mode) { + case DEV_PATH_BR_VLAN_TAG: + if (ctx->num_vlans >= ARRAY_SIZE(ctx->vlan)) + return -ENOSPC; + ctx->vlan[ctx->num_vlans].id = path->bridge.vlan_id; + ctx->vlan[ctx->num_vlans].proto = path->bridge.vlan_proto; + ctx->num_vlans++; + break; + case DEV_PATH_BR_VLAN_UNTAG_HW: + case DEV_PATH_BR_VLAN_UNTAG: + ctx->num_vlans--; + break; + case DEV_PATH_BR_VLAN_KEEP: + break; + } + + return 0; +} + +static const struct ethtool_ops br_ethtool_ops = { + .get_drvinfo = br_getinfo, + .get_link = ethtool_op_get_link, + .get_link_ksettings = br_get_link_ksettings, +}; + +static const struct net_device_ops br_netdev_ops = { + .ndo_open = br_dev_open, + .ndo_stop = br_dev_stop, + .ndo_init = br_dev_init, + .ndo_uninit = br_dev_uninit, + .ndo_start_xmit = br_dev_xmit, + .ndo_get_stats64 = dev_get_tstats64, + .ndo_set_mac_address = br_set_mac_address, + .ndo_set_rx_mode = br_dev_set_multicast_list, + .ndo_change_rx_flags = br_dev_change_rx_flags, + .ndo_change_mtu = br_change_mtu, + .ndo_siocdevprivate = br_dev_siocdevprivate, +#ifdef CONFIG_NET_POLL_CONTROLLER + .ndo_netpoll_setup = br_netpoll_setup, + .ndo_netpoll_cleanup = br_netpoll_cleanup, + .ndo_poll_controller = br_poll_controller, +#endif + .ndo_add_slave = br_add_slave, + .ndo_del_slave = br_del_slave, + .ndo_fix_features = br_fix_features, + .ndo_fdb_add = br_fdb_add, + .ndo_fdb_del = br_fdb_delete, + .ndo_fdb_del_bulk = br_fdb_delete_bulk, + .ndo_fdb_dump = br_fdb_dump, + .ndo_fdb_get = br_fdb_get, + .ndo_bridge_getlink = br_getlink, + .ndo_bridge_setlink = br_setlink, + .ndo_bridge_dellink = br_dellink, + .ndo_features_check = passthru_features_check, + .ndo_fill_forward_path = br_fill_forward_path, +}; + +static struct device_type br_type = { + .name = "bridge", +}; + +void br_dev_setup(struct net_device *dev) +{ + struct net_bridge *br = netdev_priv(dev); + + eth_hw_addr_random(dev); + ether_setup(dev); + + dev->netdev_ops = &br_netdev_ops; + dev->needs_free_netdev = true; + dev->ethtool_ops = &br_ethtool_ops; + SET_NETDEV_DEVTYPE(dev, &br_type); + dev->priv_flags = IFF_EBRIDGE | IFF_NO_QUEUE; + + dev->features = COMMON_FEATURES | NETIF_F_LLTX | NETIF_F_NETNS_LOCAL | + NETIF_F_HW_VLAN_CTAG_TX | NETIF_F_HW_VLAN_STAG_TX; + dev->hw_features = COMMON_FEATURES | NETIF_F_HW_VLAN_CTAG_TX | + NETIF_F_HW_VLAN_STAG_TX; + dev->vlan_features = COMMON_FEATURES; + + br->dev = dev; + spin_lock_init(&br->lock); + INIT_LIST_HEAD(&br->port_list); + INIT_HLIST_HEAD(&br->fdb_list); + INIT_HLIST_HEAD(&br->frame_type_list); +#if IS_ENABLED(CONFIG_BRIDGE_MRP) + INIT_HLIST_HEAD(&br->mrp_list); +#endif +#if IS_ENABLED(CONFIG_BRIDGE_CFM) + INIT_HLIST_HEAD(&br->mep_list); +#endif + spin_lock_init(&br->hash_lock); + + br->bridge_id.prio[0] = 0x80; + br->bridge_id.prio[1] = 0x00; + + ether_addr_copy(br->group_addr, eth_stp_addr); + + br->stp_enabled = BR_NO_STP; + br->group_fwd_mask = BR_GROUPFWD_DEFAULT; + br->group_fwd_mask_required = BR_GROUPFWD_DEFAULT; + + br->designated_root = br->bridge_id; + br->bridge_max_age = br->max_age = 20 * HZ; + br->bridge_hello_time = br->hello_time = 2 * HZ; + br->bridge_forward_delay = br->forward_delay = 15 * HZ; + br->bridge_ageing_time = br->ageing_time = BR_DEFAULT_AGEING_TIME; + dev->max_mtu = ETH_MAX_MTU; + + br_netfilter_rtable_init(br); + br_stp_timer_init(br); + br_multicast_init(br); + INIT_DELAYED_WORK(&br->gc_work, br_fdb_cleanup); +} diff --git a/net/bridge/br_fdb.c b/net/bridge/br_fdb.c new file mode 100644 index 000000000..e7f4fccb6 --- /dev/null +++ b/net/bridge/br_fdb.c @@ -0,0 +1,1468 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Forwarding database + * Linux ethernet bridge + * + * Authors: + * Lennert Buytenhek + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "br_private.h" + +static const struct rhashtable_params br_fdb_rht_params = { + .head_offset = offsetof(struct net_bridge_fdb_entry, rhnode), + .key_offset = offsetof(struct net_bridge_fdb_entry, key), + .key_len = sizeof(struct net_bridge_fdb_key), + .automatic_shrinking = true, +}; + +static struct kmem_cache *br_fdb_cache __read_mostly; + +int __init br_fdb_init(void) +{ + br_fdb_cache = kmem_cache_create("bridge_fdb_cache", + sizeof(struct net_bridge_fdb_entry), + 0, + SLAB_HWCACHE_ALIGN, NULL); + if (!br_fdb_cache) + return -ENOMEM; + + return 0; +} + +void br_fdb_fini(void) +{ + kmem_cache_destroy(br_fdb_cache); +} + +int br_fdb_hash_init(struct net_bridge *br) +{ + return rhashtable_init(&br->fdb_hash_tbl, &br_fdb_rht_params); +} + +void br_fdb_hash_fini(struct net_bridge *br) +{ + rhashtable_destroy(&br->fdb_hash_tbl); +} + +/* if topology_changing then use forward_delay (default 15 sec) + * otherwise keep longer (default 5 minutes) + */ +static inline unsigned long hold_time(const struct net_bridge *br) +{ + return br->topology_change ? br->forward_delay : br->ageing_time; +} + +static inline int has_expired(const struct net_bridge *br, + const struct net_bridge_fdb_entry *fdb) +{ + return !test_bit(BR_FDB_STATIC, &fdb->flags) && + !test_bit(BR_FDB_ADDED_BY_EXT_LEARN, &fdb->flags) && + time_before_eq(fdb->updated + hold_time(br), jiffies); +} + +static void fdb_rcu_free(struct rcu_head *head) +{ + struct net_bridge_fdb_entry *ent + = container_of(head, struct net_bridge_fdb_entry, rcu); + kmem_cache_free(br_fdb_cache, ent); +} + +static int fdb_to_nud(const struct net_bridge *br, + const struct net_bridge_fdb_entry *fdb) +{ + if (test_bit(BR_FDB_LOCAL, &fdb->flags)) + return NUD_PERMANENT; + else if (test_bit(BR_FDB_STATIC, &fdb->flags)) + return NUD_NOARP; + else if (has_expired(br, fdb)) + return NUD_STALE; + else + return NUD_REACHABLE; +} + +static int fdb_fill_info(struct sk_buff *skb, const struct net_bridge *br, + const struct net_bridge_fdb_entry *fdb, + u32 portid, u32 seq, int type, unsigned int flags) +{ + const struct net_bridge_port *dst = READ_ONCE(fdb->dst); + unsigned long now = jiffies; + struct nda_cacheinfo ci; + struct nlmsghdr *nlh; + struct ndmsg *ndm; + + nlh = nlmsg_put(skb, portid, seq, type, sizeof(*ndm), flags); + if (nlh == NULL) + return -EMSGSIZE; + + ndm = nlmsg_data(nlh); + ndm->ndm_family = AF_BRIDGE; + ndm->ndm_pad1 = 0; + ndm->ndm_pad2 = 0; + ndm->ndm_flags = 0; + ndm->ndm_type = 0; + ndm->ndm_ifindex = dst ? dst->dev->ifindex : br->dev->ifindex; + ndm->ndm_state = fdb_to_nud(br, fdb); + + if (test_bit(BR_FDB_OFFLOADED, &fdb->flags)) + ndm->ndm_flags |= NTF_OFFLOADED; + if (test_bit(BR_FDB_ADDED_BY_EXT_LEARN, &fdb->flags)) + ndm->ndm_flags |= NTF_EXT_LEARNED; + if (test_bit(BR_FDB_STICKY, &fdb->flags)) + ndm->ndm_flags |= NTF_STICKY; + + if (nla_put(skb, NDA_LLADDR, ETH_ALEN, &fdb->key.addr)) + goto nla_put_failure; + if (nla_put_u32(skb, NDA_MASTER, br->dev->ifindex)) + goto nla_put_failure; + ci.ndm_used = jiffies_to_clock_t(now - fdb->used); + ci.ndm_confirmed = 0; + ci.ndm_updated = jiffies_to_clock_t(now - fdb->updated); + ci.ndm_refcnt = 0; + if (nla_put(skb, NDA_CACHEINFO, sizeof(ci), &ci)) + goto nla_put_failure; + + if (fdb->key.vlan_id && nla_put(skb, NDA_VLAN, sizeof(u16), + &fdb->key.vlan_id)) + goto nla_put_failure; + + if (test_bit(BR_FDB_NOTIFY, &fdb->flags)) { + struct nlattr *nest = nla_nest_start(skb, NDA_FDB_EXT_ATTRS); + u8 notify_bits = FDB_NOTIFY_BIT; + + if (!nest) + goto nla_put_failure; + if (test_bit(BR_FDB_NOTIFY_INACTIVE, &fdb->flags)) + notify_bits |= FDB_NOTIFY_INACTIVE_BIT; + + if (nla_put_u8(skb, NFEA_ACTIVITY_NOTIFY, notify_bits)) { + nla_nest_cancel(skb, nest); + goto nla_put_failure; + } + + nla_nest_end(skb, nest); + } + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static inline size_t fdb_nlmsg_size(void) +{ + return NLMSG_ALIGN(sizeof(struct ndmsg)) + + nla_total_size(ETH_ALEN) /* NDA_LLADDR */ + + nla_total_size(sizeof(u32)) /* NDA_MASTER */ + + nla_total_size(sizeof(u16)) /* NDA_VLAN */ + + nla_total_size(sizeof(struct nda_cacheinfo)) + + nla_total_size(0) /* NDA_FDB_EXT_ATTRS */ + + nla_total_size(sizeof(u8)); /* NFEA_ACTIVITY_NOTIFY */ +} + +static void fdb_notify(struct net_bridge *br, + const struct net_bridge_fdb_entry *fdb, int type, + bool swdev_notify) +{ + struct net *net = dev_net(br->dev); + struct sk_buff *skb; + int err = -ENOBUFS; + + if (swdev_notify) + br_switchdev_fdb_notify(br, fdb, type); + + skb = nlmsg_new(fdb_nlmsg_size(), GFP_ATOMIC); + if (skb == NULL) + goto errout; + + err = fdb_fill_info(skb, br, fdb, 0, 0, type, 0); + if (err < 0) { + /* -EMSGSIZE implies BUG in fdb_nlmsg_size() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + rtnl_notify(skb, net, 0, RTNLGRP_NEIGH, NULL, GFP_ATOMIC); + return; +errout: + rtnl_set_sk_err(net, RTNLGRP_NEIGH, err); +} + +static struct net_bridge_fdb_entry *fdb_find_rcu(struct rhashtable *tbl, + const unsigned char *addr, + __u16 vid) +{ + struct net_bridge_fdb_key key; + + WARN_ON_ONCE(!rcu_read_lock_held()); + + key.vlan_id = vid; + memcpy(key.addr.addr, addr, sizeof(key.addr.addr)); + + return rhashtable_lookup(tbl, &key, br_fdb_rht_params); +} + +/* requires bridge hash_lock */ +static struct net_bridge_fdb_entry *br_fdb_find(struct net_bridge *br, + const unsigned char *addr, + __u16 vid) +{ + struct net_bridge_fdb_entry *fdb; + + lockdep_assert_held_once(&br->hash_lock); + + rcu_read_lock(); + fdb = fdb_find_rcu(&br->fdb_hash_tbl, addr, vid); + rcu_read_unlock(); + + return fdb; +} + +struct net_device *br_fdb_find_port(const struct net_device *br_dev, + const unsigned char *addr, + __u16 vid) +{ + struct net_bridge_fdb_entry *f; + struct net_device *dev = NULL; + struct net_bridge *br; + + ASSERT_RTNL(); + + if (!netif_is_bridge_master(br_dev)) + return NULL; + + br = netdev_priv(br_dev); + rcu_read_lock(); + f = br_fdb_find_rcu(br, addr, vid); + if (f && f->dst) + dev = f->dst->dev; + rcu_read_unlock(); + + return dev; +} +EXPORT_SYMBOL_GPL(br_fdb_find_port); + +struct net_bridge_fdb_entry *br_fdb_find_rcu(struct net_bridge *br, + const unsigned char *addr, + __u16 vid) +{ + return fdb_find_rcu(&br->fdb_hash_tbl, addr, vid); +} + +/* When a static FDB entry is added, the mac address from the entry is + * added to the bridge private HW address list and all required ports + * are then updated with the new information. + * Called under RTNL. + */ +static void fdb_add_hw_addr(struct net_bridge *br, const unsigned char *addr) +{ + int err; + struct net_bridge_port *p; + + ASSERT_RTNL(); + + list_for_each_entry(p, &br->port_list, list) { + if (!br_promisc_port(p)) { + err = dev_uc_add(p->dev, addr); + if (err) + goto undo; + } + } + + return; +undo: + list_for_each_entry_continue_reverse(p, &br->port_list, list) { + if (!br_promisc_port(p)) + dev_uc_del(p->dev, addr); + } +} + +/* When a static FDB entry is deleted, the HW address from that entry is + * also removed from the bridge private HW address list and updates all + * the ports with needed information. + * Called under RTNL. + */ +static void fdb_del_hw_addr(struct net_bridge *br, const unsigned char *addr) +{ + struct net_bridge_port *p; + + ASSERT_RTNL(); + + list_for_each_entry(p, &br->port_list, list) { + if (!br_promisc_port(p)) + dev_uc_del(p->dev, addr); + } +} + +static void fdb_delete(struct net_bridge *br, struct net_bridge_fdb_entry *f, + bool swdev_notify) +{ + trace_fdb_delete(br, f); + + if (test_bit(BR_FDB_STATIC, &f->flags)) + fdb_del_hw_addr(br, f->key.addr.addr); + + hlist_del_init_rcu(&f->fdb_node); + rhashtable_remove_fast(&br->fdb_hash_tbl, &f->rhnode, + br_fdb_rht_params); + fdb_notify(br, f, RTM_DELNEIGH, swdev_notify); + call_rcu(&f->rcu, fdb_rcu_free); +} + +/* Delete a local entry if no other port had the same address. */ +static void fdb_delete_local(struct net_bridge *br, + const struct net_bridge_port *p, + struct net_bridge_fdb_entry *f) +{ + const unsigned char *addr = f->key.addr.addr; + struct net_bridge_vlan_group *vg; + const struct net_bridge_vlan *v; + struct net_bridge_port *op; + u16 vid = f->key.vlan_id; + + /* Maybe another port has same hw addr? */ + list_for_each_entry(op, &br->port_list, list) { + vg = nbp_vlan_group(op); + if (op != p && ether_addr_equal(op->dev->dev_addr, addr) && + (!vid || br_vlan_find(vg, vid))) { + f->dst = op; + clear_bit(BR_FDB_ADDED_BY_USER, &f->flags); + return; + } + } + + vg = br_vlan_group(br); + v = br_vlan_find(vg, vid); + /* Maybe bridge device has same hw addr? */ + if (p && ether_addr_equal(br->dev->dev_addr, addr) && + (!vid || (v && br_vlan_should_use(v)))) { + f->dst = NULL; + clear_bit(BR_FDB_ADDED_BY_USER, &f->flags); + return; + } + + fdb_delete(br, f, true); +} + +void br_fdb_find_delete_local(struct net_bridge *br, + const struct net_bridge_port *p, + const unsigned char *addr, u16 vid) +{ + struct net_bridge_fdb_entry *f; + + spin_lock_bh(&br->hash_lock); + f = br_fdb_find(br, addr, vid); + if (f && test_bit(BR_FDB_LOCAL, &f->flags) && + !test_bit(BR_FDB_ADDED_BY_USER, &f->flags) && f->dst == p) + fdb_delete_local(br, p, f); + spin_unlock_bh(&br->hash_lock); +} + +static struct net_bridge_fdb_entry *fdb_create(struct net_bridge *br, + struct net_bridge_port *source, + const unsigned char *addr, + __u16 vid, + unsigned long flags) +{ + struct net_bridge_fdb_entry *fdb; + int err; + + fdb = kmem_cache_alloc(br_fdb_cache, GFP_ATOMIC); + if (!fdb) + return NULL; + + memcpy(fdb->key.addr.addr, addr, ETH_ALEN); + WRITE_ONCE(fdb->dst, source); + fdb->key.vlan_id = vid; + fdb->flags = flags; + fdb->updated = fdb->used = jiffies; + err = rhashtable_lookup_insert_fast(&br->fdb_hash_tbl, &fdb->rhnode, + br_fdb_rht_params); + if (err) { + kmem_cache_free(br_fdb_cache, fdb); + return NULL; + } + + hlist_add_head_rcu(&fdb->fdb_node, &br->fdb_list); + + return fdb; +} + +static int fdb_add_local(struct net_bridge *br, struct net_bridge_port *source, + const unsigned char *addr, u16 vid) +{ + struct net_bridge_fdb_entry *fdb; + + if (!is_valid_ether_addr(addr)) + return -EINVAL; + + fdb = br_fdb_find(br, addr, vid); + if (fdb) { + /* it is okay to have multiple ports with same + * address, just use the first one. + */ + if (test_bit(BR_FDB_LOCAL, &fdb->flags)) + return 0; + br_warn(br, "adding interface %s with same address as a received packet (addr:%pM, vlan:%u)\n", + source ? source->dev->name : br->dev->name, addr, vid); + fdb_delete(br, fdb, true); + } + + fdb = fdb_create(br, source, addr, vid, + BIT(BR_FDB_LOCAL) | BIT(BR_FDB_STATIC)); + if (!fdb) + return -ENOMEM; + + fdb_add_hw_addr(br, addr); + fdb_notify(br, fdb, RTM_NEWNEIGH, true); + return 0; +} + +void br_fdb_changeaddr(struct net_bridge_port *p, const unsigned char *newaddr) +{ + struct net_bridge_vlan_group *vg; + struct net_bridge_fdb_entry *f; + struct net_bridge *br = p->br; + struct net_bridge_vlan *v; + + spin_lock_bh(&br->hash_lock); + vg = nbp_vlan_group(p); + hlist_for_each_entry(f, &br->fdb_list, fdb_node) { + if (f->dst == p && test_bit(BR_FDB_LOCAL, &f->flags) && + !test_bit(BR_FDB_ADDED_BY_USER, &f->flags)) { + /* delete old one */ + fdb_delete_local(br, p, f); + + /* if this port has no vlan information + * configured, we can safely be done at + * this point. + */ + if (!vg || !vg->num_vlans) + goto insert; + } + } + +insert: + /* insert new address, may fail if invalid address or dup. */ + fdb_add_local(br, p, newaddr, 0); + + if (!vg || !vg->num_vlans) + goto done; + + /* Now add entries for every VLAN configured on the port. + * This function runs under RTNL so the bitmap will not change + * from under us. + */ + list_for_each_entry(v, &vg->vlan_list, vlist) + fdb_add_local(br, p, newaddr, v->vid); + +done: + spin_unlock_bh(&br->hash_lock); +} + +void br_fdb_change_mac_address(struct net_bridge *br, const u8 *newaddr) +{ + struct net_bridge_vlan_group *vg; + struct net_bridge_fdb_entry *f; + struct net_bridge_vlan *v; + + spin_lock_bh(&br->hash_lock); + + /* If old entry was unassociated with any port, then delete it. */ + f = br_fdb_find(br, br->dev->dev_addr, 0); + if (f && test_bit(BR_FDB_LOCAL, &f->flags) && + !f->dst && !test_bit(BR_FDB_ADDED_BY_USER, &f->flags)) + fdb_delete_local(br, NULL, f); + + fdb_add_local(br, NULL, newaddr, 0); + vg = br_vlan_group(br); + if (!vg || !vg->num_vlans) + goto out; + /* Now remove and add entries for every VLAN configured on the + * bridge. This function runs under RTNL so the bitmap will not + * change from under us. + */ + list_for_each_entry(v, &vg->vlan_list, vlist) { + if (!br_vlan_should_use(v)) + continue; + f = br_fdb_find(br, br->dev->dev_addr, v->vid); + if (f && test_bit(BR_FDB_LOCAL, &f->flags) && + !f->dst && !test_bit(BR_FDB_ADDED_BY_USER, &f->flags)) + fdb_delete_local(br, NULL, f); + fdb_add_local(br, NULL, newaddr, v->vid); + } +out: + spin_unlock_bh(&br->hash_lock); +} + +void br_fdb_cleanup(struct work_struct *work) +{ + struct net_bridge *br = container_of(work, struct net_bridge, + gc_work.work); + struct net_bridge_fdb_entry *f = NULL; + unsigned long delay = hold_time(br); + unsigned long work_delay = delay; + unsigned long now = jiffies; + + /* this part is tricky, in order to avoid blocking learning and + * consequently forwarding, we rely on rcu to delete objects with + * delayed freeing allowing us to continue traversing + */ + rcu_read_lock(); + hlist_for_each_entry_rcu(f, &br->fdb_list, fdb_node) { + unsigned long this_timer = f->updated + delay; + + if (test_bit(BR_FDB_STATIC, &f->flags) || + test_bit(BR_FDB_ADDED_BY_EXT_LEARN, &f->flags)) { + if (test_bit(BR_FDB_NOTIFY, &f->flags)) { + if (time_after(this_timer, now)) + work_delay = min(work_delay, + this_timer - now); + else if (!test_and_set_bit(BR_FDB_NOTIFY_INACTIVE, + &f->flags)) + fdb_notify(br, f, RTM_NEWNEIGH, false); + } + continue; + } + + if (time_after(this_timer, now)) { + work_delay = min(work_delay, this_timer - now); + } else { + spin_lock_bh(&br->hash_lock); + if (!hlist_unhashed(&f->fdb_node)) + fdb_delete(br, f, true); + spin_unlock_bh(&br->hash_lock); + } + } + rcu_read_unlock(); + + /* Cleanup minimum 10 milliseconds apart */ + work_delay = max_t(unsigned long, work_delay, msecs_to_jiffies(10)); + mod_delayed_work(system_long_wq, &br->gc_work, work_delay); +} + +static bool __fdb_flush_matches(const struct net_bridge *br, + const struct net_bridge_fdb_entry *f, + const struct net_bridge_fdb_flush_desc *desc) +{ + const struct net_bridge_port *dst = READ_ONCE(f->dst); + int port_ifidx = dst ? dst->dev->ifindex : br->dev->ifindex; + + if (desc->vlan_id && desc->vlan_id != f->key.vlan_id) + return false; + if (desc->port_ifindex && desc->port_ifindex != port_ifidx) + return false; + if (desc->flags_mask && (f->flags & desc->flags_mask) != desc->flags) + return false; + + return true; +} + +/* Flush forwarding database entries matching the description */ +void br_fdb_flush(struct net_bridge *br, + const struct net_bridge_fdb_flush_desc *desc) +{ + struct net_bridge_fdb_entry *f; + + rcu_read_lock(); + hlist_for_each_entry_rcu(f, &br->fdb_list, fdb_node) { + if (!__fdb_flush_matches(br, f, desc)) + continue; + + spin_lock_bh(&br->hash_lock); + if (!hlist_unhashed(&f->fdb_node)) + fdb_delete(br, f, true); + spin_unlock_bh(&br->hash_lock); + } + rcu_read_unlock(); +} + +static unsigned long __ndm_state_to_fdb_flags(u16 ndm_state) +{ + unsigned long flags = 0; + + if (ndm_state & NUD_PERMANENT) + __set_bit(BR_FDB_LOCAL, &flags); + if (ndm_state & NUD_NOARP) + __set_bit(BR_FDB_STATIC, &flags); + + return flags; +} + +static unsigned long __ndm_flags_to_fdb_flags(u8 ndm_flags) +{ + unsigned long flags = 0; + + if (ndm_flags & NTF_USE) + __set_bit(BR_FDB_ADDED_BY_USER, &flags); + if (ndm_flags & NTF_EXT_LEARNED) + __set_bit(BR_FDB_ADDED_BY_EXT_LEARN, &flags); + if (ndm_flags & NTF_OFFLOADED) + __set_bit(BR_FDB_OFFLOADED, &flags); + if (ndm_flags & NTF_STICKY) + __set_bit(BR_FDB_STICKY, &flags); + + return flags; +} + +static int __fdb_flush_validate_ifindex(const struct net_bridge *br, + int ifindex, + struct netlink_ext_ack *extack) +{ + const struct net_device *dev; + + dev = __dev_get_by_index(dev_net(br->dev), ifindex); + if (!dev) { + NL_SET_ERR_MSG_MOD(extack, "Unknown flush device ifindex"); + return -ENODEV; + } + if (!netif_is_bridge_master(dev) && !netif_is_bridge_port(dev)) { + NL_SET_ERR_MSG_MOD(extack, "Flush device is not a bridge or bridge port"); + return -EINVAL; + } + if (netif_is_bridge_master(dev) && dev != br->dev) { + NL_SET_ERR_MSG_MOD(extack, + "Flush bridge device does not match target bridge device"); + return -EINVAL; + } + if (netif_is_bridge_port(dev)) { + struct net_bridge_port *p = br_port_get_rtnl(dev); + + if (p->br != br) { + NL_SET_ERR_MSG_MOD(extack, "Port belongs to a different bridge device"); + return -EINVAL; + } + } + + return 0; +} + +int br_fdb_delete_bulk(struct ndmsg *ndm, struct nlattr *tb[], + struct net_device *dev, u16 vid, + struct netlink_ext_ack *extack) +{ + u8 ndm_flags = ndm->ndm_flags & ~FDB_FLUSH_IGNORED_NDM_FLAGS; + struct net_bridge_fdb_flush_desc desc = { .vlan_id = vid }; + struct net_bridge_port *p = NULL; + struct net_bridge *br; + + if (netif_is_bridge_master(dev)) { + br = netdev_priv(dev); + } else { + p = br_port_get_rtnl(dev); + if (!p) { + NL_SET_ERR_MSG_MOD(extack, "Device is not a bridge port"); + return -EINVAL; + } + br = p->br; + } + + if (ndm_flags & ~FDB_FLUSH_ALLOWED_NDM_FLAGS) { + NL_SET_ERR_MSG(extack, "Unsupported fdb flush ndm flag bits set"); + return -EINVAL; + } + if (ndm->ndm_state & ~FDB_FLUSH_ALLOWED_NDM_STATES) { + NL_SET_ERR_MSG(extack, "Unsupported fdb flush ndm state bits set"); + return -EINVAL; + } + + desc.flags |= __ndm_state_to_fdb_flags(ndm->ndm_state); + desc.flags |= __ndm_flags_to_fdb_flags(ndm_flags); + if (tb[NDA_NDM_STATE_MASK]) { + u16 ndm_state_mask = nla_get_u16(tb[NDA_NDM_STATE_MASK]); + + desc.flags_mask |= __ndm_state_to_fdb_flags(ndm_state_mask); + } + if (tb[NDA_NDM_FLAGS_MASK]) { + u8 ndm_flags_mask = nla_get_u8(tb[NDA_NDM_FLAGS_MASK]); + + desc.flags_mask |= __ndm_flags_to_fdb_flags(ndm_flags_mask); + } + if (tb[NDA_IFINDEX]) { + int err, ifidx = nla_get_s32(tb[NDA_IFINDEX]); + + err = __fdb_flush_validate_ifindex(br, ifidx, extack); + if (err) + return err; + desc.port_ifindex = ifidx; + } else if (p) { + /* flush was invoked with port device and NTF_MASTER */ + desc.port_ifindex = p->dev->ifindex; + } + + br_debug(br, "flushing port ifindex: %d vlan id: %u flags: 0x%lx flags mask: 0x%lx\n", + desc.port_ifindex, desc.vlan_id, desc.flags, desc.flags_mask); + + br_fdb_flush(br, &desc); + + return 0; +} + +/* Flush all entries referring to a specific port. + * if do_all is set also flush static entries + * if vid is set delete all entries that match the vlan_id + */ +void br_fdb_delete_by_port(struct net_bridge *br, + const struct net_bridge_port *p, + u16 vid, + int do_all) +{ + struct net_bridge_fdb_entry *f; + struct hlist_node *tmp; + + spin_lock_bh(&br->hash_lock); + hlist_for_each_entry_safe(f, tmp, &br->fdb_list, fdb_node) { + if (f->dst != p) + continue; + + if (!do_all) + if (test_bit(BR_FDB_STATIC, &f->flags) || + (test_bit(BR_FDB_ADDED_BY_EXT_LEARN, &f->flags) && + !test_bit(BR_FDB_OFFLOADED, &f->flags)) || + (vid && f->key.vlan_id != vid)) + continue; + + if (test_bit(BR_FDB_LOCAL, &f->flags)) + fdb_delete_local(br, p, f); + else + fdb_delete(br, f, true); + } + spin_unlock_bh(&br->hash_lock); +} + +#if IS_ENABLED(CONFIG_ATM_LANE) +/* Interface used by ATM LANE hook to test + * if an addr is on some other bridge port */ +int br_fdb_test_addr(struct net_device *dev, unsigned char *addr) +{ + struct net_bridge_fdb_entry *fdb; + struct net_bridge_port *port; + int ret; + + rcu_read_lock(); + port = br_port_get_rcu(dev); + if (!port) + ret = 0; + else { + const struct net_bridge_port *dst = NULL; + + fdb = br_fdb_find_rcu(port->br, addr, 0); + if (fdb) + dst = READ_ONCE(fdb->dst); + + ret = dst && dst->dev != dev && + dst->state == BR_STATE_FORWARDING; + } + rcu_read_unlock(); + + return ret; +} +#endif /* CONFIG_ATM_LANE */ + +/* + * Fill buffer with forwarding table records in + * the API format. + */ +int br_fdb_fillbuf(struct net_bridge *br, void *buf, + unsigned long maxnum, unsigned long skip) +{ + struct net_bridge_fdb_entry *f; + struct __fdb_entry *fe = buf; + int num = 0; + + memset(buf, 0, maxnum*sizeof(struct __fdb_entry)); + + rcu_read_lock(); + hlist_for_each_entry_rcu(f, &br->fdb_list, fdb_node) { + if (num >= maxnum) + break; + + if (has_expired(br, f)) + continue; + + /* ignore pseudo entry for local MAC address */ + if (!f->dst) + continue; + + if (skip) { + --skip; + continue; + } + + /* convert from internal format to API */ + memcpy(fe->mac_addr, f->key.addr.addr, ETH_ALEN); + + /* due to ABI compat need to split into hi/lo */ + fe->port_no = f->dst->port_no; + fe->port_hi = f->dst->port_no >> 8; + + fe->is_local = test_bit(BR_FDB_LOCAL, &f->flags); + if (!test_bit(BR_FDB_STATIC, &f->flags)) + fe->ageing_timer_value = jiffies_delta_to_clock_t(jiffies - f->updated); + ++fe; + ++num; + } + rcu_read_unlock(); + + return num; +} + +/* Add entry for local address of interface */ +int br_fdb_add_local(struct net_bridge *br, struct net_bridge_port *source, + const unsigned char *addr, u16 vid) +{ + int ret; + + spin_lock_bh(&br->hash_lock); + ret = fdb_add_local(br, source, addr, vid); + spin_unlock_bh(&br->hash_lock); + return ret; +} + +/* returns true if the fdb was modified */ +static bool __fdb_mark_active(struct net_bridge_fdb_entry *fdb) +{ + return !!(test_bit(BR_FDB_NOTIFY_INACTIVE, &fdb->flags) && + test_and_clear_bit(BR_FDB_NOTIFY_INACTIVE, &fdb->flags)); +} + +void br_fdb_update(struct net_bridge *br, struct net_bridge_port *source, + const unsigned char *addr, u16 vid, unsigned long flags) +{ + struct net_bridge_fdb_entry *fdb; + + /* some users want to always flood. */ + if (hold_time(br) == 0) + return; + + fdb = fdb_find_rcu(&br->fdb_hash_tbl, addr, vid); + if (likely(fdb)) { + /* attempt to update an entry for a local interface */ + if (unlikely(test_bit(BR_FDB_LOCAL, &fdb->flags))) { + if (net_ratelimit()) + br_warn(br, "received packet on %s with own address as source address (addr:%pM, vlan:%u)\n", + source->dev->name, addr, vid); + } else { + unsigned long now = jiffies; + bool fdb_modified = false; + + if (now != fdb->updated) { + fdb->updated = now; + fdb_modified = __fdb_mark_active(fdb); + } + + /* fastpath: update of existing entry */ + if (unlikely(source != READ_ONCE(fdb->dst) && + !test_bit(BR_FDB_STICKY, &fdb->flags))) { + br_switchdev_fdb_notify(br, fdb, RTM_DELNEIGH); + WRITE_ONCE(fdb->dst, source); + fdb_modified = true; + /* Take over HW learned entry */ + if (unlikely(test_bit(BR_FDB_ADDED_BY_EXT_LEARN, + &fdb->flags))) + clear_bit(BR_FDB_ADDED_BY_EXT_LEARN, + &fdb->flags); + } + + if (unlikely(test_bit(BR_FDB_ADDED_BY_USER, &flags))) + set_bit(BR_FDB_ADDED_BY_USER, &fdb->flags); + if (unlikely(fdb_modified)) { + trace_br_fdb_update(br, source, addr, vid, flags); + fdb_notify(br, fdb, RTM_NEWNEIGH, true); + } + } + } else { + spin_lock(&br->hash_lock); + fdb = fdb_create(br, source, addr, vid, flags); + if (fdb) { + trace_br_fdb_update(br, source, addr, vid, flags); + fdb_notify(br, fdb, RTM_NEWNEIGH, true); + } + /* else we lose race and someone else inserts + * it first, don't bother updating + */ + spin_unlock(&br->hash_lock); + } +} + +/* Dump information about entries, in response to GETNEIGH */ +int br_fdb_dump(struct sk_buff *skb, + struct netlink_callback *cb, + struct net_device *dev, + struct net_device *filter_dev, + int *idx) +{ + struct net_bridge *br = netdev_priv(dev); + struct net_bridge_fdb_entry *f; + int err = 0; + + if (!netif_is_bridge_master(dev)) + return err; + + if (!filter_dev) { + err = ndo_dflt_fdb_dump(skb, cb, dev, NULL, idx); + if (err < 0) + return err; + } + + rcu_read_lock(); + hlist_for_each_entry_rcu(f, &br->fdb_list, fdb_node) { + if (*idx < cb->args[2]) + goto skip; + if (filter_dev && (!f->dst || f->dst->dev != filter_dev)) { + if (filter_dev != dev) + goto skip; + /* !f->dst is a special case for bridge + * It means the MAC belongs to the bridge + * Therefore need a little more filtering + * we only want to dump the !f->dst case + */ + if (f->dst) + goto skip; + } + if (!filter_dev && f->dst) + goto skip; + + err = fdb_fill_info(skb, br, f, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + RTM_NEWNEIGH, + NLM_F_MULTI); + if (err < 0) + break; +skip: + *idx += 1; + } + rcu_read_unlock(); + + return err; +} + +int br_fdb_get(struct sk_buff *skb, + struct nlattr *tb[], + struct net_device *dev, + const unsigned char *addr, + u16 vid, u32 portid, u32 seq, + struct netlink_ext_ack *extack) +{ + struct net_bridge *br = netdev_priv(dev); + struct net_bridge_fdb_entry *f; + int err = 0; + + rcu_read_lock(); + f = br_fdb_find_rcu(br, addr, vid); + if (!f) { + NL_SET_ERR_MSG(extack, "Fdb entry not found"); + err = -ENOENT; + goto errout; + } + + err = fdb_fill_info(skb, br, f, portid, seq, + RTM_NEWNEIGH, 0); +errout: + rcu_read_unlock(); + return err; +} + +/* returns true if the fdb is modified */ +static bool fdb_handle_notify(struct net_bridge_fdb_entry *fdb, u8 notify) +{ + bool modified = false; + + /* allow to mark an entry as inactive, usually done on creation */ + if ((notify & FDB_NOTIFY_INACTIVE_BIT) && + !test_and_set_bit(BR_FDB_NOTIFY_INACTIVE, &fdb->flags)) + modified = true; + + if ((notify & FDB_NOTIFY_BIT) && + !test_and_set_bit(BR_FDB_NOTIFY, &fdb->flags)) { + /* enabled activity tracking */ + modified = true; + } else if (!(notify & FDB_NOTIFY_BIT) && + test_and_clear_bit(BR_FDB_NOTIFY, &fdb->flags)) { + /* disabled activity tracking, clear notify state */ + clear_bit(BR_FDB_NOTIFY_INACTIVE, &fdb->flags); + modified = true; + } + + return modified; +} + +/* Update (create or replace) forwarding database entry */ +static int fdb_add_entry(struct net_bridge *br, struct net_bridge_port *source, + const u8 *addr, struct ndmsg *ndm, u16 flags, u16 vid, + struct nlattr *nfea_tb[]) +{ + bool is_sticky = !!(ndm->ndm_flags & NTF_STICKY); + bool refresh = !nfea_tb[NFEA_DONT_REFRESH]; + struct net_bridge_fdb_entry *fdb; + u16 state = ndm->ndm_state; + bool modified = false; + u8 notify = 0; + + /* If the port cannot learn allow only local and static entries */ + if (source && !(state & NUD_PERMANENT) && !(state & NUD_NOARP) && + !(source->state == BR_STATE_LEARNING || + source->state == BR_STATE_FORWARDING)) + return -EPERM; + + if (!source && !(state & NUD_PERMANENT)) { + pr_info("bridge: RTM_NEWNEIGH %s without NUD_PERMANENT\n", + br->dev->name); + return -EINVAL; + } + + if (is_sticky && (state & NUD_PERMANENT)) + return -EINVAL; + + if (nfea_tb[NFEA_ACTIVITY_NOTIFY]) { + notify = nla_get_u8(nfea_tb[NFEA_ACTIVITY_NOTIFY]); + if ((notify & ~BR_FDB_NOTIFY_SETTABLE_BITS) || + (notify & BR_FDB_NOTIFY_SETTABLE_BITS) == FDB_NOTIFY_INACTIVE_BIT) + return -EINVAL; + } + + fdb = br_fdb_find(br, addr, vid); + if (fdb == NULL) { + if (!(flags & NLM_F_CREATE)) + return -ENOENT; + + fdb = fdb_create(br, source, addr, vid, 0); + if (!fdb) + return -ENOMEM; + + modified = true; + } else { + if (flags & NLM_F_EXCL) + return -EEXIST; + + if (READ_ONCE(fdb->dst) != source) { + WRITE_ONCE(fdb->dst, source); + modified = true; + } + } + + if (fdb_to_nud(br, fdb) != state) { + if (state & NUD_PERMANENT) { + set_bit(BR_FDB_LOCAL, &fdb->flags); + if (!test_and_set_bit(BR_FDB_STATIC, &fdb->flags)) + fdb_add_hw_addr(br, addr); + } else if (state & NUD_NOARP) { + clear_bit(BR_FDB_LOCAL, &fdb->flags); + if (!test_and_set_bit(BR_FDB_STATIC, &fdb->flags)) + fdb_add_hw_addr(br, addr); + } else { + clear_bit(BR_FDB_LOCAL, &fdb->flags); + if (test_and_clear_bit(BR_FDB_STATIC, &fdb->flags)) + fdb_del_hw_addr(br, addr); + } + + modified = true; + } + + if (is_sticky != test_bit(BR_FDB_STICKY, &fdb->flags)) { + change_bit(BR_FDB_STICKY, &fdb->flags); + modified = true; + } + + if (fdb_handle_notify(fdb, notify)) + modified = true; + + set_bit(BR_FDB_ADDED_BY_USER, &fdb->flags); + + fdb->used = jiffies; + if (modified) { + if (refresh) + fdb->updated = jiffies; + fdb_notify(br, fdb, RTM_NEWNEIGH, true); + } + + return 0; +} + +static int __br_fdb_add(struct ndmsg *ndm, struct net_bridge *br, + struct net_bridge_port *p, const unsigned char *addr, + u16 nlh_flags, u16 vid, struct nlattr *nfea_tb[], + struct netlink_ext_ack *extack) +{ + int err = 0; + + if (ndm->ndm_flags & NTF_USE) { + if (!p) { + pr_info("bridge: RTM_NEWNEIGH %s with NTF_USE is not supported\n", + br->dev->name); + return -EINVAL; + } + if (!nbp_state_should_learn(p)) + return 0; + + local_bh_disable(); + rcu_read_lock(); + br_fdb_update(br, p, addr, vid, BIT(BR_FDB_ADDED_BY_USER)); + rcu_read_unlock(); + local_bh_enable(); + } else if (ndm->ndm_flags & NTF_EXT_LEARNED) { + if (!p && !(ndm->ndm_state & NUD_PERMANENT)) { + NL_SET_ERR_MSG_MOD(extack, + "FDB entry towards bridge must be permanent"); + return -EINVAL; + } + err = br_fdb_external_learn_add(br, p, addr, vid, true); + } else { + spin_lock_bh(&br->hash_lock); + err = fdb_add_entry(br, p, addr, ndm, nlh_flags, vid, nfea_tb); + spin_unlock_bh(&br->hash_lock); + } + + return err; +} + +static const struct nla_policy br_nda_fdb_pol[NFEA_MAX + 1] = { + [NFEA_ACTIVITY_NOTIFY] = { .type = NLA_U8 }, + [NFEA_DONT_REFRESH] = { .type = NLA_FLAG }, +}; + +/* Add new permanent fdb entry with RTM_NEWNEIGH */ +int br_fdb_add(struct ndmsg *ndm, struct nlattr *tb[], + struct net_device *dev, + const unsigned char *addr, u16 vid, u16 nlh_flags, + struct netlink_ext_ack *extack) +{ + struct nlattr *nfea_tb[NFEA_MAX + 1], *attr; + struct net_bridge_vlan_group *vg; + struct net_bridge_port *p = NULL; + struct net_bridge_vlan *v; + struct net_bridge *br = NULL; + int err = 0; + + trace_br_fdb_add(ndm, dev, addr, vid, nlh_flags); + + if (!(ndm->ndm_state & (NUD_PERMANENT|NUD_NOARP|NUD_REACHABLE))) { + pr_info("bridge: RTM_NEWNEIGH with invalid state %#x\n", ndm->ndm_state); + return -EINVAL; + } + + if (is_zero_ether_addr(addr)) { + pr_info("bridge: RTM_NEWNEIGH with invalid ether address\n"); + return -EINVAL; + } + + if (netif_is_bridge_master(dev)) { + br = netdev_priv(dev); + vg = br_vlan_group(br); + } else { + p = br_port_get_rtnl(dev); + if (!p) { + pr_info("bridge: RTM_NEWNEIGH %s not a bridge port\n", + dev->name); + return -EINVAL; + } + br = p->br; + vg = nbp_vlan_group(p); + } + + if (tb[NDA_FDB_EXT_ATTRS]) { + attr = tb[NDA_FDB_EXT_ATTRS]; + err = nla_parse_nested(nfea_tb, NFEA_MAX, attr, + br_nda_fdb_pol, extack); + if (err) + return err; + } else { + memset(nfea_tb, 0, sizeof(struct nlattr *) * (NFEA_MAX + 1)); + } + + if (vid) { + v = br_vlan_find(vg, vid); + if (!v || !br_vlan_should_use(v)) { + pr_info("bridge: RTM_NEWNEIGH with unconfigured vlan %d on %s\n", vid, dev->name); + return -EINVAL; + } + + /* VID was specified, so use it. */ + err = __br_fdb_add(ndm, br, p, addr, nlh_flags, vid, nfea_tb, + extack); + } else { + err = __br_fdb_add(ndm, br, p, addr, nlh_flags, 0, nfea_tb, + extack); + if (err || !vg || !vg->num_vlans) + goto out; + + /* We have vlans configured on this port and user didn't + * specify a VLAN. To be nice, add/update entry for every + * vlan on this port. + */ + list_for_each_entry(v, &vg->vlan_list, vlist) { + if (!br_vlan_should_use(v)) + continue; + err = __br_fdb_add(ndm, br, p, addr, nlh_flags, v->vid, + nfea_tb, extack); + if (err) + goto out; + } + } + +out: + return err; +} + +static int fdb_delete_by_addr_and_port(struct net_bridge *br, + const struct net_bridge_port *p, + const u8 *addr, u16 vlan) +{ + struct net_bridge_fdb_entry *fdb; + + fdb = br_fdb_find(br, addr, vlan); + if (!fdb || READ_ONCE(fdb->dst) != p) + return -ENOENT; + + fdb_delete(br, fdb, true); + + return 0; +} + +static int __br_fdb_delete(struct net_bridge *br, + const struct net_bridge_port *p, + const unsigned char *addr, u16 vid) +{ + int err; + + spin_lock_bh(&br->hash_lock); + err = fdb_delete_by_addr_and_port(br, p, addr, vid); + spin_unlock_bh(&br->hash_lock); + + return err; +} + +/* Remove neighbor entry with RTM_DELNEIGH */ +int br_fdb_delete(struct ndmsg *ndm, struct nlattr *tb[], + struct net_device *dev, + const unsigned char *addr, u16 vid, + struct netlink_ext_ack *extack) +{ + struct net_bridge_vlan_group *vg; + struct net_bridge_port *p = NULL; + struct net_bridge_vlan *v; + struct net_bridge *br; + int err; + + if (netif_is_bridge_master(dev)) { + br = netdev_priv(dev); + vg = br_vlan_group(br); + } else { + p = br_port_get_rtnl(dev); + if (!p) { + pr_info("bridge: RTM_DELNEIGH %s not a bridge port\n", + dev->name); + return -EINVAL; + } + vg = nbp_vlan_group(p); + br = p->br; + } + + if (vid) { + v = br_vlan_find(vg, vid); + if (!v) { + pr_info("bridge: RTM_DELNEIGH with unconfigured vlan %d on %s\n", vid, dev->name); + return -EINVAL; + } + + err = __br_fdb_delete(br, p, addr, vid); + } else { + err = -ENOENT; + err &= __br_fdb_delete(br, p, addr, 0); + if (!vg || !vg->num_vlans) + return err; + + list_for_each_entry(v, &vg->vlan_list, vlist) { + if (!br_vlan_should_use(v)) + continue; + err &= __br_fdb_delete(br, p, addr, v->vid); + } + } + + return err; +} + +int br_fdb_sync_static(struct net_bridge *br, struct net_bridge_port *p) +{ + struct net_bridge_fdb_entry *f, *tmp; + int err = 0; + + ASSERT_RTNL(); + + /* the key here is that static entries change only under rtnl */ + rcu_read_lock(); + hlist_for_each_entry_rcu(f, &br->fdb_list, fdb_node) { + /* We only care for static entries */ + if (!test_bit(BR_FDB_STATIC, &f->flags)) + continue; + err = dev_uc_add(p->dev, f->key.addr.addr); + if (err) + goto rollback; + } +done: + rcu_read_unlock(); + + return err; + +rollback: + hlist_for_each_entry_rcu(tmp, &br->fdb_list, fdb_node) { + /* We only care for static entries */ + if (!test_bit(BR_FDB_STATIC, &tmp->flags)) + continue; + if (tmp == f) + break; + dev_uc_del(p->dev, tmp->key.addr.addr); + } + + goto done; +} + +void br_fdb_unsync_static(struct net_bridge *br, struct net_bridge_port *p) +{ + struct net_bridge_fdb_entry *f; + + ASSERT_RTNL(); + + rcu_read_lock(); + hlist_for_each_entry_rcu(f, &br->fdb_list, fdb_node) { + /* We only care for static entries */ + if (!test_bit(BR_FDB_STATIC, &f->flags)) + continue; + + dev_uc_del(p->dev, f->key.addr.addr); + } + rcu_read_unlock(); +} + +int br_fdb_external_learn_add(struct net_bridge *br, struct net_bridge_port *p, + const unsigned char *addr, u16 vid, + bool swdev_notify) +{ + struct net_bridge_fdb_entry *fdb; + bool modified = false; + int err = 0; + + trace_br_fdb_external_learn_add(br, p, addr, vid); + + spin_lock_bh(&br->hash_lock); + + fdb = br_fdb_find(br, addr, vid); + if (!fdb) { + unsigned long flags = BIT(BR_FDB_ADDED_BY_EXT_LEARN); + + if (swdev_notify) + flags |= BIT(BR_FDB_ADDED_BY_USER); + + if (!p) + flags |= BIT(BR_FDB_LOCAL); + + fdb = fdb_create(br, p, addr, vid, flags); + if (!fdb) { + err = -ENOMEM; + goto err_unlock; + } + fdb_notify(br, fdb, RTM_NEWNEIGH, swdev_notify); + } else { + fdb->updated = jiffies; + + if (READ_ONCE(fdb->dst) != p) { + WRITE_ONCE(fdb->dst, p); + modified = true; + } + + if (test_bit(BR_FDB_ADDED_BY_EXT_LEARN, &fdb->flags)) { + /* Refresh entry */ + fdb->used = jiffies; + } else if (!test_bit(BR_FDB_ADDED_BY_USER, &fdb->flags)) { + /* Take over SW learned entry */ + set_bit(BR_FDB_ADDED_BY_EXT_LEARN, &fdb->flags); + modified = true; + } + + if (swdev_notify) + set_bit(BR_FDB_ADDED_BY_USER, &fdb->flags); + + if (!p) + set_bit(BR_FDB_LOCAL, &fdb->flags); + + if (modified) + fdb_notify(br, fdb, RTM_NEWNEIGH, swdev_notify); + } + +err_unlock: + spin_unlock_bh(&br->hash_lock); + + return err; +} + +int br_fdb_external_learn_del(struct net_bridge *br, struct net_bridge_port *p, + const unsigned char *addr, u16 vid, + bool swdev_notify) +{ + struct net_bridge_fdb_entry *fdb; + int err = 0; + + spin_lock_bh(&br->hash_lock); + + fdb = br_fdb_find(br, addr, vid); + if (fdb && test_bit(BR_FDB_ADDED_BY_EXT_LEARN, &fdb->flags)) + fdb_delete(br, fdb, swdev_notify); + else + err = -ENOENT; + + spin_unlock_bh(&br->hash_lock); + + return err; +} + +void br_fdb_offloaded_set(struct net_bridge *br, struct net_bridge_port *p, + const unsigned char *addr, u16 vid, bool offloaded) +{ + struct net_bridge_fdb_entry *fdb; + + spin_lock_bh(&br->hash_lock); + + fdb = br_fdb_find(br, addr, vid); + if (fdb && offloaded != test_bit(BR_FDB_OFFLOADED, &fdb->flags)) + change_bit(BR_FDB_OFFLOADED, &fdb->flags); + + spin_unlock_bh(&br->hash_lock); +} + +void br_fdb_clear_offload(const struct net_device *dev, u16 vid) +{ + struct net_bridge_fdb_entry *f; + struct net_bridge_port *p; + + ASSERT_RTNL(); + + p = br_port_get_rtnl(dev); + if (!p) + return; + + spin_lock_bh(&p->br->hash_lock); + hlist_for_each_entry(f, &p->br->fdb_list, fdb_node) { + if (f->dst == p && f->key.vlan_id == vid) + clear_bit(BR_FDB_OFFLOADED, &f->flags); + } + spin_unlock_bh(&p->br->hash_lock); +} +EXPORT_SYMBOL_GPL(br_fdb_clear_offload); diff --git a/net/bridge/br_forward.c b/net/bridge/br_forward.c new file mode 100644 index 000000000..4e3394a7d --- /dev/null +++ b/net/bridge/br_forward.c @@ -0,0 +1,343 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Forwarding decision + * Linux ethernet bridge + * + * Authors: + * Lennert Buytenhek + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "br_private.h" + +/* Don't forward packets to originating port or forwarding disabled */ +static inline int should_deliver(const struct net_bridge_port *p, + const struct sk_buff *skb) +{ + struct net_bridge_vlan_group *vg; + + vg = nbp_vlan_group_rcu(p); + return ((p->flags & BR_HAIRPIN_MODE) || skb->dev != p->dev) && + p->state == BR_STATE_FORWARDING && br_allowed_egress(vg, skb) && + nbp_switchdev_allowed_egress(p, skb) && + !br_skb_isolated(p, skb); +} + +int br_dev_queue_push_xmit(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + skb_push(skb, ETH_HLEN); + if (!is_skb_forwardable(skb->dev, skb)) + goto drop; + + br_drop_fake_rtable(skb); + + if (skb->ip_summed == CHECKSUM_PARTIAL && + eth_type_vlan(skb->protocol)) { + int depth; + + if (!vlan_get_protocol_and_depth(skb, skb->protocol, &depth)) + goto drop; + + skb_set_network_header(skb, depth); + } + + br_switchdev_frame_set_offload_fwd_mark(skb); + + dev_queue_xmit(skb); + + return 0; + +drop: + kfree_skb(skb); + return 0; +} +EXPORT_SYMBOL_GPL(br_dev_queue_push_xmit); + +int br_forward_finish(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + skb_clear_tstamp(skb); + return NF_HOOK(NFPROTO_BRIDGE, NF_BR_POST_ROUTING, + net, sk, skb, NULL, skb->dev, + br_dev_queue_push_xmit); + +} +EXPORT_SYMBOL_GPL(br_forward_finish); + +static void __br_forward(const struct net_bridge_port *to, + struct sk_buff *skb, bool local_orig) +{ + struct net_bridge_vlan_group *vg; + struct net_device *indev; + struct net *net; + int br_hook; + + /* Mark the skb for forwarding offload early so that br_handle_vlan() + * can know whether to pop the VLAN header on egress or keep it. + */ + nbp_switchdev_frame_mark_tx_fwd_offload(to, skb); + + vg = nbp_vlan_group_rcu(to); + skb = br_handle_vlan(to->br, to, vg, skb); + if (!skb) + return; + + indev = skb->dev; + skb->dev = to->dev; + if (!local_orig) { + if (skb_warn_if_lro(skb)) { + kfree_skb(skb); + return; + } + br_hook = NF_BR_FORWARD; + skb_forward_csum(skb); + net = dev_net(indev); + } else { + if (unlikely(netpoll_tx_running(to->br->dev))) { + skb_push(skb, ETH_HLEN); + if (!is_skb_forwardable(skb->dev, skb)) + kfree_skb(skb); + else + br_netpoll_send_skb(to, skb); + return; + } + br_hook = NF_BR_LOCAL_OUT; + net = dev_net(skb->dev); + indev = NULL; + } + + NF_HOOK(NFPROTO_BRIDGE, br_hook, + net, NULL, skb, indev, skb->dev, + br_forward_finish); +} + +static int deliver_clone(const struct net_bridge_port *prev, + struct sk_buff *skb, bool local_orig) +{ + struct net_device *dev = BR_INPUT_SKB_CB(skb)->brdev; + + skb = skb_clone(skb, GFP_ATOMIC); + if (!skb) { + DEV_STATS_INC(dev, tx_dropped); + return -ENOMEM; + } + + __br_forward(prev, skb, local_orig); + return 0; +} + +/** + * br_forward - forward a packet to a specific port + * @to: destination port + * @skb: packet being forwarded + * @local_rcv: packet will be received locally after forwarding + * @local_orig: packet is locally originated + * + * Should be called with rcu_read_lock. + */ +void br_forward(const struct net_bridge_port *to, + struct sk_buff *skb, bool local_rcv, bool local_orig) +{ + if (unlikely(!to)) + goto out; + + /* redirect to backup link if the destination port is down */ + if (rcu_access_pointer(to->backup_port) && !netif_carrier_ok(to->dev)) { + struct net_bridge_port *backup_port; + + backup_port = rcu_dereference(to->backup_port); + if (unlikely(!backup_port)) + goto out; + to = backup_port; + } + + if (should_deliver(to, skb)) { + if (local_rcv) + deliver_clone(to, skb, local_orig); + else + __br_forward(to, skb, local_orig); + return; + } + +out: + if (!local_rcv) + kfree_skb(skb); +} +EXPORT_SYMBOL_GPL(br_forward); + +static struct net_bridge_port *maybe_deliver( + struct net_bridge_port *prev, struct net_bridge_port *p, + struct sk_buff *skb, bool local_orig) +{ + u8 igmp_type = br_multicast_igmp_type(skb); + int err; + + if (!should_deliver(p, skb)) + return prev; + + nbp_switchdev_frame_mark_tx_fwd_to_hwdom(p, skb); + + if (!prev) + goto out; + + err = deliver_clone(prev, skb, local_orig); + if (err) + return ERR_PTR(err); +out: + br_multicast_count(p->br, p, skb, igmp_type, BR_MCAST_DIR_TX); + + return p; +} + +/* called under rcu_read_lock */ +void br_flood(struct net_bridge *br, struct sk_buff *skb, + enum br_pkt_type pkt_type, bool local_rcv, bool local_orig) +{ + struct net_bridge_port *prev = NULL; + struct net_bridge_port *p; + + list_for_each_entry_rcu(p, &br->port_list, list) { + /* Do not flood unicast traffic to ports that turn it off, nor + * other traffic if flood off, except for traffic we originate + */ + switch (pkt_type) { + case BR_PKT_UNICAST: + if (!(p->flags & BR_FLOOD)) + continue; + break; + case BR_PKT_MULTICAST: + if (!(p->flags & BR_MCAST_FLOOD) && skb->dev != br->dev) + continue; + break; + case BR_PKT_BROADCAST: + if (!(p->flags & BR_BCAST_FLOOD) && skb->dev != br->dev) + continue; + break; + } + + /* Do not flood to ports that enable proxy ARP */ + if (p->flags & BR_PROXYARP) + continue; + if ((p->flags & (BR_PROXYARP_WIFI | BR_NEIGH_SUPPRESS)) && + BR_INPUT_SKB_CB(skb)->proxyarp_replied) + continue; + + prev = maybe_deliver(prev, p, skb, local_orig); + if (IS_ERR(prev)) + goto out; + } + + if (!prev) + goto out; + + if (local_rcv) + deliver_clone(prev, skb, local_orig); + else + __br_forward(prev, skb, local_orig); + return; + +out: + if (!local_rcv) + kfree_skb(skb); +} + +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING +static void maybe_deliver_addr(struct net_bridge_port *p, struct sk_buff *skb, + const unsigned char *addr, bool local_orig) +{ + struct net_device *dev = BR_INPUT_SKB_CB(skb)->brdev; + const unsigned char *src = eth_hdr(skb)->h_source; + + if (!should_deliver(p, skb)) + return; + + /* Even with hairpin, no soliloquies - prevent breaking IPv6 DAD */ + if (skb->dev == p->dev && ether_addr_equal(src, addr)) + return; + + skb = skb_copy(skb, GFP_ATOMIC); + if (!skb) { + DEV_STATS_INC(dev, tx_dropped); + return; + } + + if (!is_broadcast_ether_addr(addr)) + memcpy(eth_hdr(skb)->h_dest, addr, ETH_ALEN); + + __br_forward(p, skb, local_orig); +} + +/* called with rcu_read_lock */ +void br_multicast_flood(struct net_bridge_mdb_entry *mdst, + struct sk_buff *skb, + struct net_bridge_mcast *brmctx, + bool local_rcv, bool local_orig) +{ + struct net_bridge_port *prev = NULL; + struct net_bridge_port_group *p; + bool allow_mode_include = true; + struct hlist_node *rp; + + rp = br_multicast_get_first_rport_node(brmctx, skb); + + if (mdst) { + p = rcu_dereference(mdst->ports); + if (br_multicast_should_handle_mode(brmctx, mdst->addr.proto) && + br_multicast_is_star_g(&mdst->addr)) + allow_mode_include = false; + } else { + p = NULL; + } + + while (p || rp) { + struct net_bridge_port *port, *lport, *rport; + + lport = p ? p->key.port : NULL; + rport = br_multicast_rport_from_node_skb(rp, skb); + + if ((unsigned long)lport > (unsigned long)rport) { + port = lport; + + if (port->flags & BR_MULTICAST_TO_UNICAST) { + maybe_deliver_addr(lport, skb, p->eth_addr, + local_orig); + goto delivered; + } + if ((!allow_mode_include && + p->filter_mode == MCAST_INCLUDE) || + (p->flags & MDB_PG_FLAGS_BLOCKED)) + goto delivered; + } else { + port = rport; + } + + prev = maybe_deliver(prev, port, skb, local_orig); + if (IS_ERR(prev)) + goto out; +delivered: + if ((unsigned long)lport >= (unsigned long)port) + p = rcu_dereference(p->next); + if ((unsigned long)rport >= (unsigned long)port) + rp = rcu_dereference(hlist_next_rcu(rp)); + } + + if (!prev) + goto out; + + if (local_rcv) + deliver_clone(prev, skb, local_orig); + else + __br_forward(prev, skb, local_orig); + return; + +out: + if (!local_rcv) + kfree_skb(skb); +} +#endif diff --git a/net/bridge/br_if.c b/net/bridge/br_if.c new file mode 100644 index 000000000..0989074f3 --- /dev/null +++ b/net/bridge/br_if.c @@ -0,0 +1,777 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Userspace interface + * Linux ethernet bridge + * + * Authors: + * Lennert Buytenhek + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "br_private.h" + +/* + * Determine initial path cost based on speed. + * using recommendations from 802.1d standard + * + * Since driver might sleep need to not be holding any locks. + */ +static int port_cost(struct net_device *dev) +{ + struct ethtool_link_ksettings ecmd; + + if (!__ethtool_get_link_ksettings(dev, &ecmd)) { + switch (ecmd.base.speed) { + case SPEED_10000: + return 2; + case SPEED_5000: + return 3; + case SPEED_2500: + return 4; + case SPEED_1000: + return 5; + case SPEED_100: + return 19; + case SPEED_10: + return 100; + case SPEED_UNKNOWN: + return 100; + default: + if (ecmd.base.speed > SPEED_10000) + return 1; + } + } + + /* Old silly heuristics based on name */ + if (!strncmp(dev->name, "lec", 3)) + return 7; + + if (!strncmp(dev->name, "plip", 4)) + return 2500; + + return 100; /* assume old 10Mbps */ +} + + +/* Check for port carrier transitions. */ +void br_port_carrier_check(struct net_bridge_port *p, bool *notified) +{ + struct net_device *dev = p->dev; + struct net_bridge *br = p->br; + + if (!(p->flags & BR_ADMIN_COST) && + netif_running(dev) && netif_oper_up(dev)) + p->path_cost = port_cost(dev); + + *notified = false; + if (!netif_running(br->dev)) + return; + + spin_lock_bh(&br->lock); + if (netif_running(dev) && netif_oper_up(dev)) { + if (p->state == BR_STATE_DISABLED) { + br_stp_enable_port(p); + *notified = true; + } + } else { + if (p->state != BR_STATE_DISABLED) { + br_stp_disable_port(p); + *notified = true; + } + } + spin_unlock_bh(&br->lock); +} + +static void br_port_set_promisc(struct net_bridge_port *p) +{ + int err = 0; + + if (br_promisc_port(p)) + return; + + err = dev_set_promiscuity(p->dev, 1); + if (err) + return; + + br_fdb_unsync_static(p->br, p); + p->flags |= BR_PROMISC; +} + +static void br_port_clear_promisc(struct net_bridge_port *p) +{ + int err; + + /* Check if the port is already non-promisc or if it doesn't + * support UNICAST filtering. Without unicast filtering support + * we'll end up re-enabling promisc mode anyway, so just check for + * it here. + */ + if (!br_promisc_port(p) || !(p->dev->priv_flags & IFF_UNICAST_FLT)) + return; + + /* Since we'll be clearing the promisc mode, program the port + * first so that we don't have interruption in traffic. + */ + err = br_fdb_sync_static(p->br, p); + if (err) + return; + + dev_set_promiscuity(p->dev, -1); + p->flags &= ~BR_PROMISC; +} + +/* When a port is added or removed or when certain port flags + * change, this function is called to automatically manage + * promiscuity setting of all the bridge ports. We are always called + * under RTNL so can skip using rcu primitives. + */ +void br_manage_promisc(struct net_bridge *br) +{ + struct net_bridge_port *p; + bool set_all = false; + + /* If vlan filtering is disabled or bridge interface is placed + * into promiscuous mode, place all ports in promiscuous mode. + */ + if ((br->dev->flags & IFF_PROMISC) || !br_vlan_enabled(br->dev)) + set_all = true; + + list_for_each_entry(p, &br->port_list, list) { + if (set_all) { + br_port_set_promisc(p); + } else { + /* If the number of auto-ports is <= 1, then all other + * ports will have their output configuration + * statically specified through fdbs. Since ingress + * on the auto-port becomes forwarding/egress to other + * ports and egress configuration is statically known, + * we can say that ingress configuration of the + * auto-port is also statically known. + * This lets us disable promiscuous mode and write + * this config to hw. + */ + if ((p->dev->priv_flags & IFF_UNICAST_FLT) && + (br->auto_cnt == 0 || + (br->auto_cnt == 1 && br_auto_port(p)))) + br_port_clear_promisc(p); + else + br_port_set_promisc(p); + } + } +} + +int nbp_backup_change(struct net_bridge_port *p, + struct net_device *backup_dev) +{ + struct net_bridge_port *old_backup = rtnl_dereference(p->backup_port); + struct net_bridge_port *backup_p = NULL; + + ASSERT_RTNL(); + + if (backup_dev) { + if (!netif_is_bridge_port(backup_dev)) + return -ENOENT; + + backup_p = br_port_get_rtnl(backup_dev); + if (backup_p->br != p->br) + return -EINVAL; + } + + if (p == backup_p) + return -EINVAL; + + if (old_backup == backup_p) + return 0; + + /* if the backup link is already set, clear it */ + if (old_backup) + old_backup->backup_redirected_cnt--; + + if (backup_p) + backup_p->backup_redirected_cnt++; + rcu_assign_pointer(p->backup_port, backup_p); + + return 0; +} + +static void nbp_backup_clear(struct net_bridge_port *p) +{ + nbp_backup_change(p, NULL); + if (p->backup_redirected_cnt) { + struct net_bridge_port *cur_p; + + list_for_each_entry(cur_p, &p->br->port_list, list) { + struct net_bridge_port *backup_p; + + backup_p = rtnl_dereference(cur_p->backup_port); + if (backup_p == p) + nbp_backup_change(cur_p, NULL); + } + } + + WARN_ON(rcu_access_pointer(p->backup_port) || p->backup_redirected_cnt); +} + +static void nbp_update_port_count(struct net_bridge *br) +{ + struct net_bridge_port *p; + u32 cnt = 0; + + list_for_each_entry(p, &br->port_list, list) { + if (br_auto_port(p)) + cnt++; + } + if (br->auto_cnt != cnt) { + br->auto_cnt = cnt; + br_manage_promisc(br); + } +} + +static void nbp_delete_promisc(struct net_bridge_port *p) +{ + /* If port is currently promiscuous, unset promiscuity. + * Otherwise, it is a static port so remove all addresses + * from it. + */ + dev_set_allmulti(p->dev, -1); + if (br_promisc_port(p)) + dev_set_promiscuity(p->dev, -1); + else + br_fdb_unsync_static(p->br, p); +} + +static void release_nbp(struct kobject *kobj) +{ + struct net_bridge_port *p + = container_of(kobj, struct net_bridge_port, kobj); + kfree(p); +} + +static void brport_get_ownership(struct kobject *kobj, kuid_t *uid, kgid_t *gid) +{ + struct net_bridge_port *p = kobj_to_brport(kobj); + + net_ns_get_ownership(dev_net(p->dev), uid, gid); +} + +static struct kobj_type brport_ktype = { +#ifdef CONFIG_SYSFS + .sysfs_ops = &brport_sysfs_ops, +#endif + .release = release_nbp, + .get_ownership = brport_get_ownership, +}; + +static void destroy_nbp(struct net_bridge_port *p) +{ + struct net_device *dev = p->dev; + + p->br = NULL; + p->dev = NULL; + netdev_put(dev, &p->dev_tracker); + + kobject_put(&p->kobj); +} + +static void destroy_nbp_rcu(struct rcu_head *head) +{ + struct net_bridge_port *p = + container_of(head, struct net_bridge_port, rcu); + destroy_nbp(p); +} + +static unsigned get_max_headroom(struct net_bridge *br) +{ + unsigned max_headroom = 0; + struct net_bridge_port *p; + + list_for_each_entry(p, &br->port_list, list) { + unsigned dev_headroom = netdev_get_fwd_headroom(p->dev); + + if (dev_headroom > max_headroom) + max_headroom = dev_headroom; + } + + return max_headroom; +} + +static void update_headroom(struct net_bridge *br, int new_hr) +{ + struct net_bridge_port *p; + + list_for_each_entry(p, &br->port_list, list) + netdev_set_rx_headroom(p->dev, new_hr); + + br->dev->needed_headroom = new_hr; +} + +/* Delete port(interface) from bridge is done in two steps. + * via RCU. First step, marks device as down. That deletes + * all the timers and stops new packets from flowing through. + * + * Final cleanup doesn't occur until after all CPU's finished + * processing packets. + * + * Protected from multiple admin operations by RTNL mutex + */ +static void del_nbp(struct net_bridge_port *p) +{ + struct net_bridge *br = p->br; + struct net_device *dev = p->dev; + + sysfs_remove_link(br->ifobj, p->dev->name); + + nbp_delete_promisc(p); + + spin_lock_bh(&br->lock); + br_stp_disable_port(p); + spin_unlock_bh(&br->lock); + + br_mrp_port_del(br, p); + br_cfm_port_del(br, p); + + br_ifinfo_notify(RTM_DELLINK, NULL, p); + + list_del_rcu(&p->list); + if (netdev_get_fwd_headroom(dev) == br->dev->needed_headroom) + update_headroom(br, get_max_headroom(br)); + netdev_reset_rx_headroom(dev); + + nbp_vlan_flush(p); + br_fdb_delete_by_port(br, p, 0, 1); + switchdev_deferred_process(); + nbp_backup_clear(p); + + nbp_update_port_count(br); + + netdev_upper_dev_unlink(dev, br->dev); + + dev->priv_flags &= ~IFF_BRIDGE_PORT; + + netdev_rx_handler_unregister(dev); + + br_multicast_del_port(p); + + kobject_uevent(&p->kobj, KOBJ_REMOVE); + kobject_del(&p->kobj); + + br_netpoll_disable(p); + + call_rcu(&p->rcu, destroy_nbp_rcu); +} + +/* Delete bridge device */ +void br_dev_delete(struct net_device *dev, struct list_head *head) +{ + struct net_bridge *br = netdev_priv(dev); + struct net_bridge_port *p, *n; + + list_for_each_entry_safe(p, n, &br->port_list, list) { + del_nbp(p); + } + + br_recalculate_neigh_suppress_enabled(br); + + br_fdb_delete_by_port(br, NULL, 0, 1); + + cancel_delayed_work_sync(&br->gc_work); + + br_sysfs_delbr(br->dev); + unregister_netdevice_queue(br->dev, head); +} + +/* find an available port number */ +static int find_portno(struct net_bridge *br) +{ + int index; + struct net_bridge_port *p; + unsigned long *inuse; + + inuse = bitmap_zalloc(BR_MAX_PORTS, GFP_KERNEL); + if (!inuse) + return -ENOMEM; + + __set_bit(0, inuse); /* zero is reserved */ + list_for_each_entry(p, &br->port_list, list) + __set_bit(p->port_no, inuse); + + index = find_first_zero_bit(inuse, BR_MAX_PORTS); + bitmap_free(inuse); + + return (index >= BR_MAX_PORTS) ? -EXFULL : index; +} + +/* called with RTNL but without bridge lock */ +static struct net_bridge_port *new_nbp(struct net_bridge *br, + struct net_device *dev) +{ + struct net_bridge_port *p; + int index, err; + + index = find_portno(br); + if (index < 0) + return ERR_PTR(index); + + p = kzalloc(sizeof(*p), GFP_KERNEL); + if (p == NULL) + return ERR_PTR(-ENOMEM); + + p->br = br; + netdev_hold(dev, &p->dev_tracker, GFP_KERNEL); + p->dev = dev; + p->path_cost = port_cost(dev); + p->priority = 0x8000 >> BR_PORT_BITS; + p->port_no = index; + p->flags = BR_LEARNING | BR_FLOOD | BR_MCAST_FLOOD | BR_BCAST_FLOOD; + br_init_port(p); + br_set_state(p, BR_STATE_DISABLED); + br_stp_port_timer_init(p); + err = br_multicast_add_port(p); + if (err) { + netdev_put(dev, &p->dev_tracker); + kfree(p); + p = ERR_PTR(err); + } + + return p; +} + +int br_add_bridge(struct net *net, const char *name) +{ + struct net_device *dev; + int res; + + dev = alloc_netdev(sizeof(struct net_bridge), name, NET_NAME_UNKNOWN, + br_dev_setup); + + if (!dev) + return -ENOMEM; + + dev_net_set(dev, net); + dev->rtnl_link_ops = &br_link_ops; + + res = register_netdevice(dev); + if (res) + free_netdev(dev); + return res; +} + +int br_del_bridge(struct net *net, const char *name) +{ + struct net_device *dev; + int ret = 0; + + dev = __dev_get_by_name(net, name); + if (dev == NULL) + ret = -ENXIO; /* Could not find device */ + + else if (!netif_is_bridge_master(dev)) { + /* Attempt to delete non bridge device! */ + ret = -EPERM; + } + + else if (dev->flags & IFF_UP) { + /* Not shutdown yet. */ + ret = -EBUSY; + } + + else + br_dev_delete(dev, NULL); + + return ret; +} + +/* MTU of the bridge pseudo-device: ETH_DATA_LEN or the minimum of the ports */ +static int br_mtu_min(const struct net_bridge *br) +{ + const struct net_bridge_port *p; + int ret_mtu = 0; + + list_for_each_entry(p, &br->port_list, list) + if (!ret_mtu || ret_mtu > p->dev->mtu) + ret_mtu = p->dev->mtu; + + return ret_mtu ? ret_mtu : ETH_DATA_LEN; +} + +void br_mtu_auto_adjust(struct net_bridge *br) +{ + ASSERT_RTNL(); + + /* if the bridge MTU was manually configured don't mess with it */ + if (br_opt_get(br, BROPT_MTU_SET_BY_USER)) + return; + + /* change to the minimum MTU and clear the flag which was set by + * the bridge ndo_change_mtu callback + */ + dev_set_mtu(br->dev, br_mtu_min(br)); + br_opt_toggle(br, BROPT_MTU_SET_BY_USER, false); +} + +static void br_set_gso_limits(struct net_bridge *br) +{ + unsigned int tso_max_size = TSO_MAX_SIZE; + const struct net_bridge_port *p; + u16 tso_max_segs = TSO_MAX_SEGS; + + list_for_each_entry(p, &br->port_list, list) { + tso_max_size = min(tso_max_size, p->dev->tso_max_size); + tso_max_segs = min(tso_max_segs, p->dev->tso_max_segs); + } + netif_set_tso_max_size(br->dev, tso_max_size); + netif_set_tso_max_segs(br->dev, tso_max_segs); +} + +/* + * Recomputes features using slave's features + */ +netdev_features_t br_features_recompute(struct net_bridge *br, + netdev_features_t features) +{ + struct net_bridge_port *p; + netdev_features_t mask; + + if (list_empty(&br->port_list)) + return features; + + mask = features; + features &= ~NETIF_F_ONE_FOR_ALL; + + list_for_each_entry(p, &br->port_list, list) { + features = netdev_increment_features(features, + p->dev->features, mask); + } + features = netdev_add_tso_features(features, mask); + + return features; +} + +/* called with RTNL */ +int br_add_if(struct net_bridge *br, struct net_device *dev, + struct netlink_ext_ack *extack) +{ + struct net_bridge_port *p; + int err = 0; + unsigned br_hr, dev_hr; + bool changed_addr, fdb_synced = false; + + /* Don't allow bridging non-ethernet like devices. */ + if ((dev->flags & IFF_LOOPBACK) || + dev->type != ARPHRD_ETHER || dev->addr_len != ETH_ALEN || + !is_valid_ether_addr(dev->dev_addr)) + return -EINVAL; + + /* No bridging of bridges */ + if (dev->netdev_ops->ndo_start_xmit == br_dev_xmit) { + NL_SET_ERR_MSG(extack, + "Can not enslave a bridge to a bridge"); + return -ELOOP; + } + + /* Device has master upper dev */ + if (netdev_master_upper_dev_get(dev)) + return -EBUSY; + + /* No bridging devices that dislike that (e.g. wireless) */ + if (dev->priv_flags & IFF_DONT_BRIDGE) { + NL_SET_ERR_MSG(extack, + "Device does not allow enslaving to a bridge"); + return -EOPNOTSUPP; + } + + p = new_nbp(br, dev); + if (IS_ERR(p)) + return PTR_ERR(p); + + call_netdevice_notifiers(NETDEV_JOIN, dev); + + err = dev_set_allmulti(dev, 1); + if (err) { + br_multicast_del_port(p); + netdev_put(dev, &p->dev_tracker); + kfree(p); /* kobject not yet init'd, manually free */ + goto err1; + } + + err = kobject_init_and_add(&p->kobj, &brport_ktype, &(dev->dev.kobj), + SYSFS_BRIDGE_PORT_ATTR); + if (err) + goto err2; + + err = br_sysfs_addif(p); + if (err) + goto err2; + + err = br_netpoll_enable(p); + if (err) + goto err3; + + err = netdev_rx_handler_register(dev, br_get_rx_handler(dev), p); + if (err) + goto err4; + + dev->priv_flags |= IFF_BRIDGE_PORT; + + err = netdev_master_upper_dev_link(dev, br->dev, NULL, NULL, extack); + if (err) + goto err5; + + dev_disable_lro(dev); + + list_add_rcu(&p->list, &br->port_list); + + nbp_update_port_count(br); + if (!br_promisc_port(p) && (p->dev->priv_flags & IFF_UNICAST_FLT)) { + /* When updating the port count we also update all ports' + * promiscuous mode. + * A port leaving promiscuous mode normally gets the bridge's + * fdb synced to the unicast filter (if supported), however, + * `br_port_clear_promisc` does not distinguish between + * non-promiscuous ports and *new* ports, so we need to + * sync explicitly here. + */ + fdb_synced = br_fdb_sync_static(br, p) == 0; + if (!fdb_synced) + netdev_err(dev, "failed to sync bridge static fdb addresses to this port\n"); + } + + netdev_update_features(br->dev); + + br_hr = br->dev->needed_headroom; + dev_hr = netdev_get_fwd_headroom(dev); + if (br_hr < dev_hr) + update_headroom(br, dev_hr); + else + netdev_set_rx_headroom(dev, br_hr); + + if (br_fdb_add_local(br, p, dev->dev_addr, 0)) + netdev_err(dev, "failed insert local address bridge forwarding table\n"); + + if (br->dev->addr_assign_type != NET_ADDR_SET) { + /* Ask for permission to use this MAC address now, even if we + * don't end up choosing it below. + */ + err = dev_pre_changeaddr_notify(br->dev, dev->dev_addr, extack); + if (err) + goto err6; + } + + err = nbp_vlan_init(p, extack); + if (err) { + netdev_err(dev, "failed to initialize vlan filtering on this port\n"); + goto err6; + } + + spin_lock_bh(&br->lock); + changed_addr = br_stp_recalculate_bridge_id(br); + + if (netif_running(dev) && netif_oper_up(dev) && + (br->dev->flags & IFF_UP)) + br_stp_enable_port(p); + spin_unlock_bh(&br->lock); + + br_ifinfo_notify(RTM_NEWLINK, NULL, p); + + if (changed_addr) + call_netdevice_notifiers(NETDEV_CHANGEADDR, br->dev); + + br_mtu_auto_adjust(br); + br_set_gso_limits(br); + + kobject_uevent(&p->kobj, KOBJ_ADD); + + return 0; + +err6: + if (fdb_synced) + br_fdb_unsync_static(br, p); + list_del_rcu(&p->list); + br_fdb_delete_by_port(br, p, 0, 1); + nbp_update_port_count(br); + netdev_upper_dev_unlink(dev, br->dev); +err5: + dev->priv_flags &= ~IFF_BRIDGE_PORT; + netdev_rx_handler_unregister(dev); +err4: + br_netpoll_disable(p); +err3: + sysfs_remove_link(br->ifobj, p->dev->name); +err2: + br_multicast_del_port(p); + netdev_put(dev, &p->dev_tracker); + kobject_put(&p->kobj); + dev_set_allmulti(dev, -1); +err1: + return err; +} + +/* called with RTNL */ +int br_del_if(struct net_bridge *br, struct net_device *dev) +{ + struct net_bridge_port *p; + bool changed_addr; + + p = br_port_get_rtnl(dev); + if (!p || p->br != br) + return -EINVAL; + + /* Since more than one interface can be attached to a bridge, + * there still maybe an alternate path for netconsole to use; + * therefore there is no reason for a NETDEV_RELEASE event. + */ + del_nbp(p); + + br_mtu_auto_adjust(br); + br_set_gso_limits(br); + + spin_lock_bh(&br->lock); + changed_addr = br_stp_recalculate_bridge_id(br); + spin_unlock_bh(&br->lock); + + if (changed_addr) + call_netdevice_notifiers(NETDEV_CHANGEADDR, br->dev); + + netdev_update_features(br->dev); + + return 0; +} + +void br_port_flags_change(struct net_bridge_port *p, unsigned long mask) +{ + struct net_bridge *br = p->br; + + if (mask & BR_AUTO_MASK) + nbp_update_port_count(br); + + if (mask & BR_NEIGH_SUPPRESS) + br_recalculate_neigh_suppress_enabled(br); +} + +bool br_port_flag_is_set(const struct net_device *dev, unsigned long flag) +{ + struct net_bridge_port *p; + + p = br_port_get_rtnl_rcu(dev); + if (!p) + return false; + + return p->flags & flag; +} +EXPORT_SYMBOL_GPL(br_port_flag_is_set); diff --git a/net/bridge/br_input.c b/net/bridge/br_input.c new file mode 100644 index 000000000..6bb272894 --- /dev/null +++ b/net/bridge/br_input.c @@ -0,0 +1,441 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Handle incoming frames + * Linux ethernet bridge + * + * Authors: + * Lennert Buytenhek + */ + +#include +#include +#include +#include +#include +#ifdef CONFIG_NETFILTER_FAMILY_BRIDGE +#include +#endif +#include +#include +#include +#include +#include +#include "br_private.h" +#include "br_private_tunnel.h" + +static int +br_netif_receive_skb(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + br_drop_fake_rtable(skb); + return netif_receive_skb(skb); +} + +static int br_pass_frame_up(struct sk_buff *skb) +{ + struct net_device *indev, *brdev = BR_INPUT_SKB_CB(skb)->brdev; + struct net_bridge *br = netdev_priv(brdev); + struct net_bridge_vlan_group *vg; + + dev_sw_netstats_rx_add(brdev, skb->len); + + vg = br_vlan_group_rcu(br); + + /* Reset the offload_fwd_mark because there could be a stacked + * bridge above, and it should not think this bridge it doing + * that bridge's work forwarding out its ports. + */ + br_switchdev_frame_unmark(skb); + + /* Bridge is just like any other port. Make sure the + * packet is allowed except in promisc mode when someone + * may be running packet capture. + */ + if (!(brdev->flags & IFF_PROMISC) && + !br_allowed_egress(vg, skb)) { + kfree_skb(skb); + return NET_RX_DROP; + } + + indev = skb->dev; + skb->dev = brdev; + skb = br_handle_vlan(br, NULL, vg, skb); + if (!skb) + return NET_RX_DROP; + /* update the multicast stats if the packet is IGMP/MLD */ + br_multicast_count(br, NULL, skb, br_multicast_igmp_type(skb), + BR_MCAST_DIR_TX); + + return NF_HOOK(NFPROTO_BRIDGE, NF_BR_LOCAL_IN, + dev_net(indev), NULL, skb, indev, NULL, + br_netif_receive_skb); +} + +/* note: already called with rcu_read_lock */ +int br_handle_frame_finish(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct net_bridge_port *p = br_port_get_rcu(skb->dev); + enum br_pkt_type pkt_type = BR_PKT_UNICAST; + struct net_bridge_fdb_entry *dst = NULL; + struct net_bridge_mcast_port *pmctx; + struct net_bridge_mdb_entry *mdst; + bool local_rcv, mcast_hit = false; + struct net_bridge_mcast *brmctx; + struct net_bridge_vlan *vlan; + struct net_bridge *br; + u16 vid = 0; + u8 state; + + if (!p) + goto drop; + + br = p->br; + + if (br_mst_is_enabled(br)) { + state = BR_STATE_FORWARDING; + } else { + if (p->state == BR_STATE_DISABLED) + goto drop; + + state = p->state; + } + + brmctx = &p->br->multicast_ctx; + pmctx = &p->multicast_ctx; + if (!br_allowed_ingress(p->br, nbp_vlan_group_rcu(p), skb, &vid, + &state, &vlan)) + goto out; + + if (p->flags & BR_PORT_LOCKED) { + struct net_bridge_fdb_entry *fdb_src = + br_fdb_find_rcu(br, eth_hdr(skb)->h_source, vid); + + if (!fdb_src || READ_ONCE(fdb_src->dst) != p || + test_bit(BR_FDB_LOCAL, &fdb_src->flags)) + goto drop; + } + + nbp_switchdev_frame_mark(p, skb); + + /* insert into forwarding database after filtering to avoid spoofing */ + if (p->flags & BR_LEARNING) + br_fdb_update(br, p, eth_hdr(skb)->h_source, vid, 0); + + local_rcv = !!(br->dev->flags & IFF_PROMISC); + if (is_multicast_ether_addr(eth_hdr(skb)->h_dest)) { + /* by definition the broadcast is also a multicast address */ + if (is_broadcast_ether_addr(eth_hdr(skb)->h_dest)) { + pkt_type = BR_PKT_BROADCAST; + local_rcv = true; + } else { + pkt_type = BR_PKT_MULTICAST; + if (br_multicast_rcv(&brmctx, &pmctx, vlan, skb, vid)) + goto drop; + } + } + + if (state == BR_STATE_LEARNING) + goto drop; + + BR_INPUT_SKB_CB(skb)->brdev = br->dev; + BR_INPUT_SKB_CB(skb)->src_port_isolated = !!(p->flags & BR_ISOLATED); + + if (IS_ENABLED(CONFIG_INET) && + (skb->protocol == htons(ETH_P_ARP) || + skb->protocol == htons(ETH_P_RARP))) { + br_do_proxy_suppress_arp(skb, br, vid, p); + } else if (IS_ENABLED(CONFIG_IPV6) && + skb->protocol == htons(ETH_P_IPV6) && + br_opt_get(br, BROPT_NEIGH_SUPPRESS_ENABLED) && + pskb_may_pull(skb, sizeof(struct ipv6hdr) + + sizeof(struct nd_msg)) && + ipv6_hdr(skb)->nexthdr == IPPROTO_ICMPV6) { + struct nd_msg *msg, _msg; + + msg = br_is_nd_neigh_msg(skb, &_msg); + if (msg) + br_do_suppress_nd(skb, br, vid, p, msg); + } + + switch (pkt_type) { + case BR_PKT_MULTICAST: + mdst = br_mdb_get(brmctx, skb, vid); + if ((mdst || BR_INPUT_SKB_CB_MROUTERS_ONLY(skb)) && + br_multicast_querier_exists(brmctx, eth_hdr(skb), mdst)) { + if ((mdst && mdst->host_joined) || + br_multicast_is_router(brmctx, skb)) { + local_rcv = true; + DEV_STATS_INC(br->dev, multicast); + } + mcast_hit = true; + } else { + local_rcv = true; + DEV_STATS_INC(br->dev, multicast); + } + break; + case BR_PKT_UNICAST: + dst = br_fdb_find_rcu(br, eth_hdr(skb)->h_dest, vid); + break; + default: + break; + } + + if (dst) { + unsigned long now = jiffies; + + if (test_bit(BR_FDB_LOCAL, &dst->flags)) + return br_pass_frame_up(skb); + + if (now != dst->used) + dst->used = now; + br_forward(dst->dst, skb, local_rcv, false); + } else { + if (!mcast_hit) + br_flood(br, skb, pkt_type, local_rcv, false); + else + br_multicast_flood(mdst, skb, brmctx, local_rcv, false); + } + + if (local_rcv) + return br_pass_frame_up(skb); + +out: + return 0; +drop: + kfree_skb(skb); + goto out; +} +EXPORT_SYMBOL_GPL(br_handle_frame_finish); + +static void __br_handle_local_finish(struct sk_buff *skb) +{ + struct net_bridge_port *p = br_port_get_rcu(skb->dev); + u16 vid = 0; + + /* check if vlan is allowed, to avoid spoofing */ + if ((p->flags & BR_LEARNING) && + nbp_state_should_learn(p) && + !br_opt_get(p->br, BROPT_NO_LL_LEARN) && + br_should_learn(p, skb, &vid)) + br_fdb_update(p->br, p, eth_hdr(skb)->h_source, vid, 0); +} + +/* note: already called with rcu_read_lock */ +static int br_handle_local_finish(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + __br_handle_local_finish(skb); + + /* return 1 to signal the okfn() was called so it's ok to use the skb */ + return 1; +} + +static int nf_hook_bridge_pre(struct sk_buff *skb, struct sk_buff **pskb) +{ +#ifdef CONFIG_NETFILTER_FAMILY_BRIDGE + struct nf_hook_entries *e = NULL; + struct nf_hook_state state; + unsigned int verdict, i; + struct net *net; + int ret; + + net = dev_net(skb->dev); +#ifdef HAVE_JUMP_LABEL + if (!static_key_false(&nf_hooks_needed[NFPROTO_BRIDGE][NF_BR_PRE_ROUTING])) + goto frame_finish; +#endif + + e = rcu_dereference(net->nf.hooks_bridge[NF_BR_PRE_ROUTING]); + if (!e) + goto frame_finish; + + nf_hook_state_init(&state, NF_BR_PRE_ROUTING, + NFPROTO_BRIDGE, skb->dev, NULL, NULL, + net, br_handle_frame_finish); + + for (i = 0; i < e->num_hook_entries; i++) { + verdict = nf_hook_entry_hookfn(&e->hooks[i], skb, &state); + switch (verdict & NF_VERDICT_MASK) { + case NF_ACCEPT: + if (BR_INPUT_SKB_CB(skb)->br_netfilter_broute) { + *pskb = skb; + return RX_HANDLER_PASS; + } + break; + case NF_DROP: + kfree_skb(skb); + return RX_HANDLER_CONSUMED; + case NF_QUEUE: + ret = nf_queue(skb, &state, i, verdict); + if (ret == 1) + continue; + return RX_HANDLER_CONSUMED; + default: /* STOLEN */ + return RX_HANDLER_CONSUMED; + } + } +frame_finish: + net = dev_net(skb->dev); + br_handle_frame_finish(net, NULL, skb); +#else + br_handle_frame_finish(dev_net(skb->dev), NULL, skb); +#endif + return RX_HANDLER_CONSUMED; +} + +/* Return 0 if the frame was not processed otherwise 1 + * note: already called with rcu_read_lock + */ +static int br_process_frame_type(struct net_bridge_port *p, + struct sk_buff *skb) +{ + struct br_frame_type *tmp; + + hlist_for_each_entry_rcu(tmp, &p->br->frame_type_list, list) + if (unlikely(tmp->type == skb->protocol)) + return tmp->frame_handler(p, skb); + + return 0; +} + +/* + * Return NULL if skb is handled + * note: already called with rcu_read_lock + */ +static rx_handler_result_t br_handle_frame(struct sk_buff **pskb) +{ + struct net_bridge_port *p; + struct sk_buff *skb = *pskb; + const unsigned char *dest = eth_hdr(skb)->h_dest; + + if (unlikely(skb->pkt_type == PACKET_LOOPBACK)) + return RX_HANDLER_PASS; + + if (!is_valid_ether_addr(eth_hdr(skb)->h_source)) + goto drop; + + skb = skb_share_check(skb, GFP_ATOMIC); + if (!skb) + return RX_HANDLER_CONSUMED; + + memset(skb->cb, 0, sizeof(struct br_input_skb_cb)); + + p = br_port_get_rcu(skb->dev); + if (p->flags & BR_VLAN_TUNNEL) + br_handle_ingress_vlan_tunnel(skb, p, nbp_vlan_group_rcu(p)); + + if (unlikely(is_link_local_ether_addr(dest))) { + u16 fwd_mask = p->br->group_fwd_mask_required; + + /* + * See IEEE 802.1D Table 7-10 Reserved addresses + * + * Assignment Value + * Bridge Group Address 01-80-C2-00-00-00 + * (MAC Control) 802.3 01-80-C2-00-00-01 + * (Link Aggregation) 802.3 01-80-C2-00-00-02 + * 802.1X PAE address 01-80-C2-00-00-03 + * + * 802.1AB LLDP 01-80-C2-00-00-0E + * + * Others reserved for future standardization + */ + fwd_mask |= p->group_fwd_mask; + switch (dest[5]) { + case 0x00: /* Bridge Group Address */ + /* If STP is turned off, + then must forward to keep loop detection */ + if (p->br->stp_enabled == BR_NO_STP || + fwd_mask & (1u << dest[5])) + goto forward; + *pskb = skb; + __br_handle_local_finish(skb); + return RX_HANDLER_PASS; + + case 0x01: /* IEEE MAC (Pause) */ + goto drop; + + case 0x0E: /* 802.1AB LLDP */ + fwd_mask |= p->br->group_fwd_mask; + if (fwd_mask & (1u << dest[5])) + goto forward; + *pskb = skb; + __br_handle_local_finish(skb); + return RX_HANDLER_PASS; + + default: + /* Allow selective forwarding for most other protocols */ + fwd_mask |= p->br->group_fwd_mask; + if (fwd_mask & (1u << dest[5])) + goto forward; + } + + /* The else clause should be hit when nf_hook(): + * - returns < 0 (drop/error) + * - returns = 0 (stolen/nf_queue) + * Thus return 1 from the okfn() to signal the skb is ok to pass + */ + if (NF_HOOK(NFPROTO_BRIDGE, NF_BR_LOCAL_IN, + dev_net(skb->dev), NULL, skb, skb->dev, NULL, + br_handle_local_finish) == 1) { + return RX_HANDLER_PASS; + } else { + return RX_HANDLER_CONSUMED; + } + } + + if (unlikely(br_process_frame_type(p, skb))) + return RX_HANDLER_PASS; + +forward: + if (br_mst_is_enabled(p->br)) + goto defer_stp_filtering; + + switch (p->state) { + case BR_STATE_FORWARDING: + case BR_STATE_LEARNING: +defer_stp_filtering: + if (ether_addr_equal(p->br->dev->dev_addr, dest)) + skb->pkt_type = PACKET_HOST; + + return nf_hook_bridge_pre(skb, pskb); + default: +drop: + kfree_skb(skb); + } + return RX_HANDLER_CONSUMED; +} + +/* This function has no purpose other than to appease the br_port_get_rcu/rtnl + * helpers which identify bridged ports according to the rx_handler installed + * on them (so there _needs_ to be a bridge rx_handler even if we don't need it + * to do anything useful). This bridge won't support traffic to/from the stack, + * but only hardware bridging. So return RX_HANDLER_PASS so we don't steal + * frames from the ETH_P_XDSA packet_type handler. + */ +static rx_handler_result_t br_handle_frame_dummy(struct sk_buff **pskb) +{ + return RX_HANDLER_PASS; +} + +rx_handler_func_t *br_get_rx_handler(const struct net_device *dev) +{ + if (netdev_uses_dsa(dev)) + return br_handle_frame_dummy; + + return br_handle_frame; +} + +void br_add_frame(struct net_bridge *br, struct br_frame_type *ft) +{ + hlist_add_head_rcu(&ft->list, &br->frame_type_list); +} + +void br_del_frame(struct net_bridge *br, struct br_frame_type *ft) +{ + struct br_frame_type *tmp; + + hlist_for_each_entry(tmp, &br->frame_type_list, list) + if (ft == tmp) { + hlist_del_rcu(&ft->list); + return; + } +} diff --git a/net/bridge/br_ioctl.c b/net/bridge/br_ioctl.c new file mode 100644 index 000000000..f213ed108 --- /dev/null +++ b/net/bridge/br_ioctl.c @@ -0,0 +1,440 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Ioctl handler + * Linux ethernet bridge + * + * Authors: + * Lennert Buytenhek + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "br_private.h" + +static int get_bridge_ifindices(struct net *net, int *indices, int num) +{ + struct net_device *dev; + int i = 0; + + rcu_read_lock(); + for_each_netdev_rcu(net, dev) { + if (i >= num) + break; + if (netif_is_bridge_master(dev)) + indices[i++] = dev->ifindex; + } + rcu_read_unlock(); + + return i; +} + +/* called with RTNL */ +static void get_port_ifindices(struct net_bridge *br, int *ifindices, int num) +{ + struct net_bridge_port *p; + + list_for_each_entry(p, &br->port_list, list) { + if (p->port_no < num) + ifindices[p->port_no] = p->dev->ifindex; + } +} + +/* + * Format up to a page worth of forwarding table entries + * userbuf -- where to copy result + * maxnum -- maximum number of entries desired + * (limited to a page for sanity) + * offset -- number of records to skip + */ +static int get_fdb_entries(struct net_bridge *br, void __user *userbuf, + unsigned long maxnum, unsigned long offset) +{ + int num; + void *buf; + size_t size; + + /* Clamp size to PAGE_SIZE, test maxnum to avoid overflow */ + if (maxnum > PAGE_SIZE/sizeof(struct __fdb_entry)) + maxnum = PAGE_SIZE/sizeof(struct __fdb_entry); + + size = maxnum * sizeof(struct __fdb_entry); + + buf = kmalloc(size, GFP_USER); + if (!buf) + return -ENOMEM; + + num = br_fdb_fillbuf(br, buf, maxnum, offset); + if (num > 0) { + if (copy_to_user(userbuf, buf, + array_size(num, sizeof(struct __fdb_entry)))) + num = -EFAULT; + } + kfree(buf); + + return num; +} + +/* called with RTNL */ +static int add_del_if(struct net_bridge *br, int ifindex, int isadd) +{ + struct net *net = dev_net(br->dev); + struct net_device *dev; + int ret; + + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + dev = __dev_get_by_index(net, ifindex); + if (dev == NULL) + return -EINVAL; + + if (isadd) + ret = br_add_if(br, dev, NULL); + else + ret = br_del_if(br, dev); + + return ret; +} + +#define BR_UARGS_MAX 4 +static int br_dev_read_uargs(unsigned long *args, size_t nr_args, + void __user **argp, void __user *data) +{ + int ret; + + if (nr_args < 2 || nr_args > BR_UARGS_MAX) + return -EINVAL; + + if (in_compat_syscall()) { + unsigned int cargs[BR_UARGS_MAX]; + int i; + + ret = copy_from_user(cargs, data, nr_args * sizeof(*cargs)); + if (ret) + goto fault; + + for (i = 0; i < nr_args; ++i) + args[i] = cargs[i]; + + *argp = compat_ptr(args[1]); + } else { + ret = copy_from_user(args, data, nr_args * sizeof(*args)); + if (ret) + goto fault; + *argp = (void __user *)args[1]; + } + + return 0; +fault: + return -EFAULT; +} + +/* + * Legacy ioctl's through SIOCDEVPRIVATE + * This interface is deprecated because it was too difficult + * to do the translation for 32/64bit ioctl compatibility. + */ +int br_dev_siocdevprivate(struct net_device *dev, struct ifreq *rq, + void __user *data, int cmd) +{ + struct net_bridge *br = netdev_priv(dev); + struct net_bridge_port *p = NULL; + unsigned long args[4]; + void __user *argp; + int ret; + + ret = br_dev_read_uargs(args, ARRAY_SIZE(args), &argp, data); + if (ret) + return ret; + + switch (args[0]) { + case BRCTL_ADD_IF: + case BRCTL_DEL_IF: + return add_del_if(br, args[1], args[0] == BRCTL_ADD_IF); + + case BRCTL_GET_BRIDGE_INFO: + { + struct __bridge_info b; + + memset(&b, 0, sizeof(struct __bridge_info)); + rcu_read_lock(); + memcpy(&b.designated_root, &br->designated_root, 8); + memcpy(&b.bridge_id, &br->bridge_id, 8); + b.root_path_cost = br->root_path_cost; + b.max_age = jiffies_to_clock_t(br->max_age); + b.hello_time = jiffies_to_clock_t(br->hello_time); + b.forward_delay = br->forward_delay; + b.bridge_max_age = br->bridge_max_age; + b.bridge_hello_time = br->bridge_hello_time; + b.bridge_forward_delay = jiffies_to_clock_t(br->bridge_forward_delay); + b.topology_change = br->topology_change; + b.topology_change_detected = br->topology_change_detected; + b.root_port = br->root_port; + + b.stp_enabled = (br->stp_enabled != BR_NO_STP); + b.ageing_time = jiffies_to_clock_t(br->ageing_time); + b.hello_timer_value = br_timer_value(&br->hello_timer); + b.tcn_timer_value = br_timer_value(&br->tcn_timer); + b.topology_change_timer_value = br_timer_value(&br->topology_change_timer); + b.gc_timer_value = br_timer_value(&br->gc_work.timer); + rcu_read_unlock(); + + if (copy_to_user((void __user *)args[1], &b, sizeof(b))) + return -EFAULT; + + return 0; + } + + case BRCTL_GET_PORT_LIST: + { + int num, *indices; + + num = args[2]; + if (num < 0) + return -EINVAL; + if (num == 0) + num = 256; + if (num > BR_MAX_PORTS) + num = BR_MAX_PORTS; + + indices = kcalloc(num, sizeof(int), GFP_KERNEL); + if (indices == NULL) + return -ENOMEM; + + get_port_ifindices(br, indices, num); + if (copy_to_user(argp, indices, array_size(num, sizeof(int)))) + num = -EFAULT; + kfree(indices); + return num; + } + + case BRCTL_SET_BRIDGE_FORWARD_DELAY: + if (!ns_capable(dev_net(dev)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + ret = br_set_forward_delay(br, args[1]); + break; + + case BRCTL_SET_BRIDGE_HELLO_TIME: + if (!ns_capable(dev_net(dev)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + ret = br_set_hello_time(br, args[1]); + break; + + case BRCTL_SET_BRIDGE_MAX_AGE: + if (!ns_capable(dev_net(dev)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + ret = br_set_max_age(br, args[1]); + break; + + case BRCTL_SET_AGEING_TIME: + if (!ns_capable(dev_net(dev)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + ret = br_set_ageing_time(br, args[1]); + break; + + case BRCTL_GET_PORT_INFO: + { + struct __port_info p; + struct net_bridge_port *pt; + + rcu_read_lock(); + if ((pt = br_get_port(br, args[2])) == NULL) { + rcu_read_unlock(); + return -EINVAL; + } + + memset(&p, 0, sizeof(struct __port_info)); + memcpy(&p.designated_root, &pt->designated_root, 8); + memcpy(&p.designated_bridge, &pt->designated_bridge, 8); + p.port_id = pt->port_id; + p.designated_port = pt->designated_port; + p.path_cost = pt->path_cost; + p.designated_cost = pt->designated_cost; + p.state = pt->state; + p.top_change_ack = pt->topology_change_ack; + p.config_pending = pt->config_pending; + p.message_age_timer_value = br_timer_value(&pt->message_age_timer); + p.forward_delay_timer_value = br_timer_value(&pt->forward_delay_timer); + p.hold_timer_value = br_timer_value(&pt->hold_timer); + + rcu_read_unlock(); + + if (copy_to_user(argp, &p, sizeof(p))) + return -EFAULT; + + return 0; + } + + case BRCTL_SET_BRIDGE_STP_STATE: + if (!ns_capable(dev_net(dev)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + ret = br_stp_set_enabled(br, args[1], NULL); + break; + + case BRCTL_SET_BRIDGE_PRIORITY: + if (!ns_capable(dev_net(dev)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + br_stp_set_bridge_priority(br, args[1]); + ret = 0; + break; + + case BRCTL_SET_PORT_PRIORITY: + { + if (!ns_capable(dev_net(dev)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + spin_lock_bh(&br->lock); + if ((p = br_get_port(br, args[1])) == NULL) + ret = -EINVAL; + else + ret = br_stp_set_port_priority(p, args[2]); + spin_unlock_bh(&br->lock); + break; + } + + case BRCTL_SET_PATH_COST: + { + if (!ns_capable(dev_net(dev)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + spin_lock_bh(&br->lock); + if ((p = br_get_port(br, args[1])) == NULL) + ret = -EINVAL; + else + ret = br_stp_set_path_cost(p, args[2]); + spin_unlock_bh(&br->lock); + break; + } + + case BRCTL_GET_FDB_ENTRIES: + return get_fdb_entries(br, argp, args[2], args[3]); + + default: + ret = -EOPNOTSUPP; + } + + if (!ret) { + if (p) + br_ifinfo_notify(RTM_NEWLINK, NULL, p); + else + netdev_state_change(br->dev); + } + + return ret; +} + +static int old_deviceless(struct net *net, void __user *data) +{ + unsigned long args[3]; + void __user *argp; + int ret; + + ret = br_dev_read_uargs(args, ARRAY_SIZE(args), &argp, data); + if (ret) + return ret; + + switch (args[0]) { + case BRCTL_GET_VERSION: + return BRCTL_VERSION; + + case BRCTL_GET_BRIDGES: + { + int *indices; + int ret = 0; + + if (args[2] >= 2048) + return -ENOMEM; + indices = kcalloc(args[2], sizeof(int), GFP_KERNEL); + if (indices == NULL) + return -ENOMEM; + + args[2] = get_bridge_ifindices(net, indices, args[2]); + + ret = copy_to_user(argp, indices, + array_size(args[2], sizeof(int))) + ? -EFAULT : args[2]; + + kfree(indices); + return ret; + } + + case BRCTL_ADD_BRIDGE: + case BRCTL_DEL_BRIDGE: + { + char buf[IFNAMSIZ]; + + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + if (copy_from_user(buf, argp, IFNAMSIZ)) + return -EFAULT; + + buf[IFNAMSIZ-1] = 0; + + if (args[0] == BRCTL_ADD_BRIDGE) + return br_add_bridge(net, buf); + + return br_del_bridge(net, buf); + } + } + + return -EOPNOTSUPP; +} + +int br_ioctl_stub(struct net *net, struct net_bridge *br, unsigned int cmd, + struct ifreq *ifr, void __user *uarg) +{ + int ret = -EOPNOTSUPP; + + rtnl_lock(); + + switch (cmd) { + case SIOCGIFBR: + case SIOCSIFBR: + ret = old_deviceless(net, uarg); + break; + case SIOCBRADDBR: + case SIOCBRDELBR: + { + char buf[IFNAMSIZ]; + + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) { + ret = -EPERM; + break; + } + + if (copy_from_user(buf, uarg, IFNAMSIZ)) { + ret = -EFAULT; + break; + } + + buf[IFNAMSIZ-1] = 0; + if (cmd == SIOCBRADDBR) + ret = br_add_bridge(net, buf); + else + ret = br_del_bridge(net, buf); + } + break; + case SIOCBRADDIF: + case SIOCBRDELIF: + ret = add_del_if(br, ifr->ifr_ifindex, cmd == SIOCBRADDIF); + break; + } + + rtnl_unlock(); + + return ret; +} diff --git a/net/bridge/br_mdb.c b/net/bridge/br_mdb.c new file mode 100644 index 000000000..589ff497d --- /dev/null +++ b/net/bridge/br_mdb.c @@ -0,0 +1,1164 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_IPV6) +#include +#include +#endif + +#include "br_private.h" + +static bool +br_ip4_rports_get_timer(struct net_bridge_mcast_port *pmctx, + unsigned long *timer) +{ + *timer = br_timer_value(&pmctx->ip4_mc_router_timer); + return !hlist_unhashed(&pmctx->ip4_rlist); +} + +static bool +br_ip6_rports_get_timer(struct net_bridge_mcast_port *pmctx, + unsigned long *timer) +{ +#if IS_ENABLED(CONFIG_IPV6) + *timer = br_timer_value(&pmctx->ip6_mc_router_timer); + return !hlist_unhashed(&pmctx->ip6_rlist); +#else + *timer = 0; + return false; +#endif +} + +static size_t __br_rports_one_size(void) +{ + return nla_total_size(sizeof(u32)) + /* MDBA_ROUTER_PORT */ + nla_total_size(sizeof(u32)) + /* MDBA_ROUTER_PATTR_TIMER */ + nla_total_size(sizeof(u8)) + /* MDBA_ROUTER_PATTR_TYPE */ + nla_total_size(sizeof(u32)) + /* MDBA_ROUTER_PATTR_INET_TIMER */ + nla_total_size(sizeof(u32)) + /* MDBA_ROUTER_PATTR_INET6_TIMER */ + nla_total_size(sizeof(u32)); /* MDBA_ROUTER_PATTR_VID */ +} + +size_t br_rports_size(const struct net_bridge_mcast *brmctx) +{ + struct net_bridge_mcast_port *pmctx; + size_t size = nla_total_size(0); /* MDBA_ROUTER */ + + rcu_read_lock(); + hlist_for_each_entry_rcu(pmctx, &brmctx->ip4_mc_router_list, + ip4_rlist) + size += __br_rports_one_size(); + +#if IS_ENABLED(CONFIG_IPV6) + hlist_for_each_entry_rcu(pmctx, &brmctx->ip6_mc_router_list, + ip6_rlist) + size += __br_rports_one_size(); +#endif + rcu_read_unlock(); + + return size; +} + +int br_rports_fill_info(struct sk_buff *skb, + const struct net_bridge_mcast *brmctx) +{ + u16 vid = brmctx->vlan ? brmctx->vlan->vid : 0; + bool have_ip4_mc_rtr, have_ip6_mc_rtr; + unsigned long ip4_timer, ip6_timer; + struct nlattr *nest, *port_nest; + struct net_bridge_port *p; + + if (!brmctx->multicast_router || !br_rports_have_mc_router(brmctx)) + return 0; + + nest = nla_nest_start_noflag(skb, MDBA_ROUTER); + if (nest == NULL) + return -EMSGSIZE; + + list_for_each_entry_rcu(p, &brmctx->br->port_list, list) { + struct net_bridge_mcast_port *pmctx; + + if (vid) { + struct net_bridge_vlan *v; + + v = br_vlan_find(nbp_vlan_group(p), vid); + if (!v) + continue; + pmctx = &v->port_mcast_ctx; + } else { + pmctx = &p->multicast_ctx; + } + + have_ip4_mc_rtr = br_ip4_rports_get_timer(pmctx, &ip4_timer); + have_ip6_mc_rtr = br_ip6_rports_get_timer(pmctx, &ip6_timer); + + if (!have_ip4_mc_rtr && !have_ip6_mc_rtr) + continue; + + port_nest = nla_nest_start_noflag(skb, MDBA_ROUTER_PORT); + if (!port_nest) + goto fail; + + if (nla_put_nohdr(skb, sizeof(u32), &p->dev->ifindex) || + nla_put_u32(skb, MDBA_ROUTER_PATTR_TIMER, + max(ip4_timer, ip6_timer)) || + nla_put_u8(skb, MDBA_ROUTER_PATTR_TYPE, + p->multicast_ctx.multicast_router) || + (have_ip4_mc_rtr && + nla_put_u32(skb, MDBA_ROUTER_PATTR_INET_TIMER, + ip4_timer)) || + (have_ip6_mc_rtr && + nla_put_u32(skb, MDBA_ROUTER_PATTR_INET6_TIMER, + ip6_timer)) || + (vid && nla_put_u16(skb, MDBA_ROUTER_PATTR_VID, vid))) { + nla_nest_cancel(skb, port_nest); + goto fail; + } + nla_nest_end(skb, port_nest); + } + + nla_nest_end(skb, nest); + return 0; +fail: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static void __mdb_entry_fill_flags(struct br_mdb_entry *e, unsigned char flags) +{ + e->state = flags & MDB_PG_FLAGS_PERMANENT; + e->flags = 0; + if (flags & MDB_PG_FLAGS_OFFLOAD) + e->flags |= MDB_FLAGS_OFFLOAD; + if (flags & MDB_PG_FLAGS_FAST_LEAVE) + e->flags |= MDB_FLAGS_FAST_LEAVE; + if (flags & MDB_PG_FLAGS_STAR_EXCL) + e->flags |= MDB_FLAGS_STAR_EXCL; + if (flags & MDB_PG_FLAGS_BLOCKED) + e->flags |= MDB_FLAGS_BLOCKED; +} + +static void __mdb_entry_to_br_ip(struct br_mdb_entry *entry, struct br_ip *ip, + struct nlattr **mdb_attrs) +{ + memset(ip, 0, sizeof(struct br_ip)); + ip->vid = entry->vid; + ip->proto = entry->addr.proto; + switch (ip->proto) { + case htons(ETH_P_IP): + ip->dst.ip4 = entry->addr.u.ip4; + if (mdb_attrs && mdb_attrs[MDBE_ATTR_SOURCE]) + ip->src.ip4 = nla_get_in_addr(mdb_attrs[MDBE_ATTR_SOURCE]); + break; +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): + ip->dst.ip6 = entry->addr.u.ip6; + if (mdb_attrs && mdb_attrs[MDBE_ATTR_SOURCE]) + ip->src.ip6 = nla_get_in6_addr(mdb_attrs[MDBE_ATTR_SOURCE]); + break; +#endif + default: + ether_addr_copy(ip->dst.mac_addr, entry->addr.u.mac_addr); + } + +} + +static int __mdb_fill_srcs(struct sk_buff *skb, + struct net_bridge_port_group *p) +{ + struct net_bridge_group_src *ent; + struct nlattr *nest, *nest_ent; + + if (hlist_empty(&p->src_list)) + return 0; + + nest = nla_nest_start(skb, MDBA_MDB_EATTR_SRC_LIST); + if (!nest) + return -EMSGSIZE; + + hlist_for_each_entry_rcu(ent, &p->src_list, node, + lockdep_is_held(&p->key.port->br->multicast_lock)) { + nest_ent = nla_nest_start(skb, MDBA_MDB_SRCLIST_ENTRY); + if (!nest_ent) + goto out_cancel_err; + switch (ent->addr.proto) { + case htons(ETH_P_IP): + if (nla_put_in_addr(skb, MDBA_MDB_SRCATTR_ADDRESS, + ent->addr.src.ip4)) { + nla_nest_cancel(skb, nest_ent); + goto out_cancel_err; + } + break; +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): + if (nla_put_in6_addr(skb, MDBA_MDB_SRCATTR_ADDRESS, + &ent->addr.src.ip6)) { + nla_nest_cancel(skb, nest_ent); + goto out_cancel_err; + } + break; +#endif + default: + nla_nest_cancel(skb, nest_ent); + continue; + } + if (nla_put_u32(skb, MDBA_MDB_SRCATTR_TIMER, + br_timer_value(&ent->timer))) { + nla_nest_cancel(skb, nest_ent); + goto out_cancel_err; + } + nla_nest_end(skb, nest_ent); + } + + nla_nest_end(skb, nest); + + return 0; + +out_cancel_err: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int __mdb_fill_info(struct sk_buff *skb, + struct net_bridge_mdb_entry *mp, + struct net_bridge_port_group *p) +{ + bool dump_srcs_mode = false; + struct timer_list *mtimer; + struct nlattr *nest_ent; + struct br_mdb_entry e; + u8 flags = 0; + int ifindex; + + memset(&e, 0, sizeof(e)); + if (p) { + ifindex = p->key.port->dev->ifindex; + mtimer = &p->timer; + flags = p->flags; + } else { + ifindex = mp->br->dev->ifindex; + mtimer = &mp->timer; + } + + __mdb_entry_fill_flags(&e, flags); + e.ifindex = ifindex; + e.vid = mp->addr.vid; + if (mp->addr.proto == htons(ETH_P_IP)) { + e.addr.u.ip4 = mp->addr.dst.ip4; +#if IS_ENABLED(CONFIG_IPV6) + } else if (mp->addr.proto == htons(ETH_P_IPV6)) { + e.addr.u.ip6 = mp->addr.dst.ip6; +#endif + } else { + ether_addr_copy(e.addr.u.mac_addr, mp->addr.dst.mac_addr); + e.state = MDB_PG_FLAGS_PERMANENT; + } + e.addr.proto = mp->addr.proto; + nest_ent = nla_nest_start_noflag(skb, + MDBA_MDB_ENTRY_INFO); + if (!nest_ent) + return -EMSGSIZE; + + if (nla_put_nohdr(skb, sizeof(e), &e) || + nla_put_u32(skb, + MDBA_MDB_EATTR_TIMER, + br_timer_value(mtimer))) + goto nest_err; + + switch (mp->addr.proto) { + case htons(ETH_P_IP): + dump_srcs_mode = !!(mp->br->multicast_ctx.multicast_igmp_version == 3); + if (mp->addr.src.ip4) { + if (nla_put_in_addr(skb, MDBA_MDB_EATTR_SOURCE, + mp->addr.src.ip4)) + goto nest_err; + break; + } + break; +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): + dump_srcs_mode = !!(mp->br->multicast_ctx.multicast_mld_version == 2); + if (!ipv6_addr_any(&mp->addr.src.ip6)) { + if (nla_put_in6_addr(skb, MDBA_MDB_EATTR_SOURCE, + &mp->addr.src.ip6)) + goto nest_err; + break; + } + break; +#endif + default: + ether_addr_copy(e.addr.u.mac_addr, mp->addr.dst.mac_addr); + } + if (p) { + if (nla_put_u8(skb, MDBA_MDB_EATTR_RTPROT, p->rt_protocol)) + goto nest_err; + if (dump_srcs_mode && + (__mdb_fill_srcs(skb, p) || + nla_put_u8(skb, MDBA_MDB_EATTR_GROUP_MODE, + p->filter_mode))) + goto nest_err; + } + nla_nest_end(skb, nest_ent); + + return 0; + +nest_err: + nla_nest_cancel(skb, nest_ent); + return -EMSGSIZE; +} + +static int br_mdb_fill_info(struct sk_buff *skb, struct netlink_callback *cb, + struct net_device *dev) +{ + int idx = 0, s_idx = cb->args[1], err = 0, pidx = 0, s_pidx = cb->args[2]; + struct net_bridge *br = netdev_priv(dev); + struct net_bridge_mdb_entry *mp; + struct nlattr *nest, *nest2; + + if (!br_opt_get(br, BROPT_MULTICAST_ENABLED)) + return 0; + + nest = nla_nest_start_noflag(skb, MDBA_MDB); + if (nest == NULL) + return -EMSGSIZE; + + hlist_for_each_entry_rcu(mp, &br->mdb_list, mdb_node) { + struct net_bridge_port_group *p; + struct net_bridge_port_group __rcu **pp; + + if (idx < s_idx) + goto skip; + + nest2 = nla_nest_start_noflag(skb, MDBA_MDB_ENTRY); + if (!nest2) { + err = -EMSGSIZE; + break; + } + + if (!s_pidx && mp->host_joined) { + err = __mdb_fill_info(skb, mp, NULL); + if (err) { + nla_nest_cancel(skb, nest2); + break; + } + } + + for (pp = &mp->ports; (p = rcu_dereference(*pp)) != NULL; + pp = &p->next) { + if (!p->key.port) + continue; + if (pidx < s_pidx) + goto skip_pg; + + err = __mdb_fill_info(skb, mp, p); + if (err) { + nla_nest_end(skb, nest2); + goto out; + } +skip_pg: + pidx++; + } + pidx = 0; + s_pidx = 0; + nla_nest_end(skb, nest2); +skip: + idx++; + } + +out: + cb->args[1] = idx; + cb->args[2] = pidx; + nla_nest_end(skb, nest); + return err; +} + +static int br_mdb_valid_dump_req(const struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct br_port_msg *bpm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*bpm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for mdb dump request"); + return -EINVAL; + } + + bpm = nlmsg_data(nlh); + if (bpm->ifindex) { + NL_SET_ERR_MSG_MOD(extack, "Filtering by device index is not supported for mdb dump request"); + return -EINVAL; + } + if (nlmsg_attrlen(nlh, sizeof(*bpm))) { + NL_SET_ERR_MSG(extack, "Invalid data after header in mdb dump request"); + return -EINVAL; + } + + return 0; +} + +static int br_mdb_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net_device *dev; + struct net *net = sock_net(skb->sk); + struct nlmsghdr *nlh = NULL; + int idx = 0, s_idx; + + if (cb->strict_check) { + int err = br_mdb_valid_dump_req(cb->nlh, cb->extack); + + if (err < 0) + return err; + } + + s_idx = cb->args[0]; + + rcu_read_lock(); + + cb->seq = net->dev_base_seq; + + for_each_netdev_rcu(net, dev) { + if (netif_is_bridge_master(dev)) { + struct net_bridge *br = netdev_priv(dev); + struct br_port_msg *bpm; + + if (idx < s_idx) + goto skip; + + nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, RTM_GETMDB, + sizeof(*bpm), NLM_F_MULTI); + if (nlh == NULL) + break; + + bpm = nlmsg_data(nlh); + memset(bpm, 0, sizeof(*bpm)); + bpm->ifindex = dev->ifindex; + if (br_mdb_fill_info(skb, cb, dev) < 0) + goto out; + if (br_rports_fill_info(skb, &br->multicast_ctx) < 0) + goto out; + + cb->args[1] = 0; + nlmsg_end(skb, nlh); + skip: + idx++; + } + } + +out: + if (nlh) + nlmsg_end(skb, nlh); + rcu_read_unlock(); + cb->args[0] = idx; + return skb->len; +} + +static int nlmsg_populate_mdb_fill(struct sk_buff *skb, + struct net_device *dev, + struct net_bridge_mdb_entry *mp, + struct net_bridge_port_group *pg, + int type) +{ + struct nlmsghdr *nlh; + struct br_port_msg *bpm; + struct nlattr *nest, *nest2; + + nlh = nlmsg_put(skb, 0, 0, type, sizeof(*bpm), 0); + if (!nlh) + return -EMSGSIZE; + + bpm = nlmsg_data(nlh); + memset(bpm, 0, sizeof(*bpm)); + bpm->family = AF_BRIDGE; + bpm->ifindex = dev->ifindex; + nest = nla_nest_start_noflag(skb, MDBA_MDB); + if (nest == NULL) + goto cancel; + nest2 = nla_nest_start_noflag(skb, MDBA_MDB_ENTRY); + if (nest2 == NULL) + goto end; + + if (__mdb_fill_info(skb, mp, pg)) + goto end; + + nla_nest_end(skb, nest2); + nla_nest_end(skb, nest); + nlmsg_end(skb, nlh); + return 0; + +end: + nla_nest_end(skb, nest); +cancel: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static size_t rtnl_mdb_nlmsg_size(struct net_bridge_port_group *pg) +{ + size_t nlmsg_size = NLMSG_ALIGN(sizeof(struct br_port_msg)) + + nla_total_size(sizeof(struct br_mdb_entry)) + + nla_total_size(sizeof(u32)); + struct net_bridge_group_src *ent; + size_t addr_size = 0; + + if (!pg) + goto out; + + /* MDBA_MDB_EATTR_RTPROT */ + nlmsg_size += nla_total_size(sizeof(u8)); + + switch (pg->key.addr.proto) { + case htons(ETH_P_IP): + /* MDBA_MDB_EATTR_SOURCE */ + if (pg->key.addr.src.ip4) + nlmsg_size += nla_total_size(sizeof(__be32)); + if (pg->key.port->br->multicast_ctx.multicast_igmp_version == 2) + goto out; + addr_size = sizeof(__be32); + break; +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): + /* MDBA_MDB_EATTR_SOURCE */ + if (!ipv6_addr_any(&pg->key.addr.src.ip6)) + nlmsg_size += nla_total_size(sizeof(struct in6_addr)); + if (pg->key.port->br->multicast_ctx.multicast_mld_version == 1) + goto out; + addr_size = sizeof(struct in6_addr); + break; +#endif + } + + /* MDBA_MDB_EATTR_GROUP_MODE */ + nlmsg_size += nla_total_size(sizeof(u8)); + + /* MDBA_MDB_EATTR_SRC_LIST nested attr */ + if (!hlist_empty(&pg->src_list)) + nlmsg_size += nla_total_size(0); + + hlist_for_each_entry(ent, &pg->src_list, node) { + /* MDBA_MDB_SRCLIST_ENTRY nested attr + + * MDBA_MDB_SRCATTR_ADDRESS + MDBA_MDB_SRCATTR_TIMER + */ + nlmsg_size += nla_total_size(0) + + nla_total_size(addr_size) + + nla_total_size(sizeof(u32)); + } +out: + return nlmsg_size; +} + +void br_mdb_notify(struct net_device *dev, + struct net_bridge_mdb_entry *mp, + struct net_bridge_port_group *pg, + int type) +{ + struct net *net = dev_net(dev); + struct sk_buff *skb; + int err = -ENOBUFS; + + br_switchdev_mdb_notify(dev, mp, pg, type); + + skb = nlmsg_new(rtnl_mdb_nlmsg_size(pg), GFP_ATOMIC); + if (!skb) + goto errout; + + err = nlmsg_populate_mdb_fill(skb, dev, mp, pg, type); + if (err < 0) { + kfree_skb(skb); + goto errout; + } + + rtnl_notify(skb, net, 0, RTNLGRP_MDB, NULL, GFP_ATOMIC); + return; +errout: + rtnl_set_sk_err(net, RTNLGRP_MDB, err); +} + +static int nlmsg_populate_rtr_fill(struct sk_buff *skb, + struct net_device *dev, + int ifindex, u16 vid, u32 pid, + u32 seq, int type, unsigned int flags) +{ + struct nlattr *nest, *port_nest; + struct br_port_msg *bpm; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(skb, pid, seq, type, sizeof(*bpm), 0); + if (!nlh) + return -EMSGSIZE; + + bpm = nlmsg_data(nlh); + memset(bpm, 0, sizeof(*bpm)); + bpm->family = AF_BRIDGE; + bpm->ifindex = dev->ifindex; + nest = nla_nest_start_noflag(skb, MDBA_ROUTER); + if (!nest) + goto cancel; + + port_nest = nla_nest_start_noflag(skb, MDBA_ROUTER_PORT); + if (!port_nest) + goto end; + if (nla_put_nohdr(skb, sizeof(u32), &ifindex)) { + nla_nest_cancel(skb, port_nest); + goto end; + } + if (vid && nla_put_u16(skb, MDBA_ROUTER_PATTR_VID, vid)) { + nla_nest_cancel(skb, port_nest); + goto end; + } + nla_nest_end(skb, port_nest); + + nla_nest_end(skb, nest); + nlmsg_end(skb, nlh); + return 0; + +end: + nla_nest_end(skb, nest); +cancel: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static inline size_t rtnl_rtr_nlmsg_size(void) +{ + return NLMSG_ALIGN(sizeof(struct br_port_msg)) + + nla_total_size(sizeof(__u32)) + + nla_total_size(sizeof(u16)); +} + +void br_rtr_notify(struct net_device *dev, struct net_bridge_mcast_port *pmctx, + int type) +{ + struct net *net = dev_net(dev); + struct sk_buff *skb; + int err = -ENOBUFS; + int ifindex; + u16 vid; + + ifindex = pmctx ? pmctx->port->dev->ifindex : 0; + vid = pmctx && br_multicast_port_ctx_is_vlan(pmctx) ? pmctx->vlan->vid : + 0; + skb = nlmsg_new(rtnl_rtr_nlmsg_size(), GFP_ATOMIC); + if (!skb) + goto errout; + + err = nlmsg_populate_rtr_fill(skb, dev, ifindex, vid, 0, 0, type, + NTF_SELF); + if (err < 0) { + kfree_skb(skb); + goto errout; + } + + rtnl_notify(skb, net, 0, RTNLGRP_MDB, NULL, GFP_ATOMIC); + return; + +errout: + rtnl_set_sk_err(net, RTNLGRP_MDB, err); +} + +static bool is_valid_mdb_entry(struct br_mdb_entry *entry, + struct netlink_ext_ack *extack) +{ + if (entry->ifindex == 0) { + NL_SET_ERR_MSG_MOD(extack, "Zero entry ifindex is not allowed"); + return false; + } + + if (entry->addr.proto == htons(ETH_P_IP)) { + if (!ipv4_is_multicast(entry->addr.u.ip4)) { + NL_SET_ERR_MSG_MOD(extack, "IPv4 entry group address is not multicast"); + return false; + } + if (ipv4_is_local_multicast(entry->addr.u.ip4)) { + NL_SET_ERR_MSG_MOD(extack, "IPv4 entry group address is local multicast"); + return false; + } +#if IS_ENABLED(CONFIG_IPV6) + } else if (entry->addr.proto == htons(ETH_P_IPV6)) { + if (ipv6_addr_is_ll_all_nodes(&entry->addr.u.ip6)) { + NL_SET_ERR_MSG_MOD(extack, "IPv6 entry group address is link-local all nodes"); + return false; + } +#endif + } else if (entry->addr.proto == 0) { + /* L2 mdb */ + if (!is_multicast_ether_addr(entry->addr.u.mac_addr)) { + NL_SET_ERR_MSG_MOD(extack, "L2 entry group is not multicast"); + return false; + } + } else { + NL_SET_ERR_MSG_MOD(extack, "Unknown entry protocol"); + return false; + } + + if (entry->state != MDB_PERMANENT && entry->state != MDB_TEMPORARY) { + NL_SET_ERR_MSG_MOD(extack, "Unknown entry state"); + return false; + } + if (entry->vid >= VLAN_VID_MASK) { + NL_SET_ERR_MSG_MOD(extack, "Invalid entry VLAN id"); + return false; + } + + return true; +} + +static bool is_valid_mdb_source(struct nlattr *attr, __be16 proto, + struct netlink_ext_ack *extack) +{ + switch (proto) { + case htons(ETH_P_IP): + if (nla_len(attr) != sizeof(struct in_addr)) { + NL_SET_ERR_MSG_MOD(extack, "IPv4 invalid source address length"); + return false; + } + if (ipv4_is_multicast(nla_get_in_addr(attr))) { + NL_SET_ERR_MSG_MOD(extack, "IPv4 multicast source address is not allowed"); + return false; + } + break; +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): { + struct in6_addr src; + + if (nla_len(attr) != sizeof(struct in6_addr)) { + NL_SET_ERR_MSG_MOD(extack, "IPv6 invalid source address length"); + return false; + } + src = nla_get_in6_addr(attr); + if (ipv6_addr_is_multicast(&src)) { + NL_SET_ERR_MSG_MOD(extack, "IPv6 multicast source address is not allowed"); + return false; + } + break; + } +#endif + default: + NL_SET_ERR_MSG_MOD(extack, "Invalid protocol used with source address"); + return false; + } + + return true; +} + +static const struct nla_policy br_mdbe_attrs_pol[MDBE_ATTR_MAX + 1] = { + [MDBE_ATTR_SOURCE] = NLA_POLICY_RANGE(NLA_BINARY, + sizeof(struct in_addr), + sizeof(struct in6_addr)), +}; + +static int br_mdb_parse(struct sk_buff *skb, struct nlmsghdr *nlh, + struct net_device **pdev, struct br_mdb_entry **pentry, + struct nlattr **mdb_attrs, struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct br_mdb_entry *entry; + struct br_port_msg *bpm; + struct nlattr *tb[MDBA_SET_ENTRY_MAX+1]; + struct net_device *dev; + int err; + + err = nlmsg_parse_deprecated(nlh, sizeof(*bpm), tb, + MDBA_SET_ENTRY_MAX, NULL, NULL); + if (err < 0) + return err; + + bpm = nlmsg_data(nlh); + if (bpm->ifindex == 0) { + NL_SET_ERR_MSG_MOD(extack, "Invalid bridge ifindex"); + return -EINVAL; + } + + dev = __dev_get_by_index(net, bpm->ifindex); + if (dev == NULL) { + NL_SET_ERR_MSG_MOD(extack, "Bridge device doesn't exist"); + return -ENODEV; + } + + if (!netif_is_bridge_master(dev)) { + NL_SET_ERR_MSG_MOD(extack, "Device is not a bridge"); + return -EOPNOTSUPP; + } + + *pdev = dev; + + if (!tb[MDBA_SET_ENTRY]) { + NL_SET_ERR_MSG_MOD(extack, "Missing MDBA_SET_ENTRY attribute"); + return -EINVAL; + } + if (nla_len(tb[MDBA_SET_ENTRY]) != sizeof(struct br_mdb_entry)) { + NL_SET_ERR_MSG_MOD(extack, "Invalid MDBA_SET_ENTRY attribute length"); + return -EINVAL; + } + + entry = nla_data(tb[MDBA_SET_ENTRY]); + if (!is_valid_mdb_entry(entry, extack)) + return -EINVAL; + *pentry = entry; + + if (tb[MDBA_SET_ENTRY_ATTRS]) { + err = nla_parse_nested(mdb_attrs, MDBE_ATTR_MAX, + tb[MDBA_SET_ENTRY_ATTRS], + br_mdbe_attrs_pol, extack); + if (err) + return err; + if (mdb_attrs[MDBE_ATTR_SOURCE] && + !is_valid_mdb_source(mdb_attrs[MDBE_ATTR_SOURCE], + entry->addr.proto, extack)) + return -EINVAL; + } else { + memset(mdb_attrs, 0, + sizeof(struct nlattr *) * (MDBE_ATTR_MAX + 1)); + } + + return 0; +} + +static struct net_bridge_mcast * +__br_mdb_choose_context(struct net_bridge *br, + const struct br_mdb_entry *entry, + struct netlink_ext_ack *extack) +{ + struct net_bridge_mcast *brmctx = NULL; + struct net_bridge_vlan *v; + + if (!br_opt_get(br, BROPT_MCAST_VLAN_SNOOPING_ENABLED)) { + brmctx = &br->multicast_ctx; + goto out; + } + + if (!entry->vid) { + NL_SET_ERR_MSG_MOD(extack, "Cannot add an entry without a vlan when vlan snooping is enabled"); + goto out; + } + + v = br_vlan_find(br_vlan_group(br), entry->vid); + if (!v) { + NL_SET_ERR_MSG_MOD(extack, "Vlan is not configured"); + goto out; + } + if (br_multicast_ctx_vlan_global_disabled(&v->br_mcast_ctx)) { + NL_SET_ERR_MSG_MOD(extack, "Vlan's multicast processing is disabled"); + goto out; + } + brmctx = &v->br_mcast_ctx; +out: + return brmctx; +} + +static int br_mdb_add_group(struct net_bridge *br, struct net_bridge_port *port, + struct br_mdb_entry *entry, + struct nlattr **mdb_attrs, + struct netlink_ext_ack *extack) +{ + struct net_bridge_mdb_entry *mp, *star_mp; + struct net_bridge_port_group __rcu **pp; + struct net_bridge_port_group *p; + struct net_bridge_mcast *brmctx; + struct br_ip group, star_group; + unsigned long now = jiffies; + unsigned char flags = 0; + u8 filter_mode; + int err; + + __mdb_entry_to_br_ip(entry, &group, mdb_attrs); + + brmctx = __br_mdb_choose_context(br, entry, extack); + if (!brmctx) + return -EINVAL; + + /* host join errors which can happen before creating the group */ + if (!port && !br_group_is_l2(&group)) { + /* don't allow any flags for host-joined IP groups */ + if (entry->state) { + NL_SET_ERR_MSG_MOD(extack, "Flags are not allowed for host groups"); + return -EINVAL; + } + if (!br_multicast_is_star_g(&group)) { + NL_SET_ERR_MSG_MOD(extack, "Groups with sources cannot be manually host joined"); + return -EINVAL; + } + } + + if (br_group_is_l2(&group) && entry->state != MDB_PERMANENT) { + NL_SET_ERR_MSG_MOD(extack, "Only permanent L2 entries allowed"); + return -EINVAL; + } + + mp = br_mdb_ip_get(br, &group); + if (!mp) { + mp = br_multicast_new_group(br, &group); + err = PTR_ERR_OR_ZERO(mp); + if (err) + return err; + } + + /* host join */ + if (!port) { + if (mp->host_joined) { + NL_SET_ERR_MSG_MOD(extack, "Group is already joined by host"); + return -EEXIST; + } + + br_multicast_host_join(brmctx, mp, false); + br_mdb_notify(br->dev, mp, NULL, RTM_NEWMDB); + + return 0; + } + + for (pp = &mp->ports; + (p = mlock_dereference(*pp, br)) != NULL; + pp = &p->next) { + if (p->key.port == port) { + NL_SET_ERR_MSG_MOD(extack, "Group is already joined by port"); + return -EEXIST; + } + if ((unsigned long)p->key.port < (unsigned long)port) + break; + } + + filter_mode = br_multicast_is_star_g(&group) ? MCAST_EXCLUDE : + MCAST_INCLUDE; + + if (entry->state == MDB_PERMANENT) + flags |= MDB_PG_FLAGS_PERMANENT; + + p = br_multicast_new_port_group(port, &group, *pp, flags, NULL, + filter_mode, RTPROT_STATIC); + if (unlikely(!p)) { + NL_SET_ERR_MSG_MOD(extack, "Couldn't allocate new port group"); + return -ENOMEM; + } + rcu_assign_pointer(*pp, p); + if (entry->state == MDB_TEMPORARY) + mod_timer(&p->timer, + now + brmctx->multicast_membership_interval); + br_mdb_notify(br->dev, mp, p, RTM_NEWMDB); + /* if we are adding a new EXCLUDE port group (*,G) it needs to be also + * added to all S,G entries for proper replication, if we are adding + * a new INCLUDE port (S,G) then all of *,G EXCLUDE ports need to be + * added to it for proper replication + */ + if (br_multicast_should_handle_mode(brmctx, group.proto)) { + switch (filter_mode) { + case MCAST_EXCLUDE: + br_multicast_star_g_handle_mode(p, MCAST_EXCLUDE); + break; + case MCAST_INCLUDE: + star_group = p->key.addr; + memset(&star_group.src, 0, sizeof(star_group.src)); + star_mp = br_mdb_ip_get(br, &star_group); + if (star_mp) + br_multicast_sg_add_exclude_ports(star_mp, p); + break; + } + } + + return 0; +} + +static int __br_mdb_add(struct net *net, struct net_bridge *br, + struct net_bridge_port *p, + struct br_mdb_entry *entry, + struct nlattr **mdb_attrs, + struct netlink_ext_ack *extack) +{ + int ret; + + spin_lock_bh(&br->multicast_lock); + ret = br_mdb_add_group(br, p, entry, mdb_attrs, extack); + spin_unlock_bh(&br->multicast_lock); + + return ret; +} + +static int br_mdb_add(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct nlattr *mdb_attrs[MDBE_ATTR_MAX + 1]; + struct net *net = sock_net(skb->sk); + struct net_bridge_vlan_group *vg; + struct net_bridge_port *p = NULL; + struct net_device *dev, *pdev; + struct br_mdb_entry *entry; + struct net_bridge_vlan *v; + struct net_bridge *br; + int err; + + err = br_mdb_parse(skb, nlh, &dev, &entry, mdb_attrs, extack); + if (err < 0) + return err; + + br = netdev_priv(dev); + + if (!netif_running(br->dev)) { + NL_SET_ERR_MSG_MOD(extack, "Bridge device is not running"); + return -EINVAL; + } + + if (!br_opt_get(br, BROPT_MULTICAST_ENABLED)) { + NL_SET_ERR_MSG_MOD(extack, "Bridge's multicast processing is disabled"); + return -EINVAL; + } + + if (entry->ifindex != br->dev->ifindex) { + pdev = __dev_get_by_index(net, entry->ifindex); + if (!pdev) { + NL_SET_ERR_MSG_MOD(extack, "Port net device doesn't exist"); + return -ENODEV; + } + + p = br_port_get_rtnl(pdev); + if (!p) { + NL_SET_ERR_MSG_MOD(extack, "Net device is not a bridge port"); + return -EINVAL; + } + + if (p->br != br) { + NL_SET_ERR_MSG_MOD(extack, "Port belongs to a different bridge device"); + return -EINVAL; + } + if (p->state == BR_STATE_DISABLED && entry->state != MDB_PERMANENT) { + NL_SET_ERR_MSG_MOD(extack, "Port is in disabled state and entry is not permanent"); + return -EINVAL; + } + vg = nbp_vlan_group(p); + } else { + vg = br_vlan_group(br); + } + + /* If vlan filtering is enabled and VLAN is not specified + * install mdb entry on all vlans configured on the port. + */ + if (br_vlan_enabled(br->dev) && vg && entry->vid == 0) { + list_for_each_entry(v, &vg->vlan_list, vlist) { + entry->vid = v->vid; + err = __br_mdb_add(net, br, p, entry, mdb_attrs, extack); + if (err) + break; + } + } else { + err = __br_mdb_add(net, br, p, entry, mdb_attrs, extack); + } + + return err; +} + +static int __br_mdb_del(struct net_bridge *br, struct br_mdb_entry *entry, + struct nlattr **mdb_attrs) +{ + struct net_bridge_mdb_entry *mp; + struct net_bridge_port_group *p; + struct net_bridge_port_group __rcu **pp; + struct br_ip ip; + int err = -EINVAL; + + if (!netif_running(br->dev) || !br_opt_get(br, BROPT_MULTICAST_ENABLED)) + return -EINVAL; + + __mdb_entry_to_br_ip(entry, &ip, mdb_attrs); + + spin_lock_bh(&br->multicast_lock); + mp = br_mdb_ip_get(br, &ip); + if (!mp) + goto unlock; + + /* host leave */ + if (entry->ifindex == mp->br->dev->ifindex && mp->host_joined) { + br_multicast_host_leave(mp, false); + err = 0; + br_mdb_notify(br->dev, mp, NULL, RTM_DELMDB); + if (!mp->ports && netif_running(br->dev)) + mod_timer(&mp->timer, jiffies); + goto unlock; + } + + for (pp = &mp->ports; + (p = mlock_dereference(*pp, br)) != NULL; + pp = &p->next) { + if (!p->key.port || p->key.port->dev->ifindex != entry->ifindex) + continue; + + br_multicast_del_pg(mp, p, pp); + err = 0; + break; + } + +unlock: + spin_unlock_bh(&br->multicast_lock); + return err; +} + +static int br_mdb_del(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct nlattr *mdb_attrs[MDBE_ATTR_MAX + 1]; + struct net *net = sock_net(skb->sk); + struct net_bridge_vlan_group *vg; + struct net_bridge_port *p = NULL; + struct net_device *dev, *pdev; + struct br_mdb_entry *entry; + struct net_bridge_vlan *v; + struct net_bridge *br; + int err; + + err = br_mdb_parse(skb, nlh, &dev, &entry, mdb_attrs, extack); + if (err < 0) + return err; + + br = netdev_priv(dev); + + if (entry->ifindex != br->dev->ifindex) { + pdev = __dev_get_by_index(net, entry->ifindex); + if (!pdev) + return -ENODEV; + + p = br_port_get_rtnl(pdev); + if (!p) { + NL_SET_ERR_MSG_MOD(extack, "Net device is not a bridge port"); + return -EINVAL; + } + if (p->br != br) { + NL_SET_ERR_MSG_MOD(extack, "Port belongs to a different bridge device"); + return -EINVAL; + } + vg = nbp_vlan_group(p); + } else { + vg = br_vlan_group(br); + } + + /* If vlan filtering is enabled and VLAN is not specified + * delete mdb entry on all vlans configured on the port. + */ + if (br_vlan_enabled(br->dev) && vg && entry->vid == 0) { + list_for_each_entry(v, &vg->vlan_list, vlist) { + entry->vid = v->vid; + err = __br_mdb_del(br, entry, mdb_attrs); + } + } else { + err = __br_mdb_del(br, entry, mdb_attrs); + } + + return err; +} + +void br_mdb_init(void) +{ + rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_GETMDB, NULL, br_mdb_dump, 0); + rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_NEWMDB, br_mdb_add, NULL, 0); + rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_DELMDB, br_mdb_del, NULL, 0); +} + +void br_mdb_uninit(void) +{ + rtnl_unregister(PF_BRIDGE, RTM_GETMDB); + rtnl_unregister(PF_BRIDGE, RTM_NEWMDB); + rtnl_unregister(PF_BRIDGE, RTM_DELMDB); +} diff --git a/net/bridge/br_mrp.c b/net/bridge/br_mrp.c new file mode 100644 index 000000000..fd2de35ff --- /dev/null +++ b/net/bridge/br_mrp.c @@ -0,0 +1,1260 @@ +// SPDX-License-Identifier: GPL-2.0-or-later + +#include +#include "br_private_mrp.h" + +static const u8 mrp_test_dmac[ETH_ALEN] = { 0x1, 0x15, 0x4e, 0x0, 0x0, 0x1 }; +static const u8 mrp_in_test_dmac[ETH_ALEN] = { 0x1, 0x15, 0x4e, 0x0, 0x0, 0x3 }; + +static int br_mrp_process(struct net_bridge_port *p, struct sk_buff *skb); + +static struct br_frame_type mrp_frame_type __read_mostly = { + .type = cpu_to_be16(ETH_P_MRP), + .frame_handler = br_mrp_process, +}; + +static bool br_mrp_is_ring_port(struct net_bridge_port *p_port, + struct net_bridge_port *s_port, + struct net_bridge_port *port) +{ + if (port == p_port || + port == s_port) + return true; + + return false; +} + +static bool br_mrp_is_in_port(struct net_bridge_port *i_port, + struct net_bridge_port *port) +{ + if (port == i_port) + return true; + + return false; +} + +static struct net_bridge_port *br_mrp_get_port(struct net_bridge *br, + u32 ifindex) +{ + struct net_bridge_port *res = NULL; + struct net_bridge_port *port; + + list_for_each_entry(port, &br->port_list, list) { + if (port->dev->ifindex == ifindex) { + res = port; + break; + } + } + + return res; +} + +static struct br_mrp *br_mrp_find_id(struct net_bridge *br, u32 ring_id) +{ + struct br_mrp *res = NULL; + struct br_mrp *mrp; + + hlist_for_each_entry_rcu(mrp, &br->mrp_list, list, + lockdep_rtnl_is_held()) { + if (mrp->ring_id == ring_id) { + res = mrp; + break; + } + } + + return res; +} + +static struct br_mrp *br_mrp_find_in_id(struct net_bridge *br, u32 in_id) +{ + struct br_mrp *res = NULL; + struct br_mrp *mrp; + + hlist_for_each_entry_rcu(mrp, &br->mrp_list, list, + lockdep_rtnl_is_held()) { + if (mrp->in_id == in_id) { + res = mrp; + break; + } + } + + return res; +} + +static bool br_mrp_unique_ifindex(struct net_bridge *br, u32 ifindex) +{ + struct br_mrp *mrp; + + hlist_for_each_entry_rcu(mrp, &br->mrp_list, list, + lockdep_rtnl_is_held()) { + struct net_bridge_port *p; + + p = rtnl_dereference(mrp->p_port); + if (p && p->dev->ifindex == ifindex) + return false; + + p = rtnl_dereference(mrp->s_port); + if (p && p->dev->ifindex == ifindex) + return false; + + p = rtnl_dereference(mrp->i_port); + if (p && p->dev->ifindex == ifindex) + return false; + } + + return true; +} + +static struct br_mrp *br_mrp_find_port(struct net_bridge *br, + struct net_bridge_port *p) +{ + struct br_mrp *res = NULL; + struct br_mrp *mrp; + + hlist_for_each_entry_rcu(mrp, &br->mrp_list, list, + lockdep_rtnl_is_held()) { + if (rcu_access_pointer(mrp->p_port) == p || + rcu_access_pointer(mrp->s_port) == p || + rcu_access_pointer(mrp->i_port) == p) { + res = mrp; + break; + } + } + + return res; +} + +static int br_mrp_next_seq(struct br_mrp *mrp) +{ + mrp->seq_id++; + return mrp->seq_id; +} + +static struct sk_buff *br_mrp_skb_alloc(struct net_bridge_port *p, + const u8 *src, const u8 *dst) +{ + struct ethhdr *eth_hdr; + struct sk_buff *skb; + __be16 *version; + + skb = dev_alloc_skb(MRP_MAX_FRAME_LENGTH); + if (!skb) + return NULL; + + skb->dev = p->dev; + skb->protocol = htons(ETH_P_MRP); + skb->priority = MRP_FRAME_PRIO; + skb_reserve(skb, sizeof(*eth_hdr)); + + eth_hdr = skb_push(skb, sizeof(*eth_hdr)); + ether_addr_copy(eth_hdr->h_dest, dst); + ether_addr_copy(eth_hdr->h_source, src); + eth_hdr->h_proto = htons(ETH_P_MRP); + + version = skb_put(skb, sizeof(*version)); + *version = cpu_to_be16(MRP_VERSION); + + return skb; +} + +static void br_mrp_skb_tlv(struct sk_buff *skb, + enum br_mrp_tlv_header_type type, + u8 length) +{ + struct br_mrp_tlv_hdr *hdr; + + hdr = skb_put(skb, sizeof(*hdr)); + hdr->type = type; + hdr->length = length; +} + +static void br_mrp_skb_common(struct sk_buff *skb, struct br_mrp *mrp) +{ + struct br_mrp_common_hdr *hdr; + + br_mrp_skb_tlv(skb, BR_MRP_TLV_HEADER_COMMON, sizeof(*hdr)); + + hdr = skb_put(skb, sizeof(*hdr)); + hdr->seq_id = cpu_to_be16(br_mrp_next_seq(mrp)); + memset(hdr->domain, 0xff, MRP_DOMAIN_UUID_LENGTH); +} + +static struct sk_buff *br_mrp_alloc_test_skb(struct br_mrp *mrp, + struct net_bridge_port *p, + enum br_mrp_port_role_type port_role) +{ + struct br_mrp_ring_test_hdr *hdr = NULL; + struct sk_buff *skb = NULL; + + if (!p) + return NULL; + + skb = br_mrp_skb_alloc(p, p->dev->dev_addr, mrp_test_dmac); + if (!skb) + return NULL; + + br_mrp_skb_tlv(skb, BR_MRP_TLV_HEADER_RING_TEST, sizeof(*hdr)); + hdr = skb_put(skb, sizeof(*hdr)); + + hdr->prio = cpu_to_be16(mrp->prio); + ether_addr_copy(hdr->sa, p->br->dev->dev_addr); + hdr->port_role = cpu_to_be16(port_role); + hdr->state = cpu_to_be16(mrp->ring_state); + hdr->transitions = cpu_to_be16(mrp->ring_transitions); + hdr->timestamp = cpu_to_be32(jiffies_to_msecs(jiffies)); + + br_mrp_skb_common(skb, mrp); + + /* In case the node behaves as MRA then the Test frame needs to have + * an Option TLV which includes eventually a sub-option TLV that has + * the type AUTO_MGR + */ + if (mrp->ring_role == BR_MRP_RING_ROLE_MRA) { + struct br_mrp_sub_option1_hdr *sub_opt = NULL; + struct br_mrp_tlv_hdr *sub_tlv = NULL; + struct br_mrp_oui_hdr *oui = NULL; + u8 length; + + length = sizeof(*sub_opt) + sizeof(*sub_tlv) + sizeof(oui) + + MRP_OPT_PADDING; + br_mrp_skb_tlv(skb, BR_MRP_TLV_HEADER_OPTION, length); + + oui = skb_put(skb, sizeof(*oui)); + memset(oui, 0x0, sizeof(*oui)); + sub_opt = skb_put(skb, sizeof(*sub_opt)); + memset(sub_opt, 0x0, sizeof(*sub_opt)); + + sub_tlv = skb_put(skb, sizeof(*sub_tlv)); + sub_tlv->type = BR_MRP_SUB_TLV_HEADER_TEST_AUTO_MGR; + + /* 32 bit alligment shall be ensured therefore add 2 bytes */ + skb_put(skb, MRP_OPT_PADDING); + } + + br_mrp_skb_tlv(skb, BR_MRP_TLV_HEADER_END, 0x0); + + return skb; +} + +static struct sk_buff *br_mrp_alloc_in_test_skb(struct br_mrp *mrp, + struct net_bridge_port *p, + enum br_mrp_port_role_type port_role) +{ + struct br_mrp_in_test_hdr *hdr = NULL; + struct sk_buff *skb = NULL; + + if (!p) + return NULL; + + skb = br_mrp_skb_alloc(p, p->dev->dev_addr, mrp_in_test_dmac); + if (!skb) + return NULL; + + br_mrp_skb_tlv(skb, BR_MRP_TLV_HEADER_IN_TEST, sizeof(*hdr)); + hdr = skb_put(skb, sizeof(*hdr)); + + hdr->id = cpu_to_be16(mrp->in_id); + ether_addr_copy(hdr->sa, p->br->dev->dev_addr); + hdr->port_role = cpu_to_be16(port_role); + hdr->state = cpu_to_be16(mrp->in_state); + hdr->transitions = cpu_to_be16(mrp->in_transitions); + hdr->timestamp = cpu_to_be32(jiffies_to_msecs(jiffies)); + + br_mrp_skb_common(skb, mrp); + br_mrp_skb_tlv(skb, BR_MRP_TLV_HEADER_END, 0x0); + + return skb; +} + +/* This function is continuously called in the following cases: + * - when node role is MRM, in this case test_monitor is always set to false + * because it needs to notify the userspace that the ring is open and needs to + * send MRP_Test frames + * - when node role is MRA, there are 2 subcases: + * - when MRA behaves as MRM, in this case is similar with MRM role + * - when MRA behaves as MRC, in this case test_monitor is set to true, + * because it needs to detect when it stops seeing MRP_Test frames + * from MRM node but it doesn't need to send MRP_Test frames. + */ +static void br_mrp_test_work_expired(struct work_struct *work) +{ + struct delayed_work *del_work = to_delayed_work(work); + struct br_mrp *mrp = container_of(del_work, struct br_mrp, test_work); + struct net_bridge_port *p; + bool notify_open = false; + struct sk_buff *skb; + + if (time_before_eq(mrp->test_end, jiffies)) + return; + + if (mrp->test_count_miss < mrp->test_max_miss) { + mrp->test_count_miss++; + } else { + /* Notify that the ring is open only if the ring state is + * closed, otherwise it would continue to notify at every + * interval. + * Also notify that the ring is open when the node has the + * role MRA and behaves as MRC. The reason is that the + * userspace needs to know when the MRM stopped sending + * MRP_Test frames so that the current node to try to take + * the role of a MRM. + */ + if (mrp->ring_state == BR_MRP_RING_STATE_CLOSED || + mrp->test_monitor) + notify_open = true; + } + + rcu_read_lock(); + + p = rcu_dereference(mrp->p_port); + if (p) { + if (!mrp->test_monitor) { + skb = br_mrp_alloc_test_skb(mrp, p, + BR_MRP_PORT_ROLE_PRIMARY); + if (!skb) + goto out; + + skb_reset_network_header(skb); + dev_queue_xmit(skb); + } + + if (notify_open && !mrp->ring_role_offloaded) + br_mrp_ring_port_open(p->dev, true); + } + + p = rcu_dereference(mrp->s_port); + if (p) { + if (!mrp->test_monitor) { + skb = br_mrp_alloc_test_skb(mrp, p, + BR_MRP_PORT_ROLE_SECONDARY); + if (!skb) + goto out; + + skb_reset_network_header(skb); + dev_queue_xmit(skb); + } + + if (notify_open && !mrp->ring_role_offloaded) + br_mrp_ring_port_open(p->dev, true); + } + +out: + rcu_read_unlock(); + + queue_delayed_work(system_wq, &mrp->test_work, + usecs_to_jiffies(mrp->test_interval)); +} + +/* This function is continuously called when the node has the interconnect role + * MIM. It would generate interconnect test frames and will send them on all 3 + * ports. But will also check if it stop receiving interconnect test frames. + */ +static void br_mrp_in_test_work_expired(struct work_struct *work) +{ + struct delayed_work *del_work = to_delayed_work(work); + struct br_mrp *mrp = container_of(del_work, struct br_mrp, in_test_work); + struct net_bridge_port *p; + bool notify_open = false; + struct sk_buff *skb; + + if (time_before_eq(mrp->in_test_end, jiffies)) + return; + + if (mrp->in_test_count_miss < mrp->in_test_max_miss) { + mrp->in_test_count_miss++; + } else { + /* Notify that the interconnect ring is open only if the + * interconnect ring state is closed, otherwise it would + * continue to notify at every interval. + */ + if (mrp->in_state == BR_MRP_IN_STATE_CLOSED) + notify_open = true; + } + + rcu_read_lock(); + + p = rcu_dereference(mrp->p_port); + if (p) { + skb = br_mrp_alloc_in_test_skb(mrp, p, + BR_MRP_PORT_ROLE_PRIMARY); + if (!skb) + goto out; + + skb_reset_network_header(skb); + dev_queue_xmit(skb); + + if (notify_open && !mrp->in_role_offloaded) + br_mrp_in_port_open(p->dev, true); + } + + p = rcu_dereference(mrp->s_port); + if (p) { + skb = br_mrp_alloc_in_test_skb(mrp, p, + BR_MRP_PORT_ROLE_SECONDARY); + if (!skb) + goto out; + + skb_reset_network_header(skb); + dev_queue_xmit(skb); + + if (notify_open && !mrp->in_role_offloaded) + br_mrp_in_port_open(p->dev, true); + } + + p = rcu_dereference(mrp->i_port); + if (p) { + skb = br_mrp_alloc_in_test_skb(mrp, p, + BR_MRP_PORT_ROLE_INTER); + if (!skb) + goto out; + + skb_reset_network_header(skb); + dev_queue_xmit(skb); + + if (notify_open && !mrp->in_role_offloaded) + br_mrp_in_port_open(p->dev, true); + } + +out: + rcu_read_unlock(); + + queue_delayed_work(system_wq, &mrp->in_test_work, + usecs_to_jiffies(mrp->in_test_interval)); +} + +/* Deletes the MRP instance. + * note: called under rtnl_lock + */ +static void br_mrp_del_impl(struct net_bridge *br, struct br_mrp *mrp) +{ + struct net_bridge_port *p; + u8 state; + + /* Stop sending MRP_Test frames */ + cancel_delayed_work_sync(&mrp->test_work); + br_mrp_switchdev_send_ring_test(br, mrp, 0, 0, 0, 0); + + /* Stop sending MRP_InTest frames if has an interconnect role */ + cancel_delayed_work_sync(&mrp->in_test_work); + br_mrp_switchdev_send_in_test(br, mrp, 0, 0, 0); + + /* Disable the roles */ + br_mrp_switchdev_set_ring_role(br, mrp, BR_MRP_RING_ROLE_DISABLED); + p = rtnl_dereference(mrp->i_port); + if (p) + br_mrp_switchdev_set_in_role(br, mrp, mrp->in_id, mrp->ring_id, + BR_MRP_IN_ROLE_DISABLED); + + br_mrp_switchdev_del(br, mrp); + + /* Reset the ports */ + p = rtnl_dereference(mrp->p_port); + if (p) { + spin_lock_bh(&br->lock); + state = netif_running(br->dev) ? + BR_STATE_FORWARDING : BR_STATE_DISABLED; + p->state = state; + p->flags &= ~BR_MRP_AWARE; + spin_unlock_bh(&br->lock); + br_mrp_port_switchdev_set_state(p, state); + rcu_assign_pointer(mrp->p_port, NULL); + } + + p = rtnl_dereference(mrp->s_port); + if (p) { + spin_lock_bh(&br->lock); + state = netif_running(br->dev) ? + BR_STATE_FORWARDING : BR_STATE_DISABLED; + p->state = state; + p->flags &= ~BR_MRP_AWARE; + spin_unlock_bh(&br->lock); + br_mrp_port_switchdev_set_state(p, state); + rcu_assign_pointer(mrp->s_port, NULL); + } + + p = rtnl_dereference(mrp->i_port); + if (p) { + spin_lock_bh(&br->lock); + state = netif_running(br->dev) ? + BR_STATE_FORWARDING : BR_STATE_DISABLED; + p->state = state; + p->flags &= ~BR_MRP_AWARE; + spin_unlock_bh(&br->lock); + br_mrp_port_switchdev_set_state(p, state); + rcu_assign_pointer(mrp->i_port, NULL); + } + + hlist_del_rcu(&mrp->list); + kfree_rcu(mrp, rcu); + + if (hlist_empty(&br->mrp_list)) + br_del_frame(br, &mrp_frame_type); +} + +/* Adds a new MRP instance. + * note: called under rtnl_lock + */ +int br_mrp_add(struct net_bridge *br, struct br_mrp_instance *instance) +{ + struct net_bridge_port *p; + struct br_mrp *mrp; + int err; + + /* If the ring exists, it is not possible to create another one with the + * same ring_id + */ + mrp = br_mrp_find_id(br, instance->ring_id); + if (mrp) + return -EINVAL; + + if (!br_mrp_get_port(br, instance->p_ifindex) || + !br_mrp_get_port(br, instance->s_ifindex)) + return -EINVAL; + + /* It is not possible to have the same port part of multiple rings */ + if (!br_mrp_unique_ifindex(br, instance->p_ifindex) || + !br_mrp_unique_ifindex(br, instance->s_ifindex)) + return -EINVAL; + + mrp = kzalloc(sizeof(*mrp), GFP_KERNEL); + if (!mrp) + return -ENOMEM; + + mrp->ring_id = instance->ring_id; + mrp->prio = instance->prio; + + p = br_mrp_get_port(br, instance->p_ifindex); + spin_lock_bh(&br->lock); + p->state = BR_STATE_FORWARDING; + p->flags |= BR_MRP_AWARE; + spin_unlock_bh(&br->lock); + rcu_assign_pointer(mrp->p_port, p); + + p = br_mrp_get_port(br, instance->s_ifindex); + spin_lock_bh(&br->lock); + p->state = BR_STATE_FORWARDING; + p->flags |= BR_MRP_AWARE; + spin_unlock_bh(&br->lock); + rcu_assign_pointer(mrp->s_port, p); + + if (hlist_empty(&br->mrp_list)) + br_add_frame(br, &mrp_frame_type); + + INIT_DELAYED_WORK(&mrp->test_work, br_mrp_test_work_expired); + INIT_DELAYED_WORK(&mrp->in_test_work, br_mrp_in_test_work_expired); + hlist_add_tail_rcu(&mrp->list, &br->mrp_list); + + err = br_mrp_switchdev_add(br, mrp); + if (err) + goto delete_mrp; + + return 0; + +delete_mrp: + br_mrp_del_impl(br, mrp); + + return err; +} + +/* Deletes the MRP instance from which the port is part of + * note: called under rtnl_lock + */ +void br_mrp_port_del(struct net_bridge *br, struct net_bridge_port *p) +{ + struct br_mrp *mrp = br_mrp_find_port(br, p); + + /* If the port is not part of a MRP instance just bail out */ + if (!mrp) + return; + + br_mrp_del_impl(br, mrp); +} + +/* Deletes existing MRP instance based on ring_id + * note: called under rtnl_lock + */ +int br_mrp_del(struct net_bridge *br, struct br_mrp_instance *instance) +{ + struct br_mrp *mrp = br_mrp_find_id(br, instance->ring_id); + + if (!mrp) + return -EINVAL; + + br_mrp_del_impl(br, mrp); + + return 0; +} + +/* Set port state, port state can be forwarding, blocked or disabled + * note: already called with rtnl_lock + */ +int br_mrp_set_port_state(struct net_bridge_port *p, + enum br_mrp_port_state_type state) +{ + u32 port_state; + + if (!p || !(p->flags & BR_MRP_AWARE)) + return -EINVAL; + + spin_lock_bh(&p->br->lock); + + if (state == BR_MRP_PORT_STATE_FORWARDING) + port_state = BR_STATE_FORWARDING; + else + port_state = BR_STATE_BLOCKING; + + p->state = port_state; + spin_unlock_bh(&p->br->lock); + + br_mrp_port_switchdev_set_state(p, port_state); + + return 0; +} + +/* Set port role, port role can be primary or secondary + * note: already called with rtnl_lock + */ +int br_mrp_set_port_role(struct net_bridge_port *p, + enum br_mrp_port_role_type role) +{ + struct br_mrp *mrp; + + if (!p || !(p->flags & BR_MRP_AWARE)) + return -EINVAL; + + mrp = br_mrp_find_port(p->br, p); + + if (!mrp) + return -EINVAL; + + switch (role) { + case BR_MRP_PORT_ROLE_PRIMARY: + rcu_assign_pointer(mrp->p_port, p); + break; + case BR_MRP_PORT_ROLE_SECONDARY: + rcu_assign_pointer(mrp->s_port, p); + break; + default: + return -EINVAL; + } + + br_mrp_port_switchdev_set_role(p, role); + + return 0; +} + +/* Set ring state, ring state can be only Open or Closed + * note: already called with rtnl_lock + */ +int br_mrp_set_ring_state(struct net_bridge *br, + struct br_mrp_ring_state *state) +{ + struct br_mrp *mrp = br_mrp_find_id(br, state->ring_id); + + if (!mrp) + return -EINVAL; + + if (mrp->ring_state != state->ring_state) + mrp->ring_transitions++; + + mrp->ring_state = state->ring_state; + + br_mrp_switchdev_set_ring_state(br, mrp, state->ring_state); + + return 0; +} + +/* Set ring role, ring role can be only MRM(Media Redundancy Manager) or + * MRC(Media Redundancy Client). + * note: already called with rtnl_lock + */ +int br_mrp_set_ring_role(struct net_bridge *br, + struct br_mrp_ring_role *role) +{ + struct br_mrp *mrp = br_mrp_find_id(br, role->ring_id); + enum br_mrp_hw_support support; + + if (!mrp) + return -EINVAL; + + mrp->ring_role = role->ring_role; + + /* If there is an error just bailed out */ + support = br_mrp_switchdev_set_ring_role(br, mrp, role->ring_role); + if (support == BR_MRP_NONE) + return -EOPNOTSUPP; + + /* Now detect if the HW actually applied the role or not. If the HW + * applied the role it means that the SW will not to do those operations + * anymore. For example if the role ir MRM then the HW will notify the + * SW when ring is open, but if the is not pushed to the HW the SW will + * need to detect when the ring is open + */ + mrp->ring_role_offloaded = support == BR_MRP_SW ? 0 : 1; + + return 0; +} + +/* Start to generate or monitor MRP test frames, the frames are generated by + * HW and if it fails, they are generated by the SW. + * note: already called with rtnl_lock + */ +int br_mrp_start_test(struct net_bridge *br, + struct br_mrp_start_test *test) +{ + struct br_mrp *mrp = br_mrp_find_id(br, test->ring_id); + enum br_mrp_hw_support support; + + if (!mrp) + return -EINVAL; + + /* Try to push it to the HW and if it fails then continue with SW + * implementation and if that also fails then return error. + */ + support = br_mrp_switchdev_send_ring_test(br, mrp, test->interval, + test->max_miss, test->period, + test->monitor); + if (support == BR_MRP_NONE) + return -EOPNOTSUPP; + + if (support == BR_MRP_HW) + return 0; + + mrp->test_interval = test->interval; + mrp->test_end = jiffies + usecs_to_jiffies(test->period); + mrp->test_max_miss = test->max_miss; + mrp->test_monitor = test->monitor; + mrp->test_count_miss = 0; + queue_delayed_work(system_wq, &mrp->test_work, + usecs_to_jiffies(test->interval)); + + return 0; +} + +/* Set in state, int state can be only Open or Closed + * note: already called with rtnl_lock + */ +int br_mrp_set_in_state(struct net_bridge *br, struct br_mrp_in_state *state) +{ + struct br_mrp *mrp = br_mrp_find_in_id(br, state->in_id); + + if (!mrp) + return -EINVAL; + + if (mrp->in_state != state->in_state) + mrp->in_transitions++; + + mrp->in_state = state->in_state; + + br_mrp_switchdev_set_in_state(br, mrp, state->in_state); + + return 0; +} + +/* Set in role, in role can be only MIM(Media Interconnection Manager) or + * MIC(Media Interconnection Client). + * note: already called with rtnl_lock + */ +int br_mrp_set_in_role(struct net_bridge *br, struct br_mrp_in_role *role) +{ + struct br_mrp *mrp = br_mrp_find_id(br, role->ring_id); + enum br_mrp_hw_support support; + struct net_bridge_port *p; + + if (!mrp) + return -EINVAL; + + if (!br_mrp_get_port(br, role->i_ifindex)) + return -EINVAL; + + if (role->in_role == BR_MRP_IN_ROLE_DISABLED) { + u8 state; + + /* It is not allowed to disable a port that doesn't exist */ + p = rtnl_dereference(mrp->i_port); + if (!p) + return -EINVAL; + + /* Stop the generating MRP_InTest frames */ + cancel_delayed_work_sync(&mrp->in_test_work); + br_mrp_switchdev_send_in_test(br, mrp, 0, 0, 0); + + /* Remove the port */ + spin_lock_bh(&br->lock); + state = netif_running(br->dev) ? + BR_STATE_FORWARDING : BR_STATE_DISABLED; + p->state = state; + p->flags &= ~BR_MRP_AWARE; + spin_unlock_bh(&br->lock); + br_mrp_port_switchdev_set_state(p, state); + rcu_assign_pointer(mrp->i_port, NULL); + + mrp->in_role = role->in_role; + mrp->in_id = 0; + + return 0; + } + + /* It is not possible to have the same port part of multiple rings */ + if (!br_mrp_unique_ifindex(br, role->i_ifindex)) + return -EINVAL; + + /* It is not allowed to set a different interconnect port if the mrp + * instance has already one. First it needs to be disabled and after + * that set the new port + */ + if (rcu_access_pointer(mrp->i_port)) + return -EINVAL; + + p = br_mrp_get_port(br, role->i_ifindex); + spin_lock_bh(&br->lock); + p->state = BR_STATE_FORWARDING; + p->flags |= BR_MRP_AWARE; + spin_unlock_bh(&br->lock); + rcu_assign_pointer(mrp->i_port, p); + + mrp->in_role = role->in_role; + mrp->in_id = role->in_id; + + /* If there is an error just bailed out */ + support = br_mrp_switchdev_set_in_role(br, mrp, role->in_id, + role->ring_id, role->in_role); + if (support == BR_MRP_NONE) + return -EOPNOTSUPP; + + /* Now detect if the HW actually applied the role or not. If the HW + * applied the role it means that the SW will not to do those operations + * anymore. For example if the role is MIM then the HW will notify the + * SW when interconnect ring is open, but if the is not pushed to the HW + * the SW will need to detect when the interconnect ring is open. + */ + mrp->in_role_offloaded = support == BR_MRP_SW ? 0 : 1; + + return 0; +} + +/* Start to generate MRP_InTest frames, the frames are generated by + * HW and if it fails, they are generated by the SW. + * note: already called with rtnl_lock + */ +int br_mrp_start_in_test(struct net_bridge *br, + struct br_mrp_start_in_test *in_test) +{ + struct br_mrp *mrp = br_mrp_find_in_id(br, in_test->in_id); + enum br_mrp_hw_support support; + + if (!mrp) + return -EINVAL; + + if (mrp->in_role != BR_MRP_IN_ROLE_MIM) + return -EINVAL; + + /* Try to push it to the HW and if it fails then continue with SW + * implementation and if that also fails then return error. + */ + support = br_mrp_switchdev_send_in_test(br, mrp, in_test->interval, + in_test->max_miss, + in_test->period); + if (support == BR_MRP_NONE) + return -EOPNOTSUPP; + + if (support == BR_MRP_HW) + return 0; + + mrp->in_test_interval = in_test->interval; + mrp->in_test_end = jiffies + usecs_to_jiffies(in_test->period); + mrp->in_test_max_miss = in_test->max_miss; + mrp->in_test_count_miss = 0; + queue_delayed_work(system_wq, &mrp->in_test_work, + usecs_to_jiffies(in_test->interval)); + + return 0; +} + +/* Determine if the frame type is a ring frame */ +static bool br_mrp_ring_frame(struct sk_buff *skb) +{ + const struct br_mrp_tlv_hdr *hdr; + struct br_mrp_tlv_hdr _hdr; + + hdr = skb_header_pointer(skb, sizeof(uint16_t), sizeof(_hdr), &_hdr); + if (!hdr) + return false; + + if (hdr->type == BR_MRP_TLV_HEADER_RING_TEST || + hdr->type == BR_MRP_TLV_HEADER_RING_TOPO || + hdr->type == BR_MRP_TLV_HEADER_RING_LINK_DOWN || + hdr->type == BR_MRP_TLV_HEADER_RING_LINK_UP || + hdr->type == BR_MRP_TLV_HEADER_OPTION) + return true; + + return false; +} + +/* Determine if the frame type is an interconnect frame */ +static bool br_mrp_in_frame(struct sk_buff *skb) +{ + const struct br_mrp_tlv_hdr *hdr; + struct br_mrp_tlv_hdr _hdr; + + hdr = skb_header_pointer(skb, sizeof(uint16_t), sizeof(_hdr), &_hdr); + if (!hdr) + return false; + + if (hdr->type == BR_MRP_TLV_HEADER_IN_TEST || + hdr->type == BR_MRP_TLV_HEADER_IN_TOPO || + hdr->type == BR_MRP_TLV_HEADER_IN_LINK_DOWN || + hdr->type == BR_MRP_TLV_HEADER_IN_LINK_UP || + hdr->type == BR_MRP_TLV_HEADER_IN_LINK_STATUS) + return true; + + return false; +} + +/* Process only MRP Test frame. All the other MRP frames are processed by + * userspace application + * note: already called with rcu_read_lock + */ +static void br_mrp_mrm_process(struct br_mrp *mrp, struct net_bridge_port *port, + struct sk_buff *skb) +{ + const struct br_mrp_tlv_hdr *hdr; + struct br_mrp_tlv_hdr _hdr; + + /* Each MRP header starts with a version field which is 16 bits. + * Therefore skip the version and get directly the TLV header. + */ + hdr = skb_header_pointer(skb, sizeof(uint16_t), sizeof(_hdr), &_hdr); + if (!hdr) + return; + + if (hdr->type != BR_MRP_TLV_HEADER_RING_TEST) + return; + + mrp->test_count_miss = 0; + + /* Notify the userspace that the ring is closed only when the ring is + * not closed + */ + if (mrp->ring_state != BR_MRP_RING_STATE_CLOSED) + br_mrp_ring_port_open(port->dev, false); +} + +/* Determine if the test hdr has a better priority than the node */ +static bool br_mrp_test_better_than_own(struct br_mrp *mrp, + struct net_bridge *br, + const struct br_mrp_ring_test_hdr *hdr) +{ + u16 prio = be16_to_cpu(hdr->prio); + + if (prio < mrp->prio || + (prio == mrp->prio && + ether_addr_to_u64(hdr->sa) < ether_addr_to_u64(br->dev->dev_addr))) + return true; + + return false; +} + +/* Process only MRP Test frame. All the other MRP frames are processed by + * userspace application + * note: already called with rcu_read_lock + */ +static void br_mrp_mra_process(struct br_mrp *mrp, struct net_bridge *br, + struct net_bridge_port *port, + struct sk_buff *skb) +{ + const struct br_mrp_ring_test_hdr *test_hdr; + struct br_mrp_ring_test_hdr _test_hdr; + const struct br_mrp_tlv_hdr *hdr; + struct br_mrp_tlv_hdr _hdr; + + /* Each MRP header starts with a version field which is 16 bits. + * Therefore skip the version and get directly the TLV header. + */ + hdr = skb_header_pointer(skb, sizeof(uint16_t), sizeof(_hdr), &_hdr); + if (!hdr) + return; + + if (hdr->type != BR_MRP_TLV_HEADER_RING_TEST) + return; + + test_hdr = skb_header_pointer(skb, sizeof(uint16_t) + sizeof(_hdr), + sizeof(_test_hdr), &_test_hdr); + if (!test_hdr) + return; + + /* Only frames that have a better priority than the node will + * clear the miss counter because otherwise the node will need to behave + * as MRM. + */ + if (br_mrp_test_better_than_own(mrp, br, test_hdr)) + mrp->test_count_miss = 0; +} + +/* Process only MRP InTest frame. All the other MRP frames are processed by + * userspace application + * note: already called with rcu_read_lock + */ +static bool br_mrp_mim_process(struct br_mrp *mrp, struct net_bridge_port *port, + struct sk_buff *skb) +{ + const struct br_mrp_in_test_hdr *in_hdr; + struct br_mrp_in_test_hdr _in_hdr; + const struct br_mrp_tlv_hdr *hdr; + struct br_mrp_tlv_hdr _hdr; + + /* Each MRP header starts with a version field which is 16 bits. + * Therefore skip the version and get directly the TLV header. + */ + hdr = skb_header_pointer(skb, sizeof(uint16_t), sizeof(_hdr), &_hdr); + if (!hdr) + return false; + + /* The check for InTest frame type was already done */ + in_hdr = skb_header_pointer(skb, sizeof(uint16_t) + sizeof(_hdr), + sizeof(_in_hdr), &_in_hdr); + if (!in_hdr) + return false; + + /* It needs to process only it's own InTest frames. */ + if (mrp->in_id != ntohs(in_hdr->id)) + return false; + + mrp->in_test_count_miss = 0; + + /* Notify the userspace that the ring is closed only when the ring is + * not closed + */ + if (mrp->in_state != BR_MRP_IN_STATE_CLOSED) + br_mrp_in_port_open(port->dev, false); + + return true; +} + +/* Get the MRP frame type + * note: already called with rcu_read_lock + */ +static u8 br_mrp_get_frame_type(struct sk_buff *skb) +{ + const struct br_mrp_tlv_hdr *hdr; + struct br_mrp_tlv_hdr _hdr; + + /* Each MRP header starts with a version field which is 16 bits. + * Therefore skip the version and get directly the TLV header. + */ + hdr = skb_header_pointer(skb, sizeof(uint16_t), sizeof(_hdr), &_hdr); + if (!hdr) + return 0xff; + + return hdr->type; +} + +static bool br_mrp_mrm_behaviour(struct br_mrp *mrp) +{ + if (mrp->ring_role == BR_MRP_RING_ROLE_MRM || + (mrp->ring_role == BR_MRP_RING_ROLE_MRA && !mrp->test_monitor)) + return true; + + return false; +} + +static bool br_mrp_mrc_behaviour(struct br_mrp *mrp) +{ + if (mrp->ring_role == BR_MRP_RING_ROLE_MRC || + (mrp->ring_role == BR_MRP_RING_ROLE_MRA && mrp->test_monitor)) + return true; + + return false; +} + +/* This will just forward the frame to the other mrp ring ports, depending on + * the frame type, ring role and interconnect role + * note: already called with rcu_read_lock + */ +static int br_mrp_rcv(struct net_bridge_port *p, + struct sk_buff *skb, struct net_device *dev) +{ + struct net_bridge_port *p_port, *s_port, *i_port = NULL; + struct net_bridge_port *p_dst, *s_dst, *i_dst = NULL; + struct net_bridge *br; + struct br_mrp *mrp; + + /* If port is disabled don't accept any frames */ + if (p->state == BR_STATE_DISABLED) + return 0; + + br = p->br; + mrp = br_mrp_find_port(br, p); + if (unlikely(!mrp)) + return 0; + + p_port = rcu_dereference(mrp->p_port); + if (!p_port) + return 0; + p_dst = p_port; + + s_port = rcu_dereference(mrp->s_port); + if (!s_port) + return 0; + s_dst = s_port; + + /* If the frame is a ring frame then it is not required to check the + * interconnect role and ports to process or forward the frame + */ + if (br_mrp_ring_frame(skb)) { + /* If the role is MRM then don't forward the frames */ + if (mrp->ring_role == BR_MRP_RING_ROLE_MRM) { + br_mrp_mrm_process(mrp, p, skb); + goto no_forward; + } + + /* If the role is MRA then don't forward the frames if it + * behaves as MRM node + */ + if (mrp->ring_role == BR_MRP_RING_ROLE_MRA) { + if (!mrp->test_monitor) { + br_mrp_mrm_process(mrp, p, skb); + goto no_forward; + } + + br_mrp_mra_process(mrp, br, p, skb); + } + + goto forward; + } + + if (br_mrp_in_frame(skb)) { + u8 in_type = br_mrp_get_frame_type(skb); + + i_port = rcu_dereference(mrp->i_port); + i_dst = i_port; + + /* If the ring port is in block state it should not forward + * In_Test frames + */ + if (br_mrp_is_ring_port(p_port, s_port, p) && + p->state == BR_STATE_BLOCKING && + in_type == BR_MRP_TLV_HEADER_IN_TEST) + goto no_forward; + + /* Nodes that behaves as MRM needs to stop forwarding the + * frames in case the ring is closed, otherwise will be a loop. + * In this case the frame is no forward between the ring ports. + */ + if (br_mrp_mrm_behaviour(mrp) && + br_mrp_is_ring_port(p_port, s_port, p) && + (s_port->state != BR_STATE_FORWARDING || + p_port->state != BR_STATE_FORWARDING)) { + p_dst = NULL; + s_dst = NULL; + } + + /* A node that behaves as MRC and doesn't have a interconnect + * role then it should forward all frames between the ring ports + * because it doesn't have an interconnect port + */ + if (br_mrp_mrc_behaviour(mrp) && + mrp->in_role == BR_MRP_IN_ROLE_DISABLED) + goto forward; + + if (mrp->in_role == BR_MRP_IN_ROLE_MIM) { + if (in_type == BR_MRP_TLV_HEADER_IN_TEST) { + /* MIM should not forward it's own InTest + * frames + */ + if (br_mrp_mim_process(mrp, p, skb)) { + goto no_forward; + } else { + if (br_mrp_is_ring_port(p_port, s_port, + p)) + i_dst = NULL; + + if (br_mrp_is_in_port(i_port, p)) + goto no_forward; + } + } else { + /* MIM should forward IntLinkChange/Status and + * IntTopoChange between ring ports but MIM + * should not forward IntLinkChange/Status and + * IntTopoChange if the frame was received at + * the interconnect port + */ + if (br_mrp_is_ring_port(p_port, s_port, p)) + i_dst = NULL; + + if (br_mrp_is_in_port(i_port, p)) + goto no_forward; + } + } + + if (mrp->in_role == BR_MRP_IN_ROLE_MIC) { + /* MIC should forward InTest frames on all ports + * regardless of the received port + */ + if (in_type == BR_MRP_TLV_HEADER_IN_TEST) + goto forward; + + /* MIC should forward IntLinkChange frames only if they + * are received on ring ports to all the ports + */ + if (br_mrp_is_ring_port(p_port, s_port, p) && + (in_type == BR_MRP_TLV_HEADER_IN_LINK_UP || + in_type == BR_MRP_TLV_HEADER_IN_LINK_DOWN)) + goto forward; + + /* MIC should forward IntLinkStatus frames only to + * interconnect port if it was received on a ring port. + * If it is received on interconnect port then, it + * should be forward on both ring ports + */ + if (br_mrp_is_ring_port(p_port, s_port, p) && + in_type == BR_MRP_TLV_HEADER_IN_LINK_STATUS) { + p_dst = NULL; + s_dst = NULL; + } + + /* Should forward the InTopo frames only between the + * ring ports + */ + if (in_type == BR_MRP_TLV_HEADER_IN_TOPO) { + i_dst = NULL; + goto forward; + } + + /* In all the other cases don't forward the frames */ + goto no_forward; + } + } + +forward: + if (p_dst) + br_forward(p_dst, skb, true, false); + if (s_dst) + br_forward(s_dst, skb, true, false); + if (i_dst) + br_forward(i_dst, skb, true, false); + +no_forward: + return 1; +} + +/* Check if the frame was received on a port that is part of MRP ring + * and if the frame has MRP eth. In that case process the frame otherwise do + * normal forwarding. + * note: already called with rcu_read_lock + */ +static int br_mrp_process(struct net_bridge_port *p, struct sk_buff *skb) +{ + /* If there is no MRP instance do normal forwarding */ + if (likely(!(p->flags & BR_MRP_AWARE))) + goto out; + + return br_mrp_rcv(p, skb, p->dev); +out: + return 0; +} + +bool br_mrp_enabled(struct net_bridge *br) +{ + return !hlist_empty(&br->mrp_list); +} diff --git a/net/bridge/br_mrp_netlink.c b/net/bridge/br_mrp_netlink.c new file mode 100644 index 000000000..ce6f63c77 --- /dev/null +++ b/net/bridge/br_mrp_netlink.c @@ -0,0 +1,571 @@ +// SPDX-License-Identifier: GPL-2.0-or-later + +#include + +#include +#include "br_private.h" +#include "br_private_mrp.h" + +static const struct nla_policy br_mrp_policy[IFLA_BRIDGE_MRP_MAX + 1] = { + [IFLA_BRIDGE_MRP_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_MRP_INSTANCE] = { .type = NLA_NESTED }, + [IFLA_BRIDGE_MRP_PORT_STATE] = { .type = NLA_NESTED }, + [IFLA_BRIDGE_MRP_PORT_ROLE] = { .type = NLA_NESTED }, + [IFLA_BRIDGE_MRP_RING_STATE] = { .type = NLA_NESTED }, + [IFLA_BRIDGE_MRP_RING_ROLE] = { .type = NLA_NESTED }, + [IFLA_BRIDGE_MRP_START_TEST] = { .type = NLA_NESTED }, + [IFLA_BRIDGE_MRP_IN_ROLE] = { .type = NLA_NESTED }, + [IFLA_BRIDGE_MRP_IN_STATE] = { .type = NLA_NESTED }, + [IFLA_BRIDGE_MRP_START_IN_TEST] = { .type = NLA_NESTED }, +}; + +static const struct nla_policy +br_mrp_instance_policy[IFLA_BRIDGE_MRP_INSTANCE_MAX + 1] = { + [IFLA_BRIDGE_MRP_INSTANCE_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_MRP_INSTANCE_RING_ID] = { .type = NLA_U32 }, + [IFLA_BRIDGE_MRP_INSTANCE_P_IFINDEX] = { .type = NLA_U32 }, + [IFLA_BRIDGE_MRP_INSTANCE_S_IFINDEX] = { .type = NLA_U32 }, + [IFLA_BRIDGE_MRP_INSTANCE_PRIO] = { .type = NLA_U16 }, +}; + +static int br_mrp_instance_parse(struct net_bridge *br, struct nlattr *attr, + int cmd, struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_MRP_INSTANCE_MAX + 1]; + struct br_mrp_instance inst; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_MRP_INSTANCE_MAX, attr, + br_mrp_instance_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_MRP_INSTANCE_RING_ID] || + !tb[IFLA_BRIDGE_MRP_INSTANCE_P_IFINDEX] || + !tb[IFLA_BRIDGE_MRP_INSTANCE_S_IFINDEX]) { + NL_SET_ERR_MSG_MOD(extack, + "Missing attribute: RING_ID or P_IFINDEX or S_IFINDEX"); + return -EINVAL; + } + + memset(&inst, 0, sizeof(inst)); + + inst.ring_id = nla_get_u32(tb[IFLA_BRIDGE_MRP_INSTANCE_RING_ID]); + inst.p_ifindex = nla_get_u32(tb[IFLA_BRIDGE_MRP_INSTANCE_P_IFINDEX]); + inst.s_ifindex = nla_get_u32(tb[IFLA_BRIDGE_MRP_INSTANCE_S_IFINDEX]); + inst.prio = MRP_DEFAULT_PRIO; + + if (tb[IFLA_BRIDGE_MRP_INSTANCE_PRIO]) + inst.prio = nla_get_u16(tb[IFLA_BRIDGE_MRP_INSTANCE_PRIO]); + + if (cmd == RTM_SETLINK) + return br_mrp_add(br, &inst); + else + return br_mrp_del(br, &inst); + + return 0; +} + +static const struct nla_policy +br_mrp_port_state_policy[IFLA_BRIDGE_MRP_PORT_STATE_MAX + 1] = { + [IFLA_BRIDGE_MRP_PORT_STATE_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_MRP_PORT_STATE_STATE] = { .type = NLA_U32 }, +}; + +static int br_mrp_port_state_parse(struct net_bridge_port *p, + struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_MRP_PORT_STATE_MAX + 1]; + enum br_mrp_port_state_type state; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_MRP_PORT_STATE_MAX, attr, + br_mrp_port_state_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_MRP_PORT_STATE_STATE]) { + NL_SET_ERR_MSG_MOD(extack, "Missing attribute: STATE"); + return -EINVAL; + } + + state = nla_get_u32(tb[IFLA_BRIDGE_MRP_PORT_STATE_STATE]); + + return br_mrp_set_port_state(p, state); +} + +static const struct nla_policy +br_mrp_port_role_policy[IFLA_BRIDGE_MRP_PORT_ROLE_MAX + 1] = { + [IFLA_BRIDGE_MRP_PORT_ROLE_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_MRP_PORT_ROLE_ROLE] = { .type = NLA_U32 }, +}; + +static int br_mrp_port_role_parse(struct net_bridge_port *p, + struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_MRP_PORT_ROLE_MAX + 1]; + enum br_mrp_port_role_type role; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_MRP_PORT_ROLE_MAX, attr, + br_mrp_port_role_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_MRP_PORT_ROLE_ROLE]) { + NL_SET_ERR_MSG_MOD(extack, "Missing attribute: ROLE"); + return -EINVAL; + } + + role = nla_get_u32(tb[IFLA_BRIDGE_MRP_PORT_ROLE_ROLE]); + + return br_mrp_set_port_role(p, role); +} + +static const struct nla_policy +br_mrp_ring_state_policy[IFLA_BRIDGE_MRP_RING_STATE_MAX + 1] = { + [IFLA_BRIDGE_MRP_RING_STATE_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_MRP_RING_STATE_RING_ID] = { .type = NLA_U32 }, + [IFLA_BRIDGE_MRP_RING_STATE_STATE] = { .type = NLA_U32 }, +}; + +static int br_mrp_ring_state_parse(struct net_bridge *br, struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_MRP_RING_STATE_MAX + 1]; + struct br_mrp_ring_state state; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_MRP_RING_STATE_MAX, attr, + br_mrp_ring_state_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_MRP_RING_STATE_RING_ID] || + !tb[IFLA_BRIDGE_MRP_RING_STATE_STATE]) { + NL_SET_ERR_MSG_MOD(extack, + "Missing attribute: RING_ID or STATE"); + return -EINVAL; + } + + memset(&state, 0x0, sizeof(state)); + + state.ring_id = nla_get_u32(tb[IFLA_BRIDGE_MRP_RING_STATE_RING_ID]); + state.ring_state = nla_get_u32(tb[IFLA_BRIDGE_MRP_RING_STATE_STATE]); + + return br_mrp_set_ring_state(br, &state); +} + +static const struct nla_policy +br_mrp_ring_role_policy[IFLA_BRIDGE_MRP_RING_ROLE_MAX + 1] = { + [IFLA_BRIDGE_MRP_RING_ROLE_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_MRP_RING_ROLE_RING_ID] = { .type = NLA_U32 }, + [IFLA_BRIDGE_MRP_RING_ROLE_ROLE] = { .type = NLA_U32 }, +}; + +static int br_mrp_ring_role_parse(struct net_bridge *br, struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_MRP_RING_ROLE_MAX + 1]; + struct br_mrp_ring_role role; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_MRP_RING_ROLE_MAX, attr, + br_mrp_ring_role_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_MRP_RING_ROLE_RING_ID] || + !tb[IFLA_BRIDGE_MRP_RING_ROLE_ROLE]) { + NL_SET_ERR_MSG_MOD(extack, + "Missing attribute: RING_ID or ROLE"); + return -EINVAL; + } + + memset(&role, 0x0, sizeof(role)); + + role.ring_id = nla_get_u32(tb[IFLA_BRIDGE_MRP_RING_ROLE_RING_ID]); + role.ring_role = nla_get_u32(tb[IFLA_BRIDGE_MRP_RING_ROLE_ROLE]); + + return br_mrp_set_ring_role(br, &role); +} + +static const struct nla_policy +br_mrp_start_test_policy[IFLA_BRIDGE_MRP_START_TEST_MAX + 1] = { + [IFLA_BRIDGE_MRP_START_TEST_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_MRP_START_TEST_RING_ID] = { .type = NLA_U32 }, + [IFLA_BRIDGE_MRP_START_TEST_INTERVAL] = { .type = NLA_U32 }, + [IFLA_BRIDGE_MRP_START_TEST_MAX_MISS] = { .type = NLA_U32 }, + [IFLA_BRIDGE_MRP_START_TEST_PERIOD] = { .type = NLA_U32 }, + [IFLA_BRIDGE_MRP_START_TEST_MONITOR] = { .type = NLA_U32 }, +}; + +static int br_mrp_start_test_parse(struct net_bridge *br, struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_MRP_START_TEST_MAX + 1]; + struct br_mrp_start_test test; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_MRP_START_TEST_MAX, attr, + br_mrp_start_test_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_MRP_START_TEST_RING_ID] || + !tb[IFLA_BRIDGE_MRP_START_TEST_INTERVAL] || + !tb[IFLA_BRIDGE_MRP_START_TEST_MAX_MISS] || + !tb[IFLA_BRIDGE_MRP_START_TEST_PERIOD]) { + NL_SET_ERR_MSG_MOD(extack, + "Missing attribute: RING_ID or INTERVAL or MAX_MISS or PERIOD"); + return -EINVAL; + } + + memset(&test, 0x0, sizeof(test)); + + test.ring_id = nla_get_u32(tb[IFLA_BRIDGE_MRP_START_TEST_RING_ID]); + test.interval = nla_get_u32(tb[IFLA_BRIDGE_MRP_START_TEST_INTERVAL]); + test.max_miss = nla_get_u32(tb[IFLA_BRIDGE_MRP_START_TEST_MAX_MISS]); + test.period = nla_get_u32(tb[IFLA_BRIDGE_MRP_START_TEST_PERIOD]); + test.monitor = false; + + if (tb[IFLA_BRIDGE_MRP_START_TEST_MONITOR]) + test.monitor = + nla_get_u32(tb[IFLA_BRIDGE_MRP_START_TEST_MONITOR]); + + return br_mrp_start_test(br, &test); +} + +static const struct nla_policy +br_mrp_in_state_policy[IFLA_BRIDGE_MRP_IN_STATE_MAX + 1] = { + [IFLA_BRIDGE_MRP_IN_STATE_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_MRP_IN_STATE_IN_ID] = { .type = NLA_U32 }, + [IFLA_BRIDGE_MRP_IN_STATE_STATE] = { .type = NLA_U32 }, +}; + +static int br_mrp_in_state_parse(struct net_bridge *br, struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_MRP_IN_STATE_MAX + 1]; + struct br_mrp_in_state state; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_MRP_IN_STATE_MAX, attr, + br_mrp_in_state_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_MRP_IN_STATE_IN_ID] || + !tb[IFLA_BRIDGE_MRP_IN_STATE_STATE]) { + NL_SET_ERR_MSG_MOD(extack, + "Missing attribute: IN_ID or STATE"); + return -EINVAL; + } + + memset(&state, 0x0, sizeof(state)); + + state.in_id = nla_get_u32(tb[IFLA_BRIDGE_MRP_IN_STATE_IN_ID]); + state.in_state = nla_get_u32(tb[IFLA_BRIDGE_MRP_IN_STATE_STATE]); + + return br_mrp_set_in_state(br, &state); +} + +static const struct nla_policy +br_mrp_in_role_policy[IFLA_BRIDGE_MRP_IN_ROLE_MAX + 1] = { + [IFLA_BRIDGE_MRP_IN_ROLE_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_MRP_IN_ROLE_RING_ID] = { .type = NLA_U32 }, + [IFLA_BRIDGE_MRP_IN_ROLE_IN_ID] = { .type = NLA_U16 }, + [IFLA_BRIDGE_MRP_IN_ROLE_ROLE] = { .type = NLA_U32 }, + [IFLA_BRIDGE_MRP_IN_ROLE_I_IFINDEX] = { .type = NLA_U32 }, +}; + +static int br_mrp_in_role_parse(struct net_bridge *br, struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_MRP_IN_ROLE_MAX + 1]; + struct br_mrp_in_role role; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_MRP_IN_ROLE_MAX, attr, + br_mrp_in_role_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_MRP_IN_ROLE_RING_ID] || + !tb[IFLA_BRIDGE_MRP_IN_ROLE_IN_ID] || + !tb[IFLA_BRIDGE_MRP_IN_ROLE_I_IFINDEX] || + !tb[IFLA_BRIDGE_MRP_IN_ROLE_ROLE]) { + NL_SET_ERR_MSG_MOD(extack, + "Missing attribute: RING_ID or ROLE or IN_ID or I_IFINDEX"); + return -EINVAL; + } + + memset(&role, 0x0, sizeof(role)); + + role.ring_id = nla_get_u32(tb[IFLA_BRIDGE_MRP_IN_ROLE_RING_ID]); + role.in_id = nla_get_u16(tb[IFLA_BRIDGE_MRP_IN_ROLE_IN_ID]); + role.i_ifindex = nla_get_u32(tb[IFLA_BRIDGE_MRP_IN_ROLE_I_IFINDEX]); + role.in_role = nla_get_u32(tb[IFLA_BRIDGE_MRP_IN_ROLE_ROLE]); + + return br_mrp_set_in_role(br, &role); +} + +static const struct nla_policy +br_mrp_start_in_test_policy[IFLA_BRIDGE_MRP_START_IN_TEST_MAX + 1] = { + [IFLA_BRIDGE_MRP_START_IN_TEST_UNSPEC] = { .type = NLA_REJECT }, + [IFLA_BRIDGE_MRP_START_IN_TEST_IN_ID] = { .type = NLA_U32 }, + [IFLA_BRIDGE_MRP_START_IN_TEST_INTERVAL] = { .type = NLA_U32 }, + [IFLA_BRIDGE_MRP_START_IN_TEST_MAX_MISS] = { .type = NLA_U32 }, + [IFLA_BRIDGE_MRP_START_IN_TEST_PERIOD] = { .type = NLA_U32 }, +}; + +static int br_mrp_start_in_test_parse(struct net_bridge *br, + struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_MRP_START_IN_TEST_MAX + 1]; + struct br_mrp_start_in_test test; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_MRP_START_IN_TEST_MAX, attr, + br_mrp_start_in_test_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_MRP_START_IN_TEST_IN_ID] || + !tb[IFLA_BRIDGE_MRP_START_IN_TEST_INTERVAL] || + !tb[IFLA_BRIDGE_MRP_START_IN_TEST_MAX_MISS] || + !tb[IFLA_BRIDGE_MRP_START_IN_TEST_PERIOD]) { + NL_SET_ERR_MSG_MOD(extack, + "Missing attribute: RING_ID or INTERVAL or MAX_MISS or PERIOD"); + return -EINVAL; + } + + memset(&test, 0x0, sizeof(test)); + + test.in_id = nla_get_u32(tb[IFLA_BRIDGE_MRP_START_IN_TEST_IN_ID]); + test.interval = nla_get_u32(tb[IFLA_BRIDGE_MRP_START_IN_TEST_INTERVAL]); + test.max_miss = nla_get_u32(tb[IFLA_BRIDGE_MRP_START_IN_TEST_MAX_MISS]); + test.period = nla_get_u32(tb[IFLA_BRIDGE_MRP_START_IN_TEST_PERIOD]); + + return br_mrp_start_in_test(br, &test); +} + +int br_mrp_parse(struct net_bridge *br, struct net_bridge_port *p, + struct nlattr *attr, int cmd, struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_MRP_MAX + 1]; + int err; + + /* When this function is called for a port then the br pointer is + * invalid, therefor set the br to point correctly + */ + if (p) + br = p->br; + + if (br->stp_enabled != BR_NO_STP) { + NL_SET_ERR_MSG_MOD(extack, "MRP can't be enabled if STP is already enabled"); + return -EINVAL; + } + + err = nla_parse_nested(tb, IFLA_BRIDGE_MRP_MAX, attr, + br_mrp_policy, extack); + if (err) + return err; + + if (tb[IFLA_BRIDGE_MRP_INSTANCE]) { + err = br_mrp_instance_parse(br, tb[IFLA_BRIDGE_MRP_INSTANCE], + cmd, extack); + if (err) + return err; + } + + if (tb[IFLA_BRIDGE_MRP_PORT_STATE]) { + err = br_mrp_port_state_parse(p, tb[IFLA_BRIDGE_MRP_PORT_STATE], + extack); + if (err) + return err; + } + + if (tb[IFLA_BRIDGE_MRP_PORT_ROLE]) { + err = br_mrp_port_role_parse(p, tb[IFLA_BRIDGE_MRP_PORT_ROLE], + extack); + if (err) + return err; + } + + if (tb[IFLA_BRIDGE_MRP_RING_STATE]) { + err = br_mrp_ring_state_parse(br, + tb[IFLA_BRIDGE_MRP_RING_STATE], + extack); + if (err) + return err; + } + + if (tb[IFLA_BRIDGE_MRP_RING_ROLE]) { + err = br_mrp_ring_role_parse(br, tb[IFLA_BRIDGE_MRP_RING_ROLE], + extack); + if (err) + return err; + } + + if (tb[IFLA_BRIDGE_MRP_START_TEST]) { + err = br_mrp_start_test_parse(br, + tb[IFLA_BRIDGE_MRP_START_TEST], + extack); + if (err) + return err; + } + + if (tb[IFLA_BRIDGE_MRP_IN_STATE]) { + err = br_mrp_in_state_parse(br, tb[IFLA_BRIDGE_MRP_IN_STATE], + extack); + if (err) + return err; + } + + if (tb[IFLA_BRIDGE_MRP_IN_ROLE]) { + err = br_mrp_in_role_parse(br, tb[IFLA_BRIDGE_MRP_IN_ROLE], + extack); + if (err) + return err; + } + + if (tb[IFLA_BRIDGE_MRP_START_IN_TEST]) { + err = br_mrp_start_in_test_parse(br, + tb[IFLA_BRIDGE_MRP_START_IN_TEST], + extack); + if (err) + return err; + } + + return 0; +} + +int br_mrp_fill_info(struct sk_buff *skb, struct net_bridge *br) +{ + struct nlattr *tb, *mrp_tb; + struct br_mrp *mrp; + + mrp_tb = nla_nest_start_noflag(skb, IFLA_BRIDGE_MRP); + if (!mrp_tb) + return -EMSGSIZE; + + hlist_for_each_entry_rcu(mrp, &br->mrp_list, list) { + struct net_bridge_port *p; + + tb = nla_nest_start_noflag(skb, IFLA_BRIDGE_MRP_INFO); + if (!tb) + goto nla_info_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_MRP_INFO_RING_ID, + mrp->ring_id)) + goto nla_put_failure; + + p = rcu_dereference(mrp->p_port); + if (p && nla_put_u32(skb, IFLA_BRIDGE_MRP_INFO_P_IFINDEX, + p->dev->ifindex)) + goto nla_put_failure; + + p = rcu_dereference(mrp->s_port); + if (p && nla_put_u32(skb, IFLA_BRIDGE_MRP_INFO_S_IFINDEX, + p->dev->ifindex)) + goto nla_put_failure; + + p = rcu_dereference(mrp->i_port); + if (p && nla_put_u32(skb, IFLA_BRIDGE_MRP_INFO_I_IFINDEX, + p->dev->ifindex)) + goto nla_put_failure; + + if (nla_put_u16(skb, IFLA_BRIDGE_MRP_INFO_PRIO, + mrp->prio)) + goto nla_put_failure; + if (nla_put_u32(skb, IFLA_BRIDGE_MRP_INFO_RING_STATE, + mrp->ring_state)) + goto nla_put_failure; + if (nla_put_u32(skb, IFLA_BRIDGE_MRP_INFO_RING_ROLE, + mrp->ring_role)) + goto nla_put_failure; + if (nla_put_u32(skb, IFLA_BRIDGE_MRP_INFO_TEST_INTERVAL, + mrp->test_interval)) + goto nla_put_failure; + if (nla_put_u32(skb, IFLA_BRIDGE_MRP_INFO_TEST_MAX_MISS, + mrp->test_max_miss)) + goto nla_put_failure; + if (nla_put_u32(skb, IFLA_BRIDGE_MRP_INFO_TEST_MONITOR, + mrp->test_monitor)) + goto nla_put_failure; + + if (nla_put_u32(skb, IFLA_BRIDGE_MRP_INFO_IN_STATE, + mrp->in_state)) + goto nla_put_failure; + if (nla_put_u32(skb, IFLA_BRIDGE_MRP_INFO_IN_ROLE, + mrp->in_role)) + goto nla_put_failure; + if (nla_put_u32(skb, IFLA_BRIDGE_MRP_INFO_IN_TEST_INTERVAL, + mrp->in_test_interval)) + goto nla_put_failure; + if (nla_put_u32(skb, IFLA_BRIDGE_MRP_INFO_IN_TEST_MAX_MISS, + mrp->in_test_max_miss)) + goto nla_put_failure; + + nla_nest_end(skb, tb); + } + nla_nest_end(skb, mrp_tb); + + return 0; + +nla_put_failure: + nla_nest_cancel(skb, tb); + +nla_info_failure: + nla_nest_cancel(skb, mrp_tb); + + return -EMSGSIZE; +} + +int br_mrp_ring_port_open(struct net_device *dev, u8 loc) +{ + struct net_bridge_port *p; + int err = 0; + + p = br_port_get_rcu(dev); + if (!p) { + err = -EINVAL; + goto out; + } + + if (loc) + p->flags |= BR_MRP_LOST_CONT; + else + p->flags &= ~BR_MRP_LOST_CONT; + + br_ifinfo_notify(RTM_NEWLINK, NULL, p); + +out: + return err; +} + +int br_mrp_in_port_open(struct net_device *dev, u8 loc) +{ + struct net_bridge_port *p; + int err = 0; + + p = br_port_get_rcu(dev); + if (!p) { + err = -EINVAL; + goto out; + } + + if (loc) + p->flags |= BR_MRP_LOST_IN_CONT; + else + p->flags &= ~BR_MRP_LOST_IN_CONT; + + br_ifinfo_notify(RTM_NEWLINK, NULL, p); + +out: + return err; +} diff --git a/net/bridge/br_mrp_switchdev.c b/net/bridge/br_mrp_switchdev.c new file mode 100644 index 000000000..cb54b324f --- /dev/null +++ b/net/bridge/br_mrp_switchdev.c @@ -0,0 +1,241 @@ +// SPDX-License-Identifier: GPL-2.0-or-later + +#include + +#include "br_private_mrp.h" + +static enum br_mrp_hw_support +br_mrp_switchdev_port_obj(struct net_bridge *br, + const struct switchdev_obj *obj, bool add) +{ + int err; + + if (add) + err = switchdev_port_obj_add(br->dev, obj, NULL); + else + err = switchdev_port_obj_del(br->dev, obj); + + /* In case of success just return and notify the SW that doesn't need + * to do anything + */ + if (!err) + return BR_MRP_HW; + + if (err != -EOPNOTSUPP) + return BR_MRP_NONE; + + /* Continue with SW backup */ + return BR_MRP_SW; +} + +int br_mrp_switchdev_add(struct net_bridge *br, struct br_mrp *mrp) +{ + struct switchdev_obj_mrp mrp_obj = { + .obj.orig_dev = br->dev, + .obj.id = SWITCHDEV_OBJ_ID_MRP, + .p_port = rtnl_dereference(mrp->p_port)->dev, + .s_port = rtnl_dereference(mrp->s_port)->dev, + .ring_id = mrp->ring_id, + .prio = mrp->prio, + }; + + if (!IS_ENABLED(CONFIG_NET_SWITCHDEV)) + return 0; + + return switchdev_port_obj_add(br->dev, &mrp_obj.obj, NULL); +} + +int br_mrp_switchdev_del(struct net_bridge *br, struct br_mrp *mrp) +{ + struct switchdev_obj_mrp mrp_obj = { + .obj.orig_dev = br->dev, + .obj.id = SWITCHDEV_OBJ_ID_MRP, + .p_port = NULL, + .s_port = NULL, + .ring_id = mrp->ring_id, + }; + + if (!IS_ENABLED(CONFIG_NET_SWITCHDEV)) + return 0; + + return switchdev_port_obj_del(br->dev, &mrp_obj.obj); +} + +enum br_mrp_hw_support +br_mrp_switchdev_set_ring_role(struct net_bridge *br, struct br_mrp *mrp, + enum br_mrp_ring_role_type role) +{ + struct switchdev_obj_ring_role_mrp mrp_role = { + .obj.orig_dev = br->dev, + .obj.id = SWITCHDEV_OBJ_ID_RING_ROLE_MRP, + .ring_role = role, + .ring_id = mrp->ring_id, + .sw_backup = false, + }; + enum br_mrp_hw_support support; + int err; + + if (!IS_ENABLED(CONFIG_NET_SWITCHDEV)) + return BR_MRP_SW; + + support = br_mrp_switchdev_port_obj(br, &mrp_role.obj, + role != BR_MRP_RING_ROLE_DISABLED); + if (support != BR_MRP_SW) + return support; + + /* If the driver can't configure to run completely the protocol in HW, + * then try again to configure the HW so the SW can run the protocol. + */ + mrp_role.sw_backup = true; + if (role != BR_MRP_RING_ROLE_DISABLED) + err = switchdev_port_obj_add(br->dev, &mrp_role.obj, NULL); + else + err = switchdev_port_obj_del(br->dev, &mrp_role.obj); + + if (!err) + return BR_MRP_SW; + + return BR_MRP_NONE; +} + +enum br_mrp_hw_support +br_mrp_switchdev_send_ring_test(struct net_bridge *br, struct br_mrp *mrp, + u32 interval, u8 max_miss, u32 period, + bool monitor) +{ + struct switchdev_obj_ring_test_mrp test = { + .obj.orig_dev = br->dev, + .obj.id = SWITCHDEV_OBJ_ID_RING_TEST_MRP, + .interval = interval, + .max_miss = max_miss, + .ring_id = mrp->ring_id, + .period = period, + .monitor = monitor, + }; + + if (!IS_ENABLED(CONFIG_NET_SWITCHDEV)) + return BR_MRP_SW; + + return br_mrp_switchdev_port_obj(br, &test.obj, interval != 0); +} + +int br_mrp_switchdev_set_ring_state(struct net_bridge *br, + struct br_mrp *mrp, + enum br_mrp_ring_state_type state) +{ + struct switchdev_obj_ring_state_mrp mrp_state = { + .obj.orig_dev = br->dev, + .obj.id = SWITCHDEV_OBJ_ID_RING_STATE_MRP, + .ring_state = state, + .ring_id = mrp->ring_id, + }; + + if (!IS_ENABLED(CONFIG_NET_SWITCHDEV)) + return 0; + + return switchdev_port_obj_add(br->dev, &mrp_state.obj, NULL); +} + +enum br_mrp_hw_support +br_mrp_switchdev_set_in_role(struct net_bridge *br, struct br_mrp *mrp, + u16 in_id, u32 ring_id, + enum br_mrp_in_role_type role) +{ + struct switchdev_obj_in_role_mrp mrp_role = { + .obj.orig_dev = br->dev, + .obj.id = SWITCHDEV_OBJ_ID_IN_ROLE_MRP, + .in_role = role, + .in_id = mrp->in_id, + .ring_id = mrp->ring_id, + .i_port = rtnl_dereference(mrp->i_port)->dev, + .sw_backup = false, + }; + enum br_mrp_hw_support support; + int err; + + if (!IS_ENABLED(CONFIG_NET_SWITCHDEV)) + return BR_MRP_SW; + + support = br_mrp_switchdev_port_obj(br, &mrp_role.obj, + role != BR_MRP_IN_ROLE_DISABLED); + if (support != BR_MRP_NONE) + return support; + + /* If the driver can't configure to run completely the protocol in HW, + * then try again to configure the HW so the SW can run the protocol. + */ + mrp_role.sw_backup = true; + if (role != BR_MRP_IN_ROLE_DISABLED) + err = switchdev_port_obj_add(br->dev, &mrp_role.obj, NULL); + else + err = switchdev_port_obj_del(br->dev, &mrp_role.obj); + + if (!err) + return BR_MRP_SW; + + return BR_MRP_NONE; +} + +int br_mrp_switchdev_set_in_state(struct net_bridge *br, struct br_mrp *mrp, + enum br_mrp_in_state_type state) +{ + struct switchdev_obj_in_state_mrp mrp_state = { + .obj.orig_dev = br->dev, + .obj.id = SWITCHDEV_OBJ_ID_IN_STATE_MRP, + .in_state = state, + .in_id = mrp->in_id, + }; + + if (!IS_ENABLED(CONFIG_NET_SWITCHDEV)) + return 0; + + return switchdev_port_obj_add(br->dev, &mrp_state.obj, NULL); +} + +enum br_mrp_hw_support +br_mrp_switchdev_send_in_test(struct net_bridge *br, struct br_mrp *mrp, + u32 interval, u8 max_miss, u32 period) +{ + struct switchdev_obj_in_test_mrp test = { + .obj.orig_dev = br->dev, + .obj.id = SWITCHDEV_OBJ_ID_IN_TEST_MRP, + .interval = interval, + .max_miss = max_miss, + .in_id = mrp->in_id, + .period = period, + }; + + if (!IS_ENABLED(CONFIG_NET_SWITCHDEV)) + return BR_MRP_SW; + + return br_mrp_switchdev_port_obj(br, &test.obj, interval != 0); +} + +int br_mrp_port_switchdev_set_state(struct net_bridge_port *p, u32 state) +{ + struct switchdev_attr attr = { + .orig_dev = p->dev, + .id = SWITCHDEV_ATTR_ID_PORT_STP_STATE, + .u.stp_state = state, + }; + + if (!IS_ENABLED(CONFIG_NET_SWITCHDEV)) + return 0; + + return switchdev_port_attr_set(p->dev, &attr, NULL); +} + +int br_mrp_port_switchdev_set_role(struct net_bridge_port *p, + enum br_mrp_port_role_type role) +{ + struct switchdev_attr attr = { + .orig_dev = p->dev, + .id = SWITCHDEV_ATTR_ID_MRP_PORT_ROLE, + .u.mrp_port_role = role, + }; + + if (!IS_ENABLED(CONFIG_NET_SWITCHDEV)) + return 0; + + return switchdev_port_attr_set(p->dev, &attr, NULL); +} diff --git a/net/bridge/br_mst.c b/net/bridge/br_mst.c new file mode 100644 index 000000000..ee680adce --- /dev/null +++ b/net/bridge/br_mst.c @@ -0,0 +1,357 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Bridge Multiple Spanning Tree Support + * + * Authors: + * Tobias Waldekranz + */ + +#include +#include + +#include "br_private.h" + +DEFINE_STATIC_KEY_FALSE(br_mst_used); + +bool br_mst_enabled(const struct net_device *dev) +{ + if (!netif_is_bridge_master(dev)) + return false; + + return br_opt_get(netdev_priv(dev), BROPT_MST_ENABLED); +} +EXPORT_SYMBOL_GPL(br_mst_enabled); + +int br_mst_get_info(const struct net_device *dev, u16 msti, unsigned long *vids) +{ + const struct net_bridge_vlan_group *vg; + const struct net_bridge_vlan *v; + const struct net_bridge *br; + + ASSERT_RTNL(); + + if (!netif_is_bridge_master(dev)) + return -EINVAL; + + br = netdev_priv(dev); + if (!br_opt_get(br, BROPT_MST_ENABLED)) + return -EINVAL; + + vg = br_vlan_group(br); + + list_for_each_entry(v, &vg->vlan_list, vlist) { + if (v->msti == msti) + __set_bit(v->vid, vids); + } + + return 0; +} +EXPORT_SYMBOL_GPL(br_mst_get_info); + +int br_mst_get_state(const struct net_device *dev, u16 msti, u8 *state) +{ + const struct net_bridge_port *p = NULL; + const struct net_bridge_vlan_group *vg; + const struct net_bridge_vlan *v; + + ASSERT_RTNL(); + + p = br_port_get_check_rtnl(dev); + if (!p || !br_opt_get(p->br, BROPT_MST_ENABLED)) + return -EINVAL; + + vg = nbp_vlan_group(p); + + list_for_each_entry(v, &vg->vlan_list, vlist) { + if (v->brvlan->msti == msti) { + *state = v->state; + return 0; + } + } + + return -ENOENT; +} +EXPORT_SYMBOL_GPL(br_mst_get_state); + +static void br_mst_vlan_set_state(struct net_bridge_port *p, struct net_bridge_vlan *v, + u8 state) +{ + struct net_bridge_vlan_group *vg = nbp_vlan_group(p); + + if (v->state == state) + return; + + br_vlan_set_state(v, state); + + if (v->vid == vg->pvid) + br_vlan_set_pvid_state(vg, state); +} + +int br_mst_set_state(struct net_bridge_port *p, u16 msti, u8 state, + struct netlink_ext_ack *extack) +{ + struct switchdev_attr attr = { + .id = SWITCHDEV_ATTR_ID_PORT_MST_STATE, + .orig_dev = p->dev, + .u.mst_state = { + .msti = msti, + .state = state, + }, + }; + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *v; + int err; + + vg = nbp_vlan_group(p); + if (!vg) + return 0; + + /* MSTI 0 (CST) state changes are notified via the regular + * SWITCHDEV_ATTR_ID_PORT_STP_STATE. + */ + if (msti) { + err = switchdev_port_attr_set(p->dev, &attr, extack); + if (err && err != -EOPNOTSUPP) + return err; + } + + list_for_each_entry(v, &vg->vlan_list, vlist) { + if (v->brvlan->msti != msti) + continue; + + br_mst_vlan_set_state(p, v, state); + } + + return 0; +} + +static void br_mst_vlan_sync_state(struct net_bridge_vlan *pv, u16 msti) +{ + struct net_bridge_vlan_group *vg = nbp_vlan_group(pv->port); + struct net_bridge_vlan *v; + + list_for_each_entry(v, &vg->vlan_list, vlist) { + /* If this port already has a defined state in this + * MSTI (through some other VLAN membership), inherit + * it. + */ + if (v != pv && v->brvlan->msti == msti) { + br_mst_vlan_set_state(pv->port, pv, v->state); + return; + } + } + + /* Otherwise, start out in a new MSTI with all ports disabled. */ + return br_mst_vlan_set_state(pv->port, pv, BR_STATE_DISABLED); +} + +int br_mst_vlan_set_msti(struct net_bridge_vlan *mv, u16 msti) +{ + struct switchdev_attr attr = { + .id = SWITCHDEV_ATTR_ID_VLAN_MSTI, + .orig_dev = mv->br->dev, + .u.vlan_msti = { + .vid = mv->vid, + .msti = msti, + }, + }; + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *pv; + struct net_bridge_port *p; + int err; + + if (mv->msti == msti) + return 0; + + err = switchdev_port_attr_set(mv->br->dev, &attr, NULL); + if (err && err != -EOPNOTSUPP) + return err; + + mv->msti = msti; + + list_for_each_entry(p, &mv->br->port_list, list) { + vg = nbp_vlan_group(p); + + pv = br_vlan_find(vg, mv->vid); + if (pv) + br_mst_vlan_sync_state(pv, msti); + } + + return 0; +} + +void br_mst_vlan_init_state(struct net_bridge_vlan *v) +{ + /* VLANs always start out in MSTI 0 (CST) */ + v->msti = 0; + + if (br_vlan_is_master(v)) + v->state = BR_STATE_FORWARDING; + else + v->state = v->port->state; +} + +int br_mst_set_enabled(struct net_bridge *br, bool on, + struct netlink_ext_ack *extack) +{ + struct switchdev_attr attr = { + .id = SWITCHDEV_ATTR_ID_BRIDGE_MST, + .orig_dev = br->dev, + .u.mst = on, + }; + struct net_bridge_vlan_group *vg; + struct net_bridge_port *p; + int err; + + list_for_each_entry(p, &br->port_list, list) { + vg = nbp_vlan_group(p); + + if (!vg->num_vlans) + continue; + + NL_SET_ERR_MSG(extack, + "MST mode can't be changed while VLANs exist"); + return -EBUSY; + } + + if (br_opt_get(br, BROPT_MST_ENABLED) == on) + return 0; + + err = switchdev_port_attr_set(br->dev, &attr, extack); + if (err && err != -EOPNOTSUPP) + return err; + + if (on) + static_branch_enable(&br_mst_used); + else + static_branch_disable(&br_mst_used); + + br_opt_toggle(br, BROPT_MST_ENABLED, on); + return 0; +} + +size_t br_mst_info_size(const struct net_bridge_vlan_group *vg) +{ + DECLARE_BITMAP(seen, VLAN_N_VID) = { 0 }; + const struct net_bridge_vlan *v; + size_t sz; + + /* IFLA_BRIDGE_MST */ + sz = nla_total_size(0); + + list_for_each_entry_rcu(v, &vg->vlan_list, vlist) { + if (test_bit(v->brvlan->msti, seen)) + continue; + + /* IFLA_BRIDGE_MST_ENTRY */ + sz += nla_total_size(0) + + /* IFLA_BRIDGE_MST_ENTRY_MSTI */ + nla_total_size(sizeof(u16)) + + /* IFLA_BRIDGE_MST_ENTRY_STATE */ + nla_total_size(sizeof(u8)); + + __set_bit(v->brvlan->msti, seen); + } + + return sz; +} + +int br_mst_fill_info(struct sk_buff *skb, + const struct net_bridge_vlan_group *vg) +{ + DECLARE_BITMAP(seen, VLAN_N_VID) = { 0 }; + const struct net_bridge_vlan *v; + struct nlattr *nest; + int err = 0; + + list_for_each_entry(v, &vg->vlan_list, vlist) { + if (test_bit(v->brvlan->msti, seen)) + continue; + + nest = nla_nest_start_noflag(skb, IFLA_BRIDGE_MST_ENTRY); + if (!nest || + nla_put_u16(skb, IFLA_BRIDGE_MST_ENTRY_MSTI, v->brvlan->msti) || + nla_put_u8(skb, IFLA_BRIDGE_MST_ENTRY_STATE, v->state)) { + err = -EMSGSIZE; + break; + } + nla_nest_end(skb, nest); + + __set_bit(v->brvlan->msti, seen); + } + + return err; +} + +static const struct nla_policy br_mst_nl_policy[IFLA_BRIDGE_MST_ENTRY_MAX + 1] = { + [IFLA_BRIDGE_MST_ENTRY_MSTI] = NLA_POLICY_RANGE(NLA_U16, + 1, /* 0 reserved for CST */ + VLAN_N_VID - 1), + [IFLA_BRIDGE_MST_ENTRY_STATE] = NLA_POLICY_RANGE(NLA_U8, + BR_STATE_DISABLED, + BR_STATE_BLOCKING), +}; + +static int br_mst_process_one(struct net_bridge_port *p, + const struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_BRIDGE_MST_ENTRY_MAX + 1]; + u16 msti; + u8 state; + int err; + + err = nla_parse_nested(tb, IFLA_BRIDGE_MST_ENTRY_MAX, attr, + br_mst_nl_policy, extack); + if (err) + return err; + + if (!tb[IFLA_BRIDGE_MST_ENTRY_MSTI]) { + NL_SET_ERR_MSG_MOD(extack, "MSTI not specified"); + return -EINVAL; + } + + if (!tb[IFLA_BRIDGE_MST_ENTRY_STATE]) { + NL_SET_ERR_MSG_MOD(extack, "State not specified"); + return -EINVAL; + } + + msti = nla_get_u16(tb[IFLA_BRIDGE_MST_ENTRY_MSTI]); + state = nla_get_u8(tb[IFLA_BRIDGE_MST_ENTRY_STATE]); + + return br_mst_set_state(p, msti, state, extack); +} + +int br_mst_process(struct net_bridge_port *p, const struct nlattr *mst_attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *attr; + int err, msts = 0; + int rem; + + if (!br_opt_get(p->br, BROPT_MST_ENABLED)) { + NL_SET_ERR_MSG_MOD(extack, "Can't modify MST state when MST is disabled"); + return -EBUSY; + } + + nla_for_each_nested(attr, mst_attr, rem) { + switch (nla_type(attr)) { + case IFLA_BRIDGE_MST_ENTRY: + err = br_mst_process_one(p, attr, extack); + break; + default: + continue; + } + + msts++; + if (err) + break; + } + + if (!msts) { + NL_SET_ERR_MSG_MOD(extack, "Found no MST entries to process"); + err = -EINVAL; + } + + return err; +} diff --git a/net/bridge/br_multicast.c b/net/bridge/br_multicast.c new file mode 100644 index 000000000..db4f2641d --- /dev/null +++ b/net/bridge/br_multicast.c @@ -0,0 +1,4946 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Bridge multicast support. + * + * Copyright (c) 2010 Herbert Xu + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_IPV6) +#include +#include +#include +#include +#include +#endif + +#include "br_private.h" +#include "br_private_mcast_eht.h" + +static const struct rhashtable_params br_mdb_rht_params = { + .head_offset = offsetof(struct net_bridge_mdb_entry, rhnode), + .key_offset = offsetof(struct net_bridge_mdb_entry, addr), + .key_len = sizeof(struct br_ip), + .automatic_shrinking = true, +}; + +static const struct rhashtable_params br_sg_port_rht_params = { + .head_offset = offsetof(struct net_bridge_port_group, rhnode), + .key_offset = offsetof(struct net_bridge_port_group, key), + .key_len = sizeof(struct net_bridge_port_group_sg_key), + .automatic_shrinking = true, +}; + +static void br_multicast_start_querier(struct net_bridge_mcast *brmctx, + struct bridge_mcast_own_query *query); +static void br_ip4_multicast_add_router(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx); +static void br_ip4_multicast_leave_group(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + __be32 group, + __u16 vid, + const unsigned char *src); +static void br_multicast_port_group_rexmit(struct timer_list *t); + +static void +br_multicast_rport_del_notify(struct net_bridge_mcast_port *pmctx, bool deleted); +static void br_ip6_multicast_add_router(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx); +#if IS_ENABLED(CONFIG_IPV6) +static void br_ip6_multicast_leave_group(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + const struct in6_addr *group, + __u16 vid, const unsigned char *src); +#endif +static struct net_bridge_port_group * +__br_multicast_add_group(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct br_ip *group, + const unsigned char *src, + u8 filter_mode, + bool igmpv2_mldv1, + bool blocked); +static void br_multicast_find_del_pg(struct net_bridge *br, + struct net_bridge_port_group *pg); +static void __br_multicast_stop(struct net_bridge_mcast *brmctx); + +static int br_mc_disabled_update(struct net_device *dev, bool value, + struct netlink_ext_ack *extack); + +static struct net_bridge_port_group * +br_sg_port_find(struct net_bridge *br, + struct net_bridge_port_group_sg_key *sg_p) +{ + lockdep_assert_held_once(&br->multicast_lock); + + return rhashtable_lookup_fast(&br->sg_port_tbl, sg_p, + br_sg_port_rht_params); +} + +static struct net_bridge_mdb_entry *br_mdb_ip_get_rcu(struct net_bridge *br, + struct br_ip *dst) +{ + return rhashtable_lookup(&br->mdb_hash_tbl, dst, br_mdb_rht_params); +} + +struct net_bridge_mdb_entry *br_mdb_ip_get(struct net_bridge *br, + struct br_ip *dst) +{ + struct net_bridge_mdb_entry *ent; + + lockdep_assert_held_once(&br->multicast_lock); + + rcu_read_lock(); + ent = rhashtable_lookup(&br->mdb_hash_tbl, dst, br_mdb_rht_params); + rcu_read_unlock(); + + return ent; +} + +static struct net_bridge_mdb_entry *br_mdb_ip4_get(struct net_bridge *br, + __be32 dst, __u16 vid) +{ + struct br_ip br_dst; + + memset(&br_dst, 0, sizeof(br_dst)); + br_dst.dst.ip4 = dst; + br_dst.proto = htons(ETH_P_IP); + br_dst.vid = vid; + + return br_mdb_ip_get(br, &br_dst); +} + +#if IS_ENABLED(CONFIG_IPV6) +static struct net_bridge_mdb_entry *br_mdb_ip6_get(struct net_bridge *br, + const struct in6_addr *dst, + __u16 vid) +{ + struct br_ip br_dst; + + memset(&br_dst, 0, sizeof(br_dst)); + br_dst.dst.ip6 = *dst; + br_dst.proto = htons(ETH_P_IPV6); + br_dst.vid = vid; + + return br_mdb_ip_get(br, &br_dst); +} +#endif + +struct net_bridge_mdb_entry *br_mdb_get(struct net_bridge_mcast *brmctx, + struct sk_buff *skb, u16 vid) +{ + struct net_bridge *br = brmctx->br; + struct br_ip ip; + + if (!br_opt_get(br, BROPT_MULTICAST_ENABLED) || + br_multicast_ctx_vlan_global_disabled(brmctx)) + return NULL; + + if (BR_INPUT_SKB_CB(skb)->igmp) + return NULL; + + memset(&ip, 0, sizeof(ip)); + ip.proto = skb->protocol; + ip.vid = vid; + + switch (skb->protocol) { + case htons(ETH_P_IP): + ip.dst.ip4 = ip_hdr(skb)->daddr; + if (brmctx->multicast_igmp_version == 3) { + struct net_bridge_mdb_entry *mdb; + + ip.src.ip4 = ip_hdr(skb)->saddr; + mdb = br_mdb_ip_get_rcu(br, &ip); + if (mdb) + return mdb; + ip.src.ip4 = 0; + } + break; +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): + ip.dst.ip6 = ipv6_hdr(skb)->daddr; + if (brmctx->multicast_mld_version == 2) { + struct net_bridge_mdb_entry *mdb; + + ip.src.ip6 = ipv6_hdr(skb)->saddr; + mdb = br_mdb_ip_get_rcu(br, &ip); + if (mdb) + return mdb; + memset(&ip.src.ip6, 0, sizeof(ip.src.ip6)); + } + break; +#endif + default: + ip.proto = 0; + ether_addr_copy(ip.dst.mac_addr, eth_hdr(skb)->h_dest); + } + + return br_mdb_ip_get_rcu(br, &ip); +} + +/* IMPORTANT: this function must be used only when the contexts cannot be + * passed down (e.g. timer) and must be used for read-only purposes because + * the vlan snooping option can change, so it can return any context + * (non-vlan or vlan). Its initial intended purpose is to read timer values + * from the *current* context based on the option. At worst that could lead + * to inconsistent timers when the contexts are changed, i.e. src timer + * which needs to re-arm with a specific delay taken from the old context + */ +static struct net_bridge_mcast_port * +br_multicast_pg_to_port_ctx(const struct net_bridge_port_group *pg) +{ + struct net_bridge_mcast_port *pmctx = &pg->key.port->multicast_ctx; + struct net_bridge_vlan *vlan; + + lockdep_assert_held_once(&pg->key.port->br->multicast_lock); + + /* if vlan snooping is disabled use the port's multicast context */ + if (!pg->key.addr.vid || + !br_opt_get(pg->key.port->br, BROPT_MCAST_VLAN_SNOOPING_ENABLED)) + goto out; + + /* locking is tricky here, due to different rules for multicast and + * vlans we need to take rcu to find the vlan and make sure it has + * the BR_VLFLAG_MCAST_ENABLED flag set, it can only change under + * multicast_lock which must be already held here, so the vlan's pmctx + * can safely be used on return + */ + rcu_read_lock(); + vlan = br_vlan_find(nbp_vlan_group_rcu(pg->key.port), pg->key.addr.vid); + if (vlan && !br_multicast_port_ctx_vlan_disabled(&vlan->port_mcast_ctx)) + pmctx = &vlan->port_mcast_ctx; + else + pmctx = NULL; + rcu_read_unlock(); +out: + return pmctx; +} + +/* when snooping we need to check if the contexts should be used + * in the following order: + * - if pmctx is non-NULL (port), check if it should be used + * - if pmctx is NULL (bridge), check if brmctx should be used + */ +static bool +br_multicast_ctx_should_use(const struct net_bridge_mcast *brmctx, + const struct net_bridge_mcast_port *pmctx) +{ + if (!netif_running(brmctx->br->dev)) + return false; + + if (pmctx) + return !br_multicast_port_ctx_state_disabled(pmctx); + else + return !br_multicast_ctx_vlan_disabled(brmctx); +} + +static bool br_port_group_equal(struct net_bridge_port_group *p, + struct net_bridge_port *port, + const unsigned char *src) +{ + if (p->key.port != port) + return false; + + if (!(port->flags & BR_MULTICAST_TO_UNICAST)) + return true; + + return ether_addr_equal(src, p->eth_addr); +} + +static void __fwd_add_star_excl(struct net_bridge_mcast_port *pmctx, + struct net_bridge_port_group *pg, + struct br_ip *sg_ip) +{ + struct net_bridge_port_group_sg_key sg_key; + struct net_bridge_port_group *src_pg; + struct net_bridge_mcast *brmctx; + + memset(&sg_key, 0, sizeof(sg_key)); + brmctx = br_multicast_port_ctx_get_global(pmctx); + sg_key.port = pg->key.port; + sg_key.addr = *sg_ip; + if (br_sg_port_find(brmctx->br, &sg_key)) + return; + + src_pg = __br_multicast_add_group(brmctx, pmctx, + sg_ip, pg->eth_addr, + MCAST_INCLUDE, false, false); + if (IS_ERR_OR_NULL(src_pg) || + src_pg->rt_protocol != RTPROT_KERNEL) + return; + + src_pg->flags |= MDB_PG_FLAGS_STAR_EXCL; +} + +static void __fwd_del_star_excl(struct net_bridge_port_group *pg, + struct br_ip *sg_ip) +{ + struct net_bridge_port_group_sg_key sg_key; + struct net_bridge *br = pg->key.port->br; + struct net_bridge_port_group *src_pg; + + memset(&sg_key, 0, sizeof(sg_key)); + sg_key.port = pg->key.port; + sg_key.addr = *sg_ip; + src_pg = br_sg_port_find(br, &sg_key); + if (!src_pg || !(src_pg->flags & MDB_PG_FLAGS_STAR_EXCL) || + src_pg->rt_protocol != RTPROT_KERNEL) + return; + + br_multicast_find_del_pg(br, src_pg); +} + +/* When a port group transitions to (or is added as) EXCLUDE we need to add it + * to all other ports' S,G entries which are not blocked by the current group + * for proper replication, the assumption is that any S,G blocked entries + * are already added so the S,G,port lookup should skip them. + * When a port group transitions from EXCLUDE -> INCLUDE mode or is being + * deleted we need to remove it from all ports' S,G entries where it was + * automatically installed before (i.e. where it's MDB_PG_FLAGS_STAR_EXCL). + */ +void br_multicast_star_g_handle_mode(struct net_bridge_port_group *pg, + u8 filter_mode) +{ + struct net_bridge *br = pg->key.port->br; + struct net_bridge_port_group *pg_lst; + struct net_bridge_mcast_port *pmctx; + struct net_bridge_mdb_entry *mp; + struct br_ip sg_ip; + + if (WARN_ON(!br_multicast_is_star_g(&pg->key.addr))) + return; + + mp = br_mdb_ip_get(br, &pg->key.addr); + if (!mp) + return; + pmctx = br_multicast_pg_to_port_ctx(pg); + if (!pmctx) + return; + + memset(&sg_ip, 0, sizeof(sg_ip)); + sg_ip = pg->key.addr; + + for (pg_lst = mlock_dereference(mp->ports, br); + pg_lst; + pg_lst = mlock_dereference(pg_lst->next, br)) { + struct net_bridge_group_src *src_ent; + + if (pg_lst == pg) + continue; + hlist_for_each_entry(src_ent, &pg_lst->src_list, node) { + if (!(src_ent->flags & BR_SGRP_F_INSTALLED)) + continue; + sg_ip.src = src_ent->addr.src; + switch (filter_mode) { + case MCAST_INCLUDE: + __fwd_del_star_excl(pg, &sg_ip); + break; + case MCAST_EXCLUDE: + __fwd_add_star_excl(pmctx, pg, &sg_ip); + break; + } + } + } +} + +/* called when adding a new S,G with host_joined == false by default */ +static void br_multicast_sg_host_state(struct net_bridge_mdb_entry *star_mp, + struct net_bridge_port_group *sg) +{ + struct net_bridge_mdb_entry *sg_mp; + + if (WARN_ON(!br_multicast_is_star_g(&star_mp->addr))) + return; + if (!star_mp->host_joined) + return; + + sg_mp = br_mdb_ip_get(star_mp->br, &sg->key.addr); + if (!sg_mp) + return; + sg_mp->host_joined = true; +} + +/* set the host_joined state of all of *,G's S,G entries */ +static void br_multicast_star_g_host_state(struct net_bridge_mdb_entry *star_mp) +{ + struct net_bridge *br = star_mp->br; + struct net_bridge_mdb_entry *sg_mp; + struct net_bridge_port_group *pg; + struct br_ip sg_ip; + + if (WARN_ON(!br_multicast_is_star_g(&star_mp->addr))) + return; + + memset(&sg_ip, 0, sizeof(sg_ip)); + sg_ip = star_mp->addr; + for (pg = mlock_dereference(star_mp->ports, br); + pg; + pg = mlock_dereference(pg->next, br)) { + struct net_bridge_group_src *src_ent; + + hlist_for_each_entry(src_ent, &pg->src_list, node) { + if (!(src_ent->flags & BR_SGRP_F_INSTALLED)) + continue; + sg_ip.src = src_ent->addr.src; + sg_mp = br_mdb_ip_get(br, &sg_ip); + if (!sg_mp) + continue; + sg_mp->host_joined = star_mp->host_joined; + } + } +} + +static void br_multicast_sg_del_exclude_ports(struct net_bridge_mdb_entry *sgmp) +{ + struct net_bridge_port_group __rcu **pp; + struct net_bridge_port_group *p; + + /* *,G exclude ports are only added to S,G entries */ + if (WARN_ON(br_multicast_is_star_g(&sgmp->addr))) + return; + + /* we need the STAR_EXCLUDE ports if there are non-STAR_EXCLUDE ports + * we should ignore perm entries since they're managed by user-space + */ + for (pp = &sgmp->ports; + (p = mlock_dereference(*pp, sgmp->br)) != NULL; + pp = &p->next) + if (!(p->flags & (MDB_PG_FLAGS_STAR_EXCL | + MDB_PG_FLAGS_PERMANENT))) + return; + + /* currently the host can only have joined the *,G which means + * we treat it as EXCLUDE {}, so for an S,G it's considered a + * STAR_EXCLUDE entry and we can safely leave it + */ + sgmp->host_joined = false; + + for (pp = &sgmp->ports; + (p = mlock_dereference(*pp, sgmp->br)) != NULL;) { + if (!(p->flags & MDB_PG_FLAGS_PERMANENT)) + br_multicast_del_pg(sgmp, p, pp); + else + pp = &p->next; + } +} + +void br_multicast_sg_add_exclude_ports(struct net_bridge_mdb_entry *star_mp, + struct net_bridge_port_group *sg) +{ + struct net_bridge_port_group_sg_key sg_key; + struct net_bridge *br = star_mp->br; + struct net_bridge_mcast_port *pmctx; + struct net_bridge_port_group *pg; + struct net_bridge_mcast *brmctx; + + if (WARN_ON(br_multicast_is_star_g(&sg->key.addr))) + return; + if (WARN_ON(!br_multicast_is_star_g(&star_mp->addr))) + return; + + br_multicast_sg_host_state(star_mp, sg); + memset(&sg_key, 0, sizeof(sg_key)); + sg_key.addr = sg->key.addr; + /* we need to add all exclude ports to the S,G */ + for (pg = mlock_dereference(star_mp->ports, br); + pg; + pg = mlock_dereference(pg->next, br)) { + struct net_bridge_port_group *src_pg; + + if (pg == sg || pg->filter_mode == MCAST_INCLUDE) + continue; + + sg_key.port = pg->key.port; + if (br_sg_port_find(br, &sg_key)) + continue; + + pmctx = br_multicast_pg_to_port_ctx(pg); + if (!pmctx) + continue; + brmctx = br_multicast_port_ctx_get_global(pmctx); + + src_pg = __br_multicast_add_group(brmctx, pmctx, + &sg->key.addr, + sg->eth_addr, + MCAST_INCLUDE, false, false); + if (IS_ERR_OR_NULL(src_pg) || + src_pg->rt_protocol != RTPROT_KERNEL) + continue; + src_pg->flags |= MDB_PG_FLAGS_STAR_EXCL; + } +} + +static void br_multicast_fwd_src_add(struct net_bridge_group_src *src) +{ + struct net_bridge_mdb_entry *star_mp; + struct net_bridge_mcast_port *pmctx; + struct net_bridge_port_group *sg; + struct net_bridge_mcast *brmctx; + struct br_ip sg_ip; + + if (src->flags & BR_SGRP_F_INSTALLED) + return; + + memset(&sg_ip, 0, sizeof(sg_ip)); + pmctx = br_multicast_pg_to_port_ctx(src->pg); + if (!pmctx) + return; + brmctx = br_multicast_port_ctx_get_global(pmctx); + sg_ip = src->pg->key.addr; + sg_ip.src = src->addr.src; + + sg = __br_multicast_add_group(brmctx, pmctx, &sg_ip, + src->pg->eth_addr, MCAST_INCLUDE, false, + !timer_pending(&src->timer)); + if (IS_ERR_OR_NULL(sg)) + return; + src->flags |= BR_SGRP_F_INSTALLED; + sg->flags &= ~MDB_PG_FLAGS_STAR_EXCL; + + /* if it was added by user-space as perm we can skip next steps */ + if (sg->rt_protocol != RTPROT_KERNEL && + (sg->flags & MDB_PG_FLAGS_PERMANENT)) + return; + + /* the kernel is now responsible for removing this S,G */ + del_timer(&sg->timer); + star_mp = br_mdb_ip_get(src->br, &src->pg->key.addr); + if (!star_mp) + return; + + br_multicast_sg_add_exclude_ports(star_mp, sg); +} + +static void br_multicast_fwd_src_remove(struct net_bridge_group_src *src, + bool fastleave) +{ + struct net_bridge_port_group *p, *pg = src->pg; + struct net_bridge_port_group __rcu **pp; + struct net_bridge_mdb_entry *mp; + struct br_ip sg_ip; + + memset(&sg_ip, 0, sizeof(sg_ip)); + sg_ip = pg->key.addr; + sg_ip.src = src->addr.src; + + mp = br_mdb_ip_get(src->br, &sg_ip); + if (!mp) + return; + + for (pp = &mp->ports; + (p = mlock_dereference(*pp, src->br)) != NULL; + pp = &p->next) { + if (!br_port_group_equal(p, pg->key.port, pg->eth_addr)) + continue; + + if (p->rt_protocol != RTPROT_KERNEL && + (p->flags & MDB_PG_FLAGS_PERMANENT)) + break; + + if (fastleave) + p->flags |= MDB_PG_FLAGS_FAST_LEAVE; + br_multicast_del_pg(mp, p, pp); + break; + } + src->flags &= ~BR_SGRP_F_INSTALLED; +} + +/* install S,G and based on src's timer enable or disable forwarding */ +static void br_multicast_fwd_src_handle(struct net_bridge_group_src *src) +{ + struct net_bridge_port_group_sg_key sg_key; + struct net_bridge_port_group *sg; + u8 old_flags; + + br_multicast_fwd_src_add(src); + + memset(&sg_key, 0, sizeof(sg_key)); + sg_key.addr = src->pg->key.addr; + sg_key.addr.src = src->addr.src; + sg_key.port = src->pg->key.port; + + sg = br_sg_port_find(src->br, &sg_key); + if (!sg || (sg->flags & MDB_PG_FLAGS_PERMANENT)) + return; + + old_flags = sg->flags; + if (timer_pending(&src->timer)) + sg->flags &= ~MDB_PG_FLAGS_BLOCKED; + else + sg->flags |= MDB_PG_FLAGS_BLOCKED; + + if (old_flags != sg->flags) { + struct net_bridge_mdb_entry *sg_mp; + + sg_mp = br_mdb_ip_get(src->br, &sg_key.addr); + if (!sg_mp) + return; + br_mdb_notify(src->br->dev, sg_mp, sg, RTM_NEWMDB); + } +} + +static void br_multicast_destroy_mdb_entry(struct net_bridge_mcast_gc *gc) +{ + struct net_bridge_mdb_entry *mp; + + mp = container_of(gc, struct net_bridge_mdb_entry, mcast_gc); + WARN_ON(!hlist_unhashed(&mp->mdb_node)); + WARN_ON(mp->ports); + + del_timer_sync(&mp->timer); + kfree_rcu(mp, rcu); +} + +static void br_multicast_del_mdb_entry(struct net_bridge_mdb_entry *mp) +{ + struct net_bridge *br = mp->br; + + rhashtable_remove_fast(&br->mdb_hash_tbl, &mp->rhnode, + br_mdb_rht_params); + hlist_del_init_rcu(&mp->mdb_node); + hlist_add_head(&mp->mcast_gc.gc_node, &br->mcast_gc_list); + queue_work(system_long_wq, &br->mcast_gc_work); +} + +static void br_multicast_group_expired(struct timer_list *t) +{ + struct net_bridge_mdb_entry *mp = from_timer(mp, t, timer); + struct net_bridge *br = mp->br; + + spin_lock(&br->multicast_lock); + if (hlist_unhashed(&mp->mdb_node) || !netif_running(br->dev) || + timer_pending(&mp->timer)) + goto out; + + br_multicast_host_leave(mp, true); + + if (mp->ports) + goto out; + br_multicast_del_mdb_entry(mp); +out: + spin_unlock(&br->multicast_lock); +} + +static void br_multicast_destroy_group_src(struct net_bridge_mcast_gc *gc) +{ + struct net_bridge_group_src *src; + + src = container_of(gc, struct net_bridge_group_src, mcast_gc); + WARN_ON(!hlist_unhashed(&src->node)); + + del_timer_sync(&src->timer); + kfree_rcu(src, rcu); +} + +void br_multicast_del_group_src(struct net_bridge_group_src *src, + bool fastleave) +{ + struct net_bridge *br = src->pg->key.port->br; + + br_multicast_fwd_src_remove(src, fastleave); + hlist_del_init_rcu(&src->node); + src->pg->src_ents--; + hlist_add_head(&src->mcast_gc.gc_node, &br->mcast_gc_list); + queue_work(system_long_wq, &br->mcast_gc_work); +} + +static void br_multicast_destroy_port_group(struct net_bridge_mcast_gc *gc) +{ + struct net_bridge_port_group *pg; + + pg = container_of(gc, struct net_bridge_port_group, mcast_gc); + WARN_ON(!hlist_unhashed(&pg->mglist)); + WARN_ON(!hlist_empty(&pg->src_list)); + + del_timer_sync(&pg->rexmit_timer); + del_timer_sync(&pg->timer); + kfree_rcu(pg, rcu); +} + +void br_multicast_del_pg(struct net_bridge_mdb_entry *mp, + struct net_bridge_port_group *pg, + struct net_bridge_port_group __rcu **pp) +{ + struct net_bridge *br = pg->key.port->br; + struct net_bridge_group_src *ent; + struct hlist_node *tmp; + + rcu_assign_pointer(*pp, pg->next); + hlist_del_init(&pg->mglist); + br_multicast_eht_clean_sets(pg); + hlist_for_each_entry_safe(ent, tmp, &pg->src_list, node) + br_multicast_del_group_src(ent, false); + br_mdb_notify(br->dev, mp, pg, RTM_DELMDB); + if (!br_multicast_is_star_g(&mp->addr)) { + rhashtable_remove_fast(&br->sg_port_tbl, &pg->rhnode, + br_sg_port_rht_params); + br_multicast_sg_del_exclude_ports(mp); + } else { + br_multicast_star_g_handle_mode(pg, MCAST_INCLUDE); + } + hlist_add_head(&pg->mcast_gc.gc_node, &br->mcast_gc_list); + queue_work(system_long_wq, &br->mcast_gc_work); + + if (!mp->ports && !mp->host_joined && netif_running(br->dev)) + mod_timer(&mp->timer, jiffies); +} + +static void br_multicast_find_del_pg(struct net_bridge *br, + struct net_bridge_port_group *pg) +{ + struct net_bridge_port_group __rcu **pp; + struct net_bridge_mdb_entry *mp; + struct net_bridge_port_group *p; + + mp = br_mdb_ip_get(br, &pg->key.addr); + if (WARN_ON(!mp)) + return; + + for (pp = &mp->ports; + (p = mlock_dereference(*pp, br)) != NULL; + pp = &p->next) { + if (p != pg) + continue; + + br_multicast_del_pg(mp, pg, pp); + return; + } + + WARN_ON(1); +} + +static void br_multicast_port_group_expired(struct timer_list *t) +{ + struct net_bridge_port_group *pg = from_timer(pg, t, timer); + struct net_bridge_group_src *src_ent; + struct net_bridge *br = pg->key.port->br; + struct hlist_node *tmp; + bool changed; + + spin_lock(&br->multicast_lock); + if (!netif_running(br->dev) || timer_pending(&pg->timer) || + hlist_unhashed(&pg->mglist) || pg->flags & MDB_PG_FLAGS_PERMANENT) + goto out; + + changed = !!(pg->filter_mode == MCAST_EXCLUDE); + pg->filter_mode = MCAST_INCLUDE; + hlist_for_each_entry_safe(src_ent, tmp, &pg->src_list, node) { + if (!timer_pending(&src_ent->timer)) { + br_multicast_del_group_src(src_ent, false); + changed = true; + } + } + + if (hlist_empty(&pg->src_list)) { + br_multicast_find_del_pg(br, pg); + } else if (changed) { + struct net_bridge_mdb_entry *mp = br_mdb_ip_get(br, &pg->key.addr); + + if (changed && br_multicast_is_star_g(&pg->key.addr)) + br_multicast_star_g_handle_mode(pg, MCAST_INCLUDE); + + if (WARN_ON(!mp)) + goto out; + br_mdb_notify(br->dev, mp, pg, RTM_NEWMDB); + } +out: + spin_unlock(&br->multicast_lock); +} + +static void br_multicast_gc(struct hlist_head *head) +{ + struct net_bridge_mcast_gc *gcent; + struct hlist_node *tmp; + + hlist_for_each_entry_safe(gcent, tmp, head, gc_node) { + hlist_del_init(&gcent->gc_node); + gcent->destroy(gcent); + } +} + +static void __br_multicast_query_handle_vlan(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct sk_buff *skb) +{ + struct net_bridge_vlan *vlan = NULL; + + if (pmctx && br_multicast_port_ctx_is_vlan(pmctx)) + vlan = pmctx->vlan; + else if (br_multicast_ctx_is_vlan(brmctx)) + vlan = brmctx->vlan; + + if (vlan && !(vlan->flags & BRIDGE_VLAN_INFO_UNTAGGED)) { + u16 vlan_proto; + + if (br_vlan_get_proto(brmctx->br->dev, &vlan_proto) != 0) + return; + __vlan_hwaccel_put_tag(skb, htons(vlan_proto), vlan->vid); + } +} + +static struct sk_buff *br_ip4_multicast_alloc_query(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct net_bridge_port_group *pg, + __be32 ip_dst, __be32 group, + bool with_srcs, bool over_lmqt, + u8 sflag, u8 *igmp_type, + bool *need_rexmit) +{ + struct net_bridge_port *p = pg ? pg->key.port : NULL; + struct net_bridge_group_src *ent; + size_t pkt_size, igmp_hdr_size; + unsigned long now = jiffies; + struct igmpv3_query *ihv3; + void *csum_start = NULL; + __sum16 *csum = NULL; + struct sk_buff *skb; + struct igmphdr *ih; + struct ethhdr *eth; + unsigned long lmqt; + struct iphdr *iph; + u16 lmqt_srcs = 0; + + igmp_hdr_size = sizeof(*ih); + if (brmctx->multicast_igmp_version == 3) { + igmp_hdr_size = sizeof(*ihv3); + if (pg && with_srcs) { + lmqt = now + (brmctx->multicast_last_member_interval * + brmctx->multicast_last_member_count); + hlist_for_each_entry(ent, &pg->src_list, node) { + if (over_lmqt == time_after(ent->timer.expires, + lmqt) && + ent->src_query_rexmit_cnt > 0) + lmqt_srcs++; + } + + if (!lmqt_srcs) + return NULL; + igmp_hdr_size += lmqt_srcs * sizeof(__be32); + } + } + + pkt_size = sizeof(*eth) + sizeof(*iph) + 4 + igmp_hdr_size; + if ((p && pkt_size > p->dev->mtu) || + pkt_size > brmctx->br->dev->mtu) + return NULL; + + skb = netdev_alloc_skb_ip_align(brmctx->br->dev, pkt_size); + if (!skb) + goto out; + + __br_multicast_query_handle_vlan(brmctx, pmctx, skb); + skb->protocol = htons(ETH_P_IP); + + skb_reset_mac_header(skb); + eth = eth_hdr(skb); + + ether_addr_copy(eth->h_source, brmctx->br->dev->dev_addr); + ip_eth_mc_map(ip_dst, eth->h_dest); + eth->h_proto = htons(ETH_P_IP); + skb_put(skb, sizeof(*eth)); + + skb_set_network_header(skb, skb->len); + iph = ip_hdr(skb); + iph->tot_len = htons(pkt_size - sizeof(*eth)); + + iph->version = 4; + iph->ihl = 6; + iph->tos = 0xc0; + iph->id = 0; + iph->frag_off = htons(IP_DF); + iph->ttl = 1; + iph->protocol = IPPROTO_IGMP; + iph->saddr = br_opt_get(brmctx->br, BROPT_MULTICAST_QUERY_USE_IFADDR) ? + inet_select_addr(brmctx->br->dev, 0, RT_SCOPE_LINK) : 0; + iph->daddr = ip_dst; + ((u8 *)&iph[1])[0] = IPOPT_RA; + ((u8 *)&iph[1])[1] = 4; + ((u8 *)&iph[1])[2] = 0; + ((u8 *)&iph[1])[3] = 0; + ip_send_check(iph); + skb_put(skb, 24); + + skb_set_transport_header(skb, skb->len); + *igmp_type = IGMP_HOST_MEMBERSHIP_QUERY; + + switch (brmctx->multicast_igmp_version) { + case 2: + ih = igmp_hdr(skb); + ih->type = IGMP_HOST_MEMBERSHIP_QUERY; + ih->code = (group ? brmctx->multicast_last_member_interval : + brmctx->multicast_query_response_interval) / + (HZ / IGMP_TIMER_SCALE); + ih->group = group; + ih->csum = 0; + csum = &ih->csum; + csum_start = (void *)ih; + break; + case 3: + ihv3 = igmpv3_query_hdr(skb); + ihv3->type = IGMP_HOST_MEMBERSHIP_QUERY; + ihv3->code = (group ? brmctx->multicast_last_member_interval : + brmctx->multicast_query_response_interval) / + (HZ / IGMP_TIMER_SCALE); + ihv3->group = group; + ihv3->qqic = brmctx->multicast_query_interval / HZ; + ihv3->nsrcs = htons(lmqt_srcs); + ihv3->resv = 0; + ihv3->suppress = sflag; + ihv3->qrv = 2; + ihv3->csum = 0; + csum = &ihv3->csum; + csum_start = (void *)ihv3; + if (!pg || !with_srcs) + break; + + lmqt_srcs = 0; + hlist_for_each_entry(ent, &pg->src_list, node) { + if (over_lmqt == time_after(ent->timer.expires, + lmqt) && + ent->src_query_rexmit_cnt > 0) { + ihv3->srcs[lmqt_srcs++] = ent->addr.src.ip4; + ent->src_query_rexmit_cnt--; + if (need_rexmit && ent->src_query_rexmit_cnt) + *need_rexmit = true; + } + } + if (WARN_ON(lmqt_srcs != ntohs(ihv3->nsrcs))) { + kfree_skb(skb); + return NULL; + } + break; + } + + if (WARN_ON(!csum || !csum_start)) { + kfree_skb(skb); + return NULL; + } + + *csum = ip_compute_csum(csum_start, igmp_hdr_size); + skb_put(skb, igmp_hdr_size); + __skb_pull(skb, sizeof(*eth)); + +out: + return skb; +} + +#if IS_ENABLED(CONFIG_IPV6) +static struct sk_buff *br_ip6_multicast_alloc_query(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct net_bridge_port_group *pg, + const struct in6_addr *ip6_dst, + const struct in6_addr *group, + bool with_srcs, bool over_llqt, + u8 sflag, u8 *igmp_type, + bool *need_rexmit) +{ + struct net_bridge_port *p = pg ? pg->key.port : NULL; + struct net_bridge_group_src *ent; + size_t pkt_size, mld_hdr_size; + unsigned long now = jiffies; + struct mld2_query *mld2q; + void *csum_start = NULL; + unsigned long interval; + __sum16 *csum = NULL; + struct ipv6hdr *ip6h; + struct mld_msg *mldq; + struct sk_buff *skb; + unsigned long llqt; + struct ethhdr *eth; + u16 llqt_srcs = 0; + u8 *hopopt; + + mld_hdr_size = sizeof(*mldq); + if (brmctx->multicast_mld_version == 2) { + mld_hdr_size = sizeof(*mld2q); + if (pg && with_srcs) { + llqt = now + (brmctx->multicast_last_member_interval * + brmctx->multicast_last_member_count); + hlist_for_each_entry(ent, &pg->src_list, node) { + if (over_llqt == time_after(ent->timer.expires, + llqt) && + ent->src_query_rexmit_cnt > 0) + llqt_srcs++; + } + + if (!llqt_srcs) + return NULL; + mld_hdr_size += llqt_srcs * sizeof(struct in6_addr); + } + } + + pkt_size = sizeof(*eth) + sizeof(*ip6h) + 8 + mld_hdr_size; + if ((p && pkt_size > p->dev->mtu) || + pkt_size > brmctx->br->dev->mtu) + return NULL; + + skb = netdev_alloc_skb_ip_align(brmctx->br->dev, pkt_size); + if (!skb) + goto out; + + __br_multicast_query_handle_vlan(brmctx, pmctx, skb); + skb->protocol = htons(ETH_P_IPV6); + + /* Ethernet header */ + skb_reset_mac_header(skb); + eth = eth_hdr(skb); + + ether_addr_copy(eth->h_source, brmctx->br->dev->dev_addr); + eth->h_proto = htons(ETH_P_IPV6); + skb_put(skb, sizeof(*eth)); + + /* IPv6 header + HbH option */ + skb_set_network_header(skb, skb->len); + ip6h = ipv6_hdr(skb); + + *(__force __be32 *)ip6h = htonl(0x60000000); + ip6h->payload_len = htons(8 + mld_hdr_size); + ip6h->nexthdr = IPPROTO_HOPOPTS; + ip6h->hop_limit = 1; + ip6h->daddr = *ip6_dst; + if (ipv6_dev_get_saddr(dev_net(brmctx->br->dev), brmctx->br->dev, + &ip6h->daddr, 0, &ip6h->saddr)) { + kfree_skb(skb); + br_opt_toggle(brmctx->br, BROPT_HAS_IPV6_ADDR, false); + return NULL; + } + + br_opt_toggle(brmctx->br, BROPT_HAS_IPV6_ADDR, true); + ipv6_eth_mc_map(&ip6h->daddr, eth->h_dest); + + hopopt = (u8 *)(ip6h + 1); + hopopt[0] = IPPROTO_ICMPV6; /* next hdr */ + hopopt[1] = 0; /* length of HbH */ + hopopt[2] = IPV6_TLV_ROUTERALERT; /* Router Alert */ + hopopt[3] = 2; /* Length of RA Option */ + hopopt[4] = 0; /* Type = 0x0000 (MLD) */ + hopopt[5] = 0; + hopopt[6] = IPV6_TLV_PAD1; /* Pad1 */ + hopopt[7] = IPV6_TLV_PAD1; /* Pad1 */ + + skb_put(skb, sizeof(*ip6h) + 8); + + /* ICMPv6 */ + skb_set_transport_header(skb, skb->len); + interval = ipv6_addr_any(group) ? + brmctx->multicast_query_response_interval : + brmctx->multicast_last_member_interval; + *igmp_type = ICMPV6_MGM_QUERY; + switch (brmctx->multicast_mld_version) { + case 1: + mldq = (struct mld_msg *)icmp6_hdr(skb); + mldq->mld_type = ICMPV6_MGM_QUERY; + mldq->mld_code = 0; + mldq->mld_cksum = 0; + mldq->mld_maxdelay = htons((u16)jiffies_to_msecs(interval)); + mldq->mld_reserved = 0; + mldq->mld_mca = *group; + csum = &mldq->mld_cksum; + csum_start = (void *)mldq; + break; + case 2: + mld2q = (struct mld2_query *)icmp6_hdr(skb); + mld2q->mld2q_mrc = htons((u16)jiffies_to_msecs(interval)); + mld2q->mld2q_type = ICMPV6_MGM_QUERY; + mld2q->mld2q_code = 0; + mld2q->mld2q_cksum = 0; + mld2q->mld2q_resv1 = 0; + mld2q->mld2q_resv2 = 0; + mld2q->mld2q_suppress = sflag; + mld2q->mld2q_qrv = 2; + mld2q->mld2q_nsrcs = htons(llqt_srcs); + mld2q->mld2q_qqic = brmctx->multicast_query_interval / HZ; + mld2q->mld2q_mca = *group; + csum = &mld2q->mld2q_cksum; + csum_start = (void *)mld2q; + if (!pg || !with_srcs) + break; + + llqt_srcs = 0; + hlist_for_each_entry(ent, &pg->src_list, node) { + if (over_llqt == time_after(ent->timer.expires, + llqt) && + ent->src_query_rexmit_cnt > 0) { + mld2q->mld2q_srcs[llqt_srcs++] = ent->addr.src.ip6; + ent->src_query_rexmit_cnt--; + if (need_rexmit && ent->src_query_rexmit_cnt) + *need_rexmit = true; + } + } + if (WARN_ON(llqt_srcs != ntohs(mld2q->mld2q_nsrcs))) { + kfree_skb(skb); + return NULL; + } + break; + } + + if (WARN_ON(!csum || !csum_start)) { + kfree_skb(skb); + return NULL; + } + + *csum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, mld_hdr_size, + IPPROTO_ICMPV6, + csum_partial(csum_start, mld_hdr_size, 0)); + skb_put(skb, mld_hdr_size); + __skb_pull(skb, sizeof(*eth)); + +out: + return skb; +} +#endif + +static struct sk_buff *br_multicast_alloc_query(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct net_bridge_port_group *pg, + struct br_ip *ip_dst, + struct br_ip *group, + bool with_srcs, bool over_lmqt, + u8 sflag, u8 *igmp_type, + bool *need_rexmit) +{ + __be32 ip4_dst; + + switch (group->proto) { + case htons(ETH_P_IP): + ip4_dst = ip_dst ? ip_dst->dst.ip4 : htonl(INADDR_ALLHOSTS_GROUP); + return br_ip4_multicast_alloc_query(brmctx, pmctx, pg, + ip4_dst, group->dst.ip4, + with_srcs, over_lmqt, + sflag, igmp_type, + need_rexmit); +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): { + struct in6_addr ip6_dst; + + if (ip_dst) + ip6_dst = ip_dst->dst.ip6; + else + ipv6_addr_set(&ip6_dst, htonl(0xff020000), 0, 0, + htonl(1)); + + return br_ip6_multicast_alloc_query(brmctx, pmctx, pg, + &ip6_dst, &group->dst.ip6, + with_srcs, over_lmqt, + sflag, igmp_type, + need_rexmit); + } +#endif + } + return NULL; +} + +struct net_bridge_mdb_entry *br_multicast_new_group(struct net_bridge *br, + struct br_ip *group) +{ + struct net_bridge_mdb_entry *mp; + int err; + + mp = br_mdb_ip_get(br, group); + if (mp) + return mp; + + if (atomic_read(&br->mdb_hash_tbl.nelems) >= br->hash_max) { + br_mc_disabled_update(br->dev, false, NULL); + br_opt_toggle(br, BROPT_MULTICAST_ENABLED, false); + return ERR_PTR(-E2BIG); + } + + mp = kzalloc(sizeof(*mp), GFP_ATOMIC); + if (unlikely(!mp)) + return ERR_PTR(-ENOMEM); + + mp->br = br; + mp->addr = *group; + mp->mcast_gc.destroy = br_multicast_destroy_mdb_entry; + timer_setup(&mp->timer, br_multicast_group_expired, 0); + err = rhashtable_lookup_insert_fast(&br->mdb_hash_tbl, &mp->rhnode, + br_mdb_rht_params); + if (err) { + kfree(mp); + mp = ERR_PTR(err); + } else { + hlist_add_head_rcu(&mp->mdb_node, &br->mdb_list); + } + + return mp; +} + +static void br_multicast_group_src_expired(struct timer_list *t) +{ + struct net_bridge_group_src *src = from_timer(src, t, timer); + struct net_bridge_port_group *pg; + struct net_bridge *br = src->br; + + spin_lock(&br->multicast_lock); + if (hlist_unhashed(&src->node) || !netif_running(br->dev) || + timer_pending(&src->timer)) + goto out; + + pg = src->pg; + if (pg->filter_mode == MCAST_INCLUDE) { + br_multicast_del_group_src(src, false); + if (!hlist_empty(&pg->src_list)) + goto out; + br_multicast_find_del_pg(br, pg); + } else { + br_multicast_fwd_src_handle(src); + } + +out: + spin_unlock(&br->multicast_lock); +} + +struct net_bridge_group_src * +br_multicast_find_group_src(struct net_bridge_port_group *pg, struct br_ip *ip) +{ + struct net_bridge_group_src *ent; + + switch (ip->proto) { + case htons(ETH_P_IP): + hlist_for_each_entry(ent, &pg->src_list, node) + if (ip->src.ip4 == ent->addr.src.ip4) + return ent; + break; +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): + hlist_for_each_entry(ent, &pg->src_list, node) + if (!ipv6_addr_cmp(&ent->addr.src.ip6, &ip->src.ip6)) + return ent; + break; +#endif + } + + return NULL; +} + +static struct net_bridge_group_src * +br_multicast_new_group_src(struct net_bridge_port_group *pg, struct br_ip *src_ip) +{ + struct net_bridge_group_src *grp_src; + + if (unlikely(pg->src_ents >= PG_SRC_ENT_LIMIT)) + return NULL; + + switch (src_ip->proto) { + case htons(ETH_P_IP): + if (ipv4_is_zeronet(src_ip->src.ip4) || + ipv4_is_multicast(src_ip->src.ip4)) + return NULL; + break; +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): + if (ipv6_addr_any(&src_ip->src.ip6) || + ipv6_addr_is_multicast(&src_ip->src.ip6)) + return NULL; + break; +#endif + } + + grp_src = kzalloc(sizeof(*grp_src), GFP_ATOMIC); + if (unlikely(!grp_src)) + return NULL; + + grp_src->pg = pg; + grp_src->br = pg->key.port->br; + grp_src->addr = *src_ip; + grp_src->mcast_gc.destroy = br_multicast_destroy_group_src; + timer_setup(&grp_src->timer, br_multicast_group_src_expired, 0); + + hlist_add_head_rcu(&grp_src->node, &pg->src_list); + pg->src_ents++; + + return grp_src; +} + +struct net_bridge_port_group *br_multicast_new_port_group( + struct net_bridge_port *port, + struct br_ip *group, + struct net_bridge_port_group __rcu *next, + unsigned char flags, + const unsigned char *src, + u8 filter_mode, + u8 rt_protocol) +{ + struct net_bridge_port_group *p; + + p = kzalloc(sizeof(*p), GFP_ATOMIC); + if (unlikely(!p)) + return NULL; + + p->key.addr = *group; + p->key.port = port; + p->flags = flags; + p->filter_mode = filter_mode; + p->rt_protocol = rt_protocol; + p->eht_host_tree = RB_ROOT; + p->eht_set_tree = RB_ROOT; + p->mcast_gc.destroy = br_multicast_destroy_port_group; + INIT_HLIST_HEAD(&p->src_list); + + if (!br_multicast_is_star_g(group) && + rhashtable_lookup_insert_fast(&port->br->sg_port_tbl, &p->rhnode, + br_sg_port_rht_params)) { + kfree(p); + return NULL; + } + + rcu_assign_pointer(p->next, next); + timer_setup(&p->timer, br_multicast_port_group_expired, 0); + timer_setup(&p->rexmit_timer, br_multicast_port_group_rexmit, 0); + hlist_add_head(&p->mglist, &port->mglist); + + if (src) + memcpy(p->eth_addr, src, ETH_ALEN); + else + eth_broadcast_addr(p->eth_addr); + + return p; +} + +void br_multicast_host_join(const struct net_bridge_mcast *brmctx, + struct net_bridge_mdb_entry *mp, bool notify) +{ + if (!mp->host_joined) { + mp->host_joined = true; + if (br_multicast_is_star_g(&mp->addr)) + br_multicast_star_g_host_state(mp); + if (notify) + br_mdb_notify(mp->br->dev, mp, NULL, RTM_NEWMDB); + } + + if (br_group_is_l2(&mp->addr)) + return; + + mod_timer(&mp->timer, jiffies + brmctx->multicast_membership_interval); +} + +void br_multicast_host_leave(struct net_bridge_mdb_entry *mp, bool notify) +{ + if (!mp->host_joined) + return; + + mp->host_joined = false; + if (br_multicast_is_star_g(&mp->addr)) + br_multicast_star_g_host_state(mp); + if (notify) + br_mdb_notify(mp->br->dev, mp, NULL, RTM_DELMDB); +} + +static struct net_bridge_port_group * +__br_multicast_add_group(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct br_ip *group, + const unsigned char *src, + u8 filter_mode, + bool igmpv2_mldv1, + bool blocked) +{ + struct net_bridge_port_group __rcu **pp; + struct net_bridge_port_group *p = NULL; + struct net_bridge_mdb_entry *mp; + unsigned long now = jiffies; + + if (!br_multicast_ctx_should_use(brmctx, pmctx)) + goto out; + + mp = br_multicast_new_group(brmctx->br, group); + if (IS_ERR(mp)) + return ERR_CAST(mp); + + if (!pmctx) { + br_multicast_host_join(brmctx, mp, true); + goto out; + } + + for (pp = &mp->ports; + (p = mlock_dereference(*pp, brmctx->br)) != NULL; + pp = &p->next) { + if (br_port_group_equal(p, pmctx->port, src)) + goto found; + if ((unsigned long)p->key.port < (unsigned long)pmctx->port) + break; + } + + p = br_multicast_new_port_group(pmctx->port, group, *pp, 0, src, + filter_mode, RTPROT_KERNEL); + if (unlikely(!p)) { + p = ERR_PTR(-ENOMEM); + goto out; + } + rcu_assign_pointer(*pp, p); + if (blocked) + p->flags |= MDB_PG_FLAGS_BLOCKED; + br_mdb_notify(brmctx->br->dev, mp, p, RTM_NEWMDB); + +found: + if (igmpv2_mldv1) + mod_timer(&p->timer, + now + brmctx->multicast_membership_interval); + +out: + return p; +} + +static int br_multicast_add_group(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct br_ip *group, + const unsigned char *src, + u8 filter_mode, + bool igmpv2_mldv1) +{ + struct net_bridge_port_group *pg; + int err; + + spin_lock(&brmctx->br->multicast_lock); + pg = __br_multicast_add_group(brmctx, pmctx, group, src, filter_mode, + igmpv2_mldv1, false); + /* NULL is considered valid for host joined groups */ + err = PTR_ERR_OR_ZERO(pg); + spin_unlock(&brmctx->br->multicast_lock); + + return err; +} + +static int br_ip4_multicast_add_group(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + __be32 group, + __u16 vid, + const unsigned char *src, + bool igmpv2) +{ + struct br_ip br_group; + u8 filter_mode; + + if (ipv4_is_local_multicast(group)) + return 0; + + memset(&br_group, 0, sizeof(br_group)); + br_group.dst.ip4 = group; + br_group.proto = htons(ETH_P_IP); + br_group.vid = vid; + filter_mode = igmpv2 ? MCAST_EXCLUDE : MCAST_INCLUDE; + + return br_multicast_add_group(brmctx, pmctx, &br_group, src, + filter_mode, igmpv2); +} + +#if IS_ENABLED(CONFIG_IPV6) +static int br_ip6_multicast_add_group(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + const struct in6_addr *group, + __u16 vid, + const unsigned char *src, + bool mldv1) +{ + struct br_ip br_group; + u8 filter_mode; + + if (ipv6_addr_is_ll_all_nodes(group)) + return 0; + + memset(&br_group, 0, sizeof(br_group)); + br_group.dst.ip6 = *group; + br_group.proto = htons(ETH_P_IPV6); + br_group.vid = vid; + filter_mode = mldv1 ? MCAST_EXCLUDE : MCAST_INCLUDE; + + return br_multicast_add_group(brmctx, pmctx, &br_group, src, + filter_mode, mldv1); +} +#endif + +static bool br_multicast_rport_del(struct hlist_node *rlist) +{ + if (hlist_unhashed(rlist)) + return false; + + hlist_del_init_rcu(rlist); + return true; +} + +static bool br_ip4_multicast_rport_del(struct net_bridge_mcast_port *pmctx) +{ + return br_multicast_rport_del(&pmctx->ip4_rlist); +} + +static bool br_ip6_multicast_rport_del(struct net_bridge_mcast_port *pmctx) +{ +#if IS_ENABLED(CONFIG_IPV6) + return br_multicast_rport_del(&pmctx->ip6_rlist); +#else + return false; +#endif +} + +static void br_multicast_router_expired(struct net_bridge_mcast_port *pmctx, + struct timer_list *t, + struct hlist_node *rlist) +{ + struct net_bridge *br = pmctx->port->br; + bool del; + + spin_lock(&br->multicast_lock); + if (pmctx->multicast_router == MDB_RTR_TYPE_DISABLED || + pmctx->multicast_router == MDB_RTR_TYPE_PERM || + timer_pending(t)) + goto out; + + del = br_multicast_rport_del(rlist); + br_multicast_rport_del_notify(pmctx, del); +out: + spin_unlock(&br->multicast_lock); +} + +static void br_ip4_multicast_router_expired(struct timer_list *t) +{ + struct net_bridge_mcast_port *pmctx = from_timer(pmctx, t, + ip4_mc_router_timer); + + br_multicast_router_expired(pmctx, t, &pmctx->ip4_rlist); +} + +#if IS_ENABLED(CONFIG_IPV6) +static void br_ip6_multicast_router_expired(struct timer_list *t) +{ + struct net_bridge_mcast_port *pmctx = from_timer(pmctx, t, + ip6_mc_router_timer); + + br_multicast_router_expired(pmctx, t, &pmctx->ip6_rlist); +} +#endif + +static void br_mc_router_state_change(struct net_bridge *p, + bool is_mc_router) +{ + struct switchdev_attr attr = { + .orig_dev = p->dev, + .id = SWITCHDEV_ATTR_ID_BRIDGE_MROUTER, + .flags = SWITCHDEV_F_DEFER, + .u.mrouter = is_mc_router, + }; + + switchdev_port_attr_set(p->dev, &attr, NULL); +} + +static void br_multicast_local_router_expired(struct net_bridge_mcast *brmctx, + struct timer_list *timer) +{ + spin_lock(&brmctx->br->multicast_lock); + if (brmctx->multicast_router == MDB_RTR_TYPE_DISABLED || + brmctx->multicast_router == MDB_RTR_TYPE_PERM || + br_ip4_multicast_is_router(brmctx) || + br_ip6_multicast_is_router(brmctx)) + goto out; + + br_mc_router_state_change(brmctx->br, false); +out: + spin_unlock(&brmctx->br->multicast_lock); +} + +static void br_ip4_multicast_local_router_expired(struct timer_list *t) +{ + struct net_bridge_mcast *brmctx = from_timer(brmctx, t, + ip4_mc_router_timer); + + br_multicast_local_router_expired(brmctx, t); +} + +#if IS_ENABLED(CONFIG_IPV6) +static void br_ip6_multicast_local_router_expired(struct timer_list *t) +{ + struct net_bridge_mcast *brmctx = from_timer(brmctx, t, + ip6_mc_router_timer); + + br_multicast_local_router_expired(brmctx, t); +} +#endif + +static void br_multicast_querier_expired(struct net_bridge_mcast *brmctx, + struct bridge_mcast_own_query *query) +{ + spin_lock(&brmctx->br->multicast_lock); + if (!netif_running(brmctx->br->dev) || + br_multicast_ctx_vlan_global_disabled(brmctx) || + !br_opt_get(brmctx->br, BROPT_MULTICAST_ENABLED)) + goto out; + + br_multicast_start_querier(brmctx, query); + +out: + spin_unlock(&brmctx->br->multicast_lock); +} + +static void br_ip4_multicast_querier_expired(struct timer_list *t) +{ + struct net_bridge_mcast *brmctx = from_timer(brmctx, t, + ip4_other_query.timer); + + br_multicast_querier_expired(brmctx, &brmctx->ip4_own_query); +} + +#if IS_ENABLED(CONFIG_IPV6) +static void br_ip6_multicast_querier_expired(struct timer_list *t) +{ + struct net_bridge_mcast *brmctx = from_timer(brmctx, t, + ip6_other_query.timer); + + br_multicast_querier_expired(brmctx, &brmctx->ip6_own_query); +} +#endif + +static void br_multicast_select_own_querier(struct net_bridge_mcast *brmctx, + struct br_ip *ip, + struct sk_buff *skb) +{ + if (ip->proto == htons(ETH_P_IP)) + brmctx->ip4_querier.addr.src.ip4 = ip_hdr(skb)->saddr; +#if IS_ENABLED(CONFIG_IPV6) + else + brmctx->ip6_querier.addr.src.ip6 = ipv6_hdr(skb)->saddr; +#endif +} + +static void __br_multicast_send_query(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct net_bridge_port_group *pg, + struct br_ip *ip_dst, + struct br_ip *group, + bool with_srcs, + u8 sflag, + bool *need_rexmit) +{ + bool over_lmqt = !!sflag; + struct sk_buff *skb; + u8 igmp_type; + + if (!br_multicast_ctx_should_use(brmctx, pmctx) || + !br_multicast_ctx_matches_vlan_snooping(brmctx)) + return; + +again_under_lmqt: + skb = br_multicast_alloc_query(brmctx, pmctx, pg, ip_dst, group, + with_srcs, over_lmqt, sflag, &igmp_type, + need_rexmit); + if (!skb) + return; + + if (pmctx) { + skb->dev = pmctx->port->dev; + br_multicast_count(brmctx->br, pmctx->port, skb, igmp_type, + BR_MCAST_DIR_TX); + NF_HOOK(NFPROTO_BRIDGE, NF_BR_LOCAL_OUT, + dev_net(pmctx->port->dev), NULL, skb, NULL, skb->dev, + br_dev_queue_push_xmit); + + if (over_lmqt && with_srcs && sflag) { + over_lmqt = false; + goto again_under_lmqt; + } + } else { + br_multicast_select_own_querier(brmctx, group, skb); + br_multicast_count(brmctx->br, NULL, skb, igmp_type, + BR_MCAST_DIR_RX); + netif_rx(skb); + } +} + +static void br_multicast_read_querier(const struct bridge_mcast_querier *querier, + struct bridge_mcast_querier *dest) +{ + unsigned int seq; + + memset(dest, 0, sizeof(*dest)); + do { + seq = read_seqcount_begin(&querier->seq); + dest->port_ifidx = querier->port_ifidx; + memcpy(&dest->addr, &querier->addr, sizeof(struct br_ip)); + } while (read_seqcount_retry(&querier->seq, seq)); +} + +static void br_multicast_update_querier(struct net_bridge_mcast *brmctx, + struct bridge_mcast_querier *querier, + int ifindex, + struct br_ip *saddr) +{ + write_seqcount_begin(&querier->seq); + querier->port_ifidx = ifindex; + memcpy(&querier->addr, saddr, sizeof(*saddr)); + write_seqcount_end(&querier->seq); +} + +static void br_multicast_send_query(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct bridge_mcast_own_query *own_query) +{ + struct bridge_mcast_other_query *other_query = NULL; + struct bridge_mcast_querier *querier; + struct br_ip br_group; + unsigned long time; + + if (!br_multicast_ctx_should_use(brmctx, pmctx) || + !br_opt_get(brmctx->br, BROPT_MULTICAST_ENABLED) || + !brmctx->multicast_querier) + return; + + memset(&br_group.dst, 0, sizeof(br_group.dst)); + + if (pmctx ? (own_query == &pmctx->ip4_own_query) : + (own_query == &brmctx->ip4_own_query)) { + querier = &brmctx->ip4_querier; + other_query = &brmctx->ip4_other_query; + br_group.proto = htons(ETH_P_IP); +#if IS_ENABLED(CONFIG_IPV6) + } else { + querier = &brmctx->ip6_querier; + other_query = &brmctx->ip6_other_query; + br_group.proto = htons(ETH_P_IPV6); +#endif + } + + if (!other_query || timer_pending(&other_query->timer)) + return; + + /* we're about to select ourselves as querier */ + if (!pmctx && querier->port_ifidx) { + struct br_ip zeroip = {}; + + br_multicast_update_querier(brmctx, querier, 0, &zeroip); + } + + __br_multicast_send_query(brmctx, pmctx, NULL, NULL, &br_group, false, + 0, NULL); + + time = jiffies; + time += own_query->startup_sent < brmctx->multicast_startup_query_count ? + brmctx->multicast_startup_query_interval : + brmctx->multicast_query_interval; + mod_timer(&own_query->timer, time); +} + +static void +br_multicast_port_query_expired(struct net_bridge_mcast_port *pmctx, + struct bridge_mcast_own_query *query) +{ + struct net_bridge *br = pmctx->port->br; + struct net_bridge_mcast *brmctx; + + spin_lock(&br->multicast_lock); + if (br_multicast_port_ctx_state_stopped(pmctx)) + goto out; + + brmctx = br_multicast_port_ctx_get_global(pmctx); + if (query->startup_sent < brmctx->multicast_startup_query_count) + query->startup_sent++; + + br_multicast_send_query(brmctx, pmctx, query); + +out: + spin_unlock(&br->multicast_lock); +} + +static void br_ip4_multicast_port_query_expired(struct timer_list *t) +{ + struct net_bridge_mcast_port *pmctx = from_timer(pmctx, t, + ip4_own_query.timer); + + br_multicast_port_query_expired(pmctx, &pmctx->ip4_own_query); +} + +#if IS_ENABLED(CONFIG_IPV6) +static void br_ip6_multicast_port_query_expired(struct timer_list *t) +{ + struct net_bridge_mcast_port *pmctx = from_timer(pmctx, t, + ip6_own_query.timer); + + br_multicast_port_query_expired(pmctx, &pmctx->ip6_own_query); +} +#endif + +static void br_multicast_port_group_rexmit(struct timer_list *t) +{ + struct net_bridge_port_group *pg = from_timer(pg, t, rexmit_timer); + struct bridge_mcast_other_query *other_query = NULL; + struct net_bridge *br = pg->key.port->br; + struct net_bridge_mcast_port *pmctx; + struct net_bridge_mcast *brmctx; + bool need_rexmit = false; + + spin_lock(&br->multicast_lock); + if (!netif_running(br->dev) || hlist_unhashed(&pg->mglist) || + !br_opt_get(br, BROPT_MULTICAST_ENABLED)) + goto out; + + pmctx = br_multicast_pg_to_port_ctx(pg); + if (!pmctx) + goto out; + brmctx = br_multicast_port_ctx_get_global(pmctx); + if (!brmctx->multicast_querier) + goto out; + + if (pg->key.addr.proto == htons(ETH_P_IP)) + other_query = &brmctx->ip4_other_query; +#if IS_ENABLED(CONFIG_IPV6) + else + other_query = &brmctx->ip6_other_query; +#endif + + if (!other_query || timer_pending(&other_query->timer)) + goto out; + + if (pg->grp_query_rexmit_cnt) { + pg->grp_query_rexmit_cnt--; + __br_multicast_send_query(brmctx, pmctx, pg, &pg->key.addr, + &pg->key.addr, false, 1, NULL); + } + __br_multicast_send_query(brmctx, pmctx, pg, &pg->key.addr, + &pg->key.addr, true, 0, &need_rexmit); + + if (pg->grp_query_rexmit_cnt || need_rexmit) + mod_timer(&pg->rexmit_timer, jiffies + + brmctx->multicast_last_member_interval); +out: + spin_unlock(&br->multicast_lock); +} + +static int br_mc_disabled_update(struct net_device *dev, bool value, + struct netlink_ext_ack *extack) +{ + struct switchdev_attr attr = { + .orig_dev = dev, + .id = SWITCHDEV_ATTR_ID_BRIDGE_MC_DISABLED, + .flags = SWITCHDEV_F_DEFER, + .u.mc_disabled = !value, + }; + + return switchdev_port_attr_set(dev, &attr, extack); +} + +void br_multicast_port_ctx_init(struct net_bridge_port *port, + struct net_bridge_vlan *vlan, + struct net_bridge_mcast_port *pmctx) +{ + pmctx->port = port; + pmctx->vlan = vlan; + pmctx->multicast_router = MDB_RTR_TYPE_TEMP_QUERY; + timer_setup(&pmctx->ip4_mc_router_timer, + br_ip4_multicast_router_expired, 0); + timer_setup(&pmctx->ip4_own_query.timer, + br_ip4_multicast_port_query_expired, 0); +#if IS_ENABLED(CONFIG_IPV6) + timer_setup(&pmctx->ip6_mc_router_timer, + br_ip6_multicast_router_expired, 0); + timer_setup(&pmctx->ip6_own_query.timer, + br_ip6_multicast_port_query_expired, 0); +#endif +} + +void br_multicast_port_ctx_deinit(struct net_bridge_mcast_port *pmctx) +{ +#if IS_ENABLED(CONFIG_IPV6) + del_timer_sync(&pmctx->ip6_mc_router_timer); +#endif + del_timer_sync(&pmctx->ip4_mc_router_timer); +} + +int br_multicast_add_port(struct net_bridge_port *port) +{ + int err; + + port->multicast_eht_hosts_limit = BR_MCAST_DEFAULT_EHT_HOSTS_LIMIT; + br_multicast_port_ctx_init(port, NULL, &port->multicast_ctx); + + err = br_mc_disabled_update(port->dev, + br_opt_get(port->br, + BROPT_MULTICAST_ENABLED), + NULL); + if (err && err != -EOPNOTSUPP) + return err; + + port->mcast_stats = netdev_alloc_pcpu_stats(struct bridge_mcast_stats); + if (!port->mcast_stats) + return -ENOMEM; + + return 0; +} + +void br_multicast_del_port(struct net_bridge_port *port) +{ + struct net_bridge *br = port->br; + struct net_bridge_port_group *pg; + HLIST_HEAD(deleted_head); + struct hlist_node *n; + + /* Take care of the remaining groups, only perm ones should be left */ + spin_lock_bh(&br->multicast_lock); + hlist_for_each_entry_safe(pg, n, &port->mglist, mglist) + br_multicast_find_del_pg(br, pg); + hlist_move_list(&br->mcast_gc_list, &deleted_head); + spin_unlock_bh(&br->multicast_lock); + br_multicast_gc(&deleted_head); + br_multicast_port_ctx_deinit(&port->multicast_ctx); + free_percpu(port->mcast_stats); +} + +static void br_multicast_enable(struct bridge_mcast_own_query *query) +{ + query->startup_sent = 0; + + if (try_to_del_timer_sync(&query->timer) >= 0 || + del_timer(&query->timer)) + mod_timer(&query->timer, jiffies); +} + +static void __br_multicast_enable_port_ctx(struct net_bridge_mcast_port *pmctx) +{ + struct net_bridge *br = pmctx->port->br; + struct net_bridge_mcast *brmctx; + + brmctx = br_multicast_port_ctx_get_global(pmctx); + if (!br_opt_get(br, BROPT_MULTICAST_ENABLED) || + !netif_running(br->dev)) + return; + + br_multicast_enable(&pmctx->ip4_own_query); +#if IS_ENABLED(CONFIG_IPV6) + br_multicast_enable(&pmctx->ip6_own_query); +#endif + if (pmctx->multicast_router == MDB_RTR_TYPE_PERM) { + br_ip4_multicast_add_router(brmctx, pmctx); + br_ip6_multicast_add_router(brmctx, pmctx); + } +} + +void br_multicast_enable_port(struct net_bridge_port *port) +{ + struct net_bridge *br = port->br; + + spin_lock_bh(&br->multicast_lock); + __br_multicast_enable_port_ctx(&port->multicast_ctx); + spin_unlock_bh(&br->multicast_lock); +} + +static void __br_multicast_disable_port_ctx(struct net_bridge_mcast_port *pmctx) +{ + struct net_bridge_port_group *pg; + struct hlist_node *n; + bool del = false; + + hlist_for_each_entry_safe(pg, n, &pmctx->port->mglist, mglist) + if (!(pg->flags & MDB_PG_FLAGS_PERMANENT) && + (!br_multicast_port_ctx_is_vlan(pmctx) || + pg->key.addr.vid == pmctx->vlan->vid)) + br_multicast_find_del_pg(pmctx->port->br, pg); + + del |= br_ip4_multicast_rport_del(pmctx); + del_timer(&pmctx->ip4_mc_router_timer); + del_timer(&pmctx->ip4_own_query.timer); + del |= br_ip6_multicast_rport_del(pmctx); +#if IS_ENABLED(CONFIG_IPV6) + del_timer(&pmctx->ip6_mc_router_timer); + del_timer(&pmctx->ip6_own_query.timer); +#endif + br_multicast_rport_del_notify(pmctx, del); +} + +void br_multicast_disable_port(struct net_bridge_port *port) +{ + spin_lock_bh(&port->br->multicast_lock); + __br_multicast_disable_port_ctx(&port->multicast_ctx); + spin_unlock_bh(&port->br->multicast_lock); +} + +static int __grp_src_delete_marked(struct net_bridge_port_group *pg) +{ + struct net_bridge_group_src *ent; + struct hlist_node *tmp; + int deleted = 0; + + hlist_for_each_entry_safe(ent, tmp, &pg->src_list, node) + if (ent->flags & BR_SGRP_F_DELETE) { + br_multicast_del_group_src(ent, false); + deleted++; + } + + return deleted; +} + +static void __grp_src_mod_timer(struct net_bridge_group_src *src, + unsigned long expires) +{ + mod_timer(&src->timer, expires); + br_multicast_fwd_src_handle(src); +} + +static void __grp_src_query_marked_and_rexmit(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct net_bridge_port_group *pg) +{ + struct bridge_mcast_other_query *other_query = NULL; + u32 lmqc = brmctx->multicast_last_member_count; + unsigned long lmqt, lmi, now = jiffies; + struct net_bridge_group_src *ent; + + if (!netif_running(brmctx->br->dev) || + !br_opt_get(brmctx->br, BROPT_MULTICAST_ENABLED)) + return; + + if (pg->key.addr.proto == htons(ETH_P_IP)) + other_query = &brmctx->ip4_other_query; +#if IS_ENABLED(CONFIG_IPV6) + else + other_query = &brmctx->ip6_other_query; +#endif + + lmqt = now + br_multicast_lmqt(brmctx); + hlist_for_each_entry(ent, &pg->src_list, node) { + if (ent->flags & BR_SGRP_F_SEND) { + ent->flags &= ~BR_SGRP_F_SEND; + if (ent->timer.expires > lmqt) { + if (brmctx->multicast_querier && + other_query && + !timer_pending(&other_query->timer)) + ent->src_query_rexmit_cnt = lmqc; + __grp_src_mod_timer(ent, lmqt); + } + } + } + + if (!brmctx->multicast_querier || + !other_query || timer_pending(&other_query->timer)) + return; + + __br_multicast_send_query(brmctx, pmctx, pg, &pg->key.addr, + &pg->key.addr, true, 1, NULL); + + lmi = now + brmctx->multicast_last_member_interval; + if (!timer_pending(&pg->rexmit_timer) || + time_after(pg->rexmit_timer.expires, lmi)) + mod_timer(&pg->rexmit_timer, lmi); +} + +static void __grp_send_query_and_rexmit(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct net_bridge_port_group *pg) +{ + struct bridge_mcast_other_query *other_query = NULL; + unsigned long now = jiffies, lmi; + + if (!netif_running(brmctx->br->dev) || + !br_opt_get(brmctx->br, BROPT_MULTICAST_ENABLED)) + return; + + if (pg->key.addr.proto == htons(ETH_P_IP)) + other_query = &brmctx->ip4_other_query; +#if IS_ENABLED(CONFIG_IPV6) + else + other_query = &brmctx->ip6_other_query; +#endif + + if (brmctx->multicast_querier && + other_query && !timer_pending(&other_query->timer)) { + lmi = now + brmctx->multicast_last_member_interval; + pg->grp_query_rexmit_cnt = brmctx->multicast_last_member_count - 1; + __br_multicast_send_query(brmctx, pmctx, pg, &pg->key.addr, + &pg->key.addr, false, 0, NULL); + if (!timer_pending(&pg->rexmit_timer) || + time_after(pg->rexmit_timer.expires, lmi)) + mod_timer(&pg->rexmit_timer, lmi); + } + + if (pg->filter_mode == MCAST_EXCLUDE && + (!timer_pending(&pg->timer) || + time_after(pg->timer.expires, now + br_multicast_lmqt(brmctx)))) + mod_timer(&pg->timer, now + br_multicast_lmqt(brmctx)); +} + +/* State Msg type New state Actions + * INCLUDE (A) IS_IN (B) INCLUDE (A+B) (B)=GMI + * INCLUDE (A) ALLOW (B) INCLUDE (A+B) (B)=GMI + * EXCLUDE (X,Y) ALLOW (A) EXCLUDE (X+A,Y-A) (A)=GMI + */ +static bool br_multicast_isinc_allow(const struct net_bridge_mcast *brmctx, + struct net_bridge_port_group *pg, void *h_addr, + void *srcs, u32 nsrcs, size_t addr_size, + int grec_type) +{ + struct net_bridge_group_src *ent; + unsigned long now = jiffies; + bool changed = false; + struct br_ip src_ip; + u32 src_idx; + + memset(&src_ip, 0, sizeof(src_ip)); + src_ip.proto = pg->key.addr.proto; + for (src_idx = 0; src_idx < nsrcs; src_idx++) { + memcpy(&src_ip.src, srcs + (src_idx * addr_size), addr_size); + ent = br_multicast_find_group_src(pg, &src_ip); + if (!ent) { + ent = br_multicast_new_group_src(pg, &src_ip); + if (ent) + changed = true; + } + + if (ent) + __grp_src_mod_timer(ent, now + br_multicast_gmi(brmctx)); + } + + if (br_multicast_eht_handle(brmctx, pg, h_addr, srcs, nsrcs, addr_size, + grec_type)) + changed = true; + + return changed; +} + +/* State Msg type New state Actions + * INCLUDE (A) IS_EX (B) EXCLUDE (A*B,B-A) (B-A)=0 + * Delete (A-B) + * Group Timer=GMI + */ +static void __grp_src_isexc_incl(const struct net_bridge_mcast *brmctx, + struct net_bridge_port_group *pg, void *h_addr, + void *srcs, u32 nsrcs, size_t addr_size, + int grec_type) +{ + struct net_bridge_group_src *ent; + struct br_ip src_ip; + u32 src_idx; + + hlist_for_each_entry(ent, &pg->src_list, node) + ent->flags |= BR_SGRP_F_DELETE; + + memset(&src_ip, 0, sizeof(src_ip)); + src_ip.proto = pg->key.addr.proto; + for (src_idx = 0; src_idx < nsrcs; src_idx++) { + memcpy(&src_ip.src, srcs + (src_idx * addr_size), addr_size); + ent = br_multicast_find_group_src(pg, &src_ip); + if (ent) + ent->flags &= ~BR_SGRP_F_DELETE; + else + ent = br_multicast_new_group_src(pg, &src_ip); + if (ent) + br_multicast_fwd_src_handle(ent); + } + + br_multicast_eht_handle(brmctx, pg, h_addr, srcs, nsrcs, addr_size, + grec_type); + + __grp_src_delete_marked(pg); +} + +/* State Msg type New state Actions + * EXCLUDE (X,Y) IS_EX (A) EXCLUDE (A-Y,Y*A) (A-X-Y)=GMI + * Delete (X-A) + * Delete (Y-A) + * Group Timer=GMI + */ +static bool __grp_src_isexc_excl(const struct net_bridge_mcast *brmctx, + struct net_bridge_port_group *pg, void *h_addr, + void *srcs, u32 nsrcs, size_t addr_size, + int grec_type) +{ + struct net_bridge_group_src *ent; + unsigned long now = jiffies; + bool changed = false; + struct br_ip src_ip; + u32 src_idx; + + hlist_for_each_entry(ent, &pg->src_list, node) + ent->flags |= BR_SGRP_F_DELETE; + + memset(&src_ip, 0, sizeof(src_ip)); + src_ip.proto = pg->key.addr.proto; + for (src_idx = 0; src_idx < nsrcs; src_idx++) { + memcpy(&src_ip.src, srcs + (src_idx * addr_size), addr_size); + ent = br_multicast_find_group_src(pg, &src_ip); + if (ent) { + ent->flags &= ~BR_SGRP_F_DELETE; + } else { + ent = br_multicast_new_group_src(pg, &src_ip); + if (ent) { + __grp_src_mod_timer(ent, + now + br_multicast_gmi(brmctx)); + changed = true; + } + } + } + + if (br_multicast_eht_handle(brmctx, pg, h_addr, srcs, nsrcs, addr_size, + grec_type)) + changed = true; + + if (__grp_src_delete_marked(pg)) + changed = true; + + return changed; +} + +static bool br_multicast_isexc(const struct net_bridge_mcast *brmctx, + struct net_bridge_port_group *pg, void *h_addr, + void *srcs, u32 nsrcs, size_t addr_size, + int grec_type) +{ + bool changed = false; + + switch (pg->filter_mode) { + case MCAST_INCLUDE: + __grp_src_isexc_incl(brmctx, pg, h_addr, srcs, nsrcs, addr_size, + grec_type); + br_multicast_star_g_handle_mode(pg, MCAST_EXCLUDE); + changed = true; + break; + case MCAST_EXCLUDE: + changed = __grp_src_isexc_excl(brmctx, pg, h_addr, srcs, nsrcs, + addr_size, grec_type); + break; + } + + pg->filter_mode = MCAST_EXCLUDE; + mod_timer(&pg->timer, jiffies + br_multicast_gmi(brmctx)); + + return changed; +} + +/* State Msg type New state Actions + * INCLUDE (A) TO_IN (B) INCLUDE (A+B) (B)=GMI + * Send Q(G,A-B) + */ +static bool __grp_src_toin_incl(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct net_bridge_port_group *pg, void *h_addr, + void *srcs, u32 nsrcs, size_t addr_size, + int grec_type) +{ + u32 src_idx, to_send = pg->src_ents; + struct net_bridge_group_src *ent; + unsigned long now = jiffies; + bool changed = false; + struct br_ip src_ip; + + hlist_for_each_entry(ent, &pg->src_list, node) + ent->flags |= BR_SGRP_F_SEND; + + memset(&src_ip, 0, sizeof(src_ip)); + src_ip.proto = pg->key.addr.proto; + for (src_idx = 0; src_idx < nsrcs; src_idx++) { + memcpy(&src_ip.src, srcs + (src_idx * addr_size), addr_size); + ent = br_multicast_find_group_src(pg, &src_ip); + if (ent) { + ent->flags &= ~BR_SGRP_F_SEND; + to_send--; + } else { + ent = br_multicast_new_group_src(pg, &src_ip); + if (ent) + changed = true; + } + if (ent) + __grp_src_mod_timer(ent, now + br_multicast_gmi(brmctx)); + } + + if (br_multicast_eht_handle(brmctx, pg, h_addr, srcs, nsrcs, addr_size, + grec_type)) + changed = true; + + if (to_send) + __grp_src_query_marked_and_rexmit(brmctx, pmctx, pg); + + return changed; +} + +/* State Msg type New state Actions + * EXCLUDE (X,Y) TO_IN (A) EXCLUDE (X+A,Y-A) (A)=GMI + * Send Q(G,X-A) + * Send Q(G) + */ +static bool __grp_src_toin_excl(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct net_bridge_port_group *pg, void *h_addr, + void *srcs, u32 nsrcs, size_t addr_size, + int grec_type) +{ + u32 src_idx, to_send = pg->src_ents; + struct net_bridge_group_src *ent; + unsigned long now = jiffies; + bool changed = false; + struct br_ip src_ip; + + hlist_for_each_entry(ent, &pg->src_list, node) + if (timer_pending(&ent->timer)) + ent->flags |= BR_SGRP_F_SEND; + + memset(&src_ip, 0, sizeof(src_ip)); + src_ip.proto = pg->key.addr.proto; + for (src_idx = 0; src_idx < nsrcs; src_idx++) { + memcpy(&src_ip.src, srcs + (src_idx * addr_size), addr_size); + ent = br_multicast_find_group_src(pg, &src_ip); + if (ent) { + if (timer_pending(&ent->timer)) { + ent->flags &= ~BR_SGRP_F_SEND; + to_send--; + } + } else { + ent = br_multicast_new_group_src(pg, &src_ip); + if (ent) + changed = true; + } + if (ent) + __grp_src_mod_timer(ent, now + br_multicast_gmi(brmctx)); + } + + if (br_multicast_eht_handle(brmctx, pg, h_addr, srcs, nsrcs, addr_size, + grec_type)) + changed = true; + + if (to_send) + __grp_src_query_marked_and_rexmit(brmctx, pmctx, pg); + + __grp_send_query_and_rexmit(brmctx, pmctx, pg); + + return changed; +} + +static bool br_multicast_toin(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct net_bridge_port_group *pg, void *h_addr, + void *srcs, u32 nsrcs, size_t addr_size, + int grec_type) +{ + bool changed = false; + + switch (pg->filter_mode) { + case MCAST_INCLUDE: + changed = __grp_src_toin_incl(brmctx, pmctx, pg, h_addr, srcs, + nsrcs, addr_size, grec_type); + break; + case MCAST_EXCLUDE: + changed = __grp_src_toin_excl(brmctx, pmctx, pg, h_addr, srcs, + nsrcs, addr_size, grec_type); + break; + } + + if (br_multicast_eht_should_del_pg(pg)) { + pg->flags |= MDB_PG_FLAGS_FAST_LEAVE; + br_multicast_find_del_pg(pg->key.port->br, pg); + /* a notification has already been sent and we shouldn't + * access pg after the delete so we have to return false + */ + changed = false; + } + + return changed; +} + +/* State Msg type New state Actions + * INCLUDE (A) TO_EX (B) EXCLUDE (A*B,B-A) (B-A)=0 + * Delete (A-B) + * Send Q(G,A*B) + * Group Timer=GMI + */ +static void __grp_src_toex_incl(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct net_bridge_port_group *pg, void *h_addr, + void *srcs, u32 nsrcs, size_t addr_size, + int grec_type) +{ + struct net_bridge_group_src *ent; + u32 src_idx, to_send = 0; + struct br_ip src_ip; + + hlist_for_each_entry(ent, &pg->src_list, node) + ent->flags = (ent->flags & ~BR_SGRP_F_SEND) | BR_SGRP_F_DELETE; + + memset(&src_ip, 0, sizeof(src_ip)); + src_ip.proto = pg->key.addr.proto; + for (src_idx = 0; src_idx < nsrcs; src_idx++) { + memcpy(&src_ip.src, srcs + (src_idx * addr_size), addr_size); + ent = br_multicast_find_group_src(pg, &src_ip); + if (ent) { + ent->flags = (ent->flags & ~BR_SGRP_F_DELETE) | + BR_SGRP_F_SEND; + to_send++; + } else { + ent = br_multicast_new_group_src(pg, &src_ip); + } + if (ent) + br_multicast_fwd_src_handle(ent); + } + + br_multicast_eht_handle(brmctx, pg, h_addr, srcs, nsrcs, addr_size, + grec_type); + + __grp_src_delete_marked(pg); + if (to_send) + __grp_src_query_marked_and_rexmit(brmctx, pmctx, pg); +} + +/* State Msg type New state Actions + * EXCLUDE (X,Y) TO_EX (A) EXCLUDE (A-Y,Y*A) (A-X-Y)=Group Timer + * Delete (X-A) + * Delete (Y-A) + * Send Q(G,A-Y) + * Group Timer=GMI + */ +static bool __grp_src_toex_excl(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct net_bridge_port_group *pg, void *h_addr, + void *srcs, u32 nsrcs, size_t addr_size, + int grec_type) +{ + struct net_bridge_group_src *ent; + u32 src_idx, to_send = 0; + bool changed = false; + struct br_ip src_ip; + + hlist_for_each_entry(ent, &pg->src_list, node) + ent->flags = (ent->flags & ~BR_SGRP_F_SEND) | BR_SGRP_F_DELETE; + + memset(&src_ip, 0, sizeof(src_ip)); + src_ip.proto = pg->key.addr.proto; + for (src_idx = 0; src_idx < nsrcs; src_idx++) { + memcpy(&src_ip.src, srcs + (src_idx * addr_size), addr_size); + ent = br_multicast_find_group_src(pg, &src_ip); + if (ent) { + ent->flags &= ~BR_SGRP_F_DELETE; + } else { + ent = br_multicast_new_group_src(pg, &src_ip); + if (ent) { + __grp_src_mod_timer(ent, pg->timer.expires); + changed = true; + } + } + if (ent && timer_pending(&ent->timer)) { + ent->flags |= BR_SGRP_F_SEND; + to_send++; + } + } + + if (br_multicast_eht_handle(brmctx, pg, h_addr, srcs, nsrcs, addr_size, + grec_type)) + changed = true; + + if (__grp_src_delete_marked(pg)) + changed = true; + if (to_send) + __grp_src_query_marked_and_rexmit(brmctx, pmctx, pg); + + return changed; +} + +static bool br_multicast_toex(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct net_bridge_port_group *pg, void *h_addr, + void *srcs, u32 nsrcs, size_t addr_size, + int grec_type) +{ + bool changed = false; + + switch (pg->filter_mode) { + case MCAST_INCLUDE: + __grp_src_toex_incl(brmctx, pmctx, pg, h_addr, srcs, nsrcs, + addr_size, grec_type); + br_multicast_star_g_handle_mode(pg, MCAST_EXCLUDE); + changed = true; + break; + case MCAST_EXCLUDE: + changed = __grp_src_toex_excl(brmctx, pmctx, pg, h_addr, srcs, + nsrcs, addr_size, grec_type); + break; + } + + pg->filter_mode = MCAST_EXCLUDE; + mod_timer(&pg->timer, jiffies + br_multicast_gmi(brmctx)); + + return changed; +} + +/* State Msg type New state Actions + * INCLUDE (A) BLOCK (B) INCLUDE (A) Send Q(G,A*B) + */ +static bool __grp_src_block_incl(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct net_bridge_port_group *pg, void *h_addr, + void *srcs, u32 nsrcs, size_t addr_size, int grec_type) +{ + struct net_bridge_group_src *ent; + u32 src_idx, to_send = 0; + bool changed = false; + struct br_ip src_ip; + + hlist_for_each_entry(ent, &pg->src_list, node) + ent->flags &= ~BR_SGRP_F_SEND; + + memset(&src_ip, 0, sizeof(src_ip)); + src_ip.proto = pg->key.addr.proto; + for (src_idx = 0; src_idx < nsrcs; src_idx++) { + memcpy(&src_ip.src, srcs + (src_idx * addr_size), addr_size); + ent = br_multicast_find_group_src(pg, &src_ip); + if (ent) { + ent->flags |= BR_SGRP_F_SEND; + to_send++; + } + } + + if (br_multicast_eht_handle(brmctx, pg, h_addr, srcs, nsrcs, addr_size, + grec_type)) + changed = true; + + if (to_send) + __grp_src_query_marked_and_rexmit(brmctx, pmctx, pg); + + return changed; +} + +/* State Msg type New state Actions + * EXCLUDE (X,Y) BLOCK (A) EXCLUDE (X+(A-Y),Y) (A-X-Y)=Group Timer + * Send Q(G,A-Y) + */ +static bool __grp_src_block_excl(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct net_bridge_port_group *pg, void *h_addr, + void *srcs, u32 nsrcs, size_t addr_size, int grec_type) +{ + struct net_bridge_group_src *ent; + u32 src_idx, to_send = 0; + bool changed = false; + struct br_ip src_ip; + + hlist_for_each_entry(ent, &pg->src_list, node) + ent->flags &= ~BR_SGRP_F_SEND; + + memset(&src_ip, 0, sizeof(src_ip)); + src_ip.proto = pg->key.addr.proto; + for (src_idx = 0; src_idx < nsrcs; src_idx++) { + memcpy(&src_ip.src, srcs + (src_idx * addr_size), addr_size); + ent = br_multicast_find_group_src(pg, &src_ip); + if (!ent) { + ent = br_multicast_new_group_src(pg, &src_ip); + if (ent) { + __grp_src_mod_timer(ent, pg->timer.expires); + changed = true; + } + } + if (ent && timer_pending(&ent->timer)) { + ent->flags |= BR_SGRP_F_SEND; + to_send++; + } + } + + if (br_multicast_eht_handle(brmctx, pg, h_addr, srcs, nsrcs, addr_size, + grec_type)) + changed = true; + + if (to_send) + __grp_src_query_marked_and_rexmit(brmctx, pmctx, pg); + + return changed; +} + +static bool br_multicast_block(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct net_bridge_port_group *pg, void *h_addr, + void *srcs, u32 nsrcs, size_t addr_size, int grec_type) +{ + bool changed = false; + + switch (pg->filter_mode) { + case MCAST_INCLUDE: + changed = __grp_src_block_incl(brmctx, pmctx, pg, h_addr, srcs, + nsrcs, addr_size, grec_type); + break; + case MCAST_EXCLUDE: + changed = __grp_src_block_excl(brmctx, pmctx, pg, h_addr, srcs, + nsrcs, addr_size, grec_type); + break; + } + + if ((pg->filter_mode == MCAST_INCLUDE && hlist_empty(&pg->src_list)) || + br_multicast_eht_should_del_pg(pg)) { + if (br_multicast_eht_should_del_pg(pg)) + pg->flags |= MDB_PG_FLAGS_FAST_LEAVE; + br_multicast_find_del_pg(pg->key.port->br, pg); + /* a notification has already been sent and we shouldn't + * access pg after the delete so we have to return false + */ + changed = false; + } + + return changed; +} + +static struct net_bridge_port_group * +br_multicast_find_port(struct net_bridge_mdb_entry *mp, + struct net_bridge_port *p, + const unsigned char *src) +{ + struct net_bridge *br __maybe_unused = mp->br; + struct net_bridge_port_group *pg; + + for (pg = mlock_dereference(mp->ports, br); + pg; + pg = mlock_dereference(pg->next, br)) + if (br_port_group_equal(pg, p, src)) + return pg; + + return NULL; +} + +static int br_ip4_multicast_igmp3_report(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct sk_buff *skb, + u16 vid) +{ + bool igmpv2 = brmctx->multicast_igmp_version == 2; + struct net_bridge_mdb_entry *mdst; + struct net_bridge_port_group *pg; + const unsigned char *src; + struct igmpv3_report *ih; + struct igmpv3_grec *grec; + int i, len, num, type; + __be32 group, *h_addr; + bool changed = false; + int err = 0; + u16 nsrcs; + + ih = igmpv3_report_hdr(skb); + num = ntohs(ih->ngrec); + len = skb_transport_offset(skb) + sizeof(*ih); + + for (i = 0; i < num; i++) { + len += sizeof(*grec); + if (!ip_mc_may_pull(skb, len)) + return -EINVAL; + + grec = (void *)(skb->data + len - sizeof(*grec)); + group = grec->grec_mca; + type = grec->grec_type; + nsrcs = ntohs(grec->grec_nsrcs); + + len += nsrcs * 4; + if (!ip_mc_may_pull(skb, len)) + return -EINVAL; + + switch (type) { + case IGMPV3_MODE_IS_INCLUDE: + case IGMPV3_MODE_IS_EXCLUDE: + case IGMPV3_CHANGE_TO_INCLUDE: + case IGMPV3_CHANGE_TO_EXCLUDE: + case IGMPV3_ALLOW_NEW_SOURCES: + case IGMPV3_BLOCK_OLD_SOURCES: + break; + + default: + continue; + } + + src = eth_hdr(skb)->h_source; + if (nsrcs == 0 && + (type == IGMPV3_CHANGE_TO_INCLUDE || + type == IGMPV3_MODE_IS_INCLUDE)) { + if (!pmctx || igmpv2) { + br_ip4_multicast_leave_group(brmctx, pmctx, + group, vid, src); + continue; + } + } else { + err = br_ip4_multicast_add_group(brmctx, pmctx, group, + vid, src, igmpv2); + if (err) + break; + } + + if (!pmctx || igmpv2) + continue; + + spin_lock_bh(&brmctx->br->multicast_lock); + if (!br_multicast_ctx_should_use(brmctx, pmctx)) + goto unlock_continue; + + mdst = br_mdb_ip4_get(brmctx->br, group, vid); + if (!mdst) + goto unlock_continue; + pg = br_multicast_find_port(mdst, pmctx->port, src); + if (!pg || (pg->flags & MDB_PG_FLAGS_PERMANENT)) + goto unlock_continue; + /* reload grec and host addr */ + grec = (void *)(skb->data + len - sizeof(*grec) - (nsrcs * 4)); + h_addr = &ip_hdr(skb)->saddr; + switch (type) { + case IGMPV3_ALLOW_NEW_SOURCES: + changed = br_multicast_isinc_allow(brmctx, pg, h_addr, + grec->grec_src, + nsrcs, sizeof(__be32), type); + break; + case IGMPV3_MODE_IS_INCLUDE: + changed = br_multicast_isinc_allow(brmctx, pg, h_addr, + grec->grec_src, + nsrcs, sizeof(__be32), type); + break; + case IGMPV3_MODE_IS_EXCLUDE: + changed = br_multicast_isexc(brmctx, pg, h_addr, + grec->grec_src, + nsrcs, sizeof(__be32), type); + break; + case IGMPV3_CHANGE_TO_INCLUDE: + changed = br_multicast_toin(brmctx, pmctx, pg, h_addr, + grec->grec_src, + nsrcs, sizeof(__be32), type); + break; + case IGMPV3_CHANGE_TO_EXCLUDE: + changed = br_multicast_toex(brmctx, pmctx, pg, h_addr, + grec->grec_src, + nsrcs, sizeof(__be32), type); + break; + case IGMPV3_BLOCK_OLD_SOURCES: + changed = br_multicast_block(brmctx, pmctx, pg, h_addr, + grec->grec_src, + nsrcs, sizeof(__be32), type); + break; + } + if (changed) + br_mdb_notify(brmctx->br->dev, mdst, pg, RTM_NEWMDB); +unlock_continue: + spin_unlock_bh(&brmctx->br->multicast_lock); + } + + return err; +} + +#if IS_ENABLED(CONFIG_IPV6) +static int br_ip6_multicast_mld2_report(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct sk_buff *skb, + u16 vid) +{ + bool mldv1 = brmctx->multicast_mld_version == 1; + struct net_bridge_mdb_entry *mdst; + struct net_bridge_port_group *pg; + unsigned int nsrcs_offset; + struct mld2_report *mld2r; + const unsigned char *src; + struct in6_addr *h_addr; + struct mld2_grec *grec; + unsigned int grec_len; + bool changed = false; + int i, len, num; + int err = 0; + + if (!ipv6_mc_may_pull(skb, sizeof(*mld2r))) + return -EINVAL; + + mld2r = (struct mld2_report *)icmp6_hdr(skb); + num = ntohs(mld2r->mld2r_ngrec); + len = skb_transport_offset(skb) + sizeof(*mld2r); + + for (i = 0; i < num; i++) { + __be16 *_nsrcs, __nsrcs; + u16 nsrcs; + + nsrcs_offset = len + offsetof(struct mld2_grec, grec_nsrcs); + + if (skb_transport_offset(skb) + ipv6_transport_len(skb) < + nsrcs_offset + sizeof(__nsrcs)) + return -EINVAL; + + _nsrcs = skb_header_pointer(skb, nsrcs_offset, + sizeof(__nsrcs), &__nsrcs); + if (!_nsrcs) + return -EINVAL; + + nsrcs = ntohs(*_nsrcs); + grec_len = struct_size(grec, grec_src, nsrcs); + + if (!ipv6_mc_may_pull(skb, len + grec_len)) + return -EINVAL; + + grec = (struct mld2_grec *)(skb->data + len); + len += grec_len; + + switch (grec->grec_type) { + case MLD2_MODE_IS_INCLUDE: + case MLD2_MODE_IS_EXCLUDE: + case MLD2_CHANGE_TO_INCLUDE: + case MLD2_CHANGE_TO_EXCLUDE: + case MLD2_ALLOW_NEW_SOURCES: + case MLD2_BLOCK_OLD_SOURCES: + break; + + default: + continue; + } + + src = eth_hdr(skb)->h_source; + if ((grec->grec_type == MLD2_CHANGE_TO_INCLUDE || + grec->grec_type == MLD2_MODE_IS_INCLUDE) && + nsrcs == 0) { + if (!pmctx || mldv1) { + br_ip6_multicast_leave_group(brmctx, pmctx, + &grec->grec_mca, + vid, src); + continue; + } + } else { + err = br_ip6_multicast_add_group(brmctx, pmctx, + &grec->grec_mca, vid, + src, mldv1); + if (err) + break; + } + + if (!pmctx || mldv1) + continue; + + spin_lock_bh(&brmctx->br->multicast_lock); + if (!br_multicast_ctx_should_use(brmctx, pmctx)) + goto unlock_continue; + + mdst = br_mdb_ip6_get(brmctx->br, &grec->grec_mca, vid); + if (!mdst) + goto unlock_continue; + pg = br_multicast_find_port(mdst, pmctx->port, src); + if (!pg || (pg->flags & MDB_PG_FLAGS_PERMANENT)) + goto unlock_continue; + h_addr = &ipv6_hdr(skb)->saddr; + switch (grec->grec_type) { + case MLD2_ALLOW_NEW_SOURCES: + changed = br_multicast_isinc_allow(brmctx, pg, h_addr, + grec->grec_src, nsrcs, + sizeof(struct in6_addr), + grec->grec_type); + break; + case MLD2_MODE_IS_INCLUDE: + changed = br_multicast_isinc_allow(brmctx, pg, h_addr, + grec->grec_src, nsrcs, + sizeof(struct in6_addr), + grec->grec_type); + break; + case MLD2_MODE_IS_EXCLUDE: + changed = br_multicast_isexc(brmctx, pg, h_addr, + grec->grec_src, nsrcs, + sizeof(struct in6_addr), + grec->grec_type); + break; + case MLD2_CHANGE_TO_INCLUDE: + changed = br_multicast_toin(brmctx, pmctx, pg, h_addr, + grec->grec_src, nsrcs, + sizeof(struct in6_addr), + grec->grec_type); + break; + case MLD2_CHANGE_TO_EXCLUDE: + changed = br_multicast_toex(brmctx, pmctx, pg, h_addr, + grec->grec_src, nsrcs, + sizeof(struct in6_addr), + grec->grec_type); + break; + case MLD2_BLOCK_OLD_SOURCES: + changed = br_multicast_block(brmctx, pmctx, pg, h_addr, + grec->grec_src, nsrcs, + sizeof(struct in6_addr), + grec->grec_type); + break; + } + if (changed) + br_mdb_notify(brmctx->br->dev, mdst, pg, RTM_NEWMDB); +unlock_continue: + spin_unlock_bh(&brmctx->br->multicast_lock); + } + + return err; +} +#endif + +static bool br_multicast_select_querier(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct br_ip *saddr) +{ + int port_ifidx = pmctx ? pmctx->port->dev->ifindex : 0; + struct timer_list *own_timer, *other_timer; + struct bridge_mcast_querier *querier; + + switch (saddr->proto) { + case htons(ETH_P_IP): + querier = &brmctx->ip4_querier; + own_timer = &brmctx->ip4_own_query.timer; + other_timer = &brmctx->ip4_other_query.timer; + if (!querier->addr.src.ip4 || + ntohl(saddr->src.ip4) <= ntohl(querier->addr.src.ip4)) + goto update; + break; +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): + querier = &brmctx->ip6_querier; + own_timer = &brmctx->ip6_own_query.timer; + other_timer = &brmctx->ip6_other_query.timer; + if (ipv6_addr_cmp(&saddr->src.ip6, &querier->addr.src.ip6) <= 0) + goto update; + break; +#endif + default: + return false; + } + + if (!timer_pending(own_timer) && !timer_pending(other_timer)) + goto update; + + return false; + +update: + br_multicast_update_querier(brmctx, querier, port_ifidx, saddr); + + return true; +} + +static struct net_bridge_port * +__br_multicast_get_querier_port(struct net_bridge *br, + const struct bridge_mcast_querier *querier) +{ + int port_ifidx = READ_ONCE(querier->port_ifidx); + struct net_bridge_port *p; + struct net_device *dev; + + if (port_ifidx == 0) + return NULL; + + dev = dev_get_by_index_rcu(dev_net(br->dev), port_ifidx); + if (!dev) + return NULL; + p = br_port_get_rtnl_rcu(dev); + if (!p || p->br != br) + return NULL; + + return p; +} + +size_t br_multicast_querier_state_size(void) +{ + return nla_total_size(0) + /* nest attribute */ + nla_total_size(sizeof(__be32)) + /* BRIDGE_QUERIER_IP_ADDRESS */ + nla_total_size(sizeof(int)) + /* BRIDGE_QUERIER_IP_PORT */ + nla_total_size_64bit(sizeof(u64)) + /* BRIDGE_QUERIER_IP_OTHER_TIMER */ +#if IS_ENABLED(CONFIG_IPV6) + nla_total_size(sizeof(struct in6_addr)) + /* BRIDGE_QUERIER_IPV6_ADDRESS */ + nla_total_size(sizeof(int)) + /* BRIDGE_QUERIER_IPV6_PORT */ + nla_total_size_64bit(sizeof(u64)) + /* BRIDGE_QUERIER_IPV6_OTHER_TIMER */ +#endif + 0; +} + +/* protected by rtnl or rcu */ +int br_multicast_dump_querier_state(struct sk_buff *skb, + const struct net_bridge_mcast *brmctx, + int nest_attr) +{ + struct bridge_mcast_querier querier = {}; + struct net_bridge_port *p; + struct nlattr *nest; + + if (!br_opt_get(brmctx->br, BROPT_MULTICAST_ENABLED) || + br_multicast_ctx_vlan_global_disabled(brmctx)) + return 0; + + nest = nla_nest_start(skb, nest_attr); + if (!nest) + return -EMSGSIZE; + + rcu_read_lock(); + if (!brmctx->multicast_querier && + !timer_pending(&brmctx->ip4_other_query.timer)) + goto out_v6; + + br_multicast_read_querier(&brmctx->ip4_querier, &querier); + if (nla_put_in_addr(skb, BRIDGE_QUERIER_IP_ADDRESS, + querier.addr.src.ip4)) { + rcu_read_unlock(); + goto out_err; + } + + p = __br_multicast_get_querier_port(brmctx->br, &querier); + if (timer_pending(&brmctx->ip4_other_query.timer) && + (nla_put_u64_64bit(skb, BRIDGE_QUERIER_IP_OTHER_TIMER, + br_timer_value(&brmctx->ip4_other_query.timer), + BRIDGE_QUERIER_PAD) || + (p && nla_put_u32(skb, BRIDGE_QUERIER_IP_PORT, p->dev->ifindex)))) { + rcu_read_unlock(); + goto out_err; + } + +out_v6: +#if IS_ENABLED(CONFIG_IPV6) + if (!brmctx->multicast_querier && + !timer_pending(&brmctx->ip6_other_query.timer)) + goto out; + + br_multicast_read_querier(&brmctx->ip6_querier, &querier); + if (nla_put_in6_addr(skb, BRIDGE_QUERIER_IPV6_ADDRESS, + &querier.addr.src.ip6)) { + rcu_read_unlock(); + goto out_err; + } + + p = __br_multicast_get_querier_port(brmctx->br, &querier); + if (timer_pending(&brmctx->ip6_other_query.timer) && + (nla_put_u64_64bit(skb, BRIDGE_QUERIER_IPV6_OTHER_TIMER, + br_timer_value(&brmctx->ip6_other_query.timer), + BRIDGE_QUERIER_PAD) || + (p && nla_put_u32(skb, BRIDGE_QUERIER_IPV6_PORT, + p->dev->ifindex)))) { + rcu_read_unlock(); + goto out_err; + } +out: +#endif + rcu_read_unlock(); + nla_nest_end(skb, nest); + if (!nla_len(nest)) + nla_nest_cancel(skb, nest); + + return 0; + +out_err: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static void +br_multicast_update_query_timer(struct net_bridge_mcast *brmctx, + struct bridge_mcast_other_query *query, + unsigned long max_delay) +{ + if (!timer_pending(&query->timer)) + query->delay_time = jiffies + max_delay; + + mod_timer(&query->timer, jiffies + brmctx->multicast_querier_interval); +} + +static void br_port_mc_router_state_change(struct net_bridge_port *p, + bool is_mc_router) +{ + struct switchdev_attr attr = { + .orig_dev = p->dev, + .id = SWITCHDEV_ATTR_ID_PORT_MROUTER, + .flags = SWITCHDEV_F_DEFER, + .u.mrouter = is_mc_router, + }; + + switchdev_port_attr_set(p->dev, &attr, NULL); +} + +static struct net_bridge_port * +br_multicast_rport_from_node(struct net_bridge_mcast *brmctx, + struct hlist_head *mc_router_list, + struct hlist_node *rlist) +{ + struct net_bridge_mcast_port *pmctx; + +#if IS_ENABLED(CONFIG_IPV6) + if (mc_router_list == &brmctx->ip6_mc_router_list) + pmctx = hlist_entry(rlist, struct net_bridge_mcast_port, + ip6_rlist); + else +#endif + pmctx = hlist_entry(rlist, struct net_bridge_mcast_port, + ip4_rlist); + + return pmctx->port; +} + +static struct hlist_node * +br_multicast_get_rport_slot(struct net_bridge_mcast *brmctx, + struct net_bridge_port *port, + struct hlist_head *mc_router_list) + +{ + struct hlist_node *slot = NULL; + struct net_bridge_port *p; + struct hlist_node *rlist; + + hlist_for_each(rlist, mc_router_list) { + p = br_multicast_rport_from_node(brmctx, mc_router_list, rlist); + + if ((unsigned long)port >= (unsigned long)p) + break; + + slot = rlist; + } + + return slot; +} + +static bool br_multicast_no_router_otherpf(struct net_bridge_mcast_port *pmctx, + struct hlist_node *rnode) +{ +#if IS_ENABLED(CONFIG_IPV6) + if (rnode != &pmctx->ip6_rlist) + return hlist_unhashed(&pmctx->ip6_rlist); + else + return hlist_unhashed(&pmctx->ip4_rlist); +#else + return true; +#endif +} + +/* Add port to router_list + * list is maintained ordered by pointer value + * and locked by br->multicast_lock and RCU + */ +static void br_multicast_add_router(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct hlist_node *rlist, + struct hlist_head *mc_router_list) +{ + struct hlist_node *slot; + + if (!hlist_unhashed(rlist)) + return; + + slot = br_multicast_get_rport_slot(brmctx, pmctx->port, mc_router_list); + + if (slot) + hlist_add_behind_rcu(rlist, slot); + else + hlist_add_head_rcu(rlist, mc_router_list); + + /* For backwards compatibility for now, only notify if we + * switched from no IPv4/IPv6 multicast router to a new + * IPv4 or IPv6 multicast router. + */ + if (br_multicast_no_router_otherpf(pmctx, rlist)) { + br_rtr_notify(pmctx->port->br->dev, pmctx, RTM_NEWMDB); + br_port_mc_router_state_change(pmctx->port, true); + } +} + +/* Add port to router_list + * list is maintained ordered by pointer value + * and locked by br->multicast_lock and RCU + */ +static void br_ip4_multicast_add_router(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx) +{ + br_multicast_add_router(brmctx, pmctx, &pmctx->ip4_rlist, + &brmctx->ip4_mc_router_list); +} + +/* Add port to router_list + * list is maintained ordered by pointer value + * and locked by br->multicast_lock and RCU + */ +static void br_ip6_multicast_add_router(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx) +{ +#if IS_ENABLED(CONFIG_IPV6) + br_multicast_add_router(brmctx, pmctx, &pmctx->ip6_rlist, + &brmctx->ip6_mc_router_list); +#endif +} + +static void br_multicast_mark_router(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct timer_list *timer, + struct hlist_node *rlist, + struct hlist_head *mc_router_list) +{ + unsigned long now = jiffies; + + if (!br_multicast_ctx_should_use(brmctx, pmctx)) + return; + + if (!pmctx) { + if (brmctx->multicast_router == MDB_RTR_TYPE_TEMP_QUERY) { + if (!br_ip4_multicast_is_router(brmctx) && + !br_ip6_multicast_is_router(brmctx)) + br_mc_router_state_change(brmctx->br, true); + mod_timer(timer, now + brmctx->multicast_querier_interval); + } + return; + } + + if (pmctx->multicast_router == MDB_RTR_TYPE_DISABLED || + pmctx->multicast_router == MDB_RTR_TYPE_PERM) + return; + + br_multicast_add_router(brmctx, pmctx, rlist, mc_router_list); + mod_timer(timer, now + brmctx->multicast_querier_interval); +} + +static void br_ip4_multicast_mark_router(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx) +{ + struct timer_list *timer = &brmctx->ip4_mc_router_timer; + struct hlist_node *rlist = NULL; + + if (pmctx) { + timer = &pmctx->ip4_mc_router_timer; + rlist = &pmctx->ip4_rlist; + } + + br_multicast_mark_router(brmctx, pmctx, timer, rlist, + &brmctx->ip4_mc_router_list); +} + +static void br_ip6_multicast_mark_router(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx) +{ +#if IS_ENABLED(CONFIG_IPV6) + struct timer_list *timer = &brmctx->ip6_mc_router_timer; + struct hlist_node *rlist = NULL; + + if (pmctx) { + timer = &pmctx->ip6_mc_router_timer; + rlist = &pmctx->ip6_rlist; + } + + br_multicast_mark_router(brmctx, pmctx, timer, rlist, + &brmctx->ip6_mc_router_list); +#endif +} + +static void +br_ip4_multicast_query_received(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct bridge_mcast_other_query *query, + struct br_ip *saddr, + unsigned long max_delay) +{ + if (!br_multicast_select_querier(brmctx, pmctx, saddr)) + return; + + br_multicast_update_query_timer(brmctx, query, max_delay); + br_ip4_multicast_mark_router(brmctx, pmctx); +} + +#if IS_ENABLED(CONFIG_IPV6) +static void +br_ip6_multicast_query_received(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct bridge_mcast_other_query *query, + struct br_ip *saddr, + unsigned long max_delay) +{ + if (!br_multicast_select_querier(brmctx, pmctx, saddr)) + return; + + br_multicast_update_query_timer(brmctx, query, max_delay); + br_ip6_multicast_mark_router(brmctx, pmctx); +} +#endif + +static void br_ip4_multicast_query(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct sk_buff *skb, + u16 vid) +{ + unsigned int transport_len = ip_transport_len(skb); + const struct iphdr *iph = ip_hdr(skb); + struct igmphdr *ih = igmp_hdr(skb); + struct net_bridge_mdb_entry *mp; + struct igmpv3_query *ih3; + struct net_bridge_port_group *p; + struct net_bridge_port_group __rcu **pp; + struct br_ip saddr = {}; + unsigned long max_delay; + unsigned long now = jiffies; + __be32 group; + + spin_lock(&brmctx->br->multicast_lock); + if (!br_multicast_ctx_should_use(brmctx, pmctx)) + goto out; + + group = ih->group; + + if (transport_len == sizeof(*ih)) { + max_delay = ih->code * (HZ / IGMP_TIMER_SCALE); + + if (!max_delay) { + max_delay = 10 * HZ; + group = 0; + } + } else if (transport_len >= sizeof(*ih3)) { + ih3 = igmpv3_query_hdr(skb); + if (ih3->nsrcs || + (brmctx->multicast_igmp_version == 3 && group && + ih3->suppress)) + goto out; + + max_delay = ih3->code ? + IGMPV3_MRC(ih3->code) * (HZ / IGMP_TIMER_SCALE) : 1; + } else { + goto out; + } + + if (!group) { + saddr.proto = htons(ETH_P_IP); + saddr.src.ip4 = iph->saddr; + + br_ip4_multicast_query_received(brmctx, pmctx, + &brmctx->ip4_other_query, + &saddr, max_delay); + goto out; + } + + mp = br_mdb_ip4_get(brmctx->br, group, vid); + if (!mp) + goto out; + + max_delay *= brmctx->multicast_last_member_count; + + if (mp->host_joined && + (timer_pending(&mp->timer) ? + time_after(mp->timer.expires, now + max_delay) : + try_to_del_timer_sync(&mp->timer) >= 0)) + mod_timer(&mp->timer, now + max_delay); + + for (pp = &mp->ports; + (p = mlock_dereference(*pp, brmctx->br)) != NULL; + pp = &p->next) { + if (timer_pending(&p->timer) ? + time_after(p->timer.expires, now + max_delay) : + try_to_del_timer_sync(&p->timer) >= 0 && + (brmctx->multicast_igmp_version == 2 || + p->filter_mode == MCAST_EXCLUDE)) + mod_timer(&p->timer, now + max_delay); + } + +out: + spin_unlock(&brmctx->br->multicast_lock); +} + +#if IS_ENABLED(CONFIG_IPV6) +static int br_ip6_multicast_query(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct sk_buff *skb, + u16 vid) +{ + unsigned int transport_len = ipv6_transport_len(skb); + struct mld_msg *mld; + struct net_bridge_mdb_entry *mp; + struct mld2_query *mld2q; + struct net_bridge_port_group *p; + struct net_bridge_port_group __rcu **pp; + struct br_ip saddr = {}; + unsigned long max_delay; + unsigned long now = jiffies; + unsigned int offset = skb_transport_offset(skb); + const struct in6_addr *group = NULL; + bool is_general_query; + int err = 0; + + spin_lock(&brmctx->br->multicast_lock); + if (!br_multicast_ctx_should_use(brmctx, pmctx)) + goto out; + + if (transport_len == sizeof(*mld)) { + if (!pskb_may_pull(skb, offset + sizeof(*mld))) { + err = -EINVAL; + goto out; + } + mld = (struct mld_msg *) icmp6_hdr(skb); + max_delay = msecs_to_jiffies(ntohs(mld->mld_maxdelay)); + if (max_delay) + group = &mld->mld_mca; + } else { + if (!pskb_may_pull(skb, offset + sizeof(*mld2q))) { + err = -EINVAL; + goto out; + } + mld2q = (struct mld2_query *)icmp6_hdr(skb); + if (!mld2q->mld2q_nsrcs) + group = &mld2q->mld2q_mca; + if (brmctx->multicast_mld_version == 2 && + !ipv6_addr_any(&mld2q->mld2q_mca) && + mld2q->mld2q_suppress) + goto out; + + max_delay = max(msecs_to_jiffies(mldv2_mrc(mld2q)), 1UL); + } + + is_general_query = group && ipv6_addr_any(group); + + if (is_general_query) { + saddr.proto = htons(ETH_P_IPV6); + saddr.src.ip6 = ipv6_hdr(skb)->saddr; + + br_ip6_multicast_query_received(brmctx, pmctx, + &brmctx->ip6_other_query, + &saddr, max_delay); + goto out; + } else if (!group) { + goto out; + } + + mp = br_mdb_ip6_get(brmctx->br, group, vid); + if (!mp) + goto out; + + max_delay *= brmctx->multicast_last_member_count; + if (mp->host_joined && + (timer_pending(&mp->timer) ? + time_after(mp->timer.expires, now + max_delay) : + try_to_del_timer_sync(&mp->timer) >= 0)) + mod_timer(&mp->timer, now + max_delay); + + for (pp = &mp->ports; + (p = mlock_dereference(*pp, brmctx->br)) != NULL; + pp = &p->next) { + if (timer_pending(&p->timer) ? + time_after(p->timer.expires, now + max_delay) : + try_to_del_timer_sync(&p->timer) >= 0 && + (brmctx->multicast_mld_version == 1 || + p->filter_mode == MCAST_EXCLUDE)) + mod_timer(&p->timer, now + max_delay); + } + +out: + spin_unlock(&brmctx->br->multicast_lock); + return err; +} +#endif + +static void +br_multicast_leave_group(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct br_ip *group, + struct bridge_mcast_other_query *other_query, + struct bridge_mcast_own_query *own_query, + const unsigned char *src) +{ + struct net_bridge_mdb_entry *mp; + struct net_bridge_port_group *p; + unsigned long now; + unsigned long time; + + spin_lock(&brmctx->br->multicast_lock); + if (!br_multicast_ctx_should_use(brmctx, pmctx)) + goto out; + + mp = br_mdb_ip_get(brmctx->br, group); + if (!mp) + goto out; + + if (pmctx && (pmctx->port->flags & BR_MULTICAST_FAST_LEAVE)) { + struct net_bridge_port_group __rcu **pp; + + for (pp = &mp->ports; + (p = mlock_dereference(*pp, brmctx->br)) != NULL; + pp = &p->next) { + if (!br_port_group_equal(p, pmctx->port, src)) + continue; + + if (p->flags & MDB_PG_FLAGS_PERMANENT) + break; + + p->flags |= MDB_PG_FLAGS_FAST_LEAVE; + br_multicast_del_pg(mp, p, pp); + } + goto out; + } + + if (timer_pending(&other_query->timer)) + goto out; + + if (brmctx->multicast_querier) { + __br_multicast_send_query(brmctx, pmctx, NULL, NULL, &mp->addr, + false, 0, NULL); + + time = jiffies + brmctx->multicast_last_member_count * + brmctx->multicast_last_member_interval; + + mod_timer(&own_query->timer, time); + + for (p = mlock_dereference(mp->ports, brmctx->br); + p != NULL && pmctx != NULL; + p = mlock_dereference(p->next, brmctx->br)) { + if (!br_port_group_equal(p, pmctx->port, src)) + continue; + + if (!hlist_unhashed(&p->mglist) && + (timer_pending(&p->timer) ? + time_after(p->timer.expires, time) : + try_to_del_timer_sync(&p->timer) >= 0)) { + mod_timer(&p->timer, time); + } + + break; + } + } + + now = jiffies; + time = now + brmctx->multicast_last_member_count * + brmctx->multicast_last_member_interval; + + if (!pmctx) { + if (mp->host_joined && + (timer_pending(&mp->timer) ? + time_after(mp->timer.expires, time) : + try_to_del_timer_sync(&mp->timer) >= 0)) { + mod_timer(&mp->timer, time); + } + + goto out; + } + + for (p = mlock_dereference(mp->ports, brmctx->br); + p != NULL; + p = mlock_dereference(p->next, brmctx->br)) { + if (p->key.port != pmctx->port) + continue; + + if (!hlist_unhashed(&p->mglist) && + (timer_pending(&p->timer) ? + time_after(p->timer.expires, time) : + try_to_del_timer_sync(&p->timer) >= 0)) { + mod_timer(&p->timer, time); + } + + break; + } +out: + spin_unlock(&brmctx->br->multicast_lock); +} + +static void br_ip4_multicast_leave_group(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + __be32 group, + __u16 vid, + const unsigned char *src) +{ + struct br_ip br_group; + struct bridge_mcast_own_query *own_query; + + if (ipv4_is_local_multicast(group)) + return; + + own_query = pmctx ? &pmctx->ip4_own_query : &brmctx->ip4_own_query; + + memset(&br_group, 0, sizeof(br_group)); + br_group.dst.ip4 = group; + br_group.proto = htons(ETH_P_IP); + br_group.vid = vid; + + br_multicast_leave_group(brmctx, pmctx, &br_group, + &brmctx->ip4_other_query, + own_query, src); +} + +#if IS_ENABLED(CONFIG_IPV6) +static void br_ip6_multicast_leave_group(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + const struct in6_addr *group, + __u16 vid, + const unsigned char *src) +{ + struct br_ip br_group; + struct bridge_mcast_own_query *own_query; + + if (ipv6_addr_is_ll_all_nodes(group)) + return; + + own_query = pmctx ? &pmctx->ip6_own_query : &brmctx->ip6_own_query; + + memset(&br_group, 0, sizeof(br_group)); + br_group.dst.ip6 = *group; + br_group.proto = htons(ETH_P_IPV6); + br_group.vid = vid; + + br_multicast_leave_group(brmctx, pmctx, &br_group, + &brmctx->ip6_other_query, + own_query, src); +} +#endif + +static void br_multicast_err_count(const struct net_bridge *br, + const struct net_bridge_port *p, + __be16 proto) +{ + struct bridge_mcast_stats __percpu *stats; + struct bridge_mcast_stats *pstats; + + if (!br_opt_get(br, BROPT_MULTICAST_STATS_ENABLED)) + return; + + if (p) + stats = p->mcast_stats; + else + stats = br->mcast_stats; + if (WARN_ON(!stats)) + return; + + pstats = this_cpu_ptr(stats); + + u64_stats_update_begin(&pstats->syncp); + switch (proto) { + case htons(ETH_P_IP): + pstats->mstats.igmp_parse_errors++; + break; +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): + pstats->mstats.mld_parse_errors++; + break; +#endif + } + u64_stats_update_end(&pstats->syncp); +} + +static void br_multicast_pim(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + const struct sk_buff *skb) +{ + unsigned int offset = skb_transport_offset(skb); + struct pimhdr *pimhdr, _pimhdr; + + pimhdr = skb_header_pointer(skb, offset, sizeof(_pimhdr), &_pimhdr); + if (!pimhdr || pim_hdr_version(pimhdr) != PIM_VERSION || + pim_hdr_type(pimhdr) != PIM_TYPE_HELLO) + return; + + spin_lock(&brmctx->br->multicast_lock); + br_ip4_multicast_mark_router(brmctx, pmctx); + spin_unlock(&brmctx->br->multicast_lock); +} + +static int br_ip4_multicast_mrd_rcv(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct sk_buff *skb) +{ + if (ip_hdr(skb)->protocol != IPPROTO_IGMP || + igmp_hdr(skb)->type != IGMP_MRDISC_ADV) + return -ENOMSG; + + spin_lock(&brmctx->br->multicast_lock); + br_ip4_multicast_mark_router(brmctx, pmctx); + spin_unlock(&brmctx->br->multicast_lock); + + return 0; +} + +static int br_multicast_ipv4_rcv(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct sk_buff *skb, + u16 vid) +{ + struct net_bridge_port *p = pmctx ? pmctx->port : NULL; + const unsigned char *src; + struct igmphdr *ih; + int err; + + err = ip_mc_check_igmp(skb); + + if (err == -ENOMSG) { + if (!ipv4_is_local_multicast(ip_hdr(skb)->daddr)) { + BR_INPUT_SKB_CB(skb)->mrouters_only = 1; + } else if (pim_ipv4_all_pim_routers(ip_hdr(skb)->daddr)) { + if (ip_hdr(skb)->protocol == IPPROTO_PIM) + br_multicast_pim(brmctx, pmctx, skb); + } else if (ipv4_is_all_snoopers(ip_hdr(skb)->daddr)) { + br_ip4_multicast_mrd_rcv(brmctx, pmctx, skb); + } + + return 0; + } else if (err < 0) { + br_multicast_err_count(brmctx->br, p, skb->protocol); + return err; + } + + ih = igmp_hdr(skb); + src = eth_hdr(skb)->h_source; + BR_INPUT_SKB_CB(skb)->igmp = ih->type; + + switch (ih->type) { + case IGMP_HOST_MEMBERSHIP_REPORT: + case IGMPV2_HOST_MEMBERSHIP_REPORT: + BR_INPUT_SKB_CB(skb)->mrouters_only = 1; + err = br_ip4_multicast_add_group(brmctx, pmctx, ih->group, vid, + src, true); + break; + case IGMPV3_HOST_MEMBERSHIP_REPORT: + err = br_ip4_multicast_igmp3_report(brmctx, pmctx, skb, vid); + break; + case IGMP_HOST_MEMBERSHIP_QUERY: + br_ip4_multicast_query(brmctx, pmctx, skb, vid); + break; + case IGMP_HOST_LEAVE_MESSAGE: + br_ip4_multicast_leave_group(brmctx, pmctx, ih->group, vid, src); + break; + } + + br_multicast_count(brmctx->br, p, skb, BR_INPUT_SKB_CB(skb)->igmp, + BR_MCAST_DIR_RX); + + return err; +} + +#if IS_ENABLED(CONFIG_IPV6) +static void br_ip6_multicast_mrd_rcv(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct sk_buff *skb) +{ + if (icmp6_hdr(skb)->icmp6_type != ICMPV6_MRDISC_ADV) + return; + + spin_lock(&brmctx->br->multicast_lock); + br_ip6_multicast_mark_router(brmctx, pmctx); + spin_unlock(&brmctx->br->multicast_lock); +} + +static int br_multicast_ipv6_rcv(struct net_bridge_mcast *brmctx, + struct net_bridge_mcast_port *pmctx, + struct sk_buff *skb, + u16 vid) +{ + struct net_bridge_port *p = pmctx ? pmctx->port : NULL; + const unsigned char *src; + struct mld_msg *mld; + int err; + + err = ipv6_mc_check_mld(skb); + + if (err == -ENOMSG || err == -ENODATA) { + if (!ipv6_addr_is_ll_all_nodes(&ipv6_hdr(skb)->daddr)) + BR_INPUT_SKB_CB(skb)->mrouters_only = 1; + if (err == -ENODATA && + ipv6_addr_is_all_snoopers(&ipv6_hdr(skb)->daddr)) + br_ip6_multicast_mrd_rcv(brmctx, pmctx, skb); + + return 0; + } else if (err < 0) { + br_multicast_err_count(brmctx->br, p, skb->protocol); + return err; + } + + mld = (struct mld_msg *)skb_transport_header(skb); + BR_INPUT_SKB_CB(skb)->igmp = mld->mld_type; + + switch (mld->mld_type) { + case ICMPV6_MGM_REPORT: + src = eth_hdr(skb)->h_source; + BR_INPUT_SKB_CB(skb)->mrouters_only = 1; + err = br_ip6_multicast_add_group(brmctx, pmctx, &mld->mld_mca, + vid, src, true); + break; + case ICMPV6_MLD2_REPORT: + err = br_ip6_multicast_mld2_report(brmctx, pmctx, skb, vid); + break; + case ICMPV6_MGM_QUERY: + err = br_ip6_multicast_query(brmctx, pmctx, skb, vid); + break; + case ICMPV6_MGM_REDUCTION: + src = eth_hdr(skb)->h_source; + br_ip6_multicast_leave_group(brmctx, pmctx, &mld->mld_mca, vid, + src); + break; + } + + br_multicast_count(brmctx->br, p, skb, BR_INPUT_SKB_CB(skb)->igmp, + BR_MCAST_DIR_RX); + + return err; +} +#endif + +int br_multicast_rcv(struct net_bridge_mcast **brmctx, + struct net_bridge_mcast_port **pmctx, + struct net_bridge_vlan *vlan, + struct sk_buff *skb, u16 vid) +{ + int ret = 0; + + BR_INPUT_SKB_CB(skb)->igmp = 0; + BR_INPUT_SKB_CB(skb)->mrouters_only = 0; + + if (!br_opt_get((*brmctx)->br, BROPT_MULTICAST_ENABLED)) + return 0; + + if (br_opt_get((*brmctx)->br, BROPT_MCAST_VLAN_SNOOPING_ENABLED) && vlan) { + const struct net_bridge_vlan *masterv; + + /* the vlan has the master flag set only when transmitting + * through the bridge device + */ + if (br_vlan_is_master(vlan)) { + masterv = vlan; + *brmctx = &vlan->br_mcast_ctx; + *pmctx = NULL; + } else { + masterv = vlan->brvlan; + *brmctx = &vlan->brvlan->br_mcast_ctx; + *pmctx = &vlan->port_mcast_ctx; + } + + if (!(masterv->priv_flags & BR_VLFLAG_GLOBAL_MCAST_ENABLED)) + return 0; + } + + switch (skb->protocol) { + case htons(ETH_P_IP): + ret = br_multicast_ipv4_rcv(*brmctx, *pmctx, skb, vid); + break; +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): + ret = br_multicast_ipv6_rcv(*brmctx, *pmctx, skb, vid); + break; +#endif + } + + return ret; +} + +static void br_multicast_query_expired(struct net_bridge_mcast *brmctx, + struct bridge_mcast_own_query *query, + struct bridge_mcast_querier *querier) +{ + spin_lock(&brmctx->br->multicast_lock); + if (br_multicast_ctx_vlan_disabled(brmctx)) + goto out; + + if (query->startup_sent < brmctx->multicast_startup_query_count) + query->startup_sent++; + + br_multicast_send_query(brmctx, NULL, query); +out: + spin_unlock(&brmctx->br->multicast_lock); +} + +static void br_ip4_multicast_query_expired(struct timer_list *t) +{ + struct net_bridge_mcast *brmctx = from_timer(brmctx, t, + ip4_own_query.timer); + + br_multicast_query_expired(brmctx, &brmctx->ip4_own_query, + &brmctx->ip4_querier); +} + +#if IS_ENABLED(CONFIG_IPV6) +static void br_ip6_multicast_query_expired(struct timer_list *t) +{ + struct net_bridge_mcast *brmctx = from_timer(brmctx, t, + ip6_own_query.timer); + + br_multicast_query_expired(brmctx, &brmctx->ip6_own_query, + &brmctx->ip6_querier); +} +#endif + +static void br_multicast_gc_work(struct work_struct *work) +{ + struct net_bridge *br = container_of(work, struct net_bridge, + mcast_gc_work); + HLIST_HEAD(deleted_head); + + spin_lock_bh(&br->multicast_lock); + hlist_move_list(&br->mcast_gc_list, &deleted_head); + spin_unlock_bh(&br->multicast_lock); + + br_multicast_gc(&deleted_head); +} + +void br_multicast_ctx_init(struct net_bridge *br, + struct net_bridge_vlan *vlan, + struct net_bridge_mcast *brmctx) +{ + brmctx->br = br; + brmctx->vlan = vlan; + brmctx->multicast_router = MDB_RTR_TYPE_TEMP_QUERY; + brmctx->multicast_last_member_count = 2; + brmctx->multicast_startup_query_count = 2; + + brmctx->multicast_last_member_interval = HZ; + brmctx->multicast_query_response_interval = 10 * HZ; + brmctx->multicast_startup_query_interval = 125 * HZ / 4; + brmctx->multicast_query_interval = 125 * HZ; + brmctx->multicast_querier_interval = 255 * HZ; + brmctx->multicast_membership_interval = 260 * HZ; + + brmctx->ip4_other_query.delay_time = 0; + brmctx->ip4_querier.port_ifidx = 0; + seqcount_spinlock_init(&brmctx->ip4_querier.seq, &br->multicast_lock); + brmctx->multicast_igmp_version = 2; +#if IS_ENABLED(CONFIG_IPV6) + brmctx->multicast_mld_version = 1; + brmctx->ip6_other_query.delay_time = 0; + brmctx->ip6_querier.port_ifidx = 0; + seqcount_spinlock_init(&brmctx->ip6_querier.seq, &br->multicast_lock); +#endif + + timer_setup(&brmctx->ip4_mc_router_timer, + br_ip4_multicast_local_router_expired, 0); + timer_setup(&brmctx->ip4_other_query.timer, + br_ip4_multicast_querier_expired, 0); + timer_setup(&brmctx->ip4_own_query.timer, + br_ip4_multicast_query_expired, 0); +#if IS_ENABLED(CONFIG_IPV6) + timer_setup(&brmctx->ip6_mc_router_timer, + br_ip6_multicast_local_router_expired, 0); + timer_setup(&brmctx->ip6_other_query.timer, + br_ip6_multicast_querier_expired, 0); + timer_setup(&brmctx->ip6_own_query.timer, + br_ip6_multicast_query_expired, 0); +#endif +} + +void br_multicast_ctx_deinit(struct net_bridge_mcast *brmctx) +{ + __br_multicast_stop(brmctx); +} + +void br_multicast_init(struct net_bridge *br) +{ + br->hash_max = BR_MULTICAST_DEFAULT_HASH_MAX; + + br_multicast_ctx_init(br, NULL, &br->multicast_ctx); + + br_opt_toggle(br, BROPT_MULTICAST_ENABLED, true); + br_opt_toggle(br, BROPT_HAS_IPV6_ADDR, true); + + spin_lock_init(&br->multicast_lock); + INIT_HLIST_HEAD(&br->mdb_list); + INIT_HLIST_HEAD(&br->mcast_gc_list); + INIT_WORK(&br->mcast_gc_work, br_multicast_gc_work); +} + +static void br_ip4_multicast_join_snoopers(struct net_bridge *br) +{ + struct in_device *in_dev = in_dev_get(br->dev); + + if (!in_dev) + return; + + __ip_mc_inc_group(in_dev, htonl(INADDR_ALLSNOOPERS_GROUP), GFP_ATOMIC); + in_dev_put(in_dev); +} + +#if IS_ENABLED(CONFIG_IPV6) +static void br_ip6_multicast_join_snoopers(struct net_bridge *br) +{ + struct in6_addr addr; + + ipv6_addr_set(&addr, htonl(0xff020000), 0, 0, htonl(0x6a)); + ipv6_dev_mc_inc(br->dev, &addr); +} +#else +static inline void br_ip6_multicast_join_snoopers(struct net_bridge *br) +{ +} +#endif + +void br_multicast_join_snoopers(struct net_bridge *br) +{ + br_ip4_multicast_join_snoopers(br); + br_ip6_multicast_join_snoopers(br); +} + +static void br_ip4_multicast_leave_snoopers(struct net_bridge *br) +{ + struct in_device *in_dev = in_dev_get(br->dev); + + if (WARN_ON(!in_dev)) + return; + + __ip_mc_dec_group(in_dev, htonl(INADDR_ALLSNOOPERS_GROUP), GFP_ATOMIC); + in_dev_put(in_dev); +} + +#if IS_ENABLED(CONFIG_IPV6) +static void br_ip6_multicast_leave_snoopers(struct net_bridge *br) +{ + struct in6_addr addr; + + ipv6_addr_set(&addr, htonl(0xff020000), 0, 0, htonl(0x6a)); + ipv6_dev_mc_dec(br->dev, &addr); +} +#else +static inline void br_ip6_multicast_leave_snoopers(struct net_bridge *br) +{ +} +#endif + +void br_multicast_leave_snoopers(struct net_bridge *br) +{ + br_ip4_multicast_leave_snoopers(br); + br_ip6_multicast_leave_snoopers(br); +} + +static void __br_multicast_open_query(struct net_bridge *br, + struct bridge_mcast_own_query *query) +{ + query->startup_sent = 0; + + if (!br_opt_get(br, BROPT_MULTICAST_ENABLED)) + return; + + mod_timer(&query->timer, jiffies); +} + +static void __br_multicast_open(struct net_bridge_mcast *brmctx) +{ + __br_multicast_open_query(brmctx->br, &brmctx->ip4_own_query); +#if IS_ENABLED(CONFIG_IPV6) + __br_multicast_open_query(brmctx->br, &brmctx->ip6_own_query); +#endif +} + +void br_multicast_open(struct net_bridge *br) +{ + ASSERT_RTNL(); + + if (br_opt_get(br, BROPT_MCAST_VLAN_SNOOPING_ENABLED)) { + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *vlan; + + vg = br_vlan_group(br); + if (vg) { + list_for_each_entry(vlan, &vg->vlan_list, vlist) { + struct net_bridge_mcast *brmctx; + + brmctx = &vlan->br_mcast_ctx; + if (br_vlan_is_brentry(vlan) && + !br_multicast_ctx_vlan_disabled(brmctx)) + __br_multicast_open(&vlan->br_mcast_ctx); + } + } + } else { + __br_multicast_open(&br->multicast_ctx); + } +} + +static void __br_multicast_stop(struct net_bridge_mcast *brmctx) +{ + del_timer_sync(&brmctx->ip4_mc_router_timer); + del_timer_sync(&brmctx->ip4_other_query.timer); + del_timer_sync(&brmctx->ip4_own_query.timer); +#if IS_ENABLED(CONFIG_IPV6) + del_timer_sync(&brmctx->ip6_mc_router_timer); + del_timer_sync(&brmctx->ip6_other_query.timer); + del_timer_sync(&brmctx->ip6_own_query.timer); +#endif +} + +void br_multicast_toggle_one_vlan(struct net_bridge_vlan *vlan, bool on) +{ + struct net_bridge *br; + + /* it's okay to check for the flag without the multicast lock because it + * can only change under RTNL -> multicast_lock, we need the latter to + * sync with timers and packets + */ + if (on == !!(vlan->priv_flags & BR_VLFLAG_MCAST_ENABLED)) + return; + + if (br_vlan_is_master(vlan)) { + br = vlan->br; + + if (!br_vlan_is_brentry(vlan) || + (on && + br_multicast_ctx_vlan_global_disabled(&vlan->br_mcast_ctx))) + return; + + spin_lock_bh(&br->multicast_lock); + vlan->priv_flags ^= BR_VLFLAG_MCAST_ENABLED; + spin_unlock_bh(&br->multicast_lock); + + if (on) + __br_multicast_open(&vlan->br_mcast_ctx); + else + __br_multicast_stop(&vlan->br_mcast_ctx); + } else { + struct net_bridge_mcast *brmctx; + + brmctx = br_multicast_port_ctx_get_global(&vlan->port_mcast_ctx); + if (on && br_multicast_ctx_vlan_global_disabled(brmctx)) + return; + + br = vlan->port->br; + spin_lock_bh(&br->multicast_lock); + vlan->priv_flags ^= BR_VLFLAG_MCAST_ENABLED; + if (on) + __br_multicast_enable_port_ctx(&vlan->port_mcast_ctx); + else + __br_multicast_disable_port_ctx(&vlan->port_mcast_ctx); + spin_unlock_bh(&br->multicast_lock); + } +} + +static void br_multicast_toggle_vlan(struct net_bridge_vlan *vlan, bool on) +{ + struct net_bridge_port *p; + + if (WARN_ON_ONCE(!br_vlan_is_master(vlan))) + return; + + list_for_each_entry(p, &vlan->br->port_list, list) { + struct net_bridge_vlan *vport; + + vport = br_vlan_find(nbp_vlan_group(p), vlan->vid); + if (!vport) + continue; + br_multicast_toggle_one_vlan(vport, on); + } + + if (br_vlan_is_brentry(vlan)) + br_multicast_toggle_one_vlan(vlan, on); +} + +int br_multicast_toggle_vlan_snooping(struct net_bridge *br, bool on, + struct netlink_ext_ack *extack) +{ + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *vlan; + struct net_bridge_port *p; + + if (br_opt_get(br, BROPT_MCAST_VLAN_SNOOPING_ENABLED) == on) + return 0; + + if (on && !br_opt_get(br, BROPT_VLAN_ENABLED)) { + NL_SET_ERR_MSG_MOD(extack, "Cannot enable multicast vlan snooping with vlan filtering disabled"); + return -EINVAL; + } + + vg = br_vlan_group(br); + if (!vg) + return 0; + + br_opt_toggle(br, BROPT_MCAST_VLAN_SNOOPING_ENABLED, on); + + /* disable/enable non-vlan mcast contexts based on vlan snooping */ + if (on) + __br_multicast_stop(&br->multicast_ctx); + else + __br_multicast_open(&br->multicast_ctx); + list_for_each_entry(p, &br->port_list, list) { + if (on) + br_multicast_disable_port(p); + else + br_multicast_enable_port(p); + } + + list_for_each_entry(vlan, &vg->vlan_list, vlist) + br_multicast_toggle_vlan(vlan, on); + + return 0; +} + +bool br_multicast_toggle_global_vlan(struct net_bridge_vlan *vlan, bool on) +{ + ASSERT_RTNL(); + + /* BR_VLFLAG_GLOBAL_MCAST_ENABLED relies on eventual consistency and + * requires only RTNL to change + */ + if (on == !!(vlan->priv_flags & BR_VLFLAG_GLOBAL_MCAST_ENABLED)) + return false; + + vlan->priv_flags ^= BR_VLFLAG_GLOBAL_MCAST_ENABLED; + br_multicast_toggle_vlan(vlan, on); + + return true; +} + +void br_multicast_stop(struct net_bridge *br) +{ + ASSERT_RTNL(); + + if (br_opt_get(br, BROPT_MCAST_VLAN_SNOOPING_ENABLED)) { + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *vlan; + + vg = br_vlan_group(br); + if (vg) { + list_for_each_entry(vlan, &vg->vlan_list, vlist) { + struct net_bridge_mcast *brmctx; + + brmctx = &vlan->br_mcast_ctx; + if (br_vlan_is_brentry(vlan) && + !br_multicast_ctx_vlan_disabled(brmctx)) + __br_multicast_stop(&vlan->br_mcast_ctx); + } + } + } else { + __br_multicast_stop(&br->multicast_ctx); + } +} + +void br_multicast_dev_del(struct net_bridge *br) +{ + struct net_bridge_mdb_entry *mp; + HLIST_HEAD(deleted_head); + struct hlist_node *tmp; + + spin_lock_bh(&br->multicast_lock); + hlist_for_each_entry_safe(mp, tmp, &br->mdb_list, mdb_node) + br_multicast_del_mdb_entry(mp); + hlist_move_list(&br->mcast_gc_list, &deleted_head); + spin_unlock_bh(&br->multicast_lock); + + br_multicast_ctx_deinit(&br->multicast_ctx); + br_multicast_gc(&deleted_head); + cancel_work_sync(&br->mcast_gc_work); + + rcu_barrier(); +} + +int br_multicast_set_router(struct net_bridge_mcast *brmctx, unsigned long val) +{ + int err = -EINVAL; + + spin_lock_bh(&brmctx->br->multicast_lock); + + switch (val) { + case MDB_RTR_TYPE_DISABLED: + case MDB_RTR_TYPE_PERM: + br_mc_router_state_change(brmctx->br, val == MDB_RTR_TYPE_PERM); + del_timer(&brmctx->ip4_mc_router_timer); +#if IS_ENABLED(CONFIG_IPV6) + del_timer(&brmctx->ip6_mc_router_timer); +#endif + brmctx->multicast_router = val; + err = 0; + break; + case MDB_RTR_TYPE_TEMP_QUERY: + if (brmctx->multicast_router != MDB_RTR_TYPE_TEMP_QUERY) + br_mc_router_state_change(brmctx->br, false); + brmctx->multicast_router = val; + err = 0; + break; + } + + spin_unlock_bh(&brmctx->br->multicast_lock); + + return err; +} + +static void +br_multicast_rport_del_notify(struct net_bridge_mcast_port *pmctx, bool deleted) +{ + if (!deleted) + return; + + /* For backwards compatibility for now, only notify if there is + * no multicast router anymore for both IPv4 and IPv6. + */ + if (!hlist_unhashed(&pmctx->ip4_rlist)) + return; +#if IS_ENABLED(CONFIG_IPV6) + if (!hlist_unhashed(&pmctx->ip6_rlist)) + return; +#endif + + br_rtr_notify(pmctx->port->br->dev, pmctx, RTM_DELMDB); + br_port_mc_router_state_change(pmctx->port, false); + + /* don't allow timer refresh */ + if (pmctx->multicast_router == MDB_RTR_TYPE_TEMP) + pmctx->multicast_router = MDB_RTR_TYPE_TEMP_QUERY; +} + +int br_multicast_set_port_router(struct net_bridge_mcast_port *pmctx, + unsigned long val) +{ + struct net_bridge_mcast *brmctx; + unsigned long now = jiffies; + int err = -EINVAL; + bool del = false; + + brmctx = br_multicast_port_ctx_get_global(pmctx); + spin_lock_bh(&brmctx->br->multicast_lock); + if (pmctx->multicast_router == val) { + /* Refresh the temp router port timer */ + if (pmctx->multicast_router == MDB_RTR_TYPE_TEMP) { + mod_timer(&pmctx->ip4_mc_router_timer, + now + brmctx->multicast_querier_interval); +#if IS_ENABLED(CONFIG_IPV6) + mod_timer(&pmctx->ip6_mc_router_timer, + now + brmctx->multicast_querier_interval); +#endif + } + err = 0; + goto unlock; + } + switch (val) { + case MDB_RTR_TYPE_DISABLED: + pmctx->multicast_router = MDB_RTR_TYPE_DISABLED; + del |= br_ip4_multicast_rport_del(pmctx); + del_timer(&pmctx->ip4_mc_router_timer); + del |= br_ip6_multicast_rport_del(pmctx); +#if IS_ENABLED(CONFIG_IPV6) + del_timer(&pmctx->ip6_mc_router_timer); +#endif + br_multicast_rport_del_notify(pmctx, del); + break; + case MDB_RTR_TYPE_TEMP_QUERY: + pmctx->multicast_router = MDB_RTR_TYPE_TEMP_QUERY; + del |= br_ip4_multicast_rport_del(pmctx); + del |= br_ip6_multicast_rport_del(pmctx); + br_multicast_rport_del_notify(pmctx, del); + break; + case MDB_RTR_TYPE_PERM: + pmctx->multicast_router = MDB_RTR_TYPE_PERM; + del_timer(&pmctx->ip4_mc_router_timer); + br_ip4_multicast_add_router(brmctx, pmctx); +#if IS_ENABLED(CONFIG_IPV6) + del_timer(&pmctx->ip6_mc_router_timer); +#endif + br_ip6_multicast_add_router(brmctx, pmctx); + break; + case MDB_RTR_TYPE_TEMP: + pmctx->multicast_router = MDB_RTR_TYPE_TEMP; + br_ip4_multicast_mark_router(brmctx, pmctx); + br_ip6_multicast_mark_router(brmctx, pmctx); + break; + default: + goto unlock; + } + err = 0; +unlock: + spin_unlock_bh(&brmctx->br->multicast_lock); + + return err; +} + +int br_multicast_set_vlan_router(struct net_bridge_vlan *v, u8 mcast_router) +{ + int err; + + if (br_vlan_is_master(v)) + err = br_multicast_set_router(&v->br_mcast_ctx, mcast_router); + else + err = br_multicast_set_port_router(&v->port_mcast_ctx, + mcast_router); + + return err; +} + +static void br_multicast_start_querier(struct net_bridge_mcast *brmctx, + struct bridge_mcast_own_query *query) +{ + struct net_bridge_port *port; + + if (!br_multicast_ctx_matches_vlan_snooping(brmctx)) + return; + + __br_multicast_open_query(brmctx->br, query); + + rcu_read_lock(); + list_for_each_entry_rcu(port, &brmctx->br->port_list, list) { + struct bridge_mcast_own_query *ip4_own_query; +#if IS_ENABLED(CONFIG_IPV6) + struct bridge_mcast_own_query *ip6_own_query; +#endif + + if (br_multicast_port_ctx_state_stopped(&port->multicast_ctx)) + continue; + + if (br_multicast_ctx_is_vlan(brmctx)) { + struct net_bridge_vlan *vlan; + + vlan = br_vlan_find(nbp_vlan_group_rcu(port), + brmctx->vlan->vid); + if (!vlan || + br_multicast_port_ctx_state_stopped(&vlan->port_mcast_ctx)) + continue; + + ip4_own_query = &vlan->port_mcast_ctx.ip4_own_query; +#if IS_ENABLED(CONFIG_IPV6) + ip6_own_query = &vlan->port_mcast_ctx.ip6_own_query; +#endif + } else { + ip4_own_query = &port->multicast_ctx.ip4_own_query; +#if IS_ENABLED(CONFIG_IPV6) + ip6_own_query = &port->multicast_ctx.ip6_own_query; +#endif + } + + if (query == &brmctx->ip4_own_query) + br_multicast_enable(ip4_own_query); +#if IS_ENABLED(CONFIG_IPV6) + else + br_multicast_enable(ip6_own_query); +#endif + } + rcu_read_unlock(); +} + +int br_multicast_toggle(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + struct net_bridge_port *port; + bool change_snoopers = false; + int err = 0; + + spin_lock_bh(&br->multicast_lock); + if (!!br_opt_get(br, BROPT_MULTICAST_ENABLED) == !!val) + goto unlock; + + err = br_mc_disabled_update(br->dev, val, extack); + if (err == -EOPNOTSUPP) + err = 0; + if (err) + goto unlock; + + br_opt_toggle(br, BROPT_MULTICAST_ENABLED, !!val); + if (!br_opt_get(br, BROPT_MULTICAST_ENABLED)) { + change_snoopers = true; + goto unlock; + } + + if (!netif_running(br->dev)) + goto unlock; + + br_multicast_open(br); + list_for_each_entry(port, &br->port_list, list) + __br_multicast_enable_port_ctx(&port->multicast_ctx); + + change_snoopers = true; + +unlock: + spin_unlock_bh(&br->multicast_lock); + + /* br_multicast_join_snoopers has the potential to cause + * an MLD Report/Leave to be delivered to br_multicast_rcv, + * which would in turn call br_multicast_add_group, which would + * attempt to acquire multicast_lock. This function should be + * called after the lock has been released to avoid deadlocks on + * multicast_lock. + * + * br_multicast_leave_snoopers does not have the problem since + * br_multicast_rcv first checks BROPT_MULTICAST_ENABLED, and + * returns without calling br_multicast_ipv4/6_rcv if it's not + * enabled. Moved both functions out just for symmetry. + */ + if (change_snoopers) { + if (br_opt_get(br, BROPT_MULTICAST_ENABLED)) + br_multicast_join_snoopers(br); + else + br_multicast_leave_snoopers(br); + } + + return err; +} + +bool br_multicast_enabled(const struct net_device *dev) +{ + struct net_bridge *br = netdev_priv(dev); + + return !!br_opt_get(br, BROPT_MULTICAST_ENABLED); +} +EXPORT_SYMBOL_GPL(br_multicast_enabled); + +bool br_multicast_router(const struct net_device *dev) +{ + struct net_bridge *br = netdev_priv(dev); + bool is_router; + + spin_lock_bh(&br->multicast_lock); + is_router = br_multicast_is_router(&br->multicast_ctx, NULL); + spin_unlock_bh(&br->multicast_lock); + return is_router; +} +EXPORT_SYMBOL_GPL(br_multicast_router); + +int br_multicast_set_querier(struct net_bridge_mcast *brmctx, unsigned long val) +{ + unsigned long max_delay; + + val = !!val; + + spin_lock_bh(&brmctx->br->multicast_lock); + if (brmctx->multicast_querier == val) + goto unlock; + + WRITE_ONCE(brmctx->multicast_querier, val); + if (!val) + goto unlock; + + max_delay = brmctx->multicast_query_response_interval; + + if (!timer_pending(&brmctx->ip4_other_query.timer)) + brmctx->ip4_other_query.delay_time = jiffies + max_delay; + + br_multicast_start_querier(brmctx, &brmctx->ip4_own_query); + +#if IS_ENABLED(CONFIG_IPV6) + if (!timer_pending(&brmctx->ip6_other_query.timer)) + brmctx->ip6_other_query.delay_time = jiffies + max_delay; + + br_multicast_start_querier(brmctx, &brmctx->ip6_own_query); +#endif + +unlock: + spin_unlock_bh(&brmctx->br->multicast_lock); + + return 0; +} + +int br_multicast_set_igmp_version(struct net_bridge_mcast *brmctx, + unsigned long val) +{ + /* Currently we support only version 2 and 3 */ + switch (val) { + case 2: + case 3: + break; + default: + return -EINVAL; + } + + spin_lock_bh(&brmctx->br->multicast_lock); + brmctx->multicast_igmp_version = val; + spin_unlock_bh(&brmctx->br->multicast_lock); + + return 0; +} + +#if IS_ENABLED(CONFIG_IPV6) +int br_multicast_set_mld_version(struct net_bridge_mcast *brmctx, + unsigned long val) +{ + /* Currently we support version 1 and 2 */ + switch (val) { + case 1: + case 2: + break; + default: + return -EINVAL; + } + + spin_lock_bh(&brmctx->br->multicast_lock); + brmctx->multicast_mld_version = val; + spin_unlock_bh(&brmctx->br->multicast_lock); + + return 0; +} +#endif + +void br_multicast_set_query_intvl(struct net_bridge_mcast *brmctx, + unsigned long val) +{ + unsigned long intvl_jiffies = clock_t_to_jiffies(val); + + if (intvl_jiffies < BR_MULTICAST_QUERY_INTVL_MIN) { + br_info(brmctx->br, + "trying to set multicast query interval below minimum, setting to %lu (%ums)\n", + jiffies_to_clock_t(BR_MULTICAST_QUERY_INTVL_MIN), + jiffies_to_msecs(BR_MULTICAST_QUERY_INTVL_MIN)); + intvl_jiffies = BR_MULTICAST_QUERY_INTVL_MIN; + } + + brmctx->multicast_query_interval = intvl_jiffies; +} + +void br_multicast_set_startup_query_intvl(struct net_bridge_mcast *brmctx, + unsigned long val) +{ + unsigned long intvl_jiffies = clock_t_to_jiffies(val); + + if (intvl_jiffies < BR_MULTICAST_STARTUP_QUERY_INTVL_MIN) { + br_info(brmctx->br, + "trying to set multicast startup query interval below minimum, setting to %lu (%ums)\n", + jiffies_to_clock_t(BR_MULTICAST_STARTUP_QUERY_INTVL_MIN), + jiffies_to_msecs(BR_MULTICAST_STARTUP_QUERY_INTVL_MIN)); + intvl_jiffies = BR_MULTICAST_STARTUP_QUERY_INTVL_MIN; + } + + brmctx->multicast_startup_query_interval = intvl_jiffies; +} + +/** + * br_multicast_list_adjacent - Returns snooped multicast addresses + * @dev: The bridge port adjacent to which to retrieve addresses + * @br_ip_list: The list to store found, snooped multicast IP addresses in + * + * Creates a list of IP addresses (struct br_ip_list) sensed by the multicast + * snooping feature on all bridge ports of dev's bridge device, excluding + * the addresses from dev itself. + * + * Returns the number of items added to br_ip_list. + * + * Notes: + * - br_ip_list needs to be initialized by caller + * - br_ip_list might contain duplicates in the end + * (needs to be taken care of by caller) + * - br_ip_list needs to be freed by caller + */ +int br_multicast_list_adjacent(struct net_device *dev, + struct list_head *br_ip_list) +{ + struct net_bridge *br; + struct net_bridge_port *port; + struct net_bridge_port_group *group; + struct br_ip_list *entry; + int count = 0; + + rcu_read_lock(); + if (!br_ip_list || !netif_is_bridge_port(dev)) + goto unlock; + + port = br_port_get_rcu(dev); + if (!port || !port->br) + goto unlock; + + br = port->br; + + list_for_each_entry_rcu(port, &br->port_list, list) { + if (!port->dev || port->dev == dev) + continue; + + hlist_for_each_entry_rcu(group, &port->mglist, mglist) { + entry = kmalloc(sizeof(*entry), GFP_ATOMIC); + if (!entry) + goto unlock; + + entry->addr = group->key.addr; + list_add(&entry->list, br_ip_list); + count++; + } + } + +unlock: + rcu_read_unlock(); + return count; +} +EXPORT_SYMBOL_GPL(br_multicast_list_adjacent); + +/** + * br_multicast_has_querier_anywhere - Checks for a querier on a bridge + * @dev: The bridge port providing the bridge on which to check for a querier + * @proto: The protocol family to check for: IGMP -> ETH_P_IP, MLD -> ETH_P_IPV6 + * + * Checks whether the given interface has a bridge on top and if so returns + * true if a valid querier exists anywhere on the bridged link layer. + * Otherwise returns false. + */ +bool br_multicast_has_querier_anywhere(struct net_device *dev, int proto) +{ + struct net_bridge *br; + struct net_bridge_port *port; + struct ethhdr eth; + bool ret = false; + + rcu_read_lock(); + if (!netif_is_bridge_port(dev)) + goto unlock; + + port = br_port_get_rcu(dev); + if (!port || !port->br) + goto unlock; + + br = port->br; + + memset(ð, 0, sizeof(eth)); + eth.h_proto = htons(proto); + + ret = br_multicast_querier_exists(&br->multicast_ctx, ð, NULL); + +unlock: + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(br_multicast_has_querier_anywhere); + +/** + * br_multicast_has_querier_adjacent - Checks for a querier behind a bridge port + * @dev: The bridge port adjacent to which to check for a querier + * @proto: The protocol family to check for: IGMP -> ETH_P_IP, MLD -> ETH_P_IPV6 + * + * Checks whether the given interface has a bridge on top and if so returns + * true if a selected querier is behind one of the other ports of this + * bridge. Otherwise returns false. + */ +bool br_multicast_has_querier_adjacent(struct net_device *dev, int proto) +{ + struct net_bridge_mcast *brmctx; + struct net_bridge *br; + struct net_bridge_port *port; + bool ret = false; + int port_ifidx; + + rcu_read_lock(); + if (!netif_is_bridge_port(dev)) + goto unlock; + + port = br_port_get_rcu(dev); + if (!port || !port->br) + goto unlock; + + br = port->br; + brmctx = &br->multicast_ctx; + + switch (proto) { + case ETH_P_IP: + port_ifidx = brmctx->ip4_querier.port_ifidx; + if (!timer_pending(&brmctx->ip4_other_query.timer) || + port_ifidx == port->dev->ifindex) + goto unlock; + break; +#if IS_ENABLED(CONFIG_IPV6) + case ETH_P_IPV6: + port_ifidx = brmctx->ip6_querier.port_ifidx; + if (!timer_pending(&brmctx->ip6_other_query.timer) || + port_ifidx == port->dev->ifindex) + goto unlock; + break; +#endif + default: + goto unlock; + } + + ret = true; +unlock: + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(br_multicast_has_querier_adjacent); + +/** + * br_multicast_has_router_adjacent - Checks for a router behind a bridge port + * @dev: The bridge port adjacent to which to check for a multicast router + * @proto: The protocol family to check for: IGMP -> ETH_P_IP, MLD -> ETH_P_IPV6 + * + * Checks whether the given interface has a bridge on top and if so returns + * true if a multicast router is behind one of the other ports of this + * bridge. Otherwise returns false. + */ +bool br_multicast_has_router_adjacent(struct net_device *dev, int proto) +{ + struct net_bridge_mcast_port *pmctx; + struct net_bridge_mcast *brmctx; + struct net_bridge_port *port; + bool ret = false; + + rcu_read_lock(); + port = br_port_get_check_rcu(dev); + if (!port) + goto unlock; + + brmctx = &port->br->multicast_ctx; + switch (proto) { + case ETH_P_IP: + hlist_for_each_entry_rcu(pmctx, &brmctx->ip4_mc_router_list, + ip4_rlist) { + if (pmctx->port == port) + continue; + + ret = true; + goto unlock; + } + break; +#if IS_ENABLED(CONFIG_IPV6) + case ETH_P_IPV6: + hlist_for_each_entry_rcu(pmctx, &brmctx->ip6_mc_router_list, + ip6_rlist) { + if (pmctx->port == port) + continue; + + ret = true; + goto unlock; + } + break; +#endif + default: + /* when compiled without IPv6 support, be conservative and + * always assume presence of an IPv6 multicast router + */ + ret = true; + } + +unlock: + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(br_multicast_has_router_adjacent); + +static void br_mcast_stats_add(struct bridge_mcast_stats __percpu *stats, + const struct sk_buff *skb, u8 type, u8 dir) +{ + struct bridge_mcast_stats *pstats = this_cpu_ptr(stats); + __be16 proto = skb->protocol; + unsigned int t_len; + + u64_stats_update_begin(&pstats->syncp); + switch (proto) { + case htons(ETH_P_IP): + t_len = ntohs(ip_hdr(skb)->tot_len) - ip_hdrlen(skb); + switch (type) { + case IGMP_HOST_MEMBERSHIP_REPORT: + pstats->mstats.igmp_v1reports[dir]++; + break; + case IGMPV2_HOST_MEMBERSHIP_REPORT: + pstats->mstats.igmp_v2reports[dir]++; + break; + case IGMPV3_HOST_MEMBERSHIP_REPORT: + pstats->mstats.igmp_v3reports[dir]++; + break; + case IGMP_HOST_MEMBERSHIP_QUERY: + if (t_len != sizeof(struct igmphdr)) { + pstats->mstats.igmp_v3queries[dir]++; + } else { + unsigned int offset = skb_transport_offset(skb); + struct igmphdr *ih, _ihdr; + + ih = skb_header_pointer(skb, offset, + sizeof(_ihdr), &_ihdr); + if (!ih) + break; + if (!ih->code) + pstats->mstats.igmp_v1queries[dir]++; + else + pstats->mstats.igmp_v2queries[dir]++; + } + break; + case IGMP_HOST_LEAVE_MESSAGE: + pstats->mstats.igmp_leaves[dir]++; + break; + } + break; +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): + t_len = ntohs(ipv6_hdr(skb)->payload_len) + + sizeof(struct ipv6hdr); + t_len -= skb_network_header_len(skb); + switch (type) { + case ICMPV6_MGM_REPORT: + pstats->mstats.mld_v1reports[dir]++; + break; + case ICMPV6_MLD2_REPORT: + pstats->mstats.mld_v2reports[dir]++; + break; + case ICMPV6_MGM_QUERY: + if (t_len != sizeof(struct mld_msg)) + pstats->mstats.mld_v2queries[dir]++; + else + pstats->mstats.mld_v1queries[dir]++; + break; + case ICMPV6_MGM_REDUCTION: + pstats->mstats.mld_leaves[dir]++; + break; + } + break; +#endif /* CONFIG_IPV6 */ + } + u64_stats_update_end(&pstats->syncp); +} + +void br_multicast_count(struct net_bridge *br, + const struct net_bridge_port *p, + const struct sk_buff *skb, u8 type, u8 dir) +{ + struct bridge_mcast_stats __percpu *stats; + + /* if multicast_disabled is true then igmp type can't be set */ + if (!type || !br_opt_get(br, BROPT_MULTICAST_STATS_ENABLED)) + return; + + if (p) + stats = p->mcast_stats; + else + stats = br->mcast_stats; + if (WARN_ON(!stats)) + return; + + br_mcast_stats_add(stats, skb, type, dir); +} + +int br_multicast_init_stats(struct net_bridge *br) +{ + br->mcast_stats = netdev_alloc_pcpu_stats(struct bridge_mcast_stats); + if (!br->mcast_stats) + return -ENOMEM; + + return 0; +} + +void br_multicast_uninit_stats(struct net_bridge *br) +{ + free_percpu(br->mcast_stats); +} + +/* noinline for https://bugs.llvm.org/show_bug.cgi?id=45802#c9 */ +static noinline_for_stack void mcast_stats_add_dir(u64 *dst, u64 *src) +{ + dst[BR_MCAST_DIR_RX] += src[BR_MCAST_DIR_RX]; + dst[BR_MCAST_DIR_TX] += src[BR_MCAST_DIR_TX]; +} + +void br_multicast_get_stats(const struct net_bridge *br, + const struct net_bridge_port *p, + struct br_mcast_stats *dest) +{ + struct bridge_mcast_stats __percpu *stats; + struct br_mcast_stats tdst; + int i; + + memset(dest, 0, sizeof(*dest)); + if (p) + stats = p->mcast_stats; + else + stats = br->mcast_stats; + if (WARN_ON(!stats)) + return; + + memset(&tdst, 0, sizeof(tdst)); + for_each_possible_cpu(i) { + struct bridge_mcast_stats *cpu_stats = per_cpu_ptr(stats, i); + struct br_mcast_stats temp; + unsigned int start; + + do { + start = u64_stats_fetch_begin_irq(&cpu_stats->syncp); + memcpy(&temp, &cpu_stats->mstats, sizeof(temp)); + } while (u64_stats_fetch_retry_irq(&cpu_stats->syncp, start)); + + mcast_stats_add_dir(tdst.igmp_v1queries, temp.igmp_v1queries); + mcast_stats_add_dir(tdst.igmp_v2queries, temp.igmp_v2queries); + mcast_stats_add_dir(tdst.igmp_v3queries, temp.igmp_v3queries); + mcast_stats_add_dir(tdst.igmp_leaves, temp.igmp_leaves); + mcast_stats_add_dir(tdst.igmp_v1reports, temp.igmp_v1reports); + mcast_stats_add_dir(tdst.igmp_v2reports, temp.igmp_v2reports); + mcast_stats_add_dir(tdst.igmp_v3reports, temp.igmp_v3reports); + tdst.igmp_parse_errors += temp.igmp_parse_errors; + + mcast_stats_add_dir(tdst.mld_v1queries, temp.mld_v1queries); + mcast_stats_add_dir(tdst.mld_v2queries, temp.mld_v2queries); + mcast_stats_add_dir(tdst.mld_leaves, temp.mld_leaves); + mcast_stats_add_dir(tdst.mld_v1reports, temp.mld_v1reports); + mcast_stats_add_dir(tdst.mld_v2reports, temp.mld_v2reports); + tdst.mld_parse_errors += temp.mld_parse_errors; + } + memcpy(dest, &tdst, sizeof(*dest)); +} + +int br_mdb_hash_init(struct net_bridge *br) +{ + int err; + + err = rhashtable_init(&br->sg_port_tbl, &br_sg_port_rht_params); + if (err) + return err; + + err = rhashtable_init(&br->mdb_hash_tbl, &br_mdb_rht_params); + if (err) { + rhashtable_destroy(&br->sg_port_tbl); + return err; + } + + return 0; +} + +void br_mdb_hash_fini(struct net_bridge *br) +{ + rhashtable_destroy(&br->sg_port_tbl); + rhashtable_destroy(&br->mdb_hash_tbl); +} diff --git a/net/bridge/br_multicast_eht.c b/net/bridge/br_multicast_eht.c new file mode 100644 index 000000000..f91c071d1 --- /dev/null +++ b/net/bridge/br_multicast_eht.c @@ -0,0 +1,819 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +// Copyright (c) 2020, Nikolay Aleksandrov +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_IPV6) +#include +#include +#include +#include +#include +#endif + +#include "br_private.h" +#include "br_private_mcast_eht.h" + +static bool br_multicast_del_eht_set_entry(struct net_bridge_port_group *pg, + union net_bridge_eht_addr *src_addr, + union net_bridge_eht_addr *h_addr); +static void br_multicast_create_eht_set_entry(const struct net_bridge_mcast *brmctx, + struct net_bridge_port_group *pg, + union net_bridge_eht_addr *src_addr, + union net_bridge_eht_addr *h_addr, + int filter_mode, + bool allow_zero_src); + +static struct net_bridge_group_eht_host * +br_multicast_eht_host_lookup(struct net_bridge_port_group *pg, + union net_bridge_eht_addr *h_addr) +{ + struct rb_node *node = pg->eht_host_tree.rb_node; + + while (node) { + struct net_bridge_group_eht_host *this; + int result; + + this = rb_entry(node, struct net_bridge_group_eht_host, + rb_node); + result = memcmp(h_addr, &this->h_addr, sizeof(*h_addr)); + if (result < 0) + node = node->rb_left; + else if (result > 0) + node = node->rb_right; + else + return this; + } + + return NULL; +} + +static int br_multicast_eht_host_filter_mode(struct net_bridge_port_group *pg, + union net_bridge_eht_addr *h_addr) +{ + struct net_bridge_group_eht_host *eht_host; + + eht_host = br_multicast_eht_host_lookup(pg, h_addr); + if (!eht_host) + return MCAST_INCLUDE; + + return eht_host->filter_mode; +} + +static struct net_bridge_group_eht_set_entry * +br_multicast_eht_set_entry_lookup(struct net_bridge_group_eht_set *eht_set, + union net_bridge_eht_addr *h_addr) +{ + struct rb_node *node = eht_set->entry_tree.rb_node; + + while (node) { + struct net_bridge_group_eht_set_entry *this; + int result; + + this = rb_entry(node, struct net_bridge_group_eht_set_entry, + rb_node); + result = memcmp(h_addr, &this->h_addr, sizeof(*h_addr)); + if (result < 0) + node = node->rb_left; + else if (result > 0) + node = node->rb_right; + else + return this; + } + + return NULL; +} + +static struct net_bridge_group_eht_set * +br_multicast_eht_set_lookup(struct net_bridge_port_group *pg, + union net_bridge_eht_addr *src_addr) +{ + struct rb_node *node = pg->eht_set_tree.rb_node; + + while (node) { + struct net_bridge_group_eht_set *this; + int result; + + this = rb_entry(node, struct net_bridge_group_eht_set, + rb_node); + result = memcmp(src_addr, &this->src_addr, sizeof(*src_addr)); + if (result < 0) + node = node->rb_left; + else if (result > 0) + node = node->rb_right; + else + return this; + } + + return NULL; +} + +static void __eht_destroy_host(struct net_bridge_group_eht_host *eht_host) +{ + WARN_ON(!hlist_empty(&eht_host->set_entries)); + + br_multicast_eht_hosts_dec(eht_host->pg); + + rb_erase(&eht_host->rb_node, &eht_host->pg->eht_host_tree); + RB_CLEAR_NODE(&eht_host->rb_node); + kfree(eht_host); +} + +static void br_multicast_destroy_eht_set_entry(struct net_bridge_mcast_gc *gc) +{ + struct net_bridge_group_eht_set_entry *set_h; + + set_h = container_of(gc, struct net_bridge_group_eht_set_entry, mcast_gc); + WARN_ON(!RB_EMPTY_NODE(&set_h->rb_node)); + + del_timer_sync(&set_h->timer); + kfree(set_h); +} + +static void br_multicast_destroy_eht_set(struct net_bridge_mcast_gc *gc) +{ + struct net_bridge_group_eht_set *eht_set; + + eht_set = container_of(gc, struct net_bridge_group_eht_set, mcast_gc); + WARN_ON(!RB_EMPTY_NODE(&eht_set->rb_node)); + WARN_ON(!RB_EMPTY_ROOT(&eht_set->entry_tree)); + + del_timer_sync(&eht_set->timer); + kfree(eht_set); +} + +static void __eht_del_set_entry(struct net_bridge_group_eht_set_entry *set_h) +{ + struct net_bridge_group_eht_host *eht_host = set_h->h_parent; + union net_bridge_eht_addr zero_addr; + + rb_erase(&set_h->rb_node, &set_h->eht_set->entry_tree); + RB_CLEAR_NODE(&set_h->rb_node); + hlist_del_init(&set_h->host_list); + memset(&zero_addr, 0, sizeof(zero_addr)); + if (memcmp(&set_h->h_addr, &zero_addr, sizeof(zero_addr))) + eht_host->num_entries--; + hlist_add_head(&set_h->mcast_gc.gc_node, &set_h->br->mcast_gc_list); + queue_work(system_long_wq, &set_h->br->mcast_gc_work); + + if (hlist_empty(&eht_host->set_entries)) + __eht_destroy_host(eht_host); +} + +static void br_multicast_del_eht_set(struct net_bridge_group_eht_set *eht_set) +{ + struct net_bridge_group_eht_set_entry *set_h; + struct rb_node *node; + + while ((node = rb_first(&eht_set->entry_tree))) { + set_h = rb_entry(node, struct net_bridge_group_eht_set_entry, + rb_node); + __eht_del_set_entry(set_h); + } + + rb_erase(&eht_set->rb_node, &eht_set->pg->eht_set_tree); + RB_CLEAR_NODE(&eht_set->rb_node); + hlist_add_head(&eht_set->mcast_gc.gc_node, &eht_set->br->mcast_gc_list); + queue_work(system_long_wq, &eht_set->br->mcast_gc_work); +} + +void br_multicast_eht_clean_sets(struct net_bridge_port_group *pg) +{ + struct net_bridge_group_eht_set *eht_set; + struct rb_node *node; + + while ((node = rb_first(&pg->eht_set_tree))) { + eht_set = rb_entry(node, struct net_bridge_group_eht_set, + rb_node); + br_multicast_del_eht_set(eht_set); + } +} + +static void br_multicast_eht_set_entry_expired(struct timer_list *t) +{ + struct net_bridge_group_eht_set_entry *set_h = from_timer(set_h, t, timer); + struct net_bridge *br = set_h->br; + + spin_lock(&br->multicast_lock); + if (RB_EMPTY_NODE(&set_h->rb_node) || timer_pending(&set_h->timer)) + goto out; + + br_multicast_del_eht_set_entry(set_h->eht_set->pg, + &set_h->eht_set->src_addr, + &set_h->h_addr); +out: + spin_unlock(&br->multicast_lock); +} + +static void br_multicast_eht_set_expired(struct timer_list *t) +{ + struct net_bridge_group_eht_set *eht_set = from_timer(eht_set, t, + timer); + struct net_bridge *br = eht_set->br; + + spin_lock(&br->multicast_lock); + if (RB_EMPTY_NODE(&eht_set->rb_node) || timer_pending(&eht_set->timer)) + goto out; + + br_multicast_del_eht_set(eht_set); +out: + spin_unlock(&br->multicast_lock); +} + +static struct net_bridge_group_eht_host * +__eht_lookup_create_host(struct net_bridge_port_group *pg, + union net_bridge_eht_addr *h_addr, + unsigned char filter_mode) +{ + struct rb_node **link = &pg->eht_host_tree.rb_node, *parent = NULL; + struct net_bridge_group_eht_host *eht_host; + + while (*link) { + struct net_bridge_group_eht_host *this; + int result; + + this = rb_entry(*link, struct net_bridge_group_eht_host, + rb_node); + result = memcmp(h_addr, &this->h_addr, sizeof(*h_addr)); + parent = *link; + if (result < 0) + link = &((*link)->rb_left); + else if (result > 0) + link = &((*link)->rb_right); + else + return this; + } + + if (br_multicast_eht_hosts_over_limit(pg)) + return NULL; + + eht_host = kzalloc(sizeof(*eht_host), GFP_ATOMIC); + if (!eht_host) + return NULL; + + memcpy(&eht_host->h_addr, h_addr, sizeof(*h_addr)); + INIT_HLIST_HEAD(&eht_host->set_entries); + eht_host->pg = pg; + eht_host->filter_mode = filter_mode; + + rb_link_node(&eht_host->rb_node, parent, link); + rb_insert_color(&eht_host->rb_node, &pg->eht_host_tree); + + br_multicast_eht_hosts_inc(pg); + + return eht_host; +} + +static struct net_bridge_group_eht_set_entry * +__eht_lookup_create_set_entry(struct net_bridge *br, + struct net_bridge_group_eht_set *eht_set, + struct net_bridge_group_eht_host *eht_host, + bool allow_zero_src) +{ + struct rb_node **link = &eht_set->entry_tree.rb_node, *parent = NULL; + struct net_bridge_group_eht_set_entry *set_h; + + while (*link) { + struct net_bridge_group_eht_set_entry *this; + int result; + + this = rb_entry(*link, struct net_bridge_group_eht_set_entry, + rb_node); + result = memcmp(&eht_host->h_addr, &this->h_addr, + sizeof(union net_bridge_eht_addr)); + parent = *link; + if (result < 0) + link = &((*link)->rb_left); + else if (result > 0) + link = &((*link)->rb_right); + else + return this; + } + + /* always allow auto-created zero entry */ + if (!allow_zero_src && eht_host->num_entries >= PG_SRC_ENT_LIMIT) + return NULL; + + set_h = kzalloc(sizeof(*set_h), GFP_ATOMIC); + if (!set_h) + return NULL; + + memcpy(&set_h->h_addr, &eht_host->h_addr, + sizeof(union net_bridge_eht_addr)); + set_h->mcast_gc.destroy = br_multicast_destroy_eht_set_entry; + set_h->eht_set = eht_set; + set_h->h_parent = eht_host; + set_h->br = br; + timer_setup(&set_h->timer, br_multicast_eht_set_entry_expired, 0); + + hlist_add_head(&set_h->host_list, &eht_host->set_entries); + rb_link_node(&set_h->rb_node, parent, link); + rb_insert_color(&set_h->rb_node, &eht_set->entry_tree); + /* we must not count the auto-created zero entry otherwise we won't be + * able to track the full list of PG_SRC_ENT_LIMIT entries + */ + if (!allow_zero_src) + eht_host->num_entries++; + + return set_h; +} + +static struct net_bridge_group_eht_set * +__eht_lookup_create_set(struct net_bridge_port_group *pg, + union net_bridge_eht_addr *src_addr) +{ + struct rb_node **link = &pg->eht_set_tree.rb_node, *parent = NULL; + struct net_bridge_group_eht_set *eht_set; + + while (*link) { + struct net_bridge_group_eht_set *this; + int result; + + this = rb_entry(*link, struct net_bridge_group_eht_set, + rb_node); + result = memcmp(src_addr, &this->src_addr, sizeof(*src_addr)); + parent = *link; + if (result < 0) + link = &((*link)->rb_left); + else if (result > 0) + link = &((*link)->rb_right); + else + return this; + } + + eht_set = kzalloc(sizeof(*eht_set), GFP_ATOMIC); + if (!eht_set) + return NULL; + + memcpy(&eht_set->src_addr, src_addr, sizeof(*src_addr)); + eht_set->mcast_gc.destroy = br_multicast_destroy_eht_set; + eht_set->pg = pg; + eht_set->br = pg->key.port->br; + eht_set->entry_tree = RB_ROOT; + timer_setup(&eht_set->timer, br_multicast_eht_set_expired, 0); + + rb_link_node(&eht_set->rb_node, parent, link); + rb_insert_color(&eht_set->rb_node, &pg->eht_set_tree); + + return eht_set; +} + +static void br_multicast_ip_src_to_eht_addr(const struct br_ip *src, + union net_bridge_eht_addr *dest) +{ + switch (src->proto) { + case htons(ETH_P_IP): + dest->ip4 = src->src.ip4; + break; +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): + memcpy(&dest->ip6, &src->src.ip6, sizeof(struct in6_addr)); + break; +#endif + } +} + +static void br_eht_convert_host_filter_mode(const struct net_bridge_mcast *brmctx, + struct net_bridge_port_group *pg, + union net_bridge_eht_addr *h_addr, + int filter_mode) +{ + struct net_bridge_group_eht_host *eht_host; + union net_bridge_eht_addr zero_addr; + + eht_host = br_multicast_eht_host_lookup(pg, h_addr); + if (eht_host) + eht_host->filter_mode = filter_mode; + + memset(&zero_addr, 0, sizeof(zero_addr)); + switch (filter_mode) { + case MCAST_INCLUDE: + br_multicast_del_eht_set_entry(pg, &zero_addr, h_addr); + break; + case MCAST_EXCLUDE: + br_multicast_create_eht_set_entry(brmctx, pg, &zero_addr, + h_addr, MCAST_EXCLUDE, + true); + break; + } +} + +static void br_multicast_create_eht_set_entry(const struct net_bridge_mcast *brmctx, + struct net_bridge_port_group *pg, + union net_bridge_eht_addr *src_addr, + union net_bridge_eht_addr *h_addr, + int filter_mode, + bool allow_zero_src) +{ + struct net_bridge_group_eht_set_entry *set_h; + struct net_bridge_group_eht_host *eht_host; + struct net_bridge *br = pg->key.port->br; + struct net_bridge_group_eht_set *eht_set; + union net_bridge_eht_addr zero_addr; + + memset(&zero_addr, 0, sizeof(zero_addr)); + if (!allow_zero_src && !memcmp(src_addr, &zero_addr, sizeof(zero_addr))) + return; + + eht_set = __eht_lookup_create_set(pg, src_addr); + if (!eht_set) + return; + + eht_host = __eht_lookup_create_host(pg, h_addr, filter_mode); + if (!eht_host) + goto fail_host; + + set_h = __eht_lookup_create_set_entry(br, eht_set, eht_host, + allow_zero_src); + if (!set_h) + goto fail_set_entry; + + mod_timer(&set_h->timer, jiffies + br_multicast_gmi(brmctx)); + mod_timer(&eht_set->timer, jiffies + br_multicast_gmi(brmctx)); + + return; + +fail_set_entry: + if (hlist_empty(&eht_host->set_entries)) + __eht_destroy_host(eht_host); +fail_host: + if (RB_EMPTY_ROOT(&eht_set->entry_tree)) + br_multicast_del_eht_set(eht_set); +} + +static bool br_multicast_del_eht_set_entry(struct net_bridge_port_group *pg, + union net_bridge_eht_addr *src_addr, + union net_bridge_eht_addr *h_addr) +{ + struct net_bridge_group_eht_set_entry *set_h; + struct net_bridge_group_eht_set *eht_set; + bool set_deleted = false; + + eht_set = br_multicast_eht_set_lookup(pg, src_addr); + if (!eht_set) + goto out; + + set_h = br_multicast_eht_set_entry_lookup(eht_set, h_addr); + if (!set_h) + goto out; + + __eht_del_set_entry(set_h); + + if (RB_EMPTY_ROOT(&eht_set->entry_tree)) { + br_multicast_del_eht_set(eht_set); + set_deleted = true; + } + +out: + return set_deleted; +} + +static void br_multicast_del_eht_host(struct net_bridge_port_group *pg, + union net_bridge_eht_addr *h_addr) +{ + struct net_bridge_group_eht_set_entry *set_h; + struct net_bridge_group_eht_host *eht_host; + struct hlist_node *tmp; + + eht_host = br_multicast_eht_host_lookup(pg, h_addr); + if (!eht_host) + return; + + hlist_for_each_entry_safe(set_h, tmp, &eht_host->set_entries, host_list) + br_multicast_del_eht_set_entry(set_h->eht_set->pg, + &set_h->eht_set->src_addr, + &set_h->h_addr); +} + +/* create new set entries from reports */ +static void __eht_create_set_entries(const struct net_bridge_mcast *brmctx, + struct net_bridge_port_group *pg, + union net_bridge_eht_addr *h_addr, + void *srcs, + u32 nsrcs, + size_t addr_size, + int filter_mode) +{ + union net_bridge_eht_addr eht_src_addr; + u32 src_idx; + + memset(&eht_src_addr, 0, sizeof(eht_src_addr)); + for (src_idx = 0; src_idx < nsrcs; src_idx++) { + memcpy(&eht_src_addr, srcs + (src_idx * addr_size), addr_size); + br_multicast_create_eht_set_entry(brmctx, pg, &eht_src_addr, + h_addr, filter_mode, + false); + } +} + +/* delete existing set entries and their (S,G) entries if they were the last */ +static bool __eht_del_set_entries(struct net_bridge_port_group *pg, + union net_bridge_eht_addr *h_addr, + void *srcs, + u32 nsrcs, + size_t addr_size) +{ + union net_bridge_eht_addr eht_src_addr; + struct net_bridge_group_src *src_ent; + bool changed = false; + struct br_ip src_ip; + u32 src_idx; + + memset(&eht_src_addr, 0, sizeof(eht_src_addr)); + memset(&src_ip, 0, sizeof(src_ip)); + src_ip.proto = pg->key.addr.proto; + for (src_idx = 0; src_idx < nsrcs; src_idx++) { + memcpy(&eht_src_addr, srcs + (src_idx * addr_size), addr_size); + if (!br_multicast_del_eht_set_entry(pg, &eht_src_addr, h_addr)) + continue; + memcpy(&src_ip, srcs + (src_idx * addr_size), addr_size); + src_ent = br_multicast_find_group_src(pg, &src_ip); + if (!src_ent) + continue; + br_multicast_del_group_src(src_ent, true); + changed = true; + } + + return changed; +} + +static bool br_multicast_eht_allow(const struct net_bridge_mcast *brmctx, + struct net_bridge_port_group *pg, + union net_bridge_eht_addr *h_addr, + void *srcs, + u32 nsrcs, + size_t addr_size) +{ + bool changed = false; + + switch (br_multicast_eht_host_filter_mode(pg, h_addr)) { + case MCAST_INCLUDE: + __eht_create_set_entries(brmctx, pg, h_addr, srcs, nsrcs, + addr_size, MCAST_INCLUDE); + break; + case MCAST_EXCLUDE: + changed = __eht_del_set_entries(pg, h_addr, srcs, nsrcs, + addr_size); + break; + } + + return changed; +} + +static bool br_multicast_eht_block(const struct net_bridge_mcast *brmctx, + struct net_bridge_port_group *pg, + union net_bridge_eht_addr *h_addr, + void *srcs, + u32 nsrcs, + size_t addr_size) +{ + bool changed = false; + + switch (br_multicast_eht_host_filter_mode(pg, h_addr)) { + case MCAST_INCLUDE: + changed = __eht_del_set_entries(pg, h_addr, srcs, nsrcs, + addr_size); + break; + case MCAST_EXCLUDE: + __eht_create_set_entries(brmctx, pg, h_addr, srcs, nsrcs, addr_size, + MCAST_EXCLUDE); + break; + } + + return changed; +} + +/* flush_entries is true when changing mode */ +static bool __eht_inc_exc(const struct net_bridge_mcast *brmctx, + struct net_bridge_port_group *pg, + union net_bridge_eht_addr *h_addr, + void *srcs, + u32 nsrcs, + size_t addr_size, + unsigned char filter_mode, + bool to_report) +{ + bool changed = false, flush_entries = to_report; + union net_bridge_eht_addr eht_src_addr; + + if (br_multicast_eht_host_filter_mode(pg, h_addr) != filter_mode) + flush_entries = true; + + memset(&eht_src_addr, 0, sizeof(eht_src_addr)); + /* if we're changing mode del host and its entries */ + if (flush_entries) + br_multicast_del_eht_host(pg, h_addr); + __eht_create_set_entries(brmctx, pg, h_addr, srcs, nsrcs, addr_size, + filter_mode); + /* we can be missing sets only if we've deleted some entries */ + if (flush_entries) { + struct net_bridge_group_eht_set *eht_set; + struct net_bridge_group_src *src_ent; + struct hlist_node *tmp; + + hlist_for_each_entry_safe(src_ent, tmp, &pg->src_list, node) { + br_multicast_ip_src_to_eht_addr(&src_ent->addr, + &eht_src_addr); + if (!br_multicast_eht_set_lookup(pg, &eht_src_addr)) { + br_multicast_del_group_src(src_ent, true); + changed = true; + continue; + } + /* this is an optimization for TO_INCLUDE where we lower + * the set's timeout to LMQT to catch timeout hosts: + * - host A (timing out): set entries X, Y + * - host B: set entry Z (new from current TO_INCLUDE) + * sends BLOCK Z after LMQT but host A's EHT + * entries still exist (unless lowered to LMQT + * so they can timeout with the S,Gs) + * => we wait another LMQT, when we can just delete the + * group immediately + */ + if (!(src_ent->flags & BR_SGRP_F_SEND) || + filter_mode != MCAST_INCLUDE || + !to_report) + continue; + eht_set = br_multicast_eht_set_lookup(pg, + &eht_src_addr); + if (!eht_set) + continue; + mod_timer(&eht_set->timer, jiffies + br_multicast_lmqt(brmctx)); + } + } + + return changed; +} + +static bool br_multicast_eht_inc(const struct net_bridge_mcast *brmctx, + struct net_bridge_port_group *pg, + union net_bridge_eht_addr *h_addr, + void *srcs, + u32 nsrcs, + size_t addr_size, + bool to_report) +{ + bool changed; + + changed = __eht_inc_exc(brmctx, pg, h_addr, srcs, nsrcs, addr_size, + MCAST_INCLUDE, to_report); + br_eht_convert_host_filter_mode(brmctx, pg, h_addr, MCAST_INCLUDE); + + return changed; +} + +static bool br_multicast_eht_exc(const struct net_bridge_mcast *brmctx, + struct net_bridge_port_group *pg, + union net_bridge_eht_addr *h_addr, + void *srcs, + u32 nsrcs, + size_t addr_size, + bool to_report) +{ + bool changed; + + changed = __eht_inc_exc(brmctx, pg, h_addr, srcs, nsrcs, addr_size, + MCAST_EXCLUDE, to_report); + br_eht_convert_host_filter_mode(brmctx, pg, h_addr, MCAST_EXCLUDE); + + return changed; +} + +static bool __eht_ip4_handle(const struct net_bridge_mcast *brmctx, + struct net_bridge_port_group *pg, + union net_bridge_eht_addr *h_addr, + void *srcs, + u32 nsrcs, + int grec_type) +{ + bool changed = false, to_report = false; + + switch (grec_type) { + case IGMPV3_ALLOW_NEW_SOURCES: + br_multicast_eht_allow(brmctx, pg, h_addr, srcs, nsrcs, + sizeof(__be32)); + break; + case IGMPV3_BLOCK_OLD_SOURCES: + changed = br_multicast_eht_block(brmctx, pg, h_addr, srcs, nsrcs, + sizeof(__be32)); + break; + case IGMPV3_CHANGE_TO_INCLUDE: + to_report = true; + fallthrough; + case IGMPV3_MODE_IS_INCLUDE: + changed = br_multicast_eht_inc(brmctx, pg, h_addr, srcs, nsrcs, + sizeof(__be32), to_report); + break; + case IGMPV3_CHANGE_TO_EXCLUDE: + to_report = true; + fallthrough; + case IGMPV3_MODE_IS_EXCLUDE: + changed = br_multicast_eht_exc(brmctx, pg, h_addr, srcs, nsrcs, + sizeof(__be32), to_report); + break; + } + + return changed; +} + +#if IS_ENABLED(CONFIG_IPV6) +static bool __eht_ip6_handle(const struct net_bridge_mcast *brmctx, + struct net_bridge_port_group *pg, + union net_bridge_eht_addr *h_addr, + void *srcs, + u32 nsrcs, + int grec_type) +{ + bool changed = false, to_report = false; + + switch (grec_type) { + case MLD2_ALLOW_NEW_SOURCES: + br_multicast_eht_allow(brmctx, pg, h_addr, srcs, nsrcs, + sizeof(struct in6_addr)); + break; + case MLD2_BLOCK_OLD_SOURCES: + changed = br_multicast_eht_block(brmctx, pg, h_addr, srcs, nsrcs, + sizeof(struct in6_addr)); + break; + case MLD2_CHANGE_TO_INCLUDE: + to_report = true; + fallthrough; + case MLD2_MODE_IS_INCLUDE: + changed = br_multicast_eht_inc(brmctx, pg, h_addr, srcs, nsrcs, + sizeof(struct in6_addr), + to_report); + break; + case MLD2_CHANGE_TO_EXCLUDE: + to_report = true; + fallthrough; + case MLD2_MODE_IS_EXCLUDE: + changed = br_multicast_eht_exc(brmctx, pg, h_addr, srcs, nsrcs, + sizeof(struct in6_addr), + to_report); + break; + } + + return changed; +} +#endif + +/* true means an entry was deleted */ +bool br_multicast_eht_handle(const struct net_bridge_mcast *brmctx, + struct net_bridge_port_group *pg, + void *h_addr, + void *srcs, + u32 nsrcs, + size_t addr_size, + int grec_type) +{ + bool eht_enabled = !!(pg->key.port->flags & BR_MULTICAST_FAST_LEAVE); + union net_bridge_eht_addr eht_host_addr; + bool changed = false; + + if (!eht_enabled) + goto out; + + memset(&eht_host_addr, 0, sizeof(eht_host_addr)); + memcpy(&eht_host_addr, h_addr, addr_size); + if (addr_size == sizeof(__be32)) + changed = __eht_ip4_handle(brmctx, pg, &eht_host_addr, srcs, + nsrcs, grec_type); +#if IS_ENABLED(CONFIG_IPV6) + else + changed = __eht_ip6_handle(brmctx, pg, &eht_host_addr, srcs, + nsrcs, grec_type); +#endif + +out: + return changed; +} + +int br_multicast_eht_set_hosts_limit(struct net_bridge_port *p, + u32 eht_hosts_limit) +{ + struct net_bridge *br = p->br; + + if (!eht_hosts_limit) + return -EINVAL; + + spin_lock_bh(&br->multicast_lock); + p->multicast_eht_hosts_limit = eht_hosts_limit; + spin_unlock_bh(&br->multicast_lock); + + return 0; +} diff --git a/net/bridge/br_netfilter_hooks.c b/net/bridge/br_netfilter_hooks.c new file mode 100644 index 000000000..202ad43e3 --- /dev/null +++ b/net/bridge/br_netfilter_hooks.c @@ -0,0 +1,1247 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Handle firewalling + * Linux ethernet bridge + * + * Authors: + * Lennert Buytenhek + * Bart De Schuymer + * + * Lennert dedicates this file to Kerstin Wurdinger. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include "br_private.h" +#ifdef CONFIG_SYSCTL +#include +#endif + +static unsigned int brnf_net_id __read_mostly; + +struct brnf_net { + bool enabled; + +#ifdef CONFIG_SYSCTL + struct ctl_table_header *ctl_hdr; +#endif + + /* default value is 1 */ + int call_iptables; + int call_ip6tables; + int call_arptables; + + /* default value is 0 */ + int filter_vlan_tagged; + int filter_pppoe_tagged; + int pass_vlan_indev; +}; + +#define IS_IP(skb) \ + (!skb_vlan_tag_present(skb) && skb->protocol == htons(ETH_P_IP)) + +#define IS_IPV6(skb) \ + (!skb_vlan_tag_present(skb) && skb->protocol == htons(ETH_P_IPV6)) + +#define IS_ARP(skb) \ + (!skb_vlan_tag_present(skb) && skb->protocol == htons(ETH_P_ARP)) + +static inline __be16 vlan_proto(const struct sk_buff *skb) +{ + if (skb_vlan_tag_present(skb)) + return skb->protocol; + else if (skb->protocol == htons(ETH_P_8021Q)) + return vlan_eth_hdr(skb)->h_vlan_encapsulated_proto; + else + return 0; +} + +static inline bool is_vlan_ip(const struct sk_buff *skb, const struct net *net) +{ + struct brnf_net *brnet = net_generic(net, brnf_net_id); + + return vlan_proto(skb) == htons(ETH_P_IP) && brnet->filter_vlan_tagged; +} + +static inline bool is_vlan_ipv6(const struct sk_buff *skb, + const struct net *net) +{ + struct brnf_net *brnet = net_generic(net, brnf_net_id); + + return vlan_proto(skb) == htons(ETH_P_IPV6) && + brnet->filter_vlan_tagged; +} + +static inline bool is_vlan_arp(const struct sk_buff *skb, const struct net *net) +{ + struct brnf_net *brnet = net_generic(net, brnf_net_id); + + return vlan_proto(skb) == htons(ETH_P_ARP) && brnet->filter_vlan_tagged; +} + +static inline __be16 pppoe_proto(const struct sk_buff *skb) +{ + return *((__be16 *)(skb_mac_header(skb) + ETH_HLEN + + sizeof(struct pppoe_hdr))); +} + +static inline bool is_pppoe_ip(const struct sk_buff *skb, const struct net *net) +{ + struct brnf_net *brnet = net_generic(net, brnf_net_id); + + return skb->protocol == htons(ETH_P_PPP_SES) && + pppoe_proto(skb) == htons(PPP_IP) && brnet->filter_pppoe_tagged; +} + +static inline bool is_pppoe_ipv6(const struct sk_buff *skb, + const struct net *net) +{ + struct brnf_net *brnet = net_generic(net, brnf_net_id); + + return skb->protocol == htons(ETH_P_PPP_SES) && + pppoe_proto(skb) == htons(PPP_IPV6) && + brnet->filter_pppoe_tagged; +} + +/* largest possible L2 header, see br_nf_dev_queue_xmit() */ +#define NF_BRIDGE_MAX_MAC_HEADER_LENGTH (PPPOE_SES_HLEN + ETH_HLEN) + +struct brnf_frag_data { + char mac[NF_BRIDGE_MAX_MAC_HEADER_LENGTH]; + u8 encap_size; + u8 size; + u16 vlan_tci; + __be16 vlan_proto; +}; + +static DEFINE_PER_CPU(struct brnf_frag_data, brnf_frag_data_storage); + +static void nf_bridge_info_free(struct sk_buff *skb) +{ + skb_ext_del(skb, SKB_EXT_BRIDGE_NF); +} + +static inline struct net_device *bridge_parent(const struct net_device *dev) +{ + struct net_bridge_port *port; + + port = br_port_get_rcu(dev); + return port ? port->br->dev : NULL; +} + +static inline struct nf_bridge_info *nf_bridge_unshare(struct sk_buff *skb) +{ + return skb_ext_add(skb, SKB_EXT_BRIDGE_NF); +} + +unsigned int nf_bridge_encap_header_len(const struct sk_buff *skb) +{ + switch (skb->protocol) { + case __cpu_to_be16(ETH_P_8021Q): + return VLAN_HLEN; + case __cpu_to_be16(ETH_P_PPP_SES): + return PPPOE_SES_HLEN; + default: + return 0; + } +} + +static inline void nf_bridge_pull_encap_header(struct sk_buff *skb) +{ + unsigned int len = nf_bridge_encap_header_len(skb); + + skb_pull(skb, len); + skb->network_header += len; +} + +static inline void nf_bridge_pull_encap_header_rcsum(struct sk_buff *skb) +{ + unsigned int len = nf_bridge_encap_header_len(skb); + + skb_pull_rcsum(skb, len); + skb->network_header += len; +} + +/* When handing a packet over to the IP layer + * check whether we have a skb that is in the + * expected format + */ + +static int br_validate_ipv4(struct net *net, struct sk_buff *skb) +{ + const struct iphdr *iph; + u32 len; + + if (!pskb_may_pull(skb, sizeof(struct iphdr))) + goto inhdr_error; + + iph = ip_hdr(skb); + + /* Basic sanity checks */ + if (iph->ihl < 5 || iph->version != 4) + goto inhdr_error; + + if (!pskb_may_pull(skb, iph->ihl*4)) + goto inhdr_error; + + iph = ip_hdr(skb); + if (unlikely(ip_fast_csum((u8 *)iph, iph->ihl))) + goto csum_error; + + len = ntohs(iph->tot_len); + if (skb->len < len) { + __IP_INC_STATS(net, IPSTATS_MIB_INTRUNCATEDPKTS); + goto drop; + } else if (len < (iph->ihl*4)) + goto inhdr_error; + + if (pskb_trim_rcsum(skb, len)) { + __IP_INC_STATS(net, IPSTATS_MIB_INDISCARDS); + goto drop; + } + + memset(IPCB(skb), 0, sizeof(struct inet_skb_parm)); + /* We should really parse IP options here but until + * somebody who actually uses IP options complains to + * us we'll just silently ignore the options because + * we're lazy! + */ + return 0; + +csum_error: + __IP_INC_STATS(net, IPSTATS_MIB_CSUMERRORS); +inhdr_error: + __IP_INC_STATS(net, IPSTATS_MIB_INHDRERRORS); +drop: + return -1; +} + +void nf_bridge_update_protocol(struct sk_buff *skb) +{ + const struct nf_bridge_info *nf_bridge = nf_bridge_info_get(skb); + + switch (nf_bridge->orig_proto) { + case BRNF_PROTO_8021Q: + skb->protocol = htons(ETH_P_8021Q); + break; + case BRNF_PROTO_PPPOE: + skb->protocol = htons(ETH_P_PPP_SES); + break; + case BRNF_PROTO_UNCHANGED: + break; + } +} + +/* Obtain the correct destination MAC address, while preserving the original + * source MAC address. If we already know this address, we just copy it. If we + * don't, we use the neighbour framework to find out. In both cases, we make + * sure that br_handle_frame_finish() is called afterwards. + */ +int br_nf_pre_routing_finish_bridge(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct neighbour *neigh; + struct dst_entry *dst; + + skb->dev = bridge_parent(skb->dev); + if (!skb->dev) + goto free_skb; + dst = skb_dst(skb); + neigh = dst_neigh_lookup_skb(dst, skb); + if (neigh) { + struct nf_bridge_info *nf_bridge = nf_bridge_info_get(skb); + int ret; + + if ((READ_ONCE(neigh->nud_state) & NUD_CONNECTED) && + READ_ONCE(neigh->hh.hh_len)) { + struct net_device *br_indev; + + br_indev = nf_bridge_get_physindev(skb, net); + if (!br_indev) { + neigh_release(neigh); + goto free_skb; + } + + neigh_hh_bridge(&neigh->hh, skb); + skb->dev = br_indev; + + ret = br_handle_frame_finish(net, sk, skb); + } else { + /* the neighbour function below overwrites the complete + * MAC header, so we save the Ethernet source address and + * protocol number. + */ + skb_copy_from_linear_data_offset(skb, + -(ETH_HLEN-ETH_ALEN), + nf_bridge->neigh_header, + ETH_HLEN-ETH_ALEN); + /* tell br_dev_xmit to continue with forwarding */ + nf_bridge->bridged_dnat = 1; + /* FIXME Need to refragment */ + ret = READ_ONCE(neigh->output)(neigh, skb); + } + neigh_release(neigh); + return ret; + } +free_skb: + kfree_skb(skb); + return 0; +} + +static inline bool +br_nf_ipv4_daddr_was_changed(const struct sk_buff *skb, + const struct nf_bridge_info *nf_bridge) +{ + return ip_hdr(skb)->daddr != nf_bridge->ipv4_daddr; +} + +/* This requires some explaining. If DNAT has taken place, + * we will need to fix up the destination Ethernet address. + * This is also true when SNAT takes place (for the reply direction). + * + * There are two cases to consider: + * 1. The packet was DNAT'ed to a device in the same bridge + * port group as it was received on. We can still bridge + * the packet. + * 2. The packet was DNAT'ed to a different device, either + * a non-bridged device or another bridge port group. + * The packet will need to be routed. + * + * The correct way of distinguishing between these two cases is to + * call ip_route_input() and to look at skb->dst->dev, which is + * changed to the destination device if ip_route_input() succeeds. + * + * Let's first consider the case that ip_route_input() succeeds: + * + * If the output device equals the logical bridge device the packet + * came in on, we can consider this bridging. The corresponding MAC + * address will be obtained in br_nf_pre_routing_finish_bridge. + * Otherwise, the packet is considered to be routed and we just + * change the destination MAC address so that the packet will + * later be passed up to the IP stack to be routed. For a redirected + * packet, ip_route_input() will give back the localhost as output device, + * which differs from the bridge device. + * + * Let's now consider the case that ip_route_input() fails: + * + * This can be because the destination address is martian, in which case + * the packet will be dropped. + * If IP forwarding is disabled, ip_route_input() will fail, while + * ip_route_output_key() can return success. The source + * address for ip_route_output_key() is set to zero, so ip_route_output_key() + * thinks we're handling a locally generated packet and won't care + * if IP forwarding is enabled. If the output device equals the logical bridge + * device, we proceed as if ip_route_input() succeeded. If it differs from the + * logical bridge port or if ip_route_output_key() fails we drop the packet. + */ +static int br_nf_pre_routing_finish(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct net_device *dev = skb->dev, *br_indev; + struct iphdr *iph = ip_hdr(skb); + struct nf_bridge_info *nf_bridge = nf_bridge_info_get(skb); + struct rtable *rt; + int err; + + br_indev = nf_bridge_get_physindev(skb, net); + if (!br_indev) { + kfree_skb(skb); + return 0; + } + + nf_bridge->frag_max_size = IPCB(skb)->frag_max_size; + + if (nf_bridge->pkt_otherhost) { + skb->pkt_type = PACKET_OTHERHOST; + nf_bridge->pkt_otherhost = false; + } + nf_bridge->in_prerouting = 0; + if (br_nf_ipv4_daddr_was_changed(skb, nf_bridge)) { + if ((err = ip_route_input(skb, iph->daddr, iph->saddr, iph->tos, dev))) { + struct in_device *in_dev = __in_dev_get_rcu(dev); + + /* If err equals -EHOSTUNREACH the error is due to a + * martian destination or due to the fact that + * forwarding is disabled. For most martian packets, + * ip_route_output_key() will fail. It won't fail for 2 types of + * martian destinations: loopback destinations and destination + * 0.0.0.0. In both cases the packet will be dropped because the + * destination is the loopback device and not the bridge. */ + if (err != -EHOSTUNREACH || !in_dev || IN_DEV_FORWARD(in_dev)) + goto free_skb; + + rt = ip_route_output(net, iph->daddr, 0, + RT_TOS(iph->tos), 0); + if (!IS_ERR(rt)) { + /* - Bridged-and-DNAT'ed traffic doesn't + * require ip_forwarding. */ + if (rt->dst.dev == dev) { + skb_dst_drop(skb); + skb_dst_set(skb, &rt->dst); + goto bridged_dnat; + } + ip_rt_put(rt); + } +free_skb: + kfree_skb(skb); + return 0; + } else { + if (skb_dst(skb)->dev == dev) { +bridged_dnat: + skb->dev = br_indev; + nf_bridge_update_protocol(skb); + nf_bridge_push_encap_header(skb); + br_nf_hook_thresh(NF_BR_PRE_ROUTING, + net, sk, skb, skb->dev, + NULL, + br_nf_pre_routing_finish_bridge); + return 0; + } + ether_addr_copy(eth_hdr(skb)->h_dest, dev->dev_addr); + skb->pkt_type = PACKET_HOST; + } + } else { + rt = bridge_parent_rtable(br_indev); + if (!rt) { + kfree_skb(skb); + return 0; + } + skb_dst_drop(skb); + skb_dst_set_noref(skb, &rt->dst); + } + + skb->dev = br_indev; + nf_bridge_update_protocol(skb); + nf_bridge_push_encap_header(skb); + br_nf_hook_thresh(NF_BR_PRE_ROUTING, net, sk, skb, skb->dev, NULL, + br_handle_frame_finish); + return 0; +} + +static struct net_device *brnf_get_logical_dev(struct sk_buff *skb, + const struct net_device *dev, + const struct net *net) +{ + struct net_device *vlan, *br; + struct brnf_net *brnet = net_generic(net, brnf_net_id); + + br = bridge_parent(dev); + + if (brnet->pass_vlan_indev == 0 || !skb_vlan_tag_present(skb)) + return br; + + vlan = __vlan_find_dev_deep_rcu(br, skb->vlan_proto, + skb_vlan_tag_get(skb) & VLAN_VID_MASK); + + return vlan ? vlan : br; +} + +/* Some common code for IPv4/IPv6 */ +struct net_device *setup_pre_routing(struct sk_buff *skb, const struct net *net) +{ + struct nf_bridge_info *nf_bridge = nf_bridge_info_get(skb); + + if (skb->pkt_type == PACKET_OTHERHOST) { + skb->pkt_type = PACKET_HOST; + nf_bridge->pkt_otherhost = true; + } + + nf_bridge->in_prerouting = 1; + nf_bridge->physinif = skb->dev->ifindex; + skb->dev = brnf_get_logical_dev(skb, skb->dev, net); + + if (skb->protocol == htons(ETH_P_8021Q)) + nf_bridge->orig_proto = BRNF_PROTO_8021Q; + else if (skb->protocol == htons(ETH_P_PPP_SES)) + nf_bridge->orig_proto = BRNF_PROTO_PPPOE; + + /* Must drop socket now because of tproxy. */ + skb_orphan(skb); + return skb->dev; +} + +/* Direct IPv6 traffic to br_nf_pre_routing_ipv6. + * Replicate the checks that IPv4 does on packet reception. + * Set skb->dev to the bridge device (i.e. parent of the + * receiving device) to make netfilter happy, the REDIRECT + * target in particular. Save the original destination IP + * address to be able to detect DNAT afterwards. */ +static unsigned int br_nf_pre_routing(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nf_bridge_info *nf_bridge; + struct net_bridge_port *p; + struct net_bridge *br; + __u32 len = nf_bridge_encap_header_len(skb); + struct brnf_net *brnet; + + if (unlikely(!pskb_may_pull(skb, len))) + return NF_DROP; + + p = br_port_get_rcu(state->in); + if (p == NULL) + return NF_DROP; + br = p->br; + + brnet = net_generic(state->net, brnf_net_id); + if (IS_IPV6(skb) || is_vlan_ipv6(skb, state->net) || + is_pppoe_ipv6(skb, state->net)) { + if (!brnet->call_ip6tables && + !br_opt_get(br, BROPT_NF_CALL_IP6TABLES)) + return NF_ACCEPT; + if (!ipv6_mod_enabled()) { + pr_warn_once("Module ipv6 is disabled, so call_ip6tables is not supported."); + return NF_DROP; + } + + nf_bridge_pull_encap_header_rcsum(skb); + return br_nf_pre_routing_ipv6(priv, skb, state); + } + + if (!brnet->call_iptables && !br_opt_get(br, BROPT_NF_CALL_IPTABLES)) + return NF_ACCEPT; + + if (!IS_IP(skb) && !is_vlan_ip(skb, state->net) && + !is_pppoe_ip(skb, state->net)) + return NF_ACCEPT; + + nf_bridge_pull_encap_header_rcsum(skb); + + if (br_validate_ipv4(state->net, skb)) + return NF_DROP; + + if (!nf_bridge_alloc(skb)) + return NF_DROP; + if (!setup_pre_routing(skb, state->net)) + return NF_DROP; + + nf_bridge = nf_bridge_info_get(skb); + nf_bridge->ipv4_daddr = ip_hdr(skb)->daddr; + + skb->protocol = htons(ETH_P_IP); + skb->transport_header = skb->network_header + ip_hdr(skb)->ihl * 4; + + NF_HOOK(NFPROTO_IPV4, NF_INET_PRE_ROUTING, state->net, state->sk, skb, + skb->dev, NULL, + br_nf_pre_routing_finish); + + return NF_STOLEN; +} + + +/* PF_BRIDGE/FORWARD *************************************************/ +static int br_nf_forward_finish(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct nf_bridge_info *nf_bridge = nf_bridge_info_get(skb); + struct net_device *in; + + if (!IS_ARP(skb) && !is_vlan_arp(skb, net)) { + + if (skb->protocol == htons(ETH_P_IP)) + nf_bridge->frag_max_size = IPCB(skb)->frag_max_size; + + if (skb->protocol == htons(ETH_P_IPV6)) + nf_bridge->frag_max_size = IP6CB(skb)->frag_max_size; + + in = nf_bridge_get_physindev(skb, net); + if (!in) { + kfree_skb(skb); + return 0; + } + if (nf_bridge->pkt_otherhost) { + skb->pkt_type = PACKET_OTHERHOST; + nf_bridge->pkt_otherhost = false; + } + nf_bridge_update_protocol(skb); + } else { + in = *((struct net_device **)(skb->cb)); + } + nf_bridge_push_encap_header(skb); + + br_nf_hook_thresh(NF_BR_FORWARD, net, sk, skb, in, skb->dev, + br_forward_finish); + return 0; +} + + +/* This is the 'purely bridged' case. For IP, we pass the packet to + * netfilter with indev and outdev set to the bridge device, + * but we are still able to filter on the 'real' indev/outdev + * because of the physdev module. For ARP, indev and outdev are the + * bridge ports. */ +static unsigned int br_nf_forward_ip(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nf_bridge_info *nf_bridge; + struct net_device *parent; + u_int8_t pf; + + nf_bridge = nf_bridge_info_get(skb); + if (!nf_bridge) + return NF_ACCEPT; + + /* Need exclusive nf_bridge_info since we might have multiple + * different physoutdevs. */ + if (!nf_bridge_unshare(skb)) + return NF_DROP; + + nf_bridge = nf_bridge_info_get(skb); + if (!nf_bridge) + return NF_DROP; + + parent = bridge_parent(state->out); + if (!parent) + return NF_DROP; + + if (IS_IP(skb) || is_vlan_ip(skb, state->net) || + is_pppoe_ip(skb, state->net)) + pf = NFPROTO_IPV4; + else if (IS_IPV6(skb) || is_vlan_ipv6(skb, state->net) || + is_pppoe_ipv6(skb, state->net)) + pf = NFPROTO_IPV6; + else + return NF_ACCEPT; + + nf_bridge_pull_encap_header(skb); + + if (skb->pkt_type == PACKET_OTHERHOST) { + skb->pkt_type = PACKET_HOST; + nf_bridge->pkt_otherhost = true; + } + + if (pf == NFPROTO_IPV4) { + if (br_validate_ipv4(state->net, skb)) + return NF_DROP; + IPCB(skb)->frag_max_size = nf_bridge->frag_max_size; + } + + if (pf == NFPROTO_IPV6) { + if (br_validate_ipv6(state->net, skb)) + return NF_DROP; + IP6CB(skb)->frag_max_size = nf_bridge->frag_max_size; + } + + nf_bridge->physoutdev = skb->dev; + if (pf == NFPROTO_IPV4) + skb->protocol = htons(ETH_P_IP); + else + skb->protocol = htons(ETH_P_IPV6); + + NF_HOOK(pf, NF_INET_FORWARD, state->net, NULL, skb, + brnf_get_logical_dev(skb, state->in, state->net), + parent, br_nf_forward_finish); + + return NF_STOLEN; +} + +static unsigned int br_nf_forward_arp(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct net_bridge_port *p; + struct net_bridge *br; + struct net_device **d = (struct net_device **)(skb->cb); + struct brnf_net *brnet; + + p = br_port_get_rcu(state->out); + if (p == NULL) + return NF_ACCEPT; + br = p->br; + + brnet = net_generic(state->net, brnf_net_id); + if (!brnet->call_arptables && !br_opt_get(br, BROPT_NF_CALL_ARPTABLES)) + return NF_ACCEPT; + + if (!IS_ARP(skb)) { + if (!is_vlan_arp(skb, state->net)) + return NF_ACCEPT; + nf_bridge_pull_encap_header(skb); + } + + if (unlikely(!pskb_may_pull(skb, sizeof(struct arphdr)))) + return NF_DROP; + + if (arp_hdr(skb)->ar_pln != 4) { + if (is_vlan_arp(skb, state->net)) + nf_bridge_push_encap_header(skb); + return NF_ACCEPT; + } + *d = state->in; + NF_HOOK(NFPROTO_ARP, NF_ARP_FORWARD, state->net, state->sk, skb, + state->in, state->out, br_nf_forward_finish); + + return NF_STOLEN; +} + +static int br_nf_push_frag_xmit(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct brnf_frag_data *data; + int err; + + data = this_cpu_ptr(&brnf_frag_data_storage); + err = skb_cow_head(skb, data->size); + + if (err) { + kfree_skb(skb); + return 0; + } + + if (data->vlan_proto) + __vlan_hwaccel_put_tag(skb, data->vlan_proto, data->vlan_tci); + + skb_copy_to_linear_data_offset(skb, -data->size, data->mac, data->size); + __skb_push(skb, data->encap_size); + + nf_bridge_info_free(skb); + return br_dev_queue_push_xmit(net, sk, skb); +} + +static int +br_nf_ip_fragment(struct net *net, struct sock *sk, struct sk_buff *skb, + int (*output)(struct net *, struct sock *, struct sk_buff *)) +{ + unsigned int mtu = ip_skb_dst_mtu(sk, skb); + struct iphdr *iph = ip_hdr(skb); + + if (unlikely(((iph->frag_off & htons(IP_DF)) && !skb->ignore_df) || + (IPCB(skb)->frag_max_size && + IPCB(skb)->frag_max_size > mtu))) { + IP_INC_STATS(net, IPSTATS_MIB_FRAGFAILS); + kfree_skb(skb); + return -EMSGSIZE; + } + + return ip_do_fragment(net, sk, skb, output); +} + +static unsigned int nf_bridge_mtu_reduction(const struct sk_buff *skb) +{ + const struct nf_bridge_info *nf_bridge = nf_bridge_info_get(skb); + + if (nf_bridge->orig_proto == BRNF_PROTO_PPPOE) + return PPPOE_SES_HLEN; + return 0; +} + +static int br_nf_dev_queue_xmit(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct nf_bridge_info *nf_bridge = nf_bridge_info_get(skb); + unsigned int mtu, mtu_reserved; + + mtu_reserved = nf_bridge_mtu_reduction(skb); + mtu = skb->dev->mtu; + + if (nf_bridge->pkt_otherhost) { + skb->pkt_type = PACKET_OTHERHOST; + nf_bridge->pkt_otherhost = false; + } + + if (nf_bridge->frag_max_size && nf_bridge->frag_max_size < mtu) + mtu = nf_bridge->frag_max_size; + + nf_bridge_update_protocol(skb); + nf_bridge_push_encap_header(skb); + + if (skb_is_gso(skb) || skb->len + mtu_reserved <= mtu) { + nf_bridge_info_free(skb); + return br_dev_queue_push_xmit(net, sk, skb); + } + + /* This is wrong! We should preserve the original fragment + * boundaries by preserving frag_list rather than refragmenting. + */ + if (IS_ENABLED(CONFIG_NF_DEFRAG_IPV4) && + skb->protocol == htons(ETH_P_IP)) { + struct brnf_frag_data *data; + + if (br_validate_ipv4(net, skb)) + goto drop; + + IPCB(skb)->frag_max_size = nf_bridge->frag_max_size; + + data = this_cpu_ptr(&brnf_frag_data_storage); + + if (skb_vlan_tag_present(skb)) { + data->vlan_tci = skb->vlan_tci; + data->vlan_proto = skb->vlan_proto; + } else { + data->vlan_proto = 0; + } + + data->encap_size = nf_bridge_encap_header_len(skb); + data->size = ETH_HLEN + data->encap_size; + + skb_copy_from_linear_data_offset(skb, -data->size, data->mac, + data->size); + + return br_nf_ip_fragment(net, sk, skb, br_nf_push_frag_xmit); + } + if (IS_ENABLED(CONFIG_NF_DEFRAG_IPV6) && + skb->protocol == htons(ETH_P_IPV6)) { + const struct nf_ipv6_ops *v6ops = nf_get_ipv6_ops(); + struct brnf_frag_data *data; + + if (br_validate_ipv6(net, skb)) + goto drop; + + IP6CB(skb)->frag_max_size = nf_bridge->frag_max_size; + + data = this_cpu_ptr(&brnf_frag_data_storage); + data->encap_size = nf_bridge_encap_header_len(skb); + data->size = ETH_HLEN + data->encap_size; + + skb_copy_from_linear_data_offset(skb, -data->size, data->mac, + data->size); + + if (v6ops) + return v6ops->fragment(net, sk, skb, br_nf_push_frag_xmit); + + kfree_skb(skb); + return -EMSGSIZE; + } + nf_bridge_info_free(skb); + return br_dev_queue_push_xmit(net, sk, skb); + drop: + kfree_skb(skb); + return 0; +} + +/* PF_BRIDGE/POST_ROUTING ********************************************/ +static unsigned int br_nf_post_routing(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nf_bridge_info *nf_bridge = nf_bridge_info_get(skb); + struct net_device *realoutdev = bridge_parent(skb->dev); + u_int8_t pf; + + /* if nf_bridge is set, but ->physoutdev is NULL, this packet came in + * on a bridge, but was delivered locally and is now being routed: + * + * POST_ROUTING was already invoked from the ip stack. + */ + if (!nf_bridge || !nf_bridge->physoutdev) + return NF_ACCEPT; + + if (!realoutdev) + return NF_DROP; + + if (IS_IP(skb) || is_vlan_ip(skb, state->net) || + is_pppoe_ip(skb, state->net)) + pf = NFPROTO_IPV4; + else if (IS_IPV6(skb) || is_vlan_ipv6(skb, state->net) || + is_pppoe_ipv6(skb, state->net)) + pf = NFPROTO_IPV6; + else + return NF_ACCEPT; + + if (skb->pkt_type == PACKET_OTHERHOST) { + skb->pkt_type = PACKET_HOST; + nf_bridge->pkt_otherhost = true; + } + + nf_bridge_pull_encap_header(skb); + if (pf == NFPROTO_IPV4) + skb->protocol = htons(ETH_P_IP); + else + skb->protocol = htons(ETH_P_IPV6); + + NF_HOOK(pf, NF_INET_POST_ROUTING, state->net, state->sk, skb, + NULL, realoutdev, + br_nf_dev_queue_xmit); + + return NF_STOLEN; +} + +/* IP/SABOTAGE *****************************************************/ +/* Don't hand locally destined packets to PF_INET(6)/PRE_ROUTING + * for the second time. */ +static unsigned int ip_sabotage_in(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nf_bridge_info *nf_bridge = nf_bridge_info_get(skb); + + if (nf_bridge) { + if (nf_bridge->sabotage_in_done) + return NF_ACCEPT; + + if (!nf_bridge->in_prerouting && + !netif_is_l3_master(skb->dev) && + !netif_is_l3_slave(skb->dev)) { + nf_bridge->sabotage_in_done = 1; + state->okfn(state->net, state->sk, skb); + return NF_STOLEN; + } + } + + return NF_ACCEPT; +} + +/* This is called when br_netfilter has called into iptables/netfilter, + * and DNAT has taken place on a bridge-forwarded packet. + * + * neigh->output has created a new MAC header, with local br0 MAC + * as saddr. + * + * This restores the original MAC saddr of the bridged packet + * before invoking bridge forward logic to transmit the packet. + */ +static void br_nf_pre_routing_finish_bridge_slow(struct sk_buff *skb) +{ + struct nf_bridge_info *nf_bridge = nf_bridge_info_get(skb); + struct net_device *br_indev; + + br_indev = nf_bridge_get_physindev(skb, dev_net(skb->dev)); + if (!br_indev) { + kfree_skb(skb); + return; + } + + skb_pull(skb, ETH_HLEN); + nf_bridge->bridged_dnat = 0; + + BUILD_BUG_ON(sizeof(nf_bridge->neigh_header) != (ETH_HLEN - ETH_ALEN)); + + skb_copy_to_linear_data_offset(skb, -(ETH_HLEN - ETH_ALEN), + nf_bridge->neigh_header, + ETH_HLEN - ETH_ALEN); + skb->dev = br_indev; + + nf_bridge->physoutdev = NULL; + br_handle_frame_finish(dev_net(skb->dev), NULL, skb); +} + +static int br_nf_dev_xmit(struct sk_buff *skb) +{ + const struct nf_bridge_info *nf_bridge = nf_bridge_info_get(skb); + + if (nf_bridge && nf_bridge->bridged_dnat) { + br_nf_pre_routing_finish_bridge_slow(skb); + return 1; + } + return 0; +} + +static const struct nf_br_ops br_ops = { + .br_dev_xmit_hook = br_nf_dev_xmit, +}; + +/* For br_nf_post_routing, we need (prio = NF_BR_PRI_LAST), because + * br_dev_queue_push_xmit is called afterwards */ +static const struct nf_hook_ops br_nf_ops[] = { + { + .hook = br_nf_pre_routing, + .pf = NFPROTO_BRIDGE, + .hooknum = NF_BR_PRE_ROUTING, + .priority = NF_BR_PRI_BRNF, + }, + { + .hook = br_nf_forward_ip, + .pf = NFPROTO_BRIDGE, + .hooknum = NF_BR_FORWARD, + .priority = NF_BR_PRI_BRNF - 1, + }, + { + .hook = br_nf_forward_arp, + .pf = NFPROTO_BRIDGE, + .hooknum = NF_BR_FORWARD, + .priority = NF_BR_PRI_BRNF, + }, + { + .hook = br_nf_post_routing, + .pf = NFPROTO_BRIDGE, + .hooknum = NF_BR_POST_ROUTING, + .priority = NF_BR_PRI_LAST, + }, + { + .hook = ip_sabotage_in, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_PRE_ROUTING, + .priority = NF_IP_PRI_FIRST, + }, + { + .hook = ip_sabotage_in, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_PRE_ROUTING, + .priority = NF_IP6_PRI_FIRST, + }, +}; + +static int brnf_device_event(struct notifier_block *unused, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct brnf_net *brnet; + struct net *net; + int ret; + + if (event != NETDEV_REGISTER || !netif_is_bridge_master(dev)) + return NOTIFY_DONE; + + ASSERT_RTNL(); + + net = dev_net(dev); + brnet = net_generic(net, brnf_net_id); + if (brnet->enabled) + return NOTIFY_OK; + + ret = nf_register_net_hooks(net, br_nf_ops, ARRAY_SIZE(br_nf_ops)); + if (ret) + return NOTIFY_BAD; + + brnet->enabled = true; + return NOTIFY_OK; +} + +static struct notifier_block brnf_notifier __read_mostly = { + .notifier_call = brnf_device_event, +}; + +/* recursively invokes nf_hook_slow (again), skipping already-called + * hooks (< NF_BR_PRI_BRNF). + * + * Called with rcu read lock held. + */ +int br_nf_hook_thresh(unsigned int hook, struct net *net, + struct sock *sk, struct sk_buff *skb, + struct net_device *indev, + struct net_device *outdev, + int (*okfn)(struct net *, struct sock *, + struct sk_buff *)) +{ + const struct nf_hook_entries *e; + struct nf_hook_state state; + struct nf_hook_ops **ops; + unsigned int i; + int ret; + + e = rcu_dereference(net->nf.hooks_bridge[hook]); + if (!e) + return okfn(net, sk, skb); + + ops = nf_hook_entries_get_hook_ops(e); + for (i = 0; i < e->num_hook_entries; i++) { + /* These hooks have already been called */ + if (ops[i]->priority < NF_BR_PRI_BRNF) + continue; + + /* These hooks have not been called yet, run them. */ + if (ops[i]->priority > NF_BR_PRI_BRNF) + break; + + /* take a closer look at NF_BR_PRI_BRNF. */ + if (ops[i]->hook == br_nf_pre_routing) { + /* This hook diverted the skb to this function, + * hooks after this have not been run yet. + */ + i++; + break; + } + } + + nf_hook_state_init(&state, hook, NFPROTO_BRIDGE, indev, outdev, + sk, net, okfn); + + ret = nf_hook_slow(skb, &state, e, i); + if (ret == 1) + ret = okfn(net, sk, skb); + + return ret; +} + +#ifdef CONFIG_SYSCTL +static +int brnf_sysctl_call_tables(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + int ret; + + ret = proc_dointvec(ctl, write, buffer, lenp, ppos); + + if (write && *(int *)(ctl->data)) + *(int *)(ctl->data) = 1; + return ret; +} + +static struct ctl_table brnf_table[] = { + { + .procname = "bridge-nf-call-arptables", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = brnf_sysctl_call_tables, + }, + { + .procname = "bridge-nf-call-iptables", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = brnf_sysctl_call_tables, + }, + { + .procname = "bridge-nf-call-ip6tables", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = brnf_sysctl_call_tables, + }, + { + .procname = "bridge-nf-filter-vlan-tagged", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = brnf_sysctl_call_tables, + }, + { + .procname = "bridge-nf-filter-pppoe-tagged", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = brnf_sysctl_call_tables, + }, + { + .procname = "bridge-nf-pass-vlan-input-dev", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = brnf_sysctl_call_tables, + }, + { } +}; + +static inline void br_netfilter_sysctl_default(struct brnf_net *brnf) +{ + brnf->call_iptables = 1; + brnf->call_ip6tables = 1; + brnf->call_arptables = 1; + brnf->filter_vlan_tagged = 0; + brnf->filter_pppoe_tagged = 0; + brnf->pass_vlan_indev = 0; +} + +static int br_netfilter_sysctl_init_net(struct net *net) +{ + struct ctl_table *table = brnf_table; + struct brnf_net *brnet; + + if (!net_eq(net, &init_net)) { + table = kmemdup(table, sizeof(brnf_table), GFP_KERNEL); + if (!table) + return -ENOMEM; + } + + brnet = net_generic(net, brnf_net_id); + table[0].data = &brnet->call_arptables; + table[1].data = &brnet->call_iptables; + table[2].data = &brnet->call_ip6tables; + table[3].data = &brnet->filter_vlan_tagged; + table[4].data = &brnet->filter_pppoe_tagged; + table[5].data = &brnet->pass_vlan_indev; + + br_netfilter_sysctl_default(brnet); + + brnet->ctl_hdr = register_net_sysctl(net, "net/bridge", table); + if (!brnet->ctl_hdr) { + if (!net_eq(net, &init_net)) + kfree(table); + + return -ENOMEM; + } + + return 0; +} + +static void br_netfilter_sysctl_exit_net(struct net *net, + struct brnf_net *brnet) +{ + struct ctl_table *table = brnet->ctl_hdr->ctl_table_arg; + + unregister_net_sysctl_table(brnet->ctl_hdr); + if (!net_eq(net, &init_net)) + kfree(table); +} + +static int __net_init brnf_init_net(struct net *net) +{ + return br_netfilter_sysctl_init_net(net); +} +#endif + +static void __net_exit brnf_exit_net(struct net *net) +{ + struct brnf_net *brnet; + + brnet = net_generic(net, brnf_net_id); + if (brnet->enabled) { + nf_unregister_net_hooks(net, br_nf_ops, ARRAY_SIZE(br_nf_ops)); + brnet->enabled = false; + } + +#ifdef CONFIG_SYSCTL + br_netfilter_sysctl_exit_net(net, brnet); +#endif +} + +static struct pernet_operations brnf_net_ops __read_mostly = { +#ifdef CONFIG_SYSCTL + .init = brnf_init_net, +#endif + .exit = brnf_exit_net, + .id = &brnf_net_id, + .size = sizeof(struct brnf_net), +}; + +static int __init br_netfilter_init(void) +{ + int ret; + + ret = register_pernet_subsys(&brnf_net_ops); + if (ret < 0) + return ret; + + ret = register_netdevice_notifier(&brnf_notifier); + if (ret < 0) { + unregister_pernet_subsys(&brnf_net_ops); + return ret; + } + + RCU_INIT_POINTER(nf_br_ops, &br_ops); + printk(KERN_NOTICE "Bridge firewalling registered\n"); + return 0; +} + +static void __exit br_netfilter_fini(void) +{ + RCU_INIT_POINTER(nf_br_ops, NULL); + unregister_netdevice_notifier(&brnf_notifier); + unregister_pernet_subsys(&brnf_net_ops); +} + +module_init(br_netfilter_init); +module_exit(br_netfilter_fini); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Lennert Buytenhek "); +MODULE_AUTHOR("Bart De Schuymer "); +MODULE_DESCRIPTION("Linux ethernet netfilter firewall bridge"); diff --git a/net/bridge/br_netfilter_ipv6.c b/net/bridge/br_netfilter_ipv6.c new file mode 100644 index 000000000..cd24ab9bb --- /dev/null +++ b/net/bridge/br_netfilter_ipv6.c @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Handle firewalling + * Linux ethernet bridge + * + * Authors: + * Lennert Buytenhek + * Bart De Schuymer + * + * Lennert dedicates this file to Kerstin Wurdinger. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include "br_private.h" +#ifdef CONFIG_SYSCTL +#include +#endif + +/* We only check the length. A bridge shouldn't do any hop-by-hop stuff + * anyway + */ +static int br_nf_check_hbh_len(struct sk_buff *skb) +{ + unsigned char *raw = (u8 *)(ipv6_hdr(skb) + 1); + u32 pkt_len; + const unsigned char *nh = skb_network_header(skb); + int off = raw - nh; + int len = (raw[1] + 1) << 3; + + if ((raw + len) - skb->data > skb_headlen(skb)) + goto bad; + + off += 2; + len -= 2; + + while (len > 0) { + int optlen = nh[off + 1] + 2; + + switch (nh[off]) { + case IPV6_TLV_PAD1: + optlen = 1; + break; + + case IPV6_TLV_PADN: + break; + + case IPV6_TLV_JUMBO: + if (nh[off + 1] != 4 || (off & 3) != 2) + goto bad; + pkt_len = ntohl(*(__be32 *)(nh + off + 2)); + if (pkt_len <= IPV6_MAXPLEN || + ipv6_hdr(skb)->payload_len) + goto bad; + if (pkt_len > skb->len - sizeof(struct ipv6hdr)) + goto bad; + if (pskb_trim_rcsum(skb, + pkt_len + sizeof(struct ipv6hdr))) + goto bad; + nh = skb_network_header(skb); + break; + default: + if (optlen > len) + goto bad; + break; + } + off += optlen; + len -= optlen; + } + if (len == 0) + return 0; +bad: + return -1; +} + +int br_validate_ipv6(struct net *net, struct sk_buff *skb) +{ + const struct ipv6hdr *hdr; + struct inet6_dev *idev = __in6_dev_get(skb->dev); + u32 pkt_len; + u8 ip6h_len = sizeof(struct ipv6hdr); + + if (!pskb_may_pull(skb, ip6h_len)) + goto inhdr_error; + + if (skb->len < ip6h_len) + goto drop; + + hdr = ipv6_hdr(skb); + + if (hdr->version != 6) + goto inhdr_error; + + pkt_len = ntohs(hdr->payload_len); + + if (pkt_len || hdr->nexthdr != NEXTHDR_HOP) { + if (pkt_len + ip6h_len > skb->len) { + __IP6_INC_STATS(net, idev, + IPSTATS_MIB_INTRUNCATEDPKTS); + goto drop; + } + if (pskb_trim_rcsum(skb, pkt_len + ip6h_len)) { + __IP6_INC_STATS(net, idev, + IPSTATS_MIB_INDISCARDS); + goto drop; + } + hdr = ipv6_hdr(skb); + } + if (hdr->nexthdr == NEXTHDR_HOP && br_nf_check_hbh_len(skb)) + goto drop; + + memset(IP6CB(skb), 0, sizeof(struct inet6_skb_parm)); + /* No IP options in IPv6 header; however it should be + * checked if some next headers need special treatment + */ + return 0; + +inhdr_error: + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INHDRERRORS); +drop: + return -1; +} + +static inline bool +br_nf_ipv6_daddr_was_changed(const struct sk_buff *skb, + const struct nf_bridge_info *nf_bridge) +{ + return memcmp(&nf_bridge->ipv6_daddr, &ipv6_hdr(skb)->daddr, + sizeof(ipv6_hdr(skb)->daddr)) != 0; +} + +/* PF_BRIDGE/PRE_ROUTING: Undo the changes made for ip6tables + * PREROUTING and continue the bridge PRE_ROUTING hook. See comment + * for br_nf_pre_routing_finish(), same logic is used here but + * equivalent IPv6 function ip6_route_input() called indirectly. + */ +static int br_nf_pre_routing_finish_ipv6(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct nf_bridge_info *nf_bridge = nf_bridge_info_get(skb); + struct rtable *rt; + struct net_device *dev = skb->dev, *br_indev; + const struct nf_ipv6_ops *v6ops = nf_get_ipv6_ops(); + + br_indev = nf_bridge_get_physindev(skb, net); + if (!br_indev) { + kfree_skb(skb); + return 0; + } + + nf_bridge->frag_max_size = IP6CB(skb)->frag_max_size; + + if (nf_bridge->pkt_otherhost) { + skb->pkt_type = PACKET_OTHERHOST; + nf_bridge->pkt_otherhost = false; + } + nf_bridge->in_prerouting = 0; + if (br_nf_ipv6_daddr_was_changed(skb, nf_bridge)) { + skb_dst_drop(skb); + v6ops->route_input(skb); + + if (skb_dst(skb)->error) { + kfree_skb(skb); + return 0; + } + + if (skb_dst(skb)->dev == dev) { + skb->dev = br_indev; + nf_bridge_update_protocol(skb); + nf_bridge_push_encap_header(skb); + br_nf_hook_thresh(NF_BR_PRE_ROUTING, + net, sk, skb, skb->dev, NULL, + br_nf_pre_routing_finish_bridge); + return 0; + } + ether_addr_copy(eth_hdr(skb)->h_dest, dev->dev_addr); + skb->pkt_type = PACKET_HOST; + } else { + rt = bridge_parent_rtable(br_indev); + if (!rt) { + kfree_skb(skb); + return 0; + } + skb_dst_drop(skb); + skb_dst_set_noref(skb, &rt->dst); + } + + skb->dev = br_indev; + nf_bridge_update_protocol(skb); + nf_bridge_push_encap_header(skb); + br_nf_hook_thresh(NF_BR_PRE_ROUTING, net, sk, skb, + skb->dev, NULL, br_handle_frame_finish); + + return 0; +} + +/* Replicate the checks that IPv6 does on packet reception and pass the packet + * to ip6tables. + */ +unsigned int br_nf_pre_routing_ipv6(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nf_bridge_info *nf_bridge; + + if (br_validate_ipv6(state->net, skb)) + return NF_DROP; + + nf_bridge = nf_bridge_alloc(skb); + if (!nf_bridge) + return NF_DROP; + if (!setup_pre_routing(skb, state->net)) + return NF_DROP; + + nf_bridge = nf_bridge_info_get(skb); + nf_bridge->ipv6_daddr = ipv6_hdr(skb)->daddr; + + skb->protocol = htons(ETH_P_IPV6); + skb->transport_header = skb->network_header + sizeof(struct ipv6hdr); + + NF_HOOK(NFPROTO_IPV6, NF_INET_PRE_ROUTING, state->net, state->sk, skb, + skb->dev, NULL, + br_nf_pre_routing_finish_ipv6); + + return NF_STOLEN; +} diff --git a/net/bridge/br_netlink.c b/net/bridge/br_netlink.c new file mode 100644 index 000000000..d087fd4c7 --- /dev/null +++ b/net/bridge/br_netlink.c @@ -0,0 +1,1875 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Bridge netlink control interface + * + * Authors: + * Stephen Hemminger + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "br_private.h" +#include "br_private_stp.h" +#include "br_private_cfm.h" +#include "br_private_tunnel.h" +#include "br_private_mcast_eht.h" + +static int __get_num_vlan_infos(struct net_bridge_vlan_group *vg, + u32 filter_mask) +{ + struct net_bridge_vlan *v; + u16 vid_range_start = 0, vid_range_end = 0, vid_range_flags = 0; + u16 flags, pvid; + int num_vlans = 0; + + if (!(filter_mask & RTEXT_FILTER_BRVLAN_COMPRESSED)) + return 0; + + pvid = br_get_pvid(vg); + /* Count number of vlan infos */ + list_for_each_entry_rcu(v, &vg->vlan_list, vlist) { + flags = 0; + /* only a context, bridge vlan not activated */ + if (!br_vlan_should_use(v)) + continue; + if (v->vid == pvid) + flags |= BRIDGE_VLAN_INFO_PVID; + + if (v->flags & BRIDGE_VLAN_INFO_UNTAGGED) + flags |= BRIDGE_VLAN_INFO_UNTAGGED; + + if (vid_range_start == 0) { + goto initvars; + } else if ((v->vid - vid_range_end) == 1 && + flags == vid_range_flags) { + vid_range_end = v->vid; + continue; + } else { + if ((vid_range_end - vid_range_start) > 0) + num_vlans += 2; + else + num_vlans += 1; + } +initvars: + vid_range_start = v->vid; + vid_range_end = v->vid; + vid_range_flags = flags; + } + + if (vid_range_start != 0) { + if ((vid_range_end - vid_range_start) > 0) + num_vlans += 2; + else + num_vlans += 1; + } + + return num_vlans; +} + +static int br_get_num_vlan_infos(struct net_bridge_vlan_group *vg, + u32 filter_mask) +{ + int num_vlans; + + if (!vg) + return 0; + + if (filter_mask & RTEXT_FILTER_BRVLAN) + return vg->num_vlans; + + rcu_read_lock(); + num_vlans = __get_num_vlan_infos(vg, filter_mask); + rcu_read_unlock(); + + return num_vlans; +} + +static size_t br_get_link_af_size_filtered(const struct net_device *dev, + u32 filter_mask) +{ + struct net_bridge_vlan_group *vg = NULL; + struct net_bridge_port *p = NULL; + struct net_bridge *br = NULL; + u32 num_cfm_peer_mep_infos; + u32 num_cfm_mep_infos; + size_t vinfo_sz = 0; + int num_vlan_infos; + + rcu_read_lock(); + if (netif_is_bridge_port(dev)) { + p = br_port_get_check_rcu(dev); + if (p) + vg = nbp_vlan_group_rcu(p); + } else if (netif_is_bridge_master(dev)) { + br = netdev_priv(dev); + vg = br_vlan_group_rcu(br); + } + num_vlan_infos = br_get_num_vlan_infos(vg, filter_mask); + rcu_read_unlock(); + + if (p && (p->flags & BR_VLAN_TUNNEL)) + vinfo_sz += br_get_vlan_tunnel_info_size(vg); + + /* Each VLAN is returned in bridge_vlan_info along with flags */ + vinfo_sz += num_vlan_infos * nla_total_size(sizeof(struct bridge_vlan_info)); + + if (p && vg && (filter_mask & RTEXT_FILTER_MST)) + vinfo_sz += br_mst_info_size(vg); + + if (!(filter_mask & RTEXT_FILTER_CFM_STATUS)) + return vinfo_sz; + + if (!br) + return vinfo_sz; + + /* CFM status info must be added */ + br_cfm_mep_count(br, &num_cfm_mep_infos); + br_cfm_peer_mep_count(br, &num_cfm_peer_mep_infos); + + vinfo_sz += nla_total_size(0); /* IFLA_BRIDGE_CFM */ + /* For each status struct the MEP instance (u32) is added */ + /* MEP instance (u32) + br_cfm_mep_status */ + vinfo_sz += num_cfm_mep_infos * + /*IFLA_BRIDGE_CFM_MEP_STATUS_INSTANCE */ + (nla_total_size(sizeof(u32)) + /* IFLA_BRIDGE_CFM_MEP_STATUS_OPCODE_UNEXP_SEEN */ + + nla_total_size(sizeof(u32)) + /* IFLA_BRIDGE_CFM_MEP_STATUS_VERSION_UNEXP_SEEN */ + + nla_total_size(sizeof(u32)) + /* IFLA_BRIDGE_CFM_MEP_STATUS_RX_LEVEL_LOW_SEEN */ + + nla_total_size(sizeof(u32))); + /* MEP instance (u32) + br_cfm_cc_peer_status */ + vinfo_sz += num_cfm_peer_mep_infos * + /* IFLA_BRIDGE_CFM_CC_PEER_STATUS_INSTANCE */ + (nla_total_size(sizeof(u32)) + /* IFLA_BRIDGE_CFM_CC_PEER_STATUS_PEER_MEPID */ + + nla_total_size(sizeof(u32)) + /* IFLA_BRIDGE_CFM_CC_PEER_STATUS_CCM_DEFECT */ + + nla_total_size(sizeof(u32)) + /* IFLA_BRIDGE_CFM_CC_PEER_STATUS_RDI */ + + nla_total_size(sizeof(u32)) + /* IFLA_BRIDGE_CFM_CC_PEER_STATUS_PORT_TLV_VALUE */ + + nla_total_size(sizeof(u8)) + /* IFLA_BRIDGE_CFM_CC_PEER_STATUS_IF_TLV_VALUE */ + + nla_total_size(sizeof(u8)) + /* IFLA_BRIDGE_CFM_CC_PEER_STATUS_SEEN */ + + nla_total_size(sizeof(u32)) + /* IFLA_BRIDGE_CFM_CC_PEER_STATUS_TLV_SEEN */ + + nla_total_size(sizeof(u32)) + /* IFLA_BRIDGE_CFM_CC_PEER_STATUS_SEQ_UNEXP_SEEN */ + + nla_total_size(sizeof(u32))); + + return vinfo_sz; +} + +static inline size_t br_port_info_size(void) +{ + return nla_total_size(1) /* IFLA_BRPORT_STATE */ + + nla_total_size(2) /* IFLA_BRPORT_PRIORITY */ + + nla_total_size(4) /* IFLA_BRPORT_COST */ + + nla_total_size(1) /* IFLA_BRPORT_MODE */ + + nla_total_size(1) /* IFLA_BRPORT_GUARD */ + + nla_total_size(1) /* IFLA_BRPORT_PROTECT */ + + nla_total_size(1) /* IFLA_BRPORT_FAST_LEAVE */ + + nla_total_size(1) /* IFLA_BRPORT_MCAST_TO_UCAST */ + + nla_total_size(1) /* IFLA_BRPORT_LEARNING */ + + nla_total_size(1) /* IFLA_BRPORT_UNICAST_FLOOD */ + + nla_total_size(1) /* IFLA_BRPORT_MCAST_FLOOD */ + + nla_total_size(1) /* IFLA_BRPORT_BCAST_FLOOD */ + + nla_total_size(1) /* IFLA_BRPORT_PROXYARP */ + + nla_total_size(1) /* IFLA_BRPORT_PROXYARP_WIFI */ + + nla_total_size(1) /* IFLA_BRPORT_VLAN_TUNNEL */ + + nla_total_size(1) /* IFLA_BRPORT_NEIGH_SUPPRESS */ + + nla_total_size(1) /* IFLA_BRPORT_ISOLATED */ + + nla_total_size(1) /* IFLA_BRPORT_LOCKED */ + + nla_total_size(sizeof(struct ifla_bridge_id)) /* IFLA_BRPORT_ROOT_ID */ + + nla_total_size(sizeof(struct ifla_bridge_id)) /* IFLA_BRPORT_BRIDGE_ID */ + + nla_total_size(sizeof(u16)) /* IFLA_BRPORT_DESIGNATED_PORT */ + + nla_total_size(sizeof(u16)) /* IFLA_BRPORT_DESIGNATED_COST */ + + nla_total_size(sizeof(u16)) /* IFLA_BRPORT_ID */ + + nla_total_size(sizeof(u16)) /* IFLA_BRPORT_NO */ + + nla_total_size(sizeof(u8)) /* IFLA_BRPORT_TOPOLOGY_CHANGE_ACK */ + + nla_total_size(sizeof(u8)) /* IFLA_BRPORT_CONFIG_PENDING */ + + nla_total_size_64bit(sizeof(u64)) /* IFLA_BRPORT_MESSAGE_AGE_TIMER */ + + nla_total_size_64bit(sizeof(u64)) /* IFLA_BRPORT_FORWARD_DELAY_TIMER */ + + nla_total_size_64bit(sizeof(u64)) /* IFLA_BRPORT_HOLD_TIMER */ +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + + nla_total_size(sizeof(u8)) /* IFLA_BRPORT_MULTICAST_ROUTER */ +#endif + + nla_total_size(sizeof(u16)) /* IFLA_BRPORT_GROUP_FWD_MASK */ + + nla_total_size(sizeof(u8)) /* IFLA_BRPORT_MRP_RING_OPEN */ + + nla_total_size(sizeof(u8)) /* IFLA_BRPORT_MRP_IN_OPEN */ + + nla_total_size(sizeof(u32)) /* IFLA_BRPORT_MCAST_EHT_HOSTS_LIMIT */ + + nla_total_size(sizeof(u32)) /* IFLA_BRPORT_MCAST_EHT_HOSTS_CNT */ + + 0; +} + +static inline size_t br_nlmsg_size(struct net_device *dev, u32 filter_mask) +{ + return NLMSG_ALIGN(sizeof(struct ifinfomsg)) + + nla_total_size(IFNAMSIZ) /* IFLA_IFNAME */ + + nla_total_size(MAX_ADDR_LEN) /* IFLA_ADDRESS */ + + nla_total_size(4) /* IFLA_MASTER */ + + nla_total_size(4) /* IFLA_MTU */ + + nla_total_size(4) /* IFLA_LINK */ + + nla_total_size(1) /* IFLA_OPERSTATE */ + + nla_total_size(br_port_info_size()) /* IFLA_PROTINFO */ + + nla_total_size(br_get_link_af_size_filtered(dev, + filter_mask)) /* IFLA_AF_SPEC */ + + nla_total_size(4); /* IFLA_BRPORT_BACKUP_PORT */ +} + +static int br_port_fill_attrs(struct sk_buff *skb, + const struct net_bridge_port *p) +{ + u8 mode = !!(p->flags & BR_HAIRPIN_MODE); + struct net_bridge_port *backup_p; + u64 timerval; + + if (nla_put_u8(skb, IFLA_BRPORT_STATE, p->state) || + nla_put_u16(skb, IFLA_BRPORT_PRIORITY, p->priority) || + nla_put_u32(skb, IFLA_BRPORT_COST, p->path_cost) || + nla_put_u8(skb, IFLA_BRPORT_MODE, mode) || + nla_put_u8(skb, IFLA_BRPORT_GUARD, !!(p->flags & BR_BPDU_GUARD)) || + nla_put_u8(skb, IFLA_BRPORT_PROTECT, + !!(p->flags & BR_ROOT_BLOCK)) || + nla_put_u8(skb, IFLA_BRPORT_FAST_LEAVE, + !!(p->flags & BR_MULTICAST_FAST_LEAVE)) || + nla_put_u8(skb, IFLA_BRPORT_MCAST_TO_UCAST, + !!(p->flags & BR_MULTICAST_TO_UNICAST)) || + nla_put_u8(skb, IFLA_BRPORT_LEARNING, !!(p->flags & BR_LEARNING)) || + nla_put_u8(skb, IFLA_BRPORT_UNICAST_FLOOD, + !!(p->flags & BR_FLOOD)) || + nla_put_u8(skb, IFLA_BRPORT_MCAST_FLOOD, + !!(p->flags & BR_MCAST_FLOOD)) || + nla_put_u8(skb, IFLA_BRPORT_BCAST_FLOOD, + !!(p->flags & BR_BCAST_FLOOD)) || + nla_put_u8(skb, IFLA_BRPORT_PROXYARP, !!(p->flags & BR_PROXYARP)) || + nla_put_u8(skb, IFLA_BRPORT_PROXYARP_WIFI, + !!(p->flags & BR_PROXYARP_WIFI)) || + nla_put(skb, IFLA_BRPORT_ROOT_ID, sizeof(struct ifla_bridge_id), + &p->designated_root) || + nla_put(skb, IFLA_BRPORT_BRIDGE_ID, sizeof(struct ifla_bridge_id), + &p->designated_bridge) || + nla_put_u16(skb, IFLA_BRPORT_DESIGNATED_PORT, p->designated_port) || + nla_put_u16(skb, IFLA_BRPORT_DESIGNATED_COST, p->designated_cost) || + nla_put_u16(skb, IFLA_BRPORT_ID, p->port_id) || + nla_put_u16(skb, IFLA_BRPORT_NO, p->port_no) || + nla_put_u8(skb, IFLA_BRPORT_TOPOLOGY_CHANGE_ACK, + p->topology_change_ack) || + nla_put_u8(skb, IFLA_BRPORT_CONFIG_PENDING, p->config_pending) || + nla_put_u8(skb, IFLA_BRPORT_VLAN_TUNNEL, !!(p->flags & + BR_VLAN_TUNNEL)) || + nla_put_u16(skb, IFLA_BRPORT_GROUP_FWD_MASK, p->group_fwd_mask) || + nla_put_u8(skb, IFLA_BRPORT_NEIGH_SUPPRESS, + !!(p->flags & BR_NEIGH_SUPPRESS)) || + nla_put_u8(skb, IFLA_BRPORT_MRP_RING_OPEN, !!(p->flags & + BR_MRP_LOST_CONT)) || + nla_put_u8(skb, IFLA_BRPORT_MRP_IN_OPEN, + !!(p->flags & BR_MRP_LOST_IN_CONT)) || + nla_put_u8(skb, IFLA_BRPORT_ISOLATED, !!(p->flags & BR_ISOLATED)) || + nla_put_u8(skb, IFLA_BRPORT_LOCKED, !!(p->flags & BR_PORT_LOCKED))) + return -EMSGSIZE; + + timerval = br_timer_value(&p->message_age_timer); + if (nla_put_u64_64bit(skb, IFLA_BRPORT_MESSAGE_AGE_TIMER, timerval, + IFLA_BRPORT_PAD)) + return -EMSGSIZE; + timerval = br_timer_value(&p->forward_delay_timer); + if (nla_put_u64_64bit(skb, IFLA_BRPORT_FORWARD_DELAY_TIMER, timerval, + IFLA_BRPORT_PAD)) + return -EMSGSIZE; + timerval = br_timer_value(&p->hold_timer); + if (nla_put_u64_64bit(skb, IFLA_BRPORT_HOLD_TIMER, timerval, + IFLA_BRPORT_PAD)) + return -EMSGSIZE; + +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + if (nla_put_u8(skb, IFLA_BRPORT_MULTICAST_ROUTER, + p->multicast_ctx.multicast_router) || + nla_put_u32(skb, IFLA_BRPORT_MCAST_EHT_HOSTS_LIMIT, + p->multicast_eht_hosts_limit) || + nla_put_u32(skb, IFLA_BRPORT_MCAST_EHT_HOSTS_CNT, + p->multicast_eht_hosts_cnt)) + return -EMSGSIZE; +#endif + + /* we might be called only with br->lock */ + rcu_read_lock(); + backup_p = rcu_dereference(p->backup_port); + if (backup_p) + nla_put_u32(skb, IFLA_BRPORT_BACKUP_PORT, + backup_p->dev->ifindex); + rcu_read_unlock(); + + return 0; +} + +static int br_fill_ifvlaninfo_range(struct sk_buff *skb, u16 vid_start, + u16 vid_end, u16 flags) +{ + struct bridge_vlan_info vinfo; + + if ((vid_end - vid_start) > 0) { + /* add range to skb */ + vinfo.vid = vid_start; + vinfo.flags = flags | BRIDGE_VLAN_INFO_RANGE_BEGIN; + if (nla_put(skb, IFLA_BRIDGE_VLAN_INFO, + sizeof(vinfo), &vinfo)) + goto nla_put_failure; + + vinfo.vid = vid_end; + vinfo.flags = flags | BRIDGE_VLAN_INFO_RANGE_END; + if (nla_put(skb, IFLA_BRIDGE_VLAN_INFO, + sizeof(vinfo), &vinfo)) + goto nla_put_failure; + } else { + vinfo.vid = vid_start; + vinfo.flags = flags; + if (nla_put(skb, IFLA_BRIDGE_VLAN_INFO, + sizeof(vinfo), &vinfo)) + goto nla_put_failure; + } + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static int br_fill_ifvlaninfo_compressed(struct sk_buff *skb, + struct net_bridge_vlan_group *vg) +{ + struct net_bridge_vlan *v; + u16 vid_range_start = 0, vid_range_end = 0, vid_range_flags = 0; + u16 flags, pvid; + int err = 0; + + /* Pack IFLA_BRIDGE_VLAN_INFO's for every vlan + * and mark vlan info with begin and end flags + * if vlaninfo represents a range + */ + pvid = br_get_pvid(vg); + list_for_each_entry_rcu(v, &vg->vlan_list, vlist) { + flags = 0; + if (!br_vlan_should_use(v)) + continue; + if (v->vid == pvid) + flags |= BRIDGE_VLAN_INFO_PVID; + + if (v->flags & BRIDGE_VLAN_INFO_UNTAGGED) + flags |= BRIDGE_VLAN_INFO_UNTAGGED; + + if (vid_range_start == 0) { + goto initvars; + } else if ((v->vid - vid_range_end) == 1 && + flags == vid_range_flags) { + vid_range_end = v->vid; + continue; + } else { + err = br_fill_ifvlaninfo_range(skb, vid_range_start, + vid_range_end, + vid_range_flags); + if (err) + return err; + } + +initvars: + vid_range_start = v->vid; + vid_range_end = v->vid; + vid_range_flags = flags; + } + + if (vid_range_start != 0) { + /* Call it once more to send any left over vlans */ + err = br_fill_ifvlaninfo_range(skb, vid_range_start, + vid_range_end, + vid_range_flags); + if (err) + return err; + } + + return 0; +} + +static int br_fill_ifvlaninfo(struct sk_buff *skb, + struct net_bridge_vlan_group *vg) +{ + struct bridge_vlan_info vinfo; + struct net_bridge_vlan *v; + u16 pvid; + + pvid = br_get_pvid(vg); + list_for_each_entry_rcu(v, &vg->vlan_list, vlist) { + if (!br_vlan_should_use(v)) + continue; + + vinfo.vid = v->vid; + vinfo.flags = 0; + if (v->vid == pvid) + vinfo.flags |= BRIDGE_VLAN_INFO_PVID; + + if (v->flags & BRIDGE_VLAN_INFO_UNTAGGED) + vinfo.flags |= BRIDGE_VLAN_INFO_UNTAGGED; + + if (nla_put(skb, IFLA_BRIDGE_VLAN_INFO, + sizeof(vinfo), &vinfo)) + goto nla_put_failure; + } + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +/* + * Create one netlink message for one interface + * Contains port and master info as well as carrier and bridge state. + */ +static int br_fill_ifinfo(struct sk_buff *skb, + const struct net_bridge_port *port, + u32 pid, u32 seq, int event, unsigned int flags, + u32 filter_mask, const struct net_device *dev, + bool getlink) +{ + u8 operstate = netif_running(dev) ? dev->operstate : IF_OPER_DOWN; + struct nlattr *af = NULL; + struct net_bridge *br; + struct ifinfomsg *hdr; + struct nlmsghdr *nlh; + + if (port) + br = port->br; + else + br = netdev_priv(dev); + + br_debug(br, "br_fill_info event %d port %s master %s\n", + event, dev->name, br->dev->name); + + nlh = nlmsg_put(skb, pid, seq, event, sizeof(*hdr), flags); + if (nlh == NULL) + return -EMSGSIZE; + + hdr = nlmsg_data(nlh); + hdr->ifi_family = AF_BRIDGE; + hdr->__ifi_pad = 0; + hdr->ifi_type = dev->type; + hdr->ifi_index = dev->ifindex; + hdr->ifi_flags = dev_get_flags(dev); + hdr->ifi_change = 0; + + if (nla_put_string(skb, IFLA_IFNAME, dev->name) || + nla_put_u32(skb, IFLA_MASTER, br->dev->ifindex) || + nla_put_u32(skb, IFLA_MTU, dev->mtu) || + nla_put_u8(skb, IFLA_OPERSTATE, operstate) || + (dev->addr_len && + nla_put(skb, IFLA_ADDRESS, dev->addr_len, dev->dev_addr)) || + (dev->ifindex != dev_get_iflink(dev) && + nla_put_u32(skb, IFLA_LINK, dev_get_iflink(dev)))) + goto nla_put_failure; + + if (event == RTM_NEWLINK && port) { + struct nlattr *nest; + + nest = nla_nest_start(skb, IFLA_PROTINFO); + if (nest == NULL || br_port_fill_attrs(skb, port) < 0) + goto nla_put_failure; + nla_nest_end(skb, nest); + } + + if (filter_mask & (RTEXT_FILTER_BRVLAN | + RTEXT_FILTER_BRVLAN_COMPRESSED | + RTEXT_FILTER_MRP | + RTEXT_FILTER_CFM_CONFIG | + RTEXT_FILTER_CFM_STATUS | + RTEXT_FILTER_MST)) { + af = nla_nest_start_noflag(skb, IFLA_AF_SPEC); + if (!af) + goto nla_put_failure; + } + + /* Check if the VID information is requested */ + if ((filter_mask & RTEXT_FILTER_BRVLAN) || + (filter_mask & RTEXT_FILTER_BRVLAN_COMPRESSED)) { + struct net_bridge_vlan_group *vg; + int err; + + /* RCU needed because of the VLAN locking rules (rcu || rtnl) */ + rcu_read_lock(); + if (port) + vg = nbp_vlan_group_rcu(port); + else + vg = br_vlan_group_rcu(br); + + if (!vg || !vg->num_vlans) { + rcu_read_unlock(); + goto done; + } + if (filter_mask & RTEXT_FILTER_BRVLAN_COMPRESSED) + err = br_fill_ifvlaninfo_compressed(skb, vg); + else + err = br_fill_ifvlaninfo(skb, vg); + + if (port && (port->flags & BR_VLAN_TUNNEL)) + err = br_fill_vlan_tunnel_info(skb, vg); + rcu_read_unlock(); + if (err) + goto nla_put_failure; + } + + if (filter_mask & RTEXT_FILTER_MRP) { + int err; + + if (!br_mrp_enabled(br) || port) + goto done; + + rcu_read_lock(); + err = br_mrp_fill_info(skb, br); + rcu_read_unlock(); + + if (err) + goto nla_put_failure; + } + + if (filter_mask & (RTEXT_FILTER_CFM_CONFIG | RTEXT_FILTER_CFM_STATUS)) { + struct nlattr *cfm_nest = NULL; + int err; + + if (!br_cfm_created(br) || port) + goto done; + + cfm_nest = nla_nest_start(skb, IFLA_BRIDGE_CFM); + if (!cfm_nest) + goto nla_put_failure; + + if (filter_mask & RTEXT_FILTER_CFM_CONFIG) { + rcu_read_lock(); + err = br_cfm_config_fill_info(skb, br); + rcu_read_unlock(); + if (err) + goto nla_put_failure; + } + + if (filter_mask & RTEXT_FILTER_CFM_STATUS) { + rcu_read_lock(); + err = br_cfm_status_fill_info(skb, br, getlink); + rcu_read_unlock(); + if (err) + goto nla_put_failure; + } + + nla_nest_end(skb, cfm_nest); + } + + if ((filter_mask & RTEXT_FILTER_MST) && + br_opt_get(br, BROPT_MST_ENABLED) && port) { + const struct net_bridge_vlan_group *vg = nbp_vlan_group(port); + struct nlattr *mst_nest; + int err; + + if (!vg || !vg->num_vlans) + goto done; + + mst_nest = nla_nest_start(skb, IFLA_BRIDGE_MST); + if (!mst_nest) + goto nla_put_failure; + + err = br_mst_fill_info(skb, vg); + if (err) + goto nla_put_failure; + + nla_nest_end(skb, mst_nest); + } + +done: + if (af) { + if (nlmsg_get_pos(skb) - (void *)af > nla_attr_size(0)) + nla_nest_end(skb, af); + else + nla_nest_cancel(skb, af); + } + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +void br_info_notify(int event, const struct net_bridge *br, + const struct net_bridge_port *port, u32 filter) +{ + struct net_device *dev; + struct sk_buff *skb; + int err = -ENOBUFS; + struct net *net; + u16 port_no = 0; + + if (WARN_ON(!port && !br)) + return; + + if (port) { + dev = port->dev; + br = port->br; + port_no = port->port_no; + } else { + dev = br->dev; + } + + net = dev_net(dev); + br_debug(br, "port %u(%s) event %d\n", port_no, dev->name, event); + + skb = nlmsg_new(br_nlmsg_size(dev, filter), GFP_ATOMIC); + if (skb == NULL) + goto errout; + + err = br_fill_ifinfo(skb, port, 0, 0, event, 0, filter, dev, false); + if (err < 0) { + /* -EMSGSIZE implies BUG in br_nlmsg_size() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + rtnl_notify(skb, net, 0, RTNLGRP_LINK, NULL, GFP_ATOMIC); + return; +errout: + rtnl_set_sk_err(net, RTNLGRP_LINK, err); +} + +/* Notify listeners of a change in bridge or port information */ +void br_ifinfo_notify(int event, const struct net_bridge *br, + const struct net_bridge_port *port) +{ + u32 filter = RTEXT_FILTER_BRVLAN_COMPRESSED; + + return br_info_notify(event, br, port, filter); +} + +/* + * Dump information about all ports, in response to GETLINK + */ +int br_getlink(struct sk_buff *skb, u32 pid, u32 seq, + struct net_device *dev, u32 filter_mask, int nlflags) +{ + struct net_bridge_port *port = br_port_get_rtnl(dev); + + if (!port && !(filter_mask & RTEXT_FILTER_BRVLAN) && + !(filter_mask & RTEXT_FILTER_BRVLAN_COMPRESSED) && + !(filter_mask & RTEXT_FILTER_MRP) && + !(filter_mask & RTEXT_FILTER_CFM_CONFIG) && + !(filter_mask & RTEXT_FILTER_CFM_STATUS)) + return 0; + + return br_fill_ifinfo(skb, port, pid, seq, RTM_NEWLINK, nlflags, + filter_mask, dev, true); +} + +static int br_vlan_info(struct net_bridge *br, struct net_bridge_port *p, + int cmd, struct bridge_vlan_info *vinfo, bool *changed, + struct netlink_ext_ack *extack) +{ + bool curr_change; + int err = 0; + + switch (cmd) { + case RTM_SETLINK: + if (p) { + /* if the MASTER flag is set this will act on the global + * per-VLAN entry as well + */ + err = nbp_vlan_add(p, vinfo->vid, vinfo->flags, + &curr_change, extack); + } else { + vinfo->flags |= BRIDGE_VLAN_INFO_BRENTRY; + err = br_vlan_add(br, vinfo->vid, vinfo->flags, + &curr_change, extack); + } + if (curr_change) + *changed = true; + break; + + case RTM_DELLINK: + if (p) { + if (!nbp_vlan_delete(p, vinfo->vid)) + *changed = true; + + if ((vinfo->flags & BRIDGE_VLAN_INFO_MASTER) && + !br_vlan_delete(p->br, vinfo->vid)) + *changed = true; + } else if (!br_vlan_delete(br, vinfo->vid)) { + *changed = true; + } + break; + } + + return err; +} + +int br_process_vlan_info(struct net_bridge *br, + struct net_bridge_port *p, int cmd, + struct bridge_vlan_info *vinfo_curr, + struct bridge_vlan_info **vinfo_last, + bool *changed, + struct netlink_ext_ack *extack) +{ + int err, rtm_cmd; + + if (!br_vlan_valid_id(vinfo_curr->vid, extack)) + return -EINVAL; + + /* needed for vlan-only NEWVLAN/DELVLAN notifications */ + rtm_cmd = br_afspec_cmd_to_rtm(cmd); + + if (vinfo_curr->flags & BRIDGE_VLAN_INFO_RANGE_BEGIN) { + if (!br_vlan_valid_range(vinfo_curr, *vinfo_last, extack)) + return -EINVAL; + *vinfo_last = vinfo_curr; + return 0; + } + + if (*vinfo_last) { + struct bridge_vlan_info tmp_vinfo; + int v, v_change_start = 0; + + if (!br_vlan_valid_range(vinfo_curr, *vinfo_last, extack)) + return -EINVAL; + + memcpy(&tmp_vinfo, *vinfo_last, + sizeof(struct bridge_vlan_info)); + for (v = (*vinfo_last)->vid; v <= vinfo_curr->vid; v++) { + bool curr_change = false; + + tmp_vinfo.vid = v; + err = br_vlan_info(br, p, cmd, &tmp_vinfo, &curr_change, + extack); + if (err) + break; + if (curr_change) { + *changed = curr_change; + if (!v_change_start) + v_change_start = v; + } else { + /* nothing to notify yet */ + if (!v_change_start) + continue; + br_vlan_notify(br, p, v_change_start, + v - 1, rtm_cmd); + v_change_start = 0; + } + cond_resched(); + } + /* v_change_start is set only if the last/whole range changed */ + if (v_change_start) + br_vlan_notify(br, p, v_change_start, + v - 1, rtm_cmd); + + *vinfo_last = NULL; + + return err; + } + + err = br_vlan_info(br, p, cmd, vinfo_curr, changed, extack); + if (*changed) + br_vlan_notify(br, p, vinfo_curr->vid, 0, rtm_cmd); + + return err; +} + +static int br_afspec(struct net_bridge *br, + struct net_bridge_port *p, + struct nlattr *af_spec, + int cmd, bool *changed, + struct netlink_ext_ack *extack) +{ + struct bridge_vlan_info *vinfo_curr = NULL; + struct bridge_vlan_info *vinfo_last = NULL; + struct nlattr *attr; + struct vtunnel_info tinfo_last = {}; + struct vtunnel_info tinfo_curr = {}; + int err = 0, rem; + + nla_for_each_nested(attr, af_spec, rem) { + err = 0; + switch (nla_type(attr)) { + case IFLA_BRIDGE_VLAN_TUNNEL_INFO: + if (!p || !(p->flags & BR_VLAN_TUNNEL)) + return -EINVAL; + err = br_parse_vlan_tunnel_info(attr, &tinfo_curr); + if (err) + return err; + err = br_process_vlan_tunnel_info(br, p, cmd, + &tinfo_curr, + &tinfo_last, + changed); + if (err) + return err; + break; + case IFLA_BRIDGE_VLAN_INFO: + if (nla_len(attr) != sizeof(struct bridge_vlan_info)) + return -EINVAL; + vinfo_curr = nla_data(attr); + err = br_process_vlan_info(br, p, cmd, vinfo_curr, + &vinfo_last, changed, + extack); + if (err) + return err; + break; + case IFLA_BRIDGE_MRP: + err = br_mrp_parse(br, p, attr, cmd, extack); + if (err) + return err; + break; + case IFLA_BRIDGE_CFM: + err = br_cfm_parse(br, p, attr, cmd, extack); + if (err) + return err; + break; + case IFLA_BRIDGE_MST: + if (!p) { + NL_SET_ERR_MSG(extack, + "MST states can only be set on bridge ports"); + return -EINVAL; + } + + if (cmd != RTM_SETLINK) { + NL_SET_ERR_MSG(extack, + "MST states can only be set through RTM_SETLINK"); + return -EINVAL; + } + + err = br_mst_process(p, attr, extack); + if (err) + return err; + break; + } + } + + return err; +} + +static const struct nla_policy br_port_policy[IFLA_BRPORT_MAX + 1] = { + [IFLA_BRPORT_STATE] = { .type = NLA_U8 }, + [IFLA_BRPORT_COST] = { .type = NLA_U32 }, + [IFLA_BRPORT_PRIORITY] = { .type = NLA_U16 }, + [IFLA_BRPORT_MODE] = { .type = NLA_U8 }, + [IFLA_BRPORT_GUARD] = { .type = NLA_U8 }, + [IFLA_BRPORT_PROTECT] = { .type = NLA_U8 }, + [IFLA_BRPORT_FAST_LEAVE]= { .type = NLA_U8 }, + [IFLA_BRPORT_LEARNING] = { .type = NLA_U8 }, + [IFLA_BRPORT_UNICAST_FLOOD] = { .type = NLA_U8 }, + [IFLA_BRPORT_PROXYARP] = { .type = NLA_U8 }, + [IFLA_BRPORT_PROXYARP_WIFI] = { .type = NLA_U8 }, + [IFLA_BRPORT_MULTICAST_ROUTER] = { .type = NLA_U8 }, + [IFLA_BRPORT_MCAST_TO_UCAST] = { .type = NLA_U8 }, + [IFLA_BRPORT_MCAST_FLOOD] = { .type = NLA_U8 }, + [IFLA_BRPORT_BCAST_FLOOD] = { .type = NLA_U8 }, + [IFLA_BRPORT_VLAN_TUNNEL] = { .type = NLA_U8 }, + [IFLA_BRPORT_GROUP_FWD_MASK] = { .type = NLA_U16 }, + [IFLA_BRPORT_NEIGH_SUPPRESS] = { .type = NLA_U8 }, + [IFLA_BRPORT_ISOLATED] = { .type = NLA_U8 }, + [IFLA_BRPORT_LOCKED] = { .type = NLA_U8 }, + [IFLA_BRPORT_BACKUP_PORT] = { .type = NLA_U32 }, + [IFLA_BRPORT_MCAST_EHT_HOSTS_LIMIT] = { .type = NLA_U32 }, +}; + +/* Change the state of the port and notify spanning tree */ +static int br_set_port_state(struct net_bridge_port *p, u8 state) +{ + if (state > BR_STATE_BLOCKING) + return -EINVAL; + + /* if kernel STP is running, don't allow changes */ + if (p->br->stp_enabled == BR_KERNEL_STP) + return -EBUSY; + + /* if device is not up, change is not allowed + * if link is not present, only allowable state is disabled + */ + if (!netif_running(p->dev) || + (!netif_oper_up(p->dev) && state != BR_STATE_DISABLED)) + return -ENETDOWN; + + br_set_state(p, state); + br_port_state_selection(p->br); + return 0; +} + +/* Set/clear or port flags based on attribute */ +static void br_set_port_flag(struct net_bridge_port *p, struct nlattr *tb[], + int attrtype, unsigned long mask) +{ + if (!tb[attrtype]) + return; + + if (nla_get_u8(tb[attrtype])) + p->flags |= mask; + else + p->flags &= ~mask; +} + +/* Process bridge protocol info on port */ +static int br_setport(struct net_bridge_port *p, struct nlattr *tb[], + struct netlink_ext_ack *extack) +{ + unsigned long old_flags, changed_mask; + bool br_vlan_tunnel_old; + int err; + + old_flags = p->flags; + br_vlan_tunnel_old = (old_flags & BR_VLAN_TUNNEL) ? true : false; + + br_set_port_flag(p, tb, IFLA_BRPORT_MODE, BR_HAIRPIN_MODE); + br_set_port_flag(p, tb, IFLA_BRPORT_GUARD, BR_BPDU_GUARD); + br_set_port_flag(p, tb, IFLA_BRPORT_FAST_LEAVE, + BR_MULTICAST_FAST_LEAVE); + br_set_port_flag(p, tb, IFLA_BRPORT_PROTECT, BR_ROOT_BLOCK); + br_set_port_flag(p, tb, IFLA_BRPORT_LEARNING, BR_LEARNING); + br_set_port_flag(p, tb, IFLA_BRPORT_UNICAST_FLOOD, BR_FLOOD); + br_set_port_flag(p, tb, IFLA_BRPORT_MCAST_FLOOD, BR_MCAST_FLOOD); + br_set_port_flag(p, tb, IFLA_BRPORT_MCAST_TO_UCAST, + BR_MULTICAST_TO_UNICAST); + br_set_port_flag(p, tb, IFLA_BRPORT_BCAST_FLOOD, BR_BCAST_FLOOD); + br_set_port_flag(p, tb, IFLA_BRPORT_PROXYARP, BR_PROXYARP); + br_set_port_flag(p, tb, IFLA_BRPORT_PROXYARP_WIFI, BR_PROXYARP_WIFI); + br_set_port_flag(p, tb, IFLA_BRPORT_VLAN_TUNNEL, BR_VLAN_TUNNEL); + br_set_port_flag(p, tb, IFLA_BRPORT_NEIGH_SUPPRESS, BR_NEIGH_SUPPRESS); + br_set_port_flag(p, tb, IFLA_BRPORT_ISOLATED, BR_ISOLATED); + br_set_port_flag(p, tb, IFLA_BRPORT_LOCKED, BR_PORT_LOCKED); + + changed_mask = old_flags ^ p->flags; + + err = br_switchdev_set_port_flag(p, p->flags, changed_mask, extack); + if (err) { + p->flags = old_flags; + return err; + } + + if (br_vlan_tunnel_old && !(p->flags & BR_VLAN_TUNNEL)) + nbp_vlan_tunnel_info_flush(p); + + br_port_flags_change(p, changed_mask); + + if (tb[IFLA_BRPORT_COST]) { + err = br_stp_set_path_cost(p, nla_get_u32(tb[IFLA_BRPORT_COST])); + if (err) + return err; + } + + if (tb[IFLA_BRPORT_PRIORITY]) { + err = br_stp_set_port_priority(p, nla_get_u16(tb[IFLA_BRPORT_PRIORITY])); + if (err) + return err; + } + + if (tb[IFLA_BRPORT_STATE]) { + err = br_set_port_state(p, nla_get_u8(tb[IFLA_BRPORT_STATE])); + if (err) + return err; + } + + if (tb[IFLA_BRPORT_FLUSH]) + br_fdb_delete_by_port(p->br, p, 0, 0); + +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + if (tb[IFLA_BRPORT_MULTICAST_ROUTER]) { + u8 mcast_router = nla_get_u8(tb[IFLA_BRPORT_MULTICAST_ROUTER]); + + err = br_multicast_set_port_router(&p->multicast_ctx, + mcast_router); + if (err) + return err; + } + + if (tb[IFLA_BRPORT_MCAST_EHT_HOSTS_LIMIT]) { + u32 hlimit; + + hlimit = nla_get_u32(tb[IFLA_BRPORT_MCAST_EHT_HOSTS_LIMIT]); + err = br_multicast_eht_set_hosts_limit(p, hlimit); + if (err) + return err; + } +#endif + + if (tb[IFLA_BRPORT_GROUP_FWD_MASK]) { + u16 fwd_mask = nla_get_u16(tb[IFLA_BRPORT_GROUP_FWD_MASK]); + + if (fwd_mask & BR_GROUPFWD_MACPAUSE) + return -EINVAL; + p->group_fwd_mask = fwd_mask; + } + + if (tb[IFLA_BRPORT_BACKUP_PORT]) { + struct net_device *backup_dev = NULL; + u32 backup_ifindex; + + backup_ifindex = nla_get_u32(tb[IFLA_BRPORT_BACKUP_PORT]); + if (backup_ifindex) { + backup_dev = __dev_get_by_index(dev_net(p->dev), + backup_ifindex); + if (!backup_dev) + return -ENOENT; + } + + err = nbp_backup_change(p, backup_dev); + if (err) + return err; + } + + return 0; +} + +/* Change state and parameters on port. */ +int br_setlink(struct net_device *dev, struct nlmsghdr *nlh, u16 flags, + struct netlink_ext_ack *extack) +{ + struct net_bridge *br = (struct net_bridge *)netdev_priv(dev); + struct nlattr *tb[IFLA_BRPORT_MAX + 1]; + struct net_bridge_port *p; + struct nlattr *protinfo; + struct nlattr *afspec; + bool changed = false; + int err = 0; + + protinfo = nlmsg_find_attr(nlh, sizeof(struct ifinfomsg), IFLA_PROTINFO); + afspec = nlmsg_find_attr(nlh, sizeof(struct ifinfomsg), IFLA_AF_SPEC); + if (!protinfo && !afspec) + return 0; + + p = br_port_get_rtnl(dev); + /* We want to accept dev as bridge itself if the AF_SPEC + * is set to see if someone is setting vlan info on the bridge + */ + if (!p && !afspec) + return -EINVAL; + + if (p && protinfo) { + if (protinfo->nla_type & NLA_F_NESTED) { + err = nla_parse_nested_deprecated(tb, IFLA_BRPORT_MAX, + protinfo, + br_port_policy, + NULL); + if (err) + return err; + + spin_lock_bh(&p->br->lock); + err = br_setport(p, tb, extack); + spin_unlock_bh(&p->br->lock); + } else { + /* Binary compatibility with old RSTP */ + if (nla_len(protinfo) < sizeof(u8)) + return -EINVAL; + + spin_lock_bh(&p->br->lock); + err = br_set_port_state(p, nla_get_u8(protinfo)); + spin_unlock_bh(&p->br->lock); + } + if (err) + goto out; + changed = true; + } + + if (afspec) + err = br_afspec(br, p, afspec, RTM_SETLINK, &changed, extack); + + if (changed) + br_ifinfo_notify(RTM_NEWLINK, br, p); +out: + return err; +} + +/* Delete port information */ +int br_dellink(struct net_device *dev, struct nlmsghdr *nlh, u16 flags) +{ + struct net_bridge *br = (struct net_bridge *)netdev_priv(dev); + struct net_bridge_port *p; + struct nlattr *afspec; + bool changed = false; + int err = 0; + + afspec = nlmsg_find_attr(nlh, sizeof(struct ifinfomsg), IFLA_AF_SPEC); + if (!afspec) + return 0; + + p = br_port_get_rtnl(dev); + /* We want to accept dev as bridge itself as well */ + if (!p && !netif_is_bridge_master(dev)) + return -EINVAL; + + err = br_afspec(br, p, afspec, RTM_DELLINK, &changed, NULL); + if (changed) + /* Send RTM_NEWLINK because userspace + * expects RTM_NEWLINK for vlan dels + */ + br_ifinfo_notify(RTM_NEWLINK, br, p); + + return err; +} + +static int br_validate(struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + if (tb[IFLA_ADDRESS]) { + if (nla_len(tb[IFLA_ADDRESS]) != ETH_ALEN) + return -EINVAL; + if (!is_valid_ether_addr(nla_data(tb[IFLA_ADDRESS]))) + return -EADDRNOTAVAIL; + } + + if (!data) + return 0; + +#ifdef CONFIG_BRIDGE_VLAN_FILTERING + if (data[IFLA_BR_VLAN_PROTOCOL] && + !eth_type_vlan(nla_get_be16(data[IFLA_BR_VLAN_PROTOCOL]))) + return -EPROTONOSUPPORT; + + if (data[IFLA_BR_VLAN_DEFAULT_PVID]) { + __u16 defpvid = nla_get_u16(data[IFLA_BR_VLAN_DEFAULT_PVID]); + + if (defpvid >= VLAN_VID_MASK) + return -EINVAL; + } +#endif + + return 0; +} + +static int br_port_slave_changelink(struct net_device *brdev, + struct net_device *dev, + struct nlattr *tb[], + struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct net_bridge *br = netdev_priv(brdev); + int ret; + + if (!data) + return 0; + + spin_lock_bh(&br->lock); + ret = br_setport(br_port_get_rtnl(dev), data, extack); + spin_unlock_bh(&br->lock); + + return ret; +} + +static int br_port_fill_slave_info(struct sk_buff *skb, + const struct net_device *brdev, + const struct net_device *dev) +{ + return br_port_fill_attrs(skb, br_port_get_rtnl(dev)); +} + +static size_t br_port_get_slave_size(const struct net_device *brdev, + const struct net_device *dev) +{ + return br_port_info_size(); +} + +static const struct nla_policy br_policy[IFLA_BR_MAX + 1] = { + [IFLA_BR_FORWARD_DELAY] = { .type = NLA_U32 }, + [IFLA_BR_HELLO_TIME] = { .type = NLA_U32 }, + [IFLA_BR_MAX_AGE] = { .type = NLA_U32 }, + [IFLA_BR_AGEING_TIME] = { .type = NLA_U32 }, + [IFLA_BR_STP_STATE] = { .type = NLA_U32 }, + [IFLA_BR_PRIORITY] = { .type = NLA_U16 }, + [IFLA_BR_VLAN_FILTERING] = { .type = NLA_U8 }, + [IFLA_BR_VLAN_PROTOCOL] = { .type = NLA_U16 }, + [IFLA_BR_GROUP_FWD_MASK] = { .type = NLA_U16 }, + [IFLA_BR_GROUP_ADDR] = { .type = NLA_BINARY, + .len = ETH_ALEN }, + [IFLA_BR_MCAST_ROUTER] = { .type = NLA_U8 }, + [IFLA_BR_MCAST_SNOOPING] = { .type = NLA_U8 }, + [IFLA_BR_MCAST_QUERY_USE_IFADDR] = { .type = NLA_U8 }, + [IFLA_BR_MCAST_QUERIER] = { .type = NLA_U8 }, + [IFLA_BR_MCAST_HASH_ELASTICITY] = { .type = NLA_U32 }, + [IFLA_BR_MCAST_HASH_MAX] = { .type = NLA_U32 }, + [IFLA_BR_MCAST_LAST_MEMBER_CNT] = { .type = NLA_U32 }, + [IFLA_BR_MCAST_STARTUP_QUERY_CNT] = { .type = NLA_U32 }, + [IFLA_BR_MCAST_LAST_MEMBER_INTVL] = { .type = NLA_U64 }, + [IFLA_BR_MCAST_MEMBERSHIP_INTVL] = { .type = NLA_U64 }, + [IFLA_BR_MCAST_QUERIER_INTVL] = { .type = NLA_U64 }, + [IFLA_BR_MCAST_QUERY_INTVL] = { .type = NLA_U64 }, + [IFLA_BR_MCAST_QUERY_RESPONSE_INTVL] = { .type = NLA_U64 }, + [IFLA_BR_MCAST_STARTUP_QUERY_INTVL] = { .type = NLA_U64 }, + [IFLA_BR_NF_CALL_IPTABLES] = { .type = NLA_U8 }, + [IFLA_BR_NF_CALL_IP6TABLES] = { .type = NLA_U8 }, + [IFLA_BR_NF_CALL_ARPTABLES] = { .type = NLA_U8 }, + [IFLA_BR_VLAN_DEFAULT_PVID] = { .type = NLA_U16 }, + [IFLA_BR_VLAN_STATS_ENABLED] = { .type = NLA_U8 }, + [IFLA_BR_MCAST_STATS_ENABLED] = { .type = NLA_U8 }, + [IFLA_BR_MCAST_IGMP_VERSION] = { .type = NLA_U8 }, + [IFLA_BR_MCAST_MLD_VERSION] = { .type = NLA_U8 }, + [IFLA_BR_VLAN_STATS_PER_PORT] = { .type = NLA_U8 }, + [IFLA_BR_MULTI_BOOLOPT] = + NLA_POLICY_EXACT_LEN(sizeof(struct br_boolopt_multi)), +}; + +static int br_changelink(struct net_device *brdev, struct nlattr *tb[], + struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct net_bridge *br = netdev_priv(brdev); + int err; + + if (!data) + return 0; + + if (data[IFLA_BR_FORWARD_DELAY]) { + err = br_set_forward_delay(br, nla_get_u32(data[IFLA_BR_FORWARD_DELAY])); + if (err) + return err; + } + + if (data[IFLA_BR_HELLO_TIME]) { + err = br_set_hello_time(br, nla_get_u32(data[IFLA_BR_HELLO_TIME])); + if (err) + return err; + } + + if (data[IFLA_BR_MAX_AGE]) { + err = br_set_max_age(br, nla_get_u32(data[IFLA_BR_MAX_AGE])); + if (err) + return err; + } + + if (data[IFLA_BR_AGEING_TIME]) { + err = br_set_ageing_time(br, nla_get_u32(data[IFLA_BR_AGEING_TIME])); + if (err) + return err; + } + + if (data[IFLA_BR_STP_STATE]) { + u32 stp_enabled = nla_get_u32(data[IFLA_BR_STP_STATE]); + + err = br_stp_set_enabled(br, stp_enabled, extack); + if (err) + return err; + } + + if (data[IFLA_BR_PRIORITY]) { + u32 priority = nla_get_u16(data[IFLA_BR_PRIORITY]); + + br_stp_set_bridge_priority(br, priority); + } + + if (data[IFLA_BR_VLAN_FILTERING]) { + u8 vlan_filter = nla_get_u8(data[IFLA_BR_VLAN_FILTERING]); + + err = br_vlan_filter_toggle(br, vlan_filter, extack); + if (err) + return err; + } + +#ifdef CONFIG_BRIDGE_VLAN_FILTERING + if (data[IFLA_BR_VLAN_PROTOCOL]) { + __be16 vlan_proto = nla_get_be16(data[IFLA_BR_VLAN_PROTOCOL]); + + err = __br_vlan_set_proto(br, vlan_proto, extack); + if (err) + return err; + } + + if (data[IFLA_BR_VLAN_DEFAULT_PVID]) { + __u16 defpvid = nla_get_u16(data[IFLA_BR_VLAN_DEFAULT_PVID]); + + err = __br_vlan_set_default_pvid(br, defpvid, extack); + if (err) + return err; + } + + if (data[IFLA_BR_VLAN_STATS_ENABLED]) { + __u8 vlan_stats = nla_get_u8(data[IFLA_BR_VLAN_STATS_ENABLED]); + + err = br_vlan_set_stats(br, vlan_stats); + if (err) + return err; + } + + if (data[IFLA_BR_VLAN_STATS_PER_PORT]) { + __u8 per_port = nla_get_u8(data[IFLA_BR_VLAN_STATS_PER_PORT]); + + err = br_vlan_set_stats_per_port(br, per_port); + if (err) + return err; + } +#endif + + if (data[IFLA_BR_GROUP_FWD_MASK]) { + u16 fwd_mask = nla_get_u16(data[IFLA_BR_GROUP_FWD_MASK]); + + if (fwd_mask & BR_GROUPFWD_RESTRICTED) + return -EINVAL; + br->group_fwd_mask = fwd_mask; + } + + if (data[IFLA_BR_GROUP_ADDR]) { + u8 new_addr[ETH_ALEN]; + + if (nla_len(data[IFLA_BR_GROUP_ADDR]) != ETH_ALEN) + return -EINVAL; + memcpy(new_addr, nla_data(data[IFLA_BR_GROUP_ADDR]), ETH_ALEN); + if (!is_link_local_ether_addr(new_addr)) + return -EINVAL; + if (new_addr[5] == 1 || /* 802.3x Pause address */ + new_addr[5] == 2 || /* 802.3ad Slow protocols */ + new_addr[5] == 3) /* 802.1X PAE address */ + return -EINVAL; + spin_lock_bh(&br->lock); + memcpy(br->group_addr, new_addr, sizeof(br->group_addr)); + spin_unlock_bh(&br->lock); + br_opt_toggle(br, BROPT_GROUP_ADDR_SET, true); + br_recalculate_fwd_mask(br); + } + + if (data[IFLA_BR_FDB_FLUSH]) { + struct net_bridge_fdb_flush_desc desc = { + .flags_mask = BIT(BR_FDB_STATIC) + }; + + br_fdb_flush(br, &desc); + } + +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + if (data[IFLA_BR_MCAST_ROUTER]) { + u8 multicast_router = nla_get_u8(data[IFLA_BR_MCAST_ROUTER]); + + err = br_multicast_set_router(&br->multicast_ctx, + multicast_router); + if (err) + return err; + } + + if (data[IFLA_BR_MCAST_SNOOPING]) { + u8 mcast_snooping = nla_get_u8(data[IFLA_BR_MCAST_SNOOPING]); + + err = br_multicast_toggle(br, mcast_snooping, extack); + if (err) + return err; + } + + if (data[IFLA_BR_MCAST_QUERY_USE_IFADDR]) { + u8 val; + + val = nla_get_u8(data[IFLA_BR_MCAST_QUERY_USE_IFADDR]); + br_opt_toggle(br, BROPT_MULTICAST_QUERY_USE_IFADDR, !!val); + } + + if (data[IFLA_BR_MCAST_QUERIER]) { + u8 mcast_querier = nla_get_u8(data[IFLA_BR_MCAST_QUERIER]); + + err = br_multicast_set_querier(&br->multicast_ctx, + mcast_querier); + if (err) + return err; + } + + if (data[IFLA_BR_MCAST_HASH_ELASTICITY]) + br_warn(br, "the hash_elasticity option has been deprecated and is always %u\n", + RHT_ELASTICITY); + + if (data[IFLA_BR_MCAST_HASH_MAX]) + br->hash_max = nla_get_u32(data[IFLA_BR_MCAST_HASH_MAX]); + + if (data[IFLA_BR_MCAST_LAST_MEMBER_CNT]) { + u32 val = nla_get_u32(data[IFLA_BR_MCAST_LAST_MEMBER_CNT]); + + br->multicast_ctx.multicast_last_member_count = val; + } + + if (data[IFLA_BR_MCAST_STARTUP_QUERY_CNT]) { + u32 val = nla_get_u32(data[IFLA_BR_MCAST_STARTUP_QUERY_CNT]); + + br->multicast_ctx.multicast_startup_query_count = val; + } + + if (data[IFLA_BR_MCAST_LAST_MEMBER_INTVL]) { + u64 val = nla_get_u64(data[IFLA_BR_MCAST_LAST_MEMBER_INTVL]); + + br->multicast_ctx.multicast_last_member_interval = clock_t_to_jiffies(val); + } + + if (data[IFLA_BR_MCAST_MEMBERSHIP_INTVL]) { + u64 val = nla_get_u64(data[IFLA_BR_MCAST_MEMBERSHIP_INTVL]); + + br->multicast_ctx.multicast_membership_interval = clock_t_to_jiffies(val); + } + + if (data[IFLA_BR_MCAST_QUERIER_INTVL]) { + u64 val = nla_get_u64(data[IFLA_BR_MCAST_QUERIER_INTVL]); + + br->multicast_ctx.multicast_querier_interval = clock_t_to_jiffies(val); + } + + if (data[IFLA_BR_MCAST_QUERY_INTVL]) { + u64 val = nla_get_u64(data[IFLA_BR_MCAST_QUERY_INTVL]); + + br_multicast_set_query_intvl(&br->multicast_ctx, val); + } + + if (data[IFLA_BR_MCAST_QUERY_RESPONSE_INTVL]) { + u64 val = nla_get_u64(data[IFLA_BR_MCAST_QUERY_RESPONSE_INTVL]); + + br->multicast_ctx.multicast_query_response_interval = clock_t_to_jiffies(val); + } + + if (data[IFLA_BR_MCAST_STARTUP_QUERY_INTVL]) { + u64 val = nla_get_u64(data[IFLA_BR_MCAST_STARTUP_QUERY_INTVL]); + + br_multicast_set_startup_query_intvl(&br->multicast_ctx, val); + } + + if (data[IFLA_BR_MCAST_STATS_ENABLED]) { + __u8 mcast_stats; + + mcast_stats = nla_get_u8(data[IFLA_BR_MCAST_STATS_ENABLED]); + br_opt_toggle(br, BROPT_MULTICAST_STATS_ENABLED, !!mcast_stats); + } + + if (data[IFLA_BR_MCAST_IGMP_VERSION]) { + __u8 igmp_version; + + igmp_version = nla_get_u8(data[IFLA_BR_MCAST_IGMP_VERSION]); + err = br_multicast_set_igmp_version(&br->multicast_ctx, + igmp_version); + if (err) + return err; + } + +#if IS_ENABLED(CONFIG_IPV6) + if (data[IFLA_BR_MCAST_MLD_VERSION]) { + __u8 mld_version; + + mld_version = nla_get_u8(data[IFLA_BR_MCAST_MLD_VERSION]); + err = br_multicast_set_mld_version(&br->multicast_ctx, + mld_version); + if (err) + return err; + } +#endif +#endif +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + if (data[IFLA_BR_NF_CALL_IPTABLES]) { + u8 val = nla_get_u8(data[IFLA_BR_NF_CALL_IPTABLES]); + + br_opt_toggle(br, BROPT_NF_CALL_IPTABLES, !!val); + } + + if (data[IFLA_BR_NF_CALL_IP6TABLES]) { + u8 val = nla_get_u8(data[IFLA_BR_NF_CALL_IP6TABLES]); + + br_opt_toggle(br, BROPT_NF_CALL_IP6TABLES, !!val); + } + + if (data[IFLA_BR_NF_CALL_ARPTABLES]) { + u8 val = nla_get_u8(data[IFLA_BR_NF_CALL_ARPTABLES]); + + br_opt_toggle(br, BROPT_NF_CALL_ARPTABLES, !!val); + } +#endif + + if (data[IFLA_BR_MULTI_BOOLOPT]) { + struct br_boolopt_multi *bm; + + bm = nla_data(data[IFLA_BR_MULTI_BOOLOPT]); + err = br_boolopt_multi_toggle(br, bm, extack); + if (err) + return err; + } + + return 0; +} + +static int br_dev_newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct net_bridge *br = netdev_priv(dev); + int err; + + err = register_netdevice(dev); + if (err) + return err; + + if (tb[IFLA_ADDRESS]) { + spin_lock_bh(&br->lock); + br_stp_change_bridge_id(br, nla_data(tb[IFLA_ADDRESS])); + spin_unlock_bh(&br->lock); + } + + err = br_changelink(dev, tb, data, extack); + if (err) + br_dev_delete(dev, NULL); + + return err; +} + +static size_t br_get_size(const struct net_device *brdev) +{ + return nla_total_size(sizeof(u32)) + /* IFLA_BR_FORWARD_DELAY */ + nla_total_size(sizeof(u32)) + /* IFLA_BR_HELLO_TIME */ + nla_total_size(sizeof(u32)) + /* IFLA_BR_MAX_AGE */ + nla_total_size(sizeof(u32)) + /* IFLA_BR_AGEING_TIME */ + nla_total_size(sizeof(u32)) + /* IFLA_BR_STP_STATE */ + nla_total_size(sizeof(u16)) + /* IFLA_BR_PRIORITY */ + nla_total_size(sizeof(u8)) + /* IFLA_BR_VLAN_FILTERING */ +#ifdef CONFIG_BRIDGE_VLAN_FILTERING + nla_total_size(sizeof(__be16)) + /* IFLA_BR_VLAN_PROTOCOL */ + nla_total_size(sizeof(u16)) + /* IFLA_BR_VLAN_DEFAULT_PVID */ + nla_total_size(sizeof(u8)) + /* IFLA_BR_VLAN_STATS_ENABLED */ + nla_total_size(sizeof(u8)) + /* IFLA_BR_VLAN_STATS_PER_PORT */ +#endif + nla_total_size(sizeof(u16)) + /* IFLA_BR_GROUP_FWD_MASK */ + nla_total_size(sizeof(struct ifla_bridge_id)) + /* IFLA_BR_ROOT_ID */ + nla_total_size(sizeof(struct ifla_bridge_id)) + /* IFLA_BR_BRIDGE_ID */ + nla_total_size(sizeof(u16)) + /* IFLA_BR_ROOT_PORT */ + nla_total_size(sizeof(u32)) + /* IFLA_BR_ROOT_PATH_COST */ + nla_total_size(sizeof(u8)) + /* IFLA_BR_TOPOLOGY_CHANGE */ + nla_total_size(sizeof(u8)) + /* IFLA_BR_TOPOLOGY_CHANGE_DETECTED */ + nla_total_size_64bit(sizeof(u64)) + /* IFLA_BR_HELLO_TIMER */ + nla_total_size_64bit(sizeof(u64)) + /* IFLA_BR_TCN_TIMER */ + nla_total_size_64bit(sizeof(u64)) + /* IFLA_BR_TOPOLOGY_CHANGE_TIMER */ + nla_total_size_64bit(sizeof(u64)) + /* IFLA_BR_GC_TIMER */ + nla_total_size(ETH_ALEN) + /* IFLA_BR_GROUP_ADDR */ +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + nla_total_size(sizeof(u8)) + /* IFLA_BR_MCAST_ROUTER */ + nla_total_size(sizeof(u8)) + /* IFLA_BR_MCAST_SNOOPING */ + nla_total_size(sizeof(u8)) + /* IFLA_BR_MCAST_QUERY_USE_IFADDR */ + nla_total_size(sizeof(u8)) + /* IFLA_BR_MCAST_QUERIER */ + nla_total_size(sizeof(u8)) + /* IFLA_BR_MCAST_STATS_ENABLED */ + nla_total_size(sizeof(u32)) + /* IFLA_BR_MCAST_HASH_ELASTICITY */ + nla_total_size(sizeof(u32)) + /* IFLA_BR_MCAST_HASH_MAX */ + nla_total_size(sizeof(u32)) + /* IFLA_BR_MCAST_LAST_MEMBER_CNT */ + nla_total_size(sizeof(u32)) + /* IFLA_BR_MCAST_STARTUP_QUERY_CNT */ + nla_total_size_64bit(sizeof(u64)) + /* IFLA_BR_MCAST_LAST_MEMBER_INTVL */ + nla_total_size_64bit(sizeof(u64)) + /* IFLA_BR_MCAST_MEMBERSHIP_INTVL */ + nla_total_size_64bit(sizeof(u64)) + /* IFLA_BR_MCAST_QUERIER_INTVL */ + nla_total_size_64bit(sizeof(u64)) + /* IFLA_BR_MCAST_QUERY_INTVL */ + nla_total_size_64bit(sizeof(u64)) + /* IFLA_BR_MCAST_QUERY_RESPONSE_INTVL */ + nla_total_size_64bit(sizeof(u64)) + /* IFLA_BR_MCAST_STARTUP_QUERY_INTVL */ + nla_total_size(sizeof(u8)) + /* IFLA_BR_MCAST_IGMP_VERSION */ + nla_total_size(sizeof(u8)) + /* IFLA_BR_MCAST_MLD_VERSION */ + br_multicast_querier_state_size() + /* IFLA_BR_MCAST_QUERIER_STATE */ +#endif +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + nla_total_size(sizeof(u8)) + /* IFLA_BR_NF_CALL_IPTABLES */ + nla_total_size(sizeof(u8)) + /* IFLA_BR_NF_CALL_IP6TABLES */ + nla_total_size(sizeof(u8)) + /* IFLA_BR_NF_CALL_ARPTABLES */ +#endif + nla_total_size(sizeof(struct br_boolopt_multi)) + /* IFLA_BR_MULTI_BOOLOPT */ + 0; +} + +static int br_fill_info(struct sk_buff *skb, const struct net_device *brdev) +{ + struct net_bridge *br = netdev_priv(brdev); + u32 forward_delay = jiffies_to_clock_t(br->forward_delay); + u32 hello_time = jiffies_to_clock_t(br->hello_time); + u32 age_time = jiffies_to_clock_t(br->max_age); + u32 ageing_time = jiffies_to_clock_t(br->ageing_time); + u32 stp_enabled = br->stp_enabled; + u16 priority = (br->bridge_id.prio[0] << 8) | br->bridge_id.prio[1]; + u8 vlan_enabled = br_vlan_enabled(br->dev); + struct br_boolopt_multi bm; + u64 clockval; + + clockval = br_timer_value(&br->hello_timer); + if (nla_put_u64_64bit(skb, IFLA_BR_HELLO_TIMER, clockval, IFLA_BR_PAD)) + return -EMSGSIZE; + clockval = br_timer_value(&br->tcn_timer); + if (nla_put_u64_64bit(skb, IFLA_BR_TCN_TIMER, clockval, IFLA_BR_PAD)) + return -EMSGSIZE; + clockval = br_timer_value(&br->topology_change_timer); + if (nla_put_u64_64bit(skb, IFLA_BR_TOPOLOGY_CHANGE_TIMER, clockval, + IFLA_BR_PAD)) + return -EMSGSIZE; + clockval = br_timer_value(&br->gc_work.timer); + if (nla_put_u64_64bit(skb, IFLA_BR_GC_TIMER, clockval, IFLA_BR_PAD)) + return -EMSGSIZE; + + br_boolopt_multi_get(br, &bm); + if (nla_put_u32(skb, IFLA_BR_FORWARD_DELAY, forward_delay) || + nla_put_u32(skb, IFLA_BR_HELLO_TIME, hello_time) || + nla_put_u32(skb, IFLA_BR_MAX_AGE, age_time) || + nla_put_u32(skb, IFLA_BR_AGEING_TIME, ageing_time) || + nla_put_u32(skb, IFLA_BR_STP_STATE, stp_enabled) || + nla_put_u16(skb, IFLA_BR_PRIORITY, priority) || + nla_put_u8(skb, IFLA_BR_VLAN_FILTERING, vlan_enabled) || + nla_put_u16(skb, IFLA_BR_GROUP_FWD_MASK, br->group_fwd_mask) || + nla_put(skb, IFLA_BR_BRIDGE_ID, sizeof(struct ifla_bridge_id), + &br->bridge_id) || + nla_put(skb, IFLA_BR_ROOT_ID, sizeof(struct ifla_bridge_id), + &br->designated_root) || + nla_put_u16(skb, IFLA_BR_ROOT_PORT, br->root_port) || + nla_put_u32(skb, IFLA_BR_ROOT_PATH_COST, br->root_path_cost) || + nla_put_u8(skb, IFLA_BR_TOPOLOGY_CHANGE, br->topology_change) || + nla_put_u8(skb, IFLA_BR_TOPOLOGY_CHANGE_DETECTED, + br->topology_change_detected) || + nla_put(skb, IFLA_BR_GROUP_ADDR, ETH_ALEN, br->group_addr) || + nla_put(skb, IFLA_BR_MULTI_BOOLOPT, sizeof(bm), &bm)) + return -EMSGSIZE; + +#ifdef CONFIG_BRIDGE_VLAN_FILTERING + if (nla_put_be16(skb, IFLA_BR_VLAN_PROTOCOL, br->vlan_proto) || + nla_put_u16(skb, IFLA_BR_VLAN_DEFAULT_PVID, br->default_pvid) || + nla_put_u8(skb, IFLA_BR_VLAN_STATS_ENABLED, + br_opt_get(br, BROPT_VLAN_STATS_ENABLED)) || + nla_put_u8(skb, IFLA_BR_VLAN_STATS_PER_PORT, + br_opt_get(br, BROPT_VLAN_STATS_PER_PORT))) + return -EMSGSIZE; +#endif +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + if (nla_put_u8(skb, IFLA_BR_MCAST_ROUTER, + br->multicast_ctx.multicast_router) || + nla_put_u8(skb, IFLA_BR_MCAST_SNOOPING, + br_opt_get(br, BROPT_MULTICAST_ENABLED)) || + nla_put_u8(skb, IFLA_BR_MCAST_QUERY_USE_IFADDR, + br_opt_get(br, BROPT_MULTICAST_QUERY_USE_IFADDR)) || + nla_put_u8(skb, IFLA_BR_MCAST_QUERIER, + br->multicast_ctx.multicast_querier) || + nla_put_u8(skb, IFLA_BR_MCAST_STATS_ENABLED, + br_opt_get(br, BROPT_MULTICAST_STATS_ENABLED)) || + nla_put_u32(skb, IFLA_BR_MCAST_HASH_ELASTICITY, RHT_ELASTICITY) || + nla_put_u32(skb, IFLA_BR_MCAST_HASH_MAX, br->hash_max) || + nla_put_u32(skb, IFLA_BR_MCAST_LAST_MEMBER_CNT, + br->multicast_ctx.multicast_last_member_count) || + nla_put_u32(skb, IFLA_BR_MCAST_STARTUP_QUERY_CNT, + br->multicast_ctx.multicast_startup_query_count) || + nla_put_u8(skb, IFLA_BR_MCAST_IGMP_VERSION, + br->multicast_ctx.multicast_igmp_version) || + br_multicast_dump_querier_state(skb, &br->multicast_ctx, + IFLA_BR_MCAST_QUERIER_STATE)) + return -EMSGSIZE; +#if IS_ENABLED(CONFIG_IPV6) + if (nla_put_u8(skb, IFLA_BR_MCAST_MLD_VERSION, + br->multicast_ctx.multicast_mld_version)) + return -EMSGSIZE; +#endif + clockval = jiffies_to_clock_t(br->multicast_ctx.multicast_last_member_interval); + if (nla_put_u64_64bit(skb, IFLA_BR_MCAST_LAST_MEMBER_INTVL, clockval, + IFLA_BR_PAD)) + return -EMSGSIZE; + clockval = jiffies_to_clock_t(br->multicast_ctx.multicast_membership_interval); + if (nla_put_u64_64bit(skb, IFLA_BR_MCAST_MEMBERSHIP_INTVL, clockval, + IFLA_BR_PAD)) + return -EMSGSIZE; + clockval = jiffies_to_clock_t(br->multicast_ctx.multicast_querier_interval); + if (nla_put_u64_64bit(skb, IFLA_BR_MCAST_QUERIER_INTVL, clockval, + IFLA_BR_PAD)) + return -EMSGSIZE; + clockval = jiffies_to_clock_t(br->multicast_ctx.multicast_query_interval); + if (nla_put_u64_64bit(skb, IFLA_BR_MCAST_QUERY_INTVL, clockval, + IFLA_BR_PAD)) + return -EMSGSIZE; + clockval = jiffies_to_clock_t(br->multicast_ctx.multicast_query_response_interval); + if (nla_put_u64_64bit(skb, IFLA_BR_MCAST_QUERY_RESPONSE_INTVL, clockval, + IFLA_BR_PAD)) + return -EMSGSIZE; + clockval = jiffies_to_clock_t(br->multicast_ctx.multicast_startup_query_interval); + if (nla_put_u64_64bit(skb, IFLA_BR_MCAST_STARTUP_QUERY_INTVL, clockval, + IFLA_BR_PAD)) + return -EMSGSIZE; +#endif +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + if (nla_put_u8(skb, IFLA_BR_NF_CALL_IPTABLES, + br_opt_get(br, BROPT_NF_CALL_IPTABLES) ? 1 : 0) || + nla_put_u8(skb, IFLA_BR_NF_CALL_IP6TABLES, + br_opt_get(br, BROPT_NF_CALL_IP6TABLES) ? 1 : 0) || + nla_put_u8(skb, IFLA_BR_NF_CALL_ARPTABLES, + br_opt_get(br, BROPT_NF_CALL_ARPTABLES) ? 1 : 0)) + return -EMSGSIZE; +#endif + + return 0; +} + +static size_t br_get_linkxstats_size(const struct net_device *dev, int attr) +{ + struct net_bridge_port *p = NULL; + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *v; + struct net_bridge *br; + int numvls = 0; + + switch (attr) { + case IFLA_STATS_LINK_XSTATS: + br = netdev_priv(dev); + vg = br_vlan_group(br); + break; + case IFLA_STATS_LINK_XSTATS_SLAVE: + p = br_port_get_rtnl(dev); + if (!p) + return 0; + vg = nbp_vlan_group(p); + break; + default: + return 0; + } + + if (vg) { + /* we need to count all, even placeholder entries */ + list_for_each_entry(v, &vg->vlan_list, vlist) + numvls++; + } + + return numvls * nla_total_size(sizeof(struct bridge_vlan_xstats)) + + nla_total_size_64bit(sizeof(struct br_mcast_stats)) + + (p ? nla_total_size_64bit(sizeof(p->stp_xstats)) : 0) + + nla_total_size(0); +} + +static int br_fill_linkxstats(struct sk_buff *skb, + const struct net_device *dev, + int *prividx, int attr) +{ + struct nlattr *nla __maybe_unused; + struct net_bridge_port *p = NULL; + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *v; + struct net_bridge *br; + struct nlattr *nest; + int vl_idx = 0; + + switch (attr) { + case IFLA_STATS_LINK_XSTATS: + br = netdev_priv(dev); + vg = br_vlan_group(br); + break; + case IFLA_STATS_LINK_XSTATS_SLAVE: + p = br_port_get_rtnl(dev); + if (!p) + return 0; + br = p->br; + vg = nbp_vlan_group(p); + break; + default: + return -EINVAL; + } + + nest = nla_nest_start_noflag(skb, LINK_XSTATS_TYPE_BRIDGE); + if (!nest) + return -EMSGSIZE; + + if (vg) { + u16 pvid; + + pvid = br_get_pvid(vg); + list_for_each_entry(v, &vg->vlan_list, vlist) { + struct bridge_vlan_xstats vxi; + struct pcpu_sw_netstats stats; + + if (++vl_idx < *prividx) + continue; + memset(&vxi, 0, sizeof(vxi)); + vxi.vid = v->vid; + vxi.flags = v->flags; + if (v->vid == pvid) + vxi.flags |= BRIDGE_VLAN_INFO_PVID; + br_vlan_get_stats(v, &stats); + vxi.rx_bytes = u64_stats_read(&stats.rx_bytes); + vxi.rx_packets = u64_stats_read(&stats.rx_packets); + vxi.tx_bytes = u64_stats_read(&stats.tx_bytes); + vxi.tx_packets = u64_stats_read(&stats.tx_packets); + + if (nla_put(skb, BRIDGE_XSTATS_VLAN, sizeof(vxi), &vxi)) + goto nla_put_failure; + } + } + +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + if (++vl_idx >= *prividx) { + nla = nla_reserve_64bit(skb, BRIDGE_XSTATS_MCAST, + sizeof(struct br_mcast_stats), + BRIDGE_XSTATS_PAD); + if (!nla) + goto nla_put_failure; + br_multicast_get_stats(br, p, nla_data(nla)); + } +#endif + + if (p) { + nla = nla_reserve_64bit(skb, BRIDGE_XSTATS_STP, + sizeof(p->stp_xstats), + BRIDGE_XSTATS_PAD); + if (!nla) + goto nla_put_failure; + + spin_lock_bh(&br->lock); + memcpy(nla_data(nla), &p->stp_xstats, sizeof(p->stp_xstats)); + spin_unlock_bh(&br->lock); + } + + nla_nest_end(skb, nest); + *prividx = 0; + + return 0; + +nla_put_failure: + nla_nest_end(skb, nest); + *prividx = vl_idx; + + return -EMSGSIZE; +} + +static struct rtnl_af_ops br_af_ops __read_mostly = { + .family = AF_BRIDGE, + .get_link_af_size = br_get_link_af_size_filtered, +}; + +struct rtnl_link_ops br_link_ops __read_mostly = { + .kind = "bridge", + .priv_size = sizeof(struct net_bridge), + .setup = br_dev_setup, + .maxtype = IFLA_BR_MAX, + .policy = br_policy, + .validate = br_validate, + .newlink = br_dev_newlink, + .changelink = br_changelink, + .dellink = br_dev_delete, + .get_size = br_get_size, + .fill_info = br_fill_info, + .fill_linkxstats = br_fill_linkxstats, + .get_linkxstats_size = br_get_linkxstats_size, + + .slave_maxtype = IFLA_BRPORT_MAX, + .slave_policy = br_port_policy, + .slave_changelink = br_port_slave_changelink, + .get_slave_size = br_port_get_slave_size, + .fill_slave_info = br_port_fill_slave_info, +}; + +int __init br_netlink_init(void) +{ + int err; + + br_mdb_init(); + br_vlan_rtnl_init(); + rtnl_af_register(&br_af_ops); + + err = rtnl_link_register(&br_link_ops); + if (err) + goto out_af; + + return 0; + +out_af: + rtnl_af_unregister(&br_af_ops); + br_mdb_uninit(); + return err; +} + +void br_netlink_fini(void) +{ + br_mdb_uninit(); + br_vlan_rtnl_uninit(); + rtnl_af_unregister(&br_af_ops); + rtnl_link_unregister(&br_link_ops); +} diff --git a/net/bridge/br_netlink_tunnel.c b/net/bridge/br_netlink_tunnel.c new file mode 100644 index 000000000..8914290c7 --- /dev/null +++ b/net/bridge/br_netlink_tunnel.c @@ -0,0 +1,339 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Bridge per vlan tunnel port dst_metadata netlink control interface + * + * Authors: + * Roopa Prabhu + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "br_private.h" +#include "br_private_tunnel.h" + +static size_t __get_vlan_tinfo_size(void) +{ + return nla_total_size(0) + /* nest IFLA_BRIDGE_VLAN_TUNNEL_INFO */ + nla_total_size(sizeof(u32)) + /* IFLA_BRIDGE_VLAN_TUNNEL_ID */ + nla_total_size(sizeof(u16)) + /* IFLA_BRIDGE_VLAN_TUNNEL_VID */ + nla_total_size(sizeof(u16)); /* IFLA_BRIDGE_VLAN_TUNNEL_FLAGS */ +} + +bool vlan_tunid_inrange(const struct net_bridge_vlan *v_curr, + const struct net_bridge_vlan *v_last) +{ + __be32 tunid_curr = tunnel_id_to_key32(v_curr->tinfo.tunnel_id); + __be32 tunid_last = tunnel_id_to_key32(v_last->tinfo.tunnel_id); + + return (be32_to_cpu(tunid_curr) - be32_to_cpu(tunid_last)) == 1; +} + +static int __get_num_vlan_tunnel_infos(struct net_bridge_vlan_group *vg) +{ + struct net_bridge_vlan *v, *vtbegin = NULL, *vtend = NULL; + int num_tinfos = 0; + + /* Count number of vlan infos */ + list_for_each_entry_rcu(v, &vg->vlan_list, vlist) { + /* only a context, bridge vlan not activated */ + if (!br_vlan_should_use(v) || !v->tinfo.tunnel_id) + continue; + + if (!vtbegin) { + goto initvars; + } else if ((v->vid - vtend->vid) == 1 && + vlan_tunid_inrange(v, vtend)) { + vtend = v; + continue; + } else { + if ((vtend->vid - vtbegin->vid) > 0) + num_tinfos += 2; + else + num_tinfos += 1; + } +initvars: + vtbegin = v; + vtend = v; + } + + if (vtbegin && vtend) { + if ((vtend->vid - vtbegin->vid) > 0) + num_tinfos += 2; + else + num_tinfos += 1; + } + + return num_tinfos; +} + +int br_get_vlan_tunnel_info_size(struct net_bridge_vlan_group *vg) +{ + int num_tinfos; + + if (!vg) + return 0; + + rcu_read_lock(); + num_tinfos = __get_num_vlan_tunnel_infos(vg); + rcu_read_unlock(); + + return num_tinfos * __get_vlan_tinfo_size(); +} + +static int br_fill_vlan_tinfo(struct sk_buff *skb, u16 vid, + __be64 tunnel_id, u16 flags) +{ + __be32 tid = tunnel_id_to_key32(tunnel_id); + struct nlattr *tmap; + + tmap = nla_nest_start_noflag(skb, IFLA_BRIDGE_VLAN_TUNNEL_INFO); + if (!tmap) + return -EMSGSIZE; + if (nla_put_u32(skb, IFLA_BRIDGE_VLAN_TUNNEL_ID, + be32_to_cpu(tid))) + goto nla_put_failure; + if (nla_put_u16(skb, IFLA_BRIDGE_VLAN_TUNNEL_VID, + vid)) + goto nla_put_failure; + if (nla_put_u16(skb, IFLA_BRIDGE_VLAN_TUNNEL_FLAGS, + flags)) + goto nla_put_failure; + nla_nest_end(skb, tmap); + + return 0; + +nla_put_failure: + nla_nest_cancel(skb, tmap); + + return -EMSGSIZE; +} + +static int br_fill_vlan_tinfo_range(struct sk_buff *skb, + struct net_bridge_vlan *vtbegin, + struct net_bridge_vlan *vtend) +{ + int err; + + if (vtend && (vtend->vid - vtbegin->vid) > 0) { + /* add range to skb */ + err = br_fill_vlan_tinfo(skb, vtbegin->vid, + vtbegin->tinfo.tunnel_id, + BRIDGE_VLAN_INFO_RANGE_BEGIN); + if (err) + return err; + + err = br_fill_vlan_tinfo(skb, vtend->vid, + vtend->tinfo.tunnel_id, + BRIDGE_VLAN_INFO_RANGE_END); + if (err) + return err; + } else { + err = br_fill_vlan_tinfo(skb, vtbegin->vid, + vtbegin->tinfo.tunnel_id, + 0); + if (err) + return err; + } + + return 0; +} + +int br_fill_vlan_tunnel_info(struct sk_buff *skb, + struct net_bridge_vlan_group *vg) +{ + struct net_bridge_vlan *vtbegin = NULL; + struct net_bridge_vlan *vtend = NULL; + struct net_bridge_vlan *v; + int err; + + /* Count number of vlan infos */ + list_for_each_entry_rcu(v, &vg->vlan_list, vlist) { + /* only a context, bridge vlan not activated */ + if (!br_vlan_should_use(v)) + continue; + + if (!v->tinfo.tunnel_dst) + continue; + + if (!vtbegin) { + goto initvars; + } else if ((v->vid - vtend->vid) == 1 && + vlan_tunid_inrange(v, vtend)) { + vtend = v; + continue; + } else { + err = br_fill_vlan_tinfo_range(skb, vtbegin, vtend); + if (err) + return err; + } +initvars: + vtbegin = v; + vtend = v; + } + + if (vtbegin) { + err = br_fill_vlan_tinfo_range(skb, vtbegin, vtend); + if (err) + return err; + } + + return 0; +} + +static const struct nla_policy vlan_tunnel_policy[IFLA_BRIDGE_VLAN_TUNNEL_MAX + 1] = { + [IFLA_BRIDGE_VLAN_TUNNEL_ID] = { .type = NLA_U32 }, + [IFLA_BRIDGE_VLAN_TUNNEL_VID] = { .type = NLA_U16 }, + [IFLA_BRIDGE_VLAN_TUNNEL_FLAGS] = { .type = NLA_U16 }, +}; + +int br_vlan_tunnel_info(const struct net_bridge_port *p, int cmd, + u16 vid, u32 tun_id, bool *changed) +{ + int err = 0; + + if (!p) + return -EINVAL; + + switch (cmd) { + case RTM_SETLINK: + err = nbp_vlan_tunnel_info_add(p, vid, tun_id); + if (!err) + *changed = true; + break; + case RTM_DELLINK: + if (!nbp_vlan_tunnel_info_delete(p, vid)) + *changed = true; + break; + } + + return err; +} + +int br_parse_vlan_tunnel_info(struct nlattr *attr, + struct vtunnel_info *tinfo) +{ + struct nlattr *tb[IFLA_BRIDGE_VLAN_TUNNEL_MAX + 1]; + u32 tun_id; + u16 vid, flags = 0; + int err; + + memset(tinfo, 0, sizeof(*tinfo)); + + err = nla_parse_nested_deprecated(tb, IFLA_BRIDGE_VLAN_TUNNEL_MAX, + attr, vlan_tunnel_policy, NULL); + if (err < 0) + return err; + + if (!tb[IFLA_BRIDGE_VLAN_TUNNEL_ID] || + !tb[IFLA_BRIDGE_VLAN_TUNNEL_VID]) + return -EINVAL; + + tun_id = nla_get_u32(tb[IFLA_BRIDGE_VLAN_TUNNEL_ID]); + vid = nla_get_u16(tb[IFLA_BRIDGE_VLAN_TUNNEL_VID]); + if (vid >= VLAN_VID_MASK) + return -ERANGE; + + if (tb[IFLA_BRIDGE_VLAN_TUNNEL_FLAGS]) + flags = nla_get_u16(tb[IFLA_BRIDGE_VLAN_TUNNEL_FLAGS]); + + tinfo->tunid = tun_id; + tinfo->vid = vid; + tinfo->flags = flags; + + return 0; +} + +/* send a notification if v_curr can't enter the range and start a new one */ +static void __vlan_tunnel_handle_range(const struct net_bridge_port *p, + struct net_bridge_vlan **v_start, + struct net_bridge_vlan **v_end, + int v_curr, bool curr_change) +{ + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *v; + + vg = nbp_vlan_group(p); + if (!vg) + return; + + v = br_vlan_find(vg, v_curr); + + if (!*v_start) + goto out_init; + + if (v && curr_change && br_vlan_can_enter_range(v, *v_end)) { + *v_end = v; + return; + } + + br_vlan_notify(p->br, p, (*v_start)->vid, (*v_end)->vid, RTM_NEWVLAN); +out_init: + /* we start a range only if there are any changes to notify about */ + *v_start = curr_change ? v : NULL; + *v_end = *v_start; +} + +int br_process_vlan_tunnel_info(const struct net_bridge *br, + const struct net_bridge_port *p, int cmd, + struct vtunnel_info *tinfo_curr, + struct vtunnel_info *tinfo_last, + bool *changed) +{ + int err; + + if (tinfo_curr->flags & BRIDGE_VLAN_INFO_RANGE_BEGIN) { + if (tinfo_last->flags & BRIDGE_VLAN_INFO_RANGE_BEGIN) + return -EINVAL; + memcpy(tinfo_last, tinfo_curr, sizeof(struct vtunnel_info)); + } else if (tinfo_curr->flags & BRIDGE_VLAN_INFO_RANGE_END) { + struct net_bridge_vlan *v_start = NULL, *v_end = NULL; + int t, v; + + if (!(tinfo_last->flags & BRIDGE_VLAN_INFO_RANGE_BEGIN)) + return -EINVAL; + if ((tinfo_curr->vid - tinfo_last->vid) != + (tinfo_curr->tunid - tinfo_last->tunid)) + return -EINVAL; + t = tinfo_last->tunid; + for (v = tinfo_last->vid; v <= tinfo_curr->vid; v++) { + bool curr_change = false; + + err = br_vlan_tunnel_info(p, cmd, v, t, &curr_change); + if (err) + break; + t++; + + if (curr_change) + *changed = curr_change; + __vlan_tunnel_handle_range(p, &v_start, &v_end, v, + curr_change); + } + if (v_start && v_end) + br_vlan_notify(br, p, v_start->vid, v_end->vid, + RTM_NEWVLAN); + if (err) + return err; + + memset(tinfo_last, 0, sizeof(struct vtunnel_info)); + memset(tinfo_curr, 0, sizeof(struct vtunnel_info)); + } else { + if (tinfo_last->flags) + return -EINVAL; + err = br_vlan_tunnel_info(p, cmd, tinfo_curr->vid, + tinfo_curr->tunid, changed); + if (err) + return err; + br_vlan_notify(br, p, tinfo_curr->vid, 0, RTM_NEWVLAN); + memset(tinfo_last, 0, sizeof(struct vtunnel_info)); + memset(tinfo_curr, 0, sizeof(struct vtunnel_info)); + } + + return 0; +} diff --git a/net/bridge/br_nf_core.c b/net/bridge/br_nf_core.c new file mode 100644 index 000000000..8c69f0c95 --- /dev/null +++ b/net/bridge/br_nf_core.c @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Handle firewalling core + * Linux ethernet bridge + * + * Authors: + * Lennert Buytenhek + * Bart De Schuymer + * + * Lennert dedicates this file to Kerstin Wurdinger. + */ + +#include +#include +#include +#include +#include + +#include "br_private.h" +#ifdef CONFIG_SYSCTL +#include +#endif + +static void fake_update_pmtu(struct dst_entry *dst, struct sock *sk, + struct sk_buff *skb, u32 mtu, + bool confirm_neigh) +{ +} + +static void fake_redirect(struct dst_entry *dst, struct sock *sk, + struct sk_buff *skb) +{ +} + +static u32 *fake_cow_metrics(struct dst_entry *dst, unsigned long old) +{ + return NULL; +} + +static struct neighbour *fake_neigh_lookup(const struct dst_entry *dst, + struct sk_buff *skb, + const void *daddr) +{ + return NULL; +} + +static unsigned int fake_mtu(const struct dst_entry *dst) +{ + return dst->dev->mtu; +} + +static struct dst_ops fake_dst_ops = { + .family = AF_INET, + .update_pmtu = fake_update_pmtu, + .redirect = fake_redirect, + .cow_metrics = fake_cow_metrics, + .neigh_lookup = fake_neigh_lookup, + .mtu = fake_mtu, +}; + +/* + * Initialize bogus route table used to keep netfilter happy. + * Currently, we fill in the PMTU entry because netfilter + * refragmentation needs it, and the rt_flags entry because + * ipt_REJECT needs it. Future netfilter modules might + * require us to fill additional fields. + */ +static const u32 br_dst_default_metrics[RTAX_MAX] = { + [RTAX_MTU - 1] = 1500, +}; + +void br_netfilter_rtable_init(struct net_bridge *br) +{ + struct rtable *rt = &br->fake_rtable; + + atomic_set(&rt->dst.__refcnt, 1); + rt->dst.dev = br->dev; + dst_init_metrics(&rt->dst, br_dst_default_metrics, true); + rt->dst.flags = DST_NOXFRM | DST_FAKE_RTABLE; + rt->dst.ops = &fake_dst_ops; +} + +int __init br_nf_core_init(void) +{ + return dst_entries_init(&fake_dst_ops); +} + +void br_nf_core_fini(void) +{ + dst_entries_destroy(&fake_dst_ops); +} diff --git a/net/bridge/br_private.h b/net/bridge/br_private.h new file mode 100644 index 000000000..06e5f6faa --- /dev/null +++ b/net/bridge/br_private.h @@ -0,0 +1,2175 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * Linux ethernet bridge + * + * Authors: + * Lennert Buytenhek + */ + +#ifndef _BR_PRIVATE_H +#define _BR_PRIVATE_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define BR_HASH_BITS 8 +#define BR_HASH_SIZE (1 << BR_HASH_BITS) + +#define BR_HOLD_TIME (1*HZ) + +#define BR_PORT_BITS 10 +#define BR_MAX_PORTS (1<flags & BR_AUTO_MASK) +#define br_promisc_port(p) ((p)->flags & BR_PROMISC) + +static inline struct net_bridge_port *br_port_get_rcu(const struct net_device *dev) +{ + return rcu_dereference(dev->rx_handler_data); +} + +static inline struct net_bridge_port *br_port_get_rtnl(const struct net_device *dev) +{ + return netif_is_bridge_port(dev) ? + rtnl_dereference(dev->rx_handler_data) : NULL; +} + +static inline struct net_bridge_port *br_port_get_rtnl_rcu(const struct net_device *dev) +{ + return netif_is_bridge_port(dev) ? + rcu_dereference_rtnl(dev->rx_handler_data) : NULL; +} + +enum net_bridge_opts { + BROPT_VLAN_ENABLED, + BROPT_VLAN_STATS_ENABLED, + BROPT_NF_CALL_IPTABLES, + BROPT_NF_CALL_IP6TABLES, + BROPT_NF_CALL_ARPTABLES, + BROPT_GROUP_ADDR_SET, + BROPT_MULTICAST_ENABLED, + BROPT_MULTICAST_QUERY_USE_IFADDR, + BROPT_MULTICAST_STATS_ENABLED, + BROPT_HAS_IPV6_ADDR, + BROPT_NEIGH_SUPPRESS_ENABLED, + BROPT_MTU_SET_BY_USER, + BROPT_VLAN_STATS_PER_PORT, + BROPT_NO_LL_LEARN, + BROPT_VLAN_BRIDGE_BINDING, + BROPT_MCAST_VLAN_SNOOPING_ENABLED, + BROPT_MST_ENABLED, +}; + +struct net_bridge { + spinlock_t lock; + spinlock_t hash_lock; + struct hlist_head frame_type_list; + struct net_device *dev; + unsigned long options; + /* These fields are accessed on each packet */ +#ifdef CONFIG_BRIDGE_VLAN_FILTERING + __be16 vlan_proto; + u16 default_pvid; + struct net_bridge_vlan_group __rcu *vlgrp; +#endif + + struct rhashtable fdb_hash_tbl; + struct list_head port_list; +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + union { + struct rtable fake_rtable; + struct rt6_info fake_rt6_info; + }; +#endif + u16 group_fwd_mask; + u16 group_fwd_mask_required; + + /* STP */ + bridge_id designated_root; + bridge_id bridge_id; + unsigned char topology_change; + unsigned char topology_change_detected; + u16 root_port; + unsigned long max_age; + unsigned long hello_time; + unsigned long forward_delay; + unsigned long ageing_time; + unsigned long bridge_max_age; + unsigned long bridge_hello_time; + unsigned long bridge_forward_delay; + unsigned long bridge_ageing_time; + u32 root_path_cost; + + u8 group_addr[ETH_ALEN]; + + enum { + BR_NO_STP, /* no spanning tree */ + BR_KERNEL_STP, /* old STP in kernel */ + BR_USER_STP, /* new RSTP in userspace */ + } stp_enabled; + + struct net_bridge_mcast multicast_ctx; + +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + struct bridge_mcast_stats __percpu *mcast_stats; + + u32 hash_max; + + spinlock_t multicast_lock; + + struct rhashtable mdb_hash_tbl; + struct rhashtable sg_port_tbl; + + struct hlist_head mcast_gc_list; + struct hlist_head mdb_list; + + struct work_struct mcast_gc_work; +#endif + + struct timer_list hello_timer; + struct timer_list tcn_timer; + struct timer_list topology_change_timer; + struct delayed_work gc_work; + struct kobject *ifobj; + u32 auto_cnt; + +#ifdef CONFIG_NET_SWITCHDEV + /* Counter used to make sure that hardware domains get unique + * identifiers in case a bridge spans multiple switchdev instances. + */ + int last_hwdom; + /* Bit mask of hardware domain numbers in use */ + unsigned long busy_hwdoms; +#endif + struct hlist_head fdb_list; + +#if IS_ENABLED(CONFIG_BRIDGE_MRP) + struct hlist_head mrp_list; +#endif +#if IS_ENABLED(CONFIG_BRIDGE_CFM) + struct hlist_head mep_list; +#endif +}; + +struct br_input_skb_cb { + struct net_device *brdev; + + u16 frag_max_size; +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + u8 igmp; + u8 mrouters_only:1; +#endif + u8 proxyarp_replied:1; + u8 src_port_isolated:1; +#ifdef CONFIG_BRIDGE_VLAN_FILTERING + u8 vlan_filtered:1; +#endif +#ifdef CONFIG_NETFILTER_FAMILY_BRIDGE + u8 br_netfilter_broute:1; +#endif + +#ifdef CONFIG_NET_SWITCHDEV + /* Set if TX data plane offloading is used towards at least one + * hardware domain. + */ + u8 tx_fwd_offload:1; + /* The switchdev hardware domain from which this packet was received. + * If skb->offload_fwd_mark was set, then this packet was already + * forwarded by hardware to the other ports in the source hardware + * domain, otherwise it wasn't. + */ + int src_hwdom; + /* Bit mask of hardware domains towards this packet has already been + * transmitted using the TX data plane offload. + */ + unsigned long fwd_hwdoms; +#endif +}; + +#define BR_INPUT_SKB_CB(__skb) ((struct br_input_skb_cb *)(__skb)->cb) + +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING +# define BR_INPUT_SKB_CB_MROUTERS_ONLY(__skb) (BR_INPUT_SKB_CB(__skb)->mrouters_only) +#else +# define BR_INPUT_SKB_CB_MROUTERS_ONLY(__skb) (0) +#endif + +#define br_printk(level, br, format, args...) \ + printk(level "%s: " format, (br)->dev->name, ##args) + +#define br_err(__br, format, args...) \ + br_printk(KERN_ERR, __br, format, ##args) +#define br_warn(__br, format, args...) \ + br_printk(KERN_WARNING, __br, format, ##args) +#define br_notice(__br, format, args...) \ + br_printk(KERN_NOTICE, __br, format, ##args) +#define br_info(__br, format, args...) \ + br_printk(KERN_INFO, __br, format, ##args) + +#define br_debug(br, format, args...) \ + pr_debug("%s: " format, (br)->dev->name, ##args) + +/* called under bridge lock */ +static inline int br_is_root_bridge(const struct net_bridge *br) +{ + return !memcmp(&br->bridge_id, &br->designated_root, 8); +} + +/* check if a VLAN entry is global */ +static inline bool br_vlan_is_master(const struct net_bridge_vlan *v) +{ + return v->flags & BRIDGE_VLAN_INFO_MASTER; +} + +/* check if a VLAN entry is used by the bridge */ +static inline bool br_vlan_is_brentry(const struct net_bridge_vlan *v) +{ + return v->flags & BRIDGE_VLAN_INFO_BRENTRY; +} + +/* check if we should use the vlan entry, returns false if it's only context */ +static inline bool br_vlan_should_use(const struct net_bridge_vlan *v) +{ + if (br_vlan_is_master(v)) { + if (br_vlan_is_brentry(v)) + return true; + else + return false; + } + + return true; +} + +static inline bool nbp_state_should_learn(const struct net_bridge_port *p) +{ + return p->state == BR_STATE_LEARNING || p->state == BR_STATE_FORWARDING; +} + +static inline bool br_vlan_valid_id(u16 vid, struct netlink_ext_ack *extack) +{ + bool ret = vid > 0 && vid < VLAN_VID_MASK; + + if (!ret) + NL_SET_ERR_MSG_MOD(extack, "Vlan id is invalid"); + + return ret; +} + +static inline bool br_vlan_valid_range(const struct bridge_vlan_info *cur, + const struct bridge_vlan_info *last, + struct netlink_ext_ack *extack) +{ + /* pvid flag is not allowed in ranges */ + if (cur->flags & BRIDGE_VLAN_INFO_PVID) { + NL_SET_ERR_MSG_MOD(extack, "Pvid isn't allowed in a range"); + return false; + } + + /* when cur is the range end, check if: + * - it has range start flag + * - range ids are invalid (end is equal to or before start) + */ + if (last) { + if (cur->flags & BRIDGE_VLAN_INFO_RANGE_BEGIN) { + NL_SET_ERR_MSG_MOD(extack, "Found a new vlan range start while processing one"); + return false; + } else if (!(cur->flags & BRIDGE_VLAN_INFO_RANGE_END)) { + NL_SET_ERR_MSG_MOD(extack, "Vlan range end flag is missing"); + return false; + } else if (cur->vid <= last->vid) { + NL_SET_ERR_MSG_MOD(extack, "End vlan id is less than or equal to start vlan id"); + return false; + } + } + + /* check for required range flags */ + if (!(cur->flags & (BRIDGE_VLAN_INFO_RANGE_BEGIN | + BRIDGE_VLAN_INFO_RANGE_END))) { + NL_SET_ERR_MSG_MOD(extack, "Both vlan range flags are missing"); + return false; + } + + return true; +} + +static inline u8 br_vlan_multicast_router(const struct net_bridge_vlan *v) +{ + u8 mcast_router = MDB_RTR_TYPE_DISABLED; + +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + if (!br_vlan_is_master(v)) + mcast_router = v->port_mcast_ctx.multicast_router; + else + mcast_router = v->br_mcast_ctx.multicast_router; +#endif + + return mcast_router; +} + +static inline int br_afspec_cmd_to_rtm(int cmd) +{ + switch (cmd) { + case RTM_SETLINK: + return RTM_NEWVLAN; + case RTM_DELLINK: + return RTM_DELVLAN; + } + + return 0; +} + +static inline int br_opt_get(const struct net_bridge *br, + enum net_bridge_opts opt) +{ + return test_bit(opt, &br->options); +} + +int br_boolopt_toggle(struct net_bridge *br, enum br_boolopt_id opt, bool on, + struct netlink_ext_ack *extack); +int br_boolopt_get(const struct net_bridge *br, enum br_boolopt_id opt); +int br_boolopt_multi_toggle(struct net_bridge *br, + struct br_boolopt_multi *bm, + struct netlink_ext_ack *extack); +void br_boolopt_multi_get(const struct net_bridge *br, + struct br_boolopt_multi *bm); +void br_opt_toggle(struct net_bridge *br, enum net_bridge_opts opt, bool on); + +/* br_device.c */ +void br_dev_setup(struct net_device *dev); +void br_dev_delete(struct net_device *dev, struct list_head *list); +netdev_tx_t br_dev_xmit(struct sk_buff *skb, struct net_device *dev); +#ifdef CONFIG_NET_POLL_CONTROLLER +static inline void br_netpoll_send_skb(const struct net_bridge_port *p, + struct sk_buff *skb) +{ + netpoll_send_skb(p->np, skb); +} + +int br_netpoll_enable(struct net_bridge_port *p); +void br_netpoll_disable(struct net_bridge_port *p); +#else +static inline void br_netpoll_send_skb(const struct net_bridge_port *p, + struct sk_buff *skb) +{ +} + +static inline int br_netpoll_enable(struct net_bridge_port *p) +{ + return 0; +} + +static inline void br_netpoll_disable(struct net_bridge_port *p) +{ +} +#endif + +/* br_fdb.c */ +#define FDB_FLUSH_IGNORED_NDM_FLAGS (NTF_MASTER | NTF_SELF) +#define FDB_FLUSH_ALLOWED_NDM_STATES (NUD_PERMANENT | NUD_NOARP) +#define FDB_FLUSH_ALLOWED_NDM_FLAGS (NTF_USE | NTF_EXT_LEARNED | \ + NTF_STICKY | NTF_OFFLOADED) + +int br_fdb_init(void); +void br_fdb_fini(void); +int br_fdb_hash_init(struct net_bridge *br); +void br_fdb_hash_fini(struct net_bridge *br); +void br_fdb_flush(struct net_bridge *br, + const struct net_bridge_fdb_flush_desc *desc); +void br_fdb_find_delete_local(struct net_bridge *br, + const struct net_bridge_port *p, + const unsigned char *addr, u16 vid); +void br_fdb_changeaddr(struct net_bridge_port *p, const unsigned char *newaddr); +void br_fdb_change_mac_address(struct net_bridge *br, const u8 *newaddr); +void br_fdb_cleanup(struct work_struct *work); +void br_fdb_delete_by_port(struct net_bridge *br, + const struct net_bridge_port *p, u16 vid, int do_all); +struct net_bridge_fdb_entry *br_fdb_find_rcu(struct net_bridge *br, + const unsigned char *addr, + __u16 vid); +int br_fdb_test_addr(struct net_device *dev, unsigned char *addr); +int br_fdb_fillbuf(struct net_bridge *br, void *buf, unsigned long count, + unsigned long off); +int br_fdb_add_local(struct net_bridge *br, struct net_bridge_port *source, + const unsigned char *addr, u16 vid); +void br_fdb_update(struct net_bridge *br, struct net_bridge_port *source, + const unsigned char *addr, u16 vid, unsigned long flags); + +int br_fdb_delete(struct ndmsg *ndm, struct nlattr *tb[], + struct net_device *dev, const unsigned char *addr, u16 vid, + struct netlink_ext_ack *extack); +int br_fdb_delete_bulk(struct ndmsg *ndm, struct nlattr *tb[], + struct net_device *dev, u16 vid, + struct netlink_ext_ack *extack); +int br_fdb_add(struct ndmsg *nlh, struct nlattr *tb[], struct net_device *dev, + const unsigned char *addr, u16 vid, u16 nlh_flags, + struct netlink_ext_ack *extack); +int br_fdb_dump(struct sk_buff *skb, struct netlink_callback *cb, + struct net_device *dev, struct net_device *fdev, int *idx); +int br_fdb_get(struct sk_buff *skb, struct nlattr *tb[], struct net_device *dev, + const unsigned char *addr, u16 vid, u32 portid, u32 seq, + struct netlink_ext_ack *extack); +int br_fdb_sync_static(struct net_bridge *br, struct net_bridge_port *p); +void br_fdb_unsync_static(struct net_bridge *br, struct net_bridge_port *p); +int br_fdb_external_learn_add(struct net_bridge *br, struct net_bridge_port *p, + const unsigned char *addr, u16 vid, + bool swdev_notify); +int br_fdb_external_learn_del(struct net_bridge *br, struct net_bridge_port *p, + const unsigned char *addr, u16 vid, + bool swdev_notify); +void br_fdb_offloaded_set(struct net_bridge *br, struct net_bridge_port *p, + const unsigned char *addr, u16 vid, bool offloaded); + +/* br_forward.c */ +enum br_pkt_type { + BR_PKT_UNICAST, + BR_PKT_MULTICAST, + BR_PKT_BROADCAST +}; +int br_dev_queue_push_xmit(struct net *net, struct sock *sk, struct sk_buff *skb); +void br_forward(const struct net_bridge_port *to, struct sk_buff *skb, + bool local_rcv, bool local_orig); +int br_forward_finish(struct net *net, struct sock *sk, struct sk_buff *skb); +void br_flood(struct net_bridge *br, struct sk_buff *skb, + enum br_pkt_type pkt_type, bool local_rcv, bool local_orig); + +/* return true if both source port and dest port are isolated */ +static inline bool br_skb_isolated(const struct net_bridge_port *to, + const struct sk_buff *skb) +{ + return BR_INPUT_SKB_CB(skb)->src_port_isolated && + (to->flags & BR_ISOLATED); +} + +/* br_if.c */ +void br_port_carrier_check(struct net_bridge_port *p, bool *notified); +int br_add_bridge(struct net *net, const char *name); +int br_del_bridge(struct net *net, const char *name); +int br_add_if(struct net_bridge *br, struct net_device *dev, + struct netlink_ext_ack *extack); +int br_del_if(struct net_bridge *br, struct net_device *dev); +void br_mtu_auto_adjust(struct net_bridge *br); +netdev_features_t br_features_recompute(struct net_bridge *br, + netdev_features_t features); +void br_port_flags_change(struct net_bridge_port *port, unsigned long mask); +void br_manage_promisc(struct net_bridge *br); +int nbp_backup_change(struct net_bridge_port *p, struct net_device *backup_dev); + +/* br_input.c */ +int br_handle_frame_finish(struct net *net, struct sock *sk, struct sk_buff *skb); +rx_handler_func_t *br_get_rx_handler(const struct net_device *dev); + +struct br_frame_type { + __be16 type; + int (*frame_handler)(struct net_bridge_port *port, + struct sk_buff *skb); + struct hlist_node list; +}; + +void br_add_frame(struct net_bridge *br, struct br_frame_type *ft); +void br_del_frame(struct net_bridge *br, struct br_frame_type *ft); + +static inline bool br_rx_handler_check_rcu(const struct net_device *dev) +{ + return rcu_dereference(dev->rx_handler) == br_get_rx_handler(dev); +} + +static inline bool br_rx_handler_check_rtnl(const struct net_device *dev) +{ + return rcu_dereference_rtnl(dev->rx_handler) == br_get_rx_handler(dev); +} + +static inline struct net_bridge_port *br_port_get_check_rcu(const struct net_device *dev) +{ + return br_rx_handler_check_rcu(dev) ? br_port_get_rcu(dev) : NULL; +} + +static inline struct net_bridge_port * +br_port_get_check_rtnl(const struct net_device *dev) +{ + return br_rx_handler_check_rtnl(dev) ? br_port_get_rtnl_rcu(dev) : NULL; +} + +/* br_ioctl.c */ +int br_dev_siocdevprivate(struct net_device *dev, struct ifreq *rq, + void __user *data, int cmd); +int br_ioctl_stub(struct net *net, struct net_bridge *br, unsigned int cmd, + struct ifreq *ifr, void __user *uarg); + +/* br_multicast.c */ +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING +int br_multicast_rcv(struct net_bridge_mcast **brmctx, + struct net_bridge_mcast_port **pmctx, + struct net_bridge_vlan *vlan, + struct sk_buff *skb, u16 vid); +struct net_bridge_mdb_entry *br_mdb_get(struct net_bridge_mcast *brmctx, + struct sk_buff *skb, u16 vid); +int br_multicast_add_port(struct net_bridge_port *port); +void br_multicast_del_port(struct net_bridge_port *port); +void br_multicast_enable_port(struct net_bridge_port *port); +void br_multicast_disable_port(struct net_bridge_port *port); +void br_multicast_init(struct net_bridge *br); +void br_multicast_join_snoopers(struct net_bridge *br); +void br_multicast_leave_snoopers(struct net_bridge *br); +void br_multicast_open(struct net_bridge *br); +void br_multicast_stop(struct net_bridge *br); +void br_multicast_dev_del(struct net_bridge *br); +void br_multicast_flood(struct net_bridge_mdb_entry *mdst, struct sk_buff *skb, + struct net_bridge_mcast *brmctx, + bool local_rcv, bool local_orig); +int br_multicast_set_router(struct net_bridge_mcast *brmctx, unsigned long val); +int br_multicast_set_port_router(struct net_bridge_mcast_port *pmctx, + unsigned long val); +int br_multicast_set_vlan_router(struct net_bridge_vlan *v, u8 mcast_router); +int br_multicast_toggle(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack); +int br_multicast_set_querier(struct net_bridge_mcast *brmctx, unsigned long val); +int br_multicast_set_hash_max(struct net_bridge *br, unsigned long val); +int br_multicast_set_igmp_version(struct net_bridge_mcast *brmctx, + unsigned long val); +#if IS_ENABLED(CONFIG_IPV6) +int br_multicast_set_mld_version(struct net_bridge_mcast *brmctx, + unsigned long val); +#endif +struct net_bridge_mdb_entry * +br_mdb_ip_get(struct net_bridge *br, struct br_ip *dst); +struct net_bridge_mdb_entry * +br_multicast_new_group(struct net_bridge *br, struct br_ip *group); +struct net_bridge_port_group * +br_multicast_new_port_group(struct net_bridge_port *port, struct br_ip *group, + struct net_bridge_port_group __rcu *next, + unsigned char flags, const unsigned char *src, + u8 filter_mode, u8 rt_protocol); +int br_mdb_hash_init(struct net_bridge *br); +void br_mdb_hash_fini(struct net_bridge *br); +void br_mdb_notify(struct net_device *dev, struct net_bridge_mdb_entry *mp, + struct net_bridge_port_group *pg, int type); +void br_rtr_notify(struct net_device *dev, struct net_bridge_mcast_port *pmctx, + int type); +void br_multicast_del_pg(struct net_bridge_mdb_entry *mp, + struct net_bridge_port_group *pg, + struct net_bridge_port_group __rcu **pp); +void br_multicast_count(struct net_bridge *br, + const struct net_bridge_port *p, + const struct sk_buff *skb, u8 type, u8 dir); +int br_multicast_init_stats(struct net_bridge *br); +void br_multicast_uninit_stats(struct net_bridge *br); +void br_multicast_get_stats(const struct net_bridge *br, + const struct net_bridge_port *p, + struct br_mcast_stats *dest); +void br_mdb_init(void); +void br_mdb_uninit(void); +void br_multicast_host_join(const struct net_bridge_mcast *brmctx, + struct net_bridge_mdb_entry *mp, bool notify); +void br_multicast_host_leave(struct net_bridge_mdb_entry *mp, bool notify); +void br_multicast_star_g_handle_mode(struct net_bridge_port_group *pg, + u8 filter_mode); +void br_multicast_sg_add_exclude_ports(struct net_bridge_mdb_entry *star_mp, + struct net_bridge_port_group *sg); +struct net_bridge_group_src * +br_multicast_find_group_src(struct net_bridge_port_group *pg, struct br_ip *ip); +void br_multicast_del_group_src(struct net_bridge_group_src *src, + bool fastleave); +void br_multicast_ctx_init(struct net_bridge *br, + struct net_bridge_vlan *vlan, + struct net_bridge_mcast *brmctx); +void br_multicast_ctx_deinit(struct net_bridge_mcast *brmctx); +void br_multicast_port_ctx_init(struct net_bridge_port *port, + struct net_bridge_vlan *vlan, + struct net_bridge_mcast_port *pmctx); +void br_multicast_port_ctx_deinit(struct net_bridge_mcast_port *pmctx); +void br_multicast_toggle_one_vlan(struct net_bridge_vlan *vlan, bool on); +int br_multicast_toggle_vlan_snooping(struct net_bridge *br, bool on, + struct netlink_ext_ack *extack); +bool br_multicast_toggle_global_vlan(struct net_bridge_vlan *vlan, bool on); + +int br_rports_fill_info(struct sk_buff *skb, + const struct net_bridge_mcast *brmctx); +int br_multicast_dump_querier_state(struct sk_buff *skb, + const struct net_bridge_mcast *brmctx, + int nest_attr); +size_t br_multicast_querier_state_size(void); +size_t br_rports_size(const struct net_bridge_mcast *brmctx); +void br_multicast_set_query_intvl(struct net_bridge_mcast *brmctx, + unsigned long val); +void br_multicast_set_startup_query_intvl(struct net_bridge_mcast *brmctx, + unsigned long val); + +static inline bool br_group_is_l2(const struct br_ip *group) +{ + return group->proto == 0; +} + +#define mlock_dereference(X, br) \ + rcu_dereference_protected(X, lockdep_is_held(&br->multicast_lock)) + +static inline struct hlist_node * +br_multicast_get_first_rport_node(struct net_bridge_mcast *brmctx, + struct sk_buff *skb) +{ +#if IS_ENABLED(CONFIG_IPV6) + if (skb->protocol == htons(ETH_P_IPV6)) + return rcu_dereference(hlist_first_rcu(&brmctx->ip6_mc_router_list)); +#endif + return rcu_dereference(hlist_first_rcu(&brmctx->ip4_mc_router_list)); +} + +static inline struct net_bridge_port * +br_multicast_rport_from_node_skb(struct hlist_node *rp, struct sk_buff *skb) +{ + struct net_bridge_mcast_port *mctx; + +#if IS_ENABLED(CONFIG_IPV6) + if (skb->protocol == htons(ETH_P_IPV6)) + mctx = hlist_entry_safe(rp, struct net_bridge_mcast_port, + ip6_rlist); + else +#endif + mctx = hlist_entry_safe(rp, struct net_bridge_mcast_port, + ip4_rlist); + + if (mctx) + return mctx->port; + else + return NULL; +} + +static inline bool br_ip4_multicast_is_router(struct net_bridge_mcast *brmctx) +{ + return timer_pending(&brmctx->ip4_mc_router_timer); +} + +static inline bool br_ip6_multicast_is_router(struct net_bridge_mcast *brmctx) +{ +#if IS_ENABLED(CONFIG_IPV6) + return timer_pending(&brmctx->ip6_mc_router_timer); +#else + return false; +#endif +} + +static inline bool +br_multicast_is_router(struct net_bridge_mcast *brmctx, struct sk_buff *skb) +{ + switch (brmctx->multicast_router) { + case MDB_RTR_TYPE_PERM: + return true; + case MDB_RTR_TYPE_TEMP_QUERY: + if (skb) { + if (skb->protocol == htons(ETH_P_IP)) + return br_ip4_multicast_is_router(brmctx); + else if (skb->protocol == htons(ETH_P_IPV6)) + return br_ip6_multicast_is_router(brmctx); + } else { + return br_ip4_multicast_is_router(brmctx) || + br_ip6_multicast_is_router(brmctx); + } + fallthrough; + default: + return false; + } +} + +static inline bool +__br_multicast_querier_exists(struct net_bridge_mcast *brmctx, + struct bridge_mcast_other_query *querier, + const bool is_ipv6) +{ + bool own_querier_enabled; + + if (brmctx->multicast_querier) { + if (is_ipv6 && !br_opt_get(brmctx->br, BROPT_HAS_IPV6_ADDR)) + own_querier_enabled = false; + else + own_querier_enabled = true; + } else { + own_querier_enabled = false; + } + + return time_is_before_jiffies(querier->delay_time) && + (own_querier_enabled || timer_pending(&querier->timer)); +} + +static inline bool br_multicast_querier_exists(struct net_bridge_mcast *brmctx, + struct ethhdr *eth, + const struct net_bridge_mdb_entry *mdb) +{ + switch (eth->h_proto) { + case (htons(ETH_P_IP)): + return __br_multicast_querier_exists(brmctx, + &brmctx->ip4_other_query, false); +#if IS_ENABLED(CONFIG_IPV6) + case (htons(ETH_P_IPV6)): + return __br_multicast_querier_exists(brmctx, + &brmctx->ip6_other_query, true); +#endif + default: + return !!mdb && br_group_is_l2(&mdb->addr); + } +} + +static inline bool br_multicast_is_star_g(const struct br_ip *ip) +{ + switch (ip->proto) { + case htons(ETH_P_IP): + return ipv4_is_zeronet(ip->src.ip4); +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): + return ipv6_addr_any(&ip->src.ip6); +#endif + default: + return false; + } +} + +static inline bool +br_multicast_should_handle_mode(const struct net_bridge_mcast *brmctx, + __be16 proto) +{ + switch (proto) { + case htons(ETH_P_IP): + return !!(brmctx->multicast_igmp_version == 3); +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): + return !!(brmctx->multicast_mld_version == 2); +#endif + default: + return false; + } +} + +static inline int br_multicast_igmp_type(const struct sk_buff *skb) +{ + return BR_INPUT_SKB_CB(skb)->igmp; +} + +static inline unsigned long br_multicast_lmqt(const struct net_bridge_mcast *brmctx) +{ + return brmctx->multicast_last_member_interval * + brmctx->multicast_last_member_count; +} + +static inline unsigned long br_multicast_gmi(const struct net_bridge_mcast *brmctx) +{ + return brmctx->multicast_membership_interval; +} + +static inline bool +br_multicast_ctx_is_vlan(const struct net_bridge_mcast *brmctx) +{ + return !!brmctx->vlan; +} + +static inline bool +br_multicast_port_ctx_is_vlan(const struct net_bridge_mcast_port *pmctx) +{ + return !!pmctx->vlan; +} + +static inline struct net_bridge_mcast * +br_multicast_port_ctx_get_global(const struct net_bridge_mcast_port *pmctx) +{ + if (!br_multicast_port_ctx_is_vlan(pmctx)) + return &pmctx->port->br->multicast_ctx; + else + return &pmctx->vlan->brvlan->br_mcast_ctx; +} + +static inline bool +br_multicast_ctx_vlan_global_disabled(const struct net_bridge_mcast *brmctx) +{ + return br_multicast_ctx_is_vlan(brmctx) && + (!br_opt_get(brmctx->br, BROPT_MCAST_VLAN_SNOOPING_ENABLED) || + !(brmctx->vlan->priv_flags & BR_VLFLAG_GLOBAL_MCAST_ENABLED)); +} + +static inline bool +br_multicast_ctx_vlan_disabled(const struct net_bridge_mcast *brmctx) +{ + return br_multicast_ctx_is_vlan(brmctx) && + !(brmctx->vlan->priv_flags & BR_VLFLAG_MCAST_ENABLED); +} + +static inline bool +br_multicast_port_ctx_vlan_disabled(const struct net_bridge_mcast_port *pmctx) +{ + return br_multicast_port_ctx_is_vlan(pmctx) && + !(pmctx->vlan->priv_flags & BR_VLFLAG_MCAST_ENABLED); +} + +static inline bool +br_multicast_port_ctx_state_disabled(const struct net_bridge_mcast_port *pmctx) +{ + return pmctx->port->state == BR_STATE_DISABLED || + (br_multicast_port_ctx_is_vlan(pmctx) && + (br_multicast_port_ctx_vlan_disabled(pmctx) || + pmctx->vlan->state == BR_STATE_DISABLED)); +} + +static inline bool +br_multicast_port_ctx_state_stopped(const struct net_bridge_mcast_port *pmctx) +{ + return br_multicast_port_ctx_state_disabled(pmctx) || + pmctx->port->state == BR_STATE_BLOCKING || + (br_multicast_port_ctx_is_vlan(pmctx) && + pmctx->vlan->state == BR_STATE_BLOCKING); +} + +static inline bool +br_rports_have_mc_router(const struct net_bridge_mcast *brmctx) +{ +#if IS_ENABLED(CONFIG_IPV6) + return !hlist_empty(&brmctx->ip4_mc_router_list) || + !hlist_empty(&brmctx->ip6_mc_router_list); +#else + return !hlist_empty(&brmctx->ip4_mc_router_list); +#endif +} + +static inline bool +br_multicast_ctx_options_equal(const struct net_bridge_mcast *brmctx1, + const struct net_bridge_mcast *brmctx2) +{ + return brmctx1->multicast_igmp_version == + brmctx2->multicast_igmp_version && + brmctx1->multicast_last_member_count == + brmctx2->multicast_last_member_count && + brmctx1->multicast_startup_query_count == + brmctx2->multicast_startup_query_count && + brmctx1->multicast_last_member_interval == + brmctx2->multicast_last_member_interval && + brmctx1->multicast_membership_interval == + brmctx2->multicast_membership_interval && + brmctx1->multicast_querier_interval == + brmctx2->multicast_querier_interval && + brmctx1->multicast_query_interval == + brmctx2->multicast_query_interval && + brmctx1->multicast_query_response_interval == + brmctx2->multicast_query_response_interval && + brmctx1->multicast_startup_query_interval == + brmctx2->multicast_startup_query_interval && + brmctx1->multicast_querier == brmctx2->multicast_querier && + brmctx1->multicast_router == brmctx2->multicast_router && + !br_rports_have_mc_router(brmctx1) && + !br_rports_have_mc_router(brmctx2) && +#if IS_ENABLED(CONFIG_IPV6) + brmctx1->multicast_mld_version == + brmctx2->multicast_mld_version && +#endif + true; +} + +static inline bool +br_multicast_ctx_matches_vlan_snooping(const struct net_bridge_mcast *brmctx) +{ + bool vlan_snooping_enabled; + + vlan_snooping_enabled = !!br_opt_get(brmctx->br, + BROPT_MCAST_VLAN_SNOOPING_ENABLED); + + return !!(vlan_snooping_enabled == br_multicast_ctx_is_vlan(brmctx)); +} +#else +static inline int br_multicast_rcv(struct net_bridge_mcast **brmctx, + struct net_bridge_mcast_port **pmctx, + struct net_bridge_vlan *vlan, + struct sk_buff *skb, + u16 vid) +{ + return 0; +} + +static inline struct net_bridge_mdb_entry *br_mdb_get(struct net_bridge_mcast *brmctx, + struct sk_buff *skb, u16 vid) +{ + return NULL; +} + +static inline int br_multicast_add_port(struct net_bridge_port *port) +{ + return 0; +} + +static inline void br_multicast_del_port(struct net_bridge_port *port) +{ +} + +static inline void br_multicast_enable_port(struct net_bridge_port *port) +{ +} + +static inline void br_multicast_disable_port(struct net_bridge_port *port) +{ +} + +static inline void br_multicast_init(struct net_bridge *br) +{ +} + +static inline void br_multicast_join_snoopers(struct net_bridge *br) +{ +} + +static inline void br_multicast_leave_snoopers(struct net_bridge *br) +{ +} + +static inline void br_multicast_open(struct net_bridge *br) +{ +} + +static inline void br_multicast_stop(struct net_bridge *br) +{ +} + +static inline void br_multicast_dev_del(struct net_bridge *br) +{ +} + +static inline void br_multicast_flood(struct net_bridge_mdb_entry *mdst, + struct sk_buff *skb, + struct net_bridge_mcast *brmctx, + bool local_rcv, bool local_orig) +{ +} + +static inline bool br_multicast_is_router(struct net_bridge_mcast *brmctx, + struct sk_buff *skb) +{ + return false; +} + +static inline bool br_multicast_querier_exists(struct net_bridge_mcast *brmctx, + struct ethhdr *eth, + const struct net_bridge_mdb_entry *mdb) +{ + return false; +} + +static inline void br_mdb_init(void) +{ +} + +static inline void br_mdb_uninit(void) +{ +} + +static inline int br_mdb_hash_init(struct net_bridge *br) +{ + return 0; +} + +static inline void br_mdb_hash_fini(struct net_bridge *br) +{ +} + +static inline void br_multicast_count(struct net_bridge *br, + const struct net_bridge_port *p, + const struct sk_buff *skb, + u8 type, u8 dir) +{ +} + +static inline int br_multicast_init_stats(struct net_bridge *br) +{ + return 0; +} + +static inline void br_multicast_uninit_stats(struct net_bridge *br) +{ +} + +static inline int br_multicast_igmp_type(const struct sk_buff *skb) +{ + return 0; +} + +static inline void br_multicast_ctx_init(struct net_bridge *br, + struct net_bridge_vlan *vlan, + struct net_bridge_mcast *brmctx) +{ +} + +static inline void br_multicast_ctx_deinit(struct net_bridge_mcast *brmctx) +{ +} + +static inline void br_multicast_port_ctx_init(struct net_bridge_port *port, + struct net_bridge_vlan *vlan, + struct net_bridge_mcast_port *pmctx) +{ +} + +static inline void br_multicast_port_ctx_deinit(struct net_bridge_mcast_port *pmctx) +{ +} + +static inline void br_multicast_toggle_one_vlan(struct net_bridge_vlan *vlan, + bool on) +{ +} + +static inline int br_multicast_toggle_vlan_snooping(struct net_bridge *br, + bool on, + struct netlink_ext_ack *extack) +{ + return -EOPNOTSUPP; +} + +static inline bool br_multicast_toggle_global_vlan(struct net_bridge_vlan *vlan, + bool on) +{ + return false; +} + +static inline bool +br_multicast_ctx_options_equal(const struct net_bridge_mcast *brmctx1, + const struct net_bridge_mcast *brmctx2) +{ + return true; +} +#endif + +/* br_vlan.c */ +#ifdef CONFIG_BRIDGE_VLAN_FILTERING +bool br_allowed_ingress(const struct net_bridge *br, + struct net_bridge_vlan_group *vg, struct sk_buff *skb, + u16 *vid, u8 *state, + struct net_bridge_vlan **vlan); +bool br_allowed_egress(struct net_bridge_vlan_group *vg, + const struct sk_buff *skb); +bool br_should_learn(struct net_bridge_port *p, struct sk_buff *skb, u16 *vid); +struct sk_buff *br_handle_vlan(struct net_bridge *br, + const struct net_bridge_port *port, + struct net_bridge_vlan_group *vg, + struct sk_buff *skb); +int br_vlan_add(struct net_bridge *br, u16 vid, u16 flags, + bool *changed, struct netlink_ext_ack *extack); +int br_vlan_delete(struct net_bridge *br, u16 vid); +void br_vlan_flush(struct net_bridge *br); +struct net_bridge_vlan *br_vlan_find(struct net_bridge_vlan_group *vg, u16 vid); +void br_recalculate_fwd_mask(struct net_bridge *br); +int br_vlan_filter_toggle(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack); +int __br_vlan_set_proto(struct net_bridge *br, __be16 proto, + struct netlink_ext_ack *extack); +int br_vlan_set_proto(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack); +int br_vlan_set_stats(struct net_bridge *br, unsigned long val); +int br_vlan_set_stats_per_port(struct net_bridge *br, unsigned long val); +int br_vlan_init(struct net_bridge *br); +int br_vlan_set_default_pvid(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack); +int __br_vlan_set_default_pvid(struct net_bridge *br, u16 pvid, + struct netlink_ext_ack *extack); +int nbp_vlan_add(struct net_bridge_port *port, u16 vid, u16 flags, + bool *changed, struct netlink_ext_ack *extack); +int nbp_vlan_delete(struct net_bridge_port *port, u16 vid); +void nbp_vlan_flush(struct net_bridge_port *port); +int nbp_vlan_init(struct net_bridge_port *port, struct netlink_ext_ack *extack); +int nbp_get_num_vlan_infos(struct net_bridge_port *p, u32 filter_mask); +void br_vlan_get_stats(const struct net_bridge_vlan *v, + struct pcpu_sw_netstats *stats); +void br_vlan_port_event(struct net_bridge_port *p, unsigned long event); +int br_vlan_bridge_event(struct net_device *dev, unsigned long event, + void *ptr); +void br_vlan_rtnl_init(void); +void br_vlan_rtnl_uninit(void); +void br_vlan_notify(const struct net_bridge *br, + const struct net_bridge_port *p, + u16 vid, u16 vid_range, + int cmd); +bool br_vlan_can_enter_range(const struct net_bridge_vlan *v_curr, + const struct net_bridge_vlan *range_end); + +void br_vlan_fill_forward_path_pvid(struct net_bridge *br, + struct net_device_path_ctx *ctx, + struct net_device_path *path); +int br_vlan_fill_forward_path_mode(struct net_bridge *br, + struct net_bridge_port *dst, + struct net_device_path *path); + +static inline struct net_bridge_vlan_group *br_vlan_group( + const struct net_bridge *br) +{ + return rtnl_dereference(br->vlgrp); +} + +static inline struct net_bridge_vlan_group *nbp_vlan_group( + const struct net_bridge_port *p) +{ + return rtnl_dereference(p->vlgrp); +} + +static inline struct net_bridge_vlan_group *br_vlan_group_rcu( + const struct net_bridge *br) +{ + return rcu_dereference(br->vlgrp); +} + +static inline struct net_bridge_vlan_group *nbp_vlan_group_rcu( + const struct net_bridge_port *p) +{ + return rcu_dereference(p->vlgrp); +} + +/* Since bridge now depends on 8021Q module, but the time bridge sees the + * skb, the vlan tag will always be present if the frame was tagged. + */ +static inline int br_vlan_get_tag(const struct sk_buff *skb, u16 *vid) +{ + int err = 0; + + if (skb_vlan_tag_present(skb)) { + *vid = skb_vlan_tag_get_id(skb); + } else { + *vid = 0; + err = -EINVAL; + } + + return err; +} + +static inline u16 br_get_pvid(const struct net_bridge_vlan_group *vg) +{ + if (!vg) + return 0; + + smp_rmb(); + return vg->pvid; +} + +static inline u16 br_vlan_flags(const struct net_bridge_vlan *v, u16 pvid) +{ + return v->vid == pvid ? v->flags | BRIDGE_VLAN_INFO_PVID : v->flags; +} +#else +static inline bool br_allowed_ingress(const struct net_bridge *br, + struct net_bridge_vlan_group *vg, + struct sk_buff *skb, + u16 *vid, u8 *state, + struct net_bridge_vlan **vlan) + +{ + *vlan = NULL; + return true; +} + +static inline bool br_allowed_egress(struct net_bridge_vlan_group *vg, + const struct sk_buff *skb) +{ + return true; +} + +static inline bool br_should_learn(struct net_bridge_port *p, + struct sk_buff *skb, u16 *vid) +{ + return true; +} + +static inline struct sk_buff *br_handle_vlan(struct net_bridge *br, + const struct net_bridge_port *port, + struct net_bridge_vlan_group *vg, + struct sk_buff *skb) +{ + return skb; +} + +static inline int br_vlan_add(struct net_bridge *br, u16 vid, u16 flags, + bool *changed, struct netlink_ext_ack *extack) +{ + *changed = false; + return -EOPNOTSUPP; +} + +static inline int br_vlan_delete(struct net_bridge *br, u16 vid) +{ + return -EOPNOTSUPP; +} + +static inline void br_vlan_flush(struct net_bridge *br) +{ +} + +static inline void br_recalculate_fwd_mask(struct net_bridge *br) +{ +} + +static inline int br_vlan_init(struct net_bridge *br) +{ + return 0; +} + +static inline int nbp_vlan_add(struct net_bridge_port *port, u16 vid, u16 flags, + bool *changed, struct netlink_ext_ack *extack) +{ + *changed = false; + return -EOPNOTSUPP; +} + +static inline int nbp_vlan_delete(struct net_bridge_port *port, u16 vid) +{ + return -EOPNOTSUPP; +} + +static inline void nbp_vlan_flush(struct net_bridge_port *port) +{ +} + +static inline struct net_bridge_vlan *br_vlan_find(struct net_bridge_vlan_group *vg, + u16 vid) +{ + return NULL; +} + +static inline int nbp_vlan_init(struct net_bridge_port *port, + struct netlink_ext_ack *extack) +{ + return 0; +} + +static inline u16 br_vlan_get_tag(const struct sk_buff *skb, u16 *tag) +{ + return 0; +} + +static inline u16 br_get_pvid(const struct net_bridge_vlan_group *vg) +{ + return 0; +} + +static inline int br_vlan_filter_toggle(struct net_bridge *br, + unsigned long val, + struct netlink_ext_ack *extack) +{ + return -EOPNOTSUPP; +} + +static inline int nbp_get_num_vlan_infos(struct net_bridge_port *p, + u32 filter_mask) +{ + return 0; +} + +static inline void br_vlan_fill_forward_path_pvid(struct net_bridge *br, + struct net_device_path_ctx *ctx, + struct net_device_path *path) +{ +} + +static inline int br_vlan_fill_forward_path_mode(struct net_bridge *br, + struct net_bridge_port *dst, + struct net_device_path *path) +{ + return 0; +} + +static inline struct net_bridge_vlan_group *br_vlan_group( + const struct net_bridge *br) +{ + return NULL; +} + +static inline struct net_bridge_vlan_group *nbp_vlan_group( + const struct net_bridge_port *p) +{ + return NULL; +} + +static inline struct net_bridge_vlan_group *br_vlan_group_rcu( + const struct net_bridge *br) +{ + return NULL; +} + +static inline struct net_bridge_vlan_group *nbp_vlan_group_rcu( + const struct net_bridge_port *p) +{ + return NULL; +} + +static inline void br_vlan_get_stats(const struct net_bridge_vlan *v, + struct pcpu_sw_netstats *stats) +{ +} + +static inline void br_vlan_port_event(struct net_bridge_port *p, + unsigned long event) +{ +} + +static inline int br_vlan_bridge_event(struct net_device *dev, + unsigned long event, void *ptr) +{ + return 0; +} + +static inline void br_vlan_rtnl_init(void) +{ +} + +static inline void br_vlan_rtnl_uninit(void) +{ +} + +static inline void br_vlan_notify(const struct net_bridge *br, + const struct net_bridge_port *p, + u16 vid, u16 vid_range, + int cmd) +{ +} + +static inline bool br_vlan_can_enter_range(const struct net_bridge_vlan *v_curr, + const struct net_bridge_vlan *range_end) +{ + return true; +} + +static inline u16 br_vlan_flags(const struct net_bridge_vlan *v, u16 pvid) +{ + return 0; +} + +#endif + +/* br_vlan_options.c */ +#ifdef CONFIG_BRIDGE_VLAN_FILTERING +bool br_vlan_opts_eq_range(const struct net_bridge_vlan *v_curr, + const struct net_bridge_vlan *range_end); +bool br_vlan_opts_fill(struct sk_buff *skb, const struct net_bridge_vlan *v); +size_t br_vlan_opts_nl_size(void); +int br_vlan_process_options(const struct net_bridge *br, + const struct net_bridge_port *p, + struct net_bridge_vlan *range_start, + struct net_bridge_vlan *range_end, + struct nlattr **tb, + struct netlink_ext_ack *extack); +int br_vlan_rtm_process_global_options(struct net_device *dev, + const struct nlattr *attr, + int cmd, + struct netlink_ext_ack *extack); +bool br_vlan_global_opts_can_enter_range(const struct net_bridge_vlan *v_curr, + const struct net_bridge_vlan *r_end); +bool br_vlan_global_opts_fill(struct sk_buff *skb, u16 vid, u16 vid_range, + const struct net_bridge_vlan *v_opts); + +/* vlan state manipulation helpers using *_ONCE to annotate lock-free access */ +static inline u8 br_vlan_get_state(const struct net_bridge_vlan *v) +{ + return READ_ONCE(v->state); +} + +static inline void br_vlan_set_state(struct net_bridge_vlan *v, u8 state) +{ + WRITE_ONCE(v->state, state); +} + +static inline u8 br_vlan_get_pvid_state(const struct net_bridge_vlan_group *vg) +{ + return READ_ONCE(vg->pvid_state); +} + +static inline void br_vlan_set_pvid_state(struct net_bridge_vlan_group *vg, + u8 state) +{ + WRITE_ONCE(vg->pvid_state, state); +} + +/* learn_allow is true at ingress and false at egress */ +static inline bool br_vlan_state_allowed(u8 state, bool learn_allow) +{ + switch (state) { + case BR_STATE_LEARNING: + return learn_allow; + case BR_STATE_FORWARDING: + return true; + default: + return false; + } +} +#endif + +/* br_mst.c */ +#ifdef CONFIG_BRIDGE_VLAN_FILTERING +DECLARE_STATIC_KEY_FALSE(br_mst_used); +static inline bool br_mst_is_enabled(struct net_bridge *br) +{ + return static_branch_unlikely(&br_mst_used) && + br_opt_get(br, BROPT_MST_ENABLED); +} + +int br_mst_set_state(struct net_bridge_port *p, u16 msti, u8 state, + struct netlink_ext_ack *extack); +int br_mst_vlan_set_msti(struct net_bridge_vlan *v, u16 msti); +void br_mst_vlan_init_state(struct net_bridge_vlan *v); +int br_mst_set_enabled(struct net_bridge *br, bool on, + struct netlink_ext_ack *extack); +size_t br_mst_info_size(const struct net_bridge_vlan_group *vg); +int br_mst_fill_info(struct sk_buff *skb, + const struct net_bridge_vlan_group *vg); +int br_mst_process(struct net_bridge_port *p, const struct nlattr *mst_attr, + struct netlink_ext_ack *extack); +#else +static inline bool br_mst_is_enabled(struct net_bridge *br) +{ + return false; +} + +static inline int br_mst_set_state(struct net_bridge_port *p, u16 msti, + u8 state, struct netlink_ext_ack *extack) +{ + return -EOPNOTSUPP; +} + +static inline int br_mst_set_enabled(struct net_bridge *br, bool on, + struct netlink_ext_ack *extack) +{ + return -EOPNOTSUPP; +} + +static inline size_t br_mst_info_size(const struct net_bridge_vlan_group *vg) +{ + return 0; +} + +static inline int br_mst_fill_info(struct sk_buff *skb, + const struct net_bridge_vlan_group *vg) +{ + return -EOPNOTSUPP; +} + +static inline int br_mst_process(struct net_bridge_port *p, + const struct nlattr *mst_attr, + struct netlink_ext_ack *extack) +{ + return -EOPNOTSUPP; +} +#endif + +struct nf_br_ops { + int (*br_dev_xmit_hook)(struct sk_buff *skb); +}; +extern const struct nf_br_ops __rcu *nf_br_ops; + +/* br_netfilter.c */ +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) +int br_nf_core_init(void); +void br_nf_core_fini(void); +void br_netfilter_rtable_init(struct net_bridge *); +#else +static inline int br_nf_core_init(void) { return 0; } +static inline void br_nf_core_fini(void) {} +#define br_netfilter_rtable_init(x) +#endif + +/* br_stp.c */ +void br_set_state(struct net_bridge_port *p, unsigned int state); +struct net_bridge_port *br_get_port(struct net_bridge *br, u16 port_no); +void br_init_port(struct net_bridge_port *p); +void br_become_designated_port(struct net_bridge_port *p); + +void __br_set_forward_delay(struct net_bridge *br, unsigned long t); +int br_set_forward_delay(struct net_bridge *br, unsigned long x); +int br_set_hello_time(struct net_bridge *br, unsigned long x); +int br_set_max_age(struct net_bridge *br, unsigned long x); +int __set_ageing_time(struct net_device *dev, unsigned long t); +int br_set_ageing_time(struct net_bridge *br, clock_t ageing_time); + + +/* br_stp_if.c */ +void br_stp_enable_bridge(struct net_bridge *br); +void br_stp_disable_bridge(struct net_bridge *br); +int br_stp_set_enabled(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack); +void br_stp_enable_port(struct net_bridge_port *p); +void br_stp_disable_port(struct net_bridge_port *p); +bool br_stp_recalculate_bridge_id(struct net_bridge *br); +void br_stp_change_bridge_id(struct net_bridge *br, const unsigned char *a); +void br_stp_set_bridge_priority(struct net_bridge *br, u16 newprio); +int br_stp_set_port_priority(struct net_bridge_port *p, unsigned long newprio); +int br_stp_set_path_cost(struct net_bridge_port *p, unsigned long path_cost); +ssize_t br_show_bridge_id(char *buf, const struct bridge_id *id); + +/* br_stp_bpdu.c */ +struct stp_proto; +void br_stp_rcv(const struct stp_proto *proto, struct sk_buff *skb, + struct net_device *dev); + +/* br_stp_timer.c */ +void br_stp_timer_init(struct net_bridge *br); +void br_stp_port_timer_init(struct net_bridge_port *p); +unsigned long br_timer_value(const struct timer_list *timer); + +/* br.c */ +#if IS_ENABLED(CONFIG_ATM_LANE) +extern int (*br_fdb_test_addr_hook)(struct net_device *dev, unsigned char *addr); +#endif + +/* br_mrp.c */ +#if IS_ENABLED(CONFIG_BRIDGE_MRP) +int br_mrp_parse(struct net_bridge *br, struct net_bridge_port *p, + struct nlattr *attr, int cmd, struct netlink_ext_ack *extack); +bool br_mrp_enabled(struct net_bridge *br); +void br_mrp_port_del(struct net_bridge *br, struct net_bridge_port *p); +int br_mrp_fill_info(struct sk_buff *skb, struct net_bridge *br); +#else +static inline int br_mrp_parse(struct net_bridge *br, struct net_bridge_port *p, + struct nlattr *attr, int cmd, + struct netlink_ext_ack *extack) +{ + return -EOPNOTSUPP; +} + +static inline bool br_mrp_enabled(struct net_bridge *br) +{ + return false; +} + +static inline void br_mrp_port_del(struct net_bridge *br, + struct net_bridge_port *p) +{ +} + +static inline int br_mrp_fill_info(struct sk_buff *skb, struct net_bridge *br) +{ + return 0; +} + +#endif + +/* br_cfm.c */ +#if IS_ENABLED(CONFIG_BRIDGE_CFM) +int br_cfm_parse(struct net_bridge *br, struct net_bridge_port *p, + struct nlattr *attr, int cmd, struct netlink_ext_ack *extack); +bool br_cfm_created(struct net_bridge *br); +void br_cfm_port_del(struct net_bridge *br, struct net_bridge_port *p); +int br_cfm_config_fill_info(struct sk_buff *skb, struct net_bridge *br); +int br_cfm_status_fill_info(struct sk_buff *skb, + struct net_bridge *br, + bool getlink); +int br_cfm_mep_count(struct net_bridge *br, u32 *count); +int br_cfm_peer_mep_count(struct net_bridge *br, u32 *count); +#else +static inline int br_cfm_parse(struct net_bridge *br, struct net_bridge_port *p, + struct nlattr *attr, int cmd, + struct netlink_ext_ack *extack) +{ + return -EOPNOTSUPP; +} + +static inline bool br_cfm_created(struct net_bridge *br) +{ + return false; +} + +static inline void br_cfm_port_del(struct net_bridge *br, + struct net_bridge_port *p) +{ +} + +static inline int br_cfm_config_fill_info(struct sk_buff *skb, struct net_bridge *br) +{ + return -EOPNOTSUPP; +} + +static inline int br_cfm_status_fill_info(struct sk_buff *skb, + struct net_bridge *br, + bool getlink) +{ + return -EOPNOTSUPP; +} + +static inline int br_cfm_mep_count(struct net_bridge *br, u32 *count) +{ + *count = 0; + return -EOPNOTSUPP; +} + +static inline int br_cfm_peer_mep_count(struct net_bridge *br, u32 *count) +{ + *count = 0; + return -EOPNOTSUPP; +} +#endif + +/* br_netlink.c */ +extern struct rtnl_link_ops br_link_ops; +int br_netlink_init(void); +void br_netlink_fini(void); +void br_ifinfo_notify(int event, const struct net_bridge *br, + const struct net_bridge_port *port); +void br_info_notify(int event, const struct net_bridge *br, + const struct net_bridge_port *port, u32 filter); +int br_setlink(struct net_device *dev, struct nlmsghdr *nlmsg, u16 flags, + struct netlink_ext_ack *extack); +int br_dellink(struct net_device *dev, struct nlmsghdr *nlmsg, u16 flags); +int br_getlink(struct sk_buff *skb, u32 pid, u32 seq, struct net_device *dev, + u32 filter_mask, int nlflags); +int br_process_vlan_info(struct net_bridge *br, + struct net_bridge_port *p, int cmd, + struct bridge_vlan_info *vinfo_curr, + struct bridge_vlan_info **vinfo_last, + bool *changed, + struct netlink_ext_ack *extack); + +#ifdef CONFIG_SYSFS +/* br_sysfs_if.c */ +extern const struct sysfs_ops brport_sysfs_ops; +int br_sysfs_addif(struct net_bridge_port *p); +int br_sysfs_renameif(struct net_bridge_port *p); + +/* br_sysfs_br.c */ +int br_sysfs_addbr(struct net_device *dev); +void br_sysfs_delbr(struct net_device *dev); + +#else + +static inline int br_sysfs_addif(struct net_bridge_port *p) { return 0; } +static inline int br_sysfs_renameif(struct net_bridge_port *p) { return 0; } +static inline int br_sysfs_addbr(struct net_device *dev) { return 0; } +static inline void br_sysfs_delbr(struct net_device *dev) { return; } +#endif /* CONFIG_SYSFS */ + +/* br_switchdev.c */ +#ifdef CONFIG_NET_SWITCHDEV +int br_switchdev_port_offload(struct net_bridge_port *p, + struct net_device *dev, const void *ctx, + struct notifier_block *atomic_nb, + struct notifier_block *blocking_nb, + bool tx_fwd_offload, + struct netlink_ext_ack *extack); + +void br_switchdev_port_unoffload(struct net_bridge_port *p, const void *ctx, + struct notifier_block *atomic_nb, + struct notifier_block *blocking_nb); + +bool br_switchdev_frame_uses_tx_fwd_offload(struct sk_buff *skb); + +void br_switchdev_frame_set_offload_fwd_mark(struct sk_buff *skb); + +void nbp_switchdev_frame_mark_tx_fwd_offload(const struct net_bridge_port *p, + struct sk_buff *skb); +void nbp_switchdev_frame_mark_tx_fwd_to_hwdom(const struct net_bridge_port *p, + struct sk_buff *skb); +void nbp_switchdev_frame_mark(const struct net_bridge_port *p, + struct sk_buff *skb); +bool nbp_switchdev_allowed_egress(const struct net_bridge_port *p, + const struct sk_buff *skb); +int br_switchdev_set_port_flag(struct net_bridge_port *p, + unsigned long flags, + unsigned long mask, + struct netlink_ext_ack *extack); +void br_switchdev_fdb_notify(struct net_bridge *br, + const struct net_bridge_fdb_entry *fdb, int type); +void br_switchdev_mdb_notify(struct net_device *dev, + struct net_bridge_mdb_entry *mp, + struct net_bridge_port_group *pg, + int type); +int br_switchdev_port_vlan_add(struct net_device *dev, u16 vid, u16 flags, + bool changed, struct netlink_ext_ack *extack); +int br_switchdev_port_vlan_del(struct net_device *dev, u16 vid); +void br_switchdev_init(struct net_bridge *br); + +static inline void br_switchdev_frame_unmark(struct sk_buff *skb) +{ + skb->offload_fwd_mark = 0; +} +#else +static inline int +br_switchdev_port_offload(struct net_bridge_port *p, + struct net_device *dev, const void *ctx, + struct notifier_block *atomic_nb, + struct notifier_block *blocking_nb, + bool tx_fwd_offload, + struct netlink_ext_ack *extack) +{ + return -EOPNOTSUPP; +} + +static inline void +br_switchdev_port_unoffload(struct net_bridge_port *p, const void *ctx, + struct notifier_block *atomic_nb, + struct notifier_block *blocking_nb) +{ +} + +static inline bool br_switchdev_frame_uses_tx_fwd_offload(struct sk_buff *skb) +{ + return false; +} + +static inline void br_switchdev_frame_set_offload_fwd_mark(struct sk_buff *skb) +{ +} + +static inline void +nbp_switchdev_frame_mark_tx_fwd_offload(const struct net_bridge_port *p, + struct sk_buff *skb) +{ +} + +static inline void +nbp_switchdev_frame_mark_tx_fwd_to_hwdom(const struct net_bridge_port *p, + struct sk_buff *skb) +{ +} + +static inline void nbp_switchdev_frame_mark(const struct net_bridge_port *p, + struct sk_buff *skb) +{ +} + +static inline bool nbp_switchdev_allowed_egress(const struct net_bridge_port *p, + const struct sk_buff *skb) +{ + return true; +} + +static inline int br_switchdev_set_port_flag(struct net_bridge_port *p, + unsigned long flags, + unsigned long mask, + struct netlink_ext_ack *extack) +{ + return 0; +} + +static inline int br_switchdev_port_vlan_add(struct net_device *dev, u16 vid, + u16 flags, bool changed, + struct netlink_ext_ack *extack) +{ + return -EOPNOTSUPP; +} + +static inline int br_switchdev_port_vlan_del(struct net_device *dev, u16 vid) +{ + return -EOPNOTSUPP; +} + +static inline void +br_switchdev_fdb_notify(struct net_bridge *br, + const struct net_bridge_fdb_entry *fdb, int type) +{ +} + +static inline void br_switchdev_mdb_notify(struct net_device *dev, + struct net_bridge_mdb_entry *mp, + struct net_bridge_port_group *pg, + int type) +{ +} + +static inline void br_switchdev_frame_unmark(struct sk_buff *skb) +{ +} + +static inline void br_switchdev_init(struct net_bridge *br) +{ +} + +#endif /* CONFIG_NET_SWITCHDEV */ + +/* br_arp_nd_proxy.c */ +void br_recalculate_neigh_suppress_enabled(struct net_bridge *br); +void br_do_proxy_suppress_arp(struct sk_buff *skb, struct net_bridge *br, + u16 vid, struct net_bridge_port *p); +void br_do_suppress_nd(struct sk_buff *skb, struct net_bridge *br, + u16 vid, struct net_bridge_port *p, struct nd_msg *msg); +struct nd_msg *br_is_nd_neigh_msg(struct sk_buff *skb, struct nd_msg *m); +#endif diff --git a/net/bridge/br_private_cfm.h b/net/bridge/br_private_cfm.h new file mode 100644 index 000000000..a43a5e7fa --- /dev/null +++ b/net/bridge/br_private_cfm.h @@ -0,0 +1,147 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ + +#ifndef _BR_PRIVATE_CFM_H_ +#define _BR_PRIVATE_CFM_H_ + +#include "br_private.h" +#include + +struct br_cfm_mep_create { + enum br_cfm_domain domain; /* Domain for this MEP */ + enum br_cfm_mep_direction direction; /* Up or Down MEP direction */ + u32 ifindex; /* Residence port */ +}; + +int br_cfm_mep_create(struct net_bridge *br, + const u32 instance, + struct br_cfm_mep_create *const create, + struct netlink_ext_ack *extack); + +int br_cfm_mep_delete(struct net_bridge *br, + const u32 instance, + struct netlink_ext_ack *extack); + +struct br_cfm_mep_config { + u32 mdlevel; + u32 mepid; /* MEPID for this MEP */ + struct mac_addr unicast_mac; /* The MEP unicast MAC */ +}; + +int br_cfm_mep_config_set(struct net_bridge *br, + const u32 instance, + const struct br_cfm_mep_config *const config, + struct netlink_ext_ack *extack); + +struct br_cfm_maid { + u8 data[CFM_MAID_LENGTH]; +}; + +struct br_cfm_cc_config { + /* Expected received CCM PDU MAID. */ + struct br_cfm_maid exp_maid; + + /* Expected received CCM PDU interval. */ + /* Transmitting CCM PDU interval when CCM tx is enabled. */ + enum br_cfm_ccm_interval exp_interval; + + bool enable; /* Enable/disable CCM PDU handling */ +}; + +int br_cfm_cc_config_set(struct net_bridge *br, + const u32 instance, + const struct br_cfm_cc_config *const config, + struct netlink_ext_ack *extack); + +int br_cfm_cc_peer_mep_add(struct net_bridge *br, const u32 instance, + u32 peer_mep_id, + struct netlink_ext_ack *extack); +int br_cfm_cc_peer_mep_remove(struct net_bridge *br, const u32 instance, + u32 peer_mep_id, + struct netlink_ext_ack *extack); + +/* Transmitted CCM Remote Defect Indication status set. + * This RDI is inserted in transmitted CCM PDUs if CCM transmission is enabled. + * See br_cfm_cc_ccm_tx() with interval != BR_CFM_CCM_INTERVAL_NONE + */ +int br_cfm_cc_rdi_set(struct net_bridge *br, const u32 instance, + const bool rdi, struct netlink_ext_ack *extack); + +/* OAM PDU Tx information */ +struct br_cfm_cc_ccm_tx_info { + struct mac_addr dmac; + /* The CCM will be transmitted for this period in seconds. + * Call br_cfm_cc_ccm_tx before timeout to keep transmission alive. + * When period is zero any ongoing transmission will be stopped. + */ + u32 period; + + bool seq_no_update; /* Update Tx CCM sequence number */ + bool if_tlv; /* Insert Interface Status TLV */ + u8 if_tlv_value; /* Interface Status TLV value */ + bool port_tlv; /* Insert Port Status TLV */ + u8 port_tlv_value; /* Port Status TLV value */ + /* Sender ID TLV ?? + * Organization-Specific TLV ?? + */ +}; + +int br_cfm_cc_ccm_tx(struct net_bridge *br, const u32 instance, + const struct br_cfm_cc_ccm_tx_info *const tx_info, + struct netlink_ext_ack *extack); + +struct br_cfm_mep_status { + /* Indications that an OAM PDU has been seen. */ + bool opcode_unexp_seen; /* RX of OAM PDU with unexpected opcode */ + bool version_unexp_seen; /* RX of OAM PDU with unexpected version */ + bool rx_level_low_seen; /* Rx of OAM PDU with level low */ +}; + +struct br_cfm_cc_peer_status { + /* This CCM related status is based on the latest received CCM PDU. */ + u8 port_tlv_value; /* Port Status TLV value */ + u8 if_tlv_value; /* Interface Status TLV value */ + + /* CCM has not been received for 3.25 intervals */ + u8 ccm_defect:1; + + /* (RDI == 1) for last received CCM PDU */ + u8 rdi:1; + + /* Indications that a CCM PDU has been seen. */ + u8 seen:1; /* CCM PDU received */ + u8 tlv_seen:1; /* CCM PDU with TLV received */ + /* CCM PDU with unexpected sequence number received */ + u8 seq_unexp_seen:1; +}; + +struct br_cfm_mep { + /* list header of MEP instances */ + struct hlist_node head; + u32 instance; + struct br_cfm_mep_create create; + struct br_cfm_mep_config config; + struct br_cfm_cc_config cc_config; + struct br_cfm_cc_ccm_tx_info cc_ccm_tx_info; + /* List of multiple peer MEPs */ + struct hlist_head peer_mep_list; + struct net_bridge_port __rcu *b_port; + unsigned long ccm_tx_end; + struct delayed_work ccm_tx_dwork; + u32 ccm_tx_snumber; + u32 ccm_rx_snumber; + struct br_cfm_mep_status status; + bool rdi; + struct rcu_head rcu; +}; + +struct br_cfm_peer_mep { + struct hlist_node head; + struct br_cfm_mep *mep; + struct delayed_work ccm_rx_dwork; + u32 mepid; + struct br_cfm_cc_peer_status cc_status; + u32 ccm_rx_count_miss; + struct rcu_head rcu; +}; + +#endif /* _BR_PRIVATE_CFM_H_ */ diff --git a/net/bridge/br_private_mcast_eht.h b/net/bridge/br_private_mcast_eht.h new file mode 100644 index 000000000..adf82a055 --- /dev/null +++ b/net/bridge/br_private_mcast_eht.h @@ -0,0 +1,94 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later + * Copyright (c) 2020, Nikolay Aleksandrov + */ +#ifndef _BR_PRIVATE_MCAST_EHT_H_ +#define _BR_PRIVATE_MCAST_EHT_H_ + +#define BR_MCAST_DEFAULT_EHT_HOSTS_LIMIT 512 + +union net_bridge_eht_addr { + __be32 ip4; +#if IS_ENABLED(CONFIG_IPV6) + struct in6_addr ip6; +#endif +}; + +/* single host's list of set entries and filter_mode */ +struct net_bridge_group_eht_host { + struct rb_node rb_node; + + union net_bridge_eht_addr h_addr; + struct hlist_head set_entries; + unsigned int num_entries; + unsigned char filter_mode; + struct net_bridge_port_group *pg; +}; + +/* (host, src entry) added to a per-src set and host's list */ +struct net_bridge_group_eht_set_entry { + struct rb_node rb_node; + struct hlist_node host_list; + + union net_bridge_eht_addr h_addr; + struct timer_list timer; + struct net_bridge *br; + struct net_bridge_group_eht_set *eht_set; + struct net_bridge_group_eht_host *h_parent; + struct net_bridge_mcast_gc mcast_gc; +}; + +/* per-src set */ +struct net_bridge_group_eht_set { + struct rb_node rb_node; + + union net_bridge_eht_addr src_addr; + struct rb_root entry_tree; + struct timer_list timer; + struct net_bridge_port_group *pg; + struct net_bridge *br; + struct net_bridge_mcast_gc mcast_gc; +}; + +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING +void br_multicast_eht_clean_sets(struct net_bridge_port_group *pg); +bool br_multicast_eht_handle(const struct net_bridge_mcast *brmctx, + struct net_bridge_port_group *pg, + void *h_addr, + void *srcs, + u32 nsrcs, + size_t addr_size, + int grec_type); +int br_multicast_eht_set_hosts_limit(struct net_bridge_port *p, + u32 eht_hosts_limit); + +static inline bool +br_multicast_eht_should_del_pg(const struct net_bridge_port_group *pg) +{ + return !!((pg->key.port->flags & BR_MULTICAST_FAST_LEAVE) && + RB_EMPTY_ROOT(&pg->eht_host_tree)); +} + +static inline bool +br_multicast_eht_hosts_over_limit(const struct net_bridge_port_group *pg) +{ + const struct net_bridge_port *p = pg->key.port; + + return !!(p->multicast_eht_hosts_cnt >= p->multicast_eht_hosts_limit); +} + +static inline void br_multicast_eht_hosts_inc(struct net_bridge_port_group *pg) +{ + struct net_bridge_port *p = pg->key.port; + + p->multicast_eht_hosts_cnt++; +} + +static inline void br_multicast_eht_hosts_dec(struct net_bridge_port_group *pg) +{ + struct net_bridge_port *p = pg->key.port; + + p->multicast_eht_hosts_cnt--; +} +#endif /* CONFIG_BRIDGE_IGMP_SNOOPING */ + +#endif /* _BR_PRIVATE_MCAST_EHT_H_ */ diff --git a/net/bridge/br_private_mrp.h b/net/bridge/br_private_mrp.h new file mode 100644 index 000000000..bda8e1896 --- /dev/null +++ b/net/bridge/br_private_mrp.h @@ -0,0 +1,148 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ + +#ifndef _BR_PRIVATE_MRP_H_ +#define _BR_PRIVATE_MRP_H_ + +#include "br_private.h" +#include + +#define MRP_OPT_PADDING 0x2 + +struct br_mrp { + /* list of mrp instances */ + struct hlist_node list; + + struct net_bridge_port __rcu *p_port; + struct net_bridge_port __rcu *s_port; + struct net_bridge_port __rcu *i_port; + + u32 ring_id; + u16 in_id; + u16 prio; + + enum br_mrp_ring_role_type ring_role; + u8 ring_role_offloaded; + enum br_mrp_ring_state_type ring_state; + u32 ring_transitions; + + enum br_mrp_in_role_type in_role; + u8 in_role_offloaded; + enum br_mrp_in_state_type in_state; + u32 in_transitions; + + struct delayed_work test_work; + u32 test_interval; + unsigned long test_end; + u32 test_count_miss; + u32 test_max_miss; + bool test_monitor; + + struct delayed_work in_test_work; + u32 in_test_interval; + unsigned long in_test_end; + u32 in_test_count_miss; + u32 in_test_max_miss; + + u32 seq_id; + + struct rcu_head rcu; +}; + +/* This type is returned by br_mrp_switchdev functions that allow to have a SW + * backup in case the HW can't implement completely the protocol. + * BR_MRP_NONE - means the HW can't run at all the protocol, so the SW stops + * configuring the node anymore. + * BR_MRP_SW - the HW can help the SW to run the protocol, by redirecting MRP + * frames to CPU. + * BR_MRP_HW - the HW can implement completely the protocol. + */ +enum br_mrp_hw_support { + BR_MRP_NONE, + BR_MRP_SW, + BR_MRP_HW, +}; + +/* br_mrp.c */ +int br_mrp_add(struct net_bridge *br, struct br_mrp_instance *instance); +int br_mrp_del(struct net_bridge *br, struct br_mrp_instance *instance); +int br_mrp_set_port_state(struct net_bridge_port *p, + enum br_mrp_port_state_type state); +int br_mrp_set_port_role(struct net_bridge_port *p, + enum br_mrp_port_role_type role); +int br_mrp_set_ring_state(struct net_bridge *br, + struct br_mrp_ring_state *state); +int br_mrp_set_ring_role(struct net_bridge *br, struct br_mrp_ring_role *role); +int br_mrp_start_test(struct net_bridge *br, struct br_mrp_start_test *test); +int br_mrp_set_in_state(struct net_bridge *br, struct br_mrp_in_state *state); +int br_mrp_set_in_role(struct net_bridge *br, struct br_mrp_in_role *role); +int br_mrp_start_in_test(struct net_bridge *br, + struct br_mrp_start_in_test *test); + +/* br_mrp_switchdev.c */ +int br_mrp_switchdev_add(struct net_bridge *br, struct br_mrp *mrp); +int br_mrp_switchdev_del(struct net_bridge *br, struct br_mrp *mrp); +enum br_mrp_hw_support +br_mrp_switchdev_set_ring_role(struct net_bridge *br, struct br_mrp *mrp, + enum br_mrp_ring_role_type role); +int br_mrp_switchdev_set_ring_state(struct net_bridge *br, struct br_mrp *mrp, + enum br_mrp_ring_state_type state); +enum br_mrp_hw_support +br_mrp_switchdev_send_ring_test(struct net_bridge *br, struct br_mrp *mrp, + u32 interval, u8 max_miss, u32 period, + bool monitor); +int br_mrp_port_switchdev_set_state(struct net_bridge_port *p, u32 state); +int br_mrp_port_switchdev_set_role(struct net_bridge_port *p, + enum br_mrp_port_role_type role); +enum br_mrp_hw_support +br_mrp_switchdev_set_in_role(struct net_bridge *br, struct br_mrp *mrp, + u16 in_id, u32 ring_id, + enum br_mrp_in_role_type role); +int br_mrp_switchdev_set_in_state(struct net_bridge *br, struct br_mrp *mrp, + enum br_mrp_in_state_type state); +enum br_mrp_hw_support +br_mrp_switchdev_send_in_test(struct net_bridge *br, struct br_mrp *mrp, + u32 interval, u8 max_miss, u32 period); + +/* br_mrp_netlink.c */ +int br_mrp_ring_port_open(struct net_device *dev, u8 loc); +int br_mrp_in_port_open(struct net_device *dev, u8 loc); + +/* MRP protocol data units */ +struct br_mrp_tlv_hdr { + __u8 type; + __u8 length; +}; + +struct br_mrp_common_hdr { + __be16 seq_id; + __u8 domain[MRP_DOMAIN_UUID_LENGTH]; +}; + +struct br_mrp_ring_test_hdr { + __be16 prio; + __u8 sa[ETH_ALEN]; + __be16 port_role; + __be16 state; + __be16 transitions; + __be32 timestamp; +} __attribute__((__packed__)); + +struct br_mrp_in_test_hdr { + __be16 id; + __u8 sa[ETH_ALEN]; + __be16 port_role; + __be16 state; + __be16 transitions; + __be32 timestamp; +} __attribute__((__packed__)); + +struct br_mrp_oui_hdr { + __u8 oui[MRP_OUI_LENGTH]; +}; + +struct br_mrp_sub_option1_hdr { + __u8 type; + __u8 data[MRP_MANUFACTURE_DATA_LENGTH]; +}; + +#endif /* _BR_PRIVATE_MRP_H */ diff --git a/net/bridge/br_private_stp.h b/net/bridge/br_private_stp.h new file mode 100644 index 000000000..814cf1364 --- /dev/null +++ b/net/bridge/br_private_stp.h @@ -0,0 +1,66 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * Linux ethernet bridge + * + * Authors: + * Lennert Buytenhek + */ + +#ifndef _BR_PRIVATE_STP_H +#define _BR_PRIVATE_STP_H + +#define BPDU_TYPE_CONFIG 0 +#define BPDU_TYPE_TCN 0x80 + +/* IEEE 802.1D-1998 timer values */ +#define BR_MIN_HELLO_TIME (1*HZ) +#define BR_MAX_HELLO_TIME (10*HZ) + +#define BR_MIN_FORWARD_DELAY (2*HZ) +#define BR_MAX_FORWARD_DELAY (30*HZ) + +#define BR_MIN_MAX_AGE (6*HZ) +#define BR_MAX_MAX_AGE (40*HZ) + +#define BR_MIN_PATH_COST 1 +#define BR_MAX_PATH_COST 65535 + +struct br_config_bpdu { + unsigned int topology_change:1; + unsigned int topology_change_ack:1; + bridge_id root; + int root_path_cost; + bridge_id bridge_id; + port_id port_id; + int message_age; + int max_age; + int hello_time; + int forward_delay; +}; + +/* called under bridge lock */ +static inline int br_is_designated_port(const struct net_bridge_port *p) +{ + return !memcmp(&p->designated_bridge, &p->br->bridge_id, 8) && + (p->designated_port == p->port_id); +} + + +/* br_stp.c */ +void br_become_root_bridge(struct net_bridge *br); +void br_config_bpdu_generation(struct net_bridge *); +void br_configuration_update(struct net_bridge *); +void br_port_state_selection(struct net_bridge *); +void br_received_config_bpdu(struct net_bridge_port *p, + const struct br_config_bpdu *bpdu); +void br_received_tcn_bpdu(struct net_bridge_port *p); +void br_transmit_config(struct net_bridge_port *p); +void br_transmit_tcn(struct net_bridge *br); +void br_topology_change_detection(struct net_bridge *br); +void __br_set_topology_change(struct net_bridge *br, unsigned char val); + +/* br_stp_bpdu.c */ +void br_send_config_bpdu(struct net_bridge_port *, struct br_config_bpdu *); +void br_send_tcn_bpdu(struct net_bridge_port *); + +#endif diff --git a/net/bridge/br_private_tunnel.h b/net/bridge/br_private_tunnel.h new file mode 100644 index 000000000..efb096025 --- /dev/null +++ b/net/bridge/br_private_tunnel.h @@ -0,0 +1,85 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * Bridge per vlan tunnels + * + * Authors: + * Roopa Prabhu + */ + +#ifndef _BR_PRIVATE_TUNNEL_H +#define _BR_PRIVATE_TUNNEL_H + +struct vtunnel_info { + u32 tunid; + u16 vid; + u16 flags; +}; + +/* br_netlink_tunnel.c */ +int br_parse_vlan_tunnel_info(struct nlattr *attr, + struct vtunnel_info *tinfo); +int br_process_vlan_tunnel_info(const struct net_bridge *br, + const struct net_bridge_port *p, + int cmd, + struct vtunnel_info *tinfo_curr, + struct vtunnel_info *tinfo_last, + bool *changed); +int br_get_vlan_tunnel_info_size(struct net_bridge_vlan_group *vg); +int br_fill_vlan_tunnel_info(struct sk_buff *skb, + struct net_bridge_vlan_group *vg); +bool vlan_tunid_inrange(const struct net_bridge_vlan *v_curr, + const struct net_bridge_vlan *v_last); +int br_vlan_tunnel_info(const struct net_bridge_port *p, int cmd, + u16 vid, u32 tun_id, bool *changed); + +#ifdef CONFIG_BRIDGE_VLAN_FILTERING +/* br_vlan_tunnel.c */ +int vlan_tunnel_init(struct net_bridge_vlan_group *vg); +void vlan_tunnel_deinit(struct net_bridge_vlan_group *vg); +int nbp_vlan_tunnel_info_delete(const struct net_bridge_port *port, u16 vid); +int nbp_vlan_tunnel_info_add(const struct net_bridge_port *port, u16 vid, + u32 tun_id); +void nbp_vlan_tunnel_info_flush(struct net_bridge_port *port); +void vlan_tunnel_info_del(struct net_bridge_vlan_group *vg, + struct net_bridge_vlan *vlan); +void br_handle_ingress_vlan_tunnel(struct sk_buff *skb, + struct net_bridge_port *p, + struct net_bridge_vlan_group *vg); +int br_handle_egress_vlan_tunnel(struct sk_buff *skb, + struct net_bridge_vlan *vlan); +#else +static inline int vlan_tunnel_init(struct net_bridge_vlan_group *vg) +{ + return 0; +} + +static inline int nbp_vlan_tunnel_info_delete(const struct net_bridge_port *port, + u16 vid) +{ + return 0; +} + +static inline int nbp_vlan_tunnel_info_add(const struct net_bridge_port *port, + u16 vid, u32 tun_id) +{ + return 0; +} + +static inline void nbp_vlan_tunnel_info_flush(struct net_bridge_port *port) +{ +} + +static inline void vlan_tunnel_info_del(struct net_bridge_vlan_group *vg, + struct net_bridge_vlan *vlan) +{ +} + +static inline int br_handle_ingress_vlan_tunnel(struct sk_buff *skb, + struct net_bridge_port *p, + struct net_bridge_vlan_group *vg) +{ + return 0; +} +#endif + +#endif diff --git a/net/bridge/br_stp.c b/net/bridge/br_stp.c new file mode 100644 index 000000000..7d27b2e60 --- /dev/null +++ b/net/bridge/br_stp.c @@ -0,0 +1,713 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Spanning tree protocol; generic parts + * Linux ethernet bridge + * + * Authors: + * Lennert Buytenhek + */ +#include +#include +#include + +#include "br_private.h" +#include "br_private_stp.h" + +/* since time values in bpdu are in jiffies and then scaled (1/256) + * before sending, make sure that is at least one STP tick. + */ +#define MESSAGE_AGE_INCR ((HZ / 256) + 1) + +static const char *const br_port_state_names[] = { + [BR_STATE_DISABLED] = "disabled", + [BR_STATE_LISTENING] = "listening", + [BR_STATE_LEARNING] = "learning", + [BR_STATE_FORWARDING] = "forwarding", + [BR_STATE_BLOCKING] = "blocking", +}; + +void br_set_state(struct net_bridge_port *p, unsigned int state) +{ + struct switchdev_attr attr = { + .orig_dev = p->dev, + .id = SWITCHDEV_ATTR_ID_PORT_STP_STATE, + .flags = SWITCHDEV_F_DEFER, + .u.stp_state = state, + }; + int err; + + /* Don't change the state of the ports if they are driven by a different + * protocol. + */ + if (p->flags & BR_MRP_AWARE) + return; + + p->state = state; + if (br_opt_get(p->br, BROPT_MST_ENABLED)) { + err = br_mst_set_state(p, 0, state, NULL); + if (err) + br_warn(p->br, "error setting MST state on port %u(%s)\n", + p->port_no, netdev_name(p->dev)); + } + err = switchdev_port_attr_set(p->dev, &attr, NULL); + if (err && err != -EOPNOTSUPP) + br_warn(p->br, "error setting offload STP state on port %u(%s)\n", + (unsigned int) p->port_no, p->dev->name); + else + br_info(p->br, "port %u(%s) entered %s state\n", + (unsigned int) p->port_no, p->dev->name, + br_port_state_names[p->state]); + + if (p->br->stp_enabled == BR_KERNEL_STP) { + switch (p->state) { + case BR_STATE_BLOCKING: + p->stp_xstats.transition_blk++; + break; + case BR_STATE_FORWARDING: + p->stp_xstats.transition_fwd++; + break; + } + } +} + +u8 br_port_get_stp_state(const struct net_device *dev) +{ + struct net_bridge_port *p; + + ASSERT_RTNL(); + + p = br_port_get_rtnl(dev); + if (!p) + return BR_STATE_DISABLED; + + return p->state; +} +EXPORT_SYMBOL_GPL(br_port_get_stp_state); + +/* called under bridge lock */ +struct net_bridge_port *br_get_port(struct net_bridge *br, u16 port_no) +{ + struct net_bridge_port *p; + + list_for_each_entry_rcu(p, &br->port_list, list, + lockdep_is_held(&br->lock)) { + if (p->port_no == port_no) + return p; + } + + return NULL; +} + +/* called under bridge lock */ +static int br_should_become_root_port(const struct net_bridge_port *p, + u16 root_port) +{ + struct net_bridge *br; + struct net_bridge_port *rp; + int t; + + br = p->br; + if (p->state == BR_STATE_DISABLED || + br_is_designated_port(p)) + return 0; + + if (memcmp(&br->bridge_id, &p->designated_root, 8) <= 0) + return 0; + + if (!root_port) + return 1; + + rp = br_get_port(br, root_port); + + t = memcmp(&p->designated_root, &rp->designated_root, 8); + if (t < 0) + return 1; + else if (t > 0) + return 0; + + if (p->designated_cost + p->path_cost < + rp->designated_cost + rp->path_cost) + return 1; + else if (p->designated_cost + p->path_cost > + rp->designated_cost + rp->path_cost) + return 0; + + t = memcmp(&p->designated_bridge, &rp->designated_bridge, 8); + if (t < 0) + return 1; + else if (t > 0) + return 0; + + if (p->designated_port < rp->designated_port) + return 1; + else if (p->designated_port > rp->designated_port) + return 0; + + if (p->port_id < rp->port_id) + return 1; + + return 0; +} + +static void br_root_port_block(const struct net_bridge *br, + struct net_bridge_port *p) +{ + + br_notice(br, "port %u(%s) tried to become root port (blocked)", + (unsigned int) p->port_no, p->dev->name); + + br_set_state(p, BR_STATE_LISTENING); + br_ifinfo_notify(RTM_NEWLINK, NULL, p); + + if (br->forward_delay > 0) + mod_timer(&p->forward_delay_timer, jiffies + br->forward_delay); +} + +/* called under bridge lock */ +static void br_root_selection(struct net_bridge *br) +{ + struct net_bridge_port *p; + u16 root_port = 0; + + list_for_each_entry(p, &br->port_list, list) { + if (!br_should_become_root_port(p, root_port)) + continue; + + if (p->flags & BR_ROOT_BLOCK) + br_root_port_block(br, p); + else + root_port = p->port_no; + } + + br->root_port = root_port; + + if (!root_port) { + br->designated_root = br->bridge_id; + br->root_path_cost = 0; + } else { + p = br_get_port(br, root_port); + br->designated_root = p->designated_root; + br->root_path_cost = p->designated_cost + p->path_cost; + } +} + +/* called under bridge lock */ +void br_become_root_bridge(struct net_bridge *br) +{ + br->max_age = br->bridge_max_age; + br->hello_time = br->bridge_hello_time; + br->forward_delay = br->bridge_forward_delay; + br_topology_change_detection(br); + del_timer(&br->tcn_timer); + + if (br->dev->flags & IFF_UP) { + br_config_bpdu_generation(br); + mod_timer(&br->hello_timer, jiffies + br->hello_time); + } +} + +/* called under bridge lock */ +void br_transmit_config(struct net_bridge_port *p) +{ + struct br_config_bpdu bpdu; + struct net_bridge *br; + + if (timer_pending(&p->hold_timer)) { + p->config_pending = 1; + return; + } + + br = p->br; + + bpdu.topology_change = br->topology_change; + bpdu.topology_change_ack = p->topology_change_ack; + bpdu.root = br->designated_root; + bpdu.root_path_cost = br->root_path_cost; + bpdu.bridge_id = br->bridge_id; + bpdu.port_id = p->port_id; + if (br_is_root_bridge(br)) + bpdu.message_age = 0; + else { + struct net_bridge_port *root + = br_get_port(br, br->root_port); + bpdu.message_age = (jiffies - root->designated_age) + + MESSAGE_AGE_INCR; + } + bpdu.max_age = br->max_age; + bpdu.hello_time = br->hello_time; + bpdu.forward_delay = br->forward_delay; + + if (bpdu.message_age < br->max_age) { + br_send_config_bpdu(p, &bpdu); + p->topology_change_ack = 0; + p->config_pending = 0; + if (p->br->stp_enabled == BR_KERNEL_STP) + mod_timer(&p->hold_timer, + round_jiffies(jiffies + BR_HOLD_TIME)); + } +} + +/* called under bridge lock */ +static void br_record_config_information(struct net_bridge_port *p, + const struct br_config_bpdu *bpdu) +{ + p->designated_root = bpdu->root; + p->designated_cost = bpdu->root_path_cost; + p->designated_bridge = bpdu->bridge_id; + p->designated_port = bpdu->port_id; + p->designated_age = jiffies - bpdu->message_age; + + mod_timer(&p->message_age_timer, jiffies + + (bpdu->max_age - bpdu->message_age)); +} + +/* called under bridge lock */ +static void br_record_config_timeout_values(struct net_bridge *br, + const struct br_config_bpdu *bpdu) +{ + br->max_age = bpdu->max_age; + br->hello_time = bpdu->hello_time; + br->forward_delay = bpdu->forward_delay; + __br_set_topology_change(br, bpdu->topology_change); +} + +/* called under bridge lock */ +void br_transmit_tcn(struct net_bridge *br) +{ + struct net_bridge_port *p; + + p = br_get_port(br, br->root_port); + if (p) + br_send_tcn_bpdu(p); + else + br_notice(br, "root port %u not found for topology notice\n", + br->root_port); +} + +/* called under bridge lock */ +static int br_should_become_designated_port(const struct net_bridge_port *p) +{ + struct net_bridge *br; + int t; + + br = p->br; + if (br_is_designated_port(p)) + return 1; + + if (memcmp(&p->designated_root, &br->designated_root, 8)) + return 1; + + if (br->root_path_cost < p->designated_cost) + return 1; + else if (br->root_path_cost > p->designated_cost) + return 0; + + t = memcmp(&br->bridge_id, &p->designated_bridge, 8); + if (t < 0) + return 1; + else if (t > 0) + return 0; + + if (p->port_id < p->designated_port) + return 1; + + return 0; +} + +/* called under bridge lock */ +static void br_designated_port_selection(struct net_bridge *br) +{ + struct net_bridge_port *p; + + list_for_each_entry(p, &br->port_list, list) { + if (p->state != BR_STATE_DISABLED && + br_should_become_designated_port(p)) + br_become_designated_port(p); + + } +} + +/* called under bridge lock */ +static int br_supersedes_port_info(const struct net_bridge_port *p, + const struct br_config_bpdu *bpdu) +{ + int t; + + t = memcmp(&bpdu->root, &p->designated_root, 8); + if (t < 0) + return 1; + else if (t > 0) + return 0; + + if (bpdu->root_path_cost < p->designated_cost) + return 1; + else if (bpdu->root_path_cost > p->designated_cost) + return 0; + + t = memcmp(&bpdu->bridge_id, &p->designated_bridge, 8); + if (t < 0) + return 1; + else if (t > 0) + return 0; + + if (memcmp(&bpdu->bridge_id, &p->br->bridge_id, 8)) + return 1; + + if (bpdu->port_id <= p->designated_port) + return 1; + + return 0; +} + +/* called under bridge lock */ +static void br_topology_change_acknowledged(struct net_bridge *br) +{ + br->topology_change_detected = 0; + del_timer(&br->tcn_timer); +} + +/* called under bridge lock */ +void br_topology_change_detection(struct net_bridge *br) +{ + int isroot = br_is_root_bridge(br); + + if (br->stp_enabled != BR_KERNEL_STP) + return; + + br_info(br, "topology change detected, %s\n", + isroot ? "propagating" : "sending tcn bpdu"); + + if (isroot) { + __br_set_topology_change(br, 1); + mod_timer(&br->topology_change_timer, jiffies + + br->bridge_forward_delay + br->bridge_max_age); + } else if (!br->topology_change_detected) { + br_transmit_tcn(br); + mod_timer(&br->tcn_timer, jiffies + br->bridge_hello_time); + } + + br->topology_change_detected = 1; +} + +/* called under bridge lock */ +void br_config_bpdu_generation(struct net_bridge *br) +{ + struct net_bridge_port *p; + + list_for_each_entry(p, &br->port_list, list) { + if (p->state != BR_STATE_DISABLED && + br_is_designated_port(p)) + br_transmit_config(p); + } +} + +/* called under bridge lock */ +static void br_reply(struct net_bridge_port *p) +{ + br_transmit_config(p); +} + +/* called under bridge lock */ +void br_configuration_update(struct net_bridge *br) +{ + br_root_selection(br); + br_designated_port_selection(br); +} + +/* called under bridge lock */ +void br_become_designated_port(struct net_bridge_port *p) +{ + struct net_bridge *br; + + br = p->br; + p->designated_root = br->designated_root; + p->designated_cost = br->root_path_cost; + p->designated_bridge = br->bridge_id; + p->designated_port = p->port_id; +} + + +/* called under bridge lock */ +static void br_make_blocking(struct net_bridge_port *p) +{ + if (p->state != BR_STATE_DISABLED && + p->state != BR_STATE_BLOCKING) { + if (p->state == BR_STATE_FORWARDING || + p->state == BR_STATE_LEARNING) + br_topology_change_detection(p->br); + + br_set_state(p, BR_STATE_BLOCKING); + br_ifinfo_notify(RTM_NEWLINK, NULL, p); + + del_timer(&p->forward_delay_timer); + } +} + +/* called under bridge lock */ +static void br_make_forwarding(struct net_bridge_port *p) +{ + struct net_bridge *br = p->br; + + if (p->state != BR_STATE_BLOCKING) + return; + + if (br->stp_enabled == BR_NO_STP || br->forward_delay == 0) { + br_set_state(p, BR_STATE_FORWARDING); + br_topology_change_detection(br); + del_timer(&p->forward_delay_timer); + } else if (br->stp_enabled == BR_KERNEL_STP) + br_set_state(p, BR_STATE_LISTENING); + else + br_set_state(p, BR_STATE_LEARNING); + + br_ifinfo_notify(RTM_NEWLINK, NULL, p); + + if (br->forward_delay != 0) + mod_timer(&p->forward_delay_timer, jiffies + br->forward_delay); +} + +/* called under bridge lock */ +void br_port_state_selection(struct net_bridge *br) +{ + struct net_bridge_port *p; + unsigned int liveports = 0; + + list_for_each_entry(p, &br->port_list, list) { + if (p->state == BR_STATE_DISABLED) + continue; + + /* Don't change port states if userspace is handling STP */ + if (br->stp_enabled != BR_USER_STP) { + if (p->port_no == br->root_port) { + p->config_pending = 0; + p->topology_change_ack = 0; + br_make_forwarding(p); + } else if (br_is_designated_port(p)) { + del_timer(&p->message_age_timer); + br_make_forwarding(p); + } else { + p->config_pending = 0; + p->topology_change_ack = 0; + br_make_blocking(p); + } + } + + if (p->state != BR_STATE_BLOCKING) + br_multicast_enable_port(p); + /* Multicast is not disabled for the port when it goes in + * blocking state because the timers will expire and stop by + * themselves without sending more queries. + */ + if (p->state == BR_STATE_FORWARDING) + ++liveports; + } + + if (liveports == 0) + netif_carrier_off(br->dev); + else + netif_carrier_on(br->dev); +} + +/* called under bridge lock */ +static void br_topology_change_acknowledge(struct net_bridge_port *p) +{ + p->topology_change_ack = 1; + br_transmit_config(p); +} + +/* called under bridge lock */ +void br_received_config_bpdu(struct net_bridge_port *p, + const struct br_config_bpdu *bpdu) +{ + struct net_bridge *br; + int was_root; + + p->stp_xstats.rx_bpdu++; + + br = p->br; + was_root = br_is_root_bridge(br); + + if (br_supersedes_port_info(p, bpdu)) { + br_record_config_information(p, bpdu); + br_configuration_update(br); + br_port_state_selection(br); + + if (!br_is_root_bridge(br) && was_root) { + del_timer(&br->hello_timer); + if (br->topology_change_detected) { + del_timer(&br->topology_change_timer); + br_transmit_tcn(br); + + mod_timer(&br->tcn_timer, + jiffies + br->bridge_hello_time); + } + } + + if (p->port_no == br->root_port) { + br_record_config_timeout_values(br, bpdu); + br_config_bpdu_generation(br); + if (bpdu->topology_change_ack) + br_topology_change_acknowledged(br); + } + } else if (br_is_designated_port(p)) { + br_reply(p); + } +} + +/* called under bridge lock */ +void br_received_tcn_bpdu(struct net_bridge_port *p) +{ + p->stp_xstats.rx_tcn++; + + if (br_is_designated_port(p)) { + br_info(p->br, "port %u(%s) received tcn bpdu\n", + (unsigned int) p->port_no, p->dev->name); + + br_topology_change_detection(p->br); + br_topology_change_acknowledge(p); + } +} + +/* Change bridge STP parameter */ +int br_set_hello_time(struct net_bridge *br, unsigned long val) +{ + unsigned long t = clock_t_to_jiffies(val); + + if (t < BR_MIN_HELLO_TIME || t > BR_MAX_HELLO_TIME) + return -ERANGE; + + spin_lock_bh(&br->lock); + br->bridge_hello_time = t; + if (br_is_root_bridge(br)) + br->hello_time = br->bridge_hello_time; + spin_unlock_bh(&br->lock); + return 0; +} + +int br_set_max_age(struct net_bridge *br, unsigned long val) +{ + unsigned long t = clock_t_to_jiffies(val); + + if (t < BR_MIN_MAX_AGE || t > BR_MAX_MAX_AGE) + return -ERANGE; + + spin_lock_bh(&br->lock); + br->bridge_max_age = t; + if (br_is_root_bridge(br)) + br->max_age = br->bridge_max_age; + spin_unlock_bh(&br->lock); + return 0; + +} + +/* called under bridge lock */ +int __set_ageing_time(struct net_device *dev, unsigned long t) +{ + struct switchdev_attr attr = { + .orig_dev = dev, + .id = SWITCHDEV_ATTR_ID_BRIDGE_AGEING_TIME, + .flags = SWITCHDEV_F_SKIP_EOPNOTSUPP | SWITCHDEV_F_DEFER, + .u.ageing_time = jiffies_to_clock_t(t), + }; + int err; + + err = switchdev_port_attr_set(dev, &attr, NULL); + if (err && err != -EOPNOTSUPP) + return err; + + return 0; +} + +/* Set time interval that dynamic forwarding entries live + * For pure software bridge, allow values outside the 802.1 + * standard specification for special cases: + * 0 - entry never ages (all permanent) + * 1 - entry disappears (no persistence) + * + * Offloaded switch entries maybe more restrictive + */ +int br_set_ageing_time(struct net_bridge *br, clock_t ageing_time) +{ + unsigned long t = clock_t_to_jiffies(ageing_time); + int err; + + err = __set_ageing_time(br->dev, t); + if (err) + return err; + + spin_lock_bh(&br->lock); + br->bridge_ageing_time = t; + br->ageing_time = t; + spin_unlock_bh(&br->lock); + + mod_delayed_work(system_long_wq, &br->gc_work, 0); + + return 0; +} + +clock_t br_get_ageing_time(const struct net_device *br_dev) +{ + const struct net_bridge *br; + + if (!netif_is_bridge_master(br_dev)) + return 0; + + br = netdev_priv(br_dev); + + return jiffies_to_clock_t(br->ageing_time); +} +EXPORT_SYMBOL_GPL(br_get_ageing_time); + +/* called under bridge lock */ +void __br_set_topology_change(struct net_bridge *br, unsigned char val) +{ + unsigned long t; + int err; + + if (br->stp_enabled == BR_KERNEL_STP && br->topology_change != val) { + /* On topology change, set the bridge ageing time to twice the + * forward delay. Otherwise, restore its default ageing time. + */ + + if (val) { + t = 2 * br->forward_delay; + br_debug(br, "decreasing ageing time to %lu\n", t); + } else { + t = br->bridge_ageing_time; + br_debug(br, "restoring ageing time to %lu\n", t); + } + + err = __set_ageing_time(br->dev, t); + if (err) + br_warn(br, "error offloading ageing time\n"); + else + br->ageing_time = t; + } + + br->topology_change = val; +} + +void __br_set_forward_delay(struct net_bridge *br, unsigned long t) +{ + br->bridge_forward_delay = t; + if (br_is_root_bridge(br)) + br->forward_delay = br->bridge_forward_delay; +} + +int br_set_forward_delay(struct net_bridge *br, unsigned long val) +{ + unsigned long t = clock_t_to_jiffies(val); + int err = -ERANGE; + + spin_lock_bh(&br->lock); + if (br->stp_enabled != BR_NO_STP && + (t < BR_MIN_FORWARD_DELAY || t > BR_MAX_FORWARD_DELAY)) + goto unlock; + + __br_set_forward_delay(br, t); + err = 0; + +unlock: + spin_unlock_bh(&br->lock); + return err; +} diff --git a/net/bridge/br_stp_bpdu.c b/net/bridge/br_stp_bpdu.c new file mode 100644 index 000000000..0e4572f31 --- /dev/null +++ b/net/bridge/br_stp_bpdu.c @@ -0,0 +1,247 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Spanning tree protocol; BPDU handling + * Linux ethernet bridge + * + * Authors: + * Lennert Buytenhek + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "br_private.h" +#include "br_private_stp.h" + +#define STP_HZ 256 + +#define LLC_RESERVE sizeof(struct llc_pdu_un) + +static int br_send_bpdu_finish(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + return dev_queue_xmit(skb); +} + +static void br_send_bpdu(struct net_bridge_port *p, + const unsigned char *data, int length) +{ + struct sk_buff *skb; + + skb = dev_alloc_skb(length+LLC_RESERVE); + if (!skb) + return; + + skb->dev = p->dev; + skb->protocol = htons(ETH_P_802_2); + skb->priority = TC_PRIO_CONTROL; + + skb_reserve(skb, LLC_RESERVE); + __skb_put_data(skb, data, length); + + llc_pdu_header_init(skb, LLC_PDU_TYPE_U, LLC_SAP_BSPAN, + LLC_SAP_BSPAN, LLC_PDU_CMD); + llc_pdu_init_as_ui_cmd(skb); + + llc_mac_hdr_init(skb, p->dev->dev_addr, p->br->group_addr); + + skb_reset_mac_header(skb); + + NF_HOOK(NFPROTO_BRIDGE, NF_BR_LOCAL_OUT, + dev_net(p->dev), NULL, skb, NULL, skb->dev, + br_send_bpdu_finish); +} + +static inline void br_set_ticks(unsigned char *dest, int j) +{ + unsigned long ticks = (STP_HZ * j)/ HZ; + + put_unaligned_be16(ticks, dest); +} + +static inline int br_get_ticks(const unsigned char *src) +{ + unsigned long ticks = get_unaligned_be16(src); + + return DIV_ROUND_UP(ticks * HZ, STP_HZ); +} + +/* called under bridge lock */ +void br_send_config_bpdu(struct net_bridge_port *p, struct br_config_bpdu *bpdu) +{ + unsigned char buf[35]; + + if (p->br->stp_enabled != BR_KERNEL_STP) + return; + + buf[0] = 0; + buf[1] = 0; + buf[2] = 0; + buf[3] = BPDU_TYPE_CONFIG; + buf[4] = (bpdu->topology_change ? 0x01 : 0) | + (bpdu->topology_change_ack ? 0x80 : 0); + buf[5] = bpdu->root.prio[0]; + buf[6] = bpdu->root.prio[1]; + buf[7] = bpdu->root.addr[0]; + buf[8] = bpdu->root.addr[1]; + buf[9] = bpdu->root.addr[2]; + buf[10] = bpdu->root.addr[3]; + buf[11] = bpdu->root.addr[4]; + buf[12] = bpdu->root.addr[5]; + buf[13] = (bpdu->root_path_cost >> 24) & 0xFF; + buf[14] = (bpdu->root_path_cost >> 16) & 0xFF; + buf[15] = (bpdu->root_path_cost >> 8) & 0xFF; + buf[16] = bpdu->root_path_cost & 0xFF; + buf[17] = bpdu->bridge_id.prio[0]; + buf[18] = bpdu->bridge_id.prio[1]; + buf[19] = bpdu->bridge_id.addr[0]; + buf[20] = bpdu->bridge_id.addr[1]; + buf[21] = bpdu->bridge_id.addr[2]; + buf[22] = bpdu->bridge_id.addr[3]; + buf[23] = bpdu->bridge_id.addr[4]; + buf[24] = bpdu->bridge_id.addr[5]; + buf[25] = (bpdu->port_id >> 8) & 0xFF; + buf[26] = bpdu->port_id & 0xFF; + + br_set_ticks(buf+27, bpdu->message_age); + br_set_ticks(buf+29, bpdu->max_age); + br_set_ticks(buf+31, bpdu->hello_time); + br_set_ticks(buf+33, bpdu->forward_delay); + + br_send_bpdu(p, buf, 35); + + p->stp_xstats.tx_bpdu++; +} + +/* called under bridge lock */ +void br_send_tcn_bpdu(struct net_bridge_port *p) +{ + unsigned char buf[4]; + + if (p->br->stp_enabled != BR_KERNEL_STP) + return; + + buf[0] = 0; + buf[1] = 0; + buf[2] = 0; + buf[3] = BPDU_TYPE_TCN; + br_send_bpdu(p, buf, 4); + + p->stp_xstats.tx_tcn++; +} + +/* + * Called from llc. + * + * NO locks, but rcu_read_lock + */ +void br_stp_rcv(const struct stp_proto *proto, struct sk_buff *skb, + struct net_device *dev) +{ + struct net_bridge_port *p; + struct net_bridge *br; + const unsigned char *buf; + + if (!pskb_may_pull(skb, 4)) + goto err; + + /* compare of protocol id and version */ + buf = skb->data; + if (buf[0] != 0 || buf[1] != 0 || buf[2] != 0) + goto err; + + p = br_port_get_check_rcu(dev); + if (!p) + goto err; + + br = p->br; + spin_lock(&br->lock); + + if (br->stp_enabled != BR_KERNEL_STP) + goto out; + + if (!(br->dev->flags & IFF_UP)) + goto out; + + if (p->state == BR_STATE_DISABLED) + goto out; + + if (!ether_addr_equal(eth_hdr(skb)->h_dest, br->group_addr)) + goto out; + + if (p->flags & BR_BPDU_GUARD) { + br_notice(br, "BPDU received on blocked port %u(%s)\n", + (unsigned int) p->port_no, p->dev->name); + br_stp_disable_port(p); + goto out; + } + + buf = skb_pull(skb, 3); + + if (buf[0] == BPDU_TYPE_CONFIG) { + struct br_config_bpdu bpdu; + + if (!pskb_may_pull(skb, 32)) + goto out; + + buf = skb->data; + bpdu.topology_change = (buf[1] & 0x01) ? 1 : 0; + bpdu.topology_change_ack = (buf[1] & 0x80) ? 1 : 0; + + bpdu.root.prio[0] = buf[2]; + bpdu.root.prio[1] = buf[3]; + bpdu.root.addr[0] = buf[4]; + bpdu.root.addr[1] = buf[5]; + bpdu.root.addr[2] = buf[6]; + bpdu.root.addr[3] = buf[7]; + bpdu.root.addr[4] = buf[8]; + bpdu.root.addr[5] = buf[9]; + bpdu.root_path_cost = + (buf[10] << 24) | + (buf[11] << 16) | + (buf[12] << 8) | + buf[13]; + bpdu.bridge_id.prio[0] = buf[14]; + bpdu.bridge_id.prio[1] = buf[15]; + bpdu.bridge_id.addr[0] = buf[16]; + bpdu.bridge_id.addr[1] = buf[17]; + bpdu.bridge_id.addr[2] = buf[18]; + bpdu.bridge_id.addr[3] = buf[19]; + bpdu.bridge_id.addr[4] = buf[20]; + bpdu.bridge_id.addr[5] = buf[21]; + bpdu.port_id = (buf[22] << 8) | buf[23]; + + bpdu.message_age = br_get_ticks(buf+24); + bpdu.max_age = br_get_ticks(buf+26); + bpdu.hello_time = br_get_ticks(buf+28); + bpdu.forward_delay = br_get_ticks(buf+30); + + if (bpdu.message_age > bpdu.max_age) { + if (net_ratelimit()) + br_notice(p->br, + "port %u config from %pM" + " (message_age %ul > max_age %ul)\n", + p->port_no, + eth_hdr(skb)->h_source, + bpdu.message_age, bpdu.max_age); + goto out; + } + + br_received_config_bpdu(p, &bpdu); + } else if (buf[0] == BPDU_TYPE_TCN) { + br_received_tcn_bpdu(p); + } + out: + spin_unlock(&br->lock); + err: + kfree_skb(skb); +} diff --git a/net/bridge/br_stp_if.c b/net/bridge/br_stp_if.c new file mode 100644 index 000000000..75204d36d --- /dev/null +++ b/net/bridge/br_stp_if.c @@ -0,0 +1,351 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Spanning tree protocol; interface code + * Linux ethernet bridge + * + * Authors: + * Lennert Buytenhek + */ + +#include +#include +#include +#include +#include + +#include "br_private.h" +#include "br_private_stp.h" + + +/* Port id is composed of priority and port number. + * NB: some bits of priority are dropped to + * make room for more ports. + */ +static inline port_id br_make_port_id(__u8 priority, __u16 port_no) +{ + return ((u16)priority << BR_PORT_BITS) + | (port_no & ((1<> BR_PORT_BITS) + +/* called under bridge lock */ +void br_init_port(struct net_bridge_port *p) +{ + int err; + + p->port_id = br_make_port_id(p->priority, p->port_no); + br_become_designated_port(p); + br_set_state(p, BR_STATE_BLOCKING); + p->topology_change_ack = 0; + p->config_pending = 0; + + err = __set_ageing_time(p->dev, p->br->ageing_time); + if (err) + netdev_err(p->dev, "failed to offload ageing time\n"); +} + +/* NO locks held */ +void br_stp_enable_bridge(struct net_bridge *br) +{ + struct net_bridge_port *p; + + spin_lock_bh(&br->lock); + if (br->stp_enabled == BR_KERNEL_STP) + mod_timer(&br->hello_timer, jiffies + br->hello_time); + mod_delayed_work(system_long_wq, &br->gc_work, HZ / 10); + + br_config_bpdu_generation(br); + + list_for_each_entry(p, &br->port_list, list) { + if (netif_running(p->dev) && netif_oper_up(p->dev)) + br_stp_enable_port(p); + + } + spin_unlock_bh(&br->lock); +} + +/* NO locks held */ +void br_stp_disable_bridge(struct net_bridge *br) +{ + struct net_bridge_port *p; + + spin_lock_bh(&br->lock); + list_for_each_entry(p, &br->port_list, list) { + if (p->state != BR_STATE_DISABLED) + br_stp_disable_port(p); + + } + + __br_set_topology_change(br, 0); + br->topology_change_detected = 0; + spin_unlock_bh(&br->lock); + + del_timer_sync(&br->hello_timer); + del_timer_sync(&br->topology_change_timer); + del_timer_sync(&br->tcn_timer); + cancel_delayed_work_sync(&br->gc_work); +} + +/* called under bridge lock */ +void br_stp_enable_port(struct net_bridge_port *p) +{ + br_init_port(p); + br_port_state_selection(p->br); + br_ifinfo_notify(RTM_NEWLINK, NULL, p); +} + +/* called under bridge lock */ +void br_stp_disable_port(struct net_bridge_port *p) +{ + struct net_bridge *br = p->br; + int wasroot; + + wasroot = br_is_root_bridge(br); + br_become_designated_port(p); + br_set_state(p, BR_STATE_DISABLED); + p->topology_change_ack = 0; + p->config_pending = 0; + + br_ifinfo_notify(RTM_NEWLINK, NULL, p); + + del_timer(&p->message_age_timer); + del_timer(&p->forward_delay_timer); + del_timer(&p->hold_timer); + + if (!rcu_access_pointer(p->backup_port)) + br_fdb_delete_by_port(br, p, 0, 0); + br_multicast_disable_port(p); + + br_configuration_update(br); + + br_port_state_selection(br); + + if (br_is_root_bridge(br) && !wasroot) + br_become_root_bridge(br); +} + +static int br_stp_call_user(struct net_bridge *br, char *arg) +{ + char *argv[] = { BR_STP_PROG, br->dev->name, arg, NULL }; + char *envp[] = { NULL }; + int rc; + + /* call userspace STP and report program errors */ + rc = call_usermodehelper(BR_STP_PROG, argv, envp, UMH_WAIT_PROC); + if (rc > 0) { + if (rc & 0xff) + br_debug(br, BR_STP_PROG " received signal %d\n", + rc & 0x7f); + else + br_debug(br, BR_STP_PROG " exited with code %d\n", + (rc >> 8) & 0xff); + } + + return rc; +} + +static void br_stp_start(struct net_bridge *br) +{ + int err = -ENOENT; + + if (net_eq(dev_net(br->dev), &init_net)) + err = br_stp_call_user(br, "start"); + + if (err && err != -ENOENT) + br_err(br, "failed to start userspace STP (%d)\n", err); + + spin_lock_bh(&br->lock); + + if (br->bridge_forward_delay < BR_MIN_FORWARD_DELAY) + __br_set_forward_delay(br, BR_MIN_FORWARD_DELAY); + else if (br->bridge_forward_delay > BR_MAX_FORWARD_DELAY) + __br_set_forward_delay(br, BR_MAX_FORWARD_DELAY); + + if (!err) { + br->stp_enabled = BR_USER_STP; + br_debug(br, "userspace STP started\n"); + } else { + br->stp_enabled = BR_KERNEL_STP; + br_debug(br, "using kernel STP\n"); + + /* To start timers on any ports left in blocking */ + if (br->dev->flags & IFF_UP) + mod_timer(&br->hello_timer, jiffies + br->hello_time); + br_port_state_selection(br); + } + + spin_unlock_bh(&br->lock); +} + +static void br_stp_stop(struct net_bridge *br) +{ + int err; + + if (br->stp_enabled == BR_USER_STP) { + err = br_stp_call_user(br, "stop"); + if (err) + br_err(br, "failed to stop userspace STP (%d)\n", err); + + /* To start timers on any ports left in blocking */ + spin_lock_bh(&br->lock); + br_port_state_selection(br); + spin_unlock_bh(&br->lock); + } + + br->stp_enabled = BR_NO_STP; +} + +int br_stp_set_enabled(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + ASSERT_RTNL(); + + if (br_mrp_enabled(br)) { + NL_SET_ERR_MSG_MOD(extack, + "STP can't be enabled if MRP is already enabled"); + return -EINVAL; + } + + if (val) { + if (br->stp_enabled == BR_NO_STP) + br_stp_start(br); + } else { + if (br->stp_enabled != BR_NO_STP) + br_stp_stop(br); + } + + return 0; +} + +/* called under bridge lock */ +void br_stp_change_bridge_id(struct net_bridge *br, const unsigned char *addr) +{ + /* should be aligned on 2 bytes for ether_addr_equal() */ + unsigned short oldaddr_aligned[ETH_ALEN >> 1]; + unsigned char *oldaddr = (unsigned char *)oldaddr_aligned; + struct net_bridge_port *p; + int wasroot; + + wasroot = br_is_root_bridge(br); + + br_fdb_change_mac_address(br, addr); + + memcpy(oldaddr, br->bridge_id.addr, ETH_ALEN); + memcpy(br->bridge_id.addr, addr, ETH_ALEN); + eth_hw_addr_set(br->dev, addr); + + list_for_each_entry(p, &br->port_list, list) { + if (ether_addr_equal(p->designated_bridge.addr, oldaddr)) + memcpy(p->designated_bridge.addr, addr, ETH_ALEN); + + if (ether_addr_equal(p->designated_root.addr, oldaddr)) + memcpy(p->designated_root.addr, addr, ETH_ALEN); + } + + br_configuration_update(br); + br_port_state_selection(br); + if (br_is_root_bridge(br) && !wasroot) + br_become_root_bridge(br); +} + +/* should be aligned on 2 bytes for ether_addr_equal() */ +static const unsigned short br_mac_zero_aligned[ETH_ALEN >> 1]; + +/* called under bridge lock */ +bool br_stp_recalculate_bridge_id(struct net_bridge *br) +{ + const unsigned char *br_mac_zero = + (const unsigned char *)br_mac_zero_aligned; + const unsigned char *addr = br_mac_zero; + struct net_bridge_port *p; + + /* user has chosen a value so keep it */ + if (br->dev->addr_assign_type == NET_ADDR_SET) + return false; + + list_for_each_entry(p, &br->port_list, list) { + if (addr == br_mac_zero || + memcmp(p->dev->dev_addr, addr, ETH_ALEN) < 0) + addr = p->dev->dev_addr; + + } + + if (ether_addr_equal(br->bridge_id.addr, addr)) + return false; /* no change */ + + br_stp_change_bridge_id(br, addr); + return true; +} + +/* Acquires and releases bridge lock */ +void br_stp_set_bridge_priority(struct net_bridge *br, u16 newprio) +{ + struct net_bridge_port *p; + int wasroot; + + spin_lock_bh(&br->lock); + wasroot = br_is_root_bridge(br); + + list_for_each_entry(p, &br->port_list, list) { + if (p->state != BR_STATE_DISABLED && + br_is_designated_port(p)) { + p->designated_bridge.prio[0] = (newprio >> 8) & 0xFF; + p->designated_bridge.prio[1] = newprio & 0xFF; + } + + } + + br->bridge_id.prio[0] = (newprio >> 8) & 0xFF; + br->bridge_id.prio[1] = newprio & 0xFF; + br_configuration_update(br); + br_port_state_selection(br); + if (br_is_root_bridge(br) && !wasroot) + br_become_root_bridge(br); + spin_unlock_bh(&br->lock); +} + +/* called under bridge lock */ +int br_stp_set_port_priority(struct net_bridge_port *p, unsigned long newprio) +{ + port_id new_port_id; + + if (newprio > BR_MAX_PORT_PRIORITY) + return -ERANGE; + + new_port_id = br_make_port_id(newprio, p->port_no); + if (br_is_designated_port(p)) + p->designated_port = new_port_id; + + p->port_id = new_port_id; + p->priority = newprio; + if (!memcmp(&p->br->bridge_id, &p->designated_bridge, 8) && + p->port_id < p->designated_port) { + br_become_designated_port(p); + br_port_state_selection(p->br); + } + + return 0; +} + +/* called under bridge lock */ +int br_stp_set_path_cost(struct net_bridge_port *p, unsigned long path_cost) +{ + if (path_cost < BR_MIN_PATH_COST || + path_cost > BR_MAX_PATH_COST) + return -ERANGE; + + p->flags |= BR_ADMIN_COST; + p->path_cost = path_cost; + br_configuration_update(p->br); + br_port_state_selection(p->br); + return 0; +} + +ssize_t br_show_bridge_id(char *buf, const struct bridge_id *id) +{ + return sprintf(buf, "%.2x%.2x.%.2x%.2x%.2x%.2x%.2x%.2x\n", + id->prio[0], id->prio[1], + id->addr[0], id->addr[1], id->addr[2], + id->addr[3], id->addr[4], id->addr[5]); +} diff --git a/net/bridge/br_stp_timer.c b/net/bridge/br_stp_timer.c new file mode 100644 index 000000000..27bf1979b --- /dev/null +++ b/net/bridge/br_stp_timer.c @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Spanning tree protocol; timer-related code + * Linux ethernet bridge + * + * Authors: + * Lennert Buytenhek + */ + +#include +#include + +#include "br_private.h" +#include "br_private_stp.h" + +/* called under bridge lock */ +static int br_is_designated_for_some_port(const struct net_bridge *br) +{ + struct net_bridge_port *p; + + list_for_each_entry(p, &br->port_list, list) { + if (p->state != BR_STATE_DISABLED && + !memcmp(&p->designated_bridge, &br->bridge_id, 8)) + return 1; + } + + return 0; +} + +static void br_hello_timer_expired(struct timer_list *t) +{ + struct net_bridge *br = from_timer(br, t, hello_timer); + + br_debug(br, "hello timer expired\n"); + spin_lock(&br->lock); + if (br->dev->flags & IFF_UP) { + br_config_bpdu_generation(br); + + if (br->stp_enabled == BR_KERNEL_STP) + mod_timer(&br->hello_timer, + round_jiffies(jiffies + br->hello_time)); + } + spin_unlock(&br->lock); +} + +static void br_message_age_timer_expired(struct timer_list *t) +{ + struct net_bridge_port *p = from_timer(p, t, message_age_timer); + struct net_bridge *br = p->br; + const bridge_id *id = &p->designated_bridge; + int was_root; + + if (p->state == BR_STATE_DISABLED) + return; + + br_info(br, "port %u(%s) neighbor %.2x%.2x.%pM lost\n", + (unsigned int) p->port_no, p->dev->name, + id->prio[0], id->prio[1], &id->addr); + + /* + * According to the spec, the message age timer cannot be + * running when we are the root bridge. So.. this was_root + * check is redundant. I'm leaving it in for now, though. + */ + spin_lock(&br->lock); + if (p->state == BR_STATE_DISABLED) + goto unlock; + was_root = br_is_root_bridge(br); + + br_become_designated_port(p); + br_configuration_update(br); + br_port_state_selection(br); + if (br_is_root_bridge(br) && !was_root) + br_become_root_bridge(br); + unlock: + spin_unlock(&br->lock); +} + +static void br_forward_delay_timer_expired(struct timer_list *t) +{ + struct net_bridge_port *p = from_timer(p, t, forward_delay_timer); + struct net_bridge *br = p->br; + + br_debug(br, "port %u(%s) forward delay timer\n", + (unsigned int) p->port_no, p->dev->name); + spin_lock(&br->lock); + if (p->state == BR_STATE_LISTENING) { + br_set_state(p, BR_STATE_LEARNING); + mod_timer(&p->forward_delay_timer, + jiffies + br->forward_delay); + } else if (p->state == BR_STATE_LEARNING) { + br_set_state(p, BR_STATE_FORWARDING); + if (br_is_designated_for_some_port(br)) + br_topology_change_detection(br); + netif_carrier_on(br->dev); + } + rcu_read_lock(); + br_ifinfo_notify(RTM_NEWLINK, NULL, p); + rcu_read_unlock(); + spin_unlock(&br->lock); +} + +static void br_tcn_timer_expired(struct timer_list *t) +{ + struct net_bridge *br = from_timer(br, t, tcn_timer); + + br_debug(br, "tcn timer expired\n"); + spin_lock(&br->lock); + if (!br_is_root_bridge(br) && (br->dev->flags & IFF_UP)) { + br_transmit_tcn(br); + + mod_timer(&br->tcn_timer, jiffies + br->bridge_hello_time); + } + spin_unlock(&br->lock); +} + +static void br_topology_change_timer_expired(struct timer_list *t) +{ + struct net_bridge *br = from_timer(br, t, topology_change_timer); + + br_debug(br, "topo change timer expired\n"); + spin_lock(&br->lock); + br->topology_change_detected = 0; + __br_set_topology_change(br, 0); + spin_unlock(&br->lock); +} + +static void br_hold_timer_expired(struct timer_list *t) +{ + struct net_bridge_port *p = from_timer(p, t, hold_timer); + + br_debug(p->br, "port %u(%s) hold timer expired\n", + (unsigned int) p->port_no, p->dev->name); + + spin_lock(&p->br->lock); + if (p->config_pending) + br_transmit_config(p); + spin_unlock(&p->br->lock); +} + +void br_stp_timer_init(struct net_bridge *br) +{ + timer_setup(&br->hello_timer, br_hello_timer_expired, 0); + timer_setup(&br->tcn_timer, br_tcn_timer_expired, 0); + timer_setup(&br->topology_change_timer, + br_topology_change_timer_expired, 0); +} + +void br_stp_port_timer_init(struct net_bridge_port *p) +{ + timer_setup(&p->message_age_timer, br_message_age_timer_expired, 0); + timer_setup(&p->forward_delay_timer, br_forward_delay_timer_expired, 0); + timer_setup(&p->hold_timer, br_hold_timer_expired, 0); +} + +/* Report ticks left (in USER_HZ) used for API */ +unsigned long br_timer_value(const struct timer_list *timer) +{ + return timer_pending(timer) + ? jiffies_delta_to_clock_t(timer->expires - jiffies) : 0; +} diff --git a/net/bridge/br_switchdev.c b/net/bridge/br_switchdev.c new file mode 100644 index 000000000..4b3982c36 --- /dev/null +++ b/net/bridge/br_switchdev.c @@ -0,0 +1,825 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include +#include + +#include "br_private.h" + +static struct static_key_false br_switchdev_tx_fwd_offload; + +static bool nbp_switchdev_can_offload_tx_fwd(const struct net_bridge_port *p, + const struct sk_buff *skb) +{ + if (!static_branch_unlikely(&br_switchdev_tx_fwd_offload)) + return false; + + return (p->flags & BR_TX_FWD_OFFLOAD) && + (p->hwdom != BR_INPUT_SKB_CB(skb)->src_hwdom); +} + +bool br_switchdev_frame_uses_tx_fwd_offload(struct sk_buff *skb) +{ + if (!static_branch_unlikely(&br_switchdev_tx_fwd_offload)) + return false; + + return BR_INPUT_SKB_CB(skb)->tx_fwd_offload; +} + +void br_switchdev_frame_set_offload_fwd_mark(struct sk_buff *skb) +{ + skb->offload_fwd_mark = br_switchdev_frame_uses_tx_fwd_offload(skb); +} + +/* Mark the frame for TX forwarding offload if this egress port supports it */ +void nbp_switchdev_frame_mark_tx_fwd_offload(const struct net_bridge_port *p, + struct sk_buff *skb) +{ + if (nbp_switchdev_can_offload_tx_fwd(p, skb)) + BR_INPUT_SKB_CB(skb)->tx_fwd_offload = true; +} + +/* Lazily adds the hwdom of the egress bridge port to the bit mask of hwdoms + * that the skb has been already forwarded to, to avoid further cloning to + * other ports in the same hwdom by making nbp_switchdev_allowed_egress() + * return false. + */ +void nbp_switchdev_frame_mark_tx_fwd_to_hwdom(const struct net_bridge_port *p, + struct sk_buff *skb) +{ + if (nbp_switchdev_can_offload_tx_fwd(p, skb)) + set_bit(p->hwdom, &BR_INPUT_SKB_CB(skb)->fwd_hwdoms); +} + +void nbp_switchdev_frame_mark(const struct net_bridge_port *p, + struct sk_buff *skb) +{ + if (p->hwdom) + BR_INPUT_SKB_CB(skb)->src_hwdom = p->hwdom; +} + +bool nbp_switchdev_allowed_egress(const struct net_bridge_port *p, + const struct sk_buff *skb) +{ + struct br_input_skb_cb *cb = BR_INPUT_SKB_CB(skb); + + return !test_bit(p->hwdom, &cb->fwd_hwdoms) && + (!skb->offload_fwd_mark || cb->src_hwdom != p->hwdom); +} + +/* Flags that can be offloaded to hardware */ +#define BR_PORT_FLAGS_HW_OFFLOAD (BR_LEARNING | BR_FLOOD | \ + BR_MCAST_FLOOD | BR_BCAST_FLOOD | BR_PORT_LOCKED | \ + BR_HAIRPIN_MODE | BR_ISOLATED | BR_MULTICAST_TO_UNICAST) + +int br_switchdev_set_port_flag(struct net_bridge_port *p, + unsigned long flags, + unsigned long mask, + struct netlink_ext_ack *extack) +{ + struct switchdev_attr attr = { + .orig_dev = p->dev, + }; + struct switchdev_notifier_port_attr_info info = { + .attr = &attr, + }; + int err; + + mask &= BR_PORT_FLAGS_HW_OFFLOAD; + if (!mask) + return 0; + + attr.id = SWITCHDEV_ATTR_ID_PORT_PRE_BRIDGE_FLAGS; + attr.u.brport_flags.val = flags; + attr.u.brport_flags.mask = mask; + + /* We run from atomic context here */ + err = call_switchdev_notifiers(SWITCHDEV_PORT_ATTR_SET, p->dev, + &info.info, extack); + err = notifier_to_errno(err); + if (err == -EOPNOTSUPP) + return 0; + + if (err) { + if (extack && !extack->_msg) + NL_SET_ERR_MSG_MOD(extack, + "bridge flag offload is not supported"); + return -EOPNOTSUPP; + } + + attr.id = SWITCHDEV_ATTR_ID_PORT_BRIDGE_FLAGS; + attr.flags = SWITCHDEV_F_DEFER; + + err = switchdev_port_attr_set(p->dev, &attr, extack); + if (err) { + if (extack && !extack->_msg) + NL_SET_ERR_MSG_MOD(extack, + "error setting offload flag on port"); + return err; + } + + return 0; +} + +static void br_switchdev_fdb_populate(struct net_bridge *br, + struct switchdev_notifier_fdb_info *item, + const struct net_bridge_fdb_entry *fdb, + const void *ctx) +{ + const struct net_bridge_port *p = READ_ONCE(fdb->dst); + + item->addr = fdb->key.addr.addr; + item->vid = fdb->key.vlan_id; + item->added_by_user = test_bit(BR_FDB_ADDED_BY_USER, &fdb->flags); + item->offloaded = test_bit(BR_FDB_OFFLOADED, &fdb->flags); + item->is_local = test_bit(BR_FDB_LOCAL, &fdb->flags); + item->info.dev = (!p || item->is_local) ? br->dev : p->dev; + item->info.ctx = ctx; +} + +void +br_switchdev_fdb_notify(struct net_bridge *br, + const struct net_bridge_fdb_entry *fdb, int type) +{ + struct switchdev_notifier_fdb_info item; + + /* Entries with these flags were created using ndm_state == NUD_REACHABLE, + * ndm_flags == NTF_MASTER( | NTF_STICKY), ext_flags == 0 by something + * equivalent to 'bridge fdb add ... master dynamic (sticky)'. + * Drivers don't know how to deal with these, so don't notify them to + * avoid confusing them. + */ + if (test_bit(BR_FDB_ADDED_BY_USER, &fdb->flags) && + !test_bit(BR_FDB_STATIC, &fdb->flags) && + !test_bit(BR_FDB_ADDED_BY_EXT_LEARN, &fdb->flags)) + return; + + br_switchdev_fdb_populate(br, &item, fdb, NULL); + + switch (type) { + case RTM_DELNEIGH: + call_switchdev_notifiers(SWITCHDEV_FDB_DEL_TO_DEVICE, + item.info.dev, &item.info, NULL); + break; + case RTM_NEWNEIGH: + call_switchdev_notifiers(SWITCHDEV_FDB_ADD_TO_DEVICE, + item.info.dev, &item.info, NULL); + break; + } +} + +int br_switchdev_port_vlan_add(struct net_device *dev, u16 vid, u16 flags, + bool changed, struct netlink_ext_ack *extack) +{ + struct switchdev_obj_port_vlan v = { + .obj.orig_dev = dev, + .obj.id = SWITCHDEV_OBJ_ID_PORT_VLAN, + .flags = flags, + .vid = vid, + .changed = changed, + }; + + return switchdev_port_obj_add(dev, &v.obj, extack); +} + +int br_switchdev_port_vlan_del(struct net_device *dev, u16 vid) +{ + struct switchdev_obj_port_vlan v = { + .obj.orig_dev = dev, + .obj.id = SWITCHDEV_OBJ_ID_PORT_VLAN, + .vid = vid, + }; + + return switchdev_port_obj_del(dev, &v.obj); +} + +static int nbp_switchdev_hwdom_set(struct net_bridge_port *joining) +{ + struct net_bridge *br = joining->br; + struct net_bridge_port *p; + int hwdom; + + /* joining is yet to be added to the port list. */ + list_for_each_entry(p, &br->port_list, list) { + if (netdev_phys_item_id_same(&joining->ppid, &p->ppid)) { + joining->hwdom = p->hwdom; + return 0; + } + } + + hwdom = find_next_zero_bit(&br->busy_hwdoms, BR_HWDOM_MAX, 1); + if (hwdom >= BR_HWDOM_MAX) + return -EBUSY; + + set_bit(hwdom, &br->busy_hwdoms); + joining->hwdom = hwdom; + return 0; +} + +static void nbp_switchdev_hwdom_put(struct net_bridge_port *leaving) +{ + struct net_bridge *br = leaving->br; + struct net_bridge_port *p; + + /* leaving is no longer in the port list. */ + list_for_each_entry(p, &br->port_list, list) { + if (p->hwdom == leaving->hwdom) + return; + } + + clear_bit(leaving->hwdom, &br->busy_hwdoms); +} + +static int nbp_switchdev_add(struct net_bridge_port *p, + struct netdev_phys_item_id ppid, + bool tx_fwd_offload, + struct netlink_ext_ack *extack) +{ + int err; + + if (p->offload_count) { + /* Prevent unsupported configurations such as a bridge port + * which is a bonding interface, and the member ports are from + * different hardware switches. + */ + if (!netdev_phys_item_id_same(&p->ppid, &ppid)) { + NL_SET_ERR_MSG_MOD(extack, + "Same bridge port cannot be offloaded by two physical switches"); + return -EBUSY; + } + + /* Tolerate drivers that call switchdev_bridge_port_offload() + * more than once for the same bridge port, such as when the + * bridge port is an offloaded bonding/team interface. + */ + p->offload_count++; + + return 0; + } + + p->ppid = ppid; + p->offload_count = 1; + + err = nbp_switchdev_hwdom_set(p); + if (err) + return err; + + if (tx_fwd_offload) { + p->flags |= BR_TX_FWD_OFFLOAD; + static_branch_inc(&br_switchdev_tx_fwd_offload); + } + + return 0; +} + +static void nbp_switchdev_del(struct net_bridge_port *p) +{ + if (WARN_ON(!p->offload_count)) + return; + + p->offload_count--; + + if (p->offload_count) + return; + + if (p->hwdom) + nbp_switchdev_hwdom_put(p); + + if (p->flags & BR_TX_FWD_OFFLOAD) { + p->flags &= ~BR_TX_FWD_OFFLOAD; + static_branch_dec(&br_switchdev_tx_fwd_offload); + } +} + +static int +br_switchdev_fdb_replay_one(struct net_bridge *br, struct notifier_block *nb, + const struct net_bridge_fdb_entry *fdb, + unsigned long action, const void *ctx) +{ + struct switchdev_notifier_fdb_info item; + int err; + + br_switchdev_fdb_populate(br, &item, fdb, ctx); + + err = nb->notifier_call(nb, action, &item); + return notifier_to_errno(err); +} + +static int +br_switchdev_fdb_replay(const struct net_device *br_dev, const void *ctx, + bool adding, struct notifier_block *nb) +{ + struct net_bridge_fdb_entry *fdb; + struct net_bridge *br; + unsigned long action; + int err = 0; + + if (!nb) + return 0; + + if (!netif_is_bridge_master(br_dev)) + return -EINVAL; + + br = netdev_priv(br_dev); + + if (adding) + action = SWITCHDEV_FDB_ADD_TO_DEVICE; + else + action = SWITCHDEV_FDB_DEL_TO_DEVICE; + + rcu_read_lock(); + + hlist_for_each_entry_rcu(fdb, &br->fdb_list, fdb_node) { + err = br_switchdev_fdb_replay_one(br, nb, fdb, action, ctx); + if (err) + break; + } + + rcu_read_unlock(); + + return err; +} + +static int br_switchdev_vlan_attr_replay(struct net_device *br_dev, + const void *ctx, + struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + struct switchdev_notifier_port_attr_info attr_info = { + .info = { + .dev = br_dev, + .extack = extack, + .ctx = ctx, + }, + }; + struct net_bridge *br = netdev_priv(br_dev); + struct net_bridge_vlan_group *vg; + struct switchdev_attr attr; + struct net_bridge_vlan *v; + int err; + + attr_info.attr = &attr; + attr.orig_dev = br_dev; + + vg = br_vlan_group(br); + if (!vg) + return 0; + + list_for_each_entry(v, &vg->vlan_list, vlist) { + if (v->msti) { + attr.id = SWITCHDEV_ATTR_ID_VLAN_MSTI; + attr.u.vlan_msti.vid = v->vid; + attr.u.vlan_msti.msti = v->msti; + + err = nb->notifier_call(nb, SWITCHDEV_PORT_ATTR_SET, + &attr_info); + err = notifier_to_errno(err); + if (err) + return err; + } + } + + return 0; +} + +static int +br_switchdev_vlan_replay_one(struct notifier_block *nb, + struct net_device *dev, + struct switchdev_obj_port_vlan *vlan, + const void *ctx, unsigned long action, + struct netlink_ext_ack *extack) +{ + struct switchdev_notifier_port_obj_info obj_info = { + .info = { + .dev = dev, + .extack = extack, + .ctx = ctx, + }, + .obj = &vlan->obj, + }; + int err; + + err = nb->notifier_call(nb, action, &obj_info); + return notifier_to_errno(err); +} + +static int br_switchdev_vlan_replay_group(struct notifier_block *nb, + struct net_device *dev, + struct net_bridge_vlan_group *vg, + const void *ctx, unsigned long action, + struct netlink_ext_ack *extack) +{ + struct net_bridge_vlan *v; + int err = 0; + u16 pvid; + + if (!vg) + return 0; + + pvid = br_get_pvid(vg); + + list_for_each_entry(v, &vg->vlan_list, vlist) { + struct switchdev_obj_port_vlan vlan = { + .obj.orig_dev = dev, + .obj.id = SWITCHDEV_OBJ_ID_PORT_VLAN, + .flags = br_vlan_flags(v, pvid), + .vid = v->vid, + }; + + if (!br_vlan_should_use(v)) + continue; + + err = br_switchdev_vlan_replay_one(nb, dev, &vlan, ctx, + action, extack); + if (err) + return err; + } + + return 0; +} + +static int br_switchdev_vlan_replay(struct net_device *br_dev, + const void *ctx, bool adding, + struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + struct net_bridge *br = netdev_priv(br_dev); + struct net_bridge_port *p; + unsigned long action; + int err; + + ASSERT_RTNL(); + + if (!nb) + return 0; + + if (!netif_is_bridge_master(br_dev)) + return -EINVAL; + + if (adding) + action = SWITCHDEV_PORT_OBJ_ADD; + else + action = SWITCHDEV_PORT_OBJ_DEL; + + err = br_switchdev_vlan_replay_group(nb, br_dev, br_vlan_group(br), + ctx, action, extack); + if (err) + return err; + + list_for_each_entry(p, &br->port_list, list) { + struct net_device *dev = p->dev; + + err = br_switchdev_vlan_replay_group(nb, dev, + nbp_vlan_group(p), + ctx, action, extack); + if (err) + return err; + } + + if (adding) { + err = br_switchdev_vlan_attr_replay(br_dev, ctx, nb, extack); + if (err) + return err; + } + + return 0; +} + +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING +struct br_switchdev_mdb_complete_info { + struct net_bridge_port *port; + struct br_ip ip; +}; + +static void br_switchdev_mdb_complete(struct net_device *dev, int err, void *priv) +{ + struct br_switchdev_mdb_complete_info *data = priv; + struct net_bridge_port_group __rcu **pp; + struct net_bridge_port_group *p; + struct net_bridge_mdb_entry *mp; + struct net_bridge_port *port = data->port; + struct net_bridge *br = port->br; + + if (err) + goto err; + + spin_lock_bh(&br->multicast_lock); + mp = br_mdb_ip_get(br, &data->ip); + if (!mp) + goto out; + for (pp = &mp->ports; (p = mlock_dereference(*pp, br)) != NULL; + pp = &p->next) { + if (p->key.port != port) + continue; + p->flags |= MDB_PG_FLAGS_OFFLOAD; + } +out: + spin_unlock_bh(&br->multicast_lock); +err: + kfree(priv); +} + +static void br_switchdev_mdb_populate(struct switchdev_obj_port_mdb *mdb, + const struct net_bridge_mdb_entry *mp) +{ + if (mp->addr.proto == htons(ETH_P_IP)) + ip_eth_mc_map(mp->addr.dst.ip4, mdb->addr); +#if IS_ENABLED(CONFIG_IPV6) + else if (mp->addr.proto == htons(ETH_P_IPV6)) + ipv6_eth_mc_map(&mp->addr.dst.ip6, mdb->addr); +#endif + else + ether_addr_copy(mdb->addr, mp->addr.dst.mac_addr); + + mdb->vid = mp->addr.vid; +} + +static void br_switchdev_host_mdb_one(struct net_device *dev, + struct net_device *lower_dev, + struct net_bridge_mdb_entry *mp, + int type) +{ + struct switchdev_obj_port_mdb mdb = { + .obj = { + .id = SWITCHDEV_OBJ_ID_HOST_MDB, + .flags = SWITCHDEV_F_DEFER, + .orig_dev = dev, + }, + }; + + br_switchdev_mdb_populate(&mdb, mp); + + switch (type) { + case RTM_NEWMDB: + switchdev_port_obj_add(lower_dev, &mdb.obj, NULL); + break; + case RTM_DELMDB: + switchdev_port_obj_del(lower_dev, &mdb.obj); + break; + } +} + +static void br_switchdev_host_mdb(struct net_device *dev, + struct net_bridge_mdb_entry *mp, int type) +{ + struct net_device *lower_dev; + struct list_head *iter; + + netdev_for_each_lower_dev(dev, lower_dev, iter) + br_switchdev_host_mdb_one(dev, lower_dev, mp, type); +} + +static int +br_switchdev_mdb_replay_one(struct notifier_block *nb, struct net_device *dev, + const struct switchdev_obj_port_mdb *mdb, + unsigned long action, const void *ctx, + struct netlink_ext_ack *extack) +{ + struct switchdev_notifier_port_obj_info obj_info = { + .info = { + .dev = dev, + .extack = extack, + .ctx = ctx, + }, + .obj = &mdb->obj, + }; + int err; + + err = nb->notifier_call(nb, action, &obj_info); + return notifier_to_errno(err); +} + +static int br_switchdev_mdb_queue_one(struct list_head *mdb_list, + enum switchdev_obj_id id, + const struct net_bridge_mdb_entry *mp, + struct net_device *orig_dev) +{ + struct switchdev_obj_port_mdb *mdb; + + mdb = kzalloc(sizeof(*mdb), GFP_ATOMIC); + if (!mdb) + return -ENOMEM; + + mdb->obj.id = id; + mdb->obj.orig_dev = orig_dev; + br_switchdev_mdb_populate(mdb, mp); + list_add_tail(&mdb->obj.list, mdb_list); + + return 0; +} + +void br_switchdev_mdb_notify(struct net_device *dev, + struct net_bridge_mdb_entry *mp, + struct net_bridge_port_group *pg, + int type) +{ + struct br_switchdev_mdb_complete_info *complete_info; + struct switchdev_obj_port_mdb mdb = { + .obj = { + .id = SWITCHDEV_OBJ_ID_PORT_MDB, + .flags = SWITCHDEV_F_DEFER, + }, + }; + + if (!pg) + return br_switchdev_host_mdb(dev, mp, type); + + br_switchdev_mdb_populate(&mdb, mp); + + mdb.obj.orig_dev = pg->key.port->dev; + switch (type) { + case RTM_NEWMDB: + complete_info = kmalloc(sizeof(*complete_info), GFP_ATOMIC); + if (!complete_info) + break; + complete_info->port = pg->key.port; + complete_info->ip = mp->addr; + mdb.obj.complete_priv = complete_info; + mdb.obj.complete = br_switchdev_mdb_complete; + if (switchdev_port_obj_add(pg->key.port->dev, &mdb.obj, NULL)) + kfree(complete_info); + break; + case RTM_DELMDB: + switchdev_port_obj_del(pg->key.port->dev, &mdb.obj); + break; + } +} +#endif + +static int +br_switchdev_mdb_replay(struct net_device *br_dev, struct net_device *dev, + const void *ctx, bool adding, struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + const struct net_bridge_mdb_entry *mp; + struct switchdev_obj *obj, *tmp; + struct net_bridge *br; + unsigned long action; + LIST_HEAD(mdb_list); + int err = 0; + + ASSERT_RTNL(); + + if (!nb) + return 0; + + if (!netif_is_bridge_master(br_dev) || !netif_is_bridge_port(dev)) + return -EINVAL; + + br = netdev_priv(br_dev); + + if (!br_opt_get(br, BROPT_MULTICAST_ENABLED)) + return 0; + + /* We cannot walk over br->mdb_list protected just by the rtnl_mutex, + * because the write-side protection is br->multicast_lock. But we + * need to emulate the [ blocking ] calling context of a regular + * switchdev event, so since both br->multicast_lock and RCU read side + * critical sections are atomic, we have no choice but to pick the RCU + * read side lock, queue up all our events, leave the critical section + * and notify switchdev from blocking context. + */ + rcu_read_lock(); + + hlist_for_each_entry_rcu(mp, &br->mdb_list, mdb_node) { + struct net_bridge_port_group __rcu * const *pp; + const struct net_bridge_port_group *p; + + if (mp->host_joined) { + err = br_switchdev_mdb_queue_one(&mdb_list, + SWITCHDEV_OBJ_ID_HOST_MDB, + mp, br_dev); + if (err) { + rcu_read_unlock(); + goto out_free_mdb; + } + } + + for (pp = &mp->ports; (p = rcu_dereference(*pp)) != NULL; + pp = &p->next) { + if (p->key.port->dev != dev) + continue; + + err = br_switchdev_mdb_queue_one(&mdb_list, + SWITCHDEV_OBJ_ID_PORT_MDB, + mp, dev); + if (err) { + rcu_read_unlock(); + goto out_free_mdb; + } + } + } + + rcu_read_unlock(); + + if (adding) + action = SWITCHDEV_PORT_OBJ_ADD; + else + action = SWITCHDEV_PORT_OBJ_DEL; + + list_for_each_entry(obj, &mdb_list, list) { + err = br_switchdev_mdb_replay_one(nb, dev, + SWITCHDEV_OBJ_PORT_MDB(obj), + action, ctx, extack); + if (err) + goto out_free_mdb; + } + +out_free_mdb: + list_for_each_entry_safe(obj, tmp, &mdb_list, list) { + list_del(&obj->list); + kfree(SWITCHDEV_OBJ_PORT_MDB(obj)); + } + + if (err) + return err; +#endif + + return 0; +} + +static int nbp_switchdev_sync_objs(struct net_bridge_port *p, const void *ctx, + struct notifier_block *atomic_nb, + struct notifier_block *blocking_nb, + struct netlink_ext_ack *extack) +{ + struct net_device *br_dev = p->br->dev; + struct net_device *dev = p->dev; + int err; + + err = br_switchdev_vlan_replay(br_dev, ctx, true, blocking_nb, extack); + if (err && err != -EOPNOTSUPP) + return err; + + err = br_switchdev_mdb_replay(br_dev, dev, ctx, true, blocking_nb, + extack); + if (err && err != -EOPNOTSUPP) + return err; + + err = br_switchdev_fdb_replay(br_dev, ctx, true, atomic_nb); + if (err && err != -EOPNOTSUPP) + return err; + + return 0; +} + +static void nbp_switchdev_unsync_objs(struct net_bridge_port *p, + const void *ctx, + struct notifier_block *atomic_nb, + struct notifier_block *blocking_nb) +{ + struct net_device *br_dev = p->br->dev; + struct net_device *dev = p->dev; + + br_switchdev_fdb_replay(br_dev, ctx, false, atomic_nb); + + br_switchdev_mdb_replay(br_dev, dev, ctx, false, blocking_nb, NULL); + + br_switchdev_vlan_replay(br_dev, ctx, false, blocking_nb, NULL); +} + +/* Let the bridge know that this port is offloaded, so that it can assign a + * switchdev hardware domain to it. + */ +int br_switchdev_port_offload(struct net_bridge_port *p, + struct net_device *dev, const void *ctx, + struct notifier_block *atomic_nb, + struct notifier_block *blocking_nb, + bool tx_fwd_offload, + struct netlink_ext_ack *extack) +{ + struct netdev_phys_item_id ppid; + int err; + + err = dev_get_port_parent_id(dev, &ppid, false); + if (err) + return err; + + err = nbp_switchdev_add(p, ppid, tx_fwd_offload, extack); + if (err) + return err; + + err = nbp_switchdev_sync_objs(p, ctx, atomic_nb, blocking_nb, extack); + if (err) + goto out_switchdev_del; + + return 0; + +out_switchdev_del: + nbp_switchdev_del(p); + + return err; +} + +void br_switchdev_port_unoffload(struct net_bridge_port *p, const void *ctx, + struct notifier_block *atomic_nb, + struct notifier_block *blocking_nb) +{ + nbp_switchdev_unsync_objs(p, ctx, atomic_nb, blocking_nb); + + nbp_switchdev_del(p); +} diff --git a/net/bridge/br_sysfs_br.c b/net/bridge/br_sysfs_br.c new file mode 100644 index 000000000..ea7335422 --- /dev/null +++ b/net/bridge/br_sysfs_br.c @@ -0,0 +1,1088 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Sysfs attributes of bridge + * Linux ethernet bridge + * + * Authors: + * Stephen Hemminger + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "br_private.h" + +/* IMPORTANT: new bridge options must be added with netlink support only + * please do not add new sysfs entries + */ + +#define to_bridge(cd) ((struct net_bridge *)netdev_priv(to_net_dev(cd))) + +/* + * Common code for storing bridge parameters. + */ +static ssize_t store_bridge_parm(struct device *d, + const char *buf, size_t len, + int (*set)(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack)) +{ + struct net_bridge *br = to_bridge(d); + struct netlink_ext_ack extack = {0}; + unsigned long val; + int err; + + if (!ns_capable(dev_net(br->dev)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + err = kstrtoul(buf, 0, &val); + if (err != 0) + return err; + + if (!rtnl_trylock()) + return restart_syscall(); + + err = (*set)(br, val, &extack); + if (!err) + netdev_state_change(br->dev); + if (extack._msg) { + if (err) + br_err(br, "%s\n", extack._msg); + else + br_warn(br, "%s\n", extack._msg); + } + rtnl_unlock(); + + return err ? err : len; +} + + +static ssize_t forward_delay_show(struct device *d, + struct device_attribute *attr, char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%lu\n", jiffies_to_clock_t(br->forward_delay)); +} + +static int set_forward_delay(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + return br_set_forward_delay(br, val); +} + +static ssize_t forward_delay_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_forward_delay); +} +static DEVICE_ATTR_RW(forward_delay); + +static ssize_t hello_time_show(struct device *d, struct device_attribute *attr, + char *buf) +{ + return sprintf(buf, "%lu\n", + jiffies_to_clock_t(to_bridge(d)->hello_time)); +} + +static int set_hello_time(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + return br_set_hello_time(br, val); +} + +static ssize_t hello_time_store(struct device *d, + struct device_attribute *attr, const char *buf, + size_t len) +{ + return store_bridge_parm(d, buf, len, set_hello_time); +} +static DEVICE_ATTR_RW(hello_time); + +static ssize_t max_age_show(struct device *d, struct device_attribute *attr, + char *buf) +{ + return sprintf(buf, "%lu\n", + jiffies_to_clock_t(to_bridge(d)->max_age)); +} + +static int set_max_age(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + return br_set_max_age(br, val); +} + +static ssize_t max_age_store(struct device *d, struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_max_age); +} +static DEVICE_ATTR_RW(max_age); + +static ssize_t ageing_time_show(struct device *d, + struct device_attribute *attr, char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%lu\n", jiffies_to_clock_t(br->ageing_time)); +} + +static int set_ageing_time(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + return br_set_ageing_time(br, val); +} + +static ssize_t ageing_time_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_ageing_time); +} +static DEVICE_ATTR_RW(ageing_time); + +static ssize_t stp_state_show(struct device *d, + struct device_attribute *attr, char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%d\n", br->stp_enabled); +} + + +static int set_stp_state(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + return br_stp_set_enabled(br, val, extack); +} + +static ssize_t stp_state_store(struct device *d, + struct device_attribute *attr, const char *buf, + size_t len) +{ + return store_bridge_parm(d, buf, len, set_stp_state); +} +static DEVICE_ATTR_RW(stp_state); + +static ssize_t group_fwd_mask_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%#x\n", br->group_fwd_mask); +} + +static int set_group_fwd_mask(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + if (val & BR_GROUPFWD_RESTRICTED) + return -EINVAL; + + br->group_fwd_mask = val; + + return 0; +} + +static ssize_t group_fwd_mask_store(struct device *d, + struct device_attribute *attr, + const char *buf, + size_t len) +{ + return store_bridge_parm(d, buf, len, set_group_fwd_mask); +} +static DEVICE_ATTR_RW(group_fwd_mask); + +static ssize_t priority_show(struct device *d, struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%d\n", + (br->bridge_id.prio[0] << 8) | br->bridge_id.prio[1]); +} + +static int set_priority(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + br_stp_set_bridge_priority(br, (u16) val); + return 0; +} + +static ssize_t priority_store(struct device *d, struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_priority); +} +static DEVICE_ATTR_RW(priority); + +static ssize_t root_id_show(struct device *d, struct device_attribute *attr, + char *buf) +{ + return br_show_bridge_id(buf, &to_bridge(d)->designated_root); +} +static DEVICE_ATTR_RO(root_id); + +static ssize_t bridge_id_show(struct device *d, struct device_attribute *attr, + char *buf) +{ + return br_show_bridge_id(buf, &to_bridge(d)->bridge_id); +} +static DEVICE_ATTR_RO(bridge_id); + +static ssize_t root_port_show(struct device *d, struct device_attribute *attr, + char *buf) +{ + return sprintf(buf, "%d\n", to_bridge(d)->root_port); +} +static DEVICE_ATTR_RO(root_port); + +static ssize_t root_path_cost_show(struct device *d, + struct device_attribute *attr, char *buf) +{ + return sprintf(buf, "%d\n", to_bridge(d)->root_path_cost); +} +static DEVICE_ATTR_RO(root_path_cost); + +static ssize_t topology_change_show(struct device *d, + struct device_attribute *attr, char *buf) +{ + return sprintf(buf, "%d\n", to_bridge(d)->topology_change); +} +static DEVICE_ATTR_RO(topology_change); + +static ssize_t topology_change_detected_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%d\n", br->topology_change_detected); +} +static DEVICE_ATTR_RO(topology_change_detected); + +static ssize_t hello_timer_show(struct device *d, + struct device_attribute *attr, char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%ld\n", br_timer_value(&br->hello_timer)); +} +static DEVICE_ATTR_RO(hello_timer); + +static ssize_t tcn_timer_show(struct device *d, struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%ld\n", br_timer_value(&br->tcn_timer)); +} +static DEVICE_ATTR_RO(tcn_timer); + +static ssize_t topology_change_timer_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%ld\n", br_timer_value(&br->topology_change_timer)); +} +static DEVICE_ATTR_RO(topology_change_timer); + +static ssize_t gc_timer_show(struct device *d, struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%ld\n", br_timer_value(&br->gc_work.timer)); +} +static DEVICE_ATTR_RO(gc_timer); + +static ssize_t group_addr_show(struct device *d, + struct device_attribute *attr, char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%pM\n", br->group_addr); +} + +static ssize_t group_addr_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + struct net_bridge *br = to_bridge(d); + u8 new_addr[6]; + + if (!ns_capable(dev_net(br->dev)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + if (!mac_pton(buf, new_addr)) + return -EINVAL; + + if (!is_link_local_ether_addr(new_addr)) + return -EINVAL; + + if (new_addr[5] == 1 || /* 802.3x Pause address */ + new_addr[5] == 2 || /* 802.3ad Slow protocols */ + new_addr[5] == 3) /* 802.1X PAE address */ + return -EINVAL; + + if (!rtnl_trylock()) + return restart_syscall(); + + spin_lock_bh(&br->lock); + ether_addr_copy(br->group_addr, new_addr); + spin_unlock_bh(&br->lock); + + br_opt_toggle(br, BROPT_GROUP_ADDR_SET, true); + br_recalculate_fwd_mask(br); + netdev_state_change(br->dev); + + rtnl_unlock(); + + return len; +} + +static DEVICE_ATTR_RW(group_addr); + +static int set_flush(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + struct net_bridge_fdb_flush_desc desc = { + .flags_mask = BIT(BR_FDB_STATIC) + }; + + br_fdb_flush(br, &desc); + return 0; +} + +static ssize_t flush_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_flush); +} +static DEVICE_ATTR_WO(flush); + +static ssize_t no_linklocal_learn_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%d\n", br_boolopt_get(br, BR_BOOLOPT_NO_LL_LEARN)); +} + +static int set_no_linklocal_learn(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + return br_boolopt_toggle(br, BR_BOOLOPT_NO_LL_LEARN, !!val, extack); +} + +static ssize_t no_linklocal_learn_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_no_linklocal_learn); +} +static DEVICE_ATTR_RW(no_linklocal_learn); + +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING +static ssize_t multicast_router_show(struct device *d, + struct device_attribute *attr, char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%d\n", br->multicast_ctx.multicast_router); +} + +static int set_multicast_router(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + return br_multicast_set_router(&br->multicast_ctx, val); +} + +static ssize_t multicast_router_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_multicast_router); +} +static DEVICE_ATTR_RW(multicast_router); + +static ssize_t multicast_snooping_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%d\n", br_opt_get(br, BROPT_MULTICAST_ENABLED)); +} + +static ssize_t multicast_snooping_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, br_multicast_toggle); +} +static DEVICE_ATTR_RW(multicast_snooping); + +static ssize_t multicast_query_use_ifaddr_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%d\n", + br_opt_get(br, BROPT_MULTICAST_QUERY_USE_IFADDR)); +} + +static int set_query_use_ifaddr(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + br_opt_toggle(br, BROPT_MULTICAST_QUERY_USE_IFADDR, !!val); + return 0; +} + +static ssize_t +multicast_query_use_ifaddr_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_query_use_ifaddr); +} +static DEVICE_ATTR_RW(multicast_query_use_ifaddr); + +static ssize_t multicast_querier_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%d\n", br->multicast_ctx.multicast_querier); +} + +static int set_multicast_querier(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + return br_multicast_set_querier(&br->multicast_ctx, val); +} + +static ssize_t multicast_querier_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_multicast_querier); +} +static DEVICE_ATTR_RW(multicast_querier); + +static ssize_t hash_elasticity_show(struct device *d, + struct device_attribute *attr, char *buf) +{ + return sprintf(buf, "%u\n", RHT_ELASTICITY); +} + +static int set_elasticity(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + /* 16 is RHT_ELASTICITY */ + NL_SET_ERR_MSG_MOD(extack, + "the hash_elasticity option has been deprecated and is always 16"); + return 0; +} + +static ssize_t hash_elasticity_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_elasticity); +} +static DEVICE_ATTR_RW(hash_elasticity); + +static ssize_t hash_max_show(struct device *d, struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%u\n", br->hash_max); +} + +static int set_hash_max(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + br->hash_max = val; + return 0; +} + +static ssize_t hash_max_store(struct device *d, struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_hash_max); +} +static DEVICE_ATTR_RW(hash_max); + +static ssize_t multicast_igmp_version_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + + return sprintf(buf, "%u\n", br->multicast_ctx.multicast_igmp_version); +} + +static int set_multicast_igmp_version(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + return br_multicast_set_igmp_version(&br->multicast_ctx, val); +} + +static ssize_t multicast_igmp_version_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_multicast_igmp_version); +} +static DEVICE_ATTR_RW(multicast_igmp_version); + +static ssize_t multicast_last_member_count_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%u\n", br->multicast_ctx.multicast_last_member_count); +} + +static int set_last_member_count(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + br->multicast_ctx.multicast_last_member_count = val; + return 0; +} + +static ssize_t multicast_last_member_count_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_last_member_count); +} +static DEVICE_ATTR_RW(multicast_last_member_count); + +static ssize_t multicast_startup_query_count_show( + struct device *d, struct device_attribute *attr, char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%u\n", br->multicast_ctx.multicast_startup_query_count); +} + +static int set_startup_query_count(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + br->multicast_ctx.multicast_startup_query_count = val; + return 0; +} + +static ssize_t multicast_startup_query_count_store( + struct device *d, struct device_attribute *attr, const char *buf, + size_t len) +{ + return store_bridge_parm(d, buf, len, set_startup_query_count); +} +static DEVICE_ATTR_RW(multicast_startup_query_count); + +static ssize_t multicast_last_member_interval_show( + struct device *d, struct device_attribute *attr, char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%lu\n", + jiffies_to_clock_t(br->multicast_ctx.multicast_last_member_interval)); +} + +static int set_last_member_interval(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + br->multicast_ctx.multicast_last_member_interval = clock_t_to_jiffies(val); + return 0; +} + +static ssize_t multicast_last_member_interval_store( + struct device *d, struct device_attribute *attr, const char *buf, + size_t len) +{ + return store_bridge_parm(d, buf, len, set_last_member_interval); +} +static DEVICE_ATTR_RW(multicast_last_member_interval); + +static ssize_t multicast_membership_interval_show( + struct device *d, struct device_attribute *attr, char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%lu\n", + jiffies_to_clock_t(br->multicast_ctx.multicast_membership_interval)); +} + +static int set_membership_interval(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + br->multicast_ctx.multicast_membership_interval = clock_t_to_jiffies(val); + return 0; +} + +static ssize_t multicast_membership_interval_store( + struct device *d, struct device_attribute *attr, const char *buf, + size_t len) +{ + return store_bridge_parm(d, buf, len, set_membership_interval); +} +static DEVICE_ATTR_RW(multicast_membership_interval); + +static ssize_t multicast_querier_interval_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%lu\n", + jiffies_to_clock_t(br->multicast_ctx.multicast_querier_interval)); +} + +static int set_querier_interval(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + br->multicast_ctx.multicast_querier_interval = clock_t_to_jiffies(val); + return 0; +} + +static ssize_t multicast_querier_interval_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_querier_interval); +} +static DEVICE_ATTR_RW(multicast_querier_interval); + +static ssize_t multicast_query_interval_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%lu\n", + jiffies_to_clock_t(br->multicast_ctx.multicast_query_interval)); +} + +static int set_query_interval(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + br_multicast_set_query_intvl(&br->multicast_ctx, val); + return 0; +} + +static ssize_t multicast_query_interval_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_query_interval); +} +static DEVICE_ATTR_RW(multicast_query_interval); + +static ssize_t multicast_query_response_interval_show( + struct device *d, struct device_attribute *attr, char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf( + buf, "%lu\n", + jiffies_to_clock_t(br->multicast_ctx.multicast_query_response_interval)); +} + +static int set_query_response_interval(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + br->multicast_ctx.multicast_query_response_interval = clock_t_to_jiffies(val); + return 0; +} + +static ssize_t multicast_query_response_interval_store( + struct device *d, struct device_attribute *attr, const char *buf, + size_t len) +{ + return store_bridge_parm(d, buf, len, set_query_response_interval); +} +static DEVICE_ATTR_RW(multicast_query_response_interval); + +static ssize_t multicast_startup_query_interval_show( + struct device *d, struct device_attribute *attr, char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf( + buf, "%lu\n", + jiffies_to_clock_t(br->multicast_ctx.multicast_startup_query_interval)); +} + +static int set_startup_query_interval(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + br_multicast_set_startup_query_intvl(&br->multicast_ctx, val); + return 0; +} + +static ssize_t multicast_startup_query_interval_store( + struct device *d, struct device_attribute *attr, const char *buf, + size_t len) +{ + return store_bridge_parm(d, buf, len, set_startup_query_interval); +} +static DEVICE_ATTR_RW(multicast_startup_query_interval); + +static ssize_t multicast_stats_enabled_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + + return sprintf(buf, "%d\n", + br_opt_get(br, BROPT_MULTICAST_STATS_ENABLED)); +} + +static int set_stats_enabled(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + br_opt_toggle(br, BROPT_MULTICAST_STATS_ENABLED, !!val); + return 0; +} + +static ssize_t multicast_stats_enabled_store(struct device *d, + struct device_attribute *attr, + const char *buf, + size_t len) +{ + return store_bridge_parm(d, buf, len, set_stats_enabled); +} +static DEVICE_ATTR_RW(multicast_stats_enabled); + +#if IS_ENABLED(CONFIG_IPV6) +static ssize_t multicast_mld_version_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + + return sprintf(buf, "%u\n", br->multicast_ctx.multicast_mld_version); +} + +static int set_multicast_mld_version(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + return br_multicast_set_mld_version(&br->multicast_ctx, val); +} + +static ssize_t multicast_mld_version_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_multicast_mld_version); +} +static DEVICE_ATTR_RW(multicast_mld_version); +#endif +#endif +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) +static ssize_t nf_call_iptables_show( + struct device *d, struct device_attribute *attr, char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%u\n", br_opt_get(br, BROPT_NF_CALL_IPTABLES)); +} + +static int set_nf_call_iptables(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + br_opt_toggle(br, BROPT_NF_CALL_IPTABLES, !!val); + return 0; +} + +static ssize_t nf_call_iptables_store( + struct device *d, struct device_attribute *attr, const char *buf, + size_t len) +{ + return store_bridge_parm(d, buf, len, set_nf_call_iptables); +} +static DEVICE_ATTR_RW(nf_call_iptables); + +static ssize_t nf_call_ip6tables_show( + struct device *d, struct device_attribute *attr, char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%u\n", br_opt_get(br, BROPT_NF_CALL_IP6TABLES)); +} + +static int set_nf_call_ip6tables(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + br_opt_toggle(br, BROPT_NF_CALL_IP6TABLES, !!val); + return 0; +} + +static ssize_t nf_call_ip6tables_store( + struct device *d, struct device_attribute *attr, const char *buf, + size_t len) +{ + return store_bridge_parm(d, buf, len, set_nf_call_ip6tables); +} +static DEVICE_ATTR_RW(nf_call_ip6tables); + +static ssize_t nf_call_arptables_show( + struct device *d, struct device_attribute *attr, char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%u\n", br_opt_get(br, BROPT_NF_CALL_ARPTABLES)); +} + +static int set_nf_call_arptables(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + br_opt_toggle(br, BROPT_NF_CALL_ARPTABLES, !!val); + return 0; +} + +static ssize_t nf_call_arptables_store( + struct device *d, struct device_attribute *attr, const char *buf, + size_t len) +{ + return store_bridge_parm(d, buf, len, set_nf_call_arptables); +} +static DEVICE_ATTR_RW(nf_call_arptables); +#endif +#ifdef CONFIG_BRIDGE_VLAN_FILTERING +static ssize_t vlan_filtering_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%d\n", br_opt_get(br, BROPT_VLAN_ENABLED)); +} + +static ssize_t vlan_filtering_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, br_vlan_filter_toggle); +} +static DEVICE_ATTR_RW(vlan_filtering); + +static ssize_t vlan_protocol_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%#06x\n", ntohs(br->vlan_proto)); +} + +static ssize_t vlan_protocol_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, br_vlan_set_proto); +} +static DEVICE_ATTR_RW(vlan_protocol); + +static ssize_t default_pvid_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%d\n", br->default_pvid); +} + +static ssize_t default_pvid_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, br_vlan_set_default_pvid); +} +static DEVICE_ATTR_RW(default_pvid); + +static ssize_t vlan_stats_enabled_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%u\n", br_opt_get(br, BROPT_VLAN_STATS_ENABLED)); +} + +static int set_vlan_stats_enabled(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + return br_vlan_set_stats(br, val); +} + +static ssize_t vlan_stats_enabled_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_vlan_stats_enabled); +} +static DEVICE_ATTR_RW(vlan_stats_enabled); + +static ssize_t vlan_stats_per_port_show(struct device *d, + struct device_attribute *attr, + char *buf) +{ + struct net_bridge *br = to_bridge(d); + return sprintf(buf, "%u\n", br_opt_get(br, BROPT_VLAN_STATS_PER_PORT)); +} + +static int set_vlan_stats_per_port(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + return br_vlan_set_stats_per_port(br, val); +} + +static ssize_t vlan_stats_per_port_store(struct device *d, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return store_bridge_parm(d, buf, len, set_vlan_stats_per_port); +} +static DEVICE_ATTR_RW(vlan_stats_per_port); +#endif + +static struct attribute *bridge_attrs[] = { + &dev_attr_forward_delay.attr, + &dev_attr_hello_time.attr, + &dev_attr_max_age.attr, + &dev_attr_ageing_time.attr, + &dev_attr_stp_state.attr, + &dev_attr_group_fwd_mask.attr, + &dev_attr_priority.attr, + &dev_attr_bridge_id.attr, + &dev_attr_root_id.attr, + &dev_attr_root_path_cost.attr, + &dev_attr_root_port.attr, + &dev_attr_topology_change.attr, + &dev_attr_topology_change_detected.attr, + &dev_attr_hello_timer.attr, + &dev_attr_tcn_timer.attr, + &dev_attr_topology_change_timer.attr, + &dev_attr_gc_timer.attr, + &dev_attr_group_addr.attr, + &dev_attr_flush.attr, + &dev_attr_no_linklocal_learn.attr, +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + &dev_attr_multicast_router.attr, + &dev_attr_multicast_snooping.attr, + &dev_attr_multicast_querier.attr, + &dev_attr_multicast_query_use_ifaddr.attr, + &dev_attr_hash_elasticity.attr, + &dev_attr_hash_max.attr, + &dev_attr_multicast_last_member_count.attr, + &dev_attr_multicast_startup_query_count.attr, + &dev_attr_multicast_last_member_interval.attr, + &dev_attr_multicast_membership_interval.attr, + &dev_attr_multicast_querier_interval.attr, + &dev_attr_multicast_query_interval.attr, + &dev_attr_multicast_query_response_interval.attr, + &dev_attr_multicast_startup_query_interval.attr, + &dev_attr_multicast_stats_enabled.attr, + &dev_attr_multicast_igmp_version.attr, +#if IS_ENABLED(CONFIG_IPV6) + &dev_attr_multicast_mld_version.attr, +#endif +#endif +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + &dev_attr_nf_call_iptables.attr, + &dev_attr_nf_call_ip6tables.attr, + &dev_attr_nf_call_arptables.attr, +#endif +#ifdef CONFIG_BRIDGE_VLAN_FILTERING + &dev_attr_vlan_filtering.attr, + &dev_attr_vlan_protocol.attr, + &dev_attr_default_pvid.attr, + &dev_attr_vlan_stats_enabled.attr, + &dev_attr_vlan_stats_per_port.attr, +#endif + NULL +}; + +static const struct attribute_group bridge_group = { + .name = SYSFS_BRIDGE_ATTR, + .attrs = bridge_attrs, +}; + +/* + * Export the forwarding information table as a binary file + * The records are struct __fdb_entry. + * + * Returns the number of bytes read. + */ +static ssize_t brforward_read(struct file *filp, struct kobject *kobj, + struct bin_attribute *bin_attr, + char *buf, loff_t off, size_t count) +{ + struct device *dev = kobj_to_dev(kobj); + struct net_bridge *br = to_bridge(dev); + int n; + + /* must read whole records */ + if (off % sizeof(struct __fdb_entry) != 0) + return -EINVAL; + + n = br_fdb_fillbuf(br, buf, + count / sizeof(struct __fdb_entry), + off / sizeof(struct __fdb_entry)); + + if (n > 0) + n *= sizeof(struct __fdb_entry); + + return n; +} + +static struct bin_attribute bridge_forward = { + .attr = { .name = SYSFS_BRIDGE_FDB, + .mode = 0444, }, + .read = brforward_read, +}; + +/* + * Add entries in sysfs onto the existing network class device + * for the bridge. + * Adds a attribute group "bridge" containing tuning parameters. + * Binary attribute containing the forward table + * Sub directory to hold links to interfaces. + * + * Note: the ifobj exists only to be a subdirectory + * to hold links. The ifobj exists in same data structure + * as it's parent the bridge so reference counting works. + */ +int br_sysfs_addbr(struct net_device *dev) +{ + struct kobject *brobj = &dev->dev.kobj; + struct net_bridge *br = netdev_priv(dev); + int err; + + err = sysfs_create_group(brobj, &bridge_group); + if (err) { + pr_info("%s: can't create group %s/%s\n", + __func__, dev->name, bridge_group.name); + goto out1; + } + + err = sysfs_create_bin_file(brobj, &bridge_forward); + if (err) { + pr_info("%s: can't create attribute file %s/%s\n", + __func__, dev->name, bridge_forward.attr.name); + goto out2; + } + + br->ifobj = kobject_create_and_add(SYSFS_BRIDGE_PORT_SUBDIR, brobj); + if (!br->ifobj) { + pr_info("%s: can't add kobject (directory) %s/%s\n", + __func__, dev->name, SYSFS_BRIDGE_PORT_SUBDIR); + err = -ENOMEM; + goto out3; + } + return 0; + out3: + sysfs_remove_bin_file(&dev->dev.kobj, &bridge_forward); + out2: + sysfs_remove_group(&dev->dev.kobj, &bridge_group); + out1: + return err; + +} + +void br_sysfs_delbr(struct net_device *dev) +{ + struct kobject *kobj = &dev->dev.kobj; + struct net_bridge *br = netdev_priv(dev); + + kobject_put(br->ifobj); + sysfs_remove_bin_file(kobj, &bridge_forward); + sysfs_remove_group(kobj, &bridge_group); +} diff --git a/net/bridge/br_sysfs_if.c b/net/bridge/br_sysfs_if.c new file mode 100644 index 000000000..74fdd8105 --- /dev/null +++ b/net/bridge/br_sysfs_if.c @@ -0,0 +1,412 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Sysfs attributes of bridge ports + * Linux ethernet bridge + * + * Authors: + * Stephen Hemminger + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "br_private.h" + +/* IMPORTANT: new bridge port options must be added with netlink support only + * please do not add new sysfs entries + */ + +struct brport_attribute { + struct attribute attr; + ssize_t (*show)(struct net_bridge_port *, char *); + int (*store)(struct net_bridge_port *, unsigned long); + int (*store_raw)(struct net_bridge_port *, char *); +}; + +#define BRPORT_ATTR_RAW(_name, _mode, _show, _store) \ +const struct brport_attribute brport_attr_##_name = { \ + .attr = {.name = __stringify(_name), \ + .mode = _mode }, \ + .show = _show, \ + .store_raw = _store, \ +}; + +#define BRPORT_ATTR(_name, _mode, _show, _store) \ +const struct brport_attribute brport_attr_##_name = { \ + .attr = {.name = __stringify(_name), \ + .mode = _mode }, \ + .show = _show, \ + .store = _store, \ +}; + +#define BRPORT_ATTR_FLAG(_name, _mask) \ +static ssize_t show_##_name(struct net_bridge_port *p, char *buf) \ +{ \ + return sprintf(buf, "%d\n", !!(p->flags & _mask)); \ +} \ +static int store_##_name(struct net_bridge_port *p, unsigned long v) \ +{ \ + return store_flag(p, v, _mask); \ +} \ +static BRPORT_ATTR(_name, 0644, \ + show_##_name, store_##_name) + +static int store_flag(struct net_bridge_port *p, unsigned long v, + unsigned long mask) +{ + struct netlink_ext_ack extack = {0}; + unsigned long flags = p->flags; + int err; + + if (v) + flags |= mask; + else + flags &= ~mask; + + if (flags != p->flags) { + err = br_switchdev_set_port_flag(p, flags, mask, &extack); + if (err) { + netdev_err(p->dev, "%s\n", extack._msg); + return err; + } + + p->flags = flags; + br_port_flags_change(p, mask); + } + return 0; +} + +static ssize_t show_path_cost(struct net_bridge_port *p, char *buf) +{ + return sprintf(buf, "%d\n", p->path_cost); +} + +static BRPORT_ATTR(path_cost, 0644, + show_path_cost, br_stp_set_path_cost); + +static ssize_t show_priority(struct net_bridge_port *p, char *buf) +{ + return sprintf(buf, "%d\n", p->priority); +} + +static BRPORT_ATTR(priority, 0644, + show_priority, br_stp_set_port_priority); + +static ssize_t show_designated_root(struct net_bridge_port *p, char *buf) +{ + return br_show_bridge_id(buf, &p->designated_root); +} +static BRPORT_ATTR(designated_root, 0444, show_designated_root, NULL); + +static ssize_t show_designated_bridge(struct net_bridge_port *p, char *buf) +{ + return br_show_bridge_id(buf, &p->designated_bridge); +} +static BRPORT_ATTR(designated_bridge, 0444, show_designated_bridge, NULL); + +static ssize_t show_designated_port(struct net_bridge_port *p, char *buf) +{ + return sprintf(buf, "%d\n", p->designated_port); +} +static BRPORT_ATTR(designated_port, 0444, show_designated_port, NULL); + +static ssize_t show_designated_cost(struct net_bridge_port *p, char *buf) +{ + return sprintf(buf, "%d\n", p->designated_cost); +} +static BRPORT_ATTR(designated_cost, 0444, show_designated_cost, NULL); + +static ssize_t show_port_id(struct net_bridge_port *p, char *buf) +{ + return sprintf(buf, "0x%x\n", p->port_id); +} +static BRPORT_ATTR(port_id, 0444, show_port_id, NULL); + +static ssize_t show_port_no(struct net_bridge_port *p, char *buf) +{ + return sprintf(buf, "0x%x\n", p->port_no); +} + +static BRPORT_ATTR(port_no, 0444, show_port_no, NULL); + +static ssize_t show_change_ack(struct net_bridge_port *p, char *buf) +{ + return sprintf(buf, "%d\n", p->topology_change_ack); +} +static BRPORT_ATTR(change_ack, 0444, show_change_ack, NULL); + +static ssize_t show_config_pending(struct net_bridge_port *p, char *buf) +{ + return sprintf(buf, "%d\n", p->config_pending); +} +static BRPORT_ATTR(config_pending, 0444, show_config_pending, NULL); + +static ssize_t show_port_state(struct net_bridge_port *p, char *buf) +{ + return sprintf(buf, "%d\n", p->state); +} +static BRPORT_ATTR(state, 0444, show_port_state, NULL); + +static ssize_t show_message_age_timer(struct net_bridge_port *p, + char *buf) +{ + return sprintf(buf, "%ld\n", br_timer_value(&p->message_age_timer)); +} +static BRPORT_ATTR(message_age_timer, 0444, show_message_age_timer, NULL); + +static ssize_t show_forward_delay_timer(struct net_bridge_port *p, + char *buf) +{ + return sprintf(buf, "%ld\n", br_timer_value(&p->forward_delay_timer)); +} +static BRPORT_ATTR(forward_delay_timer, 0444, show_forward_delay_timer, NULL); + +static ssize_t show_hold_timer(struct net_bridge_port *p, + char *buf) +{ + return sprintf(buf, "%ld\n", br_timer_value(&p->hold_timer)); +} +static BRPORT_ATTR(hold_timer, 0444, show_hold_timer, NULL); + +static int store_flush(struct net_bridge_port *p, unsigned long v) +{ + br_fdb_delete_by_port(p->br, p, 0, 0); // Don't delete local entry + return 0; +} +static BRPORT_ATTR(flush, 0200, NULL, store_flush); + +static ssize_t show_group_fwd_mask(struct net_bridge_port *p, char *buf) +{ + return sprintf(buf, "%#x\n", p->group_fwd_mask); +} + +static int store_group_fwd_mask(struct net_bridge_port *p, + unsigned long v) +{ + if (v & BR_GROUPFWD_MACPAUSE) + return -EINVAL; + p->group_fwd_mask = v; + + return 0; +} +static BRPORT_ATTR(group_fwd_mask, 0644, show_group_fwd_mask, + store_group_fwd_mask); + +static ssize_t show_backup_port(struct net_bridge_port *p, char *buf) +{ + struct net_bridge_port *backup_p; + int ret = 0; + + rcu_read_lock(); + backup_p = rcu_dereference(p->backup_port); + if (backup_p) + ret = sprintf(buf, "%s\n", backup_p->dev->name); + rcu_read_unlock(); + + return ret; +} + +static int store_backup_port(struct net_bridge_port *p, char *buf) +{ + struct net_device *backup_dev = NULL; + char *nl = strchr(buf, '\n'); + + if (nl) + *nl = '\0'; + + if (strlen(buf) > 0) { + backup_dev = __dev_get_by_name(dev_net(p->dev), buf); + if (!backup_dev) + return -ENOENT; + } + + return nbp_backup_change(p, backup_dev); +} +static BRPORT_ATTR_RAW(backup_port, 0644, show_backup_port, store_backup_port); + +BRPORT_ATTR_FLAG(hairpin_mode, BR_HAIRPIN_MODE); +BRPORT_ATTR_FLAG(bpdu_guard, BR_BPDU_GUARD); +BRPORT_ATTR_FLAG(root_block, BR_ROOT_BLOCK); +BRPORT_ATTR_FLAG(learning, BR_LEARNING); +BRPORT_ATTR_FLAG(unicast_flood, BR_FLOOD); +BRPORT_ATTR_FLAG(proxyarp, BR_PROXYARP); +BRPORT_ATTR_FLAG(proxyarp_wifi, BR_PROXYARP_WIFI); +BRPORT_ATTR_FLAG(multicast_flood, BR_MCAST_FLOOD); +BRPORT_ATTR_FLAG(broadcast_flood, BR_BCAST_FLOOD); +BRPORT_ATTR_FLAG(neigh_suppress, BR_NEIGH_SUPPRESS); +BRPORT_ATTR_FLAG(isolated, BR_ISOLATED); + +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING +static ssize_t show_multicast_router(struct net_bridge_port *p, char *buf) +{ + return sprintf(buf, "%d\n", p->multicast_ctx.multicast_router); +} + +static int store_multicast_router(struct net_bridge_port *p, + unsigned long v) +{ + return br_multicast_set_port_router(&p->multicast_ctx, v); +} +static BRPORT_ATTR(multicast_router, 0644, show_multicast_router, + store_multicast_router); + +BRPORT_ATTR_FLAG(multicast_fast_leave, BR_MULTICAST_FAST_LEAVE); +BRPORT_ATTR_FLAG(multicast_to_unicast, BR_MULTICAST_TO_UNICAST); +#endif + +static const struct brport_attribute *brport_attrs[] = { + &brport_attr_path_cost, + &brport_attr_priority, + &brport_attr_port_id, + &brport_attr_port_no, + &brport_attr_designated_root, + &brport_attr_designated_bridge, + &brport_attr_designated_port, + &brport_attr_designated_cost, + &brport_attr_state, + &brport_attr_change_ack, + &brport_attr_config_pending, + &brport_attr_message_age_timer, + &brport_attr_forward_delay_timer, + &brport_attr_hold_timer, + &brport_attr_flush, + &brport_attr_hairpin_mode, + &brport_attr_bpdu_guard, + &brport_attr_root_block, + &brport_attr_learning, + &brport_attr_unicast_flood, +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + &brport_attr_multicast_router, + &brport_attr_multicast_fast_leave, + &brport_attr_multicast_to_unicast, +#endif + &brport_attr_proxyarp, + &brport_attr_proxyarp_wifi, + &brport_attr_multicast_flood, + &brport_attr_broadcast_flood, + &brport_attr_group_fwd_mask, + &brport_attr_neigh_suppress, + &brport_attr_isolated, + &brport_attr_backup_port, + NULL +}; + +#define to_brport_attr(_at) container_of(_at, struct brport_attribute, attr) + +static ssize_t brport_show(struct kobject *kobj, + struct attribute *attr, char *buf) +{ + struct brport_attribute *brport_attr = to_brport_attr(attr); + struct net_bridge_port *p = kobj_to_brport(kobj); + + if (!brport_attr->show) + return -EINVAL; + + return brport_attr->show(p, buf); +} + +static ssize_t brport_store(struct kobject *kobj, + struct attribute *attr, + const char *buf, size_t count) +{ + struct brport_attribute *brport_attr = to_brport_attr(attr); + struct net_bridge_port *p = kobj_to_brport(kobj); + ssize_t ret = -EINVAL; + unsigned long val; + char *endp; + + if (!ns_capable(dev_net(p->dev)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + if (!rtnl_trylock()) + return restart_syscall(); + + if (brport_attr->store_raw) { + char *buf_copy; + + buf_copy = kstrndup(buf, count, GFP_KERNEL); + if (!buf_copy) { + ret = -ENOMEM; + goto out_unlock; + } + spin_lock_bh(&p->br->lock); + ret = brport_attr->store_raw(p, buf_copy); + spin_unlock_bh(&p->br->lock); + kfree(buf_copy); + } else if (brport_attr->store) { + val = simple_strtoul(buf, &endp, 0); + if (endp == buf) + goto out_unlock; + spin_lock_bh(&p->br->lock); + ret = brport_attr->store(p, val); + spin_unlock_bh(&p->br->lock); + } + + if (!ret) { + br_ifinfo_notify(RTM_NEWLINK, NULL, p); + ret = count; + } +out_unlock: + rtnl_unlock(); + + return ret; +} + +const struct sysfs_ops brport_sysfs_ops = { + .show = brport_show, + .store = brport_store, +}; + +/* + * Add sysfs entries to ethernet device added to a bridge. + * Creates a brport subdirectory with bridge attributes. + * Puts symlink in bridge's brif subdirectory + */ +int br_sysfs_addif(struct net_bridge_port *p) +{ + struct net_bridge *br = p->br; + const struct brport_attribute **a; + int err; + + err = sysfs_create_link(&p->kobj, &br->dev->dev.kobj, + SYSFS_BRIDGE_PORT_LINK); + if (err) + return err; + + for (a = brport_attrs; *a; ++a) { + err = sysfs_create_file(&p->kobj, &((*a)->attr)); + if (err) + return err; + } + + strscpy(p->sysfs_name, p->dev->name, IFNAMSIZ); + return sysfs_create_link(br->ifobj, &p->kobj, p->sysfs_name); +} + +/* Rename bridge's brif symlink */ +int br_sysfs_renameif(struct net_bridge_port *p) +{ + struct net_bridge *br = p->br; + int err; + + /* If a rename fails, the rollback will cause another + * rename call with the existing name. + */ + if (!strncmp(p->sysfs_name, p->dev->name, IFNAMSIZ)) + return 0; + + err = sysfs_rename_link(br->ifobj, &p->kobj, + p->sysfs_name, p->dev->name); + if (err) + netdev_notice(br->dev, "unable to rename link %s to %s", + p->sysfs_name, p->dev->name); + else + strscpy(p->sysfs_name, p->dev->name, IFNAMSIZ); + + return err; +} diff --git a/net/bridge/br_vlan.c b/net/bridge/br_vlan.c new file mode 100644 index 000000000..9ffd40b82 --- /dev/null +++ b/net/bridge/br_vlan.c @@ -0,0 +1,2310 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include + +#include "br_private.h" +#include "br_private_tunnel.h" + +static void nbp_vlan_set_vlan_dev_state(struct net_bridge_port *p, u16 vid); + +static inline int br_vlan_cmp(struct rhashtable_compare_arg *arg, + const void *ptr) +{ + const struct net_bridge_vlan *vle = ptr; + u16 vid = *(u16 *)arg->key; + + return vle->vid != vid; +} + +static const struct rhashtable_params br_vlan_rht_params = { + .head_offset = offsetof(struct net_bridge_vlan, vnode), + .key_offset = offsetof(struct net_bridge_vlan, vid), + .key_len = sizeof(u16), + .nelem_hint = 3, + .max_size = VLAN_N_VID, + .obj_cmpfn = br_vlan_cmp, + .automatic_shrinking = true, +}; + +static struct net_bridge_vlan *br_vlan_lookup(struct rhashtable *tbl, u16 vid) +{ + return rhashtable_lookup_fast(tbl, &vid, br_vlan_rht_params); +} + +static void __vlan_add_pvid(struct net_bridge_vlan_group *vg, + const struct net_bridge_vlan *v) +{ + if (vg->pvid == v->vid) + return; + + smp_wmb(); + br_vlan_set_pvid_state(vg, v->state); + vg->pvid = v->vid; +} + +static void __vlan_delete_pvid(struct net_bridge_vlan_group *vg, u16 vid) +{ + if (vg->pvid != vid) + return; + + smp_wmb(); + vg->pvid = 0; +} + +/* Update the BRIDGE_VLAN_INFO_PVID and BRIDGE_VLAN_INFO_UNTAGGED flags of @v. + * If @commit is false, return just whether the BRIDGE_VLAN_INFO_PVID and + * BRIDGE_VLAN_INFO_UNTAGGED bits of @flags would produce any change onto @v. + */ +static bool __vlan_flags_update(struct net_bridge_vlan *v, u16 flags, + bool commit) +{ + struct net_bridge_vlan_group *vg; + bool change; + + if (br_vlan_is_master(v)) + vg = br_vlan_group(v->br); + else + vg = nbp_vlan_group(v->port); + + /* check if anything would be changed on commit */ + change = !!(flags & BRIDGE_VLAN_INFO_PVID) == !!(vg->pvid != v->vid) || + ((flags ^ v->flags) & BRIDGE_VLAN_INFO_UNTAGGED); + + if (!commit) + goto out; + + if (flags & BRIDGE_VLAN_INFO_PVID) + __vlan_add_pvid(vg, v); + else + __vlan_delete_pvid(vg, v->vid); + + if (flags & BRIDGE_VLAN_INFO_UNTAGGED) + v->flags |= BRIDGE_VLAN_INFO_UNTAGGED; + else + v->flags &= ~BRIDGE_VLAN_INFO_UNTAGGED; + +out: + return change; +} + +static bool __vlan_flags_would_change(struct net_bridge_vlan *v, u16 flags) +{ + return __vlan_flags_update(v, flags, false); +} + +static void __vlan_flags_commit(struct net_bridge_vlan *v, u16 flags) +{ + __vlan_flags_update(v, flags, true); +} + +static int __vlan_vid_add(struct net_device *dev, struct net_bridge *br, + struct net_bridge_vlan *v, u16 flags, + struct netlink_ext_ack *extack) +{ + int err; + + /* Try switchdev op first. In case it is not supported, fallback to + * 8021q add. + */ + err = br_switchdev_port_vlan_add(dev, v->vid, flags, false, extack); + if (err == -EOPNOTSUPP) + return vlan_vid_add(dev, br->vlan_proto, v->vid); + v->priv_flags |= BR_VLFLAG_ADDED_BY_SWITCHDEV; + return err; +} + +static void __vlan_add_list(struct net_bridge_vlan *v) +{ + struct net_bridge_vlan_group *vg; + struct list_head *headp, *hpos; + struct net_bridge_vlan *vent; + + if (br_vlan_is_master(v)) + vg = br_vlan_group(v->br); + else + vg = nbp_vlan_group(v->port); + + headp = &vg->vlan_list; + list_for_each_prev(hpos, headp) { + vent = list_entry(hpos, struct net_bridge_vlan, vlist); + if (v->vid >= vent->vid) + break; + } + list_add_rcu(&v->vlist, hpos); +} + +static void __vlan_del_list(struct net_bridge_vlan *v) +{ + list_del_rcu(&v->vlist); +} + +static int __vlan_vid_del(struct net_device *dev, struct net_bridge *br, + const struct net_bridge_vlan *v) +{ + int err; + + /* Try switchdev op first. In case it is not supported, fallback to + * 8021q del. + */ + err = br_switchdev_port_vlan_del(dev, v->vid); + if (!(v->priv_flags & BR_VLFLAG_ADDED_BY_SWITCHDEV)) + vlan_vid_del(dev, br->vlan_proto, v->vid); + return err == -EOPNOTSUPP ? 0 : err; +} + +/* Returns a master vlan, if it didn't exist it gets created. In all cases + * a reference is taken to the master vlan before returning. + */ +static struct net_bridge_vlan * +br_vlan_get_master(struct net_bridge *br, u16 vid, + struct netlink_ext_ack *extack) +{ + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *masterv; + + vg = br_vlan_group(br); + masterv = br_vlan_find(vg, vid); + if (!masterv) { + bool changed; + + /* missing global ctx, create it now */ + if (br_vlan_add(br, vid, 0, &changed, extack)) + return NULL; + masterv = br_vlan_find(vg, vid); + if (WARN_ON(!masterv)) + return NULL; + refcount_set(&masterv->refcnt, 1); + return masterv; + } + refcount_inc(&masterv->refcnt); + + return masterv; +} + +static void br_master_vlan_rcu_free(struct rcu_head *rcu) +{ + struct net_bridge_vlan *v; + + v = container_of(rcu, struct net_bridge_vlan, rcu); + WARN_ON(!br_vlan_is_master(v)); + free_percpu(v->stats); + v->stats = NULL; + kfree(v); +} + +static void br_vlan_put_master(struct net_bridge_vlan *masterv) +{ + struct net_bridge_vlan_group *vg; + + if (!br_vlan_is_master(masterv)) + return; + + vg = br_vlan_group(masterv->br); + if (refcount_dec_and_test(&masterv->refcnt)) { + rhashtable_remove_fast(&vg->vlan_hash, + &masterv->vnode, br_vlan_rht_params); + __vlan_del_list(masterv); + br_multicast_toggle_one_vlan(masterv, false); + br_multicast_ctx_deinit(&masterv->br_mcast_ctx); + call_rcu(&masterv->rcu, br_master_vlan_rcu_free); + } +} + +static void nbp_vlan_rcu_free(struct rcu_head *rcu) +{ + struct net_bridge_vlan *v; + + v = container_of(rcu, struct net_bridge_vlan, rcu); + WARN_ON(br_vlan_is_master(v)); + /* if we had per-port stats configured then free them here */ + if (v->priv_flags & BR_VLFLAG_PER_PORT_STATS) + free_percpu(v->stats); + v->stats = NULL; + kfree(v); +} + +static void br_vlan_init_state(struct net_bridge_vlan *v) +{ + struct net_bridge *br; + + if (br_vlan_is_master(v)) + br = v->br; + else + br = v->port->br; + + if (br_opt_get(br, BROPT_MST_ENABLED)) { + br_mst_vlan_init_state(v); + return; + } + + v->state = BR_STATE_FORWARDING; + v->msti = 0; +} + +/* This is the shared VLAN add function which works for both ports and bridge + * devices. There are four possible calls to this function in terms of the + * vlan entry type: + * 1. vlan is being added on a port (no master flags, global entry exists) + * 2. vlan is being added on a bridge (both master and brentry flags) + * 3. vlan is being added on a port, but a global entry didn't exist which + * is being created right now (master flag set, brentry flag unset), the + * global entry is used for global per-vlan features, but not for filtering + * 4. same as 3 but with both master and brentry flags set so the entry + * will be used for filtering in both the port and the bridge + */ +static int __vlan_add(struct net_bridge_vlan *v, u16 flags, + struct netlink_ext_ack *extack) +{ + struct net_bridge_vlan *masterv = NULL; + struct net_bridge_port *p = NULL; + struct net_bridge_vlan_group *vg; + struct net_device *dev; + struct net_bridge *br; + int err; + + if (br_vlan_is_master(v)) { + br = v->br; + dev = br->dev; + vg = br_vlan_group(br); + } else { + p = v->port; + br = p->br; + dev = p->dev; + vg = nbp_vlan_group(p); + } + + if (p) { + /* Add VLAN to the device filter if it is supported. + * This ensures tagged traffic enters the bridge when + * promiscuous mode is disabled by br_manage_promisc(). + */ + err = __vlan_vid_add(dev, br, v, flags, extack); + if (err) + goto out; + + /* need to work on the master vlan too */ + if (flags & BRIDGE_VLAN_INFO_MASTER) { + bool changed; + + err = br_vlan_add(br, v->vid, + flags | BRIDGE_VLAN_INFO_BRENTRY, + &changed, extack); + if (err) + goto out_filt; + + if (changed) + br_vlan_notify(br, NULL, v->vid, 0, + RTM_NEWVLAN); + } + + masterv = br_vlan_get_master(br, v->vid, extack); + if (!masterv) { + err = -ENOMEM; + goto out_filt; + } + v->brvlan = masterv; + if (br_opt_get(br, BROPT_VLAN_STATS_PER_PORT)) { + v->stats = + netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); + if (!v->stats) { + err = -ENOMEM; + goto out_filt; + } + v->priv_flags |= BR_VLFLAG_PER_PORT_STATS; + } else { + v->stats = masterv->stats; + } + br_multicast_port_ctx_init(p, v, &v->port_mcast_ctx); + } else { + if (br_vlan_should_use(v)) { + err = br_switchdev_port_vlan_add(dev, v->vid, flags, + false, extack); + if (err && err != -EOPNOTSUPP) + goto out; + } + br_multicast_ctx_init(br, v, &v->br_mcast_ctx); + v->priv_flags |= BR_VLFLAG_GLOBAL_MCAST_ENABLED; + } + + /* Add the dev mac and count the vlan only if it's usable */ + if (br_vlan_should_use(v)) { + err = br_fdb_add_local(br, p, dev->dev_addr, v->vid); + if (err) { + br_err(br, "failed insert local address into bridge forwarding table\n"); + goto out_filt; + } + vg->num_vlans++; + } + + /* set the state before publishing */ + br_vlan_init_state(v); + + err = rhashtable_lookup_insert_fast(&vg->vlan_hash, &v->vnode, + br_vlan_rht_params); + if (err) + goto out_fdb_insert; + + __vlan_add_list(v); + __vlan_flags_commit(v, flags); + br_multicast_toggle_one_vlan(v, true); + + if (p) + nbp_vlan_set_vlan_dev_state(p, v->vid); +out: + return err; + +out_fdb_insert: + if (br_vlan_should_use(v)) { + br_fdb_find_delete_local(br, p, dev->dev_addr, v->vid); + vg->num_vlans--; + } + +out_filt: + if (p) { + __vlan_vid_del(dev, br, v); + if (masterv) { + if (v->stats && masterv->stats != v->stats) + free_percpu(v->stats); + v->stats = NULL; + + br_vlan_put_master(masterv); + v->brvlan = NULL; + } + } else { + br_switchdev_port_vlan_del(dev, v->vid); + } + + goto out; +} + +static int __vlan_del(struct net_bridge_vlan *v) +{ + struct net_bridge_vlan *masterv = v; + struct net_bridge_vlan_group *vg; + struct net_bridge_port *p = NULL; + int err = 0; + + if (br_vlan_is_master(v)) { + vg = br_vlan_group(v->br); + } else { + p = v->port; + vg = nbp_vlan_group(v->port); + masterv = v->brvlan; + } + + __vlan_delete_pvid(vg, v->vid); + if (p) { + err = __vlan_vid_del(p->dev, p->br, v); + if (err) + goto out; + } else { + err = br_switchdev_port_vlan_del(v->br->dev, v->vid); + if (err && err != -EOPNOTSUPP) + goto out; + err = 0; + } + + if (br_vlan_should_use(v)) { + v->flags &= ~BRIDGE_VLAN_INFO_BRENTRY; + vg->num_vlans--; + } + + if (masterv != v) { + vlan_tunnel_info_del(vg, v); + rhashtable_remove_fast(&vg->vlan_hash, &v->vnode, + br_vlan_rht_params); + __vlan_del_list(v); + nbp_vlan_set_vlan_dev_state(p, v->vid); + br_multicast_toggle_one_vlan(v, false); + br_multicast_port_ctx_deinit(&v->port_mcast_ctx); + call_rcu(&v->rcu, nbp_vlan_rcu_free); + } + + br_vlan_put_master(masterv); +out: + return err; +} + +static void __vlan_group_free(struct net_bridge_vlan_group *vg) +{ + WARN_ON(!list_empty(&vg->vlan_list)); + rhashtable_destroy(&vg->vlan_hash); + vlan_tunnel_deinit(vg); + kfree(vg); +} + +static void __vlan_flush(const struct net_bridge *br, + const struct net_bridge_port *p, + struct net_bridge_vlan_group *vg) +{ + struct net_bridge_vlan *vlan, *tmp; + u16 v_start = 0, v_end = 0; + int err; + + __vlan_delete_pvid(vg, vg->pvid); + list_for_each_entry_safe(vlan, tmp, &vg->vlan_list, vlist) { + /* take care of disjoint ranges */ + if (!v_start) { + v_start = vlan->vid; + } else if (vlan->vid - v_end != 1) { + /* found range end, notify and start next one */ + br_vlan_notify(br, p, v_start, v_end, RTM_DELVLAN); + v_start = vlan->vid; + } + v_end = vlan->vid; + + err = __vlan_del(vlan); + if (err) { + br_err(br, + "port %u(%s) failed to delete vlan %d: %pe\n", + (unsigned int) p->port_no, p->dev->name, + vlan->vid, ERR_PTR(err)); + } + } + + /* notify about the last/whole vlan range */ + if (v_start) + br_vlan_notify(br, p, v_start, v_end, RTM_DELVLAN); +} + +struct sk_buff *br_handle_vlan(struct net_bridge *br, + const struct net_bridge_port *p, + struct net_bridge_vlan_group *vg, + struct sk_buff *skb) +{ + struct pcpu_sw_netstats *stats; + struct net_bridge_vlan *v; + u16 vid; + + /* If this packet was not filtered at input, let it pass */ + if (!BR_INPUT_SKB_CB(skb)->vlan_filtered) + goto out; + + /* At this point, we know that the frame was filtered and contains + * a valid vlan id. If the vlan id has untagged flag set, + * send untagged; otherwise, send tagged. + */ + br_vlan_get_tag(skb, &vid); + v = br_vlan_find(vg, vid); + /* Vlan entry must be configured at this point. The + * only exception is the bridge is set in promisc mode and the + * packet is destined for the bridge device. In this case + * pass the packet as is. + */ + if (!v || !br_vlan_should_use(v)) { + if ((br->dev->flags & IFF_PROMISC) && skb->dev == br->dev) { + goto out; + } else { + kfree_skb(skb); + return NULL; + } + } + if (br_opt_get(br, BROPT_VLAN_STATS_ENABLED)) { + stats = this_cpu_ptr(v->stats); + u64_stats_update_begin(&stats->syncp); + u64_stats_add(&stats->tx_bytes, skb->len); + u64_stats_inc(&stats->tx_packets); + u64_stats_update_end(&stats->syncp); + } + + /* If the skb will be sent using forwarding offload, the assumption is + * that the switchdev will inject the packet into hardware together + * with the bridge VLAN, so that it can be forwarded according to that + * VLAN. The switchdev should deal with popping the VLAN header in + * hardware on each egress port as appropriate. So only strip the VLAN + * header if forwarding offload is not being used. + */ + if (v->flags & BRIDGE_VLAN_INFO_UNTAGGED && + !br_switchdev_frame_uses_tx_fwd_offload(skb)) + __vlan_hwaccel_clear_tag(skb); + + if (p && (p->flags & BR_VLAN_TUNNEL) && + br_handle_egress_vlan_tunnel(skb, v)) { + kfree_skb(skb); + return NULL; + } +out: + return skb; +} + +/* Called under RCU */ +static bool __allowed_ingress(const struct net_bridge *br, + struct net_bridge_vlan_group *vg, + struct sk_buff *skb, u16 *vid, + u8 *state, + struct net_bridge_vlan **vlan) +{ + struct pcpu_sw_netstats *stats; + struct net_bridge_vlan *v; + bool tagged; + + BR_INPUT_SKB_CB(skb)->vlan_filtered = true; + /* If vlan tx offload is disabled on bridge device and frame was + * sent from vlan device on the bridge device, it does not have + * HW accelerated vlan tag. + */ + if (unlikely(!skb_vlan_tag_present(skb) && + skb->protocol == br->vlan_proto)) { + skb = skb_vlan_untag(skb); + if (unlikely(!skb)) + return false; + } + + if (!br_vlan_get_tag(skb, vid)) { + /* Tagged frame */ + if (skb->vlan_proto != br->vlan_proto) { + /* Protocol-mismatch, empty out vlan_tci for new tag */ + skb_push(skb, ETH_HLEN); + skb = vlan_insert_tag_set_proto(skb, skb->vlan_proto, + skb_vlan_tag_get(skb)); + if (unlikely(!skb)) + return false; + + skb_pull(skb, ETH_HLEN); + skb_reset_mac_len(skb); + *vid = 0; + tagged = false; + } else { + tagged = true; + } + } else { + /* Untagged frame */ + tagged = false; + } + + if (!*vid) { + u16 pvid = br_get_pvid(vg); + + /* Frame had a tag with VID 0 or did not have a tag. + * See if pvid is set on this port. That tells us which + * vlan untagged or priority-tagged traffic belongs to. + */ + if (!pvid) + goto drop; + + /* PVID is set on this port. Any untagged or priority-tagged + * ingress frame is considered to belong to this vlan. + */ + *vid = pvid; + if (likely(!tagged)) + /* Untagged Frame. */ + __vlan_hwaccel_put_tag(skb, br->vlan_proto, pvid); + else + /* Priority-tagged Frame. + * At this point, we know that skb->vlan_tci VID + * field was 0. + * We update only VID field and preserve PCP field. + */ + skb->vlan_tci |= pvid; + + /* if snooping and stats are disabled we can avoid the lookup */ + if (!br_opt_get(br, BROPT_MCAST_VLAN_SNOOPING_ENABLED) && + !br_opt_get(br, BROPT_VLAN_STATS_ENABLED)) { + if (*state == BR_STATE_FORWARDING) { + *state = br_vlan_get_pvid_state(vg); + if (!br_vlan_state_allowed(*state, true)) + goto drop; + } + return true; + } + } + v = br_vlan_find(vg, *vid); + if (!v || !br_vlan_should_use(v)) + goto drop; + + if (*state == BR_STATE_FORWARDING) { + *state = br_vlan_get_state(v); + if (!br_vlan_state_allowed(*state, true)) + goto drop; + } + + if (br_opt_get(br, BROPT_VLAN_STATS_ENABLED)) { + stats = this_cpu_ptr(v->stats); + u64_stats_update_begin(&stats->syncp); + u64_stats_add(&stats->rx_bytes, skb->len); + u64_stats_inc(&stats->rx_packets); + u64_stats_update_end(&stats->syncp); + } + + *vlan = v; + + return true; + +drop: + kfree_skb(skb); + return false; +} + +bool br_allowed_ingress(const struct net_bridge *br, + struct net_bridge_vlan_group *vg, struct sk_buff *skb, + u16 *vid, u8 *state, + struct net_bridge_vlan **vlan) +{ + /* If VLAN filtering is disabled on the bridge, all packets are + * permitted. + */ + *vlan = NULL; + if (!br_opt_get(br, BROPT_VLAN_ENABLED)) { + BR_INPUT_SKB_CB(skb)->vlan_filtered = false; + return true; + } + + return __allowed_ingress(br, vg, skb, vid, state, vlan); +} + +/* Called under RCU. */ +bool br_allowed_egress(struct net_bridge_vlan_group *vg, + const struct sk_buff *skb) +{ + const struct net_bridge_vlan *v; + u16 vid; + + /* If this packet was not filtered at input, let it pass */ + if (!BR_INPUT_SKB_CB(skb)->vlan_filtered) + return true; + + br_vlan_get_tag(skb, &vid); + v = br_vlan_find(vg, vid); + if (v && br_vlan_should_use(v) && + br_vlan_state_allowed(br_vlan_get_state(v), false)) + return true; + + return false; +} + +/* Called under RCU */ +bool br_should_learn(struct net_bridge_port *p, struct sk_buff *skb, u16 *vid) +{ + struct net_bridge_vlan_group *vg; + struct net_bridge *br = p->br; + struct net_bridge_vlan *v; + + /* If filtering was disabled at input, let it pass. */ + if (!br_opt_get(br, BROPT_VLAN_ENABLED)) + return true; + + vg = nbp_vlan_group_rcu(p); + if (!vg || !vg->num_vlans) + return false; + + if (!br_vlan_get_tag(skb, vid) && skb->vlan_proto != br->vlan_proto) + *vid = 0; + + if (!*vid) { + *vid = br_get_pvid(vg); + if (!*vid || + !br_vlan_state_allowed(br_vlan_get_pvid_state(vg), true)) + return false; + + return true; + } + + v = br_vlan_find(vg, *vid); + if (v && br_vlan_state_allowed(br_vlan_get_state(v), true)) + return true; + + return false; +} + +static int br_vlan_add_existing(struct net_bridge *br, + struct net_bridge_vlan_group *vg, + struct net_bridge_vlan *vlan, + u16 flags, bool *changed, + struct netlink_ext_ack *extack) +{ + bool would_change = __vlan_flags_would_change(vlan, flags); + bool becomes_brentry = false; + int err; + + if (!br_vlan_is_brentry(vlan)) { + /* Trying to change flags of non-existent bridge vlan */ + if (!(flags & BRIDGE_VLAN_INFO_BRENTRY)) + return -EINVAL; + + becomes_brentry = true; + } + + /* Master VLANs that aren't brentries weren't notified before, + * time to notify them now. + */ + if (becomes_brentry || would_change) { + err = br_switchdev_port_vlan_add(br->dev, vlan->vid, flags, + would_change, extack); + if (err && err != -EOPNOTSUPP) + return err; + } + + if (becomes_brentry) { + /* It was only kept for port vlans, now make it real */ + err = br_fdb_add_local(br, NULL, br->dev->dev_addr, vlan->vid); + if (err) { + br_err(br, "failed to insert local address into bridge forwarding table\n"); + goto err_fdb_insert; + } + + refcount_inc(&vlan->refcnt); + vlan->flags |= BRIDGE_VLAN_INFO_BRENTRY; + vg->num_vlans++; + *changed = true; + br_multicast_toggle_one_vlan(vlan, true); + } + + __vlan_flags_commit(vlan, flags); + if (would_change) + *changed = true; + + return 0; + +err_fdb_insert: + br_switchdev_port_vlan_del(br->dev, vlan->vid); + return err; +} + +/* Must be protected by RTNL. + * Must be called with vid in range from 1 to 4094 inclusive. + * changed must be true only if the vlan was created or updated + */ +int br_vlan_add(struct net_bridge *br, u16 vid, u16 flags, bool *changed, + struct netlink_ext_ack *extack) +{ + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *vlan; + int ret; + + ASSERT_RTNL(); + + *changed = false; + vg = br_vlan_group(br); + vlan = br_vlan_find(vg, vid); + if (vlan) + return br_vlan_add_existing(br, vg, vlan, flags, changed, + extack); + + vlan = kzalloc(sizeof(*vlan), GFP_KERNEL); + if (!vlan) + return -ENOMEM; + + vlan->stats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); + if (!vlan->stats) { + kfree(vlan); + return -ENOMEM; + } + vlan->vid = vid; + vlan->flags = flags | BRIDGE_VLAN_INFO_MASTER; + vlan->flags &= ~BRIDGE_VLAN_INFO_PVID; + vlan->br = br; + if (flags & BRIDGE_VLAN_INFO_BRENTRY) + refcount_set(&vlan->refcnt, 1); + ret = __vlan_add(vlan, flags, extack); + if (ret) { + free_percpu(vlan->stats); + kfree(vlan); + } else { + *changed = true; + } + + return ret; +} + +/* Must be protected by RTNL. + * Must be called with vid in range from 1 to 4094 inclusive. + */ +int br_vlan_delete(struct net_bridge *br, u16 vid) +{ + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *v; + + ASSERT_RTNL(); + + vg = br_vlan_group(br); + v = br_vlan_find(vg, vid); + if (!v || !br_vlan_is_brentry(v)) + return -ENOENT; + + br_fdb_find_delete_local(br, NULL, br->dev->dev_addr, vid); + br_fdb_delete_by_port(br, NULL, vid, 0); + + vlan_tunnel_info_del(vg, v); + + return __vlan_del(v); +} + +void br_vlan_flush(struct net_bridge *br) +{ + struct net_bridge_vlan_group *vg; + + ASSERT_RTNL(); + + vg = br_vlan_group(br); + __vlan_flush(br, NULL, vg); + RCU_INIT_POINTER(br->vlgrp, NULL); + synchronize_rcu(); + __vlan_group_free(vg); +} + +struct net_bridge_vlan *br_vlan_find(struct net_bridge_vlan_group *vg, u16 vid) +{ + if (!vg) + return NULL; + + return br_vlan_lookup(&vg->vlan_hash, vid); +} + +/* Must be protected by RTNL. */ +static void recalculate_group_addr(struct net_bridge *br) +{ + if (br_opt_get(br, BROPT_GROUP_ADDR_SET)) + return; + + spin_lock_bh(&br->lock); + if (!br_opt_get(br, BROPT_VLAN_ENABLED) || + br->vlan_proto == htons(ETH_P_8021Q)) { + /* Bridge Group Address */ + br->group_addr[5] = 0x00; + } else { /* vlan_enabled && ETH_P_8021AD */ + /* Provider Bridge Group Address */ + br->group_addr[5] = 0x08; + } + spin_unlock_bh(&br->lock); +} + +/* Must be protected by RTNL. */ +void br_recalculate_fwd_mask(struct net_bridge *br) +{ + if (!br_opt_get(br, BROPT_VLAN_ENABLED) || + br->vlan_proto == htons(ETH_P_8021Q)) + br->group_fwd_mask_required = BR_GROUPFWD_DEFAULT; + else /* vlan_enabled && ETH_P_8021AD */ + br->group_fwd_mask_required = BR_GROUPFWD_8021AD & + ~(1u << br->group_addr[5]); +} + +int br_vlan_filter_toggle(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + struct switchdev_attr attr = { + .orig_dev = br->dev, + .id = SWITCHDEV_ATTR_ID_BRIDGE_VLAN_FILTERING, + .flags = SWITCHDEV_F_SKIP_EOPNOTSUPP, + .u.vlan_filtering = val, + }; + int err; + + if (br_opt_get(br, BROPT_VLAN_ENABLED) == !!val) + return 0; + + br_opt_toggle(br, BROPT_VLAN_ENABLED, !!val); + + err = switchdev_port_attr_set(br->dev, &attr, extack); + if (err && err != -EOPNOTSUPP) { + br_opt_toggle(br, BROPT_VLAN_ENABLED, !val); + return err; + } + + br_manage_promisc(br); + recalculate_group_addr(br); + br_recalculate_fwd_mask(br); + if (!val && br_opt_get(br, BROPT_MCAST_VLAN_SNOOPING_ENABLED)) { + br_info(br, "vlan filtering disabled, automatically disabling multicast vlan snooping\n"); + br_multicast_toggle_vlan_snooping(br, false, NULL); + } + + return 0; +} + +bool br_vlan_enabled(const struct net_device *dev) +{ + struct net_bridge *br = netdev_priv(dev); + + return br_opt_get(br, BROPT_VLAN_ENABLED); +} +EXPORT_SYMBOL_GPL(br_vlan_enabled); + +int br_vlan_get_proto(const struct net_device *dev, u16 *p_proto) +{ + struct net_bridge *br = netdev_priv(dev); + + *p_proto = ntohs(br->vlan_proto); + + return 0; +} +EXPORT_SYMBOL_GPL(br_vlan_get_proto); + +int __br_vlan_set_proto(struct net_bridge *br, __be16 proto, + struct netlink_ext_ack *extack) +{ + struct switchdev_attr attr = { + .orig_dev = br->dev, + .id = SWITCHDEV_ATTR_ID_BRIDGE_VLAN_PROTOCOL, + .flags = SWITCHDEV_F_SKIP_EOPNOTSUPP, + .u.vlan_protocol = ntohs(proto), + }; + int err = 0; + struct net_bridge_port *p; + struct net_bridge_vlan *vlan; + struct net_bridge_vlan_group *vg; + __be16 oldproto = br->vlan_proto; + + if (br->vlan_proto == proto) + return 0; + + err = switchdev_port_attr_set(br->dev, &attr, extack); + if (err && err != -EOPNOTSUPP) + return err; + + /* Add VLANs for the new proto to the device filter. */ + list_for_each_entry(p, &br->port_list, list) { + vg = nbp_vlan_group(p); + list_for_each_entry(vlan, &vg->vlan_list, vlist) { + if (vlan->priv_flags & BR_VLFLAG_ADDED_BY_SWITCHDEV) + continue; + err = vlan_vid_add(p->dev, proto, vlan->vid); + if (err) + goto err_filt; + } + } + + br->vlan_proto = proto; + + recalculate_group_addr(br); + br_recalculate_fwd_mask(br); + + /* Delete VLANs for the old proto from the device filter. */ + list_for_each_entry(p, &br->port_list, list) { + vg = nbp_vlan_group(p); + list_for_each_entry(vlan, &vg->vlan_list, vlist) { + if (vlan->priv_flags & BR_VLFLAG_ADDED_BY_SWITCHDEV) + continue; + vlan_vid_del(p->dev, oldproto, vlan->vid); + } + } + + return 0; + +err_filt: + attr.u.vlan_protocol = ntohs(oldproto); + switchdev_port_attr_set(br->dev, &attr, NULL); + + list_for_each_entry_continue_reverse(vlan, &vg->vlan_list, vlist) { + if (vlan->priv_flags & BR_VLFLAG_ADDED_BY_SWITCHDEV) + continue; + vlan_vid_del(p->dev, proto, vlan->vid); + } + + list_for_each_entry_continue_reverse(p, &br->port_list, list) { + vg = nbp_vlan_group(p); + list_for_each_entry(vlan, &vg->vlan_list, vlist) { + if (vlan->priv_flags & BR_VLFLAG_ADDED_BY_SWITCHDEV) + continue; + vlan_vid_del(p->dev, proto, vlan->vid); + } + } + + return err; +} + +int br_vlan_set_proto(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + if (!eth_type_vlan(htons(val))) + return -EPROTONOSUPPORT; + + return __br_vlan_set_proto(br, htons(val), extack); +} + +int br_vlan_set_stats(struct net_bridge *br, unsigned long val) +{ + switch (val) { + case 0: + case 1: + br_opt_toggle(br, BROPT_VLAN_STATS_ENABLED, !!val); + break; + default: + return -EINVAL; + } + + return 0; +} + +int br_vlan_set_stats_per_port(struct net_bridge *br, unsigned long val) +{ + struct net_bridge_port *p; + + /* allow to change the option if there are no port vlans configured */ + list_for_each_entry(p, &br->port_list, list) { + struct net_bridge_vlan_group *vg = nbp_vlan_group(p); + + if (vg->num_vlans) + return -EBUSY; + } + + switch (val) { + case 0: + case 1: + br_opt_toggle(br, BROPT_VLAN_STATS_PER_PORT, !!val); + break; + default: + return -EINVAL; + } + + return 0; +} + +static bool vlan_default_pvid(struct net_bridge_vlan_group *vg, u16 vid) +{ + struct net_bridge_vlan *v; + + if (vid != vg->pvid) + return false; + + v = br_vlan_lookup(&vg->vlan_hash, vid); + if (v && br_vlan_should_use(v) && + (v->flags & BRIDGE_VLAN_INFO_UNTAGGED)) + return true; + + return false; +} + +static void br_vlan_disable_default_pvid(struct net_bridge *br) +{ + struct net_bridge_port *p; + u16 pvid = br->default_pvid; + + /* Disable default_pvid on all ports where it is still + * configured. + */ + if (vlan_default_pvid(br_vlan_group(br), pvid)) { + if (!br_vlan_delete(br, pvid)) + br_vlan_notify(br, NULL, pvid, 0, RTM_DELVLAN); + } + + list_for_each_entry(p, &br->port_list, list) { + if (vlan_default_pvid(nbp_vlan_group(p), pvid) && + !nbp_vlan_delete(p, pvid)) + br_vlan_notify(br, p, pvid, 0, RTM_DELVLAN); + } + + br->default_pvid = 0; +} + +int __br_vlan_set_default_pvid(struct net_bridge *br, u16 pvid, + struct netlink_ext_ack *extack) +{ + const struct net_bridge_vlan *pvent; + struct net_bridge_vlan_group *vg; + struct net_bridge_port *p; + unsigned long *changed; + bool vlchange; + u16 old_pvid; + int err = 0; + + if (!pvid) { + br_vlan_disable_default_pvid(br); + return 0; + } + + changed = bitmap_zalloc(BR_MAX_PORTS, GFP_KERNEL); + if (!changed) + return -ENOMEM; + + old_pvid = br->default_pvid; + + /* Update default_pvid config only if we do not conflict with + * user configuration. + */ + vg = br_vlan_group(br); + pvent = br_vlan_find(vg, pvid); + if ((!old_pvid || vlan_default_pvid(vg, old_pvid)) && + (!pvent || !br_vlan_should_use(pvent))) { + err = br_vlan_add(br, pvid, + BRIDGE_VLAN_INFO_PVID | + BRIDGE_VLAN_INFO_UNTAGGED | + BRIDGE_VLAN_INFO_BRENTRY, + &vlchange, extack); + if (err) + goto out; + + if (br_vlan_delete(br, old_pvid)) + br_vlan_notify(br, NULL, old_pvid, 0, RTM_DELVLAN); + br_vlan_notify(br, NULL, pvid, 0, RTM_NEWVLAN); + __set_bit(0, changed); + } + + list_for_each_entry(p, &br->port_list, list) { + /* Update default_pvid config only if we do not conflict with + * user configuration. + */ + vg = nbp_vlan_group(p); + if ((old_pvid && + !vlan_default_pvid(vg, old_pvid)) || + br_vlan_find(vg, pvid)) + continue; + + err = nbp_vlan_add(p, pvid, + BRIDGE_VLAN_INFO_PVID | + BRIDGE_VLAN_INFO_UNTAGGED, + &vlchange, extack); + if (err) + goto err_port; + if (nbp_vlan_delete(p, old_pvid)) + br_vlan_notify(br, p, old_pvid, 0, RTM_DELVLAN); + br_vlan_notify(p->br, p, pvid, 0, RTM_NEWVLAN); + __set_bit(p->port_no, changed); + } + + br->default_pvid = pvid; + +out: + bitmap_free(changed); + return err; + +err_port: + list_for_each_entry_continue_reverse(p, &br->port_list, list) { + if (!test_bit(p->port_no, changed)) + continue; + + if (old_pvid) { + nbp_vlan_add(p, old_pvid, + BRIDGE_VLAN_INFO_PVID | + BRIDGE_VLAN_INFO_UNTAGGED, + &vlchange, NULL); + br_vlan_notify(p->br, p, old_pvid, 0, RTM_NEWVLAN); + } + nbp_vlan_delete(p, pvid); + br_vlan_notify(br, p, pvid, 0, RTM_DELVLAN); + } + + if (test_bit(0, changed)) { + if (old_pvid) { + br_vlan_add(br, old_pvid, + BRIDGE_VLAN_INFO_PVID | + BRIDGE_VLAN_INFO_UNTAGGED | + BRIDGE_VLAN_INFO_BRENTRY, + &vlchange, NULL); + br_vlan_notify(br, NULL, old_pvid, 0, RTM_NEWVLAN); + } + br_vlan_delete(br, pvid); + br_vlan_notify(br, NULL, pvid, 0, RTM_DELVLAN); + } + goto out; +} + +int br_vlan_set_default_pvid(struct net_bridge *br, unsigned long val, + struct netlink_ext_ack *extack) +{ + u16 pvid = val; + int err = 0; + + if (val >= VLAN_VID_MASK) + return -EINVAL; + + if (pvid == br->default_pvid) + goto out; + + /* Only allow default pvid change when filtering is disabled */ + if (br_opt_get(br, BROPT_VLAN_ENABLED)) { + pr_info_once("Please disable vlan filtering to change default_pvid\n"); + err = -EPERM; + goto out; + } + err = __br_vlan_set_default_pvid(br, pvid, extack); +out: + return err; +} + +int br_vlan_init(struct net_bridge *br) +{ + struct net_bridge_vlan_group *vg; + int ret = -ENOMEM; + + vg = kzalloc(sizeof(*vg), GFP_KERNEL); + if (!vg) + goto out; + ret = rhashtable_init(&vg->vlan_hash, &br_vlan_rht_params); + if (ret) + goto err_rhtbl; + ret = vlan_tunnel_init(vg); + if (ret) + goto err_tunnel_init; + INIT_LIST_HEAD(&vg->vlan_list); + br->vlan_proto = htons(ETH_P_8021Q); + br->default_pvid = 1; + rcu_assign_pointer(br->vlgrp, vg); + +out: + return ret; + +err_tunnel_init: + rhashtable_destroy(&vg->vlan_hash); +err_rhtbl: + kfree(vg); + + goto out; +} + +int nbp_vlan_init(struct net_bridge_port *p, struct netlink_ext_ack *extack) +{ + struct switchdev_attr attr = { + .orig_dev = p->br->dev, + .id = SWITCHDEV_ATTR_ID_BRIDGE_VLAN_FILTERING, + .flags = SWITCHDEV_F_SKIP_EOPNOTSUPP, + .u.vlan_filtering = br_opt_get(p->br, BROPT_VLAN_ENABLED), + }; + struct net_bridge_vlan_group *vg; + int ret = -ENOMEM; + + vg = kzalloc(sizeof(struct net_bridge_vlan_group), GFP_KERNEL); + if (!vg) + goto out; + + ret = switchdev_port_attr_set(p->dev, &attr, extack); + if (ret && ret != -EOPNOTSUPP) + goto err_vlan_enabled; + + ret = rhashtable_init(&vg->vlan_hash, &br_vlan_rht_params); + if (ret) + goto err_rhtbl; + ret = vlan_tunnel_init(vg); + if (ret) + goto err_tunnel_init; + INIT_LIST_HEAD(&vg->vlan_list); + rcu_assign_pointer(p->vlgrp, vg); + if (p->br->default_pvid) { + bool changed; + + ret = nbp_vlan_add(p, p->br->default_pvid, + BRIDGE_VLAN_INFO_PVID | + BRIDGE_VLAN_INFO_UNTAGGED, + &changed, extack); + if (ret) + goto err_vlan_add; + br_vlan_notify(p->br, p, p->br->default_pvid, 0, RTM_NEWVLAN); + } +out: + return ret; + +err_vlan_add: + RCU_INIT_POINTER(p->vlgrp, NULL); + synchronize_rcu(); + vlan_tunnel_deinit(vg); +err_tunnel_init: + rhashtable_destroy(&vg->vlan_hash); +err_rhtbl: +err_vlan_enabled: + kfree(vg); + + goto out; +} + +/* Must be protected by RTNL. + * Must be called with vid in range from 1 to 4094 inclusive. + * changed must be true only if the vlan was created or updated + */ +int nbp_vlan_add(struct net_bridge_port *port, u16 vid, u16 flags, + bool *changed, struct netlink_ext_ack *extack) +{ + struct net_bridge_vlan *vlan; + int ret; + + ASSERT_RTNL(); + + *changed = false; + vlan = br_vlan_find(nbp_vlan_group(port), vid); + if (vlan) { + bool would_change = __vlan_flags_would_change(vlan, flags); + + if (would_change) { + /* Pass the flags to the hardware bridge */ + ret = br_switchdev_port_vlan_add(port->dev, vid, flags, + true, extack); + if (ret && ret != -EOPNOTSUPP) + return ret; + } + + __vlan_flags_commit(vlan, flags); + *changed = would_change; + + return 0; + } + + vlan = kzalloc(sizeof(*vlan), GFP_KERNEL); + if (!vlan) + return -ENOMEM; + + vlan->vid = vid; + vlan->port = port; + ret = __vlan_add(vlan, flags, extack); + if (ret) + kfree(vlan); + else + *changed = true; + + return ret; +} + +/* Must be protected by RTNL. + * Must be called with vid in range from 1 to 4094 inclusive. + */ +int nbp_vlan_delete(struct net_bridge_port *port, u16 vid) +{ + struct net_bridge_vlan *v; + + ASSERT_RTNL(); + + v = br_vlan_find(nbp_vlan_group(port), vid); + if (!v) + return -ENOENT; + br_fdb_find_delete_local(port->br, port, port->dev->dev_addr, vid); + br_fdb_delete_by_port(port->br, port, vid, 0); + + return __vlan_del(v); +} + +void nbp_vlan_flush(struct net_bridge_port *port) +{ + struct net_bridge_vlan_group *vg; + + ASSERT_RTNL(); + + vg = nbp_vlan_group(port); + __vlan_flush(port->br, port, vg); + RCU_INIT_POINTER(port->vlgrp, NULL); + synchronize_rcu(); + __vlan_group_free(vg); +} + +void br_vlan_get_stats(const struct net_bridge_vlan *v, + struct pcpu_sw_netstats *stats) +{ + int i; + + memset(stats, 0, sizeof(*stats)); + for_each_possible_cpu(i) { + u64 rxpackets, rxbytes, txpackets, txbytes; + struct pcpu_sw_netstats *cpu_stats; + unsigned int start; + + cpu_stats = per_cpu_ptr(v->stats, i); + do { + start = u64_stats_fetch_begin_irq(&cpu_stats->syncp); + rxpackets = u64_stats_read(&cpu_stats->rx_packets); + rxbytes = u64_stats_read(&cpu_stats->rx_bytes); + txbytes = u64_stats_read(&cpu_stats->tx_bytes); + txpackets = u64_stats_read(&cpu_stats->tx_packets); + } while (u64_stats_fetch_retry_irq(&cpu_stats->syncp, start)); + + u64_stats_add(&stats->rx_packets, rxpackets); + u64_stats_add(&stats->rx_bytes, rxbytes); + u64_stats_add(&stats->tx_bytes, txbytes); + u64_stats_add(&stats->tx_packets, txpackets); + } +} + +int br_vlan_get_pvid(const struct net_device *dev, u16 *p_pvid) +{ + struct net_bridge_vlan_group *vg; + struct net_bridge_port *p; + + ASSERT_RTNL(); + p = br_port_get_check_rtnl(dev); + if (p) + vg = nbp_vlan_group(p); + else if (netif_is_bridge_master(dev)) + vg = br_vlan_group(netdev_priv(dev)); + else + return -EINVAL; + + *p_pvid = br_get_pvid(vg); + return 0; +} +EXPORT_SYMBOL_GPL(br_vlan_get_pvid); + +int br_vlan_get_pvid_rcu(const struct net_device *dev, u16 *p_pvid) +{ + struct net_bridge_vlan_group *vg; + struct net_bridge_port *p; + + p = br_port_get_check_rcu(dev); + if (p) + vg = nbp_vlan_group_rcu(p); + else if (netif_is_bridge_master(dev)) + vg = br_vlan_group_rcu(netdev_priv(dev)); + else + return -EINVAL; + + *p_pvid = br_get_pvid(vg); + return 0; +} +EXPORT_SYMBOL_GPL(br_vlan_get_pvid_rcu); + +void br_vlan_fill_forward_path_pvid(struct net_bridge *br, + struct net_device_path_ctx *ctx, + struct net_device_path *path) +{ + struct net_bridge_vlan_group *vg; + int idx = ctx->num_vlans - 1; + u16 vid; + + path->bridge.vlan_mode = DEV_PATH_BR_VLAN_KEEP; + + if (!br_opt_get(br, BROPT_VLAN_ENABLED)) + return; + + vg = br_vlan_group(br); + + if (idx >= 0 && + ctx->vlan[idx].proto == br->vlan_proto) { + vid = ctx->vlan[idx].id; + } else { + path->bridge.vlan_mode = DEV_PATH_BR_VLAN_TAG; + vid = br_get_pvid(vg); + } + + path->bridge.vlan_id = vid; + path->bridge.vlan_proto = br->vlan_proto; +} + +int br_vlan_fill_forward_path_mode(struct net_bridge *br, + struct net_bridge_port *dst, + struct net_device_path *path) +{ + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *v; + + if (!br_opt_get(br, BROPT_VLAN_ENABLED)) + return 0; + + vg = nbp_vlan_group_rcu(dst); + v = br_vlan_find(vg, path->bridge.vlan_id); + if (!v || !br_vlan_should_use(v)) + return -EINVAL; + + if (!(v->flags & BRIDGE_VLAN_INFO_UNTAGGED)) + return 0; + + if (path->bridge.vlan_mode == DEV_PATH_BR_VLAN_TAG) + path->bridge.vlan_mode = DEV_PATH_BR_VLAN_KEEP; + else if (v->priv_flags & BR_VLFLAG_ADDED_BY_SWITCHDEV) + path->bridge.vlan_mode = DEV_PATH_BR_VLAN_UNTAG_HW; + else + path->bridge.vlan_mode = DEV_PATH_BR_VLAN_UNTAG; + + return 0; +} + +int br_vlan_get_info(const struct net_device *dev, u16 vid, + struct bridge_vlan_info *p_vinfo) +{ + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *v; + struct net_bridge_port *p; + + ASSERT_RTNL(); + p = br_port_get_check_rtnl(dev); + if (p) + vg = nbp_vlan_group(p); + else if (netif_is_bridge_master(dev)) + vg = br_vlan_group(netdev_priv(dev)); + else + return -EINVAL; + + v = br_vlan_find(vg, vid); + if (!v) + return -ENOENT; + + p_vinfo->vid = vid; + p_vinfo->flags = v->flags; + if (vid == br_get_pvid(vg)) + p_vinfo->flags |= BRIDGE_VLAN_INFO_PVID; + return 0; +} +EXPORT_SYMBOL_GPL(br_vlan_get_info); + +int br_vlan_get_info_rcu(const struct net_device *dev, u16 vid, + struct bridge_vlan_info *p_vinfo) +{ + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *v; + struct net_bridge_port *p; + + p = br_port_get_check_rcu(dev); + if (p) + vg = nbp_vlan_group_rcu(p); + else if (netif_is_bridge_master(dev)) + vg = br_vlan_group_rcu(netdev_priv(dev)); + else + return -EINVAL; + + v = br_vlan_find(vg, vid); + if (!v) + return -ENOENT; + + p_vinfo->vid = vid; + p_vinfo->flags = v->flags; + if (vid == br_get_pvid(vg)) + p_vinfo->flags |= BRIDGE_VLAN_INFO_PVID; + return 0; +} +EXPORT_SYMBOL_GPL(br_vlan_get_info_rcu); + +static int br_vlan_is_bind_vlan_dev(const struct net_device *dev) +{ + return is_vlan_dev(dev) && + !!(vlan_dev_priv(dev)->flags & VLAN_FLAG_BRIDGE_BINDING); +} + +static int br_vlan_is_bind_vlan_dev_fn(struct net_device *dev, + __always_unused struct netdev_nested_priv *priv) +{ + return br_vlan_is_bind_vlan_dev(dev); +} + +static bool br_vlan_has_upper_bind_vlan_dev(struct net_device *dev) +{ + int found; + + rcu_read_lock(); + found = netdev_walk_all_upper_dev_rcu(dev, br_vlan_is_bind_vlan_dev_fn, + NULL); + rcu_read_unlock(); + + return !!found; +} + +struct br_vlan_bind_walk_data { + u16 vid; + struct net_device *result; +}; + +static int br_vlan_match_bind_vlan_dev_fn(struct net_device *dev, + struct netdev_nested_priv *priv) +{ + struct br_vlan_bind_walk_data *data = priv->data; + int found = 0; + + if (br_vlan_is_bind_vlan_dev(dev) && + vlan_dev_priv(dev)->vlan_id == data->vid) { + data->result = dev; + found = 1; + } + + return found; +} + +static struct net_device * +br_vlan_get_upper_bind_vlan_dev(struct net_device *dev, u16 vid) +{ + struct br_vlan_bind_walk_data data = { + .vid = vid, + }; + struct netdev_nested_priv priv = { + .data = (void *)&data, + }; + + rcu_read_lock(); + netdev_walk_all_upper_dev_rcu(dev, br_vlan_match_bind_vlan_dev_fn, + &priv); + rcu_read_unlock(); + + return data.result; +} + +static bool br_vlan_is_dev_up(const struct net_device *dev) +{ + return !!(dev->flags & IFF_UP) && netif_oper_up(dev); +} + +static void br_vlan_set_vlan_dev_state(const struct net_bridge *br, + struct net_device *vlan_dev) +{ + u16 vid = vlan_dev_priv(vlan_dev)->vlan_id; + struct net_bridge_vlan_group *vg; + struct net_bridge_port *p; + bool has_carrier = false; + + if (!netif_carrier_ok(br->dev)) { + netif_carrier_off(vlan_dev); + return; + } + + list_for_each_entry(p, &br->port_list, list) { + vg = nbp_vlan_group(p); + if (br_vlan_find(vg, vid) && br_vlan_is_dev_up(p->dev)) { + has_carrier = true; + break; + } + } + + if (has_carrier) + netif_carrier_on(vlan_dev); + else + netif_carrier_off(vlan_dev); +} + +static void br_vlan_set_all_vlan_dev_state(struct net_bridge_port *p) +{ + struct net_bridge_vlan_group *vg = nbp_vlan_group(p); + struct net_bridge_vlan *vlan; + struct net_device *vlan_dev; + + list_for_each_entry(vlan, &vg->vlan_list, vlist) { + vlan_dev = br_vlan_get_upper_bind_vlan_dev(p->br->dev, + vlan->vid); + if (vlan_dev) { + if (br_vlan_is_dev_up(p->dev)) { + if (netif_carrier_ok(p->br->dev)) + netif_carrier_on(vlan_dev); + } else { + br_vlan_set_vlan_dev_state(p->br, vlan_dev); + } + } + } +} + +static void br_vlan_upper_change(struct net_device *dev, + struct net_device *upper_dev, + bool linking) +{ + struct net_bridge *br = netdev_priv(dev); + + if (!br_vlan_is_bind_vlan_dev(upper_dev)) + return; + + if (linking) { + br_vlan_set_vlan_dev_state(br, upper_dev); + br_opt_toggle(br, BROPT_VLAN_BRIDGE_BINDING, true); + } else { + br_opt_toggle(br, BROPT_VLAN_BRIDGE_BINDING, + br_vlan_has_upper_bind_vlan_dev(dev)); + } +} + +struct br_vlan_link_state_walk_data { + struct net_bridge *br; +}; + +static int br_vlan_link_state_change_fn(struct net_device *vlan_dev, + struct netdev_nested_priv *priv) +{ + struct br_vlan_link_state_walk_data *data = priv->data; + + if (br_vlan_is_bind_vlan_dev(vlan_dev)) + br_vlan_set_vlan_dev_state(data->br, vlan_dev); + + return 0; +} + +static void br_vlan_link_state_change(struct net_device *dev, + struct net_bridge *br) +{ + struct br_vlan_link_state_walk_data data = { + .br = br + }; + struct netdev_nested_priv priv = { + .data = (void *)&data, + }; + + rcu_read_lock(); + netdev_walk_all_upper_dev_rcu(dev, br_vlan_link_state_change_fn, + &priv); + rcu_read_unlock(); +} + +/* Must be protected by RTNL. */ +static void nbp_vlan_set_vlan_dev_state(struct net_bridge_port *p, u16 vid) +{ + struct net_device *vlan_dev; + + if (!br_opt_get(p->br, BROPT_VLAN_BRIDGE_BINDING)) + return; + + vlan_dev = br_vlan_get_upper_bind_vlan_dev(p->br->dev, vid); + if (vlan_dev) + br_vlan_set_vlan_dev_state(p->br, vlan_dev); +} + +/* Must be protected by RTNL. */ +int br_vlan_bridge_event(struct net_device *dev, unsigned long event, void *ptr) +{ + struct netdev_notifier_changeupper_info *info; + struct net_bridge *br = netdev_priv(dev); + int vlcmd = 0, ret = 0; + bool changed = false; + + switch (event) { + case NETDEV_REGISTER: + ret = br_vlan_add(br, br->default_pvid, + BRIDGE_VLAN_INFO_PVID | + BRIDGE_VLAN_INFO_UNTAGGED | + BRIDGE_VLAN_INFO_BRENTRY, &changed, NULL); + vlcmd = RTM_NEWVLAN; + break; + case NETDEV_UNREGISTER: + changed = !br_vlan_delete(br, br->default_pvid); + vlcmd = RTM_DELVLAN; + break; + case NETDEV_CHANGEUPPER: + info = ptr; + br_vlan_upper_change(dev, info->upper_dev, info->linking); + break; + + case NETDEV_CHANGE: + case NETDEV_UP: + if (!br_opt_get(br, BROPT_VLAN_BRIDGE_BINDING)) + break; + br_vlan_link_state_change(dev, br); + break; + } + if (changed) + br_vlan_notify(br, NULL, br->default_pvid, 0, vlcmd); + + return ret; +} + +/* Must be protected by RTNL. */ +void br_vlan_port_event(struct net_bridge_port *p, unsigned long event) +{ + if (!br_opt_get(p->br, BROPT_VLAN_BRIDGE_BINDING)) + return; + + switch (event) { + case NETDEV_CHANGE: + case NETDEV_DOWN: + case NETDEV_UP: + br_vlan_set_all_vlan_dev_state(p); + break; + } +} + +static bool br_vlan_stats_fill(struct sk_buff *skb, + const struct net_bridge_vlan *v) +{ + struct pcpu_sw_netstats stats; + struct nlattr *nest; + + nest = nla_nest_start(skb, BRIDGE_VLANDB_ENTRY_STATS); + if (!nest) + return false; + + br_vlan_get_stats(v, &stats); + if (nla_put_u64_64bit(skb, BRIDGE_VLANDB_STATS_RX_BYTES, + u64_stats_read(&stats.rx_bytes), + BRIDGE_VLANDB_STATS_PAD) || + nla_put_u64_64bit(skb, BRIDGE_VLANDB_STATS_RX_PACKETS, + u64_stats_read(&stats.rx_packets), + BRIDGE_VLANDB_STATS_PAD) || + nla_put_u64_64bit(skb, BRIDGE_VLANDB_STATS_TX_BYTES, + u64_stats_read(&stats.tx_bytes), + BRIDGE_VLANDB_STATS_PAD) || + nla_put_u64_64bit(skb, BRIDGE_VLANDB_STATS_TX_PACKETS, + u64_stats_read(&stats.tx_packets), + BRIDGE_VLANDB_STATS_PAD)) + goto out_err; + + nla_nest_end(skb, nest); + + return true; + +out_err: + nla_nest_cancel(skb, nest); + return false; +} + +/* v_opts is used to dump the options which must be equal in the whole range */ +static bool br_vlan_fill_vids(struct sk_buff *skb, u16 vid, u16 vid_range, + const struct net_bridge_vlan *v_opts, + u16 flags, + bool dump_stats) +{ + struct bridge_vlan_info info; + struct nlattr *nest; + + nest = nla_nest_start(skb, BRIDGE_VLANDB_ENTRY); + if (!nest) + return false; + + memset(&info, 0, sizeof(info)); + info.vid = vid; + if (flags & BRIDGE_VLAN_INFO_UNTAGGED) + info.flags |= BRIDGE_VLAN_INFO_UNTAGGED; + if (flags & BRIDGE_VLAN_INFO_PVID) + info.flags |= BRIDGE_VLAN_INFO_PVID; + + if (nla_put(skb, BRIDGE_VLANDB_ENTRY_INFO, sizeof(info), &info)) + goto out_err; + + if (vid_range && vid < vid_range && + !(flags & BRIDGE_VLAN_INFO_PVID) && + nla_put_u16(skb, BRIDGE_VLANDB_ENTRY_RANGE, vid_range)) + goto out_err; + + if (v_opts) { + if (!br_vlan_opts_fill(skb, v_opts)) + goto out_err; + + if (dump_stats && !br_vlan_stats_fill(skb, v_opts)) + goto out_err; + } + + nla_nest_end(skb, nest); + + return true; + +out_err: + nla_nest_cancel(skb, nest); + return false; +} + +static size_t rtnl_vlan_nlmsg_size(void) +{ + return NLMSG_ALIGN(sizeof(struct br_vlan_msg)) + + nla_total_size(0) /* BRIDGE_VLANDB_ENTRY */ + + nla_total_size(sizeof(u16)) /* BRIDGE_VLANDB_ENTRY_RANGE */ + + nla_total_size(sizeof(struct bridge_vlan_info)) /* BRIDGE_VLANDB_ENTRY_INFO */ + + br_vlan_opts_nl_size(); /* bridge vlan options */ +} + +void br_vlan_notify(const struct net_bridge *br, + const struct net_bridge_port *p, + u16 vid, u16 vid_range, + int cmd) +{ + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *v = NULL; + struct br_vlan_msg *bvm; + struct nlmsghdr *nlh; + struct sk_buff *skb; + int err = -ENOBUFS; + struct net *net; + u16 flags = 0; + int ifindex; + + /* right now notifications are done only with rtnl held */ + ASSERT_RTNL(); + + if (p) { + ifindex = p->dev->ifindex; + vg = nbp_vlan_group(p); + net = dev_net(p->dev); + } else { + ifindex = br->dev->ifindex; + vg = br_vlan_group(br); + net = dev_net(br->dev); + } + + skb = nlmsg_new(rtnl_vlan_nlmsg_size(), GFP_KERNEL); + if (!skb) + goto out_err; + + err = -EMSGSIZE; + nlh = nlmsg_put(skb, 0, 0, cmd, sizeof(*bvm), 0); + if (!nlh) + goto out_err; + bvm = nlmsg_data(nlh); + memset(bvm, 0, sizeof(*bvm)); + bvm->family = AF_BRIDGE; + bvm->ifindex = ifindex; + + switch (cmd) { + case RTM_NEWVLAN: + /* need to find the vlan due to flags/options */ + v = br_vlan_find(vg, vid); + if (!v || !br_vlan_should_use(v)) + goto out_kfree; + + flags = v->flags; + if (br_get_pvid(vg) == v->vid) + flags |= BRIDGE_VLAN_INFO_PVID; + break; + case RTM_DELVLAN: + break; + default: + goto out_kfree; + } + + if (!br_vlan_fill_vids(skb, vid, vid_range, v, flags, false)) + goto out_err; + + nlmsg_end(skb, nlh); + rtnl_notify(skb, net, 0, RTNLGRP_BRVLAN, NULL, GFP_KERNEL); + return; + +out_err: + rtnl_set_sk_err(net, RTNLGRP_BRVLAN, err); +out_kfree: + kfree_skb(skb); +} + +/* check if v_curr can enter a range ending in range_end */ +bool br_vlan_can_enter_range(const struct net_bridge_vlan *v_curr, + const struct net_bridge_vlan *range_end) +{ + return v_curr->vid - range_end->vid == 1 && + range_end->flags == v_curr->flags && + br_vlan_opts_eq_range(v_curr, range_end); +} + +static int br_vlan_dump_dev(const struct net_device *dev, + struct sk_buff *skb, + struct netlink_callback *cb, + u32 dump_flags) +{ + struct net_bridge_vlan *v, *range_start = NULL, *range_end = NULL; + bool dump_global = !!(dump_flags & BRIDGE_VLANDB_DUMPF_GLOBAL); + bool dump_stats = !!(dump_flags & BRIDGE_VLANDB_DUMPF_STATS); + struct net_bridge_vlan_group *vg; + int idx = 0, s_idx = cb->args[1]; + struct nlmsghdr *nlh = NULL; + struct net_bridge_port *p; + struct br_vlan_msg *bvm; + struct net_bridge *br; + int err = 0; + u16 pvid; + + if (!netif_is_bridge_master(dev) && !netif_is_bridge_port(dev)) + return -EINVAL; + + if (netif_is_bridge_master(dev)) { + br = netdev_priv(dev); + vg = br_vlan_group_rcu(br); + p = NULL; + } else { + /* global options are dumped only for bridge devices */ + if (dump_global) + return 0; + + p = br_port_get_rcu(dev); + if (WARN_ON(!p)) + return -EINVAL; + vg = nbp_vlan_group_rcu(p); + br = p->br; + } + + if (!vg) + return 0; + + nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + RTM_NEWVLAN, sizeof(*bvm), NLM_F_MULTI); + if (!nlh) + return -EMSGSIZE; + bvm = nlmsg_data(nlh); + memset(bvm, 0, sizeof(*bvm)); + bvm->family = PF_BRIDGE; + bvm->ifindex = dev->ifindex; + pvid = br_get_pvid(vg); + + /* idx must stay at range's beginning until it is filled in */ + list_for_each_entry_rcu(v, &vg->vlan_list, vlist) { + if (!dump_global && !br_vlan_should_use(v)) + continue; + if (idx < s_idx) { + idx++; + continue; + } + + if (!range_start) { + range_start = v; + range_end = v; + continue; + } + + if (dump_global) { + if (br_vlan_global_opts_can_enter_range(v, range_end)) + goto update_end; + if (!br_vlan_global_opts_fill(skb, range_start->vid, + range_end->vid, + range_start)) { + err = -EMSGSIZE; + break; + } + /* advance number of filled vlans */ + idx += range_end->vid - range_start->vid + 1; + + range_start = v; + } else if (dump_stats || v->vid == pvid || + !br_vlan_can_enter_range(v, range_end)) { + u16 vlan_flags = br_vlan_flags(range_start, pvid); + + if (!br_vlan_fill_vids(skb, range_start->vid, + range_end->vid, range_start, + vlan_flags, dump_stats)) { + err = -EMSGSIZE; + break; + } + /* advance number of filled vlans */ + idx += range_end->vid - range_start->vid + 1; + + range_start = v; + } +update_end: + range_end = v; + } + + /* err will be 0 and range_start will be set in 3 cases here: + * - first vlan (range_start == range_end) + * - last vlan (range_start == range_end, not in range) + * - last vlan range (range_start != range_end, in range) + */ + if (!err && range_start) { + if (dump_global && + !br_vlan_global_opts_fill(skb, range_start->vid, + range_end->vid, range_start)) + err = -EMSGSIZE; + else if (!dump_global && + !br_vlan_fill_vids(skb, range_start->vid, + range_end->vid, range_start, + br_vlan_flags(range_start, pvid), + dump_stats)) + err = -EMSGSIZE; + } + + cb->args[1] = err ? idx : 0; + + nlmsg_end(skb, nlh); + + return err; +} + +static const struct nla_policy br_vlan_db_dump_pol[BRIDGE_VLANDB_DUMP_MAX + 1] = { + [BRIDGE_VLANDB_DUMP_FLAGS] = { .type = NLA_U32 }, +}; + +static int br_vlan_rtm_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct nlattr *dtb[BRIDGE_VLANDB_DUMP_MAX + 1]; + int idx = 0, err = 0, s_idx = cb->args[0]; + struct net *net = sock_net(skb->sk); + struct br_vlan_msg *bvm; + struct net_device *dev; + u32 dump_flags = 0; + + err = nlmsg_parse(cb->nlh, sizeof(*bvm), dtb, BRIDGE_VLANDB_DUMP_MAX, + br_vlan_db_dump_pol, cb->extack); + if (err < 0) + return err; + + bvm = nlmsg_data(cb->nlh); + if (dtb[BRIDGE_VLANDB_DUMP_FLAGS]) + dump_flags = nla_get_u32(dtb[BRIDGE_VLANDB_DUMP_FLAGS]); + + rcu_read_lock(); + if (bvm->ifindex) { + dev = dev_get_by_index_rcu(net, bvm->ifindex); + if (!dev) { + err = -ENODEV; + goto out_err; + } + err = br_vlan_dump_dev(dev, skb, cb, dump_flags); + /* if the dump completed without an error we return 0 here */ + if (err != -EMSGSIZE) + goto out_err; + } else { + for_each_netdev_rcu(net, dev) { + if (idx < s_idx) + goto skip; + + err = br_vlan_dump_dev(dev, skb, cb, dump_flags); + if (err == -EMSGSIZE) + break; +skip: + idx++; + } + } + cb->args[0] = idx; + rcu_read_unlock(); + + return skb->len; + +out_err: + rcu_read_unlock(); + + return err; +} + +static const struct nla_policy br_vlan_db_policy[BRIDGE_VLANDB_ENTRY_MAX + 1] = { + [BRIDGE_VLANDB_ENTRY_INFO] = + NLA_POLICY_EXACT_LEN(sizeof(struct bridge_vlan_info)), + [BRIDGE_VLANDB_ENTRY_RANGE] = { .type = NLA_U16 }, + [BRIDGE_VLANDB_ENTRY_STATE] = { .type = NLA_U8 }, + [BRIDGE_VLANDB_ENTRY_TUNNEL_INFO] = { .type = NLA_NESTED }, + [BRIDGE_VLANDB_ENTRY_MCAST_ROUTER] = { .type = NLA_U8 }, +}; + +static int br_vlan_rtm_process_one(struct net_device *dev, + const struct nlattr *attr, + int cmd, struct netlink_ext_ack *extack) +{ + struct bridge_vlan_info *vinfo, vrange_end, *vinfo_last = NULL; + struct nlattr *tb[BRIDGE_VLANDB_ENTRY_MAX + 1]; + bool changed = false, skip_processing = false; + struct net_bridge_vlan_group *vg; + struct net_bridge_port *p = NULL; + int err = 0, cmdmap = 0; + struct net_bridge *br; + + if (netif_is_bridge_master(dev)) { + br = netdev_priv(dev); + vg = br_vlan_group(br); + } else { + p = br_port_get_rtnl(dev); + if (WARN_ON(!p)) + return -ENODEV; + br = p->br; + vg = nbp_vlan_group(p); + } + + if (WARN_ON(!vg)) + return -ENODEV; + + err = nla_parse_nested(tb, BRIDGE_VLANDB_ENTRY_MAX, attr, + br_vlan_db_policy, extack); + if (err) + return err; + + if (!tb[BRIDGE_VLANDB_ENTRY_INFO]) { + NL_SET_ERR_MSG_MOD(extack, "Missing vlan entry info"); + return -EINVAL; + } + memset(&vrange_end, 0, sizeof(vrange_end)); + + vinfo = nla_data(tb[BRIDGE_VLANDB_ENTRY_INFO]); + if (vinfo->flags & (BRIDGE_VLAN_INFO_RANGE_BEGIN | + BRIDGE_VLAN_INFO_RANGE_END)) { + NL_SET_ERR_MSG_MOD(extack, "Old-style vlan ranges are not allowed when using RTM vlan calls"); + return -EINVAL; + } + if (!br_vlan_valid_id(vinfo->vid, extack)) + return -EINVAL; + + if (tb[BRIDGE_VLANDB_ENTRY_RANGE]) { + vrange_end.vid = nla_get_u16(tb[BRIDGE_VLANDB_ENTRY_RANGE]); + /* validate user-provided flags without RANGE_BEGIN */ + vrange_end.flags = BRIDGE_VLAN_INFO_RANGE_END | vinfo->flags; + vinfo->flags |= BRIDGE_VLAN_INFO_RANGE_BEGIN; + + /* vinfo_last is the range start, vinfo the range end */ + vinfo_last = vinfo; + vinfo = &vrange_end; + + if (!br_vlan_valid_id(vinfo->vid, extack) || + !br_vlan_valid_range(vinfo, vinfo_last, extack)) + return -EINVAL; + } + + switch (cmd) { + case RTM_NEWVLAN: + cmdmap = RTM_SETLINK; + skip_processing = !!(vinfo->flags & BRIDGE_VLAN_INFO_ONLY_OPTS); + break; + case RTM_DELVLAN: + cmdmap = RTM_DELLINK; + break; + } + + if (!skip_processing) { + struct bridge_vlan_info *tmp_last = vinfo_last; + + /* br_process_vlan_info may overwrite vinfo_last */ + err = br_process_vlan_info(br, p, cmdmap, vinfo, &tmp_last, + &changed, extack); + + /* notify first if anything changed */ + if (changed) + br_ifinfo_notify(cmdmap, br, p); + + if (err) + return err; + } + + /* deal with options */ + if (cmd == RTM_NEWVLAN) { + struct net_bridge_vlan *range_start, *range_end; + + if (vinfo_last) { + range_start = br_vlan_find(vg, vinfo_last->vid); + range_end = br_vlan_find(vg, vinfo->vid); + } else { + range_start = br_vlan_find(vg, vinfo->vid); + range_end = range_start; + } + + err = br_vlan_process_options(br, p, range_start, range_end, + tb, extack); + } + + return err; +} + +static int br_vlan_rtm_process(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct br_vlan_msg *bvm; + struct net_device *dev; + struct nlattr *attr; + int err, vlans = 0; + int rem; + + /* this should validate the header and check for remaining bytes */ + err = nlmsg_parse(nlh, sizeof(*bvm), NULL, BRIDGE_VLANDB_MAX, NULL, + extack); + if (err < 0) + return err; + + bvm = nlmsg_data(nlh); + dev = __dev_get_by_index(net, bvm->ifindex); + if (!dev) + return -ENODEV; + + if (!netif_is_bridge_master(dev) && !netif_is_bridge_port(dev)) { + NL_SET_ERR_MSG_MOD(extack, "The device is not a valid bridge or bridge port"); + return -EINVAL; + } + + nlmsg_for_each_attr(attr, nlh, sizeof(*bvm), rem) { + switch (nla_type(attr)) { + case BRIDGE_VLANDB_ENTRY: + err = br_vlan_rtm_process_one(dev, attr, + nlh->nlmsg_type, + extack); + break; + case BRIDGE_VLANDB_GLOBAL_OPTIONS: + err = br_vlan_rtm_process_global_options(dev, attr, + nlh->nlmsg_type, + extack); + break; + default: + continue; + } + + vlans++; + if (err) + break; + } + if (!vlans) { + NL_SET_ERR_MSG_MOD(extack, "No vlans found to process"); + err = -EINVAL; + } + + return err; +} + +void br_vlan_rtnl_init(void) +{ + rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_GETVLAN, NULL, + br_vlan_rtm_dump, 0); + rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_NEWVLAN, + br_vlan_rtm_process, NULL, 0); + rtnl_register_module(THIS_MODULE, PF_BRIDGE, RTM_DELVLAN, + br_vlan_rtm_process, NULL, 0); +} + +void br_vlan_rtnl_uninit(void) +{ + rtnl_unregister(PF_BRIDGE, RTM_GETVLAN); + rtnl_unregister(PF_BRIDGE, RTM_NEWVLAN); + rtnl_unregister(PF_BRIDGE, RTM_DELVLAN); +} diff --git a/net/bridge/br_vlan_options.c b/net/bridge/br_vlan_options.c new file mode 100644 index 000000000..a2724d032 --- /dev/null +++ b/net/bridge/br_vlan_options.c @@ -0,0 +1,697 @@ +// SPDX-License-Identifier: GPL-2.0-only +// Copyright (c) 2020, Nikolay Aleksandrov +#include +#include +#include +#include +#include + +#include "br_private.h" +#include "br_private_tunnel.h" + +static bool __vlan_tun_put(struct sk_buff *skb, const struct net_bridge_vlan *v) +{ + __be32 tid = tunnel_id_to_key32(v->tinfo.tunnel_id); + struct nlattr *nest; + + if (!v->tinfo.tunnel_dst) + return true; + + nest = nla_nest_start(skb, BRIDGE_VLANDB_ENTRY_TUNNEL_INFO); + if (!nest) + return false; + if (nla_put_u32(skb, BRIDGE_VLANDB_TINFO_ID, be32_to_cpu(tid))) { + nla_nest_cancel(skb, nest); + return false; + } + nla_nest_end(skb, nest); + + return true; +} + +static bool __vlan_tun_can_enter_range(const struct net_bridge_vlan *v_curr, + const struct net_bridge_vlan *range_end) +{ + return (!v_curr->tinfo.tunnel_dst && !range_end->tinfo.tunnel_dst) || + vlan_tunid_inrange(v_curr, range_end); +} + +/* check if the options' state of v_curr allow it to enter the range */ +bool br_vlan_opts_eq_range(const struct net_bridge_vlan *v_curr, + const struct net_bridge_vlan *range_end) +{ + u8 range_mc_rtr = br_vlan_multicast_router(range_end); + u8 curr_mc_rtr = br_vlan_multicast_router(v_curr); + + return v_curr->state == range_end->state && + __vlan_tun_can_enter_range(v_curr, range_end) && + curr_mc_rtr == range_mc_rtr; +} + +bool br_vlan_opts_fill(struct sk_buff *skb, const struct net_bridge_vlan *v) +{ + if (nla_put_u8(skb, BRIDGE_VLANDB_ENTRY_STATE, br_vlan_get_state(v)) || + !__vlan_tun_put(skb, v)) + return false; + +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + if (nla_put_u8(skb, BRIDGE_VLANDB_ENTRY_MCAST_ROUTER, + br_vlan_multicast_router(v))) + return false; +#endif + + return true; +} + +size_t br_vlan_opts_nl_size(void) +{ + return nla_total_size(sizeof(u8)) /* BRIDGE_VLANDB_ENTRY_STATE */ + + nla_total_size(0) /* BRIDGE_VLANDB_ENTRY_TUNNEL_INFO */ + + nla_total_size(sizeof(u32)) /* BRIDGE_VLANDB_TINFO_ID */ +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + + nla_total_size(sizeof(u8)) /* BRIDGE_VLANDB_ENTRY_MCAST_ROUTER */ +#endif + + 0; +} + +static int br_vlan_modify_state(struct net_bridge_vlan_group *vg, + struct net_bridge_vlan *v, + u8 state, + bool *changed, + struct netlink_ext_ack *extack) +{ + struct net_bridge *br; + + ASSERT_RTNL(); + + if (state > BR_STATE_BLOCKING) { + NL_SET_ERR_MSG_MOD(extack, "Invalid vlan state"); + return -EINVAL; + } + + if (br_vlan_is_brentry(v)) + br = v->br; + else + br = v->port->br; + + if (br->stp_enabled == BR_KERNEL_STP) { + NL_SET_ERR_MSG_MOD(extack, "Can't modify vlan state when using kernel STP"); + return -EBUSY; + } + + if (br_opt_get(br, BROPT_MST_ENABLED)) { + NL_SET_ERR_MSG_MOD(extack, "Can't modify vlan state directly when MST is enabled"); + return -EBUSY; + } + + if (v->state == state) + return 0; + + if (v->vid == br_get_pvid(vg)) + br_vlan_set_pvid_state(vg, state); + + br_vlan_set_state(v, state); + *changed = true; + + return 0; +} + +static const struct nla_policy br_vlandb_tinfo_pol[BRIDGE_VLANDB_TINFO_MAX + 1] = { + [BRIDGE_VLANDB_TINFO_ID] = { .type = NLA_U32 }, + [BRIDGE_VLANDB_TINFO_CMD] = { .type = NLA_U32 }, +}; + +static int br_vlan_modify_tunnel(const struct net_bridge_port *p, + struct net_bridge_vlan *v, + struct nlattr **tb, + bool *changed, + struct netlink_ext_ack *extack) +{ + struct nlattr *tun_tb[BRIDGE_VLANDB_TINFO_MAX + 1], *attr; + struct bridge_vlan_info *vinfo; + u32 tun_id = 0; + int cmd, err; + + if (!p) { + NL_SET_ERR_MSG_MOD(extack, "Can't modify tunnel mapping of non-port vlans"); + return -EINVAL; + } + if (!(p->flags & BR_VLAN_TUNNEL)) { + NL_SET_ERR_MSG_MOD(extack, "Port doesn't have tunnel flag set"); + return -EINVAL; + } + + attr = tb[BRIDGE_VLANDB_ENTRY_TUNNEL_INFO]; + err = nla_parse_nested(tun_tb, BRIDGE_VLANDB_TINFO_MAX, attr, + br_vlandb_tinfo_pol, extack); + if (err) + return err; + + if (!tun_tb[BRIDGE_VLANDB_TINFO_CMD]) { + NL_SET_ERR_MSG_MOD(extack, "Missing tunnel command attribute"); + return -ENOENT; + } + cmd = nla_get_u32(tun_tb[BRIDGE_VLANDB_TINFO_CMD]); + switch (cmd) { + case RTM_SETLINK: + if (!tun_tb[BRIDGE_VLANDB_TINFO_ID]) { + NL_SET_ERR_MSG_MOD(extack, "Missing tunnel id attribute"); + return -ENOENT; + } + /* when working on vlan ranges this is the starting tunnel id */ + tun_id = nla_get_u32(tun_tb[BRIDGE_VLANDB_TINFO_ID]); + /* vlan info attr is guaranteed by br_vlan_rtm_process_one */ + vinfo = nla_data(tb[BRIDGE_VLANDB_ENTRY_INFO]); + /* tunnel ids are mapped to each vlan in increasing order, + * the starting vlan is in BRIDGE_VLANDB_ENTRY_INFO and v is the + * current vlan, so we compute: tun_id + v - vinfo->vid + */ + tun_id += v->vid - vinfo->vid; + break; + case RTM_DELLINK: + break; + default: + NL_SET_ERR_MSG_MOD(extack, "Unsupported tunnel command"); + return -EINVAL; + } + + return br_vlan_tunnel_info(p, cmd, v->vid, tun_id, changed); +} + +static int br_vlan_process_one_opts(const struct net_bridge *br, + const struct net_bridge_port *p, + struct net_bridge_vlan_group *vg, + struct net_bridge_vlan *v, + struct nlattr **tb, + bool *changed, + struct netlink_ext_ack *extack) +{ + int err; + + *changed = false; + if (tb[BRIDGE_VLANDB_ENTRY_STATE]) { + u8 state = nla_get_u8(tb[BRIDGE_VLANDB_ENTRY_STATE]); + + err = br_vlan_modify_state(vg, v, state, changed, extack); + if (err) + return err; + } + if (tb[BRIDGE_VLANDB_ENTRY_TUNNEL_INFO]) { + err = br_vlan_modify_tunnel(p, v, tb, changed, extack); + if (err) + return err; + } + +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + if (tb[BRIDGE_VLANDB_ENTRY_MCAST_ROUTER]) { + u8 val; + + val = nla_get_u8(tb[BRIDGE_VLANDB_ENTRY_MCAST_ROUTER]); + err = br_multicast_set_vlan_router(v, val); + if (err) + return err; + *changed = true; + } +#endif + + return 0; +} + +int br_vlan_process_options(const struct net_bridge *br, + const struct net_bridge_port *p, + struct net_bridge_vlan *range_start, + struct net_bridge_vlan *range_end, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct net_bridge_vlan *v, *curr_start = NULL, *curr_end = NULL; + struct net_bridge_vlan_group *vg; + int vid, err = 0; + u16 pvid; + + if (p) + vg = nbp_vlan_group(p); + else + vg = br_vlan_group(br); + + if (!range_start || !br_vlan_should_use(range_start)) { + NL_SET_ERR_MSG_MOD(extack, "Vlan range start doesn't exist, can't process options"); + return -ENOENT; + } + if (!range_end || !br_vlan_should_use(range_end)) { + NL_SET_ERR_MSG_MOD(extack, "Vlan range end doesn't exist, can't process options"); + return -ENOENT; + } + + pvid = br_get_pvid(vg); + for (vid = range_start->vid; vid <= range_end->vid; vid++) { + bool changed = false; + + v = br_vlan_find(vg, vid); + if (!v || !br_vlan_should_use(v)) { + NL_SET_ERR_MSG_MOD(extack, "Vlan in range doesn't exist, can't process options"); + err = -ENOENT; + break; + } + + err = br_vlan_process_one_opts(br, p, vg, v, tb, &changed, + extack); + if (err) + break; + + if (changed) { + /* vlan options changed, check for range */ + if (!curr_start) { + curr_start = v; + curr_end = v; + continue; + } + + if (v->vid == pvid || + !br_vlan_can_enter_range(v, curr_end)) { + br_vlan_notify(br, p, curr_start->vid, + curr_end->vid, RTM_NEWVLAN); + curr_start = v; + } + curr_end = v; + } else { + /* nothing changed and nothing to notify yet */ + if (!curr_start) + continue; + + br_vlan_notify(br, p, curr_start->vid, curr_end->vid, + RTM_NEWVLAN); + curr_start = NULL; + curr_end = NULL; + } + } + if (curr_start) + br_vlan_notify(br, p, curr_start->vid, curr_end->vid, + RTM_NEWVLAN); + + return err; +} + +bool br_vlan_global_opts_can_enter_range(const struct net_bridge_vlan *v_curr, + const struct net_bridge_vlan *r_end) +{ + return v_curr->vid - r_end->vid == 1 && + v_curr->msti == r_end->msti && + ((v_curr->priv_flags ^ r_end->priv_flags) & + BR_VLFLAG_GLOBAL_MCAST_ENABLED) == 0 && + br_multicast_ctx_options_equal(&v_curr->br_mcast_ctx, + &r_end->br_mcast_ctx); +} + +bool br_vlan_global_opts_fill(struct sk_buff *skb, u16 vid, u16 vid_range, + const struct net_bridge_vlan *v_opts) +{ + struct nlattr *nest2 __maybe_unused; + u64 clockval __maybe_unused; + struct nlattr *nest; + + nest = nla_nest_start(skb, BRIDGE_VLANDB_GLOBAL_OPTIONS); + if (!nest) + return false; + + if (nla_put_u16(skb, BRIDGE_VLANDB_GOPTS_ID, vid)) + goto out_err; + + if (vid_range && vid < vid_range && + nla_put_u16(skb, BRIDGE_VLANDB_GOPTS_RANGE, vid_range)) + goto out_err; + +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + if (nla_put_u8(skb, BRIDGE_VLANDB_GOPTS_MCAST_SNOOPING, + !!(v_opts->priv_flags & BR_VLFLAG_GLOBAL_MCAST_ENABLED)) || + nla_put_u8(skb, BRIDGE_VLANDB_GOPTS_MCAST_IGMP_VERSION, + v_opts->br_mcast_ctx.multicast_igmp_version) || + nla_put_u32(skb, BRIDGE_VLANDB_GOPTS_MCAST_LAST_MEMBER_CNT, + v_opts->br_mcast_ctx.multicast_last_member_count) || + nla_put_u32(skb, BRIDGE_VLANDB_GOPTS_MCAST_STARTUP_QUERY_CNT, + v_opts->br_mcast_ctx.multicast_startup_query_count) || + nla_put_u8(skb, BRIDGE_VLANDB_GOPTS_MCAST_QUERIER, + v_opts->br_mcast_ctx.multicast_querier) || + br_multicast_dump_querier_state(skb, &v_opts->br_mcast_ctx, + BRIDGE_VLANDB_GOPTS_MCAST_QUERIER_STATE)) + goto out_err; + + clockval = jiffies_to_clock_t(v_opts->br_mcast_ctx.multicast_last_member_interval); + if (nla_put_u64_64bit(skb, BRIDGE_VLANDB_GOPTS_MCAST_LAST_MEMBER_INTVL, + clockval, BRIDGE_VLANDB_GOPTS_PAD)) + goto out_err; + clockval = jiffies_to_clock_t(v_opts->br_mcast_ctx.multicast_membership_interval); + if (nla_put_u64_64bit(skb, BRIDGE_VLANDB_GOPTS_MCAST_MEMBERSHIP_INTVL, + clockval, BRIDGE_VLANDB_GOPTS_PAD)) + goto out_err; + clockval = jiffies_to_clock_t(v_opts->br_mcast_ctx.multicast_querier_interval); + if (nla_put_u64_64bit(skb, BRIDGE_VLANDB_GOPTS_MCAST_QUERIER_INTVL, + clockval, BRIDGE_VLANDB_GOPTS_PAD)) + goto out_err; + clockval = jiffies_to_clock_t(v_opts->br_mcast_ctx.multicast_query_interval); + if (nla_put_u64_64bit(skb, BRIDGE_VLANDB_GOPTS_MCAST_QUERY_INTVL, + clockval, BRIDGE_VLANDB_GOPTS_PAD)) + goto out_err; + clockval = jiffies_to_clock_t(v_opts->br_mcast_ctx.multicast_query_response_interval); + if (nla_put_u64_64bit(skb, BRIDGE_VLANDB_GOPTS_MCAST_QUERY_RESPONSE_INTVL, + clockval, BRIDGE_VLANDB_GOPTS_PAD)) + goto out_err; + clockval = jiffies_to_clock_t(v_opts->br_mcast_ctx.multicast_startup_query_interval); + if (nla_put_u64_64bit(skb, BRIDGE_VLANDB_GOPTS_MCAST_STARTUP_QUERY_INTVL, + clockval, BRIDGE_VLANDB_GOPTS_PAD)) + goto out_err; + + if (br_rports_have_mc_router(&v_opts->br_mcast_ctx)) { + nest2 = nla_nest_start(skb, + BRIDGE_VLANDB_GOPTS_MCAST_ROUTER_PORTS); + if (!nest2) + goto out_err; + + rcu_read_lock(); + if (br_rports_fill_info(skb, &v_opts->br_mcast_ctx)) { + rcu_read_unlock(); + nla_nest_cancel(skb, nest2); + goto out_err; + } + rcu_read_unlock(); + + nla_nest_end(skb, nest2); + } + +#if IS_ENABLED(CONFIG_IPV6) + if (nla_put_u8(skb, BRIDGE_VLANDB_GOPTS_MCAST_MLD_VERSION, + v_opts->br_mcast_ctx.multicast_mld_version)) + goto out_err; +#endif +#endif + + if (nla_put_u16(skb, BRIDGE_VLANDB_GOPTS_MSTI, v_opts->msti)) + goto out_err; + + nla_nest_end(skb, nest); + + return true; + +out_err: + nla_nest_cancel(skb, nest); + return false; +} + +static size_t rtnl_vlan_global_opts_nlmsg_size(const struct net_bridge_vlan *v) +{ + return NLMSG_ALIGN(sizeof(struct br_vlan_msg)) + + nla_total_size(0) /* BRIDGE_VLANDB_GLOBAL_OPTIONS */ + + nla_total_size(sizeof(u16)) /* BRIDGE_VLANDB_GOPTS_ID */ +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + + nla_total_size(sizeof(u8)) /* BRIDGE_VLANDB_GOPTS_MCAST_SNOOPING */ + + nla_total_size(sizeof(u8)) /* BRIDGE_VLANDB_GOPTS_MCAST_IGMP_VERSION */ + + nla_total_size(sizeof(u8)) /* BRIDGE_VLANDB_GOPTS_MCAST_MLD_VERSION */ + + nla_total_size(sizeof(u32)) /* BRIDGE_VLANDB_GOPTS_MCAST_LAST_MEMBER_CNT */ + + nla_total_size(sizeof(u32)) /* BRIDGE_VLANDB_GOPTS_MCAST_STARTUP_QUERY_CNT */ + + nla_total_size(sizeof(u64)) /* BRIDGE_VLANDB_GOPTS_MCAST_LAST_MEMBER_INTVL */ + + nla_total_size(sizeof(u64)) /* BRIDGE_VLANDB_GOPTS_MCAST_MEMBERSHIP_INTVL */ + + nla_total_size(sizeof(u64)) /* BRIDGE_VLANDB_GOPTS_MCAST_QUERIER_INTVL */ + + nla_total_size(sizeof(u64)) /* BRIDGE_VLANDB_GOPTS_MCAST_QUERY_INTVL */ + + nla_total_size(sizeof(u64)) /* BRIDGE_VLANDB_GOPTS_MCAST_QUERY_RESPONSE_INTVL */ + + nla_total_size(sizeof(u64)) /* BRIDGE_VLANDB_GOPTS_MCAST_STARTUP_QUERY_INTVL */ + + nla_total_size(sizeof(u8)) /* BRIDGE_VLANDB_GOPTS_MCAST_QUERIER */ + + br_multicast_querier_state_size() /* BRIDGE_VLANDB_GOPTS_MCAST_QUERIER_STATE */ + + nla_total_size(0) /* BRIDGE_VLANDB_GOPTS_MCAST_ROUTER_PORTS */ + + br_rports_size(&v->br_mcast_ctx) /* BRIDGE_VLANDB_GOPTS_MCAST_ROUTER_PORTS */ +#endif + + nla_total_size(sizeof(u16)) /* BRIDGE_VLANDB_GOPTS_MSTI */ + + nla_total_size(sizeof(u16)); /* BRIDGE_VLANDB_GOPTS_RANGE */ +} + +static void br_vlan_global_opts_notify(const struct net_bridge *br, + u16 vid, u16 vid_range) +{ + struct net_bridge_vlan *v; + struct br_vlan_msg *bvm; + struct nlmsghdr *nlh; + struct sk_buff *skb; + int err = -ENOBUFS; + + /* right now notifications are done only with rtnl held */ + ASSERT_RTNL(); + + /* need to find the vlan due to flags/options */ + v = br_vlan_find(br_vlan_group(br), vid); + if (!v) + return; + + skb = nlmsg_new(rtnl_vlan_global_opts_nlmsg_size(v), GFP_KERNEL); + if (!skb) + goto out_err; + + err = -EMSGSIZE; + nlh = nlmsg_put(skb, 0, 0, RTM_NEWVLAN, sizeof(*bvm), 0); + if (!nlh) + goto out_err; + bvm = nlmsg_data(nlh); + memset(bvm, 0, sizeof(*bvm)); + bvm->family = AF_BRIDGE; + bvm->ifindex = br->dev->ifindex; + + if (!br_vlan_global_opts_fill(skb, vid, vid_range, v)) + goto out_err; + + nlmsg_end(skb, nlh); + rtnl_notify(skb, dev_net(br->dev), 0, RTNLGRP_BRVLAN, NULL, GFP_KERNEL); + return; + +out_err: + rtnl_set_sk_err(dev_net(br->dev), RTNLGRP_BRVLAN, err); + kfree_skb(skb); +} + +static int br_vlan_process_global_one_opts(const struct net_bridge *br, + struct net_bridge_vlan_group *vg, + struct net_bridge_vlan *v, + struct nlattr **tb, + bool *changed, + struct netlink_ext_ack *extack) +{ + int err __maybe_unused; + + *changed = false; +#ifdef CONFIG_BRIDGE_IGMP_SNOOPING + if (tb[BRIDGE_VLANDB_GOPTS_MCAST_SNOOPING]) { + u8 mc_snooping; + + mc_snooping = nla_get_u8(tb[BRIDGE_VLANDB_GOPTS_MCAST_SNOOPING]); + if (br_multicast_toggle_global_vlan(v, !!mc_snooping)) + *changed = true; + } + if (tb[BRIDGE_VLANDB_GOPTS_MCAST_IGMP_VERSION]) { + u8 ver; + + ver = nla_get_u8(tb[BRIDGE_VLANDB_GOPTS_MCAST_IGMP_VERSION]); + err = br_multicast_set_igmp_version(&v->br_mcast_ctx, ver); + if (err) + return err; + *changed = true; + } + if (tb[BRIDGE_VLANDB_GOPTS_MCAST_LAST_MEMBER_CNT]) { + u32 cnt; + + cnt = nla_get_u32(tb[BRIDGE_VLANDB_GOPTS_MCAST_LAST_MEMBER_CNT]); + v->br_mcast_ctx.multicast_last_member_count = cnt; + *changed = true; + } + if (tb[BRIDGE_VLANDB_GOPTS_MCAST_STARTUP_QUERY_CNT]) { + u32 cnt; + + cnt = nla_get_u32(tb[BRIDGE_VLANDB_GOPTS_MCAST_STARTUP_QUERY_CNT]); + v->br_mcast_ctx.multicast_startup_query_count = cnt; + *changed = true; + } + if (tb[BRIDGE_VLANDB_GOPTS_MCAST_LAST_MEMBER_INTVL]) { + u64 val; + + val = nla_get_u64(tb[BRIDGE_VLANDB_GOPTS_MCAST_LAST_MEMBER_INTVL]); + v->br_mcast_ctx.multicast_last_member_interval = clock_t_to_jiffies(val); + *changed = true; + } + if (tb[BRIDGE_VLANDB_GOPTS_MCAST_MEMBERSHIP_INTVL]) { + u64 val; + + val = nla_get_u64(tb[BRIDGE_VLANDB_GOPTS_MCAST_MEMBERSHIP_INTVL]); + v->br_mcast_ctx.multicast_membership_interval = clock_t_to_jiffies(val); + *changed = true; + } + if (tb[BRIDGE_VLANDB_GOPTS_MCAST_QUERIER_INTVL]) { + u64 val; + + val = nla_get_u64(tb[BRIDGE_VLANDB_GOPTS_MCAST_QUERIER_INTVL]); + v->br_mcast_ctx.multicast_querier_interval = clock_t_to_jiffies(val); + *changed = true; + } + if (tb[BRIDGE_VLANDB_GOPTS_MCAST_QUERY_INTVL]) { + u64 val; + + val = nla_get_u64(tb[BRIDGE_VLANDB_GOPTS_MCAST_QUERY_INTVL]); + br_multicast_set_query_intvl(&v->br_mcast_ctx, val); + *changed = true; + } + if (tb[BRIDGE_VLANDB_GOPTS_MCAST_QUERY_RESPONSE_INTVL]) { + u64 val; + + val = nla_get_u64(tb[BRIDGE_VLANDB_GOPTS_MCAST_QUERY_RESPONSE_INTVL]); + v->br_mcast_ctx.multicast_query_response_interval = clock_t_to_jiffies(val); + *changed = true; + } + if (tb[BRIDGE_VLANDB_GOPTS_MCAST_STARTUP_QUERY_INTVL]) { + u64 val; + + val = nla_get_u64(tb[BRIDGE_VLANDB_GOPTS_MCAST_STARTUP_QUERY_INTVL]); + br_multicast_set_startup_query_intvl(&v->br_mcast_ctx, val); + *changed = true; + } + if (tb[BRIDGE_VLANDB_GOPTS_MCAST_QUERIER]) { + u8 val; + + val = nla_get_u8(tb[BRIDGE_VLANDB_GOPTS_MCAST_QUERIER]); + err = br_multicast_set_querier(&v->br_mcast_ctx, val); + if (err) + return err; + *changed = true; + } +#if IS_ENABLED(CONFIG_IPV6) + if (tb[BRIDGE_VLANDB_GOPTS_MCAST_MLD_VERSION]) { + u8 ver; + + ver = nla_get_u8(tb[BRIDGE_VLANDB_GOPTS_MCAST_MLD_VERSION]); + err = br_multicast_set_mld_version(&v->br_mcast_ctx, ver); + if (err) + return err; + *changed = true; + } +#endif +#endif + if (tb[BRIDGE_VLANDB_GOPTS_MSTI]) { + u16 msti; + + msti = nla_get_u16(tb[BRIDGE_VLANDB_GOPTS_MSTI]); + err = br_mst_vlan_set_msti(v, msti); + if (err) + return err; + *changed = true; + } + + return 0; +} + +static const struct nla_policy br_vlan_db_gpol[BRIDGE_VLANDB_GOPTS_MAX + 1] = { + [BRIDGE_VLANDB_GOPTS_ID] = { .type = NLA_U16 }, + [BRIDGE_VLANDB_GOPTS_RANGE] = { .type = NLA_U16 }, + [BRIDGE_VLANDB_GOPTS_MCAST_SNOOPING] = { .type = NLA_U8 }, + [BRIDGE_VLANDB_GOPTS_MCAST_MLD_VERSION] = { .type = NLA_U8 }, + [BRIDGE_VLANDB_GOPTS_MCAST_QUERY_INTVL] = { .type = NLA_U64 }, + [BRIDGE_VLANDB_GOPTS_MCAST_QUERIER] = { .type = NLA_U8 }, + [BRIDGE_VLANDB_GOPTS_MCAST_IGMP_VERSION] = { .type = NLA_U8 }, + [BRIDGE_VLANDB_GOPTS_MCAST_LAST_MEMBER_CNT] = { .type = NLA_U32 }, + [BRIDGE_VLANDB_GOPTS_MCAST_STARTUP_QUERY_CNT] = { .type = NLA_U32 }, + [BRIDGE_VLANDB_GOPTS_MCAST_LAST_MEMBER_INTVL] = { .type = NLA_U64 }, + [BRIDGE_VLANDB_GOPTS_MCAST_MEMBERSHIP_INTVL] = { .type = NLA_U64 }, + [BRIDGE_VLANDB_GOPTS_MCAST_QUERIER_INTVL] = { .type = NLA_U64 }, + [BRIDGE_VLANDB_GOPTS_MCAST_STARTUP_QUERY_INTVL] = { .type = NLA_U64 }, + [BRIDGE_VLANDB_GOPTS_MCAST_QUERY_RESPONSE_INTVL] = { .type = NLA_U64 }, + [BRIDGE_VLANDB_GOPTS_MSTI] = NLA_POLICY_MAX(NLA_U16, VLAN_N_VID - 1), +}; + +int br_vlan_rtm_process_global_options(struct net_device *dev, + const struct nlattr *attr, + int cmd, + struct netlink_ext_ack *extack) +{ + struct net_bridge_vlan *v, *curr_start = NULL, *curr_end = NULL; + struct nlattr *tb[BRIDGE_VLANDB_GOPTS_MAX + 1]; + struct net_bridge_vlan_group *vg; + u16 vid, vid_range = 0; + struct net_bridge *br; + int err = 0; + + if (cmd != RTM_NEWVLAN) { + NL_SET_ERR_MSG_MOD(extack, "Global vlan options support only set operation"); + return -EINVAL; + } + if (!netif_is_bridge_master(dev)) { + NL_SET_ERR_MSG_MOD(extack, "Global vlan options can only be set on bridge device"); + return -EINVAL; + } + br = netdev_priv(dev); + vg = br_vlan_group(br); + if (WARN_ON(!vg)) + return -ENODEV; + + err = nla_parse_nested(tb, BRIDGE_VLANDB_GOPTS_MAX, attr, + br_vlan_db_gpol, extack); + if (err) + return err; + + if (!tb[BRIDGE_VLANDB_GOPTS_ID]) { + NL_SET_ERR_MSG_MOD(extack, "Missing vlan entry id"); + return -EINVAL; + } + vid = nla_get_u16(tb[BRIDGE_VLANDB_GOPTS_ID]); + if (!br_vlan_valid_id(vid, extack)) + return -EINVAL; + + if (tb[BRIDGE_VLANDB_GOPTS_RANGE]) { + vid_range = nla_get_u16(tb[BRIDGE_VLANDB_GOPTS_RANGE]); + if (!br_vlan_valid_id(vid_range, extack)) + return -EINVAL; + if (vid >= vid_range) { + NL_SET_ERR_MSG_MOD(extack, "End vlan id is less than or equal to start vlan id"); + return -EINVAL; + } + } else { + vid_range = vid; + } + + for (; vid <= vid_range; vid++) { + bool changed = false; + + v = br_vlan_find(vg, vid); + if (!v) { + NL_SET_ERR_MSG_MOD(extack, "Vlan in range doesn't exist, can't process global options"); + err = -ENOENT; + break; + } + + err = br_vlan_process_global_one_opts(br, vg, v, tb, &changed, + extack); + if (err) + break; + + if (changed) { + /* vlan options changed, check for range */ + if (!curr_start) { + curr_start = v; + curr_end = v; + continue; + } + + if (!br_vlan_global_opts_can_enter_range(v, curr_end)) { + br_vlan_global_opts_notify(br, curr_start->vid, + curr_end->vid); + curr_start = v; + } + curr_end = v; + } else { + /* nothing changed and nothing to notify yet */ + if (!curr_start) + continue; + + br_vlan_global_opts_notify(br, curr_start->vid, + curr_end->vid); + curr_start = NULL; + curr_end = NULL; + } + } + if (curr_start) + br_vlan_global_opts_notify(br, curr_start->vid, curr_end->vid); + + return err; +} diff --git a/net/bridge/br_vlan_tunnel.c b/net/bridge/br_vlan_tunnel.c new file mode 100644 index 000000000..6399a8a69 --- /dev/null +++ b/net/bridge/br_vlan_tunnel.c @@ -0,0 +1,209 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Bridge per vlan tunnel port dst_metadata handling code + * + * Authors: + * Roopa Prabhu + */ + +#include +#include +#include +#include +#include +#include + +#include "br_private.h" +#include "br_private_tunnel.h" + +static inline int br_vlan_tunid_cmp(struct rhashtable_compare_arg *arg, + const void *ptr) +{ + const struct net_bridge_vlan *vle = ptr; + __be64 tunid = *(__be64 *)arg->key; + + return vle->tinfo.tunnel_id != tunid; +} + +static const struct rhashtable_params br_vlan_tunnel_rht_params = { + .head_offset = offsetof(struct net_bridge_vlan, tnode), + .key_offset = offsetof(struct net_bridge_vlan, tinfo.tunnel_id), + .key_len = sizeof(__be64), + .nelem_hint = 3, + .obj_cmpfn = br_vlan_tunid_cmp, + .automatic_shrinking = true, +}; + +static struct net_bridge_vlan *br_vlan_tunnel_lookup(struct rhashtable *tbl, + __be64 tunnel_id) +{ + return rhashtable_lookup_fast(tbl, &tunnel_id, + br_vlan_tunnel_rht_params); +} + +static void vlan_tunnel_info_release(struct net_bridge_vlan *vlan) +{ + struct metadata_dst *tdst = rtnl_dereference(vlan->tinfo.tunnel_dst); + + WRITE_ONCE(vlan->tinfo.tunnel_id, 0); + RCU_INIT_POINTER(vlan->tinfo.tunnel_dst, NULL); + dst_release(&tdst->dst); +} + +void vlan_tunnel_info_del(struct net_bridge_vlan_group *vg, + struct net_bridge_vlan *vlan) +{ + if (!rcu_access_pointer(vlan->tinfo.tunnel_dst)) + return; + rhashtable_remove_fast(&vg->tunnel_hash, &vlan->tnode, + br_vlan_tunnel_rht_params); + vlan_tunnel_info_release(vlan); +} + +static int __vlan_tunnel_info_add(struct net_bridge_vlan_group *vg, + struct net_bridge_vlan *vlan, u32 tun_id) +{ + struct metadata_dst *metadata = rtnl_dereference(vlan->tinfo.tunnel_dst); + __be64 key = key32_to_tunnel_id(cpu_to_be32(tun_id)); + int err; + + if (metadata) + return -EEXIST; + + metadata = __ip_tun_set_dst(0, 0, 0, 0, 0, TUNNEL_KEY, + key, 0); + if (!metadata) + return -EINVAL; + + metadata->u.tun_info.mode |= IP_TUNNEL_INFO_TX | IP_TUNNEL_INFO_BRIDGE; + rcu_assign_pointer(vlan->tinfo.tunnel_dst, metadata); + WRITE_ONCE(vlan->tinfo.tunnel_id, key); + + err = rhashtable_lookup_insert_fast(&vg->tunnel_hash, &vlan->tnode, + br_vlan_tunnel_rht_params); + if (err) + goto out; + + return 0; +out: + vlan_tunnel_info_release(vlan); + + return err; +} + +/* Must be protected by RTNL. + * Must be called with vid in range from 1 to 4094 inclusive. + */ +int nbp_vlan_tunnel_info_add(const struct net_bridge_port *port, u16 vid, + u32 tun_id) +{ + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *vlan; + + ASSERT_RTNL(); + + vg = nbp_vlan_group(port); + vlan = br_vlan_find(vg, vid); + if (!vlan) + return -EINVAL; + + return __vlan_tunnel_info_add(vg, vlan, tun_id); +} + +/* Must be protected by RTNL. + * Must be called with vid in range from 1 to 4094 inclusive. + */ +int nbp_vlan_tunnel_info_delete(const struct net_bridge_port *port, u16 vid) +{ + struct net_bridge_vlan_group *vg; + struct net_bridge_vlan *v; + + ASSERT_RTNL(); + + vg = nbp_vlan_group(port); + v = br_vlan_find(vg, vid); + if (!v) + return -ENOENT; + + vlan_tunnel_info_del(vg, v); + + return 0; +} + +static void __vlan_tunnel_info_flush(struct net_bridge_vlan_group *vg) +{ + struct net_bridge_vlan *vlan, *tmp; + + list_for_each_entry_safe(vlan, tmp, &vg->vlan_list, vlist) + vlan_tunnel_info_del(vg, vlan); +} + +void nbp_vlan_tunnel_info_flush(struct net_bridge_port *port) +{ + struct net_bridge_vlan_group *vg; + + ASSERT_RTNL(); + + vg = nbp_vlan_group(port); + __vlan_tunnel_info_flush(vg); +} + +int vlan_tunnel_init(struct net_bridge_vlan_group *vg) +{ + return rhashtable_init(&vg->tunnel_hash, &br_vlan_tunnel_rht_params); +} + +void vlan_tunnel_deinit(struct net_bridge_vlan_group *vg) +{ + rhashtable_destroy(&vg->tunnel_hash); +} + +void br_handle_ingress_vlan_tunnel(struct sk_buff *skb, + struct net_bridge_port *p, + struct net_bridge_vlan_group *vg) +{ + struct ip_tunnel_info *tinfo = skb_tunnel_info(skb); + struct net_bridge_vlan *vlan; + + if (!vg || !tinfo) + return; + + /* if already tagged, ignore */ + if (skb_vlan_tagged(skb)) + return; + + /* lookup vid, given tunnel id */ + vlan = br_vlan_tunnel_lookup(&vg->tunnel_hash, tinfo->key.tun_id); + if (!vlan) + return; + + skb_dst_drop(skb); + + __vlan_hwaccel_put_tag(skb, p->br->vlan_proto, vlan->vid); +} + +int br_handle_egress_vlan_tunnel(struct sk_buff *skb, + struct net_bridge_vlan *vlan) +{ + struct metadata_dst *tunnel_dst; + __be64 tunnel_id; + int err; + + if (!vlan) + return 0; + + tunnel_id = READ_ONCE(vlan->tinfo.tunnel_id); + if (!tunnel_id || unlikely(!skb_vlan_tag_present(skb))) + return 0; + + skb_dst_drop(skb); + err = skb_vlan_pop(skb); + if (err) + return err; + + tunnel_dst = rcu_dereference(vlan->tinfo.tunnel_dst); + if (tunnel_dst && dst_hold_safe(&tunnel_dst->dst)) + skb_dst_set(skb, &tunnel_dst->dst); + + return 0; +} diff --git a/net/bridge/netfilter/Kconfig b/net/bridge/netfilter/Kconfig new file mode 100644 index 000000000..7f304a19a --- /dev/null +++ b/net/bridge/netfilter/Kconfig @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Bridge netfilter configuration +# +# +menuconfig NF_TABLES_BRIDGE + depends on BRIDGE && NETFILTER && NF_TABLES + select NETFILTER_FAMILY_BRIDGE + tristate "Ethernet Bridge nf_tables support" + +if NF_TABLES_BRIDGE + +config NFT_BRIDGE_META + tristate "Netfilter nf_table bridge meta support" + help + Add support for bridge dedicated meta key. + +config NFT_BRIDGE_REJECT + tristate "Netfilter nf_tables bridge reject support" + depends on NFT_REJECT + depends on NF_REJECT_IPV4 + depends on NF_REJECT_IPV6 + help + Add support to reject packets. + +endif # NF_TABLES_BRIDGE + +config NF_CONNTRACK_BRIDGE + tristate "IPv4/IPV6 bridge connection tracking support" + depends on NF_CONNTRACK + default n + help + Connection tracking keeps a record of what packets have passed + through your machine, in order to figure out how they are related + into connections. This is used to enhance packet filtering via + stateful policies. Enable this if you want native tracking from + the bridge. This provides a replacement for the `br_netfilter' + infrastructure. + + To compile it as a module, choose M here. If unsure, say N. + +menuconfig BRIDGE_NF_EBTABLES + tristate "Ethernet Bridge tables (ebtables) support" + depends on BRIDGE && NETFILTER && NETFILTER_XTABLES + select NETFILTER_FAMILY_BRIDGE + help + ebtables is a general, extensible frame/packet identification + framework. Say 'Y' or 'M' here if you want to do Ethernet + filtering/NAT/brouting on the Ethernet bridge. + +if BRIDGE_NF_EBTABLES + +# +# tables +# +config BRIDGE_EBT_BROUTE + tristate "ebt: broute table support" + help + The ebtables broute table is used to define rules that decide between + bridging and routing frames, giving Linux the functionality of a + brouter. See the man page for ebtables(8) and examples on the ebtables + website. + + To compile it as a module, choose M here. If unsure, say N. + +config BRIDGE_EBT_T_FILTER + tristate "ebt: filter table support" + help + The ebtables filter table is used to define frame filtering rules at + local input, forwarding and local output. See the man page for + ebtables(8). + + To compile it as a module, choose M here. If unsure, say N. + +config BRIDGE_EBT_T_NAT + tristate "ebt: nat table support" + help + The ebtables nat table is used to define rules that alter the MAC + source address (MAC SNAT) or the MAC destination address (MAC DNAT). + See the man page for ebtables(8). + + To compile it as a module, choose M here. If unsure, say N. +# +# matches +# +config BRIDGE_EBT_802_3 + tristate "ebt: 802.3 filter support" + help + This option adds matching support for 802.3 Ethernet frames. + + To compile it as a module, choose M here. If unsure, say N. + +config BRIDGE_EBT_AMONG + tristate "ebt: among filter support" + help + This option adds the among match, which allows matching the MAC source + and/or destination address on a list of addresses. Optionally, + MAC/IP address pairs can be matched, f.e. for anti-spoofing rules. + + To compile it as a module, choose M here. If unsure, say N. + +config BRIDGE_EBT_ARP + tristate "ebt: ARP filter support" + help + This option adds the ARP match, which allows ARP and RARP header field + filtering. + + To compile it as a module, choose M here. If unsure, say N. + +config BRIDGE_EBT_IP + tristate "ebt: IP filter support" + help + This option adds the IP match, which allows basic IP header field + filtering. + + To compile it as a module, choose M here. If unsure, say N. + +config BRIDGE_EBT_IP6 + tristate "ebt: IP6 filter support" + depends on BRIDGE_NF_EBTABLES && IPV6 + help + This option adds the IP6 match, which allows basic IPV6 header field + filtering. + + To compile it as a module, choose M here. If unsure, say N. + +config BRIDGE_EBT_LIMIT + tristate "ebt: limit match support" + help + This option adds the limit match, which allows you to control + the rate at which a rule can be matched. This match is the + equivalent of the iptables limit match. + + If you want to compile it as a module, say M here and read + . If unsure, say `N'. + +config BRIDGE_EBT_MARK + tristate "ebt: mark filter support" + help + This option adds the mark match, which allows matching frames based on + the 'nfmark' value in the frame. This can be set by the mark target. + This value is the same as the one used in the iptables mark match and + target. + + To compile it as a module, choose M here. If unsure, say N. + +config BRIDGE_EBT_PKTTYPE + tristate "ebt: packet type filter support" + help + This option adds the packet type match, which allows matching on the + type of packet based on its Ethernet "class" (as determined by + the generic networking code): broadcast, multicast, + for this host alone or for another host. + + To compile it as a module, choose M here. If unsure, say N. + +config BRIDGE_EBT_STP + tristate "ebt: STP filter support" + help + This option adds the Spanning Tree Protocol match, which + allows STP header field filtering. + + To compile it as a module, choose M here. If unsure, say N. + +config BRIDGE_EBT_VLAN + tristate "ebt: 802.1Q VLAN filter support" + help + This option adds the 802.1Q vlan match, which allows the filtering of + 802.1Q vlan fields. + + To compile it as a module, choose M here. If unsure, say N. +# +# targets +# +config BRIDGE_EBT_ARPREPLY + tristate "ebt: arp reply target support" + depends on BRIDGE_NF_EBTABLES && INET + help + This option adds the arp reply target, which allows + automatically sending arp replies to arp requests. + + To compile it as a module, choose M here. If unsure, say N. + +config BRIDGE_EBT_DNAT + tristate "ebt: dnat target support" + help + This option adds the MAC DNAT target, which allows altering the MAC + destination address of frames. + + To compile it as a module, choose M here. If unsure, say N. + +config BRIDGE_EBT_MARK_T + tristate "ebt: mark target support" + help + This option adds the mark target, which allows marking frames by + setting the 'nfmark' value in the frame. + This value is the same as the one used in the iptables mark match and + target. + + To compile it as a module, choose M here. If unsure, say N. + +config BRIDGE_EBT_REDIRECT + tristate "ebt: redirect target support" + help + This option adds the MAC redirect target, which allows altering the MAC + destination address of a frame to that of the device it arrived on. + + To compile it as a module, choose M here. If unsure, say N. + +config BRIDGE_EBT_SNAT + tristate "ebt: snat target support" + help + This option adds the MAC SNAT target, which allows altering the MAC + source address of frames. + + To compile it as a module, choose M here. If unsure, say N. +# +# watchers +# +config BRIDGE_EBT_LOG + tristate "ebt: log support" + help + This option adds the log watcher, that you can use in any rule + in any ebtables table. It records info about the frame header + to the syslog. + + To compile it as a module, choose M here. If unsure, say N. + +config BRIDGE_EBT_NFLOG + tristate "ebt: nflog support" + help + This option enables the nflog watcher, which allows to LOG + messages through the netfilter logging API, which can use + either the old LOG target, the old ULOG target or nfnetlink_log + as backend. + + This option adds the nflog watcher, that you can use in any rule + in any ebtables table. + + To compile it as a module, choose M here. If unsure, say N. + +endif # BRIDGE_NF_EBTABLES diff --git a/net/bridge/netfilter/Makefile b/net/bridge/netfilter/Makefile new file mode 100644 index 000000000..1c9ce49ab --- /dev/null +++ b/net/bridge/netfilter/Makefile @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the netfilter modules for Link Layer filtering on a bridge. +# + +obj-$(CONFIG_NFT_BRIDGE_META) += nft_meta_bridge.o +obj-$(CONFIG_NFT_BRIDGE_REJECT) += nft_reject_bridge.o + +# connection tracking +obj-$(CONFIG_NF_CONNTRACK_BRIDGE) += nf_conntrack_bridge.o + +obj-$(CONFIG_BRIDGE_NF_EBTABLES) += ebtables.o + +# tables +obj-$(CONFIG_BRIDGE_EBT_BROUTE) += ebtable_broute.o +obj-$(CONFIG_BRIDGE_EBT_T_FILTER) += ebtable_filter.o +obj-$(CONFIG_BRIDGE_EBT_T_NAT) += ebtable_nat.o + +#matches +obj-$(CONFIG_BRIDGE_EBT_802_3) += ebt_802_3.o +obj-$(CONFIG_BRIDGE_EBT_AMONG) += ebt_among.o +obj-$(CONFIG_BRIDGE_EBT_ARP) += ebt_arp.o +obj-$(CONFIG_BRIDGE_EBT_IP) += ebt_ip.o +obj-$(CONFIG_BRIDGE_EBT_IP6) += ebt_ip6.o +obj-$(CONFIG_BRIDGE_EBT_LIMIT) += ebt_limit.o +obj-$(CONFIG_BRIDGE_EBT_MARK) += ebt_mark_m.o +obj-$(CONFIG_BRIDGE_EBT_PKTTYPE) += ebt_pkttype.o +obj-$(CONFIG_BRIDGE_EBT_STP) += ebt_stp.o +obj-$(CONFIG_BRIDGE_EBT_VLAN) += ebt_vlan.o + +# targets +obj-$(CONFIG_BRIDGE_EBT_ARPREPLY) += ebt_arpreply.o +obj-$(CONFIG_BRIDGE_EBT_MARK_T) += ebt_mark.o +obj-$(CONFIG_BRIDGE_EBT_DNAT) += ebt_dnat.o +obj-$(CONFIG_BRIDGE_EBT_REDIRECT) += ebt_redirect.o +obj-$(CONFIG_BRIDGE_EBT_SNAT) += ebt_snat.o + +# watchers +obj-$(CONFIG_BRIDGE_EBT_LOG) += ebt_log.o +obj-$(CONFIG_BRIDGE_EBT_NFLOG) += ebt_nflog.o diff --git a/net/bridge/netfilter/ebt_802_3.c b/net/bridge/netfilter/ebt_802_3.c new file mode 100644 index 000000000..68c2519bd --- /dev/null +++ b/net/bridge/netfilter/ebt_802_3.c @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * 802_3 + * + * Author: + * Chris Vitale csv@bluetail.com + * + * May 2003 + * + */ +#include +#include +#include +#include +#include + +static struct ebt_802_3_hdr *ebt_802_3_hdr(const struct sk_buff *skb) +{ + return (struct ebt_802_3_hdr *)skb_mac_header(skb); +} + +static bool +ebt_802_3_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct ebt_802_3_info *info = par->matchinfo; + const struct ebt_802_3_hdr *hdr = ebt_802_3_hdr(skb); + __be16 type = hdr->llc.ui.ctrl & IS_UI ? hdr->llc.ui.type : hdr->llc.ni.type; + + if (info->bitmask & EBT_802_3_SAP) { + if (NF_INVF(info, EBT_802_3_SAP, info->sap != hdr->llc.ui.ssap)) + return false; + if (NF_INVF(info, EBT_802_3_SAP, info->sap != hdr->llc.ui.dsap)) + return false; + } + + if (info->bitmask & EBT_802_3_TYPE) { + if (!(hdr->llc.ui.dsap == CHECK_TYPE && hdr->llc.ui.ssap == CHECK_TYPE)) + return false; + if (NF_INVF(info, EBT_802_3_TYPE, info->type != type)) + return false; + } + + return true; +} + +static int ebt_802_3_mt_check(const struct xt_mtchk_param *par) +{ + const struct ebt_802_3_info *info = par->matchinfo; + + if (info->bitmask & ~EBT_802_3_MASK || info->invflags & ~EBT_802_3_MASK) + return -EINVAL; + + return 0; +} + +static struct xt_match ebt_802_3_mt_reg __read_mostly = { + .name = "802_3", + .revision = 0, + .family = NFPROTO_BRIDGE, + .match = ebt_802_3_mt, + .checkentry = ebt_802_3_mt_check, + .matchsize = sizeof(struct ebt_802_3_info), + .me = THIS_MODULE, +}; + +static int __init ebt_802_3_init(void) +{ + return xt_register_match(&ebt_802_3_mt_reg); +} + +static void __exit ebt_802_3_fini(void) +{ + xt_unregister_match(&ebt_802_3_mt_reg); +} + +module_init(ebt_802_3_init); +module_exit(ebt_802_3_fini); +MODULE_DESCRIPTION("Ebtables: DSAP/SSAP field and SNAP type matching"); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebt_among.c b/net/bridge/netfilter/ebt_among.c new file mode 100644 index 000000000..96f7243b6 --- /dev/null +++ b/net/bridge/netfilter/ebt_among.c @@ -0,0 +1,281 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebt_among + * + * Authors: + * Grzegorz Borowiak + * + * August, 2003 + * + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include + +static bool ebt_mac_wormhash_contains(const struct ebt_mac_wormhash *wh, + const char *mac, __be32 ip) +{ + /* You may be puzzled as to how this code works. + * Some tricks were used, refer to + * include/linux/netfilter_bridge/ebt_among.h + * as there you can find a solution of this mystery. + */ + const struct ebt_mac_wormhash_tuple *p; + int start, limit, i; + uint32_t cmp[2] = { 0, 0 }; + int key = ((const unsigned char *)mac)[5]; + + ether_addr_copy(((char *) cmp) + 2, mac); + start = wh->table[key]; + limit = wh->table[key + 1]; + if (ip) { + for (i = start; i < limit; i++) { + p = &wh->pool[i]; + if (cmp[1] == p->cmp[1] && cmp[0] == p->cmp[0]) + if (p->ip == 0 || p->ip == ip) + return true; + } + } else { + for (i = start; i < limit; i++) { + p = &wh->pool[i]; + if (cmp[1] == p->cmp[1] && cmp[0] == p->cmp[0]) + if (p->ip == 0) + return true; + } + } + return false; +} + +static int ebt_mac_wormhash_check_integrity(const struct ebt_mac_wormhash + *wh) +{ + int i; + + for (i = 0; i < 256; i++) { + if (wh->table[i] > wh->table[i + 1]) + return -0x100 - i; + if (wh->table[i] < 0) + return -0x200 - i; + if (wh->table[i] > wh->poolsize) + return -0x300 - i; + } + if (wh->table[256] > wh->poolsize) + return -0xc00; + return 0; +} + +static int get_ip_dst(const struct sk_buff *skb, __be32 *addr) +{ + if (eth_hdr(skb)->h_proto == htons(ETH_P_IP)) { + const struct iphdr *ih; + struct iphdr _iph; + + ih = skb_header_pointer(skb, 0, sizeof(_iph), &_iph); + if (ih == NULL) + return -1; + *addr = ih->daddr; + } else if (eth_hdr(skb)->h_proto == htons(ETH_P_ARP)) { + const struct arphdr *ah; + struct arphdr _arph; + const __be32 *bp; + __be32 buf; + + ah = skb_header_pointer(skb, 0, sizeof(_arph), &_arph); + if (ah == NULL || + ah->ar_pln != sizeof(__be32) || + ah->ar_hln != ETH_ALEN) + return -1; + bp = skb_header_pointer(skb, sizeof(struct arphdr) + + 2 * ETH_ALEN + sizeof(__be32), + sizeof(__be32), &buf); + if (bp == NULL) + return -1; + *addr = *bp; + } + return 0; +} + +static int get_ip_src(const struct sk_buff *skb, __be32 *addr) +{ + if (eth_hdr(skb)->h_proto == htons(ETH_P_IP)) { + const struct iphdr *ih; + struct iphdr _iph; + + ih = skb_header_pointer(skb, 0, sizeof(_iph), &_iph); + if (ih == NULL) + return -1; + *addr = ih->saddr; + } else if (eth_hdr(skb)->h_proto == htons(ETH_P_ARP)) { + const struct arphdr *ah; + struct arphdr _arph; + const __be32 *bp; + __be32 buf; + + ah = skb_header_pointer(skb, 0, sizeof(_arph), &_arph); + if (ah == NULL || + ah->ar_pln != sizeof(__be32) || + ah->ar_hln != ETH_ALEN) + return -1; + bp = skb_header_pointer(skb, sizeof(struct arphdr) + + ETH_ALEN, sizeof(__be32), &buf); + if (bp == NULL) + return -1; + *addr = *bp; + } + return 0; +} + +static bool +ebt_among_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct ebt_among_info *info = par->matchinfo; + const char *dmac, *smac; + const struct ebt_mac_wormhash *wh_dst, *wh_src; + __be32 dip = 0, sip = 0; + + wh_dst = ebt_among_wh_dst(info); + wh_src = ebt_among_wh_src(info); + + if (wh_src) { + smac = eth_hdr(skb)->h_source; + if (get_ip_src(skb, &sip)) + return false; + if (!(info->bitmask & EBT_AMONG_SRC_NEG)) { + /* we match only if it contains */ + if (!ebt_mac_wormhash_contains(wh_src, smac, sip)) + return false; + } else { + /* we match only if it DOES NOT contain */ + if (ebt_mac_wormhash_contains(wh_src, smac, sip)) + return false; + } + } + + if (wh_dst) { + dmac = eth_hdr(skb)->h_dest; + if (get_ip_dst(skb, &dip)) + return false; + if (!(info->bitmask & EBT_AMONG_DST_NEG)) { + /* we match only if it contains */ + if (!ebt_mac_wormhash_contains(wh_dst, dmac, dip)) + return false; + } else { + /* we match only if it DOES NOT contain */ + if (ebt_mac_wormhash_contains(wh_dst, dmac, dip)) + return false; + } + } + + return true; +} + +static bool poolsize_invalid(const struct ebt_mac_wormhash *w) +{ + return w && w->poolsize >= (INT_MAX / sizeof(struct ebt_mac_wormhash_tuple)); +} + +static bool wormhash_offset_invalid(int off, unsigned int len) +{ + if (off == 0) /* not present */ + return false; + + if (off < (int)sizeof(struct ebt_among_info) || + off % __alignof__(struct ebt_mac_wormhash)) + return true; + + off += sizeof(struct ebt_mac_wormhash); + + return off > len; +} + +static bool wormhash_sizes_valid(const struct ebt_mac_wormhash *wh, int a, int b) +{ + if (a == 0) + a = sizeof(struct ebt_among_info); + + return ebt_mac_wormhash_size(wh) + a == b; +} + +static int ebt_among_mt_check(const struct xt_mtchk_param *par) +{ + const struct ebt_among_info *info = par->matchinfo; + const struct ebt_entry_match *em = + container_of(par->matchinfo, const struct ebt_entry_match, data); + unsigned int expected_length = sizeof(struct ebt_among_info); + const struct ebt_mac_wormhash *wh_dst, *wh_src; + int err; + + if (expected_length > em->match_size) + return -EINVAL; + + if (wormhash_offset_invalid(info->wh_dst_ofs, em->match_size) || + wormhash_offset_invalid(info->wh_src_ofs, em->match_size)) + return -EINVAL; + + wh_dst = ebt_among_wh_dst(info); + if (poolsize_invalid(wh_dst)) + return -EINVAL; + + expected_length += ebt_mac_wormhash_size(wh_dst); + if (expected_length > em->match_size) + return -EINVAL; + + wh_src = ebt_among_wh_src(info); + if (poolsize_invalid(wh_src)) + return -EINVAL; + + if (info->wh_src_ofs < info->wh_dst_ofs) { + if (!wormhash_sizes_valid(wh_src, info->wh_src_ofs, info->wh_dst_ofs)) + return -EINVAL; + } else { + if (!wormhash_sizes_valid(wh_dst, info->wh_dst_ofs, info->wh_src_ofs)) + return -EINVAL; + } + + expected_length += ebt_mac_wormhash_size(wh_src); + + if (em->match_size != EBT_ALIGN(expected_length)) { + pr_err_ratelimited("wrong size: %d against expected %d, rounded to %zd\n", + em->match_size, expected_length, + EBT_ALIGN(expected_length)); + return -EINVAL; + } + if (wh_dst && (err = ebt_mac_wormhash_check_integrity(wh_dst))) { + pr_err_ratelimited("dst integrity fail: %x\n", -err); + return -EINVAL; + } + if (wh_src && (err = ebt_mac_wormhash_check_integrity(wh_src))) { + pr_err_ratelimited("src integrity fail: %x\n", -err); + return -EINVAL; + } + return 0; +} + +static struct xt_match ebt_among_mt_reg __read_mostly = { + .name = "among", + .revision = 0, + .family = NFPROTO_BRIDGE, + .match = ebt_among_mt, + .checkentry = ebt_among_mt_check, + .matchsize = -1, /* special case */ + .me = THIS_MODULE, +}; + +static int __init ebt_among_init(void) +{ + return xt_register_match(&ebt_among_mt_reg); +} + +static void __exit ebt_among_fini(void) +{ + xt_unregister_match(&ebt_among_mt_reg); +} + +module_init(ebt_among_init); +module_exit(ebt_among_fini); +MODULE_DESCRIPTION("Ebtables: Combined MAC/IP address list matching"); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebt_arp.c b/net/bridge/netfilter/ebt_arp.c new file mode 100644 index 000000000..0707cc00f --- /dev/null +++ b/net/bridge/netfilter/ebt_arp.c @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebt_arp + * + * Authors: + * Bart De Schuymer + * Tim Gardner + * + * April, 2002 + * + */ +#include +#include +#include +#include +#include +#include + +static bool +ebt_arp_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct ebt_arp_info *info = par->matchinfo; + const struct arphdr *ah; + struct arphdr _arph; + + ah = skb_header_pointer(skb, 0, sizeof(_arph), &_arph); + if (ah == NULL) + return false; + if ((info->bitmask & EBT_ARP_OPCODE) && + NF_INVF(info, EBT_ARP_OPCODE, info->opcode != ah->ar_op)) + return false; + if ((info->bitmask & EBT_ARP_HTYPE) && + NF_INVF(info, EBT_ARP_HTYPE, info->htype != ah->ar_hrd)) + return false; + if ((info->bitmask & EBT_ARP_PTYPE) && + NF_INVF(info, EBT_ARP_PTYPE, info->ptype != ah->ar_pro)) + return false; + + if (info->bitmask & (EBT_ARP_SRC_IP | EBT_ARP_DST_IP | EBT_ARP_GRAT)) { + const __be32 *sap, *dap; + __be32 saddr, daddr; + + if (ah->ar_pln != sizeof(__be32) || ah->ar_pro != htons(ETH_P_IP)) + return false; + sap = skb_header_pointer(skb, sizeof(struct arphdr) + + ah->ar_hln, sizeof(saddr), + &saddr); + if (sap == NULL) + return false; + dap = skb_header_pointer(skb, sizeof(struct arphdr) + + 2*ah->ar_hln+sizeof(saddr), + sizeof(daddr), &daddr); + if (dap == NULL) + return false; + if ((info->bitmask & EBT_ARP_SRC_IP) && + NF_INVF(info, EBT_ARP_SRC_IP, + info->saddr != (*sap & info->smsk))) + return false; + if ((info->bitmask & EBT_ARP_DST_IP) && + NF_INVF(info, EBT_ARP_DST_IP, + info->daddr != (*dap & info->dmsk))) + return false; + if ((info->bitmask & EBT_ARP_GRAT) && + NF_INVF(info, EBT_ARP_GRAT, *dap != *sap)) + return false; + } + + if (info->bitmask & (EBT_ARP_SRC_MAC | EBT_ARP_DST_MAC)) { + const unsigned char *mp; + unsigned char _mac[ETH_ALEN]; + + if (ah->ar_hln != ETH_ALEN || ah->ar_hrd != htons(ARPHRD_ETHER)) + return false; + if (info->bitmask & EBT_ARP_SRC_MAC) { + mp = skb_header_pointer(skb, sizeof(struct arphdr), + sizeof(_mac), &_mac); + if (mp == NULL) + return false; + if (NF_INVF(info, EBT_ARP_SRC_MAC, + !ether_addr_equal_masked(mp, info->smaddr, + info->smmsk))) + return false; + } + + if (info->bitmask & EBT_ARP_DST_MAC) { + mp = skb_header_pointer(skb, sizeof(struct arphdr) + + ah->ar_hln + ah->ar_pln, + sizeof(_mac), &_mac); + if (mp == NULL) + return false; + if (NF_INVF(info, EBT_ARP_DST_MAC, + !ether_addr_equal_masked(mp, info->dmaddr, + info->dmmsk))) + return false; + } + } + + return true; +} + +static int ebt_arp_mt_check(const struct xt_mtchk_param *par) +{ + const struct ebt_arp_info *info = par->matchinfo; + const struct ebt_entry *e = par->entryinfo; + + if ((e->ethproto != htons(ETH_P_ARP) && + e->ethproto != htons(ETH_P_RARP)) || + e->invflags & EBT_IPROTO) + return -EINVAL; + if (info->bitmask & ~EBT_ARP_MASK || info->invflags & ~EBT_ARP_MASK) + return -EINVAL; + return 0; +} + +static struct xt_match ebt_arp_mt_reg __read_mostly = { + .name = "arp", + .revision = 0, + .family = NFPROTO_BRIDGE, + .match = ebt_arp_mt, + .checkentry = ebt_arp_mt_check, + .matchsize = sizeof(struct ebt_arp_info), + .me = THIS_MODULE, +}; + +static int __init ebt_arp_init(void) +{ + return xt_register_match(&ebt_arp_mt_reg); +} + +static void __exit ebt_arp_fini(void) +{ + xt_unregister_match(&ebt_arp_mt_reg); +} + +module_init(ebt_arp_init); +module_exit(ebt_arp_fini); +MODULE_DESCRIPTION("Ebtables: ARP protocol packet match"); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebt_arpreply.c b/net/bridge/netfilter/ebt_arpreply.c new file mode 100644 index 000000000..d9e77e250 --- /dev/null +++ b/net/bridge/netfilter/ebt_arpreply.c @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebt_arpreply + * + * Authors: + * Grzegorz Borowiak + * Bart De Schuymer + * + * August, 2003 + * + */ +#include +#include +#include +#include +#include +#include + +static unsigned int +ebt_arpreply_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct ebt_arpreply_info *info = par->targinfo; + const __be32 *siptr, *diptr; + __be32 _sip, _dip; + const struct arphdr *ap; + struct arphdr _ah; + const unsigned char *shp; + unsigned char _sha[ETH_ALEN]; + + ap = skb_header_pointer(skb, 0, sizeof(_ah), &_ah); + if (ap == NULL) + return EBT_DROP; + + if (ap->ar_op != htons(ARPOP_REQUEST) || + ap->ar_hln != ETH_ALEN || + ap->ar_pro != htons(ETH_P_IP) || + ap->ar_pln != 4) + return EBT_CONTINUE; + + shp = skb_header_pointer(skb, sizeof(_ah), ETH_ALEN, &_sha); + if (shp == NULL) + return EBT_DROP; + + siptr = skb_header_pointer(skb, sizeof(_ah) + ETH_ALEN, + sizeof(_sip), &_sip); + if (siptr == NULL) + return EBT_DROP; + + diptr = skb_header_pointer(skb, + sizeof(_ah) + 2 * ETH_ALEN + sizeof(_sip), + sizeof(_dip), &_dip); + if (diptr == NULL) + return EBT_DROP; + + arp_send(ARPOP_REPLY, ETH_P_ARP, *siptr, + (struct net_device *)xt_in(par), + *diptr, shp, info->mac, shp); + + return info->target; +} + +static int ebt_arpreply_tg_check(const struct xt_tgchk_param *par) +{ + const struct ebt_arpreply_info *info = par->targinfo; + const struct ebt_entry *e = par->entryinfo; + + if (BASE_CHAIN && info->target == EBT_RETURN) + return -EINVAL; + if (e->ethproto != htons(ETH_P_ARP) || + e->invflags & EBT_IPROTO) + return -EINVAL; + if (ebt_invalid_target(info->target)) + return -EINVAL; + + return 0; +} + +static struct xt_target ebt_arpreply_tg_reg __read_mostly = { + .name = "arpreply", + .revision = 0, + .family = NFPROTO_BRIDGE, + .table = "nat", + .hooks = (1 << NF_BR_NUMHOOKS) | (1 << NF_BR_PRE_ROUTING), + .target = ebt_arpreply_tg, + .checkentry = ebt_arpreply_tg_check, + .targetsize = sizeof(struct ebt_arpreply_info), + .me = THIS_MODULE, +}; + +static int __init ebt_arpreply_init(void) +{ + return xt_register_target(&ebt_arpreply_tg_reg); +} + +static void __exit ebt_arpreply_fini(void) +{ + xt_unregister_target(&ebt_arpreply_tg_reg); +} + +module_init(ebt_arpreply_init); +module_exit(ebt_arpreply_fini); +MODULE_DESCRIPTION("Ebtables: ARP reply target"); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebt_dnat.c b/net/bridge/netfilter/ebt_dnat.c new file mode 100644 index 000000000..3fda71a85 --- /dev/null +++ b/net/bridge/netfilter/ebt_dnat.c @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebt_dnat + * + * Authors: + * Bart De Schuymer + * + * June, 2002 + * + */ +#include +#include +#include "../br_private.h" +#include +#include +#include +#include + +static unsigned int +ebt_dnat_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct ebt_nat_info *info = par->targinfo; + + if (skb_ensure_writable(skb, 0)) + return EBT_DROP; + + ether_addr_copy(eth_hdr(skb)->h_dest, info->mac); + + if (is_multicast_ether_addr(info->mac)) { + if (is_broadcast_ether_addr(info->mac)) + skb->pkt_type = PACKET_BROADCAST; + else + skb->pkt_type = PACKET_MULTICAST; + } else { + const struct net_device *dev; + + switch (xt_hooknum(par)) { + case NF_BR_BROUTING: + dev = xt_in(par); + break; + case NF_BR_PRE_ROUTING: + dev = br_port_get_rcu(xt_in(par))->br->dev; + break; + default: + dev = NULL; + break; + } + + if (!dev) /* NF_BR_LOCAL_OUT */ + return info->target; + + if (ether_addr_equal(info->mac, dev->dev_addr)) + skb->pkt_type = PACKET_HOST; + else + skb->pkt_type = PACKET_OTHERHOST; + } + + return info->target; +} + +static int ebt_dnat_tg_check(const struct xt_tgchk_param *par) +{ + const struct ebt_nat_info *info = par->targinfo; + unsigned int hook_mask; + + if (BASE_CHAIN && info->target == EBT_RETURN) + return -EINVAL; + + hook_mask = par->hook_mask & ~(1 << NF_BR_NUMHOOKS); + if ((strcmp(par->table, "nat") != 0 || + (hook_mask & ~((1 << NF_BR_PRE_ROUTING) | + (1 << NF_BR_LOCAL_OUT)))) && + (strcmp(par->table, "broute") != 0 || + hook_mask & ~(1 << NF_BR_BROUTING))) + return -EINVAL; + if (ebt_invalid_target(info->target)) + return -EINVAL; + return 0; +} + +static struct xt_target ebt_dnat_tg_reg __read_mostly = { + .name = "dnat", + .revision = 0, + .family = NFPROTO_BRIDGE, + .hooks = (1 << NF_BR_NUMHOOKS) | (1 << NF_BR_PRE_ROUTING) | + (1 << NF_BR_LOCAL_OUT) | (1 << NF_BR_BROUTING), + .target = ebt_dnat_tg, + .checkentry = ebt_dnat_tg_check, + .targetsize = sizeof(struct ebt_nat_info), + .me = THIS_MODULE, +}; + +static int __init ebt_dnat_init(void) +{ + return xt_register_target(&ebt_dnat_tg_reg); +} + +static void __exit ebt_dnat_fini(void) +{ + xt_unregister_target(&ebt_dnat_tg_reg); +} + +module_init(ebt_dnat_init); +module_exit(ebt_dnat_fini); +MODULE_DESCRIPTION("Ebtables: Destination MAC address translation"); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebt_ip.c b/net/bridge/netfilter/ebt_ip.c new file mode 100644 index 000000000..df372496c --- /dev/null +++ b/net/bridge/netfilter/ebt_ip.c @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebt_ip + * + * Authors: + * Bart De Schuymer + * + * April, 2002 + * + * Changes: + * added ip-sport and ip-dport + * Innominate Security Technologies AG + * September, 2002 + */ +#include +#include +#include +#include +#include +#include +#include + +union pkthdr { + struct { + __be16 src; + __be16 dst; + } tcpudphdr; + struct { + u8 type; + u8 code; + } icmphdr; + struct { + u8 type; + } igmphdr; +}; + +static bool +ebt_ip_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct ebt_ip_info *info = par->matchinfo; + const struct iphdr *ih; + struct iphdr _iph; + const union pkthdr *pptr; + union pkthdr _pkthdr; + + ih = skb_header_pointer(skb, 0, sizeof(_iph), &_iph); + if (ih == NULL) + return false; + if ((info->bitmask & EBT_IP_TOS) && + NF_INVF(info, EBT_IP_TOS, info->tos != ih->tos)) + return false; + if ((info->bitmask & EBT_IP_SOURCE) && + NF_INVF(info, EBT_IP_SOURCE, + (ih->saddr & info->smsk) != info->saddr)) + return false; + if ((info->bitmask & EBT_IP_DEST) && + NF_INVF(info, EBT_IP_DEST, + (ih->daddr & info->dmsk) != info->daddr)) + return false; + if (info->bitmask & EBT_IP_PROTO) { + if (NF_INVF(info, EBT_IP_PROTO, info->protocol != ih->protocol)) + return false; + if (!(info->bitmask & (EBT_IP_DPORT | EBT_IP_SPORT | + EBT_IP_ICMP | EBT_IP_IGMP))) + return true; + if (ntohs(ih->frag_off) & IP_OFFSET) + return false; + + /* min icmp/igmp headersize is 4, so sizeof(_pkthdr) is ok. */ + pptr = skb_header_pointer(skb, ih->ihl*4, + sizeof(_pkthdr), &_pkthdr); + if (pptr == NULL) + return false; + if (info->bitmask & EBT_IP_DPORT) { + u32 dst = ntohs(pptr->tcpudphdr.dst); + if (NF_INVF(info, EBT_IP_DPORT, + dst < info->dport[0] || + dst > info->dport[1])) + return false; + } + if (info->bitmask & EBT_IP_SPORT) { + u32 src = ntohs(pptr->tcpudphdr.src); + if (NF_INVF(info, EBT_IP_SPORT, + src < info->sport[0] || + src > info->sport[1])) + return false; + } + if ((info->bitmask & EBT_IP_ICMP) && + NF_INVF(info, EBT_IP_ICMP, + pptr->icmphdr.type < info->icmp_type[0] || + pptr->icmphdr.type > info->icmp_type[1] || + pptr->icmphdr.code < info->icmp_code[0] || + pptr->icmphdr.code > info->icmp_code[1])) + return false; + if ((info->bitmask & EBT_IP_IGMP) && + NF_INVF(info, EBT_IP_IGMP, + pptr->igmphdr.type < info->igmp_type[0] || + pptr->igmphdr.type > info->igmp_type[1])) + return false; + } + return true; +} + +static int ebt_ip_mt_check(const struct xt_mtchk_param *par) +{ + const struct ebt_ip_info *info = par->matchinfo; + const struct ebt_entry *e = par->entryinfo; + + if (e->ethproto != htons(ETH_P_IP) || + e->invflags & EBT_IPROTO) + return -EINVAL; + if (info->bitmask & ~EBT_IP_MASK || info->invflags & ~EBT_IP_MASK) + return -EINVAL; + if (info->bitmask & (EBT_IP_DPORT | EBT_IP_SPORT)) { + if (info->invflags & EBT_IP_PROTO) + return -EINVAL; + if (info->protocol != IPPROTO_TCP && + info->protocol != IPPROTO_UDP && + info->protocol != IPPROTO_UDPLITE && + info->protocol != IPPROTO_SCTP && + info->protocol != IPPROTO_DCCP) + return -EINVAL; + } + if (info->bitmask & EBT_IP_DPORT && info->dport[0] > info->dport[1]) + return -EINVAL; + if (info->bitmask & EBT_IP_SPORT && info->sport[0] > info->sport[1]) + return -EINVAL; + if (info->bitmask & EBT_IP_ICMP) { + if ((info->invflags & EBT_IP_PROTO) || + info->protocol != IPPROTO_ICMP) + return -EINVAL; + if (info->icmp_type[0] > info->icmp_type[1] || + info->icmp_code[0] > info->icmp_code[1]) + return -EINVAL; + } + if (info->bitmask & EBT_IP_IGMP) { + if ((info->invflags & EBT_IP_PROTO) || + info->protocol != IPPROTO_IGMP) + return -EINVAL; + if (info->igmp_type[0] > info->igmp_type[1]) + return -EINVAL; + } + return 0; +} + +static struct xt_match ebt_ip_mt_reg __read_mostly = { + .name = "ip", + .revision = 0, + .family = NFPROTO_BRIDGE, + .match = ebt_ip_mt, + .checkentry = ebt_ip_mt_check, + .matchsize = sizeof(struct ebt_ip_info), + .me = THIS_MODULE, +}; + +static int __init ebt_ip_init(void) +{ + return xt_register_match(&ebt_ip_mt_reg); +} + +static void __exit ebt_ip_fini(void) +{ + xt_unregister_match(&ebt_ip_mt_reg); +} + +module_init(ebt_ip_init); +module_exit(ebt_ip_fini); +MODULE_DESCRIPTION("Ebtables: IPv4 protocol packet match"); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebt_ip6.c b/net/bridge/netfilter/ebt_ip6.c new file mode 100644 index 000000000..f3225bc31 --- /dev/null +++ b/net/bridge/netfilter/ebt_ip6.c @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebt_ip6 + * + * Authors: + * Manohar Castelino + * Kuo-Lang Tseng + * Jan Engelhardt + * + * Summary: + * This is just a modification of the IPv4 code written by + * Bart De Schuymer + * with the changes required to support IPv6 + * + * Jan, 2008 + */ +#include +#include +#include +#include +#include +#include +#include +#include + +union pkthdr { + struct { + __be16 src; + __be16 dst; + } tcpudphdr; + struct { + u8 type; + u8 code; + } icmphdr; +}; + +static bool +ebt_ip6_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct ebt_ip6_info *info = par->matchinfo; + const struct ipv6hdr *ih6; + struct ipv6hdr _ip6h; + const union pkthdr *pptr; + union pkthdr _pkthdr; + + ih6 = skb_header_pointer(skb, 0, sizeof(_ip6h), &_ip6h); + if (ih6 == NULL) + return false; + if ((info->bitmask & EBT_IP6_TCLASS) && + NF_INVF(info, EBT_IP6_TCLASS, + info->tclass != ipv6_get_dsfield(ih6))) + return false; + if (((info->bitmask & EBT_IP6_SOURCE) && + NF_INVF(info, EBT_IP6_SOURCE, + ipv6_masked_addr_cmp(&ih6->saddr, &info->smsk, + &info->saddr))) || + ((info->bitmask & EBT_IP6_DEST) && + NF_INVF(info, EBT_IP6_DEST, + ipv6_masked_addr_cmp(&ih6->daddr, &info->dmsk, + &info->daddr)))) + return false; + if (info->bitmask & EBT_IP6_PROTO) { + uint8_t nexthdr = ih6->nexthdr; + __be16 frag_off; + int offset_ph; + + offset_ph = ipv6_skip_exthdr(skb, sizeof(_ip6h), &nexthdr, &frag_off); + if (offset_ph == -1) + return false; + if (NF_INVF(info, EBT_IP6_PROTO, info->protocol != nexthdr)) + return false; + if (!(info->bitmask & (EBT_IP6_DPORT | + EBT_IP6_SPORT | EBT_IP6_ICMP6))) + return true; + + /* min icmpv6 headersize is 4, so sizeof(_pkthdr) is ok. */ + pptr = skb_header_pointer(skb, offset_ph, sizeof(_pkthdr), + &_pkthdr); + if (pptr == NULL) + return false; + if (info->bitmask & EBT_IP6_DPORT) { + u16 dst = ntohs(pptr->tcpudphdr.dst); + if (NF_INVF(info, EBT_IP6_DPORT, + dst < info->dport[0] || + dst > info->dport[1])) + return false; + } + if (info->bitmask & EBT_IP6_SPORT) { + u16 src = ntohs(pptr->tcpudphdr.src); + if (NF_INVF(info, EBT_IP6_SPORT, + src < info->sport[0] || + src > info->sport[1])) + return false; + } + if ((info->bitmask & EBT_IP6_ICMP6) && + NF_INVF(info, EBT_IP6_ICMP6, + pptr->icmphdr.type < info->icmpv6_type[0] || + pptr->icmphdr.type > info->icmpv6_type[1] || + pptr->icmphdr.code < info->icmpv6_code[0] || + pptr->icmphdr.code > info->icmpv6_code[1])) + return false; + } + return true; +} + +static int ebt_ip6_mt_check(const struct xt_mtchk_param *par) +{ + const struct ebt_entry *e = par->entryinfo; + struct ebt_ip6_info *info = par->matchinfo; + + if (e->ethproto != htons(ETH_P_IPV6) || e->invflags & EBT_IPROTO) + return -EINVAL; + if (info->bitmask & ~EBT_IP6_MASK || info->invflags & ~EBT_IP6_MASK) + return -EINVAL; + if (info->bitmask & (EBT_IP6_DPORT | EBT_IP6_SPORT)) { + if (info->invflags & EBT_IP6_PROTO) + return -EINVAL; + if (info->protocol != IPPROTO_TCP && + info->protocol != IPPROTO_UDP && + info->protocol != IPPROTO_UDPLITE && + info->protocol != IPPROTO_SCTP && + info->protocol != IPPROTO_DCCP) + return -EINVAL; + } + if (info->bitmask & EBT_IP6_DPORT && info->dport[0] > info->dport[1]) + return -EINVAL; + if (info->bitmask & EBT_IP6_SPORT && info->sport[0] > info->sport[1]) + return -EINVAL; + if (info->bitmask & EBT_IP6_ICMP6) { + if ((info->invflags & EBT_IP6_PROTO) || + info->protocol != IPPROTO_ICMPV6) + return -EINVAL; + if (info->icmpv6_type[0] > info->icmpv6_type[1] || + info->icmpv6_code[0] > info->icmpv6_code[1]) + return -EINVAL; + } + return 0; +} + +static struct xt_match ebt_ip6_mt_reg __read_mostly = { + .name = "ip6", + .revision = 0, + .family = NFPROTO_BRIDGE, + .match = ebt_ip6_mt, + .checkentry = ebt_ip6_mt_check, + .matchsize = sizeof(struct ebt_ip6_info), + .me = THIS_MODULE, +}; + +static int __init ebt_ip6_init(void) +{ + return xt_register_match(&ebt_ip6_mt_reg); +} + +static void __exit ebt_ip6_fini(void) +{ + xt_unregister_match(&ebt_ip6_mt_reg); +} + +module_init(ebt_ip6_init); +module_exit(ebt_ip6_fini); +MODULE_DESCRIPTION("Ebtables: IPv6 protocol packet match"); +MODULE_AUTHOR("Kuo-Lang Tseng "); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebt_limit.c b/net/bridge/netfilter/ebt_limit.c new file mode 100644 index 000000000..e16183bd1 --- /dev/null +++ b/net/bridge/netfilter/ebt_limit.c @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebt_limit + * + * Authors: + * Tom Marshall + * + * Mostly copied from netfilter's ipt_limit.c, see that file for + * more explanation + * + * September, 2003 + * + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include + +static DEFINE_SPINLOCK(limit_lock); + +#define MAX_CPJ (0xFFFFFFFF / (HZ*60*60*24)) + +#define _POW2_BELOW2(x) ((x)|((x)>>1)) +#define _POW2_BELOW4(x) (_POW2_BELOW2(x)|_POW2_BELOW2((x)>>2)) +#define _POW2_BELOW8(x) (_POW2_BELOW4(x)|_POW2_BELOW4((x)>>4)) +#define _POW2_BELOW16(x) (_POW2_BELOW8(x)|_POW2_BELOW8((x)>>8)) +#define _POW2_BELOW32(x) (_POW2_BELOW16(x)|_POW2_BELOW16((x)>>16)) +#define POW2_BELOW32(x) ((_POW2_BELOW32(x)>>1) + 1) + +#define CREDITS_PER_JIFFY POW2_BELOW32(MAX_CPJ) + +static bool +ebt_limit_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + struct ebt_limit_info *info = (void *)par->matchinfo; + unsigned long now = jiffies; + + spin_lock_bh(&limit_lock); + info->credit += (now - xchg(&info->prev, now)) * CREDITS_PER_JIFFY; + if (info->credit > info->credit_cap) + info->credit = info->credit_cap; + + if (info->credit >= info->cost) { + /* We're not limited. */ + info->credit -= info->cost; + spin_unlock_bh(&limit_lock); + return true; + } + + spin_unlock_bh(&limit_lock); + return false; +} + +/* Precision saver. */ +static u_int32_t +user2credits(u_int32_t user) +{ + /* If multiplying would overflow... */ + if (user > 0xFFFFFFFF / (HZ*CREDITS_PER_JIFFY)) + /* Divide first. */ + return (user / EBT_LIMIT_SCALE) * HZ * CREDITS_PER_JIFFY; + + return (user * HZ * CREDITS_PER_JIFFY) / EBT_LIMIT_SCALE; +} + +static int ebt_limit_mt_check(const struct xt_mtchk_param *par) +{ + struct ebt_limit_info *info = par->matchinfo; + + /* Check for overflow. */ + if (info->burst == 0 || + user2credits(info->avg * info->burst) < user2credits(info->avg)) { + pr_info_ratelimited("overflow, try lower: %u/%u\n", + info->avg, info->burst); + return -EINVAL; + } + + /* User avg in seconds * EBT_LIMIT_SCALE: convert to jiffies * 128. */ + info->prev = jiffies; + info->credit = user2credits(info->avg * info->burst); + info->credit_cap = user2credits(info->avg * info->burst); + info->cost = user2credits(info->avg); + return 0; +} + + +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT +/* + * no conversion function needed -- + * only avg/burst have meaningful values in userspace. + */ +struct ebt_compat_limit_info { + compat_uint_t avg, burst; + compat_ulong_t prev; + compat_uint_t credit, credit_cap, cost; +}; +#endif + +static struct xt_match ebt_limit_mt_reg __read_mostly = { + .name = "limit", + .revision = 0, + .family = NFPROTO_BRIDGE, + .match = ebt_limit_mt, + .checkentry = ebt_limit_mt_check, + .matchsize = sizeof(struct ebt_limit_info), + .usersize = offsetof(struct ebt_limit_info, prev), +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + .compatsize = sizeof(struct ebt_compat_limit_info), +#endif + .me = THIS_MODULE, +}; + +static int __init ebt_limit_init(void) +{ + return xt_register_match(&ebt_limit_mt_reg); +} + +static void __exit ebt_limit_fini(void) +{ + xt_unregister_match(&ebt_limit_mt_reg); +} + +module_init(ebt_limit_init); +module_exit(ebt_limit_fini); +MODULE_DESCRIPTION("Ebtables: Rate-limit match"); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebt_log.c b/net/bridge/netfilter/ebt_log.c new file mode 100644 index 000000000..e2eea1daa --- /dev/null +++ b/net/bridge/netfilter/ebt_log.c @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebt_log + * + * Authors: + * Bart De Schuymer + * Harald Welte + * + * April, 2002 + * + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static DEFINE_SPINLOCK(ebt_log_lock); + +static int ebt_log_tg_check(const struct xt_tgchk_param *par) +{ + struct ebt_log_info *info = par->targinfo; + + if (info->bitmask & ~EBT_LOG_MASK) + return -EINVAL; + if (info->loglevel >= 8) + return -EINVAL; + info->prefix[EBT_LOG_PREFIX_SIZE - 1] = '\0'; + return 0; +} + +struct tcpudphdr { + __be16 src; + __be16 dst; +}; + +struct arppayload { + unsigned char mac_src[ETH_ALEN]; + unsigned char ip_src[4]; + unsigned char mac_dst[ETH_ALEN]; + unsigned char ip_dst[4]; +}; + +static void +print_ports(const struct sk_buff *skb, uint8_t protocol, int offset) +{ + if (protocol == IPPROTO_TCP || + protocol == IPPROTO_UDP || + protocol == IPPROTO_UDPLITE || + protocol == IPPROTO_SCTP || + protocol == IPPROTO_DCCP) { + const struct tcpudphdr *pptr; + struct tcpudphdr _ports; + + pptr = skb_header_pointer(skb, offset, + sizeof(_ports), &_ports); + if (pptr == NULL) { + pr_cont(" INCOMPLETE TCP/UDP header"); + return; + } + pr_cont(" SPT=%u DPT=%u", ntohs(pptr->src), ntohs(pptr->dst)); + } +} + +static void +ebt_log_packet(struct net *net, u_int8_t pf, unsigned int hooknum, + const struct sk_buff *skb, const struct net_device *in, + const struct net_device *out, const struct nf_loginfo *loginfo, + const char *prefix) +{ + unsigned int bitmask; + + /* FIXME: Disabled from containers until syslog ns is supported */ + if (!net_eq(net, &init_net) && !sysctl_nf_log_all_netns) + return; + + spin_lock_bh(&ebt_log_lock); + printk(KERN_SOH "%c%s IN=%s OUT=%s MAC source = %pM MAC dest = %pM proto = 0x%04x", + '0' + loginfo->u.log.level, prefix, + in ? in->name : "", out ? out->name : "", + eth_hdr(skb)->h_source, eth_hdr(skb)->h_dest, + ntohs(eth_hdr(skb)->h_proto)); + + if (loginfo->type == NF_LOG_TYPE_LOG) + bitmask = loginfo->u.log.logflags; + else + bitmask = NF_LOG_DEFAULT_MASK; + + if ((bitmask & EBT_LOG_IP) && eth_hdr(skb)->h_proto == + htons(ETH_P_IP)) { + const struct iphdr *ih; + struct iphdr _iph; + + ih = skb_header_pointer(skb, 0, sizeof(_iph), &_iph); + if (ih == NULL) { + pr_cont(" INCOMPLETE IP header"); + goto out; + } + pr_cont(" IP SRC=%pI4 IP DST=%pI4, IP tos=0x%02X, IP proto=%d", + &ih->saddr, &ih->daddr, ih->tos, ih->protocol); + print_ports(skb, ih->protocol, ih->ihl*4); + goto out; + } + +#if IS_ENABLED(CONFIG_BRIDGE_EBT_IP6) + if ((bitmask & EBT_LOG_IP6) && eth_hdr(skb)->h_proto == + htons(ETH_P_IPV6)) { + const struct ipv6hdr *ih; + struct ipv6hdr _iph; + uint8_t nexthdr; + __be16 frag_off; + int offset_ph; + + ih = skb_header_pointer(skb, 0, sizeof(_iph), &_iph); + if (ih == NULL) { + pr_cont(" INCOMPLETE IPv6 header"); + goto out; + } + pr_cont(" IPv6 SRC=%pI6 IPv6 DST=%pI6, IPv6 priority=0x%01X, Next Header=%d", + &ih->saddr, &ih->daddr, ih->priority, ih->nexthdr); + nexthdr = ih->nexthdr; + offset_ph = ipv6_skip_exthdr(skb, sizeof(_iph), &nexthdr, &frag_off); + if (offset_ph == -1) + goto out; + print_ports(skb, nexthdr, offset_ph); + goto out; + } +#endif + + if ((bitmask & EBT_LOG_ARP) && + ((eth_hdr(skb)->h_proto == htons(ETH_P_ARP)) || + (eth_hdr(skb)->h_proto == htons(ETH_P_RARP)))) { + const struct arphdr *ah; + struct arphdr _arph; + + ah = skb_header_pointer(skb, 0, sizeof(_arph), &_arph); + if (ah == NULL) { + pr_cont(" INCOMPLETE ARP header"); + goto out; + } + pr_cont(" ARP HTYPE=%d, PTYPE=0x%04x, OPCODE=%d", + ntohs(ah->ar_hrd), ntohs(ah->ar_pro), + ntohs(ah->ar_op)); + + /* If it's for Ethernet and the lengths are OK, + * then log the ARP payload + */ + if (ah->ar_hrd == htons(1) && + ah->ar_hln == ETH_ALEN && + ah->ar_pln == sizeof(__be32)) { + const struct arppayload *ap; + struct arppayload _arpp; + + ap = skb_header_pointer(skb, sizeof(_arph), + sizeof(_arpp), &_arpp); + if (ap == NULL) { + pr_cont(" INCOMPLETE ARP payload"); + goto out; + } + pr_cont(" ARP MAC SRC=%pM ARP IP SRC=%pI4 ARP MAC DST=%pM ARP IP DST=%pI4", + ap->mac_src, ap->ip_src, + ap->mac_dst, ap->ip_dst); + } + } +out: + pr_cont("\n"); + spin_unlock_bh(&ebt_log_lock); +} + +static unsigned int +ebt_log_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct ebt_log_info *info = par->targinfo; + struct nf_loginfo li; + struct net *net = xt_net(par); + + li.type = NF_LOG_TYPE_LOG; + li.u.log.level = info->loglevel; + li.u.log.logflags = info->bitmask; + + /* Remember that we have to use ebt_log_packet() not to break backward + * compatibility. We cannot use the default bridge packet logger via + * nf_log_packet() with NFT_LOG_TYPE_LOG here. --Pablo + */ + if (info->bitmask & EBT_LOG_NFLOG) + nf_log_packet(net, NFPROTO_BRIDGE, xt_hooknum(par), skb, + xt_in(par), xt_out(par), &li, "%s", + info->prefix); + else + ebt_log_packet(net, NFPROTO_BRIDGE, xt_hooknum(par), skb, + xt_in(par), xt_out(par), &li, info->prefix); + return EBT_CONTINUE; +} + +static struct xt_target ebt_log_tg_reg __read_mostly = { + .name = "log", + .revision = 0, + .family = NFPROTO_BRIDGE, + .target = ebt_log_tg, + .checkentry = ebt_log_tg_check, + .targetsize = sizeof(struct ebt_log_info), + .me = THIS_MODULE, +}; + +static int __init ebt_log_init(void) +{ + return xt_register_target(&ebt_log_tg_reg); +} + +static void __exit ebt_log_fini(void) +{ + xt_unregister_target(&ebt_log_tg_reg); +} + +module_init(ebt_log_init); +module_exit(ebt_log_fini); +MODULE_DESCRIPTION("Ebtables: Packet logging to syslog"); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebt_mark.c b/net/bridge/netfilter/ebt_mark.c new file mode 100644 index 000000000..8cf653c72 --- /dev/null +++ b/net/bridge/netfilter/ebt_mark.c @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebt_mark + * + * Authors: + * Bart De Schuymer + * + * July, 2002 + * + */ + +/* The mark target can be used in any chain, + * I believe adding a mangle table just for marking is total overkill. + * Marking a frame doesn't really change anything in the frame anyway. + */ + +#include +#include +#include +#include + +static unsigned int +ebt_mark_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct ebt_mark_t_info *info = par->targinfo; + int action = info->target & -16; + + if (action == MARK_SET_VALUE) + skb->mark = info->mark; + else if (action == MARK_OR_VALUE) + skb->mark |= info->mark; + else if (action == MARK_AND_VALUE) + skb->mark &= info->mark; + else + skb->mark ^= info->mark; + + return info->target | ~EBT_VERDICT_BITS; +} + +static int ebt_mark_tg_check(const struct xt_tgchk_param *par) +{ + const struct ebt_mark_t_info *info = par->targinfo; + int tmp; + + tmp = info->target | ~EBT_VERDICT_BITS; + if (BASE_CHAIN && tmp == EBT_RETURN) + return -EINVAL; + if (ebt_invalid_target(tmp)) + return -EINVAL; + tmp = info->target & ~EBT_VERDICT_BITS; + if (tmp != MARK_SET_VALUE && tmp != MARK_OR_VALUE && + tmp != MARK_AND_VALUE && tmp != MARK_XOR_VALUE) + return -EINVAL; + return 0; +} +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT +struct compat_ebt_mark_t_info { + compat_ulong_t mark; + compat_uint_t target; +}; + +static void mark_tg_compat_from_user(void *dst, const void *src) +{ + const struct compat_ebt_mark_t_info *user = src; + struct ebt_mark_t_info *kern = dst; + + kern->mark = user->mark; + kern->target = user->target; +} + +static int mark_tg_compat_to_user(void __user *dst, const void *src) +{ + struct compat_ebt_mark_t_info __user *user = dst; + const struct ebt_mark_t_info *kern = src; + + if (put_user(kern->mark, &user->mark) || + put_user(kern->target, &user->target)) + return -EFAULT; + return 0; +} +#endif + +static struct xt_target ebt_mark_tg_reg __read_mostly = { + .name = "mark", + .revision = 0, + .family = NFPROTO_BRIDGE, + .target = ebt_mark_tg, + .checkentry = ebt_mark_tg_check, + .targetsize = sizeof(struct ebt_mark_t_info), +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + .compatsize = sizeof(struct compat_ebt_mark_t_info), + .compat_from_user = mark_tg_compat_from_user, + .compat_to_user = mark_tg_compat_to_user, +#endif + .me = THIS_MODULE, +}; + +static int __init ebt_mark_init(void) +{ + return xt_register_target(&ebt_mark_tg_reg); +} + +static void __exit ebt_mark_fini(void) +{ + xt_unregister_target(&ebt_mark_tg_reg); +} + +module_init(ebt_mark_init); +module_exit(ebt_mark_fini); +MODULE_DESCRIPTION("Ebtables: Packet mark modification"); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebt_mark_m.c b/net/bridge/netfilter/ebt_mark_m.c new file mode 100644 index 000000000..5872e73c7 --- /dev/null +++ b/net/bridge/netfilter/ebt_mark_m.c @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebt_mark_m + * + * Authors: + * Bart De Schuymer + * + * July, 2002 + * + */ +#include +#include +#include +#include + +static bool +ebt_mark_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct ebt_mark_m_info *info = par->matchinfo; + + if (info->bitmask & EBT_MARK_OR) + return !!(skb->mark & info->mask) ^ info->invert; + return ((skb->mark & info->mask) == info->mark) ^ info->invert; +} + +static int ebt_mark_mt_check(const struct xt_mtchk_param *par) +{ + const struct ebt_mark_m_info *info = par->matchinfo; + + if (info->bitmask & ~EBT_MARK_MASK) + return -EINVAL; + if ((info->bitmask & EBT_MARK_OR) && (info->bitmask & EBT_MARK_AND)) + return -EINVAL; + if (!info->bitmask) + return -EINVAL; + return 0; +} + + +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT +struct compat_ebt_mark_m_info { + compat_ulong_t mark, mask; + uint8_t invert, bitmask; +}; + +static void mark_mt_compat_from_user(void *dst, const void *src) +{ + const struct compat_ebt_mark_m_info *user = src; + struct ebt_mark_m_info *kern = dst; + + kern->mark = user->mark; + kern->mask = user->mask; + kern->invert = user->invert; + kern->bitmask = user->bitmask; +} + +static int mark_mt_compat_to_user(void __user *dst, const void *src) +{ + struct compat_ebt_mark_m_info __user *user = dst; + const struct ebt_mark_m_info *kern = src; + + if (put_user(kern->mark, &user->mark) || + put_user(kern->mask, &user->mask) || + put_user(kern->invert, &user->invert) || + put_user(kern->bitmask, &user->bitmask)) + return -EFAULT; + return 0; +} +#endif + +static struct xt_match ebt_mark_mt_reg __read_mostly = { + .name = "mark_m", + .revision = 0, + .family = NFPROTO_BRIDGE, + .match = ebt_mark_mt, + .checkentry = ebt_mark_mt_check, + .matchsize = sizeof(struct ebt_mark_m_info), +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + .compatsize = sizeof(struct compat_ebt_mark_m_info), + .compat_from_user = mark_mt_compat_from_user, + .compat_to_user = mark_mt_compat_to_user, +#endif + .me = THIS_MODULE, +}; + +static int __init ebt_mark_m_init(void) +{ + return xt_register_match(&ebt_mark_mt_reg); +} + +static void __exit ebt_mark_m_fini(void) +{ + xt_unregister_match(&ebt_mark_mt_reg); +} + +module_init(ebt_mark_m_init); +module_exit(ebt_mark_m_fini); +MODULE_DESCRIPTION("Ebtables: Packet mark match"); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebt_nflog.c b/net/bridge/netfilter/ebt_nflog.c new file mode 100644 index 000000000..61bf8f446 --- /dev/null +++ b/net/bridge/netfilter/ebt_nflog.c @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebt_nflog + * + * Author: + * Peter Warasin + * + * February, 2008 + * + * Based on: + * xt_NFLOG.c, (C) 2006 by Patrick McHardy + * ebt_ulog.c, (C) 2004 by Bart De Schuymer + * + */ + +#include +#include +#include +#include +#include +#include + +static unsigned int +ebt_nflog_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct ebt_nflog_info *info = par->targinfo; + struct net *net = xt_net(par); + struct nf_loginfo li; + + li.type = NF_LOG_TYPE_ULOG; + li.u.ulog.copy_len = info->len; + li.u.ulog.group = info->group; + li.u.ulog.qthreshold = info->threshold; + li.u.ulog.flags = 0; + + nf_log_packet(net, PF_BRIDGE, xt_hooknum(par), skb, xt_in(par), + xt_out(par), &li, "%s", info->prefix); + return EBT_CONTINUE; +} + +static int ebt_nflog_tg_check(const struct xt_tgchk_param *par) +{ + struct ebt_nflog_info *info = par->targinfo; + + if (info->flags & ~EBT_NFLOG_MASK) + return -EINVAL; + info->prefix[EBT_NFLOG_PREFIX_SIZE - 1] = '\0'; + return 0; +} + +static struct xt_target ebt_nflog_tg_reg __read_mostly = { + .name = "nflog", + .revision = 0, + .family = NFPROTO_BRIDGE, + .target = ebt_nflog_tg, + .checkentry = ebt_nflog_tg_check, + .targetsize = sizeof(struct ebt_nflog_info), + .me = THIS_MODULE, +}; + +static int __init ebt_nflog_init(void) +{ + return xt_register_target(&ebt_nflog_tg_reg); +} + +static void __exit ebt_nflog_fini(void) +{ + xt_unregister_target(&ebt_nflog_tg_reg); +} + +module_init(ebt_nflog_init); +module_exit(ebt_nflog_fini); +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Peter Warasin "); +MODULE_DESCRIPTION("ebtables NFLOG netfilter logging module"); diff --git a/net/bridge/netfilter/ebt_pkttype.c b/net/bridge/netfilter/ebt_pkttype.c new file mode 100644 index 000000000..c9e306119 --- /dev/null +++ b/net/bridge/netfilter/ebt_pkttype.c @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebt_pkttype + * + * Authors: + * Bart De Schuymer + * + * April, 2003 + * + */ +#include +#include +#include +#include + +static bool +ebt_pkttype_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct ebt_pkttype_info *info = par->matchinfo; + + return (skb->pkt_type == info->pkt_type) ^ info->invert; +} + +static int ebt_pkttype_mt_check(const struct xt_mtchk_param *par) +{ + const struct ebt_pkttype_info *info = par->matchinfo; + + if (info->invert != 0 && info->invert != 1) + return -EINVAL; + /* Allow any pkt_type value */ + return 0; +} + +static struct xt_match ebt_pkttype_mt_reg __read_mostly = { + .name = "pkttype", + .revision = 0, + .family = NFPROTO_BRIDGE, + .match = ebt_pkttype_mt, + .checkentry = ebt_pkttype_mt_check, + .matchsize = sizeof(struct ebt_pkttype_info), + .me = THIS_MODULE, +}; + +static int __init ebt_pkttype_init(void) +{ + return xt_register_match(&ebt_pkttype_mt_reg); +} + +static void __exit ebt_pkttype_fini(void) +{ + xt_unregister_match(&ebt_pkttype_mt_reg); +} + +module_init(ebt_pkttype_init); +module_exit(ebt_pkttype_fini); +MODULE_DESCRIPTION("Ebtables: Link layer packet type match"); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebt_redirect.c b/net/bridge/netfilter/ebt_redirect.c new file mode 100644 index 000000000..307790562 --- /dev/null +++ b/net/bridge/netfilter/ebt_redirect.c @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebt_redirect + * + * Authors: + * Bart De Schuymer + * + * April, 2002 + * + */ +#include +#include +#include "../br_private.h" +#include +#include +#include +#include + +static unsigned int +ebt_redirect_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct ebt_redirect_info *info = par->targinfo; + + if (skb_ensure_writable(skb, 0)) + return EBT_DROP; + + if (xt_hooknum(par) != NF_BR_BROUTING) + /* rcu_read_lock()ed by nf_hook_thresh */ + ether_addr_copy(eth_hdr(skb)->h_dest, + br_port_get_rcu(xt_in(par))->br->dev->dev_addr); + else + ether_addr_copy(eth_hdr(skb)->h_dest, xt_in(par)->dev_addr); + skb->pkt_type = PACKET_HOST; + return info->target; +} + +static int ebt_redirect_tg_check(const struct xt_tgchk_param *par) +{ + const struct ebt_redirect_info *info = par->targinfo; + unsigned int hook_mask; + + if (BASE_CHAIN && info->target == EBT_RETURN) + return -EINVAL; + + hook_mask = par->hook_mask & ~(1 << NF_BR_NUMHOOKS); + if ((strcmp(par->table, "nat") != 0 || + hook_mask & ~(1 << NF_BR_PRE_ROUTING)) && + (strcmp(par->table, "broute") != 0 || + hook_mask & ~(1 << NF_BR_BROUTING))) + return -EINVAL; + if (ebt_invalid_target(info->target)) + return -EINVAL; + return 0; +} + +static struct xt_target ebt_redirect_tg_reg __read_mostly = { + .name = "redirect", + .revision = 0, + .family = NFPROTO_BRIDGE, + .hooks = (1 << NF_BR_NUMHOOKS) | (1 << NF_BR_PRE_ROUTING) | + (1 << NF_BR_BROUTING), + .target = ebt_redirect_tg, + .checkentry = ebt_redirect_tg_check, + .targetsize = sizeof(struct ebt_redirect_info), + .me = THIS_MODULE, +}; + +static int __init ebt_redirect_init(void) +{ + return xt_register_target(&ebt_redirect_tg_reg); +} + +static void __exit ebt_redirect_fini(void) +{ + xt_unregister_target(&ebt_redirect_tg_reg); +} + +module_init(ebt_redirect_init); +module_exit(ebt_redirect_fini); +MODULE_DESCRIPTION("Ebtables: Packet redirection to localhost"); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebt_snat.c b/net/bridge/netfilter/ebt_snat.c new file mode 100644 index 000000000..7dfbcdfc3 --- /dev/null +++ b/net/bridge/netfilter/ebt_snat.c @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebt_snat + * + * Authors: + * Bart De Schuymer + * + * June, 2002 + * + */ +#include +#include +#include +#include +#include +#include +#include +#include + +static unsigned int +ebt_snat_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct ebt_nat_info *info = par->targinfo; + + if (skb_ensure_writable(skb, 0)) + return EBT_DROP; + + ether_addr_copy(eth_hdr(skb)->h_source, info->mac); + if (!(info->target & NAT_ARP_BIT) && + eth_hdr(skb)->h_proto == htons(ETH_P_ARP)) { + const struct arphdr *ap; + struct arphdr _ah; + + ap = skb_header_pointer(skb, 0, sizeof(_ah), &_ah); + if (ap == NULL) + return EBT_DROP; + if (ap->ar_hln != ETH_ALEN) + goto out; + if (skb_store_bits(skb, sizeof(_ah), info->mac, ETH_ALEN)) + return EBT_DROP; + } +out: + return info->target | ~EBT_VERDICT_BITS; +} + +static int ebt_snat_tg_check(const struct xt_tgchk_param *par) +{ + const struct ebt_nat_info *info = par->targinfo; + int tmp; + + tmp = info->target | ~EBT_VERDICT_BITS; + if (BASE_CHAIN && tmp == EBT_RETURN) + return -EINVAL; + + if (ebt_invalid_target(tmp)) + return -EINVAL; + tmp = info->target | EBT_VERDICT_BITS; + if ((tmp & ~NAT_ARP_BIT) != ~NAT_ARP_BIT) + return -EINVAL; + return 0; +} + +static struct xt_target ebt_snat_tg_reg __read_mostly = { + .name = "snat", + .revision = 0, + .family = NFPROTO_BRIDGE, + .table = "nat", + .hooks = (1 << NF_BR_NUMHOOKS) | (1 << NF_BR_POST_ROUTING), + .target = ebt_snat_tg, + .checkentry = ebt_snat_tg_check, + .targetsize = sizeof(struct ebt_nat_info), + .me = THIS_MODULE, +}; + +static int __init ebt_snat_init(void) +{ + return xt_register_target(&ebt_snat_tg_reg); +} + +static void __exit ebt_snat_fini(void) +{ + xt_unregister_target(&ebt_snat_tg_reg); +} + +module_init(ebt_snat_init); +module_exit(ebt_snat_fini); +MODULE_DESCRIPTION("Ebtables: Source MAC address translation"); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebt_stp.c b/net/bridge/netfilter/ebt_stp.c new file mode 100644 index 000000000..8f68afda5 --- /dev/null +++ b/net/bridge/netfilter/ebt_stp.c @@ -0,0 +1,194 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebt_stp + * + * Authors: + * Bart De Schuymer + * Stephen Hemminger + * + * July, 2003 + */ +#include +#include +#include +#include +#include + +#define BPDU_TYPE_CONFIG 0 + +struct stp_header { + u8 dsap; + u8 ssap; + u8 ctrl; + u8 pid; + u8 vers; + u8 type; +}; + +struct stp_config_pdu { + u8 flags; + u8 root[8]; + u8 root_cost[4]; + u8 sender[8]; + u8 port[2]; + u8 msg_age[2]; + u8 max_age[2]; + u8 hello_time[2]; + u8 forward_delay[2]; +}; + +#define NR16(p) (p[0] << 8 | p[1]) +#define NR32(p) ((p[0] << 24) | (p[1] << 16) | (p[2] << 8) | p[3]) + +static bool ebt_filter_config(const struct ebt_stp_info *info, + const struct stp_config_pdu *stpc) +{ + const struct ebt_stp_config_info *c; + u16 v16; + u32 v32; + + c = &info->config; + if ((info->bitmask & EBT_STP_FLAGS) && + NF_INVF(info, EBT_STP_FLAGS, c->flags != stpc->flags)) + return false; + if (info->bitmask & EBT_STP_ROOTPRIO) { + v16 = NR16(stpc->root); + if (NF_INVF(info, EBT_STP_ROOTPRIO, + v16 < c->root_priol || v16 > c->root_priou)) + return false; + } + if (info->bitmask & EBT_STP_ROOTADDR) { + if (NF_INVF(info, EBT_STP_ROOTADDR, + !ether_addr_equal_masked(&stpc->root[2], + c->root_addr, + c->root_addrmsk))) + return false; + } + if (info->bitmask & EBT_STP_ROOTCOST) { + v32 = NR32(stpc->root_cost); + if (NF_INVF(info, EBT_STP_ROOTCOST, + v32 < c->root_costl || v32 > c->root_costu)) + return false; + } + if (info->bitmask & EBT_STP_SENDERPRIO) { + v16 = NR16(stpc->sender); + if (NF_INVF(info, EBT_STP_SENDERPRIO, + v16 < c->sender_priol || v16 > c->sender_priou)) + return false; + } + if (info->bitmask & EBT_STP_SENDERADDR) { + if (NF_INVF(info, EBT_STP_SENDERADDR, + !ether_addr_equal_masked(&stpc->sender[2], + c->sender_addr, + c->sender_addrmsk))) + return false; + } + if (info->bitmask & EBT_STP_PORT) { + v16 = NR16(stpc->port); + if (NF_INVF(info, EBT_STP_PORT, + v16 < c->portl || v16 > c->portu)) + return false; + } + if (info->bitmask & EBT_STP_MSGAGE) { + v16 = NR16(stpc->msg_age); + if (NF_INVF(info, EBT_STP_MSGAGE, + v16 < c->msg_agel || v16 > c->msg_ageu)) + return false; + } + if (info->bitmask & EBT_STP_MAXAGE) { + v16 = NR16(stpc->max_age); + if (NF_INVF(info, EBT_STP_MAXAGE, + v16 < c->max_agel || v16 > c->max_ageu)) + return false; + } + if (info->bitmask & EBT_STP_HELLOTIME) { + v16 = NR16(stpc->hello_time); + if (NF_INVF(info, EBT_STP_HELLOTIME, + v16 < c->hello_timel || v16 > c->hello_timeu)) + return false; + } + if (info->bitmask & EBT_STP_FWDD) { + v16 = NR16(stpc->forward_delay); + if (NF_INVF(info, EBT_STP_FWDD, + v16 < c->forward_delayl || v16 > c->forward_delayu)) + return false; + } + return true; +} + +static bool +ebt_stp_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct ebt_stp_info *info = par->matchinfo; + const struct stp_header *sp; + struct stp_header _stph; + const u8 header[6] = {0x42, 0x42, 0x03, 0x00, 0x00, 0x00}; + + sp = skb_header_pointer(skb, 0, sizeof(_stph), &_stph); + if (sp == NULL) + return false; + + /* The stp code only considers these */ + if (memcmp(sp, header, sizeof(header))) + return false; + + if ((info->bitmask & EBT_STP_TYPE) && + NF_INVF(info, EBT_STP_TYPE, info->type != sp->type)) + return false; + + if (sp->type == BPDU_TYPE_CONFIG && + info->bitmask & EBT_STP_CONFIG_MASK) { + const struct stp_config_pdu *st; + struct stp_config_pdu _stpc; + + st = skb_header_pointer(skb, sizeof(_stph), + sizeof(_stpc), &_stpc); + if (st == NULL) + return false; + return ebt_filter_config(info, st); + } + return true; +} + +static int ebt_stp_mt_check(const struct xt_mtchk_param *par) +{ + const struct ebt_stp_info *info = par->matchinfo; + const struct ebt_entry *e = par->entryinfo; + + if (info->bitmask & ~EBT_STP_MASK || info->invflags & ~EBT_STP_MASK || + !(info->bitmask & EBT_STP_MASK)) + return -EINVAL; + /* Make sure the match only receives stp frames */ + if (!par->nft_compat && + (!ether_addr_equal(e->destmac, eth_stp_addr) || + !(e->bitmask & EBT_DESTMAC) || + !is_broadcast_ether_addr(e->destmsk))) + return -EINVAL; + + return 0; +} + +static struct xt_match ebt_stp_mt_reg __read_mostly = { + .name = "stp", + .revision = 0, + .family = NFPROTO_BRIDGE, + .match = ebt_stp_mt, + .checkentry = ebt_stp_mt_check, + .matchsize = sizeof(struct ebt_stp_info), + .me = THIS_MODULE, +}; + +static int __init ebt_stp_init(void) +{ + return xt_register_match(&ebt_stp_mt_reg); +} + +static void __exit ebt_stp_fini(void) +{ + xt_unregister_match(&ebt_stp_mt_reg); +} + +module_init(ebt_stp_init); +module_exit(ebt_stp_fini); +MODULE_DESCRIPTION("Ebtables: Spanning Tree Protocol packet match"); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebt_vlan.c b/net/bridge/netfilter/ebt_vlan.c new file mode 100644 index 000000000..80ede370a --- /dev/null +++ b/net/bridge/netfilter/ebt_vlan.c @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Description: EBTables 802.1Q match extension kernelspace module. + * Authors: Nick Fedchik + * Bart De Schuymer + */ + +#include +#include +#include +#include +#include +#include +#include + +#define MODULE_VERS "0.6" + +MODULE_AUTHOR("Nick Fedchik "); +MODULE_DESCRIPTION("Ebtables: 802.1Q VLAN tag match"); +MODULE_LICENSE("GPL"); + +#define GET_BITMASK(_BIT_MASK_) info->bitmask & _BIT_MASK_ +#define EXIT_ON_MISMATCH(_MATCH_,_MASK_) {if (!((info->_MATCH_ == _MATCH_)^!!(info->invflags & _MASK_))) return false; } + +static bool +ebt_vlan_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct ebt_vlan_info *info = par->matchinfo; + + unsigned short TCI; /* Whole TCI, given from parsed frame */ + unsigned short id; /* VLAN ID, given from frame TCI */ + unsigned char prio; /* user_priority, given from frame TCI */ + /* VLAN encapsulated Type/Length field, given from orig frame */ + __be16 encap; + + if (skb_vlan_tag_present(skb)) { + TCI = skb_vlan_tag_get(skb); + encap = skb->protocol; + } else { + const struct vlan_hdr *fp; + struct vlan_hdr _frame; + + fp = skb_header_pointer(skb, 0, sizeof(_frame), &_frame); + if (fp == NULL) + return false; + + TCI = ntohs(fp->h_vlan_TCI); + encap = fp->h_vlan_encapsulated_proto; + } + + /* Tag Control Information (TCI) consists of the following elements: + * - User_priority. The user_priority field is three bits in length, + * interpreted as a binary number. + * - Canonical Format Indicator (CFI). The Canonical Format Indicator + * (CFI) is a single bit flag value. Currently ignored. + * - VLAN Identifier (VID). The VID is encoded as + * an unsigned binary number. + */ + id = TCI & VLAN_VID_MASK; + prio = (TCI >> 13) & 0x7; + + /* Checking VLAN Identifier (VID) */ + if (GET_BITMASK(EBT_VLAN_ID)) + EXIT_ON_MISMATCH(id, EBT_VLAN_ID); + + /* Checking user_priority */ + if (GET_BITMASK(EBT_VLAN_PRIO)) + EXIT_ON_MISMATCH(prio, EBT_VLAN_PRIO); + + /* Checking Encapsulated Proto (Length/Type) field */ + if (GET_BITMASK(EBT_VLAN_ENCAP)) + EXIT_ON_MISMATCH(encap, EBT_VLAN_ENCAP); + + return true; +} + +static int ebt_vlan_mt_check(const struct xt_mtchk_param *par) +{ + struct ebt_vlan_info *info = par->matchinfo; + const struct ebt_entry *e = par->entryinfo; + + /* Is it 802.1Q frame checked? */ + if (e->ethproto != htons(ETH_P_8021Q)) { + pr_debug("passed entry proto %2.4X is not 802.1Q (8100)\n", + ntohs(e->ethproto)); + return -EINVAL; + } + + /* Check for bitmask range + * True if even one bit is out of mask + */ + if (info->bitmask & ~EBT_VLAN_MASK) { + pr_debug("bitmask %2X is out of mask (%2X)\n", + info->bitmask, EBT_VLAN_MASK); + return -EINVAL; + } + + /* Check for inversion flags range */ + if (info->invflags & ~EBT_VLAN_MASK) { + pr_debug("inversion flags %2X is out of mask (%2X)\n", + info->invflags, EBT_VLAN_MASK); + return -EINVAL; + } + + /* Reserved VLAN ID (VID) values + * ----------------------------- + * 0 - The null VLAN ID. + * 1 - The default Port VID (PVID) + * 0x0FFF - Reserved for implementation use. + * if_vlan.h: VLAN_N_VID 4096. + */ + if (GET_BITMASK(EBT_VLAN_ID)) { + if (!!info->id) { /* if id!=0 => check vid range */ + if (info->id > VLAN_N_VID) { + pr_debug("id %d is out of range (1-4096)\n", + info->id); + return -EINVAL; + } + /* Note: This is valid VLAN-tagged frame point. + * Any value of user_priority are acceptable, + * but should be ignored according to 802.1Q Std. + * So we just drop the prio flag. + */ + info->bitmask &= ~EBT_VLAN_PRIO; + } + /* Else, id=0 (null VLAN ID) => user_priority range (any?) */ + } + + if (GET_BITMASK(EBT_VLAN_PRIO)) { + if ((unsigned char) info->prio > 7) { + pr_debug("prio %d is out of range (0-7)\n", + info->prio); + return -EINVAL; + } + } + /* Check for encapsulated proto range - it is possible to be + * any value for u_short range. + * if_ether.h: ETH_ZLEN 60 - Min. octets in frame sans FCS + */ + if (GET_BITMASK(EBT_VLAN_ENCAP)) { + if ((unsigned short) ntohs(info->encap) < ETH_ZLEN) { + pr_debug("encap frame length %d is less than " + "minimal\n", ntohs(info->encap)); + return -EINVAL; + } + } + + return 0; +} + +static struct xt_match ebt_vlan_mt_reg __read_mostly = { + .name = "vlan", + .revision = 0, + .family = NFPROTO_BRIDGE, + .match = ebt_vlan_mt, + .checkentry = ebt_vlan_mt_check, + .matchsize = sizeof(struct ebt_vlan_info), + .me = THIS_MODULE, +}; + +static int __init ebt_vlan_init(void) +{ + pr_debug("ebtables 802.1Q extension module v" MODULE_VERS "\n"); + return xt_register_match(&ebt_vlan_mt_reg); +} + +static void __exit ebt_vlan_fini(void) +{ + xt_unregister_match(&ebt_vlan_mt_reg); +} + +module_init(ebt_vlan_init); +module_exit(ebt_vlan_fini); diff --git a/net/bridge/netfilter/ebtable_broute.c b/net/bridge/netfilter/ebtable_broute.c new file mode 100644 index 000000000..8f1925302 --- /dev/null +++ b/net/bridge/netfilter/ebtable_broute.c @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebtable_broute + * + * Authors: + * Bart De Schuymer + * + * April, 2002 + * + * This table lets you choose between routing and bridging for frames + * entering on a bridge enslaved nic. This table is traversed before any + * other ebtables table. See net/bridge/br_input.c. + */ + +#include +#include +#include + +#include "../br_private.h" + +/* EBT_ACCEPT means the frame will be bridged + * EBT_DROP means the frame will be routed + */ +static struct ebt_entries initial_chain = { + .name = "BROUTING", + .policy = EBT_ACCEPT, +}; + +static struct ebt_replace_kernel initial_table = { + .name = "broute", + .valid_hooks = 1 << NF_BR_BROUTING, + .entries_size = sizeof(struct ebt_entries), + .hook_entry = { + [NF_BR_BROUTING] = &initial_chain, + }, + .entries = (char *)&initial_chain, +}; + +static const struct ebt_table broute_table = { + .name = "broute", + .table = &initial_table, + .valid_hooks = 1 << NF_BR_BROUTING, + .me = THIS_MODULE, +}; + +static unsigned int ebt_broute(void *priv, struct sk_buff *skb, + const struct nf_hook_state *s) +{ + struct net_bridge_port *p = br_port_get_rcu(skb->dev); + struct nf_hook_state state; + unsigned char *dest; + int ret; + + if (!p || p->state != BR_STATE_FORWARDING) + return NF_ACCEPT; + + nf_hook_state_init(&state, NF_BR_BROUTING, + NFPROTO_BRIDGE, s->in, NULL, NULL, + s->net, NULL); + + ret = ebt_do_table(priv, skb, &state); + if (ret != NF_DROP) + return ret; + + /* DROP in ebtables -t broute means that the + * skb should be routed, not bridged. + * This is awkward, but can't be changed for compatibility + * reasons. + * + * We map DROP to ACCEPT and set the ->br_netfilter_broute flag. + */ + BR_INPUT_SKB_CB(skb)->br_netfilter_broute = 1; + + /* undo PACKET_HOST mangling done in br_input in case the dst + * address matches the logical bridge but not the port. + */ + dest = eth_hdr(skb)->h_dest; + if (skb->pkt_type == PACKET_HOST && + !ether_addr_equal(skb->dev->dev_addr, dest) && + ether_addr_equal(p->br->dev->dev_addr, dest)) + skb->pkt_type = PACKET_OTHERHOST; + + return NF_ACCEPT; +} + +static const struct nf_hook_ops ebt_ops_broute = { + .hook = ebt_broute, + .pf = NFPROTO_BRIDGE, + .hooknum = NF_BR_PRE_ROUTING, + .priority = NF_BR_PRI_FIRST, +}; + +static int broute_table_init(struct net *net) +{ + return ebt_register_table(net, &broute_table, &ebt_ops_broute); +} + +static void __net_exit broute_net_pre_exit(struct net *net) +{ + ebt_unregister_table_pre_exit(net, "broute"); +} + +static void __net_exit broute_net_exit(struct net *net) +{ + ebt_unregister_table(net, "broute"); +} + +static struct pernet_operations broute_net_ops = { + .exit = broute_net_exit, + .pre_exit = broute_net_pre_exit, +}; + +static int __init ebtable_broute_init(void) +{ + int ret = ebt_register_template(&broute_table, broute_table_init); + + if (ret) + return ret; + + ret = register_pernet_subsys(&broute_net_ops); + if (ret) { + ebt_unregister_template(&broute_table); + return ret; + } + + return 0; +} + +static void __exit ebtable_broute_fini(void) +{ + unregister_pernet_subsys(&broute_net_ops); + ebt_unregister_template(&broute_table); +} + +module_init(ebtable_broute_init); +module_exit(ebtable_broute_fini); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebtable_filter.c b/net/bridge/netfilter/ebtable_filter.c new file mode 100644 index 000000000..278f324e6 --- /dev/null +++ b/net/bridge/netfilter/ebtable_filter.c @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebtable_filter + * + * Authors: + * Bart De Schuymer + * + * April, 2002 + * + */ + +#include +#include +#include + +#define FILTER_VALID_HOOKS ((1 << NF_BR_LOCAL_IN) | (1 << NF_BR_FORWARD) | \ + (1 << NF_BR_LOCAL_OUT)) + +static struct ebt_entries initial_chains[] = { + { + .name = "INPUT", + .policy = EBT_ACCEPT, + }, + { + .name = "FORWARD", + .policy = EBT_ACCEPT, + }, + { + .name = "OUTPUT", + .policy = EBT_ACCEPT, + }, +}; + +static struct ebt_replace_kernel initial_table = { + .name = "filter", + .valid_hooks = FILTER_VALID_HOOKS, + .entries_size = 3 * sizeof(struct ebt_entries), + .hook_entry = { + [NF_BR_LOCAL_IN] = &initial_chains[0], + [NF_BR_FORWARD] = &initial_chains[1], + [NF_BR_LOCAL_OUT] = &initial_chains[2], + }, + .entries = (char *)initial_chains, +}; + +static const struct ebt_table frame_filter = { + .name = "filter", + .table = &initial_table, + .valid_hooks = FILTER_VALID_HOOKS, + .me = THIS_MODULE, +}; + +static const struct nf_hook_ops ebt_ops_filter[] = { + { + .hook = ebt_do_table, + .pf = NFPROTO_BRIDGE, + .hooknum = NF_BR_LOCAL_IN, + .priority = NF_BR_PRI_FILTER_BRIDGED, + }, + { + .hook = ebt_do_table, + .pf = NFPROTO_BRIDGE, + .hooknum = NF_BR_FORWARD, + .priority = NF_BR_PRI_FILTER_BRIDGED, + }, + { + .hook = ebt_do_table, + .pf = NFPROTO_BRIDGE, + .hooknum = NF_BR_LOCAL_OUT, + .priority = NF_BR_PRI_FILTER_OTHER, + }, +}; + +static int frame_filter_table_init(struct net *net) +{ + return ebt_register_table(net, &frame_filter, ebt_ops_filter); +} + +static void __net_exit frame_filter_net_pre_exit(struct net *net) +{ + ebt_unregister_table_pre_exit(net, "filter"); +} + +static void __net_exit frame_filter_net_exit(struct net *net) +{ + ebt_unregister_table(net, "filter"); +} + +static struct pernet_operations frame_filter_net_ops = { + .exit = frame_filter_net_exit, + .pre_exit = frame_filter_net_pre_exit, +}; + +static int __init ebtable_filter_init(void) +{ + int ret = ebt_register_template(&frame_filter, frame_filter_table_init); + + if (ret) + return ret; + + ret = register_pernet_subsys(&frame_filter_net_ops); + if (ret) { + ebt_unregister_template(&frame_filter); + return ret; + } + + return 0; +} + +static void __exit ebtable_filter_fini(void) +{ + unregister_pernet_subsys(&frame_filter_net_ops); + ebt_unregister_template(&frame_filter); +} + +module_init(ebtable_filter_init); +module_exit(ebtable_filter_fini); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebtable_nat.c b/net/bridge/netfilter/ebtable_nat.c new file mode 100644 index 000000000..9066f7f37 --- /dev/null +++ b/net/bridge/netfilter/ebtable_nat.c @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ebtable_nat + * + * Authors: + * Bart De Schuymer + * + * April, 2002 + * + */ + +#include +#include +#include + +#define NAT_VALID_HOOKS ((1 << NF_BR_PRE_ROUTING) | (1 << NF_BR_LOCAL_OUT) | \ + (1 << NF_BR_POST_ROUTING)) + +static struct ebt_entries initial_chains[] = { + { + .name = "PREROUTING", + .policy = EBT_ACCEPT, + }, + { + .name = "OUTPUT", + .policy = EBT_ACCEPT, + }, + { + .name = "POSTROUTING", + .policy = EBT_ACCEPT, + } +}; + +static struct ebt_replace_kernel initial_table = { + .name = "nat", + .valid_hooks = NAT_VALID_HOOKS, + .entries_size = 3 * sizeof(struct ebt_entries), + .hook_entry = { + [NF_BR_PRE_ROUTING] = &initial_chains[0], + [NF_BR_LOCAL_OUT] = &initial_chains[1], + [NF_BR_POST_ROUTING] = &initial_chains[2], + }, + .entries = (char *)initial_chains, +}; + +static const struct ebt_table frame_nat = { + .name = "nat", + .table = &initial_table, + .valid_hooks = NAT_VALID_HOOKS, + .me = THIS_MODULE, +}; + +static const struct nf_hook_ops ebt_ops_nat[] = { + { + .hook = ebt_do_table, + .pf = NFPROTO_BRIDGE, + .hooknum = NF_BR_LOCAL_OUT, + .priority = NF_BR_PRI_NAT_DST_OTHER, + }, + { + .hook = ebt_do_table, + .pf = NFPROTO_BRIDGE, + .hooknum = NF_BR_POST_ROUTING, + .priority = NF_BR_PRI_NAT_SRC, + }, + { + .hook = ebt_do_table, + .pf = NFPROTO_BRIDGE, + .hooknum = NF_BR_PRE_ROUTING, + .priority = NF_BR_PRI_NAT_DST_BRIDGED, + }, +}; + +static int frame_nat_table_init(struct net *net) +{ + return ebt_register_table(net, &frame_nat, ebt_ops_nat); +} + +static void __net_exit frame_nat_net_pre_exit(struct net *net) +{ + ebt_unregister_table_pre_exit(net, "nat"); +} + +static void __net_exit frame_nat_net_exit(struct net *net) +{ + ebt_unregister_table(net, "nat"); +} + +static struct pernet_operations frame_nat_net_ops = { + .exit = frame_nat_net_exit, + .pre_exit = frame_nat_net_pre_exit, +}; + +static int __init ebtable_nat_init(void) +{ + int ret = ebt_register_template(&frame_nat, frame_nat_table_init); + + if (ret) + return ret; + + ret = register_pernet_subsys(&frame_nat_net_ops); + if (ret) { + ebt_unregister_template(&frame_nat); + return ret; + } + + return ret; +} + +static void __exit ebtable_nat_fini(void) +{ + unregister_pernet_subsys(&frame_nat_net_ops); + ebt_unregister_template(&frame_nat); +} + +module_init(ebtable_nat_init); +module_exit(ebtable_nat_fini); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/ebtables.c b/net/bridge/netfilter/ebtables.c new file mode 100644 index 000000000..aa23479b2 --- /dev/null +++ b/net/bridge/netfilter/ebtables.c @@ -0,0 +1,2597 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * ebtables + * + * Author: + * Bart De Schuymer + * + * ebtables.c,v 2.0, July, 2002 + * + * This code is strongly inspired by the iptables code which is + * Copyright (C) 1999 Paul `Rusty' Russell & Michael J. Neuling + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +/* needed for logical [in,out]-dev filtering */ +#include "../br_private.h" + +/* Each cpu has its own set of counters, so there is no need for write_lock in + * the softirq + * For reading or updating the counters, the user context needs to + * get a write_lock + */ + +/* The size of each set of counters is altered to get cache alignment */ +#define SMP_ALIGN(x) (((x) + SMP_CACHE_BYTES-1) & ~(SMP_CACHE_BYTES-1)) +#define COUNTER_OFFSET(n) (SMP_ALIGN(n * sizeof(struct ebt_counter))) +#define COUNTER_BASE(c, n, cpu) ((struct ebt_counter *)(((char *)c) + \ + COUNTER_OFFSET(n) * cpu)) + +struct ebt_pernet { + struct list_head tables; +}; + +struct ebt_template { + struct list_head list; + char name[EBT_TABLE_MAXNAMELEN]; + struct module *owner; + /* called when table is needed in the given netns */ + int (*table_init)(struct net *net); +}; + +static unsigned int ebt_pernet_id __read_mostly; +static LIST_HEAD(template_tables); +static DEFINE_MUTEX(ebt_mutex); + +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT +static void ebt_standard_compat_from_user(void *dst, const void *src) +{ + int v = *(compat_int_t *)src; + + if (v >= 0) + v += xt_compat_calc_jump(NFPROTO_BRIDGE, v); + memcpy(dst, &v, sizeof(v)); +} + +static int ebt_standard_compat_to_user(void __user *dst, const void *src) +{ + compat_int_t cv = *(int *)src; + + if (cv >= 0) + cv -= xt_compat_calc_jump(NFPROTO_BRIDGE, cv); + return copy_to_user(dst, &cv, sizeof(cv)) ? -EFAULT : 0; +} +#endif + + +static struct xt_target ebt_standard_target = { + .name = "standard", + .revision = 0, + .family = NFPROTO_BRIDGE, + .targetsize = sizeof(int), +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + .compatsize = sizeof(compat_int_t), + .compat_from_user = ebt_standard_compat_from_user, + .compat_to_user = ebt_standard_compat_to_user, +#endif +}; + +static inline int +ebt_do_watcher(const struct ebt_entry_watcher *w, struct sk_buff *skb, + struct xt_action_param *par) +{ + par->target = w->u.watcher; + par->targinfo = w->data; + w->u.watcher->target(skb, par); + /* watchers don't give a verdict */ + return 0; +} + +static inline int +ebt_do_match(struct ebt_entry_match *m, const struct sk_buff *skb, + struct xt_action_param *par) +{ + par->match = m->u.match; + par->matchinfo = m->data; + return !m->u.match->match(skb, par); +} + +static inline int +ebt_dev_check(const char *entry, const struct net_device *device) +{ + int i = 0; + const char *devname; + + if (*entry == '\0') + return 0; + if (!device) + return 1; + devname = device->name; + /* 1 is the wildcard token */ + while (entry[i] != '\0' && entry[i] != 1 && entry[i] == devname[i]) + i++; + return devname[i] != entry[i] && entry[i] != 1; +} + +/* process standard matches */ +static inline int +ebt_basic_match(const struct ebt_entry *e, const struct sk_buff *skb, + const struct net_device *in, const struct net_device *out) +{ + const struct ethhdr *h = eth_hdr(skb); + const struct net_bridge_port *p; + __be16 ethproto; + + if (skb_vlan_tag_present(skb)) + ethproto = htons(ETH_P_8021Q); + else + ethproto = h->h_proto; + + if (e->bitmask & EBT_802_3) { + if (NF_INVF(e, EBT_IPROTO, eth_proto_is_802_3(ethproto))) + return 1; + } else if (!(e->bitmask & EBT_NOPROTO) && + NF_INVF(e, EBT_IPROTO, e->ethproto != ethproto)) + return 1; + + if (NF_INVF(e, EBT_IIN, ebt_dev_check(e->in, in))) + return 1; + if (NF_INVF(e, EBT_IOUT, ebt_dev_check(e->out, out))) + return 1; + /* rcu_read_lock()ed by nf_hook_thresh */ + if (in && (p = br_port_get_rcu(in)) != NULL && + NF_INVF(e, EBT_ILOGICALIN, + ebt_dev_check(e->logical_in, p->br->dev))) + return 1; + if (out && (p = br_port_get_rcu(out)) != NULL && + NF_INVF(e, EBT_ILOGICALOUT, + ebt_dev_check(e->logical_out, p->br->dev))) + return 1; + + if (e->bitmask & EBT_SOURCEMAC) { + if (NF_INVF(e, EBT_ISOURCE, + !ether_addr_equal_masked(h->h_source, e->sourcemac, + e->sourcemsk))) + return 1; + } + if (e->bitmask & EBT_DESTMAC) { + if (NF_INVF(e, EBT_IDEST, + !ether_addr_equal_masked(h->h_dest, e->destmac, + e->destmsk))) + return 1; + } + return 0; +} + +static inline +struct ebt_entry *ebt_next_entry(const struct ebt_entry *entry) +{ + return (void *)entry + entry->next_offset; +} + +static inline const struct ebt_entry_target * +ebt_get_target_c(const struct ebt_entry *e) +{ + return ebt_get_target((struct ebt_entry *)e); +} + +/* Do some firewalling */ +unsigned int ebt_do_table(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct ebt_table *table = priv; + unsigned int hook = state->hook; + int i, nentries; + struct ebt_entry *point; + struct ebt_counter *counter_base, *cb_base; + const struct ebt_entry_target *t; + int verdict, sp = 0; + struct ebt_chainstack *cs; + struct ebt_entries *chaininfo; + const char *base; + const struct ebt_table_info *private; + struct xt_action_param acpar; + + acpar.state = state; + acpar.hotdrop = false; + + read_lock_bh(&table->lock); + private = table->private; + cb_base = COUNTER_BASE(private->counters, private->nentries, + smp_processor_id()); + if (private->chainstack) + cs = private->chainstack[smp_processor_id()]; + else + cs = NULL; + chaininfo = private->hook_entry[hook]; + nentries = private->hook_entry[hook]->nentries; + point = (struct ebt_entry *)(private->hook_entry[hook]->data); + counter_base = cb_base + private->hook_entry[hook]->counter_offset; + /* base for chain jumps */ + base = private->entries; + i = 0; + while (i < nentries) { + if (ebt_basic_match(point, skb, state->in, state->out)) + goto letscontinue; + + if (EBT_MATCH_ITERATE(point, ebt_do_match, skb, &acpar) != 0) + goto letscontinue; + if (acpar.hotdrop) { + read_unlock_bh(&table->lock); + return NF_DROP; + } + + ADD_COUNTER(*(counter_base + i), skb->len, 1); + + /* these should only watch: not modify, nor tell us + * what to do with the packet + */ + EBT_WATCHER_ITERATE(point, ebt_do_watcher, skb, &acpar); + + t = ebt_get_target_c(point); + /* standard target */ + if (!t->u.target->target) + verdict = ((struct ebt_standard_target *)t)->verdict; + else { + acpar.target = t->u.target; + acpar.targinfo = t->data; + verdict = t->u.target->target(skb, &acpar); + } + if (verdict == EBT_ACCEPT) { + read_unlock_bh(&table->lock); + return NF_ACCEPT; + } + if (verdict == EBT_DROP) { + read_unlock_bh(&table->lock); + return NF_DROP; + } + if (verdict == EBT_RETURN) { +letsreturn: + if (WARN(sp == 0, "RETURN on base chain")) { + /* act like this is EBT_CONTINUE */ + goto letscontinue; + } + + sp--; + /* put all the local variables right */ + i = cs[sp].n; + chaininfo = cs[sp].chaininfo; + nentries = chaininfo->nentries; + point = cs[sp].e; + counter_base = cb_base + + chaininfo->counter_offset; + continue; + } + if (verdict == EBT_CONTINUE) + goto letscontinue; + + if (WARN(verdict < 0, "bogus standard verdict\n")) { + read_unlock_bh(&table->lock); + return NF_DROP; + } + + /* jump to a udc */ + cs[sp].n = i + 1; + cs[sp].chaininfo = chaininfo; + cs[sp].e = ebt_next_entry(point); + i = 0; + chaininfo = (struct ebt_entries *) (base + verdict); + + if (WARN(chaininfo->distinguisher, "jump to non-chain\n")) { + read_unlock_bh(&table->lock); + return NF_DROP; + } + + nentries = chaininfo->nentries; + point = (struct ebt_entry *)chaininfo->data; + counter_base = cb_base + chaininfo->counter_offset; + sp++; + continue; +letscontinue: + point = ebt_next_entry(point); + i++; + } + + /* I actually like this :) */ + if (chaininfo->policy == EBT_RETURN) + goto letsreturn; + if (chaininfo->policy == EBT_ACCEPT) { + read_unlock_bh(&table->lock); + return NF_ACCEPT; + } + read_unlock_bh(&table->lock); + return NF_DROP; +} + +/* If it succeeds, returns element and locks mutex */ +static inline void * +find_inlist_lock_noload(struct net *net, const char *name, int *error, + struct mutex *mutex) +{ + struct ebt_pernet *ebt_net = net_generic(net, ebt_pernet_id); + struct ebt_template *tmpl; + struct ebt_table *table; + + mutex_lock(mutex); + list_for_each_entry(table, &ebt_net->tables, list) { + if (strcmp(table->name, name) == 0) + return table; + } + + list_for_each_entry(tmpl, &template_tables, list) { + if (strcmp(name, tmpl->name) == 0) { + struct module *owner = tmpl->owner; + + if (!try_module_get(owner)) + goto out; + + mutex_unlock(mutex); + + *error = tmpl->table_init(net); + if (*error) { + module_put(owner); + return NULL; + } + + mutex_lock(mutex); + module_put(owner); + break; + } + } + + list_for_each_entry(table, &ebt_net->tables, list) { + if (strcmp(table->name, name) == 0) + return table; + } + +out: + *error = -ENOENT; + mutex_unlock(mutex); + return NULL; +} + +static void * +find_inlist_lock(struct net *net, const char *name, const char *prefix, + int *error, struct mutex *mutex) +{ + return try_then_request_module( + find_inlist_lock_noload(net, name, error, mutex), + "%s%s", prefix, name); +} + +static inline struct ebt_table * +find_table_lock(struct net *net, const char *name, int *error, + struct mutex *mutex) +{ + return find_inlist_lock(net, name, "ebtable_", error, mutex); +} + +static inline void ebt_free_table_info(struct ebt_table_info *info) +{ + int i; + + if (info->chainstack) { + for_each_possible_cpu(i) + vfree(info->chainstack[i]); + vfree(info->chainstack); + } +} +static inline int +ebt_check_match(struct ebt_entry_match *m, struct xt_mtchk_param *par, + unsigned int *cnt) +{ + const struct ebt_entry *e = par->entryinfo; + struct xt_match *match; + size_t left = ((char *)e + e->watchers_offset) - (char *)m; + int ret; + + if (left < sizeof(struct ebt_entry_match) || + left - sizeof(struct ebt_entry_match) < m->match_size) + return -EINVAL; + + match = xt_find_match(NFPROTO_BRIDGE, m->u.name, m->u.revision); + if (IS_ERR(match) || match->family != NFPROTO_BRIDGE) { + if (!IS_ERR(match)) + module_put(match->me); + request_module("ebt_%s", m->u.name); + match = xt_find_match(NFPROTO_BRIDGE, m->u.name, m->u.revision); + } + if (IS_ERR(match)) + return PTR_ERR(match); + m->u.match = match; + + par->match = match; + par->matchinfo = m->data; + ret = xt_check_match(par, m->match_size, + ntohs(e->ethproto), e->invflags & EBT_IPROTO); + if (ret < 0) { + module_put(match->me); + return ret; + } + + (*cnt)++; + return 0; +} + +static inline int +ebt_check_watcher(struct ebt_entry_watcher *w, struct xt_tgchk_param *par, + unsigned int *cnt) +{ + const struct ebt_entry *e = par->entryinfo; + struct xt_target *watcher; + size_t left = ((char *)e + e->target_offset) - (char *)w; + int ret; + + if (left < sizeof(struct ebt_entry_watcher) || + left - sizeof(struct ebt_entry_watcher) < w->watcher_size) + return -EINVAL; + + watcher = xt_request_find_target(NFPROTO_BRIDGE, w->u.name, 0); + if (IS_ERR(watcher)) + return PTR_ERR(watcher); + + if (watcher->family != NFPROTO_BRIDGE) { + module_put(watcher->me); + return -ENOENT; + } + + w->u.watcher = watcher; + + par->target = watcher; + par->targinfo = w->data; + ret = xt_check_target(par, w->watcher_size, + ntohs(e->ethproto), e->invflags & EBT_IPROTO); + if (ret < 0) { + module_put(watcher->me); + return ret; + } + + (*cnt)++; + return 0; +} + +static int ebt_verify_pointers(const struct ebt_replace *repl, + struct ebt_table_info *newinfo) +{ + unsigned int limit = repl->entries_size; + unsigned int valid_hooks = repl->valid_hooks; + unsigned int offset = 0; + int i; + + for (i = 0; i < NF_BR_NUMHOOKS; i++) + newinfo->hook_entry[i] = NULL; + + newinfo->entries_size = repl->entries_size; + newinfo->nentries = repl->nentries; + + while (offset < limit) { + size_t left = limit - offset; + struct ebt_entry *e = (void *)newinfo->entries + offset; + + if (left < sizeof(unsigned int)) + break; + + for (i = 0; i < NF_BR_NUMHOOKS; i++) { + if ((valid_hooks & (1 << i)) == 0) + continue; + if ((char __user *)repl->hook_entry[i] == + repl->entries + offset) + break; + } + + if (i != NF_BR_NUMHOOKS || !(e->bitmask & EBT_ENTRY_OR_ENTRIES)) { + if (e->bitmask != 0) { + /* we make userspace set this right, + * so there is no misunderstanding + */ + return -EINVAL; + } + if (i != NF_BR_NUMHOOKS) + newinfo->hook_entry[i] = (struct ebt_entries *)e; + if (left < sizeof(struct ebt_entries)) + break; + offset += sizeof(struct ebt_entries); + } else { + if (left < sizeof(struct ebt_entry)) + break; + if (left < e->next_offset) + break; + if (e->next_offset < sizeof(struct ebt_entry)) + return -EINVAL; + offset += e->next_offset; + } + } + if (offset != limit) + return -EINVAL; + + /* check if all valid hooks have a chain */ + for (i = 0; i < NF_BR_NUMHOOKS; i++) { + if (!newinfo->hook_entry[i] && + (valid_hooks & (1 << i))) + return -EINVAL; + } + return 0; +} + +/* this one is very careful, as it is the first function + * to parse the userspace data + */ +static inline int +ebt_check_entry_size_and_hooks(const struct ebt_entry *e, + const struct ebt_table_info *newinfo, + unsigned int *n, unsigned int *cnt, + unsigned int *totalcnt, unsigned int *udc_cnt) +{ + int i; + + for (i = 0; i < NF_BR_NUMHOOKS; i++) { + if ((void *)e == (void *)newinfo->hook_entry[i]) + break; + } + /* beginning of a new chain + * if i == NF_BR_NUMHOOKS it must be a user defined chain + */ + if (i != NF_BR_NUMHOOKS || !e->bitmask) { + /* this checks if the previous chain has as many entries + * as it said it has + */ + if (*n != *cnt) + return -EINVAL; + + if (((struct ebt_entries *)e)->policy != EBT_DROP && + ((struct ebt_entries *)e)->policy != EBT_ACCEPT) { + /* only RETURN from udc */ + if (i != NF_BR_NUMHOOKS || + ((struct ebt_entries *)e)->policy != EBT_RETURN) + return -EINVAL; + } + if (i == NF_BR_NUMHOOKS) /* it's a user defined chain */ + (*udc_cnt)++; + if (((struct ebt_entries *)e)->counter_offset != *totalcnt) + return -EINVAL; + *n = ((struct ebt_entries *)e)->nentries; + *cnt = 0; + return 0; + } + /* a plain old entry, heh */ + if (sizeof(struct ebt_entry) > e->watchers_offset || + e->watchers_offset > e->target_offset || + e->target_offset >= e->next_offset) + return -EINVAL; + + /* this is not checked anywhere else */ + if (e->next_offset - e->target_offset < sizeof(struct ebt_entry_target)) + return -EINVAL; + + (*cnt)++; + (*totalcnt)++; + return 0; +} + +struct ebt_cl_stack { + struct ebt_chainstack cs; + int from; + unsigned int hookmask; +}; + +/* We need these positions to check that the jumps to a different part of the + * entries is a jump to the beginning of a new chain. + */ +static inline int +ebt_get_udc_positions(struct ebt_entry *e, struct ebt_table_info *newinfo, + unsigned int *n, struct ebt_cl_stack *udc) +{ + int i; + + /* we're only interested in chain starts */ + if (e->bitmask) + return 0; + for (i = 0; i < NF_BR_NUMHOOKS; i++) { + if (newinfo->hook_entry[i] == (struct ebt_entries *)e) + break; + } + /* only care about udc */ + if (i != NF_BR_NUMHOOKS) + return 0; + + udc[*n].cs.chaininfo = (struct ebt_entries *)e; + /* these initialisations are depended on later in check_chainloops() */ + udc[*n].cs.n = 0; + udc[*n].hookmask = 0; + + (*n)++; + return 0; +} + +static inline int +ebt_cleanup_match(struct ebt_entry_match *m, struct net *net, unsigned int *i) +{ + struct xt_mtdtor_param par; + + if (i && (*i)-- == 0) + return 1; + + par.net = net; + par.match = m->u.match; + par.matchinfo = m->data; + par.family = NFPROTO_BRIDGE; + if (par.match->destroy != NULL) + par.match->destroy(&par); + module_put(par.match->me); + return 0; +} + +static inline int +ebt_cleanup_watcher(struct ebt_entry_watcher *w, struct net *net, unsigned int *i) +{ + struct xt_tgdtor_param par; + + if (i && (*i)-- == 0) + return 1; + + par.net = net; + par.target = w->u.watcher; + par.targinfo = w->data; + par.family = NFPROTO_BRIDGE; + if (par.target->destroy != NULL) + par.target->destroy(&par); + module_put(par.target->me); + return 0; +} + +static inline int +ebt_cleanup_entry(struct ebt_entry *e, struct net *net, unsigned int *cnt) +{ + struct xt_tgdtor_param par; + struct ebt_entry_target *t; + + if (e->bitmask == 0) + return 0; + /* we're done */ + if (cnt && (*cnt)-- == 0) + return 1; + EBT_WATCHER_ITERATE(e, ebt_cleanup_watcher, net, NULL); + EBT_MATCH_ITERATE(e, ebt_cleanup_match, net, NULL); + t = ebt_get_target(e); + + par.net = net; + par.target = t->u.target; + par.targinfo = t->data; + par.family = NFPROTO_BRIDGE; + if (par.target->destroy != NULL) + par.target->destroy(&par); + module_put(par.target->me); + return 0; +} + +static inline int +ebt_check_entry(struct ebt_entry *e, struct net *net, + const struct ebt_table_info *newinfo, + const char *name, unsigned int *cnt, + struct ebt_cl_stack *cl_s, unsigned int udc_cnt) +{ + struct ebt_entry_target *t; + struct xt_target *target; + unsigned int i, j, hook = 0, hookmask = 0; + size_t gap; + int ret; + struct xt_mtchk_param mtpar; + struct xt_tgchk_param tgpar; + + /* don't mess with the struct ebt_entries */ + if (e->bitmask == 0) + return 0; + + if (e->bitmask & ~EBT_F_MASK) + return -EINVAL; + + if (e->invflags & ~EBT_INV_MASK) + return -EINVAL; + + if ((e->bitmask & EBT_NOPROTO) && (e->bitmask & EBT_802_3)) + return -EINVAL; + + /* what hook do we belong to? */ + for (i = 0; i < NF_BR_NUMHOOKS; i++) { + if (!newinfo->hook_entry[i]) + continue; + if ((char *)newinfo->hook_entry[i] < (char *)e) + hook = i; + else + break; + } + /* (1 << NF_BR_NUMHOOKS) tells the check functions the rule is on + * a base chain + */ + if (i < NF_BR_NUMHOOKS) + hookmask = (1 << hook) | (1 << NF_BR_NUMHOOKS); + else { + for (i = 0; i < udc_cnt; i++) + if ((char *)(cl_s[i].cs.chaininfo) > (char *)e) + break; + if (i == 0) + hookmask = (1 << hook) | (1 << NF_BR_NUMHOOKS); + else + hookmask = cl_s[i - 1].hookmask; + } + i = 0; + + memset(&mtpar, 0, sizeof(mtpar)); + memset(&tgpar, 0, sizeof(tgpar)); + mtpar.net = tgpar.net = net; + mtpar.table = tgpar.table = name; + mtpar.entryinfo = tgpar.entryinfo = e; + mtpar.hook_mask = tgpar.hook_mask = hookmask; + mtpar.family = tgpar.family = NFPROTO_BRIDGE; + ret = EBT_MATCH_ITERATE(e, ebt_check_match, &mtpar, &i); + if (ret != 0) + goto cleanup_matches; + j = 0; + ret = EBT_WATCHER_ITERATE(e, ebt_check_watcher, &tgpar, &j); + if (ret != 0) + goto cleanup_watchers; + t = ebt_get_target(e); + gap = e->next_offset - e->target_offset; + + target = xt_request_find_target(NFPROTO_BRIDGE, t->u.name, 0); + if (IS_ERR(target)) { + ret = PTR_ERR(target); + goto cleanup_watchers; + } + + /* Reject UNSPEC, xtables verdicts/return values are incompatible */ + if (target->family != NFPROTO_BRIDGE) { + module_put(target->me); + ret = -ENOENT; + goto cleanup_watchers; + } + + t->u.target = target; + if (t->u.target == &ebt_standard_target) { + if (gap < sizeof(struct ebt_standard_target)) { + ret = -EFAULT; + goto cleanup_watchers; + } + if (((struct ebt_standard_target *)t)->verdict < + -NUM_STANDARD_TARGETS) { + ret = -EFAULT; + goto cleanup_watchers; + } + } else if (t->target_size > gap - sizeof(struct ebt_entry_target)) { + module_put(t->u.target->me); + ret = -EFAULT; + goto cleanup_watchers; + } + + tgpar.target = target; + tgpar.targinfo = t->data; + ret = xt_check_target(&tgpar, t->target_size, + ntohs(e->ethproto), e->invflags & EBT_IPROTO); + if (ret < 0) { + module_put(target->me); + goto cleanup_watchers; + } + (*cnt)++; + return 0; +cleanup_watchers: + EBT_WATCHER_ITERATE(e, ebt_cleanup_watcher, net, &j); +cleanup_matches: + EBT_MATCH_ITERATE(e, ebt_cleanup_match, net, &i); + return ret; +} + +/* checks for loops and sets the hook mask for udc + * the hook mask for udc tells us from which base chains the udc can be + * accessed. This mask is a parameter to the check() functions of the extensions + */ +static int check_chainloops(const struct ebt_entries *chain, struct ebt_cl_stack *cl_s, + unsigned int udc_cnt, unsigned int hooknr, char *base) +{ + int i, chain_nr = -1, pos = 0, nentries = chain->nentries, verdict; + const struct ebt_entry *e = (struct ebt_entry *)chain->data; + const struct ebt_entry_target *t; + + while (pos < nentries || chain_nr != -1) { + /* end of udc, go back one 'recursion' step */ + if (pos == nentries) { + /* put back values of the time when this chain was called */ + e = cl_s[chain_nr].cs.e; + if (cl_s[chain_nr].from != -1) + nentries = + cl_s[cl_s[chain_nr].from].cs.chaininfo->nentries; + else + nentries = chain->nentries; + pos = cl_s[chain_nr].cs.n; + /* make sure we won't see a loop that isn't one */ + cl_s[chain_nr].cs.n = 0; + chain_nr = cl_s[chain_nr].from; + if (pos == nentries) + continue; + } + t = ebt_get_target_c(e); + if (strcmp(t->u.name, EBT_STANDARD_TARGET)) + goto letscontinue; + if (e->target_offset + sizeof(struct ebt_standard_target) > + e->next_offset) + return -1; + + verdict = ((struct ebt_standard_target *)t)->verdict; + if (verdict >= 0) { /* jump to another chain */ + struct ebt_entries *hlp2 = + (struct ebt_entries *)(base + verdict); + for (i = 0; i < udc_cnt; i++) + if (hlp2 == cl_s[i].cs.chaininfo) + break; + /* bad destination or loop */ + if (i == udc_cnt) + return -1; + + if (cl_s[i].cs.n) + return -1; + + if (cl_s[i].hookmask & (1 << hooknr)) + goto letscontinue; + /* this can't be 0, so the loop test is correct */ + cl_s[i].cs.n = pos + 1; + pos = 0; + cl_s[i].cs.e = ebt_next_entry(e); + e = (struct ebt_entry *)(hlp2->data); + nentries = hlp2->nentries; + cl_s[i].from = chain_nr; + chain_nr = i; + /* this udc is accessible from the base chain for hooknr */ + cl_s[i].hookmask |= (1 << hooknr); + continue; + } +letscontinue: + e = ebt_next_entry(e); + pos++; + } + return 0; +} + +/* do the parsing of the table/chains/entries/matches/watchers/targets, heh */ +static int translate_table(struct net *net, const char *name, + struct ebt_table_info *newinfo) +{ + unsigned int i, j, k, udc_cnt; + int ret; + struct ebt_cl_stack *cl_s = NULL; /* used in the checking for chain loops */ + + i = 0; + while (i < NF_BR_NUMHOOKS && !newinfo->hook_entry[i]) + i++; + if (i == NF_BR_NUMHOOKS) + return -EINVAL; + + if (newinfo->hook_entry[i] != (struct ebt_entries *)newinfo->entries) + return -EINVAL; + + /* make sure chains are ordered after each other in same order + * as their corresponding hooks + */ + for (j = i + 1; j < NF_BR_NUMHOOKS; j++) { + if (!newinfo->hook_entry[j]) + continue; + if (newinfo->hook_entry[j] <= newinfo->hook_entry[i]) + return -EINVAL; + + i = j; + } + + /* do some early checkings and initialize some things */ + i = 0; /* holds the expected nr. of entries for the chain */ + j = 0; /* holds the up to now counted entries for the chain */ + k = 0; /* holds the total nr. of entries, should equal + * newinfo->nentries afterwards + */ + udc_cnt = 0; /* will hold the nr. of user defined chains (udc) */ + ret = EBT_ENTRY_ITERATE(newinfo->entries, newinfo->entries_size, + ebt_check_entry_size_and_hooks, newinfo, + &i, &j, &k, &udc_cnt); + + if (ret != 0) + return ret; + + if (i != j) + return -EINVAL; + + if (k != newinfo->nentries) + return -EINVAL; + + /* get the location of the udc, put them in an array + * while we're at it, allocate the chainstack + */ + if (udc_cnt) { + /* this will get free'd in do_replace()/ebt_register_table() + * if an error occurs + */ + newinfo->chainstack = + vmalloc(array_size(nr_cpu_ids, + sizeof(*(newinfo->chainstack)))); + if (!newinfo->chainstack) + return -ENOMEM; + for_each_possible_cpu(i) { + newinfo->chainstack[i] = + vmalloc_node(array_size(udc_cnt, + sizeof(*(newinfo->chainstack[0]))), + cpu_to_node(i)); + if (!newinfo->chainstack[i]) { + while (i) + vfree(newinfo->chainstack[--i]); + vfree(newinfo->chainstack); + newinfo->chainstack = NULL; + return -ENOMEM; + } + } + + cl_s = vmalloc(array_size(udc_cnt, sizeof(*cl_s))); + if (!cl_s) + return -ENOMEM; + i = 0; /* the i'th udc */ + EBT_ENTRY_ITERATE(newinfo->entries, newinfo->entries_size, + ebt_get_udc_positions, newinfo, &i, cl_s); + /* sanity check */ + if (i != udc_cnt) { + vfree(cl_s); + return -EFAULT; + } + } + + /* Check for loops */ + for (i = 0; i < NF_BR_NUMHOOKS; i++) + if (newinfo->hook_entry[i]) + if (check_chainloops(newinfo->hook_entry[i], + cl_s, udc_cnt, i, newinfo->entries)) { + vfree(cl_s); + return -EINVAL; + } + + /* we now know the following (along with E=mc²): + * - the nr of entries in each chain is right + * - the size of the allocated space is right + * - all valid hooks have a corresponding chain + * - there are no loops + * - wrong data can still be on the level of a single entry + * - could be there are jumps to places that are not the + * beginning of a chain. This can only occur in chains that + * are not accessible from any base chains, so we don't care. + */ + + /* used to know what we need to clean up if something goes wrong */ + i = 0; + ret = EBT_ENTRY_ITERATE(newinfo->entries, newinfo->entries_size, + ebt_check_entry, net, newinfo, name, &i, cl_s, udc_cnt); + if (ret != 0) { + EBT_ENTRY_ITERATE(newinfo->entries, newinfo->entries_size, + ebt_cleanup_entry, net, &i); + } + vfree(cl_s); + return ret; +} + +/* called under write_lock */ +static void get_counters(const struct ebt_counter *oldcounters, + struct ebt_counter *counters, unsigned int nentries) +{ + int i, cpu; + struct ebt_counter *counter_base; + + /* counters of cpu 0 */ + memcpy(counters, oldcounters, + sizeof(struct ebt_counter) * nentries); + + /* add other counters to those of cpu 0 */ + for_each_possible_cpu(cpu) { + if (cpu == 0) + continue; + counter_base = COUNTER_BASE(oldcounters, nentries, cpu); + for (i = 0; i < nentries; i++) + ADD_COUNTER(counters[i], counter_base[i].bcnt, + counter_base[i].pcnt); + } +} + +static int do_replace_finish(struct net *net, struct ebt_replace *repl, + struct ebt_table_info *newinfo) +{ + int ret; + struct ebt_counter *counterstmp = NULL; + /* used to be able to unlock earlier */ + struct ebt_table_info *table; + struct ebt_table *t; + + /* the user wants counters back + * the check on the size is done later, when we have the lock + */ + if (repl->num_counters) { + unsigned long size = repl->num_counters * sizeof(*counterstmp); + counterstmp = vmalloc(size); + if (!counterstmp) + return -ENOMEM; + } + + newinfo->chainstack = NULL; + ret = ebt_verify_pointers(repl, newinfo); + if (ret != 0) + goto free_counterstmp; + + ret = translate_table(net, repl->name, newinfo); + + if (ret != 0) + goto free_counterstmp; + + t = find_table_lock(net, repl->name, &ret, &ebt_mutex); + if (!t) { + ret = -ENOENT; + goto free_iterate; + } + + if (repl->valid_hooks != t->valid_hooks) { + ret = -EINVAL; + goto free_unlock; + } + + if (repl->num_counters && repl->num_counters != t->private->nentries) { + ret = -EINVAL; + goto free_unlock; + } + + /* we have the mutex lock, so no danger in reading this pointer */ + table = t->private; + /* make sure the table can only be rmmod'ed if it contains no rules */ + if (!table->nentries && newinfo->nentries && !try_module_get(t->me)) { + ret = -ENOENT; + goto free_unlock; + } else if (table->nentries && !newinfo->nentries) + module_put(t->me); + /* we need an atomic snapshot of the counters */ + write_lock_bh(&t->lock); + if (repl->num_counters) + get_counters(t->private->counters, counterstmp, + t->private->nentries); + + t->private = newinfo; + write_unlock_bh(&t->lock); + mutex_unlock(&ebt_mutex); + /* so, a user can change the chains while having messed up her counter + * allocation. Only reason why this is done is because this way the lock + * is held only once, while this doesn't bring the kernel into a + * dangerous state. + */ + if (repl->num_counters && + copy_to_user(repl->counters, counterstmp, + array_size(repl->num_counters, sizeof(struct ebt_counter)))) { + /* Silent error, can't fail, new table is already in place */ + net_warn_ratelimited("ebtables: counters copy to user failed while replacing table\n"); + } + + /* decrease module count and free resources */ + EBT_ENTRY_ITERATE(table->entries, table->entries_size, + ebt_cleanup_entry, net, NULL); + + vfree(table->entries); + ebt_free_table_info(table); + vfree(table); + vfree(counterstmp); + + audit_log_nfcfg(repl->name, AF_BRIDGE, repl->nentries, + AUDIT_XT_OP_REPLACE, GFP_KERNEL); + return 0; + +free_unlock: + mutex_unlock(&ebt_mutex); +free_iterate: + EBT_ENTRY_ITERATE(newinfo->entries, newinfo->entries_size, + ebt_cleanup_entry, net, NULL); +free_counterstmp: + vfree(counterstmp); + /* can be initialized in translate_table() */ + ebt_free_table_info(newinfo); + return ret; +} + +/* replace the table */ +static int do_replace(struct net *net, sockptr_t arg, unsigned int len) +{ + int ret, countersize; + struct ebt_table_info *newinfo; + struct ebt_replace tmp; + + if (copy_from_sockptr(&tmp, arg, sizeof(tmp)) != 0) + return -EFAULT; + + if (len != sizeof(tmp) + tmp.entries_size) + return -EINVAL; + + if (tmp.entries_size == 0) + return -EINVAL; + + /* overflow check */ + if (tmp.nentries >= ((INT_MAX - sizeof(struct ebt_table_info)) / + NR_CPUS - SMP_CACHE_BYTES) / sizeof(struct ebt_counter)) + return -ENOMEM; + if (tmp.num_counters >= INT_MAX / sizeof(struct ebt_counter)) + return -ENOMEM; + + tmp.name[sizeof(tmp.name) - 1] = 0; + + countersize = COUNTER_OFFSET(tmp.nentries) * nr_cpu_ids; + newinfo = __vmalloc(sizeof(*newinfo) + countersize, GFP_KERNEL_ACCOUNT); + if (!newinfo) + return -ENOMEM; + + if (countersize) + memset(newinfo->counters, 0, countersize); + + newinfo->entries = __vmalloc(tmp.entries_size, GFP_KERNEL_ACCOUNT); + if (!newinfo->entries) { + ret = -ENOMEM; + goto free_newinfo; + } + if (copy_from_user( + newinfo->entries, tmp.entries, tmp.entries_size) != 0) { + ret = -EFAULT; + goto free_entries; + } + + ret = do_replace_finish(net, &tmp, newinfo); + if (ret == 0) + return ret; +free_entries: + vfree(newinfo->entries); +free_newinfo: + vfree(newinfo); + return ret; +} + +static void __ebt_unregister_table(struct net *net, struct ebt_table *table) +{ + mutex_lock(&ebt_mutex); + list_del(&table->list); + mutex_unlock(&ebt_mutex); + audit_log_nfcfg(table->name, AF_BRIDGE, table->private->nentries, + AUDIT_XT_OP_UNREGISTER, GFP_KERNEL); + EBT_ENTRY_ITERATE(table->private->entries, table->private->entries_size, + ebt_cleanup_entry, net, NULL); + if (table->private->nentries) + module_put(table->me); + vfree(table->private->entries); + ebt_free_table_info(table->private); + vfree(table->private); + kfree(table->ops); + kfree(table); +} + +int ebt_register_table(struct net *net, const struct ebt_table *input_table, + const struct nf_hook_ops *template_ops) +{ + struct ebt_pernet *ebt_net = net_generic(net, ebt_pernet_id); + struct ebt_table_info *newinfo; + struct ebt_table *t, *table; + struct nf_hook_ops *ops; + unsigned int num_ops; + struct ebt_replace_kernel *repl; + int ret, i, countersize; + void *p; + + if (input_table == NULL || (repl = input_table->table) == NULL || + repl->entries == NULL || repl->entries_size == 0 || + repl->counters != NULL || input_table->private != NULL) + return -EINVAL; + + /* Don't add one table to multiple lists. */ + table = kmemdup(input_table, sizeof(struct ebt_table), GFP_KERNEL); + if (!table) { + ret = -ENOMEM; + goto out; + } + + countersize = COUNTER_OFFSET(repl->nentries) * nr_cpu_ids; + newinfo = vmalloc(sizeof(*newinfo) + countersize); + ret = -ENOMEM; + if (!newinfo) + goto free_table; + + p = vmalloc(repl->entries_size); + if (!p) + goto free_newinfo; + + memcpy(p, repl->entries, repl->entries_size); + newinfo->entries = p; + + newinfo->entries_size = repl->entries_size; + newinfo->nentries = repl->nentries; + + if (countersize) + memset(newinfo->counters, 0, countersize); + + /* fill in newinfo and parse the entries */ + newinfo->chainstack = NULL; + for (i = 0; i < NF_BR_NUMHOOKS; i++) { + if ((repl->valid_hooks & (1 << i)) == 0) + newinfo->hook_entry[i] = NULL; + else + newinfo->hook_entry[i] = p + + ((char *)repl->hook_entry[i] - repl->entries); + } + ret = translate_table(net, repl->name, newinfo); + if (ret != 0) + goto free_chainstack; + + table->private = newinfo; + rwlock_init(&table->lock); + mutex_lock(&ebt_mutex); + list_for_each_entry(t, &ebt_net->tables, list) { + if (strcmp(t->name, table->name) == 0) { + ret = -EEXIST; + goto free_unlock; + } + } + + /* Hold a reference count if the chains aren't empty */ + if (newinfo->nentries && !try_module_get(table->me)) { + ret = -ENOENT; + goto free_unlock; + } + + num_ops = hweight32(table->valid_hooks); + if (num_ops == 0) { + ret = -EINVAL; + goto free_unlock; + } + + ops = kmemdup(template_ops, sizeof(*ops) * num_ops, GFP_KERNEL); + if (!ops) { + ret = -ENOMEM; + if (newinfo->nentries) + module_put(table->me); + goto free_unlock; + } + + for (i = 0; i < num_ops; i++) + ops[i].priv = table; + + list_add(&table->list, &ebt_net->tables); + mutex_unlock(&ebt_mutex); + + table->ops = ops; + ret = nf_register_net_hooks(net, ops, num_ops); + if (ret) + __ebt_unregister_table(net, table); + + audit_log_nfcfg(repl->name, AF_BRIDGE, repl->nentries, + AUDIT_XT_OP_REGISTER, GFP_KERNEL); + return ret; +free_unlock: + mutex_unlock(&ebt_mutex); +free_chainstack: + ebt_free_table_info(newinfo); + vfree(newinfo->entries); +free_newinfo: + vfree(newinfo); +free_table: + kfree(table); +out: + return ret; +} + +int ebt_register_template(const struct ebt_table *t, int (*table_init)(struct net *net)) +{ + struct ebt_template *tmpl; + + mutex_lock(&ebt_mutex); + list_for_each_entry(tmpl, &template_tables, list) { + if (WARN_ON_ONCE(strcmp(t->name, tmpl->name) == 0)) { + mutex_unlock(&ebt_mutex); + return -EEXIST; + } + } + + tmpl = kzalloc(sizeof(*tmpl), GFP_KERNEL); + if (!tmpl) { + mutex_unlock(&ebt_mutex); + return -ENOMEM; + } + + tmpl->table_init = table_init; + strscpy(tmpl->name, t->name, sizeof(tmpl->name)); + tmpl->owner = t->me; + list_add(&tmpl->list, &template_tables); + + mutex_unlock(&ebt_mutex); + return 0; +} +EXPORT_SYMBOL(ebt_register_template); + +void ebt_unregister_template(const struct ebt_table *t) +{ + struct ebt_template *tmpl; + + mutex_lock(&ebt_mutex); + list_for_each_entry(tmpl, &template_tables, list) { + if (strcmp(t->name, tmpl->name)) + continue; + + list_del(&tmpl->list); + mutex_unlock(&ebt_mutex); + kfree(tmpl); + return; + } + + mutex_unlock(&ebt_mutex); + WARN_ON_ONCE(1); +} +EXPORT_SYMBOL(ebt_unregister_template); + +static struct ebt_table *__ebt_find_table(struct net *net, const char *name) +{ + struct ebt_pernet *ebt_net = net_generic(net, ebt_pernet_id); + struct ebt_table *t; + + mutex_lock(&ebt_mutex); + + list_for_each_entry(t, &ebt_net->tables, list) { + if (strcmp(t->name, name) == 0) { + mutex_unlock(&ebt_mutex); + return t; + } + } + + mutex_unlock(&ebt_mutex); + return NULL; +} + +void ebt_unregister_table_pre_exit(struct net *net, const char *name) +{ + struct ebt_table *table = __ebt_find_table(net, name); + + if (table) + nf_unregister_net_hooks(net, table->ops, hweight32(table->valid_hooks)); +} +EXPORT_SYMBOL(ebt_unregister_table_pre_exit); + +void ebt_unregister_table(struct net *net, const char *name) +{ + struct ebt_table *table = __ebt_find_table(net, name); + + if (table) + __ebt_unregister_table(net, table); +} + +/* userspace just supplied us with counters */ +static int do_update_counters(struct net *net, const char *name, + struct ebt_counter __user *counters, + unsigned int num_counters, unsigned int len) +{ + int i, ret; + struct ebt_counter *tmp; + struct ebt_table *t; + + if (num_counters == 0) + return -EINVAL; + + tmp = vmalloc(array_size(num_counters, sizeof(*tmp))); + if (!tmp) + return -ENOMEM; + + t = find_table_lock(net, name, &ret, &ebt_mutex); + if (!t) + goto free_tmp; + + if (num_counters != t->private->nentries) { + ret = -EINVAL; + goto unlock_mutex; + } + + if (copy_from_user(tmp, counters, + array_size(num_counters, sizeof(*counters)))) { + ret = -EFAULT; + goto unlock_mutex; + } + + /* we want an atomic add of the counters */ + write_lock_bh(&t->lock); + + /* we add to the counters of the first cpu */ + for (i = 0; i < num_counters; i++) + ADD_COUNTER(t->private->counters[i], tmp[i].bcnt, tmp[i].pcnt); + + write_unlock_bh(&t->lock); + ret = 0; +unlock_mutex: + mutex_unlock(&ebt_mutex); +free_tmp: + vfree(tmp); + return ret; +} + +static int update_counters(struct net *net, sockptr_t arg, unsigned int len) +{ + struct ebt_replace hlp; + + if (copy_from_sockptr(&hlp, arg, sizeof(hlp))) + return -EFAULT; + + if (len != sizeof(hlp) + hlp.num_counters * sizeof(struct ebt_counter)) + return -EINVAL; + + return do_update_counters(net, hlp.name, hlp.counters, + hlp.num_counters, len); +} + +static inline int ebt_obj_to_user(char __user *um, const char *_name, + const char *data, int entrysize, + int usersize, int datasize, u8 revision) +{ + char name[EBT_EXTENSION_MAXNAMELEN] = {0}; + + /* ebtables expects 31 bytes long names but xt_match names are 29 bytes + * long. Copy 29 bytes and fill remaining bytes with zeroes. + */ + strscpy(name, _name, sizeof(name)); + if (copy_to_user(um, name, EBT_EXTENSION_MAXNAMELEN) || + put_user(revision, (u8 __user *)(um + EBT_EXTENSION_MAXNAMELEN)) || + put_user(datasize, (int __user *)(um + EBT_EXTENSION_MAXNAMELEN + 1)) || + xt_data_to_user(um + entrysize, data, usersize, datasize, + XT_ALIGN(datasize))) + return -EFAULT; + + return 0; +} + +static inline int ebt_match_to_user(const struct ebt_entry_match *m, + const char *base, char __user *ubase) +{ + return ebt_obj_to_user(ubase + ((char *)m - base), + m->u.match->name, m->data, sizeof(*m), + m->u.match->usersize, m->match_size, + m->u.match->revision); +} + +static inline int ebt_watcher_to_user(const struct ebt_entry_watcher *w, + const char *base, char __user *ubase) +{ + return ebt_obj_to_user(ubase + ((char *)w - base), + w->u.watcher->name, w->data, sizeof(*w), + w->u.watcher->usersize, w->watcher_size, + w->u.watcher->revision); +} + +static inline int ebt_entry_to_user(struct ebt_entry *e, const char *base, + char __user *ubase) +{ + int ret; + char __user *hlp; + const struct ebt_entry_target *t; + + if (e->bitmask == 0) { + /* special case !EBT_ENTRY_OR_ENTRIES */ + if (copy_to_user(ubase + ((char *)e - base), e, + sizeof(struct ebt_entries))) + return -EFAULT; + return 0; + } + + if (copy_to_user(ubase + ((char *)e - base), e, sizeof(*e))) + return -EFAULT; + + hlp = ubase + (((char *)e + e->target_offset) - base); + t = ebt_get_target_c(e); + + ret = EBT_MATCH_ITERATE(e, ebt_match_to_user, base, ubase); + if (ret != 0) + return ret; + ret = EBT_WATCHER_ITERATE(e, ebt_watcher_to_user, base, ubase); + if (ret != 0) + return ret; + ret = ebt_obj_to_user(hlp, t->u.target->name, t->data, sizeof(*t), + t->u.target->usersize, t->target_size, + t->u.target->revision); + if (ret != 0) + return ret; + + return 0; +} + +static int copy_counters_to_user(struct ebt_table *t, + const struct ebt_counter *oldcounters, + void __user *user, unsigned int num_counters, + unsigned int nentries) +{ + struct ebt_counter *counterstmp; + int ret = 0; + + /* userspace might not need the counters */ + if (num_counters == 0) + return 0; + + if (num_counters != nentries) + return -EINVAL; + + counterstmp = vmalloc(array_size(nentries, sizeof(*counterstmp))); + if (!counterstmp) + return -ENOMEM; + + write_lock_bh(&t->lock); + get_counters(oldcounters, counterstmp, nentries); + write_unlock_bh(&t->lock); + + if (copy_to_user(user, counterstmp, + array_size(nentries, sizeof(struct ebt_counter)))) + ret = -EFAULT; + vfree(counterstmp); + return ret; +} + +/* called with ebt_mutex locked */ +static int copy_everything_to_user(struct ebt_table *t, void __user *user, + const int *len, int cmd) +{ + struct ebt_replace tmp; + const struct ebt_counter *oldcounters; + unsigned int entries_size, nentries; + int ret; + char *entries; + + if (cmd == EBT_SO_GET_ENTRIES) { + entries_size = t->private->entries_size; + nentries = t->private->nentries; + entries = t->private->entries; + oldcounters = t->private->counters; + } else { + entries_size = t->table->entries_size; + nentries = t->table->nentries; + entries = t->table->entries; + oldcounters = t->table->counters; + } + + if (copy_from_user(&tmp, user, sizeof(tmp))) + return -EFAULT; + + if (*len != sizeof(struct ebt_replace) + entries_size + + (tmp.num_counters ? nentries * sizeof(struct ebt_counter) : 0)) + return -EINVAL; + + if (tmp.nentries != nentries) + return -EINVAL; + + if (tmp.entries_size != entries_size) + return -EINVAL; + + ret = copy_counters_to_user(t, oldcounters, tmp.counters, + tmp.num_counters, nentries); + if (ret) + return ret; + + /* set the match/watcher/target names right */ + return EBT_ENTRY_ITERATE(entries, entries_size, + ebt_entry_to_user, entries, tmp.entries); +} + +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT +/* 32 bit-userspace compatibility definitions. */ +struct compat_ebt_replace { + char name[EBT_TABLE_MAXNAMELEN]; + compat_uint_t valid_hooks; + compat_uint_t nentries; + compat_uint_t entries_size; + /* start of the chains */ + compat_uptr_t hook_entry[NF_BR_NUMHOOKS]; + /* nr of counters userspace expects back */ + compat_uint_t num_counters; + /* where the kernel will put the old counters. */ + compat_uptr_t counters; + compat_uptr_t entries; +}; + +/* struct ebt_entry_match, _target and _watcher have same layout */ +struct compat_ebt_entry_mwt { + union { + struct { + char name[EBT_EXTENSION_MAXNAMELEN]; + u8 revision; + }; + compat_uptr_t ptr; + } u; + compat_uint_t match_size; + compat_uint_t data[] __aligned(__alignof__(struct compat_ebt_replace)); +}; + +/* account for possible padding between match_size and ->data */ +static int ebt_compat_entry_padsize(void) +{ + BUILD_BUG_ON(sizeof(struct ebt_entry_match) < + sizeof(struct compat_ebt_entry_mwt)); + return (int) sizeof(struct ebt_entry_match) - + sizeof(struct compat_ebt_entry_mwt); +} + +static int ebt_compat_match_offset(const struct xt_match *match, + unsigned int userlen) +{ + /* ebt_among needs special handling. The kernel .matchsize is + * set to -1 at registration time; at runtime an EBT_ALIGN()ed + * value is expected. + * Example: userspace sends 4500, ebt_among.c wants 4504. + */ + if (unlikely(match->matchsize == -1)) + return XT_ALIGN(userlen) - COMPAT_XT_ALIGN(userlen); + return xt_compat_match_offset(match); +} + +static int compat_match_to_user(struct ebt_entry_match *m, void __user **dstptr, + unsigned int *size) +{ + const struct xt_match *match = m->u.match; + struct compat_ebt_entry_mwt __user *cm = *dstptr; + int off = ebt_compat_match_offset(match, m->match_size); + compat_uint_t msize = m->match_size - off; + + if (WARN_ON(off >= m->match_size)) + return -EINVAL; + + if (copy_to_user(cm->u.name, match->name, strlen(match->name) + 1) || + put_user(match->revision, &cm->u.revision) || + put_user(msize, &cm->match_size)) + return -EFAULT; + + if (match->compat_to_user) { + if (match->compat_to_user(cm->data, m->data)) + return -EFAULT; + } else { + if (xt_data_to_user(cm->data, m->data, match->usersize, msize, + COMPAT_XT_ALIGN(msize))) + return -EFAULT; + } + + *size -= ebt_compat_entry_padsize() + off; + *dstptr = cm->data; + *dstptr += msize; + return 0; +} + +static int compat_target_to_user(struct ebt_entry_target *t, + void __user **dstptr, + unsigned int *size) +{ + const struct xt_target *target = t->u.target; + struct compat_ebt_entry_mwt __user *cm = *dstptr; + int off = xt_compat_target_offset(target); + compat_uint_t tsize = t->target_size - off; + + if (WARN_ON(off >= t->target_size)) + return -EINVAL; + + if (copy_to_user(cm->u.name, target->name, strlen(target->name) + 1) || + put_user(target->revision, &cm->u.revision) || + put_user(tsize, &cm->match_size)) + return -EFAULT; + + if (target->compat_to_user) { + if (target->compat_to_user(cm->data, t->data)) + return -EFAULT; + } else { + if (xt_data_to_user(cm->data, t->data, target->usersize, tsize, + COMPAT_XT_ALIGN(tsize))) + return -EFAULT; + } + + *size -= ebt_compat_entry_padsize() + off; + *dstptr = cm->data; + *dstptr += tsize; + return 0; +} + +static int compat_watcher_to_user(struct ebt_entry_watcher *w, + void __user **dstptr, + unsigned int *size) +{ + return compat_target_to_user((struct ebt_entry_target *)w, + dstptr, size); +} + +static int compat_copy_entry_to_user(struct ebt_entry *e, void __user **dstptr, + unsigned int *size) +{ + struct ebt_entry_target *t; + struct ebt_entry __user *ce; + u32 watchers_offset, target_offset, next_offset; + compat_uint_t origsize; + int ret; + + if (e->bitmask == 0) { + if (*size < sizeof(struct ebt_entries)) + return -EINVAL; + if (copy_to_user(*dstptr, e, sizeof(struct ebt_entries))) + return -EFAULT; + + *dstptr += sizeof(struct ebt_entries); + *size -= sizeof(struct ebt_entries); + return 0; + } + + if (*size < sizeof(*ce)) + return -EINVAL; + + ce = *dstptr; + if (copy_to_user(ce, e, sizeof(*ce))) + return -EFAULT; + + origsize = *size; + *dstptr += sizeof(*ce); + + ret = EBT_MATCH_ITERATE(e, compat_match_to_user, dstptr, size); + if (ret) + return ret; + watchers_offset = e->watchers_offset - (origsize - *size); + + ret = EBT_WATCHER_ITERATE(e, compat_watcher_to_user, dstptr, size); + if (ret) + return ret; + target_offset = e->target_offset - (origsize - *size); + + t = ebt_get_target(e); + + ret = compat_target_to_user(t, dstptr, size); + if (ret) + return ret; + next_offset = e->next_offset - (origsize - *size); + + if (put_user(watchers_offset, &ce->watchers_offset) || + put_user(target_offset, &ce->target_offset) || + put_user(next_offset, &ce->next_offset)) + return -EFAULT; + + *size -= sizeof(*ce); + return 0; +} + +static int compat_calc_match(struct ebt_entry_match *m, int *off) +{ + *off += ebt_compat_match_offset(m->u.match, m->match_size); + *off += ebt_compat_entry_padsize(); + return 0; +} + +static int compat_calc_watcher(struct ebt_entry_watcher *w, int *off) +{ + *off += xt_compat_target_offset(w->u.watcher); + *off += ebt_compat_entry_padsize(); + return 0; +} + +static int compat_calc_entry(const struct ebt_entry *e, + const struct ebt_table_info *info, + const void *base, + struct compat_ebt_replace *newinfo) +{ + const struct ebt_entry_target *t; + unsigned int entry_offset; + int off, ret, i; + + if (e->bitmask == 0) + return 0; + + off = 0; + entry_offset = (void *)e - base; + + EBT_MATCH_ITERATE(e, compat_calc_match, &off); + EBT_WATCHER_ITERATE(e, compat_calc_watcher, &off); + + t = ebt_get_target_c(e); + + off += xt_compat_target_offset(t->u.target); + off += ebt_compat_entry_padsize(); + + newinfo->entries_size -= off; + + ret = xt_compat_add_offset(NFPROTO_BRIDGE, entry_offset, off); + if (ret) + return ret; + + for (i = 0; i < NF_BR_NUMHOOKS; i++) { + const void *hookptr = info->hook_entry[i]; + if (info->hook_entry[i] && + (e < (struct ebt_entry *)(base - hookptr))) { + newinfo->hook_entry[i] -= off; + pr_debug("0x%08X -> 0x%08X\n", + newinfo->hook_entry[i] + off, + newinfo->hook_entry[i]); + } + } + + return 0; +} + +static int ebt_compat_init_offsets(unsigned int number) +{ + if (number > INT_MAX) + return -EINVAL; + + /* also count the base chain policies */ + number += NF_BR_NUMHOOKS; + + return xt_compat_init_offsets(NFPROTO_BRIDGE, number); +} + +static int compat_table_info(const struct ebt_table_info *info, + struct compat_ebt_replace *newinfo) +{ + unsigned int size = info->entries_size; + const void *entries = info->entries; + int ret; + + newinfo->entries_size = size; + ret = ebt_compat_init_offsets(info->nentries); + if (ret) + return ret; + + return EBT_ENTRY_ITERATE(entries, size, compat_calc_entry, info, + entries, newinfo); +} + +static int compat_copy_everything_to_user(struct ebt_table *t, + void __user *user, int *len, int cmd) +{ + struct compat_ebt_replace repl, tmp; + struct ebt_counter *oldcounters; + struct ebt_table_info tinfo; + int ret; + void __user *pos; + + memset(&tinfo, 0, sizeof(tinfo)); + + if (cmd == EBT_SO_GET_ENTRIES) { + tinfo.entries_size = t->private->entries_size; + tinfo.nentries = t->private->nentries; + tinfo.entries = t->private->entries; + oldcounters = t->private->counters; + } else { + tinfo.entries_size = t->table->entries_size; + tinfo.nentries = t->table->nentries; + tinfo.entries = t->table->entries; + oldcounters = t->table->counters; + } + + if (copy_from_user(&tmp, user, sizeof(tmp))) + return -EFAULT; + + if (tmp.nentries != tinfo.nentries || + (tmp.num_counters && tmp.num_counters != tinfo.nentries)) + return -EINVAL; + + memcpy(&repl, &tmp, sizeof(repl)); + if (cmd == EBT_SO_GET_ENTRIES) + ret = compat_table_info(t->private, &repl); + else + ret = compat_table_info(&tinfo, &repl); + if (ret) + return ret; + + if (*len != sizeof(tmp) + repl.entries_size + + (tmp.num_counters? tinfo.nentries * sizeof(struct ebt_counter): 0)) { + pr_err("wrong size: *len %d, entries_size %u, replsz %d\n", + *len, tinfo.entries_size, repl.entries_size); + return -EINVAL; + } + + /* userspace might not need the counters */ + ret = copy_counters_to_user(t, oldcounters, compat_ptr(tmp.counters), + tmp.num_counters, tinfo.nentries); + if (ret) + return ret; + + pos = compat_ptr(tmp.entries); + return EBT_ENTRY_ITERATE(tinfo.entries, tinfo.entries_size, + compat_copy_entry_to_user, &pos, &tmp.entries_size); +} + +struct ebt_entries_buf_state { + char *buf_kern_start; /* kernel buffer to copy (translated) data to */ + u32 buf_kern_len; /* total size of kernel buffer */ + u32 buf_kern_offset; /* amount of data copied so far */ + u32 buf_user_offset; /* read position in userspace buffer */ +}; + +static int ebt_buf_count(struct ebt_entries_buf_state *state, unsigned int sz) +{ + state->buf_kern_offset += sz; + return state->buf_kern_offset >= sz ? 0 : -EINVAL; +} + +static int ebt_buf_add(struct ebt_entries_buf_state *state, + const void *data, unsigned int sz) +{ + if (state->buf_kern_start == NULL) + goto count_only; + + if (WARN_ON(state->buf_kern_offset + sz > state->buf_kern_len)) + return -EINVAL; + + memcpy(state->buf_kern_start + state->buf_kern_offset, data, sz); + + count_only: + state->buf_user_offset += sz; + return ebt_buf_count(state, sz); +} + +static int ebt_buf_add_pad(struct ebt_entries_buf_state *state, unsigned int sz) +{ + char *b = state->buf_kern_start; + + if (WARN_ON(b && state->buf_kern_offset > state->buf_kern_len)) + return -EINVAL; + + if (b != NULL && sz > 0) + memset(b + state->buf_kern_offset, 0, sz); + /* do not adjust ->buf_user_offset here, we added kernel-side padding */ + return ebt_buf_count(state, sz); +} + +enum compat_mwt { + EBT_COMPAT_MATCH, + EBT_COMPAT_WATCHER, + EBT_COMPAT_TARGET, +}; + +static int compat_mtw_from_user(const struct compat_ebt_entry_mwt *mwt, + enum compat_mwt compat_mwt, + struct ebt_entries_buf_state *state, + const unsigned char *base) +{ + char name[EBT_EXTENSION_MAXNAMELEN]; + struct xt_match *match; + struct xt_target *wt; + void *dst = NULL; + int off, pad = 0; + unsigned int size_kern, match_size = mwt->match_size; + + if (strscpy(name, mwt->u.name, sizeof(name)) < 0) + return -EINVAL; + + if (state->buf_kern_start) + dst = state->buf_kern_start + state->buf_kern_offset; + + switch (compat_mwt) { + case EBT_COMPAT_MATCH: + match = xt_request_find_match(NFPROTO_BRIDGE, name, + mwt->u.revision); + if (IS_ERR(match)) + return PTR_ERR(match); + + off = ebt_compat_match_offset(match, match_size); + if (dst) { + if (match->compat_from_user) + match->compat_from_user(dst, mwt->data); + else + memcpy(dst, mwt->data, match_size); + } + + size_kern = match->matchsize; + if (unlikely(size_kern == -1)) + size_kern = match_size; + module_put(match->me); + break; + case EBT_COMPAT_WATCHER: + case EBT_COMPAT_TARGET: + wt = xt_request_find_target(NFPROTO_BRIDGE, name, + mwt->u.revision); + if (IS_ERR(wt)) + return PTR_ERR(wt); + off = xt_compat_target_offset(wt); + + if (dst) { + if (wt->compat_from_user) + wt->compat_from_user(dst, mwt->data); + else + memcpy(dst, mwt->data, match_size); + } + + size_kern = wt->targetsize; + module_put(wt->me); + break; + + default: + return -EINVAL; + } + + state->buf_kern_offset += match_size + off; + state->buf_user_offset += match_size; + pad = XT_ALIGN(size_kern) - size_kern; + + if (pad > 0 && dst) { + if (WARN_ON(state->buf_kern_len <= pad)) + return -EINVAL; + if (WARN_ON(state->buf_kern_offset - (match_size + off) + size_kern > state->buf_kern_len - pad)) + return -EINVAL; + memset(dst + size_kern, 0, pad); + } + return off + match_size; +} + +/* return size of all matches, watchers or target, including necessary + * alignment and padding. + */ +static int ebt_size_mwt(const struct compat_ebt_entry_mwt *match32, + unsigned int size_left, enum compat_mwt type, + struct ebt_entries_buf_state *state, const void *base) +{ + const char *buf = (const char *)match32; + int growth = 0; + + if (size_left == 0) + return 0; + + do { + struct ebt_entry_match *match_kern; + int ret; + + if (size_left < sizeof(*match32)) + return -EINVAL; + + match_kern = (struct ebt_entry_match *) state->buf_kern_start; + if (match_kern) { + char *tmp; + tmp = state->buf_kern_start + state->buf_kern_offset; + match_kern = (struct ebt_entry_match *) tmp; + } + ret = ebt_buf_add(state, buf, sizeof(*match32)); + if (ret < 0) + return ret; + size_left -= sizeof(*match32); + + /* add padding before match->data (if any) */ + ret = ebt_buf_add_pad(state, ebt_compat_entry_padsize()); + if (ret < 0) + return ret; + + if (match32->match_size > size_left) + return -EINVAL; + + size_left -= match32->match_size; + + ret = compat_mtw_from_user(match32, type, state, base); + if (ret < 0) + return ret; + + if (WARN_ON(ret < match32->match_size)) + return -EINVAL; + growth += ret - match32->match_size; + growth += ebt_compat_entry_padsize(); + + buf += sizeof(*match32); + buf += match32->match_size; + + if (match_kern) + match_kern->match_size = ret; + + match32 = (struct compat_ebt_entry_mwt *) buf; + } while (size_left); + + return growth; +} + +/* called for all ebt_entry structures. */ +static int size_entry_mwt(const struct ebt_entry *entry, const unsigned char *base, + unsigned int *total, + struct ebt_entries_buf_state *state) +{ + unsigned int i, j, startoff, next_expected_off, new_offset = 0; + /* stores match/watchers/targets & offset of next struct ebt_entry: */ + unsigned int offsets[4]; + unsigned int *offsets_update = NULL; + int ret; + char *buf_start; + + if (*total < sizeof(struct ebt_entries)) + return -EINVAL; + + if (!entry->bitmask) { + *total -= sizeof(struct ebt_entries); + return ebt_buf_add(state, entry, sizeof(struct ebt_entries)); + } + if (*total < sizeof(*entry) || entry->next_offset < sizeof(*entry)) + return -EINVAL; + + startoff = state->buf_user_offset; + /* pull in most part of ebt_entry, it does not need to be changed. */ + ret = ebt_buf_add(state, entry, + offsetof(struct ebt_entry, watchers_offset)); + if (ret < 0) + return ret; + + offsets[0] = sizeof(struct ebt_entry); /* matches come first */ + memcpy(&offsets[1], &entry->offsets, sizeof(entry->offsets)); + + if (state->buf_kern_start) { + buf_start = state->buf_kern_start + state->buf_kern_offset; + offsets_update = (unsigned int *) buf_start; + } + ret = ebt_buf_add(state, &offsets[1], + sizeof(offsets) - sizeof(offsets[0])); + if (ret < 0) + return ret; + buf_start = (char *) entry; + /* 0: matches offset, always follows ebt_entry. + * 1: watchers offset, from ebt_entry structure + * 2: target offset, from ebt_entry structure + * 3: next ebt_entry offset, from ebt_entry structure + * + * offsets are relative to beginning of struct ebt_entry (i.e., 0). + */ + for (i = 0; i < 4 ; ++i) { + if (offsets[i] > *total) + return -EINVAL; + + if (i < 3 && offsets[i] == *total) + return -EINVAL; + + if (i == 0) + continue; + if (offsets[i-1] > offsets[i]) + return -EINVAL; + } + + for (i = 0, j = 1 ; j < 4 ; j++, i++) { + struct compat_ebt_entry_mwt *match32; + unsigned int size; + char *buf = buf_start + offsets[i]; + + if (offsets[i] > offsets[j]) + return -EINVAL; + + match32 = (struct compat_ebt_entry_mwt *) buf; + size = offsets[j] - offsets[i]; + ret = ebt_size_mwt(match32, size, i, state, base); + if (ret < 0) + return ret; + new_offset += ret; + if (offsets_update && new_offset) { + pr_debug("change offset %d to %d\n", + offsets_update[i], offsets[j] + new_offset); + offsets_update[i] = offsets[j] + new_offset; + } + } + + if (state->buf_kern_start == NULL) { + unsigned int offset = buf_start - (char *) base; + + ret = xt_compat_add_offset(NFPROTO_BRIDGE, offset, new_offset); + if (ret < 0) + return ret; + } + + next_expected_off = state->buf_user_offset - startoff; + if (next_expected_off != entry->next_offset) + return -EINVAL; + + if (*total < entry->next_offset) + return -EINVAL; + *total -= entry->next_offset; + return 0; +} + +/* repl->entries_size is the size of the ebt_entry blob in userspace. + * It might need more memory when copied to a 64 bit kernel in case + * userspace is 32-bit. So, first task: find out how much memory is needed. + * + * Called before validation is performed. + */ +static int compat_copy_entries(unsigned char *data, unsigned int size_user, + struct ebt_entries_buf_state *state) +{ + unsigned int size_remaining = size_user; + int ret; + + ret = EBT_ENTRY_ITERATE(data, size_user, size_entry_mwt, data, + &size_remaining, state); + if (ret < 0) + return ret; + + if (size_remaining) + return -EINVAL; + + return state->buf_kern_offset; +} + + +static int compat_copy_ebt_replace_from_user(struct ebt_replace *repl, + sockptr_t arg, unsigned int len) +{ + struct compat_ebt_replace tmp; + int i; + + if (len < sizeof(tmp)) + return -EINVAL; + + if (copy_from_sockptr(&tmp, arg, sizeof(tmp))) + return -EFAULT; + + if (len != sizeof(tmp) + tmp.entries_size) + return -EINVAL; + + if (tmp.entries_size == 0) + return -EINVAL; + + if (tmp.nentries >= ((INT_MAX - sizeof(struct ebt_table_info)) / + NR_CPUS - SMP_CACHE_BYTES) / sizeof(struct ebt_counter)) + return -ENOMEM; + if (tmp.num_counters >= INT_MAX / sizeof(struct ebt_counter)) + return -ENOMEM; + + memcpy(repl, &tmp, offsetof(struct ebt_replace, hook_entry)); + + /* starting with hook_entry, 32 vs. 64 bit structures are different */ + for (i = 0; i < NF_BR_NUMHOOKS; i++) + repl->hook_entry[i] = compat_ptr(tmp.hook_entry[i]); + + repl->num_counters = tmp.num_counters; + repl->counters = compat_ptr(tmp.counters); + repl->entries = compat_ptr(tmp.entries); + return 0; +} + +static int compat_do_replace(struct net *net, sockptr_t arg, unsigned int len) +{ + int ret, i, countersize, size64; + struct ebt_table_info *newinfo; + struct ebt_replace tmp; + struct ebt_entries_buf_state state; + void *entries_tmp; + + ret = compat_copy_ebt_replace_from_user(&tmp, arg, len); + if (ret) { + /* try real handler in case userland supplied needed padding */ + if (ret == -EINVAL && do_replace(net, arg, len) == 0) + ret = 0; + return ret; + } + + countersize = COUNTER_OFFSET(tmp.nentries) * nr_cpu_ids; + newinfo = vmalloc(sizeof(*newinfo) + countersize); + if (!newinfo) + return -ENOMEM; + + if (countersize) + memset(newinfo->counters, 0, countersize); + + memset(&state, 0, sizeof(state)); + + newinfo->entries = vmalloc(tmp.entries_size); + if (!newinfo->entries) { + ret = -ENOMEM; + goto free_newinfo; + } + if (copy_from_user( + newinfo->entries, tmp.entries, tmp.entries_size) != 0) { + ret = -EFAULT; + goto free_entries; + } + + entries_tmp = newinfo->entries; + + xt_compat_lock(NFPROTO_BRIDGE); + + ret = ebt_compat_init_offsets(tmp.nentries); + if (ret < 0) + goto out_unlock; + + ret = compat_copy_entries(entries_tmp, tmp.entries_size, &state); + if (ret < 0) + goto out_unlock; + + pr_debug("tmp.entries_size %d, kern off %d, user off %d delta %d\n", + tmp.entries_size, state.buf_kern_offset, state.buf_user_offset, + xt_compat_calc_jump(NFPROTO_BRIDGE, tmp.entries_size)); + + size64 = ret; + newinfo->entries = vmalloc(size64); + if (!newinfo->entries) { + vfree(entries_tmp); + ret = -ENOMEM; + goto out_unlock; + } + + memset(&state, 0, sizeof(state)); + state.buf_kern_start = newinfo->entries; + state.buf_kern_len = size64; + + ret = compat_copy_entries(entries_tmp, tmp.entries_size, &state); + if (WARN_ON(ret < 0)) { + vfree(entries_tmp); + goto out_unlock; + } + + vfree(entries_tmp); + tmp.entries_size = size64; + + for (i = 0; i < NF_BR_NUMHOOKS; i++) { + char __user *usrptr; + if (tmp.hook_entry[i]) { + unsigned int delta; + usrptr = (char __user *) tmp.hook_entry[i]; + delta = usrptr - tmp.entries; + usrptr += xt_compat_calc_jump(NFPROTO_BRIDGE, delta); + tmp.hook_entry[i] = (struct ebt_entries __user *)usrptr; + } + } + + xt_compat_flush_offsets(NFPROTO_BRIDGE); + xt_compat_unlock(NFPROTO_BRIDGE); + + ret = do_replace_finish(net, &tmp, newinfo); + if (ret == 0) + return ret; +free_entries: + vfree(newinfo->entries); +free_newinfo: + vfree(newinfo); + return ret; +out_unlock: + xt_compat_flush_offsets(NFPROTO_BRIDGE); + xt_compat_unlock(NFPROTO_BRIDGE); + goto free_entries; +} + +static int compat_update_counters(struct net *net, sockptr_t arg, + unsigned int len) +{ + struct compat_ebt_replace hlp; + + if (copy_from_sockptr(&hlp, arg, sizeof(hlp))) + return -EFAULT; + + /* try real handler in case userland supplied needed padding */ + if (len != sizeof(hlp) + hlp.num_counters * sizeof(struct ebt_counter)) + return update_counters(net, arg, len); + + return do_update_counters(net, hlp.name, compat_ptr(hlp.counters), + hlp.num_counters, len); +} + +static int compat_do_ebt_get_ctl(struct sock *sk, int cmd, + void __user *user, int *len) +{ + int ret; + struct compat_ebt_replace tmp; + struct ebt_table *t; + struct net *net = sock_net(sk); + + if ((cmd == EBT_SO_GET_INFO || cmd == EBT_SO_GET_INIT_INFO) && + *len != sizeof(struct compat_ebt_replace)) + return -EINVAL; + + if (copy_from_user(&tmp, user, sizeof(tmp))) + return -EFAULT; + + tmp.name[sizeof(tmp.name) - 1] = '\0'; + + t = find_table_lock(net, tmp.name, &ret, &ebt_mutex); + if (!t) + return ret; + + xt_compat_lock(NFPROTO_BRIDGE); + switch (cmd) { + case EBT_SO_GET_INFO: + tmp.nentries = t->private->nentries; + ret = compat_table_info(t->private, &tmp); + if (ret) + goto out; + tmp.valid_hooks = t->valid_hooks; + + if (copy_to_user(user, &tmp, *len) != 0) { + ret = -EFAULT; + break; + } + ret = 0; + break; + case EBT_SO_GET_INIT_INFO: + tmp.nentries = t->table->nentries; + tmp.entries_size = t->table->entries_size; + tmp.valid_hooks = t->table->valid_hooks; + + if (copy_to_user(user, &tmp, *len) != 0) { + ret = -EFAULT; + break; + } + ret = 0; + break; + case EBT_SO_GET_ENTRIES: + case EBT_SO_GET_INIT_ENTRIES: + /* try real handler first in case of userland-side padding. + * in case we are dealing with an 'ordinary' 32 bit binary + * without 64bit compatibility padding, this will fail right + * after copy_from_user when the *len argument is validated. + * + * the compat_ variant needs to do one pass over the kernel + * data set to adjust for size differences before it the check. + */ + if (copy_everything_to_user(t, user, len, cmd) == 0) + ret = 0; + else + ret = compat_copy_everything_to_user(t, user, len, cmd); + break; + default: + ret = -EINVAL; + } + out: + xt_compat_flush_offsets(NFPROTO_BRIDGE); + xt_compat_unlock(NFPROTO_BRIDGE); + mutex_unlock(&ebt_mutex); + return ret; +} +#endif + +static int do_ebt_get_ctl(struct sock *sk, int cmd, void __user *user, int *len) +{ + struct net *net = sock_net(sk); + struct ebt_replace tmp; + struct ebt_table *t; + int ret; + + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + /* try real handler in case userland supplied needed padding */ + if (in_compat_syscall() && + ((cmd != EBT_SO_GET_INFO && cmd != EBT_SO_GET_INIT_INFO) || + *len != sizeof(tmp))) + return compat_do_ebt_get_ctl(sk, cmd, user, len); +#endif + + if (copy_from_user(&tmp, user, sizeof(tmp))) + return -EFAULT; + + tmp.name[sizeof(tmp.name) - 1] = '\0'; + + t = find_table_lock(net, tmp.name, &ret, &ebt_mutex); + if (!t) + return ret; + + switch (cmd) { + case EBT_SO_GET_INFO: + case EBT_SO_GET_INIT_INFO: + if (*len != sizeof(struct ebt_replace)) { + ret = -EINVAL; + mutex_unlock(&ebt_mutex); + break; + } + if (cmd == EBT_SO_GET_INFO) { + tmp.nentries = t->private->nentries; + tmp.entries_size = t->private->entries_size; + tmp.valid_hooks = t->valid_hooks; + } else { + tmp.nentries = t->table->nentries; + tmp.entries_size = t->table->entries_size; + tmp.valid_hooks = t->table->valid_hooks; + } + mutex_unlock(&ebt_mutex); + if (copy_to_user(user, &tmp, *len) != 0) { + ret = -EFAULT; + break; + } + ret = 0; + break; + + case EBT_SO_GET_ENTRIES: + case EBT_SO_GET_INIT_ENTRIES: + ret = copy_everything_to_user(t, user, len, cmd); + mutex_unlock(&ebt_mutex); + break; + + default: + mutex_unlock(&ebt_mutex); + ret = -EINVAL; + } + + return ret; +} + +static int do_ebt_set_ctl(struct sock *sk, int cmd, sockptr_t arg, + unsigned int len) +{ + struct net *net = sock_net(sk); + int ret; + + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + switch (cmd) { + case EBT_SO_SET_ENTRIES: +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + if (in_compat_syscall()) + ret = compat_do_replace(net, arg, len); + else +#endif + ret = do_replace(net, arg, len); + break; + case EBT_SO_SET_COUNTERS: +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + if (in_compat_syscall()) + ret = compat_update_counters(net, arg, len); + else +#endif + ret = update_counters(net, arg, len); + break; + default: + ret = -EINVAL; + } + return ret; +} + +static struct nf_sockopt_ops ebt_sockopts = { + .pf = PF_INET, + .set_optmin = EBT_BASE_CTL, + .set_optmax = EBT_SO_SET_MAX + 1, + .set = do_ebt_set_ctl, + .get_optmin = EBT_BASE_CTL, + .get_optmax = EBT_SO_GET_MAX + 1, + .get = do_ebt_get_ctl, + .owner = THIS_MODULE, +}; + +static int __net_init ebt_pernet_init(struct net *net) +{ + struct ebt_pernet *ebt_net = net_generic(net, ebt_pernet_id); + + INIT_LIST_HEAD(&ebt_net->tables); + return 0; +} + +static struct pernet_operations ebt_net_ops = { + .init = ebt_pernet_init, + .id = &ebt_pernet_id, + .size = sizeof(struct ebt_pernet), +}; + +static int __init ebtables_init(void) +{ + int ret; + + ret = xt_register_target(&ebt_standard_target); + if (ret < 0) + return ret; + ret = nf_register_sockopt(&ebt_sockopts); + if (ret < 0) { + xt_unregister_target(&ebt_standard_target); + return ret; + } + + ret = register_pernet_subsys(&ebt_net_ops); + if (ret < 0) { + nf_unregister_sockopt(&ebt_sockopts); + xt_unregister_target(&ebt_standard_target); + return ret; + } + + return 0; +} + +static void ebtables_fini(void) +{ + nf_unregister_sockopt(&ebt_sockopts); + xt_unregister_target(&ebt_standard_target); + unregister_pernet_subsys(&ebt_net_ops); +} + +EXPORT_SYMBOL(ebt_register_table); +EXPORT_SYMBOL(ebt_unregister_table); +EXPORT_SYMBOL(ebt_do_table); +module_init(ebtables_init); +module_exit(ebtables_fini); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/nf_conntrack_bridge.c b/net/bridge/netfilter/nf_conntrack_bridge.c new file mode 100644 index 000000000..06d94b2c6 --- /dev/null +++ b/net/bridge/netfilter/nf_conntrack_bridge.c @@ -0,0 +1,448 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +#include "../br_private.h" + +/* Best effort variant of ip_do_fragment which preserves geometry, unless skbuff + * has been linearized or cloned. + */ +static int nf_br_ip_fragment(struct net *net, struct sock *sk, + struct sk_buff *skb, + struct nf_bridge_frag_data *data, + int (*output)(struct net *, struct sock *sk, + const struct nf_bridge_frag_data *data, + struct sk_buff *)) +{ + int frag_max_size = BR_INPUT_SKB_CB(skb)->frag_max_size; + bool mono_delivery_time = skb->mono_delivery_time; + unsigned int hlen, ll_rs, mtu; + ktime_t tstamp = skb->tstamp; + struct ip_frag_state state; + struct iphdr *iph; + int err = 0; + + /* for offloaded checksums cleanup checksum before fragmentation */ + if (skb->ip_summed == CHECKSUM_PARTIAL && + (err = skb_checksum_help(skb))) + goto blackhole; + + iph = ip_hdr(skb); + + /* + * Setup starting values + */ + + hlen = iph->ihl * 4; + frag_max_size -= hlen; + ll_rs = LL_RESERVED_SPACE(skb->dev); + mtu = skb->dev->mtu; + + if (skb_has_frag_list(skb)) { + unsigned int first_len = skb_pagelen(skb); + struct ip_fraglist_iter iter; + struct sk_buff *frag; + + if (first_len - hlen > mtu || + skb_headroom(skb) < ll_rs) + goto blackhole; + + if (skb_cloned(skb)) + goto slow_path; + + skb_walk_frags(skb, frag) { + if (frag->len > mtu || + skb_headroom(frag) < hlen + ll_rs) + goto blackhole; + + if (skb_shared(frag)) + goto slow_path; + } + + ip_fraglist_init(skb, iph, hlen, &iter); + + for (;;) { + if (iter.frag) + ip_fraglist_prepare(skb, &iter); + + skb_set_delivery_time(skb, tstamp, mono_delivery_time); + err = output(net, sk, data, skb); + if (err || !iter.frag) + break; + + skb = ip_fraglist_next(&iter); + } + + if (!err) + return 0; + + kfree_skb_list(iter.frag); + + return err; + } +slow_path: + /* This is a linearized skbuff, the original geometry is lost for us. + * This may also be a clone skbuff, we could preserve the geometry for + * the copies but probably not worth the effort. + */ + ip_frag_init(skb, hlen, ll_rs, frag_max_size, false, &state); + + while (state.left > 0) { + struct sk_buff *skb2; + + skb2 = ip_frag_next(skb, &state); + if (IS_ERR(skb2)) { + err = PTR_ERR(skb2); + goto blackhole; + } + + skb_set_delivery_time(skb2, tstamp, mono_delivery_time); + err = output(net, sk, data, skb2); + if (err) + goto blackhole; + } + consume_skb(skb); + return err; + +blackhole: + kfree_skb(skb); + return 0; +} + +/* ip_defrag() expects IPCB() in place. */ +static void br_skb_cb_save(struct sk_buff *skb, struct br_input_skb_cb *cb, + size_t inet_skb_parm_size) +{ + memcpy(cb, skb->cb, sizeof(*cb)); + memset(skb->cb, 0, inet_skb_parm_size); +} + +static void br_skb_cb_restore(struct sk_buff *skb, + const struct br_input_skb_cb *cb, + u16 fragsz) +{ + memcpy(skb->cb, cb, sizeof(*cb)); + BR_INPUT_SKB_CB(skb)->frag_max_size = fragsz; +} + +static unsigned int nf_ct_br_defrag4(struct sk_buff *skb, + const struct nf_hook_state *state) +{ + u16 zone_id = NF_CT_DEFAULT_ZONE_ID; + enum ip_conntrack_info ctinfo; + struct br_input_skb_cb cb; + const struct nf_conn *ct; + int err; + + if (!ip_is_fragment(ip_hdr(skb))) + return NF_ACCEPT; + + ct = nf_ct_get(skb, &ctinfo); + if (ct) + zone_id = nf_ct_zone_id(nf_ct_zone(ct), CTINFO2DIR(ctinfo)); + + br_skb_cb_save(skb, &cb, sizeof(struct inet_skb_parm)); + local_bh_disable(); + err = ip_defrag(state->net, skb, + IP_DEFRAG_CONNTRACK_BRIDGE_IN + zone_id); + local_bh_enable(); + if (!err) { + br_skb_cb_restore(skb, &cb, IPCB(skb)->frag_max_size); + skb->ignore_df = 1; + return NF_ACCEPT; + } + + return NF_STOLEN; +} + +static unsigned int nf_ct_br_defrag6(struct sk_buff *skb, + const struct nf_hook_state *state) +{ +#if IS_ENABLED(CONFIG_NF_DEFRAG_IPV6) + u16 zone_id = NF_CT_DEFAULT_ZONE_ID; + enum ip_conntrack_info ctinfo; + struct br_input_skb_cb cb; + const struct nf_conn *ct; + int err; + + ct = nf_ct_get(skb, &ctinfo); + if (ct) + zone_id = nf_ct_zone_id(nf_ct_zone(ct), CTINFO2DIR(ctinfo)); + + br_skb_cb_save(skb, &cb, sizeof(struct inet6_skb_parm)); + + err = nf_ct_frag6_gather(state->net, skb, + IP_DEFRAG_CONNTRACK_BRIDGE_IN + zone_id); + /* queued */ + if (err == -EINPROGRESS) + return NF_STOLEN; + + br_skb_cb_restore(skb, &cb, IP6CB(skb)->frag_max_size); + return err == 0 ? NF_ACCEPT : NF_DROP; +#else + return NF_ACCEPT; +#endif +} + +static int nf_ct_br_ip_check(const struct sk_buff *skb) +{ + const struct iphdr *iph; + int nhoff, len; + + nhoff = skb_network_offset(skb); + iph = ip_hdr(skb); + if (iph->ihl < 5 || + iph->version != 4) + return -1; + + len = ntohs(iph->tot_len); + if (skb->len < nhoff + len || + len < (iph->ihl * 4)) + return -1; + + return 0; +} + +static int nf_ct_br_ipv6_check(const struct sk_buff *skb) +{ + const struct ipv6hdr *hdr; + int nhoff, len; + + nhoff = skb_network_offset(skb); + hdr = ipv6_hdr(skb); + if (hdr->version != 6) + return -1; + + len = ntohs(hdr->payload_len) + sizeof(struct ipv6hdr) + nhoff; + if (skb->len < len) + return -1; + + return 0; +} + +static unsigned int nf_ct_bridge_pre(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nf_hook_state bridge_state = *state; + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + u32 len; + int ret; + + ct = nf_ct_get(skb, &ctinfo); + if ((ct && !nf_ct_is_template(ct)) || + ctinfo == IP_CT_UNTRACKED) + return NF_ACCEPT; + + switch (skb->protocol) { + case htons(ETH_P_IP): + if (!pskb_may_pull(skb, sizeof(struct iphdr))) + return NF_ACCEPT; + + len = ntohs(ip_hdr(skb)->tot_len); + if (pskb_trim_rcsum(skb, len)) + return NF_ACCEPT; + + if (nf_ct_br_ip_check(skb)) + return NF_ACCEPT; + + bridge_state.pf = NFPROTO_IPV4; + ret = nf_ct_br_defrag4(skb, &bridge_state); + break; + case htons(ETH_P_IPV6): + if (!pskb_may_pull(skb, sizeof(struct ipv6hdr))) + return NF_ACCEPT; + + len = sizeof(struct ipv6hdr) + ntohs(ipv6_hdr(skb)->payload_len); + if (pskb_trim_rcsum(skb, len)) + return NF_ACCEPT; + + if (nf_ct_br_ipv6_check(skb)) + return NF_ACCEPT; + + bridge_state.pf = NFPROTO_IPV6; + ret = nf_ct_br_defrag6(skb, &bridge_state); + break; + default: + nf_ct_set(skb, NULL, IP_CT_UNTRACKED); + return NF_ACCEPT; + } + + if (ret != NF_ACCEPT) + return ret; + + return nf_conntrack_in(skb, &bridge_state); +} + +static void nf_ct_bridge_frag_save(struct sk_buff *skb, + struct nf_bridge_frag_data *data) +{ + if (skb_vlan_tag_present(skb)) { + data->vlan_present = true; + data->vlan_tci = skb->vlan_tci; + data->vlan_proto = skb->vlan_proto; + } else { + data->vlan_present = false; + } + skb_copy_from_linear_data_offset(skb, -ETH_HLEN, data->mac, ETH_HLEN); +} + +static unsigned int +nf_ct_bridge_refrag(struct sk_buff *skb, const struct nf_hook_state *state, + int (*output)(struct net *, struct sock *sk, + const struct nf_bridge_frag_data *data, + struct sk_buff *)) +{ + struct nf_bridge_frag_data data; + + if (!BR_INPUT_SKB_CB(skb)->frag_max_size) + return NF_ACCEPT; + + nf_ct_bridge_frag_save(skb, &data); + switch (skb->protocol) { + case htons(ETH_P_IP): + nf_br_ip_fragment(state->net, state->sk, skb, &data, output); + break; + case htons(ETH_P_IPV6): + nf_br_ip6_fragment(state->net, state->sk, skb, &data, output); + break; + default: + WARN_ON_ONCE(1); + return NF_DROP; + } + + return NF_STOLEN; +} + +/* Actually only slow path refragmentation needs this. */ +static int nf_ct_bridge_frag_restore(struct sk_buff *skb, + const struct nf_bridge_frag_data *data) +{ + int err; + + err = skb_cow_head(skb, ETH_HLEN); + if (err) { + kfree_skb(skb); + return -ENOMEM; + } + if (data->vlan_present) + __vlan_hwaccel_put_tag(skb, data->vlan_proto, data->vlan_tci); + else if (skb_vlan_tag_present(skb)) + __vlan_hwaccel_clear_tag(skb); + + skb_copy_to_linear_data_offset(skb, -ETH_HLEN, data->mac, ETH_HLEN); + skb_reset_mac_header(skb); + + return 0; +} + +static int nf_ct_bridge_refrag_post(struct net *net, struct sock *sk, + const struct nf_bridge_frag_data *data, + struct sk_buff *skb) +{ + int err; + + err = nf_ct_bridge_frag_restore(skb, data); + if (err < 0) + return err; + + return br_dev_queue_push_xmit(net, sk, skb); +} + +static unsigned int nf_ct_bridge_confirm(struct sk_buff *skb) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + int protoff; + + ct = nf_ct_get(skb, &ctinfo); + if (!ct || ctinfo == IP_CT_RELATED_REPLY) + return nf_conntrack_confirm(skb); + + switch (skb->protocol) { + case htons(ETH_P_IP): + protoff = skb_network_offset(skb) + ip_hdrlen(skb); + break; + case htons(ETH_P_IPV6): { + unsigned char pnum = ipv6_hdr(skb)->nexthdr; + __be16 frag_off; + + protoff = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &pnum, + &frag_off); + if (protoff < 0 || (frag_off & htons(~0x7)) != 0) + return nf_conntrack_confirm(skb); + } + break; + default: + return NF_ACCEPT; + } + return nf_confirm(skb, protoff, ct, ctinfo); +} + +static unsigned int nf_ct_bridge_post(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + int ret; + + ret = nf_ct_bridge_confirm(skb); + if (ret != NF_ACCEPT) + return ret; + + return nf_ct_bridge_refrag(skb, state, nf_ct_bridge_refrag_post); +} + +static struct nf_hook_ops nf_ct_bridge_hook_ops[] __read_mostly = { + { + .hook = nf_ct_bridge_pre, + .pf = NFPROTO_BRIDGE, + .hooknum = NF_BR_PRE_ROUTING, + .priority = NF_IP_PRI_CONNTRACK, + }, + { + .hook = nf_ct_bridge_post, + .pf = NFPROTO_BRIDGE, + .hooknum = NF_BR_POST_ROUTING, + .priority = NF_IP_PRI_CONNTRACK_CONFIRM, + }, +}; + +static struct nf_ct_bridge_info bridge_info = { + .ops = nf_ct_bridge_hook_ops, + .ops_size = ARRAY_SIZE(nf_ct_bridge_hook_ops), + .me = THIS_MODULE, +}; + +static int __init nf_conntrack_l3proto_bridge_init(void) +{ + nf_ct_bridge_register(&bridge_info); + + return 0; +} + +static void __exit nf_conntrack_l3proto_bridge_fini(void) +{ + nf_ct_bridge_unregister(&bridge_info); +} + +module_init(nf_conntrack_l3proto_bridge_init); +module_exit(nf_conntrack_l3proto_bridge_fini); + +MODULE_ALIAS("nf_conntrack-" __stringify(AF_BRIDGE)); +MODULE_LICENSE("GPL"); diff --git a/net/bridge/netfilter/nft_meta_bridge.c b/net/bridge/netfilter/nft_meta_bridge.c new file mode 100644 index 000000000..c3ecd77e2 --- /dev/null +++ b/net/bridge/netfilter/nft_meta_bridge.c @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static const struct net_device * +nft_meta_get_bridge(const struct net_device *dev) +{ + if (dev && netif_is_bridge_port(dev)) + return netdev_master_upper_dev_get_rcu((struct net_device *)dev); + + return NULL; +} + +static void nft_meta_bridge_get_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_meta *priv = nft_expr_priv(expr); + const struct net_device *in = nft_in(pkt), *out = nft_out(pkt); + u32 *dest = ®s->data[priv->dreg]; + const struct net_device *br_dev; + + switch (priv->key) { + case NFT_META_BRI_IIFNAME: + br_dev = nft_meta_get_bridge(in); + break; + case NFT_META_BRI_OIFNAME: + br_dev = nft_meta_get_bridge(out); + break; + case NFT_META_BRI_IIFPVID: { + u16 p_pvid; + + br_dev = nft_meta_get_bridge(in); + if (!br_dev || !br_vlan_enabled(br_dev)) + goto err; + + br_vlan_get_pvid_rcu(in, &p_pvid); + nft_reg_store16(dest, p_pvid); + return; + } + case NFT_META_BRI_IIFVPROTO: { + u16 p_proto; + + br_dev = nft_meta_get_bridge(in); + if (!br_dev || !br_vlan_enabled(br_dev)) + goto err; + + br_vlan_get_proto(br_dev, &p_proto); + nft_reg_store_be16(dest, htons(p_proto)); + return; + } + default: + return nft_meta_get_eval(expr, regs, pkt); + } + + strncpy((char *)dest, br_dev ? br_dev->name : "", IFNAMSIZ); + return; +err: + regs->verdict.code = NFT_BREAK; +} + +static int nft_meta_bridge_get_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_meta *priv = nft_expr_priv(expr); + unsigned int len; + + priv->key = ntohl(nla_get_be32(tb[NFTA_META_KEY])); + switch (priv->key) { + case NFT_META_BRI_IIFNAME: + case NFT_META_BRI_OIFNAME: + len = IFNAMSIZ; + break; + case NFT_META_BRI_IIFPVID: + case NFT_META_BRI_IIFVPROTO: + len = sizeof(u16); + break; + default: + return nft_meta_get_init(ctx, expr, tb); + } + + priv->len = len; + return nft_parse_register_store(ctx, tb[NFTA_META_DREG], &priv->dreg, + NULL, NFT_DATA_VALUE, len); +} + +static struct nft_expr_type nft_meta_bridge_type; +static const struct nft_expr_ops nft_meta_bridge_get_ops = { + .type = &nft_meta_bridge_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_meta)), + .eval = nft_meta_bridge_get_eval, + .init = nft_meta_bridge_get_init, + .dump = nft_meta_get_dump, + .reduce = nft_meta_get_reduce, +}; + +static bool nft_meta_bridge_set_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + int i; + + for (i = 0; i < NFT_REG32_NUM; i++) { + if (!track->regs[i].selector) + continue; + + if (track->regs[i].selector->ops != &nft_meta_bridge_get_ops) + continue; + + __nft_reg_track_cancel(track, i); + } + + return false; +} + +static const struct nft_expr_ops nft_meta_bridge_set_ops = { + .type = &nft_meta_bridge_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_meta)), + .eval = nft_meta_set_eval, + .init = nft_meta_set_init, + .destroy = nft_meta_set_destroy, + .dump = nft_meta_set_dump, + .reduce = nft_meta_bridge_set_reduce, + .validate = nft_meta_set_validate, +}; + +static const struct nft_expr_ops * +nft_meta_bridge_select_ops(const struct nft_ctx *ctx, + const struct nlattr * const tb[]) +{ + if (tb[NFTA_META_KEY] == NULL) + return ERR_PTR(-EINVAL); + + if (tb[NFTA_META_DREG] && tb[NFTA_META_SREG]) + return ERR_PTR(-EINVAL); + + if (tb[NFTA_META_DREG]) + return &nft_meta_bridge_get_ops; + + if (tb[NFTA_META_SREG]) + return &nft_meta_bridge_set_ops; + + return ERR_PTR(-EINVAL); +} + +static struct nft_expr_type nft_meta_bridge_type __read_mostly = { + .family = NFPROTO_BRIDGE, + .name = "meta", + .select_ops = nft_meta_bridge_select_ops, + .policy = nft_meta_policy, + .maxattr = NFTA_META_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_meta_bridge_module_init(void) +{ + return nft_register_expr(&nft_meta_bridge_type); +} + +static void __exit nft_meta_bridge_module_exit(void) +{ + nft_unregister_expr(&nft_meta_bridge_type); +} + +module_init(nft_meta_bridge_module_init); +module_exit(nft_meta_bridge_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("wenxu "); +MODULE_ALIAS_NFT_AF_EXPR(AF_BRIDGE, "meta"); +MODULE_DESCRIPTION("Support for bridge dedicated meta key"); diff --git a/net/bridge/netfilter/nft_reject_bridge.c b/net/bridge/netfilter/nft_reject_bridge.c new file mode 100644 index 000000000..71b54fed7 --- /dev/null +++ b/net/bridge/netfilter/nft_reject_bridge.c @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2014 Pablo Neira Ayuso + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../br_private.h" + +static void nft_reject_br_push_etherhdr(struct sk_buff *oldskb, + struct sk_buff *nskb) +{ + struct ethhdr *eth; + + eth = skb_push(nskb, ETH_HLEN); + skb_reset_mac_header(nskb); + ether_addr_copy(eth->h_source, eth_hdr(oldskb)->h_dest); + ether_addr_copy(eth->h_dest, eth_hdr(oldskb)->h_source); + eth->h_proto = eth_hdr(oldskb)->h_proto; + skb_pull(nskb, ETH_HLEN); + + if (skb_vlan_tag_present(oldskb)) { + u16 vid = skb_vlan_tag_get(oldskb); + + __vlan_hwaccel_put_tag(nskb, oldskb->vlan_proto, vid); + } +} + +/* We cannot use oldskb->dev, it can be either bridge device (NF_BRIDGE INPUT) + * or the bridge port (NF_BRIDGE PREROUTING). + */ +static void nft_reject_br_send_v4_tcp_reset(struct net *net, + struct sk_buff *oldskb, + const struct net_device *dev, + int hook) +{ + struct sk_buff *nskb; + + nskb = nf_reject_skb_v4_tcp_reset(net, oldskb, NULL, hook); + if (!nskb) + return; + + nft_reject_br_push_etherhdr(oldskb, nskb); + + br_forward(br_port_get_rcu(dev), nskb, false, true); +} + +static void nft_reject_br_send_v4_unreach(struct net *net, + struct sk_buff *oldskb, + const struct net_device *dev, + int hook, u8 code) +{ + struct sk_buff *nskb; + + nskb = nf_reject_skb_v4_unreach(net, oldskb, NULL, hook, code); + if (!nskb) + return; + + nft_reject_br_push_etherhdr(oldskb, nskb); + + br_forward(br_port_get_rcu(dev), nskb, false, true); +} + +static void nft_reject_br_send_v6_tcp_reset(struct net *net, + struct sk_buff *oldskb, + const struct net_device *dev, + int hook) +{ + struct sk_buff *nskb; + + nskb = nf_reject_skb_v6_tcp_reset(net, oldskb, NULL, hook); + if (!nskb) + return; + + nft_reject_br_push_etherhdr(oldskb, nskb); + + br_forward(br_port_get_rcu(dev), nskb, false, true); +} + + +static void nft_reject_br_send_v6_unreach(struct net *net, + struct sk_buff *oldskb, + const struct net_device *dev, + int hook, u8 code) +{ + struct sk_buff *nskb; + + nskb = nf_reject_skb_v6_unreach(net, oldskb, NULL, hook, code); + if (!nskb) + return; + + nft_reject_br_push_etherhdr(oldskb, nskb); + + br_forward(br_port_get_rcu(dev), nskb, false, true); +} + +static void nft_reject_bridge_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_reject *priv = nft_expr_priv(expr); + const unsigned char *dest = eth_hdr(pkt->skb)->h_dest; + + if (is_broadcast_ether_addr(dest) || + is_multicast_ether_addr(dest)) + goto out; + + switch (eth_hdr(pkt->skb)->h_proto) { + case htons(ETH_P_IP): + switch (priv->type) { + case NFT_REJECT_ICMP_UNREACH: + nft_reject_br_send_v4_unreach(nft_net(pkt), pkt->skb, + nft_in(pkt), + nft_hook(pkt), + priv->icmp_code); + break; + case NFT_REJECT_TCP_RST: + nft_reject_br_send_v4_tcp_reset(nft_net(pkt), pkt->skb, + nft_in(pkt), + nft_hook(pkt)); + break; + case NFT_REJECT_ICMPX_UNREACH: + nft_reject_br_send_v4_unreach(nft_net(pkt), pkt->skb, + nft_in(pkt), + nft_hook(pkt), + nft_reject_icmp_code(priv->icmp_code)); + break; + } + break; + case htons(ETH_P_IPV6): + switch (priv->type) { + case NFT_REJECT_ICMP_UNREACH: + nft_reject_br_send_v6_unreach(nft_net(pkt), pkt->skb, + nft_in(pkt), + nft_hook(pkt), + priv->icmp_code); + break; + case NFT_REJECT_TCP_RST: + nft_reject_br_send_v6_tcp_reset(nft_net(pkt), pkt->skb, + nft_in(pkt), + nft_hook(pkt)); + break; + case NFT_REJECT_ICMPX_UNREACH: + nft_reject_br_send_v6_unreach(nft_net(pkt), pkt->skb, + nft_in(pkt), + nft_hook(pkt), + nft_reject_icmpv6_code(priv->icmp_code)); + break; + } + break; + default: + /* No explicit way to reject this protocol, drop it. */ + break; + } +out: + regs->verdict.code = NF_DROP; +} + +static int nft_reject_bridge_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + return nft_chain_validate_hooks(ctx->chain, (1 << NF_BR_PRE_ROUTING) | + (1 << NF_BR_LOCAL_IN)); +} + +static struct nft_expr_type nft_reject_bridge_type; +static const struct nft_expr_ops nft_reject_bridge_ops = { + .type = &nft_reject_bridge_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_reject)), + .eval = nft_reject_bridge_eval, + .init = nft_reject_init, + .dump = nft_reject_dump, + .validate = nft_reject_bridge_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_reject_bridge_type __read_mostly = { + .family = NFPROTO_BRIDGE, + .name = "reject", + .ops = &nft_reject_bridge_ops, + .policy = nft_reject_policy, + .maxattr = NFTA_REJECT_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_reject_bridge_module_init(void) +{ + return nft_register_expr(&nft_reject_bridge_type); +} + +static void __exit nft_reject_bridge_module_exit(void) +{ + nft_unregister_expr(&nft_reject_bridge_type); +} + +module_init(nft_reject_bridge_module_init); +module_exit(nft_reject_bridge_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_ALIAS_NFT_AF_EXPR(AF_BRIDGE, "reject"); +MODULE_DESCRIPTION("Reject packets from bridge via nftables"); diff --git a/net/caif/Kconfig b/net/caif/Kconfig new file mode 100644 index 000000000..87205251c --- /dev/null +++ b/net/caif/Kconfig @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# CAIF net configurations +# + +menuconfig CAIF + tristate "CAIF support" + select CRC_CCITT + default n + help + The "Communication CPU to Application CPU Interface" (CAIF) is a packet + based connection-oriented MUX protocol developed by ST-Ericsson for use + with its modems. It is accessed from user space as sockets (PF_CAIF). + + Say Y (or M) here if you build for a phone product (e.g. Android or + MeeGo) that uses CAIF as transport. If unsure say N. + + If you select to build it as module then CAIF_NETDEV also needs to be + built as a module. You will also need to say Y (or M) to any CAIF + physical devices that your platform requires. + + See Documentation/networking/caif for a further explanation on how to + use and configure CAIF. + +config CAIF_DEBUG + bool "Enable Debug" + depends on CAIF + default n + help + Enable the inclusion of debug code in the CAIF stack. + Be aware that doing this will impact performance. + If unsure say N. + +config CAIF_NETDEV + tristate "CAIF GPRS Network device" + depends on CAIF + default CAIF + help + Say Y if you will be using a CAIF based GPRS network device. + This can be either built-in or a loadable module. + If you select to build it as a built-in then the main CAIF device must + also be a built-in. + If unsure say Y. + +config CAIF_USB + tristate "CAIF USB support" + depends on CAIF + default n + help + Say Y if you are using CAIF over USB CDC NCM. + This can be either built-in or a loadable module. + If you select to build it as a built-in then the main CAIF device must + also be a built-in. + If unsure say N. diff --git a/net/caif/Makefile b/net/caif/Makefile new file mode 100644 index 000000000..4f6c0517c --- /dev/null +++ b/net/caif/Makefile @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: GPL-2.0 +ccflags-$(CONFIG_CAIF_DEBUG) := -DDEBUG + +caif-y := caif_dev.o \ + cfcnfg.o cfmuxl.o cfctrl.o \ + cffrml.o cfveil.o cfdbgl.o\ + cfserl.o cfdgml.o \ + cfrfml.o cfvidl.o cfutill.o \ + cfsrvl.o cfpkt_skbuff.o + +obj-$(CONFIG_CAIF) += caif.o +obj-$(CONFIG_CAIF_NETDEV) += chnl_net.o +obj-$(CONFIG_CAIF) += caif_socket.o +obj-$(CONFIG_CAIF_USB) += caif_usb.o + +export-y := caif.o diff --git a/net/caif/caif_dev.c b/net/caif/caif_dev.c new file mode 100644 index 000000000..6a0cba4fc --- /dev/null +++ b/net/caif/caif_dev.c @@ -0,0 +1,585 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * CAIF Interface registration. + * Copyright (C) ST-Ericsson AB 2010 + * Author: Sjur Brendeland + * + * Borrowed heavily from file: pn_dev.c. Thanks to Remi Denis-Courmont + * and Sakari Ailus + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s(): " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); + +/* Used for local tracking of the CAIF net devices */ +struct caif_device_entry { + struct cflayer layer; + struct list_head list; + struct net_device *netdev; + int __percpu *pcpu_refcnt; + spinlock_t flow_lock; + struct sk_buff *xoff_skb; + void (*xoff_skb_dtor)(struct sk_buff *skb); + bool xoff; +}; + +struct caif_device_entry_list { + struct list_head list; + /* Protects simulanous deletes in list */ + struct mutex lock; +}; + +struct caif_net { + struct cfcnfg *cfg; + struct caif_device_entry_list caifdevs; +}; + +static unsigned int caif_net_id; +static int q_high = 50; /* Percent */ + +struct cfcnfg *get_cfcnfg(struct net *net) +{ + struct caif_net *caifn; + caifn = net_generic(net, caif_net_id); + return caifn->cfg; +} +EXPORT_SYMBOL(get_cfcnfg); + +static struct caif_device_entry_list *caif_device_list(struct net *net) +{ + struct caif_net *caifn; + caifn = net_generic(net, caif_net_id); + return &caifn->caifdevs; +} + +static void caifd_put(struct caif_device_entry *e) +{ + this_cpu_dec(*e->pcpu_refcnt); +} + +static void caifd_hold(struct caif_device_entry *e) +{ + this_cpu_inc(*e->pcpu_refcnt); +} + +static int caifd_refcnt_read(struct caif_device_entry *e) +{ + int i, refcnt = 0; + for_each_possible_cpu(i) + refcnt += *per_cpu_ptr(e->pcpu_refcnt, i); + return refcnt; +} + +/* Allocate new CAIF device. */ +static struct caif_device_entry *caif_device_alloc(struct net_device *dev) +{ + struct caif_device_entry *caifd; + + caifd = kzalloc(sizeof(*caifd), GFP_KERNEL); + if (!caifd) + return NULL; + caifd->pcpu_refcnt = alloc_percpu(int); + if (!caifd->pcpu_refcnt) { + kfree(caifd); + return NULL; + } + caifd->netdev = dev; + dev_hold(dev); + return caifd; +} + +static struct caif_device_entry *caif_get(struct net_device *dev) +{ + struct caif_device_entry_list *caifdevs = + caif_device_list(dev_net(dev)); + struct caif_device_entry *caifd; + + list_for_each_entry_rcu(caifd, &caifdevs->list, list, + lockdep_rtnl_is_held()) { + if (caifd->netdev == dev) + return caifd; + } + return NULL; +} + +static void caif_flow_cb(struct sk_buff *skb) +{ + struct caif_device_entry *caifd; + void (*dtor)(struct sk_buff *skb) = NULL; + bool send_xoff; + + WARN_ON(skb->dev == NULL); + + rcu_read_lock(); + caifd = caif_get(skb->dev); + + WARN_ON(caifd == NULL); + if (!caifd) { + rcu_read_unlock(); + return; + } + + caifd_hold(caifd); + rcu_read_unlock(); + + spin_lock_bh(&caifd->flow_lock); + send_xoff = caifd->xoff; + caifd->xoff = false; + dtor = caifd->xoff_skb_dtor; + + if (WARN_ON(caifd->xoff_skb != skb)) + skb = NULL; + + caifd->xoff_skb = NULL; + caifd->xoff_skb_dtor = NULL; + + spin_unlock_bh(&caifd->flow_lock); + + if (dtor && skb) + dtor(skb); + + if (send_xoff) + caifd->layer.up-> + ctrlcmd(caifd->layer.up, + _CAIF_CTRLCMD_PHYIF_FLOW_ON_IND, + caifd->layer.id); + caifd_put(caifd); +} + +static int transmit(struct cflayer *layer, struct cfpkt *pkt) +{ + int err, high = 0, qlen = 0; + struct caif_device_entry *caifd = + container_of(layer, struct caif_device_entry, layer); + struct sk_buff *skb; + struct netdev_queue *txq; + + rcu_read_lock_bh(); + + skb = cfpkt_tonative(pkt); + skb->dev = caifd->netdev; + skb_reset_network_header(skb); + skb->protocol = htons(ETH_P_CAIF); + + /* Check if we need to handle xoff */ + if (likely(caifd->netdev->priv_flags & IFF_NO_QUEUE)) + goto noxoff; + + if (unlikely(caifd->xoff)) + goto noxoff; + + if (likely(!netif_queue_stopped(caifd->netdev))) { + struct Qdisc *sch; + + /* If we run with a TX queue, check if the queue is too long*/ + txq = netdev_get_tx_queue(skb->dev, 0); + sch = rcu_dereference_bh(txq->qdisc); + if (likely(qdisc_is_empty(sch))) + goto noxoff; + + /* can check for explicit qdisc len value only !NOLOCK, + * always set flow off otherwise + */ + high = (caifd->netdev->tx_queue_len * q_high) / 100; + if (!(sch->flags & TCQ_F_NOLOCK) && likely(sch->q.qlen < high)) + goto noxoff; + } + + /* Hold lock while accessing xoff */ + spin_lock_bh(&caifd->flow_lock); + if (caifd->xoff) { + spin_unlock_bh(&caifd->flow_lock); + goto noxoff; + } + + /* + * Handle flow off, we do this by temporary hi-jacking this + * skb's destructor function, and replace it with our own + * flow-on callback. The callback will set flow-on and call + * the original destructor. + */ + + pr_debug("queue has stopped(%d) or is full (%d > %d)\n", + netif_queue_stopped(caifd->netdev), + qlen, high); + caifd->xoff = true; + caifd->xoff_skb = skb; + caifd->xoff_skb_dtor = skb->destructor; + skb->destructor = caif_flow_cb; + spin_unlock_bh(&caifd->flow_lock); + + caifd->layer.up->ctrlcmd(caifd->layer.up, + _CAIF_CTRLCMD_PHYIF_FLOW_OFF_IND, + caifd->layer.id); +noxoff: + rcu_read_unlock_bh(); + + err = dev_queue_xmit(skb); + if (err > 0) + err = -EIO; + + return err; +} + +/* + * Stuff received packets into the CAIF stack. + * On error, returns non-zero and releases the skb. + */ +static int receive(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pkttype, struct net_device *orig_dev) +{ + struct cfpkt *pkt; + struct caif_device_entry *caifd; + int err; + + pkt = cfpkt_fromnative(CAIF_DIR_IN, skb); + + rcu_read_lock(); + caifd = caif_get(dev); + + if (!caifd || !caifd->layer.up || !caifd->layer.up->receive || + !netif_oper_up(caifd->netdev)) { + rcu_read_unlock(); + kfree_skb(skb); + return NET_RX_DROP; + } + + /* Hold reference to netdevice while using CAIF stack */ + caifd_hold(caifd); + rcu_read_unlock(); + + err = caifd->layer.up->receive(caifd->layer.up, pkt); + + /* For -EILSEQ the packet is not freed so free it now */ + if (err == -EILSEQ) + cfpkt_destroy(pkt); + + /* Release reference to stack upwards */ + caifd_put(caifd); + + if (err != 0) + err = NET_RX_DROP; + return err; +} + +static struct packet_type caif_packet_type __read_mostly = { + .type = cpu_to_be16(ETH_P_CAIF), + .func = receive, +}; + +static void dev_flowctrl(struct net_device *dev, int on) +{ + struct caif_device_entry *caifd; + + rcu_read_lock(); + + caifd = caif_get(dev); + if (!caifd || !caifd->layer.up || !caifd->layer.up->ctrlcmd) { + rcu_read_unlock(); + return; + } + + caifd_hold(caifd); + rcu_read_unlock(); + + caifd->layer.up->ctrlcmd(caifd->layer.up, + on ? + _CAIF_CTRLCMD_PHYIF_FLOW_ON_IND : + _CAIF_CTRLCMD_PHYIF_FLOW_OFF_IND, + caifd->layer.id); + caifd_put(caifd); +} + +int caif_enroll_dev(struct net_device *dev, struct caif_dev_common *caifdev, + struct cflayer *link_support, int head_room, + struct cflayer **layer, + int (**rcv_func)(struct sk_buff *, struct net_device *, + struct packet_type *, + struct net_device *)) +{ + struct caif_device_entry *caifd; + enum cfcnfg_phy_preference pref; + struct cfcnfg *cfg = get_cfcnfg(dev_net(dev)); + struct caif_device_entry_list *caifdevs; + int res; + + caifdevs = caif_device_list(dev_net(dev)); + caifd = caif_device_alloc(dev); + if (!caifd) + return -ENOMEM; + *layer = &caifd->layer; + spin_lock_init(&caifd->flow_lock); + + switch (caifdev->link_select) { + case CAIF_LINK_HIGH_BANDW: + pref = CFPHYPREF_HIGH_BW; + break; + case CAIF_LINK_LOW_LATENCY: + pref = CFPHYPREF_LOW_LAT; + break; + default: + pref = CFPHYPREF_HIGH_BW; + break; + } + mutex_lock(&caifdevs->lock); + list_add_rcu(&caifd->list, &caifdevs->list); + + strscpy(caifd->layer.name, dev->name, + sizeof(caifd->layer.name)); + caifd->layer.transmit = transmit; + res = cfcnfg_add_phy_layer(cfg, + dev, + &caifd->layer, + pref, + link_support, + caifdev->use_fcs, + head_room); + mutex_unlock(&caifdevs->lock); + if (rcv_func) + *rcv_func = receive; + return res; +} +EXPORT_SYMBOL(caif_enroll_dev); + +/* notify Caif of device events */ +static int caif_device_notify(struct notifier_block *me, unsigned long what, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct caif_device_entry *caifd = NULL; + struct caif_dev_common *caifdev; + struct cfcnfg *cfg; + struct cflayer *layer, *link_support; + int head_room = 0; + struct caif_device_entry_list *caifdevs; + int res; + + cfg = get_cfcnfg(dev_net(dev)); + caifdevs = caif_device_list(dev_net(dev)); + + caifd = caif_get(dev); + if (caifd == NULL && dev->type != ARPHRD_CAIF) + return 0; + + switch (what) { + case NETDEV_REGISTER: + if (caifd != NULL) + break; + + caifdev = netdev_priv(dev); + + link_support = NULL; + if (caifdev->use_frag) { + head_room = 1; + link_support = cfserl_create(dev->ifindex, + caifdev->use_stx); + if (!link_support) { + pr_warn("Out of memory\n"); + break; + } + } + res = caif_enroll_dev(dev, caifdev, link_support, head_room, + &layer, NULL); + if (res) + cfserl_release(link_support); + caifdev->flowctrl = dev_flowctrl; + break; + + case NETDEV_UP: + rcu_read_lock(); + + caifd = caif_get(dev); + if (caifd == NULL) { + rcu_read_unlock(); + break; + } + + caifd->xoff = false; + cfcnfg_set_phy_state(cfg, &caifd->layer, true); + rcu_read_unlock(); + + break; + + case NETDEV_DOWN: + rcu_read_lock(); + + caifd = caif_get(dev); + if (!caifd || !caifd->layer.up || !caifd->layer.up->ctrlcmd) { + rcu_read_unlock(); + return -EINVAL; + } + + cfcnfg_set_phy_state(cfg, &caifd->layer, false); + caifd_hold(caifd); + rcu_read_unlock(); + + caifd->layer.up->ctrlcmd(caifd->layer.up, + _CAIF_CTRLCMD_PHYIF_DOWN_IND, + caifd->layer.id); + + spin_lock_bh(&caifd->flow_lock); + + /* + * Replace our xoff-destructor with original destructor. + * We trust that skb->destructor *always* is called before + * the skb reference is invalid. The hijacked SKB destructor + * takes the flow_lock so manipulating the skb->destructor here + * should be safe. + */ + if (caifd->xoff_skb_dtor != NULL && caifd->xoff_skb != NULL) + caifd->xoff_skb->destructor = caifd->xoff_skb_dtor; + + caifd->xoff = false; + caifd->xoff_skb_dtor = NULL; + caifd->xoff_skb = NULL; + + spin_unlock_bh(&caifd->flow_lock); + caifd_put(caifd); + break; + + case NETDEV_UNREGISTER: + mutex_lock(&caifdevs->lock); + + caifd = caif_get(dev); + if (caifd == NULL) { + mutex_unlock(&caifdevs->lock); + break; + } + list_del_rcu(&caifd->list); + + /* + * NETDEV_UNREGISTER is called repeatedly until all reference + * counts for the net-device are released. If references to + * caifd is taken, simply ignore NETDEV_UNREGISTER and wait for + * the next call to NETDEV_UNREGISTER. + * + * If any packets are in flight down the CAIF Stack, + * cfcnfg_del_phy_layer will return nonzero. + * If no packets are in flight, the CAIF Stack associated + * with the net-device un-registering is freed. + */ + + if (caifd_refcnt_read(caifd) != 0 || + cfcnfg_del_phy_layer(cfg, &caifd->layer) != 0) { + + pr_info("Wait for device inuse\n"); + /* Enrole device if CAIF Stack is still in use */ + list_add_rcu(&caifd->list, &caifdevs->list); + mutex_unlock(&caifdevs->lock); + break; + } + + synchronize_rcu(); + dev_put(caifd->netdev); + free_percpu(caifd->pcpu_refcnt); + kfree(caifd); + + mutex_unlock(&caifdevs->lock); + break; + } + return 0; +} + +static struct notifier_block caif_device_notifier = { + .notifier_call = caif_device_notify, + .priority = 0, +}; + +/* Per-namespace Caif devices handling */ +static int caif_init_net(struct net *net) +{ + struct caif_net *caifn = net_generic(net, caif_net_id); + INIT_LIST_HEAD(&caifn->caifdevs.list); + mutex_init(&caifn->caifdevs.lock); + + caifn->cfg = cfcnfg_create(); + if (!caifn->cfg) + return -ENOMEM; + + return 0; +} + +static void caif_exit_net(struct net *net) +{ + struct caif_device_entry *caifd, *tmp; + struct caif_device_entry_list *caifdevs = + caif_device_list(net); + struct cfcnfg *cfg = get_cfcnfg(net); + + rtnl_lock(); + mutex_lock(&caifdevs->lock); + + list_for_each_entry_safe(caifd, tmp, &caifdevs->list, list) { + int i = 0; + list_del_rcu(&caifd->list); + cfcnfg_set_phy_state(cfg, &caifd->layer, false); + + while (i < 10 && + (caifd_refcnt_read(caifd) != 0 || + cfcnfg_del_phy_layer(cfg, &caifd->layer) != 0)) { + + pr_info("Wait for device inuse\n"); + msleep(250); + i++; + } + synchronize_rcu(); + dev_put(caifd->netdev); + free_percpu(caifd->pcpu_refcnt); + kfree(caifd); + } + cfcnfg_remove(cfg); + + mutex_unlock(&caifdevs->lock); + rtnl_unlock(); +} + +static struct pernet_operations caif_net_ops = { + .init = caif_init_net, + .exit = caif_exit_net, + .id = &caif_net_id, + .size = sizeof(struct caif_net), +}; + +/* Initialize Caif devices list */ +static int __init caif_device_init(void) +{ + int result; + + result = register_pernet_subsys(&caif_net_ops); + + if (result) + return result; + + register_netdevice_notifier(&caif_device_notifier); + dev_add_pack(&caif_packet_type); + + return result; +} + +static void __exit caif_device_exit(void) +{ + unregister_netdevice_notifier(&caif_device_notifier); + dev_remove_pack(&caif_packet_type); + unregister_pernet_subsys(&caif_net_ops); +} + +module_init(caif_device_init); +module_exit(caif_device_exit); diff --git a/net/caif/caif_socket.c b/net/caif/caif_socket.c new file mode 100644 index 000000000..78c9729a6 --- /dev/null +++ b/net/caif/caif_socket.c @@ -0,0 +1,1119 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) ST-Ericsson AB 2010 + * Author: Sjur Brendeland + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s(): " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(AF_CAIF); + +/* + * CAIF state is re-using the TCP socket states. + * caif_states stored in sk_state reflect the state as reported by + * the CAIF stack, while sk_socket->state is the state of the socket. + */ +enum caif_states { + CAIF_CONNECTED = TCP_ESTABLISHED, + CAIF_CONNECTING = TCP_SYN_SENT, + CAIF_DISCONNECTED = TCP_CLOSE +}; + +#define TX_FLOW_ON_BIT 1 +#define RX_FLOW_ON_BIT 2 + +struct caifsock { + struct sock sk; /* must be first member */ + struct cflayer layer; + unsigned long flow_state; + struct caif_connect_request conn_req; + struct mutex readlock; + struct dentry *debugfs_socket_dir; + int headroom, tailroom, maxframe; +}; + +static int rx_flow_is_on(struct caifsock *cf_sk) +{ + return test_bit(RX_FLOW_ON_BIT, &cf_sk->flow_state); +} + +static int tx_flow_is_on(struct caifsock *cf_sk) +{ + return test_bit(TX_FLOW_ON_BIT, &cf_sk->flow_state); +} + +static void set_rx_flow_off(struct caifsock *cf_sk) +{ + clear_bit(RX_FLOW_ON_BIT, &cf_sk->flow_state); +} + +static void set_rx_flow_on(struct caifsock *cf_sk) +{ + set_bit(RX_FLOW_ON_BIT, &cf_sk->flow_state); +} + +static void set_tx_flow_off(struct caifsock *cf_sk) +{ + clear_bit(TX_FLOW_ON_BIT, &cf_sk->flow_state); +} + +static void set_tx_flow_on(struct caifsock *cf_sk) +{ + set_bit(TX_FLOW_ON_BIT, &cf_sk->flow_state); +} + +static void caif_read_lock(struct sock *sk) +{ + struct caifsock *cf_sk; + cf_sk = container_of(sk, struct caifsock, sk); + mutex_lock(&cf_sk->readlock); +} + +static void caif_read_unlock(struct sock *sk) +{ + struct caifsock *cf_sk; + cf_sk = container_of(sk, struct caifsock, sk); + mutex_unlock(&cf_sk->readlock); +} + +static int sk_rcvbuf_lowwater(struct caifsock *cf_sk) +{ + /* A quarter of full buffer is used a low water mark */ + return cf_sk->sk.sk_rcvbuf / 4; +} + +static void caif_flow_ctrl(struct sock *sk, int mode) +{ + struct caifsock *cf_sk; + cf_sk = container_of(sk, struct caifsock, sk); + if (cf_sk->layer.dn && cf_sk->layer.dn->modemcmd) + cf_sk->layer.dn->modemcmd(cf_sk->layer.dn, mode); +} + +/* + * Copied from sock.c:sock_queue_rcv_skb(), but changed so packets are + * not dropped, but CAIF is sending flow off instead. + */ +static void caif_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) +{ + int err; + unsigned long flags; + struct sk_buff_head *list = &sk->sk_receive_queue; + struct caifsock *cf_sk = container_of(sk, struct caifsock, sk); + bool queued = false; + + if (atomic_read(&sk->sk_rmem_alloc) + skb->truesize >= + (unsigned int)sk->sk_rcvbuf && rx_flow_is_on(cf_sk)) { + net_dbg_ratelimited("sending flow OFF (queue len = %d %d)\n", + atomic_read(&cf_sk->sk.sk_rmem_alloc), + sk_rcvbuf_lowwater(cf_sk)); + set_rx_flow_off(cf_sk); + caif_flow_ctrl(sk, CAIF_MODEMCMD_FLOW_OFF_REQ); + } + + err = sk_filter(sk, skb); + if (err) + goto out; + + if (!sk_rmem_schedule(sk, skb, skb->truesize) && rx_flow_is_on(cf_sk)) { + set_rx_flow_off(cf_sk); + net_dbg_ratelimited("sending flow OFF due to rmem_schedule\n"); + caif_flow_ctrl(sk, CAIF_MODEMCMD_FLOW_OFF_REQ); + } + skb->dev = NULL; + skb_set_owner_r(skb, sk); + spin_lock_irqsave(&list->lock, flags); + queued = !sock_flag(sk, SOCK_DEAD); + if (queued) + __skb_queue_tail(list, skb); + spin_unlock_irqrestore(&list->lock, flags); +out: + if (queued) + sk->sk_data_ready(sk); + else + kfree_skb(skb); +} + +/* Packet Receive Callback function called from CAIF Stack */ +static int caif_sktrecv_cb(struct cflayer *layr, struct cfpkt *pkt) +{ + struct caifsock *cf_sk; + struct sk_buff *skb; + + cf_sk = container_of(layr, struct caifsock, layer); + skb = cfpkt_tonative(pkt); + + if (unlikely(cf_sk->sk.sk_state != CAIF_CONNECTED)) { + kfree_skb(skb); + return 0; + } + caif_queue_rcv_skb(&cf_sk->sk, skb); + return 0; +} + +static void cfsk_hold(struct cflayer *layr) +{ + struct caifsock *cf_sk = container_of(layr, struct caifsock, layer); + sock_hold(&cf_sk->sk); +} + +static void cfsk_put(struct cflayer *layr) +{ + struct caifsock *cf_sk = container_of(layr, struct caifsock, layer); + sock_put(&cf_sk->sk); +} + +/* Packet Control Callback function called from CAIF */ +static void caif_ctrl_cb(struct cflayer *layr, + enum caif_ctrlcmd flow, + int phyid) +{ + struct caifsock *cf_sk = container_of(layr, struct caifsock, layer); + switch (flow) { + case CAIF_CTRLCMD_FLOW_ON_IND: + /* OK from modem to start sending again */ + set_tx_flow_on(cf_sk); + cf_sk->sk.sk_state_change(&cf_sk->sk); + break; + + case CAIF_CTRLCMD_FLOW_OFF_IND: + /* Modem asks us to shut up */ + set_tx_flow_off(cf_sk); + cf_sk->sk.sk_state_change(&cf_sk->sk); + break; + + case CAIF_CTRLCMD_INIT_RSP: + /* We're now connected */ + caif_client_register_refcnt(&cf_sk->layer, + cfsk_hold, cfsk_put); + cf_sk->sk.sk_state = CAIF_CONNECTED; + set_tx_flow_on(cf_sk); + cf_sk->sk.sk_shutdown = 0; + cf_sk->sk.sk_state_change(&cf_sk->sk); + break; + + case CAIF_CTRLCMD_DEINIT_RSP: + /* We're now disconnected */ + cf_sk->sk.sk_state = CAIF_DISCONNECTED; + cf_sk->sk.sk_state_change(&cf_sk->sk); + break; + + case CAIF_CTRLCMD_INIT_FAIL_RSP: + /* Connect request failed */ + cf_sk->sk.sk_err = ECONNREFUSED; + cf_sk->sk.sk_state = CAIF_DISCONNECTED; + cf_sk->sk.sk_shutdown = SHUTDOWN_MASK; + /* + * Socket "standards" seems to require POLLOUT to + * be set at connect failure. + */ + set_tx_flow_on(cf_sk); + cf_sk->sk.sk_state_change(&cf_sk->sk); + break; + + case CAIF_CTRLCMD_REMOTE_SHUTDOWN_IND: + /* Modem has closed this connection, or device is down. */ + cf_sk->sk.sk_shutdown = SHUTDOWN_MASK; + cf_sk->sk.sk_err = ECONNRESET; + set_rx_flow_on(cf_sk); + sk_error_report(&cf_sk->sk); + break; + + default: + pr_debug("Unexpected flow command %d\n", flow); + } +} + +static void caif_check_flow_release(struct sock *sk) +{ + struct caifsock *cf_sk = container_of(sk, struct caifsock, sk); + + if (rx_flow_is_on(cf_sk)) + return; + + if (atomic_read(&sk->sk_rmem_alloc) <= sk_rcvbuf_lowwater(cf_sk)) { + set_rx_flow_on(cf_sk); + caif_flow_ctrl(sk, CAIF_MODEMCMD_FLOW_ON_REQ); + } +} + +/* + * Copied from unix_dgram_recvmsg, but removed credit checks, + * changed locking, address handling and added MSG_TRUNC. + */ +static int caif_seqpkt_recvmsg(struct socket *sock, struct msghdr *m, + size_t len, int flags) + +{ + struct sock *sk = sock->sk; + struct sk_buff *skb; + int ret; + int copylen; + + ret = -EOPNOTSUPP; + if (flags & MSG_OOB) + goto read_error; + + skb = skb_recv_datagram(sk, flags, &ret); + if (!skb) + goto read_error; + copylen = skb->len; + if (len < copylen) { + m->msg_flags |= MSG_TRUNC; + copylen = len; + } + + ret = skb_copy_datagram_msg(skb, 0, m, copylen); + if (ret) + goto out_free; + + ret = (flags & MSG_TRUNC) ? skb->len : copylen; +out_free: + skb_free_datagram(sk, skb); + caif_check_flow_release(sk); + return ret; + +read_error: + return ret; +} + + +/* Copied from unix_stream_wait_data, identical except for lock call. */ +static long caif_stream_data_wait(struct sock *sk, long timeo) +{ + DEFINE_WAIT(wait); + lock_sock(sk); + + for (;;) { + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + + if (!skb_queue_empty(&sk->sk_receive_queue) || + sk->sk_err || + sk->sk_state != CAIF_CONNECTED || + sock_flag(sk, SOCK_DEAD) || + (sk->sk_shutdown & RCV_SHUTDOWN) || + signal_pending(current) || + !timeo) + break; + + sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); + release_sock(sk); + timeo = schedule_timeout(timeo); + lock_sock(sk); + + if (sock_flag(sk, SOCK_DEAD)) + break; + + sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); + } + + finish_wait(sk_sleep(sk), &wait); + release_sock(sk); + return timeo; +} + + +/* + * Copied from unix_stream_recvmsg, but removed credit checks, + * changed locking calls, changed address handling. + */ +static int caif_stream_recvmsg(struct socket *sock, struct msghdr *msg, + size_t size, int flags) +{ + struct sock *sk = sock->sk; + int copied = 0; + int target; + int err = 0; + long timeo; + + err = -EOPNOTSUPP; + if (flags&MSG_OOB) + goto out; + + /* + * Lock the socket to prevent queue disordering + * while sleeps in memcpy_tomsg + */ + err = -EAGAIN; + if (sk->sk_state == CAIF_CONNECTING) + goto out; + + caif_read_lock(sk); + target = sock_rcvlowat(sk, flags&MSG_WAITALL, size); + timeo = sock_rcvtimeo(sk, flags&MSG_DONTWAIT); + + do { + int chunk; + struct sk_buff *skb; + + lock_sock(sk); + if (sock_flag(sk, SOCK_DEAD)) { + err = -ECONNRESET; + goto unlock; + } + skb = skb_dequeue(&sk->sk_receive_queue); + caif_check_flow_release(sk); + + if (skb == NULL) { + if (copied >= target) + goto unlock; + /* + * POSIX 1003.1g mandates this order. + */ + err = sock_error(sk); + if (err) + goto unlock; + err = -ECONNRESET; + if (sk->sk_shutdown & RCV_SHUTDOWN) + goto unlock; + + err = -EPIPE; + if (sk->sk_state != CAIF_CONNECTED) + goto unlock; + if (sock_flag(sk, SOCK_DEAD)) + goto unlock; + + release_sock(sk); + + err = -EAGAIN; + if (!timeo) + break; + + caif_read_unlock(sk); + + timeo = caif_stream_data_wait(sk, timeo); + + if (signal_pending(current)) { + err = sock_intr_errno(timeo); + goto out; + } + caif_read_lock(sk); + continue; +unlock: + release_sock(sk); + break; + } + release_sock(sk); + chunk = min_t(unsigned int, skb->len, size); + if (memcpy_to_msg(msg, skb->data, chunk)) { + skb_queue_head(&sk->sk_receive_queue, skb); + if (copied == 0) + copied = -EFAULT; + break; + } + copied += chunk; + size -= chunk; + + /* Mark read part of skb as used */ + if (!(flags & MSG_PEEK)) { + skb_pull(skb, chunk); + + /* put the skb back if we didn't use it up. */ + if (skb->len) { + skb_queue_head(&sk->sk_receive_queue, skb); + break; + } + kfree_skb(skb); + + } else { + /* + * It is questionable, see note in unix_dgram_recvmsg. + */ + /* put message back and return */ + skb_queue_head(&sk->sk_receive_queue, skb); + break; + } + } while (size); + caif_read_unlock(sk); + +out: + return copied ? : err; +} + +/* + * Copied from sock.c:sock_wait_for_wmem, but change to wait for + * CAIF flow-on and sock_writable. + */ +static long caif_wait_for_flow_on(struct caifsock *cf_sk, + int wait_writeable, long timeo, int *err) +{ + struct sock *sk = &cf_sk->sk; + DEFINE_WAIT(wait); + for (;;) { + *err = 0; + if (tx_flow_is_on(cf_sk) && + (!wait_writeable || sock_writeable(&cf_sk->sk))) + break; + *err = -ETIMEDOUT; + if (!timeo) + break; + *err = -ERESTARTSYS; + if (signal_pending(current)) + break; + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + *err = -ECONNRESET; + if (sk->sk_shutdown & SHUTDOWN_MASK) + break; + *err = -sk->sk_err; + if (sk->sk_err) + break; + *err = -EPIPE; + if (cf_sk->sk.sk_state != CAIF_CONNECTED) + break; + timeo = schedule_timeout(timeo); + } + finish_wait(sk_sleep(sk), &wait); + return timeo; +} + +/* + * Transmit a SKB. The device may temporarily request re-transmission + * by returning EAGAIN. + */ +static int transmit_skb(struct sk_buff *skb, struct caifsock *cf_sk, + int noblock, long timeo) +{ + struct cfpkt *pkt; + + pkt = cfpkt_fromnative(CAIF_DIR_OUT, skb); + memset(skb->cb, 0, sizeof(struct caif_payload_info)); + cfpkt_set_prio(pkt, cf_sk->sk.sk_priority); + + if (cf_sk->layer.dn == NULL) { + kfree_skb(skb); + return -EINVAL; + } + + return cf_sk->layer.dn->transmit(cf_sk->layer.dn, pkt); +} + +/* Copied from af_unix:unix_dgram_sendmsg, and adapted to CAIF */ +static int caif_seqpkt_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + struct sock *sk = sock->sk; + struct caifsock *cf_sk = container_of(sk, struct caifsock, sk); + int buffer_size; + int ret = 0; + struct sk_buff *skb = NULL; + int noblock; + long timeo; + caif_assert(cf_sk); + ret = sock_error(sk); + if (ret) + goto err; + + ret = -EOPNOTSUPP; + if (msg->msg_flags&MSG_OOB) + goto err; + + ret = -EOPNOTSUPP; + if (msg->msg_namelen) + goto err; + + ret = -EINVAL; + if (unlikely(msg->msg_iter.nr_segs == 0) || + unlikely(msg->msg_iter.iov->iov_base == NULL)) + goto err; + noblock = msg->msg_flags & MSG_DONTWAIT; + + timeo = sock_sndtimeo(sk, noblock); + timeo = caif_wait_for_flow_on(container_of(sk, struct caifsock, sk), + 1, timeo, &ret); + + if (ret) + goto err; + ret = -EPIPE; + if (cf_sk->sk.sk_state != CAIF_CONNECTED || + sock_flag(sk, SOCK_DEAD) || + (sk->sk_shutdown & RCV_SHUTDOWN)) + goto err; + + /* Error if trying to write more than maximum frame size. */ + ret = -EMSGSIZE; + if (len > cf_sk->maxframe && cf_sk->sk.sk_protocol != CAIFPROTO_RFM) + goto err; + + buffer_size = len + cf_sk->headroom + cf_sk->tailroom; + + ret = -ENOMEM; + skb = sock_alloc_send_skb(sk, buffer_size, noblock, &ret); + + if (!skb || skb_tailroom(skb) < buffer_size) + goto err; + + skb_reserve(skb, cf_sk->headroom); + + ret = memcpy_from_msg(skb_put(skb, len), msg, len); + + if (ret) + goto err; + ret = transmit_skb(skb, cf_sk, noblock, timeo); + if (ret < 0) + /* skb is already freed */ + return ret; + + return len; +err: + kfree_skb(skb); + return ret; +} + +/* + * Copied from unix_stream_sendmsg and adapted to CAIF: + * Changed removed permission handling and added waiting for flow on + * and other minor adaptations. + */ +static int caif_stream_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + struct sock *sk = sock->sk; + struct caifsock *cf_sk = container_of(sk, struct caifsock, sk); + int err, size; + struct sk_buff *skb; + int sent = 0; + long timeo; + + err = -EOPNOTSUPP; + if (unlikely(msg->msg_flags&MSG_OOB)) + goto out_err; + + if (unlikely(msg->msg_namelen)) + goto out_err; + + timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + timeo = caif_wait_for_flow_on(cf_sk, 1, timeo, &err); + + if (unlikely(sk->sk_shutdown & SEND_SHUTDOWN)) + goto pipe_err; + + while (sent < len) { + + size = len-sent; + + if (size > cf_sk->maxframe) + size = cf_sk->maxframe; + + /* If size is more than half of sndbuf, chop up message */ + if (size > ((sk->sk_sndbuf >> 1) - 64)) + size = (sk->sk_sndbuf >> 1) - 64; + + if (size > SKB_MAX_ALLOC) + size = SKB_MAX_ALLOC; + + skb = sock_alloc_send_skb(sk, + size + cf_sk->headroom + + cf_sk->tailroom, + msg->msg_flags&MSG_DONTWAIT, + &err); + if (skb == NULL) + goto out_err; + + skb_reserve(skb, cf_sk->headroom); + /* + * If you pass two values to the sock_alloc_send_skb + * it tries to grab the large buffer with GFP_NOFS + * (which can fail easily), and if it fails grab the + * fallback size buffer which is under a page and will + * succeed. [Alan] + */ + size = min_t(int, size, skb_tailroom(skb)); + + err = memcpy_from_msg(skb_put(skb, size), msg, size); + if (err) { + kfree_skb(skb); + goto out_err; + } + err = transmit_skb(skb, cf_sk, + msg->msg_flags&MSG_DONTWAIT, timeo); + if (err < 0) + /* skb is already freed */ + goto pipe_err; + + sent += size; + } + + return sent; + +pipe_err: + if (sent == 0 && !(msg->msg_flags&MSG_NOSIGNAL)) + send_sig(SIGPIPE, current, 0); + err = -EPIPE; +out_err: + return sent ? : err; +} + +static int setsockopt(struct socket *sock, int lvl, int opt, sockptr_t ov, + unsigned int ol) +{ + struct sock *sk = sock->sk; + struct caifsock *cf_sk = container_of(sk, struct caifsock, sk); + int linksel; + + if (cf_sk->sk.sk_socket->state != SS_UNCONNECTED) + return -ENOPROTOOPT; + + switch (opt) { + case CAIFSO_LINK_SELECT: + if (ol < sizeof(int)) + return -EINVAL; + if (lvl != SOL_CAIF) + goto bad_sol; + if (copy_from_sockptr(&linksel, ov, sizeof(int))) + return -EINVAL; + lock_sock(&(cf_sk->sk)); + cf_sk->conn_req.link_selector = linksel; + release_sock(&cf_sk->sk); + return 0; + + case CAIFSO_REQ_PARAM: + if (lvl != SOL_CAIF) + goto bad_sol; + if (cf_sk->sk.sk_protocol != CAIFPROTO_UTIL) + return -ENOPROTOOPT; + lock_sock(&(cf_sk->sk)); + if (ol > sizeof(cf_sk->conn_req.param.data) || + copy_from_sockptr(&cf_sk->conn_req.param.data, ov, ol)) { + release_sock(&cf_sk->sk); + return -EINVAL; + } + cf_sk->conn_req.param.size = ol; + release_sock(&cf_sk->sk); + return 0; + + default: + return -ENOPROTOOPT; + } + + return 0; +bad_sol: + return -ENOPROTOOPT; + +} + +/* + * caif_connect() - Connect a CAIF Socket + * Copied and modified af_irda.c:irda_connect(). + * + * Note : by consulting "errno", the user space caller may learn the cause + * of the failure. Most of them are visible in the function, others may come + * from subroutines called and are listed here : + * o -EAFNOSUPPORT: bad socket family or type. + * o -ESOCKTNOSUPPORT: bad socket type or protocol + * o -EINVAL: bad socket address, or CAIF link type + * o -ECONNREFUSED: remote end refused the connection. + * o -EINPROGRESS: connect request sent but timed out (or non-blocking) + * o -EISCONN: already connected. + * o -ETIMEDOUT: Connection timed out (send timeout) + * o -ENODEV: No link layer to send request + * o -ECONNRESET: Received Shutdown indication or lost link layer + * o -ENOMEM: Out of memory + * + * State Strategy: + * o sk_state: holds the CAIF_* protocol state, it's updated by + * caif_ctrl_cb. + * o sock->state: holds the SS_* socket state and is updated by connect and + * disconnect. + */ +static int caif_connect(struct socket *sock, struct sockaddr *uaddr, + int addr_len, int flags) +{ + struct sock *sk = sock->sk; + struct caifsock *cf_sk = container_of(sk, struct caifsock, sk); + long timeo; + int err; + int ifindex, headroom, tailroom; + unsigned int mtu; + struct net_device *dev; + + lock_sock(sk); + + err = -EINVAL; + if (addr_len < offsetofend(struct sockaddr, sa_family)) + goto out; + + err = -EAFNOSUPPORT; + if (uaddr->sa_family != AF_CAIF) + goto out; + + switch (sock->state) { + case SS_UNCONNECTED: + /* Normal case, a fresh connect */ + caif_assert(sk->sk_state == CAIF_DISCONNECTED); + break; + case SS_CONNECTING: + switch (sk->sk_state) { + case CAIF_CONNECTED: + sock->state = SS_CONNECTED; + err = -EISCONN; + goto out; + case CAIF_DISCONNECTED: + /* Reconnect allowed */ + break; + case CAIF_CONNECTING: + err = -EALREADY; + if (flags & O_NONBLOCK) + goto out; + goto wait_connect; + } + break; + case SS_CONNECTED: + caif_assert(sk->sk_state == CAIF_CONNECTED || + sk->sk_state == CAIF_DISCONNECTED); + if (sk->sk_shutdown & SHUTDOWN_MASK) { + /* Allow re-connect after SHUTDOWN_IND */ + caif_disconnect_client(sock_net(sk), &cf_sk->layer); + caif_free_client(&cf_sk->layer); + break; + } + /* No reconnect on a seqpacket socket */ + err = -EISCONN; + goto out; + case SS_DISCONNECTING: + case SS_FREE: + caif_assert(1); /*Should never happen */ + break; + } + sk->sk_state = CAIF_DISCONNECTED; + sock->state = SS_UNCONNECTED; + sk_stream_kill_queues(&cf_sk->sk); + + err = -EINVAL; + if (addr_len != sizeof(struct sockaddr_caif)) + goto out; + + memcpy(&cf_sk->conn_req.sockaddr, uaddr, + sizeof(struct sockaddr_caif)); + + /* Move to connecting socket, start sending Connect Requests */ + sock->state = SS_CONNECTING; + sk->sk_state = CAIF_CONNECTING; + + /* Check priority value comming from socket */ + /* if priority value is out of range it will be ajusted */ + if (cf_sk->sk.sk_priority > CAIF_PRIO_MAX) + cf_sk->conn_req.priority = CAIF_PRIO_MAX; + else if (cf_sk->sk.sk_priority < CAIF_PRIO_MIN) + cf_sk->conn_req.priority = CAIF_PRIO_MIN; + else + cf_sk->conn_req.priority = cf_sk->sk.sk_priority; + + /*ifindex = id of the interface.*/ + cf_sk->conn_req.ifindex = cf_sk->sk.sk_bound_dev_if; + + cf_sk->layer.receive = caif_sktrecv_cb; + + err = caif_connect_client(sock_net(sk), &cf_sk->conn_req, + &cf_sk->layer, &ifindex, &headroom, &tailroom); + + if (err < 0) { + cf_sk->sk.sk_socket->state = SS_UNCONNECTED; + cf_sk->sk.sk_state = CAIF_DISCONNECTED; + goto out; + } + + err = -ENODEV; + rcu_read_lock(); + dev = dev_get_by_index_rcu(sock_net(sk), ifindex); + if (!dev) { + rcu_read_unlock(); + goto out; + } + cf_sk->headroom = LL_RESERVED_SPACE_EXTRA(dev, headroom); + mtu = dev->mtu; + rcu_read_unlock(); + + cf_sk->tailroom = tailroom; + cf_sk->maxframe = mtu - (headroom + tailroom); + if (cf_sk->maxframe < 1) { + pr_warn("CAIF Interface MTU too small (%d)\n", dev->mtu); + err = -ENODEV; + goto out; + } + + err = -EINPROGRESS; +wait_connect: + + if (sk->sk_state != CAIF_CONNECTED && (flags & O_NONBLOCK)) + goto out; + + timeo = sock_sndtimeo(sk, flags & O_NONBLOCK); + + release_sock(sk); + err = -ERESTARTSYS; + timeo = wait_event_interruptible_timeout(*sk_sleep(sk), + sk->sk_state != CAIF_CONNECTING, + timeo); + lock_sock(sk); + if (timeo < 0) + goto out; /* -ERESTARTSYS */ + + err = -ETIMEDOUT; + if (timeo == 0 && sk->sk_state != CAIF_CONNECTED) + goto out; + if (sk->sk_state != CAIF_CONNECTED) { + sock->state = SS_UNCONNECTED; + err = sock_error(sk); + if (!err) + err = -ECONNREFUSED; + goto out; + } + sock->state = SS_CONNECTED; + err = 0; +out: + release_sock(sk); + return err; +} + +/* + * caif_release() - Disconnect a CAIF Socket + * Copied and modified af_irda.c:irda_release(). + */ +static int caif_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct caifsock *cf_sk = container_of(sk, struct caifsock, sk); + + if (!sk) + return 0; + + set_tx_flow_off(cf_sk); + + /* + * Ensure that packets are not queued after this point in time. + * caif_queue_rcv_skb checks SOCK_DEAD holding the queue lock, + * this ensures no packets when sock is dead. + */ + spin_lock_bh(&sk->sk_receive_queue.lock); + sock_set_flag(sk, SOCK_DEAD); + spin_unlock_bh(&sk->sk_receive_queue.lock); + sock->sk = NULL; + + WARN_ON(IS_ERR(cf_sk->debugfs_socket_dir)); + debugfs_remove_recursive(cf_sk->debugfs_socket_dir); + + lock_sock(&(cf_sk->sk)); + sk->sk_state = CAIF_DISCONNECTED; + sk->sk_shutdown = SHUTDOWN_MASK; + + caif_disconnect_client(sock_net(sk), &cf_sk->layer); + cf_sk->sk.sk_socket->state = SS_DISCONNECTING; + wake_up_interruptible_poll(sk_sleep(sk), EPOLLERR|EPOLLHUP); + + sock_orphan(sk); + sk_stream_kill_queues(&cf_sk->sk); + release_sock(sk); + sock_put(sk); + return 0; +} + +/* Copied from af_unix.c:unix_poll(), added CAIF tx_flow handling */ +static __poll_t caif_poll(struct file *file, + struct socket *sock, poll_table *wait) +{ + struct sock *sk = sock->sk; + __poll_t mask; + struct caifsock *cf_sk = container_of(sk, struct caifsock, sk); + + sock_poll_wait(file, sock, wait); + mask = 0; + + /* exceptional events? */ + if (sk->sk_err) + mask |= EPOLLERR; + if (sk->sk_shutdown == SHUTDOWN_MASK) + mask |= EPOLLHUP; + if (sk->sk_shutdown & RCV_SHUTDOWN) + mask |= EPOLLRDHUP; + + /* readable? */ + if (!skb_queue_empty_lockless(&sk->sk_receive_queue) || + (sk->sk_shutdown & RCV_SHUTDOWN)) + mask |= EPOLLIN | EPOLLRDNORM; + + /* + * we set writable also when the other side has shut down the + * connection. This prevents stuck sockets. + */ + if (sock_writeable(sk) && tx_flow_is_on(cf_sk)) + mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND; + + return mask; +} + +static const struct proto_ops caif_seqpacket_ops = { + .family = PF_CAIF, + .owner = THIS_MODULE, + .release = caif_release, + .bind = sock_no_bind, + .connect = caif_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = sock_no_getname, + .poll = caif_poll, + .ioctl = sock_no_ioctl, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = setsockopt, + .sendmsg = caif_seqpkt_sendmsg, + .recvmsg = caif_seqpkt_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static const struct proto_ops caif_stream_ops = { + .family = PF_CAIF, + .owner = THIS_MODULE, + .release = caif_release, + .bind = sock_no_bind, + .connect = caif_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = sock_no_getname, + .poll = caif_poll, + .ioctl = sock_no_ioctl, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = setsockopt, + .sendmsg = caif_stream_sendmsg, + .recvmsg = caif_stream_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +/* This function is called when a socket is finally destroyed. */ +static void caif_sock_destructor(struct sock *sk) +{ + struct caifsock *cf_sk = container_of(sk, struct caifsock, sk); + caif_assert(!refcount_read(&sk->sk_wmem_alloc)); + caif_assert(sk_unhashed(sk)); + caif_assert(!sk->sk_socket); + if (!sock_flag(sk, SOCK_DEAD)) { + pr_debug("Attempt to release alive CAIF socket: %p\n", sk); + return; + } + sk_stream_kill_queues(&cf_sk->sk); + WARN_ON_ONCE(sk->sk_forward_alloc); + caif_free_client(&cf_sk->layer); +} + +static int caif_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk = NULL; + struct caifsock *cf_sk = NULL; + static struct proto prot = {.name = "PF_CAIF", + .owner = THIS_MODULE, + .obj_size = sizeof(struct caifsock), + .useroffset = offsetof(struct caifsock, conn_req.param), + .usersize = sizeof_field(struct caifsock, conn_req.param) + }; + + if (!capable(CAP_SYS_ADMIN) && !capable(CAP_NET_ADMIN)) + return -EPERM; + /* + * The sock->type specifies the socket type to use. + * The CAIF socket is a packet stream in the sense + * that it is packet based. CAIF trusts the reliability + * of the link, no resending is implemented. + */ + if (sock->type == SOCK_SEQPACKET) + sock->ops = &caif_seqpacket_ops; + else if (sock->type == SOCK_STREAM) + sock->ops = &caif_stream_ops; + else + return -ESOCKTNOSUPPORT; + + if (protocol < 0 || protocol >= CAIFPROTO_MAX) + return -EPROTONOSUPPORT; + /* + * Set the socket state to unconnected. The socket state + * is really not used at all in the net/core or socket.c but the + * initialization makes sure that sock->state is not uninitialized. + */ + sk = sk_alloc(net, PF_CAIF, GFP_KERNEL, &prot, kern); + if (!sk) + return -ENOMEM; + + cf_sk = container_of(sk, struct caifsock, sk); + + /* Store the protocol */ + sk->sk_protocol = (unsigned char) protocol; + + /* Initialize default priority for well-known cases */ + switch (protocol) { + case CAIFPROTO_AT: + sk->sk_priority = TC_PRIO_CONTROL; + break; + case CAIFPROTO_RFM: + sk->sk_priority = TC_PRIO_INTERACTIVE_BULK; + break; + default: + sk->sk_priority = TC_PRIO_BESTEFFORT; + } + + /* + * Lock in order to try to stop someone from opening the socket + * too early. + */ + lock_sock(&(cf_sk->sk)); + + /* Initialize the nozero default sock structure data. */ + sock_init_data(sock, sk); + sk->sk_destruct = caif_sock_destructor; + + mutex_init(&cf_sk->readlock); /* single task reading lock */ + cf_sk->layer.ctrlcmd = caif_ctrl_cb; + cf_sk->sk.sk_socket->state = SS_UNCONNECTED; + cf_sk->sk.sk_state = CAIF_DISCONNECTED; + + set_tx_flow_off(cf_sk); + set_rx_flow_on(cf_sk); + + /* Set default options on configuration */ + cf_sk->conn_req.link_selector = CAIF_LINK_LOW_LATENCY; + cf_sk->conn_req.protocol = protocol; + release_sock(&cf_sk->sk); + return 0; +} + + +static const struct net_proto_family caif_family_ops = { + .family = PF_CAIF, + .create = caif_create, + .owner = THIS_MODULE, +}; + +static int __init caif_sktinit_module(void) +{ + return sock_register(&caif_family_ops); +} + +static void __exit caif_sktexit_module(void) +{ + sock_unregister(PF_CAIF); +} +module_init(caif_sktinit_module); +module_exit(caif_sktexit_module); diff --git a/net/caif/caif_usb.c b/net/caif/caif_usb.c new file mode 100644 index 000000000..bf61ea4b8 --- /dev/null +++ b/net/caif/caif_usb.c @@ -0,0 +1,215 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * CAIF USB handler + * Copyright (C) ST-Ericsson AB 2011 + * Author: Sjur Brendeland + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s(): " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); + +#define CFUSB_PAD_DESCR_SZ 1 /* Alignment descriptor length */ +#define CFUSB_ALIGNMENT 4 /* Number of bytes to align. */ +#define CFUSB_MAX_HEADLEN (CFUSB_PAD_DESCR_SZ + CFUSB_ALIGNMENT-1) +#define STE_USB_VID 0x04cc /* USB Product ID for ST-Ericsson */ +#define STE_USB_PID_CAIF 0x230f /* Product id for CAIF Modems */ + +struct cfusbl { + struct cflayer layer; + u8 tx_eth_hdr[ETH_HLEN]; +}; + +static bool pack_added; + +static int cfusbl_receive(struct cflayer *layr, struct cfpkt *pkt) +{ + u8 hpad; + + /* Remove padding. */ + cfpkt_extr_head(pkt, &hpad, 1); + cfpkt_extr_head(pkt, NULL, hpad); + return layr->up->receive(layr->up, pkt); +} + +static int cfusbl_transmit(struct cflayer *layr, struct cfpkt *pkt) +{ + struct caif_payload_info *info; + u8 hpad; + u8 zeros[CFUSB_ALIGNMENT]; + struct sk_buff *skb; + struct cfusbl *usbl = container_of(layr, struct cfusbl, layer); + + skb = cfpkt_tonative(pkt); + + skb_reset_network_header(skb); + skb->protocol = htons(ETH_P_IP); + + info = cfpkt_info(pkt); + hpad = (info->hdr_len + CFUSB_PAD_DESCR_SZ) & (CFUSB_ALIGNMENT - 1); + + if (skb_headroom(skb) < ETH_HLEN + CFUSB_PAD_DESCR_SZ + hpad) { + pr_warn("Headroom too small\n"); + kfree_skb(skb); + return -EIO; + } + memset(zeros, 0, hpad); + + cfpkt_add_head(pkt, zeros, hpad); + cfpkt_add_head(pkt, &hpad, 1); + cfpkt_add_head(pkt, usbl->tx_eth_hdr, sizeof(usbl->tx_eth_hdr)); + return layr->dn->transmit(layr->dn, pkt); +} + +static void cfusbl_ctrlcmd(struct cflayer *layr, enum caif_ctrlcmd ctrl, + int phyid) +{ + if (layr->up && layr->up->ctrlcmd) + layr->up->ctrlcmd(layr->up, ctrl, layr->id); +} + +static struct cflayer *cfusbl_create(int phyid, const u8 ethaddr[ETH_ALEN], + u8 braddr[ETH_ALEN]) +{ + struct cfusbl *this = kmalloc(sizeof(struct cfusbl), GFP_ATOMIC); + + if (!this) + return NULL; + + caif_assert(offsetof(struct cfusbl, layer) == 0); + + memset(&this->layer, 0, sizeof(this->layer)); + this->layer.receive = cfusbl_receive; + this->layer.transmit = cfusbl_transmit; + this->layer.ctrlcmd = cfusbl_ctrlcmd; + snprintf(this->layer.name, CAIF_LAYER_NAME_SZ, "usb%d", phyid); + this->layer.id = phyid; + + /* + * Construct TX ethernet header: + * 0-5 destination address + * 5-11 source address + * 12-13 protocol type + */ + ether_addr_copy(&this->tx_eth_hdr[ETH_ALEN], braddr); + ether_addr_copy(&this->tx_eth_hdr[ETH_ALEN], ethaddr); + this->tx_eth_hdr[12] = cpu_to_be16(ETH_P_802_EX1) & 0xff; + this->tx_eth_hdr[13] = (cpu_to_be16(ETH_P_802_EX1) >> 8) & 0xff; + pr_debug("caif ethernet TX-header dst:%pM src:%pM type:%02x%02x\n", + this->tx_eth_hdr, this->tx_eth_hdr + ETH_ALEN, + this->tx_eth_hdr[12], this->tx_eth_hdr[13]); + + return (struct cflayer *) this; +} + +static void cfusbl_release(struct cflayer *layer) +{ + kfree(layer); +} + +static struct packet_type caif_usb_type __read_mostly = { + .type = cpu_to_be16(ETH_P_802_EX1), +}; + +static int cfusbl_device_notify(struct notifier_block *me, unsigned long what, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct caif_dev_common common; + struct cflayer *layer, *link_support; + struct usbnet *usbnet; + struct usb_device *usbdev; + int res; + + if (what == NETDEV_UNREGISTER && dev->reg_state >= NETREG_UNREGISTERED) + return 0; + + /* Check whether we have a NCM device, and find its VID/PID. */ + if (!(dev->dev.parent && dev->dev.parent->driver && + strcmp(dev->dev.parent->driver->name, "cdc_ncm") == 0)) + return 0; + + usbnet = netdev_priv(dev); + usbdev = usbnet->udev; + + pr_debug("USB CDC NCM device VID:0x%4x PID:0x%4x\n", + le16_to_cpu(usbdev->descriptor.idVendor), + le16_to_cpu(usbdev->descriptor.idProduct)); + + /* Check for VID/PID that supports CAIF */ + if (!(le16_to_cpu(usbdev->descriptor.idVendor) == STE_USB_VID && + le16_to_cpu(usbdev->descriptor.idProduct) == STE_USB_PID_CAIF)) + return 0; + + if (what == NETDEV_UNREGISTER) + module_put(THIS_MODULE); + + if (what != NETDEV_REGISTER) + return 0; + + __module_get(THIS_MODULE); + + memset(&common, 0, sizeof(common)); + common.use_frag = false; + common.use_fcs = false; + common.use_stx = false; + common.link_select = CAIF_LINK_HIGH_BANDW; + common.flowctrl = NULL; + + link_support = cfusbl_create(dev->ifindex, dev->dev_addr, + dev->broadcast); + + if (!link_support) + return -ENOMEM; + + if (dev->num_tx_queues > 1) + pr_warn("USB device uses more than one tx queue\n"); + + res = caif_enroll_dev(dev, &common, link_support, CFUSB_MAX_HEADLEN, + &layer, &caif_usb_type.func); + if (res) + goto err; + + if (!pack_added) + dev_add_pack(&caif_usb_type); + pack_added = true; + + strscpy(layer->name, dev->name, sizeof(layer->name)); + + return 0; +err: + cfusbl_release(link_support); + return res; +} + +static struct notifier_block caif_device_notifier = { + .notifier_call = cfusbl_device_notify, + .priority = 0, +}; + +static int __init cfusbl_init(void) +{ + return register_netdevice_notifier(&caif_device_notifier); +} + +static void __exit cfusbl_exit(void) +{ + unregister_netdevice_notifier(&caif_device_notifier); + dev_remove_pack(&caif_usb_type); +} + +module_init(cfusbl_init); +module_exit(cfusbl_exit); diff --git a/net/caif/cfcnfg.c b/net/caif/cfcnfg.c new file mode 100644 index 000000000..52509e185 --- /dev/null +++ b/net/caif/cfcnfg.c @@ -0,0 +1,612 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) ST-Ericsson AB 2010 + * Author: Sjur Brendeland + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s(): " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define container_obj(layr) container_of(layr, struct cfcnfg, layer) + +/* Information about CAIF physical interfaces held by Config Module in order + * to manage physical interfaces + */ +struct cfcnfg_phyinfo { + struct list_head node; + bool up; + + /* Pointer to the layer below the MUX (framing layer) */ + struct cflayer *frm_layer; + /* Pointer to the lowest actual physical layer */ + struct cflayer *phy_layer; + /* Unique identifier of the physical interface */ + unsigned int id; + /* Preference of the physical in interface */ + enum cfcnfg_phy_preference pref; + + /* Information about the physical device */ + struct dev_info dev_info; + + /* Interface index */ + int ifindex; + + /* Protocol head room added for CAIF link layer */ + int head_room; + + /* Use Start of frame checksum */ + bool use_fcs; +}; + +struct cfcnfg { + struct cflayer layer; + struct cflayer *ctrl; + struct cflayer *mux; + struct list_head phys; + struct mutex lock; +}; + +static void cfcnfg_linkup_rsp(struct cflayer *layer, u8 channel_id, + enum cfctrl_srv serv, u8 phyid, + struct cflayer *adapt_layer); +static void cfcnfg_linkdestroy_rsp(struct cflayer *layer, u8 channel_id); +static void cfcnfg_reject_rsp(struct cflayer *layer, u8 channel_id, + struct cflayer *adapt_layer); +static void cfctrl_resp_func(void); +static void cfctrl_enum_resp(void); + +struct cfcnfg *cfcnfg_create(void) +{ + struct cfcnfg *this; + struct cfctrl_rsp *resp; + + might_sleep(); + + /* Initiate this layer */ + this = kzalloc(sizeof(struct cfcnfg), GFP_ATOMIC); + if (!this) + return NULL; + this->mux = cfmuxl_create(); + if (!this->mux) + goto out_of_mem; + this->ctrl = cfctrl_create(); + if (!this->ctrl) + goto out_of_mem; + /* Initiate response functions */ + resp = cfctrl_get_respfuncs(this->ctrl); + resp->enum_rsp = cfctrl_enum_resp; + resp->linkerror_ind = cfctrl_resp_func; + resp->linkdestroy_rsp = cfcnfg_linkdestroy_rsp; + resp->sleep_rsp = cfctrl_resp_func; + resp->wake_rsp = cfctrl_resp_func; + resp->restart_rsp = cfctrl_resp_func; + resp->radioset_rsp = cfctrl_resp_func; + resp->linksetup_rsp = cfcnfg_linkup_rsp; + resp->reject_rsp = cfcnfg_reject_rsp; + INIT_LIST_HEAD(&this->phys); + + cfmuxl_set_uplayer(this->mux, this->ctrl, 0); + layer_set_dn(this->ctrl, this->mux); + layer_set_up(this->ctrl, this); + mutex_init(&this->lock); + + return this; +out_of_mem: + synchronize_rcu(); + + kfree(this->mux); + kfree(this->ctrl); + kfree(this); + return NULL; +} + +void cfcnfg_remove(struct cfcnfg *cfg) +{ + might_sleep(); + if (cfg) { + synchronize_rcu(); + + kfree(cfg->mux); + cfctrl_remove(cfg->ctrl); + kfree(cfg); + } +} + +static void cfctrl_resp_func(void) +{ +} + +static struct cfcnfg_phyinfo *cfcnfg_get_phyinfo_rcu(struct cfcnfg *cnfg, + u8 phyid) +{ + struct cfcnfg_phyinfo *phy; + + list_for_each_entry_rcu(phy, &cnfg->phys, node) + if (phy->id == phyid) + return phy; + return NULL; +} + +static void cfctrl_enum_resp(void) +{ +} + +static struct dev_info *cfcnfg_get_phyid(struct cfcnfg *cnfg, + enum cfcnfg_phy_preference phy_pref) +{ + /* Try to match with specified preference */ + struct cfcnfg_phyinfo *phy; + + list_for_each_entry_rcu(phy, &cnfg->phys, node) { + if (phy->up && phy->pref == phy_pref && + phy->frm_layer != NULL) + + return &phy->dev_info; + } + + /* Otherwise just return something */ + list_for_each_entry_rcu(phy, &cnfg->phys, node) + if (phy->up) + return &phy->dev_info; + + return NULL; +} + +static int cfcnfg_get_id_from_ifi(struct cfcnfg *cnfg, int ifi) +{ + struct cfcnfg_phyinfo *phy; + + list_for_each_entry_rcu(phy, &cnfg->phys, node) + if (phy->ifindex == ifi && phy->up) + return phy->id; + return -ENODEV; +} + +int caif_disconnect_client(struct net *net, struct cflayer *adap_layer) +{ + u8 channel_id; + struct cfcnfg *cfg = get_cfcnfg(net); + + caif_assert(adap_layer != NULL); + cfctrl_cancel_req(cfg->ctrl, adap_layer); + channel_id = adap_layer->id; + if (channel_id != 0) { + struct cflayer *servl; + servl = cfmuxl_remove_uplayer(cfg->mux, channel_id); + cfctrl_linkdown_req(cfg->ctrl, channel_id, adap_layer); + if (servl != NULL) + layer_set_up(servl, NULL); + } else + pr_debug("nothing to disconnect\n"); + + /* Do RCU sync before initiating cleanup */ + synchronize_rcu(); + if (adap_layer->ctrlcmd != NULL) + adap_layer->ctrlcmd(adap_layer, CAIF_CTRLCMD_DEINIT_RSP, 0); + return 0; + +} +EXPORT_SYMBOL(caif_disconnect_client); + +static void cfcnfg_linkdestroy_rsp(struct cflayer *layer, u8 channel_id) +{ +} + +static const int protohead[CFCTRL_SRV_MASK] = { + [CFCTRL_SRV_VEI] = 4, + [CFCTRL_SRV_DATAGRAM] = 7, + [CFCTRL_SRV_UTIL] = 4, + [CFCTRL_SRV_RFM] = 3, + [CFCTRL_SRV_DBG] = 3, +}; + + +static int caif_connect_req_to_link_param(struct cfcnfg *cnfg, + struct caif_connect_request *s, + struct cfctrl_link_param *l) +{ + struct dev_info *dev_info; + enum cfcnfg_phy_preference pref; + int res; + + memset(l, 0, sizeof(*l)); + /* In caif protocol low value is high priority */ + l->priority = CAIF_PRIO_MAX - s->priority + 1; + + if (s->ifindex != 0) { + res = cfcnfg_get_id_from_ifi(cnfg, s->ifindex); + if (res < 0) + return res; + l->phyid = res; + } else { + switch (s->link_selector) { + case CAIF_LINK_HIGH_BANDW: + pref = CFPHYPREF_HIGH_BW; + break; + case CAIF_LINK_LOW_LATENCY: + pref = CFPHYPREF_LOW_LAT; + break; + default: + return -EINVAL; + } + dev_info = cfcnfg_get_phyid(cnfg, pref); + if (dev_info == NULL) + return -ENODEV; + l->phyid = dev_info->id; + } + switch (s->protocol) { + case CAIFPROTO_AT: + l->linktype = CFCTRL_SRV_VEI; + l->endpoint = (s->sockaddr.u.at.type >> 2) & 0x3; + l->chtype = s->sockaddr.u.at.type & 0x3; + break; + case CAIFPROTO_DATAGRAM: + l->linktype = CFCTRL_SRV_DATAGRAM; + l->chtype = 0x00; + l->u.datagram.connid = s->sockaddr.u.dgm.connection_id; + break; + case CAIFPROTO_DATAGRAM_LOOP: + l->linktype = CFCTRL_SRV_DATAGRAM; + l->chtype = 0x03; + l->endpoint = 0x00; + l->u.datagram.connid = s->sockaddr.u.dgm.connection_id; + break; + case CAIFPROTO_RFM: + l->linktype = CFCTRL_SRV_RFM; + l->u.datagram.connid = s->sockaddr.u.rfm.connection_id; + strscpy(l->u.rfm.volume, s->sockaddr.u.rfm.volume, + sizeof(l->u.rfm.volume)); + break; + case CAIFPROTO_UTIL: + l->linktype = CFCTRL_SRV_UTIL; + l->endpoint = 0x00; + l->chtype = 0x00; + strscpy(l->u.utility.name, s->sockaddr.u.util.service, + sizeof(l->u.utility.name)); + caif_assert(sizeof(l->u.utility.name) > 10); + l->u.utility.paramlen = s->param.size; + if (l->u.utility.paramlen > sizeof(l->u.utility.params)) + l->u.utility.paramlen = sizeof(l->u.utility.params); + + memcpy(l->u.utility.params, s->param.data, + l->u.utility.paramlen); + + break; + case CAIFPROTO_DEBUG: + l->linktype = CFCTRL_SRV_DBG; + l->endpoint = s->sockaddr.u.dbg.service; + l->chtype = s->sockaddr.u.dbg.type; + break; + default: + return -EINVAL; + } + return 0; +} + +int caif_connect_client(struct net *net, struct caif_connect_request *conn_req, + struct cflayer *adap_layer, int *ifindex, + int *proto_head, int *proto_tail) +{ + struct cflayer *frml; + struct cfcnfg_phyinfo *phy; + int err; + struct cfctrl_link_param param; + struct cfcnfg *cfg = get_cfcnfg(net); + + rcu_read_lock(); + err = caif_connect_req_to_link_param(cfg, conn_req, ¶m); + if (err) + goto unlock; + + phy = cfcnfg_get_phyinfo_rcu(cfg, param.phyid); + if (!phy) { + err = -ENODEV; + goto unlock; + } + err = -EINVAL; + + if (adap_layer == NULL) { + pr_err("adap_layer is zero\n"); + goto unlock; + } + if (adap_layer->receive == NULL) { + pr_err("adap_layer->receive is NULL\n"); + goto unlock; + } + if (adap_layer->ctrlcmd == NULL) { + pr_err("adap_layer->ctrlcmd == NULL\n"); + goto unlock; + } + + err = -ENODEV; + frml = phy->frm_layer; + if (frml == NULL) { + pr_err("Specified PHY type does not exist!\n"); + goto unlock; + } + caif_assert(param.phyid == phy->id); + caif_assert(phy->frm_layer->id == + param.phyid); + caif_assert(phy->phy_layer->id == + param.phyid); + + *ifindex = phy->ifindex; + *proto_tail = 2; + *proto_head = protohead[param.linktype] + phy->head_room; + + rcu_read_unlock(); + + /* FIXME: ENUMERATE INITIALLY WHEN ACTIVATING PHYSICAL INTERFACE */ + cfctrl_enum_req(cfg->ctrl, param.phyid); + return cfctrl_linkup_request(cfg->ctrl, ¶m, adap_layer); + +unlock: + rcu_read_unlock(); + return err; +} +EXPORT_SYMBOL(caif_connect_client); + +static void cfcnfg_reject_rsp(struct cflayer *layer, u8 channel_id, + struct cflayer *adapt_layer) +{ + if (adapt_layer != NULL && adapt_layer->ctrlcmd != NULL) + adapt_layer->ctrlcmd(adapt_layer, + CAIF_CTRLCMD_INIT_FAIL_RSP, 0); +} + +static void +cfcnfg_linkup_rsp(struct cflayer *layer, u8 channel_id, enum cfctrl_srv serv, + u8 phyid, struct cflayer *adapt_layer) +{ + struct cfcnfg *cnfg = container_obj(layer); + struct cflayer *servicel = NULL; + struct cfcnfg_phyinfo *phyinfo; + struct net_device *netdev; + + if (channel_id == 0) { + pr_warn("received channel_id zero\n"); + if (adapt_layer != NULL && adapt_layer->ctrlcmd != NULL) + adapt_layer->ctrlcmd(adapt_layer, + CAIF_CTRLCMD_INIT_FAIL_RSP, 0); + return; + } + + rcu_read_lock(); + + if (adapt_layer == NULL) { + pr_debug("link setup response but no client exist, send linkdown back\n"); + cfctrl_linkdown_req(cnfg->ctrl, channel_id, NULL); + goto unlock; + } + + caif_assert(cnfg != NULL); + caif_assert(phyid != 0); + + phyinfo = cfcnfg_get_phyinfo_rcu(cnfg, phyid); + if (phyinfo == NULL) { + pr_err("ERROR: Link Layer Device disappeared while connecting\n"); + goto unlock; + } + + caif_assert(phyinfo != NULL); + caif_assert(phyinfo->id == phyid); + caif_assert(phyinfo->phy_layer != NULL); + caif_assert(phyinfo->phy_layer->id == phyid); + + adapt_layer->id = channel_id; + + switch (serv) { + case CFCTRL_SRV_VEI: + servicel = cfvei_create(channel_id, &phyinfo->dev_info); + break; + case CFCTRL_SRV_DATAGRAM: + servicel = cfdgml_create(channel_id, + &phyinfo->dev_info); + break; + case CFCTRL_SRV_RFM: + netdev = phyinfo->dev_info.dev; + servicel = cfrfml_create(channel_id, &phyinfo->dev_info, + netdev->mtu); + break; + case CFCTRL_SRV_UTIL: + servicel = cfutill_create(channel_id, &phyinfo->dev_info); + break; + case CFCTRL_SRV_VIDEO: + servicel = cfvidl_create(channel_id, &phyinfo->dev_info); + break; + case CFCTRL_SRV_DBG: + servicel = cfdbgl_create(channel_id, &phyinfo->dev_info); + break; + default: + pr_err("Protocol error. Link setup response - unknown channel type\n"); + goto unlock; + } + if (!servicel) + goto unlock; + layer_set_dn(servicel, cnfg->mux); + cfmuxl_set_uplayer(cnfg->mux, servicel, channel_id); + layer_set_up(servicel, adapt_layer); + layer_set_dn(adapt_layer, servicel); + + rcu_read_unlock(); + + servicel->ctrlcmd(servicel, CAIF_CTRLCMD_INIT_RSP, 0); + return; +unlock: + rcu_read_unlock(); +} + +int +cfcnfg_add_phy_layer(struct cfcnfg *cnfg, + struct net_device *dev, struct cflayer *phy_layer, + enum cfcnfg_phy_preference pref, + struct cflayer *link_support, + bool fcs, int head_room) +{ + struct cflayer *frml; + struct cfcnfg_phyinfo *phyinfo = NULL; + int i, res = 0; + u8 phyid; + + mutex_lock(&cnfg->lock); + + /* CAIF protocol allow maximum 6 link-layers */ + for (i = 0; i < 7; i++) { + phyid = (dev->ifindex + i) & 0x7; + if (phyid == 0) + continue; + if (cfcnfg_get_phyinfo_rcu(cnfg, phyid) == NULL) + goto got_phyid; + } + pr_warn("Too many CAIF Link Layers (max 6)\n"); + res = -EEXIST; + goto out; + +got_phyid: + phyinfo = kzalloc(sizeof(struct cfcnfg_phyinfo), GFP_ATOMIC); + if (!phyinfo) { + res = -ENOMEM; + goto out; + } + + phy_layer->id = phyid; + phyinfo->pref = pref; + phyinfo->id = phyid; + phyinfo->dev_info.id = phyid; + phyinfo->dev_info.dev = dev; + phyinfo->phy_layer = phy_layer; + phyinfo->ifindex = dev->ifindex; + phyinfo->head_room = head_room; + phyinfo->use_fcs = fcs; + + frml = cffrml_create(phyid, fcs); + + if (!frml) { + res = -ENOMEM; + goto out_err; + } + phyinfo->frm_layer = frml; + layer_set_up(frml, cnfg->mux); + + if (link_support != NULL) { + link_support->id = phyid; + layer_set_dn(frml, link_support); + layer_set_up(link_support, frml); + layer_set_dn(link_support, phy_layer); + layer_set_up(phy_layer, link_support); + } else { + layer_set_dn(frml, phy_layer); + layer_set_up(phy_layer, frml); + } + + list_add_rcu(&phyinfo->node, &cnfg->phys); +out: + mutex_unlock(&cnfg->lock); + return res; + +out_err: + kfree(phyinfo); + mutex_unlock(&cnfg->lock); + return res; +} +EXPORT_SYMBOL(cfcnfg_add_phy_layer); + +int cfcnfg_set_phy_state(struct cfcnfg *cnfg, struct cflayer *phy_layer, + bool up) +{ + struct cfcnfg_phyinfo *phyinfo; + + rcu_read_lock(); + phyinfo = cfcnfg_get_phyinfo_rcu(cnfg, phy_layer->id); + if (phyinfo == NULL) { + rcu_read_unlock(); + return -ENODEV; + } + + if (phyinfo->up == up) { + rcu_read_unlock(); + return 0; + } + phyinfo->up = up; + + if (up) { + cffrml_hold(phyinfo->frm_layer); + cfmuxl_set_dnlayer(cnfg->mux, phyinfo->frm_layer, + phy_layer->id); + } else { + cfmuxl_remove_dnlayer(cnfg->mux, phy_layer->id); + cffrml_put(phyinfo->frm_layer); + } + + rcu_read_unlock(); + return 0; +} +EXPORT_SYMBOL(cfcnfg_set_phy_state); + +int cfcnfg_del_phy_layer(struct cfcnfg *cnfg, struct cflayer *phy_layer) +{ + struct cflayer *frml, *frml_dn; + u16 phyid; + struct cfcnfg_phyinfo *phyinfo; + + might_sleep(); + + mutex_lock(&cnfg->lock); + + phyid = phy_layer->id; + phyinfo = cfcnfg_get_phyinfo_rcu(cnfg, phyid); + + if (phyinfo == NULL) { + mutex_unlock(&cnfg->lock); + return 0; + } + caif_assert(phyid == phyinfo->id); + caif_assert(phy_layer == phyinfo->phy_layer); + caif_assert(phy_layer->id == phyid); + caif_assert(phyinfo->frm_layer->id == phyid); + + list_del_rcu(&phyinfo->node); + synchronize_rcu(); + + /* Fail if reference count is not zero */ + if (cffrml_refcnt_read(phyinfo->frm_layer) != 0) { + pr_info("Wait for device inuse\n"); + list_add_rcu(&phyinfo->node, &cnfg->phys); + mutex_unlock(&cnfg->lock); + return -EAGAIN; + } + + frml = phyinfo->frm_layer; + frml_dn = frml->dn; + cffrml_set_uplayer(frml, NULL); + cffrml_set_dnlayer(frml, NULL); + if (phy_layer != frml_dn) { + layer_set_up(frml_dn, NULL); + layer_set_dn(frml_dn, NULL); + } + layer_set_up(phy_layer, NULL); + + if (phyinfo->phy_layer != frml_dn) + kfree(frml_dn); + + cffrml_free(frml); + kfree(phyinfo); + mutex_unlock(&cnfg->lock); + + return 0; +} +EXPORT_SYMBOL(cfcnfg_del_phy_layer); diff --git a/net/caif/cfctrl.c b/net/caif/cfctrl.c new file mode 100644 index 000000000..8480684f2 --- /dev/null +++ b/net/caif/cfctrl.c @@ -0,0 +1,639 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) ST-Ericsson AB 2010 + * Author: Sjur Brendeland + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s(): " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include + +#define container_obj(layr) container_of(layr, struct cfctrl, serv.layer) +#define UTILITY_NAME_LENGTH 16 +#define CFPKT_CTRL_PKT_LEN 20 + +#ifdef CAIF_NO_LOOP +static int handle_loop(struct cfctrl *ctrl, + int cmd, struct cfpkt *pkt){ + return -1; +} +#else +static int handle_loop(struct cfctrl *ctrl, + int cmd, struct cfpkt *pkt); +#endif +static int cfctrl_recv(struct cflayer *layr, struct cfpkt *pkt); +static void cfctrl_ctrlcmd(struct cflayer *layr, enum caif_ctrlcmd ctrl, + int phyid); + + +struct cflayer *cfctrl_create(void) +{ + struct dev_info dev_info; + struct cfctrl *this = + kzalloc(sizeof(struct cfctrl), GFP_ATOMIC); + if (!this) + return NULL; + caif_assert(offsetof(struct cfctrl, serv.layer) == 0); + memset(&dev_info, 0, sizeof(dev_info)); + dev_info.id = 0xff; + cfsrvl_init(&this->serv, 0, &dev_info, false); + atomic_set(&this->req_seq_no, 1); + atomic_set(&this->rsp_seq_no, 1); + this->serv.layer.receive = cfctrl_recv; + sprintf(this->serv.layer.name, "ctrl"); + this->serv.layer.ctrlcmd = cfctrl_ctrlcmd; +#ifndef CAIF_NO_LOOP + spin_lock_init(&this->loop_linkid_lock); + this->loop_linkid = 1; +#endif + spin_lock_init(&this->info_list_lock); + INIT_LIST_HEAD(&this->list); + return &this->serv.layer; +} + +void cfctrl_remove(struct cflayer *layer) +{ + struct cfctrl_request_info *p, *tmp; + struct cfctrl *ctrl = container_obj(layer); + + spin_lock_bh(&ctrl->info_list_lock); + list_for_each_entry_safe(p, tmp, &ctrl->list, list) { + list_del(&p->list); + kfree(p); + } + spin_unlock_bh(&ctrl->info_list_lock); + kfree(layer); +} + +static bool param_eq(const struct cfctrl_link_param *p1, + const struct cfctrl_link_param *p2) +{ + bool eq = + p1->linktype == p2->linktype && + p1->priority == p2->priority && + p1->phyid == p2->phyid && + p1->endpoint == p2->endpoint && p1->chtype == p2->chtype; + + if (!eq) + return false; + + switch (p1->linktype) { + case CFCTRL_SRV_VEI: + return true; + case CFCTRL_SRV_DATAGRAM: + return p1->u.datagram.connid == p2->u.datagram.connid; + case CFCTRL_SRV_RFM: + return + p1->u.rfm.connid == p2->u.rfm.connid && + strcmp(p1->u.rfm.volume, p2->u.rfm.volume) == 0; + case CFCTRL_SRV_UTIL: + return + p1->u.utility.fifosize_kb == p2->u.utility.fifosize_kb + && p1->u.utility.fifosize_bufs == + p2->u.utility.fifosize_bufs + && strcmp(p1->u.utility.name, p2->u.utility.name) == 0 + && p1->u.utility.paramlen == p2->u.utility.paramlen + && memcmp(p1->u.utility.params, p2->u.utility.params, + p1->u.utility.paramlen) == 0; + + case CFCTRL_SRV_VIDEO: + return p1->u.video.connid == p2->u.video.connid; + case CFCTRL_SRV_DBG: + return true; + case CFCTRL_SRV_DECM: + return false; + default: + return false; + } + return false; +} + +static bool cfctrl_req_eq(const struct cfctrl_request_info *r1, + const struct cfctrl_request_info *r2) +{ + if (r1->cmd != r2->cmd) + return false; + if (r1->cmd == CFCTRL_CMD_LINK_SETUP) + return param_eq(&r1->param, &r2->param); + else + return r1->channel_id == r2->channel_id; +} + +/* Insert request at the end */ +static void cfctrl_insert_req(struct cfctrl *ctrl, + struct cfctrl_request_info *req) +{ + spin_lock_bh(&ctrl->info_list_lock); + atomic_inc(&ctrl->req_seq_no); + req->sequence_no = atomic_read(&ctrl->req_seq_no); + list_add_tail(&req->list, &ctrl->list); + spin_unlock_bh(&ctrl->info_list_lock); +} + +/* Compare and remove request */ +static struct cfctrl_request_info *cfctrl_remove_req(struct cfctrl *ctrl, + struct cfctrl_request_info *req) +{ + struct cfctrl_request_info *p, *tmp, *first; + + first = list_first_entry(&ctrl->list, struct cfctrl_request_info, list); + + list_for_each_entry_safe(p, tmp, &ctrl->list, list) { + if (cfctrl_req_eq(req, p)) { + if (p != first) + pr_warn("Requests are not received in order\n"); + + atomic_set(&ctrl->rsp_seq_no, + p->sequence_no); + list_del(&p->list); + goto out; + } + } + p = NULL; +out: + return p; +} + +struct cfctrl_rsp *cfctrl_get_respfuncs(struct cflayer *layer) +{ + struct cfctrl *this = container_obj(layer); + return &this->res; +} + +static void init_info(struct caif_payload_info *info, struct cfctrl *cfctrl) +{ + info->hdr_len = 0; + info->channel_id = cfctrl->serv.layer.id; + info->dev_info = &cfctrl->serv.dev_info; +} + +void cfctrl_enum_req(struct cflayer *layer, u8 physlinkid) +{ + struct cfpkt *pkt; + struct cfctrl *cfctrl = container_obj(layer); + struct cflayer *dn = cfctrl->serv.layer.dn; + + if (!dn) { + pr_debug("not able to send enum request\n"); + return; + } + pkt = cfpkt_create(CFPKT_CTRL_PKT_LEN); + if (!pkt) + return; + caif_assert(offsetof(struct cfctrl, serv.layer) == 0); + init_info(cfpkt_info(pkt), cfctrl); + cfpkt_info(pkt)->dev_info->id = physlinkid; + cfctrl->serv.dev_info.id = physlinkid; + cfpkt_addbdy(pkt, CFCTRL_CMD_ENUM); + cfpkt_addbdy(pkt, physlinkid); + cfpkt_set_prio(pkt, TC_PRIO_CONTROL); + dn->transmit(dn, pkt); +} + +int cfctrl_linkup_request(struct cflayer *layer, + struct cfctrl_link_param *param, + struct cflayer *user_layer) +{ + struct cfctrl *cfctrl = container_obj(layer); + u32 tmp32; + u16 tmp16; + u8 tmp8; + struct cfctrl_request_info *req; + int ret; + char utility_name[16]; + struct cfpkt *pkt; + struct cflayer *dn = cfctrl->serv.layer.dn; + + if (!dn) { + pr_debug("not able to send linkup request\n"); + return -ENODEV; + } + + if (cfctrl_cancel_req(layer, user_layer) > 0) { + /* Slight Paranoia, check if already connecting */ + pr_err("Duplicate connect request for same client\n"); + WARN_ON(1); + return -EALREADY; + } + + pkt = cfpkt_create(CFPKT_CTRL_PKT_LEN); + if (!pkt) + return -ENOMEM; + cfpkt_addbdy(pkt, CFCTRL_CMD_LINK_SETUP); + cfpkt_addbdy(pkt, (param->chtype << 4) | param->linktype); + cfpkt_addbdy(pkt, (param->priority << 3) | param->phyid); + cfpkt_addbdy(pkt, param->endpoint & 0x03); + + switch (param->linktype) { + case CFCTRL_SRV_VEI: + break; + case CFCTRL_SRV_VIDEO: + cfpkt_addbdy(pkt, (u8) param->u.video.connid); + break; + case CFCTRL_SRV_DBG: + break; + case CFCTRL_SRV_DATAGRAM: + tmp32 = cpu_to_le32(param->u.datagram.connid); + cfpkt_add_body(pkt, &tmp32, 4); + break; + case CFCTRL_SRV_RFM: + /* Construct a frame, convert DatagramConnectionID to network + * format long and copy it out... + */ + tmp32 = cpu_to_le32(param->u.rfm.connid); + cfpkt_add_body(pkt, &tmp32, 4); + /* Add volume name, including zero termination... */ + cfpkt_add_body(pkt, param->u.rfm.volume, + strlen(param->u.rfm.volume) + 1); + break; + case CFCTRL_SRV_UTIL: + tmp16 = cpu_to_le16(param->u.utility.fifosize_kb); + cfpkt_add_body(pkt, &tmp16, 2); + tmp16 = cpu_to_le16(param->u.utility.fifosize_bufs); + cfpkt_add_body(pkt, &tmp16, 2); + memset(utility_name, 0, sizeof(utility_name)); + strscpy(utility_name, param->u.utility.name, + UTILITY_NAME_LENGTH); + cfpkt_add_body(pkt, utility_name, UTILITY_NAME_LENGTH); + tmp8 = param->u.utility.paramlen; + cfpkt_add_body(pkt, &tmp8, 1); + cfpkt_add_body(pkt, param->u.utility.params, + param->u.utility.paramlen); + break; + default: + pr_warn("Request setup of bad link type = %d\n", + param->linktype); + cfpkt_destroy(pkt); + return -EINVAL; + } + req = kzalloc(sizeof(*req), GFP_KERNEL); + if (!req) { + cfpkt_destroy(pkt); + return -ENOMEM; + } + + req->client_layer = user_layer; + req->cmd = CFCTRL_CMD_LINK_SETUP; + req->param = *param; + cfctrl_insert_req(cfctrl, req); + init_info(cfpkt_info(pkt), cfctrl); + /* + * NOTE:Always send linkup and linkdown request on the same + * device as the payload. Otherwise old queued up payload + * might arrive with the newly allocated channel ID. + */ + cfpkt_info(pkt)->dev_info->id = param->phyid; + cfpkt_set_prio(pkt, TC_PRIO_CONTROL); + ret = + dn->transmit(dn, pkt); + if (ret < 0) { + int count; + + count = cfctrl_cancel_req(&cfctrl->serv.layer, + user_layer); + if (count != 1) { + pr_err("Could not remove request (%d)", count); + return -ENODEV; + } + } + return 0; +} + +int cfctrl_linkdown_req(struct cflayer *layer, u8 channelid, + struct cflayer *client) +{ + int ret; + struct cfpkt *pkt; + struct cfctrl *cfctrl = container_obj(layer); + struct cflayer *dn = cfctrl->serv.layer.dn; + + if (!dn) { + pr_debug("not able to send link-down request\n"); + return -ENODEV; + } + pkt = cfpkt_create(CFPKT_CTRL_PKT_LEN); + if (!pkt) + return -ENOMEM; + cfpkt_addbdy(pkt, CFCTRL_CMD_LINK_DESTROY); + cfpkt_addbdy(pkt, channelid); + init_info(cfpkt_info(pkt), cfctrl); + cfpkt_set_prio(pkt, TC_PRIO_CONTROL); + ret = + dn->transmit(dn, pkt); +#ifndef CAIF_NO_LOOP + cfctrl->loop_linkused[channelid] = 0; +#endif + return ret; +} + +int cfctrl_cancel_req(struct cflayer *layr, struct cflayer *adap_layer) +{ + struct cfctrl_request_info *p, *tmp; + struct cfctrl *ctrl = container_obj(layr); + int found = 0; + spin_lock_bh(&ctrl->info_list_lock); + + list_for_each_entry_safe(p, tmp, &ctrl->list, list) { + if (p->client_layer == adap_layer) { + list_del(&p->list); + kfree(p); + found++; + } + } + + spin_unlock_bh(&ctrl->info_list_lock); + return found; +} + +static int cfctrl_recv(struct cflayer *layer, struct cfpkt *pkt) +{ + u8 cmdrsp; + u8 cmd; + int ret = -1; + u8 len; + u8 param[255]; + u8 linkid = 0; + struct cfctrl *cfctrl = container_obj(layer); + struct cfctrl_request_info rsp, *req; + + + cmdrsp = cfpkt_extr_head_u8(pkt); + cmd = cmdrsp & CFCTRL_CMD_MASK; + if (cmd != CFCTRL_CMD_LINK_ERR + && CFCTRL_RSP_BIT != (CFCTRL_RSP_BIT & cmdrsp) + && CFCTRL_ERR_BIT != (CFCTRL_ERR_BIT & cmdrsp)) { + if (handle_loop(cfctrl, cmd, pkt) != 0) + cmdrsp |= CFCTRL_ERR_BIT; + } + + switch (cmd) { + case CFCTRL_CMD_LINK_SETUP: + { + enum cfctrl_srv serv; + enum cfctrl_srv servtype; + u8 endpoint; + u8 physlinkid; + u8 prio; + u8 tmp; + u8 *cp; + int i; + struct cfctrl_link_param linkparam; + memset(&linkparam, 0, sizeof(linkparam)); + + tmp = cfpkt_extr_head_u8(pkt); + + serv = tmp & CFCTRL_SRV_MASK; + linkparam.linktype = serv; + + servtype = tmp >> 4; + linkparam.chtype = servtype; + + tmp = cfpkt_extr_head_u8(pkt); + physlinkid = tmp & 0x07; + prio = tmp >> 3; + + linkparam.priority = prio; + linkparam.phyid = physlinkid; + endpoint = cfpkt_extr_head_u8(pkt); + linkparam.endpoint = endpoint & 0x03; + + switch (serv) { + case CFCTRL_SRV_VEI: + case CFCTRL_SRV_DBG: + if (CFCTRL_ERR_BIT & cmdrsp) + break; + /* Link ID */ + linkid = cfpkt_extr_head_u8(pkt); + break; + case CFCTRL_SRV_VIDEO: + tmp = cfpkt_extr_head_u8(pkt); + linkparam.u.video.connid = tmp; + if (CFCTRL_ERR_BIT & cmdrsp) + break; + /* Link ID */ + linkid = cfpkt_extr_head_u8(pkt); + break; + + case CFCTRL_SRV_DATAGRAM: + linkparam.u.datagram.connid = + cfpkt_extr_head_u32(pkt); + if (CFCTRL_ERR_BIT & cmdrsp) + break; + /* Link ID */ + linkid = cfpkt_extr_head_u8(pkt); + break; + case CFCTRL_SRV_RFM: + /* Construct a frame, convert + * DatagramConnectionID + * to network format long and copy it out... + */ + linkparam.u.rfm.connid = + cfpkt_extr_head_u32(pkt); + cp = (u8 *) linkparam.u.rfm.volume; + for (tmp = cfpkt_extr_head_u8(pkt); + cfpkt_more(pkt) && tmp != '\0'; + tmp = cfpkt_extr_head_u8(pkt)) + *cp++ = tmp; + *cp = '\0'; + + if (CFCTRL_ERR_BIT & cmdrsp) + break; + /* Link ID */ + linkid = cfpkt_extr_head_u8(pkt); + + break; + case CFCTRL_SRV_UTIL: + /* Construct a frame, convert + * DatagramConnectionID + * to network format long and copy it out... + */ + /* Fifosize KB */ + linkparam.u.utility.fifosize_kb = + cfpkt_extr_head_u16(pkt); + /* Fifosize bufs */ + linkparam.u.utility.fifosize_bufs = + cfpkt_extr_head_u16(pkt); + /* name */ + cp = (u8 *) linkparam.u.utility.name; + caif_assert(sizeof(linkparam.u.utility.name) + >= UTILITY_NAME_LENGTH); + for (i = 0; + i < UTILITY_NAME_LENGTH + && cfpkt_more(pkt); i++) { + tmp = cfpkt_extr_head_u8(pkt); + *cp++ = tmp; + } + /* Length */ + len = cfpkt_extr_head_u8(pkt); + linkparam.u.utility.paramlen = len; + /* Param Data */ + cp = linkparam.u.utility.params; + while (cfpkt_more(pkt) && len--) { + tmp = cfpkt_extr_head_u8(pkt); + *cp++ = tmp; + } + if (CFCTRL_ERR_BIT & cmdrsp) + break; + /* Link ID */ + linkid = cfpkt_extr_head_u8(pkt); + /* Length */ + len = cfpkt_extr_head_u8(pkt); + /* Param Data */ + cfpkt_extr_head(pkt, ¶m, len); + break; + default: + pr_warn("Request setup, invalid type (%d)\n", + serv); + goto error; + } + + rsp.cmd = cmd; + rsp.param = linkparam; + spin_lock_bh(&cfctrl->info_list_lock); + req = cfctrl_remove_req(cfctrl, &rsp); + + if (CFCTRL_ERR_BIT == (CFCTRL_ERR_BIT & cmdrsp) || + cfpkt_erroneous(pkt)) { + pr_err("Invalid O/E bit or parse error " + "on CAIF control channel\n"); + cfctrl->res.reject_rsp(cfctrl->serv.layer.up, + 0, + req ? req->client_layer + : NULL); + } else { + cfctrl->res.linksetup_rsp(cfctrl->serv. + layer.up, linkid, + serv, physlinkid, + req ? req-> + client_layer : NULL); + } + + kfree(req); + + spin_unlock_bh(&cfctrl->info_list_lock); + } + break; + case CFCTRL_CMD_LINK_DESTROY: + linkid = cfpkt_extr_head_u8(pkt); + cfctrl->res.linkdestroy_rsp(cfctrl->serv.layer.up, linkid); + break; + case CFCTRL_CMD_LINK_ERR: + pr_err("Frame Error Indication received\n"); + cfctrl->res.linkerror_ind(); + break; + case CFCTRL_CMD_ENUM: + cfctrl->res.enum_rsp(); + break; + case CFCTRL_CMD_SLEEP: + cfctrl->res.sleep_rsp(); + break; + case CFCTRL_CMD_WAKE: + cfctrl->res.wake_rsp(); + break; + case CFCTRL_CMD_LINK_RECONF: + cfctrl->res.restart_rsp(); + break; + case CFCTRL_CMD_RADIO_SET: + cfctrl->res.radioset_rsp(); + break; + default: + pr_err("Unrecognized Control Frame\n"); + goto error; + } + ret = 0; +error: + cfpkt_destroy(pkt); + return ret; +} + +static void cfctrl_ctrlcmd(struct cflayer *layr, enum caif_ctrlcmd ctrl, + int phyid) +{ + struct cfctrl *this = container_obj(layr); + switch (ctrl) { + case _CAIF_CTRLCMD_PHYIF_FLOW_OFF_IND: + case CAIF_CTRLCMD_FLOW_OFF_IND: + spin_lock_bh(&this->info_list_lock); + if (!list_empty(&this->list)) + pr_debug("Received flow off in control layer\n"); + spin_unlock_bh(&this->info_list_lock); + break; + case _CAIF_CTRLCMD_PHYIF_DOWN_IND: { + struct cfctrl_request_info *p, *tmp; + + /* Find all connect request and report failure */ + spin_lock_bh(&this->info_list_lock); + list_for_each_entry_safe(p, tmp, &this->list, list) { + if (p->param.phyid == phyid) { + list_del(&p->list); + p->client_layer->ctrlcmd(p->client_layer, + CAIF_CTRLCMD_INIT_FAIL_RSP, + phyid); + kfree(p); + } + } + spin_unlock_bh(&this->info_list_lock); + break; + } + default: + break; + } +} + +#ifndef CAIF_NO_LOOP +static int handle_loop(struct cfctrl *ctrl, int cmd, struct cfpkt *pkt) +{ + static int last_linkid; + static int dec; + u8 linkid, linktype, tmp; + switch (cmd) { + case CFCTRL_CMD_LINK_SETUP: + spin_lock_bh(&ctrl->loop_linkid_lock); + if (!dec) { + for (linkid = last_linkid + 1; linkid < 254; linkid++) + if (!ctrl->loop_linkused[linkid]) + goto found; + } + dec = 1; + for (linkid = last_linkid - 1; linkid > 1; linkid--) + if (!ctrl->loop_linkused[linkid]) + goto found; + spin_unlock_bh(&ctrl->loop_linkid_lock); + return -1; +found: + if (linkid < 10) + dec = 0; + + if (!ctrl->loop_linkused[linkid]) + ctrl->loop_linkused[linkid] = 1; + + last_linkid = linkid; + + cfpkt_add_trail(pkt, &linkid, 1); + spin_unlock_bh(&ctrl->loop_linkid_lock); + cfpkt_peek_head(pkt, &linktype, 1); + if (linktype == CFCTRL_SRV_UTIL) { + tmp = 0x01; + cfpkt_add_trail(pkt, &tmp, 1); + cfpkt_add_trail(pkt, &tmp, 1); + } + break; + + case CFCTRL_CMD_LINK_DESTROY: + spin_lock_bh(&ctrl->loop_linkid_lock); + cfpkt_peek_head(pkt, &linkid, 1); + ctrl->loop_linkused[linkid] = 0; + spin_unlock_bh(&ctrl->loop_linkid_lock); + break; + default: + break; + } + return 0; +} +#endif diff --git a/net/caif/cfdbgl.c b/net/caif/cfdbgl.c new file mode 100644 index 000000000..77f428428 --- /dev/null +++ b/net/caif/cfdbgl.c @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) ST-Ericsson AB 2010 + * Author: Sjur Brendeland + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s(): " fmt, __func__ + +#include +#include +#include +#include +#include + +#define container_obj(layr) ((struct cfsrvl *) layr) + +static int cfdbgl_receive(struct cflayer *layr, struct cfpkt *pkt); +static int cfdbgl_transmit(struct cflayer *layr, struct cfpkt *pkt); + +struct cflayer *cfdbgl_create(u8 channel_id, struct dev_info *dev_info) +{ + struct cfsrvl *dbg = kzalloc(sizeof(struct cfsrvl), GFP_ATOMIC); + if (!dbg) + return NULL; + caif_assert(offsetof(struct cfsrvl, layer) == 0); + cfsrvl_init(dbg, channel_id, dev_info, false); + dbg->layer.receive = cfdbgl_receive; + dbg->layer.transmit = cfdbgl_transmit; + snprintf(dbg->layer.name, CAIF_LAYER_NAME_SZ, "dbg%d", channel_id); + return &dbg->layer; +} + +static int cfdbgl_receive(struct cflayer *layr, struct cfpkt *pkt) +{ + return layr->up->receive(layr->up, pkt); +} + +static int cfdbgl_transmit(struct cflayer *layr, struct cfpkt *pkt) +{ + struct cfsrvl *service = container_obj(layr); + struct caif_payload_info *info; + int ret; + + if (!cfsrvl_ready(service, &ret)) { + cfpkt_destroy(pkt); + return ret; + } + + /* Add info for MUX-layer to route the packet out */ + info = cfpkt_info(pkt); + info->channel_id = service->layer.id; + info->dev_info = &service->dev_info; + + return layr->dn->transmit(layr->dn, pkt); +} diff --git a/net/caif/cfdgml.c b/net/caif/cfdgml.c new file mode 100644 index 000000000..eb6f8ef47 --- /dev/null +++ b/net/caif/cfdgml.c @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) ST-Ericsson AB 2010 + * Author: Sjur Brendeland + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s(): " fmt, __func__ + +#include +#include +#include +#include +#include +#include + + +#define container_obj(layr) ((struct cfsrvl *) layr) + +#define DGM_CMD_BIT 0x80 +#define DGM_FLOW_OFF 0x81 +#define DGM_FLOW_ON 0x80 +#define DGM_MTU 1500 + +static int cfdgml_receive(struct cflayer *layr, struct cfpkt *pkt); +static int cfdgml_transmit(struct cflayer *layr, struct cfpkt *pkt); + +struct cflayer *cfdgml_create(u8 channel_id, struct dev_info *dev_info) +{ + struct cfsrvl *dgm = kzalloc(sizeof(struct cfsrvl), GFP_ATOMIC); + if (!dgm) + return NULL; + caif_assert(offsetof(struct cfsrvl, layer) == 0); + cfsrvl_init(dgm, channel_id, dev_info, true); + dgm->layer.receive = cfdgml_receive; + dgm->layer.transmit = cfdgml_transmit; + snprintf(dgm->layer.name, CAIF_LAYER_NAME_SZ, "dgm%d", channel_id); + return &dgm->layer; +} + +static int cfdgml_receive(struct cflayer *layr, struct cfpkt *pkt) +{ + u8 cmd = -1; + u8 dgmhdr[3]; + int ret; + caif_assert(layr->up != NULL); + caif_assert(layr->receive != NULL); + caif_assert(layr->ctrlcmd != NULL); + + if (cfpkt_extr_head(pkt, &cmd, 1) < 0) { + pr_err("Packet is erroneous!\n"); + cfpkt_destroy(pkt); + return -EPROTO; + } + + if ((cmd & DGM_CMD_BIT) == 0) { + if (cfpkt_extr_head(pkt, &dgmhdr, 3) < 0) { + pr_err("Packet is erroneous!\n"); + cfpkt_destroy(pkt); + return -EPROTO; + } + ret = layr->up->receive(layr->up, pkt); + return ret; + } + + switch (cmd) { + case DGM_FLOW_OFF: /* FLOW OFF */ + layr->ctrlcmd(layr, CAIF_CTRLCMD_FLOW_OFF_IND, 0); + cfpkt_destroy(pkt); + return 0; + case DGM_FLOW_ON: /* FLOW ON */ + layr->ctrlcmd(layr, CAIF_CTRLCMD_FLOW_ON_IND, 0); + cfpkt_destroy(pkt); + return 0; + default: + cfpkt_destroy(pkt); + pr_info("Unknown datagram control %d (0x%x)\n", cmd, cmd); + return -EPROTO; + } +} + +static int cfdgml_transmit(struct cflayer *layr, struct cfpkt *pkt) +{ + u8 packet_type; + u32 zero = 0; + struct caif_payload_info *info; + struct cfsrvl *service = container_obj(layr); + int ret; + + if (!cfsrvl_ready(service, &ret)) { + cfpkt_destroy(pkt); + return ret; + } + + /* STE Modem cannot handle more than 1500 bytes datagrams */ + if (cfpkt_getlen(pkt) > DGM_MTU) { + cfpkt_destroy(pkt); + return -EMSGSIZE; + } + + cfpkt_add_head(pkt, &zero, 3); + packet_type = 0x08; /* B9 set - UNCLASSIFIED */ + cfpkt_add_head(pkt, &packet_type, 1); + + /* Add info for MUX-layer to route the packet out. */ + info = cfpkt_info(pkt); + info->channel_id = service->layer.id; + /* To optimize alignment, we add up the size of CAIF header + * before payload. + */ + info->hdr_len = 4; + info->dev_info = &service->dev_info; + return layr->dn->transmit(layr->dn, pkt); +} diff --git a/net/caif/cffrml.c b/net/caif/cffrml.c new file mode 100644 index 000000000..6651a8dc6 --- /dev/null +++ b/net/caif/cffrml.c @@ -0,0 +1,197 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * CAIF Framing Layer. + * + * Copyright (C) ST-Ericsson AB 2010 + * Author: Sjur Brendeland + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s(): " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include +#include + +#define container_obj(layr) container_of(layr, struct cffrml, layer) + +struct cffrml { + struct cflayer layer; + bool dofcs; /* !< FCS active */ + int __percpu *pcpu_refcnt; +}; + +static int cffrml_receive(struct cflayer *layr, struct cfpkt *pkt); +static int cffrml_transmit(struct cflayer *layr, struct cfpkt *pkt); +static void cffrml_ctrlcmd(struct cflayer *layr, enum caif_ctrlcmd ctrl, + int phyid); + +static u32 cffrml_rcv_error; +static u32 cffrml_rcv_checsum_error; +struct cflayer *cffrml_create(u16 phyid, bool use_fcs) +{ + struct cffrml *this = kzalloc(sizeof(struct cffrml), GFP_ATOMIC); + if (!this) + return NULL; + this->pcpu_refcnt = alloc_percpu(int); + if (this->pcpu_refcnt == NULL) { + kfree(this); + return NULL; + } + + caif_assert(offsetof(struct cffrml, layer) == 0); + + this->layer.receive = cffrml_receive; + this->layer.transmit = cffrml_transmit; + this->layer.ctrlcmd = cffrml_ctrlcmd; + snprintf(this->layer.name, CAIF_LAYER_NAME_SZ, "frm%d", phyid); + this->dofcs = use_fcs; + this->layer.id = phyid; + return (struct cflayer *) this; +} + +void cffrml_free(struct cflayer *layer) +{ + struct cffrml *this = container_obj(layer); + free_percpu(this->pcpu_refcnt); + kfree(layer); +} + +void cffrml_set_uplayer(struct cflayer *this, struct cflayer *up) +{ + this->up = up; +} + +void cffrml_set_dnlayer(struct cflayer *this, struct cflayer *dn) +{ + this->dn = dn; +} + +static u16 cffrml_checksum(u16 chks, void *buf, u16 len) +{ + /* FIXME: FCS should be moved to glue in order to use OS-Specific + * solutions + */ + return crc_ccitt(chks, buf, len); +} + +static int cffrml_receive(struct cflayer *layr, struct cfpkt *pkt) +{ + u16 tmp; + u16 len; + u16 hdrchks; + int pktchks; + struct cffrml *this; + this = container_obj(layr); + + cfpkt_extr_head(pkt, &tmp, 2); + len = le16_to_cpu(tmp); + + /* Subtract for FCS on length if FCS is not used. */ + if (!this->dofcs) + len -= 2; + + if (cfpkt_setlen(pkt, len) < 0) { + ++cffrml_rcv_error; + pr_err("Framing length error (%d)\n", len); + cfpkt_destroy(pkt); + return -EPROTO; + } + /* + * Don't do extract if FCS is false, rather do setlen - then we don't + * get a cache-miss. + */ + if (this->dofcs) { + cfpkt_extr_trail(pkt, &tmp, 2); + hdrchks = le16_to_cpu(tmp); + pktchks = cfpkt_iterate(pkt, cffrml_checksum, 0xffff); + if (pktchks != hdrchks) { + cfpkt_add_trail(pkt, &tmp, 2); + ++cffrml_rcv_error; + ++cffrml_rcv_checsum_error; + pr_info("Frame checksum error (0x%x != 0x%x)\n", + hdrchks, pktchks); + return -EILSEQ; + } + } + if (cfpkt_erroneous(pkt)) { + ++cffrml_rcv_error; + pr_err("Packet is erroneous!\n"); + cfpkt_destroy(pkt); + return -EPROTO; + } + + if (layr->up == NULL) { + pr_err("Layr up is missing!\n"); + cfpkt_destroy(pkt); + return -EINVAL; + } + + return layr->up->receive(layr->up, pkt); +} + +static int cffrml_transmit(struct cflayer *layr, struct cfpkt *pkt) +{ + u16 chks; + u16 len; + __le16 data; + + struct cffrml *this = container_obj(layr); + if (this->dofcs) { + chks = cfpkt_iterate(pkt, cffrml_checksum, 0xffff); + data = cpu_to_le16(chks); + cfpkt_add_trail(pkt, &data, 2); + } else { + cfpkt_pad_trail(pkt, 2); + } + len = cfpkt_getlen(pkt); + data = cpu_to_le16(len); + cfpkt_add_head(pkt, &data, 2); + cfpkt_info(pkt)->hdr_len += 2; + if (cfpkt_erroneous(pkt)) { + pr_err("Packet is erroneous!\n"); + cfpkt_destroy(pkt); + return -EPROTO; + } + + if (layr->dn == NULL) { + cfpkt_destroy(pkt); + return -ENODEV; + + } + return layr->dn->transmit(layr->dn, pkt); +} + +static void cffrml_ctrlcmd(struct cflayer *layr, enum caif_ctrlcmd ctrl, + int phyid) +{ + if (layr->up && layr->up->ctrlcmd) + layr->up->ctrlcmd(layr->up, ctrl, layr->id); +} + +void cffrml_put(struct cflayer *layr) +{ + struct cffrml *this = container_obj(layr); + if (layr != NULL && this->pcpu_refcnt != NULL) + this_cpu_dec(*this->pcpu_refcnt); +} + +void cffrml_hold(struct cflayer *layr) +{ + struct cffrml *this = container_obj(layr); + if (layr != NULL && this->pcpu_refcnt != NULL) + this_cpu_inc(*this->pcpu_refcnt); +} + +int cffrml_refcnt_read(struct cflayer *layr) +{ + int i, refcnt = 0; + struct cffrml *this = container_obj(layr); + for_each_possible_cpu(i) + refcnt += *per_cpu_ptr(this->pcpu_refcnt, i); + return refcnt; +} diff --git a/net/caif/cfmuxl.c b/net/caif/cfmuxl.c new file mode 100644 index 000000000..4172b0d0d --- /dev/null +++ b/net/caif/cfmuxl.c @@ -0,0 +1,267 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) ST-Ericsson AB 2010 + * Author: Sjur Brendeland + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s(): " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include +#include + +#define container_obj(layr) container_of(layr, struct cfmuxl, layer) + +#define CAIF_CTRL_CHANNEL 0 +#define UP_CACHE_SIZE 8 +#define DN_CACHE_SIZE 8 + +struct cfmuxl { + struct cflayer layer; + struct list_head srvl_list; + struct list_head frml_list; + struct cflayer *up_cache[UP_CACHE_SIZE]; + struct cflayer *dn_cache[DN_CACHE_SIZE]; + /* + * Set when inserting or removing downwards layers. + */ + spinlock_t transmit_lock; + + /* + * Set when inserting or removing upwards layers. + */ + spinlock_t receive_lock; + +}; + +static int cfmuxl_receive(struct cflayer *layr, struct cfpkt *pkt); +static int cfmuxl_transmit(struct cflayer *layr, struct cfpkt *pkt); +static void cfmuxl_ctrlcmd(struct cflayer *layr, enum caif_ctrlcmd ctrl, + int phyid); +static struct cflayer *get_up(struct cfmuxl *muxl, u16 id); + +struct cflayer *cfmuxl_create(void) +{ + struct cfmuxl *this = kzalloc(sizeof(struct cfmuxl), GFP_ATOMIC); + + if (!this) + return NULL; + this->layer.receive = cfmuxl_receive; + this->layer.transmit = cfmuxl_transmit; + this->layer.ctrlcmd = cfmuxl_ctrlcmd; + INIT_LIST_HEAD(&this->srvl_list); + INIT_LIST_HEAD(&this->frml_list); + spin_lock_init(&this->transmit_lock); + spin_lock_init(&this->receive_lock); + snprintf(this->layer.name, CAIF_LAYER_NAME_SZ, "mux"); + return &this->layer; +} + +int cfmuxl_set_dnlayer(struct cflayer *layr, struct cflayer *dn, u8 phyid) +{ + struct cfmuxl *muxl = (struct cfmuxl *) layr; + + spin_lock_bh(&muxl->transmit_lock); + list_add_rcu(&dn->node, &muxl->frml_list); + spin_unlock_bh(&muxl->transmit_lock); + return 0; +} + +static struct cflayer *get_from_id(struct list_head *list, u16 id) +{ + struct cflayer *lyr; + list_for_each_entry_rcu(lyr, list, node) { + if (lyr->id == id) + return lyr; + } + + return NULL; +} + +int cfmuxl_set_uplayer(struct cflayer *layr, struct cflayer *up, u8 linkid) +{ + struct cfmuxl *muxl = container_obj(layr); + struct cflayer *old; + + spin_lock_bh(&muxl->receive_lock); + + /* Two entries with same id is wrong, so remove old layer from mux */ + old = get_from_id(&muxl->srvl_list, linkid); + if (old != NULL) + list_del_rcu(&old->node); + + list_add_rcu(&up->node, &muxl->srvl_list); + spin_unlock_bh(&muxl->receive_lock); + + return 0; +} + +struct cflayer *cfmuxl_remove_dnlayer(struct cflayer *layr, u8 phyid) +{ + struct cfmuxl *muxl = container_obj(layr); + struct cflayer *dn; + int idx = phyid % DN_CACHE_SIZE; + + spin_lock_bh(&muxl->transmit_lock); + RCU_INIT_POINTER(muxl->dn_cache[idx], NULL); + dn = get_from_id(&muxl->frml_list, phyid); + if (dn == NULL) + goto out; + + list_del_rcu(&dn->node); + caif_assert(dn != NULL); +out: + spin_unlock_bh(&muxl->transmit_lock); + return dn; +} + +static struct cflayer *get_up(struct cfmuxl *muxl, u16 id) +{ + struct cflayer *up; + int idx = id % UP_CACHE_SIZE; + up = rcu_dereference(muxl->up_cache[idx]); + if (up == NULL || up->id != id) { + spin_lock_bh(&muxl->receive_lock); + up = get_from_id(&muxl->srvl_list, id); + rcu_assign_pointer(muxl->up_cache[idx], up); + spin_unlock_bh(&muxl->receive_lock); + } + return up; +} + +static struct cflayer *get_dn(struct cfmuxl *muxl, struct dev_info *dev_info) +{ + struct cflayer *dn; + int idx = dev_info->id % DN_CACHE_SIZE; + dn = rcu_dereference(muxl->dn_cache[idx]); + if (dn == NULL || dn->id != dev_info->id) { + spin_lock_bh(&muxl->transmit_lock); + dn = get_from_id(&muxl->frml_list, dev_info->id); + rcu_assign_pointer(muxl->dn_cache[idx], dn); + spin_unlock_bh(&muxl->transmit_lock); + } + return dn; +} + +struct cflayer *cfmuxl_remove_uplayer(struct cflayer *layr, u8 id) +{ + struct cflayer *up; + struct cfmuxl *muxl = container_obj(layr); + int idx = id % UP_CACHE_SIZE; + + if (id == 0) { + pr_warn("Trying to remove control layer\n"); + return NULL; + } + + spin_lock_bh(&muxl->receive_lock); + up = get_from_id(&muxl->srvl_list, id); + if (up == NULL) + goto out; + + RCU_INIT_POINTER(muxl->up_cache[idx], NULL); + list_del_rcu(&up->node); +out: + spin_unlock_bh(&muxl->receive_lock); + return up; +} + +static int cfmuxl_receive(struct cflayer *layr, struct cfpkt *pkt) +{ + int ret; + struct cfmuxl *muxl = container_obj(layr); + u8 id; + struct cflayer *up; + if (cfpkt_extr_head(pkt, &id, 1) < 0) { + pr_err("erroneous Caif Packet\n"); + cfpkt_destroy(pkt); + return -EPROTO; + } + rcu_read_lock(); + up = get_up(muxl, id); + + if (up == NULL) { + pr_debug("Received data on unknown link ID = %d (0x%x)" + " up == NULL", id, id); + cfpkt_destroy(pkt); + /* + * Don't return ERROR, since modem misbehaves and sends out + * flow on before linksetup response. + */ + + rcu_read_unlock(); + return /* CFGLU_EPROT; */ 0; + } + + /* We can't hold rcu_lock during receive, so take a ref count instead */ + cfsrvl_get(up); + rcu_read_unlock(); + + ret = up->receive(up, pkt); + + cfsrvl_put(up); + return ret; +} + +static int cfmuxl_transmit(struct cflayer *layr, struct cfpkt *pkt) +{ + struct cfmuxl *muxl = container_obj(layr); + int err; + u8 linkid; + struct cflayer *dn; + struct caif_payload_info *info = cfpkt_info(pkt); + BUG_ON(!info); + + rcu_read_lock(); + + dn = get_dn(muxl, info->dev_info); + if (dn == NULL) { + pr_debug("Send data on unknown phy ID = %d (0x%x)\n", + info->dev_info->id, info->dev_info->id); + rcu_read_unlock(); + cfpkt_destroy(pkt); + return -ENOTCONN; + } + + info->hdr_len += 1; + linkid = info->channel_id; + cfpkt_add_head(pkt, &linkid, 1); + + /* We can't hold rcu_lock during receive, so take a ref count instead */ + cffrml_hold(dn); + + rcu_read_unlock(); + + err = dn->transmit(dn, pkt); + + cffrml_put(dn); + return err; +} + +static void cfmuxl_ctrlcmd(struct cflayer *layr, enum caif_ctrlcmd ctrl, + int phyid) +{ + struct cfmuxl *muxl = container_obj(layr); + struct cflayer *layer; + + rcu_read_lock(); + list_for_each_entry_rcu(layer, &muxl->srvl_list, node) { + + if (cfsrvl_phyid_match(layer, phyid) && layer->ctrlcmd) { + + if ((ctrl == _CAIF_CTRLCMD_PHYIF_DOWN_IND || + ctrl == CAIF_CTRLCMD_REMOTE_SHUTDOWN_IND) && + layer->id != 0) + cfmuxl_remove_uplayer(layr, layer->id); + + /* NOTE: ctrlcmd is not allowed to block */ + layer->ctrlcmd(layer, ctrl, phyid); + } + } + rcu_read_unlock(); +} diff --git a/net/caif/cfpkt_skbuff.c b/net/caif/cfpkt_skbuff.c new file mode 100644 index 000000000..7796414d4 --- /dev/null +++ b/net/caif/cfpkt_skbuff.c @@ -0,0 +1,382 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) ST-Ericsson AB 2010 + * Author: Sjur Brendeland + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s(): " fmt, __func__ + +#include +#include +#include +#include + +#define PKT_PREFIX 48 +#define PKT_POSTFIX 2 +#define PKT_LEN_WHEN_EXTENDING 128 +#define PKT_ERROR(pkt, errmsg) \ +do { \ + cfpkt_priv(pkt)->erronous = true; \ + skb_reset_tail_pointer(&pkt->skb); \ + pr_warn(errmsg); \ +} while (0) + +struct cfpktq { + struct sk_buff_head head; + atomic_t count; + /* Lock protects count updates */ + spinlock_t lock; +}; + +/* + * net/caif/ is generic and does not + * understand SKB, so we do this typecast + */ +struct cfpkt { + struct sk_buff skb; +}; + +/* Private data inside SKB */ +struct cfpkt_priv_data { + struct dev_info dev_info; + bool erronous; +}; + +static inline struct cfpkt_priv_data *cfpkt_priv(struct cfpkt *pkt) +{ + return (struct cfpkt_priv_data *) pkt->skb.cb; +} + +static inline bool is_erronous(struct cfpkt *pkt) +{ + return cfpkt_priv(pkt)->erronous; +} + +static inline struct sk_buff *pkt_to_skb(struct cfpkt *pkt) +{ + return &pkt->skb; +} + +static inline struct cfpkt *skb_to_pkt(struct sk_buff *skb) +{ + return (struct cfpkt *) skb; +} + +struct cfpkt *cfpkt_fromnative(enum caif_direction dir, void *nativepkt) +{ + struct cfpkt *pkt = skb_to_pkt(nativepkt); + cfpkt_priv(pkt)->erronous = false; + return pkt; +} +EXPORT_SYMBOL(cfpkt_fromnative); + +void *cfpkt_tonative(struct cfpkt *pkt) +{ + return (void *) pkt; +} +EXPORT_SYMBOL(cfpkt_tonative); + +static struct cfpkt *cfpkt_create_pfx(u16 len, u16 pfx) +{ + struct sk_buff *skb; + + skb = alloc_skb(len + pfx, GFP_ATOMIC); + if (unlikely(skb == NULL)) + return NULL; + + skb_reserve(skb, pfx); + return skb_to_pkt(skb); +} + +inline struct cfpkt *cfpkt_create(u16 len) +{ + return cfpkt_create_pfx(len + PKT_POSTFIX, PKT_PREFIX); +} + +void cfpkt_destroy(struct cfpkt *pkt) +{ + struct sk_buff *skb = pkt_to_skb(pkt); + kfree_skb(skb); +} + +inline bool cfpkt_more(struct cfpkt *pkt) +{ + struct sk_buff *skb = pkt_to_skb(pkt); + return skb->len > 0; +} + +int cfpkt_peek_head(struct cfpkt *pkt, void *data, u16 len) +{ + struct sk_buff *skb = pkt_to_skb(pkt); + if (skb_headlen(skb) >= len) { + memcpy(data, skb->data, len); + return 0; + } + return !cfpkt_extr_head(pkt, data, len) && + !cfpkt_add_head(pkt, data, len); +} + +int cfpkt_extr_head(struct cfpkt *pkt, void *data, u16 len) +{ + struct sk_buff *skb = pkt_to_skb(pkt); + u8 *from; + if (unlikely(is_erronous(pkt))) + return -EPROTO; + + if (unlikely(len > skb->len)) { + PKT_ERROR(pkt, "read beyond end of packet\n"); + return -EPROTO; + } + + if (unlikely(len > skb_headlen(skb))) { + if (unlikely(skb_linearize(skb) != 0)) { + PKT_ERROR(pkt, "linearize failed\n"); + return -EPROTO; + } + } + from = skb_pull(skb, len); + from -= len; + if (data) + memcpy(data, from, len); + return 0; +} +EXPORT_SYMBOL(cfpkt_extr_head); + +int cfpkt_extr_trail(struct cfpkt *pkt, void *dta, u16 len) +{ + struct sk_buff *skb = pkt_to_skb(pkt); + u8 *data = dta; + u8 *from; + if (unlikely(is_erronous(pkt))) + return -EPROTO; + + if (unlikely(skb_linearize(skb) != 0)) { + PKT_ERROR(pkt, "linearize failed\n"); + return -EPROTO; + } + if (unlikely(skb->data + len > skb_tail_pointer(skb))) { + PKT_ERROR(pkt, "read beyond end of packet\n"); + return -EPROTO; + } + from = skb_tail_pointer(skb) - len; + skb_trim(skb, skb->len - len); + memcpy(data, from, len); + return 0; +} + +int cfpkt_pad_trail(struct cfpkt *pkt, u16 len) +{ + return cfpkt_add_body(pkt, NULL, len); +} + +int cfpkt_add_body(struct cfpkt *pkt, const void *data, u16 len) +{ + struct sk_buff *skb = pkt_to_skb(pkt); + struct sk_buff *lastskb; + u8 *to; + u16 addlen = 0; + + + if (unlikely(is_erronous(pkt))) + return -EPROTO; + + lastskb = skb; + + /* Check whether we need to add space at the tail */ + if (unlikely(skb_tailroom(skb) < len)) { + if (likely(len < PKT_LEN_WHEN_EXTENDING)) + addlen = PKT_LEN_WHEN_EXTENDING; + else + addlen = len; + } + + /* Check whether we need to change the SKB before writing to the tail */ + if (unlikely((addlen > 0) || skb_cloned(skb) || skb_shared(skb))) { + + /* Make sure data is writable */ + if (unlikely(skb_cow_data(skb, addlen, &lastskb) < 0)) { + PKT_ERROR(pkt, "cow failed\n"); + return -EPROTO; + } + } + + /* All set to put the last SKB and optionally write data there. */ + to = pskb_put(skb, lastskb, len); + if (likely(data)) + memcpy(to, data, len); + return 0; +} + +inline int cfpkt_addbdy(struct cfpkt *pkt, u8 data) +{ + return cfpkt_add_body(pkt, &data, 1); +} + +int cfpkt_add_head(struct cfpkt *pkt, const void *data2, u16 len) +{ + struct sk_buff *skb = pkt_to_skb(pkt); + struct sk_buff *lastskb; + u8 *to; + const u8 *data = data2; + int ret; + if (unlikely(is_erronous(pkt))) + return -EPROTO; + if (unlikely(skb_headroom(skb) < len)) { + PKT_ERROR(pkt, "no headroom\n"); + return -EPROTO; + } + + /* Make sure data is writable */ + ret = skb_cow_data(skb, 0, &lastskb); + if (unlikely(ret < 0)) { + PKT_ERROR(pkt, "cow failed\n"); + return ret; + } + + to = skb_push(skb, len); + memcpy(to, data, len); + return 0; +} +EXPORT_SYMBOL(cfpkt_add_head); + +inline int cfpkt_add_trail(struct cfpkt *pkt, const void *data, u16 len) +{ + return cfpkt_add_body(pkt, data, len); +} + +inline u16 cfpkt_getlen(struct cfpkt *pkt) +{ + struct sk_buff *skb = pkt_to_skb(pkt); + return skb->len; +} + +int cfpkt_iterate(struct cfpkt *pkt, + u16 (*iter_func)(u16, void *, u16), + u16 data) +{ + /* + * Don't care about the performance hit of linearizing, + * Checksum should not be used on high-speed interfaces anyway. + */ + if (unlikely(is_erronous(pkt))) + return -EPROTO; + if (unlikely(skb_linearize(&pkt->skb) != 0)) { + PKT_ERROR(pkt, "linearize failed\n"); + return -EPROTO; + } + return iter_func(data, pkt->skb.data, cfpkt_getlen(pkt)); +} + +int cfpkt_setlen(struct cfpkt *pkt, u16 len) +{ + struct sk_buff *skb = pkt_to_skb(pkt); + + + if (unlikely(is_erronous(pkt))) + return -EPROTO; + + if (likely(len <= skb->len)) { + if (unlikely(skb->data_len)) + ___pskb_trim(skb, len); + else + skb_trim(skb, len); + + return cfpkt_getlen(pkt); + } + + /* Need to expand SKB */ + if (unlikely(!cfpkt_pad_trail(pkt, len - skb->len))) + PKT_ERROR(pkt, "skb_pad_trail failed\n"); + + return cfpkt_getlen(pkt); +} + +struct cfpkt *cfpkt_append(struct cfpkt *dstpkt, + struct cfpkt *addpkt, + u16 expectlen) +{ + struct sk_buff *dst = pkt_to_skb(dstpkt); + struct sk_buff *add = pkt_to_skb(addpkt); + u16 addlen = skb_headlen(add); + u16 neededtailspace; + struct sk_buff *tmp; + u16 dstlen; + u16 createlen; + if (unlikely(is_erronous(dstpkt) || is_erronous(addpkt))) { + return dstpkt; + } + if (expectlen > addlen) + neededtailspace = expectlen; + else + neededtailspace = addlen; + + if (dst->tail + neededtailspace > dst->end) { + /* Create a dumplicate of 'dst' with more tail space */ + struct cfpkt *tmppkt; + dstlen = skb_headlen(dst); + createlen = dstlen + neededtailspace; + tmppkt = cfpkt_create(createlen + PKT_PREFIX + PKT_POSTFIX); + if (tmppkt == NULL) + return NULL; + tmp = pkt_to_skb(tmppkt); + skb_put_data(tmp, dst->data, dstlen); + cfpkt_destroy(dstpkt); + dst = tmp; + } + skb_put_data(dst, add->data, skb_headlen(add)); + cfpkt_destroy(addpkt); + return skb_to_pkt(dst); +} + +struct cfpkt *cfpkt_split(struct cfpkt *pkt, u16 pos) +{ + struct sk_buff *skb2; + struct sk_buff *skb = pkt_to_skb(pkt); + struct cfpkt *tmppkt; + u8 *split = skb->data + pos; + u16 len2nd = skb_tail_pointer(skb) - split; + + if (unlikely(is_erronous(pkt))) + return NULL; + + if (skb->data + pos > skb_tail_pointer(skb)) { + PKT_ERROR(pkt, "trying to split beyond end of packet\n"); + return NULL; + } + + /* Create a new packet for the second part of the data */ + tmppkt = cfpkt_create_pfx(len2nd + PKT_PREFIX + PKT_POSTFIX, + PKT_PREFIX); + if (tmppkt == NULL) + return NULL; + skb2 = pkt_to_skb(tmppkt); + + + if (skb2 == NULL) + return NULL; + + skb_put_data(skb2, split, len2nd); + + /* Reduce the length of the original packet */ + skb_trim(skb, pos); + + skb2->priority = skb->priority; + return skb_to_pkt(skb2); +} + +bool cfpkt_erroneous(struct cfpkt *pkt) +{ + return cfpkt_priv(pkt)->erronous; +} + +struct caif_payload_info *cfpkt_info(struct cfpkt *pkt) +{ + return (struct caif_payload_info *)&pkt_to_skb(pkt)->cb; +} +EXPORT_SYMBOL(cfpkt_info); + +void cfpkt_set_prio(struct cfpkt *pkt, int prio) +{ + pkt_to_skb(pkt)->priority = prio; +} +EXPORT_SYMBOL(cfpkt_set_prio); diff --git a/net/caif/cfrfml.c b/net/caif/cfrfml.c new file mode 100644 index 000000000..7b0af33bd --- /dev/null +++ b/net/caif/cfrfml.c @@ -0,0 +1,299 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) ST-Ericsson AB 2010 + * Author: Sjur Brendeland + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s(): " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include + +#define container_obj(layr) container_of(layr, struct cfrfml, serv.layer) +#define RFM_SEGMENTATION_BIT 0x01 +#define RFM_HEAD_SIZE 7 + +static int cfrfml_receive(struct cflayer *layr, struct cfpkt *pkt); +static int cfrfml_transmit(struct cflayer *layr, struct cfpkt *pkt); + +struct cfrfml { + struct cfsrvl serv; + struct cfpkt *incomplete_frm; + int fragment_size; + u8 seghead[6]; + u16 pdu_size; + /* Protects serialized processing of packets */ + spinlock_t sync; +}; + +static void cfrfml_release(struct cflayer *layer) +{ + struct cfsrvl *srvl = container_of(layer, struct cfsrvl, layer); + struct cfrfml *rfml = container_obj(&srvl->layer); + + if (rfml->incomplete_frm) + cfpkt_destroy(rfml->incomplete_frm); + + kfree(srvl); +} + +struct cflayer *cfrfml_create(u8 channel_id, struct dev_info *dev_info, + int mtu_size) +{ + int tmp; + struct cfrfml *this = kzalloc(sizeof(struct cfrfml), GFP_ATOMIC); + + if (!this) + return NULL; + + cfsrvl_init(&this->serv, channel_id, dev_info, false); + this->serv.release = cfrfml_release; + this->serv.layer.receive = cfrfml_receive; + this->serv.layer.transmit = cfrfml_transmit; + + /* Round down to closest multiple of 16 */ + tmp = (mtu_size - RFM_HEAD_SIZE - 6) / 16; + tmp *= 16; + + this->fragment_size = tmp; + spin_lock_init(&this->sync); + snprintf(this->serv.layer.name, CAIF_LAYER_NAME_SZ, + "rfm%d", channel_id); + + return &this->serv.layer; +} + +static struct cfpkt *rfm_append(struct cfrfml *rfml, char *seghead, + struct cfpkt *pkt, int *err) +{ + struct cfpkt *tmppkt; + *err = -EPROTO; + /* n-th but not last segment */ + + if (cfpkt_extr_head(pkt, seghead, 6) < 0) + return NULL; + + /* Verify correct header */ + if (memcmp(seghead, rfml->seghead, 6) != 0) + return NULL; + + tmppkt = cfpkt_append(rfml->incomplete_frm, pkt, + rfml->pdu_size + RFM_HEAD_SIZE); + + /* If cfpkt_append failes input pkts are not freed */ + *err = -ENOMEM; + if (tmppkt == NULL) + return NULL; + + *err = 0; + return tmppkt; +} + +static int cfrfml_receive(struct cflayer *layr, struct cfpkt *pkt) +{ + u8 tmp; + bool segmented; + int err; + u8 seghead[6]; + struct cfrfml *rfml; + struct cfpkt *tmppkt = NULL; + + caif_assert(layr->up != NULL); + caif_assert(layr->receive != NULL); + rfml = container_obj(layr); + spin_lock(&rfml->sync); + + err = -EPROTO; + if (cfpkt_extr_head(pkt, &tmp, 1) < 0) + goto out; + segmented = tmp & RFM_SEGMENTATION_BIT; + + if (segmented) { + if (rfml->incomplete_frm == NULL) { + /* Initial Segment */ + if (cfpkt_peek_head(pkt, rfml->seghead, 6) != 0) + goto out; + + rfml->pdu_size = get_unaligned_le16(rfml->seghead+4); + + if (cfpkt_erroneous(pkt)) + goto out; + rfml->incomplete_frm = pkt; + pkt = NULL; + } else { + + tmppkt = rfm_append(rfml, seghead, pkt, &err); + if (tmppkt == NULL) + goto out; + + if (cfpkt_erroneous(tmppkt)) + goto out; + + rfml->incomplete_frm = tmppkt; + + + if (cfpkt_erroneous(tmppkt)) + goto out; + } + err = 0; + goto out; + } + + if (rfml->incomplete_frm) { + + /* Last Segment */ + tmppkt = rfm_append(rfml, seghead, pkt, &err); + if (tmppkt == NULL) + goto out; + + if (cfpkt_erroneous(tmppkt)) + goto out; + + rfml->incomplete_frm = NULL; + pkt = tmppkt; + tmppkt = NULL; + + /* Verify that length is correct */ + err = -EPROTO; + if (rfml->pdu_size != cfpkt_getlen(pkt) - RFM_HEAD_SIZE + 1) + goto out; + } + + err = rfml->serv.layer.up->receive(rfml->serv.layer.up, pkt); + +out: + + if (err != 0) { + if (tmppkt) + cfpkt_destroy(tmppkt); + if (pkt) + cfpkt_destroy(pkt); + if (rfml->incomplete_frm) + cfpkt_destroy(rfml->incomplete_frm); + rfml->incomplete_frm = NULL; + + pr_info("Connection error %d triggered on RFM link\n", err); + + /* Trigger connection error upon failure.*/ + layr->up->ctrlcmd(layr->up, CAIF_CTRLCMD_REMOTE_SHUTDOWN_IND, + rfml->serv.dev_info.id); + } + spin_unlock(&rfml->sync); + + if (unlikely(err == -EAGAIN)) + /* It is not possible to recover after drop of a fragment */ + err = -EIO; + + return err; +} + + +static int cfrfml_transmit_segment(struct cfrfml *rfml, struct cfpkt *pkt) +{ + caif_assert(cfpkt_getlen(pkt) < rfml->fragment_size + RFM_HEAD_SIZE); + + /* Add info for MUX-layer to route the packet out. */ + cfpkt_info(pkt)->channel_id = rfml->serv.layer.id; + + /* + * To optimize alignment, we add up the size of CAIF header before + * payload. + */ + cfpkt_info(pkt)->hdr_len = RFM_HEAD_SIZE; + cfpkt_info(pkt)->dev_info = &rfml->serv.dev_info; + + return rfml->serv.layer.dn->transmit(rfml->serv.layer.dn, pkt); +} + +static int cfrfml_transmit(struct cflayer *layr, struct cfpkt *pkt) +{ + int err; + u8 seg; + u8 head[6]; + struct cfpkt *rearpkt = NULL; + struct cfpkt *frontpkt = pkt; + struct cfrfml *rfml = container_obj(layr); + + caif_assert(layr->dn != NULL); + caif_assert(layr->dn->transmit != NULL); + + if (!cfsrvl_ready(&rfml->serv, &err)) + goto out; + + err = -EPROTO; + if (cfpkt_getlen(pkt) <= RFM_HEAD_SIZE-1) + goto out; + + err = 0; + if (cfpkt_getlen(pkt) > rfml->fragment_size + RFM_HEAD_SIZE) + err = cfpkt_peek_head(pkt, head, 6); + + if (err != 0) + goto out; + + while (cfpkt_getlen(frontpkt) > rfml->fragment_size + RFM_HEAD_SIZE) { + + seg = 1; + err = -EPROTO; + + if (cfpkt_add_head(frontpkt, &seg, 1) < 0) + goto out; + /* + * On OOM error cfpkt_split returns NULL. + * + * NOTE: Segmented pdu is not correctly aligned. + * This has negative performance impact. + */ + + rearpkt = cfpkt_split(frontpkt, rfml->fragment_size); + if (rearpkt == NULL) + goto out; + + err = cfrfml_transmit_segment(rfml, frontpkt); + + if (err != 0) { + frontpkt = NULL; + goto out; + } + + frontpkt = rearpkt; + rearpkt = NULL; + + err = -EPROTO; + if (cfpkt_add_head(frontpkt, head, 6) < 0) + goto out; + + } + + seg = 0; + err = -EPROTO; + + if (cfpkt_add_head(frontpkt, &seg, 1) < 0) + goto out; + + err = cfrfml_transmit_segment(rfml, frontpkt); + + frontpkt = NULL; +out: + + if (err != 0) { + pr_info("Connection error %d triggered on RFM link\n", err); + /* Trigger connection error upon failure.*/ + + layr->up->ctrlcmd(layr->up, CAIF_CTRLCMD_REMOTE_SHUTDOWN_IND, + rfml->serv.dev_info.id); + + if (rearpkt) + cfpkt_destroy(rearpkt); + + if (frontpkt) + cfpkt_destroy(frontpkt); + } + + return err; +} diff --git a/net/caif/cfserl.c b/net/caif/cfserl.c new file mode 100644 index 000000000..aee11c74d --- /dev/null +++ b/net/caif/cfserl.c @@ -0,0 +1,192 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) ST-Ericsson AB 2010 + * Author: Sjur Brendeland + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s(): " fmt, __func__ + +#include +#include +#include +#include +#include +#include + +#define container_obj(layr) ((struct cfserl *) layr) + +#define CFSERL_STX 0x02 +#define SERIAL_MINIUM_PACKET_SIZE 4 +#define SERIAL_MAX_FRAMESIZE 4096 +struct cfserl { + struct cflayer layer; + struct cfpkt *incomplete_frm; + /* Protects parallel processing of incoming packets */ + spinlock_t sync; + bool usestx; +}; + +static int cfserl_receive(struct cflayer *layr, struct cfpkt *pkt); +static int cfserl_transmit(struct cflayer *layr, struct cfpkt *pkt); +static void cfserl_ctrlcmd(struct cflayer *layr, enum caif_ctrlcmd ctrl, + int phyid); + +void cfserl_release(struct cflayer *layer) +{ + kfree(layer); +} + +struct cflayer *cfserl_create(int instance, bool use_stx) +{ + struct cfserl *this = kzalloc(sizeof(struct cfserl), GFP_ATOMIC); + if (!this) + return NULL; + caif_assert(offsetof(struct cfserl, layer) == 0); + this->layer.receive = cfserl_receive; + this->layer.transmit = cfserl_transmit; + this->layer.ctrlcmd = cfserl_ctrlcmd; + this->usestx = use_stx; + spin_lock_init(&this->sync); + snprintf(this->layer.name, CAIF_LAYER_NAME_SZ, "ser1"); + return &this->layer; +} + +static int cfserl_receive(struct cflayer *l, struct cfpkt *newpkt) +{ + struct cfserl *layr = container_obj(l); + u16 pkt_len; + struct cfpkt *pkt = NULL; + struct cfpkt *tail_pkt = NULL; + u8 tmp8; + u16 tmp; + u8 stx = CFSERL_STX; + int ret; + u16 expectlen = 0; + + caif_assert(newpkt != NULL); + spin_lock(&layr->sync); + + if (layr->incomplete_frm != NULL) { + layr->incomplete_frm = + cfpkt_append(layr->incomplete_frm, newpkt, expectlen); + pkt = layr->incomplete_frm; + if (pkt == NULL) { + spin_unlock(&layr->sync); + return -ENOMEM; + } + } else { + pkt = newpkt; + } + layr->incomplete_frm = NULL; + + do { + /* Search for STX at start of pkt if STX is used */ + if (layr->usestx) { + cfpkt_extr_head(pkt, &tmp8, 1); + if (tmp8 != CFSERL_STX) { + while (cfpkt_more(pkt) + && tmp8 != CFSERL_STX) { + cfpkt_extr_head(pkt, &tmp8, 1); + } + if (!cfpkt_more(pkt)) { + cfpkt_destroy(pkt); + layr->incomplete_frm = NULL; + spin_unlock(&layr->sync); + return -EPROTO; + } + } + } + + pkt_len = cfpkt_getlen(pkt); + + /* + * pkt_len is the accumulated length of the packet data + * we have received so far. + * Exit if frame doesn't hold length. + */ + + if (pkt_len < 2) { + if (layr->usestx) + cfpkt_add_head(pkt, &stx, 1); + layr->incomplete_frm = pkt; + spin_unlock(&layr->sync); + return 0; + } + + /* + * Find length of frame. + * expectlen is the length we need for a full frame. + */ + cfpkt_peek_head(pkt, &tmp, 2); + expectlen = le16_to_cpu(tmp) + 2; + /* + * Frame error handling + */ + if (expectlen < SERIAL_MINIUM_PACKET_SIZE + || expectlen > SERIAL_MAX_FRAMESIZE) { + if (!layr->usestx) { + if (pkt != NULL) + cfpkt_destroy(pkt); + layr->incomplete_frm = NULL; + spin_unlock(&layr->sync); + return -EPROTO; + } + continue; + } + + if (pkt_len < expectlen) { + /* Too little received data */ + if (layr->usestx) + cfpkt_add_head(pkt, &stx, 1); + layr->incomplete_frm = pkt; + spin_unlock(&layr->sync); + return 0; + } + + /* + * Enough data for at least one frame. + * Split the frame, if too long + */ + if (pkt_len > expectlen) + tail_pkt = cfpkt_split(pkt, expectlen); + else + tail_pkt = NULL; + + /* Send the first part of packet upwards.*/ + spin_unlock(&layr->sync); + ret = layr->layer.up->receive(layr->layer.up, pkt); + spin_lock(&layr->sync); + if (ret == -EILSEQ) { + if (layr->usestx) { + if (tail_pkt != NULL) + pkt = cfpkt_append(pkt, tail_pkt, 0); + /* Start search for next STX if frame failed */ + continue; + } else { + cfpkt_destroy(pkt); + pkt = NULL; + } + } + + pkt = tail_pkt; + + } while (pkt != NULL); + + spin_unlock(&layr->sync); + return 0; +} + +static int cfserl_transmit(struct cflayer *layer, struct cfpkt *newpkt) +{ + struct cfserl *layr = container_obj(layer); + u8 tmp8 = CFSERL_STX; + if (layr->usestx) + cfpkt_add_head(newpkt, &tmp8, 1); + return layer->dn->transmit(layer->dn, newpkt); +} + +static void cfserl_ctrlcmd(struct cflayer *layr, enum caif_ctrlcmd ctrl, + int phyid) +{ + layr->up->ctrlcmd(layr->up, ctrl, phyid); +} diff --git a/net/caif/cfsrvl.c b/net/caif/cfsrvl.c new file mode 100644 index 000000000..9cef9496a --- /dev/null +++ b/net/caif/cfsrvl.c @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) ST-Ericsson AB 2010 + * Author: Sjur Brendeland + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s(): " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define SRVL_CTRL_PKT_SIZE 1 +#define SRVL_FLOW_OFF 0x81 +#define SRVL_FLOW_ON 0x80 +#define SRVL_SET_PIN 0x82 + +#define container_obj(layr) container_of(layr, struct cfsrvl, layer) + +static void cfservl_ctrlcmd(struct cflayer *layr, enum caif_ctrlcmd ctrl, + int phyid) +{ + struct cfsrvl *service = container_obj(layr); + + if (layr->up == NULL || layr->up->ctrlcmd == NULL) + return; + + switch (ctrl) { + case CAIF_CTRLCMD_INIT_RSP: + service->open = true; + layr->up->ctrlcmd(layr->up, ctrl, phyid); + break; + case CAIF_CTRLCMD_DEINIT_RSP: + case CAIF_CTRLCMD_INIT_FAIL_RSP: + service->open = false; + layr->up->ctrlcmd(layr->up, ctrl, phyid); + break; + case _CAIF_CTRLCMD_PHYIF_FLOW_OFF_IND: + if (phyid != service->dev_info.id) + break; + if (service->modem_flow_on) + layr->up->ctrlcmd(layr->up, + CAIF_CTRLCMD_FLOW_OFF_IND, phyid); + service->phy_flow_on = false; + break; + case _CAIF_CTRLCMD_PHYIF_FLOW_ON_IND: + if (phyid != service->dev_info.id) + return; + if (service->modem_flow_on) { + layr->up->ctrlcmd(layr->up, + CAIF_CTRLCMD_FLOW_ON_IND, + phyid); + } + service->phy_flow_on = true; + break; + case CAIF_CTRLCMD_FLOW_OFF_IND: + if (service->phy_flow_on) { + layr->up->ctrlcmd(layr->up, + CAIF_CTRLCMD_FLOW_OFF_IND, phyid); + } + service->modem_flow_on = false; + break; + case CAIF_CTRLCMD_FLOW_ON_IND: + if (service->phy_flow_on) { + layr->up->ctrlcmd(layr->up, + CAIF_CTRLCMD_FLOW_ON_IND, phyid); + } + service->modem_flow_on = true; + break; + case _CAIF_CTRLCMD_PHYIF_DOWN_IND: + /* In case interface is down, let's fake a remove shutdown */ + layr->up->ctrlcmd(layr->up, + CAIF_CTRLCMD_REMOTE_SHUTDOWN_IND, phyid); + break; + case CAIF_CTRLCMD_REMOTE_SHUTDOWN_IND: + layr->up->ctrlcmd(layr->up, ctrl, phyid); + break; + default: + pr_warn("Unexpected ctrl in cfsrvl (%d)\n", ctrl); + /* We have both modem and phy flow on, send flow on */ + layr->up->ctrlcmd(layr->up, ctrl, phyid); + service->phy_flow_on = true; + break; + } +} + +static int cfservl_modemcmd(struct cflayer *layr, enum caif_modemcmd ctrl) +{ + struct cfsrvl *service = container_obj(layr); + + caif_assert(layr != NULL); + caif_assert(layr->dn != NULL); + caif_assert(layr->dn->transmit != NULL); + + if (!service->supports_flowctrl) + return 0; + + switch (ctrl) { + case CAIF_MODEMCMD_FLOW_ON_REQ: + { + struct cfpkt *pkt; + struct caif_payload_info *info; + u8 flow_on = SRVL_FLOW_ON; + pkt = cfpkt_create(SRVL_CTRL_PKT_SIZE); + if (!pkt) + return -ENOMEM; + + if (cfpkt_add_head(pkt, &flow_on, 1) < 0) { + pr_err("Packet is erroneous!\n"); + cfpkt_destroy(pkt); + return -EPROTO; + } + info = cfpkt_info(pkt); + info->channel_id = service->layer.id; + info->hdr_len = 1; + info->dev_info = &service->dev_info; + cfpkt_set_prio(pkt, TC_PRIO_CONTROL); + return layr->dn->transmit(layr->dn, pkt); + } + case CAIF_MODEMCMD_FLOW_OFF_REQ: + { + struct cfpkt *pkt; + struct caif_payload_info *info; + u8 flow_off = SRVL_FLOW_OFF; + pkt = cfpkt_create(SRVL_CTRL_PKT_SIZE); + if (!pkt) + return -ENOMEM; + + if (cfpkt_add_head(pkt, &flow_off, 1) < 0) { + pr_err("Packet is erroneous!\n"); + cfpkt_destroy(pkt); + return -EPROTO; + } + info = cfpkt_info(pkt); + info->channel_id = service->layer.id; + info->hdr_len = 1; + info->dev_info = &service->dev_info; + cfpkt_set_prio(pkt, TC_PRIO_CONTROL); + return layr->dn->transmit(layr->dn, pkt); + } + default: + break; + } + return -EINVAL; +} + +static void cfsrvl_release(struct cflayer *layer) +{ + struct cfsrvl *service = container_of(layer, struct cfsrvl, layer); + kfree(service); +} + +void cfsrvl_init(struct cfsrvl *service, + u8 channel_id, + struct dev_info *dev_info, + bool supports_flowctrl) +{ + caif_assert(offsetof(struct cfsrvl, layer) == 0); + service->open = false; + service->modem_flow_on = true; + service->phy_flow_on = true; + service->layer.id = channel_id; + service->layer.ctrlcmd = cfservl_ctrlcmd; + service->layer.modemcmd = cfservl_modemcmd; + service->dev_info = *dev_info; + service->supports_flowctrl = supports_flowctrl; + service->release = cfsrvl_release; +} + +bool cfsrvl_ready(struct cfsrvl *service, int *err) +{ + if (!service->open) { + *err = -ENOTCONN; + return false; + } + return true; +} + +u8 cfsrvl_getphyid(struct cflayer *layer) +{ + struct cfsrvl *servl = container_obj(layer); + return servl->dev_info.id; +} + +bool cfsrvl_phyid_match(struct cflayer *layer, int phyid) +{ + struct cfsrvl *servl = container_obj(layer); + return servl->dev_info.id == phyid; +} + +void caif_free_client(struct cflayer *adap_layer) +{ + struct cfsrvl *servl; + if (adap_layer == NULL || adap_layer->dn == NULL) + return; + servl = container_obj(adap_layer->dn); + servl->release(&servl->layer); +} +EXPORT_SYMBOL(caif_free_client); + +void caif_client_register_refcnt(struct cflayer *adapt_layer, + void (*hold)(struct cflayer *lyr), + void (*put)(struct cflayer *lyr)) +{ + struct cfsrvl *service; + + if (WARN_ON(adapt_layer == NULL || adapt_layer->dn == NULL)) + return; + service = container_of(adapt_layer->dn, struct cfsrvl, layer); + service->hold = hold; + service->put = put; +} +EXPORT_SYMBOL(caif_client_register_refcnt); diff --git a/net/caif/cfutill.c b/net/caif/cfutill.c new file mode 100644 index 000000000..b2e47ede9 --- /dev/null +++ b/net/caif/cfutill.c @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) ST-Ericsson AB 2010 + * Author: Sjur Brendeland + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s(): " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include + +#define container_obj(layr) ((struct cfsrvl *) layr) +#define UTIL_PAYLOAD 0x00 +#define UTIL_CMD_BIT 0x80 +#define UTIL_REMOTE_SHUTDOWN 0x82 +#define UTIL_FLOW_OFF 0x81 +#define UTIL_FLOW_ON 0x80 + +static int cfutill_receive(struct cflayer *layr, struct cfpkt *pkt); +static int cfutill_transmit(struct cflayer *layr, struct cfpkt *pkt); + +struct cflayer *cfutill_create(u8 channel_id, struct dev_info *dev_info) +{ + struct cfsrvl *util = kzalloc(sizeof(struct cfsrvl), GFP_ATOMIC); + if (!util) + return NULL; + caif_assert(offsetof(struct cfsrvl, layer) == 0); + cfsrvl_init(util, channel_id, dev_info, true); + util->layer.receive = cfutill_receive; + util->layer.transmit = cfutill_transmit; + snprintf(util->layer.name, CAIF_LAYER_NAME_SZ, "util1"); + return &util->layer; +} + +static int cfutill_receive(struct cflayer *layr, struct cfpkt *pkt) +{ + u8 cmd = -1; + struct cfsrvl *service = container_obj(layr); + caif_assert(layr != NULL); + caif_assert(layr->up != NULL); + caif_assert(layr->up->receive != NULL); + caif_assert(layr->up->ctrlcmd != NULL); + if (cfpkt_extr_head(pkt, &cmd, 1) < 0) { + pr_err("Packet is erroneous!\n"); + cfpkt_destroy(pkt); + return -EPROTO; + } + + switch (cmd) { + case UTIL_PAYLOAD: + return layr->up->receive(layr->up, pkt); + case UTIL_FLOW_OFF: + layr->ctrlcmd(layr, CAIF_CTRLCMD_FLOW_OFF_IND, 0); + cfpkt_destroy(pkt); + return 0; + case UTIL_FLOW_ON: + layr->ctrlcmd(layr, CAIF_CTRLCMD_FLOW_ON_IND, 0); + cfpkt_destroy(pkt); + return 0; + case UTIL_REMOTE_SHUTDOWN: /* Remote Shutdown Request */ + pr_err("REMOTE SHUTDOWN REQUEST RECEIVED\n"); + layr->ctrlcmd(layr, CAIF_CTRLCMD_REMOTE_SHUTDOWN_IND, 0); + service->open = false; + cfpkt_destroy(pkt); + return 0; + default: + cfpkt_destroy(pkt); + pr_warn("Unknown service control %d (0x%x)\n", cmd, cmd); + return -EPROTO; + } +} + +static int cfutill_transmit(struct cflayer *layr, struct cfpkt *pkt) +{ + u8 zero = 0; + struct caif_payload_info *info; + int ret; + struct cfsrvl *service = container_obj(layr); + caif_assert(layr != NULL); + caif_assert(layr->dn != NULL); + caif_assert(layr->dn->transmit != NULL); + + if (!cfsrvl_ready(service, &ret)) { + cfpkt_destroy(pkt); + return ret; + } + + cfpkt_add_head(pkt, &zero, 1); + /* Add info for MUX-layer to route the packet out. */ + info = cfpkt_info(pkt); + info->channel_id = service->layer.id; + /* + * To optimize alignment, we add up the size of CAIF header before + * payload. + */ + info->hdr_len = 1; + info->dev_info = &service->dev_info; + return layr->dn->transmit(layr->dn, pkt); +} diff --git a/net/caif/cfveil.c b/net/caif/cfveil.c new file mode 100644 index 000000000..db2274b94 --- /dev/null +++ b/net/caif/cfveil.c @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) ST-Ericsson AB 2010 + * Author: Sjur Brendeland + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s(): " fmt, __func__ + +#include +#include +#include +#include +#include + +#define VEI_PAYLOAD 0x00 +#define VEI_CMD_BIT 0x80 +#define VEI_FLOW_OFF 0x81 +#define VEI_FLOW_ON 0x80 +#define VEI_SET_PIN 0x82 + +#define container_obj(layr) container_of(layr, struct cfsrvl, layer) + +static int cfvei_receive(struct cflayer *layr, struct cfpkt *pkt); +static int cfvei_transmit(struct cflayer *layr, struct cfpkt *pkt); + +struct cflayer *cfvei_create(u8 channel_id, struct dev_info *dev_info) +{ + struct cfsrvl *vei = kzalloc(sizeof(struct cfsrvl), GFP_ATOMIC); + if (!vei) + return NULL; + caif_assert(offsetof(struct cfsrvl, layer) == 0); + cfsrvl_init(vei, channel_id, dev_info, true); + vei->layer.receive = cfvei_receive; + vei->layer.transmit = cfvei_transmit; + snprintf(vei->layer.name, CAIF_LAYER_NAME_SZ, "vei%d", channel_id); + return &vei->layer; +} + +static int cfvei_receive(struct cflayer *layr, struct cfpkt *pkt) +{ + u8 cmd; + int ret; + caif_assert(layr->up != NULL); + caif_assert(layr->receive != NULL); + caif_assert(layr->ctrlcmd != NULL); + + + if (cfpkt_extr_head(pkt, &cmd, 1) < 0) { + pr_err("Packet is erroneous!\n"); + cfpkt_destroy(pkt); + return -EPROTO; + } + switch (cmd) { + case VEI_PAYLOAD: + ret = layr->up->receive(layr->up, pkt); + return ret; + case VEI_FLOW_OFF: + layr->ctrlcmd(layr, CAIF_CTRLCMD_FLOW_OFF_IND, 0); + cfpkt_destroy(pkt); + return 0; + case VEI_FLOW_ON: + layr->ctrlcmd(layr, CAIF_CTRLCMD_FLOW_ON_IND, 0); + cfpkt_destroy(pkt); + return 0; + case VEI_SET_PIN: /* SET RS232 PIN */ + cfpkt_destroy(pkt); + return 0; + default: /* SET RS232 PIN */ + pr_warn("Unknown VEI control packet %d (0x%x)!\n", cmd, cmd); + cfpkt_destroy(pkt); + return -EPROTO; + } +} + +static int cfvei_transmit(struct cflayer *layr, struct cfpkt *pkt) +{ + u8 tmp = 0; + struct caif_payload_info *info; + int ret; + struct cfsrvl *service = container_obj(layr); + if (!cfsrvl_ready(service, &ret)) + goto err; + caif_assert(layr->dn != NULL); + caif_assert(layr->dn->transmit != NULL); + + if (cfpkt_add_head(pkt, &tmp, 1) < 0) { + pr_err("Packet is erroneous!\n"); + ret = -EPROTO; + goto err; + } + + /* Add info-> for MUX-layer to route the packet out. */ + info = cfpkt_info(pkt); + info->channel_id = service->layer.id; + info->hdr_len = 1; + info->dev_info = &service->dev_info; + return layr->dn->transmit(layr->dn, pkt); +err: + cfpkt_destroy(pkt); + return ret; +} diff --git a/net/caif/cfvidl.c b/net/caif/cfvidl.c new file mode 100644 index 000000000..134bad431 --- /dev/null +++ b/net/caif/cfvidl.c @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) ST-Ericsson AB 2010 + * Author: Sjur Brendeland + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s(): " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include + +#define container_obj(layr) ((struct cfsrvl *) layr) + +static int cfvidl_receive(struct cflayer *layr, struct cfpkt *pkt); +static int cfvidl_transmit(struct cflayer *layr, struct cfpkt *pkt); + +struct cflayer *cfvidl_create(u8 channel_id, struct dev_info *dev_info) +{ + struct cfsrvl *vid = kzalloc(sizeof(struct cfsrvl), GFP_ATOMIC); + if (!vid) + return NULL; + caif_assert(offsetof(struct cfsrvl, layer) == 0); + + cfsrvl_init(vid, channel_id, dev_info, false); + vid->layer.receive = cfvidl_receive; + vid->layer.transmit = cfvidl_transmit; + snprintf(vid->layer.name, CAIF_LAYER_NAME_SZ, "vid1"); + return &vid->layer; +} + +static int cfvidl_receive(struct cflayer *layr, struct cfpkt *pkt) +{ + u32 videoheader; + if (cfpkt_extr_head(pkt, &videoheader, 4) < 0) { + pr_err("Packet is erroneous!\n"); + cfpkt_destroy(pkt); + return -EPROTO; + } + return layr->up->receive(layr->up, pkt); +} + +static int cfvidl_transmit(struct cflayer *layr, struct cfpkt *pkt) +{ + struct cfsrvl *service = container_obj(layr); + struct caif_payload_info *info; + u32 videoheader = 0; + int ret; + + if (!cfsrvl_ready(service, &ret)) { + cfpkt_destroy(pkt); + return ret; + } + + cfpkt_add_head(pkt, &videoheader, 4); + /* Add info for MUX-layer to route the packet out */ + info = cfpkt_info(pkt); + info->channel_id = service->layer.id; + info->dev_info = &service->dev_info; + return layr->dn->transmit(layr->dn, pkt); +} diff --git a/net/caif/chnl_net.c b/net/caif/chnl_net.c new file mode 100644 index 000000000..f35fc87c4 --- /dev/null +++ b/net/caif/chnl_net.c @@ -0,0 +1,531 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) ST-Ericsson AB 2010 + * Authors: Sjur Brendeland + * Daniel Martensson + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ":%s(): " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* GPRS PDP connection has MTU to 1500 */ +#define GPRS_PDP_MTU 1500 +/* 5 sec. connect timeout */ +#define CONNECT_TIMEOUT (5 * HZ) +#define CAIF_NET_DEFAULT_QUEUE_LEN 500 +#define UNDEF_CONNID 0xffffffff + +/*This list is protected by the rtnl lock. */ +static LIST_HEAD(chnl_net_list); + +MODULE_LICENSE("GPL"); +MODULE_ALIAS_RTNL_LINK("caif"); + +enum caif_states { + CAIF_CONNECTED = 1, + CAIF_CONNECTING, + CAIF_DISCONNECTED, + CAIF_SHUTDOWN +}; + +struct chnl_net { + struct cflayer chnl; + struct caif_connect_request conn_req; + struct list_head list_field; + struct net_device *netdev; + char name[256]; + wait_queue_head_t netmgmt_wq; + /* Flow status to remember and control the transmission. */ + bool flowenabled; + enum caif_states state; +}; + +static int chnl_recv_cb(struct cflayer *layr, struct cfpkt *pkt) +{ + struct sk_buff *skb; + struct chnl_net *priv; + int pktlen; + const u8 *ip_version; + u8 buf; + + priv = container_of(layr, struct chnl_net, chnl); + + skb = (struct sk_buff *) cfpkt_tonative(pkt); + + /* Get length of CAIF packet. */ + pktlen = skb->len; + + /* Pass some minimum information and + * send the packet to the net stack. + */ + skb->dev = priv->netdev; + + /* check the version of IP */ + ip_version = skb_header_pointer(skb, 0, 1, &buf); + if (!ip_version) { + kfree_skb(skb); + return -EINVAL; + } + + switch (*ip_version >> 4) { + case 4: + skb->protocol = htons(ETH_P_IP); + break; + case 6: + skb->protocol = htons(ETH_P_IPV6); + break; + default: + kfree_skb(skb); + priv->netdev->stats.rx_errors++; + return -EINVAL; + } + + /* If we change the header in loop mode, the checksum is corrupted. */ + if (priv->conn_req.protocol == CAIFPROTO_DATAGRAM_LOOP) + skb->ip_summed = CHECKSUM_UNNECESSARY; + else + skb->ip_summed = CHECKSUM_NONE; + + netif_rx(skb); + + /* Update statistics. */ + priv->netdev->stats.rx_packets++; + priv->netdev->stats.rx_bytes += pktlen; + + return 0; +} + +static int delete_device(struct chnl_net *dev) +{ + ASSERT_RTNL(); + if (dev->netdev) + unregister_netdevice(dev->netdev); + return 0; +} + +static void close_work(struct work_struct *work) +{ + struct chnl_net *dev = NULL; + struct list_head *list_node; + struct list_head *_tmp; + + rtnl_lock(); + list_for_each_safe(list_node, _tmp, &chnl_net_list) { + dev = list_entry(list_node, struct chnl_net, list_field); + if (dev->state == CAIF_SHUTDOWN) + dev_close(dev->netdev); + } + rtnl_unlock(); +} +static DECLARE_WORK(close_worker, close_work); + +static void chnl_hold(struct cflayer *lyr) +{ + struct chnl_net *priv = container_of(lyr, struct chnl_net, chnl); + dev_hold(priv->netdev); +} + +static void chnl_put(struct cflayer *lyr) +{ + struct chnl_net *priv = container_of(lyr, struct chnl_net, chnl); + dev_put(priv->netdev); +} + +static void chnl_flowctrl_cb(struct cflayer *layr, enum caif_ctrlcmd flow, + int phyid) +{ + struct chnl_net *priv = container_of(layr, struct chnl_net, chnl); + pr_debug("NET flowctrl func called flow: %s\n", + flow == CAIF_CTRLCMD_FLOW_ON_IND ? "ON" : + flow == CAIF_CTRLCMD_INIT_RSP ? "INIT" : + flow == CAIF_CTRLCMD_FLOW_OFF_IND ? "OFF" : + flow == CAIF_CTRLCMD_DEINIT_RSP ? "CLOSE/DEINIT" : + flow == CAIF_CTRLCMD_INIT_FAIL_RSP ? "OPEN_FAIL" : + flow == CAIF_CTRLCMD_REMOTE_SHUTDOWN_IND ? + "REMOTE_SHUTDOWN" : "UNKNOWN CTRL COMMAND"); + + + + switch (flow) { + case CAIF_CTRLCMD_FLOW_OFF_IND: + priv->flowenabled = false; + netif_stop_queue(priv->netdev); + break; + case CAIF_CTRLCMD_DEINIT_RSP: + priv->state = CAIF_DISCONNECTED; + break; + case CAIF_CTRLCMD_INIT_FAIL_RSP: + priv->state = CAIF_DISCONNECTED; + wake_up_interruptible(&priv->netmgmt_wq); + break; + case CAIF_CTRLCMD_REMOTE_SHUTDOWN_IND: + priv->state = CAIF_SHUTDOWN; + netif_tx_disable(priv->netdev); + schedule_work(&close_worker); + break; + case CAIF_CTRLCMD_FLOW_ON_IND: + priv->flowenabled = true; + netif_wake_queue(priv->netdev); + break; + case CAIF_CTRLCMD_INIT_RSP: + caif_client_register_refcnt(&priv->chnl, chnl_hold, chnl_put); + priv->state = CAIF_CONNECTED; + priv->flowenabled = true; + netif_wake_queue(priv->netdev); + wake_up_interruptible(&priv->netmgmt_wq); + break; + default: + break; + } +} + +static netdev_tx_t chnl_net_start_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct chnl_net *priv; + struct cfpkt *pkt = NULL; + int len; + int result = -1; + /* Get our private data. */ + priv = netdev_priv(dev); + + if (skb->len > priv->netdev->mtu) { + pr_warn("Size of skb exceeded MTU\n"); + kfree_skb(skb); + dev->stats.tx_errors++; + return NETDEV_TX_OK; + } + + if (!priv->flowenabled) { + pr_debug("dropping packets flow off\n"); + kfree_skb(skb); + dev->stats.tx_dropped++; + return NETDEV_TX_OK; + } + + if (priv->conn_req.protocol == CAIFPROTO_DATAGRAM_LOOP) + swap(ip_hdr(skb)->saddr, ip_hdr(skb)->daddr); + + /* Store original SKB length. */ + len = skb->len; + + pkt = cfpkt_fromnative(CAIF_DIR_OUT, (void *) skb); + + /* Send the packet down the stack. */ + result = priv->chnl.dn->transmit(priv->chnl.dn, pkt); + if (result) { + dev->stats.tx_dropped++; + return NETDEV_TX_OK; + } + + /* Update statistics. */ + dev->stats.tx_packets++; + dev->stats.tx_bytes += len; + + return NETDEV_TX_OK; +} + +static int chnl_net_open(struct net_device *dev) +{ + struct chnl_net *priv = NULL; + int result = -1; + int llifindex, headroom, tailroom, mtu; + struct net_device *lldev; + ASSERT_RTNL(); + priv = netdev_priv(dev); + if (!priv) { + pr_debug("chnl_net_open: no priv\n"); + return -ENODEV; + } + + if (priv->state != CAIF_CONNECTING) { + priv->state = CAIF_CONNECTING; + result = caif_connect_client(dev_net(dev), &priv->conn_req, + &priv->chnl, &llifindex, + &headroom, &tailroom); + if (result != 0) { + pr_debug("err: " + "Unable to register and open device," + " Err:%d\n", + result); + goto error; + } + + lldev = __dev_get_by_index(dev_net(dev), llifindex); + + if (lldev == NULL) { + pr_debug("no interface?\n"); + result = -ENODEV; + goto error; + } + + dev->needed_tailroom = tailroom + lldev->needed_tailroom; + dev->hard_header_len = headroom + lldev->hard_header_len + + lldev->needed_tailroom; + + /* + * MTU, head-room etc is not know before we have a + * CAIF link layer device available. MTU calculation may + * override initial RTNL configuration. + * MTU is minimum of current mtu, link layer mtu pluss + * CAIF head and tail, and PDP GPRS contexts max MTU. + */ + mtu = min_t(int, dev->mtu, lldev->mtu - (headroom + tailroom)); + mtu = min_t(int, GPRS_PDP_MTU, mtu); + dev_set_mtu(dev, mtu); + + if (mtu < 100) { + pr_warn("CAIF Interface MTU too small (%d)\n", mtu); + result = -ENODEV; + goto error; + } + } + + rtnl_unlock(); /* Release RTNL lock during connect wait */ + + result = wait_event_interruptible_timeout(priv->netmgmt_wq, + priv->state != CAIF_CONNECTING, + CONNECT_TIMEOUT); + + rtnl_lock(); + + if (result == -ERESTARTSYS) { + pr_debug("wait_event_interruptible woken by a signal\n"); + result = -ERESTARTSYS; + goto error; + } + + if (result == 0) { + pr_debug("connect timeout\n"); + result = -ETIMEDOUT; + goto error; + } + + if (priv->state != CAIF_CONNECTED) { + pr_debug("connect failed\n"); + result = -ECONNREFUSED; + goto error; + } + pr_debug("CAIF Netdevice connected\n"); + return 0; + +error: + caif_disconnect_client(dev_net(dev), &priv->chnl); + priv->state = CAIF_DISCONNECTED; + pr_debug("state disconnected\n"); + return result; + +} + +static int chnl_net_stop(struct net_device *dev) +{ + struct chnl_net *priv; + + ASSERT_RTNL(); + priv = netdev_priv(dev); + priv->state = CAIF_DISCONNECTED; + caif_disconnect_client(dev_net(dev), &priv->chnl); + return 0; +} + +static int chnl_net_init(struct net_device *dev) +{ + struct chnl_net *priv; + ASSERT_RTNL(); + priv = netdev_priv(dev); + strncpy(priv->name, dev->name, sizeof(priv->name)); + INIT_LIST_HEAD(&priv->list_field); + return 0; +} + +static void chnl_net_uninit(struct net_device *dev) +{ + struct chnl_net *priv; + ASSERT_RTNL(); + priv = netdev_priv(dev); + list_del_init(&priv->list_field); +} + +static const struct net_device_ops netdev_ops = { + .ndo_open = chnl_net_open, + .ndo_stop = chnl_net_stop, + .ndo_init = chnl_net_init, + .ndo_uninit = chnl_net_uninit, + .ndo_start_xmit = chnl_net_start_xmit, +}; + +static void chnl_net_destructor(struct net_device *dev) +{ + struct chnl_net *priv = netdev_priv(dev); + caif_free_client(&priv->chnl); +} + +static void ipcaif_net_setup(struct net_device *dev) +{ + struct chnl_net *priv; + dev->netdev_ops = &netdev_ops; + dev->needs_free_netdev = true; + dev->priv_destructor = chnl_net_destructor; + dev->flags |= IFF_NOARP; + dev->flags |= IFF_POINTOPOINT; + dev->mtu = GPRS_PDP_MTU; + dev->tx_queue_len = CAIF_NET_DEFAULT_QUEUE_LEN; + + priv = netdev_priv(dev); + priv->chnl.receive = chnl_recv_cb; + priv->chnl.ctrlcmd = chnl_flowctrl_cb; + priv->netdev = dev; + priv->conn_req.protocol = CAIFPROTO_DATAGRAM; + priv->conn_req.link_selector = CAIF_LINK_HIGH_BANDW; + priv->conn_req.priority = CAIF_PRIO_LOW; + /* Insert illegal value */ + priv->conn_req.sockaddr.u.dgm.connection_id = UNDEF_CONNID; + priv->flowenabled = false; + + init_waitqueue_head(&priv->netmgmt_wq); +} + + +static int ipcaif_fill_info(struct sk_buff *skb, const struct net_device *dev) +{ + struct chnl_net *priv; + u8 loop; + priv = netdev_priv(dev); + if (nla_put_u32(skb, IFLA_CAIF_IPV4_CONNID, + priv->conn_req.sockaddr.u.dgm.connection_id) || + nla_put_u32(skb, IFLA_CAIF_IPV6_CONNID, + priv->conn_req.sockaddr.u.dgm.connection_id)) + goto nla_put_failure; + loop = priv->conn_req.protocol == CAIFPROTO_DATAGRAM_LOOP; + if (nla_put_u8(skb, IFLA_CAIF_LOOPBACK, loop)) + goto nla_put_failure; + return 0; +nla_put_failure: + return -EMSGSIZE; + +} + +static void caif_netlink_parms(struct nlattr *data[], + struct caif_connect_request *conn_req) +{ + if (!data) { + pr_warn("no params data found\n"); + return; + } + if (data[IFLA_CAIF_IPV4_CONNID]) + conn_req->sockaddr.u.dgm.connection_id = + nla_get_u32(data[IFLA_CAIF_IPV4_CONNID]); + if (data[IFLA_CAIF_IPV6_CONNID]) + conn_req->sockaddr.u.dgm.connection_id = + nla_get_u32(data[IFLA_CAIF_IPV6_CONNID]); + if (data[IFLA_CAIF_LOOPBACK]) { + if (nla_get_u8(data[IFLA_CAIF_LOOPBACK])) + conn_req->protocol = CAIFPROTO_DATAGRAM_LOOP; + else + conn_req->protocol = CAIFPROTO_DATAGRAM; + } +} + +static int ipcaif_newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + int ret; + struct chnl_net *caifdev; + ASSERT_RTNL(); + caifdev = netdev_priv(dev); + caif_netlink_parms(data, &caifdev->conn_req); + + ret = register_netdevice(dev); + if (ret) + pr_warn("device rtml registration failed\n"); + else + list_add(&caifdev->list_field, &chnl_net_list); + + /* Use ifindex as connection id, and use loopback channel default. */ + if (caifdev->conn_req.sockaddr.u.dgm.connection_id == UNDEF_CONNID) { + caifdev->conn_req.sockaddr.u.dgm.connection_id = dev->ifindex; + caifdev->conn_req.protocol = CAIFPROTO_DATAGRAM_LOOP; + } + return ret; +} + +static int ipcaif_changelink(struct net_device *dev, struct nlattr *tb[], + struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct chnl_net *caifdev; + ASSERT_RTNL(); + caifdev = netdev_priv(dev); + caif_netlink_parms(data, &caifdev->conn_req); + netdev_state_change(dev); + return 0; +} + +static size_t ipcaif_get_size(const struct net_device *dev) +{ + return + /* IFLA_CAIF_IPV4_CONNID */ + nla_total_size(4) + + /* IFLA_CAIF_IPV6_CONNID */ + nla_total_size(4) + + /* IFLA_CAIF_LOOPBACK */ + nla_total_size(2) + + 0; +} + +static const struct nla_policy ipcaif_policy[IFLA_CAIF_MAX + 1] = { + [IFLA_CAIF_IPV4_CONNID] = { .type = NLA_U32 }, + [IFLA_CAIF_IPV6_CONNID] = { .type = NLA_U32 }, + [IFLA_CAIF_LOOPBACK] = { .type = NLA_U8 } +}; + + +static struct rtnl_link_ops ipcaif_link_ops __read_mostly = { + .kind = "caif", + .priv_size = sizeof(struct chnl_net), + .setup = ipcaif_net_setup, + .maxtype = IFLA_CAIF_MAX, + .policy = ipcaif_policy, + .newlink = ipcaif_newlink, + .changelink = ipcaif_changelink, + .get_size = ipcaif_get_size, + .fill_info = ipcaif_fill_info, + +}; + +static int __init chnl_init_module(void) +{ + return rtnl_link_register(&ipcaif_link_ops); +} + +static void __exit chnl_exit_module(void) +{ + struct chnl_net *dev = NULL; + struct list_head *list_node; + struct list_head *_tmp; + rtnl_link_unregister(&ipcaif_link_ops); + rtnl_lock(); + list_for_each_safe(list_node, _tmp, &chnl_net_list) { + dev = list_entry(list_node, struct chnl_net, list_field); + list_del_init(list_node); + delete_device(dev); + } + rtnl_unlock(); +} + +module_init(chnl_init_module); +module_exit(chnl_exit_module); diff --git a/net/can/Kconfig b/net/can/Kconfig new file mode 100644 index 000000000..cb56be8e3 --- /dev/null +++ b/net/can/Kconfig @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Controller Area Network (CAN) network layer core configuration +# + +menuconfig CAN + tristate "CAN bus subsystem support" + help + Controller Area Network (CAN) is a slow (up to 1Mbit/s) serial + communications protocol. Development of the CAN bus started in + 1983 at Robert Bosch GmbH, and the protocol was officially + released in 1986. The CAN bus was originally mainly for automotive, + but is now widely used in marine (NMEA2000), industrial, and medical + applications. More information on the CAN network protocol family + PF_CAN is contained in . + + If you want CAN support you should say Y here and also to the + specific driver for your controller(s) under the Network device + support section. + +if CAN + +config CAN_RAW + tristate "Raw CAN Protocol (raw access with CAN-ID filtering)" + default y + help + The raw CAN protocol option offers access to the CAN bus via + the BSD socket API. You probably want to use the raw socket in + most cases where no higher level protocol is being used. The raw + socket has several filter options e.g. ID masking / error frames. + To receive/send raw CAN messages, use AF_CAN with protocol CAN_RAW. + +config CAN_BCM + tristate "Broadcast Manager CAN Protocol (with content filtering)" + default y + help + The Broadcast Manager offers content filtering, timeout monitoring, + sending of RTR frames, and cyclic CAN messages without permanent user + interaction. The BCM can be 'programmed' via the BSD socket API and + informs you on demand e.g. only on content updates / timeouts. + You probably want to use the bcm socket in most cases where cyclic + CAN messages are used on the bus (e.g. in automotive environments). + To use the Broadcast Manager, use AF_CAN with protocol CAN_BCM. + +config CAN_GW + tristate "CAN Gateway/Router (with netlink configuration)" + default y + help + The CAN Gateway/Router is used to route (and modify) CAN frames. + It is based on the PF_CAN core infrastructure for msg filtering and + msg sending and can optionally modify routed CAN frames on the fly. + CAN frames can be routed between CAN network interfaces (one hop). + They can be modified with AND/OR/XOR/SET operations as configured + by the netlink configuration interface known e.g. from iptables. + +source "net/can/j1939/Kconfig" + +config CAN_ISOTP + tristate "ISO 15765-2:2016 CAN transport protocol" + help + CAN Transport Protocols offer support for segmented Point-to-Point + communication between CAN nodes via two defined CAN Identifiers. + As CAN frames can only transport a small amount of data bytes + (max. 8 bytes for 'classic' CAN and max. 64 bytes for CAN FD) this + segmentation is needed to transport longer Protocol Data Units (PDU) + as needed e.g. for vehicle diagnosis (UDS, ISO 14229) or IP-over-CAN + traffic. + This protocol driver implements data transfers according to + ISO 15765-2:2016 for 'classic' CAN and CAN FD frame types. + If you want to perform automotive vehicle diagnostic services (UDS), + say 'y'. + +endif diff --git a/net/can/Makefile b/net/can/Makefile new file mode 100644 index 000000000..58f2c31c1 --- /dev/null +++ b/net/can/Makefile @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the Linux Controller Area Network core. +# + +obj-$(CONFIG_CAN) += can.o +can-y := af_can.o +can-$(CONFIG_PROC_FS) += proc.o + +obj-$(CONFIG_CAN_RAW) += can-raw.o +can-raw-y := raw.o + +obj-$(CONFIG_CAN_BCM) += can-bcm.o +can-bcm-y := bcm.o + +obj-$(CONFIG_CAN_GW) += can-gw.o +can-gw-y := gw.o + +obj-$(CONFIG_CAN_J1939) += j1939/ + +obj-$(CONFIG_CAN_ISOTP) += can-isotp.o +can-isotp-y := isotp.o diff --git a/net/can/af_can.c b/net/can/af_can.c new file mode 100644 index 000000000..c69168f11 --- /dev/null +++ b/net/can/af_can.c @@ -0,0 +1,918 @@ +// SPDX-License-Identifier: (GPL-2.0 OR BSD-3-Clause) +/* af_can.c - Protocol family CAN core module + * (used by different CAN protocol modules) + * + * Copyright (c) 2002-2017 Volkswagen Group Electronic Research + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of Volkswagen nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * Alternatively, provided that this notice is retained in full, this + * software may be distributed under the terms of the GNU General + * Public License ("GPL") version 2, in which case the provisions of the + * GPL apply INSTEAD OF those given above. + * + * The provided data structures and external interfaces from this code + * are not restricted to be used by modules with a GPL compatible license. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH + * DAMAGE. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "af_can.h" + +MODULE_DESCRIPTION("Controller Area Network PF_CAN core"); +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_AUTHOR("Urs Thuermann , " + "Oliver Hartkopp "); + +MODULE_ALIAS_NETPROTO(PF_CAN); + +static int stats_timer __read_mostly = 1; +module_param(stats_timer, int, 0444); +MODULE_PARM_DESC(stats_timer, "enable timer for statistics (default:on)"); + +static struct kmem_cache *rcv_cache __read_mostly; + +/* table of registered CAN protocols */ +static const struct can_proto __rcu *proto_tab[CAN_NPROTO] __read_mostly; +static DEFINE_MUTEX(proto_tab_lock); + +static atomic_t skbcounter = ATOMIC_INIT(0); + +/* af_can socket functions */ + +void can_sock_destruct(struct sock *sk) +{ + skb_queue_purge(&sk->sk_receive_queue); + skb_queue_purge(&sk->sk_error_queue); +} +EXPORT_SYMBOL(can_sock_destruct); + +static const struct can_proto *can_get_proto(int protocol) +{ + const struct can_proto *cp; + + rcu_read_lock(); + cp = rcu_dereference(proto_tab[protocol]); + if (cp && !try_module_get(cp->prot->owner)) + cp = NULL; + rcu_read_unlock(); + + return cp; +} + +static inline void can_put_proto(const struct can_proto *cp) +{ + module_put(cp->prot->owner); +} + +static int can_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + const struct can_proto *cp; + int err = 0; + + sock->state = SS_UNCONNECTED; + + if (protocol < 0 || protocol >= CAN_NPROTO) + return -EINVAL; + + cp = can_get_proto(protocol); + +#ifdef CONFIG_MODULES + if (!cp) { + /* try to load protocol module if kernel is modular */ + + err = request_module("can-proto-%d", protocol); + + /* In case of error we only print a message but don't + * return the error code immediately. Below we will + * return -EPROTONOSUPPORT + */ + if (err) + pr_err_ratelimited("can: request_module (can-proto-%d) failed.\n", + protocol); + + cp = can_get_proto(protocol); + } +#endif + + /* check for available protocol and correct usage */ + + if (!cp) + return -EPROTONOSUPPORT; + + if (cp->type != sock->type) { + err = -EPROTOTYPE; + goto errout; + } + + sock->ops = cp->ops; + + sk = sk_alloc(net, PF_CAN, GFP_KERNEL, cp->prot, kern); + if (!sk) { + err = -ENOMEM; + goto errout; + } + + sock_init_data(sock, sk); + sk->sk_destruct = can_sock_destruct; + + if (sk->sk_prot->init) + err = sk->sk_prot->init(sk); + + if (err) { + /* release sk on errors */ + sock_orphan(sk); + sock_put(sk); + } + + errout: + can_put_proto(cp); + return err; +} + +/* af_can tx path */ + +/** + * can_send - transmit a CAN frame (optional with local loopback) + * @skb: pointer to socket buffer with CAN frame in data section + * @loop: loopback for listeners on local CAN sockets (recommended default!) + * + * Due to the loopback this routine must not be called from hardirq context. + * + * Return: + * 0 on success + * -ENETDOWN when the selected interface is down + * -ENOBUFS on full driver queue (see net_xmit_errno()) + * -ENOMEM when local loopback failed at calling skb_clone() + * -EPERM when trying to send on a non-CAN interface + * -EMSGSIZE CAN frame size is bigger than CAN interface MTU + * -EINVAL when the skb->data does not contain a valid CAN frame + */ +int can_send(struct sk_buff *skb, int loop) +{ + struct sk_buff *newskb = NULL; + struct can_pkg_stats *pkg_stats = dev_net(skb->dev)->can.pkg_stats; + int err = -EINVAL; + + if (can_is_canxl_skb(skb)) { + skb->protocol = htons(ETH_P_CANXL); + } else if (can_is_can_skb(skb)) { + skb->protocol = htons(ETH_P_CAN); + } else if (can_is_canfd_skb(skb)) { + struct canfd_frame *cfd = (struct canfd_frame *)skb->data; + + skb->protocol = htons(ETH_P_CANFD); + + /* set CAN FD flag for CAN FD frames by default */ + cfd->flags |= CANFD_FDF; + } else { + goto inval_skb; + } + + /* Make sure the CAN frame can pass the selected CAN netdevice. */ + if (unlikely(skb->len > skb->dev->mtu)) { + err = -EMSGSIZE; + goto inval_skb; + } + + if (unlikely(skb->dev->type != ARPHRD_CAN)) { + err = -EPERM; + goto inval_skb; + } + + if (unlikely(!(skb->dev->flags & IFF_UP))) { + err = -ENETDOWN; + goto inval_skb; + } + + skb->ip_summed = CHECKSUM_UNNECESSARY; + + skb_reset_mac_header(skb); + skb_reset_network_header(skb); + skb_reset_transport_header(skb); + + if (loop) { + /* local loopback of sent CAN frames */ + + /* indication for the CAN driver: do loopback */ + skb->pkt_type = PACKET_LOOPBACK; + + /* The reference to the originating sock may be required + * by the receiving socket to check whether the frame is + * its own. Example: can_raw sockopt CAN_RAW_RECV_OWN_MSGS + * Therefore we have to ensure that skb->sk remains the + * reference to the originating sock by restoring skb->sk + * after each skb_clone() or skb_orphan() usage. + */ + + if (!(skb->dev->flags & IFF_ECHO)) { + /* If the interface is not capable to do loopback + * itself, we do it here. + */ + newskb = skb_clone(skb, GFP_ATOMIC); + if (!newskb) { + kfree_skb(skb); + return -ENOMEM; + } + + can_skb_set_owner(newskb, skb->sk); + newskb->ip_summed = CHECKSUM_UNNECESSARY; + newskb->pkt_type = PACKET_BROADCAST; + } + } else { + /* indication for the CAN driver: no loopback required */ + skb->pkt_type = PACKET_HOST; + } + + /* send to netdevice */ + err = dev_queue_xmit(skb); + if (err > 0) + err = net_xmit_errno(err); + + if (err) { + kfree_skb(newskb); + return err; + } + + if (newskb) + netif_rx(newskb); + + /* update statistics */ + pkg_stats->tx_frames++; + pkg_stats->tx_frames_delta++; + + return 0; + +inval_skb: + kfree_skb(skb); + return err; +} +EXPORT_SYMBOL(can_send); + +/* af_can rx path */ + +static struct can_dev_rcv_lists *can_dev_rcv_lists_find(struct net *net, + struct net_device *dev) +{ + if (dev) { + struct can_ml_priv *can_ml = can_get_ml_priv(dev); + return &can_ml->dev_rcv_lists; + } else { + return net->can.rx_alldev_list; + } +} + +/** + * effhash - hash function for 29 bit CAN identifier reduction + * @can_id: 29 bit CAN identifier + * + * Description: + * To reduce the linear traversal in one linked list of _single_ EFF CAN + * frame subscriptions the 29 bit identifier is mapped to 10 bits. + * (see CAN_EFF_RCV_HASH_BITS definition) + * + * Return: + * Hash value from 0x000 - 0x3FF ( enforced by CAN_EFF_RCV_HASH_BITS mask ) + */ +static unsigned int effhash(canid_t can_id) +{ + unsigned int hash; + + hash = can_id; + hash ^= can_id >> CAN_EFF_RCV_HASH_BITS; + hash ^= can_id >> (2 * CAN_EFF_RCV_HASH_BITS); + + return hash & ((1 << CAN_EFF_RCV_HASH_BITS) - 1); +} + +/** + * can_rcv_list_find - determine optimal filterlist inside device filter struct + * @can_id: pointer to CAN identifier of a given can_filter + * @mask: pointer to CAN mask of a given can_filter + * @dev_rcv_lists: pointer to the device filter struct + * + * Description: + * Returns the optimal filterlist to reduce the filter handling in the + * receive path. This function is called by service functions that need + * to register or unregister a can_filter in the filter lists. + * + * A filter matches in general, when + * + * & mask == can_id & mask + * + * so every bit set in the mask (even CAN_EFF_FLAG, CAN_RTR_FLAG) describe + * relevant bits for the filter. + * + * The filter can be inverted (CAN_INV_FILTER bit set in can_id) or it can + * filter for error messages (CAN_ERR_FLAG bit set in mask). For error msg + * frames there is a special filterlist and a special rx path filter handling. + * + * Return: + * Pointer to optimal filterlist for the given can_id/mask pair. + * Consistency checked mask. + * Reduced can_id to have a preprocessed filter compare value. + */ +static struct hlist_head *can_rcv_list_find(canid_t *can_id, canid_t *mask, + struct can_dev_rcv_lists *dev_rcv_lists) +{ + canid_t inv = *can_id & CAN_INV_FILTER; /* save flag before masking */ + + /* filter for error message frames in extra filterlist */ + if (*mask & CAN_ERR_FLAG) { + /* clear CAN_ERR_FLAG in filter entry */ + *mask &= CAN_ERR_MASK; + return &dev_rcv_lists->rx[RX_ERR]; + } + + /* with cleared CAN_ERR_FLAG we have a simple mask/value filterpair */ + +#define CAN_EFF_RTR_FLAGS (CAN_EFF_FLAG | CAN_RTR_FLAG) + + /* ensure valid values in can_mask for 'SFF only' frame filtering */ + if ((*mask & CAN_EFF_FLAG) && !(*can_id & CAN_EFF_FLAG)) + *mask &= (CAN_SFF_MASK | CAN_EFF_RTR_FLAGS); + + /* reduce condition testing at receive time */ + *can_id &= *mask; + + /* inverse can_id/can_mask filter */ + if (inv) + return &dev_rcv_lists->rx[RX_INV]; + + /* mask == 0 => no condition testing at receive time */ + if (!(*mask)) + return &dev_rcv_lists->rx[RX_ALL]; + + /* extra filterlists for the subscription of a single non-RTR can_id */ + if (((*mask & CAN_EFF_RTR_FLAGS) == CAN_EFF_RTR_FLAGS) && + !(*can_id & CAN_RTR_FLAG)) { + if (*can_id & CAN_EFF_FLAG) { + if (*mask == (CAN_EFF_MASK | CAN_EFF_RTR_FLAGS)) + return &dev_rcv_lists->rx_eff[effhash(*can_id)]; + } else { + if (*mask == (CAN_SFF_MASK | CAN_EFF_RTR_FLAGS)) + return &dev_rcv_lists->rx_sff[*can_id]; + } + } + + /* default: filter via can_id/can_mask */ + return &dev_rcv_lists->rx[RX_FIL]; +} + +/** + * can_rx_register - subscribe CAN frames from a specific interface + * @net: the applicable net namespace + * @dev: pointer to netdevice (NULL => subscribe from 'all' CAN devices list) + * @can_id: CAN identifier (see description) + * @mask: CAN mask (see description) + * @func: callback function on filter match + * @data: returned parameter for callback function + * @ident: string for calling module identification + * @sk: socket pointer (might be NULL) + * + * Description: + * Invokes the callback function with the received sk_buff and the given + * parameter 'data' on a matching receive filter. A filter matches, when + * + * & mask == can_id & mask + * + * The filter can be inverted (CAN_INV_FILTER bit set in can_id) or it can + * filter for error message frames (CAN_ERR_FLAG bit set in mask). + * + * The provided pointer to the sk_buff is guaranteed to be valid as long as + * the callback function is running. The callback function must *not* free + * the given sk_buff while processing it's task. When the given sk_buff is + * needed after the end of the callback function it must be cloned inside + * the callback function with skb_clone(). + * + * Return: + * 0 on success + * -ENOMEM on missing cache mem to create subscription entry + * -ENODEV unknown device + */ +int can_rx_register(struct net *net, struct net_device *dev, canid_t can_id, + canid_t mask, void (*func)(struct sk_buff *, void *), + void *data, char *ident, struct sock *sk) +{ + struct receiver *rcv; + struct hlist_head *rcv_list; + struct can_dev_rcv_lists *dev_rcv_lists; + struct can_rcv_lists_stats *rcv_lists_stats = net->can.rcv_lists_stats; + int err = 0; + + /* insert new receiver (dev,canid,mask) -> (func,data) */ + + if (dev && (dev->type != ARPHRD_CAN || !can_get_ml_priv(dev))) + return -ENODEV; + + if (dev && !net_eq(net, dev_net(dev))) + return -ENODEV; + + rcv = kmem_cache_alloc(rcv_cache, GFP_KERNEL); + if (!rcv) + return -ENOMEM; + + spin_lock_bh(&net->can.rcvlists_lock); + + dev_rcv_lists = can_dev_rcv_lists_find(net, dev); + rcv_list = can_rcv_list_find(&can_id, &mask, dev_rcv_lists); + + rcv->can_id = can_id; + rcv->mask = mask; + rcv->matches = 0; + rcv->func = func; + rcv->data = data; + rcv->ident = ident; + rcv->sk = sk; + + hlist_add_head_rcu(&rcv->list, rcv_list); + dev_rcv_lists->entries++; + + rcv_lists_stats->rcv_entries++; + rcv_lists_stats->rcv_entries_max = max(rcv_lists_stats->rcv_entries_max, + rcv_lists_stats->rcv_entries); + spin_unlock_bh(&net->can.rcvlists_lock); + + return err; +} +EXPORT_SYMBOL(can_rx_register); + +/* can_rx_delete_receiver - rcu callback for single receiver entry removal */ +static void can_rx_delete_receiver(struct rcu_head *rp) +{ + struct receiver *rcv = container_of(rp, struct receiver, rcu); + struct sock *sk = rcv->sk; + + kmem_cache_free(rcv_cache, rcv); + if (sk) + sock_put(sk); +} + +/** + * can_rx_unregister - unsubscribe CAN frames from a specific interface + * @net: the applicable net namespace + * @dev: pointer to netdevice (NULL => unsubscribe from 'all' CAN devices list) + * @can_id: CAN identifier + * @mask: CAN mask + * @func: callback function on filter match + * @data: returned parameter for callback function + * + * Description: + * Removes subscription entry depending on given (subscription) values. + */ +void can_rx_unregister(struct net *net, struct net_device *dev, canid_t can_id, + canid_t mask, void (*func)(struct sk_buff *, void *), + void *data) +{ + struct receiver *rcv = NULL; + struct hlist_head *rcv_list; + struct can_rcv_lists_stats *rcv_lists_stats = net->can.rcv_lists_stats; + struct can_dev_rcv_lists *dev_rcv_lists; + + if (dev && dev->type != ARPHRD_CAN) + return; + + if (dev && !net_eq(net, dev_net(dev))) + return; + + spin_lock_bh(&net->can.rcvlists_lock); + + dev_rcv_lists = can_dev_rcv_lists_find(net, dev); + rcv_list = can_rcv_list_find(&can_id, &mask, dev_rcv_lists); + + /* Search the receiver list for the item to delete. This should + * exist, since no receiver may be unregistered that hasn't + * been registered before. + */ + hlist_for_each_entry_rcu(rcv, rcv_list, list) { + if (rcv->can_id == can_id && rcv->mask == mask && + rcv->func == func && rcv->data == data) + break; + } + + /* Check for bugs in CAN protocol implementations using af_can.c: + * 'rcv' will be NULL if no matching list item was found for removal. + * As this case may potentially happen when closing a socket while + * the notifier for removing the CAN netdev is running we just print + * a warning here. + */ + if (!rcv) { + pr_warn("can: receive list entry not found for dev %s, id %03X, mask %03X\n", + DNAME(dev), can_id, mask); + goto out; + } + + hlist_del_rcu(&rcv->list); + dev_rcv_lists->entries--; + + if (rcv_lists_stats->rcv_entries > 0) + rcv_lists_stats->rcv_entries--; + + out: + spin_unlock_bh(&net->can.rcvlists_lock); + + /* schedule the receiver item for deletion */ + if (rcv) { + if (rcv->sk) + sock_hold(rcv->sk); + call_rcu(&rcv->rcu, can_rx_delete_receiver); + } +} +EXPORT_SYMBOL(can_rx_unregister); + +static inline void deliver(struct sk_buff *skb, struct receiver *rcv) +{ + rcv->func(skb, rcv->data); + rcv->matches++; +} + +static int can_rcv_filter(struct can_dev_rcv_lists *dev_rcv_lists, struct sk_buff *skb) +{ + struct receiver *rcv; + int matches = 0; + struct can_frame *cf = (struct can_frame *)skb->data; + canid_t can_id = cf->can_id; + + if (dev_rcv_lists->entries == 0) + return 0; + + if (can_id & CAN_ERR_FLAG) { + /* check for error message frame entries only */ + hlist_for_each_entry_rcu(rcv, &dev_rcv_lists->rx[RX_ERR], list) { + if (can_id & rcv->mask) { + deliver(skb, rcv); + matches++; + } + } + return matches; + } + + /* check for unfiltered entries */ + hlist_for_each_entry_rcu(rcv, &dev_rcv_lists->rx[RX_ALL], list) { + deliver(skb, rcv); + matches++; + } + + /* check for can_id/mask entries */ + hlist_for_each_entry_rcu(rcv, &dev_rcv_lists->rx[RX_FIL], list) { + if ((can_id & rcv->mask) == rcv->can_id) { + deliver(skb, rcv); + matches++; + } + } + + /* check for inverted can_id/mask entries */ + hlist_for_each_entry_rcu(rcv, &dev_rcv_lists->rx[RX_INV], list) { + if ((can_id & rcv->mask) != rcv->can_id) { + deliver(skb, rcv); + matches++; + } + } + + /* check filterlists for single non-RTR can_ids */ + if (can_id & CAN_RTR_FLAG) + return matches; + + if (can_id & CAN_EFF_FLAG) { + hlist_for_each_entry_rcu(rcv, &dev_rcv_lists->rx_eff[effhash(can_id)], list) { + if (rcv->can_id == can_id) { + deliver(skb, rcv); + matches++; + } + } + } else { + can_id &= CAN_SFF_MASK; + hlist_for_each_entry_rcu(rcv, &dev_rcv_lists->rx_sff[can_id], list) { + deliver(skb, rcv); + matches++; + } + } + + return matches; +} + +static void can_receive(struct sk_buff *skb, struct net_device *dev) +{ + struct can_dev_rcv_lists *dev_rcv_lists; + struct net *net = dev_net(dev); + struct can_pkg_stats *pkg_stats = net->can.pkg_stats; + int matches; + + /* update statistics */ + pkg_stats->rx_frames++; + pkg_stats->rx_frames_delta++; + + /* create non-zero unique skb identifier together with *skb */ + while (!(can_skb_prv(skb)->skbcnt)) + can_skb_prv(skb)->skbcnt = atomic_inc_return(&skbcounter); + + rcu_read_lock(); + + /* deliver the packet to sockets listening on all devices */ + matches = can_rcv_filter(net->can.rx_alldev_list, skb); + + /* find receive list for this device */ + dev_rcv_lists = can_dev_rcv_lists_find(net, dev); + matches += can_rcv_filter(dev_rcv_lists, skb); + + rcu_read_unlock(); + + /* consume the skbuff allocated by the netdevice driver */ + consume_skb(skb); + + if (matches > 0) { + pkg_stats->matches++; + pkg_stats->matches_delta++; + } +} + +static int can_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + if (unlikely(dev->type != ARPHRD_CAN || !can_get_ml_priv(dev) || !can_is_can_skb(skb))) { + pr_warn_once("PF_CAN: dropped non conform CAN skbuff: dev type %d, len %d\n", + dev->type, skb->len); + + kfree_skb(skb); + return NET_RX_DROP; + } + + can_receive(skb, dev); + return NET_RX_SUCCESS; +} + +static int canfd_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + if (unlikely(dev->type != ARPHRD_CAN || !can_get_ml_priv(dev) || !can_is_canfd_skb(skb))) { + pr_warn_once("PF_CAN: dropped non conform CAN FD skbuff: dev type %d, len %d\n", + dev->type, skb->len); + + kfree_skb(skb); + return NET_RX_DROP; + } + + can_receive(skb, dev); + return NET_RX_SUCCESS; +} + +static int canxl_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + if (unlikely(dev->type != ARPHRD_CAN || !can_get_ml_priv(dev) || !can_is_canxl_skb(skb))) { + pr_warn_once("PF_CAN: dropped non conform CAN XL skbuff: dev type %d, len %d\n", + dev->type, skb->len); + + kfree_skb(skb); + return NET_RX_DROP; + } + + can_receive(skb, dev); + return NET_RX_SUCCESS; +} + +/* af_can protocol functions */ + +/** + * can_proto_register - register CAN transport protocol + * @cp: pointer to CAN protocol structure + * + * Return: + * 0 on success + * -EINVAL invalid (out of range) protocol number + * -EBUSY protocol already in use + * -ENOBUF if proto_register() fails + */ +int can_proto_register(const struct can_proto *cp) +{ + int proto = cp->protocol; + int err = 0; + + if (proto < 0 || proto >= CAN_NPROTO) { + pr_err("can: protocol number %d out of range\n", proto); + return -EINVAL; + } + + err = proto_register(cp->prot, 0); + if (err < 0) + return err; + + mutex_lock(&proto_tab_lock); + + if (rcu_access_pointer(proto_tab[proto])) { + pr_err("can: protocol %d already registered\n", proto); + err = -EBUSY; + } else { + RCU_INIT_POINTER(proto_tab[proto], cp); + } + + mutex_unlock(&proto_tab_lock); + + if (err < 0) + proto_unregister(cp->prot); + + return err; +} +EXPORT_SYMBOL(can_proto_register); + +/** + * can_proto_unregister - unregister CAN transport protocol + * @cp: pointer to CAN protocol structure + */ +void can_proto_unregister(const struct can_proto *cp) +{ + int proto = cp->protocol; + + mutex_lock(&proto_tab_lock); + BUG_ON(rcu_access_pointer(proto_tab[proto]) != cp); + RCU_INIT_POINTER(proto_tab[proto], NULL); + mutex_unlock(&proto_tab_lock); + + synchronize_rcu(); + + proto_unregister(cp->prot); +} +EXPORT_SYMBOL(can_proto_unregister); + +static int can_pernet_init(struct net *net) +{ + spin_lock_init(&net->can.rcvlists_lock); + net->can.rx_alldev_list = + kzalloc(sizeof(*net->can.rx_alldev_list), GFP_KERNEL); + if (!net->can.rx_alldev_list) + goto out; + net->can.pkg_stats = kzalloc(sizeof(*net->can.pkg_stats), GFP_KERNEL); + if (!net->can.pkg_stats) + goto out_free_rx_alldev_list; + net->can.rcv_lists_stats = kzalloc(sizeof(*net->can.rcv_lists_stats), GFP_KERNEL); + if (!net->can.rcv_lists_stats) + goto out_free_pkg_stats; + + if (IS_ENABLED(CONFIG_PROC_FS)) { + /* the statistics are updated every second (timer triggered) */ + if (stats_timer) { + timer_setup(&net->can.stattimer, can_stat_update, + 0); + mod_timer(&net->can.stattimer, + round_jiffies(jiffies + HZ)); + } + net->can.pkg_stats->jiffies_init = jiffies; + can_init_proc(net); + } + + return 0; + + out_free_pkg_stats: + kfree(net->can.pkg_stats); + out_free_rx_alldev_list: + kfree(net->can.rx_alldev_list); + out: + return -ENOMEM; +} + +static void can_pernet_exit(struct net *net) +{ + if (IS_ENABLED(CONFIG_PROC_FS)) { + can_remove_proc(net); + if (stats_timer) + del_timer_sync(&net->can.stattimer); + } + + kfree(net->can.rx_alldev_list); + kfree(net->can.pkg_stats); + kfree(net->can.rcv_lists_stats); +} + +/* af_can module init/exit functions */ + +static struct packet_type can_packet __read_mostly = { + .type = cpu_to_be16(ETH_P_CAN), + .func = can_rcv, +}; + +static struct packet_type canfd_packet __read_mostly = { + .type = cpu_to_be16(ETH_P_CANFD), + .func = canfd_rcv, +}; + +static struct packet_type canxl_packet __read_mostly = { + .type = cpu_to_be16(ETH_P_CANXL), + .func = canxl_rcv, +}; + +static const struct net_proto_family can_family_ops = { + .family = PF_CAN, + .create = can_create, + .owner = THIS_MODULE, +}; + +static struct pernet_operations can_pernet_ops __read_mostly = { + .init = can_pernet_init, + .exit = can_pernet_exit, +}; + +static __init int can_init(void) +{ + int err; + + /* check for correct padding to be able to use the structs similarly */ + BUILD_BUG_ON(offsetof(struct can_frame, len) != + offsetof(struct canfd_frame, len) || + offsetof(struct can_frame, data) != + offsetof(struct canfd_frame, data)); + + pr_info("can: controller area network core\n"); + + rcv_cache = kmem_cache_create("can_receiver", sizeof(struct receiver), + 0, 0, NULL); + if (!rcv_cache) + return -ENOMEM; + + err = register_pernet_subsys(&can_pernet_ops); + if (err) + goto out_pernet; + + /* protocol register */ + err = sock_register(&can_family_ops); + if (err) + goto out_sock; + + dev_add_pack(&can_packet); + dev_add_pack(&canfd_packet); + dev_add_pack(&canxl_packet); + + return 0; + +out_sock: + unregister_pernet_subsys(&can_pernet_ops); +out_pernet: + kmem_cache_destroy(rcv_cache); + + return err; +} + +static __exit void can_exit(void) +{ + /* protocol unregister */ + dev_remove_pack(&canxl_packet); + dev_remove_pack(&canfd_packet); + dev_remove_pack(&can_packet); + sock_unregister(PF_CAN); + + unregister_pernet_subsys(&can_pernet_ops); + + rcu_barrier(); /* Wait for completion of call_rcu()'s */ + + kmem_cache_destroy(rcv_cache); +} + +module_init(can_init); +module_exit(can_exit); diff --git a/net/can/af_can.h b/net/can/af_can.h new file mode 100644 index 000000000..7c2d9161e --- /dev/null +++ b/net/can/af_can.h @@ -0,0 +1,103 @@ +/* SPDX-License-Identifier: (GPL-2.0 OR BSD-3-Clause) */ +/* Copyright (c) 2002-2007 Volkswagen Group Electronic Research + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of Volkswagen nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * Alternatively, provided that this notice is retained in full, this + * software may be distributed under the terms of the GNU General + * Public License ("GPL") version 2, in which case the provisions of the + * GPL apply INSTEAD OF those given above. + * + * The provided data structures and external interfaces from this code + * are not restricted to be used by modules with a GPL compatible license. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH + * DAMAGE. + * + */ + +#ifndef AF_CAN_H +#define AF_CAN_H + +#include +#include +#include +#include +#include + +/* af_can rx dispatcher structures */ + +struct receiver { + struct hlist_node list; + canid_t can_id; + canid_t mask; + unsigned long matches; + void (*func)(struct sk_buff *skb, void *data); + void *data; + char *ident; + struct sock *sk; + struct rcu_head rcu; +}; + +/* statistic structures */ + +/* can be reset e.g. by can_init_stats() */ +struct can_pkg_stats { + unsigned long jiffies_init; + + unsigned long rx_frames; + unsigned long tx_frames; + unsigned long matches; + + unsigned long total_rx_rate; + unsigned long total_tx_rate; + unsigned long total_rx_match_ratio; + + unsigned long current_rx_rate; + unsigned long current_tx_rate; + unsigned long current_rx_match_ratio; + + unsigned long max_rx_rate; + unsigned long max_tx_rate; + unsigned long max_rx_match_ratio; + + unsigned long rx_frames_delta; + unsigned long tx_frames_delta; + unsigned long matches_delta; +}; + +/* persistent statistics */ +struct can_rcv_lists_stats { + unsigned long stats_reset; + unsigned long user_reset; + unsigned long rcv_entries; + unsigned long rcv_entries_max; +}; + +/* function prototypes for the CAN networklayer procfs (proc.c) */ +void can_init_proc(struct net *net); +void can_remove_proc(struct net *net); +void can_stat_update(struct timer_list *t); + +#endif /* AF_CAN_H */ diff --git a/net/can/bcm.c b/net/can/bcm.c new file mode 100644 index 000000000..925d48cc5 --- /dev/null +++ b/net/can/bcm.c @@ -0,0 +1,1788 @@ +// SPDX-License-Identifier: (GPL-2.0 OR BSD-3-Clause) +/* + * bcm.c - Broadcast Manager to filter/send (cyclic) CAN content + * + * Copyright (c) 2002-2017 Volkswagen Group Electronic Research + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of Volkswagen nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * Alternatively, provided that this notice is retained in full, this + * software may be distributed under the terms of the GNU General + * Public License ("GPL") version 2, in which case the provisions of the + * GPL apply INSTEAD OF those given above. + * + * The provided data structures and external interfaces from this code + * are not restricted to be used by modules with a GPL compatible license. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH + * DAMAGE. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * To send multiple CAN frame content within TX_SETUP or to filter + * CAN messages with multiplex index within RX_SETUP, the number of + * different filters is limited to 256 due to the one byte index value. + */ +#define MAX_NFRAMES 256 + +/* limit timers to 400 days for sending/timeouts */ +#define BCM_TIMER_SEC_MAX (400 * 24 * 60 * 60) + +/* use of last_frames[index].flags */ +#define RX_RECV 0x40 /* received data for this element */ +#define RX_THR 0x80 /* element not been sent due to throttle feature */ +#define BCM_CAN_FLAGS_MASK 0x3F /* to clean private flags after usage */ + +/* get best masking value for can_rx_register() for a given single can_id */ +#define REGMASK(id) ((id & CAN_EFF_FLAG) ? \ + (CAN_EFF_MASK | CAN_EFF_FLAG | CAN_RTR_FLAG) : \ + (CAN_SFF_MASK | CAN_EFF_FLAG | CAN_RTR_FLAG)) + +MODULE_DESCRIPTION("PF_CAN broadcast manager protocol"); +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_AUTHOR("Oliver Hartkopp "); +MODULE_ALIAS("can-proto-2"); + +#define BCM_MIN_NAMELEN CAN_REQUIRED_SIZE(struct sockaddr_can, can_ifindex) + +/* + * easy access to the first 64 bit of can(fd)_frame payload. cp->data is + * 64 bit aligned so the offset has to be multiples of 8 which is ensured + * by the only callers in bcm_rx_cmp_to_index() bcm_rx_handler(). + */ +static inline u64 get_u64(const struct canfd_frame *cp, int offset) +{ + return *(u64 *)(cp->data + offset); +} + +struct bcm_op { + struct list_head list; + struct rcu_head rcu; + int ifindex; + canid_t can_id; + u32 flags; + unsigned long frames_abs, frames_filtered; + struct bcm_timeval ival1, ival2; + struct hrtimer timer, thrtimer; + ktime_t rx_stamp, kt_ival1, kt_ival2, kt_lastmsg; + int rx_ifindex; + int cfsiz; + u32 count; + u32 nframes; + u32 currframe; + /* void pointers to arrays of struct can[fd]_frame */ + void *frames; + void *last_frames; + struct canfd_frame sframe; + struct canfd_frame last_sframe; + struct sock *sk; + struct net_device *rx_reg_dev; +}; + +struct bcm_sock { + struct sock sk; + int bound; + int ifindex; + struct list_head notifier; + struct list_head rx_ops; + struct list_head tx_ops; + unsigned long dropped_usr_msgs; + struct proc_dir_entry *bcm_proc_read; + char procname [32]; /* inode number in decimal with \0 */ +}; + +static LIST_HEAD(bcm_notifier_list); +static DEFINE_SPINLOCK(bcm_notifier_lock); +static struct bcm_sock *bcm_busy_notifier; + +static inline struct bcm_sock *bcm_sk(const struct sock *sk) +{ + return (struct bcm_sock *)sk; +} + +static inline ktime_t bcm_timeval_to_ktime(struct bcm_timeval tv) +{ + return ktime_set(tv.tv_sec, tv.tv_usec * NSEC_PER_USEC); +} + +/* check limitations for timeval provided by user */ +static bool bcm_is_invalid_tv(struct bcm_msg_head *msg_head) +{ + if ((msg_head->ival1.tv_sec < 0) || + (msg_head->ival1.tv_sec > BCM_TIMER_SEC_MAX) || + (msg_head->ival1.tv_usec < 0) || + (msg_head->ival1.tv_usec >= USEC_PER_SEC) || + (msg_head->ival2.tv_sec < 0) || + (msg_head->ival2.tv_sec > BCM_TIMER_SEC_MAX) || + (msg_head->ival2.tv_usec < 0) || + (msg_head->ival2.tv_usec >= USEC_PER_SEC)) + return true; + + return false; +} + +#define CFSIZ(flags) ((flags & CAN_FD_FRAME) ? CANFD_MTU : CAN_MTU) +#define OPSIZ sizeof(struct bcm_op) +#define MHSIZ sizeof(struct bcm_msg_head) + +/* + * procfs functions + */ +#if IS_ENABLED(CONFIG_PROC_FS) +static char *bcm_proc_getifname(struct net *net, char *result, int ifindex) +{ + struct net_device *dev; + + if (!ifindex) + return "any"; + + rcu_read_lock(); + dev = dev_get_by_index_rcu(net, ifindex); + if (dev) + strcpy(result, dev->name); + else + strcpy(result, "???"); + rcu_read_unlock(); + + return result; +} + +static int bcm_proc_show(struct seq_file *m, void *v) +{ + char ifname[IFNAMSIZ]; + struct net *net = m->private; + struct sock *sk = (struct sock *)pde_data(m->file->f_inode); + struct bcm_sock *bo = bcm_sk(sk); + struct bcm_op *op; + + seq_printf(m, ">>> socket %pK", sk->sk_socket); + seq_printf(m, " / sk %pK", sk); + seq_printf(m, " / bo %pK", bo); + seq_printf(m, " / dropped %lu", bo->dropped_usr_msgs); + seq_printf(m, " / bound %s", bcm_proc_getifname(net, ifname, bo->ifindex)); + seq_printf(m, " <<<\n"); + + list_for_each_entry(op, &bo->rx_ops, list) { + + unsigned long reduction; + + /* print only active entries & prevent division by zero */ + if (!op->frames_abs) + continue; + + seq_printf(m, "rx_op: %03X %-5s ", op->can_id, + bcm_proc_getifname(net, ifname, op->ifindex)); + + if (op->flags & CAN_FD_FRAME) + seq_printf(m, "(%u)", op->nframes); + else + seq_printf(m, "[%u]", op->nframes); + + seq_printf(m, "%c ", (op->flags & RX_CHECK_DLC) ? 'd' : ' '); + + if (op->kt_ival1) + seq_printf(m, "timeo=%lld ", + (long long)ktime_to_us(op->kt_ival1)); + + if (op->kt_ival2) + seq_printf(m, "thr=%lld ", + (long long)ktime_to_us(op->kt_ival2)); + + seq_printf(m, "# recv %ld (%ld) => reduction: ", + op->frames_filtered, op->frames_abs); + + reduction = 100 - (op->frames_filtered * 100) / op->frames_abs; + + seq_printf(m, "%s%ld%%\n", + (reduction == 100) ? "near " : "", reduction); + } + + list_for_each_entry(op, &bo->tx_ops, list) { + + seq_printf(m, "tx_op: %03X %s ", op->can_id, + bcm_proc_getifname(net, ifname, op->ifindex)); + + if (op->flags & CAN_FD_FRAME) + seq_printf(m, "(%u) ", op->nframes); + else + seq_printf(m, "[%u] ", op->nframes); + + if (op->kt_ival1) + seq_printf(m, "t1=%lld ", + (long long)ktime_to_us(op->kt_ival1)); + + if (op->kt_ival2) + seq_printf(m, "t2=%lld ", + (long long)ktime_to_us(op->kt_ival2)); + + seq_printf(m, "# sent %ld\n", op->frames_abs); + } + seq_putc(m, '\n'); + return 0; +} +#endif /* CONFIG_PROC_FS */ + +/* + * bcm_can_tx - send the (next) CAN frame to the appropriate CAN interface + * of the given bcm tx op + */ +static void bcm_can_tx(struct bcm_op *op) +{ + struct sk_buff *skb; + struct net_device *dev; + struct canfd_frame *cf = op->frames + op->cfsiz * op->currframe; + int err; + + /* no target device? => exit */ + if (!op->ifindex) + return; + + dev = dev_get_by_index(sock_net(op->sk), op->ifindex); + if (!dev) { + /* RFC: should this bcm_op remove itself here? */ + return; + } + + skb = alloc_skb(op->cfsiz + sizeof(struct can_skb_priv), gfp_any()); + if (!skb) + goto out; + + can_skb_reserve(skb); + can_skb_prv(skb)->ifindex = dev->ifindex; + can_skb_prv(skb)->skbcnt = 0; + + skb_put_data(skb, cf, op->cfsiz); + + /* send with loopback */ + skb->dev = dev; + can_skb_set_owner(skb, op->sk); + err = can_send(skb, 1); + if (!err) + op->frames_abs++; + + op->currframe++; + + /* reached last frame? */ + if (op->currframe >= op->nframes) + op->currframe = 0; +out: + dev_put(dev); +} + +/* + * bcm_send_to_user - send a BCM message to the userspace + * (consisting of bcm_msg_head + x CAN frames) + */ +static void bcm_send_to_user(struct bcm_op *op, struct bcm_msg_head *head, + struct canfd_frame *frames, int has_timestamp) +{ + struct sk_buff *skb; + struct canfd_frame *firstframe; + struct sockaddr_can *addr; + struct sock *sk = op->sk; + unsigned int datalen = head->nframes * op->cfsiz; + int err; + + skb = alloc_skb(sizeof(*head) + datalen, gfp_any()); + if (!skb) + return; + + skb_put_data(skb, head, sizeof(*head)); + + if (head->nframes) { + /* CAN frames starting here */ + firstframe = (struct canfd_frame *)skb_tail_pointer(skb); + + skb_put_data(skb, frames, datalen); + + /* + * the BCM uses the flags-element of the canfd_frame + * structure for internal purposes. This is only + * relevant for updates that are generated by the + * BCM, where nframes is 1 + */ + if (head->nframes == 1) + firstframe->flags &= BCM_CAN_FLAGS_MASK; + } + + if (has_timestamp) { + /* restore rx timestamp */ + skb->tstamp = op->rx_stamp; + } + + /* + * Put the datagram to the queue so that bcm_recvmsg() can + * get it from there. We need to pass the interface index to + * bcm_recvmsg(). We pass a whole struct sockaddr_can in skb->cb + * containing the interface index. + */ + + sock_skb_cb_check_size(sizeof(struct sockaddr_can)); + addr = (struct sockaddr_can *)skb->cb; + memset(addr, 0, sizeof(*addr)); + addr->can_family = AF_CAN; + addr->can_ifindex = op->rx_ifindex; + + err = sock_queue_rcv_skb(sk, skb); + if (err < 0) { + struct bcm_sock *bo = bcm_sk(sk); + + kfree_skb(skb); + /* don't care about overflows in this statistic */ + bo->dropped_usr_msgs++; + } +} + +static bool bcm_tx_set_expiry(struct bcm_op *op, struct hrtimer *hrt) +{ + ktime_t ival; + + if (op->kt_ival1 && op->count) + ival = op->kt_ival1; + else if (op->kt_ival2) + ival = op->kt_ival2; + else + return false; + + hrtimer_set_expires(hrt, ktime_add(ktime_get(), ival)); + return true; +} + +static void bcm_tx_start_timer(struct bcm_op *op) +{ + if (bcm_tx_set_expiry(op, &op->timer)) + hrtimer_start_expires(&op->timer, HRTIMER_MODE_ABS_SOFT); +} + +/* bcm_tx_timeout_handler - performs cyclic CAN frame transmissions */ +static enum hrtimer_restart bcm_tx_timeout_handler(struct hrtimer *hrtimer) +{ + struct bcm_op *op = container_of(hrtimer, struct bcm_op, timer); + struct bcm_msg_head msg_head; + + if (op->kt_ival1 && (op->count > 0)) { + op->count--; + if (!op->count && (op->flags & TX_COUNTEVT)) { + + /* create notification to user */ + memset(&msg_head, 0, sizeof(msg_head)); + msg_head.opcode = TX_EXPIRED; + msg_head.flags = op->flags; + msg_head.count = op->count; + msg_head.ival1 = op->ival1; + msg_head.ival2 = op->ival2; + msg_head.can_id = op->can_id; + msg_head.nframes = 0; + + bcm_send_to_user(op, &msg_head, NULL, 0); + } + bcm_can_tx(op); + + } else if (op->kt_ival2) { + bcm_can_tx(op); + } + + return bcm_tx_set_expiry(op, &op->timer) ? + HRTIMER_RESTART : HRTIMER_NORESTART; +} + +/* + * bcm_rx_changed - create a RX_CHANGED notification due to changed content + */ +static void bcm_rx_changed(struct bcm_op *op, struct canfd_frame *data) +{ + struct bcm_msg_head head; + + /* update statistics */ + op->frames_filtered++; + + /* prevent statistics overflow */ + if (op->frames_filtered > ULONG_MAX/100) + op->frames_filtered = op->frames_abs = 0; + + /* this element is not throttled anymore */ + data->flags &= (BCM_CAN_FLAGS_MASK|RX_RECV); + + memset(&head, 0, sizeof(head)); + head.opcode = RX_CHANGED; + head.flags = op->flags; + head.count = op->count; + head.ival1 = op->ival1; + head.ival2 = op->ival2; + head.can_id = op->can_id; + head.nframes = 1; + + bcm_send_to_user(op, &head, data, 1); +} + +/* + * bcm_rx_update_and_send - process a detected relevant receive content change + * 1. update the last received data + * 2. send a notification to the user (if possible) + */ +static void bcm_rx_update_and_send(struct bcm_op *op, + struct canfd_frame *lastdata, + const struct canfd_frame *rxdata) +{ + memcpy(lastdata, rxdata, op->cfsiz); + + /* mark as used and throttled by default */ + lastdata->flags |= (RX_RECV|RX_THR); + + /* throttling mode inactive ? */ + if (!op->kt_ival2) { + /* send RX_CHANGED to the user immediately */ + bcm_rx_changed(op, lastdata); + return; + } + + /* with active throttling timer we are just done here */ + if (hrtimer_active(&op->thrtimer)) + return; + + /* first reception with enabled throttling mode */ + if (!op->kt_lastmsg) + goto rx_changed_settime; + + /* got a second frame inside a potential throttle period? */ + if (ktime_us_delta(ktime_get(), op->kt_lastmsg) < + ktime_to_us(op->kt_ival2)) { + /* do not send the saved data - only start throttle timer */ + hrtimer_start(&op->thrtimer, + ktime_add(op->kt_lastmsg, op->kt_ival2), + HRTIMER_MODE_ABS_SOFT); + return; + } + + /* the gap was that big, that throttling was not needed here */ +rx_changed_settime: + bcm_rx_changed(op, lastdata); + op->kt_lastmsg = ktime_get(); +} + +/* + * bcm_rx_cmp_to_index - (bit)compares the currently received data to formerly + * received data stored in op->last_frames[] + */ +static void bcm_rx_cmp_to_index(struct bcm_op *op, unsigned int index, + const struct canfd_frame *rxdata) +{ + struct canfd_frame *cf = op->frames + op->cfsiz * index; + struct canfd_frame *lcf = op->last_frames + op->cfsiz * index; + int i; + + /* + * no one uses the MSBs of flags for comparison, + * so we use it here to detect the first time of reception + */ + + if (!(lcf->flags & RX_RECV)) { + /* received data for the first time => send update to user */ + bcm_rx_update_and_send(op, lcf, rxdata); + return; + } + + /* do a real check in CAN frame data section */ + for (i = 0; i < rxdata->len; i += 8) { + if ((get_u64(cf, i) & get_u64(rxdata, i)) != + (get_u64(cf, i) & get_u64(lcf, i))) { + bcm_rx_update_and_send(op, lcf, rxdata); + return; + } + } + + if (op->flags & RX_CHECK_DLC) { + /* do a real check in CAN frame length */ + if (rxdata->len != lcf->len) { + bcm_rx_update_and_send(op, lcf, rxdata); + return; + } + } +} + +/* + * bcm_rx_starttimer - enable timeout monitoring for CAN frame reception + */ +static void bcm_rx_starttimer(struct bcm_op *op) +{ + if (op->flags & RX_NO_AUTOTIMER) + return; + + if (op->kt_ival1) + hrtimer_start(&op->timer, op->kt_ival1, HRTIMER_MODE_REL_SOFT); +} + +/* bcm_rx_timeout_handler - when the (cyclic) CAN frame reception timed out */ +static enum hrtimer_restart bcm_rx_timeout_handler(struct hrtimer *hrtimer) +{ + struct bcm_op *op = container_of(hrtimer, struct bcm_op, timer); + struct bcm_msg_head msg_head; + + /* if user wants to be informed, when cyclic CAN-Messages come back */ + if ((op->flags & RX_ANNOUNCE_RESUME) && op->last_frames) { + /* clear received CAN frames to indicate 'nothing received' */ + memset(op->last_frames, 0, op->nframes * op->cfsiz); + } + + /* create notification to user */ + memset(&msg_head, 0, sizeof(msg_head)); + msg_head.opcode = RX_TIMEOUT; + msg_head.flags = op->flags; + msg_head.count = op->count; + msg_head.ival1 = op->ival1; + msg_head.ival2 = op->ival2; + msg_head.can_id = op->can_id; + msg_head.nframes = 0; + + bcm_send_to_user(op, &msg_head, NULL, 0); + + return HRTIMER_NORESTART; +} + +/* + * bcm_rx_do_flush - helper for bcm_rx_thr_flush + */ +static inline int bcm_rx_do_flush(struct bcm_op *op, unsigned int index) +{ + struct canfd_frame *lcf = op->last_frames + op->cfsiz * index; + + if ((op->last_frames) && (lcf->flags & RX_THR)) { + bcm_rx_changed(op, lcf); + return 1; + } + return 0; +} + +/* + * bcm_rx_thr_flush - Check for throttled data and send it to the userspace + */ +static int bcm_rx_thr_flush(struct bcm_op *op) +{ + int updated = 0; + + if (op->nframes > 1) { + unsigned int i; + + /* for MUX filter we start at index 1 */ + for (i = 1; i < op->nframes; i++) + updated += bcm_rx_do_flush(op, i); + + } else { + /* for RX_FILTER_ID and simple filter */ + updated += bcm_rx_do_flush(op, 0); + } + + return updated; +} + +/* + * bcm_rx_thr_handler - the time for blocked content updates is over now: + * Check for throttled data and send it to the userspace + */ +static enum hrtimer_restart bcm_rx_thr_handler(struct hrtimer *hrtimer) +{ + struct bcm_op *op = container_of(hrtimer, struct bcm_op, thrtimer); + + if (bcm_rx_thr_flush(op)) { + hrtimer_forward_now(hrtimer, op->kt_ival2); + return HRTIMER_RESTART; + } else { + /* rearm throttle handling */ + op->kt_lastmsg = 0; + return HRTIMER_NORESTART; + } +} + +/* + * bcm_rx_handler - handle a CAN frame reception + */ +static void bcm_rx_handler(struct sk_buff *skb, void *data) +{ + struct bcm_op *op = (struct bcm_op *)data; + const struct canfd_frame *rxframe = (struct canfd_frame *)skb->data; + unsigned int i; + + if (op->can_id != rxframe->can_id) + return; + + /* make sure to handle the correct frame type (CAN / CAN FD) */ + if (op->flags & CAN_FD_FRAME) { + if (!can_is_canfd_skb(skb)) + return; + } else { + if (!can_is_can_skb(skb)) + return; + } + + /* disable timeout */ + hrtimer_cancel(&op->timer); + + /* save rx timestamp */ + op->rx_stamp = skb->tstamp; + /* save originator for recvfrom() */ + op->rx_ifindex = skb->dev->ifindex; + /* update statistics */ + op->frames_abs++; + + if (op->flags & RX_RTR_FRAME) { + /* send reply for RTR-request (placed in op->frames[0]) */ + bcm_can_tx(op); + return; + } + + if (op->flags & RX_FILTER_ID) { + /* the easiest case */ + bcm_rx_update_and_send(op, op->last_frames, rxframe); + goto rx_starttimer; + } + + if (op->nframes == 1) { + /* simple compare with index 0 */ + bcm_rx_cmp_to_index(op, 0, rxframe); + goto rx_starttimer; + } + + if (op->nframes > 1) { + /* + * multiplex compare + * + * find the first multiplex mask that fits. + * Remark: The MUX-mask is stored in index 0 - but only the + * first 64 bits of the frame data[] are relevant (CAN FD) + */ + + for (i = 1; i < op->nframes; i++) { + if ((get_u64(op->frames, 0) & get_u64(rxframe, 0)) == + (get_u64(op->frames, 0) & + get_u64(op->frames + op->cfsiz * i, 0))) { + bcm_rx_cmp_to_index(op, i, rxframe); + break; + } + } + } + +rx_starttimer: + bcm_rx_starttimer(op); +} + +/* + * helpers for bcm_op handling: find & delete bcm [rx|tx] op elements + */ +static struct bcm_op *bcm_find_op(struct list_head *ops, + struct bcm_msg_head *mh, int ifindex) +{ + struct bcm_op *op; + + list_for_each_entry(op, ops, list) { + if ((op->can_id == mh->can_id) && (op->ifindex == ifindex) && + (op->flags & CAN_FD_FRAME) == (mh->flags & CAN_FD_FRAME)) + return op; + } + + return NULL; +} + +static void bcm_free_op_rcu(struct rcu_head *rcu_head) +{ + struct bcm_op *op = container_of(rcu_head, struct bcm_op, rcu); + + if ((op->frames) && (op->frames != &op->sframe)) + kfree(op->frames); + + if ((op->last_frames) && (op->last_frames != &op->last_sframe)) + kfree(op->last_frames); + + kfree(op); +} + +static void bcm_remove_op(struct bcm_op *op) +{ + hrtimer_cancel(&op->timer); + hrtimer_cancel(&op->thrtimer); + + call_rcu(&op->rcu, bcm_free_op_rcu); +} + +static void bcm_rx_unreg(struct net_device *dev, struct bcm_op *op) +{ + if (op->rx_reg_dev == dev) { + can_rx_unregister(dev_net(dev), dev, op->can_id, + REGMASK(op->can_id), bcm_rx_handler, op); + + /* mark as removed subscription */ + op->rx_reg_dev = NULL; + } else + printk(KERN_ERR "can-bcm: bcm_rx_unreg: registered device " + "mismatch %p %p\n", op->rx_reg_dev, dev); +} + +/* + * bcm_delete_rx_op - find and remove a rx op (returns number of removed ops) + */ +static int bcm_delete_rx_op(struct list_head *ops, struct bcm_msg_head *mh, + int ifindex) +{ + struct bcm_op *op, *n; + + list_for_each_entry_safe(op, n, ops, list) { + if ((op->can_id == mh->can_id) && (op->ifindex == ifindex) && + (op->flags & CAN_FD_FRAME) == (mh->flags & CAN_FD_FRAME)) { + + /* disable automatic timer on frame reception */ + op->flags |= RX_NO_AUTOTIMER; + + /* + * Don't care if we're bound or not (due to netdev + * problems) can_rx_unregister() is always a save + * thing to do here. + */ + if (op->ifindex) { + /* + * Only remove subscriptions that had not + * been removed due to NETDEV_UNREGISTER + * in bcm_notifier() + */ + if (op->rx_reg_dev) { + struct net_device *dev; + + dev = dev_get_by_index(sock_net(op->sk), + op->ifindex); + if (dev) { + bcm_rx_unreg(dev, op); + dev_put(dev); + } + } + } else + can_rx_unregister(sock_net(op->sk), NULL, + op->can_id, + REGMASK(op->can_id), + bcm_rx_handler, op); + + list_del(&op->list); + bcm_remove_op(op); + return 1; /* done */ + } + } + + return 0; /* not found */ +} + +/* + * bcm_delete_tx_op - find and remove a tx op (returns number of removed ops) + */ +static int bcm_delete_tx_op(struct list_head *ops, struct bcm_msg_head *mh, + int ifindex) +{ + struct bcm_op *op, *n; + + list_for_each_entry_safe(op, n, ops, list) { + if ((op->can_id == mh->can_id) && (op->ifindex == ifindex) && + (op->flags & CAN_FD_FRAME) == (mh->flags & CAN_FD_FRAME)) { + list_del(&op->list); + bcm_remove_op(op); + return 1; /* done */ + } + } + + return 0; /* not found */ +} + +/* + * bcm_read_op - read out a bcm_op and send it to the user (for bcm_sendmsg) + */ +static int bcm_read_op(struct list_head *ops, struct bcm_msg_head *msg_head, + int ifindex) +{ + struct bcm_op *op = bcm_find_op(ops, msg_head, ifindex); + + if (!op) + return -EINVAL; + + /* put current values into msg_head */ + msg_head->flags = op->flags; + msg_head->count = op->count; + msg_head->ival1 = op->ival1; + msg_head->ival2 = op->ival2; + msg_head->nframes = op->nframes; + + bcm_send_to_user(op, msg_head, op->frames, 0); + + return MHSIZ; +} + +/* + * bcm_tx_setup - create or update a bcm tx op (for bcm_sendmsg) + */ +static int bcm_tx_setup(struct bcm_msg_head *msg_head, struct msghdr *msg, + int ifindex, struct sock *sk) +{ + struct bcm_sock *bo = bcm_sk(sk); + struct bcm_op *op; + struct canfd_frame *cf; + unsigned int i; + int err; + + /* we need a real device to send frames */ + if (!ifindex) + return -ENODEV; + + /* check nframes boundaries - we need at least one CAN frame */ + if (msg_head->nframes < 1 || msg_head->nframes > MAX_NFRAMES) + return -EINVAL; + + /* check timeval limitations */ + if ((msg_head->flags & SETTIMER) && bcm_is_invalid_tv(msg_head)) + return -EINVAL; + + /* check the given can_id */ + op = bcm_find_op(&bo->tx_ops, msg_head, ifindex); + if (op) { + /* update existing BCM operation */ + + /* + * Do we need more space for the CAN frames than currently + * allocated? -> This is a _really_ unusual use-case and + * therefore (complexity / locking) it is not supported. + */ + if (msg_head->nframes > op->nframes) + return -E2BIG; + + /* update CAN frames content */ + for (i = 0; i < msg_head->nframes; i++) { + + cf = op->frames + op->cfsiz * i; + err = memcpy_from_msg((u8 *)cf, msg, op->cfsiz); + + if (op->flags & CAN_FD_FRAME) { + if (cf->len > 64) + err = -EINVAL; + } else { + if (cf->len > 8) + err = -EINVAL; + } + + if (err < 0) + return err; + + if (msg_head->flags & TX_CP_CAN_ID) { + /* copy can_id into frame */ + cf->can_id = msg_head->can_id; + } + } + op->flags = msg_head->flags; + + } else { + /* insert new BCM operation for the given can_id */ + + op = kzalloc(OPSIZ, GFP_KERNEL); + if (!op) + return -ENOMEM; + + op->can_id = msg_head->can_id; + op->cfsiz = CFSIZ(msg_head->flags); + op->flags = msg_head->flags; + + /* create array for CAN frames and copy the data */ + if (msg_head->nframes > 1) { + op->frames = kmalloc_array(msg_head->nframes, + op->cfsiz, + GFP_KERNEL); + if (!op->frames) { + kfree(op); + return -ENOMEM; + } + } else + op->frames = &op->sframe; + + for (i = 0; i < msg_head->nframes; i++) { + + cf = op->frames + op->cfsiz * i; + err = memcpy_from_msg((u8 *)cf, msg, op->cfsiz); + if (err < 0) + goto free_op; + + if (op->flags & CAN_FD_FRAME) { + if (cf->len > 64) + err = -EINVAL; + } else { + if (cf->len > 8) + err = -EINVAL; + } + + if (err < 0) + goto free_op; + + if (msg_head->flags & TX_CP_CAN_ID) { + /* copy can_id into frame */ + cf->can_id = msg_head->can_id; + } + } + + /* tx_ops never compare with previous received messages */ + op->last_frames = NULL; + + /* bcm_can_tx / bcm_tx_timeout_handler needs this */ + op->sk = sk; + op->ifindex = ifindex; + + /* initialize uninitialized (kzalloc) structure */ + hrtimer_init(&op->timer, CLOCK_MONOTONIC, + HRTIMER_MODE_REL_SOFT); + op->timer.function = bcm_tx_timeout_handler; + + /* currently unused in tx_ops */ + hrtimer_init(&op->thrtimer, CLOCK_MONOTONIC, + HRTIMER_MODE_REL_SOFT); + + /* add this bcm_op to the list of the tx_ops */ + list_add(&op->list, &bo->tx_ops); + + } /* if ((op = bcm_find_op(&bo->tx_ops, msg_head->can_id, ifindex))) */ + + if (op->nframes != msg_head->nframes) { + op->nframes = msg_head->nframes; + /* start multiple frame transmission with index 0 */ + op->currframe = 0; + } + + /* check flags */ + + if (op->flags & TX_RESET_MULTI_IDX) { + /* start multiple frame transmission with index 0 */ + op->currframe = 0; + } + + if (op->flags & SETTIMER) { + /* set timer values */ + op->count = msg_head->count; + op->ival1 = msg_head->ival1; + op->ival2 = msg_head->ival2; + op->kt_ival1 = bcm_timeval_to_ktime(msg_head->ival1); + op->kt_ival2 = bcm_timeval_to_ktime(msg_head->ival2); + + /* disable an active timer due to zero values? */ + if (!op->kt_ival1 && !op->kt_ival2) + hrtimer_cancel(&op->timer); + } + + if (op->flags & STARTTIMER) { + hrtimer_cancel(&op->timer); + /* spec: send CAN frame when starting timer */ + op->flags |= TX_ANNOUNCE; + } + + if (op->flags & TX_ANNOUNCE) { + bcm_can_tx(op); + if (op->count) + op->count--; + } + + if (op->flags & STARTTIMER) + bcm_tx_start_timer(op); + + return msg_head->nframes * op->cfsiz + MHSIZ; + +free_op: + if (op->frames != &op->sframe) + kfree(op->frames); + kfree(op); + return err; +} + +/* + * bcm_rx_setup - create or update a bcm rx op (for bcm_sendmsg) + */ +static int bcm_rx_setup(struct bcm_msg_head *msg_head, struct msghdr *msg, + int ifindex, struct sock *sk) +{ + struct bcm_sock *bo = bcm_sk(sk); + struct bcm_op *op; + int do_rx_register; + int err = 0; + + if ((msg_head->flags & RX_FILTER_ID) || (!(msg_head->nframes))) { + /* be robust against wrong usage ... */ + msg_head->flags |= RX_FILTER_ID; + /* ignore trailing garbage */ + msg_head->nframes = 0; + } + + /* the first element contains the mux-mask => MAX_NFRAMES + 1 */ + if (msg_head->nframes > MAX_NFRAMES + 1) + return -EINVAL; + + if ((msg_head->flags & RX_RTR_FRAME) && + ((msg_head->nframes != 1) || + (!(msg_head->can_id & CAN_RTR_FLAG)))) + return -EINVAL; + + /* check timeval limitations */ + if ((msg_head->flags & SETTIMER) && bcm_is_invalid_tv(msg_head)) + return -EINVAL; + + /* check the given can_id */ + op = bcm_find_op(&bo->rx_ops, msg_head, ifindex); + if (op) { + /* update existing BCM operation */ + + /* + * Do we need more space for the CAN frames than currently + * allocated? -> This is a _really_ unusual use-case and + * therefore (complexity / locking) it is not supported. + */ + if (msg_head->nframes > op->nframes) + return -E2BIG; + + if (msg_head->nframes) { + /* update CAN frames content */ + err = memcpy_from_msg(op->frames, msg, + msg_head->nframes * op->cfsiz); + if (err < 0) + return err; + + /* clear last_frames to indicate 'nothing received' */ + memset(op->last_frames, 0, msg_head->nframes * op->cfsiz); + } + + op->nframes = msg_head->nframes; + op->flags = msg_head->flags; + + /* Only an update -> do not call can_rx_register() */ + do_rx_register = 0; + + } else { + /* insert new BCM operation for the given can_id */ + op = kzalloc(OPSIZ, GFP_KERNEL); + if (!op) + return -ENOMEM; + + op->can_id = msg_head->can_id; + op->nframes = msg_head->nframes; + op->cfsiz = CFSIZ(msg_head->flags); + op->flags = msg_head->flags; + + if (msg_head->nframes > 1) { + /* create array for CAN frames and copy the data */ + op->frames = kmalloc_array(msg_head->nframes, + op->cfsiz, + GFP_KERNEL); + if (!op->frames) { + kfree(op); + return -ENOMEM; + } + + /* create and init array for received CAN frames */ + op->last_frames = kcalloc(msg_head->nframes, + op->cfsiz, + GFP_KERNEL); + if (!op->last_frames) { + kfree(op->frames); + kfree(op); + return -ENOMEM; + } + + } else { + op->frames = &op->sframe; + op->last_frames = &op->last_sframe; + } + + if (msg_head->nframes) { + err = memcpy_from_msg(op->frames, msg, + msg_head->nframes * op->cfsiz); + if (err < 0) { + if (op->frames != &op->sframe) + kfree(op->frames); + if (op->last_frames != &op->last_sframe) + kfree(op->last_frames); + kfree(op); + return err; + } + } + + /* bcm_can_tx / bcm_tx_timeout_handler needs this */ + op->sk = sk; + op->ifindex = ifindex; + + /* ifindex for timeout events w/o previous frame reception */ + op->rx_ifindex = ifindex; + + /* initialize uninitialized (kzalloc) structure */ + hrtimer_init(&op->timer, CLOCK_MONOTONIC, + HRTIMER_MODE_REL_SOFT); + op->timer.function = bcm_rx_timeout_handler; + + hrtimer_init(&op->thrtimer, CLOCK_MONOTONIC, + HRTIMER_MODE_REL_SOFT); + op->thrtimer.function = bcm_rx_thr_handler; + + /* add this bcm_op to the list of the rx_ops */ + list_add(&op->list, &bo->rx_ops); + + /* call can_rx_register() */ + do_rx_register = 1; + + } /* if ((op = bcm_find_op(&bo->rx_ops, msg_head->can_id, ifindex))) */ + + /* check flags */ + + if (op->flags & RX_RTR_FRAME) { + struct canfd_frame *frame0 = op->frames; + + /* no timers in RTR-mode */ + hrtimer_cancel(&op->thrtimer); + hrtimer_cancel(&op->timer); + + /* + * funny feature in RX(!)_SETUP only for RTR-mode: + * copy can_id into frame BUT without RTR-flag to + * prevent a full-load-loopback-test ... ;-] + */ + if ((op->flags & TX_CP_CAN_ID) || + (frame0->can_id == op->can_id)) + frame0->can_id = op->can_id & ~CAN_RTR_FLAG; + + } else { + if (op->flags & SETTIMER) { + + /* set timer value */ + op->ival1 = msg_head->ival1; + op->ival2 = msg_head->ival2; + op->kt_ival1 = bcm_timeval_to_ktime(msg_head->ival1); + op->kt_ival2 = bcm_timeval_to_ktime(msg_head->ival2); + + /* disable an active timer due to zero value? */ + if (!op->kt_ival1) + hrtimer_cancel(&op->timer); + + /* + * In any case cancel the throttle timer, flush + * potentially blocked msgs and reset throttle handling + */ + op->kt_lastmsg = 0; + hrtimer_cancel(&op->thrtimer); + bcm_rx_thr_flush(op); + } + + if ((op->flags & STARTTIMER) && op->kt_ival1) + hrtimer_start(&op->timer, op->kt_ival1, + HRTIMER_MODE_REL_SOFT); + } + + /* now we can register for can_ids, if we added a new bcm_op */ + if (do_rx_register) { + if (ifindex) { + struct net_device *dev; + + dev = dev_get_by_index(sock_net(sk), ifindex); + if (dev) { + err = can_rx_register(sock_net(sk), dev, + op->can_id, + REGMASK(op->can_id), + bcm_rx_handler, op, + "bcm", sk); + + op->rx_reg_dev = dev; + dev_put(dev); + } + + } else + err = can_rx_register(sock_net(sk), NULL, op->can_id, + REGMASK(op->can_id), + bcm_rx_handler, op, "bcm", sk); + if (err) { + /* this bcm rx op is broken -> remove it */ + list_del(&op->list); + bcm_remove_op(op); + return err; + } + } + + return msg_head->nframes * op->cfsiz + MHSIZ; +} + +/* + * bcm_tx_send - send a single CAN frame to the CAN interface (for bcm_sendmsg) + */ +static int bcm_tx_send(struct msghdr *msg, int ifindex, struct sock *sk, + int cfsiz) +{ + struct sk_buff *skb; + struct net_device *dev; + int err; + + /* we need a real device to send frames */ + if (!ifindex) + return -ENODEV; + + skb = alloc_skb(cfsiz + sizeof(struct can_skb_priv), GFP_KERNEL); + if (!skb) + return -ENOMEM; + + can_skb_reserve(skb); + + err = memcpy_from_msg(skb_put(skb, cfsiz), msg, cfsiz); + if (err < 0) { + kfree_skb(skb); + return err; + } + + dev = dev_get_by_index(sock_net(sk), ifindex); + if (!dev) { + kfree_skb(skb); + return -ENODEV; + } + + can_skb_prv(skb)->ifindex = dev->ifindex; + can_skb_prv(skb)->skbcnt = 0; + skb->dev = dev; + can_skb_set_owner(skb, sk); + err = can_send(skb, 1); /* send with loopback */ + dev_put(dev); + + if (err) + return err; + + return cfsiz + MHSIZ; +} + +/* + * bcm_sendmsg - process BCM commands (opcodes) from the userspace + */ +static int bcm_sendmsg(struct socket *sock, struct msghdr *msg, size_t size) +{ + struct sock *sk = sock->sk; + struct bcm_sock *bo = bcm_sk(sk); + int ifindex = bo->ifindex; /* default ifindex for this bcm_op */ + struct bcm_msg_head msg_head; + int cfsiz; + int ret; /* read bytes or error codes as return value */ + + if (!bo->bound) + return -ENOTCONN; + + /* check for valid message length from userspace */ + if (size < MHSIZ) + return -EINVAL; + + /* read message head information */ + ret = memcpy_from_msg((u8 *)&msg_head, msg, MHSIZ); + if (ret < 0) + return ret; + + cfsiz = CFSIZ(msg_head.flags); + if ((size - MHSIZ) % cfsiz) + return -EINVAL; + + /* check for alternative ifindex for this bcm_op */ + + if (!ifindex && msg->msg_name) { + /* no bound device as default => check msg_name */ + DECLARE_SOCKADDR(struct sockaddr_can *, addr, msg->msg_name); + + if (msg->msg_namelen < BCM_MIN_NAMELEN) + return -EINVAL; + + if (addr->can_family != AF_CAN) + return -EINVAL; + + /* ifindex from sendto() */ + ifindex = addr->can_ifindex; + + if (ifindex) { + struct net_device *dev; + + dev = dev_get_by_index(sock_net(sk), ifindex); + if (!dev) + return -ENODEV; + + if (dev->type != ARPHRD_CAN) { + dev_put(dev); + return -ENODEV; + } + + dev_put(dev); + } + } + + lock_sock(sk); + + switch (msg_head.opcode) { + + case TX_SETUP: + ret = bcm_tx_setup(&msg_head, msg, ifindex, sk); + break; + + case RX_SETUP: + ret = bcm_rx_setup(&msg_head, msg, ifindex, sk); + break; + + case TX_DELETE: + if (bcm_delete_tx_op(&bo->tx_ops, &msg_head, ifindex)) + ret = MHSIZ; + else + ret = -EINVAL; + break; + + case RX_DELETE: + if (bcm_delete_rx_op(&bo->rx_ops, &msg_head, ifindex)) + ret = MHSIZ; + else + ret = -EINVAL; + break; + + case TX_READ: + /* reuse msg_head for the reply to TX_READ */ + msg_head.opcode = TX_STATUS; + ret = bcm_read_op(&bo->tx_ops, &msg_head, ifindex); + break; + + case RX_READ: + /* reuse msg_head for the reply to RX_READ */ + msg_head.opcode = RX_STATUS; + ret = bcm_read_op(&bo->rx_ops, &msg_head, ifindex); + break; + + case TX_SEND: + /* we need exactly one CAN frame behind the msg head */ + if ((msg_head.nframes != 1) || (size != cfsiz + MHSIZ)) + ret = -EINVAL; + else + ret = bcm_tx_send(msg, ifindex, sk, cfsiz); + break; + + default: + ret = -EINVAL; + break; + } + + release_sock(sk); + + return ret; +} + +/* + * notification handler for netdevice status changes + */ +static void bcm_notify(struct bcm_sock *bo, unsigned long msg, + struct net_device *dev) +{ + struct sock *sk = &bo->sk; + struct bcm_op *op; + int notify_enodev = 0; + + if (!net_eq(dev_net(dev), sock_net(sk))) + return; + + switch (msg) { + + case NETDEV_UNREGISTER: + lock_sock(sk); + + /* remove device specific receive entries */ + list_for_each_entry(op, &bo->rx_ops, list) + if (op->rx_reg_dev == dev) + bcm_rx_unreg(dev, op); + + /* remove device reference, if this is our bound device */ + if (bo->bound && bo->ifindex == dev->ifindex) { + bo->bound = 0; + bo->ifindex = 0; + notify_enodev = 1; + } + + release_sock(sk); + + if (notify_enodev) { + sk->sk_err = ENODEV; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + } + break; + + case NETDEV_DOWN: + if (bo->bound && bo->ifindex == dev->ifindex) { + sk->sk_err = ENETDOWN; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + } + } +} + +static int bcm_notifier(struct notifier_block *nb, unsigned long msg, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + if (dev->type != ARPHRD_CAN) + return NOTIFY_DONE; + if (msg != NETDEV_UNREGISTER && msg != NETDEV_DOWN) + return NOTIFY_DONE; + if (unlikely(bcm_busy_notifier)) /* Check for reentrant bug. */ + return NOTIFY_DONE; + + spin_lock(&bcm_notifier_lock); + list_for_each_entry(bcm_busy_notifier, &bcm_notifier_list, notifier) { + spin_unlock(&bcm_notifier_lock); + bcm_notify(bcm_busy_notifier, msg, dev); + spin_lock(&bcm_notifier_lock); + } + bcm_busy_notifier = NULL; + spin_unlock(&bcm_notifier_lock); + return NOTIFY_DONE; +} + +/* + * initial settings for all BCM sockets to be set at socket creation time + */ +static int bcm_init(struct sock *sk) +{ + struct bcm_sock *bo = bcm_sk(sk); + + bo->bound = 0; + bo->ifindex = 0; + bo->dropped_usr_msgs = 0; + bo->bcm_proc_read = NULL; + + INIT_LIST_HEAD(&bo->tx_ops); + INIT_LIST_HEAD(&bo->rx_ops); + + /* set notifier */ + spin_lock(&bcm_notifier_lock); + list_add_tail(&bo->notifier, &bcm_notifier_list); + spin_unlock(&bcm_notifier_lock); + + return 0; +} + +/* + * standard socket functions + */ +static int bcm_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct net *net; + struct bcm_sock *bo; + struct bcm_op *op, *next; + + if (!sk) + return 0; + + net = sock_net(sk); + bo = bcm_sk(sk); + + /* remove bcm_ops, timer, rx_unregister(), etc. */ + + spin_lock(&bcm_notifier_lock); + while (bcm_busy_notifier == bo) { + spin_unlock(&bcm_notifier_lock); + schedule_timeout_uninterruptible(1); + spin_lock(&bcm_notifier_lock); + } + list_del(&bo->notifier); + spin_unlock(&bcm_notifier_lock); + + lock_sock(sk); + +#if IS_ENABLED(CONFIG_PROC_FS) + /* remove procfs entry */ + if (net->can.bcmproc_dir && bo->bcm_proc_read) + remove_proc_entry(bo->procname, net->can.bcmproc_dir); +#endif /* CONFIG_PROC_FS */ + + list_for_each_entry_safe(op, next, &bo->tx_ops, list) + bcm_remove_op(op); + + list_for_each_entry_safe(op, next, &bo->rx_ops, list) { + /* + * Don't care if we're bound or not (due to netdev problems) + * can_rx_unregister() is always a save thing to do here. + */ + if (op->ifindex) { + /* + * Only remove subscriptions that had not + * been removed due to NETDEV_UNREGISTER + * in bcm_notifier() + */ + if (op->rx_reg_dev) { + struct net_device *dev; + + dev = dev_get_by_index(net, op->ifindex); + if (dev) { + bcm_rx_unreg(dev, op); + dev_put(dev); + } + } + } else + can_rx_unregister(net, NULL, op->can_id, + REGMASK(op->can_id), + bcm_rx_handler, op); + + } + + synchronize_rcu(); + + list_for_each_entry_safe(op, next, &bo->rx_ops, list) + bcm_remove_op(op); + + /* remove device reference */ + if (bo->bound) { + bo->bound = 0; + bo->ifindex = 0; + } + + sock_orphan(sk); + sock->sk = NULL; + + release_sock(sk); + sock_put(sk); + + return 0; +} + +static int bcm_connect(struct socket *sock, struct sockaddr *uaddr, int len, + int flags) +{ + struct sockaddr_can *addr = (struct sockaddr_can *)uaddr; + struct sock *sk = sock->sk; + struct bcm_sock *bo = bcm_sk(sk); + struct net *net = sock_net(sk); + int ret = 0; + + if (len < BCM_MIN_NAMELEN) + return -EINVAL; + + lock_sock(sk); + + if (bo->bound) { + ret = -EISCONN; + goto fail; + } + + /* bind a device to this socket */ + if (addr->can_ifindex) { + struct net_device *dev; + + dev = dev_get_by_index(net, addr->can_ifindex); + if (!dev) { + ret = -ENODEV; + goto fail; + } + if (dev->type != ARPHRD_CAN) { + dev_put(dev); + ret = -ENODEV; + goto fail; + } + + bo->ifindex = dev->ifindex; + dev_put(dev); + + } else { + /* no interface reference for ifindex = 0 ('any' CAN device) */ + bo->ifindex = 0; + } + +#if IS_ENABLED(CONFIG_PROC_FS) + if (net->can.bcmproc_dir) { + /* unique socket address as filename */ + sprintf(bo->procname, "%lu", sock_i_ino(sk)); + bo->bcm_proc_read = proc_create_net_single(bo->procname, 0644, + net->can.bcmproc_dir, + bcm_proc_show, sk); + if (!bo->bcm_proc_read) { + ret = -ENOMEM; + goto fail; + } + } +#endif /* CONFIG_PROC_FS */ + + bo->bound = 1; + +fail: + release_sock(sk); + + return ret; +} + +static int bcm_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, + int flags) +{ + struct sock *sk = sock->sk; + struct sk_buff *skb; + int error = 0; + int err; + + skb = skb_recv_datagram(sk, flags, &error); + if (!skb) + return error; + + if (skb->len < size) + size = skb->len; + + err = memcpy_to_msg(msg, skb->data, size); + if (err < 0) { + skb_free_datagram(sk, skb); + return err; + } + + sock_recv_cmsgs(msg, sk, skb); + + if (msg->msg_name) { + __sockaddr_check_size(BCM_MIN_NAMELEN); + msg->msg_namelen = BCM_MIN_NAMELEN; + memcpy(msg->msg_name, skb->cb, msg->msg_namelen); + } + + skb_free_datagram(sk, skb); + + return size; +} + +static int bcm_sock_no_ioctlcmd(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + /* no ioctls for socket layer -> hand it down to NIC layer */ + return -ENOIOCTLCMD; +} + +static const struct proto_ops bcm_ops = { + .family = PF_CAN, + .release = bcm_release, + .bind = sock_no_bind, + .connect = bcm_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = sock_no_getname, + .poll = datagram_poll, + .ioctl = bcm_sock_no_ioctlcmd, + .gettstamp = sock_gettstamp, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .sendmsg = bcm_sendmsg, + .recvmsg = bcm_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static struct proto bcm_proto __read_mostly = { + .name = "CAN_BCM", + .owner = THIS_MODULE, + .obj_size = sizeof(struct bcm_sock), + .init = bcm_init, +}; + +static const struct can_proto bcm_can_proto = { + .type = SOCK_DGRAM, + .protocol = CAN_BCM, + .ops = &bcm_ops, + .prot = &bcm_proto, +}; + +static int canbcm_pernet_init(struct net *net) +{ +#if IS_ENABLED(CONFIG_PROC_FS) + /* create /proc/net/can-bcm directory */ + net->can.bcmproc_dir = proc_net_mkdir(net, "can-bcm", net->proc_net); +#endif /* CONFIG_PROC_FS */ + + return 0; +} + +static void canbcm_pernet_exit(struct net *net) +{ +#if IS_ENABLED(CONFIG_PROC_FS) + /* remove /proc/net/can-bcm directory */ + if (net->can.bcmproc_dir) + remove_proc_entry("can-bcm", net->proc_net); +#endif /* CONFIG_PROC_FS */ +} + +static struct pernet_operations canbcm_pernet_ops __read_mostly = { + .init = canbcm_pernet_init, + .exit = canbcm_pernet_exit, +}; + +static struct notifier_block canbcm_notifier = { + .notifier_call = bcm_notifier +}; + +static int __init bcm_module_init(void) +{ + int err; + + pr_info("can: broadcast manager protocol\n"); + + err = register_pernet_subsys(&canbcm_pernet_ops); + if (err) + return err; + + err = register_netdevice_notifier(&canbcm_notifier); + if (err) + goto register_notifier_failed; + + err = can_proto_register(&bcm_can_proto); + if (err < 0) { + printk(KERN_ERR "can: registration of bcm protocol failed\n"); + goto register_proto_failed; + } + + return 0; + +register_proto_failed: + unregister_netdevice_notifier(&canbcm_notifier); +register_notifier_failed: + unregister_pernet_subsys(&canbcm_pernet_ops); + return err; +} + +static void __exit bcm_module_exit(void) +{ + can_proto_unregister(&bcm_can_proto); + unregister_netdevice_notifier(&canbcm_notifier); + unregister_pernet_subsys(&canbcm_pernet_ops); +} + +module_init(bcm_module_init); +module_exit(bcm_module_exit); diff --git a/net/can/gw.c b/net/can/gw.c new file mode 100644 index 000000000..23a3d89ca --- /dev/null +++ b/net/can/gw.c @@ -0,0 +1,1329 @@ +// SPDX-License-Identifier: (GPL-2.0 OR BSD-3-Clause) +/* gw.c - CAN frame Gateway/Router/Bridge with netlink interface + * + * Copyright (c) 2019 Volkswagen Group Electronic Research + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of Volkswagen nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * Alternatively, provided that this notice is retained in full, this + * software may be distributed under the terms of the GNU General + * Public License ("GPL") version 2, in which case the provisions of the + * GPL apply INSTEAD OF those given above. + * + * The provided data structures and external interfaces from this code + * are not restricted to be used by modules with a GPL compatible license. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH + * DAMAGE. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define CAN_GW_NAME "can-gw" + +MODULE_DESCRIPTION("PF_CAN netlink gateway"); +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_AUTHOR("Oliver Hartkopp "); +MODULE_ALIAS(CAN_GW_NAME); + +#define CGW_MIN_HOPS 1 +#define CGW_MAX_HOPS 6 +#define CGW_DEFAULT_HOPS 1 + +static unsigned int max_hops __read_mostly = CGW_DEFAULT_HOPS; +module_param(max_hops, uint, 0444); +MODULE_PARM_DESC(max_hops, + "maximum " CAN_GW_NAME " routing hops for CAN frames " + "(valid values: " __stringify(CGW_MIN_HOPS) "-" + __stringify(CGW_MAX_HOPS) " hops, " + "default: " __stringify(CGW_DEFAULT_HOPS) ")"); + +static struct notifier_block notifier; +static struct kmem_cache *cgw_cache __read_mostly; + +/* structure that contains the (on-the-fly) CAN frame modifications */ +struct cf_mod { + struct { + struct canfd_frame and; + struct canfd_frame or; + struct canfd_frame xor; + struct canfd_frame set; + } modframe; + struct { + u8 and; + u8 or; + u8 xor; + u8 set; + } modtype; + void (*modfunc[MAX_MODFUNCTIONS])(struct canfd_frame *cf, + struct cf_mod *mod); + + /* CAN frame checksum calculation after CAN frame modifications */ + struct { + struct cgw_csum_xor xor; + struct cgw_csum_crc8 crc8; + } csum; + struct { + void (*xor)(struct canfd_frame *cf, + struct cgw_csum_xor *xor); + void (*crc8)(struct canfd_frame *cf, + struct cgw_csum_crc8 *crc8); + } csumfunc; + u32 uid; +}; + +/* So far we just support CAN -> CAN routing and frame modifications. + * + * The internal can_can_gw structure contains data and attributes for + * a CAN -> CAN gateway job. + */ +struct can_can_gw { + struct can_filter filter; + int src_idx; + int dst_idx; +}; + +/* list entry for CAN gateways jobs */ +struct cgw_job { + struct hlist_node list; + struct rcu_head rcu; + u32 handled_frames; + u32 dropped_frames; + u32 deleted_frames; + struct cf_mod mod; + union { + /* CAN frame data source */ + struct net_device *dev; + } src; + union { + /* CAN frame data destination */ + struct net_device *dev; + } dst; + union { + struct can_can_gw ccgw; + /* tbc */ + }; + u8 gwtype; + u8 limit_hops; + u16 flags; +}; + +/* modification functions that are invoked in the hot path in can_can_gw_rcv */ + +#define MODFUNC(func, op) static void func(struct canfd_frame *cf, \ + struct cf_mod *mod) { op ; } + +MODFUNC(mod_and_id, cf->can_id &= mod->modframe.and.can_id) +MODFUNC(mod_and_len, cf->len &= mod->modframe.and.len) +MODFUNC(mod_and_flags, cf->flags &= mod->modframe.and.flags) +MODFUNC(mod_and_data, *(u64 *)cf->data &= *(u64 *)mod->modframe.and.data) +MODFUNC(mod_or_id, cf->can_id |= mod->modframe.or.can_id) +MODFUNC(mod_or_len, cf->len |= mod->modframe.or.len) +MODFUNC(mod_or_flags, cf->flags |= mod->modframe.or.flags) +MODFUNC(mod_or_data, *(u64 *)cf->data |= *(u64 *)mod->modframe.or.data) +MODFUNC(mod_xor_id, cf->can_id ^= mod->modframe.xor.can_id) +MODFUNC(mod_xor_len, cf->len ^= mod->modframe.xor.len) +MODFUNC(mod_xor_flags, cf->flags ^= mod->modframe.xor.flags) +MODFUNC(mod_xor_data, *(u64 *)cf->data ^= *(u64 *)mod->modframe.xor.data) +MODFUNC(mod_set_id, cf->can_id = mod->modframe.set.can_id) +MODFUNC(mod_set_len, cf->len = mod->modframe.set.len) +MODFUNC(mod_set_flags, cf->flags = mod->modframe.set.flags) +MODFUNC(mod_set_data, *(u64 *)cf->data = *(u64 *)mod->modframe.set.data) + +static void mod_and_fddata(struct canfd_frame *cf, struct cf_mod *mod) +{ + int i; + + for (i = 0; i < CANFD_MAX_DLEN; i += 8) + *(u64 *)(cf->data + i) &= *(u64 *)(mod->modframe.and.data + i); +} + +static void mod_or_fddata(struct canfd_frame *cf, struct cf_mod *mod) +{ + int i; + + for (i = 0; i < CANFD_MAX_DLEN; i += 8) + *(u64 *)(cf->data + i) |= *(u64 *)(mod->modframe.or.data + i); +} + +static void mod_xor_fddata(struct canfd_frame *cf, struct cf_mod *mod) +{ + int i; + + for (i = 0; i < CANFD_MAX_DLEN; i += 8) + *(u64 *)(cf->data + i) ^= *(u64 *)(mod->modframe.xor.data + i); +} + +static void mod_set_fddata(struct canfd_frame *cf, struct cf_mod *mod) +{ + memcpy(cf->data, mod->modframe.set.data, CANFD_MAX_DLEN); +} + +/* retrieve valid CC DLC value and store it into 'len' */ +static void mod_retrieve_ccdlc(struct canfd_frame *cf) +{ + struct can_frame *ccf = (struct can_frame *)cf; + + /* len8_dlc is only valid if len == CAN_MAX_DLEN */ + if (ccf->len != CAN_MAX_DLEN) + return; + + /* do we have a valid len8_dlc value from 9 .. 15 ? */ + if (ccf->len8_dlc > CAN_MAX_DLEN && ccf->len8_dlc <= CAN_MAX_RAW_DLC) + ccf->len = ccf->len8_dlc; +} + +/* convert valid CC DLC value in 'len' into struct can_frame elements */ +static void mod_store_ccdlc(struct canfd_frame *cf) +{ + struct can_frame *ccf = (struct can_frame *)cf; + + /* clear potential leftovers */ + ccf->len8_dlc = 0; + + /* plain data length 0 .. 8 - that was easy */ + if (ccf->len <= CAN_MAX_DLEN) + return; + + /* potentially broken values are caught in can_can_gw_rcv() */ + if (ccf->len > CAN_MAX_RAW_DLC) + return; + + /* we have a valid dlc value from 9 .. 15 in ccf->len */ + ccf->len8_dlc = ccf->len; + ccf->len = CAN_MAX_DLEN; +} + +static void mod_and_ccdlc(struct canfd_frame *cf, struct cf_mod *mod) +{ + mod_retrieve_ccdlc(cf); + mod_and_len(cf, mod); + mod_store_ccdlc(cf); +} + +static void mod_or_ccdlc(struct canfd_frame *cf, struct cf_mod *mod) +{ + mod_retrieve_ccdlc(cf); + mod_or_len(cf, mod); + mod_store_ccdlc(cf); +} + +static void mod_xor_ccdlc(struct canfd_frame *cf, struct cf_mod *mod) +{ + mod_retrieve_ccdlc(cf); + mod_xor_len(cf, mod); + mod_store_ccdlc(cf); +} + +static void mod_set_ccdlc(struct canfd_frame *cf, struct cf_mod *mod) +{ + mod_set_len(cf, mod); + mod_store_ccdlc(cf); +} + +static void canframecpy(struct canfd_frame *dst, struct can_frame *src) +{ + /* Copy the struct members separately to ensure that no uninitialized + * data are copied in the 3 bytes hole of the struct. This is needed + * to make easy compares of the data in the struct cf_mod. + */ + + dst->can_id = src->can_id; + dst->len = src->len; + *(u64 *)dst->data = *(u64 *)src->data; +} + +static void canfdframecpy(struct canfd_frame *dst, struct canfd_frame *src) +{ + /* Copy the struct members separately to ensure that no uninitialized + * data are copied in the 2 bytes hole of the struct. This is needed + * to make easy compares of the data in the struct cf_mod. + */ + + dst->can_id = src->can_id; + dst->flags = src->flags; + dst->len = src->len; + memcpy(dst->data, src->data, CANFD_MAX_DLEN); +} + +static int cgw_chk_csum_parms(s8 fr, s8 to, s8 re, struct rtcanmsg *r) +{ + s8 dlen = CAN_MAX_DLEN; + + if (r->flags & CGW_FLAGS_CAN_FD) + dlen = CANFD_MAX_DLEN; + + /* absolute dlc values 0 .. 7 => 0 .. 7, e.g. data [0] + * relative to received dlc -1 .. -8 : + * e.g. for received dlc = 8 + * -1 => index = 7 (data[7]) + * -3 => index = 5 (data[5]) + * -8 => index = 0 (data[0]) + */ + + if (fr >= -dlen && fr < dlen && + to >= -dlen && to < dlen && + re >= -dlen && re < dlen) + return 0; + else + return -EINVAL; +} + +static inline int calc_idx(int idx, int rx_len) +{ + if (idx < 0) + return rx_len + idx; + else + return idx; +} + +static void cgw_csum_xor_rel(struct canfd_frame *cf, struct cgw_csum_xor *xor) +{ + int from = calc_idx(xor->from_idx, cf->len); + int to = calc_idx(xor->to_idx, cf->len); + int res = calc_idx(xor->result_idx, cf->len); + u8 val = xor->init_xor_val; + int i; + + if (from < 0 || to < 0 || res < 0) + return; + + if (from <= to) { + for (i = from; i <= to; i++) + val ^= cf->data[i]; + } else { + for (i = from; i >= to; i--) + val ^= cf->data[i]; + } + + cf->data[res] = val; +} + +static void cgw_csum_xor_pos(struct canfd_frame *cf, struct cgw_csum_xor *xor) +{ + u8 val = xor->init_xor_val; + int i; + + for (i = xor->from_idx; i <= xor->to_idx; i++) + val ^= cf->data[i]; + + cf->data[xor->result_idx] = val; +} + +static void cgw_csum_xor_neg(struct canfd_frame *cf, struct cgw_csum_xor *xor) +{ + u8 val = xor->init_xor_val; + int i; + + for (i = xor->from_idx; i >= xor->to_idx; i--) + val ^= cf->data[i]; + + cf->data[xor->result_idx] = val; +} + +static void cgw_csum_crc8_rel(struct canfd_frame *cf, + struct cgw_csum_crc8 *crc8) +{ + int from = calc_idx(crc8->from_idx, cf->len); + int to = calc_idx(crc8->to_idx, cf->len); + int res = calc_idx(crc8->result_idx, cf->len); + u8 crc = crc8->init_crc_val; + int i; + + if (from < 0 || to < 0 || res < 0) + return; + + if (from <= to) { + for (i = crc8->from_idx; i <= crc8->to_idx; i++) + crc = crc8->crctab[crc ^ cf->data[i]]; + } else { + for (i = crc8->from_idx; i >= crc8->to_idx; i--) + crc = crc8->crctab[crc ^ cf->data[i]]; + } + + switch (crc8->profile) { + case CGW_CRC8PRF_1U8: + crc = crc8->crctab[crc ^ crc8->profile_data[0]]; + break; + + case CGW_CRC8PRF_16U8: + crc = crc8->crctab[crc ^ crc8->profile_data[cf->data[1] & 0xF]]; + break; + + case CGW_CRC8PRF_SFFID_XOR: + crc = crc8->crctab[crc ^ (cf->can_id & 0xFF) ^ + (cf->can_id >> 8 & 0xFF)]; + break; + } + + cf->data[crc8->result_idx] = crc ^ crc8->final_xor_val; +} + +static void cgw_csum_crc8_pos(struct canfd_frame *cf, + struct cgw_csum_crc8 *crc8) +{ + u8 crc = crc8->init_crc_val; + int i; + + for (i = crc8->from_idx; i <= crc8->to_idx; i++) + crc = crc8->crctab[crc ^ cf->data[i]]; + + switch (crc8->profile) { + case CGW_CRC8PRF_1U8: + crc = crc8->crctab[crc ^ crc8->profile_data[0]]; + break; + + case CGW_CRC8PRF_16U8: + crc = crc8->crctab[crc ^ crc8->profile_data[cf->data[1] & 0xF]]; + break; + + case CGW_CRC8PRF_SFFID_XOR: + crc = crc8->crctab[crc ^ (cf->can_id & 0xFF) ^ + (cf->can_id >> 8 & 0xFF)]; + break; + } + + cf->data[crc8->result_idx] = crc ^ crc8->final_xor_val; +} + +static void cgw_csum_crc8_neg(struct canfd_frame *cf, + struct cgw_csum_crc8 *crc8) +{ + u8 crc = crc8->init_crc_val; + int i; + + for (i = crc8->from_idx; i >= crc8->to_idx; i--) + crc = crc8->crctab[crc ^ cf->data[i]]; + + switch (crc8->profile) { + case CGW_CRC8PRF_1U8: + crc = crc8->crctab[crc ^ crc8->profile_data[0]]; + break; + + case CGW_CRC8PRF_16U8: + crc = crc8->crctab[crc ^ crc8->profile_data[cf->data[1] & 0xF]]; + break; + + case CGW_CRC8PRF_SFFID_XOR: + crc = crc8->crctab[crc ^ (cf->can_id & 0xFF) ^ + (cf->can_id >> 8 & 0xFF)]; + break; + } + + cf->data[crc8->result_idx] = crc ^ crc8->final_xor_val; +} + +/* the receive & process & send function */ +static void can_can_gw_rcv(struct sk_buff *skb, void *data) +{ + struct cgw_job *gwj = (struct cgw_job *)data; + struct canfd_frame *cf; + struct sk_buff *nskb; + int modidx = 0; + + /* process strictly Classic CAN or CAN FD frames */ + if (gwj->flags & CGW_FLAGS_CAN_FD) { + if (!can_is_canfd_skb(skb)) + return; + } else { + if (!can_is_can_skb(skb)) + return; + } + + /* Do not handle CAN frames routed more than 'max_hops' times. + * In general we should never catch this delimiter which is intended + * to cover a misconfiguration protection (e.g. circular CAN routes). + * + * The Controller Area Network controllers only accept CAN frames with + * correct CRCs - which are not visible in the controller registers. + * According to skbuff.h documentation the csum_start element for IP + * checksums is undefined/unused when ip_summed == CHECKSUM_UNNECESSARY. + * Only CAN skbs can be processed here which already have this property. + */ + +#define cgw_hops(skb) ((skb)->csum_start) + + BUG_ON(skb->ip_summed != CHECKSUM_UNNECESSARY); + + if (cgw_hops(skb) >= max_hops) { + /* indicate deleted frames due to misconfiguration */ + gwj->deleted_frames++; + return; + } + + if (!(gwj->dst.dev->flags & IFF_UP)) { + gwj->dropped_frames++; + return; + } + + /* is sending the skb back to the incoming interface not allowed? */ + if (!(gwj->flags & CGW_FLAGS_CAN_IIF_TX_OK) && + can_skb_prv(skb)->ifindex == gwj->dst.dev->ifindex) + return; + + /* clone the given skb, which has not been done in can_rcv() + * + * When there is at least one modification function activated, + * we need to copy the skb as we want to modify skb->data. + */ + if (gwj->mod.modfunc[0]) + nskb = skb_copy(skb, GFP_ATOMIC); + else + nskb = skb_clone(skb, GFP_ATOMIC); + + if (!nskb) { + gwj->dropped_frames++; + return; + } + + /* put the incremented hop counter in the cloned skb */ + cgw_hops(nskb) = cgw_hops(skb) + 1; + + /* first processing of this CAN frame -> adjust to private hop limit */ + if (gwj->limit_hops && cgw_hops(nskb) == 1) + cgw_hops(nskb) = max_hops - gwj->limit_hops + 1; + + nskb->dev = gwj->dst.dev; + + /* pointer to modifiable CAN frame */ + cf = (struct canfd_frame *)nskb->data; + + /* perform preprocessed modification functions if there are any */ + while (modidx < MAX_MODFUNCTIONS && gwj->mod.modfunc[modidx]) + (*gwj->mod.modfunc[modidx++])(cf, &gwj->mod); + + /* Has the CAN frame been modified? */ + if (modidx) { + /* get available space for the processed CAN frame type */ + int max_len = nskb->len - offsetof(struct canfd_frame, data); + + /* dlc may have changed, make sure it fits to the CAN frame */ + if (cf->len > max_len) { + /* delete frame due to misconfiguration */ + gwj->deleted_frames++; + kfree_skb(nskb); + return; + } + + /* check for checksum updates */ + if (gwj->mod.csumfunc.crc8) + (*gwj->mod.csumfunc.crc8)(cf, &gwj->mod.csum.crc8); + + if (gwj->mod.csumfunc.xor) + (*gwj->mod.csumfunc.xor)(cf, &gwj->mod.csum.xor); + } + + /* clear the skb timestamp if not configured the other way */ + if (!(gwj->flags & CGW_FLAGS_CAN_SRC_TSTAMP)) + nskb->tstamp = 0; + + /* send to netdevice */ + if (can_send(nskb, gwj->flags & CGW_FLAGS_CAN_ECHO)) + gwj->dropped_frames++; + else + gwj->handled_frames++; +} + +static inline int cgw_register_filter(struct net *net, struct cgw_job *gwj) +{ + return can_rx_register(net, gwj->src.dev, gwj->ccgw.filter.can_id, + gwj->ccgw.filter.can_mask, can_can_gw_rcv, + gwj, "gw", NULL); +} + +static inline void cgw_unregister_filter(struct net *net, struct cgw_job *gwj) +{ + can_rx_unregister(net, gwj->src.dev, gwj->ccgw.filter.can_id, + gwj->ccgw.filter.can_mask, can_can_gw_rcv, gwj); +} + +static void cgw_job_free_rcu(struct rcu_head *rcu_head) +{ + struct cgw_job *gwj = container_of(rcu_head, struct cgw_job, rcu); + + kmem_cache_free(cgw_cache, gwj); +} + +static int cgw_notifier(struct notifier_block *nb, + unsigned long msg, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct net *net = dev_net(dev); + + if (dev->type != ARPHRD_CAN) + return NOTIFY_DONE; + + if (msg == NETDEV_UNREGISTER) { + struct cgw_job *gwj = NULL; + struct hlist_node *nx; + + ASSERT_RTNL(); + + hlist_for_each_entry_safe(gwj, nx, &net->can.cgw_list, list) { + if (gwj->src.dev == dev || gwj->dst.dev == dev) { + hlist_del(&gwj->list); + cgw_unregister_filter(net, gwj); + call_rcu(&gwj->rcu, cgw_job_free_rcu); + } + } + } + + return NOTIFY_DONE; +} + +static int cgw_put_job(struct sk_buff *skb, struct cgw_job *gwj, int type, + u32 pid, u32 seq, int flags) +{ + struct rtcanmsg *rtcan; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(skb, pid, seq, type, sizeof(*rtcan), flags); + if (!nlh) + return -EMSGSIZE; + + rtcan = nlmsg_data(nlh); + rtcan->can_family = AF_CAN; + rtcan->gwtype = gwj->gwtype; + rtcan->flags = gwj->flags; + + /* add statistics if available */ + + if (gwj->handled_frames) { + if (nla_put_u32(skb, CGW_HANDLED, gwj->handled_frames) < 0) + goto cancel; + } + + if (gwj->dropped_frames) { + if (nla_put_u32(skb, CGW_DROPPED, gwj->dropped_frames) < 0) + goto cancel; + } + + if (gwj->deleted_frames) { + if (nla_put_u32(skb, CGW_DELETED, gwj->deleted_frames) < 0) + goto cancel; + } + + /* check non default settings of attributes */ + + if (gwj->limit_hops) { + if (nla_put_u8(skb, CGW_LIM_HOPS, gwj->limit_hops) < 0) + goto cancel; + } + + if (gwj->flags & CGW_FLAGS_CAN_FD) { + struct cgw_fdframe_mod mb; + + if (gwj->mod.modtype.and) { + memcpy(&mb.cf, &gwj->mod.modframe.and, sizeof(mb.cf)); + mb.modtype = gwj->mod.modtype.and; + if (nla_put(skb, CGW_FDMOD_AND, sizeof(mb), &mb) < 0) + goto cancel; + } + + if (gwj->mod.modtype.or) { + memcpy(&mb.cf, &gwj->mod.modframe.or, sizeof(mb.cf)); + mb.modtype = gwj->mod.modtype.or; + if (nla_put(skb, CGW_FDMOD_OR, sizeof(mb), &mb) < 0) + goto cancel; + } + + if (gwj->mod.modtype.xor) { + memcpy(&mb.cf, &gwj->mod.modframe.xor, sizeof(mb.cf)); + mb.modtype = gwj->mod.modtype.xor; + if (nla_put(skb, CGW_FDMOD_XOR, sizeof(mb), &mb) < 0) + goto cancel; + } + + if (gwj->mod.modtype.set) { + memcpy(&mb.cf, &gwj->mod.modframe.set, sizeof(mb.cf)); + mb.modtype = gwj->mod.modtype.set; + if (nla_put(skb, CGW_FDMOD_SET, sizeof(mb), &mb) < 0) + goto cancel; + } + } else { + struct cgw_frame_mod mb; + + if (gwj->mod.modtype.and) { + memcpy(&mb.cf, &gwj->mod.modframe.and, sizeof(mb.cf)); + mb.modtype = gwj->mod.modtype.and; + if (nla_put(skb, CGW_MOD_AND, sizeof(mb), &mb) < 0) + goto cancel; + } + + if (gwj->mod.modtype.or) { + memcpy(&mb.cf, &gwj->mod.modframe.or, sizeof(mb.cf)); + mb.modtype = gwj->mod.modtype.or; + if (nla_put(skb, CGW_MOD_OR, sizeof(mb), &mb) < 0) + goto cancel; + } + + if (gwj->mod.modtype.xor) { + memcpy(&mb.cf, &gwj->mod.modframe.xor, sizeof(mb.cf)); + mb.modtype = gwj->mod.modtype.xor; + if (nla_put(skb, CGW_MOD_XOR, sizeof(mb), &mb) < 0) + goto cancel; + } + + if (gwj->mod.modtype.set) { + memcpy(&mb.cf, &gwj->mod.modframe.set, sizeof(mb.cf)); + mb.modtype = gwj->mod.modtype.set; + if (nla_put(skb, CGW_MOD_SET, sizeof(mb), &mb) < 0) + goto cancel; + } + } + + if (gwj->mod.uid) { + if (nla_put_u32(skb, CGW_MOD_UID, gwj->mod.uid) < 0) + goto cancel; + } + + if (gwj->mod.csumfunc.crc8) { + if (nla_put(skb, CGW_CS_CRC8, CGW_CS_CRC8_LEN, + &gwj->mod.csum.crc8) < 0) + goto cancel; + } + + if (gwj->mod.csumfunc.xor) { + if (nla_put(skb, CGW_CS_XOR, CGW_CS_XOR_LEN, + &gwj->mod.csum.xor) < 0) + goto cancel; + } + + if (gwj->gwtype == CGW_TYPE_CAN_CAN) { + if (gwj->ccgw.filter.can_id || gwj->ccgw.filter.can_mask) { + if (nla_put(skb, CGW_FILTER, sizeof(struct can_filter), + &gwj->ccgw.filter) < 0) + goto cancel; + } + + if (nla_put_u32(skb, CGW_SRC_IF, gwj->ccgw.src_idx) < 0) + goto cancel; + + if (nla_put_u32(skb, CGW_DST_IF, gwj->ccgw.dst_idx) < 0) + goto cancel; + } + + nlmsg_end(skb, nlh); + return 0; + +cancel: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +/* Dump information about all CAN gateway jobs, in response to RTM_GETROUTE */ +static int cgw_dump_jobs(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + struct cgw_job *gwj = NULL; + int idx = 0; + int s_idx = cb->args[0]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(gwj, &net->can.cgw_list, list) { + if (idx < s_idx) + goto cont; + + if (cgw_put_job(skb, gwj, RTM_NEWROUTE, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI) < 0) + break; +cont: + idx++; + } + rcu_read_unlock(); + + cb->args[0] = idx; + + return skb->len; +} + +static const struct nla_policy cgw_policy[CGW_MAX + 1] = { + [CGW_MOD_AND] = { .len = sizeof(struct cgw_frame_mod) }, + [CGW_MOD_OR] = { .len = sizeof(struct cgw_frame_mod) }, + [CGW_MOD_XOR] = { .len = sizeof(struct cgw_frame_mod) }, + [CGW_MOD_SET] = { .len = sizeof(struct cgw_frame_mod) }, + [CGW_CS_XOR] = { .len = sizeof(struct cgw_csum_xor) }, + [CGW_CS_CRC8] = { .len = sizeof(struct cgw_csum_crc8) }, + [CGW_SRC_IF] = { .type = NLA_U32 }, + [CGW_DST_IF] = { .type = NLA_U32 }, + [CGW_FILTER] = { .len = sizeof(struct can_filter) }, + [CGW_LIM_HOPS] = { .type = NLA_U8 }, + [CGW_MOD_UID] = { .type = NLA_U32 }, + [CGW_FDMOD_AND] = { .len = sizeof(struct cgw_fdframe_mod) }, + [CGW_FDMOD_OR] = { .len = sizeof(struct cgw_fdframe_mod) }, + [CGW_FDMOD_XOR] = { .len = sizeof(struct cgw_fdframe_mod) }, + [CGW_FDMOD_SET] = { .len = sizeof(struct cgw_fdframe_mod) }, +}; + +/* check for common and gwtype specific attributes */ +static int cgw_parse_attr(struct nlmsghdr *nlh, struct cf_mod *mod, + u8 gwtype, void *gwtypeattr, u8 *limhops) +{ + struct nlattr *tb[CGW_MAX + 1]; + struct rtcanmsg *r = nlmsg_data(nlh); + int modidx = 0; + int err = 0; + + /* initialize modification & checksum data space */ + memset(mod, 0, sizeof(*mod)); + + err = nlmsg_parse_deprecated(nlh, sizeof(struct rtcanmsg), tb, + CGW_MAX, cgw_policy, NULL); + if (err < 0) + return err; + + if (tb[CGW_LIM_HOPS]) { + *limhops = nla_get_u8(tb[CGW_LIM_HOPS]); + + if (*limhops < 1 || *limhops > max_hops) + return -EINVAL; + } + + /* check for AND/OR/XOR/SET modifications */ + if (r->flags & CGW_FLAGS_CAN_FD) { + struct cgw_fdframe_mod mb; + + if (tb[CGW_FDMOD_AND]) { + nla_memcpy(&mb, tb[CGW_FDMOD_AND], CGW_FDMODATTR_LEN); + + canfdframecpy(&mod->modframe.and, &mb.cf); + mod->modtype.and = mb.modtype; + + if (mb.modtype & CGW_MOD_ID) + mod->modfunc[modidx++] = mod_and_id; + + if (mb.modtype & CGW_MOD_LEN) + mod->modfunc[modidx++] = mod_and_len; + + if (mb.modtype & CGW_MOD_FLAGS) + mod->modfunc[modidx++] = mod_and_flags; + + if (mb.modtype & CGW_MOD_DATA) + mod->modfunc[modidx++] = mod_and_fddata; + } + + if (tb[CGW_FDMOD_OR]) { + nla_memcpy(&mb, tb[CGW_FDMOD_OR], CGW_FDMODATTR_LEN); + + canfdframecpy(&mod->modframe.or, &mb.cf); + mod->modtype.or = mb.modtype; + + if (mb.modtype & CGW_MOD_ID) + mod->modfunc[modidx++] = mod_or_id; + + if (mb.modtype & CGW_MOD_LEN) + mod->modfunc[modidx++] = mod_or_len; + + if (mb.modtype & CGW_MOD_FLAGS) + mod->modfunc[modidx++] = mod_or_flags; + + if (mb.modtype & CGW_MOD_DATA) + mod->modfunc[modidx++] = mod_or_fddata; + } + + if (tb[CGW_FDMOD_XOR]) { + nla_memcpy(&mb, tb[CGW_FDMOD_XOR], CGW_FDMODATTR_LEN); + + canfdframecpy(&mod->modframe.xor, &mb.cf); + mod->modtype.xor = mb.modtype; + + if (mb.modtype & CGW_MOD_ID) + mod->modfunc[modidx++] = mod_xor_id; + + if (mb.modtype & CGW_MOD_LEN) + mod->modfunc[modidx++] = mod_xor_len; + + if (mb.modtype & CGW_MOD_FLAGS) + mod->modfunc[modidx++] = mod_xor_flags; + + if (mb.modtype & CGW_MOD_DATA) + mod->modfunc[modidx++] = mod_xor_fddata; + } + + if (tb[CGW_FDMOD_SET]) { + nla_memcpy(&mb, tb[CGW_FDMOD_SET], CGW_FDMODATTR_LEN); + + canfdframecpy(&mod->modframe.set, &mb.cf); + mod->modtype.set = mb.modtype; + + if (mb.modtype & CGW_MOD_ID) + mod->modfunc[modidx++] = mod_set_id; + + if (mb.modtype & CGW_MOD_LEN) + mod->modfunc[modidx++] = mod_set_len; + + if (mb.modtype & CGW_MOD_FLAGS) + mod->modfunc[modidx++] = mod_set_flags; + + if (mb.modtype & CGW_MOD_DATA) + mod->modfunc[modidx++] = mod_set_fddata; + } + } else { + struct cgw_frame_mod mb; + + if (tb[CGW_MOD_AND]) { + nla_memcpy(&mb, tb[CGW_MOD_AND], CGW_MODATTR_LEN); + + canframecpy(&mod->modframe.and, &mb.cf); + mod->modtype.and = mb.modtype; + + if (mb.modtype & CGW_MOD_ID) + mod->modfunc[modidx++] = mod_and_id; + + if (mb.modtype & CGW_MOD_DLC) + mod->modfunc[modidx++] = mod_and_ccdlc; + + if (mb.modtype & CGW_MOD_DATA) + mod->modfunc[modidx++] = mod_and_data; + } + + if (tb[CGW_MOD_OR]) { + nla_memcpy(&mb, tb[CGW_MOD_OR], CGW_MODATTR_LEN); + + canframecpy(&mod->modframe.or, &mb.cf); + mod->modtype.or = mb.modtype; + + if (mb.modtype & CGW_MOD_ID) + mod->modfunc[modidx++] = mod_or_id; + + if (mb.modtype & CGW_MOD_DLC) + mod->modfunc[modidx++] = mod_or_ccdlc; + + if (mb.modtype & CGW_MOD_DATA) + mod->modfunc[modidx++] = mod_or_data; + } + + if (tb[CGW_MOD_XOR]) { + nla_memcpy(&mb, tb[CGW_MOD_XOR], CGW_MODATTR_LEN); + + canframecpy(&mod->modframe.xor, &mb.cf); + mod->modtype.xor = mb.modtype; + + if (mb.modtype & CGW_MOD_ID) + mod->modfunc[modidx++] = mod_xor_id; + + if (mb.modtype & CGW_MOD_DLC) + mod->modfunc[modidx++] = mod_xor_ccdlc; + + if (mb.modtype & CGW_MOD_DATA) + mod->modfunc[modidx++] = mod_xor_data; + } + + if (tb[CGW_MOD_SET]) { + nla_memcpy(&mb, tb[CGW_MOD_SET], CGW_MODATTR_LEN); + + canframecpy(&mod->modframe.set, &mb.cf); + mod->modtype.set = mb.modtype; + + if (mb.modtype & CGW_MOD_ID) + mod->modfunc[modidx++] = mod_set_id; + + if (mb.modtype & CGW_MOD_DLC) + mod->modfunc[modidx++] = mod_set_ccdlc; + + if (mb.modtype & CGW_MOD_DATA) + mod->modfunc[modidx++] = mod_set_data; + } + } + + /* check for checksum operations after CAN frame modifications */ + if (modidx) { + if (tb[CGW_CS_CRC8]) { + struct cgw_csum_crc8 *c = nla_data(tb[CGW_CS_CRC8]); + + err = cgw_chk_csum_parms(c->from_idx, c->to_idx, + c->result_idx, r); + if (err) + return err; + + nla_memcpy(&mod->csum.crc8, tb[CGW_CS_CRC8], + CGW_CS_CRC8_LEN); + + /* select dedicated processing function to reduce + * runtime operations in receive hot path. + */ + if (c->from_idx < 0 || c->to_idx < 0 || + c->result_idx < 0) + mod->csumfunc.crc8 = cgw_csum_crc8_rel; + else if (c->from_idx <= c->to_idx) + mod->csumfunc.crc8 = cgw_csum_crc8_pos; + else + mod->csumfunc.crc8 = cgw_csum_crc8_neg; + } + + if (tb[CGW_CS_XOR]) { + struct cgw_csum_xor *c = nla_data(tb[CGW_CS_XOR]); + + err = cgw_chk_csum_parms(c->from_idx, c->to_idx, + c->result_idx, r); + if (err) + return err; + + nla_memcpy(&mod->csum.xor, tb[CGW_CS_XOR], + CGW_CS_XOR_LEN); + + /* select dedicated processing function to reduce + * runtime operations in receive hot path. + */ + if (c->from_idx < 0 || c->to_idx < 0 || + c->result_idx < 0) + mod->csumfunc.xor = cgw_csum_xor_rel; + else if (c->from_idx <= c->to_idx) + mod->csumfunc.xor = cgw_csum_xor_pos; + else + mod->csumfunc.xor = cgw_csum_xor_neg; + } + + if (tb[CGW_MOD_UID]) + nla_memcpy(&mod->uid, tb[CGW_MOD_UID], sizeof(u32)); + } + + if (gwtype == CGW_TYPE_CAN_CAN) { + /* check CGW_TYPE_CAN_CAN specific attributes */ + struct can_can_gw *ccgw = (struct can_can_gw *)gwtypeattr; + + memset(ccgw, 0, sizeof(*ccgw)); + + /* check for can_filter in attributes */ + if (tb[CGW_FILTER]) + nla_memcpy(&ccgw->filter, tb[CGW_FILTER], + sizeof(struct can_filter)); + + err = -ENODEV; + + /* specifying two interfaces is mandatory */ + if (!tb[CGW_SRC_IF] || !tb[CGW_DST_IF]) + return err; + + ccgw->src_idx = nla_get_u32(tb[CGW_SRC_IF]); + ccgw->dst_idx = nla_get_u32(tb[CGW_DST_IF]); + + /* both indices set to 0 for flushing all routing entries */ + if (!ccgw->src_idx && !ccgw->dst_idx) + return 0; + + /* only one index set to 0 is an error */ + if (!ccgw->src_idx || !ccgw->dst_idx) + return err; + } + + /* add the checks for other gwtypes here */ + + return 0; +} + +static int cgw_create_job(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct rtcanmsg *r; + struct cgw_job *gwj; + struct cf_mod mod; + struct can_can_gw ccgw; + u8 limhops = 0; + int err = 0; + + if (!netlink_capable(skb, CAP_NET_ADMIN)) + return -EPERM; + + if (nlmsg_len(nlh) < sizeof(*r)) + return -EINVAL; + + r = nlmsg_data(nlh); + if (r->can_family != AF_CAN) + return -EPFNOSUPPORT; + + /* so far we only support CAN -> CAN routings */ + if (r->gwtype != CGW_TYPE_CAN_CAN) + return -EINVAL; + + err = cgw_parse_attr(nlh, &mod, CGW_TYPE_CAN_CAN, &ccgw, &limhops); + if (err < 0) + return err; + + if (mod.uid) { + ASSERT_RTNL(); + + /* check for updating an existing job with identical uid */ + hlist_for_each_entry(gwj, &net->can.cgw_list, list) { + if (gwj->mod.uid != mod.uid) + continue; + + /* interfaces & filters must be identical */ + if (memcmp(&gwj->ccgw, &ccgw, sizeof(ccgw))) + return -EINVAL; + + /* update modifications with disabled softirq & quit */ + local_bh_disable(); + memcpy(&gwj->mod, &mod, sizeof(mod)); + local_bh_enable(); + return 0; + } + } + + /* ifindex == 0 is not allowed for job creation */ + if (!ccgw.src_idx || !ccgw.dst_idx) + return -ENODEV; + + gwj = kmem_cache_alloc(cgw_cache, GFP_KERNEL); + if (!gwj) + return -ENOMEM; + + gwj->handled_frames = 0; + gwj->dropped_frames = 0; + gwj->deleted_frames = 0; + gwj->flags = r->flags; + gwj->gwtype = r->gwtype; + gwj->limit_hops = limhops; + + /* insert already parsed information */ + memcpy(&gwj->mod, &mod, sizeof(mod)); + memcpy(&gwj->ccgw, &ccgw, sizeof(ccgw)); + + err = -ENODEV; + + gwj->src.dev = __dev_get_by_index(net, gwj->ccgw.src_idx); + + if (!gwj->src.dev) + goto out; + + if (gwj->src.dev->type != ARPHRD_CAN) + goto out; + + gwj->dst.dev = __dev_get_by_index(net, gwj->ccgw.dst_idx); + + if (!gwj->dst.dev) + goto out; + + if (gwj->dst.dev->type != ARPHRD_CAN) + goto out; + + ASSERT_RTNL(); + + err = cgw_register_filter(net, gwj); + if (!err) + hlist_add_head_rcu(&gwj->list, &net->can.cgw_list); +out: + if (err) + kmem_cache_free(cgw_cache, gwj); + + return err; +} + +static void cgw_remove_all_jobs(struct net *net) +{ + struct cgw_job *gwj = NULL; + struct hlist_node *nx; + + ASSERT_RTNL(); + + hlist_for_each_entry_safe(gwj, nx, &net->can.cgw_list, list) { + hlist_del(&gwj->list); + cgw_unregister_filter(net, gwj); + call_rcu(&gwj->rcu, cgw_job_free_rcu); + } +} + +static int cgw_remove_job(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct cgw_job *gwj = NULL; + struct hlist_node *nx; + struct rtcanmsg *r; + struct cf_mod mod; + struct can_can_gw ccgw; + u8 limhops = 0; + int err = 0; + + if (!netlink_capable(skb, CAP_NET_ADMIN)) + return -EPERM; + + if (nlmsg_len(nlh) < sizeof(*r)) + return -EINVAL; + + r = nlmsg_data(nlh); + if (r->can_family != AF_CAN) + return -EPFNOSUPPORT; + + /* so far we only support CAN -> CAN routings */ + if (r->gwtype != CGW_TYPE_CAN_CAN) + return -EINVAL; + + err = cgw_parse_attr(nlh, &mod, CGW_TYPE_CAN_CAN, &ccgw, &limhops); + if (err < 0) + return err; + + /* two interface indices both set to 0 => remove all entries */ + if (!ccgw.src_idx && !ccgw.dst_idx) { + cgw_remove_all_jobs(net); + return 0; + } + + err = -EINVAL; + + ASSERT_RTNL(); + + /* remove only the first matching entry */ + hlist_for_each_entry_safe(gwj, nx, &net->can.cgw_list, list) { + if (gwj->flags != r->flags) + continue; + + if (gwj->limit_hops != limhops) + continue; + + /* we have a match when uid is enabled and identical */ + if (gwj->mod.uid || mod.uid) { + if (gwj->mod.uid != mod.uid) + continue; + } else { + /* no uid => check for identical modifications */ + if (memcmp(&gwj->mod, &mod, sizeof(mod))) + continue; + } + + /* if (r->gwtype == CGW_TYPE_CAN_CAN) - is made sure here */ + if (memcmp(&gwj->ccgw, &ccgw, sizeof(ccgw))) + continue; + + hlist_del(&gwj->list); + cgw_unregister_filter(net, gwj); + call_rcu(&gwj->rcu, cgw_job_free_rcu); + err = 0; + break; + } + + return err; +} + +static int __net_init cangw_pernet_init(struct net *net) +{ + INIT_HLIST_HEAD(&net->can.cgw_list); + return 0; +} + +static void __net_exit cangw_pernet_exit_batch(struct list_head *net_list) +{ + struct net *net; + + rtnl_lock(); + list_for_each_entry(net, net_list, exit_list) + cgw_remove_all_jobs(net); + rtnl_unlock(); +} + +static struct pernet_operations cangw_pernet_ops = { + .init = cangw_pernet_init, + .exit_batch = cangw_pernet_exit_batch, +}; + +static __init int cgw_module_init(void) +{ + int ret; + + /* sanitize given module parameter */ + max_hops = clamp_t(unsigned int, max_hops, CGW_MIN_HOPS, CGW_MAX_HOPS); + + pr_info("can: netlink gateway - max_hops=%d\n", max_hops); + + ret = register_pernet_subsys(&cangw_pernet_ops); + if (ret) + return ret; + + ret = -ENOMEM; + cgw_cache = kmem_cache_create("can_gw", sizeof(struct cgw_job), + 0, 0, NULL); + if (!cgw_cache) + goto out_cache_create; + + /* set notifier */ + notifier.notifier_call = cgw_notifier; + ret = register_netdevice_notifier(¬ifier); + if (ret) + goto out_register_notifier; + + ret = rtnl_register_module(THIS_MODULE, PF_CAN, RTM_GETROUTE, + NULL, cgw_dump_jobs, 0); + if (ret) + goto out_rtnl_register1; + + ret = rtnl_register_module(THIS_MODULE, PF_CAN, RTM_NEWROUTE, + cgw_create_job, NULL, 0); + if (ret) + goto out_rtnl_register2; + ret = rtnl_register_module(THIS_MODULE, PF_CAN, RTM_DELROUTE, + cgw_remove_job, NULL, 0); + if (ret) + goto out_rtnl_register3; + + return 0; + +out_rtnl_register3: + rtnl_unregister(PF_CAN, RTM_NEWROUTE); +out_rtnl_register2: + rtnl_unregister(PF_CAN, RTM_GETROUTE); +out_rtnl_register1: + unregister_netdevice_notifier(¬ifier); +out_register_notifier: + kmem_cache_destroy(cgw_cache); +out_cache_create: + unregister_pernet_subsys(&cangw_pernet_ops); + + return ret; +} + +static __exit void cgw_module_exit(void) +{ + rtnl_unregister_all(PF_CAN); + + unregister_netdevice_notifier(¬ifier); + + unregister_pernet_subsys(&cangw_pernet_ops); + rcu_barrier(); /* Wait for completion of call_rcu()'s */ + + kmem_cache_destroy(cgw_cache); +} + +module_init(cgw_module_init); +module_exit(cgw_module_exit); diff --git a/net/can/isotp.c b/net/can/isotp.c new file mode 100644 index 000000000..545889935 --- /dev/null +++ b/net/can/isotp.c @@ -0,0 +1,1691 @@ +// SPDX-License-Identifier: (GPL-2.0 OR BSD-3-Clause) +/* isotp.c - ISO 15765-2 CAN transport protocol for protocol family CAN + * + * This implementation does not provide ISO-TP specific return values to the + * userspace. + * + * - RX path timeout of data reception leads to -ETIMEDOUT + * - RX path SN mismatch leads to -EILSEQ + * - RX path data reception with wrong padding leads to -EBADMSG + * - TX path flowcontrol reception timeout leads to -ECOMM + * - TX path flowcontrol reception overflow leads to -EMSGSIZE + * - TX path flowcontrol reception with wrong layout/padding leads to -EBADMSG + * - when a transfer (tx) is on the run the next write() blocks until it's done + * - use CAN_ISOTP_WAIT_TX_DONE flag to block the caller until the PDU is sent + * - as we have static buffers the check whether the PDU fits into the buffer + * is done at FF reception time (no support for sending 'wait frames') + * + * Copyright (c) 2020 Volkswagen Group Electronic Research + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of Volkswagen nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * Alternatively, provided that this notice is retained in full, this + * software may be distributed under the terms of the GNU General + * Public License ("GPL") version 2, in which case the provisions of the + * GPL apply INSTEAD OF those given above. + * + * The provided data structures and external interfaces from this code + * are not restricted to be used by modules with a GPL compatible license. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH + * DAMAGE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +MODULE_DESCRIPTION("PF_CAN isotp 15765-2:2016 protocol"); +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_AUTHOR("Oliver Hartkopp "); +MODULE_ALIAS("can-proto-6"); + +#define ISOTP_MIN_NAMELEN CAN_REQUIRED_SIZE(struct sockaddr_can, can_addr.tp) + +#define SINGLE_MASK(id) (((id) & CAN_EFF_FLAG) ? \ + (CAN_EFF_MASK | CAN_EFF_FLAG | CAN_RTR_FLAG) : \ + (CAN_SFF_MASK | CAN_EFF_FLAG | CAN_RTR_FLAG)) + +/* ISO 15765-2:2016 supports more than 4095 byte per ISO PDU as the FF_DL can + * take full 32 bit values (4 Gbyte). We would need some good concept to handle + * this between user space and kernel space. For now increase the static buffer + * to something about 64 kbyte to be able to test this new functionality. + */ +#define MAX_MSG_LENGTH 66000 + +/* N_PCI type values in bits 7-4 of N_PCI bytes */ +#define N_PCI_SF 0x00 /* single frame */ +#define N_PCI_FF 0x10 /* first frame */ +#define N_PCI_CF 0x20 /* consecutive frame */ +#define N_PCI_FC 0x30 /* flow control */ + +#define N_PCI_SZ 1 /* size of the PCI byte #1 */ +#define SF_PCI_SZ4 1 /* size of SingleFrame PCI including 4 bit SF_DL */ +#define SF_PCI_SZ8 2 /* size of SingleFrame PCI including 8 bit SF_DL */ +#define FF_PCI_SZ12 2 /* size of FirstFrame PCI including 12 bit FF_DL */ +#define FF_PCI_SZ32 6 /* size of FirstFrame PCI including 32 bit FF_DL */ +#define FC_CONTENT_SZ 3 /* flow control content size in byte (FS/BS/STmin) */ + +#define ISOTP_CHECK_PADDING (CAN_ISOTP_CHK_PAD_LEN | CAN_ISOTP_CHK_PAD_DATA) +#define ISOTP_ALL_BC_FLAGS (CAN_ISOTP_SF_BROADCAST | CAN_ISOTP_CF_BROADCAST) + +/* Flow Status given in FC frame */ +#define ISOTP_FC_CTS 0 /* clear to send */ +#define ISOTP_FC_WT 1 /* wait */ +#define ISOTP_FC_OVFLW 2 /* overflow */ + +#define ISOTP_FC_TIMEOUT 1 /* 1 sec */ +#define ISOTP_ECHO_TIMEOUT 2 /* 2 secs */ + +enum { + ISOTP_IDLE = 0, + ISOTP_WAIT_FIRST_FC, + ISOTP_WAIT_FC, + ISOTP_WAIT_DATA, + ISOTP_SENDING, + ISOTP_SHUTDOWN, +}; + +struct tpcon { + unsigned int idx; + unsigned int len; + u32 state; + u8 bs; + u8 sn; + u8 ll_dl; + u8 buf[MAX_MSG_LENGTH + 1]; +}; + +struct isotp_sock { + struct sock sk; + int bound; + int ifindex; + canid_t txid; + canid_t rxid; + ktime_t tx_gap; + ktime_t lastrxcf_tstamp; + struct hrtimer rxtimer, txtimer, txfrtimer; + struct can_isotp_options opt; + struct can_isotp_fc_options rxfc, txfc; + struct can_isotp_ll_options ll; + u32 frame_txtime; + u32 force_tx_stmin; + u32 force_rx_stmin; + u32 cfecho; /* consecutive frame echo tag */ + struct tpcon rx, tx; + struct list_head notifier; + wait_queue_head_t wait; + spinlock_t rx_lock; /* protect single thread state machine */ +}; + +static LIST_HEAD(isotp_notifier_list); +static DEFINE_SPINLOCK(isotp_notifier_lock); +static struct isotp_sock *isotp_busy_notifier; + +static inline struct isotp_sock *isotp_sk(const struct sock *sk) +{ + return (struct isotp_sock *)sk; +} + +static u32 isotp_bc_flags(struct isotp_sock *so) +{ + return so->opt.flags & ISOTP_ALL_BC_FLAGS; +} + +static bool isotp_register_rxid(struct isotp_sock *so) +{ + /* no broadcast modes => register rx_id for FC frame reception */ + return (isotp_bc_flags(so) == 0); +} + +static enum hrtimer_restart isotp_rx_timer_handler(struct hrtimer *hrtimer) +{ + struct isotp_sock *so = container_of(hrtimer, struct isotp_sock, + rxtimer); + struct sock *sk = &so->sk; + + if (so->rx.state == ISOTP_WAIT_DATA) { + /* we did not get new data frames in time */ + + /* report 'connection timed out' */ + sk->sk_err = ETIMEDOUT; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + + /* reset rx state */ + so->rx.state = ISOTP_IDLE; + } + + return HRTIMER_NORESTART; +} + +static int isotp_send_fc(struct sock *sk, int ae, u8 flowstatus) +{ + struct net_device *dev; + struct sk_buff *nskb; + struct canfd_frame *ncf; + struct isotp_sock *so = isotp_sk(sk); + int can_send_ret; + + nskb = alloc_skb(so->ll.mtu + sizeof(struct can_skb_priv), gfp_any()); + if (!nskb) + return 1; + + dev = dev_get_by_index(sock_net(sk), so->ifindex); + if (!dev) { + kfree_skb(nskb); + return 1; + } + + can_skb_reserve(nskb); + can_skb_prv(nskb)->ifindex = dev->ifindex; + can_skb_prv(nskb)->skbcnt = 0; + + nskb->dev = dev; + can_skb_set_owner(nskb, sk); + ncf = (struct canfd_frame *)nskb->data; + skb_put_zero(nskb, so->ll.mtu); + + /* create & send flow control reply */ + ncf->can_id = so->txid; + + if (so->opt.flags & CAN_ISOTP_TX_PADDING) { + memset(ncf->data, so->opt.txpad_content, CAN_MAX_DLEN); + ncf->len = CAN_MAX_DLEN; + } else { + ncf->len = ae + FC_CONTENT_SZ; + } + + ncf->data[ae] = N_PCI_FC | flowstatus; + ncf->data[ae + 1] = so->rxfc.bs; + ncf->data[ae + 2] = so->rxfc.stmin; + + if (ae) + ncf->data[0] = so->opt.ext_address; + + ncf->flags = so->ll.tx_flags; + + can_send_ret = can_send(nskb, 1); + if (can_send_ret) + pr_notice_once("can-isotp: %s: can_send_ret %pe\n", + __func__, ERR_PTR(can_send_ret)); + + dev_put(dev); + + /* reset blocksize counter */ + so->rx.bs = 0; + + /* reset last CF frame rx timestamp for rx stmin enforcement */ + so->lastrxcf_tstamp = ktime_set(0, 0); + + /* start rx timeout watchdog */ + hrtimer_start(&so->rxtimer, ktime_set(ISOTP_FC_TIMEOUT, 0), + HRTIMER_MODE_REL_SOFT); + return 0; +} + +static void isotp_rcv_skb(struct sk_buff *skb, struct sock *sk) +{ + struct sockaddr_can *addr = (struct sockaddr_can *)skb->cb; + + BUILD_BUG_ON(sizeof(skb->cb) < sizeof(struct sockaddr_can)); + + memset(addr, 0, sizeof(*addr)); + addr->can_family = AF_CAN; + addr->can_ifindex = skb->dev->ifindex; + + if (sock_queue_rcv_skb(sk, skb) < 0) + kfree_skb(skb); +} + +static u8 padlen(u8 datalen) +{ + static const u8 plen[] = { + 8, 8, 8, 8, 8, 8, 8, 8, 8, /* 0 - 8 */ + 12, 12, 12, 12, /* 9 - 12 */ + 16, 16, 16, 16, /* 13 - 16 */ + 20, 20, 20, 20, /* 17 - 20 */ + 24, 24, 24, 24, /* 21 - 24 */ + 32, 32, 32, 32, 32, 32, 32, 32, /* 25 - 32 */ + 48, 48, 48, 48, 48, 48, 48, 48, /* 33 - 40 */ + 48, 48, 48, 48, 48, 48, 48, 48 /* 41 - 48 */ + }; + + if (datalen > 48) + return 64; + + return plen[datalen]; +} + +/* check for length optimization and return 1/true when the check fails */ +static int check_optimized(struct canfd_frame *cf, int start_index) +{ + /* for CAN_DL <= 8 the start_index is equal to the CAN_DL as the + * padding would start at this point. E.g. if the padding would + * start at cf.data[7] cf->len has to be 7 to be optimal. + * Note: The data[] index starts with zero. + */ + if (cf->len <= CAN_MAX_DLEN) + return (cf->len != start_index); + + /* This relation is also valid in the non-linear DLC range, where + * we need to take care of the minimal next possible CAN_DL. + * The correct check would be (padlen(cf->len) != padlen(start_index)). + * But as cf->len can only take discrete values from 12, .., 64 at this + * point the padlen(cf->len) is always equal to cf->len. + */ + return (cf->len != padlen(start_index)); +} + +/* check padding and return 1/true when the check fails */ +static int check_pad(struct isotp_sock *so, struct canfd_frame *cf, + int start_index, u8 content) +{ + int i; + + /* no RX_PADDING value => check length of optimized frame length */ + if (!(so->opt.flags & CAN_ISOTP_RX_PADDING)) { + if (so->opt.flags & CAN_ISOTP_CHK_PAD_LEN) + return check_optimized(cf, start_index); + + /* no valid test against empty value => ignore frame */ + return 1; + } + + /* check datalength of correctly padded CAN frame */ + if ((so->opt.flags & CAN_ISOTP_CHK_PAD_LEN) && + cf->len != padlen(cf->len)) + return 1; + + /* check padding content */ + if (so->opt.flags & CAN_ISOTP_CHK_PAD_DATA) { + for (i = start_index; i < cf->len; i++) + if (cf->data[i] != content) + return 1; + } + return 0; +} + +static void isotp_send_cframe(struct isotp_sock *so); + +static int isotp_rcv_fc(struct isotp_sock *so, struct canfd_frame *cf, int ae) +{ + struct sock *sk = &so->sk; + + if (so->tx.state != ISOTP_WAIT_FC && + so->tx.state != ISOTP_WAIT_FIRST_FC) + return 0; + + hrtimer_cancel(&so->txtimer); + + if ((cf->len < ae + FC_CONTENT_SZ) || + ((so->opt.flags & ISOTP_CHECK_PADDING) && + check_pad(so, cf, ae + FC_CONTENT_SZ, so->opt.rxpad_content))) { + /* malformed PDU - report 'not a data message' */ + sk->sk_err = EBADMSG; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + + so->tx.state = ISOTP_IDLE; + wake_up_interruptible(&so->wait); + return 1; + } + + /* get communication parameters only from the first FC frame */ + if (so->tx.state == ISOTP_WAIT_FIRST_FC) { + so->txfc.bs = cf->data[ae + 1]; + so->txfc.stmin = cf->data[ae + 2]; + + /* fix wrong STmin values according spec */ + if (so->txfc.stmin > 0x7F && + (so->txfc.stmin < 0xF1 || so->txfc.stmin > 0xF9)) + so->txfc.stmin = 0x7F; + + so->tx_gap = ktime_set(0, 0); + /* add transmission time for CAN frame N_As */ + so->tx_gap = ktime_add_ns(so->tx_gap, so->frame_txtime); + /* add waiting time for consecutive frames N_Cs */ + if (so->opt.flags & CAN_ISOTP_FORCE_TXSTMIN) + so->tx_gap = ktime_add_ns(so->tx_gap, + so->force_tx_stmin); + else if (so->txfc.stmin < 0x80) + so->tx_gap = ktime_add_ns(so->tx_gap, + so->txfc.stmin * 1000000); + else + so->tx_gap = ktime_add_ns(so->tx_gap, + (so->txfc.stmin - 0xF0) + * 100000); + so->tx.state = ISOTP_WAIT_FC; + } + + switch (cf->data[ae] & 0x0F) { + case ISOTP_FC_CTS: + so->tx.bs = 0; + so->tx.state = ISOTP_SENDING; + /* send CF frame and enable echo timeout handling */ + hrtimer_start(&so->txtimer, ktime_set(ISOTP_ECHO_TIMEOUT, 0), + HRTIMER_MODE_REL_SOFT); + isotp_send_cframe(so); + break; + + case ISOTP_FC_WT: + /* start timer to wait for next FC frame */ + hrtimer_start(&so->txtimer, ktime_set(ISOTP_FC_TIMEOUT, 0), + HRTIMER_MODE_REL_SOFT); + break; + + case ISOTP_FC_OVFLW: + /* overflow on receiver side - report 'message too long' */ + sk->sk_err = EMSGSIZE; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + fallthrough; + + default: + /* stop this tx job */ + so->tx.state = ISOTP_IDLE; + wake_up_interruptible(&so->wait); + } + return 0; +} + +static int isotp_rcv_sf(struct sock *sk, struct canfd_frame *cf, int pcilen, + struct sk_buff *skb, int len) +{ + struct isotp_sock *so = isotp_sk(sk); + struct sk_buff *nskb; + + hrtimer_cancel(&so->rxtimer); + so->rx.state = ISOTP_IDLE; + + if (!len || len > cf->len - pcilen) + return 1; + + if ((so->opt.flags & ISOTP_CHECK_PADDING) && + check_pad(so, cf, pcilen + len, so->opt.rxpad_content)) { + /* malformed PDU - report 'not a data message' */ + sk->sk_err = EBADMSG; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + return 1; + } + + nskb = alloc_skb(len, gfp_any()); + if (!nskb) + return 1; + + memcpy(skb_put(nskb, len), &cf->data[pcilen], len); + + nskb->tstamp = skb->tstamp; + nskb->dev = skb->dev; + isotp_rcv_skb(nskb, sk); + return 0; +} + +static int isotp_rcv_ff(struct sock *sk, struct canfd_frame *cf, int ae) +{ + struct isotp_sock *so = isotp_sk(sk); + int i; + int off; + int ff_pci_sz; + + hrtimer_cancel(&so->rxtimer); + so->rx.state = ISOTP_IDLE; + + /* get the used sender LL_DL from the (first) CAN frame data length */ + so->rx.ll_dl = padlen(cf->len); + + /* the first frame has to use the entire frame up to LL_DL length */ + if (cf->len != so->rx.ll_dl) + return 1; + + /* get the FF_DL */ + so->rx.len = (cf->data[ae] & 0x0F) << 8; + so->rx.len += cf->data[ae + 1]; + + /* Check for FF_DL escape sequence supporting 32 bit PDU length */ + if (so->rx.len) { + ff_pci_sz = FF_PCI_SZ12; + } else { + /* FF_DL = 0 => get real length from next 4 bytes */ + so->rx.len = cf->data[ae + 2] << 24; + so->rx.len += cf->data[ae + 3] << 16; + so->rx.len += cf->data[ae + 4] << 8; + so->rx.len += cf->data[ae + 5]; + ff_pci_sz = FF_PCI_SZ32; + } + + /* take care of a potential SF_DL ESC offset for TX_DL > 8 */ + off = (so->rx.ll_dl > CAN_MAX_DLEN) ? 1 : 0; + + if (so->rx.len + ae + off + ff_pci_sz < so->rx.ll_dl) + return 1; + + if (so->rx.len > MAX_MSG_LENGTH) { + /* send FC frame with overflow status */ + isotp_send_fc(sk, ae, ISOTP_FC_OVFLW); + return 1; + } + + /* copy the first received data bytes */ + so->rx.idx = 0; + for (i = ae + ff_pci_sz; i < so->rx.ll_dl; i++) + so->rx.buf[so->rx.idx++] = cf->data[i]; + + /* initial setup for this pdu reception */ + so->rx.sn = 1; + so->rx.state = ISOTP_WAIT_DATA; + + /* no creation of flow control frames */ + if (so->opt.flags & CAN_ISOTP_LISTEN_MODE) + return 0; + + /* send our first FC frame */ + isotp_send_fc(sk, ae, ISOTP_FC_CTS); + return 0; +} + +static int isotp_rcv_cf(struct sock *sk, struct canfd_frame *cf, int ae, + struct sk_buff *skb) +{ + struct isotp_sock *so = isotp_sk(sk); + struct sk_buff *nskb; + int i; + + if (so->rx.state != ISOTP_WAIT_DATA) + return 0; + + /* drop if timestamp gap is less than force_rx_stmin nano secs */ + if (so->opt.flags & CAN_ISOTP_FORCE_RXSTMIN) { + if (ktime_to_ns(ktime_sub(skb->tstamp, so->lastrxcf_tstamp)) < + so->force_rx_stmin) + return 0; + + so->lastrxcf_tstamp = skb->tstamp; + } + + hrtimer_cancel(&so->rxtimer); + + /* CFs are never longer than the FF */ + if (cf->len > so->rx.ll_dl) + return 1; + + /* CFs have usually the LL_DL length */ + if (cf->len < so->rx.ll_dl) { + /* this is only allowed for the last CF */ + if (so->rx.len - so->rx.idx > so->rx.ll_dl - ae - N_PCI_SZ) + return 1; + } + + if ((cf->data[ae] & 0x0F) != so->rx.sn) { + /* wrong sn detected - report 'illegal byte sequence' */ + sk->sk_err = EILSEQ; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + + /* reset rx state */ + so->rx.state = ISOTP_IDLE; + return 1; + } + so->rx.sn++; + so->rx.sn %= 16; + + for (i = ae + N_PCI_SZ; i < cf->len; i++) { + so->rx.buf[so->rx.idx++] = cf->data[i]; + if (so->rx.idx >= so->rx.len) + break; + } + + if (so->rx.idx >= so->rx.len) { + /* we are done */ + so->rx.state = ISOTP_IDLE; + + if ((so->opt.flags & ISOTP_CHECK_PADDING) && + check_pad(so, cf, i + 1, so->opt.rxpad_content)) { + /* malformed PDU - report 'not a data message' */ + sk->sk_err = EBADMSG; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + return 1; + } + + nskb = alloc_skb(so->rx.len, gfp_any()); + if (!nskb) + return 1; + + memcpy(skb_put(nskb, so->rx.len), so->rx.buf, + so->rx.len); + + nskb->tstamp = skb->tstamp; + nskb->dev = skb->dev; + isotp_rcv_skb(nskb, sk); + return 0; + } + + /* perform blocksize handling, if enabled */ + if (!so->rxfc.bs || ++so->rx.bs < so->rxfc.bs) { + /* start rx timeout watchdog */ + hrtimer_start(&so->rxtimer, ktime_set(ISOTP_FC_TIMEOUT, 0), + HRTIMER_MODE_REL_SOFT); + return 0; + } + + /* no creation of flow control frames */ + if (so->opt.flags & CAN_ISOTP_LISTEN_MODE) + return 0; + + /* we reached the specified blocksize so->rxfc.bs */ + isotp_send_fc(sk, ae, ISOTP_FC_CTS); + return 0; +} + +static void isotp_rcv(struct sk_buff *skb, void *data) +{ + struct sock *sk = (struct sock *)data; + struct isotp_sock *so = isotp_sk(sk); + struct canfd_frame *cf; + int ae = (so->opt.flags & CAN_ISOTP_EXTEND_ADDR) ? 1 : 0; + u8 n_pci_type, sf_dl; + + /* Strictly receive only frames with the configured MTU size + * => clear separation of CAN2.0 / CAN FD transport channels + */ + if (skb->len != so->ll.mtu) + return; + + cf = (struct canfd_frame *)skb->data; + + /* if enabled: check reception of my configured extended address */ + if (ae && cf->data[0] != so->opt.rx_ext_address) + return; + + n_pci_type = cf->data[ae] & 0xF0; + + /* Make sure the state changes and data structures stay consistent at + * CAN frame reception time. This locking is not needed in real world + * use cases but the inconsistency can be triggered with syzkaller. + */ + spin_lock(&so->rx_lock); + + if (so->opt.flags & CAN_ISOTP_HALF_DUPLEX) { + /* check rx/tx path half duplex expectations */ + if ((so->tx.state != ISOTP_IDLE && n_pci_type != N_PCI_FC) || + (so->rx.state != ISOTP_IDLE && n_pci_type == N_PCI_FC)) + goto out_unlock; + } + + switch (n_pci_type) { + case N_PCI_FC: + /* tx path: flow control frame containing the FC parameters */ + isotp_rcv_fc(so, cf, ae); + break; + + case N_PCI_SF: + /* rx path: single frame + * + * As we do not have a rx.ll_dl configuration, we can only test + * if the CAN frames payload length matches the LL_DL == 8 + * requirements - no matter if it's CAN 2.0 or CAN FD + */ + + /* get the SF_DL from the N_PCI byte */ + sf_dl = cf->data[ae] & 0x0F; + + if (cf->len <= CAN_MAX_DLEN) { + isotp_rcv_sf(sk, cf, SF_PCI_SZ4 + ae, skb, sf_dl); + } else { + if (can_is_canfd_skb(skb)) { + /* We have a CAN FD frame and CAN_DL is greater than 8: + * Only frames with the SF_DL == 0 ESC value are valid. + * + * If so take care of the increased SF PCI size + * (SF_PCI_SZ8) to point to the message content behind + * the extended SF PCI info and get the real SF_DL + * length value from the formerly first data byte. + */ + if (sf_dl == 0) + isotp_rcv_sf(sk, cf, SF_PCI_SZ8 + ae, skb, + cf->data[SF_PCI_SZ4 + ae]); + } + } + break; + + case N_PCI_FF: + /* rx path: first frame */ + isotp_rcv_ff(sk, cf, ae); + break; + + case N_PCI_CF: + /* rx path: consecutive frame */ + isotp_rcv_cf(sk, cf, ae, skb); + break; + } + +out_unlock: + spin_unlock(&so->rx_lock); +} + +static void isotp_fill_dataframe(struct canfd_frame *cf, struct isotp_sock *so, + int ae, int off) +{ + int pcilen = N_PCI_SZ + ae + off; + int space = so->tx.ll_dl - pcilen; + int num = min_t(int, so->tx.len - so->tx.idx, space); + int i; + + cf->can_id = so->txid; + cf->len = num + pcilen; + + if (num < space) { + if (so->opt.flags & CAN_ISOTP_TX_PADDING) { + /* user requested padding */ + cf->len = padlen(cf->len); + memset(cf->data, so->opt.txpad_content, cf->len); + } else if (cf->len > CAN_MAX_DLEN) { + /* mandatory padding for CAN FD frames */ + cf->len = padlen(cf->len); + memset(cf->data, CAN_ISOTP_DEFAULT_PAD_CONTENT, + cf->len); + } + } + + for (i = 0; i < num; i++) + cf->data[pcilen + i] = so->tx.buf[so->tx.idx++]; + + if (ae) + cf->data[0] = so->opt.ext_address; +} + +static void isotp_send_cframe(struct isotp_sock *so) +{ + struct sock *sk = &so->sk; + struct sk_buff *skb; + struct net_device *dev; + struct canfd_frame *cf; + int can_send_ret; + int ae = (so->opt.flags & CAN_ISOTP_EXTEND_ADDR) ? 1 : 0; + + dev = dev_get_by_index(sock_net(sk), so->ifindex); + if (!dev) + return; + + skb = alloc_skb(so->ll.mtu + sizeof(struct can_skb_priv), GFP_ATOMIC); + if (!skb) { + dev_put(dev); + return; + } + + can_skb_reserve(skb); + can_skb_prv(skb)->ifindex = dev->ifindex; + can_skb_prv(skb)->skbcnt = 0; + + cf = (struct canfd_frame *)skb->data; + skb_put_zero(skb, so->ll.mtu); + + /* create consecutive frame */ + isotp_fill_dataframe(cf, so, ae, 0); + + /* place consecutive frame N_PCI in appropriate index */ + cf->data[ae] = N_PCI_CF | so->tx.sn++; + so->tx.sn %= 16; + so->tx.bs++; + + cf->flags = so->ll.tx_flags; + + skb->dev = dev; + can_skb_set_owner(skb, sk); + + /* cfecho should have been zero'ed by init/isotp_rcv_echo() */ + if (so->cfecho) + pr_notice_once("can-isotp: cfecho is %08X != 0\n", so->cfecho); + + /* set consecutive frame echo tag */ + so->cfecho = *(u32 *)cf->data; + + /* send frame with local echo enabled */ + can_send_ret = can_send(skb, 1); + if (can_send_ret) { + pr_notice_once("can-isotp: %s: can_send_ret %pe\n", + __func__, ERR_PTR(can_send_ret)); + if (can_send_ret == -ENOBUFS) + pr_notice_once("can-isotp: tx queue is full\n"); + } + dev_put(dev); +} + +static void isotp_create_fframe(struct canfd_frame *cf, struct isotp_sock *so, + int ae) +{ + int i; + int ff_pci_sz; + + cf->can_id = so->txid; + cf->len = so->tx.ll_dl; + if (ae) + cf->data[0] = so->opt.ext_address; + + /* create N_PCI bytes with 12/32 bit FF_DL data length */ + if (so->tx.len > 4095) { + /* use 32 bit FF_DL notation */ + cf->data[ae] = N_PCI_FF; + cf->data[ae + 1] = 0; + cf->data[ae + 2] = (u8)(so->tx.len >> 24) & 0xFFU; + cf->data[ae + 3] = (u8)(so->tx.len >> 16) & 0xFFU; + cf->data[ae + 4] = (u8)(so->tx.len >> 8) & 0xFFU; + cf->data[ae + 5] = (u8)so->tx.len & 0xFFU; + ff_pci_sz = FF_PCI_SZ32; + } else { + /* use 12 bit FF_DL notation */ + cf->data[ae] = (u8)(so->tx.len >> 8) | N_PCI_FF; + cf->data[ae + 1] = (u8)so->tx.len & 0xFFU; + ff_pci_sz = FF_PCI_SZ12; + } + + /* add first data bytes depending on ae */ + for (i = ae + ff_pci_sz; i < so->tx.ll_dl; i++) + cf->data[i] = so->tx.buf[so->tx.idx++]; + + so->tx.sn = 1; +} + +static void isotp_rcv_echo(struct sk_buff *skb, void *data) +{ + struct sock *sk = (struct sock *)data; + struct isotp_sock *so = isotp_sk(sk); + struct canfd_frame *cf = (struct canfd_frame *)skb->data; + + /* only handle my own local echo CF/SF skb's (no FF!) */ + if (skb->sk != sk || so->cfecho != *(u32 *)cf->data) + return; + + /* cancel local echo timeout */ + hrtimer_cancel(&so->txtimer); + + /* local echo skb with consecutive frame has been consumed */ + so->cfecho = 0; + + if (so->tx.idx >= so->tx.len) { + /* we are done */ + so->tx.state = ISOTP_IDLE; + wake_up_interruptible(&so->wait); + return; + } + + if (so->txfc.bs && so->tx.bs >= so->txfc.bs) { + /* stop and wait for FC with timeout */ + so->tx.state = ISOTP_WAIT_FC; + hrtimer_start(&so->txtimer, ktime_set(ISOTP_FC_TIMEOUT, 0), + HRTIMER_MODE_REL_SOFT); + return; + } + + /* no gap between data frames needed => use burst mode */ + if (!so->tx_gap) { + /* enable echo timeout handling */ + hrtimer_start(&so->txtimer, ktime_set(ISOTP_ECHO_TIMEOUT, 0), + HRTIMER_MODE_REL_SOFT); + isotp_send_cframe(so); + return; + } + + /* start timer to send next consecutive frame with correct delay */ + hrtimer_start(&so->txfrtimer, so->tx_gap, HRTIMER_MODE_REL_SOFT); +} + +static enum hrtimer_restart isotp_tx_timer_handler(struct hrtimer *hrtimer) +{ + struct isotp_sock *so = container_of(hrtimer, struct isotp_sock, + txtimer); + struct sock *sk = &so->sk; + + /* don't handle timeouts in IDLE or SHUTDOWN state */ + if (so->tx.state == ISOTP_IDLE || so->tx.state == ISOTP_SHUTDOWN) + return HRTIMER_NORESTART; + + /* we did not get any flow control or echo frame in time */ + + /* report 'communication error on send' */ + sk->sk_err = ECOMM; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + + /* reset tx state */ + so->tx.state = ISOTP_IDLE; + wake_up_interruptible(&so->wait); + + return HRTIMER_NORESTART; +} + +static enum hrtimer_restart isotp_txfr_timer_handler(struct hrtimer *hrtimer) +{ + struct isotp_sock *so = container_of(hrtimer, struct isotp_sock, + txfrtimer); + + /* start echo timeout handling and cover below protocol error */ + hrtimer_start(&so->txtimer, ktime_set(ISOTP_ECHO_TIMEOUT, 0), + HRTIMER_MODE_REL_SOFT); + + /* cfecho should be consumed by isotp_rcv_echo() here */ + if (so->tx.state == ISOTP_SENDING && !so->cfecho) + isotp_send_cframe(so); + + return HRTIMER_NORESTART; +} + +static int isotp_sendmsg(struct socket *sock, struct msghdr *msg, size_t size) +{ + struct sock *sk = sock->sk; + struct isotp_sock *so = isotp_sk(sk); + struct sk_buff *skb; + struct net_device *dev; + struct canfd_frame *cf; + int ae = (so->opt.flags & CAN_ISOTP_EXTEND_ADDR) ? 1 : 0; + int wait_tx_done = (so->opt.flags & CAN_ISOTP_WAIT_TX_DONE) ? 1 : 0; + s64 hrtimer_sec = ISOTP_ECHO_TIMEOUT; + int off; + int err; + + if (!so->bound || so->tx.state == ISOTP_SHUTDOWN) + return -EADDRNOTAVAIL; + + while (cmpxchg(&so->tx.state, ISOTP_IDLE, ISOTP_SENDING) != ISOTP_IDLE) { + /* we do not support multiple buffers - for now */ + if (msg->msg_flags & MSG_DONTWAIT) + return -EAGAIN; + + if (so->tx.state == ISOTP_SHUTDOWN) + return -EADDRNOTAVAIL; + + /* wait for complete transmission of current pdu */ + err = wait_event_interruptible(so->wait, so->tx.state == ISOTP_IDLE); + if (err) + goto err_event_drop; + } + + if (!size || size > MAX_MSG_LENGTH) { + err = -EINVAL; + goto err_out_drop; + } + + /* take care of a potential SF_DL ESC offset for TX_DL > 8 */ + off = (so->tx.ll_dl > CAN_MAX_DLEN) ? 1 : 0; + + /* does the given data fit into a single frame for SF_BROADCAST? */ + if ((isotp_bc_flags(so) == CAN_ISOTP_SF_BROADCAST) && + (size > so->tx.ll_dl - SF_PCI_SZ4 - ae - off)) { + err = -EINVAL; + goto err_out_drop; + } + + err = memcpy_from_msg(so->tx.buf, msg, size); + if (err < 0) + goto err_out_drop; + + dev = dev_get_by_index(sock_net(sk), so->ifindex); + if (!dev) { + err = -ENXIO; + goto err_out_drop; + } + + skb = sock_alloc_send_skb(sk, so->ll.mtu + sizeof(struct can_skb_priv), + msg->msg_flags & MSG_DONTWAIT, &err); + if (!skb) { + dev_put(dev); + goto err_out_drop; + } + + can_skb_reserve(skb); + can_skb_prv(skb)->ifindex = dev->ifindex; + can_skb_prv(skb)->skbcnt = 0; + + so->tx.len = size; + so->tx.idx = 0; + + cf = (struct canfd_frame *)skb->data; + skb_put_zero(skb, so->ll.mtu); + + /* cfecho should have been zero'ed by init / former isotp_rcv_echo() */ + if (so->cfecho) + pr_notice_once("can-isotp: uninit cfecho %08X\n", so->cfecho); + + /* check for single frame transmission depending on TX_DL */ + if (size <= so->tx.ll_dl - SF_PCI_SZ4 - ae - off) { + /* The message size generally fits into a SingleFrame - good. + * + * SF_DL ESC offset optimization: + * + * When TX_DL is greater 8 but the message would still fit + * into a 8 byte CAN frame, we can omit the offset. + * This prevents a protocol caused length extension from + * CAN_DL = 8 to CAN_DL = 12 due to the SF_SL ESC handling. + */ + if (size <= CAN_MAX_DLEN - SF_PCI_SZ4 - ae) + off = 0; + + isotp_fill_dataframe(cf, so, ae, off); + + /* place single frame N_PCI w/o length in appropriate index */ + cf->data[ae] = N_PCI_SF; + + /* place SF_DL size value depending on the SF_DL ESC offset */ + if (off) + cf->data[SF_PCI_SZ4 + ae] = size; + else + cf->data[ae] |= size; + + /* set CF echo tag for isotp_rcv_echo() (SF-mode) */ + so->cfecho = *(u32 *)cf->data; + } else { + /* send first frame */ + + isotp_create_fframe(cf, so, ae); + + if (isotp_bc_flags(so) == CAN_ISOTP_CF_BROADCAST) { + /* set timer for FC-less operation (STmin = 0) */ + if (so->opt.flags & CAN_ISOTP_FORCE_TXSTMIN) + so->tx_gap = ktime_set(0, so->force_tx_stmin); + else + so->tx_gap = ktime_set(0, so->frame_txtime); + + /* disable wait for FCs due to activated block size */ + so->txfc.bs = 0; + + /* set CF echo tag for isotp_rcv_echo() (CF-mode) */ + so->cfecho = *(u32 *)cf->data; + } else { + /* standard flow control check */ + so->tx.state = ISOTP_WAIT_FIRST_FC; + + /* start timeout for FC */ + hrtimer_sec = ISOTP_FC_TIMEOUT; + + /* no CF echo tag for isotp_rcv_echo() (FF-mode) */ + so->cfecho = 0; + } + } + + hrtimer_start(&so->txtimer, ktime_set(hrtimer_sec, 0), + HRTIMER_MODE_REL_SOFT); + + /* send the first or only CAN frame */ + cf->flags = so->ll.tx_flags; + + skb->dev = dev; + skb->sk = sk; + err = can_send(skb, 1); + dev_put(dev); + if (err) { + pr_notice_once("can-isotp: %s: can_send_ret %pe\n", + __func__, ERR_PTR(err)); + + /* no transmission -> no timeout monitoring */ + hrtimer_cancel(&so->txtimer); + + /* reset consecutive frame echo tag */ + so->cfecho = 0; + + goto err_out_drop; + } + + if (wait_tx_done) { + /* wait for complete transmission of current pdu */ + err = wait_event_interruptible(so->wait, so->tx.state == ISOTP_IDLE); + if (err) + goto err_event_drop; + + err = sock_error(sk); + if (err) + return err; + } + + return size; + +err_event_drop: + /* got signal: force tx state machine to be idle */ + so->tx.state = ISOTP_IDLE; + hrtimer_cancel(&so->txfrtimer); + hrtimer_cancel(&so->txtimer); +err_out_drop: + /* drop this PDU and unlock a potential wait queue */ + so->tx.state = ISOTP_IDLE; + wake_up_interruptible(&so->wait); + + return err; +} + +static int isotp_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, + int flags) +{ + struct sock *sk = sock->sk; + struct sk_buff *skb; + struct isotp_sock *so = isotp_sk(sk); + int ret = 0; + + if (flags & ~(MSG_DONTWAIT | MSG_TRUNC | MSG_PEEK | MSG_CMSG_COMPAT)) + return -EINVAL; + + if (!so->bound) + return -EADDRNOTAVAIL; + + skb = skb_recv_datagram(sk, flags, &ret); + if (!skb) + return ret; + + if (size < skb->len) + msg->msg_flags |= MSG_TRUNC; + else + size = skb->len; + + ret = memcpy_to_msg(msg, skb->data, size); + if (ret < 0) + goto out_err; + + sock_recv_cmsgs(msg, sk, skb); + + if (msg->msg_name) { + __sockaddr_check_size(ISOTP_MIN_NAMELEN); + msg->msg_namelen = ISOTP_MIN_NAMELEN; + memcpy(msg->msg_name, skb->cb, msg->msg_namelen); + } + + /* set length of return value */ + ret = (flags & MSG_TRUNC) ? skb->len : size; + +out_err: + skb_free_datagram(sk, skb); + + return ret; +} + +static int isotp_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct isotp_sock *so; + struct net *net; + + if (!sk) + return 0; + + so = isotp_sk(sk); + net = sock_net(sk); + + /* wait for complete transmission of current pdu */ + while (wait_event_interruptible(so->wait, so->tx.state == ISOTP_IDLE) == 0 && + cmpxchg(&so->tx.state, ISOTP_IDLE, ISOTP_SHUTDOWN) != ISOTP_IDLE) + ; + + /* force state machines to be idle also when a signal occurred */ + so->tx.state = ISOTP_SHUTDOWN; + so->rx.state = ISOTP_IDLE; + + spin_lock(&isotp_notifier_lock); + while (isotp_busy_notifier == so) { + spin_unlock(&isotp_notifier_lock); + schedule_timeout_uninterruptible(1); + spin_lock(&isotp_notifier_lock); + } + list_del(&so->notifier); + spin_unlock(&isotp_notifier_lock); + + lock_sock(sk); + + /* remove current filters & unregister */ + if (so->bound) { + if (so->ifindex) { + struct net_device *dev; + + dev = dev_get_by_index(net, so->ifindex); + if (dev) { + if (isotp_register_rxid(so)) + can_rx_unregister(net, dev, so->rxid, + SINGLE_MASK(so->rxid), + isotp_rcv, sk); + + can_rx_unregister(net, dev, so->txid, + SINGLE_MASK(so->txid), + isotp_rcv_echo, sk); + dev_put(dev); + synchronize_rcu(); + } + } + } + + hrtimer_cancel(&so->txfrtimer); + hrtimer_cancel(&so->txtimer); + hrtimer_cancel(&so->rxtimer); + + so->ifindex = 0; + so->bound = 0; + + sock_orphan(sk); + sock->sk = NULL; + + release_sock(sk); + sock_put(sk); + + return 0; +} + +static int isotp_bind(struct socket *sock, struct sockaddr *uaddr, int len) +{ + struct sockaddr_can *addr = (struct sockaddr_can *)uaddr; + struct sock *sk = sock->sk; + struct isotp_sock *so = isotp_sk(sk); + struct net *net = sock_net(sk); + int ifindex; + struct net_device *dev; + canid_t tx_id = addr->can_addr.tp.tx_id; + canid_t rx_id = addr->can_addr.tp.rx_id; + int err = 0; + int notify_enetdown = 0; + + if (len < ISOTP_MIN_NAMELEN) + return -EINVAL; + + if (addr->can_family != AF_CAN) + return -EINVAL; + + /* sanitize tx CAN identifier */ + if (tx_id & CAN_EFF_FLAG) + tx_id &= (CAN_EFF_FLAG | CAN_EFF_MASK); + else + tx_id &= CAN_SFF_MASK; + + /* give feedback on wrong CAN-ID value */ + if (tx_id != addr->can_addr.tp.tx_id) + return -EINVAL; + + /* sanitize rx CAN identifier (if needed) */ + if (isotp_register_rxid(so)) { + if (rx_id & CAN_EFF_FLAG) + rx_id &= (CAN_EFF_FLAG | CAN_EFF_MASK); + else + rx_id &= CAN_SFF_MASK; + + /* give feedback on wrong CAN-ID value */ + if (rx_id != addr->can_addr.tp.rx_id) + return -EINVAL; + } + + if (!addr->can_ifindex) + return -ENODEV; + + lock_sock(sk); + + if (so->bound) { + err = -EINVAL; + goto out; + } + + /* ensure different CAN IDs when the rx_id is to be registered */ + if (isotp_register_rxid(so) && rx_id == tx_id) { + err = -EADDRNOTAVAIL; + goto out; + } + + dev = dev_get_by_index(net, addr->can_ifindex); + if (!dev) { + err = -ENODEV; + goto out; + } + if (dev->type != ARPHRD_CAN) { + dev_put(dev); + err = -ENODEV; + goto out; + } + if (dev->mtu < so->ll.mtu) { + dev_put(dev); + err = -EINVAL; + goto out; + } + if (!(dev->flags & IFF_UP)) + notify_enetdown = 1; + + ifindex = dev->ifindex; + + if (isotp_register_rxid(so)) + can_rx_register(net, dev, rx_id, SINGLE_MASK(rx_id), + isotp_rcv, sk, "isotp", sk); + + /* no consecutive frame echo skb in flight */ + so->cfecho = 0; + + /* register for echo skb's */ + can_rx_register(net, dev, tx_id, SINGLE_MASK(tx_id), + isotp_rcv_echo, sk, "isotpe", sk); + + dev_put(dev); + + /* switch to new settings */ + so->ifindex = ifindex; + so->rxid = rx_id; + so->txid = tx_id; + so->bound = 1; + +out: + release_sock(sk); + + if (notify_enetdown) { + sk->sk_err = ENETDOWN; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + } + + return err; +} + +static int isotp_getname(struct socket *sock, struct sockaddr *uaddr, int peer) +{ + struct sockaddr_can *addr = (struct sockaddr_can *)uaddr; + struct sock *sk = sock->sk; + struct isotp_sock *so = isotp_sk(sk); + + if (peer) + return -EOPNOTSUPP; + + memset(addr, 0, ISOTP_MIN_NAMELEN); + addr->can_family = AF_CAN; + addr->can_ifindex = so->ifindex; + addr->can_addr.tp.rx_id = so->rxid; + addr->can_addr.tp.tx_id = so->txid; + + return ISOTP_MIN_NAMELEN; +} + +static int isotp_setsockopt_locked(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct isotp_sock *so = isotp_sk(sk); + int ret = 0; + + if (so->bound) + return -EISCONN; + + switch (optname) { + case CAN_ISOTP_OPTS: + if (optlen != sizeof(struct can_isotp_options)) + return -EINVAL; + + if (copy_from_sockptr(&so->opt, optval, optlen)) + return -EFAULT; + + /* no separate rx_ext_address is given => use ext_address */ + if (!(so->opt.flags & CAN_ISOTP_RX_EXT_ADDR)) + so->opt.rx_ext_address = so->opt.ext_address; + + /* these broadcast flags are not allowed together */ + if (isotp_bc_flags(so) == ISOTP_ALL_BC_FLAGS) { + /* CAN_ISOTP_SF_BROADCAST is prioritized */ + so->opt.flags &= ~CAN_ISOTP_CF_BROADCAST; + + /* give user feedback on wrong config attempt */ + ret = -EINVAL; + } + + /* check for frame_txtime changes (0 => no changes) */ + if (so->opt.frame_txtime) { + if (so->opt.frame_txtime == CAN_ISOTP_FRAME_TXTIME_ZERO) + so->frame_txtime = 0; + else + so->frame_txtime = so->opt.frame_txtime; + } + break; + + case CAN_ISOTP_RECV_FC: + if (optlen != sizeof(struct can_isotp_fc_options)) + return -EINVAL; + + if (copy_from_sockptr(&so->rxfc, optval, optlen)) + return -EFAULT; + break; + + case CAN_ISOTP_TX_STMIN: + if (optlen != sizeof(u32)) + return -EINVAL; + + if (copy_from_sockptr(&so->force_tx_stmin, optval, optlen)) + return -EFAULT; + break; + + case CAN_ISOTP_RX_STMIN: + if (optlen != sizeof(u32)) + return -EINVAL; + + if (copy_from_sockptr(&so->force_rx_stmin, optval, optlen)) + return -EFAULT; + break; + + case CAN_ISOTP_LL_OPTS: + if (optlen == sizeof(struct can_isotp_ll_options)) { + struct can_isotp_ll_options ll; + + if (copy_from_sockptr(&ll, optval, optlen)) + return -EFAULT; + + /* check for correct ISO 11898-1 DLC data length */ + if (ll.tx_dl != padlen(ll.tx_dl)) + return -EINVAL; + + if (ll.mtu != CAN_MTU && ll.mtu != CANFD_MTU) + return -EINVAL; + + if (ll.mtu == CAN_MTU && + (ll.tx_dl > CAN_MAX_DLEN || ll.tx_flags != 0)) + return -EINVAL; + + memcpy(&so->ll, &ll, sizeof(ll)); + + /* set ll_dl for tx path to similar place as for rx */ + so->tx.ll_dl = ll.tx_dl; + } else { + return -EINVAL; + } + break; + + default: + ret = -ENOPROTOOPT; + } + + return ret; +} + +static int isotp_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) + +{ + struct sock *sk = sock->sk; + int ret; + + if (level != SOL_CAN_ISOTP) + return -EINVAL; + + lock_sock(sk); + ret = isotp_setsockopt_locked(sock, level, optname, optval, optlen); + release_sock(sk); + return ret; +} + +static int isotp_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + struct isotp_sock *so = isotp_sk(sk); + int len; + void *val; + + if (level != SOL_CAN_ISOTP) + return -EINVAL; + if (get_user(len, optlen)) + return -EFAULT; + if (len < 0) + return -EINVAL; + + switch (optname) { + case CAN_ISOTP_OPTS: + len = min_t(int, len, sizeof(struct can_isotp_options)); + val = &so->opt; + break; + + case CAN_ISOTP_RECV_FC: + len = min_t(int, len, sizeof(struct can_isotp_fc_options)); + val = &so->rxfc; + break; + + case CAN_ISOTP_TX_STMIN: + len = min_t(int, len, sizeof(u32)); + val = &so->force_tx_stmin; + break; + + case CAN_ISOTP_RX_STMIN: + len = min_t(int, len, sizeof(u32)); + val = &so->force_rx_stmin; + break; + + case CAN_ISOTP_LL_OPTS: + len = min_t(int, len, sizeof(struct can_isotp_ll_options)); + val = &so->ll; + break; + + default: + return -ENOPROTOOPT; + } + + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, val, len)) + return -EFAULT; + return 0; +} + +static void isotp_notify(struct isotp_sock *so, unsigned long msg, + struct net_device *dev) +{ + struct sock *sk = &so->sk; + + if (!net_eq(dev_net(dev), sock_net(sk))) + return; + + if (so->ifindex != dev->ifindex) + return; + + switch (msg) { + case NETDEV_UNREGISTER: + lock_sock(sk); + /* remove current filters & unregister */ + if (so->bound) { + if (isotp_register_rxid(so)) + can_rx_unregister(dev_net(dev), dev, so->rxid, + SINGLE_MASK(so->rxid), + isotp_rcv, sk); + + can_rx_unregister(dev_net(dev), dev, so->txid, + SINGLE_MASK(so->txid), + isotp_rcv_echo, sk); + } + + so->ifindex = 0; + so->bound = 0; + release_sock(sk); + + sk->sk_err = ENODEV; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + break; + + case NETDEV_DOWN: + sk->sk_err = ENETDOWN; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + break; + } +} + +static int isotp_notifier(struct notifier_block *nb, unsigned long msg, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + if (dev->type != ARPHRD_CAN) + return NOTIFY_DONE; + if (msg != NETDEV_UNREGISTER && msg != NETDEV_DOWN) + return NOTIFY_DONE; + if (unlikely(isotp_busy_notifier)) /* Check for reentrant bug. */ + return NOTIFY_DONE; + + spin_lock(&isotp_notifier_lock); + list_for_each_entry(isotp_busy_notifier, &isotp_notifier_list, notifier) { + spin_unlock(&isotp_notifier_lock); + isotp_notify(isotp_busy_notifier, msg, dev); + spin_lock(&isotp_notifier_lock); + } + isotp_busy_notifier = NULL; + spin_unlock(&isotp_notifier_lock); + return NOTIFY_DONE; +} + +static int isotp_init(struct sock *sk) +{ + struct isotp_sock *so = isotp_sk(sk); + + so->ifindex = 0; + so->bound = 0; + + so->opt.flags = CAN_ISOTP_DEFAULT_FLAGS; + so->opt.ext_address = CAN_ISOTP_DEFAULT_EXT_ADDRESS; + so->opt.rx_ext_address = CAN_ISOTP_DEFAULT_EXT_ADDRESS; + so->opt.rxpad_content = CAN_ISOTP_DEFAULT_PAD_CONTENT; + so->opt.txpad_content = CAN_ISOTP_DEFAULT_PAD_CONTENT; + so->opt.frame_txtime = CAN_ISOTP_DEFAULT_FRAME_TXTIME; + so->frame_txtime = CAN_ISOTP_DEFAULT_FRAME_TXTIME; + so->rxfc.bs = CAN_ISOTP_DEFAULT_RECV_BS; + so->rxfc.stmin = CAN_ISOTP_DEFAULT_RECV_STMIN; + so->rxfc.wftmax = CAN_ISOTP_DEFAULT_RECV_WFTMAX; + so->ll.mtu = CAN_ISOTP_DEFAULT_LL_MTU; + so->ll.tx_dl = CAN_ISOTP_DEFAULT_LL_TX_DL; + so->ll.tx_flags = CAN_ISOTP_DEFAULT_LL_TX_FLAGS; + + /* set ll_dl for tx path to similar place as for rx */ + so->tx.ll_dl = so->ll.tx_dl; + + so->rx.state = ISOTP_IDLE; + so->tx.state = ISOTP_IDLE; + + hrtimer_init(&so->rxtimer, CLOCK_MONOTONIC, HRTIMER_MODE_REL_SOFT); + so->rxtimer.function = isotp_rx_timer_handler; + hrtimer_init(&so->txtimer, CLOCK_MONOTONIC, HRTIMER_MODE_REL_SOFT); + so->txtimer.function = isotp_tx_timer_handler; + hrtimer_init(&so->txfrtimer, CLOCK_MONOTONIC, HRTIMER_MODE_REL_SOFT); + so->txfrtimer.function = isotp_txfr_timer_handler; + + init_waitqueue_head(&so->wait); + spin_lock_init(&so->rx_lock); + + spin_lock(&isotp_notifier_lock); + list_add_tail(&so->notifier, &isotp_notifier_list); + spin_unlock(&isotp_notifier_lock); + + return 0; +} + +static __poll_t isotp_poll(struct file *file, struct socket *sock, poll_table *wait) +{ + struct sock *sk = sock->sk; + struct isotp_sock *so = isotp_sk(sk); + + __poll_t mask = datagram_poll(file, sock, wait); + poll_wait(file, &so->wait, wait); + + /* Check for false positives due to TX state */ + if ((mask & EPOLLWRNORM) && (so->tx.state != ISOTP_IDLE)) + mask &= ~(EPOLLOUT | EPOLLWRNORM); + + return mask; +} + +static int isotp_sock_no_ioctlcmd(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + /* no ioctls for socket layer -> hand it down to NIC layer */ + return -ENOIOCTLCMD; +} + +static const struct proto_ops isotp_ops = { + .family = PF_CAN, + .release = isotp_release, + .bind = isotp_bind, + .connect = sock_no_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = isotp_getname, + .poll = isotp_poll, + .ioctl = isotp_sock_no_ioctlcmd, + .gettstamp = sock_gettstamp, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = isotp_setsockopt, + .getsockopt = isotp_getsockopt, + .sendmsg = isotp_sendmsg, + .recvmsg = isotp_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static struct proto isotp_proto __read_mostly = { + .name = "CAN_ISOTP", + .owner = THIS_MODULE, + .obj_size = sizeof(struct isotp_sock), + .init = isotp_init, +}; + +static const struct can_proto isotp_can_proto = { + .type = SOCK_DGRAM, + .protocol = CAN_ISOTP, + .ops = &isotp_ops, + .prot = &isotp_proto, +}; + +static struct notifier_block canisotp_notifier = { + .notifier_call = isotp_notifier +}; + +static __init int isotp_module_init(void) +{ + int err; + + pr_info("can: isotp protocol\n"); + + err = can_proto_register(&isotp_can_proto); + if (err < 0) + pr_err("can: registration of isotp protocol failed %pe\n", ERR_PTR(err)); + else + register_netdevice_notifier(&canisotp_notifier); + + return err; +} + +static __exit void isotp_module_exit(void) +{ + can_proto_unregister(&isotp_can_proto); + unregister_netdevice_notifier(&canisotp_notifier); +} + +module_init(isotp_module_init); +module_exit(isotp_module_exit); diff --git a/net/can/j1939/Kconfig b/net/can/j1939/Kconfig new file mode 100644 index 000000000..2998298b7 --- /dev/null +++ b/net/can/j1939/Kconfig @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# SAE J1939 network layer core configuration +# + +config CAN_J1939 + tristate "SAE J1939" + depends on CAN + help + SAE J1939 + Say Y to have in-kernel support for j1939 socket type. This + allows communication according to SAE j1939. + The relevant parts in kernel are + SAE j1939-21 (datalink & transport protocol) + & SAE j1939-81 (network management). diff --git a/net/can/j1939/Makefile b/net/can/j1939/Makefile new file mode 100644 index 000000000..19181bdae --- /dev/null +++ b/net/can/j1939/Makefile @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: GPL-2.0 + +obj-$(CONFIG_CAN_J1939) += can-j1939.o + +can-j1939-objs := \ + address-claim.o \ + bus.o \ + main.o \ + socket.o \ + transport.o diff --git a/net/can/j1939/address-claim.c b/net/can/j1939/address-claim.c new file mode 100644 index 000000000..ca4ad6cdd --- /dev/null +++ b/net/can/j1939/address-claim.c @@ -0,0 +1,270 @@ +// SPDX-License-Identifier: GPL-2.0 +// Copyright (c) 2010-2011 EIA Electronics, +// Kurt Van Dijck +// Copyright (c) 2010-2011 EIA Electronics, +// Pieter Beyens +// Copyright (c) 2017-2019 Pengutronix, +// Marc Kleine-Budde +// Copyright (c) 2017-2019 Pengutronix, +// Oleksij Rempel + +/* J1939 Address Claiming. + * Address Claiming in the kernel + * - keeps track of the AC states of ECU's, + * - resolves NAME<=>SA taking into account the AC states of ECU's. + * + * All Address Claim msgs (including host-originated msg) are processed + * at the receive path (a sent msg is always received again via CAN echo). + * As such, the processing of AC msgs is done in the order on which msgs + * are sent on the bus. + * + * This module doesn't send msgs itself (e.g. replies on Address Claims), + * this is the responsibility of a user space application or daemon. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include + +#include "j1939-priv.h" + +static inline name_t j1939_skb_to_name(const struct sk_buff *skb) +{ + return le64_to_cpup((__le64 *)skb->data); +} + +static inline bool j1939_ac_msg_is_request(struct sk_buff *skb) +{ + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + int req_pgn; + + if (skb->len < 3 || skcb->addr.pgn != J1939_PGN_REQUEST) + return false; + + req_pgn = skb->data[0] | (skb->data[1] << 8) | (skb->data[2] << 16); + + return req_pgn == J1939_PGN_ADDRESS_CLAIMED; +} + +static int j1939_ac_verify_outgoing(struct j1939_priv *priv, + struct sk_buff *skb) +{ + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + + if (skb->len != 8) { + netdev_notice(priv->ndev, "tx address claim with dlc %i\n", + skb->len); + return -EPROTO; + } + + if (skcb->addr.src_name != j1939_skb_to_name(skb)) { + netdev_notice(priv->ndev, "tx address claim with different name\n"); + return -EPROTO; + } + + if (skcb->addr.sa == J1939_NO_ADDR) { + netdev_notice(priv->ndev, "tx address claim with broadcast sa\n"); + return -EPROTO; + } + + /* ac must always be a broadcast */ + if (skcb->addr.dst_name || skcb->addr.da != J1939_NO_ADDR) { + netdev_notice(priv->ndev, "tx address claim with dest, not broadcast\n"); + return -EPROTO; + } + return 0; +} + +int j1939_ac_fixup(struct j1939_priv *priv, struct sk_buff *skb) +{ + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + int ret; + u8 addr; + + /* network mgmt: address claiming msgs */ + if (skcb->addr.pgn == J1939_PGN_ADDRESS_CLAIMED) { + struct j1939_ecu *ecu; + + ret = j1939_ac_verify_outgoing(priv, skb); + /* return both when failure & when successful */ + if (ret < 0) + return ret; + ecu = j1939_ecu_get_by_name(priv, skcb->addr.src_name); + if (!ecu) + return -ENODEV; + + if (ecu->addr != skcb->addr.sa) + /* hold further traffic for ecu, remove from parent */ + j1939_ecu_unmap(ecu); + j1939_ecu_put(ecu); + } else if (skcb->addr.src_name) { + /* assign source address */ + addr = j1939_name_to_addr(priv, skcb->addr.src_name); + if (!j1939_address_is_unicast(addr) && + !j1939_ac_msg_is_request(skb)) { + netdev_notice(priv->ndev, "tx drop: invalid sa for name 0x%016llx\n", + skcb->addr.src_name); + return -EADDRNOTAVAIL; + } + skcb->addr.sa = addr; + } + + /* assign destination address */ + if (skcb->addr.dst_name) { + addr = j1939_name_to_addr(priv, skcb->addr.dst_name); + if (!j1939_address_is_unicast(addr)) { + netdev_notice(priv->ndev, "tx drop: invalid da for name 0x%016llx\n", + skcb->addr.dst_name); + return -EADDRNOTAVAIL; + } + skcb->addr.da = addr; + } + return 0; +} + +static void j1939_ac_process(struct j1939_priv *priv, struct sk_buff *skb) +{ + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + struct j1939_ecu *ecu, *prev; + name_t name; + + if (skb->len != 8) { + netdev_notice(priv->ndev, "rx address claim with wrong dlc %i\n", + skb->len); + return; + } + + name = j1939_skb_to_name(skb); + skcb->addr.src_name = name; + if (!name) { + netdev_notice(priv->ndev, "rx address claim without name\n"); + return; + } + + if (!j1939_address_is_valid(skcb->addr.sa)) { + netdev_notice(priv->ndev, "rx address claim with broadcast sa\n"); + return; + } + + write_lock_bh(&priv->lock); + + /* Few words on the ECU ref counting: + * + * First we get an ECU handle, either with + * j1939_ecu_get_by_name_locked() (increments the ref counter) + * or j1939_ecu_create_locked() (initializes an ECU object + * with a ref counter of 1). + * + * j1939_ecu_unmap_locked() will decrement the ref counter, + * but only if the ECU was mapped before. So "ecu" still + * belongs to us. + * + * j1939_ecu_timer_start() will increment the ref counter + * before it starts the timer, so we can put the ecu when + * leaving this function. + */ + ecu = j1939_ecu_get_by_name_locked(priv, name); + + if (ecu && ecu->addr == skcb->addr.sa) { + /* The ISO 11783-5 standard, in "4.5.2 - Address claim + * requirements", states: + * d) No CF shall begin, or resume, transmission on the + * network until 250 ms after it has successfully claimed + * an address except when responding to a request for + * address-claimed. + * + * But "Figure 6" and "Figure 7" in "4.5.4.2 - Address-claim + * prioritization" show that the CF begins the transmission + * after 250 ms from the first AC (address-claimed) message + * even if it sends another AC message during that time window + * to resolve the address contention with another CF. + * + * As stated in "4.4.2.3 - Address-claimed message": + * In order to successfully claim an address, the CF sending + * an address claimed message shall not receive a contending + * claim from another CF for at least 250 ms. + * + * As stated in "4.4.3.2 - NAME management (NM) message": + * 1) A commanding CF can + * d) request that a CF with a specified NAME transmit + * the address-claimed message with its current NAME. + * 2) A target CF shall + * d) send an address-claimed message in response to a + * request for a matching NAME + * + * Taking the above arguments into account, the 250 ms wait is + * requested only during network initialization. + * + * Do not restart the timer on AC message if both the NAME and + * the address match and so if the address has already been + * claimed (timer has expired) or the AC message has been sent + * to resolve the contention with another CF (timer is still + * running). + */ + goto out_ecu_put; + } + + if (!ecu && j1939_address_is_unicast(skcb->addr.sa)) + ecu = j1939_ecu_create_locked(priv, name); + + if (IS_ERR_OR_NULL(ecu)) + goto out_unlock_bh; + + /* cancel pending (previous) address claim */ + j1939_ecu_timer_cancel(ecu); + + if (j1939_address_is_idle(skcb->addr.sa)) { + j1939_ecu_unmap_locked(ecu); + goto out_ecu_put; + } + + /* save new addr */ + if (ecu->addr != skcb->addr.sa) + j1939_ecu_unmap_locked(ecu); + ecu->addr = skcb->addr.sa; + + prev = j1939_ecu_get_by_addr_locked(priv, skcb->addr.sa); + if (prev) { + if (ecu->name > prev->name) { + j1939_ecu_unmap_locked(ecu); + j1939_ecu_put(prev); + goto out_ecu_put; + } else { + /* kick prev if less or equal */ + j1939_ecu_unmap_locked(prev); + j1939_ecu_put(prev); + } + } + + j1939_ecu_timer_start(ecu); + out_ecu_put: + j1939_ecu_put(ecu); + out_unlock_bh: + write_unlock_bh(&priv->lock); +} + +void j1939_ac_recv(struct j1939_priv *priv, struct sk_buff *skb) +{ + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + struct j1939_ecu *ecu; + + /* network mgmt */ + if (skcb->addr.pgn == J1939_PGN_ADDRESS_CLAIMED) { + j1939_ac_process(priv, skb); + } else if (j1939_address_is_unicast(skcb->addr.sa)) { + /* assign source name */ + ecu = j1939_ecu_get_by_addr(priv, skcb->addr.sa); + if (ecu) { + skcb->addr.src_name = ecu->name; + j1939_ecu_put(ecu); + } + } + + /* assign destination name */ + ecu = j1939_ecu_get_by_addr(priv, skcb->addr.da); + if (ecu) { + skcb->addr.dst_name = ecu->name; + j1939_ecu_put(ecu); + } +} diff --git a/net/can/j1939/bus.c b/net/can/j1939/bus.c new file mode 100644 index 000000000..486687901 --- /dev/null +++ b/net/can/j1939/bus.c @@ -0,0 +1,333 @@ +// SPDX-License-Identifier: GPL-2.0 +// Copyright (c) 2010-2011 EIA Electronics, +// Kurt Van Dijck +// Copyright (c) 2017-2019 Pengutronix, +// Marc Kleine-Budde +// Copyright (c) 2017-2019 Pengutronix, +// Oleksij Rempel + +/* bus for j1939 remote devices + * Since rtnetlink, no real bus is used. + */ + +#include + +#include "j1939-priv.h" + +static void __j1939_ecu_release(struct kref *kref) +{ + struct j1939_ecu *ecu = container_of(kref, struct j1939_ecu, kref); + struct j1939_priv *priv = ecu->priv; + + list_del(&ecu->list); + kfree(ecu); + j1939_priv_put(priv); +} + +void j1939_ecu_put(struct j1939_ecu *ecu) +{ + kref_put(&ecu->kref, __j1939_ecu_release); +} + +static void j1939_ecu_get(struct j1939_ecu *ecu) +{ + kref_get(&ecu->kref); +} + +static bool j1939_ecu_is_mapped_locked(struct j1939_ecu *ecu) +{ + struct j1939_priv *priv = ecu->priv; + + lockdep_assert_held(&priv->lock); + + return j1939_ecu_find_by_addr_locked(priv, ecu->addr) == ecu; +} + +/* ECU device interface */ +/* map ECU to a bus address space */ +static void j1939_ecu_map_locked(struct j1939_ecu *ecu) +{ + struct j1939_priv *priv = ecu->priv; + struct j1939_addr_ent *ent; + + lockdep_assert_held(&priv->lock); + + if (!j1939_address_is_unicast(ecu->addr)) + return; + + ent = &priv->ents[ecu->addr]; + + if (ent->ecu) { + netdev_warn(priv->ndev, "Trying to map already mapped ECU, addr: 0x%02x, name: 0x%016llx. Skip it.\n", + ecu->addr, ecu->name); + return; + } + + j1939_ecu_get(ecu); + ent->ecu = ecu; + ent->nusers += ecu->nusers; +} + +/* unmap ECU from a bus address space */ +void j1939_ecu_unmap_locked(struct j1939_ecu *ecu) +{ + struct j1939_priv *priv = ecu->priv; + struct j1939_addr_ent *ent; + + lockdep_assert_held(&priv->lock); + + if (!j1939_address_is_unicast(ecu->addr)) + return; + + if (!j1939_ecu_is_mapped_locked(ecu)) + return; + + ent = &priv->ents[ecu->addr]; + ent->ecu = NULL; + ent->nusers -= ecu->nusers; + j1939_ecu_put(ecu); +} + +void j1939_ecu_unmap(struct j1939_ecu *ecu) +{ + write_lock_bh(&ecu->priv->lock); + j1939_ecu_unmap_locked(ecu); + write_unlock_bh(&ecu->priv->lock); +} + +void j1939_ecu_unmap_all(struct j1939_priv *priv) +{ + int i; + + write_lock_bh(&priv->lock); + for (i = 0; i < ARRAY_SIZE(priv->ents); i++) + if (priv->ents[i].ecu) + j1939_ecu_unmap_locked(priv->ents[i].ecu); + write_unlock_bh(&priv->lock); +} + +void j1939_ecu_timer_start(struct j1939_ecu *ecu) +{ + /* The ECU is held here and released in the + * j1939_ecu_timer_handler() or j1939_ecu_timer_cancel(). + */ + j1939_ecu_get(ecu); + + /* Schedule timer in 250 msec to commit address change. */ + hrtimer_start(&ecu->ac_timer, ms_to_ktime(250), + HRTIMER_MODE_REL_SOFT); +} + +void j1939_ecu_timer_cancel(struct j1939_ecu *ecu) +{ + if (hrtimer_cancel(&ecu->ac_timer)) + j1939_ecu_put(ecu); +} + +static enum hrtimer_restart j1939_ecu_timer_handler(struct hrtimer *hrtimer) +{ + struct j1939_ecu *ecu = + container_of(hrtimer, struct j1939_ecu, ac_timer); + struct j1939_priv *priv = ecu->priv; + + write_lock_bh(&priv->lock); + /* TODO: can we test if ecu->addr is unicast before starting + * the timer? + */ + j1939_ecu_map_locked(ecu); + + /* The corresponding j1939_ecu_get() is in + * j1939_ecu_timer_start(). + */ + j1939_ecu_put(ecu); + write_unlock_bh(&priv->lock); + + return HRTIMER_NORESTART; +} + +struct j1939_ecu *j1939_ecu_create_locked(struct j1939_priv *priv, name_t name) +{ + struct j1939_ecu *ecu; + + lockdep_assert_held(&priv->lock); + + ecu = kzalloc(sizeof(*ecu), gfp_any()); + if (!ecu) + return ERR_PTR(-ENOMEM); + kref_init(&ecu->kref); + ecu->addr = J1939_IDLE_ADDR; + ecu->name = name; + + hrtimer_init(&ecu->ac_timer, CLOCK_MONOTONIC, HRTIMER_MODE_REL_SOFT); + ecu->ac_timer.function = j1939_ecu_timer_handler; + INIT_LIST_HEAD(&ecu->list); + + j1939_priv_get(priv); + ecu->priv = priv; + list_add_tail(&ecu->list, &priv->ecus); + + return ecu; +} + +struct j1939_ecu *j1939_ecu_find_by_addr_locked(struct j1939_priv *priv, + u8 addr) +{ + lockdep_assert_held(&priv->lock); + + return priv->ents[addr].ecu; +} + +struct j1939_ecu *j1939_ecu_get_by_addr_locked(struct j1939_priv *priv, u8 addr) +{ + struct j1939_ecu *ecu; + + lockdep_assert_held(&priv->lock); + + if (!j1939_address_is_unicast(addr)) + return NULL; + + ecu = j1939_ecu_find_by_addr_locked(priv, addr); + if (ecu) + j1939_ecu_get(ecu); + + return ecu; +} + +struct j1939_ecu *j1939_ecu_get_by_addr(struct j1939_priv *priv, u8 addr) +{ + struct j1939_ecu *ecu; + + read_lock_bh(&priv->lock); + ecu = j1939_ecu_get_by_addr_locked(priv, addr); + read_unlock_bh(&priv->lock); + + return ecu; +} + +/* get pointer to ecu without increasing ref counter */ +static struct j1939_ecu *j1939_ecu_find_by_name_locked(struct j1939_priv *priv, + name_t name) +{ + struct j1939_ecu *ecu; + + lockdep_assert_held(&priv->lock); + + list_for_each_entry(ecu, &priv->ecus, list) { + if (ecu->name == name) + return ecu; + } + + return NULL; +} + +struct j1939_ecu *j1939_ecu_get_by_name_locked(struct j1939_priv *priv, + name_t name) +{ + struct j1939_ecu *ecu; + + lockdep_assert_held(&priv->lock); + + if (!name) + return NULL; + + ecu = j1939_ecu_find_by_name_locked(priv, name); + if (ecu) + j1939_ecu_get(ecu); + + return ecu; +} + +struct j1939_ecu *j1939_ecu_get_by_name(struct j1939_priv *priv, name_t name) +{ + struct j1939_ecu *ecu; + + read_lock_bh(&priv->lock); + ecu = j1939_ecu_get_by_name_locked(priv, name); + read_unlock_bh(&priv->lock); + + return ecu; +} + +u8 j1939_name_to_addr(struct j1939_priv *priv, name_t name) +{ + struct j1939_ecu *ecu; + int addr = J1939_IDLE_ADDR; + + if (!name) + return J1939_NO_ADDR; + + read_lock_bh(&priv->lock); + ecu = j1939_ecu_find_by_name_locked(priv, name); + if (ecu && j1939_ecu_is_mapped_locked(ecu)) + /* ecu's SA is registered */ + addr = ecu->addr; + + read_unlock_bh(&priv->lock); + + return addr; +} + +/* TX addr/name accounting + * Transport protocol needs to know if a SA is local or not + * These functions originate from userspace manipulating sockets, + * so locking is straigforward + */ + +int j1939_local_ecu_get(struct j1939_priv *priv, name_t name, u8 sa) +{ + struct j1939_ecu *ecu; + int err = 0; + + write_lock_bh(&priv->lock); + + if (j1939_address_is_unicast(sa)) + priv->ents[sa].nusers++; + + if (!name) + goto done; + + ecu = j1939_ecu_get_by_name_locked(priv, name); + if (!ecu) + ecu = j1939_ecu_create_locked(priv, name); + err = PTR_ERR_OR_ZERO(ecu); + if (err) + goto done; + + ecu->nusers++; + /* TODO: do we care if ecu->addr != sa? */ + if (j1939_ecu_is_mapped_locked(ecu)) + /* ecu's sa is active already */ + priv->ents[ecu->addr].nusers++; + + done: + write_unlock_bh(&priv->lock); + + return err; +} + +void j1939_local_ecu_put(struct j1939_priv *priv, name_t name, u8 sa) +{ + struct j1939_ecu *ecu; + + write_lock_bh(&priv->lock); + + if (j1939_address_is_unicast(sa)) + priv->ents[sa].nusers--; + + if (!name) + goto done; + + ecu = j1939_ecu_find_by_name_locked(priv, name); + if (WARN_ON_ONCE(!ecu)) + goto done; + + ecu->nusers--; + /* TODO: do we care if ecu->addr != sa? */ + if (j1939_ecu_is_mapped_locked(ecu)) + /* ecu's sa is active already */ + priv->ents[ecu->addr].nusers--; + j1939_ecu_put(ecu); + + done: + write_unlock_bh(&priv->lock); +} diff --git a/net/can/j1939/j1939-priv.h b/net/can/j1939/j1939-priv.h new file mode 100644 index 000000000..16af1a7f8 --- /dev/null +++ b/net/can/j1939/j1939-priv.h @@ -0,0 +1,343 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +// Copyright (c) 2010-2011 EIA Electronics, +// Kurt Van Dijck +// Copyright (c) 2017-2019 Pengutronix, +// Marc Kleine-Budde +// Copyright (c) 2017-2019 Pengutronix, +// Oleksij Rempel + +#ifndef _J1939_PRIV_H_ +#define _J1939_PRIV_H_ + +#include +#include + +/* Timeout to receive the abort signal over loop back. In case CAN + * bus is open, the timeout should be triggered. + */ +#define J1939_XTP_ABORT_TIMEOUT_MS 500 +#define J1939_SIMPLE_ECHO_TIMEOUT_MS (10 * 1000) + +struct j1939_session; +enum j1939_sk_errqueue_type { + J1939_ERRQUEUE_TX_ACK, + J1939_ERRQUEUE_TX_SCHED, + J1939_ERRQUEUE_TX_ABORT, + J1939_ERRQUEUE_RX_RTS, + J1939_ERRQUEUE_RX_DPO, + J1939_ERRQUEUE_RX_ABORT, +}; + +/* j1939 devices */ +struct j1939_ecu { + struct list_head list; + name_t name; + u8 addr; + + /* indicates that this ecu successfully claimed @sa as its address */ + struct hrtimer ac_timer; + struct kref kref; + struct j1939_priv *priv; + + /* count users, to help transport protocol decide for interaction */ + int nusers; +}; + +struct j1939_priv { + struct list_head ecus; + /* local list entry in priv + * These allow irq (& softirq) context lookups on j1939 devices + * This approach (separate lists) is done as the other 2 alternatives + * are not easier or even wrong + * 1) using the pure kobject methods involves mutexes, which are not + * allowed in irq context. + * 2) duplicating data structures would require a lot of synchronization + * code + * usage: + */ + + /* segments need a lock to protect the above list */ + rwlock_t lock; + + struct net_device *ndev; + + /* list of 256 ecu ptrs, that cache the claimed addresses. + * also protected by the above lock + */ + struct j1939_addr_ent { + struct j1939_ecu *ecu; + /* count users, to help transport protocol */ + int nusers; + } ents[256]; + + struct kref kref; + + /* List of active sessions to prevent start of conflicting + * one. + * + * Do not start two sessions of same type, addresses and + * direction. + */ + struct list_head active_session_list; + + /* protects active_session_list */ + spinlock_t active_session_list_lock; + + unsigned int tp_max_packet_size; + + /* lock for j1939_socks list */ + spinlock_t j1939_socks_lock; + struct list_head j1939_socks; + + struct kref rx_kref; + u32 rx_tskey; +}; + +void j1939_ecu_put(struct j1939_ecu *ecu); + +/* keep the cache of what is local */ +int j1939_local_ecu_get(struct j1939_priv *priv, name_t name, u8 sa); +void j1939_local_ecu_put(struct j1939_priv *priv, name_t name, u8 sa); + +static inline bool j1939_address_is_unicast(u8 addr) +{ + return addr <= J1939_MAX_UNICAST_ADDR; +} + +static inline bool j1939_address_is_idle(u8 addr) +{ + return addr == J1939_IDLE_ADDR; +} + +static inline bool j1939_address_is_valid(u8 addr) +{ + return addr != J1939_NO_ADDR; +} + +static inline bool j1939_pgn_is_pdu1(pgn_t pgn) +{ + /* ignore dp & res bits for this */ + return (pgn & 0xff00) < 0xf000; +} + +/* utility to correctly unmap an ECU */ +void j1939_ecu_unmap_locked(struct j1939_ecu *ecu); +void j1939_ecu_unmap(struct j1939_ecu *ecu); + +u8 j1939_name_to_addr(struct j1939_priv *priv, name_t name); +struct j1939_ecu *j1939_ecu_find_by_addr_locked(struct j1939_priv *priv, + u8 addr); +struct j1939_ecu *j1939_ecu_get_by_addr(struct j1939_priv *priv, u8 addr); +struct j1939_ecu *j1939_ecu_get_by_addr_locked(struct j1939_priv *priv, + u8 addr); +struct j1939_ecu *j1939_ecu_get_by_name(struct j1939_priv *priv, name_t name); +struct j1939_ecu *j1939_ecu_get_by_name_locked(struct j1939_priv *priv, + name_t name); + +enum j1939_transfer_type { + J1939_TP, + J1939_ETP, + J1939_SIMPLE, +}; + +struct j1939_addr { + name_t src_name; + name_t dst_name; + pgn_t pgn; + + u8 sa; + u8 da; + + u8 type; +}; + +/* control buffer of the sk_buff */ +struct j1939_sk_buff_cb { + /* Offset in bytes within one ETP session */ + u32 offset; + + /* for tx, MSG_SYN will be used to sync on sockets */ + u32 msg_flags; + u32 tskey; + + struct j1939_addr addr; + + /* Flags for quick lookups during skb processing. + * These are set in the receive path only. + */ +#define J1939_ECU_LOCAL_SRC BIT(0) +#define J1939_ECU_LOCAL_DST BIT(1) + u8 flags; + + priority_t priority; +}; + +static inline +struct j1939_sk_buff_cb *j1939_skb_to_cb(const struct sk_buff *skb) +{ + BUILD_BUG_ON(sizeof(struct j1939_sk_buff_cb) > sizeof(skb->cb)); + + return (struct j1939_sk_buff_cb *)skb->cb; +} + +int j1939_send_one(struct j1939_priv *priv, struct sk_buff *skb); +void j1939_sk_recv(struct j1939_priv *priv, struct sk_buff *skb); +bool j1939_sk_recv_match(struct j1939_priv *priv, + struct j1939_sk_buff_cb *skcb); +void j1939_sk_send_loop_abort(struct sock *sk, int err); +void j1939_sk_errqueue(struct j1939_session *session, + enum j1939_sk_errqueue_type type); +void j1939_sk_queue_activate_next(struct j1939_session *session); + +/* stack entries */ +struct j1939_session *j1939_tp_send(struct j1939_priv *priv, + struct sk_buff *skb, size_t size); +int j1939_tp_recv(struct j1939_priv *priv, struct sk_buff *skb); +int j1939_ac_fixup(struct j1939_priv *priv, struct sk_buff *skb); +void j1939_ac_recv(struct j1939_priv *priv, struct sk_buff *skb); +void j1939_simple_recv(struct j1939_priv *priv, struct sk_buff *skb); + +/* network management */ +struct j1939_ecu *j1939_ecu_create_locked(struct j1939_priv *priv, name_t name); + +void j1939_ecu_timer_start(struct j1939_ecu *ecu); +void j1939_ecu_timer_cancel(struct j1939_ecu *ecu); +void j1939_ecu_unmap_all(struct j1939_priv *priv); + +struct j1939_priv *j1939_netdev_start(struct net_device *ndev); +void j1939_netdev_stop(struct j1939_priv *priv); + +void j1939_priv_put(struct j1939_priv *priv); +void j1939_priv_get(struct j1939_priv *priv); + +/* notify/alert all j1939 sockets bound to ifindex */ +void j1939_sk_netdev_event_netdown(struct j1939_priv *priv); +int j1939_cancel_active_session(struct j1939_priv *priv, struct sock *sk); +void j1939_tp_init(struct j1939_priv *priv); + +/* decrement pending skb for a j1939 socket */ +void j1939_sock_pending_del(struct sock *sk); + +enum j1939_session_state { + J1939_SESSION_NEW, + J1939_SESSION_ACTIVE, + /* waiting for abort signal on the bus */ + J1939_SESSION_WAITING_ABORT, + J1939_SESSION_ACTIVE_MAX, + J1939_SESSION_DONE, +}; + +struct j1939_session { + struct j1939_priv *priv; + struct list_head active_session_list_entry; + struct list_head sk_session_queue_entry; + struct kref kref; + struct sock *sk; + + /* ifindex, src, dst, pgn define the session block + * the are _never_ modified after insertion in the list + * this decreases locking problems a _lot_ + */ + struct j1939_sk_buff_cb skcb; + struct sk_buff_head skb_queue; + + /* all tx related stuff (last_txcmd, pkt.tx) + * is protected (modified only) with the txtimer hrtimer + * 'total' & 'block' are never changed, + * last_cmd, last & block are protected by ->lock + * this means that the tx may run after cts is received that should + * have stopped tx, but this time discrepancy is never avoided anyhow + */ + u8 last_cmd, last_txcmd; + bool transmission; + bool extd; + /* Total message size, number of bytes */ + unsigned int total_message_size; + /* Total number of bytes queue from socket to the session */ + unsigned int total_queued_size; + unsigned int tx_retry; + + int err; + u32 tskey; + enum j1939_session_state state; + + /* Packets counters for a (extended) transfer session. The packet is + * maximal of 7 bytes. + */ + struct { + /* total - total number of packets for this session */ + unsigned int total; + /* last - last packet of a transfer block after which + * responder should send ETP.CM_CTS and originator + * ETP.CM_DPO + */ + unsigned int last; + /* tx - number of packets send by originator node. + * this counter can be set back if responder node + * didn't received all packets send by originator. + */ + unsigned int tx; + unsigned int tx_acked; + /* rx - number of packets received */ + unsigned int rx; + /* block - amount of packets expected in one block */ + unsigned int block; + /* dpo - ETP.CM_DPO, Data Packet Offset */ + unsigned int dpo; + } pkt; + struct hrtimer txtimer, rxtimer; +}; + +struct j1939_sock { + struct sock sk; /* must be first to skip with memset */ + struct j1939_priv *priv; + struct list_head list; + +#define J1939_SOCK_BOUND BIT(0) +#define J1939_SOCK_CONNECTED BIT(1) +#define J1939_SOCK_PROMISC BIT(2) +#define J1939_SOCK_ERRQUEUE BIT(3) + int state; + + int ifindex; + struct j1939_addr addr; + struct j1939_filter *filters; + int nfilters; + pgn_t pgn_rx_filter; + + /* j1939 may emit equal PGN (!= equal CAN-id's) out of order + * when transport protocol comes in. + * To allow emitting in order, keep a 'pending' nr. of packets + */ + atomic_t skb_pending; + wait_queue_head_t waitq; + + /* lock for the sk_session_queue list */ + spinlock_t sk_session_queue_lock; + struct list_head sk_session_queue; +}; + +static inline struct j1939_sock *j1939_sk(const struct sock *sk) +{ + return container_of(sk, struct j1939_sock, sk); +} + +void j1939_session_get(struct j1939_session *session); +void j1939_session_put(struct j1939_session *session); +void j1939_session_skb_queue(struct j1939_session *session, + struct sk_buff *skb); +int j1939_session_activate(struct j1939_session *session); +void j1939_tp_schedule_txtimer(struct j1939_session *session, int msec); +void j1939_session_timers_cancel(struct j1939_session *session); + +#define J1939_MIN_TP_PACKET_SIZE 9 +#define J1939_MAX_TP_PACKET_SIZE (7 * 0xff) +#define J1939_MAX_ETP_PACKET_SIZE (7 * 0x00ffffff) + +#define J1939_REGULAR 0 +#define J1939_EXTENDED 1 + +/* CAN protocol */ +extern const struct can_proto j1939_can_proto; + +#endif /* _J1939_PRIV_H_ */ diff --git a/net/can/j1939/main.c b/net/can/j1939/main.c new file mode 100644 index 000000000..ecff1c947 --- /dev/null +++ b/net/can/j1939/main.c @@ -0,0 +1,429 @@ +// SPDX-License-Identifier: GPL-2.0 +// Copyright (c) 2010-2011 EIA Electronics, +// Pieter Beyens +// Copyright (c) 2010-2011 EIA Electronics, +// Kurt Van Dijck +// Copyright (c) 2018 Protonic, +// Robin van der Gracht +// Copyright (c) 2017-2019 Pengutronix, +// Marc Kleine-Budde +// Copyright (c) 2017-2019 Pengutronix, +// Oleksij Rempel + +/* Core of can-j1939 that links j1939 to CAN. */ + +#include +#include +#include +#include +#include + +#include "j1939-priv.h" + +MODULE_DESCRIPTION("PF_CAN SAE J1939"); +MODULE_LICENSE("GPL v2"); +MODULE_AUTHOR("EIA Electronics (Kurt Van Dijck & Pieter Beyens)"); +MODULE_ALIAS("can-proto-" __stringify(CAN_J1939)); + +/* LOWLEVEL CAN interface */ + +/* CAN_HDR: #bytes before can_frame data part */ +#define J1939_CAN_HDR (offsetof(struct can_frame, data)) + +/* CAN_FTR: #bytes beyond data part */ +#define J1939_CAN_FTR (sizeof(struct can_frame) - J1939_CAN_HDR - \ + sizeof(((struct can_frame *)0)->data)) + +/* lowest layer */ +static void j1939_can_recv(struct sk_buff *iskb, void *data) +{ + struct j1939_priv *priv = data; + struct sk_buff *skb; + struct j1939_sk_buff_cb *skcb, *iskcb; + struct can_frame *cf; + + /* make sure we only get Classical CAN frames */ + if (!can_is_can_skb(iskb)) + return; + + /* create a copy of the skb + * j1939 only delivers the real data bytes, + * the header goes into sockaddr. + * j1939 may not touch the incoming skb in such way + */ + skb = skb_clone(iskb, GFP_ATOMIC); + if (!skb) + return; + + j1939_priv_get(priv); + can_skb_set_owner(skb, iskb->sk); + + /* get a pointer to the header of the skb + * the skb payload (pointer) is moved, so that the next skb_data + * returns the actual payload + */ + cf = (void *)skb->data; + skb_pull(skb, J1939_CAN_HDR); + + /* fix length, set to dlc, with 8 maximum */ + skb_trim(skb, min_t(uint8_t, cf->len, 8)); + + /* set addr */ + skcb = j1939_skb_to_cb(skb); + memset(skcb, 0, sizeof(*skcb)); + + iskcb = j1939_skb_to_cb(iskb); + skcb->tskey = iskcb->tskey; + skcb->priority = (cf->can_id >> 26) & 0x7; + skcb->addr.sa = cf->can_id; + skcb->addr.pgn = (cf->can_id >> 8) & J1939_PGN_MAX; + /* set default message type */ + skcb->addr.type = J1939_TP; + + if (!j1939_address_is_valid(skcb->addr.sa)) { + netdev_err_once(priv->ndev, "%s: sa is broadcast address, ignoring!\n", + __func__); + goto done; + } + + if (j1939_pgn_is_pdu1(skcb->addr.pgn)) { + /* Type 1: with destination address */ + skcb->addr.da = skcb->addr.pgn; + /* normalize pgn: strip dst address */ + skcb->addr.pgn &= 0x3ff00; + } else { + /* set broadcast address */ + skcb->addr.da = J1939_NO_ADDR; + } + + /* update localflags */ + read_lock_bh(&priv->lock); + if (j1939_address_is_unicast(skcb->addr.sa) && + priv->ents[skcb->addr.sa].nusers) + skcb->flags |= J1939_ECU_LOCAL_SRC; + if (j1939_address_is_unicast(skcb->addr.da) && + priv->ents[skcb->addr.da].nusers) + skcb->flags |= J1939_ECU_LOCAL_DST; + read_unlock_bh(&priv->lock); + + /* deliver into the j1939 stack ... */ + j1939_ac_recv(priv, skb); + + if (j1939_tp_recv(priv, skb)) + /* this means the transport layer processed the message */ + goto done; + + j1939_simple_recv(priv, skb); + j1939_sk_recv(priv, skb); + done: + j1939_priv_put(priv); + kfree_skb(skb); +} + +/* NETDEV MANAGEMENT */ + +/* values for can_rx_(un)register */ +#define J1939_CAN_ID CAN_EFF_FLAG +#define J1939_CAN_MASK (CAN_EFF_FLAG | CAN_RTR_FLAG) + +static DEFINE_MUTEX(j1939_netdev_lock); + +static struct j1939_priv *j1939_priv_create(struct net_device *ndev) +{ + struct j1939_priv *priv; + + priv = kzalloc(sizeof(*priv), GFP_KERNEL); + if (!priv) + return NULL; + + rwlock_init(&priv->lock); + INIT_LIST_HEAD(&priv->ecus); + priv->ndev = ndev; + kref_init(&priv->kref); + kref_init(&priv->rx_kref); + dev_hold(ndev); + + netdev_dbg(priv->ndev, "%s : 0x%p\n", __func__, priv); + + return priv; +} + +static inline void j1939_priv_set(struct net_device *ndev, + struct j1939_priv *priv) +{ + struct can_ml_priv *can_ml = can_get_ml_priv(ndev); + + can_ml->j1939_priv = priv; +} + +static void __j1939_priv_release(struct kref *kref) +{ + struct j1939_priv *priv = container_of(kref, struct j1939_priv, kref); + struct net_device *ndev = priv->ndev; + + netdev_dbg(priv->ndev, "%s: 0x%p\n", __func__, priv); + + WARN_ON_ONCE(!list_empty(&priv->active_session_list)); + WARN_ON_ONCE(!list_empty(&priv->ecus)); + WARN_ON_ONCE(!list_empty(&priv->j1939_socks)); + + dev_put(ndev); + kfree(priv); +} + +void j1939_priv_put(struct j1939_priv *priv) +{ + kref_put(&priv->kref, __j1939_priv_release); +} + +void j1939_priv_get(struct j1939_priv *priv) +{ + kref_get(&priv->kref); +} + +static int j1939_can_rx_register(struct j1939_priv *priv) +{ + struct net_device *ndev = priv->ndev; + int ret; + + j1939_priv_get(priv); + ret = can_rx_register(dev_net(ndev), ndev, J1939_CAN_ID, J1939_CAN_MASK, + j1939_can_recv, priv, "j1939", NULL); + if (ret < 0) { + j1939_priv_put(priv); + return ret; + } + + return 0; +} + +static void j1939_can_rx_unregister(struct j1939_priv *priv) +{ + struct net_device *ndev = priv->ndev; + + can_rx_unregister(dev_net(ndev), ndev, J1939_CAN_ID, J1939_CAN_MASK, + j1939_can_recv, priv); + + /* The last reference of priv is dropped by the RCU deferred + * j1939_sk_sock_destruct() of the last socket, so we can + * safely drop this reference here. + */ + j1939_priv_put(priv); +} + +static void __j1939_rx_release(struct kref *kref) + __releases(&j1939_netdev_lock) +{ + struct j1939_priv *priv = container_of(kref, struct j1939_priv, + rx_kref); + + j1939_can_rx_unregister(priv); + j1939_ecu_unmap_all(priv); + j1939_priv_set(priv->ndev, NULL); + mutex_unlock(&j1939_netdev_lock); +} + +/* get pointer to priv without increasing ref counter */ +static inline struct j1939_priv *j1939_ndev_to_priv(struct net_device *ndev) +{ + struct can_ml_priv *can_ml = can_get_ml_priv(ndev); + + return can_ml->j1939_priv; +} + +static struct j1939_priv *j1939_priv_get_by_ndev_locked(struct net_device *ndev) +{ + struct j1939_priv *priv; + + lockdep_assert_held(&j1939_netdev_lock); + + priv = j1939_ndev_to_priv(ndev); + if (priv) + j1939_priv_get(priv); + + return priv; +} + +static struct j1939_priv *j1939_priv_get_by_ndev(struct net_device *ndev) +{ + struct j1939_priv *priv; + + mutex_lock(&j1939_netdev_lock); + priv = j1939_priv_get_by_ndev_locked(ndev); + mutex_unlock(&j1939_netdev_lock); + + return priv; +} + +struct j1939_priv *j1939_netdev_start(struct net_device *ndev) +{ + struct j1939_priv *priv, *priv_new; + int ret; + + mutex_lock(&j1939_netdev_lock); + priv = j1939_priv_get_by_ndev_locked(ndev); + if (priv) { + kref_get(&priv->rx_kref); + mutex_unlock(&j1939_netdev_lock); + return priv; + } + mutex_unlock(&j1939_netdev_lock); + + priv = j1939_priv_create(ndev); + if (!priv) + return ERR_PTR(-ENOMEM); + + j1939_tp_init(priv); + spin_lock_init(&priv->j1939_socks_lock); + INIT_LIST_HEAD(&priv->j1939_socks); + + mutex_lock(&j1939_netdev_lock); + priv_new = j1939_priv_get_by_ndev_locked(ndev); + if (priv_new) { + /* Someone was faster than us, use their priv and roll + * back our's. + */ + kref_get(&priv_new->rx_kref); + mutex_unlock(&j1939_netdev_lock); + dev_put(ndev); + kfree(priv); + return priv_new; + } + j1939_priv_set(ndev, priv); + + ret = j1939_can_rx_register(priv); + if (ret < 0) + goto out_priv_put; + + mutex_unlock(&j1939_netdev_lock); + return priv; + + out_priv_put: + j1939_priv_set(ndev, NULL); + mutex_unlock(&j1939_netdev_lock); + + dev_put(ndev); + kfree(priv); + + return ERR_PTR(ret); +} + +void j1939_netdev_stop(struct j1939_priv *priv) +{ + kref_put_mutex(&priv->rx_kref, __j1939_rx_release, &j1939_netdev_lock); + j1939_priv_put(priv); +} + +int j1939_send_one(struct j1939_priv *priv, struct sk_buff *skb) +{ + int ret, dlc; + canid_t canid; + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + struct can_frame *cf; + + /* apply sanity checks */ + if (j1939_pgn_is_pdu1(skcb->addr.pgn)) + skcb->addr.pgn &= J1939_PGN_PDU1_MAX; + else + skcb->addr.pgn &= J1939_PGN_MAX; + + if (skcb->priority > 7) + skcb->priority = 6; + + ret = j1939_ac_fixup(priv, skb); + if (unlikely(ret)) + goto failed; + dlc = skb->len; + + /* re-claim the CAN_HDR from the SKB */ + cf = skb_push(skb, J1939_CAN_HDR); + + /* initialize header structure */ + memset(cf, 0, J1939_CAN_HDR); + + /* make it a full can frame again */ + skb_put(skb, J1939_CAN_FTR + (8 - dlc)); + + canid = CAN_EFF_FLAG | + (skcb->priority << 26) | + (skcb->addr.pgn << 8) | + skcb->addr.sa; + if (j1939_pgn_is_pdu1(skcb->addr.pgn)) + canid |= skcb->addr.da << 8; + + cf->can_id = canid; + cf->len = dlc; + + return can_send(skb, 1); + + failed: + kfree_skb(skb); + return ret; +} + +static int j1939_netdev_notify(struct notifier_block *nb, + unsigned long msg, void *data) +{ + struct net_device *ndev = netdev_notifier_info_to_dev(data); + struct can_ml_priv *can_ml = can_get_ml_priv(ndev); + struct j1939_priv *priv; + + if (!can_ml) + goto notify_done; + + priv = j1939_priv_get_by_ndev(ndev); + if (!priv) + goto notify_done; + + switch (msg) { + case NETDEV_DOWN: + j1939_cancel_active_session(priv, NULL); + j1939_sk_netdev_event_netdown(priv); + j1939_ecu_unmap_all(priv); + break; + } + + j1939_priv_put(priv); + +notify_done: + return NOTIFY_DONE; +} + +static struct notifier_block j1939_netdev_notifier = { + .notifier_call = j1939_netdev_notify, +}; + +/* MODULE interface */ +static __init int j1939_module_init(void) +{ + int ret; + + pr_info("can: SAE J1939\n"); + + ret = register_netdevice_notifier(&j1939_netdev_notifier); + if (ret) + goto fail_notifier; + + ret = can_proto_register(&j1939_can_proto); + if (ret < 0) { + pr_err("can: registration of j1939 protocol failed\n"); + goto fail_sk; + } + + return 0; + + fail_sk: + unregister_netdevice_notifier(&j1939_netdev_notifier); + fail_notifier: + return ret; +} + +static __exit void j1939_module_exit(void) +{ + can_proto_unregister(&j1939_can_proto); + + unregister_netdevice_notifier(&j1939_netdev_notifier); +} + +module_init(j1939_module_init); +module_exit(j1939_module_exit); diff --git a/net/can/j1939/socket.c b/net/can/j1939/socket.c new file mode 100644 index 000000000..b0be23559 --- /dev/null +++ b/net/can/j1939/socket.c @@ -0,0 +1,1326 @@ +// SPDX-License-Identifier: GPL-2.0 +// Copyright (c) 2010-2011 EIA Electronics, +// Pieter Beyens +// Copyright (c) 2010-2011 EIA Electronics, +// Kurt Van Dijck +// Copyright (c) 2018 Protonic, +// Robin van der Gracht +// Copyright (c) 2017-2019 Pengutronix, +// Marc Kleine-Budde +// Copyright (c) 2017-2019 Pengutronix, +// Oleksij Rempel + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include + +#include "j1939-priv.h" + +#define J1939_MIN_NAMELEN CAN_REQUIRED_SIZE(struct sockaddr_can, can_addr.j1939) + +/* conversion function between struct sock::sk_priority from linux and + * j1939 priority field + */ +static inline priority_t j1939_prio(u32 sk_priority) +{ + sk_priority = min(sk_priority, 7U); + + return 7 - sk_priority; +} + +static inline u32 j1939_to_sk_priority(priority_t prio) +{ + return 7 - prio; +} + +/* function to see if pgn is to be evaluated */ +static inline bool j1939_pgn_is_valid(pgn_t pgn) +{ + return pgn <= J1939_PGN_MAX; +} + +/* test function to avoid non-zero DA placeholder for pdu1 pgn's */ +static inline bool j1939_pgn_is_clean_pdu(pgn_t pgn) +{ + if (j1939_pgn_is_pdu1(pgn)) + return !(pgn & 0xff); + else + return true; +} + +static inline void j1939_sock_pending_add(struct sock *sk) +{ + struct j1939_sock *jsk = j1939_sk(sk); + + atomic_inc(&jsk->skb_pending); +} + +static int j1939_sock_pending_get(struct sock *sk) +{ + struct j1939_sock *jsk = j1939_sk(sk); + + return atomic_read(&jsk->skb_pending); +} + +void j1939_sock_pending_del(struct sock *sk) +{ + struct j1939_sock *jsk = j1939_sk(sk); + + /* atomic_dec_return returns the new value */ + if (!atomic_dec_return(&jsk->skb_pending)) + wake_up(&jsk->waitq); /* no pending SKB's */ +} + +static void j1939_jsk_add(struct j1939_priv *priv, struct j1939_sock *jsk) +{ + jsk->state |= J1939_SOCK_BOUND; + j1939_priv_get(priv); + + spin_lock_bh(&priv->j1939_socks_lock); + list_add_tail(&jsk->list, &priv->j1939_socks); + spin_unlock_bh(&priv->j1939_socks_lock); +} + +static void j1939_jsk_del(struct j1939_priv *priv, struct j1939_sock *jsk) +{ + spin_lock_bh(&priv->j1939_socks_lock); + list_del_init(&jsk->list); + spin_unlock_bh(&priv->j1939_socks_lock); + + j1939_priv_put(priv); + jsk->state &= ~J1939_SOCK_BOUND; +} + +static bool j1939_sk_queue_session(struct j1939_session *session) +{ + struct j1939_sock *jsk = j1939_sk(session->sk); + bool empty; + + spin_lock_bh(&jsk->sk_session_queue_lock); + empty = list_empty(&jsk->sk_session_queue); + j1939_session_get(session); + list_add_tail(&session->sk_session_queue_entry, &jsk->sk_session_queue); + spin_unlock_bh(&jsk->sk_session_queue_lock); + j1939_sock_pending_add(&jsk->sk); + + return empty; +} + +static struct +j1939_session *j1939_sk_get_incomplete_session(struct j1939_sock *jsk) +{ + struct j1939_session *session = NULL; + + spin_lock_bh(&jsk->sk_session_queue_lock); + if (!list_empty(&jsk->sk_session_queue)) { + session = list_last_entry(&jsk->sk_session_queue, + struct j1939_session, + sk_session_queue_entry); + if (session->total_queued_size == session->total_message_size) + session = NULL; + else + j1939_session_get(session); + } + spin_unlock_bh(&jsk->sk_session_queue_lock); + + return session; +} + +static void j1939_sk_queue_drop_all(struct j1939_priv *priv, + struct j1939_sock *jsk, int err) +{ + struct j1939_session *session, *tmp; + + netdev_dbg(priv->ndev, "%s: err: %i\n", __func__, err); + spin_lock_bh(&jsk->sk_session_queue_lock); + list_for_each_entry_safe(session, tmp, &jsk->sk_session_queue, + sk_session_queue_entry) { + list_del_init(&session->sk_session_queue_entry); + session->err = err; + j1939_session_put(session); + } + spin_unlock_bh(&jsk->sk_session_queue_lock); +} + +static void j1939_sk_queue_activate_next_locked(struct j1939_session *session) +{ + struct j1939_sock *jsk; + struct j1939_session *first; + int err; + + /* RX-Session don't have a socket (yet) */ + if (!session->sk) + return; + + jsk = j1939_sk(session->sk); + lockdep_assert_held(&jsk->sk_session_queue_lock); + + err = session->err; + + first = list_first_entry_or_null(&jsk->sk_session_queue, + struct j1939_session, + sk_session_queue_entry); + + /* Some else has already activated the next session */ + if (first != session) + return; + +activate_next: + list_del_init(&first->sk_session_queue_entry); + j1939_session_put(first); + first = list_first_entry_or_null(&jsk->sk_session_queue, + struct j1939_session, + sk_session_queue_entry); + if (!first) + return; + + if (j1939_session_activate(first)) { + netdev_warn_once(first->priv->ndev, + "%s: 0x%p: Identical session is already activated.\n", + __func__, first); + first->err = -EBUSY; + goto activate_next; + } else { + /* Give receiver some time (arbitrary chosen) to recover */ + int time_ms = 0; + + if (err) + time_ms = 10 + prandom_u32_max(16); + + j1939_tp_schedule_txtimer(first, time_ms); + } +} + +void j1939_sk_queue_activate_next(struct j1939_session *session) +{ + struct j1939_sock *jsk; + + if (!session->sk) + return; + + jsk = j1939_sk(session->sk); + + spin_lock_bh(&jsk->sk_session_queue_lock); + j1939_sk_queue_activate_next_locked(session); + spin_unlock_bh(&jsk->sk_session_queue_lock); +} + +static bool j1939_sk_match_dst(struct j1939_sock *jsk, + const struct j1939_sk_buff_cb *skcb) +{ + if ((jsk->state & J1939_SOCK_PROMISC)) + return true; + + /* Destination address filter */ + if (jsk->addr.src_name && skcb->addr.dst_name) { + if (jsk->addr.src_name != skcb->addr.dst_name) + return false; + } else { + /* receive (all sockets) if + * - all packages that match our bind() address + * - all broadcast on a socket if SO_BROADCAST + * is set + */ + if (j1939_address_is_unicast(skcb->addr.da)) { + if (jsk->addr.sa != skcb->addr.da) + return false; + } else if (!sock_flag(&jsk->sk, SOCK_BROADCAST)) { + /* receiving broadcast without SO_BROADCAST + * flag is not allowed + */ + return false; + } + } + + /* Source address filter */ + if (jsk->state & J1939_SOCK_CONNECTED) { + /* receive (all sockets) if + * - all packages that match our connect() name or address + */ + if (jsk->addr.dst_name && skcb->addr.src_name) { + if (jsk->addr.dst_name != skcb->addr.src_name) + return false; + } else { + if (jsk->addr.da != skcb->addr.sa) + return false; + } + } + + /* PGN filter */ + if (j1939_pgn_is_valid(jsk->pgn_rx_filter) && + jsk->pgn_rx_filter != skcb->addr.pgn) + return false; + + return true; +} + +/* matches skb control buffer (addr) with a j1939 filter */ +static bool j1939_sk_match_filter(struct j1939_sock *jsk, + const struct j1939_sk_buff_cb *skcb) +{ + const struct j1939_filter *f = jsk->filters; + int nfilter = jsk->nfilters; + + if (!nfilter) + /* receive all when no filters are assigned */ + return true; + + for (; nfilter; ++f, --nfilter) { + if ((skcb->addr.pgn & f->pgn_mask) != f->pgn) + continue; + if ((skcb->addr.sa & f->addr_mask) != f->addr) + continue; + if ((skcb->addr.src_name & f->name_mask) != f->name) + continue; + return true; + } + return false; +} + +static bool j1939_sk_recv_match_one(struct j1939_sock *jsk, + const struct j1939_sk_buff_cb *skcb) +{ + if (!(jsk->state & J1939_SOCK_BOUND)) + return false; + + if (!j1939_sk_match_dst(jsk, skcb)) + return false; + + if (!j1939_sk_match_filter(jsk, skcb)) + return false; + + return true; +} + +static void j1939_sk_recv_one(struct j1939_sock *jsk, struct sk_buff *oskb) +{ + const struct j1939_sk_buff_cb *oskcb = j1939_skb_to_cb(oskb); + struct j1939_sk_buff_cb *skcb; + struct sk_buff *skb; + + if (oskb->sk == &jsk->sk) + return; + + if (!j1939_sk_recv_match_one(jsk, oskcb)) + return; + + skb = skb_clone(oskb, GFP_ATOMIC); + if (!skb) { + pr_warn("skb clone failed\n"); + return; + } + can_skb_set_owner(skb, oskb->sk); + + skcb = j1939_skb_to_cb(skb); + skcb->msg_flags &= ~(MSG_DONTROUTE); + if (skb->sk) + skcb->msg_flags |= MSG_DONTROUTE; + + if (sock_queue_rcv_skb(&jsk->sk, skb) < 0) + kfree_skb(skb); +} + +bool j1939_sk_recv_match(struct j1939_priv *priv, struct j1939_sk_buff_cb *skcb) +{ + struct j1939_sock *jsk; + bool match = false; + + spin_lock_bh(&priv->j1939_socks_lock); + list_for_each_entry(jsk, &priv->j1939_socks, list) { + match = j1939_sk_recv_match_one(jsk, skcb); + if (match) + break; + } + spin_unlock_bh(&priv->j1939_socks_lock); + + return match; +} + +void j1939_sk_recv(struct j1939_priv *priv, struct sk_buff *skb) +{ + struct j1939_sock *jsk; + + spin_lock_bh(&priv->j1939_socks_lock); + list_for_each_entry(jsk, &priv->j1939_socks, list) { + j1939_sk_recv_one(jsk, skb); + } + spin_unlock_bh(&priv->j1939_socks_lock); +} + +static void j1939_sk_sock_destruct(struct sock *sk) +{ + struct j1939_sock *jsk = j1939_sk(sk); + + /* This function will be called by the generic networking code, when + * the socket is ultimately closed (sk->sk_destruct). + * + * The race between + * - processing a received CAN frame + * (can_receive -> j1939_can_recv) + * and accessing j1939_priv + * ... and ... + * - closing a socket + * (j1939_can_rx_unregister -> can_rx_unregister) + * and calling the final j1939_priv_put() + * + * is avoided by calling the final j1939_priv_put() from this + * RCU deferred cleanup call. + */ + if (jsk->priv) { + j1939_priv_put(jsk->priv); + jsk->priv = NULL; + } + + /* call generic CAN sock destruct */ + can_sock_destruct(sk); +} + +static int j1939_sk_init(struct sock *sk) +{ + struct j1939_sock *jsk = j1939_sk(sk); + + /* Ensure that "sk" is first member in "struct j1939_sock", so that we + * can skip it during memset(). + */ + BUILD_BUG_ON(offsetof(struct j1939_sock, sk) != 0); + memset((void *)jsk + sizeof(jsk->sk), 0x0, + sizeof(*jsk) - sizeof(jsk->sk)); + + INIT_LIST_HEAD(&jsk->list); + init_waitqueue_head(&jsk->waitq); + jsk->sk.sk_priority = j1939_to_sk_priority(6); + jsk->sk.sk_reuse = 1; /* per default */ + jsk->addr.sa = J1939_NO_ADDR; + jsk->addr.da = J1939_NO_ADDR; + jsk->addr.pgn = J1939_NO_PGN; + jsk->pgn_rx_filter = J1939_NO_PGN; + atomic_set(&jsk->skb_pending, 0); + spin_lock_init(&jsk->sk_session_queue_lock); + INIT_LIST_HEAD(&jsk->sk_session_queue); + + /* j1939_sk_sock_destruct() depends on SOCK_RCU_FREE flag */ + sock_set_flag(sk, SOCK_RCU_FREE); + sk->sk_destruct = j1939_sk_sock_destruct; + sk->sk_protocol = CAN_J1939; + + return 0; +} + +static int j1939_sk_sanity_check(struct sockaddr_can *addr, int len) +{ + if (!addr) + return -EDESTADDRREQ; + if (len < J1939_MIN_NAMELEN) + return -EINVAL; + if (addr->can_family != AF_CAN) + return -EINVAL; + if (!addr->can_ifindex) + return -ENODEV; + if (j1939_pgn_is_valid(addr->can_addr.j1939.pgn) && + !j1939_pgn_is_clean_pdu(addr->can_addr.j1939.pgn)) + return -EINVAL; + + return 0; +} + +static int j1939_sk_bind(struct socket *sock, struct sockaddr *uaddr, int len) +{ + struct sockaddr_can *addr = (struct sockaddr_can *)uaddr; + struct j1939_sock *jsk = j1939_sk(sock->sk); + struct j1939_priv *priv; + struct sock *sk; + struct net *net; + int ret = 0; + + ret = j1939_sk_sanity_check(addr, len); + if (ret) + return ret; + + lock_sock(sock->sk); + + priv = jsk->priv; + sk = sock->sk; + net = sock_net(sk); + + /* Already bound to an interface? */ + if (jsk->state & J1939_SOCK_BOUND) { + /* A re-bind() to a different interface is not + * supported. + */ + if (jsk->ifindex != addr->can_ifindex) { + ret = -EINVAL; + goto out_release_sock; + } + + /* drop old references */ + j1939_jsk_del(priv, jsk); + j1939_local_ecu_put(priv, jsk->addr.src_name, jsk->addr.sa); + } else { + struct can_ml_priv *can_ml; + struct net_device *ndev; + + ndev = dev_get_by_index(net, addr->can_ifindex); + if (!ndev) { + ret = -ENODEV; + goto out_release_sock; + } + + can_ml = can_get_ml_priv(ndev); + if (!can_ml) { + dev_put(ndev); + ret = -ENODEV; + goto out_release_sock; + } + + if (!(ndev->flags & IFF_UP)) { + dev_put(ndev); + ret = -ENETDOWN; + goto out_release_sock; + } + + priv = j1939_netdev_start(ndev); + dev_put(ndev); + if (IS_ERR(priv)) { + ret = PTR_ERR(priv); + goto out_release_sock; + } + + jsk->ifindex = addr->can_ifindex; + + /* the corresponding j1939_priv_put() is called via + * sk->sk_destruct, which points to j1939_sk_sock_destruct() + */ + j1939_priv_get(priv); + jsk->priv = priv; + } + + /* set default transmit pgn */ + if (j1939_pgn_is_valid(addr->can_addr.j1939.pgn)) + jsk->pgn_rx_filter = addr->can_addr.j1939.pgn; + jsk->addr.src_name = addr->can_addr.j1939.name; + jsk->addr.sa = addr->can_addr.j1939.addr; + + /* get new references */ + ret = j1939_local_ecu_get(priv, jsk->addr.src_name, jsk->addr.sa); + if (ret) { + j1939_netdev_stop(priv); + goto out_release_sock; + } + + j1939_jsk_add(priv, jsk); + + out_release_sock: /* fall through */ + release_sock(sock->sk); + + return ret; +} + +static int j1939_sk_connect(struct socket *sock, struct sockaddr *uaddr, + int len, int flags) +{ + struct sockaddr_can *addr = (struct sockaddr_can *)uaddr; + struct j1939_sock *jsk = j1939_sk(sock->sk); + int ret = 0; + + ret = j1939_sk_sanity_check(addr, len); + if (ret) + return ret; + + lock_sock(sock->sk); + + /* bind() before connect() is mandatory */ + if (!(jsk->state & J1939_SOCK_BOUND)) { + ret = -EINVAL; + goto out_release_sock; + } + + /* A connect() to a different interface is not supported. */ + if (jsk->ifindex != addr->can_ifindex) { + ret = -EINVAL; + goto out_release_sock; + } + + if (!addr->can_addr.j1939.name && + addr->can_addr.j1939.addr == J1939_NO_ADDR && + !sock_flag(&jsk->sk, SOCK_BROADCAST)) { + /* broadcast, but SO_BROADCAST not set */ + ret = -EACCES; + goto out_release_sock; + } + + jsk->addr.dst_name = addr->can_addr.j1939.name; + jsk->addr.da = addr->can_addr.j1939.addr; + + if (j1939_pgn_is_valid(addr->can_addr.j1939.pgn)) + jsk->addr.pgn = addr->can_addr.j1939.pgn; + + jsk->state |= J1939_SOCK_CONNECTED; + + out_release_sock: /* fall through */ + release_sock(sock->sk); + + return ret; +} + +static void j1939_sk_sock2sockaddr_can(struct sockaddr_can *addr, + const struct j1939_sock *jsk, int peer) +{ + /* There are two holes (2 bytes and 3 bytes) to clear to avoid + * leaking kernel information to user space. + */ + memset(addr, 0, J1939_MIN_NAMELEN); + + addr->can_family = AF_CAN; + addr->can_ifindex = jsk->ifindex; + addr->can_addr.j1939.pgn = jsk->addr.pgn; + if (peer) { + addr->can_addr.j1939.name = jsk->addr.dst_name; + addr->can_addr.j1939.addr = jsk->addr.da; + } else { + addr->can_addr.j1939.name = jsk->addr.src_name; + addr->can_addr.j1939.addr = jsk->addr.sa; + } +} + +static int j1939_sk_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + struct sockaddr_can *addr = (struct sockaddr_can *)uaddr; + struct sock *sk = sock->sk; + struct j1939_sock *jsk = j1939_sk(sk); + int ret = 0; + + lock_sock(sk); + + if (peer && !(jsk->state & J1939_SOCK_CONNECTED)) { + ret = -EADDRNOTAVAIL; + goto failure; + } + + j1939_sk_sock2sockaddr_can(addr, jsk, peer); + ret = J1939_MIN_NAMELEN; + + failure: + release_sock(sk); + + return ret; +} + +static int j1939_sk_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct j1939_sock *jsk; + + if (!sk) + return 0; + + lock_sock(sk); + jsk = j1939_sk(sk); + + if (jsk->state & J1939_SOCK_BOUND) { + struct j1939_priv *priv = jsk->priv; + + if (wait_event_interruptible(jsk->waitq, + !j1939_sock_pending_get(&jsk->sk))) { + j1939_cancel_active_session(priv, sk); + j1939_sk_queue_drop_all(priv, jsk, ESHUTDOWN); + } + + j1939_jsk_del(priv, jsk); + + j1939_local_ecu_put(priv, jsk->addr.src_name, + jsk->addr.sa); + + j1939_netdev_stop(priv); + } + + kfree(jsk->filters); + sock_orphan(sk); + sock->sk = NULL; + + release_sock(sk); + sock_put(sk); + + return 0; +} + +static int j1939_sk_setsockopt_flag(struct j1939_sock *jsk, sockptr_t optval, + unsigned int optlen, int flag) +{ + int tmp; + + if (optlen != sizeof(tmp)) + return -EINVAL; + if (copy_from_sockptr(&tmp, optval, optlen)) + return -EFAULT; + lock_sock(&jsk->sk); + if (tmp) + jsk->state |= flag; + else + jsk->state &= ~flag; + release_sock(&jsk->sk); + return tmp; +} + +static int j1939_sk_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct j1939_sock *jsk = j1939_sk(sk); + int tmp, count = 0, ret = 0; + struct j1939_filter *filters = NULL, *ofilters; + + if (level != SOL_CAN_J1939) + return -EINVAL; + + switch (optname) { + case SO_J1939_FILTER: + if (!sockptr_is_null(optval) && optlen != 0) { + struct j1939_filter *f; + int c; + + if (optlen % sizeof(*filters) != 0) + return -EINVAL; + + if (optlen > J1939_FILTER_MAX * + sizeof(struct j1939_filter)) + return -EINVAL; + + count = optlen / sizeof(*filters); + filters = memdup_sockptr(optval, optlen); + if (IS_ERR(filters)) + return PTR_ERR(filters); + + for (f = filters, c = count; c; f++, c--) { + f->name &= f->name_mask; + f->pgn &= f->pgn_mask; + f->addr &= f->addr_mask; + } + } + + lock_sock(&jsk->sk); + ofilters = jsk->filters; + jsk->filters = filters; + jsk->nfilters = count; + release_sock(&jsk->sk); + kfree(ofilters); + return 0; + case SO_J1939_PROMISC: + return j1939_sk_setsockopt_flag(jsk, optval, optlen, + J1939_SOCK_PROMISC); + case SO_J1939_ERRQUEUE: + ret = j1939_sk_setsockopt_flag(jsk, optval, optlen, + J1939_SOCK_ERRQUEUE); + if (ret < 0) + return ret; + + if (!(jsk->state & J1939_SOCK_ERRQUEUE)) + skb_queue_purge(&sk->sk_error_queue); + return ret; + case SO_J1939_SEND_PRIO: + if (optlen != sizeof(tmp)) + return -EINVAL; + if (copy_from_sockptr(&tmp, optval, optlen)) + return -EFAULT; + if (tmp < 0 || tmp > 7) + return -EDOM; + if (tmp < 2 && !capable(CAP_NET_ADMIN)) + return -EPERM; + lock_sock(&jsk->sk); + jsk->sk.sk_priority = j1939_to_sk_priority(tmp); + release_sock(&jsk->sk); + return 0; + default: + return -ENOPROTOOPT; + } +} + +static int j1939_sk_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + struct j1939_sock *jsk = j1939_sk(sk); + int ret, ulen; + /* set defaults for using 'int' properties */ + int tmp = 0; + int len = sizeof(tmp); + void *val = &tmp; + + if (level != SOL_CAN_J1939) + return -EINVAL; + if (get_user(ulen, optlen)) + return -EFAULT; + if (ulen < 0) + return -EINVAL; + + lock_sock(&jsk->sk); + switch (optname) { + case SO_J1939_PROMISC: + tmp = (jsk->state & J1939_SOCK_PROMISC) ? 1 : 0; + break; + case SO_J1939_ERRQUEUE: + tmp = (jsk->state & J1939_SOCK_ERRQUEUE) ? 1 : 0; + break; + case SO_J1939_SEND_PRIO: + tmp = j1939_prio(jsk->sk.sk_priority); + break; + default: + ret = -ENOPROTOOPT; + goto no_copy; + } + + /* copy to user, based on 'len' & 'val' + * but most sockopt's are 'int' properties, and have 'len' & 'val' + * left unchanged, but instead modified 'tmp' + */ + if (len > ulen) + ret = -EFAULT; + else if (put_user(len, optlen)) + ret = -EFAULT; + else if (copy_to_user(optval, val, len)) + ret = -EFAULT; + else + ret = 0; + no_copy: + release_sock(&jsk->sk); + return ret; +} + +static int j1939_sk_recvmsg(struct socket *sock, struct msghdr *msg, + size_t size, int flags) +{ + struct sock *sk = sock->sk; + struct sk_buff *skb; + struct j1939_sk_buff_cb *skcb; + int ret = 0; + + if (flags & ~(MSG_DONTWAIT | MSG_ERRQUEUE | MSG_CMSG_COMPAT)) + return -EINVAL; + + if (flags & MSG_ERRQUEUE) + return sock_recv_errqueue(sock->sk, msg, size, SOL_CAN_J1939, + SCM_J1939_ERRQUEUE); + + skb = skb_recv_datagram(sk, flags, &ret); + if (!skb) + return ret; + + if (size < skb->len) + msg->msg_flags |= MSG_TRUNC; + else + size = skb->len; + + ret = memcpy_to_msg(msg, skb->data, size); + if (ret < 0) { + skb_free_datagram(sk, skb); + return ret; + } + + skcb = j1939_skb_to_cb(skb); + if (j1939_address_is_valid(skcb->addr.da)) + put_cmsg(msg, SOL_CAN_J1939, SCM_J1939_DEST_ADDR, + sizeof(skcb->addr.da), &skcb->addr.da); + + if (skcb->addr.dst_name) + put_cmsg(msg, SOL_CAN_J1939, SCM_J1939_DEST_NAME, + sizeof(skcb->addr.dst_name), &skcb->addr.dst_name); + + put_cmsg(msg, SOL_CAN_J1939, SCM_J1939_PRIO, + sizeof(skcb->priority), &skcb->priority); + + if (msg->msg_name) { + struct sockaddr_can *paddr = msg->msg_name; + + msg->msg_namelen = J1939_MIN_NAMELEN; + memset(msg->msg_name, 0, msg->msg_namelen); + paddr->can_family = AF_CAN; + paddr->can_ifindex = skb->skb_iif; + paddr->can_addr.j1939.name = skcb->addr.src_name; + paddr->can_addr.j1939.addr = skcb->addr.sa; + paddr->can_addr.j1939.pgn = skcb->addr.pgn; + } + + sock_recv_cmsgs(msg, sk, skb); + msg->msg_flags |= skcb->msg_flags; + skb_free_datagram(sk, skb); + + return size; +} + +static struct sk_buff *j1939_sk_alloc_skb(struct net_device *ndev, + struct sock *sk, + struct msghdr *msg, size_t size, + int *errcode) +{ + struct j1939_sock *jsk = j1939_sk(sk); + struct j1939_sk_buff_cb *skcb; + struct sk_buff *skb; + int ret; + + skb = sock_alloc_send_skb(sk, + size + + sizeof(struct can_frame) - + sizeof(((struct can_frame *)NULL)->data) + + sizeof(struct can_skb_priv), + msg->msg_flags & MSG_DONTWAIT, &ret); + if (!skb) + goto failure; + + can_skb_reserve(skb); + can_skb_prv(skb)->ifindex = ndev->ifindex; + can_skb_prv(skb)->skbcnt = 0; + skb_reserve(skb, offsetof(struct can_frame, data)); + + ret = memcpy_from_msg(skb_put(skb, size), msg, size); + if (ret < 0) + goto free_skb; + + skb->dev = ndev; + + skcb = j1939_skb_to_cb(skb); + memset(skcb, 0, sizeof(*skcb)); + skcb->addr = jsk->addr; + skcb->priority = j1939_prio(sk->sk_priority); + + if (msg->msg_name) { + struct sockaddr_can *addr = msg->msg_name; + + if (addr->can_addr.j1939.name || + addr->can_addr.j1939.addr != J1939_NO_ADDR) { + skcb->addr.dst_name = addr->can_addr.j1939.name; + skcb->addr.da = addr->can_addr.j1939.addr; + } + if (j1939_pgn_is_valid(addr->can_addr.j1939.pgn)) + skcb->addr.pgn = addr->can_addr.j1939.pgn; + } + + *errcode = ret; + return skb; + +free_skb: + kfree_skb(skb); +failure: + *errcode = ret; + return NULL; +} + +static size_t j1939_sk_opt_stats_get_size(enum j1939_sk_errqueue_type type) +{ + switch (type) { + case J1939_ERRQUEUE_RX_RTS: + return + nla_total_size(sizeof(u32)) + /* J1939_NLA_TOTAL_SIZE */ + nla_total_size(sizeof(u32)) + /* J1939_NLA_PGN */ + nla_total_size(sizeof(u64)) + /* J1939_NLA_SRC_NAME */ + nla_total_size(sizeof(u64)) + /* J1939_NLA_DEST_NAME */ + nla_total_size(sizeof(u8)) + /* J1939_NLA_SRC_ADDR */ + nla_total_size(sizeof(u8)) + /* J1939_NLA_DEST_ADDR */ + 0; + default: + return + nla_total_size(sizeof(u32)) + /* J1939_NLA_BYTES_ACKED */ + 0; + } +} + +static struct sk_buff * +j1939_sk_get_timestamping_opt_stats(struct j1939_session *session, + enum j1939_sk_errqueue_type type) +{ + struct sk_buff *stats; + u32 size; + + stats = alloc_skb(j1939_sk_opt_stats_get_size(type), GFP_ATOMIC); + if (!stats) + return NULL; + + if (session->skcb.addr.type == J1939_SIMPLE) + size = session->total_message_size; + else + size = min(session->pkt.tx_acked * 7, + session->total_message_size); + + switch (type) { + case J1939_ERRQUEUE_RX_RTS: + nla_put_u32(stats, J1939_NLA_TOTAL_SIZE, + session->total_message_size); + nla_put_u32(stats, J1939_NLA_PGN, + session->skcb.addr.pgn); + nla_put_u64_64bit(stats, J1939_NLA_SRC_NAME, + session->skcb.addr.src_name, J1939_NLA_PAD); + nla_put_u64_64bit(stats, J1939_NLA_DEST_NAME, + session->skcb.addr.dst_name, J1939_NLA_PAD); + nla_put_u8(stats, J1939_NLA_SRC_ADDR, + session->skcb.addr.sa); + nla_put_u8(stats, J1939_NLA_DEST_ADDR, + session->skcb.addr.da); + break; + default: + nla_put_u32(stats, J1939_NLA_BYTES_ACKED, size); + } + + return stats; +} + +static void __j1939_sk_errqueue(struct j1939_session *session, struct sock *sk, + enum j1939_sk_errqueue_type type) +{ + struct j1939_priv *priv = session->priv; + struct j1939_sock *jsk; + struct sock_exterr_skb *serr; + struct sk_buff *skb; + char *state = "UNK"; + u32 tsflags; + int err; + + jsk = j1939_sk(sk); + + if (!(jsk->state & J1939_SOCK_ERRQUEUE)) + return; + + tsflags = READ_ONCE(sk->sk_tsflags); + switch (type) { + case J1939_ERRQUEUE_TX_ACK: + if (!(tsflags & SOF_TIMESTAMPING_TX_ACK)) + return; + break; + case J1939_ERRQUEUE_TX_SCHED: + if (!(tsflags & SOF_TIMESTAMPING_TX_SCHED)) + return; + break; + case J1939_ERRQUEUE_TX_ABORT: + break; + case J1939_ERRQUEUE_RX_RTS: + fallthrough; + case J1939_ERRQUEUE_RX_DPO: + fallthrough; + case J1939_ERRQUEUE_RX_ABORT: + if (!(tsflags & SOF_TIMESTAMPING_RX_SOFTWARE)) + return; + break; + default: + netdev_err(priv->ndev, "Unknown errqueue type %i\n", type); + } + + skb = j1939_sk_get_timestamping_opt_stats(session, type); + if (!skb) + return; + + skb->tstamp = ktime_get_real(); + + BUILD_BUG_ON(sizeof(struct sock_exterr_skb) > sizeof(skb->cb)); + + serr = SKB_EXT_ERR(skb); + memset(serr, 0, sizeof(*serr)); + switch (type) { + case J1939_ERRQUEUE_TX_ACK: + serr->ee.ee_errno = ENOMSG; + serr->ee.ee_origin = SO_EE_ORIGIN_TIMESTAMPING; + serr->ee.ee_info = SCM_TSTAMP_ACK; + state = "TX ACK"; + break; + case J1939_ERRQUEUE_TX_SCHED: + serr->ee.ee_errno = ENOMSG; + serr->ee.ee_origin = SO_EE_ORIGIN_TIMESTAMPING; + serr->ee.ee_info = SCM_TSTAMP_SCHED; + state = "TX SCH"; + break; + case J1939_ERRQUEUE_TX_ABORT: + serr->ee.ee_errno = session->err; + serr->ee.ee_origin = SO_EE_ORIGIN_LOCAL; + serr->ee.ee_info = J1939_EE_INFO_TX_ABORT; + state = "TX ABT"; + break; + case J1939_ERRQUEUE_RX_RTS: + serr->ee.ee_errno = ENOMSG; + serr->ee.ee_origin = SO_EE_ORIGIN_LOCAL; + serr->ee.ee_info = J1939_EE_INFO_RX_RTS; + state = "RX RTS"; + break; + case J1939_ERRQUEUE_RX_DPO: + serr->ee.ee_errno = ENOMSG; + serr->ee.ee_origin = SO_EE_ORIGIN_LOCAL; + serr->ee.ee_info = J1939_EE_INFO_RX_DPO; + state = "RX DPO"; + break; + case J1939_ERRQUEUE_RX_ABORT: + serr->ee.ee_errno = session->err; + serr->ee.ee_origin = SO_EE_ORIGIN_LOCAL; + serr->ee.ee_info = J1939_EE_INFO_RX_ABORT; + state = "RX ABT"; + break; + } + + serr->opt_stats = true; + if (tsflags & SOF_TIMESTAMPING_OPT_ID) + serr->ee.ee_data = session->tskey; + + netdev_dbg(session->priv->ndev, "%s: 0x%p tskey: %i, state: %s\n", + __func__, session, session->tskey, state); + err = sock_queue_err_skb(sk, skb); + + if (err) + kfree_skb(skb); +}; + +void j1939_sk_errqueue(struct j1939_session *session, + enum j1939_sk_errqueue_type type) +{ + struct j1939_priv *priv = session->priv; + struct j1939_sock *jsk; + + if (session->sk) { + /* send TX notifications to the socket of origin */ + __j1939_sk_errqueue(session, session->sk, type); + return; + } + + /* spread RX notifications to all sockets subscribed to this session */ + spin_lock_bh(&priv->j1939_socks_lock); + list_for_each_entry(jsk, &priv->j1939_socks, list) { + if (j1939_sk_recv_match_one(jsk, &session->skcb)) + __j1939_sk_errqueue(session, &jsk->sk, type); + } + spin_unlock_bh(&priv->j1939_socks_lock); +}; + +void j1939_sk_send_loop_abort(struct sock *sk, int err) +{ + struct j1939_sock *jsk = j1939_sk(sk); + + if (jsk->state & J1939_SOCK_ERRQUEUE) + return; + + sk->sk_err = err; + + sk_error_report(sk); +} + +static int j1939_sk_send_loop(struct j1939_priv *priv, struct sock *sk, + struct msghdr *msg, size_t size) + +{ + struct j1939_sock *jsk = j1939_sk(sk); + struct j1939_session *session = j1939_sk_get_incomplete_session(jsk); + struct sk_buff *skb; + size_t segment_size, todo_size; + int ret = 0; + + if (session && + session->total_message_size != session->total_queued_size + size) { + j1939_session_put(session); + return -EIO; + } + + todo_size = size; + + while (todo_size) { + struct j1939_sk_buff_cb *skcb; + + segment_size = min_t(size_t, J1939_MAX_TP_PACKET_SIZE, + todo_size); + + /* Allocate skb for one segment */ + skb = j1939_sk_alloc_skb(priv->ndev, sk, msg, segment_size, + &ret); + if (ret) + break; + + skcb = j1939_skb_to_cb(skb); + + if (!session) { + /* at this point the size should be full size + * of the session + */ + skcb->offset = 0; + session = j1939_tp_send(priv, skb, size); + if (IS_ERR(session)) { + ret = PTR_ERR(session); + goto kfree_skb; + } + if (j1939_sk_queue_session(session)) { + /* try to activate session if we a + * fist in the queue + */ + if (!j1939_session_activate(session)) { + j1939_tp_schedule_txtimer(session, 0); + } else { + ret = -EBUSY; + session->err = ret; + j1939_sk_queue_drop_all(priv, jsk, + EBUSY); + break; + } + } + } else { + skcb->offset = session->total_queued_size; + j1939_session_skb_queue(session, skb); + } + + todo_size -= segment_size; + session->total_queued_size += segment_size; + } + + switch (ret) { + case 0: /* OK */ + if (todo_size) + netdev_warn(priv->ndev, + "no error found and not completely queued?! %zu\n", + todo_size); + ret = size; + break; + case -ERESTARTSYS: + ret = -EINTR; + fallthrough; + case -EAGAIN: /* OK */ + if (todo_size != size) + ret = size - todo_size; + break; + default: /* ERROR */ + break; + } + + if (session) + j1939_session_put(session); + + return ret; + + kfree_skb: + kfree_skb(skb); + return ret; +} + +static int j1939_sk_sendmsg(struct socket *sock, struct msghdr *msg, + size_t size) +{ + struct sock *sk = sock->sk; + struct j1939_sock *jsk = j1939_sk(sk); + struct j1939_priv *priv; + int ifindex; + int ret; + + lock_sock(sock->sk); + /* various socket state tests */ + if (!(jsk->state & J1939_SOCK_BOUND)) { + ret = -EBADFD; + goto sendmsg_done; + } + + priv = jsk->priv; + ifindex = jsk->ifindex; + + if (!jsk->addr.src_name && jsk->addr.sa == J1939_NO_ADDR) { + /* no source address assigned yet */ + ret = -EBADFD; + goto sendmsg_done; + } + + /* deal with provided destination address info */ + if (msg->msg_name) { + struct sockaddr_can *addr = msg->msg_name; + + if (msg->msg_namelen < J1939_MIN_NAMELEN) { + ret = -EINVAL; + goto sendmsg_done; + } + + if (addr->can_family != AF_CAN) { + ret = -EINVAL; + goto sendmsg_done; + } + + if (addr->can_ifindex && addr->can_ifindex != ifindex) { + ret = -EBADFD; + goto sendmsg_done; + } + + if (j1939_pgn_is_valid(addr->can_addr.j1939.pgn) && + !j1939_pgn_is_clean_pdu(addr->can_addr.j1939.pgn)) { + ret = -EINVAL; + goto sendmsg_done; + } + + if (!addr->can_addr.j1939.name && + addr->can_addr.j1939.addr == J1939_NO_ADDR && + !sock_flag(sk, SOCK_BROADCAST)) { + /* broadcast, but SO_BROADCAST not set */ + ret = -EACCES; + goto sendmsg_done; + } + } else { + if (!jsk->addr.dst_name && jsk->addr.da == J1939_NO_ADDR && + !sock_flag(sk, SOCK_BROADCAST)) { + /* broadcast, but SO_BROADCAST not set */ + ret = -EACCES; + goto sendmsg_done; + } + } + + ret = j1939_sk_send_loop(priv, sk, msg, size); + +sendmsg_done: + release_sock(sock->sk); + + return ret; +} + +void j1939_sk_netdev_event_netdown(struct j1939_priv *priv) +{ + struct j1939_sock *jsk; + int error_code = ENETDOWN; + + spin_lock_bh(&priv->j1939_socks_lock); + list_for_each_entry(jsk, &priv->j1939_socks, list) { + jsk->sk.sk_err = error_code; + if (!sock_flag(&jsk->sk, SOCK_DEAD)) + sk_error_report(&jsk->sk); + + j1939_sk_queue_drop_all(priv, jsk, error_code); + } + spin_unlock_bh(&priv->j1939_socks_lock); +} + +static int j1939_sk_no_ioctlcmd(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + /* no ioctls for socket layer -> hand it down to NIC layer */ + return -ENOIOCTLCMD; +} + +static const struct proto_ops j1939_ops = { + .family = PF_CAN, + .release = j1939_sk_release, + .bind = j1939_sk_bind, + .connect = j1939_sk_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = j1939_sk_getname, + .poll = datagram_poll, + .ioctl = j1939_sk_no_ioctlcmd, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = j1939_sk_setsockopt, + .getsockopt = j1939_sk_getsockopt, + .sendmsg = j1939_sk_sendmsg, + .recvmsg = j1939_sk_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static struct proto j1939_proto __read_mostly = { + .name = "CAN_J1939", + .owner = THIS_MODULE, + .obj_size = sizeof(struct j1939_sock), + .init = j1939_sk_init, +}; + +const struct can_proto j1939_can_proto = { + .type = SOCK_DGRAM, + .protocol = CAN_J1939, + .ops = &j1939_ops, + .prot = &j1939_proto, +}; diff --git a/net/can/j1939/transport.c b/net/can/j1939/transport.c new file mode 100644 index 000000000..bd8ec2433 --- /dev/null +++ b/net/can/j1939/transport.c @@ -0,0 +1,2206 @@ +// SPDX-License-Identifier: GPL-2.0 +// Copyright (c) 2010-2011 EIA Electronics, +// Kurt Van Dijck +// Copyright (c) 2018 Protonic, +// Robin van der Gracht +// Copyright (c) 2017-2019 Pengutronix, +// Marc Kleine-Budde +// Copyright (c) 2017-2019 Pengutronix, +// Oleksij Rempel + +#include + +#include "j1939-priv.h" + +#define J1939_XTP_TX_RETRY_LIMIT 100 + +#define J1939_ETP_PGN_CTL 0xc800 +#define J1939_ETP_PGN_DAT 0xc700 +#define J1939_TP_PGN_CTL 0xec00 +#define J1939_TP_PGN_DAT 0xeb00 + +#define J1939_TP_CMD_RTS 0x10 +#define J1939_TP_CMD_CTS 0x11 +#define J1939_TP_CMD_EOMA 0x13 +#define J1939_TP_CMD_BAM 0x20 +#define J1939_TP_CMD_ABORT 0xff + +#define J1939_ETP_CMD_RTS 0x14 +#define J1939_ETP_CMD_CTS 0x15 +#define J1939_ETP_CMD_DPO 0x16 +#define J1939_ETP_CMD_EOMA 0x17 +#define J1939_ETP_CMD_ABORT 0xff + +enum j1939_xtp_abort { + J1939_XTP_NO_ABORT = 0, + J1939_XTP_ABORT_BUSY = 1, + /* Already in one or more connection managed sessions and + * cannot support another. + * + * EALREADY: + * Operation already in progress + */ + + J1939_XTP_ABORT_RESOURCE = 2, + /* System resources were needed for another task so this + * connection managed session was terminated. + * + * EMSGSIZE: + * The socket type requires that message be sent atomically, + * and the size of the message to be sent made this + * impossible. + */ + + J1939_XTP_ABORT_TIMEOUT = 3, + /* A timeout occurred and this is the connection abort to + * close the session. + * + * EHOSTUNREACH: + * The destination host cannot be reached (probably because + * the host is down or a remote router cannot reach it). + */ + + J1939_XTP_ABORT_GENERIC = 4, + /* CTS messages received when data transfer is in progress + * + * EBADMSG: + * Not a data message + */ + + J1939_XTP_ABORT_FAULT = 5, + /* Maximal retransmit request limit reached + * + * ENOTRECOVERABLE: + * State not recoverable + */ + + J1939_XTP_ABORT_UNEXPECTED_DATA = 6, + /* Unexpected data transfer packet + * + * ENOTCONN: + * Transport endpoint is not connected + */ + + J1939_XTP_ABORT_BAD_SEQ = 7, + /* Bad sequence number (and software is not able to recover) + * + * EILSEQ: + * Illegal byte sequence + */ + + J1939_XTP_ABORT_DUP_SEQ = 8, + /* Duplicate sequence number (and software is not able to + * recover) + */ + + J1939_XTP_ABORT_EDPO_UNEXPECTED = 9, + /* Unexpected EDPO packet (ETP) or Message size > 1785 bytes + * (TP) + */ + + J1939_XTP_ABORT_BAD_EDPO_PGN = 10, + /* Unexpected EDPO PGN (PGN in EDPO is bad) */ + + J1939_XTP_ABORT_EDPO_OUTOF_CTS = 11, + /* EDPO number of packets is greater than CTS */ + + J1939_XTP_ABORT_BAD_EDPO_OFFSET = 12, + /* Bad EDPO offset */ + + J1939_XTP_ABORT_OTHER_DEPRECATED = 13, + /* Deprecated. Use 250 instead (Any other reason) */ + + J1939_XTP_ABORT_ECTS_UNXPECTED_PGN = 14, + /* Unexpected ECTS PGN (PGN in ECTS is bad) */ + + J1939_XTP_ABORT_ECTS_TOO_BIG = 15, + /* ECTS requested packets exceeds message size */ + + J1939_XTP_ABORT_OTHER = 250, + /* Any other reason (if a Connection Abort reason is + * identified that is not listed in the table use code 250) + */ +}; + +static unsigned int j1939_tp_block = 255; +static unsigned int j1939_tp_packet_delay; +static unsigned int j1939_tp_padding = 1; + +/* helpers */ +static const char *j1939_xtp_abort_to_str(enum j1939_xtp_abort abort) +{ + switch (abort) { + case J1939_XTP_ABORT_BUSY: + return "Already in one or more connection managed sessions and cannot support another."; + case J1939_XTP_ABORT_RESOURCE: + return "System resources were needed for another task so this connection managed session was terminated."; + case J1939_XTP_ABORT_TIMEOUT: + return "A timeout occurred and this is the connection abort to close the session."; + case J1939_XTP_ABORT_GENERIC: + return "CTS messages received when data transfer is in progress"; + case J1939_XTP_ABORT_FAULT: + return "Maximal retransmit request limit reached"; + case J1939_XTP_ABORT_UNEXPECTED_DATA: + return "Unexpected data transfer packet"; + case J1939_XTP_ABORT_BAD_SEQ: + return "Bad sequence number (and software is not able to recover)"; + case J1939_XTP_ABORT_DUP_SEQ: + return "Duplicate sequence number (and software is not able to recover)"; + case J1939_XTP_ABORT_EDPO_UNEXPECTED: + return "Unexpected EDPO packet (ETP) or Message size > 1785 bytes (TP)"; + case J1939_XTP_ABORT_BAD_EDPO_PGN: + return "Unexpected EDPO PGN (PGN in EDPO is bad)"; + case J1939_XTP_ABORT_EDPO_OUTOF_CTS: + return "EDPO number of packets is greater than CTS"; + case J1939_XTP_ABORT_BAD_EDPO_OFFSET: + return "Bad EDPO offset"; + case J1939_XTP_ABORT_OTHER_DEPRECATED: + return "Deprecated. Use 250 instead (Any other reason)"; + case J1939_XTP_ABORT_ECTS_UNXPECTED_PGN: + return "Unexpected ECTS PGN (PGN in ECTS is bad)"; + case J1939_XTP_ABORT_ECTS_TOO_BIG: + return "ECTS requested packets exceeds message size"; + case J1939_XTP_ABORT_OTHER: + return "Any other reason (if a Connection Abort reason is identified that is not listed in the table use code 250)"; + default: + return ""; + } +} + +static int j1939_xtp_abort_to_errno(struct j1939_priv *priv, + enum j1939_xtp_abort abort) +{ + int err; + + switch (abort) { + case J1939_XTP_NO_ABORT: + WARN_ON_ONCE(abort == J1939_XTP_NO_ABORT); + err = 0; + break; + case J1939_XTP_ABORT_BUSY: + err = EALREADY; + break; + case J1939_XTP_ABORT_RESOURCE: + err = EMSGSIZE; + break; + case J1939_XTP_ABORT_TIMEOUT: + err = EHOSTUNREACH; + break; + case J1939_XTP_ABORT_GENERIC: + err = EBADMSG; + break; + case J1939_XTP_ABORT_FAULT: + err = ENOTRECOVERABLE; + break; + case J1939_XTP_ABORT_UNEXPECTED_DATA: + err = ENOTCONN; + break; + case J1939_XTP_ABORT_BAD_SEQ: + err = EILSEQ; + break; + case J1939_XTP_ABORT_DUP_SEQ: + err = EPROTO; + break; + case J1939_XTP_ABORT_EDPO_UNEXPECTED: + err = EPROTO; + break; + case J1939_XTP_ABORT_BAD_EDPO_PGN: + err = EPROTO; + break; + case J1939_XTP_ABORT_EDPO_OUTOF_CTS: + err = EPROTO; + break; + case J1939_XTP_ABORT_BAD_EDPO_OFFSET: + err = EPROTO; + break; + case J1939_XTP_ABORT_OTHER_DEPRECATED: + err = EPROTO; + break; + case J1939_XTP_ABORT_ECTS_UNXPECTED_PGN: + err = EPROTO; + break; + case J1939_XTP_ABORT_ECTS_TOO_BIG: + err = EPROTO; + break; + case J1939_XTP_ABORT_OTHER: + err = EPROTO; + break; + default: + netdev_warn(priv->ndev, "Unknown abort code %i", abort); + err = EPROTO; + } + + return err; +} + +static inline void j1939_session_list_lock(struct j1939_priv *priv) +{ + spin_lock_bh(&priv->active_session_list_lock); +} + +static inline void j1939_session_list_unlock(struct j1939_priv *priv) +{ + spin_unlock_bh(&priv->active_session_list_lock); +} + +void j1939_session_get(struct j1939_session *session) +{ + kref_get(&session->kref); +} + +/* session completion functions */ +static void __j1939_session_drop(struct j1939_session *session) +{ + if (!session->transmission) + return; + + j1939_sock_pending_del(session->sk); + sock_put(session->sk); +} + +static void j1939_session_destroy(struct j1939_session *session) +{ + struct sk_buff *skb; + + if (session->transmission) { + if (session->err) + j1939_sk_errqueue(session, J1939_ERRQUEUE_TX_ABORT); + else + j1939_sk_errqueue(session, J1939_ERRQUEUE_TX_ACK); + } else if (session->err) { + j1939_sk_errqueue(session, J1939_ERRQUEUE_RX_ABORT); + } + + netdev_dbg(session->priv->ndev, "%s: 0x%p\n", __func__, session); + + WARN_ON_ONCE(!list_empty(&session->sk_session_queue_entry)); + WARN_ON_ONCE(!list_empty(&session->active_session_list_entry)); + + while ((skb = skb_dequeue(&session->skb_queue)) != NULL) { + /* drop ref taken in j1939_session_skb_queue() */ + skb_unref(skb); + kfree_skb(skb); + } + __j1939_session_drop(session); + j1939_priv_put(session->priv); + kfree(session); +} + +static void __j1939_session_release(struct kref *kref) +{ + struct j1939_session *session = container_of(kref, struct j1939_session, + kref); + + j1939_session_destroy(session); +} + +void j1939_session_put(struct j1939_session *session) +{ + kref_put(&session->kref, __j1939_session_release); +} + +static void j1939_session_txtimer_cancel(struct j1939_session *session) +{ + if (hrtimer_cancel(&session->txtimer)) + j1939_session_put(session); +} + +static void j1939_session_rxtimer_cancel(struct j1939_session *session) +{ + if (hrtimer_cancel(&session->rxtimer)) + j1939_session_put(session); +} + +void j1939_session_timers_cancel(struct j1939_session *session) +{ + j1939_session_txtimer_cancel(session); + j1939_session_rxtimer_cancel(session); +} + +static inline bool j1939_cb_is_broadcast(const struct j1939_sk_buff_cb *skcb) +{ + return (!skcb->addr.dst_name && (skcb->addr.da == 0xff)); +} + +static void j1939_session_skb_drop_old(struct j1939_session *session) +{ + struct sk_buff *do_skb; + struct j1939_sk_buff_cb *do_skcb; + unsigned int offset_start; + unsigned long flags; + + if (skb_queue_len(&session->skb_queue) < 2) + return; + + offset_start = session->pkt.tx_acked * 7; + + spin_lock_irqsave(&session->skb_queue.lock, flags); + do_skb = skb_peek(&session->skb_queue); + do_skcb = j1939_skb_to_cb(do_skb); + + if ((do_skcb->offset + do_skb->len) < offset_start) { + __skb_unlink(do_skb, &session->skb_queue); + /* drop ref taken in j1939_session_skb_queue() */ + skb_unref(do_skb); + spin_unlock_irqrestore(&session->skb_queue.lock, flags); + + kfree_skb(do_skb); + } else { + spin_unlock_irqrestore(&session->skb_queue.lock, flags); + } +} + +void j1939_session_skb_queue(struct j1939_session *session, + struct sk_buff *skb) +{ + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + struct j1939_priv *priv = session->priv; + + j1939_ac_fixup(priv, skb); + + if (j1939_address_is_unicast(skcb->addr.da) && + priv->ents[skcb->addr.da].nusers) + skcb->flags |= J1939_ECU_LOCAL_DST; + + skcb->flags |= J1939_ECU_LOCAL_SRC; + + skb_get(skb); + skb_queue_tail(&session->skb_queue, skb); +} + +static struct +sk_buff *j1939_session_skb_get_by_offset(struct j1939_session *session, + unsigned int offset_start) +{ + struct j1939_priv *priv = session->priv; + struct j1939_sk_buff_cb *do_skcb; + struct sk_buff *skb = NULL; + struct sk_buff *do_skb; + unsigned long flags; + + spin_lock_irqsave(&session->skb_queue.lock, flags); + skb_queue_walk(&session->skb_queue, do_skb) { + do_skcb = j1939_skb_to_cb(do_skb); + + if (offset_start >= do_skcb->offset && + offset_start < (do_skcb->offset + do_skb->len)) { + skb = do_skb; + } + } + + if (skb) + skb_get(skb); + + spin_unlock_irqrestore(&session->skb_queue.lock, flags); + + if (!skb) + netdev_dbg(priv->ndev, "%s: 0x%p: no skb found for start: %i, queue size: %i\n", + __func__, session, offset_start, + skb_queue_len(&session->skb_queue)); + + return skb; +} + +static struct sk_buff *j1939_session_skb_get(struct j1939_session *session) +{ + unsigned int offset_start; + + offset_start = session->pkt.dpo * 7; + return j1939_session_skb_get_by_offset(session, offset_start); +} + +/* see if we are receiver + * returns 0 for broadcasts, although we will receive them + */ +static inline int j1939_tp_im_receiver(const struct j1939_sk_buff_cb *skcb) +{ + return skcb->flags & J1939_ECU_LOCAL_DST; +} + +/* see if we are sender */ +static inline int j1939_tp_im_transmitter(const struct j1939_sk_buff_cb *skcb) +{ + return skcb->flags & J1939_ECU_LOCAL_SRC; +} + +/* see if we are involved as either receiver or transmitter */ +static int j1939_tp_im_involved(const struct j1939_sk_buff_cb *skcb, bool swap) +{ + if (swap) + return j1939_tp_im_receiver(skcb); + else + return j1939_tp_im_transmitter(skcb); +} + +static int j1939_tp_im_involved_anydir(struct j1939_sk_buff_cb *skcb) +{ + return skcb->flags & (J1939_ECU_LOCAL_SRC | J1939_ECU_LOCAL_DST); +} + +/* extract pgn from flow-ctl message */ +static inline pgn_t j1939_xtp_ctl_to_pgn(const u8 *dat) +{ + pgn_t pgn; + + pgn = (dat[7] << 16) | (dat[6] << 8) | (dat[5] << 0); + if (j1939_pgn_is_pdu1(pgn)) + pgn &= 0xffff00; + return pgn; +} + +static inline unsigned int j1939_tp_ctl_to_size(const u8 *dat) +{ + return (dat[2] << 8) + (dat[1] << 0); +} + +static inline unsigned int j1939_etp_ctl_to_packet(const u8 *dat) +{ + return (dat[4] << 16) | (dat[3] << 8) | (dat[2] << 0); +} + +static inline unsigned int j1939_etp_ctl_to_size(const u8 *dat) +{ + return (dat[4] << 24) | (dat[3] << 16) | + (dat[2] << 8) | (dat[1] << 0); +} + +/* find existing session: + * reverse: swap cb's src & dst + * there is no problem with matching broadcasts, since + * broadcasts (no dst, no da) would never call this + * with reverse == true + */ +static bool j1939_session_match(struct j1939_addr *se_addr, + struct j1939_addr *sk_addr, bool reverse) +{ + if (se_addr->type != sk_addr->type) + return false; + + if (reverse) { + if (se_addr->src_name) { + if (se_addr->src_name != sk_addr->dst_name) + return false; + } else if (se_addr->sa != sk_addr->da) { + return false; + } + + if (se_addr->dst_name) { + if (se_addr->dst_name != sk_addr->src_name) + return false; + } else if (se_addr->da != sk_addr->sa) { + return false; + } + } else { + if (se_addr->src_name) { + if (se_addr->src_name != sk_addr->src_name) + return false; + } else if (se_addr->sa != sk_addr->sa) { + return false; + } + + if (se_addr->dst_name) { + if (se_addr->dst_name != sk_addr->dst_name) + return false; + } else if (se_addr->da != sk_addr->da) { + return false; + } + } + + return true; +} + +static struct +j1939_session *j1939_session_get_by_addr_locked(struct j1939_priv *priv, + struct list_head *root, + struct j1939_addr *addr, + bool reverse, bool transmitter) +{ + struct j1939_session *session; + + lockdep_assert_held(&priv->active_session_list_lock); + + list_for_each_entry(session, root, active_session_list_entry) { + j1939_session_get(session); + if (j1939_session_match(&session->skcb.addr, addr, reverse) && + session->transmission == transmitter) + return session; + j1939_session_put(session); + } + + return NULL; +} + +static struct +j1939_session *j1939_session_get_simple(struct j1939_priv *priv, + struct sk_buff *skb) +{ + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + struct j1939_session *session; + + lockdep_assert_held(&priv->active_session_list_lock); + + list_for_each_entry(session, &priv->active_session_list, + active_session_list_entry) { + j1939_session_get(session); + if (session->skcb.addr.type == J1939_SIMPLE && + session->tskey == skcb->tskey && session->sk == skb->sk) + return session; + j1939_session_put(session); + } + + return NULL; +} + +static struct +j1939_session *j1939_session_get_by_addr(struct j1939_priv *priv, + struct j1939_addr *addr, + bool reverse, bool transmitter) +{ + struct j1939_session *session; + + j1939_session_list_lock(priv); + session = j1939_session_get_by_addr_locked(priv, + &priv->active_session_list, + addr, reverse, transmitter); + j1939_session_list_unlock(priv); + + return session; +} + +static void j1939_skbcb_swap(struct j1939_sk_buff_cb *skcb) +{ + u8 tmp = 0; + + swap(skcb->addr.dst_name, skcb->addr.src_name); + swap(skcb->addr.da, skcb->addr.sa); + + /* swap SRC and DST flags, leave other untouched */ + if (skcb->flags & J1939_ECU_LOCAL_SRC) + tmp |= J1939_ECU_LOCAL_DST; + if (skcb->flags & J1939_ECU_LOCAL_DST) + tmp |= J1939_ECU_LOCAL_SRC; + skcb->flags &= ~(J1939_ECU_LOCAL_SRC | J1939_ECU_LOCAL_DST); + skcb->flags |= tmp; +} + +static struct +sk_buff *j1939_tp_tx_dat_new(struct j1939_priv *priv, + const struct j1939_sk_buff_cb *re_skcb, + bool ctl, + bool swap_src_dst) +{ + struct sk_buff *skb; + struct j1939_sk_buff_cb *skcb; + + skb = alloc_skb(sizeof(struct can_frame) + sizeof(struct can_skb_priv), + GFP_ATOMIC); + if (unlikely(!skb)) + return ERR_PTR(-ENOMEM); + + skb->dev = priv->ndev; + can_skb_reserve(skb); + can_skb_prv(skb)->ifindex = priv->ndev->ifindex; + can_skb_prv(skb)->skbcnt = 0; + /* reserve CAN header */ + skb_reserve(skb, offsetof(struct can_frame, data)); + + /* skb->cb must be large enough to hold a j1939_sk_buff_cb structure */ + BUILD_BUG_ON(sizeof(skb->cb) < sizeof(*re_skcb)); + + memcpy(skb->cb, re_skcb, sizeof(*re_skcb)); + skcb = j1939_skb_to_cb(skb); + if (swap_src_dst) + j1939_skbcb_swap(skcb); + + if (ctl) { + if (skcb->addr.type == J1939_ETP) + skcb->addr.pgn = J1939_ETP_PGN_CTL; + else + skcb->addr.pgn = J1939_TP_PGN_CTL; + } else { + if (skcb->addr.type == J1939_ETP) + skcb->addr.pgn = J1939_ETP_PGN_DAT; + else + skcb->addr.pgn = J1939_TP_PGN_DAT; + } + + return skb; +} + +/* TP transmit packet functions */ +static int j1939_tp_tx_dat(struct j1939_session *session, + const u8 *dat, int len) +{ + struct j1939_priv *priv = session->priv; + struct sk_buff *skb; + + skb = j1939_tp_tx_dat_new(priv, &session->skcb, + false, false); + if (IS_ERR(skb)) + return PTR_ERR(skb); + + skb_put_data(skb, dat, len); + if (j1939_tp_padding && len < 8) + memset(skb_put(skb, 8 - len), 0xff, 8 - len); + + return j1939_send_one(priv, skb); +} + +static int j1939_xtp_do_tx_ctl(struct j1939_priv *priv, + const struct j1939_sk_buff_cb *re_skcb, + bool swap_src_dst, pgn_t pgn, const u8 *dat) +{ + struct sk_buff *skb; + u8 *skdat; + + if (!j1939_tp_im_involved(re_skcb, swap_src_dst)) + return 0; + + skb = j1939_tp_tx_dat_new(priv, re_skcb, true, swap_src_dst); + if (IS_ERR(skb)) + return PTR_ERR(skb); + + skdat = skb_put(skb, 8); + memcpy(skdat, dat, 5); + skdat[5] = (pgn >> 0); + skdat[6] = (pgn >> 8); + skdat[7] = (pgn >> 16); + + return j1939_send_one(priv, skb); +} + +static inline int j1939_tp_tx_ctl(struct j1939_session *session, + bool swap_src_dst, const u8 *dat) +{ + struct j1939_priv *priv = session->priv; + + return j1939_xtp_do_tx_ctl(priv, &session->skcb, + swap_src_dst, + session->skcb.addr.pgn, dat); +} + +static int j1939_xtp_tx_abort(struct j1939_priv *priv, + const struct j1939_sk_buff_cb *re_skcb, + bool swap_src_dst, + enum j1939_xtp_abort err, + pgn_t pgn) +{ + u8 dat[5]; + + if (!j1939_tp_im_involved(re_skcb, swap_src_dst)) + return 0; + + memset(dat, 0xff, sizeof(dat)); + dat[0] = J1939_TP_CMD_ABORT; + dat[1] = err; + return j1939_xtp_do_tx_ctl(priv, re_skcb, swap_src_dst, pgn, dat); +} + +void j1939_tp_schedule_txtimer(struct j1939_session *session, int msec) +{ + j1939_session_get(session); + hrtimer_start(&session->txtimer, ms_to_ktime(msec), + HRTIMER_MODE_REL_SOFT); +} + +static inline void j1939_tp_set_rxtimeout(struct j1939_session *session, + int msec) +{ + j1939_session_rxtimer_cancel(session); + j1939_session_get(session); + hrtimer_start(&session->rxtimer, ms_to_ktime(msec), + HRTIMER_MODE_REL_SOFT); +} + +static int j1939_session_tx_rts(struct j1939_session *session) +{ + u8 dat[8]; + int ret; + + memset(dat, 0xff, sizeof(dat)); + + dat[1] = (session->total_message_size >> 0); + dat[2] = (session->total_message_size >> 8); + dat[3] = session->pkt.total; + + if (session->skcb.addr.type == J1939_ETP) { + dat[0] = J1939_ETP_CMD_RTS; + dat[1] = (session->total_message_size >> 0); + dat[2] = (session->total_message_size >> 8); + dat[3] = (session->total_message_size >> 16); + dat[4] = (session->total_message_size >> 24); + } else if (j1939_cb_is_broadcast(&session->skcb)) { + dat[0] = J1939_TP_CMD_BAM; + /* fake cts for broadcast */ + session->pkt.tx = 0; + } else { + dat[0] = J1939_TP_CMD_RTS; + dat[4] = dat[3]; + } + + if (dat[0] == session->last_txcmd) + /* done already */ + return 0; + + ret = j1939_tp_tx_ctl(session, false, dat); + if (ret < 0) + return ret; + + session->last_txcmd = dat[0]; + if (dat[0] == J1939_TP_CMD_BAM) { + j1939_tp_schedule_txtimer(session, 50); + j1939_tp_set_rxtimeout(session, 250); + } else { + j1939_tp_set_rxtimeout(session, 1250); + } + + netdev_dbg(session->priv->ndev, "%s: 0x%p\n", __func__, session); + + return 0; +} + +static int j1939_session_tx_dpo(struct j1939_session *session) +{ + unsigned int pkt; + u8 dat[8]; + int ret; + + memset(dat, 0xff, sizeof(dat)); + + dat[0] = J1939_ETP_CMD_DPO; + session->pkt.dpo = session->pkt.tx_acked; + pkt = session->pkt.dpo; + dat[1] = session->pkt.last - session->pkt.tx_acked; + dat[2] = (pkt >> 0); + dat[3] = (pkt >> 8); + dat[4] = (pkt >> 16); + + ret = j1939_tp_tx_ctl(session, false, dat); + if (ret < 0) + return ret; + + session->last_txcmd = dat[0]; + j1939_tp_set_rxtimeout(session, 1250); + session->pkt.tx = session->pkt.tx_acked; + + netdev_dbg(session->priv->ndev, "%s: 0x%p\n", __func__, session); + + return 0; +} + +static int j1939_session_tx_dat(struct j1939_session *session) +{ + struct j1939_priv *priv = session->priv; + struct j1939_sk_buff_cb *se_skcb; + int offset, pkt_done, pkt_end; + unsigned int len, pdelay; + struct sk_buff *se_skb; + const u8 *tpdat; + int ret = 0; + u8 dat[8]; + + se_skb = j1939_session_skb_get_by_offset(session, session->pkt.tx * 7); + if (!se_skb) + return -ENOBUFS; + + se_skcb = j1939_skb_to_cb(se_skb); + tpdat = se_skb->data; + ret = 0; + pkt_done = 0; + if (session->skcb.addr.type != J1939_ETP && + j1939_cb_is_broadcast(&session->skcb)) + pkt_end = session->pkt.total; + else + pkt_end = session->pkt.last; + + while (session->pkt.tx < pkt_end) { + dat[0] = session->pkt.tx - session->pkt.dpo + 1; + offset = (session->pkt.tx * 7) - se_skcb->offset; + len = se_skb->len - offset; + if (len > 7) + len = 7; + + if (offset + len > se_skb->len) { + netdev_err_once(priv->ndev, + "%s: 0x%p: requested data outside of queued buffer: offset %i, len %i, pkt.tx: %i\n", + __func__, session, se_skcb->offset, + se_skb->len , session->pkt.tx); + ret = -EOVERFLOW; + goto out_free; + } + + if (!len) { + ret = -ENOBUFS; + break; + } + + memcpy(&dat[1], &tpdat[offset], len); + ret = j1939_tp_tx_dat(session, dat, len + 1); + if (ret < 0) { + /* ENOBUFS == CAN interface TX queue is full */ + if (ret != -ENOBUFS) + netdev_alert(priv->ndev, + "%s: 0x%p: queue data error: %i\n", + __func__, session, ret); + break; + } + + session->last_txcmd = 0xff; + pkt_done++; + session->pkt.tx++; + pdelay = j1939_cb_is_broadcast(&session->skcb) ? 50 : + j1939_tp_packet_delay; + + if (session->pkt.tx < session->pkt.total && pdelay) { + j1939_tp_schedule_txtimer(session, pdelay); + break; + } + } + + if (pkt_done) + j1939_tp_set_rxtimeout(session, 250); + + out_free: + if (ret) + kfree_skb(se_skb); + else + consume_skb(se_skb); + + return ret; +} + +static int j1939_xtp_txnext_transmiter(struct j1939_session *session) +{ + struct j1939_priv *priv = session->priv; + int ret = 0; + + if (!j1939_tp_im_transmitter(&session->skcb)) { + netdev_alert(priv->ndev, "%s: 0x%p: called by not transmitter!\n", + __func__, session); + return -EINVAL; + } + + switch (session->last_cmd) { + case 0: + ret = j1939_session_tx_rts(session); + break; + + case J1939_ETP_CMD_CTS: + if (session->last_txcmd != J1939_ETP_CMD_DPO) { + ret = j1939_session_tx_dpo(session); + if (ret) + return ret; + } + + fallthrough; + case J1939_TP_CMD_CTS: + case 0xff: /* did some data */ + case J1939_ETP_CMD_DPO: + case J1939_TP_CMD_BAM: + ret = j1939_session_tx_dat(session); + + break; + default: + netdev_alert(priv->ndev, "%s: 0x%p: unexpected last_cmd: %x\n", + __func__, session, session->last_cmd); + } + + return ret; +} + +static int j1939_session_tx_cts(struct j1939_session *session) +{ + struct j1939_priv *priv = session->priv; + unsigned int pkt, len; + int ret; + u8 dat[8]; + + if (!j1939_sk_recv_match(priv, &session->skcb)) + return -ENOENT; + + len = session->pkt.total - session->pkt.rx; + len = min3(len, session->pkt.block, j1939_tp_block ?: 255); + memset(dat, 0xff, sizeof(dat)); + + if (session->skcb.addr.type == J1939_ETP) { + pkt = session->pkt.rx + 1; + dat[0] = J1939_ETP_CMD_CTS; + dat[1] = len; + dat[2] = (pkt >> 0); + dat[3] = (pkt >> 8); + dat[4] = (pkt >> 16); + } else { + dat[0] = J1939_TP_CMD_CTS; + dat[1] = len; + dat[2] = session->pkt.rx + 1; + } + + if (dat[0] == session->last_txcmd) + /* done already */ + return 0; + + ret = j1939_tp_tx_ctl(session, true, dat); + if (ret < 0) + return ret; + + if (len) + /* only mark cts done when len is set */ + session->last_txcmd = dat[0]; + j1939_tp_set_rxtimeout(session, 1250); + + netdev_dbg(session->priv->ndev, "%s: 0x%p\n", __func__, session); + + return 0; +} + +static int j1939_session_tx_eoma(struct j1939_session *session) +{ + struct j1939_priv *priv = session->priv; + u8 dat[8]; + int ret; + + if (!j1939_sk_recv_match(priv, &session->skcb)) + return -ENOENT; + + memset(dat, 0xff, sizeof(dat)); + + if (session->skcb.addr.type == J1939_ETP) { + dat[0] = J1939_ETP_CMD_EOMA; + dat[1] = session->total_message_size >> 0; + dat[2] = session->total_message_size >> 8; + dat[3] = session->total_message_size >> 16; + dat[4] = session->total_message_size >> 24; + } else { + dat[0] = J1939_TP_CMD_EOMA; + dat[1] = session->total_message_size; + dat[2] = session->total_message_size >> 8; + dat[3] = session->pkt.total; + } + + if (dat[0] == session->last_txcmd) + /* done already */ + return 0; + + ret = j1939_tp_tx_ctl(session, true, dat); + if (ret < 0) + return ret; + + session->last_txcmd = dat[0]; + + /* wait for the EOMA packet to come in */ + j1939_tp_set_rxtimeout(session, 1250); + + netdev_dbg(session->priv->ndev, "%p: 0x%p\n", __func__, session); + + return 0; +} + +static int j1939_xtp_txnext_receiver(struct j1939_session *session) +{ + struct j1939_priv *priv = session->priv; + int ret = 0; + + if (!j1939_tp_im_receiver(&session->skcb)) { + netdev_alert(priv->ndev, "%s: 0x%p: called by not receiver!\n", + __func__, session); + return -EINVAL; + } + + switch (session->last_cmd) { + case J1939_TP_CMD_RTS: + case J1939_ETP_CMD_RTS: + ret = j1939_session_tx_cts(session); + break; + + case J1939_ETP_CMD_CTS: + case J1939_TP_CMD_CTS: + case 0xff: /* did some data */ + case J1939_ETP_CMD_DPO: + if ((session->skcb.addr.type == J1939_TP && + j1939_cb_is_broadcast(&session->skcb))) + break; + + if (session->pkt.rx >= session->pkt.total) { + ret = j1939_session_tx_eoma(session); + } else if (session->pkt.rx >= session->pkt.last) { + session->last_txcmd = 0; + ret = j1939_session_tx_cts(session); + } + break; + default: + netdev_alert(priv->ndev, "%s: 0x%p: unexpected last_cmd: %x\n", + __func__, session, session->last_cmd); + } + + return ret; +} + +static int j1939_simple_txnext(struct j1939_session *session) +{ + struct j1939_priv *priv = session->priv; + struct sk_buff *se_skb = j1939_session_skb_get(session); + struct sk_buff *skb; + int ret; + + if (!se_skb) + return 0; + + skb = skb_clone(se_skb, GFP_ATOMIC); + if (!skb) { + ret = -ENOMEM; + goto out_free; + } + + can_skb_set_owner(skb, se_skb->sk); + + j1939_tp_set_rxtimeout(session, J1939_SIMPLE_ECHO_TIMEOUT_MS); + + ret = j1939_send_one(priv, skb); + if (ret) + goto out_free; + + j1939_sk_errqueue(session, J1939_ERRQUEUE_TX_SCHED); + j1939_sk_queue_activate_next(session); + + out_free: + if (ret) + kfree_skb(se_skb); + else + consume_skb(se_skb); + + return ret; +} + +static bool j1939_session_deactivate_locked(struct j1939_session *session) +{ + bool active = false; + + lockdep_assert_held(&session->priv->active_session_list_lock); + + if (session->state >= J1939_SESSION_ACTIVE && + session->state < J1939_SESSION_ACTIVE_MAX) { + active = true; + + list_del_init(&session->active_session_list_entry); + session->state = J1939_SESSION_DONE; + j1939_session_put(session); + } + + return active; +} + +static bool j1939_session_deactivate(struct j1939_session *session) +{ + struct j1939_priv *priv = session->priv; + bool active; + + j1939_session_list_lock(priv); + active = j1939_session_deactivate_locked(session); + j1939_session_list_unlock(priv); + + return active; +} + +static void +j1939_session_deactivate_activate_next(struct j1939_session *session) +{ + if (j1939_session_deactivate(session)) + j1939_sk_queue_activate_next(session); +} + +static void __j1939_session_cancel(struct j1939_session *session, + enum j1939_xtp_abort err) +{ + struct j1939_priv *priv = session->priv; + + WARN_ON_ONCE(!err); + lockdep_assert_held(&session->priv->active_session_list_lock); + + session->err = j1939_xtp_abort_to_errno(priv, err); + session->state = J1939_SESSION_WAITING_ABORT; + /* do not send aborts on incoming broadcasts */ + if (!j1939_cb_is_broadcast(&session->skcb)) { + j1939_xtp_tx_abort(priv, &session->skcb, + !session->transmission, + err, session->skcb.addr.pgn); + } + + if (session->sk) + j1939_sk_send_loop_abort(session->sk, session->err); +} + +static void j1939_session_cancel(struct j1939_session *session, + enum j1939_xtp_abort err) +{ + j1939_session_list_lock(session->priv); + + if (session->state >= J1939_SESSION_ACTIVE && + session->state < J1939_SESSION_WAITING_ABORT) { + j1939_tp_set_rxtimeout(session, J1939_XTP_ABORT_TIMEOUT_MS); + __j1939_session_cancel(session, err); + } + + j1939_session_list_unlock(session->priv); + + if (!session->sk) + j1939_sk_errqueue(session, J1939_ERRQUEUE_RX_ABORT); +} + +static enum hrtimer_restart j1939_tp_txtimer(struct hrtimer *hrtimer) +{ + struct j1939_session *session = + container_of(hrtimer, struct j1939_session, txtimer); + struct j1939_priv *priv = session->priv; + int ret = 0; + + if (session->skcb.addr.type == J1939_SIMPLE) { + ret = j1939_simple_txnext(session); + } else { + if (session->transmission) + ret = j1939_xtp_txnext_transmiter(session); + else + ret = j1939_xtp_txnext_receiver(session); + } + + switch (ret) { + case -ENOBUFS: + /* Retry limit is currently arbitrary chosen */ + if (session->tx_retry < J1939_XTP_TX_RETRY_LIMIT) { + session->tx_retry++; + j1939_tp_schedule_txtimer(session, + 10 + prandom_u32_max(16)); + } else { + netdev_alert(priv->ndev, "%s: 0x%p: tx retry count reached\n", + __func__, session); + session->err = -ENETUNREACH; + j1939_session_rxtimer_cancel(session); + j1939_session_deactivate_activate_next(session); + } + break; + case -ENETDOWN: + /* In this case we should get a netdev_event(), all active + * sessions will be cleared by + * j1939_cancel_all_active_sessions(). So handle this as an + * error, but let j1939_cancel_all_active_sessions() do the + * cleanup including propagation of the error to user space. + */ + break; + case -EOVERFLOW: + j1939_session_cancel(session, J1939_XTP_ABORT_ECTS_TOO_BIG); + break; + case 0: + session->tx_retry = 0; + break; + default: + netdev_alert(priv->ndev, "%s: 0x%p: tx aborted with unknown reason: %i\n", + __func__, session, ret); + if (session->skcb.addr.type != J1939_SIMPLE) { + j1939_session_cancel(session, J1939_XTP_ABORT_OTHER); + } else { + session->err = ret; + j1939_session_rxtimer_cancel(session); + j1939_session_deactivate_activate_next(session); + } + } + + j1939_session_put(session); + + return HRTIMER_NORESTART; +} + +static void j1939_session_completed(struct j1939_session *session) +{ + struct sk_buff *se_skb; + + if (!session->transmission) { + se_skb = j1939_session_skb_get(session); + /* distribute among j1939 receivers */ + j1939_sk_recv(session->priv, se_skb); + consume_skb(se_skb); + } + + j1939_session_deactivate_activate_next(session); +} + +static enum hrtimer_restart j1939_tp_rxtimer(struct hrtimer *hrtimer) +{ + struct j1939_session *session = container_of(hrtimer, + struct j1939_session, + rxtimer); + struct j1939_priv *priv = session->priv; + + if (session->state == J1939_SESSION_WAITING_ABORT) { + netdev_alert(priv->ndev, "%s: 0x%p: abort rx timeout. Force session deactivation\n", + __func__, session); + + j1939_session_deactivate_activate_next(session); + + } else if (session->skcb.addr.type == J1939_SIMPLE) { + netdev_alert(priv->ndev, "%s: 0x%p: Timeout. Failed to send simple message.\n", + __func__, session); + + /* The message is probably stuck in the CAN controller and can + * be send as soon as CAN bus is in working state again. + */ + session->err = -ETIME; + j1939_session_deactivate(session); + } else { + j1939_session_list_lock(session->priv); + if (session->state >= J1939_SESSION_ACTIVE && + session->state < J1939_SESSION_ACTIVE_MAX) { + netdev_alert(priv->ndev, "%s: 0x%p: rx timeout, send abort\n", + __func__, session); + j1939_session_get(session); + hrtimer_start(&session->rxtimer, + ms_to_ktime(J1939_XTP_ABORT_TIMEOUT_MS), + HRTIMER_MODE_REL_SOFT); + __j1939_session_cancel(session, J1939_XTP_ABORT_TIMEOUT); + } + j1939_session_list_unlock(session->priv); + + if (!session->sk) + j1939_sk_errqueue(session, J1939_ERRQUEUE_RX_ABORT); + } + + j1939_session_put(session); + + return HRTIMER_NORESTART; +} + +static bool j1939_xtp_rx_cmd_bad_pgn(struct j1939_session *session, + const struct sk_buff *skb) +{ + const struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + pgn_t pgn = j1939_xtp_ctl_to_pgn(skb->data); + struct j1939_priv *priv = session->priv; + enum j1939_xtp_abort abort = J1939_XTP_NO_ABORT; + u8 cmd = skb->data[0]; + + if (session->skcb.addr.pgn == pgn) + return false; + + switch (cmd) { + case J1939_TP_CMD_BAM: + abort = J1939_XTP_NO_ABORT; + break; + + case J1939_ETP_CMD_RTS: + fallthrough; + case J1939_TP_CMD_RTS: + abort = J1939_XTP_ABORT_BUSY; + break; + + case J1939_ETP_CMD_CTS: + fallthrough; + case J1939_TP_CMD_CTS: + abort = J1939_XTP_ABORT_ECTS_UNXPECTED_PGN; + break; + + case J1939_ETP_CMD_DPO: + abort = J1939_XTP_ABORT_BAD_EDPO_PGN; + break; + + case J1939_ETP_CMD_EOMA: + fallthrough; + case J1939_TP_CMD_EOMA: + abort = J1939_XTP_ABORT_OTHER; + break; + + case J1939_ETP_CMD_ABORT: /* && J1939_TP_CMD_ABORT */ + abort = J1939_XTP_NO_ABORT; + break; + + default: + WARN_ON_ONCE(1); + break; + } + + netdev_warn(priv->ndev, "%s: 0x%p: CMD 0x%02x with PGN 0x%05x for running session with different PGN 0x%05x.\n", + __func__, session, cmd, pgn, session->skcb.addr.pgn); + if (abort != J1939_XTP_NO_ABORT) + j1939_xtp_tx_abort(priv, skcb, true, abort, pgn); + + return true; +} + +static void j1939_xtp_rx_abort_one(struct j1939_priv *priv, struct sk_buff *skb, + bool reverse, bool transmitter) +{ + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + struct j1939_session *session; + u8 abort = skb->data[1]; + + session = j1939_session_get_by_addr(priv, &skcb->addr, reverse, + transmitter); + if (!session) + return; + + if (j1939_xtp_rx_cmd_bad_pgn(session, skb)) + goto abort_put; + + netdev_info(priv->ndev, "%s: 0x%p: 0x%05x: (%u) %s\n", __func__, + session, j1939_xtp_ctl_to_pgn(skb->data), abort, + j1939_xtp_abort_to_str(abort)); + + j1939_session_timers_cancel(session); + session->err = j1939_xtp_abort_to_errno(priv, abort); + if (session->sk) + j1939_sk_send_loop_abort(session->sk, session->err); + else + j1939_sk_errqueue(session, J1939_ERRQUEUE_RX_ABORT); + j1939_session_deactivate_activate_next(session); + +abort_put: + j1939_session_put(session); +} + +/* abort packets may come in 2 directions */ +static void +j1939_xtp_rx_abort(struct j1939_priv *priv, struct sk_buff *skb, + bool transmitter) +{ + j1939_xtp_rx_abort_one(priv, skb, false, transmitter); + j1939_xtp_rx_abort_one(priv, skb, true, transmitter); +} + +static void +j1939_xtp_rx_eoma_one(struct j1939_session *session, struct sk_buff *skb) +{ + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + const u8 *dat; + int len; + + if (j1939_xtp_rx_cmd_bad_pgn(session, skb)) + return; + + dat = skb->data; + + if (skcb->addr.type == J1939_ETP) + len = j1939_etp_ctl_to_size(dat); + else + len = j1939_tp_ctl_to_size(dat); + + if (session->total_message_size != len) { + netdev_warn_once(session->priv->ndev, + "%s: 0x%p: Incorrect size. Expected: %i; got: %i.\n", + __func__, session, session->total_message_size, + len); + } + + netdev_dbg(session->priv->ndev, "%s: 0x%p\n", __func__, session); + + session->pkt.tx_acked = session->pkt.total; + j1939_session_timers_cancel(session); + /* transmitted without problems */ + j1939_session_completed(session); +} + +static void +j1939_xtp_rx_eoma(struct j1939_priv *priv, struct sk_buff *skb, + bool transmitter) +{ + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + struct j1939_session *session; + + session = j1939_session_get_by_addr(priv, &skcb->addr, true, + transmitter); + if (!session) + return; + + j1939_xtp_rx_eoma_one(session, skb); + j1939_session_put(session); +} + +static void +j1939_xtp_rx_cts_one(struct j1939_session *session, struct sk_buff *skb) +{ + enum j1939_xtp_abort err = J1939_XTP_ABORT_FAULT; + unsigned int pkt; + const u8 *dat; + + dat = skb->data; + + if (j1939_xtp_rx_cmd_bad_pgn(session, skb)) + return; + + netdev_dbg(session->priv->ndev, "%s: 0x%p\n", __func__, session); + + if (session->last_cmd == dat[0]) { + err = J1939_XTP_ABORT_DUP_SEQ; + goto out_session_cancel; + } + + if (session->skcb.addr.type == J1939_ETP) + pkt = j1939_etp_ctl_to_packet(dat); + else + pkt = dat[2]; + + if (!pkt) + goto out_session_cancel; + else if (dat[1] > session->pkt.block /* 0xff for etp */) + goto out_session_cancel; + + /* set packet counters only when not CTS(0) */ + session->pkt.tx_acked = pkt - 1; + j1939_session_skb_drop_old(session); + session->pkt.last = session->pkt.tx_acked + dat[1]; + if (session->pkt.last > session->pkt.total) + /* safety measure */ + session->pkt.last = session->pkt.total; + /* TODO: do not set tx here, do it in txtimer */ + session->pkt.tx = session->pkt.tx_acked; + + session->last_cmd = dat[0]; + if (dat[1]) { + j1939_tp_set_rxtimeout(session, 1250); + if (session->transmission) { + if (session->pkt.tx_acked) + j1939_sk_errqueue(session, + J1939_ERRQUEUE_TX_SCHED); + j1939_session_txtimer_cancel(session); + j1939_tp_schedule_txtimer(session, 0); + } + } else { + /* CTS(0) */ + j1939_tp_set_rxtimeout(session, 550); + } + return; + + out_session_cancel: + j1939_session_timers_cancel(session); + j1939_session_cancel(session, err); +} + +static void +j1939_xtp_rx_cts(struct j1939_priv *priv, struct sk_buff *skb, bool transmitter) +{ + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + struct j1939_session *session; + + session = j1939_session_get_by_addr(priv, &skcb->addr, true, + transmitter); + if (!session) + return; + j1939_xtp_rx_cts_one(session, skb); + j1939_session_put(session); +} + +static struct j1939_session *j1939_session_new(struct j1939_priv *priv, + struct sk_buff *skb, size_t size) +{ + struct j1939_session *session; + struct j1939_sk_buff_cb *skcb; + + session = kzalloc(sizeof(*session), gfp_any()); + if (!session) + return NULL; + + INIT_LIST_HEAD(&session->active_session_list_entry); + INIT_LIST_HEAD(&session->sk_session_queue_entry); + kref_init(&session->kref); + + j1939_priv_get(priv); + session->priv = priv; + session->total_message_size = size; + session->state = J1939_SESSION_NEW; + + skb_queue_head_init(&session->skb_queue); + skb_queue_tail(&session->skb_queue, skb); + + skcb = j1939_skb_to_cb(skb); + memcpy(&session->skcb, skcb, sizeof(session->skcb)); + + hrtimer_init(&session->txtimer, CLOCK_MONOTONIC, + HRTIMER_MODE_REL_SOFT); + session->txtimer.function = j1939_tp_txtimer; + hrtimer_init(&session->rxtimer, CLOCK_MONOTONIC, + HRTIMER_MODE_REL_SOFT); + session->rxtimer.function = j1939_tp_rxtimer; + + netdev_dbg(priv->ndev, "%s: 0x%p: sa: %02x, da: %02x\n", + __func__, session, skcb->addr.sa, skcb->addr.da); + + return session; +} + +static struct +j1939_session *j1939_session_fresh_new(struct j1939_priv *priv, + int size, + const struct j1939_sk_buff_cb *rel_skcb) +{ + struct sk_buff *skb; + struct j1939_sk_buff_cb *skcb; + struct j1939_session *session; + + skb = alloc_skb(size + sizeof(struct can_skb_priv), GFP_ATOMIC); + if (unlikely(!skb)) + return NULL; + + skb->dev = priv->ndev; + can_skb_reserve(skb); + can_skb_prv(skb)->ifindex = priv->ndev->ifindex; + can_skb_prv(skb)->skbcnt = 0; + skcb = j1939_skb_to_cb(skb); + memcpy(skcb, rel_skcb, sizeof(*skcb)); + + session = j1939_session_new(priv, skb, size); + if (!session) { + kfree_skb(skb); + return NULL; + } + + /* alloc data area */ + skb_put(skb, size); + /* skb is recounted in j1939_session_new() */ + return session; +} + +int j1939_session_activate(struct j1939_session *session) +{ + struct j1939_priv *priv = session->priv; + struct j1939_session *active = NULL; + int ret = 0; + + j1939_session_list_lock(priv); + if (session->skcb.addr.type != J1939_SIMPLE) + active = j1939_session_get_by_addr_locked(priv, + &priv->active_session_list, + &session->skcb.addr, false, + session->transmission); + if (active) { + j1939_session_put(active); + ret = -EAGAIN; + } else { + WARN_ON_ONCE(session->state != J1939_SESSION_NEW); + list_add_tail(&session->active_session_list_entry, + &priv->active_session_list); + j1939_session_get(session); + session->state = J1939_SESSION_ACTIVE; + + netdev_dbg(session->priv->ndev, "%s: 0x%p\n", + __func__, session); + } + j1939_session_list_unlock(priv); + + return ret; +} + +static struct +j1939_session *j1939_xtp_rx_rts_session_new(struct j1939_priv *priv, + struct sk_buff *skb) +{ + enum j1939_xtp_abort abort = J1939_XTP_NO_ABORT; + struct j1939_sk_buff_cb skcb = *j1939_skb_to_cb(skb); + struct j1939_session *session; + const u8 *dat; + pgn_t pgn; + int len; + + netdev_dbg(priv->ndev, "%s\n", __func__); + + dat = skb->data; + pgn = j1939_xtp_ctl_to_pgn(dat); + skcb.addr.pgn = pgn; + + if (!j1939_sk_recv_match(priv, &skcb)) + return NULL; + + if (skcb.addr.type == J1939_ETP) { + len = j1939_etp_ctl_to_size(dat); + if (len > J1939_MAX_ETP_PACKET_SIZE) + abort = J1939_XTP_ABORT_FAULT; + else if (len > priv->tp_max_packet_size) + abort = J1939_XTP_ABORT_RESOURCE; + else if (len <= J1939_MAX_TP_PACKET_SIZE) + abort = J1939_XTP_ABORT_FAULT; + } else { + len = j1939_tp_ctl_to_size(dat); + if (len > J1939_MAX_TP_PACKET_SIZE) + abort = J1939_XTP_ABORT_FAULT; + else if (len > priv->tp_max_packet_size) + abort = J1939_XTP_ABORT_RESOURCE; + else if (len < J1939_MIN_TP_PACKET_SIZE) + abort = J1939_XTP_ABORT_FAULT; + } + + if (abort != J1939_XTP_NO_ABORT) { + j1939_xtp_tx_abort(priv, &skcb, true, abort, pgn); + return NULL; + } + + session = j1939_session_fresh_new(priv, len, &skcb); + if (!session) { + j1939_xtp_tx_abort(priv, &skcb, true, + J1939_XTP_ABORT_RESOURCE, pgn); + return NULL; + } + + /* initialize the control buffer: plain copy */ + session->pkt.total = (len + 6) / 7; + session->pkt.block = 0xff; + if (skcb.addr.type != J1939_ETP) { + if (dat[3] != session->pkt.total) + netdev_alert(priv->ndev, "%s: 0x%p: strange total, %u != %u\n", + __func__, session, session->pkt.total, + dat[3]); + session->pkt.total = dat[3]; + session->pkt.block = min(dat[3], dat[4]); + } + + session->pkt.rx = 0; + session->pkt.tx = 0; + + session->tskey = priv->rx_tskey++; + j1939_sk_errqueue(session, J1939_ERRQUEUE_RX_RTS); + + WARN_ON_ONCE(j1939_session_activate(session)); + + return session; +} + +static int j1939_xtp_rx_rts_session_active(struct j1939_session *session, + struct sk_buff *skb) +{ + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + struct j1939_priv *priv = session->priv; + + if (!session->transmission) { + if (j1939_xtp_rx_cmd_bad_pgn(session, skb)) + return -EBUSY; + + /* RTS on active session */ + j1939_session_timers_cancel(session); + j1939_session_cancel(session, J1939_XTP_ABORT_BUSY); + } + + if (session->last_cmd != 0) { + /* we received a second rts on the same connection */ + netdev_alert(priv->ndev, "%s: 0x%p: connection exists (%02x %02x). last cmd: %x\n", + __func__, session, skcb->addr.sa, skcb->addr.da, + session->last_cmd); + + j1939_session_timers_cancel(session); + j1939_session_cancel(session, J1939_XTP_ABORT_BUSY); + + return -EBUSY; + } + + if (session->skcb.addr.sa != skcb->addr.sa || + session->skcb.addr.da != skcb->addr.da) + netdev_warn(priv->ndev, "%s: 0x%p: session->skcb.addr.sa=0x%02x skcb->addr.sa=0x%02x session->skcb.addr.da=0x%02x skcb->addr.da=0x%02x\n", + __func__, session, + session->skcb.addr.sa, skcb->addr.sa, + session->skcb.addr.da, skcb->addr.da); + /* make sure 'sa' & 'da' are correct ! + * They may be 'not filled in yet' for sending + * skb's, since they did not pass the Address Claim ever. + */ + session->skcb.addr.sa = skcb->addr.sa; + session->skcb.addr.da = skcb->addr.da; + + netdev_dbg(session->priv->ndev, "%s: 0x%p\n", __func__, session); + + return 0; +} + +static void j1939_xtp_rx_rts(struct j1939_priv *priv, struct sk_buff *skb, + bool transmitter) +{ + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + struct j1939_session *session; + u8 cmd = skb->data[0]; + + session = j1939_session_get_by_addr(priv, &skcb->addr, false, + transmitter); + + if (!session) { + if (transmitter) { + /* If we're the transmitter and this function is called, + * we received our own RTS. A session has already been + * created. + * + * For some reasons however it might have been destroyed + * already. So don't create a new one here (using + * "j1939_xtp_rx_rts_session_new()") as this will be a + * receiver session. + * + * The reasons the session is already destroyed might + * be: + * - user space closed socket was and the session was + * aborted + * - session was aborted due to external abort message + */ + return; + } + session = j1939_xtp_rx_rts_session_new(priv, skb); + if (!session) { + if (cmd == J1939_TP_CMD_BAM && j1939_sk_recv_match(priv, skcb)) + netdev_info(priv->ndev, "%s: failed to create TP BAM session\n", + __func__); + return; + } + } else { + if (j1939_xtp_rx_rts_session_active(session, skb)) { + j1939_session_put(session); + return; + } + } + session->last_cmd = cmd; + + if (cmd == J1939_TP_CMD_BAM) { + if (!session->transmission) + j1939_tp_set_rxtimeout(session, 750); + } else { + if (!session->transmission) { + j1939_session_txtimer_cancel(session); + j1939_tp_schedule_txtimer(session, 0); + } + j1939_tp_set_rxtimeout(session, 1250); + } + + j1939_session_put(session); +} + +static void j1939_xtp_rx_dpo_one(struct j1939_session *session, + struct sk_buff *skb) +{ + const u8 *dat = skb->data; + + if (j1939_xtp_rx_cmd_bad_pgn(session, skb)) + return; + + netdev_dbg(session->priv->ndev, "%s: 0x%p\n", __func__, session); + + /* transmitted without problems */ + session->pkt.dpo = j1939_etp_ctl_to_packet(skb->data); + session->last_cmd = dat[0]; + j1939_tp_set_rxtimeout(session, 750); + + if (!session->transmission) + j1939_sk_errqueue(session, J1939_ERRQUEUE_RX_DPO); +} + +static void j1939_xtp_rx_dpo(struct j1939_priv *priv, struct sk_buff *skb, + bool transmitter) +{ + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + struct j1939_session *session; + + session = j1939_session_get_by_addr(priv, &skcb->addr, false, + transmitter); + if (!session) { + netdev_info(priv->ndev, + "%s: no connection found\n", __func__); + return; + } + + j1939_xtp_rx_dpo_one(session, skb); + j1939_session_put(session); +} + +static void j1939_xtp_rx_dat_one(struct j1939_session *session, + struct sk_buff *skb) +{ + enum j1939_xtp_abort abort = J1939_XTP_ABORT_FAULT; + struct j1939_priv *priv = session->priv; + struct j1939_sk_buff_cb *skcb, *se_skcb; + struct sk_buff *se_skb = NULL; + const u8 *dat; + u8 *tpdat; + int offset; + int nbytes; + bool final = false; + bool remain = false; + bool do_cts_eoma = false; + int packet; + + skcb = j1939_skb_to_cb(skb); + dat = skb->data; + if (skb->len != 8) { + /* makes no sense */ + abort = J1939_XTP_ABORT_UNEXPECTED_DATA; + goto out_session_cancel; + } + + switch (session->last_cmd) { + case 0xff: + break; + case J1939_ETP_CMD_DPO: + if (skcb->addr.type == J1939_ETP) + break; + fallthrough; + case J1939_TP_CMD_BAM: + fallthrough; + case J1939_TP_CMD_CTS: + if (skcb->addr.type != J1939_ETP) + break; + fallthrough; + default: + netdev_info(priv->ndev, "%s: 0x%p: last %02x\n", __func__, + session, session->last_cmd); + goto out_session_cancel; + } + + packet = (dat[0] - 1 + session->pkt.dpo); + if (packet > session->pkt.total || + (session->pkt.rx + 1) > session->pkt.total) { + netdev_info(priv->ndev, "%s: 0x%p: should have been completed\n", + __func__, session); + goto out_session_cancel; + } + + se_skb = j1939_session_skb_get_by_offset(session, packet * 7); + if (!se_skb) { + netdev_warn(priv->ndev, "%s: 0x%p: no skb found\n", __func__, + session); + goto out_session_cancel; + } + + se_skcb = j1939_skb_to_cb(se_skb); + offset = packet * 7 - se_skcb->offset; + nbytes = se_skb->len - offset; + if (nbytes > 7) + nbytes = 7; + if (nbytes <= 0 || (nbytes + 1) > skb->len) { + netdev_info(priv->ndev, "%s: 0x%p: nbytes %i, len %i\n", + __func__, session, nbytes, skb->len); + goto out_session_cancel; + } + + tpdat = se_skb->data; + if (!session->transmission) { + memcpy(&tpdat[offset], &dat[1], nbytes); + } else { + int err; + + err = memcmp(&tpdat[offset], &dat[1], nbytes); + if (err) + netdev_err_once(priv->ndev, + "%s: 0x%p: Data of RX-looped back packet (%*ph) doesn't match TX data (%*ph)!\n", + __func__, session, + nbytes, &dat[1], + nbytes, &tpdat[offset]); + } + + if (packet == session->pkt.rx) + session->pkt.rx++; + + if (se_skcb->addr.type != J1939_ETP && + j1939_cb_is_broadcast(&session->skcb)) { + if (session->pkt.rx >= session->pkt.total) + final = true; + else + remain = true; + } else { + /* never final, an EOMA must follow */ + if (session->pkt.rx >= session->pkt.last) + do_cts_eoma = true; + } + + if (final) { + j1939_session_timers_cancel(session); + j1939_session_completed(session); + } else if (remain) { + if (!session->transmission) + j1939_tp_set_rxtimeout(session, 750); + } else if (do_cts_eoma) { + j1939_tp_set_rxtimeout(session, 1250); + if (!session->transmission) + j1939_tp_schedule_txtimer(session, 0); + } else { + j1939_tp_set_rxtimeout(session, 750); + } + session->last_cmd = 0xff; + consume_skb(se_skb); + j1939_session_put(session); + + return; + + out_session_cancel: + kfree_skb(se_skb); + j1939_session_timers_cancel(session); + j1939_session_cancel(session, abort); + j1939_session_put(session); +} + +static void j1939_xtp_rx_dat(struct j1939_priv *priv, struct sk_buff *skb) +{ + struct j1939_sk_buff_cb *skcb; + struct j1939_session *session; + + skcb = j1939_skb_to_cb(skb); + + if (j1939_tp_im_transmitter(skcb)) { + session = j1939_session_get_by_addr(priv, &skcb->addr, false, + true); + if (!session) + netdev_info(priv->ndev, "%s: no tx connection found\n", + __func__); + else + j1939_xtp_rx_dat_one(session, skb); + } + + if (j1939_tp_im_receiver(skcb)) { + session = j1939_session_get_by_addr(priv, &skcb->addr, false, + false); + if (!session) + netdev_info(priv->ndev, "%s: no rx connection found\n", + __func__); + else + j1939_xtp_rx_dat_one(session, skb); + } + + if (j1939_cb_is_broadcast(skcb)) { + session = j1939_session_get_by_addr(priv, &skcb->addr, false, + false); + if (session) + j1939_xtp_rx_dat_one(session, skb); + } +} + +/* j1939 main intf */ +struct j1939_session *j1939_tp_send(struct j1939_priv *priv, + struct sk_buff *skb, size_t size) +{ + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + struct j1939_session *session; + int ret; + + if (skcb->addr.pgn == J1939_TP_PGN_DAT || + skcb->addr.pgn == J1939_TP_PGN_CTL || + skcb->addr.pgn == J1939_ETP_PGN_DAT || + skcb->addr.pgn == J1939_ETP_PGN_CTL) + /* avoid conflict */ + return ERR_PTR(-EDOM); + + if (size > priv->tp_max_packet_size) + return ERR_PTR(-EMSGSIZE); + + if (size <= 8) + skcb->addr.type = J1939_SIMPLE; + else if (size > J1939_MAX_TP_PACKET_SIZE) + skcb->addr.type = J1939_ETP; + else + skcb->addr.type = J1939_TP; + + if (skcb->addr.type == J1939_ETP && + j1939_cb_is_broadcast(skcb)) + return ERR_PTR(-EDESTADDRREQ); + + /* fill in addresses from names */ + ret = j1939_ac_fixup(priv, skb); + if (unlikely(ret)) + return ERR_PTR(ret); + + /* fix DST flags, it may be used there soon */ + if (j1939_address_is_unicast(skcb->addr.da) && + priv->ents[skcb->addr.da].nusers) + skcb->flags |= J1939_ECU_LOCAL_DST; + + /* src is always local, I'm sending ... */ + skcb->flags |= J1939_ECU_LOCAL_SRC; + + /* prepare new session */ + session = j1939_session_new(priv, skb, size); + if (!session) + return ERR_PTR(-ENOMEM); + + /* skb is recounted in j1939_session_new() */ + sock_hold(skb->sk); + session->sk = skb->sk; + session->transmission = true; + session->pkt.total = (size + 6) / 7; + session->pkt.block = skcb->addr.type == J1939_ETP ? 255 : + min(j1939_tp_block ?: 255, session->pkt.total); + + if (j1939_cb_is_broadcast(&session->skcb)) + /* set the end-packet for broadcast */ + session->pkt.last = session->pkt.total; + + skcb->tskey = atomic_inc_return(&session->sk->sk_tskey) - 1; + session->tskey = skcb->tskey; + + return session; +} + +static void j1939_tp_cmd_recv(struct j1939_priv *priv, struct sk_buff *skb) +{ + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + int extd = J1939_TP; + u8 cmd = skb->data[0]; + + switch (cmd) { + case J1939_ETP_CMD_RTS: + extd = J1939_ETP; + fallthrough; + case J1939_TP_CMD_BAM: + if (cmd == J1939_TP_CMD_BAM && !j1939_cb_is_broadcast(skcb)) { + netdev_err_once(priv->ndev, "%s: BAM to unicast (%02x), ignoring!\n", + __func__, skcb->addr.sa); + return; + } + fallthrough; + case J1939_TP_CMD_RTS: + if (skcb->addr.type != extd) + return; + + if (cmd == J1939_TP_CMD_RTS && j1939_cb_is_broadcast(skcb)) { + netdev_alert(priv->ndev, "%s: rts without destination (%02x)\n", + __func__, skcb->addr.sa); + return; + } + + if (j1939_tp_im_transmitter(skcb)) + j1939_xtp_rx_rts(priv, skb, true); + + if (j1939_tp_im_receiver(skcb) || j1939_cb_is_broadcast(skcb)) + j1939_xtp_rx_rts(priv, skb, false); + + break; + + case J1939_ETP_CMD_CTS: + extd = J1939_ETP; + fallthrough; + case J1939_TP_CMD_CTS: + if (skcb->addr.type != extd) + return; + + if (j1939_tp_im_transmitter(skcb)) + j1939_xtp_rx_cts(priv, skb, false); + + if (j1939_tp_im_receiver(skcb)) + j1939_xtp_rx_cts(priv, skb, true); + + break; + + case J1939_ETP_CMD_DPO: + if (skcb->addr.type != J1939_ETP) + return; + + if (j1939_tp_im_transmitter(skcb)) + j1939_xtp_rx_dpo(priv, skb, true); + + if (j1939_tp_im_receiver(skcb)) + j1939_xtp_rx_dpo(priv, skb, false); + + break; + + case J1939_ETP_CMD_EOMA: + extd = J1939_ETP; + fallthrough; + case J1939_TP_CMD_EOMA: + if (skcb->addr.type != extd) + return; + + if (j1939_tp_im_transmitter(skcb)) + j1939_xtp_rx_eoma(priv, skb, false); + + if (j1939_tp_im_receiver(skcb)) + j1939_xtp_rx_eoma(priv, skb, true); + + break; + + case J1939_ETP_CMD_ABORT: /* && J1939_TP_CMD_ABORT */ + if (j1939_cb_is_broadcast(skcb)) { + netdev_err_once(priv->ndev, "%s: abort to broadcast (%02x), ignoring!\n", + __func__, skcb->addr.sa); + return; + } + + if (j1939_tp_im_transmitter(skcb)) + j1939_xtp_rx_abort(priv, skb, true); + + if (j1939_tp_im_receiver(skcb)) + j1939_xtp_rx_abort(priv, skb, false); + + break; + default: + return; + } +} + +int j1939_tp_recv(struct j1939_priv *priv, struct sk_buff *skb) +{ + struct j1939_sk_buff_cb *skcb = j1939_skb_to_cb(skb); + + if (!j1939_tp_im_involved_anydir(skcb) && !j1939_cb_is_broadcast(skcb)) + return 0; + + switch (skcb->addr.pgn) { + case J1939_ETP_PGN_DAT: + skcb->addr.type = J1939_ETP; + fallthrough; + case J1939_TP_PGN_DAT: + j1939_xtp_rx_dat(priv, skb); + break; + + case J1939_ETP_PGN_CTL: + skcb->addr.type = J1939_ETP; + fallthrough; + case J1939_TP_PGN_CTL: + if (skb->len < 8) + return 0; /* Don't care. Nothing to extract here */ + + j1939_tp_cmd_recv(priv, skb); + break; + default: + return 0; /* no problem */ + } + return 1; /* "I processed the message" */ +} + +void j1939_simple_recv(struct j1939_priv *priv, struct sk_buff *skb) +{ + struct j1939_session *session; + + if (!skb->sk) + return; + + if (skb->sk->sk_family != AF_CAN || + skb->sk->sk_protocol != CAN_J1939) + return; + + j1939_session_list_lock(priv); + session = j1939_session_get_simple(priv, skb); + j1939_session_list_unlock(priv); + if (!session) { + netdev_warn(priv->ndev, + "%s: Received already invalidated message\n", + __func__); + return; + } + + j1939_session_timers_cancel(session); + j1939_session_deactivate(session); + j1939_session_put(session); +} + +int j1939_cancel_active_session(struct j1939_priv *priv, struct sock *sk) +{ + struct j1939_session *session, *saved; + + netdev_dbg(priv->ndev, "%s, sk: %p\n", __func__, sk); + j1939_session_list_lock(priv); + list_for_each_entry_safe(session, saved, + &priv->active_session_list, + active_session_list_entry) { + if (!sk || sk == session->sk) { + if (hrtimer_try_to_cancel(&session->txtimer) == 1) + j1939_session_put(session); + if (hrtimer_try_to_cancel(&session->rxtimer) == 1) + j1939_session_put(session); + + session->err = ESHUTDOWN; + j1939_session_deactivate_locked(session); + } + } + j1939_session_list_unlock(priv); + return NOTIFY_DONE; +} + +void j1939_tp_init(struct j1939_priv *priv) +{ + spin_lock_init(&priv->active_session_list_lock); + INIT_LIST_HEAD(&priv->active_session_list); + priv->tp_max_packet_size = J1939_MAX_ETP_PACKET_SIZE; +} diff --git a/net/can/proc.c b/net/can/proc.c new file mode 100644 index 000000000..bbce97825 --- /dev/null +++ b/net/can/proc.c @@ -0,0 +1,498 @@ +// SPDX-License-Identifier: (GPL-2.0 OR BSD-3-Clause) +/* + * proc.c - procfs support for Protocol family CAN core module + * + * Copyright (c) 2002-2007 Volkswagen Group Electronic Research + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of Volkswagen nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * Alternatively, provided that this notice is retained in full, this + * software may be distributed under the terms of the GNU General + * Public License ("GPL") version 2, in which case the provisions of the + * GPL apply INSTEAD OF those given above. + * + * The provided data structures and external interfaces from this code + * are not restricted to be used by modules with a GPL compatible license. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH + * DAMAGE. + * + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "af_can.h" + +/* + * proc filenames for the PF_CAN core + */ + +#define CAN_PROC_STATS "stats" +#define CAN_PROC_RESET_STATS "reset_stats" +#define CAN_PROC_RCVLIST_ALL "rcvlist_all" +#define CAN_PROC_RCVLIST_FIL "rcvlist_fil" +#define CAN_PROC_RCVLIST_INV "rcvlist_inv" +#define CAN_PROC_RCVLIST_SFF "rcvlist_sff" +#define CAN_PROC_RCVLIST_EFF "rcvlist_eff" +#define CAN_PROC_RCVLIST_ERR "rcvlist_err" + +static int user_reset; + +static const char rx_list_name[][8] = { + [RX_ERR] = "rx_err", + [RX_ALL] = "rx_all", + [RX_FIL] = "rx_fil", + [RX_INV] = "rx_inv", +}; + +/* + * af_can statistics stuff + */ + +static void can_init_stats(struct net *net) +{ + struct can_pkg_stats *pkg_stats = net->can.pkg_stats; + struct can_rcv_lists_stats *rcv_lists_stats = net->can.rcv_lists_stats; + /* + * This memset function is called from a timer context (when + * can_stattimer is active which is the default) OR in a process + * context (reading the proc_fs when can_stattimer is disabled). + */ + memset(pkg_stats, 0, sizeof(struct can_pkg_stats)); + pkg_stats->jiffies_init = jiffies; + + rcv_lists_stats->stats_reset++; + + if (user_reset) { + user_reset = 0; + rcv_lists_stats->user_reset++; + } +} + +static unsigned long calc_rate(unsigned long oldjif, unsigned long newjif, + unsigned long count) +{ + if (oldjif == newjif) + return 0; + + /* see can_stat_update() - this should NEVER happen! */ + if (count > (ULONG_MAX / HZ)) { + printk(KERN_ERR "can: calc_rate: count exceeded! %ld\n", + count); + return 99999999; + } + + return (count * HZ) / (newjif - oldjif); +} + +void can_stat_update(struct timer_list *t) +{ + struct net *net = from_timer(net, t, can.stattimer); + struct can_pkg_stats *pkg_stats = net->can.pkg_stats; + unsigned long j = jiffies; /* snapshot */ + + /* restart counting in timer context on user request */ + if (user_reset) + can_init_stats(net); + + /* restart counting on jiffies overflow */ + if (j < pkg_stats->jiffies_init) + can_init_stats(net); + + /* prevent overflow in calc_rate() */ + if (pkg_stats->rx_frames > (ULONG_MAX / HZ)) + can_init_stats(net); + + /* prevent overflow in calc_rate() */ + if (pkg_stats->tx_frames > (ULONG_MAX / HZ)) + can_init_stats(net); + + /* matches overflow - very improbable */ + if (pkg_stats->matches > (ULONG_MAX / 100)) + can_init_stats(net); + + /* calc total values */ + if (pkg_stats->rx_frames) + pkg_stats->total_rx_match_ratio = (pkg_stats->matches * 100) / + pkg_stats->rx_frames; + + pkg_stats->total_tx_rate = calc_rate(pkg_stats->jiffies_init, j, + pkg_stats->tx_frames); + pkg_stats->total_rx_rate = calc_rate(pkg_stats->jiffies_init, j, + pkg_stats->rx_frames); + + /* calc current values */ + if (pkg_stats->rx_frames_delta) + pkg_stats->current_rx_match_ratio = + (pkg_stats->matches_delta * 100) / + pkg_stats->rx_frames_delta; + + pkg_stats->current_tx_rate = calc_rate(0, HZ, pkg_stats->tx_frames_delta); + pkg_stats->current_rx_rate = calc_rate(0, HZ, pkg_stats->rx_frames_delta); + + /* check / update maximum values */ + if (pkg_stats->max_tx_rate < pkg_stats->current_tx_rate) + pkg_stats->max_tx_rate = pkg_stats->current_tx_rate; + + if (pkg_stats->max_rx_rate < pkg_stats->current_rx_rate) + pkg_stats->max_rx_rate = pkg_stats->current_rx_rate; + + if (pkg_stats->max_rx_match_ratio < pkg_stats->current_rx_match_ratio) + pkg_stats->max_rx_match_ratio = pkg_stats->current_rx_match_ratio; + + /* clear values for 'current rate' calculation */ + pkg_stats->tx_frames_delta = 0; + pkg_stats->rx_frames_delta = 0; + pkg_stats->matches_delta = 0; + + /* restart timer (one second) */ + mod_timer(&net->can.stattimer, round_jiffies(jiffies + HZ)); +} + +/* + * proc read functions + */ + +static void can_print_rcvlist(struct seq_file *m, struct hlist_head *rx_list, + struct net_device *dev) +{ + struct receiver *r; + + hlist_for_each_entry_rcu(r, rx_list, list) { + char *fmt = (r->can_id & CAN_EFF_FLAG)? + " %-5s %08x %08x %pK %pK %8ld %s\n" : + " %-5s %03x %08x %pK %pK %8ld %s\n"; + + seq_printf(m, fmt, DNAME(dev), r->can_id, r->mask, + r->func, r->data, r->matches, r->ident); + } +} + +static void can_print_recv_banner(struct seq_file *m) +{ + /* + * can1. 00000000 00000000 00000000 + * ....... 0 tp20 + */ + if (IS_ENABLED(CONFIG_64BIT)) + seq_puts(m, " device can_id can_mask function userdata matches ident\n"); + else + seq_puts(m, " device can_id can_mask function userdata matches ident\n"); +} + +static int can_stats_proc_show(struct seq_file *m, void *v) +{ + struct net *net = m->private; + struct can_pkg_stats *pkg_stats = net->can.pkg_stats; + struct can_rcv_lists_stats *rcv_lists_stats = net->can.rcv_lists_stats; + + seq_putc(m, '\n'); + seq_printf(m, " %8ld transmitted frames (TXF)\n", pkg_stats->tx_frames); + seq_printf(m, " %8ld received frames (RXF)\n", pkg_stats->rx_frames); + seq_printf(m, " %8ld matched frames (RXMF)\n", pkg_stats->matches); + + seq_putc(m, '\n'); + + if (net->can.stattimer.function == can_stat_update) { + seq_printf(m, " %8ld %% total match ratio (RXMR)\n", + pkg_stats->total_rx_match_ratio); + + seq_printf(m, " %8ld frames/s total tx rate (TXR)\n", + pkg_stats->total_tx_rate); + seq_printf(m, " %8ld frames/s total rx rate (RXR)\n", + pkg_stats->total_rx_rate); + + seq_putc(m, '\n'); + + seq_printf(m, " %8ld %% current match ratio (CRXMR)\n", + pkg_stats->current_rx_match_ratio); + + seq_printf(m, " %8ld frames/s current tx rate (CTXR)\n", + pkg_stats->current_tx_rate); + seq_printf(m, " %8ld frames/s current rx rate (CRXR)\n", + pkg_stats->current_rx_rate); + + seq_putc(m, '\n'); + + seq_printf(m, " %8ld %% max match ratio (MRXMR)\n", + pkg_stats->max_rx_match_ratio); + + seq_printf(m, " %8ld frames/s max tx rate (MTXR)\n", + pkg_stats->max_tx_rate); + seq_printf(m, " %8ld frames/s max rx rate (MRXR)\n", + pkg_stats->max_rx_rate); + + seq_putc(m, '\n'); + } + + seq_printf(m, " %8ld current receive list entries (CRCV)\n", + rcv_lists_stats->rcv_entries); + seq_printf(m, " %8ld maximum receive list entries (MRCV)\n", + rcv_lists_stats->rcv_entries_max); + + if (rcv_lists_stats->stats_reset) + seq_printf(m, "\n %8ld statistic resets (STR)\n", + rcv_lists_stats->stats_reset); + + if (rcv_lists_stats->user_reset) + seq_printf(m, " %8ld user statistic resets (USTR)\n", + rcv_lists_stats->user_reset); + + seq_putc(m, '\n'); + return 0; +} + +static int can_reset_stats_proc_show(struct seq_file *m, void *v) +{ + struct net *net = m->private; + struct can_rcv_lists_stats *rcv_lists_stats = net->can.rcv_lists_stats; + struct can_pkg_stats *pkg_stats = net->can.pkg_stats; + + user_reset = 1; + + if (net->can.stattimer.function == can_stat_update) { + seq_printf(m, "Scheduled statistic reset #%ld.\n", + rcv_lists_stats->stats_reset + 1); + } else { + if (pkg_stats->jiffies_init != jiffies) + can_init_stats(net); + + seq_printf(m, "Performed statistic reset #%ld.\n", + rcv_lists_stats->stats_reset); + } + return 0; +} + +static inline void can_rcvlist_proc_show_one(struct seq_file *m, int idx, + struct net_device *dev, + struct can_dev_rcv_lists *dev_rcv_lists) +{ + if (!hlist_empty(&dev_rcv_lists->rx[idx])) { + can_print_recv_banner(m); + can_print_rcvlist(m, &dev_rcv_lists->rx[idx], dev); + } else + seq_printf(m, " (%s: no entry)\n", DNAME(dev)); + +} + +static int can_rcvlist_proc_show(struct seq_file *m, void *v) +{ + /* double cast to prevent GCC warning */ + int idx = (int)(long)pde_data(m->file->f_inode); + struct net_device *dev; + struct can_dev_rcv_lists *dev_rcv_lists; + struct net *net = m->private; + + seq_printf(m, "\nreceive list '%s':\n", rx_list_name[idx]); + + rcu_read_lock(); + + /* receive list for 'all' CAN devices (dev == NULL) */ + dev_rcv_lists = net->can.rx_alldev_list; + can_rcvlist_proc_show_one(m, idx, NULL, dev_rcv_lists); + + /* receive list for registered CAN devices */ + for_each_netdev_rcu(net, dev) { + struct can_ml_priv *can_ml = can_get_ml_priv(dev); + + if (can_ml) + can_rcvlist_proc_show_one(m, idx, dev, + &can_ml->dev_rcv_lists); + } + + rcu_read_unlock(); + + seq_putc(m, '\n'); + return 0; +} + +static inline void can_rcvlist_proc_show_array(struct seq_file *m, + struct net_device *dev, + struct hlist_head *rcv_array, + unsigned int rcv_array_sz) +{ + unsigned int i; + int all_empty = 1; + + /* check whether at least one list is non-empty */ + for (i = 0; i < rcv_array_sz; i++) + if (!hlist_empty(&rcv_array[i])) { + all_empty = 0; + break; + } + + if (!all_empty) { + can_print_recv_banner(m); + for (i = 0; i < rcv_array_sz; i++) { + if (!hlist_empty(&rcv_array[i])) + can_print_rcvlist(m, &rcv_array[i], dev); + } + } else + seq_printf(m, " (%s: no entry)\n", DNAME(dev)); +} + +static int can_rcvlist_sff_proc_show(struct seq_file *m, void *v) +{ + struct net_device *dev; + struct can_dev_rcv_lists *dev_rcv_lists; + struct net *net = m->private; + + /* RX_SFF */ + seq_puts(m, "\nreceive list 'rx_sff':\n"); + + rcu_read_lock(); + + /* sff receive list for 'all' CAN devices (dev == NULL) */ + dev_rcv_lists = net->can.rx_alldev_list; + can_rcvlist_proc_show_array(m, NULL, dev_rcv_lists->rx_sff, + ARRAY_SIZE(dev_rcv_lists->rx_sff)); + + /* sff receive list for registered CAN devices */ + for_each_netdev_rcu(net, dev) { + struct can_ml_priv *can_ml = can_get_ml_priv(dev); + + if (can_ml) { + dev_rcv_lists = &can_ml->dev_rcv_lists; + can_rcvlist_proc_show_array(m, dev, dev_rcv_lists->rx_sff, + ARRAY_SIZE(dev_rcv_lists->rx_sff)); + } + } + + rcu_read_unlock(); + + seq_putc(m, '\n'); + return 0; +} + +static int can_rcvlist_eff_proc_show(struct seq_file *m, void *v) +{ + struct net_device *dev; + struct can_dev_rcv_lists *dev_rcv_lists; + struct net *net = m->private; + + /* RX_EFF */ + seq_puts(m, "\nreceive list 'rx_eff':\n"); + + rcu_read_lock(); + + /* eff receive list for 'all' CAN devices (dev == NULL) */ + dev_rcv_lists = net->can.rx_alldev_list; + can_rcvlist_proc_show_array(m, NULL, dev_rcv_lists->rx_eff, + ARRAY_SIZE(dev_rcv_lists->rx_eff)); + + /* eff receive list for registered CAN devices */ + for_each_netdev_rcu(net, dev) { + struct can_ml_priv *can_ml = can_get_ml_priv(dev); + + if (can_ml) { + dev_rcv_lists = &can_ml->dev_rcv_lists; + can_rcvlist_proc_show_array(m, dev, dev_rcv_lists->rx_eff, + ARRAY_SIZE(dev_rcv_lists->rx_eff)); + } + } + + rcu_read_unlock(); + + seq_putc(m, '\n'); + return 0; +} + +/* + * can_init_proc - create main CAN proc directory and procfs entries + */ +void can_init_proc(struct net *net) +{ + /* create /proc/net/can directory */ + net->can.proc_dir = proc_net_mkdir(net, "can", net->proc_net); + + if (!net->can.proc_dir) { + printk(KERN_INFO "can: failed to create /proc/net/can . " + "CONFIG_PROC_FS missing?\n"); + return; + } + + /* own procfs entries from the AF_CAN core */ + net->can.pde_stats = proc_create_net_single(CAN_PROC_STATS, 0644, + net->can.proc_dir, can_stats_proc_show, NULL); + net->can.pde_reset_stats = proc_create_net_single(CAN_PROC_RESET_STATS, + 0644, net->can.proc_dir, can_reset_stats_proc_show, + NULL); + net->can.pde_rcvlist_err = proc_create_net_single(CAN_PROC_RCVLIST_ERR, + 0644, net->can.proc_dir, can_rcvlist_proc_show, + (void *)RX_ERR); + net->can.pde_rcvlist_all = proc_create_net_single(CAN_PROC_RCVLIST_ALL, + 0644, net->can.proc_dir, can_rcvlist_proc_show, + (void *)RX_ALL); + net->can.pde_rcvlist_fil = proc_create_net_single(CAN_PROC_RCVLIST_FIL, + 0644, net->can.proc_dir, can_rcvlist_proc_show, + (void *)RX_FIL); + net->can.pde_rcvlist_inv = proc_create_net_single(CAN_PROC_RCVLIST_INV, + 0644, net->can.proc_dir, can_rcvlist_proc_show, + (void *)RX_INV); + net->can.pde_rcvlist_eff = proc_create_net_single(CAN_PROC_RCVLIST_EFF, + 0644, net->can.proc_dir, can_rcvlist_eff_proc_show, NULL); + net->can.pde_rcvlist_sff = proc_create_net_single(CAN_PROC_RCVLIST_SFF, + 0644, net->can.proc_dir, can_rcvlist_sff_proc_show, NULL); +} + +/* + * can_remove_proc - remove procfs entries and main CAN proc directory + */ +void can_remove_proc(struct net *net) +{ + if (!net->can.proc_dir) + return; + + if (net->can.pde_stats) + remove_proc_entry(CAN_PROC_STATS, net->can.proc_dir); + + if (net->can.pde_reset_stats) + remove_proc_entry(CAN_PROC_RESET_STATS, net->can.proc_dir); + + if (net->can.pde_rcvlist_err) + remove_proc_entry(CAN_PROC_RCVLIST_ERR, net->can.proc_dir); + + if (net->can.pde_rcvlist_all) + remove_proc_entry(CAN_PROC_RCVLIST_ALL, net->can.proc_dir); + + if (net->can.pde_rcvlist_fil) + remove_proc_entry(CAN_PROC_RCVLIST_FIL, net->can.proc_dir); + + if (net->can.pde_rcvlist_inv) + remove_proc_entry(CAN_PROC_RCVLIST_INV, net->can.proc_dir); + + if (net->can.pde_rcvlist_eff) + remove_proc_entry(CAN_PROC_RCVLIST_EFF, net->can.proc_dir); + + if (net->can.pde_rcvlist_sff) + remove_proc_entry(CAN_PROC_RCVLIST_SFF, net->can.proc_dir); + + remove_proc_entry("can", net->proc_net); +} diff --git a/net/can/raw.c b/net/can/raw.c new file mode 100644 index 000000000..488320738 --- /dev/null +++ b/net/can/raw.c @@ -0,0 +1,1024 @@ +// SPDX-License-Identifier: (GPL-2.0 OR BSD-3-Clause) +/* raw.c - Raw sockets for protocol family CAN + * + * Copyright (c) 2002-2007 Volkswagen Group Electronic Research + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of Volkswagen nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * Alternatively, provided that this notice is retained in full, this + * software may be distributed under the terms of the GNU General + * Public License ("GPL") version 2, in which case the provisions of the + * GPL apply INSTEAD OF those given above. + * + * The provided data structures and external interfaces from this code + * are not restricted to be used by modules with a GPL compatible license. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH + * DAMAGE. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include /* for can_is_canxl_dev_mtu() */ +#include +#include +#include +#include + +MODULE_DESCRIPTION("PF_CAN raw protocol"); +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_AUTHOR("Urs Thuermann "); +MODULE_ALIAS("can-proto-1"); + +#define RAW_MIN_NAMELEN CAN_REQUIRED_SIZE(struct sockaddr_can, can_ifindex) + +#define MASK_ALL 0 + +/* A raw socket has a list of can_filters attached to it, each receiving + * the CAN frames matching that filter. If the filter list is empty, + * no CAN frames will be received by the socket. The default after + * opening the socket, is to have one filter which receives all frames. + * The filter list is allocated dynamically with the exception of the + * list containing only one item. This common case is optimized by + * storing the single filter in dfilter, to avoid using dynamic memory. + */ + +struct uniqframe { + int skbcnt; + const struct sk_buff *skb; + unsigned int join_rx_count; +}; + +struct raw_sock { + struct sock sk; + int bound; + int ifindex; + struct net_device *dev; + netdevice_tracker dev_tracker; + struct list_head notifier; + int loopback; + int recv_own_msgs; + int fd_frames; + int xl_frames; + int join_filters; + int count; /* number of active filters */ + struct can_filter dfilter; /* default/single filter */ + struct can_filter *filter; /* pointer to filter(s) */ + can_err_mask_t err_mask; + struct uniqframe __percpu *uniq; +}; + +static LIST_HEAD(raw_notifier_list); +static DEFINE_SPINLOCK(raw_notifier_lock); +static struct raw_sock *raw_busy_notifier; + +/* Return pointer to store the extra msg flags for raw_recvmsg(). + * We use the space of one unsigned int beyond the 'struct sockaddr_can' + * in skb->cb. + */ +static inline unsigned int *raw_flags(struct sk_buff *skb) +{ + sock_skb_cb_check_size(sizeof(struct sockaddr_can) + + sizeof(unsigned int)); + + /* return pointer after struct sockaddr_can */ + return (unsigned int *)(&((struct sockaddr_can *)skb->cb)[1]); +} + +static inline struct raw_sock *raw_sk(const struct sock *sk) +{ + return (struct raw_sock *)sk; +} + +static void raw_rcv(struct sk_buff *oskb, void *data) +{ + struct sock *sk = (struct sock *)data; + struct raw_sock *ro = raw_sk(sk); + struct sockaddr_can *addr; + struct sk_buff *skb; + unsigned int *pflags; + + /* check the received tx sock reference */ + if (!ro->recv_own_msgs && oskb->sk == sk) + return; + + /* make sure to not pass oversized frames to the socket */ + if ((!ro->fd_frames && can_is_canfd_skb(oskb)) || + (!ro->xl_frames && can_is_canxl_skb(oskb))) + return; + + /* eliminate multiple filter matches for the same skb */ + if (this_cpu_ptr(ro->uniq)->skb == oskb && + this_cpu_ptr(ro->uniq)->skbcnt == can_skb_prv(oskb)->skbcnt) { + if (!ro->join_filters) + return; + + this_cpu_inc(ro->uniq->join_rx_count); + /* drop frame until all enabled filters matched */ + if (this_cpu_ptr(ro->uniq)->join_rx_count < ro->count) + return; + } else { + this_cpu_ptr(ro->uniq)->skb = oskb; + this_cpu_ptr(ro->uniq)->skbcnt = can_skb_prv(oskb)->skbcnt; + this_cpu_ptr(ro->uniq)->join_rx_count = 1; + /* drop first frame to check all enabled filters? */ + if (ro->join_filters && ro->count > 1) + return; + } + + /* clone the given skb to be able to enqueue it into the rcv queue */ + skb = skb_clone(oskb, GFP_ATOMIC); + if (!skb) + return; + + /* Put the datagram to the queue so that raw_recvmsg() can get + * it from there. We need to pass the interface index to + * raw_recvmsg(). We pass a whole struct sockaddr_can in + * skb->cb containing the interface index. + */ + + sock_skb_cb_check_size(sizeof(struct sockaddr_can)); + addr = (struct sockaddr_can *)skb->cb; + memset(addr, 0, sizeof(*addr)); + addr->can_family = AF_CAN; + addr->can_ifindex = skb->dev->ifindex; + + /* add CAN specific message flags for raw_recvmsg() */ + pflags = raw_flags(skb); + *pflags = 0; + if (oskb->sk) + *pflags |= MSG_DONTROUTE; + if (oskb->sk == sk) + *pflags |= MSG_CONFIRM; + + if (sock_queue_rcv_skb(sk, skb) < 0) + kfree_skb(skb); +} + +static int raw_enable_filters(struct net *net, struct net_device *dev, + struct sock *sk, struct can_filter *filter, + int count) +{ + int err = 0; + int i; + + for (i = 0; i < count; i++) { + err = can_rx_register(net, dev, filter[i].can_id, + filter[i].can_mask, + raw_rcv, sk, "raw", sk); + if (err) { + /* clean up successfully registered filters */ + while (--i >= 0) + can_rx_unregister(net, dev, filter[i].can_id, + filter[i].can_mask, + raw_rcv, sk); + break; + } + } + + return err; +} + +static int raw_enable_errfilter(struct net *net, struct net_device *dev, + struct sock *sk, can_err_mask_t err_mask) +{ + int err = 0; + + if (err_mask) + err = can_rx_register(net, dev, 0, err_mask | CAN_ERR_FLAG, + raw_rcv, sk, "raw", sk); + + return err; +} + +static void raw_disable_filters(struct net *net, struct net_device *dev, + struct sock *sk, struct can_filter *filter, + int count) +{ + int i; + + for (i = 0; i < count; i++) + can_rx_unregister(net, dev, filter[i].can_id, + filter[i].can_mask, raw_rcv, sk); +} + +static inline void raw_disable_errfilter(struct net *net, + struct net_device *dev, + struct sock *sk, + can_err_mask_t err_mask) + +{ + if (err_mask) + can_rx_unregister(net, dev, 0, err_mask | CAN_ERR_FLAG, + raw_rcv, sk); +} + +static inline void raw_disable_allfilters(struct net *net, + struct net_device *dev, + struct sock *sk) +{ + struct raw_sock *ro = raw_sk(sk); + + raw_disable_filters(net, dev, sk, ro->filter, ro->count); + raw_disable_errfilter(net, dev, sk, ro->err_mask); +} + +static int raw_enable_allfilters(struct net *net, struct net_device *dev, + struct sock *sk) +{ + struct raw_sock *ro = raw_sk(sk); + int err; + + err = raw_enable_filters(net, dev, sk, ro->filter, ro->count); + if (!err) { + err = raw_enable_errfilter(net, dev, sk, ro->err_mask); + if (err) + raw_disable_filters(net, dev, sk, ro->filter, + ro->count); + } + + return err; +} + +static void raw_notify(struct raw_sock *ro, unsigned long msg, + struct net_device *dev) +{ + struct sock *sk = &ro->sk; + + if (!net_eq(dev_net(dev), sock_net(sk))) + return; + + if (ro->dev != dev) + return; + + switch (msg) { + case NETDEV_UNREGISTER: + lock_sock(sk); + /* remove current filters & unregister */ + if (ro->bound) { + raw_disable_allfilters(dev_net(dev), dev, sk); + netdev_put(dev, &ro->dev_tracker); + } + + if (ro->count > 1) + kfree(ro->filter); + + ro->ifindex = 0; + ro->bound = 0; + ro->dev = NULL; + ro->count = 0; + release_sock(sk); + + sk->sk_err = ENODEV; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + break; + + case NETDEV_DOWN: + sk->sk_err = ENETDOWN; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + break; + } +} + +static int raw_notifier(struct notifier_block *nb, unsigned long msg, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + if (dev->type != ARPHRD_CAN) + return NOTIFY_DONE; + if (msg != NETDEV_UNREGISTER && msg != NETDEV_DOWN) + return NOTIFY_DONE; + if (unlikely(raw_busy_notifier)) /* Check for reentrant bug. */ + return NOTIFY_DONE; + + spin_lock(&raw_notifier_lock); + list_for_each_entry(raw_busy_notifier, &raw_notifier_list, notifier) { + spin_unlock(&raw_notifier_lock); + raw_notify(raw_busy_notifier, msg, dev); + spin_lock(&raw_notifier_lock); + } + raw_busy_notifier = NULL; + spin_unlock(&raw_notifier_lock); + return NOTIFY_DONE; +} + +static int raw_init(struct sock *sk) +{ + struct raw_sock *ro = raw_sk(sk); + + ro->bound = 0; + ro->ifindex = 0; + ro->dev = NULL; + + /* set default filter to single entry dfilter */ + ro->dfilter.can_id = 0; + ro->dfilter.can_mask = MASK_ALL; + ro->filter = &ro->dfilter; + ro->count = 1; + + /* set default loopback behaviour */ + ro->loopback = 1; + ro->recv_own_msgs = 0; + ro->fd_frames = 0; + ro->xl_frames = 0; + ro->join_filters = 0; + + /* alloc_percpu provides zero'ed memory */ + ro->uniq = alloc_percpu(struct uniqframe); + if (unlikely(!ro->uniq)) + return -ENOMEM; + + /* set notifier */ + spin_lock(&raw_notifier_lock); + list_add_tail(&ro->notifier, &raw_notifier_list); + spin_unlock(&raw_notifier_lock); + + return 0; +} + +static int raw_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct raw_sock *ro; + + if (!sk) + return 0; + + ro = raw_sk(sk); + + spin_lock(&raw_notifier_lock); + while (raw_busy_notifier == ro) { + spin_unlock(&raw_notifier_lock); + schedule_timeout_uninterruptible(1); + spin_lock(&raw_notifier_lock); + } + list_del(&ro->notifier); + spin_unlock(&raw_notifier_lock); + + rtnl_lock(); + lock_sock(sk); + + /* remove current filters & unregister */ + if (ro->bound) { + if (ro->dev) { + raw_disable_allfilters(dev_net(ro->dev), ro->dev, sk); + netdev_put(ro->dev, &ro->dev_tracker); + } else { + raw_disable_allfilters(sock_net(sk), NULL, sk); + } + } + + if (ro->count > 1) + kfree(ro->filter); + + ro->ifindex = 0; + ro->bound = 0; + ro->dev = NULL; + ro->count = 0; + free_percpu(ro->uniq); + + sock_orphan(sk); + sock->sk = NULL; + + release_sock(sk); + rtnl_unlock(); + + sock_put(sk); + + return 0; +} + +static int raw_bind(struct socket *sock, struct sockaddr *uaddr, int len) +{ + struct sockaddr_can *addr = (struct sockaddr_can *)uaddr; + struct sock *sk = sock->sk; + struct raw_sock *ro = raw_sk(sk); + struct net_device *dev = NULL; + int ifindex; + int err = 0; + int notify_enetdown = 0; + + if (len < RAW_MIN_NAMELEN) + return -EINVAL; + if (addr->can_family != AF_CAN) + return -EINVAL; + + rtnl_lock(); + lock_sock(sk); + + if (ro->bound && addr->can_ifindex == ro->ifindex) + goto out; + + if (addr->can_ifindex) { + dev = dev_get_by_index(sock_net(sk), addr->can_ifindex); + if (!dev) { + err = -ENODEV; + goto out; + } + if (dev->type != ARPHRD_CAN) { + err = -ENODEV; + goto out_put_dev; + } + + if (!(dev->flags & IFF_UP)) + notify_enetdown = 1; + + ifindex = dev->ifindex; + + /* filters set by default/setsockopt */ + err = raw_enable_allfilters(sock_net(sk), dev, sk); + if (err) + goto out_put_dev; + + } else { + ifindex = 0; + + /* filters set by default/setsockopt */ + err = raw_enable_allfilters(sock_net(sk), NULL, sk); + } + + if (!err) { + if (ro->bound) { + /* unregister old filters */ + if (ro->dev) { + raw_disable_allfilters(dev_net(ro->dev), + ro->dev, sk); + /* drop reference to old ro->dev */ + netdev_put(ro->dev, &ro->dev_tracker); + } else { + raw_disable_allfilters(sock_net(sk), NULL, sk); + } + } + ro->ifindex = ifindex; + ro->bound = 1; + /* bind() ok -> hold a reference for new ro->dev */ + ro->dev = dev; + if (ro->dev) + netdev_hold(ro->dev, &ro->dev_tracker, GFP_KERNEL); + } + +out_put_dev: + /* remove potential reference from dev_get_by_index() */ + if (dev) + dev_put(dev); +out: + release_sock(sk); + rtnl_unlock(); + + if (notify_enetdown) { + sk->sk_err = ENETDOWN; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + } + + return err; +} + +static int raw_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + struct sockaddr_can *addr = (struct sockaddr_can *)uaddr; + struct sock *sk = sock->sk; + struct raw_sock *ro = raw_sk(sk); + + if (peer) + return -EOPNOTSUPP; + + memset(addr, 0, RAW_MIN_NAMELEN); + addr->can_family = AF_CAN; + addr->can_ifindex = ro->ifindex; + + return RAW_MIN_NAMELEN; +} + +static int raw_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct raw_sock *ro = raw_sk(sk); + struct can_filter *filter = NULL; /* dyn. alloc'ed filters */ + struct can_filter sfilter; /* single filter */ + struct net_device *dev = NULL; + can_err_mask_t err_mask = 0; + int count = 0; + int err = 0; + + if (level != SOL_CAN_RAW) + return -EINVAL; + + switch (optname) { + case CAN_RAW_FILTER: + if (optlen % sizeof(struct can_filter) != 0) + return -EINVAL; + + if (optlen > CAN_RAW_FILTER_MAX * sizeof(struct can_filter)) + return -EINVAL; + + count = optlen / sizeof(struct can_filter); + + if (count > 1) { + /* filter does not fit into dfilter => alloc space */ + filter = memdup_sockptr(optval, optlen); + if (IS_ERR(filter)) + return PTR_ERR(filter); + } else if (count == 1) { + if (copy_from_sockptr(&sfilter, optval, sizeof(sfilter))) + return -EFAULT; + } + + rtnl_lock(); + lock_sock(sk); + + dev = ro->dev; + if (ro->bound && dev) { + if (dev->reg_state != NETREG_REGISTERED) { + if (count > 1) + kfree(filter); + err = -ENODEV; + goto out_fil; + } + } + + if (ro->bound) { + /* (try to) register the new filters */ + if (count == 1) + err = raw_enable_filters(sock_net(sk), dev, sk, + &sfilter, 1); + else + err = raw_enable_filters(sock_net(sk), dev, sk, + filter, count); + if (err) { + if (count > 1) + kfree(filter); + goto out_fil; + } + + /* remove old filter registrations */ + raw_disable_filters(sock_net(sk), dev, sk, ro->filter, + ro->count); + } + + /* remove old filter space */ + if (ro->count > 1) + kfree(ro->filter); + + /* link new filters to the socket */ + if (count == 1) { + /* copy filter data for single filter */ + ro->dfilter = sfilter; + filter = &ro->dfilter; + } + ro->filter = filter; + ro->count = count; + + out_fil: + release_sock(sk); + rtnl_unlock(); + + break; + + case CAN_RAW_ERR_FILTER: + if (optlen != sizeof(err_mask)) + return -EINVAL; + + if (copy_from_sockptr(&err_mask, optval, optlen)) + return -EFAULT; + + err_mask &= CAN_ERR_MASK; + + rtnl_lock(); + lock_sock(sk); + + dev = ro->dev; + if (ro->bound && dev) { + if (dev->reg_state != NETREG_REGISTERED) { + err = -ENODEV; + goto out_err; + } + } + + /* remove current error mask */ + if (ro->bound) { + /* (try to) register the new err_mask */ + err = raw_enable_errfilter(sock_net(sk), dev, sk, + err_mask); + + if (err) + goto out_err; + + /* remove old err_mask registration */ + raw_disable_errfilter(sock_net(sk), dev, sk, + ro->err_mask); + } + + /* link new err_mask to the socket */ + ro->err_mask = err_mask; + + out_err: + release_sock(sk); + rtnl_unlock(); + + break; + + case CAN_RAW_LOOPBACK: + if (optlen != sizeof(ro->loopback)) + return -EINVAL; + + if (copy_from_sockptr(&ro->loopback, optval, optlen)) + return -EFAULT; + + break; + + case CAN_RAW_RECV_OWN_MSGS: + if (optlen != sizeof(ro->recv_own_msgs)) + return -EINVAL; + + if (copy_from_sockptr(&ro->recv_own_msgs, optval, optlen)) + return -EFAULT; + + break; + + case CAN_RAW_FD_FRAMES: + if (optlen != sizeof(ro->fd_frames)) + return -EINVAL; + + if (copy_from_sockptr(&ro->fd_frames, optval, optlen)) + return -EFAULT; + + /* Enabling CAN XL includes CAN FD */ + if (ro->xl_frames && !ro->fd_frames) { + ro->fd_frames = ro->xl_frames; + return -EINVAL; + } + break; + + case CAN_RAW_XL_FRAMES: + if (optlen != sizeof(ro->xl_frames)) + return -EINVAL; + + if (copy_from_sockptr(&ro->xl_frames, optval, optlen)) + return -EFAULT; + + /* Enabling CAN XL includes CAN FD */ + if (ro->xl_frames) + ro->fd_frames = ro->xl_frames; + break; + + case CAN_RAW_JOIN_FILTERS: + if (optlen != sizeof(ro->join_filters)) + return -EINVAL; + + if (copy_from_sockptr(&ro->join_filters, optval, optlen)) + return -EFAULT; + + break; + + default: + return -ENOPROTOOPT; + } + return err; +} + +static int raw_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + struct raw_sock *ro = raw_sk(sk); + int len; + void *val; + int err = 0; + + if (level != SOL_CAN_RAW) + return -EINVAL; + if (get_user(len, optlen)) + return -EFAULT; + if (len < 0) + return -EINVAL; + + switch (optname) { + case CAN_RAW_FILTER: + lock_sock(sk); + if (ro->count > 0) { + int fsize = ro->count * sizeof(struct can_filter); + + /* user space buffer to small for filter list? */ + if (len < fsize) { + /* return -ERANGE and needed space in optlen */ + err = -ERANGE; + if (put_user(fsize, optlen)) + err = -EFAULT; + } else { + if (len > fsize) + len = fsize; + if (copy_to_user(optval, ro->filter, len)) + err = -EFAULT; + } + } else { + len = 0; + } + release_sock(sk); + + if (!err) + err = put_user(len, optlen); + return err; + + case CAN_RAW_ERR_FILTER: + if (len > sizeof(can_err_mask_t)) + len = sizeof(can_err_mask_t); + val = &ro->err_mask; + break; + + case CAN_RAW_LOOPBACK: + if (len > sizeof(int)) + len = sizeof(int); + val = &ro->loopback; + break; + + case CAN_RAW_RECV_OWN_MSGS: + if (len > sizeof(int)) + len = sizeof(int); + val = &ro->recv_own_msgs; + break; + + case CAN_RAW_FD_FRAMES: + if (len > sizeof(int)) + len = sizeof(int); + val = &ro->fd_frames; + break; + + case CAN_RAW_XL_FRAMES: + if (len > sizeof(int)) + len = sizeof(int); + val = &ro->xl_frames; + break; + + case CAN_RAW_JOIN_FILTERS: + if (len > sizeof(int)) + len = sizeof(int); + val = &ro->join_filters; + break; + + default: + return -ENOPROTOOPT; + } + + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, val, len)) + return -EFAULT; + return 0; +} + +static bool raw_bad_txframe(struct raw_sock *ro, struct sk_buff *skb, int mtu) +{ + /* Classical CAN -> no checks for flags and device capabilities */ + if (can_is_can_skb(skb)) + return false; + + /* CAN FD -> needs to be enabled and a CAN FD or CAN XL device */ + if (ro->fd_frames && can_is_canfd_skb(skb) && + (mtu == CANFD_MTU || can_is_canxl_dev_mtu(mtu))) + return false; + + /* CAN XL -> needs to be enabled and a CAN XL device */ + if (ro->xl_frames && can_is_canxl_skb(skb) && + can_is_canxl_dev_mtu(mtu)) + return false; + + return true; +} + +static int raw_sendmsg(struct socket *sock, struct msghdr *msg, size_t size) +{ + struct sock *sk = sock->sk; + struct raw_sock *ro = raw_sk(sk); + struct sockcm_cookie sockc; + struct sk_buff *skb; + struct net_device *dev; + int ifindex; + int err = -EINVAL; + + /* check for valid CAN frame sizes */ + if (size < CANXL_HDR_SIZE + CANXL_MIN_DLEN || size > CANXL_MTU) + return -EINVAL; + + if (msg->msg_name) { + DECLARE_SOCKADDR(struct sockaddr_can *, addr, msg->msg_name); + + if (msg->msg_namelen < RAW_MIN_NAMELEN) + return -EINVAL; + + if (addr->can_family != AF_CAN) + return -EINVAL; + + ifindex = addr->can_ifindex; + } else { + ifindex = ro->ifindex; + } + + dev = dev_get_by_index(sock_net(sk), ifindex); + if (!dev) + return -ENXIO; + + skb = sock_alloc_send_skb(sk, size + sizeof(struct can_skb_priv), + msg->msg_flags & MSG_DONTWAIT, &err); + if (!skb) + goto put_dev; + + can_skb_reserve(skb); + can_skb_prv(skb)->ifindex = dev->ifindex; + can_skb_prv(skb)->skbcnt = 0; + + /* fill the skb before testing for valid CAN frames */ + err = memcpy_from_msg(skb_put(skb, size), msg, size); + if (err < 0) + goto free_skb; + + err = -EINVAL; + if (raw_bad_txframe(ro, skb, dev->mtu)) + goto free_skb; + + sockcm_init(&sockc, sk); + if (msg->msg_controllen) { + err = sock_cmsg_send(sk, msg, &sockc); + if (unlikely(err)) + goto free_skb; + } + + skb->dev = dev; + skb->priority = sk->sk_priority; + skb->mark = sk->sk_mark; + skb->tstamp = sockc.transmit_time; + + skb_setup_tx_timestamp(skb, sockc.tsflags); + + err = can_send(skb, ro->loopback); + + dev_put(dev); + + if (err) + goto send_failed; + + return size; + +free_skb: + kfree_skb(skb); +put_dev: + dev_put(dev); +send_failed: + return err; +} + +static int raw_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, + int flags) +{ + struct sock *sk = sock->sk; + struct sk_buff *skb; + int err = 0; + + if (flags & MSG_ERRQUEUE) + return sock_recv_errqueue(sk, msg, size, + SOL_CAN_RAW, SCM_CAN_RAW_ERRQUEUE); + + skb = skb_recv_datagram(sk, flags, &err); + if (!skb) + return err; + + if (size < skb->len) + msg->msg_flags |= MSG_TRUNC; + else + size = skb->len; + + err = memcpy_to_msg(msg, skb->data, size); + if (err < 0) { + skb_free_datagram(sk, skb); + return err; + } + + sock_recv_cmsgs(msg, sk, skb); + + if (msg->msg_name) { + __sockaddr_check_size(RAW_MIN_NAMELEN); + msg->msg_namelen = RAW_MIN_NAMELEN; + memcpy(msg->msg_name, skb->cb, msg->msg_namelen); + } + + /* assign the flags that have been recorded in raw_rcv() */ + msg->msg_flags |= *(raw_flags(skb)); + + skb_free_datagram(sk, skb); + + return size; +} + +static int raw_sock_no_ioctlcmd(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + /* no ioctls for socket layer -> hand it down to NIC layer */ + return -ENOIOCTLCMD; +} + +static const struct proto_ops raw_ops = { + .family = PF_CAN, + .release = raw_release, + .bind = raw_bind, + .connect = sock_no_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = raw_getname, + .poll = datagram_poll, + .ioctl = raw_sock_no_ioctlcmd, + .gettstamp = sock_gettstamp, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = raw_setsockopt, + .getsockopt = raw_getsockopt, + .sendmsg = raw_sendmsg, + .recvmsg = raw_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static struct proto raw_proto __read_mostly = { + .name = "CAN_RAW", + .owner = THIS_MODULE, + .obj_size = sizeof(struct raw_sock), + .init = raw_init, +}; + +static const struct can_proto raw_can_proto = { + .type = SOCK_RAW, + .protocol = CAN_RAW, + .ops = &raw_ops, + .prot = &raw_proto, +}; + +static struct notifier_block canraw_notifier = { + .notifier_call = raw_notifier +}; + +static __init int raw_module_init(void) +{ + int err; + + pr_info("can: raw protocol\n"); + + err = register_netdevice_notifier(&canraw_notifier); + if (err) + return err; + + err = can_proto_register(&raw_can_proto); + if (err < 0) { + pr_err("can: registration of raw protocol failed\n"); + goto register_proto_failed; + } + + return 0; + +register_proto_failed: + unregister_netdevice_notifier(&canraw_notifier); + return err; +} + +static __exit void raw_module_exit(void) +{ + can_proto_unregister(&raw_can_proto); + unregister_netdevice_notifier(&canraw_notifier); +} + +module_init(raw_module_init); +module_exit(raw_module_exit); diff --git a/net/ceph/Kconfig b/net/ceph/Kconfig new file mode 100644 index 000000000..c5c4eef3a --- /dev/null +++ b/net/ceph/Kconfig @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: GPL-2.0-only +config CEPH_LIB + tristate "Ceph core library" + depends on INET + select LIBCRC32C + select CRYPTO_AES + select CRYPTO_CBC + select CRYPTO_GCM + select CRYPTO_HMAC + select CRYPTO_SHA256 + select CRYPTO + select KEYS + default n + help + Choose Y or M here to include cephlib, which provides the + common functionality to both the Ceph filesystem and + to the rados block device (rbd). + + More information at https://ceph.io/. + + If unsure, say N. + +config CEPH_LIB_PRETTYDEBUG + bool "Include file:line in ceph debug output" + depends on CEPH_LIB + default n + help + If you say Y here, debug output will include a filename and + line to aid debugging. This increases kernel size and slows + execution slightly when debug call sites are enabled (e.g., + via CONFIG_DYNAMIC_DEBUG). + + If unsure, say N. + +config CEPH_LIB_USE_DNS_RESOLVER + bool "Use in-kernel support for DNS lookup" + depends on CEPH_LIB + select DNS_RESOLVER + default n + help + If you say Y here, hostnames (e.g. monitor addresses) will + be resolved using the CONFIG_DNS_RESOLVER facility. + + For information on how to use CONFIG_DNS_RESOLVER consult + Documentation/networking/dns_resolver.rst + + If unsure, say N. diff --git a/net/ceph/Makefile b/net/ceph/Makefile new file mode 100644 index 000000000..8802a0c01 --- /dev/null +++ b/net/ceph/Makefile @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for CEPH filesystem. +# +obj-$(CONFIG_CEPH_LIB) += libceph.o + +libceph-y := ceph_common.o messenger.o msgpool.o buffer.o pagelist.o \ + mon_client.o decode.o \ + cls_lock_client.o \ + osd_client.o osdmap.o crush/crush.o crush/mapper.o crush/hash.o \ + striper.o \ + debugfs.o \ + auth.o auth_none.o \ + crypto.o armor.o \ + auth_x.o \ + ceph_strings.o ceph_hash.o \ + pagevec.o snapshot.o string_table.o \ + messenger_v1.o messenger_v2.o diff --git a/net/ceph/armor.c b/net/ceph/armor.c new file mode 100644 index 000000000..0db806592 --- /dev/null +++ b/net/ceph/armor.c @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include + +int ceph_armor(char *dst, const char *src, const char *end); +int ceph_unarmor(char *dst, const char *src, const char *end); + +/* + * base64 encode/decode. + */ + +static const char *pem_key = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +static int encode_bits(int c) +{ + return pem_key[c]; +} + +static int decode_bits(char c) +{ + if (c >= 'A' && c <= 'Z') + return c - 'A'; + if (c >= 'a' && c <= 'z') + return c - 'a' + 26; + if (c >= '0' && c <= '9') + return c - '0' + 52; + if (c == '+') + return 62; + if (c == '/') + return 63; + if (c == '=') + return 0; /* just non-negative, please */ + return -EINVAL; +} + +int ceph_armor(char *dst, const char *src, const char *end) +{ + int olen = 0; + int line = 0; + + while (src < end) { + unsigned char a, b, c; + + a = *src++; + *dst++ = encode_bits(a >> 2); + if (src < end) { + b = *src++; + *dst++ = encode_bits(((a & 3) << 4) | (b >> 4)); + if (src < end) { + c = *src++; + *dst++ = encode_bits(((b & 15) << 2) | + (c >> 6)); + *dst++ = encode_bits(c & 63); + } else { + *dst++ = encode_bits((b & 15) << 2); + *dst++ = '='; + } + } else { + *dst++ = encode_bits(((a & 3) << 4)); + *dst++ = '='; + *dst++ = '='; + } + olen += 4; + line += 4; + if (line == 64) { + line = 0; + *(dst++) = '\n'; + olen++; + } + } + return olen; +} + +int ceph_unarmor(char *dst, const char *src, const char *end) +{ + int olen = 0; + + while (src < end) { + int a, b, c, d; + + if (src[0] == '\n') { + src++; + continue; + } + if (src + 4 > end) + return -EINVAL; + a = decode_bits(src[0]); + b = decode_bits(src[1]); + c = decode_bits(src[2]); + d = decode_bits(src[3]); + if (a < 0 || b < 0 || c < 0 || d < 0) + return -EINVAL; + + *dst++ = (a << 2) | (b >> 4); + if (src[2] == '=') + return olen + 1; + *dst++ = ((b & 15) << 4) | (c >> 2); + if (src[3] == '=') + return olen + 2; + *dst++ = ((c & 3) << 6) | d; + olen += 3; + src += 4; + } + return olen; +} diff --git a/net/ceph/auth.c b/net/ceph/auth.c new file mode 100644 index 000000000..d38c9eadb --- /dev/null +++ b/net/ceph/auth.c @@ -0,0 +1,659 @@ +// SPDX-License-Identifier: GPL-2.0 +#include + +#include +#include +#include + +#include +#include +#include +#include +#include "auth_none.h" +#include "auth_x.h" + + +/* + * get protocol handler + */ +static u32 supported_protocols[] = { + CEPH_AUTH_NONE, + CEPH_AUTH_CEPHX +}; + +static int init_protocol(struct ceph_auth_client *ac, int proto) +{ + dout("%s proto %d\n", __func__, proto); + + switch (proto) { + case CEPH_AUTH_NONE: + return ceph_auth_none_init(ac); + case CEPH_AUTH_CEPHX: + return ceph_x_init(ac); + default: + pr_err("bad auth protocol %d\n", proto); + return -EINVAL; + } +} + +void ceph_auth_set_global_id(struct ceph_auth_client *ac, u64 global_id) +{ + dout("%s global_id %llu\n", __func__, global_id); + + if (!global_id) + pr_err("got zero global_id\n"); + + if (ac->global_id && global_id != ac->global_id) + pr_err("global_id changed from %llu to %llu\n", ac->global_id, + global_id); + + ac->global_id = global_id; +} + +/* + * setup, teardown. + */ +struct ceph_auth_client *ceph_auth_init(const char *name, + const struct ceph_crypto_key *key, + const int *con_modes) +{ + struct ceph_auth_client *ac; + + ac = kzalloc(sizeof(*ac), GFP_NOFS); + if (!ac) + return ERR_PTR(-ENOMEM); + + mutex_init(&ac->mutex); + ac->negotiating = true; + if (name) + ac->name = name; + else + ac->name = CEPH_AUTH_NAME_DEFAULT; + ac->key = key; + ac->preferred_mode = con_modes[0]; + ac->fallback_mode = con_modes[1]; + + dout("%s name '%s' preferred_mode %d fallback_mode %d\n", __func__, + ac->name, ac->preferred_mode, ac->fallback_mode); + return ac; +} + +void ceph_auth_destroy(struct ceph_auth_client *ac) +{ + dout("auth_destroy %p\n", ac); + if (ac->ops) + ac->ops->destroy(ac); + kfree(ac); +} + +/* + * Reset occurs when reconnecting to the monitor. + */ +void ceph_auth_reset(struct ceph_auth_client *ac) +{ + mutex_lock(&ac->mutex); + dout("auth_reset %p\n", ac); + if (ac->ops && !ac->negotiating) + ac->ops->reset(ac); + ac->negotiating = true; + mutex_unlock(&ac->mutex); +} + +/* + * EntityName, not to be confused with entity_name_t + */ +int ceph_auth_entity_name_encode(const char *name, void **p, void *end) +{ + int len = strlen(name); + + if (*p + 2*sizeof(u32) + len > end) + return -ERANGE; + ceph_encode_32(p, CEPH_ENTITY_TYPE_CLIENT); + ceph_encode_32(p, len); + ceph_encode_copy(p, name, len); + return 0; +} + +/* + * Initiate protocol negotiation with monitor. Include entity name + * and list supported protocols. + */ +int ceph_auth_build_hello(struct ceph_auth_client *ac, void *buf, size_t len) +{ + struct ceph_mon_request_header *monhdr = buf; + void *p = monhdr + 1, *end = buf + len, *lenp; + int i, num; + int ret; + + mutex_lock(&ac->mutex); + dout("auth_build_hello\n"); + monhdr->have_version = 0; + monhdr->session_mon = cpu_to_le16(-1); + monhdr->session_mon_tid = 0; + + ceph_encode_32(&p, CEPH_AUTH_UNKNOWN); /* no protocol, yet */ + + lenp = p; + p += sizeof(u32); + + ceph_decode_need(&p, end, 1 + sizeof(u32), bad); + ceph_encode_8(&p, 1); + num = ARRAY_SIZE(supported_protocols); + ceph_encode_32(&p, num); + ceph_decode_need(&p, end, num * sizeof(u32), bad); + for (i = 0; i < num; i++) + ceph_encode_32(&p, supported_protocols[i]); + + ret = ceph_auth_entity_name_encode(ac->name, &p, end); + if (ret < 0) + goto out; + ceph_decode_need(&p, end, sizeof(u64), bad); + ceph_encode_64(&p, ac->global_id); + + ceph_encode_32(&lenp, p - lenp - sizeof(u32)); + ret = p - buf; +out: + mutex_unlock(&ac->mutex); + return ret; + +bad: + ret = -ERANGE; + goto out; +} + +static int build_request(struct ceph_auth_client *ac, bool add_header, + void *buf, int buf_len) +{ + void *end = buf + buf_len; + void *p; + int ret; + + p = buf; + if (add_header) { + /* struct ceph_mon_request_header + protocol */ + ceph_encode_64_safe(&p, end, 0, e_range); + ceph_encode_16_safe(&p, end, -1, e_range); + ceph_encode_64_safe(&p, end, 0, e_range); + ceph_encode_32_safe(&p, end, ac->protocol, e_range); + } + + ceph_encode_need(&p, end, sizeof(u32), e_range); + ret = ac->ops->build_request(ac, p + sizeof(u32), end); + if (ret < 0) { + pr_err("auth protocol '%s' building request failed: %d\n", + ceph_auth_proto_name(ac->protocol), ret); + return ret; + } + dout(" built request %d bytes\n", ret); + ceph_encode_32(&p, ret); + return p + ret - buf; + +e_range: + return -ERANGE; +} + +/* + * Handle auth message from monitor. + */ +int ceph_handle_auth_reply(struct ceph_auth_client *ac, + void *buf, size_t len, + void *reply_buf, size_t reply_len) +{ + void *p = buf; + void *end = buf + len; + int protocol; + s32 result; + u64 global_id; + void *payload, *payload_end; + int payload_len; + char *result_msg; + int result_msg_len; + int ret = -EINVAL; + + mutex_lock(&ac->mutex); + dout("handle_auth_reply %p %p\n", p, end); + ceph_decode_need(&p, end, sizeof(u32) * 3 + sizeof(u64), bad); + protocol = ceph_decode_32(&p); + result = ceph_decode_32(&p); + global_id = ceph_decode_64(&p); + payload_len = ceph_decode_32(&p); + payload = p; + p += payload_len; + ceph_decode_need(&p, end, sizeof(u32), bad); + result_msg_len = ceph_decode_32(&p); + result_msg = p; + p += result_msg_len; + if (p != end) + goto bad; + + dout(" result %d '%.*s' gid %llu len %d\n", result, result_msg_len, + result_msg, global_id, payload_len); + + payload_end = payload + payload_len; + + if (ac->negotiating) { + /* server does not support our protocols? */ + if (!protocol && result < 0) { + ret = result; + goto out; + } + /* set up (new) protocol handler? */ + if (ac->protocol && ac->protocol != protocol) { + ac->ops->destroy(ac); + ac->protocol = 0; + ac->ops = NULL; + } + if (ac->protocol != protocol) { + ret = init_protocol(ac, protocol); + if (ret) { + pr_err("auth protocol '%s' init failed: %d\n", + ceph_auth_proto_name(protocol), ret); + goto out; + } + } + + ac->negotiating = false; + } + + if (result) { + pr_err("auth protocol '%s' mauth authentication failed: %d\n", + ceph_auth_proto_name(ac->protocol), result); + ret = result; + goto out; + } + + ret = ac->ops->handle_reply(ac, global_id, payload, payload_end, + NULL, NULL, NULL, NULL); + if (ret == -EAGAIN) { + ret = build_request(ac, true, reply_buf, reply_len); + goto out; + } else if (ret) { + goto out; + } + +out: + mutex_unlock(&ac->mutex); + return ret; + +bad: + pr_err("failed to decode auth msg\n"); + ret = -EINVAL; + goto out; +} + +int ceph_build_auth(struct ceph_auth_client *ac, + void *msg_buf, size_t msg_len) +{ + int ret = 0; + + mutex_lock(&ac->mutex); + if (ac->ops->should_authenticate(ac)) + ret = build_request(ac, true, msg_buf, msg_len); + mutex_unlock(&ac->mutex); + return ret; +} + +int ceph_auth_is_authenticated(struct ceph_auth_client *ac) +{ + int ret = 0; + + mutex_lock(&ac->mutex); + if (ac->ops) + ret = ac->ops->is_authenticated(ac); + mutex_unlock(&ac->mutex); + return ret; +} +EXPORT_SYMBOL(ceph_auth_is_authenticated); + +int __ceph_auth_get_authorizer(struct ceph_auth_client *ac, + struct ceph_auth_handshake *auth, + int peer_type, bool force_new, + int *proto, int *pref_mode, int *fallb_mode) +{ + int ret; + + mutex_lock(&ac->mutex); + if (force_new && auth->authorizer) { + ceph_auth_destroy_authorizer(auth->authorizer); + auth->authorizer = NULL; + } + if (!auth->authorizer) + ret = ac->ops->create_authorizer(ac, peer_type, auth); + else if (ac->ops->update_authorizer) + ret = ac->ops->update_authorizer(ac, peer_type, auth); + else + ret = 0; + if (ret) + goto out; + + *proto = ac->protocol; + if (pref_mode && fallb_mode) { + *pref_mode = ac->preferred_mode; + *fallb_mode = ac->fallback_mode; + } + +out: + mutex_unlock(&ac->mutex); + return ret; +} +EXPORT_SYMBOL(__ceph_auth_get_authorizer); + +void ceph_auth_destroy_authorizer(struct ceph_authorizer *a) +{ + a->destroy(a); +} +EXPORT_SYMBOL(ceph_auth_destroy_authorizer); + +int ceph_auth_add_authorizer_challenge(struct ceph_auth_client *ac, + struct ceph_authorizer *a, + void *challenge_buf, + int challenge_buf_len) +{ + int ret = 0; + + mutex_lock(&ac->mutex); + if (ac->ops && ac->ops->add_authorizer_challenge) + ret = ac->ops->add_authorizer_challenge(ac, a, challenge_buf, + challenge_buf_len); + mutex_unlock(&ac->mutex); + return ret; +} +EXPORT_SYMBOL(ceph_auth_add_authorizer_challenge); + +int ceph_auth_verify_authorizer_reply(struct ceph_auth_client *ac, + struct ceph_authorizer *a, + void *reply, int reply_len, + u8 *session_key, int *session_key_len, + u8 *con_secret, int *con_secret_len) +{ + int ret = 0; + + mutex_lock(&ac->mutex); + if (ac->ops && ac->ops->verify_authorizer_reply) + ret = ac->ops->verify_authorizer_reply(ac, a, + reply, reply_len, session_key, session_key_len, + con_secret, con_secret_len); + mutex_unlock(&ac->mutex); + return ret; +} +EXPORT_SYMBOL(ceph_auth_verify_authorizer_reply); + +void ceph_auth_invalidate_authorizer(struct ceph_auth_client *ac, int peer_type) +{ + mutex_lock(&ac->mutex); + if (ac->ops && ac->ops->invalidate_authorizer) + ac->ops->invalidate_authorizer(ac, peer_type); + mutex_unlock(&ac->mutex); +} +EXPORT_SYMBOL(ceph_auth_invalidate_authorizer); + +/* + * msgr2 authentication + */ + +static bool contains(const int *arr, int cnt, int val) +{ + int i; + + for (i = 0; i < cnt; i++) { + if (arr[i] == val) + return true; + } + + return false; +} + +static int encode_con_modes(void **p, void *end, int pref_mode, int fallb_mode) +{ + WARN_ON(pref_mode == CEPH_CON_MODE_UNKNOWN); + if (fallb_mode != CEPH_CON_MODE_UNKNOWN) { + ceph_encode_32_safe(p, end, 2, e_range); + ceph_encode_32_safe(p, end, pref_mode, e_range); + ceph_encode_32_safe(p, end, fallb_mode, e_range); + } else { + ceph_encode_32_safe(p, end, 1, e_range); + ceph_encode_32_safe(p, end, pref_mode, e_range); + } + + return 0; + +e_range: + return -ERANGE; +} + +/* + * Similar to ceph_auth_build_hello(). + */ +int ceph_auth_get_request(struct ceph_auth_client *ac, void *buf, int buf_len) +{ + int proto = ac->key ? CEPH_AUTH_CEPHX : CEPH_AUTH_NONE; + void *end = buf + buf_len; + void *lenp; + void *p; + int ret; + + mutex_lock(&ac->mutex); + if (ac->protocol == CEPH_AUTH_UNKNOWN) { + ret = init_protocol(ac, proto); + if (ret) { + pr_err("auth protocol '%s' init failed: %d\n", + ceph_auth_proto_name(proto), ret); + goto out; + } + } else { + WARN_ON(ac->protocol != proto); + ac->ops->reset(ac); + } + + p = buf; + ceph_encode_32_safe(&p, end, ac->protocol, e_range); + ret = encode_con_modes(&p, end, ac->preferred_mode, ac->fallback_mode); + if (ret) + goto out; + + lenp = p; + p += 4; /* space for len */ + + ceph_encode_8_safe(&p, end, CEPH_AUTH_MODE_MON, e_range); + ret = ceph_auth_entity_name_encode(ac->name, &p, end); + if (ret) + goto out; + + ceph_encode_64_safe(&p, end, ac->global_id, e_range); + ceph_encode_32(&lenp, p - lenp - 4); + ret = p - buf; + +out: + mutex_unlock(&ac->mutex); + return ret; + +e_range: + ret = -ERANGE; + goto out; +} + +int ceph_auth_handle_reply_more(struct ceph_auth_client *ac, void *reply, + int reply_len, void *buf, int buf_len) +{ + int ret; + + mutex_lock(&ac->mutex); + ret = ac->ops->handle_reply(ac, 0, reply, reply + reply_len, + NULL, NULL, NULL, NULL); + if (ret == -EAGAIN) + ret = build_request(ac, false, buf, buf_len); + else + WARN_ON(ret >= 0); + mutex_unlock(&ac->mutex); + return ret; +} + +int ceph_auth_handle_reply_done(struct ceph_auth_client *ac, + u64 global_id, void *reply, int reply_len, + u8 *session_key, int *session_key_len, + u8 *con_secret, int *con_secret_len) +{ + int ret; + + mutex_lock(&ac->mutex); + ret = ac->ops->handle_reply(ac, global_id, reply, reply + reply_len, + session_key, session_key_len, + con_secret, con_secret_len); + WARN_ON(ret == -EAGAIN || ret > 0); + mutex_unlock(&ac->mutex); + return ret; +} + +bool ceph_auth_handle_bad_method(struct ceph_auth_client *ac, + int used_proto, int result, + const int *allowed_protos, int proto_cnt, + const int *allowed_modes, int mode_cnt) +{ + mutex_lock(&ac->mutex); + WARN_ON(used_proto != ac->protocol); + + if (result == -EOPNOTSUPP) { + if (!contains(allowed_protos, proto_cnt, ac->protocol)) { + pr_err("auth protocol '%s' not allowed\n", + ceph_auth_proto_name(ac->protocol)); + goto not_allowed; + } + if (!contains(allowed_modes, mode_cnt, ac->preferred_mode) && + (ac->fallback_mode == CEPH_CON_MODE_UNKNOWN || + !contains(allowed_modes, mode_cnt, ac->fallback_mode))) { + pr_err("preferred mode '%s' not allowed\n", + ceph_con_mode_name(ac->preferred_mode)); + if (ac->fallback_mode == CEPH_CON_MODE_UNKNOWN) + pr_err("no fallback mode\n"); + else + pr_err("fallback mode '%s' not allowed\n", + ceph_con_mode_name(ac->fallback_mode)); + goto not_allowed; + } + } + + WARN_ON(result == -EOPNOTSUPP || result >= 0); + pr_err("auth protocol '%s' msgr authentication failed: %d\n", + ceph_auth_proto_name(ac->protocol), result); + + mutex_unlock(&ac->mutex); + return true; + +not_allowed: + mutex_unlock(&ac->mutex); + return false; +} + +int ceph_auth_get_authorizer(struct ceph_auth_client *ac, + struct ceph_auth_handshake *auth, + int peer_type, void *buf, int *buf_len) +{ + void *end = buf + *buf_len; + int pref_mode, fallb_mode; + int proto; + void *p; + int ret; + + ret = __ceph_auth_get_authorizer(ac, auth, peer_type, true, &proto, + &pref_mode, &fallb_mode); + if (ret) + return ret; + + p = buf; + ceph_encode_32_safe(&p, end, proto, e_range); + ret = encode_con_modes(&p, end, pref_mode, fallb_mode); + if (ret) + return ret; + + ceph_encode_32_safe(&p, end, auth->authorizer_buf_len, e_range); + *buf_len = p - buf; + return 0; + +e_range: + return -ERANGE; +} +EXPORT_SYMBOL(ceph_auth_get_authorizer); + +int ceph_auth_handle_svc_reply_more(struct ceph_auth_client *ac, + struct ceph_auth_handshake *auth, + void *reply, int reply_len, + void *buf, int *buf_len) +{ + void *end = buf + *buf_len; + void *p; + int ret; + + ret = ceph_auth_add_authorizer_challenge(ac, auth->authorizer, + reply, reply_len); + if (ret) + return ret; + + p = buf; + ceph_encode_32_safe(&p, end, auth->authorizer_buf_len, e_range); + *buf_len = p - buf; + return 0; + +e_range: + return -ERANGE; +} +EXPORT_SYMBOL(ceph_auth_handle_svc_reply_more); + +int ceph_auth_handle_svc_reply_done(struct ceph_auth_client *ac, + struct ceph_auth_handshake *auth, + void *reply, int reply_len, + u8 *session_key, int *session_key_len, + u8 *con_secret, int *con_secret_len) +{ + return ceph_auth_verify_authorizer_reply(ac, auth->authorizer, + reply, reply_len, session_key, session_key_len, + con_secret, con_secret_len); +} +EXPORT_SYMBOL(ceph_auth_handle_svc_reply_done); + +bool ceph_auth_handle_bad_authorizer(struct ceph_auth_client *ac, + int peer_type, int used_proto, int result, + const int *allowed_protos, int proto_cnt, + const int *allowed_modes, int mode_cnt) +{ + mutex_lock(&ac->mutex); + WARN_ON(used_proto != ac->protocol); + + if (result == -EOPNOTSUPP) { + if (!contains(allowed_protos, proto_cnt, ac->protocol)) { + pr_err("auth protocol '%s' not allowed by %s\n", + ceph_auth_proto_name(ac->protocol), + ceph_entity_type_name(peer_type)); + goto not_allowed; + } + if (!contains(allowed_modes, mode_cnt, ac->preferred_mode) && + (ac->fallback_mode == CEPH_CON_MODE_UNKNOWN || + !contains(allowed_modes, mode_cnt, ac->fallback_mode))) { + pr_err("preferred mode '%s' not allowed by %s\n", + ceph_con_mode_name(ac->preferred_mode), + ceph_entity_type_name(peer_type)); + if (ac->fallback_mode == CEPH_CON_MODE_UNKNOWN) + pr_err("no fallback mode\n"); + else + pr_err("fallback mode '%s' not allowed by %s\n", + ceph_con_mode_name(ac->fallback_mode), + ceph_entity_type_name(peer_type)); + goto not_allowed; + } + } + + WARN_ON(result == -EOPNOTSUPP || result >= 0); + pr_err("auth protocol '%s' authorization to %s failed: %d\n", + ceph_auth_proto_name(ac->protocol), + ceph_entity_type_name(peer_type), result); + + if (ac->ops->invalidate_authorizer) + ac->ops->invalidate_authorizer(ac, peer_type); + + mutex_unlock(&ac->mutex); + return true; + +not_allowed: + mutex_unlock(&ac->mutex); + return false; +} +EXPORT_SYMBOL(ceph_auth_handle_bad_authorizer); diff --git a/net/ceph/auth_none.c b/net/ceph/auth_none.c new file mode 100644 index 000000000..77b5519bc --- /dev/null +++ b/net/ceph/auth_none.c @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include + +#include +#include +#include +#include + +#include +#include + +#include "auth_none.h" + +static void reset(struct ceph_auth_client *ac) +{ + struct ceph_auth_none_info *xi = ac->private; + + xi->starting = true; +} + +static void destroy(struct ceph_auth_client *ac) +{ + kfree(ac->private); + ac->private = NULL; +} + +static int is_authenticated(struct ceph_auth_client *ac) +{ + struct ceph_auth_none_info *xi = ac->private; + + return !xi->starting; +} + +static int should_authenticate(struct ceph_auth_client *ac) +{ + struct ceph_auth_none_info *xi = ac->private; + + return xi->starting; +} + +static int ceph_auth_none_build_authorizer(struct ceph_auth_client *ac, + struct ceph_none_authorizer *au) +{ + void *p = au->buf; + void *const end = p + sizeof(au->buf); + int ret; + + ceph_encode_8_safe(&p, end, 1, e_range); + ret = ceph_auth_entity_name_encode(ac->name, &p, end); + if (ret < 0) + return ret; + + ceph_encode_64_safe(&p, end, ac->global_id, e_range); + au->buf_len = p - (void *)au->buf; + dout("%s built authorizer len %d\n", __func__, au->buf_len); + return 0; + +e_range: + return -ERANGE; +} + +static int build_request(struct ceph_auth_client *ac, void *buf, void *end) +{ + return 0; +} + +/* + * the generic auth code decode the global_id, and we carry no actual + * authenticate state, so nothing happens here. + */ +static int handle_reply(struct ceph_auth_client *ac, u64 global_id, + void *buf, void *end, u8 *session_key, + int *session_key_len, u8 *con_secret, + int *con_secret_len) +{ + struct ceph_auth_none_info *xi = ac->private; + + xi->starting = false; + ceph_auth_set_global_id(ac, global_id); + return 0; +} + +static void ceph_auth_none_destroy_authorizer(struct ceph_authorizer *a) +{ + kfree(a); +} + +/* + * build an 'authorizer' with our entity_name and global_id. it is + * identical for all services we connect to. + */ +static int ceph_auth_none_create_authorizer( + struct ceph_auth_client *ac, int peer_type, + struct ceph_auth_handshake *auth) +{ + struct ceph_none_authorizer *au; + int ret; + + au = kmalloc(sizeof(*au), GFP_NOFS); + if (!au) + return -ENOMEM; + + au->base.destroy = ceph_auth_none_destroy_authorizer; + + ret = ceph_auth_none_build_authorizer(ac, au); + if (ret) { + kfree(au); + return ret; + } + + auth->authorizer = (struct ceph_authorizer *) au; + auth->authorizer_buf = au->buf; + auth->authorizer_buf_len = au->buf_len; + auth->authorizer_reply_buf = NULL; + auth->authorizer_reply_buf_len = 0; + + return 0; +} + +static const struct ceph_auth_client_ops ceph_auth_none_ops = { + .reset = reset, + .destroy = destroy, + .is_authenticated = is_authenticated, + .should_authenticate = should_authenticate, + .build_request = build_request, + .handle_reply = handle_reply, + .create_authorizer = ceph_auth_none_create_authorizer, +}; + +int ceph_auth_none_init(struct ceph_auth_client *ac) +{ + struct ceph_auth_none_info *xi; + + dout("ceph_auth_none_init %p\n", ac); + xi = kzalloc(sizeof(*xi), GFP_NOFS); + if (!xi) + return -ENOMEM; + + xi->starting = true; + + ac->protocol = CEPH_AUTH_NONE; + ac->private = xi; + ac->ops = &ceph_auth_none_ops; + return 0; +} diff --git a/net/ceph/auth_none.h b/net/ceph/auth_none.h new file mode 100644 index 000000000..bb121539e --- /dev/null +++ b/net/ceph/auth_none.h @@ -0,0 +1,27 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _FS_CEPH_AUTH_NONE_H +#define _FS_CEPH_AUTH_NONE_H + +#include +#include + +/* + * null security mode. + * + * we use a single static authorizer that simply encodes our entity name + * and global id. + */ + +struct ceph_none_authorizer { + struct ceph_authorizer base; + char buf[128]; + int buf_len; +}; + +struct ceph_auth_none_info { + bool starting; +}; + +int ceph_auth_none_init(struct ceph_auth_client *ac); + +#endif diff --git a/net/ceph/auth_x.c b/net/ceph/auth_x.c new file mode 100644 index 000000000..b71b16359 --- /dev/null +++ b/net/ceph/auth_x.c @@ -0,0 +1,1122 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "crypto.h" +#include "auth_x.h" +#include "auth_x_protocol.h" + +static void ceph_x_validate_tickets(struct ceph_auth_client *ac, int *pneed); + +static int ceph_x_is_authenticated(struct ceph_auth_client *ac) +{ + struct ceph_x_info *xi = ac->private; + int missing; + int need; /* missing + need renewal */ + + ceph_x_validate_tickets(ac, &need); + missing = ac->want_keys & ~xi->have_keys; + WARN_ON((need & missing) != missing); + dout("%s want 0x%x have 0x%x missing 0x%x -> %d\n", __func__, + ac->want_keys, xi->have_keys, missing, !missing); + return !missing; +} + +static int ceph_x_should_authenticate(struct ceph_auth_client *ac) +{ + struct ceph_x_info *xi = ac->private; + int need; + + ceph_x_validate_tickets(ac, &need); + dout("%s want 0x%x have 0x%x need 0x%x -> %d\n", __func__, + ac->want_keys, xi->have_keys, need, !!need); + return !!need; +} + +static int ceph_x_encrypt_offset(void) +{ + return sizeof(u32) + sizeof(struct ceph_x_encrypt_header); +} + +static int ceph_x_encrypt_buflen(int ilen) +{ + return ceph_x_encrypt_offset() + ilen + 16; +} + +static int ceph_x_encrypt(struct ceph_crypto_key *secret, void *buf, + int buf_len, int plaintext_len) +{ + struct ceph_x_encrypt_header *hdr = buf + sizeof(u32); + int ciphertext_len; + int ret; + + hdr->struct_v = 1; + hdr->magic = cpu_to_le64(CEPHX_ENC_MAGIC); + + ret = ceph_crypt(secret, true, buf + sizeof(u32), buf_len - sizeof(u32), + plaintext_len + sizeof(struct ceph_x_encrypt_header), + &ciphertext_len); + if (ret) + return ret; + + ceph_encode_32(&buf, ciphertext_len); + return sizeof(u32) + ciphertext_len; +} + +static int __ceph_x_decrypt(struct ceph_crypto_key *secret, void *p, + int ciphertext_len) +{ + struct ceph_x_encrypt_header *hdr = p; + int plaintext_len; + int ret; + + ret = ceph_crypt(secret, false, p, ciphertext_len, ciphertext_len, + &plaintext_len); + if (ret) + return ret; + + if (le64_to_cpu(hdr->magic) != CEPHX_ENC_MAGIC) { + pr_err("%s bad magic\n", __func__); + return -EINVAL; + } + + return plaintext_len - sizeof(*hdr); +} + +static int ceph_x_decrypt(struct ceph_crypto_key *secret, void **p, void *end) +{ + int ciphertext_len; + int ret; + + ceph_decode_32_safe(p, end, ciphertext_len, e_inval); + ceph_decode_need(p, end, ciphertext_len, e_inval); + + ret = __ceph_x_decrypt(secret, *p, ciphertext_len); + if (ret < 0) + return ret; + + *p += ciphertext_len; + return ret; + +e_inval: + return -EINVAL; +} + +/* + * get existing (or insert new) ticket handler + */ +static struct ceph_x_ticket_handler * +get_ticket_handler(struct ceph_auth_client *ac, int service) +{ + struct ceph_x_ticket_handler *th; + struct ceph_x_info *xi = ac->private; + struct rb_node *parent = NULL, **p = &xi->ticket_handlers.rb_node; + + while (*p) { + parent = *p; + th = rb_entry(parent, struct ceph_x_ticket_handler, node); + if (service < th->service) + p = &(*p)->rb_left; + else if (service > th->service) + p = &(*p)->rb_right; + else + return th; + } + + /* add it */ + th = kzalloc(sizeof(*th), GFP_NOFS); + if (!th) + return ERR_PTR(-ENOMEM); + th->service = service; + rb_link_node(&th->node, parent, p); + rb_insert_color(&th->node, &xi->ticket_handlers); + return th; +} + +static void remove_ticket_handler(struct ceph_auth_client *ac, + struct ceph_x_ticket_handler *th) +{ + struct ceph_x_info *xi = ac->private; + + dout("remove_ticket_handler %p %d\n", th, th->service); + rb_erase(&th->node, &xi->ticket_handlers); + ceph_crypto_key_destroy(&th->session_key); + if (th->ticket_blob) + ceph_buffer_put(th->ticket_blob); + kfree(th); +} + +static int process_one_ticket(struct ceph_auth_client *ac, + struct ceph_crypto_key *secret, + void **p, void *end) +{ + struct ceph_x_info *xi = ac->private; + int type; + u8 tkt_struct_v, blob_struct_v; + struct ceph_x_ticket_handler *th; + void *dp, *dend; + int dlen; + char is_enc; + struct timespec64 validity; + void *tp, *tpend; + void **ptp; + struct ceph_crypto_key new_session_key = { 0 }; + struct ceph_buffer *new_ticket_blob; + time64_t new_expires, new_renew_after; + u64 new_secret_id; + int ret; + + ceph_decode_need(p, end, sizeof(u32) + 1, bad); + + type = ceph_decode_32(p); + dout(" ticket type %d %s\n", type, ceph_entity_type_name(type)); + + tkt_struct_v = ceph_decode_8(p); + if (tkt_struct_v != 1) + goto bad; + + th = get_ticket_handler(ac, type); + if (IS_ERR(th)) { + ret = PTR_ERR(th); + goto out; + } + + /* blob for me */ + dp = *p + ceph_x_encrypt_offset(); + ret = ceph_x_decrypt(secret, p, end); + if (ret < 0) + goto out; + dout(" decrypted %d bytes\n", ret); + dend = dp + ret; + + ceph_decode_8_safe(&dp, dend, tkt_struct_v, bad); + if (tkt_struct_v != 1) + goto bad; + + ret = ceph_crypto_key_decode(&new_session_key, &dp, dend); + if (ret) + goto out; + + ceph_decode_need(&dp, dend, sizeof(struct ceph_timespec), bad); + ceph_decode_timespec64(&validity, dp); + dp += sizeof(struct ceph_timespec); + new_expires = ktime_get_real_seconds() + validity.tv_sec; + new_renew_after = new_expires - (validity.tv_sec / 4); + dout(" expires=%llu renew_after=%llu\n", new_expires, + new_renew_after); + + /* ticket blob for service */ + ceph_decode_8_safe(p, end, is_enc, bad); + if (is_enc) { + /* encrypted */ + tp = *p + ceph_x_encrypt_offset(); + ret = ceph_x_decrypt(&th->session_key, p, end); + if (ret < 0) + goto out; + dout(" encrypted ticket, decrypted %d bytes\n", ret); + ptp = &tp; + tpend = tp + ret; + } else { + /* unencrypted */ + ptp = p; + tpend = end; + } + ceph_decode_32_safe(ptp, tpend, dlen, bad); + dout(" ticket blob is %d bytes\n", dlen); + ceph_decode_need(ptp, tpend, 1 + sizeof(u64), bad); + blob_struct_v = ceph_decode_8(ptp); + if (blob_struct_v != 1) + goto bad; + + new_secret_id = ceph_decode_64(ptp); + ret = ceph_decode_buffer(&new_ticket_blob, ptp, tpend); + if (ret) + goto out; + + /* all is well, update our ticket */ + ceph_crypto_key_destroy(&th->session_key); + if (th->ticket_blob) + ceph_buffer_put(th->ticket_blob); + th->session_key = new_session_key; + th->ticket_blob = new_ticket_blob; + th->secret_id = new_secret_id; + th->expires = new_expires; + th->renew_after = new_renew_after; + th->have_key = true; + dout(" got ticket service %d (%s) secret_id %lld len %d\n", + type, ceph_entity_type_name(type), th->secret_id, + (int)th->ticket_blob->vec.iov_len); + xi->have_keys |= th->service; + return 0; + +bad: + ret = -EINVAL; +out: + ceph_crypto_key_destroy(&new_session_key); + return ret; +} + +static int ceph_x_proc_ticket_reply(struct ceph_auth_client *ac, + struct ceph_crypto_key *secret, + void **p, void *end) +{ + u8 reply_struct_v; + u32 num; + int ret; + + ceph_decode_8_safe(p, end, reply_struct_v, bad); + if (reply_struct_v != 1) + return -EINVAL; + + ceph_decode_32_safe(p, end, num, bad); + dout("%d tickets\n", num); + + while (num--) { + ret = process_one_ticket(ac, secret, p, end); + if (ret) + return ret; + } + + return 0; + +bad: + return -EINVAL; +} + +/* + * Encode and encrypt the second part (ceph_x_authorize_b) of the + * authorizer. The first part (ceph_x_authorize_a) should already be + * encoded. + */ +static int encrypt_authorizer(struct ceph_x_authorizer *au, + u64 *server_challenge) +{ + struct ceph_x_authorize_a *msg_a; + struct ceph_x_authorize_b *msg_b; + void *p, *end; + int ret; + + msg_a = au->buf->vec.iov_base; + WARN_ON(msg_a->ticket_blob.secret_id != cpu_to_le64(au->secret_id)); + p = (void *)(msg_a + 1) + le32_to_cpu(msg_a->ticket_blob.blob_len); + end = au->buf->vec.iov_base + au->buf->vec.iov_len; + + msg_b = p + ceph_x_encrypt_offset(); + msg_b->struct_v = 2; + msg_b->nonce = cpu_to_le64(au->nonce); + if (server_challenge) { + msg_b->have_challenge = 1; + msg_b->server_challenge_plus_one = + cpu_to_le64(*server_challenge + 1); + } else { + msg_b->have_challenge = 0; + msg_b->server_challenge_plus_one = 0; + } + + ret = ceph_x_encrypt(&au->session_key, p, end - p, sizeof(*msg_b)); + if (ret < 0) + return ret; + + p += ret; + if (server_challenge) { + WARN_ON(p != end); + } else { + WARN_ON(p > end); + au->buf->vec.iov_len = p - au->buf->vec.iov_base; + } + + return 0; +} + +static void ceph_x_authorizer_cleanup(struct ceph_x_authorizer *au) +{ + ceph_crypto_key_destroy(&au->session_key); + if (au->buf) { + ceph_buffer_put(au->buf); + au->buf = NULL; + } +} + +static int ceph_x_build_authorizer(struct ceph_auth_client *ac, + struct ceph_x_ticket_handler *th, + struct ceph_x_authorizer *au) +{ + int maxlen; + struct ceph_x_authorize_a *msg_a; + struct ceph_x_authorize_b *msg_b; + int ret; + int ticket_blob_len = + (th->ticket_blob ? th->ticket_blob->vec.iov_len : 0); + + dout("build_authorizer for %s %p\n", + ceph_entity_type_name(th->service), au); + + ceph_crypto_key_destroy(&au->session_key); + ret = ceph_crypto_key_clone(&au->session_key, &th->session_key); + if (ret) + goto out_au; + + maxlen = sizeof(*msg_a) + ticket_blob_len + + ceph_x_encrypt_buflen(sizeof(*msg_b)); + dout(" need len %d\n", maxlen); + if (au->buf && au->buf->alloc_len < maxlen) { + ceph_buffer_put(au->buf); + au->buf = NULL; + } + if (!au->buf) { + au->buf = ceph_buffer_new(maxlen, GFP_NOFS); + if (!au->buf) { + ret = -ENOMEM; + goto out_au; + } + } + au->service = th->service; + WARN_ON(!th->secret_id); + au->secret_id = th->secret_id; + + msg_a = au->buf->vec.iov_base; + msg_a->struct_v = 1; + msg_a->global_id = cpu_to_le64(ac->global_id); + msg_a->service_id = cpu_to_le32(th->service); + msg_a->ticket_blob.struct_v = 1; + msg_a->ticket_blob.secret_id = cpu_to_le64(th->secret_id); + msg_a->ticket_blob.blob_len = cpu_to_le32(ticket_blob_len); + if (ticket_blob_len) { + memcpy(msg_a->ticket_blob.blob, th->ticket_blob->vec.iov_base, + th->ticket_blob->vec.iov_len); + } + dout(" th %p secret_id %lld %lld\n", th, th->secret_id, + le64_to_cpu(msg_a->ticket_blob.secret_id)); + + get_random_bytes(&au->nonce, sizeof(au->nonce)); + ret = encrypt_authorizer(au, NULL); + if (ret) { + pr_err("failed to encrypt authorizer: %d", ret); + goto out_au; + } + + dout(" built authorizer nonce %llx len %d\n", au->nonce, + (int)au->buf->vec.iov_len); + return 0; + +out_au: + ceph_x_authorizer_cleanup(au); + return ret; +} + +static int ceph_x_encode_ticket(struct ceph_x_ticket_handler *th, + void **p, void *end) +{ + ceph_decode_need(p, end, 1 + sizeof(u64), bad); + ceph_encode_8(p, 1); + ceph_encode_64(p, th->secret_id); + if (th->ticket_blob) { + const char *buf = th->ticket_blob->vec.iov_base; + u32 len = th->ticket_blob->vec.iov_len; + + ceph_encode_32_safe(p, end, len, bad); + ceph_encode_copy_safe(p, end, buf, len, bad); + } else { + ceph_encode_32_safe(p, end, 0, bad); + } + + return 0; +bad: + return -ERANGE; +} + +static bool need_key(struct ceph_x_ticket_handler *th) +{ + if (!th->have_key) + return true; + + return ktime_get_real_seconds() >= th->renew_after; +} + +static bool have_key(struct ceph_x_ticket_handler *th) +{ + if (th->have_key && ktime_get_real_seconds() >= th->expires) { + dout("ticket %d (%s) secret_id %llu expired\n", th->service, + ceph_entity_type_name(th->service), th->secret_id); + th->have_key = false; + } + + return th->have_key; +} + +static void ceph_x_validate_tickets(struct ceph_auth_client *ac, int *pneed) +{ + int want = ac->want_keys; + struct ceph_x_info *xi = ac->private; + int service; + + *pneed = ac->want_keys & ~(xi->have_keys); + + for (service = 1; service <= want; service <<= 1) { + struct ceph_x_ticket_handler *th; + + if (!(ac->want_keys & service)) + continue; + + if (*pneed & service) + continue; + + th = get_ticket_handler(ac, service); + if (IS_ERR(th)) { + *pneed |= service; + continue; + } + + if (need_key(th)) + *pneed |= service; + if (!have_key(th)) + xi->have_keys &= ~service; + } +} + +static int ceph_x_build_request(struct ceph_auth_client *ac, + void *buf, void *end) +{ + struct ceph_x_info *xi = ac->private; + int need; + struct ceph_x_request_header *head = buf; + void *p; + int ret; + struct ceph_x_ticket_handler *th = + get_ticket_handler(ac, CEPH_ENTITY_TYPE_AUTH); + + if (IS_ERR(th)) + return PTR_ERR(th); + + ceph_x_validate_tickets(ac, &need); + dout("%s want 0x%x have 0x%x need 0x%x\n", __func__, ac->want_keys, + xi->have_keys, need); + + if (need & CEPH_ENTITY_TYPE_AUTH) { + struct ceph_x_authenticate *auth = (void *)(head + 1); + void *enc_buf = xi->auth_authorizer.enc_buf; + struct ceph_x_challenge_blob *blob = enc_buf + + ceph_x_encrypt_offset(); + u64 *u; + + p = auth + 1; + if (p > end) + return -ERANGE; + + dout(" get_auth_session_key\n"); + head->op = cpu_to_le16(CEPHX_GET_AUTH_SESSION_KEY); + + /* encrypt and hash */ + get_random_bytes(&auth->client_challenge, sizeof(u64)); + blob->client_challenge = auth->client_challenge; + blob->server_challenge = cpu_to_le64(xi->server_challenge); + ret = ceph_x_encrypt(&xi->secret, enc_buf, CEPHX_AU_ENC_BUF_LEN, + sizeof(*blob)); + if (ret < 0) + return ret; + + auth->struct_v = 3; /* nautilus+ */ + auth->key = 0; + for (u = (u64 *)enc_buf; u + 1 <= (u64 *)(enc_buf + ret); u++) + auth->key ^= *(__le64 *)u; + dout(" server_challenge %llx client_challenge %llx key %llx\n", + xi->server_challenge, le64_to_cpu(auth->client_challenge), + le64_to_cpu(auth->key)); + + /* now encode the old ticket if exists */ + ret = ceph_x_encode_ticket(th, &p, end); + if (ret < 0) + return ret; + + /* nautilus+: request service tickets at the same time */ + need = ac->want_keys & ~CEPH_ENTITY_TYPE_AUTH; + WARN_ON(!need); + ceph_encode_32_safe(&p, end, need, e_range); + return p - buf; + } + + if (need) { + dout(" get_principal_session_key\n"); + ret = ceph_x_build_authorizer(ac, th, &xi->auth_authorizer); + if (ret) + return ret; + + p = buf; + ceph_encode_16_safe(&p, end, CEPHX_GET_PRINCIPAL_SESSION_KEY, + e_range); + ceph_encode_copy_safe(&p, end, + xi->auth_authorizer.buf->vec.iov_base, + xi->auth_authorizer.buf->vec.iov_len, e_range); + ceph_encode_8_safe(&p, end, 1, e_range); + ceph_encode_32_safe(&p, end, need, e_range); + return p - buf; + } + + return 0; + +e_range: + return -ERANGE; +} + +static int decode_con_secret(void **p, void *end, u8 *con_secret, + int *con_secret_len) +{ + int len; + + ceph_decode_32_safe(p, end, len, bad); + ceph_decode_need(p, end, len, bad); + + dout("%s len %d\n", __func__, len); + if (con_secret) { + if (len > CEPH_MAX_CON_SECRET_LEN) { + pr_err("connection secret too big %d\n", len); + goto bad_memzero; + } + memcpy(con_secret, *p, len); + *con_secret_len = len; + } + memzero_explicit(*p, len); + *p += len; + return 0; + +bad_memzero: + memzero_explicit(*p, len); +bad: + pr_err("failed to decode connection secret\n"); + return -EINVAL; +} + +static int handle_auth_session_key(struct ceph_auth_client *ac, u64 global_id, + void **p, void *end, + u8 *session_key, int *session_key_len, + u8 *con_secret, int *con_secret_len) +{ + struct ceph_x_info *xi = ac->private; + struct ceph_x_ticket_handler *th; + void *dp, *dend; + int len; + int ret; + + /* AUTH ticket */ + ret = ceph_x_proc_ticket_reply(ac, &xi->secret, p, end); + if (ret) + return ret; + + ceph_auth_set_global_id(ac, global_id); + if (*p == end) { + /* pre-nautilus (or didn't request service tickets!) */ + WARN_ON(session_key || con_secret); + return 0; + } + + th = get_ticket_handler(ac, CEPH_ENTITY_TYPE_AUTH); + if (IS_ERR(th)) + return PTR_ERR(th); + + if (session_key) { + memcpy(session_key, th->session_key.key, th->session_key.len); + *session_key_len = th->session_key.len; + } + + /* connection secret */ + ceph_decode_32_safe(p, end, len, e_inval); + dout("%s connection secret blob len %d\n", __func__, len); + if (len > 0) { + dp = *p + ceph_x_encrypt_offset(); + ret = ceph_x_decrypt(&th->session_key, p, *p + len); + if (ret < 0) + return ret; + + dout("%s decrypted %d bytes\n", __func__, ret); + dend = dp + ret; + + ret = decode_con_secret(&dp, dend, con_secret, con_secret_len); + if (ret) + return ret; + } + + /* service tickets */ + ceph_decode_32_safe(p, end, len, e_inval); + dout("%s service tickets blob len %d\n", __func__, len); + if (len > 0) { + ret = ceph_x_proc_ticket_reply(ac, &th->session_key, + p, *p + len); + if (ret) + return ret; + } + + return 0; + +e_inval: + return -EINVAL; +} + +static int ceph_x_handle_reply(struct ceph_auth_client *ac, u64 global_id, + void *buf, void *end, + u8 *session_key, int *session_key_len, + u8 *con_secret, int *con_secret_len) +{ + struct ceph_x_info *xi = ac->private; + struct ceph_x_ticket_handler *th; + int len = end - buf; + int result; + void *p; + int op; + int ret; + + if (xi->starting) { + /* it's a hello */ + struct ceph_x_server_challenge *sc = buf; + + if (len != sizeof(*sc)) + return -EINVAL; + xi->server_challenge = le64_to_cpu(sc->server_challenge); + dout("handle_reply got server challenge %llx\n", + xi->server_challenge); + xi->starting = false; + xi->have_keys &= ~CEPH_ENTITY_TYPE_AUTH; + return -EAGAIN; + } + + p = buf; + ceph_decode_16_safe(&p, end, op, e_inval); + ceph_decode_32_safe(&p, end, result, e_inval); + dout("handle_reply op %d result %d\n", op, result); + switch (op) { + case CEPHX_GET_AUTH_SESSION_KEY: + /* AUTH ticket + [connection secret] + service tickets */ + ret = handle_auth_session_key(ac, global_id, &p, end, + session_key, session_key_len, + con_secret, con_secret_len); + break; + + case CEPHX_GET_PRINCIPAL_SESSION_KEY: + th = get_ticket_handler(ac, CEPH_ENTITY_TYPE_AUTH); + if (IS_ERR(th)) + return PTR_ERR(th); + + /* service tickets */ + ret = ceph_x_proc_ticket_reply(ac, &th->session_key, &p, end); + break; + + default: + return -EINVAL; + } + if (ret) + return ret; + if (ac->want_keys == xi->have_keys) + return 0; + return -EAGAIN; + +e_inval: + return -EINVAL; +} + +static void ceph_x_destroy_authorizer(struct ceph_authorizer *a) +{ + struct ceph_x_authorizer *au = (void *)a; + + ceph_x_authorizer_cleanup(au); + kfree(au); +} + +static int ceph_x_create_authorizer( + struct ceph_auth_client *ac, int peer_type, + struct ceph_auth_handshake *auth) +{ + struct ceph_x_authorizer *au; + struct ceph_x_ticket_handler *th; + int ret; + + th = get_ticket_handler(ac, peer_type); + if (IS_ERR(th)) + return PTR_ERR(th); + + au = kzalloc(sizeof(*au), GFP_NOFS); + if (!au) + return -ENOMEM; + + au->base.destroy = ceph_x_destroy_authorizer; + + ret = ceph_x_build_authorizer(ac, th, au); + if (ret) { + kfree(au); + return ret; + } + + auth->authorizer = (struct ceph_authorizer *) au; + auth->authorizer_buf = au->buf->vec.iov_base; + auth->authorizer_buf_len = au->buf->vec.iov_len; + auth->authorizer_reply_buf = au->enc_buf; + auth->authorizer_reply_buf_len = CEPHX_AU_ENC_BUF_LEN; + auth->sign_message = ac->ops->sign_message; + auth->check_message_signature = ac->ops->check_message_signature; + + return 0; +} + +static int ceph_x_update_authorizer( + struct ceph_auth_client *ac, int peer_type, + struct ceph_auth_handshake *auth) +{ + struct ceph_x_authorizer *au; + struct ceph_x_ticket_handler *th; + + th = get_ticket_handler(ac, peer_type); + if (IS_ERR(th)) + return PTR_ERR(th); + + au = (struct ceph_x_authorizer *)auth->authorizer; + if (au->secret_id < th->secret_id) { + dout("ceph_x_update_authorizer service %u secret %llu < %llu\n", + au->service, au->secret_id, th->secret_id); + return ceph_x_build_authorizer(ac, th, au); + } + return 0; +} + +/* + * CephXAuthorizeChallenge + */ +static int decrypt_authorizer_challenge(struct ceph_crypto_key *secret, + void *challenge, int challenge_len, + u64 *server_challenge) +{ + void *dp, *dend; + int ret; + + /* no leading len */ + ret = __ceph_x_decrypt(secret, challenge, challenge_len); + if (ret < 0) + return ret; + + dout("%s decrypted %d bytes\n", __func__, ret); + dp = challenge + sizeof(struct ceph_x_encrypt_header); + dend = dp + ret; + + ceph_decode_skip_8(&dp, dend, e_inval); /* struct_v */ + ceph_decode_64_safe(&dp, dend, *server_challenge, e_inval); + dout("%s server_challenge %llu\n", __func__, *server_challenge); + return 0; + +e_inval: + return -EINVAL; +} + +static int ceph_x_add_authorizer_challenge(struct ceph_auth_client *ac, + struct ceph_authorizer *a, + void *challenge, int challenge_len) +{ + struct ceph_x_authorizer *au = (void *)a; + u64 server_challenge; + int ret; + + ret = decrypt_authorizer_challenge(&au->session_key, challenge, + challenge_len, &server_challenge); + if (ret) { + pr_err("failed to decrypt authorize challenge: %d", ret); + return ret; + } + + ret = encrypt_authorizer(au, &server_challenge); + if (ret) { + pr_err("failed to encrypt authorizer w/ challenge: %d", ret); + return ret; + } + + return 0; +} + +/* + * CephXAuthorizeReply + */ +static int decrypt_authorizer_reply(struct ceph_crypto_key *secret, + void **p, void *end, u64 *nonce_plus_one, + u8 *con_secret, int *con_secret_len) +{ + void *dp, *dend; + u8 struct_v; + int ret; + + dp = *p + ceph_x_encrypt_offset(); + ret = ceph_x_decrypt(secret, p, end); + if (ret < 0) + return ret; + + dout("%s decrypted %d bytes\n", __func__, ret); + dend = dp + ret; + + ceph_decode_8_safe(&dp, dend, struct_v, e_inval); + ceph_decode_64_safe(&dp, dend, *nonce_plus_one, e_inval); + dout("%s nonce_plus_one %llu\n", __func__, *nonce_plus_one); + if (struct_v >= 2) { + ret = decode_con_secret(&dp, dend, con_secret, con_secret_len); + if (ret) + return ret; + } + + return 0; + +e_inval: + return -EINVAL; +} + +static int ceph_x_verify_authorizer_reply(struct ceph_auth_client *ac, + struct ceph_authorizer *a, + void *reply, int reply_len, + u8 *session_key, int *session_key_len, + u8 *con_secret, int *con_secret_len) +{ + struct ceph_x_authorizer *au = (void *)a; + u64 nonce_plus_one; + int ret; + + if (session_key) { + memcpy(session_key, au->session_key.key, au->session_key.len); + *session_key_len = au->session_key.len; + } + + ret = decrypt_authorizer_reply(&au->session_key, &reply, + reply + reply_len, &nonce_plus_one, + con_secret, con_secret_len); + if (ret) + return ret; + + if (nonce_plus_one != au->nonce + 1) { + pr_err("failed to authenticate server\n"); + return -EPERM; + } + + return 0; +} + +static void ceph_x_reset(struct ceph_auth_client *ac) +{ + struct ceph_x_info *xi = ac->private; + + dout("reset\n"); + xi->starting = true; + xi->server_challenge = 0; +} + +static void ceph_x_destroy(struct ceph_auth_client *ac) +{ + struct ceph_x_info *xi = ac->private; + struct rb_node *p; + + dout("ceph_x_destroy %p\n", ac); + ceph_crypto_key_destroy(&xi->secret); + + while ((p = rb_first(&xi->ticket_handlers)) != NULL) { + struct ceph_x_ticket_handler *th = + rb_entry(p, struct ceph_x_ticket_handler, node); + remove_ticket_handler(ac, th); + } + + ceph_x_authorizer_cleanup(&xi->auth_authorizer); + + kfree(ac->private); + ac->private = NULL; +} + +static void invalidate_ticket(struct ceph_auth_client *ac, int peer_type) +{ + struct ceph_x_ticket_handler *th; + + th = get_ticket_handler(ac, peer_type); + if (IS_ERR(th)) + return; + + if (th->have_key) { + dout("ticket %d (%s) secret_id %llu invalidated\n", + th->service, ceph_entity_type_name(th->service), + th->secret_id); + th->have_key = false; + } +} + +static void ceph_x_invalidate_authorizer(struct ceph_auth_client *ac, + int peer_type) +{ + /* + * We are to invalidate a service ticket in the hopes of + * getting a new, hopefully more valid, one. But, we won't get + * it unless our AUTH ticket is good, so invalidate AUTH ticket + * as well, just in case. + */ + invalidate_ticket(ac, peer_type); + invalidate_ticket(ac, CEPH_ENTITY_TYPE_AUTH); +} + +static int calc_signature(struct ceph_x_authorizer *au, struct ceph_msg *msg, + __le64 *psig) +{ + void *enc_buf = au->enc_buf; + int ret; + + if (!CEPH_HAVE_FEATURE(msg->con->peer_features, CEPHX_V2)) { + struct { + __le32 len; + __le32 header_crc; + __le32 front_crc; + __le32 middle_crc; + __le32 data_crc; + } __packed *sigblock = enc_buf + ceph_x_encrypt_offset(); + + sigblock->len = cpu_to_le32(4*sizeof(u32)); + sigblock->header_crc = msg->hdr.crc; + sigblock->front_crc = msg->footer.front_crc; + sigblock->middle_crc = msg->footer.middle_crc; + sigblock->data_crc = msg->footer.data_crc; + + ret = ceph_x_encrypt(&au->session_key, enc_buf, + CEPHX_AU_ENC_BUF_LEN, sizeof(*sigblock)); + if (ret < 0) + return ret; + + *psig = *(__le64 *)(enc_buf + sizeof(u32)); + } else { + struct { + __le32 header_crc; + __le32 front_crc; + __le32 front_len; + __le32 middle_crc; + __le32 middle_len; + __le32 data_crc; + __le32 data_len; + __le32 seq_lower_word; + } __packed *sigblock = enc_buf; + struct { + __le64 a, b, c, d; + } __packed *penc = enc_buf; + int ciphertext_len; + + sigblock->header_crc = msg->hdr.crc; + sigblock->front_crc = msg->footer.front_crc; + sigblock->front_len = msg->hdr.front_len; + sigblock->middle_crc = msg->footer.middle_crc; + sigblock->middle_len = msg->hdr.middle_len; + sigblock->data_crc = msg->footer.data_crc; + sigblock->data_len = msg->hdr.data_len; + sigblock->seq_lower_word = *(__le32 *)&msg->hdr.seq; + + /* no leading len, no ceph_x_encrypt_header */ + ret = ceph_crypt(&au->session_key, true, enc_buf, + CEPHX_AU_ENC_BUF_LEN, sizeof(*sigblock), + &ciphertext_len); + if (ret) + return ret; + + *psig = penc->a ^ penc->b ^ penc->c ^ penc->d; + } + + return 0; +} + +static int ceph_x_sign_message(struct ceph_auth_handshake *auth, + struct ceph_msg *msg) +{ + __le64 sig; + int ret; + + if (ceph_test_opt(from_msgr(msg->con->msgr), NOMSGSIGN)) + return 0; + + ret = calc_signature((struct ceph_x_authorizer *)auth->authorizer, + msg, &sig); + if (ret) + return ret; + + msg->footer.sig = sig; + msg->footer.flags |= CEPH_MSG_FOOTER_SIGNED; + return 0; +} + +static int ceph_x_check_message_signature(struct ceph_auth_handshake *auth, + struct ceph_msg *msg) +{ + __le64 sig_check; + int ret; + + if (ceph_test_opt(from_msgr(msg->con->msgr), NOMSGSIGN)) + return 0; + + ret = calc_signature((struct ceph_x_authorizer *)auth->authorizer, + msg, &sig_check); + if (ret) + return ret; + if (sig_check == msg->footer.sig) + return 0; + if (msg->footer.flags & CEPH_MSG_FOOTER_SIGNED) + dout("ceph_x_check_message_signature %p has signature %llx " + "expect %llx\n", msg, msg->footer.sig, sig_check); + else + dout("ceph_x_check_message_signature %p sender did not set " + "CEPH_MSG_FOOTER_SIGNED\n", msg); + return -EBADMSG; +} + +static const struct ceph_auth_client_ops ceph_x_ops = { + .is_authenticated = ceph_x_is_authenticated, + .should_authenticate = ceph_x_should_authenticate, + .build_request = ceph_x_build_request, + .handle_reply = ceph_x_handle_reply, + .create_authorizer = ceph_x_create_authorizer, + .update_authorizer = ceph_x_update_authorizer, + .add_authorizer_challenge = ceph_x_add_authorizer_challenge, + .verify_authorizer_reply = ceph_x_verify_authorizer_reply, + .invalidate_authorizer = ceph_x_invalidate_authorizer, + .reset = ceph_x_reset, + .destroy = ceph_x_destroy, + .sign_message = ceph_x_sign_message, + .check_message_signature = ceph_x_check_message_signature, +}; + + +int ceph_x_init(struct ceph_auth_client *ac) +{ + struct ceph_x_info *xi; + int ret; + + dout("ceph_x_init %p\n", ac); + ret = -ENOMEM; + xi = kzalloc(sizeof(*xi), GFP_NOFS); + if (!xi) + goto out; + + ret = -EINVAL; + if (!ac->key) { + pr_err("no secret set (for auth_x protocol)\n"); + goto out_nomem; + } + + ret = ceph_crypto_key_clone(&xi->secret, ac->key); + if (ret < 0) { + pr_err("cannot clone key: %d\n", ret); + goto out_nomem; + } + + xi->starting = true; + xi->ticket_handlers = RB_ROOT; + + ac->protocol = CEPH_AUTH_CEPHX; + ac->private = xi; + ac->ops = &ceph_x_ops; + return 0; + +out_nomem: + kfree(xi); +out: + return ret; +} diff --git a/net/ceph/auth_x.h b/net/ceph/auth_x.h new file mode 100644 index 000000000..c03735f96 --- /dev/null +++ b/net/ceph/auth_x.h @@ -0,0 +1,54 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _FS_CEPH_AUTH_X_H +#define _FS_CEPH_AUTH_X_H + +#include + +#include + +#include "crypto.h" +#include "auth_x_protocol.h" + +/* + * Handle ticket for a single service. + */ +struct ceph_x_ticket_handler { + struct rb_node node; + unsigned int service; + + struct ceph_crypto_key session_key; + bool have_key; + + u64 secret_id; + struct ceph_buffer *ticket_blob; + + time64_t renew_after, expires; +}; + +#define CEPHX_AU_ENC_BUF_LEN 128 /* big enough for encrypted blob */ + +struct ceph_x_authorizer { + struct ceph_authorizer base; + struct ceph_crypto_key session_key; + struct ceph_buffer *buf; + unsigned int service; + u64 nonce; + u64 secret_id; + char enc_buf[CEPHX_AU_ENC_BUF_LEN] __aligned(8); +}; + +struct ceph_x_info { + struct ceph_crypto_key secret; + + bool starting; + u64 server_challenge; + + unsigned int have_keys; + struct rb_root ticket_handlers; + + struct ceph_x_authorizer auth_authorizer; +}; + +int ceph_x_init(struct ceph_auth_client *ac); + +#endif diff --git a/net/ceph/auth_x_protocol.h b/net/ceph/auth_x_protocol.h new file mode 100644 index 000000000..9c60feeb1 --- /dev/null +++ b/net/ceph/auth_x_protocol.h @@ -0,0 +1,99 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __FS_CEPH_AUTH_X_PROTOCOL +#define __FS_CEPH_AUTH_X_PROTOCOL + +#define CEPHX_GET_AUTH_SESSION_KEY 0x0100 +#define CEPHX_GET_PRINCIPAL_SESSION_KEY 0x0200 +#define CEPHX_GET_ROTATING_KEY 0x0400 + +/* common bits */ +struct ceph_x_ticket_blob { + __u8 struct_v; + __le64 secret_id; + __le32 blob_len; + char blob[]; +} __attribute__ ((packed)); + + +/* common request/reply headers */ +struct ceph_x_request_header { + __le16 op; +} __attribute__ ((packed)); + +struct ceph_x_reply_header { + __le16 op; + __le32 result; +} __attribute__ ((packed)); + + +/* authenticate handshake */ + +/* initial hello (no reply header) */ +struct ceph_x_server_challenge { + __u8 struct_v; + __le64 server_challenge; +} __attribute__ ((packed)); + +struct ceph_x_authenticate { + __u8 struct_v; + __le64 client_challenge; + __le64 key; + /* old_ticket blob */ + /* nautilus+: other_keys */ +} __attribute__ ((packed)); + +struct ceph_x_service_ticket_request { + __u8 struct_v; + __le32 keys; +} __attribute__ ((packed)); + +struct ceph_x_challenge_blob { + __le64 server_challenge; + __le64 client_challenge; +} __attribute__ ((packed)); + + + +/* authorize handshake */ + +/* + * The authorizer consists of two pieces: + * a - service id, ticket blob + * b - encrypted with session key + */ +struct ceph_x_authorize_a { + __u8 struct_v; + __le64 global_id; + __le32 service_id; + struct ceph_x_ticket_blob ticket_blob; +} __attribute__ ((packed)); + +struct ceph_x_authorize_b { + __u8 struct_v; + __le64 nonce; + __u8 have_challenge; + __le64 server_challenge_plus_one; +} __attribute__ ((packed)); + +struct ceph_x_authorize_challenge { + __u8 struct_v; + __le64 server_challenge; +} __attribute__ ((packed)); + +struct ceph_x_authorize_reply { + __u8 struct_v; + __le64 nonce_plus_one; +} __attribute__ ((packed)); + + +/* + * encryption bundle + */ +#define CEPHX_ENC_MAGIC 0xff009cad8826aa55ull + +struct ceph_x_encrypt_header { + __u8 struct_v; + __le64 magic; +} __attribute__ ((packed)); + +#endif diff --git a/net/ceph/buffer.c b/net/ceph/buffer.c new file mode 100644 index 000000000..7e51f1280 --- /dev/null +++ b/net/ceph/buffer.c @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include + +#include +#include + +#include +#include +#include /* for kvmalloc */ + +struct ceph_buffer *ceph_buffer_new(size_t len, gfp_t gfp) +{ + struct ceph_buffer *b; + + b = kmalloc(sizeof(*b), gfp); + if (!b) + return NULL; + + b->vec.iov_base = kvmalloc(len, gfp); + if (!b->vec.iov_base) { + kfree(b); + return NULL; + } + + kref_init(&b->kref); + b->alloc_len = len; + b->vec.iov_len = len; + dout("buffer_new %p\n", b); + return b; +} +EXPORT_SYMBOL(ceph_buffer_new); + +void ceph_buffer_release(struct kref *kref) +{ + struct ceph_buffer *b = container_of(kref, struct ceph_buffer, kref); + + dout("buffer_release %p\n", b); + kvfree(b->vec.iov_base); + kfree(b); +} +EXPORT_SYMBOL(ceph_buffer_release); + +int ceph_decode_buffer(struct ceph_buffer **b, void **p, void *end) +{ + size_t len; + + ceph_decode_need(p, end, sizeof(u32), bad); + len = ceph_decode_32(p); + dout("decode_buffer len %d\n", (int)len); + ceph_decode_need(p, end, len, bad); + *b = ceph_buffer_new(len, GFP_NOFS); + if (!*b) + return -ENOMEM; + ceph_decode_copy(p, (*b)->vec.iov_base, len); + return 0; +bad: + return -EINVAL; +} diff --git a/net/ceph/ceph_common.c b/net/ceph/ceph_common.c new file mode 100644 index 000000000..4c6441536 --- /dev/null +++ b/net/ceph/ceph_common.c @@ -0,0 +1,916 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#include +#include +#include +#include +#include +#include +#include "crypto.h" + + +/* + * Module compatibility interface. For now it doesn't do anything, + * but its existence signals a certain level of functionality. + * + * The data buffer is used to pass information both to and from + * libceph. The return value indicates whether libceph determines + * it is compatible with the caller (from another kernel module), + * given the provided data. + * + * The data pointer can be null. + */ +bool libceph_compatible(void *data) +{ + return true; +} +EXPORT_SYMBOL(libceph_compatible); + +static int param_get_supported_features(char *buffer, + const struct kernel_param *kp) +{ + return sprintf(buffer, "0x%llx", CEPH_FEATURES_SUPPORTED_DEFAULT); +} +static const struct kernel_param_ops param_ops_supported_features = { + .get = param_get_supported_features, +}; +module_param_cb(supported_features, ¶m_ops_supported_features, NULL, + 0444); + +const char *ceph_msg_type_name(int type) +{ + switch (type) { + case CEPH_MSG_SHUTDOWN: return "shutdown"; + case CEPH_MSG_PING: return "ping"; + case CEPH_MSG_AUTH: return "auth"; + case CEPH_MSG_AUTH_REPLY: return "auth_reply"; + case CEPH_MSG_MON_MAP: return "mon_map"; + case CEPH_MSG_MON_GET_MAP: return "mon_get_map"; + case CEPH_MSG_MON_SUBSCRIBE: return "mon_subscribe"; + case CEPH_MSG_MON_SUBSCRIBE_ACK: return "mon_subscribe_ack"; + case CEPH_MSG_STATFS: return "statfs"; + case CEPH_MSG_STATFS_REPLY: return "statfs_reply"; + case CEPH_MSG_MON_GET_VERSION: return "mon_get_version"; + case CEPH_MSG_MON_GET_VERSION_REPLY: return "mon_get_version_reply"; + case CEPH_MSG_MDS_MAP: return "mds_map"; + case CEPH_MSG_FS_MAP_USER: return "fs_map_user"; + case CEPH_MSG_CLIENT_SESSION: return "client_session"; + case CEPH_MSG_CLIENT_RECONNECT: return "client_reconnect"; + case CEPH_MSG_CLIENT_REQUEST: return "client_request"; + case CEPH_MSG_CLIENT_REQUEST_FORWARD: return "client_request_forward"; + case CEPH_MSG_CLIENT_REPLY: return "client_reply"; + case CEPH_MSG_CLIENT_CAPS: return "client_caps"; + case CEPH_MSG_CLIENT_CAPRELEASE: return "client_cap_release"; + case CEPH_MSG_CLIENT_QUOTA: return "client_quota"; + case CEPH_MSG_CLIENT_SNAP: return "client_snap"; + case CEPH_MSG_CLIENT_LEASE: return "client_lease"; + case CEPH_MSG_POOLOP_REPLY: return "poolop_reply"; + case CEPH_MSG_POOLOP: return "poolop"; + case CEPH_MSG_MON_COMMAND: return "mon_command"; + case CEPH_MSG_MON_COMMAND_ACK: return "mon_command_ack"; + case CEPH_MSG_OSD_MAP: return "osd_map"; + case CEPH_MSG_OSD_OP: return "osd_op"; + case CEPH_MSG_OSD_OPREPLY: return "osd_opreply"; + case CEPH_MSG_WATCH_NOTIFY: return "watch_notify"; + case CEPH_MSG_OSD_BACKOFF: return "osd_backoff"; + default: return "unknown"; + } +} +EXPORT_SYMBOL(ceph_msg_type_name); + +/* + * Initially learn our fsid, or verify an fsid matches. + */ +int ceph_check_fsid(struct ceph_client *client, struct ceph_fsid *fsid) +{ + if (client->have_fsid) { + if (ceph_fsid_compare(&client->fsid, fsid)) { + pr_err("bad fsid, had %pU got %pU", + &client->fsid, fsid); + return -1; + } + } else { + memcpy(&client->fsid, fsid, sizeof(*fsid)); + } + return 0; +} +EXPORT_SYMBOL(ceph_check_fsid); + +static int strcmp_null(const char *s1, const char *s2) +{ + if (!s1 && !s2) + return 0; + if (s1 && !s2) + return -1; + if (!s1 && s2) + return 1; + return strcmp(s1, s2); +} + +int ceph_compare_options(struct ceph_options *new_opt, + struct ceph_client *client) +{ + struct ceph_options *opt1 = new_opt; + struct ceph_options *opt2 = client->options; + int ofs = offsetof(struct ceph_options, mon_addr); + int i; + int ret; + + /* + * Don't bother comparing options if network namespaces don't + * match. + */ + if (!net_eq(current->nsproxy->net_ns, read_pnet(&client->msgr.net))) + return -1; + + ret = memcmp(opt1, opt2, ofs); + if (ret) + return ret; + + ret = strcmp_null(opt1->name, opt2->name); + if (ret) + return ret; + + if (opt1->key && !opt2->key) + return -1; + if (!opt1->key && opt2->key) + return 1; + if (opt1->key && opt2->key) { + if (opt1->key->type != opt2->key->type) + return -1; + if (opt1->key->created.tv_sec != opt2->key->created.tv_sec) + return -1; + if (opt1->key->created.tv_nsec != opt2->key->created.tv_nsec) + return -1; + if (opt1->key->len != opt2->key->len) + return -1; + if (opt1->key->key && !opt2->key->key) + return -1; + if (!opt1->key->key && opt2->key->key) + return 1; + if (opt1->key->key && opt2->key->key) { + ret = memcmp(opt1->key->key, opt2->key->key, opt1->key->len); + if (ret) + return ret; + } + } + + ret = ceph_compare_crush_locs(&opt1->crush_locs, &opt2->crush_locs); + if (ret) + return ret; + + /* any matching mon ip implies a match */ + for (i = 0; i < opt1->num_mon; i++) { + if (ceph_monmap_contains(client->monc.monmap, + &opt1->mon_addr[i])) + return 0; + } + return -1; +} +EXPORT_SYMBOL(ceph_compare_options); + +int ceph_parse_fsid(const char *str, struct ceph_fsid *fsid) +{ + int i = 0; + char tmp[3]; + int err = -EINVAL; + int d; + + dout("%s '%s'\n", __func__, str); + tmp[2] = 0; + while (*str && i < 16) { + if (ispunct(*str)) { + str++; + continue; + } + if (!isxdigit(str[0]) || !isxdigit(str[1])) + break; + tmp[0] = str[0]; + tmp[1] = str[1]; + if (sscanf(tmp, "%x", &d) < 1) + break; + fsid->fsid[i] = d & 0xff; + i++; + str += 2; + } + + if (i == 16) + err = 0; + dout("%s ret %d got fsid %pU\n", __func__, err, fsid); + return err; +} +EXPORT_SYMBOL(ceph_parse_fsid); + +/* + * ceph options + */ +enum { + Opt_osdkeepalivetimeout, + Opt_mount_timeout, + Opt_osd_idle_ttl, + Opt_osd_request_timeout, + /* int args above */ + Opt_fsid, + Opt_name, + Opt_secret, + Opt_key, + Opt_ip, + Opt_crush_location, + Opt_read_from_replica, + Opt_ms_mode, + /* string args above */ + Opt_share, + Opt_crc, + Opt_cephx_require_signatures, + Opt_cephx_sign_messages, + Opt_tcp_nodelay, + Opt_abort_on_full, + Opt_rxbounce, +}; + +enum { + Opt_read_from_replica_no, + Opt_read_from_replica_balance, + Opt_read_from_replica_localize, +}; + +static const struct constant_table ceph_param_read_from_replica[] = { + {"no", Opt_read_from_replica_no}, + {"balance", Opt_read_from_replica_balance}, + {"localize", Opt_read_from_replica_localize}, + {} +}; + +enum ceph_ms_mode { + Opt_ms_mode_legacy, + Opt_ms_mode_crc, + Opt_ms_mode_secure, + Opt_ms_mode_prefer_crc, + Opt_ms_mode_prefer_secure +}; + +static const struct constant_table ceph_param_ms_mode[] = { + {"legacy", Opt_ms_mode_legacy}, + {"crc", Opt_ms_mode_crc}, + {"secure", Opt_ms_mode_secure}, + {"prefer-crc", Opt_ms_mode_prefer_crc}, + {"prefer-secure", Opt_ms_mode_prefer_secure}, + {} +}; + +static const struct fs_parameter_spec ceph_parameters[] = { + fsparam_flag ("abort_on_full", Opt_abort_on_full), + __fsparam (NULL, "cephx_require_signatures", Opt_cephx_require_signatures, + fs_param_neg_with_no|fs_param_deprecated, NULL), + fsparam_flag_no ("cephx_sign_messages", Opt_cephx_sign_messages), + fsparam_flag_no ("crc", Opt_crc), + fsparam_string ("crush_location", Opt_crush_location), + fsparam_string ("fsid", Opt_fsid), + fsparam_string ("ip", Opt_ip), + fsparam_string ("key", Opt_key), + fsparam_u32 ("mount_timeout", Opt_mount_timeout), + fsparam_string ("name", Opt_name), + fsparam_u32 ("osd_idle_ttl", Opt_osd_idle_ttl), + fsparam_u32 ("osd_request_timeout", Opt_osd_request_timeout), + fsparam_u32 ("osdkeepalive", Opt_osdkeepalivetimeout), + fsparam_enum ("read_from_replica", Opt_read_from_replica, + ceph_param_read_from_replica), + fsparam_flag ("rxbounce", Opt_rxbounce), + fsparam_enum ("ms_mode", Opt_ms_mode, + ceph_param_ms_mode), + fsparam_string ("secret", Opt_secret), + fsparam_flag_no ("share", Opt_share), + fsparam_flag_no ("tcp_nodelay", Opt_tcp_nodelay), + {} +}; + +struct ceph_options *ceph_alloc_options(void) +{ + struct ceph_options *opt; + + opt = kzalloc(sizeof(*opt), GFP_KERNEL); + if (!opt) + return NULL; + + opt->crush_locs = RB_ROOT; + opt->mon_addr = kcalloc(CEPH_MAX_MON, sizeof(*opt->mon_addr), + GFP_KERNEL); + if (!opt->mon_addr) { + kfree(opt); + return NULL; + } + + opt->flags = CEPH_OPT_DEFAULT; + opt->osd_keepalive_timeout = CEPH_OSD_KEEPALIVE_DEFAULT; + opt->mount_timeout = CEPH_MOUNT_TIMEOUT_DEFAULT; + opt->osd_idle_ttl = CEPH_OSD_IDLE_TTL_DEFAULT; + opt->osd_request_timeout = CEPH_OSD_REQUEST_TIMEOUT_DEFAULT; + opt->read_from_replica = CEPH_READ_FROM_REPLICA_DEFAULT; + opt->con_modes[0] = CEPH_CON_MODE_UNKNOWN; + opt->con_modes[1] = CEPH_CON_MODE_UNKNOWN; + return opt; +} +EXPORT_SYMBOL(ceph_alloc_options); + +void ceph_destroy_options(struct ceph_options *opt) +{ + dout("destroy_options %p\n", opt); + if (!opt) + return; + + ceph_clear_crush_locs(&opt->crush_locs); + kfree(opt->name); + if (opt->key) { + ceph_crypto_key_destroy(opt->key); + kfree(opt->key); + } + kfree(opt->mon_addr); + kfree(opt); +} +EXPORT_SYMBOL(ceph_destroy_options); + +/* get secret from key store */ +static int get_secret(struct ceph_crypto_key *dst, const char *name, + struct p_log *log) +{ + struct key *ukey; + int key_err; + int err = 0; + struct ceph_crypto_key *ckey; + + ukey = request_key(&key_type_ceph, name, NULL); + if (IS_ERR(ukey)) { + /* request_key errors don't map nicely to mount(2) + errors; don't even try, but still printk */ + key_err = PTR_ERR(ukey); + switch (key_err) { + case -ENOKEY: + error_plog(log, "Failed due to key not found: %s", + name); + break; + case -EKEYEXPIRED: + error_plog(log, "Failed due to expired key: %s", + name); + break; + case -EKEYREVOKED: + error_plog(log, "Failed due to revoked key: %s", + name); + break; + default: + error_plog(log, "Failed due to key error %d: %s", + key_err, name); + } + err = -EPERM; + goto out; + } + + ckey = ukey->payload.data[0]; + err = ceph_crypto_key_clone(dst, ckey); + if (err) + goto out_key; + /* pass through, err is 0 */ + +out_key: + key_put(ukey); +out: + return err; +} + +int ceph_parse_mon_ips(const char *buf, size_t len, struct ceph_options *opt, + struct fc_log *l, char delim) +{ + struct p_log log = {.prefix = "libceph", .log = l}; + int ret; + + /* ip1[:port1][ip2[:port2]...] */ + ret = ceph_parse_ips(buf, buf + len, opt->mon_addr, CEPH_MAX_MON, + &opt->num_mon, delim); + if (ret) { + error_plog(&log, "Failed to parse monitor IPs: %d", ret); + return ret; + } + + return 0; +} +EXPORT_SYMBOL(ceph_parse_mon_ips); + +int ceph_parse_param(struct fs_parameter *param, struct ceph_options *opt, + struct fc_log *l) +{ + struct fs_parse_result result; + int token, err; + struct p_log log = {.prefix = "libceph", .log = l}; + + token = __fs_parse(&log, ceph_parameters, param, &result); + dout("%s fs_parse '%s' token %d\n", __func__, param->key, token); + if (token < 0) + return token; + + switch (token) { + case Opt_ip: + err = ceph_parse_ips(param->string, + param->string + param->size, + &opt->my_addr, 1, NULL, ','); + if (err) { + error_plog(&log, "Failed to parse ip: %d", err); + return err; + } + opt->flags |= CEPH_OPT_MYIP; + break; + + case Opt_fsid: + err = ceph_parse_fsid(param->string, &opt->fsid); + if (err) { + error_plog(&log, "Failed to parse fsid: %d", err); + return err; + } + opt->flags |= CEPH_OPT_FSID; + break; + case Opt_name: + kfree(opt->name); + opt->name = param->string; + param->string = NULL; + break; + case Opt_secret: + ceph_crypto_key_destroy(opt->key); + kfree(opt->key); + + opt->key = kzalloc(sizeof(*opt->key), GFP_KERNEL); + if (!opt->key) + return -ENOMEM; + err = ceph_crypto_key_unarmor(opt->key, param->string); + if (err) { + error_plog(&log, "Failed to parse secret: %d", err); + return err; + } + break; + case Opt_key: + ceph_crypto_key_destroy(opt->key); + kfree(opt->key); + + opt->key = kzalloc(sizeof(*opt->key), GFP_KERNEL); + if (!opt->key) + return -ENOMEM; + return get_secret(opt->key, param->string, &log); + case Opt_crush_location: + ceph_clear_crush_locs(&opt->crush_locs); + err = ceph_parse_crush_location(param->string, + &opt->crush_locs); + if (err) { + error_plog(&log, "Failed to parse CRUSH location: %d", + err); + return err; + } + break; + case Opt_read_from_replica: + switch (result.uint_32) { + case Opt_read_from_replica_no: + opt->read_from_replica = 0; + break; + case Opt_read_from_replica_balance: + opt->read_from_replica = CEPH_OSD_FLAG_BALANCE_READS; + break; + case Opt_read_from_replica_localize: + opt->read_from_replica = CEPH_OSD_FLAG_LOCALIZE_READS; + break; + default: + BUG(); + } + break; + case Opt_ms_mode: + switch (result.uint_32) { + case Opt_ms_mode_legacy: + opt->con_modes[0] = CEPH_CON_MODE_UNKNOWN; + opt->con_modes[1] = CEPH_CON_MODE_UNKNOWN; + break; + case Opt_ms_mode_crc: + opt->con_modes[0] = CEPH_CON_MODE_CRC; + opt->con_modes[1] = CEPH_CON_MODE_UNKNOWN; + break; + case Opt_ms_mode_secure: + opt->con_modes[0] = CEPH_CON_MODE_SECURE; + opt->con_modes[1] = CEPH_CON_MODE_UNKNOWN; + break; + case Opt_ms_mode_prefer_crc: + opt->con_modes[0] = CEPH_CON_MODE_CRC; + opt->con_modes[1] = CEPH_CON_MODE_SECURE; + break; + case Opt_ms_mode_prefer_secure: + opt->con_modes[0] = CEPH_CON_MODE_SECURE; + opt->con_modes[1] = CEPH_CON_MODE_CRC; + break; + default: + BUG(); + } + break; + + case Opt_osdkeepalivetimeout: + /* 0 isn't well defined right now, reject it */ + if (result.uint_32 < 1 || result.uint_32 > INT_MAX / 1000) + goto out_of_range; + opt->osd_keepalive_timeout = + msecs_to_jiffies(result.uint_32 * 1000); + break; + case Opt_osd_idle_ttl: + /* 0 isn't well defined right now, reject it */ + if (result.uint_32 < 1 || result.uint_32 > INT_MAX / 1000) + goto out_of_range; + opt->osd_idle_ttl = msecs_to_jiffies(result.uint_32 * 1000); + break; + case Opt_mount_timeout: + /* 0 is "wait forever" (i.e. infinite timeout) */ + if (result.uint_32 > INT_MAX / 1000) + goto out_of_range; + opt->mount_timeout = msecs_to_jiffies(result.uint_32 * 1000); + break; + case Opt_osd_request_timeout: + /* 0 is "wait forever" (i.e. infinite timeout) */ + if (result.uint_32 > INT_MAX / 1000) + goto out_of_range; + opt->osd_request_timeout = + msecs_to_jiffies(result.uint_32 * 1000); + break; + + case Opt_share: + if (!result.negated) + opt->flags &= ~CEPH_OPT_NOSHARE; + else + opt->flags |= CEPH_OPT_NOSHARE; + break; + case Opt_crc: + if (!result.negated) + opt->flags &= ~CEPH_OPT_NOCRC; + else + opt->flags |= CEPH_OPT_NOCRC; + break; + case Opt_cephx_require_signatures: + if (!result.negated) + warn_plog(&log, "Ignoring cephx_require_signatures"); + else + warn_plog(&log, "Ignoring nocephx_require_signatures, use nocephx_sign_messages"); + break; + case Opt_cephx_sign_messages: + if (!result.negated) + opt->flags &= ~CEPH_OPT_NOMSGSIGN; + else + opt->flags |= CEPH_OPT_NOMSGSIGN; + break; + case Opt_tcp_nodelay: + if (!result.negated) + opt->flags |= CEPH_OPT_TCP_NODELAY; + else + opt->flags &= ~CEPH_OPT_TCP_NODELAY; + break; + + case Opt_abort_on_full: + opt->flags |= CEPH_OPT_ABORT_ON_FULL; + break; + case Opt_rxbounce: + opt->flags |= CEPH_OPT_RXBOUNCE; + break; + + default: + BUG(); + } + + return 0; + +out_of_range: + return inval_plog(&log, "%s out of range", param->key); +} +EXPORT_SYMBOL(ceph_parse_param); + +int ceph_print_client_options(struct seq_file *m, struct ceph_client *client, + bool show_all) +{ + struct ceph_options *opt = client->options; + size_t pos = m->count; + struct rb_node *n; + + if (opt->name) { + seq_puts(m, "name="); + seq_escape(m, opt->name, ", \t\n\\"); + seq_putc(m, ','); + } + if (opt->key) + seq_puts(m, "secret=,"); + + if (!RB_EMPTY_ROOT(&opt->crush_locs)) { + seq_puts(m, "crush_location="); + for (n = rb_first(&opt->crush_locs); ; ) { + struct crush_loc_node *loc = + rb_entry(n, struct crush_loc_node, cl_node); + + seq_printf(m, "%s:%s", loc->cl_loc.cl_type_name, + loc->cl_loc.cl_name); + n = rb_next(n); + if (!n) + break; + + seq_putc(m, '|'); + } + seq_putc(m, ','); + } + if (opt->read_from_replica == CEPH_OSD_FLAG_BALANCE_READS) { + seq_puts(m, "read_from_replica=balance,"); + } else if (opt->read_from_replica == CEPH_OSD_FLAG_LOCALIZE_READS) { + seq_puts(m, "read_from_replica=localize,"); + } + if (opt->con_modes[0] != CEPH_CON_MODE_UNKNOWN) { + if (opt->con_modes[0] == CEPH_CON_MODE_CRC && + opt->con_modes[1] == CEPH_CON_MODE_UNKNOWN) { + seq_puts(m, "ms_mode=crc,"); + } else if (opt->con_modes[0] == CEPH_CON_MODE_SECURE && + opt->con_modes[1] == CEPH_CON_MODE_UNKNOWN) { + seq_puts(m, "ms_mode=secure,"); + } else if (opt->con_modes[0] == CEPH_CON_MODE_CRC && + opt->con_modes[1] == CEPH_CON_MODE_SECURE) { + seq_puts(m, "ms_mode=prefer-crc,"); + } else if (opt->con_modes[0] == CEPH_CON_MODE_SECURE && + opt->con_modes[1] == CEPH_CON_MODE_CRC) { + seq_puts(m, "ms_mode=prefer-secure,"); + } + } + + if (opt->flags & CEPH_OPT_FSID) + seq_printf(m, "fsid=%pU,", &opt->fsid); + if (opt->flags & CEPH_OPT_NOSHARE) + seq_puts(m, "noshare,"); + if (opt->flags & CEPH_OPT_NOCRC) + seq_puts(m, "nocrc,"); + if (opt->flags & CEPH_OPT_NOMSGSIGN) + seq_puts(m, "nocephx_sign_messages,"); + if ((opt->flags & CEPH_OPT_TCP_NODELAY) == 0) + seq_puts(m, "notcp_nodelay,"); + if (show_all && (opt->flags & CEPH_OPT_ABORT_ON_FULL)) + seq_puts(m, "abort_on_full,"); + if (opt->flags & CEPH_OPT_RXBOUNCE) + seq_puts(m, "rxbounce,"); + + if (opt->mount_timeout != CEPH_MOUNT_TIMEOUT_DEFAULT) + seq_printf(m, "mount_timeout=%d,", + jiffies_to_msecs(opt->mount_timeout) / 1000); + if (opt->osd_idle_ttl != CEPH_OSD_IDLE_TTL_DEFAULT) + seq_printf(m, "osd_idle_ttl=%d,", + jiffies_to_msecs(opt->osd_idle_ttl) / 1000); + if (opt->osd_keepalive_timeout != CEPH_OSD_KEEPALIVE_DEFAULT) + seq_printf(m, "osdkeepalivetimeout=%d,", + jiffies_to_msecs(opt->osd_keepalive_timeout) / 1000); + if (opt->osd_request_timeout != CEPH_OSD_REQUEST_TIMEOUT_DEFAULT) + seq_printf(m, "osd_request_timeout=%d,", + jiffies_to_msecs(opt->osd_request_timeout) / 1000); + + /* drop redundant comma */ + if (m->count != pos) + m->count--; + + return 0; +} +EXPORT_SYMBOL(ceph_print_client_options); + +struct ceph_entity_addr *ceph_client_addr(struct ceph_client *client) +{ + return &client->msgr.inst.addr; +} +EXPORT_SYMBOL(ceph_client_addr); + +u64 ceph_client_gid(struct ceph_client *client) +{ + return client->monc.auth->global_id; +} +EXPORT_SYMBOL(ceph_client_gid); + +/* + * create a fresh client instance + */ +struct ceph_client *ceph_create_client(struct ceph_options *opt, void *private) +{ + struct ceph_client *client; + struct ceph_entity_addr *myaddr = NULL; + int err; + + err = wait_for_random_bytes(); + if (err < 0) + return ERR_PTR(err); + + client = kzalloc(sizeof(*client), GFP_KERNEL); + if (client == NULL) + return ERR_PTR(-ENOMEM); + + client->private = private; + client->options = opt; + + mutex_init(&client->mount_mutex); + init_waitqueue_head(&client->auth_wq); + client->auth_err = 0; + + client->extra_mon_dispatch = NULL; + client->supported_features = CEPH_FEATURES_SUPPORTED_DEFAULT; + client->required_features = CEPH_FEATURES_REQUIRED_DEFAULT; + + if (!ceph_test_opt(client, NOMSGSIGN)) + client->required_features |= CEPH_FEATURE_MSG_AUTH; + + /* msgr */ + if (ceph_test_opt(client, MYIP)) + myaddr = &client->options->my_addr; + + ceph_messenger_init(&client->msgr, myaddr); + + /* subsystems */ + err = ceph_monc_init(&client->monc, client); + if (err < 0) + goto fail; + err = ceph_osdc_init(&client->osdc, client); + if (err < 0) + goto fail_monc; + + return client; + +fail_monc: + ceph_monc_stop(&client->monc); +fail: + ceph_messenger_fini(&client->msgr); + kfree(client); + return ERR_PTR(err); +} +EXPORT_SYMBOL(ceph_create_client); + +void ceph_destroy_client(struct ceph_client *client) +{ + dout("destroy_client %p\n", client); + + atomic_set(&client->msgr.stopping, 1); + + /* unmount */ + ceph_osdc_stop(&client->osdc); + ceph_monc_stop(&client->monc); + ceph_messenger_fini(&client->msgr); + + ceph_debugfs_client_cleanup(client); + + ceph_destroy_options(client->options); + + kfree(client); + dout("destroy_client %p done\n", client); +} +EXPORT_SYMBOL(ceph_destroy_client); + +void ceph_reset_client_addr(struct ceph_client *client) +{ + ceph_messenger_reset_nonce(&client->msgr); + ceph_monc_reopen_session(&client->monc); + ceph_osdc_reopen_osds(&client->osdc); +} +EXPORT_SYMBOL(ceph_reset_client_addr); + +/* + * true if we have the mon map (and have thus joined the cluster) + */ +static bool have_mon_and_osd_map(struct ceph_client *client) +{ + return client->monc.monmap && client->monc.monmap->epoch && + client->osdc.osdmap && client->osdc.osdmap->epoch; +} + +/* + * mount: join the ceph cluster, and open root directory. + */ +int __ceph_open_session(struct ceph_client *client, unsigned long started) +{ + unsigned long timeout = client->options->mount_timeout; + long err; + + /* open session, and wait for mon and osd maps */ + err = ceph_monc_open_session(&client->monc); + if (err < 0) + return err; + + while (!have_mon_and_osd_map(client)) { + if (timeout && time_after_eq(jiffies, started + timeout)) + return -ETIMEDOUT; + + /* wait */ + dout("mount waiting for mon_map\n"); + err = wait_event_interruptible_timeout(client->auth_wq, + have_mon_and_osd_map(client) || (client->auth_err < 0), + ceph_timeout_jiffies(timeout)); + if (err < 0) + return err; + if (client->auth_err < 0) + return client->auth_err; + } + + pr_info("client%llu fsid %pU\n", ceph_client_gid(client), + &client->fsid); + ceph_debugfs_client_init(client); + + return 0; +} +EXPORT_SYMBOL(__ceph_open_session); + +int ceph_open_session(struct ceph_client *client) +{ + int ret; + unsigned long started = jiffies; /* note the start time */ + + dout("open_session start\n"); + mutex_lock(&client->mount_mutex); + + ret = __ceph_open_session(client, started); + + mutex_unlock(&client->mount_mutex); + return ret; +} +EXPORT_SYMBOL(ceph_open_session); + +int ceph_wait_for_latest_osdmap(struct ceph_client *client, + unsigned long timeout) +{ + u64 newest_epoch; + int ret; + + ret = ceph_monc_get_version(&client->monc, "osdmap", &newest_epoch); + if (ret) + return ret; + + if (client->osdc.osdmap->epoch >= newest_epoch) + return 0; + + ceph_osdc_maybe_request_map(&client->osdc); + return ceph_monc_wait_osdmap(&client->monc, newest_epoch, timeout); +} +EXPORT_SYMBOL(ceph_wait_for_latest_osdmap); + +static int __init init_ceph_lib(void) +{ + int ret = 0; + + ceph_debugfs_init(); + + ret = ceph_crypto_init(); + if (ret < 0) + goto out_debugfs; + + ret = ceph_msgr_init(); + if (ret < 0) + goto out_crypto; + + ret = ceph_osdc_setup(); + if (ret < 0) + goto out_msgr; + + pr_info("loaded (mon/osd proto %d/%d)\n", + CEPH_MONC_PROTOCOL, CEPH_OSDC_PROTOCOL); + + return 0; + +out_msgr: + ceph_msgr_exit(); +out_crypto: + ceph_crypto_shutdown(); +out_debugfs: + ceph_debugfs_cleanup(); + return ret; +} + +static void __exit exit_ceph_lib(void) +{ + dout("exit_ceph_lib\n"); + WARN_ON(!ceph_strings_empty()); + + ceph_osdc_cleanup(); + ceph_msgr_exit(); + ceph_crypto_shutdown(); + ceph_debugfs_cleanup(); +} + +module_init(init_ceph_lib); +module_exit(exit_ceph_lib); + +MODULE_AUTHOR("Sage Weil "); +MODULE_AUTHOR("Yehuda Sadeh "); +MODULE_AUTHOR("Patience Warnick "); +MODULE_DESCRIPTION("Ceph core library"); +MODULE_LICENSE("GPL"); diff --git a/net/ceph/ceph_hash.c b/net/ceph/ceph_hash.c new file mode 100644 index 000000000..16a47c0ee --- /dev/null +++ b/net/ceph/ceph_hash.c @@ -0,0 +1,131 @@ + +#include +#include + +/* + * Robert Jenkin's hash function. + * https://burtleburtle.net/bob/hash/evahash.html + * This is in the public domain. + */ +#define mix(a, b, c) \ + do { \ + a = a - b; a = a - c; a = a ^ (c >> 13); \ + b = b - c; b = b - a; b = b ^ (a << 8); \ + c = c - a; c = c - b; c = c ^ (b >> 13); \ + a = a - b; a = a - c; a = a ^ (c >> 12); \ + b = b - c; b = b - a; b = b ^ (a << 16); \ + c = c - a; c = c - b; c = c ^ (b >> 5); \ + a = a - b; a = a - c; a = a ^ (c >> 3); \ + b = b - c; b = b - a; b = b ^ (a << 10); \ + c = c - a; c = c - b; c = c ^ (b >> 15); \ + } while (0) + +unsigned int ceph_str_hash_rjenkins(const char *str, unsigned int length) +{ + const unsigned char *k = (const unsigned char *)str; + __u32 a, b, c; /* the internal state */ + __u32 len; /* how many key bytes still need mixing */ + + /* Set up the internal state */ + len = length; + a = 0x9e3779b9; /* the golden ratio; an arbitrary value */ + b = a; + c = 0; /* variable initialization of internal state */ + + /* handle most of the key */ + while (len >= 12) { + a = a + (k[0] + ((__u32)k[1] << 8) + ((__u32)k[2] << 16) + + ((__u32)k[3] << 24)); + b = b + (k[4] + ((__u32)k[5] << 8) + ((__u32)k[6] << 16) + + ((__u32)k[7] << 24)); + c = c + (k[8] + ((__u32)k[9] << 8) + ((__u32)k[10] << 16) + + ((__u32)k[11] << 24)); + mix(a, b, c); + k = k + 12; + len = len - 12; + } + + /* handle the last 11 bytes */ + c = c + length; + switch (len) { + case 11: + c = c + ((__u32)k[10] << 24); + fallthrough; + case 10: + c = c + ((__u32)k[9] << 16); + fallthrough; + case 9: + c = c + ((__u32)k[8] << 8); + /* the first byte of c is reserved for the length */ + fallthrough; + case 8: + b = b + ((__u32)k[7] << 24); + fallthrough; + case 7: + b = b + ((__u32)k[6] << 16); + fallthrough; + case 6: + b = b + ((__u32)k[5] << 8); + fallthrough; + case 5: + b = b + k[4]; + fallthrough; + case 4: + a = a + ((__u32)k[3] << 24); + fallthrough; + case 3: + a = a + ((__u32)k[2] << 16); + fallthrough; + case 2: + a = a + ((__u32)k[1] << 8); + fallthrough; + case 1: + a = a + k[0]; + /* case 0: nothing left to add */ + } + mix(a, b, c); + + return c; +} + +/* + * linux dcache hash + */ +unsigned int ceph_str_hash_linux(const char *str, unsigned int length) +{ + unsigned long hash = 0; + unsigned char c; + + while (length--) { + c = *str++; + hash = (hash + (c << 4) + (c >> 4)) * 11; + } + return hash; +} + + +unsigned int ceph_str_hash(int type, const char *s, unsigned int len) +{ + switch (type) { + case CEPH_STR_HASH_LINUX: + return ceph_str_hash_linux(s, len); + case CEPH_STR_HASH_RJENKINS: + return ceph_str_hash_rjenkins(s, len); + default: + return -1; + } +} +EXPORT_SYMBOL(ceph_str_hash); + +const char *ceph_str_hash_name(int type) +{ + switch (type) { + case CEPH_STR_HASH_LINUX: + return "linux"; + case CEPH_STR_HASH_RJENKINS: + return "rjenkins"; + default: + return "unknown"; + } +} +EXPORT_SYMBOL(ceph_str_hash_name); diff --git a/net/ceph/ceph_strings.c b/net/ceph/ceph_strings.c new file mode 100644 index 000000000..355fea272 --- /dev/null +++ b/net/ceph/ceph_strings.c @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Ceph string constants + */ +#include +#include + +const char *ceph_entity_type_name(int type) +{ + switch (type) { + case CEPH_ENTITY_TYPE_MDS: return "mds"; + case CEPH_ENTITY_TYPE_OSD: return "osd"; + case CEPH_ENTITY_TYPE_MON: return "mon"; + case CEPH_ENTITY_TYPE_CLIENT: return "client"; + case CEPH_ENTITY_TYPE_AUTH: return "auth"; + default: return "unknown"; + } +} +EXPORT_SYMBOL(ceph_entity_type_name); + +const char *ceph_auth_proto_name(int proto) +{ + switch (proto) { + case CEPH_AUTH_UNKNOWN: + return "unknown"; + case CEPH_AUTH_NONE: + return "none"; + case CEPH_AUTH_CEPHX: + return "cephx"; + default: + return "???"; + } +} + +const char *ceph_con_mode_name(int mode) +{ + switch (mode) { + case CEPH_CON_MODE_UNKNOWN: + return "unknown"; + case CEPH_CON_MODE_CRC: + return "crc"; + case CEPH_CON_MODE_SECURE: + return "secure"; + default: + return "???"; + } +} + +const char *ceph_osd_op_name(int op) +{ + switch (op) { +#define GENERATE_CASE(op, opcode, str) case CEPH_OSD_OP_##op: return (str); +__CEPH_FORALL_OSD_OPS(GENERATE_CASE) +#undef GENERATE_CASE + default: + return "???"; + } +} + +const char *ceph_osd_watch_op_name(int o) +{ + switch (o) { + case CEPH_OSD_WATCH_OP_UNWATCH: + return "unwatch"; + case CEPH_OSD_WATCH_OP_WATCH: + return "watch"; + case CEPH_OSD_WATCH_OP_RECONNECT: + return "reconnect"; + case CEPH_OSD_WATCH_OP_PING: + return "ping"; + default: + return "???"; + } +} + +const char *ceph_osd_state_name(int s) +{ + switch (s) { + case CEPH_OSD_EXISTS: + return "exists"; + case CEPH_OSD_UP: + return "up"; + case CEPH_OSD_AUTOOUT: + return "autoout"; + case CEPH_OSD_NEW: + return "new"; + default: + return "???"; + } +} diff --git a/net/ceph/cls_lock_client.c b/net/ceph/cls_lock_client.c new file mode 100644 index 000000000..66136a4c1 --- /dev/null +++ b/net/ceph/cls_lock_client.c @@ -0,0 +1,431 @@ +// SPDX-License-Identifier: GPL-2.0 +#include + +#include +#include + +#include +#include +#include + +/** + * ceph_cls_lock - grab rados lock for object + * @osdc: OSD client instance + * @oid: object to lock + * @oloc: object to lock + * @lock_name: the name of the lock + * @type: lock type (CEPH_CLS_LOCK_EXCLUSIVE or CEPH_CLS_LOCK_SHARED) + * @cookie: user-defined identifier for this instance of the lock + * @tag: user-defined tag + * @desc: user-defined lock description + * @flags: lock flags + * + * All operations on the same lock should use the same tag. + */ +int ceph_cls_lock(struct ceph_osd_client *osdc, + struct ceph_object_id *oid, + struct ceph_object_locator *oloc, + char *lock_name, u8 type, char *cookie, + char *tag, char *desc, u8 flags) +{ + int lock_op_buf_size; + int name_len = strlen(lock_name); + int cookie_len = strlen(cookie); + int tag_len = strlen(tag); + int desc_len = strlen(desc); + void *p, *end; + struct page *lock_op_page; + struct timespec64 mtime; + int ret; + + lock_op_buf_size = name_len + sizeof(__le32) + + cookie_len + sizeof(__le32) + + tag_len + sizeof(__le32) + + desc_len + sizeof(__le32) + + sizeof(struct ceph_timespec) + + /* flag and type */ + sizeof(u8) + sizeof(u8) + + CEPH_ENCODING_START_BLK_LEN; + if (lock_op_buf_size > PAGE_SIZE) + return -E2BIG; + + lock_op_page = alloc_page(GFP_NOIO); + if (!lock_op_page) + return -ENOMEM; + + p = page_address(lock_op_page); + end = p + lock_op_buf_size; + + /* encode cls_lock_lock_op struct */ + ceph_start_encoding(&p, 1, 1, + lock_op_buf_size - CEPH_ENCODING_START_BLK_LEN); + ceph_encode_string(&p, end, lock_name, name_len); + ceph_encode_8(&p, type); + ceph_encode_string(&p, end, cookie, cookie_len); + ceph_encode_string(&p, end, tag, tag_len); + ceph_encode_string(&p, end, desc, desc_len); + /* only support infinite duration */ + memset(&mtime, 0, sizeof(mtime)); + ceph_encode_timespec64(p, &mtime); + p += sizeof(struct ceph_timespec); + ceph_encode_8(&p, flags); + + dout("%s lock_name %s type %d cookie %s tag %s desc %s flags 0x%x\n", + __func__, lock_name, type, cookie, tag, desc, flags); + ret = ceph_osdc_call(osdc, oid, oloc, "lock", "lock", + CEPH_OSD_FLAG_WRITE, lock_op_page, + lock_op_buf_size, NULL, NULL); + + dout("%s: status %d\n", __func__, ret); + __free_page(lock_op_page); + return ret; +} +EXPORT_SYMBOL(ceph_cls_lock); + +/** + * ceph_cls_unlock - release rados lock for object + * @osdc: OSD client instance + * @oid: object to lock + * @oloc: object to lock + * @lock_name: the name of the lock + * @cookie: user-defined identifier for this instance of the lock + */ +int ceph_cls_unlock(struct ceph_osd_client *osdc, + struct ceph_object_id *oid, + struct ceph_object_locator *oloc, + char *lock_name, char *cookie) +{ + int unlock_op_buf_size; + int name_len = strlen(lock_name); + int cookie_len = strlen(cookie); + void *p, *end; + struct page *unlock_op_page; + int ret; + + unlock_op_buf_size = name_len + sizeof(__le32) + + cookie_len + sizeof(__le32) + + CEPH_ENCODING_START_BLK_LEN; + if (unlock_op_buf_size > PAGE_SIZE) + return -E2BIG; + + unlock_op_page = alloc_page(GFP_NOIO); + if (!unlock_op_page) + return -ENOMEM; + + p = page_address(unlock_op_page); + end = p + unlock_op_buf_size; + + /* encode cls_lock_unlock_op struct */ + ceph_start_encoding(&p, 1, 1, + unlock_op_buf_size - CEPH_ENCODING_START_BLK_LEN); + ceph_encode_string(&p, end, lock_name, name_len); + ceph_encode_string(&p, end, cookie, cookie_len); + + dout("%s lock_name %s cookie %s\n", __func__, lock_name, cookie); + ret = ceph_osdc_call(osdc, oid, oloc, "lock", "unlock", + CEPH_OSD_FLAG_WRITE, unlock_op_page, + unlock_op_buf_size, NULL, NULL); + + dout("%s: status %d\n", __func__, ret); + __free_page(unlock_op_page); + return ret; +} +EXPORT_SYMBOL(ceph_cls_unlock); + +/** + * ceph_cls_break_lock - release rados lock for object for specified client + * @osdc: OSD client instance + * @oid: object to lock + * @oloc: object to lock + * @lock_name: the name of the lock + * @cookie: user-defined identifier for this instance of the lock + * @locker: current lock owner + */ +int ceph_cls_break_lock(struct ceph_osd_client *osdc, + struct ceph_object_id *oid, + struct ceph_object_locator *oloc, + char *lock_name, char *cookie, + struct ceph_entity_name *locker) +{ + int break_op_buf_size; + int name_len = strlen(lock_name); + int cookie_len = strlen(cookie); + struct page *break_op_page; + void *p, *end; + int ret; + + break_op_buf_size = name_len + sizeof(__le32) + + cookie_len + sizeof(__le32) + + sizeof(u8) + sizeof(__le64) + + CEPH_ENCODING_START_BLK_LEN; + if (break_op_buf_size > PAGE_SIZE) + return -E2BIG; + + break_op_page = alloc_page(GFP_NOIO); + if (!break_op_page) + return -ENOMEM; + + p = page_address(break_op_page); + end = p + break_op_buf_size; + + /* encode cls_lock_break_op struct */ + ceph_start_encoding(&p, 1, 1, + break_op_buf_size - CEPH_ENCODING_START_BLK_LEN); + ceph_encode_string(&p, end, lock_name, name_len); + ceph_encode_copy(&p, locker, sizeof(*locker)); + ceph_encode_string(&p, end, cookie, cookie_len); + + dout("%s lock_name %s cookie %s locker %s%llu\n", __func__, lock_name, + cookie, ENTITY_NAME(*locker)); + ret = ceph_osdc_call(osdc, oid, oloc, "lock", "break_lock", + CEPH_OSD_FLAG_WRITE, break_op_page, + break_op_buf_size, NULL, NULL); + + dout("%s: status %d\n", __func__, ret); + __free_page(break_op_page); + return ret; +} +EXPORT_SYMBOL(ceph_cls_break_lock); + +int ceph_cls_set_cookie(struct ceph_osd_client *osdc, + struct ceph_object_id *oid, + struct ceph_object_locator *oloc, + char *lock_name, u8 type, char *old_cookie, + char *tag, char *new_cookie) +{ + int cookie_op_buf_size; + int name_len = strlen(lock_name); + int old_cookie_len = strlen(old_cookie); + int tag_len = strlen(tag); + int new_cookie_len = strlen(new_cookie); + void *p, *end; + struct page *cookie_op_page; + int ret; + + cookie_op_buf_size = name_len + sizeof(__le32) + + old_cookie_len + sizeof(__le32) + + tag_len + sizeof(__le32) + + new_cookie_len + sizeof(__le32) + + sizeof(u8) + CEPH_ENCODING_START_BLK_LEN; + if (cookie_op_buf_size > PAGE_SIZE) + return -E2BIG; + + cookie_op_page = alloc_page(GFP_NOIO); + if (!cookie_op_page) + return -ENOMEM; + + p = page_address(cookie_op_page); + end = p + cookie_op_buf_size; + + /* encode cls_lock_set_cookie_op struct */ + ceph_start_encoding(&p, 1, 1, + cookie_op_buf_size - CEPH_ENCODING_START_BLK_LEN); + ceph_encode_string(&p, end, lock_name, name_len); + ceph_encode_8(&p, type); + ceph_encode_string(&p, end, old_cookie, old_cookie_len); + ceph_encode_string(&p, end, tag, tag_len); + ceph_encode_string(&p, end, new_cookie, new_cookie_len); + + dout("%s lock_name %s type %d old_cookie %s tag %s new_cookie %s\n", + __func__, lock_name, type, old_cookie, tag, new_cookie); + ret = ceph_osdc_call(osdc, oid, oloc, "lock", "set_cookie", + CEPH_OSD_FLAG_WRITE, cookie_op_page, + cookie_op_buf_size, NULL, NULL); + + dout("%s: status %d\n", __func__, ret); + __free_page(cookie_op_page); + return ret; +} +EXPORT_SYMBOL(ceph_cls_set_cookie); + +void ceph_free_lockers(struct ceph_locker *lockers, u32 num_lockers) +{ + int i; + + for (i = 0; i < num_lockers; i++) + kfree(lockers[i].id.cookie); + kfree(lockers); +} +EXPORT_SYMBOL(ceph_free_lockers); + +static int decode_locker(void **p, void *end, struct ceph_locker *locker) +{ + u8 struct_v; + u32 len; + char *s; + int ret; + + ret = ceph_start_decoding(p, end, 1, "locker_id_t", &struct_v, &len); + if (ret) + return ret; + + ceph_decode_copy(p, &locker->id.name, sizeof(locker->id.name)); + s = ceph_extract_encoded_string(p, end, NULL, GFP_NOIO); + if (IS_ERR(s)) + return PTR_ERR(s); + + locker->id.cookie = s; + + ret = ceph_start_decoding(p, end, 1, "locker_info_t", &struct_v, &len); + if (ret) + return ret; + + *p += sizeof(struct ceph_timespec); /* skip expiration */ + + ret = ceph_decode_entity_addr(p, end, &locker->info.addr); + if (ret) + return ret; + + len = ceph_decode_32(p); + *p += len; /* skip description */ + + dout("%s %s%llu cookie %s addr %s\n", __func__, + ENTITY_NAME(locker->id.name), locker->id.cookie, + ceph_pr_addr(&locker->info.addr)); + return 0; +} + +static int decode_lockers(void **p, void *end, u8 *type, char **tag, + struct ceph_locker **lockers, u32 *num_lockers) +{ + u8 struct_v; + u32 struct_len; + char *s; + int i; + int ret; + + ret = ceph_start_decoding(p, end, 1, "cls_lock_get_info_reply", + &struct_v, &struct_len); + if (ret) + return ret; + + *num_lockers = ceph_decode_32(p); + *lockers = kcalloc(*num_lockers, sizeof(**lockers), GFP_NOIO); + if (!*lockers) + return -ENOMEM; + + for (i = 0; i < *num_lockers; i++) { + ret = decode_locker(p, end, *lockers + i); + if (ret) + goto err_free_lockers; + } + + *type = ceph_decode_8(p); + s = ceph_extract_encoded_string(p, end, NULL, GFP_NOIO); + if (IS_ERR(s)) { + ret = PTR_ERR(s); + goto err_free_lockers; + } + + *tag = s; + return 0; + +err_free_lockers: + ceph_free_lockers(*lockers, *num_lockers); + return ret; +} + +/* + * On success, the caller is responsible for: + * + * kfree(tag); + * ceph_free_lockers(lockers, num_lockers); + */ +int ceph_cls_lock_info(struct ceph_osd_client *osdc, + struct ceph_object_id *oid, + struct ceph_object_locator *oloc, + char *lock_name, u8 *type, char **tag, + struct ceph_locker **lockers, u32 *num_lockers) +{ + int get_info_op_buf_size; + int name_len = strlen(lock_name); + struct page *get_info_op_page, *reply_page; + size_t reply_len = PAGE_SIZE; + void *p, *end; + int ret; + + get_info_op_buf_size = name_len + sizeof(__le32) + + CEPH_ENCODING_START_BLK_LEN; + if (get_info_op_buf_size > PAGE_SIZE) + return -E2BIG; + + get_info_op_page = alloc_page(GFP_NOIO); + if (!get_info_op_page) + return -ENOMEM; + + reply_page = alloc_page(GFP_NOIO); + if (!reply_page) { + __free_page(get_info_op_page); + return -ENOMEM; + } + + p = page_address(get_info_op_page); + end = p + get_info_op_buf_size; + + /* encode cls_lock_get_info_op struct */ + ceph_start_encoding(&p, 1, 1, + get_info_op_buf_size - CEPH_ENCODING_START_BLK_LEN); + ceph_encode_string(&p, end, lock_name, name_len); + + dout("%s lock_name %s\n", __func__, lock_name); + ret = ceph_osdc_call(osdc, oid, oloc, "lock", "get_info", + CEPH_OSD_FLAG_READ, get_info_op_page, + get_info_op_buf_size, &reply_page, &reply_len); + + dout("%s: status %d\n", __func__, ret); + if (ret >= 0) { + p = page_address(reply_page); + end = p + reply_len; + + ret = decode_lockers(&p, end, type, tag, lockers, num_lockers); + } + + __free_page(get_info_op_page); + __free_page(reply_page); + return ret; +} +EXPORT_SYMBOL(ceph_cls_lock_info); + +int ceph_cls_assert_locked(struct ceph_osd_request *req, int which, + char *lock_name, u8 type, char *cookie, char *tag) +{ + int assert_op_buf_size; + int name_len = strlen(lock_name); + int cookie_len = strlen(cookie); + int tag_len = strlen(tag); + struct page **pages; + void *p, *end; + int ret; + + assert_op_buf_size = name_len + sizeof(__le32) + + cookie_len + sizeof(__le32) + + tag_len + sizeof(__le32) + + sizeof(u8) + CEPH_ENCODING_START_BLK_LEN; + if (assert_op_buf_size > PAGE_SIZE) + return -E2BIG; + + ret = osd_req_op_cls_init(req, which, "lock", "assert_locked"); + if (ret) + return ret; + + pages = ceph_alloc_page_vector(1, GFP_NOIO); + if (IS_ERR(pages)) + return PTR_ERR(pages); + + p = page_address(pages[0]); + end = p + assert_op_buf_size; + + /* encode cls_lock_assert_op struct */ + ceph_start_encoding(&p, 1, 1, + assert_op_buf_size - CEPH_ENCODING_START_BLK_LEN); + ceph_encode_string(&p, end, lock_name, name_len); + ceph_encode_8(&p, type); + ceph_encode_string(&p, end, cookie, cookie_len); + ceph_encode_string(&p, end, tag, tag_len); + WARN_ON(p != end); + + osd_req_op_cls_request_data_pages(req, which, pages, assert_op_buf_size, + 0, false, true); + return 0; +} +EXPORT_SYMBOL(ceph_cls_assert_locked); diff --git a/net/ceph/crush/crush.c b/net/ceph/crush/crush.c new file mode 100644 index 000000000..254ded0b0 --- /dev/null +++ b/net/ceph/crush/crush.c @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: GPL-2.0 +#ifdef __KERNEL__ +# include +# include +#else +# include "crush_compat.h" +# include "crush.h" +#endif + +const char *crush_bucket_alg_name(int alg) +{ + switch (alg) { + case CRUSH_BUCKET_UNIFORM: return "uniform"; + case CRUSH_BUCKET_LIST: return "list"; + case CRUSH_BUCKET_TREE: return "tree"; + case CRUSH_BUCKET_STRAW: return "straw"; + case CRUSH_BUCKET_STRAW2: return "straw2"; + default: return "unknown"; + } +} + +/** + * crush_get_bucket_item_weight - Get weight of an item in given bucket + * @b: bucket pointer + * @p: item index in bucket + */ +int crush_get_bucket_item_weight(const struct crush_bucket *b, int p) +{ + if ((__u32)p >= b->size) + return 0; + + switch (b->alg) { + case CRUSH_BUCKET_UNIFORM: + return ((struct crush_bucket_uniform *)b)->item_weight; + case CRUSH_BUCKET_LIST: + return ((struct crush_bucket_list *)b)->item_weights[p]; + case CRUSH_BUCKET_TREE: + return ((struct crush_bucket_tree *)b)->node_weights[crush_calc_tree_node(p)]; + case CRUSH_BUCKET_STRAW: + return ((struct crush_bucket_straw *)b)->item_weights[p]; + case CRUSH_BUCKET_STRAW2: + return ((struct crush_bucket_straw2 *)b)->item_weights[p]; + } + return 0; +} + +void crush_destroy_bucket_uniform(struct crush_bucket_uniform *b) +{ + kfree(b->h.items); + kfree(b); +} + +void crush_destroy_bucket_list(struct crush_bucket_list *b) +{ + kfree(b->item_weights); + kfree(b->sum_weights); + kfree(b->h.items); + kfree(b); +} + +void crush_destroy_bucket_tree(struct crush_bucket_tree *b) +{ + kfree(b->h.items); + kfree(b->node_weights); + kfree(b); +} + +void crush_destroy_bucket_straw(struct crush_bucket_straw *b) +{ + kfree(b->straws); + kfree(b->item_weights); + kfree(b->h.items); + kfree(b); +} + +void crush_destroy_bucket_straw2(struct crush_bucket_straw2 *b) +{ + kfree(b->item_weights); + kfree(b->h.items); + kfree(b); +} + +void crush_destroy_bucket(struct crush_bucket *b) +{ + switch (b->alg) { + case CRUSH_BUCKET_UNIFORM: + crush_destroy_bucket_uniform((struct crush_bucket_uniform *)b); + break; + case CRUSH_BUCKET_LIST: + crush_destroy_bucket_list((struct crush_bucket_list *)b); + break; + case CRUSH_BUCKET_TREE: + crush_destroy_bucket_tree((struct crush_bucket_tree *)b); + break; + case CRUSH_BUCKET_STRAW: + crush_destroy_bucket_straw((struct crush_bucket_straw *)b); + break; + case CRUSH_BUCKET_STRAW2: + crush_destroy_bucket_straw2((struct crush_bucket_straw2 *)b); + break; + } +} + +/** + * crush_destroy - Destroy a crush_map + * @map: crush_map pointer + */ +void crush_destroy(struct crush_map *map) +{ + /* buckets */ + if (map->buckets) { + __s32 b; + for (b = 0; b < map->max_buckets; b++) { + if (map->buckets[b] == NULL) + continue; + crush_destroy_bucket(map->buckets[b]); + } + kfree(map->buckets); + } + + /* rules */ + if (map->rules) { + __u32 b; + for (b = 0; b < map->max_rules; b++) + crush_destroy_rule(map->rules[b]); + kfree(map->rules); + } + +#ifndef __KERNEL__ + kfree(map->choose_tries); +#else + clear_crush_names(&map->type_names); + clear_crush_names(&map->names); + clear_choose_args(map); +#endif + kfree(map); +} + +void crush_destroy_rule(struct crush_rule *rule) +{ + kfree(rule); +} diff --git a/net/ceph/crush/crush_ln_table.h b/net/ceph/crush/crush_ln_table.h new file mode 100644 index 000000000..aae534c90 --- /dev/null +++ b/net/ceph/crush/crush_ln_table.h @@ -0,0 +1,164 @@ +/* + * Ceph - scalable distributed file system + * + * Copyright (C) 2015 Intel Corporation All Rights Reserved + * + * This is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License version 2.1, as published by the Free Software + * Foundation. See file COPYING. + * + */ + +#ifndef CEPH_CRUSH_LN_H +#define CEPH_CRUSH_LN_H + +#ifdef __KERNEL__ +# include +#else +# include "crush_compat.h" +#endif + +/* + * RH_LH_tbl[2*k] = 2^48/(1.0+k/128.0) + * RH_LH_tbl[2*k+1] = 2^48*log2(1.0+k/128.0) + */ +static __s64 __RH_LH_tbl[128*2+2] = { + 0x0001000000000000ll, 0x0000000000000000ll, 0x0000fe03f80fe040ll, 0x000002dfca16dde1ll, + 0x0000fc0fc0fc0fc1ll, 0x000005b9e5a170b4ll, 0x0000fa232cf25214ll, 0x0000088e68ea899all, + 0x0000f83e0f83e0f9ll, 0x00000b5d69bac77ell, 0x0000f6603d980f67ll, 0x00000e26fd5c8555ll, + 0x0000f4898d5f85bcll, 0x000010eb389fa29fll, 0x0000f2b9d6480f2cll, 0x000013aa2fdd27f1ll, + 0x0000f0f0f0f0f0f1ll, 0x00001663f6fac913ll, 0x0000ef2eb71fc435ll, 0x00001918a16e4633ll, + 0x0000ed7303b5cc0fll, 0x00001bc84240adabll, 0x0000ebbdb2a5c162ll, 0x00001e72ec117fa5ll, + 0x0000ea0ea0ea0ea1ll, 0x00002118b119b4f3ll, 0x0000e865ac7b7604ll, 0x000023b9a32eaa56ll, + 0x0000e6c2b4481cd9ll, 0x00002655d3c4f15cll, 0x0000e525982af70dll, 0x000028ed53f307eell, + 0x0000e38e38e38e39ll, 0x00002b803473f7adll, 0x0000e1fc780e1fc8ll, 0x00002e0e85a9de04ll, + 0x0000e070381c0e08ll, 0x0000309857a05e07ll, 0x0000dee95c4ca038ll, 0x0000331dba0efce1ll, + 0x0000dd67c8a60dd7ll, 0x0000359ebc5b69d9ll, 0x0000dbeb61eed19dll, 0x0000381b6d9bb29bll, + 0x0000da740da740dbll, 0x00003a93dc9864b2ll, 0x0000d901b2036407ll, 0x00003d0817ce9cd4ll, + 0x0000d79435e50d7all, 0x00003f782d7204d0ll, 0x0000d62b80d62b81ll, 0x000041e42b6ec0c0ll, + 0x0000d4c77b03531ell, 0x0000444c1f6b4c2dll, 0x0000d3680d3680d4ll, 0x000046b016ca47c1ll, + 0x0000d20d20d20d21ll, 0x000049101eac381cll, 0x0000d0b69fcbd259ll, 0x00004b6c43f1366all, + 0x0000cf6474a8819fll, 0x00004dc4933a9337ll, 0x0000ce168a772509ll, 0x0000501918ec6c11ll, + 0x0000cccccccccccdll, 0x00005269e12f346ell, 0x0000cb8727c065c4ll, 0x000054b6f7f1325all, + 0x0000ca4587e6b750ll, 0x0000570068e7ef5all, 0x0000c907da4e8712ll, 0x000059463f919deell, + 0x0000c7ce0c7ce0c8ll, 0x00005b8887367433ll, 0x0000c6980c6980c7ll, 0x00005dc74ae9fbecll, + 0x0000c565c87b5f9ell, 0x00006002958c5871ll, 0x0000c4372f855d83ll, 0x0000623a71cb82c8ll, + 0x0000c30c30c30c31ll, 0x0000646eea247c5cll, 0x0000c1e4bbd595f7ll, 0x000066a008e4788cll, + 0x0000c0c0c0c0c0c1ll, 0x000068cdd829fd81ll, 0x0000bfa02fe80bfbll, 0x00006af861e5fc7dll, + 0x0000be82fa0be830ll, 0x00006d1fafdce20all, 0x0000bd6910470767ll, 0x00006f43cba79e40ll, + 0x0000bc52640bc527ll, 0x00007164beb4a56dll, 0x0000bb3ee721a54ell, 0x000073829248e961ll, + 0x0000ba2e8ba2e8bbll, 0x0000759d4f80cba8ll, 0x0000b92143fa36f6ll, 0x000077b4ff5108d9ll, + 0x0000b81702e05c0cll, 0x000079c9aa879d53ll, 0x0000b70fbb5a19bfll, 0x00007bdb59cca388ll, + 0x0000b60b60b60b61ll, 0x00007dea15a32c1bll, 0x0000b509e68a9b95ll, 0x00007ff5e66a0ffell, + 0x0000b40b40b40b41ll, 0x000081fed45cbccbll, 0x0000b30f63528918ll, 0x00008404e793fb81ll, + 0x0000b21642c8590cll, 0x000086082806b1d5ll, 0x0000b11fd3b80b12ll, 0x000088089d8a9e47ll, + 0x0000b02c0b02c0b1ll, 0x00008a064fd50f2all, 0x0000af3addc680b0ll, 0x00008c01467b94bbll, + 0x0000ae4c415c9883ll, 0x00008df988f4ae80ll, 0x0000ad602b580ad7ll, 0x00008fef1e987409ll, + 0x0000ac7691840ac8ll, 0x000091e20ea1393ell, 0x0000ab8f69e2835all, 0x000093d2602c2e5fll, + 0x0000aaaaaaaaaaabll, 0x000095c01a39fbd6ll, 0x0000a9c84a47a080ll, 0x000097ab43af59f9ll, + 0x0000a8e83f5717c1ll, 0x00009993e355a4e5ll, 0x0000a80a80a80a81ll, 0x00009b79ffdb6c8bll, + 0x0000a72f0539782all, 0x00009d5d9fd5010bll, 0x0000a655c4392d7cll, 0x00009f3ec9bcfb80ll, + 0x0000a57eb50295fbll, 0x0000a11d83f4c355ll, 0x0000a4a9cf1d9684ll, 0x0000a2f9d4c51039ll, + 0x0000a3d70a3d70a4ll, 0x0000a4d3c25e68dcll, 0x0000a3065e3fae7dll, 0x0000a6ab52d99e76ll, + 0x0000a237c32b16d0ll, 0x0000a8808c384547ll, 0x0000a16b312ea8fdll, 0x0000aa5374652a1cll, + 0x0000a0a0a0a0a0a1ll, 0x0000ac241134c4e9ll, 0x00009fd809fd80a0ll, 0x0000adf26865a8a1ll, + 0x00009f1165e72549ll, 0x0000afbe7fa0f04dll, 0x00009e4cad23dd60ll, 0x0000b1885c7aa982ll, + 0x00009d89d89d89d9ll, 0x0000b35004723c46ll, 0x00009cc8e160c3fcll, 0x0000b5157cf2d078ll, + 0x00009c09c09c09c1ll, 0x0000b6d8cb53b0call, 0x00009b4c6f9ef03bll, 0x0000b899f4d8ab63ll, + 0x00009a90e7d95bc7ll, 0x0000ba58feb2703all, 0x000099d722dabde6ll, 0x0000bc15edfeed32ll, + 0x0000991f1a515886ll, 0x0000bdd0c7c9a817ll, 0x00009868c809868dll, 0x0000bf89910c1678ll, + 0x000097b425ed097cll, 0x0000c1404eadf383ll, 0x000097012e025c05ll, 0x0000c2f5058593d9ll, + 0x0000964fda6c0965ll, 0x0000c4a7ba58377cll, 0x000095a02568095bll, 0x0000c65871da59ddll, + 0x000094f2094f2095ll, 0x0000c80730b00016ll, 0x0000944580944581ll, 0x0000c9b3fb6d0559ll, + 0x0000939a85c4093all, 0x0000cb5ed69565afll, 0x000092f113840498ll, 0x0000cd07c69d8702ll, + 0x0000924924924925ll, 0x0000ceaecfea8085ll, 0x000091a2b3c4d5e7ll, 0x0000d053f6d26089ll, + 0x000090fdbc090fdcll, 0x0000d1f73f9c70c0ll, 0x0000905a38633e07ll, 0x0000d398ae817906ll, + 0x00008fb823ee08fcll, 0x0000d53847ac00a6ll, 0x00008f1779d9fdc4ll, 0x0000d6d60f388e41ll, + 0x00008e78356d1409ll, 0x0000d8720935e643ll, 0x00008dda5202376all, 0x0000da0c39a54804ll, + 0x00008d3dcb08d3ddll, 0x0000dba4a47aa996ll, 0x00008ca29c046515ll, 0x0000dd3b4d9cf24bll, + 0x00008c08c08c08c1ll, 0x0000ded038e633f3ll, 0x00008b70344a139cll, 0x0000e0636a23e2eell, + 0x00008ad8f2fba939ll, 0x0000e1f4e5170d02ll, 0x00008a42f870566all, 0x0000e384ad748f0ell, + 0x000089ae4089ae41ll, 0x0000e512c6e54998ll, 0x0000891ac73ae982ll, 0x0000e69f35065448ll, + 0x0000888888888889ll, 0x0000e829fb693044ll, 0x000087f78087f781ll, 0x0000e9b31d93f98ell, + 0x00008767ab5f34e5ll, 0x0000eb3a9f019750ll, 0x000086d905447a35ll, 0x0000ecc08321eb30ll, + 0x0000864b8a7de6d2ll, 0x0000ee44cd59ffabll, 0x000085bf37612cefll, 0x0000efc781043579ll, + 0x0000853408534086ll, 0x0000f148a170700all, 0x000084a9f9c8084bll, 0x0000f2c831e44116ll, + 0x0000842108421085ll, 0x0000f446359b1353ll, 0x0000839930523fbfll, 0x0000f5c2afc65447ll, + 0x000083126e978d50ll, 0x0000f73da38d9d4all, 0x0000828cbfbeb9a1ll, 0x0000f8b7140edbb1ll, + 0x0000820820820821ll, 0x0000fa2f045e7832ll, 0x000081848da8faf1ll, 0x0000fba577877d7dll, + 0x0000810204081021ll, 0x0000fd1a708bbe11ll, 0x0000808080808081ll, 0x0000fe8df263f957ll, + 0x0000800000000000ll, 0x0000ffff00000000ll, +}; + +/* + * LL_tbl[k] = 2^48*log2(1.0+k/2^15) + */ +static __s64 __LL_tbl[256] = { + 0x0000000000000000ull, 0x00000002e2a60a00ull, 0x000000070cb64ec5ull, 0x00000009ef50ce67ull, + 0x0000000cd1e588fdull, 0x0000000fb4747e9cull, 0x0000001296fdaf5eull, 0x0000001579811b58ull, + 0x000000185bfec2a1ull, 0x0000001b3e76a552ull, 0x0000001e20e8c380ull, 0x0000002103551d43ull, + 0x00000023e5bbb2b2ull, 0x00000026c81c83e4ull, 0x00000029aa7790f0ull, 0x0000002c8cccd9edull, + 0x0000002f6f1c5ef2ull, 0x0000003251662017ull, 0x0000003533aa1d71ull, 0x0000003815e8571aull, + 0x0000003af820cd26ull, 0x0000003dda537faeull, 0x00000040bc806ec8ull, 0x000000439ea79a8cull, + 0x0000004680c90310ull, 0x0000004962e4a86cull, 0x0000004c44fa8ab6ull, 0x0000004f270aaa06ull, + 0x0000005209150672ull, 0x00000054eb19a013ull, 0x00000057cd1876fdull, 0x0000005aaf118b4aull, + 0x0000005d9104dd0full, 0x0000006072f26c64ull, 0x0000006354da3960ull, 0x0000006636bc441aull, + 0x0000006918988ca8ull, 0x0000006bfa6f1322ull, 0x0000006edc3fd79full, 0x00000071be0ada35ull, + 0x000000749fd01afdull, 0x00000077818f9a0cull, 0x0000007a6349577aull, 0x0000007d44fd535eull, + 0x0000008026ab8dceull, 0x00000083085406e3ull, 0x00000085e9f6beb2ull, 0x00000088cb93b552ull, + 0x0000008bad2aeadcull, 0x0000008e8ebc5f65ull, 0x0000009170481305ull, 0x0000009451ce05d3ull, + 0x00000097334e37e5ull, 0x0000009a14c8a953ull, 0x0000009cf63d5a33ull, 0x0000009fd7ac4a9dull, + 0x000000a2b07f3458ull, 0x000000a59a78ea6aull, 0x000000a87bd699fbull, 0x000000ab5d2e8970ull, + 0x000000ae3e80b8e3ull, 0x000000b11fcd2869ull, 0x000000b40113d818ull, 0x000000b6e254c80aull, + 0x000000b9c38ff853ull, 0x000000bca4c5690cull, 0x000000bf85f51a4aull, 0x000000c2671f0c26ull, + 0x000000c548433eb6ull, 0x000000c82961b211ull, 0x000000cb0a7a664dull, 0x000000cdeb8d5b82ull, + 0x000000d0cc9a91c8ull, 0x000000d3ada20933ull, 0x000000d68ea3c1ddull, 0x000000d96f9fbbdbull, + 0x000000dc5095f744ull, 0x000000df31867430ull, 0x000000e2127132b5ull, 0x000000e4f35632eaull, + 0x000000e7d43574e6ull, 0x000000eab50ef8c1ull, 0x000000ed95e2be90ull, 0x000000f076b0c66cull, + 0x000000f35779106aull, 0x000000f6383b9ca2ull, 0x000000f918f86b2aull, 0x000000fbf9af7c1aull, + 0x000000feda60cf88ull, 0x00000101bb0c658cull, 0x000001049bb23e3cull, 0x000001077c5259afull, + 0x0000010a5cecb7fcull, 0x0000010d3d81593aull, 0x000001101e103d7full, 0x00000112fe9964e4ull, + 0x00000115df1ccf7eull, 0x00000118bf9a7d64ull, 0x0000011ba0126eadull, 0x0000011e8084a371ull, + 0x0000012160f11bc6ull, 0x000001244157d7c3ull, 0x0000012721b8d77full, 0x0000012a02141b10ull, + 0x0000012ce269a28eull, 0x0000012fc2b96e0full, 0x00000132a3037daaull, 0x000001358347d177ull, + 0x000001386386698cull, 0x0000013b43bf45ffull, 0x0000013e23f266e9ull, 0x00000141041fcc5eull, + 0x00000143e4477678ull, 0x00000146c469654bull, 0x00000149a48598f0ull, 0x0000014c849c117cull, + 0x0000014f64accf08ull, 0x0000015244b7d1a9ull, 0x0000015524bd1976ull, 0x0000015804bca687ull, + 0x0000015ae4b678f2ull, 0x0000015dc4aa90ceull, 0x00000160a498ee31ull, 0x0000016384819134ull, + 0x00000166646479ecull, 0x000001694441a870ull, 0x0000016c24191cd7ull, 0x0000016df6ca19bdull, + 0x00000171e3b6d7aaull, 0x00000174c37d1e44ull, 0x00000177a33dab1cull, 0x0000017a82f87e49ull, + 0x0000017d62ad97e2ull, 0x00000180425cf7feull, 0x00000182b07f3458ull, 0x0000018601aa8c19ull, + 0x00000188e148c046ull, 0x0000018bc0e13b52ull, 0x0000018ea073fd52ull, 0x000001918001065dull, + 0x000001945f88568bull, 0x000001973f09edf2ull, 0x0000019a1e85ccaaull, 0x0000019cfdfbf2c8ull, + 0x0000019fdd6c6063ull, 0x000001a2bcd71593ull, 0x000001a59c3c126eull, 0x000001a87b9b570bull, + 0x000001ab5af4e380ull, 0x000001ae3a48b7e5ull, 0x000001b11996d450ull, 0x000001b3f8df38d9ull, + 0x000001b6d821e595ull, 0x000001b9b75eda9bull, 0x000001bc96961803ull, 0x000001bf75c79de3ull, + 0x000001c254f36c51ull, 0x000001c534198365ull, 0x000001c81339e336ull, 0x000001caf2548bd9ull, + 0x000001cdd1697d67ull, 0x000001d0b078b7f5ull, 0x000001d38f823b9aull, 0x000001d66e86086dull, + 0x000001d94d841e86ull, 0x000001dc2c7c7df9ull, 0x000001df0b6f26dfull, 0x000001e1ea5c194eull, + 0x000001e4c943555dull, 0x000001e7a824db23ull, 0x000001ea8700aab5ull, 0x000001ed65d6c42bull, + 0x000001f044a7279dull, 0x000001f32371d51full, 0x000001f60236cccaull, 0x000001f8e0f60eb3ull, + 0x000001fbbfaf9af3ull, 0x000001fe9e63719eull, 0x000002017d1192ccull, 0x000002045bb9fe94ull, + 0x000002073a5cb50dull, 0x00000209c06e6212ull, 0x0000020cf791026aull, 0x0000020fd622997cull, + 0x00000212b07f3458ull, 0x000002159334a8d8ull, 0x0000021871b52150ull, 0x0000021b502fe517ull, + 0x0000021d6a73a78full, 0x000002210d144eeeull, 0x00000223eb7df52cull, 0x00000226c9e1e713ull, + 0x00000229a84024bbull, 0x0000022c23679b4eull, 0x0000022f64eb83a8ull, 0x000002324338a51bull, + 0x00000235218012a9ull, 0x00000237ffc1cc69ull, 0x0000023a2c3b0ea4ull, 0x0000023d13ee805bull, + 0x0000024035e9221full, 0x00000243788faf25ull, 0x0000024656b4e735ull, 0x00000247ed646bfeull, + 0x0000024c12ee3d98ull, 0x0000024ef1025c1aull, 0x00000251cf10c799ull, 0x0000025492644d65ull, + 0x000002578b1c85eeull, 0x0000025a6919d8f0ull, 0x0000025d13ee805bull, 0x0000026025036716ull, + 0x0000026296453882ull, 0x00000265e0d62b53ull, 0x00000268beb701f3ull, 0x0000026b9c92265eull, + 0x0000026d32f798a9ull, 0x00000271583758ebull, 0x000002743601673bull, 0x0000027713c5c3b0ull, + 0x00000279f1846e5full, 0x0000027ccf3d6761ull, 0x0000027e6580aecbull, 0x000002828a9e44b3ull, + 0x0000028568462932ull, 0x00000287bdbf5255ull, 0x0000028b2384de4aull, 0x0000028d13ee805bull, + 0x0000029035e9221full, 0x0000029296453882ull, 0x0000029699bdfb61ull, 0x0000029902a37aabull, + 0x0000029c54b864c9ull, 0x0000029deabd1083ull, 0x000002a20f9c0bb5ull, 0x000002a4c7605d61ull, + 0x000002a7bdbf5255ull, 0x000002a96056dafcull, 0x000002ac3daf14efull, 0x000002af1b019ecaull, + 0x000002b296453882ull, 0x000002b5d022d80full, 0x000002b8fa471cb3ull, 0x000002ba9012e713ull, + 0x000002bd6d4901ccull, 0x000002c04a796cf6ull, 0x000002c327a428a6ull, 0x000002c61a5e8f4cull, + 0x000002c8e1e891f6ull, 0x000002cbbf023fc2ull, 0x000002ce9c163e6eull, 0x000002d179248e13ull, + 0x000002d4562d2ec6ull, 0x000002d73330209dull, 0x000002da102d63b0ull, 0x000002dced24f814ull, +}; + +#endif diff --git a/net/ceph/crush/hash.c b/net/ceph/crush/hash.c new file mode 100644 index 000000000..fe79f6d2d --- /dev/null +++ b/net/ceph/crush/hash.c @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: GPL-2.0 +#ifdef __KERNEL__ +# include +#else +# include "hash.h" +#endif + +/* + * Robert Jenkins' function for mixing 32-bit values + * https://burtleburtle.net/bob/hash/evahash.html + * a, b = random bits, c = input and output + */ +#define crush_hashmix(a, b, c) do { \ + a = a-b; a = a-c; a = a^(c>>13); \ + b = b-c; b = b-a; b = b^(a<<8); \ + c = c-a; c = c-b; c = c^(b>>13); \ + a = a-b; a = a-c; a = a^(c>>12); \ + b = b-c; b = b-a; b = b^(a<<16); \ + c = c-a; c = c-b; c = c^(b>>5); \ + a = a-b; a = a-c; a = a^(c>>3); \ + b = b-c; b = b-a; b = b^(a<<10); \ + c = c-a; c = c-b; c = c^(b>>15); \ + } while (0) + +#define crush_hash_seed 1315423911 + +static __u32 crush_hash32_rjenkins1(__u32 a) +{ + __u32 hash = crush_hash_seed ^ a; + __u32 b = a; + __u32 x = 231232; + __u32 y = 1232; + crush_hashmix(b, x, hash); + crush_hashmix(y, a, hash); + return hash; +} + +static __u32 crush_hash32_rjenkins1_2(__u32 a, __u32 b) +{ + __u32 hash = crush_hash_seed ^ a ^ b; + __u32 x = 231232; + __u32 y = 1232; + crush_hashmix(a, b, hash); + crush_hashmix(x, a, hash); + crush_hashmix(b, y, hash); + return hash; +} + +static __u32 crush_hash32_rjenkins1_3(__u32 a, __u32 b, __u32 c) +{ + __u32 hash = crush_hash_seed ^ a ^ b ^ c; + __u32 x = 231232; + __u32 y = 1232; + crush_hashmix(a, b, hash); + crush_hashmix(c, x, hash); + crush_hashmix(y, a, hash); + crush_hashmix(b, x, hash); + crush_hashmix(y, c, hash); + return hash; +} + +static __u32 crush_hash32_rjenkins1_4(__u32 a, __u32 b, __u32 c, __u32 d) +{ + __u32 hash = crush_hash_seed ^ a ^ b ^ c ^ d; + __u32 x = 231232; + __u32 y = 1232; + crush_hashmix(a, b, hash); + crush_hashmix(c, d, hash); + crush_hashmix(a, x, hash); + crush_hashmix(y, b, hash); + crush_hashmix(c, x, hash); + crush_hashmix(y, d, hash); + return hash; +} + +static __u32 crush_hash32_rjenkins1_5(__u32 a, __u32 b, __u32 c, __u32 d, + __u32 e) +{ + __u32 hash = crush_hash_seed ^ a ^ b ^ c ^ d ^ e; + __u32 x = 231232; + __u32 y = 1232; + crush_hashmix(a, b, hash); + crush_hashmix(c, d, hash); + crush_hashmix(e, x, hash); + crush_hashmix(y, a, hash); + crush_hashmix(b, x, hash); + crush_hashmix(y, c, hash); + crush_hashmix(d, x, hash); + crush_hashmix(y, e, hash); + return hash; +} + + +__u32 crush_hash32(int type, __u32 a) +{ + switch (type) { + case CRUSH_HASH_RJENKINS1: + return crush_hash32_rjenkins1(a); + default: + return 0; + } +} + +__u32 crush_hash32_2(int type, __u32 a, __u32 b) +{ + switch (type) { + case CRUSH_HASH_RJENKINS1: + return crush_hash32_rjenkins1_2(a, b); + default: + return 0; + } +} + +__u32 crush_hash32_3(int type, __u32 a, __u32 b, __u32 c) +{ + switch (type) { + case CRUSH_HASH_RJENKINS1: + return crush_hash32_rjenkins1_3(a, b, c); + default: + return 0; + } +} + +__u32 crush_hash32_4(int type, __u32 a, __u32 b, __u32 c, __u32 d) +{ + switch (type) { + case CRUSH_HASH_RJENKINS1: + return crush_hash32_rjenkins1_4(a, b, c, d); + default: + return 0; + } +} + +__u32 crush_hash32_5(int type, __u32 a, __u32 b, __u32 c, __u32 d, __u32 e) +{ + switch (type) { + case CRUSH_HASH_RJENKINS1: + return crush_hash32_rjenkins1_5(a, b, c, d, e); + default: + return 0; + } +} + +const char *crush_hash_name(int type) +{ + switch (type) { + case CRUSH_HASH_RJENKINS1: + return "rjenkins1"; + default: + return "unknown"; + } +} diff --git a/net/ceph/crush/mapper.c b/net/ceph/crush/mapper.c new file mode 100644 index 000000000..1daf95e17 --- /dev/null +++ b/net/ceph/crush/mapper.c @@ -0,0 +1,1096 @@ +/* + * Ceph - scalable distributed file system + * + * Copyright (C) 2015 Intel Corporation All Rights Reserved + * + * This is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License version 2.1, as published by the Free Software + * Foundation. See file COPYING. + * + */ + +#ifdef __KERNEL__ +# include +# include +# include +# include +# include +# include +# include +#else +# include "crush_compat.h" +# include "crush.h" +# include "hash.h" +# include "mapper.h" +#endif +#include "crush_ln_table.h" + +#define dprintk(args...) /* printf(args) */ + +/* + * Implement the core CRUSH mapping algorithm. + */ + +/** + * crush_find_rule - find a crush_rule id for a given ruleset, type, and size. + * @map: the crush_map + * @ruleset: the storage ruleset id (user defined) + * @type: storage ruleset type (user defined) + * @size: output set size + */ +int crush_find_rule(const struct crush_map *map, int ruleset, int type, int size) +{ + __u32 i; + + for (i = 0; i < map->max_rules; i++) { + if (map->rules[i] && + map->rules[i]->mask.ruleset == ruleset && + map->rules[i]->mask.type == type && + map->rules[i]->mask.min_size <= size && + map->rules[i]->mask.max_size >= size) + return i; + } + return -1; +} + +/* + * bucket choose methods + * + * For each bucket algorithm, we have a "choose" method that, given a + * crush input @x and replica position (usually, position in output set) @r, + * will produce an item in the bucket. + */ + +/* + * Choose based on a random permutation of the bucket. + * + * We used to use some prime number arithmetic to do this, but it + * wasn't very random, and had some other bad behaviors. Instead, we + * calculate an actual random permutation of the bucket members. + * Since this is expensive, we optimize for the r=0 case, which + * captures the vast majority of calls. + */ +static int bucket_perm_choose(const struct crush_bucket *bucket, + struct crush_work_bucket *work, + int x, int r) +{ + unsigned int pr = r % bucket->size; + unsigned int i, s; + + /* start a new permutation if @x has changed */ + if (work->perm_x != (__u32)x || work->perm_n == 0) { + dprintk("bucket %d new x=%d\n", bucket->id, x); + work->perm_x = x; + + /* optimize common r=0 case */ + if (pr == 0) { + s = crush_hash32_3(bucket->hash, x, bucket->id, 0) % + bucket->size; + work->perm[0] = s; + work->perm_n = 0xffff; /* magic value, see below */ + goto out; + } + + for (i = 0; i < bucket->size; i++) + work->perm[i] = i; + work->perm_n = 0; + } else if (work->perm_n == 0xffff) { + /* clean up after the r=0 case above */ + for (i = 1; i < bucket->size; i++) + work->perm[i] = i; + work->perm[work->perm[0]] = 0; + work->perm_n = 1; + } + + /* calculate permutation up to pr */ + for (i = 0; i < work->perm_n; i++) + dprintk(" perm_choose have %d: %d\n", i, work->perm[i]); + while (work->perm_n <= pr) { + unsigned int p = work->perm_n; + /* no point in swapping the final entry */ + if (p < bucket->size - 1) { + i = crush_hash32_3(bucket->hash, x, bucket->id, p) % + (bucket->size - p); + if (i) { + unsigned int t = work->perm[p + i]; + work->perm[p + i] = work->perm[p]; + work->perm[p] = t; + } + dprintk(" perm_choose swap %d with %d\n", p, p+i); + } + work->perm_n++; + } + for (i = 0; i < bucket->size; i++) + dprintk(" perm_choose %d: %d\n", i, work->perm[i]); + + s = work->perm[pr]; +out: + dprintk(" perm_choose %d sz=%d x=%d r=%d (%d) s=%d\n", bucket->id, + bucket->size, x, r, pr, s); + return bucket->items[s]; +} + +/* uniform */ +static int bucket_uniform_choose(const struct crush_bucket_uniform *bucket, + struct crush_work_bucket *work, int x, int r) +{ + return bucket_perm_choose(&bucket->h, work, x, r); +} + +/* list */ +static int bucket_list_choose(const struct crush_bucket_list *bucket, + int x, int r) +{ + int i; + + for (i = bucket->h.size-1; i >= 0; i--) { + __u64 w = crush_hash32_4(bucket->h.hash, x, bucket->h.items[i], + r, bucket->h.id); + w &= 0xffff; + dprintk("list_choose i=%d x=%d r=%d item %d weight %x " + "sw %x rand %llx", + i, x, r, bucket->h.items[i], bucket->item_weights[i], + bucket->sum_weights[i], w); + w *= bucket->sum_weights[i]; + w = w >> 16; + /*dprintk(" scaled %llx\n", w);*/ + if (w < bucket->item_weights[i]) { + return bucket->h.items[i]; + } + } + + dprintk("bad list sums for bucket %d\n", bucket->h.id); + return bucket->h.items[0]; +} + + +/* (binary) tree */ +static int height(int n) +{ + int h = 0; + while ((n & 1) == 0) { + h++; + n = n >> 1; + } + return h; +} + +static int left(int x) +{ + int h = height(x); + return x - (1 << (h-1)); +} + +static int right(int x) +{ + int h = height(x); + return x + (1 << (h-1)); +} + +static int terminal(int x) +{ + return x & 1; +} + +static int bucket_tree_choose(const struct crush_bucket_tree *bucket, + int x, int r) +{ + int n; + __u32 w; + __u64 t; + + /* start at root */ + n = bucket->num_nodes >> 1; + + while (!terminal(n)) { + int l; + /* pick point in [0, w) */ + w = bucket->node_weights[n]; + t = (__u64)crush_hash32_4(bucket->h.hash, x, n, r, + bucket->h.id) * (__u64)w; + t = t >> 32; + + /* descend to the left or right? */ + l = left(n); + if (t < bucket->node_weights[l]) + n = l; + else + n = right(n); + } + + return bucket->h.items[n >> 1]; +} + + +/* straw */ + +static int bucket_straw_choose(const struct crush_bucket_straw *bucket, + int x, int r) +{ + __u32 i; + int high = 0; + __u64 high_draw = 0; + __u64 draw; + + for (i = 0; i < bucket->h.size; i++) { + draw = crush_hash32_3(bucket->h.hash, x, bucket->h.items[i], r); + draw &= 0xffff; + draw *= bucket->straws[i]; + if (i == 0 || draw > high_draw) { + high = i; + high_draw = draw; + } + } + return bucket->h.items[high]; +} + +/* compute 2^44*log2(input+1) */ +static __u64 crush_ln(unsigned int xin) +{ + unsigned int x = xin; + int iexpon, index1, index2; + __u64 RH, LH, LL, xl64, result; + + x++; + + /* normalize input */ + iexpon = 15; + + /* + * figure out number of bits we need to shift and + * do it in one step instead of iteratively + */ + if (!(x & 0x18000)) { + int bits = __builtin_clz(x & 0x1FFFF) - 16; + x <<= bits; + iexpon = 15 - bits; + } + + index1 = (x >> 8) << 1; + /* RH ~ 2^56/index1 */ + RH = __RH_LH_tbl[index1 - 256]; + /* LH ~ 2^48 * log2(index1/256) */ + LH = __RH_LH_tbl[index1 + 1 - 256]; + + /* RH*x ~ 2^48 * (2^15 + xf), xf<2^8 */ + xl64 = (__s64)x * RH; + xl64 >>= 48; + + result = iexpon; + result <<= (12 + 32); + + index2 = xl64 & 0xff; + /* LL ~ 2^48*log2(1.0+index2/2^15) */ + LL = __LL_tbl[index2]; + + LH = LH + LL; + + LH >>= (48 - 12 - 32); + result += LH; + + return result; +} + + +/* + * straw2 + * + * for reference, see: + * + * https://en.wikipedia.org/wiki/Exponential_distribution#Distribution_of_the_minimum_of_exponential_random_variables + * + */ + +static __u32 *get_choose_arg_weights(const struct crush_bucket_straw2 *bucket, + const struct crush_choose_arg *arg, + int position) +{ + if (!arg || !arg->weight_set) + return bucket->item_weights; + + if (position >= arg->weight_set_size) + position = arg->weight_set_size - 1; + return arg->weight_set[position].weights; +} + +static __s32 *get_choose_arg_ids(const struct crush_bucket_straw2 *bucket, + const struct crush_choose_arg *arg) +{ + if (!arg || !arg->ids) + return bucket->h.items; + + return arg->ids; +} + +static int bucket_straw2_choose(const struct crush_bucket_straw2 *bucket, + int x, int r, + const struct crush_choose_arg *arg, + int position) +{ + unsigned int i, high = 0; + unsigned int u; + __s64 ln, draw, high_draw = 0; + __u32 *weights = get_choose_arg_weights(bucket, arg, position); + __s32 *ids = get_choose_arg_ids(bucket, arg); + + for (i = 0; i < bucket->h.size; i++) { + dprintk("weight 0x%x item %d\n", weights[i], ids[i]); + if (weights[i]) { + u = crush_hash32_3(bucket->h.hash, x, ids[i], r); + u &= 0xffff; + + /* + * for some reason slightly less than 0x10000 produces + * a slightly more accurate distribution... probably a + * rounding effect. + * + * the natural log lookup table maps [0,0xffff] + * (corresponding to real numbers [1/0x10000, 1] to + * [0, 0xffffffffffff] (corresponding to real numbers + * [-11.090355,0]). + */ + ln = crush_ln(u) - 0x1000000000000ll; + + /* + * divide by 16.16 fixed-point weight. note + * that the ln value is negative, so a larger + * weight means a larger (less negative) value + * for draw. + */ + draw = div64_s64(ln, weights[i]); + } else { + draw = S64_MIN; + } + + if (i == 0 || draw > high_draw) { + high = i; + high_draw = draw; + } + } + + return bucket->h.items[high]; +} + + +static int crush_bucket_choose(const struct crush_bucket *in, + struct crush_work_bucket *work, + int x, int r, + const struct crush_choose_arg *arg, + int position) +{ + dprintk(" crush_bucket_choose %d x=%d r=%d\n", in->id, x, r); + BUG_ON(in->size == 0); + switch (in->alg) { + case CRUSH_BUCKET_UNIFORM: + return bucket_uniform_choose( + (const struct crush_bucket_uniform *)in, + work, x, r); + case CRUSH_BUCKET_LIST: + return bucket_list_choose((const struct crush_bucket_list *)in, + x, r); + case CRUSH_BUCKET_TREE: + return bucket_tree_choose((const struct crush_bucket_tree *)in, + x, r); + case CRUSH_BUCKET_STRAW: + return bucket_straw_choose( + (const struct crush_bucket_straw *)in, + x, r); + case CRUSH_BUCKET_STRAW2: + return bucket_straw2_choose( + (const struct crush_bucket_straw2 *)in, + x, r, arg, position); + default: + dprintk("unknown bucket %d alg %d\n", in->id, in->alg); + return in->items[0]; + } +} + +/* + * true if device is marked "out" (failed, fully offloaded) + * of the cluster + */ +static int is_out(const struct crush_map *map, + const __u32 *weight, int weight_max, + int item, int x) +{ + if (item >= weight_max) + return 1; + if (weight[item] >= 0x10000) + return 0; + if (weight[item] == 0) + return 1; + if ((crush_hash32_2(CRUSH_HASH_RJENKINS1, x, item) & 0xffff) + < weight[item]) + return 0; + return 1; +} + +/** + * crush_choose_firstn - choose numrep distinct items of given type + * @map: the crush_map + * @bucket: the bucket we are choose an item from + * @x: crush input value + * @numrep: the number of items to choose + * @type: the type of item to choose + * @out: pointer to output vector + * @outpos: our position in that vector + * @out_size: size of the out vector + * @tries: number of attempts to make + * @recurse_tries: number of attempts to have recursive chooseleaf make + * @local_retries: localized retries + * @local_fallback_retries: localized fallback retries + * @recurse_to_leaf: true if we want one device under each item of given type (chooseleaf instead of choose) + * @stable: stable mode starts rep=0 in the recursive call for all replicas + * @vary_r: pass r to recursive calls + * @out2: second output vector for leaf items (if @recurse_to_leaf) + * @parent_r: r value passed from the parent + */ +static int crush_choose_firstn(const struct crush_map *map, + struct crush_work *work, + const struct crush_bucket *bucket, + const __u32 *weight, int weight_max, + int x, int numrep, int type, + int *out, int outpos, + int out_size, + unsigned int tries, + unsigned int recurse_tries, + unsigned int local_retries, + unsigned int local_fallback_retries, + int recurse_to_leaf, + unsigned int vary_r, + unsigned int stable, + int *out2, + int parent_r, + const struct crush_choose_arg *choose_args) +{ + int rep; + unsigned int ftotal, flocal; + int retry_descent, retry_bucket, skip_rep; + const struct crush_bucket *in = bucket; + int r; + int i; + int item = 0; + int itemtype; + int collide, reject; + int count = out_size; + + dprintk("CHOOSE%s bucket %d x %d outpos %d numrep %d tries %d recurse_tries %d local_retries %d local_fallback_retries %d parent_r %d stable %d\n", + recurse_to_leaf ? "_LEAF" : "", + bucket->id, x, outpos, numrep, + tries, recurse_tries, local_retries, local_fallback_retries, + parent_r, stable); + + for (rep = stable ? 0 : outpos; rep < numrep && count > 0 ; rep++) { + /* keep trying until we get a non-out, non-colliding item */ + ftotal = 0; + skip_rep = 0; + do { + retry_descent = 0; + in = bucket; /* initial bucket */ + + /* choose through intervening buckets */ + flocal = 0; + do { + collide = 0; + retry_bucket = 0; + r = rep + parent_r; + /* r' = r + f_total */ + r += ftotal; + + /* bucket choose */ + if (in->size == 0) { + reject = 1; + goto reject; + } + if (local_fallback_retries > 0 && + flocal >= (in->size>>1) && + flocal > local_fallback_retries) + item = bucket_perm_choose( + in, work->work[-1-in->id], + x, r); + else + item = crush_bucket_choose( + in, work->work[-1-in->id], + x, r, + (choose_args ? + &choose_args[-1-in->id] : NULL), + outpos); + if (item >= map->max_devices) { + dprintk(" bad item %d\n", item); + skip_rep = 1; + break; + } + + /* desired type? */ + if (item < 0) + itemtype = map->buckets[-1-item]->type; + else + itemtype = 0; + dprintk(" item %d type %d\n", item, itemtype); + + /* keep going? */ + if (itemtype != type) { + if (item >= 0 || + (-1-item) >= map->max_buckets) { + dprintk(" bad item type %d\n", type); + skip_rep = 1; + break; + } + in = map->buckets[-1-item]; + retry_bucket = 1; + continue; + } + + /* collision? */ + for (i = 0; i < outpos; i++) { + if (out[i] == item) { + collide = 1; + break; + } + } + + reject = 0; + if (!collide && recurse_to_leaf) { + if (item < 0) { + int sub_r; + if (vary_r) + sub_r = r >> (vary_r-1); + else + sub_r = 0; + if (crush_choose_firstn( + map, + work, + map->buckets[-1-item], + weight, weight_max, + x, stable ? 1 : outpos+1, 0, + out2, outpos, count, + recurse_tries, 0, + local_retries, + local_fallback_retries, + 0, + vary_r, + stable, + NULL, + sub_r, + choose_args) <= outpos) + /* didn't get leaf */ + reject = 1; + } else { + /* we already have a leaf! */ + out2[outpos] = item; + } + } + + if (!reject && !collide) { + /* out? */ + if (itemtype == 0) + reject = is_out(map, weight, + weight_max, + item, x); + } + +reject: + if (reject || collide) { + ftotal++; + flocal++; + + if (collide && flocal <= local_retries) + /* retry locally a few times */ + retry_bucket = 1; + else if (local_fallback_retries > 0 && + flocal <= in->size + local_fallback_retries) + /* exhaustive bucket search */ + retry_bucket = 1; + else if (ftotal < tries) + /* then retry descent */ + retry_descent = 1; + else + /* else give up */ + skip_rep = 1; + dprintk(" reject %d collide %d " + "ftotal %u flocal %u\n", + reject, collide, ftotal, + flocal); + } + } while (retry_bucket); + } while (retry_descent); + + if (skip_rep) { + dprintk("skip rep\n"); + continue; + } + + dprintk("CHOOSE got %d\n", item); + out[outpos] = item; + outpos++; + count--; +#ifndef __KERNEL__ + if (map->choose_tries && ftotal <= map->choose_total_tries) + map->choose_tries[ftotal]++; +#endif + } + + dprintk("CHOOSE returns %d\n", outpos); + return outpos; +} + + +/** + * crush_choose_indep: alternative breadth-first positionally stable mapping + * + */ +static void crush_choose_indep(const struct crush_map *map, + struct crush_work *work, + const struct crush_bucket *bucket, + const __u32 *weight, int weight_max, + int x, int left, int numrep, int type, + int *out, int outpos, + unsigned int tries, + unsigned int recurse_tries, + int recurse_to_leaf, + int *out2, + int parent_r, + const struct crush_choose_arg *choose_args) +{ + const struct crush_bucket *in = bucket; + int endpos = outpos + left; + int rep; + unsigned int ftotal; + int r; + int i; + int item = 0; + int itemtype; + int collide; + + dprintk("CHOOSE%s INDEP bucket %d x %d outpos %d numrep %d\n", recurse_to_leaf ? "_LEAF" : "", + bucket->id, x, outpos, numrep); + + /* initially my result is undefined */ + for (rep = outpos; rep < endpos; rep++) { + out[rep] = CRUSH_ITEM_UNDEF; + if (out2) + out2[rep] = CRUSH_ITEM_UNDEF; + } + + for (ftotal = 0; left > 0 && ftotal < tries; ftotal++) { +#ifdef DEBUG_INDEP + if (out2 && ftotal) { + dprintk("%u %d a: ", ftotal, left); + for (rep = outpos; rep < endpos; rep++) { + dprintk(" %d", out[rep]); + } + dprintk("\n"); + dprintk("%u %d b: ", ftotal, left); + for (rep = outpos; rep < endpos; rep++) { + dprintk(" %d", out2[rep]); + } + dprintk("\n"); + } +#endif + for (rep = outpos; rep < endpos; rep++) { + if (out[rep] != CRUSH_ITEM_UNDEF) + continue; + + in = bucket; /* initial bucket */ + + /* choose through intervening buckets */ + for (;;) { + /* note: we base the choice on the position + * even in the nested call. that means that + * if the first layer chooses the same bucket + * in a different position, we will tend to + * choose a different item in that bucket. + * this will involve more devices in data + * movement and tend to distribute the load. + */ + r = rep + parent_r; + + /* be careful */ + if (in->alg == CRUSH_BUCKET_UNIFORM && + in->size % numrep == 0) + /* r'=r+(n+1)*f_total */ + r += (numrep+1) * ftotal; + else + /* r' = r + n*f_total */ + r += numrep * ftotal; + + /* bucket choose */ + if (in->size == 0) { + dprintk(" empty bucket\n"); + break; + } + + item = crush_bucket_choose( + in, work->work[-1-in->id], + x, r, + (choose_args ? + &choose_args[-1-in->id] : NULL), + outpos); + if (item >= map->max_devices) { + dprintk(" bad item %d\n", item); + out[rep] = CRUSH_ITEM_NONE; + if (out2) + out2[rep] = CRUSH_ITEM_NONE; + left--; + break; + } + + /* desired type? */ + if (item < 0) + itemtype = map->buckets[-1-item]->type; + else + itemtype = 0; + dprintk(" item %d type %d\n", item, itemtype); + + /* keep going? */ + if (itemtype != type) { + if (item >= 0 || + (-1-item) >= map->max_buckets) { + dprintk(" bad item type %d\n", type); + out[rep] = CRUSH_ITEM_NONE; + if (out2) + out2[rep] = + CRUSH_ITEM_NONE; + left--; + break; + } + in = map->buckets[-1-item]; + continue; + } + + /* collision? */ + collide = 0; + for (i = outpos; i < endpos; i++) { + if (out[i] == item) { + collide = 1; + break; + } + } + if (collide) + break; + + if (recurse_to_leaf) { + if (item < 0) { + crush_choose_indep( + map, + work, + map->buckets[-1-item], + weight, weight_max, + x, 1, numrep, 0, + out2, rep, + recurse_tries, 0, + 0, NULL, r, + choose_args); + if (out2[rep] == CRUSH_ITEM_NONE) { + /* placed nothing; no leaf */ + break; + } + } else { + /* we already have a leaf! */ + out2[rep] = item; + } + } + + /* out? */ + if (itemtype == 0 && + is_out(map, weight, weight_max, item, x)) + break; + + /* yay! */ + out[rep] = item; + left--; + break; + } + } + } + for (rep = outpos; rep < endpos; rep++) { + if (out[rep] == CRUSH_ITEM_UNDEF) { + out[rep] = CRUSH_ITEM_NONE; + } + if (out2 && out2[rep] == CRUSH_ITEM_UNDEF) { + out2[rep] = CRUSH_ITEM_NONE; + } + } +#ifndef __KERNEL__ + if (map->choose_tries && ftotal <= map->choose_total_tries) + map->choose_tries[ftotal]++; +#endif +#ifdef DEBUG_INDEP + if (out2) { + dprintk("%u %d a: ", ftotal, left); + for (rep = outpos; rep < endpos; rep++) { + dprintk(" %d", out[rep]); + } + dprintk("\n"); + dprintk("%u %d b: ", ftotal, left); + for (rep = outpos; rep < endpos; rep++) { + dprintk(" %d", out2[rep]); + } + dprintk("\n"); + } +#endif +} + + +/* + * This takes a chunk of memory and sets it up to be a shiny new + * working area for a CRUSH placement computation. It must be called + * on any newly allocated memory before passing it in to + * crush_do_rule. It may be used repeatedly after that, so long as the + * map has not changed. If the map /has/ changed, you must make sure + * the working size is no smaller than what was allocated and re-run + * crush_init_workspace. + * + * If you do retain the working space between calls to crush, make it + * thread-local. + */ +void crush_init_workspace(const struct crush_map *map, void *v) +{ + struct crush_work *w = v; + __s32 b; + + /* + * We work by moving through the available space and setting + * values and pointers as we go. + * + * It's a bit like Forth's use of the 'allot' word since we + * set the pointer first and then reserve the space for it to + * point to by incrementing the point. + */ + v += sizeof(struct crush_work); + w->work = v; + v += map->max_buckets * sizeof(struct crush_work_bucket *); + for (b = 0; b < map->max_buckets; ++b) { + if (!map->buckets[b]) + continue; + + w->work[b] = v; + switch (map->buckets[b]->alg) { + default: + v += sizeof(struct crush_work_bucket); + break; + } + w->work[b]->perm_x = 0; + w->work[b]->perm_n = 0; + w->work[b]->perm = v; + v += map->buckets[b]->size * sizeof(__u32); + } + BUG_ON(v - (void *)w != map->working_size); +} + +/** + * crush_do_rule - calculate a mapping with the given input and rule + * @map: the crush_map + * @ruleno: the rule id + * @x: hash input + * @result: pointer to result vector + * @result_max: maximum result size + * @weight: weight vector (for map leaves) + * @weight_max: size of weight vector + * @cwin: pointer to at least crush_work_size() bytes of memory + * @choose_args: weights and ids for each known bucket + */ +int crush_do_rule(const struct crush_map *map, + int ruleno, int x, int *result, int result_max, + const __u32 *weight, int weight_max, + void *cwin, const struct crush_choose_arg *choose_args) +{ + int result_len; + struct crush_work *cw = cwin; + int *a = cwin + map->working_size; + int *b = a + result_max; + int *c = b + result_max; + int *w = a; + int *o = b; + int recurse_to_leaf; + int wsize = 0; + int osize; + const struct crush_rule *rule; + __u32 step; + int i, j; + int numrep; + int out_size; + /* + * the original choose_total_tries value was off by one (it + * counted "retries" and not "tries"). add one. + */ + int choose_tries = map->choose_total_tries + 1; + int choose_leaf_tries = 0; + /* + * the local tries values were counted as "retries", though, + * and need no adjustment + */ + int choose_local_retries = map->choose_local_tries; + int choose_local_fallback_retries = map->choose_local_fallback_tries; + + int vary_r = map->chooseleaf_vary_r; + int stable = map->chooseleaf_stable; + + if ((__u32)ruleno >= map->max_rules) { + dprintk(" bad ruleno %d\n", ruleno); + return 0; + } + + rule = map->rules[ruleno]; + result_len = 0; + + for (step = 0; step < rule->len; step++) { + int firstn = 0; + const struct crush_rule_step *curstep = &rule->steps[step]; + + switch (curstep->op) { + case CRUSH_RULE_TAKE: + if ((curstep->arg1 >= 0 && + curstep->arg1 < map->max_devices) || + (-1-curstep->arg1 >= 0 && + -1-curstep->arg1 < map->max_buckets && + map->buckets[-1-curstep->arg1])) { + w[0] = curstep->arg1; + wsize = 1; + } else { + dprintk(" bad take value %d\n", curstep->arg1); + } + break; + + case CRUSH_RULE_SET_CHOOSE_TRIES: + if (curstep->arg1 > 0) + choose_tries = curstep->arg1; + break; + + case CRUSH_RULE_SET_CHOOSELEAF_TRIES: + if (curstep->arg1 > 0) + choose_leaf_tries = curstep->arg1; + break; + + case CRUSH_RULE_SET_CHOOSE_LOCAL_TRIES: + if (curstep->arg1 >= 0) + choose_local_retries = curstep->arg1; + break; + + case CRUSH_RULE_SET_CHOOSE_LOCAL_FALLBACK_TRIES: + if (curstep->arg1 >= 0) + choose_local_fallback_retries = curstep->arg1; + break; + + case CRUSH_RULE_SET_CHOOSELEAF_VARY_R: + if (curstep->arg1 >= 0) + vary_r = curstep->arg1; + break; + + case CRUSH_RULE_SET_CHOOSELEAF_STABLE: + if (curstep->arg1 >= 0) + stable = curstep->arg1; + break; + + case CRUSH_RULE_CHOOSELEAF_FIRSTN: + case CRUSH_RULE_CHOOSE_FIRSTN: + firstn = 1; + fallthrough; + case CRUSH_RULE_CHOOSELEAF_INDEP: + case CRUSH_RULE_CHOOSE_INDEP: + if (wsize == 0) + break; + + recurse_to_leaf = + curstep->op == + CRUSH_RULE_CHOOSELEAF_FIRSTN || + curstep->op == + CRUSH_RULE_CHOOSELEAF_INDEP; + + /* reset output */ + osize = 0; + + for (i = 0; i < wsize; i++) { + int bno; + numrep = curstep->arg1; + if (numrep <= 0) { + numrep += result_max; + if (numrep <= 0) + continue; + } + j = 0; + /* make sure bucket id is valid */ + bno = -1 - w[i]; + if (bno < 0 || bno >= map->max_buckets) { + /* w[i] is probably CRUSH_ITEM_NONE */ + dprintk(" bad w[i] %d\n", w[i]); + continue; + } + if (firstn) { + int recurse_tries; + if (choose_leaf_tries) + recurse_tries = + choose_leaf_tries; + else if (map->chooseleaf_descend_once) + recurse_tries = 1; + else + recurse_tries = choose_tries; + osize += crush_choose_firstn( + map, + cw, + map->buckets[bno], + weight, weight_max, + x, numrep, + curstep->arg2, + o+osize, j, + result_max-osize, + choose_tries, + recurse_tries, + choose_local_retries, + choose_local_fallback_retries, + recurse_to_leaf, + vary_r, + stable, + c+osize, + 0, + choose_args); + } else { + out_size = ((numrep < (result_max-osize)) ? + numrep : (result_max-osize)); + crush_choose_indep( + map, + cw, + map->buckets[bno], + weight, weight_max, + x, out_size, numrep, + curstep->arg2, + o+osize, j, + choose_tries, + choose_leaf_tries ? + choose_leaf_tries : 1, + recurse_to_leaf, + c+osize, + 0, + choose_args); + osize += out_size; + } + } + + if (recurse_to_leaf) + /* copy final _leaf_ values to output set */ + memcpy(o, c, osize*sizeof(*o)); + + /* swap o and w arrays */ + swap(o, w); + wsize = osize; + break; + + + case CRUSH_RULE_EMIT: + for (i = 0; i < wsize && result_len < result_max; i++) { + result[result_len] = w[i]; + result_len++; + } + wsize = 0; + break; + + default: + dprintk(" unknown op %d at step %d\n", + curstep->op, step); + break; + } + } + + return result_len; +} diff --git a/net/ceph/crypto.c b/net/ceph/crypto.c new file mode 100644 index 000000000..051d22c0e --- /dev/null +++ b/net/ceph/crypto.c @@ -0,0 +1,361 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include "crypto.h" + +/* + * Set ->key and ->tfm. The rest of the key should be filled in before + * this function is called. + */ +static int set_secret(struct ceph_crypto_key *key, void *buf) +{ + unsigned int noio_flag; + int ret; + + key->key = NULL; + key->tfm = NULL; + + switch (key->type) { + case CEPH_CRYPTO_NONE: + return 0; /* nothing to do */ + case CEPH_CRYPTO_AES: + break; + default: + return -ENOTSUPP; + } + + if (!key->len) + return -EINVAL; + + key->key = kmemdup(buf, key->len, GFP_NOIO); + if (!key->key) { + ret = -ENOMEM; + goto fail; + } + + /* crypto_alloc_sync_skcipher() allocates with GFP_KERNEL */ + noio_flag = memalloc_noio_save(); + key->tfm = crypto_alloc_sync_skcipher("cbc(aes)", 0, 0); + memalloc_noio_restore(noio_flag); + if (IS_ERR(key->tfm)) { + ret = PTR_ERR(key->tfm); + key->tfm = NULL; + goto fail; + } + + ret = crypto_sync_skcipher_setkey(key->tfm, key->key, key->len); + if (ret) + goto fail; + + return 0; + +fail: + ceph_crypto_key_destroy(key); + return ret; +} + +int ceph_crypto_key_clone(struct ceph_crypto_key *dst, + const struct ceph_crypto_key *src) +{ + memcpy(dst, src, sizeof(struct ceph_crypto_key)); + return set_secret(dst, src->key); +} + +int ceph_crypto_key_encode(struct ceph_crypto_key *key, void **p, void *end) +{ + if (*p + sizeof(u16) + sizeof(key->created) + + sizeof(u16) + key->len > end) + return -ERANGE; + ceph_encode_16(p, key->type); + ceph_encode_copy(p, &key->created, sizeof(key->created)); + ceph_encode_16(p, key->len); + ceph_encode_copy(p, key->key, key->len); + return 0; +} + +int ceph_crypto_key_decode(struct ceph_crypto_key *key, void **p, void *end) +{ + int ret; + + ceph_decode_need(p, end, 2*sizeof(u16) + sizeof(key->created), bad); + key->type = ceph_decode_16(p); + ceph_decode_copy(p, &key->created, sizeof(key->created)); + key->len = ceph_decode_16(p); + ceph_decode_need(p, end, key->len, bad); + ret = set_secret(key, *p); + memzero_explicit(*p, key->len); + *p += key->len; + return ret; + +bad: + dout("failed to decode crypto key\n"); + return -EINVAL; +} + +int ceph_crypto_key_unarmor(struct ceph_crypto_key *key, const char *inkey) +{ + int inlen = strlen(inkey); + int blen = inlen * 3 / 4; + void *buf, *p; + int ret; + + dout("crypto_key_unarmor %s\n", inkey); + buf = kmalloc(blen, GFP_NOFS); + if (!buf) + return -ENOMEM; + blen = ceph_unarmor(buf, inkey, inkey+inlen); + if (blen < 0) { + kfree(buf); + return blen; + } + + p = buf; + ret = ceph_crypto_key_decode(key, &p, p + blen); + kfree(buf); + if (ret) + return ret; + dout("crypto_key_unarmor key %p type %d len %d\n", key, + key->type, key->len); + return 0; +} + +void ceph_crypto_key_destroy(struct ceph_crypto_key *key) +{ + if (key) { + kfree_sensitive(key->key); + key->key = NULL; + if (key->tfm) { + crypto_free_sync_skcipher(key->tfm); + key->tfm = NULL; + } + } +} + +static const u8 *aes_iv = (u8 *)CEPH_AES_IV; + +/* + * Should be used for buffers allocated with kvmalloc(). + * Currently these are encrypt out-buffer (ceph_buffer) and decrypt + * in-buffer (msg front). + * + * Dispose of @sgt with teardown_sgtable(). + * + * @prealloc_sg is to avoid memory allocation inside sg_alloc_table() + * in cases where a single sg is sufficient. No attempt to reduce the + * number of sgs by squeezing physically contiguous pages together is + * made though, for simplicity. + */ +static int setup_sgtable(struct sg_table *sgt, struct scatterlist *prealloc_sg, + const void *buf, unsigned int buf_len) +{ + struct scatterlist *sg; + const bool is_vmalloc = is_vmalloc_addr(buf); + unsigned int off = offset_in_page(buf); + unsigned int chunk_cnt = 1; + unsigned int chunk_len = PAGE_ALIGN(off + buf_len); + int i; + int ret; + + if (buf_len == 0) { + memset(sgt, 0, sizeof(*sgt)); + return -EINVAL; + } + + if (is_vmalloc) { + chunk_cnt = chunk_len >> PAGE_SHIFT; + chunk_len = PAGE_SIZE; + } + + if (chunk_cnt > 1) { + ret = sg_alloc_table(sgt, chunk_cnt, GFP_NOFS); + if (ret) + return ret; + } else { + WARN_ON(chunk_cnt != 1); + sg_init_table(prealloc_sg, 1); + sgt->sgl = prealloc_sg; + sgt->nents = sgt->orig_nents = 1; + } + + for_each_sg(sgt->sgl, sg, sgt->orig_nents, i) { + struct page *page; + unsigned int len = min(chunk_len - off, buf_len); + + if (is_vmalloc) + page = vmalloc_to_page(buf); + else + page = virt_to_page(buf); + + sg_set_page(sg, page, len, off); + + off = 0; + buf += len; + buf_len -= len; + } + WARN_ON(buf_len != 0); + + return 0; +} + +static void teardown_sgtable(struct sg_table *sgt) +{ + if (sgt->orig_nents > 1) + sg_free_table(sgt); +} + +static int ceph_aes_crypt(const struct ceph_crypto_key *key, bool encrypt, + void *buf, int buf_len, int in_len, int *pout_len) +{ + SYNC_SKCIPHER_REQUEST_ON_STACK(req, key->tfm); + struct sg_table sgt; + struct scatterlist prealloc_sg; + char iv[AES_BLOCK_SIZE] __aligned(8); + int pad_byte = AES_BLOCK_SIZE - (in_len & (AES_BLOCK_SIZE - 1)); + int crypt_len = encrypt ? in_len + pad_byte : in_len; + int ret; + + WARN_ON(crypt_len > buf_len); + if (encrypt) + memset(buf + in_len, pad_byte, pad_byte); + ret = setup_sgtable(&sgt, &prealloc_sg, buf, crypt_len); + if (ret) + return ret; + + memcpy(iv, aes_iv, AES_BLOCK_SIZE); + skcipher_request_set_sync_tfm(req, key->tfm); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, sgt.sgl, sgt.sgl, crypt_len, iv); + + /* + print_hex_dump(KERN_ERR, "key: ", DUMP_PREFIX_NONE, 16, 1, + key->key, key->len, 1); + print_hex_dump(KERN_ERR, " in: ", DUMP_PREFIX_NONE, 16, 1, + buf, crypt_len, 1); + */ + if (encrypt) + ret = crypto_skcipher_encrypt(req); + else + ret = crypto_skcipher_decrypt(req); + skcipher_request_zero(req); + if (ret) { + pr_err("%s %scrypt failed: %d\n", __func__, + encrypt ? "en" : "de", ret); + goto out_sgt; + } + /* + print_hex_dump(KERN_ERR, "out: ", DUMP_PREFIX_NONE, 16, 1, + buf, crypt_len, 1); + */ + + if (encrypt) { + *pout_len = crypt_len; + } else { + pad_byte = *(char *)(buf + in_len - 1); + if (pad_byte > 0 && pad_byte <= AES_BLOCK_SIZE && + in_len >= pad_byte) { + *pout_len = in_len - pad_byte; + } else { + pr_err("%s got bad padding %d on in_len %d\n", + __func__, pad_byte, in_len); + ret = -EPERM; + goto out_sgt; + } + } + +out_sgt: + teardown_sgtable(&sgt); + return ret; +} + +int ceph_crypt(const struct ceph_crypto_key *key, bool encrypt, + void *buf, int buf_len, int in_len, int *pout_len) +{ + switch (key->type) { + case CEPH_CRYPTO_NONE: + *pout_len = in_len; + return 0; + case CEPH_CRYPTO_AES: + return ceph_aes_crypt(key, encrypt, buf, buf_len, in_len, + pout_len); + default: + return -ENOTSUPP; + } +} + +static int ceph_key_preparse(struct key_preparsed_payload *prep) +{ + struct ceph_crypto_key *ckey; + size_t datalen = prep->datalen; + int ret; + void *p; + + ret = -EINVAL; + if (datalen <= 0 || datalen > 32767 || !prep->data) + goto err; + + ret = -ENOMEM; + ckey = kmalloc(sizeof(*ckey), GFP_KERNEL); + if (!ckey) + goto err; + + /* TODO ceph_crypto_key_decode should really take const input */ + p = (void *)prep->data; + ret = ceph_crypto_key_decode(ckey, &p, (char*)prep->data+datalen); + if (ret < 0) + goto err_ckey; + + prep->payload.data[0] = ckey; + prep->quotalen = datalen; + return 0; + +err_ckey: + kfree(ckey); +err: + return ret; +} + +static void ceph_key_free_preparse(struct key_preparsed_payload *prep) +{ + struct ceph_crypto_key *ckey = prep->payload.data[0]; + ceph_crypto_key_destroy(ckey); + kfree(ckey); +} + +static void ceph_key_destroy(struct key *key) +{ + struct ceph_crypto_key *ckey = key->payload.data[0]; + + ceph_crypto_key_destroy(ckey); + kfree(ckey); +} + +struct key_type key_type_ceph = { + .name = "ceph", + .preparse = ceph_key_preparse, + .free_preparse = ceph_key_free_preparse, + .instantiate = generic_key_instantiate, + .destroy = ceph_key_destroy, +}; + +int __init ceph_crypto_init(void) +{ + return register_key_type(&key_type_ceph); +} + +void ceph_crypto_shutdown(void) +{ + unregister_key_type(&key_type_ceph); +} diff --git a/net/ceph/crypto.h b/net/ceph/crypto.h new file mode 100644 index 000000000..13bd52634 --- /dev/null +++ b/net/ceph/crypto.h @@ -0,0 +1,39 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _FS_CEPH_CRYPTO_H +#define _FS_CEPH_CRYPTO_H + +#include +#include + +#define CEPH_KEY_LEN 16 +#define CEPH_MAX_CON_SECRET_LEN 64 + +/* + * cryptographic secret + */ +struct ceph_crypto_key { + int type; + struct ceph_timespec created; + int len; + void *key; + struct crypto_sync_skcipher *tfm; +}; + +int ceph_crypto_key_clone(struct ceph_crypto_key *dst, + const struct ceph_crypto_key *src); +int ceph_crypto_key_encode(struct ceph_crypto_key *key, void **p, void *end); +int ceph_crypto_key_decode(struct ceph_crypto_key *key, void **p, void *end); +int ceph_crypto_key_unarmor(struct ceph_crypto_key *key, const char *in); +void ceph_crypto_key_destroy(struct ceph_crypto_key *key); + +/* crypto.c */ +int ceph_crypt(const struct ceph_crypto_key *key, bool encrypt, + void *buf, int buf_len, int in_len, int *pout_len); +int ceph_crypto_init(void); +void ceph_crypto_shutdown(void); + +/* armor.c */ +int ceph_armor(char *dst, const char *src, const char *end); +int ceph_unarmor(char *dst, const char *src, const char *end); + +#endif diff --git a/net/ceph/debugfs.c b/net/ceph/debugfs.c new file mode 100644 index 000000000..2110439f8 --- /dev/null +++ b/net/ceph/debugfs.c @@ -0,0 +1,478 @@ +// SPDX-License-Identifier: GPL-2.0 +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#ifdef CONFIG_DEBUG_FS + +/* + * Implement /sys/kernel/debug/ceph fun + * + * /sys/kernel/debug/ceph/client* - an instance of the ceph client + * .../osdmap - current osdmap + * .../monmap - current monmap + * .../osdc - active osd requests + * .../monc - mon client state + * .../client_options - libceph-only (i.e. not rbd or cephfs) options + * .../dentry_lru - dump contents of dentry lru + * .../caps - expose cap (reservation) stats + * .../bdi - symlink to ../../bdi/something + */ + +static struct dentry *ceph_debugfs_dir; + +static int monmap_show(struct seq_file *s, void *p) +{ + int i; + struct ceph_client *client = s->private; + + if (client->monc.monmap == NULL) + return 0; + + seq_printf(s, "epoch %d\n", client->monc.monmap->epoch); + for (i = 0; i < client->monc.monmap->num_mon; i++) { + struct ceph_entity_inst *inst = + &client->monc.monmap->mon_inst[i]; + + seq_printf(s, "\t%s%lld\t%s\n", + ENTITY_NAME(inst->name), + ceph_pr_addr(&inst->addr)); + } + return 0; +} + +static int osdmap_show(struct seq_file *s, void *p) +{ + int i; + struct ceph_client *client = s->private; + struct ceph_osd_client *osdc = &client->osdc; + struct ceph_osdmap *map = osdc->osdmap; + struct rb_node *n; + + if (map == NULL) + return 0; + + down_read(&osdc->lock); + seq_printf(s, "epoch %u barrier %u flags 0x%x\n", map->epoch, + osdc->epoch_barrier, map->flags); + + for (n = rb_first(&map->pg_pools); n; n = rb_next(n)) { + struct ceph_pg_pool_info *pi = + rb_entry(n, struct ceph_pg_pool_info, node); + + seq_printf(s, "pool %lld '%s' type %d size %d min_size %d pg_num %u pg_num_mask %d flags 0x%llx lfor %u read_tier %lld write_tier %lld\n", + pi->id, pi->name, pi->type, pi->size, pi->min_size, + pi->pg_num, pi->pg_num_mask, pi->flags, + pi->last_force_request_resend, pi->read_tier, + pi->write_tier); + } + for (i = 0; i < map->max_osd; i++) { + struct ceph_entity_addr *addr = &map->osd_addr[i]; + u32 state = map->osd_state[i]; + char sb[64]; + + seq_printf(s, "osd%d\t%s\t%3d%%\t(%s)\t%3d%%\t%2d\n", + i, ceph_pr_addr(addr), + ((map->osd_weight[i]*100) >> 16), + ceph_osdmap_state_str(sb, sizeof(sb), state), + ((ceph_get_primary_affinity(map, i)*100) >> 16), + ceph_get_crush_locality(map, i, + &client->options->crush_locs)); + } + for (n = rb_first(&map->pg_temp); n; n = rb_next(n)) { + struct ceph_pg_mapping *pg = + rb_entry(n, struct ceph_pg_mapping, node); + + seq_printf(s, "pg_temp %llu.%x [", pg->pgid.pool, + pg->pgid.seed); + for (i = 0; i < pg->pg_temp.len; i++) + seq_printf(s, "%s%d", (i == 0 ? "" : ","), + pg->pg_temp.osds[i]); + seq_printf(s, "]\n"); + } + for (n = rb_first(&map->primary_temp); n; n = rb_next(n)) { + struct ceph_pg_mapping *pg = + rb_entry(n, struct ceph_pg_mapping, node); + + seq_printf(s, "primary_temp %llu.%x %d\n", pg->pgid.pool, + pg->pgid.seed, pg->primary_temp.osd); + } + for (n = rb_first(&map->pg_upmap); n; n = rb_next(n)) { + struct ceph_pg_mapping *pg = + rb_entry(n, struct ceph_pg_mapping, node); + + seq_printf(s, "pg_upmap %llu.%x [", pg->pgid.pool, + pg->pgid.seed); + for (i = 0; i < pg->pg_upmap.len; i++) + seq_printf(s, "%s%d", (i == 0 ? "" : ","), + pg->pg_upmap.osds[i]); + seq_printf(s, "]\n"); + } + for (n = rb_first(&map->pg_upmap_items); n; n = rb_next(n)) { + struct ceph_pg_mapping *pg = + rb_entry(n, struct ceph_pg_mapping, node); + + seq_printf(s, "pg_upmap_items %llu.%x [", pg->pgid.pool, + pg->pgid.seed); + for (i = 0; i < pg->pg_upmap_items.len; i++) + seq_printf(s, "%s%d->%d", (i == 0 ? "" : ","), + pg->pg_upmap_items.from_to[i][0], + pg->pg_upmap_items.from_to[i][1]); + seq_printf(s, "]\n"); + } + + up_read(&osdc->lock); + return 0; +} + +static int monc_show(struct seq_file *s, void *p) +{ + struct ceph_client *client = s->private; + struct ceph_mon_generic_request *req; + struct ceph_mon_client *monc = &client->monc; + struct rb_node *rp; + int i; + + mutex_lock(&monc->mutex); + + for (i = 0; i < ARRAY_SIZE(monc->subs); i++) { + seq_printf(s, "have %s %u", ceph_sub_str[i], + monc->subs[i].have); + if (monc->subs[i].want) + seq_printf(s, " want %llu%s", + le64_to_cpu(monc->subs[i].item.start), + (monc->subs[i].item.flags & + CEPH_SUBSCRIBE_ONETIME ? "" : "+")); + seq_putc(s, '\n'); + } + seq_printf(s, "fs_cluster_id %d\n", monc->fs_cluster_id); + + for (rp = rb_first(&monc->generic_request_tree); rp; rp = rb_next(rp)) { + __u16 op; + req = rb_entry(rp, struct ceph_mon_generic_request, node); + op = le16_to_cpu(req->request->hdr.type); + if (op == CEPH_MSG_STATFS) + seq_printf(s, "%llu statfs\n", req->tid); + else if (op == CEPH_MSG_MON_GET_VERSION) + seq_printf(s, "%llu mon_get_version", req->tid); + else + seq_printf(s, "%llu unknown\n", req->tid); + } + + mutex_unlock(&monc->mutex); + return 0; +} + +static void dump_spgid(struct seq_file *s, const struct ceph_spg *spgid) +{ + seq_printf(s, "%llu.%x", spgid->pgid.pool, spgid->pgid.seed); + if (spgid->shard != CEPH_SPG_NOSHARD) + seq_printf(s, "s%d", spgid->shard); +} + +static void dump_target(struct seq_file *s, struct ceph_osd_request_target *t) +{ + int i; + + seq_printf(s, "osd%d\t%llu.%x\t", t->osd, t->pgid.pool, t->pgid.seed); + dump_spgid(s, &t->spgid); + seq_puts(s, "\t["); + for (i = 0; i < t->up.size; i++) + seq_printf(s, "%s%d", (!i ? "" : ","), t->up.osds[i]); + seq_printf(s, "]/%d\t[", t->up.primary); + for (i = 0; i < t->acting.size; i++) + seq_printf(s, "%s%d", (!i ? "" : ","), t->acting.osds[i]); + seq_printf(s, "]/%d\te%u\t", t->acting.primary, t->epoch); + if (t->target_oloc.pool_ns) { + seq_printf(s, "%*pE/%*pE\t0x%x", + (int)t->target_oloc.pool_ns->len, + t->target_oloc.pool_ns->str, + t->target_oid.name_len, t->target_oid.name, t->flags); + } else { + seq_printf(s, "%*pE\t0x%x", t->target_oid.name_len, + t->target_oid.name, t->flags); + } + if (t->paused) + seq_puts(s, "\tP"); +} + +static void dump_request(struct seq_file *s, struct ceph_osd_request *req) +{ + int i; + + seq_printf(s, "%llu\t", req->r_tid); + dump_target(s, &req->r_t); + + seq_printf(s, "\t%d", req->r_attempts); + + for (i = 0; i < req->r_num_ops; i++) { + struct ceph_osd_req_op *op = &req->r_ops[i]; + + seq_printf(s, "%s%s", (i == 0 ? "\t" : ","), + ceph_osd_op_name(op->op)); + if (op->op == CEPH_OSD_OP_WATCH) + seq_printf(s, "-%s", + ceph_osd_watch_op_name(op->watch.op)); + else if (op->op == CEPH_OSD_OP_CALL) + seq_printf(s, "-%s/%s", op->cls.class_name, + op->cls.method_name); + } + + seq_putc(s, '\n'); +} + +static void dump_requests(struct seq_file *s, struct ceph_osd *osd) +{ + struct rb_node *n; + + mutex_lock(&osd->lock); + for (n = rb_first(&osd->o_requests); n; n = rb_next(n)) { + struct ceph_osd_request *req = + rb_entry(n, struct ceph_osd_request, r_node); + + dump_request(s, req); + } + + mutex_unlock(&osd->lock); +} + +static void dump_linger_request(struct seq_file *s, + struct ceph_osd_linger_request *lreq) +{ + seq_printf(s, "%llu\t", lreq->linger_id); + dump_target(s, &lreq->t); + + seq_printf(s, "\t%u\t%s%s/%d\n", lreq->register_gen, + lreq->is_watch ? "W" : "N", lreq->committed ? "C" : "", + lreq->last_error); +} + +static void dump_linger_requests(struct seq_file *s, struct ceph_osd *osd) +{ + struct rb_node *n; + + mutex_lock(&osd->lock); + for (n = rb_first(&osd->o_linger_requests); n; n = rb_next(n)) { + struct ceph_osd_linger_request *lreq = + rb_entry(n, struct ceph_osd_linger_request, node); + + dump_linger_request(s, lreq); + } + + mutex_unlock(&osd->lock); +} + +static void dump_snapid(struct seq_file *s, u64 snapid) +{ + if (snapid == CEPH_NOSNAP) + seq_puts(s, "head"); + else if (snapid == CEPH_SNAPDIR) + seq_puts(s, "snapdir"); + else + seq_printf(s, "%llx", snapid); +} + +static void dump_name_escaped(struct seq_file *s, unsigned char *name, + size_t len) +{ + size_t i; + + for (i = 0; i < len; i++) { + if (name[i] == '%' || name[i] == ':' || name[i] == '/' || + name[i] < 32 || name[i] >= 127) { + seq_printf(s, "%%%02x", name[i]); + } else { + seq_putc(s, name[i]); + } + } +} + +static void dump_hoid(struct seq_file *s, const struct ceph_hobject_id *hoid) +{ + if (hoid->snapid == 0 && hoid->hash == 0 && !hoid->is_max && + hoid->pool == S64_MIN) { + seq_puts(s, "MIN"); + return; + } + if (hoid->is_max) { + seq_puts(s, "MAX"); + return; + } + seq_printf(s, "%lld:%08x:", hoid->pool, hoid->hash_reverse_bits); + dump_name_escaped(s, hoid->nspace, hoid->nspace_len); + seq_putc(s, ':'); + dump_name_escaped(s, hoid->key, hoid->key_len); + seq_putc(s, ':'); + dump_name_escaped(s, hoid->oid, hoid->oid_len); + seq_putc(s, ':'); + dump_snapid(s, hoid->snapid); +} + +static void dump_backoffs(struct seq_file *s, struct ceph_osd *osd) +{ + struct rb_node *n; + + mutex_lock(&osd->lock); + for (n = rb_first(&osd->o_backoffs_by_id); n; n = rb_next(n)) { + struct ceph_osd_backoff *backoff = + rb_entry(n, struct ceph_osd_backoff, id_node); + + seq_printf(s, "osd%d\t", osd->o_osd); + dump_spgid(s, &backoff->spgid); + seq_printf(s, "\t%llu\t", backoff->id); + dump_hoid(s, backoff->begin); + seq_putc(s, '\t'); + dump_hoid(s, backoff->end); + seq_putc(s, '\n'); + } + + mutex_unlock(&osd->lock); +} + +static int osdc_show(struct seq_file *s, void *pp) +{ + struct ceph_client *client = s->private; + struct ceph_osd_client *osdc = &client->osdc; + struct rb_node *n; + + down_read(&osdc->lock); + seq_printf(s, "REQUESTS %d homeless %d\n", + atomic_read(&osdc->num_requests), + atomic_read(&osdc->num_homeless)); + for (n = rb_first(&osdc->osds); n; n = rb_next(n)) { + struct ceph_osd *osd = rb_entry(n, struct ceph_osd, o_node); + + dump_requests(s, osd); + } + dump_requests(s, &osdc->homeless_osd); + + seq_puts(s, "LINGER REQUESTS\n"); + for (n = rb_first(&osdc->osds); n; n = rb_next(n)) { + struct ceph_osd *osd = rb_entry(n, struct ceph_osd, o_node); + + dump_linger_requests(s, osd); + } + dump_linger_requests(s, &osdc->homeless_osd); + + seq_puts(s, "BACKOFFS\n"); + for (n = rb_first(&osdc->osds); n; n = rb_next(n)) { + struct ceph_osd *osd = rb_entry(n, struct ceph_osd, o_node); + + dump_backoffs(s, osd); + } + + up_read(&osdc->lock); + return 0; +} + +static int client_options_show(struct seq_file *s, void *p) +{ + struct ceph_client *client = s->private; + int ret; + + ret = ceph_print_client_options(s, client, true); + if (ret) + return ret; + + seq_putc(s, '\n'); + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(monmap); +DEFINE_SHOW_ATTRIBUTE(osdmap); +DEFINE_SHOW_ATTRIBUTE(monc); +DEFINE_SHOW_ATTRIBUTE(osdc); +DEFINE_SHOW_ATTRIBUTE(client_options); + +void __init ceph_debugfs_init(void) +{ + ceph_debugfs_dir = debugfs_create_dir("ceph", NULL); +} + +void ceph_debugfs_cleanup(void) +{ + debugfs_remove(ceph_debugfs_dir); +} + +void ceph_debugfs_client_init(struct ceph_client *client) +{ + char name[80]; + + snprintf(name, sizeof(name), "%pU.client%lld", &client->fsid, + client->monc.auth->global_id); + + dout("ceph_debugfs_client_init %p %s\n", client, name); + + client->debugfs_dir = debugfs_create_dir(name, ceph_debugfs_dir); + + client->monc.debugfs_file = debugfs_create_file("monc", + 0400, + client->debugfs_dir, + client, + &monc_fops); + + client->osdc.debugfs_file = debugfs_create_file("osdc", + 0400, + client->debugfs_dir, + client, + &osdc_fops); + + client->debugfs_monmap = debugfs_create_file("monmap", + 0400, + client->debugfs_dir, + client, + &monmap_fops); + + client->debugfs_osdmap = debugfs_create_file("osdmap", + 0400, + client->debugfs_dir, + client, + &osdmap_fops); + + client->debugfs_options = debugfs_create_file("client_options", + 0400, + client->debugfs_dir, + client, + &client_options_fops); +} + +void ceph_debugfs_client_cleanup(struct ceph_client *client) +{ + dout("ceph_debugfs_client_cleanup %p\n", client); + debugfs_remove(client->debugfs_options); + debugfs_remove(client->debugfs_osdmap); + debugfs_remove(client->debugfs_monmap); + debugfs_remove(client->osdc.debugfs_file); + debugfs_remove(client->monc.debugfs_file); + debugfs_remove(client->debugfs_dir); +} + +#else /* CONFIG_DEBUG_FS */ + +void __init ceph_debugfs_init(void) +{ +} + +void ceph_debugfs_cleanup(void) +{ +} + +void ceph_debugfs_client_init(struct ceph_client *client) +{ +} + +void ceph_debugfs_client_cleanup(struct ceph_client *client) +{ +} + +#endif /* CONFIG_DEBUG_FS */ diff --git a/net/ceph/decode.c b/net/ceph/decode.c new file mode 100644 index 000000000..bc109a1a4 --- /dev/null +++ b/net/ceph/decode.c @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: GPL-2.0 +#include + +#include + +#include +#include /* for ceph_pr_addr() */ + +static int +ceph_decode_entity_addr_versioned(void **p, void *end, + struct ceph_entity_addr *addr) +{ + int ret; + u8 struct_v; + u32 struct_len, addr_len; + void *struct_end; + + ret = ceph_start_decoding(p, end, 1, "entity_addr_t", &struct_v, + &struct_len); + if (ret) + goto bad; + + ret = -EINVAL; + struct_end = *p + struct_len; + + ceph_decode_copy_safe(p, end, &addr->type, sizeof(addr->type), bad); + + ceph_decode_copy_safe(p, end, &addr->nonce, sizeof(addr->nonce), bad); + + ceph_decode_32_safe(p, end, addr_len, bad); + if (addr_len > sizeof(addr->in_addr)) + goto bad; + + memset(&addr->in_addr, 0, sizeof(addr->in_addr)); + if (addr_len) { + ceph_decode_copy_safe(p, end, &addr->in_addr, addr_len, bad); + + addr->in_addr.ss_family = + le16_to_cpu((__force __le16)addr->in_addr.ss_family); + } + + /* Advance past anything the client doesn't yet understand */ + *p = struct_end; + ret = 0; +bad: + return ret; +} + +static int +ceph_decode_entity_addr_legacy(void **p, void *end, + struct ceph_entity_addr *addr) +{ + int ret = -EINVAL; + + /* Skip rest of type field */ + ceph_decode_skip_n(p, end, 3, bad); + + /* + * Clients that don't support ADDR2 always send TYPE_NONE, change it + * to TYPE_LEGACY for forward compatibility. + */ + addr->type = CEPH_ENTITY_ADDR_TYPE_LEGACY; + ceph_decode_copy_safe(p, end, &addr->nonce, sizeof(addr->nonce), bad); + memset(&addr->in_addr, 0, sizeof(addr->in_addr)); + ceph_decode_copy_safe(p, end, &addr->in_addr, + sizeof(addr->in_addr), bad); + addr->in_addr.ss_family = + be16_to_cpu((__force __be16)addr->in_addr.ss_family); + ret = 0; +bad: + return ret; +} + +int +ceph_decode_entity_addr(void **p, void *end, struct ceph_entity_addr *addr) +{ + u8 marker; + + ceph_decode_8_safe(p, end, marker, bad); + if (marker == 1) + return ceph_decode_entity_addr_versioned(p, end, addr); + else if (marker == 0) + return ceph_decode_entity_addr_legacy(p, end, addr); +bad: + return -EINVAL; +} +EXPORT_SYMBOL(ceph_decode_entity_addr); + +/* + * Return addr of desired type (MSGR2 or LEGACY) or error. + * Make sure there is only one match. + * + * Assume encoding with MSG_ADDR2. + */ +int ceph_decode_entity_addrvec(void **p, void *end, bool msgr2, + struct ceph_entity_addr *addr) +{ + __le32 my_type = msgr2 ? CEPH_ENTITY_ADDR_TYPE_MSGR2 : + CEPH_ENTITY_ADDR_TYPE_LEGACY; + struct ceph_entity_addr tmp_addr; + int addr_cnt; + bool found; + u8 marker; + int ret; + int i; + + ceph_decode_8_safe(p, end, marker, e_inval); + if (marker != 2) { + pr_err("bad addrvec marker %d\n", marker); + return -EINVAL; + } + + ceph_decode_32_safe(p, end, addr_cnt, e_inval); + dout("%s addr_cnt %d\n", __func__, addr_cnt); + + found = false; + for (i = 0; i < addr_cnt; i++) { + ret = ceph_decode_entity_addr(p, end, &tmp_addr); + if (ret) + return ret; + + dout("%s i %d addr %s\n", __func__, i, ceph_pr_addr(&tmp_addr)); + if (tmp_addr.type == my_type) { + if (found) { + pr_err("another match of type %d in addrvec\n", + le32_to_cpu(my_type)); + return -EINVAL; + } + + memcpy(addr, &tmp_addr, sizeof(*addr)); + found = true; + } + } + + if (found) + return 0; + + if (!addr_cnt) + return 0; /* normal -- e.g. unused OSD id/slot */ + + if (addr_cnt == 1 && !memchr_inv(&tmp_addr, 0, sizeof(tmp_addr))) + return 0; /* weird but effectively the same as !addr_cnt */ + + pr_err("no match of type %d in addrvec\n", le32_to_cpu(my_type)); + return -ENOENT; + +e_inval: + return -EINVAL; +} +EXPORT_SYMBOL(ceph_decode_entity_addrvec); + +static int get_sockaddr_encoding_len(sa_family_t family) +{ + union { + struct sockaddr sa; + struct sockaddr_in sin; + struct sockaddr_in6 sin6; + } u; + + switch (family) { + case AF_INET: + return sizeof(u.sin); + case AF_INET6: + return sizeof(u.sin6); + default: + return sizeof(u); + } +} + +int ceph_entity_addr_encoding_len(const struct ceph_entity_addr *addr) +{ + sa_family_t family = get_unaligned(&addr->in_addr.ss_family); + int addr_len = get_sockaddr_encoding_len(family); + + return 1 + CEPH_ENCODING_START_BLK_LEN + 4 + 4 + 4 + addr_len; +} + +void ceph_encode_entity_addr(void **p, const struct ceph_entity_addr *addr) +{ + sa_family_t family = get_unaligned(&addr->in_addr.ss_family); + int addr_len = get_sockaddr_encoding_len(family); + + ceph_encode_8(p, 1); /* marker */ + ceph_start_encoding(p, 1, 1, sizeof(addr->type) + + sizeof(addr->nonce) + + sizeof(u32) + addr_len); + ceph_encode_copy(p, &addr->type, sizeof(addr->type)); + ceph_encode_copy(p, &addr->nonce, sizeof(addr->nonce)); + + ceph_encode_32(p, addr_len); + ceph_encode_16(p, family); + ceph_encode_copy(p, addr->in_addr.__data, addr_len - sizeof(family)); +} diff --git a/net/ceph/messenger.c b/net/ceph/messenger.c new file mode 100644 index 000000000..b9b64a242 --- /dev/null +++ b/net/ceph/messenger.c @@ -0,0 +1,2135 @@ +// SPDX-License-Identifier: GPL-2.0 +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_BLOCK +#include +#endif /* CONFIG_BLOCK */ +#include +#include + +#include +#include +#include +#include +#include +#include + +/* + * Ceph uses the messenger to exchange ceph_msg messages with other + * hosts in the system. The messenger provides ordered and reliable + * delivery. We tolerate TCP disconnects by reconnecting (with + * exponential backoff) in the case of a fault (disconnection, bad + * crc, protocol error). Acks allow sent messages to be discarded by + * the sender. + */ + +/* + * We track the state of the socket on a given connection using + * values defined below. The transition to a new socket state is + * handled by a function which verifies we aren't coming from an + * unexpected state. + * + * -------- + * | NEW* | transient initial state + * -------- + * | con_sock_state_init() + * v + * ---------- + * | CLOSED | initialized, but no socket (and no + * ---------- TCP connection) + * ^ \ + * | \ con_sock_state_connecting() + * | ---------------------- + * | \ + * + con_sock_state_closed() \ + * |+--------------------------- \ + * | \ \ \ + * | ----------- \ \ + * | | CLOSING | socket event; \ \ + * | ----------- await close \ \ + * | ^ \ | + * | | \ | + * | + con_sock_state_closing() \ | + * | / \ | | + * | / --------------- | | + * | / \ v v + * | / -------------- + * | / -----------------| CONNECTING | socket created, TCP + * | | / -------------- connect initiated + * | | | con_sock_state_connected() + * | | v + * ------------- + * | CONNECTED | TCP connection established + * ------------- + * + * State values for ceph_connection->sock_state; NEW is assumed to be 0. + */ + +#define CON_SOCK_STATE_NEW 0 /* -> CLOSED */ +#define CON_SOCK_STATE_CLOSED 1 /* -> CONNECTING */ +#define CON_SOCK_STATE_CONNECTING 2 /* -> CONNECTED or -> CLOSING */ +#define CON_SOCK_STATE_CONNECTED 3 /* -> CLOSING or -> CLOSED */ +#define CON_SOCK_STATE_CLOSING 4 /* -> CLOSED */ + +static bool con_flag_valid(unsigned long con_flag) +{ + switch (con_flag) { + case CEPH_CON_F_LOSSYTX: + case CEPH_CON_F_KEEPALIVE_PENDING: + case CEPH_CON_F_WRITE_PENDING: + case CEPH_CON_F_SOCK_CLOSED: + case CEPH_CON_F_BACKOFF: + return true; + default: + return false; + } +} + +void ceph_con_flag_clear(struct ceph_connection *con, unsigned long con_flag) +{ + BUG_ON(!con_flag_valid(con_flag)); + + clear_bit(con_flag, &con->flags); +} + +void ceph_con_flag_set(struct ceph_connection *con, unsigned long con_flag) +{ + BUG_ON(!con_flag_valid(con_flag)); + + set_bit(con_flag, &con->flags); +} + +bool ceph_con_flag_test(struct ceph_connection *con, unsigned long con_flag) +{ + BUG_ON(!con_flag_valid(con_flag)); + + return test_bit(con_flag, &con->flags); +} + +bool ceph_con_flag_test_and_clear(struct ceph_connection *con, + unsigned long con_flag) +{ + BUG_ON(!con_flag_valid(con_flag)); + + return test_and_clear_bit(con_flag, &con->flags); +} + +bool ceph_con_flag_test_and_set(struct ceph_connection *con, + unsigned long con_flag) +{ + BUG_ON(!con_flag_valid(con_flag)); + + return test_and_set_bit(con_flag, &con->flags); +} + +/* Slab caches for frequently-allocated structures */ + +static struct kmem_cache *ceph_msg_cache; + +#ifdef CONFIG_LOCKDEP +static struct lock_class_key socket_class; +#endif + +static void queue_con(struct ceph_connection *con); +static void cancel_con(struct ceph_connection *con); +static void ceph_con_workfn(struct work_struct *); +static void con_fault(struct ceph_connection *con); + +/* + * Nicely render a sockaddr as a string. An array of formatted + * strings is used, to approximate reentrancy. + */ +#define ADDR_STR_COUNT_LOG 5 /* log2(# address strings in array) */ +#define ADDR_STR_COUNT (1 << ADDR_STR_COUNT_LOG) +#define ADDR_STR_COUNT_MASK (ADDR_STR_COUNT - 1) +#define MAX_ADDR_STR_LEN 64 /* 54 is enough */ + +static char addr_str[ADDR_STR_COUNT][MAX_ADDR_STR_LEN]; +static atomic_t addr_str_seq = ATOMIC_INIT(0); + +struct page *ceph_zero_page; /* used in certain error cases */ + +const char *ceph_pr_addr(const struct ceph_entity_addr *addr) +{ + int i; + char *s; + struct sockaddr_storage ss = addr->in_addr; /* align */ + struct sockaddr_in *in4 = (struct sockaddr_in *)&ss; + struct sockaddr_in6 *in6 = (struct sockaddr_in6 *)&ss; + + i = atomic_inc_return(&addr_str_seq) & ADDR_STR_COUNT_MASK; + s = addr_str[i]; + + switch (ss.ss_family) { + case AF_INET: + snprintf(s, MAX_ADDR_STR_LEN, "(%d)%pI4:%hu", + le32_to_cpu(addr->type), &in4->sin_addr, + ntohs(in4->sin_port)); + break; + + case AF_INET6: + snprintf(s, MAX_ADDR_STR_LEN, "(%d)[%pI6c]:%hu", + le32_to_cpu(addr->type), &in6->sin6_addr, + ntohs(in6->sin6_port)); + break; + + default: + snprintf(s, MAX_ADDR_STR_LEN, "(unknown sockaddr family %hu)", + ss.ss_family); + } + + return s; +} +EXPORT_SYMBOL(ceph_pr_addr); + +void ceph_encode_my_addr(struct ceph_messenger *msgr) +{ + if (!ceph_msgr2(from_msgr(msgr))) { + memcpy(&msgr->my_enc_addr, &msgr->inst.addr, + sizeof(msgr->my_enc_addr)); + ceph_encode_banner_addr(&msgr->my_enc_addr); + } +} + +/* + * work queue for all reading and writing to/from the socket. + */ +static struct workqueue_struct *ceph_msgr_wq; + +static int ceph_msgr_slab_init(void) +{ + BUG_ON(ceph_msg_cache); + ceph_msg_cache = KMEM_CACHE(ceph_msg, 0); + if (!ceph_msg_cache) + return -ENOMEM; + + return 0; +} + +static void ceph_msgr_slab_exit(void) +{ + BUG_ON(!ceph_msg_cache); + kmem_cache_destroy(ceph_msg_cache); + ceph_msg_cache = NULL; +} + +static void _ceph_msgr_exit(void) +{ + if (ceph_msgr_wq) { + destroy_workqueue(ceph_msgr_wq); + ceph_msgr_wq = NULL; + } + + BUG_ON(!ceph_zero_page); + put_page(ceph_zero_page); + ceph_zero_page = NULL; + + ceph_msgr_slab_exit(); +} + +int __init ceph_msgr_init(void) +{ + if (ceph_msgr_slab_init()) + return -ENOMEM; + + BUG_ON(ceph_zero_page); + ceph_zero_page = ZERO_PAGE(0); + get_page(ceph_zero_page); + + /* + * The number of active work items is limited by the number of + * connections, so leave @max_active at default. + */ + ceph_msgr_wq = alloc_workqueue("ceph-msgr", WQ_MEM_RECLAIM, 0); + if (ceph_msgr_wq) + return 0; + + pr_err("msgr_init failed to create workqueue\n"); + _ceph_msgr_exit(); + + return -ENOMEM; +} + +void ceph_msgr_exit(void) +{ + BUG_ON(ceph_msgr_wq == NULL); + + _ceph_msgr_exit(); +} + +void ceph_msgr_flush(void) +{ + flush_workqueue(ceph_msgr_wq); +} +EXPORT_SYMBOL(ceph_msgr_flush); + +/* Connection socket state transition functions */ + +static void con_sock_state_init(struct ceph_connection *con) +{ + int old_state; + + old_state = atomic_xchg(&con->sock_state, CON_SOCK_STATE_CLOSED); + if (WARN_ON(old_state != CON_SOCK_STATE_NEW)) + printk("%s: unexpected old state %d\n", __func__, old_state); + dout("%s con %p sock %d -> %d\n", __func__, con, old_state, + CON_SOCK_STATE_CLOSED); +} + +static void con_sock_state_connecting(struct ceph_connection *con) +{ + int old_state; + + old_state = atomic_xchg(&con->sock_state, CON_SOCK_STATE_CONNECTING); + if (WARN_ON(old_state != CON_SOCK_STATE_CLOSED)) + printk("%s: unexpected old state %d\n", __func__, old_state); + dout("%s con %p sock %d -> %d\n", __func__, con, old_state, + CON_SOCK_STATE_CONNECTING); +} + +static void con_sock_state_connected(struct ceph_connection *con) +{ + int old_state; + + old_state = atomic_xchg(&con->sock_state, CON_SOCK_STATE_CONNECTED); + if (WARN_ON(old_state != CON_SOCK_STATE_CONNECTING)) + printk("%s: unexpected old state %d\n", __func__, old_state); + dout("%s con %p sock %d -> %d\n", __func__, con, old_state, + CON_SOCK_STATE_CONNECTED); +} + +static void con_sock_state_closing(struct ceph_connection *con) +{ + int old_state; + + old_state = atomic_xchg(&con->sock_state, CON_SOCK_STATE_CLOSING); + if (WARN_ON(old_state != CON_SOCK_STATE_CONNECTING && + old_state != CON_SOCK_STATE_CONNECTED && + old_state != CON_SOCK_STATE_CLOSING)) + printk("%s: unexpected old state %d\n", __func__, old_state); + dout("%s con %p sock %d -> %d\n", __func__, con, old_state, + CON_SOCK_STATE_CLOSING); +} + +static void con_sock_state_closed(struct ceph_connection *con) +{ + int old_state; + + old_state = atomic_xchg(&con->sock_state, CON_SOCK_STATE_CLOSED); + if (WARN_ON(old_state != CON_SOCK_STATE_CONNECTED && + old_state != CON_SOCK_STATE_CLOSING && + old_state != CON_SOCK_STATE_CONNECTING && + old_state != CON_SOCK_STATE_CLOSED)) + printk("%s: unexpected old state %d\n", __func__, old_state); + dout("%s con %p sock %d -> %d\n", __func__, con, old_state, + CON_SOCK_STATE_CLOSED); +} + +/* + * socket callback functions + */ + +/* data available on socket, or listen socket received a connect */ +static void ceph_sock_data_ready(struct sock *sk) +{ + struct ceph_connection *con = sk->sk_user_data; + if (atomic_read(&con->msgr->stopping)) { + return; + } + + if (sk->sk_state != TCP_CLOSE_WAIT) { + dout("%s %p state = %d, queueing work\n", __func__, + con, con->state); + queue_con(con); + } +} + +/* socket has buffer space for writing */ +static void ceph_sock_write_space(struct sock *sk) +{ + struct ceph_connection *con = sk->sk_user_data; + + /* only queue to workqueue if there is data we want to write, + * and there is sufficient space in the socket buffer to accept + * more data. clear SOCK_NOSPACE so that ceph_sock_write_space() + * doesn't get called again until try_write() fills the socket + * buffer. See net/ipv4/tcp_input.c:tcp_check_space() + * and net/core/stream.c:sk_stream_write_space(). + */ + if (ceph_con_flag_test(con, CEPH_CON_F_WRITE_PENDING)) { + if (sk_stream_is_writeable(sk)) { + dout("%s %p queueing write work\n", __func__, con); + clear_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + queue_con(con); + } + } else { + dout("%s %p nothing to write\n", __func__, con); + } +} + +/* socket's state has changed */ +static void ceph_sock_state_change(struct sock *sk) +{ + struct ceph_connection *con = sk->sk_user_data; + + dout("%s %p state = %d sk_state = %u\n", __func__, + con, con->state, sk->sk_state); + + switch (sk->sk_state) { + case TCP_CLOSE: + dout("%s TCP_CLOSE\n", __func__); + fallthrough; + case TCP_CLOSE_WAIT: + dout("%s TCP_CLOSE_WAIT\n", __func__); + con_sock_state_closing(con); + ceph_con_flag_set(con, CEPH_CON_F_SOCK_CLOSED); + queue_con(con); + break; + case TCP_ESTABLISHED: + dout("%s TCP_ESTABLISHED\n", __func__); + con_sock_state_connected(con); + queue_con(con); + break; + default: /* Everything else is uninteresting */ + break; + } +} + +/* + * set up socket callbacks + */ +static void set_sock_callbacks(struct socket *sock, + struct ceph_connection *con) +{ + struct sock *sk = sock->sk; + sk->sk_user_data = con; + sk->sk_data_ready = ceph_sock_data_ready; + sk->sk_write_space = ceph_sock_write_space; + sk->sk_state_change = ceph_sock_state_change; +} + + +/* + * socket helpers + */ + +/* + * initiate connection to a remote socket. + */ +int ceph_tcp_connect(struct ceph_connection *con) +{ + struct sockaddr_storage ss = con->peer_addr.in_addr; /* align */ + struct socket *sock; + unsigned int noio_flag; + int ret; + + dout("%s con %p peer_addr %s\n", __func__, con, + ceph_pr_addr(&con->peer_addr)); + BUG_ON(con->sock); + + /* sock_create_kern() allocates with GFP_KERNEL */ + noio_flag = memalloc_noio_save(); + ret = sock_create_kern(read_pnet(&con->msgr->net), ss.ss_family, + SOCK_STREAM, IPPROTO_TCP, &sock); + memalloc_noio_restore(noio_flag); + if (ret) + return ret; + sock->sk->sk_allocation = GFP_NOFS; + +#ifdef CONFIG_LOCKDEP + lockdep_set_class(&sock->sk->sk_lock, &socket_class); +#endif + + set_sock_callbacks(sock, con); + + con_sock_state_connecting(con); + ret = kernel_connect(sock, (struct sockaddr *)&ss, sizeof(ss), + O_NONBLOCK); + if (ret == -EINPROGRESS) { + dout("connect %s EINPROGRESS sk_state = %u\n", + ceph_pr_addr(&con->peer_addr), + sock->sk->sk_state); + } else if (ret < 0) { + pr_err("connect %s error %d\n", + ceph_pr_addr(&con->peer_addr), ret); + sock_release(sock); + return ret; + } + + if (ceph_test_opt(from_msgr(con->msgr), TCP_NODELAY)) + tcp_sock_set_nodelay(sock->sk); + + con->sock = sock; + return 0; +} + +/* + * Shutdown/close the socket for the given connection. + */ +int ceph_con_close_socket(struct ceph_connection *con) +{ + int rc = 0; + + dout("%s con %p sock %p\n", __func__, con, con->sock); + if (con->sock) { + rc = con->sock->ops->shutdown(con->sock, SHUT_RDWR); + sock_release(con->sock); + con->sock = NULL; + } + + /* + * Forcibly clear the SOCK_CLOSED flag. It gets set + * independent of the connection mutex, and we could have + * received a socket close event before we had the chance to + * shut the socket down. + */ + ceph_con_flag_clear(con, CEPH_CON_F_SOCK_CLOSED); + + con_sock_state_closed(con); + return rc; +} + +static void ceph_con_reset_protocol(struct ceph_connection *con) +{ + dout("%s con %p\n", __func__, con); + + ceph_con_close_socket(con); + if (con->in_msg) { + WARN_ON(con->in_msg->con != con); + ceph_msg_put(con->in_msg); + con->in_msg = NULL; + } + if (con->out_msg) { + WARN_ON(con->out_msg->con != con); + ceph_msg_put(con->out_msg); + con->out_msg = NULL; + } + if (con->bounce_page) { + __free_page(con->bounce_page); + con->bounce_page = NULL; + } + + if (ceph_msgr2(from_msgr(con->msgr))) + ceph_con_v2_reset_protocol(con); + else + ceph_con_v1_reset_protocol(con); +} + +/* + * Reset a connection. Discard all incoming and outgoing messages + * and clear *_seq state. + */ +static void ceph_msg_remove(struct ceph_msg *msg) +{ + list_del_init(&msg->list_head); + + ceph_msg_put(msg); +} + +static void ceph_msg_remove_list(struct list_head *head) +{ + while (!list_empty(head)) { + struct ceph_msg *msg = list_first_entry(head, struct ceph_msg, + list_head); + ceph_msg_remove(msg); + } +} + +void ceph_con_reset_session(struct ceph_connection *con) +{ + dout("%s con %p\n", __func__, con); + + WARN_ON(con->in_msg); + WARN_ON(con->out_msg); + ceph_msg_remove_list(&con->out_queue); + ceph_msg_remove_list(&con->out_sent); + con->out_seq = 0; + con->in_seq = 0; + con->in_seq_acked = 0; + + if (ceph_msgr2(from_msgr(con->msgr))) + ceph_con_v2_reset_session(con); + else + ceph_con_v1_reset_session(con); +} + +/* + * mark a peer down. drop any open connections. + */ +void ceph_con_close(struct ceph_connection *con) +{ + mutex_lock(&con->mutex); + dout("con_close %p peer %s\n", con, ceph_pr_addr(&con->peer_addr)); + con->state = CEPH_CON_S_CLOSED; + + ceph_con_flag_clear(con, CEPH_CON_F_LOSSYTX); /* so we retry next + connect */ + ceph_con_flag_clear(con, CEPH_CON_F_KEEPALIVE_PENDING); + ceph_con_flag_clear(con, CEPH_CON_F_WRITE_PENDING); + ceph_con_flag_clear(con, CEPH_CON_F_BACKOFF); + + ceph_con_reset_protocol(con); + ceph_con_reset_session(con); + cancel_con(con); + mutex_unlock(&con->mutex); +} +EXPORT_SYMBOL(ceph_con_close); + +/* + * Reopen a closed connection, with a new peer address. + */ +void ceph_con_open(struct ceph_connection *con, + __u8 entity_type, __u64 entity_num, + struct ceph_entity_addr *addr) +{ + mutex_lock(&con->mutex); + dout("con_open %p %s\n", con, ceph_pr_addr(addr)); + + WARN_ON(con->state != CEPH_CON_S_CLOSED); + con->state = CEPH_CON_S_PREOPEN; + + con->peer_name.type = (__u8) entity_type; + con->peer_name.num = cpu_to_le64(entity_num); + + memcpy(&con->peer_addr, addr, sizeof(*addr)); + con->delay = 0; /* reset backoff memory */ + mutex_unlock(&con->mutex); + queue_con(con); +} +EXPORT_SYMBOL(ceph_con_open); + +/* + * return true if this connection ever successfully opened + */ +bool ceph_con_opened(struct ceph_connection *con) +{ + if (ceph_msgr2(from_msgr(con->msgr))) + return ceph_con_v2_opened(con); + + return ceph_con_v1_opened(con); +} + +/* + * initialize a new connection. + */ +void ceph_con_init(struct ceph_connection *con, void *private, + const struct ceph_connection_operations *ops, + struct ceph_messenger *msgr) +{ + dout("con_init %p\n", con); + memset(con, 0, sizeof(*con)); + con->private = private; + con->ops = ops; + con->msgr = msgr; + + con_sock_state_init(con); + + mutex_init(&con->mutex); + INIT_LIST_HEAD(&con->out_queue); + INIT_LIST_HEAD(&con->out_sent); + INIT_DELAYED_WORK(&con->work, ceph_con_workfn); + + con->state = CEPH_CON_S_CLOSED; +} +EXPORT_SYMBOL(ceph_con_init); + +/* + * We maintain a global counter to order connection attempts. Get + * a unique seq greater than @gt. + */ +u32 ceph_get_global_seq(struct ceph_messenger *msgr, u32 gt) +{ + u32 ret; + + spin_lock(&msgr->global_seq_lock); + if (msgr->global_seq < gt) + msgr->global_seq = gt; + ret = ++msgr->global_seq; + spin_unlock(&msgr->global_seq_lock); + return ret; +} + +/* + * Discard messages that have been acked by the server. + */ +void ceph_con_discard_sent(struct ceph_connection *con, u64 ack_seq) +{ + struct ceph_msg *msg; + u64 seq; + + dout("%s con %p ack_seq %llu\n", __func__, con, ack_seq); + while (!list_empty(&con->out_sent)) { + msg = list_first_entry(&con->out_sent, struct ceph_msg, + list_head); + WARN_ON(msg->needs_out_seq); + seq = le64_to_cpu(msg->hdr.seq); + if (seq > ack_seq) + break; + + dout("%s con %p discarding msg %p seq %llu\n", __func__, con, + msg, seq); + ceph_msg_remove(msg); + } +} + +/* + * Discard messages that have been requeued in con_fault(), up to + * reconnect_seq. This avoids gratuitously resending messages that + * the server had received and handled prior to reconnect. + */ +void ceph_con_discard_requeued(struct ceph_connection *con, u64 reconnect_seq) +{ + struct ceph_msg *msg; + u64 seq; + + dout("%s con %p reconnect_seq %llu\n", __func__, con, reconnect_seq); + while (!list_empty(&con->out_queue)) { + msg = list_first_entry(&con->out_queue, struct ceph_msg, + list_head); + if (msg->needs_out_seq) + break; + seq = le64_to_cpu(msg->hdr.seq); + if (seq > reconnect_seq) + break; + + dout("%s con %p discarding msg %p seq %llu\n", __func__, con, + msg, seq); + ceph_msg_remove(msg); + } +} + +#ifdef CONFIG_BLOCK + +/* + * For a bio data item, a piece is whatever remains of the next + * entry in the current bio iovec, or the first entry in the next + * bio in the list. + */ +static void ceph_msg_data_bio_cursor_init(struct ceph_msg_data_cursor *cursor, + size_t length) +{ + struct ceph_msg_data *data = cursor->data; + struct ceph_bio_iter *it = &cursor->bio_iter; + + cursor->resid = min_t(size_t, length, data->bio_length); + *it = data->bio_pos; + if (cursor->resid < it->iter.bi_size) + it->iter.bi_size = cursor->resid; + + BUG_ON(cursor->resid < bio_iter_len(it->bio, it->iter)); +} + +static struct page *ceph_msg_data_bio_next(struct ceph_msg_data_cursor *cursor, + size_t *page_offset, + size_t *length) +{ + struct bio_vec bv = bio_iter_iovec(cursor->bio_iter.bio, + cursor->bio_iter.iter); + + *page_offset = bv.bv_offset; + *length = bv.bv_len; + return bv.bv_page; +} + +static bool ceph_msg_data_bio_advance(struct ceph_msg_data_cursor *cursor, + size_t bytes) +{ + struct ceph_bio_iter *it = &cursor->bio_iter; + struct page *page = bio_iter_page(it->bio, it->iter); + + BUG_ON(bytes > cursor->resid); + BUG_ON(bytes > bio_iter_len(it->bio, it->iter)); + cursor->resid -= bytes; + bio_advance_iter(it->bio, &it->iter, bytes); + + if (!cursor->resid) + return false; /* no more data */ + + if (!bytes || (it->iter.bi_size && it->iter.bi_bvec_done && + page == bio_iter_page(it->bio, it->iter))) + return false; /* more bytes to process in this segment */ + + if (!it->iter.bi_size) { + it->bio = it->bio->bi_next; + it->iter = it->bio->bi_iter; + if (cursor->resid < it->iter.bi_size) + it->iter.bi_size = cursor->resid; + } + + BUG_ON(cursor->resid < bio_iter_len(it->bio, it->iter)); + return true; +} +#endif /* CONFIG_BLOCK */ + +static void ceph_msg_data_bvecs_cursor_init(struct ceph_msg_data_cursor *cursor, + size_t length) +{ + struct ceph_msg_data *data = cursor->data; + struct bio_vec *bvecs = data->bvec_pos.bvecs; + + cursor->resid = min_t(size_t, length, data->bvec_pos.iter.bi_size); + cursor->bvec_iter = data->bvec_pos.iter; + cursor->bvec_iter.bi_size = cursor->resid; + + BUG_ON(cursor->resid < bvec_iter_len(bvecs, cursor->bvec_iter)); +} + +static struct page *ceph_msg_data_bvecs_next(struct ceph_msg_data_cursor *cursor, + size_t *page_offset, + size_t *length) +{ + struct bio_vec bv = bvec_iter_bvec(cursor->data->bvec_pos.bvecs, + cursor->bvec_iter); + + *page_offset = bv.bv_offset; + *length = bv.bv_len; + return bv.bv_page; +} + +static bool ceph_msg_data_bvecs_advance(struct ceph_msg_data_cursor *cursor, + size_t bytes) +{ + struct bio_vec *bvecs = cursor->data->bvec_pos.bvecs; + struct page *page = bvec_iter_page(bvecs, cursor->bvec_iter); + + BUG_ON(bytes > cursor->resid); + BUG_ON(bytes > bvec_iter_len(bvecs, cursor->bvec_iter)); + cursor->resid -= bytes; + bvec_iter_advance(bvecs, &cursor->bvec_iter, bytes); + + if (!cursor->resid) + return false; /* no more data */ + + if (!bytes || (cursor->bvec_iter.bi_bvec_done && + page == bvec_iter_page(bvecs, cursor->bvec_iter))) + return false; /* more bytes to process in this segment */ + + BUG_ON(cursor->resid < bvec_iter_len(bvecs, cursor->bvec_iter)); + return true; +} + +/* + * For a page array, a piece comes from the first page in the array + * that has not already been fully consumed. + */ +static void ceph_msg_data_pages_cursor_init(struct ceph_msg_data_cursor *cursor, + size_t length) +{ + struct ceph_msg_data *data = cursor->data; + int page_count; + + BUG_ON(data->type != CEPH_MSG_DATA_PAGES); + + BUG_ON(!data->pages); + BUG_ON(!data->length); + + cursor->resid = min(length, data->length); + page_count = calc_pages_for(data->alignment, (u64)data->length); + cursor->page_offset = data->alignment & ~PAGE_MASK; + cursor->page_index = 0; + BUG_ON(page_count > (int)USHRT_MAX); + cursor->page_count = (unsigned short)page_count; + BUG_ON(length > SIZE_MAX - cursor->page_offset); +} + +static struct page * +ceph_msg_data_pages_next(struct ceph_msg_data_cursor *cursor, + size_t *page_offset, size_t *length) +{ + struct ceph_msg_data *data = cursor->data; + + BUG_ON(data->type != CEPH_MSG_DATA_PAGES); + + BUG_ON(cursor->page_index >= cursor->page_count); + BUG_ON(cursor->page_offset >= PAGE_SIZE); + + *page_offset = cursor->page_offset; + *length = min_t(size_t, cursor->resid, PAGE_SIZE - *page_offset); + return data->pages[cursor->page_index]; +} + +static bool ceph_msg_data_pages_advance(struct ceph_msg_data_cursor *cursor, + size_t bytes) +{ + BUG_ON(cursor->data->type != CEPH_MSG_DATA_PAGES); + + BUG_ON(cursor->page_offset + bytes > PAGE_SIZE); + + /* Advance the cursor page offset */ + + cursor->resid -= bytes; + cursor->page_offset = (cursor->page_offset + bytes) & ~PAGE_MASK; + if (!bytes || cursor->page_offset) + return false; /* more bytes to process in the current page */ + + if (!cursor->resid) + return false; /* no more data */ + + /* Move on to the next page; offset is already at 0 */ + + BUG_ON(cursor->page_index >= cursor->page_count); + cursor->page_index++; + return true; +} + +/* + * For a pagelist, a piece is whatever remains to be consumed in the + * first page in the list, or the front of the next page. + */ +static void +ceph_msg_data_pagelist_cursor_init(struct ceph_msg_data_cursor *cursor, + size_t length) +{ + struct ceph_msg_data *data = cursor->data; + struct ceph_pagelist *pagelist; + struct page *page; + + BUG_ON(data->type != CEPH_MSG_DATA_PAGELIST); + + pagelist = data->pagelist; + BUG_ON(!pagelist); + + if (!length) + return; /* pagelist can be assigned but empty */ + + BUG_ON(list_empty(&pagelist->head)); + page = list_first_entry(&pagelist->head, struct page, lru); + + cursor->resid = min(length, pagelist->length); + cursor->page = page; + cursor->offset = 0; +} + +static struct page * +ceph_msg_data_pagelist_next(struct ceph_msg_data_cursor *cursor, + size_t *page_offset, size_t *length) +{ + struct ceph_msg_data *data = cursor->data; + struct ceph_pagelist *pagelist; + + BUG_ON(data->type != CEPH_MSG_DATA_PAGELIST); + + pagelist = data->pagelist; + BUG_ON(!pagelist); + + BUG_ON(!cursor->page); + BUG_ON(cursor->offset + cursor->resid != pagelist->length); + + /* offset of first page in pagelist is always 0 */ + *page_offset = cursor->offset & ~PAGE_MASK; + *length = min_t(size_t, cursor->resid, PAGE_SIZE - *page_offset); + return cursor->page; +} + +static bool ceph_msg_data_pagelist_advance(struct ceph_msg_data_cursor *cursor, + size_t bytes) +{ + struct ceph_msg_data *data = cursor->data; + struct ceph_pagelist *pagelist; + + BUG_ON(data->type != CEPH_MSG_DATA_PAGELIST); + + pagelist = data->pagelist; + BUG_ON(!pagelist); + + BUG_ON(cursor->offset + cursor->resid != pagelist->length); + BUG_ON((cursor->offset & ~PAGE_MASK) + bytes > PAGE_SIZE); + + /* Advance the cursor offset */ + + cursor->resid -= bytes; + cursor->offset += bytes; + /* offset of first page in pagelist is always 0 */ + if (!bytes || cursor->offset & ~PAGE_MASK) + return false; /* more bytes to process in the current page */ + + if (!cursor->resid) + return false; /* no more data */ + + /* Move on to the next page */ + + BUG_ON(list_is_last(&cursor->page->lru, &pagelist->head)); + cursor->page = list_next_entry(cursor->page, lru); + return true; +} + +/* + * Message data is handled (sent or received) in pieces, where each + * piece resides on a single page. The network layer might not + * consume an entire piece at once. A data item's cursor keeps + * track of which piece is next to process and how much remains to + * be processed in that piece. It also tracks whether the current + * piece is the last one in the data item. + */ +static void __ceph_msg_data_cursor_init(struct ceph_msg_data_cursor *cursor) +{ + size_t length = cursor->total_resid; + + switch (cursor->data->type) { + case CEPH_MSG_DATA_PAGELIST: + ceph_msg_data_pagelist_cursor_init(cursor, length); + break; + case CEPH_MSG_DATA_PAGES: + ceph_msg_data_pages_cursor_init(cursor, length); + break; +#ifdef CONFIG_BLOCK + case CEPH_MSG_DATA_BIO: + ceph_msg_data_bio_cursor_init(cursor, length); + break; +#endif /* CONFIG_BLOCK */ + case CEPH_MSG_DATA_BVECS: + ceph_msg_data_bvecs_cursor_init(cursor, length); + break; + case CEPH_MSG_DATA_NONE: + default: + /* BUG(); */ + break; + } + cursor->need_crc = true; +} + +void ceph_msg_data_cursor_init(struct ceph_msg_data_cursor *cursor, + struct ceph_msg *msg, size_t length) +{ + BUG_ON(!length); + BUG_ON(length > msg->data_length); + BUG_ON(!msg->num_data_items); + + cursor->total_resid = length; + cursor->data = msg->data; + + __ceph_msg_data_cursor_init(cursor); +} + +/* + * Return the page containing the next piece to process for a given + * data item, and supply the page offset and length of that piece. + * Indicate whether this is the last piece in this data item. + */ +struct page *ceph_msg_data_next(struct ceph_msg_data_cursor *cursor, + size_t *page_offset, size_t *length) +{ + struct page *page; + + switch (cursor->data->type) { + case CEPH_MSG_DATA_PAGELIST: + page = ceph_msg_data_pagelist_next(cursor, page_offset, length); + break; + case CEPH_MSG_DATA_PAGES: + page = ceph_msg_data_pages_next(cursor, page_offset, length); + break; +#ifdef CONFIG_BLOCK + case CEPH_MSG_DATA_BIO: + page = ceph_msg_data_bio_next(cursor, page_offset, length); + break; +#endif /* CONFIG_BLOCK */ + case CEPH_MSG_DATA_BVECS: + page = ceph_msg_data_bvecs_next(cursor, page_offset, length); + break; + case CEPH_MSG_DATA_NONE: + default: + page = NULL; + break; + } + + BUG_ON(!page); + BUG_ON(*page_offset + *length > PAGE_SIZE); + BUG_ON(!*length); + BUG_ON(*length > cursor->resid); + + return page; +} + +/* + * Returns true if the result moves the cursor on to the next piece + * of the data item. + */ +void ceph_msg_data_advance(struct ceph_msg_data_cursor *cursor, size_t bytes) +{ + bool new_piece; + + BUG_ON(bytes > cursor->resid); + switch (cursor->data->type) { + case CEPH_MSG_DATA_PAGELIST: + new_piece = ceph_msg_data_pagelist_advance(cursor, bytes); + break; + case CEPH_MSG_DATA_PAGES: + new_piece = ceph_msg_data_pages_advance(cursor, bytes); + break; +#ifdef CONFIG_BLOCK + case CEPH_MSG_DATA_BIO: + new_piece = ceph_msg_data_bio_advance(cursor, bytes); + break; +#endif /* CONFIG_BLOCK */ + case CEPH_MSG_DATA_BVECS: + new_piece = ceph_msg_data_bvecs_advance(cursor, bytes); + break; + case CEPH_MSG_DATA_NONE: + default: + BUG(); + break; + } + cursor->total_resid -= bytes; + + if (!cursor->resid && cursor->total_resid) { + cursor->data++; + __ceph_msg_data_cursor_init(cursor); + new_piece = true; + } + cursor->need_crc = new_piece; +} + +u32 ceph_crc32c_page(u32 crc, struct page *page, unsigned int page_offset, + unsigned int length) +{ + char *kaddr; + + kaddr = kmap(page); + BUG_ON(kaddr == NULL); + crc = crc32c(crc, kaddr + page_offset, length); + kunmap(page); + + return crc; +} + +bool ceph_addr_is_blank(const struct ceph_entity_addr *addr) +{ + struct sockaddr_storage ss = addr->in_addr; /* align */ + struct in_addr *addr4 = &((struct sockaddr_in *)&ss)->sin_addr; + struct in6_addr *addr6 = &((struct sockaddr_in6 *)&ss)->sin6_addr; + + switch (ss.ss_family) { + case AF_INET: + return addr4->s_addr == htonl(INADDR_ANY); + case AF_INET6: + return ipv6_addr_any(addr6); + default: + return true; + } +} +EXPORT_SYMBOL(ceph_addr_is_blank); + +int ceph_addr_port(const struct ceph_entity_addr *addr) +{ + switch (get_unaligned(&addr->in_addr.ss_family)) { + case AF_INET: + return ntohs(get_unaligned(&((struct sockaddr_in *)&addr->in_addr)->sin_port)); + case AF_INET6: + return ntohs(get_unaligned(&((struct sockaddr_in6 *)&addr->in_addr)->sin6_port)); + } + return 0; +} + +void ceph_addr_set_port(struct ceph_entity_addr *addr, int p) +{ + switch (get_unaligned(&addr->in_addr.ss_family)) { + case AF_INET: + put_unaligned(htons(p), &((struct sockaddr_in *)&addr->in_addr)->sin_port); + break; + case AF_INET6: + put_unaligned(htons(p), &((struct sockaddr_in6 *)&addr->in_addr)->sin6_port); + break; + } +} + +/* + * Unlike other *_pton function semantics, zero indicates success. + */ +static int ceph_pton(const char *str, size_t len, struct ceph_entity_addr *addr, + char delim, const char **ipend) +{ + memset(&addr->in_addr, 0, sizeof(addr->in_addr)); + + if (in4_pton(str, len, (u8 *)&((struct sockaddr_in *)&addr->in_addr)->sin_addr.s_addr, delim, ipend)) { + put_unaligned(AF_INET, &addr->in_addr.ss_family); + return 0; + } + + if (in6_pton(str, len, (u8 *)&((struct sockaddr_in6 *)&addr->in_addr)->sin6_addr.s6_addr, delim, ipend)) { + put_unaligned(AF_INET6, &addr->in_addr.ss_family); + return 0; + } + + return -EINVAL; +} + +/* + * Extract hostname string and resolve using kernel DNS facility. + */ +#ifdef CONFIG_CEPH_LIB_USE_DNS_RESOLVER +static int ceph_dns_resolve_name(const char *name, size_t namelen, + struct ceph_entity_addr *addr, char delim, const char **ipend) +{ + const char *end, *delim_p; + char *colon_p, *ip_addr = NULL; + int ip_len, ret; + + /* + * The end of the hostname occurs immediately preceding the delimiter or + * the port marker (':') where the delimiter takes precedence. + */ + delim_p = memchr(name, delim, namelen); + colon_p = memchr(name, ':', namelen); + + if (delim_p && colon_p) + end = delim_p < colon_p ? delim_p : colon_p; + else if (!delim_p && colon_p) + end = colon_p; + else { + end = delim_p; + if (!end) /* case: hostname:/ */ + end = name + namelen; + } + + if (end <= name) + return -EINVAL; + + /* do dns_resolve upcall */ + ip_len = dns_query(current->nsproxy->net_ns, + NULL, name, end - name, NULL, &ip_addr, NULL, false); + if (ip_len > 0) + ret = ceph_pton(ip_addr, ip_len, addr, -1, NULL); + else + ret = -ESRCH; + + kfree(ip_addr); + + *ipend = end; + + pr_info("resolve '%.*s' (ret=%d): %s\n", (int)(end - name), name, + ret, ret ? "failed" : ceph_pr_addr(addr)); + + return ret; +} +#else +static inline int ceph_dns_resolve_name(const char *name, size_t namelen, + struct ceph_entity_addr *addr, char delim, const char **ipend) +{ + return -EINVAL; +} +#endif + +/* + * Parse a server name (IP or hostname). If a valid IP address is not found + * then try to extract a hostname to resolve using userspace DNS upcall. + */ +static int ceph_parse_server_name(const char *name, size_t namelen, + struct ceph_entity_addr *addr, char delim, const char **ipend) +{ + int ret; + + ret = ceph_pton(name, namelen, addr, delim, ipend); + if (ret) + ret = ceph_dns_resolve_name(name, namelen, addr, delim, ipend); + + return ret; +} + +/* + * Parse an ip[:port] list into an addr array. Use the default + * monitor port if a port isn't specified. + */ +int ceph_parse_ips(const char *c, const char *end, + struct ceph_entity_addr *addr, + int max_count, int *count, char delim) +{ + int i, ret = -EINVAL; + const char *p = c; + + dout("parse_ips on '%.*s'\n", (int)(end-c), c); + for (i = 0; i < max_count; i++) { + char cur_delim = delim; + const char *ipend; + int port; + + if (*p == '[') { + cur_delim = ']'; + p++; + } + + ret = ceph_parse_server_name(p, end - p, &addr[i], cur_delim, + &ipend); + if (ret) + goto bad; + ret = -EINVAL; + + p = ipend; + + if (cur_delim == ']') { + if (*p != ']') { + dout("missing matching ']'\n"); + goto bad; + } + p++; + } + + /* port? */ + if (p < end && *p == ':') { + port = 0; + p++; + while (p < end && *p >= '0' && *p <= '9') { + port = (port * 10) + (*p - '0'); + p++; + } + if (port == 0) + port = CEPH_MON_PORT; + else if (port > 65535) + goto bad; + } else { + port = CEPH_MON_PORT; + } + + ceph_addr_set_port(&addr[i], port); + /* + * We want the type to be set according to ms_mode + * option, but options are normally parsed after mon + * addresses. Rather than complicating parsing, set + * to LEGACY and override in build_initial_monmap() + * for mon addresses and ceph_messenger_init() for + * ip option. + */ + addr[i].type = CEPH_ENTITY_ADDR_TYPE_LEGACY; + addr[i].nonce = 0; + + dout("%s got %s\n", __func__, ceph_pr_addr(&addr[i])); + + if (p == end) + break; + if (*p != delim) + goto bad; + p++; + } + + if (p != end) + goto bad; + + if (count) + *count = i + 1; + return 0; + +bad: + return ret; +} + +/* + * Process message. This happens in the worker thread. The callback should + * be careful not to do anything that waits on other incoming messages or it + * may deadlock. + */ +void ceph_con_process_message(struct ceph_connection *con) +{ + struct ceph_msg *msg = con->in_msg; + + BUG_ON(con->in_msg->con != con); + con->in_msg = NULL; + + /* if first message, set peer_name */ + if (con->peer_name.type == 0) + con->peer_name = msg->hdr.src; + + con->in_seq++; + mutex_unlock(&con->mutex); + + dout("===== %p %llu from %s%lld %d=%s len %d+%d+%d (%u %u %u) =====\n", + msg, le64_to_cpu(msg->hdr.seq), + ENTITY_NAME(msg->hdr.src), + le16_to_cpu(msg->hdr.type), + ceph_msg_type_name(le16_to_cpu(msg->hdr.type)), + le32_to_cpu(msg->hdr.front_len), + le32_to_cpu(msg->hdr.middle_len), + le32_to_cpu(msg->hdr.data_len), + con->in_front_crc, con->in_middle_crc, con->in_data_crc); + con->ops->dispatch(con, msg); + + mutex_lock(&con->mutex); +} + +/* + * Atomically queue work on a connection after the specified delay. + * Bump @con reference to avoid races with connection teardown. + * Returns 0 if work was queued, or an error code otherwise. + */ +static int queue_con_delay(struct ceph_connection *con, unsigned long delay) +{ + if (!con->ops->get(con)) { + dout("%s %p ref count 0\n", __func__, con); + return -ENOENT; + } + + if (delay >= HZ) + delay = round_jiffies_relative(delay); + + dout("%s %p %lu\n", __func__, con, delay); + if (!queue_delayed_work(ceph_msgr_wq, &con->work, delay)) { + dout("%s %p - already queued\n", __func__, con); + con->ops->put(con); + return -EBUSY; + } + + return 0; +} + +static void queue_con(struct ceph_connection *con) +{ + (void) queue_con_delay(con, 0); +} + +static void cancel_con(struct ceph_connection *con) +{ + if (cancel_delayed_work(&con->work)) { + dout("%s %p\n", __func__, con); + con->ops->put(con); + } +} + +static bool con_sock_closed(struct ceph_connection *con) +{ + if (!ceph_con_flag_test_and_clear(con, CEPH_CON_F_SOCK_CLOSED)) + return false; + +#define CASE(x) \ + case CEPH_CON_S_ ## x: \ + con->error_msg = "socket closed (con state " #x ")"; \ + break; + + switch (con->state) { + CASE(CLOSED); + CASE(PREOPEN); + CASE(V1_BANNER); + CASE(V1_CONNECT_MSG); + CASE(V2_BANNER_PREFIX); + CASE(V2_BANNER_PAYLOAD); + CASE(V2_HELLO); + CASE(V2_AUTH); + CASE(V2_AUTH_SIGNATURE); + CASE(V2_SESSION_CONNECT); + CASE(V2_SESSION_RECONNECT); + CASE(OPEN); + CASE(STANDBY); + default: + BUG(); + } +#undef CASE + + return true; +} + +static bool con_backoff(struct ceph_connection *con) +{ + int ret; + + if (!ceph_con_flag_test_and_clear(con, CEPH_CON_F_BACKOFF)) + return false; + + ret = queue_con_delay(con, con->delay); + if (ret) { + dout("%s: con %p FAILED to back off %lu\n", __func__, + con, con->delay); + BUG_ON(ret == -ENOENT); + ceph_con_flag_set(con, CEPH_CON_F_BACKOFF); + } + + return true; +} + +/* Finish fault handling; con->mutex must *not* be held here */ + +static void con_fault_finish(struct ceph_connection *con) +{ + dout("%s %p\n", __func__, con); + + /* + * in case we faulted due to authentication, invalidate our + * current tickets so that we can get new ones. + */ + if (con->v1.auth_retry) { + dout("auth_retry %d, invalidating\n", con->v1.auth_retry); + if (con->ops->invalidate_authorizer) + con->ops->invalidate_authorizer(con); + con->v1.auth_retry = 0; + } + + if (con->ops->fault) + con->ops->fault(con); +} + +/* + * Do some work on a connection. Drop a connection ref when we're done. + */ +static void ceph_con_workfn(struct work_struct *work) +{ + struct ceph_connection *con = container_of(work, struct ceph_connection, + work.work); + bool fault; + + mutex_lock(&con->mutex); + while (true) { + int ret; + + if ((fault = con_sock_closed(con))) { + dout("%s: con %p SOCK_CLOSED\n", __func__, con); + break; + } + if (con_backoff(con)) { + dout("%s: con %p BACKOFF\n", __func__, con); + break; + } + if (con->state == CEPH_CON_S_STANDBY) { + dout("%s: con %p STANDBY\n", __func__, con); + break; + } + if (con->state == CEPH_CON_S_CLOSED) { + dout("%s: con %p CLOSED\n", __func__, con); + BUG_ON(con->sock); + break; + } + if (con->state == CEPH_CON_S_PREOPEN) { + dout("%s: con %p PREOPEN\n", __func__, con); + BUG_ON(con->sock); + } + + if (ceph_msgr2(from_msgr(con->msgr))) + ret = ceph_con_v2_try_read(con); + else + ret = ceph_con_v1_try_read(con); + if (ret < 0) { + if (ret == -EAGAIN) + continue; + if (!con->error_msg) + con->error_msg = "socket error on read"; + fault = true; + break; + } + + if (ceph_msgr2(from_msgr(con->msgr))) + ret = ceph_con_v2_try_write(con); + else + ret = ceph_con_v1_try_write(con); + if (ret < 0) { + if (ret == -EAGAIN) + continue; + if (!con->error_msg) + con->error_msg = "socket error on write"; + fault = true; + } + + break; /* If we make it to here, we're done */ + } + if (fault) + con_fault(con); + mutex_unlock(&con->mutex); + + if (fault) + con_fault_finish(con); + + con->ops->put(con); +} + +/* + * Generic error/fault handler. A retry mechanism is used with + * exponential backoff + */ +static void con_fault(struct ceph_connection *con) +{ + dout("fault %p state %d to peer %s\n", + con, con->state, ceph_pr_addr(&con->peer_addr)); + + pr_warn("%s%lld %s %s\n", ENTITY_NAME(con->peer_name), + ceph_pr_addr(&con->peer_addr), con->error_msg); + con->error_msg = NULL; + + WARN_ON(con->state == CEPH_CON_S_STANDBY || + con->state == CEPH_CON_S_CLOSED); + + ceph_con_reset_protocol(con); + + if (ceph_con_flag_test(con, CEPH_CON_F_LOSSYTX)) { + dout("fault on LOSSYTX channel, marking CLOSED\n"); + con->state = CEPH_CON_S_CLOSED; + return; + } + + /* Requeue anything that hasn't been acked */ + list_splice_init(&con->out_sent, &con->out_queue); + + /* If there are no messages queued or keepalive pending, place + * the connection in a STANDBY state */ + if (list_empty(&con->out_queue) && + !ceph_con_flag_test(con, CEPH_CON_F_KEEPALIVE_PENDING)) { + dout("fault %p setting STANDBY clearing WRITE_PENDING\n", con); + ceph_con_flag_clear(con, CEPH_CON_F_WRITE_PENDING); + con->state = CEPH_CON_S_STANDBY; + } else { + /* retry after a delay. */ + con->state = CEPH_CON_S_PREOPEN; + if (!con->delay) { + con->delay = BASE_DELAY_INTERVAL; + } else if (con->delay < MAX_DELAY_INTERVAL) { + con->delay *= 2; + if (con->delay > MAX_DELAY_INTERVAL) + con->delay = MAX_DELAY_INTERVAL; + } + ceph_con_flag_set(con, CEPH_CON_F_BACKOFF); + queue_con(con); + } +} + +void ceph_messenger_reset_nonce(struct ceph_messenger *msgr) +{ + u32 nonce = le32_to_cpu(msgr->inst.addr.nonce) + 1000000; + msgr->inst.addr.nonce = cpu_to_le32(nonce); + ceph_encode_my_addr(msgr); +} + +/* + * initialize a new messenger instance + */ +void ceph_messenger_init(struct ceph_messenger *msgr, + struct ceph_entity_addr *myaddr) +{ + spin_lock_init(&msgr->global_seq_lock); + + if (myaddr) { + memcpy(&msgr->inst.addr.in_addr, &myaddr->in_addr, + sizeof(msgr->inst.addr.in_addr)); + ceph_addr_set_port(&msgr->inst.addr, 0); + } + + /* + * Since nautilus, clients are identified using type ANY. + * For msgr1, ceph_encode_banner_addr() munges it to NONE. + */ + msgr->inst.addr.type = CEPH_ENTITY_ADDR_TYPE_ANY; + + /* generate a random non-zero nonce */ + do { + get_random_bytes(&msgr->inst.addr.nonce, + sizeof(msgr->inst.addr.nonce)); + } while (!msgr->inst.addr.nonce); + ceph_encode_my_addr(msgr); + + atomic_set(&msgr->stopping, 0); + write_pnet(&msgr->net, get_net(current->nsproxy->net_ns)); + + dout("%s %p\n", __func__, msgr); +} + +void ceph_messenger_fini(struct ceph_messenger *msgr) +{ + put_net(read_pnet(&msgr->net)); +} + +static void msg_con_set(struct ceph_msg *msg, struct ceph_connection *con) +{ + if (msg->con) + msg->con->ops->put(msg->con); + + msg->con = con ? con->ops->get(con) : NULL; + BUG_ON(msg->con != con); +} + +static void clear_standby(struct ceph_connection *con) +{ + /* come back from STANDBY? */ + if (con->state == CEPH_CON_S_STANDBY) { + dout("clear_standby %p and ++connect_seq\n", con); + con->state = CEPH_CON_S_PREOPEN; + con->v1.connect_seq++; + WARN_ON(ceph_con_flag_test(con, CEPH_CON_F_WRITE_PENDING)); + WARN_ON(ceph_con_flag_test(con, CEPH_CON_F_KEEPALIVE_PENDING)); + } +} + +/* + * Queue up an outgoing message on the given connection. + * + * Consumes a ref on @msg. + */ +void ceph_con_send(struct ceph_connection *con, struct ceph_msg *msg) +{ + /* set src+dst */ + msg->hdr.src = con->msgr->inst.name; + BUG_ON(msg->front.iov_len != le32_to_cpu(msg->hdr.front_len)); + msg->needs_out_seq = true; + + mutex_lock(&con->mutex); + + if (con->state == CEPH_CON_S_CLOSED) { + dout("con_send %p closed, dropping %p\n", con, msg); + ceph_msg_put(msg); + mutex_unlock(&con->mutex); + return; + } + + msg_con_set(msg, con); + + BUG_ON(!list_empty(&msg->list_head)); + list_add_tail(&msg->list_head, &con->out_queue); + dout("----- %p to %s%lld %d=%s len %d+%d+%d -----\n", msg, + ENTITY_NAME(con->peer_name), le16_to_cpu(msg->hdr.type), + ceph_msg_type_name(le16_to_cpu(msg->hdr.type)), + le32_to_cpu(msg->hdr.front_len), + le32_to_cpu(msg->hdr.middle_len), + le32_to_cpu(msg->hdr.data_len)); + + clear_standby(con); + mutex_unlock(&con->mutex); + + /* if there wasn't anything waiting to send before, queue + * new work */ + if (!ceph_con_flag_test_and_set(con, CEPH_CON_F_WRITE_PENDING)) + queue_con(con); +} +EXPORT_SYMBOL(ceph_con_send); + +/* + * Revoke a message that was previously queued for send + */ +void ceph_msg_revoke(struct ceph_msg *msg) +{ + struct ceph_connection *con = msg->con; + + if (!con) { + dout("%s msg %p null con\n", __func__, msg); + return; /* Message not in our possession */ + } + + mutex_lock(&con->mutex); + if (list_empty(&msg->list_head)) { + WARN_ON(con->out_msg == msg); + dout("%s con %p msg %p not linked\n", __func__, con, msg); + mutex_unlock(&con->mutex); + return; + } + + dout("%s con %p msg %p was linked\n", __func__, con, msg); + msg->hdr.seq = 0; + ceph_msg_remove(msg); + + if (con->out_msg == msg) { + WARN_ON(con->state != CEPH_CON_S_OPEN); + dout("%s con %p msg %p was sending\n", __func__, con, msg); + if (ceph_msgr2(from_msgr(con->msgr))) + ceph_con_v2_revoke(con); + else + ceph_con_v1_revoke(con); + ceph_msg_put(con->out_msg); + con->out_msg = NULL; + } else { + dout("%s con %p msg %p not current, out_msg %p\n", __func__, + con, msg, con->out_msg); + } + mutex_unlock(&con->mutex); +} + +/* + * Revoke a message that we may be reading data into + */ +void ceph_msg_revoke_incoming(struct ceph_msg *msg) +{ + struct ceph_connection *con = msg->con; + + if (!con) { + dout("%s msg %p null con\n", __func__, msg); + return; /* Message not in our possession */ + } + + mutex_lock(&con->mutex); + if (con->in_msg == msg) { + WARN_ON(con->state != CEPH_CON_S_OPEN); + dout("%s con %p msg %p was recving\n", __func__, con, msg); + if (ceph_msgr2(from_msgr(con->msgr))) + ceph_con_v2_revoke_incoming(con); + else + ceph_con_v1_revoke_incoming(con); + ceph_msg_put(con->in_msg); + con->in_msg = NULL; + } else { + dout("%s con %p msg %p not current, in_msg %p\n", __func__, + con, msg, con->in_msg); + } + mutex_unlock(&con->mutex); +} + +/* + * Queue a keepalive byte to ensure the tcp connection is alive. + */ +void ceph_con_keepalive(struct ceph_connection *con) +{ + dout("con_keepalive %p\n", con); + mutex_lock(&con->mutex); + clear_standby(con); + ceph_con_flag_set(con, CEPH_CON_F_KEEPALIVE_PENDING); + mutex_unlock(&con->mutex); + + if (!ceph_con_flag_test_and_set(con, CEPH_CON_F_WRITE_PENDING)) + queue_con(con); +} +EXPORT_SYMBOL(ceph_con_keepalive); + +bool ceph_con_keepalive_expired(struct ceph_connection *con, + unsigned long interval) +{ + if (interval > 0 && + (con->peer_features & CEPH_FEATURE_MSGR_KEEPALIVE2)) { + struct timespec64 now; + struct timespec64 ts; + ktime_get_real_ts64(&now); + jiffies_to_timespec64(interval, &ts); + ts = timespec64_add(con->last_keepalive_ack, ts); + return timespec64_compare(&now, &ts) >= 0; + } + return false; +} + +static struct ceph_msg_data *ceph_msg_data_add(struct ceph_msg *msg) +{ + BUG_ON(msg->num_data_items >= msg->max_data_items); + return &msg->data[msg->num_data_items++]; +} + +static void ceph_msg_data_destroy(struct ceph_msg_data *data) +{ + if (data->type == CEPH_MSG_DATA_PAGES && data->own_pages) { + int num_pages = calc_pages_for(data->alignment, data->length); + ceph_release_page_vector(data->pages, num_pages); + } else if (data->type == CEPH_MSG_DATA_PAGELIST) { + ceph_pagelist_release(data->pagelist); + } +} + +void ceph_msg_data_add_pages(struct ceph_msg *msg, struct page **pages, + size_t length, size_t alignment, bool own_pages) +{ + struct ceph_msg_data *data; + + BUG_ON(!pages); + BUG_ON(!length); + + data = ceph_msg_data_add(msg); + data->type = CEPH_MSG_DATA_PAGES; + data->pages = pages; + data->length = length; + data->alignment = alignment & ~PAGE_MASK; + data->own_pages = own_pages; + + msg->data_length += length; +} +EXPORT_SYMBOL(ceph_msg_data_add_pages); + +void ceph_msg_data_add_pagelist(struct ceph_msg *msg, + struct ceph_pagelist *pagelist) +{ + struct ceph_msg_data *data; + + BUG_ON(!pagelist); + BUG_ON(!pagelist->length); + + data = ceph_msg_data_add(msg); + data->type = CEPH_MSG_DATA_PAGELIST; + refcount_inc(&pagelist->refcnt); + data->pagelist = pagelist; + + msg->data_length += pagelist->length; +} +EXPORT_SYMBOL(ceph_msg_data_add_pagelist); + +#ifdef CONFIG_BLOCK +void ceph_msg_data_add_bio(struct ceph_msg *msg, struct ceph_bio_iter *bio_pos, + u32 length) +{ + struct ceph_msg_data *data; + + data = ceph_msg_data_add(msg); + data->type = CEPH_MSG_DATA_BIO; + data->bio_pos = *bio_pos; + data->bio_length = length; + + msg->data_length += length; +} +EXPORT_SYMBOL(ceph_msg_data_add_bio); +#endif /* CONFIG_BLOCK */ + +void ceph_msg_data_add_bvecs(struct ceph_msg *msg, + struct ceph_bvec_iter *bvec_pos) +{ + struct ceph_msg_data *data; + + data = ceph_msg_data_add(msg); + data->type = CEPH_MSG_DATA_BVECS; + data->bvec_pos = *bvec_pos; + + msg->data_length += bvec_pos->iter.bi_size; +} +EXPORT_SYMBOL(ceph_msg_data_add_bvecs); + +/* + * construct a new message with given type, size + * the new msg has a ref count of 1. + */ +struct ceph_msg *ceph_msg_new2(int type, int front_len, int max_data_items, + gfp_t flags, bool can_fail) +{ + struct ceph_msg *m; + + m = kmem_cache_zalloc(ceph_msg_cache, flags); + if (m == NULL) + goto out; + + m->hdr.type = cpu_to_le16(type); + m->hdr.priority = cpu_to_le16(CEPH_MSG_PRIO_DEFAULT); + m->hdr.front_len = cpu_to_le32(front_len); + + INIT_LIST_HEAD(&m->list_head); + kref_init(&m->kref); + + /* front */ + if (front_len) { + m->front.iov_base = kvmalloc(front_len, flags); + if (m->front.iov_base == NULL) { + dout("ceph_msg_new can't allocate %d bytes\n", + front_len); + goto out2; + } + } else { + m->front.iov_base = NULL; + } + m->front_alloc_len = m->front.iov_len = front_len; + + if (max_data_items) { + m->data = kmalloc_array(max_data_items, sizeof(*m->data), + flags); + if (!m->data) + goto out2; + + m->max_data_items = max_data_items; + } + + dout("ceph_msg_new %p front %d\n", m, front_len); + return m; + +out2: + ceph_msg_put(m); +out: + if (!can_fail) { + pr_err("msg_new can't create type %d front %d\n", type, + front_len); + WARN_ON(1); + } else { + dout("msg_new can't create type %d front %d\n", type, + front_len); + } + return NULL; +} +EXPORT_SYMBOL(ceph_msg_new2); + +struct ceph_msg *ceph_msg_new(int type, int front_len, gfp_t flags, + bool can_fail) +{ + return ceph_msg_new2(type, front_len, 0, flags, can_fail); +} +EXPORT_SYMBOL(ceph_msg_new); + +/* + * Allocate "middle" portion of a message, if it is needed and wasn't + * allocated by alloc_msg. This allows us to read a small fixed-size + * per-type header in the front and then gracefully fail (i.e., + * propagate the error to the caller based on info in the front) when + * the middle is too large. + */ +static int ceph_alloc_middle(struct ceph_connection *con, struct ceph_msg *msg) +{ + int type = le16_to_cpu(msg->hdr.type); + int middle_len = le32_to_cpu(msg->hdr.middle_len); + + dout("alloc_middle %p type %d %s middle_len %d\n", msg, type, + ceph_msg_type_name(type), middle_len); + BUG_ON(!middle_len); + BUG_ON(msg->middle); + + msg->middle = ceph_buffer_new(middle_len, GFP_NOFS); + if (!msg->middle) + return -ENOMEM; + return 0; +} + +/* + * Allocate a message for receiving an incoming message on a + * connection, and save the result in con->in_msg. Uses the + * connection's private alloc_msg op if available. + * + * Returns 0 on success, or a negative error code. + * + * On success, if we set *skip = 1: + * - the next message should be skipped and ignored. + * - con->in_msg == NULL + * or if we set *skip = 0: + * - con->in_msg is non-null. + * On error (ENOMEM, EAGAIN, ...), + * - con->in_msg == NULL + */ +int ceph_con_in_msg_alloc(struct ceph_connection *con, + struct ceph_msg_header *hdr, int *skip) +{ + int middle_len = le32_to_cpu(hdr->middle_len); + struct ceph_msg *msg; + int ret = 0; + + BUG_ON(con->in_msg != NULL); + BUG_ON(!con->ops->alloc_msg); + + mutex_unlock(&con->mutex); + msg = con->ops->alloc_msg(con, hdr, skip); + mutex_lock(&con->mutex); + if (con->state != CEPH_CON_S_OPEN) { + if (msg) + ceph_msg_put(msg); + return -EAGAIN; + } + if (msg) { + BUG_ON(*skip); + msg_con_set(msg, con); + con->in_msg = msg; + } else { + /* + * Null message pointer means either we should skip + * this message or we couldn't allocate memory. The + * former is not an error. + */ + if (*skip) + return 0; + + con->error_msg = "error allocating memory for incoming message"; + return -ENOMEM; + } + memcpy(&con->in_msg->hdr, hdr, sizeof(*hdr)); + + if (middle_len && !con->in_msg->middle) { + ret = ceph_alloc_middle(con, con->in_msg); + if (ret < 0) { + ceph_msg_put(con->in_msg); + con->in_msg = NULL; + } + } + + return ret; +} + +void ceph_con_get_out_msg(struct ceph_connection *con) +{ + struct ceph_msg *msg; + + BUG_ON(list_empty(&con->out_queue)); + msg = list_first_entry(&con->out_queue, struct ceph_msg, list_head); + WARN_ON(msg->con != con); + + /* + * Put the message on "sent" list using a ref from ceph_con_send(). + * It is put when the message is acked or revoked. + */ + list_move_tail(&msg->list_head, &con->out_sent); + + /* + * Only assign outgoing seq # if we haven't sent this message + * yet. If it is requeued, resend with it's original seq. + */ + if (msg->needs_out_seq) { + msg->hdr.seq = cpu_to_le64(++con->out_seq); + msg->needs_out_seq = false; + + if (con->ops->reencode_message) + con->ops->reencode_message(msg); + } + + /* + * Get a ref for out_msg. It is put when we are done sending the + * message or in case of a fault. + */ + WARN_ON(con->out_msg); + con->out_msg = ceph_msg_get(msg); +} + +/* + * Free a generically kmalloc'd message. + */ +static void ceph_msg_free(struct ceph_msg *m) +{ + dout("%s %p\n", __func__, m); + kvfree(m->front.iov_base); + kfree(m->data); + kmem_cache_free(ceph_msg_cache, m); +} + +static void ceph_msg_release(struct kref *kref) +{ + struct ceph_msg *m = container_of(kref, struct ceph_msg, kref); + int i; + + dout("%s %p\n", __func__, m); + WARN_ON(!list_empty(&m->list_head)); + + msg_con_set(m, NULL); + + /* drop middle, data, if any */ + if (m->middle) { + ceph_buffer_put(m->middle); + m->middle = NULL; + } + + for (i = 0; i < m->num_data_items; i++) + ceph_msg_data_destroy(&m->data[i]); + + if (m->pool) + ceph_msgpool_put(m->pool, m); + else + ceph_msg_free(m); +} + +struct ceph_msg *ceph_msg_get(struct ceph_msg *msg) +{ + dout("%s %p (was %d)\n", __func__, msg, + kref_read(&msg->kref)); + kref_get(&msg->kref); + return msg; +} +EXPORT_SYMBOL(ceph_msg_get); + +void ceph_msg_put(struct ceph_msg *msg) +{ + dout("%s %p (was %d)\n", __func__, msg, + kref_read(&msg->kref)); + kref_put(&msg->kref, ceph_msg_release); +} +EXPORT_SYMBOL(ceph_msg_put); + +void ceph_msg_dump(struct ceph_msg *msg) +{ + pr_debug("msg_dump %p (front_alloc_len %d length %zd)\n", msg, + msg->front_alloc_len, msg->data_length); + print_hex_dump(KERN_DEBUG, "header: ", + DUMP_PREFIX_OFFSET, 16, 1, + &msg->hdr, sizeof(msg->hdr), true); + print_hex_dump(KERN_DEBUG, " front: ", + DUMP_PREFIX_OFFSET, 16, 1, + msg->front.iov_base, msg->front.iov_len, true); + if (msg->middle) + print_hex_dump(KERN_DEBUG, "middle: ", + DUMP_PREFIX_OFFSET, 16, 1, + msg->middle->vec.iov_base, + msg->middle->vec.iov_len, true); + print_hex_dump(KERN_DEBUG, "footer: ", + DUMP_PREFIX_OFFSET, 16, 1, + &msg->footer, sizeof(msg->footer), true); +} +EXPORT_SYMBOL(ceph_msg_dump); diff --git a/net/ceph/messenger_v1.c b/net/ceph/messenger_v1.c new file mode 100644 index 000000000..d1787d7d3 --- /dev/null +++ b/net/ceph/messenger_v1.c @@ -0,0 +1,1548 @@ +// SPDX-License-Identifier: GPL-2.0 +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +/* static tag bytes (protocol control messages) */ +static char tag_msg = CEPH_MSGR_TAG_MSG; +static char tag_ack = CEPH_MSGR_TAG_ACK; +static char tag_keepalive = CEPH_MSGR_TAG_KEEPALIVE; +static char tag_keepalive2 = CEPH_MSGR_TAG_KEEPALIVE2; + +/* + * If @buf is NULL, discard up to @len bytes. + */ +static int ceph_tcp_recvmsg(struct socket *sock, void *buf, size_t len) +{ + struct kvec iov = {buf, len}; + struct msghdr msg = { .msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL }; + int r; + + if (!buf) + msg.msg_flags |= MSG_TRUNC; + + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &iov, 1, len); + r = sock_recvmsg(sock, &msg, msg.msg_flags); + if (r == -EAGAIN) + r = 0; + return r; +} + +static int ceph_tcp_recvpage(struct socket *sock, struct page *page, + int page_offset, size_t length) +{ + struct bio_vec bvec = { + .bv_page = page, + .bv_offset = page_offset, + .bv_len = length + }; + struct msghdr msg = { .msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL }; + int r; + + BUG_ON(page_offset + length > PAGE_SIZE); + iov_iter_bvec(&msg.msg_iter, ITER_DEST, &bvec, 1, length); + r = sock_recvmsg(sock, &msg, msg.msg_flags); + if (r == -EAGAIN) + r = 0; + return r; +} + +/* + * write something. @more is true if caller will be sending more data + * shortly. + */ +static int ceph_tcp_sendmsg(struct socket *sock, struct kvec *iov, + size_t kvlen, size_t len, bool more) +{ + struct msghdr msg = { .msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL }; + int r; + + if (more) + msg.msg_flags |= MSG_MORE; + else + msg.msg_flags |= MSG_EOR; /* superfluous, but what the hell */ + + r = kernel_sendmsg(sock, &msg, iov, kvlen, len); + if (r == -EAGAIN) + r = 0; + return r; +} + +/* + * @more: either or both of MSG_MORE and MSG_SENDPAGE_NOTLAST + */ +static int ceph_tcp_sendpage(struct socket *sock, struct page *page, + int offset, size_t size, int more) +{ + ssize_t (*sendpage)(struct socket *sock, struct page *page, + int offset, size_t size, int flags); + int flags = MSG_DONTWAIT | MSG_NOSIGNAL | more; + int ret; + + /* + * sendpage cannot properly handle pages with page_count == 0, + * we need to fall back to sendmsg if that's the case. + * + * Same goes for slab pages: skb_can_coalesce() allows + * coalescing neighboring slab objects into a single frag which + * triggers one of hardened usercopy checks. + */ + if (sendpage_ok(page)) + sendpage = sock->ops->sendpage; + else + sendpage = sock_no_sendpage; + + ret = sendpage(sock, page, offset, size, flags); + if (ret == -EAGAIN) + ret = 0; + + return ret; +} + +static void con_out_kvec_reset(struct ceph_connection *con) +{ + BUG_ON(con->v1.out_skip); + + con->v1.out_kvec_left = 0; + con->v1.out_kvec_bytes = 0; + con->v1.out_kvec_cur = &con->v1.out_kvec[0]; +} + +static void con_out_kvec_add(struct ceph_connection *con, + size_t size, void *data) +{ + int index = con->v1.out_kvec_left; + + BUG_ON(con->v1.out_skip); + BUG_ON(index >= ARRAY_SIZE(con->v1.out_kvec)); + + con->v1.out_kvec[index].iov_len = size; + con->v1.out_kvec[index].iov_base = data; + con->v1.out_kvec_left++; + con->v1.out_kvec_bytes += size; +} + +/* + * Chop off a kvec from the end. Return residual number of bytes for + * that kvec, i.e. how many bytes would have been written if the kvec + * hadn't been nuked. + */ +static int con_out_kvec_skip(struct ceph_connection *con) +{ + int skip = 0; + + if (con->v1.out_kvec_bytes > 0) { + skip = con->v1.out_kvec_cur[con->v1.out_kvec_left - 1].iov_len; + BUG_ON(con->v1.out_kvec_bytes < skip); + BUG_ON(!con->v1.out_kvec_left); + con->v1.out_kvec_bytes -= skip; + con->v1.out_kvec_left--; + } + + return skip; +} + +static size_t sizeof_footer(struct ceph_connection *con) +{ + return (con->peer_features & CEPH_FEATURE_MSG_AUTH) ? + sizeof(struct ceph_msg_footer) : + sizeof(struct ceph_msg_footer_old); +} + +static void prepare_message_data(struct ceph_msg *msg, u32 data_len) +{ + /* Initialize data cursor */ + + ceph_msg_data_cursor_init(&msg->cursor, msg, data_len); +} + +/* + * Prepare footer for currently outgoing message, and finish things + * off. Assumes out_kvec* are already valid.. we just add on to the end. + */ +static void prepare_write_message_footer(struct ceph_connection *con) +{ + struct ceph_msg *m = con->out_msg; + + m->footer.flags |= CEPH_MSG_FOOTER_COMPLETE; + + dout("prepare_write_message_footer %p\n", con); + con_out_kvec_add(con, sizeof_footer(con), &m->footer); + if (con->peer_features & CEPH_FEATURE_MSG_AUTH) { + if (con->ops->sign_message) + con->ops->sign_message(m); + else + m->footer.sig = 0; + } else { + m->old_footer.flags = m->footer.flags; + } + con->v1.out_more = m->more_to_follow; + con->v1.out_msg_done = true; +} + +/* + * Prepare headers for the next outgoing message. + */ +static void prepare_write_message(struct ceph_connection *con) +{ + struct ceph_msg *m; + u32 crc; + + con_out_kvec_reset(con); + con->v1.out_msg_done = false; + + /* Sneak an ack in there first? If we can get it into the same + * TCP packet that's a good thing. */ + if (con->in_seq > con->in_seq_acked) { + con->in_seq_acked = con->in_seq; + con_out_kvec_add(con, sizeof (tag_ack), &tag_ack); + con->v1.out_temp_ack = cpu_to_le64(con->in_seq_acked); + con_out_kvec_add(con, sizeof(con->v1.out_temp_ack), + &con->v1.out_temp_ack); + } + + ceph_con_get_out_msg(con); + m = con->out_msg; + + dout("prepare_write_message %p seq %lld type %d len %d+%d+%zd\n", + m, con->out_seq, le16_to_cpu(m->hdr.type), + le32_to_cpu(m->hdr.front_len), le32_to_cpu(m->hdr.middle_len), + m->data_length); + WARN_ON(m->front.iov_len != le32_to_cpu(m->hdr.front_len)); + WARN_ON(m->data_length != le32_to_cpu(m->hdr.data_len)); + + /* tag + hdr + front + middle */ + con_out_kvec_add(con, sizeof (tag_msg), &tag_msg); + con_out_kvec_add(con, sizeof(con->v1.out_hdr), &con->v1.out_hdr); + con_out_kvec_add(con, m->front.iov_len, m->front.iov_base); + + if (m->middle) + con_out_kvec_add(con, m->middle->vec.iov_len, + m->middle->vec.iov_base); + + /* fill in hdr crc and finalize hdr */ + crc = crc32c(0, &m->hdr, offsetof(struct ceph_msg_header, crc)); + con->out_msg->hdr.crc = cpu_to_le32(crc); + memcpy(&con->v1.out_hdr, &con->out_msg->hdr, sizeof(con->v1.out_hdr)); + + /* fill in front and middle crc, footer */ + crc = crc32c(0, m->front.iov_base, m->front.iov_len); + con->out_msg->footer.front_crc = cpu_to_le32(crc); + if (m->middle) { + crc = crc32c(0, m->middle->vec.iov_base, + m->middle->vec.iov_len); + con->out_msg->footer.middle_crc = cpu_to_le32(crc); + } else + con->out_msg->footer.middle_crc = 0; + dout("%s front_crc %u middle_crc %u\n", __func__, + le32_to_cpu(con->out_msg->footer.front_crc), + le32_to_cpu(con->out_msg->footer.middle_crc)); + con->out_msg->footer.flags = 0; + + /* is there a data payload? */ + con->out_msg->footer.data_crc = 0; + if (m->data_length) { + prepare_message_data(con->out_msg, m->data_length); + con->v1.out_more = 1; /* data + footer will follow */ + } else { + /* no, queue up footer too and be done */ + prepare_write_message_footer(con); + } + + ceph_con_flag_set(con, CEPH_CON_F_WRITE_PENDING); +} + +/* + * Prepare an ack. + */ +static void prepare_write_ack(struct ceph_connection *con) +{ + dout("prepare_write_ack %p %llu -> %llu\n", con, + con->in_seq_acked, con->in_seq); + con->in_seq_acked = con->in_seq; + + con_out_kvec_reset(con); + + con_out_kvec_add(con, sizeof (tag_ack), &tag_ack); + + con->v1.out_temp_ack = cpu_to_le64(con->in_seq_acked); + con_out_kvec_add(con, sizeof(con->v1.out_temp_ack), + &con->v1.out_temp_ack); + + con->v1.out_more = 1; /* more will follow.. eventually.. */ + ceph_con_flag_set(con, CEPH_CON_F_WRITE_PENDING); +} + +/* + * Prepare to share the seq during handshake + */ +static void prepare_write_seq(struct ceph_connection *con) +{ + dout("prepare_write_seq %p %llu -> %llu\n", con, + con->in_seq_acked, con->in_seq); + con->in_seq_acked = con->in_seq; + + con_out_kvec_reset(con); + + con->v1.out_temp_ack = cpu_to_le64(con->in_seq_acked); + con_out_kvec_add(con, sizeof(con->v1.out_temp_ack), + &con->v1.out_temp_ack); + + ceph_con_flag_set(con, CEPH_CON_F_WRITE_PENDING); +} + +/* + * Prepare to write keepalive byte. + */ +static void prepare_write_keepalive(struct ceph_connection *con) +{ + dout("prepare_write_keepalive %p\n", con); + con_out_kvec_reset(con); + if (con->peer_features & CEPH_FEATURE_MSGR_KEEPALIVE2) { + struct timespec64 now; + + ktime_get_real_ts64(&now); + con_out_kvec_add(con, sizeof(tag_keepalive2), &tag_keepalive2); + ceph_encode_timespec64(&con->v1.out_temp_keepalive2, &now); + con_out_kvec_add(con, sizeof(con->v1.out_temp_keepalive2), + &con->v1.out_temp_keepalive2); + } else { + con_out_kvec_add(con, sizeof(tag_keepalive), &tag_keepalive); + } + ceph_con_flag_set(con, CEPH_CON_F_WRITE_PENDING); +} + +/* + * Connection negotiation. + */ + +static int get_connect_authorizer(struct ceph_connection *con) +{ + struct ceph_auth_handshake *auth; + int auth_proto; + + if (!con->ops->get_authorizer) { + con->v1.auth = NULL; + con->v1.out_connect.authorizer_protocol = CEPH_AUTH_UNKNOWN; + con->v1.out_connect.authorizer_len = 0; + return 0; + } + + auth = con->ops->get_authorizer(con, &auth_proto, con->v1.auth_retry); + if (IS_ERR(auth)) + return PTR_ERR(auth); + + con->v1.auth = auth; + con->v1.out_connect.authorizer_protocol = cpu_to_le32(auth_proto); + con->v1.out_connect.authorizer_len = + cpu_to_le32(auth->authorizer_buf_len); + return 0; +} + +/* + * We connected to a peer and are saying hello. + */ +static void prepare_write_banner(struct ceph_connection *con) +{ + con_out_kvec_add(con, strlen(CEPH_BANNER), CEPH_BANNER); + con_out_kvec_add(con, sizeof (con->msgr->my_enc_addr), + &con->msgr->my_enc_addr); + + con->v1.out_more = 0; + ceph_con_flag_set(con, CEPH_CON_F_WRITE_PENDING); +} + +static void __prepare_write_connect(struct ceph_connection *con) +{ + con_out_kvec_add(con, sizeof(con->v1.out_connect), + &con->v1.out_connect); + if (con->v1.auth) + con_out_kvec_add(con, con->v1.auth->authorizer_buf_len, + con->v1.auth->authorizer_buf); + + con->v1.out_more = 0; + ceph_con_flag_set(con, CEPH_CON_F_WRITE_PENDING); +} + +static int prepare_write_connect(struct ceph_connection *con) +{ + unsigned int global_seq = ceph_get_global_seq(con->msgr, 0); + int proto; + int ret; + + switch (con->peer_name.type) { + case CEPH_ENTITY_TYPE_MON: + proto = CEPH_MONC_PROTOCOL; + break; + case CEPH_ENTITY_TYPE_OSD: + proto = CEPH_OSDC_PROTOCOL; + break; + case CEPH_ENTITY_TYPE_MDS: + proto = CEPH_MDSC_PROTOCOL; + break; + default: + BUG(); + } + + dout("prepare_write_connect %p cseq=%d gseq=%d proto=%d\n", con, + con->v1.connect_seq, global_seq, proto); + + con->v1.out_connect.features = + cpu_to_le64(from_msgr(con->msgr)->supported_features); + con->v1.out_connect.host_type = cpu_to_le32(CEPH_ENTITY_TYPE_CLIENT); + con->v1.out_connect.connect_seq = cpu_to_le32(con->v1.connect_seq); + con->v1.out_connect.global_seq = cpu_to_le32(global_seq); + con->v1.out_connect.protocol_version = cpu_to_le32(proto); + con->v1.out_connect.flags = 0; + + ret = get_connect_authorizer(con); + if (ret) + return ret; + + __prepare_write_connect(con); + return 0; +} + +/* + * write as much of pending kvecs to the socket as we can. + * 1 -> done + * 0 -> socket full, but more to do + * <0 -> error + */ +static int write_partial_kvec(struct ceph_connection *con) +{ + int ret; + + dout("write_partial_kvec %p %d left\n", con, con->v1.out_kvec_bytes); + while (con->v1.out_kvec_bytes > 0) { + ret = ceph_tcp_sendmsg(con->sock, con->v1.out_kvec_cur, + con->v1.out_kvec_left, + con->v1.out_kvec_bytes, + con->v1.out_more); + if (ret <= 0) + goto out; + con->v1.out_kvec_bytes -= ret; + if (!con->v1.out_kvec_bytes) + break; /* done */ + + /* account for full iov entries consumed */ + while (ret >= con->v1.out_kvec_cur->iov_len) { + BUG_ON(!con->v1.out_kvec_left); + ret -= con->v1.out_kvec_cur->iov_len; + con->v1.out_kvec_cur++; + con->v1.out_kvec_left--; + } + /* and for a partially-consumed entry */ + if (ret) { + con->v1.out_kvec_cur->iov_len -= ret; + con->v1.out_kvec_cur->iov_base += ret; + } + } + con->v1.out_kvec_left = 0; + ret = 1; +out: + dout("write_partial_kvec %p %d left in %d kvecs ret = %d\n", con, + con->v1.out_kvec_bytes, con->v1.out_kvec_left, ret); + return ret; /* done! */ +} + +/* + * Write as much message data payload as we can. If we finish, queue + * up the footer. + * 1 -> done, footer is now queued in out_kvec[]. + * 0 -> socket full, but more to do + * <0 -> error + */ +static int write_partial_message_data(struct ceph_connection *con) +{ + struct ceph_msg *msg = con->out_msg; + struct ceph_msg_data_cursor *cursor = &msg->cursor; + bool do_datacrc = !ceph_test_opt(from_msgr(con->msgr), NOCRC); + int more = MSG_MORE | MSG_SENDPAGE_NOTLAST; + u32 crc; + + dout("%s %p msg %p\n", __func__, con, msg); + + if (!msg->num_data_items) + return -EINVAL; + + /* + * Iterate through each page that contains data to be + * written, and send as much as possible for each. + * + * If we are calculating the data crc (the default), we will + * need to map the page. If we have no pages, they have + * been revoked, so use the zero page. + */ + crc = do_datacrc ? le32_to_cpu(msg->footer.data_crc) : 0; + while (cursor->total_resid) { + struct page *page; + size_t page_offset; + size_t length; + int ret; + + if (!cursor->resid) { + ceph_msg_data_advance(cursor, 0); + continue; + } + + page = ceph_msg_data_next(cursor, &page_offset, &length); + if (length == cursor->total_resid) + more = MSG_MORE; + ret = ceph_tcp_sendpage(con->sock, page, page_offset, length, + more); + if (ret <= 0) { + if (do_datacrc) + msg->footer.data_crc = cpu_to_le32(crc); + + return ret; + } + if (do_datacrc && cursor->need_crc) + crc = ceph_crc32c_page(crc, page, page_offset, length); + ceph_msg_data_advance(cursor, (size_t)ret); + } + + dout("%s %p msg %p done\n", __func__, con, msg); + + /* prepare and queue up footer, too */ + if (do_datacrc) + msg->footer.data_crc = cpu_to_le32(crc); + else + msg->footer.flags |= CEPH_MSG_FOOTER_NOCRC; + con_out_kvec_reset(con); + prepare_write_message_footer(con); + + return 1; /* must return > 0 to indicate success */ +} + +/* + * write some zeros + */ +static int write_partial_skip(struct ceph_connection *con) +{ + int more = MSG_MORE | MSG_SENDPAGE_NOTLAST; + int ret; + + dout("%s %p %d left\n", __func__, con, con->v1.out_skip); + while (con->v1.out_skip > 0) { + size_t size = min(con->v1.out_skip, (int)PAGE_SIZE); + + if (size == con->v1.out_skip) + more = MSG_MORE; + ret = ceph_tcp_sendpage(con->sock, ceph_zero_page, 0, size, + more); + if (ret <= 0) + goto out; + con->v1.out_skip -= ret; + } + ret = 1; +out: + return ret; +} + +/* + * Prepare to read connection handshake, or an ack. + */ +static void prepare_read_banner(struct ceph_connection *con) +{ + dout("prepare_read_banner %p\n", con); + con->v1.in_base_pos = 0; +} + +static void prepare_read_connect(struct ceph_connection *con) +{ + dout("prepare_read_connect %p\n", con); + con->v1.in_base_pos = 0; +} + +static void prepare_read_ack(struct ceph_connection *con) +{ + dout("prepare_read_ack %p\n", con); + con->v1.in_base_pos = 0; +} + +static void prepare_read_seq(struct ceph_connection *con) +{ + dout("prepare_read_seq %p\n", con); + con->v1.in_base_pos = 0; + con->v1.in_tag = CEPH_MSGR_TAG_SEQ; +} + +static void prepare_read_tag(struct ceph_connection *con) +{ + dout("prepare_read_tag %p\n", con); + con->v1.in_base_pos = 0; + con->v1.in_tag = CEPH_MSGR_TAG_READY; +} + +static void prepare_read_keepalive_ack(struct ceph_connection *con) +{ + dout("prepare_read_keepalive_ack %p\n", con); + con->v1.in_base_pos = 0; +} + +/* + * Prepare to read a message. + */ +static int prepare_read_message(struct ceph_connection *con) +{ + dout("prepare_read_message %p\n", con); + BUG_ON(con->in_msg != NULL); + con->v1.in_base_pos = 0; + con->in_front_crc = con->in_middle_crc = con->in_data_crc = 0; + return 0; +} + +static int read_partial(struct ceph_connection *con, + int end, int size, void *object) +{ + while (con->v1.in_base_pos < end) { + int left = end - con->v1.in_base_pos; + int have = size - left; + int ret = ceph_tcp_recvmsg(con->sock, object + have, left); + if (ret <= 0) + return ret; + con->v1.in_base_pos += ret; + } + return 1; +} + +/* + * Read all or part of the connect-side handshake on a new connection + */ +static int read_partial_banner(struct ceph_connection *con) +{ + int size; + int end; + int ret; + + dout("read_partial_banner %p at %d\n", con, con->v1.in_base_pos); + + /* peer's banner */ + size = strlen(CEPH_BANNER); + end = size; + ret = read_partial(con, end, size, con->v1.in_banner); + if (ret <= 0) + goto out; + + size = sizeof(con->v1.actual_peer_addr); + end += size; + ret = read_partial(con, end, size, &con->v1.actual_peer_addr); + if (ret <= 0) + goto out; + ceph_decode_banner_addr(&con->v1.actual_peer_addr); + + size = sizeof(con->v1.peer_addr_for_me); + end += size; + ret = read_partial(con, end, size, &con->v1.peer_addr_for_me); + if (ret <= 0) + goto out; + ceph_decode_banner_addr(&con->v1.peer_addr_for_me); + +out: + return ret; +} + +static int read_partial_connect(struct ceph_connection *con) +{ + int size; + int end; + int ret; + + dout("read_partial_connect %p at %d\n", con, con->v1.in_base_pos); + + size = sizeof(con->v1.in_reply); + end = size; + ret = read_partial(con, end, size, &con->v1.in_reply); + if (ret <= 0) + goto out; + + if (con->v1.auth) { + size = le32_to_cpu(con->v1.in_reply.authorizer_len); + if (size > con->v1.auth->authorizer_reply_buf_len) { + pr_err("authorizer reply too big: %d > %zu\n", size, + con->v1.auth->authorizer_reply_buf_len); + ret = -EINVAL; + goto out; + } + + end += size; + ret = read_partial(con, end, size, + con->v1.auth->authorizer_reply_buf); + if (ret <= 0) + goto out; + } + + dout("read_partial_connect %p tag %d, con_seq = %u, g_seq = %u\n", + con, con->v1.in_reply.tag, + le32_to_cpu(con->v1.in_reply.connect_seq), + le32_to_cpu(con->v1.in_reply.global_seq)); +out: + return ret; +} + +/* + * Verify the hello banner looks okay. + */ +static int verify_hello(struct ceph_connection *con) +{ + if (memcmp(con->v1.in_banner, CEPH_BANNER, strlen(CEPH_BANNER))) { + pr_err("connect to %s got bad banner\n", + ceph_pr_addr(&con->peer_addr)); + con->error_msg = "protocol error, bad banner"; + return -1; + } + return 0; +} + +static int process_banner(struct ceph_connection *con) +{ + struct ceph_entity_addr *my_addr = &con->msgr->inst.addr; + + dout("process_banner on %p\n", con); + + if (verify_hello(con) < 0) + return -1; + + /* + * Make sure the other end is who we wanted. note that the other + * end may not yet know their ip address, so if it's 0.0.0.0, give + * them the benefit of the doubt. + */ + if (memcmp(&con->peer_addr, &con->v1.actual_peer_addr, + sizeof(con->peer_addr)) != 0 && + !(ceph_addr_is_blank(&con->v1.actual_peer_addr) && + con->v1.actual_peer_addr.nonce == con->peer_addr.nonce)) { + pr_warn("wrong peer, want %s/%u, got %s/%u\n", + ceph_pr_addr(&con->peer_addr), + le32_to_cpu(con->peer_addr.nonce), + ceph_pr_addr(&con->v1.actual_peer_addr), + le32_to_cpu(con->v1.actual_peer_addr.nonce)); + con->error_msg = "wrong peer at address"; + return -1; + } + + /* + * did we learn our address? + */ + if (ceph_addr_is_blank(my_addr)) { + memcpy(&my_addr->in_addr, + &con->v1.peer_addr_for_me.in_addr, + sizeof(con->v1.peer_addr_for_me.in_addr)); + ceph_addr_set_port(my_addr, 0); + ceph_encode_my_addr(con->msgr); + dout("process_banner learned my addr is %s\n", + ceph_pr_addr(my_addr)); + } + + return 0; +} + +static int process_connect(struct ceph_connection *con) +{ + u64 sup_feat = from_msgr(con->msgr)->supported_features; + u64 req_feat = from_msgr(con->msgr)->required_features; + u64 server_feat = le64_to_cpu(con->v1.in_reply.features); + int ret; + + dout("process_connect on %p tag %d\n", con, con->v1.in_tag); + + if (con->v1.auth) { + int len = le32_to_cpu(con->v1.in_reply.authorizer_len); + + /* + * Any connection that defines ->get_authorizer() + * should also define ->add_authorizer_challenge() and + * ->verify_authorizer_reply(). + * + * See get_connect_authorizer(). + */ + if (con->v1.in_reply.tag == + CEPH_MSGR_TAG_CHALLENGE_AUTHORIZER) { + ret = con->ops->add_authorizer_challenge( + con, con->v1.auth->authorizer_reply_buf, len); + if (ret < 0) + return ret; + + con_out_kvec_reset(con); + __prepare_write_connect(con); + prepare_read_connect(con); + return 0; + } + + if (len) { + ret = con->ops->verify_authorizer_reply(con); + if (ret < 0) { + con->error_msg = "bad authorize reply"; + return ret; + } + } + } + + switch (con->v1.in_reply.tag) { + case CEPH_MSGR_TAG_FEATURES: + pr_err("%s%lld %s feature set mismatch," + " my %llx < server's %llx, missing %llx\n", + ENTITY_NAME(con->peer_name), + ceph_pr_addr(&con->peer_addr), + sup_feat, server_feat, server_feat & ~sup_feat); + con->error_msg = "missing required protocol features"; + return -1; + + case CEPH_MSGR_TAG_BADPROTOVER: + pr_err("%s%lld %s protocol version mismatch," + " my %d != server's %d\n", + ENTITY_NAME(con->peer_name), + ceph_pr_addr(&con->peer_addr), + le32_to_cpu(con->v1.out_connect.protocol_version), + le32_to_cpu(con->v1.in_reply.protocol_version)); + con->error_msg = "protocol version mismatch"; + return -1; + + case CEPH_MSGR_TAG_BADAUTHORIZER: + con->v1.auth_retry++; + dout("process_connect %p got BADAUTHORIZER attempt %d\n", con, + con->v1.auth_retry); + if (con->v1.auth_retry == 2) { + con->error_msg = "connect authorization failure"; + return -1; + } + con_out_kvec_reset(con); + ret = prepare_write_connect(con); + if (ret < 0) + return ret; + prepare_read_connect(con); + break; + + case CEPH_MSGR_TAG_RESETSESSION: + /* + * If we connected with a large connect_seq but the peer + * has no record of a session with us (no connection, or + * connect_seq == 0), they will send RESETSESION to indicate + * that they must have reset their session, and may have + * dropped messages. + */ + dout("process_connect got RESET peer seq %u\n", + le32_to_cpu(con->v1.in_reply.connect_seq)); + pr_info("%s%lld %s session reset\n", + ENTITY_NAME(con->peer_name), + ceph_pr_addr(&con->peer_addr)); + ceph_con_reset_session(con); + con_out_kvec_reset(con); + ret = prepare_write_connect(con); + if (ret < 0) + return ret; + prepare_read_connect(con); + + /* Tell ceph about it. */ + mutex_unlock(&con->mutex); + if (con->ops->peer_reset) + con->ops->peer_reset(con); + mutex_lock(&con->mutex); + if (con->state != CEPH_CON_S_V1_CONNECT_MSG) + return -EAGAIN; + break; + + case CEPH_MSGR_TAG_RETRY_SESSION: + /* + * If we sent a smaller connect_seq than the peer has, try + * again with a larger value. + */ + dout("process_connect got RETRY_SESSION my seq %u, peer %u\n", + le32_to_cpu(con->v1.out_connect.connect_seq), + le32_to_cpu(con->v1.in_reply.connect_seq)); + con->v1.connect_seq = le32_to_cpu(con->v1.in_reply.connect_seq); + con_out_kvec_reset(con); + ret = prepare_write_connect(con); + if (ret < 0) + return ret; + prepare_read_connect(con); + break; + + case CEPH_MSGR_TAG_RETRY_GLOBAL: + /* + * If we sent a smaller global_seq than the peer has, try + * again with a larger value. + */ + dout("process_connect got RETRY_GLOBAL my %u peer_gseq %u\n", + con->v1.peer_global_seq, + le32_to_cpu(con->v1.in_reply.global_seq)); + ceph_get_global_seq(con->msgr, + le32_to_cpu(con->v1.in_reply.global_seq)); + con_out_kvec_reset(con); + ret = prepare_write_connect(con); + if (ret < 0) + return ret; + prepare_read_connect(con); + break; + + case CEPH_MSGR_TAG_SEQ: + case CEPH_MSGR_TAG_READY: + if (req_feat & ~server_feat) { + pr_err("%s%lld %s protocol feature mismatch," + " my required %llx > server's %llx, need %llx\n", + ENTITY_NAME(con->peer_name), + ceph_pr_addr(&con->peer_addr), + req_feat, server_feat, req_feat & ~server_feat); + con->error_msg = "missing required protocol features"; + return -1; + } + + WARN_ON(con->state != CEPH_CON_S_V1_CONNECT_MSG); + con->state = CEPH_CON_S_OPEN; + con->v1.auth_retry = 0; /* we authenticated; clear flag */ + con->v1.peer_global_seq = + le32_to_cpu(con->v1.in_reply.global_seq); + con->v1.connect_seq++; + con->peer_features = server_feat; + dout("process_connect got READY gseq %d cseq %d (%d)\n", + con->v1.peer_global_seq, + le32_to_cpu(con->v1.in_reply.connect_seq), + con->v1.connect_seq); + WARN_ON(con->v1.connect_seq != + le32_to_cpu(con->v1.in_reply.connect_seq)); + + if (con->v1.in_reply.flags & CEPH_MSG_CONNECT_LOSSY) + ceph_con_flag_set(con, CEPH_CON_F_LOSSYTX); + + con->delay = 0; /* reset backoff memory */ + + if (con->v1.in_reply.tag == CEPH_MSGR_TAG_SEQ) { + prepare_write_seq(con); + prepare_read_seq(con); + } else { + prepare_read_tag(con); + } + break; + + case CEPH_MSGR_TAG_WAIT: + /* + * If there is a connection race (we are opening + * connections to each other), one of us may just have + * to WAIT. This shouldn't happen if we are the + * client. + */ + con->error_msg = "protocol error, got WAIT as client"; + return -1; + + default: + con->error_msg = "protocol error, garbage tag during connect"; + return -1; + } + return 0; +} + +/* + * read (part of) an ack + */ +static int read_partial_ack(struct ceph_connection *con) +{ + int size = sizeof(con->v1.in_temp_ack); + int end = size; + + return read_partial(con, end, size, &con->v1.in_temp_ack); +} + +/* + * We can finally discard anything that's been acked. + */ +static void process_ack(struct ceph_connection *con) +{ + u64 ack = le64_to_cpu(con->v1.in_temp_ack); + + if (con->v1.in_tag == CEPH_MSGR_TAG_ACK) + ceph_con_discard_sent(con, ack); + else + ceph_con_discard_requeued(con, ack); + + prepare_read_tag(con); +} + +static int read_partial_message_section(struct ceph_connection *con, + struct kvec *section, + unsigned int sec_len, u32 *crc) +{ + int ret, left; + + BUG_ON(!section); + + while (section->iov_len < sec_len) { + BUG_ON(section->iov_base == NULL); + left = sec_len - section->iov_len; + ret = ceph_tcp_recvmsg(con->sock, (char *)section->iov_base + + section->iov_len, left); + if (ret <= 0) + return ret; + section->iov_len += ret; + } + if (section->iov_len == sec_len) + *crc = crc32c(0, section->iov_base, section->iov_len); + + return 1; +} + +static int read_partial_msg_data(struct ceph_connection *con) +{ + struct ceph_msg_data_cursor *cursor = &con->in_msg->cursor; + bool do_datacrc = !ceph_test_opt(from_msgr(con->msgr), NOCRC); + struct page *page; + size_t page_offset; + size_t length; + u32 crc = 0; + int ret; + + if (do_datacrc) + crc = con->in_data_crc; + while (cursor->total_resid) { + if (!cursor->resid) { + ceph_msg_data_advance(cursor, 0); + continue; + } + + page = ceph_msg_data_next(cursor, &page_offset, &length); + ret = ceph_tcp_recvpage(con->sock, page, page_offset, length); + if (ret <= 0) { + if (do_datacrc) + con->in_data_crc = crc; + + return ret; + } + + if (do_datacrc) + crc = ceph_crc32c_page(crc, page, page_offset, ret); + ceph_msg_data_advance(cursor, (size_t)ret); + } + if (do_datacrc) + con->in_data_crc = crc; + + return 1; /* must return > 0 to indicate success */ +} + +static int read_partial_msg_data_bounce(struct ceph_connection *con) +{ + struct ceph_msg_data_cursor *cursor = &con->in_msg->cursor; + struct page *page; + size_t off, len; + u32 crc; + int ret; + + if (unlikely(!con->bounce_page)) { + con->bounce_page = alloc_page(GFP_NOIO); + if (!con->bounce_page) { + pr_err("failed to allocate bounce page\n"); + return -ENOMEM; + } + } + + crc = con->in_data_crc; + while (cursor->total_resid) { + if (!cursor->resid) { + ceph_msg_data_advance(cursor, 0); + continue; + } + + page = ceph_msg_data_next(cursor, &off, &len); + ret = ceph_tcp_recvpage(con->sock, con->bounce_page, 0, len); + if (ret <= 0) { + con->in_data_crc = crc; + return ret; + } + + crc = crc32c(crc, page_address(con->bounce_page), ret); + memcpy_to_page(page, off, page_address(con->bounce_page), ret); + + ceph_msg_data_advance(cursor, ret); + } + con->in_data_crc = crc; + + return 1; /* must return > 0 to indicate success */ +} + +/* + * read (part of) a message. + */ +static int read_partial_message(struct ceph_connection *con) +{ + struct ceph_msg *m = con->in_msg; + int size; + int end; + int ret; + unsigned int front_len, middle_len, data_len; + bool do_datacrc = !ceph_test_opt(from_msgr(con->msgr), NOCRC); + bool need_sign = (con->peer_features & CEPH_FEATURE_MSG_AUTH); + u64 seq; + u32 crc; + + dout("read_partial_message con %p msg %p\n", con, m); + + /* header */ + size = sizeof(con->v1.in_hdr); + end = size; + ret = read_partial(con, end, size, &con->v1.in_hdr); + if (ret <= 0) + return ret; + + crc = crc32c(0, &con->v1.in_hdr, offsetof(struct ceph_msg_header, crc)); + if (cpu_to_le32(crc) != con->v1.in_hdr.crc) { + pr_err("read_partial_message bad hdr crc %u != expected %u\n", + crc, con->v1.in_hdr.crc); + return -EBADMSG; + } + + front_len = le32_to_cpu(con->v1.in_hdr.front_len); + if (front_len > CEPH_MSG_MAX_FRONT_LEN) + return -EIO; + middle_len = le32_to_cpu(con->v1.in_hdr.middle_len); + if (middle_len > CEPH_MSG_MAX_MIDDLE_LEN) + return -EIO; + data_len = le32_to_cpu(con->v1.in_hdr.data_len); + if (data_len > CEPH_MSG_MAX_DATA_LEN) + return -EIO; + + /* verify seq# */ + seq = le64_to_cpu(con->v1.in_hdr.seq); + if ((s64)seq - (s64)con->in_seq < 1) { + pr_info("skipping %s%lld %s seq %lld expected %lld\n", + ENTITY_NAME(con->peer_name), + ceph_pr_addr(&con->peer_addr), + seq, con->in_seq + 1); + con->v1.in_base_pos = -front_len - middle_len - data_len - + sizeof_footer(con); + con->v1.in_tag = CEPH_MSGR_TAG_READY; + return 1; + } else if ((s64)seq - (s64)con->in_seq > 1) { + pr_err("read_partial_message bad seq %lld expected %lld\n", + seq, con->in_seq + 1); + con->error_msg = "bad message sequence # for incoming message"; + return -EBADE; + } + + /* allocate message? */ + if (!con->in_msg) { + int skip = 0; + + dout("got hdr type %d front %d data %d\n", con->v1.in_hdr.type, + front_len, data_len); + ret = ceph_con_in_msg_alloc(con, &con->v1.in_hdr, &skip); + if (ret < 0) + return ret; + + BUG_ON((!con->in_msg) ^ skip); + if (skip) { + /* skip this message */ + dout("alloc_msg said skip message\n"); + con->v1.in_base_pos = -front_len - middle_len - + data_len - sizeof_footer(con); + con->v1.in_tag = CEPH_MSGR_TAG_READY; + con->in_seq++; + return 1; + } + + BUG_ON(!con->in_msg); + BUG_ON(con->in_msg->con != con); + m = con->in_msg; + m->front.iov_len = 0; /* haven't read it yet */ + if (m->middle) + m->middle->vec.iov_len = 0; + + /* prepare for data payload, if any */ + + if (data_len) + prepare_message_data(con->in_msg, data_len); + } + + /* front */ + ret = read_partial_message_section(con, &m->front, front_len, + &con->in_front_crc); + if (ret <= 0) + return ret; + + /* middle */ + if (m->middle) { + ret = read_partial_message_section(con, &m->middle->vec, + middle_len, + &con->in_middle_crc); + if (ret <= 0) + return ret; + } + + /* (page) data */ + if (data_len) { + if (!m->num_data_items) + return -EIO; + + if (ceph_test_opt(from_msgr(con->msgr), RXBOUNCE)) + ret = read_partial_msg_data_bounce(con); + else + ret = read_partial_msg_data(con); + if (ret <= 0) + return ret; + } + + /* footer */ + size = sizeof_footer(con); + end += size; + ret = read_partial(con, end, size, &m->footer); + if (ret <= 0) + return ret; + + if (!need_sign) { + m->footer.flags = m->old_footer.flags; + m->footer.sig = 0; + } + + dout("read_partial_message got msg %p %d (%u) + %d (%u) + %d (%u)\n", + m, front_len, m->footer.front_crc, middle_len, + m->footer.middle_crc, data_len, m->footer.data_crc); + + /* crc ok? */ + if (con->in_front_crc != le32_to_cpu(m->footer.front_crc)) { + pr_err("read_partial_message %p front crc %u != exp. %u\n", + m, con->in_front_crc, m->footer.front_crc); + return -EBADMSG; + } + if (con->in_middle_crc != le32_to_cpu(m->footer.middle_crc)) { + pr_err("read_partial_message %p middle crc %u != exp %u\n", + m, con->in_middle_crc, m->footer.middle_crc); + return -EBADMSG; + } + if (do_datacrc && + (m->footer.flags & CEPH_MSG_FOOTER_NOCRC) == 0 && + con->in_data_crc != le32_to_cpu(m->footer.data_crc)) { + pr_err("read_partial_message %p data crc %u != exp. %u\n", m, + con->in_data_crc, le32_to_cpu(m->footer.data_crc)); + return -EBADMSG; + } + + if (need_sign && con->ops->check_message_signature && + con->ops->check_message_signature(m)) { + pr_err("read_partial_message %p signature check failed\n", m); + return -EBADMSG; + } + + return 1; /* done! */ +} + +static int read_keepalive_ack(struct ceph_connection *con) +{ + struct ceph_timespec ceph_ts; + size_t size = sizeof(ceph_ts); + int ret = read_partial(con, size, size, &ceph_ts); + if (ret <= 0) + return ret; + ceph_decode_timespec64(&con->last_keepalive_ack, &ceph_ts); + prepare_read_tag(con); + return 1; +} + +/* + * Read what we can from the socket. + */ +int ceph_con_v1_try_read(struct ceph_connection *con) +{ + int ret = -1; + +more: + dout("try_read start %p state %d\n", con, con->state); + if (con->state != CEPH_CON_S_V1_BANNER && + con->state != CEPH_CON_S_V1_CONNECT_MSG && + con->state != CEPH_CON_S_OPEN) + return 0; + + BUG_ON(!con->sock); + + dout("try_read tag %d in_base_pos %d\n", con->v1.in_tag, + con->v1.in_base_pos); + + if (con->state == CEPH_CON_S_V1_BANNER) { + ret = read_partial_banner(con); + if (ret <= 0) + goto out; + ret = process_banner(con); + if (ret < 0) + goto out; + + con->state = CEPH_CON_S_V1_CONNECT_MSG; + + /* + * Received banner is good, exchange connection info. + * Do not reset out_kvec, as sending our banner raced + * with receiving peer banner after connect completed. + */ + ret = prepare_write_connect(con); + if (ret < 0) + goto out; + prepare_read_connect(con); + + /* Send connection info before awaiting response */ + goto out; + } + + if (con->state == CEPH_CON_S_V1_CONNECT_MSG) { + ret = read_partial_connect(con); + if (ret <= 0) + goto out; + ret = process_connect(con); + if (ret < 0) + goto out; + goto more; + } + + WARN_ON(con->state != CEPH_CON_S_OPEN); + + if (con->v1.in_base_pos < 0) { + /* + * skipping + discarding content. + */ + ret = ceph_tcp_recvmsg(con->sock, NULL, -con->v1.in_base_pos); + if (ret <= 0) + goto out; + dout("skipped %d / %d bytes\n", ret, -con->v1.in_base_pos); + con->v1.in_base_pos += ret; + if (con->v1.in_base_pos) + goto more; + } + if (con->v1.in_tag == CEPH_MSGR_TAG_READY) { + /* + * what's next? + */ + ret = ceph_tcp_recvmsg(con->sock, &con->v1.in_tag, 1); + if (ret <= 0) + goto out; + dout("try_read got tag %d\n", con->v1.in_tag); + switch (con->v1.in_tag) { + case CEPH_MSGR_TAG_MSG: + prepare_read_message(con); + break; + case CEPH_MSGR_TAG_ACK: + prepare_read_ack(con); + break; + case CEPH_MSGR_TAG_KEEPALIVE2_ACK: + prepare_read_keepalive_ack(con); + break; + case CEPH_MSGR_TAG_CLOSE: + ceph_con_close_socket(con); + con->state = CEPH_CON_S_CLOSED; + goto out; + default: + goto bad_tag; + } + } + if (con->v1.in_tag == CEPH_MSGR_TAG_MSG) { + ret = read_partial_message(con); + if (ret <= 0) { + switch (ret) { + case -EBADMSG: + con->error_msg = "bad crc/signature"; + fallthrough; + case -EBADE: + ret = -EIO; + break; + case -EIO: + con->error_msg = "io error"; + break; + } + goto out; + } + if (con->v1.in_tag == CEPH_MSGR_TAG_READY) + goto more; + ceph_con_process_message(con); + if (con->state == CEPH_CON_S_OPEN) + prepare_read_tag(con); + goto more; + } + if (con->v1.in_tag == CEPH_MSGR_TAG_ACK || + con->v1.in_tag == CEPH_MSGR_TAG_SEQ) { + /* + * the final handshake seq exchange is semantically + * equivalent to an ACK + */ + ret = read_partial_ack(con); + if (ret <= 0) + goto out; + process_ack(con); + goto more; + } + if (con->v1.in_tag == CEPH_MSGR_TAG_KEEPALIVE2_ACK) { + ret = read_keepalive_ack(con); + if (ret <= 0) + goto out; + goto more; + } + +out: + dout("try_read done on %p ret %d\n", con, ret); + return ret; + +bad_tag: + pr_err("try_read bad tag %d\n", con->v1.in_tag); + con->error_msg = "protocol error, garbage tag"; + ret = -1; + goto out; +} + +/* + * Write something to the socket. Called in a worker thread when the + * socket appears to be writeable and we have something ready to send. + */ +int ceph_con_v1_try_write(struct ceph_connection *con) +{ + int ret = 1; + + dout("try_write start %p state %d\n", con, con->state); + if (con->state != CEPH_CON_S_PREOPEN && + con->state != CEPH_CON_S_V1_BANNER && + con->state != CEPH_CON_S_V1_CONNECT_MSG && + con->state != CEPH_CON_S_OPEN) + return 0; + + /* open the socket first? */ + if (con->state == CEPH_CON_S_PREOPEN) { + BUG_ON(con->sock); + con->state = CEPH_CON_S_V1_BANNER; + + con_out_kvec_reset(con); + prepare_write_banner(con); + prepare_read_banner(con); + + BUG_ON(con->in_msg); + con->v1.in_tag = CEPH_MSGR_TAG_READY; + dout("try_write initiating connect on %p new state %d\n", + con, con->state); + ret = ceph_tcp_connect(con); + if (ret < 0) { + con->error_msg = "connect error"; + goto out; + } + } + +more: + dout("try_write out_kvec_bytes %d\n", con->v1.out_kvec_bytes); + BUG_ON(!con->sock); + + /* kvec data queued? */ + if (con->v1.out_kvec_left) { + ret = write_partial_kvec(con); + if (ret <= 0) + goto out; + } + if (con->v1.out_skip) { + ret = write_partial_skip(con); + if (ret <= 0) + goto out; + } + + /* msg pages? */ + if (con->out_msg) { + if (con->v1.out_msg_done) { + ceph_msg_put(con->out_msg); + con->out_msg = NULL; /* we're done with this one */ + goto do_next; + } + + ret = write_partial_message_data(con); + if (ret == 1) + goto more; /* we need to send the footer, too! */ + if (ret == 0) + goto out; + if (ret < 0) { + dout("try_write write_partial_message_data err %d\n", + ret); + goto out; + } + } + +do_next: + if (con->state == CEPH_CON_S_OPEN) { + if (ceph_con_flag_test_and_clear(con, + CEPH_CON_F_KEEPALIVE_PENDING)) { + prepare_write_keepalive(con); + goto more; + } + /* is anything else pending? */ + if (!list_empty(&con->out_queue)) { + prepare_write_message(con); + goto more; + } + if (con->in_seq > con->in_seq_acked) { + prepare_write_ack(con); + goto more; + } + } + + /* Nothing to do! */ + ceph_con_flag_clear(con, CEPH_CON_F_WRITE_PENDING); + dout("try_write nothing else to write.\n"); + ret = 0; +out: + dout("try_write done on %p ret %d\n", con, ret); + return ret; +} + +void ceph_con_v1_revoke(struct ceph_connection *con) +{ + struct ceph_msg *msg = con->out_msg; + + WARN_ON(con->v1.out_skip); + /* footer */ + if (con->v1.out_msg_done) { + con->v1.out_skip += con_out_kvec_skip(con); + } else { + WARN_ON(!msg->data_length); + con->v1.out_skip += sizeof_footer(con); + } + /* data, middle, front */ + if (msg->data_length) + con->v1.out_skip += msg->cursor.total_resid; + if (msg->middle) + con->v1.out_skip += con_out_kvec_skip(con); + con->v1.out_skip += con_out_kvec_skip(con); + + dout("%s con %p out_kvec_bytes %d out_skip %d\n", __func__, con, + con->v1.out_kvec_bytes, con->v1.out_skip); +} + +void ceph_con_v1_revoke_incoming(struct ceph_connection *con) +{ + unsigned int front_len = le32_to_cpu(con->v1.in_hdr.front_len); + unsigned int middle_len = le32_to_cpu(con->v1.in_hdr.middle_len); + unsigned int data_len = le32_to_cpu(con->v1.in_hdr.data_len); + + /* skip rest of message */ + con->v1.in_base_pos = con->v1.in_base_pos - + sizeof(struct ceph_msg_header) - + front_len - + middle_len - + data_len - + sizeof(struct ceph_msg_footer); + + con->v1.in_tag = CEPH_MSGR_TAG_READY; + con->in_seq++; + + dout("%s con %p in_base_pos %d\n", __func__, con, con->v1.in_base_pos); +} + +bool ceph_con_v1_opened(struct ceph_connection *con) +{ + return con->v1.connect_seq; +} + +void ceph_con_v1_reset_session(struct ceph_connection *con) +{ + con->v1.connect_seq = 0; + con->v1.peer_global_seq = 0; +} + +void ceph_con_v1_reset_protocol(struct ceph_connection *con) +{ + con->v1.out_skip = 0; +} diff --git a/net/ceph/messenger_v2.c b/net/ceph/messenger_v2.c new file mode 100644 index 000000000..489baf2e6 --- /dev/null +++ b/net/ceph/messenger_v2.c @@ -0,0 +1,3588 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Ceph msgr2 protocol implementation + * + * Copyright (C) 2020 Ilya Dryomov + */ + +#include + +#include +#include /* for crypto_memneq() */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "crypto.h" /* for CEPH_KEY_LEN and CEPH_MAX_CON_SECRET_LEN */ + +#define FRAME_TAG_HELLO 1 +#define FRAME_TAG_AUTH_REQUEST 2 +#define FRAME_TAG_AUTH_BAD_METHOD 3 +#define FRAME_TAG_AUTH_REPLY_MORE 4 +#define FRAME_TAG_AUTH_REQUEST_MORE 5 +#define FRAME_TAG_AUTH_DONE 6 +#define FRAME_TAG_AUTH_SIGNATURE 7 +#define FRAME_TAG_CLIENT_IDENT 8 +#define FRAME_TAG_SERVER_IDENT 9 +#define FRAME_TAG_IDENT_MISSING_FEATURES 10 +#define FRAME_TAG_SESSION_RECONNECT 11 +#define FRAME_TAG_SESSION_RESET 12 +#define FRAME_TAG_SESSION_RETRY 13 +#define FRAME_TAG_SESSION_RETRY_GLOBAL 14 +#define FRAME_TAG_SESSION_RECONNECT_OK 15 +#define FRAME_TAG_WAIT 16 +#define FRAME_TAG_MESSAGE 17 +#define FRAME_TAG_KEEPALIVE2 18 +#define FRAME_TAG_KEEPALIVE2_ACK 19 +#define FRAME_TAG_ACK 20 + +#define FRAME_LATE_STATUS_ABORTED 0x1 +#define FRAME_LATE_STATUS_COMPLETE 0xe +#define FRAME_LATE_STATUS_ABORTED_MASK 0xf + +#define IN_S_HANDLE_PREAMBLE 1 +#define IN_S_HANDLE_CONTROL 2 +#define IN_S_HANDLE_CONTROL_REMAINDER 3 +#define IN_S_PREPARE_READ_DATA 4 +#define IN_S_PREPARE_READ_DATA_CONT 5 +#define IN_S_PREPARE_READ_ENC_PAGE 6 +#define IN_S_HANDLE_EPILOGUE 7 +#define IN_S_FINISH_SKIP 8 + +#define OUT_S_QUEUE_DATA 1 +#define OUT_S_QUEUE_DATA_CONT 2 +#define OUT_S_QUEUE_ENC_PAGE 3 +#define OUT_S_QUEUE_ZEROS 4 +#define OUT_S_FINISH_MESSAGE 5 +#define OUT_S_GET_NEXT 6 + +#define CTRL_BODY(p) ((void *)(p) + CEPH_PREAMBLE_LEN) +#define FRONT_PAD(p) ((void *)(p) + CEPH_EPILOGUE_SECURE_LEN) +#define MIDDLE_PAD(p) (FRONT_PAD(p) + CEPH_GCM_BLOCK_LEN) +#define DATA_PAD(p) (MIDDLE_PAD(p) + CEPH_GCM_BLOCK_LEN) + +#define CEPH_MSG_FLAGS (MSG_DONTWAIT | MSG_NOSIGNAL) + +static int do_recvmsg(struct socket *sock, struct iov_iter *it) +{ + struct msghdr msg = { .msg_flags = CEPH_MSG_FLAGS }; + int ret; + + msg.msg_iter = *it; + while (iov_iter_count(it)) { + ret = sock_recvmsg(sock, &msg, msg.msg_flags); + if (ret <= 0) { + if (ret == -EAGAIN) + ret = 0; + return ret; + } + + iov_iter_advance(it, ret); + } + + WARN_ON(msg_data_left(&msg)); + return 1; +} + +/* + * Read as much as possible. + * + * Return: + * 1 - done, nothing (else) to read + * 0 - socket is empty, need to wait + * <0 - error + */ +static int ceph_tcp_recv(struct ceph_connection *con) +{ + int ret; + + dout("%s con %p %s %zu\n", __func__, con, + iov_iter_is_discard(&con->v2.in_iter) ? "discard" : "need", + iov_iter_count(&con->v2.in_iter)); + ret = do_recvmsg(con->sock, &con->v2.in_iter); + dout("%s con %p ret %d left %zu\n", __func__, con, ret, + iov_iter_count(&con->v2.in_iter)); + return ret; +} + +static int do_sendmsg(struct socket *sock, struct iov_iter *it) +{ + struct msghdr msg = { .msg_flags = CEPH_MSG_FLAGS }; + int ret; + + msg.msg_iter = *it; + while (iov_iter_count(it)) { + ret = sock_sendmsg(sock, &msg); + if (ret <= 0) { + if (ret == -EAGAIN) + ret = 0; + return ret; + } + + iov_iter_advance(it, ret); + } + + WARN_ON(msg_data_left(&msg)); + return 1; +} + +static int do_try_sendpage(struct socket *sock, struct iov_iter *it) +{ + struct msghdr msg = { .msg_flags = CEPH_MSG_FLAGS }; + struct bio_vec bv; + int ret; + + if (WARN_ON(!iov_iter_is_bvec(it))) + return -EINVAL; + + while (iov_iter_count(it)) { + /* iov_iter_iovec() for ITER_BVEC */ + bv.bv_page = it->bvec->bv_page; + bv.bv_offset = it->bvec->bv_offset + it->iov_offset; + bv.bv_len = min(iov_iter_count(it), + it->bvec->bv_len - it->iov_offset); + + /* + * sendpage cannot properly handle pages with + * page_count == 0, we need to fall back to sendmsg if + * that's the case. + * + * Same goes for slab pages: skb_can_coalesce() allows + * coalescing neighboring slab objects into a single frag + * which triggers one of hardened usercopy checks. + */ + if (sendpage_ok(bv.bv_page)) { + ret = sock->ops->sendpage(sock, bv.bv_page, + bv.bv_offset, bv.bv_len, + CEPH_MSG_FLAGS); + } else { + iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bv, 1, bv.bv_len); + ret = sock_sendmsg(sock, &msg); + } + if (ret <= 0) { + if (ret == -EAGAIN) + ret = 0; + return ret; + } + + iov_iter_advance(it, ret); + } + + return 1; +} + +/* + * Write as much as possible. The socket is expected to be corked, + * so we don't bother with MSG_MORE/MSG_SENDPAGE_NOTLAST here. + * + * Return: + * 1 - done, nothing (else) to write + * 0 - socket is full, need to wait + * <0 - error + */ +static int ceph_tcp_send(struct ceph_connection *con) +{ + int ret; + + dout("%s con %p have %zu try_sendpage %d\n", __func__, con, + iov_iter_count(&con->v2.out_iter), con->v2.out_iter_sendpage); + if (con->v2.out_iter_sendpage) + ret = do_try_sendpage(con->sock, &con->v2.out_iter); + else + ret = do_sendmsg(con->sock, &con->v2.out_iter); + dout("%s con %p ret %d left %zu\n", __func__, con, ret, + iov_iter_count(&con->v2.out_iter)); + return ret; +} + +static void add_in_kvec(struct ceph_connection *con, void *buf, int len) +{ + BUG_ON(con->v2.in_kvec_cnt >= ARRAY_SIZE(con->v2.in_kvecs)); + WARN_ON(!iov_iter_is_kvec(&con->v2.in_iter)); + + con->v2.in_kvecs[con->v2.in_kvec_cnt].iov_base = buf; + con->v2.in_kvecs[con->v2.in_kvec_cnt].iov_len = len; + con->v2.in_kvec_cnt++; + + con->v2.in_iter.nr_segs++; + con->v2.in_iter.count += len; +} + +static void reset_in_kvecs(struct ceph_connection *con) +{ + WARN_ON(iov_iter_count(&con->v2.in_iter)); + + con->v2.in_kvec_cnt = 0; + iov_iter_kvec(&con->v2.in_iter, ITER_DEST, con->v2.in_kvecs, 0, 0); +} + +static void set_in_bvec(struct ceph_connection *con, const struct bio_vec *bv) +{ + WARN_ON(iov_iter_count(&con->v2.in_iter)); + + con->v2.in_bvec = *bv; + iov_iter_bvec(&con->v2.in_iter, ITER_DEST, &con->v2.in_bvec, 1, bv->bv_len); +} + +static void set_in_skip(struct ceph_connection *con, int len) +{ + WARN_ON(iov_iter_count(&con->v2.in_iter)); + + dout("%s con %p len %d\n", __func__, con, len); + iov_iter_discard(&con->v2.in_iter, ITER_DEST, len); +} + +static void add_out_kvec(struct ceph_connection *con, void *buf, int len) +{ + BUG_ON(con->v2.out_kvec_cnt >= ARRAY_SIZE(con->v2.out_kvecs)); + WARN_ON(!iov_iter_is_kvec(&con->v2.out_iter)); + WARN_ON(con->v2.out_zero); + + con->v2.out_kvecs[con->v2.out_kvec_cnt].iov_base = buf; + con->v2.out_kvecs[con->v2.out_kvec_cnt].iov_len = len; + con->v2.out_kvec_cnt++; + + con->v2.out_iter.nr_segs++; + con->v2.out_iter.count += len; +} + +static void reset_out_kvecs(struct ceph_connection *con) +{ + WARN_ON(iov_iter_count(&con->v2.out_iter)); + WARN_ON(con->v2.out_zero); + + con->v2.out_kvec_cnt = 0; + + iov_iter_kvec(&con->v2.out_iter, ITER_SOURCE, con->v2.out_kvecs, 0, 0); + con->v2.out_iter_sendpage = false; +} + +static void set_out_bvec(struct ceph_connection *con, const struct bio_vec *bv, + bool zerocopy) +{ + WARN_ON(iov_iter_count(&con->v2.out_iter)); + WARN_ON(con->v2.out_zero); + + con->v2.out_bvec = *bv; + con->v2.out_iter_sendpage = zerocopy; + iov_iter_bvec(&con->v2.out_iter, ITER_SOURCE, &con->v2.out_bvec, 1, + con->v2.out_bvec.bv_len); +} + +static void set_out_bvec_zero(struct ceph_connection *con) +{ + WARN_ON(iov_iter_count(&con->v2.out_iter)); + WARN_ON(!con->v2.out_zero); + + con->v2.out_bvec.bv_page = ceph_zero_page; + con->v2.out_bvec.bv_offset = 0; + con->v2.out_bvec.bv_len = min(con->v2.out_zero, (int)PAGE_SIZE); + con->v2.out_iter_sendpage = true; + iov_iter_bvec(&con->v2.out_iter, ITER_SOURCE, &con->v2.out_bvec, 1, + con->v2.out_bvec.bv_len); +} + +static void out_zero_add(struct ceph_connection *con, int len) +{ + dout("%s con %p len %d\n", __func__, con, len); + con->v2.out_zero += len; +} + +static void *alloc_conn_buf(struct ceph_connection *con, int len) +{ + void *buf; + + dout("%s con %p len %d\n", __func__, con, len); + + if (WARN_ON(con->v2.conn_buf_cnt >= ARRAY_SIZE(con->v2.conn_bufs))) + return NULL; + + buf = kvmalloc(len, GFP_NOIO); + if (!buf) + return NULL; + + con->v2.conn_bufs[con->v2.conn_buf_cnt++] = buf; + return buf; +} + +static void free_conn_bufs(struct ceph_connection *con) +{ + while (con->v2.conn_buf_cnt) + kvfree(con->v2.conn_bufs[--con->v2.conn_buf_cnt]); +} + +static void add_in_sign_kvec(struct ceph_connection *con, void *buf, int len) +{ + BUG_ON(con->v2.in_sign_kvec_cnt >= ARRAY_SIZE(con->v2.in_sign_kvecs)); + + con->v2.in_sign_kvecs[con->v2.in_sign_kvec_cnt].iov_base = buf; + con->v2.in_sign_kvecs[con->v2.in_sign_kvec_cnt].iov_len = len; + con->v2.in_sign_kvec_cnt++; +} + +static void clear_in_sign_kvecs(struct ceph_connection *con) +{ + con->v2.in_sign_kvec_cnt = 0; +} + +static void add_out_sign_kvec(struct ceph_connection *con, void *buf, int len) +{ + BUG_ON(con->v2.out_sign_kvec_cnt >= ARRAY_SIZE(con->v2.out_sign_kvecs)); + + con->v2.out_sign_kvecs[con->v2.out_sign_kvec_cnt].iov_base = buf; + con->v2.out_sign_kvecs[con->v2.out_sign_kvec_cnt].iov_len = len; + con->v2.out_sign_kvec_cnt++; +} + +static void clear_out_sign_kvecs(struct ceph_connection *con) +{ + con->v2.out_sign_kvec_cnt = 0; +} + +static bool con_secure(struct ceph_connection *con) +{ + return con->v2.con_mode == CEPH_CON_MODE_SECURE; +} + +static int front_len(const struct ceph_msg *msg) +{ + return le32_to_cpu(msg->hdr.front_len); +} + +static int middle_len(const struct ceph_msg *msg) +{ + return le32_to_cpu(msg->hdr.middle_len); +} + +static int data_len(const struct ceph_msg *msg) +{ + return le32_to_cpu(msg->hdr.data_len); +} + +static bool need_padding(int len) +{ + return !IS_ALIGNED(len, CEPH_GCM_BLOCK_LEN); +} + +static int padded_len(int len) +{ + return ALIGN(len, CEPH_GCM_BLOCK_LEN); +} + +static int padding_len(int len) +{ + return padded_len(len) - len; +} + +/* preamble + control segment */ +static int head_onwire_len(int ctrl_len, bool secure) +{ + int head_len; + int rem_len; + + BUG_ON(ctrl_len < 0 || ctrl_len > CEPH_MSG_MAX_CONTROL_LEN); + + if (secure) { + head_len = CEPH_PREAMBLE_SECURE_LEN; + if (ctrl_len > CEPH_PREAMBLE_INLINE_LEN) { + rem_len = ctrl_len - CEPH_PREAMBLE_INLINE_LEN; + head_len += padded_len(rem_len) + CEPH_GCM_TAG_LEN; + } + } else { + head_len = CEPH_PREAMBLE_PLAIN_LEN; + if (ctrl_len) + head_len += ctrl_len + CEPH_CRC_LEN; + } + return head_len; +} + +/* front, middle and data segments + epilogue */ +static int __tail_onwire_len(int front_len, int middle_len, int data_len, + bool secure) +{ + BUG_ON(front_len < 0 || front_len > CEPH_MSG_MAX_FRONT_LEN || + middle_len < 0 || middle_len > CEPH_MSG_MAX_MIDDLE_LEN || + data_len < 0 || data_len > CEPH_MSG_MAX_DATA_LEN); + + if (!front_len && !middle_len && !data_len) + return 0; + + if (!secure) + return front_len + middle_len + data_len + + CEPH_EPILOGUE_PLAIN_LEN; + + return padded_len(front_len) + padded_len(middle_len) + + padded_len(data_len) + CEPH_EPILOGUE_SECURE_LEN; +} + +static int tail_onwire_len(const struct ceph_msg *msg, bool secure) +{ + return __tail_onwire_len(front_len(msg), middle_len(msg), + data_len(msg), secure); +} + +/* head_onwire_len(sizeof(struct ceph_msg_header2), false) */ +#define MESSAGE_HEAD_PLAIN_LEN (CEPH_PREAMBLE_PLAIN_LEN + \ + sizeof(struct ceph_msg_header2) + \ + CEPH_CRC_LEN) + +static const int frame_aligns[] = { + sizeof(void *), + sizeof(void *), + sizeof(void *), + PAGE_SIZE +}; + +/* + * Discards trailing empty segments, unless there is just one segment. + * A frame always has at least one (possibly empty) segment. + */ +static int calc_segment_count(const int *lens, int len_cnt) +{ + int i; + + for (i = len_cnt - 1; i >= 0; i--) { + if (lens[i]) + return i + 1; + } + + return 1; +} + +static void init_frame_desc(struct ceph_frame_desc *desc, int tag, + const int *lens, int len_cnt) +{ + int i; + + memset(desc, 0, sizeof(*desc)); + + desc->fd_tag = tag; + desc->fd_seg_cnt = calc_segment_count(lens, len_cnt); + BUG_ON(desc->fd_seg_cnt > CEPH_FRAME_MAX_SEGMENT_COUNT); + for (i = 0; i < desc->fd_seg_cnt; i++) { + desc->fd_lens[i] = lens[i]; + desc->fd_aligns[i] = frame_aligns[i]; + } +} + +/* + * Preamble crc covers everything up to itself (28 bytes) and + * is calculated and verified irrespective of the connection mode + * (i.e. even if the frame is encrypted). + */ +static void encode_preamble(const struct ceph_frame_desc *desc, void *p) +{ + void *crcp = p + CEPH_PREAMBLE_LEN - CEPH_CRC_LEN; + void *start = p; + int i; + + memset(p, 0, CEPH_PREAMBLE_LEN); + + ceph_encode_8(&p, desc->fd_tag); + ceph_encode_8(&p, desc->fd_seg_cnt); + for (i = 0; i < desc->fd_seg_cnt; i++) { + ceph_encode_32(&p, desc->fd_lens[i]); + ceph_encode_16(&p, desc->fd_aligns[i]); + } + + put_unaligned_le32(crc32c(0, start, crcp - start), crcp); +} + +static int decode_preamble(void *p, struct ceph_frame_desc *desc) +{ + void *crcp = p + CEPH_PREAMBLE_LEN - CEPH_CRC_LEN; + u32 crc, expected_crc; + int i; + + crc = crc32c(0, p, crcp - p); + expected_crc = get_unaligned_le32(crcp); + if (crc != expected_crc) { + pr_err("bad preamble crc, calculated %u, expected %u\n", + crc, expected_crc); + return -EBADMSG; + } + + memset(desc, 0, sizeof(*desc)); + + desc->fd_tag = ceph_decode_8(&p); + desc->fd_seg_cnt = ceph_decode_8(&p); + if (desc->fd_seg_cnt < 1 || + desc->fd_seg_cnt > CEPH_FRAME_MAX_SEGMENT_COUNT) { + pr_err("bad segment count %d\n", desc->fd_seg_cnt); + return -EINVAL; + } + for (i = 0; i < desc->fd_seg_cnt; i++) { + desc->fd_lens[i] = ceph_decode_32(&p); + desc->fd_aligns[i] = ceph_decode_16(&p); + } + + if (desc->fd_lens[0] < 0 || + desc->fd_lens[0] > CEPH_MSG_MAX_CONTROL_LEN) { + pr_err("bad control segment length %d\n", desc->fd_lens[0]); + return -EINVAL; + } + if (desc->fd_lens[1] < 0 || + desc->fd_lens[1] > CEPH_MSG_MAX_FRONT_LEN) { + pr_err("bad front segment length %d\n", desc->fd_lens[1]); + return -EINVAL; + } + if (desc->fd_lens[2] < 0 || + desc->fd_lens[2] > CEPH_MSG_MAX_MIDDLE_LEN) { + pr_err("bad middle segment length %d\n", desc->fd_lens[2]); + return -EINVAL; + } + if (desc->fd_lens[3] < 0 || + desc->fd_lens[3] > CEPH_MSG_MAX_DATA_LEN) { + pr_err("bad data segment length %d\n", desc->fd_lens[3]); + return -EINVAL; + } + + /* + * This would fire for FRAME_TAG_WAIT (it has one empty + * segment), but we should never get it as client. + */ + if (!desc->fd_lens[desc->fd_seg_cnt - 1]) { + pr_err("last segment empty, segment count %d\n", + desc->fd_seg_cnt); + return -EINVAL; + } + + return 0; +} + +static void encode_epilogue_plain(struct ceph_connection *con, bool aborted) +{ + con->v2.out_epil.late_status = aborted ? FRAME_LATE_STATUS_ABORTED : + FRAME_LATE_STATUS_COMPLETE; + cpu_to_le32s(&con->v2.out_epil.front_crc); + cpu_to_le32s(&con->v2.out_epil.middle_crc); + cpu_to_le32s(&con->v2.out_epil.data_crc); +} + +static void encode_epilogue_secure(struct ceph_connection *con, bool aborted) +{ + memset(&con->v2.out_epil, 0, sizeof(con->v2.out_epil)); + con->v2.out_epil.late_status = aborted ? FRAME_LATE_STATUS_ABORTED : + FRAME_LATE_STATUS_COMPLETE; +} + +static int decode_epilogue(void *p, u32 *front_crc, u32 *middle_crc, + u32 *data_crc) +{ + u8 late_status; + + late_status = ceph_decode_8(&p); + if ((late_status & FRAME_LATE_STATUS_ABORTED_MASK) != + FRAME_LATE_STATUS_COMPLETE) { + /* we should never get an aborted message as client */ + pr_err("bad late_status 0x%x\n", late_status); + return -EINVAL; + } + + if (front_crc && middle_crc && data_crc) { + *front_crc = ceph_decode_32(&p); + *middle_crc = ceph_decode_32(&p); + *data_crc = ceph_decode_32(&p); + } + + return 0; +} + +static void fill_header(struct ceph_msg_header *hdr, + const struct ceph_msg_header2 *hdr2, + int front_len, int middle_len, int data_len, + const struct ceph_entity_name *peer_name) +{ + hdr->seq = hdr2->seq; + hdr->tid = hdr2->tid; + hdr->type = hdr2->type; + hdr->priority = hdr2->priority; + hdr->version = hdr2->version; + hdr->front_len = cpu_to_le32(front_len); + hdr->middle_len = cpu_to_le32(middle_len); + hdr->data_len = cpu_to_le32(data_len); + hdr->data_off = hdr2->data_off; + hdr->src = *peer_name; + hdr->compat_version = hdr2->compat_version; + hdr->reserved = 0; + hdr->crc = 0; +} + +static void fill_header2(struct ceph_msg_header2 *hdr2, + const struct ceph_msg_header *hdr, u64 ack_seq) +{ + hdr2->seq = hdr->seq; + hdr2->tid = hdr->tid; + hdr2->type = hdr->type; + hdr2->priority = hdr->priority; + hdr2->version = hdr->version; + hdr2->data_pre_padding_len = 0; + hdr2->data_off = hdr->data_off; + hdr2->ack_seq = cpu_to_le64(ack_seq); + hdr2->flags = 0; + hdr2->compat_version = hdr->compat_version; + hdr2->reserved = 0; +} + +static int verify_control_crc(struct ceph_connection *con) +{ + int ctrl_len = con->v2.in_desc.fd_lens[0]; + u32 crc, expected_crc; + + WARN_ON(con->v2.in_kvecs[0].iov_len != ctrl_len); + WARN_ON(con->v2.in_kvecs[1].iov_len != CEPH_CRC_LEN); + + crc = crc32c(-1, con->v2.in_kvecs[0].iov_base, ctrl_len); + expected_crc = get_unaligned_le32(con->v2.in_kvecs[1].iov_base); + if (crc != expected_crc) { + pr_err("bad control crc, calculated %u, expected %u\n", + crc, expected_crc); + return -EBADMSG; + } + + return 0; +} + +static int verify_epilogue_crcs(struct ceph_connection *con, u32 front_crc, + u32 middle_crc, u32 data_crc) +{ + if (front_len(con->in_msg)) { + con->in_front_crc = crc32c(-1, con->in_msg->front.iov_base, + front_len(con->in_msg)); + } else { + WARN_ON(!middle_len(con->in_msg) && !data_len(con->in_msg)); + con->in_front_crc = -1; + } + + if (middle_len(con->in_msg)) + con->in_middle_crc = crc32c(-1, + con->in_msg->middle->vec.iov_base, + middle_len(con->in_msg)); + else if (data_len(con->in_msg)) + con->in_middle_crc = -1; + else + con->in_middle_crc = 0; + + if (!data_len(con->in_msg)) + con->in_data_crc = 0; + + dout("%s con %p msg %p crcs %u %u %u\n", __func__, con, con->in_msg, + con->in_front_crc, con->in_middle_crc, con->in_data_crc); + + if (con->in_front_crc != front_crc) { + pr_err("bad front crc, calculated %u, expected %u\n", + con->in_front_crc, front_crc); + return -EBADMSG; + } + if (con->in_middle_crc != middle_crc) { + pr_err("bad middle crc, calculated %u, expected %u\n", + con->in_middle_crc, middle_crc); + return -EBADMSG; + } + if (con->in_data_crc != data_crc) { + pr_err("bad data crc, calculated %u, expected %u\n", + con->in_data_crc, data_crc); + return -EBADMSG; + } + + return 0; +} + +static int setup_crypto(struct ceph_connection *con, + const u8 *session_key, int session_key_len, + const u8 *con_secret, int con_secret_len) +{ + unsigned int noio_flag; + int ret; + + dout("%s con %p con_mode %d session_key_len %d con_secret_len %d\n", + __func__, con, con->v2.con_mode, session_key_len, con_secret_len); + WARN_ON(con->v2.hmac_tfm || con->v2.gcm_tfm || con->v2.gcm_req); + + if (con->v2.con_mode != CEPH_CON_MODE_CRC && + con->v2.con_mode != CEPH_CON_MODE_SECURE) { + pr_err("bad con_mode %d\n", con->v2.con_mode); + return -EINVAL; + } + + if (!session_key_len) { + WARN_ON(con->v2.con_mode != CEPH_CON_MODE_CRC); + WARN_ON(con_secret_len); + return 0; /* auth_none */ + } + + noio_flag = memalloc_noio_save(); + con->v2.hmac_tfm = crypto_alloc_shash("hmac(sha256)", 0, 0); + memalloc_noio_restore(noio_flag); + if (IS_ERR(con->v2.hmac_tfm)) { + ret = PTR_ERR(con->v2.hmac_tfm); + con->v2.hmac_tfm = NULL; + pr_err("failed to allocate hmac tfm context: %d\n", ret); + return ret; + } + + WARN_ON((unsigned long)session_key & + crypto_shash_alignmask(con->v2.hmac_tfm)); + ret = crypto_shash_setkey(con->v2.hmac_tfm, session_key, + session_key_len); + if (ret) { + pr_err("failed to set hmac key: %d\n", ret); + return ret; + } + + if (con->v2.con_mode == CEPH_CON_MODE_CRC) { + WARN_ON(con_secret_len); + return 0; /* auth_x, plain mode */ + } + + if (con_secret_len < CEPH_GCM_KEY_LEN + 2 * CEPH_GCM_IV_LEN) { + pr_err("con_secret too small %d\n", con_secret_len); + return -EINVAL; + } + + noio_flag = memalloc_noio_save(); + con->v2.gcm_tfm = crypto_alloc_aead("gcm(aes)", 0, 0); + memalloc_noio_restore(noio_flag); + if (IS_ERR(con->v2.gcm_tfm)) { + ret = PTR_ERR(con->v2.gcm_tfm); + con->v2.gcm_tfm = NULL; + pr_err("failed to allocate gcm tfm context: %d\n", ret); + return ret; + } + + WARN_ON((unsigned long)con_secret & + crypto_aead_alignmask(con->v2.gcm_tfm)); + ret = crypto_aead_setkey(con->v2.gcm_tfm, con_secret, CEPH_GCM_KEY_LEN); + if (ret) { + pr_err("failed to set gcm key: %d\n", ret); + return ret; + } + + WARN_ON(crypto_aead_ivsize(con->v2.gcm_tfm) != CEPH_GCM_IV_LEN); + ret = crypto_aead_setauthsize(con->v2.gcm_tfm, CEPH_GCM_TAG_LEN); + if (ret) { + pr_err("failed to set gcm tag size: %d\n", ret); + return ret; + } + + con->v2.gcm_req = aead_request_alloc(con->v2.gcm_tfm, GFP_NOIO); + if (!con->v2.gcm_req) { + pr_err("failed to allocate gcm request\n"); + return -ENOMEM; + } + + crypto_init_wait(&con->v2.gcm_wait); + aead_request_set_callback(con->v2.gcm_req, CRYPTO_TFM_REQ_MAY_BACKLOG, + crypto_req_done, &con->v2.gcm_wait); + + memcpy(&con->v2.in_gcm_nonce, con_secret + CEPH_GCM_KEY_LEN, + CEPH_GCM_IV_LEN); + memcpy(&con->v2.out_gcm_nonce, + con_secret + CEPH_GCM_KEY_LEN + CEPH_GCM_IV_LEN, + CEPH_GCM_IV_LEN); + return 0; /* auth_x, secure mode */ +} + +static int hmac_sha256(struct ceph_connection *con, const struct kvec *kvecs, + int kvec_cnt, u8 *hmac) +{ + SHASH_DESC_ON_STACK(desc, con->v2.hmac_tfm); /* tfm arg is ignored */ + int ret; + int i; + + dout("%s con %p hmac_tfm %p kvec_cnt %d\n", __func__, con, + con->v2.hmac_tfm, kvec_cnt); + + if (!con->v2.hmac_tfm) { + memset(hmac, 0, SHA256_DIGEST_SIZE); + return 0; /* auth_none */ + } + + desc->tfm = con->v2.hmac_tfm; + ret = crypto_shash_init(desc); + if (ret) + goto out; + + for (i = 0; i < kvec_cnt; i++) { + WARN_ON((unsigned long)kvecs[i].iov_base & + crypto_shash_alignmask(con->v2.hmac_tfm)); + ret = crypto_shash_update(desc, kvecs[i].iov_base, + kvecs[i].iov_len); + if (ret) + goto out; + } + + ret = crypto_shash_final(desc, hmac); + +out: + shash_desc_zero(desc); + return ret; /* auth_x, both plain and secure modes */ +} + +static void gcm_inc_nonce(struct ceph_gcm_nonce *nonce) +{ + u64 counter; + + counter = le64_to_cpu(nonce->counter); + nonce->counter = cpu_to_le64(counter + 1); +} + +static int gcm_crypt(struct ceph_connection *con, bool encrypt, + struct scatterlist *src, struct scatterlist *dst, + int src_len) +{ + struct ceph_gcm_nonce *nonce; + int ret; + + nonce = encrypt ? &con->v2.out_gcm_nonce : &con->v2.in_gcm_nonce; + + aead_request_set_ad(con->v2.gcm_req, 0); /* no AAD */ + aead_request_set_crypt(con->v2.gcm_req, src, dst, src_len, (u8 *)nonce); + ret = crypto_wait_req(encrypt ? crypto_aead_encrypt(con->v2.gcm_req) : + crypto_aead_decrypt(con->v2.gcm_req), + &con->v2.gcm_wait); + if (ret) + return ret; + + gcm_inc_nonce(nonce); + return 0; +} + +static void get_bvec_at(struct ceph_msg_data_cursor *cursor, + struct bio_vec *bv) +{ + struct page *page; + size_t off, len; + + WARN_ON(!cursor->total_resid); + + /* skip zero-length data items */ + while (!cursor->resid) + ceph_msg_data_advance(cursor, 0); + + /* get a piece of data, cursor isn't advanced */ + page = ceph_msg_data_next(cursor, &off, &len); + + bv->bv_page = page; + bv->bv_offset = off; + bv->bv_len = len; +} + +static int calc_sg_cnt(void *buf, int buf_len) +{ + int sg_cnt; + + if (!buf_len) + return 0; + + sg_cnt = need_padding(buf_len) ? 1 : 0; + if (is_vmalloc_addr(buf)) { + WARN_ON(offset_in_page(buf)); + sg_cnt += PAGE_ALIGN(buf_len) >> PAGE_SHIFT; + } else { + sg_cnt++; + } + + return sg_cnt; +} + +static int calc_sg_cnt_cursor(struct ceph_msg_data_cursor *cursor) +{ + int data_len = cursor->total_resid; + struct bio_vec bv; + int sg_cnt; + + if (!data_len) + return 0; + + sg_cnt = need_padding(data_len) ? 1 : 0; + do { + get_bvec_at(cursor, &bv); + sg_cnt++; + + ceph_msg_data_advance(cursor, bv.bv_len); + } while (cursor->total_resid); + + return sg_cnt; +} + +static void init_sgs(struct scatterlist **sg, void *buf, int buf_len, u8 *pad) +{ + void *end = buf + buf_len; + struct page *page; + int len; + void *p; + + if (!buf_len) + return; + + if (is_vmalloc_addr(buf)) { + p = buf; + do { + page = vmalloc_to_page(p); + len = min_t(int, end - p, PAGE_SIZE); + WARN_ON(!page || !len || offset_in_page(p)); + sg_set_page(*sg, page, len, 0); + *sg = sg_next(*sg); + p += len; + } while (p != end); + } else { + sg_set_buf(*sg, buf, buf_len); + *sg = sg_next(*sg); + } + + if (need_padding(buf_len)) { + sg_set_buf(*sg, pad, padding_len(buf_len)); + *sg = sg_next(*sg); + } +} + +static void init_sgs_cursor(struct scatterlist **sg, + struct ceph_msg_data_cursor *cursor, u8 *pad) +{ + int data_len = cursor->total_resid; + struct bio_vec bv; + + if (!data_len) + return; + + do { + get_bvec_at(cursor, &bv); + sg_set_page(*sg, bv.bv_page, bv.bv_len, bv.bv_offset); + *sg = sg_next(*sg); + + ceph_msg_data_advance(cursor, bv.bv_len); + } while (cursor->total_resid); + + if (need_padding(data_len)) { + sg_set_buf(*sg, pad, padding_len(data_len)); + *sg = sg_next(*sg); + } +} + +static int setup_message_sgs(struct sg_table *sgt, struct ceph_msg *msg, + u8 *front_pad, u8 *middle_pad, u8 *data_pad, + void *epilogue, bool add_tag) +{ + struct ceph_msg_data_cursor cursor; + struct scatterlist *cur_sg; + int sg_cnt; + int ret; + + if (!front_len(msg) && !middle_len(msg) && !data_len(msg)) + return 0; + + sg_cnt = 1; /* epilogue + [auth tag] */ + if (front_len(msg)) + sg_cnt += calc_sg_cnt(msg->front.iov_base, + front_len(msg)); + if (middle_len(msg)) + sg_cnt += calc_sg_cnt(msg->middle->vec.iov_base, + middle_len(msg)); + if (data_len(msg)) { + ceph_msg_data_cursor_init(&cursor, msg, data_len(msg)); + sg_cnt += calc_sg_cnt_cursor(&cursor); + } + + ret = sg_alloc_table(sgt, sg_cnt, GFP_NOIO); + if (ret) + return ret; + + cur_sg = sgt->sgl; + if (front_len(msg)) + init_sgs(&cur_sg, msg->front.iov_base, front_len(msg), + front_pad); + if (middle_len(msg)) + init_sgs(&cur_sg, msg->middle->vec.iov_base, middle_len(msg), + middle_pad); + if (data_len(msg)) { + ceph_msg_data_cursor_init(&cursor, msg, data_len(msg)); + init_sgs_cursor(&cur_sg, &cursor, data_pad); + } + + WARN_ON(!sg_is_last(cur_sg)); + sg_set_buf(cur_sg, epilogue, + CEPH_GCM_BLOCK_LEN + (add_tag ? CEPH_GCM_TAG_LEN : 0)); + return 0; +} + +static int decrypt_preamble(struct ceph_connection *con) +{ + struct scatterlist sg; + + sg_init_one(&sg, con->v2.in_buf, CEPH_PREAMBLE_SECURE_LEN); + return gcm_crypt(con, false, &sg, &sg, CEPH_PREAMBLE_SECURE_LEN); +} + +static int decrypt_control_remainder(struct ceph_connection *con) +{ + int ctrl_len = con->v2.in_desc.fd_lens[0]; + int rem_len = ctrl_len - CEPH_PREAMBLE_INLINE_LEN; + int pt_len = padding_len(rem_len) + CEPH_GCM_TAG_LEN; + struct scatterlist sgs[2]; + + WARN_ON(con->v2.in_kvecs[0].iov_len != rem_len); + WARN_ON(con->v2.in_kvecs[1].iov_len != pt_len); + + sg_init_table(sgs, 2); + sg_set_buf(&sgs[0], con->v2.in_kvecs[0].iov_base, rem_len); + sg_set_buf(&sgs[1], con->v2.in_buf, pt_len); + + return gcm_crypt(con, false, sgs, sgs, + padded_len(rem_len) + CEPH_GCM_TAG_LEN); +} + +static int decrypt_tail(struct ceph_connection *con) +{ + struct sg_table enc_sgt = {}; + struct sg_table sgt = {}; + int tail_len; + int ret; + + tail_len = tail_onwire_len(con->in_msg, true); + ret = sg_alloc_table_from_pages(&enc_sgt, con->v2.in_enc_pages, + con->v2.in_enc_page_cnt, 0, tail_len, + GFP_NOIO); + if (ret) + goto out; + + ret = setup_message_sgs(&sgt, con->in_msg, FRONT_PAD(con->v2.in_buf), + MIDDLE_PAD(con->v2.in_buf), DATA_PAD(con->v2.in_buf), + con->v2.in_buf, true); + if (ret) + goto out; + + dout("%s con %p msg %p enc_page_cnt %d sg_cnt %d\n", __func__, con, + con->in_msg, con->v2.in_enc_page_cnt, sgt.orig_nents); + ret = gcm_crypt(con, false, enc_sgt.sgl, sgt.sgl, tail_len); + if (ret) + goto out; + + WARN_ON(!con->v2.in_enc_page_cnt); + ceph_release_page_vector(con->v2.in_enc_pages, + con->v2.in_enc_page_cnt); + con->v2.in_enc_pages = NULL; + con->v2.in_enc_page_cnt = 0; + +out: + sg_free_table(&sgt); + sg_free_table(&enc_sgt); + return ret; +} + +static int prepare_banner(struct ceph_connection *con) +{ + int buf_len = CEPH_BANNER_V2_LEN + 2 + 8 + 8; + void *buf, *p; + + buf = alloc_conn_buf(con, buf_len); + if (!buf) + return -ENOMEM; + + p = buf; + ceph_encode_copy(&p, CEPH_BANNER_V2, CEPH_BANNER_V2_LEN); + ceph_encode_16(&p, sizeof(u64) + sizeof(u64)); + ceph_encode_64(&p, CEPH_MSGR2_SUPPORTED_FEATURES); + ceph_encode_64(&p, CEPH_MSGR2_REQUIRED_FEATURES); + WARN_ON(p != buf + buf_len); + + add_out_kvec(con, buf, buf_len); + add_out_sign_kvec(con, buf, buf_len); + ceph_con_flag_set(con, CEPH_CON_F_WRITE_PENDING); + return 0; +} + +/* + * base: + * preamble + * control body (ctrl_len bytes) + * space for control crc + * + * extdata (optional): + * control body (extdata_len bytes) + * + * Compute control crc and gather base and extdata into: + * + * preamble + * control body (ctrl_len + extdata_len bytes) + * control crc + * + * Preamble should already be encoded at the start of base. + */ +static void prepare_head_plain(struct ceph_connection *con, void *base, + int ctrl_len, void *extdata, int extdata_len, + bool to_be_signed) +{ + int base_len = CEPH_PREAMBLE_LEN + ctrl_len + CEPH_CRC_LEN; + void *crcp = base + base_len - CEPH_CRC_LEN; + u32 crc; + + crc = crc32c(-1, CTRL_BODY(base), ctrl_len); + if (extdata_len) + crc = crc32c(crc, extdata, extdata_len); + put_unaligned_le32(crc, crcp); + + if (!extdata_len) { + add_out_kvec(con, base, base_len); + if (to_be_signed) + add_out_sign_kvec(con, base, base_len); + return; + } + + add_out_kvec(con, base, crcp - base); + add_out_kvec(con, extdata, extdata_len); + add_out_kvec(con, crcp, CEPH_CRC_LEN); + if (to_be_signed) { + add_out_sign_kvec(con, base, crcp - base); + add_out_sign_kvec(con, extdata, extdata_len); + add_out_sign_kvec(con, crcp, CEPH_CRC_LEN); + } +} + +static int prepare_head_secure_small(struct ceph_connection *con, + void *base, int ctrl_len) +{ + struct scatterlist sg; + int ret; + + /* inline buffer padding? */ + if (ctrl_len < CEPH_PREAMBLE_INLINE_LEN) + memset(CTRL_BODY(base) + ctrl_len, 0, + CEPH_PREAMBLE_INLINE_LEN - ctrl_len); + + sg_init_one(&sg, base, CEPH_PREAMBLE_SECURE_LEN); + ret = gcm_crypt(con, true, &sg, &sg, + CEPH_PREAMBLE_SECURE_LEN - CEPH_GCM_TAG_LEN); + if (ret) + return ret; + + add_out_kvec(con, base, CEPH_PREAMBLE_SECURE_LEN); + return 0; +} + +/* + * base: + * preamble + * control body (ctrl_len bytes) + * space for padding, if needed + * space for control remainder auth tag + * space for preamble auth tag + * + * Encrypt preamble and the inline portion, then encrypt the remainder + * and gather into: + * + * preamble + * control body (48 bytes) + * preamble auth tag + * control body (ctrl_len - 48 bytes) + * zero padding, if needed + * control remainder auth tag + * + * Preamble should already be encoded at the start of base. + */ +static int prepare_head_secure_big(struct ceph_connection *con, + void *base, int ctrl_len) +{ + int rem_len = ctrl_len - CEPH_PREAMBLE_INLINE_LEN; + void *rem = CTRL_BODY(base) + CEPH_PREAMBLE_INLINE_LEN; + void *rem_tag = rem + padded_len(rem_len); + void *pmbl_tag = rem_tag + CEPH_GCM_TAG_LEN; + struct scatterlist sgs[2]; + int ret; + + sg_init_table(sgs, 2); + sg_set_buf(&sgs[0], base, rem - base); + sg_set_buf(&sgs[1], pmbl_tag, CEPH_GCM_TAG_LEN); + ret = gcm_crypt(con, true, sgs, sgs, rem - base); + if (ret) + return ret; + + /* control remainder padding? */ + if (need_padding(rem_len)) + memset(rem + rem_len, 0, padding_len(rem_len)); + + sg_init_one(&sgs[0], rem, pmbl_tag - rem); + ret = gcm_crypt(con, true, sgs, sgs, rem_tag - rem); + if (ret) + return ret; + + add_out_kvec(con, base, rem - base); + add_out_kvec(con, pmbl_tag, CEPH_GCM_TAG_LEN); + add_out_kvec(con, rem, pmbl_tag - rem); + return 0; +} + +static int __prepare_control(struct ceph_connection *con, int tag, + void *base, int ctrl_len, void *extdata, + int extdata_len, bool to_be_signed) +{ + int total_len = ctrl_len + extdata_len; + struct ceph_frame_desc desc; + int ret; + + dout("%s con %p tag %d len %d (%d+%d)\n", __func__, con, tag, + total_len, ctrl_len, extdata_len); + + /* extdata may be vmalloc'ed but not base */ + if (WARN_ON(is_vmalloc_addr(base) || !ctrl_len)) + return -EINVAL; + + init_frame_desc(&desc, tag, &total_len, 1); + encode_preamble(&desc, base); + + if (con_secure(con)) { + if (WARN_ON(extdata_len || to_be_signed)) + return -EINVAL; + + if (ctrl_len <= CEPH_PREAMBLE_INLINE_LEN) + /* fully inlined, inline buffer may need padding */ + ret = prepare_head_secure_small(con, base, ctrl_len); + else + /* partially inlined, inline buffer is full */ + ret = prepare_head_secure_big(con, base, ctrl_len); + if (ret) + return ret; + } else { + prepare_head_plain(con, base, ctrl_len, extdata, extdata_len, + to_be_signed); + } + + ceph_con_flag_set(con, CEPH_CON_F_WRITE_PENDING); + return 0; +} + +static int prepare_control(struct ceph_connection *con, int tag, + void *base, int ctrl_len) +{ + return __prepare_control(con, tag, base, ctrl_len, NULL, 0, false); +} + +static int prepare_hello(struct ceph_connection *con) +{ + void *buf, *p; + int ctrl_len; + + ctrl_len = 1 + ceph_entity_addr_encoding_len(&con->peer_addr); + buf = alloc_conn_buf(con, head_onwire_len(ctrl_len, false)); + if (!buf) + return -ENOMEM; + + p = CTRL_BODY(buf); + ceph_encode_8(&p, CEPH_ENTITY_TYPE_CLIENT); + ceph_encode_entity_addr(&p, &con->peer_addr); + WARN_ON(p != CTRL_BODY(buf) + ctrl_len); + + return __prepare_control(con, FRAME_TAG_HELLO, buf, ctrl_len, + NULL, 0, true); +} + +/* so that head_onwire_len(AUTH_BUF_LEN, false) is 512 */ +#define AUTH_BUF_LEN (512 - CEPH_CRC_LEN - CEPH_PREAMBLE_PLAIN_LEN) + +static int prepare_auth_request(struct ceph_connection *con) +{ + void *authorizer, *authorizer_copy; + int ctrl_len, authorizer_len; + void *buf; + int ret; + + ctrl_len = AUTH_BUF_LEN; + buf = alloc_conn_buf(con, head_onwire_len(ctrl_len, false)); + if (!buf) + return -ENOMEM; + + mutex_unlock(&con->mutex); + ret = con->ops->get_auth_request(con, CTRL_BODY(buf), &ctrl_len, + &authorizer, &authorizer_len); + mutex_lock(&con->mutex); + if (con->state != CEPH_CON_S_V2_HELLO) { + dout("%s con %p state changed to %d\n", __func__, con, + con->state); + return -EAGAIN; + } + + dout("%s con %p get_auth_request ret %d\n", __func__, con, ret); + if (ret) + return ret; + + authorizer_copy = alloc_conn_buf(con, authorizer_len); + if (!authorizer_copy) + return -ENOMEM; + + memcpy(authorizer_copy, authorizer, authorizer_len); + + return __prepare_control(con, FRAME_TAG_AUTH_REQUEST, buf, ctrl_len, + authorizer_copy, authorizer_len, true); +} + +static int prepare_auth_request_more(struct ceph_connection *con, + void *reply, int reply_len) +{ + int ctrl_len, authorizer_len; + void *authorizer; + void *buf; + int ret; + + ctrl_len = AUTH_BUF_LEN; + buf = alloc_conn_buf(con, head_onwire_len(ctrl_len, false)); + if (!buf) + return -ENOMEM; + + mutex_unlock(&con->mutex); + ret = con->ops->handle_auth_reply_more(con, reply, reply_len, + CTRL_BODY(buf), &ctrl_len, + &authorizer, &authorizer_len); + mutex_lock(&con->mutex); + if (con->state != CEPH_CON_S_V2_AUTH) { + dout("%s con %p state changed to %d\n", __func__, con, + con->state); + return -EAGAIN; + } + + dout("%s con %p handle_auth_reply_more ret %d\n", __func__, con, ret); + if (ret) + return ret; + + return __prepare_control(con, FRAME_TAG_AUTH_REQUEST_MORE, buf, + ctrl_len, authorizer, authorizer_len, true); +} + +static int prepare_auth_signature(struct ceph_connection *con) +{ + void *buf; + int ret; + + buf = alloc_conn_buf(con, head_onwire_len(SHA256_DIGEST_SIZE, + con_secure(con))); + if (!buf) + return -ENOMEM; + + ret = hmac_sha256(con, con->v2.in_sign_kvecs, con->v2.in_sign_kvec_cnt, + CTRL_BODY(buf)); + if (ret) + return ret; + + return prepare_control(con, FRAME_TAG_AUTH_SIGNATURE, buf, + SHA256_DIGEST_SIZE); +} + +static int prepare_client_ident(struct ceph_connection *con) +{ + struct ceph_entity_addr *my_addr = &con->msgr->inst.addr; + struct ceph_client *client = from_msgr(con->msgr); + u64 global_id = ceph_client_gid(client); + void *buf, *p; + int ctrl_len; + + WARN_ON(con->v2.server_cookie); + WARN_ON(con->v2.connect_seq); + WARN_ON(con->v2.peer_global_seq); + + if (!con->v2.client_cookie) { + do { + get_random_bytes(&con->v2.client_cookie, + sizeof(con->v2.client_cookie)); + } while (!con->v2.client_cookie); + dout("%s con %p generated cookie 0x%llx\n", __func__, con, + con->v2.client_cookie); + } else { + dout("%s con %p cookie already set 0x%llx\n", __func__, con, + con->v2.client_cookie); + } + + dout("%s con %p my_addr %s/%u peer_addr %s/%u global_id %llu global_seq %llu features 0x%llx required_features 0x%llx cookie 0x%llx\n", + __func__, con, ceph_pr_addr(my_addr), le32_to_cpu(my_addr->nonce), + ceph_pr_addr(&con->peer_addr), le32_to_cpu(con->peer_addr.nonce), + global_id, con->v2.global_seq, client->supported_features, + client->required_features, con->v2.client_cookie); + + ctrl_len = 1 + 4 + ceph_entity_addr_encoding_len(my_addr) + + ceph_entity_addr_encoding_len(&con->peer_addr) + 6 * 8; + buf = alloc_conn_buf(con, head_onwire_len(ctrl_len, con_secure(con))); + if (!buf) + return -ENOMEM; + + p = CTRL_BODY(buf); + ceph_encode_8(&p, 2); /* addrvec marker */ + ceph_encode_32(&p, 1); /* addr_cnt */ + ceph_encode_entity_addr(&p, my_addr); + ceph_encode_entity_addr(&p, &con->peer_addr); + ceph_encode_64(&p, global_id); + ceph_encode_64(&p, con->v2.global_seq); + ceph_encode_64(&p, client->supported_features); + ceph_encode_64(&p, client->required_features); + ceph_encode_64(&p, 0); /* flags */ + ceph_encode_64(&p, con->v2.client_cookie); + WARN_ON(p != CTRL_BODY(buf) + ctrl_len); + + return prepare_control(con, FRAME_TAG_CLIENT_IDENT, buf, ctrl_len); +} + +static int prepare_session_reconnect(struct ceph_connection *con) +{ + struct ceph_entity_addr *my_addr = &con->msgr->inst.addr; + void *buf, *p; + int ctrl_len; + + WARN_ON(!con->v2.client_cookie); + WARN_ON(!con->v2.server_cookie); + WARN_ON(!con->v2.connect_seq); + WARN_ON(!con->v2.peer_global_seq); + + dout("%s con %p my_addr %s/%u client_cookie 0x%llx server_cookie 0x%llx global_seq %llu connect_seq %llu in_seq %llu\n", + __func__, con, ceph_pr_addr(my_addr), le32_to_cpu(my_addr->nonce), + con->v2.client_cookie, con->v2.server_cookie, con->v2.global_seq, + con->v2.connect_seq, con->in_seq); + + ctrl_len = 1 + 4 + ceph_entity_addr_encoding_len(my_addr) + 5 * 8; + buf = alloc_conn_buf(con, head_onwire_len(ctrl_len, con_secure(con))); + if (!buf) + return -ENOMEM; + + p = CTRL_BODY(buf); + ceph_encode_8(&p, 2); /* entity_addrvec_t marker */ + ceph_encode_32(&p, 1); /* my_addrs len */ + ceph_encode_entity_addr(&p, my_addr); + ceph_encode_64(&p, con->v2.client_cookie); + ceph_encode_64(&p, con->v2.server_cookie); + ceph_encode_64(&p, con->v2.global_seq); + ceph_encode_64(&p, con->v2.connect_seq); + ceph_encode_64(&p, con->in_seq); + WARN_ON(p != CTRL_BODY(buf) + ctrl_len); + + return prepare_control(con, FRAME_TAG_SESSION_RECONNECT, buf, ctrl_len); +} + +static int prepare_keepalive2(struct ceph_connection *con) +{ + struct ceph_timespec *ts = CTRL_BODY(con->v2.out_buf); + struct timespec64 now; + + ktime_get_real_ts64(&now); + dout("%s con %p timestamp %lld.%09ld\n", __func__, con, now.tv_sec, + now.tv_nsec); + + ceph_encode_timespec64(ts, &now); + + reset_out_kvecs(con); + return prepare_control(con, FRAME_TAG_KEEPALIVE2, con->v2.out_buf, + sizeof(struct ceph_timespec)); +} + +static int prepare_ack(struct ceph_connection *con) +{ + void *p; + + dout("%s con %p in_seq_acked %llu -> %llu\n", __func__, con, + con->in_seq_acked, con->in_seq); + con->in_seq_acked = con->in_seq; + + p = CTRL_BODY(con->v2.out_buf); + ceph_encode_64(&p, con->in_seq_acked); + + reset_out_kvecs(con); + return prepare_control(con, FRAME_TAG_ACK, con->v2.out_buf, 8); +} + +static void prepare_epilogue_plain(struct ceph_connection *con, bool aborted) +{ + dout("%s con %p msg %p aborted %d crcs %u %u %u\n", __func__, con, + con->out_msg, aborted, con->v2.out_epil.front_crc, + con->v2.out_epil.middle_crc, con->v2.out_epil.data_crc); + + encode_epilogue_plain(con, aborted); + add_out_kvec(con, &con->v2.out_epil, CEPH_EPILOGUE_PLAIN_LEN); +} + +/* + * For "used" empty segments, crc is -1. For unused (trailing) + * segments, crc is 0. + */ +static void prepare_message_plain(struct ceph_connection *con) +{ + struct ceph_msg *msg = con->out_msg; + + prepare_head_plain(con, con->v2.out_buf, + sizeof(struct ceph_msg_header2), NULL, 0, false); + + if (!front_len(msg) && !middle_len(msg)) { + if (!data_len(msg)) { + /* + * Empty message: once the head is written, + * we are done -- there is no epilogue. + */ + con->v2.out_state = OUT_S_FINISH_MESSAGE; + return; + } + + con->v2.out_epil.front_crc = -1; + con->v2.out_epil.middle_crc = -1; + con->v2.out_state = OUT_S_QUEUE_DATA; + return; + } + + if (front_len(msg)) { + con->v2.out_epil.front_crc = crc32c(-1, msg->front.iov_base, + front_len(msg)); + add_out_kvec(con, msg->front.iov_base, front_len(msg)); + } else { + /* middle (at least) is there, checked above */ + con->v2.out_epil.front_crc = -1; + } + + if (middle_len(msg)) { + con->v2.out_epil.middle_crc = + crc32c(-1, msg->middle->vec.iov_base, middle_len(msg)); + add_out_kvec(con, msg->middle->vec.iov_base, middle_len(msg)); + } else { + con->v2.out_epil.middle_crc = data_len(msg) ? -1 : 0; + } + + if (data_len(msg)) { + con->v2.out_state = OUT_S_QUEUE_DATA; + } else { + con->v2.out_epil.data_crc = 0; + prepare_epilogue_plain(con, false); + con->v2.out_state = OUT_S_FINISH_MESSAGE; + } +} + +/* + * Unfortunately the kernel crypto API doesn't support streaming + * (piecewise) operation for AEAD algorithms, so we can't get away + * with a fixed size buffer and a couple sgs. Instead, we have to + * allocate pages for the entire tail of the message (currently up + * to ~32M) and two sgs arrays (up to ~256K each)... + */ +static int prepare_message_secure(struct ceph_connection *con) +{ + void *zerop = page_address(ceph_zero_page); + struct sg_table enc_sgt = {}; + struct sg_table sgt = {}; + struct page **enc_pages; + int enc_page_cnt; + int tail_len; + int ret; + + ret = prepare_head_secure_small(con, con->v2.out_buf, + sizeof(struct ceph_msg_header2)); + if (ret) + return ret; + + tail_len = tail_onwire_len(con->out_msg, true); + if (!tail_len) { + /* + * Empty message: once the head is written, + * we are done -- there is no epilogue. + */ + con->v2.out_state = OUT_S_FINISH_MESSAGE; + return 0; + } + + encode_epilogue_secure(con, false); + ret = setup_message_sgs(&sgt, con->out_msg, zerop, zerop, zerop, + &con->v2.out_epil, false); + if (ret) + goto out; + + enc_page_cnt = calc_pages_for(0, tail_len); + enc_pages = ceph_alloc_page_vector(enc_page_cnt, GFP_NOIO); + if (IS_ERR(enc_pages)) { + ret = PTR_ERR(enc_pages); + goto out; + } + + WARN_ON(con->v2.out_enc_pages || con->v2.out_enc_page_cnt); + con->v2.out_enc_pages = enc_pages; + con->v2.out_enc_page_cnt = enc_page_cnt; + con->v2.out_enc_resid = tail_len; + con->v2.out_enc_i = 0; + + ret = sg_alloc_table_from_pages(&enc_sgt, enc_pages, enc_page_cnt, + 0, tail_len, GFP_NOIO); + if (ret) + goto out; + + ret = gcm_crypt(con, true, sgt.sgl, enc_sgt.sgl, + tail_len - CEPH_GCM_TAG_LEN); + if (ret) + goto out; + + dout("%s con %p msg %p sg_cnt %d enc_page_cnt %d\n", __func__, con, + con->out_msg, sgt.orig_nents, enc_page_cnt); + con->v2.out_state = OUT_S_QUEUE_ENC_PAGE; + +out: + sg_free_table(&sgt); + sg_free_table(&enc_sgt); + return ret; +} + +static int prepare_message(struct ceph_connection *con) +{ + int lens[] = { + sizeof(struct ceph_msg_header2), + front_len(con->out_msg), + middle_len(con->out_msg), + data_len(con->out_msg) + }; + struct ceph_frame_desc desc; + int ret; + + dout("%s con %p msg %p logical %d+%d+%d+%d\n", __func__, con, + con->out_msg, lens[0], lens[1], lens[2], lens[3]); + + if (con->in_seq > con->in_seq_acked) { + dout("%s con %p in_seq_acked %llu -> %llu\n", __func__, con, + con->in_seq_acked, con->in_seq); + con->in_seq_acked = con->in_seq; + } + + reset_out_kvecs(con); + init_frame_desc(&desc, FRAME_TAG_MESSAGE, lens, 4); + encode_preamble(&desc, con->v2.out_buf); + fill_header2(CTRL_BODY(con->v2.out_buf), &con->out_msg->hdr, + con->in_seq_acked); + + if (con_secure(con)) { + ret = prepare_message_secure(con); + if (ret) + return ret; + } else { + prepare_message_plain(con); + } + + ceph_con_flag_set(con, CEPH_CON_F_WRITE_PENDING); + return 0; +} + +static int prepare_read_banner_prefix(struct ceph_connection *con) +{ + void *buf; + + buf = alloc_conn_buf(con, CEPH_BANNER_V2_PREFIX_LEN); + if (!buf) + return -ENOMEM; + + reset_in_kvecs(con); + add_in_kvec(con, buf, CEPH_BANNER_V2_PREFIX_LEN); + add_in_sign_kvec(con, buf, CEPH_BANNER_V2_PREFIX_LEN); + con->state = CEPH_CON_S_V2_BANNER_PREFIX; + return 0; +} + +static int prepare_read_banner_payload(struct ceph_connection *con, + int payload_len) +{ + void *buf; + + buf = alloc_conn_buf(con, payload_len); + if (!buf) + return -ENOMEM; + + reset_in_kvecs(con); + add_in_kvec(con, buf, payload_len); + add_in_sign_kvec(con, buf, payload_len); + con->state = CEPH_CON_S_V2_BANNER_PAYLOAD; + return 0; +} + +static void prepare_read_preamble(struct ceph_connection *con) +{ + reset_in_kvecs(con); + add_in_kvec(con, con->v2.in_buf, + con_secure(con) ? CEPH_PREAMBLE_SECURE_LEN : + CEPH_PREAMBLE_PLAIN_LEN); + con->v2.in_state = IN_S_HANDLE_PREAMBLE; +} + +static int prepare_read_control(struct ceph_connection *con) +{ + int ctrl_len = con->v2.in_desc.fd_lens[0]; + int head_len; + void *buf; + + reset_in_kvecs(con); + if (con->state == CEPH_CON_S_V2_HELLO || + con->state == CEPH_CON_S_V2_AUTH) { + head_len = head_onwire_len(ctrl_len, false); + buf = alloc_conn_buf(con, head_len); + if (!buf) + return -ENOMEM; + + /* preserve preamble */ + memcpy(buf, con->v2.in_buf, CEPH_PREAMBLE_LEN); + + add_in_kvec(con, CTRL_BODY(buf), ctrl_len); + add_in_kvec(con, CTRL_BODY(buf) + ctrl_len, CEPH_CRC_LEN); + add_in_sign_kvec(con, buf, head_len); + } else { + if (ctrl_len > CEPH_PREAMBLE_INLINE_LEN) { + buf = alloc_conn_buf(con, ctrl_len); + if (!buf) + return -ENOMEM; + + add_in_kvec(con, buf, ctrl_len); + } else { + add_in_kvec(con, CTRL_BODY(con->v2.in_buf), ctrl_len); + } + add_in_kvec(con, con->v2.in_buf, CEPH_CRC_LEN); + } + con->v2.in_state = IN_S_HANDLE_CONTROL; + return 0; +} + +static int prepare_read_control_remainder(struct ceph_connection *con) +{ + int ctrl_len = con->v2.in_desc.fd_lens[0]; + int rem_len = ctrl_len - CEPH_PREAMBLE_INLINE_LEN; + void *buf; + + buf = alloc_conn_buf(con, ctrl_len); + if (!buf) + return -ENOMEM; + + memcpy(buf, CTRL_BODY(con->v2.in_buf), CEPH_PREAMBLE_INLINE_LEN); + + reset_in_kvecs(con); + add_in_kvec(con, buf + CEPH_PREAMBLE_INLINE_LEN, rem_len); + add_in_kvec(con, con->v2.in_buf, + padding_len(rem_len) + CEPH_GCM_TAG_LEN); + con->v2.in_state = IN_S_HANDLE_CONTROL_REMAINDER; + return 0; +} + +static int prepare_read_data(struct ceph_connection *con) +{ + struct bio_vec bv; + + con->in_data_crc = -1; + ceph_msg_data_cursor_init(&con->v2.in_cursor, con->in_msg, + data_len(con->in_msg)); + + get_bvec_at(&con->v2.in_cursor, &bv); + if (ceph_test_opt(from_msgr(con->msgr), RXBOUNCE)) { + if (unlikely(!con->bounce_page)) { + con->bounce_page = alloc_page(GFP_NOIO); + if (!con->bounce_page) { + pr_err("failed to allocate bounce page\n"); + return -ENOMEM; + } + } + + bv.bv_page = con->bounce_page; + bv.bv_offset = 0; + } + set_in_bvec(con, &bv); + con->v2.in_state = IN_S_PREPARE_READ_DATA_CONT; + return 0; +} + +static void prepare_read_data_cont(struct ceph_connection *con) +{ + struct bio_vec bv; + + if (ceph_test_opt(from_msgr(con->msgr), RXBOUNCE)) { + con->in_data_crc = crc32c(con->in_data_crc, + page_address(con->bounce_page), + con->v2.in_bvec.bv_len); + + get_bvec_at(&con->v2.in_cursor, &bv); + memcpy_to_page(bv.bv_page, bv.bv_offset, + page_address(con->bounce_page), + con->v2.in_bvec.bv_len); + } else { + con->in_data_crc = ceph_crc32c_page(con->in_data_crc, + con->v2.in_bvec.bv_page, + con->v2.in_bvec.bv_offset, + con->v2.in_bvec.bv_len); + } + + ceph_msg_data_advance(&con->v2.in_cursor, con->v2.in_bvec.bv_len); + if (con->v2.in_cursor.total_resid) { + get_bvec_at(&con->v2.in_cursor, &bv); + if (ceph_test_opt(from_msgr(con->msgr), RXBOUNCE)) { + bv.bv_page = con->bounce_page; + bv.bv_offset = 0; + } + set_in_bvec(con, &bv); + WARN_ON(con->v2.in_state != IN_S_PREPARE_READ_DATA_CONT); + return; + } + + /* + * We've read all data. Prepare to read epilogue. + */ + reset_in_kvecs(con); + add_in_kvec(con, con->v2.in_buf, CEPH_EPILOGUE_PLAIN_LEN); + con->v2.in_state = IN_S_HANDLE_EPILOGUE; +} + +static int prepare_read_tail_plain(struct ceph_connection *con) +{ + struct ceph_msg *msg = con->in_msg; + + if (!front_len(msg) && !middle_len(msg)) { + WARN_ON(!data_len(msg)); + return prepare_read_data(con); + } + + reset_in_kvecs(con); + if (front_len(msg)) { + add_in_kvec(con, msg->front.iov_base, front_len(msg)); + WARN_ON(msg->front.iov_len != front_len(msg)); + } + if (middle_len(msg)) { + add_in_kvec(con, msg->middle->vec.iov_base, middle_len(msg)); + WARN_ON(msg->middle->vec.iov_len != middle_len(msg)); + } + + if (data_len(msg)) { + con->v2.in_state = IN_S_PREPARE_READ_DATA; + } else { + add_in_kvec(con, con->v2.in_buf, CEPH_EPILOGUE_PLAIN_LEN); + con->v2.in_state = IN_S_HANDLE_EPILOGUE; + } + return 0; +} + +static void prepare_read_enc_page(struct ceph_connection *con) +{ + struct bio_vec bv; + + dout("%s con %p i %d resid %d\n", __func__, con, con->v2.in_enc_i, + con->v2.in_enc_resid); + WARN_ON(!con->v2.in_enc_resid); + + bv.bv_page = con->v2.in_enc_pages[con->v2.in_enc_i]; + bv.bv_offset = 0; + bv.bv_len = min(con->v2.in_enc_resid, (int)PAGE_SIZE); + + set_in_bvec(con, &bv); + con->v2.in_enc_i++; + con->v2.in_enc_resid -= bv.bv_len; + + if (con->v2.in_enc_resid) { + con->v2.in_state = IN_S_PREPARE_READ_ENC_PAGE; + return; + } + + /* + * We are set to read the last piece of ciphertext (ending + * with epilogue) + auth tag. + */ + WARN_ON(con->v2.in_enc_i != con->v2.in_enc_page_cnt); + con->v2.in_state = IN_S_HANDLE_EPILOGUE; +} + +static int prepare_read_tail_secure(struct ceph_connection *con) +{ + struct page **enc_pages; + int enc_page_cnt; + int tail_len; + + tail_len = tail_onwire_len(con->in_msg, true); + WARN_ON(!tail_len); + + enc_page_cnt = calc_pages_for(0, tail_len); + enc_pages = ceph_alloc_page_vector(enc_page_cnt, GFP_NOIO); + if (IS_ERR(enc_pages)) + return PTR_ERR(enc_pages); + + WARN_ON(con->v2.in_enc_pages || con->v2.in_enc_page_cnt); + con->v2.in_enc_pages = enc_pages; + con->v2.in_enc_page_cnt = enc_page_cnt; + con->v2.in_enc_resid = tail_len; + con->v2.in_enc_i = 0; + + prepare_read_enc_page(con); + return 0; +} + +static void __finish_skip(struct ceph_connection *con) +{ + con->in_seq++; + prepare_read_preamble(con); +} + +static void prepare_skip_message(struct ceph_connection *con) +{ + struct ceph_frame_desc *desc = &con->v2.in_desc; + int tail_len; + + dout("%s con %p %d+%d+%d\n", __func__, con, desc->fd_lens[1], + desc->fd_lens[2], desc->fd_lens[3]); + + tail_len = __tail_onwire_len(desc->fd_lens[1], desc->fd_lens[2], + desc->fd_lens[3], con_secure(con)); + if (!tail_len) { + __finish_skip(con); + } else { + set_in_skip(con, tail_len); + con->v2.in_state = IN_S_FINISH_SKIP; + } +} + +static int process_banner_prefix(struct ceph_connection *con) +{ + int payload_len; + void *p; + + WARN_ON(con->v2.in_kvecs[0].iov_len != CEPH_BANNER_V2_PREFIX_LEN); + + p = con->v2.in_kvecs[0].iov_base; + if (memcmp(p, CEPH_BANNER_V2, CEPH_BANNER_V2_LEN)) { + if (!memcmp(p, CEPH_BANNER, CEPH_BANNER_LEN)) + con->error_msg = "server is speaking msgr1 protocol"; + else + con->error_msg = "protocol error, bad banner"; + return -EINVAL; + } + + p += CEPH_BANNER_V2_LEN; + payload_len = ceph_decode_16(&p); + dout("%s con %p payload_len %d\n", __func__, con, payload_len); + + return prepare_read_banner_payload(con, payload_len); +} + +static int process_banner_payload(struct ceph_connection *con) +{ + void *end = con->v2.in_kvecs[0].iov_base + con->v2.in_kvecs[0].iov_len; + u64 feat = CEPH_MSGR2_SUPPORTED_FEATURES; + u64 req_feat = CEPH_MSGR2_REQUIRED_FEATURES; + u64 server_feat, server_req_feat; + void *p; + int ret; + + p = con->v2.in_kvecs[0].iov_base; + ceph_decode_64_safe(&p, end, server_feat, bad); + ceph_decode_64_safe(&p, end, server_req_feat, bad); + + dout("%s con %p server_feat 0x%llx server_req_feat 0x%llx\n", + __func__, con, server_feat, server_req_feat); + + if (req_feat & ~server_feat) { + pr_err("msgr2 feature set mismatch: my required > server's supported 0x%llx, need 0x%llx\n", + server_feat, req_feat & ~server_feat); + con->error_msg = "missing required protocol features"; + return -EINVAL; + } + if (server_req_feat & ~feat) { + pr_err("msgr2 feature set mismatch: server's required > my supported 0x%llx, missing 0x%llx\n", + feat, server_req_feat & ~feat); + con->error_msg = "missing required protocol features"; + return -EINVAL; + } + + /* no reset_out_kvecs() as our banner may still be pending */ + ret = prepare_hello(con); + if (ret) { + pr_err("prepare_hello failed: %d\n", ret); + return ret; + } + + con->state = CEPH_CON_S_V2_HELLO; + prepare_read_preamble(con); + return 0; + +bad: + pr_err("failed to decode banner payload\n"); + return -EINVAL; +} + +static int process_hello(struct ceph_connection *con, void *p, void *end) +{ + struct ceph_entity_addr *my_addr = &con->msgr->inst.addr; + struct ceph_entity_addr addr_for_me; + u8 entity_type; + int ret; + + if (con->state != CEPH_CON_S_V2_HELLO) { + con->error_msg = "protocol error, unexpected hello"; + return -EINVAL; + } + + ceph_decode_8_safe(&p, end, entity_type, bad); + ret = ceph_decode_entity_addr(&p, end, &addr_for_me); + if (ret) { + pr_err("failed to decode addr_for_me: %d\n", ret); + return ret; + } + + dout("%s con %p entity_type %d addr_for_me %s\n", __func__, con, + entity_type, ceph_pr_addr(&addr_for_me)); + + if (entity_type != con->peer_name.type) { + pr_err("bad peer type, want %d, got %d\n", + con->peer_name.type, entity_type); + con->error_msg = "wrong peer at address"; + return -EINVAL; + } + + /* + * Set our address to the address our first peer (i.e. monitor) + * sees that we are connecting from. If we are behind some sort + * of NAT and want to be identified by some private (not NATed) + * address, ip option should be used. + */ + if (ceph_addr_is_blank(my_addr)) { + memcpy(&my_addr->in_addr, &addr_for_me.in_addr, + sizeof(my_addr->in_addr)); + ceph_addr_set_port(my_addr, 0); + dout("%s con %p set my addr %s, as seen by peer %s\n", + __func__, con, ceph_pr_addr(my_addr), + ceph_pr_addr(&con->peer_addr)); + } else { + dout("%s con %p my addr already set %s\n", + __func__, con, ceph_pr_addr(my_addr)); + } + + WARN_ON(ceph_addr_is_blank(my_addr) || ceph_addr_port(my_addr)); + WARN_ON(my_addr->type != CEPH_ENTITY_ADDR_TYPE_ANY); + WARN_ON(!my_addr->nonce); + + /* no reset_out_kvecs() as our hello may still be pending */ + ret = prepare_auth_request(con); + if (ret) { + if (ret != -EAGAIN) + pr_err("prepare_auth_request failed: %d\n", ret); + return ret; + } + + con->state = CEPH_CON_S_V2_AUTH; + return 0; + +bad: + pr_err("failed to decode hello\n"); + return -EINVAL; +} + +static int process_auth_bad_method(struct ceph_connection *con, + void *p, void *end) +{ + int allowed_protos[8], allowed_modes[8]; + int allowed_proto_cnt, allowed_mode_cnt; + int used_proto, result; + int ret; + int i; + + if (con->state != CEPH_CON_S_V2_AUTH) { + con->error_msg = "protocol error, unexpected auth_bad_method"; + return -EINVAL; + } + + ceph_decode_32_safe(&p, end, used_proto, bad); + ceph_decode_32_safe(&p, end, result, bad); + dout("%s con %p used_proto %d result %d\n", __func__, con, used_proto, + result); + + ceph_decode_32_safe(&p, end, allowed_proto_cnt, bad); + if (allowed_proto_cnt > ARRAY_SIZE(allowed_protos)) { + pr_err("allowed_protos too big %d\n", allowed_proto_cnt); + return -EINVAL; + } + for (i = 0; i < allowed_proto_cnt; i++) { + ceph_decode_32_safe(&p, end, allowed_protos[i], bad); + dout("%s con %p allowed_protos[%d] %d\n", __func__, con, + i, allowed_protos[i]); + } + + ceph_decode_32_safe(&p, end, allowed_mode_cnt, bad); + if (allowed_mode_cnt > ARRAY_SIZE(allowed_modes)) { + pr_err("allowed_modes too big %d\n", allowed_mode_cnt); + return -EINVAL; + } + for (i = 0; i < allowed_mode_cnt; i++) { + ceph_decode_32_safe(&p, end, allowed_modes[i], bad); + dout("%s con %p allowed_modes[%d] %d\n", __func__, con, + i, allowed_modes[i]); + } + + mutex_unlock(&con->mutex); + ret = con->ops->handle_auth_bad_method(con, used_proto, result, + allowed_protos, + allowed_proto_cnt, + allowed_modes, + allowed_mode_cnt); + mutex_lock(&con->mutex); + if (con->state != CEPH_CON_S_V2_AUTH) { + dout("%s con %p state changed to %d\n", __func__, con, + con->state); + return -EAGAIN; + } + + dout("%s con %p handle_auth_bad_method ret %d\n", __func__, con, ret); + return ret; + +bad: + pr_err("failed to decode auth_bad_method\n"); + return -EINVAL; +} + +static int process_auth_reply_more(struct ceph_connection *con, + void *p, void *end) +{ + int payload_len; + int ret; + + if (con->state != CEPH_CON_S_V2_AUTH) { + con->error_msg = "protocol error, unexpected auth_reply_more"; + return -EINVAL; + } + + ceph_decode_32_safe(&p, end, payload_len, bad); + ceph_decode_need(&p, end, payload_len, bad); + + dout("%s con %p payload_len %d\n", __func__, con, payload_len); + + reset_out_kvecs(con); + ret = prepare_auth_request_more(con, p, payload_len); + if (ret) { + if (ret != -EAGAIN) + pr_err("prepare_auth_request_more failed: %d\n", ret); + return ret; + } + + return 0; + +bad: + pr_err("failed to decode auth_reply_more\n"); + return -EINVAL; +} + +/* + * Align session_key and con_secret to avoid GFP_ATOMIC allocation + * inside crypto_shash_setkey() and crypto_aead_setkey() called from + * setup_crypto(). __aligned(16) isn't guaranteed to work for stack + * objects, so do it by hand. + */ +static int process_auth_done(struct ceph_connection *con, void *p, void *end) +{ + u8 session_key_buf[CEPH_KEY_LEN + 16]; + u8 con_secret_buf[CEPH_MAX_CON_SECRET_LEN + 16]; + u8 *session_key = PTR_ALIGN(&session_key_buf[0], 16); + u8 *con_secret = PTR_ALIGN(&con_secret_buf[0], 16); + int session_key_len, con_secret_len; + int payload_len; + u64 global_id; + int ret; + + if (con->state != CEPH_CON_S_V2_AUTH) { + con->error_msg = "protocol error, unexpected auth_done"; + return -EINVAL; + } + + ceph_decode_64_safe(&p, end, global_id, bad); + ceph_decode_32_safe(&p, end, con->v2.con_mode, bad); + ceph_decode_32_safe(&p, end, payload_len, bad); + + dout("%s con %p global_id %llu con_mode %d payload_len %d\n", + __func__, con, global_id, con->v2.con_mode, payload_len); + + mutex_unlock(&con->mutex); + session_key_len = 0; + con_secret_len = 0; + ret = con->ops->handle_auth_done(con, global_id, p, payload_len, + session_key, &session_key_len, + con_secret, &con_secret_len); + mutex_lock(&con->mutex); + if (con->state != CEPH_CON_S_V2_AUTH) { + dout("%s con %p state changed to %d\n", __func__, con, + con->state); + ret = -EAGAIN; + goto out; + } + + dout("%s con %p handle_auth_done ret %d\n", __func__, con, ret); + if (ret) + goto out; + + ret = setup_crypto(con, session_key, session_key_len, con_secret, + con_secret_len); + if (ret) + goto out; + + reset_out_kvecs(con); + ret = prepare_auth_signature(con); + if (ret) { + pr_err("prepare_auth_signature failed: %d\n", ret); + goto out; + } + + con->state = CEPH_CON_S_V2_AUTH_SIGNATURE; + +out: + memzero_explicit(session_key_buf, sizeof(session_key_buf)); + memzero_explicit(con_secret_buf, sizeof(con_secret_buf)); + return ret; + +bad: + pr_err("failed to decode auth_done\n"); + return -EINVAL; +} + +static int process_auth_signature(struct ceph_connection *con, + void *p, void *end) +{ + u8 hmac[SHA256_DIGEST_SIZE]; + int ret; + + if (con->state != CEPH_CON_S_V2_AUTH_SIGNATURE) { + con->error_msg = "protocol error, unexpected auth_signature"; + return -EINVAL; + } + + ret = hmac_sha256(con, con->v2.out_sign_kvecs, + con->v2.out_sign_kvec_cnt, hmac); + if (ret) + return ret; + + ceph_decode_need(&p, end, SHA256_DIGEST_SIZE, bad); + if (crypto_memneq(p, hmac, SHA256_DIGEST_SIZE)) { + con->error_msg = "integrity error, bad auth signature"; + return -EBADMSG; + } + + dout("%s con %p auth signature ok\n", __func__, con); + + /* no reset_out_kvecs() as our auth_signature may still be pending */ + if (!con->v2.server_cookie) { + ret = prepare_client_ident(con); + if (ret) { + pr_err("prepare_client_ident failed: %d\n", ret); + return ret; + } + + con->state = CEPH_CON_S_V2_SESSION_CONNECT; + } else { + ret = prepare_session_reconnect(con); + if (ret) { + pr_err("prepare_session_reconnect failed: %d\n", ret); + return ret; + } + + con->state = CEPH_CON_S_V2_SESSION_RECONNECT; + } + + return 0; + +bad: + pr_err("failed to decode auth_signature\n"); + return -EINVAL; +} + +static int process_server_ident(struct ceph_connection *con, + void *p, void *end) +{ + struct ceph_client *client = from_msgr(con->msgr); + u64 features, required_features; + struct ceph_entity_addr addr; + u64 global_seq; + u64 global_id; + u64 cookie; + u64 flags; + int ret; + + if (con->state != CEPH_CON_S_V2_SESSION_CONNECT) { + con->error_msg = "protocol error, unexpected server_ident"; + return -EINVAL; + } + + ret = ceph_decode_entity_addrvec(&p, end, true, &addr); + if (ret) { + pr_err("failed to decode server addrs: %d\n", ret); + return ret; + } + + ceph_decode_64_safe(&p, end, global_id, bad); + ceph_decode_64_safe(&p, end, global_seq, bad); + ceph_decode_64_safe(&p, end, features, bad); + ceph_decode_64_safe(&p, end, required_features, bad); + ceph_decode_64_safe(&p, end, flags, bad); + ceph_decode_64_safe(&p, end, cookie, bad); + + dout("%s con %p addr %s/%u global_id %llu global_seq %llu features 0x%llx required_features 0x%llx flags 0x%llx cookie 0x%llx\n", + __func__, con, ceph_pr_addr(&addr), le32_to_cpu(addr.nonce), + global_id, global_seq, features, required_features, flags, cookie); + + /* is this who we intended to talk to? */ + if (memcmp(&addr, &con->peer_addr, sizeof(con->peer_addr))) { + pr_err("bad peer addr/nonce, want %s/%u, got %s/%u\n", + ceph_pr_addr(&con->peer_addr), + le32_to_cpu(con->peer_addr.nonce), + ceph_pr_addr(&addr), le32_to_cpu(addr.nonce)); + con->error_msg = "wrong peer at address"; + return -EINVAL; + } + + if (client->required_features & ~features) { + pr_err("RADOS feature set mismatch: my required > server's supported 0x%llx, need 0x%llx\n", + features, client->required_features & ~features); + con->error_msg = "missing required protocol features"; + return -EINVAL; + } + + /* + * Both name->type and name->num are set in ceph_con_open() but + * name->num may be bogus in the initial monmap. name->type is + * verified in handle_hello(). + */ + WARN_ON(!con->peer_name.type); + con->peer_name.num = cpu_to_le64(global_id); + con->v2.peer_global_seq = global_seq; + con->peer_features = features; + WARN_ON(required_features & ~client->supported_features); + con->v2.server_cookie = cookie; + + if (flags & CEPH_MSG_CONNECT_LOSSY) { + ceph_con_flag_set(con, CEPH_CON_F_LOSSYTX); + WARN_ON(con->v2.server_cookie); + } else { + WARN_ON(!con->v2.server_cookie); + } + + clear_in_sign_kvecs(con); + clear_out_sign_kvecs(con); + free_conn_bufs(con); + con->delay = 0; /* reset backoff memory */ + + con->state = CEPH_CON_S_OPEN; + con->v2.out_state = OUT_S_GET_NEXT; + return 0; + +bad: + pr_err("failed to decode server_ident\n"); + return -EINVAL; +} + +static int process_ident_missing_features(struct ceph_connection *con, + void *p, void *end) +{ + struct ceph_client *client = from_msgr(con->msgr); + u64 missing_features; + + if (con->state != CEPH_CON_S_V2_SESSION_CONNECT) { + con->error_msg = "protocol error, unexpected ident_missing_features"; + return -EINVAL; + } + + ceph_decode_64_safe(&p, end, missing_features, bad); + pr_err("RADOS feature set mismatch: server's required > my supported 0x%llx, missing 0x%llx\n", + client->supported_features, missing_features); + con->error_msg = "missing required protocol features"; + return -EINVAL; + +bad: + pr_err("failed to decode ident_missing_features\n"); + return -EINVAL; +} + +static int process_session_reconnect_ok(struct ceph_connection *con, + void *p, void *end) +{ + u64 seq; + + if (con->state != CEPH_CON_S_V2_SESSION_RECONNECT) { + con->error_msg = "protocol error, unexpected session_reconnect_ok"; + return -EINVAL; + } + + ceph_decode_64_safe(&p, end, seq, bad); + + dout("%s con %p seq %llu\n", __func__, con, seq); + ceph_con_discard_requeued(con, seq); + + clear_in_sign_kvecs(con); + clear_out_sign_kvecs(con); + free_conn_bufs(con); + con->delay = 0; /* reset backoff memory */ + + con->state = CEPH_CON_S_OPEN; + con->v2.out_state = OUT_S_GET_NEXT; + return 0; + +bad: + pr_err("failed to decode session_reconnect_ok\n"); + return -EINVAL; +} + +static int process_session_retry(struct ceph_connection *con, + void *p, void *end) +{ + u64 connect_seq; + int ret; + + if (con->state != CEPH_CON_S_V2_SESSION_RECONNECT) { + con->error_msg = "protocol error, unexpected session_retry"; + return -EINVAL; + } + + ceph_decode_64_safe(&p, end, connect_seq, bad); + + dout("%s con %p connect_seq %llu\n", __func__, con, connect_seq); + WARN_ON(connect_seq <= con->v2.connect_seq); + con->v2.connect_seq = connect_seq + 1; + + free_conn_bufs(con); + + reset_out_kvecs(con); + ret = prepare_session_reconnect(con); + if (ret) { + pr_err("prepare_session_reconnect (cseq) failed: %d\n", ret); + return ret; + } + + return 0; + +bad: + pr_err("failed to decode session_retry\n"); + return -EINVAL; +} + +static int process_session_retry_global(struct ceph_connection *con, + void *p, void *end) +{ + u64 global_seq; + int ret; + + if (con->state != CEPH_CON_S_V2_SESSION_RECONNECT) { + con->error_msg = "protocol error, unexpected session_retry_global"; + return -EINVAL; + } + + ceph_decode_64_safe(&p, end, global_seq, bad); + + dout("%s con %p global_seq %llu\n", __func__, con, global_seq); + WARN_ON(global_seq <= con->v2.global_seq); + con->v2.global_seq = ceph_get_global_seq(con->msgr, global_seq); + + free_conn_bufs(con); + + reset_out_kvecs(con); + ret = prepare_session_reconnect(con); + if (ret) { + pr_err("prepare_session_reconnect (gseq) failed: %d\n", ret); + return ret; + } + + return 0; + +bad: + pr_err("failed to decode session_retry_global\n"); + return -EINVAL; +} + +static int process_session_reset(struct ceph_connection *con, + void *p, void *end) +{ + bool full; + int ret; + + if (con->state != CEPH_CON_S_V2_SESSION_RECONNECT) { + con->error_msg = "protocol error, unexpected session_reset"; + return -EINVAL; + } + + ceph_decode_8_safe(&p, end, full, bad); + if (!full) { + con->error_msg = "protocol error, bad session_reset"; + return -EINVAL; + } + + pr_info("%s%lld %s session reset\n", ENTITY_NAME(con->peer_name), + ceph_pr_addr(&con->peer_addr)); + ceph_con_reset_session(con); + + mutex_unlock(&con->mutex); + if (con->ops->peer_reset) + con->ops->peer_reset(con); + mutex_lock(&con->mutex); + if (con->state != CEPH_CON_S_V2_SESSION_RECONNECT) { + dout("%s con %p state changed to %d\n", __func__, con, + con->state); + return -EAGAIN; + } + + free_conn_bufs(con); + + reset_out_kvecs(con); + ret = prepare_client_ident(con); + if (ret) { + pr_err("prepare_client_ident (rst) failed: %d\n", ret); + return ret; + } + + con->state = CEPH_CON_S_V2_SESSION_CONNECT; + return 0; + +bad: + pr_err("failed to decode session_reset\n"); + return -EINVAL; +} + +static int process_keepalive2_ack(struct ceph_connection *con, + void *p, void *end) +{ + if (con->state != CEPH_CON_S_OPEN) { + con->error_msg = "protocol error, unexpected keepalive2_ack"; + return -EINVAL; + } + + ceph_decode_need(&p, end, sizeof(struct ceph_timespec), bad); + ceph_decode_timespec64(&con->last_keepalive_ack, p); + + dout("%s con %p timestamp %lld.%09ld\n", __func__, con, + con->last_keepalive_ack.tv_sec, con->last_keepalive_ack.tv_nsec); + + return 0; + +bad: + pr_err("failed to decode keepalive2_ack\n"); + return -EINVAL; +} + +static int process_ack(struct ceph_connection *con, void *p, void *end) +{ + u64 seq; + + if (con->state != CEPH_CON_S_OPEN) { + con->error_msg = "protocol error, unexpected ack"; + return -EINVAL; + } + + ceph_decode_64_safe(&p, end, seq, bad); + + dout("%s con %p seq %llu\n", __func__, con, seq); + ceph_con_discard_sent(con, seq); + return 0; + +bad: + pr_err("failed to decode ack\n"); + return -EINVAL; +} + +static int process_control(struct ceph_connection *con, void *p, void *end) +{ + int tag = con->v2.in_desc.fd_tag; + int ret; + + dout("%s con %p tag %d len %d\n", __func__, con, tag, (int)(end - p)); + + switch (tag) { + case FRAME_TAG_HELLO: + ret = process_hello(con, p, end); + break; + case FRAME_TAG_AUTH_BAD_METHOD: + ret = process_auth_bad_method(con, p, end); + break; + case FRAME_TAG_AUTH_REPLY_MORE: + ret = process_auth_reply_more(con, p, end); + break; + case FRAME_TAG_AUTH_DONE: + ret = process_auth_done(con, p, end); + break; + case FRAME_TAG_AUTH_SIGNATURE: + ret = process_auth_signature(con, p, end); + break; + case FRAME_TAG_SERVER_IDENT: + ret = process_server_ident(con, p, end); + break; + case FRAME_TAG_IDENT_MISSING_FEATURES: + ret = process_ident_missing_features(con, p, end); + break; + case FRAME_TAG_SESSION_RECONNECT_OK: + ret = process_session_reconnect_ok(con, p, end); + break; + case FRAME_TAG_SESSION_RETRY: + ret = process_session_retry(con, p, end); + break; + case FRAME_TAG_SESSION_RETRY_GLOBAL: + ret = process_session_retry_global(con, p, end); + break; + case FRAME_TAG_SESSION_RESET: + ret = process_session_reset(con, p, end); + break; + case FRAME_TAG_KEEPALIVE2_ACK: + ret = process_keepalive2_ack(con, p, end); + break; + case FRAME_TAG_ACK: + ret = process_ack(con, p, end); + break; + default: + pr_err("bad tag %d\n", tag); + con->error_msg = "protocol error, bad tag"; + return -EINVAL; + } + if (ret) { + dout("%s con %p error %d\n", __func__, con, ret); + return ret; + } + + prepare_read_preamble(con); + return 0; +} + +/* + * Return: + * 1 - con->in_msg set, read message + * 0 - skip message + * <0 - error + */ +static int process_message_header(struct ceph_connection *con, + void *p, void *end) +{ + struct ceph_frame_desc *desc = &con->v2.in_desc; + struct ceph_msg_header2 *hdr2 = p; + struct ceph_msg_header hdr; + int skip; + int ret; + u64 seq; + + /* verify seq# */ + seq = le64_to_cpu(hdr2->seq); + if ((s64)seq - (s64)con->in_seq < 1) { + pr_info("%s%lld %s skipping old message: seq %llu, expected %llu\n", + ENTITY_NAME(con->peer_name), + ceph_pr_addr(&con->peer_addr), + seq, con->in_seq + 1); + return 0; + } + if ((s64)seq - (s64)con->in_seq > 1) { + pr_err("bad seq %llu, expected %llu\n", seq, con->in_seq + 1); + con->error_msg = "bad message sequence # for incoming message"; + return -EBADE; + } + + ceph_con_discard_sent(con, le64_to_cpu(hdr2->ack_seq)); + + fill_header(&hdr, hdr2, desc->fd_lens[1], desc->fd_lens[2], + desc->fd_lens[3], &con->peer_name); + ret = ceph_con_in_msg_alloc(con, &hdr, &skip); + if (ret) + return ret; + + WARN_ON(!con->in_msg ^ skip); + if (skip) + return 0; + + WARN_ON(!con->in_msg); + WARN_ON(con->in_msg->con != con); + return 1; +} + +static int process_message(struct ceph_connection *con) +{ + ceph_con_process_message(con); + + /* + * We could have been closed by ceph_con_close() because + * ceph_con_process_message() temporarily drops con->mutex. + */ + if (con->state != CEPH_CON_S_OPEN) { + dout("%s con %p state changed to %d\n", __func__, con, + con->state); + return -EAGAIN; + } + + prepare_read_preamble(con); + return 0; +} + +static int __handle_control(struct ceph_connection *con, void *p) +{ + void *end = p + con->v2.in_desc.fd_lens[0]; + struct ceph_msg *msg; + int ret; + + if (con->v2.in_desc.fd_tag != FRAME_TAG_MESSAGE) + return process_control(con, p, end); + + ret = process_message_header(con, p, end); + if (ret < 0) + return ret; + if (ret == 0) { + prepare_skip_message(con); + return 0; + } + + msg = con->in_msg; /* set in process_message_header() */ + if (front_len(msg)) { + WARN_ON(front_len(msg) > msg->front_alloc_len); + msg->front.iov_len = front_len(msg); + } else { + msg->front.iov_len = 0; + } + if (middle_len(msg)) { + WARN_ON(middle_len(msg) > msg->middle->alloc_len); + msg->middle->vec.iov_len = middle_len(msg); + } else if (msg->middle) { + msg->middle->vec.iov_len = 0; + } + + if (!front_len(msg) && !middle_len(msg) && !data_len(msg)) + return process_message(con); + + if (con_secure(con)) + return prepare_read_tail_secure(con); + + return prepare_read_tail_plain(con); +} + +static int handle_preamble(struct ceph_connection *con) +{ + struct ceph_frame_desc *desc = &con->v2.in_desc; + int ret; + + if (con_secure(con)) { + ret = decrypt_preamble(con); + if (ret) { + if (ret == -EBADMSG) + con->error_msg = "integrity error, bad preamble auth tag"; + return ret; + } + } + + ret = decode_preamble(con->v2.in_buf, desc); + if (ret) { + if (ret == -EBADMSG) + con->error_msg = "integrity error, bad crc"; + else + con->error_msg = "protocol error, bad preamble"; + return ret; + } + + dout("%s con %p tag %d seg_cnt %d %d+%d+%d+%d\n", __func__, + con, desc->fd_tag, desc->fd_seg_cnt, desc->fd_lens[0], + desc->fd_lens[1], desc->fd_lens[2], desc->fd_lens[3]); + + if (!con_secure(con)) + return prepare_read_control(con); + + if (desc->fd_lens[0] > CEPH_PREAMBLE_INLINE_LEN) + return prepare_read_control_remainder(con); + + return __handle_control(con, CTRL_BODY(con->v2.in_buf)); +} + +static int handle_control(struct ceph_connection *con) +{ + int ctrl_len = con->v2.in_desc.fd_lens[0]; + void *buf; + int ret; + + WARN_ON(con_secure(con)); + + ret = verify_control_crc(con); + if (ret) { + con->error_msg = "integrity error, bad crc"; + return ret; + } + + if (con->state == CEPH_CON_S_V2_AUTH) { + buf = alloc_conn_buf(con, ctrl_len); + if (!buf) + return -ENOMEM; + + memcpy(buf, con->v2.in_kvecs[0].iov_base, ctrl_len); + return __handle_control(con, buf); + } + + return __handle_control(con, con->v2.in_kvecs[0].iov_base); +} + +static int handle_control_remainder(struct ceph_connection *con) +{ + int ret; + + WARN_ON(!con_secure(con)); + + ret = decrypt_control_remainder(con); + if (ret) { + if (ret == -EBADMSG) + con->error_msg = "integrity error, bad control remainder auth tag"; + return ret; + } + + return __handle_control(con, con->v2.in_kvecs[0].iov_base - + CEPH_PREAMBLE_INLINE_LEN); +} + +static int handle_epilogue(struct ceph_connection *con) +{ + u32 front_crc, middle_crc, data_crc; + int ret; + + if (con_secure(con)) { + ret = decrypt_tail(con); + if (ret) { + if (ret == -EBADMSG) + con->error_msg = "integrity error, bad epilogue auth tag"; + return ret; + } + + /* just late_status */ + ret = decode_epilogue(con->v2.in_buf, NULL, NULL, NULL); + if (ret) { + con->error_msg = "protocol error, bad epilogue"; + return ret; + } + } else { + ret = decode_epilogue(con->v2.in_buf, &front_crc, + &middle_crc, &data_crc); + if (ret) { + con->error_msg = "protocol error, bad epilogue"; + return ret; + } + + ret = verify_epilogue_crcs(con, front_crc, middle_crc, + data_crc); + if (ret) { + con->error_msg = "integrity error, bad crc"; + return ret; + } + } + + return process_message(con); +} + +static void finish_skip(struct ceph_connection *con) +{ + dout("%s con %p\n", __func__, con); + + if (con_secure(con)) + gcm_inc_nonce(&con->v2.in_gcm_nonce); + + __finish_skip(con); +} + +static int populate_in_iter(struct ceph_connection *con) +{ + int ret; + + dout("%s con %p state %d in_state %d\n", __func__, con, con->state, + con->v2.in_state); + WARN_ON(iov_iter_count(&con->v2.in_iter)); + + if (con->state == CEPH_CON_S_V2_BANNER_PREFIX) { + ret = process_banner_prefix(con); + } else if (con->state == CEPH_CON_S_V2_BANNER_PAYLOAD) { + ret = process_banner_payload(con); + } else if ((con->state >= CEPH_CON_S_V2_HELLO && + con->state <= CEPH_CON_S_V2_SESSION_RECONNECT) || + con->state == CEPH_CON_S_OPEN) { + switch (con->v2.in_state) { + case IN_S_HANDLE_PREAMBLE: + ret = handle_preamble(con); + break; + case IN_S_HANDLE_CONTROL: + ret = handle_control(con); + break; + case IN_S_HANDLE_CONTROL_REMAINDER: + ret = handle_control_remainder(con); + break; + case IN_S_PREPARE_READ_DATA: + ret = prepare_read_data(con); + break; + case IN_S_PREPARE_READ_DATA_CONT: + prepare_read_data_cont(con); + ret = 0; + break; + case IN_S_PREPARE_READ_ENC_PAGE: + prepare_read_enc_page(con); + ret = 0; + break; + case IN_S_HANDLE_EPILOGUE: + ret = handle_epilogue(con); + break; + case IN_S_FINISH_SKIP: + finish_skip(con); + ret = 0; + break; + default: + WARN(1, "bad in_state %d", con->v2.in_state); + return -EINVAL; + } + } else { + WARN(1, "bad state %d", con->state); + return -EINVAL; + } + if (ret) { + dout("%s con %p error %d\n", __func__, con, ret); + return ret; + } + + if (WARN_ON(!iov_iter_count(&con->v2.in_iter))) + return -ENODATA; + dout("%s con %p populated %zu\n", __func__, con, + iov_iter_count(&con->v2.in_iter)); + return 1; +} + +int ceph_con_v2_try_read(struct ceph_connection *con) +{ + int ret; + + dout("%s con %p state %d need %zu\n", __func__, con, con->state, + iov_iter_count(&con->v2.in_iter)); + + if (con->state == CEPH_CON_S_PREOPEN) + return 0; + + /* + * We should always have something pending here. If not, + * avoid calling populate_in_iter() as if we read something + * (ceph_tcp_recv() would immediately return 1). + */ + if (WARN_ON(!iov_iter_count(&con->v2.in_iter))) + return -ENODATA; + + for (;;) { + ret = ceph_tcp_recv(con); + if (ret <= 0) + return ret; + + ret = populate_in_iter(con); + if (ret <= 0) { + if (ret && ret != -EAGAIN && !con->error_msg) + con->error_msg = "read processing error"; + return ret; + } + } +} + +static void queue_data(struct ceph_connection *con) +{ + struct bio_vec bv; + + con->v2.out_epil.data_crc = -1; + ceph_msg_data_cursor_init(&con->v2.out_cursor, con->out_msg, + data_len(con->out_msg)); + + get_bvec_at(&con->v2.out_cursor, &bv); + set_out_bvec(con, &bv, true); + con->v2.out_state = OUT_S_QUEUE_DATA_CONT; +} + +static void queue_data_cont(struct ceph_connection *con) +{ + struct bio_vec bv; + + con->v2.out_epil.data_crc = ceph_crc32c_page( + con->v2.out_epil.data_crc, con->v2.out_bvec.bv_page, + con->v2.out_bvec.bv_offset, con->v2.out_bvec.bv_len); + + ceph_msg_data_advance(&con->v2.out_cursor, con->v2.out_bvec.bv_len); + if (con->v2.out_cursor.total_resid) { + get_bvec_at(&con->v2.out_cursor, &bv); + set_out_bvec(con, &bv, true); + WARN_ON(con->v2.out_state != OUT_S_QUEUE_DATA_CONT); + return; + } + + /* + * We've written all data. Queue epilogue. Once it's written, + * we are done. + */ + reset_out_kvecs(con); + prepare_epilogue_plain(con, false); + con->v2.out_state = OUT_S_FINISH_MESSAGE; +} + +static void queue_enc_page(struct ceph_connection *con) +{ + struct bio_vec bv; + + dout("%s con %p i %d resid %d\n", __func__, con, con->v2.out_enc_i, + con->v2.out_enc_resid); + WARN_ON(!con->v2.out_enc_resid); + + bv.bv_page = con->v2.out_enc_pages[con->v2.out_enc_i]; + bv.bv_offset = 0; + bv.bv_len = min(con->v2.out_enc_resid, (int)PAGE_SIZE); + + set_out_bvec(con, &bv, false); + con->v2.out_enc_i++; + con->v2.out_enc_resid -= bv.bv_len; + + if (con->v2.out_enc_resid) { + WARN_ON(con->v2.out_state != OUT_S_QUEUE_ENC_PAGE); + return; + } + + /* + * We've queued the last piece of ciphertext (ending with + * epilogue) + auth tag. Once it's written, we are done. + */ + WARN_ON(con->v2.out_enc_i != con->v2.out_enc_page_cnt); + con->v2.out_state = OUT_S_FINISH_MESSAGE; +} + +static void queue_zeros(struct ceph_connection *con) +{ + dout("%s con %p out_zero %d\n", __func__, con, con->v2.out_zero); + + if (con->v2.out_zero) { + set_out_bvec_zero(con); + con->v2.out_zero -= con->v2.out_bvec.bv_len; + con->v2.out_state = OUT_S_QUEUE_ZEROS; + return; + } + + /* + * We've zero-filled everything up to epilogue. Queue epilogue + * with late_status set to ABORTED and crcs adjusted for zeros. + * Once it's written, we are done patching up for the revoke. + */ + reset_out_kvecs(con); + prepare_epilogue_plain(con, true); + con->v2.out_state = OUT_S_FINISH_MESSAGE; +} + +static void finish_message(struct ceph_connection *con) +{ + dout("%s con %p msg %p\n", __func__, con, con->out_msg); + + /* we end up here both plain and secure modes */ + if (con->v2.out_enc_pages) { + WARN_ON(!con->v2.out_enc_page_cnt); + ceph_release_page_vector(con->v2.out_enc_pages, + con->v2.out_enc_page_cnt); + con->v2.out_enc_pages = NULL; + con->v2.out_enc_page_cnt = 0; + } + /* message may have been revoked */ + if (con->out_msg) { + ceph_msg_put(con->out_msg); + con->out_msg = NULL; + } + + con->v2.out_state = OUT_S_GET_NEXT; +} + +static int populate_out_iter(struct ceph_connection *con) +{ + int ret; + + dout("%s con %p state %d out_state %d\n", __func__, con, con->state, + con->v2.out_state); + WARN_ON(iov_iter_count(&con->v2.out_iter)); + + if (con->state != CEPH_CON_S_OPEN) { + WARN_ON(con->state < CEPH_CON_S_V2_BANNER_PREFIX || + con->state > CEPH_CON_S_V2_SESSION_RECONNECT); + goto nothing_pending; + } + + switch (con->v2.out_state) { + case OUT_S_QUEUE_DATA: + WARN_ON(!con->out_msg); + queue_data(con); + goto populated; + case OUT_S_QUEUE_DATA_CONT: + WARN_ON(!con->out_msg); + queue_data_cont(con); + goto populated; + case OUT_S_QUEUE_ENC_PAGE: + queue_enc_page(con); + goto populated; + case OUT_S_QUEUE_ZEROS: + WARN_ON(con->out_msg); /* revoked */ + queue_zeros(con); + goto populated; + case OUT_S_FINISH_MESSAGE: + finish_message(con); + break; + case OUT_S_GET_NEXT: + break; + default: + WARN(1, "bad out_state %d", con->v2.out_state); + return -EINVAL; + } + + WARN_ON(con->v2.out_state != OUT_S_GET_NEXT); + if (ceph_con_flag_test_and_clear(con, CEPH_CON_F_KEEPALIVE_PENDING)) { + ret = prepare_keepalive2(con); + if (ret) { + pr_err("prepare_keepalive2 failed: %d\n", ret); + return ret; + } + } else if (!list_empty(&con->out_queue)) { + ceph_con_get_out_msg(con); + ret = prepare_message(con); + if (ret) { + pr_err("prepare_message failed: %d\n", ret); + return ret; + } + } else if (con->in_seq > con->in_seq_acked) { + ret = prepare_ack(con); + if (ret) { + pr_err("prepare_ack failed: %d\n", ret); + return ret; + } + } else { + goto nothing_pending; + } + +populated: + if (WARN_ON(!iov_iter_count(&con->v2.out_iter))) + return -ENODATA; + dout("%s con %p populated %zu\n", __func__, con, + iov_iter_count(&con->v2.out_iter)); + return 1; + +nothing_pending: + WARN_ON(iov_iter_count(&con->v2.out_iter)); + dout("%s con %p nothing pending\n", __func__, con); + ceph_con_flag_clear(con, CEPH_CON_F_WRITE_PENDING); + return 0; +} + +int ceph_con_v2_try_write(struct ceph_connection *con) +{ + int ret; + + dout("%s con %p state %d have %zu\n", __func__, con, con->state, + iov_iter_count(&con->v2.out_iter)); + + /* open the socket first? */ + if (con->state == CEPH_CON_S_PREOPEN) { + WARN_ON(con->peer_addr.type != CEPH_ENTITY_ADDR_TYPE_MSGR2); + + /* + * Always bump global_seq. Bump connect_seq only if + * there is a session (i.e. we are reconnecting and will + * send session_reconnect instead of client_ident). + */ + con->v2.global_seq = ceph_get_global_seq(con->msgr, 0); + if (con->v2.server_cookie) + con->v2.connect_seq++; + + ret = prepare_read_banner_prefix(con); + if (ret) { + pr_err("prepare_read_banner_prefix failed: %d\n", ret); + con->error_msg = "connect error"; + return ret; + } + + reset_out_kvecs(con); + ret = prepare_banner(con); + if (ret) { + pr_err("prepare_banner failed: %d\n", ret); + con->error_msg = "connect error"; + return ret; + } + + ret = ceph_tcp_connect(con); + if (ret) { + pr_err("ceph_tcp_connect failed: %d\n", ret); + con->error_msg = "connect error"; + return ret; + } + } + + if (!iov_iter_count(&con->v2.out_iter)) { + ret = populate_out_iter(con); + if (ret <= 0) { + if (ret && ret != -EAGAIN && !con->error_msg) + con->error_msg = "write processing error"; + return ret; + } + } + + tcp_sock_set_cork(con->sock->sk, true); + for (;;) { + ret = ceph_tcp_send(con); + if (ret <= 0) + break; + + ret = populate_out_iter(con); + if (ret <= 0) { + if (ret && ret != -EAGAIN && !con->error_msg) + con->error_msg = "write processing error"; + break; + } + } + + tcp_sock_set_cork(con->sock->sk, false); + return ret; +} + +static u32 crc32c_zeros(u32 crc, int zero_len) +{ + int len; + + while (zero_len) { + len = min(zero_len, (int)PAGE_SIZE); + crc = crc32c(crc, page_address(ceph_zero_page), len); + zero_len -= len; + } + + return crc; +} + +static void prepare_zero_front(struct ceph_connection *con, int resid) +{ + int sent; + + WARN_ON(!resid || resid > front_len(con->out_msg)); + sent = front_len(con->out_msg) - resid; + dout("%s con %p sent %d resid %d\n", __func__, con, sent, resid); + + if (sent) { + con->v2.out_epil.front_crc = + crc32c(-1, con->out_msg->front.iov_base, sent); + con->v2.out_epil.front_crc = + crc32c_zeros(con->v2.out_epil.front_crc, resid); + } else { + con->v2.out_epil.front_crc = crc32c_zeros(-1, resid); + } + + con->v2.out_iter.count -= resid; + out_zero_add(con, resid); +} + +static void prepare_zero_middle(struct ceph_connection *con, int resid) +{ + int sent; + + WARN_ON(!resid || resid > middle_len(con->out_msg)); + sent = middle_len(con->out_msg) - resid; + dout("%s con %p sent %d resid %d\n", __func__, con, sent, resid); + + if (sent) { + con->v2.out_epil.middle_crc = + crc32c(-1, con->out_msg->middle->vec.iov_base, sent); + con->v2.out_epil.middle_crc = + crc32c_zeros(con->v2.out_epil.middle_crc, resid); + } else { + con->v2.out_epil.middle_crc = crc32c_zeros(-1, resid); + } + + con->v2.out_iter.count -= resid; + out_zero_add(con, resid); +} + +static void prepare_zero_data(struct ceph_connection *con) +{ + dout("%s con %p\n", __func__, con); + con->v2.out_epil.data_crc = crc32c_zeros(-1, data_len(con->out_msg)); + out_zero_add(con, data_len(con->out_msg)); +} + +static void revoke_at_queue_data(struct ceph_connection *con) +{ + int boundary; + int resid; + + WARN_ON(!data_len(con->out_msg)); + WARN_ON(!iov_iter_is_kvec(&con->v2.out_iter)); + resid = iov_iter_count(&con->v2.out_iter); + + boundary = front_len(con->out_msg) + middle_len(con->out_msg); + if (resid > boundary) { + resid -= boundary; + WARN_ON(resid > MESSAGE_HEAD_PLAIN_LEN); + dout("%s con %p was sending head\n", __func__, con); + if (front_len(con->out_msg)) + prepare_zero_front(con, front_len(con->out_msg)); + if (middle_len(con->out_msg)) + prepare_zero_middle(con, middle_len(con->out_msg)); + prepare_zero_data(con); + WARN_ON(iov_iter_count(&con->v2.out_iter) != resid); + con->v2.out_state = OUT_S_QUEUE_ZEROS; + return; + } + + boundary = middle_len(con->out_msg); + if (resid > boundary) { + resid -= boundary; + dout("%s con %p was sending front\n", __func__, con); + prepare_zero_front(con, resid); + if (middle_len(con->out_msg)) + prepare_zero_middle(con, middle_len(con->out_msg)); + prepare_zero_data(con); + queue_zeros(con); + return; + } + + WARN_ON(!resid); + dout("%s con %p was sending middle\n", __func__, con); + prepare_zero_middle(con, resid); + prepare_zero_data(con); + queue_zeros(con); +} + +static void revoke_at_queue_data_cont(struct ceph_connection *con) +{ + int sent, resid; /* current piece of data */ + + WARN_ON(!data_len(con->out_msg)); + WARN_ON(!iov_iter_is_bvec(&con->v2.out_iter)); + resid = iov_iter_count(&con->v2.out_iter); + WARN_ON(!resid || resid > con->v2.out_bvec.bv_len); + sent = con->v2.out_bvec.bv_len - resid; + dout("%s con %p sent %d resid %d\n", __func__, con, sent, resid); + + if (sent) { + con->v2.out_epil.data_crc = ceph_crc32c_page( + con->v2.out_epil.data_crc, con->v2.out_bvec.bv_page, + con->v2.out_bvec.bv_offset, sent); + ceph_msg_data_advance(&con->v2.out_cursor, sent); + } + WARN_ON(resid > con->v2.out_cursor.total_resid); + con->v2.out_epil.data_crc = crc32c_zeros(con->v2.out_epil.data_crc, + con->v2.out_cursor.total_resid); + + con->v2.out_iter.count -= resid; + out_zero_add(con, con->v2.out_cursor.total_resid); + queue_zeros(con); +} + +static void revoke_at_finish_message(struct ceph_connection *con) +{ + int boundary; + int resid; + + WARN_ON(!iov_iter_is_kvec(&con->v2.out_iter)); + resid = iov_iter_count(&con->v2.out_iter); + + if (!front_len(con->out_msg) && !middle_len(con->out_msg) && + !data_len(con->out_msg)) { + WARN_ON(!resid || resid > MESSAGE_HEAD_PLAIN_LEN); + dout("%s con %p was sending head (empty message) - noop\n", + __func__, con); + return; + } + + boundary = front_len(con->out_msg) + middle_len(con->out_msg) + + CEPH_EPILOGUE_PLAIN_LEN; + if (resid > boundary) { + resid -= boundary; + WARN_ON(resid > MESSAGE_HEAD_PLAIN_LEN); + dout("%s con %p was sending head\n", __func__, con); + if (front_len(con->out_msg)) + prepare_zero_front(con, front_len(con->out_msg)); + if (middle_len(con->out_msg)) + prepare_zero_middle(con, middle_len(con->out_msg)); + con->v2.out_iter.count -= CEPH_EPILOGUE_PLAIN_LEN; + WARN_ON(iov_iter_count(&con->v2.out_iter) != resid); + con->v2.out_state = OUT_S_QUEUE_ZEROS; + return; + } + + boundary = middle_len(con->out_msg) + CEPH_EPILOGUE_PLAIN_LEN; + if (resid > boundary) { + resid -= boundary; + dout("%s con %p was sending front\n", __func__, con); + prepare_zero_front(con, resid); + if (middle_len(con->out_msg)) + prepare_zero_middle(con, middle_len(con->out_msg)); + con->v2.out_iter.count -= CEPH_EPILOGUE_PLAIN_LEN; + queue_zeros(con); + return; + } + + boundary = CEPH_EPILOGUE_PLAIN_LEN; + if (resid > boundary) { + resid -= boundary; + dout("%s con %p was sending middle\n", __func__, con); + prepare_zero_middle(con, resid); + con->v2.out_iter.count -= CEPH_EPILOGUE_PLAIN_LEN; + queue_zeros(con); + return; + } + + WARN_ON(!resid); + dout("%s con %p was sending epilogue - noop\n", __func__, con); +} + +void ceph_con_v2_revoke(struct ceph_connection *con) +{ + WARN_ON(con->v2.out_zero); + + if (con_secure(con)) { + WARN_ON(con->v2.out_state != OUT_S_QUEUE_ENC_PAGE && + con->v2.out_state != OUT_S_FINISH_MESSAGE); + dout("%s con %p secure - noop\n", __func__, con); + return; + } + + switch (con->v2.out_state) { + case OUT_S_QUEUE_DATA: + revoke_at_queue_data(con); + break; + case OUT_S_QUEUE_DATA_CONT: + revoke_at_queue_data_cont(con); + break; + case OUT_S_FINISH_MESSAGE: + revoke_at_finish_message(con); + break; + default: + WARN(1, "bad out_state %d", con->v2.out_state); + break; + } +} + +static void revoke_at_prepare_read_data(struct ceph_connection *con) +{ + int remaining; + int resid; + + WARN_ON(con_secure(con)); + WARN_ON(!data_len(con->in_msg)); + WARN_ON(!iov_iter_is_kvec(&con->v2.in_iter)); + resid = iov_iter_count(&con->v2.in_iter); + WARN_ON(!resid); + + remaining = data_len(con->in_msg) + CEPH_EPILOGUE_PLAIN_LEN; + dout("%s con %p resid %d remaining %d\n", __func__, con, resid, + remaining); + con->v2.in_iter.count -= resid; + set_in_skip(con, resid + remaining); + con->v2.in_state = IN_S_FINISH_SKIP; +} + +static void revoke_at_prepare_read_data_cont(struct ceph_connection *con) +{ + int recved, resid; /* current piece of data */ + int remaining; + + WARN_ON(con_secure(con)); + WARN_ON(!data_len(con->in_msg)); + WARN_ON(!iov_iter_is_bvec(&con->v2.in_iter)); + resid = iov_iter_count(&con->v2.in_iter); + WARN_ON(!resid || resid > con->v2.in_bvec.bv_len); + recved = con->v2.in_bvec.bv_len - resid; + dout("%s con %p recved %d resid %d\n", __func__, con, recved, resid); + + if (recved) + ceph_msg_data_advance(&con->v2.in_cursor, recved); + WARN_ON(resid > con->v2.in_cursor.total_resid); + + remaining = CEPH_EPILOGUE_PLAIN_LEN; + dout("%s con %p total_resid %zu remaining %d\n", __func__, con, + con->v2.in_cursor.total_resid, remaining); + con->v2.in_iter.count -= resid; + set_in_skip(con, con->v2.in_cursor.total_resid + remaining); + con->v2.in_state = IN_S_FINISH_SKIP; +} + +static void revoke_at_prepare_read_enc_page(struct ceph_connection *con) +{ + int resid; /* current enc page (not necessarily data) */ + + WARN_ON(!con_secure(con)); + WARN_ON(!iov_iter_is_bvec(&con->v2.in_iter)); + resid = iov_iter_count(&con->v2.in_iter); + WARN_ON(!resid || resid > con->v2.in_bvec.bv_len); + + dout("%s con %p resid %d enc_resid %d\n", __func__, con, resid, + con->v2.in_enc_resid); + con->v2.in_iter.count -= resid; + set_in_skip(con, resid + con->v2.in_enc_resid); + con->v2.in_state = IN_S_FINISH_SKIP; +} + +static void revoke_at_handle_epilogue(struct ceph_connection *con) +{ + int resid; + + resid = iov_iter_count(&con->v2.in_iter); + WARN_ON(!resid); + + dout("%s con %p resid %d\n", __func__, con, resid); + con->v2.in_iter.count -= resid; + set_in_skip(con, resid); + con->v2.in_state = IN_S_FINISH_SKIP; +} + +void ceph_con_v2_revoke_incoming(struct ceph_connection *con) +{ + switch (con->v2.in_state) { + case IN_S_PREPARE_READ_DATA: + revoke_at_prepare_read_data(con); + break; + case IN_S_PREPARE_READ_DATA_CONT: + revoke_at_prepare_read_data_cont(con); + break; + case IN_S_PREPARE_READ_ENC_PAGE: + revoke_at_prepare_read_enc_page(con); + break; + case IN_S_HANDLE_EPILOGUE: + revoke_at_handle_epilogue(con); + break; + default: + WARN(1, "bad in_state %d", con->v2.in_state); + break; + } +} + +bool ceph_con_v2_opened(struct ceph_connection *con) +{ + return con->v2.peer_global_seq; +} + +void ceph_con_v2_reset_session(struct ceph_connection *con) +{ + con->v2.client_cookie = 0; + con->v2.server_cookie = 0; + con->v2.global_seq = 0; + con->v2.connect_seq = 0; + con->v2.peer_global_seq = 0; +} + +void ceph_con_v2_reset_protocol(struct ceph_connection *con) +{ + iov_iter_truncate(&con->v2.in_iter, 0); + iov_iter_truncate(&con->v2.out_iter, 0); + con->v2.out_zero = 0; + + clear_in_sign_kvecs(con); + clear_out_sign_kvecs(con); + free_conn_bufs(con); + + if (con->v2.in_enc_pages) { + WARN_ON(!con->v2.in_enc_page_cnt); + ceph_release_page_vector(con->v2.in_enc_pages, + con->v2.in_enc_page_cnt); + con->v2.in_enc_pages = NULL; + con->v2.in_enc_page_cnt = 0; + } + if (con->v2.out_enc_pages) { + WARN_ON(!con->v2.out_enc_page_cnt); + ceph_release_page_vector(con->v2.out_enc_pages, + con->v2.out_enc_page_cnt); + con->v2.out_enc_pages = NULL; + con->v2.out_enc_page_cnt = 0; + } + + con->v2.con_mode = CEPH_CON_MODE_UNKNOWN; + memzero_explicit(&con->v2.in_gcm_nonce, CEPH_GCM_IV_LEN); + memzero_explicit(&con->v2.out_gcm_nonce, CEPH_GCM_IV_LEN); + + if (con->v2.hmac_tfm) { + crypto_free_shash(con->v2.hmac_tfm); + con->v2.hmac_tfm = NULL; + } + if (con->v2.gcm_req) { + aead_request_free(con->v2.gcm_req); + con->v2.gcm_req = NULL; + } + if (con->v2.gcm_tfm) { + crypto_free_aead(con->v2.gcm_tfm); + con->v2.gcm_tfm = NULL; + } +} diff --git a/net/ceph/mon_client.c b/net/ceph/mon_client.c new file mode 100644 index 000000000..db60217f9 --- /dev/null +++ b/net/ceph/mon_client.c @@ -0,0 +1,1586 @@ +// SPDX-License-Identifier: GPL-2.0 +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +/* + * Interact with Ceph monitor cluster. Handle requests for new map + * versions, and periodically resend as needed. Also implement + * statfs() and umount(). + * + * A small cluster of Ceph "monitors" are responsible for managing critical + * cluster configuration and state information. An odd number (e.g., 3, 5) + * of cmon daemons use a modified version of the Paxos part-time parliament + * algorithm to manage the MDS map (mds cluster membership), OSD map, and + * list of clients who have mounted the file system. + * + * We maintain an open, active session with a monitor at all times in order to + * receive timely MDSMap updates. We periodically send a keepalive byte on the + * TCP socket to ensure we detect a failure. If the connection does break, we + * randomly hunt for a new monitor. Once the connection is reestablished, we + * resend any outstanding requests. + */ + +static const struct ceph_connection_operations mon_con_ops; + +static int __validate_auth(struct ceph_mon_client *monc); + +static int decode_mon_info(void **p, void *end, bool msgr2, + struct ceph_entity_addr *addr) +{ + void *mon_info_end; + u32 struct_len; + u8 struct_v; + int ret; + + ret = ceph_start_decoding(p, end, 1, "mon_info_t", &struct_v, + &struct_len); + if (ret) + return ret; + + mon_info_end = *p + struct_len; + ceph_decode_skip_string(p, end, e_inval); /* skip mon name */ + ret = ceph_decode_entity_addrvec(p, end, msgr2, addr); + if (ret) + return ret; + + *p = mon_info_end; + return 0; + +e_inval: + return -EINVAL; +} + +/* + * Decode a monmap blob (e.g., during mount). + * + * Assume MonMap v3 (i.e. encoding with MONNAMES and MONENC). + */ +static struct ceph_monmap *ceph_monmap_decode(void **p, void *end, bool msgr2) +{ + struct ceph_monmap *monmap = NULL; + struct ceph_fsid fsid; + u32 struct_len; + int blob_len; + int num_mon; + u8 struct_v; + u32 epoch; + int ret; + int i; + + ceph_decode_32_safe(p, end, blob_len, e_inval); + ceph_decode_need(p, end, blob_len, e_inval); + + ret = ceph_start_decoding(p, end, 6, "monmap", &struct_v, &struct_len); + if (ret) + goto fail; + + dout("%s struct_v %d\n", __func__, struct_v); + ceph_decode_copy_safe(p, end, &fsid, sizeof(fsid), e_inval); + ceph_decode_32_safe(p, end, epoch, e_inval); + if (struct_v >= 6) { + u32 feat_struct_len; + u8 feat_struct_v; + + *p += sizeof(struct ceph_timespec); /* skip last_changed */ + *p += sizeof(struct ceph_timespec); /* skip created */ + + ret = ceph_start_decoding(p, end, 1, "mon_feature_t", + &feat_struct_v, &feat_struct_len); + if (ret) + goto fail; + + *p += feat_struct_len; /* skip persistent_features */ + + ret = ceph_start_decoding(p, end, 1, "mon_feature_t", + &feat_struct_v, &feat_struct_len); + if (ret) + goto fail; + + *p += feat_struct_len; /* skip optional_features */ + } + ceph_decode_32_safe(p, end, num_mon, e_inval); + + dout("%s fsid %pU epoch %u num_mon %d\n", __func__, &fsid, epoch, + num_mon); + if (num_mon > CEPH_MAX_MON) + goto e_inval; + + monmap = kmalloc(struct_size(monmap, mon_inst, num_mon), GFP_NOIO); + if (!monmap) { + ret = -ENOMEM; + goto fail; + } + monmap->fsid = fsid; + monmap->epoch = epoch; + monmap->num_mon = num_mon; + + /* legacy_mon_addr map or mon_info map */ + for (i = 0; i < num_mon; i++) { + struct ceph_entity_inst *inst = &monmap->mon_inst[i]; + + ceph_decode_skip_string(p, end, e_inval); /* skip mon name */ + inst->name.type = CEPH_ENTITY_TYPE_MON; + inst->name.num = cpu_to_le64(i); + + if (struct_v >= 6) + ret = decode_mon_info(p, end, msgr2, &inst->addr); + else + ret = ceph_decode_entity_addr(p, end, &inst->addr); + if (ret) + goto fail; + + dout("%s mon%d addr %s\n", __func__, i, + ceph_pr_addr(&inst->addr)); + } + + return monmap; + +e_inval: + ret = -EINVAL; +fail: + kfree(monmap); + return ERR_PTR(ret); +} + +/* + * return true if *addr is included in the monmap. + */ +int ceph_monmap_contains(struct ceph_monmap *m, struct ceph_entity_addr *addr) +{ + int i; + + for (i = 0; i < m->num_mon; i++) { + if (ceph_addr_equal_no_type(addr, &m->mon_inst[i].addr)) + return 1; + } + + return 0; +} + +/* + * Send an auth request. + */ +static void __send_prepared_auth_request(struct ceph_mon_client *monc, int len) +{ + monc->pending_auth = 1; + monc->m_auth->front.iov_len = len; + monc->m_auth->hdr.front_len = cpu_to_le32(len); + ceph_msg_revoke(monc->m_auth); + ceph_msg_get(monc->m_auth); /* keep our ref */ + ceph_con_send(&monc->con, monc->m_auth); +} + +/* + * Close monitor session, if any. + */ +static void __close_session(struct ceph_mon_client *monc) +{ + dout("__close_session closing mon%d\n", monc->cur_mon); + ceph_msg_revoke(monc->m_auth); + ceph_msg_revoke_incoming(monc->m_auth_reply); + ceph_msg_revoke(monc->m_subscribe); + ceph_msg_revoke_incoming(monc->m_subscribe_ack); + ceph_con_close(&monc->con); + + monc->pending_auth = 0; + ceph_auth_reset(monc->auth); +} + +/* + * Pick a new monitor at random and set cur_mon. If we are repicking + * (i.e. cur_mon is already set), be sure to pick a different one. + */ +static void pick_new_mon(struct ceph_mon_client *monc) +{ + int old_mon = monc->cur_mon; + + BUG_ON(monc->monmap->num_mon < 1); + + if (monc->monmap->num_mon == 1) { + monc->cur_mon = 0; + } else { + int max = monc->monmap->num_mon; + int o = -1; + int n; + + if (monc->cur_mon >= 0) { + if (monc->cur_mon < monc->monmap->num_mon) + o = monc->cur_mon; + if (o >= 0) + max--; + } + + n = prandom_u32_max(max); + if (o >= 0 && n >= o) + n++; + + monc->cur_mon = n; + } + + dout("%s mon%d -> mon%d out of %d mons\n", __func__, old_mon, + monc->cur_mon, monc->monmap->num_mon); +} + +/* + * Open a session with a new monitor. + */ +static void __open_session(struct ceph_mon_client *monc) +{ + int ret; + + pick_new_mon(monc); + + monc->hunting = true; + if (monc->had_a_connection) { + monc->hunt_mult *= CEPH_MONC_HUNT_BACKOFF; + if (monc->hunt_mult > CEPH_MONC_HUNT_MAX_MULT) + monc->hunt_mult = CEPH_MONC_HUNT_MAX_MULT; + } + + monc->sub_renew_after = jiffies; /* i.e., expired */ + monc->sub_renew_sent = 0; + + dout("%s opening mon%d\n", __func__, monc->cur_mon); + ceph_con_open(&monc->con, CEPH_ENTITY_TYPE_MON, monc->cur_mon, + &monc->monmap->mon_inst[monc->cur_mon].addr); + + /* + * Queue a keepalive to ensure that in case of an early fault + * the messenger doesn't put us into STANDBY state and instead + * retries. This also ensures that our timestamp is valid by + * the time we finish hunting and delayed_work() checks it. + */ + ceph_con_keepalive(&monc->con); + if (ceph_msgr2(monc->client)) { + monc->pending_auth = 1; + return; + } + + /* initiate authentication handshake */ + ret = ceph_auth_build_hello(monc->auth, + monc->m_auth->front.iov_base, + monc->m_auth->front_alloc_len); + BUG_ON(ret <= 0); + __send_prepared_auth_request(monc, ret); +} + +static void reopen_session(struct ceph_mon_client *monc) +{ + if (!monc->hunting) + pr_info("mon%d %s session lost, hunting for new mon\n", + monc->cur_mon, ceph_pr_addr(&monc->con.peer_addr)); + + __close_session(monc); + __open_session(monc); +} + +void ceph_monc_reopen_session(struct ceph_mon_client *monc) +{ + mutex_lock(&monc->mutex); + reopen_session(monc); + mutex_unlock(&monc->mutex); +} + +static void un_backoff(struct ceph_mon_client *monc) +{ + monc->hunt_mult /= 2; /* reduce by 50% */ + if (monc->hunt_mult < 1) + monc->hunt_mult = 1; + dout("%s hunt_mult now %d\n", __func__, monc->hunt_mult); +} + +/* + * Reschedule delayed work timer. + */ +static void __schedule_delayed(struct ceph_mon_client *monc) +{ + unsigned long delay; + + if (monc->hunting) + delay = CEPH_MONC_HUNT_INTERVAL * monc->hunt_mult; + else + delay = CEPH_MONC_PING_INTERVAL; + + dout("__schedule_delayed after %lu\n", delay); + mod_delayed_work(system_wq, &monc->delayed_work, + round_jiffies_relative(delay)); +} + +const char *ceph_sub_str[] = { + [CEPH_SUB_MONMAP] = "monmap", + [CEPH_SUB_OSDMAP] = "osdmap", + [CEPH_SUB_FSMAP] = "fsmap.user", + [CEPH_SUB_MDSMAP] = "mdsmap", +}; + +/* + * Send subscribe request for one or more maps, according to + * monc->subs. + */ +static void __send_subscribe(struct ceph_mon_client *monc) +{ + struct ceph_msg *msg = monc->m_subscribe; + void *p = msg->front.iov_base; + void *const end = p + msg->front_alloc_len; + int num = 0; + int i; + + dout("%s sent %lu\n", __func__, monc->sub_renew_sent); + + BUG_ON(monc->cur_mon < 0); + + if (!monc->sub_renew_sent) + monc->sub_renew_sent = jiffies | 1; /* never 0 */ + + msg->hdr.version = cpu_to_le16(2); + + for (i = 0; i < ARRAY_SIZE(monc->subs); i++) { + if (monc->subs[i].want) + num++; + } + BUG_ON(num < 1); /* monmap sub is always there */ + ceph_encode_32(&p, num); + for (i = 0; i < ARRAY_SIZE(monc->subs); i++) { + char buf[32]; + int len; + + if (!monc->subs[i].want) + continue; + + len = sprintf(buf, "%s", ceph_sub_str[i]); + if (i == CEPH_SUB_MDSMAP && + monc->fs_cluster_id != CEPH_FS_CLUSTER_ID_NONE) + len += sprintf(buf + len, ".%d", monc->fs_cluster_id); + + dout("%s %s start %llu flags 0x%x\n", __func__, buf, + le64_to_cpu(monc->subs[i].item.start), + monc->subs[i].item.flags); + ceph_encode_string(&p, end, buf, len); + memcpy(p, &monc->subs[i].item, sizeof(monc->subs[i].item)); + p += sizeof(monc->subs[i].item); + } + + BUG_ON(p > end); + msg->front.iov_len = p - msg->front.iov_base; + msg->hdr.front_len = cpu_to_le32(msg->front.iov_len); + ceph_msg_revoke(msg); + ceph_con_send(&monc->con, ceph_msg_get(msg)); +} + +static void handle_subscribe_ack(struct ceph_mon_client *monc, + struct ceph_msg *msg) +{ + unsigned int seconds; + struct ceph_mon_subscribe_ack *h = msg->front.iov_base; + + if (msg->front.iov_len < sizeof(*h)) + goto bad; + seconds = le32_to_cpu(h->duration); + + mutex_lock(&monc->mutex); + if (monc->sub_renew_sent) { + /* + * This is only needed for legacy (infernalis or older) + * MONs -- see delayed_work(). + */ + monc->sub_renew_after = monc->sub_renew_sent + + (seconds >> 1) * HZ - 1; + dout("%s sent %lu duration %d renew after %lu\n", __func__, + monc->sub_renew_sent, seconds, monc->sub_renew_after); + monc->sub_renew_sent = 0; + } else { + dout("%s sent %lu renew after %lu, ignoring\n", __func__, + monc->sub_renew_sent, monc->sub_renew_after); + } + mutex_unlock(&monc->mutex); + return; +bad: + pr_err("got corrupt subscribe-ack msg\n"); + ceph_msg_dump(msg); +} + +/* + * Register interest in a map + * + * @sub: one of CEPH_SUB_* + * @epoch: X for "every map since X", or 0 for "just the latest" + */ +static bool __ceph_monc_want_map(struct ceph_mon_client *monc, int sub, + u32 epoch, bool continuous) +{ + __le64 start = cpu_to_le64(epoch); + u8 flags = !continuous ? CEPH_SUBSCRIBE_ONETIME : 0; + + dout("%s %s epoch %u continuous %d\n", __func__, ceph_sub_str[sub], + epoch, continuous); + + if (monc->subs[sub].want && + monc->subs[sub].item.start == start && + monc->subs[sub].item.flags == flags) + return false; + + monc->subs[sub].item.start = start; + monc->subs[sub].item.flags = flags; + monc->subs[sub].want = true; + + return true; +} + +bool ceph_monc_want_map(struct ceph_mon_client *monc, int sub, u32 epoch, + bool continuous) +{ + bool need_request; + + mutex_lock(&monc->mutex); + need_request = __ceph_monc_want_map(monc, sub, epoch, continuous); + mutex_unlock(&monc->mutex); + + return need_request; +} +EXPORT_SYMBOL(ceph_monc_want_map); + +/* + * Keep track of which maps we have + * + * @sub: one of CEPH_SUB_* + */ +static void __ceph_monc_got_map(struct ceph_mon_client *monc, int sub, + u32 epoch) +{ + dout("%s %s epoch %u\n", __func__, ceph_sub_str[sub], epoch); + + if (monc->subs[sub].want) { + if (monc->subs[sub].item.flags & CEPH_SUBSCRIBE_ONETIME) + monc->subs[sub].want = false; + else + monc->subs[sub].item.start = cpu_to_le64(epoch + 1); + } + + monc->subs[sub].have = epoch; +} + +void ceph_monc_got_map(struct ceph_mon_client *monc, int sub, u32 epoch) +{ + mutex_lock(&monc->mutex); + __ceph_monc_got_map(monc, sub, epoch); + mutex_unlock(&monc->mutex); +} +EXPORT_SYMBOL(ceph_monc_got_map); + +void ceph_monc_renew_subs(struct ceph_mon_client *monc) +{ + mutex_lock(&monc->mutex); + __send_subscribe(monc); + mutex_unlock(&monc->mutex); +} +EXPORT_SYMBOL(ceph_monc_renew_subs); + +/* + * Wait for an osdmap with a given epoch. + * + * @epoch: epoch to wait for + * @timeout: in jiffies, 0 means "wait forever" + */ +int ceph_monc_wait_osdmap(struct ceph_mon_client *monc, u32 epoch, + unsigned long timeout) +{ + unsigned long started = jiffies; + long ret; + + mutex_lock(&monc->mutex); + while (monc->subs[CEPH_SUB_OSDMAP].have < epoch) { + mutex_unlock(&monc->mutex); + + if (timeout && time_after_eq(jiffies, started + timeout)) + return -ETIMEDOUT; + + ret = wait_event_interruptible_timeout(monc->client->auth_wq, + monc->subs[CEPH_SUB_OSDMAP].have >= epoch, + ceph_timeout_jiffies(timeout)); + if (ret < 0) + return ret; + + mutex_lock(&monc->mutex); + } + + mutex_unlock(&monc->mutex); + return 0; +} +EXPORT_SYMBOL(ceph_monc_wait_osdmap); + +/* + * Open a session with a random monitor. Request monmap and osdmap, + * which are waited upon in __ceph_open_session(). + */ +int ceph_monc_open_session(struct ceph_mon_client *monc) +{ + mutex_lock(&monc->mutex); + __ceph_monc_want_map(monc, CEPH_SUB_MONMAP, 0, true); + __ceph_monc_want_map(monc, CEPH_SUB_OSDMAP, 0, false); + __open_session(monc); + __schedule_delayed(monc); + mutex_unlock(&monc->mutex); + return 0; +} +EXPORT_SYMBOL(ceph_monc_open_session); + +static void ceph_monc_handle_map(struct ceph_mon_client *monc, + struct ceph_msg *msg) +{ + struct ceph_client *client = monc->client; + struct ceph_monmap *monmap; + void *p, *end; + + mutex_lock(&monc->mutex); + + dout("handle_monmap\n"); + p = msg->front.iov_base; + end = p + msg->front.iov_len; + + monmap = ceph_monmap_decode(&p, end, ceph_msgr2(client)); + if (IS_ERR(monmap)) { + pr_err("problem decoding monmap, %d\n", + (int)PTR_ERR(monmap)); + ceph_msg_dump(msg); + goto out; + } + + if (ceph_check_fsid(client, &monmap->fsid) < 0) { + kfree(monmap); + goto out; + } + + kfree(monc->monmap); + monc->monmap = monmap; + + __ceph_monc_got_map(monc, CEPH_SUB_MONMAP, monc->monmap->epoch); + client->have_fsid = true; + +out: + mutex_unlock(&monc->mutex); + wake_up_all(&client->auth_wq); +} + +/* + * generic requests (currently statfs, mon_get_version) + */ +DEFINE_RB_FUNCS(generic_request, struct ceph_mon_generic_request, tid, node) + +static void release_generic_request(struct kref *kref) +{ + struct ceph_mon_generic_request *req = + container_of(kref, struct ceph_mon_generic_request, kref); + + dout("%s greq %p request %p reply %p\n", __func__, req, req->request, + req->reply); + WARN_ON(!RB_EMPTY_NODE(&req->node)); + + if (req->reply) + ceph_msg_put(req->reply); + if (req->request) + ceph_msg_put(req->request); + + kfree(req); +} + +static void put_generic_request(struct ceph_mon_generic_request *req) +{ + if (req) + kref_put(&req->kref, release_generic_request); +} + +static void get_generic_request(struct ceph_mon_generic_request *req) +{ + kref_get(&req->kref); +} + +static struct ceph_mon_generic_request * +alloc_generic_request(struct ceph_mon_client *monc, gfp_t gfp) +{ + struct ceph_mon_generic_request *req; + + req = kzalloc(sizeof(*req), gfp); + if (!req) + return NULL; + + req->monc = monc; + kref_init(&req->kref); + RB_CLEAR_NODE(&req->node); + init_completion(&req->completion); + + dout("%s greq %p\n", __func__, req); + return req; +} + +static void register_generic_request(struct ceph_mon_generic_request *req) +{ + struct ceph_mon_client *monc = req->monc; + + WARN_ON(req->tid); + + get_generic_request(req); + req->tid = ++monc->last_tid; + insert_generic_request(&monc->generic_request_tree, req); +} + +static void send_generic_request(struct ceph_mon_client *monc, + struct ceph_mon_generic_request *req) +{ + WARN_ON(!req->tid); + + dout("%s greq %p tid %llu\n", __func__, req, req->tid); + req->request->hdr.tid = cpu_to_le64(req->tid); + ceph_con_send(&monc->con, ceph_msg_get(req->request)); +} + +static void __finish_generic_request(struct ceph_mon_generic_request *req) +{ + struct ceph_mon_client *monc = req->monc; + + dout("%s greq %p tid %llu\n", __func__, req, req->tid); + erase_generic_request(&monc->generic_request_tree, req); + + ceph_msg_revoke(req->request); + ceph_msg_revoke_incoming(req->reply); +} + +static void finish_generic_request(struct ceph_mon_generic_request *req) +{ + __finish_generic_request(req); + put_generic_request(req); +} + +static void complete_generic_request(struct ceph_mon_generic_request *req) +{ + if (req->complete_cb) + req->complete_cb(req); + else + complete_all(&req->completion); + put_generic_request(req); +} + +static void cancel_generic_request(struct ceph_mon_generic_request *req) +{ + struct ceph_mon_client *monc = req->monc; + struct ceph_mon_generic_request *lookup_req; + + dout("%s greq %p tid %llu\n", __func__, req, req->tid); + + mutex_lock(&monc->mutex); + lookup_req = lookup_generic_request(&monc->generic_request_tree, + req->tid); + if (lookup_req) { + WARN_ON(lookup_req != req); + finish_generic_request(req); + } + + mutex_unlock(&monc->mutex); +} + +static int wait_generic_request(struct ceph_mon_generic_request *req) +{ + int ret; + + dout("%s greq %p tid %llu\n", __func__, req, req->tid); + ret = wait_for_completion_interruptible(&req->completion); + if (ret) + cancel_generic_request(req); + else + ret = req->result; /* completed */ + + return ret; +} + +static struct ceph_msg *get_generic_reply(struct ceph_connection *con, + struct ceph_msg_header *hdr, + int *skip) +{ + struct ceph_mon_client *monc = con->private; + struct ceph_mon_generic_request *req; + u64 tid = le64_to_cpu(hdr->tid); + struct ceph_msg *m; + + mutex_lock(&monc->mutex); + req = lookup_generic_request(&monc->generic_request_tree, tid); + if (!req) { + dout("get_generic_reply %lld dne\n", tid); + *skip = 1; + m = NULL; + } else { + dout("get_generic_reply %lld got %p\n", tid, req->reply); + *skip = 0; + m = ceph_msg_get(req->reply); + /* + * we don't need to track the connection reading into + * this reply because we only have one open connection + * at a time, ever. + */ + } + mutex_unlock(&monc->mutex); + return m; +} + +/* + * statfs + */ +static void handle_statfs_reply(struct ceph_mon_client *monc, + struct ceph_msg *msg) +{ + struct ceph_mon_generic_request *req; + struct ceph_mon_statfs_reply *reply = msg->front.iov_base; + u64 tid = le64_to_cpu(msg->hdr.tid); + + dout("%s msg %p tid %llu\n", __func__, msg, tid); + + if (msg->front.iov_len != sizeof(*reply)) + goto bad; + + mutex_lock(&monc->mutex); + req = lookup_generic_request(&monc->generic_request_tree, tid); + if (!req) { + mutex_unlock(&monc->mutex); + return; + } + + req->result = 0; + *req->u.st = reply->st; /* struct */ + __finish_generic_request(req); + mutex_unlock(&monc->mutex); + + complete_generic_request(req); + return; + +bad: + pr_err("corrupt statfs reply, tid %llu\n", tid); + ceph_msg_dump(msg); +} + +/* + * Do a synchronous statfs(). + */ +int ceph_monc_do_statfs(struct ceph_mon_client *monc, u64 data_pool, + struct ceph_statfs *buf) +{ + struct ceph_mon_generic_request *req; + struct ceph_mon_statfs *h; + int ret = -ENOMEM; + + req = alloc_generic_request(monc, GFP_NOFS); + if (!req) + goto out; + + req->request = ceph_msg_new(CEPH_MSG_STATFS, sizeof(*h), GFP_NOFS, + true); + if (!req->request) + goto out; + + req->reply = ceph_msg_new(CEPH_MSG_STATFS_REPLY, 64, GFP_NOFS, true); + if (!req->reply) + goto out; + + req->u.st = buf; + req->request->hdr.version = cpu_to_le16(2); + + mutex_lock(&monc->mutex); + register_generic_request(req); + /* fill out request */ + h = req->request->front.iov_base; + h->monhdr.have_version = 0; + h->monhdr.session_mon = cpu_to_le16(-1); + h->monhdr.session_mon_tid = 0; + h->fsid = monc->monmap->fsid; + h->contains_data_pool = (data_pool != CEPH_NOPOOL); + h->data_pool = cpu_to_le64(data_pool); + send_generic_request(monc, req); + mutex_unlock(&monc->mutex); + + ret = wait_generic_request(req); +out: + put_generic_request(req); + return ret; +} +EXPORT_SYMBOL(ceph_monc_do_statfs); + +static void handle_get_version_reply(struct ceph_mon_client *monc, + struct ceph_msg *msg) +{ + struct ceph_mon_generic_request *req; + u64 tid = le64_to_cpu(msg->hdr.tid); + void *p = msg->front.iov_base; + void *end = p + msg->front_alloc_len; + u64 handle; + + dout("%s msg %p tid %llu\n", __func__, msg, tid); + + ceph_decode_need(&p, end, 2*sizeof(u64), bad); + handle = ceph_decode_64(&p); + if (tid != 0 && tid != handle) + goto bad; + + mutex_lock(&monc->mutex); + req = lookup_generic_request(&monc->generic_request_tree, handle); + if (!req) { + mutex_unlock(&monc->mutex); + return; + } + + req->result = 0; + req->u.newest = ceph_decode_64(&p); + __finish_generic_request(req); + mutex_unlock(&monc->mutex); + + complete_generic_request(req); + return; + +bad: + pr_err("corrupt mon_get_version reply, tid %llu\n", tid); + ceph_msg_dump(msg); +} + +static struct ceph_mon_generic_request * +__ceph_monc_get_version(struct ceph_mon_client *monc, const char *what, + ceph_monc_callback_t cb, u64 private_data) +{ + struct ceph_mon_generic_request *req; + + req = alloc_generic_request(monc, GFP_NOIO); + if (!req) + goto err_put_req; + + req->request = ceph_msg_new(CEPH_MSG_MON_GET_VERSION, + sizeof(u64) + sizeof(u32) + strlen(what), + GFP_NOIO, true); + if (!req->request) + goto err_put_req; + + req->reply = ceph_msg_new(CEPH_MSG_MON_GET_VERSION_REPLY, 32, GFP_NOIO, + true); + if (!req->reply) + goto err_put_req; + + req->complete_cb = cb; + req->private_data = private_data; + + mutex_lock(&monc->mutex); + register_generic_request(req); + { + void *p = req->request->front.iov_base; + void *const end = p + req->request->front_alloc_len; + + ceph_encode_64(&p, req->tid); /* handle */ + ceph_encode_string(&p, end, what, strlen(what)); + WARN_ON(p != end); + } + send_generic_request(monc, req); + mutex_unlock(&monc->mutex); + + return req; + +err_put_req: + put_generic_request(req); + return ERR_PTR(-ENOMEM); +} + +/* + * Send MMonGetVersion and wait for the reply. + * + * @what: one of "mdsmap", "osdmap" or "monmap" + */ +int ceph_monc_get_version(struct ceph_mon_client *monc, const char *what, + u64 *newest) +{ + struct ceph_mon_generic_request *req; + int ret; + + req = __ceph_monc_get_version(monc, what, NULL, 0); + if (IS_ERR(req)) + return PTR_ERR(req); + + ret = wait_generic_request(req); + if (!ret) + *newest = req->u.newest; + + put_generic_request(req); + return ret; +} +EXPORT_SYMBOL(ceph_monc_get_version); + +/* + * Send MMonGetVersion, + * + * @what: one of "mdsmap", "osdmap" or "monmap" + */ +int ceph_monc_get_version_async(struct ceph_mon_client *monc, const char *what, + ceph_monc_callback_t cb, u64 private_data) +{ + struct ceph_mon_generic_request *req; + + req = __ceph_monc_get_version(monc, what, cb, private_data); + if (IS_ERR(req)) + return PTR_ERR(req); + + put_generic_request(req); + return 0; +} +EXPORT_SYMBOL(ceph_monc_get_version_async); + +static void handle_command_ack(struct ceph_mon_client *monc, + struct ceph_msg *msg) +{ + struct ceph_mon_generic_request *req; + void *p = msg->front.iov_base; + void *const end = p + msg->front_alloc_len; + u64 tid = le64_to_cpu(msg->hdr.tid); + + dout("%s msg %p tid %llu\n", __func__, msg, tid); + + ceph_decode_need(&p, end, sizeof(struct ceph_mon_request_header) + + sizeof(u32), bad); + p += sizeof(struct ceph_mon_request_header); + + mutex_lock(&monc->mutex); + req = lookup_generic_request(&monc->generic_request_tree, tid); + if (!req) { + mutex_unlock(&monc->mutex); + return; + } + + req->result = ceph_decode_32(&p); + __finish_generic_request(req); + mutex_unlock(&monc->mutex); + + complete_generic_request(req); + return; + +bad: + pr_err("corrupt mon_command ack, tid %llu\n", tid); + ceph_msg_dump(msg); +} + +static __printf(2, 0) +int do_mon_command_vargs(struct ceph_mon_client *monc, const char *fmt, + va_list ap) +{ + struct ceph_mon_generic_request *req; + struct ceph_mon_command *h; + int ret = -ENOMEM; + int len; + + req = alloc_generic_request(monc, GFP_NOIO); + if (!req) + goto out; + + req->request = ceph_msg_new(CEPH_MSG_MON_COMMAND, 256, GFP_NOIO, true); + if (!req->request) + goto out; + + req->reply = ceph_msg_new(CEPH_MSG_MON_COMMAND_ACK, 512, GFP_NOIO, + true); + if (!req->reply) + goto out; + + mutex_lock(&monc->mutex); + register_generic_request(req); + h = req->request->front.iov_base; + h->monhdr.have_version = 0; + h->monhdr.session_mon = cpu_to_le16(-1); + h->monhdr.session_mon_tid = 0; + h->fsid = monc->monmap->fsid; + h->num_strs = cpu_to_le32(1); + len = vsprintf(h->str, fmt, ap); + h->str_len = cpu_to_le32(len); + send_generic_request(monc, req); + mutex_unlock(&monc->mutex); + + ret = wait_generic_request(req); +out: + put_generic_request(req); + return ret; +} + +static __printf(2, 3) +int do_mon_command(struct ceph_mon_client *monc, const char *fmt, ...) +{ + va_list ap; + int ret; + + va_start(ap, fmt); + ret = do_mon_command_vargs(monc, fmt, ap); + va_end(ap); + return ret; +} + +int ceph_monc_blocklist_add(struct ceph_mon_client *monc, + struct ceph_entity_addr *client_addr) +{ + int ret; + + ret = do_mon_command(monc, + "{ \"prefix\": \"osd blocklist\", \ + \"blocklistop\": \"add\", \ + \"addr\": \"%pISpc/%u\" }", + &client_addr->in_addr, + le32_to_cpu(client_addr->nonce)); + if (ret == -EINVAL) { + /* + * The monitor returns EINVAL on an unrecognized command. + * Try the legacy command -- it is exactly the same except + * for the name. + */ + ret = do_mon_command(monc, + "{ \"prefix\": \"osd blacklist\", \ + \"blacklistop\": \"add\", \ + \"addr\": \"%pISpc/%u\" }", + &client_addr->in_addr, + le32_to_cpu(client_addr->nonce)); + } + if (ret) + return ret; + + /* + * Make sure we have the osdmap that includes the blocklist + * entry. This is needed to ensure that the OSDs pick up the + * new blocklist before processing any future requests from + * this client. + */ + return ceph_wait_for_latest_osdmap(monc->client, 0); +} +EXPORT_SYMBOL(ceph_monc_blocklist_add); + +/* + * Resend pending generic requests. + */ +static void __resend_generic_request(struct ceph_mon_client *monc) +{ + struct ceph_mon_generic_request *req; + struct rb_node *p; + + for (p = rb_first(&monc->generic_request_tree); p; p = rb_next(p)) { + req = rb_entry(p, struct ceph_mon_generic_request, node); + ceph_msg_revoke(req->request); + ceph_msg_revoke_incoming(req->reply); + ceph_con_send(&monc->con, ceph_msg_get(req->request)); + } +} + +/* + * Delayed work. If we haven't mounted yet, retry. Otherwise, + * renew/retry subscription as needed (in case it is timing out, or we + * got an ENOMEM). And keep the monitor connection alive. + */ +static void delayed_work(struct work_struct *work) +{ + struct ceph_mon_client *monc = + container_of(work, struct ceph_mon_client, delayed_work.work); + + dout("monc delayed_work\n"); + mutex_lock(&monc->mutex); + if (monc->hunting) { + dout("%s continuing hunt\n", __func__); + reopen_session(monc); + } else { + int is_auth = ceph_auth_is_authenticated(monc->auth); + if (ceph_con_keepalive_expired(&monc->con, + CEPH_MONC_PING_TIMEOUT)) { + dout("monc keepalive timeout\n"); + is_auth = 0; + reopen_session(monc); + } + + if (!monc->hunting) { + ceph_con_keepalive(&monc->con); + __validate_auth(monc); + un_backoff(monc); + } + + if (is_auth && + !(monc->con.peer_features & CEPH_FEATURE_MON_STATEFUL_SUB)) { + unsigned long now = jiffies; + + dout("%s renew subs? now %lu renew after %lu\n", + __func__, now, monc->sub_renew_after); + if (time_after_eq(now, monc->sub_renew_after)) + __send_subscribe(monc); + } + } + __schedule_delayed(monc); + mutex_unlock(&monc->mutex); +} + +/* + * On startup, we build a temporary monmap populated with the IPs + * provided by mount(2). + */ +static int build_initial_monmap(struct ceph_mon_client *monc) +{ + __le32 my_type = ceph_msgr2(monc->client) ? + CEPH_ENTITY_ADDR_TYPE_MSGR2 : CEPH_ENTITY_ADDR_TYPE_LEGACY; + struct ceph_options *opt = monc->client->options; + int num_mon = opt->num_mon; + int i; + + /* build initial monmap */ + monc->monmap = kzalloc(struct_size(monc->monmap, mon_inst, num_mon), + GFP_KERNEL); + if (!monc->monmap) + return -ENOMEM; + + for (i = 0; i < num_mon; i++) { + struct ceph_entity_inst *inst = &monc->monmap->mon_inst[i]; + + memcpy(&inst->addr.in_addr, &opt->mon_addr[i].in_addr, + sizeof(inst->addr.in_addr)); + inst->addr.type = my_type; + inst->addr.nonce = 0; + inst->name.type = CEPH_ENTITY_TYPE_MON; + inst->name.num = cpu_to_le64(i); + } + monc->monmap->num_mon = num_mon; + return 0; +} + +int ceph_monc_init(struct ceph_mon_client *monc, struct ceph_client *cl) +{ + int err; + + dout("init\n"); + memset(monc, 0, sizeof(*monc)); + monc->client = cl; + mutex_init(&monc->mutex); + + err = build_initial_monmap(monc); + if (err) + goto out; + + /* connection */ + /* authentication */ + monc->auth = ceph_auth_init(cl->options->name, cl->options->key, + cl->options->con_modes); + if (IS_ERR(monc->auth)) { + err = PTR_ERR(monc->auth); + goto out_monmap; + } + monc->auth->want_keys = + CEPH_ENTITY_TYPE_AUTH | CEPH_ENTITY_TYPE_MON | + CEPH_ENTITY_TYPE_OSD | CEPH_ENTITY_TYPE_MDS; + + /* msgs */ + err = -ENOMEM; + monc->m_subscribe_ack = ceph_msg_new(CEPH_MSG_MON_SUBSCRIBE_ACK, + sizeof(struct ceph_mon_subscribe_ack), + GFP_KERNEL, true); + if (!monc->m_subscribe_ack) + goto out_auth; + + monc->m_subscribe = ceph_msg_new(CEPH_MSG_MON_SUBSCRIBE, 128, + GFP_KERNEL, true); + if (!monc->m_subscribe) + goto out_subscribe_ack; + + monc->m_auth_reply = ceph_msg_new(CEPH_MSG_AUTH_REPLY, 4096, + GFP_KERNEL, true); + if (!monc->m_auth_reply) + goto out_subscribe; + + monc->m_auth = ceph_msg_new(CEPH_MSG_AUTH, 4096, GFP_KERNEL, true); + monc->pending_auth = 0; + if (!monc->m_auth) + goto out_auth_reply; + + ceph_con_init(&monc->con, monc, &mon_con_ops, + &monc->client->msgr); + + monc->cur_mon = -1; + monc->had_a_connection = false; + monc->hunt_mult = 1; + + INIT_DELAYED_WORK(&monc->delayed_work, delayed_work); + monc->generic_request_tree = RB_ROOT; + monc->last_tid = 0; + + monc->fs_cluster_id = CEPH_FS_CLUSTER_ID_NONE; + + return 0; + +out_auth_reply: + ceph_msg_put(monc->m_auth_reply); +out_subscribe: + ceph_msg_put(monc->m_subscribe); +out_subscribe_ack: + ceph_msg_put(monc->m_subscribe_ack); +out_auth: + ceph_auth_destroy(monc->auth); +out_monmap: + kfree(monc->monmap); +out: + return err; +} +EXPORT_SYMBOL(ceph_monc_init); + +void ceph_monc_stop(struct ceph_mon_client *monc) +{ + dout("stop\n"); + cancel_delayed_work_sync(&monc->delayed_work); + + mutex_lock(&monc->mutex); + __close_session(monc); + monc->cur_mon = -1; + mutex_unlock(&monc->mutex); + + /* + * flush msgr queue before we destroy ourselves to ensure that: + * - any work that references our embedded con is finished. + * - any osd_client or other work that may reference an authorizer + * finishes before we shut down the auth subsystem. + */ + ceph_msgr_flush(); + + ceph_auth_destroy(monc->auth); + + WARN_ON(!RB_EMPTY_ROOT(&monc->generic_request_tree)); + + ceph_msg_put(monc->m_auth); + ceph_msg_put(monc->m_auth_reply); + ceph_msg_put(monc->m_subscribe); + ceph_msg_put(monc->m_subscribe_ack); + + kfree(monc->monmap); +} +EXPORT_SYMBOL(ceph_monc_stop); + +static void finish_hunting(struct ceph_mon_client *monc) +{ + if (monc->hunting) { + dout("%s found mon%d\n", __func__, monc->cur_mon); + monc->hunting = false; + monc->had_a_connection = true; + un_backoff(monc); + __schedule_delayed(monc); + } +} + +static void finish_auth(struct ceph_mon_client *monc, int auth_err, + bool was_authed) +{ + dout("%s auth_err %d was_authed %d\n", __func__, auth_err, was_authed); + WARN_ON(auth_err > 0); + + monc->pending_auth = 0; + if (auth_err) { + monc->client->auth_err = auth_err; + wake_up_all(&monc->client->auth_wq); + return; + } + + if (!was_authed && ceph_auth_is_authenticated(monc->auth)) { + dout("%s authenticated, starting session global_id %llu\n", + __func__, monc->auth->global_id); + + monc->client->msgr.inst.name.type = CEPH_ENTITY_TYPE_CLIENT; + monc->client->msgr.inst.name.num = + cpu_to_le64(monc->auth->global_id); + + __send_subscribe(monc); + __resend_generic_request(monc); + + pr_info("mon%d %s session established\n", monc->cur_mon, + ceph_pr_addr(&monc->con.peer_addr)); + } +} + +static void handle_auth_reply(struct ceph_mon_client *monc, + struct ceph_msg *msg) +{ + bool was_authed; + int ret; + + mutex_lock(&monc->mutex); + was_authed = ceph_auth_is_authenticated(monc->auth); + ret = ceph_handle_auth_reply(monc->auth, msg->front.iov_base, + msg->front.iov_len, + monc->m_auth->front.iov_base, + monc->m_auth->front_alloc_len); + if (ret > 0) { + __send_prepared_auth_request(monc, ret); + } else { + finish_auth(monc, ret, was_authed); + finish_hunting(monc); + } + mutex_unlock(&monc->mutex); +} + +static int __validate_auth(struct ceph_mon_client *monc) +{ + int ret; + + if (monc->pending_auth) + return 0; + + ret = ceph_build_auth(monc->auth, monc->m_auth->front.iov_base, + monc->m_auth->front_alloc_len); + if (ret <= 0) + return ret; /* either an error, or no need to authenticate */ + __send_prepared_auth_request(monc, ret); + return 0; +} + +int ceph_monc_validate_auth(struct ceph_mon_client *monc) +{ + int ret; + + mutex_lock(&monc->mutex); + ret = __validate_auth(monc); + mutex_unlock(&monc->mutex); + return ret; +} +EXPORT_SYMBOL(ceph_monc_validate_auth); + +static int mon_get_auth_request(struct ceph_connection *con, + void *buf, int *buf_len, + void **authorizer, int *authorizer_len) +{ + struct ceph_mon_client *monc = con->private; + int ret; + + mutex_lock(&monc->mutex); + ret = ceph_auth_get_request(monc->auth, buf, *buf_len); + mutex_unlock(&monc->mutex); + if (ret < 0) + return ret; + + *buf_len = ret; + *authorizer = NULL; + *authorizer_len = 0; + return 0; +} + +static int mon_handle_auth_reply_more(struct ceph_connection *con, + void *reply, int reply_len, + void *buf, int *buf_len, + void **authorizer, int *authorizer_len) +{ + struct ceph_mon_client *monc = con->private; + int ret; + + mutex_lock(&monc->mutex); + ret = ceph_auth_handle_reply_more(monc->auth, reply, reply_len, + buf, *buf_len); + mutex_unlock(&monc->mutex); + if (ret < 0) + return ret; + + *buf_len = ret; + *authorizer = NULL; + *authorizer_len = 0; + return 0; +} + +static int mon_handle_auth_done(struct ceph_connection *con, + u64 global_id, void *reply, int reply_len, + u8 *session_key, int *session_key_len, + u8 *con_secret, int *con_secret_len) +{ + struct ceph_mon_client *monc = con->private; + bool was_authed; + int ret; + + mutex_lock(&monc->mutex); + WARN_ON(!monc->hunting); + was_authed = ceph_auth_is_authenticated(monc->auth); + ret = ceph_auth_handle_reply_done(monc->auth, global_id, + reply, reply_len, + session_key, session_key_len, + con_secret, con_secret_len); + finish_auth(monc, ret, was_authed); + if (!ret) + finish_hunting(monc); + mutex_unlock(&monc->mutex); + return 0; +} + +static int mon_handle_auth_bad_method(struct ceph_connection *con, + int used_proto, int result, + const int *allowed_protos, int proto_cnt, + const int *allowed_modes, int mode_cnt) +{ + struct ceph_mon_client *monc = con->private; + bool was_authed; + + mutex_lock(&monc->mutex); + WARN_ON(!monc->hunting); + was_authed = ceph_auth_is_authenticated(monc->auth); + ceph_auth_handle_bad_method(monc->auth, used_proto, result, + allowed_protos, proto_cnt, + allowed_modes, mode_cnt); + finish_auth(monc, -EACCES, was_authed); + mutex_unlock(&monc->mutex); + return 0; +} + +/* + * handle incoming message + */ +static void mon_dispatch(struct ceph_connection *con, struct ceph_msg *msg) +{ + struct ceph_mon_client *monc = con->private; + int type = le16_to_cpu(msg->hdr.type); + + switch (type) { + case CEPH_MSG_AUTH_REPLY: + handle_auth_reply(monc, msg); + break; + + case CEPH_MSG_MON_SUBSCRIBE_ACK: + handle_subscribe_ack(monc, msg); + break; + + case CEPH_MSG_STATFS_REPLY: + handle_statfs_reply(monc, msg); + break; + + case CEPH_MSG_MON_GET_VERSION_REPLY: + handle_get_version_reply(monc, msg); + break; + + case CEPH_MSG_MON_COMMAND_ACK: + handle_command_ack(monc, msg); + break; + + case CEPH_MSG_MON_MAP: + ceph_monc_handle_map(monc, msg); + break; + + case CEPH_MSG_OSD_MAP: + ceph_osdc_handle_map(&monc->client->osdc, msg); + break; + + default: + /* can the chained handler handle it? */ + if (monc->client->extra_mon_dispatch && + monc->client->extra_mon_dispatch(monc->client, msg) == 0) + break; + + pr_err("received unknown message type %d %s\n", type, + ceph_msg_type_name(type)); + } + ceph_msg_put(msg); +} + +/* + * Allocate memory for incoming message + */ +static struct ceph_msg *mon_alloc_msg(struct ceph_connection *con, + struct ceph_msg_header *hdr, + int *skip) +{ + struct ceph_mon_client *monc = con->private; + int type = le16_to_cpu(hdr->type); + int front_len = le32_to_cpu(hdr->front_len); + struct ceph_msg *m = NULL; + + *skip = 0; + + switch (type) { + case CEPH_MSG_MON_SUBSCRIBE_ACK: + m = ceph_msg_get(monc->m_subscribe_ack); + break; + case CEPH_MSG_STATFS_REPLY: + case CEPH_MSG_MON_COMMAND_ACK: + return get_generic_reply(con, hdr, skip); + case CEPH_MSG_AUTH_REPLY: + m = ceph_msg_get(monc->m_auth_reply); + break; + case CEPH_MSG_MON_GET_VERSION_REPLY: + if (le64_to_cpu(hdr->tid) != 0) + return get_generic_reply(con, hdr, skip); + + /* + * Older OSDs don't set reply tid even if the original + * request had a non-zero tid. Work around this weirdness + * by allocating a new message. + */ + fallthrough; + case CEPH_MSG_MON_MAP: + case CEPH_MSG_MDS_MAP: + case CEPH_MSG_OSD_MAP: + case CEPH_MSG_FS_MAP_USER: + m = ceph_msg_new(type, front_len, GFP_NOFS, false); + if (!m) + return NULL; /* ENOMEM--return skip == 0 */ + break; + } + + if (!m) { + pr_info("alloc_msg unknown type %d\n", type); + *skip = 1; + } else if (front_len > m->front_alloc_len) { + pr_warn("mon_alloc_msg front %d > prealloc %d (%u#%llu)\n", + front_len, m->front_alloc_len, + (unsigned int)con->peer_name.type, + le64_to_cpu(con->peer_name.num)); + ceph_msg_put(m); + m = ceph_msg_new(type, front_len, GFP_NOFS, false); + } + + return m; +} + +/* + * If the monitor connection resets, pick a new monitor and resubmit + * any pending requests. + */ +static void mon_fault(struct ceph_connection *con) +{ + struct ceph_mon_client *monc = con->private; + + mutex_lock(&monc->mutex); + dout("%s mon%d\n", __func__, monc->cur_mon); + if (monc->cur_mon >= 0) { + if (!monc->hunting) { + dout("%s hunting for new mon\n", __func__); + reopen_session(monc); + __schedule_delayed(monc); + } else { + dout("%s already hunting\n", __func__); + } + } + mutex_unlock(&monc->mutex); +} + +/* + * We can ignore refcounting on the connection struct, as all references + * will come from the messenger workqueue, which is drained prior to + * mon_client destruction. + */ +static struct ceph_connection *mon_get_con(struct ceph_connection *con) +{ + return con; +} + +static void mon_put_con(struct ceph_connection *con) +{ +} + +static const struct ceph_connection_operations mon_con_ops = { + .get = mon_get_con, + .put = mon_put_con, + .alloc_msg = mon_alloc_msg, + .dispatch = mon_dispatch, + .fault = mon_fault, + .get_auth_request = mon_get_auth_request, + .handle_auth_reply_more = mon_handle_auth_reply_more, + .handle_auth_done = mon_handle_auth_done, + .handle_auth_bad_method = mon_handle_auth_bad_method, +}; diff --git a/net/ceph/msgpool.c b/net/ceph/msgpool.c new file mode 100644 index 000000000..e3ecb80cd --- /dev/null +++ b/net/ceph/msgpool.c @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: GPL-2.0 +#include + +#include +#include +#include +#include + +#include +#include + +static void *msgpool_alloc(gfp_t gfp_mask, void *arg) +{ + struct ceph_msgpool *pool = arg; + struct ceph_msg *msg; + + msg = ceph_msg_new2(pool->type, pool->front_len, pool->max_data_items, + gfp_mask, true); + if (!msg) { + dout("msgpool_alloc %s failed\n", pool->name); + } else { + dout("msgpool_alloc %s %p\n", pool->name, msg); + msg->pool = pool; + } + return msg; +} + +static void msgpool_free(void *element, void *arg) +{ + struct ceph_msgpool *pool = arg; + struct ceph_msg *msg = element; + + dout("msgpool_release %s %p\n", pool->name, msg); + msg->pool = NULL; + ceph_msg_put(msg); +} + +int ceph_msgpool_init(struct ceph_msgpool *pool, int type, + int front_len, int max_data_items, int size, + const char *name) +{ + dout("msgpool %s init\n", name); + pool->type = type; + pool->front_len = front_len; + pool->max_data_items = max_data_items; + pool->pool = mempool_create(size, msgpool_alloc, msgpool_free, pool); + if (!pool->pool) + return -ENOMEM; + pool->name = name; + return 0; +} + +void ceph_msgpool_destroy(struct ceph_msgpool *pool) +{ + dout("msgpool %s destroy\n", pool->name); + mempool_destroy(pool->pool); +} + +struct ceph_msg *ceph_msgpool_get(struct ceph_msgpool *pool, int front_len, + int max_data_items) +{ + struct ceph_msg *msg; + + if (front_len > pool->front_len || + max_data_items > pool->max_data_items) { + pr_warn_ratelimited("%s need %d/%d, pool %s has %d/%d\n", + __func__, front_len, max_data_items, pool->name, + pool->front_len, pool->max_data_items); + WARN_ON_ONCE(1); + + /* try to alloc a fresh message */ + return ceph_msg_new2(pool->type, front_len, max_data_items, + GFP_NOFS, false); + } + + msg = mempool_alloc(pool->pool, GFP_NOFS); + dout("msgpool_get %s %p\n", pool->name, msg); + return msg; +} + +void ceph_msgpool_put(struct ceph_msgpool *pool, struct ceph_msg *msg) +{ + dout("msgpool_put %s %p\n", pool->name, msg); + + /* reset msg front_len; user may have changed it */ + msg->front.iov_len = pool->front_len; + msg->hdr.front_len = cpu_to_le32(pool->front_len); + + msg->data_length = 0; + msg->num_data_items = 0; + + kref_init(&msg->kref); /* retake single ref */ + mempool_free(msg, pool->pool); +} diff --git a/net/ceph/osd_client.c b/net/ceph/osd_client.c new file mode 100644 index 000000000..c22bb06b4 --- /dev/null +++ b/net/ceph/osd_client.c @@ -0,0 +1,5658 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include + +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_BLOCK +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#define OSD_OPREPLY_FRONT_LEN 512 + +static struct kmem_cache *ceph_osd_request_cache; + +static const struct ceph_connection_operations osd_con_ops; + +/* + * Implement client access to distributed object storage cluster. + * + * All data objects are stored within a cluster/cloud of OSDs, or + * "object storage devices." (Note that Ceph OSDs have _nothing_ to + * do with the T10 OSD extensions to SCSI.) Ceph OSDs are simply + * remote daemons serving up and coordinating consistent and safe + * access to storage. + * + * Cluster membership and the mapping of data objects onto storage devices + * are described by the osd map. + * + * We keep track of pending OSD requests (read, write), resubmit + * requests to different OSDs when the cluster topology/data layout + * change, or retry the affected requests when the communications + * channel with an OSD is reset. + */ + +static void link_request(struct ceph_osd *osd, struct ceph_osd_request *req); +static void unlink_request(struct ceph_osd *osd, struct ceph_osd_request *req); +static void link_linger(struct ceph_osd *osd, + struct ceph_osd_linger_request *lreq); +static void unlink_linger(struct ceph_osd *osd, + struct ceph_osd_linger_request *lreq); +static void clear_backoffs(struct ceph_osd *osd); + +#if 1 +static inline bool rwsem_is_wrlocked(struct rw_semaphore *sem) +{ + bool wrlocked = true; + + if (unlikely(down_read_trylock(sem))) { + wrlocked = false; + up_read(sem); + } + + return wrlocked; +} +static inline void verify_osdc_locked(struct ceph_osd_client *osdc) +{ + WARN_ON(!rwsem_is_locked(&osdc->lock)); +} +static inline void verify_osdc_wrlocked(struct ceph_osd_client *osdc) +{ + WARN_ON(!rwsem_is_wrlocked(&osdc->lock)); +} +static inline void verify_osd_locked(struct ceph_osd *osd) +{ + struct ceph_osd_client *osdc = osd->o_osdc; + + WARN_ON(!(mutex_is_locked(&osd->lock) && + rwsem_is_locked(&osdc->lock)) && + !rwsem_is_wrlocked(&osdc->lock)); +} +static inline void verify_lreq_locked(struct ceph_osd_linger_request *lreq) +{ + WARN_ON(!mutex_is_locked(&lreq->lock)); +} +#else +static inline void verify_osdc_locked(struct ceph_osd_client *osdc) { } +static inline void verify_osdc_wrlocked(struct ceph_osd_client *osdc) { } +static inline void verify_osd_locked(struct ceph_osd *osd) { } +static inline void verify_lreq_locked(struct ceph_osd_linger_request *lreq) { } +#endif + +/* + * calculate the mapping of a file extent onto an object, and fill out the + * request accordingly. shorten extent as necessary if it crosses an + * object boundary. + * + * fill osd op in request message. + */ +static int calc_layout(struct ceph_file_layout *layout, u64 off, u64 *plen, + u64 *objnum, u64 *objoff, u64 *objlen) +{ + u64 orig_len = *plen; + u32 xlen; + + /* object extent? */ + ceph_calc_file_object_mapping(layout, off, orig_len, objnum, + objoff, &xlen); + *objlen = xlen; + if (*objlen < orig_len) { + *plen = *objlen; + dout(" skipping last %llu, final file extent %llu~%llu\n", + orig_len - *plen, off, *plen); + } + + dout("calc_layout objnum=%llx %llu~%llu\n", *objnum, *objoff, *objlen); + return 0; +} + +static void ceph_osd_data_init(struct ceph_osd_data *osd_data) +{ + memset(osd_data, 0, sizeof (*osd_data)); + osd_data->type = CEPH_OSD_DATA_TYPE_NONE; +} + +/* + * Consumes @pages if @own_pages is true. + */ +static void ceph_osd_data_pages_init(struct ceph_osd_data *osd_data, + struct page **pages, u64 length, u32 alignment, + bool pages_from_pool, bool own_pages) +{ + osd_data->type = CEPH_OSD_DATA_TYPE_PAGES; + osd_data->pages = pages; + osd_data->length = length; + osd_data->alignment = alignment; + osd_data->pages_from_pool = pages_from_pool; + osd_data->own_pages = own_pages; +} + +/* + * Consumes a ref on @pagelist. + */ +static void ceph_osd_data_pagelist_init(struct ceph_osd_data *osd_data, + struct ceph_pagelist *pagelist) +{ + osd_data->type = CEPH_OSD_DATA_TYPE_PAGELIST; + osd_data->pagelist = pagelist; +} + +#ifdef CONFIG_BLOCK +static void ceph_osd_data_bio_init(struct ceph_osd_data *osd_data, + struct ceph_bio_iter *bio_pos, + u32 bio_length) +{ + osd_data->type = CEPH_OSD_DATA_TYPE_BIO; + osd_data->bio_pos = *bio_pos; + osd_data->bio_length = bio_length; +} +#endif /* CONFIG_BLOCK */ + +static void ceph_osd_data_bvecs_init(struct ceph_osd_data *osd_data, + struct ceph_bvec_iter *bvec_pos, + u32 num_bvecs) +{ + osd_data->type = CEPH_OSD_DATA_TYPE_BVECS; + osd_data->bvec_pos = *bvec_pos; + osd_data->num_bvecs = num_bvecs; +} + +static struct ceph_osd_data * +osd_req_op_raw_data_in(struct ceph_osd_request *osd_req, unsigned int which) +{ + BUG_ON(which >= osd_req->r_num_ops); + + return &osd_req->r_ops[which].raw_data_in; +} + +struct ceph_osd_data * +osd_req_op_extent_osd_data(struct ceph_osd_request *osd_req, + unsigned int which) +{ + return osd_req_op_data(osd_req, which, extent, osd_data); +} +EXPORT_SYMBOL(osd_req_op_extent_osd_data); + +void osd_req_op_raw_data_in_pages(struct ceph_osd_request *osd_req, + unsigned int which, struct page **pages, + u64 length, u32 alignment, + bool pages_from_pool, bool own_pages) +{ + struct ceph_osd_data *osd_data; + + osd_data = osd_req_op_raw_data_in(osd_req, which); + ceph_osd_data_pages_init(osd_data, pages, length, alignment, + pages_from_pool, own_pages); +} +EXPORT_SYMBOL(osd_req_op_raw_data_in_pages); + +void osd_req_op_extent_osd_data_pages(struct ceph_osd_request *osd_req, + unsigned int which, struct page **pages, + u64 length, u32 alignment, + bool pages_from_pool, bool own_pages) +{ + struct ceph_osd_data *osd_data; + + osd_data = osd_req_op_data(osd_req, which, extent, osd_data); + ceph_osd_data_pages_init(osd_data, pages, length, alignment, + pages_from_pool, own_pages); +} +EXPORT_SYMBOL(osd_req_op_extent_osd_data_pages); + +void osd_req_op_extent_osd_data_pagelist(struct ceph_osd_request *osd_req, + unsigned int which, struct ceph_pagelist *pagelist) +{ + struct ceph_osd_data *osd_data; + + osd_data = osd_req_op_data(osd_req, which, extent, osd_data); + ceph_osd_data_pagelist_init(osd_data, pagelist); +} +EXPORT_SYMBOL(osd_req_op_extent_osd_data_pagelist); + +#ifdef CONFIG_BLOCK +void osd_req_op_extent_osd_data_bio(struct ceph_osd_request *osd_req, + unsigned int which, + struct ceph_bio_iter *bio_pos, + u32 bio_length) +{ + struct ceph_osd_data *osd_data; + + osd_data = osd_req_op_data(osd_req, which, extent, osd_data); + ceph_osd_data_bio_init(osd_data, bio_pos, bio_length); +} +EXPORT_SYMBOL(osd_req_op_extent_osd_data_bio); +#endif /* CONFIG_BLOCK */ + +void osd_req_op_extent_osd_data_bvecs(struct ceph_osd_request *osd_req, + unsigned int which, + struct bio_vec *bvecs, u32 num_bvecs, + u32 bytes) +{ + struct ceph_osd_data *osd_data; + struct ceph_bvec_iter it = { + .bvecs = bvecs, + .iter = { .bi_size = bytes }, + }; + + osd_data = osd_req_op_data(osd_req, which, extent, osd_data); + ceph_osd_data_bvecs_init(osd_data, &it, num_bvecs); +} +EXPORT_SYMBOL(osd_req_op_extent_osd_data_bvecs); + +void osd_req_op_extent_osd_data_bvec_pos(struct ceph_osd_request *osd_req, + unsigned int which, + struct ceph_bvec_iter *bvec_pos) +{ + struct ceph_osd_data *osd_data; + + osd_data = osd_req_op_data(osd_req, which, extent, osd_data); + ceph_osd_data_bvecs_init(osd_data, bvec_pos, 0); +} +EXPORT_SYMBOL(osd_req_op_extent_osd_data_bvec_pos); + +static void osd_req_op_cls_request_info_pagelist( + struct ceph_osd_request *osd_req, + unsigned int which, struct ceph_pagelist *pagelist) +{ + struct ceph_osd_data *osd_data; + + osd_data = osd_req_op_data(osd_req, which, cls, request_info); + ceph_osd_data_pagelist_init(osd_data, pagelist); +} + +void osd_req_op_cls_request_data_pagelist( + struct ceph_osd_request *osd_req, + unsigned int which, struct ceph_pagelist *pagelist) +{ + struct ceph_osd_data *osd_data; + + osd_data = osd_req_op_data(osd_req, which, cls, request_data); + ceph_osd_data_pagelist_init(osd_data, pagelist); + osd_req->r_ops[which].cls.indata_len += pagelist->length; + osd_req->r_ops[which].indata_len += pagelist->length; +} +EXPORT_SYMBOL(osd_req_op_cls_request_data_pagelist); + +void osd_req_op_cls_request_data_pages(struct ceph_osd_request *osd_req, + unsigned int which, struct page **pages, u64 length, + u32 alignment, bool pages_from_pool, bool own_pages) +{ + struct ceph_osd_data *osd_data; + + osd_data = osd_req_op_data(osd_req, which, cls, request_data); + ceph_osd_data_pages_init(osd_data, pages, length, alignment, + pages_from_pool, own_pages); + osd_req->r_ops[which].cls.indata_len += length; + osd_req->r_ops[which].indata_len += length; +} +EXPORT_SYMBOL(osd_req_op_cls_request_data_pages); + +void osd_req_op_cls_request_data_bvecs(struct ceph_osd_request *osd_req, + unsigned int which, + struct bio_vec *bvecs, u32 num_bvecs, + u32 bytes) +{ + struct ceph_osd_data *osd_data; + struct ceph_bvec_iter it = { + .bvecs = bvecs, + .iter = { .bi_size = bytes }, + }; + + osd_data = osd_req_op_data(osd_req, which, cls, request_data); + ceph_osd_data_bvecs_init(osd_data, &it, num_bvecs); + osd_req->r_ops[which].cls.indata_len += bytes; + osd_req->r_ops[which].indata_len += bytes; +} +EXPORT_SYMBOL(osd_req_op_cls_request_data_bvecs); + +void osd_req_op_cls_response_data_pages(struct ceph_osd_request *osd_req, + unsigned int which, struct page **pages, u64 length, + u32 alignment, bool pages_from_pool, bool own_pages) +{ + struct ceph_osd_data *osd_data; + + osd_data = osd_req_op_data(osd_req, which, cls, response_data); + ceph_osd_data_pages_init(osd_data, pages, length, alignment, + pages_from_pool, own_pages); +} +EXPORT_SYMBOL(osd_req_op_cls_response_data_pages); + +static u64 ceph_osd_data_length(struct ceph_osd_data *osd_data) +{ + switch (osd_data->type) { + case CEPH_OSD_DATA_TYPE_NONE: + return 0; + case CEPH_OSD_DATA_TYPE_PAGES: + return osd_data->length; + case CEPH_OSD_DATA_TYPE_PAGELIST: + return (u64)osd_data->pagelist->length; +#ifdef CONFIG_BLOCK + case CEPH_OSD_DATA_TYPE_BIO: + return (u64)osd_data->bio_length; +#endif /* CONFIG_BLOCK */ + case CEPH_OSD_DATA_TYPE_BVECS: + return osd_data->bvec_pos.iter.bi_size; + default: + WARN(true, "unrecognized data type %d\n", (int)osd_data->type); + return 0; + } +} + +static void ceph_osd_data_release(struct ceph_osd_data *osd_data) +{ + if (osd_data->type == CEPH_OSD_DATA_TYPE_PAGES && osd_data->own_pages) { + int num_pages; + + num_pages = calc_pages_for((u64)osd_data->alignment, + (u64)osd_data->length); + ceph_release_page_vector(osd_data->pages, num_pages); + } else if (osd_data->type == CEPH_OSD_DATA_TYPE_PAGELIST) { + ceph_pagelist_release(osd_data->pagelist); + } + ceph_osd_data_init(osd_data); +} + +static void osd_req_op_data_release(struct ceph_osd_request *osd_req, + unsigned int which) +{ + struct ceph_osd_req_op *op; + + BUG_ON(which >= osd_req->r_num_ops); + op = &osd_req->r_ops[which]; + + switch (op->op) { + case CEPH_OSD_OP_READ: + case CEPH_OSD_OP_WRITE: + case CEPH_OSD_OP_WRITEFULL: + ceph_osd_data_release(&op->extent.osd_data); + break; + case CEPH_OSD_OP_CALL: + ceph_osd_data_release(&op->cls.request_info); + ceph_osd_data_release(&op->cls.request_data); + ceph_osd_data_release(&op->cls.response_data); + break; + case CEPH_OSD_OP_SETXATTR: + case CEPH_OSD_OP_CMPXATTR: + ceph_osd_data_release(&op->xattr.osd_data); + break; + case CEPH_OSD_OP_STAT: + ceph_osd_data_release(&op->raw_data_in); + break; + case CEPH_OSD_OP_NOTIFY_ACK: + ceph_osd_data_release(&op->notify_ack.request_data); + break; + case CEPH_OSD_OP_NOTIFY: + ceph_osd_data_release(&op->notify.request_data); + ceph_osd_data_release(&op->notify.response_data); + break; + case CEPH_OSD_OP_LIST_WATCHERS: + ceph_osd_data_release(&op->list_watchers.response_data); + break; + case CEPH_OSD_OP_COPY_FROM2: + ceph_osd_data_release(&op->copy_from.osd_data); + break; + default: + break; + } +} + +/* + * Assumes @t is zero-initialized. + */ +static void target_init(struct ceph_osd_request_target *t) +{ + ceph_oid_init(&t->base_oid); + ceph_oloc_init(&t->base_oloc); + ceph_oid_init(&t->target_oid); + ceph_oloc_init(&t->target_oloc); + + ceph_osds_init(&t->acting); + ceph_osds_init(&t->up); + t->size = -1; + t->min_size = -1; + + t->osd = CEPH_HOMELESS_OSD; +} + +static void target_copy(struct ceph_osd_request_target *dest, + const struct ceph_osd_request_target *src) +{ + ceph_oid_copy(&dest->base_oid, &src->base_oid); + ceph_oloc_copy(&dest->base_oloc, &src->base_oloc); + ceph_oid_copy(&dest->target_oid, &src->target_oid); + ceph_oloc_copy(&dest->target_oloc, &src->target_oloc); + + dest->pgid = src->pgid; /* struct */ + dest->spgid = src->spgid; /* struct */ + dest->pg_num = src->pg_num; + dest->pg_num_mask = src->pg_num_mask; + ceph_osds_copy(&dest->acting, &src->acting); + ceph_osds_copy(&dest->up, &src->up); + dest->size = src->size; + dest->min_size = src->min_size; + dest->sort_bitwise = src->sort_bitwise; + dest->recovery_deletes = src->recovery_deletes; + + dest->flags = src->flags; + dest->used_replica = src->used_replica; + dest->paused = src->paused; + + dest->epoch = src->epoch; + dest->last_force_resend = src->last_force_resend; + + dest->osd = src->osd; +} + +static void target_destroy(struct ceph_osd_request_target *t) +{ + ceph_oid_destroy(&t->base_oid); + ceph_oloc_destroy(&t->base_oloc); + ceph_oid_destroy(&t->target_oid); + ceph_oloc_destroy(&t->target_oloc); +} + +/* + * requests + */ +static void request_release_checks(struct ceph_osd_request *req) +{ + WARN_ON(!RB_EMPTY_NODE(&req->r_node)); + WARN_ON(!RB_EMPTY_NODE(&req->r_mc_node)); + WARN_ON(!list_empty(&req->r_private_item)); + WARN_ON(req->r_osd); +} + +static void ceph_osdc_release_request(struct kref *kref) +{ + struct ceph_osd_request *req = container_of(kref, + struct ceph_osd_request, r_kref); + unsigned int which; + + dout("%s %p (r_request %p r_reply %p)\n", __func__, req, + req->r_request, req->r_reply); + request_release_checks(req); + + if (req->r_request) + ceph_msg_put(req->r_request); + if (req->r_reply) + ceph_msg_put(req->r_reply); + + for (which = 0; which < req->r_num_ops; which++) + osd_req_op_data_release(req, which); + + target_destroy(&req->r_t); + ceph_put_snap_context(req->r_snapc); + + if (req->r_mempool) + mempool_free(req, req->r_osdc->req_mempool); + else if (req->r_num_ops <= CEPH_OSD_SLAB_OPS) + kmem_cache_free(ceph_osd_request_cache, req); + else + kfree(req); +} + +void ceph_osdc_get_request(struct ceph_osd_request *req) +{ + dout("%s %p (was %d)\n", __func__, req, + kref_read(&req->r_kref)); + kref_get(&req->r_kref); +} +EXPORT_SYMBOL(ceph_osdc_get_request); + +void ceph_osdc_put_request(struct ceph_osd_request *req) +{ + if (req) { + dout("%s %p (was %d)\n", __func__, req, + kref_read(&req->r_kref)); + kref_put(&req->r_kref, ceph_osdc_release_request); + } +} +EXPORT_SYMBOL(ceph_osdc_put_request); + +static void request_init(struct ceph_osd_request *req) +{ + /* req only, each op is zeroed in osd_req_op_init() */ + memset(req, 0, sizeof(*req)); + + kref_init(&req->r_kref); + init_completion(&req->r_completion); + RB_CLEAR_NODE(&req->r_node); + RB_CLEAR_NODE(&req->r_mc_node); + INIT_LIST_HEAD(&req->r_private_item); + + target_init(&req->r_t); +} + +struct ceph_osd_request *ceph_osdc_alloc_request(struct ceph_osd_client *osdc, + struct ceph_snap_context *snapc, + unsigned int num_ops, + bool use_mempool, + gfp_t gfp_flags) +{ + struct ceph_osd_request *req; + + if (use_mempool) { + BUG_ON(num_ops > CEPH_OSD_SLAB_OPS); + req = mempool_alloc(osdc->req_mempool, gfp_flags); + } else if (num_ops <= CEPH_OSD_SLAB_OPS) { + req = kmem_cache_alloc(ceph_osd_request_cache, gfp_flags); + } else { + BUG_ON(num_ops > CEPH_OSD_MAX_OPS); + req = kmalloc(struct_size(req, r_ops, num_ops), gfp_flags); + } + if (unlikely(!req)) + return NULL; + + request_init(req); + req->r_osdc = osdc; + req->r_mempool = use_mempool; + req->r_num_ops = num_ops; + req->r_snapid = CEPH_NOSNAP; + req->r_snapc = ceph_get_snap_context(snapc); + + dout("%s req %p\n", __func__, req); + return req; +} +EXPORT_SYMBOL(ceph_osdc_alloc_request); + +static int ceph_oloc_encoding_size(const struct ceph_object_locator *oloc) +{ + return 8 + 4 + 4 + 4 + (oloc->pool_ns ? oloc->pool_ns->len : 0); +} + +static int __ceph_osdc_alloc_messages(struct ceph_osd_request *req, gfp_t gfp, + int num_request_data_items, + int num_reply_data_items) +{ + struct ceph_osd_client *osdc = req->r_osdc; + struct ceph_msg *msg; + int msg_size; + + WARN_ON(req->r_request || req->r_reply); + WARN_ON(ceph_oid_empty(&req->r_base_oid)); + WARN_ON(ceph_oloc_empty(&req->r_base_oloc)); + + /* create request message */ + msg_size = CEPH_ENCODING_START_BLK_LEN + + CEPH_PGID_ENCODING_LEN + 1; /* spgid */ + msg_size += 4 + 4 + 4; /* hash, osdmap_epoch, flags */ + msg_size += CEPH_ENCODING_START_BLK_LEN + + sizeof(struct ceph_osd_reqid); /* reqid */ + msg_size += sizeof(struct ceph_blkin_trace_info); /* trace */ + msg_size += 4 + sizeof(struct ceph_timespec); /* client_inc, mtime */ + msg_size += CEPH_ENCODING_START_BLK_LEN + + ceph_oloc_encoding_size(&req->r_base_oloc); /* oloc */ + msg_size += 4 + req->r_base_oid.name_len; /* oid */ + msg_size += 2 + req->r_num_ops * sizeof(struct ceph_osd_op); + msg_size += 8; /* snapid */ + msg_size += 8; /* snap_seq */ + msg_size += 4 + 8 * (req->r_snapc ? req->r_snapc->num_snaps : 0); + msg_size += 4 + 8; /* retry_attempt, features */ + + if (req->r_mempool) + msg = ceph_msgpool_get(&osdc->msgpool_op, msg_size, + num_request_data_items); + else + msg = ceph_msg_new2(CEPH_MSG_OSD_OP, msg_size, + num_request_data_items, gfp, true); + if (!msg) + return -ENOMEM; + + memset(msg->front.iov_base, 0, msg->front.iov_len); + req->r_request = msg; + + /* create reply message */ + msg_size = OSD_OPREPLY_FRONT_LEN; + msg_size += req->r_base_oid.name_len; + msg_size += req->r_num_ops * sizeof(struct ceph_osd_op); + + if (req->r_mempool) + msg = ceph_msgpool_get(&osdc->msgpool_op_reply, msg_size, + num_reply_data_items); + else + msg = ceph_msg_new2(CEPH_MSG_OSD_OPREPLY, msg_size, + num_reply_data_items, gfp, true); + if (!msg) + return -ENOMEM; + + req->r_reply = msg; + + return 0; +} + +static bool osd_req_opcode_valid(u16 opcode) +{ + switch (opcode) { +#define GENERATE_CASE(op, opcode, str) case CEPH_OSD_OP_##op: return true; +__CEPH_FORALL_OSD_OPS(GENERATE_CASE) +#undef GENERATE_CASE + default: + return false; + } +} + +static void get_num_data_items(struct ceph_osd_request *req, + int *num_request_data_items, + int *num_reply_data_items) +{ + struct ceph_osd_req_op *op; + + *num_request_data_items = 0; + *num_reply_data_items = 0; + + for (op = req->r_ops; op != &req->r_ops[req->r_num_ops]; op++) { + switch (op->op) { + /* request */ + case CEPH_OSD_OP_WRITE: + case CEPH_OSD_OP_WRITEFULL: + case CEPH_OSD_OP_SETXATTR: + case CEPH_OSD_OP_CMPXATTR: + case CEPH_OSD_OP_NOTIFY_ACK: + case CEPH_OSD_OP_COPY_FROM2: + *num_request_data_items += 1; + break; + + /* reply */ + case CEPH_OSD_OP_STAT: + case CEPH_OSD_OP_READ: + case CEPH_OSD_OP_LIST_WATCHERS: + *num_reply_data_items += 1; + break; + + /* both */ + case CEPH_OSD_OP_NOTIFY: + *num_request_data_items += 1; + *num_reply_data_items += 1; + break; + case CEPH_OSD_OP_CALL: + *num_request_data_items += 2; + *num_reply_data_items += 1; + break; + + default: + WARN_ON(!osd_req_opcode_valid(op->op)); + break; + } + } +} + +/* + * oid, oloc and OSD op opcode(s) must be filled in before this function + * is called. + */ +int ceph_osdc_alloc_messages(struct ceph_osd_request *req, gfp_t gfp) +{ + int num_request_data_items, num_reply_data_items; + + get_num_data_items(req, &num_request_data_items, &num_reply_data_items); + return __ceph_osdc_alloc_messages(req, gfp, num_request_data_items, + num_reply_data_items); +} +EXPORT_SYMBOL(ceph_osdc_alloc_messages); + +/* + * This is an osd op init function for opcodes that have no data or + * other information associated with them. It also serves as a + * common init routine for all the other init functions, below. + */ +struct ceph_osd_req_op * +osd_req_op_init(struct ceph_osd_request *osd_req, unsigned int which, + u16 opcode, u32 flags) +{ + struct ceph_osd_req_op *op; + + BUG_ON(which >= osd_req->r_num_ops); + BUG_ON(!osd_req_opcode_valid(opcode)); + + op = &osd_req->r_ops[which]; + memset(op, 0, sizeof (*op)); + op->op = opcode; + op->flags = flags; + + return op; +} +EXPORT_SYMBOL(osd_req_op_init); + +void osd_req_op_extent_init(struct ceph_osd_request *osd_req, + unsigned int which, u16 opcode, + u64 offset, u64 length, + u64 truncate_size, u32 truncate_seq) +{ + struct ceph_osd_req_op *op = osd_req_op_init(osd_req, which, + opcode, 0); + size_t payload_len = 0; + + BUG_ON(opcode != CEPH_OSD_OP_READ && opcode != CEPH_OSD_OP_WRITE && + opcode != CEPH_OSD_OP_WRITEFULL && opcode != CEPH_OSD_OP_ZERO && + opcode != CEPH_OSD_OP_TRUNCATE); + + op->extent.offset = offset; + op->extent.length = length; + op->extent.truncate_size = truncate_size; + op->extent.truncate_seq = truncate_seq; + if (opcode == CEPH_OSD_OP_WRITE || opcode == CEPH_OSD_OP_WRITEFULL) + payload_len += length; + + op->indata_len = payload_len; +} +EXPORT_SYMBOL(osd_req_op_extent_init); + +void osd_req_op_extent_update(struct ceph_osd_request *osd_req, + unsigned int which, u64 length) +{ + struct ceph_osd_req_op *op; + u64 previous; + + BUG_ON(which >= osd_req->r_num_ops); + op = &osd_req->r_ops[which]; + previous = op->extent.length; + + if (length == previous) + return; /* Nothing to do */ + BUG_ON(length > previous); + + op->extent.length = length; + if (op->op == CEPH_OSD_OP_WRITE || op->op == CEPH_OSD_OP_WRITEFULL) + op->indata_len -= previous - length; +} +EXPORT_SYMBOL(osd_req_op_extent_update); + +void osd_req_op_extent_dup_last(struct ceph_osd_request *osd_req, + unsigned int which, u64 offset_inc) +{ + struct ceph_osd_req_op *op, *prev_op; + + BUG_ON(which + 1 >= osd_req->r_num_ops); + + prev_op = &osd_req->r_ops[which]; + op = osd_req_op_init(osd_req, which + 1, prev_op->op, prev_op->flags); + /* dup previous one */ + op->indata_len = prev_op->indata_len; + op->outdata_len = prev_op->outdata_len; + op->extent = prev_op->extent; + /* adjust offset */ + op->extent.offset += offset_inc; + op->extent.length -= offset_inc; + + if (op->op == CEPH_OSD_OP_WRITE || op->op == CEPH_OSD_OP_WRITEFULL) + op->indata_len -= offset_inc; +} +EXPORT_SYMBOL(osd_req_op_extent_dup_last); + +int osd_req_op_cls_init(struct ceph_osd_request *osd_req, unsigned int which, + const char *class, const char *method) +{ + struct ceph_osd_req_op *op; + struct ceph_pagelist *pagelist; + size_t payload_len = 0; + size_t size; + int ret; + + op = osd_req_op_init(osd_req, which, CEPH_OSD_OP_CALL, 0); + + pagelist = ceph_pagelist_alloc(GFP_NOFS); + if (!pagelist) + return -ENOMEM; + + op->cls.class_name = class; + size = strlen(class); + BUG_ON(size > (size_t) U8_MAX); + op->cls.class_len = size; + ret = ceph_pagelist_append(pagelist, class, size); + if (ret) + goto err_pagelist_free; + payload_len += size; + + op->cls.method_name = method; + size = strlen(method); + BUG_ON(size > (size_t) U8_MAX); + op->cls.method_len = size; + ret = ceph_pagelist_append(pagelist, method, size); + if (ret) + goto err_pagelist_free; + payload_len += size; + + osd_req_op_cls_request_info_pagelist(osd_req, which, pagelist); + op->indata_len = payload_len; + return 0; + +err_pagelist_free: + ceph_pagelist_release(pagelist); + return ret; +} +EXPORT_SYMBOL(osd_req_op_cls_init); + +int osd_req_op_xattr_init(struct ceph_osd_request *osd_req, unsigned int which, + u16 opcode, const char *name, const void *value, + size_t size, u8 cmp_op, u8 cmp_mode) +{ + struct ceph_osd_req_op *op = osd_req_op_init(osd_req, which, + opcode, 0); + struct ceph_pagelist *pagelist; + size_t payload_len; + int ret; + + BUG_ON(opcode != CEPH_OSD_OP_SETXATTR && opcode != CEPH_OSD_OP_CMPXATTR); + + pagelist = ceph_pagelist_alloc(GFP_NOFS); + if (!pagelist) + return -ENOMEM; + + payload_len = strlen(name); + op->xattr.name_len = payload_len; + ret = ceph_pagelist_append(pagelist, name, payload_len); + if (ret) + goto err_pagelist_free; + + op->xattr.value_len = size; + ret = ceph_pagelist_append(pagelist, value, size); + if (ret) + goto err_pagelist_free; + payload_len += size; + + op->xattr.cmp_op = cmp_op; + op->xattr.cmp_mode = cmp_mode; + + ceph_osd_data_pagelist_init(&op->xattr.osd_data, pagelist); + op->indata_len = payload_len; + return 0; + +err_pagelist_free: + ceph_pagelist_release(pagelist); + return ret; +} +EXPORT_SYMBOL(osd_req_op_xattr_init); + +/* + * @watch_opcode: CEPH_OSD_WATCH_OP_* + */ +static void osd_req_op_watch_init(struct ceph_osd_request *req, int which, + u8 watch_opcode, u64 cookie, u32 gen) +{ + struct ceph_osd_req_op *op; + + op = osd_req_op_init(req, which, CEPH_OSD_OP_WATCH, 0); + op->watch.cookie = cookie; + op->watch.op = watch_opcode; + op->watch.gen = gen; +} + +/* + * prot_ver, timeout and notify payload (may be empty) should already be + * encoded in @request_pl + */ +static void osd_req_op_notify_init(struct ceph_osd_request *req, int which, + u64 cookie, struct ceph_pagelist *request_pl) +{ + struct ceph_osd_req_op *op; + + op = osd_req_op_init(req, which, CEPH_OSD_OP_NOTIFY, 0); + op->notify.cookie = cookie; + + ceph_osd_data_pagelist_init(&op->notify.request_data, request_pl); + op->indata_len = request_pl->length; +} + +/* + * @flags: CEPH_OSD_OP_ALLOC_HINT_FLAG_* + */ +void osd_req_op_alloc_hint_init(struct ceph_osd_request *osd_req, + unsigned int which, + u64 expected_object_size, + u64 expected_write_size, + u32 flags) +{ + struct ceph_osd_req_op *op; + + op = osd_req_op_init(osd_req, which, CEPH_OSD_OP_SETALLOCHINT, 0); + op->alloc_hint.expected_object_size = expected_object_size; + op->alloc_hint.expected_write_size = expected_write_size; + op->alloc_hint.flags = flags; + + /* + * CEPH_OSD_OP_SETALLOCHINT op is advisory and therefore deemed + * not worth a feature bit. Set FAILOK per-op flag to make + * sure older osds don't trip over an unsupported opcode. + */ + op->flags |= CEPH_OSD_OP_FLAG_FAILOK; +} +EXPORT_SYMBOL(osd_req_op_alloc_hint_init); + +static void ceph_osdc_msg_data_add(struct ceph_msg *msg, + struct ceph_osd_data *osd_data) +{ + u64 length = ceph_osd_data_length(osd_data); + + if (osd_data->type == CEPH_OSD_DATA_TYPE_PAGES) { + BUG_ON(length > (u64) SIZE_MAX); + if (length) + ceph_msg_data_add_pages(msg, osd_data->pages, + length, osd_data->alignment, false); + } else if (osd_data->type == CEPH_OSD_DATA_TYPE_PAGELIST) { + BUG_ON(!length); + ceph_msg_data_add_pagelist(msg, osd_data->pagelist); +#ifdef CONFIG_BLOCK + } else if (osd_data->type == CEPH_OSD_DATA_TYPE_BIO) { + ceph_msg_data_add_bio(msg, &osd_data->bio_pos, length); +#endif + } else if (osd_data->type == CEPH_OSD_DATA_TYPE_BVECS) { + ceph_msg_data_add_bvecs(msg, &osd_data->bvec_pos); + } else { + BUG_ON(osd_data->type != CEPH_OSD_DATA_TYPE_NONE); + } +} + +static u32 osd_req_encode_op(struct ceph_osd_op *dst, + const struct ceph_osd_req_op *src) +{ + switch (src->op) { + case CEPH_OSD_OP_STAT: + break; + case CEPH_OSD_OP_READ: + case CEPH_OSD_OP_WRITE: + case CEPH_OSD_OP_WRITEFULL: + case CEPH_OSD_OP_ZERO: + case CEPH_OSD_OP_TRUNCATE: + dst->extent.offset = cpu_to_le64(src->extent.offset); + dst->extent.length = cpu_to_le64(src->extent.length); + dst->extent.truncate_size = + cpu_to_le64(src->extent.truncate_size); + dst->extent.truncate_seq = + cpu_to_le32(src->extent.truncate_seq); + break; + case CEPH_OSD_OP_CALL: + dst->cls.class_len = src->cls.class_len; + dst->cls.method_len = src->cls.method_len; + dst->cls.indata_len = cpu_to_le32(src->cls.indata_len); + break; + case CEPH_OSD_OP_WATCH: + dst->watch.cookie = cpu_to_le64(src->watch.cookie); + dst->watch.ver = cpu_to_le64(0); + dst->watch.op = src->watch.op; + dst->watch.gen = cpu_to_le32(src->watch.gen); + break; + case CEPH_OSD_OP_NOTIFY_ACK: + break; + case CEPH_OSD_OP_NOTIFY: + dst->notify.cookie = cpu_to_le64(src->notify.cookie); + break; + case CEPH_OSD_OP_LIST_WATCHERS: + break; + case CEPH_OSD_OP_SETALLOCHINT: + dst->alloc_hint.expected_object_size = + cpu_to_le64(src->alloc_hint.expected_object_size); + dst->alloc_hint.expected_write_size = + cpu_to_le64(src->alloc_hint.expected_write_size); + dst->alloc_hint.flags = cpu_to_le32(src->alloc_hint.flags); + break; + case CEPH_OSD_OP_SETXATTR: + case CEPH_OSD_OP_CMPXATTR: + dst->xattr.name_len = cpu_to_le32(src->xattr.name_len); + dst->xattr.value_len = cpu_to_le32(src->xattr.value_len); + dst->xattr.cmp_op = src->xattr.cmp_op; + dst->xattr.cmp_mode = src->xattr.cmp_mode; + break; + case CEPH_OSD_OP_CREATE: + case CEPH_OSD_OP_DELETE: + break; + case CEPH_OSD_OP_COPY_FROM2: + dst->copy_from.snapid = cpu_to_le64(src->copy_from.snapid); + dst->copy_from.src_version = + cpu_to_le64(src->copy_from.src_version); + dst->copy_from.flags = src->copy_from.flags; + dst->copy_from.src_fadvise_flags = + cpu_to_le32(src->copy_from.src_fadvise_flags); + break; + default: + pr_err("unsupported osd opcode %s\n", + ceph_osd_op_name(src->op)); + WARN_ON(1); + + return 0; + } + + dst->op = cpu_to_le16(src->op); + dst->flags = cpu_to_le32(src->flags); + dst->payload_len = cpu_to_le32(src->indata_len); + + return src->indata_len; +} + +/* + * build new request AND message, calculate layout, and adjust file + * extent as needed. + * + * if the file was recently truncated, we include information about its + * old and new size so that the object can be updated appropriately. (we + * avoid synchronously deleting truncated objects because it's slow.) + */ +struct ceph_osd_request *ceph_osdc_new_request(struct ceph_osd_client *osdc, + struct ceph_file_layout *layout, + struct ceph_vino vino, + u64 off, u64 *plen, + unsigned int which, int num_ops, + int opcode, int flags, + struct ceph_snap_context *snapc, + u32 truncate_seq, + u64 truncate_size, + bool use_mempool) +{ + struct ceph_osd_request *req; + u64 objnum = 0; + u64 objoff = 0; + u64 objlen = 0; + int r; + + BUG_ON(opcode != CEPH_OSD_OP_READ && opcode != CEPH_OSD_OP_WRITE && + opcode != CEPH_OSD_OP_ZERO && opcode != CEPH_OSD_OP_TRUNCATE && + opcode != CEPH_OSD_OP_CREATE && opcode != CEPH_OSD_OP_DELETE); + + req = ceph_osdc_alloc_request(osdc, snapc, num_ops, use_mempool, + GFP_NOFS); + if (!req) { + r = -ENOMEM; + goto fail; + } + + /* calculate max write size */ + r = calc_layout(layout, off, plen, &objnum, &objoff, &objlen); + if (r) + goto fail; + + if (opcode == CEPH_OSD_OP_CREATE || opcode == CEPH_OSD_OP_DELETE) { + osd_req_op_init(req, which, opcode, 0); + } else { + u32 object_size = layout->object_size; + u32 object_base = off - objoff; + if (!(truncate_seq == 1 && truncate_size == -1ULL)) { + if (truncate_size <= object_base) { + truncate_size = 0; + } else { + truncate_size -= object_base; + if (truncate_size > object_size) + truncate_size = object_size; + } + } + osd_req_op_extent_init(req, which, opcode, objoff, objlen, + truncate_size, truncate_seq); + } + + req->r_base_oloc.pool = layout->pool_id; + req->r_base_oloc.pool_ns = ceph_try_get_string(layout->pool_ns); + ceph_oid_printf(&req->r_base_oid, "%llx.%08llx", vino.ino, objnum); + req->r_flags = flags | osdc->client->options->read_from_replica; + + req->r_snapid = vino.snap; + if (flags & CEPH_OSD_FLAG_WRITE) + req->r_data_offset = off; + + if (num_ops > 1) + /* + * This is a special case for ceph_writepages_start(), but it + * also covers ceph_uninline_data(). If more multi-op request + * use cases emerge, we will need a separate helper. + */ + r = __ceph_osdc_alloc_messages(req, GFP_NOFS, num_ops, 0); + else + r = ceph_osdc_alloc_messages(req, GFP_NOFS); + if (r) + goto fail; + + return req; + +fail: + ceph_osdc_put_request(req); + return ERR_PTR(r); +} +EXPORT_SYMBOL(ceph_osdc_new_request); + +/* + * We keep osd requests in an rbtree, sorted by ->r_tid. + */ +DEFINE_RB_FUNCS(request, struct ceph_osd_request, r_tid, r_node) +DEFINE_RB_FUNCS(request_mc, struct ceph_osd_request, r_tid, r_mc_node) + +/* + * Call @fn on each OSD request as long as @fn returns 0. + */ +static void for_each_request(struct ceph_osd_client *osdc, + int (*fn)(struct ceph_osd_request *req, void *arg), + void *arg) +{ + struct rb_node *n, *p; + + for (n = rb_first(&osdc->osds); n; n = rb_next(n)) { + struct ceph_osd *osd = rb_entry(n, struct ceph_osd, o_node); + + for (p = rb_first(&osd->o_requests); p; ) { + struct ceph_osd_request *req = + rb_entry(p, struct ceph_osd_request, r_node); + + p = rb_next(p); + if (fn(req, arg)) + return; + } + } + + for (p = rb_first(&osdc->homeless_osd.o_requests); p; ) { + struct ceph_osd_request *req = + rb_entry(p, struct ceph_osd_request, r_node); + + p = rb_next(p); + if (fn(req, arg)) + return; + } +} + +static bool osd_homeless(struct ceph_osd *osd) +{ + return osd->o_osd == CEPH_HOMELESS_OSD; +} + +static bool osd_registered(struct ceph_osd *osd) +{ + verify_osdc_locked(osd->o_osdc); + + return !RB_EMPTY_NODE(&osd->o_node); +} + +/* + * Assumes @osd is zero-initialized. + */ +static void osd_init(struct ceph_osd *osd) +{ + refcount_set(&osd->o_ref, 1); + RB_CLEAR_NODE(&osd->o_node); + osd->o_requests = RB_ROOT; + osd->o_linger_requests = RB_ROOT; + osd->o_backoff_mappings = RB_ROOT; + osd->o_backoffs_by_id = RB_ROOT; + INIT_LIST_HEAD(&osd->o_osd_lru); + INIT_LIST_HEAD(&osd->o_keepalive_item); + osd->o_incarnation = 1; + mutex_init(&osd->lock); +} + +static void osd_cleanup(struct ceph_osd *osd) +{ + WARN_ON(!RB_EMPTY_NODE(&osd->o_node)); + WARN_ON(!RB_EMPTY_ROOT(&osd->o_requests)); + WARN_ON(!RB_EMPTY_ROOT(&osd->o_linger_requests)); + WARN_ON(!RB_EMPTY_ROOT(&osd->o_backoff_mappings)); + WARN_ON(!RB_EMPTY_ROOT(&osd->o_backoffs_by_id)); + WARN_ON(!list_empty(&osd->o_osd_lru)); + WARN_ON(!list_empty(&osd->o_keepalive_item)); + + if (osd->o_auth.authorizer) { + WARN_ON(osd_homeless(osd)); + ceph_auth_destroy_authorizer(osd->o_auth.authorizer); + } +} + +/* + * Track open sessions with osds. + */ +static struct ceph_osd *create_osd(struct ceph_osd_client *osdc, int onum) +{ + struct ceph_osd *osd; + + WARN_ON(onum == CEPH_HOMELESS_OSD); + + osd = kzalloc(sizeof(*osd), GFP_NOIO | __GFP_NOFAIL); + osd_init(osd); + osd->o_osdc = osdc; + osd->o_osd = onum; + + ceph_con_init(&osd->o_con, osd, &osd_con_ops, &osdc->client->msgr); + + return osd; +} + +static struct ceph_osd *get_osd(struct ceph_osd *osd) +{ + if (refcount_inc_not_zero(&osd->o_ref)) { + dout("get_osd %p %d -> %d\n", osd, refcount_read(&osd->o_ref)-1, + refcount_read(&osd->o_ref)); + return osd; + } else { + dout("get_osd %p FAIL\n", osd); + return NULL; + } +} + +static void put_osd(struct ceph_osd *osd) +{ + dout("put_osd %p %d -> %d\n", osd, refcount_read(&osd->o_ref), + refcount_read(&osd->o_ref) - 1); + if (refcount_dec_and_test(&osd->o_ref)) { + osd_cleanup(osd); + kfree(osd); + } +} + +DEFINE_RB_FUNCS(osd, struct ceph_osd, o_osd, o_node) + +static void __move_osd_to_lru(struct ceph_osd *osd) +{ + struct ceph_osd_client *osdc = osd->o_osdc; + + dout("%s osd %p osd%d\n", __func__, osd, osd->o_osd); + BUG_ON(!list_empty(&osd->o_osd_lru)); + + spin_lock(&osdc->osd_lru_lock); + list_add_tail(&osd->o_osd_lru, &osdc->osd_lru); + spin_unlock(&osdc->osd_lru_lock); + + osd->lru_ttl = jiffies + osdc->client->options->osd_idle_ttl; +} + +static void maybe_move_osd_to_lru(struct ceph_osd *osd) +{ + if (RB_EMPTY_ROOT(&osd->o_requests) && + RB_EMPTY_ROOT(&osd->o_linger_requests)) + __move_osd_to_lru(osd); +} + +static void __remove_osd_from_lru(struct ceph_osd *osd) +{ + struct ceph_osd_client *osdc = osd->o_osdc; + + dout("%s osd %p osd%d\n", __func__, osd, osd->o_osd); + + spin_lock(&osdc->osd_lru_lock); + if (!list_empty(&osd->o_osd_lru)) + list_del_init(&osd->o_osd_lru); + spin_unlock(&osdc->osd_lru_lock); +} + +/* + * Close the connection and assign any leftover requests to the + * homeless session. + */ +static void close_osd(struct ceph_osd *osd) +{ + struct ceph_osd_client *osdc = osd->o_osdc; + struct rb_node *n; + + verify_osdc_wrlocked(osdc); + dout("%s osd %p osd%d\n", __func__, osd, osd->o_osd); + + ceph_con_close(&osd->o_con); + + for (n = rb_first(&osd->o_requests); n; ) { + struct ceph_osd_request *req = + rb_entry(n, struct ceph_osd_request, r_node); + + n = rb_next(n); /* unlink_request() */ + + dout(" reassigning req %p tid %llu\n", req, req->r_tid); + unlink_request(osd, req); + link_request(&osdc->homeless_osd, req); + } + for (n = rb_first(&osd->o_linger_requests); n; ) { + struct ceph_osd_linger_request *lreq = + rb_entry(n, struct ceph_osd_linger_request, node); + + n = rb_next(n); /* unlink_linger() */ + + dout(" reassigning lreq %p linger_id %llu\n", lreq, + lreq->linger_id); + unlink_linger(osd, lreq); + link_linger(&osdc->homeless_osd, lreq); + } + clear_backoffs(osd); + + __remove_osd_from_lru(osd); + erase_osd(&osdc->osds, osd); + put_osd(osd); +} + +/* + * reset osd connect + */ +static int reopen_osd(struct ceph_osd *osd) +{ + struct ceph_entity_addr *peer_addr; + + dout("%s osd %p osd%d\n", __func__, osd, osd->o_osd); + + if (RB_EMPTY_ROOT(&osd->o_requests) && + RB_EMPTY_ROOT(&osd->o_linger_requests)) { + close_osd(osd); + return -ENODEV; + } + + peer_addr = &osd->o_osdc->osdmap->osd_addr[osd->o_osd]; + if (!memcmp(peer_addr, &osd->o_con.peer_addr, sizeof (*peer_addr)) && + !ceph_con_opened(&osd->o_con)) { + struct rb_node *n; + + dout("osd addr hasn't changed and connection never opened, " + "letting msgr retry\n"); + /* touch each r_stamp for handle_timeout()'s benfit */ + for (n = rb_first(&osd->o_requests); n; n = rb_next(n)) { + struct ceph_osd_request *req = + rb_entry(n, struct ceph_osd_request, r_node); + req->r_stamp = jiffies; + } + + return -EAGAIN; + } + + ceph_con_close(&osd->o_con); + ceph_con_open(&osd->o_con, CEPH_ENTITY_TYPE_OSD, osd->o_osd, peer_addr); + osd->o_incarnation++; + + return 0; +} + +static struct ceph_osd *lookup_create_osd(struct ceph_osd_client *osdc, int o, + bool wrlocked) +{ + struct ceph_osd *osd; + + if (wrlocked) + verify_osdc_wrlocked(osdc); + else + verify_osdc_locked(osdc); + + if (o != CEPH_HOMELESS_OSD) + osd = lookup_osd(&osdc->osds, o); + else + osd = &osdc->homeless_osd; + if (!osd) { + if (!wrlocked) + return ERR_PTR(-EAGAIN); + + osd = create_osd(osdc, o); + insert_osd(&osdc->osds, osd); + ceph_con_open(&osd->o_con, CEPH_ENTITY_TYPE_OSD, osd->o_osd, + &osdc->osdmap->osd_addr[osd->o_osd]); + } + + dout("%s osdc %p osd%d -> osd %p\n", __func__, osdc, o, osd); + return osd; +} + +/* + * Create request <-> OSD session relation. + * + * @req has to be assigned a tid, @osd may be homeless. + */ +static void link_request(struct ceph_osd *osd, struct ceph_osd_request *req) +{ + verify_osd_locked(osd); + WARN_ON(!req->r_tid || req->r_osd); + dout("%s osd %p osd%d req %p tid %llu\n", __func__, osd, osd->o_osd, + req, req->r_tid); + + if (!osd_homeless(osd)) + __remove_osd_from_lru(osd); + else + atomic_inc(&osd->o_osdc->num_homeless); + + get_osd(osd); + insert_request(&osd->o_requests, req); + req->r_osd = osd; +} + +static void unlink_request(struct ceph_osd *osd, struct ceph_osd_request *req) +{ + verify_osd_locked(osd); + WARN_ON(req->r_osd != osd); + dout("%s osd %p osd%d req %p tid %llu\n", __func__, osd, osd->o_osd, + req, req->r_tid); + + req->r_osd = NULL; + erase_request(&osd->o_requests, req); + put_osd(osd); + + if (!osd_homeless(osd)) + maybe_move_osd_to_lru(osd); + else + atomic_dec(&osd->o_osdc->num_homeless); +} + +static bool __pool_full(struct ceph_pg_pool_info *pi) +{ + return pi->flags & CEPH_POOL_FLAG_FULL; +} + +static bool have_pool_full(struct ceph_osd_client *osdc) +{ + struct rb_node *n; + + for (n = rb_first(&osdc->osdmap->pg_pools); n; n = rb_next(n)) { + struct ceph_pg_pool_info *pi = + rb_entry(n, struct ceph_pg_pool_info, node); + + if (__pool_full(pi)) + return true; + } + + return false; +} + +static bool pool_full(struct ceph_osd_client *osdc, s64 pool_id) +{ + struct ceph_pg_pool_info *pi; + + pi = ceph_pg_pool_by_id(osdc->osdmap, pool_id); + if (!pi) + return false; + + return __pool_full(pi); +} + +/* + * Returns whether a request should be blocked from being sent + * based on the current osdmap and osd_client settings. + */ +static bool target_should_be_paused(struct ceph_osd_client *osdc, + const struct ceph_osd_request_target *t, + struct ceph_pg_pool_info *pi) +{ + bool pauserd = ceph_osdmap_flag(osdc, CEPH_OSDMAP_PAUSERD); + bool pausewr = ceph_osdmap_flag(osdc, CEPH_OSDMAP_PAUSEWR) || + ceph_osdmap_flag(osdc, CEPH_OSDMAP_FULL) || + __pool_full(pi); + + WARN_ON(pi->id != t->target_oloc.pool); + return ((t->flags & CEPH_OSD_FLAG_READ) && pauserd) || + ((t->flags & CEPH_OSD_FLAG_WRITE) && pausewr) || + (osdc->osdmap->epoch < osdc->epoch_barrier); +} + +static int pick_random_replica(const struct ceph_osds *acting) +{ + int i = prandom_u32_max(acting->size); + + dout("%s picked osd%d, primary osd%d\n", __func__, + acting->osds[i], acting->primary); + return i; +} + +/* + * Picks the closest replica based on client's location given by + * crush_location option. Prefers the primary if the locality is + * the same. + */ +static int pick_closest_replica(struct ceph_osd_client *osdc, + const struct ceph_osds *acting) +{ + struct ceph_options *opt = osdc->client->options; + int best_i, best_locality; + int i = 0, locality; + + do { + locality = ceph_get_crush_locality(osdc->osdmap, + acting->osds[i], + &opt->crush_locs); + if (i == 0 || + (locality >= 0 && best_locality < 0) || + (locality >= 0 && best_locality >= 0 && + locality < best_locality)) { + best_i = i; + best_locality = locality; + } + } while (++i < acting->size); + + dout("%s picked osd%d with locality %d, primary osd%d\n", __func__, + acting->osds[best_i], best_locality, acting->primary); + return best_i; +} + +enum calc_target_result { + CALC_TARGET_NO_ACTION = 0, + CALC_TARGET_NEED_RESEND, + CALC_TARGET_POOL_DNE, +}; + +static enum calc_target_result calc_target(struct ceph_osd_client *osdc, + struct ceph_osd_request_target *t, + bool any_change) +{ + struct ceph_pg_pool_info *pi; + struct ceph_pg pgid, last_pgid; + struct ceph_osds up, acting; + bool is_read = t->flags & CEPH_OSD_FLAG_READ; + bool is_write = t->flags & CEPH_OSD_FLAG_WRITE; + bool force_resend = false; + bool unpaused = false; + bool legacy_change = false; + bool split = false; + bool sort_bitwise = ceph_osdmap_flag(osdc, CEPH_OSDMAP_SORTBITWISE); + bool recovery_deletes = ceph_osdmap_flag(osdc, + CEPH_OSDMAP_RECOVERY_DELETES); + enum calc_target_result ct_res; + + t->epoch = osdc->osdmap->epoch; + pi = ceph_pg_pool_by_id(osdc->osdmap, t->base_oloc.pool); + if (!pi) { + t->osd = CEPH_HOMELESS_OSD; + ct_res = CALC_TARGET_POOL_DNE; + goto out; + } + + if (osdc->osdmap->epoch == pi->last_force_request_resend) { + if (t->last_force_resend < pi->last_force_request_resend) { + t->last_force_resend = pi->last_force_request_resend; + force_resend = true; + } else if (t->last_force_resend == 0) { + force_resend = true; + } + } + + /* apply tiering */ + ceph_oid_copy(&t->target_oid, &t->base_oid); + ceph_oloc_copy(&t->target_oloc, &t->base_oloc); + if ((t->flags & CEPH_OSD_FLAG_IGNORE_OVERLAY) == 0) { + if (is_read && pi->read_tier >= 0) + t->target_oloc.pool = pi->read_tier; + if (is_write && pi->write_tier >= 0) + t->target_oloc.pool = pi->write_tier; + + pi = ceph_pg_pool_by_id(osdc->osdmap, t->target_oloc.pool); + if (!pi) { + t->osd = CEPH_HOMELESS_OSD; + ct_res = CALC_TARGET_POOL_DNE; + goto out; + } + } + + __ceph_object_locator_to_pg(pi, &t->target_oid, &t->target_oloc, &pgid); + last_pgid.pool = pgid.pool; + last_pgid.seed = ceph_stable_mod(pgid.seed, t->pg_num, t->pg_num_mask); + + ceph_pg_to_up_acting_osds(osdc->osdmap, pi, &pgid, &up, &acting); + if (any_change && + ceph_is_new_interval(&t->acting, + &acting, + &t->up, + &up, + t->size, + pi->size, + t->min_size, + pi->min_size, + t->pg_num, + pi->pg_num, + t->sort_bitwise, + sort_bitwise, + t->recovery_deletes, + recovery_deletes, + &last_pgid)) + force_resend = true; + + if (t->paused && !target_should_be_paused(osdc, t, pi)) { + t->paused = false; + unpaused = true; + } + legacy_change = ceph_pg_compare(&t->pgid, &pgid) || + ceph_osds_changed(&t->acting, &acting, + t->used_replica || any_change); + if (t->pg_num) + split = ceph_pg_is_split(&last_pgid, t->pg_num, pi->pg_num); + + if (legacy_change || force_resend || split) { + t->pgid = pgid; /* struct */ + ceph_pg_to_primary_shard(osdc->osdmap, pi, &pgid, &t->spgid); + ceph_osds_copy(&t->acting, &acting); + ceph_osds_copy(&t->up, &up); + t->size = pi->size; + t->min_size = pi->min_size; + t->pg_num = pi->pg_num; + t->pg_num_mask = pi->pg_num_mask; + t->sort_bitwise = sort_bitwise; + t->recovery_deletes = recovery_deletes; + + if ((t->flags & (CEPH_OSD_FLAG_BALANCE_READS | + CEPH_OSD_FLAG_LOCALIZE_READS)) && + !is_write && pi->type == CEPH_POOL_TYPE_REP && + acting.size > 1) { + int pos; + + WARN_ON(!is_read || acting.osds[0] != acting.primary); + if (t->flags & CEPH_OSD_FLAG_BALANCE_READS) { + pos = pick_random_replica(&acting); + } else { + pos = pick_closest_replica(osdc, &acting); + } + t->osd = acting.osds[pos]; + t->used_replica = pos > 0; + } else { + t->osd = acting.primary; + t->used_replica = false; + } + } + + if (unpaused || legacy_change || force_resend || split) + ct_res = CALC_TARGET_NEED_RESEND; + else + ct_res = CALC_TARGET_NO_ACTION; + +out: + dout("%s t %p -> %d%d%d%d ct_res %d osd%d\n", __func__, t, unpaused, + legacy_change, force_resend, split, ct_res, t->osd); + return ct_res; +} + +static struct ceph_spg_mapping *alloc_spg_mapping(void) +{ + struct ceph_spg_mapping *spg; + + spg = kmalloc(sizeof(*spg), GFP_NOIO); + if (!spg) + return NULL; + + RB_CLEAR_NODE(&spg->node); + spg->backoffs = RB_ROOT; + return spg; +} + +static void free_spg_mapping(struct ceph_spg_mapping *spg) +{ + WARN_ON(!RB_EMPTY_NODE(&spg->node)); + WARN_ON(!RB_EMPTY_ROOT(&spg->backoffs)); + + kfree(spg); +} + +/* + * rbtree of ceph_spg_mapping for handling map, similar to + * ceph_pg_mapping. Used to track OSD backoffs -- a backoff [range] is + * defined only within a specific spgid; it does not pass anything to + * children on split, or to another primary. + */ +DEFINE_RB_FUNCS2(spg_mapping, struct ceph_spg_mapping, spgid, ceph_spg_compare, + RB_BYPTR, const struct ceph_spg *, node) + +static u64 hoid_get_bitwise_key(const struct ceph_hobject_id *hoid) +{ + return hoid->is_max ? 0x100000000ull : hoid->hash_reverse_bits; +} + +static void hoid_get_effective_key(const struct ceph_hobject_id *hoid, + void **pkey, size_t *pkey_len) +{ + if (hoid->key_len) { + *pkey = hoid->key; + *pkey_len = hoid->key_len; + } else { + *pkey = hoid->oid; + *pkey_len = hoid->oid_len; + } +} + +static int compare_names(const void *name1, size_t name1_len, + const void *name2, size_t name2_len) +{ + int ret; + + ret = memcmp(name1, name2, min(name1_len, name2_len)); + if (!ret) { + if (name1_len < name2_len) + ret = -1; + else if (name1_len > name2_len) + ret = 1; + } + return ret; +} + +static int hoid_compare(const struct ceph_hobject_id *lhs, + const struct ceph_hobject_id *rhs) +{ + void *effective_key1, *effective_key2; + size_t effective_key1_len, effective_key2_len; + int ret; + + if (lhs->is_max < rhs->is_max) + return -1; + if (lhs->is_max > rhs->is_max) + return 1; + + if (lhs->pool < rhs->pool) + return -1; + if (lhs->pool > rhs->pool) + return 1; + + if (hoid_get_bitwise_key(lhs) < hoid_get_bitwise_key(rhs)) + return -1; + if (hoid_get_bitwise_key(lhs) > hoid_get_bitwise_key(rhs)) + return 1; + + ret = compare_names(lhs->nspace, lhs->nspace_len, + rhs->nspace, rhs->nspace_len); + if (ret) + return ret; + + hoid_get_effective_key(lhs, &effective_key1, &effective_key1_len); + hoid_get_effective_key(rhs, &effective_key2, &effective_key2_len); + ret = compare_names(effective_key1, effective_key1_len, + effective_key2, effective_key2_len); + if (ret) + return ret; + + ret = compare_names(lhs->oid, lhs->oid_len, rhs->oid, rhs->oid_len); + if (ret) + return ret; + + if (lhs->snapid < rhs->snapid) + return -1; + if (lhs->snapid > rhs->snapid) + return 1; + + return 0; +} + +/* + * For decoding ->begin and ->end of MOSDBackoff only -- no MIN/MAX + * compat stuff here. + * + * Assumes @hoid is zero-initialized. + */ +static int decode_hoid(void **p, void *end, struct ceph_hobject_id *hoid) +{ + u8 struct_v; + u32 struct_len; + int ret; + + ret = ceph_start_decoding(p, end, 4, "hobject_t", &struct_v, + &struct_len); + if (ret) + return ret; + + if (struct_v < 4) { + pr_err("got struct_v %d < 4 of hobject_t\n", struct_v); + goto e_inval; + } + + hoid->key = ceph_extract_encoded_string(p, end, &hoid->key_len, + GFP_NOIO); + if (IS_ERR(hoid->key)) { + ret = PTR_ERR(hoid->key); + hoid->key = NULL; + return ret; + } + + hoid->oid = ceph_extract_encoded_string(p, end, &hoid->oid_len, + GFP_NOIO); + if (IS_ERR(hoid->oid)) { + ret = PTR_ERR(hoid->oid); + hoid->oid = NULL; + return ret; + } + + ceph_decode_64_safe(p, end, hoid->snapid, e_inval); + ceph_decode_32_safe(p, end, hoid->hash, e_inval); + ceph_decode_8_safe(p, end, hoid->is_max, e_inval); + + hoid->nspace = ceph_extract_encoded_string(p, end, &hoid->nspace_len, + GFP_NOIO); + if (IS_ERR(hoid->nspace)) { + ret = PTR_ERR(hoid->nspace); + hoid->nspace = NULL; + return ret; + } + + ceph_decode_64_safe(p, end, hoid->pool, e_inval); + + ceph_hoid_build_hash_cache(hoid); + return 0; + +e_inval: + return -EINVAL; +} + +static int hoid_encoding_size(const struct ceph_hobject_id *hoid) +{ + return 8 + 4 + 1 + 8 + /* snapid, hash, is_max, pool */ + 4 + hoid->key_len + 4 + hoid->oid_len + 4 + hoid->nspace_len; +} + +static void encode_hoid(void **p, void *end, const struct ceph_hobject_id *hoid) +{ + ceph_start_encoding(p, 4, 3, hoid_encoding_size(hoid)); + ceph_encode_string(p, end, hoid->key, hoid->key_len); + ceph_encode_string(p, end, hoid->oid, hoid->oid_len); + ceph_encode_64(p, hoid->snapid); + ceph_encode_32(p, hoid->hash); + ceph_encode_8(p, hoid->is_max); + ceph_encode_string(p, end, hoid->nspace, hoid->nspace_len); + ceph_encode_64(p, hoid->pool); +} + +static void free_hoid(struct ceph_hobject_id *hoid) +{ + if (hoid) { + kfree(hoid->key); + kfree(hoid->oid); + kfree(hoid->nspace); + kfree(hoid); + } +} + +static struct ceph_osd_backoff *alloc_backoff(void) +{ + struct ceph_osd_backoff *backoff; + + backoff = kzalloc(sizeof(*backoff), GFP_NOIO); + if (!backoff) + return NULL; + + RB_CLEAR_NODE(&backoff->spg_node); + RB_CLEAR_NODE(&backoff->id_node); + return backoff; +} + +static void free_backoff(struct ceph_osd_backoff *backoff) +{ + WARN_ON(!RB_EMPTY_NODE(&backoff->spg_node)); + WARN_ON(!RB_EMPTY_NODE(&backoff->id_node)); + + free_hoid(backoff->begin); + free_hoid(backoff->end); + kfree(backoff); +} + +/* + * Within a specific spgid, backoffs are managed by ->begin hoid. + */ +DEFINE_RB_INSDEL_FUNCS2(backoff, struct ceph_osd_backoff, begin, hoid_compare, + RB_BYVAL, spg_node); + +static struct ceph_osd_backoff *lookup_containing_backoff(struct rb_root *root, + const struct ceph_hobject_id *hoid) +{ + struct rb_node *n = root->rb_node; + + while (n) { + struct ceph_osd_backoff *cur = + rb_entry(n, struct ceph_osd_backoff, spg_node); + int cmp; + + cmp = hoid_compare(hoid, cur->begin); + if (cmp < 0) { + n = n->rb_left; + } else if (cmp > 0) { + if (hoid_compare(hoid, cur->end) < 0) + return cur; + + n = n->rb_right; + } else { + return cur; + } + } + + return NULL; +} + +/* + * Each backoff has a unique id within its OSD session. + */ +DEFINE_RB_FUNCS(backoff_by_id, struct ceph_osd_backoff, id, id_node) + +static void clear_backoffs(struct ceph_osd *osd) +{ + while (!RB_EMPTY_ROOT(&osd->o_backoff_mappings)) { + struct ceph_spg_mapping *spg = + rb_entry(rb_first(&osd->o_backoff_mappings), + struct ceph_spg_mapping, node); + + while (!RB_EMPTY_ROOT(&spg->backoffs)) { + struct ceph_osd_backoff *backoff = + rb_entry(rb_first(&spg->backoffs), + struct ceph_osd_backoff, spg_node); + + erase_backoff(&spg->backoffs, backoff); + erase_backoff_by_id(&osd->o_backoffs_by_id, backoff); + free_backoff(backoff); + } + erase_spg_mapping(&osd->o_backoff_mappings, spg); + free_spg_mapping(spg); + } +} + +/* + * Set up a temporary, non-owning view into @t. + */ +static void hoid_fill_from_target(struct ceph_hobject_id *hoid, + const struct ceph_osd_request_target *t) +{ + hoid->key = NULL; + hoid->key_len = 0; + hoid->oid = t->target_oid.name; + hoid->oid_len = t->target_oid.name_len; + hoid->snapid = CEPH_NOSNAP; + hoid->hash = t->pgid.seed; + hoid->is_max = false; + if (t->target_oloc.pool_ns) { + hoid->nspace = t->target_oloc.pool_ns->str; + hoid->nspace_len = t->target_oloc.pool_ns->len; + } else { + hoid->nspace = NULL; + hoid->nspace_len = 0; + } + hoid->pool = t->target_oloc.pool; + ceph_hoid_build_hash_cache(hoid); +} + +static bool should_plug_request(struct ceph_osd_request *req) +{ + struct ceph_osd *osd = req->r_osd; + struct ceph_spg_mapping *spg; + struct ceph_osd_backoff *backoff; + struct ceph_hobject_id hoid; + + spg = lookup_spg_mapping(&osd->o_backoff_mappings, &req->r_t.spgid); + if (!spg) + return false; + + hoid_fill_from_target(&hoid, &req->r_t); + backoff = lookup_containing_backoff(&spg->backoffs, &hoid); + if (!backoff) + return false; + + dout("%s req %p tid %llu backoff osd%d spgid %llu.%xs%d id %llu\n", + __func__, req, req->r_tid, osd->o_osd, backoff->spgid.pgid.pool, + backoff->spgid.pgid.seed, backoff->spgid.shard, backoff->id); + return true; +} + +/* + * Keep get_num_data_items() in sync with this function. + */ +static void setup_request_data(struct ceph_osd_request *req) +{ + struct ceph_msg *request_msg = req->r_request; + struct ceph_msg *reply_msg = req->r_reply; + struct ceph_osd_req_op *op; + + if (req->r_request->num_data_items || req->r_reply->num_data_items) + return; + + WARN_ON(request_msg->data_length || reply_msg->data_length); + for (op = req->r_ops; op != &req->r_ops[req->r_num_ops]; op++) { + switch (op->op) { + /* request */ + case CEPH_OSD_OP_WRITE: + case CEPH_OSD_OP_WRITEFULL: + WARN_ON(op->indata_len != op->extent.length); + ceph_osdc_msg_data_add(request_msg, + &op->extent.osd_data); + break; + case CEPH_OSD_OP_SETXATTR: + case CEPH_OSD_OP_CMPXATTR: + WARN_ON(op->indata_len != op->xattr.name_len + + op->xattr.value_len); + ceph_osdc_msg_data_add(request_msg, + &op->xattr.osd_data); + break; + case CEPH_OSD_OP_NOTIFY_ACK: + ceph_osdc_msg_data_add(request_msg, + &op->notify_ack.request_data); + break; + case CEPH_OSD_OP_COPY_FROM2: + ceph_osdc_msg_data_add(request_msg, + &op->copy_from.osd_data); + break; + + /* reply */ + case CEPH_OSD_OP_STAT: + ceph_osdc_msg_data_add(reply_msg, + &op->raw_data_in); + break; + case CEPH_OSD_OP_READ: + ceph_osdc_msg_data_add(reply_msg, + &op->extent.osd_data); + break; + case CEPH_OSD_OP_LIST_WATCHERS: + ceph_osdc_msg_data_add(reply_msg, + &op->list_watchers.response_data); + break; + + /* both */ + case CEPH_OSD_OP_CALL: + WARN_ON(op->indata_len != op->cls.class_len + + op->cls.method_len + + op->cls.indata_len); + ceph_osdc_msg_data_add(request_msg, + &op->cls.request_info); + /* optional, can be NONE */ + ceph_osdc_msg_data_add(request_msg, + &op->cls.request_data); + /* optional, can be NONE */ + ceph_osdc_msg_data_add(reply_msg, + &op->cls.response_data); + break; + case CEPH_OSD_OP_NOTIFY: + ceph_osdc_msg_data_add(request_msg, + &op->notify.request_data); + ceph_osdc_msg_data_add(reply_msg, + &op->notify.response_data); + break; + } + } +} + +static void encode_pgid(void **p, const struct ceph_pg *pgid) +{ + ceph_encode_8(p, 1); + ceph_encode_64(p, pgid->pool); + ceph_encode_32(p, pgid->seed); + ceph_encode_32(p, -1); /* preferred */ +} + +static void encode_spgid(void **p, const struct ceph_spg *spgid) +{ + ceph_start_encoding(p, 1, 1, CEPH_PGID_ENCODING_LEN + 1); + encode_pgid(p, &spgid->pgid); + ceph_encode_8(p, spgid->shard); +} + +static void encode_oloc(void **p, void *end, + const struct ceph_object_locator *oloc) +{ + ceph_start_encoding(p, 5, 4, ceph_oloc_encoding_size(oloc)); + ceph_encode_64(p, oloc->pool); + ceph_encode_32(p, -1); /* preferred */ + ceph_encode_32(p, 0); /* key len */ + if (oloc->pool_ns) + ceph_encode_string(p, end, oloc->pool_ns->str, + oloc->pool_ns->len); + else + ceph_encode_32(p, 0); +} + +static void encode_request_partial(struct ceph_osd_request *req, + struct ceph_msg *msg) +{ + void *p = msg->front.iov_base; + void *const end = p + msg->front_alloc_len; + u32 data_len = 0; + int i; + + if (req->r_flags & CEPH_OSD_FLAG_WRITE) { + /* snapshots aren't writeable */ + WARN_ON(req->r_snapid != CEPH_NOSNAP); + } else { + WARN_ON(req->r_mtime.tv_sec || req->r_mtime.tv_nsec || + req->r_data_offset || req->r_snapc); + } + + setup_request_data(req); + + encode_spgid(&p, &req->r_t.spgid); /* actual spg */ + ceph_encode_32(&p, req->r_t.pgid.seed); /* raw hash */ + ceph_encode_32(&p, req->r_osdc->osdmap->epoch); + ceph_encode_32(&p, req->r_flags); + + /* reqid */ + ceph_start_encoding(&p, 2, 2, sizeof(struct ceph_osd_reqid)); + memset(p, 0, sizeof(struct ceph_osd_reqid)); + p += sizeof(struct ceph_osd_reqid); + + /* trace */ + memset(p, 0, sizeof(struct ceph_blkin_trace_info)); + p += sizeof(struct ceph_blkin_trace_info); + + ceph_encode_32(&p, 0); /* client_inc, always 0 */ + ceph_encode_timespec64(p, &req->r_mtime); + p += sizeof(struct ceph_timespec); + + encode_oloc(&p, end, &req->r_t.target_oloc); + ceph_encode_string(&p, end, req->r_t.target_oid.name, + req->r_t.target_oid.name_len); + + /* ops, can imply data */ + ceph_encode_16(&p, req->r_num_ops); + for (i = 0; i < req->r_num_ops; i++) { + data_len += osd_req_encode_op(p, &req->r_ops[i]); + p += sizeof(struct ceph_osd_op); + } + + ceph_encode_64(&p, req->r_snapid); /* snapid */ + if (req->r_snapc) { + ceph_encode_64(&p, req->r_snapc->seq); + ceph_encode_32(&p, req->r_snapc->num_snaps); + for (i = 0; i < req->r_snapc->num_snaps; i++) + ceph_encode_64(&p, req->r_snapc->snaps[i]); + } else { + ceph_encode_64(&p, 0); /* snap_seq */ + ceph_encode_32(&p, 0); /* snaps len */ + } + + ceph_encode_32(&p, req->r_attempts); /* retry_attempt */ + BUG_ON(p > end - 8); /* space for features */ + + msg->hdr.version = cpu_to_le16(8); /* MOSDOp v8 */ + /* front_len is finalized in encode_request_finish() */ + msg->front.iov_len = p - msg->front.iov_base; + msg->hdr.front_len = cpu_to_le32(msg->front.iov_len); + msg->hdr.data_len = cpu_to_le32(data_len); + /* + * The header "data_off" is a hint to the receiver allowing it + * to align received data into its buffers such that there's no + * need to re-copy it before writing it to disk (direct I/O). + */ + msg->hdr.data_off = cpu_to_le16(req->r_data_offset); + + dout("%s req %p msg %p oid %s oid_len %d\n", __func__, req, msg, + req->r_t.target_oid.name, req->r_t.target_oid.name_len); +} + +static void encode_request_finish(struct ceph_msg *msg) +{ + void *p = msg->front.iov_base; + void *const partial_end = p + msg->front.iov_len; + void *const end = p + msg->front_alloc_len; + + if (CEPH_HAVE_FEATURE(msg->con->peer_features, RESEND_ON_SPLIT)) { + /* luminous OSD -- encode features and be done */ + p = partial_end; + ceph_encode_64(&p, msg->con->peer_features); + } else { + struct { + char spgid[CEPH_ENCODING_START_BLK_LEN + + CEPH_PGID_ENCODING_LEN + 1]; + __le32 hash; + __le32 epoch; + __le32 flags; + char reqid[CEPH_ENCODING_START_BLK_LEN + + sizeof(struct ceph_osd_reqid)]; + char trace[sizeof(struct ceph_blkin_trace_info)]; + __le32 client_inc; + struct ceph_timespec mtime; + } __packed head; + struct ceph_pg pgid; + void *oloc, *oid, *tail; + int oloc_len, oid_len, tail_len; + int len; + + /* + * Pre-luminous OSD -- reencode v8 into v4 using @head + * as a temporary buffer. Encode the raw PG; the rest + * is just a matter of moving oloc, oid and tail blobs + * around. + */ + memcpy(&head, p, sizeof(head)); + p += sizeof(head); + + oloc = p; + p += CEPH_ENCODING_START_BLK_LEN; + pgid.pool = ceph_decode_64(&p); + p += 4 + 4; /* preferred, key len */ + len = ceph_decode_32(&p); + p += len; /* nspace */ + oloc_len = p - oloc; + + oid = p; + len = ceph_decode_32(&p); + p += len; + oid_len = p - oid; + + tail = p; + tail_len = partial_end - p; + + p = msg->front.iov_base; + ceph_encode_copy(&p, &head.client_inc, sizeof(head.client_inc)); + ceph_encode_copy(&p, &head.epoch, sizeof(head.epoch)); + ceph_encode_copy(&p, &head.flags, sizeof(head.flags)); + ceph_encode_copy(&p, &head.mtime, sizeof(head.mtime)); + + /* reassert_version */ + memset(p, 0, sizeof(struct ceph_eversion)); + p += sizeof(struct ceph_eversion); + + BUG_ON(p >= oloc); + memmove(p, oloc, oloc_len); + p += oloc_len; + + pgid.seed = le32_to_cpu(head.hash); + encode_pgid(&p, &pgid); /* raw pg */ + + BUG_ON(p >= oid); + memmove(p, oid, oid_len); + p += oid_len; + + /* tail -- ops, snapid, snapc, retry_attempt */ + BUG_ON(p >= tail); + memmove(p, tail, tail_len); + p += tail_len; + + msg->hdr.version = cpu_to_le16(4); /* MOSDOp v4 */ + } + + BUG_ON(p > end); + msg->front.iov_len = p - msg->front.iov_base; + msg->hdr.front_len = cpu_to_le32(msg->front.iov_len); + + dout("%s msg %p tid %llu %u+%u+%u v%d\n", __func__, msg, + le64_to_cpu(msg->hdr.tid), le32_to_cpu(msg->hdr.front_len), + le32_to_cpu(msg->hdr.middle_len), le32_to_cpu(msg->hdr.data_len), + le16_to_cpu(msg->hdr.version)); +} + +/* + * @req has to be assigned a tid and registered. + */ +static void send_request(struct ceph_osd_request *req) +{ + struct ceph_osd *osd = req->r_osd; + + verify_osd_locked(osd); + WARN_ON(osd->o_osd != req->r_t.osd); + + /* backoff? */ + if (should_plug_request(req)) + return; + + /* + * We may have a previously queued request message hanging + * around. Cancel it to avoid corrupting the msgr. + */ + if (req->r_sent) + ceph_msg_revoke(req->r_request); + + req->r_flags |= CEPH_OSD_FLAG_KNOWN_REDIR; + if (req->r_attempts) + req->r_flags |= CEPH_OSD_FLAG_RETRY; + else + WARN_ON(req->r_flags & CEPH_OSD_FLAG_RETRY); + + encode_request_partial(req, req->r_request); + + dout("%s req %p tid %llu to pgid %llu.%x spgid %llu.%xs%d osd%d e%u flags 0x%x attempt %d\n", + __func__, req, req->r_tid, req->r_t.pgid.pool, req->r_t.pgid.seed, + req->r_t.spgid.pgid.pool, req->r_t.spgid.pgid.seed, + req->r_t.spgid.shard, osd->o_osd, req->r_t.epoch, req->r_flags, + req->r_attempts); + + req->r_t.paused = false; + req->r_stamp = jiffies; + req->r_attempts++; + + req->r_sent = osd->o_incarnation; + req->r_request->hdr.tid = cpu_to_le64(req->r_tid); + ceph_con_send(&osd->o_con, ceph_msg_get(req->r_request)); +} + +static void maybe_request_map(struct ceph_osd_client *osdc) +{ + bool continuous = false; + + verify_osdc_locked(osdc); + WARN_ON(!osdc->osdmap->epoch); + + if (ceph_osdmap_flag(osdc, CEPH_OSDMAP_FULL) || + ceph_osdmap_flag(osdc, CEPH_OSDMAP_PAUSERD) || + ceph_osdmap_flag(osdc, CEPH_OSDMAP_PAUSEWR)) { + dout("%s osdc %p continuous\n", __func__, osdc); + continuous = true; + } else { + dout("%s osdc %p onetime\n", __func__, osdc); + } + + if (ceph_monc_want_map(&osdc->client->monc, CEPH_SUB_OSDMAP, + osdc->osdmap->epoch + 1, continuous)) + ceph_monc_renew_subs(&osdc->client->monc); +} + +static void complete_request(struct ceph_osd_request *req, int err); +static void send_map_check(struct ceph_osd_request *req); + +static void __submit_request(struct ceph_osd_request *req, bool wrlocked) +{ + struct ceph_osd_client *osdc = req->r_osdc; + struct ceph_osd *osd; + enum calc_target_result ct_res; + int err = 0; + bool need_send = false; + bool promoted = false; + + WARN_ON(req->r_tid); + dout("%s req %p wrlocked %d\n", __func__, req, wrlocked); + +again: + ct_res = calc_target(osdc, &req->r_t, false); + if (ct_res == CALC_TARGET_POOL_DNE && !wrlocked) + goto promote; + + osd = lookup_create_osd(osdc, req->r_t.osd, wrlocked); + if (IS_ERR(osd)) { + WARN_ON(PTR_ERR(osd) != -EAGAIN || wrlocked); + goto promote; + } + + if (osdc->abort_err) { + dout("req %p abort_err %d\n", req, osdc->abort_err); + err = osdc->abort_err; + } else if (osdc->osdmap->epoch < osdc->epoch_barrier) { + dout("req %p epoch %u barrier %u\n", req, osdc->osdmap->epoch, + osdc->epoch_barrier); + req->r_t.paused = true; + maybe_request_map(osdc); + } else if ((req->r_flags & CEPH_OSD_FLAG_WRITE) && + ceph_osdmap_flag(osdc, CEPH_OSDMAP_PAUSEWR)) { + dout("req %p pausewr\n", req); + req->r_t.paused = true; + maybe_request_map(osdc); + } else if ((req->r_flags & CEPH_OSD_FLAG_READ) && + ceph_osdmap_flag(osdc, CEPH_OSDMAP_PAUSERD)) { + dout("req %p pauserd\n", req); + req->r_t.paused = true; + maybe_request_map(osdc); + } else if ((req->r_flags & CEPH_OSD_FLAG_WRITE) && + !(req->r_flags & (CEPH_OSD_FLAG_FULL_TRY | + CEPH_OSD_FLAG_FULL_FORCE)) && + (ceph_osdmap_flag(osdc, CEPH_OSDMAP_FULL) || + pool_full(osdc, req->r_t.base_oloc.pool))) { + dout("req %p full/pool_full\n", req); + if (ceph_test_opt(osdc->client, ABORT_ON_FULL)) { + err = -ENOSPC; + } else { + if (ceph_osdmap_flag(osdc, CEPH_OSDMAP_FULL)) + pr_warn_ratelimited("cluster is full (osdmap FULL)\n"); + else + pr_warn_ratelimited("pool %lld is full or reached quota\n", + req->r_t.base_oloc.pool); + req->r_t.paused = true; + maybe_request_map(osdc); + } + } else if (!osd_homeless(osd)) { + need_send = true; + } else { + maybe_request_map(osdc); + } + + mutex_lock(&osd->lock); + /* + * Assign the tid atomically with send_request() to protect + * multiple writes to the same object from racing with each + * other, resulting in out of order ops on the OSDs. + */ + req->r_tid = atomic64_inc_return(&osdc->last_tid); + link_request(osd, req); + if (need_send) + send_request(req); + else if (err) + complete_request(req, err); + mutex_unlock(&osd->lock); + + if (!err && ct_res == CALC_TARGET_POOL_DNE) + send_map_check(req); + + if (promoted) + downgrade_write(&osdc->lock); + return; + +promote: + up_read(&osdc->lock); + down_write(&osdc->lock); + wrlocked = true; + promoted = true; + goto again; +} + +static void account_request(struct ceph_osd_request *req) +{ + WARN_ON(req->r_flags & (CEPH_OSD_FLAG_ACK | CEPH_OSD_FLAG_ONDISK)); + WARN_ON(!(req->r_flags & (CEPH_OSD_FLAG_READ | CEPH_OSD_FLAG_WRITE))); + + req->r_flags |= CEPH_OSD_FLAG_ONDISK; + atomic_inc(&req->r_osdc->num_requests); + + req->r_start_stamp = jiffies; + req->r_start_latency = ktime_get(); +} + +static void submit_request(struct ceph_osd_request *req, bool wrlocked) +{ + ceph_osdc_get_request(req); + account_request(req); + __submit_request(req, wrlocked); +} + +static void finish_request(struct ceph_osd_request *req) +{ + struct ceph_osd_client *osdc = req->r_osdc; + + WARN_ON(lookup_request_mc(&osdc->map_checks, req->r_tid)); + dout("%s req %p tid %llu\n", __func__, req, req->r_tid); + + req->r_end_latency = ktime_get(); + + if (req->r_osd) + unlink_request(req->r_osd, req); + atomic_dec(&osdc->num_requests); + + /* + * If an OSD has failed or returned and a request has been sent + * twice, it's possible to get a reply and end up here while the + * request message is queued for delivery. We will ignore the + * reply, so not a big deal, but better to try and catch it. + */ + ceph_msg_revoke(req->r_request); + ceph_msg_revoke_incoming(req->r_reply); +} + +static void __complete_request(struct ceph_osd_request *req) +{ + dout("%s req %p tid %llu cb %ps result %d\n", __func__, req, + req->r_tid, req->r_callback, req->r_result); + + if (req->r_callback) + req->r_callback(req); + complete_all(&req->r_completion); + ceph_osdc_put_request(req); +} + +static void complete_request_workfn(struct work_struct *work) +{ + struct ceph_osd_request *req = + container_of(work, struct ceph_osd_request, r_complete_work); + + __complete_request(req); +} + +/* + * This is open-coded in handle_reply(). + */ +static void complete_request(struct ceph_osd_request *req, int err) +{ + dout("%s req %p tid %llu err %d\n", __func__, req, req->r_tid, err); + + req->r_result = err; + finish_request(req); + + INIT_WORK(&req->r_complete_work, complete_request_workfn); + queue_work(req->r_osdc->completion_wq, &req->r_complete_work); +} + +static void cancel_map_check(struct ceph_osd_request *req) +{ + struct ceph_osd_client *osdc = req->r_osdc; + struct ceph_osd_request *lookup_req; + + verify_osdc_wrlocked(osdc); + + lookup_req = lookup_request_mc(&osdc->map_checks, req->r_tid); + if (!lookup_req) + return; + + WARN_ON(lookup_req != req); + erase_request_mc(&osdc->map_checks, req); + ceph_osdc_put_request(req); +} + +static void cancel_request(struct ceph_osd_request *req) +{ + dout("%s req %p tid %llu\n", __func__, req, req->r_tid); + + cancel_map_check(req); + finish_request(req); + complete_all(&req->r_completion); + ceph_osdc_put_request(req); +} + +static void abort_request(struct ceph_osd_request *req, int err) +{ + dout("%s req %p tid %llu err %d\n", __func__, req, req->r_tid, err); + + cancel_map_check(req); + complete_request(req, err); +} + +static int abort_fn(struct ceph_osd_request *req, void *arg) +{ + int err = *(int *)arg; + + abort_request(req, err); + return 0; /* continue iteration */ +} + +/* + * Abort all in-flight requests with @err and arrange for all future + * requests to be failed immediately. + */ +void ceph_osdc_abort_requests(struct ceph_osd_client *osdc, int err) +{ + dout("%s osdc %p err %d\n", __func__, osdc, err); + down_write(&osdc->lock); + for_each_request(osdc, abort_fn, &err); + osdc->abort_err = err; + up_write(&osdc->lock); +} +EXPORT_SYMBOL(ceph_osdc_abort_requests); + +void ceph_osdc_clear_abort_err(struct ceph_osd_client *osdc) +{ + down_write(&osdc->lock); + osdc->abort_err = 0; + up_write(&osdc->lock); +} +EXPORT_SYMBOL(ceph_osdc_clear_abort_err); + +static void update_epoch_barrier(struct ceph_osd_client *osdc, u32 eb) +{ + if (likely(eb > osdc->epoch_barrier)) { + dout("updating epoch_barrier from %u to %u\n", + osdc->epoch_barrier, eb); + osdc->epoch_barrier = eb; + /* Request map if we're not to the barrier yet */ + if (eb > osdc->osdmap->epoch) + maybe_request_map(osdc); + } +} + +void ceph_osdc_update_epoch_barrier(struct ceph_osd_client *osdc, u32 eb) +{ + down_read(&osdc->lock); + if (unlikely(eb > osdc->epoch_barrier)) { + up_read(&osdc->lock); + down_write(&osdc->lock); + update_epoch_barrier(osdc, eb); + up_write(&osdc->lock); + } else { + up_read(&osdc->lock); + } +} +EXPORT_SYMBOL(ceph_osdc_update_epoch_barrier); + +/* + * We can end up releasing caps as a result of abort_request(). + * In that case, we probably want to ensure that the cap release message + * has an updated epoch barrier in it, so set the epoch barrier prior to + * aborting the first request. + */ +static int abort_on_full_fn(struct ceph_osd_request *req, void *arg) +{ + struct ceph_osd_client *osdc = req->r_osdc; + bool *victims = arg; + + if ((req->r_flags & CEPH_OSD_FLAG_WRITE) && + (ceph_osdmap_flag(osdc, CEPH_OSDMAP_FULL) || + pool_full(osdc, req->r_t.base_oloc.pool))) { + if (!*victims) { + update_epoch_barrier(osdc, osdc->osdmap->epoch); + *victims = true; + } + abort_request(req, -ENOSPC); + } + + return 0; /* continue iteration */ +} + +/* + * Drop all pending requests that are stalled waiting on a full condition to + * clear, and complete them with ENOSPC as the return code. Set the + * osdc->epoch_barrier to the latest map epoch that we've seen if any were + * cancelled. + */ +static void ceph_osdc_abort_on_full(struct ceph_osd_client *osdc) +{ + bool victims = false; + + if (ceph_test_opt(osdc->client, ABORT_ON_FULL) && + (ceph_osdmap_flag(osdc, CEPH_OSDMAP_FULL) || have_pool_full(osdc))) + for_each_request(osdc, abort_on_full_fn, &victims); +} + +static void check_pool_dne(struct ceph_osd_request *req) +{ + struct ceph_osd_client *osdc = req->r_osdc; + struct ceph_osdmap *map = osdc->osdmap; + + verify_osdc_wrlocked(osdc); + WARN_ON(!map->epoch); + + if (req->r_attempts) { + /* + * We sent a request earlier, which means that + * previously the pool existed, and now it does not + * (i.e., it was deleted). + */ + req->r_map_dne_bound = map->epoch; + dout("%s req %p tid %llu pool disappeared\n", __func__, req, + req->r_tid); + } else { + dout("%s req %p tid %llu map_dne_bound %u have %u\n", __func__, + req, req->r_tid, req->r_map_dne_bound, map->epoch); + } + + if (req->r_map_dne_bound) { + if (map->epoch >= req->r_map_dne_bound) { + /* we had a new enough map */ + pr_info_ratelimited("tid %llu pool does not exist\n", + req->r_tid); + complete_request(req, -ENOENT); + } + } else { + send_map_check(req); + } +} + +static void map_check_cb(struct ceph_mon_generic_request *greq) +{ + struct ceph_osd_client *osdc = &greq->monc->client->osdc; + struct ceph_osd_request *req; + u64 tid = greq->private_data; + + WARN_ON(greq->result || !greq->u.newest); + + down_write(&osdc->lock); + req = lookup_request_mc(&osdc->map_checks, tid); + if (!req) { + dout("%s tid %llu dne\n", __func__, tid); + goto out_unlock; + } + + dout("%s req %p tid %llu map_dne_bound %u newest %llu\n", __func__, + req, req->r_tid, req->r_map_dne_bound, greq->u.newest); + if (!req->r_map_dne_bound) + req->r_map_dne_bound = greq->u.newest; + erase_request_mc(&osdc->map_checks, req); + check_pool_dne(req); + + ceph_osdc_put_request(req); +out_unlock: + up_write(&osdc->lock); +} + +static void send_map_check(struct ceph_osd_request *req) +{ + struct ceph_osd_client *osdc = req->r_osdc; + struct ceph_osd_request *lookup_req; + int ret; + + verify_osdc_wrlocked(osdc); + + lookup_req = lookup_request_mc(&osdc->map_checks, req->r_tid); + if (lookup_req) { + WARN_ON(lookup_req != req); + return; + } + + ceph_osdc_get_request(req); + insert_request_mc(&osdc->map_checks, req); + ret = ceph_monc_get_version_async(&osdc->client->monc, "osdmap", + map_check_cb, req->r_tid); + WARN_ON(ret); +} + +/* + * lingering requests, watch/notify v2 infrastructure + */ +static void linger_release(struct kref *kref) +{ + struct ceph_osd_linger_request *lreq = + container_of(kref, struct ceph_osd_linger_request, kref); + + dout("%s lreq %p reg_req %p ping_req %p\n", __func__, lreq, + lreq->reg_req, lreq->ping_req); + WARN_ON(!RB_EMPTY_NODE(&lreq->node)); + WARN_ON(!RB_EMPTY_NODE(&lreq->osdc_node)); + WARN_ON(!RB_EMPTY_NODE(&lreq->mc_node)); + WARN_ON(!list_empty(&lreq->scan_item)); + WARN_ON(!list_empty(&lreq->pending_lworks)); + WARN_ON(lreq->osd); + + if (lreq->request_pl) + ceph_pagelist_release(lreq->request_pl); + if (lreq->notify_id_pages) + ceph_release_page_vector(lreq->notify_id_pages, 1); + + ceph_osdc_put_request(lreq->reg_req); + ceph_osdc_put_request(lreq->ping_req); + target_destroy(&lreq->t); + kfree(lreq); +} + +static void linger_put(struct ceph_osd_linger_request *lreq) +{ + if (lreq) + kref_put(&lreq->kref, linger_release); +} + +static struct ceph_osd_linger_request * +linger_get(struct ceph_osd_linger_request *lreq) +{ + kref_get(&lreq->kref); + return lreq; +} + +static struct ceph_osd_linger_request * +linger_alloc(struct ceph_osd_client *osdc) +{ + struct ceph_osd_linger_request *lreq; + + lreq = kzalloc(sizeof(*lreq), GFP_NOIO); + if (!lreq) + return NULL; + + kref_init(&lreq->kref); + mutex_init(&lreq->lock); + RB_CLEAR_NODE(&lreq->node); + RB_CLEAR_NODE(&lreq->osdc_node); + RB_CLEAR_NODE(&lreq->mc_node); + INIT_LIST_HEAD(&lreq->scan_item); + INIT_LIST_HEAD(&lreq->pending_lworks); + init_completion(&lreq->reg_commit_wait); + init_completion(&lreq->notify_finish_wait); + + lreq->osdc = osdc; + target_init(&lreq->t); + + dout("%s lreq %p\n", __func__, lreq); + return lreq; +} + +DEFINE_RB_INSDEL_FUNCS(linger, struct ceph_osd_linger_request, linger_id, node) +DEFINE_RB_FUNCS(linger_osdc, struct ceph_osd_linger_request, linger_id, osdc_node) +DEFINE_RB_FUNCS(linger_mc, struct ceph_osd_linger_request, linger_id, mc_node) + +/* + * Create linger request <-> OSD session relation. + * + * @lreq has to be registered, @osd may be homeless. + */ +static void link_linger(struct ceph_osd *osd, + struct ceph_osd_linger_request *lreq) +{ + verify_osd_locked(osd); + WARN_ON(!lreq->linger_id || lreq->osd); + dout("%s osd %p osd%d lreq %p linger_id %llu\n", __func__, osd, + osd->o_osd, lreq, lreq->linger_id); + + if (!osd_homeless(osd)) + __remove_osd_from_lru(osd); + else + atomic_inc(&osd->o_osdc->num_homeless); + + get_osd(osd); + insert_linger(&osd->o_linger_requests, lreq); + lreq->osd = osd; +} + +static void unlink_linger(struct ceph_osd *osd, + struct ceph_osd_linger_request *lreq) +{ + verify_osd_locked(osd); + WARN_ON(lreq->osd != osd); + dout("%s osd %p osd%d lreq %p linger_id %llu\n", __func__, osd, + osd->o_osd, lreq, lreq->linger_id); + + lreq->osd = NULL; + erase_linger(&osd->o_linger_requests, lreq); + put_osd(osd); + + if (!osd_homeless(osd)) + maybe_move_osd_to_lru(osd); + else + atomic_dec(&osd->o_osdc->num_homeless); +} + +static bool __linger_registered(struct ceph_osd_linger_request *lreq) +{ + verify_osdc_locked(lreq->osdc); + + return !RB_EMPTY_NODE(&lreq->osdc_node); +} + +static bool linger_registered(struct ceph_osd_linger_request *lreq) +{ + struct ceph_osd_client *osdc = lreq->osdc; + bool registered; + + down_read(&osdc->lock); + registered = __linger_registered(lreq); + up_read(&osdc->lock); + + return registered; +} + +static void linger_register(struct ceph_osd_linger_request *lreq) +{ + struct ceph_osd_client *osdc = lreq->osdc; + + verify_osdc_wrlocked(osdc); + WARN_ON(lreq->linger_id); + + linger_get(lreq); + lreq->linger_id = ++osdc->last_linger_id; + insert_linger_osdc(&osdc->linger_requests, lreq); +} + +static void linger_unregister(struct ceph_osd_linger_request *lreq) +{ + struct ceph_osd_client *osdc = lreq->osdc; + + verify_osdc_wrlocked(osdc); + + erase_linger_osdc(&osdc->linger_requests, lreq); + linger_put(lreq); +} + +static void cancel_linger_request(struct ceph_osd_request *req) +{ + struct ceph_osd_linger_request *lreq = req->r_priv; + + WARN_ON(!req->r_linger); + cancel_request(req); + linger_put(lreq); +} + +struct linger_work { + struct work_struct work; + struct ceph_osd_linger_request *lreq; + struct list_head pending_item; + unsigned long queued_stamp; + + union { + struct { + u64 notify_id; + u64 notifier_id; + void *payload; /* points into @msg front */ + size_t payload_len; + + struct ceph_msg *msg; /* for ceph_msg_put() */ + } notify; + struct { + int err; + } error; + }; +}; + +static struct linger_work *lwork_alloc(struct ceph_osd_linger_request *lreq, + work_func_t workfn) +{ + struct linger_work *lwork; + + lwork = kzalloc(sizeof(*lwork), GFP_NOIO); + if (!lwork) + return NULL; + + INIT_WORK(&lwork->work, workfn); + INIT_LIST_HEAD(&lwork->pending_item); + lwork->lreq = linger_get(lreq); + + return lwork; +} + +static void lwork_free(struct linger_work *lwork) +{ + struct ceph_osd_linger_request *lreq = lwork->lreq; + + mutex_lock(&lreq->lock); + list_del(&lwork->pending_item); + mutex_unlock(&lreq->lock); + + linger_put(lreq); + kfree(lwork); +} + +static void lwork_queue(struct linger_work *lwork) +{ + struct ceph_osd_linger_request *lreq = lwork->lreq; + struct ceph_osd_client *osdc = lreq->osdc; + + verify_lreq_locked(lreq); + WARN_ON(!list_empty(&lwork->pending_item)); + + lwork->queued_stamp = jiffies; + list_add_tail(&lwork->pending_item, &lreq->pending_lworks); + queue_work(osdc->notify_wq, &lwork->work); +} + +static void do_watch_notify(struct work_struct *w) +{ + struct linger_work *lwork = container_of(w, struct linger_work, work); + struct ceph_osd_linger_request *lreq = lwork->lreq; + + if (!linger_registered(lreq)) { + dout("%s lreq %p not registered\n", __func__, lreq); + goto out; + } + + WARN_ON(!lreq->is_watch); + dout("%s lreq %p notify_id %llu notifier_id %llu payload_len %zu\n", + __func__, lreq, lwork->notify.notify_id, lwork->notify.notifier_id, + lwork->notify.payload_len); + lreq->wcb(lreq->data, lwork->notify.notify_id, lreq->linger_id, + lwork->notify.notifier_id, lwork->notify.payload, + lwork->notify.payload_len); + +out: + ceph_msg_put(lwork->notify.msg); + lwork_free(lwork); +} + +static void do_watch_error(struct work_struct *w) +{ + struct linger_work *lwork = container_of(w, struct linger_work, work); + struct ceph_osd_linger_request *lreq = lwork->lreq; + + if (!linger_registered(lreq)) { + dout("%s lreq %p not registered\n", __func__, lreq); + goto out; + } + + dout("%s lreq %p err %d\n", __func__, lreq, lwork->error.err); + lreq->errcb(lreq->data, lreq->linger_id, lwork->error.err); + +out: + lwork_free(lwork); +} + +static void queue_watch_error(struct ceph_osd_linger_request *lreq) +{ + struct linger_work *lwork; + + lwork = lwork_alloc(lreq, do_watch_error); + if (!lwork) { + pr_err("failed to allocate error-lwork\n"); + return; + } + + lwork->error.err = lreq->last_error; + lwork_queue(lwork); +} + +static void linger_reg_commit_complete(struct ceph_osd_linger_request *lreq, + int result) +{ + if (!completion_done(&lreq->reg_commit_wait)) { + lreq->reg_commit_error = (result <= 0 ? result : 0); + complete_all(&lreq->reg_commit_wait); + } +} + +static void linger_commit_cb(struct ceph_osd_request *req) +{ + struct ceph_osd_linger_request *lreq = req->r_priv; + + mutex_lock(&lreq->lock); + if (req != lreq->reg_req) { + dout("%s lreq %p linger_id %llu unknown req (%p != %p)\n", + __func__, lreq, lreq->linger_id, req, lreq->reg_req); + goto out; + } + + dout("%s lreq %p linger_id %llu result %d\n", __func__, lreq, + lreq->linger_id, req->r_result); + linger_reg_commit_complete(lreq, req->r_result); + lreq->committed = true; + + if (!lreq->is_watch) { + struct ceph_osd_data *osd_data = + osd_req_op_data(req, 0, notify, response_data); + void *p = page_address(osd_data->pages[0]); + + WARN_ON(req->r_ops[0].op != CEPH_OSD_OP_NOTIFY || + osd_data->type != CEPH_OSD_DATA_TYPE_PAGES); + + /* make note of the notify_id */ + if (req->r_ops[0].outdata_len >= sizeof(u64)) { + lreq->notify_id = ceph_decode_64(&p); + dout("lreq %p notify_id %llu\n", lreq, + lreq->notify_id); + } else { + dout("lreq %p no notify_id\n", lreq); + } + } + +out: + mutex_unlock(&lreq->lock); + linger_put(lreq); +} + +static int normalize_watch_error(int err) +{ + /* + * Translate ENOENT -> ENOTCONN so that a delete->disconnection + * notification and a failure to reconnect because we raced with + * the delete appear the same to the user. + */ + if (err == -ENOENT) + err = -ENOTCONN; + + return err; +} + +static void linger_reconnect_cb(struct ceph_osd_request *req) +{ + struct ceph_osd_linger_request *lreq = req->r_priv; + + mutex_lock(&lreq->lock); + if (req != lreq->reg_req) { + dout("%s lreq %p linger_id %llu unknown req (%p != %p)\n", + __func__, lreq, lreq->linger_id, req, lreq->reg_req); + goto out; + } + + dout("%s lreq %p linger_id %llu result %d last_error %d\n", __func__, + lreq, lreq->linger_id, req->r_result, lreq->last_error); + if (req->r_result < 0) { + if (!lreq->last_error) { + lreq->last_error = normalize_watch_error(req->r_result); + queue_watch_error(lreq); + } + } + +out: + mutex_unlock(&lreq->lock); + linger_put(lreq); +} + +static void send_linger(struct ceph_osd_linger_request *lreq) +{ + struct ceph_osd_client *osdc = lreq->osdc; + struct ceph_osd_request *req; + int ret; + + verify_osdc_wrlocked(osdc); + mutex_lock(&lreq->lock); + dout("%s lreq %p linger_id %llu\n", __func__, lreq, lreq->linger_id); + + if (lreq->reg_req) { + if (lreq->reg_req->r_osd) + cancel_linger_request(lreq->reg_req); + ceph_osdc_put_request(lreq->reg_req); + } + + req = ceph_osdc_alloc_request(osdc, NULL, 1, true, GFP_NOIO); + BUG_ON(!req); + + target_copy(&req->r_t, &lreq->t); + req->r_mtime = lreq->mtime; + + if (lreq->is_watch && lreq->committed) { + osd_req_op_watch_init(req, 0, CEPH_OSD_WATCH_OP_RECONNECT, + lreq->linger_id, ++lreq->register_gen); + dout("lreq %p reconnect register_gen %u\n", lreq, + req->r_ops[0].watch.gen); + req->r_callback = linger_reconnect_cb; + } else { + if (lreq->is_watch) { + osd_req_op_watch_init(req, 0, CEPH_OSD_WATCH_OP_WATCH, + lreq->linger_id, 0); + } else { + lreq->notify_id = 0; + + refcount_inc(&lreq->request_pl->refcnt); + osd_req_op_notify_init(req, 0, lreq->linger_id, + lreq->request_pl); + ceph_osd_data_pages_init( + osd_req_op_data(req, 0, notify, response_data), + lreq->notify_id_pages, PAGE_SIZE, 0, false, false); + } + dout("lreq %p register\n", lreq); + req->r_callback = linger_commit_cb; + } + + ret = ceph_osdc_alloc_messages(req, GFP_NOIO); + BUG_ON(ret); + + req->r_priv = linger_get(lreq); + req->r_linger = true; + lreq->reg_req = req; + mutex_unlock(&lreq->lock); + + submit_request(req, true); +} + +static void linger_ping_cb(struct ceph_osd_request *req) +{ + struct ceph_osd_linger_request *lreq = req->r_priv; + + mutex_lock(&lreq->lock); + if (req != lreq->ping_req) { + dout("%s lreq %p linger_id %llu unknown req (%p != %p)\n", + __func__, lreq, lreq->linger_id, req, lreq->ping_req); + goto out; + } + + dout("%s lreq %p linger_id %llu result %d ping_sent %lu last_error %d\n", + __func__, lreq, lreq->linger_id, req->r_result, lreq->ping_sent, + lreq->last_error); + if (lreq->register_gen == req->r_ops[0].watch.gen) { + if (!req->r_result) { + lreq->watch_valid_thru = lreq->ping_sent; + } else if (!lreq->last_error) { + lreq->last_error = normalize_watch_error(req->r_result); + queue_watch_error(lreq); + } + } else { + dout("lreq %p register_gen %u ignoring old pong %u\n", lreq, + lreq->register_gen, req->r_ops[0].watch.gen); + } + +out: + mutex_unlock(&lreq->lock); + linger_put(lreq); +} + +static void send_linger_ping(struct ceph_osd_linger_request *lreq) +{ + struct ceph_osd_client *osdc = lreq->osdc; + struct ceph_osd_request *req; + int ret; + + if (ceph_osdmap_flag(osdc, CEPH_OSDMAP_PAUSERD)) { + dout("%s PAUSERD\n", __func__); + return; + } + + lreq->ping_sent = jiffies; + dout("%s lreq %p linger_id %llu ping_sent %lu register_gen %u\n", + __func__, lreq, lreq->linger_id, lreq->ping_sent, + lreq->register_gen); + + if (lreq->ping_req) { + if (lreq->ping_req->r_osd) + cancel_linger_request(lreq->ping_req); + ceph_osdc_put_request(lreq->ping_req); + } + + req = ceph_osdc_alloc_request(osdc, NULL, 1, true, GFP_NOIO); + BUG_ON(!req); + + target_copy(&req->r_t, &lreq->t); + osd_req_op_watch_init(req, 0, CEPH_OSD_WATCH_OP_PING, lreq->linger_id, + lreq->register_gen); + req->r_callback = linger_ping_cb; + + ret = ceph_osdc_alloc_messages(req, GFP_NOIO); + BUG_ON(ret); + + req->r_priv = linger_get(lreq); + req->r_linger = true; + lreq->ping_req = req; + + ceph_osdc_get_request(req); + account_request(req); + req->r_tid = atomic64_inc_return(&osdc->last_tid); + link_request(lreq->osd, req); + send_request(req); +} + +static void linger_submit(struct ceph_osd_linger_request *lreq) +{ + struct ceph_osd_client *osdc = lreq->osdc; + struct ceph_osd *osd; + + down_write(&osdc->lock); + linger_register(lreq); + + calc_target(osdc, &lreq->t, false); + osd = lookup_create_osd(osdc, lreq->t.osd, true); + link_linger(osd, lreq); + + send_linger(lreq); + up_write(&osdc->lock); +} + +static void cancel_linger_map_check(struct ceph_osd_linger_request *lreq) +{ + struct ceph_osd_client *osdc = lreq->osdc; + struct ceph_osd_linger_request *lookup_lreq; + + verify_osdc_wrlocked(osdc); + + lookup_lreq = lookup_linger_mc(&osdc->linger_map_checks, + lreq->linger_id); + if (!lookup_lreq) + return; + + WARN_ON(lookup_lreq != lreq); + erase_linger_mc(&osdc->linger_map_checks, lreq); + linger_put(lreq); +} + +/* + * @lreq has to be both registered and linked. + */ +static void __linger_cancel(struct ceph_osd_linger_request *lreq) +{ + if (lreq->ping_req && lreq->ping_req->r_osd) + cancel_linger_request(lreq->ping_req); + if (lreq->reg_req && lreq->reg_req->r_osd) + cancel_linger_request(lreq->reg_req); + cancel_linger_map_check(lreq); + unlink_linger(lreq->osd, lreq); + linger_unregister(lreq); +} + +static void linger_cancel(struct ceph_osd_linger_request *lreq) +{ + struct ceph_osd_client *osdc = lreq->osdc; + + down_write(&osdc->lock); + if (__linger_registered(lreq)) + __linger_cancel(lreq); + up_write(&osdc->lock); +} + +static void send_linger_map_check(struct ceph_osd_linger_request *lreq); + +static void check_linger_pool_dne(struct ceph_osd_linger_request *lreq) +{ + struct ceph_osd_client *osdc = lreq->osdc; + struct ceph_osdmap *map = osdc->osdmap; + + verify_osdc_wrlocked(osdc); + WARN_ON(!map->epoch); + + if (lreq->register_gen) { + lreq->map_dne_bound = map->epoch; + dout("%s lreq %p linger_id %llu pool disappeared\n", __func__, + lreq, lreq->linger_id); + } else { + dout("%s lreq %p linger_id %llu map_dne_bound %u have %u\n", + __func__, lreq, lreq->linger_id, lreq->map_dne_bound, + map->epoch); + } + + if (lreq->map_dne_bound) { + if (map->epoch >= lreq->map_dne_bound) { + /* we had a new enough map */ + pr_info("linger_id %llu pool does not exist\n", + lreq->linger_id); + linger_reg_commit_complete(lreq, -ENOENT); + __linger_cancel(lreq); + } + } else { + send_linger_map_check(lreq); + } +} + +static void linger_map_check_cb(struct ceph_mon_generic_request *greq) +{ + struct ceph_osd_client *osdc = &greq->monc->client->osdc; + struct ceph_osd_linger_request *lreq; + u64 linger_id = greq->private_data; + + WARN_ON(greq->result || !greq->u.newest); + + down_write(&osdc->lock); + lreq = lookup_linger_mc(&osdc->linger_map_checks, linger_id); + if (!lreq) { + dout("%s linger_id %llu dne\n", __func__, linger_id); + goto out_unlock; + } + + dout("%s lreq %p linger_id %llu map_dne_bound %u newest %llu\n", + __func__, lreq, lreq->linger_id, lreq->map_dne_bound, + greq->u.newest); + if (!lreq->map_dne_bound) + lreq->map_dne_bound = greq->u.newest; + erase_linger_mc(&osdc->linger_map_checks, lreq); + check_linger_pool_dne(lreq); + + linger_put(lreq); +out_unlock: + up_write(&osdc->lock); +} + +static void send_linger_map_check(struct ceph_osd_linger_request *lreq) +{ + struct ceph_osd_client *osdc = lreq->osdc; + struct ceph_osd_linger_request *lookup_lreq; + int ret; + + verify_osdc_wrlocked(osdc); + + lookup_lreq = lookup_linger_mc(&osdc->linger_map_checks, + lreq->linger_id); + if (lookup_lreq) { + WARN_ON(lookup_lreq != lreq); + return; + } + + linger_get(lreq); + insert_linger_mc(&osdc->linger_map_checks, lreq); + ret = ceph_monc_get_version_async(&osdc->client->monc, "osdmap", + linger_map_check_cb, lreq->linger_id); + WARN_ON(ret); +} + +static int linger_reg_commit_wait(struct ceph_osd_linger_request *lreq) +{ + int ret; + + dout("%s lreq %p linger_id %llu\n", __func__, lreq, lreq->linger_id); + ret = wait_for_completion_killable(&lreq->reg_commit_wait); + return ret ?: lreq->reg_commit_error; +} + +static int linger_notify_finish_wait(struct ceph_osd_linger_request *lreq, + unsigned long timeout) +{ + long left; + + dout("%s lreq %p linger_id %llu\n", __func__, lreq, lreq->linger_id); + left = wait_for_completion_killable_timeout(&lreq->notify_finish_wait, + ceph_timeout_jiffies(timeout)); + if (left <= 0) + left = left ?: -ETIMEDOUT; + else + left = lreq->notify_finish_error; /* completed */ + + return left; +} + +/* + * Timeout callback, called every N seconds. When 1 or more OSD + * requests has been active for more than N seconds, we send a keepalive + * (tag + timestamp) to its OSD to ensure any communications channel + * reset is detected. + */ +static void handle_timeout(struct work_struct *work) +{ + struct ceph_osd_client *osdc = + container_of(work, struct ceph_osd_client, timeout_work.work); + struct ceph_options *opts = osdc->client->options; + unsigned long cutoff = jiffies - opts->osd_keepalive_timeout; + unsigned long expiry_cutoff = jiffies - opts->osd_request_timeout; + LIST_HEAD(slow_osds); + struct rb_node *n, *p; + + dout("%s osdc %p\n", __func__, osdc); + down_write(&osdc->lock); + + /* + * ping osds that are a bit slow. this ensures that if there + * is a break in the TCP connection we will notice, and reopen + * a connection with that osd (from the fault callback). + */ + for (n = rb_first(&osdc->osds); n; n = rb_next(n)) { + struct ceph_osd *osd = rb_entry(n, struct ceph_osd, o_node); + bool found = false; + + for (p = rb_first(&osd->o_requests); p; ) { + struct ceph_osd_request *req = + rb_entry(p, struct ceph_osd_request, r_node); + + p = rb_next(p); /* abort_request() */ + + if (time_before(req->r_stamp, cutoff)) { + dout(" req %p tid %llu on osd%d is laggy\n", + req, req->r_tid, osd->o_osd); + found = true; + } + if (opts->osd_request_timeout && + time_before(req->r_start_stamp, expiry_cutoff)) { + pr_err_ratelimited("tid %llu on osd%d timeout\n", + req->r_tid, osd->o_osd); + abort_request(req, -ETIMEDOUT); + } + } + for (p = rb_first(&osd->o_linger_requests); p; p = rb_next(p)) { + struct ceph_osd_linger_request *lreq = + rb_entry(p, struct ceph_osd_linger_request, node); + + dout(" lreq %p linger_id %llu is served by osd%d\n", + lreq, lreq->linger_id, osd->o_osd); + found = true; + + mutex_lock(&lreq->lock); + if (lreq->is_watch && lreq->committed && !lreq->last_error) + send_linger_ping(lreq); + mutex_unlock(&lreq->lock); + } + + if (found) + list_move_tail(&osd->o_keepalive_item, &slow_osds); + } + + if (opts->osd_request_timeout) { + for (p = rb_first(&osdc->homeless_osd.o_requests); p; ) { + struct ceph_osd_request *req = + rb_entry(p, struct ceph_osd_request, r_node); + + p = rb_next(p); /* abort_request() */ + + if (time_before(req->r_start_stamp, expiry_cutoff)) { + pr_err_ratelimited("tid %llu on osd%d timeout\n", + req->r_tid, osdc->homeless_osd.o_osd); + abort_request(req, -ETIMEDOUT); + } + } + } + + if (atomic_read(&osdc->num_homeless) || !list_empty(&slow_osds)) + maybe_request_map(osdc); + + while (!list_empty(&slow_osds)) { + struct ceph_osd *osd = list_first_entry(&slow_osds, + struct ceph_osd, + o_keepalive_item); + list_del_init(&osd->o_keepalive_item); + ceph_con_keepalive(&osd->o_con); + } + + up_write(&osdc->lock); + schedule_delayed_work(&osdc->timeout_work, + osdc->client->options->osd_keepalive_timeout); +} + +static void handle_osds_timeout(struct work_struct *work) +{ + struct ceph_osd_client *osdc = + container_of(work, struct ceph_osd_client, + osds_timeout_work.work); + unsigned long delay = osdc->client->options->osd_idle_ttl / 4; + struct ceph_osd *osd, *nosd; + + dout("%s osdc %p\n", __func__, osdc); + down_write(&osdc->lock); + list_for_each_entry_safe(osd, nosd, &osdc->osd_lru, o_osd_lru) { + if (time_before(jiffies, osd->lru_ttl)) + break; + + WARN_ON(!RB_EMPTY_ROOT(&osd->o_requests)); + WARN_ON(!RB_EMPTY_ROOT(&osd->o_linger_requests)); + close_osd(osd); + } + + up_write(&osdc->lock); + schedule_delayed_work(&osdc->osds_timeout_work, + round_jiffies_relative(delay)); +} + +static int ceph_oloc_decode(void **p, void *end, + struct ceph_object_locator *oloc) +{ + u8 struct_v, struct_cv; + u32 len; + void *struct_end; + int ret = 0; + + ceph_decode_need(p, end, 1 + 1 + 4, e_inval); + struct_v = ceph_decode_8(p); + struct_cv = ceph_decode_8(p); + if (struct_v < 3) { + pr_warn("got v %d < 3 cv %d of ceph_object_locator\n", + struct_v, struct_cv); + goto e_inval; + } + if (struct_cv > 6) { + pr_warn("got v %d cv %d > 6 of ceph_object_locator\n", + struct_v, struct_cv); + goto e_inval; + } + len = ceph_decode_32(p); + ceph_decode_need(p, end, len, e_inval); + struct_end = *p + len; + + oloc->pool = ceph_decode_64(p); + *p += 4; /* skip preferred */ + + len = ceph_decode_32(p); + if (len > 0) { + pr_warn("ceph_object_locator::key is set\n"); + goto e_inval; + } + + if (struct_v >= 5) { + bool changed = false; + + len = ceph_decode_32(p); + if (len > 0) { + ceph_decode_need(p, end, len, e_inval); + if (!oloc->pool_ns || + ceph_compare_string(oloc->pool_ns, *p, len)) + changed = true; + *p += len; + } else { + if (oloc->pool_ns) + changed = true; + } + if (changed) { + /* redirect changes namespace */ + pr_warn("ceph_object_locator::nspace is changed\n"); + goto e_inval; + } + } + + if (struct_v >= 6) { + s64 hash = ceph_decode_64(p); + if (hash != -1) { + pr_warn("ceph_object_locator::hash is set\n"); + goto e_inval; + } + } + + /* skip the rest */ + *p = struct_end; +out: + return ret; + +e_inval: + ret = -EINVAL; + goto out; +} + +static int ceph_redirect_decode(void **p, void *end, + struct ceph_request_redirect *redir) +{ + u8 struct_v, struct_cv; + u32 len; + void *struct_end; + int ret; + + ceph_decode_need(p, end, 1 + 1 + 4, e_inval); + struct_v = ceph_decode_8(p); + struct_cv = ceph_decode_8(p); + if (struct_cv > 1) { + pr_warn("got v %d cv %d > 1 of ceph_request_redirect\n", + struct_v, struct_cv); + goto e_inval; + } + len = ceph_decode_32(p); + ceph_decode_need(p, end, len, e_inval); + struct_end = *p + len; + + ret = ceph_oloc_decode(p, end, &redir->oloc); + if (ret) + goto out; + + len = ceph_decode_32(p); + if (len > 0) { + pr_warn("ceph_request_redirect::object_name is set\n"); + goto e_inval; + } + + /* skip the rest */ + *p = struct_end; +out: + return ret; + +e_inval: + ret = -EINVAL; + goto out; +} + +struct MOSDOpReply { + struct ceph_pg pgid; + u64 flags; + int result; + u32 epoch; + int num_ops; + u32 outdata_len[CEPH_OSD_MAX_OPS]; + s32 rval[CEPH_OSD_MAX_OPS]; + int retry_attempt; + struct ceph_eversion replay_version; + u64 user_version; + struct ceph_request_redirect redirect; +}; + +static int decode_MOSDOpReply(const struct ceph_msg *msg, struct MOSDOpReply *m) +{ + void *p = msg->front.iov_base; + void *const end = p + msg->front.iov_len; + u16 version = le16_to_cpu(msg->hdr.version); + struct ceph_eversion bad_replay_version; + u8 decode_redir; + u32 len; + int ret; + int i; + + ceph_decode_32_safe(&p, end, len, e_inval); + ceph_decode_need(&p, end, len, e_inval); + p += len; /* skip oid */ + + ret = ceph_decode_pgid(&p, end, &m->pgid); + if (ret) + return ret; + + ceph_decode_64_safe(&p, end, m->flags, e_inval); + ceph_decode_32_safe(&p, end, m->result, e_inval); + ceph_decode_need(&p, end, sizeof(bad_replay_version), e_inval); + memcpy(&bad_replay_version, p, sizeof(bad_replay_version)); + p += sizeof(bad_replay_version); + ceph_decode_32_safe(&p, end, m->epoch, e_inval); + + ceph_decode_32_safe(&p, end, m->num_ops, e_inval); + if (m->num_ops > ARRAY_SIZE(m->outdata_len)) + goto e_inval; + + ceph_decode_need(&p, end, m->num_ops * sizeof(struct ceph_osd_op), + e_inval); + for (i = 0; i < m->num_ops; i++) { + struct ceph_osd_op *op = p; + + m->outdata_len[i] = le32_to_cpu(op->payload_len); + p += sizeof(*op); + } + + ceph_decode_32_safe(&p, end, m->retry_attempt, e_inval); + for (i = 0; i < m->num_ops; i++) + ceph_decode_32_safe(&p, end, m->rval[i], e_inval); + + if (version >= 5) { + ceph_decode_need(&p, end, sizeof(m->replay_version), e_inval); + memcpy(&m->replay_version, p, sizeof(m->replay_version)); + p += sizeof(m->replay_version); + ceph_decode_64_safe(&p, end, m->user_version, e_inval); + } else { + m->replay_version = bad_replay_version; /* struct */ + m->user_version = le64_to_cpu(m->replay_version.version); + } + + if (version >= 6) { + if (version >= 7) + ceph_decode_8_safe(&p, end, decode_redir, e_inval); + else + decode_redir = 1; + } else { + decode_redir = 0; + } + + if (decode_redir) { + ret = ceph_redirect_decode(&p, end, &m->redirect); + if (ret) + return ret; + } else { + ceph_oloc_init(&m->redirect.oloc); + } + + return 0; + +e_inval: + return -EINVAL; +} + +/* + * Handle MOSDOpReply. Set ->r_result and call the callback if it is + * specified. + */ +static void handle_reply(struct ceph_osd *osd, struct ceph_msg *msg) +{ + struct ceph_osd_client *osdc = osd->o_osdc; + struct ceph_osd_request *req; + struct MOSDOpReply m; + u64 tid = le64_to_cpu(msg->hdr.tid); + u32 data_len = 0; + int ret; + int i; + + dout("%s msg %p tid %llu\n", __func__, msg, tid); + + down_read(&osdc->lock); + if (!osd_registered(osd)) { + dout("%s osd%d unknown\n", __func__, osd->o_osd); + goto out_unlock_osdc; + } + WARN_ON(osd->o_osd != le64_to_cpu(msg->hdr.src.num)); + + mutex_lock(&osd->lock); + req = lookup_request(&osd->o_requests, tid); + if (!req) { + dout("%s osd%d tid %llu unknown\n", __func__, osd->o_osd, tid); + goto out_unlock_session; + } + + m.redirect.oloc.pool_ns = req->r_t.target_oloc.pool_ns; + ret = decode_MOSDOpReply(msg, &m); + m.redirect.oloc.pool_ns = NULL; + if (ret) { + pr_err("failed to decode MOSDOpReply for tid %llu: %d\n", + req->r_tid, ret); + ceph_msg_dump(msg); + goto fail_request; + } + dout("%s req %p tid %llu flags 0x%llx pgid %llu.%x epoch %u attempt %d v %u'%llu uv %llu\n", + __func__, req, req->r_tid, m.flags, m.pgid.pool, m.pgid.seed, + m.epoch, m.retry_attempt, le32_to_cpu(m.replay_version.epoch), + le64_to_cpu(m.replay_version.version), m.user_version); + + if (m.retry_attempt >= 0) { + if (m.retry_attempt != req->r_attempts - 1) { + dout("req %p tid %llu retry_attempt %d != %d, ignoring\n", + req, req->r_tid, m.retry_attempt, + req->r_attempts - 1); + goto out_unlock_session; + } + } else { + WARN_ON(1); /* MOSDOpReply v4 is assumed */ + } + + if (!ceph_oloc_empty(&m.redirect.oloc)) { + dout("req %p tid %llu redirect pool %lld\n", req, req->r_tid, + m.redirect.oloc.pool); + unlink_request(osd, req); + mutex_unlock(&osd->lock); + + /* + * Not ceph_oloc_copy() - changing pool_ns is not + * supported. + */ + req->r_t.target_oloc.pool = m.redirect.oloc.pool; + req->r_flags |= CEPH_OSD_FLAG_REDIRECTED | + CEPH_OSD_FLAG_IGNORE_OVERLAY | + CEPH_OSD_FLAG_IGNORE_CACHE; + req->r_tid = 0; + __submit_request(req, false); + goto out_unlock_osdc; + } + + if (m.result == -EAGAIN) { + dout("req %p tid %llu EAGAIN\n", req, req->r_tid); + unlink_request(osd, req); + mutex_unlock(&osd->lock); + + /* + * The object is missing on the replica or not (yet) + * readable. Clear pgid to force a resend to the primary + * via legacy_change. + */ + req->r_t.pgid.pool = 0; + req->r_t.pgid.seed = 0; + WARN_ON(!req->r_t.used_replica); + req->r_flags &= ~(CEPH_OSD_FLAG_BALANCE_READS | + CEPH_OSD_FLAG_LOCALIZE_READS); + req->r_tid = 0; + __submit_request(req, false); + goto out_unlock_osdc; + } + + if (m.num_ops != req->r_num_ops) { + pr_err("num_ops %d != %d for tid %llu\n", m.num_ops, + req->r_num_ops, req->r_tid); + goto fail_request; + } + for (i = 0; i < req->r_num_ops; i++) { + dout(" req %p tid %llu op %d rval %d len %u\n", req, + req->r_tid, i, m.rval[i], m.outdata_len[i]); + req->r_ops[i].rval = m.rval[i]; + req->r_ops[i].outdata_len = m.outdata_len[i]; + data_len += m.outdata_len[i]; + } + if (data_len != le32_to_cpu(msg->hdr.data_len)) { + pr_err("sum of lens %u != %u for tid %llu\n", data_len, + le32_to_cpu(msg->hdr.data_len), req->r_tid); + goto fail_request; + } + dout("%s req %p tid %llu result %d data_len %u\n", __func__, + req, req->r_tid, m.result, data_len); + + /* + * Since we only ever request ONDISK, we should only ever get + * one (type of) reply back. + */ + WARN_ON(!(m.flags & CEPH_OSD_FLAG_ONDISK)); + req->r_result = m.result ?: data_len; + finish_request(req); + mutex_unlock(&osd->lock); + up_read(&osdc->lock); + + __complete_request(req); + return; + +fail_request: + complete_request(req, -EIO); +out_unlock_session: + mutex_unlock(&osd->lock); +out_unlock_osdc: + up_read(&osdc->lock); +} + +static void set_pool_was_full(struct ceph_osd_client *osdc) +{ + struct rb_node *n; + + for (n = rb_first(&osdc->osdmap->pg_pools); n; n = rb_next(n)) { + struct ceph_pg_pool_info *pi = + rb_entry(n, struct ceph_pg_pool_info, node); + + pi->was_full = __pool_full(pi); + } +} + +static bool pool_cleared_full(struct ceph_osd_client *osdc, s64 pool_id) +{ + struct ceph_pg_pool_info *pi; + + pi = ceph_pg_pool_by_id(osdc->osdmap, pool_id); + if (!pi) + return false; + + return pi->was_full && !__pool_full(pi); +} + +static enum calc_target_result +recalc_linger_target(struct ceph_osd_linger_request *lreq) +{ + struct ceph_osd_client *osdc = lreq->osdc; + enum calc_target_result ct_res; + + ct_res = calc_target(osdc, &lreq->t, true); + if (ct_res == CALC_TARGET_NEED_RESEND) { + struct ceph_osd *osd; + + osd = lookup_create_osd(osdc, lreq->t.osd, true); + if (osd != lreq->osd) { + unlink_linger(lreq->osd, lreq); + link_linger(osd, lreq); + } + } + + return ct_res; +} + +/* + * Requeue requests whose mapping to an OSD has changed. + */ +static void scan_requests(struct ceph_osd *osd, + bool force_resend, + bool cleared_full, + bool check_pool_cleared_full, + struct rb_root *need_resend, + struct list_head *need_resend_linger) +{ + struct ceph_osd_client *osdc = osd->o_osdc; + struct rb_node *n; + bool force_resend_writes; + + for (n = rb_first(&osd->o_linger_requests); n; ) { + struct ceph_osd_linger_request *lreq = + rb_entry(n, struct ceph_osd_linger_request, node); + enum calc_target_result ct_res; + + n = rb_next(n); /* recalc_linger_target() */ + + dout("%s lreq %p linger_id %llu\n", __func__, lreq, + lreq->linger_id); + ct_res = recalc_linger_target(lreq); + switch (ct_res) { + case CALC_TARGET_NO_ACTION: + force_resend_writes = cleared_full || + (check_pool_cleared_full && + pool_cleared_full(osdc, lreq->t.base_oloc.pool)); + if (!force_resend && !force_resend_writes) + break; + + fallthrough; + case CALC_TARGET_NEED_RESEND: + cancel_linger_map_check(lreq); + /* + * scan_requests() for the previous epoch(s) + * may have already added it to the list, since + * it's not unlinked here. + */ + if (list_empty(&lreq->scan_item)) + list_add_tail(&lreq->scan_item, need_resend_linger); + break; + case CALC_TARGET_POOL_DNE: + list_del_init(&lreq->scan_item); + check_linger_pool_dne(lreq); + break; + } + } + + for (n = rb_first(&osd->o_requests); n; ) { + struct ceph_osd_request *req = + rb_entry(n, struct ceph_osd_request, r_node); + enum calc_target_result ct_res; + + n = rb_next(n); /* unlink_request(), check_pool_dne() */ + + dout("%s req %p tid %llu\n", __func__, req, req->r_tid); + ct_res = calc_target(osdc, &req->r_t, false); + switch (ct_res) { + case CALC_TARGET_NO_ACTION: + force_resend_writes = cleared_full || + (check_pool_cleared_full && + pool_cleared_full(osdc, req->r_t.base_oloc.pool)); + if (!force_resend && + (!(req->r_flags & CEPH_OSD_FLAG_WRITE) || + !force_resend_writes)) + break; + + fallthrough; + case CALC_TARGET_NEED_RESEND: + cancel_map_check(req); + unlink_request(osd, req); + insert_request(need_resend, req); + break; + case CALC_TARGET_POOL_DNE: + check_pool_dne(req); + break; + } + } +} + +static int handle_one_map(struct ceph_osd_client *osdc, + void *p, void *end, bool incremental, + struct rb_root *need_resend, + struct list_head *need_resend_linger) +{ + struct ceph_osdmap *newmap; + struct rb_node *n; + bool skipped_map = false; + bool was_full; + + was_full = ceph_osdmap_flag(osdc, CEPH_OSDMAP_FULL); + set_pool_was_full(osdc); + + if (incremental) + newmap = osdmap_apply_incremental(&p, end, + ceph_msgr2(osdc->client), + osdc->osdmap); + else + newmap = ceph_osdmap_decode(&p, end, ceph_msgr2(osdc->client)); + if (IS_ERR(newmap)) + return PTR_ERR(newmap); + + if (newmap != osdc->osdmap) { + /* + * Preserve ->was_full before destroying the old map. + * For pools that weren't in the old map, ->was_full + * should be false. + */ + for (n = rb_first(&newmap->pg_pools); n; n = rb_next(n)) { + struct ceph_pg_pool_info *pi = + rb_entry(n, struct ceph_pg_pool_info, node); + struct ceph_pg_pool_info *old_pi; + + old_pi = ceph_pg_pool_by_id(osdc->osdmap, pi->id); + if (old_pi) + pi->was_full = old_pi->was_full; + else + WARN_ON(pi->was_full); + } + + if (osdc->osdmap->epoch && + osdc->osdmap->epoch + 1 < newmap->epoch) { + WARN_ON(incremental); + skipped_map = true; + } + + ceph_osdmap_destroy(osdc->osdmap); + osdc->osdmap = newmap; + } + + was_full &= !ceph_osdmap_flag(osdc, CEPH_OSDMAP_FULL); + scan_requests(&osdc->homeless_osd, skipped_map, was_full, true, + need_resend, need_resend_linger); + + for (n = rb_first(&osdc->osds); n; ) { + struct ceph_osd *osd = rb_entry(n, struct ceph_osd, o_node); + + n = rb_next(n); /* close_osd() */ + + scan_requests(osd, skipped_map, was_full, true, need_resend, + need_resend_linger); + if (!ceph_osd_is_up(osdc->osdmap, osd->o_osd) || + memcmp(&osd->o_con.peer_addr, + ceph_osd_addr(osdc->osdmap, osd->o_osd), + sizeof(struct ceph_entity_addr))) + close_osd(osd); + } + + return 0; +} + +static void kick_requests(struct ceph_osd_client *osdc, + struct rb_root *need_resend, + struct list_head *need_resend_linger) +{ + struct ceph_osd_linger_request *lreq, *nlreq; + enum calc_target_result ct_res; + struct rb_node *n; + + /* make sure need_resend targets reflect latest map */ + for (n = rb_first(need_resend); n; ) { + struct ceph_osd_request *req = + rb_entry(n, struct ceph_osd_request, r_node); + + n = rb_next(n); + + if (req->r_t.epoch < osdc->osdmap->epoch) { + ct_res = calc_target(osdc, &req->r_t, false); + if (ct_res == CALC_TARGET_POOL_DNE) { + erase_request(need_resend, req); + check_pool_dne(req); + } + } + } + + for (n = rb_first(need_resend); n; ) { + struct ceph_osd_request *req = + rb_entry(n, struct ceph_osd_request, r_node); + struct ceph_osd *osd; + + n = rb_next(n); + erase_request(need_resend, req); /* before link_request() */ + + osd = lookup_create_osd(osdc, req->r_t.osd, true); + link_request(osd, req); + if (!req->r_linger) { + if (!osd_homeless(osd) && !req->r_t.paused) + send_request(req); + } else { + cancel_linger_request(req); + } + } + + list_for_each_entry_safe(lreq, nlreq, need_resend_linger, scan_item) { + if (!osd_homeless(lreq->osd)) + send_linger(lreq); + + list_del_init(&lreq->scan_item); + } +} + +/* + * Process updated osd map. + * + * The message contains any number of incremental and full maps, normally + * indicating some sort of topology change in the cluster. Kick requests + * off to different OSDs as needed. + */ +void ceph_osdc_handle_map(struct ceph_osd_client *osdc, struct ceph_msg *msg) +{ + void *p = msg->front.iov_base; + void *const end = p + msg->front.iov_len; + u32 nr_maps, maplen; + u32 epoch; + struct ceph_fsid fsid; + struct rb_root need_resend = RB_ROOT; + LIST_HEAD(need_resend_linger); + bool handled_incremental = false; + bool was_pauserd, was_pausewr; + bool pauserd, pausewr; + int err; + + dout("%s have %u\n", __func__, osdc->osdmap->epoch); + down_write(&osdc->lock); + + /* verify fsid */ + ceph_decode_need(&p, end, sizeof(fsid), bad); + ceph_decode_copy(&p, &fsid, sizeof(fsid)); + if (ceph_check_fsid(osdc->client, &fsid) < 0) + goto bad; + + was_pauserd = ceph_osdmap_flag(osdc, CEPH_OSDMAP_PAUSERD); + was_pausewr = ceph_osdmap_flag(osdc, CEPH_OSDMAP_PAUSEWR) || + ceph_osdmap_flag(osdc, CEPH_OSDMAP_FULL) || + have_pool_full(osdc); + + /* incremental maps */ + ceph_decode_32_safe(&p, end, nr_maps, bad); + dout(" %d inc maps\n", nr_maps); + while (nr_maps > 0) { + ceph_decode_need(&p, end, 2*sizeof(u32), bad); + epoch = ceph_decode_32(&p); + maplen = ceph_decode_32(&p); + ceph_decode_need(&p, end, maplen, bad); + if (osdc->osdmap->epoch && + osdc->osdmap->epoch + 1 == epoch) { + dout("applying incremental map %u len %d\n", + epoch, maplen); + err = handle_one_map(osdc, p, p + maplen, true, + &need_resend, &need_resend_linger); + if (err) + goto bad; + handled_incremental = true; + } else { + dout("ignoring incremental map %u len %d\n", + epoch, maplen); + } + p += maplen; + nr_maps--; + } + if (handled_incremental) + goto done; + + /* full maps */ + ceph_decode_32_safe(&p, end, nr_maps, bad); + dout(" %d full maps\n", nr_maps); + while (nr_maps) { + ceph_decode_need(&p, end, 2*sizeof(u32), bad); + epoch = ceph_decode_32(&p); + maplen = ceph_decode_32(&p); + ceph_decode_need(&p, end, maplen, bad); + if (nr_maps > 1) { + dout("skipping non-latest full map %u len %d\n", + epoch, maplen); + } else if (osdc->osdmap->epoch >= epoch) { + dout("skipping full map %u len %d, " + "older than our %u\n", epoch, maplen, + osdc->osdmap->epoch); + } else { + dout("taking full map %u len %d\n", epoch, maplen); + err = handle_one_map(osdc, p, p + maplen, false, + &need_resend, &need_resend_linger); + if (err) + goto bad; + } + p += maplen; + nr_maps--; + } + +done: + /* + * subscribe to subsequent osdmap updates if full to ensure + * we find out when we are no longer full and stop returning + * ENOSPC. + */ + pauserd = ceph_osdmap_flag(osdc, CEPH_OSDMAP_PAUSERD); + pausewr = ceph_osdmap_flag(osdc, CEPH_OSDMAP_PAUSEWR) || + ceph_osdmap_flag(osdc, CEPH_OSDMAP_FULL) || + have_pool_full(osdc); + if (was_pauserd || was_pausewr || pauserd || pausewr || + osdc->osdmap->epoch < osdc->epoch_barrier) + maybe_request_map(osdc); + + kick_requests(osdc, &need_resend, &need_resend_linger); + + ceph_osdc_abort_on_full(osdc); + ceph_monc_got_map(&osdc->client->monc, CEPH_SUB_OSDMAP, + osdc->osdmap->epoch); + up_write(&osdc->lock); + wake_up_all(&osdc->client->auth_wq); + return; + +bad: + pr_err("osdc handle_map corrupt msg\n"); + ceph_msg_dump(msg); + up_write(&osdc->lock); +} + +/* + * Resubmit requests pending on the given osd. + */ +static void kick_osd_requests(struct ceph_osd *osd) +{ + struct rb_node *n; + + clear_backoffs(osd); + + for (n = rb_first(&osd->o_requests); n; ) { + struct ceph_osd_request *req = + rb_entry(n, struct ceph_osd_request, r_node); + + n = rb_next(n); /* cancel_linger_request() */ + + if (!req->r_linger) { + if (!req->r_t.paused) + send_request(req); + } else { + cancel_linger_request(req); + } + } + for (n = rb_first(&osd->o_linger_requests); n; n = rb_next(n)) { + struct ceph_osd_linger_request *lreq = + rb_entry(n, struct ceph_osd_linger_request, node); + + send_linger(lreq); + } +} + +/* + * If the osd connection drops, we need to resubmit all requests. + */ +static void osd_fault(struct ceph_connection *con) +{ + struct ceph_osd *osd = con->private; + struct ceph_osd_client *osdc = osd->o_osdc; + + dout("%s osd %p osd%d\n", __func__, osd, osd->o_osd); + + down_write(&osdc->lock); + if (!osd_registered(osd)) { + dout("%s osd%d unknown\n", __func__, osd->o_osd); + goto out_unlock; + } + + if (!reopen_osd(osd)) + kick_osd_requests(osd); + maybe_request_map(osdc); + +out_unlock: + up_write(&osdc->lock); +} + +struct MOSDBackoff { + struct ceph_spg spgid; + u32 map_epoch; + u8 op; + u64 id; + struct ceph_hobject_id *begin; + struct ceph_hobject_id *end; +}; + +static int decode_MOSDBackoff(const struct ceph_msg *msg, struct MOSDBackoff *m) +{ + void *p = msg->front.iov_base; + void *const end = p + msg->front.iov_len; + u8 struct_v; + u32 struct_len; + int ret; + + ret = ceph_start_decoding(&p, end, 1, "spg_t", &struct_v, &struct_len); + if (ret) + return ret; + + ret = ceph_decode_pgid(&p, end, &m->spgid.pgid); + if (ret) + return ret; + + ceph_decode_8_safe(&p, end, m->spgid.shard, e_inval); + ceph_decode_32_safe(&p, end, m->map_epoch, e_inval); + ceph_decode_8_safe(&p, end, m->op, e_inval); + ceph_decode_64_safe(&p, end, m->id, e_inval); + + m->begin = kzalloc(sizeof(*m->begin), GFP_NOIO); + if (!m->begin) + return -ENOMEM; + + ret = decode_hoid(&p, end, m->begin); + if (ret) { + free_hoid(m->begin); + return ret; + } + + m->end = kzalloc(sizeof(*m->end), GFP_NOIO); + if (!m->end) { + free_hoid(m->begin); + return -ENOMEM; + } + + ret = decode_hoid(&p, end, m->end); + if (ret) { + free_hoid(m->begin); + free_hoid(m->end); + return ret; + } + + return 0; + +e_inval: + return -EINVAL; +} + +static struct ceph_msg *create_backoff_message( + const struct ceph_osd_backoff *backoff, + u32 map_epoch) +{ + struct ceph_msg *msg; + void *p, *end; + int msg_size; + + msg_size = CEPH_ENCODING_START_BLK_LEN + + CEPH_PGID_ENCODING_LEN + 1; /* spgid */ + msg_size += 4 + 1 + 8; /* map_epoch, op, id */ + msg_size += CEPH_ENCODING_START_BLK_LEN + + hoid_encoding_size(backoff->begin); + msg_size += CEPH_ENCODING_START_BLK_LEN + + hoid_encoding_size(backoff->end); + + msg = ceph_msg_new(CEPH_MSG_OSD_BACKOFF, msg_size, GFP_NOIO, true); + if (!msg) + return NULL; + + p = msg->front.iov_base; + end = p + msg->front_alloc_len; + + encode_spgid(&p, &backoff->spgid); + ceph_encode_32(&p, map_epoch); + ceph_encode_8(&p, CEPH_OSD_BACKOFF_OP_ACK_BLOCK); + ceph_encode_64(&p, backoff->id); + encode_hoid(&p, end, backoff->begin); + encode_hoid(&p, end, backoff->end); + BUG_ON(p != end); + + msg->front.iov_len = p - msg->front.iov_base; + msg->hdr.version = cpu_to_le16(1); /* MOSDBackoff v1 */ + msg->hdr.front_len = cpu_to_le32(msg->front.iov_len); + + return msg; +} + +static void handle_backoff_block(struct ceph_osd *osd, struct MOSDBackoff *m) +{ + struct ceph_spg_mapping *spg; + struct ceph_osd_backoff *backoff; + struct ceph_msg *msg; + + dout("%s osd%d spgid %llu.%xs%d id %llu\n", __func__, osd->o_osd, + m->spgid.pgid.pool, m->spgid.pgid.seed, m->spgid.shard, m->id); + + spg = lookup_spg_mapping(&osd->o_backoff_mappings, &m->spgid); + if (!spg) { + spg = alloc_spg_mapping(); + if (!spg) { + pr_err("%s failed to allocate spg\n", __func__); + return; + } + spg->spgid = m->spgid; /* struct */ + insert_spg_mapping(&osd->o_backoff_mappings, spg); + } + + backoff = alloc_backoff(); + if (!backoff) { + pr_err("%s failed to allocate backoff\n", __func__); + return; + } + backoff->spgid = m->spgid; /* struct */ + backoff->id = m->id; + backoff->begin = m->begin; + m->begin = NULL; /* backoff now owns this */ + backoff->end = m->end; + m->end = NULL; /* ditto */ + + insert_backoff(&spg->backoffs, backoff); + insert_backoff_by_id(&osd->o_backoffs_by_id, backoff); + + /* + * Ack with original backoff's epoch so that the OSD can + * discard this if there was a PG split. + */ + msg = create_backoff_message(backoff, m->map_epoch); + if (!msg) { + pr_err("%s failed to allocate msg\n", __func__); + return; + } + ceph_con_send(&osd->o_con, msg); +} + +static bool target_contained_by(const struct ceph_osd_request_target *t, + const struct ceph_hobject_id *begin, + const struct ceph_hobject_id *end) +{ + struct ceph_hobject_id hoid; + int cmp; + + hoid_fill_from_target(&hoid, t); + cmp = hoid_compare(&hoid, begin); + return !cmp || (cmp > 0 && hoid_compare(&hoid, end) < 0); +} + +static void handle_backoff_unblock(struct ceph_osd *osd, + const struct MOSDBackoff *m) +{ + struct ceph_spg_mapping *spg; + struct ceph_osd_backoff *backoff; + struct rb_node *n; + + dout("%s osd%d spgid %llu.%xs%d id %llu\n", __func__, osd->o_osd, + m->spgid.pgid.pool, m->spgid.pgid.seed, m->spgid.shard, m->id); + + backoff = lookup_backoff_by_id(&osd->o_backoffs_by_id, m->id); + if (!backoff) { + pr_err("%s osd%d spgid %llu.%xs%d id %llu backoff dne\n", + __func__, osd->o_osd, m->spgid.pgid.pool, + m->spgid.pgid.seed, m->spgid.shard, m->id); + return; + } + + if (hoid_compare(backoff->begin, m->begin) && + hoid_compare(backoff->end, m->end)) { + pr_err("%s osd%d spgid %llu.%xs%d id %llu bad range?\n", + __func__, osd->o_osd, m->spgid.pgid.pool, + m->spgid.pgid.seed, m->spgid.shard, m->id); + /* unblock it anyway... */ + } + + spg = lookup_spg_mapping(&osd->o_backoff_mappings, &backoff->spgid); + BUG_ON(!spg); + + erase_backoff(&spg->backoffs, backoff); + erase_backoff_by_id(&osd->o_backoffs_by_id, backoff); + free_backoff(backoff); + + if (RB_EMPTY_ROOT(&spg->backoffs)) { + erase_spg_mapping(&osd->o_backoff_mappings, spg); + free_spg_mapping(spg); + } + + for (n = rb_first(&osd->o_requests); n; n = rb_next(n)) { + struct ceph_osd_request *req = + rb_entry(n, struct ceph_osd_request, r_node); + + if (!ceph_spg_compare(&req->r_t.spgid, &m->spgid)) { + /* + * Match against @m, not @backoff -- the PG may + * have split on the OSD. + */ + if (target_contained_by(&req->r_t, m->begin, m->end)) { + /* + * If no other installed backoff applies, + * resend. + */ + send_request(req); + } + } + } +} + +static void handle_backoff(struct ceph_osd *osd, struct ceph_msg *msg) +{ + struct ceph_osd_client *osdc = osd->o_osdc; + struct MOSDBackoff m; + int ret; + + down_read(&osdc->lock); + if (!osd_registered(osd)) { + dout("%s osd%d unknown\n", __func__, osd->o_osd); + up_read(&osdc->lock); + return; + } + WARN_ON(osd->o_osd != le64_to_cpu(msg->hdr.src.num)); + + mutex_lock(&osd->lock); + ret = decode_MOSDBackoff(msg, &m); + if (ret) { + pr_err("failed to decode MOSDBackoff: %d\n", ret); + ceph_msg_dump(msg); + goto out_unlock; + } + + switch (m.op) { + case CEPH_OSD_BACKOFF_OP_BLOCK: + handle_backoff_block(osd, &m); + break; + case CEPH_OSD_BACKOFF_OP_UNBLOCK: + handle_backoff_unblock(osd, &m); + break; + default: + pr_err("%s osd%d unknown op %d\n", __func__, osd->o_osd, m.op); + } + + free_hoid(m.begin); + free_hoid(m.end); + +out_unlock: + mutex_unlock(&osd->lock); + up_read(&osdc->lock); +} + +/* + * Process osd watch notifications + */ +static void handle_watch_notify(struct ceph_osd_client *osdc, + struct ceph_msg *msg) +{ + void *p = msg->front.iov_base; + void *const end = p + msg->front.iov_len; + struct ceph_osd_linger_request *lreq; + struct linger_work *lwork; + u8 proto_ver, opcode; + u64 cookie, notify_id; + u64 notifier_id = 0; + s32 return_code = 0; + void *payload = NULL; + u32 payload_len = 0; + + ceph_decode_8_safe(&p, end, proto_ver, bad); + ceph_decode_8_safe(&p, end, opcode, bad); + ceph_decode_64_safe(&p, end, cookie, bad); + p += 8; /* skip ver */ + ceph_decode_64_safe(&p, end, notify_id, bad); + + if (proto_ver >= 1) { + ceph_decode_32_safe(&p, end, payload_len, bad); + ceph_decode_need(&p, end, payload_len, bad); + payload = p; + p += payload_len; + } + + if (le16_to_cpu(msg->hdr.version) >= 2) + ceph_decode_32_safe(&p, end, return_code, bad); + + if (le16_to_cpu(msg->hdr.version) >= 3) + ceph_decode_64_safe(&p, end, notifier_id, bad); + + down_read(&osdc->lock); + lreq = lookup_linger_osdc(&osdc->linger_requests, cookie); + if (!lreq) { + dout("%s opcode %d cookie %llu dne\n", __func__, opcode, + cookie); + goto out_unlock_osdc; + } + + mutex_lock(&lreq->lock); + dout("%s opcode %d cookie %llu lreq %p is_watch %d\n", __func__, + opcode, cookie, lreq, lreq->is_watch); + if (opcode == CEPH_WATCH_EVENT_DISCONNECT) { + if (!lreq->last_error) { + lreq->last_error = -ENOTCONN; + queue_watch_error(lreq); + } + } else if (!lreq->is_watch) { + /* CEPH_WATCH_EVENT_NOTIFY_COMPLETE */ + if (lreq->notify_id && lreq->notify_id != notify_id) { + dout("lreq %p notify_id %llu != %llu, ignoring\n", lreq, + lreq->notify_id, notify_id); + } else if (!completion_done(&lreq->notify_finish_wait)) { + struct ceph_msg_data *data = + msg->num_data_items ? &msg->data[0] : NULL; + + if (data) { + if (lreq->preply_pages) { + WARN_ON(data->type != + CEPH_MSG_DATA_PAGES); + *lreq->preply_pages = data->pages; + *lreq->preply_len = data->length; + data->own_pages = false; + } + } + lreq->notify_finish_error = return_code; + complete_all(&lreq->notify_finish_wait); + } + } else { + /* CEPH_WATCH_EVENT_NOTIFY */ + lwork = lwork_alloc(lreq, do_watch_notify); + if (!lwork) { + pr_err("failed to allocate notify-lwork\n"); + goto out_unlock_lreq; + } + + lwork->notify.notify_id = notify_id; + lwork->notify.notifier_id = notifier_id; + lwork->notify.payload = payload; + lwork->notify.payload_len = payload_len; + lwork->notify.msg = ceph_msg_get(msg); + lwork_queue(lwork); + } + +out_unlock_lreq: + mutex_unlock(&lreq->lock); +out_unlock_osdc: + up_read(&osdc->lock); + return; + +bad: + pr_err("osdc handle_watch_notify corrupt msg\n"); +} + +/* + * Register request, send initial attempt. + */ +void ceph_osdc_start_request(struct ceph_osd_client *osdc, + struct ceph_osd_request *req) +{ + down_read(&osdc->lock); + submit_request(req, false); + up_read(&osdc->lock); +} +EXPORT_SYMBOL(ceph_osdc_start_request); + +/* + * Unregister request. If @req was registered, it isn't completed: + * r_result isn't set and __complete_request() isn't invoked. + * + * If @req wasn't registered, this call may have raced with + * handle_reply(), in which case r_result would already be set and + * __complete_request() would be getting invoked, possibly even + * concurrently with this call. + */ +void ceph_osdc_cancel_request(struct ceph_osd_request *req) +{ + struct ceph_osd_client *osdc = req->r_osdc; + + down_write(&osdc->lock); + if (req->r_osd) + cancel_request(req); + up_write(&osdc->lock); +} +EXPORT_SYMBOL(ceph_osdc_cancel_request); + +/* + * @timeout: in jiffies, 0 means "wait forever" + */ +static int wait_request_timeout(struct ceph_osd_request *req, + unsigned long timeout) +{ + long left; + + dout("%s req %p tid %llu\n", __func__, req, req->r_tid); + left = wait_for_completion_killable_timeout(&req->r_completion, + ceph_timeout_jiffies(timeout)); + if (left <= 0) { + left = left ?: -ETIMEDOUT; + ceph_osdc_cancel_request(req); + } else { + left = req->r_result; /* completed */ + } + + return left; +} + +/* + * wait for a request to complete + */ +int ceph_osdc_wait_request(struct ceph_osd_client *osdc, + struct ceph_osd_request *req) +{ + return wait_request_timeout(req, 0); +} +EXPORT_SYMBOL(ceph_osdc_wait_request); + +/* + * sync - wait for all in-flight requests to flush. avoid starvation. + */ +void ceph_osdc_sync(struct ceph_osd_client *osdc) +{ + struct rb_node *n, *p; + u64 last_tid = atomic64_read(&osdc->last_tid); + +again: + down_read(&osdc->lock); + for (n = rb_first(&osdc->osds); n; n = rb_next(n)) { + struct ceph_osd *osd = rb_entry(n, struct ceph_osd, o_node); + + mutex_lock(&osd->lock); + for (p = rb_first(&osd->o_requests); p; p = rb_next(p)) { + struct ceph_osd_request *req = + rb_entry(p, struct ceph_osd_request, r_node); + + if (req->r_tid > last_tid) + break; + + if (!(req->r_flags & CEPH_OSD_FLAG_WRITE)) + continue; + + ceph_osdc_get_request(req); + mutex_unlock(&osd->lock); + up_read(&osdc->lock); + dout("%s waiting on req %p tid %llu last_tid %llu\n", + __func__, req, req->r_tid, last_tid); + wait_for_completion(&req->r_completion); + ceph_osdc_put_request(req); + goto again; + } + + mutex_unlock(&osd->lock); + } + + up_read(&osdc->lock); + dout("%s done last_tid %llu\n", __func__, last_tid); +} +EXPORT_SYMBOL(ceph_osdc_sync); + +/* + * Returns a handle, caller owns a ref. + */ +struct ceph_osd_linger_request * +ceph_osdc_watch(struct ceph_osd_client *osdc, + struct ceph_object_id *oid, + struct ceph_object_locator *oloc, + rados_watchcb2_t wcb, + rados_watcherrcb_t errcb, + void *data) +{ + struct ceph_osd_linger_request *lreq; + int ret; + + lreq = linger_alloc(osdc); + if (!lreq) + return ERR_PTR(-ENOMEM); + + lreq->is_watch = true; + lreq->wcb = wcb; + lreq->errcb = errcb; + lreq->data = data; + lreq->watch_valid_thru = jiffies; + + ceph_oid_copy(&lreq->t.base_oid, oid); + ceph_oloc_copy(&lreq->t.base_oloc, oloc); + lreq->t.flags = CEPH_OSD_FLAG_WRITE; + ktime_get_real_ts64(&lreq->mtime); + + linger_submit(lreq); + ret = linger_reg_commit_wait(lreq); + if (ret) { + linger_cancel(lreq); + goto err_put_lreq; + } + + return lreq; + +err_put_lreq: + linger_put(lreq); + return ERR_PTR(ret); +} +EXPORT_SYMBOL(ceph_osdc_watch); + +/* + * Releases a ref. + * + * Times out after mount_timeout to preserve rbd unmap behaviour + * introduced in 2894e1d76974 ("rbd: timeout watch teardown on unmap + * with mount_timeout"). + */ +int ceph_osdc_unwatch(struct ceph_osd_client *osdc, + struct ceph_osd_linger_request *lreq) +{ + struct ceph_options *opts = osdc->client->options; + struct ceph_osd_request *req; + int ret; + + req = ceph_osdc_alloc_request(osdc, NULL, 1, false, GFP_NOIO); + if (!req) + return -ENOMEM; + + ceph_oid_copy(&req->r_base_oid, &lreq->t.base_oid); + ceph_oloc_copy(&req->r_base_oloc, &lreq->t.base_oloc); + req->r_flags = CEPH_OSD_FLAG_WRITE; + ktime_get_real_ts64(&req->r_mtime); + osd_req_op_watch_init(req, 0, CEPH_OSD_WATCH_OP_UNWATCH, + lreq->linger_id, 0); + + ret = ceph_osdc_alloc_messages(req, GFP_NOIO); + if (ret) + goto out_put_req; + + ceph_osdc_start_request(osdc, req); + linger_cancel(lreq); + linger_put(lreq); + ret = wait_request_timeout(req, opts->mount_timeout); + +out_put_req: + ceph_osdc_put_request(req); + return ret; +} +EXPORT_SYMBOL(ceph_osdc_unwatch); + +static int osd_req_op_notify_ack_init(struct ceph_osd_request *req, int which, + u64 notify_id, u64 cookie, void *payload, + u32 payload_len) +{ + struct ceph_osd_req_op *op; + struct ceph_pagelist *pl; + int ret; + + op = osd_req_op_init(req, which, CEPH_OSD_OP_NOTIFY_ACK, 0); + + pl = ceph_pagelist_alloc(GFP_NOIO); + if (!pl) + return -ENOMEM; + + ret = ceph_pagelist_encode_64(pl, notify_id); + ret |= ceph_pagelist_encode_64(pl, cookie); + if (payload) { + ret |= ceph_pagelist_encode_32(pl, payload_len); + ret |= ceph_pagelist_append(pl, payload, payload_len); + } else { + ret |= ceph_pagelist_encode_32(pl, 0); + } + if (ret) { + ceph_pagelist_release(pl); + return -ENOMEM; + } + + ceph_osd_data_pagelist_init(&op->notify_ack.request_data, pl); + op->indata_len = pl->length; + return 0; +} + +int ceph_osdc_notify_ack(struct ceph_osd_client *osdc, + struct ceph_object_id *oid, + struct ceph_object_locator *oloc, + u64 notify_id, + u64 cookie, + void *payload, + u32 payload_len) +{ + struct ceph_osd_request *req; + int ret; + + req = ceph_osdc_alloc_request(osdc, NULL, 1, false, GFP_NOIO); + if (!req) + return -ENOMEM; + + ceph_oid_copy(&req->r_base_oid, oid); + ceph_oloc_copy(&req->r_base_oloc, oloc); + req->r_flags = CEPH_OSD_FLAG_READ; + + ret = osd_req_op_notify_ack_init(req, 0, notify_id, cookie, payload, + payload_len); + if (ret) + goto out_put_req; + + ret = ceph_osdc_alloc_messages(req, GFP_NOIO); + if (ret) + goto out_put_req; + + ceph_osdc_start_request(osdc, req); + ret = ceph_osdc_wait_request(osdc, req); + +out_put_req: + ceph_osdc_put_request(req); + return ret; +} +EXPORT_SYMBOL(ceph_osdc_notify_ack); + +/* + * @timeout: in seconds + * + * @preply_{pages,len} are initialized both on success and error. + * The caller is responsible for: + * + * ceph_release_page_vector(reply_pages, calc_pages_for(0, reply_len)) + */ +int ceph_osdc_notify(struct ceph_osd_client *osdc, + struct ceph_object_id *oid, + struct ceph_object_locator *oloc, + void *payload, + u32 payload_len, + u32 timeout, + struct page ***preply_pages, + size_t *preply_len) +{ + struct ceph_osd_linger_request *lreq; + int ret; + + WARN_ON(!timeout); + if (preply_pages) { + *preply_pages = NULL; + *preply_len = 0; + } + + lreq = linger_alloc(osdc); + if (!lreq) + return -ENOMEM; + + lreq->request_pl = ceph_pagelist_alloc(GFP_NOIO); + if (!lreq->request_pl) { + ret = -ENOMEM; + goto out_put_lreq; + } + + ret = ceph_pagelist_encode_32(lreq->request_pl, 1); /* prot_ver */ + ret |= ceph_pagelist_encode_32(lreq->request_pl, timeout); + ret |= ceph_pagelist_encode_32(lreq->request_pl, payload_len); + ret |= ceph_pagelist_append(lreq->request_pl, payload, payload_len); + if (ret) { + ret = -ENOMEM; + goto out_put_lreq; + } + + /* for notify_id */ + lreq->notify_id_pages = ceph_alloc_page_vector(1, GFP_NOIO); + if (IS_ERR(lreq->notify_id_pages)) { + ret = PTR_ERR(lreq->notify_id_pages); + lreq->notify_id_pages = NULL; + goto out_put_lreq; + } + + lreq->preply_pages = preply_pages; + lreq->preply_len = preply_len; + + ceph_oid_copy(&lreq->t.base_oid, oid); + ceph_oloc_copy(&lreq->t.base_oloc, oloc); + lreq->t.flags = CEPH_OSD_FLAG_READ; + + linger_submit(lreq); + ret = linger_reg_commit_wait(lreq); + if (!ret) + ret = linger_notify_finish_wait(lreq, + msecs_to_jiffies(2 * timeout * MSEC_PER_SEC)); + else + dout("lreq %p failed to initiate notify %d\n", lreq, ret); + + linger_cancel(lreq); +out_put_lreq: + linger_put(lreq); + return ret; +} +EXPORT_SYMBOL(ceph_osdc_notify); + +/* + * Return the number of milliseconds since the watch was last + * confirmed, or an error. If there is an error, the watch is no + * longer valid, and should be destroyed with ceph_osdc_unwatch(). + */ +int ceph_osdc_watch_check(struct ceph_osd_client *osdc, + struct ceph_osd_linger_request *lreq) +{ + unsigned long stamp, age; + int ret; + + down_read(&osdc->lock); + mutex_lock(&lreq->lock); + stamp = lreq->watch_valid_thru; + if (!list_empty(&lreq->pending_lworks)) { + struct linger_work *lwork = + list_first_entry(&lreq->pending_lworks, + struct linger_work, + pending_item); + + if (time_before(lwork->queued_stamp, stamp)) + stamp = lwork->queued_stamp; + } + age = jiffies - stamp; + dout("%s lreq %p linger_id %llu age %lu last_error %d\n", __func__, + lreq, lreq->linger_id, age, lreq->last_error); + /* we are truncating to msecs, so return a safe upper bound */ + ret = lreq->last_error ?: 1 + jiffies_to_msecs(age); + + mutex_unlock(&lreq->lock); + up_read(&osdc->lock); + return ret; +} + +static int decode_watcher(void **p, void *end, struct ceph_watch_item *item) +{ + u8 struct_v; + u32 struct_len; + int ret; + + ret = ceph_start_decoding(p, end, 2, "watch_item_t", + &struct_v, &struct_len); + if (ret) + goto bad; + + ret = -EINVAL; + ceph_decode_copy_safe(p, end, &item->name, sizeof(item->name), bad); + ceph_decode_64_safe(p, end, item->cookie, bad); + ceph_decode_skip_32(p, end, bad); /* skip timeout seconds */ + + if (struct_v >= 2) { + ret = ceph_decode_entity_addr(p, end, &item->addr); + if (ret) + goto bad; + } else { + ret = 0; + } + + dout("%s %s%llu cookie %llu addr %s\n", __func__, + ENTITY_NAME(item->name), item->cookie, + ceph_pr_addr(&item->addr)); +bad: + return ret; +} + +static int decode_watchers(void **p, void *end, + struct ceph_watch_item **watchers, + u32 *num_watchers) +{ + u8 struct_v; + u32 struct_len; + int i; + int ret; + + ret = ceph_start_decoding(p, end, 1, "obj_list_watch_response_t", + &struct_v, &struct_len); + if (ret) + return ret; + + *num_watchers = ceph_decode_32(p); + *watchers = kcalloc(*num_watchers, sizeof(**watchers), GFP_NOIO); + if (!*watchers) + return -ENOMEM; + + for (i = 0; i < *num_watchers; i++) { + ret = decode_watcher(p, end, *watchers + i); + if (ret) { + kfree(*watchers); + return ret; + } + } + + return 0; +} + +/* + * On success, the caller is responsible for: + * + * kfree(watchers); + */ +int ceph_osdc_list_watchers(struct ceph_osd_client *osdc, + struct ceph_object_id *oid, + struct ceph_object_locator *oloc, + struct ceph_watch_item **watchers, + u32 *num_watchers) +{ + struct ceph_osd_request *req; + struct page **pages; + int ret; + + req = ceph_osdc_alloc_request(osdc, NULL, 1, false, GFP_NOIO); + if (!req) + return -ENOMEM; + + ceph_oid_copy(&req->r_base_oid, oid); + ceph_oloc_copy(&req->r_base_oloc, oloc); + req->r_flags = CEPH_OSD_FLAG_READ; + + pages = ceph_alloc_page_vector(1, GFP_NOIO); + if (IS_ERR(pages)) { + ret = PTR_ERR(pages); + goto out_put_req; + } + + osd_req_op_init(req, 0, CEPH_OSD_OP_LIST_WATCHERS, 0); + ceph_osd_data_pages_init(osd_req_op_data(req, 0, list_watchers, + response_data), + pages, PAGE_SIZE, 0, false, true); + + ret = ceph_osdc_alloc_messages(req, GFP_NOIO); + if (ret) + goto out_put_req; + + ceph_osdc_start_request(osdc, req); + ret = ceph_osdc_wait_request(osdc, req); + if (ret >= 0) { + void *p = page_address(pages[0]); + void *const end = p + req->r_ops[0].outdata_len; + + ret = decode_watchers(&p, end, watchers, num_watchers); + } + +out_put_req: + ceph_osdc_put_request(req); + return ret; +} +EXPORT_SYMBOL(ceph_osdc_list_watchers); + +/* + * Call all pending notify callbacks - for use after a watch is + * unregistered, to make sure no more callbacks for it will be invoked + */ +void ceph_osdc_flush_notifies(struct ceph_osd_client *osdc) +{ + dout("%s osdc %p\n", __func__, osdc); + flush_workqueue(osdc->notify_wq); +} +EXPORT_SYMBOL(ceph_osdc_flush_notifies); + +void ceph_osdc_maybe_request_map(struct ceph_osd_client *osdc) +{ + down_read(&osdc->lock); + maybe_request_map(osdc); + up_read(&osdc->lock); +} +EXPORT_SYMBOL(ceph_osdc_maybe_request_map); + +/* + * Execute an OSD class method on an object. + * + * @flags: CEPH_OSD_FLAG_* + * @resp_len: in/out param for reply length + */ +int ceph_osdc_call(struct ceph_osd_client *osdc, + struct ceph_object_id *oid, + struct ceph_object_locator *oloc, + const char *class, const char *method, + unsigned int flags, + struct page *req_page, size_t req_len, + struct page **resp_pages, size_t *resp_len) +{ + struct ceph_osd_request *req; + int ret; + + if (req_len > PAGE_SIZE) + return -E2BIG; + + req = ceph_osdc_alloc_request(osdc, NULL, 1, false, GFP_NOIO); + if (!req) + return -ENOMEM; + + ceph_oid_copy(&req->r_base_oid, oid); + ceph_oloc_copy(&req->r_base_oloc, oloc); + req->r_flags = flags; + + ret = osd_req_op_cls_init(req, 0, class, method); + if (ret) + goto out_put_req; + + if (req_page) + osd_req_op_cls_request_data_pages(req, 0, &req_page, req_len, + 0, false, false); + if (resp_pages) + osd_req_op_cls_response_data_pages(req, 0, resp_pages, + *resp_len, 0, false, false); + + ret = ceph_osdc_alloc_messages(req, GFP_NOIO); + if (ret) + goto out_put_req; + + ceph_osdc_start_request(osdc, req); + ret = ceph_osdc_wait_request(osdc, req); + if (ret >= 0) { + ret = req->r_ops[0].rval; + if (resp_pages) + *resp_len = req->r_ops[0].outdata_len; + } + +out_put_req: + ceph_osdc_put_request(req); + return ret; +} +EXPORT_SYMBOL(ceph_osdc_call); + +/* + * reset all osd connections + */ +void ceph_osdc_reopen_osds(struct ceph_osd_client *osdc) +{ + struct rb_node *n; + + down_write(&osdc->lock); + for (n = rb_first(&osdc->osds); n; ) { + struct ceph_osd *osd = rb_entry(n, struct ceph_osd, o_node); + + n = rb_next(n); + if (!reopen_osd(osd)) + kick_osd_requests(osd); + } + up_write(&osdc->lock); +} + +/* + * init, shutdown + */ +int ceph_osdc_init(struct ceph_osd_client *osdc, struct ceph_client *client) +{ + int err; + + dout("init\n"); + osdc->client = client; + init_rwsem(&osdc->lock); + osdc->osds = RB_ROOT; + INIT_LIST_HEAD(&osdc->osd_lru); + spin_lock_init(&osdc->osd_lru_lock); + osd_init(&osdc->homeless_osd); + osdc->homeless_osd.o_osdc = osdc; + osdc->homeless_osd.o_osd = CEPH_HOMELESS_OSD; + osdc->last_linger_id = CEPH_LINGER_ID_START; + osdc->linger_requests = RB_ROOT; + osdc->map_checks = RB_ROOT; + osdc->linger_map_checks = RB_ROOT; + INIT_DELAYED_WORK(&osdc->timeout_work, handle_timeout); + INIT_DELAYED_WORK(&osdc->osds_timeout_work, handle_osds_timeout); + + err = -ENOMEM; + osdc->osdmap = ceph_osdmap_alloc(); + if (!osdc->osdmap) + goto out; + + osdc->req_mempool = mempool_create_slab_pool(10, + ceph_osd_request_cache); + if (!osdc->req_mempool) + goto out_map; + + err = ceph_msgpool_init(&osdc->msgpool_op, CEPH_MSG_OSD_OP, + PAGE_SIZE, CEPH_OSD_SLAB_OPS, 10, "osd_op"); + if (err < 0) + goto out_mempool; + err = ceph_msgpool_init(&osdc->msgpool_op_reply, CEPH_MSG_OSD_OPREPLY, + PAGE_SIZE, CEPH_OSD_SLAB_OPS, 10, + "osd_op_reply"); + if (err < 0) + goto out_msgpool; + + err = -ENOMEM; + osdc->notify_wq = create_singlethread_workqueue("ceph-watch-notify"); + if (!osdc->notify_wq) + goto out_msgpool_reply; + + osdc->completion_wq = create_singlethread_workqueue("ceph-completion"); + if (!osdc->completion_wq) + goto out_notify_wq; + + schedule_delayed_work(&osdc->timeout_work, + osdc->client->options->osd_keepalive_timeout); + schedule_delayed_work(&osdc->osds_timeout_work, + round_jiffies_relative(osdc->client->options->osd_idle_ttl)); + + return 0; + +out_notify_wq: + destroy_workqueue(osdc->notify_wq); +out_msgpool_reply: + ceph_msgpool_destroy(&osdc->msgpool_op_reply); +out_msgpool: + ceph_msgpool_destroy(&osdc->msgpool_op); +out_mempool: + mempool_destroy(osdc->req_mempool); +out_map: + ceph_osdmap_destroy(osdc->osdmap); +out: + return err; +} + +void ceph_osdc_stop(struct ceph_osd_client *osdc) +{ + destroy_workqueue(osdc->completion_wq); + destroy_workqueue(osdc->notify_wq); + cancel_delayed_work_sync(&osdc->timeout_work); + cancel_delayed_work_sync(&osdc->osds_timeout_work); + + down_write(&osdc->lock); + while (!RB_EMPTY_ROOT(&osdc->osds)) { + struct ceph_osd *osd = rb_entry(rb_first(&osdc->osds), + struct ceph_osd, o_node); + close_osd(osd); + } + up_write(&osdc->lock); + WARN_ON(refcount_read(&osdc->homeless_osd.o_ref) != 1); + osd_cleanup(&osdc->homeless_osd); + + WARN_ON(!list_empty(&osdc->osd_lru)); + WARN_ON(!RB_EMPTY_ROOT(&osdc->linger_requests)); + WARN_ON(!RB_EMPTY_ROOT(&osdc->map_checks)); + WARN_ON(!RB_EMPTY_ROOT(&osdc->linger_map_checks)); + WARN_ON(atomic_read(&osdc->num_requests)); + WARN_ON(atomic_read(&osdc->num_homeless)); + + ceph_osdmap_destroy(osdc->osdmap); + mempool_destroy(osdc->req_mempool); + ceph_msgpool_destroy(&osdc->msgpool_op); + ceph_msgpool_destroy(&osdc->msgpool_op_reply); +} + +int osd_req_op_copy_from_init(struct ceph_osd_request *req, + u64 src_snapid, u64 src_version, + struct ceph_object_id *src_oid, + struct ceph_object_locator *src_oloc, + u32 src_fadvise_flags, + u32 dst_fadvise_flags, + u32 truncate_seq, u64 truncate_size, + u8 copy_from_flags) +{ + struct ceph_osd_req_op *op; + struct page **pages; + void *p, *end; + + pages = ceph_alloc_page_vector(1, GFP_KERNEL); + if (IS_ERR(pages)) + return PTR_ERR(pages); + + op = osd_req_op_init(req, 0, CEPH_OSD_OP_COPY_FROM2, + dst_fadvise_flags); + op->copy_from.snapid = src_snapid; + op->copy_from.src_version = src_version; + op->copy_from.flags = copy_from_flags; + op->copy_from.src_fadvise_flags = src_fadvise_flags; + + p = page_address(pages[0]); + end = p + PAGE_SIZE; + ceph_encode_string(&p, end, src_oid->name, src_oid->name_len); + encode_oloc(&p, end, src_oloc); + ceph_encode_32(&p, truncate_seq); + ceph_encode_64(&p, truncate_size); + op->indata_len = PAGE_SIZE - (end - p); + + ceph_osd_data_pages_init(&op->copy_from.osd_data, pages, + op->indata_len, 0, false, true); + return 0; +} +EXPORT_SYMBOL(osd_req_op_copy_from_init); + +int __init ceph_osdc_setup(void) +{ + size_t size = sizeof(struct ceph_osd_request) + + CEPH_OSD_SLAB_OPS * sizeof(struct ceph_osd_req_op); + + BUG_ON(ceph_osd_request_cache); + ceph_osd_request_cache = kmem_cache_create("ceph_osd_request", size, + 0, 0, NULL); + + return ceph_osd_request_cache ? 0 : -ENOMEM; +} + +void ceph_osdc_cleanup(void) +{ + BUG_ON(!ceph_osd_request_cache); + kmem_cache_destroy(ceph_osd_request_cache); + ceph_osd_request_cache = NULL; +} + +/* + * handle incoming message + */ +static void osd_dispatch(struct ceph_connection *con, struct ceph_msg *msg) +{ + struct ceph_osd *osd = con->private; + struct ceph_osd_client *osdc = osd->o_osdc; + int type = le16_to_cpu(msg->hdr.type); + + switch (type) { + case CEPH_MSG_OSD_MAP: + ceph_osdc_handle_map(osdc, msg); + break; + case CEPH_MSG_OSD_OPREPLY: + handle_reply(osd, msg); + break; + case CEPH_MSG_OSD_BACKOFF: + handle_backoff(osd, msg); + break; + case CEPH_MSG_WATCH_NOTIFY: + handle_watch_notify(osdc, msg); + break; + + default: + pr_err("received unknown message type %d %s\n", type, + ceph_msg_type_name(type)); + } + + ceph_msg_put(msg); +} + +/* + * Lookup and return message for incoming reply. Don't try to do + * anything about a larger than preallocated data portion of the + * message at the moment - for now, just skip the message. + */ +static struct ceph_msg *get_reply(struct ceph_connection *con, + struct ceph_msg_header *hdr, + int *skip) +{ + struct ceph_osd *osd = con->private; + struct ceph_osd_client *osdc = osd->o_osdc; + struct ceph_msg *m = NULL; + struct ceph_osd_request *req; + int front_len = le32_to_cpu(hdr->front_len); + int data_len = le32_to_cpu(hdr->data_len); + u64 tid = le64_to_cpu(hdr->tid); + + down_read(&osdc->lock); + if (!osd_registered(osd)) { + dout("%s osd%d unknown, skipping\n", __func__, osd->o_osd); + *skip = 1; + goto out_unlock_osdc; + } + WARN_ON(osd->o_osd != le64_to_cpu(hdr->src.num)); + + mutex_lock(&osd->lock); + req = lookup_request(&osd->o_requests, tid); + if (!req) { + dout("%s osd%d tid %llu unknown, skipping\n", __func__, + osd->o_osd, tid); + *skip = 1; + goto out_unlock_session; + } + + ceph_msg_revoke_incoming(req->r_reply); + + if (front_len > req->r_reply->front_alloc_len) { + pr_warn("%s osd%d tid %llu front %d > preallocated %d\n", + __func__, osd->o_osd, req->r_tid, front_len, + req->r_reply->front_alloc_len); + m = ceph_msg_new(CEPH_MSG_OSD_OPREPLY, front_len, GFP_NOFS, + false); + if (!m) + goto out_unlock_session; + ceph_msg_put(req->r_reply); + req->r_reply = m; + } + + if (data_len > req->r_reply->data_length) { + pr_warn("%s osd%d tid %llu data %d > preallocated %zu, skipping\n", + __func__, osd->o_osd, req->r_tid, data_len, + req->r_reply->data_length); + m = NULL; + *skip = 1; + goto out_unlock_session; + } + + m = ceph_msg_get(req->r_reply); + dout("get_reply tid %lld %p\n", tid, m); + +out_unlock_session: + mutex_unlock(&osd->lock); +out_unlock_osdc: + up_read(&osdc->lock); + return m; +} + +static struct ceph_msg *alloc_msg_with_page_vector(struct ceph_msg_header *hdr) +{ + struct ceph_msg *m; + int type = le16_to_cpu(hdr->type); + u32 front_len = le32_to_cpu(hdr->front_len); + u32 data_len = le32_to_cpu(hdr->data_len); + + m = ceph_msg_new2(type, front_len, 1, GFP_NOIO, false); + if (!m) + return NULL; + + if (data_len) { + struct page **pages; + + pages = ceph_alloc_page_vector(calc_pages_for(0, data_len), + GFP_NOIO); + if (IS_ERR(pages)) { + ceph_msg_put(m); + return NULL; + } + + ceph_msg_data_add_pages(m, pages, data_len, 0, true); + } + + return m; +} + +static struct ceph_msg *osd_alloc_msg(struct ceph_connection *con, + struct ceph_msg_header *hdr, + int *skip) +{ + struct ceph_osd *osd = con->private; + int type = le16_to_cpu(hdr->type); + + *skip = 0; + switch (type) { + case CEPH_MSG_OSD_MAP: + case CEPH_MSG_OSD_BACKOFF: + case CEPH_MSG_WATCH_NOTIFY: + return alloc_msg_with_page_vector(hdr); + case CEPH_MSG_OSD_OPREPLY: + return get_reply(con, hdr, skip); + default: + pr_warn("%s osd%d unknown msg type %d, skipping\n", __func__, + osd->o_osd, type); + *skip = 1; + return NULL; + } +} + +/* + * Wrappers to refcount containing ceph_osd struct + */ +static struct ceph_connection *osd_get_con(struct ceph_connection *con) +{ + struct ceph_osd *osd = con->private; + if (get_osd(osd)) + return con; + return NULL; +} + +static void osd_put_con(struct ceph_connection *con) +{ + struct ceph_osd *osd = con->private; + put_osd(osd); +} + +/* + * authentication + */ + +/* + * Note: returned pointer is the address of a structure that's + * managed separately. Caller must *not* attempt to free it. + */ +static struct ceph_auth_handshake * +osd_get_authorizer(struct ceph_connection *con, int *proto, int force_new) +{ + struct ceph_osd *o = con->private; + struct ceph_osd_client *osdc = o->o_osdc; + struct ceph_auth_client *ac = osdc->client->monc.auth; + struct ceph_auth_handshake *auth = &o->o_auth; + int ret; + + ret = __ceph_auth_get_authorizer(ac, auth, CEPH_ENTITY_TYPE_OSD, + force_new, proto, NULL, NULL); + if (ret) + return ERR_PTR(ret); + + return auth; +} + +static int osd_add_authorizer_challenge(struct ceph_connection *con, + void *challenge_buf, int challenge_buf_len) +{ + struct ceph_osd *o = con->private; + struct ceph_osd_client *osdc = o->o_osdc; + struct ceph_auth_client *ac = osdc->client->monc.auth; + + return ceph_auth_add_authorizer_challenge(ac, o->o_auth.authorizer, + challenge_buf, challenge_buf_len); +} + +static int osd_verify_authorizer_reply(struct ceph_connection *con) +{ + struct ceph_osd *o = con->private; + struct ceph_osd_client *osdc = o->o_osdc; + struct ceph_auth_client *ac = osdc->client->monc.auth; + struct ceph_auth_handshake *auth = &o->o_auth; + + return ceph_auth_verify_authorizer_reply(ac, auth->authorizer, + auth->authorizer_reply_buf, auth->authorizer_reply_buf_len, + NULL, NULL, NULL, NULL); +} + +static int osd_invalidate_authorizer(struct ceph_connection *con) +{ + struct ceph_osd *o = con->private; + struct ceph_osd_client *osdc = o->o_osdc; + struct ceph_auth_client *ac = osdc->client->monc.auth; + + ceph_auth_invalidate_authorizer(ac, CEPH_ENTITY_TYPE_OSD); + return ceph_monc_validate_auth(&osdc->client->monc); +} + +static int osd_get_auth_request(struct ceph_connection *con, + void *buf, int *buf_len, + void **authorizer, int *authorizer_len) +{ + struct ceph_osd *o = con->private; + struct ceph_auth_client *ac = o->o_osdc->client->monc.auth; + struct ceph_auth_handshake *auth = &o->o_auth; + int ret; + + ret = ceph_auth_get_authorizer(ac, auth, CEPH_ENTITY_TYPE_OSD, + buf, buf_len); + if (ret) + return ret; + + *authorizer = auth->authorizer_buf; + *authorizer_len = auth->authorizer_buf_len; + return 0; +} + +static int osd_handle_auth_reply_more(struct ceph_connection *con, + void *reply, int reply_len, + void *buf, int *buf_len, + void **authorizer, int *authorizer_len) +{ + struct ceph_osd *o = con->private; + struct ceph_auth_client *ac = o->o_osdc->client->monc.auth; + struct ceph_auth_handshake *auth = &o->o_auth; + int ret; + + ret = ceph_auth_handle_svc_reply_more(ac, auth, reply, reply_len, + buf, buf_len); + if (ret) + return ret; + + *authorizer = auth->authorizer_buf; + *authorizer_len = auth->authorizer_buf_len; + return 0; +} + +static int osd_handle_auth_done(struct ceph_connection *con, + u64 global_id, void *reply, int reply_len, + u8 *session_key, int *session_key_len, + u8 *con_secret, int *con_secret_len) +{ + struct ceph_osd *o = con->private; + struct ceph_auth_client *ac = o->o_osdc->client->monc.auth; + struct ceph_auth_handshake *auth = &o->o_auth; + + return ceph_auth_handle_svc_reply_done(ac, auth, reply, reply_len, + session_key, session_key_len, + con_secret, con_secret_len); +} + +static int osd_handle_auth_bad_method(struct ceph_connection *con, + int used_proto, int result, + const int *allowed_protos, int proto_cnt, + const int *allowed_modes, int mode_cnt) +{ + struct ceph_osd *o = con->private; + struct ceph_mon_client *monc = &o->o_osdc->client->monc; + int ret; + + if (ceph_auth_handle_bad_authorizer(monc->auth, CEPH_ENTITY_TYPE_OSD, + used_proto, result, + allowed_protos, proto_cnt, + allowed_modes, mode_cnt)) { + ret = ceph_monc_validate_auth(monc); + if (ret) + return ret; + } + + return -EACCES; +} + +static void osd_reencode_message(struct ceph_msg *msg) +{ + int type = le16_to_cpu(msg->hdr.type); + + if (type == CEPH_MSG_OSD_OP) + encode_request_finish(msg); +} + +static int osd_sign_message(struct ceph_msg *msg) +{ + struct ceph_osd *o = msg->con->private; + struct ceph_auth_handshake *auth = &o->o_auth; + + return ceph_auth_sign_message(auth, msg); +} + +static int osd_check_message_signature(struct ceph_msg *msg) +{ + struct ceph_osd *o = msg->con->private; + struct ceph_auth_handshake *auth = &o->o_auth; + + return ceph_auth_check_message_signature(auth, msg); +} + +static const struct ceph_connection_operations osd_con_ops = { + .get = osd_get_con, + .put = osd_put_con, + .alloc_msg = osd_alloc_msg, + .dispatch = osd_dispatch, + .fault = osd_fault, + .reencode_message = osd_reencode_message, + .get_authorizer = osd_get_authorizer, + .add_authorizer_challenge = osd_add_authorizer_challenge, + .verify_authorizer_reply = osd_verify_authorizer_reply, + .invalidate_authorizer = osd_invalidate_authorizer, + .sign_message = osd_sign_message, + .check_message_signature = osd_check_message_signature, + .get_auth_request = osd_get_auth_request, + .handle_auth_reply_more = osd_handle_auth_reply_more, + .handle_auth_done = osd_handle_auth_done, + .handle_auth_bad_method = osd_handle_auth_bad_method, +}; diff --git a/net/ceph/osdmap.c b/net/ceph/osdmap.c new file mode 100644 index 000000000..295098873 --- /dev/null +++ b/net/ceph/osdmap.c @@ -0,0 +1,3106 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include + +#include +#include + +#include +#include +#include +#include +#include + +static __printf(2, 3) +void osdmap_info(const struct ceph_osdmap *map, const char *fmt, ...) +{ + struct va_format vaf; + va_list args; + + va_start(args, fmt); + vaf.fmt = fmt; + vaf.va = &args; + + printk(KERN_INFO "%s (%pU e%u): %pV", KBUILD_MODNAME, &map->fsid, + map->epoch, &vaf); + + va_end(args); +} + +char *ceph_osdmap_state_str(char *str, int len, u32 state) +{ + if (!len) + return str; + + if ((state & CEPH_OSD_EXISTS) && (state & CEPH_OSD_UP)) + snprintf(str, len, "exists, up"); + else if (state & CEPH_OSD_EXISTS) + snprintf(str, len, "exists"); + else if (state & CEPH_OSD_UP) + snprintf(str, len, "up"); + else + snprintf(str, len, "doesn't exist"); + + return str; +} + +/* maps */ + +static int calc_bits_of(unsigned int t) +{ + int b = 0; + while (t) { + t = t >> 1; + b++; + } + return b; +} + +/* + * the foo_mask is the smallest value 2^n-1 that is >= foo. + */ +static void calc_pg_masks(struct ceph_pg_pool_info *pi) +{ + pi->pg_num_mask = (1 << calc_bits_of(pi->pg_num-1)) - 1; + pi->pgp_num_mask = (1 << calc_bits_of(pi->pgp_num-1)) - 1; +} + +/* + * decode crush map + */ +static int crush_decode_uniform_bucket(void **p, void *end, + struct crush_bucket_uniform *b) +{ + dout("crush_decode_uniform_bucket %p to %p\n", *p, end); + ceph_decode_need(p, end, (1+b->h.size) * sizeof(u32), bad); + b->item_weight = ceph_decode_32(p); + return 0; +bad: + return -EINVAL; +} + +static int crush_decode_list_bucket(void **p, void *end, + struct crush_bucket_list *b) +{ + int j; + dout("crush_decode_list_bucket %p to %p\n", *p, end); + b->item_weights = kcalloc(b->h.size, sizeof(u32), GFP_NOFS); + if (b->item_weights == NULL) + return -ENOMEM; + b->sum_weights = kcalloc(b->h.size, sizeof(u32), GFP_NOFS); + if (b->sum_weights == NULL) + return -ENOMEM; + ceph_decode_need(p, end, 2 * b->h.size * sizeof(u32), bad); + for (j = 0; j < b->h.size; j++) { + b->item_weights[j] = ceph_decode_32(p); + b->sum_weights[j] = ceph_decode_32(p); + } + return 0; +bad: + return -EINVAL; +} + +static int crush_decode_tree_bucket(void **p, void *end, + struct crush_bucket_tree *b) +{ + int j; + dout("crush_decode_tree_bucket %p to %p\n", *p, end); + ceph_decode_8_safe(p, end, b->num_nodes, bad); + b->node_weights = kcalloc(b->num_nodes, sizeof(u32), GFP_NOFS); + if (b->node_weights == NULL) + return -ENOMEM; + ceph_decode_need(p, end, b->num_nodes * sizeof(u32), bad); + for (j = 0; j < b->num_nodes; j++) + b->node_weights[j] = ceph_decode_32(p); + return 0; +bad: + return -EINVAL; +} + +static int crush_decode_straw_bucket(void **p, void *end, + struct crush_bucket_straw *b) +{ + int j; + dout("crush_decode_straw_bucket %p to %p\n", *p, end); + b->item_weights = kcalloc(b->h.size, sizeof(u32), GFP_NOFS); + if (b->item_weights == NULL) + return -ENOMEM; + b->straws = kcalloc(b->h.size, sizeof(u32), GFP_NOFS); + if (b->straws == NULL) + return -ENOMEM; + ceph_decode_need(p, end, 2 * b->h.size * sizeof(u32), bad); + for (j = 0; j < b->h.size; j++) { + b->item_weights[j] = ceph_decode_32(p); + b->straws[j] = ceph_decode_32(p); + } + return 0; +bad: + return -EINVAL; +} + +static int crush_decode_straw2_bucket(void **p, void *end, + struct crush_bucket_straw2 *b) +{ + int j; + dout("crush_decode_straw2_bucket %p to %p\n", *p, end); + b->item_weights = kcalloc(b->h.size, sizeof(u32), GFP_NOFS); + if (b->item_weights == NULL) + return -ENOMEM; + ceph_decode_need(p, end, b->h.size * sizeof(u32), bad); + for (j = 0; j < b->h.size; j++) + b->item_weights[j] = ceph_decode_32(p); + return 0; +bad: + return -EINVAL; +} + +struct crush_name_node { + struct rb_node cn_node; + int cn_id; + char cn_name[]; +}; + +static struct crush_name_node *alloc_crush_name(size_t name_len) +{ + struct crush_name_node *cn; + + cn = kmalloc(sizeof(*cn) + name_len + 1, GFP_NOIO); + if (!cn) + return NULL; + + RB_CLEAR_NODE(&cn->cn_node); + return cn; +} + +static void free_crush_name(struct crush_name_node *cn) +{ + WARN_ON(!RB_EMPTY_NODE(&cn->cn_node)); + + kfree(cn); +} + +DEFINE_RB_FUNCS(crush_name, struct crush_name_node, cn_id, cn_node) + +static int decode_crush_names(void **p, void *end, struct rb_root *root) +{ + u32 n; + + ceph_decode_32_safe(p, end, n, e_inval); + while (n--) { + struct crush_name_node *cn; + int id; + u32 name_len; + + ceph_decode_32_safe(p, end, id, e_inval); + ceph_decode_32_safe(p, end, name_len, e_inval); + ceph_decode_need(p, end, name_len, e_inval); + + cn = alloc_crush_name(name_len); + if (!cn) + return -ENOMEM; + + cn->cn_id = id; + memcpy(cn->cn_name, *p, name_len); + cn->cn_name[name_len] = '\0'; + *p += name_len; + + if (!__insert_crush_name(root, cn)) { + free_crush_name(cn); + return -EEXIST; + } + } + + return 0; + +e_inval: + return -EINVAL; +} + +void clear_crush_names(struct rb_root *root) +{ + while (!RB_EMPTY_ROOT(root)) { + struct crush_name_node *cn = + rb_entry(rb_first(root), struct crush_name_node, cn_node); + + erase_crush_name(root, cn); + free_crush_name(cn); + } +} + +static struct crush_choose_arg_map *alloc_choose_arg_map(void) +{ + struct crush_choose_arg_map *arg_map; + + arg_map = kzalloc(sizeof(*arg_map), GFP_NOIO); + if (!arg_map) + return NULL; + + RB_CLEAR_NODE(&arg_map->node); + return arg_map; +} + +static void free_choose_arg_map(struct crush_choose_arg_map *arg_map) +{ + if (arg_map) { + int i, j; + + WARN_ON(!RB_EMPTY_NODE(&arg_map->node)); + + for (i = 0; i < arg_map->size; i++) { + struct crush_choose_arg *arg = &arg_map->args[i]; + + for (j = 0; j < arg->weight_set_size; j++) + kfree(arg->weight_set[j].weights); + kfree(arg->weight_set); + kfree(arg->ids); + } + kfree(arg_map->args); + kfree(arg_map); + } +} + +DEFINE_RB_FUNCS(choose_arg_map, struct crush_choose_arg_map, choose_args_index, + node); + +void clear_choose_args(struct crush_map *c) +{ + while (!RB_EMPTY_ROOT(&c->choose_args)) { + struct crush_choose_arg_map *arg_map = + rb_entry(rb_first(&c->choose_args), + struct crush_choose_arg_map, node); + + erase_choose_arg_map(&c->choose_args, arg_map); + free_choose_arg_map(arg_map); + } +} + +static u32 *decode_array_32_alloc(void **p, void *end, u32 *plen) +{ + u32 *a = NULL; + u32 len; + int ret; + + ceph_decode_32_safe(p, end, len, e_inval); + if (len) { + u32 i; + + a = kmalloc_array(len, sizeof(u32), GFP_NOIO); + if (!a) { + ret = -ENOMEM; + goto fail; + } + + ceph_decode_need(p, end, len * sizeof(u32), e_inval); + for (i = 0; i < len; i++) + a[i] = ceph_decode_32(p); + } + + *plen = len; + return a; + +e_inval: + ret = -EINVAL; +fail: + kfree(a); + return ERR_PTR(ret); +} + +/* + * Assumes @arg is zero-initialized. + */ +static int decode_choose_arg(void **p, void *end, struct crush_choose_arg *arg) +{ + int ret; + + ceph_decode_32_safe(p, end, arg->weight_set_size, e_inval); + if (arg->weight_set_size) { + u32 i; + + arg->weight_set = kmalloc_array(arg->weight_set_size, + sizeof(*arg->weight_set), + GFP_NOIO); + if (!arg->weight_set) + return -ENOMEM; + + for (i = 0; i < arg->weight_set_size; i++) { + struct crush_weight_set *w = &arg->weight_set[i]; + + w->weights = decode_array_32_alloc(p, end, &w->size); + if (IS_ERR(w->weights)) { + ret = PTR_ERR(w->weights); + w->weights = NULL; + return ret; + } + } + } + + arg->ids = decode_array_32_alloc(p, end, &arg->ids_size); + if (IS_ERR(arg->ids)) { + ret = PTR_ERR(arg->ids); + arg->ids = NULL; + return ret; + } + + return 0; + +e_inval: + return -EINVAL; +} + +static int decode_choose_args(void **p, void *end, struct crush_map *c) +{ + struct crush_choose_arg_map *arg_map = NULL; + u32 num_choose_arg_maps, num_buckets; + int ret; + + ceph_decode_32_safe(p, end, num_choose_arg_maps, e_inval); + while (num_choose_arg_maps--) { + arg_map = alloc_choose_arg_map(); + if (!arg_map) { + ret = -ENOMEM; + goto fail; + } + + ceph_decode_64_safe(p, end, arg_map->choose_args_index, + e_inval); + arg_map->size = c->max_buckets; + arg_map->args = kcalloc(arg_map->size, sizeof(*arg_map->args), + GFP_NOIO); + if (!arg_map->args) { + ret = -ENOMEM; + goto fail; + } + + ceph_decode_32_safe(p, end, num_buckets, e_inval); + while (num_buckets--) { + struct crush_choose_arg *arg; + u32 bucket_index; + + ceph_decode_32_safe(p, end, bucket_index, e_inval); + if (bucket_index >= arg_map->size) + goto e_inval; + + arg = &arg_map->args[bucket_index]; + ret = decode_choose_arg(p, end, arg); + if (ret) + goto fail; + + if (arg->ids_size && + arg->ids_size != c->buckets[bucket_index]->size) + goto e_inval; + } + + insert_choose_arg_map(&c->choose_args, arg_map); + } + + return 0; + +e_inval: + ret = -EINVAL; +fail: + free_choose_arg_map(arg_map); + return ret; +} + +static void crush_finalize(struct crush_map *c) +{ + __s32 b; + + /* Space for the array of pointers to per-bucket workspace */ + c->working_size = sizeof(struct crush_work) + + c->max_buckets * sizeof(struct crush_work_bucket *); + + for (b = 0; b < c->max_buckets; b++) { + if (!c->buckets[b]) + continue; + + switch (c->buckets[b]->alg) { + default: + /* + * The base case, permutation variables and + * the pointer to the permutation array. + */ + c->working_size += sizeof(struct crush_work_bucket); + break; + } + /* Every bucket has a permutation array. */ + c->working_size += c->buckets[b]->size * sizeof(__u32); + } +} + +static struct crush_map *crush_decode(void *pbyval, void *end) +{ + struct crush_map *c; + int err; + int i, j; + void **p = &pbyval; + void *start = pbyval; + u32 magic; + + dout("crush_decode %p to %p len %d\n", *p, end, (int)(end - *p)); + + c = kzalloc(sizeof(*c), GFP_NOFS); + if (c == NULL) + return ERR_PTR(-ENOMEM); + + c->type_names = RB_ROOT; + c->names = RB_ROOT; + c->choose_args = RB_ROOT; + + /* set tunables to default values */ + c->choose_local_tries = 2; + c->choose_local_fallback_tries = 5; + c->choose_total_tries = 19; + c->chooseleaf_descend_once = 0; + + ceph_decode_need(p, end, 4*sizeof(u32), bad); + magic = ceph_decode_32(p); + if (magic != CRUSH_MAGIC) { + pr_err("crush_decode magic %x != current %x\n", + (unsigned int)magic, (unsigned int)CRUSH_MAGIC); + goto bad; + } + c->max_buckets = ceph_decode_32(p); + c->max_rules = ceph_decode_32(p); + c->max_devices = ceph_decode_32(p); + + c->buckets = kcalloc(c->max_buckets, sizeof(*c->buckets), GFP_NOFS); + if (c->buckets == NULL) + goto badmem; + c->rules = kcalloc(c->max_rules, sizeof(*c->rules), GFP_NOFS); + if (c->rules == NULL) + goto badmem; + + /* buckets */ + for (i = 0; i < c->max_buckets; i++) { + int size = 0; + u32 alg; + struct crush_bucket *b; + + ceph_decode_32_safe(p, end, alg, bad); + if (alg == 0) { + c->buckets[i] = NULL; + continue; + } + dout("crush_decode bucket %d off %x %p to %p\n", + i, (int)(*p-start), *p, end); + + switch (alg) { + case CRUSH_BUCKET_UNIFORM: + size = sizeof(struct crush_bucket_uniform); + break; + case CRUSH_BUCKET_LIST: + size = sizeof(struct crush_bucket_list); + break; + case CRUSH_BUCKET_TREE: + size = sizeof(struct crush_bucket_tree); + break; + case CRUSH_BUCKET_STRAW: + size = sizeof(struct crush_bucket_straw); + break; + case CRUSH_BUCKET_STRAW2: + size = sizeof(struct crush_bucket_straw2); + break; + default: + goto bad; + } + BUG_ON(size == 0); + b = c->buckets[i] = kzalloc(size, GFP_NOFS); + if (b == NULL) + goto badmem; + + ceph_decode_need(p, end, 4*sizeof(u32), bad); + b->id = ceph_decode_32(p); + b->type = ceph_decode_16(p); + b->alg = ceph_decode_8(p); + b->hash = ceph_decode_8(p); + b->weight = ceph_decode_32(p); + b->size = ceph_decode_32(p); + + dout("crush_decode bucket size %d off %x %p to %p\n", + b->size, (int)(*p-start), *p, end); + + b->items = kcalloc(b->size, sizeof(__s32), GFP_NOFS); + if (b->items == NULL) + goto badmem; + + ceph_decode_need(p, end, b->size*sizeof(u32), bad); + for (j = 0; j < b->size; j++) + b->items[j] = ceph_decode_32(p); + + switch (b->alg) { + case CRUSH_BUCKET_UNIFORM: + err = crush_decode_uniform_bucket(p, end, + (struct crush_bucket_uniform *)b); + if (err < 0) + goto fail; + break; + case CRUSH_BUCKET_LIST: + err = crush_decode_list_bucket(p, end, + (struct crush_bucket_list *)b); + if (err < 0) + goto fail; + break; + case CRUSH_BUCKET_TREE: + err = crush_decode_tree_bucket(p, end, + (struct crush_bucket_tree *)b); + if (err < 0) + goto fail; + break; + case CRUSH_BUCKET_STRAW: + err = crush_decode_straw_bucket(p, end, + (struct crush_bucket_straw *)b); + if (err < 0) + goto fail; + break; + case CRUSH_BUCKET_STRAW2: + err = crush_decode_straw2_bucket(p, end, + (struct crush_bucket_straw2 *)b); + if (err < 0) + goto fail; + break; + } + } + + /* rules */ + dout("rule vec is %p\n", c->rules); + for (i = 0; i < c->max_rules; i++) { + u32 yes; + struct crush_rule *r; + + ceph_decode_32_safe(p, end, yes, bad); + if (!yes) { + dout("crush_decode NO rule %d off %x %p to %p\n", + i, (int)(*p-start), *p, end); + c->rules[i] = NULL; + continue; + } + + dout("crush_decode rule %d off %x %p to %p\n", + i, (int)(*p-start), *p, end); + + /* len */ + ceph_decode_32_safe(p, end, yes, bad); +#if BITS_PER_LONG == 32 + if (yes > (ULONG_MAX - sizeof(*r)) + / sizeof(struct crush_rule_step)) + goto bad; +#endif + r = kmalloc(struct_size(r, steps, yes), GFP_NOFS); + if (r == NULL) + goto badmem; + dout(" rule %d is at %p\n", i, r); + c->rules[i] = r; + r->len = yes; + ceph_decode_copy_safe(p, end, &r->mask, 4, bad); /* 4 u8's */ + ceph_decode_need(p, end, r->len*3*sizeof(u32), bad); + for (j = 0; j < r->len; j++) { + r->steps[j].op = ceph_decode_32(p); + r->steps[j].arg1 = ceph_decode_32(p); + r->steps[j].arg2 = ceph_decode_32(p); + } + } + + err = decode_crush_names(p, end, &c->type_names); + if (err) + goto fail; + + err = decode_crush_names(p, end, &c->names); + if (err) + goto fail; + + ceph_decode_skip_map(p, end, 32, string, bad); /* rule_name_map */ + + /* tunables */ + ceph_decode_need(p, end, 3*sizeof(u32), done); + c->choose_local_tries = ceph_decode_32(p); + c->choose_local_fallback_tries = ceph_decode_32(p); + c->choose_total_tries = ceph_decode_32(p); + dout("crush decode tunable choose_local_tries = %d\n", + c->choose_local_tries); + dout("crush decode tunable choose_local_fallback_tries = %d\n", + c->choose_local_fallback_tries); + dout("crush decode tunable choose_total_tries = %d\n", + c->choose_total_tries); + + ceph_decode_need(p, end, sizeof(u32), done); + c->chooseleaf_descend_once = ceph_decode_32(p); + dout("crush decode tunable chooseleaf_descend_once = %d\n", + c->chooseleaf_descend_once); + + ceph_decode_need(p, end, sizeof(u8), done); + c->chooseleaf_vary_r = ceph_decode_8(p); + dout("crush decode tunable chooseleaf_vary_r = %d\n", + c->chooseleaf_vary_r); + + /* skip straw_calc_version, allowed_bucket_algs */ + ceph_decode_need(p, end, sizeof(u8) + sizeof(u32), done); + *p += sizeof(u8) + sizeof(u32); + + ceph_decode_need(p, end, sizeof(u8), done); + c->chooseleaf_stable = ceph_decode_8(p); + dout("crush decode tunable chooseleaf_stable = %d\n", + c->chooseleaf_stable); + + if (*p != end) { + /* class_map */ + ceph_decode_skip_map(p, end, 32, 32, bad); + /* class_name */ + ceph_decode_skip_map(p, end, 32, string, bad); + /* class_bucket */ + ceph_decode_skip_map_of_map(p, end, 32, 32, 32, bad); + } + + if (*p != end) { + err = decode_choose_args(p, end, c); + if (err) + goto fail; + } + +done: + crush_finalize(c); + dout("crush_decode success\n"); + return c; + +badmem: + err = -ENOMEM; +fail: + dout("crush_decode fail %d\n", err); + crush_destroy(c); + return ERR_PTR(err); + +bad: + err = -EINVAL; + goto fail; +} + +int ceph_pg_compare(const struct ceph_pg *lhs, const struct ceph_pg *rhs) +{ + if (lhs->pool < rhs->pool) + return -1; + if (lhs->pool > rhs->pool) + return 1; + if (lhs->seed < rhs->seed) + return -1; + if (lhs->seed > rhs->seed) + return 1; + + return 0; +} + +int ceph_spg_compare(const struct ceph_spg *lhs, const struct ceph_spg *rhs) +{ + int ret; + + ret = ceph_pg_compare(&lhs->pgid, &rhs->pgid); + if (ret) + return ret; + + if (lhs->shard < rhs->shard) + return -1; + if (lhs->shard > rhs->shard) + return 1; + + return 0; +} + +static struct ceph_pg_mapping *alloc_pg_mapping(size_t payload_len) +{ + struct ceph_pg_mapping *pg; + + pg = kmalloc(sizeof(*pg) + payload_len, GFP_NOIO); + if (!pg) + return NULL; + + RB_CLEAR_NODE(&pg->node); + return pg; +} + +static void free_pg_mapping(struct ceph_pg_mapping *pg) +{ + WARN_ON(!RB_EMPTY_NODE(&pg->node)); + + kfree(pg); +} + +/* + * rbtree of pg_mapping for handling pg_temp (explicit mapping of pgid + * to a set of osds) and primary_temp (explicit primary setting) + */ +DEFINE_RB_FUNCS2(pg_mapping, struct ceph_pg_mapping, pgid, ceph_pg_compare, + RB_BYPTR, const struct ceph_pg *, node) + +/* + * rbtree of pg pool info + */ +DEFINE_RB_FUNCS(pg_pool, struct ceph_pg_pool_info, id, node) + +struct ceph_pg_pool_info *ceph_pg_pool_by_id(struct ceph_osdmap *map, u64 id) +{ + return lookup_pg_pool(&map->pg_pools, id); +} + +const char *ceph_pg_pool_name_by_id(struct ceph_osdmap *map, u64 id) +{ + struct ceph_pg_pool_info *pi; + + if (id == CEPH_NOPOOL) + return NULL; + + if (WARN_ON_ONCE(id > (u64) INT_MAX)) + return NULL; + + pi = lookup_pg_pool(&map->pg_pools, id); + return pi ? pi->name : NULL; +} +EXPORT_SYMBOL(ceph_pg_pool_name_by_id); + +int ceph_pg_poolid_by_name(struct ceph_osdmap *map, const char *name) +{ + struct rb_node *rbp; + + for (rbp = rb_first(&map->pg_pools); rbp; rbp = rb_next(rbp)) { + struct ceph_pg_pool_info *pi = + rb_entry(rbp, struct ceph_pg_pool_info, node); + if (pi->name && strcmp(pi->name, name) == 0) + return pi->id; + } + return -ENOENT; +} +EXPORT_SYMBOL(ceph_pg_poolid_by_name); + +u64 ceph_pg_pool_flags(struct ceph_osdmap *map, u64 id) +{ + struct ceph_pg_pool_info *pi; + + pi = lookup_pg_pool(&map->pg_pools, id); + return pi ? pi->flags : 0; +} +EXPORT_SYMBOL(ceph_pg_pool_flags); + +static void __remove_pg_pool(struct rb_root *root, struct ceph_pg_pool_info *pi) +{ + erase_pg_pool(root, pi); + kfree(pi->name); + kfree(pi); +} + +static int decode_pool(void **p, void *end, struct ceph_pg_pool_info *pi) +{ + u8 ev, cv; + unsigned len, num; + void *pool_end; + + ceph_decode_need(p, end, 2 + 4, bad); + ev = ceph_decode_8(p); /* encoding version */ + cv = ceph_decode_8(p); /* compat version */ + if (ev < 5) { + pr_warn("got v %d < 5 cv %d of ceph_pg_pool\n", ev, cv); + return -EINVAL; + } + if (cv > 9) { + pr_warn("got v %d cv %d > 9 of ceph_pg_pool\n", ev, cv); + return -EINVAL; + } + len = ceph_decode_32(p); + ceph_decode_need(p, end, len, bad); + pool_end = *p + len; + + pi->type = ceph_decode_8(p); + pi->size = ceph_decode_8(p); + pi->crush_ruleset = ceph_decode_8(p); + pi->object_hash = ceph_decode_8(p); + + pi->pg_num = ceph_decode_32(p); + pi->pgp_num = ceph_decode_32(p); + + *p += 4 + 4; /* skip lpg* */ + *p += 4; /* skip last_change */ + *p += 8 + 4; /* skip snap_seq, snap_epoch */ + + /* skip snaps */ + num = ceph_decode_32(p); + while (num--) { + *p += 8; /* snapid key */ + *p += 1 + 1; /* versions */ + len = ceph_decode_32(p); + *p += len; + } + + /* skip removed_snaps */ + num = ceph_decode_32(p); + *p += num * (8 + 8); + + *p += 8; /* skip auid */ + pi->flags = ceph_decode_64(p); + *p += 4; /* skip crash_replay_interval */ + + if (ev >= 7) + pi->min_size = ceph_decode_8(p); + else + pi->min_size = pi->size - pi->size / 2; + + if (ev >= 8) + *p += 8 + 8; /* skip quota_max_* */ + + if (ev >= 9) { + /* skip tiers */ + num = ceph_decode_32(p); + *p += num * 8; + + *p += 8; /* skip tier_of */ + *p += 1; /* skip cache_mode */ + + pi->read_tier = ceph_decode_64(p); + pi->write_tier = ceph_decode_64(p); + } else { + pi->read_tier = -1; + pi->write_tier = -1; + } + + if (ev >= 10) { + /* skip properties */ + num = ceph_decode_32(p); + while (num--) { + len = ceph_decode_32(p); + *p += len; /* key */ + len = ceph_decode_32(p); + *p += len; /* val */ + } + } + + if (ev >= 11) { + /* skip hit_set_params */ + *p += 1 + 1; /* versions */ + len = ceph_decode_32(p); + *p += len; + + *p += 4; /* skip hit_set_period */ + *p += 4; /* skip hit_set_count */ + } + + if (ev >= 12) + *p += 4; /* skip stripe_width */ + + if (ev >= 13) { + *p += 8; /* skip target_max_bytes */ + *p += 8; /* skip target_max_objects */ + *p += 4; /* skip cache_target_dirty_ratio_micro */ + *p += 4; /* skip cache_target_full_ratio_micro */ + *p += 4; /* skip cache_min_flush_age */ + *p += 4; /* skip cache_min_evict_age */ + } + + if (ev >= 14) { + /* skip erasure_code_profile */ + len = ceph_decode_32(p); + *p += len; + } + + /* + * last_force_op_resend_preluminous, will be overridden if the + * map was encoded with RESEND_ON_SPLIT + */ + if (ev >= 15) + pi->last_force_request_resend = ceph_decode_32(p); + else + pi->last_force_request_resend = 0; + + if (ev >= 16) + *p += 4; /* skip min_read_recency_for_promote */ + + if (ev >= 17) + *p += 8; /* skip expected_num_objects */ + + if (ev >= 19) + *p += 4; /* skip cache_target_dirty_high_ratio_micro */ + + if (ev >= 20) + *p += 4; /* skip min_write_recency_for_promote */ + + if (ev >= 21) + *p += 1; /* skip use_gmt_hitset */ + + if (ev >= 22) + *p += 1; /* skip fast_read */ + + if (ev >= 23) { + *p += 4; /* skip hit_set_grade_decay_rate */ + *p += 4; /* skip hit_set_search_last_n */ + } + + if (ev >= 24) { + /* skip opts */ + *p += 1 + 1; /* versions */ + len = ceph_decode_32(p); + *p += len; + } + + if (ev >= 25) + pi->last_force_request_resend = ceph_decode_32(p); + + /* ignore the rest */ + + *p = pool_end; + calc_pg_masks(pi); + return 0; + +bad: + return -EINVAL; +} + +static int decode_pool_names(void **p, void *end, struct ceph_osdmap *map) +{ + struct ceph_pg_pool_info *pi; + u32 num, len; + u64 pool; + + ceph_decode_32_safe(p, end, num, bad); + dout(" %d pool names\n", num); + while (num--) { + ceph_decode_64_safe(p, end, pool, bad); + ceph_decode_32_safe(p, end, len, bad); + dout(" pool %llu len %d\n", pool, len); + ceph_decode_need(p, end, len, bad); + pi = lookup_pg_pool(&map->pg_pools, pool); + if (pi) { + char *name = kstrndup(*p, len, GFP_NOFS); + + if (!name) + return -ENOMEM; + kfree(pi->name); + pi->name = name; + dout(" name is %s\n", pi->name); + } + *p += len; + } + return 0; + +bad: + return -EINVAL; +} + +/* + * CRUSH workspaces + * + * workspace_manager framework borrowed from fs/btrfs/compression.c. + * Two simplifications: there is only one type of workspace and there + * is always at least one workspace. + */ +static struct crush_work *alloc_workspace(const struct crush_map *c) +{ + struct crush_work *work; + size_t work_size; + + WARN_ON(!c->working_size); + work_size = crush_work_size(c, CEPH_PG_MAX_SIZE); + dout("%s work_size %zu bytes\n", __func__, work_size); + + work = kvmalloc(work_size, GFP_NOIO); + if (!work) + return NULL; + + INIT_LIST_HEAD(&work->item); + crush_init_workspace(c, work); + return work; +} + +static void free_workspace(struct crush_work *work) +{ + WARN_ON(!list_empty(&work->item)); + kvfree(work); +} + +static void init_workspace_manager(struct workspace_manager *wsm) +{ + INIT_LIST_HEAD(&wsm->idle_ws); + spin_lock_init(&wsm->ws_lock); + atomic_set(&wsm->total_ws, 0); + wsm->free_ws = 0; + init_waitqueue_head(&wsm->ws_wait); +} + +static void add_initial_workspace(struct workspace_manager *wsm, + struct crush_work *work) +{ + WARN_ON(!list_empty(&wsm->idle_ws)); + + list_add(&work->item, &wsm->idle_ws); + atomic_set(&wsm->total_ws, 1); + wsm->free_ws = 1; +} + +static void cleanup_workspace_manager(struct workspace_manager *wsm) +{ + struct crush_work *work; + + while (!list_empty(&wsm->idle_ws)) { + work = list_first_entry(&wsm->idle_ws, struct crush_work, + item); + list_del_init(&work->item); + free_workspace(work); + } + atomic_set(&wsm->total_ws, 0); + wsm->free_ws = 0; +} + +/* + * Finds an available workspace or allocates a new one. If it's not + * possible to allocate a new one, waits until there is one. + */ +static struct crush_work *get_workspace(struct workspace_manager *wsm, + const struct crush_map *c) +{ + struct crush_work *work; + int cpus = num_online_cpus(); + +again: + spin_lock(&wsm->ws_lock); + if (!list_empty(&wsm->idle_ws)) { + work = list_first_entry(&wsm->idle_ws, struct crush_work, + item); + list_del_init(&work->item); + wsm->free_ws--; + spin_unlock(&wsm->ws_lock); + return work; + + } + if (atomic_read(&wsm->total_ws) > cpus) { + DEFINE_WAIT(wait); + + spin_unlock(&wsm->ws_lock); + prepare_to_wait(&wsm->ws_wait, &wait, TASK_UNINTERRUPTIBLE); + if (atomic_read(&wsm->total_ws) > cpus && !wsm->free_ws) + schedule(); + finish_wait(&wsm->ws_wait, &wait); + goto again; + } + atomic_inc(&wsm->total_ws); + spin_unlock(&wsm->ws_lock); + + work = alloc_workspace(c); + if (!work) { + atomic_dec(&wsm->total_ws); + wake_up(&wsm->ws_wait); + + /* + * Do not return the error but go back to waiting. We + * have the initial workspace and the CRUSH computation + * time is bounded so we will get it eventually. + */ + WARN_ON(atomic_read(&wsm->total_ws) < 1); + goto again; + } + return work; +} + +/* + * Puts a workspace back on the list or frees it if we have enough + * idle ones sitting around. + */ +static void put_workspace(struct workspace_manager *wsm, + struct crush_work *work) +{ + spin_lock(&wsm->ws_lock); + if (wsm->free_ws <= num_online_cpus()) { + list_add(&work->item, &wsm->idle_ws); + wsm->free_ws++; + spin_unlock(&wsm->ws_lock); + goto wake; + } + spin_unlock(&wsm->ws_lock); + + free_workspace(work); + atomic_dec(&wsm->total_ws); +wake: + if (wq_has_sleeper(&wsm->ws_wait)) + wake_up(&wsm->ws_wait); +} + +/* + * osd map + */ +struct ceph_osdmap *ceph_osdmap_alloc(void) +{ + struct ceph_osdmap *map; + + map = kzalloc(sizeof(*map), GFP_NOIO); + if (!map) + return NULL; + + map->pg_pools = RB_ROOT; + map->pool_max = -1; + map->pg_temp = RB_ROOT; + map->primary_temp = RB_ROOT; + map->pg_upmap = RB_ROOT; + map->pg_upmap_items = RB_ROOT; + + init_workspace_manager(&map->crush_wsm); + + return map; +} + +void ceph_osdmap_destroy(struct ceph_osdmap *map) +{ + dout("osdmap_destroy %p\n", map); + + if (map->crush) + crush_destroy(map->crush); + cleanup_workspace_manager(&map->crush_wsm); + + while (!RB_EMPTY_ROOT(&map->pg_temp)) { + struct ceph_pg_mapping *pg = + rb_entry(rb_first(&map->pg_temp), + struct ceph_pg_mapping, node); + erase_pg_mapping(&map->pg_temp, pg); + free_pg_mapping(pg); + } + while (!RB_EMPTY_ROOT(&map->primary_temp)) { + struct ceph_pg_mapping *pg = + rb_entry(rb_first(&map->primary_temp), + struct ceph_pg_mapping, node); + erase_pg_mapping(&map->primary_temp, pg); + free_pg_mapping(pg); + } + while (!RB_EMPTY_ROOT(&map->pg_upmap)) { + struct ceph_pg_mapping *pg = + rb_entry(rb_first(&map->pg_upmap), + struct ceph_pg_mapping, node); + rb_erase(&pg->node, &map->pg_upmap); + kfree(pg); + } + while (!RB_EMPTY_ROOT(&map->pg_upmap_items)) { + struct ceph_pg_mapping *pg = + rb_entry(rb_first(&map->pg_upmap_items), + struct ceph_pg_mapping, node); + rb_erase(&pg->node, &map->pg_upmap_items); + kfree(pg); + } + while (!RB_EMPTY_ROOT(&map->pg_pools)) { + struct ceph_pg_pool_info *pi = + rb_entry(rb_first(&map->pg_pools), + struct ceph_pg_pool_info, node); + __remove_pg_pool(&map->pg_pools, pi); + } + kvfree(map->osd_state); + kvfree(map->osd_weight); + kvfree(map->osd_addr); + kvfree(map->osd_primary_affinity); + kfree(map); +} + +/* + * Adjust max_osd value, (re)allocate arrays. + * + * The new elements are properly initialized. + */ +static int osdmap_set_max_osd(struct ceph_osdmap *map, u32 max) +{ + u32 *state; + u32 *weight; + struct ceph_entity_addr *addr; + u32 to_copy; + int i; + + dout("%s old %u new %u\n", __func__, map->max_osd, max); + if (max == map->max_osd) + return 0; + + state = kvmalloc(array_size(max, sizeof(*state)), GFP_NOFS); + weight = kvmalloc(array_size(max, sizeof(*weight)), GFP_NOFS); + addr = kvmalloc(array_size(max, sizeof(*addr)), GFP_NOFS); + if (!state || !weight || !addr) { + kvfree(state); + kvfree(weight); + kvfree(addr); + return -ENOMEM; + } + + to_copy = min(map->max_osd, max); + if (map->osd_state) { + memcpy(state, map->osd_state, to_copy * sizeof(*state)); + memcpy(weight, map->osd_weight, to_copy * sizeof(*weight)); + memcpy(addr, map->osd_addr, to_copy * sizeof(*addr)); + kvfree(map->osd_state); + kvfree(map->osd_weight); + kvfree(map->osd_addr); + } + + map->osd_state = state; + map->osd_weight = weight; + map->osd_addr = addr; + for (i = map->max_osd; i < max; i++) { + map->osd_state[i] = 0; + map->osd_weight[i] = CEPH_OSD_OUT; + memset(map->osd_addr + i, 0, sizeof(*map->osd_addr)); + } + + if (map->osd_primary_affinity) { + u32 *affinity; + + affinity = kvmalloc(array_size(max, sizeof(*affinity)), + GFP_NOFS); + if (!affinity) + return -ENOMEM; + + memcpy(affinity, map->osd_primary_affinity, + to_copy * sizeof(*affinity)); + kvfree(map->osd_primary_affinity); + + map->osd_primary_affinity = affinity; + for (i = map->max_osd; i < max; i++) + map->osd_primary_affinity[i] = + CEPH_OSD_DEFAULT_PRIMARY_AFFINITY; + } + + map->max_osd = max; + + return 0; +} + +static int osdmap_set_crush(struct ceph_osdmap *map, struct crush_map *crush) +{ + struct crush_work *work; + + if (IS_ERR(crush)) + return PTR_ERR(crush); + + work = alloc_workspace(crush); + if (!work) { + crush_destroy(crush); + return -ENOMEM; + } + + if (map->crush) + crush_destroy(map->crush); + cleanup_workspace_manager(&map->crush_wsm); + map->crush = crush; + add_initial_workspace(&map->crush_wsm, work); + return 0; +} + +#define OSDMAP_WRAPPER_COMPAT_VER 7 +#define OSDMAP_CLIENT_DATA_COMPAT_VER 1 + +/* + * Return 0 or error. On success, *v is set to 0 for old (v6) osdmaps, + * to struct_v of the client_data section for new (v7 and above) + * osdmaps. + */ +static int get_osdmap_client_data_v(void **p, void *end, + const char *prefix, u8 *v) +{ + u8 struct_v; + + ceph_decode_8_safe(p, end, struct_v, e_inval); + if (struct_v >= 7) { + u8 struct_compat; + + ceph_decode_8_safe(p, end, struct_compat, e_inval); + if (struct_compat > OSDMAP_WRAPPER_COMPAT_VER) { + pr_warn("got v %d cv %d > %d of %s ceph_osdmap\n", + struct_v, struct_compat, + OSDMAP_WRAPPER_COMPAT_VER, prefix); + return -EINVAL; + } + *p += 4; /* ignore wrapper struct_len */ + + ceph_decode_8_safe(p, end, struct_v, e_inval); + ceph_decode_8_safe(p, end, struct_compat, e_inval); + if (struct_compat > OSDMAP_CLIENT_DATA_COMPAT_VER) { + pr_warn("got v %d cv %d > %d of %s ceph_osdmap client data\n", + struct_v, struct_compat, + OSDMAP_CLIENT_DATA_COMPAT_VER, prefix); + return -EINVAL; + } + *p += 4; /* ignore client data struct_len */ + } else { + u16 version; + + *p -= 1; + ceph_decode_16_safe(p, end, version, e_inval); + if (version < 6) { + pr_warn("got v %d < 6 of %s ceph_osdmap\n", + version, prefix); + return -EINVAL; + } + + /* old osdmap encoding */ + struct_v = 0; + } + + *v = struct_v; + return 0; + +e_inval: + return -EINVAL; +} + +static int __decode_pools(void **p, void *end, struct ceph_osdmap *map, + bool incremental) +{ + u32 n; + + ceph_decode_32_safe(p, end, n, e_inval); + while (n--) { + struct ceph_pg_pool_info *pi; + u64 pool; + int ret; + + ceph_decode_64_safe(p, end, pool, e_inval); + + pi = lookup_pg_pool(&map->pg_pools, pool); + if (!incremental || !pi) { + pi = kzalloc(sizeof(*pi), GFP_NOFS); + if (!pi) + return -ENOMEM; + + RB_CLEAR_NODE(&pi->node); + pi->id = pool; + + if (!__insert_pg_pool(&map->pg_pools, pi)) { + kfree(pi); + return -EEXIST; + } + } + + ret = decode_pool(p, end, pi); + if (ret) + return ret; + } + + return 0; + +e_inval: + return -EINVAL; +} + +static int decode_pools(void **p, void *end, struct ceph_osdmap *map) +{ + return __decode_pools(p, end, map, false); +} + +static int decode_new_pools(void **p, void *end, struct ceph_osdmap *map) +{ + return __decode_pools(p, end, map, true); +} + +typedef struct ceph_pg_mapping *(*decode_mapping_fn_t)(void **, void *, bool); + +static int decode_pg_mapping(void **p, void *end, struct rb_root *mapping_root, + decode_mapping_fn_t fn, bool incremental) +{ + u32 n; + + WARN_ON(!incremental && !fn); + + ceph_decode_32_safe(p, end, n, e_inval); + while (n--) { + struct ceph_pg_mapping *pg; + struct ceph_pg pgid; + int ret; + + ret = ceph_decode_pgid(p, end, &pgid); + if (ret) + return ret; + + pg = lookup_pg_mapping(mapping_root, &pgid); + if (pg) { + WARN_ON(!incremental); + erase_pg_mapping(mapping_root, pg); + free_pg_mapping(pg); + } + + if (fn) { + pg = fn(p, end, incremental); + if (IS_ERR(pg)) + return PTR_ERR(pg); + + if (pg) { + pg->pgid = pgid; /* struct */ + insert_pg_mapping(mapping_root, pg); + } + } + } + + return 0; + +e_inval: + return -EINVAL; +} + +static struct ceph_pg_mapping *__decode_pg_temp(void **p, void *end, + bool incremental) +{ + struct ceph_pg_mapping *pg; + u32 len, i; + + ceph_decode_32_safe(p, end, len, e_inval); + if (len == 0 && incremental) + return NULL; /* new_pg_temp: [] to remove */ + if (len > (SIZE_MAX - sizeof(*pg)) / sizeof(u32)) + return ERR_PTR(-EINVAL); + + ceph_decode_need(p, end, len * sizeof(u32), e_inval); + pg = alloc_pg_mapping(len * sizeof(u32)); + if (!pg) + return ERR_PTR(-ENOMEM); + + pg->pg_temp.len = len; + for (i = 0; i < len; i++) + pg->pg_temp.osds[i] = ceph_decode_32(p); + + return pg; + +e_inval: + return ERR_PTR(-EINVAL); +} + +static int decode_pg_temp(void **p, void *end, struct ceph_osdmap *map) +{ + return decode_pg_mapping(p, end, &map->pg_temp, __decode_pg_temp, + false); +} + +static int decode_new_pg_temp(void **p, void *end, struct ceph_osdmap *map) +{ + return decode_pg_mapping(p, end, &map->pg_temp, __decode_pg_temp, + true); +} + +static struct ceph_pg_mapping *__decode_primary_temp(void **p, void *end, + bool incremental) +{ + struct ceph_pg_mapping *pg; + u32 osd; + + ceph_decode_32_safe(p, end, osd, e_inval); + if (osd == (u32)-1 && incremental) + return NULL; /* new_primary_temp: -1 to remove */ + + pg = alloc_pg_mapping(0); + if (!pg) + return ERR_PTR(-ENOMEM); + + pg->primary_temp.osd = osd; + return pg; + +e_inval: + return ERR_PTR(-EINVAL); +} + +static int decode_primary_temp(void **p, void *end, struct ceph_osdmap *map) +{ + return decode_pg_mapping(p, end, &map->primary_temp, + __decode_primary_temp, false); +} + +static int decode_new_primary_temp(void **p, void *end, + struct ceph_osdmap *map) +{ + return decode_pg_mapping(p, end, &map->primary_temp, + __decode_primary_temp, true); +} + +u32 ceph_get_primary_affinity(struct ceph_osdmap *map, int osd) +{ + BUG_ON(osd >= map->max_osd); + + if (!map->osd_primary_affinity) + return CEPH_OSD_DEFAULT_PRIMARY_AFFINITY; + + return map->osd_primary_affinity[osd]; +} + +static int set_primary_affinity(struct ceph_osdmap *map, int osd, u32 aff) +{ + BUG_ON(osd >= map->max_osd); + + if (!map->osd_primary_affinity) { + int i; + + map->osd_primary_affinity = kvmalloc( + array_size(map->max_osd, sizeof(*map->osd_primary_affinity)), + GFP_NOFS); + if (!map->osd_primary_affinity) + return -ENOMEM; + + for (i = 0; i < map->max_osd; i++) + map->osd_primary_affinity[i] = + CEPH_OSD_DEFAULT_PRIMARY_AFFINITY; + } + + map->osd_primary_affinity[osd] = aff; + + return 0; +} + +static int decode_primary_affinity(void **p, void *end, + struct ceph_osdmap *map) +{ + u32 len, i; + + ceph_decode_32_safe(p, end, len, e_inval); + if (len == 0) { + kvfree(map->osd_primary_affinity); + map->osd_primary_affinity = NULL; + return 0; + } + if (len != map->max_osd) + goto e_inval; + + ceph_decode_need(p, end, map->max_osd*sizeof(u32), e_inval); + + for (i = 0; i < map->max_osd; i++) { + int ret; + + ret = set_primary_affinity(map, i, ceph_decode_32(p)); + if (ret) + return ret; + } + + return 0; + +e_inval: + return -EINVAL; +} + +static int decode_new_primary_affinity(void **p, void *end, + struct ceph_osdmap *map) +{ + u32 n; + + ceph_decode_32_safe(p, end, n, e_inval); + while (n--) { + u32 osd, aff; + int ret; + + ceph_decode_32_safe(p, end, osd, e_inval); + ceph_decode_32_safe(p, end, aff, e_inval); + + ret = set_primary_affinity(map, osd, aff); + if (ret) + return ret; + + osdmap_info(map, "osd%d primary-affinity 0x%x\n", osd, aff); + } + + return 0; + +e_inval: + return -EINVAL; +} + +static struct ceph_pg_mapping *__decode_pg_upmap(void **p, void *end, + bool __unused) +{ + return __decode_pg_temp(p, end, false); +} + +static int decode_pg_upmap(void **p, void *end, struct ceph_osdmap *map) +{ + return decode_pg_mapping(p, end, &map->pg_upmap, __decode_pg_upmap, + false); +} + +static int decode_new_pg_upmap(void **p, void *end, struct ceph_osdmap *map) +{ + return decode_pg_mapping(p, end, &map->pg_upmap, __decode_pg_upmap, + true); +} + +static int decode_old_pg_upmap(void **p, void *end, struct ceph_osdmap *map) +{ + return decode_pg_mapping(p, end, &map->pg_upmap, NULL, true); +} + +static struct ceph_pg_mapping *__decode_pg_upmap_items(void **p, void *end, + bool __unused) +{ + struct ceph_pg_mapping *pg; + u32 len, i; + + ceph_decode_32_safe(p, end, len, e_inval); + if (len > (SIZE_MAX - sizeof(*pg)) / (2 * sizeof(u32))) + return ERR_PTR(-EINVAL); + + ceph_decode_need(p, end, 2 * len * sizeof(u32), e_inval); + pg = alloc_pg_mapping(2 * len * sizeof(u32)); + if (!pg) + return ERR_PTR(-ENOMEM); + + pg->pg_upmap_items.len = len; + for (i = 0; i < len; i++) { + pg->pg_upmap_items.from_to[i][0] = ceph_decode_32(p); + pg->pg_upmap_items.from_to[i][1] = ceph_decode_32(p); + } + + return pg; + +e_inval: + return ERR_PTR(-EINVAL); +} + +static int decode_pg_upmap_items(void **p, void *end, struct ceph_osdmap *map) +{ + return decode_pg_mapping(p, end, &map->pg_upmap_items, + __decode_pg_upmap_items, false); +} + +static int decode_new_pg_upmap_items(void **p, void *end, + struct ceph_osdmap *map) +{ + return decode_pg_mapping(p, end, &map->pg_upmap_items, + __decode_pg_upmap_items, true); +} + +static int decode_old_pg_upmap_items(void **p, void *end, + struct ceph_osdmap *map) +{ + return decode_pg_mapping(p, end, &map->pg_upmap_items, NULL, true); +} + +/* + * decode a full map. + */ +static int osdmap_decode(void **p, void *end, bool msgr2, + struct ceph_osdmap *map) +{ + u8 struct_v; + u32 epoch = 0; + void *start = *p; + u32 max; + u32 len, i; + int err; + + dout("%s %p to %p len %d\n", __func__, *p, end, (int)(end - *p)); + + err = get_osdmap_client_data_v(p, end, "full", &struct_v); + if (err) + goto bad; + + /* fsid, epoch, created, modified */ + ceph_decode_need(p, end, sizeof(map->fsid) + sizeof(u32) + + sizeof(map->created) + sizeof(map->modified), e_inval); + ceph_decode_copy(p, &map->fsid, sizeof(map->fsid)); + epoch = map->epoch = ceph_decode_32(p); + ceph_decode_copy(p, &map->created, sizeof(map->created)); + ceph_decode_copy(p, &map->modified, sizeof(map->modified)); + + /* pools */ + err = decode_pools(p, end, map); + if (err) + goto bad; + + /* pool_name */ + err = decode_pool_names(p, end, map); + if (err) + goto bad; + + ceph_decode_32_safe(p, end, map->pool_max, e_inval); + + ceph_decode_32_safe(p, end, map->flags, e_inval); + + /* max_osd */ + ceph_decode_32_safe(p, end, max, e_inval); + + /* (re)alloc osd arrays */ + err = osdmap_set_max_osd(map, max); + if (err) + goto bad; + + /* osd_state, osd_weight, osd_addrs->client_addr */ + ceph_decode_need(p, end, 3*sizeof(u32) + + map->max_osd*(struct_v >= 5 ? sizeof(u32) : + sizeof(u8)) + + sizeof(*map->osd_weight), e_inval); + if (ceph_decode_32(p) != map->max_osd) + goto e_inval; + + if (struct_v >= 5) { + for (i = 0; i < map->max_osd; i++) + map->osd_state[i] = ceph_decode_32(p); + } else { + for (i = 0; i < map->max_osd; i++) + map->osd_state[i] = ceph_decode_8(p); + } + + if (ceph_decode_32(p) != map->max_osd) + goto e_inval; + + for (i = 0; i < map->max_osd; i++) + map->osd_weight[i] = ceph_decode_32(p); + + if (ceph_decode_32(p) != map->max_osd) + goto e_inval; + + for (i = 0; i < map->max_osd; i++) { + struct ceph_entity_addr *addr = &map->osd_addr[i]; + + if (struct_v >= 8) + err = ceph_decode_entity_addrvec(p, end, msgr2, addr); + else + err = ceph_decode_entity_addr(p, end, addr); + if (err) + goto bad; + + dout("%s osd%d addr %s\n", __func__, i, ceph_pr_addr(addr)); + } + + /* pg_temp */ + err = decode_pg_temp(p, end, map); + if (err) + goto bad; + + /* primary_temp */ + if (struct_v >= 1) { + err = decode_primary_temp(p, end, map); + if (err) + goto bad; + } + + /* primary_affinity */ + if (struct_v >= 2) { + err = decode_primary_affinity(p, end, map); + if (err) + goto bad; + } else { + WARN_ON(map->osd_primary_affinity); + } + + /* crush */ + ceph_decode_32_safe(p, end, len, e_inval); + err = osdmap_set_crush(map, crush_decode(*p, min(*p + len, end))); + if (err) + goto bad; + + *p += len; + if (struct_v >= 3) { + /* erasure_code_profiles */ + ceph_decode_skip_map_of_map(p, end, string, string, string, + e_inval); + } + + if (struct_v >= 4) { + err = decode_pg_upmap(p, end, map); + if (err) + goto bad; + + err = decode_pg_upmap_items(p, end, map); + if (err) + goto bad; + } else { + WARN_ON(!RB_EMPTY_ROOT(&map->pg_upmap)); + WARN_ON(!RB_EMPTY_ROOT(&map->pg_upmap_items)); + } + + /* ignore the rest */ + *p = end; + + dout("full osdmap epoch %d max_osd %d\n", map->epoch, map->max_osd); + return 0; + +e_inval: + err = -EINVAL; +bad: + pr_err("corrupt full osdmap (%d) epoch %d off %d (%p of %p-%p)\n", + err, epoch, (int)(*p - start), *p, start, end); + print_hex_dump(KERN_DEBUG, "osdmap: ", + DUMP_PREFIX_OFFSET, 16, 1, + start, end - start, true); + return err; +} + +/* + * Allocate and decode a full map. + */ +struct ceph_osdmap *ceph_osdmap_decode(void **p, void *end, bool msgr2) +{ + struct ceph_osdmap *map; + int ret; + + map = ceph_osdmap_alloc(); + if (!map) + return ERR_PTR(-ENOMEM); + + ret = osdmap_decode(p, end, msgr2, map); + if (ret) { + ceph_osdmap_destroy(map); + return ERR_PTR(ret); + } + + return map; +} + +/* + * Encoding order is (new_up_client, new_state, new_weight). Need to + * apply in the (new_weight, new_state, new_up_client) order, because + * an incremental map may look like e.g. + * + * new_up_client: { osd=6, addr=... } # set osd_state and addr + * new_state: { osd=6, xorstate=EXISTS } # clear osd_state + */ +static int decode_new_up_state_weight(void **p, void *end, u8 struct_v, + bool msgr2, struct ceph_osdmap *map) +{ + void *new_up_client; + void *new_state; + void *new_weight_end; + u32 len; + int ret; + int i; + + new_up_client = *p; + ceph_decode_32_safe(p, end, len, e_inval); + for (i = 0; i < len; ++i) { + struct ceph_entity_addr addr; + + ceph_decode_skip_32(p, end, e_inval); + if (struct_v >= 7) + ret = ceph_decode_entity_addrvec(p, end, msgr2, &addr); + else + ret = ceph_decode_entity_addr(p, end, &addr); + if (ret) + return ret; + } + + new_state = *p; + ceph_decode_32_safe(p, end, len, e_inval); + len *= sizeof(u32) + (struct_v >= 5 ? sizeof(u32) : sizeof(u8)); + ceph_decode_need(p, end, len, e_inval); + *p += len; + + /* new_weight */ + ceph_decode_32_safe(p, end, len, e_inval); + while (len--) { + s32 osd; + u32 w; + + ceph_decode_need(p, end, 2*sizeof(u32), e_inval); + osd = ceph_decode_32(p); + w = ceph_decode_32(p); + BUG_ON(osd >= map->max_osd); + osdmap_info(map, "osd%d weight 0x%x %s\n", osd, w, + w == CEPH_OSD_IN ? "(in)" : + (w == CEPH_OSD_OUT ? "(out)" : "")); + map->osd_weight[osd] = w; + + /* + * If we are marking in, set the EXISTS, and clear the + * AUTOOUT and NEW bits. + */ + if (w) { + map->osd_state[osd] |= CEPH_OSD_EXISTS; + map->osd_state[osd] &= ~(CEPH_OSD_AUTOOUT | + CEPH_OSD_NEW); + } + } + new_weight_end = *p; + + /* new_state (up/down) */ + *p = new_state; + len = ceph_decode_32(p); + while (len--) { + s32 osd; + u32 xorstate; + + osd = ceph_decode_32(p); + if (struct_v >= 5) + xorstate = ceph_decode_32(p); + else + xorstate = ceph_decode_8(p); + if (xorstate == 0) + xorstate = CEPH_OSD_UP; + BUG_ON(osd >= map->max_osd); + if ((map->osd_state[osd] & CEPH_OSD_UP) && + (xorstate & CEPH_OSD_UP)) + osdmap_info(map, "osd%d down\n", osd); + if ((map->osd_state[osd] & CEPH_OSD_EXISTS) && + (xorstate & CEPH_OSD_EXISTS)) { + osdmap_info(map, "osd%d does not exist\n", osd); + ret = set_primary_affinity(map, osd, + CEPH_OSD_DEFAULT_PRIMARY_AFFINITY); + if (ret) + return ret; + memset(map->osd_addr + osd, 0, sizeof(*map->osd_addr)); + map->osd_state[osd] = 0; + } else { + map->osd_state[osd] ^= xorstate; + } + } + + /* new_up_client */ + *p = new_up_client; + len = ceph_decode_32(p); + while (len--) { + s32 osd; + struct ceph_entity_addr addr; + + osd = ceph_decode_32(p); + BUG_ON(osd >= map->max_osd); + if (struct_v >= 7) + ret = ceph_decode_entity_addrvec(p, end, msgr2, &addr); + else + ret = ceph_decode_entity_addr(p, end, &addr); + if (ret) + return ret; + + dout("%s osd%d addr %s\n", __func__, osd, ceph_pr_addr(&addr)); + + osdmap_info(map, "osd%d up\n", osd); + map->osd_state[osd] |= CEPH_OSD_EXISTS | CEPH_OSD_UP; + map->osd_addr[osd] = addr; + } + + *p = new_weight_end; + return 0; + +e_inval: + return -EINVAL; +} + +/* + * decode and apply an incremental map update. + */ +struct ceph_osdmap *osdmap_apply_incremental(void **p, void *end, bool msgr2, + struct ceph_osdmap *map) +{ + struct ceph_fsid fsid; + u32 epoch = 0; + struct ceph_timespec modified; + s32 len; + u64 pool; + __s64 new_pool_max; + __s32 new_flags, max; + void *start = *p; + int err; + u8 struct_v; + + dout("%s %p to %p len %d\n", __func__, *p, end, (int)(end - *p)); + + err = get_osdmap_client_data_v(p, end, "inc", &struct_v); + if (err) + goto bad; + + /* fsid, epoch, modified, new_pool_max, new_flags */ + ceph_decode_need(p, end, sizeof(fsid) + sizeof(u32) + sizeof(modified) + + sizeof(u64) + sizeof(u32), e_inval); + ceph_decode_copy(p, &fsid, sizeof(fsid)); + epoch = ceph_decode_32(p); + BUG_ON(epoch != map->epoch+1); + ceph_decode_copy(p, &modified, sizeof(modified)); + new_pool_max = ceph_decode_64(p); + new_flags = ceph_decode_32(p); + + /* full map? */ + ceph_decode_32_safe(p, end, len, e_inval); + if (len > 0) { + dout("apply_incremental full map len %d, %p to %p\n", + len, *p, end); + return ceph_osdmap_decode(p, min(*p+len, end), msgr2); + } + + /* new crush? */ + ceph_decode_32_safe(p, end, len, e_inval); + if (len > 0) { + err = osdmap_set_crush(map, + crush_decode(*p, min(*p + len, end))); + if (err) + goto bad; + *p += len; + } + + /* new flags? */ + if (new_flags >= 0) + map->flags = new_flags; + if (new_pool_max >= 0) + map->pool_max = new_pool_max; + + /* new max? */ + ceph_decode_32_safe(p, end, max, e_inval); + if (max >= 0) { + err = osdmap_set_max_osd(map, max); + if (err) + goto bad; + } + + map->epoch++; + map->modified = modified; + + /* new_pools */ + err = decode_new_pools(p, end, map); + if (err) + goto bad; + + /* new_pool_names */ + err = decode_pool_names(p, end, map); + if (err) + goto bad; + + /* old_pool */ + ceph_decode_32_safe(p, end, len, e_inval); + while (len--) { + struct ceph_pg_pool_info *pi; + + ceph_decode_64_safe(p, end, pool, e_inval); + pi = lookup_pg_pool(&map->pg_pools, pool); + if (pi) + __remove_pg_pool(&map->pg_pools, pi); + } + + /* new_up_client, new_state, new_weight */ + err = decode_new_up_state_weight(p, end, struct_v, msgr2, map); + if (err) + goto bad; + + /* new_pg_temp */ + err = decode_new_pg_temp(p, end, map); + if (err) + goto bad; + + /* new_primary_temp */ + if (struct_v >= 1) { + err = decode_new_primary_temp(p, end, map); + if (err) + goto bad; + } + + /* new_primary_affinity */ + if (struct_v >= 2) { + err = decode_new_primary_affinity(p, end, map); + if (err) + goto bad; + } + + if (struct_v >= 3) { + /* new_erasure_code_profiles */ + ceph_decode_skip_map_of_map(p, end, string, string, string, + e_inval); + /* old_erasure_code_profiles */ + ceph_decode_skip_set(p, end, string, e_inval); + } + + if (struct_v >= 4) { + err = decode_new_pg_upmap(p, end, map); + if (err) + goto bad; + + err = decode_old_pg_upmap(p, end, map); + if (err) + goto bad; + + err = decode_new_pg_upmap_items(p, end, map); + if (err) + goto bad; + + err = decode_old_pg_upmap_items(p, end, map); + if (err) + goto bad; + } + + /* ignore the rest */ + *p = end; + + dout("inc osdmap epoch %d max_osd %d\n", map->epoch, map->max_osd); + return map; + +e_inval: + err = -EINVAL; +bad: + pr_err("corrupt inc osdmap (%d) epoch %d off %d (%p of %p-%p)\n", + err, epoch, (int)(*p - start), *p, start, end); + print_hex_dump(KERN_DEBUG, "osdmap: ", + DUMP_PREFIX_OFFSET, 16, 1, + start, end - start, true); + return ERR_PTR(err); +} + +void ceph_oloc_copy(struct ceph_object_locator *dest, + const struct ceph_object_locator *src) +{ + ceph_oloc_destroy(dest); + + dest->pool = src->pool; + if (src->pool_ns) + dest->pool_ns = ceph_get_string(src->pool_ns); + else + dest->pool_ns = NULL; +} +EXPORT_SYMBOL(ceph_oloc_copy); + +void ceph_oloc_destroy(struct ceph_object_locator *oloc) +{ + ceph_put_string(oloc->pool_ns); +} +EXPORT_SYMBOL(ceph_oloc_destroy); + +void ceph_oid_copy(struct ceph_object_id *dest, + const struct ceph_object_id *src) +{ + ceph_oid_destroy(dest); + + if (src->name != src->inline_name) { + /* very rare, see ceph_object_id definition */ + dest->name = kmalloc(src->name_len + 1, + GFP_NOIO | __GFP_NOFAIL); + } else { + dest->name = dest->inline_name; + } + memcpy(dest->name, src->name, src->name_len + 1); + dest->name_len = src->name_len; +} +EXPORT_SYMBOL(ceph_oid_copy); + +static __printf(2, 0) +int oid_printf_vargs(struct ceph_object_id *oid, const char *fmt, va_list ap) +{ + int len; + + WARN_ON(!ceph_oid_empty(oid)); + + len = vsnprintf(oid->inline_name, sizeof(oid->inline_name), fmt, ap); + if (len >= sizeof(oid->inline_name)) + return len; + + oid->name_len = len; + return 0; +} + +/* + * If oid doesn't fit into inline buffer, BUG. + */ +void ceph_oid_printf(struct ceph_object_id *oid, const char *fmt, ...) +{ + va_list ap; + + va_start(ap, fmt); + BUG_ON(oid_printf_vargs(oid, fmt, ap)); + va_end(ap); +} +EXPORT_SYMBOL(ceph_oid_printf); + +static __printf(3, 0) +int oid_aprintf_vargs(struct ceph_object_id *oid, gfp_t gfp, + const char *fmt, va_list ap) +{ + va_list aq; + int len; + + va_copy(aq, ap); + len = oid_printf_vargs(oid, fmt, aq); + va_end(aq); + + if (len) { + char *external_name; + + external_name = kmalloc(len + 1, gfp); + if (!external_name) + return -ENOMEM; + + oid->name = external_name; + WARN_ON(vsnprintf(oid->name, len + 1, fmt, ap) != len); + oid->name_len = len; + } + + return 0; +} + +/* + * If oid doesn't fit into inline buffer, allocate. + */ +int ceph_oid_aprintf(struct ceph_object_id *oid, gfp_t gfp, + const char *fmt, ...) +{ + va_list ap; + int ret; + + va_start(ap, fmt); + ret = oid_aprintf_vargs(oid, gfp, fmt, ap); + va_end(ap); + + return ret; +} +EXPORT_SYMBOL(ceph_oid_aprintf); + +void ceph_oid_destroy(struct ceph_object_id *oid) +{ + if (oid->name != oid->inline_name) + kfree(oid->name); +} +EXPORT_SYMBOL(ceph_oid_destroy); + +/* + * osds only + */ +static bool __osds_equal(const struct ceph_osds *lhs, + const struct ceph_osds *rhs) +{ + if (lhs->size == rhs->size && + !memcmp(lhs->osds, rhs->osds, rhs->size * sizeof(rhs->osds[0]))) + return true; + + return false; +} + +/* + * osds + primary + */ +static bool osds_equal(const struct ceph_osds *lhs, + const struct ceph_osds *rhs) +{ + if (__osds_equal(lhs, rhs) && + lhs->primary == rhs->primary) + return true; + + return false; +} + +static bool osds_valid(const struct ceph_osds *set) +{ + /* non-empty set */ + if (set->size > 0 && set->primary >= 0) + return true; + + /* empty can_shift_osds set */ + if (!set->size && set->primary == -1) + return true; + + /* empty !can_shift_osds set - all NONE */ + if (set->size > 0 && set->primary == -1) { + int i; + + for (i = 0; i < set->size; i++) { + if (set->osds[i] != CRUSH_ITEM_NONE) + break; + } + if (i == set->size) + return true; + } + + return false; +} + +void ceph_osds_copy(struct ceph_osds *dest, const struct ceph_osds *src) +{ + memcpy(dest->osds, src->osds, src->size * sizeof(src->osds[0])); + dest->size = src->size; + dest->primary = src->primary; +} + +bool ceph_pg_is_split(const struct ceph_pg *pgid, u32 old_pg_num, + u32 new_pg_num) +{ + int old_bits = calc_bits_of(old_pg_num); + int old_mask = (1 << old_bits) - 1; + int n; + + WARN_ON(pgid->seed >= old_pg_num); + if (new_pg_num <= old_pg_num) + return false; + + for (n = 1; ; n++) { + int next_bit = n << (old_bits - 1); + u32 s = next_bit | pgid->seed; + + if (s < old_pg_num || s == pgid->seed) + continue; + if (s >= new_pg_num) + break; + + s = ceph_stable_mod(s, old_pg_num, old_mask); + if (s == pgid->seed) + return true; + } + + return false; +} + +bool ceph_is_new_interval(const struct ceph_osds *old_acting, + const struct ceph_osds *new_acting, + const struct ceph_osds *old_up, + const struct ceph_osds *new_up, + int old_size, + int new_size, + int old_min_size, + int new_min_size, + u32 old_pg_num, + u32 new_pg_num, + bool old_sort_bitwise, + bool new_sort_bitwise, + bool old_recovery_deletes, + bool new_recovery_deletes, + const struct ceph_pg *pgid) +{ + return !osds_equal(old_acting, new_acting) || + !osds_equal(old_up, new_up) || + old_size != new_size || + old_min_size != new_min_size || + ceph_pg_is_split(pgid, old_pg_num, new_pg_num) || + old_sort_bitwise != new_sort_bitwise || + old_recovery_deletes != new_recovery_deletes; +} + +static int calc_pg_rank(int osd, const struct ceph_osds *acting) +{ + int i; + + for (i = 0; i < acting->size; i++) { + if (acting->osds[i] == osd) + return i; + } + + return -1; +} + +static bool primary_changed(const struct ceph_osds *old_acting, + const struct ceph_osds *new_acting) +{ + if (!old_acting->size && !new_acting->size) + return false; /* both still empty */ + + if (!old_acting->size ^ !new_acting->size) + return true; /* was empty, now not, or vice versa */ + + if (old_acting->primary != new_acting->primary) + return true; /* primary changed */ + + if (calc_pg_rank(old_acting->primary, old_acting) != + calc_pg_rank(new_acting->primary, new_acting)) + return true; + + return false; /* same primary (tho replicas may have changed) */ +} + +bool ceph_osds_changed(const struct ceph_osds *old_acting, + const struct ceph_osds *new_acting, + bool any_change) +{ + if (primary_changed(old_acting, new_acting)) + return true; + + if (any_change && !__osds_equal(old_acting, new_acting)) + return true; + + return false; +} + +/* + * Map an object into a PG. + * + * Should only be called with target_oid and target_oloc (as opposed to + * base_oid and base_oloc), since tiering isn't taken into account. + */ +void __ceph_object_locator_to_pg(struct ceph_pg_pool_info *pi, + const struct ceph_object_id *oid, + const struct ceph_object_locator *oloc, + struct ceph_pg *raw_pgid) +{ + WARN_ON(pi->id != oloc->pool); + + if (!oloc->pool_ns) { + raw_pgid->pool = oloc->pool; + raw_pgid->seed = ceph_str_hash(pi->object_hash, oid->name, + oid->name_len); + dout("%s %s -> raw_pgid %llu.%x\n", __func__, oid->name, + raw_pgid->pool, raw_pgid->seed); + } else { + char stack_buf[256]; + char *buf = stack_buf; + int nsl = oloc->pool_ns->len; + size_t total = nsl + 1 + oid->name_len; + + if (total > sizeof(stack_buf)) + buf = kmalloc(total, GFP_NOIO | __GFP_NOFAIL); + memcpy(buf, oloc->pool_ns->str, nsl); + buf[nsl] = '\037'; + memcpy(buf + nsl + 1, oid->name, oid->name_len); + raw_pgid->pool = oloc->pool; + raw_pgid->seed = ceph_str_hash(pi->object_hash, buf, total); + if (buf != stack_buf) + kfree(buf); + dout("%s %s ns %.*s -> raw_pgid %llu.%x\n", __func__, + oid->name, nsl, oloc->pool_ns->str, + raw_pgid->pool, raw_pgid->seed); + } +} + +int ceph_object_locator_to_pg(struct ceph_osdmap *osdmap, + const struct ceph_object_id *oid, + const struct ceph_object_locator *oloc, + struct ceph_pg *raw_pgid) +{ + struct ceph_pg_pool_info *pi; + + pi = ceph_pg_pool_by_id(osdmap, oloc->pool); + if (!pi) + return -ENOENT; + + __ceph_object_locator_to_pg(pi, oid, oloc, raw_pgid); + return 0; +} +EXPORT_SYMBOL(ceph_object_locator_to_pg); + +/* + * Map a raw PG (full precision ps) into an actual PG. + */ +static void raw_pg_to_pg(struct ceph_pg_pool_info *pi, + const struct ceph_pg *raw_pgid, + struct ceph_pg *pgid) +{ + pgid->pool = raw_pgid->pool; + pgid->seed = ceph_stable_mod(raw_pgid->seed, pi->pg_num, + pi->pg_num_mask); +} + +/* + * Map a raw PG (full precision ps) into a placement ps (placement + * seed). Include pool id in that value so that different pools don't + * use the same seeds. + */ +static u32 raw_pg_to_pps(struct ceph_pg_pool_info *pi, + const struct ceph_pg *raw_pgid) +{ + if (pi->flags & CEPH_POOL_FLAG_HASHPSPOOL) { + /* hash pool id and seed so that pool PGs do not overlap */ + return crush_hash32_2(CRUSH_HASH_RJENKINS1, + ceph_stable_mod(raw_pgid->seed, + pi->pgp_num, + pi->pgp_num_mask), + raw_pgid->pool); + } else { + /* + * legacy behavior: add ps and pool together. this is + * not a great approach because the PGs from each pool + * will overlap on top of each other: 0.5 == 1.4 == + * 2.3 == ... + */ + return ceph_stable_mod(raw_pgid->seed, pi->pgp_num, + pi->pgp_num_mask) + + (unsigned)raw_pgid->pool; + } +} + +/* + * Magic value used for a "default" fallback choose_args, used if the + * crush_choose_arg_map passed to do_crush() does not exist. If this + * also doesn't exist, fall back to canonical weights. + */ +#define CEPH_DEFAULT_CHOOSE_ARGS -1 + +static int do_crush(struct ceph_osdmap *map, int ruleno, int x, + int *result, int result_max, + const __u32 *weight, int weight_max, + s64 choose_args_index) +{ + struct crush_choose_arg_map *arg_map; + struct crush_work *work; + int r; + + BUG_ON(result_max > CEPH_PG_MAX_SIZE); + + arg_map = lookup_choose_arg_map(&map->crush->choose_args, + choose_args_index); + if (!arg_map) + arg_map = lookup_choose_arg_map(&map->crush->choose_args, + CEPH_DEFAULT_CHOOSE_ARGS); + + work = get_workspace(&map->crush_wsm, map->crush); + r = crush_do_rule(map->crush, ruleno, x, result, result_max, + weight, weight_max, work, + arg_map ? arg_map->args : NULL); + put_workspace(&map->crush_wsm, work); + return r; +} + +static void remove_nonexistent_osds(struct ceph_osdmap *osdmap, + struct ceph_pg_pool_info *pi, + struct ceph_osds *set) +{ + int i; + + if (ceph_can_shift_osds(pi)) { + int removed = 0; + + /* shift left */ + for (i = 0; i < set->size; i++) { + if (!ceph_osd_exists(osdmap, set->osds[i])) { + removed++; + continue; + } + if (removed) + set->osds[i - removed] = set->osds[i]; + } + set->size -= removed; + } else { + /* set dne devices to NONE */ + for (i = 0; i < set->size; i++) { + if (!ceph_osd_exists(osdmap, set->osds[i])) + set->osds[i] = CRUSH_ITEM_NONE; + } + } +} + +/* + * Calculate raw set (CRUSH output) for given PG and filter out + * nonexistent OSDs. ->primary is undefined for a raw set. + * + * Placement seed (CRUSH input) is returned through @ppps. + */ +static void pg_to_raw_osds(struct ceph_osdmap *osdmap, + struct ceph_pg_pool_info *pi, + const struct ceph_pg *raw_pgid, + struct ceph_osds *raw, + u32 *ppps) +{ + u32 pps = raw_pg_to_pps(pi, raw_pgid); + int ruleno; + int len; + + ceph_osds_init(raw); + if (ppps) + *ppps = pps; + + ruleno = crush_find_rule(osdmap->crush, pi->crush_ruleset, pi->type, + pi->size); + if (ruleno < 0) { + pr_err("no crush rule: pool %lld ruleset %d type %d size %d\n", + pi->id, pi->crush_ruleset, pi->type, pi->size); + return; + } + + if (pi->size > ARRAY_SIZE(raw->osds)) { + pr_err_ratelimited("pool %lld ruleset %d type %d too wide: size %d > %zu\n", + pi->id, pi->crush_ruleset, pi->type, pi->size, + ARRAY_SIZE(raw->osds)); + return; + } + + len = do_crush(osdmap, ruleno, pps, raw->osds, pi->size, + osdmap->osd_weight, osdmap->max_osd, pi->id); + if (len < 0) { + pr_err("error %d from crush rule %d: pool %lld ruleset %d type %d size %d\n", + len, ruleno, pi->id, pi->crush_ruleset, pi->type, + pi->size); + return; + } + + raw->size = len; + remove_nonexistent_osds(osdmap, pi, raw); +} + +/* apply pg_upmap[_items] mappings */ +static void apply_upmap(struct ceph_osdmap *osdmap, + const struct ceph_pg *pgid, + struct ceph_osds *raw) +{ + struct ceph_pg_mapping *pg; + int i, j; + + pg = lookup_pg_mapping(&osdmap->pg_upmap, pgid); + if (pg) { + /* make sure targets aren't marked out */ + for (i = 0; i < pg->pg_upmap.len; i++) { + int osd = pg->pg_upmap.osds[i]; + + if (osd != CRUSH_ITEM_NONE && + osd < osdmap->max_osd && + osdmap->osd_weight[osd] == 0) { + /* reject/ignore explicit mapping */ + return; + } + } + for (i = 0; i < pg->pg_upmap.len; i++) + raw->osds[i] = pg->pg_upmap.osds[i]; + raw->size = pg->pg_upmap.len; + /* check and apply pg_upmap_items, if any */ + } + + pg = lookup_pg_mapping(&osdmap->pg_upmap_items, pgid); + if (pg) { + /* + * Note: this approach does not allow a bidirectional swap, + * e.g., [[1,2],[2,1]] applied to [0,1,2] -> [0,2,1]. + */ + for (i = 0; i < pg->pg_upmap_items.len; i++) { + int from = pg->pg_upmap_items.from_to[i][0]; + int to = pg->pg_upmap_items.from_to[i][1]; + int pos = -1; + bool exists = false; + + /* make sure replacement doesn't already appear */ + for (j = 0; j < raw->size; j++) { + int osd = raw->osds[j]; + + if (osd == to) { + exists = true; + break; + } + /* ignore mapping if target is marked out */ + if (osd == from && pos < 0 && + !(to != CRUSH_ITEM_NONE && + to < osdmap->max_osd && + osdmap->osd_weight[to] == 0)) { + pos = j; + } + } + if (!exists && pos >= 0) + raw->osds[pos] = to; + } + } +} + +/* + * Given raw set, calculate up set and up primary. By definition of an + * up set, the result won't contain nonexistent or down OSDs. + * + * This is done in-place - on return @set is the up set. If it's + * empty, ->primary will remain undefined. + */ +static void raw_to_up_osds(struct ceph_osdmap *osdmap, + struct ceph_pg_pool_info *pi, + struct ceph_osds *set) +{ + int i; + + /* ->primary is undefined for a raw set */ + BUG_ON(set->primary != -1); + + if (ceph_can_shift_osds(pi)) { + int removed = 0; + + /* shift left */ + for (i = 0; i < set->size; i++) { + if (ceph_osd_is_down(osdmap, set->osds[i])) { + removed++; + continue; + } + if (removed) + set->osds[i - removed] = set->osds[i]; + } + set->size -= removed; + if (set->size > 0) + set->primary = set->osds[0]; + } else { + /* set down/dne devices to NONE */ + for (i = set->size - 1; i >= 0; i--) { + if (ceph_osd_is_down(osdmap, set->osds[i])) + set->osds[i] = CRUSH_ITEM_NONE; + else + set->primary = set->osds[i]; + } + } +} + +static void apply_primary_affinity(struct ceph_osdmap *osdmap, + struct ceph_pg_pool_info *pi, + u32 pps, + struct ceph_osds *up) +{ + int i; + int pos = -1; + + /* + * Do we have any non-default primary_affinity values for these + * osds? + */ + if (!osdmap->osd_primary_affinity) + return; + + for (i = 0; i < up->size; i++) { + int osd = up->osds[i]; + + if (osd != CRUSH_ITEM_NONE && + osdmap->osd_primary_affinity[osd] != + CEPH_OSD_DEFAULT_PRIMARY_AFFINITY) { + break; + } + } + if (i == up->size) + return; + + /* + * Pick the primary. Feed both the seed (for the pg) and the + * osd into the hash/rng so that a proportional fraction of an + * osd's pgs get rejected as primary. + */ + for (i = 0; i < up->size; i++) { + int osd = up->osds[i]; + u32 aff; + + if (osd == CRUSH_ITEM_NONE) + continue; + + aff = osdmap->osd_primary_affinity[osd]; + if (aff < CEPH_OSD_MAX_PRIMARY_AFFINITY && + (crush_hash32_2(CRUSH_HASH_RJENKINS1, + pps, osd) >> 16) >= aff) { + /* + * We chose not to use this primary. Note it + * anyway as a fallback in case we don't pick + * anyone else, but keep looking. + */ + if (pos < 0) + pos = i; + } else { + pos = i; + break; + } + } + if (pos < 0) + return; + + up->primary = up->osds[pos]; + + if (ceph_can_shift_osds(pi) && pos > 0) { + /* move the new primary to the front */ + for (i = pos; i > 0; i--) + up->osds[i] = up->osds[i - 1]; + up->osds[0] = up->primary; + } +} + +/* + * Get pg_temp and primary_temp mappings for given PG. + * + * Note that a PG may have none, only pg_temp, only primary_temp or + * both pg_temp and primary_temp mappings. This means @temp isn't + * always a valid OSD set on return: in the "only primary_temp" case, + * @temp will have its ->primary >= 0 but ->size == 0. + */ +static void get_temp_osds(struct ceph_osdmap *osdmap, + struct ceph_pg_pool_info *pi, + const struct ceph_pg *pgid, + struct ceph_osds *temp) +{ + struct ceph_pg_mapping *pg; + int i; + + ceph_osds_init(temp); + + /* pg_temp? */ + pg = lookup_pg_mapping(&osdmap->pg_temp, pgid); + if (pg) { + for (i = 0; i < pg->pg_temp.len; i++) { + if (ceph_osd_is_down(osdmap, pg->pg_temp.osds[i])) { + if (ceph_can_shift_osds(pi)) + continue; + + temp->osds[temp->size++] = CRUSH_ITEM_NONE; + } else { + temp->osds[temp->size++] = pg->pg_temp.osds[i]; + } + } + + /* apply pg_temp's primary */ + for (i = 0; i < temp->size; i++) { + if (temp->osds[i] != CRUSH_ITEM_NONE) { + temp->primary = temp->osds[i]; + break; + } + } + } + + /* primary_temp? */ + pg = lookup_pg_mapping(&osdmap->primary_temp, pgid); + if (pg) + temp->primary = pg->primary_temp.osd; +} + +/* + * Map a PG to its acting set as well as its up set. + * + * Acting set is used for data mapping purposes, while up set can be + * recorded for detecting interval changes and deciding whether to + * resend a request. + */ +void ceph_pg_to_up_acting_osds(struct ceph_osdmap *osdmap, + struct ceph_pg_pool_info *pi, + const struct ceph_pg *raw_pgid, + struct ceph_osds *up, + struct ceph_osds *acting) +{ + struct ceph_pg pgid; + u32 pps; + + WARN_ON(pi->id != raw_pgid->pool); + raw_pg_to_pg(pi, raw_pgid, &pgid); + + pg_to_raw_osds(osdmap, pi, raw_pgid, up, &pps); + apply_upmap(osdmap, &pgid, up); + raw_to_up_osds(osdmap, pi, up); + apply_primary_affinity(osdmap, pi, pps, up); + get_temp_osds(osdmap, pi, &pgid, acting); + if (!acting->size) { + memcpy(acting->osds, up->osds, up->size * sizeof(up->osds[0])); + acting->size = up->size; + if (acting->primary == -1) + acting->primary = up->primary; + } + WARN_ON(!osds_valid(up) || !osds_valid(acting)); +} + +bool ceph_pg_to_primary_shard(struct ceph_osdmap *osdmap, + struct ceph_pg_pool_info *pi, + const struct ceph_pg *raw_pgid, + struct ceph_spg *spgid) +{ + struct ceph_pg pgid; + struct ceph_osds up, acting; + int i; + + WARN_ON(pi->id != raw_pgid->pool); + raw_pg_to_pg(pi, raw_pgid, &pgid); + + if (ceph_can_shift_osds(pi)) { + spgid->pgid = pgid; /* struct */ + spgid->shard = CEPH_SPG_NOSHARD; + return true; + } + + ceph_pg_to_up_acting_osds(osdmap, pi, &pgid, &up, &acting); + for (i = 0; i < acting.size; i++) { + if (acting.osds[i] == acting.primary) { + spgid->pgid = pgid; /* struct */ + spgid->shard = i; + return true; + } + } + + return false; +} + +/* + * Return acting primary for given PG, or -1 if none. + */ +int ceph_pg_to_acting_primary(struct ceph_osdmap *osdmap, + const struct ceph_pg *raw_pgid) +{ + struct ceph_pg_pool_info *pi; + struct ceph_osds up, acting; + + pi = ceph_pg_pool_by_id(osdmap, raw_pgid->pool); + if (!pi) + return -1; + + ceph_pg_to_up_acting_osds(osdmap, pi, raw_pgid, &up, &acting); + return acting.primary; +} +EXPORT_SYMBOL(ceph_pg_to_acting_primary); + +static struct crush_loc_node *alloc_crush_loc(size_t type_name_len, + size_t name_len) +{ + struct crush_loc_node *loc; + + loc = kmalloc(sizeof(*loc) + type_name_len + name_len + 2, GFP_NOIO); + if (!loc) + return NULL; + + RB_CLEAR_NODE(&loc->cl_node); + return loc; +} + +static void free_crush_loc(struct crush_loc_node *loc) +{ + WARN_ON(!RB_EMPTY_NODE(&loc->cl_node)); + + kfree(loc); +} + +static int crush_loc_compare(const struct crush_loc *loc1, + const struct crush_loc *loc2) +{ + return strcmp(loc1->cl_type_name, loc2->cl_type_name) ?: + strcmp(loc1->cl_name, loc2->cl_name); +} + +DEFINE_RB_FUNCS2(crush_loc, struct crush_loc_node, cl_loc, crush_loc_compare, + RB_BYPTR, const struct crush_loc *, cl_node) + +/* + * Parses a set of ':' pairs separated + * by '|', e.g. "rack:foo1|rack:foo2|datacenter:bar". + * + * Note that @crush_location is modified by strsep(). + */ +int ceph_parse_crush_location(char *crush_location, struct rb_root *locs) +{ + struct crush_loc_node *loc; + const char *type_name, *name, *colon; + size_t type_name_len, name_len; + + dout("%s '%s'\n", __func__, crush_location); + while ((type_name = strsep(&crush_location, "|"))) { + colon = strchr(type_name, ':'); + if (!colon) + return -EINVAL; + + type_name_len = colon - type_name; + if (type_name_len == 0) + return -EINVAL; + + name = colon + 1; + name_len = strlen(name); + if (name_len == 0) + return -EINVAL; + + loc = alloc_crush_loc(type_name_len, name_len); + if (!loc) + return -ENOMEM; + + loc->cl_loc.cl_type_name = loc->cl_data; + memcpy(loc->cl_loc.cl_type_name, type_name, type_name_len); + loc->cl_loc.cl_type_name[type_name_len] = '\0'; + + loc->cl_loc.cl_name = loc->cl_data + type_name_len + 1; + memcpy(loc->cl_loc.cl_name, name, name_len); + loc->cl_loc.cl_name[name_len] = '\0'; + + if (!__insert_crush_loc(locs, loc)) { + free_crush_loc(loc); + return -EEXIST; + } + + dout("%s type_name '%s' name '%s'\n", __func__, + loc->cl_loc.cl_type_name, loc->cl_loc.cl_name); + } + + return 0; +} + +int ceph_compare_crush_locs(struct rb_root *locs1, struct rb_root *locs2) +{ + struct rb_node *n1 = rb_first(locs1); + struct rb_node *n2 = rb_first(locs2); + int ret; + + for ( ; n1 && n2; n1 = rb_next(n1), n2 = rb_next(n2)) { + struct crush_loc_node *loc1 = + rb_entry(n1, struct crush_loc_node, cl_node); + struct crush_loc_node *loc2 = + rb_entry(n2, struct crush_loc_node, cl_node); + + ret = crush_loc_compare(&loc1->cl_loc, &loc2->cl_loc); + if (ret) + return ret; + } + + if (!n1 && n2) + return -1; + if (n1 && !n2) + return 1; + return 0; +} + +void ceph_clear_crush_locs(struct rb_root *locs) +{ + while (!RB_EMPTY_ROOT(locs)) { + struct crush_loc_node *loc = + rb_entry(rb_first(locs), struct crush_loc_node, cl_node); + + erase_crush_loc(locs, loc); + free_crush_loc(loc); + } +} + +/* + * [a-zA-Z0-9-_.]+ + */ +static bool is_valid_crush_name(const char *name) +{ + do { + if (!('a' <= *name && *name <= 'z') && + !('A' <= *name && *name <= 'Z') && + !('0' <= *name && *name <= '9') && + *name != '-' && *name != '_' && *name != '.') + return false; + } while (*++name != '\0'); + + return true; +} + +/* + * Gets the parent of an item. Returns its id (<0 because the + * parent is always a bucket), type id (>0 for the same reason, + * via @parent_type_id) and location (via @parent_loc). If no + * parent, returns 0. + * + * Does a linear search, as there are no parent pointers of any + * kind. Note that the result is ambiguous for items that occur + * multiple times in the map. + */ +static int get_immediate_parent(struct crush_map *c, int id, + u16 *parent_type_id, + struct crush_loc *parent_loc) +{ + struct crush_bucket *b; + struct crush_name_node *type_cn, *cn; + int i, j; + + for (i = 0; i < c->max_buckets; i++) { + b = c->buckets[i]; + if (!b) + continue; + + /* ignore per-class shadow hierarchy */ + cn = lookup_crush_name(&c->names, b->id); + if (!cn || !is_valid_crush_name(cn->cn_name)) + continue; + + for (j = 0; j < b->size; j++) { + if (b->items[j] != id) + continue; + + *parent_type_id = b->type; + type_cn = lookup_crush_name(&c->type_names, b->type); + parent_loc->cl_type_name = type_cn->cn_name; + parent_loc->cl_name = cn->cn_name; + return b->id; + } + } + + return 0; /* no parent */ +} + +/* + * Calculates the locality/distance from an item to a client + * location expressed in terms of CRUSH hierarchy as a set of + * (bucket type name, bucket name) pairs. Specifically, looks + * for the lowest-valued bucket type for which the location of + * @id matches one of the locations in @locs, so for standard + * bucket types (host = 1, rack = 3, datacenter = 8, zone = 9) + * a matching host is closer than a matching rack and a matching + * data center is closer than a matching zone. + * + * Specifying multiple locations (a "multipath" location) such + * as "rack=foo1 rack=foo2 datacenter=bar" is allowed -- @locs + * is a multimap. The locality will be: + * + * - 3 for OSDs in racks foo1 and foo2 + * - 8 for OSDs in data center bar + * - -1 for all other OSDs + * + * The lowest possible bucket type is 1, so the best locality + * for an OSD is 1 (i.e. a matching host). Locality 0 would be + * the OSD itself. + */ +int ceph_get_crush_locality(struct ceph_osdmap *osdmap, int id, + struct rb_root *locs) +{ + struct crush_loc loc; + u16 type_id; + + /* + * Instead of repeated get_immediate_parent() calls, + * the location of @id could be obtained with a single + * depth-first traversal. + */ + for (;;) { + id = get_immediate_parent(osdmap->crush, id, &type_id, &loc); + if (id >= 0) + return -1; /* not local */ + + if (lookup_crush_loc(locs, &loc)) + return type_id; + } +} diff --git a/net/ceph/pagelist.c b/net/ceph/pagelist.c new file mode 100644 index 000000000..74622b278 --- /dev/null +++ b/net/ceph/pagelist.c @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include + +struct ceph_pagelist *ceph_pagelist_alloc(gfp_t gfp_flags) +{ + struct ceph_pagelist *pl; + + pl = kmalloc(sizeof(*pl), gfp_flags); + if (!pl) + return NULL; + + INIT_LIST_HEAD(&pl->head); + pl->mapped_tail = NULL; + pl->length = 0; + pl->room = 0; + INIT_LIST_HEAD(&pl->free_list); + pl->num_pages_free = 0; + refcount_set(&pl->refcnt, 1); + + return pl; +} +EXPORT_SYMBOL(ceph_pagelist_alloc); + +static void ceph_pagelist_unmap_tail(struct ceph_pagelist *pl) +{ + if (pl->mapped_tail) { + struct page *page = list_entry(pl->head.prev, struct page, lru); + kunmap(page); + pl->mapped_tail = NULL; + } +} + +void ceph_pagelist_release(struct ceph_pagelist *pl) +{ + if (!refcount_dec_and_test(&pl->refcnt)) + return; + ceph_pagelist_unmap_tail(pl); + while (!list_empty(&pl->head)) { + struct page *page = list_first_entry(&pl->head, struct page, + lru); + list_del(&page->lru); + __free_page(page); + } + ceph_pagelist_free_reserve(pl); + kfree(pl); +} +EXPORT_SYMBOL(ceph_pagelist_release); + +static int ceph_pagelist_addpage(struct ceph_pagelist *pl) +{ + struct page *page; + + if (!pl->num_pages_free) { + page = __page_cache_alloc(GFP_NOFS); + } else { + page = list_first_entry(&pl->free_list, struct page, lru); + list_del(&page->lru); + --pl->num_pages_free; + } + if (!page) + return -ENOMEM; + pl->room += PAGE_SIZE; + ceph_pagelist_unmap_tail(pl); + list_add_tail(&page->lru, &pl->head); + pl->mapped_tail = kmap(page); + return 0; +} + +int ceph_pagelist_append(struct ceph_pagelist *pl, const void *buf, size_t len) +{ + while (pl->room < len) { + size_t bit = pl->room; + int ret; + + memcpy(pl->mapped_tail + (pl->length & ~PAGE_MASK), + buf, bit); + pl->length += bit; + pl->room -= bit; + buf += bit; + len -= bit; + ret = ceph_pagelist_addpage(pl); + if (ret) + return ret; + } + + memcpy(pl->mapped_tail + (pl->length & ~PAGE_MASK), buf, len); + pl->length += len; + pl->room -= len; + return 0; +} +EXPORT_SYMBOL(ceph_pagelist_append); + +/* Allocate enough pages for a pagelist to append the given amount + * of data without allocating. + * Returns: 0 on success, -ENOMEM on error. + */ +int ceph_pagelist_reserve(struct ceph_pagelist *pl, size_t space) +{ + if (space <= pl->room) + return 0; + space -= pl->room; + space = (space + PAGE_SIZE - 1) >> PAGE_SHIFT; /* conv to num pages */ + + while (space > pl->num_pages_free) { + struct page *page = __page_cache_alloc(GFP_NOFS); + if (!page) + return -ENOMEM; + list_add_tail(&page->lru, &pl->free_list); + ++pl->num_pages_free; + } + return 0; +} +EXPORT_SYMBOL(ceph_pagelist_reserve); + +/* Free any pages that have been preallocated. */ +int ceph_pagelist_free_reserve(struct ceph_pagelist *pl) +{ + while (!list_empty(&pl->free_list)) { + struct page *page = list_first_entry(&pl->free_list, + struct page, lru); + list_del(&page->lru); + __free_page(page); + --pl->num_pages_free; + } + BUG_ON(pl->num_pages_free); + return 0; +} +EXPORT_SYMBOL(ceph_pagelist_free_reserve); + +/* Create a truncation point. */ +void ceph_pagelist_set_cursor(struct ceph_pagelist *pl, + struct ceph_pagelist_cursor *c) +{ + c->pl = pl; + c->page_lru = pl->head.prev; + c->room = pl->room; +} +EXPORT_SYMBOL(ceph_pagelist_set_cursor); + +/* Truncate a pagelist to the given point. Move extra pages to reserve. + * This won't sleep. + * Returns: 0 on success, + * -EINVAL if the pagelist doesn't match the trunc point pagelist + */ +int ceph_pagelist_truncate(struct ceph_pagelist *pl, + struct ceph_pagelist_cursor *c) +{ + struct page *page; + + if (pl != c->pl) + return -EINVAL; + ceph_pagelist_unmap_tail(pl); + while (pl->head.prev != c->page_lru) { + page = list_entry(pl->head.prev, struct page, lru); + /* move from pagelist to reserve */ + list_move_tail(&page->lru, &pl->free_list); + ++pl->num_pages_free; + } + pl->room = c->room; + if (!list_empty(&pl->head)) { + page = list_entry(pl->head.prev, struct page, lru); + pl->mapped_tail = kmap(page); + } + return 0; +} +EXPORT_SYMBOL(ceph_pagelist_truncate); diff --git a/net/ceph/pagevec.c b/net/ceph/pagevec.c new file mode 100644 index 000000000..64305e705 --- /dev/null +++ b/net/ceph/pagevec.c @@ -0,0 +1,166 @@ +// SPDX-License-Identifier: GPL-2.0 +#include + +#include +#include +#include +#include +#include +#include + +#include + +void ceph_put_page_vector(struct page **pages, int num_pages, bool dirty) +{ + int i; + + for (i = 0; i < num_pages; i++) { + if (dirty) + set_page_dirty_lock(pages[i]); + put_page(pages[i]); + } + kvfree(pages); +} +EXPORT_SYMBOL(ceph_put_page_vector); + +void ceph_release_page_vector(struct page **pages, int num_pages) +{ + int i; + + for (i = 0; i < num_pages; i++) + __free_pages(pages[i], 0); + kfree(pages); +} +EXPORT_SYMBOL(ceph_release_page_vector); + +/* + * allocate a vector new pages + */ +struct page **ceph_alloc_page_vector(int num_pages, gfp_t flags) +{ + struct page **pages; + int i; + + pages = kmalloc_array(num_pages, sizeof(*pages), flags); + if (!pages) + return ERR_PTR(-ENOMEM); + for (i = 0; i < num_pages; i++) { + pages[i] = __page_cache_alloc(flags); + if (pages[i] == NULL) { + ceph_release_page_vector(pages, i); + return ERR_PTR(-ENOMEM); + } + } + return pages; +} +EXPORT_SYMBOL(ceph_alloc_page_vector); + +/* + * copy user data into a page vector + */ +int ceph_copy_user_to_page_vector(struct page **pages, + const void __user *data, + loff_t off, size_t len) +{ + int i = 0; + int po = off & ~PAGE_MASK; + int left = len; + int l, bad; + + while (left > 0) { + l = min_t(int, PAGE_SIZE-po, left); + bad = copy_from_user(page_address(pages[i]) + po, data, l); + if (bad == l) + return -EFAULT; + data += l - bad; + left -= l - bad; + po += l - bad; + if (po == PAGE_SIZE) { + po = 0; + i++; + } + } + return len; +} +EXPORT_SYMBOL(ceph_copy_user_to_page_vector); + +void ceph_copy_to_page_vector(struct page **pages, + const void *data, + loff_t off, size_t len) +{ + int i = 0; + size_t po = off & ~PAGE_MASK; + size_t left = len; + + while (left > 0) { + size_t l = min_t(size_t, PAGE_SIZE-po, left); + + memcpy(page_address(pages[i]) + po, data, l); + data += l; + left -= l; + po += l; + if (po == PAGE_SIZE) { + po = 0; + i++; + } + } +} +EXPORT_SYMBOL(ceph_copy_to_page_vector); + +void ceph_copy_from_page_vector(struct page **pages, + void *data, + loff_t off, size_t len) +{ + int i = 0; + size_t po = off & ~PAGE_MASK; + size_t left = len; + + while (left > 0) { + size_t l = min_t(size_t, PAGE_SIZE-po, left); + + memcpy(data, page_address(pages[i]) + po, l); + data += l; + left -= l; + po += l; + if (po == PAGE_SIZE) { + po = 0; + i++; + } + } +} +EXPORT_SYMBOL(ceph_copy_from_page_vector); + +/* + * Zero an extent within a page vector. Offset is relative to the + * start of the first page. + */ +void ceph_zero_page_vector_range(int off, int len, struct page **pages) +{ + int i = off >> PAGE_SHIFT; + + off &= ~PAGE_MASK; + + dout("zero_page_vector_page %u~%u\n", off, len); + + /* leading partial page? */ + if (off) { + int end = min((int)PAGE_SIZE, off + len); + dout("zeroing %d %p head from %d\n", i, pages[i], + (int)off); + zero_user_segment(pages[i], off, end); + len -= (end - off); + i++; + } + while (len >= PAGE_SIZE) { + dout("zeroing %d %p len=%d\n", i, pages[i], len); + zero_user_segment(pages[i], 0, PAGE_SIZE); + len -= PAGE_SIZE; + i++; + } + /* trailing partial page? */ + if (len) { + dout("zeroing %d %p tail to %d\n", i, pages[i], (int)len); + zero_user_segment(pages[i], 0, len); + } +} +EXPORT_SYMBOL(ceph_zero_page_vector_range); diff --git a/net/ceph/snapshot.c b/net/ceph/snapshot.c new file mode 100644 index 000000000..e24315937 --- /dev/null +++ b/net/ceph/snapshot.c @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * snapshot.c Ceph snapshot context utility routines (part of libceph) + * + * Copyright (C) 2013 Inktank Storage, Inc. + */ + +#include +#include +#include + +/* + * Ceph snapshot contexts are reference counted objects, and the + * returned structure holds a single reference. Acquire additional + * references with ceph_get_snap_context(), and release them with + * ceph_put_snap_context(). When the reference count reaches zero + * the entire structure is freed. + */ + +/* + * Create a new ceph snapshot context large enough to hold the + * indicated number of snapshot ids (which can be 0). Caller has + * to fill in snapc->seq and snapc->snaps[0..snap_count-1]. + * + * Returns a null pointer if an error occurs. + */ +struct ceph_snap_context *ceph_create_snap_context(u32 snap_count, + gfp_t gfp_flags) +{ + struct ceph_snap_context *snapc; + size_t size; + + size = sizeof (struct ceph_snap_context); + size += snap_count * sizeof (snapc->snaps[0]); + snapc = kzalloc(size, gfp_flags); + if (!snapc) + return NULL; + + refcount_set(&snapc->nref, 1); + snapc->num_snaps = snap_count; + + return snapc; +} +EXPORT_SYMBOL(ceph_create_snap_context); + +struct ceph_snap_context *ceph_get_snap_context(struct ceph_snap_context *sc) +{ + if (sc) + refcount_inc(&sc->nref); + return sc; +} +EXPORT_SYMBOL(ceph_get_snap_context); + +void ceph_put_snap_context(struct ceph_snap_context *sc) +{ + if (!sc) + return; + if (refcount_dec_and_test(&sc->nref)) { + /*printk(" deleting snap_context %p\n", sc);*/ + kfree(sc); + } +} +EXPORT_SYMBOL(ceph_put_snap_context); diff --git a/net/ceph/string_table.c b/net/ceph/string_table.c new file mode 100644 index 000000000..3191d9d16 --- /dev/null +++ b/net/ceph/string_table.c @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include + +static DEFINE_SPINLOCK(string_tree_lock); +static struct rb_root string_tree = RB_ROOT; + +struct ceph_string *ceph_find_or_create_string(const char* str, size_t len) +{ + struct ceph_string *cs, *exist; + struct rb_node **p, *parent; + int ret; + + exist = NULL; + spin_lock(&string_tree_lock); + p = &string_tree.rb_node; + while (*p) { + exist = rb_entry(*p, struct ceph_string, node); + ret = ceph_compare_string(exist, str, len); + if (ret > 0) + p = &(*p)->rb_left; + else if (ret < 0) + p = &(*p)->rb_right; + else + break; + exist = NULL; + } + if (exist && !kref_get_unless_zero(&exist->kref)) { + rb_erase(&exist->node, &string_tree); + RB_CLEAR_NODE(&exist->node); + exist = NULL; + } + spin_unlock(&string_tree_lock); + if (exist) + return exist; + + cs = kmalloc(sizeof(*cs) + len + 1, GFP_NOFS); + if (!cs) + return NULL; + + kref_init(&cs->kref); + cs->len = len; + memcpy(cs->str, str, len); + cs->str[len] = 0; + +retry: + exist = NULL; + parent = NULL; + p = &string_tree.rb_node; + spin_lock(&string_tree_lock); + while (*p) { + parent = *p; + exist = rb_entry(*p, struct ceph_string, node); + ret = ceph_compare_string(exist, str, len); + if (ret > 0) + p = &(*p)->rb_left; + else if (ret < 0) + p = &(*p)->rb_right; + else + break; + exist = NULL; + } + ret = 0; + if (!exist) { + rb_link_node(&cs->node, parent, p); + rb_insert_color(&cs->node, &string_tree); + } else if (!kref_get_unless_zero(&exist->kref)) { + rb_erase(&exist->node, &string_tree); + RB_CLEAR_NODE(&exist->node); + ret = -EAGAIN; + } + spin_unlock(&string_tree_lock); + if (ret == -EAGAIN) + goto retry; + + if (exist) { + kfree(cs); + cs = exist; + } + + return cs; +} +EXPORT_SYMBOL(ceph_find_or_create_string); + +void ceph_release_string(struct kref *ref) +{ + struct ceph_string *cs = container_of(ref, struct ceph_string, kref); + + spin_lock(&string_tree_lock); + if (!RB_EMPTY_NODE(&cs->node)) { + rb_erase(&cs->node, &string_tree); + RB_CLEAR_NODE(&cs->node); + } + spin_unlock(&string_tree_lock); + + kfree_rcu(cs, rcu); +} +EXPORT_SYMBOL(ceph_release_string); + +bool ceph_strings_empty(void) +{ + return RB_EMPTY_ROOT(&string_tree); +} diff --git a/net/ceph/striper.c b/net/ceph/striper.c new file mode 100644 index 000000000..3b3fa75d1 --- /dev/null +++ b/net/ceph/striper.c @@ -0,0 +1,278 @@ +/* SPDX-License-Identifier: GPL-2.0 */ + +#include + +#include +#include + +#include +#include + +/* + * Map a file extent to a stripe unit within an object. + * Fill in objno, offset into object, and object extent length (i.e. the + * number of bytes mapped, less than or equal to @l->stripe_unit). + * + * Example for stripe_count = 3, stripes_per_object = 4: + * + * blockno | 0 3 6 9 | 1 4 7 10 | 2 5 8 11 | 12 15 18 21 | 13 16 19 + * stripeno | 0 1 2 3 | 0 1 2 3 | 0 1 2 3 | 4 5 6 7 | 4 5 6 + * stripepos | 0 | 1 | 2 | 0 | 1 + * objno | 0 | 1 | 2 | 3 | 4 + * objsetno | 0 | 1 + */ +void ceph_calc_file_object_mapping(struct ceph_file_layout *l, + u64 off, u64 len, + u64 *objno, u64 *objoff, u32 *xlen) +{ + u32 stripes_per_object = l->object_size / l->stripe_unit; + u64 blockno; /* which su in the file (i.e. globally) */ + u32 blockoff; /* offset into su */ + u64 stripeno; /* which stripe */ + u32 stripepos; /* which su in the stripe, + which object in the object set */ + u64 objsetno; /* which object set */ + u32 objsetpos; /* which stripe in the object set */ + + blockno = div_u64_rem(off, l->stripe_unit, &blockoff); + stripeno = div_u64_rem(blockno, l->stripe_count, &stripepos); + objsetno = div_u64_rem(stripeno, stripes_per_object, &objsetpos); + + *objno = objsetno * l->stripe_count + stripepos; + *objoff = objsetpos * l->stripe_unit + blockoff; + *xlen = min_t(u64, len, l->stripe_unit - blockoff); +} +EXPORT_SYMBOL(ceph_calc_file_object_mapping); + +/* + * Return the last extent with given objno (@object_extents is sorted + * by objno). If not found, return NULL and set @add_pos so that the + * new extent can be added with list_add(add_pos, new_ex). + */ +static struct ceph_object_extent * +lookup_last(struct list_head *object_extents, u64 objno, + struct list_head **add_pos) +{ + struct list_head *pos; + + list_for_each_prev(pos, object_extents) { + struct ceph_object_extent *ex = + list_entry(pos, typeof(*ex), oe_item); + + if (ex->oe_objno == objno) + return ex; + + if (ex->oe_objno < objno) + break; + } + + *add_pos = pos; + return NULL; +} + +static struct ceph_object_extent * +lookup_containing(struct list_head *object_extents, u64 objno, + u64 objoff, u32 xlen) +{ + struct ceph_object_extent *ex; + + list_for_each_entry(ex, object_extents, oe_item) { + if (ex->oe_objno == objno && + ex->oe_off <= objoff && + ex->oe_off + ex->oe_len >= objoff + xlen) /* paranoia */ + return ex; + + if (ex->oe_objno > objno) + break; + } + + return NULL; +} + +/* + * Map a file extent to a sorted list of object extents. + * + * We want only one (or as few as possible) object extents per object. + * Adjacent object extents will be merged together, each returned object + * extent may reverse map to multiple different file extents. + * + * Call @alloc_fn for each new object extent and @action_fn for each + * mapped stripe unit, whether it was merged into an already allocated + * object extent or started a new object extent. + * + * Newly allocated object extents are added to @object_extents. + * To keep @object_extents sorted, successive calls to this function + * must map successive file extents (i.e. the list of file extents that + * are mapped using the same @object_extents must be sorted). + * + * The caller is responsible for @object_extents. + */ +int ceph_file_to_extents(struct ceph_file_layout *l, u64 off, u64 len, + struct list_head *object_extents, + struct ceph_object_extent *alloc_fn(void *arg), + void *alloc_arg, + ceph_object_extent_fn_t action_fn, + void *action_arg) +{ + struct ceph_object_extent *last_ex, *ex; + + while (len) { + struct list_head *add_pos = NULL; + u64 objno, objoff; + u32 xlen; + + ceph_calc_file_object_mapping(l, off, len, &objno, &objoff, + &xlen); + + last_ex = lookup_last(object_extents, objno, &add_pos); + if (!last_ex || last_ex->oe_off + last_ex->oe_len != objoff) { + ex = alloc_fn(alloc_arg); + if (!ex) + return -ENOMEM; + + ex->oe_objno = objno; + ex->oe_off = objoff; + ex->oe_len = xlen; + if (action_fn) + action_fn(ex, xlen, action_arg); + + if (!last_ex) + list_add(&ex->oe_item, add_pos); + else + list_add(&ex->oe_item, &last_ex->oe_item); + } else { + last_ex->oe_len += xlen; + if (action_fn) + action_fn(last_ex, xlen, action_arg); + } + + off += xlen; + len -= xlen; + } + + for (last_ex = list_first_entry(object_extents, typeof(*ex), oe_item), + ex = list_next_entry(last_ex, oe_item); + &ex->oe_item != object_extents; + last_ex = ex, ex = list_next_entry(ex, oe_item)) { + if (last_ex->oe_objno > ex->oe_objno || + (last_ex->oe_objno == ex->oe_objno && + last_ex->oe_off + last_ex->oe_len >= ex->oe_off)) { + WARN(1, "%s: object_extents list not sorted!\n", + __func__); + return -EINVAL; + } + } + + return 0; +} +EXPORT_SYMBOL(ceph_file_to_extents); + +/* + * A stripped down, non-allocating version of ceph_file_to_extents(), + * for when @object_extents is already populated. + */ +int ceph_iterate_extents(struct ceph_file_layout *l, u64 off, u64 len, + struct list_head *object_extents, + ceph_object_extent_fn_t action_fn, + void *action_arg) +{ + while (len) { + struct ceph_object_extent *ex; + u64 objno, objoff; + u32 xlen; + + ceph_calc_file_object_mapping(l, off, len, &objno, &objoff, + &xlen); + + ex = lookup_containing(object_extents, objno, objoff, xlen); + if (!ex) { + WARN(1, "%s: objno %llu %llu~%u not found!\n", + __func__, objno, objoff, xlen); + return -EINVAL; + } + + action_fn(ex, xlen, action_arg); + + off += xlen; + len -= xlen; + } + + return 0; +} +EXPORT_SYMBOL(ceph_iterate_extents); + +/* + * Reverse map an object extent to a sorted list of file extents. + * + * On success, the caller is responsible for: + * + * kfree(file_extents) + */ +int ceph_extent_to_file(struct ceph_file_layout *l, + u64 objno, u64 objoff, u64 objlen, + struct ceph_file_extent **file_extents, + u32 *num_file_extents) +{ + u32 stripes_per_object = l->object_size / l->stripe_unit; + u64 blockno; /* which su */ + u32 blockoff; /* offset into su */ + u64 stripeno; /* which stripe */ + u32 stripepos; /* which su in the stripe, + which object in the object set */ + u64 objsetno; /* which object set */ + u32 i = 0; + + if (!objlen) { + *file_extents = NULL; + *num_file_extents = 0; + return 0; + } + + *num_file_extents = DIV_ROUND_UP_ULL(objoff + objlen, l->stripe_unit) - + DIV_ROUND_DOWN_ULL(objoff, l->stripe_unit); + *file_extents = kmalloc_array(*num_file_extents, sizeof(**file_extents), + GFP_NOIO); + if (!*file_extents) + return -ENOMEM; + + div_u64_rem(objoff, l->stripe_unit, &blockoff); + while (objlen) { + u64 off, len; + + objsetno = div_u64_rem(objno, l->stripe_count, &stripepos); + stripeno = div_u64(objoff, l->stripe_unit) + + objsetno * stripes_per_object; + blockno = stripeno * l->stripe_count + stripepos; + off = blockno * l->stripe_unit + blockoff; + len = min_t(u64, objlen, l->stripe_unit - blockoff); + + (*file_extents)[i].fe_off = off; + (*file_extents)[i].fe_len = len; + + blockoff = 0; + objoff += len; + objlen -= len; + i++; + } + + BUG_ON(i != *num_file_extents); + return 0; +} +EXPORT_SYMBOL(ceph_extent_to_file); + +u64 ceph_get_num_objects(struct ceph_file_layout *l, u64 size) +{ + u64 period = (u64)l->stripe_count * l->object_size; + u64 num_periods = DIV64_U64_ROUND_UP(size, period); + u64 remainder_bytes; + u64 remainder_objs = 0; + + div64_u64_rem(size, period, &remainder_bytes); + if (remainder_bytes > 0 && + remainder_bytes < (u64)l->stripe_count * l->stripe_unit) + remainder_objs = l->stripe_count - + DIV_ROUND_UP_ULL(remainder_bytes, l->stripe_unit); + + return num_periods * l->stripe_count - remainder_objs; +} +EXPORT_SYMBOL(ceph_get_num_objects); diff --git a/net/compat.c b/net/compat.c new file mode 100644 index 000000000..161b7bea1 --- /dev/null +++ b/net/compat.c @@ -0,0 +1,518 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * 32bit Socket syscall emulation. Based on arch/sparc64/kernel/sys_sparc32.c. + * + * Copyright (C) 2000 VA Linux Co + * Copyright (C) 2000 Don Dugger + * Copyright (C) 1999 Arun Sharma + * Copyright (C) 1997,1998 Jakub Jelinek (jj@sunsite.mff.cuni.cz) + * Copyright (C) 1997 David S. Miller (davem@caip.rutgers.edu) + * Copyright (C) 2000 Hewlett-Packard Co. + * Copyright (C) 2000 David Mosberger-Tang + * Copyright (C) 2000,2001 Andi Kleen, SuSE Labs + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +int __get_compat_msghdr(struct msghdr *kmsg, + struct compat_msghdr *msg, + struct sockaddr __user **save_addr) +{ + ssize_t err; + + kmsg->msg_flags = msg->msg_flags; + kmsg->msg_namelen = msg->msg_namelen; + + if (!msg->msg_name) + kmsg->msg_namelen = 0; + + if (kmsg->msg_namelen < 0) + return -EINVAL; + + if (kmsg->msg_namelen > sizeof(struct sockaddr_storage)) + kmsg->msg_namelen = sizeof(struct sockaddr_storage); + + kmsg->msg_control_is_user = true; + kmsg->msg_get_inq = 0; + kmsg->msg_control_user = compat_ptr(msg->msg_control); + kmsg->msg_controllen = msg->msg_controllen; + + if (save_addr) + *save_addr = compat_ptr(msg->msg_name); + + if (msg->msg_name && kmsg->msg_namelen) { + if (!save_addr) { + err = move_addr_to_kernel(compat_ptr(msg->msg_name), + kmsg->msg_namelen, + kmsg->msg_name); + if (err < 0) + return err; + } + } else { + kmsg->msg_name = NULL; + kmsg->msg_namelen = 0; + } + + if (msg->msg_iovlen > UIO_MAXIOV) + return -EMSGSIZE; + + kmsg->msg_iocb = NULL; + kmsg->msg_ubuf = NULL; + return 0; +} + +int get_compat_msghdr(struct msghdr *kmsg, + struct compat_msghdr __user *umsg, + struct sockaddr __user **save_addr, + struct iovec **iov) +{ + struct compat_msghdr msg; + ssize_t err; + + if (copy_from_user(&msg, umsg, sizeof(*umsg))) + return -EFAULT; + + err = __get_compat_msghdr(kmsg, &msg, save_addr); + if (err) + return err; + + err = import_iovec(save_addr ? ITER_DEST : ITER_SOURCE, + compat_ptr(msg.msg_iov), msg.msg_iovlen, + UIO_FASTIOV, iov, &kmsg->msg_iter); + return err < 0 ? err : 0; +} + +/* Bleech... */ +#define CMSG_COMPAT_ALIGN(len) ALIGN((len), sizeof(s32)) + +#define CMSG_COMPAT_DATA(cmsg) \ + ((void __user *)((char __user *)(cmsg) + sizeof(struct compat_cmsghdr))) +#define CMSG_COMPAT_SPACE(len) \ + (sizeof(struct compat_cmsghdr) + CMSG_COMPAT_ALIGN(len)) +#define CMSG_COMPAT_LEN(len) \ + (sizeof(struct compat_cmsghdr) + (len)) + +#define CMSG_COMPAT_FIRSTHDR(msg) \ + (((msg)->msg_controllen) >= sizeof(struct compat_cmsghdr) ? \ + (struct compat_cmsghdr __user *)((msg)->msg_control) : \ + (struct compat_cmsghdr __user *)NULL) + +#define CMSG_COMPAT_OK(ucmlen, ucmsg, mhdr) \ + ((ucmlen) >= sizeof(struct compat_cmsghdr) && \ + (ucmlen) <= (unsigned long) \ + ((mhdr)->msg_controllen - \ + ((char __user *)(ucmsg) - (char __user *)(mhdr)->msg_control_user))) + +static inline struct compat_cmsghdr __user *cmsg_compat_nxthdr(struct msghdr *msg, + struct compat_cmsghdr __user *cmsg, int cmsg_len) +{ + char __user *ptr = (char __user *)cmsg + CMSG_COMPAT_ALIGN(cmsg_len); + if ((unsigned long)(ptr + 1 - (char __user *)msg->msg_control) > + msg->msg_controllen) + return NULL; + return (struct compat_cmsghdr __user *)ptr; +} + +/* There is a lot of hair here because the alignment rules (and + * thus placement) of cmsg headers and length are different for + * 32-bit apps. -DaveM + */ +int cmsghdr_from_user_compat_to_kern(struct msghdr *kmsg, struct sock *sk, + unsigned char *stackbuf, int stackbuf_size) +{ + struct compat_cmsghdr __user *ucmsg; + struct cmsghdr *kcmsg, *kcmsg_base; + compat_size_t ucmlen; + __kernel_size_t kcmlen, tmp; + int err = -EFAULT; + + BUILD_BUG_ON(sizeof(struct compat_cmsghdr) != + CMSG_COMPAT_ALIGN(sizeof(struct compat_cmsghdr))); + + kcmlen = 0; + kcmsg_base = kcmsg = (struct cmsghdr *)stackbuf; + ucmsg = CMSG_COMPAT_FIRSTHDR(kmsg); + while (ucmsg != NULL) { + if (get_user(ucmlen, &ucmsg->cmsg_len)) + return -EFAULT; + + /* Catch bogons. */ + if (!CMSG_COMPAT_OK(ucmlen, ucmsg, kmsg)) + return -EINVAL; + + tmp = ((ucmlen - sizeof(*ucmsg)) + sizeof(struct cmsghdr)); + tmp = CMSG_ALIGN(tmp); + kcmlen += tmp; + ucmsg = cmsg_compat_nxthdr(kmsg, ucmsg, ucmlen); + } + if (kcmlen == 0) + return -EINVAL; + + /* The kcmlen holds the 64-bit version of the control length. + * It may not be modified as we do not stick it into the kmsg + * until we have successfully copied over all of the data + * from the user. + */ + if (kcmlen > stackbuf_size) + kcmsg_base = kcmsg = sock_kmalloc(sk, kcmlen, GFP_KERNEL); + if (kcmsg == NULL) + return -ENOMEM; + + /* Now copy them over neatly. */ + memset(kcmsg, 0, kcmlen); + ucmsg = CMSG_COMPAT_FIRSTHDR(kmsg); + while (ucmsg != NULL) { + struct compat_cmsghdr cmsg; + if (copy_from_user(&cmsg, ucmsg, sizeof(cmsg))) + goto Efault; + if (!CMSG_COMPAT_OK(cmsg.cmsg_len, ucmsg, kmsg)) + goto Einval; + tmp = ((cmsg.cmsg_len - sizeof(*ucmsg)) + sizeof(struct cmsghdr)); + if ((char *)kcmsg_base + kcmlen - (char *)kcmsg < CMSG_ALIGN(tmp)) + goto Einval; + kcmsg->cmsg_len = tmp; + kcmsg->cmsg_level = cmsg.cmsg_level; + kcmsg->cmsg_type = cmsg.cmsg_type; + tmp = CMSG_ALIGN(tmp); + if (copy_from_user(CMSG_DATA(kcmsg), + CMSG_COMPAT_DATA(ucmsg), + (cmsg.cmsg_len - sizeof(*ucmsg)))) + goto Efault; + + /* Advance. */ + kcmsg = (struct cmsghdr *)((char *)kcmsg + tmp); + ucmsg = cmsg_compat_nxthdr(kmsg, ucmsg, cmsg.cmsg_len); + } + + /* + * check the length of messages copied in is the same as the + * what we get from the first loop + */ + if ((char *)kcmsg - (char *)kcmsg_base != kcmlen) + goto Einval; + + /* Ok, looks like we made it. Hook it up and return success. */ + kmsg->msg_control = kcmsg_base; + kmsg->msg_controllen = kcmlen; + return 0; + +Einval: + err = -EINVAL; +Efault: + if (kcmsg_base != (struct cmsghdr *)stackbuf) + sock_kfree_s(sk, kcmsg_base, kcmlen); + return err; +} + +int put_cmsg_compat(struct msghdr *kmsg, int level, int type, int len, void *data) +{ + struct compat_cmsghdr __user *cm = (struct compat_cmsghdr __user *) kmsg->msg_control; + struct compat_cmsghdr cmhdr; + struct old_timeval32 ctv; + struct old_timespec32 cts[3]; + int cmlen; + + if (cm == NULL || kmsg->msg_controllen < sizeof(*cm)) { + kmsg->msg_flags |= MSG_CTRUNC; + return 0; /* XXX: return error? check spec. */ + } + + if (!COMPAT_USE_64BIT_TIME) { + if (level == SOL_SOCKET && type == SO_TIMESTAMP_OLD) { + struct __kernel_old_timeval *tv = (struct __kernel_old_timeval *)data; + ctv.tv_sec = tv->tv_sec; + ctv.tv_usec = tv->tv_usec; + data = &ctv; + len = sizeof(ctv); + } + if (level == SOL_SOCKET && + (type == SO_TIMESTAMPNS_OLD || type == SO_TIMESTAMPING_OLD)) { + int count = type == SO_TIMESTAMPNS_OLD ? 1 : 3; + int i; + struct __kernel_old_timespec *ts = data; + for (i = 0; i < count; i++) { + cts[i].tv_sec = ts[i].tv_sec; + cts[i].tv_nsec = ts[i].tv_nsec; + } + data = &cts; + len = sizeof(cts[0]) * count; + } + } + + cmlen = CMSG_COMPAT_LEN(len); + if (kmsg->msg_controllen < cmlen) { + kmsg->msg_flags |= MSG_CTRUNC; + cmlen = kmsg->msg_controllen; + } + cmhdr.cmsg_level = level; + cmhdr.cmsg_type = type; + cmhdr.cmsg_len = cmlen; + + if (copy_to_user(cm, &cmhdr, sizeof cmhdr)) + return -EFAULT; + if (copy_to_user(CMSG_COMPAT_DATA(cm), data, cmlen - sizeof(struct compat_cmsghdr))) + return -EFAULT; + cmlen = CMSG_COMPAT_SPACE(len); + if (kmsg->msg_controllen < cmlen) + cmlen = kmsg->msg_controllen; + kmsg->msg_control += cmlen; + kmsg->msg_controllen -= cmlen; + return 0; +} + +static int scm_max_fds_compat(struct msghdr *msg) +{ + if (msg->msg_controllen <= sizeof(struct compat_cmsghdr)) + return 0; + return (msg->msg_controllen - sizeof(struct compat_cmsghdr)) / sizeof(int); +} + +void scm_detach_fds_compat(struct msghdr *msg, struct scm_cookie *scm) +{ + struct compat_cmsghdr __user *cm = + (struct compat_cmsghdr __user *)msg->msg_control; + unsigned int o_flags = (msg->msg_flags & MSG_CMSG_CLOEXEC) ? O_CLOEXEC : 0; + int fdmax = min_t(int, scm_max_fds_compat(msg), scm->fp->count); + int __user *cmsg_data = CMSG_COMPAT_DATA(cm); + int err = 0, i; + + for (i = 0; i < fdmax; i++) { + err = receive_fd_user(scm->fp->fp[i], cmsg_data + i, o_flags); + if (err < 0) + break; + } + + if (i > 0) { + int cmlen = CMSG_COMPAT_LEN(i * sizeof(int)); + + err = put_user(SOL_SOCKET, &cm->cmsg_level); + if (!err) + err = put_user(SCM_RIGHTS, &cm->cmsg_type); + if (!err) + err = put_user(cmlen, &cm->cmsg_len); + if (!err) { + cmlen = CMSG_COMPAT_SPACE(i * sizeof(int)); + if (msg->msg_controllen < cmlen) + cmlen = msg->msg_controllen; + msg->msg_control += cmlen; + msg->msg_controllen -= cmlen; + } + } + + if (i < scm->fp->count || (scm->fp->count && fdmax <= 0)) + msg->msg_flags |= MSG_CTRUNC; + + /* + * All of the files that fit in the message have had their usage counts + * incremented, so we just free the list. + */ + __scm_destroy(scm); +} + +/* Argument list sizes for compat_sys_socketcall */ +#define AL(x) ((x) * sizeof(u32)) +static unsigned char nas[21] = { + AL(0), AL(3), AL(3), AL(3), AL(2), AL(3), + AL(3), AL(3), AL(4), AL(4), AL(4), AL(6), + AL(6), AL(2), AL(5), AL(5), AL(3), AL(3), + AL(4), AL(5), AL(4) +}; +#undef AL + +static inline long __compat_sys_sendmsg(int fd, + struct compat_msghdr __user *msg, + unsigned int flags) +{ + return __sys_sendmsg(fd, (struct user_msghdr __user *)msg, + flags | MSG_CMSG_COMPAT, false); +} + +COMPAT_SYSCALL_DEFINE3(sendmsg, int, fd, struct compat_msghdr __user *, msg, + unsigned int, flags) +{ + return __compat_sys_sendmsg(fd, msg, flags); +} + +static inline long __compat_sys_sendmmsg(int fd, + struct compat_mmsghdr __user *mmsg, + unsigned int vlen, unsigned int flags) +{ + return __sys_sendmmsg(fd, (struct mmsghdr __user *)mmsg, vlen, + flags | MSG_CMSG_COMPAT, false); +} + +COMPAT_SYSCALL_DEFINE4(sendmmsg, int, fd, struct compat_mmsghdr __user *, mmsg, + unsigned int, vlen, unsigned int, flags) +{ + return __compat_sys_sendmmsg(fd, mmsg, vlen, flags); +} + +static inline long __compat_sys_recvmsg(int fd, + struct compat_msghdr __user *msg, + unsigned int flags) +{ + return __sys_recvmsg(fd, (struct user_msghdr __user *)msg, + flags | MSG_CMSG_COMPAT, false); +} + +COMPAT_SYSCALL_DEFINE3(recvmsg, int, fd, struct compat_msghdr __user *, msg, + unsigned int, flags) +{ + return __compat_sys_recvmsg(fd, msg, flags); +} + +static inline long __compat_sys_recvfrom(int fd, void __user *buf, + compat_size_t len, unsigned int flags, + struct sockaddr __user *addr, + int __user *addrlen) +{ + return __sys_recvfrom(fd, buf, len, flags | MSG_CMSG_COMPAT, addr, + addrlen); +} + +COMPAT_SYSCALL_DEFINE4(recv, int, fd, void __user *, buf, compat_size_t, len, unsigned int, flags) +{ + return __compat_sys_recvfrom(fd, buf, len, flags, NULL, NULL); +} + +COMPAT_SYSCALL_DEFINE6(recvfrom, int, fd, void __user *, buf, compat_size_t, len, + unsigned int, flags, struct sockaddr __user *, addr, + int __user *, addrlen) +{ + return __compat_sys_recvfrom(fd, buf, len, flags, addr, addrlen); +} + +COMPAT_SYSCALL_DEFINE5(recvmmsg_time64, int, fd, struct compat_mmsghdr __user *, mmsg, + unsigned int, vlen, unsigned int, flags, + struct __kernel_timespec __user *, timeout) +{ + return __sys_recvmmsg(fd, (struct mmsghdr __user *)mmsg, vlen, + flags | MSG_CMSG_COMPAT, timeout, NULL); +} + +#ifdef CONFIG_COMPAT_32BIT_TIME +COMPAT_SYSCALL_DEFINE5(recvmmsg_time32, int, fd, struct compat_mmsghdr __user *, mmsg, + unsigned int, vlen, unsigned int, flags, + struct old_timespec32 __user *, timeout) +{ + return __sys_recvmmsg(fd, (struct mmsghdr __user *)mmsg, vlen, + flags | MSG_CMSG_COMPAT, NULL, timeout); +} +#endif + +COMPAT_SYSCALL_DEFINE2(socketcall, int, call, u32 __user *, args) +{ + u32 a[AUDITSC_ARGS]; + unsigned int len; + u32 a0, a1; + int ret; + + if (call < SYS_SOCKET || call > SYS_SENDMMSG) + return -EINVAL; + len = nas[call]; + if (len > sizeof(a)) + return -EINVAL; + + if (copy_from_user(a, args, len)) + return -EFAULT; + + ret = audit_socketcall_compat(len / sizeof(a[0]), a); + if (ret) + return ret; + + a0 = a[0]; + a1 = a[1]; + + switch (call) { + case SYS_SOCKET: + ret = __sys_socket(a0, a1, a[2]); + break; + case SYS_BIND: + ret = __sys_bind(a0, compat_ptr(a1), a[2]); + break; + case SYS_CONNECT: + ret = __sys_connect(a0, compat_ptr(a1), a[2]); + break; + case SYS_LISTEN: + ret = __sys_listen(a0, a1); + break; + case SYS_ACCEPT: + ret = __sys_accept4(a0, compat_ptr(a1), compat_ptr(a[2]), 0); + break; + case SYS_GETSOCKNAME: + ret = __sys_getsockname(a0, compat_ptr(a1), compat_ptr(a[2])); + break; + case SYS_GETPEERNAME: + ret = __sys_getpeername(a0, compat_ptr(a1), compat_ptr(a[2])); + break; + case SYS_SOCKETPAIR: + ret = __sys_socketpair(a0, a1, a[2], compat_ptr(a[3])); + break; + case SYS_SEND: + ret = __sys_sendto(a0, compat_ptr(a1), a[2], a[3], NULL, 0); + break; + case SYS_SENDTO: + ret = __sys_sendto(a0, compat_ptr(a1), a[2], a[3], + compat_ptr(a[4]), a[5]); + break; + case SYS_RECV: + ret = __compat_sys_recvfrom(a0, compat_ptr(a1), a[2], a[3], + NULL, NULL); + break; + case SYS_RECVFROM: + ret = __compat_sys_recvfrom(a0, compat_ptr(a1), a[2], a[3], + compat_ptr(a[4]), + compat_ptr(a[5])); + break; + case SYS_SHUTDOWN: + ret = __sys_shutdown(a0, a1); + break; + case SYS_SETSOCKOPT: + ret = __sys_setsockopt(a0, a1, a[2], compat_ptr(a[3]), a[4]); + break; + case SYS_GETSOCKOPT: + ret = __sys_getsockopt(a0, a1, a[2], compat_ptr(a[3]), + compat_ptr(a[4])); + break; + case SYS_SENDMSG: + ret = __compat_sys_sendmsg(a0, compat_ptr(a1), a[2]); + break; + case SYS_SENDMMSG: + ret = __compat_sys_sendmmsg(a0, compat_ptr(a1), a[2], a[3]); + break; + case SYS_RECVMSG: + ret = __compat_sys_recvmsg(a0, compat_ptr(a1), a[2]); + break; + case SYS_RECVMMSG: + ret = __sys_recvmmsg(a0, compat_ptr(a1), a[2], + a[3] | MSG_CMSG_COMPAT, NULL, + compat_ptr(a[4])); + break; + case SYS_ACCEPT4: + ret = __sys_accept4(a0, compat_ptr(a1), compat_ptr(a[2]), a[3]); + break; + default: + ret = -EINVAL; + break; + } + return ret; +} diff --git a/net/core/Makefile b/net/core/Makefile new file mode 100644 index 000000000..10edd66a8 --- /dev/null +++ b/net/core/Makefile @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the Linux networking core. +# + +obj-y := sock.o request_sock.o skbuff.o datagram.o stream.o scm.o \ + gen_stats.o gen_estimator.o net_namespace.o secure_seq.o \ + flow_dissector.o + +obj-$(CONFIG_SYSCTL) += sysctl_net_core.o + +obj-y += dev.o dev_addr_lists.o dst.o netevent.o \ + neighbour.o rtnetlink.o utils.o link_watch.o filter.o \ + sock_diag.o dev_ioctl.o tso.o sock_reuseport.o \ + fib_notifier.o xdp.o flow_offload.o gro.o + +obj-$(CONFIG_NETDEV_ADDR_LIST_TEST) += dev_addr_lists_test.o + +obj-y += net-sysfs.o +obj-$(CONFIG_PAGE_POOL) += page_pool.o +obj-$(CONFIG_PROC_FS) += net-procfs.o +obj-$(CONFIG_NET_PKTGEN) += pktgen.o +obj-$(CONFIG_NETPOLL) += netpoll.o +obj-$(CONFIG_FIB_RULES) += fib_rules.o +obj-$(CONFIG_TRACEPOINTS) += net-traces.o +obj-$(CONFIG_NET_DROP_MONITOR) += drop_monitor.o +obj-$(CONFIG_NET_SELFTESTS) += selftests.o +obj-$(CONFIG_NETWORK_PHY_TIMESTAMPING) += timestamping.o +obj-$(CONFIG_NET_PTP_CLASSIFY) += ptp_classifier.o +obj-$(CONFIG_CGROUP_NET_PRIO) += netprio_cgroup.o +obj-$(CONFIG_CGROUP_NET_CLASSID) += netclassid_cgroup.o +obj-$(CONFIG_LWTUNNEL) += lwtunnel.o +obj-$(CONFIG_LWTUNNEL_BPF) += lwt_bpf.o +obj-$(CONFIG_DST_CACHE) += dst_cache.o +obj-$(CONFIG_HWBM) += hwbm.o +obj-$(CONFIG_GRO_CELLS) += gro_cells.o +obj-$(CONFIG_FAILOVER) += failover.o +obj-$(CONFIG_NET_SOCK_MSG) += skmsg.o +obj-$(CONFIG_BPF_SYSCALL) += sock_map.o +obj-$(CONFIG_BPF_SYSCALL) += bpf_sk_storage.o +obj-$(CONFIG_OF) += of_net.o diff --git a/net/core/bpf_sk_storage.c b/net/core/bpf_sk_storage.c new file mode 100644 index 000000000..ad01b1bea --- /dev/null +++ b/net/core/bpf_sk_storage.c @@ -0,0 +1,965 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2019 Facebook */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +DEFINE_BPF_STORAGE_CACHE(sk_cache); + +static struct bpf_local_storage_data * +bpf_sk_storage_lookup(struct sock *sk, struct bpf_map *map, bool cacheit_lockit) +{ + struct bpf_local_storage *sk_storage; + struct bpf_local_storage_map *smap; + + sk_storage = + rcu_dereference_check(sk->sk_bpf_storage, bpf_rcu_lock_held()); + if (!sk_storage) + return NULL; + + smap = (struct bpf_local_storage_map *)map; + return bpf_local_storage_lookup(sk_storage, smap, cacheit_lockit); +} + +static int bpf_sk_storage_del(struct sock *sk, struct bpf_map *map) +{ + struct bpf_local_storage_data *sdata; + + sdata = bpf_sk_storage_lookup(sk, map, false); + if (!sdata) + return -ENOENT; + + bpf_selem_unlink(SELEM(sdata), true); + + return 0; +} + +/* Called by __sk_destruct() & bpf_sk_storage_clone() */ +void bpf_sk_storage_free(struct sock *sk) +{ + struct bpf_local_storage_elem *selem; + struct bpf_local_storage *sk_storage; + bool free_sk_storage = false; + struct hlist_node *n; + + rcu_read_lock(); + sk_storage = rcu_dereference(sk->sk_bpf_storage); + if (!sk_storage) { + rcu_read_unlock(); + return; + } + + /* Netiher the bpf_prog nor the bpf-map's syscall + * could be modifying the sk_storage->list now. + * Thus, no elem can be added-to or deleted-from the + * sk_storage->list by the bpf_prog or by the bpf-map's syscall. + * + * It is racing with bpf_local_storage_map_free() alone + * when unlinking elem from the sk_storage->list and + * the map's bucket->list. + */ + raw_spin_lock_bh(&sk_storage->lock); + hlist_for_each_entry_safe(selem, n, &sk_storage->list, snode) { + /* Always unlink from map before unlinking from + * sk_storage. + */ + bpf_selem_unlink_map(selem); + free_sk_storage = bpf_selem_unlink_storage_nolock( + sk_storage, selem, true, false); + } + raw_spin_unlock_bh(&sk_storage->lock); + rcu_read_unlock(); + + if (free_sk_storage) + kfree_rcu(sk_storage, rcu); +} + +static void bpf_sk_storage_map_free(struct bpf_map *map) +{ + struct bpf_local_storage_map *smap; + + smap = (struct bpf_local_storage_map *)map; + bpf_local_storage_cache_idx_free(&sk_cache, smap->cache_idx); + bpf_local_storage_map_free(smap, NULL); +} + +static struct bpf_map *bpf_sk_storage_map_alloc(union bpf_attr *attr) +{ + struct bpf_local_storage_map *smap; + + smap = bpf_local_storage_map_alloc(attr); + if (IS_ERR(smap)) + return ERR_CAST(smap); + + smap->cache_idx = bpf_local_storage_cache_idx_get(&sk_cache); + return &smap->map; +} + +static int notsupp_get_next_key(struct bpf_map *map, void *key, + void *next_key) +{ + return -ENOTSUPP; +} + +static void *bpf_fd_sk_storage_lookup_elem(struct bpf_map *map, void *key) +{ + struct bpf_local_storage_data *sdata; + struct socket *sock; + int fd, err; + + fd = *(int *)key; + sock = sockfd_lookup(fd, &err); + if (sock) { + sdata = bpf_sk_storage_lookup(sock->sk, map, true); + sockfd_put(sock); + return sdata ? sdata->data : NULL; + } + + return ERR_PTR(err); +} + +static int bpf_fd_sk_storage_update_elem(struct bpf_map *map, void *key, + void *value, u64 map_flags) +{ + struct bpf_local_storage_data *sdata; + struct socket *sock; + int fd, err; + + fd = *(int *)key; + sock = sockfd_lookup(fd, &err); + if (sock) { + sdata = bpf_local_storage_update( + sock->sk, (struct bpf_local_storage_map *)map, value, + map_flags, GFP_ATOMIC); + sockfd_put(sock); + return PTR_ERR_OR_ZERO(sdata); + } + + return err; +} + +static int bpf_fd_sk_storage_delete_elem(struct bpf_map *map, void *key) +{ + struct socket *sock; + int fd, err; + + fd = *(int *)key; + sock = sockfd_lookup(fd, &err); + if (sock) { + err = bpf_sk_storage_del(sock->sk, map); + sockfd_put(sock); + return err; + } + + return err; +} + +static struct bpf_local_storage_elem * +bpf_sk_storage_clone_elem(struct sock *newsk, + struct bpf_local_storage_map *smap, + struct bpf_local_storage_elem *selem) +{ + struct bpf_local_storage_elem *copy_selem; + + copy_selem = bpf_selem_alloc(smap, newsk, NULL, true, GFP_ATOMIC); + if (!copy_selem) + return NULL; + + if (map_value_has_spin_lock(&smap->map)) + copy_map_value_locked(&smap->map, SDATA(copy_selem)->data, + SDATA(selem)->data, true); + else + copy_map_value(&smap->map, SDATA(copy_selem)->data, + SDATA(selem)->data); + + return copy_selem; +} + +int bpf_sk_storage_clone(const struct sock *sk, struct sock *newsk) +{ + struct bpf_local_storage *new_sk_storage = NULL; + struct bpf_local_storage *sk_storage; + struct bpf_local_storage_elem *selem; + int ret = 0; + + RCU_INIT_POINTER(newsk->sk_bpf_storage, NULL); + + rcu_read_lock(); + sk_storage = rcu_dereference(sk->sk_bpf_storage); + + if (!sk_storage || hlist_empty(&sk_storage->list)) + goto out; + + hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) { + struct bpf_local_storage_elem *copy_selem; + struct bpf_local_storage_map *smap; + struct bpf_map *map; + + smap = rcu_dereference(SDATA(selem)->smap); + if (!(smap->map.map_flags & BPF_F_CLONE)) + continue; + + /* Note that for lockless listeners adding new element + * here can race with cleanup in bpf_local_storage_map_free. + * Try to grab map refcnt to make sure that it's still + * alive and prevent concurrent removal. + */ + map = bpf_map_inc_not_zero(&smap->map); + if (IS_ERR(map)) + continue; + + copy_selem = bpf_sk_storage_clone_elem(newsk, smap, selem); + if (!copy_selem) { + ret = -ENOMEM; + bpf_map_put(map); + goto out; + } + + if (new_sk_storage) { + bpf_selem_link_map(smap, copy_selem); + bpf_selem_link_storage_nolock(new_sk_storage, copy_selem); + } else { + ret = bpf_local_storage_alloc(newsk, smap, copy_selem, GFP_ATOMIC); + if (ret) { + kfree(copy_selem); + atomic_sub(smap->elem_size, + &newsk->sk_omem_alloc); + bpf_map_put(map); + goto out; + } + + new_sk_storage = + rcu_dereference(copy_selem->local_storage); + } + bpf_map_put(map); + } + +out: + rcu_read_unlock(); + + /* In case of an error, don't free anything explicitly here, the + * caller is responsible to call bpf_sk_storage_free. + */ + + return ret; +} + +/* *gfp_flags* is a hidden argument provided by the verifier */ +BPF_CALL_5(bpf_sk_storage_get, struct bpf_map *, map, struct sock *, sk, + void *, value, u64, flags, gfp_t, gfp_flags) +{ + struct bpf_local_storage_data *sdata; + + WARN_ON_ONCE(!bpf_rcu_lock_held()); + if (!sk || !sk_fullsock(sk) || flags > BPF_SK_STORAGE_GET_F_CREATE) + return (unsigned long)NULL; + + sdata = bpf_sk_storage_lookup(sk, map, true); + if (sdata) + return (unsigned long)sdata->data; + + if (flags == BPF_SK_STORAGE_GET_F_CREATE && + /* Cannot add new elem to a going away sk. + * Otherwise, the new elem may become a leak + * (and also other memory issues during map + * destruction). + */ + refcount_inc_not_zero(&sk->sk_refcnt)) { + sdata = bpf_local_storage_update( + sk, (struct bpf_local_storage_map *)map, value, + BPF_NOEXIST, gfp_flags); + /* sk must be a fullsock (guaranteed by verifier), + * so sock_gen_put() is unnecessary. + */ + sock_put(sk); + return IS_ERR(sdata) ? + (unsigned long)NULL : (unsigned long)sdata->data; + } + + return (unsigned long)NULL; +} + +BPF_CALL_2(bpf_sk_storage_delete, struct bpf_map *, map, struct sock *, sk) +{ + WARN_ON_ONCE(!bpf_rcu_lock_held()); + if (!sk || !sk_fullsock(sk)) + return -EINVAL; + + if (refcount_inc_not_zero(&sk->sk_refcnt)) { + int err; + + err = bpf_sk_storage_del(sk, map); + sock_put(sk); + return err; + } + + return -ENOENT; +} + +static int bpf_sk_storage_charge(struct bpf_local_storage_map *smap, + void *owner, u32 size) +{ + int optmem_max = READ_ONCE(sysctl_optmem_max); + struct sock *sk = (struct sock *)owner; + + /* same check as in sock_kmalloc() */ + if (size <= optmem_max && + atomic_read(&sk->sk_omem_alloc) + size < optmem_max) { + atomic_add(size, &sk->sk_omem_alloc); + return 0; + } + + return -ENOMEM; +} + +static void bpf_sk_storage_uncharge(struct bpf_local_storage_map *smap, + void *owner, u32 size) +{ + struct sock *sk = owner; + + atomic_sub(size, &sk->sk_omem_alloc); +} + +static struct bpf_local_storage __rcu ** +bpf_sk_storage_ptr(void *owner) +{ + struct sock *sk = owner; + + return &sk->sk_bpf_storage; +} + +BTF_ID_LIST_SINGLE(sk_storage_map_btf_ids, struct, bpf_local_storage_map) +const struct bpf_map_ops sk_storage_map_ops = { + .map_meta_equal = bpf_map_meta_equal, + .map_alloc_check = bpf_local_storage_map_alloc_check, + .map_alloc = bpf_sk_storage_map_alloc, + .map_free = bpf_sk_storage_map_free, + .map_get_next_key = notsupp_get_next_key, + .map_lookup_elem = bpf_fd_sk_storage_lookup_elem, + .map_update_elem = bpf_fd_sk_storage_update_elem, + .map_delete_elem = bpf_fd_sk_storage_delete_elem, + .map_check_btf = bpf_local_storage_map_check_btf, + .map_btf_id = &sk_storage_map_btf_ids[0], + .map_local_storage_charge = bpf_sk_storage_charge, + .map_local_storage_uncharge = bpf_sk_storage_uncharge, + .map_owner_storage_ptr = bpf_sk_storage_ptr, +}; + +const struct bpf_func_proto bpf_sk_storage_get_proto = { + .func = bpf_sk_storage_get, + .gpl_only = false, + .ret_type = RET_PTR_TO_MAP_VALUE_OR_NULL, + .arg1_type = ARG_CONST_MAP_PTR, + .arg2_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, + .arg3_type = ARG_PTR_TO_MAP_VALUE_OR_NULL, + .arg4_type = ARG_ANYTHING, +}; + +const struct bpf_func_proto bpf_sk_storage_get_cg_sock_proto = { + .func = bpf_sk_storage_get, + .gpl_only = false, + .ret_type = RET_PTR_TO_MAP_VALUE_OR_NULL, + .arg1_type = ARG_CONST_MAP_PTR, + .arg2_type = ARG_PTR_TO_CTX, /* context is 'struct sock' */ + .arg3_type = ARG_PTR_TO_MAP_VALUE_OR_NULL, + .arg4_type = ARG_ANYTHING, +}; + +const struct bpf_func_proto bpf_sk_storage_delete_proto = { + .func = bpf_sk_storage_delete, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_CONST_MAP_PTR, + .arg2_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, +}; + +static bool bpf_sk_storage_tracing_allowed(const struct bpf_prog *prog) +{ + const struct btf *btf_vmlinux; + const struct btf_type *t; + const char *tname; + u32 btf_id; + + if (prog->aux->dst_prog) + return false; + + /* Ensure the tracing program is not tracing + * any bpf_sk_storage*() function and also + * use the bpf_sk_storage_(get|delete) helper. + */ + switch (prog->expected_attach_type) { + case BPF_TRACE_ITER: + case BPF_TRACE_RAW_TP: + /* bpf_sk_storage has no trace point */ + return true; + case BPF_TRACE_FENTRY: + case BPF_TRACE_FEXIT: + btf_vmlinux = bpf_get_btf_vmlinux(); + if (IS_ERR_OR_NULL(btf_vmlinux)) + return false; + btf_id = prog->aux->attach_btf_id; + t = btf_type_by_id(btf_vmlinux, btf_id); + tname = btf_name_by_offset(btf_vmlinux, t->name_off); + return !!strncmp(tname, "bpf_sk_storage", + strlen("bpf_sk_storage")); + default: + return false; + } + + return false; +} + +/* *gfp_flags* is a hidden argument provided by the verifier */ +BPF_CALL_5(bpf_sk_storage_get_tracing, struct bpf_map *, map, struct sock *, sk, + void *, value, u64, flags, gfp_t, gfp_flags) +{ + WARN_ON_ONCE(!bpf_rcu_lock_held()); + if (in_hardirq() || in_nmi()) + return (unsigned long)NULL; + + return (unsigned long)____bpf_sk_storage_get(map, sk, value, flags, + gfp_flags); +} + +BPF_CALL_2(bpf_sk_storage_delete_tracing, struct bpf_map *, map, + struct sock *, sk) +{ + WARN_ON_ONCE(!bpf_rcu_lock_held()); + if (in_hardirq() || in_nmi()) + return -EPERM; + + return ____bpf_sk_storage_delete(map, sk); +} + +const struct bpf_func_proto bpf_sk_storage_get_tracing_proto = { + .func = bpf_sk_storage_get_tracing, + .gpl_only = false, + .ret_type = RET_PTR_TO_MAP_VALUE_OR_NULL, + .arg1_type = ARG_CONST_MAP_PTR, + .arg2_type = ARG_PTR_TO_BTF_ID, + .arg2_btf_id = &btf_sock_ids[BTF_SOCK_TYPE_SOCK_COMMON], + .arg3_type = ARG_PTR_TO_MAP_VALUE_OR_NULL, + .arg4_type = ARG_ANYTHING, + .allowed = bpf_sk_storage_tracing_allowed, +}; + +const struct bpf_func_proto bpf_sk_storage_delete_tracing_proto = { + .func = bpf_sk_storage_delete_tracing, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_CONST_MAP_PTR, + .arg2_type = ARG_PTR_TO_BTF_ID, + .arg2_btf_id = &btf_sock_ids[BTF_SOCK_TYPE_SOCK_COMMON], + .allowed = bpf_sk_storage_tracing_allowed, +}; + +struct bpf_sk_storage_diag { + u32 nr_maps; + struct bpf_map *maps[]; +}; + +/* The reply will be like: + * INET_DIAG_BPF_SK_STORAGES (nla_nest) + * SK_DIAG_BPF_STORAGE (nla_nest) + * SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32) + * SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit) + * SK_DIAG_BPF_STORAGE (nla_nest) + * SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32) + * SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit) + * .... + */ +static int nla_value_size(u32 value_size) +{ + /* SK_DIAG_BPF_STORAGE (nla_nest) + * SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32) + * SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit) + */ + return nla_total_size(0) + nla_total_size(sizeof(u32)) + + nla_total_size_64bit(value_size); +} + +void bpf_sk_storage_diag_free(struct bpf_sk_storage_diag *diag) +{ + u32 i; + + if (!diag) + return; + + for (i = 0; i < diag->nr_maps; i++) + bpf_map_put(diag->maps[i]); + + kfree(diag); +} +EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_free); + +static bool diag_check_dup(const struct bpf_sk_storage_diag *diag, + const struct bpf_map *map) +{ + u32 i; + + for (i = 0; i < diag->nr_maps; i++) { + if (diag->maps[i] == map) + return true; + } + + return false; +} + +struct bpf_sk_storage_diag * +bpf_sk_storage_diag_alloc(const struct nlattr *nla_stgs) +{ + struct bpf_sk_storage_diag *diag; + struct nlattr *nla; + u32 nr_maps = 0; + int rem, err; + + /* bpf_local_storage_map is currently limited to CAP_SYS_ADMIN as + * the map_alloc_check() side also does. + */ + if (!bpf_capable()) + return ERR_PTR(-EPERM); + + nla_for_each_nested(nla, nla_stgs, rem) { + if (nla_type(nla) == SK_DIAG_BPF_STORAGE_REQ_MAP_FD) { + if (nla_len(nla) != sizeof(u32)) + return ERR_PTR(-EINVAL); + nr_maps++; + } + } + + diag = kzalloc(struct_size(diag, maps, nr_maps), GFP_KERNEL); + if (!diag) + return ERR_PTR(-ENOMEM); + + nla_for_each_nested(nla, nla_stgs, rem) { + struct bpf_map *map; + int map_fd; + + if (nla_type(nla) != SK_DIAG_BPF_STORAGE_REQ_MAP_FD) + continue; + + map_fd = nla_get_u32(nla); + map = bpf_map_get(map_fd); + if (IS_ERR(map)) { + err = PTR_ERR(map); + goto err_free; + } + if (map->map_type != BPF_MAP_TYPE_SK_STORAGE) { + bpf_map_put(map); + err = -EINVAL; + goto err_free; + } + if (diag_check_dup(diag, map)) { + bpf_map_put(map); + err = -EEXIST; + goto err_free; + } + diag->maps[diag->nr_maps++] = map; + } + + return diag; + +err_free: + bpf_sk_storage_diag_free(diag); + return ERR_PTR(err); +} +EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_alloc); + +static int diag_get(struct bpf_local_storage_data *sdata, struct sk_buff *skb) +{ + struct nlattr *nla_stg, *nla_value; + struct bpf_local_storage_map *smap; + + /* It cannot exceed max nlattr's payload */ + BUILD_BUG_ON(U16_MAX - NLA_HDRLEN < BPF_LOCAL_STORAGE_MAX_VALUE_SIZE); + + nla_stg = nla_nest_start(skb, SK_DIAG_BPF_STORAGE); + if (!nla_stg) + return -EMSGSIZE; + + smap = rcu_dereference(sdata->smap); + if (nla_put_u32(skb, SK_DIAG_BPF_STORAGE_MAP_ID, smap->map.id)) + goto errout; + + nla_value = nla_reserve_64bit(skb, SK_DIAG_BPF_STORAGE_MAP_VALUE, + smap->map.value_size, + SK_DIAG_BPF_STORAGE_PAD); + if (!nla_value) + goto errout; + + if (map_value_has_spin_lock(&smap->map)) + copy_map_value_locked(&smap->map, nla_data(nla_value), + sdata->data, true); + else + copy_map_value(&smap->map, nla_data(nla_value), sdata->data); + + nla_nest_end(skb, nla_stg); + return 0; + +errout: + nla_nest_cancel(skb, nla_stg); + return -EMSGSIZE; +} + +static int bpf_sk_storage_diag_put_all(struct sock *sk, struct sk_buff *skb, + int stg_array_type, + unsigned int *res_diag_size) +{ + /* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */ + unsigned int diag_size = nla_total_size(0); + struct bpf_local_storage *sk_storage; + struct bpf_local_storage_elem *selem; + struct bpf_local_storage_map *smap; + struct nlattr *nla_stgs; + unsigned int saved_len; + int err = 0; + + rcu_read_lock(); + + sk_storage = rcu_dereference(sk->sk_bpf_storage); + if (!sk_storage || hlist_empty(&sk_storage->list)) { + rcu_read_unlock(); + return 0; + } + + nla_stgs = nla_nest_start(skb, stg_array_type); + if (!nla_stgs) + /* Continue to learn diag_size */ + err = -EMSGSIZE; + + saved_len = skb->len; + hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) { + smap = rcu_dereference(SDATA(selem)->smap); + diag_size += nla_value_size(smap->map.value_size); + + if (nla_stgs && diag_get(SDATA(selem), skb)) + /* Continue to learn diag_size */ + err = -EMSGSIZE; + } + + rcu_read_unlock(); + + if (nla_stgs) { + if (saved_len == skb->len) + nla_nest_cancel(skb, nla_stgs); + else + nla_nest_end(skb, nla_stgs); + } + + if (diag_size == nla_total_size(0)) { + *res_diag_size = 0; + return 0; + } + + *res_diag_size = diag_size; + return err; +} + +int bpf_sk_storage_diag_put(struct bpf_sk_storage_diag *diag, + struct sock *sk, struct sk_buff *skb, + int stg_array_type, + unsigned int *res_diag_size) +{ + /* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */ + unsigned int diag_size = nla_total_size(0); + struct bpf_local_storage *sk_storage; + struct bpf_local_storage_data *sdata; + struct nlattr *nla_stgs; + unsigned int saved_len; + int err = 0; + u32 i; + + *res_diag_size = 0; + + /* No map has been specified. Dump all. */ + if (!diag->nr_maps) + return bpf_sk_storage_diag_put_all(sk, skb, stg_array_type, + res_diag_size); + + rcu_read_lock(); + sk_storage = rcu_dereference(sk->sk_bpf_storage); + if (!sk_storage || hlist_empty(&sk_storage->list)) { + rcu_read_unlock(); + return 0; + } + + nla_stgs = nla_nest_start(skb, stg_array_type); + if (!nla_stgs) + /* Continue to learn diag_size */ + err = -EMSGSIZE; + + saved_len = skb->len; + for (i = 0; i < diag->nr_maps; i++) { + sdata = bpf_local_storage_lookup(sk_storage, + (struct bpf_local_storage_map *)diag->maps[i], + false); + + if (!sdata) + continue; + + diag_size += nla_value_size(diag->maps[i]->value_size); + + if (nla_stgs && diag_get(sdata, skb)) + /* Continue to learn diag_size */ + err = -EMSGSIZE; + } + rcu_read_unlock(); + + if (nla_stgs) { + if (saved_len == skb->len) + nla_nest_cancel(skb, nla_stgs); + else + nla_nest_end(skb, nla_stgs); + } + + if (diag_size == nla_total_size(0)) { + *res_diag_size = 0; + return 0; + } + + *res_diag_size = diag_size; + return err; +} +EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_put); + +struct bpf_iter_seq_sk_storage_map_info { + struct bpf_map *map; + unsigned int bucket_id; + unsigned skip_elems; +}; + +static struct bpf_local_storage_elem * +bpf_sk_storage_map_seq_find_next(struct bpf_iter_seq_sk_storage_map_info *info, + struct bpf_local_storage_elem *prev_selem) + __acquires(RCU) __releases(RCU) +{ + struct bpf_local_storage *sk_storage; + struct bpf_local_storage_elem *selem; + u32 skip_elems = info->skip_elems; + struct bpf_local_storage_map *smap; + u32 bucket_id = info->bucket_id; + u32 i, count, n_buckets; + struct bpf_local_storage_map_bucket *b; + + smap = (struct bpf_local_storage_map *)info->map; + n_buckets = 1U << smap->bucket_log; + if (bucket_id >= n_buckets) + return NULL; + + /* try to find next selem in the same bucket */ + selem = prev_selem; + count = 0; + while (selem) { + selem = hlist_entry_safe(rcu_dereference(hlist_next_rcu(&selem->map_node)), + struct bpf_local_storage_elem, map_node); + if (!selem) { + /* not found, unlock and go to the next bucket */ + b = &smap->buckets[bucket_id++]; + rcu_read_unlock(); + skip_elems = 0; + break; + } + sk_storage = rcu_dereference(selem->local_storage); + if (sk_storage) { + info->skip_elems = skip_elems + count; + return selem; + } + count++; + } + + for (i = bucket_id; i < (1U << smap->bucket_log); i++) { + b = &smap->buckets[i]; + rcu_read_lock(); + count = 0; + hlist_for_each_entry_rcu(selem, &b->list, map_node) { + sk_storage = rcu_dereference(selem->local_storage); + if (sk_storage && count >= skip_elems) { + info->bucket_id = i; + info->skip_elems = count; + return selem; + } + count++; + } + rcu_read_unlock(); + skip_elems = 0; + } + + info->bucket_id = i; + info->skip_elems = 0; + return NULL; +} + +static void *bpf_sk_storage_map_seq_start(struct seq_file *seq, loff_t *pos) +{ + struct bpf_local_storage_elem *selem; + + selem = bpf_sk_storage_map_seq_find_next(seq->private, NULL); + if (!selem) + return NULL; + + if (*pos == 0) + ++*pos; + return selem; +} + +static void *bpf_sk_storage_map_seq_next(struct seq_file *seq, void *v, + loff_t *pos) +{ + struct bpf_iter_seq_sk_storage_map_info *info = seq->private; + + ++*pos; + ++info->skip_elems; + return bpf_sk_storage_map_seq_find_next(seq->private, v); +} + +struct bpf_iter__bpf_sk_storage_map { + __bpf_md_ptr(struct bpf_iter_meta *, meta); + __bpf_md_ptr(struct bpf_map *, map); + __bpf_md_ptr(struct sock *, sk); + __bpf_md_ptr(void *, value); +}; + +DEFINE_BPF_ITER_FUNC(bpf_sk_storage_map, struct bpf_iter_meta *meta, + struct bpf_map *map, struct sock *sk, + void *value) + +static int __bpf_sk_storage_map_seq_show(struct seq_file *seq, + struct bpf_local_storage_elem *selem) +{ + struct bpf_iter_seq_sk_storage_map_info *info = seq->private; + struct bpf_iter__bpf_sk_storage_map ctx = {}; + struct bpf_local_storage *sk_storage; + struct bpf_iter_meta meta; + struct bpf_prog *prog; + int ret = 0; + + meta.seq = seq; + prog = bpf_iter_get_info(&meta, selem == NULL); + if (prog) { + ctx.meta = &meta; + ctx.map = info->map; + if (selem) { + sk_storage = rcu_dereference(selem->local_storage); + ctx.sk = sk_storage->owner; + ctx.value = SDATA(selem)->data; + } + ret = bpf_iter_run_prog(prog, &ctx); + } + + return ret; +} + +static int bpf_sk_storage_map_seq_show(struct seq_file *seq, void *v) +{ + return __bpf_sk_storage_map_seq_show(seq, v); +} + +static void bpf_sk_storage_map_seq_stop(struct seq_file *seq, void *v) + __releases(RCU) +{ + if (!v) + (void)__bpf_sk_storage_map_seq_show(seq, v); + else + rcu_read_unlock(); +} + +static int bpf_iter_init_sk_storage_map(void *priv_data, + struct bpf_iter_aux_info *aux) +{ + struct bpf_iter_seq_sk_storage_map_info *seq_info = priv_data; + + bpf_map_inc_with_uref(aux->map); + seq_info->map = aux->map; + return 0; +} + +static void bpf_iter_fini_sk_storage_map(void *priv_data) +{ + struct bpf_iter_seq_sk_storage_map_info *seq_info = priv_data; + + bpf_map_put_with_uref(seq_info->map); +} + +static int bpf_iter_attach_map(struct bpf_prog *prog, + union bpf_iter_link_info *linfo, + struct bpf_iter_aux_info *aux) +{ + struct bpf_map *map; + int err = -EINVAL; + + if (!linfo->map.map_fd) + return -EBADF; + + map = bpf_map_get_with_uref(linfo->map.map_fd); + if (IS_ERR(map)) + return PTR_ERR(map); + + if (map->map_type != BPF_MAP_TYPE_SK_STORAGE) + goto put_map; + + if (prog->aux->max_rdwr_access > map->value_size) { + err = -EACCES; + goto put_map; + } + + aux->map = map; + return 0; + +put_map: + bpf_map_put_with_uref(map); + return err; +} + +static void bpf_iter_detach_map(struct bpf_iter_aux_info *aux) +{ + bpf_map_put_with_uref(aux->map); +} + +static const struct seq_operations bpf_sk_storage_map_seq_ops = { + .start = bpf_sk_storage_map_seq_start, + .next = bpf_sk_storage_map_seq_next, + .stop = bpf_sk_storage_map_seq_stop, + .show = bpf_sk_storage_map_seq_show, +}; + +static const struct bpf_iter_seq_info iter_seq_info = { + .seq_ops = &bpf_sk_storage_map_seq_ops, + .init_seq_private = bpf_iter_init_sk_storage_map, + .fini_seq_private = bpf_iter_fini_sk_storage_map, + .seq_priv_size = sizeof(struct bpf_iter_seq_sk_storage_map_info), +}; + +static struct bpf_iter_reg bpf_sk_storage_map_reg_info = { + .target = "bpf_sk_storage_map", + .attach_target = bpf_iter_attach_map, + .detach_target = bpf_iter_detach_map, + .show_fdinfo = bpf_iter_map_show_fdinfo, + .fill_link_info = bpf_iter_map_fill_link_info, + .ctx_arg_info_size = 2, + .ctx_arg_info = { + { offsetof(struct bpf_iter__bpf_sk_storage_map, sk), + PTR_TO_BTF_ID_OR_NULL }, + { offsetof(struct bpf_iter__bpf_sk_storage_map, value), + PTR_TO_BUF | PTR_MAYBE_NULL }, + }, + .seq_info = &iter_seq_info, +}; + +static int __init bpf_sk_storage_map_iter_init(void) +{ + bpf_sk_storage_map_reg_info.ctx_arg_info[0].btf_id = + btf_sock_ids[BTF_SOCK_TYPE_SOCK]; + return bpf_iter_reg_target(&bpf_sk_storage_map_reg_info); +} +late_initcall(bpf_sk_storage_map_iter_init); diff --git a/net/core/datagram.c b/net/core/datagram.c new file mode 100644 index 000000000..8dabb9a74 --- /dev/null +++ b/net/core/datagram.c @@ -0,0 +1,842 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * SUCS NET3: + * + * Generic datagram handling routines. These are generic for all + * protocols. Possibly a generic IP version on top of these would + * make sense. Not tonight however 8-). + * This is used because UDP, RAW, PACKET, DDP, IPX, AX.25 and + * NetROM layer all have identical poll code and mostly + * identical recvmsg() code. So we share it here. The poll was + * shared before but buried in udp.c so I moved it. + * + * Authors: Alan Cox . (datagram_poll() from old + * udp.c code) + * + * Fixes: + * Alan Cox : NULL return from skb_peek_copy() + * understood + * Alan Cox : Rewrote skb_read_datagram to avoid the + * skb_peek_copy stuff. + * Alan Cox : Added support for SOCK_SEQPACKET. + * IPX can no longer use the SO_TYPE hack + * but AX.25 now works right, and SPX is + * feasible. + * Alan Cox : Fixed write poll of non IP protocol + * crash. + * Florian La Roche: Changed for my new skbuff handling. + * Darryl Miles : Fixed non-blocking SOCK_SEQPACKET. + * Linus Torvalds : BSD semantic fixes. + * Alan Cox : Datagram iovec handling + * Darryl Miles : Fixed non-blocking SOCK_STREAM. + * Alan Cox : POSIXisms + * Pete Wyckoff : Unconnected accept() fix. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +/* + * Is a socket 'connection oriented' ? + */ +static inline int connection_based(struct sock *sk) +{ + return sk->sk_type == SOCK_SEQPACKET || sk->sk_type == SOCK_STREAM; +} + +static int receiver_wake_function(wait_queue_entry_t *wait, unsigned int mode, int sync, + void *key) +{ + /* + * Avoid a wakeup if event not interesting for us + */ + if (key && !(key_to_poll(key) & (EPOLLIN | EPOLLERR))) + return 0; + return autoremove_wake_function(wait, mode, sync, key); +} +/* + * Wait for the last received packet to be different from skb + */ +int __skb_wait_for_more_packets(struct sock *sk, struct sk_buff_head *queue, + int *err, long *timeo_p, + const struct sk_buff *skb) +{ + int error; + DEFINE_WAIT_FUNC(wait, receiver_wake_function); + + prepare_to_wait_exclusive(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + + /* Socket errors? */ + error = sock_error(sk); + if (error) + goto out_err; + + if (READ_ONCE(queue->prev) != skb) + goto out; + + /* Socket shut down? */ + if (sk->sk_shutdown & RCV_SHUTDOWN) + goto out_noerr; + + /* Sequenced packets can come disconnected. + * If so we report the problem + */ + error = -ENOTCONN; + if (connection_based(sk) && + !(sk->sk_state == TCP_ESTABLISHED || sk->sk_state == TCP_LISTEN)) + goto out_err; + + /* handle signals */ + if (signal_pending(current)) + goto interrupted; + + error = 0; + *timeo_p = schedule_timeout(*timeo_p); +out: + finish_wait(sk_sleep(sk), &wait); + return error; +interrupted: + error = sock_intr_errno(*timeo_p); +out_err: + *err = error; + goto out; +out_noerr: + *err = 0; + error = 1; + goto out; +} +EXPORT_SYMBOL(__skb_wait_for_more_packets); + +static struct sk_buff *skb_set_peeked(struct sk_buff *skb) +{ + struct sk_buff *nskb; + + if (skb->peeked) + return skb; + + /* We have to unshare an skb before modifying it. */ + if (!skb_shared(skb)) + goto done; + + nskb = skb_clone(skb, GFP_ATOMIC); + if (!nskb) + return ERR_PTR(-ENOMEM); + + skb->prev->next = nskb; + skb->next->prev = nskb; + nskb->prev = skb->prev; + nskb->next = skb->next; + + consume_skb(skb); + skb = nskb; + +done: + skb->peeked = 1; + + return skb; +} + +struct sk_buff *__skb_try_recv_from_queue(struct sock *sk, + struct sk_buff_head *queue, + unsigned int flags, + int *off, int *err, + struct sk_buff **last) +{ + bool peek_at_off = false; + struct sk_buff *skb; + int _off = 0; + + if (unlikely(flags & MSG_PEEK && *off >= 0)) { + peek_at_off = true; + _off = *off; + } + + *last = queue->prev; + skb_queue_walk(queue, skb) { + if (flags & MSG_PEEK) { + if (peek_at_off && _off >= skb->len && + (_off || skb->peeked)) { + _off -= skb->len; + continue; + } + if (!skb->len) { + skb = skb_set_peeked(skb); + if (IS_ERR(skb)) { + *err = PTR_ERR(skb); + return NULL; + } + } + refcount_inc(&skb->users); + } else { + __skb_unlink(skb, queue); + } + *off = _off; + return skb; + } + return NULL; +} + +/** + * __skb_try_recv_datagram - Receive a datagram skbuff + * @sk: socket + * @queue: socket queue from which to receive + * @flags: MSG\_ flags + * @off: an offset in bytes to peek skb from. Returns an offset + * within an skb where data actually starts + * @err: error code returned + * @last: set to last peeked message to inform the wait function + * what to look for when peeking + * + * Get a datagram skbuff, understands the peeking, nonblocking wakeups + * and possible races. This replaces identical code in packet, raw and + * udp, as well as the IPX AX.25 and Appletalk. It also finally fixes + * the long standing peek and read race for datagram sockets. If you + * alter this routine remember it must be re-entrant. + * + * This function will lock the socket if a skb is returned, so + * the caller needs to unlock the socket in that case (usually by + * calling skb_free_datagram). Returns NULL with @err set to + * -EAGAIN if no data was available or to some other value if an + * error was detected. + * + * * It does not lock socket since today. This function is + * * free of race conditions. This measure should/can improve + * * significantly datagram socket latencies at high loads, + * * when data copying to user space takes lots of time. + * * (BTW I've just killed the last cli() in IP/IPv6/core/netlink/packet + * * 8) Great win.) + * * --ANK (980729) + * + * The order of the tests when we find no data waiting are specified + * quite explicitly by POSIX 1003.1g, don't change them without having + * the standard around please. + */ +struct sk_buff *__skb_try_recv_datagram(struct sock *sk, + struct sk_buff_head *queue, + unsigned int flags, int *off, int *err, + struct sk_buff **last) +{ + struct sk_buff *skb; + unsigned long cpu_flags; + /* + * Caller is allowed not to check sk->sk_err before skb_recv_datagram() + */ + int error = sock_error(sk); + + if (error) + goto no_packet; + + do { + /* Again only user level code calls this function, so nothing + * interrupt level will suddenly eat the receive_queue. + * + * Look at current nfs client by the way... + * However, this function was correct in any case. 8) + */ + spin_lock_irqsave(&queue->lock, cpu_flags); + skb = __skb_try_recv_from_queue(sk, queue, flags, off, &error, + last); + spin_unlock_irqrestore(&queue->lock, cpu_flags); + if (error) + goto no_packet; + if (skb) + return skb; + + if (!sk_can_busy_loop(sk)) + break; + + sk_busy_loop(sk, flags & MSG_DONTWAIT); + } while (READ_ONCE(queue->prev) != *last); + + error = -EAGAIN; + +no_packet: + *err = error; + return NULL; +} +EXPORT_SYMBOL(__skb_try_recv_datagram); + +struct sk_buff *__skb_recv_datagram(struct sock *sk, + struct sk_buff_head *sk_queue, + unsigned int flags, int *off, int *err) +{ + struct sk_buff *skb, *last; + long timeo; + + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + + do { + skb = __skb_try_recv_datagram(sk, sk_queue, flags, off, err, + &last); + if (skb) + return skb; + + if (*err != -EAGAIN) + break; + } while (timeo && + !__skb_wait_for_more_packets(sk, sk_queue, err, + &timeo, last)); + + return NULL; +} +EXPORT_SYMBOL(__skb_recv_datagram); + +struct sk_buff *skb_recv_datagram(struct sock *sk, unsigned int flags, + int *err) +{ + int off = 0; + + return __skb_recv_datagram(sk, &sk->sk_receive_queue, flags, + &off, err); +} +EXPORT_SYMBOL(skb_recv_datagram); + +void skb_free_datagram(struct sock *sk, struct sk_buff *skb) +{ + consume_skb(skb); +} +EXPORT_SYMBOL(skb_free_datagram); + +void __skb_free_datagram_locked(struct sock *sk, struct sk_buff *skb, int len) +{ + bool slow; + + if (!skb_unref(skb)) { + sk_peek_offset_bwd(sk, len); + return; + } + + slow = lock_sock_fast(sk); + sk_peek_offset_bwd(sk, len); + skb_orphan(skb); + unlock_sock_fast(sk, slow); + + /* skb is now orphaned, can be freed outside of locked section */ + __kfree_skb(skb); +} +EXPORT_SYMBOL(__skb_free_datagram_locked); + +int __sk_queue_drop_skb(struct sock *sk, struct sk_buff_head *sk_queue, + struct sk_buff *skb, unsigned int flags, + void (*destructor)(struct sock *sk, + struct sk_buff *skb)) +{ + int err = 0; + + if (flags & MSG_PEEK) { + err = -ENOENT; + spin_lock_bh(&sk_queue->lock); + if (skb->next) { + __skb_unlink(skb, sk_queue); + refcount_dec(&skb->users); + if (destructor) + destructor(sk, skb); + err = 0; + } + spin_unlock_bh(&sk_queue->lock); + } + + atomic_inc(&sk->sk_drops); + return err; +} +EXPORT_SYMBOL(__sk_queue_drop_skb); + +/** + * skb_kill_datagram - Free a datagram skbuff forcibly + * @sk: socket + * @skb: datagram skbuff + * @flags: MSG\_ flags + * + * This function frees a datagram skbuff that was received by + * skb_recv_datagram. The flags argument must match the one + * used for skb_recv_datagram. + * + * If the MSG_PEEK flag is set, and the packet is still on the + * receive queue of the socket, it will be taken off the queue + * before it is freed. + * + * This function currently only disables BH when acquiring the + * sk_receive_queue lock. Therefore it must not be used in a + * context where that lock is acquired in an IRQ context. + * + * It returns 0 if the packet was removed by us. + */ + +int skb_kill_datagram(struct sock *sk, struct sk_buff *skb, unsigned int flags) +{ + int err = __sk_queue_drop_skb(sk, &sk->sk_receive_queue, skb, flags, + NULL); + + kfree_skb(skb); + return err; +} +EXPORT_SYMBOL(skb_kill_datagram); + +INDIRECT_CALLABLE_DECLARE(static size_t simple_copy_to_iter(const void *addr, + size_t bytes, + void *data __always_unused, + struct iov_iter *i)); + +static int __skb_datagram_iter(const struct sk_buff *skb, int offset, + struct iov_iter *to, int len, bool fault_short, + size_t (*cb)(const void *, size_t, void *, + struct iov_iter *), void *data) +{ + int start = skb_headlen(skb); + int i, copy = start - offset, start_off = offset, n; + struct sk_buff *frag_iter; + + /* Copy header. */ + if (copy > 0) { + if (copy > len) + copy = len; + n = INDIRECT_CALL_1(cb, simple_copy_to_iter, + skb->data + offset, copy, data, to); + offset += n; + if (n != copy) + goto short_copy; + if ((len -= copy) == 0) + return 0; + } + + /* Copy paged appendix. Hmm... why does this look so complicated? */ + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + int end; + const skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; + + WARN_ON(start > offset + len); + + end = start + skb_frag_size(frag); + if ((copy = end - offset) > 0) { + struct page *page = skb_frag_page(frag); + u8 *vaddr = kmap(page); + + if (copy > len) + copy = len; + n = INDIRECT_CALL_1(cb, simple_copy_to_iter, + vaddr + skb_frag_off(frag) + offset - start, + copy, data, to); + kunmap(page); + offset += n; + if (n != copy) + goto short_copy; + if (!(len -= copy)) + return 0; + } + start = end; + } + + skb_walk_frags(skb, frag_iter) { + int end; + + WARN_ON(start > offset + len); + + end = start + frag_iter->len; + if ((copy = end - offset) > 0) { + if (copy > len) + copy = len; + if (__skb_datagram_iter(frag_iter, offset - start, + to, copy, fault_short, cb, data)) + goto fault; + if ((len -= copy) == 0) + return 0; + offset += copy; + } + start = end; + } + if (!len) + return 0; + + /* This is not really a user copy fault, but rather someone + * gave us a bogus length on the skb. We should probably + * print a warning here as it may indicate a kernel bug. + */ + +fault: + iov_iter_revert(to, offset - start_off); + return -EFAULT; + +short_copy: + if (fault_short || iov_iter_count(to)) + goto fault; + + return 0; +} + +/** + * skb_copy_and_hash_datagram_iter - Copy datagram to an iovec iterator + * and update a hash. + * @skb: buffer to copy + * @offset: offset in the buffer to start copying from + * @to: iovec iterator to copy to + * @len: amount of data to copy from buffer to iovec + * @hash: hash request to update + */ +int skb_copy_and_hash_datagram_iter(const struct sk_buff *skb, int offset, + struct iov_iter *to, int len, + struct ahash_request *hash) +{ + return __skb_datagram_iter(skb, offset, to, len, true, + hash_and_copy_to_iter, hash); +} +EXPORT_SYMBOL(skb_copy_and_hash_datagram_iter); + +static size_t simple_copy_to_iter(const void *addr, size_t bytes, + void *data __always_unused, struct iov_iter *i) +{ + return copy_to_iter(addr, bytes, i); +} + +/** + * skb_copy_datagram_iter - Copy a datagram to an iovec iterator. + * @skb: buffer to copy + * @offset: offset in the buffer to start copying from + * @to: iovec iterator to copy to + * @len: amount of data to copy from buffer to iovec + */ +int skb_copy_datagram_iter(const struct sk_buff *skb, int offset, + struct iov_iter *to, int len) +{ + trace_skb_copy_datagram_iovec(skb, len); + return __skb_datagram_iter(skb, offset, to, len, false, + simple_copy_to_iter, NULL); +} +EXPORT_SYMBOL(skb_copy_datagram_iter); + +/** + * skb_copy_datagram_from_iter - Copy a datagram from an iov_iter. + * @skb: buffer to copy + * @offset: offset in the buffer to start copying to + * @from: the copy source + * @len: amount of data to copy to buffer from iovec + * + * Returns 0 or -EFAULT. + */ +int skb_copy_datagram_from_iter(struct sk_buff *skb, int offset, + struct iov_iter *from, + int len) +{ + int start = skb_headlen(skb); + int i, copy = start - offset; + struct sk_buff *frag_iter; + + /* Copy header. */ + if (copy > 0) { + if (copy > len) + copy = len; + if (copy_from_iter(skb->data + offset, copy, from) != copy) + goto fault; + if ((len -= copy) == 0) + return 0; + offset += copy; + } + + /* Copy paged appendix. Hmm... why does this look so complicated? */ + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + int end; + const skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; + + WARN_ON(start > offset + len); + + end = start + skb_frag_size(frag); + if ((copy = end - offset) > 0) { + size_t copied; + + if (copy > len) + copy = len; + copied = copy_page_from_iter(skb_frag_page(frag), + skb_frag_off(frag) + offset - start, + copy, from); + if (copied != copy) + goto fault; + + if (!(len -= copy)) + return 0; + offset += copy; + } + start = end; + } + + skb_walk_frags(skb, frag_iter) { + int end; + + WARN_ON(start > offset + len); + + end = start + frag_iter->len; + if ((copy = end - offset) > 0) { + if (copy > len) + copy = len; + if (skb_copy_datagram_from_iter(frag_iter, + offset - start, + from, copy)) + goto fault; + if ((len -= copy) == 0) + return 0; + offset += copy; + } + start = end; + } + if (!len) + return 0; + +fault: + return -EFAULT; +} +EXPORT_SYMBOL(skb_copy_datagram_from_iter); + +int __zerocopy_sg_from_iter(struct msghdr *msg, struct sock *sk, + struct sk_buff *skb, struct iov_iter *from, + size_t length) +{ + int frag; + + if (msg && msg->msg_ubuf && msg->sg_from_iter) + return msg->sg_from_iter(sk, skb, from, length); + + frag = skb_shinfo(skb)->nr_frags; + + while (length && iov_iter_count(from)) { + struct page *pages[MAX_SKB_FRAGS]; + struct page *last_head = NULL; + size_t start; + ssize_t copied; + unsigned long truesize; + int refs, n = 0; + + if (frag == MAX_SKB_FRAGS) + return -EMSGSIZE; + + copied = iov_iter_get_pages2(from, pages, length, + MAX_SKB_FRAGS - frag, &start); + if (copied < 0) + return -EFAULT; + + length -= copied; + + truesize = PAGE_ALIGN(copied + start); + skb->data_len += copied; + skb->len += copied; + skb->truesize += truesize; + if (sk && sk->sk_type == SOCK_STREAM) { + sk_wmem_queued_add(sk, truesize); + if (!skb_zcopy_pure(skb)) + sk_mem_charge(sk, truesize); + } else { + refcount_add(truesize, &skb->sk->sk_wmem_alloc); + } + for (refs = 0; copied != 0; start = 0) { + int size = min_t(int, copied, PAGE_SIZE - start); + struct page *head = compound_head(pages[n]); + + start += (pages[n] - head) << PAGE_SHIFT; + copied -= size; + n++; + if (frag) { + skb_frag_t *last = &skb_shinfo(skb)->frags[frag - 1]; + + if (head == skb_frag_page(last) && + start == skb_frag_off(last) + skb_frag_size(last)) { + skb_frag_size_add(last, size); + /* We combined this page, we need to release + * a reference. Since compound pages refcount + * is shared among many pages, batch the refcount + * adjustments to limit false sharing. + */ + last_head = head; + refs++; + continue; + } + } + if (refs) { + page_ref_sub(last_head, refs); + refs = 0; + } + skb_fill_page_desc_noacc(skb, frag++, head, start, size); + } + if (refs) + page_ref_sub(last_head, refs); + } + return 0; +} +EXPORT_SYMBOL(__zerocopy_sg_from_iter); + +/** + * zerocopy_sg_from_iter - Build a zerocopy datagram from an iov_iter + * @skb: buffer to copy + * @from: the source to copy from + * + * The function will first copy up to headlen, and then pin the userspace + * pages and build frags through them. + * + * Returns 0, -EFAULT or -EMSGSIZE. + */ +int zerocopy_sg_from_iter(struct sk_buff *skb, struct iov_iter *from) +{ + int copy = min_t(int, skb_headlen(skb), iov_iter_count(from)); + + /* copy up to skb headlen */ + if (skb_copy_datagram_from_iter(skb, 0, from, copy)) + return -EFAULT; + + return __zerocopy_sg_from_iter(NULL, NULL, skb, from, ~0U); +} +EXPORT_SYMBOL(zerocopy_sg_from_iter); + +/** + * skb_copy_and_csum_datagram - Copy datagram to an iovec iterator + * and update a checksum. + * @skb: buffer to copy + * @offset: offset in the buffer to start copying from + * @to: iovec iterator to copy to + * @len: amount of data to copy from buffer to iovec + * @csump: checksum pointer + */ +static int skb_copy_and_csum_datagram(const struct sk_buff *skb, int offset, + struct iov_iter *to, int len, + __wsum *csump) +{ + struct csum_state csdata = { .csum = *csump }; + int ret; + + ret = __skb_datagram_iter(skb, offset, to, len, true, + csum_and_copy_to_iter, &csdata); + if (ret) + return ret; + + *csump = csdata.csum; + return 0; +} + +/** + * skb_copy_and_csum_datagram_msg - Copy and checksum skb to user iovec. + * @skb: skbuff + * @hlen: hardware length + * @msg: destination + * + * Caller _must_ check that skb will fit to this iovec. + * + * Returns: 0 - success. + * -EINVAL - checksum failure. + * -EFAULT - fault during copy. + */ +int skb_copy_and_csum_datagram_msg(struct sk_buff *skb, + int hlen, struct msghdr *msg) +{ + __wsum csum; + int chunk = skb->len - hlen; + + if (!chunk) + return 0; + + if (msg_data_left(msg) < chunk) { + if (__skb_checksum_complete(skb)) + return -EINVAL; + if (skb_copy_datagram_msg(skb, hlen, msg, chunk)) + goto fault; + } else { + csum = csum_partial(skb->data, hlen, skb->csum); + if (skb_copy_and_csum_datagram(skb, hlen, &msg->msg_iter, + chunk, &csum)) + goto fault; + + if (csum_fold(csum)) { + iov_iter_revert(&msg->msg_iter, chunk); + return -EINVAL; + } + + if (unlikely(skb->ip_summed == CHECKSUM_COMPLETE) && + !skb->csum_complete_sw) + netdev_rx_csum_fault(NULL, skb); + } + return 0; +fault: + return -EFAULT; +} +EXPORT_SYMBOL(skb_copy_and_csum_datagram_msg); + +/** + * datagram_poll - generic datagram poll + * @file: file struct + * @sock: socket + * @wait: poll table + * + * Datagram poll: Again totally generic. This also handles + * sequenced packet sockets providing the socket receive queue + * is only ever holding data ready to receive. + * + * Note: when you *don't* use this routine for this protocol, + * and you use a different write policy from sock_writeable() + * then please supply your own write_space callback. + */ +__poll_t datagram_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + struct sock *sk = sock->sk; + __poll_t mask; + u8 shutdown; + + sock_poll_wait(file, sock, wait); + mask = 0; + + /* exceptional events? */ + if (READ_ONCE(sk->sk_err) || + !skb_queue_empty_lockless(&sk->sk_error_queue)) + mask |= EPOLLERR | + (sock_flag(sk, SOCK_SELECT_ERR_QUEUE) ? EPOLLPRI : 0); + + shutdown = READ_ONCE(sk->sk_shutdown); + if (shutdown & RCV_SHUTDOWN) + mask |= EPOLLRDHUP | EPOLLIN | EPOLLRDNORM; + if (shutdown == SHUTDOWN_MASK) + mask |= EPOLLHUP; + + /* readable? */ + if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) + mask |= EPOLLIN | EPOLLRDNORM; + + /* Connection-based need to check for termination and startup */ + if (connection_based(sk)) { + int state = READ_ONCE(sk->sk_state); + + if (state == TCP_CLOSE) + mask |= EPOLLHUP; + /* connection hasn't started yet? */ + if (state == TCP_SYN_SENT) + return mask; + } + + /* writable? */ + if (sock_writeable(sk)) + mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND; + else + sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk); + + return mask; +} +EXPORT_SYMBOL(datagram_poll); diff --git a/net/core/dev.c b/net/core/dev.c new file mode 100644 index 000000000..1ba3662fa --- /dev/null +++ b/net/core/dev.c @@ -0,0 +1,11486 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NET3 Protocol independent device support routines. + * + * Derived from the non IP parts of dev.c 1.0.19 + * Authors: Ross Biro + * Fred N. van Kempen, + * Mark Evans, + * + * Additional Authors: + * Florian la Roche + * Alan Cox + * David Hinds + * Alexey Kuznetsov + * Adam Sulmicki + * Pekka Riikonen + * + * Changes: + * D.J. Barrow : Fixed bug where dev->refcnt gets set + * to 2 if register_netdev gets called + * before net_dev_init & also removed a + * few lines of code in the process. + * Alan Cox : device private ioctl copies fields back. + * Alan Cox : Transmit queue code does relevant + * stunts to keep the queue safe. + * Alan Cox : Fixed double lock. + * Alan Cox : Fixed promisc NULL pointer trap + * ???????? : Support the full private ioctl range + * Alan Cox : Moved ioctl permission check into + * drivers + * Tim Kordas : SIOCADDMULTI/SIOCDELMULTI + * Alan Cox : 100 backlog just doesn't cut it when + * you start doing multicast video 8) + * Alan Cox : Rewrote net_bh and list manager. + * Alan Cox : Fix ETH_P_ALL echoback lengths. + * Alan Cox : Took out transmit every packet pass + * Saved a few bytes in the ioctl handler + * Alan Cox : Network driver sets packet type before + * calling netif_rx. Saves a function + * call a packet. + * Alan Cox : Hashed net_bh() + * Richard Kooijman: Timestamp fixes. + * Alan Cox : Wrong field in SIOCGIFDSTADDR + * Alan Cox : Device lock protection. + * Alan Cox : Fixed nasty side effect of device close + * changes. + * Rudi Cilibrasi : Pass the right thing to + * set_mac_address() + * Dave Miller : 32bit quantity for the device lock to + * make it work out on a Sparc. + * Bjorn Ekwall : Added KERNELD hack. + * Alan Cox : Cleaned up the backlog initialise. + * Craig Metz : SIOCGIFCONF fix if space for under + * 1 device. + * Thomas Bogendoerfer : Return ENODEV for dev_open, if there + * is no device open function. + * Andi Kleen : Fix error reporting for SIOCGIFCONF + * Michael Chastain : Fix signed/unsigned for SIOCGIFCONF + * Cyrus Durgin : Cleaned for KMOD + * Adam Sulmicki : Bug Fix : Network Device Unload + * A network device unload needs to purge + * the backlog queue. + * Paul Rusty Russell : SIOCSIFNAME + * Pekka Riikonen : Netdev boot-time settings code + * Andrew Morton : Make unregister_netdevice wait + * indefinitely on dev->refcnt + * J Hadi Salim : - Backlog queue sampling + * - netif_rx() feedback + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dev.h" +#include "net-sysfs.h" + + +static DEFINE_SPINLOCK(ptype_lock); +struct list_head ptype_base[PTYPE_HASH_SIZE] __read_mostly; +struct list_head ptype_all __read_mostly; /* Taps */ + +static int netif_rx_internal(struct sk_buff *skb); +static int call_netdevice_notifiers_info(unsigned long val, + struct netdev_notifier_info *info); +static int call_netdevice_notifiers_extack(unsigned long val, + struct net_device *dev, + struct netlink_ext_ack *extack); +static struct napi_struct *napi_by_id(unsigned int napi_id); + +/* + * The @dev_base_head list is protected by @dev_base_lock and the rtnl + * semaphore. + * + * Pure readers hold dev_base_lock for reading, or rcu_read_lock() + * + * Writers must hold the rtnl semaphore while they loop through the + * dev_base_head list, and hold dev_base_lock for writing when they do the + * actual updates. This allows pure readers to access the list even + * while a writer is preparing to update it. + * + * To put it another way, dev_base_lock is held for writing only to + * protect against pure readers; the rtnl semaphore provides the + * protection against other writers. + * + * See, for example usages, register_netdevice() and + * unregister_netdevice(), which must be called with the rtnl + * semaphore held. + */ +DEFINE_RWLOCK(dev_base_lock); +EXPORT_SYMBOL(dev_base_lock); + +static DEFINE_MUTEX(ifalias_mutex); + +/* protects napi_hash addition/deletion and napi_gen_id */ +static DEFINE_SPINLOCK(napi_hash_lock); + +static unsigned int napi_gen_id = NR_CPUS; +static DEFINE_READ_MOSTLY_HASHTABLE(napi_hash, 8); + +static DECLARE_RWSEM(devnet_rename_sem); + +static inline void dev_base_seq_inc(struct net *net) +{ + while (++net->dev_base_seq == 0) + ; +} + +static inline struct hlist_head *dev_name_hash(struct net *net, const char *name) +{ + unsigned int hash = full_name_hash(net, name, strnlen(name, IFNAMSIZ)); + + return &net->dev_name_head[hash_32(hash, NETDEV_HASHBITS)]; +} + +static inline struct hlist_head *dev_index_hash(struct net *net, int ifindex) +{ + return &net->dev_index_head[ifindex & (NETDEV_HASHENTRIES - 1)]; +} + +static inline void rps_lock_irqsave(struct softnet_data *sd, + unsigned long *flags) +{ + if (IS_ENABLED(CONFIG_RPS)) + spin_lock_irqsave(&sd->input_pkt_queue.lock, *flags); + else if (!IS_ENABLED(CONFIG_PREEMPT_RT)) + local_irq_save(*flags); +} + +static inline void rps_lock_irq_disable(struct softnet_data *sd) +{ + if (IS_ENABLED(CONFIG_RPS)) + spin_lock_irq(&sd->input_pkt_queue.lock); + else if (!IS_ENABLED(CONFIG_PREEMPT_RT)) + local_irq_disable(); +} + +static inline void rps_unlock_irq_restore(struct softnet_data *sd, + unsigned long *flags) +{ + if (IS_ENABLED(CONFIG_RPS)) + spin_unlock_irqrestore(&sd->input_pkt_queue.lock, *flags); + else if (!IS_ENABLED(CONFIG_PREEMPT_RT)) + local_irq_restore(*flags); +} + +static inline void rps_unlock_irq_enable(struct softnet_data *sd) +{ + if (IS_ENABLED(CONFIG_RPS)) + spin_unlock_irq(&sd->input_pkt_queue.lock); + else if (!IS_ENABLED(CONFIG_PREEMPT_RT)) + local_irq_enable(); +} + +static struct netdev_name_node *netdev_name_node_alloc(struct net_device *dev, + const char *name) +{ + struct netdev_name_node *name_node; + + name_node = kmalloc(sizeof(*name_node), GFP_KERNEL); + if (!name_node) + return NULL; + INIT_HLIST_NODE(&name_node->hlist); + name_node->dev = dev; + name_node->name = name; + return name_node; +} + +static struct netdev_name_node * +netdev_name_node_head_alloc(struct net_device *dev) +{ + struct netdev_name_node *name_node; + + name_node = netdev_name_node_alloc(dev, dev->name); + if (!name_node) + return NULL; + INIT_LIST_HEAD(&name_node->list); + return name_node; +} + +static void netdev_name_node_free(struct netdev_name_node *name_node) +{ + kfree(name_node); +} + +static void netdev_name_node_add(struct net *net, + struct netdev_name_node *name_node) +{ + hlist_add_head_rcu(&name_node->hlist, + dev_name_hash(net, name_node->name)); +} + +static void netdev_name_node_del(struct netdev_name_node *name_node) +{ + hlist_del_rcu(&name_node->hlist); +} + +static struct netdev_name_node *netdev_name_node_lookup(struct net *net, + const char *name) +{ + struct hlist_head *head = dev_name_hash(net, name); + struct netdev_name_node *name_node; + + hlist_for_each_entry(name_node, head, hlist) + if (!strcmp(name_node->name, name)) + return name_node; + return NULL; +} + +static struct netdev_name_node *netdev_name_node_lookup_rcu(struct net *net, + const char *name) +{ + struct hlist_head *head = dev_name_hash(net, name); + struct netdev_name_node *name_node; + + hlist_for_each_entry_rcu(name_node, head, hlist) + if (!strcmp(name_node->name, name)) + return name_node; + return NULL; +} + +bool netdev_name_in_use(struct net *net, const char *name) +{ + return netdev_name_node_lookup(net, name); +} +EXPORT_SYMBOL(netdev_name_in_use); + +int netdev_name_node_alt_create(struct net_device *dev, const char *name) +{ + struct netdev_name_node *name_node; + struct net *net = dev_net(dev); + + name_node = netdev_name_node_lookup(net, name); + if (name_node) + return -EEXIST; + name_node = netdev_name_node_alloc(dev, name); + if (!name_node) + return -ENOMEM; + netdev_name_node_add(net, name_node); + /* The node that holds dev->name acts as a head of per-device list. */ + list_add_tail(&name_node->list, &dev->name_node->list); + + return 0; +} + +static void __netdev_name_node_alt_destroy(struct netdev_name_node *name_node) +{ + list_del(&name_node->list); + kfree(name_node->name); + netdev_name_node_free(name_node); +} + +int netdev_name_node_alt_destroy(struct net_device *dev, const char *name) +{ + struct netdev_name_node *name_node; + struct net *net = dev_net(dev); + + name_node = netdev_name_node_lookup(net, name); + if (!name_node) + return -ENOENT; + /* lookup might have found our primary name or a name belonging + * to another device. + */ + if (name_node == dev->name_node || name_node->dev != dev) + return -EINVAL; + + netdev_name_node_del(name_node); + synchronize_rcu(); + __netdev_name_node_alt_destroy(name_node); + + return 0; +} + +static void netdev_name_node_alt_flush(struct net_device *dev) +{ + struct netdev_name_node *name_node, *tmp; + + list_for_each_entry_safe(name_node, tmp, &dev->name_node->list, list) + __netdev_name_node_alt_destroy(name_node); +} + +/* Device list insertion */ +static void list_netdevice(struct net_device *dev) +{ + struct netdev_name_node *name_node; + struct net *net = dev_net(dev); + + ASSERT_RTNL(); + + write_lock(&dev_base_lock); + list_add_tail_rcu(&dev->dev_list, &net->dev_base_head); + netdev_name_node_add(net, dev->name_node); + hlist_add_head_rcu(&dev->index_hlist, + dev_index_hash(net, dev->ifindex)); + write_unlock(&dev_base_lock); + + netdev_for_each_altname(dev, name_node) + netdev_name_node_add(net, name_node); + + dev_base_seq_inc(net); +} + +/* Device list removal + * caller must respect a RCU grace period before freeing/reusing dev + */ +static void unlist_netdevice(struct net_device *dev, bool lock) +{ + struct netdev_name_node *name_node; + + ASSERT_RTNL(); + + netdev_for_each_altname(dev, name_node) + netdev_name_node_del(name_node); + + /* Unlink dev from the device chain */ + if (lock) + write_lock(&dev_base_lock); + list_del_rcu(&dev->dev_list); + netdev_name_node_del(dev->name_node); + hlist_del_rcu(&dev->index_hlist); + if (lock) + write_unlock(&dev_base_lock); + + dev_base_seq_inc(dev_net(dev)); +} + +/* + * Our notifier list + */ + +static RAW_NOTIFIER_HEAD(netdev_chain); + +/* + * Device drivers call our routines to queue packets here. We empty the + * queue in the local softnet handler. + */ + +DEFINE_PER_CPU_ALIGNED(struct softnet_data, softnet_data); +EXPORT_PER_CPU_SYMBOL(softnet_data); + +#ifdef CONFIG_LOCKDEP +/* + * register_netdevice() inits txq->_xmit_lock and sets lockdep class + * according to dev->type + */ +static const unsigned short netdev_lock_type[] = { + ARPHRD_NETROM, ARPHRD_ETHER, ARPHRD_EETHER, ARPHRD_AX25, + ARPHRD_PRONET, ARPHRD_CHAOS, ARPHRD_IEEE802, ARPHRD_ARCNET, + ARPHRD_APPLETLK, ARPHRD_DLCI, ARPHRD_ATM, ARPHRD_METRICOM, + ARPHRD_IEEE1394, ARPHRD_EUI64, ARPHRD_INFINIBAND, ARPHRD_SLIP, + ARPHRD_CSLIP, ARPHRD_SLIP6, ARPHRD_CSLIP6, ARPHRD_RSRVD, + ARPHRD_ADAPT, ARPHRD_ROSE, ARPHRD_X25, ARPHRD_HWX25, + ARPHRD_PPP, ARPHRD_CISCO, ARPHRD_LAPB, ARPHRD_DDCMP, + ARPHRD_RAWHDLC, ARPHRD_TUNNEL, ARPHRD_TUNNEL6, ARPHRD_FRAD, + ARPHRD_SKIP, ARPHRD_LOOPBACK, ARPHRD_LOCALTLK, ARPHRD_FDDI, + ARPHRD_BIF, ARPHRD_SIT, ARPHRD_IPDDP, ARPHRD_IPGRE, + ARPHRD_PIMREG, ARPHRD_HIPPI, ARPHRD_ASH, ARPHRD_ECONET, + ARPHRD_IRDA, ARPHRD_FCPP, ARPHRD_FCAL, ARPHRD_FCPL, + ARPHRD_FCFABRIC, ARPHRD_IEEE80211, ARPHRD_IEEE80211_PRISM, + ARPHRD_IEEE80211_RADIOTAP, ARPHRD_PHONET, ARPHRD_PHONET_PIPE, + ARPHRD_IEEE802154, ARPHRD_VOID, ARPHRD_NONE}; + +static const char *const netdev_lock_name[] = { + "_xmit_NETROM", "_xmit_ETHER", "_xmit_EETHER", "_xmit_AX25", + "_xmit_PRONET", "_xmit_CHAOS", "_xmit_IEEE802", "_xmit_ARCNET", + "_xmit_APPLETLK", "_xmit_DLCI", "_xmit_ATM", "_xmit_METRICOM", + "_xmit_IEEE1394", "_xmit_EUI64", "_xmit_INFINIBAND", "_xmit_SLIP", + "_xmit_CSLIP", "_xmit_SLIP6", "_xmit_CSLIP6", "_xmit_RSRVD", + "_xmit_ADAPT", "_xmit_ROSE", "_xmit_X25", "_xmit_HWX25", + "_xmit_PPP", "_xmit_CISCO", "_xmit_LAPB", "_xmit_DDCMP", + "_xmit_RAWHDLC", "_xmit_TUNNEL", "_xmit_TUNNEL6", "_xmit_FRAD", + "_xmit_SKIP", "_xmit_LOOPBACK", "_xmit_LOCALTLK", "_xmit_FDDI", + "_xmit_BIF", "_xmit_SIT", "_xmit_IPDDP", "_xmit_IPGRE", + "_xmit_PIMREG", "_xmit_HIPPI", "_xmit_ASH", "_xmit_ECONET", + "_xmit_IRDA", "_xmit_FCPP", "_xmit_FCAL", "_xmit_FCPL", + "_xmit_FCFABRIC", "_xmit_IEEE80211", "_xmit_IEEE80211_PRISM", + "_xmit_IEEE80211_RADIOTAP", "_xmit_PHONET", "_xmit_PHONET_PIPE", + "_xmit_IEEE802154", "_xmit_VOID", "_xmit_NONE"}; + +static struct lock_class_key netdev_xmit_lock_key[ARRAY_SIZE(netdev_lock_type)]; +static struct lock_class_key netdev_addr_lock_key[ARRAY_SIZE(netdev_lock_type)]; + +static inline unsigned short netdev_lock_pos(unsigned short dev_type) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(netdev_lock_type); i++) + if (netdev_lock_type[i] == dev_type) + return i; + /* the last key is used by default */ + return ARRAY_SIZE(netdev_lock_type) - 1; +} + +static inline void netdev_set_xmit_lockdep_class(spinlock_t *lock, + unsigned short dev_type) +{ + int i; + + i = netdev_lock_pos(dev_type); + lockdep_set_class_and_name(lock, &netdev_xmit_lock_key[i], + netdev_lock_name[i]); +} + +static inline void netdev_set_addr_lockdep_class(struct net_device *dev) +{ + int i; + + i = netdev_lock_pos(dev->type); + lockdep_set_class_and_name(&dev->addr_list_lock, + &netdev_addr_lock_key[i], + netdev_lock_name[i]); +} +#else +static inline void netdev_set_xmit_lockdep_class(spinlock_t *lock, + unsigned short dev_type) +{ +} + +static inline void netdev_set_addr_lockdep_class(struct net_device *dev) +{ +} +#endif + +/******************************************************************************* + * + * Protocol management and registration routines + * + *******************************************************************************/ + + +/* + * Add a protocol ID to the list. Now that the input handler is + * smarter we can dispense with all the messy stuff that used to be + * here. + * + * BEWARE!!! Protocol handlers, mangling input packets, + * MUST BE last in hash buckets and checking protocol handlers + * MUST start from promiscuous ptype_all chain in net_bh. + * It is true now, do not change it. + * Explanation follows: if protocol handler, mangling packet, will + * be the first on list, it is not able to sense, that packet + * is cloned and should be copied-on-write, so that it will + * change it and subsequent readers will get broken packet. + * --ANK (980803) + */ + +static inline struct list_head *ptype_head(const struct packet_type *pt) +{ + if (pt->type == htons(ETH_P_ALL)) + return pt->dev ? &pt->dev->ptype_all : &ptype_all; + else + return pt->dev ? &pt->dev->ptype_specific : + &ptype_base[ntohs(pt->type) & PTYPE_HASH_MASK]; +} + +/** + * dev_add_pack - add packet handler + * @pt: packet type declaration + * + * Add a protocol handler to the networking stack. The passed &packet_type + * is linked into kernel lists and may not be freed until it has been + * removed from the kernel lists. + * + * This call does not sleep therefore it can not + * guarantee all CPU's that are in middle of receiving packets + * will see the new packet type (until the next received packet). + */ + +void dev_add_pack(struct packet_type *pt) +{ + struct list_head *head = ptype_head(pt); + + spin_lock(&ptype_lock); + list_add_rcu(&pt->list, head); + spin_unlock(&ptype_lock); +} +EXPORT_SYMBOL(dev_add_pack); + +/** + * __dev_remove_pack - remove packet handler + * @pt: packet type declaration + * + * Remove a protocol handler that was previously added to the kernel + * protocol handlers by dev_add_pack(). The passed &packet_type is removed + * from the kernel lists and can be freed or reused once this function + * returns. + * + * The packet type might still be in use by receivers + * and must not be freed until after all the CPU's have gone + * through a quiescent state. + */ +void __dev_remove_pack(struct packet_type *pt) +{ + struct list_head *head = ptype_head(pt); + struct packet_type *pt1; + + spin_lock(&ptype_lock); + + list_for_each_entry(pt1, head, list) { + if (pt == pt1) { + list_del_rcu(&pt->list); + goto out; + } + } + + pr_warn("dev_remove_pack: %p not found\n", pt); +out: + spin_unlock(&ptype_lock); +} +EXPORT_SYMBOL(__dev_remove_pack); + +/** + * dev_remove_pack - remove packet handler + * @pt: packet type declaration + * + * Remove a protocol handler that was previously added to the kernel + * protocol handlers by dev_add_pack(). The passed &packet_type is removed + * from the kernel lists and can be freed or reused once this function + * returns. + * + * This call sleeps to guarantee that no CPU is looking at the packet + * type after return. + */ +void dev_remove_pack(struct packet_type *pt) +{ + __dev_remove_pack(pt); + + synchronize_net(); +} +EXPORT_SYMBOL(dev_remove_pack); + + +/******************************************************************************* + * + * Device Interface Subroutines + * + *******************************************************************************/ + +/** + * dev_get_iflink - get 'iflink' value of a interface + * @dev: targeted interface + * + * Indicates the ifindex the interface is linked to. + * Physical interfaces have the same 'ifindex' and 'iflink' values. + */ + +int dev_get_iflink(const struct net_device *dev) +{ + if (dev->netdev_ops && dev->netdev_ops->ndo_get_iflink) + return dev->netdev_ops->ndo_get_iflink(dev); + + return dev->ifindex; +} +EXPORT_SYMBOL(dev_get_iflink); + +/** + * dev_fill_metadata_dst - Retrieve tunnel egress information. + * @dev: targeted interface + * @skb: The packet. + * + * For better visibility of tunnel traffic OVS needs to retrieve + * egress tunnel information for a packet. Following API allows + * user to get this info. + */ +int dev_fill_metadata_dst(struct net_device *dev, struct sk_buff *skb) +{ + struct ip_tunnel_info *info; + + if (!dev->netdev_ops || !dev->netdev_ops->ndo_fill_metadata_dst) + return -EINVAL; + + info = skb_tunnel_info_unclone(skb); + if (!info) + return -ENOMEM; + if (unlikely(!(info->mode & IP_TUNNEL_INFO_TX))) + return -EINVAL; + + return dev->netdev_ops->ndo_fill_metadata_dst(dev, skb); +} +EXPORT_SYMBOL_GPL(dev_fill_metadata_dst); + +static struct net_device_path *dev_fwd_path(struct net_device_path_stack *stack) +{ + int k = stack->num_paths++; + + if (WARN_ON_ONCE(k >= NET_DEVICE_PATH_STACK_MAX)) + return NULL; + + return &stack->path[k]; +} + +int dev_fill_forward_path(const struct net_device *dev, const u8 *daddr, + struct net_device_path_stack *stack) +{ + const struct net_device *last_dev; + struct net_device_path_ctx ctx = { + .dev = dev, + }; + struct net_device_path *path; + int ret = 0; + + memcpy(ctx.daddr, daddr, sizeof(ctx.daddr)); + stack->num_paths = 0; + while (ctx.dev && ctx.dev->netdev_ops->ndo_fill_forward_path) { + last_dev = ctx.dev; + path = dev_fwd_path(stack); + if (!path) + return -1; + + memset(path, 0, sizeof(struct net_device_path)); + ret = ctx.dev->netdev_ops->ndo_fill_forward_path(&ctx, path); + if (ret < 0) + return -1; + + if (WARN_ON_ONCE(last_dev == ctx.dev)) + return -1; + } + + if (!ctx.dev) + return ret; + + path = dev_fwd_path(stack); + if (!path) + return -1; + path->type = DEV_PATH_ETHERNET; + path->dev = ctx.dev; + + return ret; +} +EXPORT_SYMBOL_GPL(dev_fill_forward_path); + +/** + * __dev_get_by_name - find a device by its name + * @net: the applicable net namespace + * @name: name to find + * + * Find an interface by name. Must be called under RTNL semaphore + * or @dev_base_lock. If the name is found a pointer to the device + * is returned. If the name is not found then %NULL is returned. The + * reference counters are not incremented so the caller must be + * careful with locks. + */ + +struct net_device *__dev_get_by_name(struct net *net, const char *name) +{ + struct netdev_name_node *node_name; + + node_name = netdev_name_node_lookup(net, name); + return node_name ? node_name->dev : NULL; +} +EXPORT_SYMBOL(__dev_get_by_name); + +/** + * dev_get_by_name_rcu - find a device by its name + * @net: the applicable net namespace + * @name: name to find + * + * Find an interface by name. + * If the name is found a pointer to the device is returned. + * If the name is not found then %NULL is returned. + * The reference counters are not incremented so the caller must be + * careful with locks. The caller must hold RCU lock. + */ + +struct net_device *dev_get_by_name_rcu(struct net *net, const char *name) +{ + struct netdev_name_node *node_name; + + node_name = netdev_name_node_lookup_rcu(net, name); + return node_name ? node_name->dev : NULL; +} +EXPORT_SYMBOL(dev_get_by_name_rcu); + +/** + * dev_get_by_name - find a device by its name + * @net: the applicable net namespace + * @name: name to find + * + * Find an interface by name. This can be called from any + * context and does its own locking. The returned handle has + * the usage count incremented and the caller must use dev_put() to + * release it when it is no longer needed. %NULL is returned if no + * matching device is found. + */ + +struct net_device *dev_get_by_name(struct net *net, const char *name) +{ + struct net_device *dev; + + rcu_read_lock(); + dev = dev_get_by_name_rcu(net, name); + dev_hold(dev); + rcu_read_unlock(); + return dev; +} +EXPORT_SYMBOL(dev_get_by_name); + +/** + * __dev_get_by_index - find a device by its ifindex + * @net: the applicable net namespace + * @ifindex: index of device + * + * Search for an interface by index. Returns %NULL if the device + * is not found or a pointer to the device. The device has not + * had its reference counter increased so the caller must be careful + * about locking. The caller must hold either the RTNL semaphore + * or @dev_base_lock. + */ + +struct net_device *__dev_get_by_index(struct net *net, int ifindex) +{ + struct net_device *dev; + struct hlist_head *head = dev_index_hash(net, ifindex); + + hlist_for_each_entry(dev, head, index_hlist) + if (dev->ifindex == ifindex) + return dev; + + return NULL; +} +EXPORT_SYMBOL(__dev_get_by_index); + +/** + * dev_get_by_index_rcu - find a device by its ifindex + * @net: the applicable net namespace + * @ifindex: index of device + * + * Search for an interface by index. Returns %NULL if the device + * is not found or a pointer to the device. The device has not + * had its reference counter increased so the caller must be careful + * about locking. The caller must hold RCU lock. + */ + +struct net_device *dev_get_by_index_rcu(struct net *net, int ifindex) +{ + struct net_device *dev; + struct hlist_head *head = dev_index_hash(net, ifindex); + + hlist_for_each_entry_rcu(dev, head, index_hlist) + if (dev->ifindex == ifindex) + return dev; + + return NULL; +} +EXPORT_SYMBOL(dev_get_by_index_rcu); + + +/** + * dev_get_by_index - find a device by its ifindex + * @net: the applicable net namespace + * @ifindex: index of device + * + * Search for an interface by index. Returns NULL if the device + * is not found or a pointer to the device. The device returned has + * had a reference added and the pointer is safe until the user calls + * dev_put to indicate they have finished with it. + */ + +struct net_device *dev_get_by_index(struct net *net, int ifindex) +{ + struct net_device *dev; + + rcu_read_lock(); + dev = dev_get_by_index_rcu(net, ifindex); + dev_hold(dev); + rcu_read_unlock(); + return dev; +} +EXPORT_SYMBOL(dev_get_by_index); + +/** + * dev_get_by_napi_id - find a device by napi_id + * @napi_id: ID of the NAPI struct + * + * Search for an interface by NAPI ID. Returns %NULL if the device + * is not found or a pointer to the device. The device has not had + * its reference counter increased so the caller must be careful + * about locking. The caller must hold RCU lock. + */ + +struct net_device *dev_get_by_napi_id(unsigned int napi_id) +{ + struct napi_struct *napi; + + WARN_ON_ONCE(!rcu_read_lock_held()); + + if (napi_id < MIN_NAPI_ID) + return NULL; + + napi = napi_by_id(napi_id); + + return napi ? napi->dev : NULL; +} +EXPORT_SYMBOL(dev_get_by_napi_id); + +/** + * netdev_get_name - get a netdevice name, knowing its ifindex. + * @net: network namespace + * @name: a pointer to the buffer where the name will be stored. + * @ifindex: the ifindex of the interface to get the name from. + */ +int netdev_get_name(struct net *net, char *name, int ifindex) +{ + struct net_device *dev; + int ret; + + down_read(&devnet_rename_sem); + rcu_read_lock(); + + dev = dev_get_by_index_rcu(net, ifindex); + if (!dev) { + ret = -ENODEV; + goto out; + } + + strcpy(name, dev->name); + + ret = 0; +out: + rcu_read_unlock(); + up_read(&devnet_rename_sem); + return ret; +} + +/** + * dev_getbyhwaddr_rcu - find a device by its hardware address + * @net: the applicable net namespace + * @type: media type of device + * @ha: hardware address + * + * Search for an interface by MAC address. Returns NULL if the device + * is not found or a pointer to the device. + * The caller must hold RCU or RTNL. + * The returned device has not had its ref count increased + * and the caller must therefore be careful about locking + * + */ + +struct net_device *dev_getbyhwaddr_rcu(struct net *net, unsigned short type, + const char *ha) +{ + struct net_device *dev; + + for_each_netdev_rcu(net, dev) + if (dev->type == type && + !memcmp(dev->dev_addr, ha, dev->addr_len)) + return dev; + + return NULL; +} +EXPORT_SYMBOL(dev_getbyhwaddr_rcu); + +struct net_device *dev_getfirstbyhwtype(struct net *net, unsigned short type) +{ + struct net_device *dev, *ret = NULL; + + rcu_read_lock(); + for_each_netdev_rcu(net, dev) + if (dev->type == type) { + dev_hold(dev); + ret = dev; + break; + } + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL(dev_getfirstbyhwtype); + +/** + * __dev_get_by_flags - find any device with given flags + * @net: the applicable net namespace + * @if_flags: IFF_* values + * @mask: bitmask of bits in if_flags to check + * + * Search for any interface with the given flags. Returns NULL if a device + * is not found or a pointer to the device. Must be called inside + * rtnl_lock(), and result refcount is unchanged. + */ + +struct net_device *__dev_get_by_flags(struct net *net, unsigned short if_flags, + unsigned short mask) +{ + struct net_device *dev, *ret; + + ASSERT_RTNL(); + + ret = NULL; + for_each_netdev(net, dev) { + if (((dev->flags ^ if_flags) & mask) == 0) { + ret = dev; + break; + } + } + return ret; +} +EXPORT_SYMBOL(__dev_get_by_flags); + +/** + * dev_valid_name - check if name is okay for network device + * @name: name string + * + * Network device names need to be valid file names to + * allow sysfs to work. We also disallow any kind of + * whitespace. + */ +bool dev_valid_name(const char *name) +{ + if (*name == '\0') + return false; + if (strnlen(name, IFNAMSIZ) == IFNAMSIZ) + return false; + if (!strcmp(name, ".") || !strcmp(name, "..")) + return false; + + while (*name) { + if (*name == '/' || *name == ':' || isspace(*name)) + return false; + name++; + } + return true; +} +EXPORT_SYMBOL(dev_valid_name); + +/** + * __dev_alloc_name - allocate a name for a device + * @net: network namespace to allocate the device name in + * @name: name format string + * @buf: scratch buffer and result name string + * + * Passed a format string - eg "lt%d" it will try and find a suitable + * id. It scans list of devices to build up a free map, then chooses + * the first empty slot. The caller must hold the dev_base or rtnl lock + * while allocating the name and adding the device in order to avoid + * duplicates. + * Limited to bits_per_byte * page size devices (ie 32K on most platforms). + * Returns the number of the unit assigned or a negative errno code. + */ + +static int __dev_alloc_name(struct net *net, const char *name, char *buf) +{ + int i = 0; + const char *p; + const int max_netdevices = 8*PAGE_SIZE; + unsigned long *inuse; + struct net_device *d; + + if (!dev_valid_name(name)) + return -EINVAL; + + p = strchr(name, '%'); + if (p) { + /* + * Verify the string as this thing may have come from + * the user. There must be either one "%d" and no other "%" + * characters. + */ + if (p[1] != 'd' || strchr(p + 2, '%')) + return -EINVAL; + + /* Use one page as a bit array of possible slots */ + inuse = (unsigned long *) get_zeroed_page(GFP_ATOMIC); + if (!inuse) + return -ENOMEM; + + for_each_netdev(net, d) { + struct netdev_name_node *name_node; + + netdev_for_each_altname(d, name_node) { + if (!sscanf(name_node->name, name, &i)) + continue; + if (i < 0 || i >= max_netdevices) + continue; + + /* avoid cases where sscanf is not exact inverse of printf */ + snprintf(buf, IFNAMSIZ, name, i); + if (!strncmp(buf, name_node->name, IFNAMSIZ)) + __set_bit(i, inuse); + } + if (!sscanf(d->name, name, &i)) + continue; + if (i < 0 || i >= max_netdevices) + continue; + + /* avoid cases where sscanf is not exact inverse of printf */ + snprintf(buf, IFNAMSIZ, name, i); + if (!strncmp(buf, d->name, IFNAMSIZ)) + __set_bit(i, inuse); + } + + i = find_first_zero_bit(inuse, max_netdevices); + free_page((unsigned long) inuse); + } + + snprintf(buf, IFNAMSIZ, name, i); + if (!netdev_name_in_use(net, buf)) + return i; + + /* It is possible to run out of possible slots + * when the name is long and there isn't enough space left + * for the digits, or if all bits are used. + */ + return -ENFILE; +} + +static int dev_prep_valid_name(struct net *net, struct net_device *dev, + const char *want_name, char *out_name) +{ + int ret; + + if (!dev_valid_name(want_name)) + return -EINVAL; + + if (strchr(want_name, '%')) { + ret = __dev_alloc_name(net, want_name, out_name); + return ret < 0 ? ret : 0; + } else if (netdev_name_in_use(net, want_name)) { + return -EEXIST; + } else if (out_name != want_name) { + strscpy(out_name, want_name, IFNAMSIZ); + } + + return 0; +} + +static int dev_alloc_name_ns(struct net *net, + struct net_device *dev, + const char *name) +{ + char buf[IFNAMSIZ]; + int ret; + + BUG_ON(!net); + ret = __dev_alloc_name(net, name, buf); + if (ret >= 0) + strscpy(dev->name, buf, IFNAMSIZ); + return ret; +} + +/** + * dev_alloc_name - allocate a name for a device + * @dev: device + * @name: name format string + * + * Passed a format string - eg "lt%d" it will try and find a suitable + * id. It scans list of devices to build up a free map, then chooses + * the first empty slot. The caller must hold the dev_base or rtnl lock + * while allocating the name and adding the device in order to avoid + * duplicates. + * Limited to bits_per_byte * page size devices (ie 32K on most platforms). + * Returns the number of the unit assigned or a negative errno code. + */ + +int dev_alloc_name(struct net_device *dev, const char *name) +{ + return dev_alloc_name_ns(dev_net(dev), dev, name); +} +EXPORT_SYMBOL(dev_alloc_name); + +static int dev_get_valid_name(struct net *net, struct net_device *dev, + const char *name) +{ + char buf[IFNAMSIZ]; + int ret; + + ret = dev_prep_valid_name(net, dev, name, buf); + if (ret >= 0) + strscpy(dev->name, buf, IFNAMSIZ); + return ret; +} + +/** + * dev_change_name - change name of a device + * @dev: device + * @newname: name (or format string) must be at least IFNAMSIZ + * + * Change name of a device, can pass format strings "eth%d". + * for wildcarding. + */ +int dev_change_name(struct net_device *dev, const char *newname) +{ + unsigned char old_assign_type; + char oldname[IFNAMSIZ]; + int err = 0; + int ret; + struct net *net; + + ASSERT_RTNL(); + BUG_ON(!dev_net(dev)); + + net = dev_net(dev); + + /* Some auto-enslaved devices e.g. failover slaves are + * special, as userspace might rename the device after + * the interface had been brought up and running since + * the point kernel initiated auto-enslavement. Allow + * live name change even when these slave devices are + * up and running. + * + * Typically, users of these auto-enslaving devices + * don't actually care about slave name change, as + * they are supposed to operate on master interface + * directly. + */ + if (dev->flags & IFF_UP && + likely(!(dev->priv_flags & IFF_LIVE_RENAME_OK))) + return -EBUSY; + + down_write(&devnet_rename_sem); + + if (strncmp(newname, dev->name, IFNAMSIZ) == 0) { + up_write(&devnet_rename_sem); + return 0; + } + + memcpy(oldname, dev->name, IFNAMSIZ); + + err = dev_get_valid_name(net, dev, newname); + if (err < 0) { + up_write(&devnet_rename_sem); + return err; + } + + if (oldname[0] && !strchr(oldname, '%')) + netdev_info(dev, "renamed from %s\n", oldname); + + old_assign_type = dev->name_assign_type; + dev->name_assign_type = NET_NAME_RENAMED; + +rollback: + ret = device_rename(&dev->dev, dev->name); + if (ret) { + memcpy(dev->name, oldname, IFNAMSIZ); + dev->name_assign_type = old_assign_type; + up_write(&devnet_rename_sem); + return ret; + } + + up_write(&devnet_rename_sem); + + netdev_adjacent_rename_links(dev, oldname); + + write_lock(&dev_base_lock); + netdev_name_node_del(dev->name_node); + write_unlock(&dev_base_lock); + + synchronize_rcu(); + + write_lock(&dev_base_lock); + netdev_name_node_add(net, dev->name_node); + write_unlock(&dev_base_lock); + + ret = call_netdevice_notifiers(NETDEV_CHANGENAME, dev); + ret = notifier_to_errno(ret); + + if (ret) { + /* err >= 0 after dev_alloc_name() or stores the first errno */ + if (err >= 0) { + err = ret; + down_write(&devnet_rename_sem); + memcpy(dev->name, oldname, IFNAMSIZ); + memcpy(oldname, newname, IFNAMSIZ); + dev->name_assign_type = old_assign_type; + old_assign_type = NET_NAME_RENAMED; + goto rollback; + } else { + netdev_err(dev, "name change rollback failed: %d\n", + ret); + } + } + + return err; +} + +/** + * dev_set_alias - change ifalias of a device + * @dev: device + * @alias: name up to IFALIASZ + * @len: limit of bytes to copy from info + * + * Set ifalias for a device, + */ +int dev_set_alias(struct net_device *dev, const char *alias, size_t len) +{ + struct dev_ifalias *new_alias = NULL; + + if (len >= IFALIASZ) + return -EINVAL; + + if (len) { + new_alias = kmalloc(sizeof(*new_alias) + len + 1, GFP_KERNEL); + if (!new_alias) + return -ENOMEM; + + memcpy(new_alias->ifalias, alias, len); + new_alias->ifalias[len] = 0; + } + + mutex_lock(&ifalias_mutex); + new_alias = rcu_replace_pointer(dev->ifalias, new_alias, + mutex_is_locked(&ifalias_mutex)); + mutex_unlock(&ifalias_mutex); + + if (new_alias) + kfree_rcu(new_alias, rcuhead); + + return len; +} +EXPORT_SYMBOL(dev_set_alias); + +/** + * dev_get_alias - get ifalias of a device + * @dev: device + * @name: buffer to store name of ifalias + * @len: size of buffer + * + * get ifalias for a device. Caller must make sure dev cannot go + * away, e.g. rcu read lock or own a reference count to device. + */ +int dev_get_alias(const struct net_device *dev, char *name, size_t len) +{ + const struct dev_ifalias *alias; + int ret = 0; + + rcu_read_lock(); + alias = rcu_dereference(dev->ifalias); + if (alias) + ret = snprintf(name, len, "%s", alias->ifalias); + rcu_read_unlock(); + + return ret; +} + +/** + * netdev_features_change - device changes features + * @dev: device to cause notification + * + * Called to indicate a device has changed features. + */ +void netdev_features_change(struct net_device *dev) +{ + call_netdevice_notifiers(NETDEV_FEAT_CHANGE, dev); +} +EXPORT_SYMBOL(netdev_features_change); + +/** + * netdev_state_change - device changes state + * @dev: device to cause notification + * + * Called to indicate a device has changed state. This function calls + * the notifier chains for netdev_chain and sends a NEWLINK message + * to the routing socket. + */ +void netdev_state_change(struct net_device *dev) +{ + if (dev->flags & IFF_UP) { + struct netdev_notifier_change_info change_info = { + .info.dev = dev, + }; + + call_netdevice_notifiers_info(NETDEV_CHANGE, + &change_info.info); + rtmsg_ifinfo(RTM_NEWLINK, dev, 0, GFP_KERNEL); + } +} +EXPORT_SYMBOL(netdev_state_change); + +/** + * __netdev_notify_peers - notify network peers about existence of @dev, + * to be called when rtnl lock is already held. + * @dev: network device + * + * Generate traffic such that interested network peers are aware of + * @dev, such as by generating a gratuitous ARP. This may be used when + * a device wants to inform the rest of the network about some sort of + * reconfiguration such as a failover event or virtual machine + * migration. + */ +void __netdev_notify_peers(struct net_device *dev) +{ + ASSERT_RTNL(); + call_netdevice_notifiers(NETDEV_NOTIFY_PEERS, dev); + call_netdevice_notifiers(NETDEV_RESEND_IGMP, dev); +} +EXPORT_SYMBOL(__netdev_notify_peers); + +/** + * netdev_notify_peers - notify network peers about existence of @dev + * @dev: network device + * + * Generate traffic such that interested network peers are aware of + * @dev, such as by generating a gratuitous ARP. This may be used when + * a device wants to inform the rest of the network about some sort of + * reconfiguration such as a failover event or virtual machine + * migration. + */ +void netdev_notify_peers(struct net_device *dev) +{ + rtnl_lock(); + __netdev_notify_peers(dev); + rtnl_unlock(); +} +EXPORT_SYMBOL(netdev_notify_peers); + +static int napi_threaded_poll(void *data); + +static int napi_kthread_create(struct napi_struct *n) +{ + int err = 0; + + /* Create and wake up the kthread once to put it in + * TASK_INTERRUPTIBLE mode to avoid the blocked task + * warning and work with loadavg. + */ + n->thread = kthread_run(napi_threaded_poll, n, "napi/%s-%d", + n->dev->name, n->napi_id); + if (IS_ERR(n->thread)) { + err = PTR_ERR(n->thread); + pr_err("kthread_run failed with err %d\n", err); + n->thread = NULL; + } + + return err; +} + +static int __dev_open(struct net_device *dev, struct netlink_ext_ack *extack) +{ + const struct net_device_ops *ops = dev->netdev_ops; + int ret; + + ASSERT_RTNL(); + dev_addr_check(dev); + + if (!netif_device_present(dev)) { + /* may be detached because parent is runtime-suspended */ + if (dev->dev.parent) + pm_runtime_resume(dev->dev.parent); + if (!netif_device_present(dev)) + return -ENODEV; + } + + /* Block netpoll from trying to do any rx path servicing. + * If we don't do this there is a chance ndo_poll_controller + * or ndo_poll may be running while we open the device + */ + netpoll_poll_disable(dev); + + ret = call_netdevice_notifiers_extack(NETDEV_PRE_UP, dev, extack); + ret = notifier_to_errno(ret); + if (ret) + return ret; + + set_bit(__LINK_STATE_START, &dev->state); + + if (ops->ndo_validate_addr) + ret = ops->ndo_validate_addr(dev); + + if (!ret && ops->ndo_open) + ret = ops->ndo_open(dev); + + netpoll_poll_enable(dev); + + if (ret) + clear_bit(__LINK_STATE_START, &dev->state); + else { + dev->flags |= IFF_UP; + dev_set_rx_mode(dev); + dev_activate(dev); + add_device_randomness(dev->dev_addr, dev->addr_len); + } + + return ret; +} + +/** + * dev_open - prepare an interface for use. + * @dev: device to open + * @extack: netlink extended ack + * + * Takes a device from down to up state. The device's private open + * function is invoked and then the multicast lists are loaded. Finally + * the device is moved into the up state and a %NETDEV_UP message is + * sent to the netdev notifier chain. + * + * Calling this function on an active interface is a nop. On a failure + * a negative errno code is returned. + */ +int dev_open(struct net_device *dev, struct netlink_ext_ack *extack) +{ + int ret; + + if (dev->flags & IFF_UP) + return 0; + + ret = __dev_open(dev, extack); + if (ret < 0) + return ret; + + rtmsg_ifinfo(RTM_NEWLINK, dev, IFF_UP|IFF_RUNNING, GFP_KERNEL); + call_netdevice_notifiers(NETDEV_UP, dev); + + return ret; +} +EXPORT_SYMBOL(dev_open); + +static void __dev_close_many(struct list_head *head) +{ + struct net_device *dev; + + ASSERT_RTNL(); + might_sleep(); + + list_for_each_entry(dev, head, close_list) { + /* Temporarily disable netpoll until the interface is down */ + netpoll_poll_disable(dev); + + call_netdevice_notifiers(NETDEV_GOING_DOWN, dev); + + clear_bit(__LINK_STATE_START, &dev->state); + + /* Synchronize to scheduled poll. We cannot touch poll list, it + * can be even on different cpu. So just clear netif_running(). + * + * dev->stop() will invoke napi_disable() on all of it's + * napi_struct instances on this device. + */ + smp_mb__after_atomic(); /* Commit netif_running(). */ + } + + dev_deactivate_many(head); + + list_for_each_entry(dev, head, close_list) { + const struct net_device_ops *ops = dev->netdev_ops; + + /* + * Call the device specific close. This cannot fail. + * Only if device is UP + * + * We allow it to be called even after a DETACH hot-plug + * event. + */ + if (ops->ndo_stop) + ops->ndo_stop(dev); + + dev->flags &= ~IFF_UP; + netpoll_poll_enable(dev); + } +} + +static void __dev_close(struct net_device *dev) +{ + LIST_HEAD(single); + + list_add(&dev->close_list, &single); + __dev_close_many(&single); + list_del(&single); +} + +void dev_close_many(struct list_head *head, bool unlink) +{ + struct net_device *dev, *tmp; + + /* Remove the devices that don't need to be closed */ + list_for_each_entry_safe(dev, tmp, head, close_list) + if (!(dev->flags & IFF_UP)) + list_del_init(&dev->close_list); + + __dev_close_many(head); + + list_for_each_entry_safe(dev, tmp, head, close_list) { + rtmsg_ifinfo(RTM_NEWLINK, dev, IFF_UP|IFF_RUNNING, GFP_KERNEL); + call_netdevice_notifiers(NETDEV_DOWN, dev); + if (unlink) + list_del_init(&dev->close_list); + } +} +EXPORT_SYMBOL(dev_close_many); + +/** + * dev_close - shutdown an interface. + * @dev: device to shutdown + * + * This function moves an active device into down state. A + * %NETDEV_GOING_DOWN is sent to the netdev notifier chain. The device + * is then deactivated and finally a %NETDEV_DOWN is sent to the notifier + * chain. + */ +void dev_close(struct net_device *dev) +{ + if (dev->flags & IFF_UP) { + LIST_HEAD(single); + + list_add(&dev->close_list, &single); + dev_close_many(&single, true); + list_del(&single); + } +} +EXPORT_SYMBOL(dev_close); + + +/** + * dev_disable_lro - disable Large Receive Offload on a device + * @dev: device + * + * Disable Large Receive Offload (LRO) on a net device. Must be + * called under RTNL. This is needed if received packets may be + * forwarded to another interface. + */ +void dev_disable_lro(struct net_device *dev) +{ + struct net_device *lower_dev; + struct list_head *iter; + + dev->wanted_features &= ~NETIF_F_LRO; + netdev_update_features(dev); + + if (unlikely(dev->features & NETIF_F_LRO)) + netdev_WARN(dev, "failed to disable LRO!\n"); + + netdev_for_each_lower_dev(dev, lower_dev, iter) + dev_disable_lro(lower_dev); +} +EXPORT_SYMBOL(dev_disable_lro); + +/** + * dev_disable_gro_hw - disable HW Generic Receive Offload on a device + * @dev: device + * + * Disable HW Generic Receive Offload (GRO_HW) on a net device. Must be + * called under RTNL. This is needed if Generic XDP is installed on + * the device. + */ +static void dev_disable_gro_hw(struct net_device *dev) +{ + dev->wanted_features &= ~NETIF_F_GRO_HW; + netdev_update_features(dev); + + if (unlikely(dev->features & NETIF_F_GRO_HW)) + netdev_WARN(dev, "failed to disable GRO_HW!\n"); +} + +const char *netdev_cmd_to_name(enum netdev_cmd cmd) +{ +#define N(val) \ + case NETDEV_##val: \ + return "NETDEV_" __stringify(val); + switch (cmd) { + N(UP) N(DOWN) N(REBOOT) N(CHANGE) N(REGISTER) N(UNREGISTER) + N(CHANGEMTU) N(CHANGEADDR) N(GOING_DOWN) N(CHANGENAME) N(FEAT_CHANGE) + N(BONDING_FAILOVER) N(PRE_UP) N(PRE_TYPE_CHANGE) N(POST_TYPE_CHANGE) + N(POST_INIT) N(RELEASE) N(NOTIFY_PEERS) N(JOIN) N(CHANGEUPPER) + N(RESEND_IGMP) N(PRECHANGEMTU) N(CHANGEINFODATA) N(BONDING_INFO) + N(PRECHANGEUPPER) N(CHANGELOWERSTATE) N(UDP_TUNNEL_PUSH_INFO) + N(UDP_TUNNEL_DROP_INFO) N(CHANGE_TX_QUEUE_LEN) + N(CVLAN_FILTER_PUSH_INFO) N(CVLAN_FILTER_DROP_INFO) + N(SVLAN_FILTER_PUSH_INFO) N(SVLAN_FILTER_DROP_INFO) + N(PRE_CHANGEADDR) N(OFFLOAD_XSTATS_ENABLE) N(OFFLOAD_XSTATS_DISABLE) + N(OFFLOAD_XSTATS_REPORT_USED) N(OFFLOAD_XSTATS_REPORT_DELTA) + } +#undef N + return "UNKNOWN_NETDEV_EVENT"; +} +EXPORT_SYMBOL_GPL(netdev_cmd_to_name); + +static int call_netdevice_notifier(struct notifier_block *nb, unsigned long val, + struct net_device *dev) +{ + struct netdev_notifier_info info = { + .dev = dev, + }; + + return nb->notifier_call(nb, val, &info); +} + +static int call_netdevice_register_notifiers(struct notifier_block *nb, + struct net_device *dev) +{ + int err; + + err = call_netdevice_notifier(nb, NETDEV_REGISTER, dev); + err = notifier_to_errno(err); + if (err) + return err; + + if (!(dev->flags & IFF_UP)) + return 0; + + call_netdevice_notifier(nb, NETDEV_UP, dev); + return 0; +} + +static void call_netdevice_unregister_notifiers(struct notifier_block *nb, + struct net_device *dev) +{ + if (dev->flags & IFF_UP) { + call_netdevice_notifier(nb, NETDEV_GOING_DOWN, + dev); + call_netdevice_notifier(nb, NETDEV_DOWN, dev); + } + call_netdevice_notifier(nb, NETDEV_UNREGISTER, dev); +} + +static int call_netdevice_register_net_notifiers(struct notifier_block *nb, + struct net *net) +{ + struct net_device *dev; + int err; + + for_each_netdev(net, dev) { + err = call_netdevice_register_notifiers(nb, dev); + if (err) + goto rollback; + } + return 0; + +rollback: + for_each_netdev_continue_reverse(net, dev) + call_netdevice_unregister_notifiers(nb, dev); + return err; +} + +static void call_netdevice_unregister_net_notifiers(struct notifier_block *nb, + struct net *net) +{ + struct net_device *dev; + + for_each_netdev(net, dev) + call_netdevice_unregister_notifiers(nb, dev); +} + +static int dev_boot_phase = 1; + +/** + * register_netdevice_notifier - register a network notifier block + * @nb: notifier + * + * Register a notifier to be called when network device events occur. + * The notifier passed is linked into the kernel structures and must + * not be reused until it has been unregistered. A negative errno code + * is returned on a failure. + * + * When registered all registration and up events are replayed + * to the new notifier to allow device to have a race free + * view of the network device list. + */ + +int register_netdevice_notifier(struct notifier_block *nb) +{ + struct net *net; + int err; + + /* Close race with setup_net() and cleanup_net() */ + down_write(&pernet_ops_rwsem); + rtnl_lock(); + err = raw_notifier_chain_register(&netdev_chain, nb); + if (err) + goto unlock; + if (dev_boot_phase) + goto unlock; + for_each_net(net) { + err = call_netdevice_register_net_notifiers(nb, net); + if (err) + goto rollback; + } + +unlock: + rtnl_unlock(); + up_write(&pernet_ops_rwsem); + return err; + +rollback: + for_each_net_continue_reverse(net) + call_netdevice_unregister_net_notifiers(nb, net); + + raw_notifier_chain_unregister(&netdev_chain, nb); + goto unlock; +} +EXPORT_SYMBOL(register_netdevice_notifier); + +/** + * unregister_netdevice_notifier - unregister a network notifier block + * @nb: notifier + * + * Unregister a notifier previously registered by + * register_netdevice_notifier(). The notifier is unlinked into the + * kernel structures and may then be reused. A negative errno code + * is returned on a failure. + * + * After unregistering unregister and down device events are synthesized + * for all devices on the device list to the removed notifier to remove + * the need for special case cleanup code. + */ + +int unregister_netdevice_notifier(struct notifier_block *nb) +{ + struct net *net; + int err; + + /* Close race with setup_net() and cleanup_net() */ + down_write(&pernet_ops_rwsem); + rtnl_lock(); + err = raw_notifier_chain_unregister(&netdev_chain, nb); + if (err) + goto unlock; + + for_each_net(net) + call_netdevice_unregister_net_notifiers(nb, net); + +unlock: + rtnl_unlock(); + up_write(&pernet_ops_rwsem); + return err; +} +EXPORT_SYMBOL(unregister_netdevice_notifier); + +static int __register_netdevice_notifier_net(struct net *net, + struct notifier_block *nb, + bool ignore_call_fail) +{ + int err; + + err = raw_notifier_chain_register(&net->netdev_chain, nb); + if (err) + return err; + if (dev_boot_phase) + return 0; + + err = call_netdevice_register_net_notifiers(nb, net); + if (err && !ignore_call_fail) + goto chain_unregister; + + return 0; + +chain_unregister: + raw_notifier_chain_unregister(&net->netdev_chain, nb); + return err; +} + +static int __unregister_netdevice_notifier_net(struct net *net, + struct notifier_block *nb) +{ + int err; + + err = raw_notifier_chain_unregister(&net->netdev_chain, nb); + if (err) + return err; + + call_netdevice_unregister_net_notifiers(nb, net); + return 0; +} + +/** + * register_netdevice_notifier_net - register a per-netns network notifier block + * @net: network namespace + * @nb: notifier + * + * Register a notifier to be called when network device events occur. + * The notifier passed is linked into the kernel structures and must + * not be reused until it has been unregistered. A negative errno code + * is returned on a failure. + * + * When registered all registration and up events are replayed + * to the new notifier to allow device to have a race free + * view of the network device list. + */ + +int register_netdevice_notifier_net(struct net *net, struct notifier_block *nb) +{ + int err; + + rtnl_lock(); + err = __register_netdevice_notifier_net(net, nb, false); + rtnl_unlock(); + return err; +} +EXPORT_SYMBOL(register_netdevice_notifier_net); + +/** + * unregister_netdevice_notifier_net - unregister a per-netns + * network notifier block + * @net: network namespace + * @nb: notifier + * + * Unregister a notifier previously registered by + * register_netdevice_notifier(). The notifier is unlinked into the + * kernel structures and may then be reused. A negative errno code + * is returned on a failure. + * + * After unregistering unregister and down device events are synthesized + * for all devices on the device list to the removed notifier to remove + * the need for special case cleanup code. + */ + +int unregister_netdevice_notifier_net(struct net *net, + struct notifier_block *nb) +{ + int err; + + rtnl_lock(); + err = __unregister_netdevice_notifier_net(net, nb); + rtnl_unlock(); + return err; +} +EXPORT_SYMBOL(unregister_netdevice_notifier_net); + +int register_netdevice_notifier_dev_net(struct net_device *dev, + struct notifier_block *nb, + struct netdev_net_notifier *nn) +{ + int err; + + rtnl_lock(); + err = __register_netdevice_notifier_net(dev_net(dev), nb, false); + if (!err) { + nn->nb = nb; + list_add(&nn->list, &dev->net_notifier_list); + } + rtnl_unlock(); + return err; +} +EXPORT_SYMBOL(register_netdevice_notifier_dev_net); + +int unregister_netdevice_notifier_dev_net(struct net_device *dev, + struct notifier_block *nb, + struct netdev_net_notifier *nn) +{ + int err; + + rtnl_lock(); + list_del(&nn->list); + err = __unregister_netdevice_notifier_net(dev_net(dev), nb); + rtnl_unlock(); + return err; +} +EXPORT_SYMBOL(unregister_netdevice_notifier_dev_net); + +static void move_netdevice_notifiers_dev_net(struct net_device *dev, + struct net *net) +{ + struct netdev_net_notifier *nn; + + list_for_each_entry(nn, &dev->net_notifier_list, list) { + __unregister_netdevice_notifier_net(dev_net(dev), nn->nb); + __register_netdevice_notifier_net(net, nn->nb, true); + } +} + +/** + * call_netdevice_notifiers_info - call all network notifier blocks + * @val: value passed unmodified to notifier function + * @info: notifier information data + * + * Call all network notifier blocks. Parameters and return value + * are as for raw_notifier_call_chain(). + */ + +static int call_netdevice_notifiers_info(unsigned long val, + struct netdev_notifier_info *info) +{ + struct net *net = dev_net(info->dev); + int ret; + + ASSERT_RTNL(); + + /* Run per-netns notifier block chain first, then run the global one. + * Hopefully, one day, the global one is going to be removed after + * all notifier block registrators get converted to be per-netns. + */ + ret = raw_notifier_call_chain(&net->netdev_chain, val, info); + if (ret & NOTIFY_STOP_MASK) + return ret; + return raw_notifier_call_chain(&netdev_chain, val, info); +} + +/** + * call_netdevice_notifiers_info_robust - call per-netns notifier blocks + * for and rollback on error + * @val_up: value passed unmodified to notifier function + * @val_down: value passed unmodified to the notifier function when + * recovering from an error on @val_up + * @info: notifier information data + * + * Call all per-netns network notifier blocks, but not notifier blocks on + * the global notifier chain. Parameters and return value are as for + * raw_notifier_call_chain_robust(). + */ + +static int +call_netdevice_notifiers_info_robust(unsigned long val_up, + unsigned long val_down, + struct netdev_notifier_info *info) +{ + struct net *net = dev_net(info->dev); + + ASSERT_RTNL(); + + return raw_notifier_call_chain_robust(&net->netdev_chain, + val_up, val_down, info); +} + +static int call_netdevice_notifiers_extack(unsigned long val, + struct net_device *dev, + struct netlink_ext_ack *extack) +{ + struct netdev_notifier_info info = { + .dev = dev, + .extack = extack, + }; + + return call_netdevice_notifiers_info(val, &info); +} + +/** + * call_netdevice_notifiers - call all network notifier blocks + * @val: value passed unmodified to notifier function + * @dev: net_device pointer passed unmodified to notifier function + * + * Call all network notifier blocks. Parameters and return value + * are as for raw_notifier_call_chain(). + */ + +int call_netdevice_notifiers(unsigned long val, struct net_device *dev) +{ + return call_netdevice_notifiers_extack(val, dev, NULL); +} +EXPORT_SYMBOL(call_netdevice_notifiers); + +/** + * call_netdevice_notifiers_mtu - call all network notifier blocks + * @val: value passed unmodified to notifier function + * @dev: net_device pointer passed unmodified to notifier function + * @arg: additional u32 argument passed to the notifier function + * + * Call all network notifier blocks. Parameters and return value + * are as for raw_notifier_call_chain(). + */ +static int call_netdevice_notifiers_mtu(unsigned long val, + struct net_device *dev, u32 arg) +{ + struct netdev_notifier_info_ext info = { + .info.dev = dev, + .ext.mtu = arg, + }; + + BUILD_BUG_ON(offsetof(struct netdev_notifier_info_ext, info) != 0); + + return call_netdevice_notifiers_info(val, &info.info); +} + +#ifdef CONFIG_NET_INGRESS +static DEFINE_STATIC_KEY_FALSE(ingress_needed_key); + +void net_inc_ingress_queue(void) +{ + static_branch_inc(&ingress_needed_key); +} +EXPORT_SYMBOL_GPL(net_inc_ingress_queue); + +void net_dec_ingress_queue(void) +{ + static_branch_dec(&ingress_needed_key); +} +EXPORT_SYMBOL_GPL(net_dec_ingress_queue); +#endif + +#ifdef CONFIG_NET_EGRESS +static DEFINE_STATIC_KEY_FALSE(egress_needed_key); + +void net_inc_egress_queue(void) +{ + static_branch_inc(&egress_needed_key); +} +EXPORT_SYMBOL_GPL(net_inc_egress_queue); + +void net_dec_egress_queue(void) +{ + static_branch_dec(&egress_needed_key); +} +EXPORT_SYMBOL_GPL(net_dec_egress_queue); +#endif + +DEFINE_STATIC_KEY_FALSE(netstamp_needed_key); +EXPORT_SYMBOL(netstamp_needed_key); +#ifdef CONFIG_JUMP_LABEL +static atomic_t netstamp_needed_deferred; +static atomic_t netstamp_wanted; +static void netstamp_clear(struct work_struct *work) +{ + int deferred = atomic_xchg(&netstamp_needed_deferred, 0); + int wanted; + + wanted = atomic_add_return(deferred, &netstamp_wanted); + if (wanted > 0) + static_branch_enable(&netstamp_needed_key); + else + static_branch_disable(&netstamp_needed_key); +} +static DECLARE_WORK(netstamp_work, netstamp_clear); +#endif + +void net_enable_timestamp(void) +{ +#ifdef CONFIG_JUMP_LABEL + int wanted; + + while (1) { + wanted = atomic_read(&netstamp_wanted); + if (wanted <= 0) + break; + if (atomic_cmpxchg(&netstamp_wanted, wanted, wanted + 1) == wanted) + return; + } + atomic_inc(&netstamp_needed_deferred); + schedule_work(&netstamp_work); +#else + static_branch_inc(&netstamp_needed_key); +#endif +} +EXPORT_SYMBOL(net_enable_timestamp); + +void net_disable_timestamp(void) +{ +#ifdef CONFIG_JUMP_LABEL + int wanted; + + while (1) { + wanted = atomic_read(&netstamp_wanted); + if (wanted <= 1) + break; + if (atomic_cmpxchg(&netstamp_wanted, wanted, wanted - 1) == wanted) + return; + } + atomic_dec(&netstamp_needed_deferred); + schedule_work(&netstamp_work); +#else + static_branch_dec(&netstamp_needed_key); +#endif +} +EXPORT_SYMBOL(net_disable_timestamp); + +static inline void net_timestamp_set(struct sk_buff *skb) +{ + skb->tstamp = 0; + skb->mono_delivery_time = 0; + if (static_branch_unlikely(&netstamp_needed_key)) + skb->tstamp = ktime_get_real(); +} + +#define net_timestamp_check(COND, SKB) \ + if (static_branch_unlikely(&netstamp_needed_key)) { \ + if ((COND) && !(SKB)->tstamp) \ + (SKB)->tstamp = ktime_get_real(); \ + } \ + +bool is_skb_forwardable(const struct net_device *dev, const struct sk_buff *skb) +{ + return __is_skb_forwardable(dev, skb, true); +} +EXPORT_SYMBOL_GPL(is_skb_forwardable); + +static int __dev_forward_skb2(struct net_device *dev, struct sk_buff *skb, + bool check_mtu) +{ + int ret = ____dev_forward_skb(dev, skb, check_mtu); + + if (likely(!ret)) { + skb->protocol = eth_type_trans(skb, dev); + skb_postpull_rcsum(skb, eth_hdr(skb), ETH_HLEN); + } + + return ret; +} + +int __dev_forward_skb(struct net_device *dev, struct sk_buff *skb) +{ + return __dev_forward_skb2(dev, skb, true); +} +EXPORT_SYMBOL_GPL(__dev_forward_skb); + +/** + * dev_forward_skb - loopback an skb to another netif + * + * @dev: destination network device + * @skb: buffer to forward + * + * return values: + * NET_RX_SUCCESS (no congestion) + * NET_RX_DROP (packet was dropped, but freed) + * + * dev_forward_skb can be used for injecting an skb from the + * start_xmit function of one device into the receive queue + * of another device. + * + * The receiving device may be in another namespace, so + * we have to clear all information in the skb that could + * impact namespace isolation. + */ +int dev_forward_skb(struct net_device *dev, struct sk_buff *skb) +{ + return __dev_forward_skb(dev, skb) ?: netif_rx_internal(skb); +} +EXPORT_SYMBOL_GPL(dev_forward_skb); + +int dev_forward_skb_nomtu(struct net_device *dev, struct sk_buff *skb) +{ + return __dev_forward_skb2(dev, skb, false) ?: netif_rx_internal(skb); +} + +static inline int deliver_skb(struct sk_buff *skb, + struct packet_type *pt_prev, + struct net_device *orig_dev) +{ + if (unlikely(skb_orphan_frags_rx(skb, GFP_ATOMIC))) + return -ENOMEM; + refcount_inc(&skb->users); + return pt_prev->func(skb, skb->dev, pt_prev, orig_dev); +} + +static inline void deliver_ptype_list_skb(struct sk_buff *skb, + struct packet_type **pt, + struct net_device *orig_dev, + __be16 type, + struct list_head *ptype_list) +{ + struct packet_type *ptype, *pt_prev = *pt; + + list_for_each_entry_rcu(ptype, ptype_list, list) { + if (ptype->type != type) + continue; + if (pt_prev) + deliver_skb(skb, pt_prev, orig_dev); + pt_prev = ptype; + } + *pt = pt_prev; +} + +static inline bool skb_loop_sk(struct packet_type *ptype, struct sk_buff *skb) +{ + if (!ptype->af_packet_priv || !skb->sk) + return false; + + if (ptype->id_match) + return ptype->id_match(ptype, skb->sk); + else if ((struct sock *)ptype->af_packet_priv == skb->sk) + return true; + + return false; +} + +/** + * dev_nit_active - return true if any network interface taps are in use + * + * @dev: network device to check for the presence of taps + */ +bool dev_nit_active(struct net_device *dev) +{ + return !list_empty(&ptype_all) || !list_empty(&dev->ptype_all); +} +EXPORT_SYMBOL_GPL(dev_nit_active); + +/* + * Support routine. Sends outgoing frames to any network + * taps currently in use. + */ + +void dev_queue_xmit_nit(struct sk_buff *skb, struct net_device *dev) +{ + struct packet_type *ptype; + struct sk_buff *skb2 = NULL; + struct packet_type *pt_prev = NULL; + struct list_head *ptype_list = &ptype_all; + + rcu_read_lock(); +again: + list_for_each_entry_rcu(ptype, ptype_list, list) { + if (ptype->ignore_outgoing) + continue; + + /* Never send packets back to the socket + * they originated from - MvS (miquels@drinkel.ow.org) + */ + if (skb_loop_sk(ptype, skb)) + continue; + + if (pt_prev) { + deliver_skb(skb2, pt_prev, skb->dev); + pt_prev = ptype; + continue; + } + + /* need to clone skb, done only once */ + skb2 = skb_clone(skb, GFP_ATOMIC); + if (!skb2) + goto out_unlock; + + net_timestamp_set(skb2); + + /* skb->nh should be correctly + * set by sender, so that the second statement is + * just protection against buggy protocols. + */ + skb_reset_mac_header(skb2); + + if (skb_network_header(skb2) < skb2->data || + skb_network_header(skb2) > skb_tail_pointer(skb2)) { + net_crit_ratelimited("protocol %04x is buggy, dev %s\n", + ntohs(skb2->protocol), + dev->name); + skb_reset_network_header(skb2); + } + + skb2->transport_header = skb2->network_header; + skb2->pkt_type = PACKET_OUTGOING; + pt_prev = ptype; + } + + if (ptype_list == &ptype_all) { + ptype_list = &dev->ptype_all; + goto again; + } +out_unlock: + if (pt_prev) { + if (!skb_orphan_frags_rx(skb2, GFP_ATOMIC)) + pt_prev->func(skb2, skb->dev, pt_prev, skb->dev); + else + kfree_skb(skb2); + } + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(dev_queue_xmit_nit); + +/** + * netif_setup_tc - Handle tc mappings on real_num_tx_queues change + * @dev: Network device + * @txq: number of queues available + * + * If real_num_tx_queues is changed the tc mappings may no longer be + * valid. To resolve this verify the tc mapping remains valid and if + * not NULL the mapping. With no priorities mapping to this + * offset/count pair it will no longer be used. In the worst case TC0 + * is invalid nothing can be done so disable priority mappings. If is + * expected that drivers will fix this mapping if they can before + * calling netif_set_real_num_tx_queues. + */ +static void netif_setup_tc(struct net_device *dev, unsigned int txq) +{ + int i; + struct netdev_tc_txq *tc = &dev->tc_to_txq[0]; + + /* If TC0 is invalidated disable TC mapping */ + if (tc->offset + tc->count > txq) { + netdev_warn(dev, "Number of in use tx queues changed invalidating tc mappings. Priority traffic classification disabled!\n"); + dev->num_tc = 0; + return; + } + + /* Invalidated prio to tc mappings set to TC0 */ + for (i = 1; i < TC_BITMASK + 1; i++) { + int q = netdev_get_prio_tc_map(dev, i); + + tc = &dev->tc_to_txq[q]; + if (tc->offset + tc->count > txq) { + netdev_warn(dev, "Number of in use tx queues changed. Priority %i to tc mapping %i is no longer valid. Setting map to 0\n", + i, q); + netdev_set_prio_tc_map(dev, i, 0); + } + } +} + +int netdev_txq_to_tc(struct net_device *dev, unsigned int txq) +{ + if (dev->num_tc) { + struct netdev_tc_txq *tc = &dev->tc_to_txq[0]; + int i; + + /* walk through the TCs and see if it falls into any of them */ + for (i = 0; i < TC_MAX_QUEUE; i++, tc++) { + if ((txq - tc->offset) < tc->count) + return i; + } + + /* didn't find it, just return -1 to indicate no match */ + return -1; + } + + return 0; +} +EXPORT_SYMBOL(netdev_txq_to_tc); + +#ifdef CONFIG_XPS +static struct static_key xps_needed __read_mostly; +static struct static_key xps_rxqs_needed __read_mostly; +static DEFINE_MUTEX(xps_map_mutex); +#define xmap_dereference(P) \ + rcu_dereference_protected((P), lockdep_is_held(&xps_map_mutex)) + +static bool remove_xps_queue(struct xps_dev_maps *dev_maps, + struct xps_dev_maps *old_maps, int tci, u16 index) +{ + struct xps_map *map = NULL; + int pos; + + if (dev_maps) + map = xmap_dereference(dev_maps->attr_map[tci]); + if (!map) + return false; + + for (pos = map->len; pos--;) { + if (map->queues[pos] != index) + continue; + + if (map->len > 1) { + map->queues[pos] = map->queues[--map->len]; + break; + } + + if (old_maps) + RCU_INIT_POINTER(old_maps->attr_map[tci], NULL); + RCU_INIT_POINTER(dev_maps->attr_map[tci], NULL); + kfree_rcu(map, rcu); + return false; + } + + return true; +} + +static bool remove_xps_queue_cpu(struct net_device *dev, + struct xps_dev_maps *dev_maps, + int cpu, u16 offset, u16 count) +{ + int num_tc = dev_maps->num_tc; + bool active = false; + int tci; + + for (tci = cpu * num_tc; num_tc--; tci++) { + int i, j; + + for (i = count, j = offset; i--; j++) { + if (!remove_xps_queue(dev_maps, NULL, tci, j)) + break; + } + + active |= i < 0; + } + + return active; +} + +static void reset_xps_maps(struct net_device *dev, + struct xps_dev_maps *dev_maps, + enum xps_map_type type) +{ + static_key_slow_dec_cpuslocked(&xps_needed); + if (type == XPS_RXQS) + static_key_slow_dec_cpuslocked(&xps_rxqs_needed); + + RCU_INIT_POINTER(dev->xps_maps[type], NULL); + + kfree_rcu(dev_maps, rcu); +} + +static void clean_xps_maps(struct net_device *dev, enum xps_map_type type, + u16 offset, u16 count) +{ + struct xps_dev_maps *dev_maps; + bool active = false; + int i, j; + + dev_maps = xmap_dereference(dev->xps_maps[type]); + if (!dev_maps) + return; + + for (j = 0; j < dev_maps->nr_ids; j++) + active |= remove_xps_queue_cpu(dev, dev_maps, j, offset, count); + if (!active) + reset_xps_maps(dev, dev_maps, type); + + if (type == XPS_CPUS) { + for (i = offset + (count - 1); count--; i--) + netdev_queue_numa_node_write( + netdev_get_tx_queue(dev, i), NUMA_NO_NODE); + } +} + +static void netif_reset_xps_queues(struct net_device *dev, u16 offset, + u16 count) +{ + if (!static_key_false(&xps_needed)) + return; + + cpus_read_lock(); + mutex_lock(&xps_map_mutex); + + if (static_key_false(&xps_rxqs_needed)) + clean_xps_maps(dev, XPS_RXQS, offset, count); + + clean_xps_maps(dev, XPS_CPUS, offset, count); + + mutex_unlock(&xps_map_mutex); + cpus_read_unlock(); +} + +static void netif_reset_xps_queues_gt(struct net_device *dev, u16 index) +{ + netif_reset_xps_queues(dev, index, dev->num_tx_queues - index); +} + +static struct xps_map *expand_xps_map(struct xps_map *map, int attr_index, + u16 index, bool is_rxqs_map) +{ + struct xps_map *new_map; + int alloc_len = XPS_MIN_MAP_ALLOC; + int i, pos; + + for (pos = 0; map && pos < map->len; pos++) { + if (map->queues[pos] != index) + continue; + return map; + } + + /* Need to add tx-queue to this CPU's/rx-queue's existing map */ + if (map) { + if (pos < map->alloc_len) + return map; + + alloc_len = map->alloc_len * 2; + } + + /* Need to allocate new map to store tx-queue on this CPU's/rx-queue's + * map + */ + if (is_rxqs_map) + new_map = kzalloc(XPS_MAP_SIZE(alloc_len), GFP_KERNEL); + else + new_map = kzalloc_node(XPS_MAP_SIZE(alloc_len), GFP_KERNEL, + cpu_to_node(attr_index)); + if (!new_map) + return NULL; + + for (i = 0; i < pos; i++) + new_map->queues[i] = map->queues[i]; + new_map->alloc_len = alloc_len; + new_map->len = pos; + + return new_map; +} + +/* Copy xps maps at a given index */ +static void xps_copy_dev_maps(struct xps_dev_maps *dev_maps, + struct xps_dev_maps *new_dev_maps, int index, + int tc, bool skip_tc) +{ + int i, tci = index * dev_maps->num_tc; + struct xps_map *map; + + /* copy maps belonging to foreign traffic classes */ + for (i = 0; i < dev_maps->num_tc; i++, tci++) { + if (i == tc && skip_tc) + continue; + + /* fill in the new device map from the old device map */ + map = xmap_dereference(dev_maps->attr_map[tci]); + RCU_INIT_POINTER(new_dev_maps->attr_map[tci], map); + } +} + +/* Must be called under cpus_read_lock */ +int __netif_set_xps_queue(struct net_device *dev, const unsigned long *mask, + u16 index, enum xps_map_type type) +{ + struct xps_dev_maps *dev_maps, *new_dev_maps = NULL, *old_dev_maps = NULL; + const unsigned long *online_mask = NULL; + bool active = false, copy = false; + int i, j, tci, numa_node_id = -2; + int maps_sz, num_tc = 1, tc = 0; + struct xps_map *map, *new_map; + unsigned int nr_ids; + + WARN_ON_ONCE(index >= dev->num_tx_queues); + + if (dev->num_tc) { + /* Do not allow XPS on subordinate device directly */ + num_tc = dev->num_tc; + if (num_tc < 0) + return -EINVAL; + + /* If queue belongs to subordinate dev use its map */ + dev = netdev_get_tx_queue(dev, index)->sb_dev ? : dev; + + tc = netdev_txq_to_tc(dev, index); + if (tc < 0) + return -EINVAL; + } + + mutex_lock(&xps_map_mutex); + + dev_maps = xmap_dereference(dev->xps_maps[type]); + if (type == XPS_RXQS) { + maps_sz = XPS_RXQ_DEV_MAPS_SIZE(num_tc, dev->num_rx_queues); + nr_ids = dev->num_rx_queues; + } else { + maps_sz = XPS_CPU_DEV_MAPS_SIZE(num_tc); + if (num_possible_cpus() > 1) + online_mask = cpumask_bits(cpu_online_mask); + nr_ids = nr_cpu_ids; + } + + if (maps_sz < L1_CACHE_BYTES) + maps_sz = L1_CACHE_BYTES; + + /* The old dev_maps could be larger or smaller than the one we're + * setting up now, as dev->num_tc or nr_ids could have been updated in + * between. We could try to be smart, but let's be safe instead and only + * copy foreign traffic classes if the two map sizes match. + */ + if (dev_maps && + dev_maps->num_tc == num_tc && dev_maps->nr_ids == nr_ids) + copy = true; + + /* allocate memory for queue storage */ + for (j = -1; j = netif_attrmask_next_and(j, online_mask, mask, nr_ids), + j < nr_ids;) { + if (!new_dev_maps) { + new_dev_maps = kzalloc(maps_sz, GFP_KERNEL); + if (!new_dev_maps) { + mutex_unlock(&xps_map_mutex); + return -ENOMEM; + } + + new_dev_maps->nr_ids = nr_ids; + new_dev_maps->num_tc = num_tc; + } + + tci = j * num_tc + tc; + map = copy ? xmap_dereference(dev_maps->attr_map[tci]) : NULL; + + map = expand_xps_map(map, j, index, type == XPS_RXQS); + if (!map) + goto error; + + RCU_INIT_POINTER(new_dev_maps->attr_map[tci], map); + } + + if (!new_dev_maps) + goto out_no_new_maps; + + if (!dev_maps) { + /* Increment static keys at most once per type */ + static_key_slow_inc_cpuslocked(&xps_needed); + if (type == XPS_RXQS) + static_key_slow_inc_cpuslocked(&xps_rxqs_needed); + } + + for (j = 0; j < nr_ids; j++) { + bool skip_tc = false; + + tci = j * num_tc + tc; + if (netif_attr_test_mask(j, mask, nr_ids) && + netif_attr_test_online(j, online_mask, nr_ids)) { + /* add tx-queue to CPU/rx-queue maps */ + int pos = 0; + + skip_tc = true; + + map = xmap_dereference(new_dev_maps->attr_map[tci]); + while ((pos < map->len) && (map->queues[pos] != index)) + pos++; + + if (pos == map->len) + map->queues[map->len++] = index; +#ifdef CONFIG_NUMA + if (type == XPS_CPUS) { + if (numa_node_id == -2) + numa_node_id = cpu_to_node(j); + else if (numa_node_id != cpu_to_node(j)) + numa_node_id = -1; + } +#endif + } + + if (copy) + xps_copy_dev_maps(dev_maps, new_dev_maps, j, tc, + skip_tc); + } + + rcu_assign_pointer(dev->xps_maps[type], new_dev_maps); + + /* Cleanup old maps */ + if (!dev_maps) + goto out_no_old_maps; + + for (j = 0; j < dev_maps->nr_ids; j++) { + for (i = num_tc, tci = j * dev_maps->num_tc; i--; tci++) { + map = xmap_dereference(dev_maps->attr_map[tci]); + if (!map) + continue; + + if (copy) { + new_map = xmap_dereference(new_dev_maps->attr_map[tci]); + if (map == new_map) + continue; + } + + RCU_INIT_POINTER(dev_maps->attr_map[tci], NULL); + kfree_rcu(map, rcu); + } + } + + old_dev_maps = dev_maps; + +out_no_old_maps: + dev_maps = new_dev_maps; + active = true; + +out_no_new_maps: + if (type == XPS_CPUS) + /* update Tx queue numa node */ + netdev_queue_numa_node_write(netdev_get_tx_queue(dev, index), + (numa_node_id >= 0) ? + numa_node_id : NUMA_NO_NODE); + + if (!dev_maps) + goto out_no_maps; + + /* removes tx-queue from unused CPUs/rx-queues */ + for (j = 0; j < dev_maps->nr_ids; j++) { + tci = j * dev_maps->num_tc; + + for (i = 0; i < dev_maps->num_tc; i++, tci++) { + if (i == tc && + netif_attr_test_mask(j, mask, dev_maps->nr_ids) && + netif_attr_test_online(j, online_mask, dev_maps->nr_ids)) + continue; + + active |= remove_xps_queue(dev_maps, + copy ? old_dev_maps : NULL, + tci, index); + } + } + + if (old_dev_maps) + kfree_rcu(old_dev_maps, rcu); + + /* free map if not active */ + if (!active) + reset_xps_maps(dev, dev_maps, type); + +out_no_maps: + mutex_unlock(&xps_map_mutex); + + return 0; +error: + /* remove any maps that we added */ + for (j = 0; j < nr_ids; j++) { + for (i = num_tc, tci = j * num_tc; i--; tci++) { + new_map = xmap_dereference(new_dev_maps->attr_map[tci]); + map = copy ? + xmap_dereference(dev_maps->attr_map[tci]) : + NULL; + if (new_map && new_map != map) + kfree(new_map); + } + } + + mutex_unlock(&xps_map_mutex); + + kfree(new_dev_maps); + return -ENOMEM; +} +EXPORT_SYMBOL_GPL(__netif_set_xps_queue); + +int netif_set_xps_queue(struct net_device *dev, const struct cpumask *mask, + u16 index) +{ + int ret; + + cpus_read_lock(); + ret = __netif_set_xps_queue(dev, cpumask_bits(mask), index, XPS_CPUS); + cpus_read_unlock(); + + return ret; +} +EXPORT_SYMBOL(netif_set_xps_queue); + +#endif +static void netdev_unbind_all_sb_channels(struct net_device *dev) +{ + struct netdev_queue *txq = &dev->_tx[dev->num_tx_queues]; + + /* Unbind any subordinate channels */ + while (txq-- != &dev->_tx[0]) { + if (txq->sb_dev) + netdev_unbind_sb_channel(dev, txq->sb_dev); + } +} + +void netdev_reset_tc(struct net_device *dev) +{ +#ifdef CONFIG_XPS + netif_reset_xps_queues_gt(dev, 0); +#endif + netdev_unbind_all_sb_channels(dev); + + /* Reset TC configuration of device */ + dev->num_tc = 0; + memset(dev->tc_to_txq, 0, sizeof(dev->tc_to_txq)); + memset(dev->prio_tc_map, 0, sizeof(dev->prio_tc_map)); +} +EXPORT_SYMBOL(netdev_reset_tc); + +int netdev_set_tc_queue(struct net_device *dev, u8 tc, u16 count, u16 offset) +{ + if (tc >= dev->num_tc) + return -EINVAL; + +#ifdef CONFIG_XPS + netif_reset_xps_queues(dev, offset, count); +#endif + dev->tc_to_txq[tc].count = count; + dev->tc_to_txq[tc].offset = offset; + return 0; +} +EXPORT_SYMBOL(netdev_set_tc_queue); + +int netdev_set_num_tc(struct net_device *dev, u8 num_tc) +{ + if (num_tc > TC_MAX_QUEUE) + return -EINVAL; + +#ifdef CONFIG_XPS + netif_reset_xps_queues_gt(dev, 0); +#endif + netdev_unbind_all_sb_channels(dev); + + dev->num_tc = num_tc; + return 0; +} +EXPORT_SYMBOL(netdev_set_num_tc); + +void netdev_unbind_sb_channel(struct net_device *dev, + struct net_device *sb_dev) +{ + struct netdev_queue *txq = &dev->_tx[dev->num_tx_queues]; + +#ifdef CONFIG_XPS + netif_reset_xps_queues_gt(sb_dev, 0); +#endif + memset(sb_dev->tc_to_txq, 0, sizeof(sb_dev->tc_to_txq)); + memset(sb_dev->prio_tc_map, 0, sizeof(sb_dev->prio_tc_map)); + + while (txq-- != &dev->_tx[0]) { + if (txq->sb_dev == sb_dev) + txq->sb_dev = NULL; + } +} +EXPORT_SYMBOL(netdev_unbind_sb_channel); + +int netdev_bind_sb_channel_queue(struct net_device *dev, + struct net_device *sb_dev, + u8 tc, u16 count, u16 offset) +{ + /* Make certain the sb_dev and dev are already configured */ + if (sb_dev->num_tc >= 0 || tc >= dev->num_tc) + return -EINVAL; + + /* We cannot hand out queues we don't have */ + if ((offset + count) > dev->real_num_tx_queues) + return -EINVAL; + + /* Record the mapping */ + sb_dev->tc_to_txq[tc].count = count; + sb_dev->tc_to_txq[tc].offset = offset; + + /* Provide a way for Tx queue to find the tc_to_txq map or + * XPS map for itself. + */ + while (count--) + netdev_get_tx_queue(dev, count + offset)->sb_dev = sb_dev; + + return 0; +} +EXPORT_SYMBOL(netdev_bind_sb_channel_queue); + +int netdev_set_sb_channel(struct net_device *dev, u16 channel) +{ + /* Do not use a multiqueue device to represent a subordinate channel */ + if (netif_is_multiqueue(dev)) + return -ENODEV; + + /* We allow channels 1 - 32767 to be used for subordinate channels. + * Channel 0 is meant to be "native" mode and used only to represent + * the main root device. We allow writing 0 to reset the device back + * to normal mode after being used as a subordinate channel. + */ + if (channel > S16_MAX) + return -EINVAL; + + dev->num_tc = -channel; + + return 0; +} +EXPORT_SYMBOL(netdev_set_sb_channel); + +/* + * Routine to help set real_num_tx_queues. To avoid skbs mapped to queues + * greater than real_num_tx_queues stale skbs on the qdisc must be flushed. + */ +int netif_set_real_num_tx_queues(struct net_device *dev, unsigned int txq) +{ + bool disabling; + int rc; + + disabling = txq < dev->real_num_tx_queues; + + if (txq < 1 || txq > dev->num_tx_queues) + return -EINVAL; + + if (dev->reg_state == NETREG_REGISTERED || + dev->reg_state == NETREG_UNREGISTERING) { + ASSERT_RTNL(); + + rc = netdev_queue_update_kobjects(dev, dev->real_num_tx_queues, + txq); + if (rc) + return rc; + + if (dev->num_tc) + netif_setup_tc(dev, txq); + + dev_qdisc_change_real_num_tx(dev, txq); + + dev->real_num_tx_queues = txq; + + if (disabling) { + synchronize_net(); + qdisc_reset_all_tx_gt(dev, txq); +#ifdef CONFIG_XPS + netif_reset_xps_queues_gt(dev, txq); +#endif + } + } else { + dev->real_num_tx_queues = txq; + } + + return 0; +} +EXPORT_SYMBOL(netif_set_real_num_tx_queues); + +#ifdef CONFIG_SYSFS +/** + * netif_set_real_num_rx_queues - set actual number of RX queues used + * @dev: Network device + * @rxq: Actual number of RX queues + * + * This must be called either with the rtnl_lock held or before + * registration of the net device. Returns 0 on success, or a + * negative error code. If called before registration, it always + * succeeds. + */ +int netif_set_real_num_rx_queues(struct net_device *dev, unsigned int rxq) +{ + int rc; + + if (rxq < 1 || rxq > dev->num_rx_queues) + return -EINVAL; + + if (dev->reg_state == NETREG_REGISTERED) { + ASSERT_RTNL(); + + rc = net_rx_queue_update_kobjects(dev, dev->real_num_rx_queues, + rxq); + if (rc) + return rc; + } + + dev->real_num_rx_queues = rxq; + return 0; +} +EXPORT_SYMBOL(netif_set_real_num_rx_queues); +#endif + +/** + * netif_set_real_num_queues - set actual number of RX and TX queues used + * @dev: Network device + * @txq: Actual number of TX queues + * @rxq: Actual number of RX queues + * + * Set the real number of both TX and RX queues. + * Does nothing if the number of queues is already correct. + */ +int netif_set_real_num_queues(struct net_device *dev, + unsigned int txq, unsigned int rxq) +{ + unsigned int old_rxq = dev->real_num_rx_queues; + int err; + + if (txq < 1 || txq > dev->num_tx_queues || + rxq < 1 || rxq > dev->num_rx_queues) + return -EINVAL; + + /* Start from increases, so the error path only does decreases - + * decreases can't fail. + */ + if (rxq > dev->real_num_rx_queues) { + err = netif_set_real_num_rx_queues(dev, rxq); + if (err) + return err; + } + if (txq > dev->real_num_tx_queues) { + err = netif_set_real_num_tx_queues(dev, txq); + if (err) + goto undo_rx; + } + if (rxq < dev->real_num_rx_queues) + WARN_ON(netif_set_real_num_rx_queues(dev, rxq)); + if (txq < dev->real_num_tx_queues) + WARN_ON(netif_set_real_num_tx_queues(dev, txq)); + + return 0; +undo_rx: + WARN_ON(netif_set_real_num_rx_queues(dev, old_rxq)); + return err; +} +EXPORT_SYMBOL(netif_set_real_num_queues); + +/** + * netif_set_tso_max_size() - set the max size of TSO frames supported + * @dev: netdev to update + * @size: max skb->len of a TSO frame + * + * Set the limit on the size of TSO super-frames the device can handle. + * Unless explicitly set the stack will assume the value of + * %GSO_LEGACY_MAX_SIZE. + */ +void netif_set_tso_max_size(struct net_device *dev, unsigned int size) +{ + dev->tso_max_size = min(GSO_MAX_SIZE, size); + if (size < READ_ONCE(dev->gso_max_size)) + netif_set_gso_max_size(dev, size); +} +EXPORT_SYMBOL(netif_set_tso_max_size); + +/** + * netif_set_tso_max_segs() - set the max number of segs supported for TSO + * @dev: netdev to update + * @segs: max number of TCP segments + * + * Set the limit on the number of TCP segments the device can generate from + * a single TSO super-frame. + * Unless explicitly set the stack will assume the value of %GSO_MAX_SEGS. + */ +void netif_set_tso_max_segs(struct net_device *dev, unsigned int segs) +{ + dev->tso_max_segs = segs; + if (segs < READ_ONCE(dev->gso_max_segs)) + netif_set_gso_max_segs(dev, segs); +} +EXPORT_SYMBOL(netif_set_tso_max_segs); + +/** + * netif_inherit_tso_max() - copy all TSO limits from a lower device to an upper + * @to: netdev to update + * @from: netdev from which to copy the limits + */ +void netif_inherit_tso_max(struct net_device *to, const struct net_device *from) +{ + netif_set_tso_max_size(to, from->tso_max_size); + netif_set_tso_max_segs(to, from->tso_max_segs); +} +EXPORT_SYMBOL(netif_inherit_tso_max); + +/** + * netif_get_num_default_rss_queues - default number of RSS queues + * + * Default value is the number of physical cores if there are only 1 or 2, or + * divided by 2 if there are more. + */ +int netif_get_num_default_rss_queues(void) +{ + cpumask_var_t cpus; + int cpu, count = 0; + + if (unlikely(is_kdump_kernel() || !zalloc_cpumask_var(&cpus, GFP_KERNEL))) + return 1; + + cpumask_copy(cpus, cpu_online_mask); + for_each_cpu(cpu, cpus) { + ++count; + cpumask_andnot(cpus, cpus, topology_sibling_cpumask(cpu)); + } + free_cpumask_var(cpus); + + return count > 2 ? DIV_ROUND_UP(count, 2) : count; +} +EXPORT_SYMBOL(netif_get_num_default_rss_queues); + +static void __netif_reschedule(struct Qdisc *q) +{ + struct softnet_data *sd; + unsigned long flags; + + local_irq_save(flags); + sd = this_cpu_ptr(&softnet_data); + q->next_sched = NULL; + *sd->output_queue_tailp = q; + sd->output_queue_tailp = &q->next_sched; + raise_softirq_irqoff(NET_TX_SOFTIRQ); + local_irq_restore(flags); +} + +void __netif_schedule(struct Qdisc *q) +{ + if (!test_and_set_bit(__QDISC_STATE_SCHED, &q->state)) + __netif_reschedule(q); +} +EXPORT_SYMBOL(__netif_schedule); + +struct dev_kfree_skb_cb { + enum skb_free_reason reason; +}; + +static struct dev_kfree_skb_cb *get_kfree_skb_cb(const struct sk_buff *skb) +{ + return (struct dev_kfree_skb_cb *)skb->cb; +} + +void netif_schedule_queue(struct netdev_queue *txq) +{ + rcu_read_lock(); + if (!netif_xmit_stopped(txq)) { + struct Qdisc *q = rcu_dereference(txq->qdisc); + + __netif_schedule(q); + } + rcu_read_unlock(); +} +EXPORT_SYMBOL(netif_schedule_queue); + +void netif_tx_wake_queue(struct netdev_queue *dev_queue) +{ + if (test_and_clear_bit(__QUEUE_STATE_DRV_XOFF, &dev_queue->state)) { + struct Qdisc *q; + + rcu_read_lock(); + q = rcu_dereference(dev_queue->qdisc); + __netif_schedule(q); + rcu_read_unlock(); + } +} +EXPORT_SYMBOL(netif_tx_wake_queue); + +void __dev_kfree_skb_irq(struct sk_buff *skb, enum skb_free_reason reason) +{ + unsigned long flags; + + if (unlikely(!skb)) + return; + + if (likely(refcount_read(&skb->users) == 1)) { + smp_rmb(); + refcount_set(&skb->users, 0); + } else if (likely(!refcount_dec_and_test(&skb->users))) { + return; + } + get_kfree_skb_cb(skb)->reason = reason; + local_irq_save(flags); + skb->next = __this_cpu_read(softnet_data.completion_queue); + __this_cpu_write(softnet_data.completion_queue, skb); + raise_softirq_irqoff(NET_TX_SOFTIRQ); + local_irq_restore(flags); +} +EXPORT_SYMBOL(__dev_kfree_skb_irq); + +void __dev_kfree_skb_any(struct sk_buff *skb, enum skb_free_reason reason) +{ + if (in_hardirq() || irqs_disabled()) + __dev_kfree_skb_irq(skb, reason); + else if (unlikely(reason == SKB_REASON_DROPPED)) + kfree_skb(skb); + else + consume_skb(skb); +} +EXPORT_SYMBOL(__dev_kfree_skb_any); + + +/** + * netif_device_detach - mark device as removed + * @dev: network device + * + * Mark device as removed from system and therefore no longer available. + */ +void netif_device_detach(struct net_device *dev) +{ + if (test_and_clear_bit(__LINK_STATE_PRESENT, &dev->state) && + netif_running(dev)) { + netif_tx_stop_all_queues(dev); + } +} +EXPORT_SYMBOL(netif_device_detach); + +/** + * netif_device_attach - mark device as attached + * @dev: network device + * + * Mark device as attached from system and restart if needed. + */ +void netif_device_attach(struct net_device *dev) +{ + if (!test_and_set_bit(__LINK_STATE_PRESENT, &dev->state) && + netif_running(dev)) { + netif_tx_wake_all_queues(dev); + __netdev_watchdog_up(dev); + } +} +EXPORT_SYMBOL(netif_device_attach); + +/* + * Returns a Tx hash based on the given packet descriptor a Tx queues' number + * to be used as a distribution range. + */ +static u16 skb_tx_hash(const struct net_device *dev, + const struct net_device *sb_dev, + struct sk_buff *skb) +{ + u32 hash; + u16 qoffset = 0; + u16 qcount = dev->real_num_tx_queues; + + if (dev->num_tc) { + u8 tc = netdev_get_prio_tc_map(dev, skb->priority); + + qoffset = sb_dev->tc_to_txq[tc].offset; + qcount = sb_dev->tc_to_txq[tc].count; + if (unlikely(!qcount)) { + net_warn_ratelimited("%s: invalid qcount, qoffset %u for tc %u\n", + sb_dev->name, qoffset, tc); + qoffset = 0; + qcount = dev->real_num_tx_queues; + } + } + + if (skb_rx_queue_recorded(skb)) { + DEBUG_NET_WARN_ON_ONCE(qcount == 0); + hash = skb_get_rx_queue(skb); + if (hash >= qoffset) + hash -= qoffset; + while (unlikely(hash >= qcount)) + hash -= qcount; + return hash + qoffset; + } + + return (u16) reciprocal_scale(skb_get_hash(skb), qcount) + qoffset; +} + +static void skb_warn_bad_offload(const struct sk_buff *skb) +{ + static const netdev_features_t null_features; + struct net_device *dev = skb->dev; + const char *name = ""; + + if (!net_ratelimit()) + return; + + if (dev) { + if (dev->dev.parent) + name = dev_driver_string(dev->dev.parent); + else + name = netdev_name(dev); + } + skb_dump(KERN_WARNING, skb, false); + WARN(1, "%s: caps=(%pNF, %pNF)\n", + name, dev ? &dev->features : &null_features, + skb->sk ? &skb->sk->sk_route_caps : &null_features); +} + +/* + * Invalidate hardware checksum when packet is to be mangled, and + * complete checksum manually on outgoing path. + */ +int skb_checksum_help(struct sk_buff *skb) +{ + __wsum csum; + int ret = 0, offset; + + if (skb->ip_summed == CHECKSUM_COMPLETE) + goto out_set_summed; + + if (unlikely(skb_is_gso(skb))) { + skb_warn_bad_offload(skb); + return -EINVAL; + } + + /* Before computing a checksum, we should make sure no frag could + * be modified by an external entity : checksum could be wrong. + */ + if (skb_has_shared_frag(skb)) { + ret = __skb_linearize(skb); + if (ret) + goto out; + } + + offset = skb_checksum_start_offset(skb); + ret = -EINVAL; + if (unlikely(offset >= skb_headlen(skb))) { + DO_ONCE_LITE(skb_dump, KERN_ERR, skb, false); + WARN_ONCE(true, "offset (%d) >= skb_headlen() (%u)\n", + offset, skb_headlen(skb)); + goto out; + } + csum = skb_checksum(skb, offset, skb->len - offset, 0); + + offset += skb->csum_offset; + if (unlikely(offset + sizeof(__sum16) > skb_headlen(skb))) { + DO_ONCE_LITE(skb_dump, KERN_ERR, skb, false); + WARN_ONCE(true, "offset+2 (%zu) > skb_headlen() (%u)\n", + offset + sizeof(__sum16), skb_headlen(skb)); + goto out; + } + ret = skb_ensure_writable(skb, offset + sizeof(__sum16)); + if (ret) + goto out; + + *(__sum16 *)(skb->data + offset) = csum_fold(csum) ?: CSUM_MANGLED_0; +out_set_summed: + skb->ip_summed = CHECKSUM_NONE; +out: + return ret; +} +EXPORT_SYMBOL(skb_checksum_help); + +int skb_crc32c_csum_help(struct sk_buff *skb) +{ + __le32 crc32c_csum; + int ret = 0, offset, start; + + if (skb->ip_summed != CHECKSUM_PARTIAL) + goto out; + + if (unlikely(skb_is_gso(skb))) + goto out; + + /* Before computing a checksum, we should make sure no frag could + * be modified by an external entity : checksum could be wrong. + */ + if (unlikely(skb_has_shared_frag(skb))) { + ret = __skb_linearize(skb); + if (ret) + goto out; + } + start = skb_checksum_start_offset(skb); + offset = start + offsetof(struct sctphdr, checksum); + if (WARN_ON_ONCE(offset >= skb_headlen(skb))) { + ret = -EINVAL; + goto out; + } + + ret = skb_ensure_writable(skb, offset + sizeof(__le32)); + if (ret) + goto out; + + crc32c_csum = cpu_to_le32(~__skb_checksum(skb, start, + skb->len - start, ~(__u32)0, + crc32c_csum_stub)); + *(__le32 *)(skb->data + offset) = crc32c_csum; + skb->ip_summed = CHECKSUM_NONE; + skb->csum_not_inet = 0; +out: + return ret; +} + +__be16 skb_network_protocol(struct sk_buff *skb, int *depth) +{ + __be16 type = skb->protocol; + + /* Tunnel gso handlers can set protocol to ethernet. */ + if (type == htons(ETH_P_TEB)) { + struct ethhdr *eth; + + if (unlikely(!pskb_may_pull(skb, sizeof(struct ethhdr)))) + return 0; + + eth = (struct ethhdr *)skb->data; + type = eth->h_proto; + } + + return vlan_get_protocol_and_depth(skb, type, depth); +} + +/* openvswitch calls this on rx path, so we need a different check. + */ +static inline bool skb_needs_check(struct sk_buff *skb, bool tx_path) +{ + if (tx_path) + return skb->ip_summed != CHECKSUM_PARTIAL && + skb->ip_summed != CHECKSUM_UNNECESSARY; + + return skb->ip_summed == CHECKSUM_NONE; +} + +/** + * __skb_gso_segment - Perform segmentation on skb. + * @skb: buffer to segment + * @features: features for the output path (see dev->features) + * @tx_path: whether it is called in TX path + * + * This function segments the given skb and returns a list of segments. + * + * It may return NULL if the skb requires no segmentation. This is + * only possible when GSO is used for verifying header integrity. + * + * Segmentation preserves SKB_GSO_CB_OFFSET bytes of previous skb cb. + */ +struct sk_buff *__skb_gso_segment(struct sk_buff *skb, + netdev_features_t features, bool tx_path) +{ + struct sk_buff *segs; + + if (unlikely(skb_needs_check(skb, tx_path))) { + int err; + + /* We're going to init ->check field in TCP or UDP header */ + err = skb_cow_head(skb, 0); + if (err < 0) + return ERR_PTR(err); + } + + /* Only report GSO partial support if it will enable us to + * support segmentation on this frame without needing additional + * work. + */ + if (features & NETIF_F_GSO_PARTIAL) { + netdev_features_t partial_features = NETIF_F_GSO_ROBUST; + struct net_device *dev = skb->dev; + + partial_features |= dev->features & dev->gso_partial_features; + if (!skb_gso_ok(skb, features | partial_features)) + features &= ~NETIF_F_GSO_PARTIAL; + } + + BUILD_BUG_ON(SKB_GSO_CB_OFFSET + + sizeof(*SKB_GSO_CB(skb)) > sizeof(skb->cb)); + + SKB_GSO_CB(skb)->mac_offset = skb_headroom(skb); + SKB_GSO_CB(skb)->encap_level = 0; + + skb_reset_mac_header(skb); + skb_reset_mac_len(skb); + + segs = skb_mac_gso_segment(skb, features); + + if (segs != skb && unlikely(skb_needs_check(skb, tx_path) && !IS_ERR(segs))) + skb_warn_bad_offload(skb); + + return segs; +} +EXPORT_SYMBOL(__skb_gso_segment); + +/* Take action when hardware reception checksum errors are detected. */ +#ifdef CONFIG_BUG +static void do_netdev_rx_csum_fault(struct net_device *dev, struct sk_buff *skb) +{ + netdev_err(dev, "hw csum failure\n"); + skb_dump(KERN_ERR, skb, true); + dump_stack(); +} + +void netdev_rx_csum_fault(struct net_device *dev, struct sk_buff *skb) +{ + DO_ONCE_LITE(do_netdev_rx_csum_fault, dev, skb); +} +EXPORT_SYMBOL(netdev_rx_csum_fault); +#endif + +/* XXX: check that highmem exists at all on the given machine. */ +static int illegal_highdma(struct net_device *dev, struct sk_buff *skb) +{ +#ifdef CONFIG_HIGHMEM + int i; + + if (!(dev->features & NETIF_F_HIGHDMA)) { + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; + + if (PageHighMem(skb_frag_page(frag))) + return 1; + } + } +#endif + return 0; +} + +/* If MPLS offload request, verify we are testing hardware MPLS features + * instead of standard features for the netdev. + */ +#if IS_ENABLED(CONFIG_NET_MPLS_GSO) +static netdev_features_t net_mpls_features(struct sk_buff *skb, + netdev_features_t features, + __be16 type) +{ + if (eth_p_mpls(type)) + features &= skb->dev->mpls_features; + + return features; +} +#else +static netdev_features_t net_mpls_features(struct sk_buff *skb, + netdev_features_t features, + __be16 type) +{ + return features; +} +#endif + +static netdev_features_t harmonize_features(struct sk_buff *skb, + netdev_features_t features) +{ + __be16 type; + + type = skb_network_protocol(skb, NULL); + features = net_mpls_features(skb, features, type); + + if (skb->ip_summed != CHECKSUM_NONE && + !can_checksum_protocol(features, type)) { + features &= ~(NETIF_F_CSUM_MASK | NETIF_F_GSO_MASK); + } + if (illegal_highdma(skb->dev, skb)) + features &= ~NETIF_F_SG; + + return features; +} + +netdev_features_t passthru_features_check(struct sk_buff *skb, + struct net_device *dev, + netdev_features_t features) +{ + return features; +} +EXPORT_SYMBOL(passthru_features_check); + +static netdev_features_t dflt_features_check(struct sk_buff *skb, + struct net_device *dev, + netdev_features_t features) +{ + return vlan_features_check(skb, features); +} + +static netdev_features_t gso_features_check(const struct sk_buff *skb, + struct net_device *dev, + netdev_features_t features) +{ + u16 gso_segs = skb_shinfo(skb)->gso_segs; + + if (gso_segs > READ_ONCE(dev->gso_max_segs)) + return features & ~NETIF_F_GSO_MASK; + + if (unlikely(skb->len >= READ_ONCE(dev->gso_max_size))) + return features & ~NETIF_F_GSO_MASK; + + if (!skb_shinfo(skb)->gso_type) { + skb_warn_bad_offload(skb); + return features & ~NETIF_F_GSO_MASK; + } + + /* Support for GSO partial features requires software + * intervention before we can actually process the packets + * so we need to strip support for any partial features now + * and we can pull them back in after we have partially + * segmented the frame. + */ + if (!(skb_shinfo(skb)->gso_type & SKB_GSO_PARTIAL)) + features &= ~dev->gso_partial_features; + + /* Make sure to clear the IPv4 ID mangling feature if the + * IPv4 header has the potential to be fragmented. + */ + if (skb_shinfo(skb)->gso_type & SKB_GSO_TCPV4) { + struct iphdr *iph = skb->encapsulation ? + inner_ip_hdr(skb) : ip_hdr(skb); + + if (!(iph->frag_off & htons(IP_DF))) + features &= ~NETIF_F_TSO_MANGLEID; + } + + return features; +} + +netdev_features_t netif_skb_features(struct sk_buff *skb) +{ + struct net_device *dev = skb->dev; + netdev_features_t features = dev->features; + + if (skb_is_gso(skb)) + features = gso_features_check(skb, dev, features); + + /* If encapsulation offload request, verify we are testing + * hardware encapsulation features instead of standard + * features for the netdev + */ + if (skb->encapsulation) + features &= dev->hw_enc_features; + + if (skb_vlan_tagged(skb)) + features = netdev_intersect_features(features, + dev->vlan_features | + NETIF_F_HW_VLAN_CTAG_TX | + NETIF_F_HW_VLAN_STAG_TX); + + if (dev->netdev_ops->ndo_features_check) + features &= dev->netdev_ops->ndo_features_check(skb, dev, + features); + else + features &= dflt_features_check(skb, dev, features); + + return harmonize_features(skb, features); +} +EXPORT_SYMBOL(netif_skb_features); + +static int xmit_one(struct sk_buff *skb, struct net_device *dev, + struct netdev_queue *txq, bool more) +{ + unsigned int len; + int rc; + + if (dev_nit_active(dev)) + dev_queue_xmit_nit(skb, dev); + + len = skb->len; + trace_net_dev_start_xmit(skb, dev); + rc = netdev_start_xmit(skb, dev, txq, more); + trace_net_dev_xmit(skb, rc, dev, len); + + return rc; +} + +struct sk_buff *dev_hard_start_xmit(struct sk_buff *first, struct net_device *dev, + struct netdev_queue *txq, int *ret) +{ + struct sk_buff *skb = first; + int rc = NETDEV_TX_OK; + + while (skb) { + struct sk_buff *next = skb->next; + + skb_mark_not_on_list(skb); + rc = xmit_one(skb, dev, txq, next != NULL); + if (unlikely(!dev_xmit_complete(rc))) { + skb->next = next; + goto out; + } + + skb = next; + if (netif_tx_queue_stopped(txq) && skb) { + rc = NETDEV_TX_BUSY; + break; + } + } + +out: + *ret = rc; + return skb; +} + +static struct sk_buff *validate_xmit_vlan(struct sk_buff *skb, + netdev_features_t features) +{ + if (skb_vlan_tag_present(skb) && + !vlan_hw_offload_capable(features, skb->vlan_proto)) + skb = __vlan_hwaccel_push_inside(skb); + return skb; +} + +int skb_csum_hwoffload_help(struct sk_buff *skb, + const netdev_features_t features) +{ + if (unlikely(skb_csum_is_sctp(skb))) + return !!(features & NETIF_F_SCTP_CRC) ? 0 : + skb_crc32c_csum_help(skb); + + if (features & NETIF_F_HW_CSUM) + return 0; + + if (features & (NETIF_F_IP_CSUM | NETIF_F_IPV6_CSUM)) { + switch (skb->csum_offset) { + case offsetof(struct tcphdr, check): + case offsetof(struct udphdr, check): + return 0; + } + } + + return skb_checksum_help(skb); +} +EXPORT_SYMBOL(skb_csum_hwoffload_help); + +static struct sk_buff *validate_xmit_skb(struct sk_buff *skb, struct net_device *dev, bool *again) +{ + netdev_features_t features; + + features = netif_skb_features(skb); + skb = validate_xmit_vlan(skb, features); + if (unlikely(!skb)) + goto out_null; + + skb = sk_validate_xmit_skb(skb, dev); + if (unlikely(!skb)) + goto out_null; + + if (netif_needs_gso(skb, features)) { + struct sk_buff *segs; + + segs = skb_gso_segment(skb, features); + if (IS_ERR(segs)) { + goto out_kfree_skb; + } else if (segs) { + consume_skb(skb); + skb = segs; + } + } else { + if (skb_needs_linearize(skb, features) && + __skb_linearize(skb)) + goto out_kfree_skb; + + /* If packet is not checksummed and device does not + * support checksumming for this protocol, complete + * checksumming here. + */ + if (skb->ip_summed == CHECKSUM_PARTIAL) { + if (skb->encapsulation) + skb_set_inner_transport_header(skb, + skb_checksum_start_offset(skb)); + else + skb_set_transport_header(skb, + skb_checksum_start_offset(skb)); + if (skb_csum_hwoffload_help(skb, features)) + goto out_kfree_skb; + } + } + + skb = validate_xmit_xfrm(skb, features, again); + + return skb; + +out_kfree_skb: + kfree_skb(skb); +out_null: + dev_core_stats_tx_dropped_inc(dev); + return NULL; +} + +struct sk_buff *validate_xmit_skb_list(struct sk_buff *skb, struct net_device *dev, bool *again) +{ + struct sk_buff *next, *head = NULL, *tail; + + for (; skb != NULL; skb = next) { + next = skb->next; + skb_mark_not_on_list(skb); + + /* in case skb wont be segmented, point to itself */ + skb->prev = skb; + + skb = validate_xmit_skb(skb, dev, again); + if (!skb) + continue; + + if (!head) + head = skb; + else + tail->next = skb; + /* If skb was segmented, skb->prev points to + * the last segment. If not, it still contains skb. + */ + tail = skb->prev; + } + return head; +} +EXPORT_SYMBOL_GPL(validate_xmit_skb_list); + +static void qdisc_pkt_len_init(struct sk_buff *skb) +{ + const struct skb_shared_info *shinfo = skb_shinfo(skb); + + qdisc_skb_cb(skb)->pkt_len = skb->len; + + /* To get more precise estimation of bytes sent on wire, + * we add to pkt_len the headers size of all segments + */ + if (shinfo->gso_size && skb_transport_header_was_set(skb)) { + unsigned int hdr_len; + u16 gso_segs = shinfo->gso_segs; + + /* mac layer + network layer */ + hdr_len = skb_transport_header(skb) - skb_mac_header(skb); + + /* + transport layer */ + if (likely(shinfo->gso_type & (SKB_GSO_TCPV4 | SKB_GSO_TCPV6))) { + const struct tcphdr *th; + struct tcphdr _tcphdr; + + th = skb_header_pointer(skb, skb_transport_offset(skb), + sizeof(_tcphdr), &_tcphdr); + if (likely(th)) + hdr_len += __tcp_hdrlen(th); + } else { + struct udphdr _udphdr; + + if (skb_header_pointer(skb, skb_transport_offset(skb), + sizeof(_udphdr), &_udphdr)) + hdr_len += sizeof(struct udphdr); + } + + if (shinfo->gso_type & SKB_GSO_DODGY) + gso_segs = DIV_ROUND_UP(skb->len - hdr_len, + shinfo->gso_size); + + qdisc_skb_cb(skb)->pkt_len += (gso_segs - 1) * hdr_len; + } +} + +static int dev_qdisc_enqueue(struct sk_buff *skb, struct Qdisc *q, + struct sk_buff **to_free, + struct netdev_queue *txq) +{ + int rc; + + rc = q->enqueue(skb, q, to_free) & NET_XMIT_MASK; + if (rc == NET_XMIT_SUCCESS) + trace_qdisc_enqueue(q, txq, skb); + return rc; +} + +static inline int __dev_xmit_skb(struct sk_buff *skb, struct Qdisc *q, + struct net_device *dev, + struct netdev_queue *txq) +{ + spinlock_t *root_lock = qdisc_lock(q); + struct sk_buff *to_free = NULL; + bool contended; + int rc; + + qdisc_calculate_pkt_len(skb, q); + + if (q->flags & TCQ_F_NOLOCK) { + if (q->flags & TCQ_F_CAN_BYPASS && nolock_qdisc_is_empty(q) && + qdisc_run_begin(q)) { + /* Retest nolock_qdisc_is_empty() within the protection + * of q->seqlock to protect from racing with requeuing. + */ + if (unlikely(!nolock_qdisc_is_empty(q))) { + rc = dev_qdisc_enqueue(skb, q, &to_free, txq); + __qdisc_run(q); + qdisc_run_end(q); + + goto no_lock_out; + } + + qdisc_bstats_cpu_update(q, skb); + if (sch_direct_xmit(skb, q, dev, txq, NULL, true) && + !nolock_qdisc_is_empty(q)) + __qdisc_run(q); + + qdisc_run_end(q); + return NET_XMIT_SUCCESS; + } + + rc = dev_qdisc_enqueue(skb, q, &to_free, txq); + qdisc_run(q); + +no_lock_out: + if (unlikely(to_free)) + kfree_skb_list_reason(to_free, + SKB_DROP_REASON_QDISC_DROP); + return rc; + } + + /* + * Heuristic to force contended enqueues to serialize on a + * separate lock before trying to get qdisc main lock. + * This permits qdisc->running owner to get the lock more + * often and dequeue packets faster. + * On PREEMPT_RT it is possible to preempt the qdisc owner during xmit + * and then other tasks will only enqueue packets. The packets will be + * sent after the qdisc owner is scheduled again. To prevent this + * scenario the task always serialize on the lock. + */ + contended = qdisc_is_running(q) || IS_ENABLED(CONFIG_PREEMPT_RT); + if (unlikely(contended)) + spin_lock(&q->busylock); + + spin_lock(root_lock); + if (unlikely(test_bit(__QDISC_STATE_DEACTIVATED, &q->state))) { + __qdisc_drop(skb, &to_free); + rc = NET_XMIT_DROP; + } else if ((q->flags & TCQ_F_CAN_BYPASS) && !qdisc_qlen(q) && + qdisc_run_begin(q)) { + /* + * This is a work-conserving queue; there are no old skbs + * waiting to be sent out; and the qdisc is not running - + * xmit the skb directly. + */ + + qdisc_bstats_update(q, skb); + + if (sch_direct_xmit(skb, q, dev, txq, root_lock, true)) { + if (unlikely(contended)) { + spin_unlock(&q->busylock); + contended = false; + } + __qdisc_run(q); + } + + qdisc_run_end(q); + rc = NET_XMIT_SUCCESS; + } else { + rc = dev_qdisc_enqueue(skb, q, &to_free, txq); + if (qdisc_run_begin(q)) { + if (unlikely(contended)) { + spin_unlock(&q->busylock); + contended = false; + } + __qdisc_run(q); + qdisc_run_end(q); + } + } + spin_unlock(root_lock); + if (unlikely(to_free)) + kfree_skb_list_reason(to_free, SKB_DROP_REASON_QDISC_DROP); + if (unlikely(contended)) + spin_unlock(&q->busylock); + return rc; +} + +#if IS_ENABLED(CONFIG_CGROUP_NET_PRIO) +static void skb_update_prio(struct sk_buff *skb) +{ + const struct netprio_map *map; + const struct sock *sk; + unsigned int prioidx; + + if (skb->priority) + return; + map = rcu_dereference_bh(skb->dev->priomap); + if (!map) + return; + sk = skb_to_full_sk(skb); + if (!sk) + return; + + prioidx = sock_cgroup_prioidx(&sk->sk_cgrp_data); + + if (prioidx < map->priomap_len) + skb->priority = map->priomap[prioidx]; +} +#else +#define skb_update_prio(skb) +#endif + +/** + * dev_loopback_xmit - loop back @skb + * @net: network namespace this loopback is happening in + * @sk: sk needed to be a netfilter okfn + * @skb: buffer to transmit + */ +int dev_loopback_xmit(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + skb_reset_mac_header(skb); + __skb_pull(skb, skb_network_offset(skb)); + skb->pkt_type = PACKET_LOOPBACK; + if (skb->ip_summed == CHECKSUM_NONE) + skb->ip_summed = CHECKSUM_UNNECESSARY; + DEBUG_NET_WARN_ON_ONCE(!skb_dst(skb)); + skb_dst_force(skb); + netif_rx(skb); + return 0; +} +EXPORT_SYMBOL(dev_loopback_xmit); + +#ifdef CONFIG_NET_EGRESS +static struct sk_buff * +sch_handle_egress(struct sk_buff *skb, int *ret, struct net_device *dev) +{ +#ifdef CONFIG_NET_CLS_ACT + struct mini_Qdisc *miniq = rcu_dereference_bh(dev->miniq_egress); + struct tcf_result cl_res; + + if (!miniq) + return skb; + + /* qdisc_skb_cb(skb)->pkt_len was already set by the caller. */ + tc_skb_cb(skb)->mru = 0; + tc_skb_cb(skb)->post_ct = false; + mini_qdisc_bstats_cpu_update(miniq, skb); + + switch (tcf_classify(skb, miniq->block, miniq->filter_list, &cl_res, false)) { + case TC_ACT_OK: + case TC_ACT_RECLASSIFY: + skb->tc_index = TC_H_MIN(cl_res.classid); + break; + case TC_ACT_SHOT: + mini_qdisc_qstats_cpu_drop(miniq); + *ret = NET_XMIT_DROP; + kfree_skb_reason(skb, SKB_DROP_REASON_TC_EGRESS); + return NULL; + case TC_ACT_STOLEN: + case TC_ACT_QUEUED: + case TC_ACT_TRAP: + *ret = NET_XMIT_SUCCESS; + consume_skb(skb); + return NULL; + case TC_ACT_REDIRECT: + /* No need to push/pop skb's mac_header here on egress! */ + skb_do_redirect(skb); + *ret = NET_XMIT_SUCCESS; + return NULL; + default: + break; + } +#endif /* CONFIG_NET_CLS_ACT */ + + return skb; +} + +static struct netdev_queue * +netdev_tx_queue_mapping(struct net_device *dev, struct sk_buff *skb) +{ + int qm = skb_get_queue_mapping(skb); + + return netdev_get_tx_queue(dev, netdev_cap_txqueue(dev, qm)); +} + +static bool netdev_xmit_txqueue_skipped(void) +{ + return __this_cpu_read(softnet_data.xmit.skip_txqueue); +} + +void netdev_xmit_skip_txqueue(bool skip) +{ + __this_cpu_write(softnet_data.xmit.skip_txqueue, skip); +} +EXPORT_SYMBOL_GPL(netdev_xmit_skip_txqueue); +#endif /* CONFIG_NET_EGRESS */ + +#ifdef CONFIG_XPS +static int __get_xps_queue_idx(struct net_device *dev, struct sk_buff *skb, + struct xps_dev_maps *dev_maps, unsigned int tci) +{ + int tc = netdev_get_prio_tc_map(dev, skb->priority); + struct xps_map *map; + int queue_index = -1; + + if (tc >= dev_maps->num_tc || tci >= dev_maps->nr_ids) + return queue_index; + + tci *= dev_maps->num_tc; + tci += tc; + + map = rcu_dereference(dev_maps->attr_map[tci]); + if (map) { + if (map->len == 1) + queue_index = map->queues[0]; + else + queue_index = map->queues[reciprocal_scale( + skb_get_hash(skb), map->len)]; + if (unlikely(queue_index >= dev->real_num_tx_queues)) + queue_index = -1; + } + return queue_index; +} +#endif + +static int get_xps_queue(struct net_device *dev, struct net_device *sb_dev, + struct sk_buff *skb) +{ +#ifdef CONFIG_XPS + struct xps_dev_maps *dev_maps; + struct sock *sk = skb->sk; + int queue_index = -1; + + if (!static_key_false(&xps_needed)) + return -1; + + rcu_read_lock(); + if (!static_key_false(&xps_rxqs_needed)) + goto get_cpus_map; + + dev_maps = rcu_dereference(sb_dev->xps_maps[XPS_RXQS]); + if (dev_maps) { + int tci = sk_rx_queue_get(sk); + + if (tci >= 0) + queue_index = __get_xps_queue_idx(dev, skb, dev_maps, + tci); + } + +get_cpus_map: + if (queue_index < 0) { + dev_maps = rcu_dereference(sb_dev->xps_maps[XPS_CPUS]); + if (dev_maps) { + unsigned int tci = skb->sender_cpu - 1; + + queue_index = __get_xps_queue_idx(dev, skb, dev_maps, + tci); + } + } + rcu_read_unlock(); + + return queue_index; +#else + return -1; +#endif +} + +u16 dev_pick_tx_zero(struct net_device *dev, struct sk_buff *skb, + struct net_device *sb_dev) +{ + return 0; +} +EXPORT_SYMBOL(dev_pick_tx_zero); + +u16 dev_pick_tx_cpu_id(struct net_device *dev, struct sk_buff *skb, + struct net_device *sb_dev) +{ + return (u16)raw_smp_processor_id() % dev->real_num_tx_queues; +} +EXPORT_SYMBOL(dev_pick_tx_cpu_id); + +u16 netdev_pick_tx(struct net_device *dev, struct sk_buff *skb, + struct net_device *sb_dev) +{ + struct sock *sk = skb->sk; + int queue_index = sk_tx_queue_get(sk); + + sb_dev = sb_dev ? : dev; + + if (queue_index < 0 || skb->ooo_okay || + queue_index >= dev->real_num_tx_queues) { + int new_index = get_xps_queue(dev, sb_dev, skb); + + if (new_index < 0) + new_index = skb_tx_hash(dev, sb_dev, skb); + + if (queue_index != new_index && sk && + sk_fullsock(sk) && + rcu_access_pointer(sk->sk_dst_cache)) + sk_tx_queue_set(sk, new_index); + + queue_index = new_index; + } + + return queue_index; +} +EXPORT_SYMBOL(netdev_pick_tx); + +struct netdev_queue *netdev_core_pick_tx(struct net_device *dev, + struct sk_buff *skb, + struct net_device *sb_dev) +{ + int queue_index = 0; + +#ifdef CONFIG_XPS + u32 sender_cpu = skb->sender_cpu - 1; + + if (sender_cpu >= (u32)NR_CPUS) + skb->sender_cpu = raw_smp_processor_id() + 1; +#endif + + if (dev->real_num_tx_queues != 1) { + const struct net_device_ops *ops = dev->netdev_ops; + + if (ops->ndo_select_queue) + queue_index = ops->ndo_select_queue(dev, skb, sb_dev); + else + queue_index = netdev_pick_tx(dev, skb, sb_dev); + + queue_index = netdev_cap_txqueue(dev, queue_index); + } + + skb_set_queue_mapping(skb, queue_index); + return netdev_get_tx_queue(dev, queue_index); +} + +/** + * __dev_queue_xmit() - transmit a buffer + * @skb: buffer to transmit + * @sb_dev: suboordinate device used for L2 forwarding offload + * + * Queue a buffer for transmission to a network device. The caller must + * have set the device and priority and built the buffer before calling + * this function. The function can be called from an interrupt. + * + * When calling this method, interrupts MUST be enabled. This is because + * the BH enable code must have IRQs enabled so that it will not deadlock. + * + * Regardless of the return value, the skb is consumed, so it is currently + * difficult to retry a send to this method. (You can bump the ref count + * before sending to hold a reference for retry if you are careful.) + * + * Return: + * * 0 - buffer successfully transmitted + * * positive qdisc return code - NET_XMIT_DROP etc. + * * negative errno - other errors + */ +int __dev_queue_xmit(struct sk_buff *skb, struct net_device *sb_dev) +{ + struct net_device *dev = skb->dev; + struct netdev_queue *txq = NULL; + struct Qdisc *q; + int rc = -ENOMEM; + bool again = false; + + skb_reset_mac_header(skb); + skb_assert_len(skb); + + if (unlikely(skb_shinfo(skb)->tx_flags & SKBTX_SCHED_TSTAMP)) + __skb_tstamp_tx(skb, NULL, NULL, skb->sk, SCM_TSTAMP_SCHED); + + /* Disable soft irqs for various locks below. Also + * stops preemption for RCU. + */ + rcu_read_lock_bh(); + + skb_update_prio(skb); + + qdisc_pkt_len_init(skb); +#ifdef CONFIG_NET_CLS_ACT + skb->tc_at_ingress = 0; +#endif +#ifdef CONFIG_NET_EGRESS + if (static_branch_unlikely(&egress_needed_key)) { + if (nf_hook_egress_active()) { + skb = nf_hook_egress(skb, &rc, dev); + if (!skb) + goto out; + } + + netdev_xmit_skip_txqueue(false); + + nf_skip_egress(skb, true); + skb = sch_handle_egress(skb, &rc, dev); + if (!skb) + goto out; + nf_skip_egress(skb, false); + + if (netdev_xmit_txqueue_skipped()) + txq = netdev_tx_queue_mapping(dev, skb); + } +#endif + /* If device/qdisc don't need skb->dst, release it right now while + * its hot in this cpu cache. + */ + if (dev->priv_flags & IFF_XMIT_DST_RELEASE) + skb_dst_drop(skb); + else + skb_dst_force(skb); + + if (!txq) + txq = netdev_core_pick_tx(dev, skb, sb_dev); + + q = rcu_dereference_bh(txq->qdisc); + + trace_net_dev_queue(skb); + if (q->enqueue) { + rc = __dev_xmit_skb(skb, q, dev, txq); + goto out; + } + + /* The device has no queue. Common case for software devices: + * loopback, all the sorts of tunnels... + + * Really, it is unlikely that netif_tx_lock protection is necessary + * here. (f.e. loopback and IP tunnels are clean ignoring statistics + * counters.) + * However, it is possible, that they rely on protection + * made by us here. + + * Check this and shot the lock. It is not prone from deadlocks. + *Either shot noqueue qdisc, it is even simpler 8) + */ + if (dev->flags & IFF_UP) { + int cpu = smp_processor_id(); /* ok because BHs are off */ + + /* Other cpus might concurrently change txq->xmit_lock_owner + * to -1 or to their cpu id, but not to our id. + */ + if (READ_ONCE(txq->xmit_lock_owner) != cpu) { + if (dev_xmit_recursion()) + goto recursion_alert; + + skb = validate_xmit_skb(skb, dev, &again); + if (!skb) + goto out; + + HARD_TX_LOCK(dev, txq, cpu); + + if (!netif_xmit_stopped(txq)) { + dev_xmit_recursion_inc(); + skb = dev_hard_start_xmit(skb, dev, txq, &rc); + dev_xmit_recursion_dec(); + if (dev_xmit_complete(rc)) { + HARD_TX_UNLOCK(dev, txq); + goto out; + } + } + HARD_TX_UNLOCK(dev, txq); + net_crit_ratelimited("Virtual device %s asks to queue packet!\n", + dev->name); + } else { + /* Recursion is detected! It is possible, + * unfortunately + */ +recursion_alert: + net_crit_ratelimited("Dead loop on virtual device %s, fix it urgently!\n", + dev->name); + } + } + + rc = -ENETDOWN; + rcu_read_unlock_bh(); + + dev_core_stats_tx_dropped_inc(dev); + kfree_skb_list(skb); + return rc; +out: + rcu_read_unlock_bh(); + return rc; +} +EXPORT_SYMBOL(__dev_queue_xmit); + +int __dev_direct_xmit(struct sk_buff *skb, u16 queue_id) +{ + struct net_device *dev = skb->dev; + struct sk_buff *orig_skb = skb; + struct netdev_queue *txq; + int ret = NETDEV_TX_BUSY; + bool again = false; + + if (unlikely(!netif_running(dev) || + !netif_carrier_ok(dev))) + goto drop; + + skb = validate_xmit_skb_list(skb, dev, &again); + if (skb != orig_skb) + goto drop; + + skb_set_queue_mapping(skb, queue_id); + txq = skb_get_tx_queue(dev, skb); + + local_bh_disable(); + + dev_xmit_recursion_inc(); + HARD_TX_LOCK(dev, txq, smp_processor_id()); + if (!netif_xmit_frozen_or_drv_stopped(txq)) + ret = netdev_start_xmit(skb, dev, txq, false); + HARD_TX_UNLOCK(dev, txq); + dev_xmit_recursion_dec(); + + local_bh_enable(); + return ret; +drop: + dev_core_stats_tx_dropped_inc(dev); + kfree_skb_list(skb); + return NET_XMIT_DROP; +} +EXPORT_SYMBOL(__dev_direct_xmit); + +/************************************************************************* + * Receiver routines + *************************************************************************/ + +int netdev_max_backlog __read_mostly = 1000; +EXPORT_SYMBOL(netdev_max_backlog); + +int netdev_tstamp_prequeue __read_mostly = 1; +unsigned int sysctl_skb_defer_max __read_mostly = 64; +int netdev_budget __read_mostly = 300; +/* Must be at least 2 jiffes to guarantee 1 jiffy timeout */ +unsigned int __read_mostly netdev_budget_usecs = 2 * USEC_PER_SEC / HZ; +int weight_p __read_mostly = 64; /* old backlog weight */ +int dev_weight_rx_bias __read_mostly = 1; /* bias for backlog weight */ +int dev_weight_tx_bias __read_mostly = 1; /* bias for output_queue quota */ +int dev_rx_weight __read_mostly = 64; +int dev_tx_weight __read_mostly = 64; + +/* Called with irq disabled */ +static inline void ____napi_schedule(struct softnet_data *sd, + struct napi_struct *napi) +{ + struct task_struct *thread; + + lockdep_assert_irqs_disabled(); + + if (test_bit(NAPI_STATE_THREADED, &napi->state)) { + /* Paired with smp_mb__before_atomic() in + * napi_enable()/dev_set_threaded(). + * Use READ_ONCE() to guarantee a complete + * read on napi->thread. Only call + * wake_up_process() when it's not NULL. + */ + thread = READ_ONCE(napi->thread); + if (thread) { + /* Avoid doing set_bit() if the thread is in + * INTERRUPTIBLE state, cause napi_thread_wait() + * makes sure to proceed with napi polling + * if the thread is explicitly woken from here. + */ + if (READ_ONCE(thread->__state) != TASK_INTERRUPTIBLE) + set_bit(NAPI_STATE_SCHED_THREADED, &napi->state); + wake_up_process(thread); + return; + } + } + + list_add_tail(&napi->poll_list, &sd->poll_list); + __raise_softirq_irqoff(NET_RX_SOFTIRQ); +} + +#ifdef CONFIG_RPS + +/* One global table that all flow-based protocols share. */ +struct rps_sock_flow_table __rcu *rps_sock_flow_table __read_mostly; +EXPORT_SYMBOL(rps_sock_flow_table); +u32 rps_cpu_mask __read_mostly; +EXPORT_SYMBOL(rps_cpu_mask); + +struct static_key_false rps_needed __read_mostly; +EXPORT_SYMBOL(rps_needed); +struct static_key_false rfs_needed __read_mostly; +EXPORT_SYMBOL(rfs_needed); + +static struct rps_dev_flow * +set_rps_cpu(struct net_device *dev, struct sk_buff *skb, + struct rps_dev_flow *rflow, u16 next_cpu) +{ + if (next_cpu < nr_cpu_ids) { +#ifdef CONFIG_RFS_ACCEL + struct netdev_rx_queue *rxqueue; + struct rps_dev_flow_table *flow_table; + struct rps_dev_flow *old_rflow; + u32 flow_id; + u16 rxq_index; + int rc; + + /* Should we steer this flow to a different hardware queue? */ + if (!skb_rx_queue_recorded(skb) || !dev->rx_cpu_rmap || + !(dev->features & NETIF_F_NTUPLE)) + goto out; + rxq_index = cpu_rmap_lookup_index(dev->rx_cpu_rmap, next_cpu); + if (rxq_index == skb_get_rx_queue(skb)) + goto out; + + rxqueue = dev->_rx + rxq_index; + flow_table = rcu_dereference(rxqueue->rps_flow_table); + if (!flow_table) + goto out; + flow_id = skb_get_hash(skb) & flow_table->mask; + rc = dev->netdev_ops->ndo_rx_flow_steer(dev, skb, + rxq_index, flow_id); + if (rc < 0) + goto out; + old_rflow = rflow; + rflow = &flow_table->flows[flow_id]; + rflow->filter = rc; + if (old_rflow->filter == rflow->filter) + old_rflow->filter = RPS_NO_FILTER; + out: +#endif + rflow->last_qtail = + per_cpu(softnet_data, next_cpu).input_queue_head; + } + + rflow->cpu = next_cpu; + return rflow; +} + +/* + * get_rps_cpu is called from netif_receive_skb and returns the target + * CPU from the RPS map of the receiving queue for a given skb. + * rcu_read_lock must be held on entry. + */ +static int get_rps_cpu(struct net_device *dev, struct sk_buff *skb, + struct rps_dev_flow **rflowp) +{ + const struct rps_sock_flow_table *sock_flow_table; + struct netdev_rx_queue *rxqueue = dev->_rx; + struct rps_dev_flow_table *flow_table; + struct rps_map *map; + int cpu = -1; + u32 tcpu; + u32 hash; + + if (skb_rx_queue_recorded(skb)) { + u16 index = skb_get_rx_queue(skb); + + if (unlikely(index >= dev->real_num_rx_queues)) { + WARN_ONCE(dev->real_num_rx_queues > 1, + "%s received packet on queue %u, but number " + "of RX queues is %u\n", + dev->name, index, dev->real_num_rx_queues); + goto done; + } + rxqueue += index; + } + + /* Avoid computing hash if RFS/RPS is not active for this rxqueue */ + + flow_table = rcu_dereference(rxqueue->rps_flow_table); + map = rcu_dereference(rxqueue->rps_map); + if (!flow_table && !map) + goto done; + + skb_reset_network_header(skb); + hash = skb_get_hash(skb); + if (!hash) + goto done; + + sock_flow_table = rcu_dereference(rps_sock_flow_table); + if (flow_table && sock_flow_table) { + struct rps_dev_flow *rflow; + u32 next_cpu; + u32 ident; + + /* First check into global flow table if there is a match. + * This READ_ONCE() pairs with WRITE_ONCE() from rps_record_sock_flow(). + */ + ident = READ_ONCE(sock_flow_table->ents[hash & sock_flow_table->mask]); + if ((ident ^ hash) & ~rps_cpu_mask) + goto try_rps; + + next_cpu = ident & rps_cpu_mask; + + /* OK, now we know there is a match, + * we can look at the local (per receive queue) flow table + */ + rflow = &flow_table->flows[hash & flow_table->mask]; + tcpu = rflow->cpu; + + /* + * If the desired CPU (where last recvmsg was done) is + * different from current CPU (one in the rx-queue flow + * table entry), switch if one of the following holds: + * - Current CPU is unset (>= nr_cpu_ids). + * - Current CPU is offline. + * - The current CPU's queue tail has advanced beyond the + * last packet that was enqueued using this table entry. + * This guarantees that all previous packets for the flow + * have been dequeued, thus preserving in order delivery. + */ + if (unlikely(tcpu != next_cpu) && + (tcpu >= nr_cpu_ids || !cpu_online(tcpu) || + ((int)(per_cpu(softnet_data, tcpu).input_queue_head - + rflow->last_qtail)) >= 0)) { + tcpu = next_cpu; + rflow = set_rps_cpu(dev, skb, rflow, next_cpu); + } + + if (tcpu < nr_cpu_ids && cpu_online(tcpu)) { + *rflowp = rflow; + cpu = tcpu; + goto done; + } + } + +try_rps: + + if (map) { + tcpu = map->cpus[reciprocal_scale(hash, map->len)]; + if (cpu_online(tcpu)) { + cpu = tcpu; + goto done; + } + } + +done: + return cpu; +} + +#ifdef CONFIG_RFS_ACCEL + +/** + * rps_may_expire_flow - check whether an RFS hardware filter may be removed + * @dev: Device on which the filter was set + * @rxq_index: RX queue index + * @flow_id: Flow ID passed to ndo_rx_flow_steer() + * @filter_id: Filter ID returned by ndo_rx_flow_steer() + * + * Drivers that implement ndo_rx_flow_steer() should periodically call + * this function for each installed filter and remove the filters for + * which it returns %true. + */ +bool rps_may_expire_flow(struct net_device *dev, u16 rxq_index, + u32 flow_id, u16 filter_id) +{ + struct netdev_rx_queue *rxqueue = dev->_rx + rxq_index; + struct rps_dev_flow_table *flow_table; + struct rps_dev_flow *rflow; + bool expire = true; + unsigned int cpu; + + rcu_read_lock(); + flow_table = rcu_dereference(rxqueue->rps_flow_table); + if (flow_table && flow_id <= flow_table->mask) { + rflow = &flow_table->flows[flow_id]; + cpu = READ_ONCE(rflow->cpu); + if (rflow->filter == filter_id && cpu < nr_cpu_ids && + ((int)(per_cpu(softnet_data, cpu).input_queue_head - + rflow->last_qtail) < + (int)(10 * flow_table->mask))) + expire = false; + } + rcu_read_unlock(); + return expire; +} +EXPORT_SYMBOL(rps_may_expire_flow); + +#endif /* CONFIG_RFS_ACCEL */ + +/* Called from hardirq (IPI) context */ +static void rps_trigger_softirq(void *data) +{ + struct softnet_data *sd = data; + + ____napi_schedule(sd, &sd->backlog); + sd->received_rps++; +} + +#endif /* CONFIG_RPS */ + +/* Called from hardirq (IPI) context */ +static void trigger_rx_softirq(void *data) +{ + struct softnet_data *sd = data; + + __raise_softirq_irqoff(NET_RX_SOFTIRQ); + smp_store_release(&sd->defer_ipi_scheduled, 0); +} + +/* + * Check if this softnet_data structure is another cpu one + * If yes, queue it to our IPI list and return 1 + * If no, return 0 + */ +static int napi_schedule_rps(struct softnet_data *sd) +{ + struct softnet_data *mysd = this_cpu_ptr(&softnet_data); + +#ifdef CONFIG_RPS + if (sd != mysd) { + sd->rps_ipi_next = mysd->rps_ipi_list; + mysd->rps_ipi_list = sd; + + __raise_softirq_irqoff(NET_RX_SOFTIRQ); + return 1; + } +#endif /* CONFIG_RPS */ + __napi_schedule_irqoff(&mysd->backlog); + return 0; +} + +#ifdef CONFIG_NET_FLOW_LIMIT +int netdev_flow_limit_table_len __read_mostly = (1 << 12); +#endif + +static bool skb_flow_limit(struct sk_buff *skb, unsigned int qlen) +{ +#ifdef CONFIG_NET_FLOW_LIMIT + struct sd_flow_limit *fl; + struct softnet_data *sd; + unsigned int old_flow, new_flow; + + if (qlen < (READ_ONCE(netdev_max_backlog) >> 1)) + return false; + + sd = this_cpu_ptr(&softnet_data); + + rcu_read_lock(); + fl = rcu_dereference(sd->flow_limit); + if (fl) { + new_flow = skb_get_hash(skb) & (fl->num_buckets - 1); + old_flow = fl->history[fl->history_head]; + fl->history[fl->history_head] = new_flow; + + fl->history_head++; + fl->history_head &= FLOW_LIMIT_HISTORY - 1; + + if (likely(fl->buckets[old_flow])) + fl->buckets[old_flow]--; + + if (++fl->buckets[new_flow] > (FLOW_LIMIT_HISTORY >> 1)) { + fl->count++; + rcu_read_unlock(); + return true; + } + } + rcu_read_unlock(); +#endif + return false; +} + +/* + * enqueue_to_backlog is called to queue an skb to a per CPU backlog + * queue (may be a remote CPU queue). + */ +static int enqueue_to_backlog(struct sk_buff *skb, int cpu, + unsigned int *qtail) +{ + enum skb_drop_reason reason; + struct softnet_data *sd; + unsigned long flags; + unsigned int qlen; + + reason = SKB_DROP_REASON_NOT_SPECIFIED; + sd = &per_cpu(softnet_data, cpu); + + rps_lock_irqsave(sd, &flags); + if (!netif_running(skb->dev)) + goto drop; + qlen = skb_queue_len(&sd->input_pkt_queue); + if (qlen <= READ_ONCE(netdev_max_backlog) && !skb_flow_limit(skb, qlen)) { + if (qlen) { +enqueue: + __skb_queue_tail(&sd->input_pkt_queue, skb); + input_queue_tail_incr_save(sd, qtail); + rps_unlock_irq_restore(sd, &flags); + return NET_RX_SUCCESS; + } + + /* Schedule NAPI for backlog device + * We can use non atomic operation since we own the queue lock + */ + if (!__test_and_set_bit(NAPI_STATE_SCHED, &sd->backlog.state)) + napi_schedule_rps(sd); + goto enqueue; + } + reason = SKB_DROP_REASON_CPU_BACKLOG; + +drop: + sd->dropped++; + rps_unlock_irq_restore(sd, &flags); + + dev_core_stats_rx_dropped_inc(skb->dev); + kfree_skb_reason(skb, reason); + return NET_RX_DROP; +} + +static struct netdev_rx_queue *netif_get_rxqueue(struct sk_buff *skb) +{ + struct net_device *dev = skb->dev; + struct netdev_rx_queue *rxqueue; + + rxqueue = dev->_rx; + + if (skb_rx_queue_recorded(skb)) { + u16 index = skb_get_rx_queue(skb); + + if (unlikely(index >= dev->real_num_rx_queues)) { + WARN_ONCE(dev->real_num_rx_queues > 1, + "%s received packet on queue %u, but number " + "of RX queues is %u\n", + dev->name, index, dev->real_num_rx_queues); + + return rxqueue; /* Return first rxqueue */ + } + rxqueue += index; + } + return rxqueue; +} + +u32 bpf_prog_run_generic_xdp(struct sk_buff *skb, struct xdp_buff *xdp, + struct bpf_prog *xdp_prog) +{ + void *orig_data, *orig_data_end, *hard_start; + struct netdev_rx_queue *rxqueue; + bool orig_bcast, orig_host; + u32 mac_len, frame_sz; + __be16 orig_eth_type; + struct ethhdr *eth; + u32 metalen, act; + int off; + + /* The XDP program wants to see the packet starting at the MAC + * header. + */ + mac_len = skb->data - skb_mac_header(skb); + hard_start = skb->data - skb_headroom(skb); + + /* SKB "head" area always have tailroom for skb_shared_info */ + frame_sz = (void *)skb_end_pointer(skb) - hard_start; + frame_sz += SKB_DATA_ALIGN(sizeof(struct skb_shared_info)); + + rxqueue = netif_get_rxqueue(skb); + xdp_init_buff(xdp, frame_sz, &rxqueue->xdp_rxq); + xdp_prepare_buff(xdp, hard_start, skb_headroom(skb) - mac_len, + skb_headlen(skb) + mac_len, true); + + orig_data_end = xdp->data_end; + orig_data = xdp->data; + eth = (struct ethhdr *)xdp->data; + orig_host = ether_addr_equal_64bits(eth->h_dest, skb->dev->dev_addr); + orig_bcast = is_multicast_ether_addr_64bits(eth->h_dest); + orig_eth_type = eth->h_proto; + + act = bpf_prog_run_xdp(xdp_prog, xdp); + + /* check if bpf_xdp_adjust_head was used */ + off = xdp->data - orig_data; + if (off) { + if (off > 0) + __skb_pull(skb, off); + else if (off < 0) + __skb_push(skb, -off); + + skb->mac_header += off; + skb_reset_network_header(skb); + } + + /* check if bpf_xdp_adjust_tail was used */ + off = xdp->data_end - orig_data_end; + if (off != 0) { + skb_set_tail_pointer(skb, xdp->data_end - xdp->data); + skb->len += off; /* positive on grow, negative on shrink */ + } + + /* check if XDP changed eth hdr such SKB needs update */ + eth = (struct ethhdr *)xdp->data; + if ((orig_eth_type != eth->h_proto) || + (orig_host != ether_addr_equal_64bits(eth->h_dest, + skb->dev->dev_addr)) || + (orig_bcast != is_multicast_ether_addr_64bits(eth->h_dest))) { + __skb_push(skb, ETH_HLEN); + skb->pkt_type = PACKET_HOST; + skb->protocol = eth_type_trans(skb, skb->dev); + } + + /* Redirect/Tx gives L2 packet, code that will reuse skb must __skb_pull + * before calling us again on redirect path. We do not call do_redirect + * as we leave that up to the caller. + * + * Caller is responsible for managing lifetime of skb (i.e. calling + * kfree_skb in response to actions it cannot handle/XDP_DROP). + */ + switch (act) { + case XDP_REDIRECT: + case XDP_TX: + __skb_push(skb, mac_len); + break; + case XDP_PASS: + metalen = xdp->data - xdp->data_meta; + if (metalen) + skb_metadata_set(skb, metalen); + break; + } + + return act; +} + +static u32 netif_receive_generic_xdp(struct sk_buff *skb, + struct xdp_buff *xdp, + struct bpf_prog *xdp_prog) +{ + u32 act = XDP_DROP; + + /* Reinjected packets coming from act_mirred or similar should + * not get XDP generic processing. + */ + if (skb_is_redirected(skb)) + return XDP_PASS; + + /* XDP packets must be linear and must have sufficient headroom + * of XDP_PACKET_HEADROOM bytes. This is the guarantee that also + * native XDP provides, thus we need to do it here as well. + */ + if (skb_cloned(skb) || skb_is_nonlinear(skb) || + skb_headroom(skb) < XDP_PACKET_HEADROOM) { + int hroom = XDP_PACKET_HEADROOM - skb_headroom(skb); + int troom = skb->tail + skb->data_len - skb->end; + + /* In case we have to go down the path and also linearize, + * then lets do the pskb_expand_head() work just once here. + */ + if (pskb_expand_head(skb, + hroom > 0 ? ALIGN(hroom, NET_SKB_PAD) : 0, + troom > 0 ? troom + 128 : 0, GFP_ATOMIC)) + goto do_drop; + if (skb_linearize(skb)) + goto do_drop; + } + + act = bpf_prog_run_generic_xdp(skb, xdp, xdp_prog); + switch (act) { + case XDP_REDIRECT: + case XDP_TX: + case XDP_PASS: + break; + default: + bpf_warn_invalid_xdp_action(skb->dev, xdp_prog, act); + fallthrough; + case XDP_ABORTED: + trace_xdp_exception(skb->dev, xdp_prog, act); + fallthrough; + case XDP_DROP: + do_drop: + kfree_skb(skb); + break; + } + + return act; +} + +/* When doing generic XDP we have to bypass the qdisc layer and the + * network taps in order to match in-driver-XDP behavior. This also means + * that XDP packets are able to starve other packets going through a qdisc, + * and DDOS attacks will be more effective. In-driver-XDP use dedicated TX + * queues, so they do not have this starvation issue. + */ +void generic_xdp_tx(struct sk_buff *skb, struct bpf_prog *xdp_prog) +{ + struct net_device *dev = skb->dev; + struct netdev_queue *txq; + bool free_skb = true; + int cpu, rc; + + txq = netdev_core_pick_tx(dev, skb, NULL); + cpu = smp_processor_id(); + HARD_TX_LOCK(dev, txq, cpu); + if (!netif_xmit_frozen_or_drv_stopped(txq)) { + rc = netdev_start_xmit(skb, dev, txq, 0); + if (dev_xmit_complete(rc)) + free_skb = false; + } + HARD_TX_UNLOCK(dev, txq); + if (free_skb) { + trace_xdp_exception(dev, xdp_prog, XDP_TX); + dev_core_stats_tx_dropped_inc(dev); + kfree_skb(skb); + } +} + +static DEFINE_STATIC_KEY_FALSE(generic_xdp_needed_key); + +int do_xdp_generic(struct bpf_prog *xdp_prog, struct sk_buff *skb) +{ + if (xdp_prog) { + struct xdp_buff xdp; + u32 act; + int err; + + act = netif_receive_generic_xdp(skb, &xdp, xdp_prog); + if (act != XDP_PASS) { + switch (act) { + case XDP_REDIRECT: + err = xdp_do_generic_redirect(skb->dev, skb, + &xdp, xdp_prog); + if (err) + goto out_redir; + break; + case XDP_TX: + generic_xdp_tx(skb, xdp_prog); + break; + } + return XDP_DROP; + } + } + return XDP_PASS; +out_redir: + kfree_skb_reason(skb, SKB_DROP_REASON_XDP); + return XDP_DROP; +} +EXPORT_SYMBOL_GPL(do_xdp_generic); + +static int netif_rx_internal(struct sk_buff *skb) +{ + int ret; + + net_timestamp_check(READ_ONCE(netdev_tstamp_prequeue), skb); + + trace_netif_rx(skb); + +#ifdef CONFIG_RPS + if (static_branch_unlikely(&rps_needed)) { + struct rps_dev_flow voidflow, *rflow = &voidflow; + int cpu; + + rcu_read_lock(); + + cpu = get_rps_cpu(skb->dev, skb, &rflow); + if (cpu < 0) + cpu = smp_processor_id(); + + ret = enqueue_to_backlog(skb, cpu, &rflow->last_qtail); + + rcu_read_unlock(); + } else +#endif + { + unsigned int qtail; + + ret = enqueue_to_backlog(skb, smp_processor_id(), &qtail); + } + return ret; +} + +/** + * __netif_rx - Slightly optimized version of netif_rx + * @skb: buffer to post + * + * This behaves as netif_rx except that it does not disable bottom halves. + * As a result this function may only be invoked from the interrupt context + * (either hard or soft interrupt). + */ +int __netif_rx(struct sk_buff *skb) +{ + int ret; + + lockdep_assert_once(hardirq_count() | softirq_count()); + + trace_netif_rx_entry(skb); + ret = netif_rx_internal(skb); + trace_netif_rx_exit(ret); + return ret; +} +EXPORT_SYMBOL(__netif_rx); + +/** + * netif_rx - post buffer to the network code + * @skb: buffer to post + * + * This function receives a packet from a device driver and queues it for + * the upper (protocol) levels to process via the backlog NAPI device. It + * always succeeds. The buffer may be dropped during processing for + * congestion control or by the protocol layers. + * The network buffer is passed via the backlog NAPI device. Modern NIC + * driver should use NAPI and GRO. + * This function can used from interrupt and from process context. The + * caller from process context must not disable interrupts before invoking + * this function. + * + * return values: + * NET_RX_SUCCESS (no congestion) + * NET_RX_DROP (packet was dropped) + * + */ +int netif_rx(struct sk_buff *skb) +{ + bool need_bh_off = !(hardirq_count() | softirq_count()); + int ret; + + if (need_bh_off) + local_bh_disable(); + trace_netif_rx_entry(skb); + ret = netif_rx_internal(skb); + trace_netif_rx_exit(ret); + if (need_bh_off) + local_bh_enable(); + return ret; +} +EXPORT_SYMBOL(netif_rx); + +static __latent_entropy void net_tx_action(struct softirq_action *h) +{ + struct softnet_data *sd = this_cpu_ptr(&softnet_data); + + if (sd->completion_queue) { + struct sk_buff *clist; + + local_irq_disable(); + clist = sd->completion_queue; + sd->completion_queue = NULL; + local_irq_enable(); + + while (clist) { + struct sk_buff *skb = clist; + + clist = clist->next; + + WARN_ON(refcount_read(&skb->users)); + if (likely(get_kfree_skb_cb(skb)->reason == SKB_REASON_CONSUMED)) + trace_consume_skb(skb); + else + trace_kfree_skb(skb, net_tx_action, + SKB_DROP_REASON_NOT_SPECIFIED); + + if (skb->fclone != SKB_FCLONE_UNAVAILABLE) + __kfree_skb(skb); + else + __kfree_skb_defer(skb); + } + } + + if (sd->output_queue) { + struct Qdisc *head; + + local_irq_disable(); + head = sd->output_queue; + sd->output_queue = NULL; + sd->output_queue_tailp = &sd->output_queue; + local_irq_enable(); + + rcu_read_lock(); + + while (head) { + struct Qdisc *q = head; + spinlock_t *root_lock = NULL; + + head = head->next_sched; + + /* We need to make sure head->next_sched is read + * before clearing __QDISC_STATE_SCHED + */ + smp_mb__before_atomic(); + + if (!(q->flags & TCQ_F_NOLOCK)) { + root_lock = qdisc_lock(q); + spin_lock(root_lock); + } else if (unlikely(test_bit(__QDISC_STATE_DEACTIVATED, + &q->state))) { + /* There is a synchronize_net() between + * STATE_DEACTIVATED flag being set and + * qdisc_reset()/some_qdisc_is_busy() in + * dev_deactivate(), so we can safely bail out + * early here to avoid data race between + * qdisc_deactivate() and some_qdisc_is_busy() + * for lockless qdisc. + */ + clear_bit(__QDISC_STATE_SCHED, &q->state); + continue; + } + + clear_bit(__QDISC_STATE_SCHED, &q->state); + qdisc_run(q); + if (root_lock) + spin_unlock(root_lock); + } + + rcu_read_unlock(); + } + + xfrm_dev_backlog(sd); +} + +#if IS_ENABLED(CONFIG_BRIDGE) && IS_ENABLED(CONFIG_ATM_LANE) +/* This hook is defined here for ATM LANE */ +int (*br_fdb_test_addr_hook)(struct net_device *dev, + unsigned char *addr) __read_mostly; +EXPORT_SYMBOL_GPL(br_fdb_test_addr_hook); +#endif + +static inline struct sk_buff * +sch_handle_ingress(struct sk_buff *skb, struct packet_type **pt_prev, int *ret, + struct net_device *orig_dev, bool *another) +{ +#ifdef CONFIG_NET_CLS_ACT + struct mini_Qdisc *miniq = rcu_dereference_bh(skb->dev->miniq_ingress); + struct tcf_result cl_res; + + /* If there's at least one ingress present somewhere (so + * we get here via enabled static key), remaining devices + * that are not configured with an ingress qdisc will bail + * out here. + */ + if (!miniq) + return skb; + + if (*pt_prev) { + *ret = deliver_skb(skb, *pt_prev, orig_dev); + *pt_prev = NULL; + } + + qdisc_skb_cb(skb)->pkt_len = skb->len; + tc_skb_cb(skb)->mru = 0; + tc_skb_cb(skb)->post_ct = false; + skb->tc_at_ingress = 1; + mini_qdisc_bstats_cpu_update(miniq, skb); + + switch (tcf_classify(skb, miniq->block, miniq->filter_list, &cl_res, false)) { + case TC_ACT_OK: + case TC_ACT_RECLASSIFY: + skb->tc_index = TC_H_MIN(cl_res.classid); + break; + case TC_ACT_SHOT: + mini_qdisc_qstats_cpu_drop(miniq); + kfree_skb_reason(skb, SKB_DROP_REASON_TC_INGRESS); + *ret = NET_RX_DROP; + return NULL; + case TC_ACT_STOLEN: + case TC_ACT_QUEUED: + case TC_ACT_TRAP: + consume_skb(skb); + *ret = NET_RX_SUCCESS; + return NULL; + case TC_ACT_REDIRECT: + /* skb_mac_header check was done by cls/act_bpf, so + * we can safely push the L2 header back before + * redirecting to another netdev + */ + __skb_push(skb, skb->mac_len); + if (skb_do_redirect(skb) == -EAGAIN) { + __skb_pull(skb, skb->mac_len); + *another = true; + break; + } + *ret = NET_RX_SUCCESS; + return NULL; + case TC_ACT_CONSUMED: + *ret = NET_RX_SUCCESS; + return NULL; + default: + break; + } +#endif /* CONFIG_NET_CLS_ACT */ + return skb; +} + +/** + * netdev_is_rx_handler_busy - check if receive handler is registered + * @dev: device to check + * + * Check if a receive handler is already registered for a given device. + * Return true if there one. + * + * The caller must hold the rtnl_mutex. + */ +bool netdev_is_rx_handler_busy(struct net_device *dev) +{ + ASSERT_RTNL(); + return dev && rtnl_dereference(dev->rx_handler); +} +EXPORT_SYMBOL_GPL(netdev_is_rx_handler_busy); + +/** + * netdev_rx_handler_register - register receive handler + * @dev: device to register a handler for + * @rx_handler: receive handler to register + * @rx_handler_data: data pointer that is used by rx handler + * + * Register a receive handler for a device. This handler will then be + * called from __netif_receive_skb. A negative errno code is returned + * on a failure. + * + * The caller must hold the rtnl_mutex. + * + * For a general description of rx_handler, see enum rx_handler_result. + */ +int netdev_rx_handler_register(struct net_device *dev, + rx_handler_func_t *rx_handler, + void *rx_handler_data) +{ + if (netdev_is_rx_handler_busy(dev)) + return -EBUSY; + + if (dev->priv_flags & IFF_NO_RX_HANDLER) + return -EINVAL; + + /* Note: rx_handler_data must be set before rx_handler */ + rcu_assign_pointer(dev->rx_handler_data, rx_handler_data); + rcu_assign_pointer(dev->rx_handler, rx_handler); + + return 0; +} +EXPORT_SYMBOL_GPL(netdev_rx_handler_register); + +/** + * netdev_rx_handler_unregister - unregister receive handler + * @dev: device to unregister a handler from + * + * Unregister a receive handler from a device. + * + * The caller must hold the rtnl_mutex. + */ +void netdev_rx_handler_unregister(struct net_device *dev) +{ + + ASSERT_RTNL(); + RCU_INIT_POINTER(dev->rx_handler, NULL); + /* a reader seeing a non NULL rx_handler in a rcu_read_lock() + * section has a guarantee to see a non NULL rx_handler_data + * as well. + */ + synchronize_net(); + RCU_INIT_POINTER(dev->rx_handler_data, NULL); +} +EXPORT_SYMBOL_GPL(netdev_rx_handler_unregister); + +/* + * Limit the use of PFMEMALLOC reserves to those protocols that implement + * the special handling of PFMEMALLOC skbs. + */ +static bool skb_pfmemalloc_protocol(struct sk_buff *skb) +{ + switch (skb->protocol) { + case htons(ETH_P_ARP): + case htons(ETH_P_IP): + case htons(ETH_P_IPV6): + case htons(ETH_P_8021Q): + case htons(ETH_P_8021AD): + return true; + default: + return false; + } +} + +static inline int nf_ingress(struct sk_buff *skb, struct packet_type **pt_prev, + int *ret, struct net_device *orig_dev) +{ + if (nf_hook_ingress_active(skb)) { + int ingress_retval; + + if (*pt_prev) { + *ret = deliver_skb(skb, *pt_prev, orig_dev); + *pt_prev = NULL; + } + + rcu_read_lock(); + ingress_retval = nf_hook_ingress(skb); + rcu_read_unlock(); + return ingress_retval; + } + return 0; +} + +static int __netif_receive_skb_core(struct sk_buff **pskb, bool pfmemalloc, + struct packet_type **ppt_prev) +{ + struct packet_type *ptype, *pt_prev; + rx_handler_func_t *rx_handler; + struct sk_buff *skb = *pskb; + struct net_device *orig_dev; + bool deliver_exact = false; + int ret = NET_RX_DROP; + __be16 type; + + net_timestamp_check(!READ_ONCE(netdev_tstamp_prequeue), skb); + + trace_netif_receive_skb(skb); + + orig_dev = skb->dev; + + skb_reset_network_header(skb); + if (!skb_transport_header_was_set(skb)) + skb_reset_transport_header(skb); + skb_reset_mac_len(skb); + + pt_prev = NULL; + +another_round: + skb->skb_iif = skb->dev->ifindex; + + __this_cpu_inc(softnet_data.processed); + + if (static_branch_unlikely(&generic_xdp_needed_key)) { + int ret2; + + migrate_disable(); + ret2 = do_xdp_generic(rcu_dereference(skb->dev->xdp_prog), skb); + migrate_enable(); + + if (ret2 != XDP_PASS) { + ret = NET_RX_DROP; + goto out; + } + } + + if (eth_type_vlan(skb->protocol)) { + skb = skb_vlan_untag(skb); + if (unlikely(!skb)) + goto out; + } + + if (skb_skip_tc_classify(skb)) + goto skip_classify; + + if (pfmemalloc) + goto skip_taps; + + list_for_each_entry_rcu(ptype, &ptype_all, list) { + if (pt_prev) + ret = deliver_skb(skb, pt_prev, orig_dev); + pt_prev = ptype; + } + + list_for_each_entry_rcu(ptype, &skb->dev->ptype_all, list) { + if (pt_prev) + ret = deliver_skb(skb, pt_prev, orig_dev); + pt_prev = ptype; + } + +skip_taps: +#ifdef CONFIG_NET_INGRESS + if (static_branch_unlikely(&ingress_needed_key)) { + bool another = false; + + nf_skip_egress(skb, true); + skb = sch_handle_ingress(skb, &pt_prev, &ret, orig_dev, + &another); + if (another) + goto another_round; + if (!skb) + goto out; + + nf_skip_egress(skb, false); + if (nf_ingress(skb, &pt_prev, &ret, orig_dev) < 0) + goto out; + } +#endif + skb_reset_redirect(skb); +skip_classify: + if (pfmemalloc && !skb_pfmemalloc_protocol(skb)) + goto drop; + + if (skb_vlan_tag_present(skb)) { + if (pt_prev) { + ret = deliver_skb(skb, pt_prev, orig_dev); + pt_prev = NULL; + } + if (vlan_do_receive(&skb)) + goto another_round; + else if (unlikely(!skb)) + goto out; + } + + rx_handler = rcu_dereference(skb->dev->rx_handler); + if (rx_handler) { + if (pt_prev) { + ret = deliver_skb(skb, pt_prev, orig_dev); + pt_prev = NULL; + } + switch (rx_handler(&skb)) { + case RX_HANDLER_CONSUMED: + ret = NET_RX_SUCCESS; + goto out; + case RX_HANDLER_ANOTHER: + goto another_round; + case RX_HANDLER_EXACT: + deliver_exact = true; + break; + case RX_HANDLER_PASS: + break; + default: + BUG(); + } + } + + if (unlikely(skb_vlan_tag_present(skb)) && !netdev_uses_dsa(skb->dev)) { +check_vlan_id: + if (skb_vlan_tag_get_id(skb)) { + /* Vlan id is non 0 and vlan_do_receive() above couldn't + * find vlan device. + */ + skb->pkt_type = PACKET_OTHERHOST; + } else if (eth_type_vlan(skb->protocol)) { + /* Outer header is 802.1P with vlan 0, inner header is + * 802.1Q or 802.1AD and vlan_do_receive() above could + * not find vlan dev for vlan id 0. + */ + __vlan_hwaccel_clear_tag(skb); + skb = skb_vlan_untag(skb); + if (unlikely(!skb)) + goto out; + if (vlan_do_receive(&skb)) + /* After stripping off 802.1P header with vlan 0 + * vlan dev is found for inner header. + */ + goto another_round; + else if (unlikely(!skb)) + goto out; + else + /* We have stripped outer 802.1P vlan 0 header. + * But could not find vlan dev. + * check again for vlan id to set OTHERHOST. + */ + goto check_vlan_id; + } + /* Note: we might in the future use prio bits + * and set skb->priority like in vlan_do_receive() + * For the time being, just ignore Priority Code Point + */ + __vlan_hwaccel_clear_tag(skb); + } + + type = skb->protocol; + + /* deliver only exact match when indicated */ + if (likely(!deliver_exact)) { + deliver_ptype_list_skb(skb, &pt_prev, orig_dev, type, + &ptype_base[ntohs(type) & + PTYPE_HASH_MASK]); + } + + deliver_ptype_list_skb(skb, &pt_prev, orig_dev, type, + &orig_dev->ptype_specific); + + if (unlikely(skb->dev != orig_dev)) { + deliver_ptype_list_skb(skb, &pt_prev, orig_dev, type, + &skb->dev->ptype_specific); + } + + if (pt_prev) { + if (unlikely(skb_orphan_frags_rx(skb, GFP_ATOMIC))) + goto drop; + *ppt_prev = pt_prev; + } else { +drop: + if (!deliver_exact) + dev_core_stats_rx_dropped_inc(skb->dev); + else + dev_core_stats_rx_nohandler_inc(skb->dev); + kfree_skb_reason(skb, SKB_DROP_REASON_UNHANDLED_PROTO); + /* Jamal, now you will not able to escape explaining + * me how you were going to use this. :-) + */ + ret = NET_RX_DROP; + } + +out: + /* The invariant here is that if *ppt_prev is not NULL + * then skb should also be non-NULL. + * + * Apparently *ppt_prev assignment above holds this invariant due to + * skb dereferencing near it. + */ + *pskb = skb; + return ret; +} + +static int __netif_receive_skb_one_core(struct sk_buff *skb, bool pfmemalloc) +{ + struct net_device *orig_dev = skb->dev; + struct packet_type *pt_prev = NULL; + int ret; + + ret = __netif_receive_skb_core(&skb, pfmemalloc, &pt_prev); + if (pt_prev) + ret = INDIRECT_CALL_INET(pt_prev->func, ipv6_rcv, ip_rcv, skb, + skb->dev, pt_prev, orig_dev); + return ret; +} + +/** + * netif_receive_skb_core - special purpose version of netif_receive_skb + * @skb: buffer to process + * + * More direct receive version of netif_receive_skb(). It should + * only be used by callers that have a need to skip RPS and Generic XDP. + * Caller must also take care of handling if ``(page_is_)pfmemalloc``. + * + * This function may only be called from softirq context and interrupts + * should be enabled. + * + * Return values (usually ignored): + * NET_RX_SUCCESS: no congestion + * NET_RX_DROP: packet was dropped + */ +int netif_receive_skb_core(struct sk_buff *skb) +{ + int ret; + + rcu_read_lock(); + ret = __netif_receive_skb_one_core(skb, false); + rcu_read_unlock(); + + return ret; +} +EXPORT_SYMBOL(netif_receive_skb_core); + +static inline void __netif_receive_skb_list_ptype(struct list_head *head, + struct packet_type *pt_prev, + struct net_device *orig_dev) +{ + struct sk_buff *skb, *next; + + if (!pt_prev) + return; + if (list_empty(head)) + return; + if (pt_prev->list_func != NULL) + INDIRECT_CALL_INET(pt_prev->list_func, ipv6_list_rcv, + ip_list_rcv, head, pt_prev, orig_dev); + else + list_for_each_entry_safe(skb, next, head, list) { + skb_list_del_init(skb); + pt_prev->func(skb, skb->dev, pt_prev, orig_dev); + } +} + +static void __netif_receive_skb_list_core(struct list_head *head, bool pfmemalloc) +{ + /* Fast-path assumptions: + * - There is no RX handler. + * - Only one packet_type matches. + * If either of these fails, we will end up doing some per-packet + * processing in-line, then handling the 'last ptype' for the whole + * sublist. This can't cause out-of-order delivery to any single ptype, + * because the 'last ptype' must be constant across the sublist, and all + * other ptypes are handled per-packet. + */ + /* Current (common) ptype of sublist */ + struct packet_type *pt_curr = NULL; + /* Current (common) orig_dev of sublist */ + struct net_device *od_curr = NULL; + struct list_head sublist; + struct sk_buff *skb, *next; + + INIT_LIST_HEAD(&sublist); + list_for_each_entry_safe(skb, next, head, list) { + struct net_device *orig_dev = skb->dev; + struct packet_type *pt_prev = NULL; + + skb_list_del_init(skb); + __netif_receive_skb_core(&skb, pfmemalloc, &pt_prev); + if (!pt_prev) + continue; + if (pt_curr != pt_prev || od_curr != orig_dev) { + /* dispatch old sublist */ + __netif_receive_skb_list_ptype(&sublist, pt_curr, od_curr); + /* start new sublist */ + INIT_LIST_HEAD(&sublist); + pt_curr = pt_prev; + od_curr = orig_dev; + } + list_add_tail(&skb->list, &sublist); + } + + /* dispatch final sublist */ + __netif_receive_skb_list_ptype(&sublist, pt_curr, od_curr); +} + +static int __netif_receive_skb(struct sk_buff *skb) +{ + int ret; + + if (sk_memalloc_socks() && skb_pfmemalloc(skb)) { + unsigned int noreclaim_flag; + + /* + * PFMEMALLOC skbs are special, they should + * - be delivered to SOCK_MEMALLOC sockets only + * - stay away from userspace + * - have bounded memory usage + * + * Use PF_MEMALLOC as this saves us from propagating the allocation + * context down to all allocation sites. + */ + noreclaim_flag = memalloc_noreclaim_save(); + ret = __netif_receive_skb_one_core(skb, true); + memalloc_noreclaim_restore(noreclaim_flag); + } else + ret = __netif_receive_skb_one_core(skb, false); + + return ret; +} + +static void __netif_receive_skb_list(struct list_head *head) +{ + unsigned long noreclaim_flag = 0; + struct sk_buff *skb, *next; + bool pfmemalloc = false; /* Is current sublist PF_MEMALLOC? */ + + list_for_each_entry_safe(skb, next, head, list) { + if ((sk_memalloc_socks() && skb_pfmemalloc(skb)) != pfmemalloc) { + struct list_head sublist; + + /* Handle the previous sublist */ + list_cut_before(&sublist, head, &skb->list); + if (!list_empty(&sublist)) + __netif_receive_skb_list_core(&sublist, pfmemalloc); + pfmemalloc = !pfmemalloc; + /* See comments in __netif_receive_skb */ + if (pfmemalloc) + noreclaim_flag = memalloc_noreclaim_save(); + else + memalloc_noreclaim_restore(noreclaim_flag); + } + } + /* Handle the remaining sublist */ + if (!list_empty(head)) + __netif_receive_skb_list_core(head, pfmemalloc); + /* Restore pflags */ + if (pfmemalloc) + memalloc_noreclaim_restore(noreclaim_flag); +} + +static int generic_xdp_install(struct net_device *dev, struct netdev_bpf *xdp) +{ + struct bpf_prog *old = rtnl_dereference(dev->xdp_prog); + struct bpf_prog *new = xdp->prog; + int ret = 0; + + switch (xdp->command) { + case XDP_SETUP_PROG: + rcu_assign_pointer(dev->xdp_prog, new); + if (old) + bpf_prog_put(old); + + if (old && !new) { + static_branch_dec(&generic_xdp_needed_key); + } else if (new && !old) { + static_branch_inc(&generic_xdp_needed_key); + dev_disable_lro(dev); + dev_disable_gro_hw(dev); + } + break; + + default: + ret = -EINVAL; + break; + } + + return ret; +} + +static int netif_receive_skb_internal(struct sk_buff *skb) +{ + int ret; + + net_timestamp_check(READ_ONCE(netdev_tstamp_prequeue), skb); + + if (skb_defer_rx_timestamp(skb)) + return NET_RX_SUCCESS; + + rcu_read_lock(); +#ifdef CONFIG_RPS + if (static_branch_unlikely(&rps_needed)) { + struct rps_dev_flow voidflow, *rflow = &voidflow; + int cpu = get_rps_cpu(skb->dev, skb, &rflow); + + if (cpu >= 0) { + ret = enqueue_to_backlog(skb, cpu, &rflow->last_qtail); + rcu_read_unlock(); + return ret; + } + } +#endif + ret = __netif_receive_skb(skb); + rcu_read_unlock(); + return ret; +} + +void netif_receive_skb_list_internal(struct list_head *head) +{ + struct sk_buff *skb, *next; + struct list_head sublist; + + INIT_LIST_HEAD(&sublist); + list_for_each_entry_safe(skb, next, head, list) { + net_timestamp_check(READ_ONCE(netdev_tstamp_prequeue), skb); + skb_list_del_init(skb); + if (!skb_defer_rx_timestamp(skb)) + list_add_tail(&skb->list, &sublist); + } + list_splice_init(&sublist, head); + + rcu_read_lock(); +#ifdef CONFIG_RPS + if (static_branch_unlikely(&rps_needed)) { + list_for_each_entry_safe(skb, next, head, list) { + struct rps_dev_flow voidflow, *rflow = &voidflow; + int cpu = get_rps_cpu(skb->dev, skb, &rflow); + + if (cpu >= 0) { + /* Will be handled, remove from list */ + skb_list_del_init(skb); + enqueue_to_backlog(skb, cpu, &rflow->last_qtail); + } + } + } +#endif + __netif_receive_skb_list(head); + rcu_read_unlock(); +} + +/** + * netif_receive_skb - process receive buffer from network + * @skb: buffer to process + * + * netif_receive_skb() is the main receive data processing function. + * It always succeeds. The buffer may be dropped during processing + * for congestion control or by the protocol layers. + * + * This function may only be called from softirq context and interrupts + * should be enabled. + * + * Return values (usually ignored): + * NET_RX_SUCCESS: no congestion + * NET_RX_DROP: packet was dropped + */ +int netif_receive_skb(struct sk_buff *skb) +{ + int ret; + + trace_netif_receive_skb_entry(skb); + + ret = netif_receive_skb_internal(skb); + trace_netif_receive_skb_exit(ret); + + return ret; +} +EXPORT_SYMBOL(netif_receive_skb); + +/** + * netif_receive_skb_list - process many receive buffers from network + * @head: list of skbs to process. + * + * Since return value of netif_receive_skb() is normally ignored, and + * wouldn't be meaningful for a list, this function returns void. + * + * This function may only be called from softirq context and interrupts + * should be enabled. + */ +void netif_receive_skb_list(struct list_head *head) +{ + struct sk_buff *skb; + + if (list_empty(head)) + return; + if (trace_netif_receive_skb_list_entry_enabled()) { + list_for_each_entry(skb, head, list) + trace_netif_receive_skb_list_entry(skb); + } + netif_receive_skb_list_internal(head); + trace_netif_receive_skb_list_exit(0); +} +EXPORT_SYMBOL(netif_receive_skb_list); + +static DEFINE_PER_CPU(struct work_struct, flush_works); + +/* Network device is going away, flush any packets still pending */ +static void flush_backlog(struct work_struct *work) +{ + struct sk_buff *skb, *tmp; + struct softnet_data *sd; + + local_bh_disable(); + sd = this_cpu_ptr(&softnet_data); + + rps_lock_irq_disable(sd); + skb_queue_walk_safe(&sd->input_pkt_queue, skb, tmp) { + if (skb->dev->reg_state == NETREG_UNREGISTERING) { + __skb_unlink(skb, &sd->input_pkt_queue); + dev_kfree_skb_irq(skb); + input_queue_head_incr(sd); + } + } + rps_unlock_irq_enable(sd); + + skb_queue_walk_safe(&sd->process_queue, skb, tmp) { + if (skb->dev->reg_state == NETREG_UNREGISTERING) { + __skb_unlink(skb, &sd->process_queue); + kfree_skb(skb); + input_queue_head_incr(sd); + } + } + local_bh_enable(); +} + +static bool flush_required(int cpu) +{ +#if IS_ENABLED(CONFIG_RPS) + struct softnet_data *sd = &per_cpu(softnet_data, cpu); + bool do_flush; + + rps_lock_irq_disable(sd); + + /* as insertion into process_queue happens with the rps lock held, + * process_queue access may race only with dequeue + */ + do_flush = !skb_queue_empty(&sd->input_pkt_queue) || + !skb_queue_empty_lockless(&sd->process_queue); + rps_unlock_irq_enable(sd); + + return do_flush; +#endif + /* without RPS we can't safely check input_pkt_queue: during a + * concurrent remote skb_queue_splice() we can detect as empty both + * input_pkt_queue and process_queue even if the latter could end-up + * containing a lot of packets. + */ + return true; +} + +static void flush_all_backlogs(void) +{ + static cpumask_t flush_cpus; + unsigned int cpu; + + /* since we are under rtnl lock protection we can use static data + * for the cpumask and avoid allocating on stack the possibly + * large mask + */ + ASSERT_RTNL(); + + cpus_read_lock(); + + cpumask_clear(&flush_cpus); + for_each_online_cpu(cpu) { + if (flush_required(cpu)) { + queue_work_on(cpu, system_highpri_wq, + per_cpu_ptr(&flush_works, cpu)); + cpumask_set_cpu(cpu, &flush_cpus); + } + } + + /* we can have in flight packet[s] on the cpus we are not flushing, + * synchronize_net() in unregister_netdevice_many() will take care of + * them + */ + for_each_cpu(cpu, &flush_cpus) + flush_work(per_cpu_ptr(&flush_works, cpu)); + + cpus_read_unlock(); +} + +static void net_rps_send_ipi(struct softnet_data *remsd) +{ +#ifdef CONFIG_RPS + while (remsd) { + struct softnet_data *next = remsd->rps_ipi_next; + + if (cpu_online(remsd->cpu)) + smp_call_function_single_async(remsd->cpu, &remsd->csd); + remsd = next; + } +#endif +} + +/* + * net_rps_action_and_irq_enable sends any pending IPI's for rps. + * Note: called with local irq disabled, but exits with local irq enabled. + */ +static void net_rps_action_and_irq_enable(struct softnet_data *sd) +{ +#ifdef CONFIG_RPS + struct softnet_data *remsd = sd->rps_ipi_list; + + if (remsd) { + sd->rps_ipi_list = NULL; + + local_irq_enable(); + + /* Send pending IPI's to kick RPS processing on remote cpus. */ + net_rps_send_ipi(remsd); + } else +#endif + local_irq_enable(); +} + +static bool sd_has_rps_ipi_waiting(struct softnet_data *sd) +{ +#ifdef CONFIG_RPS + return sd->rps_ipi_list != NULL; +#else + return false; +#endif +} + +static int process_backlog(struct napi_struct *napi, int quota) +{ + struct softnet_data *sd = container_of(napi, struct softnet_data, backlog); + bool again = true; + int work = 0; + + /* Check if we have pending ipi, its better to send them now, + * not waiting net_rx_action() end. + */ + if (sd_has_rps_ipi_waiting(sd)) { + local_irq_disable(); + net_rps_action_and_irq_enable(sd); + } + + napi->weight = READ_ONCE(dev_rx_weight); + while (again) { + struct sk_buff *skb; + + while ((skb = __skb_dequeue(&sd->process_queue))) { + rcu_read_lock(); + __netif_receive_skb(skb); + rcu_read_unlock(); + input_queue_head_incr(sd); + if (++work >= quota) + return work; + + } + + rps_lock_irq_disable(sd); + if (skb_queue_empty(&sd->input_pkt_queue)) { + /* + * Inline a custom version of __napi_complete(). + * only current cpu owns and manipulates this napi, + * and NAPI_STATE_SCHED is the only possible flag set + * on backlog. + * We can use a plain write instead of clear_bit(), + * and we dont need an smp_mb() memory barrier. + */ + napi->state = 0; + again = false; + } else { + skb_queue_splice_tail_init(&sd->input_pkt_queue, + &sd->process_queue); + } + rps_unlock_irq_enable(sd); + } + + return work; +} + +/** + * __napi_schedule - schedule for receive + * @n: entry to schedule + * + * The entry's receive function will be scheduled to run. + * Consider using __napi_schedule_irqoff() if hard irqs are masked. + */ +void __napi_schedule(struct napi_struct *n) +{ + unsigned long flags; + + local_irq_save(flags); + ____napi_schedule(this_cpu_ptr(&softnet_data), n); + local_irq_restore(flags); +} +EXPORT_SYMBOL(__napi_schedule); + +/** + * napi_schedule_prep - check if napi can be scheduled + * @n: napi context + * + * Test if NAPI routine is already running, and if not mark + * it as running. This is used as a condition variable to + * insure only one NAPI poll instance runs. We also make + * sure there is no pending NAPI disable. + */ +bool napi_schedule_prep(struct napi_struct *n) +{ + unsigned long val, new; + + do { + val = READ_ONCE(n->state); + if (unlikely(val & NAPIF_STATE_DISABLE)) + return false; + new = val | NAPIF_STATE_SCHED; + + /* Sets STATE_MISSED bit if STATE_SCHED was already set + * This was suggested by Alexander Duyck, as compiler + * emits better code than : + * if (val & NAPIF_STATE_SCHED) + * new |= NAPIF_STATE_MISSED; + */ + new |= (val & NAPIF_STATE_SCHED) / NAPIF_STATE_SCHED * + NAPIF_STATE_MISSED; + } while (cmpxchg(&n->state, val, new) != val); + + return !(val & NAPIF_STATE_SCHED); +} +EXPORT_SYMBOL(napi_schedule_prep); + +/** + * __napi_schedule_irqoff - schedule for receive + * @n: entry to schedule + * + * Variant of __napi_schedule() assuming hard irqs are masked. + * + * On PREEMPT_RT enabled kernels this maps to __napi_schedule() + * because the interrupt disabled assumption might not be true + * due to force-threaded interrupts and spinlock substitution. + */ +void __napi_schedule_irqoff(struct napi_struct *n) +{ + if (!IS_ENABLED(CONFIG_PREEMPT_RT)) + ____napi_schedule(this_cpu_ptr(&softnet_data), n); + else + __napi_schedule(n); +} +EXPORT_SYMBOL(__napi_schedule_irqoff); + +bool napi_complete_done(struct napi_struct *n, int work_done) +{ + unsigned long flags, val, new, timeout = 0; + bool ret = true; + + /* + * 1) Don't let napi dequeue from the cpu poll list + * just in case its running on a different cpu. + * 2) If we are busy polling, do nothing here, we have + * the guarantee we will be called later. + */ + if (unlikely(n->state & (NAPIF_STATE_NPSVC | + NAPIF_STATE_IN_BUSY_POLL))) + return false; + + if (work_done) { + if (n->gro_bitmask) + timeout = READ_ONCE(n->dev->gro_flush_timeout); + n->defer_hard_irqs_count = READ_ONCE(n->dev->napi_defer_hard_irqs); + } + if (n->defer_hard_irqs_count > 0) { + n->defer_hard_irqs_count--; + timeout = READ_ONCE(n->dev->gro_flush_timeout); + if (timeout) + ret = false; + } + if (n->gro_bitmask) { + /* When the NAPI instance uses a timeout and keeps postponing + * it, we need to bound somehow the time packets are kept in + * the GRO layer + */ + napi_gro_flush(n, !!timeout); + } + + gro_normal_list(n); + + if (unlikely(!list_empty(&n->poll_list))) { + /* If n->poll_list is not empty, we need to mask irqs */ + local_irq_save(flags); + list_del_init(&n->poll_list); + local_irq_restore(flags); + } + + do { + val = READ_ONCE(n->state); + + WARN_ON_ONCE(!(val & NAPIF_STATE_SCHED)); + + new = val & ~(NAPIF_STATE_MISSED | NAPIF_STATE_SCHED | + NAPIF_STATE_SCHED_THREADED | + NAPIF_STATE_PREFER_BUSY_POLL); + + /* If STATE_MISSED was set, leave STATE_SCHED set, + * because we will call napi->poll() one more time. + * This C code was suggested by Alexander Duyck to help gcc. + */ + new |= (val & NAPIF_STATE_MISSED) / NAPIF_STATE_MISSED * + NAPIF_STATE_SCHED; + } while (cmpxchg(&n->state, val, new) != val); + + if (unlikely(val & NAPIF_STATE_MISSED)) { + __napi_schedule(n); + return false; + } + + if (timeout) + hrtimer_start(&n->timer, ns_to_ktime(timeout), + HRTIMER_MODE_REL_PINNED); + return ret; +} +EXPORT_SYMBOL(napi_complete_done); + +/* must be called under rcu_read_lock(), as we dont take a reference */ +static struct napi_struct *napi_by_id(unsigned int napi_id) +{ + unsigned int hash = napi_id % HASH_SIZE(napi_hash); + struct napi_struct *napi; + + hlist_for_each_entry_rcu(napi, &napi_hash[hash], napi_hash_node) + if (napi->napi_id == napi_id) + return napi; + + return NULL; +} + +#if defined(CONFIG_NET_RX_BUSY_POLL) + +static void __busy_poll_stop(struct napi_struct *napi, bool skip_schedule) +{ + if (!skip_schedule) { + gro_normal_list(napi); + __napi_schedule(napi); + return; + } + + if (napi->gro_bitmask) { + /* flush too old packets + * If HZ < 1000, flush all packets. + */ + napi_gro_flush(napi, HZ >= 1000); + } + + gro_normal_list(napi); + clear_bit(NAPI_STATE_SCHED, &napi->state); +} + +static void busy_poll_stop(struct napi_struct *napi, void *have_poll_lock, bool prefer_busy_poll, + u16 budget) +{ + bool skip_schedule = false; + unsigned long timeout; + int rc; + + /* Busy polling means there is a high chance device driver hard irq + * could not grab NAPI_STATE_SCHED, and that NAPI_STATE_MISSED was + * set in napi_schedule_prep(). + * Since we are about to call napi->poll() once more, we can safely + * clear NAPI_STATE_MISSED. + * + * Note: x86 could use a single "lock and ..." instruction + * to perform these two clear_bit() + */ + clear_bit(NAPI_STATE_MISSED, &napi->state); + clear_bit(NAPI_STATE_IN_BUSY_POLL, &napi->state); + + local_bh_disable(); + + if (prefer_busy_poll) { + napi->defer_hard_irqs_count = READ_ONCE(napi->dev->napi_defer_hard_irqs); + timeout = READ_ONCE(napi->dev->gro_flush_timeout); + if (napi->defer_hard_irqs_count && timeout) { + hrtimer_start(&napi->timer, ns_to_ktime(timeout), HRTIMER_MODE_REL_PINNED); + skip_schedule = true; + } + } + + /* All we really want here is to re-enable device interrupts. + * Ideally, a new ndo_busy_poll_stop() could avoid another round. + */ + rc = napi->poll(napi, budget); + /* We can't gro_normal_list() here, because napi->poll() might have + * rearmed the napi (napi_complete_done()) in which case it could + * already be running on another CPU. + */ + trace_napi_poll(napi, rc, budget); + netpoll_poll_unlock(have_poll_lock); + if (rc == budget) + __busy_poll_stop(napi, skip_schedule); + local_bh_enable(); +} + +void napi_busy_loop(unsigned int napi_id, + bool (*loop_end)(void *, unsigned long), + void *loop_end_arg, bool prefer_busy_poll, u16 budget) +{ + unsigned long start_time = loop_end ? busy_loop_current_time() : 0; + int (*napi_poll)(struct napi_struct *napi, int budget); + void *have_poll_lock = NULL; + struct napi_struct *napi; + +restart: + napi_poll = NULL; + + rcu_read_lock(); + + napi = napi_by_id(napi_id); + if (!napi) + goto out; + + preempt_disable(); + for (;;) { + int work = 0; + + local_bh_disable(); + if (!napi_poll) { + unsigned long val = READ_ONCE(napi->state); + + /* If multiple threads are competing for this napi, + * we avoid dirtying napi->state as much as we can. + */ + if (val & (NAPIF_STATE_DISABLE | NAPIF_STATE_SCHED | + NAPIF_STATE_IN_BUSY_POLL)) { + if (prefer_busy_poll) + set_bit(NAPI_STATE_PREFER_BUSY_POLL, &napi->state); + goto count; + } + if (cmpxchg(&napi->state, val, + val | NAPIF_STATE_IN_BUSY_POLL | + NAPIF_STATE_SCHED) != val) { + if (prefer_busy_poll) + set_bit(NAPI_STATE_PREFER_BUSY_POLL, &napi->state); + goto count; + } + have_poll_lock = netpoll_poll_lock(napi); + napi_poll = napi->poll; + } + work = napi_poll(napi, budget); + trace_napi_poll(napi, work, budget); + gro_normal_list(napi); +count: + if (work > 0) + __NET_ADD_STATS(dev_net(napi->dev), + LINUX_MIB_BUSYPOLLRXPACKETS, work); + local_bh_enable(); + + if (!loop_end || loop_end(loop_end_arg, start_time)) + break; + + if (unlikely(need_resched())) { + if (napi_poll) + busy_poll_stop(napi, have_poll_lock, prefer_busy_poll, budget); + preempt_enable(); + rcu_read_unlock(); + cond_resched(); + if (loop_end(loop_end_arg, start_time)) + return; + goto restart; + } + cpu_relax(); + } + if (napi_poll) + busy_poll_stop(napi, have_poll_lock, prefer_busy_poll, budget); + preempt_enable(); +out: + rcu_read_unlock(); +} +EXPORT_SYMBOL(napi_busy_loop); + +#endif /* CONFIG_NET_RX_BUSY_POLL */ + +static void napi_hash_add(struct napi_struct *napi) +{ + if (test_bit(NAPI_STATE_NO_BUSY_POLL, &napi->state)) + return; + + spin_lock(&napi_hash_lock); + + /* 0..NR_CPUS range is reserved for sender_cpu use */ + do { + if (unlikely(++napi_gen_id < MIN_NAPI_ID)) + napi_gen_id = MIN_NAPI_ID; + } while (napi_by_id(napi_gen_id)); + napi->napi_id = napi_gen_id; + + hlist_add_head_rcu(&napi->napi_hash_node, + &napi_hash[napi->napi_id % HASH_SIZE(napi_hash)]); + + spin_unlock(&napi_hash_lock); +} + +/* Warning : caller is responsible to make sure rcu grace period + * is respected before freeing memory containing @napi + */ +static void napi_hash_del(struct napi_struct *napi) +{ + spin_lock(&napi_hash_lock); + + hlist_del_init_rcu(&napi->napi_hash_node); + + spin_unlock(&napi_hash_lock); +} + +static enum hrtimer_restart napi_watchdog(struct hrtimer *timer) +{ + struct napi_struct *napi; + + napi = container_of(timer, struct napi_struct, timer); + + /* Note : we use a relaxed variant of napi_schedule_prep() not setting + * NAPI_STATE_MISSED, since we do not react to a device IRQ. + */ + if (!napi_disable_pending(napi) && + !test_and_set_bit(NAPI_STATE_SCHED, &napi->state)) { + clear_bit(NAPI_STATE_PREFER_BUSY_POLL, &napi->state); + __napi_schedule_irqoff(napi); + } + + return HRTIMER_NORESTART; +} + +static void init_gro_hash(struct napi_struct *napi) +{ + int i; + + for (i = 0; i < GRO_HASH_BUCKETS; i++) { + INIT_LIST_HEAD(&napi->gro_hash[i].list); + napi->gro_hash[i].count = 0; + } + napi->gro_bitmask = 0; +} + +int dev_set_threaded(struct net_device *dev, bool threaded) +{ + struct napi_struct *napi; + int err = 0; + + if (dev->threaded == threaded) + return 0; + + if (threaded) { + list_for_each_entry(napi, &dev->napi_list, dev_list) { + if (!napi->thread) { + err = napi_kthread_create(napi); + if (err) { + threaded = false; + break; + } + } + } + } + + dev->threaded = threaded; + + /* Make sure kthread is created before THREADED bit + * is set. + */ + smp_mb__before_atomic(); + + /* Setting/unsetting threaded mode on a napi might not immediately + * take effect, if the current napi instance is actively being + * polled. In this case, the switch between threaded mode and + * softirq mode will happen in the next round of napi_schedule(). + * This should not cause hiccups/stalls to the live traffic. + */ + list_for_each_entry(napi, &dev->napi_list, dev_list) { + if (threaded) + set_bit(NAPI_STATE_THREADED, &napi->state); + else + clear_bit(NAPI_STATE_THREADED, &napi->state); + } + + return err; +} +EXPORT_SYMBOL(dev_set_threaded); + +void netif_napi_add_weight(struct net_device *dev, struct napi_struct *napi, + int (*poll)(struct napi_struct *, int), int weight) +{ + if (WARN_ON(test_and_set_bit(NAPI_STATE_LISTED, &napi->state))) + return; + + INIT_LIST_HEAD(&napi->poll_list); + INIT_HLIST_NODE(&napi->napi_hash_node); + hrtimer_init(&napi->timer, CLOCK_MONOTONIC, HRTIMER_MODE_REL_PINNED); + napi->timer.function = napi_watchdog; + init_gro_hash(napi); + napi->skb = NULL; + INIT_LIST_HEAD(&napi->rx_list); + napi->rx_count = 0; + napi->poll = poll; + if (weight > NAPI_POLL_WEIGHT) + netdev_err_once(dev, "%s() called with weight %d\n", __func__, + weight); + napi->weight = weight; + napi->dev = dev; +#ifdef CONFIG_NETPOLL + napi->poll_owner = -1; +#endif + set_bit(NAPI_STATE_SCHED, &napi->state); + set_bit(NAPI_STATE_NPSVC, &napi->state); + list_add_rcu(&napi->dev_list, &dev->napi_list); + napi_hash_add(napi); + napi_get_frags_check(napi); + /* Create kthread for this napi if dev->threaded is set. + * Clear dev->threaded if kthread creation failed so that + * threaded mode will not be enabled in napi_enable(). + */ + if (dev->threaded && napi_kthread_create(napi)) + dev->threaded = 0; +} +EXPORT_SYMBOL(netif_napi_add_weight); + +void napi_disable(struct napi_struct *n) +{ + unsigned long val, new; + + might_sleep(); + set_bit(NAPI_STATE_DISABLE, &n->state); + + for ( ; ; ) { + val = READ_ONCE(n->state); + if (val & (NAPIF_STATE_SCHED | NAPIF_STATE_NPSVC)) { + usleep_range(20, 200); + continue; + } + + new = val | NAPIF_STATE_SCHED | NAPIF_STATE_NPSVC; + new &= ~(NAPIF_STATE_THREADED | NAPIF_STATE_PREFER_BUSY_POLL); + + if (cmpxchg(&n->state, val, new) == val) + break; + } + + hrtimer_cancel(&n->timer); + + clear_bit(NAPI_STATE_DISABLE, &n->state); +} +EXPORT_SYMBOL(napi_disable); + +/** + * napi_enable - enable NAPI scheduling + * @n: NAPI context + * + * Resume NAPI from being scheduled on this context. + * Must be paired with napi_disable. + */ +void napi_enable(struct napi_struct *n) +{ + unsigned long val, new; + + do { + val = READ_ONCE(n->state); + BUG_ON(!test_bit(NAPI_STATE_SCHED, &val)); + + new = val & ~(NAPIF_STATE_SCHED | NAPIF_STATE_NPSVC); + if (n->dev->threaded && n->thread) + new |= NAPIF_STATE_THREADED; + } while (cmpxchg(&n->state, val, new) != val); +} +EXPORT_SYMBOL(napi_enable); + +static void flush_gro_hash(struct napi_struct *napi) +{ + int i; + + for (i = 0; i < GRO_HASH_BUCKETS; i++) { + struct sk_buff *skb, *n; + + list_for_each_entry_safe(skb, n, &napi->gro_hash[i].list, list) + kfree_skb(skb); + napi->gro_hash[i].count = 0; + } +} + +/* Must be called in process context */ +void __netif_napi_del(struct napi_struct *napi) +{ + if (!test_and_clear_bit(NAPI_STATE_LISTED, &napi->state)) + return; + + napi_hash_del(napi); + list_del_rcu(&napi->dev_list); + napi_free_frags(napi); + + flush_gro_hash(napi); + napi->gro_bitmask = 0; + + if (napi->thread) { + kthread_stop(napi->thread); + napi->thread = NULL; + } +} +EXPORT_SYMBOL(__netif_napi_del); + +static int __napi_poll(struct napi_struct *n, bool *repoll) +{ + int work, weight; + + weight = n->weight; + + /* This NAPI_STATE_SCHED test is for avoiding a race + * with netpoll's poll_napi(). Only the entity which + * obtains the lock and sees NAPI_STATE_SCHED set will + * actually make the ->poll() call. Therefore we avoid + * accidentally calling ->poll() when NAPI is not scheduled. + */ + work = 0; + if (test_bit(NAPI_STATE_SCHED, &n->state)) { + work = n->poll(n, weight); + trace_napi_poll(n, work, weight); + } + + if (unlikely(work > weight)) + netdev_err_once(n->dev, "NAPI poll function %pS returned %d, exceeding its budget of %d.\n", + n->poll, work, weight); + + if (likely(work < weight)) + return work; + + /* Drivers must not modify the NAPI state if they + * consume the entire weight. In such cases this code + * still "owns" the NAPI instance and therefore can + * move the instance around on the list at-will. + */ + if (unlikely(napi_disable_pending(n))) { + napi_complete(n); + return work; + } + + /* The NAPI context has more processing work, but busy-polling + * is preferred. Exit early. + */ + if (napi_prefer_busy_poll(n)) { + if (napi_complete_done(n, work)) { + /* If timeout is not set, we need to make sure + * that the NAPI is re-scheduled. + */ + napi_schedule(n); + } + return work; + } + + if (n->gro_bitmask) { + /* flush too old packets + * If HZ < 1000, flush all packets. + */ + napi_gro_flush(n, HZ >= 1000); + } + + gro_normal_list(n); + + /* Some drivers may have called napi_schedule + * prior to exhausting their budget. + */ + if (unlikely(!list_empty(&n->poll_list))) { + pr_warn_once("%s: Budget exhausted after napi rescheduled\n", + n->dev ? n->dev->name : "backlog"); + return work; + } + + *repoll = true; + + return work; +} + +static int napi_poll(struct napi_struct *n, struct list_head *repoll) +{ + bool do_repoll = false; + void *have; + int work; + + list_del_init(&n->poll_list); + + have = netpoll_poll_lock(n); + + work = __napi_poll(n, &do_repoll); + + if (do_repoll) + list_add_tail(&n->poll_list, repoll); + + netpoll_poll_unlock(have); + + return work; +} + +static int napi_thread_wait(struct napi_struct *napi) +{ + bool woken = false; + + set_current_state(TASK_INTERRUPTIBLE); + + while (!kthread_should_stop()) { + /* Testing SCHED_THREADED bit here to make sure the current + * kthread owns this napi and could poll on this napi. + * Testing SCHED bit is not enough because SCHED bit might be + * set by some other busy poll thread or by napi_disable(). + */ + if (test_bit(NAPI_STATE_SCHED_THREADED, &napi->state) || woken) { + WARN_ON(!list_empty(&napi->poll_list)); + __set_current_state(TASK_RUNNING); + return 0; + } + + schedule(); + /* woken being true indicates this thread owns this napi. */ + woken = true; + set_current_state(TASK_INTERRUPTIBLE); + } + __set_current_state(TASK_RUNNING); + + return -1; +} + +static int napi_threaded_poll(void *data) +{ + struct napi_struct *napi = data; + void *have; + + while (!napi_thread_wait(napi)) { + for (;;) { + bool repoll = false; + + local_bh_disable(); + + have = netpoll_poll_lock(napi); + __napi_poll(napi, &repoll); + netpoll_poll_unlock(have); + + local_bh_enable(); + + if (!repoll) + break; + + cond_resched(); + } + } + return 0; +} + +static void skb_defer_free_flush(struct softnet_data *sd) +{ + struct sk_buff *skb, *next; + unsigned long flags; + + /* Paired with WRITE_ONCE() in skb_attempt_defer_free() */ + if (!READ_ONCE(sd->defer_list)) + return; + + spin_lock_irqsave(&sd->defer_lock, flags); + skb = sd->defer_list; + sd->defer_list = NULL; + sd->defer_count = 0; + spin_unlock_irqrestore(&sd->defer_lock, flags); + + while (skb != NULL) { + next = skb->next; + napi_consume_skb(skb, 1); + skb = next; + } +} + +static __latent_entropy void net_rx_action(struct softirq_action *h) +{ + struct softnet_data *sd = this_cpu_ptr(&softnet_data); + unsigned long time_limit = jiffies + + usecs_to_jiffies(READ_ONCE(netdev_budget_usecs)); + int budget = READ_ONCE(netdev_budget); + LIST_HEAD(list); + LIST_HEAD(repoll); + + local_irq_disable(); + list_splice_init(&sd->poll_list, &list); + local_irq_enable(); + + for (;;) { + struct napi_struct *n; + + skb_defer_free_flush(sd); + + if (list_empty(&list)) { + if (!sd_has_rps_ipi_waiting(sd) && list_empty(&repoll)) + goto end; + break; + } + + n = list_first_entry(&list, struct napi_struct, poll_list); + budget -= napi_poll(n, &repoll); + + /* If softirq window is exhausted then punt. + * Allow this to run for 2 jiffies since which will allow + * an average latency of 1.5/HZ. + */ + if (unlikely(budget <= 0 || + time_after_eq(jiffies, time_limit))) { + sd->time_squeeze++; + break; + } + } + + local_irq_disable(); + + list_splice_tail_init(&sd->poll_list, &list); + list_splice_tail(&repoll, &list); + list_splice(&list, &sd->poll_list); + if (!list_empty(&sd->poll_list)) + __raise_softirq_irqoff(NET_RX_SOFTIRQ); + + net_rps_action_and_irq_enable(sd); +end:; +} + +struct netdev_adjacent { + struct net_device *dev; + netdevice_tracker dev_tracker; + + /* upper master flag, there can only be one master device per list */ + bool master; + + /* lookup ignore flag */ + bool ignore; + + /* counter for the number of times this device was added to us */ + u16 ref_nr; + + /* private field for the users */ + void *private; + + struct list_head list; + struct rcu_head rcu; +}; + +static struct netdev_adjacent *__netdev_find_adj(struct net_device *adj_dev, + struct list_head *adj_list) +{ + struct netdev_adjacent *adj; + + list_for_each_entry(adj, adj_list, list) { + if (adj->dev == adj_dev) + return adj; + } + return NULL; +} + +static int ____netdev_has_upper_dev(struct net_device *upper_dev, + struct netdev_nested_priv *priv) +{ + struct net_device *dev = (struct net_device *)priv->data; + + return upper_dev == dev; +} + +/** + * netdev_has_upper_dev - Check if device is linked to an upper device + * @dev: device + * @upper_dev: upper device to check + * + * Find out if a device is linked to specified upper device and return true + * in case it is. Note that this checks only immediate upper device, + * not through a complete stack of devices. The caller must hold the RTNL lock. + */ +bool netdev_has_upper_dev(struct net_device *dev, + struct net_device *upper_dev) +{ + struct netdev_nested_priv priv = { + .data = (void *)upper_dev, + }; + + ASSERT_RTNL(); + + return netdev_walk_all_upper_dev_rcu(dev, ____netdev_has_upper_dev, + &priv); +} +EXPORT_SYMBOL(netdev_has_upper_dev); + +/** + * netdev_has_upper_dev_all_rcu - Check if device is linked to an upper device + * @dev: device + * @upper_dev: upper device to check + * + * Find out if a device is linked to specified upper device and return true + * in case it is. Note that this checks the entire upper device chain. + * The caller must hold rcu lock. + */ + +bool netdev_has_upper_dev_all_rcu(struct net_device *dev, + struct net_device *upper_dev) +{ + struct netdev_nested_priv priv = { + .data = (void *)upper_dev, + }; + + return !!netdev_walk_all_upper_dev_rcu(dev, ____netdev_has_upper_dev, + &priv); +} +EXPORT_SYMBOL(netdev_has_upper_dev_all_rcu); + +/** + * netdev_has_any_upper_dev - Check if device is linked to some device + * @dev: device + * + * Find out if a device is linked to an upper device and return true in case + * it is. The caller must hold the RTNL lock. + */ +bool netdev_has_any_upper_dev(struct net_device *dev) +{ + ASSERT_RTNL(); + + return !list_empty(&dev->adj_list.upper); +} +EXPORT_SYMBOL(netdev_has_any_upper_dev); + +/** + * netdev_master_upper_dev_get - Get master upper device + * @dev: device + * + * Find a master upper device and return pointer to it or NULL in case + * it's not there. The caller must hold the RTNL lock. + */ +struct net_device *netdev_master_upper_dev_get(struct net_device *dev) +{ + struct netdev_adjacent *upper; + + ASSERT_RTNL(); + + if (list_empty(&dev->adj_list.upper)) + return NULL; + + upper = list_first_entry(&dev->adj_list.upper, + struct netdev_adjacent, list); + if (likely(upper->master)) + return upper->dev; + return NULL; +} +EXPORT_SYMBOL(netdev_master_upper_dev_get); + +static struct net_device *__netdev_master_upper_dev_get(struct net_device *dev) +{ + struct netdev_adjacent *upper; + + ASSERT_RTNL(); + + if (list_empty(&dev->adj_list.upper)) + return NULL; + + upper = list_first_entry(&dev->adj_list.upper, + struct netdev_adjacent, list); + if (likely(upper->master) && !upper->ignore) + return upper->dev; + return NULL; +} + +/** + * netdev_has_any_lower_dev - Check if device is linked to some device + * @dev: device + * + * Find out if a device is linked to a lower device and return true in case + * it is. The caller must hold the RTNL lock. + */ +static bool netdev_has_any_lower_dev(struct net_device *dev) +{ + ASSERT_RTNL(); + + return !list_empty(&dev->adj_list.lower); +} + +void *netdev_adjacent_get_private(struct list_head *adj_list) +{ + struct netdev_adjacent *adj; + + adj = list_entry(adj_list, struct netdev_adjacent, list); + + return adj->private; +} +EXPORT_SYMBOL(netdev_adjacent_get_private); + +/** + * netdev_upper_get_next_dev_rcu - Get the next dev from upper list + * @dev: device + * @iter: list_head ** of the current position + * + * Gets the next device from the dev's upper list, starting from iter + * position. The caller must hold RCU read lock. + */ +struct net_device *netdev_upper_get_next_dev_rcu(struct net_device *dev, + struct list_head **iter) +{ + struct netdev_adjacent *upper; + + WARN_ON_ONCE(!rcu_read_lock_held() && !lockdep_rtnl_is_held()); + + upper = list_entry_rcu((*iter)->next, struct netdev_adjacent, list); + + if (&upper->list == &dev->adj_list.upper) + return NULL; + + *iter = &upper->list; + + return upper->dev; +} +EXPORT_SYMBOL(netdev_upper_get_next_dev_rcu); + +static struct net_device *__netdev_next_upper_dev(struct net_device *dev, + struct list_head **iter, + bool *ignore) +{ + struct netdev_adjacent *upper; + + upper = list_entry((*iter)->next, struct netdev_adjacent, list); + + if (&upper->list == &dev->adj_list.upper) + return NULL; + + *iter = &upper->list; + *ignore = upper->ignore; + + return upper->dev; +} + +static struct net_device *netdev_next_upper_dev_rcu(struct net_device *dev, + struct list_head **iter) +{ + struct netdev_adjacent *upper; + + WARN_ON_ONCE(!rcu_read_lock_held() && !lockdep_rtnl_is_held()); + + upper = list_entry_rcu((*iter)->next, struct netdev_adjacent, list); + + if (&upper->list == &dev->adj_list.upper) + return NULL; + + *iter = &upper->list; + + return upper->dev; +} + +static int __netdev_walk_all_upper_dev(struct net_device *dev, + int (*fn)(struct net_device *dev, + struct netdev_nested_priv *priv), + struct netdev_nested_priv *priv) +{ + struct net_device *udev, *next, *now, *dev_stack[MAX_NEST_DEV + 1]; + struct list_head *niter, *iter, *iter_stack[MAX_NEST_DEV + 1]; + int ret, cur = 0; + bool ignore; + + now = dev; + iter = &dev->adj_list.upper; + + while (1) { + if (now != dev) { + ret = fn(now, priv); + if (ret) + return ret; + } + + next = NULL; + while (1) { + udev = __netdev_next_upper_dev(now, &iter, &ignore); + if (!udev) + break; + if (ignore) + continue; + + next = udev; + niter = &udev->adj_list.upper; + dev_stack[cur] = now; + iter_stack[cur++] = iter; + break; + } + + if (!next) { + if (!cur) + return 0; + next = dev_stack[--cur]; + niter = iter_stack[cur]; + } + + now = next; + iter = niter; + } + + return 0; +} + +int netdev_walk_all_upper_dev_rcu(struct net_device *dev, + int (*fn)(struct net_device *dev, + struct netdev_nested_priv *priv), + struct netdev_nested_priv *priv) +{ + struct net_device *udev, *next, *now, *dev_stack[MAX_NEST_DEV + 1]; + struct list_head *niter, *iter, *iter_stack[MAX_NEST_DEV + 1]; + int ret, cur = 0; + + now = dev; + iter = &dev->adj_list.upper; + + while (1) { + if (now != dev) { + ret = fn(now, priv); + if (ret) + return ret; + } + + next = NULL; + while (1) { + udev = netdev_next_upper_dev_rcu(now, &iter); + if (!udev) + break; + + next = udev; + niter = &udev->adj_list.upper; + dev_stack[cur] = now; + iter_stack[cur++] = iter; + break; + } + + if (!next) { + if (!cur) + return 0; + next = dev_stack[--cur]; + niter = iter_stack[cur]; + } + + now = next; + iter = niter; + } + + return 0; +} +EXPORT_SYMBOL_GPL(netdev_walk_all_upper_dev_rcu); + +static bool __netdev_has_upper_dev(struct net_device *dev, + struct net_device *upper_dev) +{ + struct netdev_nested_priv priv = { + .flags = 0, + .data = (void *)upper_dev, + }; + + ASSERT_RTNL(); + + return __netdev_walk_all_upper_dev(dev, ____netdev_has_upper_dev, + &priv); +} + +/** + * netdev_lower_get_next_private - Get the next ->private from the + * lower neighbour list + * @dev: device + * @iter: list_head ** of the current position + * + * Gets the next netdev_adjacent->private from the dev's lower neighbour + * list, starting from iter position. The caller must hold either hold the + * RTNL lock or its own locking that guarantees that the neighbour lower + * list will remain unchanged. + */ +void *netdev_lower_get_next_private(struct net_device *dev, + struct list_head **iter) +{ + struct netdev_adjacent *lower; + + lower = list_entry(*iter, struct netdev_adjacent, list); + + if (&lower->list == &dev->adj_list.lower) + return NULL; + + *iter = lower->list.next; + + return lower->private; +} +EXPORT_SYMBOL(netdev_lower_get_next_private); + +/** + * netdev_lower_get_next_private_rcu - Get the next ->private from the + * lower neighbour list, RCU + * variant + * @dev: device + * @iter: list_head ** of the current position + * + * Gets the next netdev_adjacent->private from the dev's lower neighbour + * list, starting from iter position. The caller must hold RCU read lock. + */ +void *netdev_lower_get_next_private_rcu(struct net_device *dev, + struct list_head **iter) +{ + struct netdev_adjacent *lower; + + WARN_ON_ONCE(!rcu_read_lock_held() && !rcu_read_lock_bh_held()); + + lower = list_entry_rcu((*iter)->next, struct netdev_adjacent, list); + + if (&lower->list == &dev->adj_list.lower) + return NULL; + + *iter = &lower->list; + + return lower->private; +} +EXPORT_SYMBOL(netdev_lower_get_next_private_rcu); + +/** + * netdev_lower_get_next - Get the next device from the lower neighbour + * list + * @dev: device + * @iter: list_head ** of the current position + * + * Gets the next netdev_adjacent from the dev's lower neighbour + * list, starting from iter position. The caller must hold RTNL lock or + * its own locking that guarantees that the neighbour lower + * list will remain unchanged. + */ +void *netdev_lower_get_next(struct net_device *dev, struct list_head **iter) +{ + struct netdev_adjacent *lower; + + lower = list_entry(*iter, struct netdev_adjacent, list); + + if (&lower->list == &dev->adj_list.lower) + return NULL; + + *iter = lower->list.next; + + return lower->dev; +} +EXPORT_SYMBOL(netdev_lower_get_next); + +static struct net_device *netdev_next_lower_dev(struct net_device *dev, + struct list_head **iter) +{ + struct netdev_adjacent *lower; + + lower = list_entry((*iter)->next, struct netdev_adjacent, list); + + if (&lower->list == &dev->adj_list.lower) + return NULL; + + *iter = &lower->list; + + return lower->dev; +} + +static struct net_device *__netdev_next_lower_dev(struct net_device *dev, + struct list_head **iter, + bool *ignore) +{ + struct netdev_adjacent *lower; + + lower = list_entry((*iter)->next, struct netdev_adjacent, list); + + if (&lower->list == &dev->adj_list.lower) + return NULL; + + *iter = &lower->list; + *ignore = lower->ignore; + + return lower->dev; +} + +int netdev_walk_all_lower_dev(struct net_device *dev, + int (*fn)(struct net_device *dev, + struct netdev_nested_priv *priv), + struct netdev_nested_priv *priv) +{ + struct net_device *ldev, *next, *now, *dev_stack[MAX_NEST_DEV + 1]; + struct list_head *niter, *iter, *iter_stack[MAX_NEST_DEV + 1]; + int ret, cur = 0; + + now = dev; + iter = &dev->adj_list.lower; + + while (1) { + if (now != dev) { + ret = fn(now, priv); + if (ret) + return ret; + } + + next = NULL; + while (1) { + ldev = netdev_next_lower_dev(now, &iter); + if (!ldev) + break; + + next = ldev; + niter = &ldev->adj_list.lower; + dev_stack[cur] = now; + iter_stack[cur++] = iter; + break; + } + + if (!next) { + if (!cur) + return 0; + next = dev_stack[--cur]; + niter = iter_stack[cur]; + } + + now = next; + iter = niter; + } + + return 0; +} +EXPORT_SYMBOL_GPL(netdev_walk_all_lower_dev); + +static int __netdev_walk_all_lower_dev(struct net_device *dev, + int (*fn)(struct net_device *dev, + struct netdev_nested_priv *priv), + struct netdev_nested_priv *priv) +{ + struct net_device *ldev, *next, *now, *dev_stack[MAX_NEST_DEV + 1]; + struct list_head *niter, *iter, *iter_stack[MAX_NEST_DEV + 1]; + int ret, cur = 0; + bool ignore; + + now = dev; + iter = &dev->adj_list.lower; + + while (1) { + if (now != dev) { + ret = fn(now, priv); + if (ret) + return ret; + } + + next = NULL; + while (1) { + ldev = __netdev_next_lower_dev(now, &iter, &ignore); + if (!ldev) + break; + if (ignore) + continue; + + next = ldev; + niter = &ldev->adj_list.lower; + dev_stack[cur] = now; + iter_stack[cur++] = iter; + break; + } + + if (!next) { + if (!cur) + return 0; + next = dev_stack[--cur]; + niter = iter_stack[cur]; + } + + now = next; + iter = niter; + } + + return 0; +} + +struct net_device *netdev_next_lower_dev_rcu(struct net_device *dev, + struct list_head **iter) +{ + struct netdev_adjacent *lower; + + lower = list_entry_rcu((*iter)->next, struct netdev_adjacent, list); + if (&lower->list == &dev->adj_list.lower) + return NULL; + + *iter = &lower->list; + + return lower->dev; +} +EXPORT_SYMBOL(netdev_next_lower_dev_rcu); + +static u8 __netdev_upper_depth(struct net_device *dev) +{ + struct net_device *udev; + struct list_head *iter; + u8 max_depth = 0; + bool ignore; + + for (iter = &dev->adj_list.upper, + udev = __netdev_next_upper_dev(dev, &iter, &ignore); + udev; + udev = __netdev_next_upper_dev(dev, &iter, &ignore)) { + if (ignore) + continue; + if (max_depth < udev->upper_level) + max_depth = udev->upper_level; + } + + return max_depth; +} + +static u8 __netdev_lower_depth(struct net_device *dev) +{ + struct net_device *ldev; + struct list_head *iter; + u8 max_depth = 0; + bool ignore; + + for (iter = &dev->adj_list.lower, + ldev = __netdev_next_lower_dev(dev, &iter, &ignore); + ldev; + ldev = __netdev_next_lower_dev(dev, &iter, &ignore)) { + if (ignore) + continue; + if (max_depth < ldev->lower_level) + max_depth = ldev->lower_level; + } + + return max_depth; +} + +static int __netdev_update_upper_level(struct net_device *dev, + struct netdev_nested_priv *__unused) +{ + dev->upper_level = __netdev_upper_depth(dev) + 1; + return 0; +} + +#ifdef CONFIG_LOCKDEP +static LIST_HEAD(net_unlink_list); + +static void net_unlink_todo(struct net_device *dev) +{ + if (list_empty(&dev->unlink_list)) + list_add_tail(&dev->unlink_list, &net_unlink_list); +} +#endif + +static int __netdev_update_lower_level(struct net_device *dev, + struct netdev_nested_priv *priv) +{ + dev->lower_level = __netdev_lower_depth(dev) + 1; + +#ifdef CONFIG_LOCKDEP + if (!priv) + return 0; + + if (priv->flags & NESTED_SYNC_IMM) + dev->nested_level = dev->lower_level - 1; + if (priv->flags & NESTED_SYNC_TODO) + net_unlink_todo(dev); +#endif + return 0; +} + +int netdev_walk_all_lower_dev_rcu(struct net_device *dev, + int (*fn)(struct net_device *dev, + struct netdev_nested_priv *priv), + struct netdev_nested_priv *priv) +{ + struct net_device *ldev, *next, *now, *dev_stack[MAX_NEST_DEV + 1]; + struct list_head *niter, *iter, *iter_stack[MAX_NEST_DEV + 1]; + int ret, cur = 0; + + now = dev; + iter = &dev->adj_list.lower; + + while (1) { + if (now != dev) { + ret = fn(now, priv); + if (ret) + return ret; + } + + next = NULL; + while (1) { + ldev = netdev_next_lower_dev_rcu(now, &iter); + if (!ldev) + break; + + next = ldev; + niter = &ldev->adj_list.lower; + dev_stack[cur] = now; + iter_stack[cur++] = iter; + break; + } + + if (!next) { + if (!cur) + return 0; + next = dev_stack[--cur]; + niter = iter_stack[cur]; + } + + now = next; + iter = niter; + } + + return 0; +} +EXPORT_SYMBOL_GPL(netdev_walk_all_lower_dev_rcu); + +/** + * netdev_lower_get_first_private_rcu - Get the first ->private from the + * lower neighbour list, RCU + * variant + * @dev: device + * + * Gets the first netdev_adjacent->private from the dev's lower neighbour + * list. The caller must hold RCU read lock. + */ +void *netdev_lower_get_first_private_rcu(struct net_device *dev) +{ + struct netdev_adjacent *lower; + + lower = list_first_or_null_rcu(&dev->adj_list.lower, + struct netdev_adjacent, list); + if (lower) + return lower->private; + return NULL; +} +EXPORT_SYMBOL(netdev_lower_get_first_private_rcu); + +/** + * netdev_master_upper_dev_get_rcu - Get master upper device + * @dev: device + * + * Find a master upper device and return pointer to it or NULL in case + * it's not there. The caller must hold the RCU read lock. + */ +struct net_device *netdev_master_upper_dev_get_rcu(struct net_device *dev) +{ + struct netdev_adjacent *upper; + + upper = list_first_or_null_rcu(&dev->adj_list.upper, + struct netdev_adjacent, list); + if (upper && likely(upper->master)) + return upper->dev; + return NULL; +} +EXPORT_SYMBOL(netdev_master_upper_dev_get_rcu); + +static int netdev_adjacent_sysfs_add(struct net_device *dev, + struct net_device *adj_dev, + struct list_head *dev_list) +{ + char linkname[IFNAMSIZ+7]; + + sprintf(linkname, dev_list == &dev->adj_list.upper ? + "upper_%s" : "lower_%s", adj_dev->name); + return sysfs_create_link(&(dev->dev.kobj), &(adj_dev->dev.kobj), + linkname); +} +static void netdev_adjacent_sysfs_del(struct net_device *dev, + char *name, + struct list_head *dev_list) +{ + char linkname[IFNAMSIZ+7]; + + sprintf(linkname, dev_list == &dev->adj_list.upper ? + "upper_%s" : "lower_%s", name); + sysfs_remove_link(&(dev->dev.kobj), linkname); +} + +static inline bool netdev_adjacent_is_neigh_list(struct net_device *dev, + struct net_device *adj_dev, + struct list_head *dev_list) +{ + return (dev_list == &dev->adj_list.upper || + dev_list == &dev->adj_list.lower) && + net_eq(dev_net(dev), dev_net(adj_dev)); +} + +static int __netdev_adjacent_dev_insert(struct net_device *dev, + struct net_device *adj_dev, + struct list_head *dev_list, + void *private, bool master) +{ + struct netdev_adjacent *adj; + int ret; + + adj = __netdev_find_adj(adj_dev, dev_list); + + if (adj) { + adj->ref_nr += 1; + pr_debug("Insert adjacency: dev %s adj_dev %s adj->ref_nr %d\n", + dev->name, adj_dev->name, adj->ref_nr); + + return 0; + } + + adj = kmalloc(sizeof(*adj), GFP_KERNEL); + if (!adj) + return -ENOMEM; + + adj->dev = adj_dev; + adj->master = master; + adj->ref_nr = 1; + adj->private = private; + adj->ignore = false; + netdev_hold(adj_dev, &adj->dev_tracker, GFP_KERNEL); + + pr_debug("Insert adjacency: dev %s adj_dev %s adj->ref_nr %d; dev_hold on %s\n", + dev->name, adj_dev->name, adj->ref_nr, adj_dev->name); + + if (netdev_adjacent_is_neigh_list(dev, adj_dev, dev_list)) { + ret = netdev_adjacent_sysfs_add(dev, adj_dev, dev_list); + if (ret) + goto free_adj; + } + + /* Ensure that master link is always the first item in list. */ + if (master) { + ret = sysfs_create_link(&(dev->dev.kobj), + &(adj_dev->dev.kobj), "master"); + if (ret) + goto remove_symlinks; + + list_add_rcu(&adj->list, dev_list); + } else { + list_add_tail_rcu(&adj->list, dev_list); + } + + return 0; + +remove_symlinks: + if (netdev_adjacent_is_neigh_list(dev, adj_dev, dev_list)) + netdev_adjacent_sysfs_del(dev, adj_dev->name, dev_list); +free_adj: + netdev_put(adj_dev, &adj->dev_tracker); + kfree(adj); + + return ret; +} + +static void __netdev_adjacent_dev_remove(struct net_device *dev, + struct net_device *adj_dev, + u16 ref_nr, + struct list_head *dev_list) +{ + struct netdev_adjacent *adj; + + pr_debug("Remove adjacency: dev %s adj_dev %s ref_nr %d\n", + dev->name, adj_dev->name, ref_nr); + + adj = __netdev_find_adj(adj_dev, dev_list); + + if (!adj) { + pr_err("Adjacency does not exist for device %s from %s\n", + dev->name, adj_dev->name); + WARN_ON(1); + return; + } + + if (adj->ref_nr > ref_nr) { + pr_debug("adjacency: %s to %s ref_nr - %d = %d\n", + dev->name, adj_dev->name, ref_nr, + adj->ref_nr - ref_nr); + adj->ref_nr -= ref_nr; + return; + } + + if (adj->master) + sysfs_remove_link(&(dev->dev.kobj), "master"); + + if (netdev_adjacent_is_neigh_list(dev, adj_dev, dev_list)) + netdev_adjacent_sysfs_del(dev, adj_dev->name, dev_list); + + list_del_rcu(&adj->list); + pr_debug("adjacency: dev_put for %s, because link removed from %s to %s\n", + adj_dev->name, dev->name, adj_dev->name); + netdev_put(adj_dev, &adj->dev_tracker); + kfree_rcu(adj, rcu); +} + +static int __netdev_adjacent_dev_link_lists(struct net_device *dev, + struct net_device *upper_dev, + struct list_head *up_list, + struct list_head *down_list, + void *private, bool master) +{ + int ret; + + ret = __netdev_adjacent_dev_insert(dev, upper_dev, up_list, + private, master); + if (ret) + return ret; + + ret = __netdev_adjacent_dev_insert(upper_dev, dev, down_list, + private, false); + if (ret) { + __netdev_adjacent_dev_remove(dev, upper_dev, 1, up_list); + return ret; + } + + return 0; +} + +static void __netdev_adjacent_dev_unlink_lists(struct net_device *dev, + struct net_device *upper_dev, + u16 ref_nr, + struct list_head *up_list, + struct list_head *down_list) +{ + __netdev_adjacent_dev_remove(dev, upper_dev, ref_nr, up_list); + __netdev_adjacent_dev_remove(upper_dev, dev, ref_nr, down_list); +} + +static int __netdev_adjacent_dev_link_neighbour(struct net_device *dev, + struct net_device *upper_dev, + void *private, bool master) +{ + return __netdev_adjacent_dev_link_lists(dev, upper_dev, + &dev->adj_list.upper, + &upper_dev->adj_list.lower, + private, master); +} + +static void __netdev_adjacent_dev_unlink_neighbour(struct net_device *dev, + struct net_device *upper_dev) +{ + __netdev_adjacent_dev_unlink_lists(dev, upper_dev, 1, + &dev->adj_list.upper, + &upper_dev->adj_list.lower); +} + +static int __netdev_upper_dev_link(struct net_device *dev, + struct net_device *upper_dev, bool master, + void *upper_priv, void *upper_info, + struct netdev_nested_priv *priv, + struct netlink_ext_ack *extack) +{ + struct netdev_notifier_changeupper_info changeupper_info = { + .info = { + .dev = dev, + .extack = extack, + }, + .upper_dev = upper_dev, + .master = master, + .linking = true, + .upper_info = upper_info, + }; + struct net_device *master_dev; + int ret = 0; + + ASSERT_RTNL(); + + if (dev == upper_dev) + return -EBUSY; + + /* To prevent loops, check if dev is not upper device to upper_dev. */ + if (__netdev_has_upper_dev(upper_dev, dev)) + return -EBUSY; + + if ((dev->lower_level + upper_dev->upper_level) > MAX_NEST_DEV) + return -EMLINK; + + if (!master) { + if (__netdev_has_upper_dev(dev, upper_dev)) + return -EEXIST; + } else { + master_dev = __netdev_master_upper_dev_get(dev); + if (master_dev) + return master_dev == upper_dev ? -EEXIST : -EBUSY; + } + + ret = call_netdevice_notifiers_info(NETDEV_PRECHANGEUPPER, + &changeupper_info.info); + ret = notifier_to_errno(ret); + if (ret) + return ret; + + ret = __netdev_adjacent_dev_link_neighbour(dev, upper_dev, upper_priv, + master); + if (ret) + return ret; + + ret = call_netdevice_notifiers_info(NETDEV_CHANGEUPPER, + &changeupper_info.info); + ret = notifier_to_errno(ret); + if (ret) + goto rollback; + + __netdev_update_upper_level(dev, NULL); + __netdev_walk_all_lower_dev(dev, __netdev_update_upper_level, NULL); + + __netdev_update_lower_level(upper_dev, priv); + __netdev_walk_all_upper_dev(upper_dev, __netdev_update_lower_level, + priv); + + return 0; + +rollback: + __netdev_adjacent_dev_unlink_neighbour(dev, upper_dev); + + return ret; +} + +/** + * netdev_upper_dev_link - Add a link to the upper device + * @dev: device + * @upper_dev: new upper device + * @extack: netlink extended ack + * + * Adds a link to device which is upper to this one. The caller must hold + * the RTNL lock. On a failure a negative errno code is returned. + * On success the reference counts are adjusted and the function + * returns zero. + */ +int netdev_upper_dev_link(struct net_device *dev, + struct net_device *upper_dev, + struct netlink_ext_ack *extack) +{ + struct netdev_nested_priv priv = { + .flags = NESTED_SYNC_IMM | NESTED_SYNC_TODO, + .data = NULL, + }; + + return __netdev_upper_dev_link(dev, upper_dev, false, + NULL, NULL, &priv, extack); +} +EXPORT_SYMBOL(netdev_upper_dev_link); + +/** + * netdev_master_upper_dev_link - Add a master link to the upper device + * @dev: device + * @upper_dev: new upper device + * @upper_priv: upper device private + * @upper_info: upper info to be passed down via notifier + * @extack: netlink extended ack + * + * Adds a link to device which is upper to this one. In this case, only + * one master upper device can be linked, although other non-master devices + * might be linked as well. The caller must hold the RTNL lock. + * On a failure a negative errno code is returned. On success the reference + * counts are adjusted and the function returns zero. + */ +int netdev_master_upper_dev_link(struct net_device *dev, + struct net_device *upper_dev, + void *upper_priv, void *upper_info, + struct netlink_ext_ack *extack) +{ + struct netdev_nested_priv priv = { + .flags = NESTED_SYNC_IMM | NESTED_SYNC_TODO, + .data = NULL, + }; + + return __netdev_upper_dev_link(dev, upper_dev, true, + upper_priv, upper_info, &priv, extack); +} +EXPORT_SYMBOL(netdev_master_upper_dev_link); + +static void __netdev_upper_dev_unlink(struct net_device *dev, + struct net_device *upper_dev, + struct netdev_nested_priv *priv) +{ + struct netdev_notifier_changeupper_info changeupper_info = { + .info = { + .dev = dev, + }, + .upper_dev = upper_dev, + .linking = false, + }; + + ASSERT_RTNL(); + + changeupper_info.master = netdev_master_upper_dev_get(dev) == upper_dev; + + call_netdevice_notifiers_info(NETDEV_PRECHANGEUPPER, + &changeupper_info.info); + + __netdev_adjacent_dev_unlink_neighbour(dev, upper_dev); + + call_netdevice_notifiers_info(NETDEV_CHANGEUPPER, + &changeupper_info.info); + + __netdev_update_upper_level(dev, NULL); + __netdev_walk_all_lower_dev(dev, __netdev_update_upper_level, NULL); + + __netdev_update_lower_level(upper_dev, priv); + __netdev_walk_all_upper_dev(upper_dev, __netdev_update_lower_level, + priv); +} + +/** + * netdev_upper_dev_unlink - Removes a link to upper device + * @dev: device + * @upper_dev: new upper device + * + * Removes a link to device which is upper to this one. The caller must hold + * the RTNL lock. + */ +void netdev_upper_dev_unlink(struct net_device *dev, + struct net_device *upper_dev) +{ + struct netdev_nested_priv priv = { + .flags = NESTED_SYNC_TODO, + .data = NULL, + }; + + __netdev_upper_dev_unlink(dev, upper_dev, &priv); +} +EXPORT_SYMBOL(netdev_upper_dev_unlink); + +static void __netdev_adjacent_dev_set(struct net_device *upper_dev, + struct net_device *lower_dev, + bool val) +{ + struct netdev_adjacent *adj; + + adj = __netdev_find_adj(lower_dev, &upper_dev->adj_list.lower); + if (adj) + adj->ignore = val; + + adj = __netdev_find_adj(upper_dev, &lower_dev->adj_list.upper); + if (adj) + adj->ignore = val; +} + +static void netdev_adjacent_dev_disable(struct net_device *upper_dev, + struct net_device *lower_dev) +{ + __netdev_adjacent_dev_set(upper_dev, lower_dev, true); +} + +static void netdev_adjacent_dev_enable(struct net_device *upper_dev, + struct net_device *lower_dev) +{ + __netdev_adjacent_dev_set(upper_dev, lower_dev, false); +} + +int netdev_adjacent_change_prepare(struct net_device *old_dev, + struct net_device *new_dev, + struct net_device *dev, + struct netlink_ext_ack *extack) +{ + struct netdev_nested_priv priv = { + .flags = 0, + .data = NULL, + }; + int err; + + if (!new_dev) + return 0; + + if (old_dev && new_dev != old_dev) + netdev_adjacent_dev_disable(dev, old_dev); + err = __netdev_upper_dev_link(new_dev, dev, false, NULL, NULL, &priv, + extack); + if (err) { + if (old_dev && new_dev != old_dev) + netdev_adjacent_dev_enable(dev, old_dev); + return err; + } + + return 0; +} +EXPORT_SYMBOL(netdev_adjacent_change_prepare); + +void netdev_adjacent_change_commit(struct net_device *old_dev, + struct net_device *new_dev, + struct net_device *dev) +{ + struct netdev_nested_priv priv = { + .flags = NESTED_SYNC_IMM | NESTED_SYNC_TODO, + .data = NULL, + }; + + if (!new_dev || !old_dev) + return; + + if (new_dev == old_dev) + return; + + netdev_adjacent_dev_enable(dev, old_dev); + __netdev_upper_dev_unlink(old_dev, dev, &priv); +} +EXPORT_SYMBOL(netdev_adjacent_change_commit); + +void netdev_adjacent_change_abort(struct net_device *old_dev, + struct net_device *new_dev, + struct net_device *dev) +{ + struct netdev_nested_priv priv = { + .flags = 0, + .data = NULL, + }; + + if (!new_dev) + return; + + if (old_dev && new_dev != old_dev) + netdev_adjacent_dev_enable(dev, old_dev); + + __netdev_upper_dev_unlink(new_dev, dev, &priv); +} +EXPORT_SYMBOL(netdev_adjacent_change_abort); + +/** + * netdev_bonding_info_change - Dispatch event about slave change + * @dev: device + * @bonding_info: info to dispatch + * + * Send NETDEV_BONDING_INFO to netdev notifiers with info. + * The caller must hold the RTNL lock. + */ +void netdev_bonding_info_change(struct net_device *dev, + struct netdev_bonding_info *bonding_info) +{ + struct netdev_notifier_bonding_info info = { + .info.dev = dev, + }; + + memcpy(&info.bonding_info, bonding_info, + sizeof(struct netdev_bonding_info)); + call_netdevice_notifiers_info(NETDEV_BONDING_INFO, + &info.info); +} +EXPORT_SYMBOL(netdev_bonding_info_change); + +static int netdev_offload_xstats_enable_l3(struct net_device *dev, + struct netlink_ext_ack *extack) +{ + struct netdev_notifier_offload_xstats_info info = { + .info.dev = dev, + .info.extack = extack, + .type = NETDEV_OFFLOAD_XSTATS_TYPE_L3, + }; + int err; + int rc; + + dev->offload_xstats_l3 = kzalloc(sizeof(*dev->offload_xstats_l3), + GFP_KERNEL); + if (!dev->offload_xstats_l3) + return -ENOMEM; + + rc = call_netdevice_notifiers_info_robust(NETDEV_OFFLOAD_XSTATS_ENABLE, + NETDEV_OFFLOAD_XSTATS_DISABLE, + &info.info); + err = notifier_to_errno(rc); + if (err) + goto free_stats; + + return 0; + +free_stats: + kfree(dev->offload_xstats_l3); + dev->offload_xstats_l3 = NULL; + return err; +} + +int netdev_offload_xstats_enable(struct net_device *dev, + enum netdev_offload_xstats_type type, + struct netlink_ext_ack *extack) +{ + ASSERT_RTNL(); + + if (netdev_offload_xstats_enabled(dev, type)) + return -EALREADY; + + switch (type) { + case NETDEV_OFFLOAD_XSTATS_TYPE_L3: + return netdev_offload_xstats_enable_l3(dev, extack); + } + + WARN_ON(1); + return -EINVAL; +} +EXPORT_SYMBOL(netdev_offload_xstats_enable); + +static void netdev_offload_xstats_disable_l3(struct net_device *dev) +{ + struct netdev_notifier_offload_xstats_info info = { + .info.dev = dev, + .type = NETDEV_OFFLOAD_XSTATS_TYPE_L3, + }; + + call_netdevice_notifiers_info(NETDEV_OFFLOAD_XSTATS_DISABLE, + &info.info); + kfree(dev->offload_xstats_l3); + dev->offload_xstats_l3 = NULL; +} + +int netdev_offload_xstats_disable(struct net_device *dev, + enum netdev_offload_xstats_type type) +{ + ASSERT_RTNL(); + + if (!netdev_offload_xstats_enabled(dev, type)) + return -EALREADY; + + switch (type) { + case NETDEV_OFFLOAD_XSTATS_TYPE_L3: + netdev_offload_xstats_disable_l3(dev); + return 0; + } + + WARN_ON(1); + return -EINVAL; +} +EXPORT_SYMBOL(netdev_offload_xstats_disable); + +static void netdev_offload_xstats_disable_all(struct net_device *dev) +{ + netdev_offload_xstats_disable(dev, NETDEV_OFFLOAD_XSTATS_TYPE_L3); +} + +static struct rtnl_hw_stats64 * +netdev_offload_xstats_get_ptr(const struct net_device *dev, + enum netdev_offload_xstats_type type) +{ + switch (type) { + case NETDEV_OFFLOAD_XSTATS_TYPE_L3: + return dev->offload_xstats_l3; + } + + WARN_ON(1); + return NULL; +} + +bool netdev_offload_xstats_enabled(const struct net_device *dev, + enum netdev_offload_xstats_type type) +{ + ASSERT_RTNL(); + + return netdev_offload_xstats_get_ptr(dev, type); +} +EXPORT_SYMBOL(netdev_offload_xstats_enabled); + +struct netdev_notifier_offload_xstats_ru { + bool used; +}; + +struct netdev_notifier_offload_xstats_rd { + struct rtnl_hw_stats64 stats; + bool used; +}; + +static void netdev_hw_stats64_add(struct rtnl_hw_stats64 *dest, + const struct rtnl_hw_stats64 *src) +{ + dest->rx_packets += src->rx_packets; + dest->tx_packets += src->tx_packets; + dest->rx_bytes += src->rx_bytes; + dest->tx_bytes += src->tx_bytes; + dest->rx_errors += src->rx_errors; + dest->tx_errors += src->tx_errors; + dest->rx_dropped += src->rx_dropped; + dest->tx_dropped += src->tx_dropped; + dest->multicast += src->multicast; +} + +static int netdev_offload_xstats_get_used(struct net_device *dev, + enum netdev_offload_xstats_type type, + bool *p_used, + struct netlink_ext_ack *extack) +{ + struct netdev_notifier_offload_xstats_ru report_used = {}; + struct netdev_notifier_offload_xstats_info info = { + .info.dev = dev, + .info.extack = extack, + .type = type, + .report_used = &report_used, + }; + int rc; + + WARN_ON(!netdev_offload_xstats_enabled(dev, type)); + rc = call_netdevice_notifiers_info(NETDEV_OFFLOAD_XSTATS_REPORT_USED, + &info.info); + *p_used = report_used.used; + return notifier_to_errno(rc); +} + +static int netdev_offload_xstats_get_stats(struct net_device *dev, + enum netdev_offload_xstats_type type, + struct rtnl_hw_stats64 *p_stats, + bool *p_used, + struct netlink_ext_ack *extack) +{ + struct netdev_notifier_offload_xstats_rd report_delta = {}; + struct netdev_notifier_offload_xstats_info info = { + .info.dev = dev, + .info.extack = extack, + .type = type, + .report_delta = &report_delta, + }; + struct rtnl_hw_stats64 *stats; + int rc; + + stats = netdev_offload_xstats_get_ptr(dev, type); + if (WARN_ON(!stats)) + return -EINVAL; + + rc = call_netdevice_notifiers_info(NETDEV_OFFLOAD_XSTATS_REPORT_DELTA, + &info.info); + + /* Cache whatever we got, even if there was an error, otherwise the + * successful stats retrievals would get lost. + */ + netdev_hw_stats64_add(stats, &report_delta.stats); + + if (p_stats) + *p_stats = *stats; + *p_used = report_delta.used; + + return notifier_to_errno(rc); +} + +int netdev_offload_xstats_get(struct net_device *dev, + enum netdev_offload_xstats_type type, + struct rtnl_hw_stats64 *p_stats, bool *p_used, + struct netlink_ext_ack *extack) +{ + ASSERT_RTNL(); + + if (p_stats) + return netdev_offload_xstats_get_stats(dev, type, p_stats, + p_used, extack); + else + return netdev_offload_xstats_get_used(dev, type, p_used, + extack); +} +EXPORT_SYMBOL(netdev_offload_xstats_get); + +void +netdev_offload_xstats_report_delta(struct netdev_notifier_offload_xstats_rd *report_delta, + const struct rtnl_hw_stats64 *stats) +{ + report_delta->used = true; + netdev_hw_stats64_add(&report_delta->stats, stats); +} +EXPORT_SYMBOL(netdev_offload_xstats_report_delta); + +void +netdev_offload_xstats_report_used(struct netdev_notifier_offload_xstats_ru *report_used) +{ + report_used->used = true; +} +EXPORT_SYMBOL(netdev_offload_xstats_report_used); + +void netdev_offload_xstats_push_delta(struct net_device *dev, + enum netdev_offload_xstats_type type, + const struct rtnl_hw_stats64 *p_stats) +{ + struct rtnl_hw_stats64 *stats; + + ASSERT_RTNL(); + + stats = netdev_offload_xstats_get_ptr(dev, type); + if (WARN_ON(!stats)) + return; + + netdev_hw_stats64_add(stats, p_stats); +} +EXPORT_SYMBOL(netdev_offload_xstats_push_delta); + +/** + * netdev_get_xmit_slave - Get the xmit slave of master device + * @dev: device + * @skb: The packet + * @all_slaves: assume all the slaves are active + * + * The reference counters are not incremented so the caller must be + * careful with locks. The caller must hold RCU lock. + * %NULL is returned if no slave is found. + */ + +struct net_device *netdev_get_xmit_slave(struct net_device *dev, + struct sk_buff *skb, + bool all_slaves) +{ + const struct net_device_ops *ops = dev->netdev_ops; + + if (!ops->ndo_get_xmit_slave) + return NULL; + return ops->ndo_get_xmit_slave(dev, skb, all_slaves); +} +EXPORT_SYMBOL(netdev_get_xmit_slave); + +static struct net_device *netdev_sk_get_lower_dev(struct net_device *dev, + struct sock *sk) +{ + const struct net_device_ops *ops = dev->netdev_ops; + + if (!ops->ndo_sk_get_lower_dev) + return NULL; + return ops->ndo_sk_get_lower_dev(dev, sk); +} + +/** + * netdev_sk_get_lowest_dev - Get the lowest device in chain given device and socket + * @dev: device + * @sk: the socket + * + * %NULL is returned if no lower device is found. + */ + +struct net_device *netdev_sk_get_lowest_dev(struct net_device *dev, + struct sock *sk) +{ + struct net_device *lower; + + lower = netdev_sk_get_lower_dev(dev, sk); + while (lower) { + dev = lower; + lower = netdev_sk_get_lower_dev(dev, sk); + } + + return dev; +} +EXPORT_SYMBOL(netdev_sk_get_lowest_dev); + +static void netdev_adjacent_add_links(struct net_device *dev) +{ + struct netdev_adjacent *iter; + + struct net *net = dev_net(dev); + + list_for_each_entry(iter, &dev->adj_list.upper, list) { + if (!net_eq(net, dev_net(iter->dev))) + continue; + netdev_adjacent_sysfs_add(iter->dev, dev, + &iter->dev->adj_list.lower); + netdev_adjacent_sysfs_add(dev, iter->dev, + &dev->adj_list.upper); + } + + list_for_each_entry(iter, &dev->adj_list.lower, list) { + if (!net_eq(net, dev_net(iter->dev))) + continue; + netdev_adjacent_sysfs_add(iter->dev, dev, + &iter->dev->adj_list.upper); + netdev_adjacent_sysfs_add(dev, iter->dev, + &dev->adj_list.lower); + } +} + +static void netdev_adjacent_del_links(struct net_device *dev) +{ + struct netdev_adjacent *iter; + + struct net *net = dev_net(dev); + + list_for_each_entry(iter, &dev->adj_list.upper, list) { + if (!net_eq(net, dev_net(iter->dev))) + continue; + netdev_adjacent_sysfs_del(iter->dev, dev->name, + &iter->dev->adj_list.lower); + netdev_adjacent_sysfs_del(dev, iter->dev->name, + &dev->adj_list.upper); + } + + list_for_each_entry(iter, &dev->adj_list.lower, list) { + if (!net_eq(net, dev_net(iter->dev))) + continue; + netdev_adjacent_sysfs_del(iter->dev, dev->name, + &iter->dev->adj_list.upper); + netdev_adjacent_sysfs_del(dev, iter->dev->name, + &dev->adj_list.lower); + } +} + +void netdev_adjacent_rename_links(struct net_device *dev, char *oldname) +{ + struct netdev_adjacent *iter; + + struct net *net = dev_net(dev); + + list_for_each_entry(iter, &dev->adj_list.upper, list) { + if (!net_eq(net, dev_net(iter->dev))) + continue; + netdev_adjacent_sysfs_del(iter->dev, oldname, + &iter->dev->adj_list.lower); + netdev_adjacent_sysfs_add(iter->dev, dev, + &iter->dev->adj_list.lower); + } + + list_for_each_entry(iter, &dev->adj_list.lower, list) { + if (!net_eq(net, dev_net(iter->dev))) + continue; + netdev_adjacent_sysfs_del(iter->dev, oldname, + &iter->dev->adj_list.upper); + netdev_adjacent_sysfs_add(iter->dev, dev, + &iter->dev->adj_list.upper); + } +} + +void *netdev_lower_dev_get_private(struct net_device *dev, + struct net_device *lower_dev) +{ + struct netdev_adjacent *lower; + + if (!lower_dev) + return NULL; + lower = __netdev_find_adj(lower_dev, &dev->adj_list.lower); + if (!lower) + return NULL; + + return lower->private; +} +EXPORT_SYMBOL(netdev_lower_dev_get_private); + + +/** + * netdev_lower_state_changed - Dispatch event about lower device state change + * @lower_dev: device + * @lower_state_info: state to dispatch + * + * Send NETDEV_CHANGELOWERSTATE to netdev notifiers with info. + * The caller must hold the RTNL lock. + */ +void netdev_lower_state_changed(struct net_device *lower_dev, + void *lower_state_info) +{ + struct netdev_notifier_changelowerstate_info changelowerstate_info = { + .info.dev = lower_dev, + }; + + ASSERT_RTNL(); + changelowerstate_info.lower_state_info = lower_state_info; + call_netdevice_notifiers_info(NETDEV_CHANGELOWERSTATE, + &changelowerstate_info.info); +} +EXPORT_SYMBOL(netdev_lower_state_changed); + +static void dev_change_rx_flags(struct net_device *dev, int flags) +{ + const struct net_device_ops *ops = dev->netdev_ops; + + if (ops->ndo_change_rx_flags) + ops->ndo_change_rx_flags(dev, flags); +} + +static int __dev_set_promiscuity(struct net_device *dev, int inc, bool notify) +{ + unsigned int old_flags = dev->flags; + kuid_t uid; + kgid_t gid; + + ASSERT_RTNL(); + + dev->flags |= IFF_PROMISC; + dev->promiscuity += inc; + if (dev->promiscuity == 0) { + /* + * Avoid overflow. + * If inc causes overflow, untouch promisc and return error. + */ + if (inc < 0) + dev->flags &= ~IFF_PROMISC; + else { + dev->promiscuity -= inc; + netdev_warn(dev, "promiscuity touches roof, set promiscuity failed. promiscuity feature of device might be broken.\n"); + return -EOVERFLOW; + } + } + if (dev->flags != old_flags) { + pr_info("device %s %s promiscuous mode\n", + dev->name, + dev->flags & IFF_PROMISC ? "entered" : "left"); + if (audit_enabled) { + current_uid_gid(&uid, &gid); + audit_log(audit_context(), GFP_ATOMIC, + AUDIT_ANOM_PROMISCUOUS, + "dev=%s prom=%d old_prom=%d auid=%u uid=%u gid=%u ses=%u", + dev->name, (dev->flags & IFF_PROMISC), + (old_flags & IFF_PROMISC), + from_kuid(&init_user_ns, audit_get_loginuid(current)), + from_kuid(&init_user_ns, uid), + from_kgid(&init_user_ns, gid), + audit_get_sessionid(current)); + } + + dev_change_rx_flags(dev, IFF_PROMISC); + } + if (notify) + __dev_notify_flags(dev, old_flags, IFF_PROMISC); + return 0; +} + +/** + * dev_set_promiscuity - update promiscuity count on a device + * @dev: device + * @inc: modifier + * + * Add or remove promiscuity from a device. While the count in the device + * remains above zero the interface remains promiscuous. Once it hits zero + * the device reverts back to normal filtering operation. A negative inc + * value is used to drop promiscuity on the device. + * Return 0 if successful or a negative errno code on error. + */ +int dev_set_promiscuity(struct net_device *dev, int inc) +{ + unsigned int old_flags = dev->flags; + int err; + + err = __dev_set_promiscuity(dev, inc, true); + if (err < 0) + return err; + if (dev->flags != old_flags) + dev_set_rx_mode(dev); + return err; +} +EXPORT_SYMBOL(dev_set_promiscuity); + +static int __dev_set_allmulti(struct net_device *dev, int inc, bool notify) +{ + unsigned int old_flags = dev->flags, old_gflags = dev->gflags; + + ASSERT_RTNL(); + + dev->flags |= IFF_ALLMULTI; + dev->allmulti += inc; + if (dev->allmulti == 0) { + /* + * Avoid overflow. + * If inc causes overflow, untouch allmulti and return error. + */ + if (inc < 0) + dev->flags &= ~IFF_ALLMULTI; + else { + dev->allmulti -= inc; + netdev_warn(dev, "allmulti touches roof, set allmulti failed. allmulti feature of device might be broken.\n"); + return -EOVERFLOW; + } + } + if (dev->flags ^ old_flags) { + dev_change_rx_flags(dev, IFF_ALLMULTI); + dev_set_rx_mode(dev); + if (notify) + __dev_notify_flags(dev, old_flags, + dev->gflags ^ old_gflags); + } + return 0; +} + +/** + * dev_set_allmulti - update allmulti count on a device + * @dev: device + * @inc: modifier + * + * Add or remove reception of all multicast frames to a device. While the + * count in the device remains above zero the interface remains listening + * to all interfaces. Once it hits zero the device reverts back to normal + * filtering operation. A negative @inc value is used to drop the counter + * when releasing a resource needing all multicasts. + * Return 0 if successful or a negative errno code on error. + */ + +int dev_set_allmulti(struct net_device *dev, int inc) +{ + return __dev_set_allmulti(dev, inc, true); +} +EXPORT_SYMBOL(dev_set_allmulti); + +/* + * Upload unicast and multicast address lists to device and + * configure RX filtering. When the device doesn't support unicast + * filtering it is put in promiscuous mode while unicast addresses + * are present. + */ +void __dev_set_rx_mode(struct net_device *dev) +{ + const struct net_device_ops *ops = dev->netdev_ops; + + /* dev_open will call this function so the list will stay sane. */ + if (!(dev->flags&IFF_UP)) + return; + + if (!netif_device_present(dev)) + return; + + if (!(dev->priv_flags & IFF_UNICAST_FLT)) { + /* Unicast addresses changes may only happen under the rtnl, + * therefore calling __dev_set_promiscuity here is safe. + */ + if (!netdev_uc_empty(dev) && !dev->uc_promisc) { + __dev_set_promiscuity(dev, 1, false); + dev->uc_promisc = true; + } else if (netdev_uc_empty(dev) && dev->uc_promisc) { + __dev_set_promiscuity(dev, -1, false); + dev->uc_promisc = false; + } + } + + if (ops->ndo_set_rx_mode) + ops->ndo_set_rx_mode(dev); +} + +void dev_set_rx_mode(struct net_device *dev) +{ + netif_addr_lock_bh(dev); + __dev_set_rx_mode(dev); + netif_addr_unlock_bh(dev); +} + +/** + * dev_get_flags - get flags reported to userspace + * @dev: device + * + * Get the combination of flag bits exported through APIs to userspace. + */ +unsigned int dev_get_flags(const struct net_device *dev) +{ + unsigned int flags; + + flags = (dev->flags & ~(IFF_PROMISC | + IFF_ALLMULTI | + IFF_RUNNING | + IFF_LOWER_UP | + IFF_DORMANT)) | + (dev->gflags & (IFF_PROMISC | + IFF_ALLMULTI)); + + if (netif_running(dev)) { + if (netif_oper_up(dev)) + flags |= IFF_RUNNING; + if (netif_carrier_ok(dev)) + flags |= IFF_LOWER_UP; + if (netif_dormant(dev)) + flags |= IFF_DORMANT; + } + + return flags; +} +EXPORT_SYMBOL(dev_get_flags); + +int __dev_change_flags(struct net_device *dev, unsigned int flags, + struct netlink_ext_ack *extack) +{ + unsigned int old_flags = dev->flags; + int ret; + + ASSERT_RTNL(); + + /* + * Set the flags on our device. + */ + + dev->flags = (flags & (IFF_DEBUG | IFF_NOTRAILERS | IFF_NOARP | + IFF_DYNAMIC | IFF_MULTICAST | IFF_PORTSEL | + IFF_AUTOMEDIA)) | + (dev->flags & (IFF_UP | IFF_VOLATILE | IFF_PROMISC | + IFF_ALLMULTI)); + + /* + * Load in the correct multicast list now the flags have changed. + */ + + if ((old_flags ^ flags) & IFF_MULTICAST) + dev_change_rx_flags(dev, IFF_MULTICAST); + + dev_set_rx_mode(dev); + + /* + * Have we downed the interface. We handle IFF_UP ourselves + * according to user attempts to set it, rather than blindly + * setting it. + */ + + ret = 0; + if ((old_flags ^ flags) & IFF_UP) { + if (old_flags & IFF_UP) + __dev_close(dev); + else + ret = __dev_open(dev, extack); + } + + if ((flags ^ dev->gflags) & IFF_PROMISC) { + int inc = (flags & IFF_PROMISC) ? 1 : -1; + unsigned int old_flags = dev->flags; + + dev->gflags ^= IFF_PROMISC; + + if (__dev_set_promiscuity(dev, inc, false) >= 0) + if (dev->flags != old_flags) + dev_set_rx_mode(dev); + } + + /* NOTE: order of synchronization of IFF_PROMISC and IFF_ALLMULTI + * is important. Some (broken) drivers set IFF_PROMISC, when + * IFF_ALLMULTI is requested not asking us and not reporting. + */ + if ((flags ^ dev->gflags) & IFF_ALLMULTI) { + int inc = (flags & IFF_ALLMULTI) ? 1 : -1; + + dev->gflags ^= IFF_ALLMULTI; + __dev_set_allmulti(dev, inc, false); + } + + return ret; +} + +void __dev_notify_flags(struct net_device *dev, unsigned int old_flags, + unsigned int gchanges) +{ + unsigned int changes = dev->flags ^ old_flags; + + if (gchanges) + rtmsg_ifinfo(RTM_NEWLINK, dev, gchanges, GFP_ATOMIC); + + if (changes & IFF_UP) { + if (dev->flags & IFF_UP) + call_netdevice_notifiers(NETDEV_UP, dev); + else + call_netdevice_notifiers(NETDEV_DOWN, dev); + } + + if (dev->flags & IFF_UP && + (changes & ~(IFF_UP | IFF_PROMISC | IFF_ALLMULTI | IFF_VOLATILE))) { + struct netdev_notifier_change_info change_info = { + .info = { + .dev = dev, + }, + .flags_changed = changes, + }; + + call_netdevice_notifiers_info(NETDEV_CHANGE, &change_info.info); + } +} + +/** + * dev_change_flags - change device settings + * @dev: device + * @flags: device state flags + * @extack: netlink extended ack + * + * Change settings on device based state flags. The flags are + * in the userspace exported format. + */ +int dev_change_flags(struct net_device *dev, unsigned int flags, + struct netlink_ext_ack *extack) +{ + int ret; + unsigned int changes, old_flags = dev->flags, old_gflags = dev->gflags; + + ret = __dev_change_flags(dev, flags, extack); + if (ret < 0) + return ret; + + changes = (old_flags ^ dev->flags) | (old_gflags ^ dev->gflags); + __dev_notify_flags(dev, old_flags, changes); + return ret; +} +EXPORT_SYMBOL(dev_change_flags); + +int __dev_set_mtu(struct net_device *dev, int new_mtu) +{ + const struct net_device_ops *ops = dev->netdev_ops; + + if (ops->ndo_change_mtu) + return ops->ndo_change_mtu(dev, new_mtu); + + /* Pairs with all the lockless reads of dev->mtu in the stack */ + WRITE_ONCE(dev->mtu, new_mtu); + return 0; +} +EXPORT_SYMBOL(__dev_set_mtu); + +int dev_validate_mtu(struct net_device *dev, int new_mtu, + struct netlink_ext_ack *extack) +{ + /* MTU must be positive, and in range */ + if (new_mtu < 0 || new_mtu < dev->min_mtu) { + NL_SET_ERR_MSG(extack, "mtu less than device minimum"); + return -EINVAL; + } + + if (dev->max_mtu > 0 && new_mtu > dev->max_mtu) { + NL_SET_ERR_MSG(extack, "mtu greater than device maximum"); + return -EINVAL; + } + return 0; +} + +/** + * dev_set_mtu_ext - Change maximum transfer unit + * @dev: device + * @new_mtu: new transfer unit + * @extack: netlink extended ack + * + * Change the maximum transfer size of the network device. + */ +int dev_set_mtu_ext(struct net_device *dev, int new_mtu, + struct netlink_ext_ack *extack) +{ + int err, orig_mtu; + + if (new_mtu == dev->mtu) + return 0; + + err = dev_validate_mtu(dev, new_mtu, extack); + if (err) + return err; + + if (!netif_device_present(dev)) + return -ENODEV; + + err = call_netdevice_notifiers(NETDEV_PRECHANGEMTU, dev); + err = notifier_to_errno(err); + if (err) + return err; + + orig_mtu = dev->mtu; + err = __dev_set_mtu(dev, new_mtu); + + if (!err) { + err = call_netdevice_notifiers_mtu(NETDEV_CHANGEMTU, dev, + orig_mtu); + err = notifier_to_errno(err); + if (err) { + /* setting mtu back and notifying everyone again, + * so that they have a chance to revert changes. + */ + __dev_set_mtu(dev, orig_mtu); + call_netdevice_notifiers_mtu(NETDEV_CHANGEMTU, dev, + new_mtu); + } + } + return err; +} + +int dev_set_mtu(struct net_device *dev, int new_mtu) +{ + struct netlink_ext_ack extack; + int err; + + memset(&extack, 0, sizeof(extack)); + err = dev_set_mtu_ext(dev, new_mtu, &extack); + if (err && extack._msg) + net_err_ratelimited("%s: %s\n", dev->name, extack._msg); + return err; +} +EXPORT_SYMBOL(dev_set_mtu); + +/** + * dev_change_tx_queue_len - Change TX queue length of a netdevice + * @dev: device + * @new_len: new tx queue length + */ +int dev_change_tx_queue_len(struct net_device *dev, unsigned long new_len) +{ + unsigned int orig_len = dev->tx_queue_len; + int res; + + if (new_len != (unsigned int)new_len) + return -ERANGE; + + if (new_len != orig_len) { + dev->tx_queue_len = new_len; + res = call_netdevice_notifiers(NETDEV_CHANGE_TX_QUEUE_LEN, dev); + res = notifier_to_errno(res); + if (res) + goto err_rollback; + res = dev_qdisc_change_tx_queue_len(dev); + if (res) + goto err_rollback; + } + + return 0; + +err_rollback: + netdev_err(dev, "refused to change device tx_queue_len\n"); + dev->tx_queue_len = orig_len; + return res; +} + +/** + * dev_set_group - Change group this device belongs to + * @dev: device + * @new_group: group this device should belong to + */ +void dev_set_group(struct net_device *dev, int new_group) +{ + dev->group = new_group; +} + +/** + * dev_pre_changeaddr_notify - Call NETDEV_PRE_CHANGEADDR. + * @dev: device + * @addr: new address + * @extack: netlink extended ack + */ +int dev_pre_changeaddr_notify(struct net_device *dev, const char *addr, + struct netlink_ext_ack *extack) +{ + struct netdev_notifier_pre_changeaddr_info info = { + .info.dev = dev, + .info.extack = extack, + .dev_addr = addr, + }; + int rc; + + rc = call_netdevice_notifiers_info(NETDEV_PRE_CHANGEADDR, &info.info); + return notifier_to_errno(rc); +} +EXPORT_SYMBOL(dev_pre_changeaddr_notify); + +/** + * dev_set_mac_address - Change Media Access Control Address + * @dev: device + * @sa: new address + * @extack: netlink extended ack + * + * Change the hardware (MAC) address of the device + */ +int dev_set_mac_address(struct net_device *dev, struct sockaddr *sa, + struct netlink_ext_ack *extack) +{ + const struct net_device_ops *ops = dev->netdev_ops; + int err; + + if (!ops->ndo_set_mac_address) + return -EOPNOTSUPP; + if (sa->sa_family != dev->type) + return -EINVAL; + if (!netif_device_present(dev)) + return -ENODEV; + err = dev_pre_changeaddr_notify(dev, sa->sa_data, extack); + if (err) + return err; + err = ops->ndo_set_mac_address(dev, sa); + if (err) + return err; + dev->addr_assign_type = NET_ADDR_SET; + call_netdevice_notifiers(NETDEV_CHANGEADDR, dev); + add_device_randomness(dev->dev_addr, dev->addr_len); + return 0; +} +EXPORT_SYMBOL(dev_set_mac_address); + +static DECLARE_RWSEM(dev_addr_sem); + +int dev_set_mac_address_user(struct net_device *dev, struct sockaddr *sa, + struct netlink_ext_ack *extack) +{ + int ret; + + down_write(&dev_addr_sem); + ret = dev_set_mac_address(dev, sa, extack); + up_write(&dev_addr_sem); + return ret; +} +EXPORT_SYMBOL(dev_set_mac_address_user); + +int dev_get_mac_address(struct sockaddr *sa, struct net *net, char *dev_name) +{ + size_t size = sizeof(sa->sa_data); + struct net_device *dev; + int ret = 0; + + down_read(&dev_addr_sem); + rcu_read_lock(); + + dev = dev_get_by_name_rcu(net, dev_name); + if (!dev) { + ret = -ENODEV; + goto unlock; + } + if (!dev->addr_len) + memset(sa->sa_data, 0, size); + else + memcpy(sa->sa_data, dev->dev_addr, + min_t(size_t, size, dev->addr_len)); + sa->sa_family = dev->type; + +unlock: + rcu_read_unlock(); + up_read(&dev_addr_sem); + return ret; +} +EXPORT_SYMBOL(dev_get_mac_address); + +/** + * dev_change_carrier - Change device carrier + * @dev: device + * @new_carrier: new value + * + * Change device carrier + */ +int dev_change_carrier(struct net_device *dev, bool new_carrier) +{ + const struct net_device_ops *ops = dev->netdev_ops; + + if (!ops->ndo_change_carrier) + return -EOPNOTSUPP; + if (!netif_device_present(dev)) + return -ENODEV; + return ops->ndo_change_carrier(dev, new_carrier); +} + +/** + * dev_get_phys_port_id - Get device physical port ID + * @dev: device + * @ppid: port ID + * + * Get device physical port ID + */ +int dev_get_phys_port_id(struct net_device *dev, + struct netdev_phys_item_id *ppid) +{ + const struct net_device_ops *ops = dev->netdev_ops; + + if (!ops->ndo_get_phys_port_id) + return -EOPNOTSUPP; + return ops->ndo_get_phys_port_id(dev, ppid); +} + +/** + * dev_get_phys_port_name - Get device physical port name + * @dev: device + * @name: port name + * @len: limit of bytes to copy to name + * + * Get device physical port name + */ +int dev_get_phys_port_name(struct net_device *dev, + char *name, size_t len) +{ + const struct net_device_ops *ops = dev->netdev_ops; + int err; + + if (ops->ndo_get_phys_port_name) { + err = ops->ndo_get_phys_port_name(dev, name, len); + if (err != -EOPNOTSUPP) + return err; + } + return devlink_compat_phys_port_name_get(dev, name, len); +} + +/** + * dev_get_port_parent_id - Get the device's port parent identifier + * @dev: network device + * @ppid: pointer to a storage for the port's parent identifier + * @recurse: allow/disallow recursion to lower devices + * + * Get the devices's port parent identifier + */ +int dev_get_port_parent_id(struct net_device *dev, + struct netdev_phys_item_id *ppid, + bool recurse) +{ + const struct net_device_ops *ops = dev->netdev_ops; + struct netdev_phys_item_id first = { }; + struct net_device *lower_dev; + struct list_head *iter; + int err; + + if (ops->ndo_get_port_parent_id) { + err = ops->ndo_get_port_parent_id(dev, ppid); + if (err != -EOPNOTSUPP) + return err; + } + + err = devlink_compat_switch_id_get(dev, ppid); + if (!recurse || err != -EOPNOTSUPP) + return err; + + netdev_for_each_lower_dev(dev, lower_dev, iter) { + err = dev_get_port_parent_id(lower_dev, ppid, true); + if (err) + break; + if (!first.id_len) + first = *ppid; + else if (memcmp(&first, ppid, sizeof(*ppid))) + return -EOPNOTSUPP; + } + + return err; +} +EXPORT_SYMBOL(dev_get_port_parent_id); + +/** + * netdev_port_same_parent_id - Indicate if two network devices have + * the same port parent identifier + * @a: first network device + * @b: second network device + */ +bool netdev_port_same_parent_id(struct net_device *a, struct net_device *b) +{ + struct netdev_phys_item_id a_id = { }; + struct netdev_phys_item_id b_id = { }; + + if (dev_get_port_parent_id(a, &a_id, true) || + dev_get_port_parent_id(b, &b_id, true)) + return false; + + return netdev_phys_item_id_same(&a_id, &b_id); +} +EXPORT_SYMBOL(netdev_port_same_parent_id); + +/** + * dev_change_proto_down - set carrier according to proto_down. + * + * @dev: device + * @proto_down: new value + */ +int dev_change_proto_down(struct net_device *dev, bool proto_down) +{ + if (!(dev->priv_flags & IFF_CHANGE_PROTO_DOWN)) + return -EOPNOTSUPP; + if (!netif_device_present(dev)) + return -ENODEV; + if (proto_down) + netif_carrier_off(dev); + else + netif_carrier_on(dev); + dev->proto_down = proto_down; + return 0; +} + +/** + * dev_change_proto_down_reason - proto down reason + * + * @dev: device + * @mask: proto down mask + * @value: proto down value + */ +void dev_change_proto_down_reason(struct net_device *dev, unsigned long mask, + u32 value) +{ + int b; + + if (!mask) { + dev->proto_down_reason = value; + } else { + for_each_set_bit(b, &mask, 32) { + if (value & (1 << b)) + dev->proto_down_reason |= BIT(b); + else + dev->proto_down_reason &= ~BIT(b); + } + } +} + +struct bpf_xdp_link { + struct bpf_link link; + struct net_device *dev; /* protected by rtnl_lock, no refcnt held */ + int flags; +}; + +static enum bpf_xdp_mode dev_xdp_mode(struct net_device *dev, u32 flags) +{ + if (flags & XDP_FLAGS_HW_MODE) + return XDP_MODE_HW; + if (flags & XDP_FLAGS_DRV_MODE) + return XDP_MODE_DRV; + if (flags & XDP_FLAGS_SKB_MODE) + return XDP_MODE_SKB; + return dev->netdev_ops->ndo_bpf ? XDP_MODE_DRV : XDP_MODE_SKB; +} + +static bpf_op_t dev_xdp_bpf_op(struct net_device *dev, enum bpf_xdp_mode mode) +{ + switch (mode) { + case XDP_MODE_SKB: + return generic_xdp_install; + case XDP_MODE_DRV: + case XDP_MODE_HW: + return dev->netdev_ops->ndo_bpf; + default: + return NULL; + } +} + +static struct bpf_xdp_link *dev_xdp_link(struct net_device *dev, + enum bpf_xdp_mode mode) +{ + return dev->xdp_state[mode].link; +} + +static struct bpf_prog *dev_xdp_prog(struct net_device *dev, + enum bpf_xdp_mode mode) +{ + struct bpf_xdp_link *link = dev_xdp_link(dev, mode); + + if (link) + return link->link.prog; + return dev->xdp_state[mode].prog; +} + +u8 dev_xdp_prog_count(struct net_device *dev) +{ + u8 count = 0; + int i; + + for (i = 0; i < __MAX_XDP_MODE; i++) + if (dev->xdp_state[i].prog || dev->xdp_state[i].link) + count++; + return count; +} +EXPORT_SYMBOL_GPL(dev_xdp_prog_count); + +u32 dev_xdp_prog_id(struct net_device *dev, enum bpf_xdp_mode mode) +{ + struct bpf_prog *prog = dev_xdp_prog(dev, mode); + + return prog ? prog->aux->id : 0; +} + +static void dev_xdp_set_link(struct net_device *dev, enum bpf_xdp_mode mode, + struct bpf_xdp_link *link) +{ + dev->xdp_state[mode].link = link; + dev->xdp_state[mode].prog = NULL; +} + +static void dev_xdp_set_prog(struct net_device *dev, enum bpf_xdp_mode mode, + struct bpf_prog *prog) +{ + dev->xdp_state[mode].link = NULL; + dev->xdp_state[mode].prog = prog; +} + +static int dev_xdp_install(struct net_device *dev, enum bpf_xdp_mode mode, + bpf_op_t bpf_op, struct netlink_ext_ack *extack, + u32 flags, struct bpf_prog *prog) +{ + struct netdev_bpf xdp; + int err; + + memset(&xdp, 0, sizeof(xdp)); + xdp.command = mode == XDP_MODE_HW ? XDP_SETUP_PROG_HW : XDP_SETUP_PROG; + xdp.extack = extack; + xdp.flags = flags; + xdp.prog = prog; + + /* Drivers assume refcnt is already incremented (i.e, prog pointer is + * "moved" into driver), so they don't increment it on their own, but + * they do decrement refcnt when program is detached or replaced. + * Given net_device also owns link/prog, we need to bump refcnt here + * to prevent drivers from underflowing it. + */ + if (prog) + bpf_prog_inc(prog); + err = bpf_op(dev, &xdp); + if (err) { + if (prog) + bpf_prog_put(prog); + return err; + } + + if (mode != XDP_MODE_HW) + bpf_prog_change_xdp(dev_xdp_prog(dev, mode), prog); + + return 0; +} + +static void dev_xdp_uninstall(struct net_device *dev) +{ + struct bpf_xdp_link *link; + struct bpf_prog *prog; + enum bpf_xdp_mode mode; + bpf_op_t bpf_op; + + ASSERT_RTNL(); + + for (mode = XDP_MODE_SKB; mode < __MAX_XDP_MODE; mode++) { + prog = dev_xdp_prog(dev, mode); + if (!prog) + continue; + + bpf_op = dev_xdp_bpf_op(dev, mode); + if (!bpf_op) + continue; + + WARN_ON(dev_xdp_install(dev, mode, bpf_op, NULL, 0, NULL)); + + /* auto-detach link from net device */ + link = dev_xdp_link(dev, mode); + if (link) + link->dev = NULL; + else + bpf_prog_put(prog); + + dev_xdp_set_link(dev, mode, NULL); + } +} + +static int dev_xdp_attach(struct net_device *dev, struct netlink_ext_ack *extack, + struct bpf_xdp_link *link, struct bpf_prog *new_prog, + struct bpf_prog *old_prog, u32 flags) +{ + unsigned int num_modes = hweight32(flags & XDP_FLAGS_MODES); + struct bpf_prog *cur_prog; + struct net_device *upper; + struct list_head *iter; + enum bpf_xdp_mode mode; + bpf_op_t bpf_op; + int err; + + ASSERT_RTNL(); + + /* either link or prog attachment, never both */ + if (link && (new_prog || old_prog)) + return -EINVAL; + /* link supports only XDP mode flags */ + if (link && (flags & ~XDP_FLAGS_MODES)) { + NL_SET_ERR_MSG(extack, "Invalid XDP flags for BPF link attachment"); + return -EINVAL; + } + /* just one XDP mode bit should be set, zero defaults to drv/skb mode */ + if (num_modes > 1) { + NL_SET_ERR_MSG(extack, "Only one XDP mode flag can be set"); + return -EINVAL; + } + /* avoid ambiguity if offload + drv/skb mode progs are both loaded */ + if (!num_modes && dev_xdp_prog_count(dev) > 1) { + NL_SET_ERR_MSG(extack, + "More than one program loaded, unset mode is ambiguous"); + return -EINVAL; + } + /* old_prog != NULL implies XDP_FLAGS_REPLACE is set */ + if (old_prog && !(flags & XDP_FLAGS_REPLACE)) { + NL_SET_ERR_MSG(extack, "XDP_FLAGS_REPLACE is not specified"); + return -EINVAL; + } + + mode = dev_xdp_mode(dev, flags); + /* can't replace attached link */ + if (dev_xdp_link(dev, mode)) { + NL_SET_ERR_MSG(extack, "Can't replace active BPF XDP link"); + return -EBUSY; + } + + /* don't allow if an upper device already has a program */ + netdev_for_each_upper_dev_rcu(dev, upper, iter) { + if (dev_xdp_prog_count(upper) > 0) { + NL_SET_ERR_MSG(extack, "Cannot attach when an upper device already has a program"); + return -EEXIST; + } + } + + cur_prog = dev_xdp_prog(dev, mode); + /* can't replace attached prog with link */ + if (link && cur_prog) { + NL_SET_ERR_MSG(extack, "Can't replace active XDP program with BPF link"); + return -EBUSY; + } + if ((flags & XDP_FLAGS_REPLACE) && cur_prog != old_prog) { + NL_SET_ERR_MSG(extack, "Active program does not match expected"); + return -EEXIST; + } + + /* put effective new program into new_prog */ + if (link) + new_prog = link->link.prog; + + if (new_prog) { + bool offload = mode == XDP_MODE_HW; + enum bpf_xdp_mode other_mode = mode == XDP_MODE_SKB + ? XDP_MODE_DRV : XDP_MODE_SKB; + + if ((flags & XDP_FLAGS_UPDATE_IF_NOEXIST) && cur_prog) { + NL_SET_ERR_MSG(extack, "XDP program already attached"); + return -EBUSY; + } + if (!offload && dev_xdp_prog(dev, other_mode)) { + NL_SET_ERR_MSG(extack, "Native and generic XDP can't be active at the same time"); + return -EEXIST; + } + if (!offload && bpf_prog_is_dev_bound(new_prog->aux)) { + NL_SET_ERR_MSG(extack, "Using device-bound program without HW_MODE flag is not supported"); + return -EINVAL; + } + if (new_prog->expected_attach_type == BPF_XDP_DEVMAP) { + NL_SET_ERR_MSG(extack, "BPF_XDP_DEVMAP programs can not be attached to a device"); + return -EINVAL; + } + if (new_prog->expected_attach_type == BPF_XDP_CPUMAP) { + NL_SET_ERR_MSG(extack, "BPF_XDP_CPUMAP programs can not be attached to a device"); + return -EINVAL; + } + } + + /* don't call drivers if the effective program didn't change */ + if (new_prog != cur_prog) { + bpf_op = dev_xdp_bpf_op(dev, mode); + if (!bpf_op) { + NL_SET_ERR_MSG(extack, "Underlying driver does not support XDP in native mode"); + return -EOPNOTSUPP; + } + + err = dev_xdp_install(dev, mode, bpf_op, extack, flags, new_prog); + if (err) + return err; + } + + if (link) + dev_xdp_set_link(dev, mode, link); + else + dev_xdp_set_prog(dev, mode, new_prog); + if (cur_prog) + bpf_prog_put(cur_prog); + + return 0; +} + +static int dev_xdp_attach_link(struct net_device *dev, + struct netlink_ext_ack *extack, + struct bpf_xdp_link *link) +{ + return dev_xdp_attach(dev, extack, link, NULL, NULL, link->flags); +} + +static int dev_xdp_detach_link(struct net_device *dev, + struct netlink_ext_ack *extack, + struct bpf_xdp_link *link) +{ + enum bpf_xdp_mode mode; + bpf_op_t bpf_op; + + ASSERT_RTNL(); + + mode = dev_xdp_mode(dev, link->flags); + if (dev_xdp_link(dev, mode) != link) + return -EINVAL; + + bpf_op = dev_xdp_bpf_op(dev, mode); + WARN_ON(dev_xdp_install(dev, mode, bpf_op, NULL, 0, NULL)); + dev_xdp_set_link(dev, mode, NULL); + return 0; +} + +static void bpf_xdp_link_release(struct bpf_link *link) +{ + struct bpf_xdp_link *xdp_link = container_of(link, struct bpf_xdp_link, link); + + rtnl_lock(); + + /* if racing with net_device's tear down, xdp_link->dev might be + * already NULL, in which case link was already auto-detached + */ + if (xdp_link->dev) { + WARN_ON(dev_xdp_detach_link(xdp_link->dev, NULL, xdp_link)); + xdp_link->dev = NULL; + } + + rtnl_unlock(); +} + +static int bpf_xdp_link_detach(struct bpf_link *link) +{ + bpf_xdp_link_release(link); + return 0; +} + +static void bpf_xdp_link_dealloc(struct bpf_link *link) +{ + struct bpf_xdp_link *xdp_link = container_of(link, struct bpf_xdp_link, link); + + kfree(xdp_link); +} + +static void bpf_xdp_link_show_fdinfo(const struct bpf_link *link, + struct seq_file *seq) +{ + struct bpf_xdp_link *xdp_link = container_of(link, struct bpf_xdp_link, link); + u32 ifindex = 0; + + rtnl_lock(); + if (xdp_link->dev) + ifindex = xdp_link->dev->ifindex; + rtnl_unlock(); + + seq_printf(seq, "ifindex:\t%u\n", ifindex); +} + +static int bpf_xdp_link_fill_link_info(const struct bpf_link *link, + struct bpf_link_info *info) +{ + struct bpf_xdp_link *xdp_link = container_of(link, struct bpf_xdp_link, link); + u32 ifindex = 0; + + rtnl_lock(); + if (xdp_link->dev) + ifindex = xdp_link->dev->ifindex; + rtnl_unlock(); + + info->xdp.ifindex = ifindex; + return 0; +} + +static int bpf_xdp_link_update(struct bpf_link *link, struct bpf_prog *new_prog, + struct bpf_prog *old_prog) +{ + struct bpf_xdp_link *xdp_link = container_of(link, struct bpf_xdp_link, link); + enum bpf_xdp_mode mode; + bpf_op_t bpf_op; + int err = 0; + + rtnl_lock(); + + /* link might have been auto-released already, so fail */ + if (!xdp_link->dev) { + err = -ENOLINK; + goto out_unlock; + } + + if (old_prog && link->prog != old_prog) { + err = -EPERM; + goto out_unlock; + } + old_prog = link->prog; + if (old_prog->type != new_prog->type || + old_prog->expected_attach_type != new_prog->expected_attach_type) { + err = -EINVAL; + goto out_unlock; + } + + if (old_prog == new_prog) { + /* no-op, don't disturb drivers */ + bpf_prog_put(new_prog); + goto out_unlock; + } + + mode = dev_xdp_mode(xdp_link->dev, xdp_link->flags); + bpf_op = dev_xdp_bpf_op(xdp_link->dev, mode); + err = dev_xdp_install(xdp_link->dev, mode, bpf_op, NULL, + xdp_link->flags, new_prog); + if (err) + goto out_unlock; + + old_prog = xchg(&link->prog, new_prog); + bpf_prog_put(old_prog); + +out_unlock: + rtnl_unlock(); + return err; +} + +static const struct bpf_link_ops bpf_xdp_link_lops = { + .release = bpf_xdp_link_release, + .dealloc = bpf_xdp_link_dealloc, + .detach = bpf_xdp_link_detach, + .show_fdinfo = bpf_xdp_link_show_fdinfo, + .fill_link_info = bpf_xdp_link_fill_link_info, + .update_prog = bpf_xdp_link_update, +}; + +int bpf_xdp_link_attach(const union bpf_attr *attr, struct bpf_prog *prog) +{ + struct net *net = current->nsproxy->net_ns; + struct bpf_link_primer link_primer; + struct bpf_xdp_link *link; + struct net_device *dev; + int err, fd; + + rtnl_lock(); + dev = dev_get_by_index(net, attr->link_create.target_ifindex); + if (!dev) { + rtnl_unlock(); + return -EINVAL; + } + + link = kzalloc(sizeof(*link), GFP_USER); + if (!link) { + err = -ENOMEM; + goto unlock; + } + + bpf_link_init(&link->link, BPF_LINK_TYPE_XDP, &bpf_xdp_link_lops, prog); + link->dev = dev; + link->flags = attr->link_create.flags; + + err = bpf_link_prime(&link->link, &link_primer); + if (err) { + kfree(link); + goto unlock; + } + + err = dev_xdp_attach_link(dev, NULL, link); + rtnl_unlock(); + + if (err) { + link->dev = NULL; + bpf_link_cleanup(&link_primer); + goto out_put_dev; + } + + fd = bpf_link_settle(&link_primer); + /* link itself doesn't hold dev's refcnt to not complicate shutdown */ + dev_put(dev); + return fd; + +unlock: + rtnl_unlock(); + +out_put_dev: + dev_put(dev); + return err; +} + +/** + * dev_change_xdp_fd - set or clear a bpf program for a device rx path + * @dev: device + * @extack: netlink extended ack + * @fd: new program fd or negative value to clear + * @expected_fd: old program fd that userspace expects to replace or clear + * @flags: xdp-related flags + * + * Set or clear a bpf program for a device + */ +int dev_change_xdp_fd(struct net_device *dev, struct netlink_ext_ack *extack, + int fd, int expected_fd, u32 flags) +{ + enum bpf_xdp_mode mode = dev_xdp_mode(dev, flags); + struct bpf_prog *new_prog = NULL, *old_prog = NULL; + int err; + + ASSERT_RTNL(); + + if (fd >= 0) { + new_prog = bpf_prog_get_type_dev(fd, BPF_PROG_TYPE_XDP, + mode != XDP_MODE_SKB); + if (IS_ERR(new_prog)) + return PTR_ERR(new_prog); + } + + if (expected_fd >= 0) { + old_prog = bpf_prog_get_type_dev(expected_fd, BPF_PROG_TYPE_XDP, + mode != XDP_MODE_SKB); + if (IS_ERR(old_prog)) { + err = PTR_ERR(old_prog); + old_prog = NULL; + goto err_out; + } + } + + err = dev_xdp_attach(dev, extack, NULL, new_prog, old_prog, flags); + +err_out: + if (err && new_prog) + bpf_prog_put(new_prog); + if (old_prog) + bpf_prog_put(old_prog); + return err; +} + +/** + * dev_new_index - allocate an ifindex + * @net: the applicable net namespace + * + * Returns a suitable unique value for a new device interface + * number. The caller must hold the rtnl semaphore or the + * dev_base_lock to be sure it remains unique. + */ +static int dev_new_index(struct net *net) +{ + int ifindex = net->ifindex; + + for (;;) { + if (++ifindex <= 0) + ifindex = 1; + if (!__dev_get_by_index(net, ifindex)) + return net->ifindex = ifindex; + } +} + +/* Delayed registration/unregisteration */ +LIST_HEAD(net_todo_list); +DECLARE_WAIT_QUEUE_HEAD(netdev_unregistering_wq); + +static void net_set_todo(struct net_device *dev) +{ + list_add_tail(&dev->todo_list, &net_todo_list); + atomic_inc(&dev_net(dev)->dev_unreg_count); +} + +static netdev_features_t netdev_sync_upper_features(struct net_device *lower, + struct net_device *upper, netdev_features_t features) +{ + netdev_features_t upper_disables = NETIF_F_UPPER_DISABLES; + netdev_features_t feature; + int feature_bit; + + for_each_netdev_feature(upper_disables, feature_bit) { + feature = __NETIF_F_BIT(feature_bit); + if (!(upper->wanted_features & feature) + && (features & feature)) { + netdev_dbg(lower, "Dropping feature %pNF, upper dev %s has it off.\n", + &feature, upper->name); + features &= ~feature; + } + } + + return features; +} + +static void netdev_sync_lower_features(struct net_device *upper, + struct net_device *lower, netdev_features_t features) +{ + netdev_features_t upper_disables = NETIF_F_UPPER_DISABLES; + netdev_features_t feature; + int feature_bit; + + for_each_netdev_feature(upper_disables, feature_bit) { + feature = __NETIF_F_BIT(feature_bit); + if (!(features & feature) && (lower->features & feature)) { + netdev_dbg(upper, "Disabling feature %pNF on lower dev %s.\n", + &feature, lower->name); + lower->wanted_features &= ~feature; + __netdev_update_features(lower); + + if (unlikely(lower->features & feature)) + netdev_WARN(upper, "failed to disable %pNF on %s!\n", + &feature, lower->name); + else + netdev_features_change(lower); + } + } +} + +static netdev_features_t netdev_fix_features(struct net_device *dev, + netdev_features_t features) +{ + /* Fix illegal checksum combinations */ + if ((features & NETIF_F_HW_CSUM) && + (features & (NETIF_F_IP_CSUM|NETIF_F_IPV6_CSUM))) { + netdev_warn(dev, "mixed HW and IP checksum settings.\n"); + features &= ~(NETIF_F_IP_CSUM|NETIF_F_IPV6_CSUM); + } + + /* TSO requires that SG is present as well. */ + if ((features & NETIF_F_ALL_TSO) && !(features & NETIF_F_SG)) { + netdev_dbg(dev, "Dropping TSO features since no SG feature.\n"); + features &= ~NETIF_F_ALL_TSO; + } + + if ((features & NETIF_F_TSO) && !(features & NETIF_F_HW_CSUM) && + !(features & NETIF_F_IP_CSUM)) { + netdev_dbg(dev, "Dropping TSO features since no CSUM feature.\n"); + features &= ~NETIF_F_TSO; + features &= ~NETIF_F_TSO_ECN; + } + + if ((features & NETIF_F_TSO6) && !(features & NETIF_F_HW_CSUM) && + !(features & NETIF_F_IPV6_CSUM)) { + netdev_dbg(dev, "Dropping TSO6 features since no CSUM feature.\n"); + features &= ~NETIF_F_TSO6; + } + + /* TSO with IPv4 ID mangling requires IPv4 TSO be enabled */ + if ((features & NETIF_F_TSO_MANGLEID) && !(features & NETIF_F_TSO)) + features &= ~NETIF_F_TSO_MANGLEID; + + /* TSO ECN requires that TSO is present as well. */ + if ((features & NETIF_F_ALL_TSO) == NETIF_F_TSO_ECN) + features &= ~NETIF_F_TSO_ECN; + + /* Software GSO depends on SG. */ + if ((features & NETIF_F_GSO) && !(features & NETIF_F_SG)) { + netdev_dbg(dev, "Dropping NETIF_F_GSO since no SG feature.\n"); + features &= ~NETIF_F_GSO; + } + + /* GSO partial features require GSO partial be set */ + if ((features & dev->gso_partial_features) && + !(features & NETIF_F_GSO_PARTIAL)) { + netdev_dbg(dev, + "Dropping partially supported GSO features since no GSO partial.\n"); + features &= ~dev->gso_partial_features; + } + + if (!(features & NETIF_F_RXCSUM)) { + /* NETIF_F_GRO_HW implies doing RXCSUM since every packet + * successfully merged by hardware must also have the + * checksum verified by hardware. If the user does not + * want to enable RXCSUM, logically, we should disable GRO_HW. + */ + if (features & NETIF_F_GRO_HW) { + netdev_dbg(dev, "Dropping NETIF_F_GRO_HW since no RXCSUM feature.\n"); + features &= ~NETIF_F_GRO_HW; + } + } + + /* LRO/HW-GRO features cannot be combined with RX-FCS */ + if (features & NETIF_F_RXFCS) { + if (features & NETIF_F_LRO) { + netdev_dbg(dev, "Dropping LRO feature since RX-FCS is requested.\n"); + features &= ~NETIF_F_LRO; + } + + if (features & NETIF_F_GRO_HW) { + netdev_dbg(dev, "Dropping HW-GRO feature since RX-FCS is requested.\n"); + features &= ~NETIF_F_GRO_HW; + } + } + + if ((features & NETIF_F_GRO_HW) && (features & NETIF_F_LRO)) { + netdev_dbg(dev, "Dropping LRO feature since HW-GRO is requested.\n"); + features &= ~NETIF_F_LRO; + } + + if (features & NETIF_F_HW_TLS_TX) { + bool ip_csum = (features & (NETIF_F_IP_CSUM | NETIF_F_IPV6_CSUM)) == + (NETIF_F_IP_CSUM | NETIF_F_IPV6_CSUM); + bool hw_csum = features & NETIF_F_HW_CSUM; + + if (!ip_csum && !hw_csum) { + netdev_dbg(dev, "Dropping TLS TX HW offload feature since no CSUM feature.\n"); + features &= ~NETIF_F_HW_TLS_TX; + } + } + + if ((features & NETIF_F_HW_TLS_RX) && !(features & NETIF_F_RXCSUM)) { + netdev_dbg(dev, "Dropping TLS RX HW offload feature since no RXCSUM feature.\n"); + features &= ~NETIF_F_HW_TLS_RX; + } + + return features; +} + +int __netdev_update_features(struct net_device *dev) +{ + struct net_device *upper, *lower; + netdev_features_t features; + struct list_head *iter; + int err = -1; + + ASSERT_RTNL(); + + features = netdev_get_wanted_features(dev); + + if (dev->netdev_ops->ndo_fix_features) + features = dev->netdev_ops->ndo_fix_features(dev, features); + + /* driver might be less strict about feature dependencies */ + features = netdev_fix_features(dev, features); + + /* some features can't be enabled if they're off on an upper device */ + netdev_for_each_upper_dev_rcu(dev, upper, iter) + features = netdev_sync_upper_features(dev, upper, features); + + if (dev->features == features) + goto sync_lower; + + netdev_dbg(dev, "Features changed: %pNF -> %pNF\n", + &dev->features, &features); + + if (dev->netdev_ops->ndo_set_features) + err = dev->netdev_ops->ndo_set_features(dev, features); + else + err = 0; + + if (unlikely(err < 0)) { + netdev_err(dev, + "set_features() failed (%d); wanted %pNF, left %pNF\n", + err, &features, &dev->features); + /* return non-0 since some features might have changed and + * it's better to fire a spurious notification than miss it + */ + return -1; + } + +sync_lower: + /* some features must be disabled on lower devices when disabled + * on an upper device (think: bonding master or bridge) + */ + netdev_for_each_lower_dev(dev, lower, iter) + netdev_sync_lower_features(dev, lower, features); + + if (!err) { + netdev_features_t diff = features ^ dev->features; + + if (diff & NETIF_F_RX_UDP_TUNNEL_PORT) { + /* udp_tunnel_{get,drop}_rx_info both need + * NETIF_F_RX_UDP_TUNNEL_PORT enabled on the + * device, or they won't do anything. + * Thus we need to update dev->features + * *before* calling udp_tunnel_get_rx_info, + * but *after* calling udp_tunnel_drop_rx_info. + */ + if (features & NETIF_F_RX_UDP_TUNNEL_PORT) { + dev->features = features; + udp_tunnel_get_rx_info(dev); + } else { + udp_tunnel_drop_rx_info(dev); + } + } + + if (diff & NETIF_F_HW_VLAN_CTAG_FILTER) { + if (features & NETIF_F_HW_VLAN_CTAG_FILTER) { + dev->features = features; + err |= vlan_get_rx_ctag_filter_info(dev); + } else { + vlan_drop_rx_ctag_filter_info(dev); + } + } + + if (diff & NETIF_F_HW_VLAN_STAG_FILTER) { + if (features & NETIF_F_HW_VLAN_STAG_FILTER) { + dev->features = features; + err |= vlan_get_rx_stag_filter_info(dev); + } else { + vlan_drop_rx_stag_filter_info(dev); + } + } + + dev->features = features; + } + + return err < 0 ? 0 : 1; +} + +/** + * netdev_update_features - recalculate device features + * @dev: the device to check + * + * Recalculate dev->features set and send notifications if it + * has changed. Should be called after driver or hardware dependent + * conditions might have changed that influence the features. + */ +void netdev_update_features(struct net_device *dev) +{ + if (__netdev_update_features(dev)) + netdev_features_change(dev); +} +EXPORT_SYMBOL(netdev_update_features); + +/** + * netdev_change_features - recalculate device features + * @dev: the device to check + * + * Recalculate dev->features set and send notifications even + * if they have not changed. Should be called instead of + * netdev_update_features() if also dev->vlan_features might + * have changed to allow the changes to be propagated to stacked + * VLAN devices. + */ +void netdev_change_features(struct net_device *dev) +{ + __netdev_update_features(dev); + netdev_features_change(dev); +} +EXPORT_SYMBOL(netdev_change_features); + +/** + * netif_stacked_transfer_operstate - transfer operstate + * @rootdev: the root or lower level device to transfer state from + * @dev: the device to transfer operstate to + * + * Transfer operational state from root to device. This is normally + * called when a stacking relationship exists between the root + * device and the device(a leaf device). + */ +void netif_stacked_transfer_operstate(const struct net_device *rootdev, + struct net_device *dev) +{ + if (rootdev->operstate == IF_OPER_DORMANT) + netif_dormant_on(dev); + else + netif_dormant_off(dev); + + if (rootdev->operstate == IF_OPER_TESTING) + netif_testing_on(dev); + else + netif_testing_off(dev); + + if (netif_carrier_ok(rootdev)) + netif_carrier_on(dev); + else + netif_carrier_off(dev); +} +EXPORT_SYMBOL(netif_stacked_transfer_operstate); + +static int netif_alloc_rx_queues(struct net_device *dev) +{ + unsigned int i, count = dev->num_rx_queues; + struct netdev_rx_queue *rx; + size_t sz = count * sizeof(*rx); + int err = 0; + + BUG_ON(count < 1); + + rx = kvzalloc(sz, GFP_KERNEL_ACCOUNT | __GFP_RETRY_MAYFAIL); + if (!rx) + return -ENOMEM; + + dev->_rx = rx; + + for (i = 0; i < count; i++) { + rx[i].dev = dev; + + /* XDP RX-queue setup */ + err = xdp_rxq_info_reg(&rx[i].xdp_rxq, dev, i, 0); + if (err < 0) + goto err_rxq_info; + } + return 0; + +err_rxq_info: + /* Rollback successful reg's and free other resources */ + while (i--) + xdp_rxq_info_unreg(&rx[i].xdp_rxq); + kvfree(dev->_rx); + dev->_rx = NULL; + return err; +} + +static void netif_free_rx_queues(struct net_device *dev) +{ + unsigned int i, count = dev->num_rx_queues; + + /* netif_alloc_rx_queues alloc failed, resources have been unreg'ed */ + if (!dev->_rx) + return; + + for (i = 0; i < count; i++) + xdp_rxq_info_unreg(&dev->_rx[i].xdp_rxq); + + kvfree(dev->_rx); +} + +static void netdev_init_one_queue(struct net_device *dev, + struct netdev_queue *queue, void *_unused) +{ + /* Initialize queue lock */ + spin_lock_init(&queue->_xmit_lock); + netdev_set_xmit_lockdep_class(&queue->_xmit_lock, dev->type); + queue->xmit_lock_owner = -1; + netdev_queue_numa_node_write(queue, NUMA_NO_NODE); + queue->dev = dev; +#ifdef CONFIG_BQL + dql_init(&queue->dql, HZ); +#endif +} + +static void netif_free_tx_queues(struct net_device *dev) +{ + kvfree(dev->_tx); +} + +static int netif_alloc_netdev_queues(struct net_device *dev) +{ + unsigned int count = dev->num_tx_queues; + struct netdev_queue *tx; + size_t sz = count * sizeof(*tx); + + if (count < 1 || count > 0xffff) + return -EINVAL; + + tx = kvzalloc(sz, GFP_KERNEL_ACCOUNT | __GFP_RETRY_MAYFAIL); + if (!tx) + return -ENOMEM; + + dev->_tx = tx; + + netdev_for_each_tx_queue(dev, netdev_init_one_queue, NULL); + spin_lock_init(&dev->tx_global_lock); + + return 0; +} + +void netif_tx_stop_all_queues(struct net_device *dev) +{ + unsigned int i; + + for (i = 0; i < dev->num_tx_queues; i++) { + struct netdev_queue *txq = netdev_get_tx_queue(dev, i); + + netif_tx_stop_queue(txq); + } +} +EXPORT_SYMBOL(netif_tx_stop_all_queues); + +/** + * register_netdevice() - register a network device + * @dev: device to register + * + * Take a prepared network device structure and make it externally accessible. + * A %NETDEV_REGISTER message is sent to the netdev notifier chain. + * Callers must hold the rtnl lock - you may want register_netdev() + * instead of this. + */ +int register_netdevice(struct net_device *dev) +{ + int ret; + struct net *net = dev_net(dev); + + BUILD_BUG_ON(sizeof(netdev_features_t) * BITS_PER_BYTE < + NETDEV_FEATURE_COUNT); + BUG_ON(dev_boot_phase); + ASSERT_RTNL(); + + might_sleep(); + + /* When net_device's are persistent, this will be fatal. */ + BUG_ON(dev->reg_state != NETREG_UNINITIALIZED); + BUG_ON(!net); + + ret = ethtool_check_ops(dev->ethtool_ops); + if (ret) + return ret; + + spin_lock_init(&dev->addr_list_lock); + netdev_set_addr_lockdep_class(dev); + + ret = dev_get_valid_name(net, dev, dev->name); + if (ret < 0) + goto out; + + ret = -ENOMEM; + dev->name_node = netdev_name_node_head_alloc(dev); + if (!dev->name_node) + goto out; + + /* Init, if this function is available */ + if (dev->netdev_ops->ndo_init) { + ret = dev->netdev_ops->ndo_init(dev); + if (ret) { + if (ret > 0) + ret = -EIO; + goto err_free_name; + } + } + + if (((dev->hw_features | dev->features) & + NETIF_F_HW_VLAN_CTAG_FILTER) && + (!dev->netdev_ops->ndo_vlan_rx_add_vid || + !dev->netdev_ops->ndo_vlan_rx_kill_vid)) { + netdev_WARN(dev, "Buggy VLAN acceleration in driver!\n"); + ret = -EINVAL; + goto err_uninit; + } + + ret = -EBUSY; + if (!dev->ifindex) + dev->ifindex = dev_new_index(net); + else if (__dev_get_by_index(net, dev->ifindex)) + goto err_uninit; + + /* Transfer changeable features to wanted_features and enable + * software offloads (GSO and GRO). + */ + dev->hw_features |= (NETIF_F_SOFT_FEATURES | NETIF_F_SOFT_FEATURES_OFF); + dev->features |= NETIF_F_SOFT_FEATURES; + + if (dev->udp_tunnel_nic_info) { + dev->features |= NETIF_F_RX_UDP_TUNNEL_PORT; + dev->hw_features |= NETIF_F_RX_UDP_TUNNEL_PORT; + } + + dev->wanted_features = dev->features & dev->hw_features; + + if (!(dev->flags & IFF_LOOPBACK)) + dev->hw_features |= NETIF_F_NOCACHE_COPY; + + /* If IPv4 TCP segmentation offload is supported we should also + * allow the device to enable segmenting the frame with the option + * of ignoring a static IP ID value. This doesn't enable the + * feature itself but allows the user to enable it later. + */ + if (dev->hw_features & NETIF_F_TSO) + dev->hw_features |= NETIF_F_TSO_MANGLEID; + if (dev->vlan_features & NETIF_F_TSO) + dev->vlan_features |= NETIF_F_TSO_MANGLEID; + if (dev->mpls_features & NETIF_F_TSO) + dev->mpls_features |= NETIF_F_TSO_MANGLEID; + if (dev->hw_enc_features & NETIF_F_TSO) + dev->hw_enc_features |= NETIF_F_TSO_MANGLEID; + + /* Make NETIF_F_HIGHDMA inheritable to VLAN devices. + */ + dev->vlan_features |= NETIF_F_HIGHDMA; + + /* Make NETIF_F_SG inheritable to tunnel devices. + */ + dev->hw_enc_features |= NETIF_F_SG | NETIF_F_GSO_PARTIAL; + + /* Make NETIF_F_SG inheritable to MPLS. + */ + dev->mpls_features |= NETIF_F_SG; + + ret = call_netdevice_notifiers(NETDEV_POST_INIT, dev); + ret = notifier_to_errno(ret); + if (ret) + goto err_uninit; + + ret = netdev_register_kobject(dev); + write_lock(&dev_base_lock); + dev->reg_state = ret ? NETREG_UNREGISTERED : NETREG_REGISTERED; + write_unlock(&dev_base_lock); + if (ret) + goto err_uninit; + + __netdev_update_features(dev); + + /* + * Default initial state at registry is that the + * device is present. + */ + + set_bit(__LINK_STATE_PRESENT, &dev->state); + + linkwatch_init_dev(dev); + + dev_init_scheduler(dev); + + netdev_hold(dev, &dev->dev_registered_tracker, GFP_KERNEL); + list_netdevice(dev); + + add_device_randomness(dev->dev_addr, dev->addr_len); + + /* If the device has permanent device address, driver should + * set dev_addr and also addr_assign_type should be set to + * NET_ADDR_PERM (default value). + */ + if (dev->addr_assign_type == NET_ADDR_PERM) + memcpy(dev->perm_addr, dev->dev_addr, dev->addr_len); + + /* Notify protocols, that a new device appeared. */ + ret = call_netdevice_notifiers(NETDEV_REGISTER, dev); + ret = notifier_to_errno(ret); + if (ret) { + /* Expect explicit free_netdev() on failure */ + dev->needs_free_netdev = false; + unregister_netdevice_queue(dev, NULL); + goto out; + } + /* + * Prevent userspace races by waiting until the network + * device is fully setup before sending notifications. + */ + if (!dev->rtnl_link_ops || + dev->rtnl_link_state == RTNL_LINK_INITIALIZED) + rtmsg_ifinfo(RTM_NEWLINK, dev, ~0U, GFP_KERNEL); + +out: + return ret; + +err_uninit: + if (dev->netdev_ops->ndo_uninit) + dev->netdev_ops->ndo_uninit(dev); + if (dev->priv_destructor) + dev->priv_destructor(dev); +err_free_name: + netdev_name_node_free(dev->name_node); + goto out; +} +EXPORT_SYMBOL(register_netdevice); + +/** + * init_dummy_netdev - init a dummy network device for NAPI + * @dev: device to init + * + * This takes a network device structure and initialize the minimum + * amount of fields so it can be used to schedule NAPI polls without + * registering a full blown interface. This is to be used by drivers + * that need to tie several hardware interfaces to a single NAPI + * poll scheduler due to HW limitations. + */ +int init_dummy_netdev(struct net_device *dev) +{ + /* Clear everything. Note we don't initialize spinlocks + * are they aren't supposed to be taken by any of the + * NAPI code and this dummy netdev is supposed to be + * only ever used for NAPI polls + */ + memset(dev, 0, sizeof(struct net_device)); + + /* make sure we BUG if trying to hit standard + * register/unregister code path + */ + dev->reg_state = NETREG_DUMMY; + + /* NAPI wants this */ + INIT_LIST_HEAD(&dev->napi_list); + + /* a dummy interface is started by default */ + set_bit(__LINK_STATE_PRESENT, &dev->state); + set_bit(__LINK_STATE_START, &dev->state); + + /* napi_busy_loop stats accounting wants this */ + dev_net_set(dev, &init_net); + + /* Note : We dont allocate pcpu_refcnt for dummy devices, + * because users of this 'device' dont need to change + * its refcount. + */ + + return 0; +} +EXPORT_SYMBOL_GPL(init_dummy_netdev); + + +/** + * register_netdev - register a network device + * @dev: device to register + * + * Take a completed network device structure and add it to the kernel + * interfaces. A %NETDEV_REGISTER message is sent to the netdev notifier + * chain. 0 is returned on success. A negative errno code is returned + * on a failure to set up the device, or if the name is a duplicate. + * + * This is a wrapper around register_netdevice that takes the rtnl semaphore + * and expands the device name if you passed a format string to + * alloc_netdev. + */ +int register_netdev(struct net_device *dev) +{ + int err; + + if (rtnl_lock_killable()) + return -EINTR; + err = register_netdevice(dev); + rtnl_unlock(); + return err; +} +EXPORT_SYMBOL(register_netdev); + +int netdev_refcnt_read(const struct net_device *dev) +{ +#ifdef CONFIG_PCPU_DEV_REFCNT + int i, refcnt = 0; + + for_each_possible_cpu(i) + refcnt += *per_cpu_ptr(dev->pcpu_refcnt, i); + return refcnt; +#else + return refcount_read(&dev->dev_refcnt); +#endif +} +EXPORT_SYMBOL(netdev_refcnt_read); + +int netdev_unregister_timeout_secs __read_mostly = 10; + +#define WAIT_REFS_MIN_MSECS 1 +#define WAIT_REFS_MAX_MSECS 250 +/** + * netdev_wait_allrefs_any - wait until all references are gone. + * @list: list of net_devices to wait on + * + * This is called when unregistering network devices. + * + * Any protocol or device that holds a reference should register + * for netdevice notification, and cleanup and put back the + * reference if they receive an UNREGISTER event. + * We can get stuck here if buggy protocols don't correctly + * call dev_put. + */ +static struct net_device *netdev_wait_allrefs_any(struct list_head *list) +{ + unsigned long rebroadcast_time, warning_time; + struct net_device *dev; + int wait = 0; + + rebroadcast_time = warning_time = jiffies; + + list_for_each_entry(dev, list, todo_list) + if (netdev_refcnt_read(dev) == 1) + return dev; + + while (true) { + if (time_after(jiffies, rebroadcast_time + 1 * HZ)) { + rtnl_lock(); + + /* Rebroadcast unregister notification */ + list_for_each_entry(dev, list, todo_list) + call_netdevice_notifiers(NETDEV_UNREGISTER, dev); + + __rtnl_unlock(); + rcu_barrier(); + rtnl_lock(); + + list_for_each_entry(dev, list, todo_list) + if (test_bit(__LINK_STATE_LINKWATCH_PENDING, + &dev->state)) { + /* We must not have linkwatch events + * pending on unregister. If this + * happens, we simply run the queue + * unscheduled, resulting in a noop + * for this device. + */ + linkwatch_run_queue(); + break; + } + + __rtnl_unlock(); + + rebroadcast_time = jiffies; + } + + if (!wait) { + rcu_barrier(); + wait = WAIT_REFS_MIN_MSECS; + } else { + msleep(wait); + wait = min(wait << 1, WAIT_REFS_MAX_MSECS); + } + + list_for_each_entry(dev, list, todo_list) + if (netdev_refcnt_read(dev) == 1) + return dev; + + if (time_after(jiffies, warning_time + + READ_ONCE(netdev_unregister_timeout_secs) * HZ)) { + list_for_each_entry(dev, list, todo_list) { + pr_emerg("unregister_netdevice: waiting for %s to become free. Usage count = %d\n", + dev->name, netdev_refcnt_read(dev)); + ref_tracker_dir_print(&dev->refcnt_tracker, 10); + } + + warning_time = jiffies; + } + } +} + +/* The sequence is: + * + * rtnl_lock(); + * ... + * register_netdevice(x1); + * register_netdevice(x2); + * ... + * unregister_netdevice(y1); + * unregister_netdevice(y2); + * ... + * rtnl_unlock(); + * free_netdev(y1); + * free_netdev(y2); + * + * We are invoked by rtnl_unlock(). + * This allows us to deal with problems: + * 1) We can delete sysfs objects which invoke hotplug + * without deadlocking with linkwatch via keventd. + * 2) Since we run with the RTNL semaphore not held, we can sleep + * safely in order to wait for the netdev refcnt to drop to zero. + * + * We must not return until all unregister events added during + * the interval the lock was held have been completed. + */ +void netdev_run_todo(void) +{ + struct net_device *dev, *tmp; + struct list_head list; +#ifdef CONFIG_LOCKDEP + struct list_head unlink_list; + + list_replace_init(&net_unlink_list, &unlink_list); + + while (!list_empty(&unlink_list)) { + struct net_device *dev = list_first_entry(&unlink_list, + struct net_device, + unlink_list); + list_del_init(&dev->unlink_list); + dev->nested_level = dev->lower_level - 1; + } +#endif + + /* Snapshot list, allow later requests */ + list_replace_init(&net_todo_list, &list); + + __rtnl_unlock(); + + /* Wait for rcu callbacks to finish before next phase */ + if (!list_empty(&list)) + rcu_barrier(); + + list_for_each_entry_safe(dev, tmp, &list, todo_list) { + if (unlikely(dev->reg_state != NETREG_UNREGISTERING)) { + netdev_WARN(dev, "run_todo but not unregistering\n"); + list_del(&dev->todo_list); + continue; + } + + write_lock(&dev_base_lock); + dev->reg_state = NETREG_UNREGISTERED; + write_unlock(&dev_base_lock); + linkwatch_forget_dev(dev); + } + + while (!list_empty(&list)) { + dev = netdev_wait_allrefs_any(&list); + list_del(&dev->todo_list); + + /* paranoia */ + BUG_ON(netdev_refcnt_read(dev) != 1); + BUG_ON(!list_empty(&dev->ptype_all)); + BUG_ON(!list_empty(&dev->ptype_specific)); + WARN_ON(rcu_access_pointer(dev->ip_ptr)); + WARN_ON(rcu_access_pointer(dev->ip6_ptr)); + + if (dev->priv_destructor) + dev->priv_destructor(dev); + if (dev->needs_free_netdev) + free_netdev(dev); + + if (atomic_dec_and_test(&dev_net(dev)->dev_unreg_count)) + wake_up(&netdev_unregistering_wq); + + /* Free network device */ + kobject_put(&dev->dev.kobj); + } +} + +/* Convert net_device_stats to rtnl_link_stats64. rtnl_link_stats64 has + * all the same fields in the same order as net_device_stats, with only + * the type differing, but rtnl_link_stats64 may have additional fields + * at the end for newer counters. + */ +void netdev_stats_to_stats64(struct rtnl_link_stats64 *stats64, + const struct net_device_stats *netdev_stats) +{ + size_t i, n = sizeof(*netdev_stats) / sizeof(atomic_long_t); + const atomic_long_t *src = (atomic_long_t *)netdev_stats; + u64 *dst = (u64 *)stats64; + + BUILD_BUG_ON(n > sizeof(*stats64) / sizeof(u64)); + for (i = 0; i < n; i++) + dst[i] = (unsigned long)atomic_long_read(&src[i]); + /* zero out counters that only exist in rtnl_link_stats64 */ + memset((char *)stats64 + n * sizeof(u64), 0, + sizeof(*stats64) - n * sizeof(u64)); +} +EXPORT_SYMBOL(netdev_stats_to_stats64); + +struct net_device_core_stats __percpu *netdev_core_stats_alloc(struct net_device *dev) +{ + struct net_device_core_stats __percpu *p; + + p = alloc_percpu_gfp(struct net_device_core_stats, + GFP_ATOMIC | __GFP_NOWARN); + + if (p && cmpxchg(&dev->core_stats, NULL, p)) + free_percpu(p); + + /* This READ_ONCE() pairs with the cmpxchg() above */ + return READ_ONCE(dev->core_stats); +} +EXPORT_SYMBOL(netdev_core_stats_alloc); + +/** + * dev_get_stats - get network device statistics + * @dev: device to get statistics from + * @storage: place to store stats + * + * Get network statistics from device. Return @storage. + * The device driver may provide its own method by setting + * dev->netdev_ops->get_stats64 or dev->netdev_ops->get_stats; + * otherwise the internal statistics structure is used. + */ +struct rtnl_link_stats64 *dev_get_stats(struct net_device *dev, + struct rtnl_link_stats64 *storage) +{ + const struct net_device_ops *ops = dev->netdev_ops; + const struct net_device_core_stats __percpu *p; + + if (ops->ndo_get_stats64) { + memset(storage, 0, sizeof(*storage)); + ops->ndo_get_stats64(dev, storage); + } else if (ops->ndo_get_stats) { + netdev_stats_to_stats64(storage, ops->ndo_get_stats(dev)); + } else { + netdev_stats_to_stats64(storage, &dev->stats); + } + + /* This READ_ONCE() pairs with the write in netdev_core_stats_alloc() */ + p = READ_ONCE(dev->core_stats); + if (p) { + const struct net_device_core_stats *core_stats; + int i; + + for_each_possible_cpu(i) { + core_stats = per_cpu_ptr(p, i); + storage->rx_dropped += READ_ONCE(core_stats->rx_dropped); + storage->tx_dropped += READ_ONCE(core_stats->tx_dropped); + storage->rx_nohandler += READ_ONCE(core_stats->rx_nohandler); + storage->rx_otherhost_dropped += READ_ONCE(core_stats->rx_otherhost_dropped); + } + } + return storage; +} +EXPORT_SYMBOL(dev_get_stats); + +/** + * dev_fetch_sw_netstats - get per-cpu network device statistics + * @s: place to store stats + * @netstats: per-cpu network stats to read from + * + * Read per-cpu network statistics and populate the related fields in @s. + */ +void dev_fetch_sw_netstats(struct rtnl_link_stats64 *s, + const struct pcpu_sw_netstats __percpu *netstats) +{ + int cpu; + + for_each_possible_cpu(cpu) { + u64 rx_packets, rx_bytes, tx_packets, tx_bytes; + const struct pcpu_sw_netstats *stats; + unsigned int start; + + stats = per_cpu_ptr(netstats, cpu); + do { + start = u64_stats_fetch_begin_irq(&stats->syncp); + rx_packets = u64_stats_read(&stats->rx_packets); + rx_bytes = u64_stats_read(&stats->rx_bytes); + tx_packets = u64_stats_read(&stats->tx_packets); + tx_bytes = u64_stats_read(&stats->tx_bytes); + } while (u64_stats_fetch_retry_irq(&stats->syncp, start)); + + s->rx_packets += rx_packets; + s->rx_bytes += rx_bytes; + s->tx_packets += tx_packets; + s->tx_bytes += tx_bytes; + } +} +EXPORT_SYMBOL_GPL(dev_fetch_sw_netstats); + +/** + * dev_get_tstats64 - ndo_get_stats64 implementation + * @dev: device to get statistics from + * @s: place to store stats + * + * Populate @s from dev->stats and dev->tstats. Can be used as + * ndo_get_stats64() callback. + */ +void dev_get_tstats64(struct net_device *dev, struct rtnl_link_stats64 *s) +{ + netdev_stats_to_stats64(s, &dev->stats); + dev_fetch_sw_netstats(s, dev->tstats); +} +EXPORT_SYMBOL_GPL(dev_get_tstats64); + +struct netdev_queue *dev_ingress_queue_create(struct net_device *dev) +{ + struct netdev_queue *queue = dev_ingress_queue(dev); + +#ifdef CONFIG_NET_CLS_ACT + if (queue) + return queue; + queue = kzalloc(sizeof(*queue), GFP_KERNEL); + if (!queue) + return NULL; + netdev_init_one_queue(dev, queue, NULL); + RCU_INIT_POINTER(queue->qdisc, &noop_qdisc); + RCU_INIT_POINTER(queue->qdisc_sleeping, &noop_qdisc); + rcu_assign_pointer(dev->ingress_queue, queue); +#endif + return queue; +} + +static const struct ethtool_ops default_ethtool_ops; + +void netdev_set_default_ethtool_ops(struct net_device *dev, + const struct ethtool_ops *ops) +{ + if (dev->ethtool_ops == &default_ethtool_ops) + dev->ethtool_ops = ops; +} +EXPORT_SYMBOL_GPL(netdev_set_default_ethtool_ops); + +void netdev_freemem(struct net_device *dev) +{ + char *addr = (char *)dev - dev->padded; + + kvfree(addr); +} + +/** + * alloc_netdev_mqs - allocate network device + * @sizeof_priv: size of private data to allocate space for + * @name: device name format string + * @name_assign_type: origin of device name + * @setup: callback to initialize device + * @txqs: the number of TX subqueues to allocate + * @rxqs: the number of RX subqueues to allocate + * + * Allocates a struct net_device with private data area for driver use + * and performs basic initialization. Also allocates subqueue structs + * for each queue on the device. + */ +struct net_device *alloc_netdev_mqs(int sizeof_priv, const char *name, + unsigned char name_assign_type, + void (*setup)(struct net_device *), + unsigned int txqs, unsigned int rxqs) +{ + struct net_device *dev; + unsigned int alloc_size; + struct net_device *p; + + BUG_ON(strlen(name) >= sizeof(dev->name)); + + if (txqs < 1) { + pr_err("alloc_netdev: Unable to allocate device with zero queues\n"); + return NULL; + } + + if (rxqs < 1) { + pr_err("alloc_netdev: Unable to allocate device with zero RX queues\n"); + return NULL; + } + + alloc_size = sizeof(struct net_device); + if (sizeof_priv) { + /* ensure 32-byte alignment of private area */ + alloc_size = ALIGN(alloc_size, NETDEV_ALIGN); + alloc_size += sizeof_priv; + } + /* ensure 32-byte alignment of whole construct */ + alloc_size += NETDEV_ALIGN - 1; + + p = kvzalloc(alloc_size, GFP_KERNEL_ACCOUNT | __GFP_RETRY_MAYFAIL); + if (!p) + return NULL; + + dev = PTR_ALIGN(p, NETDEV_ALIGN); + dev->padded = (char *)dev - (char *)p; + + ref_tracker_dir_init(&dev->refcnt_tracker, 128); +#ifdef CONFIG_PCPU_DEV_REFCNT + dev->pcpu_refcnt = alloc_percpu(int); + if (!dev->pcpu_refcnt) + goto free_dev; + __dev_hold(dev); +#else + refcount_set(&dev->dev_refcnt, 1); +#endif + + if (dev_addr_init(dev)) + goto free_pcpu; + + dev_mc_init(dev); + dev_uc_init(dev); + + dev_net_set(dev, &init_net); + + dev->gso_max_size = GSO_LEGACY_MAX_SIZE; + dev->gso_max_segs = GSO_MAX_SEGS; + dev->gro_max_size = GRO_LEGACY_MAX_SIZE; + dev->tso_max_size = TSO_LEGACY_MAX_SIZE; + dev->tso_max_segs = TSO_MAX_SEGS; + dev->upper_level = 1; + dev->lower_level = 1; +#ifdef CONFIG_LOCKDEP + dev->nested_level = 0; + INIT_LIST_HEAD(&dev->unlink_list); +#endif + + INIT_LIST_HEAD(&dev->napi_list); + INIT_LIST_HEAD(&dev->unreg_list); + INIT_LIST_HEAD(&dev->close_list); + INIT_LIST_HEAD(&dev->link_watch_list); + INIT_LIST_HEAD(&dev->adj_list.upper); + INIT_LIST_HEAD(&dev->adj_list.lower); + INIT_LIST_HEAD(&dev->ptype_all); + INIT_LIST_HEAD(&dev->ptype_specific); + INIT_LIST_HEAD(&dev->net_notifier_list); +#ifdef CONFIG_NET_SCHED + hash_init(dev->qdisc_hash); +#endif + dev->priv_flags = IFF_XMIT_DST_RELEASE | IFF_XMIT_DST_RELEASE_PERM; + setup(dev); + + if (!dev->tx_queue_len) { + dev->priv_flags |= IFF_NO_QUEUE; + dev->tx_queue_len = DEFAULT_TX_QUEUE_LEN; + } + + dev->num_tx_queues = txqs; + dev->real_num_tx_queues = txqs; + if (netif_alloc_netdev_queues(dev)) + goto free_all; + + dev->num_rx_queues = rxqs; + dev->real_num_rx_queues = rxqs; + if (netif_alloc_rx_queues(dev)) + goto free_all; + + strcpy(dev->name, name); + dev->name_assign_type = name_assign_type; + dev->group = INIT_NETDEV_GROUP; + if (!dev->ethtool_ops) + dev->ethtool_ops = &default_ethtool_ops; + + nf_hook_netdev_init(dev); + + return dev; + +free_all: + free_netdev(dev); + return NULL; + +free_pcpu: +#ifdef CONFIG_PCPU_DEV_REFCNT + free_percpu(dev->pcpu_refcnt); +free_dev: +#endif + netdev_freemem(dev); + return NULL; +} +EXPORT_SYMBOL(alloc_netdev_mqs); + +/** + * free_netdev - free network device + * @dev: device + * + * This function does the last stage of destroying an allocated device + * interface. The reference to the device object is released. If this + * is the last reference then it will be freed.Must be called in process + * context. + */ +void free_netdev(struct net_device *dev) +{ + struct napi_struct *p, *n; + + might_sleep(); + + /* When called immediately after register_netdevice() failed the unwind + * handling may still be dismantling the device. Handle that case by + * deferring the free. + */ + if (dev->reg_state == NETREG_UNREGISTERING) { + ASSERT_RTNL(); + dev->needs_free_netdev = true; + return; + } + + netif_free_tx_queues(dev); + netif_free_rx_queues(dev); + + kfree(rcu_dereference_protected(dev->ingress_queue, 1)); + + /* Flush device addresses */ + dev_addr_flush(dev); + + list_for_each_entry_safe(p, n, &dev->napi_list, dev_list) + netif_napi_del(p); + + ref_tracker_dir_exit(&dev->refcnt_tracker); +#ifdef CONFIG_PCPU_DEV_REFCNT + free_percpu(dev->pcpu_refcnt); + dev->pcpu_refcnt = NULL; +#endif + free_percpu(dev->core_stats); + dev->core_stats = NULL; + free_percpu(dev->xdp_bulkq); + dev->xdp_bulkq = NULL; + + /* Compatibility with error handling in drivers */ + if (dev->reg_state == NETREG_UNINITIALIZED) { + netdev_freemem(dev); + return; + } + + BUG_ON(dev->reg_state != NETREG_UNREGISTERED); + dev->reg_state = NETREG_RELEASED; + + /* will free via device release */ + put_device(&dev->dev); +} +EXPORT_SYMBOL(free_netdev); + +/** + * synchronize_net - Synchronize with packet receive processing + * + * Wait for packets currently being received to be done. + * Does not block later packets from starting. + */ +void synchronize_net(void) +{ + might_sleep(); + if (rtnl_is_locked()) + synchronize_rcu_expedited(); + else + synchronize_rcu(); +} +EXPORT_SYMBOL(synchronize_net); + +/** + * unregister_netdevice_queue - remove device from the kernel + * @dev: device + * @head: list + * + * This function shuts down a device interface and removes it + * from the kernel tables. + * If head not NULL, device is queued to be unregistered later. + * + * Callers must hold the rtnl semaphore. You may want + * unregister_netdev() instead of this. + */ + +void unregister_netdevice_queue(struct net_device *dev, struct list_head *head) +{ + ASSERT_RTNL(); + + if (head) { + list_move_tail(&dev->unreg_list, head); + } else { + LIST_HEAD(single); + + list_add(&dev->unreg_list, &single); + unregister_netdevice_many(&single); + } +} +EXPORT_SYMBOL(unregister_netdevice_queue); + +/** + * unregister_netdevice_many - unregister many devices + * @head: list of devices + * + * Note: As most callers use a stack allocated list_head, + * we force a list_del() to make sure stack wont be corrupted later. + */ +void unregister_netdevice_many(struct list_head *head) +{ + struct net_device *dev, *tmp; + LIST_HEAD(close_head); + + BUG_ON(dev_boot_phase); + ASSERT_RTNL(); + + if (list_empty(head)) + return; + + list_for_each_entry_safe(dev, tmp, head, unreg_list) { + /* Some devices call without registering + * for initialization unwind. Remove those + * devices and proceed with the remaining. + */ + if (dev->reg_state == NETREG_UNINITIALIZED) { + pr_debug("unregister_netdevice: device %s/%p never was registered\n", + dev->name, dev); + + WARN_ON(1); + list_del(&dev->unreg_list); + continue; + } + dev->dismantle = true; + BUG_ON(dev->reg_state != NETREG_REGISTERED); + } + + /* If device is running, close it first. */ + list_for_each_entry(dev, head, unreg_list) + list_add_tail(&dev->close_list, &close_head); + dev_close_many(&close_head, true); + + list_for_each_entry(dev, head, unreg_list) { + /* And unlink it from device chain. */ + write_lock(&dev_base_lock); + unlist_netdevice(dev, false); + dev->reg_state = NETREG_UNREGISTERING; + write_unlock(&dev_base_lock); + } + flush_all_backlogs(); + + synchronize_net(); + + list_for_each_entry(dev, head, unreg_list) { + struct sk_buff *skb = NULL; + + /* Shutdown queueing discipline. */ + dev_shutdown(dev); + + dev_xdp_uninstall(dev); + + netdev_offload_xstats_disable_all(dev); + + /* Notify protocols, that we are about to destroy + * this device. They should clean all the things. + */ + call_netdevice_notifiers(NETDEV_UNREGISTER, dev); + + if (!dev->rtnl_link_ops || + dev->rtnl_link_state == RTNL_LINK_INITIALIZED) + skb = rtmsg_ifinfo_build_skb(RTM_DELLINK, dev, ~0U, 0, + GFP_KERNEL, NULL, 0); + + /* + * Flush the unicast and multicast chains + */ + dev_uc_flush(dev); + dev_mc_flush(dev); + + netdev_name_node_alt_flush(dev); + netdev_name_node_free(dev->name_node); + + if (dev->netdev_ops->ndo_uninit) + dev->netdev_ops->ndo_uninit(dev); + + if (skb) + rtmsg_ifinfo_send(skb, dev, GFP_KERNEL); + + /* Notifier chain MUST detach us all upper devices. */ + WARN_ON(netdev_has_any_upper_dev(dev)); + WARN_ON(netdev_has_any_lower_dev(dev)); + + /* Remove entries from kobject tree */ + netdev_unregister_kobject(dev); +#ifdef CONFIG_XPS + /* Remove XPS queueing entries */ + netif_reset_xps_queues_gt(dev, 0); +#endif + } + + synchronize_net(); + + list_for_each_entry(dev, head, unreg_list) { + netdev_put(dev, &dev->dev_registered_tracker); + net_set_todo(dev); + } + + list_del(head); +} +EXPORT_SYMBOL(unregister_netdevice_many); + +/** + * unregister_netdev - remove device from the kernel + * @dev: device + * + * This function shuts down a device interface and removes it + * from the kernel tables. + * + * This is just a wrapper for unregister_netdevice that takes + * the rtnl semaphore. In general you want to use this and not + * unregister_netdevice. + */ +void unregister_netdev(struct net_device *dev) +{ + rtnl_lock(); + unregister_netdevice(dev); + rtnl_unlock(); +} +EXPORT_SYMBOL(unregister_netdev); + +/** + * __dev_change_net_namespace - move device to different nethost namespace + * @dev: device + * @net: network namespace + * @pat: If not NULL name pattern to try if the current device name + * is already taken in the destination network namespace. + * @new_ifindex: If not zero, specifies device index in the target + * namespace. + * + * This function shuts down a device interface and moves it + * to a new network namespace. On success 0 is returned, on + * a failure a netagive errno code is returned. + * + * Callers must hold the rtnl semaphore. + */ + +int __dev_change_net_namespace(struct net_device *dev, struct net *net, + const char *pat, int new_ifindex) +{ + struct netdev_name_node *name_node; + struct net *net_old = dev_net(dev); + char new_name[IFNAMSIZ] = {}; + int err, new_nsid; + + ASSERT_RTNL(); + + /* Don't allow namespace local devices to be moved. */ + err = -EINVAL; + if (dev->features & NETIF_F_NETNS_LOCAL) + goto out; + + /* Ensure the device has been registrered */ + if (dev->reg_state != NETREG_REGISTERED) + goto out; + + /* Get out if there is nothing todo */ + err = 0; + if (net_eq(net_old, net)) + goto out; + + /* Pick the destination device name, and ensure + * we can use it in the destination network namespace. + */ + err = -EEXIST; + if (netdev_name_in_use(net, dev->name)) { + /* We get here if we can't use the current device name */ + if (!pat) + goto out; + err = dev_prep_valid_name(net, dev, pat, new_name); + if (err < 0) + goto out; + } + /* Check that none of the altnames conflicts. */ + err = -EEXIST; + netdev_for_each_altname(dev, name_node) + if (netdev_name_in_use(net, name_node->name)) + goto out; + + /* Check that new_ifindex isn't used yet. */ + err = -EBUSY; + if (new_ifindex && __dev_get_by_index(net, new_ifindex)) + goto out; + + /* + * And now a mini version of register_netdevice unregister_netdevice. + */ + + /* If device is running close it first. */ + dev_close(dev); + + /* And unlink it from device chain */ + unlist_netdevice(dev, true); + + synchronize_net(); + + /* Shutdown queueing discipline. */ + dev_shutdown(dev); + + /* Notify protocols, that we are about to destroy + * this device. They should clean all the things. + * + * Note that dev->reg_state stays at NETREG_REGISTERED. + * This is wanted because this way 8021q and macvlan know + * the device is just moving and can keep their slaves up. + */ + call_netdevice_notifiers(NETDEV_UNREGISTER, dev); + rcu_barrier(); + + new_nsid = peernet2id_alloc(dev_net(dev), net, GFP_KERNEL); + /* If there is an ifindex conflict assign a new one */ + if (!new_ifindex) { + if (__dev_get_by_index(net, dev->ifindex)) + new_ifindex = dev_new_index(net); + else + new_ifindex = dev->ifindex; + } + + rtmsg_ifinfo_newnet(RTM_DELLINK, dev, ~0U, GFP_KERNEL, &new_nsid, + new_ifindex); + + /* + * Flush the unicast and multicast chains + */ + dev_uc_flush(dev); + dev_mc_flush(dev); + + /* Send a netdev-removed uevent to the old namespace */ + kobject_uevent(&dev->dev.kobj, KOBJ_REMOVE); + netdev_adjacent_del_links(dev); + + /* Move per-net netdevice notifiers that are following the netdevice */ + move_netdevice_notifiers_dev_net(dev, net); + + /* Actually switch the network namespace */ + dev_net_set(dev, net); + dev->ifindex = new_ifindex; + + /* Send a netdev-add uevent to the new namespace */ + kobject_uevent(&dev->dev.kobj, KOBJ_ADD); + netdev_adjacent_add_links(dev); + + if (new_name[0]) /* Rename the netdev to prepared name */ + strscpy(dev->name, new_name, IFNAMSIZ); + + /* Fixup kobjects */ + err = device_rename(&dev->dev, dev->name); + WARN_ON(err); + + /* Adapt owner in case owning user namespace of target network + * namespace is different from the original one. + */ + err = netdev_change_owner(dev, net_old, net); + WARN_ON(err); + + /* Add the device back in the hashes */ + list_netdevice(dev); + + /* Notify protocols, that a new device appeared. */ + call_netdevice_notifiers(NETDEV_REGISTER, dev); + + /* + * Prevent userspace races by waiting until the network + * device is fully setup before sending notifications. + */ + rtmsg_ifinfo(RTM_NEWLINK, dev, ~0U, GFP_KERNEL); + + synchronize_net(); + err = 0; +out: + return err; +} +EXPORT_SYMBOL_GPL(__dev_change_net_namespace); + +static int dev_cpu_dead(unsigned int oldcpu) +{ + struct sk_buff **list_skb; + struct sk_buff *skb; + unsigned int cpu; + struct softnet_data *sd, *oldsd, *remsd = NULL; + + local_irq_disable(); + cpu = smp_processor_id(); + sd = &per_cpu(softnet_data, cpu); + oldsd = &per_cpu(softnet_data, oldcpu); + + /* Find end of our completion_queue. */ + list_skb = &sd->completion_queue; + while (*list_skb) + list_skb = &(*list_skb)->next; + /* Append completion queue from offline CPU. */ + *list_skb = oldsd->completion_queue; + oldsd->completion_queue = NULL; + + /* Append output queue from offline CPU. */ + if (oldsd->output_queue) { + *sd->output_queue_tailp = oldsd->output_queue; + sd->output_queue_tailp = oldsd->output_queue_tailp; + oldsd->output_queue = NULL; + oldsd->output_queue_tailp = &oldsd->output_queue; + } + /* Append NAPI poll list from offline CPU, with one exception : + * process_backlog() must be called by cpu owning percpu backlog. + * We properly handle process_queue & input_pkt_queue later. + */ + while (!list_empty(&oldsd->poll_list)) { + struct napi_struct *napi = list_first_entry(&oldsd->poll_list, + struct napi_struct, + poll_list); + + list_del_init(&napi->poll_list); + if (napi->poll == process_backlog) + napi->state = 0; + else + ____napi_schedule(sd, napi); + } + + raise_softirq_irqoff(NET_TX_SOFTIRQ); + local_irq_enable(); + +#ifdef CONFIG_RPS + remsd = oldsd->rps_ipi_list; + oldsd->rps_ipi_list = NULL; +#endif + /* send out pending IPI's on offline CPU */ + net_rps_send_ipi(remsd); + + /* Process offline CPU's input_pkt_queue */ + while ((skb = __skb_dequeue(&oldsd->process_queue))) { + netif_rx(skb); + input_queue_head_incr(oldsd); + } + while ((skb = skb_dequeue(&oldsd->input_pkt_queue))) { + netif_rx(skb); + input_queue_head_incr(oldsd); + } + + return 0; +} + +/** + * netdev_increment_features - increment feature set by one + * @all: current feature set + * @one: new feature set + * @mask: mask feature set + * + * Computes a new feature set after adding a device with feature set + * @one to the master device with current feature set @all. Will not + * enable anything that is off in @mask. Returns the new feature set. + */ +netdev_features_t netdev_increment_features(netdev_features_t all, + netdev_features_t one, netdev_features_t mask) +{ + if (mask & NETIF_F_HW_CSUM) + mask |= NETIF_F_CSUM_MASK; + mask |= NETIF_F_VLAN_CHALLENGED; + + all |= one & (NETIF_F_ONE_FOR_ALL | NETIF_F_CSUM_MASK) & mask; + all &= one | ~NETIF_F_ALL_FOR_ALL; + + /* If one device supports hw checksumming, set for all. */ + if (all & NETIF_F_HW_CSUM) + all &= ~(NETIF_F_CSUM_MASK & ~NETIF_F_HW_CSUM); + + return all; +} +EXPORT_SYMBOL(netdev_increment_features); + +static struct hlist_head * __net_init netdev_create_hash(void) +{ + int i; + struct hlist_head *hash; + + hash = kmalloc_array(NETDEV_HASHENTRIES, sizeof(*hash), GFP_KERNEL); + if (hash != NULL) + for (i = 0; i < NETDEV_HASHENTRIES; i++) + INIT_HLIST_HEAD(&hash[i]); + + return hash; +} + +/* Initialize per network namespace state */ +static int __net_init netdev_init(struct net *net) +{ + BUILD_BUG_ON(GRO_HASH_BUCKETS > + 8 * sizeof_field(struct napi_struct, gro_bitmask)); + + INIT_LIST_HEAD(&net->dev_base_head); + + net->dev_name_head = netdev_create_hash(); + if (net->dev_name_head == NULL) + goto err_name; + + net->dev_index_head = netdev_create_hash(); + if (net->dev_index_head == NULL) + goto err_idx; + + RAW_INIT_NOTIFIER_HEAD(&net->netdev_chain); + + return 0; + +err_idx: + kfree(net->dev_name_head); +err_name: + return -ENOMEM; +} + +/** + * netdev_drivername - network driver for the device + * @dev: network device + * + * Determine network driver for device. + */ +const char *netdev_drivername(const struct net_device *dev) +{ + const struct device_driver *driver; + const struct device *parent; + const char *empty = ""; + + parent = dev->dev.parent; + if (!parent) + return empty; + + driver = parent->driver; + if (driver && driver->name) + return driver->name; + return empty; +} + +static void __netdev_printk(const char *level, const struct net_device *dev, + struct va_format *vaf) +{ + if (dev && dev->dev.parent) { + dev_printk_emit(level[1] - '0', + dev->dev.parent, + "%s %s %s%s: %pV", + dev_driver_string(dev->dev.parent), + dev_name(dev->dev.parent), + netdev_name(dev), netdev_reg_state(dev), + vaf); + } else if (dev) { + printk("%s%s%s: %pV", + level, netdev_name(dev), netdev_reg_state(dev), vaf); + } else { + printk("%s(NULL net_device): %pV", level, vaf); + } +} + +void netdev_printk(const char *level, const struct net_device *dev, + const char *format, ...) +{ + struct va_format vaf; + va_list args; + + va_start(args, format); + + vaf.fmt = format; + vaf.va = &args; + + __netdev_printk(level, dev, &vaf); + + va_end(args); +} +EXPORT_SYMBOL(netdev_printk); + +#define define_netdev_printk_level(func, level) \ +void func(const struct net_device *dev, const char *fmt, ...) \ +{ \ + struct va_format vaf; \ + va_list args; \ + \ + va_start(args, fmt); \ + \ + vaf.fmt = fmt; \ + vaf.va = &args; \ + \ + __netdev_printk(level, dev, &vaf); \ + \ + va_end(args); \ +} \ +EXPORT_SYMBOL(func); + +define_netdev_printk_level(netdev_emerg, KERN_EMERG); +define_netdev_printk_level(netdev_alert, KERN_ALERT); +define_netdev_printk_level(netdev_crit, KERN_CRIT); +define_netdev_printk_level(netdev_err, KERN_ERR); +define_netdev_printk_level(netdev_warn, KERN_WARNING); +define_netdev_printk_level(netdev_notice, KERN_NOTICE); +define_netdev_printk_level(netdev_info, KERN_INFO); + +static void __net_exit netdev_exit(struct net *net) +{ + kfree(net->dev_name_head); + kfree(net->dev_index_head); + if (net != &init_net) + WARN_ON_ONCE(!list_empty(&net->dev_base_head)); +} + +static struct pernet_operations __net_initdata netdev_net_ops = { + .init = netdev_init, + .exit = netdev_exit, +}; + +static void __net_exit default_device_exit_net(struct net *net) +{ + struct netdev_name_node *name_node, *tmp; + struct net_device *dev, *aux; + /* + * Push all migratable network devices back to the + * initial network namespace + */ + ASSERT_RTNL(); + for_each_netdev_safe(net, dev, aux) { + int err; + char fb_name[IFNAMSIZ]; + + /* Ignore unmoveable devices (i.e. loopback) */ + if (dev->features & NETIF_F_NETNS_LOCAL) + continue; + + /* Leave virtual devices for the generic cleanup */ + if (dev->rtnl_link_ops && !dev->rtnl_link_ops->netns_refund) + continue; + + /* Push remaining network devices to init_net */ + snprintf(fb_name, IFNAMSIZ, "dev%d", dev->ifindex); + if (netdev_name_in_use(&init_net, fb_name)) + snprintf(fb_name, IFNAMSIZ, "dev%%d"); + + netdev_for_each_altname_safe(dev, name_node, tmp) + if (netdev_name_in_use(&init_net, name_node->name)) { + netdev_name_node_del(name_node); + synchronize_rcu(); + __netdev_name_node_alt_destroy(name_node); + } + + err = dev_change_net_namespace(dev, &init_net, fb_name); + if (err) { + pr_emerg("%s: failed to move %s to init_net: %d\n", + __func__, dev->name, err); + BUG(); + } + } +} + +static void __net_exit default_device_exit_batch(struct list_head *net_list) +{ + /* At exit all network devices most be removed from a network + * namespace. Do this in the reverse order of registration. + * Do this across as many network namespaces as possible to + * improve batching efficiency. + */ + struct net_device *dev; + struct net *net; + LIST_HEAD(dev_kill_list); + + rtnl_lock(); + list_for_each_entry(net, net_list, exit_list) { + default_device_exit_net(net); + cond_resched(); + } + + list_for_each_entry(net, net_list, exit_list) { + for_each_netdev_reverse(net, dev) { + if (dev->rtnl_link_ops && dev->rtnl_link_ops->dellink) + dev->rtnl_link_ops->dellink(dev, &dev_kill_list); + else + unregister_netdevice_queue(dev, &dev_kill_list); + } + } + unregister_netdevice_many(&dev_kill_list); + rtnl_unlock(); +} + +static struct pernet_operations __net_initdata default_device_ops = { + .exit_batch = default_device_exit_batch, +}; + +/* + * Initialize the DEV module. At boot time this walks the device list and + * unhooks any devices that fail to initialise (normally hardware not + * present) and leaves us with a valid list of present and active devices. + * + */ + +/* + * This is called single threaded during boot, so no need + * to take the rtnl semaphore. + */ +static int __init net_dev_init(void) +{ + int i, rc = -ENOMEM; + + BUG_ON(!dev_boot_phase); + + if (dev_proc_init()) + goto out; + + if (netdev_kobject_init()) + goto out; + + INIT_LIST_HEAD(&ptype_all); + for (i = 0; i < PTYPE_HASH_SIZE; i++) + INIT_LIST_HEAD(&ptype_base[i]); + + if (register_pernet_subsys(&netdev_net_ops)) + goto out; + + /* + * Initialise the packet receive queues. + */ + + for_each_possible_cpu(i) { + struct work_struct *flush = per_cpu_ptr(&flush_works, i); + struct softnet_data *sd = &per_cpu(softnet_data, i); + + INIT_WORK(flush, flush_backlog); + + skb_queue_head_init(&sd->input_pkt_queue); + skb_queue_head_init(&sd->process_queue); +#ifdef CONFIG_XFRM_OFFLOAD + skb_queue_head_init(&sd->xfrm_backlog); +#endif + INIT_LIST_HEAD(&sd->poll_list); + sd->output_queue_tailp = &sd->output_queue; +#ifdef CONFIG_RPS + INIT_CSD(&sd->csd, rps_trigger_softirq, sd); + sd->cpu = i; +#endif + INIT_CSD(&sd->defer_csd, trigger_rx_softirq, sd); + spin_lock_init(&sd->defer_lock); + + init_gro_hash(&sd->backlog); + sd->backlog.poll = process_backlog; + sd->backlog.weight = weight_p; + } + + dev_boot_phase = 0; + + /* The loopback device is special if any other network devices + * is present in a network namespace the loopback device must + * be present. Since we now dynamically allocate and free the + * loopback device ensure this invariant is maintained by + * keeping the loopback device as the first device on the + * list of network devices. Ensuring the loopback devices + * is the first device that appears and the last network device + * that disappears. + */ + if (register_pernet_device(&loopback_net_ops)) + goto out; + + if (register_pernet_device(&default_device_ops)) + goto out; + + open_softirq(NET_TX_SOFTIRQ, net_tx_action); + open_softirq(NET_RX_SOFTIRQ, net_rx_action); + + rc = cpuhp_setup_state_nocalls(CPUHP_NET_DEV_DEAD, "net/dev:dead", + NULL, dev_cpu_dead); + WARN_ON(rc < 0); + rc = 0; +out: + return rc; +} + +subsys_initcall(net_dev_init); diff --git a/net/core/dev.h b/net/core/dev.h new file mode 100644 index 000000000..db9ff8cd8 --- /dev/null +++ b/net/core/dev.h @@ -0,0 +1,118 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +#ifndef _NET_CORE_DEV_H +#define _NET_CORE_DEV_H + +#include + +struct net; +struct net_device; +struct netdev_bpf; +struct netdev_phys_item_id; +struct netlink_ext_ack; + +/* Random bits of netdevice that don't need to be exposed */ +#define FLOW_LIMIT_HISTORY (1 << 7) /* must be ^2 and !overflow buckets */ +struct sd_flow_limit { + u64 count; + unsigned int num_buckets; + unsigned int history_head; + u16 history[FLOW_LIMIT_HISTORY]; + u8 buckets[]; +}; + +extern int netdev_flow_limit_table_len; + +#ifdef CONFIG_PROC_FS +int __init dev_proc_init(void); +#else +#define dev_proc_init() 0 +#endif + +void linkwatch_init_dev(struct net_device *dev); +void linkwatch_forget_dev(struct net_device *dev); +void linkwatch_run_queue(void); + +void dev_addr_flush(struct net_device *dev); +int dev_addr_init(struct net_device *dev); +void dev_addr_check(struct net_device *dev); + +/* sysctls not referred to from outside net/core/ */ +extern int netdev_budget; +extern unsigned int netdev_budget_usecs; +extern unsigned int sysctl_skb_defer_max; +extern int netdev_tstamp_prequeue; +extern int netdev_unregister_timeout_secs; +extern int weight_p; +extern int dev_weight_rx_bias; +extern int dev_weight_tx_bias; + +/* rtnl helpers */ +extern struct list_head net_todo_list; +void netdev_run_todo(void); + +/* netdev management, shared between various uAPI entry points */ +struct netdev_name_node { + struct hlist_node hlist; + struct list_head list; + struct net_device *dev; + const char *name; +}; + +int netdev_get_name(struct net *net, char *name, int ifindex); +int dev_change_name(struct net_device *dev, const char *newname); + +#define netdev_for_each_altname(dev, namenode) \ + list_for_each_entry((namenode), &(dev)->name_node->list, list) +#define netdev_for_each_altname_safe(dev, namenode, next) \ + list_for_each_entry_safe((namenode), (next), &(dev)->name_node->list, \ + list) + +int netdev_name_node_alt_create(struct net_device *dev, const char *name); +int netdev_name_node_alt_destroy(struct net_device *dev, const char *name); + +int dev_validate_mtu(struct net_device *dev, int mtu, + struct netlink_ext_ack *extack); +int dev_set_mtu_ext(struct net_device *dev, int mtu, + struct netlink_ext_ack *extack); + +int dev_get_phys_port_id(struct net_device *dev, + struct netdev_phys_item_id *ppid); +int dev_get_phys_port_name(struct net_device *dev, + char *name, size_t len); + +int dev_change_proto_down(struct net_device *dev, bool proto_down); +void dev_change_proto_down_reason(struct net_device *dev, unsigned long mask, + u32 value); + +typedef int (*bpf_op_t)(struct net_device *dev, struct netdev_bpf *bpf); +int dev_change_xdp_fd(struct net_device *dev, struct netlink_ext_ack *extack, + int fd, int expected_fd, u32 flags); + +int dev_change_tx_queue_len(struct net_device *dev, unsigned long new_len); +void dev_set_group(struct net_device *dev, int new_group); +int dev_change_carrier(struct net_device *dev, bool new_carrier); + +void __dev_set_rx_mode(struct net_device *dev); + +static inline void netif_set_gso_max_size(struct net_device *dev, + unsigned int size) +{ + /* dev->gso_max_size is read locklessly from sk_setup_caps() */ + WRITE_ONCE(dev->gso_max_size, size); +} + +static inline void netif_set_gso_max_segs(struct net_device *dev, + unsigned int segs) +{ + /* dev->gso_max_segs is read locklessly from sk_setup_caps() */ + WRITE_ONCE(dev->gso_max_segs, segs); +} + +static inline void netif_set_gro_max_size(struct net_device *dev, + unsigned int size) +{ + /* This pairs with the READ_ONCE() in skb_gro_receive() */ + WRITE_ONCE(dev->gro_max_size, size); +} + +#endif diff --git a/net/core/dev_addr_lists.c b/net/core/dev_addr_lists.c new file mode 100644 index 000000000..baa63dee2 --- /dev/null +++ b/net/core/dev_addr_lists.c @@ -0,0 +1,1050 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/core/dev_addr_lists.c - Functions for handling net device lists + * Copyright (c) 2010 Jiri Pirko + * + * This file contains functions for working with unicast, multicast and device + * addresses lists. + */ + +#include +#include +#include +#include + +#include "dev.h" + +/* + * General list handling functions + */ + +static int __hw_addr_insert(struct netdev_hw_addr_list *list, + struct netdev_hw_addr *new, int addr_len) +{ + struct rb_node **ins_point = &list->tree.rb_node, *parent = NULL; + struct netdev_hw_addr *ha; + + while (*ins_point) { + int diff; + + ha = rb_entry(*ins_point, struct netdev_hw_addr, node); + diff = memcmp(new->addr, ha->addr, addr_len); + if (diff == 0) + diff = memcmp(&new->type, &ha->type, sizeof(new->type)); + + parent = *ins_point; + if (diff < 0) + ins_point = &parent->rb_left; + else if (diff > 0) + ins_point = &parent->rb_right; + else + return -EEXIST; + } + + rb_link_node_rcu(&new->node, parent, ins_point); + rb_insert_color(&new->node, &list->tree); + + return 0; +} + +static struct netdev_hw_addr* +__hw_addr_create(const unsigned char *addr, int addr_len, + unsigned char addr_type, bool global, bool sync) +{ + struct netdev_hw_addr *ha; + int alloc_size; + + alloc_size = sizeof(*ha); + if (alloc_size < L1_CACHE_BYTES) + alloc_size = L1_CACHE_BYTES; + ha = kmalloc(alloc_size, GFP_ATOMIC); + if (!ha) + return NULL; + memcpy(ha->addr, addr, addr_len); + ha->type = addr_type; + ha->refcount = 1; + ha->global_use = global; + ha->synced = sync ? 1 : 0; + ha->sync_cnt = 0; + + return ha; +} + +static int __hw_addr_add_ex(struct netdev_hw_addr_list *list, + const unsigned char *addr, int addr_len, + unsigned char addr_type, bool global, bool sync, + int sync_count, bool exclusive) +{ + struct rb_node **ins_point = &list->tree.rb_node, *parent = NULL; + struct netdev_hw_addr *ha; + + if (addr_len > MAX_ADDR_LEN) + return -EINVAL; + + while (*ins_point) { + int diff; + + ha = rb_entry(*ins_point, struct netdev_hw_addr, node); + diff = memcmp(addr, ha->addr, addr_len); + if (diff == 0) + diff = memcmp(&addr_type, &ha->type, sizeof(addr_type)); + + parent = *ins_point; + if (diff < 0) { + ins_point = &parent->rb_left; + } else if (diff > 0) { + ins_point = &parent->rb_right; + } else { + if (exclusive) + return -EEXIST; + if (global) { + /* check if addr is already used as global */ + if (ha->global_use) + return 0; + else + ha->global_use = true; + } + if (sync) { + if (ha->synced && sync_count) + return -EEXIST; + else + ha->synced++; + } + ha->refcount++; + return 0; + } + } + + ha = __hw_addr_create(addr, addr_len, addr_type, global, sync); + if (!ha) + return -ENOMEM; + + rb_link_node(&ha->node, parent, ins_point); + rb_insert_color(&ha->node, &list->tree); + + list_add_tail_rcu(&ha->list, &list->list); + list->count++; + + return 0; +} + +static int __hw_addr_add(struct netdev_hw_addr_list *list, + const unsigned char *addr, int addr_len, + unsigned char addr_type) +{ + return __hw_addr_add_ex(list, addr, addr_len, addr_type, false, false, + 0, false); +} + +static int __hw_addr_del_entry(struct netdev_hw_addr_list *list, + struct netdev_hw_addr *ha, bool global, + bool sync) +{ + if (global && !ha->global_use) + return -ENOENT; + + if (sync && !ha->synced) + return -ENOENT; + + if (global) + ha->global_use = false; + + if (sync) + ha->synced--; + + if (--ha->refcount) + return 0; + + rb_erase(&ha->node, &list->tree); + + list_del_rcu(&ha->list); + kfree_rcu(ha, rcu_head); + list->count--; + return 0; +} + +static struct netdev_hw_addr *__hw_addr_lookup(struct netdev_hw_addr_list *list, + const unsigned char *addr, int addr_len, + unsigned char addr_type) +{ + struct rb_node *node; + + node = list->tree.rb_node; + + while (node) { + struct netdev_hw_addr *ha = rb_entry(node, struct netdev_hw_addr, node); + int diff = memcmp(addr, ha->addr, addr_len); + + if (diff == 0 && addr_type) + diff = memcmp(&addr_type, &ha->type, sizeof(addr_type)); + + if (diff < 0) + node = node->rb_left; + else if (diff > 0) + node = node->rb_right; + else + return ha; + } + + return NULL; +} + +static int __hw_addr_del_ex(struct netdev_hw_addr_list *list, + const unsigned char *addr, int addr_len, + unsigned char addr_type, bool global, bool sync) +{ + struct netdev_hw_addr *ha = __hw_addr_lookup(list, addr, addr_len, addr_type); + + if (!ha) + return -ENOENT; + return __hw_addr_del_entry(list, ha, global, sync); +} + +static int __hw_addr_del(struct netdev_hw_addr_list *list, + const unsigned char *addr, int addr_len, + unsigned char addr_type) +{ + return __hw_addr_del_ex(list, addr, addr_len, addr_type, false, false); +} + +static int __hw_addr_sync_one(struct netdev_hw_addr_list *to_list, + struct netdev_hw_addr *ha, + int addr_len) +{ + int err; + + err = __hw_addr_add_ex(to_list, ha->addr, addr_len, ha->type, + false, true, ha->sync_cnt, false); + if (err && err != -EEXIST) + return err; + + if (!err) { + ha->sync_cnt++; + ha->refcount++; + } + + return 0; +} + +static void __hw_addr_unsync_one(struct netdev_hw_addr_list *to_list, + struct netdev_hw_addr_list *from_list, + struct netdev_hw_addr *ha, + int addr_len) +{ + int err; + + err = __hw_addr_del_ex(to_list, ha->addr, addr_len, ha->type, + false, true); + if (err) + return; + ha->sync_cnt--; + /* address on from list is not marked synced */ + __hw_addr_del_entry(from_list, ha, false, false); +} + +static int __hw_addr_sync_multiple(struct netdev_hw_addr_list *to_list, + struct netdev_hw_addr_list *from_list, + int addr_len) +{ + int err = 0; + struct netdev_hw_addr *ha, *tmp; + + list_for_each_entry_safe(ha, tmp, &from_list->list, list) { + if (ha->sync_cnt == ha->refcount) { + __hw_addr_unsync_one(to_list, from_list, ha, addr_len); + } else { + err = __hw_addr_sync_one(to_list, ha, addr_len); + if (err) + break; + } + } + return err; +} + +/* This function only works where there is a strict 1-1 relationship + * between source and destionation of they synch. If you ever need to + * sync addresses to more then 1 destination, you need to use + * __hw_addr_sync_multiple(). + */ +int __hw_addr_sync(struct netdev_hw_addr_list *to_list, + struct netdev_hw_addr_list *from_list, + int addr_len) +{ + int err = 0; + struct netdev_hw_addr *ha, *tmp; + + list_for_each_entry_safe(ha, tmp, &from_list->list, list) { + if (!ha->sync_cnt) { + err = __hw_addr_sync_one(to_list, ha, addr_len); + if (err) + break; + } else if (ha->refcount == 1) + __hw_addr_unsync_one(to_list, from_list, ha, addr_len); + } + return err; +} +EXPORT_SYMBOL(__hw_addr_sync); + +void __hw_addr_unsync(struct netdev_hw_addr_list *to_list, + struct netdev_hw_addr_list *from_list, + int addr_len) +{ + struct netdev_hw_addr *ha, *tmp; + + list_for_each_entry_safe(ha, tmp, &from_list->list, list) { + if (ha->sync_cnt) + __hw_addr_unsync_one(to_list, from_list, ha, addr_len); + } +} +EXPORT_SYMBOL(__hw_addr_unsync); + +/** + * __hw_addr_sync_dev - Synchonize device's multicast list + * @list: address list to syncronize + * @dev: device to sync + * @sync: function to call if address should be added + * @unsync: function to call if address should be removed + * + * This function is intended to be called from the ndo_set_rx_mode + * function of devices that require explicit address add/remove + * notifications. The unsync function may be NULL in which case + * the addresses requiring removal will simply be removed without + * any notification to the device. + **/ +int __hw_addr_sync_dev(struct netdev_hw_addr_list *list, + struct net_device *dev, + int (*sync)(struct net_device *, const unsigned char *), + int (*unsync)(struct net_device *, + const unsigned char *)) +{ + struct netdev_hw_addr *ha, *tmp; + int err; + + /* first go through and flush out any stale entries */ + list_for_each_entry_safe(ha, tmp, &list->list, list) { + if (!ha->sync_cnt || ha->refcount != 1) + continue; + + /* if unsync is defined and fails defer unsyncing address */ + if (unsync && unsync(dev, ha->addr)) + continue; + + ha->sync_cnt--; + __hw_addr_del_entry(list, ha, false, false); + } + + /* go through and sync new entries to the list */ + list_for_each_entry_safe(ha, tmp, &list->list, list) { + if (ha->sync_cnt) + continue; + + err = sync(dev, ha->addr); + if (err) + return err; + + ha->sync_cnt++; + ha->refcount++; + } + + return 0; +} +EXPORT_SYMBOL(__hw_addr_sync_dev); + +/** + * __hw_addr_ref_sync_dev - Synchronize device's multicast address list taking + * into account references + * @list: address list to synchronize + * @dev: device to sync + * @sync: function to call if address or reference on it should be added + * @unsync: function to call if address or some reference on it should removed + * + * This function is intended to be called from the ndo_set_rx_mode + * function of devices that require explicit address or references on it + * add/remove notifications. The unsync function may be NULL in which case + * the addresses or references on it requiring removal will simply be + * removed without any notification to the device. That is responsibility of + * the driver to identify and distribute address or references on it between + * internal address tables. + **/ +int __hw_addr_ref_sync_dev(struct netdev_hw_addr_list *list, + struct net_device *dev, + int (*sync)(struct net_device *, + const unsigned char *, int), + int (*unsync)(struct net_device *, + const unsigned char *, int)) +{ + struct netdev_hw_addr *ha, *tmp; + int err, ref_cnt; + + /* first go through and flush out any unsynced/stale entries */ + list_for_each_entry_safe(ha, tmp, &list->list, list) { + /* sync if address is not used */ + if ((ha->sync_cnt << 1) <= ha->refcount) + continue; + + /* if fails defer unsyncing address */ + ref_cnt = ha->refcount - ha->sync_cnt; + if (unsync && unsync(dev, ha->addr, ref_cnt)) + continue; + + ha->refcount = (ref_cnt << 1) + 1; + ha->sync_cnt = ref_cnt; + __hw_addr_del_entry(list, ha, false, false); + } + + /* go through and sync updated/new entries to the list */ + list_for_each_entry_safe(ha, tmp, &list->list, list) { + /* sync if address added or reused */ + if ((ha->sync_cnt << 1) >= ha->refcount) + continue; + + ref_cnt = ha->refcount - ha->sync_cnt; + err = sync(dev, ha->addr, ref_cnt); + if (err) + return err; + + ha->refcount = ref_cnt << 1; + ha->sync_cnt = ref_cnt; + } + + return 0; +} +EXPORT_SYMBOL(__hw_addr_ref_sync_dev); + +/** + * __hw_addr_ref_unsync_dev - Remove synchronized addresses and references on + * it from device + * @list: address list to remove synchronized addresses (references on it) from + * @dev: device to sync + * @unsync: function to call if address and references on it should be removed + * + * Remove all addresses that were added to the device by + * __hw_addr_ref_sync_dev(). This function is intended to be called from the + * ndo_stop or ndo_open functions on devices that require explicit address (or + * references on it) add/remove notifications. If the unsync function pointer + * is NULL then this function can be used to just reset the sync_cnt for the + * addresses in the list. + **/ +void __hw_addr_ref_unsync_dev(struct netdev_hw_addr_list *list, + struct net_device *dev, + int (*unsync)(struct net_device *, + const unsigned char *, int)) +{ + struct netdev_hw_addr *ha, *tmp; + + list_for_each_entry_safe(ha, tmp, &list->list, list) { + if (!ha->sync_cnt) + continue; + + /* if fails defer unsyncing address */ + if (unsync && unsync(dev, ha->addr, ha->sync_cnt)) + continue; + + ha->refcount -= ha->sync_cnt - 1; + ha->sync_cnt = 0; + __hw_addr_del_entry(list, ha, false, false); + } +} +EXPORT_SYMBOL(__hw_addr_ref_unsync_dev); + +/** + * __hw_addr_unsync_dev - Remove synchronized addresses from device + * @list: address list to remove synchronized addresses from + * @dev: device to sync + * @unsync: function to call if address should be removed + * + * Remove all addresses that were added to the device by __hw_addr_sync_dev(). + * This function is intended to be called from the ndo_stop or ndo_open + * functions on devices that require explicit address add/remove + * notifications. If the unsync function pointer is NULL then this function + * can be used to just reset the sync_cnt for the addresses in the list. + **/ +void __hw_addr_unsync_dev(struct netdev_hw_addr_list *list, + struct net_device *dev, + int (*unsync)(struct net_device *, + const unsigned char *)) +{ + struct netdev_hw_addr *ha, *tmp; + + list_for_each_entry_safe(ha, tmp, &list->list, list) { + if (!ha->sync_cnt) + continue; + + /* if unsync is defined and fails defer unsyncing address */ + if (unsync && unsync(dev, ha->addr)) + continue; + + ha->sync_cnt--; + __hw_addr_del_entry(list, ha, false, false); + } +} +EXPORT_SYMBOL(__hw_addr_unsync_dev); + +static void __hw_addr_flush(struct netdev_hw_addr_list *list) +{ + struct netdev_hw_addr *ha, *tmp; + + list->tree = RB_ROOT; + list_for_each_entry_safe(ha, tmp, &list->list, list) { + list_del_rcu(&ha->list); + kfree_rcu(ha, rcu_head); + } + list->count = 0; +} + +void __hw_addr_init(struct netdev_hw_addr_list *list) +{ + INIT_LIST_HEAD(&list->list); + list->count = 0; + list->tree = RB_ROOT; +} +EXPORT_SYMBOL(__hw_addr_init); + +/* + * Device addresses handling functions + */ + +/* Check that netdev->dev_addr is not written to directly as this would + * break the rbtree layout. All changes should go thru dev_addr_set() and co. + * Remove this check in mid-2024. + */ +void dev_addr_check(struct net_device *dev) +{ + if (!memcmp(dev->dev_addr, dev->dev_addr_shadow, MAX_ADDR_LEN)) + return; + + netdev_warn(dev, "Current addr: %*ph\n", MAX_ADDR_LEN, dev->dev_addr); + netdev_warn(dev, "Expected addr: %*ph\n", + MAX_ADDR_LEN, dev->dev_addr_shadow); + netdev_WARN(dev, "Incorrect netdev->dev_addr\n"); +} + +/** + * dev_addr_flush - Flush device address list + * @dev: device + * + * Flush device address list and reset ->dev_addr. + * + * The caller must hold the rtnl_mutex. + */ +void dev_addr_flush(struct net_device *dev) +{ + /* rtnl_mutex must be held here */ + dev_addr_check(dev); + + __hw_addr_flush(&dev->dev_addrs); + dev->dev_addr = NULL; +} + +/** + * dev_addr_init - Init device address list + * @dev: device + * + * Init device address list and create the first element, + * used by ->dev_addr. + * + * The caller must hold the rtnl_mutex. + */ +int dev_addr_init(struct net_device *dev) +{ + unsigned char addr[MAX_ADDR_LEN]; + struct netdev_hw_addr *ha; + int err; + + /* rtnl_mutex must be held here */ + + __hw_addr_init(&dev->dev_addrs); + memset(addr, 0, sizeof(addr)); + err = __hw_addr_add(&dev->dev_addrs, addr, sizeof(addr), + NETDEV_HW_ADDR_T_LAN); + if (!err) { + /* + * Get the first (previously created) address from the list + * and set dev_addr pointer to this location. + */ + ha = list_first_entry(&dev->dev_addrs.list, + struct netdev_hw_addr, list); + dev->dev_addr = ha->addr; + } + return err; +} + +void dev_addr_mod(struct net_device *dev, unsigned int offset, + const void *addr, size_t len) +{ + struct netdev_hw_addr *ha; + + dev_addr_check(dev); + + ha = container_of(dev->dev_addr, struct netdev_hw_addr, addr[0]); + rb_erase(&ha->node, &dev->dev_addrs.tree); + memcpy(&ha->addr[offset], addr, len); + memcpy(&dev->dev_addr_shadow[offset], addr, len); + WARN_ON(__hw_addr_insert(&dev->dev_addrs, ha, dev->addr_len)); +} +EXPORT_SYMBOL(dev_addr_mod); + +/** + * dev_addr_add - Add a device address + * @dev: device + * @addr: address to add + * @addr_type: address type + * + * Add a device address to the device or increase the reference count if + * it already exists. + * + * The caller must hold the rtnl_mutex. + */ +int dev_addr_add(struct net_device *dev, const unsigned char *addr, + unsigned char addr_type) +{ + int err; + + ASSERT_RTNL(); + + err = dev_pre_changeaddr_notify(dev, addr, NULL); + if (err) + return err; + err = __hw_addr_add(&dev->dev_addrs, addr, dev->addr_len, addr_type); + if (!err) + call_netdevice_notifiers(NETDEV_CHANGEADDR, dev); + return err; +} +EXPORT_SYMBOL(dev_addr_add); + +/** + * dev_addr_del - Release a device address. + * @dev: device + * @addr: address to delete + * @addr_type: address type + * + * Release reference to a device address and remove it from the device + * if the reference count drops to zero. + * + * The caller must hold the rtnl_mutex. + */ +int dev_addr_del(struct net_device *dev, const unsigned char *addr, + unsigned char addr_type) +{ + int err; + struct netdev_hw_addr *ha; + + ASSERT_RTNL(); + + /* + * We can not remove the first address from the list because + * dev->dev_addr points to that. + */ + ha = list_first_entry(&dev->dev_addrs.list, + struct netdev_hw_addr, list); + if (!memcmp(ha->addr, addr, dev->addr_len) && + ha->type == addr_type && ha->refcount == 1) + return -ENOENT; + + err = __hw_addr_del(&dev->dev_addrs, addr, dev->addr_len, + addr_type); + if (!err) + call_netdevice_notifiers(NETDEV_CHANGEADDR, dev); + return err; +} +EXPORT_SYMBOL(dev_addr_del); + +/* + * Unicast list handling functions + */ + +/** + * dev_uc_add_excl - Add a global secondary unicast address + * @dev: device + * @addr: address to add + */ +int dev_uc_add_excl(struct net_device *dev, const unsigned char *addr) +{ + int err; + + netif_addr_lock_bh(dev); + err = __hw_addr_add_ex(&dev->uc, addr, dev->addr_len, + NETDEV_HW_ADDR_T_UNICAST, true, false, + 0, true); + if (!err) + __dev_set_rx_mode(dev); + netif_addr_unlock_bh(dev); + return err; +} +EXPORT_SYMBOL(dev_uc_add_excl); + +/** + * dev_uc_add - Add a secondary unicast address + * @dev: device + * @addr: address to add + * + * Add a secondary unicast address to the device or increase + * the reference count if it already exists. + */ +int dev_uc_add(struct net_device *dev, const unsigned char *addr) +{ + int err; + + netif_addr_lock_bh(dev); + err = __hw_addr_add(&dev->uc, addr, dev->addr_len, + NETDEV_HW_ADDR_T_UNICAST); + if (!err) + __dev_set_rx_mode(dev); + netif_addr_unlock_bh(dev); + return err; +} +EXPORT_SYMBOL(dev_uc_add); + +/** + * dev_uc_del - Release secondary unicast address. + * @dev: device + * @addr: address to delete + * + * Release reference to a secondary unicast address and remove it + * from the device if the reference count drops to zero. + */ +int dev_uc_del(struct net_device *dev, const unsigned char *addr) +{ + int err; + + netif_addr_lock_bh(dev); + err = __hw_addr_del(&dev->uc, addr, dev->addr_len, + NETDEV_HW_ADDR_T_UNICAST); + if (!err) + __dev_set_rx_mode(dev); + netif_addr_unlock_bh(dev); + return err; +} +EXPORT_SYMBOL(dev_uc_del); + +/** + * dev_uc_sync - Synchronize device's unicast list to another device + * @to: destination device + * @from: source device + * + * Add newly added addresses to the destination device and release + * addresses that have no users left. The source device must be + * locked by netif_addr_lock_bh. + * + * This function is intended to be called from the dev->set_rx_mode + * function of layered software devices. This function assumes that + * addresses will only ever be synced to the @to devices and no other. + */ +int dev_uc_sync(struct net_device *to, struct net_device *from) +{ + int err = 0; + + if (to->addr_len != from->addr_len) + return -EINVAL; + + netif_addr_lock(to); + err = __hw_addr_sync(&to->uc, &from->uc, to->addr_len); + if (!err) + __dev_set_rx_mode(to); + netif_addr_unlock(to); + return err; +} +EXPORT_SYMBOL(dev_uc_sync); + +/** + * dev_uc_sync_multiple - Synchronize device's unicast list to another + * device, but allow for multiple calls to sync to multiple devices. + * @to: destination device + * @from: source device + * + * Add newly added addresses to the destination device and release + * addresses that have been deleted from the source. The source device + * must be locked by netif_addr_lock_bh. + * + * This function is intended to be called from the dev->set_rx_mode + * function of layered software devices. It allows for a single source + * device to be synced to multiple destination devices. + */ +int dev_uc_sync_multiple(struct net_device *to, struct net_device *from) +{ + int err = 0; + + if (to->addr_len != from->addr_len) + return -EINVAL; + + netif_addr_lock(to); + err = __hw_addr_sync_multiple(&to->uc, &from->uc, to->addr_len); + if (!err) + __dev_set_rx_mode(to); + netif_addr_unlock(to); + return err; +} +EXPORT_SYMBOL(dev_uc_sync_multiple); + +/** + * dev_uc_unsync - Remove synchronized addresses from the destination device + * @to: destination device + * @from: source device + * + * Remove all addresses that were added to the destination device by + * dev_uc_sync(). This function is intended to be called from the + * dev->stop function of layered software devices. + */ +void dev_uc_unsync(struct net_device *to, struct net_device *from) +{ + if (to->addr_len != from->addr_len) + return; + + /* netif_addr_lock_bh() uses lockdep subclass 0, this is okay for two + * reasons: + * 1) This is always called without any addr_list_lock, so as the + * outermost one here, it must be 0. + * 2) This is called by some callers after unlinking the upper device, + * so the dev->lower_level becomes 1 again. + * Therefore, the subclass for 'from' is 0, for 'to' is either 1 or + * larger. + */ + netif_addr_lock_bh(from); + netif_addr_lock(to); + __hw_addr_unsync(&to->uc, &from->uc, to->addr_len); + __dev_set_rx_mode(to); + netif_addr_unlock(to); + netif_addr_unlock_bh(from); +} +EXPORT_SYMBOL(dev_uc_unsync); + +/** + * dev_uc_flush - Flush unicast addresses + * @dev: device + * + * Flush unicast addresses. + */ +void dev_uc_flush(struct net_device *dev) +{ + netif_addr_lock_bh(dev); + __hw_addr_flush(&dev->uc); + netif_addr_unlock_bh(dev); +} +EXPORT_SYMBOL(dev_uc_flush); + +/** + * dev_uc_init - Init unicast address list + * @dev: device + * + * Init unicast address list. + */ +void dev_uc_init(struct net_device *dev) +{ + __hw_addr_init(&dev->uc); +} +EXPORT_SYMBOL(dev_uc_init); + +/* + * Multicast list handling functions + */ + +/** + * dev_mc_add_excl - Add a global secondary multicast address + * @dev: device + * @addr: address to add + */ +int dev_mc_add_excl(struct net_device *dev, const unsigned char *addr) +{ + int err; + + netif_addr_lock_bh(dev); + err = __hw_addr_add_ex(&dev->mc, addr, dev->addr_len, + NETDEV_HW_ADDR_T_MULTICAST, true, false, + 0, true); + if (!err) + __dev_set_rx_mode(dev); + netif_addr_unlock_bh(dev); + return err; +} +EXPORT_SYMBOL(dev_mc_add_excl); + +static int __dev_mc_add(struct net_device *dev, const unsigned char *addr, + bool global) +{ + int err; + + netif_addr_lock_bh(dev); + err = __hw_addr_add_ex(&dev->mc, addr, dev->addr_len, + NETDEV_HW_ADDR_T_MULTICAST, global, false, + 0, false); + if (!err) + __dev_set_rx_mode(dev); + netif_addr_unlock_bh(dev); + return err; +} +/** + * dev_mc_add - Add a multicast address + * @dev: device + * @addr: address to add + * + * Add a multicast address to the device or increase + * the reference count if it already exists. + */ +int dev_mc_add(struct net_device *dev, const unsigned char *addr) +{ + return __dev_mc_add(dev, addr, false); +} +EXPORT_SYMBOL(dev_mc_add); + +/** + * dev_mc_add_global - Add a global multicast address + * @dev: device + * @addr: address to add + * + * Add a global multicast address to the device. + */ +int dev_mc_add_global(struct net_device *dev, const unsigned char *addr) +{ + return __dev_mc_add(dev, addr, true); +} +EXPORT_SYMBOL(dev_mc_add_global); + +static int __dev_mc_del(struct net_device *dev, const unsigned char *addr, + bool global) +{ + int err; + + netif_addr_lock_bh(dev); + err = __hw_addr_del_ex(&dev->mc, addr, dev->addr_len, + NETDEV_HW_ADDR_T_MULTICAST, global, false); + if (!err) + __dev_set_rx_mode(dev); + netif_addr_unlock_bh(dev); + return err; +} + +/** + * dev_mc_del - Delete a multicast address. + * @dev: device + * @addr: address to delete + * + * Release reference to a multicast address and remove it + * from the device if the reference count drops to zero. + */ +int dev_mc_del(struct net_device *dev, const unsigned char *addr) +{ + return __dev_mc_del(dev, addr, false); +} +EXPORT_SYMBOL(dev_mc_del); + +/** + * dev_mc_del_global - Delete a global multicast address. + * @dev: device + * @addr: address to delete + * + * Release reference to a multicast address and remove it + * from the device if the reference count drops to zero. + */ +int dev_mc_del_global(struct net_device *dev, const unsigned char *addr) +{ + return __dev_mc_del(dev, addr, true); +} +EXPORT_SYMBOL(dev_mc_del_global); + +/** + * dev_mc_sync - Synchronize device's multicast list to another device + * @to: destination device + * @from: source device + * + * Add newly added addresses to the destination device and release + * addresses that have no users left. The source device must be + * locked by netif_addr_lock_bh. + * + * This function is intended to be called from the ndo_set_rx_mode + * function of layered software devices. + */ +int dev_mc_sync(struct net_device *to, struct net_device *from) +{ + int err = 0; + + if (to->addr_len != from->addr_len) + return -EINVAL; + + netif_addr_lock(to); + err = __hw_addr_sync(&to->mc, &from->mc, to->addr_len); + if (!err) + __dev_set_rx_mode(to); + netif_addr_unlock(to); + return err; +} +EXPORT_SYMBOL(dev_mc_sync); + +/** + * dev_mc_sync_multiple - Synchronize device's multicast list to another + * device, but allow for multiple calls to sync to multiple devices. + * @to: destination device + * @from: source device + * + * Add newly added addresses to the destination device and release + * addresses that have no users left. The source device must be + * locked by netif_addr_lock_bh. + * + * This function is intended to be called from the ndo_set_rx_mode + * function of layered software devices. It allows for a single + * source device to be synced to multiple destination devices. + */ +int dev_mc_sync_multiple(struct net_device *to, struct net_device *from) +{ + int err = 0; + + if (to->addr_len != from->addr_len) + return -EINVAL; + + netif_addr_lock(to); + err = __hw_addr_sync_multiple(&to->mc, &from->mc, to->addr_len); + if (!err) + __dev_set_rx_mode(to); + netif_addr_unlock(to); + return err; +} +EXPORT_SYMBOL(dev_mc_sync_multiple); + +/** + * dev_mc_unsync - Remove synchronized addresses from the destination device + * @to: destination device + * @from: source device + * + * Remove all addresses that were added to the destination device by + * dev_mc_sync(). This function is intended to be called from the + * dev->stop function of layered software devices. + */ +void dev_mc_unsync(struct net_device *to, struct net_device *from) +{ + if (to->addr_len != from->addr_len) + return; + + /* See the above comments inside dev_uc_unsync(). */ + netif_addr_lock_bh(from); + netif_addr_lock(to); + __hw_addr_unsync(&to->mc, &from->mc, to->addr_len); + __dev_set_rx_mode(to); + netif_addr_unlock(to); + netif_addr_unlock_bh(from); +} +EXPORT_SYMBOL(dev_mc_unsync); + +/** + * dev_mc_flush - Flush multicast addresses + * @dev: device + * + * Flush multicast addresses. + */ +void dev_mc_flush(struct net_device *dev) +{ + netif_addr_lock_bh(dev); + __hw_addr_flush(&dev->mc); + netif_addr_unlock_bh(dev); +} +EXPORT_SYMBOL(dev_mc_flush); + +/** + * dev_mc_init - Init multicast address list + * @dev: device + * + * Init multicast address list. + */ +void dev_mc_init(struct net_device *dev) +{ + __hw_addr_init(&dev->mc); +} +EXPORT_SYMBOL(dev_mc_init); diff --git a/net/core/dev_addr_lists_test.c b/net/core/dev_addr_lists_test.c new file mode 100644 index 000000000..049cfbc58 --- /dev/null +++ b/net/core/dev_addr_lists_test.c @@ -0,0 +1,236 @@ +// SPDX-License-Identifier: GPL-2.0-or-later + +#include +#include +#include +#include + +static const struct net_device_ops dummy_netdev_ops = { +}; + +struct dev_addr_test_priv { + u32 addr_seen; +}; + +static int dev_addr_test_sync(struct net_device *netdev, const unsigned char *a) +{ + struct dev_addr_test_priv *datp = netdev_priv(netdev); + + if (a[0] < 31 && !memchr_inv(a, a[0], ETH_ALEN)) + datp->addr_seen |= 1 << a[0]; + return 0; +} + +static int dev_addr_test_unsync(struct net_device *netdev, + const unsigned char *a) +{ + struct dev_addr_test_priv *datp = netdev_priv(netdev); + + if (a[0] < 31 && !memchr_inv(a, a[0], ETH_ALEN)) + datp->addr_seen &= ~(1 << a[0]); + return 0; +} + +static int dev_addr_test_init(struct kunit *test) +{ + struct dev_addr_test_priv *datp; + struct net_device *netdev; + int err; + + netdev = alloc_etherdev(sizeof(*datp)); + KUNIT_ASSERT_TRUE(test, !!netdev); + + test->priv = netdev; + netdev->netdev_ops = &dummy_netdev_ops; + + err = register_netdev(netdev); + if (err) { + free_netdev(netdev); + KUNIT_FAIL(test, "Can't register netdev %d", err); + } + + rtnl_lock(); + return 0; +} + +static void dev_addr_test_exit(struct kunit *test) +{ + struct net_device *netdev = test->priv; + + rtnl_unlock(); + unregister_netdev(netdev); + free_netdev(netdev); +} + +static void dev_addr_test_basic(struct kunit *test) +{ + struct net_device *netdev = test->priv; + u8 addr[ETH_ALEN]; + + KUNIT_EXPECT_TRUE(test, !!netdev->dev_addr); + + memset(addr, 2, sizeof(addr)); + eth_hw_addr_set(netdev, addr); + KUNIT_EXPECT_EQ(test, 0, memcmp(netdev->dev_addr, addr, sizeof(addr))); + + memset(addr, 3, sizeof(addr)); + dev_addr_set(netdev, addr); + KUNIT_EXPECT_EQ(test, 0, memcmp(netdev->dev_addr, addr, sizeof(addr))); +} + +static void dev_addr_test_sync_one(struct kunit *test) +{ + struct net_device *netdev = test->priv; + struct dev_addr_test_priv *datp; + u8 addr[ETH_ALEN]; + + datp = netdev_priv(netdev); + + memset(addr, 1, sizeof(addr)); + eth_hw_addr_set(netdev, addr); + + __hw_addr_sync_dev(&netdev->dev_addrs, netdev, dev_addr_test_sync, + dev_addr_test_unsync); + KUNIT_EXPECT_EQ(test, 2, datp->addr_seen); + + memset(addr, 2, sizeof(addr)); + eth_hw_addr_set(netdev, addr); + + datp->addr_seen = 0; + __hw_addr_sync_dev(&netdev->dev_addrs, netdev, dev_addr_test_sync, + dev_addr_test_unsync); + /* It's not going to sync anything because the main address is + * considered synced and we overwrite in place. + */ + KUNIT_EXPECT_EQ(test, 0, datp->addr_seen); +} + +static void dev_addr_test_add_del(struct kunit *test) +{ + struct net_device *netdev = test->priv; + struct dev_addr_test_priv *datp; + u8 addr[ETH_ALEN]; + int i; + + datp = netdev_priv(netdev); + + for (i = 1; i < 4; i++) { + memset(addr, i, sizeof(addr)); + KUNIT_EXPECT_EQ(test, 0, dev_addr_add(netdev, addr, + NETDEV_HW_ADDR_T_LAN)); + } + /* Add 3 again */ + KUNIT_EXPECT_EQ(test, 0, dev_addr_add(netdev, addr, + NETDEV_HW_ADDR_T_LAN)); + + __hw_addr_sync_dev(&netdev->dev_addrs, netdev, dev_addr_test_sync, + dev_addr_test_unsync); + KUNIT_EXPECT_EQ(test, 0xf, datp->addr_seen); + + KUNIT_EXPECT_EQ(test, 0, dev_addr_del(netdev, addr, + NETDEV_HW_ADDR_T_LAN)); + + __hw_addr_sync_dev(&netdev->dev_addrs, netdev, dev_addr_test_sync, + dev_addr_test_unsync); + KUNIT_EXPECT_EQ(test, 0xf, datp->addr_seen); + + for (i = 1; i < 4; i++) { + memset(addr, i, sizeof(addr)); + KUNIT_EXPECT_EQ(test, 0, dev_addr_del(netdev, addr, + NETDEV_HW_ADDR_T_LAN)); + } + + __hw_addr_sync_dev(&netdev->dev_addrs, netdev, dev_addr_test_sync, + dev_addr_test_unsync); + KUNIT_EXPECT_EQ(test, 1, datp->addr_seen); +} + +static void dev_addr_test_del_main(struct kunit *test) +{ + struct net_device *netdev = test->priv; + u8 addr[ETH_ALEN]; + + memset(addr, 1, sizeof(addr)); + eth_hw_addr_set(netdev, addr); + + KUNIT_EXPECT_EQ(test, -ENOENT, dev_addr_del(netdev, addr, + NETDEV_HW_ADDR_T_LAN)); + KUNIT_EXPECT_EQ(test, 0, dev_addr_add(netdev, addr, + NETDEV_HW_ADDR_T_LAN)); + KUNIT_EXPECT_EQ(test, 0, dev_addr_del(netdev, addr, + NETDEV_HW_ADDR_T_LAN)); + KUNIT_EXPECT_EQ(test, -ENOENT, dev_addr_del(netdev, addr, + NETDEV_HW_ADDR_T_LAN)); +} + +static void dev_addr_test_add_set(struct kunit *test) +{ + struct net_device *netdev = test->priv; + struct dev_addr_test_priv *datp; + u8 addr[ETH_ALEN]; + int i; + + datp = netdev_priv(netdev); + + /* There is no external API like dev_addr_add_excl(), + * so shuffle the tree a little bit and exploit aliasing. + */ + for (i = 1; i < 16; i++) { + memset(addr, i, sizeof(addr)); + KUNIT_EXPECT_EQ(test, 0, dev_addr_add(netdev, addr, + NETDEV_HW_ADDR_T_LAN)); + } + + memset(addr, i, sizeof(addr)); + eth_hw_addr_set(netdev, addr); + KUNIT_EXPECT_EQ(test, 0, dev_addr_add(netdev, addr, + NETDEV_HW_ADDR_T_LAN)); + memset(addr, 0, sizeof(addr)); + eth_hw_addr_set(netdev, addr); + + __hw_addr_sync_dev(&netdev->dev_addrs, netdev, dev_addr_test_sync, + dev_addr_test_unsync); + KUNIT_EXPECT_EQ(test, 0xffff, datp->addr_seen); +} + +static void dev_addr_test_add_excl(struct kunit *test) +{ + struct net_device *netdev = test->priv; + u8 addr[ETH_ALEN]; + int i; + + for (i = 0; i < 10; i++) { + memset(addr, i, sizeof(addr)); + KUNIT_EXPECT_EQ(test, 0, dev_uc_add_excl(netdev, addr)); + } + KUNIT_EXPECT_EQ(test, -EEXIST, dev_uc_add_excl(netdev, addr)); + + for (i = 0; i < 10; i += 2) { + memset(addr, i, sizeof(addr)); + KUNIT_EXPECT_EQ(test, 0, dev_uc_del(netdev, addr)); + } + for (i = 1; i < 10; i += 2) { + memset(addr, i, sizeof(addr)); + KUNIT_EXPECT_EQ(test, -EEXIST, dev_uc_add_excl(netdev, addr)); + } +} + +static struct kunit_case dev_addr_test_cases[] = { + KUNIT_CASE(dev_addr_test_basic), + KUNIT_CASE(dev_addr_test_sync_one), + KUNIT_CASE(dev_addr_test_add_del), + KUNIT_CASE(dev_addr_test_del_main), + KUNIT_CASE(dev_addr_test_add_set), + KUNIT_CASE(dev_addr_test_add_excl), + {} +}; + +static struct kunit_suite dev_addr_test_suite = { + .name = "dev-addr-list-test", + .test_cases = dev_addr_test_cases, + .init = dev_addr_test_init, + .exit = dev_addr_test_exit, +}; +kunit_test_suite(dev_addr_test_suite); + +MODULE_LICENSE("GPL"); diff --git a/net/core/dev_ioctl.c b/net/core/dev_ioctl.c new file mode 100644 index 000000000..7674bb9f3 --- /dev/null +++ b/net/core/dev_ioctl.c @@ -0,0 +1,619 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dev.h" + +/* + * Map an interface index to its name (SIOCGIFNAME) + */ + +/* + * We need this ioctl for efficient implementation of the + * if_indextoname() function required by the IPv6 API. Without + * it, we would have to search all the interfaces to find a + * match. --pb + */ + +static int dev_ifname(struct net *net, struct ifreq *ifr) +{ + ifr->ifr_name[IFNAMSIZ-1] = 0; + return netdev_get_name(net, ifr->ifr_name, ifr->ifr_ifindex); +} + +/* + * Perform a SIOCGIFCONF call. This structure will change + * size eventually, and there is nothing I can do about it. + * Thus we will need a 'compatibility mode'. + */ +int dev_ifconf(struct net *net, struct ifconf __user *uifc) +{ + struct net_device *dev; + void __user *pos; + size_t size; + int len, total = 0, done; + + /* both the ifconf and the ifreq structures are slightly different */ + if (in_compat_syscall()) { + struct compat_ifconf ifc32; + + if (copy_from_user(&ifc32, uifc, sizeof(struct compat_ifconf))) + return -EFAULT; + + pos = compat_ptr(ifc32.ifcbuf); + len = ifc32.ifc_len; + size = sizeof(struct compat_ifreq); + } else { + struct ifconf ifc; + + if (copy_from_user(&ifc, uifc, sizeof(struct ifconf))) + return -EFAULT; + + pos = ifc.ifc_buf; + len = ifc.ifc_len; + size = sizeof(struct ifreq); + } + + /* Loop over the interfaces, and write an info block for each. */ + rtnl_lock(); + for_each_netdev(net, dev) { + if (!pos) + done = inet_gifconf(dev, NULL, 0, size); + else + done = inet_gifconf(dev, pos + total, + len - total, size); + if (done < 0) { + rtnl_unlock(); + return -EFAULT; + } + total += done; + } + rtnl_unlock(); + + return put_user(total, &uifc->ifc_len); +} + +static int dev_getifmap(struct net_device *dev, struct ifreq *ifr) +{ + struct ifmap *ifmap = &ifr->ifr_map; + + if (in_compat_syscall()) { + struct compat_ifmap *cifmap = (struct compat_ifmap *)ifmap; + + cifmap->mem_start = dev->mem_start; + cifmap->mem_end = dev->mem_end; + cifmap->base_addr = dev->base_addr; + cifmap->irq = dev->irq; + cifmap->dma = dev->dma; + cifmap->port = dev->if_port; + + return 0; + } + + ifmap->mem_start = dev->mem_start; + ifmap->mem_end = dev->mem_end; + ifmap->base_addr = dev->base_addr; + ifmap->irq = dev->irq; + ifmap->dma = dev->dma; + ifmap->port = dev->if_port; + + return 0; +} + +static int dev_setifmap(struct net_device *dev, struct ifreq *ifr) +{ + struct compat_ifmap *cifmap = (struct compat_ifmap *)&ifr->ifr_map; + + if (!dev->netdev_ops->ndo_set_config) + return -EOPNOTSUPP; + + if (in_compat_syscall()) { + struct ifmap ifmap = { + .mem_start = cifmap->mem_start, + .mem_end = cifmap->mem_end, + .base_addr = cifmap->base_addr, + .irq = cifmap->irq, + .dma = cifmap->dma, + .port = cifmap->port, + }; + + return dev->netdev_ops->ndo_set_config(dev, &ifmap); + } + + return dev->netdev_ops->ndo_set_config(dev, &ifr->ifr_map); +} + +/* + * Perform the SIOCxIFxxx calls, inside rcu_read_lock() + */ +static int dev_ifsioc_locked(struct net *net, struct ifreq *ifr, unsigned int cmd) +{ + int err; + struct net_device *dev = dev_get_by_name_rcu(net, ifr->ifr_name); + + if (!dev) + return -ENODEV; + + switch (cmd) { + case SIOCGIFFLAGS: /* Get interface flags */ + ifr->ifr_flags = (short) dev_get_flags(dev); + return 0; + + case SIOCGIFMETRIC: /* Get the metric on the interface + (currently unused) */ + ifr->ifr_metric = 0; + return 0; + + case SIOCGIFMTU: /* Get the MTU of a device */ + ifr->ifr_mtu = dev->mtu; + return 0; + + case SIOCGIFSLAVE: + err = -EINVAL; + break; + + case SIOCGIFMAP: + return dev_getifmap(dev, ifr); + + case SIOCGIFINDEX: + ifr->ifr_ifindex = dev->ifindex; + return 0; + + case SIOCGIFTXQLEN: + ifr->ifr_qlen = dev->tx_queue_len; + return 0; + + default: + /* dev_ioctl() should ensure this case + * is never reached + */ + WARN_ON(1); + err = -ENOTTY; + break; + + } + return err; +} + +static int net_hwtstamp_validate(struct ifreq *ifr) +{ + struct hwtstamp_config cfg; + enum hwtstamp_tx_types tx_type; + enum hwtstamp_rx_filters rx_filter; + int tx_type_valid = 0; + int rx_filter_valid = 0; + + if (copy_from_user(&cfg, ifr->ifr_data, sizeof(cfg))) + return -EFAULT; + + if (cfg.flags & ~HWTSTAMP_FLAG_MASK) + return -EINVAL; + + tx_type = cfg.tx_type; + rx_filter = cfg.rx_filter; + + switch (tx_type) { + case HWTSTAMP_TX_OFF: + case HWTSTAMP_TX_ON: + case HWTSTAMP_TX_ONESTEP_SYNC: + case HWTSTAMP_TX_ONESTEP_P2P: + tx_type_valid = 1; + break; + case __HWTSTAMP_TX_CNT: + /* not a real value */ + break; + } + + switch (rx_filter) { + case HWTSTAMP_FILTER_NONE: + case HWTSTAMP_FILTER_ALL: + case HWTSTAMP_FILTER_SOME: + case HWTSTAMP_FILTER_PTP_V1_L4_EVENT: + case HWTSTAMP_FILTER_PTP_V1_L4_SYNC: + case HWTSTAMP_FILTER_PTP_V1_L4_DELAY_REQ: + case HWTSTAMP_FILTER_PTP_V2_L4_EVENT: + case HWTSTAMP_FILTER_PTP_V2_L4_SYNC: + case HWTSTAMP_FILTER_PTP_V2_L4_DELAY_REQ: + case HWTSTAMP_FILTER_PTP_V2_L2_EVENT: + case HWTSTAMP_FILTER_PTP_V2_L2_SYNC: + case HWTSTAMP_FILTER_PTP_V2_L2_DELAY_REQ: + case HWTSTAMP_FILTER_PTP_V2_EVENT: + case HWTSTAMP_FILTER_PTP_V2_SYNC: + case HWTSTAMP_FILTER_PTP_V2_DELAY_REQ: + case HWTSTAMP_FILTER_NTP_ALL: + rx_filter_valid = 1; + break; + case __HWTSTAMP_FILTER_CNT: + /* not a real value */ + break; + } + + if (!tx_type_valid || !rx_filter_valid) + return -ERANGE; + + return 0; +} + +static int dev_eth_ioctl(struct net_device *dev, + struct ifreq *ifr, unsigned int cmd) +{ + const struct net_device_ops *ops = dev->netdev_ops; + int err; + + err = dsa_ndo_eth_ioctl(dev, ifr, cmd); + if (err == 0 || err != -EOPNOTSUPP) + return err; + + if (ops->ndo_eth_ioctl) { + if (netif_device_present(dev)) + err = ops->ndo_eth_ioctl(dev, ifr, cmd); + else + err = -ENODEV; + } + + return err; +} + +static int dev_siocbond(struct net_device *dev, + struct ifreq *ifr, unsigned int cmd) +{ + const struct net_device_ops *ops = dev->netdev_ops; + + if (ops->ndo_siocbond) { + if (netif_device_present(dev)) + return ops->ndo_siocbond(dev, ifr, cmd); + else + return -ENODEV; + } + + return -EOPNOTSUPP; +} + +static int dev_siocdevprivate(struct net_device *dev, struct ifreq *ifr, + void __user *data, unsigned int cmd) +{ + const struct net_device_ops *ops = dev->netdev_ops; + + if (ops->ndo_siocdevprivate) { + if (netif_device_present(dev)) + return ops->ndo_siocdevprivate(dev, ifr, data, cmd); + else + return -ENODEV; + } + + return -EOPNOTSUPP; +} + +static int dev_siocwandev(struct net_device *dev, struct if_settings *ifs) +{ + const struct net_device_ops *ops = dev->netdev_ops; + + if (ops->ndo_siocwandev) { + if (netif_device_present(dev)) + return ops->ndo_siocwandev(dev, ifs); + else + return -ENODEV; + } + + return -EOPNOTSUPP; +} + +/* + * Perform the SIOCxIFxxx calls, inside rtnl_lock() + */ +static int dev_ifsioc(struct net *net, struct ifreq *ifr, void __user *data, + unsigned int cmd) +{ + int err; + struct net_device *dev = __dev_get_by_name(net, ifr->ifr_name); + const struct net_device_ops *ops; + netdevice_tracker dev_tracker; + + if (!dev) + return -ENODEV; + + ops = dev->netdev_ops; + + switch (cmd) { + case SIOCSIFFLAGS: /* Set interface flags */ + return dev_change_flags(dev, ifr->ifr_flags, NULL); + + case SIOCSIFMETRIC: /* Set the metric on the interface + (currently unused) */ + return -EOPNOTSUPP; + + case SIOCSIFMTU: /* Set the MTU of a device */ + return dev_set_mtu(dev, ifr->ifr_mtu); + + case SIOCSIFHWADDR: + if (dev->addr_len > sizeof(struct sockaddr)) + return -EINVAL; + return dev_set_mac_address_user(dev, &ifr->ifr_hwaddr, NULL); + + case SIOCSIFHWBROADCAST: + if (ifr->ifr_hwaddr.sa_family != dev->type) + return -EINVAL; + memcpy(dev->broadcast, ifr->ifr_hwaddr.sa_data, + min(sizeof(ifr->ifr_hwaddr.sa_data), + (size_t)dev->addr_len)); + call_netdevice_notifiers(NETDEV_CHANGEADDR, dev); + return 0; + + case SIOCSIFMAP: + return dev_setifmap(dev, ifr); + + case SIOCADDMULTI: + if (!ops->ndo_set_rx_mode || + ifr->ifr_hwaddr.sa_family != AF_UNSPEC) + return -EINVAL; + if (!netif_device_present(dev)) + return -ENODEV; + return dev_mc_add_global(dev, ifr->ifr_hwaddr.sa_data); + + case SIOCDELMULTI: + if (!ops->ndo_set_rx_mode || + ifr->ifr_hwaddr.sa_family != AF_UNSPEC) + return -EINVAL; + if (!netif_device_present(dev)) + return -ENODEV; + return dev_mc_del_global(dev, ifr->ifr_hwaddr.sa_data); + + case SIOCSIFTXQLEN: + if (ifr->ifr_qlen < 0) + return -EINVAL; + return dev_change_tx_queue_len(dev, ifr->ifr_qlen); + + case SIOCSIFNAME: + ifr->ifr_newname[IFNAMSIZ-1] = '\0'; + return dev_change_name(dev, ifr->ifr_newname); + + case SIOCWANDEV: + return dev_siocwandev(dev, &ifr->ifr_settings); + + case SIOCBRADDIF: + case SIOCBRDELIF: + if (!netif_device_present(dev)) + return -ENODEV; + if (!netif_is_bridge_master(dev)) + return -EOPNOTSUPP; + netdev_hold(dev, &dev_tracker, GFP_KERNEL); + rtnl_unlock(); + err = br_ioctl_call(net, netdev_priv(dev), cmd, ifr, NULL); + netdev_put(dev, &dev_tracker); + rtnl_lock(); + return err; + + case SIOCSHWTSTAMP: + err = net_hwtstamp_validate(ifr); + if (err) + return err; + fallthrough; + + /* + * Unknown or private ioctl + */ + default: + if (cmd >= SIOCDEVPRIVATE && + cmd <= SIOCDEVPRIVATE + 15) + return dev_siocdevprivate(dev, ifr, data, cmd); + + if (cmd == SIOCGMIIPHY || + cmd == SIOCGMIIREG || + cmd == SIOCSMIIREG || + cmd == SIOCSHWTSTAMP || + cmd == SIOCGHWTSTAMP) { + err = dev_eth_ioctl(dev, ifr, cmd); + } else if (cmd == SIOCBONDENSLAVE || + cmd == SIOCBONDRELEASE || + cmd == SIOCBONDSETHWADDR || + cmd == SIOCBONDSLAVEINFOQUERY || + cmd == SIOCBONDINFOQUERY || + cmd == SIOCBONDCHANGEACTIVE) { + err = dev_siocbond(dev, ifr, cmd); + } else + err = -EINVAL; + + } + return err; +} + +/** + * dev_load - load a network module + * @net: the applicable net namespace + * @name: name of interface + * + * If a network interface is not present and the process has suitable + * privileges this function loads the module. If module loading is not + * available in this kernel then it becomes a nop. + */ + +void dev_load(struct net *net, const char *name) +{ + struct net_device *dev; + int no_module; + + rcu_read_lock(); + dev = dev_get_by_name_rcu(net, name); + rcu_read_unlock(); + + no_module = !dev; + if (no_module && capable(CAP_NET_ADMIN)) + no_module = request_module("netdev-%s", name); + if (no_module && capable(CAP_SYS_MODULE)) + request_module("%s", name); +} +EXPORT_SYMBOL(dev_load); + +/* + * This function handles all "interface"-type I/O control requests. The actual + * 'doing' part of this is dev_ifsioc above. + */ + +/** + * dev_ioctl - network device ioctl + * @net: the applicable net namespace + * @cmd: command to issue + * @ifr: pointer to a struct ifreq in user space + * @need_copyout: whether or not copy_to_user() should be called + * + * Issue ioctl functions to devices. This is normally called by the + * user space syscall interfaces but can sometimes be useful for + * other purposes. The return value is the return from the syscall if + * positive or a negative errno code on error. + */ + +int dev_ioctl(struct net *net, unsigned int cmd, struct ifreq *ifr, + void __user *data, bool *need_copyout) +{ + int ret; + char *colon; + + if (need_copyout) + *need_copyout = true; + if (cmd == SIOCGIFNAME) + return dev_ifname(net, ifr); + + ifr->ifr_name[IFNAMSIZ-1] = 0; + + colon = strchr(ifr->ifr_name, ':'); + if (colon) + *colon = 0; + + /* + * See which interface the caller is talking about. + */ + + switch (cmd) { + case SIOCGIFHWADDR: + dev_load(net, ifr->ifr_name); + ret = dev_get_mac_address(&ifr->ifr_hwaddr, net, ifr->ifr_name); + if (colon) + *colon = ':'; + return ret; + /* + * These ioctl calls: + * - can be done by all. + * - atomic and do not require locking. + * - return a value + */ + case SIOCGIFFLAGS: + case SIOCGIFMETRIC: + case SIOCGIFMTU: + case SIOCGIFSLAVE: + case SIOCGIFMAP: + case SIOCGIFINDEX: + case SIOCGIFTXQLEN: + dev_load(net, ifr->ifr_name); + rcu_read_lock(); + ret = dev_ifsioc_locked(net, ifr, cmd); + rcu_read_unlock(); + if (colon) + *colon = ':'; + return ret; + + case SIOCETHTOOL: + dev_load(net, ifr->ifr_name); + ret = dev_ethtool(net, ifr, data); + if (colon) + *colon = ':'; + return ret; + + /* + * These ioctl calls: + * - require superuser power. + * - require strict serialization. + * - return a value + */ + case SIOCGMIIPHY: + case SIOCGMIIREG: + case SIOCSIFNAME: + dev_load(net, ifr->ifr_name); + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + rtnl_lock(); + ret = dev_ifsioc(net, ifr, data, cmd); + rtnl_unlock(); + if (colon) + *colon = ':'; + return ret; + + /* + * These ioctl calls: + * - require superuser power. + * - require strict serialization. + * - do not return a value + */ + case SIOCSIFMAP: + case SIOCSIFTXQLEN: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + fallthrough; + /* + * These ioctl calls: + * - require local superuser power. + * - require strict serialization. + * - do not return a value + */ + case SIOCSIFFLAGS: + case SIOCSIFMETRIC: + case SIOCSIFMTU: + case SIOCSIFHWADDR: + case SIOCSIFSLAVE: + case SIOCADDMULTI: + case SIOCDELMULTI: + case SIOCSIFHWBROADCAST: + case SIOCSMIIREG: + case SIOCBONDENSLAVE: + case SIOCBONDRELEASE: + case SIOCBONDSETHWADDR: + case SIOCBONDCHANGEACTIVE: + case SIOCBRADDIF: + case SIOCBRDELIF: + case SIOCSHWTSTAMP: + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + fallthrough; + case SIOCBONDSLAVEINFOQUERY: + case SIOCBONDINFOQUERY: + dev_load(net, ifr->ifr_name); + rtnl_lock(); + ret = dev_ifsioc(net, ifr, data, cmd); + rtnl_unlock(); + if (need_copyout) + *need_copyout = false; + return ret; + + case SIOCGIFMEM: + /* Get the per device memory space. We can add this but + * currently do not support it */ + case SIOCSIFMEM: + /* Set the per device memory buffer space. + * Not applicable in our case */ + case SIOCSIFLINK: + return -ENOTTY; + + /* + * Unknown or private ioctl. + */ + default: + if (cmd == SIOCWANDEV || + cmd == SIOCGHWTSTAMP || + (cmd >= SIOCDEVPRIVATE && + cmd <= SIOCDEVPRIVATE + 15)) { + dev_load(net, ifr->ifr_name); + rtnl_lock(); + ret = dev_ifsioc(net, ifr, data, cmd); + rtnl_unlock(); + return ret; + } + return -ENOTTY; + } +} diff --git a/net/core/drop_monitor.c b/net/core/drop_monitor.c new file mode 100644 index 000000000..8e0a90b45 --- /dev/null +++ b/net/core/drop_monitor.c @@ -0,0 +1,1772 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Monitoring code for network dropped packet alerts + * + * Copyright (C) 2009 Neil Horman + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#define TRACE_ON 1 +#define TRACE_OFF 0 + +/* + * Globals, our netlink socket pointer + * and the work handle that will send up + * netlink alerts + */ +static int trace_state = TRACE_OFF; +static bool monitor_hw; + +/* net_dm_mutex + * + * An overall lock guarding every operation coming from userspace. + */ +static DEFINE_MUTEX(net_dm_mutex); + +struct net_dm_stats { + u64_stats_t dropped; + struct u64_stats_sync syncp; +}; + +#define NET_DM_MAX_HW_TRAP_NAME_LEN 40 + +struct net_dm_hw_entry { + char trap_name[NET_DM_MAX_HW_TRAP_NAME_LEN]; + u32 count; +}; + +struct net_dm_hw_entries { + u32 num_entries; + struct net_dm_hw_entry entries[]; +}; + +struct per_cpu_dm_data { + spinlock_t lock; /* Protects 'skb', 'hw_entries' and + * 'send_timer' + */ + union { + struct sk_buff *skb; + struct net_dm_hw_entries *hw_entries; + }; + struct sk_buff_head drop_queue; + struct work_struct dm_alert_work; + struct timer_list send_timer; + struct net_dm_stats stats; +}; + +struct dm_hw_stat_delta { + unsigned long last_rx; + unsigned long last_drop_val; + struct rcu_head rcu; +}; + +static struct genl_family net_drop_monitor_family; + +static DEFINE_PER_CPU(struct per_cpu_dm_data, dm_cpu_data); +static DEFINE_PER_CPU(struct per_cpu_dm_data, dm_hw_cpu_data); + +static int dm_hit_limit = 64; +static int dm_delay = 1; +static unsigned long dm_hw_check_delta = 2*HZ; + +static enum net_dm_alert_mode net_dm_alert_mode = NET_DM_ALERT_MODE_SUMMARY; +static u32 net_dm_trunc_len; +static u32 net_dm_queue_len = 1000; + +struct net_dm_alert_ops { + void (*kfree_skb_probe)(void *ignore, struct sk_buff *skb, + void *location, + enum skb_drop_reason reason); + void (*napi_poll_probe)(void *ignore, struct napi_struct *napi, + int work, int budget); + void (*work_item_func)(struct work_struct *work); + void (*hw_work_item_func)(struct work_struct *work); + void (*hw_trap_probe)(void *ignore, const struct devlink *devlink, + struct sk_buff *skb, + const struct devlink_trap_metadata *metadata); +}; + +struct net_dm_skb_cb { + union { + struct devlink_trap_metadata *hw_metadata; + void *pc; + }; + enum skb_drop_reason reason; +}; + +#define NET_DM_SKB_CB(__skb) ((struct net_dm_skb_cb *)&((__skb)->cb[0])) + +static struct sk_buff *reset_per_cpu_data(struct per_cpu_dm_data *data) +{ + size_t al; + struct net_dm_alert_msg *msg; + struct nlattr *nla; + struct sk_buff *skb; + unsigned long flags; + void *msg_header; + + al = sizeof(struct net_dm_alert_msg); + al += dm_hit_limit * sizeof(struct net_dm_drop_point); + al += sizeof(struct nlattr); + + skb = genlmsg_new(al, GFP_KERNEL); + + if (!skb) + goto err; + + msg_header = genlmsg_put(skb, 0, 0, &net_drop_monitor_family, + 0, NET_DM_CMD_ALERT); + if (!msg_header) { + nlmsg_free(skb); + skb = NULL; + goto err; + } + nla = nla_reserve(skb, NLA_UNSPEC, + sizeof(struct net_dm_alert_msg)); + if (!nla) { + nlmsg_free(skb); + skb = NULL; + goto err; + } + msg = nla_data(nla); + memset(msg, 0, al); + goto out; + +err: + mod_timer(&data->send_timer, jiffies + HZ / 10); +out: + spin_lock_irqsave(&data->lock, flags); + swap(data->skb, skb); + spin_unlock_irqrestore(&data->lock, flags); + + if (skb) { + struct nlmsghdr *nlh = (struct nlmsghdr *)skb->data; + struct genlmsghdr *gnlh = (struct genlmsghdr *)nlmsg_data(nlh); + + genlmsg_end(skb, genlmsg_data(gnlh)); + } + + return skb; +} + +static const struct genl_multicast_group dropmon_mcgrps[] = { + { .name = "events", .cap_sys_admin = 1 }, +}; + +static void send_dm_alert(struct work_struct *work) +{ + struct sk_buff *skb; + struct per_cpu_dm_data *data; + + data = container_of(work, struct per_cpu_dm_data, dm_alert_work); + + skb = reset_per_cpu_data(data); + + if (skb) + genlmsg_multicast(&net_drop_monitor_family, skb, 0, + 0, GFP_KERNEL); +} + +/* + * This is the timer function to delay the sending of an alert + * in the event that more drops will arrive during the + * hysteresis period. + */ +static void sched_send_work(struct timer_list *t) +{ + struct per_cpu_dm_data *data = from_timer(data, t, send_timer); + + schedule_work(&data->dm_alert_work); +} + +static void trace_drop_common(struct sk_buff *skb, void *location) +{ + struct net_dm_alert_msg *msg; + struct net_dm_drop_point *point; + struct nlmsghdr *nlh; + struct nlattr *nla; + int i; + struct sk_buff *dskb; + struct per_cpu_dm_data *data; + unsigned long flags; + + local_irq_save(flags); + data = this_cpu_ptr(&dm_cpu_data); + spin_lock(&data->lock); + dskb = data->skb; + + if (!dskb) + goto out; + + nlh = (struct nlmsghdr *)dskb->data; + nla = genlmsg_data(nlmsg_data(nlh)); + msg = nla_data(nla); + point = msg->points; + for (i = 0; i < msg->entries; i++) { + if (!memcmp(&location, &point->pc, sizeof(void *))) { + point->count++; + goto out; + } + point++; + } + if (msg->entries == dm_hit_limit) + goto out; + /* + * We need to create a new entry + */ + __nla_reserve_nohdr(dskb, sizeof(struct net_dm_drop_point)); + nla->nla_len += NLA_ALIGN(sizeof(struct net_dm_drop_point)); + memcpy(point->pc, &location, sizeof(void *)); + point->count = 1; + msg->entries++; + + if (!timer_pending(&data->send_timer)) { + data->send_timer.expires = jiffies + dm_delay * HZ; + add_timer(&data->send_timer); + } + +out: + spin_unlock_irqrestore(&data->lock, flags); +} + +static void trace_kfree_skb_hit(void *ignore, struct sk_buff *skb, + void *location, + enum skb_drop_reason reason) +{ + trace_drop_common(skb, location); +} + +static void trace_napi_poll_hit(void *ignore, struct napi_struct *napi, + int work, int budget) +{ + struct net_device *dev = napi->dev; + struct dm_hw_stat_delta *stat; + /* + * Don't check napi structures with no associated device + */ + if (!dev) + return; + + rcu_read_lock(); + stat = rcu_dereference(dev->dm_private); + if (stat) { + /* + * only add a note to our monitor buffer if: + * 1) its after the last_rx delta + * 2) our rx_dropped count has gone up + */ + if (time_after(jiffies, stat->last_rx + dm_hw_check_delta) && + (dev->stats.rx_dropped != stat->last_drop_val)) { + trace_drop_common(NULL, NULL); + stat->last_drop_val = dev->stats.rx_dropped; + stat->last_rx = jiffies; + } + } + rcu_read_unlock(); +} + +static struct net_dm_hw_entries * +net_dm_hw_reset_per_cpu_data(struct per_cpu_dm_data *hw_data) +{ + struct net_dm_hw_entries *hw_entries; + unsigned long flags; + + hw_entries = kzalloc(struct_size(hw_entries, entries, dm_hit_limit), + GFP_KERNEL); + if (!hw_entries) { + /* If the memory allocation failed, we try to perform another + * allocation in 1/10 second. Otherwise, the probe function + * will constantly bail out. + */ + mod_timer(&hw_data->send_timer, jiffies + HZ / 10); + } + + spin_lock_irqsave(&hw_data->lock, flags); + swap(hw_data->hw_entries, hw_entries); + spin_unlock_irqrestore(&hw_data->lock, flags); + + return hw_entries; +} + +static int net_dm_hw_entry_put(struct sk_buff *msg, + const struct net_dm_hw_entry *hw_entry) +{ + struct nlattr *attr; + + attr = nla_nest_start(msg, NET_DM_ATTR_HW_ENTRY); + if (!attr) + return -EMSGSIZE; + + if (nla_put_string(msg, NET_DM_ATTR_HW_TRAP_NAME, hw_entry->trap_name)) + goto nla_put_failure; + + if (nla_put_u32(msg, NET_DM_ATTR_HW_TRAP_COUNT, hw_entry->count)) + goto nla_put_failure; + + nla_nest_end(msg, attr); + + return 0; + +nla_put_failure: + nla_nest_cancel(msg, attr); + return -EMSGSIZE; +} + +static int net_dm_hw_entries_put(struct sk_buff *msg, + const struct net_dm_hw_entries *hw_entries) +{ + struct nlattr *attr; + int i; + + attr = nla_nest_start(msg, NET_DM_ATTR_HW_ENTRIES); + if (!attr) + return -EMSGSIZE; + + for (i = 0; i < hw_entries->num_entries; i++) { + int rc; + + rc = net_dm_hw_entry_put(msg, &hw_entries->entries[i]); + if (rc) + goto nla_put_failure; + } + + nla_nest_end(msg, attr); + + return 0; + +nla_put_failure: + nla_nest_cancel(msg, attr); + return -EMSGSIZE; +} + +static int +net_dm_hw_summary_report_fill(struct sk_buff *msg, + const struct net_dm_hw_entries *hw_entries) +{ + struct net_dm_alert_msg anc_hdr = { 0 }; + void *hdr; + int rc; + + hdr = genlmsg_put(msg, 0, 0, &net_drop_monitor_family, 0, + NET_DM_CMD_ALERT); + if (!hdr) + return -EMSGSIZE; + + /* We need to put the ancillary header in order not to break user + * space. + */ + if (nla_put(msg, NLA_UNSPEC, sizeof(anc_hdr), &anc_hdr)) + goto nla_put_failure; + + rc = net_dm_hw_entries_put(msg, hw_entries); + if (rc) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static void net_dm_hw_summary_work(struct work_struct *work) +{ + struct net_dm_hw_entries *hw_entries; + struct per_cpu_dm_data *hw_data; + struct sk_buff *msg; + int rc; + + hw_data = container_of(work, struct per_cpu_dm_data, dm_alert_work); + + hw_entries = net_dm_hw_reset_per_cpu_data(hw_data); + if (!hw_entries) + return; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + goto out; + + rc = net_dm_hw_summary_report_fill(msg, hw_entries); + if (rc) { + nlmsg_free(msg); + goto out; + } + + genlmsg_multicast(&net_drop_monitor_family, msg, 0, 0, GFP_KERNEL); + +out: + kfree(hw_entries); +} + +static void +net_dm_hw_trap_summary_probe(void *ignore, const struct devlink *devlink, + struct sk_buff *skb, + const struct devlink_trap_metadata *metadata) +{ + struct net_dm_hw_entries *hw_entries; + struct net_dm_hw_entry *hw_entry; + struct per_cpu_dm_data *hw_data; + unsigned long flags; + int i; + + if (metadata->trap_type == DEVLINK_TRAP_TYPE_CONTROL) + return; + + hw_data = this_cpu_ptr(&dm_hw_cpu_data); + spin_lock_irqsave(&hw_data->lock, flags); + hw_entries = hw_data->hw_entries; + + if (!hw_entries) + goto out; + + for (i = 0; i < hw_entries->num_entries; i++) { + hw_entry = &hw_entries->entries[i]; + if (!strncmp(hw_entry->trap_name, metadata->trap_name, + NET_DM_MAX_HW_TRAP_NAME_LEN - 1)) { + hw_entry->count++; + goto out; + } + } + if (WARN_ON_ONCE(hw_entries->num_entries == dm_hit_limit)) + goto out; + + hw_entry = &hw_entries->entries[hw_entries->num_entries]; + strscpy(hw_entry->trap_name, metadata->trap_name, + NET_DM_MAX_HW_TRAP_NAME_LEN - 1); + hw_entry->count = 1; + hw_entries->num_entries++; + + if (!timer_pending(&hw_data->send_timer)) { + hw_data->send_timer.expires = jiffies + dm_delay * HZ; + add_timer(&hw_data->send_timer); + } + +out: + spin_unlock_irqrestore(&hw_data->lock, flags); +} + +static const struct net_dm_alert_ops net_dm_alert_summary_ops = { + .kfree_skb_probe = trace_kfree_skb_hit, + .napi_poll_probe = trace_napi_poll_hit, + .work_item_func = send_dm_alert, + .hw_work_item_func = net_dm_hw_summary_work, + .hw_trap_probe = net_dm_hw_trap_summary_probe, +}; + +static void net_dm_packet_trace_kfree_skb_hit(void *ignore, + struct sk_buff *skb, + void *location, + enum skb_drop_reason reason) +{ + ktime_t tstamp = ktime_get_real(); + struct per_cpu_dm_data *data; + struct net_dm_skb_cb *cb; + struct sk_buff *nskb; + unsigned long flags; + + if (!skb_mac_header_was_set(skb)) + return; + + nskb = skb_clone(skb, GFP_ATOMIC); + if (!nskb) + return; + + if (unlikely(reason >= SKB_DROP_REASON_MAX || reason <= 0)) + reason = SKB_DROP_REASON_NOT_SPECIFIED; + cb = NET_DM_SKB_CB(nskb); + cb->reason = reason; + cb->pc = location; + /* Override the timestamp because we care about the time when the + * packet was dropped. + */ + nskb->tstamp = tstamp; + + data = this_cpu_ptr(&dm_cpu_data); + + spin_lock_irqsave(&data->drop_queue.lock, flags); + if (skb_queue_len(&data->drop_queue) < net_dm_queue_len) + __skb_queue_tail(&data->drop_queue, nskb); + else + goto unlock_free; + spin_unlock_irqrestore(&data->drop_queue.lock, flags); + + schedule_work(&data->dm_alert_work); + + return; + +unlock_free: + spin_unlock_irqrestore(&data->drop_queue.lock, flags); + u64_stats_update_begin(&data->stats.syncp); + u64_stats_inc(&data->stats.dropped); + u64_stats_update_end(&data->stats.syncp); + consume_skb(nskb); +} + +static void net_dm_packet_trace_napi_poll_hit(void *ignore, + struct napi_struct *napi, + int work, int budget) +{ +} + +static size_t net_dm_in_port_size(void) +{ + /* NET_DM_ATTR_IN_PORT nest */ + return nla_total_size(0) + + /* NET_DM_ATTR_PORT_NETDEV_IFINDEX */ + nla_total_size(sizeof(u32)) + + /* NET_DM_ATTR_PORT_NETDEV_NAME */ + nla_total_size(IFNAMSIZ + 1); +} + +#define NET_DM_MAX_SYMBOL_LEN 40 + +static size_t net_dm_packet_report_size(size_t payload_len, + enum skb_drop_reason reason) +{ + size_t size; + + size = nlmsg_msg_size(GENL_HDRLEN + net_drop_monitor_family.hdrsize); + + return NLMSG_ALIGN(size) + + /* NET_DM_ATTR_ORIGIN */ + nla_total_size(sizeof(u16)) + + /* NET_DM_ATTR_PC */ + nla_total_size(sizeof(u64)) + + /* NET_DM_ATTR_SYMBOL */ + nla_total_size(NET_DM_MAX_SYMBOL_LEN + 1) + + /* NET_DM_ATTR_IN_PORT */ + net_dm_in_port_size() + + /* NET_DM_ATTR_TIMESTAMP */ + nla_total_size(sizeof(u64)) + + /* NET_DM_ATTR_ORIG_LEN */ + nla_total_size(sizeof(u32)) + + /* NET_DM_ATTR_PROTO */ + nla_total_size(sizeof(u16)) + + /* NET_DM_ATTR_REASON */ + nla_total_size(strlen(drop_reasons[reason]) + 1) + + /* NET_DM_ATTR_PAYLOAD */ + nla_total_size(payload_len); +} + +static int net_dm_packet_report_in_port_put(struct sk_buff *msg, int ifindex, + const char *name) +{ + struct nlattr *attr; + + attr = nla_nest_start(msg, NET_DM_ATTR_IN_PORT); + if (!attr) + return -EMSGSIZE; + + if (ifindex && + nla_put_u32(msg, NET_DM_ATTR_PORT_NETDEV_IFINDEX, ifindex)) + goto nla_put_failure; + + if (name && nla_put_string(msg, NET_DM_ATTR_PORT_NETDEV_NAME, name)) + goto nla_put_failure; + + nla_nest_end(msg, attr); + + return 0; + +nla_put_failure: + nla_nest_cancel(msg, attr); + return -EMSGSIZE; +} + +static int net_dm_packet_report_fill(struct sk_buff *msg, struct sk_buff *skb, + size_t payload_len) +{ + struct net_dm_skb_cb *cb = NET_DM_SKB_CB(skb); + char buf[NET_DM_MAX_SYMBOL_LEN]; + struct nlattr *attr; + void *hdr; + int rc; + + hdr = genlmsg_put(msg, 0, 0, &net_drop_monitor_family, 0, + NET_DM_CMD_PACKET_ALERT); + if (!hdr) + return -EMSGSIZE; + + if (nla_put_u16(msg, NET_DM_ATTR_ORIGIN, NET_DM_ORIGIN_SW)) + goto nla_put_failure; + + if (nla_put_u64_64bit(msg, NET_DM_ATTR_PC, (u64)(uintptr_t)cb->pc, + NET_DM_ATTR_PAD)) + goto nla_put_failure; + + if (nla_put_string(msg, NET_DM_ATTR_REASON, + drop_reasons[cb->reason])) + goto nla_put_failure; + + snprintf(buf, sizeof(buf), "%pS", cb->pc); + if (nla_put_string(msg, NET_DM_ATTR_SYMBOL, buf)) + goto nla_put_failure; + + rc = net_dm_packet_report_in_port_put(msg, skb->skb_iif, NULL); + if (rc) + goto nla_put_failure; + + if (nla_put_u64_64bit(msg, NET_DM_ATTR_TIMESTAMP, + ktime_to_ns(skb->tstamp), NET_DM_ATTR_PAD)) + goto nla_put_failure; + + if (nla_put_u32(msg, NET_DM_ATTR_ORIG_LEN, skb->len)) + goto nla_put_failure; + + if (!payload_len) + goto out; + + if (nla_put_u16(msg, NET_DM_ATTR_PROTO, be16_to_cpu(skb->protocol))) + goto nla_put_failure; + + attr = skb_put(msg, nla_total_size(payload_len)); + attr->nla_type = NET_DM_ATTR_PAYLOAD; + attr->nla_len = nla_attr_size(payload_len); + if (skb_copy_bits(skb, 0, nla_data(attr), payload_len)) + goto nla_put_failure; + +out: + genlmsg_end(msg, hdr); + + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +#define NET_DM_MAX_PACKET_SIZE (0xffff - NLA_HDRLEN - NLA_ALIGNTO) + +static void net_dm_packet_report(struct sk_buff *skb) +{ + struct sk_buff *msg; + size_t payload_len; + int rc; + + /* Make sure we start copying the packet from the MAC header */ + if (skb->data > skb_mac_header(skb)) + skb_push(skb, skb->data - skb_mac_header(skb)); + else + skb_pull(skb, skb_mac_header(skb) - skb->data); + + /* Ensure packet fits inside a single netlink attribute */ + payload_len = min_t(size_t, skb->len, NET_DM_MAX_PACKET_SIZE); + if (net_dm_trunc_len) + payload_len = min_t(size_t, net_dm_trunc_len, payload_len); + + msg = nlmsg_new(net_dm_packet_report_size(payload_len, + NET_DM_SKB_CB(skb)->reason), + GFP_KERNEL); + if (!msg) + goto out; + + rc = net_dm_packet_report_fill(msg, skb, payload_len); + if (rc) { + nlmsg_free(msg); + goto out; + } + + genlmsg_multicast(&net_drop_monitor_family, msg, 0, 0, GFP_KERNEL); + +out: + consume_skb(skb); +} + +static void net_dm_packet_work(struct work_struct *work) +{ + struct per_cpu_dm_data *data; + struct sk_buff_head list; + struct sk_buff *skb; + unsigned long flags; + + data = container_of(work, struct per_cpu_dm_data, dm_alert_work); + + __skb_queue_head_init(&list); + + spin_lock_irqsave(&data->drop_queue.lock, flags); + skb_queue_splice_tail_init(&data->drop_queue, &list); + spin_unlock_irqrestore(&data->drop_queue.lock, flags); + + while ((skb = __skb_dequeue(&list))) + net_dm_packet_report(skb); +} + +static size_t +net_dm_flow_action_cookie_size(const struct devlink_trap_metadata *hw_metadata) +{ + return hw_metadata->fa_cookie ? + nla_total_size(hw_metadata->fa_cookie->cookie_len) : 0; +} + +static size_t +net_dm_hw_packet_report_size(size_t payload_len, + const struct devlink_trap_metadata *hw_metadata) +{ + size_t size; + + size = nlmsg_msg_size(GENL_HDRLEN + net_drop_monitor_family.hdrsize); + + return NLMSG_ALIGN(size) + + /* NET_DM_ATTR_ORIGIN */ + nla_total_size(sizeof(u16)) + + /* NET_DM_ATTR_HW_TRAP_GROUP_NAME */ + nla_total_size(strlen(hw_metadata->trap_group_name) + 1) + + /* NET_DM_ATTR_HW_TRAP_NAME */ + nla_total_size(strlen(hw_metadata->trap_name) + 1) + + /* NET_DM_ATTR_IN_PORT */ + net_dm_in_port_size() + + /* NET_DM_ATTR_FLOW_ACTION_COOKIE */ + net_dm_flow_action_cookie_size(hw_metadata) + + /* NET_DM_ATTR_TIMESTAMP */ + nla_total_size(sizeof(u64)) + + /* NET_DM_ATTR_ORIG_LEN */ + nla_total_size(sizeof(u32)) + + /* NET_DM_ATTR_PROTO */ + nla_total_size(sizeof(u16)) + + /* NET_DM_ATTR_PAYLOAD */ + nla_total_size(payload_len); +} + +static int net_dm_hw_packet_report_fill(struct sk_buff *msg, + struct sk_buff *skb, size_t payload_len) +{ + struct devlink_trap_metadata *hw_metadata; + struct nlattr *attr; + void *hdr; + + hw_metadata = NET_DM_SKB_CB(skb)->hw_metadata; + + hdr = genlmsg_put(msg, 0, 0, &net_drop_monitor_family, 0, + NET_DM_CMD_PACKET_ALERT); + if (!hdr) + return -EMSGSIZE; + + if (nla_put_u16(msg, NET_DM_ATTR_ORIGIN, NET_DM_ORIGIN_HW)) + goto nla_put_failure; + + if (nla_put_string(msg, NET_DM_ATTR_HW_TRAP_GROUP_NAME, + hw_metadata->trap_group_name)) + goto nla_put_failure; + + if (nla_put_string(msg, NET_DM_ATTR_HW_TRAP_NAME, + hw_metadata->trap_name)) + goto nla_put_failure; + + if (hw_metadata->input_dev) { + struct net_device *dev = hw_metadata->input_dev; + int rc; + + rc = net_dm_packet_report_in_port_put(msg, dev->ifindex, + dev->name); + if (rc) + goto nla_put_failure; + } + + if (hw_metadata->fa_cookie && + nla_put(msg, NET_DM_ATTR_FLOW_ACTION_COOKIE, + hw_metadata->fa_cookie->cookie_len, + hw_metadata->fa_cookie->cookie)) + goto nla_put_failure; + + if (nla_put_u64_64bit(msg, NET_DM_ATTR_TIMESTAMP, + ktime_to_ns(skb->tstamp), NET_DM_ATTR_PAD)) + goto nla_put_failure; + + if (nla_put_u32(msg, NET_DM_ATTR_ORIG_LEN, skb->len)) + goto nla_put_failure; + + if (!payload_len) + goto out; + + if (nla_put_u16(msg, NET_DM_ATTR_PROTO, be16_to_cpu(skb->protocol))) + goto nla_put_failure; + + attr = skb_put(msg, nla_total_size(payload_len)); + attr->nla_type = NET_DM_ATTR_PAYLOAD; + attr->nla_len = nla_attr_size(payload_len); + if (skb_copy_bits(skb, 0, nla_data(attr), payload_len)) + goto nla_put_failure; + +out: + genlmsg_end(msg, hdr); + + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static struct devlink_trap_metadata * +net_dm_hw_metadata_copy(const struct devlink_trap_metadata *metadata) +{ + const struct flow_action_cookie *fa_cookie; + struct devlink_trap_metadata *hw_metadata; + const char *trap_group_name; + const char *trap_name; + + hw_metadata = kzalloc(sizeof(*hw_metadata), GFP_ATOMIC); + if (!hw_metadata) + return NULL; + + trap_group_name = kstrdup(metadata->trap_group_name, GFP_ATOMIC); + if (!trap_group_name) + goto free_hw_metadata; + hw_metadata->trap_group_name = trap_group_name; + + trap_name = kstrdup(metadata->trap_name, GFP_ATOMIC); + if (!trap_name) + goto free_trap_group; + hw_metadata->trap_name = trap_name; + + if (metadata->fa_cookie) { + size_t cookie_size = sizeof(*fa_cookie) + + metadata->fa_cookie->cookie_len; + + fa_cookie = kmemdup(metadata->fa_cookie, cookie_size, + GFP_ATOMIC); + if (!fa_cookie) + goto free_trap_name; + hw_metadata->fa_cookie = fa_cookie; + } + + hw_metadata->input_dev = metadata->input_dev; + netdev_hold(hw_metadata->input_dev, &hw_metadata->dev_tracker, + GFP_ATOMIC); + + return hw_metadata; + +free_trap_name: + kfree(trap_name); +free_trap_group: + kfree(trap_group_name); +free_hw_metadata: + kfree(hw_metadata); + return NULL; +} + +static void +net_dm_hw_metadata_free(struct devlink_trap_metadata *hw_metadata) +{ + netdev_put(hw_metadata->input_dev, &hw_metadata->dev_tracker); + kfree(hw_metadata->fa_cookie); + kfree(hw_metadata->trap_name); + kfree(hw_metadata->trap_group_name); + kfree(hw_metadata); +} + +static void net_dm_hw_packet_report(struct sk_buff *skb) +{ + struct devlink_trap_metadata *hw_metadata; + struct sk_buff *msg; + size_t payload_len; + int rc; + + if (skb->data > skb_mac_header(skb)) + skb_push(skb, skb->data - skb_mac_header(skb)); + else + skb_pull(skb, skb_mac_header(skb) - skb->data); + + payload_len = min_t(size_t, skb->len, NET_DM_MAX_PACKET_SIZE); + if (net_dm_trunc_len) + payload_len = min_t(size_t, net_dm_trunc_len, payload_len); + + hw_metadata = NET_DM_SKB_CB(skb)->hw_metadata; + msg = nlmsg_new(net_dm_hw_packet_report_size(payload_len, hw_metadata), + GFP_KERNEL); + if (!msg) + goto out; + + rc = net_dm_hw_packet_report_fill(msg, skb, payload_len); + if (rc) { + nlmsg_free(msg); + goto out; + } + + genlmsg_multicast(&net_drop_monitor_family, msg, 0, 0, GFP_KERNEL); + +out: + net_dm_hw_metadata_free(NET_DM_SKB_CB(skb)->hw_metadata); + consume_skb(skb); +} + +static void net_dm_hw_packet_work(struct work_struct *work) +{ + struct per_cpu_dm_data *hw_data; + struct sk_buff_head list; + struct sk_buff *skb; + unsigned long flags; + + hw_data = container_of(work, struct per_cpu_dm_data, dm_alert_work); + + __skb_queue_head_init(&list); + + spin_lock_irqsave(&hw_data->drop_queue.lock, flags); + skb_queue_splice_tail_init(&hw_data->drop_queue, &list); + spin_unlock_irqrestore(&hw_data->drop_queue.lock, flags); + + while ((skb = __skb_dequeue(&list))) + net_dm_hw_packet_report(skb); +} + +static void +net_dm_hw_trap_packet_probe(void *ignore, const struct devlink *devlink, + struct sk_buff *skb, + const struct devlink_trap_metadata *metadata) +{ + struct devlink_trap_metadata *n_hw_metadata; + ktime_t tstamp = ktime_get_real(); + struct per_cpu_dm_data *hw_data; + struct sk_buff *nskb; + unsigned long flags; + + if (metadata->trap_type == DEVLINK_TRAP_TYPE_CONTROL) + return; + + if (!skb_mac_header_was_set(skb)) + return; + + nskb = skb_clone(skb, GFP_ATOMIC); + if (!nskb) + return; + + n_hw_metadata = net_dm_hw_metadata_copy(metadata); + if (!n_hw_metadata) + goto free; + + NET_DM_SKB_CB(nskb)->hw_metadata = n_hw_metadata; + nskb->tstamp = tstamp; + + hw_data = this_cpu_ptr(&dm_hw_cpu_data); + + spin_lock_irqsave(&hw_data->drop_queue.lock, flags); + if (skb_queue_len(&hw_data->drop_queue) < net_dm_queue_len) + __skb_queue_tail(&hw_data->drop_queue, nskb); + else + goto unlock_free; + spin_unlock_irqrestore(&hw_data->drop_queue.lock, flags); + + schedule_work(&hw_data->dm_alert_work); + + return; + +unlock_free: + spin_unlock_irqrestore(&hw_data->drop_queue.lock, flags); + u64_stats_update_begin(&hw_data->stats.syncp); + u64_stats_inc(&hw_data->stats.dropped); + u64_stats_update_end(&hw_data->stats.syncp); + net_dm_hw_metadata_free(n_hw_metadata); +free: + consume_skb(nskb); +} + +static const struct net_dm_alert_ops net_dm_alert_packet_ops = { + .kfree_skb_probe = net_dm_packet_trace_kfree_skb_hit, + .napi_poll_probe = net_dm_packet_trace_napi_poll_hit, + .work_item_func = net_dm_packet_work, + .hw_work_item_func = net_dm_hw_packet_work, + .hw_trap_probe = net_dm_hw_trap_packet_probe, +}; + +static const struct net_dm_alert_ops *net_dm_alert_ops_arr[] = { + [NET_DM_ALERT_MODE_SUMMARY] = &net_dm_alert_summary_ops, + [NET_DM_ALERT_MODE_PACKET] = &net_dm_alert_packet_ops, +}; + +#if IS_ENABLED(CONFIG_NET_DEVLINK) +static int net_dm_hw_probe_register(const struct net_dm_alert_ops *ops) +{ + return register_trace_devlink_trap_report(ops->hw_trap_probe, NULL); +} + +static void net_dm_hw_probe_unregister(const struct net_dm_alert_ops *ops) +{ + unregister_trace_devlink_trap_report(ops->hw_trap_probe, NULL); + tracepoint_synchronize_unregister(); +} +#else +static int net_dm_hw_probe_register(const struct net_dm_alert_ops *ops) +{ + return -EOPNOTSUPP; +} + +static void net_dm_hw_probe_unregister(const struct net_dm_alert_ops *ops) +{ +} +#endif + +static int net_dm_hw_monitor_start(struct netlink_ext_ack *extack) +{ + const struct net_dm_alert_ops *ops; + int cpu, rc; + + if (monitor_hw) { + NL_SET_ERR_MSG_MOD(extack, "Hardware monitoring already enabled"); + return -EAGAIN; + } + + ops = net_dm_alert_ops_arr[net_dm_alert_mode]; + + if (!try_module_get(THIS_MODULE)) { + NL_SET_ERR_MSG_MOD(extack, "Failed to take reference on module"); + return -ENODEV; + } + + for_each_possible_cpu(cpu) { + struct per_cpu_dm_data *hw_data = &per_cpu(dm_hw_cpu_data, cpu); + struct net_dm_hw_entries *hw_entries; + + INIT_WORK(&hw_data->dm_alert_work, ops->hw_work_item_func); + timer_setup(&hw_data->send_timer, sched_send_work, 0); + hw_entries = net_dm_hw_reset_per_cpu_data(hw_data); + kfree(hw_entries); + } + + rc = net_dm_hw_probe_register(ops); + if (rc) { + NL_SET_ERR_MSG_MOD(extack, "Failed to connect probe to devlink_trap_probe() tracepoint"); + goto err_module_put; + } + + monitor_hw = true; + + return 0; + +err_module_put: + for_each_possible_cpu(cpu) { + struct per_cpu_dm_data *hw_data = &per_cpu(dm_hw_cpu_data, cpu); + struct sk_buff *skb; + + del_timer_sync(&hw_data->send_timer); + cancel_work_sync(&hw_data->dm_alert_work); + while ((skb = __skb_dequeue(&hw_data->drop_queue))) { + struct devlink_trap_metadata *hw_metadata; + + hw_metadata = NET_DM_SKB_CB(skb)->hw_metadata; + net_dm_hw_metadata_free(hw_metadata); + consume_skb(skb); + } + } + module_put(THIS_MODULE); + return rc; +} + +static void net_dm_hw_monitor_stop(struct netlink_ext_ack *extack) +{ + const struct net_dm_alert_ops *ops; + int cpu; + + if (!monitor_hw) { + NL_SET_ERR_MSG_MOD(extack, "Hardware monitoring already disabled"); + return; + } + + ops = net_dm_alert_ops_arr[net_dm_alert_mode]; + + monitor_hw = false; + + net_dm_hw_probe_unregister(ops); + + for_each_possible_cpu(cpu) { + struct per_cpu_dm_data *hw_data = &per_cpu(dm_hw_cpu_data, cpu); + struct sk_buff *skb; + + del_timer_sync(&hw_data->send_timer); + cancel_work_sync(&hw_data->dm_alert_work); + while ((skb = __skb_dequeue(&hw_data->drop_queue))) { + struct devlink_trap_metadata *hw_metadata; + + hw_metadata = NET_DM_SKB_CB(skb)->hw_metadata; + net_dm_hw_metadata_free(hw_metadata); + consume_skb(skb); + } + } + + module_put(THIS_MODULE); +} + +static int net_dm_trace_on_set(struct netlink_ext_ack *extack) +{ + const struct net_dm_alert_ops *ops; + int cpu, rc; + + ops = net_dm_alert_ops_arr[net_dm_alert_mode]; + + if (!try_module_get(THIS_MODULE)) { + NL_SET_ERR_MSG_MOD(extack, "Failed to take reference on module"); + return -ENODEV; + } + + for_each_possible_cpu(cpu) { + struct per_cpu_dm_data *data = &per_cpu(dm_cpu_data, cpu); + struct sk_buff *skb; + + INIT_WORK(&data->dm_alert_work, ops->work_item_func); + timer_setup(&data->send_timer, sched_send_work, 0); + /* Allocate a new per-CPU skb for the summary alert message and + * free the old one which might contain stale data from + * previous tracing. + */ + skb = reset_per_cpu_data(data); + consume_skb(skb); + } + + rc = register_trace_kfree_skb(ops->kfree_skb_probe, NULL); + if (rc) { + NL_SET_ERR_MSG_MOD(extack, "Failed to connect probe to kfree_skb() tracepoint"); + goto err_module_put; + } + + rc = register_trace_napi_poll(ops->napi_poll_probe, NULL); + if (rc) { + NL_SET_ERR_MSG_MOD(extack, "Failed to connect probe to napi_poll() tracepoint"); + goto err_unregister_trace; + } + + return 0; + +err_unregister_trace: + unregister_trace_kfree_skb(ops->kfree_skb_probe, NULL); +err_module_put: + for_each_possible_cpu(cpu) { + struct per_cpu_dm_data *data = &per_cpu(dm_cpu_data, cpu); + struct sk_buff *skb; + + del_timer_sync(&data->send_timer); + cancel_work_sync(&data->dm_alert_work); + while ((skb = __skb_dequeue(&data->drop_queue))) + consume_skb(skb); + } + module_put(THIS_MODULE); + return rc; +} + +static void net_dm_trace_off_set(void) +{ + const struct net_dm_alert_ops *ops; + int cpu; + + ops = net_dm_alert_ops_arr[net_dm_alert_mode]; + + unregister_trace_napi_poll(ops->napi_poll_probe, NULL); + unregister_trace_kfree_skb(ops->kfree_skb_probe, NULL); + + tracepoint_synchronize_unregister(); + + /* Make sure we do not send notifications to user space after request + * to stop tracing returns. + */ + for_each_possible_cpu(cpu) { + struct per_cpu_dm_data *data = &per_cpu(dm_cpu_data, cpu); + struct sk_buff *skb; + + del_timer_sync(&data->send_timer); + cancel_work_sync(&data->dm_alert_work); + while ((skb = __skb_dequeue(&data->drop_queue))) + consume_skb(skb); + } + + module_put(THIS_MODULE); +} + +static int set_all_monitor_traces(int state, struct netlink_ext_ack *extack) +{ + int rc = 0; + + if (state == trace_state) { + NL_SET_ERR_MSG_MOD(extack, "Trace state already set to requested state"); + return -EAGAIN; + } + + switch (state) { + case TRACE_ON: + rc = net_dm_trace_on_set(extack); + break; + case TRACE_OFF: + net_dm_trace_off_set(); + break; + default: + rc = 1; + break; + } + + if (!rc) + trace_state = state; + else + rc = -EINPROGRESS; + + return rc; +} + +static bool net_dm_is_monitoring(void) +{ + return trace_state == TRACE_ON || monitor_hw; +} + +static int net_dm_alert_mode_get_from_info(struct genl_info *info, + enum net_dm_alert_mode *p_alert_mode) +{ + u8 val; + + val = nla_get_u8(info->attrs[NET_DM_ATTR_ALERT_MODE]); + + switch (val) { + case NET_DM_ALERT_MODE_SUMMARY: + case NET_DM_ALERT_MODE_PACKET: + *p_alert_mode = val; + break; + default: + return -EINVAL; + } + + return 0; +} + +static int net_dm_alert_mode_set(struct genl_info *info) +{ + struct netlink_ext_ack *extack = info->extack; + enum net_dm_alert_mode alert_mode; + int rc; + + if (!info->attrs[NET_DM_ATTR_ALERT_MODE]) + return 0; + + rc = net_dm_alert_mode_get_from_info(info, &alert_mode); + if (rc) { + NL_SET_ERR_MSG_MOD(extack, "Invalid alert mode"); + return -EINVAL; + } + + net_dm_alert_mode = alert_mode; + + return 0; +} + +static void net_dm_trunc_len_set(struct genl_info *info) +{ + if (!info->attrs[NET_DM_ATTR_TRUNC_LEN]) + return; + + net_dm_trunc_len = nla_get_u32(info->attrs[NET_DM_ATTR_TRUNC_LEN]); +} + +static void net_dm_queue_len_set(struct genl_info *info) +{ + if (!info->attrs[NET_DM_ATTR_QUEUE_LEN]) + return; + + net_dm_queue_len = nla_get_u32(info->attrs[NET_DM_ATTR_QUEUE_LEN]); +} + +static int net_dm_cmd_config(struct sk_buff *skb, + struct genl_info *info) +{ + struct netlink_ext_ack *extack = info->extack; + int rc; + + if (net_dm_is_monitoring()) { + NL_SET_ERR_MSG_MOD(extack, "Cannot configure drop monitor during monitoring"); + return -EBUSY; + } + + rc = net_dm_alert_mode_set(info); + if (rc) + return rc; + + net_dm_trunc_len_set(info); + + net_dm_queue_len_set(info); + + return 0; +} + +static int net_dm_monitor_start(bool set_sw, bool set_hw, + struct netlink_ext_ack *extack) +{ + bool sw_set = false; + int rc; + + if (set_sw) { + rc = set_all_monitor_traces(TRACE_ON, extack); + if (rc) + return rc; + sw_set = true; + } + + if (set_hw) { + rc = net_dm_hw_monitor_start(extack); + if (rc) + goto err_monitor_hw; + } + + return 0; + +err_monitor_hw: + if (sw_set) + set_all_monitor_traces(TRACE_OFF, extack); + return rc; +} + +static void net_dm_monitor_stop(bool set_sw, bool set_hw, + struct netlink_ext_ack *extack) +{ + if (set_hw) + net_dm_hw_monitor_stop(extack); + if (set_sw) + set_all_monitor_traces(TRACE_OFF, extack); +} + +static int net_dm_cmd_trace(struct sk_buff *skb, + struct genl_info *info) +{ + bool set_sw = !!info->attrs[NET_DM_ATTR_SW_DROPS]; + bool set_hw = !!info->attrs[NET_DM_ATTR_HW_DROPS]; + struct netlink_ext_ack *extack = info->extack; + + /* To maintain backward compatibility, we start / stop monitoring of + * software drops if no flag is specified. + */ + if (!set_sw && !set_hw) + set_sw = true; + + switch (info->genlhdr->cmd) { + case NET_DM_CMD_START: + return net_dm_monitor_start(set_sw, set_hw, extack); + case NET_DM_CMD_STOP: + net_dm_monitor_stop(set_sw, set_hw, extack); + return 0; + } + + return -EOPNOTSUPP; +} + +static int net_dm_config_fill(struct sk_buff *msg, struct genl_info *info) +{ + void *hdr; + + hdr = genlmsg_put(msg, info->snd_portid, info->snd_seq, + &net_drop_monitor_family, 0, NET_DM_CMD_CONFIG_NEW); + if (!hdr) + return -EMSGSIZE; + + if (nla_put_u8(msg, NET_DM_ATTR_ALERT_MODE, net_dm_alert_mode)) + goto nla_put_failure; + + if (nla_put_u32(msg, NET_DM_ATTR_TRUNC_LEN, net_dm_trunc_len)) + goto nla_put_failure; + + if (nla_put_u32(msg, NET_DM_ATTR_QUEUE_LEN, net_dm_queue_len)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int net_dm_cmd_config_get(struct sk_buff *skb, struct genl_info *info) +{ + struct sk_buff *msg; + int rc; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + rc = net_dm_config_fill(msg, info); + if (rc) + goto free_msg; + + return genlmsg_reply(msg, info); + +free_msg: + nlmsg_free(msg); + return rc; +} + +static void net_dm_stats_read(struct net_dm_stats *stats) +{ + int cpu; + + memset(stats, 0, sizeof(*stats)); + for_each_possible_cpu(cpu) { + struct per_cpu_dm_data *data = &per_cpu(dm_cpu_data, cpu); + struct net_dm_stats *cpu_stats = &data->stats; + unsigned int start; + u64 dropped; + + do { + start = u64_stats_fetch_begin_irq(&cpu_stats->syncp); + dropped = u64_stats_read(&cpu_stats->dropped); + } while (u64_stats_fetch_retry_irq(&cpu_stats->syncp, start)); + + u64_stats_add(&stats->dropped, dropped); + } +} + +static int net_dm_stats_put(struct sk_buff *msg) +{ + struct net_dm_stats stats; + struct nlattr *attr; + + net_dm_stats_read(&stats); + + attr = nla_nest_start(msg, NET_DM_ATTR_STATS); + if (!attr) + return -EMSGSIZE; + + if (nla_put_u64_64bit(msg, NET_DM_ATTR_STATS_DROPPED, + u64_stats_read(&stats.dropped), NET_DM_ATTR_PAD)) + goto nla_put_failure; + + nla_nest_end(msg, attr); + + return 0; + +nla_put_failure: + nla_nest_cancel(msg, attr); + return -EMSGSIZE; +} + +static void net_dm_hw_stats_read(struct net_dm_stats *stats) +{ + int cpu; + + memset(stats, 0, sizeof(*stats)); + for_each_possible_cpu(cpu) { + struct per_cpu_dm_data *hw_data = &per_cpu(dm_hw_cpu_data, cpu); + struct net_dm_stats *cpu_stats = &hw_data->stats; + unsigned int start; + u64 dropped; + + do { + start = u64_stats_fetch_begin_irq(&cpu_stats->syncp); + dropped = u64_stats_read(&cpu_stats->dropped); + } while (u64_stats_fetch_retry_irq(&cpu_stats->syncp, start)); + + u64_stats_add(&stats->dropped, dropped); + } +} + +static int net_dm_hw_stats_put(struct sk_buff *msg) +{ + struct net_dm_stats stats; + struct nlattr *attr; + + net_dm_hw_stats_read(&stats); + + attr = nla_nest_start(msg, NET_DM_ATTR_HW_STATS); + if (!attr) + return -EMSGSIZE; + + if (nla_put_u64_64bit(msg, NET_DM_ATTR_STATS_DROPPED, + u64_stats_read(&stats.dropped), NET_DM_ATTR_PAD)) + goto nla_put_failure; + + nla_nest_end(msg, attr); + + return 0; + +nla_put_failure: + nla_nest_cancel(msg, attr); + return -EMSGSIZE; +} + +static int net_dm_stats_fill(struct sk_buff *msg, struct genl_info *info) +{ + void *hdr; + int rc; + + hdr = genlmsg_put(msg, info->snd_portid, info->snd_seq, + &net_drop_monitor_family, 0, NET_DM_CMD_STATS_NEW); + if (!hdr) + return -EMSGSIZE; + + rc = net_dm_stats_put(msg); + if (rc) + goto nla_put_failure; + + rc = net_dm_hw_stats_put(msg); + if (rc) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int net_dm_cmd_stats_get(struct sk_buff *skb, struct genl_info *info) +{ + struct sk_buff *msg; + int rc; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + rc = net_dm_stats_fill(msg, info); + if (rc) + goto free_msg; + + return genlmsg_reply(msg, info); + +free_msg: + nlmsg_free(msg); + return rc; +} + +static int dropmon_net_event(struct notifier_block *ev_block, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct dm_hw_stat_delta *stat; + + switch (event) { + case NETDEV_REGISTER: + if (WARN_ON_ONCE(rtnl_dereference(dev->dm_private))) + break; + stat = kzalloc(sizeof(*stat), GFP_KERNEL); + if (!stat) + break; + + stat->last_rx = jiffies; + rcu_assign_pointer(dev->dm_private, stat); + + break; + case NETDEV_UNREGISTER: + stat = rtnl_dereference(dev->dm_private); + if (stat) { + rcu_assign_pointer(dev->dm_private, NULL); + kfree_rcu(stat, rcu); + } + break; + } + return NOTIFY_DONE; +} + +static const struct nla_policy net_dm_nl_policy[NET_DM_ATTR_MAX + 1] = { + [NET_DM_ATTR_UNSPEC] = { .strict_start_type = NET_DM_ATTR_UNSPEC + 1 }, + [NET_DM_ATTR_ALERT_MODE] = { .type = NLA_U8 }, + [NET_DM_ATTR_TRUNC_LEN] = { .type = NLA_U32 }, + [NET_DM_ATTR_QUEUE_LEN] = { .type = NLA_U32 }, + [NET_DM_ATTR_SW_DROPS] = {. type = NLA_FLAG }, + [NET_DM_ATTR_HW_DROPS] = {. type = NLA_FLAG }, +}; + +static const struct genl_small_ops dropmon_ops[] = { + { + .cmd = NET_DM_CMD_CONFIG, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = net_dm_cmd_config, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NET_DM_CMD_START, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = net_dm_cmd_trace, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NET_DM_CMD_STOP, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = net_dm_cmd_trace, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NET_DM_CMD_CONFIG_GET, + .doit = net_dm_cmd_config_get, + }, + { + .cmd = NET_DM_CMD_STATS_GET, + .doit = net_dm_cmd_stats_get, + }, +}; + +static int net_dm_nl_pre_doit(const struct genl_ops *ops, + struct sk_buff *skb, struct genl_info *info) +{ + mutex_lock(&net_dm_mutex); + + return 0; +} + +static void net_dm_nl_post_doit(const struct genl_ops *ops, + struct sk_buff *skb, struct genl_info *info) +{ + mutex_unlock(&net_dm_mutex); +} + +static struct genl_family net_drop_monitor_family __ro_after_init = { + .hdrsize = 0, + .name = "NET_DM", + .version = 2, + .maxattr = NET_DM_ATTR_MAX, + .policy = net_dm_nl_policy, + .pre_doit = net_dm_nl_pre_doit, + .post_doit = net_dm_nl_post_doit, + .module = THIS_MODULE, + .small_ops = dropmon_ops, + .n_small_ops = ARRAY_SIZE(dropmon_ops), + .resv_start_op = NET_DM_CMD_STATS_GET + 1, + .mcgrps = dropmon_mcgrps, + .n_mcgrps = ARRAY_SIZE(dropmon_mcgrps), +}; + +static struct notifier_block dropmon_net_notifier = { + .notifier_call = dropmon_net_event +}; + +static void __net_dm_cpu_data_init(struct per_cpu_dm_data *data) +{ + spin_lock_init(&data->lock); + skb_queue_head_init(&data->drop_queue); + u64_stats_init(&data->stats.syncp); +} + +static void __net_dm_cpu_data_fini(struct per_cpu_dm_data *data) +{ + WARN_ON(!skb_queue_empty(&data->drop_queue)); +} + +static void net_dm_cpu_data_init(int cpu) +{ + struct per_cpu_dm_data *data; + + data = &per_cpu(dm_cpu_data, cpu); + __net_dm_cpu_data_init(data); +} + +static void net_dm_cpu_data_fini(int cpu) +{ + struct per_cpu_dm_data *data; + + data = &per_cpu(dm_cpu_data, cpu); + /* At this point, we should have exclusive access + * to this struct and can free the skb inside it. + */ + consume_skb(data->skb); + __net_dm_cpu_data_fini(data); +} + +static void net_dm_hw_cpu_data_init(int cpu) +{ + struct per_cpu_dm_data *hw_data; + + hw_data = &per_cpu(dm_hw_cpu_data, cpu); + __net_dm_cpu_data_init(hw_data); +} + +static void net_dm_hw_cpu_data_fini(int cpu) +{ + struct per_cpu_dm_data *hw_data; + + hw_data = &per_cpu(dm_hw_cpu_data, cpu); + kfree(hw_data->hw_entries); + __net_dm_cpu_data_fini(hw_data); +} + +static int __init init_net_drop_monitor(void) +{ + int cpu, rc; + + pr_info("Initializing network drop monitor service\n"); + + if (sizeof(void *) > 8) { + pr_err("Unable to store program counters on this arch, Drop monitor failed\n"); + return -ENOSPC; + } + + rc = genl_register_family(&net_drop_monitor_family); + if (rc) { + pr_err("Could not create drop monitor netlink family\n"); + return rc; + } + WARN_ON(net_drop_monitor_family.mcgrp_offset != NET_DM_GRP_ALERT); + + rc = register_netdevice_notifier(&dropmon_net_notifier); + if (rc < 0) { + pr_crit("Failed to register netdevice notifier\n"); + goto out_unreg; + } + + rc = 0; + + for_each_possible_cpu(cpu) { + net_dm_cpu_data_init(cpu); + net_dm_hw_cpu_data_init(cpu); + } + + goto out; + +out_unreg: + genl_unregister_family(&net_drop_monitor_family); +out: + return rc; +} + +static void exit_net_drop_monitor(void) +{ + int cpu; + + BUG_ON(unregister_netdevice_notifier(&dropmon_net_notifier)); + + /* + * Because of the module_get/put we do in the trace state change path + * we are guaranteed not to have any current users when we get here + */ + + for_each_possible_cpu(cpu) { + net_dm_hw_cpu_data_fini(cpu); + net_dm_cpu_data_fini(cpu); + } + + BUG_ON(genl_unregister_family(&net_drop_monitor_family)); +} + +module_init(init_net_drop_monitor); +module_exit(exit_net_drop_monitor); + +MODULE_LICENSE("GPL v2"); +MODULE_AUTHOR("Neil Horman "); +MODULE_ALIAS_GENL_FAMILY("NET_DM"); +MODULE_DESCRIPTION("Monitoring code for network dropped packet alerts"); diff --git a/net/core/dst.c b/net/core/dst.c new file mode 100644 index 000000000..d178c5641 --- /dev/null +++ b/net/core/dst.c @@ -0,0 +1,351 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/core/dst.c Protocol independent destination cache. + * + * Authors: Alexey Kuznetsov, + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +int dst_discard_out(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + kfree_skb(skb); + return 0; +} +EXPORT_SYMBOL(dst_discard_out); + +const struct dst_metrics dst_default_metrics = { + /* This initializer is needed to force linker to place this variable + * into const section. Otherwise it might end into bss section. + * We really want to avoid false sharing on this variable, and catch + * any writes on it. + */ + .refcnt = REFCOUNT_INIT(1), +}; +EXPORT_SYMBOL(dst_default_metrics); + +void dst_init(struct dst_entry *dst, struct dst_ops *ops, + struct net_device *dev, int initial_ref, int initial_obsolete, + unsigned short flags) +{ + dst->dev = dev; + netdev_hold(dev, &dst->dev_tracker, GFP_ATOMIC); + dst->ops = ops; + dst_init_metrics(dst, dst_default_metrics.metrics, true); + dst->expires = 0UL; +#ifdef CONFIG_XFRM + dst->xfrm = NULL; +#endif + dst->input = dst_discard; + dst->output = dst_discard_out; + dst->error = 0; + dst->obsolete = initial_obsolete; + dst->header_len = 0; + dst->trailer_len = 0; +#ifdef CONFIG_IP_ROUTE_CLASSID + dst->tclassid = 0; +#endif + dst->lwtstate = NULL; + atomic_set(&dst->__refcnt, initial_ref); + dst->__use = 0; + dst->lastuse = jiffies; + dst->flags = flags; + if (!(flags & DST_NOCOUNT)) + dst_entries_add(ops, 1); +} +EXPORT_SYMBOL(dst_init); + +void *dst_alloc(struct dst_ops *ops, struct net_device *dev, + int initial_ref, int initial_obsolete, unsigned short flags) +{ + struct dst_entry *dst; + + if (ops->gc && + !(flags & DST_NOCOUNT) && + dst_entries_get_fast(ops) > ops->gc_thresh) + ops->gc(ops); + + dst = kmem_cache_alloc(ops->kmem_cachep, GFP_ATOMIC); + if (!dst) + return NULL; + + dst_init(dst, ops, dev, initial_ref, initial_obsolete, flags); + + return dst; +} +EXPORT_SYMBOL(dst_alloc); + +struct dst_entry *dst_destroy(struct dst_entry * dst) +{ + struct dst_entry *child = NULL; + + smp_rmb(); + +#ifdef CONFIG_XFRM + if (dst->xfrm) { + struct xfrm_dst *xdst = (struct xfrm_dst *) dst; + + child = xdst->child; + } +#endif + if (!(dst->flags & DST_NOCOUNT)) + dst_entries_add(dst->ops, -1); + + if (dst->ops->destroy) + dst->ops->destroy(dst); + netdev_put(dst->dev, &dst->dev_tracker); + + lwtstate_put(dst->lwtstate); + + if (dst->flags & DST_METADATA) + metadata_dst_free((struct metadata_dst *)dst); + else + kmem_cache_free(dst->ops->kmem_cachep, dst); + + dst = child; + if (dst) + dst_release_immediate(dst); + return NULL; +} +EXPORT_SYMBOL(dst_destroy); + +static void dst_destroy_rcu(struct rcu_head *head) +{ + struct dst_entry *dst = container_of(head, struct dst_entry, rcu_head); + + dst = dst_destroy(dst); +} + +/* Operations to mark dst as DEAD and clean up the net device referenced + * by dst: + * 1. put the dst under blackhole interface and discard all tx/rx packets + * on this route. + * 2. release the net_device + * This function should be called when removing routes from the fib tree + * in preparation for a NETDEV_DOWN/NETDEV_UNREGISTER event and also to + * make the next dst_ops->check() fail. + */ +void dst_dev_put(struct dst_entry *dst) +{ + struct net_device *dev = dst->dev; + + dst->obsolete = DST_OBSOLETE_DEAD; + if (dst->ops->ifdown) + dst->ops->ifdown(dst, dev, true); + dst->input = dst_discard; + dst->output = dst_discard_out; + dst->dev = blackhole_netdev; + netdev_ref_replace(dev, blackhole_netdev, &dst->dev_tracker, + GFP_ATOMIC); +} +EXPORT_SYMBOL(dst_dev_put); + +void dst_release(struct dst_entry *dst) +{ + if (dst) { + int newrefcnt; + + newrefcnt = atomic_dec_return(&dst->__refcnt); + if (WARN_ONCE(newrefcnt < 0, "dst_release underflow")) + net_warn_ratelimited("%s: dst:%p refcnt:%d\n", + __func__, dst, newrefcnt); + if (!newrefcnt) + call_rcu(&dst->rcu_head, dst_destroy_rcu); + } +} +EXPORT_SYMBOL(dst_release); + +void dst_release_immediate(struct dst_entry *dst) +{ + if (dst) { + int newrefcnt; + + newrefcnt = atomic_dec_return(&dst->__refcnt); + if (WARN_ONCE(newrefcnt < 0, "dst_release_immediate underflow")) + net_warn_ratelimited("%s: dst:%p refcnt:%d\n", + __func__, dst, newrefcnt); + if (!newrefcnt) + dst_destroy(dst); + } +} +EXPORT_SYMBOL(dst_release_immediate); + +u32 *dst_cow_metrics_generic(struct dst_entry *dst, unsigned long old) +{ + struct dst_metrics *p = kmalloc(sizeof(*p), GFP_ATOMIC); + + if (p) { + struct dst_metrics *old_p = (struct dst_metrics *)__DST_METRICS_PTR(old); + unsigned long prev, new; + + refcount_set(&p->refcnt, 1); + memcpy(p->metrics, old_p->metrics, sizeof(p->metrics)); + + new = (unsigned long) p; + prev = cmpxchg(&dst->_metrics, old, new); + + if (prev != old) { + kfree(p); + p = (struct dst_metrics *)__DST_METRICS_PTR(prev); + if (prev & DST_METRICS_READ_ONLY) + p = NULL; + } else if (prev & DST_METRICS_REFCOUNTED) { + if (refcount_dec_and_test(&old_p->refcnt)) + kfree(old_p); + } + } + BUILD_BUG_ON(offsetof(struct dst_metrics, metrics) != 0); + return (u32 *)p; +} +EXPORT_SYMBOL(dst_cow_metrics_generic); + +/* Caller asserts that dst_metrics_read_only(dst) is false. */ +void __dst_destroy_metrics_generic(struct dst_entry *dst, unsigned long old) +{ + unsigned long prev, new; + + new = ((unsigned long) &dst_default_metrics) | DST_METRICS_READ_ONLY; + prev = cmpxchg(&dst->_metrics, old, new); + if (prev == old) + kfree(__DST_METRICS_PTR(old)); +} +EXPORT_SYMBOL(__dst_destroy_metrics_generic); + +struct dst_entry *dst_blackhole_check(struct dst_entry *dst, u32 cookie) +{ + return NULL; +} + +u32 *dst_blackhole_cow_metrics(struct dst_entry *dst, unsigned long old) +{ + return NULL; +} + +struct neighbour *dst_blackhole_neigh_lookup(const struct dst_entry *dst, + struct sk_buff *skb, + const void *daddr) +{ + return NULL; +} + +void dst_blackhole_update_pmtu(struct dst_entry *dst, struct sock *sk, + struct sk_buff *skb, u32 mtu, + bool confirm_neigh) +{ +} +EXPORT_SYMBOL_GPL(dst_blackhole_update_pmtu); + +void dst_blackhole_redirect(struct dst_entry *dst, struct sock *sk, + struct sk_buff *skb) +{ +} +EXPORT_SYMBOL_GPL(dst_blackhole_redirect); + +unsigned int dst_blackhole_mtu(const struct dst_entry *dst) +{ + unsigned int mtu = dst_metric_raw(dst, RTAX_MTU); + + return mtu ? : dst->dev->mtu; +} +EXPORT_SYMBOL_GPL(dst_blackhole_mtu); + +static struct dst_ops dst_blackhole_ops = { + .family = AF_UNSPEC, + .neigh_lookup = dst_blackhole_neigh_lookup, + .check = dst_blackhole_check, + .cow_metrics = dst_blackhole_cow_metrics, + .update_pmtu = dst_blackhole_update_pmtu, + .redirect = dst_blackhole_redirect, + .mtu = dst_blackhole_mtu, +}; + +static void __metadata_dst_init(struct metadata_dst *md_dst, + enum metadata_type type, u8 optslen) +{ + struct dst_entry *dst; + + dst = &md_dst->dst; + dst_init(dst, &dst_blackhole_ops, NULL, 1, DST_OBSOLETE_NONE, + DST_METADATA | DST_NOCOUNT); + memset(dst + 1, 0, sizeof(*md_dst) + optslen - sizeof(*dst)); + md_dst->type = type; +} + +struct metadata_dst *metadata_dst_alloc(u8 optslen, enum metadata_type type, + gfp_t flags) +{ + struct metadata_dst *md_dst; + + md_dst = kmalloc(sizeof(*md_dst) + optslen, flags); + if (!md_dst) + return NULL; + + __metadata_dst_init(md_dst, type, optslen); + + return md_dst; +} +EXPORT_SYMBOL_GPL(metadata_dst_alloc); + +void metadata_dst_free(struct metadata_dst *md_dst) +{ +#ifdef CONFIG_DST_CACHE + if (md_dst->type == METADATA_IP_TUNNEL) + dst_cache_destroy(&md_dst->u.tun_info.dst_cache); +#endif + kfree(md_dst); +} +EXPORT_SYMBOL_GPL(metadata_dst_free); + +struct metadata_dst __percpu * +metadata_dst_alloc_percpu(u8 optslen, enum metadata_type type, gfp_t flags) +{ + int cpu; + struct metadata_dst __percpu *md_dst; + + md_dst = __alloc_percpu_gfp(sizeof(struct metadata_dst) + optslen, + __alignof__(struct metadata_dst), flags); + if (!md_dst) + return NULL; + + for_each_possible_cpu(cpu) + __metadata_dst_init(per_cpu_ptr(md_dst, cpu), type, optslen); + + return md_dst; +} +EXPORT_SYMBOL_GPL(metadata_dst_alloc_percpu); + +void metadata_dst_free_percpu(struct metadata_dst __percpu *md_dst) +{ +#ifdef CONFIG_DST_CACHE + int cpu; + + for_each_possible_cpu(cpu) { + struct metadata_dst *one_md_dst = per_cpu_ptr(md_dst, cpu); + + if (one_md_dst->type == METADATA_IP_TUNNEL) + dst_cache_destroy(&one_md_dst->u.tun_info.dst_cache); + } +#endif + free_percpu(md_dst); +} +EXPORT_SYMBOL_GPL(metadata_dst_free_percpu); diff --git a/net/core/dst_cache.c b/net/core/dst_cache.c new file mode 100644 index 000000000..0ccfd5fa5 --- /dev/null +++ b/net/core/dst_cache.c @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/core/dst_cache.c - dst entry cache + * + * Copyright (c) 2016 Paolo Abeni + */ + +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_IPV6) +#include +#endif +#include + +struct dst_cache_pcpu { + unsigned long refresh_ts; + struct dst_entry *dst; + u32 cookie; + union { + struct in_addr in_saddr; + struct in6_addr in6_saddr; + }; +}; + +static void dst_cache_per_cpu_dst_set(struct dst_cache_pcpu *dst_cache, + struct dst_entry *dst, u32 cookie) +{ + dst_release(dst_cache->dst); + if (dst) + dst_hold(dst); + + dst_cache->cookie = cookie; + dst_cache->dst = dst; +} + +static struct dst_entry *dst_cache_per_cpu_get(struct dst_cache *dst_cache, + struct dst_cache_pcpu *idst) +{ + struct dst_entry *dst; + + dst = idst->dst; + if (!dst) + goto fail; + + /* the cache already hold a dst reference; it can't go away */ + dst_hold(dst); + + if (unlikely(!time_after(idst->refresh_ts, dst_cache->reset_ts) || + (dst->obsolete && !dst->ops->check(dst, idst->cookie)))) { + dst_cache_per_cpu_dst_set(idst, NULL, 0); + dst_release(dst); + goto fail; + } + return dst; + +fail: + idst->refresh_ts = jiffies; + return NULL; +} + +struct dst_entry *dst_cache_get(struct dst_cache *dst_cache) +{ + if (!dst_cache->cache) + return NULL; + + return dst_cache_per_cpu_get(dst_cache, this_cpu_ptr(dst_cache->cache)); +} +EXPORT_SYMBOL_GPL(dst_cache_get); + +struct rtable *dst_cache_get_ip4(struct dst_cache *dst_cache, __be32 *saddr) +{ + struct dst_cache_pcpu *idst; + struct dst_entry *dst; + + if (!dst_cache->cache) + return NULL; + + idst = this_cpu_ptr(dst_cache->cache); + dst = dst_cache_per_cpu_get(dst_cache, idst); + if (!dst) + return NULL; + + *saddr = idst->in_saddr.s_addr; + return container_of(dst, struct rtable, dst); +} +EXPORT_SYMBOL_GPL(dst_cache_get_ip4); + +void dst_cache_set_ip4(struct dst_cache *dst_cache, struct dst_entry *dst, + __be32 saddr) +{ + struct dst_cache_pcpu *idst; + + if (!dst_cache->cache) + return; + + idst = this_cpu_ptr(dst_cache->cache); + dst_cache_per_cpu_dst_set(idst, dst, 0); + idst->in_saddr.s_addr = saddr; +} +EXPORT_SYMBOL_GPL(dst_cache_set_ip4); + +#if IS_ENABLED(CONFIG_IPV6) +void dst_cache_set_ip6(struct dst_cache *dst_cache, struct dst_entry *dst, + const struct in6_addr *saddr) +{ + struct dst_cache_pcpu *idst; + + if (!dst_cache->cache) + return; + + idst = this_cpu_ptr(dst_cache->cache); + dst_cache_per_cpu_dst_set(this_cpu_ptr(dst_cache->cache), dst, + rt6_get_cookie((struct rt6_info *)dst)); + idst->in6_saddr = *saddr; +} +EXPORT_SYMBOL_GPL(dst_cache_set_ip6); + +struct dst_entry *dst_cache_get_ip6(struct dst_cache *dst_cache, + struct in6_addr *saddr) +{ + struct dst_cache_pcpu *idst; + struct dst_entry *dst; + + if (!dst_cache->cache) + return NULL; + + idst = this_cpu_ptr(dst_cache->cache); + dst = dst_cache_per_cpu_get(dst_cache, idst); + if (!dst) + return NULL; + + *saddr = idst->in6_saddr; + return dst; +} +EXPORT_SYMBOL_GPL(dst_cache_get_ip6); +#endif + +int dst_cache_init(struct dst_cache *dst_cache, gfp_t gfp) +{ + dst_cache->cache = alloc_percpu_gfp(struct dst_cache_pcpu, + gfp | __GFP_ZERO); + if (!dst_cache->cache) + return -ENOMEM; + + dst_cache_reset(dst_cache); + return 0; +} +EXPORT_SYMBOL_GPL(dst_cache_init); + +void dst_cache_destroy(struct dst_cache *dst_cache) +{ + int i; + + if (!dst_cache->cache) + return; + + for_each_possible_cpu(i) + dst_release(per_cpu_ptr(dst_cache->cache, i)->dst); + + free_percpu(dst_cache->cache); +} +EXPORT_SYMBOL_GPL(dst_cache_destroy); + +void dst_cache_reset_now(struct dst_cache *dst_cache) +{ + int i; + + if (!dst_cache->cache) + return; + + dst_cache->reset_ts = jiffies; + for_each_possible_cpu(i) { + struct dst_cache_pcpu *idst = per_cpu_ptr(dst_cache->cache, i); + struct dst_entry *dst = idst->dst; + + idst->cookie = 0; + idst->dst = NULL; + dst_release(dst); + } +} +EXPORT_SYMBOL_GPL(dst_cache_reset_now); diff --git a/net/core/failover.c b/net/core/failover.c new file mode 100644 index 000000000..864d2d83e --- /dev/null +++ b/net/core/failover.c @@ -0,0 +1,315 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2018, Intel Corporation. */ + +/* A common module to handle registrations and notifications for paravirtual + * drivers to enable accelerated datapath and support VF live migration. + * + * The notifier and event handling code is based on netvsc driver. + */ + +#include +#include +#include +#include +#include +#include + +static LIST_HEAD(failover_list); +static DEFINE_SPINLOCK(failover_lock); + +static struct net_device *failover_get_bymac(u8 *mac, struct failover_ops **ops) +{ + struct net_device *failover_dev; + struct failover *failover; + + spin_lock(&failover_lock); + list_for_each_entry(failover, &failover_list, list) { + failover_dev = rtnl_dereference(failover->failover_dev); + if (ether_addr_equal(failover_dev->perm_addr, mac)) { + *ops = rtnl_dereference(failover->ops); + spin_unlock(&failover_lock); + return failover_dev; + } + } + spin_unlock(&failover_lock); + return NULL; +} + +/** + * failover_slave_register - Register a slave netdev + * + * @slave_dev: slave netdev that is being registered + * + * Registers a slave device to a failover instance. Only ethernet devices + * are supported. + */ +static int failover_slave_register(struct net_device *slave_dev) +{ + struct netdev_lag_upper_info lag_upper_info; + struct net_device *failover_dev; + struct failover_ops *fops; + int err; + + if (slave_dev->type != ARPHRD_ETHER) + goto done; + + ASSERT_RTNL(); + + failover_dev = failover_get_bymac(slave_dev->perm_addr, &fops); + if (!failover_dev) + goto done; + + if (fops && fops->slave_pre_register && + fops->slave_pre_register(slave_dev, failover_dev)) + goto done; + + err = netdev_rx_handler_register(slave_dev, fops->slave_handle_frame, + failover_dev); + if (err) { + netdev_err(slave_dev, "can not register failover rx handler (err = %d)\n", + err); + goto done; + } + + lag_upper_info.tx_type = NETDEV_LAG_TX_TYPE_ACTIVEBACKUP; + err = netdev_master_upper_dev_link(slave_dev, failover_dev, NULL, + &lag_upper_info, NULL); + if (err) { + netdev_err(slave_dev, "can not set failover device %s (err = %d)\n", + failover_dev->name, err); + goto err_upper_link; + } + + slave_dev->priv_flags |= (IFF_FAILOVER_SLAVE | IFF_LIVE_RENAME_OK); + + if (fops && fops->slave_register && + !fops->slave_register(slave_dev, failover_dev)) + return NOTIFY_OK; + + netdev_upper_dev_unlink(slave_dev, failover_dev); + slave_dev->priv_flags &= ~(IFF_FAILOVER_SLAVE | IFF_LIVE_RENAME_OK); +err_upper_link: + netdev_rx_handler_unregister(slave_dev); +done: + return NOTIFY_DONE; +} + +/** + * failover_slave_unregister - Unregister a slave netdev + * + * @slave_dev: slave netdev that is being unregistered + * + * Unregisters a slave device from a failover instance. + */ +int failover_slave_unregister(struct net_device *slave_dev) +{ + struct net_device *failover_dev; + struct failover_ops *fops; + + if (!netif_is_failover_slave(slave_dev)) + goto done; + + ASSERT_RTNL(); + + failover_dev = failover_get_bymac(slave_dev->perm_addr, &fops); + if (!failover_dev) + goto done; + + if (fops && fops->slave_pre_unregister && + fops->slave_pre_unregister(slave_dev, failover_dev)) + goto done; + + netdev_rx_handler_unregister(slave_dev); + netdev_upper_dev_unlink(slave_dev, failover_dev); + slave_dev->priv_flags &= ~(IFF_FAILOVER_SLAVE | IFF_LIVE_RENAME_OK); + + if (fops && fops->slave_unregister && + !fops->slave_unregister(slave_dev, failover_dev)) + return NOTIFY_OK; + +done: + return NOTIFY_DONE; +} +EXPORT_SYMBOL_GPL(failover_slave_unregister); + +static int failover_slave_link_change(struct net_device *slave_dev) +{ + struct net_device *failover_dev; + struct failover_ops *fops; + + if (!netif_is_failover_slave(slave_dev)) + goto done; + + ASSERT_RTNL(); + + failover_dev = failover_get_bymac(slave_dev->perm_addr, &fops); + if (!failover_dev) + goto done; + + if (!netif_running(failover_dev)) + goto done; + + if (fops && fops->slave_link_change && + !fops->slave_link_change(slave_dev, failover_dev)) + return NOTIFY_OK; + +done: + return NOTIFY_DONE; +} + +static int failover_slave_name_change(struct net_device *slave_dev) +{ + struct net_device *failover_dev; + struct failover_ops *fops; + + if (!netif_is_failover_slave(slave_dev)) + goto done; + + ASSERT_RTNL(); + + failover_dev = failover_get_bymac(slave_dev->perm_addr, &fops); + if (!failover_dev) + goto done; + + if (!netif_running(failover_dev)) + goto done; + + if (fops && fops->slave_name_change && + !fops->slave_name_change(slave_dev, failover_dev)) + return NOTIFY_OK; + +done: + return NOTIFY_DONE; +} + +static int +failover_event(struct notifier_block *this, unsigned long event, void *ptr) +{ + struct net_device *event_dev = netdev_notifier_info_to_dev(ptr); + + /* Skip parent events */ + if (netif_is_failover(event_dev)) + return NOTIFY_DONE; + + switch (event) { + case NETDEV_REGISTER: + return failover_slave_register(event_dev); + case NETDEV_UNREGISTER: + return failover_slave_unregister(event_dev); + case NETDEV_UP: + case NETDEV_DOWN: + case NETDEV_CHANGE: + return failover_slave_link_change(event_dev); + case NETDEV_CHANGENAME: + return failover_slave_name_change(event_dev); + default: + return NOTIFY_DONE; + } +} + +static struct notifier_block failover_notifier = { + .notifier_call = failover_event, +}; + +static void +failover_existing_slave_register(struct net_device *failover_dev) +{ + struct net *net = dev_net(failover_dev); + struct net_device *dev; + + rtnl_lock(); + for_each_netdev(net, dev) { + if (netif_is_failover(dev)) + continue; + if (ether_addr_equal(failover_dev->perm_addr, dev->perm_addr)) + failover_slave_register(dev); + } + rtnl_unlock(); +} + +/** + * failover_register - Register a failover instance + * + * @dev: failover netdev + * @ops: failover ops + * + * Allocate and register a failover instance for a failover netdev. ops + * provides handlers for slave device register/unregister/link change/ + * name change events. + * + * Return: pointer to failover instance + */ +struct failover *failover_register(struct net_device *dev, + struct failover_ops *ops) +{ + struct failover *failover; + + if (dev->type != ARPHRD_ETHER) + return ERR_PTR(-EINVAL); + + failover = kzalloc(sizeof(*failover), GFP_KERNEL); + if (!failover) + return ERR_PTR(-ENOMEM); + + rcu_assign_pointer(failover->ops, ops); + netdev_hold(dev, &failover->dev_tracker, GFP_KERNEL); + dev->priv_flags |= IFF_FAILOVER; + rcu_assign_pointer(failover->failover_dev, dev); + + spin_lock(&failover_lock); + list_add_tail(&failover->list, &failover_list); + spin_unlock(&failover_lock); + + netdev_info(dev, "failover master:%s registered\n", dev->name); + + failover_existing_slave_register(dev); + + return failover; +} +EXPORT_SYMBOL_GPL(failover_register); + +/** + * failover_unregister - Unregister a failover instance + * + * @failover: pointer to failover instance + * + * Unregisters and frees a failover instance. + */ +void failover_unregister(struct failover *failover) +{ + struct net_device *failover_dev; + + failover_dev = rcu_dereference(failover->failover_dev); + + netdev_info(failover_dev, "failover master:%s unregistered\n", + failover_dev->name); + + failover_dev->priv_flags &= ~IFF_FAILOVER; + netdev_put(failover_dev, &failover->dev_tracker); + + spin_lock(&failover_lock); + list_del(&failover->list); + spin_unlock(&failover_lock); + + kfree(failover); +} +EXPORT_SYMBOL_GPL(failover_unregister); + +static __init int +failover_init(void) +{ + register_netdevice_notifier(&failover_notifier); + + return 0; +} +module_init(failover_init); + +static __exit +void failover_exit(void) +{ + unregister_netdevice_notifier(&failover_notifier); +} +module_exit(failover_exit); + +MODULE_DESCRIPTION("Generic failover infrastructure/interface"); +MODULE_LICENSE("GPL v2"); diff --git a/net/core/fib_notifier.c b/net/core/fib_notifier.c new file mode 100644 index 000000000..fc9625980 --- /dev/null +++ b/net/core/fib_notifier.c @@ -0,0 +1,199 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static unsigned int fib_notifier_net_id; + +struct fib_notifier_net { + struct list_head fib_notifier_ops; + struct atomic_notifier_head fib_chain; +}; + +int call_fib_notifier(struct notifier_block *nb, + enum fib_event_type event_type, + struct fib_notifier_info *info) +{ + int err; + + err = nb->notifier_call(nb, event_type, info); + return notifier_to_errno(err); +} +EXPORT_SYMBOL(call_fib_notifier); + +int call_fib_notifiers(struct net *net, enum fib_event_type event_type, + struct fib_notifier_info *info) +{ + struct fib_notifier_net *fn_net = net_generic(net, fib_notifier_net_id); + int err; + + err = atomic_notifier_call_chain(&fn_net->fib_chain, event_type, info); + return notifier_to_errno(err); +} +EXPORT_SYMBOL(call_fib_notifiers); + +static unsigned int fib_seq_sum(struct net *net) +{ + struct fib_notifier_net *fn_net = net_generic(net, fib_notifier_net_id); + struct fib_notifier_ops *ops; + unsigned int fib_seq = 0; + + rtnl_lock(); + rcu_read_lock(); + list_for_each_entry_rcu(ops, &fn_net->fib_notifier_ops, list) { + if (!try_module_get(ops->owner)) + continue; + fib_seq += ops->fib_seq_read(net); + module_put(ops->owner); + } + rcu_read_unlock(); + rtnl_unlock(); + + return fib_seq; +} + +static int fib_net_dump(struct net *net, struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + struct fib_notifier_net *fn_net = net_generic(net, fib_notifier_net_id); + struct fib_notifier_ops *ops; + int err = 0; + + rcu_read_lock(); + list_for_each_entry_rcu(ops, &fn_net->fib_notifier_ops, list) { + if (!try_module_get(ops->owner)) + continue; + err = ops->fib_dump(net, nb, extack); + module_put(ops->owner); + if (err) + goto unlock; + } + +unlock: + rcu_read_unlock(); + + return err; +} + +static bool fib_dump_is_consistent(struct net *net, struct notifier_block *nb, + void (*cb)(struct notifier_block *nb), + unsigned int fib_seq) +{ + struct fib_notifier_net *fn_net = net_generic(net, fib_notifier_net_id); + + atomic_notifier_chain_register(&fn_net->fib_chain, nb); + if (fib_seq == fib_seq_sum(net)) + return true; + atomic_notifier_chain_unregister(&fn_net->fib_chain, nb); + if (cb) + cb(nb); + return false; +} + +#define FIB_DUMP_MAX_RETRIES 5 +int register_fib_notifier(struct net *net, struct notifier_block *nb, + void (*cb)(struct notifier_block *nb), + struct netlink_ext_ack *extack) +{ + int retries = 0; + int err; + + do { + unsigned int fib_seq = fib_seq_sum(net); + + err = fib_net_dump(net, nb, extack); + if (err) + return err; + + if (fib_dump_is_consistent(net, nb, cb, fib_seq)) + return 0; + } while (++retries < FIB_DUMP_MAX_RETRIES); + + return -EBUSY; +} +EXPORT_SYMBOL(register_fib_notifier); + +int unregister_fib_notifier(struct net *net, struct notifier_block *nb) +{ + struct fib_notifier_net *fn_net = net_generic(net, fib_notifier_net_id); + + return atomic_notifier_chain_unregister(&fn_net->fib_chain, nb); +} +EXPORT_SYMBOL(unregister_fib_notifier); + +static int __fib_notifier_ops_register(struct fib_notifier_ops *ops, + struct net *net) +{ + struct fib_notifier_net *fn_net = net_generic(net, fib_notifier_net_id); + struct fib_notifier_ops *o; + + list_for_each_entry(o, &fn_net->fib_notifier_ops, list) + if (ops->family == o->family) + return -EEXIST; + list_add_tail_rcu(&ops->list, &fn_net->fib_notifier_ops); + return 0; +} + +struct fib_notifier_ops * +fib_notifier_ops_register(const struct fib_notifier_ops *tmpl, struct net *net) +{ + struct fib_notifier_ops *ops; + int err; + + ops = kmemdup(tmpl, sizeof(*ops), GFP_KERNEL); + if (!ops) + return ERR_PTR(-ENOMEM); + + err = __fib_notifier_ops_register(ops, net); + if (err) + goto err_register; + + return ops; + +err_register: + kfree(ops); + return ERR_PTR(err); +} +EXPORT_SYMBOL(fib_notifier_ops_register); + +void fib_notifier_ops_unregister(struct fib_notifier_ops *ops) +{ + list_del_rcu(&ops->list); + kfree_rcu(ops, rcu); +} +EXPORT_SYMBOL(fib_notifier_ops_unregister); + +static int __net_init fib_notifier_net_init(struct net *net) +{ + struct fib_notifier_net *fn_net = net_generic(net, fib_notifier_net_id); + + INIT_LIST_HEAD(&fn_net->fib_notifier_ops); + ATOMIC_INIT_NOTIFIER_HEAD(&fn_net->fib_chain); + return 0; +} + +static void __net_exit fib_notifier_net_exit(struct net *net) +{ + struct fib_notifier_net *fn_net = net_generic(net, fib_notifier_net_id); + + WARN_ON_ONCE(!list_empty(&fn_net->fib_notifier_ops)); +} + +static struct pernet_operations fib_notifier_net_ops = { + .init = fib_notifier_net_init, + .exit = fib_notifier_net_exit, + .id = &fib_notifier_net_id, + .size = sizeof(struct fib_notifier_net), +}; + +static int __init fib_notifier_init(void) +{ + return register_pernet_subsys(&fib_notifier_net_ops); +} + +subsys_initcall(fib_notifier_init); diff --git a/net/core/fib_rules.c b/net/core/fib_rules.c new file mode 100644 index 000000000..75282222e --- /dev/null +++ b/net/core/fib_rules.c @@ -0,0 +1,1319 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/core/fib_rules.c Generic Routing Rules + * + * Authors: Thomas Graf + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(CONFIG_IPV6) && defined(CONFIG_IPV6_MULTIPLE_TABLES) +#ifdef CONFIG_IP_MULTIPLE_TABLES +#define INDIRECT_CALL_MT(f, f2, f1, ...) \ + INDIRECT_CALL_INET(f, f2, f1, __VA_ARGS__) +#else +#define INDIRECT_CALL_MT(f, f2, f1, ...) INDIRECT_CALL_1(f, f2, __VA_ARGS__) +#endif +#elif defined(CONFIG_IP_MULTIPLE_TABLES) +#define INDIRECT_CALL_MT(f, f2, f1, ...) INDIRECT_CALL_1(f, f1, __VA_ARGS__) +#else +#define INDIRECT_CALL_MT(f, f2, f1, ...) f(__VA_ARGS__) +#endif + +static const struct fib_kuid_range fib_kuid_range_unset = { + KUIDT_INIT(0), + KUIDT_INIT(~0), +}; + +bool fib_rule_matchall(const struct fib_rule *rule) +{ + if (rule->iifindex || rule->oifindex || rule->mark || rule->tun_id || + rule->flags) + return false; + if (rule->suppress_ifgroup != -1 || rule->suppress_prefixlen != -1) + return false; + if (!uid_eq(rule->uid_range.start, fib_kuid_range_unset.start) || + !uid_eq(rule->uid_range.end, fib_kuid_range_unset.end)) + return false; + if (fib_rule_port_range_set(&rule->sport_range)) + return false; + if (fib_rule_port_range_set(&rule->dport_range)) + return false; + return true; +} +EXPORT_SYMBOL_GPL(fib_rule_matchall); + +int fib_default_rule_add(struct fib_rules_ops *ops, + u32 pref, u32 table, u32 flags) +{ + struct fib_rule *r; + + r = kzalloc(ops->rule_size, GFP_KERNEL_ACCOUNT); + if (r == NULL) + return -ENOMEM; + + refcount_set(&r->refcnt, 1); + r->action = FR_ACT_TO_TBL; + r->pref = pref; + r->table = table; + r->flags = flags; + r->proto = RTPROT_KERNEL; + r->fr_net = ops->fro_net; + r->uid_range = fib_kuid_range_unset; + + r->suppress_prefixlen = -1; + r->suppress_ifgroup = -1; + + /* The lock is not required here, the list in unreacheable + * at the moment this function is called */ + list_add_tail(&r->list, &ops->rules_list); + return 0; +} +EXPORT_SYMBOL(fib_default_rule_add); + +static u32 fib_default_rule_pref(struct fib_rules_ops *ops) +{ + struct list_head *pos; + struct fib_rule *rule; + + if (!list_empty(&ops->rules_list)) { + pos = ops->rules_list.next; + if (pos->next != &ops->rules_list) { + rule = list_entry(pos->next, struct fib_rule, list); + if (rule->pref) + return rule->pref - 1; + } + } + + return 0; +} + +static void notify_rule_change(int event, struct fib_rule *rule, + struct fib_rules_ops *ops, struct nlmsghdr *nlh, + u32 pid); + +static struct fib_rules_ops *lookup_rules_ops(struct net *net, int family) +{ + struct fib_rules_ops *ops; + + rcu_read_lock(); + list_for_each_entry_rcu(ops, &net->rules_ops, list) { + if (ops->family == family) { + if (!try_module_get(ops->owner)) + ops = NULL; + rcu_read_unlock(); + return ops; + } + } + rcu_read_unlock(); + + return NULL; +} + +static void rules_ops_put(struct fib_rules_ops *ops) +{ + if (ops) + module_put(ops->owner); +} + +static void flush_route_cache(struct fib_rules_ops *ops) +{ + if (ops->flush_cache) + ops->flush_cache(ops); +} + +static int __fib_rules_register(struct fib_rules_ops *ops) +{ + int err = -EEXIST; + struct fib_rules_ops *o; + struct net *net; + + net = ops->fro_net; + + if (ops->rule_size < sizeof(struct fib_rule)) + return -EINVAL; + + if (ops->match == NULL || ops->configure == NULL || + ops->compare == NULL || ops->fill == NULL || + ops->action == NULL) + return -EINVAL; + + spin_lock(&net->rules_mod_lock); + list_for_each_entry(o, &net->rules_ops, list) + if (ops->family == o->family) + goto errout; + + list_add_tail_rcu(&ops->list, &net->rules_ops); + err = 0; +errout: + spin_unlock(&net->rules_mod_lock); + + return err; +} + +struct fib_rules_ops * +fib_rules_register(const struct fib_rules_ops *tmpl, struct net *net) +{ + struct fib_rules_ops *ops; + int err; + + ops = kmemdup(tmpl, sizeof(*ops), GFP_KERNEL); + if (ops == NULL) + return ERR_PTR(-ENOMEM); + + INIT_LIST_HEAD(&ops->rules_list); + ops->fro_net = net; + + err = __fib_rules_register(ops); + if (err) { + kfree(ops); + ops = ERR_PTR(err); + } + + return ops; +} +EXPORT_SYMBOL_GPL(fib_rules_register); + +static void fib_rules_cleanup_ops(struct fib_rules_ops *ops) +{ + struct fib_rule *rule, *tmp; + + list_for_each_entry_safe(rule, tmp, &ops->rules_list, list) { + list_del_rcu(&rule->list); + if (ops->delete) + ops->delete(rule); + fib_rule_put(rule); + } +} + +void fib_rules_unregister(struct fib_rules_ops *ops) +{ + struct net *net = ops->fro_net; + + spin_lock(&net->rules_mod_lock); + list_del_rcu(&ops->list); + spin_unlock(&net->rules_mod_lock); + + fib_rules_cleanup_ops(ops); + kfree_rcu(ops, rcu); +} +EXPORT_SYMBOL_GPL(fib_rules_unregister); + +static int uid_range_set(struct fib_kuid_range *range) +{ + return uid_valid(range->start) && uid_valid(range->end); +} + +static struct fib_kuid_range nla_get_kuid_range(struct nlattr **tb) +{ + struct fib_rule_uid_range *in; + struct fib_kuid_range out; + + in = (struct fib_rule_uid_range *)nla_data(tb[FRA_UID_RANGE]); + + out.start = make_kuid(current_user_ns(), in->start); + out.end = make_kuid(current_user_ns(), in->end); + + return out; +} + +static int nla_put_uid_range(struct sk_buff *skb, struct fib_kuid_range *range) +{ + struct fib_rule_uid_range out = { + from_kuid_munged(current_user_ns(), range->start), + from_kuid_munged(current_user_ns(), range->end) + }; + + return nla_put(skb, FRA_UID_RANGE, sizeof(out), &out); +} + +static int nla_get_port_range(struct nlattr *pattr, + struct fib_rule_port_range *port_range) +{ + const struct fib_rule_port_range *pr = nla_data(pattr); + + if (!fib_rule_port_range_valid(pr)) + return -EINVAL; + + port_range->start = pr->start; + port_range->end = pr->end; + + return 0; +} + +static int nla_put_port_range(struct sk_buff *skb, int attrtype, + struct fib_rule_port_range *range) +{ + return nla_put(skb, attrtype, sizeof(*range), range); +} + +static int fib_rule_match(struct fib_rule *rule, struct fib_rules_ops *ops, + struct flowi *fl, int flags, + struct fib_lookup_arg *arg) +{ + int ret = 0; + + if (rule->iifindex && (rule->iifindex != fl->flowi_iif)) + goto out; + + if (rule->oifindex && (rule->oifindex != fl->flowi_oif)) + goto out; + + if ((rule->mark ^ fl->flowi_mark) & rule->mark_mask) + goto out; + + if (rule->tun_id && (rule->tun_id != fl->flowi_tun_key.tun_id)) + goto out; + + if (rule->l3mdev && !l3mdev_fib_rule_match(rule->fr_net, fl, arg)) + goto out; + + if (uid_lt(fl->flowi_uid, rule->uid_range.start) || + uid_gt(fl->flowi_uid, rule->uid_range.end)) + goto out; + + ret = INDIRECT_CALL_MT(ops->match, + fib6_rule_match, + fib4_rule_match, + rule, fl, flags); +out: + return (rule->flags & FIB_RULE_INVERT) ? !ret : ret; +} + +int fib_rules_lookup(struct fib_rules_ops *ops, struct flowi *fl, + int flags, struct fib_lookup_arg *arg) +{ + struct fib_rule *rule; + int err; + + rcu_read_lock(); + + list_for_each_entry_rcu(rule, &ops->rules_list, list) { +jumped: + if (!fib_rule_match(rule, ops, fl, flags, arg)) + continue; + + if (rule->action == FR_ACT_GOTO) { + struct fib_rule *target; + + target = rcu_dereference(rule->ctarget); + if (target == NULL) { + continue; + } else { + rule = target; + goto jumped; + } + } else if (rule->action == FR_ACT_NOP) + continue; + else + err = INDIRECT_CALL_MT(ops->action, + fib6_rule_action, + fib4_rule_action, + rule, fl, flags, arg); + + if (!err && ops->suppress && INDIRECT_CALL_MT(ops->suppress, + fib6_rule_suppress, + fib4_rule_suppress, + rule, flags, arg)) + continue; + + if (err != -EAGAIN) { + if ((arg->flags & FIB_LOOKUP_NOREF) || + likely(refcount_inc_not_zero(&rule->refcnt))) { + arg->rule = rule; + goto out; + } + break; + } + } + + err = -ESRCH; +out: + rcu_read_unlock(); + + return err; +} +EXPORT_SYMBOL_GPL(fib_rules_lookup); + +static int call_fib_rule_notifier(struct notifier_block *nb, + enum fib_event_type event_type, + struct fib_rule *rule, int family, + struct netlink_ext_ack *extack) +{ + struct fib_rule_notifier_info info = { + .info.family = family, + .info.extack = extack, + .rule = rule, + }; + + return call_fib_notifier(nb, event_type, &info.info); +} + +static int call_fib_rule_notifiers(struct net *net, + enum fib_event_type event_type, + struct fib_rule *rule, + struct fib_rules_ops *ops, + struct netlink_ext_ack *extack) +{ + struct fib_rule_notifier_info info = { + .info.family = ops->family, + .info.extack = extack, + .rule = rule, + }; + + ops->fib_rules_seq++; + return call_fib_notifiers(net, event_type, &info.info); +} + +/* Called with rcu_read_lock() */ +int fib_rules_dump(struct net *net, struct notifier_block *nb, int family, + struct netlink_ext_ack *extack) +{ + struct fib_rules_ops *ops; + struct fib_rule *rule; + int err = 0; + + ops = lookup_rules_ops(net, family); + if (!ops) + return -EAFNOSUPPORT; + list_for_each_entry_rcu(rule, &ops->rules_list, list) { + err = call_fib_rule_notifier(nb, FIB_EVENT_RULE_ADD, + rule, family, extack); + if (err) + break; + } + rules_ops_put(ops); + + return err; +} +EXPORT_SYMBOL_GPL(fib_rules_dump); + +unsigned int fib_rules_seq_read(struct net *net, int family) +{ + unsigned int fib_rules_seq; + struct fib_rules_ops *ops; + + ASSERT_RTNL(); + + ops = lookup_rules_ops(net, family); + if (!ops) + return 0; + fib_rules_seq = ops->fib_rules_seq; + rules_ops_put(ops); + + return fib_rules_seq; +} +EXPORT_SYMBOL_GPL(fib_rules_seq_read); + +static struct fib_rule *rule_find(struct fib_rules_ops *ops, + struct fib_rule_hdr *frh, + struct nlattr **tb, + struct fib_rule *rule, + bool user_priority) +{ + struct fib_rule *r; + + list_for_each_entry(r, &ops->rules_list, list) { + if (rule->action && r->action != rule->action) + continue; + + if (rule->table && r->table != rule->table) + continue; + + if (user_priority && r->pref != rule->pref) + continue; + + if (rule->iifname[0] && + memcmp(r->iifname, rule->iifname, IFNAMSIZ)) + continue; + + if (rule->oifname[0] && + memcmp(r->oifname, rule->oifname, IFNAMSIZ)) + continue; + + if (rule->mark && r->mark != rule->mark) + continue; + + if (rule->suppress_ifgroup != -1 && + r->suppress_ifgroup != rule->suppress_ifgroup) + continue; + + if (rule->suppress_prefixlen != -1 && + r->suppress_prefixlen != rule->suppress_prefixlen) + continue; + + if (rule->mark_mask && r->mark_mask != rule->mark_mask) + continue; + + if (rule->tun_id && r->tun_id != rule->tun_id) + continue; + + if (r->fr_net != rule->fr_net) + continue; + + if (rule->l3mdev && r->l3mdev != rule->l3mdev) + continue; + + if (uid_range_set(&rule->uid_range) && + (!uid_eq(r->uid_range.start, rule->uid_range.start) || + !uid_eq(r->uid_range.end, rule->uid_range.end))) + continue; + + if (rule->ip_proto && r->ip_proto != rule->ip_proto) + continue; + + if (rule->proto && r->proto != rule->proto) + continue; + + if (fib_rule_port_range_set(&rule->sport_range) && + !fib_rule_port_range_compare(&r->sport_range, + &rule->sport_range)) + continue; + + if (fib_rule_port_range_set(&rule->dport_range) && + !fib_rule_port_range_compare(&r->dport_range, + &rule->dport_range)) + continue; + + if (!ops->compare(r, frh, tb)) + continue; + return r; + } + + return NULL; +} + +#ifdef CONFIG_NET_L3_MASTER_DEV +static int fib_nl2rule_l3mdev(struct nlattr *nla, struct fib_rule *nlrule, + struct netlink_ext_ack *extack) +{ + nlrule->l3mdev = nla_get_u8(nla); + if (nlrule->l3mdev != 1) { + NL_SET_ERR_MSG(extack, "Invalid l3mdev attribute"); + return -1; + } + + return 0; +} +#else +static int fib_nl2rule_l3mdev(struct nlattr *nla, struct fib_rule *nlrule, + struct netlink_ext_ack *extack) +{ + NL_SET_ERR_MSG(extack, "l3mdev support is not enabled in kernel"); + return -1; +} +#endif + +static int fib_nl2rule(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack, + struct fib_rules_ops *ops, + struct nlattr *tb[], + struct fib_rule **rule, + bool *user_priority) +{ + struct net *net = sock_net(skb->sk); + struct fib_rule_hdr *frh = nlmsg_data(nlh); + struct fib_rule *nlrule = NULL; + int err = -EINVAL; + + if (frh->src_len) + if (!tb[FRA_SRC] || + frh->src_len > (ops->addr_size * 8) || + nla_len(tb[FRA_SRC]) != ops->addr_size) { + NL_SET_ERR_MSG(extack, "Invalid source address"); + goto errout; + } + + if (frh->dst_len) + if (!tb[FRA_DST] || + frh->dst_len > (ops->addr_size * 8) || + nla_len(tb[FRA_DST]) != ops->addr_size) { + NL_SET_ERR_MSG(extack, "Invalid dst address"); + goto errout; + } + + nlrule = kzalloc(ops->rule_size, GFP_KERNEL_ACCOUNT); + if (!nlrule) { + err = -ENOMEM; + goto errout; + } + refcount_set(&nlrule->refcnt, 1); + nlrule->fr_net = net; + + if (tb[FRA_PRIORITY]) { + nlrule->pref = nla_get_u32(tb[FRA_PRIORITY]); + *user_priority = true; + } else { + nlrule->pref = fib_default_rule_pref(ops); + } + + nlrule->proto = tb[FRA_PROTOCOL] ? + nla_get_u8(tb[FRA_PROTOCOL]) : RTPROT_UNSPEC; + + if (tb[FRA_IIFNAME]) { + struct net_device *dev; + + nlrule->iifindex = -1; + nla_strscpy(nlrule->iifname, tb[FRA_IIFNAME], IFNAMSIZ); + dev = __dev_get_by_name(net, nlrule->iifname); + if (dev) + nlrule->iifindex = dev->ifindex; + } + + if (tb[FRA_OIFNAME]) { + struct net_device *dev; + + nlrule->oifindex = -1; + nla_strscpy(nlrule->oifname, tb[FRA_OIFNAME], IFNAMSIZ); + dev = __dev_get_by_name(net, nlrule->oifname); + if (dev) + nlrule->oifindex = dev->ifindex; + } + + if (tb[FRA_FWMARK]) { + nlrule->mark = nla_get_u32(tb[FRA_FWMARK]); + if (nlrule->mark) + /* compatibility: if the mark value is non-zero all bits + * are compared unless a mask is explicitly specified. + */ + nlrule->mark_mask = 0xFFFFFFFF; + } + + if (tb[FRA_FWMASK]) + nlrule->mark_mask = nla_get_u32(tb[FRA_FWMASK]); + + if (tb[FRA_TUN_ID]) + nlrule->tun_id = nla_get_be64(tb[FRA_TUN_ID]); + + err = -EINVAL; + if (tb[FRA_L3MDEV] && + fib_nl2rule_l3mdev(tb[FRA_L3MDEV], nlrule, extack) < 0) + goto errout_free; + + nlrule->action = frh->action; + nlrule->flags = frh->flags; + nlrule->table = frh_get_table(frh, tb); + if (tb[FRA_SUPPRESS_PREFIXLEN]) + nlrule->suppress_prefixlen = nla_get_u32(tb[FRA_SUPPRESS_PREFIXLEN]); + else + nlrule->suppress_prefixlen = -1; + + if (tb[FRA_SUPPRESS_IFGROUP]) + nlrule->suppress_ifgroup = nla_get_u32(tb[FRA_SUPPRESS_IFGROUP]); + else + nlrule->suppress_ifgroup = -1; + + if (tb[FRA_GOTO]) { + if (nlrule->action != FR_ACT_GOTO) { + NL_SET_ERR_MSG(extack, "Unexpected goto"); + goto errout_free; + } + + nlrule->target = nla_get_u32(tb[FRA_GOTO]); + /* Backward jumps are prohibited to avoid endless loops */ + if (nlrule->target <= nlrule->pref) { + NL_SET_ERR_MSG(extack, "Backward goto not supported"); + goto errout_free; + } + } else if (nlrule->action == FR_ACT_GOTO) { + NL_SET_ERR_MSG(extack, "Missing goto target for action goto"); + goto errout_free; + } + + if (nlrule->l3mdev && nlrule->table) { + NL_SET_ERR_MSG(extack, "l3mdev and table are mutually exclusive"); + goto errout_free; + } + + if (tb[FRA_UID_RANGE]) { + if (current_user_ns() != net->user_ns) { + err = -EPERM; + NL_SET_ERR_MSG(extack, "No permission to set uid"); + goto errout_free; + } + + nlrule->uid_range = nla_get_kuid_range(tb); + + if (!uid_range_set(&nlrule->uid_range) || + !uid_lte(nlrule->uid_range.start, nlrule->uid_range.end)) { + NL_SET_ERR_MSG(extack, "Invalid uid range"); + goto errout_free; + } + } else { + nlrule->uid_range = fib_kuid_range_unset; + } + + if (tb[FRA_IP_PROTO]) + nlrule->ip_proto = nla_get_u8(tb[FRA_IP_PROTO]); + + if (tb[FRA_SPORT_RANGE]) { + err = nla_get_port_range(tb[FRA_SPORT_RANGE], + &nlrule->sport_range); + if (err) { + NL_SET_ERR_MSG(extack, "Invalid sport range"); + goto errout_free; + } + } + + if (tb[FRA_DPORT_RANGE]) { + err = nla_get_port_range(tb[FRA_DPORT_RANGE], + &nlrule->dport_range); + if (err) { + NL_SET_ERR_MSG(extack, "Invalid dport range"); + goto errout_free; + } + } + + *rule = nlrule; + + return 0; + +errout_free: + kfree(nlrule); +errout: + return err; +} + +static int rule_exists(struct fib_rules_ops *ops, struct fib_rule_hdr *frh, + struct nlattr **tb, struct fib_rule *rule) +{ + struct fib_rule *r; + + list_for_each_entry(r, &ops->rules_list, list) { + if (r->action != rule->action) + continue; + + if (r->table != rule->table) + continue; + + if (r->pref != rule->pref) + continue; + + if (memcmp(r->iifname, rule->iifname, IFNAMSIZ)) + continue; + + if (memcmp(r->oifname, rule->oifname, IFNAMSIZ)) + continue; + + if (r->mark != rule->mark) + continue; + + if (r->suppress_ifgroup != rule->suppress_ifgroup) + continue; + + if (r->suppress_prefixlen != rule->suppress_prefixlen) + continue; + + if (r->mark_mask != rule->mark_mask) + continue; + + if (r->tun_id != rule->tun_id) + continue; + + if (r->fr_net != rule->fr_net) + continue; + + if (r->l3mdev != rule->l3mdev) + continue; + + if (!uid_eq(r->uid_range.start, rule->uid_range.start) || + !uid_eq(r->uid_range.end, rule->uid_range.end)) + continue; + + if (r->ip_proto != rule->ip_proto) + continue; + + if (r->proto != rule->proto) + continue; + + if (!fib_rule_port_range_compare(&r->sport_range, + &rule->sport_range)) + continue; + + if (!fib_rule_port_range_compare(&r->dport_range, + &rule->dport_range)) + continue; + + if (!ops->compare(r, frh, tb)) + continue; + return 1; + } + return 0; +} + +static const struct nla_policy fib_rule_policy[FRA_MAX + 1] = { + [FRA_UNSPEC] = { .strict_start_type = FRA_DPORT_RANGE + 1 }, + [FRA_IIFNAME] = { .type = NLA_STRING, .len = IFNAMSIZ - 1 }, + [FRA_OIFNAME] = { .type = NLA_STRING, .len = IFNAMSIZ - 1 }, + [FRA_PRIORITY] = { .type = NLA_U32 }, + [FRA_FWMARK] = { .type = NLA_U32 }, + [FRA_FLOW] = { .type = NLA_U32 }, + [FRA_TUN_ID] = { .type = NLA_U64 }, + [FRA_FWMASK] = { .type = NLA_U32 }, + [FRA_TABLE] = { .type = NLA_U32 }, + [FRA_SUPPRESS_PREFIXLEN] = { .type = NLA_U32 }, + [FRA_SUPPRESS_IFGROUP] = { .type = NLA_U32 }, + [FRA_GOTO] = { .type = NLA_U32 }, + [FRA_L3MDEV] = { .type = NLA_U8 }, + [FRA_UID_RANGE] = { .len = sizeof(struct fib_rule_uid_range) }, + [FRA_PROTOCOL] = { .type = NLA_U8 }, + [FRA_IP_PROTO] = { .type = NLA_U8 }, + [FRA_SPORT_RANGE] = { .len = sizeof(struct fib_rule_port_range) }, + [FRA_DPORT_RANGE] = { .len = sizeof(struct fib_rule_port_range) } +}; + +int fib_nl_newrule(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct fib_rule_hdr *frh = nlmsg_data(nlh); + struct fib_rules_ops *ops = NULL; + struct fib_rule *rule = NULL, *r, *last = NULL; + struct nlattr *tb[FRA_MAX + 1]; + int err = -EINVAL, unresolved = 0; + bool user_priority = false; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*frh))) { + NL_SET_ERR_MSG(extack, "Invalid msg length"); + goto errout; + } + + ops = lookup_rules_ops(net, frh->family); + if (!ops) { + err = -EAFNOSUPPORT; + NL_SET_ERR_MSG(extack, "Rule family not supported"); + goto errout; + } + + err = nlmsg_parse_deprecated(nlh, sizeof(*frh), tb, FRA_MAX, + fib_rule_policy, extack); + if (err < 0) { + NL_SET_ERR_MSG(extack, "Error parsing msg"); + goto errout; + } + + err = fib_nl2rule(skb, nlh, extack, ops, tb, &rule, &user_priority); + if (err) + goto errout; + + if ((nlh->nlmsg_flags & NLM_F_EXCL) && + rule_exists(ops, frh, tb, rule)) { + err = -EEXIST; + goto errout_free; + } + + err = ops->configure(rule, skb, frh, tb, extack); + if (err < 0) + goto errout_free; + + err = call_fib_rule_notifiers(net, FIB_EVENT_RULE_ADD, rule, ops, + extack); + if (err < 0) + goto errout_free; + + list_for_each_entry(r, &ops->rules_list, list) { + if (r->pref == rule->target) { + RCU_INIT_POINTER(rule->ctarget, r); + break; + } + } + + if (rcu_dereference_protected(rule->ctarget, 1) == NULL) + unresolved = 1; + + list_for_each_entry(r, &ops->rules_list, list) { + if (r->pref > rule->pref) + break; + last = r; + } + + if (last) + list_add_rcu(&rule->list, &last->list); + else + list_add_rcu(&rule->list, &ops->rules_list); + + if (ops->unresolved_rules) { + /* + * There are unresolved goto rules in the list, check if + * any of them are pointing to this new rule. + */ + list_for_each_entry(r, &ops->rules_list, list) { + if (r->action == FR_ACT_GOTO && + r->target == rule->pref && + rtnl_dereference(r->ctarget) == NULL) { + rcu_assign_pointer(r->ctarget, rule); + if (--ops->unresolved_rules == 0) + break; + } + } + } + + if (rule->action == FR_ACT_GOTO) + ops->nr_goto_rules++; + + if (unresolved) + ops->unresolved_rules++; + + if (rule->tun_id) + ip_tunnel_need_metadata(); + + notify_rule_change(RTM_NEWRULE, rule, ops, nlh, NETLINK_CB(skb).portid); + flush_route_cache(ops); + rules_ops_put(ops); + return 0; + +errout_free: + kfree(rule); +errout: + rules_ops_put(ops); + return err; +} +EXPORT_SYMBOL_GPL(fib_nl_newrule); + +int fib_nl_delrule(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct fib_rule_hdr *frh = nlmsg_data(nlh); + struct fib_rules_ops *ops = NULL; + struct fib_rule *rule = NULL, *r, *nlrule = NULL; + struct nlattr *tb[FRA_MAX+1]; + int err = -EINVAL; + bool user_priority = false; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*frh))) { + NL_SET_ERR_MSG(extack, "Invalid msg length"); + goto errout; + } + + ops = lookup_rules_ops(net, frh->family); + if (ops == NULL) { + err = -EAFNOSUPPORT; + NL_SET_ERR_MSG(extack, "Rule family not supported"); + goto errout; + } + + err = nlmsg_parse_deprecated(nlh, sizeof(*frh), tb, FRA_MAX, + fib_rule_policy, extack); + if (err < 0) { + NL_SET_ERR_MSG(extack, "Error parsing msg"); + goto errout; + } + + err = fib_nl2rule(skb, nlh, extack, ops, tb, &nlrule, &user_priority); + if (err) + goto errout; + + rule = rule_find(ops, frh, tb, nlrule, user_priority); + if (!rule) { + err = -ENOENT; + goto errout; + } + + if (rule->flags & FIB_RULE_PERMANENT) { + err = -EPERM; + goto errout; + } + + if (ops->delete) { + err = ops->delete(rule); + if (err) + goto errout; + } + + if (rule->tun_id) + ip_tunnel_unneed_metadata(); + + list_del_rcu(&rule->list); + + if (rule->action == FR_ACT_GOTO) { + ops->nr_goto_rules--; + if (rtnl_dereference(rule->ctarget) == NULL) + ops->unresolved_rules--; + } + + /* + * Check if this rule is a target to any of them. If so, + * adjust to the next one with the same preference or + * disable them. As this operation is eventually very + * expensive, it is only performed if goto rules, except + * current if it is goto rule, have actually been added. + */ + if (ops->nr_goto_rules > 0) { + struct fib_rule *n; + + n = list_next_entry(rule, list); + if (&n->list == &ops->rules_list || n->pref != rule->pref) + n = NULL; + list_for_each_entry(r, &ops->rules_list, list) { + if (rtnl_dereference(r->ctarget) != rule) + continue; + rcu_assign_pointer(r->ctarget, n); + if (!n) + ops->unresolved_rules++; + } + } + + call_fib_rule_notifiers(net, FIB_EVENT_RULE_DEL, rule, ops, + NULL); + notify_rule_change(RTM_DELRULE, rule, ops, nlh, + NETLINK_CB(skb).portid); + fib_rule_put(rule); + flush_route_cache(ops); + rules_ops_put(ops); + kfree(nlrule); + return 0; + +errout: + kfree(nlrule); + rules_ops_put(ops); + return err; +} +EXPORT_SYMBOL_GPL(fib_nl_delrule); + +static inline size_t fib_rule_nlmsg_size(struct fib_rules_ops *ops, + struct fib_rule *rule) +{ + size_t payload = NLMSG_ALIGN(sizeof(struct fib_rule_hdr)) + + nla_total_size(IFNAMSIZ) /* FRA_IIFNAME */ + + nla_total_size(IFNAMSIZ) /* FRA_OIFNAME */ + + nla_total_size(4) /* FRA_PRIORITY */ + + nla_total_size(4) /* FRA_TABLE */ + + nla_total_size(4) /* FRA_SUPPRESS_PREFIXLEN */ + + nla_total_size(4) /* FRA_SUPPRESS_IFGROUP */ + + nla_total_size(4) /* FRA_FWMARK */ + + nla_total_size(4) /* FRA_FWMASK */ + + nla_total_size_64bit(8) /* FRA_TUN_ID */ + + nla_total_size(sizeof(struct fib_kuid_range)) + + nla_total_size(1) /* FRA_PROTOCOL */ + + nla_total_size(1) /* FRA_IP_PROTO */ + + nla_total_size(sizeof(struct fib_rule_port_range)) /* FRA_SPORT_RANGE */ + + nla_total_size(sizeof(struct fib_rule_port_range)); /* FRA_DPORT_RANGE */ + + if (ops->nlmsg_payload) + payload += ops->nlmsg_payload(rule); + + return payload; +} + +static int fib_nl_fill_rule(struct sk_buff *skb, struct fib_rule *rule, + u32 pid, u32 seq, int type, int flags, + struct fib_rules_ops *ops) +{ + struct nlmsghdr *nlh; + struct fib_rule_hdr *frh; + + nlh = nlmsg_put(skb, pid, seq, type, sizeof(*frh), flags); + if (nlh == NULL) + return -EMSGSIZE; + + frh = nlmsg_data(nlh); + frh->family = ops->family; + frh->table = rule->table < 256 ? rule->table : RT_TABLE_COMPAT; + if (nla_put_u32(skb, FRA_TABLE, rule->table)) + goto nla_put_failure; + if (nla_put_u32(skb, FRA_SUPPRESS_PREFIXLEN, rule->suppress_prefixlen)) + goto nla_put_failure; + frh->res1 = 0; + frh->res2 = 0; + frh->action = rule->action; + frh->flags = rule->flags; + + if (nla_put_u8(skb, FRA_PROTOCOL, rule->proto)) + goto nla_put_failure; + + if (rule->action == FR_ACT_GOTO && + rcu_access_pointer(rule->ctarget) == NULL) + frh->flags |= FIB_RULE_UNRESOLVED; + + if (rule->iifname[0]) { + if (nla_put_string(skb, FRA_IIFNAME, rule->iifname)) + goto nla_put_failure; + if (rule->iifindex == -1) + frh->flags |= FIB_RULE_IIF_DETACHED; + } + + if (rule->oifname[0]) { + if (nla_put_string(skb, FRA_OIFNAME, rule->oifname)) + goto nla_put_failure; + if (rule->oifindex == -1) + frh->flags |= FIB_RULE_OIF_DETACHED; + } + + if ((rule->pref && + nla_put_u32(skb, FRA_PRIORITY, rule->pref)) || + (rule->mark && + nla_put_u32(skb, FRA_FWMARK, rule->mark)) || + ((rule->mark_mask || rule->mark) && + nla_put_u32(skb, FRA_FWMASK, rule->mark_mask)) || + (rule->target && + nla_put_u32(skb, FRA_GOTO, rule->target)) || + (rule->tun_id && + nla_put_be64(skb, FRA_TUN_ID, rule->tun_id, FRA_PAD)) || + (rule->l3mdev && + nla_put_u8(skb, FRA_L3MDEV, rule->l3mdev)) || + (uid_range_set(&rule->uid_range) && + nla_put_uid_range(skb, &rule->uid_range)) || + (fib_rule_port_range_set(&rule->sport_range) && + nla_put_port_range(skb, FRA_SPORT_RANGE, &rule->sport_range)) || + (fib_rule_port_range_set(&rule->dport_range) && + nla_put_port_range(skb, FRA_DPORT_RANGE, &rule->dport_range)) || + (rule->ip_proto && nla_put_u8(skb, FRA_IP_PROTO, rule->ip_proto))) + goto nla_put_failure; + + if (rule->suppress_ifgroup != -1) { + if (nla_put_u32(skb, FRA_SUPPRESS_IFGROUP, rule->suppress_ifgroup)) + goto nla_put_failure; + } + + if (ops->fill(rule, skb, frh) < 0) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int dump_rules(struct sk_buff *skb, struct netlink_callback *cb, + struct fib_rules_ops *ops) +{ + int idx = 0; + struct fib_rule *rule; + int err = 0; + + rcu_read_lock(); + list_for_each_entry_rcu(rule, &ops->rules_list, list) { + if (idx < cb->args[1]) + goto skip; + + err = fib_nl_fill_rule(skb, rule, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, RTM_NEWRULE, + NLM_F_MULTI, ops); + if (err) + break; +skip: + idx++; + } + rcu_read_unlock(); + cb->args[1] = idx; + rules_ops_put(ops); + + return err; +} + +static int fib_valid_dumprule_req(const struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct fib_rule_hdr *frh; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*frh))) { + NL_SET_ERR_MSG(extack, "Invalid header for fib rule dump request"); + return -EINVAL; + } + + frh = nlmsg_data(nlh); + if (frh->dst_len || frh->src_len || frh->tos || frh->table || + frh->res1 || frh->res2 || frh->action || frh->flags) { + NL_SET_ERR_MSG(extack, + "Invalid values in header for fib rule dump request"); + return -EINVAL; + } + + if (nlmsg_attrlen(nlh, sizeof(*frh))) { + NL_SET_ERR_MSG(extack, "Invalid data after header in fib rule dump request"); + return -EINVAL; + } + + return 0; +} + +static int fib_nl_dumprule(struct sk_buff *skb, struct netlink_callback *cb) +{ + const struct nlmsghdr *nlh = cb->nlh; + struct net *net = sock_net(skb->sk); + struct fib_rules_ops *ops; + int idx = 0, family; + + if (cb->strict_check) { + int err = fib_valid_dumprule_req(nlh, cb->extack); + + if (err < 0) + return err; + } + + family = rtnl_msg_family(nlh); + if (family != AF_UNSPEC) { + /* Protocol specific dump request */ + ops = lookup_rules_ops(net, family); + if (ops == NULL) + return -EAFNOSUPPORT; + + dump_rules(skb, cb, ops); + + return skb->len; + } + + rcu_read_lock(); + list_for_each_entry_rcu(ops, &net->rules_ops, list) { + if (idx < cb->args[0] || !try_module_get(ops->owner)) + goto skip; + + if (dump_rules(skb, cb, ops) < 0) + break; + + cb->args[1] = 0; +skip: + idx++; + } + rcu_read_unlock(); + cb->args[0] = idx; + + return skb->len; +} + +static void notify_rule_change(int event, struct fib_rule *rule, + struct fib_rules_ops *ops, struct nlmsghdr *nlh, + u32 pid) +{ + struct net *net; + struct sk_buff *skb; + int err = -ENOMEM; + + net = ops->fro_net; + skb = nlmsg_new(fib_rule_nlmsg_size(ops, rule), GFP_KERNEL); + if (skb == NULL) + goto errout; + + err = fib_nl_fill_rule(skb, rule, pid, nlh->nlmsg_seq, event, 0, ops); + if (err < 0) { + /* -EMSGSIZE implies BUG in fib_rule_nlmsg_size() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + + rtnl_notify(skb, net, pid, ops->nlgroup, nlh, GFP_KERNEL); + return; +errout: + if (err < 0) + rtnl_set_sk_err(net, ops->nlgroup, err); +} + +static void attach_rules(struct list_head *rules, struct net_device *dev) +{ + struct fib_rule *rule; + + list_for_each_entry(rule, rules, list) { + if (rule->iifindex == -1 && + strcmp(dev->name, rule->iifname) == 0) + rule->iifindex = dev->ifindex; + if (rule->oifindex == -1 && + strcmp(dev->name, rule->oifname) == 0) + rule->oifindex = dev->ifindex; + } +} + +static void detach_rules(struct list_head *rules, struct net_device *dev) +{ + struct fib_rule *rule; + + list_for_each_entry(rule, rules, list) { + if (rule->iifindex == dev->ifindex) + rule->iifindex = -1; + if (rule->oifindex == dev->ifindex) + rule->oifindex = -1; + } +} + + +static int fib_rules_event(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct net *net = dev_net(dev); + struct fib_rules_ops *ops; + + ASSERT_RTNL(); + + switch (event) { + case NETDEV_REGISTER: + list_for_each_entry(ops, &net->rules_ops, list) + attach_rules(&ops->rules_list, dev); + break; + + case NETDEV_CHANGENAME: + list_for_each_entry(ops, &net->rules_ops, list) { + detach_rules(&ops->rules_list, dev); + attach_rules(&ops->rules_list, dev); + } + break; + + case NETDEV_UNREGISTER: + list_for_each_entry(ops, &net->rules_ops, list) + detach_rules(&ops->rules_list, dev); + break; + } + + return NOTIFY_DONE; +} + +static struct notifier_block fib_rules_notifier = { + .notifier_call = fib_rules_event, +}; + +static int __net_init fib_rules_net_init(struct net *net) +{ + INIT_LIST_HEAD(&net->rules_ops); + spin_lock_init(&net->rules_mod_lock); + return 0; +} + +static void __net_exit fib_rules_net_exit(struct net *net) +{ + WARN_ON_ONCE(!list_empty(&net->rules_ops)); +} + +static struct pernet_operations fib_rules_net_ops = { + .init = fib_rules_net_init, + .exit = fib_rules_net_exit, +}; + +static int __init fib_rules_init(void) +{ + int err; + rtnl_register(PF_UNSPEC, RTM_NEWRULE, fib_nl_newrule, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_DELRULE, fib_nl_delrule, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_GETRULE, NULL, fib_nl_dumprule, 0); + + err = register_pernet_subsys(&fib_rules_net_ops); + if (err < 0) + goto fail; + + err = register_netdevice_notifier(&fib_rules_notifier); + if (err < 0) + goto fail_unregister; + + return 0; + +fail_unregister: + unregister_pernet_subsys(&fib_rules_net_ops); +fail: + rtnl_unregister(PF_UNSPEC, RTM_NEWRULE); + rtnl_unregister(PF_UNSPEC, RTM_DELRULE); + rtnl_unregister(PF_UNSPEC, RTM_GETRULE); + return err; +} + +subsys_initcall(fib_rules_init); diff --git a/net/core/filter.c b/net/core/filter.c new file mode 100644 index 000000000..3a6110ea4 --- /dev/null +++ b/net/core/filter.c @@ -0,0 +1,11669 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Linux Socket Filter - Kernel level socket filtering + * + * Based on the design of the Berkeley Packet Filter. The new + * internal format has been designed by PLUMgrid: + * + * Copyright (c) 2011 - 2014 PLUMgrid, http://plumgrid.com + * + * Authors: + * + * Jay Schulist + * Alexei Starovoitov + * Daniel Borkmann + * + * Andi Kleen - Fix a few bad bugs and races. + * Kris Katterjohn - Added many additional checks in bpf_check_classic() + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static const struct bpf_func_proto * +bpf_sk_base_func_proto(enum bpf_func_id func_id); + +int copy_bpf_fprog_from_user(struct sock_fprog *dst, sockptr_t src, int len) +{ + if (in_compat_syscall()) { + struct compat_sock_fprog f32; + + if (len != sizeof(f32)) + return -EINVAL; + if (copy_from_sockptr(&f32, src, sizeof(f32))) + return -EFAULT; + memset(dst, 0, sizeof(*dst)); + dst->len = f32.len; + dst->filter = compat_ptr(f32.filter); + } else { + if (len != sizeof(*dst)) + return -EINVAL; + if (copy_from_sockptr(dst, src, sizeof(*dst))) + return -EFAULT; + } + + return 0; +} +EXPORT_SYMBOL_GPL(copy_bpf_fprog_from_user); + +/** + * sk_filter_trim_cap - run a packet through a socket filter + * @sk: sock associated with &sk_buff + * @skb: buffer to filter + * @cap: limit on how short the eBPF program may trim the packet + * + * Run the eBPF program and then cut skb->data to correct size returned by + * the program. If pkt_len is 0 we toss packet. If skb->len is smaller + * than pkt_len we keep whole skb->data. This is the socket level + * wrapper to bpf_prog_run. It returns 0 if the packet should + * be accepted or -EPERM if the packet should be tossed. + * + */ +int sk_filter_trim_cap(struct sock *sk, struct sk_buff *skb, unsigned int cap) +{ + int err; + struct sk_filter *filter; + + /* + * If the skb was allocated from pfmemalloc reserves, only + * allow SOCK_MEMALLOC sockets to use it as this socket is + * helping free memory + */ + if (skb_pfmemalloc(skb) && !sock_flag(sk, SOCK_MEMALLOC)) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_PFMEMALLOCDROP); + return -ENOMEM; + } + err = BPF_CGROUP_RUN_PROG_INET_INGRESS(sk, skb); + if (err) + return err; + + err = security_sock_rcv_skb(sk, skb); + if (err) + return err; + + rcu_read_lock(); + filter = rcu_dereference(sk->sk_filter); + if (filter) { + struct sock *save_sk = skb->sk; + unsigned int pkt_len; + + skb->sk = sk; + pkt_len = bpf_prog_run_save_cb(filter->prog, skb); + skb->sk = save_sk; + err = pkt_len ? pskb_trim(skb, max(cap, pkt_len)) : -EPERM; + } + rcu_read_unlock(); + + return err; +} +EXPORT_SYMBOL(sk_filter_trim_cap); + +BPF_CALL_1(bpf_skb_get_pay_offset, struct sk_buff *, skb) +{ + return skb_get_poff(skb); +} + +BPF_CALL_3(bpf_skb_get_nlattr, struct sk_buff *, skb, u32, a, u32, x) +{ + struct nlattr *nla; + + if (skb_is_nonlinear(skb)) + return 0; + + if (skb->len < sizeof(struct nlattr)) + return 0; + + if (a > skb->len - sizeof(struct nlattr)) + return 0; + + nla = nla_find((struct nlattr *) &skb->data[a], skb->len - a, x); + if (nla) + return (void *) nla - (void *) skb->data; + + return 0; +} + +BPF_CALL_3(bpf_skb_get_nlattr_nest, struct sk_buff *, skb, u32, a, u32, x) +{ + struct nlattr *nla; + + if (skb_is_nonlinear(skb)) + return 0; + + if (skb->len < sizeof(struct nlattr)) + return 0; + + if (a > skb->len - sizeof(struct nlattr)) + return 0; + + nla = (struct nlattr *) &skb->data[a]; + if (nla->nla_len > skb->len - a) + return 0; + + nla = nla_find_nested(nla, x); + if (nla) + return (void *) nla - (void *) skb->data; + + return 0; +} + +BPF_CALL_4(bpf_skb_load_helper_8, const struct sk_buff *, skb, const void *, + data, int, headlen, int, offset) +{ + u8 tmp, *ptr; + const int len = sizeof(tmp); + + if (offset >= 0) { + if (headlen - offset >= len) + return *(u8 *)(data + offset); + if (!skb_copy_bits(skb, offset, &tmp, sizeof(tmp))) + return tmp; + } else { + ptr = bpf_internal_load_pointer_neg_helper(skb, offset, len); + if (likely(ptr)) + return *(u8 *)ptr; + } + + return -EFAULT; +} + +BPF_CALL_2(bpf_skb_load_helper_8_no_cache, const struct sk_buff *, skb, + int, offset) +{ + return ____bpf_skb_load_helper_8(skb, skb->data, skb->len - skb->data_len, + offset); +} + +BPF_CALL_4(bpf_skb_load_helper_16, const struct sk_buff *, skb, const void *, + data, int, headlen, int, offset) +{ + __be16 tmp, *ptr; + const int len = sizeof(tmp); + + if (offset >= 0) { + if (headlen - offset >= len) + return get_unaligned_be16(data + offset); + if (!skb_copy_bits(skb, offset, &tmp, sizeof(tmp))) + return be16_to_cpu(tmp); + } else { + ptr = bpf_internal_load_pointer_neg_helper(skb, offset, len); + if (likely(ptr)) + return get_unaligned_be16(ptr); + } + + return -EFAULT; +} + +BPF_CALL_2(bpf_skb_load_helper_16_no_cache, const struct sk_buff *, skb, + int, offset) +{ + return ____bpf_skb_load_helper_16(skb, skb->data, skb->len - skb->data_len, + offset); +} + +BPF_CALL_4(bpf_skb_load_helper_32, const struct sk_buff *, skb, const void *, + data, int, headlen, int, offset) +{ + __be32 tmp, *ptr; + const int len = sizeof(tmp); + + if (likely(offset >= 0)) { + if (headlen - offset >= len) + return get_unaligned_be32(data + offset); + if (!skb_copy_bits(skb, offset, &tmp, sizeof(tmp))) + return be32_to_cpu(tmp); + } else { + ptr = bpf_internal_load_pointer_neg_helper(skb, offset, len); + if (likely(ptr)) + return get_unaligned_be32(ptr); + } + + return -EFAULT; +} + +BPF_CALL_2(bpf_skb_load_helper_32_no_cache, const struct sk_buff *, skb, + int, offset) +{ + return ____bpf_skb_load_helper_32(skb, skb->data, skb->len - skb->data_len, + offset); +} + +static u32 convert_skb_access(int skb_field, int dst_reg, int src_reg, + struct bpf_insn *insn_buf) +{ + struct bpf_insn *insn = insn_buf; + + switch (skb_field) { + case SKF_AD_MARK: + BUILD_BUG_ON(sizeof_field(struct sk_buff, mark) != 4); + + *insn++ = BPF_LDX_MEM(BPF_W, dst_reg, src_reg, + offsetof(struct sk_buff, mark)); + break; + + case SKF_AD_PKTTYPE: + *insn++ = BPF_LDX_MEM(BPF_B, dst_reg, src_reg, PKT_TYPE_OFFSET); + *insn++ = BPF_ALU32_IMM(BPF_AND, dst_reg, PKT_TYPE_MAX); +#ifdef __BIG_ENDIAN_BITFIELD + *insn++ = BPF_ALU32_IMM(BPF_RSH, dst_reg, 5); +#endif + break; + + case SKF_AD_QUEUE: + BUILD_BUG_ON(sizeof_field(struct sk_buff, queue_mapping) != 2); + + *insn++ = BPF_LDX_MEM(BPF_H, dst_reg, src_reg, + offsetof(struct sk_buff, queue_mapping)); + break; + + case SKF_AD_VLAN_TAG: + BUILD_BUG_ON(sizeof_field(struct sk_buff, vlan_tci) != 2); + + /* dst_reg = *(u16 *) (src_reg + offsetof(vlan_tci)) */ + *insn++ = BPF_LDX_MEM(BPF_H, dst_reg, src_reg, + offsetof(struct sk_buff, vlan_tci)); + break; + case SKF_AD_VLAN_TAG_PRESENT: + *insn++ = BPF_LDX_MEM(BPF_B, dst_reg, src_reg, PKT_VLAN_PRESENT_OFFSET); + if (PKT_VLAN_PRESENT_BIT) + *insn++ = BPF_ALU32_IMM(BPF_RSH, dst_reg, PKT_VLAN_PRESENT_BIT); + if (PKT_VLAN_PRESENT_BIT < 7) + *insn++ = BPF_ALU32_IMM(BPF_AND, dst_reg, 1); + break; + } + + return insn - insn_buf; +} + +static bool convert_bpf_extensions(struct sock_filter *fp, + struct bpf_insn **insnp) +{ + struct bpf_insn *insn = *insnp; + u32 cnt; + + switch (fp->k) { + case SKF_AD_OFF + SKF_AD_PROTOCOL: + BUILD_BUG_ON(sizeof_field(struct sk_buff, protocol) != 2); + + /* A = *(u16 *) (CTX + offsetof(protocol)) */ + *insn++ = BPF_LDX_MEM(BPF_H, BPF_REG_A, BPF_REG_CTX, + offsetof(struct sk_buff, protocol)); + /* A = ntohs(A) [emitting a nop or swap16] */ + *insn = BPF_ENDIAN(BPF_FROM_BE, BPF_REG_A, 16); + break; + + case SKF_AD_OFF + SKF_AD_PKTTYPE: + cnt = convert_skb_access(SKF_AD_PKTTYPE, BPF_REG_A, BPF_REG_CTX, insn); + insn += cnt - 1; + break; + + case SKF_AD_OFF + SKF_AD_IFINDEX: + case SKF_AD_OFF + SKF_AD_HATYPE: + BUILD_BUG_ON(sizeof_field(struct net_device, ifindex) != 4); + BUILD_BUG_ON(sizeof_field(struct net_device, type) != 2); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, dev), + BPF_REG_TMP, BPF_REG_CTX, + offsetof(struct sk_buff, dev)); + /* if (tmp != 0) goto pc + 1 */ + *insn++ = BPF_JMP_IMM(BPF_JNE, BPF_REG_TMP, 0, 1); + *insn++ = BPF_EXIT_INSN(); + if (fp->k == SKF_AD_OFF + SKF_AD_IFINDEX) + *insn = BPF_LDX_MEM(BPF_W, BPF_REG_A, BPF_REG_TMP, + offsetof(struct net_device, ifindex)); + else + *insn = BPF_LDX_MEM(BPF_H, BPF_REG_A, BPF_REG_TMP, + offsetof(struct net_device, type)); + break; + + case SKF_AD_OFF + SKF_AD_MARK: + cnt = convert_skb_access(SKF_AD_MARK, BPF_REG_A, BPF_REG_CTX, insn); + insn += cnt - 1; + break; + + case SKF_AD_OFF + SKF_AD_RXHASH: + BUILD_BUG_ON(sizeof_field(struct sk_buff, hash) != 4); + + *insn = BPF_LDX_MEM(BPF_W, BPF_REG_A, BPF_REG_CTX, + offsetof(struct sk_buff, hash)); + break; + + case SKF_AD_OFF + SKF_AD_QUEUE: + cnt = convert_skb_access(SKF_AD_QUEUE, BPF_REG_A, BPF_REG_CTX, insn); + insn += cnt - 1; + break; + + case SKF_AD_OFF + SKF_AD_VLAN_TAG: + cnt = convert_skb_access(SKF_AD_VLAN_TAG, + BPF_REG_A, BPF_REG_CTX, insn); + insn += cnt - 1; + break; + + case SKF_AD_OFF + SKF_AD_VLAN_TAG_PRESENT: + cnt = convert_skb_access(SKF_AD_VLAN_TAG_PRESENT, + BPF_REG_A, BPF_REG_CTX, insn); + insn += cnt - 1; + break; + + case SKF_AD_OFF + SKF_AD_VLAN_TPID: + BUILD_BUG_ON(sizeof_field(struct sk_buff, vlan_proto) != 2); + + /* A = *(u16 *) (CTX + offsetof(vlan_proto)) */ + *insn++ = BPF_LDX_MEM(BPF_H, BPF_REG_A, BPF_REG_CTX, + offsetof(struct sk_buff, vlan_proto)); + /* A = ntohs(A) [emitting a nop or swap16] */ + *insn = BPF_ENDIAN(BPF_FROM_BE, BPF_REG_A, 16); + break; + + case SKF_AD_OFF + SKF_AD_PAY_OFFSET: + case SKF_AD_OFF + SKF_AD_NLATTR: + case SKF_AD_OFF + SKF_AD_NLATTR_NEST: + case SKF_AD_OFF + SKF_AD_CPU: + case SKF_AD_OFF + SKF_AD_RANDOM: + /* arg1 = CTX */ + *insn++ = BPF_MOV64_REG(BPF_REG_ARG1, BPF_REG_CTX); + /* arg2 = A */ + *insn++ = BPF_MOV64_REG(BPF_REG_ARG2, BPF_REG_A); + /* arg3 = X */ + *insn++ = BPF_MOV64_REG(BPF_REG_ARG3, BPF_REG_X); + /* Emit call(arg1=CTX, arg2=A, arg3=X) */ + switch (fp->k) { + case SKF_AD_OFF + SKF_AD_PAY_OFFSET: + *insn = BPF_EMIT_CALL(bpf_skb_get_pay_offset); + break; + case SKF_AD_OFF + SKF_AD_NLATTR: + *insn = BPF_EMIT_CALL(bpf_skb_get_nlattr); + break; + case SKF_AD_OFF + SKF_AD_NLATTR_NEST: + *insn = BPF_EMIT_CALL(bpf_skb_get_nlattr_nest); + break; + case SKF_AD_OFF + SKF_AD_CPU: + *insn = BPF_EMIT_CALL(bpf_get_raw_cpu_id); + break; + case SKF_AD_OFF + SKF_AD_RANDOM: + *insn = BPF_EMIT_CALL(bpf_user_rnd_u32); + bpf_user_rnd_init_once(); + break; + } + break; + + case SKF_AD_OFF + SKF_AD_ALU_XOR_X: + /* A ^= X */ + *insn = BPF_ALU32_REG(BPF_XOR, BPF_REG_A, BPF_REG_X); + break; + + default: + /* This is just a dummy call to avoid letting the compiler + * evict __bpf_call_base() as an optimization. Placed here + * where no-one bothers. + */ + BUG_ON(__bpf_call_base(0, 0, 0, 0, 0) != 0); + return false; + } + + *insnp = insn; + return true; +} + +static bool convert_bpf_ld_abs(struct sock_filter *fp, struct bpf_insn **insnp) +{ + const bool unaligned_ok = IS_BUILTIN(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS); + int size = bpf_size_to_bytes(BPF_SIZE(fp->code)); + bool endian = BPF_SIZE(fp->code) == BPF_H || + BPF_SIZE(fp->code) == BPF_W; + bool indirect = BPF_MODE(fp->code) == BPF_IND; + const int ip_align = NET_IP_ALIGN; + struct bpf_insn *insn = *insnp; + int offset = fp->k; + + if (!indirect && + ((unaligned_ok && offset >= 0) || + (!unaligned_ok && offset >= 0 && + offset + ip_align >= 0 && + offset + ip_align % size == 0))) { + bool ldx_off_ok = offset <= S16_MAX; + + *insn++ = BPF_MOV64_REG(BPF_REG_TMP, BPF_REG_H); + if (offset) + *insn++ = BPF_ALU64_IMM(BPF_SUB, BPF_REG_TMP, offset); + *insn++ = BPF_JMP_IMM(BPF_JSLT, BPF_REG_TMP, + size, 2 + endian + (!ldx_off_ok * 2)); + if (ldx_off_ok) { + *insn++ = BPF_LDX_MEM(BPF_SIZE(fp->code), BPF_REG_A, + BPF_REG_D, offset); + } else { + *insn++ = BPF_MOV64_REG(BPF_REG_TMP, BPF_REG_D); + *insn++ = BPF_ALU64_IMM(BPF_ADD, BPF_REG_TMP, offset); + *insn++ = BPF_LDX_MEM(BPF_SIZE(fp->code), BPF_REG_A, + BPF_REG_TMP, 0); + } + if (endian) + *insn++ = BPF_ENDIAN(BPF_FROM_BE, BPF_REG_A, size * 8); + *insn++ = BPF_JMP_A(8); + } + + *insn++ = BPF_MOV64_REG(BPF_REG_ARG1, BPF_REG_CTX); + *insn++ = BPF_MOV64_REG(BPF_REG_ARG2, BPF_REG_D); + *insn++ = BPF_MOV64_REG(BPF_REG_ARG3, BPF_REG_H); + if (!indirect) { + *insn++ = BPF_MOV64_IMM(BPF_REG_ARG4, offset); + } else { + *insn++ = BPF_MOV64_REG(BPF_REG_ARG4, BPF_REG_X); + if (fp->k) + *insn++ = BPF_ALU64_IMM(BPF_ADD, BPF_REG_ARG4, offset); + } + + switch (BPF_SIZE(fp->code)) { + case BPF_B: + *insn++ = BPF_EMIT_CALL(bpf_skb_load_helper_8); + break; + case BPF_H: + *insn++ = BPF_EMIT_CALL(bpf_skb_load_helper_16); + break; + case BPF_W: + *insn++ = BPF_EMIT_CALL(bpf_skb_load_helper_32); + break; + default: + return false; + } + + *insn++ = BPF_JMP_IMM(BPF_JSGE, BPF_REG_A, 0, 2); + *insn++ = BPF_ALU32_REG(BPF_XOR, BPF_REG_A, BPF_REG_A); + *insn = BPF_EXIT_INSN(); + + *insnp = insn; + return true; +} + +/** + * bpf_convert_filter - convert filter program + * @prog: the user passed filter program + * @len: the length of the user passed filter program + * @new_prog: allocated 'struct bpf_prog' or NULL + * @new_len: pointer to store length of converted program + * @seen_ld_abs: bool whether we've seen ld_abs/ind + * + * Remap 'sock_filter' style classic BPF (cBPF) instruction set to 'bpf_insn' + * style extended BPF (eBPF). + * Conversion workflow: + * + * 1) First pass for calculating the new program length: + * bpf_convert_filter(old_prog, old_len, NULL, &new_len, &seen_ld_abs) + * + * 2) 2nd pass to remap in two passes: 1st pass finds new + * jump offsets, 2nd pass remapping: + * bpf_convert_filter(old_prog, old_len, new_prog, &new_len, &seen_ld_abs) + */ +static int bpf_convert_filter(struct sock_filter *prog, int len, + struct bpf_prog *new_prog, int *new_len, + bool *seen_ld_abs) +{ + int new_flen = 0, pass = 0, target, i, stack_off; + struct bpf_insn *new_insn, *first_insn = NULL; + struct sock_filter *fp; + int *addrs = NULL; + u8 bpf_src; + + BUILD_BUG_ON(BPF_MEMWORDS * sizeof(u32) > MAX_BPF_STACK); + BUILD_BUG_ON(BPF_REG_FP + 1 != MAX_BPF_REG); + + if (len <= 0 || len > BPF_MAXINSNS) + return -EINVAL; + + if (new_prog) { + first_insn = new_prog->insnsi; + addrs = kcalloc(len, sizeof(*addrs), + GFP_KERNEL | __GFP_NOWARN); + if (!addrs) + return -ENOMEM; + } + +do_pass: + new_insn = first_insn; + fp = prog; + + /* Classic BPF related prologue emission. */ + if (new_prog) { + /* Classic BPF expects A and X to be reset first. These need + * to be guaranteed to be the first two instructions. + */ + *new_insn++ = BPF_ALU32_REG(BPF_XOR, BPF_REG_A, BPF_REG_A); + *new_insn++ = BPF_ALU32_REG(BPF_XOR, BPF_REG_X, BPF_REG_X); + + /* All programs must keep CTX in callee saved BPF_REG_CTX. + * In eBPF case it's done by the compiler, here we need to + * do this ourself. Initial CTX is present in BPF_REG_ARG1. + */ + *new_insn++ = BPF_MOV64_REG(BPF_REG_CTX, BPF_REG_ARG1); + if (*seen_ld_abs) { + /* For packet access in classic BPF, cache skb->data + * in callee-saved BPF R8 and skb->len - skb->data_len + * (headlen) in BPF R9. Since classic BPF is read-only + * on CTX, we only need to cache it once. + */ + *new_insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, data), + BPF_REG_D, BPF_REG_CTX, + offsetof(struct sk_buff, data)); + *new_insn++ = BPF_LDX_MEM(BPF_W, BPF_REG_H, BPF_REG_CTX, + offsetof(struct sk_buff, len)); + *new_insn++ = BPF_LDX_MEM(BPF_W, BPF_REG_TMP, BPF_REG_CTX, + offsetof(struct sk_buff, data_len)); + *new_insn++ = BPF_ALU32_REG(BPF_SUB, BPF_REG_H, BPF_REG_TMP); + } + } else { + new_insn += 3; + } + + for (i = 0; i < len; fp++, i++) { + struct bpf_insn tmp_insns[32] = { }; + struct bpf_insn *insn = tmp_insns; + + if (addrs) + addrs[i] = new_insn - first_insn; + + switch (fp->code) { + /* All arithmetic insns and skb loads map as-is. */ + case BPF_ALU | BPF_ADD | BPF_X: + case BPF_ALU | BPF_ADD | BPF_K: + case BPF_ALU | BPF_SUB | BPF_X: + case BPF_ALU | BPF_SUB | BPF_K: + case BPF_ALU | BPF_AND | BPF_X: + case BPF_ALU | BPF_AND | BPF_K: + case BPF_ALU | BPF_OR | BPF_X: + case BPF_ALU | BPF_OR | BPF_K: + case BPF_ALU | BPF_LSH | BPF_X: + case BPF_ALU | BPF_LSH | BPF_K: + case BPF_ALU | BPF_RSH | BPF_X: + case BPF_ALU | BPF_RSH | BPF_K: + case BPF_ALU | BPF_XOR | BPF_X: + case BPF_ALU | BPF_XOR | BPF_K: + case BPF_ALU | BPF_MUL | BPF_X: + case BPF_ALU | BPF_MUL | BPF_K: + case BPF_ALU | BPF_DIV | BPF_X: + case BPF_ALU | BPF_DIV | BPF_K: + case BPF_ALU | BPF_MOD | BPF_X: + case BPF_ALU | BPF_MOD | BPF_K: + case BPF_ALU | BPF_NEG: + case BPF_LD | BPF_ABS | BPF_W: + case BPF_LD | BPF_ABS | BPF_H: + case BPF_LD | BPF_ABS | BPF_B: + case BPF_LD | BPF_IND | BPF_W: + case BPF_LD | BPF_IND | BPF_H: + case BPF_LD | BPF_IND | BPF_B: + /* Check for overloaded BPF extension and + * directly convert it if found, otherwise + * just move on with mapping. + */ + if (BPF_CLASS(fp->code) == BPF_LD && + BPF_MODE(fp->code) == BPF_ABS && + convert_bpf_extensions(fp, &insn)) + break; + if (BPF_CLASS(fp->code) == BPF_LD && + convert_bpf_ld_abs(fp, &insn)) { + *seen_ld_abs = true; + break; + } + + if (fp->code == (BPF_ALU | BPF_DIV | BPF_X) || + fp->code == (BPF_ALU | BPF_MOD | BPF_X)) { + *insn++ = BPF_MOV32_REG(BPF_REG_X, BPF_REG_X); + /* Error with exception code on div/mod by 0. + * For cBPF programs, this was always return 0. + */ + *insn++ = BPF_JMP_IMM(BPF_JNE, BPF_REG_X, 0, 2); + *insn++ = BPF_ALU32_REG(BPF_XOR, BPF_REG_A, BPF_REG_A); + *insn++ = BPF_EXIT_INSN(); + } + + *insn = BPF_RAW_INSN(fp->code, BPF_REG_A, BPF_REG_X, 0, fp->k); + break; + + /* Jump transformation cannot use BPF block macros + * everywhere as offset calculation and target updates + * require a bit more work than the rest, i.e. jump + * opcodes map as-is, but offsets need adjustment. + */ + +#define BPF_EMIT_JMP \ + do { \ + const s32 off_min = S16_MIN, off_max = S16_MAX; \ + s32 off; \ + \ + if (target >= len || target < 0) \ + goto err; \ + off = addrs ? addrs[target] - addrs[i] - 1 : 0; \ + /* Adjust pc relative offset for 2nd or 3rd insn. */ \ + off -= insn - tmp_insns; \ + /* Reject anything not fitting into insn->off. */ \ + if (off < off_min || off > off_max) \ + goto err; \ + insn->off = off; \ + } while (0) + + case BPF_JMP | BPF_JA: + target = i + fp->k + 1; + insn->code = fp->code; + BPF_EMIT_JMP; + break; + + case BPF_JMP | BPF_JEQ | BPF_K: + case BPF_JMP | BPF_JEQ | BPF_X: + case BPF_JMP | BPF_JSET | BPF_K: + case BPF_JMP | BPF_JSET | BPF_X: + case BPF_JMP | BPF_JGT | BPF_K: + case BPF_JMP | BPF_JGT | BPF_X: + case BPF_JMP | BPF_JGE | BPF_K: + case BPF_JMP | BPF_JGE | BPF_X: + if (BPF_SRC(fp->code) == BPF_K && (int) fp->k < 0) { + /* BPF immediates are signed, zero extend + * immediate into tmp register and use it + * in compare insn. + */ + *insn++ = BPF_MOV32_IMM(BPF_REG_TMP, fp->k); + + insn->dst_reg = BPF_REG_A; + insn->src_reg = BPF_REG_TMP; + bpf_src = BPF_X; + } else { + insn->dst_reg = BPF_REG_A; + insn->imm = fp->k; + bpf_src = BPF_SRC(fp->code); + insn->src_reg = bpf_src == BPF_X ? BPF_REG_X : 0; + } + + /* Common case where 'jump_false' is next insn. */ + if (fp->jf == 0) { + insn->code = BPF_JMP | BPF_OP(fp->code) | bpf_src; + target = i + fp->jt + 1; + BPF_EMIT_JMP; + break; + } + + /* Convert some jumps when 'jump_true' is next insn. */ + if (fp->jt == 0) { + switch (BPF_OP(fp->code)) { + case BPF_JEQ: + insn->code = BPF_JMP | BPF_JNE | bpf_src; + break; + case BPF_JGT: + insn->code = BPF_JMP | BPF_JLE | bpf_src; + break; + case BPF_JGE: + insn->code = BPF_JMP | BPF_JLT | bpf_src; + break; + default: + goto jmp_rest; + } + + target = i + fp->jf + 1; + BPF_EMIT_JMP; + break; + } +jmp_rest: + /* Other jumps are mapped into two insns: Jxx and JA. */ + target = i + fp->jt + 1; + insn->code = BPF_JMP | BPF_OP(fp->code) | bpf_src; + BPF_EMIT_JMP; + insn++; + + insn->code = BPF_JMP | BPF_JA; + target = i + fp->jf + 1; + BPF_EMIT_JMP; + break; + + /* ldxb 4 * ([14] & 0xf) is remaped into 6 insns. */ + case BPF_LDX | BPF_MSH | BPF_B: { + struct sock_filter tmp = { + .code = BPF_LD | BPF_ABS | BPF_B, + .k = fp->k, + }; + + *seen_ld_abs = true; + + /* X = A */ + *insn++ = BPF_MOV64_REG(BPF_REG_X, BPF_REG_A); + /* A = BPF_R0 = *(u8 *) (skb->data + K) */ + convert_bpf_ld_abs(&tmp, &insn); + insn++; + /* A &= 0xf */ + *insn++ = BPF_ALU32_IMM(BPF_AND, BPF_REG_A, 0xf); + /* A <<= 2 */ + *insn++ = BPF_ALU32_IMM(BPF_LSH, BPF_REG_A, 2); + /* tmp = X */ + *insn++ = BPF_MOV64_REG(BPF_REG_TMP, BPF_REG_X); + /* X = A */ + *insn++ = BPF_MOV64_REG(BPF_REG_X, BPF_REG_A); + /* A = tmp */ + *insn = BPF_MOV64_REG(BPF_REG_A, BPF_REG_TMP); + break; + } + /* RET_K is remaped into 2 insns. RET_A case doesn't need an + * extra mov as BPF_REG_0 is already mapped into BPF_REG_A. + */ + case BPF_RET | BPF_A: + case BPF_RET | BPF_K: + if (BPF_RVAL(fp->code) == BPF_K) + *insn++ = BPF_MOV32_RAW(BPF_K, BPF_REG_0, + 0, fp->k); + *insn = BPF_EXIT_INSN(); + break; + + /* Store to stack. */ + case BPF_ST: + case BPF_STX: + stack_off = fp->k * 4 + 4; + *insn = BPF_STX_MEM(BPF_W, BPF_REG_FP, BPF_CLASS(fp->code) == + BPF_ST ? BPF_REG_A : BPF_REG_X, + -stack_off); + /* check_load_and_stores() verifies that classic BPF can + * load from stack only after write, so tracking + * stack_depth for ST|STX insns is enough + */ + if (new_prog && new_prog->aux->stack_depth < stack_off) + new_prog->aux->stack_depth = stack_off; + break; + + /* Load from stack. */ + case BPF_LD | BPF_MEM: + case BPF_LDX | BPF_MEM: + stack_off = fp->k * 4 + 4; + *insn = BPF_LDX_MEM(BPF_W, BPF_CLASS(fp->code) == BPF_LD ? + BPF_REG_A : BPF_REG_X, BPF_REG_FP, + -stack_off); + break; + + /* A = K or X = K */ + case BPF_LD | BPF_IMM: + case BPF_LDX | BPF_IMM: + *insn = BPF_MOV32_IMM(BPF_CLASS(fp->code) == BPF_LD ? + BPF_REG_A : BPF_REG_X, fp->k); + break; + + /* X = A */ + case BPF_MISC | BPF_TAX: + *insn = BPF_MOV64_REG(BPF_REG_X, BPF_REG_A); + break; + + /* A = X */ + case BPF_MISC | BPF_TXA: + *insn = BPF_MOV64_REG(BPF_REG_A, BPF_REG_X); + break; + + /* A = skb->len or X = skb->len */ + case BPF_LD | BPF_W | BPF_LEN: + case BPF_LDX | BPF_W | BPF_LEN: + *insn = BPF_LDX_MEM(BPF_W, BPF_CLASS(fp->code) == BPF_LD ? + BPF_REG_A : BPF_REG_X, BPF_REG_CTX, + offsetof(struct sk_buff, len)); + break; + + /* Access seccomp_data fields. */ + case BPF_LDX | BPF_ABS | BPF_W: + /* A = *(u32 *) (ctx + K) */ + *insn = BPF_LDX_MEM(BPF_W, BPF_REG_A, BPF_REG_CTX, fp->k); + break; + + /* Unknown instruction. */ + default: + goto err; + } + + insn++; + if (new_prog) + memcpy(new_insn, tmp_insns, + sizeof(*insn) * (insn - tmp_insns)); + new_insn += insn - tmp_insns; + } + + if (!new_prog) { + /* Only calculating new length. */ + *new_len = new_insn - first_insn; + if (*seen_ld_abs) + *new_len += 4; /* Prologue bits. */ + return 0; + } + + pass++; + if (new_flen != new_insn - first_insn) { + new_flen = new_insn - first_insn; + if (pass > 2) + goto err; + goto do_pass; + } + + kfree(addrs); + BUG_ON(*new_len != new_flen); + return 0; +err: + kfree(addrs); + return -EINVAL; +} + +/* Security: + * + * As we dont want to clear mem[] array for each packet going through + * __bpf_prog_run(), we check that filter loaded by user never try to read + * a cell if not previously written, and we check all branches to be sure + * a malicious user doesn't try to abuse us. + */ +static int check_load_and_stores(const struct sock_filter *filter, int flen) +{ + u16 *masks, memvalid = 0; /* One bit per cell, 16 cells */ + int pc, ret = 0; + + BUILD_BUG_ON(BPF_MEMWORDS > 16); + + masks = kmalloc_array(flen, sizeof(*masks), GFP_KERNEL); + if (!masks) + return -ENOMEM; + + memset(masks, 0xff, flen * sizeof(*masks)); + + for (pc = 0; pc < flen; pc++) { + memvalid &= masks[pc]; + + switch (filter[pc].code) { + case BPF_ST: + case BPF_STX: + memvalid |= (1 << filter[pc].k); + break; + case BPF_LD | BPF_MEM: + case BPF_LDX | BPF_MEM: + if (!(memvalid & (1 << filter[pc].k))) { + ret = -EINVAL; + goto error; + } + break; + case BPF_JMP | BPF_JA: + /* A jump must set masks on target */ + masks[pc + 1 + filter[pc].k] &= memvalid; + memvalid = ~0; + break; + case BPF_JMP | BPF_JEQ | BPF_K: + case BPF_JMP | BPF_JEQ | BPF_X: + case BPF_JMP | BPF_JGE | BPF_K: + case BPF_JMP | BPF_JGE | BPF_X: + case BPF_JMP | BPF_JGT | BPF_K: + case BPF_JMP | BPF_JGT | BPF_X: + case BPF_JMP | BPF_JSET | BPF_K: + case BPF_JMP | BPF_JSET | BPF_X: + /* A jump must set masks on targets */ + masks[pc + 1 + filter[pc].jt] &= memvalid; + masks[pc + 1 + filter[pc].jf] &= memvalid; + memvalid = ~0; + break; + } + } +error: + kfree(masks); + return ret; +} + +static bool chk_code_allowed(u16 code_to_probe) +{ + static const bool codes[] = { + /* 32 bit ALU operations */ + [BPF_ALU | BPF_ADD | BPF_K] = true, + [BPF_ALU | BPF_ADD | BPF_X] = true, + [BPF_ALU | BPF_SUB | BPF_K] = true, + [BPF_ALU | BPF_SUB | BPF_X] = true, + [BPF_ALU | BPF_MUL | BPF_K] = true, + [BPF_ALU | BPF_MUL | BPF_X] = true, + [BPF_ALU | BPF_DIV | BPF_K] = true, + [BPF_ALU | BPF_DIV | BPF_X] = true, + [BPF_ALU | BPF_MOD | BPF_K] = true, + [BPF_ALU | BPF_MOD | BPF_X] = true, + [BPF_ALU | BPF_AND | BPF_K] = true, + [BPF_ALU | BPF_AND | BPF_X] = true, + [BPF_ALU | BPF_OR | BPF_K] = true, + [BPF_ALU | BPF_OR | BPF_X] = true, + [BPF_ALU | BPF_XOR | BPF_K] = true, + [BPF_ALU | BPF_XOR | BPF_X] = true, + [BPF_ALU | BPF_LSH | BPF_K] = true, + [BPF_ALU | BPF_LSH | BPF_X] = true, + [BPF_ALU | BPF_RSH | BPF_K] = true, + [BPF_ALU | BPF_RSH | BPF_X] = true, + [BPF_ALU | BPF_NEG] = true, + /* Load instructions */ + [BPF_LD | BPF_W | BPF_ABS] = true, + [BPF_LD | BPF_H | BPF_ABS] = true, + [BPF_LD | BPF_B | BPF_ABS] = true, + [BPF_LD | BPF_W | BPF_LEN] = true, + [BPF_LD | BPF_W | BPF_IND] = true, + [BPF_LD | BPF_H | BPF_IND] = true, + [BPF_LD | BPF_B | BPF_IND] = true, + [BPF_LD | BPF_IMM] = true, + [BPF_LD | BPF_MEM] = true, + [BPF_LDX | BPF_W | BPF_LEN] = true, + [BPF_LDX | BPF_B | BPF_MSH] = true, + [BPF_LDX | BPF_IMM] = true, + [BPF_LDX | BPF_MEM] = true, + /* Store instructions */ + [BPF_ST] = true, + [BPF_STX] = true, + /* Misc instructions */ + [BPF_MISC | BPF_TAX] = true, + [BPF_MISC | BPF_TXA] = true, + /* Return instructions */ + [BPF_RET | BPF_K] = true, + [BPF_RET | BPF_A] = true, + /* Jump instructions */ + [BPF_JMP | BPF_JA] = true, + [BPF_JMP | BPF_JEQ | BPF_K] = true, + [BPF_JMP | BPF_JEQ | BPF_X] = true, + [BPF_JMP | BPF_JGE | BPF_K] = true, + [BPF_JMP | BPF_JGE | BPF_X] = true, + [BPF_JMP | BPF_JGT | BPF_K] = true, + [BPF_JMP | BPF_JGT | BPF_X] = true, + [BPF_JMP | BPF_JSET | BPF_K] = true, + [BPF_JMP | BPF_JSET | BPF_X] = true, + }; + + if (code_to_probe >= ARRAY_SIZE(codes)) + return false; + + return codes[code_to_probe]; +} + +static bool bpf_check_basics_ok(const struct sock_filter *filter, + unsigned int flen) +{ + if (filter == NULL) + return false; + if (flen == 0 || flen > BPF_MAXINSNS) + return false; + + return true; +} + +/** + * bpf_check_classic - verify socket filter code + * @filter: filter to verify + * @flen: length of filter + * + * Check the user's filter code. If we let some ugly + * filter code slip through kaboom! The filter must contain + * no references or jumps that are out of range, no illegal + * instructions, and must end with a RET instruction. + * + * All jumps are forward as they are not signed. + * + * Returns 0 if the rule set is legal or -EINVAL if not. + */ +static int bpf_check_classic(const struct sock_filter *filter, + unsigned int flen) +{ + bool anc_found; + int pc; + + /* Check the filter code now */ + for (pc = 0; pc < flen; pc++) { + const struct sock_filter *ftest = &filter[pc]; + + /* May we actually operate on this code? */ + if (!chk_code_allowed(ftest->code)) + return -EINVAL; + + /* Some instructions need special checks */ + switch (ftest->code) { + case BPF_ALU | BPF_DIV | BPF_K: + case BPF_ALU | BPF_MOD | BPF_K: + /* Check for division by zero */ + if (ftest->k == 0) + return -EINVAL; + break; + case BPF_ALU | BPF_LSH | BPF_K: + case BPF_ALU | BPF_RSH | BPF_K: + if (ftest->k >= 32) + return -EINVAL; + break; + case BPF_LD | BPF_MEM: + case BPF_LDX | BPF_MEM: + case BPF_ST: + case BPF_STX: + /* Check for invalid memory addresses */ + if (ftest->k >= BPF_MEMWORDS) + return -EINVAL; + break; + case BPF_JMP | BPF_JA: + /* Note, the large ftest->k might cause loops. + * Compare this with conditional jumps below, + * where offsets are limited. --ANK (981016) + */ + if (ftest->k >= (unsigned int)(flen - pc - 1)) + return -EINVAL; + break; + case BPF_JMP | BPF_JEQ | BPF_K: + case BPF_JMP | BPF_JEQ | BPF_X: + case BPF_JMP | BPF_JGE | BPF_K: + case BPF_JMP | BPF_JGE | BPF_X: + case BPF_JMP | BPF_JGT | BPF_K: + case BPF_JMP | BPF_JGT | BPF_X: + case BPF_JMP | BPF_JSET | BPF_K: + case BPF_JMP | BPF_JSET | BPF_X: + /* Both conditionals must be safe */ + if (pc + ftest->jt + 1 >= flen || + pc + ftest->jf + 1 >= flen) + return -EINVAL; + break; + case BPF_LD | BPF_W | BPF_ABS: + case BPF_LD | BPF_H | BPF_ABS: + case BPF_LD | BPF_B | BPF_ABS: + anc_found = false; + if (bpf_anc_helper(ftest) & BPF_ANC) + anc_found = true; + /* Ancillary operation unknown or unsupported */ + if (anc_found == false && ftest->k >= SKF_AD_OFF) + return -EINVAL; + } + } + + /* Last instruction must be a RET code */ + switch (filter[flen - 1].code) { + case BPF_RET | BPF_K: + case BPF_RET | BPF_A: + return check_load_and_stores(filter, flen); + } + + return -EINVAL; +} + +static int bpf_prog_store_orig_filter(struct bpf_prog *fp, + const struct sock_fprog *fprog) +{ + unsigned int fsize = bpf_classic_proglen(fprog); + struct sock_fprog_kern *fkprog; + + fp->orig_prog = kmalloc(sizeof(*fkprog), GFP_KERNEL); + if (!fp->orig_prog) + return -ENOMEM; + + fkprog = fp->orig_prog; + fkprog->len = fprog->len; + + fkprog->filter = kmemdup(fp->insns, fsize, + GFP_KERNEL | __GFP_NOWARN); + if (!fkprog->filter) { + kfree(fp->orig_prog); + return -ENOMEM; + } + + return 0; +} + +static void bpf_release_orig_filter(struct bpf_prog *fp) +{ + struct sock_fprog_kern *fprog = fp->orig_prog; + + if (fprog) { + kfree(fprog->filter); + kfree(fprog); + } +} + +static void __bpf_prog_release(struct bpf_prog *prog) +{ + if (prog->type == BPF_PROG_TYPE_SOCKET_FILTER) { + bpf_prog_put(prog); + } else { + bpf_release_orig_filter(prog); + bpf_prog_free(prog); + } +} + +static void __sk_filter_release(struct sk_filter *fp) +{ + __bpf_prog_release(fp->prog); + kfree(fp); +} + +/** + * sk_filter_release_rcu - Release a socket filter by rcu_head + * @rcu: rcu_head that contains the sk_filter to free + */ +static void sk_filter_release_rcu(struct rcu_head *rcu) +{ + struct sk_filter *fp = container_of(rcu, struct sk_filter, rcu); + + __sk_filter_release(fp); +} + +/** + * sk_filter_release - release a socket filter + * @fp: filter to remove + * + * Remove a filter from a socket and release its resources. + */ +static void sk_filter_release(struct sk_filter *fp) +{ + if (refcount_dec_and_test(&fp->refcnt)) + call_rcu(&fp->rcu, sk_filter_release_rcu); +} + +void sk_filter_uncharge(struct sock *sk, struct sk_filter *fp) +{ + u32 filter_size = bpf_prog_size(fp->prog->len); + + atomic_sub(filter_size, &sk->sk_omem_alloc); + sk_filter_release(fp); +} + +/* try to charge the socket memory if there is space available + * return true on success + */ +static bool __sk_filter_charge(struct sock *sk, struct sk_filter *fp) +{ + u32 filter_size = bpf_prog_size(fp->prog->len); + int optmem_max = READ_ONCE(sysctl_optmem_max); + + /* same check as in sock_kmalloc() */ + if (filter_size <= optmem_max && + atomic_read(&sk->sk_omem_alloc) + filter_size < optmem_max) { + atomic_add(filter_size, &sk->sk_omem_alloc); + return true; + } + return false; +} + +bool sk_filter_charge(struct sock *sk, struct sk_filter *fp) +{ + if (!refcount_inc_not_zero(&fp->refcnt)) + return false; + + if (!__sk_filter_charge(sk, fp)) { + sk_filter_release(fp); + return false; + } + return true; +} + +static struct bpf_prog *bpf_migrate_filter(struct bpf_prog *fp) +{ + struct sock_filter *old_prog; + struct bpf_prog *old_fp; + int err, new_len, old_len = fp->len; + bool seen_ld_abs = false; + + /* We are free to overwrite insns et al right here as it won't be used at + * this point in time anymore internally after the migration to the eBPF + * instruction representation. + */ + BUILD_BUG_ON(sizeof(struct sock_filter) != + sizeof(struct bpf_insn)); + + /* Conversion cannot happen on overlapping memory areas, + * so we need to keep the user BPF around until the 2nd + * pass. At this time, the user BPF is stored in fp->insns. + */ + old_prog = kmemdup(fp->insns, old_len * sizeof(struct sock_filter), + GFP_KERNEL | __GFP_NOWARN); + if (!old_prog) { + err = -ENOMEM; + goto out_err; + } + + /* 1st pass: calculate the new program length. */ + err = bpf_convert_filter(old_prog, old_len, NULL, &new_len, + &seen_ld_abs); + if (err) + goto out_err_free; + + /* Expand fp for appending the new filter representation. */ + old_fp = fp; + fp = bpf_prog_realloc(old_fp, bpf_prog_size(new_len), 0); + if (!fp) { + /* The old_fp is still around in case we couldn't + * allocate new memory, so uncharge on that one. + */ + fp = old_fp; + err = -ENOMEM; + goto out_err_free; + } + + fp->len = new_len; + + /* 2nd pass: remap sock_filter insns into bpf_insn insns. */ + err = bpf_convert_filter(old_prog, old_len, fp, &new_len, + &seen_ld_abs); + if (err) + /* 2nd bpf_convert_filter() can fail only if it fails + * to allocate memory, remapping must succeed. Note, + * that at this time old_fp has already been released + * by krealloc(). + */ + goto out_err_free; + + fp = bpf_prog_select_runtime(fp, &err); + if (err) + goto out_err_free; + + kfree(old_prog); + return fp; + +out_err_free: + kfree(old_prog); +out_err: + __bpf_prog_release(fp); + return ERR_PTR(err); +} + +static struct bpf_prog *bpf_prepare_filter(struct bpf_prog *fp, + bpf_aux_classic_check_t trans) +{ + int err; + + fp->bpf_func = NULL; + fp->jited = 0; + + err = bpf_check_classic(fp->insns, fp->len); + if (err) { + __bpf_prog_release(fp); + return ERR_PTR(err); + } + + /* There might be additional checks and transformations + * needed on classic filters, f.e. in case of seccomp. + */ + if (trans) { + err = trans(fp->insns, fp->len); + if (err) { + __bpf_prog_release(fp); + return ERR_PTR(err); + } + } + + /* Probe if we can JIT compile the filter and if so, do + * the compilation of the filter. + */ + bpf_jit_compile(fp); + + /* JIT compiler couldn't process this filter, so do the eBPF translation + * for the optimized interpreter. + */ + if (!fp->jited) + fp = bpf_migrate_filter(fp); + + return fp; +} + +/** + * bpf_prog_create - create an unattached filter + * @pfp: the unattached filter that is created + * @fprog: the filter program + * + * Create a filter independent of any socket. We first run some + * sanity checks on it to make sure it does not explode on us later. + * If an error occurs or there is insufficient memory for the filter + * a negative errno code is returned. On success the return is zero. + */ +int bpf_prog_create(struct bpf_prog **pfp, struct sock_fprog_kern *fprog) +{ + unsigned int fsize = bpf_classic_proglen(fprog); + struct bpf_prog *fp; + + /* Make sure new filter is there and in the right amounts. */ + if (!bpf_check_basics_ok(fprog->filter, fprog->len)) + return -EINVAL; + + fp = bpf_prog_alloc(bpf_prog_size(fprog->len), 0); + if (!fp) + return -ENOMEM; + + memcpy(fp->insns, fprog->filter, fsize); + + fp->len = fprog->len; + /* Since unattached filters are not copied back to user + * space through sk_get_filter(), we do not need to hold + * a copy here, and can spare us the work. + */ + fp->orig_prog = NULL; + + /* bpf_prepare_filter() already takes care of freeing + * memory in case something goes wrong. + */ + fp = bpf_prepare_filter(fp, NULL); + if (IS_ERR(fp)) + return PTR_ERR(fp); + + *pfp = fp; + return 0; +} +EXPORT_SYMBOL_GPL(bpf_prog_create); + +/** + * bpf_prog_create_from_user - create an unattached filter from user buffer + * @pfp: the unattached filter that is created + * @fprog: the filter program + * @trans: post-classic verifier transformation handler + * @save_orig: save classic BPF program + * + * This function effectively does the same as bpf_prog_create(), only + * that it builds up its insns buffer from user space provided buffer. + * It also allows for passing a bpf_aux_classic_check_t handler. + */ +int bpf_prog_create_from_user(struct bpf_prog **pfp, struct sock_fprog *fprog, + bpf_aux_classic_check_t trans, bool save_orig) +{ + unsigned int fsize = bpf_classic_proglen(fprog); + struct bpf_prog *fp; + int err; + + /* Make sure new filter is there and in the right amounts. */ + if (!bpf_check_basics_ok(fprog->filter, fprog->len)) + return -EINVAL; + + fp = bpf_prog_alloc(bpf_prog_size(fprog->len), 0); + if (!fp) + return -ENOMEM; + + if (copy_from_user(fp->insns, fprog->filter, fsize)) { + __bpf_prog_free(fp); + return -EFAULT; + } + + fp->len = fprog->len; + fp->orig_prog = NULL; + + if (save_orig) { + err = bpf_prog_store_orig_filter(fp, fprog); + if (err) { + __bpf_prog_free(fp); + return -ENOMEM; + } + } + + /* bpf_prepare_filter() already takes care of freeing + * memory in case something goes wrong. + */ + fp = bpf_prepare_filter(fp, trans); + if (IS_ERR(fp)) + return PTR_ERR(fp); + + *pfp = fp; + return 0; +} +EXPORT_SYMBOL_GPL(bpf_prog_create_from_user); + +void bpf_prog_destroy(struct bpf_prog *fp) +{ + __bpf_prog_release(fp); +} +EXPORT_SYMBOL_GPL(bpf_prog_destroy); + +static int __sk_attach_prog(struct bpf_prog *prog, struct sock *sk) +{ + struct sk_filter *fp, *old_fp; + + fp = kmalloc(sizeof(*fp), GFP_KERNEL); + if (!fp) + return -ENOMEM; + + fp->prog = prog; + + if (!__sk_filter_charge(sk, fp)) { + kfree(fp); + return -ENOMEM; + } + refcount_set(&fp->refcnt, 1); + + old_fp = rcu_dereference_protected(sk->sk_filter, + lockdep_sock_is_held(sk)); + rcu_assign_pointer(sk->sk_filter, fp); + + if (old_fp) + sk_filter_uncharge(sk, old_fp); + + return 0; +} + +static +struct bpf_prog *__get_filter(struct sock_fprog *fprog, struct sock *sk) +{ + unsigned int fsize = bpf_classic_proglen(fprog); + struct bpf_prog *prog; + int err; + + if (sock_flag(sk, SOCK_FILTER_LOCKED)) + return ERR_PTR(-EPERM); + + /* Make sure new filter is there and in the right amounts. */ + if (!bpf_check_basics_ok(fprog->filter, fprog->len)) + return ERR_PTR(-EINVAL); + + prog = bpf_prog_alloc(bpf_prog_size(fprog->len), 0); + if (!prog) + return ERR_PTR(-ENOMEM); + + if (copy_from_user(prog->insns, fprog->filter, fsize)) { + __bpf_prog_free(prog); + return ERR_PTR(-EFAULT); + } + + prog->len = fprog->len; + + err = bpf_prog_store_orig_filter(prog, fprog); + if (err) { + __bpf_prog_free(prog); + return ERR_PTR(-ENOMEM); + } + + /* bpf_prepare_filter() already takes care of freeing + * memory in case something goes wrong. + */ + return bpf_prepare_filter(prog, NULL); +} + +/** + * sk_attach_filter - attach a socket filter + * @fprog: the filter program + * @sk: the socket to use + * + * Attach the user's filter code. We first run some sanity checks on + * it to make sure it does not explode on us later. If an error + * occurs or there is insufficient memory for the filter a negative + * errno code is returned. On success the return is zero. + */ +int sk_attach_filter(struct sock_fprog *fprog, struct sock *sk) +{ + struct bpf_prog *prog = __get_filter(fprog, sk); + int err; + + if (IS_ERR(prog)) + return PTR_ERR(prog); + + err = __sk_attach_prog(prog, sk); + if (err < 0) { + __bpf_prog_release(prog); + return err; + } + + return 0; +} +EXPORT_SYMBOL_GPL(sk_attach_filter); + +int sk_reuseport_attach_filter(struct sock_fprog *fprog, struct sock *sk) +{ + struct bpf_prog *prog = __get_filter(fprog, sk); + int err; + + if (IS_ERR(prog)) + return PTR_ERR(prog); + + if (bpf_prog_size(prog->len) > READ_ONCE(sysctl_optmem_max)) + err = -ENOMEM; + else + err = reuseport_attach_prog(sk, prog); + + if (err) + __bpf_prog_release(prog); + + return err; +} + +static struct bpf_prog *__get_bpf(u32 ufd, struct sock *sk) +{ + if (sock_flag(sk, SOCK_FILTER_LOCKED)) + return ERR_PTR(-EPERM); + + return bpf_prog_get_type(ufd, BPF_PROG_TYPE_SOCKET_FILTER); +} + +int sk_attach_bpf(u32 ufd, struct sock *sk) +{ + struct bpf_prog *prog = __get_bpf(ufd, sk); + int err; + + if (IS_ERR(prog)) + return PTR_ERR(prog); + + err = __sk_attach_prog(prog, sk); + if (err < 0) { + bpf_prog_put(prog); + return err; + } + + return 0; +} + +int sk_reuseport_attach_bpf(u32 ufd, struct sock *sk) +{ + struct bpf_prog *prog; + int err; + + if (sock_flag(sk, SOCK_FILTER_LOCKED)) + return -EPERM; + + prog = bpf_prog_get_type(ufd, BPF_PROG_TYPE_SOCKET_FILTER); + if (PTR_ERR(prog) == -EINVAL) + prog = bpf_prog_get_type(ufd, BPF_PROG_TYPE_SK_REUSEPORT); + if (IS_ERR(prog)) + return PTR_ERR(prog); + + if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT) { + /* Like other non BPF_PROG_TYPE_SOCKET_FILTER + * bpf prog (e.g. sockmap). It depends on the + * limitation imposed by bpf_prog_load(). + * Hence, sysctl_optmem_max is not checked. + */ + if ((sk->sk_type != SOCK_STREAM && + sk->sk_type != SOCK_DGRAM) || + (sk->sk_protocol != IPPROTO_UDP && + sk->sk_protocol != IPPROTO_TCP) || + (sk->sk_family != AF_INET && + sk->sk_family != AF_INET6)) { + err = -ENOTSUPP; + goto err_prog_put; + } + } else { + /* BPF_PROG_TYPE_SOCKET_FILTER */ + if (bpf_prog_size(prog->len) > READ_ONCE(sysctl_optmem_max)) { + err = -ENOMEM; + goto err_prog_put; + } + } + + err = reuseport_attach_prog(sk, prog); +err_prog_put: + if (err) + bpf_prog_put(prog); + + return err; +} + +void sk_reuseport_prog_free(struct bpf_prog *prog) +{ + if (!prog) + return; + + if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT) + bpf_prog_put(prog); + else + bpf_prog_destroy(prog); +} + +struct bpf_scratchpad { + union { + __be32 diff[MAX_BPF_STACK / sizeof(__be32)]; + u8 buff[MAX_BPF_STACK]; + }; +}; + +static DEFINE_PER_CPU(struct bpf_scratchpad, bpf_sp); + +static inline int __bpf_try_make_writable(struct sk_buff *skb, + unsigned int write_len) +{ + return skb_ensure_writable(skb, write_len); +} + +static inline int bpf_try_make_writable(struct sk_buff *skb, + unsigned int write_len) +{ + int err = __bpf_try_make_writable(skb, write_len); + + bpf_compute_data_pointers(skb); + return err; +} + +static int bpf_try_make_head_writable(struct sk_buff *skb) +{ + return bpf_try_make_writable(skb, skb_headlen(skb)); +} + +static inline void bpf_push_mac_rcsum(struct sk_buff *skb) +{ + if (skb_at_tc_ingress(skb)) + skb_postpush_rcsum(skb, skb_mac_header(skb), skb->mac_len); +} + +static inline void bpf_pull_mac_rcsum(struct sk_buff *skb) +{ + if (skb_at_tc_ingress(skb)) + skb_postpull_rcsum(skb, skb_mac_header(skb), skb->mac_len); +} + +BPF_CALL_5(bpf_skb_store_bytes, struct sk_buff *, skb, u32, offset, + const void *, from, u32, len, u64, flags) +{ + void *ptr; + + if (unlikely(flags & ~(BPF_F_RECOMPUTE_CSUM | BPF_F_INVALIDATE_HASH))) + return -EINVAL; + if (unlikely(offset > INT_MAX)) + return -EFAULT; + if (unlikely(bpf_try_make_writable(skb, offset + len))) + return -EFAULT; + + ptr = skb->data + offset; + if (flags & BPF_F_RECOMPUTE_CSUM) + __skb_postpull_rcsum(skb, ptr, len, offset); + + memcpy(ptr, from, len); + + if (flags & BPF_F_RECOMPUTE_CSUM) + __skb_postpush_rcsum(skb, ptr, len, offset); + if (flags & BPF_F_INVALIDATE_HASH) + skb_clear_hash(skb); + + return 0; +} + +static const struct bpf_func_proto bpf_skb_store_bytes_proto = { + .func = bpf_skb_store_bytes, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg4_type = ARG_CONST_SIZE, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_4(bpf_skb_load_bytes, const struct sk_buff *, skb, u32, offset, + void *, to, u32, len) +{ + void *ptr; + + if (unlikely(offset > INT_MAX)) + goto err_clear; + + ptr = skb_header_pointer(skb, offset, len, to); + if (unlikely(!ptr)) + goto err_clear; + if (ptr != to) + memcpy(to, ptr, len); + + return 0; +err_clear: + memset(to, 0, len); + return -EFAULT; +} + +static const struct bpf_func_proto bpf_skb_load_bytes_proto = { + .func = bpf_skb_load_bytes, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_PTR_TO_UNINIT_MEM, + .arg4_type = ARG_CONST_SIZE, +}; + +BPF_CALL_4(bpf_flow_dissector_load_bytes, + const struct bpf_flow_dissector *, ctx, u32, offset, + void *, to, u32, len) +{ + void *ptr; + + if (unlikely(offset > 0xffff)) + goto err_clear; + + if (unlikely(!ctx->skb)) + goto err_clear; + + ptr = skb_header_pointer(ctx->skb, offset, len, to); + if (unlikely(!ptr)) + goto err_clear; + if (ptr != to) + memcpy(to, ptr, len); + + return 0; +err_clear: + memset(to, 0, len); + return -EFAULT; +} + +static const struct bpf_func_proto bpf_flow_dissector_load_bytes_proto = { + .func = bpf_flow_dissector_load_bytes, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_PTR_TO_UNINIT_MEM, + .arg4_type = ARG_CONST_SIZE, +}; + +BPF_CALL_5(bpf_skb_load_bytes_relative, const struct sk_buff *, skb, + u32, offset, void *, to, u32, len, u32, start_header) +{ + u8 *end = skb_tail_pointer(skb); + u8 *start, *ptr; + + if (unlikely(offset > 0xffff)) + goto err_clear; + + switch (start_header) { + case BPF_HDR_START_MAC: + if (unlikely(!skb_mac_header_was_set(skb))) + goto err_clear; + start = skb_mac_header(skb); + break; + case BPF_HDR_START_NET: + start = skb_network_header(skb); + break; + default: + goto err_clear; + } + + ptr = start + offset; + + if (likely(ptr + len <= end)) { + memcpy(to, ptr, len); + return 0; + } + +err_clear: + memset(to, 0, len); + return -EFAULT; +} + +static const struct bpf_func_proto bpf_skb_load_bytes_relative_proto = { + .func = bpf_skb_load_bytes_relative, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_PTR_TO_UNINIT_MEM, + .arg4_type = ARG_CONST_SIZE, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_2(bpf_skb_pull_data, struct sk_buff *, skb, u32, len) +{ + /* Idea is the following: should the needed direct read/write + * test fail during runtime, we can pull in more data and redo + * again, since implicitly, we invalidate previous checks here. + * + * Or, since we know how much we need to make read/writeable, + * this can be done once at the program beginning for direct + * access case. By this we overcome limitations of only current + * headroom being accessible. + */ + return bpf_try_make_writable(skb, len ? : skb_headlen(skb)); +} + +static const struct bpf_func_proto bpf_skb_pull_data_proto = { + .func = bpf_skb_pull_data, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, +}; + +BPF_CALL_1(bpf_sk_fullsock, struct sock *, sk) +{ + return sk_fullsock(sk) ? (unsigned long)sk : (unsigned long)NULL; +} + +static const struct bpf_func_proto bpf_sk_fullsock_proto = { + .func = bpf_sk_fullsock, + .gpl_only = false, + .ret_type = RET_PTR_TO_SOCKET_OR_NULL, + .arg1_type = ARG_PTR_TO_SOCK_COMMON, +}; + +static inline int sk_skb_try_make_writable(struct sk_buff *skb, + unsigned int write_len) +{ + return __bpf_try_make_writable(skb, write_len); +} + +BPF_CALL_2(sk_skb_pull_data, struct sk_buff *, skb, u32, len) +{ + /* Idea is the following: should the needed direct read/write + * test fail during runtime, we can pull in more data and redo + * again, since implicitly, we invalidate previous checks here. + * + * Or, since we know how much we need to make read/writeable, + * this can be done once at the program beginning for direct + * access case. By this we overcome limitations of only current + * headroom being accessible. + */ + return sk_skb_try_make_writable(skb, len ? : skb_headlen(skb)); +} + +static const struct bpf_func_proto sk_skb_pull_data_proto = { + .func = sk_skb_pull_data, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, +}; + +BPF_CALL_5(bpf_l3_csum_replace, struct sk_buff *, skb, u32, offset, + u64, from, u64, to, u64, flags) +{ + __sum16 *ptr; + + if (unlikely(flags & ~(BPF_F_HDR_FIELD_MASK))) + return -EINVAL; + if (unlikely(offset > 0xffff || offset & 1)) + return -EFAULT; + if (unlikely(bpf_try_make_writable(skb, offset + sizeof(*ptr)))) + return -EFAULT; + + ptr = (__sum16 *)(skb->data + offset); + switch (flags & BPF_F_HDR_FIELD_MASK) { + case 0: + if (unlikely(from != 0)) + return -EINVAL; + + csum_replace_by_diff(ptr, to); + break; + case 2: + csum_replace2(ptr, from, to); + break; + case 4: + csum_replace4(ptr, from, to); + break; + default: + return -EINVAL; + } + + return 0; +} + +static const struct bpf_func_proto bpf_l3_csum_replace_proto = { + .func = bpf_l3_csum_replace, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_5(bpf_l4_csum_replace, struct sk_buff *, skb, u32, offset, + u64, from, u64, to, u64, flags) +{ + bool is_pseudo = flags & BPF_F_PSEUDO_HDR; + bool is_mmzero = flags & BPF_F_MARK_MANGLED_0; + bool do_mforce = flags & BPF_F_MARK_ENFORCE; + __sum16 *ptr; + + if (unlikely(flags & ~(BPF_F_MARK_MANGLED_0 | BPF_F_MARK_ENFORCE | + BPF_F_PSEUDO_HDR | BPF_F_HDR_FIELD_MASK))) + return -EINVAL; + if (unlikely(offset > 0xffff || offset & 1)) + return -EFAULT; + if (unlikely(bpf_try_make_writable(skb, offset + sizeof(*ptr)))) + return -EFAULT; + + ptr = (__sum16 *)(skb->data + offset); + if (is_mmzero && !do_mforce && !*ptr) + return 0; + + switch (flags & BPF_F_HDR_FIELD_MASK) { + case 0: + if (unlikely(from != 0)) + return -EINVAL; + + inet_proto_csum_replace_by_diff(ptr, skb, to, is_pseudo); + break; + case 2: + inet_proto_csum_replace2(ptr, skb, from, to, is_pseudo); + break; + case 4: + inet_proto_csum_replace4(ptr, skb, from, to, is_pseudo); + break; + default: + return -EINVAL; + } + + if (is_mmzero && !*ptr) + *ptr = CSUM_MANGLED_0; + return 0; +} + +static const struct bpf_func_proto bpf_l4_csum_replace_proto = { + .func = bpf_l4_csum_replace, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_5(bpf_csum_diff, __be32 *, from, u32, from_size, + __be32 *, to, u32, to_size, __wsum, seed) +{ + struct bpf_scratchpad *sp = this_cpu_ptr(&bpf_sp); + u32 diff_size = from_size + to_size; + int i, j = 0; + + /* This is quite flexible, some examples: + * + * from_size == 0, to_size > 0, seed := csum --> pushing data + * from_size > 0, to_size == 0, seed := csum --> pulling data + * from_size > 0, to_size > 0, seed := 0 --> diffing data + * + * Even for diffing, from_size and to_size don't need to be equal. + */ + if (unlikely(((from_size | to_size) & (sizeof(__be32) - 1)) || + diff_size > sizeof(sp->diff))) + return -EINVAL; + + for (i = 0; i < from_size / sizeof(__be32); i++, j++) + sp->diff[j] = ~from[i]; + for (i = 0; i < to_size / sizeof(__be32); i++, j++) + sp->diff[j] = to[i]; + + return csum_partial(sp->diff, diff_size, seed); +} + +static const struct bpf_func_proto bpf_csum_diff_proto = { + .func = bpf_csum_diff, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_MEM | PTR_MAYBE_NULL | MEM_RDONLY, + .arg2_type = ARG_CONST_SIZE_OR_ZERO, + .arg3_type = ARG_PTR_TO_MEM | PTR_MAYBE_NULL | MEM_RDONLY, + .arg4_type = ARG_CONST_SIZE_OR_ZERO, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_2(bpf_csum_update, struct sk_buff *, skb, __wsum, csum) +{ + /* The interface is to be used in combination with bpf_csum_diff() + * for direct packet writes. csum rotation for alignment as well + * as emulating csum_sub() can be done from the eBPF program. + */ + if (skb->ip_summed == CHECKSUM_COMPLETE) + return (skb->csum = csum_add(skb->csum, csum)); + + return -ENOTSUPP; +} + +static const struct bpf_func_proto bpf_csum_update_proto = { + .func = bpf_csum_update, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, +}; + +BPF_CALL_2(bpf_csum_level, struct sk_buff *, skb, u64, level) +{ + /* The interface is to be used in combination with bpf_skb_adjust_room() + * for encap/decap of packet headers when BPF_F_ADJ_ROOM_NO_CSUM_RESET + * is passed as flags, for example. + */ + switch (level) { + case BPF_CSUM_LEVEL_INC: + __skb_incr_checksum_unnecessary(skb); + break; + case BPF_CSUM_LEVEL_DEC: + __skb_decr_checksum_unnecessary(skb); + break; + case BPF_CSUM_LEVEL_RESET: + __skb_reset_checksum_unnecessary(skb); + break; + case BPF_CSUM_LEVEL_QUERY: + return skb->ip_summed == CHECKSUM_UNNECESSARY ? + skb->csum_level : -EACCES; + default: + return -EINVAL; + } + + return 0; +} + +static const struct bpf_func_proto bpf_csum_level_proto = { + .func = bpf_csum_level, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, +}; + +static inline int __bpf_rx_skb(struct net_device *dev, struct sk_buff *skb) +{ + return dev_forward_skb_nomtu(dev, skb); +} + +static inline int __bpf_rx_skb_no_mac(struct net_device *dev, + struct sk_buff *skb) +{ + int ret = ____dev_forward_skb(dev, skb, false); + + if (likely(!ret)) { + skb->dev = dev; + ret = netif_rx(skb); + } + + return ret; +} + +static inline int __bpf_tx_skb(struct net_device *dev, struct sk_buff *skb) +{ + int ret; + + if (dev_xmit_recursion()) { + net_crit_ratelimited("bpf: recursion limit reached on datapath, buggy bpf program?\n"); + kfree_skb(skb); + return -ENETDOWN; + } + + skb->dev = dev; + skb_clear_tstamp(skb); + + dev_xmit_recursion_inc(); + ret = dev_queue_xmit(skb); + dev_xmit_recursion_dec(); + + return ret; +} + +static int __bpf_redirect_no_mac(struct sk_buff *skb, struct net_device *dev, + u32 flags) +{ + unsigned int mlen = skb_network_offset(skb); + + if (unlikely(skb->len <= mlen)) { + kfree_skb(skb); + return -ERANGE; + } + + if (mlen) { + __skb_pull(skb, mlen); + if (unlikely(!skb->len)) { + kfree_skb(skb); + return -ERANGE; + } + + /* At ingress, the mac header has already been pulled once. + * At egress, skb_pospull_rcsum has to be done in case that + * the skb is originated from ingress (i.e. a forwarded skb) + * to ensure that rcsum starts at net header. + */ + if (!skb_at_tc_ingress(skb)) + skb_postpull_rcsum(skb, skb_mac_header(skb), mlen); + } + skb_pop_mac_header(skb); + skb_reset_mac_len(skb); + return flags & BPF_F_INGRESS ? + __bpf_rx_skb_no_mac(dev, skb) : __bpf_tx_skb(dev, skb); +} + +static int __bpf_redirect_common(struct sk_buff *skb, struct net_device *dev, + u32 flags) +{ + /* Verify that a link layer header is carried */ + if (unlikely(skb->mac_header >= skb->network_header || skb->len == 0)) { + kfree_skb(skb); + return -ERANGE; + } + + bpf_push_mac_rcsum(skb); + return flags & BPF_F_INGRESS ? + __bpf_rx_skb(dev, skb) : __bpf_tx_skb(dev, skb); +} + +static int __bpf_redirect(struct sk_buff *skb, struct net_device *dev, + u32 flags) +{ + if (dev_is_mac_header_xmit(dev)) + return __bpf_redirect_common(skb, dev, flags); + else + return __bpf_redirect_no_mac(skb, dev, flags); +} + +#if IS_ENABLED(CONFIG_IPV6) +static int bpf_out_neigh_v6(struct net *net, struct sk_buff *skb, + struct net_device *dev, struct bpf_nh_params *nh) +{ + u32 hh_len = LL_RESERVED_SPACE(dev); + const struct in6_addr *nexthop; + struct dst_entry *dst = NULL; + struct neighbour *neigh; + + if (dev_xmit_recursion()) { + net_crit_ratelimited("bpf: recursion limit reached on datapath, buggy bpf program?\n"); + goto out_drop; + } + + skb->dev = dev; + skb_clear_tstamp(skb); + + if (unlikely(skb_headroom(skb) < hh_len && dev->header_ops)) { + skb = skb_expand_head(skb, hh_len); + if (!skb) + return -ENOMEM; + } + + rcu_read_lock(); + if (!nh) { + dst = skb_dst(skb); + nexthop = rt6_nexthop(container_of(dst, struct rt6_info, dst), + &ipv6_hdr(skb)->daddr); + } else { + nexthop = &nh->ipv6_nh; + } + neigh = ip_neigh_gw6(dev, nexthop); + if (likely(!IS_ERR(neigh))) { + int ret; + + sock_confirm_neigh(skb, neigh); + local_bh_disable(); + dev_xmit_recursion_inc(); + ret = neigh_output(neigh, skb, false); + dev_xmit_recursion_dec(); + local_bh_enable(); + rcu_read_unlock(); + return ret; + } + rcu_read_unlock_bh(); + if (dst) + IP6_INC_STATS(net, ip6_dst_idev(dst), IPSTATS_MIB_OUTNOROUTES); +out_drop: + kfree_skb(skb); + return -ENETDOWN; +} + +static int __bpf_redirect_neigh_v6(struct sk_buff *skb, struct net_device *dev, + struct bpf_nh_params *nh) +{ + const struct ipv6hdr *ip6h = ipv6_hdr(skb); + struct net *net = dev_net(dev); + int err, ret = NET_XMIT_DROP; + + if (!nh) { + struct dst_entry *dst; + struct flowi6 fl6 = { + .flowi6_flags = FLOWI_FLAG_ANYSRC, + .flowi6_mark = skb->mark, + .flowlabel = ip6_flowinfo(ip6h), + .flowi6_oif = dev->ifindex, + .flowi6_proto = ip6h->nexthdr, + .daddr = ip6h->daddr, + .saddr = ip6h->saddr, + }; + + dst = ipv6_stub->ipv6_dst_lookup_flow(net, NULL, &fl6, NULL); + if (IS_ERR(dst)) + goto out_drop; + + skb_dst_set(skb, dst); + } else if (nh->nh_family != AF_INET6) { + goto out_drop; + } + + err = bpf_out_neigh_v6(net, skb, dev, nh); + if (unlikely(net_xmit_eval(err))) + dev->stats.tx_errors++; + else + ret = NET_XMIT_SUCCESS; + goto out_xmit; +out_drop: + dev->stats.tx_errors++; + kfree_skb(skb); +out_xmit: + return ret; +} +#else +static int __bpf_redirect_neigh_v6(struct sk_buff *skb, struct net_device *dev, + struct bpf_nh_params *nh) +{ + kfree_skb(skb); + return NET_XMIT_DROP; +} +#endif /* CONFIG_IPV6 */ + +#if IS_ENABLED(CONFIG_INET) +static int bpf_out_neigh_v4(struct net *net, struct sk_buff *skb, + struct net_device *dev, struct bpf_nh_params *nh) +{ + u32 hh_len = LL_RESERVED_SPACE(dev); + struct neighbour *neigh; + bool is_v6gw = false; + + if (dev_xmit_recursion()) { + net_crit_ratelimited("bpf: recursion limit reached on datapath, buggy bpf program?\n"); + goto out_drop; + } + + skb->dev = dev; + skb_clear_tstamp(skb); + + if (unlikely(skb_headroom(skb) < hh_len && dev->header_ops)) { + skb = skb_expand_head(skb, hh_len); + if (!skb) + return -ENOMEM; + } + + rcu_read_lock(); + if (!nh) { + struct dst_entry *dst = skb_dst(skb); + struct rtable *rt = container_of(dst, struct rtable, dst); + + neigh = ip_neigh_for_gw(rt, skb, &is_v6gw); + } else if (nh->nh_family == AF_INET6) { + neigh = ip_neigh_gw6(dev, &nh->ipv6_nh); + is_v6gw = true; + } else if (nh->nh_family == AF_INET) { + neigh = ip_neigh_gw4(dev, nh->ipv4_nh); + } else { + rcu_read_unlock(); + goto out_drop; + } + + if (likely(!IS_ERR(neigh))) { + int ret; + + sock_confirm_neigh(skb, neigh); + local_bh_disable(); + dev_xmit_recursion_inc(); + ret = neigh_output(neigh, skb, is_v6gw); + dev_xmit_recursion_dec(); + local_bh_enable(); + rcu_read_unlock(); + return ret; + } + rcu_read_unlock(); +out_drop: + kfree_skb(skb); + return -ENETDOWN; +} + +static int __bpf_redirect_neigh_v4(struct sk_buff *skb, struct net_device *dev, + struct bpf_nh_params *nh) +{ + const struct iphdr *ip4h = ip_hdr(skb); + struct net *net = dev_net(dev); + int err, ret = NET_XMIT_DROP; + + if (!nh) { + struct flowi4 fl4 = { + .flowi4_flags = FLOWI_FLAG_ANYSRC, + .flowi4_mark = skb->mark, + .flowi4_tos = RT_TOS(ip4h->tos), + .flowi4_oif = dev->ifindex, + .flowi4_proto = ip4h->protocol, + .daddr = ip4h->daddr, + .saddr = ip4h->saddr, + }; + struct rtable *rt; + + rt = ip_route_output_flow(net, &fl4, NULL); + if (IS_ERR(rt)) + goto out_drop; + if (rt->rt_type != RTN_UNICAST && rt->rt_type != RTN_LOCAL) { + ip_rt_put(rt); + goto out_drop; + } + + skb_dst_set(skb, &rt->dst); + } + + err = bpf_out_neigh_v4(net, skb, dev, nh); + if (unlikely(net_xmit_eval(err))) + dev->stats.tx_errors++; + else + ret = NET_XMIT_SUCCESS; + goto out_xmit; +out_drop: + dev->stats.tx_errors++; + kfree_skb(skb); +out_xmit: + return ret; +} +#else +static int __bpf_redirect_neigh_v4(struct sk_buff *skb, struct net_device *dev, + struct bpf_nh_params *nh) +{ + kfree_skb(skb); + return NET_XMIT_DROP; +} +#endif /* CONFIG_INET */ + +static int __bpf_redirect_neigh(struct sk_buff *skb, struct net_device *dev, + struct bpf_nh_params *nh) +{ + struct ethhdr *ethh = eth_hdr(skb); + + if (unlikely(skb->mac_header >= skb->network_header)) + goto out; + bpf_push_mac_rcsum(skb); + if (is_multicast_ether_addr(ethh->h_dest)) + goto out; + + skb_pull(skb, sizeof(*ethh)); + skb_unset_mac_header(skb); + skb_reset_network_header(skb); + + if (skb->protocol == htons(ETH_P_IP)) + return __bpf_redirect_neigh_v4(skb, dev, nh); + else if (skb->protocol == htons(ETH_P_IPV6)) + return __bpf_redirect_neigh_v6(skb, dev, nh); +out: + kfree_skb(skb); + return -ENOTSUPP; +} + +/* Internal, non-exposed redirect flags. */ +enum { + BPF_F_NEIGH = (1ULL << 1), + BPF_F_PEER = (1ULL << 2), + BPF_F_NEXTHOP = (1ULL << 3), +#define BPF_F_REDIRECT_INTERNAL (BPF_F_NEIGH | BPF_F_PEER | BPF_F_NEXTHOP) +}; + +BPF_CALL_3(bpf_clone_redirect, struct sk_buff *, skb, u32, ifindex, u64, flags) +{ + struct net_device *dev; + struct sk_buff *clone; + int ret; + + if (unlikely(flags & (~(BPF_F_INGRESS) | BPF_F_REDIRECT_INTERNAL))) + return -EINVAL; + + dev = dev_get_by_index_rcu(dev_net(skb->dev), ifindex); + if (unlikely(!dev)) + return -EINVAL; + + clone = skb_clone(skb, GFP_ATOMIC); + if (unlikely(!clone)) + return -ENOMEM; + + /* For direct write, we need to keep the invariant that the skbs + * we're dealing with need to be uncloned. Should uncloning fail + * here, we need to free the just generated clone to unclone once + * again. + */ + ret = bpf_try_make_head_writable(skb); + if (unlikely(ret)) { + kfree_skb(clone); + return -ENOMEM; + } + + return __bpf_redirect(clone, dev, flags); +} + +static const struct bpf_func_proto bpf_clone_redirect_proto = { + .func = bpf_clone_redirect, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, +}; + +DEFINE_PER_CPU(struct bpf_redirect_info, bpf_redirect_info); +EXPORT_PER_CPU_SYMBOL_GPL(bpf_redirect_info); + +int skb_do_redirect(struct sk_buff *skb) +{ + struct bpf_redirect_info *ri = this_cpu_ptr(&bpf_redirect_info); + struct net *net = dev_net(skb->dev); + struct net_device *dev; + u32 flags = ri->flags; + + dev = dev_get_by_index_rcu(net, ri->tgt_index); + ri->tgt_index = 0; + ri->flags = 0; + if (unlikely(!dev)) + goto out_drop; + if (flags & BPF_F_PEER) { + const struct net_device_ops *ops = dev->netdev_ops; + + if (unlikely(!ops->ndo_get_peer_dev || + !skb_at_tc_ingress(skb))) + goto out_drop; + dev = ops->ndo_get_peer_dev(dev); + if (unlikely(!dev || + !(dev->flags & IFF_UP) || + net_eq(net, dev_net(dev)))) + goto out_drop; + skb->dev = dev; + return -EAGAIN; + } + return flags & BPF_F_NEIGH ? + __bpf_redirect_neigh(skb, dev, flags & BPF_F_NEXTHOP ? + &ri->nh : NULL) : + __bpf_redirect(skb, dev, flags); +out_drop: + kfree_skb(skb); + return -EINVAL; +} + +BPF_CALL_2(bpf_redirect, u32, ifindex, u64, flags) +{ + struct bpf_redirect_info *ri = this_cpu_ptr(&bpf_redirect_info); + + if (unlikely(flags & (~(BPF_F_INGRESS) | BPF_F_REDIRECT_INTERNAL))) + return TC_ACT_SHOT; + + ri->flags = flags; + ri->tgt_index = ifindex; + + return TC_ACT_REDIRECT; +} + +static const struct bpf_func_proto bpf_redirect_proto = { + .func = bpf_redirect, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_ANYTHING, + .arg2_type = ARG_ANYTHING, +}; + +BPF_CALL_2(bpf_redirect_peer, u32, ifindex, u64, flags) +{ + struct bpf_redirect_info *ri = this_cpu_ptr(&bpf_redirect_info); + + if (unlikely(flags)) + return TC_ACT_SHOT; + + ri->flags = BPF_F_PEER; + ri->tgt_index = ifindex; + + return TC_ACT_REDIRECT; +} + +static const struct bpf_func_proto bpf_redirect_peer_proto = { + .func = bpf_redirect_peer, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_ANYTHING, + .arg2_type = ARG_ANYTHING, +}; + +BPF_CALL_4(bpf_redirect_neigh, u32, ifindex, struct bpf_redir_neigh *, params, + int, plen, u64, flags) +{ + struct bpf_redirect_info *ri = this_cpu_ptr(&bpf_redirect_info); + + if (unlikely((plen && plen < sizeof(*params)) || flags)) + return TC_ACT_SHOT; + + ri->flags = BPF_F_NEIGH | (plen ? BPF_F_NEXTHOP : 0); + ri->tgt_index = ifindex; + + BUILD_BUG_ON(sizeof(struct bpf_redir_neigh) != sizeof(struct bpf_nh_params)); + if (plen) + memcpy(&ri->nh, params, sizeof(ri->nh)); + + return TC_ACT_REDIRECT; +} + +static const struct bpf_func_proto bpf_redirect_neigh_proto = { + .func = bpf_redirect_neigh, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_ANYTHING, + .arg2_type = ARG_PTR_TO_MEM | PTR_MAYBE_NULL | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE_OR_ZERO, + .arg4_type = ARG_ANYTHING, +}; + +BPF_CALL_2(bpf_msg_apply_bytes, struct sk_msg *, msg, u32, bytes) +{ + msg->apply_bytes = bytes; + return 0; +} + +static const struct bpf_func_proto bpf_msg_apply_bytes_proto = { + .func = bpf_msg_apply_bytes, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, +}; + +BPF_CALL_2(bpf_msg_cork_bytes, struct sk_msg *, msg, u32, bytes) +{ + msg->cork_bytes = bytes; + return 0; +} + +static void sk_msg_reset_curr(struct sk_msg *msg) +{ + u32 i = msg->sg.start; + u32 len = 0; + + do { + len += sk_msg_elem(msg, i)->length; + sk_msg_iter_var_next(i); + if (len >= msg->sg.size) + break; + } while (i != msg->sg.end); + + msg->sg.curr = i; + msg->sg.copybreak = 0; +} + +static const struct bpf_func_proto bpf_msg_cork_bytes_proto = { + .func = bpf_msg_cork_bytes, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, +}; + +BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start, + u32, end, u64, flags) +{ + u32 len = 0, offset = 0, copy = 0, poffset = 0, bytes = end - start; + u32 first_sge, last_sge, i, shift, bytes_sg_total; + struct scatterlist *sge; + u8 *raw, *to, *from; + struct page *page; + + if (unlikely(flags || end <= start)) + return -EINVAL; + + /* First find the starting scatterlist element */ + i = msg->sg.start; + do { + offset += len; + len = sk_msg_elem(msg, i)->length; + if (start < offset + len) + break; + sk_msg_iter_var_next(i); + } while (i != msg->sg.end); + + if (unlikely(start >= offset + len)) + return -EINVAL; + + first_sge = i; + /* The start may point into the sg element so we need to also + * account for the headroom. + */ + bytes_sg_total = start - offset + bytes; + if (!test_bit(i, msg->sg.copy) && bytes_sg_total <= len) + goto out; + + /* At this point we need to linearize multiple scatterlist + * elements or a single shared page. Either way we need to + * copy into a linear buffer exclusively owned by BPF. Then + * place the buffer in the scatterlist and fixup the original + * entries by removing the entries now in the linear buffer + * and shifting the remaining entries. For now we do not try + * to copy partial entries to avoid complexity of running out + * of sg_entry slots. The downside is reading a single byte + * will copy the entire sg entry. + */ + do { + copy += sk_msg_elem(msg, i)->length; + sk_msg_iter_var_next(i); + if (bytes_sg_total <= copy) + break; + } while (i != msg->sg.end); + last_sge = i; + + if (unlikely(bytes_sg_total > copy)) + return -EINVAL; + + page = alloc_pages(__GFP_NOWARN | GFP_ATOMIC | __GFP_COMP, + get_order(copy)); + if (unlikely(!page)) + return -ENOMEM; + + raw = page_address(page); + i = first_sge; + do { + sge = sk_msg_elem(msg, i); + from = sg_virt(sge); + len = sge->length; + to = raw + poffset; + + memcpy(to, from, len); + poffset += len; + sge->length = 0; + put_page(sg_page(sge)); + + sk_msg_iter_var_next(i); + } while (i != last_sge); + + sg_set_page(&msg->sg.data[first_sge], page, copy, 0); + + /* To repair sg ring we need to shift entries. If we only + * had a single entry though we can just replace it and + * be done. Otherwise walk the ring and shift the entries. + */ + WARN_ON_ONCE(last_sge == first_sge); + shift = last_sge > first_sge ? + last_sge - first_sge - 1 : + NR_MSG_FRAG_IDS - first_sge + last_sge - 1; + if (!shift) + goto out; + + i = first_sge; + sk_msg_iter_var_next(i); + do { + u32 move_from; + + if (i + shift >= NR_MSG_FRAG_IDS) + move_from = i + shift - NR_MSG_FRAG_IDS; + else + move_from = i + shift; + if (move_from == msg->sg.end) + break; + + msg->sg.data[i] = msg->sg.data[move_from]; + msg->sg.data[move_from].length = 0; + msg->sg.data[move_from].page_link = 0; + msg->sg.data[move_from].offset = 0; + sk_msg_iter_var_next(i); + } while (1); + + msg->sg.end = msg->sg.end - shift > msg->sg.end ? + msg->sg.end - shift + NR_MSG_FRAG_IDS : + msg->sg.end - shift; +out: + sk_msg_reset_curr(msg); + msg->data = sg_virt(&msg->sg.data[first_sge]) + start - offset; + msg->data_end = msg->data + bytes; + return 0; +} + +static const struct bpf_func_proto bpf_msg_pull_data_proto = { + .func = bpf_msg_pull_data, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_ANYTHING, +}; + +BPF_CALL_4(bpf_msg_push_data, struct sk_msg *, msg, u32, start, + u32, len, u64, flags) +{ + struct scatterlist sge, nsge, nnsge, rsge = {0}, *psge; + u32 new, i = 0, l = 0, space, copy = 0, offset = 0; + u8 *raw, *to, *from; + struct page *page; + + if (unlikely(flags)) + return -EINVAL; + + if (unlikely(len == 0)) + return 0; + + /* First find the starting scatterlist element */ + i = msg->sg.start; + do { + offset += l; + l = sk_msg_elem(msg, i)->length; + + if (start < offset + l) + break; + sk_msg_iter_var_next(i); + } while (i != msg->sg.end); + + if (start >= offset + l) + return -EINVAL; + + space = MAX_MSG_FRAGS - sk_msg_elem_used(msg); + + /* If no space available will fallback to copy, we need at + * least one scatterlist elem available to push data into + * when start aligns to the beginning of an element or two + * when it falls inside an element. We handle the start equals + * offset case because its the common case for inserting a + * header. + */ + if (!space || (space == 1 && start != offset)) + copy = msg->sg.data[i].length; + + page = alloc_pages(__GFP_NOWARN | GFP_ATOMIC | __GFP_COMP, + get_order(copy + len)); + if (unlikely(!page)) + return -ENOMEM; + + if (copy) { + int front, back; + + raw = page_address(page); + + psge = sk_msg_elem(msg, i); + front = start - offset; + back = psge->length - front; + from = sg_virt(psge); + + if (front) + memcpy(raw, from, front); + + if (back) { + from += front; + to = raw + front + len; + + memcpy(to, from, back); + } + + put_page(sg_page(psge)); + } else if (start - offset) { + psge = sk_msg_elem(msg, i); + rsge = sk_msg_elem_cpy(msg, i); + + psge->length = start - offset; + rsge.length -= psge->length; + rsge.offset += start; + + sk_msg_iter_var_next(i); + sg_unmark_end(psge); + sg_unmark_end(&rsge); + sk_msg_iter_next(msg, end); + } + + /* Slot(s) to place newly allocated data */ + new = i; + + /* Shift one or two slots as needed */ + if (!copy) { + sge = sk_msg_elem_cpy(msg, i); + + sk_msg_iter_var_next(i); + sg_unmark_end(&sge); + sk_msg_iter_next(msg, end); + + nsge = sk_msg_elem_cpy(msg, i); + if (rsge.length) { + sk_msg_iter_var_next(i); + nnsge = sk_msg_elem_cpy(msg, i); + } + + while (i != msg->sg.end) { + msg->sg.data[i] = sge; + sge = nsge; + sk_msg_iter_var_next(i); + if (rsge.length) { + nsge = nnsge; + nnsge = sk_msg_elem_cpy(msg, i); + } else { + nsge = sk_msg_elem_cpy(msg, i); + } + } + } + + /* Place newly allocated data buffer */ + sk_mem_charge(msg->sk, len); + msg->sg.size += len; + __clear_bit(new, msg->sg.copy); + sg_set_page(&msg->sg.data[new], page, len + copy, 0); + if (rsge.length) { + get_page(sg_page(&rsge)); + sk_msg_iter_var_next(new); + msg->sg.data[new] = rsge; + } + + sk_msg_reset_curr(msg); + sk_msg_compute_data_pointers(msg); + return 0; +} + +static const struct bpf_func_proto bpf_msg_push_data_proto = { + .func = bpf_msg_push_data, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_ANYTHING, +}; + +static void sk_msg_shift_left(struct sk_msg *msg, int i) +{ + int prev; + + do { + prev = i; + sk_msg_iter_var_next(i); + msg->sg.data[prev] = msg->sg.data[i]; + } while (i != msg->sg.end); + + sk_msg_iter_prev(msg, end); +} + +static void sk_msg_shift_right(struct sk_msg *msg, int i) +{ + struct scatterlist tmp, sge; + + sk_msg_iter_next(msg, end); + sge = sk_msg_elem_cpy(msg, i); + sk_msg_iter_var_next(i); + tmp = sk_msg_elem_cpy(msg, i); + + while (i != msg->sg.end) { + msg->sg.data[i] = sge; + sk_msg_iter_var_next(i); + sge = tmp; + tmp = sk_msg_elem_cpy(msg, i); + } +} + +BPF_CALL_4(bpf_msg_pop_data, struct sk_msg *, msg, u32, start, + u32, len, u64, flags) +{ + u32 i = 0, l = 0, space, offset = 0; + u64 last = start + len; + int pop; + + if (unlikely(flags)) + return -EINVAL; + + /* First find the starting scatterlist element */ + i = msg->sg.start; + do { + offset += l; + l = sk_msg_elem(msg, i)->length; + + if (start < offset + l) + break; + sk_msg_iter_var_next(i); + } while (i != msg->sg.end); + + /* Bounds checks: start and pop must be inside message */ + if (start >= offset + l || last >= msg->sg.size) + return -EINVAL; + + space = MAX_MSG_FRAGS - sk_msg_elem_used(msg); + + pop = len; + /* --------------| offset + * -| start |-------- len -------| + * + * |----- a ----|-------- pop -------|----- b ----| + * |______________________________________________| length + * + * + * a: region at front of scatter element to save + * b: region at back of scatter element to save when length > A + pop + * pop: region to pop from element, same as input 'pop' here will be + * decremented below per iteration. + * + * Two top-level cases to handle when start != offset, first B is non + * zero and second B is zero corresponding to when a pop includes more + * than one element. + * + * Then if B is non-zero AND there is no space allocate space and + * compact A, B regions into page. If there is space shift ring to + * the rigth free'ing the next element in ring to place B, leaving + * A untouched except to reduce length. + */ + if (start != offset) { + struct scatterlist *nsge, *sge = sk_msg_elem(msg, i); + int a = start; + int b = sge->length - pop - a; + + sk_msg_iter_var_next(i); + + if (pop < sge->length - a) { + if (space) { + sge->length = a; + sk_msg_shift_right(msg, i); + nsge = sk_msg_elem(msg, i); + get_page(sg_page(sge)); + sg_set_page(nsge, + sg_page(sge), + b, sge->offset + pop + a); + } else { + struct page *page, *orig; + u8 *to, *from; + + page = alloc_pages(__GFP_NOWARN | + __GFP_COMP | GFP_ATOMIC, + get_order(a + b)); + if (unlikely(!page)) + return -ENOMEM; + + sge->length = a; + orig = sg_page(sge); + from = sg_virt(sge); + to = page_address(page); + memcpy(to, from, a); + memcpy(to + a, from + a + pop, b); + sg_set_page(sge, page, a + b, 0); + put_page(orig); + } + pop = 0; + } else if (pop >= sge->length - a) { + pop -= (sge->length - a); + sge->length = a; + } + } + + /* From above the current layout _must_ be as follows, + * + * -| offset + * -| start + * + * |---- pop ---|---------------- b ------------| + * |____________________________________________| length + * + * Offset and start of the current msg elem are equal because in the + * previous case we handled offset != start and either consumed the + * entire element and advanced to the next element OR pop == 0. + * + * Two cases to handle here are first pop is less than the length + * leaving some remainder b above. Simply adjust the element's layout + * in this case. Or pop >= length of the element so that b = 0. In this + * case advance to next element decrementing pop. + */ + while (pop) { + struct scatterlist *sge = sk_msg_elem(msg, i); + + if (pop < sge->length) { + sge->length -= pop; + sge->offset += pop; + pop = 0; + } else { + pop -= sge->length; + sk_msg_shift_left(msg, i); + } + sk_msg_iter_var_next(i); + } + + sk_mem_uncharge(msg->sk, len - pop); + msg->sg.size -= (len - pop); + sk_msg_reset_curr(msg); + sk_msg_compute_data_pointers(msg); + return 0; +} + +static const struct bpf_func_proto bpf_msg_pop_data_proto = { + .func = bpf_msg_pop_data, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_ANYTHING, +}; + +#ifdef CONFIG_CGROUP_NET_CLASSID +BPF_CALL_0(bpf_get_cgroup_classid_curr) +{ + return __task_get_classid(current); +} + +const struct bpf_func_proto bpf_get_cgroup_classid_curr_proto = { + .func = bpf_get_cgroup_classid_curr, + .gpl_only = false, + .ret_type = RET_INTEGER, +}; + +BPF_CALL_1(bpf_skb_cgroup_classid, const struct sk_buff *, skb) +{ + struct sock *sk = skb_to_full_sk(skb); + + if (!sk || !sk_fullsock(sk)) + return 0; + + return sock_cgroup_classid(&sk->sk_cgrp_data); +} + +static const struct bpf_func_proto bpf_skb_cgroup_classid_proto = { + .func = bpf_skb_cgroup_classid, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, +}; +#endif + +BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb) +{ + return task_get_classid(skb); +} + +static const struct bpf_func_proto bpf_get_cgroup_classid_proto = { + .func = bpf_get_cgroup_classid, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, +}; + +BPF_CALL_1(bpf_get_route_realm, const struct sk_buff *, skb) +{ + return dst_tclassid(skb); +} + +static const struct bpf_func_proto bpf_get_route_realm_proto = { + .func = bpf_get_route_realm, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, +}; + +BPF_CALL_1(bpf_get_hash_recalc, struct sk_buff *, skb) +{ + /* If skb_clear_hash() was called due to mangling, we can + * trigger SW recalculation here. Later access to hash + * can then use the inline skb->hash via context directly + * instead of calling this helper again. + */ + return skb_get_hash(skb); +} + +static const struct bpf_func_proto bpf_get_hash_recalc_proto = { + .func = bpf_get_hash_recalc, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, +}; + +BPF_CALL_1(bpf_set_hash_invalid, struct sk_buff *, skb) +{ + /* After all direct packet write, this can be used once for + * triggering a lazy recalc on next skb_get_hash() invocation. + */ + skb_clear_hash(skb); + return 0; +} + +static const struct bpf_func_proto bpf_set_hash_invalid_proto = { + .func = bpf_set_hash_invalid, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, +}; + +BPF_CALL_2(bpf_set_hash, struct sk_buff *, skb, u32, hash) +{ + /* Set user specified hash as L4(+), so that it gets returned + * on skb_get_hash() call unless BPF prog later on triggers a + * skb_clear_hash(). + */ + __skb_set_sw_hash(skb, hash, true); + return 0; +} + +static const struct bpf_func_proto bpf_set_hash_proto = { + .func = bpf_set_hash, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, +}; + +BPF_CALL_3(bpf_skb_vlan_push, struct sk_buff *, skb, __be16, vlan_proto, + u16, vlan_tci) +{ + int ret; + + if (unlikely(vlan_proto != htons(ETH_P_8021Q) && + vlan_proto != htons(ETH_P_8021AD))) + vlan_proto = htons(ETH_P_8021Q); + + bpf_push_mac_rcsum(skb); + ret = skb_vlan_push(skb, vlan_proto, vlan_tci); + bpf_pull_mac_rcsum(skb); + + bpf_compute_data_pointers(skb); + return ret; +} + +static const struct bpf_func_proto bpf_skb_vlan_push_proto = { + .func = bpf_skb_vlan_push, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, +}; + +BPF_CALL_1(bpf_skb_vlan_pop, struct sk_buff *, skb) +{ + int ret; + + bpf_push_mac_rcsum(skb); + ret = skb_vlan_pop(skb); + bpf_pull_mac_rcsum(skb); + + bpf_compute_data_pointers(skb); + return ret; +} + +static const struct bpf_func_proto bpf_skb_vlan_pop_proto = { + .func = bpf_skb_vlan_pop, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, +}; + +static int bpf_skb_generic_push(struct sk_buff *skb, u32 off, u32 len) +{ + /* Caller already did skb_cow() with len as headroom, + * so no need to do it here. + */ + skb_push(skb, len); + memmove(skb->data, skb->data + len, off); + memset(skb->data + off, 0, len); + + /* No skb_postpush_rcsum(skb, skb->data + off, len) + * needed here as it does not change the skb->csum + * result for checksum complete when summing over + * zeroed blocks. + */ + return 0; +} + +static int bpf_skb_generic_pop(struct sk_buff *skb, u32 off, u32 len) +{ + void *old_data; + + /* skb_ensure_writable() is not needed here, as we're + * already working on an uncloned skb. + */ + if (unlikely(!pskb_may_pull(skb, off + len))) + return -ENOMEM; + + old_data = skb->data; + __skb_pull(skb, len); + skb_postpull_rcsum(skb, old_data + off, len); + memmove(skb->data, old_data, off); + + return 0; +} + +static int bpf_skb_net_hdr_push(struct sk_buff *skb, u32 off, u32 len) +{ + bool trans_same = skb->transport_header == skb->network_header; + int ret; + + /* There's no need for __skb_push()/__skb_pull() pair to + * get to the start of the mac header as we're guaranteed + * to always start from here under eBPF. + */ + ret = bpf_skb_generic_push(skb, off, len); + if (likely(!ret)) { + skb->mac_header -= len; + skb->network_header -= len; + if (trans_same) + skb->transport_header = skb->network_header; + } + + return ret; +} + +static int bpf_skb_net_hdr_pop(struct sk_buff *skb, u32 off, u32 len) +{ + bool trans_same = skb->transport_header == skb->network_header; + int ret; + + /* Same here, __skb_push()/__skb_pull() pair not needed. */ + ret = bpf_skb_generic_pop(skb, off, len); + if (likely(!ret)) { + skb->mac_header += len; + skb->network_header += len; + if (trans_same) + skb->transport_header = skb->network_header; + } + + return ret; +} + +static int bpf_skb_proto_4_to_6(struct sk_buff *skb) +{ + const u32 len_diff = sizeof(struct ipv6hdr) - sizeof(struct iphdr); + u32 off = skb_mac_header_len(skb); + int ret; + + ret = skb_cow(skb, len_diff); + if (unlikely(ret < 0)) + return ret; + + ret = bpf_skb_net_hdr_push(skb, off, len_diff); + if (unlikely(ret < 0)) + return ret; + + if (skb_is_gso(skb)) { + struct skb_shared_info *shinfo = skb_shinfo(skb); + + /* SKB_GSO_TCPV4 needs to be changed into SKB_GSO_TCPV6. */ + if (shinfo->gso_type & SKB_GSO_TCPV4) { + shinfo->gso_type &= ~SKB_GSO_TCPV4; + shinfo->gso_type |= SKB_GSO_TCPV6; + } + } + + skb->protocol = htons(ETH_P_IPV6); + skb_clear_hash(skb); + + return 0; +} + +static int bpf_skb_proto_6_to_4(struct sk_buff *skb) +{ + const u32 len_diff = sizeof(struct ipv6hdr) - sizeof(struct iphdr); + u32 off = skb_mac_header_len(skb); + int ret; + + ret = skb_unclone(skb, GFP_ATOMIC); + if (unlikely(ret < 0)) + return ret; + + ret = bpf_skb_net_hdr_pop(skb, off, len_diff); + if (unlikely(ret < 0)) + return ret; + + if (skb_is_gso(skb)) { + struct skb_shared_info *shinfo = skb_shinfo(skb); + + /* SKB_GSO_TCPV6 needs to be changed into SKB_GSO_TCPV4. */ + if (shinfo->gso_type & SKB_GSO_TCPV6) { + shinfo->gso_type &= ~SKB_GSO_TCPV6; + shinfo->gso_type |= SKB_GSO_TCPV4; + } + } + + skb->protocol = htons(ETH_P_IP); + skb_clear_hash(skb); + + return 0; +} + +static int bpf_skb_proto_xlat(struct sk_buff *skb, __be16 to_proto) +{ + __be16 from_proto = skb->protocol; + + if (from_proto == htons(ETH_P_IP) && + to_proto == htons(ETH_P_IPV6)) + return bpf_skb_proto_4_to_6(skb); + + if (from_proto == htons(ETH_P_IPV6) && + to_proto == htons(ETH_P_IP)) + return bpf_skb_proto_6_to_4(skb); + + return -ENOTSUPP; +} + +BPF_CALL_3(bpf_skb_change_proto, struct sk_buff *, skb, __be16, proto, + u64, flags) +{ + int ret; + + if (unlikely(flags)) + return -EINVAL; + + /* General idea is that this helper does the basic groundwork + * needed for changing the protocol, and eBPF program fills the + * rest through bpf_skb_store_bytes(), bpf_lX_csum_replace() + * and other helpers, rather than passing a raw buffer here. + * + * The rationale is to keep this minimal and without a need to + * deal with raw packet data. F.e. even if we would pass buffers + * here, the program still needs to call the bpf_lX_csum_replace() + * helpers anyway. Plus, this way we keep also separation of + * concerns, since f.e. bpf_skb_store_bytes() should only take + * care of stores. + * + * Currently, additional options and extension header space are + * not supported, but flags register is reserved so we can adapt + * that. For offloads, we mark packet as dodgy, so that headers + * need to be verified first. + */ + ret = bpf_skb_proto_xlat(skb, proto); + bpf_compute_data_pointers(skb); + return ret; +} + +static const struct bpf_func_proto bpf_skb_change_proto_proto = { + .func = bpf_skb_change_proto, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, +}; + +BPF_CALL_2(bpf_skb_change_type, struct sk_buff *, skb, u32, pkt_type) +{ + /* We only allow a restricted subset to be changed for now. */ + if (unlikely(!skb_pkt_type_ok(skb->pkt_type) || + !skb_pkt_type_ok(pkt_type))) + return -EINVAL; + + skb->pkt_type = pkt_type; + return 0; +} + +static const struct bpf_func_proto bpf_skb_change_type_proto = { + .func = bpf_skb_change_type, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, +}; + +static u32 bpf_skb_net_base_len(const struct sk_buff *skb) +{ + switch (skb->protocol) { + case htons(ETH_P_IP): + return sizeof(struct iphdr); + case htons(ETH_P_IPV6): + return sizeof(struct ipv6hdr); + default: + return ~0U; + } +} + +#define BPF_F_ADJ_ROOM_ENCAP_L3_MASK (BPF_F_ADJ_ROOM_ENCAP_L3_IPV4 | \ + BPF_F_ADJ_ROOM_ENCAP_L3_IPV6) + +#define BPF_F_ADJ_ROOM_MASK (BPF_F_ADJ_ROOM_FIXED_GSO | \ + BPF_F_ADJ_ROOM_ENCAP_L3_MASK | \ + BPF_F_ADJ_ROOM_ENCAP_L4_GRE | \ + BPF_F_ADJ_ROOM_ENCAP_L4_UDP | \ + BPF_F_ADJ_ROOM_ENCAP_L2_ETH | \ + BPF_F_ADJ_ROOM_ENCAP_L2( \ + BPF_ADJ_ROOM_ENCAP_L2_MASK)) + +static int bpf_skb_net_grow(struct sk_buff *skb, u32 off, u32 len_diff, + u64 flags) +{ + u8 inner_mac_len = flags >> BPF_ADJ_ROOM_ENCAP_L2_SHIFT; + bool encap = flags & BPF_F_ADJ_ROOM_ENCAP_L3_MASK; + u16 mac_len = 0, inner_net = 0, inner_trans = 0; + unsigned int gso_type = SKB_GSO_DODGY; + int ret; + + if (skb_is_gso(skb) && !skb_is_gso_tcp(skb)) { + /* udp gso_size delineates datagrams, only allow if fixed */ + if (!(skb_shinfo(skb)->gso_type & SKB_GSO_UDP_L4) || + !(flags & BPF_F_ADJ_ROOM_FIXED_GSO)) + return -ENOTSUPP; + } + + ret = skb_cow_head(skb, len_diff); + if (unlikely(ret < 0)) + return ret; + + if (encap) { + if (skb->protocol != htons(ETH_P_IP) && + skb->protocol != htons(ETH_P_IPV6)) + return -ENOTSUPP; + + if (flags & BPF_F_ADJ_ROOM_ENCAP_L3_IPV4 && + flags & BPF_F_ADJ_ROOM_ENCAP_L3_IPV6) + return -EINVAL; + + if (flags & BPF_F_ADJ_ROOM_ENCAP_L4_GRE && + flags & BPF_F_ADJ_ROOM_ENCAP_L4_UDP) + return -EINVAL; + + if (flags & BPF_F_ADJ_ROOM_ENCAP_L2_ETH && + inner_mac_len < ETH_HLEN) + return -EINVAL; + + if (skb->encapsulation) + return -EALREADY; + + mac_len = skb->network_header - skb->mac_header; + inner_net = skb->network_header; + if (inner_mac_len > len_diff) + return -EINVAL; + inner_trans = skb->transport_header; + } + + ret = bpf_skb_net_hdr_push(skb, off, len_diff); + if (unlikely(ret < 0)) + return ret; + + if (encap) { + skb->inner_mac_header = inner_net - inner_mac_len; + skb->inner_network_header = inner_net; + skb->inner_transport_header = inner_trans; + + if (flags & BPF_F_ADJ_ROOM_ENCAP_L2_ETH) + skb_set_inner_protocol(skb, htons(ETH_P_TEB)); + else + skb_set_inner_protocol(skb, skb->protocol); + + skb->encapsulation = 1; + skb_set_network_header(skb, mac_len); + + if (flags & BPF_F_ADJ_ROOM_ENCAP_L4_UDP) + gso_type |= SKB_GSO_UDP_TUNNEL; + else if (flags & BPF_F_ADJ_ROOM_ENCAP_L4_GRE) + gso_type |= SKB_GSO_GRE; + else if (flags & BPF_F_ADJ_ROOM_ENCAP_L3_IPV6) + gso_type |= SKB_GSO_IPXIP6; + else if (flags & BPF_F_ADJ_ROOM_ENCAP_L3_IPV4) + gso_type |= SKB_GSO_IPXIP4; + + if (flags & BPF_F_ADJ_ROOM_ENCAP_L4_GRE || + flags & BPF_F_ADJ_ROOM_ENCAP_L4_UDP) { + int nh_len = flags & BPF_F_ADJ_ROOM_ENCAP_L3_IPV6 ? + sizeof(struct ipv6hdr) : + sizeof(struct iphdr); + + skb_set_transport_header(skb, mac_len + nh_len); + } + + /* Match skb->protocol to new outer l3 protocol */ + if (skb->protocol == htons(ETH_P_IP) && + flags & BPF_F_ADJ_ROOM_ENCAP_L3_IPV6) + skb->protocol = htons(ETH_P_IPV6); + else if (skb->protocol == htons(ETH_P_IPV6) && + flags & BPF_F_ADJ_ROOM_ENCAP_L3_IPV4) + skb->protocol = htons(ETH_P_IP); + } + + if (skb_is_gso(skb)) { + struct skb_shared_info *shinfo = skb_shinfo(skb); + + /* Due to header grow, MSS needs to be downgraded. */ + if (!(flags & BPF_F_ADJ_ROOM_FIXED_GSO)) + skb_decrease_gso_size(shinfo, len_diff); + + /* Header must be checked, and gso_segs recomputed. */ + shinfo->gso_type |= gso_type; + shinfo->gso_segs = 0; + } + + return 0; +} + +static int bpf_skb_net_shrink(struct sk_buff *skb, u32 off, u32 len_diff, + u64 flags) +{ + int ret; + + if (unlikely(flags & ~(BPF_F_ADJ_ROOM_FIXED_GSO | + BPF_F_ADJ_ROOM_NO_CSUM_RESET))) + return -EINVAL; + + if (skb_is_gso(skb) && !skb_is_gso_tcp(skb)) { + /* udp gso_size delineates datagrams, only allow if fixed */ + if (!(skb_shinfo(skb)->gso_type & SKB_GSO_UDP_L4) || + !(flags & BPF_F_ADJ_ROOM_FIXED_GSO)) + return -ENOTSUPP; + } + + ret = skb_unclone(skb, GFP_ATOMIC); + if (unlikely(ret < 0)) + return ret; + + ret = bpf_skb_net_hdr_pop(skb, off, len_diff); + if (unlikely(ret < 0)) + return ret; + + if (skb_is_gso(skb)) { + struct skb_shared_info *shinfo = skb_shinfo(skb); + + /* Due to header shrink, MSS can be upgraded. */ + if (!(flags & BPF_F_ADJ_ROOM_FIXED_GSO)) + skb_increase_gso_size(shinfo, len_diff); + + /* Header must be checked, and gso_segs recomputed. */ + shinfo->gso_type |= SKB_GSO_DODGY; + shinfo->gso_segs = 0; + } + + return 0; +} + +#define BPF_SKB_MAX_LEN SKB_MAX_ALLOC + +BPF_CALL_4(sk_skb_adjust_room, struct sk_buff *, skb, s32, len_diff, + u32, mode, u64, flags) +{ + u32 len_diff_abs = abs(len_diff); + bool shrink = len_diff < 0; + int ret = 0; + + if (unlikely(flags || mode)) + return -EINVAL; + if (unlikely(len_diff_abs > 0xfffU)) + return -EFAULT; + + if (!shrink) { + ret = skb_cow(skb, len_diff); + if (unlikely(ret < 0)) + return ret; + __skb_push(skb, len_diff_abs); + memset(skb->data, 0, len_diff_abs); + } else { + if (unlikely(!pskb_may_pull(skb, len_diff_abs))) + return -ENOMEM; + __skb_pull(skb, len_diff_abs); + } + if (tls_sw_has_ctx_rx(skb->sk)) { + struct strp_msg *rxm = strp_msg(skb); + + rxm->full_len += len_diff; + } + return ret; +} + +static const struct bpf_func_proto sk_skb_adjust_room_proto = { + .func = sk_skb_adjust_room, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_ANYTHING, +}; + +BPF_CALL_4(bpf_skb_adjust_room, struct sk_buff *, skb, s32, len_diff, + u32, mode, u64, flags) +{ + u32 len_cur, len_diff_abs = abs(len_diff); + u32 len_min = bpf_skb_net_base_len(skb); + u32 len_max = BPF_SKB_MAX_LEN; + __be16 proto = skb->protocol; + bool shrink = len_diff < 0; + u32 off; + int ret; + + if (unlikely(flags & ~(BPF_F_ADJ_ROOM_MASK | + BPF_F_ADJ_ROOM_NO_CSUM_RESET))) + return -EINVAL; + if (unlikely(len_diff_abs > 0xfffU)) + return -EFAULT; + if (unlikely(proto != htons(ETH_P_IP) && + proto != htons(ETH_P_IPV6))) + return -ENOTSUPP; + + off = skb_mac_header_len(skb); + switch (mode) { + case BPF_ADJ_ROOM_NET: + off += bpf_skb_net_base_len(skb); + break; + case BPF_ADJ_ROOM_MAC: + break; + default: + return -ENOTSUPP; + } + + len_cur = skb->len - skb_network_offset(skb); + if ((shrink && (len_diff_abs >= len_cur || + len_cur - len_diff_abs < len_min)) || + (!shrink && (skb->len + len_diff_abs > len_max && + !skb_is_gso(skb)))) + return -ENOTSUPP; + + ret = shrink ? bpf_skb_net_shrink(skb, off, len_diff_abs, flags) : + bpf_skb_net_grow(skb, off, len_diff_abs, flags); + if (!ret && !(flags & BPF_F_ADJ_ROOM_NO_CSUM_RESET)) + __skb_reset_checksum_unnecessary(skb); + + bpf_compute_data_pointers(skb); + return ret; +} + +static const struct bpf_func_proto bpf_skb_adjust_room_proto = { + .func = bpf_skb_adjust_room, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_ANYTHING, +}; + +static u32 __bpf_skb_min_len(const struct sk_buff *skb) +{ + u32 min_len = skb_network_offset(skb); + + if (skb_transport_header_was_set(skb)) + min_len = skb_transport_offset(skb); + if (skb->ip_summed == CHECKSUM_PARTIAL) + min_len = skb_checksum_start_offset(skb) + + skb->csum_offset + sizeof(__sum16); + return min_len; +} + +static int bpf_skb_grow_rcsum(struct sk_buff *skb, unsigned int new_len) +{ + unsigned int old_len = skb->len; + int ret; + + ret = __skb_grow_rcsum(skb, new_len); + if (!ret) + memset(skb->data + old_len, 0, new_len - old_len); + return ret; +} + +static int bpf_skb_trim_rcsum(struct sk_buff *skb, unsigned int new_len) +{ + return __skb_trim_rcsum(skb, new_len); +} + +static inline int __bpf_skb_change_tail(struct sk_buff *skb, u32 new_len, + u64 flags) +{ + u32 max_len = BPF_SKB_MAX_LEN; + u32 min_len = __bpf_skb_min_len(skb); + int ret; + + if (unlikely(flags || new_len > max_len || new_len < min_len)) + return -EINVAL; + if (skb->encapsulation) + return -ENOTSUPP; + + /* The basic idea of this helper is that it's performing the + * needed work to either grow or trim an skb, and eBPF program + * rewrites the rest via helpers like bpf_skb_store_bytes(), + * bpf_lX_csum_replace() and others rather than passing a raw + * buffer here. This one is a slow path helper and intended + * for replies with control messages. + * + * Like in bpf_skb_change_proto(), we want to keep this rather + * minimal and without protocol specifics so that we are able + * to separate concerns as in bpf_skb_store_bytes() should only + * be the one responsible for writing buffers. + * + * It's really expected to be a slow path operation here for + * control message replies, so we're implicitly linearizing, + * uncloning and drop offloads from the skb by this. + */ + ret = __bpf_try_make_writable(skb, skb->len); + if (!ret) { + if (new_len > skb->len) + ret = bpf_skb_grow_rcsum(skb, new_len); + else if (new_len < skb->len) + ret = bpf_skb_trim_rcsum(skb, new_len); + if (!ret && skb_is_gso(skb)) + skb_gso_reset(skb); + } + return ret; +} + +BPF_CALL_3(bpf_skb_change_tail, struct sk_buff *, skb, u32, new_len, + u64, flags) +{ + int ret = __bpf_skb_change_tail(skb, new_len, flags); + + bpf_compute_data_pointers(skb); + return ret; +} + +static const struct bpf_func_proto bpf_skb_change_tail_proto = { + .func = bpf_skb_change_tail, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, +}; + +BPF_CALL_3(sk_skb_change_tail, struct sk_buff *, skb, u32, new_len, + u64, flags) +{ + return __bpf_skb_change_tail(skb, new_len, flags); +} + +static const struct bpf_func_proto sk_skb_change_tail_proto = { + .func = sk_skb_change_tail, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, +}; + +static inline int __bpf_skb_change_head(struct sk_buff *skb, u32 head_room, + u64 flags) +{ + u32 max_len = BPF_SKB_MAX_LEN; + u32 new_len = skb->len + head_room; + int ret; + + if (unlikely(flags || (!skb_is_gso(skb) && new_len > max_len) || + new_len < skb->len)) + return -EINVAL; + + ret = skb_cow(skb, head_room); + if (likely(!ret)) { + /* Idea for this helper is that we currently only + * allow to expand on mac header. This means that + * skb->protocol network header, etc, stay as is. + * Compared to bpf_skb_change_tail(), we're more + * flexible due to not needing to linearize or + * reset GSO. Intention for this helper is to be + * used by an L3 skb that needs to push mac header + * for redirection into L2 device. + */ + __skb_push(skb, head_room); + memset(skb->data, 0, head_room); + skb_reset_mac_header(skb); + skb_reset_mac_len(skb); + } + + return ret; +} + +BPF_CALL_3(bpf_skb_change_head, struct sk_buff *, skb, u32, head_room, + u64, flags) +{ + int ret = __bpf_skb_change_head(skb, head_room, flags); + + bpf_compute_data_pointers(skb); + return ret; +} + +static const struct bpf_func_proto bpf_skb_change_head_proto = { + .func = bpf_skb_change_head, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, +}; + +BPF_CALL_3(sk_skb_change_head, struct sk_buff *, skb, u32, head_room, + u64, flags) +{ + return __bpf_skb_change_head(skb, head_room, flags); +} + +static const struct bpf_func_proto sk_skb_change_head_proto = { + .func = sk_skb_change_head, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, +}; + +BPF_CALL_1(bpf_xdp_get_buff_len, struct xdp_buff*, xdp) +{ + return xdp_get_buff_len(xdp); +} + +static const struct bpf_func_proto bpf_xdp_get_buff_len_proto = { + .func = bpf_xdp_get_buff_len, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, +}; + +BTF_ID_LIST_SINGLE(bpf_xdp_get_buff_len_bpf_ids, struct, xdp_buff) + +const struct bpf_func_proto bpf_xdp_get_buff_len_trace_proto = { + .func = bpf_xdp_get_buff_len, + .gpl_only = false, + .arg1_type = ARG_PTR_TO_BTF_ID, + .arg1_btf_id = &bpf_xdp_get_buff_len_bpf_ids[0], +}; + +static unsigned long xdp_get_metalen(const struct xdp_buff *xdp) +{ + return xdp_data_meta_unsupported(xdp) ? 0 : + xdp->data - xdp->data_meta; +} + +BPF_CALL_2(bpf_xdp_adjust_head, struct xdp_buff *, xdp, int, offset) +{ + void *xdp_frame_end = xdp->data_hard_start + sizeof(struct xdp_frame); + unsigned long metalen = xdp_get_metalen(xdp); + void *data_start = xdp_frame_end + metalen; + void *data = xdp->data + offset; + + if (unlikely(data < data_start || + data > xdp->data_end - ETH_HLEN)) + return -EINVAL; + + if (metalen) + memmove(xdp->data_meta + offset, + xdp->data_meta, metalen); + xdp->data_meta += offset; + xdp->data = data; + + return 0; +} + +static const struct bpf_func_proto bpf_xdp_adjust_head_proto = { + .func = bpf_xdp_adjust_head, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, +}; + +static void bpf_xdp_copy_buf(struct xdp_buff *xdp, unsigned long off, + void *buf, unsigned long len, bool flush) +{ + unsigned long ptr_len, ptr_off = 0; + skb_frag_t *next_frag, *end_frag; + struct skb_shared_info *sinfo; + void *src, *dst; + u8 *ptr_buf; + + if (likely(xdp->data_end - xdp->data >= off + len)) { + src = flush ? buf : xdp->data + off; + dst = flush ? xdp->data + off : buf; + memcpy(dst, src, len); + return; + } + + sinfo = xdp_get_shared_info_from_buff(xdp); + end_frag = &sinfo->frags[sinfo->nr_frags]; + next_frag = &sinfo->frags[0]; + + ptr_len = xdp->data_end - xdp->data; + ptr_buf = xdp->data; + + while (true) { + if (off < ptr_off + ptr_len) { + unsigned long copy_off = off - ptr_off; + unsigned long copy_len = min(len, ptr_len - copy_off); + + src = flush ? buf : ptr_buf + copy_off; + dst = flush ? ptr_buf + copy_off : buf; + memcpy(dst, src, copy_len); + + off += copy_len; + len -= copy_len; + buf += copy_len; + } + + if (!len || next_frag == end_frag) + break; + + ptr_off += ptr_len; + ptr_buf = skb_frag_address(next_frag); + ptr_len = skb_frag_size(next_frag); + next_frag++; + } +} + +static void *bpf_xdp_pointer(struct xdp_buff *xdp, u32 offset, u32 len) +{ + struct skb_shared_info *sinfo = xdp_get_shared_info_from_buff(xdp); + u32 size = xdp->data_end - xdp->data; + void *addr = xdp->data; + int i; + + if (unlikely(offset > 0xffff || len > 0xffff)) + return ERR_PTR(-EFAULT); + + if (offset + len > xdp_get_buff_len(xdp)) + return ERR_PTR(-EINVAL); + + if (offset < size) /* linear area */ + goto out; + + offset -= size; + for (i = 0; i < sinfo->nr_frags; i++) { /* paged area */ + u32 frag_size = skb_frag_size(&sinfo->frags[i]); + + if (offset < frag_size) { + addr = skb_frag_address(&sinfo->frags[i]); + size = frag_size; + break; + } + offset -= frag_size; + } +out: + return offset + len <= size ? addr + offset : NULL; +} + +BPF_CALL_4(bpf_xdp_load_bytes, struct xdp_buff *, xdp, u32, offset, + void *, buf, u32, len) +{ + void *ptr; + + ptr = bpf_xdp_pointer(xdp, offset, len); + if (IS_ERR(ptr)) + return PTR_ERR(ptr); + + if (!ptr) + bpf_xdp_copy_buf(xdp, offset, buf, len, false); + else + memcpy(buf, ptr, len); + + return 0; +} + +static const struct bpf_func_proto bpf_xdp_load_bytes_proto = { + .func = bpf_xdp_load_bytes, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_PTR_TO_UNINIT_MEM, + .arg4_type = ARG_CONST_SIZE, +}; + +BPF_CALL_4(bpf_xdp_store_bytes, struct xdp_buff *, xdp, u32, offset, + void *, buf, u32, len) +{ + void *ptr; + + ptr = bpf_xdp_pointer(xdp, offset, len); + if (IS_ERR(ptr)) + return PTR_ERR(ptr); + + if (!ptr) + bpf_xdp_copy_buf(xdp, offset, buf, len, true); + else + memcpy(ptr, buf, len); + + return 0; +} + +static const struct bpf_func_proto bpf_xdp_store_bytes_proto = { + .func = bpf_xdp_store_bytes, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_PTR_TO_UNINIT_MEM, + .arg4_type = ARG_CONST_SIZE, +}; + +static int bpf_xdp_frags_increase_tail(struct xdp_buff *xdp, int offset) +{ + struct skb_shared_info *sinfo = xdp_get_shared_info_from_buff(xdp); + skb_frag_t *frag = &sinfo->frags[sinfo->nr_frags - 1]; + struct xdp_rxq_info *rxq = xdp->rxq; + unsigned int tailroom; + + if (!rxq->frag_size || rxq->frag_size > xdp->frame_sz) + return -EOPNOTSUPP; + + tailroom = rxq->frag_size - skb_frag_size(frag) - skb_frag_off(frag); + if (unlikely(offset > tailroom)) + return -EINVAL; + + memset(skb_frag_address(frag) + skb_frag_size(frag), 0, offset); + skb_frag_size_add(frag, offset); + sinfo->xdp_frags_size += offset; + + return 0; +} + +static int bpf_xdp_frags_shrink_tail(struct xdp_buff *xdp, int offset) +{ + struct skb_shared_info *sinfo = xdp_get_shared_info_from_buff(xdp); + int i, n_frags_free = 0, len_free = 0; + + if (unlikely(offset > (int)xdp_get_buff_len(xdp) - ETH_HLEN)) + return -EINVAL; + + for (i = sinfo->nr_frags - 1; i >= 0 && offset > 0; i--) { + skb_frag_t *frag = &sinfo->frags[i]; + int shrink = min_t(int, offset, skb_frag_size(frag)); + + len_free += shrink; + offset -= shrink; + + if (skb_frag_size(frag) == shrink) { + struct page *page = skb_frag_page(frag); + + __xdp_return(page_address(page), &xdp->rxq->mem, + false, NULL); + n_frags_free++; + } else { + skb_frag_size_sub(frag, shrink); + break; + } + } + sinfo->nr_frags -= n_frags_free; + sinfo->xdp_frags_size -= len_free; + + if (unlikely(!sinfo->nr_frags)) { + xdp_buff_clear_frags_flag(xdp); + xdp->data_end -= offset; + } + + return 0; +} + +BPF_CALL_2(bpf_xdp_adjust_tail, struct xdp_buff *, xdp, int, offset) +{ + void *data_hard_end = xdp_data_hard_end(xdp); /* use xdp->frame_sz */ + void *data_end = xdp->data_end + offset; + + if (unlikely(xdp_buff_has_frags(xdp))) { /* non-linear xdp buff */ + if (offset < 0) + return bpf_xdp_frags_shrink_tail(xdp, -offset); + + return bpf_xdp_frags_increase_tail(xdp, offset); + } + + /* Notice that xdp_data_hard_end have reserved some tailroom */ + if (unlikely(data_end > data_hard_end)) + return -EINVAL; + + if (unlikely(data_end < xdp->data + ETH_HLEN)) + return -EINVAL; + + /* Clear memory area on grow, can contain uninit kernel memory */ + if (offset > 0) + memset(xdp->data_end, 0, offset); + + xdp->data_end = data_end; + + return 0; +} + +static const struct bpf_func_proto bpf_xdp_adjust_tail_proto = { + .func = bpf_xdp_adjust_tail, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, +}; + +BPF_CALL_2(bpf_xdp_adjust_meta, struct xdp_buff *, xdp, int, offset) +{ + void *xdp_frame_end = xdp->data_hard_start + sizeof(struct xdp_frame); + void *meta = xdp->data_meta + offset; + unsigned long metalen = xdp->data - meta; + + if (xdp_data_meta_unsupported(xdp)) + return -ENOTSUPP; + if (unlikely(meta < xdp_frame_end || + meta > xdp->data)) + return -EINVAL; + if (unlikely(xdp_metalen_invalid(metalen))) + return -EACCES; + + xdp->data_meta = meta; + + return 0; +} + +static const struct bpf_func_proto bpf_xdp_adjust_meta_proto = { + .func = bpf_xdp_adjust_meta, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, +}; + +/* XDP_REDIRECT works by a three-step process, implemented in the functions + * below: + * + * 1. The bpf_redirect() and bpf_redirect_map() helpers will lookup the target + * of the redirect and store it (along with some other metadata) in a per-CPU + * struct bpf_redirect_info. + * + * 2. When the program returns the XDP_REDIRECT return code, the driver will + * call xdp_do_redirect() which will use the information in struct + * bpf_redirect_info to actually enqueue the frame into a map type-specific + * bulk queue structure. + * + * 3. Before exiting its NAPI poll loop, the driver will call xdp_do_flush(), + * which will flush all the different bulk queues, thus completing the + * redirect. + * + * Pointers to the map entries will be kept around for this whole sequence of + * steps, protected by RCU. However, there is no top-level rcu_read_lock() in + * the core code; instead, the RCU protection relies on everything happening + * inside a single NAPI poll sequence, which means it's between a pair of calls + * to local_bh_disable()/local_bh_enable(). + * + * The map entries are marked as __rcu and the map code makes sure to + * dereference those pointers with rcu_dereference_check() in a way that works + * for both sections that to hold an rcu_read_lock() and sections that are + * called from NAPI without a separate rcu_read_lock(). The code below does not + * use RCU annotations, but relies on those in the map code. + */ +void xdp_do_flush(void) +{ + __dev_flush(); + __cpu_map_flush(); + __xsk_map_flush(); +} +EXPORT_SYMBOL_GPL(xdp_do_flush); + +void bpf_clear_redirect_map(struct bpf_map *map) +{ + struct bpf_redirect_info *ri; + int cpu; + + for_each_possible_cpu(cpu) { + ri = per_cpu_ptr(&bpf_redirect_info, cpu); + /* Avoid polluting remote cacheline due to writes if + * not needed. Once we pass this test, we need the + * cmpxchg() to make sure it hasn't been changed in + * the meantime by remote CPU. + */ + if (unlikely(READ_ONCE(ri->map) == map)) + cmpxchg(&ri->map, map, NULL); + } +} + +DEFINE_STATIC_KEY_FALSE(bpf_master_redirect_enabled_key); +EXPORT_SYMBOL_GPL(bpf_master_redirect_enabled_key); + +u32 xdp_master_redirect(struct xdp_buff *xdp) +{ + struct net_device *master, *slave; + struct bpf_redirect_info *ri = this_cpu_ptr(&bpf_redirect_info); + + master = netdev_master_upper_dev_get_rcu(xdp->rxq->dev); + slave = master->netdev_ops->ndo_xdp_get_xmit_slave(master, xdp); + if (slave && slave != xdp->rxq->dev) { + /* The target device is different from the receiving device, so + * redirect it to the new device. + * Using XDP_REDIRECT gets the correct behaviour from XDP enabled + * drivers to unmap the packet from their rx ring. + */ + ri->tgt_index = slave->ifindex; + ri->map_id = INT_MAX; + ri->map_type = BPF_MAP_TYPE_UNSPEC; + return XDP_REDIRECT; + } + return XDP_TX; +} +EXPORT_SYMBOL_GPL(xdp_master_redirect); + +static inline int __xdp_do_redirect_xsk(struct bpf_redirect_info *ri, + struct net_device *dev, + struct xdp_buff *xdp, + struct bpf_prog *xdp_prog) +{ + enum bpf_map_type map_type = ri->map_type; + void *fwd = ri->tgt_value; + u32 map_id = ri->map_id; + int err; + + ri->map_id = 0; /* Valid map id idr range: [1,INT_MAX[ */ + ri->map_type = BPF_MAP_TYPE_UNSPEC; + + err = __xsk_map_redirect(fwd, xdp); + if (unlikely(err)) + goto err; + + _trace_xdp_redirect_map(dev, xdp_prog, fwd, map_type, map_id, ri->tgt_index); + return 0; +err: + _trace_xdp_redirect_map_err(dev, xdp_prog, fwd, map_type, map_id, ri->tgt_index, err); + return err; +} + +static __always_inline int __xdp_do_redirect_frame(struct bpf_redirect_info *ri, + struct net_device *dev, + struct xdp_frame *xdpf, + struct bpf_prog *xdp_prog) +{ + enum bpf_map_type map_type = ri->map_type; + void *fwd = ri->tgt_value; + u32 map_id = ri->map_id; + struct bpf_map *map; + int err; + + ri->map_id = 0; /* Valid map id idr range: [1,INT_MAX[ */ + ri->map_type = BPF_MAP_TYPE_UNSPEC; + + if (unlikely(!xdpf)) { + err = -EOVERFLOW; + goto err; + } + + switch (map_type) { + case BPF_MAP_TYPE_DEVMAP: + fallthrough; + case BPF_MAP_TYPE_DEVMAP_HASH: + map = READ_ONCE(ri->map); + if (unlikely(map)) { + WRITE_ONCE(ri->map, NULL); + err = dev_map_enqueue_multi(xdpf, dev, map, + ri->flags & BPF_F_EXCLUDE_INGRESS); + } else { + err = dev_map_enqueue(fwd, xdpf, dev); + } + break; + case BPF_MAP_TYPE_CPUMAP: + err = cpu_map_enqueue(fwd, xdpf, dev); + break; + case BPF_MAP_TYPE_UNSPEC: + if (map_id == INT_MAX) { + fwd = dev_get_by_index_rcu(dev_net(dev), ri->tgt_index); + if (unlikely(!fwd)) { + err = -EINVAL; + break; + } + err = dev_xdp_enqueue(fwd, xdpf, dev); + break; + } + fallthrough; + default: + err = -EBADRQC; + } + + if (unlikely(err)) + goto err; + + _trace_xdp_redirect_map(dev, xdp_prog, fwd, map_type, map_id, ri->tgt_index); + return 0; +err: + _trace_xdp_redirect_map_err(dev, xdp_prog, fwd, map_type, map_id, ri->tgt_index, err); + return err; +} + +int xdp_do_redirect(struct net_device *dev, struct xdp_buff *xdp, + struct bpf_prog *xdp_prog) +{ + struct bpf_redirect_info *ri = this_cpu_ptr(&bpf_redirect_info); + enum bpf_map_type map_type = ri->map_type; + + /* XDP_REDIRECT is not fully supported yet for xdp frags since + * not all XDP capable drivers can map non-linear xdp_frame in + * ndo_xdp_xmit. + */ + if (unlikely(xdp_buff_has_frags(xdp) && + map_type != BPF_MAP_TYPE_CPUMAP)) + return -EOPNOTSUPP; + + if (map_type == BPF_MAP_TYPE_XSKMAP) + return __xdp_do_redirect_xsk(ri, dev, xdp, xdp_prog); + + return __xdp_do_redirect_frame(ri, dev, xdp_convert_buff_to_frame(xdp), + xdp_prog); +} +EXPORT_SYMBOL_GPL(xdp_do_redirect); + +int xdp_do_redirect_frame(struct net_device *dev, struct xdp_buff *xdp, + struct xdp_frame *xdpf, struct bpf_prog *xdp_prog) +{ + struct bpf_redirect_info *ri = this_cpu_ptr(&bpf_redirect_info); + enum bpf_map_type map_type = ri->map_type; + + if (map_type == BPF_MAP_TYPE_XSKMAP) + return __xdp_do_redirect_xsk(ri, dev, xdp, xdp_prog); + + return __xdp_do_redirect_frame(ri, dev, xdpf, xdp_prog); +} +EXPORT_SYMBOL_GPL(xdp_do_redirect_frame); + +static int xdp_do_generic_redirect_map(struct net_device *dev, + struct sk_buff *skb, + struct xdp_buff *xdp, + struct bpf_prog *xdp_prog, + void *fwd, + enum bpf_map_type map_type, u32 map_id) +{ + struct bpf_redirect_info *ri = this_cpu_ptr(&bpf_redirect_info); + struct bpf_map *map; + int err; + + switch (map_type) { + case BPF_MAP_TYPE_DEVMAP: + fallthrough; + case BPF_MAP_TYPE_DEVMAP_HASH: + map = READ_ONCE(ri->map); + if (unlikely(map)) { + WRITE_ONCE(ri->map, NULL); + err = dev_map_redirect_multi(dev, skb, xdp_prog, map, + ri->flags & BPF_F_EXCLUDE_INGRESS); + } else { + err = dev_map_generic_redirect(fwd, skb, xdp_prog); + } + if (unlikely(err)) + goto err; + break; + case BPF_MAP_TYPE_XSKMAP: + err = xsk_generic_rcv(fwd, xdp); + if (err) + goto err; + consume_skb(skb); + break; + case BPF_MAP_TYPE_CPUMAP: + err = cpu_map_generic_redirect(fwd, skb); + if (unlikely(err)) + goto err; + break; + default: + err = -EBADRQC; + goto err; + } + + _trace_xdp_redirect_map(dev, xdp_prog, fwd, map_type, map_id, ri->tgt_index); + return 0; +err: + _trace_xdp_redirect_map_err(dev, xdp_prog, fwd, map_type, map_id, ri->tgt_index, err); + return err; +} + +int xdp_do_generic_redirect(struct net_device *dev, struct sk_buff *skb, + struct xdp_buff *xdp, struct bpf_prog *xdp_prog) +{ + struct bpf_redirect_info *ri = this_cpu_ptr(&bpf_redirect_info); + enum bpf_map_type map_type = ri->map_type; + void *fwd = ri->tgt_value; + u32 map_id = ri->map_id; + int err; + + ri->map_id = 0; /* Valid map id idr range: [1,INT_MAX[ */ + ri->map_type = BPF_MAP_TYPE_UNSPEC; + + if (map_type == BPF_MAP_TYPE_UNSPEC && map_id == INT_MAX) { + fwd = dev_get_by_index_rcu(dev_net(dev), ri->tgt_index); + if (unlikely(!fwd)) { + err = -EINVAL; + goto err; + } + + err = xdp_ok_fwd_dev(fwd, skb->len); + if (unlikely(err)) + goto err; + + skb->dev = fwd; + _trace_xdp_redirect(dev, xdp_prog, ri->tgt_index); + generic_xdp_tx(skb, xdp_prog); + return 0; + } + + return xdp_do_generic_redirect_map(dev, skb, xdp, xdp_prog, fwd, map_type, map_id); +err: + _trace_xdp_redirect_err(dev, xdp_prog, ri->tgt_index, err); + return err; +} + +BPF_CALL_2(bpf_xdp_redirect, u32, ifindex, u64, flags) +{ + struct bpf_redirect_info *ri = this_cpu_ptr(&bpf_redirect_info); + + if (unlikely(flags)) + return XDP_ABORTED; + + /* NB! Map type UNSPEC and map_id == INT_MAX (never generated + * by map_idr) is used for ifindex based XDP redirect. + */ + ri->tgt_index = ifindex; + ri->map_id = INT_MAX; + ri->map_type = BPF_MAP_TYPE_UNSPEC; + + return XDP_REDIRECT; +} + +static const struct bpf_func_proto bpf_xdp_redirect_proto = { + .func = bpf_xdp_redirect, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_ANYTHING, + .arg2_type = ARG_ANYTHING, +}; + +BPF_CALL_3(bpf_xdp_redirect_map, struct bpf_map *, map, u32, ifindex, + u64, flags) +{ + return map->ops->map_redirect(map, ifindex, flags); +} + +static const struct bpf_func_proto bpf_xdp_redirect_map_proto = { + .func = bpf_xdp_redirect_map, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_CONST_MAP_PTR, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, +}; + +static unsigned long bpf_skb_copy(void *dst_buff, const void *skb, + unsigned long off, unsigned long len) +{ + void *ptr = skb_header_pointer(skb, off, len, dst_buff); + + if (unlikely(!ptr)) + return len; + if (ptr != dst_buff) + memcpy(dst_buff, ptr, len); + + return 0; +} + +BPF_CALL_5(bpf_skb_event_output, struct sk_buff *, skb, struct bpf_map *, map, + u64, flags, void *, meta, u64, meta_size) +{ + u64 skb_size = (flags & BPF_F_CTXLEN_MASK) >> 32; + + if (unlikely(flags & ~(BPF_F_CTXLEN_MASK | BPF_F_INDEX_MASK))) + return -EINVAL; + if (unlikely(!skb || skb_size > skb->len)) + return -EFAULT; + + return bpf_event_output(map, flags, meta, meta_size, skb, skb_size, + bpf_skb_copy); +} + +static const struct bpf_func_proto bpf_skb_event_output_proto = { + .func = bpf_skb_event_output, + .gpl_only = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg5_type = ARG_CONST_SIZE_OR_ZERO, +}; + +BTF_ID_LIST_SINGLE(bpf_skb_output_btf_ids, struct, sk_buff) + +const struct bpf_func_proto bpf_skb_output_proto = { + .func = bpf_skb_event_output, + .gpl_only = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_BTF_ID, + .arg1_btf_id = &bpf_skb_output_btf_ids[0], + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg5_type = ARG_CONST_SIZE_OR_ZERO, +}; + +static unsigned short bpf_tunnel_key_af(u64 flags) +{ + return flags & BPF_F_TUNINFO_IPV6 ? AF_INET6 : AF_INET; +} + +BPF_CALL_4(bpf_skb_get_tunnel_key, struct sk_buff *, skb, struct bpf_tunnel_key *, to, + u32, size, u64, flags) +{ + const struct ip_tunnel_info *info = skb_tunnel_info(skb); + u8 compat[sizeof(struct bpf_tunnel_key)]; + void *to_orig = to; + int err; + + if (unlikely(!info || (flags & ~(BPF_F_TUNINFO_IPV6 | + BPF_F_TUNINFO_FLAGS)))) { + err = -EINVAL; + goto err_clear; + } + if (ip_tunnel_info_af(info) != bpf_tunnel_key_af(flags)) { + err = -EPROTO; + goto err_clear; + } + if (unlikely(size != sizeof(struct bpf_tunnel_key))) { + err = -EINVAL; + switch (size) { + case offsetof(struct bpf_tunnel_key, local_ipv6[0]): + case offsetof(struct bpf_tunnel_key, tunnel_label): + case offsetof(struct bpf_tunnel_key, tunnel_ext): + goto set_compat; + case offsetof(struct bpf_tunnel_key, remote_ipv6[1]): + /* Fixup deprecated structure layouts here, so we have + * a common path later on. + */ + if (ip_tunnel_info_af(info) != AF_INET) + goto err_clear; +set_compat: + to = (struct bpf_tunnel_key *)compat; + break; + default: + goto err_clear; + } + } + + to->tunnel_id = be64_to_cpu(info->key.tun_id); + to->tunnel_tos = info->key.tos; + to->tunnel_ttl = info->key.ttl; + if (flags & BPF_F_TUNINFO_FLAGS) + to->tunnel_flags = info->key.tun_flags; + else + to->tunnel_ext = 0; + + if (flags & BPF_F_TUNINFO_IPV6) { + memcpy(to->remote_ipv6, &info->key.u.ipv6.src, + sizeof(to->remote_ipv6)); + memcpy(to->local_ipv6, &info->key.u.ipv6.dst, + sizeof(to->local_ipv6)); + to->tunnel_label = be32_to_cpu(info->key.label); + } else { + to->remote_ipv4 = be32_to_cpu(info->key.u.ipv4.src); + memset(&to->remote_ipv6[1], 0, sizeof(__u32) * 3); + to->local_ipv4 = be32_to_cpu(info->key.u.ipv4.dst); + memset(&to->local_ipv6[1], 0, sizeof(__u32) * 3); + to->tunnel_label = 0; + } + + if (unlikely(size != sizeof(struct bpf_tunnel_key))) + memcpy(to_orig, to, size); + + return 0; +err_clear: + memset(to_orig, 0, size); + return err; +} + +static const struct bpf_func_proto bpf_skb_get_tunnel_key_proto = { + .func = bpf_skb_get_tunnel_key, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_UNINIT_MEM, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, +}; + +BPF_CALL_3(bpf_skb_get_tunnel_opt, struct sk_buff *, skb, u8 *, to, u32, size) +{ + const struct ip_tunnel_info *info = skb_tunnel_info(skb); + int err; + + if (unlikely(!info || + !(info->key.tun_flags & TUNNEL_OPTIONS_PRESENT))) { + err = -ENOENT; + goto err_clear; + } + if (unlikely(size < info->options_len)) { + err = -ENOMEM; + goto err_clear; + } + + ip_tunnel_info_opts_get(to, info); + if (size > info->options_len) + memset(to + info->options_len, 0, size - info->options_len); + + return info->options_len; +err_clear: + memset(to, 0, size); + return err; +} + +static const struct bpf_func_proto bpf_skb_get_tunnel_opt_proto = { + .func = bpf_skb_get_tunnel_opt, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_UNINIT_MEM, + .arg3_type = ARG_CONST_SIZE, +}; + +static struct metadata_dst __percpu *md_dst; + +BPF_CALL_4(bpf_skb_set_tunnel_key, struct sk_buff *, skb, + const struct bpf_tunnel_key *, from, u32, size, u64, flags) +{ + struct metadata_dst *md = this_cpu_ptr(md_dst); + u8 compat[sizeof(struct bpf_tunnel_key)]; + struct ip_tunnel_info *info; + + if (unlikely(flags & ~(BPF_F_TUNINFO_IPV6 | BPF_F_ZERO_CSUM_TX | + BPF_F_DONT_FRAGMENT | BPF_F_SEQ_NUMBER))) + return -EINVAL; + if (unlikely(size != sizeof(struct bpf_tunnel_key))) { + switch (size) { + case offsetof(struct bpf_tunnel_key, local_ipv6[0]): + case offsetof(struct bpf_tunnel_key, tunnel_label): + case offsetof(struct bpf_tunnel_key, tunnel_ext): + case offsetof(struct bpf_tunnel_key, remote_ipv6[1]): + /* Fixup deprecated structure layouts here, so we have + * a common path later on. + */ + memcpy(compat, from, size); + memset(compat + size, 0, sizeof(compat) - size); + from = (const struct bpf_tunnel_key *) compat; + break; + default: + return -EINVAL; + } + } + if (unlikely((!(flags & BPF_F_TUNINFO_IPV6) && from->tunnel_label) || + from->tunnel_ext)) + return -EINVAL; + + skb_dst_drop(skb); + dst_hold((struct dst_entry *) md); + skb_dst_set(skb, (struct dst_entry *) md); + + info = &md->u.tun_info; + memset(info, 0, sizeof(*info)); + info->mode = IP_TUNNEL_INFO_TX; + + info->key.tun_flags = TUNNEL_KEY | TUNNEL_CSUM | TUNNEL_NOCACHE; + if (flags & BPF_F_DONT_FRAGMENT) + info->key.tun_flags |= TUNNEL_DONT_FRAGMENT; + if (flags & BPF_F_ZERO_CSUM_TX) + info->key.tun_flags &= ~TUNNEL_CSUM; + if (flags & BPF_F_SEQ_NUMBER) + info->key.tun_flags |= TUNNEL_SEQ; + + info->key.tun_id = cpu_to_be64(from->tunnel_id); + info->key.tos = from->tunnel_tos; + info->key.ttl = from->tunnel_ttl; + + if (flags & BPF_F_TUNINFO_IPV6) { + info->mode |= IP_TUNNEL_INFO_IPV6; + memcpy(&info->key.u.ipv6.dst, from->remote_ipv6, + sizeof(from->remote_ipv6)); + memcpy(&info->key.u.ipv6.src, from->local_ipv6, + sizeof(from->local_ipv6)); + info->key.label = cpu_to_be32(from->tunnel_label) & + IPV6_FLOWLABEL_MASK; + } else { + info->key.u.ipv4.dst = cpu_to_be32(from->remote_ipv4); + info->key.u.ipv4.src = cpu_to_be32(from->local_ipv4); + info->key.flow_flags = FLOWI_FLAG_ANYSRC; + } + + return 0; +} + +static const struct bpf_func_proto bpf_skb_set_tunnel_key_proto = { + .func = bpf_skb_set_tunnel_key, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, +}; + +BPF_CALL_3(bpf_skb_set_tunnel_opt, struct sk_buff *, skb, + const u8 *, from, u32, size) +{ + struct ip_tunnel_info *info = skb_tunnel_info(skb); + const struct metadata_dst *md = this_cpu_ptr(md_dst); + + if (unlikely(info != &md->u.tun_info || (size & (sizeof(u32) - 1)))) + return -EINVAL; + if (unlikely(size > IP_TUNNEL_OPTS_MAX)) + return -ENOMEM; + + ip_tunnel_info_opts_set(info, from, size, TUNNEL_OPTIONS_PRESENT); + + return 0; +} + +static const struct bpf_func_proto bpf_skb_set_tunnel_opt_proto = { + .func = bpf_skb_set_tunnel_opt, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, +}; + +static const struct bpf_func_proto * +bpf_get_skb_set_tunnel_proto(enum bpf_func_id which) +{ + if (!md_dst) { + struct metadata_dst __percpu *tmp; + + tmp = metadata_dst_alloc_percpu(IP_TUNNEL_OPTS_MAX, + METADATA_IP_TUNNEL, + GFP_KERNEL); + if (!tmp) + return NULL; + if (cmpxchg(&md_dst, NULL, tmp)) + metadata_dst_free_percpu(tmp); + } + + switch (which) { + case BPF_FUNC_skb_set_tunnel_key: + return &bpf_skb_set_tunnel_key_proto; + case BPF_FUNC_skb_set_tunnel_opt: + return &bpf_skb_set_tunnel_opt_proto; + default: + return NULL; + } +} + +BPF_CALL_3(bpf_skb_under_cgroup, struct sk_buff *, skb, struct bpf_map *, map, + u32, idx) +{ + struct bpf_array *array = container_of(map, struct bpf_array, map); + struct cgroup *cgrp; + struct sock *sk; + + sk = skb_to_full_sk(skb); + if (!sk || !sk_fullsock(sk)) + return -ENOENT; + if (unlikely(idx >= array->map.max_entries)) + return -E2BIG; + + cgrp = READ_ONCE(array->ptrs[idx]); + if (unlikely(!cgrp)) + return -EAGAIN; + + return sk_under_cgroup_hierarchy(sk, cgrp); +} + +static const struct bpf_func_proto bpf_skb_under_cgroup_proto = { + .func = bpf_skb_under_cgroup, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_ANYTHING, +}; + +#ifdef CONFIG_SOCK_CGROUP_DATA +static inline u64 __bpf_sk_cgroup_id(struct sock *sk) +{ + struct cgroup *cgrp; + + sk = sk_to_full_sk(sk); + if (!sk || !sk_fullsock(sk)) + return 0; + + cgrp = sock_cgroup_ptr(&sk->sk_cgrp_data); + return cgroup_id(cgrp); +} + +BPF_CALL_1(bpf_skb_cgroup_id, const struct sk_buff *, skb) +{ + return __bpf_sk_cgroup_id(skb->sk); +} + +static const struct bpf_func_proto bpf_skb_cgroup_id_proto = { + .func = bpf_skb_cgroup_id, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, +}; + +static inline u64 __bpf_sk_ancestor_cgroup_id(struct sock *sk, + int ancestor_level) +{ + struct cgroup *ancestor; + struct cgroup *cgrp; + + sk = sk_to_full_sk(sk); + if (!sk || !sk_fullsock(sk)) + return 0; + + cgrp = sock_cgroup_ptr(&sk->sk_cgrp_data); + ancestor = cgroup_ancestor(cgrp, ancestor_level); + if (!ancestor) + return 0; + + return cgroup_id(ancestor); +} + +BPF_CALL_2(bpf_skb_ancestor_cgroup_id, const struct sk_buff *, skb, int, + ancestor_level) +{ + return __bpf_sk_ancestor_cgroup_id(skb->sk, ancestor_level); +} + +static const struct bpf_func_proto bpf_skb_ancestor_cgroup_id_proto = { + .func = bpf_skb_ancestor_cgroup_id, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, +}; + +BPF_CALL_1(bpf_sk_cgroup_id, struct sock *, sk) +{ + return __bpf_sk_cgroup_id(sk); +} + +static const struct bpf_func_proto bpf_sk_cgroup_id_proto = { + .func = bpf_sk_cgroup_id, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, +}; + +BPF_CALL_2(bpf_sk_ancestor_cgroup_id, struct sock *, sk, int, ancestor_level) +{ + return __bpf_sk_ancestor_cgroup_id(sk, ancestor_level); +} + +static const struct bpf_func_proto bpf_sk_ancestor_cgroup_id_proto = { + .func = bpf_sk_ancestor_cgroup_id, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, + .arg2_type = ARG_ANYTHING, +}; +#endif + +static unsigned long bpf_xdp_copy(void *dst, const void *ctx, + unsigned long off, unsigned long len) +{ + struct xdp_buff *xdp = (struct xdp_buff *)ctx; + + bpf_xdp_copy_buf(xdp, off, dst, len, false); + return 0; +} + +BPF_CALL_5(bpf_xdp_event_output, struct xdp_buff *, xdp, struct bpf_map *, map, + u64, flags, void *, meta, u64, meta_size) +{ + u64 xdp_size = (flags & BPF_F_CTXLEN_MASK) >> 32; + + if (unlikely(flags & ~(BPF_F_CTXLEN_MASK | BPF_F_INDEX_MASK))) + return -EINVAL; + + if (unlikely(!xdp || xdp_size > xdp_get_buff_len(xdp))) + return -EFAULT; + + return bpf_event_output(map, flags, meta, meta_size, xdp, + xdp_size, bpf_xdp_copy); +} + +static const struct bpf_func_proto bpf_xdp_event_output_proto = { + .func = bpf_xdp_event_output, + .gpl_only = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg5_type = ARG_CONST_SIZE_OR_ZERO, +}; + +BTF_ID_LIST_SINGLE(bpf_xdp_output_btf_ids, struct, xdp_buff) + +const struct bpf_func_proto bpf_xdp_output_proto = { + .func = bpf_xdp_event_output, + .gpl_only = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_BTF_ID, + .arg1_btf_id = &bpf_xdp_output_btf_ids[0], + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg5_type = ARG_CONST_SIZE_OR_ZERO, +}; + +BPF_CALL_1(bpf_get_socket_cookie, struct sk_buff *, skb) +{ + return skb->sk ? __sock_gen_cookie(skb->sk) : 0; +} + +static const struct bpf_func_proto bpf_get_socket_cookie_proto = { + .func = bpf_get_socket_cookie, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, +}; + +BPF_CALL_1(bpf_get_socket_cookie_sock_addr, struct bpf_sock_addr_kern *, ctx) +{ + return __sock_gen_cookie(ctx->sk); +} + +static const struct bpf_func_proto bpf_get_socket_cookie_sock_addr_proto = { + .func = bpf_get_socket_cookie_sock_addr, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, +}; + +BPF_CALL_1(bpf_get_socket_cookie_sock, struct sock *, ctx) +{ + return __sock_gen_cookie(ctx); +} + +static const struct bpf_func_proto bpf_get_socket_cookie_sock_proto = { + .func = bpf_get_socket_cookie_sock, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, +}; + +BPF_CALL_1(bpf_get_socket_ptr_cookie, struct sock *, sk) +{ + return sk ? sock_gen_cookie(sk) : 0; +} + +const struct bpf_func_proto bpf_get_socket_ptr_cookie_proto = { + .func = bpf_get_socket_ptr_cookie, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, +}; + +BPF_CALL_1(bpf_get_socket_cookie_sock_ops, struct bpf_sock_ops_kern *, ctx) +{ + return __sock_gen_cookie(ctx->sk); +} + +static const struct bpf_func_proto bpf_get_socket_cookie_sock_ops_proto = { + .func = bpf_get_socket_cookie_sock_ops, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, +}; + +static u64 __bpf_get_netns_cookie(struct sock *sk) +{ + const struct net *net = sk ? sock_net(sk) : &init_net; + + return net->net_cookie; +} + +BPF_CALL_1(bpf_get_netns_cookie_sock, struct sock *, ctx) +{ + return __bpf_get_netns_cookie(ctx); +} + +static const struct bpf_func_proto bpf_get_netns_cookie_sock_proto = { + .func = bpf_get_netns_cookie_sock, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX_OR_NULL, +}; + +BPF_CALL_1(bpf_get_netns_cookie_sock_addr, struct bpf_sock_addr_kern *, ctx) +{ + return __bpf_get_netns_cookie(ctx ? ctx->sk : NULL); +} + +static const struct bpf_func_proto bpf_get_netns_cookie_sock_addr_proto = { + .func = bpf_get_netns_cookie_sock_addr, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX_OR_NULL, +}; + +BPF_CALL_1(bpf_get_netns_cookie_sock_ops, struct bpf_sock_ops_kern *, ctx) +{ + return __bpf_get_netns_cookie(ctx ? ctx->sk : NULL); +} + +static const struct bpf_func_proto bpf_get_netns_cookie_sock_ops_proto = { + .func = bpf_get_netns_cookie_sock_ops, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX_OR_NULL, +}; + +BPF_CALL_1(bpf_get_netns_cookie_sk_msg, struct sk_msg *, ctx) +{ + return __bpf_get_netns_cookie(ctx ? ctx->sk : NULL); +} + +static const struct bpf_func_proto bpf_get_netns_cookie_sk_msg_proto = { + .func = bpf_get_netns_cookie_sk_msg, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX_OR_NULL, +}; + +BPF_CALL_1(bpf_get_socket_uid, struct sk_buff *, skb) +{ + struct sock *sk = sk_to_full_sk(skb->sk); + kuid_t kuid; + + if (!sk || !sk_fullsock(sk)) + return overflowuid; + kuid = sock_net_uid(sock_net(sk), sk); + return from_kuid_munged(sock_net(sk)->user_ns, kuid); +} + +static const struct bpf_func_proto bpf_get_socket_uid_proto = { + .func = bpf_get_socket_uid, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, +}; + +static int sol_socket_sockopt(struct sock *sk, int optname, + char *optval, int *optlen, + bool getopt) +{ + switch (optname) { + case SO_REUSEADDR: + case SO_SNDBUF: + case SO_RCVBUF: + case SO_KEEPALIVE: + case SO_PRIORITY: + case SO_REUSEPORT: + case SO_RCVLOWAT: + case SO_MARK: + case SO_MAX_PACING_RATE: + case SO_BINDTOIFINDEX: + case SO_TXREHASH: + if (*optlen != sizeof(int)) + return -EINVAL; + break; + case SO_BINDTODEVICE: + break; + default: + return -EINVAL; + } + + if (getopt) { + if (optname == SO_BINDTODEVICE) + return -EINVAL; + return sk_getsockopt(sk, SOL_SOCKET, optname, + KERNEL_SOCKPTR(optval), + KERNEL_SOCKPTR(optlen)); + } + + return sk_setsockopt(sk, SOL_SOCKET, optname, + KERNEL_SOCKPTR(optval), *optlen); +} + +static int bpf_sol_tcp_setsockopt(struct sock *sk, int optname, + char *optval, int optlen) +{ + struct tcp_sock *tp = tcp_sk(sk); + unsigned long timeout; + int val; + + if (optlen != sizeof(int)) + return -EINVAL; + + val = *(int *)optval; + + /* Only some options are supported */ + switch (optname) { + case TCP_BPF_IW: + if (val <= 0 || tp->data_segs_out > tp->syn_data) + return -EINVAL; + tcp_snd_cwnd_set(tp, val); + break; + case TCP_BPF_SNDCWND_CLAMP: + if (val <= 0) + return -EINVAL; + tp->snd_cwnd_clamp = val; + tp->snd_ssthresh = val; + break; + case TCP_BPF_DELACK_MAX: + timeout = usecs_to_jiffies(val); + if (timeout > TCP_DELACK_MAX || + timeout < TCP_TIMEOUT_MIN) + return -EINVAL; + inet_csk(sk)->icsk_delack_max = timeout; + break; + case TCP_BPF_RTO_MIN: + timeout = usecs_to_jiffies(val); + if (timeout > TCP_RTO_MIN || + timeout < TCP_TIMEOUT_MIN) + return -EINVAL; + inet_csk(sk)->icsk_rto_min = timeout; + break; + default: + return -EINVAL; + } + + return 0; +} + +static int sol_tcp_sockopt_congestion(struct sock *sk, char *optval, + int *optlen, bool getopt) +{ + struct tcp_sock *tp; + int ret; + + if (*optlen < 2) + return -EINVAL; + + if (getopt) { + if (!inet_csk(sk)->icsk_ca_ops) + return -EINVAL; + /* BPF expects NULL-terminated tcp-cc string */ + optval[--(*optlen)] = '\0'; + return do_tcp_getsockopt(sk, SOL_TCP, TCP_CONGESTION, + KERNEL_SOCKPTR(optval), + KERNEL_SOCKPTR(optlen)); + } + + /* "cdg" is the only cc that alloc a ptr + * in inet_csk_ca area. The bpf-tcp-cc may + * overwrite this ptr after switching to cdg. + */ + if (*optlen >= sizeof("cdg") - 1 && !strncmp("cdg", optval, *optlen)) + return -ENOTSUPP; + + /* It stops this looping + * + * .init => bpf_setsockopt(tcp_cc) => .init => + * bpf_setsockopt(tcp_cc)" => .init => .... + * + * The second bpf_setsockopt(tcp_cc) is not allowed + * in order to break the loop when both .init + * are the same bpf prog. + * + * This applies even the second bpf_setsockopt(tcp_cc) + * does not cause a loop. This limits only the first + * '.init' can call bpf_setsockopt(TCP_CONGESTION) to + * pick a fallback cc (eg. peer does not support ECN) + * and the second '.init' cannot fallback to + * another. + */ + tp = tcp_sk(sk); + if (tp->bpf_chg_cc_inprogress) + return -EBUSY; + + tp->bpf_chg_cc_inprogress = 1; + ret = do_tcp_setsockopt(sk, SOL_TCP, TCP_CONGESTION, + KERNEL_SOCKPTR(optval), *optlen); + tp->bpf_chg_cc_inprogress = 0; + return ret; +} + +static int sol_tcp_sockopt(struct sock *sk, int optname, + char *optval, int *optlen, + bool getopt) +{ + if (sk->sk_prot->setsockopt != tcp_setsockopt) + return -EINVAL; + + switch (optname) { + case TCP_NODELAY: + case TCP_MAXSEG: + case TCP_KEEPIDLE: + case TCP_KEEPINTVL: + case TCP_KEEPCNT: + case TCP_SYNCNT: + case TCP_WINDOW_CLAMP: + case TCP_THIN_LINEAR_TIMEOUTS: + case TCP_USER_TIMEOUT: + case TCP_NOTSENT_LOWAT: + case TCP_SAVE_SYN: + if (*optlen != sizeof(int)) + return -EINVAL; + break; + case TCP_CONGESTION: + return sol_tcp_sockopt_congestion(sk, optval, optlen, getopt); + case TCP_SAVED_SYN: + if (*optlen < 1) + return -EINVAL; + break; + default: + if (getopt) + return -EINVAL; + return bpf_sol_tcp_setsockopt(sk, optname, optval, *optlen); + } + + if (getopt) { + if (optname == TCP_SAVED_SYN) { + struct tcp_sock *tp = tcp_sk(sk); + + if (!tp->saved_syn || + *optlen > tcp_saved_syn_len(tp->saved_syn)) + return -EINVAL; + memcpy(optval, tp->saved_syn->data, *optlen); + /* It cannot free tp->saved_syn here because it + * does not know if the user space still needs it. + */ + return 0; + } + + return do_tcp_getsockopt(sk, SOL_TCP, optname, + KERNEL_SOCKPTR(optval), + KERNEL_SOCKPTR(optlen)); + } + + return do_tcp_setsockopt(sk, SOL_TCP, optname, + KERNEL_SOCKPTR(optval), *optlen); +} + +static int sol_ip_sockopt(struct sock *sk, int optname, + char *optval, int *optlen, + bool getopt) +{ + if (sk->sk_family != AF_INET) + return -EINVAL; + + switch (optname) { + case IP_TOS: + if (*optlen != sizeof(int)) + return -EINVAL; + break; + default: + return -EINVAL; + } + + if (getopt) + return do_ip_getsockopt(sk, SOL_IP, optname, + KERNEL_SOCKPTR(optval), + KERNEL_SOCKPTR(optlen)); + + return do_ip_setsockopt(sk, SOL_IP, optname, + KERNEL_SOCKPTR(optval), *optlen); +} + +static int sol_ipv6_sockopt(struct sock *sk, int optname, + char *optval, int *optlen, + bool getopt) +{ + if (sk->sk_family != AF_INET6) + return -EINVAL; + + switch (optname) { + case IPV6_TCLASS: + case IPV6_AUTOFLOWLABEL: + if (*optlen != sizeof(int)) + return -EINVAL; + break; + default: + return -EINVAL; + } + + if (getopt) + return ipv6_bpf_stub->ipv6_getsockopt(sk, SOL_IPV6, optname, + KERNEL_SOCKPTR(optval), + KERNEL_SOCKPTR(optlen)); + + return ipv6_bpf_stub->ipv6_setsockopt(sk, SOL_IPV6, optname, + KERNEL_SOCKPTR(optval), *optlen); +} + +static int __bpf_setsockopt(struct sock *sk, int level, int optname, + char *optval, int optlen) +{ + if (!sk_fullsock(sk)) + return -EINVAL; + + if (level == SOL_SOCKET) + return sol_socket_sockopt(sk, optname, optval, &optlen, false); + else if (IS_ENABLED(CONFIG_INET) && level == SOL_IP) + return sol_ip_sockopt(sk, optname, optval, &optlen, false); + else if (IS_ENABLED(CONFIG_IPV6) && level == SOL_IPV6) + return sol_ipv6_sockopt(sk, optname, optval, &optlen, false); + else if (IS_ENABLED(CONFIG_INET) && level == SOL_TCP) + return sol_tcp_sockopt(sk, optname, optval, &optlen, false); + + return -EINVAL; +} + +static int _bpf_setsockopt(struct sock *sk, int level, int optname, + char *optval, int optlen) +{ + if (sk_fullsock(sk)) + sock_owned_by_me(sk); + return __bpf_setsockopt(sk, level, optname, optval, optlen); +} + +static int __bpf_getsockopt(struct sock *sk, int level, int optname, + char *optval, int optlen) +{ + int err, saved_optlen = optlen; + + if (!sk_fullsock(sk)) { + err = -EINVAL; + goto done; + } + + if (level == SOL_SOCKET) + err = sol_socket_sockopt(sk, optname, optval, &optlen, true); + else if (IS_ENABLED(CONFIG_INET) && level == SOL_TCP) + err = sol_tcp_sockopt(sk, optname, optval, &optlen, true); + else if (IS_ENABLED(CONFIG_INET) && level == SOL_IP) + err = sol_ip_sockopt(sk, optname, optval, &optlen, true); + else if (IS_ENABLED(CONFIG_IPV6) && level == SOL_IPV6) + err = sol_ipv6_sockopt(sk, optname, optval, &optlen, true); + else + err = -EINVAL; + +done: + if (err) + optlen = 0; + if (optlen < saved_optlen) + memset(optval + optlen, 0, saved_optlen - optlen); + return err; +} + +static int _bpf_getsockopt(struct sock *sk, int level, int optname, + char *optval, int optlen) +{ + if (sk_fullsock(sk)) + sock_owned_by_me(sk); + return __bpf_getsockopt(sk, level, optname, optval, optlen); +} + +BPF_CALL_5(bpf_sk_setsockopt, struct sock *, sk, int, level, + int, optname, char *, optval, int, optlen) +{ + return _bpf_setsockopt(sk, level, optname, optval, optlen); +} + +const struct bpf_func_proto bpf_sk_setsockopt_proto = { + .func = bpf_sk_setsockopt, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg5_type = ARG_CONST_SIZE, +}; + +BPF_CALL_5(bpf_sk_getsockopt, struct sock *, sk, int, level, + int, optname, char *, optval, int, optlen) +{ + return _bpf_getsockopt(sk, level, optname, optval, optlen); +} + +const struct bpf_func_proto bpf_sk_getsockopt_proto = { + .func = bpf_sk_getsockopt, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_PTR_TO_UNINIT_MEM, + .arg5_type = ARG_CONST_SIZE, +}; + +BPF_CALL_5(bpf_unlocked_sk_setsockopt, struct sock *, sk, int, level, + int, optname, char *, optval, int, optlen) +{ + return __bpf_setsockopt(sk, level, optname, optval, optlen); +} + +const struct bpf_func_proto bpf_unlocked_sk_setsockopt_proto = { + .func = bpf_unlocked_sk_setsockopt, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg5_type = ARG_CONST_SIZE, +}; + +BPF_CALL_5(bpf_unlocked_sk_getsockopt, struct sock *, sk, int, level, + int, optname, char *, optval, int, optlen) +{ + return __bpf_getsockopt(sk, level, optname, optval, optlen); +} + +const struct bpf_func_proto bpf_unlocked_sk_getsockopt_proto = { + .func = bpf_unlocked_sk_getsockopt, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_PTR_TO_UNINIT_MEM, + .arg5_type = ARG_CONST_SIZE, +}; + +BPF_CALL_5(bpf_sock_addr_setsockopt, struct bpf_sock_addr_kern *, ctx, + int, level, int, optname, char *, optval, int, optlen) +{ + return _bpf_setsockopt(ctx->sk, level, optname, optval, optlen); +} + +static const struct bpf_func_proto bpf_sock_addr_setsockopt_proto = { + .func = bpf_sock_addr_setsockopt, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg5_type = ARG_CONST_SIZE, +}; + +BPF_CALL_5(bpf_sock_addr_getsockopt, struct bpf_sock_addr_kern *, ctx, + int, level, int, optname, char *, optval, int, optlen) +{ + return _bpf_getsockopt(ctx->sk, level, optname, optval, optlen); +} + +static const struct bpf_func_proto bpf_sock_addr_getsockopt_proto = { + .func = bpf_sock_addr_getsockopt, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_PTR_TO_UNINIT_MEM, + .arg5_type = ARG_CONST_SIZE, +}; + +BPF_CALL_5(bpf_sock_ops_setsockopt, struct bpf_sock_ops_kern *, bpf_sock, + int, level, int, optname, char *, optval, int, optlen) +{ + return _bpf_setsockopt(bpf_sock->sk, level, optname, optval, optlen); +} + +static const struct bpf_func_proto bpf_sock_ops_setsockopt_proto = { + .func = bpf_sock_ops_setsockopt, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg5_type = ARG_CONST_SIZE, +}; + +static int bpf_sock_ops_get_syn(struct bpf_sock_ops_kern *bpf_sock, + int optname, const u8 **start) +{ + struct sk_buff *syn_skb = bpf_sock->syn_skb; + const u8 *hdr_start; + int ret; + + if (syn_skb) { + /* sk is a request_sock here */ + + if (optname == TCP_BPF_SYN) { + hdr_start = syn_skb->data; + ret = tcp_hdrlen(syn_skb); + } else if (optname == TCP_BPF_SYN_IP) { + hdr_start = skb_network_header(syn_skb); + ret = skb_network_header_len(syn_skb) + + tcp_hdrlen(syn_skb); + } else { + /* optname == TCP_BPF_SYN_MAC */ + hdr_start = skb_mac_header(syn_skb); + ret = skb_mac_header_len(syn_skb) + + skb_network_header_len(syn_skb) + + tcp_hdrlen(syn_skb); + } + } else { + struct sock *sk = bpf_sock->sk; + struct saved_syn *saved_syn; + + if (sk->sk_state == TCP_NEW_SYN_RECV) + /* synack retransmit. bpf_sock->syn_skb will + * not be available. It has to resort to + * saved_syn (if it is saved). + */ + saved_syn = inet_reqsk(sk)->saved_syn; + else + saved_syn = tcp_sk(sk)->saved_syn; + + if (!saved_syn) + return -ENOENT; + + if (optname == TCP_BPF_SYN) { + hdr_start = saved_syn->data + + saved_syn->mac_hdrlen + + saved_syn->network_hdrlen; + ret = saved_syn->tcp_hdrlen; + } else if (optname == TCP_BPF_SYN_IP) { + hdr_start = saved_syn->data + + saved_syn->mac_hdrlen; + ret = saved_syn->network_hdrlen + + saved_syn->tcp_hdrlen; + } else { + /* optname == TCP_BPF_SYN_MAC */ + + /* TCP_SAVE_SYN may not have saved the mac hdr */ + if (!saved_syn->mac_hdrlen) + return -ENOENT; + + hdr_start = saved_syn->data; + ret = saved_syn->mac_hdrlen + + saved_syn->network_hdrlen + + saved_syn->tcp_hdrlen; + } + } + + *start = hdr_start; + return ret; +} + +BPF_CALL_5(bpf_sock_ops_getsockopt, struct bpf_sock_ops_kern *, bpf_sock, + int, level, int, optname, char *, optval, int, optlen) +{ + if (IS_ENABLED(CONFIG_INET) && level == SOL_TCP && + optname >= TCP_BPF_SYN && optname <= TCP_BPF_SYN_MAC) { + int ret, copy_len = 0; + const u8 *start; + + ret = bpf_sock_ops_get_syn(bpf_sock, optname, &start); + if (ret > 0) { + copy_len = ret; + if (optlen < copy_len) { + copy_len = optlen; + ret = -ENOSPC; + } + + memcpy(optval, start, copy_len); + } + + /* Zero out unused buffer at the end */ + memset(optval + copy_len, 0, optlen - copy_len); + + return ret; + } + + return _bpf_getsockopt(bpf_sock->sk, level, optname, optval, optlen); +} + +static const struct bpf_func_proto bpf_sock_ops_getsockopt_proto = { + .func = bpf_sock_ops_getsockopt, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_PTR_TO_UNINIT_MEM, + .arg5_type = ARG_CONST_SIZE, +}; + +BPF_CALL_2(bpf_sock_ops_cb_flags_set, struct bpf_sock_ops_kern *, bpf_sock, + int, argval) +{ + struct sock *sk = bpf_sock->sk; + int val = argval & BPF_SOCK_OPS_ALL_CB_FLAGS; + + if (!IS_ENABLED(CONFIG_INET) || !sk_fullsock(sk)) + return -EINVAL; + + tcp_sk(sk)->bpf_sock_ops_cb_flags = val; + + return argval & (~BPF_SOCK_OPS_ALL_CB_FLAGS); +} + +static const struct bpf_func_proto bpf_sock_ops_cb_flags_set_proto = { + .func = bpf_sock_ops_cb_flags_set, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, +}; + +const struct ipv6_bpf_stub *ipv6_bpf_stub __read_mostly; +EXPORT_SYMBOL_GPL(ipv6_bpf_stub); + +BPF_CALL_3(bpf_bind, struct bpf_sock_addr_kern *, ctx, struct sockaddr *, addr, + int, addr_len) +{ +#ifdef CONFIG_INET + struct sock *sk = ctx->sk; + u32 flags = BIND_FROM_BPF; + int err; + + err = -EINVAL; + if (addr_len < offsetofend(struct sockaddr, sa_family)) + return err; + if (addr->sa_family == AF_INET) { + if (addr_len < sizeof(struct sockaddr_in)) + return err; + if (((struct sockaddr_in *)addr)->sin_port == htons(0)) + flags |= BIND_FORCE_ADDRESS_NO_PORT; + return __inet_bind(sk, addr, addr_len, flags); +#if IS_ENABLED(CONFIG_IPV6) + } else if (addr->sa_family == AF_INET6) { + if (addr_len < SIN6_LEN_RFC2133) + return err; + if (((struct sockaddr_in6 *)addr)->sin6_port == htons(0)) + flags |= BIND_FORCE_ADDRESS_NO_PORT; + /* ipv6_bpf_stub cannot be NULL, since it's called from + * bpf_cgroup_inet6_connect hook and ipv6 is already loaded + */ + return ipv6_bpf_stub->inet6_bind(sk, addr, addr_len, flags); +#endif /* CONFIG_IPV6 */ + } +#endif /* CONFIG_INET */ + + return -EAFNOSUPPORT; +} + +static const struct bpf_func_proto bpf_bind_proto = { + .func = bpf_bind, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, +}; + +#ifdef CONFIG_XFRM +BPF_CALL_5(bpf_skb_get_xfrm_state, struct sk_buff *, skb, u32, index, + struct bpf_xfrm_state *, to, u32, size, u64, flags) +{ + const struct sec_path *sp = skb_sec_path(skb); + const struct xfrm_state *x; + + if (!sp || unlikely(index >= sp->len || flags)) + goto err_clear; + + x = sp->xvec[index]; + + if (unlikely(size != sizeof(struct bpf_xfrm_state))) + goto err_clear; + + to->reqid = x->props.reqid; + to->spi = x->id.spi; + to->family = x->props.family; + to->ext = 0; + + if (to->family == AF_INET6) { + memcpy(to->remote_ipv6, x->props.saddr.a6, + sizeof(to->remote_ipv6)); + } else { + to->remote_ipv4 = x->props.saddr.a4; + memset(&to->remote_ipv6[1], 0, sizeof(__u32) * 3); + } + + return 0; +err_clear: + memset(to, 0, size); + return -EINVAL; +} + +static const struct bpf_func_proto bpf_skb_get_xfrm_state_proto = { + .func = bpf_skb_get_xfrm_state, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_PTR_TO_UNINIT_MEM, + .arg4_type = ARG_CONST_SIZE, + .arg5_type = ARG_ANYTHING, +}; +#endif + +#if IS_ENABLED(CONFIG_INET) || IS_ENABLED(CONFIG_IPV6) +static int bpf_fib_set_fwd_params(struct bpf_fib_lookup *params, u32 mtu) +{ + params->h_vlan_TCI = 0; + params->h_vlan_proto = 0; + if (mtu) + params->mtu_result = mtu; /* union with tot_len */ + + return 0; +} +#endif + +#if IS_ENABLED(CONFIG_INET) +static int bpf_ipv4_fib_lookup(struct net *net, struct bpf_fib_lookup *params, + u32 flags, bool check_mtu) +{ + struct fib_nh_common *nhc; + struct in_device *in_dev; + struct neighbour *neigh; + struct net_device *dev; + struct fib_result res; + struct flowi4 fl4; + u32 mtu = 0; + int err; + + dev = dev_get_by_index_rcu(net, params->ifindex); + if (unlikely(!dev)) + return -ENODEV; + + /* verify forwarding is enabled on this interface */ + in_dev = __in_dev_get_rcu(dev); + if (unlikely(!in_dev || !IN_DEV_FORWARD(in_dev))) + return BPF_FIB_LKUP_RET_FWD_DISABLED; + + if (flags & BPF_FIB_LOOKUP_OUTPUT) { + fl4.flowi4_iif = 1; + fl4.flowi4_oif = params->ifindex; + } else { + fl4.flowi4_iif = params->ifindex; + fl4.flowi4_oif = 0; + } + fl4.flowi4_tos = params->tos & IPTOS_RT_MASK; + fl4.flowi4_scope = RT_SCOPE_UNIVERSE; + fl4.flowi4_flags = 0; + + fl4.flowi4_proto = params->l4_protocol; + fl4.daddr = params->ipv4_dst; + fl4.saddr = params->ipv4_src; + fl4.fl4_sport = params->sport; + fl4.fl4_dport = params->dport; + fl4.flowi4_multipath_hash = 0; + + if (flags & BPF_FIB_LOOKUP_DIRECT) { + u32 tbid = l3mdev_fib_table_rcu(dev) ? : RT_TABLE_MAIN; + struct fib_table *tb; + + tb = fib_get_table(net, tbid); + if (unlikely(!tb)) + return BPF_FIB_LKUP_RET_NOT_FWDED; + + err = fib_table_lookup(tb, &fl4, &res, FIB_LOOKUP_NOREF); + } else { + fl4.flowi4_mark = 0; + fl4.flowi4_secid = 0; + fl4.flowi4_tun_key.tun_id = 0; + fl4.flowi4_uid = sock_net_uid(net, NULL); + + err = fib_lookup(net, &fl4, &res, FIB_LOOKUP_NOREF); + } + + if (err) { + /* map fib lookup errors to RTN_ type */ + if (err == -EINVAL) + return BPF_FIB_LKUP_RET_BLACKHOLE; + if (err == -EHOSTUNREACH) + return BPF_FIB_LKUP_RET_UNREACHABLE; + if (err == -EACCES) + return BPF_FIB_LKUP_RET_PROHIBIT; + + return BPF_FIB_LKUP_RET_NOT_FWDED; + } + + if (res.type != RTN_UNICAST) + return BPF_FIB_LKUP_RET_NOT_FWDED; + + if (fib_info_num_path(res.fi) > 1) + fib_select_path(net, &res, &fl4, NULL); + + if (check_mtu) { + mtu = ip_mtu_from_fib_result(&res, params->ipv4_dst); + if (params->tot_len > mtu) { + params->mtu_result = mtu; /* union with tot_len */ + return BPF_FIB_LKUP_RET_FRAG_NEEDED; + } + } + + nhc = res.nhc; + + /* do not handle lwt encaps right now */ + if (nhc->nhc_lwtstate) + return BPF_FIB_LKUP_RET_UNSUPP_LWT; + + dev = nhc->nhc_dev; + + params->rt_metric = res.fi->fib_priority; + params->ifindex = dev->ifindex; + + /* xdp and cls_bpf programs are run in RCU-bh so + * rcu_read_lock_bh is not needed here + */ + if (likely(nhc->nhc_gw_family != AF_INET6)) { + if (nhc->nhc_gw_family) + params->ipv4_dst = nhc->nhc_gw.ipv4; + } else { + struct in6_addr *dst = (struct in6_addr *)params->ipv6_dst; + + params->family = AF_INET6; + *dst = nhc->nhc_gw.ipv6; + } + + if (flags & BPF_FIB_LOOKUP_SKIP_NEIGH) + goto set_fwd_params; + + if (likely(nhc->nhc_gw_family != AF_INET6)) + neigh = __ipv4_neigh_lookup_noref(dev, + (__force u32)params->ipv4_dst); + else + neigh = __ipv6_neigh_lookup_noref_stub(dev, params->ipv6_dst); + + if (!neigh || !(READ_ONCE(neigh->nud_state) & NUD_VALID)) + return BPF_FIB_LKUP_RET_NO_NEIGH; + memcpy(params->dmac, neigh->ha, ETH_ALEN); + memcpy(params->smac, dev->dev_addr, ETH_ALEN); + +set_fwd_params: + return bpf_fib_set_fwd_params(params, mtu); +} +#endif + +#if IS_ENABLED(CONFIG_IPV6) +static int bpf_ipv6_fib_lookup(struct net *net, struct bpf_fib_lookup *params, + u32 flags, bool check_mtu) +{ + struct in6_addr *src = (struct in6_addr *) params->ipv6_src; + struct in6_addr *dst = (struct in6_addr *) params->ipv6_dst; + struct fib6_result res = {}; + struct neighbour *neigh; + struct net_device *dev; + struct inet6_dev *idev; + struct flowi6 fl6; + int strict = 0; + int oif, err; + u32 mtu = 0; + + /* link local addresses are never forwarded */ + if (rt6_need_strict(dst) || rt6_need_strict(src)) + return BPF_FIB_LKUP_RET_NOT_FWDED; + + dev = dev_get_by_index_rcu(net, params->ifindex); + if (unlikely(!dev)) + return -ENODEV; + + idev = __in6_dev_get_safely(dev); + if (unlikely(!idev || !idev->cnf.forwarding)) + return BPF_FIB_LKUP_RET_FWD_DISABLED; + + if (flags & BPF_FIB_LOOKUP_OUTPUT) { + fl6.flowi6_iif = 1; + oif = fl6.flowi6_oif = params->ifindex; + } else { + oif = fl6.flowi6_iif = params->ifindex; + fl6.flowi6_oif = 0; + strict = RT6_LOOKUP_F_HAS_SADDR; + } + fl6.flowlabel = params->flowinfo; + fl6.flowi6_scope = 0; + fl6.flowi6_flags = 0; + fl6.mp_hash = 0; + + fl6.flowi6_proto = params->l4_protocol; + fl6.daddr = *dst; + fl6.saddr = *src; + fl6.fl6_sport = params->sport; + fl6.fl6_dport = params->dport; + + if (flags & BPF_FIB_LOOKUP_DIRECT) { + u32 tbid = l3mdev_fib_table_rcu(dev) ? : RT_TABLE_MAIN; + struct fib6_table *tb; + + tb = ipv6_stub->fib6_get_table(net, tbid); + if (unlikely(!tb)) + return BPF_FIB_LKUP_RET_NOT_FWDED; + + err = ipv6_stub->fib6_table_lookup(net, tb, oif, &fl6, &res, + strict); + } else { + fl6.flowi6_mark = 0; + fl6.flowi6_secid = 0; + fl6.flowi6_tun_key.tun_id = 0; + fl6.flowi6_uid = sock_net_uid(net, NULL); + + err = ipv6_stub->fib6_lookup(net, oif, &fl6, &res, strict); + } + + if (unlikely(err || IS_ERR_OR_NULL(res.f6i) || + res.f6i == net->ipv6.fib6_null_entry)) + return BPF_FIB_LKUP_RET_NOT_FWDED; + + switch (res.fib6_type) { + /* only unicast is forwarded */ + case RTN_UNICAST: + break; + case RTN_BLACKHOLE: + return BPF_FIB_LKUP_RET_BLACKHOLE; + case RTN_UNREACHABLE: + return BPF_FIB_LKUP_RET_UNREACHABLE; + case RTN_PROHIBIT: + return BPF_FIB_LKUP_RET_PROHIBIT; + default: + return BPF_FIB_LKUP_RET_NOT_FWDED; + } + + ipv6_stub->fib6_select_path(net, &res, &fl6, fl6.flowi6_oif, + fl6.flowi6_oif != 0, NULL, strict); + + if (check_mtu) { + mtu = ipv6_stub->ip6_mtu_from_fib6(&res, dst, src); + if (params->tot_len > mtu) { + params->mtu_result = mtu; /* union with tot_len */ + return BPF_FIB_LKUP_RET_FRAG_NEEDED; + } + } + + if (res.nh->fib_nh_lws) + return BPF_FIB_LKUP_RET_UNSUPP_LWT; + + if (res.nh->fib_nh_gw_family) + *dst = res.nh->fib_nh_gw6; + + dev = res.nh->fib_nh_dev; + params->rt_metric = res.f6i->fib6_metric; + params->ifindex = dev->ifindex; + + if (flags & BPF_FIB_LOOKUP_SKIP_NEIGH) + goto set_fwd_params; + + /* xdp and cls_bpf programs are run in RCU-bh so rcu_read_lock_bh is + * not needed here. + */ + neigh = __ipv6_neigh_lookup_noref_stub(dev, dst); + if (!neigh || !(READ_ONCE(neigh->nud_state) & NUD_VALID)) + return BPF_FIB_LKUP_RET_NO_NEIGH; + memcpy(params->dmac, neigh->ha, ETH_ALEN); + memcpy(params->smac, dev->dev_addr, ETH_ALEN); + +set_fwd_params: + return bpf_fib_set_fwd_params(params, mtu); +} +#endif + +#define BPF_FIB_LOOKUP_MASK (BPF_FIB_LOOKUP_DIRECT | BPF_FIB_LOOKUP_OUTPUT | \ + BPF_FIB_LOOKUP_SKIP_NEIGH) + +BPF_CALL_4(bpf_xdp_fib_lookup, struct xdp_buff *, ctx, + struct bpf_fib_lookup *, params, int, plen, u32, flags) +{ + if (plen < sizeof(*params)) + return -EINVAL; + + if (flags & ~BPF_FIB_LOOKUP_MASK) + return -EINVAL; + + switch (params->family) { +#if IS_ENABLED(CONFIG_INET) + case AF_INET: + return bpf_ipv4_fib_lookup(dev_net(ctx->rxq->dev), params, + flags, true); +#endif +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + return bpf_ipv6_fib_lookup(dev_net(ctx->rxq->dev), params, + flags, true); +#endif + } + return -EAFNOSUPPORT; +} + +static const struct bpf_func_proto bpf_xdp_fib_lookup_proto = { + .func = bpf_xdp_fib_lookup, + .gpl_only = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, +}; + +BPF_CALL_4(bpf_skb_fib_lookup, struct sk_buff *, skb, + struct bpf_fib_lookup *, params, int, plen, u32, flags) +{ + struct net *net = dev_net(skb->dev); + int rc = -EAFNOSUPPORT; + bool check_mtu = false; + + if (plen < sizeof(*params)) + return -EINVAL; + + if (flags & ~BPF_FIB_LOOKUP_MASK) + return -EINVAL; + + if (params->tot_len) + check_mtu = true; + + switch (params->family) { +#if IS_ENABLED(CONFIG_INET) + case AF_INET: + rc = bpf_ipv4_fib_lookup(net, params, flags, check_mtu); + break; +#endif +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + rc = bpf_ipv6_fib_lookup(net, params, flags, check_mtu); + break; +#endif + } + + if (rc == BPF_FIB_LKUP_RET_SUCCESS && !check_mtu) { + struct net_device *dev; + + /* When tot_len isn't provided by user, check skb + * against MTU of FIB lookup resulting net_device + */ + dev = dev_get_by_index_rcu(net, params->ifindex); + if (!is_skb_forwardable(dev, skb)) + rc = BPF_FIB_LKUP_RET_FRAG_NEEDED; + + params->mtu_result = dev->mtu; /* union with tot_len */ + } + + return rc; +} + +static const struct bpf_func_proto bpf_skb_fib_lookup_proto = { + .func = bpf_skb_fib_lookup, + .gpl_only = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, +}; + +static struct net_device *__dev_via_ifindex(struct net_device *dev_curr, + u32 ifindex) +{ + struct net *netns = dev_net(dev_curr); + + /* Non-redirect use-cases can use ifindex=0 and save ifindex lookup */ + if (ifindex == 0) + return dev_curr; + + return dev_get_by_index_rcu(netns, ifindex); +} + +BPF_CALL_5(bpf_skb_check_mtu, struct sk_buff *, skb, + u32, ifindex, u32 *, mtu_len, s32, len_diff, u64, flags) +{ + int ret = BPF_MTU_CHK_RET_FRAG_NEEDED; + struct net_device *dev = skb->dev; + int skb_len, dev_len; + int mtu; + + if (unlikely(flags & ~(BPF_MTU_CHK_SEGS))) + return -EINVAL; + + if (unlikely(flags & BPF_MTU_CHK_SEGS && (len_diff || *mtu_len))) + return -EINVAL; + + dev = __dev_via_ifindex(dev, ifindex); + if (unlikely(!dev)) + return -ENODEV; + + mtu = READ_ONCE(dev->mtu); + + dev_len = mtu + dev->hard_header_len; + + /* If set use *mtu_len as input, L3 as iph->tot_len (like fib_lookup) */ + skb_len = *mtu_len ? *mtu_len + dev->hard_header_len : skb->len; + + skb_len += len_diff; /* minus result pass check */ + if (skb_len <= dev_len) { + ret = BPF_MTU_CHK_RET_SUCCESS; + goto out; + } + /* At this point, skb->len exceed MTU, but as it include length of all + * segments, it can still be below MTU. The SKB can possibly get + * re-segmented in transmit path (see validate_xmit_skb). Thus, user + * must choose if segs are to be MTU checked. + */ + if (skb_is_gso(skb)) { + ret = BPF_MTU_CHK_RET_SUCCESS; + + if (flags & BPF_MTU_CHK_SEGS && + !skb_gso_validate_network_len(skb, mtu)) + ret = BPF_MTU_CHK_RET_SEGS_TOOBIG; + } +out: + /* BPF verifier guarantees valid pointer */ + *mtu_len = mtu; + + return ret; +} + +BPF_CALL_5(bpf_xdp_check_mtu, struct xdp_buff *, xdp, + u32, ifindex, u32 *, mtu_len, s32, len_diff, u64, flags) +{ + struct net_device *dev = xdp->rxq->dev; + int xdp_len = xdp->data_end - xdp->data; + int ret = BPF_MTU_CHK_RET_SUCCESS; + int mtu, dev_len; + + /* XDP variant doesn't support multi-buffer segment check (yet) */ + if (unlikely(flags)) + return -EINVAL; + + dev = __dev_via_ifindex(dev, ifindex); + if (unlikely(!dev)) + return -ENODEV; + + mtu = READ_ONCE(dev->mtu); + + /* Add L2-header as dev MTU is L3 size */ + dev_len = mtu + dev->hard_header_len; + + /* Use *mtu_len as input, L3 as iph->tot_len (like fib_lookup) */ + if (*mtu_len) + xdp_len = *mtu_len + dev->hard_header_len; + + xdp_len += len_diff; /* minus result pass check */ + if (xdp_len > dev_len) + ret = BPF_MTU_CHK_RET_FRAG_NEEDED; + + /* BPF verifier guarantees valid pointer */ + *mtu_len = mtu; + + return ret; +} + +static const struct bpf_func_proto bpf_skb_check_mtu_proto = { + .func = bpf_skb_check_mtu, + .gpl_only = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_PTR_TO_INT, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +static const struct bpf_func_proto bpf_xdp_check_mtu_proto = { + .func = bpf_xdp_check_mtu, + .gpl_only = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_PTR_TO_INT, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +#if IS_ENABLED(CONFIG_IPV6_SEG6_BPF) +static int bpf_push_seg6_encap(struct sk_buff *skb, u32 type, void *hdr, u32 len) +{ + int err; + struct ipv6_sr_hdr *srh = (struct ipv6_sr_hdr *)hdr; + + if (!seg6_validate_srh(srh, len, false)) + return -EINVAL; + + switch (type) { + case BPF_LWT_ENCAP_SEG6_INLINE: + if (skb->protocol != htons(ETH_P_IPV6)) + return -EBADMSG; + + err = seg6_do_srh_inline(skb, srh); + break; + case BPF_LWT_ENCAP_SEG6: + skb_reset_inner_headers(skb); + skb->encapsulation = 1; + err = seg6_do_srh_encap(skb, srh, IPPROTO_IPV6); + break; + default: + return -EINVAL; + } + + bpf_compute_data_pointers(skb); + if (err) + return err; + + skb_set_transport_header(skb, sizeof(struct ipv6hdr)); + + return seg6_lookup_nexthop(skb, NULL, 0); +} +#endif /* CONFIG_IPV6_SEG6_BPF */ + +#if IS_ENABLED(CONFIG_LWTUNNEL_BPF) +static int bpf_push_ip_encap(struct sk_buff *skb, void *hdr, u32 len, + bool ingress) +{ + return bpf_lwt_push_ip_encap(skb, hdr, len, ingress); +} +#endif + +BPF_CALL_4(bpf_lwt_in_push_encap, struct sk_buff *, skb, u32, type, void *, hdr, + u32, len) +{ + switch (type) { +#if IS_ENABLED(CONFIG_IPV6_SEG6_BPF) + case BPF_LWT_ENCAP_SEG6: + case BPF_LWT_ENCAP_SEG6_INLINE: + return bpf_push_seg6_encap(skb, type, hdr, len); +#endif +#if IS_ENABLED(CONFIG_LWTUNNEL_BPF) + case BPF_LWT_ENCAP_IP: + return bpf_push_ip_encap(skb, hdr, len, true /* ingress */); +#endif + default: + return -EINVAL; + } +} + +BPF_CALL_4(bpf_lwt_xmit_push_encap, struct sk_buff *, skb, u32, type, + void *, hdr, u32, len) +{ + switch (type) { +#if IS_ENABLED(CONFIG_LWTUNNEL_BPF) + case BPF_LWT_ENCAP_IP: + return bpf_push_ip_encap(skb, hdr, len, false /* egress */); +#endif + default: + return -EINVAL; + } +} + +static const struct bpf_func_proto bpf_lwt_in_push_encap_proto = { + .func = bpf_lwt_in_push_encap, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg4_type = ARG_CONST_SIZE +}; + +static const struct bpf_func_proto bpf_lwt_xmit_push_encap_proto = { + .func = bpf_lwt_xmit_push_encap, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg4_type = ARG_CONST_SIZE +}; + +#if IS_ENABLED(CONFIG_IPV6_SEG6_BPF) +BPF_CALL_4(bpf_lwt_seg6_store_bytes, struct sk_buff *, skb, u32, offset, + const void *, from, u32, len) +{ + struct seg6_bpf_srh_state *srh_state = + this_cpu_ptr(&seg6_bpf_srh_states); + struct ipv6_sr_hdr *srh = srh_state->srh; + void *srh_tlvs, *srh_end, *ptr; + int srhoff = 0; + + if (srh == NULL) + return -EINVAL; + + srh_tlvs = (void *)((char *)srh + ((srh->first_segment + 1) << 4)); + srh_end = (void *)((char *)srh + sizeof(*srh) + srh_state->hdrlen); + + ptr = skb->data + offset; + if (ptr >= srh_tlvs && ptr + len <= srh_end) + srh_state->valid = false; + else if (ptr < (void *)&srh->flags || + ptr + len > (void *)&srh->segments) + return -EFAULT; + + if (unlikely(bpf_try_make_writable(skb, offset + len))) + return -EFAULT; + if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, NULL) < 0) + return -EINVAL; + srh_state->srh = (struct ipv6_sr_hdr *)(skb->data + srhoff); + + memcpy(skb->data + offset, from, len); + return 0; +} + +static const struct bpf_func_proto bpf_lwt_seg6_store_bytes_proto = { + .func = bpf_lwt_seg6_store_bytes, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg4_type = ARG_CONST_SIZE +}; + +static void bpf_update_srh_state(struct sk_buff *skb) +{ + struct seg6_bpf_srh_state *srh_state = + this_cpu_ptr(&seg6_bpf_srh_states); + int srhoff = 0; + + if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, NULL) < 0) { + srh_state->srh = NULL; + } else { + srh_state->srh = (struct ipv6_sr_hdr *)(skb->data + srhoff); + srh_state->hdrlen = srh_state->srh->hdrlen << 3; + srh_state->valid = true; + } +} + +BPF_CALL_4(bpf_lwt_seg6_action, struct sk_buff *, skb, + u32, action, void *, param, u32, param_len) +{ + struct seg6_bpf_srh_state *srh_state = + this_cpu_ptr(&seg6_bpf_srh_states); + int hdroff = 0; + int err; + + switch (action) { + case SEG6_LOCAL_ACTION_END_X: + if (!seg6_bpf_has_valid_srh(skb)) + return -EBADMSG; + if (param_len != sizeof(struct in6_addr)) + return -EINVAL; + return seg6_lookup_nexthop(skb, (struct in6_addr *)param, 0); + case SEG6_LOCAL_ACTION_END_T: + if (!seg6_bpf_has_valid_srh(skb)) + return -EBADMSG; + if (param_len != sizeof(int)) + return -EINVAL; + return seg6_lookup_nexthop(skb, NULL, *(int *)param); + case SEG6_LOCAL_ACTION_END_DT6: + if (!seg6_bpf_has_valid_srh(skb)) + return -EBADMSG; + if (param_len != sizeof(int)) + return -EINVAL; + + if (ipv6_find_hdr(skb, &hdroff, IPPROTO_IPV6, NULL, NULL) < 0) + return -EBADMSG; + if (!pskb_pull(skb, hdroff)) + return -EBADMSG; + + skb_postpull_rcsum(skb, skb_network_header(skb), hdroff); + skb_reset_network_header(skb); + skb_reset_transport_header(skb); + skb->encapsulation = 0; + + bpf_compute_data_pointers(skb); + bpf_update_srh_state(skb); + return seg6_lookup_nexthop(skb, NULL, *(int *)param); + case SEG6_LOCAL_ACTION_END_B6: + if (srh_state->srh && !seg6_bpf_has_valid_srh(skb)) + return -EBADMSG; + err = bpf_push_seg6_encap(skb, BPF_LWT_ENCAP_SEG6_INLINE, + param, param_len); + if (!err) + bpf_update_srh_state(skb); + + return err; + case SEG6_LOCAL_ACTION_END_B6_ENCAP: + if (srh_state->srh && !seg6_bpf_has_valid_srh(skb)) + return -EBADMSG; + err = bpf_push_seg6_encap(skb, BPF_LWT_ENCAP_SEG6, + param, param_len); + if (!err) + bpf_update_srh_state(skb); + + return err; + default: + return -EINVAL; + } +} + +static const struct bpf_func_proto bpf_lwt_seg6_action_proto = { + .func = bpf_lwt_seg6_action, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg4_type = ARG_CONST_SIZE +}; + +BPF_CALL_3(bpf_lwt_seg6_adjust_srh, struct sk_buff *, skb, u32, offset, + s32, len) +{ + struct seg6_bpf_srh_state *srh_state = + this_cpu_ptr(&seg6_bpf_srh_states); + struct ipv6_sr_hdr *srh = srh_state->srh; + void *srh_end, *srh_tlvs, *ptr; + struct ipv6hdr *hdr; + int srhoff = 0; + int ret; + + if (unlikely(srh == NULL)) + return -EINVAL; + + srh_tlvs = (void *)((unsigned char *)srh + sizeof(*srh) + + ((srh->first_segment + 1) << 4)); + srh_end = (void *)((unsigned char *)srh + sizeof(*srh) + + srh_state->hdrlen); + ptr = skb->data + offset; + + if (unlikely(ptr < srh_tlvs || ptr > srh_end)) + return -EFAULT; + if (unlikely(len < 0 && (void *)((char *)ptr - len) > srh_end)) + return -EFAULT; + + if (len > 0) { + ret = skb_cow_head(skb, len); + if (unlikely(ret < 0)) + return ret; + + ret = bpf_skb_net_hdr_push(skb, offset, len); + } else { + ret = bpf_skb_net_hdr_pop(skb, offset, -1 * len); + } + + bpf_compute_data_pointers(skb); + if (unlikely(ret < 0)) + return ret; + + hdr = (struct ipv6hdr *)skb->data; + hdr->payload_len = htons(skb->len - sizeof(struct ipv6hdr)); + + if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, NULL) < 0) + return -EINVAL; + srh_state->srh = (struct ipv6_sr_hdr *)(skb->data + srhoff); + srh_state->hdrlen += len; + srh_state->valid = false; + return 0; +} + +static const struct bpf_func_proto bpf_lwt_seg6_adjust_srh_proto = { + .func = bpf_lwt_seg6_adjust_srh, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, +}; +#endif /* CONFIG_IPV6_SEG6_BPF */ + +#ifdef CONFIG_INET +static struct sock *sk_lookup(struct net *net, struct bpf_sock_tuple *tuple, + int dif, int sdif, u8 family, u8 proto) +{ + struct inet_hashinfo *hinfo = net->ipv4.tcp_death_row.hashinfo; + bool refcounted = false; + struct sock *sk = NULL; + + if (family == AF_INET) { + __be32 src4 = tuple->ipv4.saddr; + __be32 dst4 = tuple->ipv4.daddr; + + if (proto == IPPROTO_TCP) + sk = __inet_lookup(net, hinfo, NULL, 0, + src4, tuple->ipv4.sport, + dst4, tuple->ipv4.dport, + dif, sdif, &refcounted); + else + sk = __udp4_lib_lookup(net, src4, tuple->ipv4.sport, + dst4, tuple->ipv4.dport, + dif, sdif, &udp_table, NULL); +#if IS_ENABLED(CONFIG_IPV6) + } else { + struct in6_addr *src6 = (struct in6_addr *)&tuple->ipv6.saddr; + struct in6_addr *dst6 = (struct in6_addr *)&tuple->ipv6.daddr; + + if (proto == IPPROTO_TCP) + sk = __inet6_lookup(net, hinfo, NULL, 0, + src6, tuple->ipv6.sport, + dst6, ntohs(tuple->ipv6.dport), + dif, sdif, &refcounted); + else if (likely(ipv6_bpf_stub)) + sk = ipv6_bpf_stub->udp6_lib_lookup(net, + src6, tuple->ipv6.sport, + dst6, tuple->ipv6.dport, + dif, sdif, + &udp_table, NULL); +#endif + } + + if (unlikely(sk && !refcounted && !sock_flag(sk, SOCK_RCU_FREE))) { + WARN_ONCE(1, "Found non-RCU, unreferenced socket!"); + sk = NULL; + } + return sk; +} + +/* bpf_skc_lookup performs the core lookup for different types of sockets, + * taking a reference on the socket if it doesn't have the flag SOCK_RCU_FREE. + */ +static struct sock * +__bpf_skc_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len, + struct net *caller_net, u32 ifindex, u8 proto, u64 netns_id, + u64 flags, int sdif) +{ + struct sock *sk = NULL; + struct net *net; + u8 family; + + if (len == sizeof(tuple->ipv4)) + family = AF_INET; + else if (len == sizeof(tuple->ipv6)) + family = AF_INET6; + else + return NULL; + + if (unlikely(flags || !((s32)netns_id < 0 || netns_id <= S32_MAX))) + goto out; + + if (sdif < 0) { + if (family == AF_INET) + sdif = inet_sdif(skb); + else + sdif = inet6_sdif(skb); + } + + if ((s32)netns_id < 0) { + net = caller_net; + sk = sk_lookup(net, tuple, ifindex, sdif, family, proto); + } else { + net = get_net_ns_by_id(caller_net, netns_id); + if (unlikely(!net)) + goto out; + sk = sk_lookup(net, tuple, ifindex, sdif, family, proto); + put_net(net); + } + +out: + return sk; +} + +static struct sock * +__bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len, + struct net *caller_net, u32 ifindex, u8 proto, u64 netns_id, + u64 flags, int sdif) +{ + struct sock *sk = __bpf_skc_lookup(skb, tuple, len, caller_net, + ifindex, proto, netns_id, flags, + sdif); + + if (sk) { + struct sock *sk2 = sk_to_full_sk(sk); + + /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk + * sock refcnt is decremented to prevent a request_sock leak. + */ + if (!sk_fullsock(sk2)) + sk2 = NULL; + if (sk2 != sk) { + sock_gen_put(sk); + /* Ensure there is no need to bump sk2 refcnt */ + if (unlikely(sk2 && !sock_flag(sk2, SOCK_RCU_FREE))) { + WARN_ONCE(1, "Found non-RCU, unreferenced socket!"); + return NULL; + } + sk = sk2; + } + } + + return sk; +} + +static struct sock * +bpf_skc_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len, + u8 proto, u64 netns_id, u64 flags) +{ + struct net *caller_net; + int ifindex; + + if (skb->dev) { + caller_net = dev_net(skb->dev); + ifindex = skb->dev->ifindex; + } else { + caller_net = sock_net(skb->sk); + ifindex = 0; + } + + return __bpf_skc_lookup(skb, tuple, len, caller_net, ifindex, proto, + netns_id, flags, -1); +} + +static struct sock * +bpf_sk_lookup(struct sk_buff *skb, struct bpf_sock_tuple *tuple, u32 len, + u8 proto, u64 netns_id, u64 flags) +{ + struct sock *sk = bpf_skc_lookup(skb, tuple, len, proto, netns_id, + flags); + + if (sk) { + struct sock *sk2 = sk_to_full_sk(sk); + + /* sk_to_full_sk() may return (sk)->rsk_listener, so make sure the original sk + * sock refcnt is decremented to prevent a request_sock leak. + */ + if (!sk_fullsock(sk2)) + sk2 = NULL; + if (sk2 != sk) { + sock_gen_put(sk); + /* Ensure there is no need to bump sk2 refcnt */ + if (unlikely(sk2 && !sock_flag(sk2, SOCK_RCU_FREE))) { + WARN_ONCE(1, "Found non-RCU, unreferenced socket!"); + return NULL; + } + sk = sk2; + } + } + + return sk; +} + +BPF_CALL_5(bpf_skc_lookup_tcp, struct sk_buff *, skb, + struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags) +{ + return (unsigned long)bpf_skc_lookup(skb, tuple, len, IPPROTO_TCP, + netns_id, flags); +} + +static const struct bpf_func_proto bpf_skc_lookup_tcp_proto = { + .func = bpf_skc_lookup_tcp, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_PTR_TO_SOCK_COMMON_OR_NULL, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_5(bpf_sk_lookup_tcp, struct sk_buff *, skb, + struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags) +{ + return (unsigned long)bpf_sk_lookup(skb, tuple, len, IPPROTO_TCP, + netns_id, flags); +} + +static const struct bpf_func_proto bpf_sk_lookup_tcp_proto = { + .func = bpf_sk_lookup_tcp, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_PTR_TO_SOCKET_OR_NULL, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_5(bpf_sk_lookup_udp, struct sk_buff *, skb, + struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags) +{ + return (unsigned long)bpf_sk_lookup(skb, tuple, len, IPPROTO_UDP, + netns_id, flags); +} + +static const struct bpf_func_proto bpf_sk_lookup_udp_proto = { + .func = bpf_sk_lookup_udp, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_PTR_TO_SOCKET_OR_NULL, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_5(bpf_tc_skc_lookup_tcp, struct sk_buff *, skb, + struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags) +{ + struct net_device *dev = skb->dev; + int ifindex = dev->ifindex, sdif = dev_sdif(dev); + struct net *caller_net = dev_net(dev); + + return (unsigned long)__bpf_skc_lookup(skb, tuple, len, caller_net, + ifindex, IPPROTO_TCP, netns_id, + flags, sdif); +} + +static const struct bpf_func_proto bpf_tc_skc_lookup_tcp_proto = { + .func = bpf_tc_skc_lookup_tcp, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_PTR_TO_SOCK_COMMON_OR_NULL, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_5(bpf_tc_sk_lookup_tcp, struct sk_buff *, skb, + struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags) +{ + struct net_device *dev = skb->dev; + int ifindex = dev->ifindex, sdif = dev_sdif(dev); + struct net *caller_net = dev_net(dev); + + return (unsigned long)__bpf_sk_lookup(skb, tuple, len, caller_net, + ifindex, IPPROTO_TCP, netns_id, + flags, sdif); +} + +static const struct bpf_func_proto bpf_tc_sk_lookup_tcp_proto = { + .func = bpf_tc_sk_lookup_tcp, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_PTR_TO_SOCKET_OR_NULL, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_5(bpf_tc_sk_lookup_udp, struct sk_buff *, skb, + struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags) +{ + struct net_device *dev = skb->dev; + int ifindex = dev->ifindex, sdif = dev_sdif(dev); + struct net *caller_net = dev_net(dev); + + return (unsigned long)__bpf_sk_lookup(skb, tuple, len, caller_net, + ifindex, IPPROTO_UDP, netns_id, + flags, sdif); +} + +static const struct bpf_func_proto bpf_tc_sk_lookup_udp_proto = { + .func = bpf_tc_sk_lookup_udp, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_PTR_TO_SOCKET_OR_NULL, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_1(bpf_sk_release, struct sock *, sk) +{ + if (sk && sk_is_refcounted(sk)) + sock_gen_put(sk); + return 0; +} + +static const struct bpf_func_proto bpf_sk_release_proto = { + .func = bpf_sk_release, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON | OBJ_RELEASE, +}; + +BPF_CALL_5(bpf_xdp_sk_lookup_udp, struct xdp_buff *, ctx, + struct bpf_sock_tuple *, tuple, u32, len, u32, netns_id, u64, flags) +{ + struct net_device *dev = ctx->rxq->dev; + int ifindex = dev->ifindex, sdif = dev_sdif(dev); + struct net *caller_net = dev_net(dev); + + return (unsigned long)__bpf_sk_lookup(NULL, tuple, len, caller_net, + ifindex, IPPROTO_UDP, netns_id, + flags, sdif); +} + +static const struct bpf_func_proto bpf_xdp_sk_lookup_udp_proto = { + .func = bpf_xdp_sk_lookup_udp, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_PTR_TO_SOCKET_OR_NULL, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_5(bpf_xdp_skc_lookup_tcp, struct xdp_buff *, ctx, + struct bpf_sock_tuple *, tuple, u32, len, u32, netns_id, u64, flags) +{ + struct net_device *dev = ctx->rxq->dev; + int ifindex = dev->ifindex, sdif = dev_sdif(dev); + struct net *caller_net = dev_net(dev); + + return (unsigned long)__bpf_skc_lookup(NULL, tuple, len, caller_net, + ifindex, IPPROTO_TCP, netns_id, + flags, sdif); +} + +static const struct bpf_func_proto bpf_xdp_skc_lookup_tcp_proto = { + .func = bpf_xdp_skc_lookup_tcp, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_PTR_TO_SOCK_COMMON_OR_NULL, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_5(bpf_xdp_sk_lookup_tcp, struct xdp_buff *, ctx, + struct bpf_sock_tuple *, tuple, u32, len, u32, netns_id, u64, flags) +{ + struct net_device *dev = ctx->rxq->dev; + int ifindex = dev->ifindex, sdif = dev_sdif(dev); + struct net *caller_net = dev_net(dev); + + return (unsigned long)__bpf_sk_lookup(NULL, tuple, len, caller_net, + ifindex, IPPROTO_TCP, netns_id, + flags, sdif); +} + +static const struct bpf_func_proto bpf_xdp_sk_lookup_tcp_proto = { + .func = bpf_xdp_sk_lookup_tcp, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_PTR_TO_SOCKET_OR_NULL, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_5(bpf_sock_addr_skc_lookup_tcp, struct bpf_sock_addr_kern *, ctx, + struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags) +{ + return (unsigned long)__bpf_skc_lookup(NULL, tuple, len, + sock_net(ctx->sk), 0, + IPPROTO_TCP, netns_id, flags, + -1); +} + +static const struct bpf_func_proto bpf_sock_addr_skc_lookup_tcp_proto = { + .func = bpf_sock_addr_skc_lookup_tcp, + .gpl_only = false, + .ret_type = RET_PTR_TO_SOCK_COMMON_OR_NULL, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_5(bpf_sock_addr_sk_lookup_tcp, struct bpf_sock_addr_kern *, ctx, + struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags) +{ + return (unsigned long)__bpf_sk_lookup(NULL, tuple, len, + sock_net(ctx->sk), 0, IPPROTO_TCP, + netns_id, flags, -1); +} + +static const struct bpf_func_proto bpf_sock_addr_sk_lookup_tcp_proto = { + .func = bpf_sock_addr_sk_lookup_tcp, + .gpl_only = false, + .ret_type = RET_PTR_TO_SOCKET_OR_NULL, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +BPF_CALL_5(bpf_sock_addr_sk_lookup_udp, struct bpf_sock_addr_kern *, ctx, + struct bpf_sock_tuple *, tuple, u32, len, u64, netns_id, u64, flags) +{ + return (unsigned long)__bpf_sk_lookup(NULL, tuple, len, + sock_net(ctx->sk), 0, IPPROTO_UDP, + netns_id, flags, -1); +} + +static const struct bpf_func_proto bpf_sock_addr_sk_lookup_udp_proto = { + .func = bpf_sock_addr_sk_lookup_udp, + .gpl_only = false, + .ret_type = RET_PTR_TO_SOCKET_OR_NULL, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, + .arg5_type = ARG_ANYTHING, +}; + +bool bpf_tcp_sock_is_valid_access(int off, int size, enum bpf_access_type type, + struct bpf_insn_access_aux *info) +{ + if (off < 0 || off >= offsetofend(struct bpf_tcp_sock, + icsk_retransmits)) + return false; + + if (off % size != 0) + return false; + + switch (off) { + case offsetof(struct bpf_tcp_sock, bytes_received): + case offsetof(struct bpf_tcp_sock, bytes_acked): + return size == sizeof(__u64); + default: + return size == sizeof(__u32); + } +} + +u32 bpf_tcp_sock_convert_ctx_access(enum bpf_access_type type, + const struct bpf_insn *si, + struct bpf_insn *insn_buf, + struct bpf_prog *prog, u32 *target_size) +{ + struct bpf_insn *insn = insn_buf; + +#define BPF_TCP_SOCK_GET_COMMON(FIELD) \ + do { \ + BUILD_BUG_ON(sizeof_field(struct tcp_sock, FIELD) > \ + sizeof_field(struct bpf_tcp_sock, FIELD)); \ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct tcp_sock, FIELD),\ + si->dst_reg, si->src_reg, \ + offsetof(struct tcp_sock, FIELD)); \ + } while (0) + +#define BPF_INET_SOCK_GET_COMMON(FIELD) \ + do { \ + BUILD_BUG_ON(sizeof_field(struct inet_connection_sock, \ + FIELD) > \ + sizeof_field(struct bpf_tcp_sock, FIELD)); \ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( \ + struct inet_connection_sock, \ + FIELD), \ + si->dst_reg, si->src_reg, \ + offsetof( \ + struct inet_connection_sock, \ + FIELD)); \ + } while (0) + + if (insn > insn_buf) + return insn - insn_buf; + + switch (si->off) { + case offsetof(struct bpf_tcp_sock, rtt_min): + BUILD_BUG_ON(sizeof_field(struct tcp_sock, rtt_min) != + sizeof(struct minmax)); + BUILD_BUG_ON(sizeof(struct minmax) < + sizeof(struct minmax_sample)); + + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, + offsetof(struct tcp_sock, rtt_min) + + offsetof(struct minmax_sample, v)); + break; + case offsetof(struct bpf_tcp_sock, snd_cwnd): + BPF_TCP_SOCK_GET_COMMON(snd_cwnd); + break; + case offsetof(struct bpf_tcp_sock, srtt_us): + BPF_TCP_SOCK_GET_COMMON(srtt_us); + break; + case offsetof(struct bpf_tcp_sock, snd_ssthresh): + BPF_TCP_SOCK_GET_COMMON(snd_ssthresh); + break; + case offsetof(struct bpf_tcp_sock, rcv_nxt): + BPF_TCP_SOCK_GET_COMMON(rcv_nxt); + break; + case offsetof(struct bpf_tcp_sock, snd_nxt): + BPF_TCP_SOCK_GET_COMMON(snd_nxt); + break; + case offsetof(struct bpf_tcp_sock, snd_una): + BPF_TCP_SOCK_GET_COMMON(snd_una); + break; + case offsetof(struct bpf_tcp_sock, mss_cache): + BPF_TCP_SOCK_GET_COMMON(mss_cache); + break; + case offsetof(struct bpf_tcp_sock, ecn_flags): + BPF_TCP_SOCK_GET_COMMON(ecn_flags); + break; + case offsetof(struct bpf_tcp_sock, rate_delivered): + BPF_TCP_SOCK_GET_COMMON(rate_delivered); + break; + case offsetof(struct bpf_tcp_sock, rate_interval_us): + BPF_TCP_SOCK_GET_COMMON(rate_interval_us); + break; + case offsetof(struct bpf_tcp_sock, packets_out): + BPF_TCP_SOCK_GET_COMMON(packets_out); + break; + case offsetof(struct bpf_tcp_sock, retrans_out): + BPF_TCP_SOCK_GET_COMMON(retrans_out); + break; + case offsetof(struct bpf_tcp_sock, total_retrans): + BPF_TCP_SOCK_GET_COMMON(total_retrans); + break; + case offsetof(struct bpf_tcp_sock, segs_in): + BPF_TCP_SOCK_GET_COMMON(segs_in); + break; + case offsetof(struct bpf_tcp_sock, data_segs_in): + BPF_TCP_SOCK_GET_COMMON(data_segs_in); + break; + case offsetof(struct bpf_tcp_sock, segs_out): + BPF_TCP_SOCK_GET_COMMON(segs_out); + break; + case offsetof(struct bpf_tcp_sock, data_segs_out): + BPF_TCP_SOCK_GET_COMMON(data_segs_out); + break; + case offsetof(struct bpf_tcp_sock, lost_out): + BPF_TCP_SOCK_GET_COMMON(lost_out); + break; + case offsetof(struct bpf_tcp_sock, sacked_out): + BPF_TCP_SOCK_GET_COMMON(sacked_out); + break; + case offsetof(struct bpf_tcp_sock, bytes_received): + BPF_TCP_SOCK_GET_COMMON(bytes_received); + break; + case offsetof(struct bpf_tcp_sock, bytes_acked): + BPF_TCP_SOCK_GET_COMMON(bytes_acked); + break; + case offsetof(struct bpf_tcp_sock, dsack_dups): + BPF_TCP_SOCK_GET_COMMON(dsack_dups); + break; + case offsetof(struct bpf_tcp_sock, delivered): + BPF_TCP_SOCK_GET_COMMON(delivered); + break; + case offsetof(struct bpf_tcp_sock, delivered_ce): + BPF_TCP_SOCK_GET_COMMON(delivered_ce); + break; + case offsetof(struct bpf_tcp_sock, icsk_retransmits): + BPF_INET_SOCK_GET_COMMON(icsk_retransmits); + break; + } + + return insn - insn_buf; +} + +BPF_CALL_1(bpf_tcp_sock, struct sock *, sk) +{ + if (sk_fullsock(sk) && sk->sk_protocol == IPPROTO_TCP) + return (unsigned long)sk; + + return (unsigned long)NULL; +} + +const struct bpf_func_proto bpf_tcp_sock_proto = { + .func = bpf_tcp_sock, + .gpl_only = false, + .ret_type = RET_PTR_TO_TCP_SOCK_OR_NULL, + .arg1_type = ARG_PTR_TO_SOCK_COMMON, +}; + +BPF_CALL_1(bpf_get_listener_sock, struct sock *, sk) +{ + sk = sk_to_full_sk(sk); + + if (sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_RCU_FREE)) + return (unsigned long)sk; + + return (unsigned long)NULL; +} + +static const struct bpf_func_proto bpf_get_listener_sock_proto = { + .func = bpf_get_listener_sock, + .gpl_only = false, + .ret_type = RET_PTR_TO_SOCKET_OR_NULL, + .arg1_type = ARG_PTR_TO_SOCK_COMMON, +}; + +BPF_CALL_1(bpf_skb_ecn_set_ce, struct sk_buff *, skb) +{ + unsigned int iphdr_len; + + switch (skb_protocol(skb, true)) { + case cpu_to_be16(ETH_P_IP): + iphdr_len = sizeof(struct iphdr); + break; + case cpu_to_be16(ETH_P_IPV6): + iphdr_len = sizeof(struct ipv6hdr); + break; + default: + return 0; + } + + if (skb_headlen(skb) < iphdr_len) + return 0; + + if (skb_cloned(skb) && !skb_clone_writable(skb, iphdr_len)) + return 0; + + return INET_ECN_set_ce(skb); +} + +bool bpf_xdp_sock_is_valid_access(int off, int size, enum bpf_access_type type, + struct bpf_insn_access_aux *info) +{ + if (off < 0 || off >= offsetofend(struct bpf_xdp_sock, queue_id)) + return false; + + if (off % size != 0) + return false; + + switch (off) { + default: + return size == sizeof(__u32); + } +} + +u32 bpf_xdp_sock_convert_ctx_access(enum bpf_access_type type, + const struct bpf_insn *si, + struct bpf_insn *insn_buf, + struct bpf_prog *prog, u32 *target_size) +{ + struct bpf_insn *insn = insn_buf; + +#define BPF_XDP_SOCK_GET(FIELD) \ + do { \ + BUILD_BUG_ON(sizeof_field(struct xdp_sock, FIELD) > \ + sizeof_field(struct bpf_xdp_sock, FIELD)); \ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct xdp_sock, FIELD),\ + si->dst_reg, si->src_reg, \ + offsetof(struct xdp_sock, FIELD)); \ + } while (0) + + switch (si->off) { + case offsetof(struct bpf_xdp_sock, queue_id): + BPF_XDP_SOCK_GET(queue_id); + break; + } + + return insn - insn_buf; +} + +static const struct bpf_func_proto bpf_skb_ecn_set_ce_proto = { + .func = bpf_skb_ecn_set_ce, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, +}; + +BPF_CALL_5(bpf_tcp_check_syncookie, struct sock *, sk, void *, iph, u32, iph_len, + struct tcphdr *, th, u32, th_len) +{ +#ifdef CONFIG_SYN_COOKIES + u32 cookie; + int ret; + + if (unlikely(!sk || th_len < sizeof(*th))) + return -EINVAL; + + /* sk_listener() allows TCP_NEW_SYN_RECV, which makes no sense here. */ + if (sk->sk_protocol != IPPROTO_TCP || sk->sk_state != TCP_LISTEN) + return -EINVAL; + + if (!READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_syncookies)) + return -EINVAL; + + if (!th->ack || th->rst || th->syn) + return -ENOENT; + + if (unlikely(iph_len < sizeof(struct iphdr))) + return -EINVAL; + + if (tcp_synq_no_recent_overflow(sk)) + return -ENOENT; + + cookie = ntohl(th->ack_seq) - 1; + + /* Both struct iphdr and struct ipv6hdr have the version field at the + * same offset so we can cast to the shorter header (struct iphdr). + */ + switch (((struct iphdr *)iph)->version) { + case 4: + if (sk->sk_family == AF_INET6 && ipv6_only_sock(sk)) + return -EINVAL; + + ret = __cookie_v4_check((struct iphdr *)iph, th, cookie); + break; + +#if IS_BUILTIN(CONFIG_IPV6) + case 6: + if (unlikely(iph_len < sizeof(struct ipv6hdr))) + return -EINVAL; + + if (sk->sk_family != AF_INET6) + return -EINVAL; + + ret = __cookie_v6_check((struct ipv6hdr *)iph, th, cookie); + break; +#endif /* CONFIG_IPV6 */ + + default: + return -EPROTONOSUPPORT; + } + + if (ret > 0) + return 0; + + return -ENOENT; +#else + return -ENOTSUPP; +#endif +} + +static const struct bpf_func_proto bpf_tcp_check_syncookie_proto = { + .func = bpf_tcp_check_syncookie, + .gpl_only = true, + .pkt_access = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg5_type = ARG_CONST_SIZE, +}; + +BPF_CALL_5(bpf_tcp_gen_syncookie, struct sock *, sk, void *, iph, u32, iph_len, + struct tcphdr *, th, u32, th_len) +{ +#ifdef CONFIG_SYN_COOKIES + u32 cookie; + u16 mss; + + if (unlikely(!sk || th_len < sizeof(*th) || th_len != th->doff * 4)) + return -EINVAL; + + if (sk->sk_protocol != IPPROTO_TCP || sk->sk_state != TCP_LISTEN) + return -EINVAL; + + if (!READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_syncookies)) + return -ENOENT; + + if (!th->syn || th->ack || th->fin || th->rst) + return -EINVAL; + + if (unlikely(iph_len < sizeof(struct iphdr))) + return -EINVAL; + + /* Both struct iphdr and struct ipv6hdr have the version field at the + * same offset so we can cast to the shorter header (struct iphdr). + */ + switch (((struct iphdr *)iph)->version) { + case 4: + if (sk->sk_family == AF_INET6 && ipv6_only_sock(sk)) + return -EINVAL; + + mss = tcp_v4_get_syncookie(sk, iph, th, &cookie); + break; + +#if IS_BUILTIN(CONFIG_IPV6) + case 6: + if (unlikely(iph_len < sizeof(struct ipv6hdr))) + return -EINVAL; + + if (sk->sk_family != AF_INET6) + return -EINVAL; + + mss = tcp_v6_get_syncookie(sk, iph, th, &cookie); + break; +#endif /* CONFIG_IPV6 */ + + default: + return -EPROTONOSUPPORT; + } + if (mss == 0) + return -ENOENT; + + return cookie | ((u64)mss << 32); +#else + return -EOPNOTSUPP; +#endif /* CONFIG_SYN_COOKIES */ +} + +static const struct bpf_func_proto bpf_tcp_gen_syncookie_proto = { + .func = bpf_tcp_gen_syncookie, + .gpl_only = true, /* __cookie_v*_init_sequence() is GPL */ + .pkt_access = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg5_type = ARG_CONST_SIZE, +}; + +BPF_CALL_3(bpf_sk_assign, struct sk_buff *, skb, struct sock *, sk, u64, flags) +{ + if (!sk || flags != 0) + return -EINVAL; + if (!skb_at_tc_ingress(skb)) + return -EOPNOTSUPP; + if (unlikely(dev_net(skb->dev) != sock_net(sk))) + return -ENETUNREACH; + if (unlikely(sk_fullsock(sk) && sk->sk_reuseport)) + return -ESOCKTNOSUPPORT; + if (sk_unhashed(sk)) + return -EOPNOTSUPP; + if (sk_is_refcounted(sk) && + unlikely(!refcount_inc_not_zero(&sk->sk_refcnt))) + return -ENOENT; + + skb_orphan(skb); + skb->sk = sk; + skb->destructor = sock_pfree; + + return 0; +} + +static const struct bpf_func_proto bpf_sk_assign_proto = { + .func = bpf_sk_assign, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, + .arg3_type = ARG_ANYTHING, +}; + +static const u8 *bpf_search_tcp_opt(const u8 *op, const u8 *opend, + u8 search_kind, const u8 *magic, + u8 magic_len, bool *eol) +{ + u8 kind, kind_len; + + *eol = false; + + while (op < opend) { + kind = op[0]; + + if (kind == TCPOPT_EOL) { + *eol = true; + return ERR_PTR(-ENOMSG); + } else if (kind == TCPOPT_NOP) { + op++; + continue; + } + + if (opend - op < 2 || opend - op < op[1] || op[1] < 2) + /* Something is wrong in the received header. + * Follow the TCP stack's tcp_parse_options() + * and just bail here. + */ + return ERR_PTR(-EFAULT); + + kind_len = op[1]; + if (search_kind == kind) { + if (!magic_len) + return op; + + if (magic_len > kind_len - 2) + return ERR_PTR(-ENOMSG); + + if (!memcmp(&op[2], magic, magic_len)) + return op; + } + + op += kind_len; + } + + return ERR_PTR(-ENOMSG); +} + +BPF_CALL_4(bpf_sock_ops_load_hdr_opt, struct bpf_sock_ops_kern *, bpf_sock, + void *, search_res, u32, len, u64, flags) +{ + bool eol, load_syn = flags & BPF_LOAD_HDR_OPT_TCP_SYN; + const u8 *op, *opend, *magic, *search = search_res; + u8 search_kind, search_len, copy_len, magic_len; + int ret; + + /* 2 byte is the minimal option len except TCPOPT_NOP and + * TCPOPT_EOL which are useless for the bpf prog to learn + * and this helper disallow loading them also. + */ + if (len < 2 || flags & ~BPF_LOAD_HDR_OPT_TCP_SYN) + return -EINVAL; + + search_kind = search[0]; + search_len = search[1]; + + if (search_len > len || search_kind == TCPOPT_NOP || + search_kind == TCPOPT_EOL) + return -EINVAL; + + if (search_kind == TCPOPT_EXP || search_kind == 253) { + /* 16 or 32 bit magic. +2 for kind and kind length */ + if (search_len != 4 && search_len != 6) + return -EINVAL; + magic = &search[2]; + magic_len = search_len - 2; + } else { + if (search_len) + return -EINVAL; + magic = NULL; + magic_len = 0; + } + + if (load_syn) { + ret = bpf_sock_ops_get_syn(bpf_sock, TCP_BPF_SYN, &op); + if (ret < 0) + return ret; + + opend = op + ret; + op += sizeof(struct tcphdr); + } else { + if (!bpf_sock->skb || + bpf_sock->op == BPF_SOCK_OPS_HDR_OPT_LEN_CB) + /* This bpf_sock->op cannot call this helper */ + return -EPERM; + + opend = bpf_sock->skb_data_end; + op = bpf_sock->skb->data + sizeof(struct tcphdr); + } + + op = bpf_search_tcp_opt(op, opend, search_kind, magic, magic_len, + &eol); + if (IS_ERR(op)) + return PTR_ERR(op); + + copy_len = op[1]; + ret = copy_len; + if (copy_len > len) { + ret = -ENOSPC; + copy_len = len; + } + + memcpy(search_res, op, copy_len); + return ret; +} + +static const struct bpf_func_proto bpf_sock_ops_load_hdr_opt_proto = { + .func = bpf_sock_ops_load_hdr_opt, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, +}; + +BPF_CALL_4(bpf_sock_ops_store_hdr_opt, struct bpf_sock_ops_kern *, bpf_sock, + const void *, from, u32, len, u64, flags) +{ + u8 new_kind, new_kind_len, magic_len = 0, *opend; + const u8 *op, *new_op, *magic = NULL; + struct sk_buff *skb; + bool eol; + + if (bpf_sock->op != BPF_SOCK_OPS_WRITE_HDR_OPT_CB) + return -EPERM; + + if (len < 2 || flags) + return -EINVAL; + + new_op = from; + new_kind = new_op[0]; + new_kind_len = new_op[1]; + + if (new_kind_len > len || new_kind == TCPOPT_NOP || + new_kind == TCPOPT_EOL) + return -EINVAL; + + if (new_kind_len > bpf_sock->remaining_opt_len) + return -ENOSPC; + + /* 253 is another experimental kind */ + if (new_kind == TCPOPT_EXP || new_kind == 253) { + if (new_kind_len < 4) + return -EINVAL; + /* Match for the 2 byte magic also. + * RFC 6994: the magic could be 2 or 4 bytes. + * Hence, matching by 2 byte only is on the + * conservative side but it is the right + * thing to do for the 'search-for-duplication' + * purpose. + */ + magic = &new_op[2]; + magic_len = 2; + } + + /* Check for duplication */ + skb = bpf_sock->skb; + op = skb->data + sizeof(struct tcphdr); + opend = bpf_sock->skb_data_end; + + op = bpf_search_tcp_opt(op, opend, new_kind, magic, magic_len, + &eol); + if (!IS_ERR(op)) + return -EEXIST; + + if (PTR_ERR(op) != -ENOMSG) + return PTR_ERR(op); + + if (eol) + /* The option has been ended. Treat it as no more + * header option can be written. + */ + return -ENOSPC; + + /* No duplication found. Store the header option. */ + memcpy(opend, from, new_kind_len); + + bpf_sock->remaining_opt_len -= new_kind_len; + bpf_sock->skb_data_end += new_kind_len; + + return 0; +} + +static const struct bpf_func_proto bpf_sock_ops_store_hdr_opt_proto = { + .func = bpf_sock_ops_store_hdr_opt, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_MEM | MEM_RDONLY, + .arg3_type = ARG_CONST_SIZE, + .arg4_type = ARG_ANYTHING, +}; + +BPF_CALL_3(bpf_sock_ops_reserve_hdr_opt, struct bpf_sock_ops_kern *, bpf_sock, + u32, len, u64, flags) +{ + if (bpf_sock->op != BPF_SOCK_OPS_HDR_OPT_LEN_CB) + return -EPERM; + + if (flags || len < 2) + return -EINVAL; + + if (len > bpf_sock->remaining_opt_len) + return -ENOSPC; + + bpf_sock->remaining_opt_len -= len; + + return 0; +} + +static const struct bpf_func_proto bpf_sock_ops_reserve_hdr_opt_proto = { + .func = bpf_sock_ops_reserve_hdr_opt, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, +}; + +BPF_CALL_3(bpf_skb_set_tstamp, struct sk_buff *, skb, + u64, tstamp, u32, tstamp_type) +{ + /* skb_clear_delivery_time() is done for inet protocol */ + if (skb->protocol != htons(ETH_P_IP) && + skb->protocol != htons(ETH_P_IPV6)) + return -EOPNOTSUPP; + + switch (tstamp_type) { + case BPF_SKB_TSTAMP_DELIVERY_MONO: + if (!tstamp) + return -EINVAL; + skb->tstamp = tstamp; + skb->mono_delivery_time = 1; + break; + case BPF_SKB_TSTAMP_UNSPEC: + if (tstamp) + return -EINVAL; + skb->tstamp = 0; + skb->mono_delivery_time = 0; + break; + default: + return -EINVAL; + } + + return 0; +} + +static const struct bpf_func_proto bpf_skb_set_tstamp_proto = { + .func = bpf_skb_set_tstamp, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_ANYTHING, +}; + +#ifdef CONFIG_SYN_COOKIES +BPF_CALL_3(bpf_tcp_raw_gen_syncookie_ipv4, struct iphdr *, iph, + struct tcphdr *, th, u32, th_len) +{ + u32 cookie; + u16 mss; + + if (unlikely(th_len < sizeof(*th) || th_len != th->doff * 4)) + return -EINVAL; + + mss = tcp_parse_mss_option(th, 0) ?: TCP_MSS_DEFAULT; + cookie = __cookie_v4_init_sequence(iph, th, &mss); + + return cookie | ((u64)mss << 32); +} + +static const struct bpf_func_proto bpf_tcp_raw_gen_syncookie_ipv4_proto = { + .func = bpf_tcp_raw_gen_syncookie_ipv4, + .gpl_only = true, /* __cookie_v4_init_sequence() is GPL */ + .pkt_access = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_FIXED_SIZE_MEM, + .arg1_size = sizeof(struct iphdr), + .arg2_type = ARG_PTR_TO_MEM, + .arg3_type = ARG_CONST_SIZE, +}; + +BPF_CALL_3(bpf_tcp_raw_gen_syncookie_ipv6, struct ipv6hdr *, iph, + struct tcphdr *, th, u32, th_len) +{ +#if IS_BUILTIN(CONFIG_IPV6) + const u16 mss_clamp = IPV6_MIN_MTU - sizeof(struct tcphdr) - + sizeof(struct ipv6hdr); + u32 cookie; + u16 mss; + + if (unlikely(th_len < sizeof(*th) || th_len != th->doff * 4)) + return -EINVAL; + + mss = tcp_parse_mss_option(th, 0) ?: mss_clamp; + cookie = __cookie_v6_init_sequence(iph, th, &mss); + + return cookie | ((u64)mss << 32); +#else + return -EPROTONOSUPPORT; +#endif +} + +static const struct bpf_func_proto bpf_tcp_raw_gen_syncookie_ipv6_proto = { + .func = bpf_tcp_raw_gen_syncookie_ipv6, + .gpl_only = true, /* __cookie_v6_init_sequence() is GPL */ + .pkt_access = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_FIXED_SIZE_MEM, + .arg1_size = sizeof(struct ipv6hdr), + .arg2_type = ARG_PTR_TO_MEM, + .arg3_type = ARG_CONST_SIZE, +}; + +BPF_CALL_2(bpf_tcp_raw_check_syncookie_ipv4, struct iphdr *, iph, + struct tcphdr *, th) +{ + u32 cookie = ntohl(th->ack_seq) - 1; + + if (__cookie_v4_check(iph, th, cookie) > 0) + return 0; + + return -EACCES; +} + +static const struct bpf_func_proto bpf_tcp_raw_check_syncookie_ipv4_proto = { + .func = bpf_tcp_raw_check_syncookie_ipv4, + .gpl_only = true, /* __cookie_v4_check is GPL */ + .pkt_access = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_FIXED_SIZE_MEM, + .arg1_size = sizeof(struct iphdr), + .arg2_type = ARG_PTR_TO_FIXED_SIZE_MEM, + .arg2_size = sizeof(struct tcphdr), +}; + +BPF_CALL_2(bpf_tcp_raw_check_syncookie_ipv6, struct ipv6hdr *, iph, + struct tcphdr *, th) +{ +#if IS_BUILTIN(CONFIG_IPV6) + u32 cookie = ntohl(th->ack_seq) - 1; + + if (__cookie_v6_check(iph, th, cookie) > 0) + return 0; + + return -EACCES; +#else + return -EPROTONOSUPPORT; +#endif +} + +static const struct bpf_func_proto bpf_tcp_raw_check_syncookie_ipv6_proto = { + .func = bpf_tcp_raw_check_syncookie_ipv6, + .gpl_only = true, /* __cookie_v6_check is GPL */ + .pkt_access = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_FIXED_SIZE_MEM, + .arg1_size = sizeof(struct ipv6hdr), + .arg2_type = ARG_PTR_TO_FIXED_SIZE_MEM, + .arg2_size = sizeof(struct tcphdr), +}; +#endif /* CONFIG_SYN_COOKIES */ + +#endif /* CONFIG_INET */ + +bool bpf_helper_changes_pkt_data(void *func) +{ + if (func == bpf_skb_vlan_push || + func == bpf_skb_vlan_pop || + func == bpf_skb_store_bytes || + func == bpf_skb_change_proto || + func == bpf_skb_change_head || + func == sk_skb_change_head || + func == bpf_skb_change_tail || + func == sk_skb_change_tail || + func == bpf_skb_adjust_room || + func == sk_skb_adjust_room || + func == bpf_skb_pull_data || + func == sk_skb_pull_data || + func == bpf_clone_redirect || + func == bpf_l3_csum_replace || + func == bpf_l4_csum_replace || + func == bpf_xdp_adjust_head || + func == bpf_xdp_adjust_meta || + func == bpf_msg_pull_data || + func == bpf_msg_push_data || + func == bpf_msg_pop_data || + func == bpf_xdp_adjust_tail || +#if IS_ENABLED(CONFIG_IPV6_SEG6_BPF) + func == bpf_lwt_seg6_store_bytes || + func == bpf_lwt_seg6_adjust_srh || + func == bpf_lwt_seg6_action || +#endif +#ifdef CONFIG_INET + func == bpf_sock_ops_store_hdr_opt || +#endif + func == bpf_lwt_in_push_encap || + func == bpf_lwt_xmit_push_encap) + return true; + + return false; +} + +const struct bpf_func_proto bpf_event_output_data_proto __weak; +const struct bpf_func_proto bpf_sk_storage_get_cg_sock_proto __weak; + +static const struct bpf_func_proto * +sock_filter_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) +{ + const struct bpf_func_proto *func_proto; + + func_proto = cgroup_common_func_proto(func_id, prog); + if (func_proto) + return func_proto; + + func_proto = cgroup_current_func_proto(func_id, prog); + if (func_proto) + return func_proto; + + switch (func_id) { + case BPF_FUNC_get_socket_cookie: + return &bpf_get_socket_cookie_sock_proto; + case BPF_FUNC_get_netns_cookie: + return &bpf_get_netns_cookie_sock_proto; + case BPF_FUNC_perf_event_output: + return &bpf_event_output_data_proto; + case BPF_FUNC_sk_storage_get: + return &bpf_sk_storage_get_cg_sock_proto; + case BPF_FUNC_ktime_get_coarse_ns: + return &bpf_ktime_get_coarse_ns_proto; + default: + return bpf_base_func_proto(func_id); + } +} + +static const struct bpf_func_proto * +sock_addr_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) +{ + const struct bpf_func_proto *func_proto; + + func_proto = cgroup_common_func_proto(func_id, prog); + if (func_proto) + return func_proto; + + func_proto = cgroup_current_func_proto(func_id, prog); + if (func_proto) + return func_proto; + + switch (func_id) { + case BPF_FUNC_bind: + switch (prog->expected_attach_type) { + case BPF_CGROUP_INET4_CONNECT: + case BPF_CGROUP_INET6_CONNECT: + return &bpf_bind_proto; + default: + return NULL; + } + case BPF_FUNC_get_socket_cookie: + return &bpf_get_socket_cookie_sock_addr_proto; + case BPF_FUNC_get_netns_cookie: + return &bpf_get_netns_cookie_sock_addr_proto; + case BPF_FUNC_perf_event_output: + return &bpf_event_output_data_proto; +#ifdef CONFIG_INET + case BPF_FUNC_sk_lookup_tcp: + return &bpf_sock_addr_sk_lookup_tcp_proto; + case BPF_FUNC_sk_lookup_udp: + return &bpf_sock_addr_sk_lookup_udp_proto; + case BPF_FUNC_sk_release: + return &bpf_sk_release_proto; + case BPF_FUNC_skc_lookup_tcp: + return &bpf_sock_addr_skc_lookup_tcp_proto; +#endif /* CONFIG_INET */ + case BPF_FUNC_sk_storage_get: + return &bpf_sk_storage_get_proto; + case BPF_FUNC_sk_storage_delete: + return &bpf_sk_storage_delete_proto; + case BPF_FUNC_setsockopt: + switch (prog->expected_attach_type) { + case BPF_CGROUP_INET4_BIND: + case BPF_CGROUP_INET6_BIND: + case BPF_CGROUP_INET4_CONNECT: + case BPF_CGROUP_INET6_CONNECT: + case BPF_CGROUP_UDP4_RECVMSG: + case BPF_CGROUP_UDP6_RECVMSG: + case BPF_CGROUP_UDP4_SENDMSG: + case BPF_CGROUP_UDP6_SENDMSG: + case BPF_CGROUP_INET4_GETPEERNAME: + case BPF_CGROUP_INET6_GETPEERNAME: + case BPF_CGROUP_INET4_GETSOCKNAME: + case BPF_CGROUP_INET6_GETSOCKNAME: + return &bpf_sock_addr_setsockopt_proto; + default: + return NULL; + } + case BPF_FUNC_getsockopt: + switch (prog->expected_attach_type) { + case BPF_CGROUP_INET4_BIND: + case BPF_CGROUP_INET6_BIND: + case BPF_CGROUP_INET4_CONNECT: + case BPF_CGROUP_INET6_CONNECT: + case BPF_CGROUP_UDP4_RECVMSG: + case BPF_CGROUP_UDP6_RECVMSG: + case BPF_CGROUP_UDP4_SENDMSG: + case BPF_CGROUP_UDP6_SENDMSG: + case BPF_CGROUP_INET4_GETPEERNAME: + case BPF_CGROUP_INET6_GETPEERNAME: + case BPF_CGROUP_INET4_GETSOCKNAME: + case BPF_CGROUP_INET6_GETSOCKNAME: + return &bpf_sock_addr_getsockopt_proto; + default: + return NULL; + } + default: + return bpf_sk_base_func_proto(func_id); + } +} + +static const struct bpf_func_proto * +sk_filter_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) +{ + switch (func_id) { + case BPF_FUNC_skb_load_bytes: + return &bpf_skb_load_bytes_proto; + case BPF_FUNC_skb_load_bytes_relative: + return &bpf_skb_load_bytes_relative_proto; + case BPF_FUNC_get_socket_cookie: + return &bpf_get_socket_cookie_proto; + case BPF_FUNC_get_socket_uid: + return &bpf_get_socket_uid_proto; + case BPF_FUNC_perf_event_output: + return &bpf_skb_event_output_proto; + default: + return bpf_sk_base_func_proto(func_id); + } +} + +const struct bpf_func_proto bpf_sk_storage_get_proto __weak; +const struct bpf_func_proto bpf_sk_storage_delete_proto __weak; + +static const struct bpf_func_proto * +cg_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) +{ + const struct bpf_func_proto *func_proto; + + func_proto = cgroup_common_func_proto(func_id, prog); + if (func_proto) + return func_proto; + + switch (func_id) { + case BPF_FUNC_sk_fullsock: + return &bpf_sk_fullsock_proto; + case BPF_FUNC_sk_storage_get: + return &bpf_sk_storage_get_proto; + case BPF_FUNC_sk_storage_delete: + return &bpf_sk_storage_delete_proto; + case BPF_FUNC_perf_event_output: + return &bpf_skb_event_output_proto; +#ifdef CONFIG_SOCK_CGROUP_DATA + case BPF_FUNC_skb_cgroup_id: + return &bpf_skb_cgroup_id_proto; + case BPF_FUNC_skb_ancestor_cgroup_id: + return &bpf_skb_ancestor_cgroup_id_proto; + case BPF_FUNC_sk_cgroup_id: + return &bpf_sk_cgroup_id_proto; + case BPF_FUNC_sk_ancestor_cgroup_id: + return &bpf_sk_ancestor_cgroup_id_proto; +#endif +#ifdef CONFIG_INET + case BPF_FUNC_sk_lookup_tcp: + return &bpf_sk_lookup_tcp_proto; + case BPF_FUNC_sk_lookup_udp: + return &bpf_sk_lookup_udp_proto; + case BPF_FUNC_sk_release: + return &bpf_sk_release_proto; + case BPF_FUNC_skc_lookup_tcp: + return &bpf_skc_lookup_tcp_proto; + case BPF_FUNC_tcp_sock: + return &bpf_tcp_sock_proto; + case BPF_FUNC_get_listener_sock: + return &bpf_get_listener_sock_proto; + case BPF_FUNC_skb_ecn_set_ce: + return &bpf_skb_ecn_set_ce_proto; +#endif + default: + return sk_filter_func_proto(func_id, prog); + } +} + +static const struct bpf_func_proto * +tc_cls_act_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) +{ + switch (func_id) { + case BPF_FUNC_skb_store_bytes: + return &bpf_skb_store_bytes_proto; + case BPF_FUNC_skb_load_bytes: + return &bpf_skb_load_bytes_proto; + case BPF_FUNC_skb_load_bytes_relative: + return &bpf_skb_load_bytes_relative_proto; + case BPF_FUNC_skb_pull_data: + return &bpf_skb_pull_data_proto; + case BPF_FUNC_csum_diff: + return &bpf_csum_diff_proto; + case BPF_FUNC_csum_update: + return &bpf_csum_update_proto; + case BPF_FUNC_csum_level: + return &bpf_csum_level_proto; + case BPF_FUNC_l3_csum_replace: + return &bpf_l3_csum_replace_proto; + case BPF_FUNC_l4_csum_replace: + return &bpf_l4_csum_replace_proto; + case BPF_FUNC_clone_redirect: + return &bpf_clone_redirect_proto; + case BPF_FUNC_get_cgroup_classid: + return &bpf_get_cgroup_classid_proto; + case BPF_FUNC_skb_vlan_push: + return &bpf_skb_vlan_push_proto; + case BPF_FUNC_skb_vlan_pop: + return &bpf_skb_vlan_pop_proto; + case BPF_FUNC_skb_change_proto: + return &bpf_skb_change_proto_proto; + case BPF_FUNC_skb_change_type: + return &bpf_skb_change_type_proto; + case BPF_FUNC_skb_adjust_room: + return &bpf_skb_adjust_room_proto; + case BPF_FUNC_skb_change_tail: + return &bpf_skb_change_tail_proto; + case BPF_FUNC_skb_change_head: + return &bpf_skb_change_head_proto; + case BPF_FUNC_skb_get_tunnel_key: + return &bpf_skb_get_tunnel_key_proto; + case BPF_FUNC_skb_set_tunnel_key: + return bpf_get_skb_set_tunnel_proto(func_id); + case BPF_FUNC_skb_get_tunnel_opt: + return &bpf_skb_get_tunnel_opt_proto; + case BPF_FUNC_skb_set_tunnel_opt: + return bpf_get_skb_set_tunnel_proto(func_id); + case BPF_FUNC_redirect: + return &bpf_redirect_proto; + case BPF_FUNC_redirect_neigh: + return &bpf_redirect_neigh_proto; + case BPF_FUNC_redirect_peer: + return &bpf_redirect_peer_proto; + case BPF_FUNC_get_route_realm: + return &bpf_get_route_realm_proto; + case BPF_FUNC_get_hash_recalc: + return &bpf_get_hash_recalc_proto; + case BPF_FUNC_set_hash_invalid: + return &bpf_set_hash_invalid_proto; + case BPF_FUNC_set_hash: + return &bpf_set_hash_proto; + case BPF_FUNC_perf_event_output: + return &bpf_skb_event_output_proto; + case BPF_FUNC_get_smp_processor_id: + return &bpf_get_smp_processor_id_proto; + case BPF_FUNC_skb_under_cgroup: + return &bpf_skb_under_cgroup_proto; + case BPF_FUNC_get_socket_cookie: + return &bpf_get_socket_cookie_proto; + case BPF_FUNC_get_socket_uid: + return &bpf_get_socket_uid_proto; + case BPF_FUNC_fib_lookup: + return &bpf_skb_fib_lookup_proto; + case BPF_FUNC_check_mtu: + return &bpf_skb_check_mtu_proto; + case BPF_FUNC_sk_fullsock: + return &bpf_sk_fullsock_proto; + case BPF_FUNC_sk_storage_get: + return &bpf_sk_storage_get_proto; + case BPF_FUNC_sk_storage_delete: + return &bpf_sk_storage_delete_proto; +#ifdef CONFIG_XFRM + case BPF_FUNC_skb_get_xfrm_state: + return &bpf_skb_get_xfrm_state_proto; +#endif +#ifdef CONFIG_CGROUP_NET_CLASSID + case BPF_FUNC_skb_cgroup_classid: + return &bpf_skb_cgroup_classid_proto; +#endif +#ifdef CONFIG_SOCK_CGROUP_DATA + case BPF_FUNC_skb_cgroup_id: + return &bpf_skb_cgroup_id_proto; + case BPF_FUNC_skb_ancestor_cgroup_id: + return &bpf_skb_ancestor_cgroup_id_proto; +#endif +#ifdef CONFIG_INET + case BPF_FUNC_sk_lookup_tcp: + return &bpf_tc_sk_lookup_tcp_proto; + case BPF_FUNC_sk_lookup_udp: + return &bpf_tc_sk_lookup_udp_proto; + case BPF_FUNC_sk_release: + return &bpf_sk_release_proto; + case BPF_FUNC_tcp_sock: + return &bpf_tcp_sock_proto; + case BPF_FUNC_get_listener_sock: + return &bpf_get_listener_sock_proto; + case BPF_FUNC_skc_lookup_tcp: + return &bpf_tc_skc_lookup_tcp_proto; + case BPF_FUNC_tcp_check_syncookie: + return &bpf_tcp_check_syncookie_proto; + case BPF_FUNC_skb_ecn_set_ce: + return &bpf_skb_ecn_set_ce_proto; + case BPF_FUNC_tcp_gen_syncookie: + return &bpf_tcp_gen_syncookie_proto; + case BPF_FUNC_sk_assign: + return &bpf_sk_assign_proto; + case BPF_FUNC_skb_set_tstamp: + return &bpf_skb_set_tstamp_proto; +#ifdef CONFIG_SYN_COOKIES + case BPF_FUNC_tcp_raw_gen_syncookie_ipv4: + return &bpf_tcp_raw_gen_syncookie_ipv4_proto; + case BPF_FUNC_tcp_raw_gen_syncookie_ipv6: + return &bpf_tcp_raw_gen_syncookie_ipv6_proto; + case BPF_FUNC_tcp_raw_check_syncookie_ipv4: + return &bpf_tcp_raw_check_syncookie_ipv4_proto; + case BPF_FUNC_tcp_raw_check_syncookie_ipv6: + return &bpf_tcp_raw_check_syncookie_ipv6_proto; +#endif +#endif + default: + return bpf_sk_base_func_proto(func_id); + } +} + +static const struct bpf_func_proto * +xdp_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) +{ + switch (func_id) { + case BPF_FUNC_perf_event_output: + return &bpf_xdp_event_output_proto; + case BPF_FUNC_get_smp_processor_id: + return &bpf_get_smp_processor_id_proto; + case BPF_FUNC_csum_diff: + return &bpf_csum_diff_proto; + case BPF_FUNC_xdp_adjust_head: + return &bpf_xdp_adjust_head_proto; + case BPF_FUNC_xdp_adjust_meta: + return &bpf_xdp_adjust_meta_proto; + case BPF_FUNC_redirect: + return &bpf_xdp_redirect_proto; + case BPF_FUNC_redirect_map: + return &bpf_xdp_redirect_map_proto; + case BPF_FUNC_xdp_adjust_tail: + return &bpf_xdp_adjust_tail_proto; + case BPF_FUNC_xdp_get_buff_len: + return &bpf_xdp_get_buff_len_proto; + case BPF_FUNC_xdp_load_bytes: + return &bpf_xdp_load_bytes_proto; + case BPF_FUNC_xdp_store_bytes: + return &bpf_xdp_store_bytes_proto; + case BPF_FUNC_fib_lookup: + return &bpf_xdp_fib_lookup_proto; + case BPF_FUNC_check_mtu: + return &bpf_xdp_check_mtu_proto; +#ifdef CONFIG_INET + case BPF_FUNC_sk_lookup_udp: + return &bpf_xdp_sk_lookup_udp_proto; + case BPF_FUNC_sk_lookup_tcp: + return &bpf_xdp_sk_lookup_tcp_proto; + case BPF_FUNC_sk_release: + return &bpf_sk_release_proto; + case BPF_FUNC_skc_lookup_tcp: + return &bpf_xdp_skc_lookup_tcp_proto; + case BPF_FUNC_tcp_check_syncookie: + return &bpf_tcp_check_syncookie_proto; + case BPF_FUNC_tcp_gen_syncookie: + return &bpf_tcp_gen_syncookie_proto; +#ifdef CONFIG_SYN_COOKIES + case BPF_FUNC_tcp_raw_gen_syncookie_ipv4: + return &bpf_tcp_raw_gen_syncookie_ipv4_proto; + case BPF_FUNC_tcp_raw_gen_syncookie_ipv6: + return &bpf_tcp_raw_gen_syncookie_ipv6_proto; + case BPF_FUNC_tcp_raw_check_syncookie_ipv4: + return &bpf_tcp_raw_check_syncookie_ipv4_proto; + case BPF_FUNC_tcp_raw_check_syncookie_ipv6: + return &bpf_tcp_raw_check_syncookie_ipv6_proto; +#endif +#endif + default: + return bpf_sk_base_func_proto(func_id); + } + +#if IS_MODULE(CONFIG_NF_CONNTRACK) && IS_ENABLED(CONFIG_DEBUG_INFO_BTF_MODULES) + /* The nf_conn___init type is used in the NF_CONNTRACK kfuncs. The + * kfuncs are defined in two different modules, and we want to be able + * to use them interchangably with the same BTF type ID. Because modules + * can't de-duplicate BTF IDs between each other, we need the type to be + * referenced in the vmlinux BTF or the verifier will get confused about + * the different types. So we add this dummy type reference which will + * be included in vmlinux BTF, allowing both modules to refer to the + * same type ID. + */ + BTF_TYPE_EMIT(struct nf_conn___init); +#endif +} + +const struct bpf_func_proto bpf_sock_map_update_proto __weak; +const struct bpf_func_proto bpf_sock_hash_update_proto __weak; + +static const struct bpf_func_proto * +sock_ops_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) +{ + const struct bpf_func_proto *func_proto; + + func_proto = cgroup_common_func_proto(func_id, prog); + if (func_proto) + return func_proto; + + switch (func_id) { + case BPF_FUNC_setsockopt: + return &bpf_sock_ops_setsockopt_proto; + case BPF_FUNC_getsockopt: + return &bpf_sock_ops_getsockopt_proto; + case BPF_FUNC_sock_ops_cb_flags_set: + return &bpf_sock_ops_cb_flags_set_proto; + case BPF_FUNC_sock_map_update: + return &bpf_sock_map_update_proto; + case BPF_FUNC_sock_hash_update: + return &bpf_sock_hash_update_proto; + case BPF_FUNC_get_socket_cookie: + return &bpf_get_socket_cookie_sock_ops_proto; + case BPF_FUNC_perf_event_output: + return &bpf_event_output_data_proto; + case BPF_FUNC_sk_storage_get: + return &bpf_sk_storage_get_proto; + case BPF_FUNC_sk_storage_delete: + return &bpf_sk_storage_delete_proto; + case BPF_FUNC_get_netns_cookie: + return &bpf_get_netns_cookie_sock_ops_proto; +#ifdef CONFIG_INET + case BPF_FUNC_load_hdr_opt: + return &bpf_sock_ops_load_hdr_opt_proto; + case BPF_FUNC_store_hdr_opt: + return &bpf_sock_ops_store_hdr_opt_proto; + case BPF_FUNC_reserve_hdr_opt: + return &bpf_sock_ops_reserve_hdr_opt_proto; + case BPF_FUNC_tcp_sock: + return &bpf_tcp_sock_proto; +#endif /* CONFIG_INET */ + default: + return bpf_sk_base_func_proto(func_id); + } +} + +const struct bpf_func_proto bpf_msg_redirect_map_proto __weak; +const struct bpf_func_proto bpf_msg_redirect_hash_proto __weak; + +static const struct bpf_func_proto * +sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) +{ + switch (func_id) { + case BPF_FUNC_msg_redirect_map: + return &bpf_msg_redirect_map_proto; + case BPF_FUNC_msg_redirect_hash: + return &bpf_msg_redirect_hash_proto; + case BPF_FUNC_msg_apply_bytes: + return &bpf_msg_apply_bytes_proto; + case BPF_FUNC_msg_cork_bytes: + return &bpf_msg_cork_bytes_proto; + case BPF_FUNC_msg_pull_data: + return &bpf_msg_pull_data_proto; + case BPF_FUNC_msg_push_data: + return &bpf_msg_push_data_proto; + case BPF_FUNC_msg_pop_data: + return &bpf_msg_pop_data_proto; + case BPF_FUNC_perf_event_output: + return &bpf_event_output_data_proto; + case BPF_FUNC_get_current_uid_gid: + return &bpf_get_current_uid_gid_proto; + case BPF_FUNC_get_current_pid_tgid: + return &bpf_get_current_pid_tgid_proto; + case BPF_FUNC_sk_storage_get: + return &bpf_sk_storage_get_proto; + case BPF_FUNC_sk_storage_delete: + return &bpf_sk_storage_delete_proto; + case BPF_FUNC_get_netns_cookie: + return &bpf_get_netns_cookie_sk_msg_proto; +#ifdef CONFIG_CGROUPS + case BPF_FUNC_get_current_cgroup_id: + return &bpf_get_current_cgroup_id_proto; + case BPF_FUNC_get_current_ancestor_cgroup_id: + return &bpf_get_current_ancestor_cgroup_id_proto; +#endif +#ifdef CONFIG_CGROUP_NET_CLASSID + case BPF_FUNC_get_cgroup_classid: + return &bpf_get_cgroup_classid_curr_proto; +#endif + default: + return bpf_sk_base_func_proto(func_id); + } +} + +const struct bpf_func_proto bpf_sk_redirect_map_proto __weak; +const struct bpf_func_proto bpf_sk_redirect_hash_proto __weak; + +static const struct bpf_func_proto * +sk_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) +{ + switch (func_id) { + case BPF_FUNC_skb_store_bytes: + return &bpf_skb_store_bytes_proto; + case BPF_FUNC_skb_load_bytes: + return &bpf_skb_load_bytes_proto; + case BPF_FUNC_skb_pull_data: + return &sk_skb_pull_data_proto; + case BPF_FUNC_skb_change_tail: + return &sk_skb_change_tail_proto; + case BPF_FUNC_skb_change_head: + return &sk_skb_change_head_proto; + case BPF_FUNC_skb_adjust_room: + return &sk_skb_adjust_room_proto; + case BPF_FUNC_get_socket_cookie: + return &bpf_get_socket_cookie_proto; + case BPF_FUNC_get_socket_uid: + return &bpf_get_socket_uid_proto; + case BPF_FUNC_sk_redirect_map: + return &bpf_sk_redirect_map_proto; + case BPF_FUNC_sk_redirect_hash: + return &bpf_sk_redirect_hash_proto; + case BPF_FUNC_perf_event_output: + return &bpf_skb_event_output_proto; +#ifdef CONFIG_INET + case BPF_FUNC_sk_lookup_tcp: + return &bpf_sk_lookup_tcp_proto; + case BPF_FUNC_sk_lookup_udp: + return &bpf_sk_lookup_udp_proto; + case BPF_FUNC_sk_release: + return &bpf_sk_release_proto; + case BPF_FUNC_skc_lookup_tcp: + return &bpf_skc_lookup_tcp_proto; +#endif + default: + return bpf_sk_base_func_proto(func_id); + } +} + +static const struct bpf_func_proto * +flow_dissector_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) +{ + switch (func_id) { + case BPF_FUNC_skb_load_bytes: + return &bpf_flow_dissector_load_bytes_proto; + default: + return bpf_sk_base_func_proto(func_id); + } +} + +static const struct bpf_func_proto * +lwt_out_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) +{ + switch (func_id) { + case BPF_FUNC_skb_load_bytes: + return &bpf_skb_load_bytes_proto; + case BPF_FUNC_skb_pull_data: + return &bpf_skb_pull_data_proto; + case BPF_FUNC_csum_diff: + return &bpf_csum_diff_proto; + case BPF_FUNC_get_cgroup_classid: + return &bpf_get_cgroup_classid_proto; + case BPF_FUNC_get_route_realm: + return &bpf_get_route_realm_proto; + case BPF_FUNC_get_hash_recalc: + return &bpf_get_hash_recalc_proto; + case BPF_FUNC_perf_event_output: + return &bpf_skb_event_output_proto; + case BPF_FUNC_get_smp_processor_id: + return &bpf_get_smp_processor_id_proto; + case BPF_FUNC_skb_under_cgroup: + return &bpf_skb_under_cgroup_proto; + default: + return bpf_sk_base_func_proto(func_id); + } +} + +static const struct bpf_func_proto * +lwt_in_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) +{ + switch (func_id) { + case BPF_FUNC_lwt_push_encap: + return &bpf_lwt_in_push_encap_proto; + default: + return lwt_out_func_proto(func_id, prog); + } +} + +static const struct bpf_func_proto * +lwt_xmit_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) +{ + switch (func_id) { + case BPF_FUNC_skb_get_tunnel_key: + return &bpf_skb_get_tunnel_key_proto; + case BPF_FUNC_skb_set_tunnel_key: + return bpf_get_skb_set_tunnel_proto(func_id); + case BPF_FUNC_skb_get_tunnel_opt: + return &bpf_skb_get_tunnel_opt_proto; + case BPF_FUNC_skb_set_tunnel_opt: + return bpf_get_skb_set_tunnel_proto(func_id); + case BPF_FUNC_redirect: + return &bpf_redirect_proto; + case BPF_FUNC_clone_redirect: + return &bpf_clone_redirect_proto; + case BPF_FUNC_skb_change_tail: + return &bpf_skb_change_tail_proto; + case BPF_FUNC_skb_change_head: + return &bpf_skb_change_head_proto; + case BPF_FUNC_skb_store_bytes: + return &bpf_skb_store_bytes_proto; + case BPF_FUNC_csum_update: + return &bpf_csum_update_proto; + case BPF_FUNC_csum_level: + return &bpf_csum_level_proto; + case BPF_FUNC_l3_csum_replace: + return &bpf_l3_csum_replace_proto; + case BPF_FUNC_l4_csum_replace: + return &bpf_l4_csum_replace_proto; + case BPF_FUNC_set_hash_invalid: + return &bpf_set_hash_invalid_proto; + case BPF_FUNC_lwt_push_encap: + return &bpf_lwt_xmit_push_encap_proto; + default: + return lwt_out_func_proto(func_id, prog); + } +} + +static const struct bpf_func_proto * +lwt_seg6local_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) +{ + switch (func_id) { +#if IS_ENABLED(CONFIG_IPV6_SEG6_BPF) + case BPF_FUNC_lwt_seg6_store_bytes: + return &bpf_lwt_seg6_store_bytes_proto; + case BPF_FUNC_lwt_seg6_action: + return &bpf_lwt_seg6_action_proto; + case BPF_FUNC_lwt_seg6_adjust_srh: + return &bpf_lwt_seg6_adjust_srh_proto; +#endif + default: + return lwt_out_func_proto(func_id, prog); + } +} + +static bool bpf_skb_is_valid_access(int off, int size, enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + const int size_default = sizeof(__u32); + + if (off < 0 || off >= sizeof(struct __sk_buff)) + return false; + + /* The verifier guarantees that size > 0. */ + if (off % size != 0) + return false; + + switch (off) { + case bpf_ctx_range_till(struct __sk_buff, cb[0], cb[4]): + if (off + size > offsetofend(struct __sk_buff, cb[4])) + return false; + break; + case bpf_ctx_range_till(struct __sk_buff, remote_ip6[0], remote_ip6[3]): + case bpf_ctx_range_till(struct __sk_buff, local_ip6[0], local_ip6[3]): + case bpf_ctx_range_till(struct __sk_buff, remote_ip4, remote_ip4): + case bpf_ctx_range_till(struct __sk_buff, local_ip4, local_ip4): + case bpf_ctx_range(struct __sk_buff, data): + case bpf_ctx_range(struct __sk_buff, data_meta): + case bpf_ctx_range(struct __sk_buff, data_end): + if (size != size_default) + return false; + break; + case bpf_ctx_range_ptr(struct __sk_buff, flow_keys): + return false; + case bpf_ctx_range(struct __sk_buff, hwtstamp): + if (type == BPF_WRITE || size != sizeof(__u64)) + return false; + break; + case bpf_ctx_range(struct __sk_buff, tstamp): + if (size != sizeof(__u64)) + return false; + break; + case offsetof(struct __sk_buff, sk): + if (type == BPF_WRITE || size != sizeof(__u64)) + return false; + info->reg_type = PTR_TO_SOCK_COMMON_OR_NULL; + break; + case offsetof(struct __sk_buff, tstamp_type): + return false; + case offsetofend(struct __sk_buff, tstamp_type) ... offsetof(struct __sk_buff, hwtstamp) - 1: + /* Explicitly prohibit access to padding in __sk_buff. */ + return false; + default: + /* Only narrow read access allowed for now. */ + if (type == BPF_WRITE) { + if (size != size_default) + return false; + } else { + bpf_ctx_record_field_size(info, size_default); + if (!bpf_ctx_narrow_access_ok(off, size, size_default)) + return false; + } + } + + return true; +} + +static bool sk_filter_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + switch (off) { + case bpf_ctx_range(struct __sk_buff, tc_classid): + case bpf_ctx_range(struct __sk_buff, data): + case bpf_ctx_range(struct __sk_buff, data_meta): + case bpf_ctx_range(struct __sk_buff, data_end): + case bpf_ctx_range_till(struct __sk_buff, family, local_port): + case bpf_ctx_range(struct __sk_buff, tstamp): + case bpf_ctx_range(struct __sk_buff, wire_len): + case bpf_ctx_range(struct __sk_buff, hwtstamp): + return false; + } + + if (type == BPF_WRITE) { + switch (off) { + case bpf_ctx_range_till(struct __sk_buff, cb[0], cb[4]): + break; + default: + return false; + } + } + + return bpf_skb_is_valid_access(off, size, type, prog, info); +} + +static bool cg_skb_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + switch (off) { + case bpf_ctx_range(struct __sk_buff, tc_classid): + case bpf_ctx_range(struct __sk_buff, data_meta): + case bpf_ctx_range(struct __sk_buff, wire_len): + return false; + case bpf_ctx_range(struct __sk_buff, data): + case bpf_ctx_range(struct __sk_buff, data_end): + if (!bpf_capable()) + return false; + break; + } + + if (type == BPF_WRITE) { + switch (off) { + case bpf_ctx_range(struct __sk_buff, mark): + case bpf_ctx_range(struct __sk_buff, priority): + case bpf_ctx_range_till(struct __sk_buff, cb[0], cb[4]): + break; + case bpf_ctx_range(struct __sk_buff, tstamp): + if (!bpf_capable()) + return false; + break; + default: + return false; + } + } + + switch (off) { + case bpf_ctx_range(struct __sk_buff, data): + info->reg_type = PTR_TO_PACKET; + break; + case bpf_ctx_range(struct __sk_buff, data_end): + info->reg_type = PTR_TO_PACKET_END; + break; + } + + return bpf_skb_is_valid_access(off, size, type, prog, info); +} + +static bool lwt_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + switch (off) { + case bpf_ctx_range(struct __sk_buff, tc_classid): + case bpf_ctx_range_till(struct __sk_buff, family, local_port): + case bpf_ctx_range(struct __sk_buff, data_meta): + case bpf_ctx_range(struct __sk_buff, tstamp): + case bpf_ctx_range(struct __sk_buff, wire_len): + case bpf_ctx_range(struct __sk_buff, hwtstamp): + return false; + } + + if (type == BPF_WRITE) { + switch (off) { + case bpf_ctx_range(struct __sk_buff, mark): + case bpf_ctx_range(struct __sk_buff, priority): + case bpf_ctx_range_till(struct __sk_buff, cb[0], cb[4]): + break; + default: + return false; + } + } + + switch (off) { + case bpf_ctx_range(struct __sk_buff, data): + info->reg_type = PTR_TO_PACKET; + break; + case bpf_ctx_range(struct __sk_buff, data_end): + info->reg_type = PTR_TO_PACKET_END; + break; + } + + return bpf_skb_is_valid_access(off, size, type, prog, info); +} + +/* Attach type specific accesses */ +static bool __sock_filter_check_attach_type(int off, + enum bpf_access_type access_type, + enum bpf_attach_type attach_type) +{ + switch (off) { + case offsetof(struct bpf_sock, bound_dev_if): + case offsetof(struct bpf_sock, mark): + case offsetof(struct bpf_sock, priority): + switch (attach_type) { + case BPF_CGROUP_INET_SOCK_CREATE: + case BPF_CGROUP_INET_SOCK_RELEASE: + goto full_access; + default: + return false; + } + case bpf_ctx_range(struct bpf_sock, src_ip4): + switch (attach_type) { + case BPF_CGROUP_INET4_POST_BIND: + goto read_only; + default: + return false; + } + case bpf_ctx_range_till(struct bpf_sock, src_ip6[0], src_ip6[3]): + switch (attach_type) { + case BPF_CGROUP_INET6_POST_BIND: + goto read_only; + default: + return false; + } + case bpf_ctx_range(struct bpf_sock, src_port): + switch (attach_type) { + case BPF_CGROUP_INET4_POST_BIND: + case BPF_CGROUP_INET6_POST_BIND: + goto read_only; + default: + return false; + } + } +read_only: + return access_type == BPF_READ; +full_access: + return true; +} + +bool bpf_sock_common_is_valid_access(int off, int size, + enum bpf_access_type type, + struct bpf_insn_access_aux *info) +{ + switch (off) { + case bpf_ctx_range_till(struct bpf_sock, type, priority): + return false; + default: + return bpf_sock_is_valid_access(off, size, type, info); + } +} + +bool bpf_sock_is_valid_access(int off, int size, enum bpf_access_type type, + struct bpf_insn_access_aux *info) +{ + const int size_default = sizeof(__u32); + int field_size; + + if (off < 0 || off >= sizeof(struct bpf_sock)) + return false; + if (off % size != 0) + return false; + + switch (off) { + case offsetof(struct bpf_sock, state): + case offsetof(struct bpf_sock, family): + case offsetof(struct bpf_sock, type): + case offsetof(struct bpf_sock, protocol): + case offsetof(struct bpf_sock, src_port): + case offsetof(struct bpf_sock, rx_queue_mapping): + case bpf_ctx_range(struct bpf_sock, src_ip4): + case bpf_ctx_range_till(struct bpf_sock, src_ip6[0], src_ip6[3]): + case bpf_ctx_range(struct bpf_sock, dst_ip4): + case bpf_ctx_range_till(struct bpf_sock, dst_ip6[0], dst_ip6[3]): + bpf_ctx_record_field_size(info, size_default); + return bpf_ctx_narrow_access_ok(off, size, size_default); + case bpf_ctx_range(struct bpf_sock, dst_port): + field_size = size == size_default ? + size_default : sizeof_field(struct bpf_sock, dst_port); + bpf_ctx_record_field_size(info, field_size); + return bpf_ctx_narrow_access_ok(off, size, field_size); + case offsetofend(struct bpf_sock, dst_port) ... + offsetof(struct bpf_sock, dst_ip4) - 1: + return false; + } + + return size == size_default; +} + +static bool sock_filter_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + if (!bpf_sock_is_valid_access(off, size, type, info)) + return false; + return __sock_filter_check_attach_type(off, type, + prog->expected_attach_type); +} + +static int bpf_noop_prologue(struct bpf_insn *insn_buf, bool direct_write, + const struct bpf_prog *prog) +{ + /* Neither direct read nor direct write requires any preliminary + * action. + */ + return 0; +} + +static int bpf_unclone_prologue(struct bpf_insn *insn_buf, bool direct_write, + const struct bpf_prog *prog, int drop_verdict) +{ + struct bpf_insn *insn = insn_buf; + + if (!direct_write) + return 0; + + /* if (!skb->cloned) + * goto start; + * + * (Fast-path, otherwise approximation that we might be + * a clone, do the rest in helper.) + */ + *insn++ = BPF_LDX_MEM(BPF_B, BPF_REG_6, BPF_REG_1, CLONED_OFFSET); + *insn++ = BPF_ALU32_IMM(BPF_AND, BPF_REG_6, CLONED_MASK); + *insn++ = BPF_JMP_IMM(BPF_JEQ, BPF_REG_6, 0, 7); + + /* ret = bpf_skb_pull_data(skb, 0); */ + *insn++ = BPF_MOV64_REG(BPF_REG_6, BPF_REG_1); + *insn++ = BPF_ALU64_REG(BPF_XOR, BPF_REG_2, BPF_REG_2); + *insn++ = BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0, + BPF_FUNC_skb_pull_data); + /* if (!ret) + * goto restore; + * return TC_ACT_SHOT; + */ + *insn++ = BPF_JMP_IMM(BPF_JEQ, BPF_REG_0, 0, 2); + *insn++ = BPF_ALU32_IMM(BPF_MOV, BPF_REG_0, drop_verdict); + *insn++ = BPF_EXIT_INSN(); + + /* restore: */ + *insn++ = BPF_MOV64_REG(BPF_REG_1, BPF_REG_6); + /* start: */ + *insn++ = prog->insnsi[0]; + + return insn - insn_buf; +} + +static int bpf_gen_ld_abs(const struct bpf_insn *orig, + struct bpf_insn *insn_buf) +{ + bool indirect = BPF_MODE(orig->code) == BPF_IND; + struct bpf_insn *insn = insn_buf; + + if (!indirect) { + *insn++ = BPF_MOV64_IMM(BPF_REG_2, orig->imm); + } else { + *insn++ = BPF_MOV64_REG(BPF_REG_2, orig->src_reg); + if (orig->imm) + *insn++ = BPF_ALU64_IMM(BPF_ADD, BPF_REG_2, orig->imm); + } + /* We're guaranteed here that CTX is in R6. */ + *insn++ = BPF_MOV64_REG(BPF_REG_1, BPF_REG_CTX); + + switch (BPF_SIZE(orig->code)) { + case BPF_B: + *insn++ = BPF_EMIT_CALL(bpf_skb_load_helper_8_no_cache); + break; + case BPF_H: + *insn++ = BPF_EMIT_CALL(bpf_skb_load_helper_16_no_cache); + break; + case BPF_W: + *insn++ = BPF_EMIT_CALL(bpf_skb_load_helper_32_no_cache); + break; + } + + *insn++ = BPF_JMP_IMM(BPF_JSGE, BPF_REG_0, 0, 2); + *insn++ = BPF_ALU32_REG(BPF_XOR, BPF_REG_0, BPF_REG_0); + *insn++ = BPF_EXIT_INSN(); + + return insn - insn_buf; +} + +static int tc_cls_act_prologue(struct bpf_insn *insn_buf, bool direct_write, + const struct bpf_prog *prog) +{ + return bpf_unclone_prologue(insn_buf, direct_write, prog, TC_ACT_SHOT); +} + +static bool tc_cls_act_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + if (type == BPF_WRITE) { + switch (off) { + case bpf_ctx_range(struct __sk_buff, mark): + case bpf_ctx_range(struct __sk_buff, tc_index): + case bpf_ctx_range(struct __sk_buff, priority): + case bpf_ctx_range(struct __sk_buff, tc_classid): + case bpf_ctx_range_till(struct __sk_buff, cb[0], cb[4]): + case bpf_ctx_range(struct __sk_buff, tstamp): + case bpf_ctx_range(struct __sk_buff, queue_mapping): + break; + default: + return false; + } + } + + switch (off) { + case bpf_ctx_range(struct __sk_buff, data): + info->reg_type = PTR_TO_PACKET; + break; + case bpf_ctx_range(struct __sk_buff, data_meta): + info->reg_type = PTR_TO_PACKET_META; + break; + case bpf_ctx_range(struct __sk_buff, data_end): + info->reg_type = PTR_TO_PACKET_END; + break; + case bpf_ctx_range_till(struct __sk_buff, family, local_port): + return false; + case offsetof(struct __sk_buff, tstamp_type): + /* The convert_ctx_access() on reading and writing + * __sk_buff->tstamp depends on whether the bpf prog + * has used __sk_buff->tstamp_type or not. + * Thus, we need to set prog->tstamp_type_access + * earlier during is_valid_access() here. + */ + ((struct bpf_prog *)prog)->tstamp_type_access = 1; + return size == sizeof(__u8); + } + + return bpf_skb_is_valid_access(off, size, type, prog, info); +} + +DEFINE_MUTEX(nf_conn_btf_access_lock); +EXPORT_SYMBOL_GPL(nf_conn_btf_access_lock); + +int (*nfct_btf_struct_access)(struct bpf_verifier_log *log, const struct btf *btf, + const struct btf_type *t, int off, int size, + enum bpf_access_type atype, u32 *next_btf_id, + enum bpf_type_flag *flag); +EXPORT_SYMBOL_GPL(nfct_btf_struct_access); + +static int tc_cls_act_btf_struct_access(struct bpf_verifier_log *log, + const struct btf *btf, + const struct btf_type *t, int off, + int size, enum bpf_access_type atype, + u32 *next_btf_id, + enum bpf_type_flag *flag) +{ + int ret = -EACCES; + + if (atype == BPF_READ) + return btf_struct_access(log, btf, t, off, size, atype, next_btf_id, + flag); + + mutex_lock(&nf_conn_btf_access_lock); + if (nfct_btf_struct_access) + ret = nfct_btf_struct_access(log, btf, t, off, size, atype, next_btf_id, flag); + mutex_unlock(&nf_conn_btf_access_lock); + + return ret; +} + +static bool __is_valid_xdp_access(int off, int size) +{ + if (off < 0 || off >= sizeof(struct xdp_md)) + return false; + if (off % size != 0) + return false; + if (size != sizeof(__u32)) + return false; + + return true; +} + +static bool xdp_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + if (prog->expected_attach_type != BPF_XDP_DEVMAP) { + switch (off) { + case offsetof(struct xdp_md, egress_ifindex): + return false; + } + } + + if (type == BPF_WRITE) { + if (bpf_prog_is_dev_bound(prog->aux)) { + switch (off) { + case offsetof(struct xdp_md, rx_queue_index): + return __is_valid_xdp_access(off, size); + } + } + return false; + } + + switch (off) { + case offsetof(struct xdp_md, data): + info->reg_type = PTR_TO_PACKET; + break; + case offsetof(struct xdp_md, data_meta): + info->reg_type = PTR_TO_PACKET_META; + break; + case offsetof(struct xdp_md, data_end): + info->reg_type = PTR_TO_PACKET_END; + break; + } + + return __is_valid_xdp_access(off, size); +} + +void bpf_warn_invalid_xdp_action(struct net_device *dev, struct bpf_prog *prog, u32 act) +{ + const u32 act_max = XDP_REDIRECT; + + pr_warn_once("%s XDP return value %u on prog %s (id %d) dev %s, expect packet loss!\n", + act > act_max ? "Illegal" : "Driver unsupported", + act, prog->aux->name, prog->aux->id, dev ? dev->name : "N/A"); +} +EXPORT_SYMBOL_GPL(bpf_warn_invalid_xdp_action); + +static int xdp_btf_struct_access(struct bpf_verifier_log *log, + const struct btf *btf, + const struct btf_type *t, int off, + int size, enum bpf_access_type atype, + u32 *next_btf_id, + enum bpf_type_flag *flag) +{ + int ret = -EACCES; + + if (atype == BPF_READ) + return btf_struct_access(log, btf, t, off, size, atype, next_btf_id, + flag); + + mutex_lock(&nf_conn_btf_access_lock); + if (nfct_btf_struct_access) + ret = nfct_btf_struct_access(log, btf, t, off, size, atype, next_btf_id, flag); + mutex_unlock(&nf_conn_btf_access_lock); + + return ret; +} + +static bool sock_addr_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + const int size_default = sizeof(__u32); + + if (off < 0 || off >= sizeof(struct bpf_sock_addr)) + return false; + if (off % size != 0) + return false; + + /* Disallow access to IPv6 fields from IPv4 contex and vise + * versa. + */ + switch (off) { + case bpf_ctx_range(struct bpf_sock_addr, user_ip4): + switch (prog->expected_attach_type) { + case BPF_CGROUP_INET4_BIND: + case BPF_CGROUP_INET4_CONNECT: + case BPF_CGROUP_INET4_GETPEERNAME: + case BPF_CGROUP_INET4_GETSOCKNAME: + case BPF_CGROUP_UDP4_SENDMSG: + case BPF_CGROUP_UDP4_RECVMSG: + break; + default: + return false; + } + break; + case bpf_ctx_range_till(struct bpf_sock_addr, user_ip6[0], user_ip6[3]): + switch (prog->expected_attach_type) { + case BPF_CGROUP_INET6_BIND: + case BPF_CGROUP_INET6_CONNECT: + case BPF_CGROUP_INET6_GETPEERNAME: + case BPF_CGROUP_INET6_GETSOCKNAME: + case BPF_CGROUP_UDP6_SENDMSG: + case BPF_CGROUP_UDP6_RECVMSG: + break; + default: + return false; + } + break; + case bpf_ctx_range(struct bpf_sock_addr, msg_src_ip4): + switch (prog->expected_attach_type) { + case BPF_CGROUP_UDP4_SENDMSG: + break; + default: + return false; + } + break; + case bpf_ctx_range_till(struct bpf_sock_addr, msg_src_ip6[0], + msg_src_ip6[3]): + switch (prog->expected_attach_type) { + case BPF_CGROUP_UDP6_SENDMSG: + break; + default: + return false; + } + break; + } + + switch (off) { + case bpf_ctx_range(struct bpf_sock_addr, user_ip4): + case bpf_ctx_range_till(struct bpf_sock_addr, user_ip6[0], user_ip6[3]): + case bpf_ctx_range(struct bpf_sock_addr, msg_src_ip4): + case bpf_ctx_range_till(struct bpf_sock_addr, msg_src_ip6[0], + msg_src_ip6[3]): + case bpf_ctx_range(struct bpf_sock_addr, user_port): + if (type == BPF_READ) { + bpf_ctx_record_field_size(info, size_default); + + if (bpf_ctx_wide_access_ok(off, size, + struct bpf_sock_addr, + user_ip6)) + return true; + + if (bpf_ctx_wide_access_ok(off, size, + struct bpf_sock_addr, + msg_src_ip6)) + return true; + + if (!bpf_ctx_narrow_access_ok(off, size, size_default)) + return false; + } else { + if (bpf_ctx_wide_access_ok(off, size, + struct bpf_sock_addr, + user_ip6)) + return true; + + if (bpf_ctx_wide_access_ok(off, size, + struct bpf_sock_addr, + msg_src_ip6)) + return true; + + if (size != size_default) + return false; + } + break; + case offsetof(struct bpf_sock_addr, sk): + if (type != BPF_READ) + return false; + if (size != sizeof(__u64)) + return false; + info->reg_type = PTR_TO_SOCKET; + break; + default: + if (type == BPF_READ) { + if (size != size_default) + return false; + } else { + return false; + } + } + + return true; +} + +static bool sock_ops_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + const int size_default = sizeof(__u32); + + if (off < 0 || off >= sizeof(struct bpf_sock_ops)) + return false; + + /* The verifier guarantees that size > 0. */ + if (off % size != 0) + return false; + + if (type == BPF_WRITE) { + switch (off) { + case offsetof(struct bpf_sock_ops, reply): + case offsetof(struct bpf_sock_ops, sk_txhash): + if (size != size_default) + return false; + break; + default: + return false; + } + } else { + switch (off) { + case bpf_ctx_range_till(struct bpf_sock_ops, bytes_received, + bytes_acked): + if (size != sizeof(__u64)) + return false; + break; + case offsetof(struct bpf_sock_ops, sk): + if (size != sizeof(__u64)) + return false; + info->reg_type = PTR_TO_SOCKET_OR_NULL; + break; + case offsetof(struct bpf_sock_ops, skb_data): + if (size != sizeof(__u64)) + return false; + info->reg_type = PTR_TO_PACKET; + break; + case offsetof(struct bpf_sock_ops, skb_data_end): + if (size != sizeof(__u64)) + return false; + info->reg_type = PTR_TO_PACKET_END; + break; + case offsetof(struct bpf_sock_ops, skb_tcp_flags): + bpf_ctx_record_field_size(info, size_default); + return bpf_ctx_narrow_access_ok(off, size, + size_default); + default: + if (size != size_default) + return false; + break; + } + } + + return true; +} + +static int sk_skb_prologue(struct bpf_insn *insn_buf, bool direct_write, + const struct bpf_prog *prog) +{ + return bpf_unclone_prologue(insn_buf, direct_write, prog, SK_DROP); +} + +static bool sk_skb_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + switch (off) { + case bpf_ctx_range(struct __sk_buff, tc_classid): + case bpf_ctx_range(struct __sk_buff, data_meta): + case bpf_ctx_range(struct __sk_buff, tstamp): + case bpf_ctx_range(struct __sk_buff, wire_len): + case bpf_ctx_range(struct __sk_buff, hwtstamp): + return false; + } + + if (type == BPF_WRITE) { + switch (off) { + case bpf_ctx_range(struct __sk_buff, tc_index): + case bpf_ctx_range(struct __sk_buff, priority): + break; + default: + return false; + } + } + + switch (off) { + case bpf_ctx_range(struct __sk_buff, mark): + return false; + case bpf_ctx_range(struct __sk_buff, data): + info->reg_type = PTR_TO_PACKET; + break; + case bpf_ctx_range(struct __sk_buff, data_end): + info->reg_type = PTR_TO_PACKET_END; + break; + } + + return bpf_skb_is_valid_access(off, size, type, prog, info); +} + +static bool sk_msg_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + if (type == BPF_WRITE) + return false; + + if (off % size != 0) + return false; + + switch (off) { + case offsetof(struct sk_msg_md, data): + info->reg_type = PTR_TO_PACKET; + if (size != sizeof(__u64)) + return false; + break; + case offsetof(struct sk_msg_md, data_end): + info->reg_type = PTR_TO_PACKET_END; + if (size != sizeof(__u64)) + return false; + break; + case offsetof(struct sk_msg_md, sk): + if (size != sizeof(__u64)) + return false; + info->reg_type = PTR_TO_SOCKET; + break; + case bpf_ctx_range(struct sk_msg_md, family): + case bpf_ctx_range(struct sk_msg_md, remote_ip4): + case bpf_ctx_range(struct sk_msg_md, local_ip4): + case bpf_ctx_range_till(struct sk_msg_md, remote_ip6[0], remote_ip6[3]): + case bpf_ctx_range_till(struct sk_msg_md, local_ip6[0], local_ip6[3]): + case bpf_ctx_range(struct sk_msg_md, remote_port): + case bpf_ctx_range(struct sk_msg_md, local_port): + case bpf_ctx_range(struct sk_msg_md, size): + if (size != sizeof(__u32)) + return false; + break; + default: + return false; + } + return true; +} + +static bool flow_dissector_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + const int size_default = sizeof(__u32); + + if (off < 0 || off >= sizeof(struct __sk_buff)) + return false; + + if (type == BPF_WRITE) + return false; + + switch (off) { + case bpf_ctx_range(struct __sk_buff, data): + if (size != size_default) + return false; + info->reg_type = PTR_TO_PACKET; + return true; + case bpf_ctx_range(struct __sk_buff, data_end): + if (size != size_default) + return false; + info->reg_type = PTR_TO_PACKET_END; + return true; + case bpf_ctx_range_ptr(struct __sk_buff, flow_keys): + if (size != sizeof(__u64)) + return false; + info->reg_type = PTR_TO_FLOW_KEYS; + return true; + default: + return false; + } +} + +static u32 flow_dissector_convert_ctx_access(enum bpf_access_type type, + const struct bpf_insn *si, + struct bpf_insn *insn_buf, + struct bpf_prog *prog, + u32 *target_size) + +{ + struct bpf_insn *insn = insn_buf; + + switch (si->off) { + case offsetof(struct __sk_buff, data): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct bpf_flow_dissector, data), + si->dst_reg, si->src_reg, + offsetof(struct bpf_flow_dissector, data)); + break; + + case offsetof(struct __sk_buff, data_end): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct bpf_flow_dissector, data_end), + si->dst_reg, si->src_reg, + offsetof(struct bpf_flow_dissector, data_end)); + break; + + case offsetof(struct __sk_buff, flow_keys): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct bpf_flow_dissector, flow_keys), + si->dst_reg, si->src_reg, + offsetof(struct bpf_flow_dissector, flow_keys)); + break; + } + + return insn - insn_buf; +} + +static struct bpf_insn *bpf_convert_tstamp_type_read(const struct bpf_insn *si, + struct bpf_insn *insn) +{ + __u8 value_reg = si->dst_reg; + __u8 skb_reg = si->src_reg; + /* AX is needed because src_reg and dst_reg could be the same */ + __u8 tmp_reg = BPF_REG_AX; + + *insn++ = BPF_LDX_MEM(BPF_B, tmp_reg, skb_reg, + PKT_VLAN_PRESENT_OFFSET); + *insn++ = BPF_JMP32_IMM(BPF_JSET, tmp_reg, + SKB_MONO_DELIVERY_TIME_MASK, 2); + *insn++ = BPF_MOV32_IMM(value_reg, BPF_SKB_TSTAMP_UNSPEC); + *insn++ = BPF_JMP_A(1); + *insn++ = BPF_MOV32_IMM(value_reg, BPF_SKB_TSTAMP_DELIVERY_MONO); + + return insn; +} + +static struct bpf_insn *bpf_convert_shinfo_access(const struct bpf_insn *si, + struct bpf_insn *insn) +{ + /* si->dst_reg = skb_shinfo(SKB); */ +#ifdef NET_SKBUFF_DATA_USES_OFFSET + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, end), + BPF_REG_AX, si->src_reg, + offsetof(struct sk_buff, end)); + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, head), + si->dst_reg, si->src_reg, + offsetof(struct sk_buff, head)); + *insn++ = BPF_ALU64_REG(BPF_ADD, si->dst_reg, BPF_REG_AX); +#else + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, end), + si->dst_reg, si->src_reg, + offsetof(struct sk_buff, end)); +#endif + + return insn; +} + +static struct bpf_insn *bpf_convert_tstamp_read(const struct bpf_prog *prog, + const struct bpf_insn *si, + struct bpf_insn *insn) +{ + __u8 value_reg = si->dst_reg; + __u8 skb_reg = si->src_reg; + +#ifdef CONFIG_NET_CLS_ACT + /* If the tstamp_type is read, + * the bpf prog is aware the tstamp could have delivery time. + * Thus, read skb->tstamp as is if tstamp_type_access is true. + */ + if (!prog->tstamp_type_access) { + /* AX is needed because src_reg and dst_reg could be the same */ + __u8 tmp_reg = BPF_REG_AX; + + *insn++ = BPF_LDX_MEM(BPF_B, tmp_reg, skb_reg, PKT_VLAN_PRESENT_OFFSET); + *insn++ = BPF_ALU32_IMM(BPF_AND, tmp_reg, + TC_AT_INGRESS_MASK | SKB_MONO_DELIVERY_TIME_MASK); + *insn++ = BPF_JMP32_IMM(BPF_JNE, tmp_reg, + TC_AT_INGRESS_MASK | SKB_MONO_DELIVERY_TIME_MASK, 2); + /* skb->tc_at_ingress && skb->mono_delivery_time, + * read 0 as the (rcv) timestamp. + */ + *insn++ = BPF_MOV64_IMM(value_reg, 0); + *insn++ = BPF_JMP_A(1); + } +#endif + + *insn++ = BPF_LDX_MEM(BPF_DW, value_reg, skb_reg, + offsetof(struct sk_buff, tstamp)); + return insn; +} + +static struct bpf_insn *bpf_convert_tstamp_write(const struct bpf_prog *prog, + const struct bpf_insn *si, + struct bpf_insn *insn) +{ + __u8 value_reg = si->src_reg; + __u8 skb_reg = si->dst_reg; + +#ifdef CONFIG_NET_CLS_ACT + /* If the tstamp_type is read, + * the bpf prog is aware the tstamp could have delivery time. + * Thus, write skb->tstamp as is if tstamp_type_access is true. + * Otherwise, writing at ingress will have to clear the + * mono_delivery_time bit also. + */ + if (!prog->tstamp_type_access) { + __u8 tmp_reg = BPF_REG_AX; + + *insn++ = BPF_LDX_MEM(BPF_B, tmp_reg, skb_reg, PKT_VLAN_PRESENT_OFFSET); + /* Writing __sk_buff->tstamp as ingress, goto */ + *insn++ = BPF_JMP32_IMM(BPF_JSET, tmp_reg, TC_AT_INGRESS_MASK, 1); + /* goto */ + *insn++ = BPF_JMP_A(2); + /* : mono_delivery_time */ + *insn++ = BPF_ALU32_IMM(BPF_AND, tmp_reg, ~SKB_MONO_DELIVERY_TIME_MASK); + *insn++ = BPF_STX_MEM(BPF_B, skb_reg, tmp_reg, PKT_VLAN_PRESENT_OFFSET); + } +#endif + + /* : skb->tstamp = tstamp */ + *insn++ = BPF_STX_MEM(BPF_DW, skb_reg, value_reg, + offsetof(struct sk_buff, tstamp)); + return insn; +} + +static u32 bpf_convert_ctx_access(enum bpf_access_type type, + const struct bpf_insn *si, + struct bpf_insn *insn_buf, + struct bpf_prog *prog, u32 *target_size) +{ + struct bpf_insn *insn = insn_buf; + int off; + + switch (si->off) { + case offsetof(struct __sk_buff, len): + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, len, 4, + target_size)); + break; + + case offsetof(struct __sk_buff, protocol): + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, protocol, 2, + target_size)); + break; + + case offsetof(struct __sk_buff, vlan_proto): + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, vlan_proto, 2, + target_size)); + break; + + case offsetof(struct __sk_buff, priority): + if (type == BPF_WRITE) + *insn++ = BPF_STX_MEM(BPF_W, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, priority, 4, + target_size)); + else + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, priority, 4, + target_size)); + break; + + case offsetof(struct __sk_buff, ingress_ifindex): + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, skb_iif, 4, + target_size)); + break; + + case offsetof(struct __sk_buff, ifindex): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, dev), + si->dst_reg, si->src_reg, + offsetof(struct sk_buff, dev)); + *insn++ = BPF_JMP_IMM(BPF_JEQ, si->dst_reg, 0, 1); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + bpf_target_off(struct net_device, ifindex, 4, + target_size)); + break; + + case offsetof(struct __sk_buff, hash): + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, hash, 4, + target_size)); + break; + + case offsetof(struct __sk_buff, mark): + if (type == BPF_WRITE) + *insn++ = BPF_STX_MEM(BPF_W, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, mark, 4, + target_size)); + else + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, mark, 4, + target_size)); + break; + + case offsetof(struct __sk_buff, pkt_type): + *target_size = 1; + *insn++ = BPF_LDX_MEM(BPF_B, si->dst_reg, si->src_reg, + PKT_TYPE_OFFSET); + *insn++ = BPF_ALU32_IMM(BPF_AND, si->dst_reg, PKT_TYPE_MAX); +#ifdef __BIG_ENDIAN_BITFIELD + *insn++ = BPF_ALU32_IMM(BPF_RSH, si->dst_reg, 5); +#endif + break; + + case offsetof(struct __sk_buff, queue_mapping): + if (type == BPF_WRITE) { + *insn++ = BPF_JMP_IMM(BPF_JGE, si->src_reg, NO_QUEUE_MAPPING, 1); + *insn++ = BPF_STX_MEM(BPF_H, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, + queue_mapping, + 2, target_size)); + } else { + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, + queue_mapping, + 2, target_size)); + } + break; + + case offsetof(struct __sk_buff, vlan_present): + *target_size = 1; + *insn++ = BPF_LDX_MEM(BPF_B, si->dst_reg, si->src_reg, + PKT_VLAN_PRESENT_OFFSET); + if (PKT_VLAN_PRESENT_BIT) + *insn++ = BPF_ALU32_IMM(BPF_RSH, si->dst_reg, PKT_VLAN_PRESENT_BIT); + if (PKT_VLAN_PRESENT_BIT < 7) + *insn++ = BPF_ALU32_IMM(BPF_AND, si->dst_reg, 1); + break; + + case offsetof(struct __sk_buff, vlan_tci): + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, vlan_tci, 2, + target_size)); + break; + + case offsetof(struct __sk_buff, cb[0]) ... + offsetofend(struct __sk_buff, cb[4]) - 1: + BUILD_BUG_ON(sizeof_field(struct qdisc_skb_cb, data) < 20); + BUILD_BUG_ON((offsetof(struct sk_buff, cb) + + offsetof(struct qdisc_skb_cb, data)) % + sizeof(__u64)); + + prog->cb_access = 1; + off = si->off; + off -= offsetof(struct __sk_buff, cb[0]); + off += offsetof(struct sk_buff, cb); + off += offsetof(struct qdisc_skb_cb, data); + if (type == BPF_WRITE) + *insn++ = BPF_STX_MEM(BPF_SIZE(si->code), si->dst_reg, + si->src_reg, off); + else + *insn++ = BPF_LDX_MEM(BPF_SIZE(si->code), si->dst_reg, + si->src_reg, off); + break; + + case offsetof(struct __sk_buff, tc_classid): + BUILD_BUG_ON(sizeof_field(struct qdisc_skb_cb, tc_classid) != 2); + + off = si->off; + off -= offsetof(struct __sk_buff, tc_classid); + off += offsetof(struct sk_buff, cb); + off += offsetof(struct qdisc_skb_cb, tc_classid); + *target_size = 2; + if (type == BPF_WRITE) + *insn++ = BPF_STX_MEM(BPF_H, si->dst_reg, + si->src_reg, off); + else + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, + si->src_reg, off); + break; + + case offsetof(struct __sk_buff, data): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, data), + si->dst_reg, si->src_reg, + offsetof(struct sk_buff, data)); + break; + + case offsetof(struct __sk_buff, data_meta): + off = si->off; + off -= offsetof(struct __sk_buff, data_meta); + off += offsetof(struct sk_buff, cb); + off += offsetof(struct bpf_skb_data_end, data_meta); + *insn++ = BPF_LDX_MEM(BPF_SIZEOF(void *), si->dst_reg, + si->src_reg, off); + break; + + case offsetof(struct __sk_buff, data_end): + off = si->off; + off -= offsetof(struct __sk_buff, data_end); + off += offsetof(struct sk_buff, cb); + off += offsetof(struct bpf_skb_data_end, data_end); + *insn++ = BPF_LDX_MEM(BPF_SIZEOF(void *), si->dst_reg, + si->src_reg, off); + break; + + case offsetof(struct __sk_buff, tc_index): +#ifdef CONFIG_NET_SCHED + if (type == BPF_WRITE) + *insn++ = BPF_STX_MEM(BPF_H, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, tc_index, 2, + target_size)); + else + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, tc_index, 2, + target_size)); +#else + *target_size = 2; + if (type == BPF_WRITE) + *insn++ = BPF_MOV64_REG(si->dst_reg, si->dst_reg); + else + *insn++ = BPF_MOV64_IMM(si->dst_reg, 0); +#endif + break; + + case offsetof(struct __sk_buff, napi_id): +#if defined(CONFIG_NET_RX_BUSY_POLL) + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, + bpf_target_off(struct sk_buff, napi_id, 4, + target_size)); + *insn++ = BPF_JMP_IMM(BPF_JGE, si->dst_reg, MIN_NAPI_ID, 1); + *insn++ = BPF_MOV64_IMM(si->dst_reg, 0); +#else + *target_size = 4; + *insn++ = BPF_MOV64_IMM(si->dst_reg, 0); +#endif + break; + case offsetof(struct __sk_buff, family): + BUILD_BUG_ON(sizeof_field(struct sock_common, skc_family) != 2); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk), + si->dst_reg, si->src_reg, + offsetof(struct sk_buff, sk)); + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, + bpf_target_off(struct sock_common, + skc_family, + 2, target_size)); + break; + case offsetof(struct __sk_buff, remote_ip4): + BUILD_BUG_ON(sizeof_field(struct sock_common, skc_daddr) != 4); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk), + si->dst_reg, si->src_reg, + offsetof(struct sk_buff, sk)); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + bpf_target_off(struct sock_common, + skc_daddr, + 4, target_size)); + break; + case offsetof(struct __sk_buff, local_ip4): + BUILD_BUG_ON(sizeof_field(struct sock_common, + skc_rcv_saddr) != 4); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk), + si->dst_reg, si->src_reg, + offsetof(struct sk_buff, sk)); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + bpf_target_off(struct sock_common, + skc_rcv_saddr, + 4, target_size)); + break; + case offsetof(struct __sk_buff, remote_ip6[0]) ... + offsetof(struct __sk_buff, remote_ip6[3]): +#if IS_ENABLED(CONFIG_IPV6) + BUILD_BUG_ON(sizeof_field(struct sock_common, + skc_v6_daddr.s6_addr32[0]) != 4); + + off = si->off; + off -= offsetof(struct __sk_buff, remote_ip6[0]); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk), + si->dst_reg, si->src_reg, + offsetof(struct sk_buff, sk)); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + offsetof(struct sock_common, + skc_v6_daddr.s6_addr32[0]) + + off); +#else + *insn++ = BPF_MOV32_IMM(si->dst_reg, 0); +#endif + break; + case offsetof(struct __sk_buff, local_ip6[0]) ... + offsetof(struct __sk_buff, local_ip6[3]): +#if IS_ENABLED(CONFIG_IPV6) + BUILD_BUG_ON(sizeof_field(struct sock_common, + skc_v6_rcv_saddr.s6_addr32[0]) != 4); + + off = si->off; + off -= offsetof(struct __sk_buff, local_ip6[0]); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk), + si->dst_reg, si->src_reg, + offsetof(struct sk_buff, sk)); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + offsetof(struct sock_common, + skc_v6_rcv_saddr.s6_addr32[0]) + + off); +#else + *insn++ = BPF_MOV32_IMM(si->dst_reg, 0); +#endif + break; + + case offsetof(struct __sk_buff, remote_port): + BUILD_BUG_ON(sizeof_field(struct sock_common, skc_dport) != 2); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk), + si->dst_reg, si->src_reg, + offsetof(struct sk_buff, sk)); + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, + bpf_target_off(struct sock_common, + skc_dport, + 2, target_size)); +#ifndef __BIG_ENDIAN_BITFIELD + *insn++ = BPF_ALU32_IMM(BPF_LSH, si->dst_reg, 16); +#endif + break; + + case offsetof(struct __sk_buff, local_port): + BUILD_BUG_ON(sizeof_field(struct sock_common, skc_num) != 2); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk), + si->dst_reg, si->src_reg, + offsetof(struct sk_buff, sk)); + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, + bpf_target_off(struct sock_common, + skc_num, 2, target_size)); + break; + + case offsetof(struct __sk_buff, tstamp): + BUILD_BUG_ON(sizeof_field(struct sk_buff, tstamp) != 8); + + if (type == BPF_WRITE) + insn = bpf_convert_tstamp_write(prog, si, insn); + else + insn = bpf_convert_tstamp_read(prog, si, insn); + break; + + case offsetof(struct __sk_buff, tstamp_type): + insn = bpf_convert_tstamp_type_read(si, insn); + break; + + case offsetof(struct __sk_buff, gso_segs): + insn = bpf_convert_shinfo_access(si, insn); + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct skb_shared_info, gso_segs), + si->dst_reg, si->dst_reg, + bpf_target_off(struct skb_shared_info, + gso_segs, 2, + target_size)); + break; + case offsetof(struct __sk_buff, gso_size): + insn = bpf_convert_shinfo_access(si, insn); + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct skb_shared_info, gso_size), + si->dst_reg, si->dst_reg, + bpf_target_off(struct skb_shared_info, + gso_size, 2, + target_size)); + break; + case offsetof(struct __sk_buff, wire_len): + BUILD_BUG_ON(sizeof_field(struct qdisc_skb_cb, pkt_len) != 4); + + off = si->off; + off -= offsetof(struct __sk_buff, wire_len); + off += offsetof(struct sk_buff, cb); + off += offsetof(struct qdisc_skb_cb, pkt_len); + *target_size = 4; + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, off); + break; + + case offsetof(struct __sk_buff, sk): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk), + si->dst_reg, si->src_reg, + offsetof(struct sk_buff, sk)); + break; + case offsetof(struct __sk_buff, hwtstamp): + BUILD_BUG_ON(sizeof_field(struct skb_shared_hwtstamps, hwtstamp) != 8); + BUILD_BUG_ON(offsetof(struct skb_shared_hwtstamps, hwtstamp) != 0); + + insn = bpf_convert_shinfo_access(si, insn); + *insn++ = BPF_LDX_MEM(BPF_DW, + si->dst_reg, si->dst_reg, + bpf_target_off(struct skb_shared_info, + hwtstamps, 8, + target_size)); + break; + } + + return insn - insn_buf; +} + +u32 bpf_sock_convert_ctx_access(enum bpf_access_type type, + const struct bpf_insn *si, + struct bpf_insn *insn_buf, + struct bpf_prog *prog, u32 *target_size) +{ + struct bpf_insn *insn = insn_buf; + int off; + + switch (si->off) { + case offsetof(struct bpf_sock, bound_dev_if): + BUILD_BUG_ON(sizeof_field(struct sock, sk_bound_dev_if) != 4); + + if (type == BPF_WRITE) + *insn++ = BPF_STX_MEM(BPF_W, si->dst_reg, si->src_reg, + offsetof(struct sock, sk_bound_dev_if)); + else + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, + offsetof(struct sock, sk_bound_dev_if)); + break; + + case offsetof(struct bpf_sock, mark): + BUILD_BUG_ON(sizeof_field(struct sock, sk_mark) != 4); + + if (type == BPF_WRITE) + *insn++ = BPF_STX_MEM(BPF_W, si->dst_reg, si->src_reg, + offsetof(struct sock, sk_mark)); + else + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, + offsetof(struct sock, sk_mark)); + break; + + case offsetof(struct bpf_sock, priority): + BUILD_BUG_ON(sizeof_field(struct sock, sk_priority) != 4); + + if (type == BPF_WRITE) + *insn++ = BPF_STX_MEM(BPF_W, si->dst_reg, si->src_reg, + offsetof(struct sock, sk_priority)); + else + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, + offsetof(struct sock, sk_priority)); + break; + + case offsetof(struct bpf_sock, family): + *insn++ = BPF_LDX_MEM( + BPF_FIELD_SIZEOF(struct sock_common, skc_family), + si->dst_reg, si->src_reg, + bpf_target_off(struct sock_common, + skc_family, + sizeof_field(struct sock_common, + skc_family), + target_size)); + break; + + case offsetof(struct bpf_sock, type): + *insn++ = BPF_LDX_MEM( + BPF_FIELD_SIZEOF(struct sock, sk_type), + si->dst_reg, si->src_reg, + bpf_target_off(struct sock, sk_type, + sizeof_field(struct sock, sk_type), + target_size)); + break; + + case offsetof(struct bpf_sock, protocol): + *insn++ = BPF_LDX_MEM( + BPF_FIELD_SIZEOF(struct sock, sk_protocol), + si->dst_reg, si->src_reg, + bpf_target_off(struct sock, sk_protocol, + sizeof_field(struct sock, sk_protocol), + target_size)); + break; + + case offsetof(struct bpf_sock, src_ip4): + *insn++ = BPF_LDX_MEM( + BPF_SIZE(si->code), si->dst_reg, si->src_reg, + bpf_target_off(struct sock_common, skc_rcv_saddr, + sizeof_field(struct sock_common, + skc_rcv_saddr), + target_size)); + break; + + case offsetof(struct bpf_sock, dst_ip4): + *insn++ = BPF_LDX_MEM( + BPF_SIZE(si->code), si->dst_reg, si->src_reg, + bpf_target_off(struct sock_common, skc_daddr, + sizeof_field(struct sock_common, + skc_daddr), + target_size)); + break; + + case bpf_ctx_range_till(struct bpf_sock, src_ip6[0], src_ip6[3]): +#if IS_ENABLED(CONFIG_IPV6) + off = si->off; + off -= offsetof(struct bpf_sock, src_ip6[0]); + *insn++ = BPF_LDX_MEM( + BPF_SIZE(si->code), si->dst_reg, si->src_reg, + bpf_target_off( + struct sock_common, + skc_v6_rcv_saddr.s6_addr32[0], + sizeof_field(struct sock_common, + skc_v6_rcv_saddr.s6_addr32[0]), + target_size) + off); +#else + (void)off; + *insn++ = BPF_MOV32_IMM(si->dst_reg, 0); +#endif + break; + + case bpf_ctx_range_till(struct bpf_sock, dst_ip6[0], dst_ip6[3]): +#if IS_ENABLED(CONFIG_IPV6) + off = si->off; + off -= offsetof(struct bpf_sock, dst_ip6[0]); + *insn++ = BPF_LDX_MEM( + BPF_SIZE(si->code), si->dst_reg, si->src_reg, + bpf_target_off(struct sock_common, + skc_v6_daddr.s6_addr32[0], + sizeof_field(struct sock_common, + skc_v6_daddr.s6_addr32[0]), + target_size) + off); +#else + *insn++ = BPF_MOV32_IMM(si->dst_reg, 0); + *target_size = 4; +#endif + break; + + case offsetof(struct bpf_sock, src_port): + *insn++ = BPF_LDX_MEM( + BPF_FIELD_SIZEOF(struct sock_common, skc_num), + si->dst_reg, si->src_reg, + bpf_target_off(struct sock_common, skc_num, + sizeof_field(struct sock_common, + skc_num), + target_size)); + break; + + case offsetof(struct bpf_sock, dst_port): + *insn++ = BPF_LDX_MEM( + BPF_FIELD_SIZEOF(struct sock_common, skc_dport), + si->dst_reg, si->src_reg, + bpf_target_off(struct sock_common, skc_dport, + sizeof_field(struct sock_common, + skc_dport), + target_size)); + break; + + case offsetof(struct bpf_sock, state): + *insn++ = BPF_LDX_MEM( + BPF_FIELD_SIZEOF(struct sock_common, skc_state), + si->dst_reg, si->src_reg, + bpf_target_off(struct sock_common, skc_state, + sizeof_field(struct sock_common, + skc_state), + target_size)); + break; + case offsetof(struct bpf_sock, rx_queue_mapping): +#ifdef CONFIG_SOCK_RX_QUEUE_MAPPING + *insn++ = BPF_LDX_MEM( + BPF_FIELD_SIZEOF(struct sock, sk_rx_queue_mapping), + si->dst_reg, si->src_reg, + bpf_target_off(struct sock, sk_rx_queue_mapping, + sizeof_field(struct sock, + sk_rx_queue_mapping), + target_size)); + *insn++ = BPF_JMP_IMM(BPF_JNE, si->dst_reg, NO_QUEUE_MAPPING, + 1); + *insn++ = BPF_MOV64_IMM(si->dst_reg, -1); +#else + *insn++ = BPF_MOV64_IMM(si->dst_reg, -1); + *target_size = 2; +#endif + break; + } + + return insn - insn_buf; +} + +static u32 tc_cls_act_convert_ctx_access(enum bpf_access_type type, + const struct bpf_insn *si, + struct bpf_insn *insn_buf, + struct bpf_prog *prog, u32 *target_size) +{ + struct bpf_insn *insn = insn_buf; + + switch (si->off) { + case offsetof(struct __sk_buff, ifindex): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, dev), + si->dst_reg, si->src_reg, + offsetof(struct sk_buff, dev)); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + bpf_target_off(struct net_device, ifindex, 4, + target_size)); + break; + default: + return bpf_convert_ctx_access(type, si, insn_buf, prog, + target_size); + } + + return insn - insn_buf; +} + +static u32 xdp_convert_ctx_access(enum bpf_access_type type, + const struct bpf_insn *si, + struct bpf_insn *insn_buf, + struct bpf_prog *prog, u32 *target_size) +{ + struct bpf_insn *insn = insn_buf; + + switch (si->off) { + case offsetof(struct xdp_md, data): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct xdp_buff, data), + si->dst_reg, si->src_reg, + offsetof(struct xdp_buff, data)); + break; + case offsetof(struct xdp_md, data_meta): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct xdp_buff, data_meta), + si->dst_reg, si->src_reg, + offsetof(struct xdp_buff, data_meta)); + break; + case offsetof(struct xdp_md, data_end): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct xdp_buff, data_end), + si->dst_reg, si->src_reg, + offsetof(struct xdp_buff, data_end)); + break; + case offsetof(struct xdp_md, ingress_ifindex): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct xdp_buff, rxq), + si->dst_reg, si->src_reg, + offsetof(struct xdp_buff, rxq)); + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct xdp_rxq_info, dev), + si->dst_reg, si->dst_reg, + offsetof(struct xdp_rxq_info, dev)); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + offsetof(struct net_device, ifindex)); + break; + case offsetof(struct xdp_md, rx_queue_index): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct xdp_buff, rxq), + si->dst_reg, si->src_reg, + offsetof(struct xdp_buff, rxq)); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + offsetof(struct xdp_rxq_info, + queue_index)); + break; + case offsetof(struct xdp_md, egress_ifindex): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct xdp_buff, txq), + si->dst_reg, si->src_reg, + offsetof(struct xdp_buff, txq)); + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct xdp_txq_info, dev), + si->dst_reg, si->dst_reg, + offsetof(struct xdp_txq_info, dev)); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + offsetof(struct net_device, ifindex)); + break; + } + + return insn - insn_buf; +} + +/* SOCK_ADDR_LOAD_NESTED_FIELD() loads Nested Field S.F.NF where S is type of + * context Structure, F is Field in context structure that contains a pointer + * to Nested Structure of type NS that has the field NF. + * + * SIZE encodes the load size (BPF_B, BPF_H, etc). It's up to caller to make + * sure that SIZE is not greater than actual size of S.F.NF. + * + * If offset OFF is provided, the load happens from that offset relative to + * offset of NF. + */ +#define SOCK_ADDR_LOAD_NESTED_FIELD_SIZE_OFF(S, NS, F, NF, SIZE, OFF) \ + do { \ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(S, F), si->dst_reg, \ + si->src_reg, offsetof(S, F)); \ + *insn++ = BPF_LDX_MEM( \ + SIZE, si->dst_reg, si->dst_reg, \ + bpf_target_off(NS, NF, sizeof_field(NS, NF), \ + target_size) \ + + OFF); \ + } while (0) + +#define SOCK_ADDR_LOAD_NESTED_FIELD(S, NS, F, NF) \ + SOCK_ADDR_LOAD_NESTED_FIELD_SIZE_OFF(S, NS, F, NF, \ + BPF_FIELD_SIZEOF(NS, NF), 0) + +/* SOCK_ADDR_STORE_NESTED_FIELD_OFF() has semantic similar to + * SOCK_ADDR_LOAD_NESTED_FIELD_SIZE_OFF() but for store operation. + * + * In addition it uses Temporary Field TF (member of struct S) as the 3rd + * "register" since two registers available in convert_ctx_access are not + * enough: we can't override neither SRC, since it contains value to store, nor + * DST since it contains pointer to context that may be used by later + * instructions. But we need a temporary place to save pointer to nested + * structure whose field we want to store to. + */ +#define SOCK_ADDR_STORE_NESTED_FIELD_OFF(S, NS, F, NF, SIZE, OFF, TF) \ + do { \ + int tmp_reg = BPF_REG_9; \ + if (si->src_reg == tmp_reg || si->dst_reg == tmp_reg) \ + --tmp_reg; \ + if (si->src_reg == tmp_reg || si->dst_reg == tmp_reg) \ + --tmp_reg; \ + *insn++ = BPF_STX_MEM(BPF_DW, si->dst_reg, tmp_reg, \ + offsetof(S, TF)); \ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(S, F), tmp_reg, \ + si->dst_reg, offsetof(S, F)); \ + *insn++ = BPF_STX_MEM(SIZE, tmp_reg, si->src_reg, \ + bpf_target_off(NS, NF, sizeof_field(NS, NF), \ + target_size) \ + + OFF); \ + *insn++ = BPF_LDX_MEM(BPF_DW, tmp_reg, si->dst_reg, \ + offsetof(S, TF)); \ + } while (0) + +#define SOCK_ADDR_LOAD_OR_STORE_NESTED_FIELD_SIZE_OFF(S, NS, F, NF, SIZE, OFF, \ + TF) \ + do { \ + if (type == BPF_WRITE) { \ + SOCK_ADDR_STORE_NESTED_FIELD_OFF(S, NS, F, NF, SIZE, \ + OFF, TF); \ + } else { \ + SOCK_ADDR_LOAD_NESTED_FIELD_SIZE_OFF( \ + S, NS, F, NF, SIZE, OFF); \ + } \ + } while (0) + +#define SOCK_ADDR_LOAD_OR_STORE_NESTED_FIELD(S, NS, F, NF, TF) \ + SOCK_ADDR_LOAD_OR_STORE_NESTED_FIELD_SIZE_OFF( \ + S, NS, F, NF, BPF_FIELD_SIZEOF(NS, NF), 0, TF) + +static u32 sock_addr_convert_ctx_access(enum bpf_access_type type, + const struct bpf_insn *si, + struct bpf_insn *insn_buf, + struct bpf_prog *prog, u32 *target_size) +{ + int off, port_size = sizeof_field(struct sockaddr_in6, sin6_port); + struct bpf_insn *insn = insn_buf; + + switch (si->off) { + case offsetof(struct bpf_sock_addr, user_family): + SOCK_ADDR_LOAD_NESTED_FIELD(struct bpf_sock_addr_kern, + struct sockaddr, uaddr, sa_family); + break; + + case offsetof(struct bpf_sock_addr, user_ip4): + SOCK_ADDR_LOAD_OR_STORE_NESTED_FIELD_SIZE_OFF( + struct bpf_sock_addr_kern, struct sockaddr_in, uaddr, + sin_addr, BPF_SIZE(si->code), 0, tmp_reg); + break; + + case bpf_ctx_range_till(struct bpf_sock_addr, user_ip6[0], user_ip6[3]): + off = si->off; + off -= offsetof(struct bpf_sock_addr, user_ip6[0]); + SOCK_ADDR_LOAD_OR_STORE_NESTED_FIELD_SIZE_OFF( + struct bpf_sock_addr_kern, struct sockaddr_in6, uaddr, + sin6_addr.s6_addr32[0], BPF_SIZE(si->code), off, + tmp_reg); + break; + + case offsetof(struct bpf_sock_addr, user_port): + /* To get port we need to know sa_family first and then treat + * sockaddr as either sockaddr_in or sockaddr_in6. + * Though we can simplify since port field has same offset and + * size in both structures. + * Here we check this invariant and use just one of the + * structures if it's true. + */ + BUILD_BUG_ON(offsetof(struct sockaddr_in, sin_port) != + offsetof(struct sockaddr_in6, sin6_port)); + BUILD_BUG_ON(sizeof_field(struct sockaddr_in, sin_port) != + sizeof_field(struct sockaddr_in6, sin6_port)); + /* Account for sin6_port being smaller than user_port. */ + port_size = min(port_size, BPF_LDST_BYTES(si)); + SOCK_ADDR_LOAD_OR_STORE_NESTED_FIELD_SIZE_OFF( + struct bpf_sock_addr_kern, struct sockaddr_in6, uaddr, + sin6_port, bytes_to_bpf_size(port_size), 0, tmp_reg); + break; + + case offsetof(struct bpf_sock_addr, family): + SOCK_ADDR_LOAD_NESTED_FIELD(struct bpf_sock_addr_kern, + struct sock, sk, sk_family); + break; + + case offsetof(struct bpf_sock_addr, type): + SOCK_ADDR_LOAD_NESTED_FIELD(struct bpf_sock_addr_kern, + struct sock, sk, sk_type); + break; + + case offsetof(struct bpf_sock_addr, protocol): + SOCK_ADDR_LOAD_NESTED_FIELD(struct bpf_sock_addr_kern, + struct sock, sk, sk_protocol); + break; + + case offsetof(struct bpf_sock_addr, msg_src_ip4): + /* Treat t_ctx as struct in_addr for msg_src_ip4. */ + SOCK_ADDR_LOAD_OR_STORE_NESTED_FIELD_SIZE_OFF( + struct bpf_sock_addr_kern, struct in_addr, t_ctx, + s_addr, BPF_SIZE(si->code), 0, tmp_reg); + break; + + case bpf_ctx_range_till(struct bpf_sock_addr, msg_src_ip6[0], + msg_src_ip6[3]): + off = si->off; + off -= offsetof(struct bpf_sock_addr, msg_src_ip6[0]); + /* Treat t_ctx as struct in6_addr for msg_src_ip6. */ + SOCK_ADDR_LOAD_OR_STORE_NESTED_FIELD_SIZE_OFF( + struct bpf_sock_addr_kern, struct in6_addr, t_ctx, + s6_addr32[0], BPF_SIZE(si->code), off, tmp_reg); + break; + case offsetof(struct bpf_sock_addr, sk): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct bpf_sock_addr_kern, sk), + si->dst_reg, si->src_reg, + offsetof(struct bpf_sock_addr_kern, sk)); + break; + } + + return insn - insn_buf; +} + +static u32 sock_ops_convert_ctx_access(enum bpf_access_type type, + const struct bpf_insn *si, + struct bpf_insn *insn_buf, + struct bpf_prog *prog, + u32 *target_size) +{ + struct bpf_insn *insn = insn_buf; + int off; + +/* Helper macro for adding read access to tcp_sock or sock fields. */ +#define SOCK_OPS_GET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ) \ + do { \ + int fullsock_reg = si->dst_reg, reg = BPF_REG_9, jmp = 2; \ + BUILD_BUG_ON(sizeof_field(OBJ, OBJ_FIELD) > \ + sizeof_field(struct bpf_sock_ops, BPF_FIELD)); \ + if (si->dst_reg == reg || si->src_reg == reg) \ + reg--; \ + if (si->dst_reg == reg || si->src_reg == reg) \ + reg--; \ + if (si->dst_reg == si->src_reg) { \ + *insn++ = BPF_STX_MEM(BPF_DW, si->src_reg, reg, \ + offsetof(struct bpf_sock_ops_kern, \ + temp)); \ + fullsock_reg = reg; \ + jmp += 2; \ + } \ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( \ + struct bpf_sock_ops_kern, \ + is_fullsock), \ + fullsock_reg, si->src_reg, \ + offsetof(struct bpf_sock_ops_kern, \ + is_fullsock)); \ + *insn++ = BPF_JMP_IMM(BPF_JEQ, fullsock_reg, 0, jmp); \ + if (si->dst_reg == si->src_reg) \ + *insn++ = BPF_LDX_MEM(BPF_DW, reg, si->src_reg, \ + offsetof(struct bpf_sock_ops_kern, \ + temp)); \ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( \ + struct bpf_sock_ops_kern, sk),\ + si->dst_reg, si->src_reg, \ + offsetof(struct bpf_sock_ops_kern, sk));\ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(OBJ, \ + OBJ_FIELD), \ + si->dst_reg, si->dst_reg, \ + offsetof(OBJ, OBJ_FIELD)); \ + if (si->dst_reg == si->src_reg) { \ + *insn++ = BPF_JMP_A(1); \ + *insn++ = BPF_LDX_MEM(BPF_DW, reg, si->src_reg, \ + offsetof(struct bpf_sock_ops_kern, \ + temp)); \ + } \ + } while (0) + +#define SOCK_OPS_GET_SK() \ + do { \ + int fullsock_reg = si->dst_reg, reg = BPF_REG_9, jmp = 1; \ + if (si->dst_reg == reg || si->src_reg == reg) \ + reg--; \ + if (si->dst_reg == reg || si->src_reg == reg) \ + reg--; \ + if (si->dst_reg == si->src_reg) { \ + *insn++ = BPF_STX_MEM(BPF_DW, si->src_reg, reg, \ + offsetof(struct bpf_sock_ops_kern, \ + temp)); \ + fullsock_reg = reg; \ + jmp += 2; \ + } \ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( \ + struct bpf_sock_ops_kern, \ + is_fullsock), \ + fullsock_reg, si->src_reg, \ + offsetof(struct bpf_sock_ops_kern, \ + is_fullsock)); \ + *insn++ = BPF_JMP_IMM(BPF_JEQ, fullsock_reg, 0, jmp); \ + if (si->dst_reg == si->src_reg) \ + *insn++ = BPF_LDX_MEM(BPF_DW, reg, si->src_reg, \ + offsetof(struct bpf_sock_ops_kern, \ + temp)); \ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( \ + struct bpf_sock_ops_kern, sk),\ + si->dst_reg, si->src_reg, \ + offsetof(struct bpf_sock_ops_kern, sk));\ + if (si->dst_reg == si->src_reg) { \ + *insn++ = BPF_JMP_A(1); \ + *insn++ = BPF_LDX_MEM(BPF_DW, reg, si->src_reg, \ + offsetof(struct bpf_sock_ops_kern, \ + temp)); \ + } \ + } while (0) + +#define SOCK_OPS_GET_TCP_SOCK_FIELD(FIELD) \ + SOCK_OPS_GET_FIELD(FIELD, FIELD, struct tcp_sock) + +/* Helper macro for adding write access to tcp_sock or sock fields. + * The macro is called with two registers, dst_reg which contains a pointer + * to ctx (context) and src_reg which contains the value that should be + * stored. However, we need an additional register since we cannot overwrite + * dst_reg because it may be used later in the program. + * Instead we "borrow" one of the other register. We first save its value + * into a new (temp) field in bpf_sock_ops_kern, use it, and then restore + * it at the end of the macro. + */ +#define SOCK_OPS_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ) \ + do { \ + int reg = BPF_REG_9; \ + BUILD_BUG_ON(sizeof_field(OBJ, OBJ_FIELD) > \ + sizeof_field(struct bpf_sock_ops, BPF_FIELD)); \ + if (si->dst_reg == reg || si->src_reg == reg) \ + reg--; \ + if (si->dst_reg == reg || si->src_reg == reg) \ + reg--; \ + *insn++ = BPF_STX_MEM(BPF_DW, si->dst_reg, reg, \ + offsetof(struct bpf_sock_ops_kern, \ + temp)); \ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( \ + struct bpf_sock_ops_kern, \ + is_fullsock), \ + reg, si->dst_reg, \ + offsetof(struct bpf_sock_ops_kern, \ + is_fullsock)); \ + *insn++ = BPF_JMP_IMM(BPF_JEQ, reg, 0, 2); \ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( \ + struct bpf_sock_ops_kern, sk),\ + reg, si->dst_reg, \ + offsetof(struct bpf_sock_ops_kern, sk));\ + *insn++ = BPF_STX_MEM(BPF_FIELD_SIZEOF(OBJ, OBJ_FIELD), \ + reg, si->src_reg, \ + offsetof(OBJ, OBJ_FIELD)); \ + *insn++ = BPF_LDX_MEM(BPF_DW, reg, si->dst_reg, \ + offsetof(struct bpf_sock_ops_kern, \ + temp)); \ + } while (0) + +#define SOCK_OPS_GET_OR_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ, TYPE) \ + do { \ + if (TYPE == BPF_WRITE) \ + SOCK_OPS_SET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ); \ + else \ + SOCK_OPS_GET_FIELD(BPF_FIELD, OBJ_FIELD, OBJ); \ + } while (0) + + if (insn > insn_buf) + return insn - insn_buf; + + switch (si->off) { + case offsetof(struct bpf_sock_ops, op): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct bpf_sock_ops_kern, + op), + si->dst_reg, si->src_reg, + offsetof(struct bpf_sock_ops_kern, op)); + break; + + case offsetof(struct bpf_sock_ops, replylong[0]) ... + offsetof(struct bpf_sock_ops, replylong[3]): + BUILD_BUG_ON(sizeof_field(struct bpf_sock_ops, reply) != + sizeof_field(struct bpf_sock_ops_kern, reply)); + BUILD_BUG_ON(sizeof_field(struct bpf_sock_ops, replylong) != + sizeof_field(struct bpf_sock_ops_kern, replylong)); + off = si->off; + off -= offsetof(struct bpf_sock_ops, replylong[0]); + off += offsetof(struct bpf_sock_ops_kern, replylong[0]); + if (type == BPF_WRITE) + *insn++ = BPF_STX_MEM(BPF_W, si->dst_reg, si->src_reg, + off); + else + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, + off); + break; + + case offsetof(struct bpf_sock_ops, family): + BUILD_BUG_ON(sizeof_field(struct sock_common, skc_family) != 2); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( + struct bpf_sock_ops_kern, sk), + si->dst_reg, si->src_reg, + offsetof(struct bpf_sock_ops_kern, sk)); + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, + offsetof(struct sock_common, skc_family)); + break; + + case offsetof(struct bpf_sock_ops, remote_ip4): + BUILD_BUG_ON(sizeof_field(struct sock_common, skc_daddr) != 4); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( + struct bpf_sock_ops_kern, sk), + si->dst_reg, si->src_reg, + offsetof(struct bpf_sock_ops_kern, sk)); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + offsetof(struct sock_common, skc_daddr)); + break; + + case offsetof(struct bpf_sock_ops, local_ip4): + BUILD_BUG_ON(sizeof_field(struct sock_common, + skc_rcv_saddr) != 4); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( + struct bpf_sock_ops_kern, sk), + si->dst_reg, si->src_reg, + offsetof(struct bpf_sock_ops_kern, sk)); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + offsetof(struct sock_common, + skc_rcv_saddr)); + break; + + case offsetof(struct bpf_sock_ops, remote_ip6[0]) ... + offsetof(struct bpf_sock_ops, remote_ip6[3]): +#if IS_ENABLED(CONFIG_IPV6) + BUILD_BUG_ON(sizeof_field(struct sock_common, + skc_v6_daddr.s6_addr32[0]) != 4); + + off = si->off; + off -= offsetof(struct bpf_sock_ops, remote_ip6[0]); + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( + struct bpf_sock_ops_kern, sk), + si->dst_reg, si->src_reg, + offsetof(struct bpf_sock_ops_kern, sk)); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + offsetof(struct sock_common, + skc_v6_daddr.s6_addr32[0]) + + off); +#else + *insn++ = BPF_MOV32_IMM(si->dst_reg, 0); +#endif + break; + + case offsetof(struct bpf_sock_ops, local_ip6[0]) ... + offsetof(struct bpf_sock_ops, local_ip6[3]): +#if IS_ENABLED(CONFIG_IPV6) + BUILD_BUG_ON(sizeof_field(struct sock_common, + skc_v6_rcv_saddr.s6_addr32[0]) != 4); + + off = si->off; + off -= offsetof(struct bpf_sock_ops, local_ip6[0]); + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( + struct bpf_sock_ops_kern, sk), + si->dst_reg, si->src_reg, + offsetof(struct bpf_sock_ops_kern, sk)); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + offsetof(struct sock_common, + skc_v6_rcv_saddr.s6_addr32[0]) + + off); +#else + *insn++ = BPF_MOV32_IMM(si->dst_reg, 0); +#endif + break; + + case offsetof(struct bpf_sock_ops, remote_port): + BUILD_BUG_ON(sizeof_field(struct sock_common, skc_dport) != 2); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( + struct bpf_sock_ops_kern, sk), + si->dst_reg, si->src_reg, + offsetof(struct bpf_sock_ops_kern, sk)); + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, + offsetof(struct sock_common, skc_dport)); +#ifndef __BIG_ENDIAN_BITFIELD + *insn++ = BPF_ALU32_IMM(BPF_LSH, si->dst_reg, 16); +#endif + break; + + case offsetof(struct bpf_sock_ops, local_port): + BUILD_BUG_ON(sizeof_field(struct sock_common, skc_num) != 2); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( + struct bpf_sock_ops_kern, sk), + si->dst_reg, si->src_reg, + offsetof(struct bpf_sock_ops_kern, sk)); + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, + offsetof(struct sock_common, skc_num)); + break; + + case offsetof(struct bpf_sock_ops, is_fullsock): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( + struct bpf_sock_ops_kern, + is_fullsock), + si->dst_reg, si->src_reg, + offsetof(struct bpf_sock_ops_kern, + is_fullsock)); + break; + + case offsetof(struct bpf_sock_ops, state): + BUILD_BUG_ON(sizeof_field(struct sock_common, skc_state) != 1); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( + struct bpf_sock_ops_kern, sk), + si->dst_reg, si->src_reg, + offsetof(struct bpf_sock_ops_kern, sk)); + *insn++ = BPF_LDX_MEM(BPF_B, si->dst_reg, si->dst_reg, + offsetof(struct sock_common, skc_state)); + break; + + case offsetof(struct bpf_sock_ops, rtt_min): + BUILD_BUG_ON(sizeof_field(struct tcp_sock, rtt_min) != + sizeof(struct minmax)); + BUILD_BUG_ON(sizeof(struct minmax) < + sizeof(struct minmax_sample)); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( + struct bpf_sock_ops_kern, sk), + si->dst_reg, si->src_reg, + offsetof(struct bpf_sock_ops_kern, sk)); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + offsetof(struct tcp_sock, rtt_min) + + sizeof_field(struct minmax_sample, t)); + break; + + case offsetof(struct bpf_sock_ops, bpf_sock_ops_cb_flags): + SOCK_OPS_GET_FIELD(bpf_sock_ops_cb_flags, bpf_sock_ops_cb_flags, + struct tcp_sock); + break; + + case offsetof(struct bpf_sock_ops, sk_txhash): + SOCK_OPS_GET_OR_SET_FIELD(sk_txhash, sk_txhash, + struct sock, type); + break; + case offsetof(struct bpf_sock_ops, snd_cwnd): + SOCK_OPS_GET_TCP_SOCK_FIELD(snd_cwnd); + break; + case offsetof(struct bpf_sock_ops, srtt_us): + SOCK_OPS_GET_TCP_SOCK_FIELD(srtt_us); + break; + case offsetof(struct bpf_sock_ops, snd_ssthresh): + SOCK_OPS_GET_TCP_SOCK_FIELD(snd_ssthresh); + break; + case offsetof(struct bpf_sock_ops, rcv_nxt): + SOCK_OPS_GET_TCP_SOCK_FIELD(rcv_nxt); + break; + case offsetof(struct bpf_sock_ops, snd_nxt): + SOCK_OPS_GET_TCP_SOCK_FIELD(snd_nxt); + break; + case offsetof(struct bpf_sock_ops, snd_una): + SOCK_OPS_GET_TCP_SOCK_FIELD(snd_una); + break; + case offsetof(struct bpf_sock_ops, mss_cache): + SOCK_OPS_GET_TCP_SOCK_FIELD(mss_cache); + break; + case offsetof(struct bpf_sock_ops, ecn_flags): + SOCK_OPS_GET_TCP_SOCK_FIELD(ecn_flags); + break; + case offsetof(struct bpf_sock_ops, rate_delivered): + SOCK_OPS_GET_TCP_SOCK_FIELD(rate_delivered); + break; + case offsetof(struct bpf_sock_ops, rate_interval_us): + SOCK_OPS_GET_TCP_SOCK_FIELD(rate_interval_us); + break; + case offsetof(struct bpf_sock_ops, packets_out): + SOCK_OPS_GET_TCP_SOCK_FIELD(packets_out); + break; + case offsetof(struct bpf_sock_ops, retrans_out): + SOCK_OPS_GET_TCP_SOCK_FIELD(retrans_out); + break; + case offsetof(struct bpf_sock_ops, total_retrans): + SOCK_OPS_GET_TCP_SOCK_FIELD(total_retrans); + break; + case offsetof(struct bpf_sock_ops, segs_in): + SOCK_OPS_GET_TCP_SOCK_FIELD(segs_in); + break; + case offsetof(struct bpf_sock_ops, data_segs_in): + SOCK_OPS_GET_TCP_SOCK_FIELD(data_segs_in); + break; + case offsetof(struct bpf_sock_ops, segs_out): + SOCK_OPS_GET_TCP_SOCK_FIELD(segs_out); + break; + case offsetof(struct bpf_sock_ops, data_segs_out): + SOCK_OPS_GET_TCP_SOCK_FIELD(data_segs_out); + break; + case offsetof(struct bpf_sock_ops, lost_out): + SOCK_OPS_GET_TCP_SOCK_FIELD(lost_out); + break; + case offsetof(struct bpf_sock_ops, sacked_out): + SOCK_OPS_GET_TCP_SOCK_FIELD(sacked_out); + break; + case offsetof(struct bpf_sock_ops, bytes_received): + SOCK_OPS_GET_TCP_SOCK_FIELD(bytes_received); + break; + case offsetof(struct bpf_sock_ops, bytes_acked): + SOCK_OPS_GET_TCP_SOCK_FIELD(bytes_acked); + break; + case offsetof(struct bpf_sock_ops, sk): + SOCK_OPS_GET_SK(); + break; + case offsetof(struct bpf_sock_ops, skb_data_end): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct bpf_sock_ops_kern, + skb_data_end), + si->dst_reg, si->src_reg, + offsetof(struct bpf_sock_ops_kern, + skb_data_end)); + break; + case offsetof(struct bpf_sock_ops, skb_data): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct bpf_sock_ops_kern, + skb), + si->dst_reg, si->src_reg, + offsetof(struct bpf_sock_ops_kern, + skb)); + *insn++ = BPF_JMP_IMM(BPF_JEQ, si->dst_reg, 0, 1); + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, data), + si->dst_reg, si->dst_reg, + offsetof(struct sk_buff, data)); + break; + case offsetof(struct bpf_sock_ops, skb_len): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct bpf_sock_ops_kern, + skb), + si->dst_reg, si->src_reg, + offsetof(struct bpf_sock_ops_kern, + skb)); + *insn++ = BPF_JMP_IMM(BPF_JEQ, si->dst_reg, 0, 1); + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, len), + si->dst_reg, si->dst_reg, + offsetof(struct sk_buff, len)); + break; + case offsetof(struct bpf_sock_ops, skb_tcp_flags): + off = offsetof(struct sk_buff, cb); + off += offsetof(struct tcp_skb_cb, tcp_flags); + *target_size = sizeof_field(struct tcp_skb_cb, tcp_flags); + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct bpf_sock_ops_kern, + skb), + si->dst_reg, si->src_reg, + offsetof(struct bpf_sock_ops_kern, + skb)); + *insn++ = BPF_JMP_IMM(BPF_JEQ, si->dst_reg, 0, 1); + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct tcp_skb_cb, + tcp_flags), + si->dst_reg, si->dst_reg, off); + break; + } + return insn - insn_buf; +} + +/* data_end = skb->data + skb_headlen() */ +static struct bpf_insn *bpf_convert_data_end_access(const struct bpf_insn *si, + struct bpf_insn *insn) +{ + int reg; + int temp_reg_off = offsetof(struct sk_buff, cb) + + offsetof(struct sk_skb_cb, temp_reg); + + if (si->src_reg == si->dst_reg) { + /* We need an extra register, choose and save a register. */ + reg = BPF_REG_9; + if (si->src_reg == reg || si->dst_reg == reg) + reg--; + if (si->src_reg == reg || si->dst_reg == reg) + reg--; + *insn++ = BPF_STX_MEM(BPF_DW, si->src_reg, reg, temp_reg_off); + } else { + reg = si->dst_reg; + } + + /* reg = skb->data */ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, data), + reg, si->src_reg, + offsetof(struct sk_buff, data)); + /* AX = skb->len */ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, len), + BPF_REG_AX, si->src_reg, + offsetof(struct sk_buff, len)); + /* reg = skb->data + skb->len */ + *insn++ = BPF_ALU64_REG(BPF_ADD, reg, BPF_REG_AX); + /* AX = skb->data_len */ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, data_len), + BPF_REG_AX, si->src_reg, + offsetof(struct sk_buff, data_len)); + + /* reg = skb->data + skb->len - skb->data_len */ + *insn++ = BPF_ALU64_REG(BPF_SUB, reg, BPF_REG_AX); + + if (si->src_reg == si->dst_reg) { + /* Restore the saved register */ + *insn++ = BPF_MOV64_REG(BPF_REG_AX, si->src_reg); + *insn++ = BPF_MOV64_REG(si->dst_reg, reg); + *insn++ = BPF_LDX_MEM(BPF_DW, reg, BPF_REG_AX, temp_reg_off); + } + + return insn; +} + +static u32 sk_skb_convert_ctx_access(enum bpf_access_type type, + const struct bpf_insn *si, + struct bpf_insn *insn_buf, + struct bpf_prog *prog, u32 *target_size) +{ + struct bpf_insn *insn = insn_buf; + int off; + + switch (si->off) { + case offsetof(struct __sk_buff, data_end): + insn = bpf_convert_data_end_access(si, insn); + break; + case offsetof(struct __sk_buff, cb[0]) ... + offsetofend(struct __sk_buff, cb[4]) - 1: + BUILD_BUG_ON(sizeof_field(struct sk_skb_cb, data) < 20); + BUILD_BUG_ON((offsetof(struct sk_buff, cb) + + offsetof(struct sk_skb_cb, data)) % + sizeof(__u64)); + + prog->cb_access = 1; + off = si->off; + off -= offsetof(struct __sk_buff, cb[0]); + off += offsetof(struct sk_buff, cb); + off += offsetof(struct sk_skb_cb, data); + if (type == BPF_WRITE) + *insn++ = BPF_STX_MEM(BPF_SIZE(si->code), si->dst_reg, + si->src_reg, off); + else + *insn++ = BPF_LDX_MEM(BPF_SIZE(si->code), si->dst_reg, + si->src_reg, off); + break; + + + default: + return bpf_convert_ctx_access(type, si, insn_buf, prog, + target_size); + } + + return insn - insn_buf; +} + +static u32 sk_msg_convert_ctx_access(enum bpf_access_type type, + const struct bpf_insn *si, + struct bpf_insn *insn_buf, + struct bpf_prog *prog, u32 *target_size) +{ + struct bpf_insn *insn = insn_buf; +#if IS_ENABLED(CONFIG_IPV6) + int off; +#endif + + /* convert ctx uses the fact sg element is first in struct */ + BUILD_BUG_ON(offsetof(struct sk_msg, sg) != 0); + + switch (si->off) { + case offsetof(struct sk_msg_md, data): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg, data), + si->dst_reg, si->src_reg, + offsetof(struct sk_msg, data)); + break; + case offsetof(struct sk_msg_md, data_end): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg, data_end), + si->dst_reg, si->src_reg, + offsetof(struct sk_msg, data_end)); + break; + case offsetof(struct sk_msg_md, family): + BUILD_BUG_ON(sizeof_field(struct sock_common, skc_family) != 2); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( + struct sk_msg, sk), + si->dst_reg, si->src_reg, + offsetof(struct sk_msg, sk)); + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, + offsetof(struct sock_common, skc_family)); + break; + + case offsetof(struct sk_msg_md, remote_ip4): + BUILD_BUG_ON(sizeof_field(struct sock_common, skc_daddr) != 4); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( + struct sk_msg, sk), + si->dst_reg, si->src_reg, + offsetof(struct sk_msg, sk)); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + offsetof(struct sock_common, skc_daddr)); + break; + + case offsetof(struct sk_msg_md, local_ip4): + BUILD_BUG_ON(sizeof_field(struct sock_common, + skc_rcv_saddr) != 4); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( + struct sk_msg, sk), + si->dst_reg, si->src_reg, + offsetof(struct sk_msg, sk)); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + offsetof(struct sock_common, + skc_rcv_saddr)); + break; + + case offsetof(struct sk_msg_md, remote_ip6[0]) ... + offsetof(struct sk_msg_md, remote_ip6[3]): +#if IS_ENABLED(CONFIG_IPV6) + BUILD_BUG_ON(sizeof_field(struct sock_common, + skc_v6_daddr.s6_addr32[0]) != 4); + + off = si->off; + off -= offsetof(struct sk_msg_md, remote_ip6[0]); + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( + struct sk_msg, sk), + si->dst_reg, si->src_reg, + offsetof(struct sk_msg, sk)); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + offsetof(struct sock_common, + skc_v6_daddr.s6_addr32[0]) + + off); +#else + *insn++ = BPF_MOV32_IMM(si->dst_reg, 0); +#endif + break; + + case offsetof(struct sk_msg_md, local_ip6[0]) ... + offsetof(struct sk_msg_md, local_ip6[3]): +#if IS_ENABLED(CONFIG_IPV6) + BUILD_BUG_ON(sizeof_field(struct sock_common, + skc_v6_rcv_saddr.s6_addr32[0]) != 4); + + off = si->off; + off -= offsetof(struct sk_msg_md, local_ip6[0]); + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( + struct sk_msg, sk), + si->dst_reg, si->src_reg, + offsetof(struct sk_msg, sk)); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, + offsetof(struct sock_common, + skc_v6_rcv_saddr.s6_addr32[0]) + + off); +#else + *insn++ = BPF_MOV32_IMM(si->dst_reg, 0); +#endif + break; + + case offsetof(struct sk_msg_md, remote_port): + BUILD_BUG_ON(sizeof_field(struct sock_common, skc_dport) != 2); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( + struct sk_msg, sk), + si->dst_reg, si->src_reg, + offsetof(struct sk_msg, sk)); + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, + offsetof(struct sock_common, skc_dport)); +#ifndef __BIG_ENDIAN_BITFIELD + *insn++ = BPF_ALU32_IMM(BPF_LSH, si->dst_reg, 16); +#endif + break; + + case offsetof(struct sk_msg_md, local_port): + BUILD_BUG_ON(sizeof_field(struct sock_common, skc_num) != 2); + + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF( + struct sk_msg, sk), + si->dst_reg, si->src_reg, + offsetof(struct sk_msg, sk)); + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->dst_reg, + offsetof(struct sock_common, skc_num)); + break; + + case offsetof(struct sk_msg_md, size): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_sg, size), + si->dst_reg, si->src_reg, + offsetof(struct sk_msg_sg, size)); + break; + + case offsetof(struct sk_msg_md, sk): + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg, sk), + si->dst_reg, si->src_reg, + offsetof(struct sk_msg, sk)); + break; + } + + return insn - insn_buf; +} + +const struct bpf_verifier_ops sk_filter_verifier_ops = { + .get_func_proto = sk_filter_func_proto, + .is_valid_access = sk_filter_is_valid_access, + .convert_ctx_access = bpf_convert_ctx_access, + .gen_ld_abs = bpf_gen_ld_abs, +}; + +const struct bpf_prog_ops sk_filter_prog_ops = { + .test_run = bpf_prog_test_run_skb, +}; + +const struct bpf_verifier_ops tc_cls_act_verifier_ops = { + .get_func_proto = tc_cls_act_func_proto, + .is_valid_access = tc_cls_act_is_valid_access, + .convert_ctx_access = tc_cls_act_convert_ctx_access, + .gen_prologue = tc_cls_act_prologue, + .gen_ld_abs = bpf_gen_ld_abs, + .btf_struct_access = tc_cls_act_btf_struct_access, +}; + +const struct bpf_prog_ops tc_cls_act_prog_ops = { + .test_run = bpf_prog_test_run_skb, +}; + +const struct bpf_verifier_ops xdp_verifier_ops = { + .get_func_proto = xdp_func_proto, + .is_valid_access = xdp_is_valid_access, + .convert_ctx_access = xdp_convert_ctx_access, + .gen_prologue = bpf_noop_prologue, + .btf_struct_access = xdp_btf_struct_access, +}; + +const struct bpf_prog_ops xdp_prog_ops = { + .test_run = bpf_prog_test_run_xdp, +}; + +const struct bpf_verifier_ops cg_skb_verifier_ops = { + .get_func_proto = cg_skb_func_proto, + .is_valid_access = cg_skb_is_valid_access, + .convert_ctx_access = bpf_convert_ctx_access, +}; + +const struct bpf_prog_ops cg_skb_prog_ops = { + .test_run = bpf_prog_test_run_skb, +}; + +const struct bpf_verifier_ops lwt_in_verifier_ops = { + .get_func_proto = lwt_in_func_proto, + .is_valid_access = lwt_is_valid_access, + .convert_ctx_access = bpf_convert_ctx_access, +}; + +const struct bpf_prog_ops lwt_in_prog_ops = { + .test_run = bpf_prog_test_run_skb, +}; + +const struct bpf_verifier_ops lwt_out_verifier_ops = { + .get_func_proto = lwt_out_func_proto, + .is_valid_access = lwt_is_valid_access, + .convert_ctx_access = bpf_convert_ctx_access, +}; + +const struct bpf_prog_ops lwt_out_prog_ops = { + .test_run = bpf_prog_test_run_skb, +}; + +const struct bpf_verifier_ops lwt_xmit_verifier_ops = { + .get_func_proto = lwt_xmit_func_proto, + .is_valid_access = lwt_is_valid_access, + .convert_ctx_access = bpf_convert_ctx_access, + .gen_prologue = tc_cls_act_prologue, +}; + +const struct bpf_prog_ops lwt_xmit_prog_ops = { + .test_run = bpf_prog_test_run_skb, +}; + +const struct bpf_verifier_ops lwt_seg6local_verifier_ops = { + .get_func_proto = lwt_seg6local_func_proto, + .is_valid_access = lwt_is_valid_access, + .convert_ctx_access = bpf_convert_ctx_access, +}; + +const struct bpf_prog_ops lwt_seg6local_prog_ops = { + .test_run = bpf_prog_test_run_skb, +}; + +const struct bpf_verifier_ops cg_sock_verifier_ops = { + .get_func_proto = sock_filter_func_proto, + .is_valid_access = sock_filter_is_valid_access, + .convert_ctx_access = bpf_sock_convert_ctx_access, +}; + +const struct bpf_prog_ops cg_sock_prog_ops = { +}; + +const struct bpf_verifier_ops cg_sock_addr_verifier_ops = { + .get_func_proto = sock_addr_func_proto, + .is_valid_access = sock_addr_is_valid_access, + .convert_ctx_access = sock_addr_convert_ctx_access, +}; + +const struct bpf_prog_ops cg_sock_addr_prog_ops = { +}; + +const struct bpf_verifier_ops sock_ops_verifier_ops = { + .get_func_proto = sock_ops_func_proto, + .is_valid_access = sock_ops_is_valid_access, + .convert_ctx_access = sock_ops_convert_ctx_access, +}; + +const struct bpf_prog_ops sock_ops_prog_ops = { +}; + +const struct bpf_verifier_ops sk_skb_verifier_ops = { + .get_func_proto = sk_skb_func_proto, + .is_valid_access = sk_skb_is_valid_access, + .convert_ctx_access = sk_skb_convert_ctx_access, + .gen_prologue = sk_skb_prologue, +}; + +const struct bpf_prog_ops sk_skb_prog_ops = { +}; + +const struct bpf_verifier_ops sk_msg_verifier_ops = { + .get_func_proto = sk_msg_func_proto, + .is_valid_access = sk_msg_is_valid_access, + .convert_ctx_access = sk_msg_convert_ctx_access, + .gen_prologue = bpf_noop_prologue, +}; + +const struct bpf_prog_ops sk_msg_prog_ops = { +}; + +const struct bpf_verifier_ops flow_dissector_verifier_ops = { + .get_func_proto = flow_dissector_func_proto, + .is_valid_access = flow_dissector_is_valid_access, + .convert_ctx_access = flow_dissector_convert_ctx_access, +}; + +const struct bpf_prog_ops flow_dissector_prog_ops = { + .test_run = bpf_prog_test_run_flow_dissector, +}; + +int sk_detach_filter(struct sock *sk) +{ + int ret = -ENOENT; + struct sk_filter *filter; + + if (sock_flag(sk, SOCK_FILTER_LOCKED)) + return -EPERM; + + filter = rcu_dereference_protected(sk->sk_filter, + lockdep_sock_is_held(sk)); + if (filter) { + RCU_INIT_POINTER(sk->sk_filter, NULL); + sk_filter_uncharge(sk, filter); + ret = 0; + } + + return ret; +} +EXPORT_SYMBOL_GPL(sk_detach_filter); + +int sk_get_filter(struct sock *sk, sockptr_t optval, unsigned int len) +{ + struct sock_fprog_kern *fprog; + struct sk_filter *filter; + int ret = 0; + + sockopt_lock_sock(sk); + filter = rcu_dereference_protected(sk->sk_filter, + lockdep_sock_is_held(sk)); + if (!filter) + goto out; + + /* We're copying the filter that has been originally attached, + * so no conversion/decode needed anymore. eBPF programs that + * have no original program cannot be dumped through this. + */ + ret = -EACCES; + fprog = filter->prog->orig_prog; + if (!fprog) + goto out; + + ret = fprog->len; + if (!len) + /* User space only enquires number of filter blocks. */ + goto out; + + ret = -EINVAL; + if (len < fprog->len) + goto out; + + ret = -EFAULT; + if (copy_to_sockptr(optval, fprog->filter, bpf_classic_proglen(fprog))) + goto out; + + /* Instead of bytes, the API requests to return the number + * of filter blocks. + */ + ret = fprog->len; +out: + sockopt_release_sock(sk); + return ret; +} + +#ifdef CONFIG_INET +static void bpf_init_reuseport_kern(struct sk_reuseport_kern *reuse_kern, + struct sock_reuseport *reuse, + struct sock *sk, struct sk_buff *skb, + struct sock *migrating_sk, + u32 hash) +{ + reuse_kern->skb = skb; + reuse_kern->sk = sk; + reuse_kern->selected_sk = NULL; + reuse_kern->migrating_sk = migrating_sk; + reuse_kern->data_end = skb->data + skb_headlen(skb); + reuse_kern->hash = hash; + reuse_kern->reuseport_id = reuse->reuseport_id; + reuse_kern->bind_inany = reuse->bind_inany; +} + +struct sock *bpf_run_sk_reuseport(struct sock_reuseport *reuse, struct sock *sk, + struct bpf_prog *prog, struct sk_buff *skb, + struct sock *migrating_sk, + u32 hash) +{ + struct sk_reuseport_kern reuse_kern; + enum sk_action action; + + bpf_init_reuseport_kern(&reuse_kern, reuse, sk, skb, migrating_sk, hash); + action = bpf_prog_run(prog, &reuse_kern); + + if (action == SK_PASS) + return reuse_kern.selected_sk; + else + return ERR_PTR(-ECONNREFUSED); +} + +BPF_CALL_4(sk_select_reuseport, struct sk_reuseport_kern *, reuse_kern, + struct bpf_map *, map, void *, key, u32, flags) +{ + bool is_sockarray = map->map_type == BPF_MAP_TYPE_REUSEPORT_SOCKARRAY; + struct sock_reuseport *reuse; + struct sock *selected_sk; + + selected_sk = map->ops->map_lookup_elem(map, key); + if (!selected_sk) + return -ENOENT; + + reuse = rcu_dereference(selected_sk->sk_reuseport_cb); + if (!reuse) { + /* Lookup in sock_map can return TCP ESTABLISHED sockets. */ + if (sk_is_refcounted(selected_sk)) + sock_put(selected_sk); + + /* reuseport_array has only sk with non NULL sk_reuseport_cb. + * The only (!reuse) case here is - the sk has already been + * unhashed (e.g. by close()), so treat it as -ENOENT. + * + * Other maps (e.g. sock_map) do not provide this guarantee and + * the sk may never be in the reuseport group to begin with. + */ + return is_sockarray ? -ENOENT : -EINVAL; + } + + if (unlikely(reuse->reuseport_id != reuse_kern->reuseport_id)) { + struct sock *sk = reuse_kern->sk; + + if (sk->sk_protocol != selected_sk->sk_protocol) + return -EPROTOTYPE; + else if (sk->sk_family != selected_sk->sk_family) + return -EAFNOSUPPORT; + + /* Catch all. Likely bound to a different sockaddr. */ + return -EBADFD; + } + + reuse_kern->selected_sk = selected_sk; + + return 0; +} + +static const struct bpf_func_proto sk_select_reuseport_proto = { + .func = sk_select_reuseport, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_PTR_TO_MAP_KEY, + .arg4_type = ARG_ANYTHING, +}; + +BPF_CALL_4(sk_reuseport_load_bytes, + const struct sk_reuseport_kern *, reuse_kern, u32, offset, + void *, to, u32, len) +{ + return ____bpf_skb_load_bytes(reuse_kern->skb, offset, to, len); +} + +static const struct bpf_func_proto sk_reuseport_load_bytes_proto = { + .func = sk_reuseport_load_bytes, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_PTR_TO_UNINIT_MEM, + .arg4_type = ARG_CONST_SIZE, +}; + +BPF_CALL_5(sk_reuseport_load_bytes_relative, + const struct sk_reuseport_kern *, reuse_kern, u32, offset, + void *, to, u32, len, u32, start_header) +{ + return ____bpf_skb_load_bytes_relative(reuse_kern->skb, offset, to, + len, start_header); +} + +static const struct bpf_func_proto sk_reuseport_load_bytes_relative_proto = { + .func = sk_reuseport_load_bytes_relative, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_ANYTHING, + .arg3_type = ARG_PTR_TO_UNINIT_MEM, + .arg4_type = ARG_CONST_SIZE, + .arg5_type = ARG_ANYTHING, +}; + +static const struct bpf_func_proto * +sk_reuseport_func_proto(enum bpf_func_id func_id, + const struct bpf_prog *prog) +{ + switch (func_id) { + case BPF_FUNC_sk_select_reuseport: + return &sk_select_reuseport_proto; + case BPF_FUNC_skb_load_bytes: + return &sk_reuseport_load_bytes_proto; + case BPF_FUNC_skb_load_bytes_relative: + return &sk_reuseport_load_bytes_relative_proto; + case BPF_FUNC_get_socket_cookie: + return &bpf_get_socket_ptr_cookie_proto; + case BPF_FUNC_ktime_get_coarse_ns: + return &bpf_ktime_get_coarse_ns_proto; + default: + return bpf_base_func_proto(func_id); + } +} + +static bool +sk_reuseport_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + const u32 size_default = sizeof(__u32); + + if (off < 0 || off >= sizeof(struct sk_reuseport_md) || + off % size || type != BPF_READ) + return false; + + switch (off) { + case offsetof(struct sk_reuseport_md, data): + info->reg_type = PTR_TO_PACKET; + return size == sizeof(__u64); + + case offsetof(struct sk_reuseport_md, data_end): + info->reg_type = PTR_TO_PACKET_END; + return size == sizeof(__u64); + + case offsetof(struct sk_reuseport_md, hash): + return size == size_default; + + case offsetof(struct sk_reuseport_md, sk): + info->reg_type = PTR_TO_SOCKET; + return size == sizeof(__u64); + + case offsetof(struct sk_reuseport_md, migrating_sk): + info->reg_type = PTR_TO_SOCK_COMMON_OR_NULL; + return size == sizeof(__u64); + + /* Fields that allow narrowing */ + case bpf_ctx_range(struct sk_reuseport_md, eth_protocol): + if (size < sizeof_field(struct sk_buff, protocol)) + return false; + fallthrough; + case bpf_ctx_range(struct sk_reuseport_md, ip_protocol): + case bpf_ctx_range(struct sk_reuseport_md, bind_inany): + case bpf_ctx_range(struct sk_reuseport_md, len): + bpf_ctx_record_field_size(info, size_default); + return bpf_ctx_narrow_access_ok(off, size, size_default); + + default: + return false; + } +} + +#define SK_REUSEPORT_LOAD_FIELD(F) ({ \ + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_reuseport_kern, F), \ + si->dst_reg, si->src_reg, \ + bpf_target_off(struct sk_reuseport_kern, F, \ + sizeof_field(struct sk_reuseport_kern, F), \ + target_size)); \ + }) + +#define SK_REUSEPORT_LOAD_SKB_FIELD(SKB_FIELD) \ + SOCK_ADDR_LOAD_NESTED_FIELD(struct sk_reuseport_kern, \ + struct sk_buff, \ + skb, \ + SKB_FIELD) + +#define SK_REUSEPORT_LOAD_SK_FIELD(SK_FIELD) \ + SOCK_ADDR_LOAD_NESTED_FIELD(struct sk_reuseport_kern, \ + struct sock, \ + sk, \ + SK_FIELD) + +static u32 sk_reuseport_convert_ctx_access(enum bpf_access_type type, + const struct bpf_insn *si, + struct bpf_insn *insn_buf, + struct bpf_prog *prog, + u32 *target_size) +{ + struct bpf_insn *insn = insn_buf; + + switch (si->off) { + case offsetof(struct sk_reuseport_md, data): + SK_REUSEPORT_LOAD_SKB_FIELD(data); + break; + + case offsetof(struct sk_reuseport_md, len): + SK_REUSEPORT_LOAD_SKB_FIELD(len); + break; + + case offsetof(struct sk_reuseport_md, eth_protocol): + SK_REUSEPORT_LOAD_SKB_FIELD(protocol); + break; + + case offsetof(struct sk_reuseport_md, ip_protocol): + SK_REUSEPORT_LOAD_SK_FIELD(sk_protocol); + break; + + case offsetof(struct sk_reuseport_md, data_end): + SK_REUSEPORT_LOAD_FIELD(data_end); + break; + + case offsetof(struct sk_reuseport_md, hash): + SK_REUSEPORT_LOAD_FIELD(hash); + break; + + case offsetof(struct sk_reuseport_md, bind_inany): + SK_REUSEPORT_LOAD_FIELD(bind_inany); + break; + + case offsetof(struct sk_reuseport_md, sk): + SK_REUSEPORT_LOAD_FIELD(sk); + break; + + case offsetof(struct sk_reuseport_md, migrating_sk): + SK_REUSEPORT_LOAD_FIELD(migrating_sk); + break; + } + + return insn - insn_buf; +} + +const struct bpf_verifier_ops sk_reuseport_verifier_ops = { + .get_func_proto = sk_reuseport_func_proto, + .is_valid_access = sk_reuseport_is_valid_access, + .convert_ctx_access = sk_reuseport_convert_ctx_access, +}; + +const struct bpf_prog_ops sk_reuseport_prog_ops = { +}; + +DEFINE_STATIC_KEY_FALSE(bpf_sk_lookup_enabled); +EXPORT_SYMBOL(bpf_sk_lookup_enabled); + +BPF_CALL_3(bpf_sk_lookup_assign, struct bpf_sk_lookup_kern *, ctx, + struct sock *, sk, u64, flags) +{ + if (unlikely(flags & ~(BPF_SK_LOOKUP_F_REPLACE | + BPF_SK_LOOKUP_F_NO_REUSEPORT))) + return -EINVAL; + if (unlikely(sk && sk_is_refcounted(sk))) + return -ESOCKTNOSUPPORT; /* reject non-RCU freed sockets */ + if (unlikely(sk && sk_is_tcp(sk) && sk->sk_state != TCP_LISTEN)) + return -ESOCKTNOSUPPORT; /* only accept TCP socket in LISTEN */ + if (unlikely(sk && sk_is_udp(sk) && sk->sk_state != TCP_CLOSE)) + return -ESOCKTNOSUPPORT; /* only accept UDP socket in CLOSE */ + + /* Check if socket is suitable for packet L3/L4 protocol */ + if (sk && sk->sk_protocol != ctx->protocol) + return -EPROTOTYPE; + if (sk && sk->sk_family != ctx->family && + (sk->sk_family == AF_INET || ipv6_only_sock(sk))) + return -EAFNOSUPPORT; + + if (ctx->selected_sk && !(flags & BPF_SK_LOOKUP_F_REPLACE)) + return -EEXIST; + + /* Select socket as lookup result */ + ctx->selected_sk = sk; + ctx->no_reuseport = flags & BPF_SK_LOOKUP_F_NO_REUSEPORT; + return 0; +} + +static const struct bpf_func_proto bpf_sk_lookup_assign_proto = { + .func = bpf_sk_lookup_assign, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_PTR_TO_SOCKET_OR_NULL, + .arg3_type = ARG_ANYTHING, +}; + +static const struct bpf_func_proto * +sk_lookup_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog) +{ + switch (func_id) { + case BPF_FUNC_perf_event_output: + return &bpf_event_output_data_proto; + case BPF_FUNC_sk_assign: + return &bpf_sk_lookup_assign_proto; + case BPF_FUNC_sk_release: + return &bpf_sk_release_proto; + default: + return bpf_sk_base_func_proto(func_id); + } +} + +static bool sk_lookup_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + if (off < 0 || off >= sizeof(struct bpf_sk_lookup)) + return false; + if (off % size != 0) + return false; + if (type != BPF_READ) + return false; + + switch (off) { + case offsetof(struct bpf_sk_lookup, sk): + info->reg_type = PTR_TO_SOCKET_OR_NULL; + return size == sizeof(__u64); + + case bpf_ctx_range(struct bpf_sk_lookup, family): + case bpf_ctx_range(struct bpf_sk_lookup, protocol): + case bpf_ctx_range(struct bpf_sk_lookup, remote_ip4): + case bpf_ctx_range(struct bpf_sk_lookup, local_ip4): + case bpf_ctx_range_till(struct bpf_sk_lookup, remote_ip6[0], remote_ip6[3]): + case bpf_ctx_range_till(struct bpf_sk_lookup, local_ip6[0], local_ip6[3]): + case bpf_ctx_range(struct bpf_sk_lookup, local_port): + case bpf_ctx_range(struct bpf_sk_lookup, ingress_ifindex): + bpf_ctx_record_field_size(info, sizeof(__u32)); + return bpf_ctx_narrow_access_ok(off, size, sizeof(__u32)); + + case bpf_ctx_range(struct bpf_sk_lookup, remote_port): + /* Allow 4-byte access to 2-byte field for backward compatibility */ + if (size == sizeof(__u32)) + return true; + bpf_ctx_record_field_size(info, sizeof(__be16)); + return bpf_ctx_narrow_access_ok(off, size, sizeof(__be16)); + + case offsetofend(struct bpf_sk_lookup, remote_port) ... + offsetof(struct bpf_sk_lookup, local_ip4) - 1: + /* Allow access to zero padding for backward compatibility */ + bpf_ctx_record_field_size(info, sizeof(__u16)); + return bpf_ctx_narrow_access_ok(off, size, sizeof(__u16)); + + default: + return false; + } +} + +static u32 sk_lookup_convert_ctx_access(enum bpf_access_type type, + const struct bpf_insn *si, + struct bpf_insn *insn_buf, + struct bpf_prog *prog, + u32 *target_size) +{ + struct bpf_insn *insn = insn_buf; + + switch (si->off) { + case offsetof(struct bpf_sk_lookup, sk): + *insn++ = BPF_LDX_MEM(BPF_SIZEOF(void *), si->dst_reg, si->src_reg, + offsetof(struct bpf_sk_lookup_kern, selected_sk)); + break; + + case offsetof(struct bpf_sk_lookup, family): + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg, + bpf_target_off(struct bpf_sk_lookup_kern, + family, 2, target_size)); + break; + + case offsetof(struct bpf_sk_lookup, protocol): + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg, + bpf_target_off(struct bpf_sk_lookup_kern, + protocol, 2, target_size)); + break; + + case offsetof(struct bpf_sk_lookup, remote_ip4): + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, + bpf_target_off(struct bpf_sk_lookup_kern, + v4.saddr, 4, target_size)); + break; + + case offsetof(struct bpf_sk_lookup, local_ip4): + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, + bpf_target_off(struct bpf_sk_lookup_kern, + v4.daddr, 4, target_size)); + break; + + case bpf_ctx_range_till(struct bpf_sk_lookup, + remote_ip6[0], remote_ip6[3]): { +#if IS_ENABLED(CONFIG_IPV6) + int off = si->off; + + off -= offsetof(struct bpf_sk_lookup, remote_ip6[0]); + off += bpf_target_off(struct in6_addr, s6_addr32[0], 4, target_size); + *insn++ = BPF_LDX_MEM(BPF_SIZEOF(void *), si->dst_reg, si->src_reg, + offsetof(struct bpf_sk_lookup_kern, v6.saddr)); + *insn++ = BPF_JMP_IMM(BPF_JEQ, si->dst_reg, 0, 1); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, off); +#else + *insn++ = BPF_MOV32_IMM(si->dst_reg, 0); +#endif + break; + } + case bpf_ctx_range_till(struct bpf_sk_lookup, + local_ip6[0], local_ip6[3]): { +#if IS_ENABLED(CONFIG_IPV6) + int off = si->off; + + off -= offsetof(struct bpf_sk_lookup, local_ip6[0]); + off += bpf_target_off(struct in6_addr, s6_addr32[0], 4, target_size); + *insn++ = BPF_LDX_MEM(BPF_SIZEOF(void *), si->dst_reg, si->src_reg, + offsetof(struct bpf_sk_lookup_kern, v6.daddr)); + *insn++ = BPF_JMP_IMM(BPF_JEQ, si->dst_reg, 0, 1); + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->dst_reg, off); +#else + *insn++ = BPF_MOV32_IMM(si->dst_reg, 0); +#endif + break; + } + case offsetof(struct bpf_sk_lookup, remote_port): + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg, + bpf_target_off(struct bpf_sk_lookup_kern, + sport, 2, target_size)); + break; + + case offsetofend(struct bpf_sk_lookup, remote_port): + *target_size = 2; + *insn++ = BPF_MOV32_IMM(si->dst_reg, 0); + break; + + case offsetof(struct bpf_sk_lookup, local_port): + *insn++ = BPF_LDX_MEM(BPF_H, si->dst_reg, si->src_reg, + bpf_target_off(struct bpf_sk_lookup_kern, + dport, 2, target_size)); + break; + + case offsetof(struct bpf_sk_lookup, ingress_ifindex): + *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, + bpf_target_off(struct bpf_sk_lookup_kern, + ingress_ifindex, 4, target_size)); + break; + } + + return insn - insn_buf; +} + +const struct bpf_prog_ops sk_lookup_prog_ops = { + .test_run = bpf_prog_test_run_sk_lookup, +}; + +const struct bpf_verifier_ops sk_lookup_verifier_ops = { + .get_func_proto = sk_lookup_func_proto, + .is_valid_access = sk_lookup_is_valid_access, + .convert_ctx_access = sk_lookup_convert_ctx_access, +}; + +#endif /* CONFIG_INET */ + +DEFINE_BPF_DISPATCHER(xdp) + +void bpf_prog_change_xdp(struct bpf_prog *prev_prog, struct bpf_prog *prog) +{ + bpf_dispatcher_change_prog(BPF_DISPATCHER_PTR(xdp), prev_prog, prog); +} + +BTF_ID_LIST_GLOBAL(btf_sock_ids, MAX_BTF_SOCK_TYPE) +#define BTF_SOCK_TYPE(name, type) BTF_ID(struct, type) +BTF_SOCK_TYPE_xxx +#undef BTF_SOCK_TYPE + +BPF_CALL_1(bpf_skc_to_tcp6_sock, struct sock *, sk) +{ + /* tcp6_sock type is not generated in dwarf and hence btf, + * trigger an explicit type generation here. + */ + BTF_TYPE_EMIT(struct tcp6_sock); + if (sk && sk_fullsock(sk) && sk->sk_protocol == IPPROTO_TCP && + sk->sk_family == AF_INET6) + return (unsigned long)sk; + + return (unsigned long)NULL; +} + +const struct bpf_func_proto bpf_skc_to_tcp6_sock_proto = { + .func = bpf_skc_to_tcp6_sock, + .gpl_only = false, + .ret_type = RET_PTR_TO_BTF_ID_OR_NULL, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, + .ret_btf_id = &btf_sock_ids[BTF_SOCK_TYPE_TCP6], +}; + +BPF_CALL_1(bpf_skc_to_tcp_sock, struct sock *, sk) +{ + if (sk && sk_fullsock(sk) && sk->sk_protocol == IPPROTO_TCP) + return (unsigned long)sk; + + return (unsigned long)NULL; +} + +const struct bpf_func_proto bpf_skc_to_tcp_sock_proto = { + .func = bpf_skc_to_tcp_sock, + .gpl_only = false, + .ret_type = RET_PTR_TO_BTF_ID_OR_NULL, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, + .ret_btf_id = &btf_sock_ids[BTF_SOCK_TYPE_TCP], +}; + +BPF_CALL_1(bpf_skc_to_tcp_timewait_sock, struct sock *, sk) +{ + /* BTF types for tcp_timewait_sock and inet_timewait_sock are not + * generated if CONFIG_INET=n. Trigger an explicit generation here. + */ + BTF_TYPE_EMIT(struct inet_timewait_sock); + BTF_TYPE_EMIT(struct tcp_timewait_sock); + +#ifdef CONFIG_INET + if (sk && sk->sk_prot == &tcp_prot && sk->sk_state == TCP_TIME_WAIT) + return (unsigned long)sk; +#endif + +#if IS_BUILTIN(CONFIG_IPV6) + if (sk && sk->sk_prot == &tcpv6_prot && sk->sk_state == TCP_TIME_WAIT) + return (unsigned long)sk; +#endif + + return (unsigned long)NULL; +} + +const struct bpf_func_proto bpf_skc_to_tcp_timewait_sock_proto = { + .func = bpf_skc_to_tcp_timewait_sock, + .gpl_only = false, + .ret_type = RET_PTR_TO_BTF_ID_OR_NULL, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, + .ret_btf_id = &btf_sock_ids[BTF_SOCK_TYPE_TCP_TW], +}; + +BPF_CALL_1(bpf_skc_to_tcp_request_sock, struct sock *, sk) +{ +#ifdef CONFIG_INET + if (sk && sk->sk_prot == &tcp_prot && sk->sk_state == TCP_NEW_SYN_RECV) + return (unsigned long)sk; +#endif + +#if IS_BUILTIN(CONFIG_IPV6) + if (sk && sk->sk_prot == &tcpv6_prot && sk->sk_state == TCP_NEW_SYN_RECV) + return (unsigned long)sk; +#endif + + return (unsigned long)NULL; +} + +const struct bpf_func_proto bpf_skc_to_tcp_request_sock_proto = { + .func = bpf_skc_to_tcp_request_sock, + .gpl_only = false, + .ret_type = RET_PTR_TO_BTF_ID_OR_NULL, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, + .ret_btf_id = &btf_sock_ids[BTF_SOCK_TYPE_TCP_REQ], +}; + +BPF_CALL_1(bpf_skc_to_udp6_sock, struct sock *, sk) +{ + /* udp6_sock type is not generated in dwarf and hence btf, + * trigger an explicit type generation here. + */ + BTF_TYPE_EMIT(struct udp6_sock); + if (sk && sk_fullsock(sk) && sk->sk_protocol == IPPROTO_UDP && + sk->sk_type == SOCK_DGRAM && sk->sk_family == AF_INET6) + return (unsigned long)sk; + + return (unsigned long)NULL; +} + +const struct bpf_func_proto bpf_skc_to_udp6_sock_proto = { + .func = bpf_skc_to_udp6_sock, + .gpl_only = false, + .ret_type = RET_PTR_TO_BTF_ID_OR_NULL, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, + .ret_btf_id = &btf_sock_ids[BTF_SOCK_TYPE_UDP6], +}; + +BPF_CALL_1(bpf_skc_to_unix_sock, struct sock *, sk) +{ + /* unix_sock type is not generated in dwarf and hence btf, + * trigger an explicit type generation here. + */ + BTF_TYPE_EMIT(struct unix_sock); + if (sk && sk_fullsock(sk) && sk->sk_family == AF_UNIX) + return (unsigned long)sk; + + return (unsigned long)NULL; +} + +const struct bpf_func_proto bpf_skc_to_unix_sock_proto = { + .func = bpf_skc_to_unix_sock, + .gpl_only = false, + .ret_type = RET_PTR_TO_BTF_ID_OR_NULL, + .arg1_type = ARG_PTR_TO_BTF_ID_SOCK_COMMON, + .ret_btf_id = &btf_sock_ids[BTF_SOCK_TYPE_UNIX], +}; + +BPF_CALL_1(bpf_skc_to_mptcp_sock, struct sock *, sk) +{ + BTF_TYPE_EMIT(struct mptcp_sock); + return (unsigned long)bpf_mptcp_sock_from_subflow(sk); +} + +const struct bpf_func_proto bpf_skc_to_mptcp_sock_proto = { + .func = bpf_skc_to_mptcp_sock, + .gpl_only = false, + .ret_type = RET_PTR_TO_BTF_ID_OR_NULL, + .arg1_type = ARG_PTR_TO_SOCK_COMMON, + .ret_btf_id = &btf_sock_ids[BTF_SOCK_TYPE_MPTCP], +}; + +BPF_CALL_1(bpf_sock_from_file, struct file *, file) +{ + return (unsigned long)sock_from_file(file); +} + +BTF_ID_LIST(bpf_sock_from_file_btf_ids) +BTF_ID(struct, socket) +BTF_ID(struct, file) + +const struct bpf_func_proto bpf_sock_from_file_proto = { + .func = bpf_sock_from_file, + .gpl_only = false, + .ret_type = RET_PTR_TO_BTF_ID_OR_NULL, + .ret_btf_id = &bpf_sock_from_file_btf_ids[0], + .arg1_type = ARG_PTR_TO_BTF_ID, + .arg1_btf_id = &bpf_sock_from_file_btf_ids[1], +}; + +static const struct bpf_func_proto * +bpf_sk_base_func_proto(enum bpf_func_id func_id) +{ + const struct bpf_func_proto *func; + + switch (func_id) { + case BPF_FUNC_skc_to_tcp6_sock: + func = &bpf_skc_to_tcp6_sock_proto; + break; + case BPF_FUNC_skc_to_tcp_sock: + func = &bpf_skc_to_tcp_sock_proto; + break; + case BPF_FUNC_skc_to_tcp_timewait_sock: + func = &bpf_skc_to_tcp_timewait_sock_proto; + break; + case BPF_FUNC_skc_to_tcp_request_sock: + func = &bpf_skc_to_tcp_request_sock_proto; + break; + case BPF_FUNC_skc_to_udp6_sock: + func = &bpf_skc_to_udp6_sock_proto; + break; + case BPF_FUNC_skc_to_unix_sock: + func = &bpf_skc_to_unix_sock_proto; + break; + case BPF_FUNC_skc_to_mptcp_sock: + func = &bpf_skc_to_mptcp_sock_proto; + break; + case BPF_FUNC_ktime_get_coarse_ns: + return &bpf_ktime_get_coarse_ns_proto; + default: + return bpf_base_func_proto(func_id); + } + + if (!perfmon_capable()) + return NULL; + + return func; +} diff --git a/net/core/flow_dissector.c b/net/core/flow_dissector.c new file mode 100644 index 000000000..0c85c8a9e --- /dev/null +++ b/net/core/flow_dissector.c @@ -0,0 +1,1960 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_NF_CONNTRACK) +#include +#include +#endif +#include + +static void dissector_set_key(struct flow_dissector *flow_dissector, + enum flow_dissector_key_id key_id) +{ + flow_dissector->used_keys |= (1 << key_id); +} + +void skb_flow_dissector_init(struct flow_dissector *flow_dissector, + const struct flow_dissector_key *key, + unsigned int key_count) +{ + unsigned int i; + + memset(flow_dissector, 0, sizeof(*flow_dissector)); + + for (i = 0; i < key_count; i++, key++) { + /* User should make sure that every key target offset is within + * boundaries of unsigned short. + */ + BUG_ON(key->offset > USHRT_MAX); + BUG_ON(dissector_uses_key(flow_dissector, + key->key_id)); + + dissector_set_key(flow_dissector, key->key_id); + flow_dissector->offset[key->key_id] = key->offset; + } + + /* Ensure that the dissector always includes control and basic key. + * That way we are able to avoid handling lack of these in fast path. + */ + BUG_ON(!dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_CONTROL)); + BUG_ON(!dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_BASIC)); +} +EXPORT_SYMBOL(skb_flow_dissector_init); + +#ifdef CONFIG_BPF_SYSCALL +int flow_dissector_bpf_prog_attach_check(struct net *net, + struct bpf_prog *prog) +{ + enum netns_bpf_attach_type type = NETNS_BPF_FLOW_DISSECTOR; + + if (net == &init_net) { + /* BPF flow dissector in the root namespace overrides + * any per-net-namespace one. When attaching to root, + * make sure we don't have any BPF program attached + * to the non-root namespaces. + */ + struct net *ns; + + for_each_net(ns) { + if (ns == &init_net) + continue; + if (rcu_access_pointer(ns->bpf.run_array[type])) + return -EEXIST; + } + } else { + /* Make sure root flow dissector is not attached + * when attaching to the non-root namespace. + */ + if (rcu_access_pointer(init_net.bpf.run_array[type])) + return -EEXIST; + } + + return 0; +} +#endif /* CONFIG_BPF_SYSCALL */ + +/** + * __skb_flow_get_ports - extract the upper layer ports and return them + * @skb: sk_buff to extract the ports from + * @thoff: transport header offset + * @ip_proto: protocol for which to get port offset + * @data: raw buffer pointer to the packet, if NULL use skb->data + * @hlen: packet header length, if @data is NULL use skb_headlen(skb) + * + * The function will try to retrieve the ports at offset thoff + poff where poff + * is the protocol port offset returned from proto_ports_offset + */ +__be32 __skb_flow_get_ports(const struct sk_buff *skb, int thoff, u8 ip_proto, + const void *data, int hlen) +{ + int poff = proto_ports_offset(ip_proto); + + if (!data) { + data = skb->data; + hlen = skb_headlen(skb); + } + + if (poff >= 0) { + __be32 *ports, _ports; + + ports = __skb_header_pointer(skb, thoff + poff, + sizeof(_ports), data, hlen, &_ports); + if (ports) + return *ports; + } + + return 0; +} +EXPORT_SYMBOL(__skb_flow_get_ports); + +static bool icmp_has_id(u8 type) +{ + switch (type) { + case ICMP_ECHO: + case ICMP_ECHOREPLY: + case ICMP_TIMESTAMP: + case ICMP_TIMESTAMPREPLY: + case ICMPV6_ECHO_REQUEST: + case ICMPV6_ECHO_REPLY: + return true; + } + + return false; +} + +/** + * skb_flow_get_icmp_tci - extract ICMP(6) Type, Code and Identifier fields + * @skb: sk_buff to extract from + * @key_icmp: struct flow_dissector_key_icmp to fill + * @data: raw buffer pointer to the packet + * @thoff: offset to extract at + * @hlen: packet header length + */ +void skb_flow_get_icmp_tci(const struct sk_buff *skb, + struct flow_dissector_key_icmp *key_icmp, + const void *data, int thoff, int hlen) +{ + struct icmphdr *ih, _ih; + + ih = __skb_header_pointer(skb, thoff, sizeof(_ih), data, hlen, &_ih); + if (!ih) + return; + + key_icmp->type = ih->type; + key_icmp->code = ih->code; + + /* As we use 0 to signal that the Id field is not present, + * avoid confusion with packets without such field + */ + if (icmp_has_id(ih->type)) + key_icmp->id = ih->un.echo.id ? ntohs(ih->un.echo.id) : 1; + else + key_icmp->id = 0; +} +EXPORT_SYMBOL(skb_flow_get_icmp_tci); + +/* If FLOW_DISSECTOR_KEY_ICMP is set, dissect an ICMP packet + * using skb_flow_get_icmp_tci(). + */ +static void __skb_flow_dissect_icmp(const struct sk_buff *skb, + struct flow_dissector *flow_dissector, + void *target_container, const void *data, + int thoff, int hlen) +{ + struct flow_dissector_key_icmp *key_icmp; + + if (!dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_ICMP)) + return; + + key_icmp = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_ICMP, + target_container); + + skb_flow_get_icmp_tci(skb, key_icmp, data, thoff, hlen); +} + +static void __skb_flow_dissect_l2tpv3(const struct sk_buff *skb, + struct flow_dissector *flow_dissector, + void *target_container, const void *data, + int nhoff, int hlen) +{ + struct flow_dissector_key_l2tpv3 *key_l2tpv3; + struct { + __be32 session_id; + } *hdr, _hdr; + + if (!dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_L2TPV3)) + return; + + hdr = __skb_header_pointer(skb, nhoff, sizeof(_hdr), data, hlen, &_hdr); + if (!hdr) + return; + + key_l2tpv3 = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_L2TPV3, + target_container); + + key_l2tpv3->session_id = hdr->session_id; +} + +void skb_flow_dissect_meta(const struct sk_buff *skb, + struct flow_dissector *flow_dissector, + void *target_container) +{ + struct flow_dissector_key_meta *meta; + + if (!dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_META)) + return; + + meta = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_META, + target_container); + meta->ingress_ifindex = skb->skb_iif; +} +EXPORT_SYMBOL(skb_flow_dissect_meta); + +static void +skb_flow_dissect_set_enc_addr_type(enum flow_dissector_key_id type, + struct flow_dissector *flow_dissector, + void *target_container) +{ + struct flow_dissector_key_control *ctrl; + + if (!dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_ENC_CONTROL)) + return; + + ctrl = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_ENC_CONTROL, + target_container); + ctrl->addr_type = type; +} + +void +skb_flow_dissect_ct(const struct sk_buff *skb, + struct flow_dissector *flow_dissector, + void *target_container, u16 *ctinfo_map, + size_t mapsize, bool post_ct, u16 zone) +{ +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + struct flow_dissector_key_ct *key; + enum ip_conntrack_info ctinfo; + struct nf_conn_labels *cl; + struct nf_conn *ct; + + if (!dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_CT)) + return; + + ct = nf_ct_get(skb, &ctinfo); + if (!ct && !post_ct) + return; + + key = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_CT, + target_container); + + if (!ct) { + key->ct_state = TCA_FLOWER_KEY_CT_FLAGS_TRACKED | + TCA_FLOWER_KEY_CT_FLAGS_INVALID; + key->ct_zone = zone; + return; + } + + if (ctinfo < mapsize) + key->ct_state = ctinfo_map[ctinfo]; +#if IS_ENABLED(CONFIG_NF_CONNTRACK_ZONES) + key->ct_zone = ct->zone.id; +#endif +#if IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) + key->ct_mark = READ_ONCE(ct->mark); +#endif + + cl = nf_ct_labels_find(ct); + if (cl) + memcpy(key->ct_labels, cl->bits, sizeof(key->ct_labels)); +#endif /* CONFIG_NF_CONNTRACK */ +} +EXPORT_SYMBOL(skb_flow_dissect_ct); + +void +skb_flow_dissect_tunnel_info(const struct sk_buff *skb, + struct flow_dissector *flow_dissector, + void *target_container) +{ + struct ip_tunnel_info *info; + struct ip_tunnel_key *key; + + /* A quick check to see if there might be something to do. */ + if (!dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_ENC_KEYID) && + !dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_ENC_IPV4_ADDRS) && + !dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_ENC_IPV6_ADDRS) && + !dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_ENC_CONTROL) && + !dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_ENC_PORTS) && + !dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_ENC_IP) && + !dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_ENC_OPTS)) + return; + + info = skb_tunnel_info(skb); + if (!info) + return; + + key = &info->key; + + switch (ip_tunnel_info_af(info)) { + case AF_INET: + skb_flow_dissect_set_enc_addr_type(FLOW_DISSECTOR_KEY_IPV4_ADDRS, + flow_dissector, + target_container); + if (dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_ENC_IPV4_ADDRS)) { + struct flow_dissector_key_ipv4_addrs *ipv4; + + ipv4 = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_ENC_IPV4_ADDRS, + target_container); + ipv4->src = key->u.ipv4.src; + ipv4->dst = key->u.ipv4.dst; + } + break; + case AF_INET6: + skb_flow_dissect_set_enc_addr_type(FLOW_DISSECTOR_KEY_IPV6_ADDRS, + flow_dissector, + target_container); + if (dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_ENC_IPV6_ADDRS)) { + struct flow_dissector_key_ipv6_addrs *ipv6; + + ipv6 = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_ENC_IPV6_ADDRS, + target_container); + ipv6->src = key->u.ipv6.src; + ipv6->dst = key->u.ipv6.dst; + } + break; + } + + if (dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_ENC_KEYID)) { + struct flow_dissector_key_keyid *keyid; + + keyid = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_ENC_KEYID, + target_container); + keyid->keyid = tunnel_id_to_key32(key->tun_id); + } + + if (dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_ENC_PORTS)) { + struct flow_dissector_key_ports *tp; + + tp = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_ENC_PORTS, + target_container); + tp->src = key->tp_src; + tp->dst = key->tp_dst; + } + + if (dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_ENC_IP)) { + struct flow_dissector_key_ip *ip; + + ip = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_ENC_IP, + target_container); + ip->tos = key->tos; + ip->ttl = key->ttl; + } + + if (dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_ENC_OPTS)) { + struct flow_dissector_key_enc_opts *enc_opt; + + enc_opt = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_ENC_OPTS, + target_container); + + if (info->options_len) { + enc_opt->len = info->options_len; + ip_tunnel_info_opts_get(enc_opt->data, info); + enc_opt->dst_opt_type = info->key.tun_flags & + TUNNEL_OPTIONS_PRESENT; + } + } +} +EXPORT_SYMBOL(skb_flow_dissect_tunnel_info); + +void skb_flow_dissect_hash(const struct sk_buff *skb, + struct flow_dissector *flow_dissector, + void *target_container) +{ + struct flow_dissector_key_hash *key; + + if (!dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_HASH)) + return; + + key = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_HASH, + target_container); + + key->hash = skb_get_hash_raw(skb); +} +EXPORT_SYMBOL(skb_flow_dissect_hash); + +static enum flow_dissect_ret +__skb_flow_dissect_mpls(const struct sk_buff *skb, + struct flow_dissector *flow_dissector, + void *target_container, const void *data, int nhoff, + int hlen, int lse_index, bool *entropy_label) +{ + struct mpls_label *hdr, _hdr; + u32 entry, label, bos; + + if (!dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_MPLS_ENTROPY) && + !dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_MPLS)) + return FLOW_DISSECT_RET_OUT_GOOD; + + if (lse_index >= FLOW_DIS_MPLS_MAX) + return FLOW_DISSECT_RET_OUT_GOOD; + + hdr = __skb_header_pointer(skb, nhoff, sizeof(_hdr), data, + hlen, &_hdr); + if (!hdr) + return FLOW_DISSECT_RET_OUT_BAD; + + entry = ntohl(hdr->entry); + label = (entry & MPLS_LS_LABEL_MASK) >> MPLS_LS_LABEL_SHIFT; + bos = (entry & MPLS_LS_S_MASK) >> MPLS_LS_S_SHIFT; + + if (dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_MPLS)) { + struct flow_dissector_key_mpls *key_mpls; + struct flow_dissector_mpls_lse *lse; + + key_mpls = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_MPLS, + target_container); + lse = &key_mpls->ls[lse_index]; + + lse->mpls_ttl = (entry & MPLS_LS_TTL_MASK) >> MPLS_LS_TTL_SHIFT; + lse->mpls_bos = bos; + lse->mpls_tc = (entry & MPLS_LS_TC_MASK) >> MPLS_LS_TC_SHIFT; + lse->mpls_label = label; + dissector_set_mpls_lse(key_mpls, lse_index); + } + + if (*entropy_label && + dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_MPLS_ENTROPY)) { + struct flow_dissector_key_keyid *key_keyid; + + key_keyid = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_MPLS_ENTROPY, + target_container); + key_keyid->keyid = cpu_to_be32(label); + } + + *entropy_label = label == MPLS_LABEL_ENTROPY; + + return bos ? FLOW_DISSECT_RET_OUT_GOOD : FLOW_DISSECT_RET_PROTO_AGAIN; +} + +static enum flow_dissect_ret +__skb_flow_dissect_arp(const struct sk_buff *skb, + struct flow_dissector *flow_dissector, + void *target_container, const void *data, + int nhoff, int hlen) +{ + struct flow_dissector_key_arp *key_arp; + struct { + unsigned char ar_sha[ETH_ALEN]; + unsigned char ar_sip[4]; + unsigned char ar_tha[ETH_ALEN]; + unsigned char ar_tip[4]; + } *arp_eth, _arp_eth; + const struct arphdr *arp; + struct arphdr _arp; + + if (!dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_ARP)) + return FLOW_DISSECT_RET_OUT_GOOD; + + arp = __skb_header_pointer(skb, nhoff, sizeof(_arp), data, + hlen, &_arp); + if (!arp) + return FLOW_DISSECT_RET_OUT_BAD; + + if (arp->ar_hrd != htons(ARPHRD_ETHER) || + arp->ar_pro != htons(ETH_P_IP) || + arp->ar_hln != ETH_ALEN || + arp->ar_pln != 4 || + (arp->ar_op != htons(ARPOP_REPLY) && + arp->ar_op != htons(ARPOP_REQUEST))) + return FLOW_DISSECT_RET_OUT_BAD; + + arp_eth = __skb_header_pointer(skb, nhoff + sizeof(_arp), + sizeof(_arp_eth), data, + hlen, &_arp_eth); + if (!arp_eth) + return FLOW_DISSECT_RET_OUT_BAD; + + key_arp = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_ARP, + target_container); + + memcpy(&key_arp->sip, arp_eth->ar_sip, sizeof(key_arp->sip)); + memcpy(&key_arp->tip, arp_eth->ar_tip, sizeof(key_arp->tip)); + + /* Only store the lower byte of the opcode; + * this covers ARPOP_REPLY and ARPOP_REQUEST. + */ + key_arp->op = ntohs(arp->ar_op) & 0xff; + + ether_addr_copy(key_arp->sha, arp_eth->ar_sha); + ether_addr_copy(key_arp->tha, arp_eth->ar_tha); + + return FLOW_DISSECT_RET_OUT_GOOD; +} + +static enum flow_dissect_ret +__skb_flow_dissect_gre(const struct sk_buff *skb, + struct flow_dissector_key_control *key_control, + struct flow_dissector *flow_dissector, + void *target_container, const void *data, + __be16 *p_proto, int *p_nhoff, int *p_hlen, + unsigned int flags) +{ + struct flow_dissector_key_keyid *key_keyid; + struct gre_base_hdr *hdr, _hdr; + int offset = 0; + u16 gre_ver; + + hdr = __skb_header_pointer(skb, *p_nhoff, sizeof(_hdr), + data, *p_hlen, &_hdr); + if (!hdr) + return FLOW_DISSECT_RET_OUT_BAD; + + /* Only look inside GRE without routing */ + if (hdr->flags & GRE_ROUTING) + return FLOW_DISSECT_RET_OUT_GOOD; + + /* Only look inside GRE for version 0 and 1 */ + gre_ver = ntohs(hdr->flags & GRE_VERSION); + if (gre_ver > 1) + return FLOW_DISSECT_RET_OUT_GOOD; + + *p_proto = hdr->protocol; + if (gre_ver) { + /* Version1 must be PPTP, and check the flags */ + if (!(*p_proto == GRE_PROTO_PPP && (hdr->flags & GRE_KEY))) + return FLOW_DISSECT_RET_OUT_GOOD; + } + + offset += sizeof(struct gre_base_hdr); + + if (hdr->flags & GRE_CSUM) + offset += sizeof_field(struct gre_full_hdr, csum) + + sizeof_field(struct gre_full_hdr, reserved1); + + if (hdr->flags & GRE_KEY) { + const __be32 *keyid; + __be32 _keyid; + + keyid = __skb_header_pointer(skb, *p_nhoff + offset, + sizeof(_keyid), + data, *p_hlen, &_keyid); + if (!keyid) + return FLOW_DISSECT_RET_OUT_BAD; + + if (dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_GRE_KEYID)) { + key_keyid = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_GRE_KEYID, + target_container); + if (gre_ver == 0) + key_keyid->keyid = *keyid; + else + key_keyid->keyid = *keyid & GRE_PPTP_KEY_MASK; + } + offset += sizeof_field(struct gre_full_hdr, key); + } + + if (hdr->flags & GRE_SEQ) + offset += sizeof_field(struct pptp_gre_header, seq); + + if (gre_ver == 0) { + if (*p_proto == htons(ETH_P_TEB)) { + const struct ethhdr *eth; + struct ethhdr _eth; + + eth = __skb_header_pointer(skb, *p_nhoff + offset, + sizeof(_eth), + data, *p_hlen, &_eth); + if (!eth) + return FLOW_DISSECT_RET_OUT_BAD; + *p_proto = eth->h_proto; + offset += sizeof(*eth); + + /* Cap headers that we access via pointers at the + * end of the Ethernet header as our maximum alignment + * at that point is only 2 bytes. + */ + if (NET_IP_ALIGN) + *p_hlen = *p_nhoff + offset; + } + } else { /* version 1, must be PPTP */ + u8 _ppp_hdr[PPP_HDRLEN]; + u8 *ppp_hdr; + + if (hdr->flags & GRE_ACK) + offset += sizeof_field(struct pptp_gre_header, ack); + + ppp_hdr = __skb_header_pointer(skb, *p_nhoff + offset, + sizeof(_ppp_hdr), + data, *p_hlen, _ppp_hdr); + if (!ppp_hdr) + return FLOW_DISSECT_RET_OUT_BAD; + + switch (PPP_PROTOCOL(ppp_hdr)) { + case PPP_IP: + *p_proto = htons(ETH_P_IP); + break; + case PPP_IPV6: + *p_proto = htons(ETH_P_IPV6); + break; + default: + /* Could probably catch some more like MPLS */ + break; + } + + offset += PPP_HDRLEN; + } + + *p_nhoff += offset; + key_control->flags |= FLOW_DIS_ENCAPSULATION; + if (flags & FLOW_DISSECTOR_F_STOP_AT_ENCAP) + return FLOW_DISSECT_RET_OUT_GOOD; + + return FLOW_DISSECT_RET_PROTO_AGAIN; +} + +/** + * __skb_flow_dissect_batadv() - dissect batman-adv header + * @skb: sk_buff to with the batman-adv header + * @key_control: flow dissectors control key + * @data: raw buffer pointer to the packet, if NULL use skb->data + * @p_proto: pointer used to update the protocol to process next + * @p_nhoff: pointer used to update inner network header offset + * @hlen: packet header length + * @flags: any combination of FLOW_DISSECTOR_F_* + * + * ETH_P_BATMAN packets are tried to be dissected. Only + * &struct batadv_unicast packets are actually processed because they contain an + * inner ethernet header and are usually followed by actual network header. This + * allows the flow dissector to continue processing the packet. + * + * Return: FLOW_DISSECT_RET_PROTO_AGAIN when &struct batadv_unicast was found, + * FLOW_DISSECT_RET_OUT_GOOD when dissector should stop after encapsulation, + * otherwise FLOW_DISSECT_RET_OUT_BAD + */ +static enum flow_dissect_ret +__skb_flow_dissect_batadv(const struct sk_buff *skb, + struct flow_dissector_key_control *key_control, + const void *data, __be16 *p_proto, int *p_nhoff, + int hlen, unsigned int flags) +{ + struct { + struct batadv_unicast_packet batadv_unicast; + struct ethhdr eth; + } *hdr, _hdr; + + hdr = __skb_header_pointer(skb, *p_nhoff, sizeof(_hdr), data, hlen, + &_hdr); + if (!hdr) + return FLOW_DISSECT_RET_OUT_BAD; + + if (hdr->batadv_unicast.version != BATADV_COMPAT_VERSION) + return FLOW_DISSECT_RET_OUT_BAD; + + if (hdr->batadv_unicast.packet_type != BATADV_UNICAST) + return FLOW_DISSECT_RET_OUT_BAD; + + *p_proto = hdr->eth.h_proto; + *p_nhoff += sizeof(*hdr); + + key_control->flags |= FLOW_DIS_ENCAPSULATION; + if (flags & FLOW_DISSECTOR_F_STOP_AT_ENCAP) + return FLOW_DISSECT_RET_OUT_GOOD; + + return FLOW_DISSECT_RET_PROTO_AGAIN; +} + +static void +__skb_flow_dissect_tcp(const struct sk_buff *skb, + struct flow_dissector *flow_dissector, + void *target_container, const void *data, + int thoff, int hlen) +{ + struct flow_dissector_key_tcp *key_tcp; + struct tcphdr *th, _th; + + if (!dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_TCP)) + return; + + th = __skb_header_pointer(skb, thoff, sizeof(_th), data, hlen, &_th); + if (!th) + return; + + if (unlikely(__tcp_hdrlen(th) < sizeof(_th))) + return; + + key_tcp = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_TCP, + target_container); + key_tcp->flags = (*(__be16 *) &tcp_flag_word(th) & htons(0x0FFF)); +} + +static void +__skb_flow_dissect_ports(const struct sk_buff *skb, + struct flow_dissector *flow_dissector, + void *target_container, const void *data, + int nhoff, u8 ip_proto, int hlen) +{ + enum flow_dissector_key_id dissector_ports = FLOW_DISSECTOR_KEY_MAX; + struct flow_dissector_key_ports *key_ports; + + if (dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_PORTS)) + dissector_ports = FLOW_DISSECTOR_KEY_PORTS; + else if (dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_PORTS_RANGE)) + dissector_ports = FLOW_DISSECTOR_KEY_PORTS_RANGE; + + if (dissector_ports == FLOW_DISSECTOR_KEY_MAX) + return; + + key_ports = skb_flow_dissector_target(flow_dissector, + dissector_ports, + target_container); + key_ports->ports = __skb_flow_get_ports(skb, nhoff, ip_proto, + data, hlen); +} + +static void +__skb_flow_dissect_ipv4(const struct sk_buff *skb, + struct flow_dissector *flow_dissector, + void *target_container, const void *data, + const struct iphdr *iph) +{ + struct flow_dissector_key_ip *key_ip; + + if (!dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_IP)) + return; + + key_ip = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_IP, + target_container); + key_ip->tos = iph->tos; + key_ip->ttl = iph->ttl; +} + +static void +__skb_flow_dissect_ipv6(const struct sk_buff *skb, + struct flow_dissector *flow_dissector, + void *target_container, const void *data, + const struct ipv6hdr *iph) +{ + struct flow_dissector_key_ip *key_ip; + + if (!dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_IP)) + return; + + key_ip = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_IP, + target_container); + key_ip->tos = ipv6_get_dsfield(iph); + key_ip->ttl = iph->hop_limit; +} + +/* Maximum number of protocol headers that can be parsed in + * __skb_flow_dissect + */ +#define MAX_FLOW_DISSECT_HDRS 15 + +static bool skb_flow_dissect_allowed(int *num_hdrs) +{ + ++*num_hdrs; + + return (*num_hdrs <= MAX_FLOW_DISSECT_HDRS); +} + +static void __skb_flow_bpf_to_target(const struct bpf_flow_keys *flow_keys, + struct flow_dissector *flow_dissector, + void *target_container) +{ + struct flow_dissector_key_ports *key_ports = NULL; + struct flow_dissector_key_control *key_control; + struct flow_dissector_key_basic *key_basic; + struct flow_dissector_key_addrs *key_addrs; + struct flow_dissector_key_tags *key_tags; + + key_control = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_CONTROL, + target_container); + key_control->thoff = flow_keys->thoff; + if (flow_keys->is_frag) + key_control->flags |= FLOW_DIS_IS_FRAGMENT; + if (flow_keys->is_first_frag) + key_control->flags |= FLOW_DIS_FIRST_FRAG; + if (flow_keys->is_encap) + key_control->flags |= FLOW_DIS_ENCAPSULATION; + + key_basic = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_BASIC, + target_container); + key_basic->n_proto = flow_keys->n_proto; + key_basic->ip_proto = flow_keys->ip_proto; + + if (flow_keys->addr_proto == ETH_P_IP && + dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_IPV4_ADDRS)) { + key_addrs = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_IPV4_ADDRS, + target_container); + key_addrs->v4addrs.src = flow_keys->ipv4_src; + key_addrs->v4addrs.dst = flow_keys->ipv4_dst; + key_control->addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; + } else if (flow_keys->addr_proto == ETH_P_IPV6 && + dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_IPV6_ADDRS)) { + key_addrs = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_IPV6_ADDRS, + target_container); + memcpy(&key_addrs->v6addrs.src, &flow_keys->ipv6_src, + sizeof(key_addrs->v6addrs.src)); + memcpy(&key_addrs->v6addrs.dst, &flow_keys->ipv6_dst, + sizeof(key_addrs->v6addrs.dst)); + key_control->addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + } + + if (dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_PORTS)) + key_ports = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_PORTS, + target_container); + else if (dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_PORTS_RANGE)) + key_ports = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_PORTS_RANGE, + target_container); + + if (key_ports) { + key_ports->src = flow_keys->sport; + key_ports->dst = flow_keys->dport; + } + + if (dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_FLOW_LABEL)) { + key_tags = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_FLOW_LABEL, + target_container); + key_tags->flow_label = ntohl(flow_keys->flow_label); + } +} + +u32 bpf_flow_dissect(struct bpf_prog *prog, struct bpf_flow_dissector *ctx, + __be16 proto, int nhoff, int hlen, unsigned int flags) +{ + struct bpf_flow_keys *flow_keys = ctx->flow_keys; + u32 result; + + /* Pass parameters to the BPF program */ + memset(flow_keys, 0, sizeof(*flow_keys)); + flow_keys->n_proto = proto; + flow_keys->nhoff = nhoff; + flow_keys->thoff = flow_keys->nhoff; + + BUILD_BUG_ON((int)BPF_FLOW_DISSECTOR_F_PARSE_1ST_FRAG != + (int)FLOW_DISSECTOR_F_PARSE_1ST_FRAG); + BUILD_BUG_ON((int)BPF_FLOW_DISSECTOR_F_STOP_AT_FLOW_LABEL != + (int)FLOW_DISSECTOR_F_STOP_AT_FLOW_LABEL); + BUILD_BUG_ON((int)BPF_FLOW_DISSECTOR_F_STOP_AT_ENCAP != + (int)FLOW_DISSECTOR_F_STOP_AT_ENCAP); + flow_keys->flags = flags; + + result = bpf_prog_run_pin_on_cpu(prog, ctx); + + flow_keys->nhoff = clamp_t(u16, flow_keys->nhoff, nhoff, hlen); + flow_keys->thoff = clamp_t(u16, flow_keys->thoff, + flow_keys->nhoff, hlen); + + return result; +} + +static bool is_pppoe_ses_hdr_valid(const struct pppoe_hdr *hdr) +{ + return hdr->ver == 1 && hdr->type == 1 && hdr->code == 0; +} + +/** + * __skb_flow_dissect - extract the flow_keys struct and return it + * @net: associated network namespace, derived from @skb if NULL + * @skb: sk_buff to extract the flow from, can be NULL if the rest are specified + * @flow_dissector: list of keys to dissect + * @target_container: target structure to put dissected values into + * @data: raw buffer pointer to the packet, if NULL use skb->data + * @proto: protocol for which to get the flow, if @data is NULL use skb->protocol + * @nhoff: network header offset, if @data is NULL use skb_network_offset(skb) + * @hlen: packet header length, if @data is NULL use skb_headlen(skb) + * @flags: flags that control the dissection process, e.g. + * FLOW_DISSECTOR_F_STOP_AT_ENCAP. + * + * The function will try to retrieve individual keys into target specified + * by flow_dissector from either the skbuff or a raw buffer specified by the + * rest parameters. + * + * Caller must take care of zeroing target container memory. + */ +bool __skb_flow_dissect(const struct net *net, + const struct sk_buff *skb, + struct flow_dissector *flow_dissector, + void *target_container, const void *data, + __be16 proto, int nhoff, int hlen, unsigned int flags) +{ + struct flow_dissector_key_control *key_control; + struct flow_dissector_key_basic *key_basic; + struct flow_dissector_key_addrs *key_addrs; + struct flow_dissector_key_tags *key_tags; + struct flow_dissector_key_vlan *key_vlan; + enum flow_dissect_ret fdret; + enum flow_dissector_key_id dissector_vlan = FLOW_DISSECTOR_KEY_MAX; + bool mpls_el = false; + int mpls_lse = 0; + int num_hdrs = 0; + u8 ip_proto = 0; + bool ret; + + if (!data) { + data = skb->data; + proto = skb_vlan_tag_present(skb) ? + skb->vlan_proto : skb->protocol; + nhoff = skb_network_offset(skb); + hlen = skb_headlen(skb); +#if IS_ENABLED(CONFIG_NET_DSA) + if (unlikely(skb->dev && netdev_uses_dsa(skb->dev) && + proto == htons(ETH_P_XDSA))) { + const struct dsa_device_ops *ops; + int offset = 0; + + ops = skb->dev->dsa_ptr->tag_ops; + /* Only DSA header taggers break flow dissection */ + if (ops->needed_headroom) { + if (ops->flow_dissect) + ops->flow_dissect(skb, &proto, &offset); + else + dsa_tag_generic_flow_dissect(skb, + &proto, + &offset); + hlen -= offset; + nhoff += offset; + } + } +#endif + } + + /* It is ensured by skb_flow_dissector_init() that control key will + * be always present. + */ + key_control = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_CONTROL, + target_container); + + /* It is ensured by skb_flow_dissector_init() that basic key will + * be always present. + */ + key_basic = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_BASIC, + target_container); + + if (skb) { + if (!net) { + if (skb->dev) + net = dev_net(skb->dev); + else if (skb->sk) + net = sock_net(skb->sk); + } + } + + WARN_ON_ONCE(!net); + if (net) { + enum netns_bpf_attach_type type = NETNS_BPF_FLOW_DISSECTOR; + struct bpf_prog_array *run_array; + + rcu_read_lock(); + run_array = rcu_dereference(init_net.bpf.run_array[type]); + if (!run_array) + run_array = rcu_dereference(net->bpf.run_array[type]); + + if (run_array) { + struct bpf_flow_keys flow_keys; + struct bpf_flow_dissector ctx = { + .flow_keys = &flow_keys, + .data = data, + .data_end = data + hlen, + }; + __be16 n_proto = proto; + struct bpf_prog *prog; + u32 result; + + if (skb) { + ctx.skb = skb; + /* we can't use 'proto' in the skb case + * because it might be set to skb->vlan_proto + * which has been pulled from the data + */ + n_proto = skb->protocol; + } + + prog = READ_ONCE(run_array->items[0].prog); + result = bpf_flow_dissect(prog, &ctx, n_proto, nhoff, + hlen, flags); + if (result == BPF_FLOW_DISSECTOR_CONTINUE) + goto dissect_continue; + __skb_flow_bpf_to_target(&flow_keys, flow_dissector, + target_container); + rcu_read_unlock(); + return result == BPF_OK; + } +dissect_continue: + rcu_read_unlock(); + } + + if (dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_ETH_ADDRS)) { + struct ethhdr *eth = eth_hdr(skb); + struct flow_dissector_key_eth_addrs *key_eth_addrs; + + key_eth_addrs = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_ETH_ADDRS, + target_container); + memcpy(key_eth_addrs, eth, sizeof(*key_eth_addrs)); + } + + if (dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_NUM_OF_VLANS)) { + struct flow_dissector_key_num_of_vlans *key_num_of_vlans; + + key_num_of_vlans = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_NUM_OF_VLANS, + target_container); + key_num_of_vlans->num_of_vlans = 0; + } + +proto_again: + fdret = FLOW_DISSECT_RET_CONTINUE; + + switch (proto) { + case htons(ETH_P_IP): { + const struct iphdr *iph; + struct iphdr _iph; + + iph = __skb_header_pointer(skb, nhoff, sizeof(_iph), data, hlen, &_iph); + if (!iph || iph->ihl < 5) { + fdret = FLOW_DISSECT_RET_OUT_BAD; + break; + } + + nhoff += iph->ihl * 4; + + ip_proto = iph->protocol; + + if (dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_IPV4_ADDRS)) { + key_addrs = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_IPV4_ADDRS, + target_container); + + memcpy(&key_addrs->v4addrs.src, &iph->saddr, + sizeof(key_addrs->v4addrs.src)); + memcpy(&key_addrs->v4addrs.dst, &iph->daddr, + sizeof(key_addrs->v4addrs.dst)); + key_control->addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; + } + + __skb_flow_dissect_ipv4(skb, flow_dissector, + target_container, data, iph); + + if (ip_is_fragment(iph)) { + key_control->flags |= FLOW_DIS_IS_FRAGMENT; + + if (iph->frag_off & htons(IP_OFFSET)) { + fdret = FLOW_DISSECT_RET_OUT_GOOD; + break; + } else { + key_control->flags |= FLOW_DIS_FIRST_FRAG; + if (!(flags & + FLOW_DISSECTOR_F_PARSE_1ST_FRAG)) { + fdret = FLOW_DISSECT_RET_OUT_GOOD; + break; + } + } + } + + break; + } + case htons(ETH_P_IPV6): { + const struct ipv6hdr *iph; + struct ipv6hdr _iph; + + iph = __skb_header_pointer(skb, nhoff, sizeof(_iph), data, hlen, &_iph); + if (!iph) { + fdret = FLOW_DISSECT_RET_OUT_BAD; + break; + } + + ip_proto = iph->nexthdr; + nhoff += sizeof(struct ipv6hdr); + + if (dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_IPV6_ADDRS)) { + key_addrs = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_IPV6_ADDRS, + target_container); + + memcpy(&key_addrs->v6addrs.src, &iph->saddr, + sizeof(key_addrs->v6addrs.src)); + memcpy(&key_addrs->v6addrs.dst, &iph->daddr, + sizeof(key_addrs->v6addrs.dst)); + key_control->addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + } + + if ((dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_FLOW_LABEL) || + (flags & FLOW_DISSECTOR_F_STOP_AT_FLOW_LABEL)) && + ip6_flowlabel(iph)) { + __be32 flow_label = ip6_flowlabel(iph); + + if (dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_FLOW_LABEL)) { + key_tags = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_FLOW_LABEL, + target_container); + key_tags->flow_label = ntohl(flow_label); + } + if (flags & FLOW_DISSECTOR_F_STOP_AT_FLOW_LABEL) { + fdret = FLOW_DISSECT_RET_OUT_GOOD; + break; + } + } + + __skb_flow_dissect_ipv6(skb, flow_dissector, + target_container, data, iph); + + break; + } + case htons(ETH_P_8021AD): + case htons(ETH_P_8021Q): { + const struct vlan_hdr *vlan = NULL; + struct vlan_hdr _vlan; + __be16 saved_vlan_tpid = proto; + + if (dissector_vlan == FLOW_DISSECTOR_KEY_MAX && + skb && skb_vlan_tag_present(skb)) { + proto = skb->protocol; + } else { + vlan = __skb_header_pointer(skb, nhoff, sizeof(_vlan), + data, hlen, &_vlan); + if (!vlan) { + fdret = FLOW_DISSECT_RET_OUT_BAD; + break; + } + + proto = vlan->h_vlan_encapsulated_proto; + nhoff += sizeof(*vlan); + } + + if (dissector_uses_key(flow_dissector, FLOW_DISSECTOR_KEY_NUM_OF_VLANS) && + !(key_control->flags & FLOW_DIS_ENCAPSULATION)) { + struct flow_dissector_key_num_of_vlans *key_nvs; + + key_nvs = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_NUM_OF_VLANS, + target_container); + key_nvs->num_of_vlans++; + } + + if (dissector_vlan == FLOW_DISSECTOR_KEY_MAX) { + dissector_vlan = FLOW_DISSECTOR_KEY_VLAN; + } else if (dissector_vlan == FLOW_DISSECTOR_KEY_VLAN) { + dissector_vlan = FLOW_DISSECTOR_KEY_CVLAN; + } else { + fdret = FLOW_DISSECT_RET_PROTO_AGAIN; + break; + } + + if (dissector_uses_key(flow_dissector, dissector_vlan)) { + key_vlan = skb_flow_dissector_target(flow_dissector, + dissector_vlan, + target_container); + + if (!vlan) { + key_vlan->vlan_id = skb_vlan_tag_get_id(skb); + key_vlan->vlan_priority = skb_vlan_tag_get_prio(skb); + } else { + key_vlan->vlan_id = ntohs(vlan->h_vlan_TCI) & + VLAN_VID_MASK; + key_vlan->vlan_priority = + (ntohs(vlan->h_vlan_TCI) & + VLAN_PRIO_MASK) >> VLAN_PRIO_SHIFT; + } + key_vlan->vlan_tpid = saved_vlan_tpid; + key_vlan->vlan_eth_type = proto; + } + + fdret = FLOW_DISSECT_RET_PROTO_AGAIN; + break; + } + case htons(ETH_P_PPP_SES): { + struct { + struct pppoe_hdr hdr; + __be16 proto; + } *hdr, _hdr; + u16 ppp_proto; + + hdr = __skb_header_pointer(skb, nhoff, sizeof(_hdr), data, hlen, &_hdr); + if (!hdr) { + fdret = FLOW_DISSECT_RET_OUT_BAD; + break; + } + + if (!is_pppoe_ses_hdr_valid(&hdr->hdr)) { + fdret = FLOW_DISSECT_RET_OUT_BAD; + break; + } + + /* least significant bit of the most significant octet + * indicates if protocol field was compressed + */ + ppp_proto = ntohs(hdr->proto); + if (ppp_proto & 0x0100) { + ppp_proto = ppp_proto >> 8; + nhoff += PPPOE_SES_HLEN - 1; + } else { + nhoff += PPPOE_SES_HLEN; + } + + if (ppp_proto == PPP_IP) { + proto = htons(ETH_P_IP); + fdret = FLOW_DISSECT_RET_PROTO_AGAIN; + } else if (ppp_proto == PPP_IPV6) { + proto = htons(ETH_P_IPV6); + fdret = FLOW_DISSECT_RET_PROTO_AGAIN; + } else if (ppp_proto == PPP_MPLS_UC) { + proto = htons(ETH_P_MPLS_UC); + fdret = FLOW_DISSECT_RET_PROTO_AGAIN; + } else if (ppp_proto == PPP_MPLS_MC) { + proto = htons(ETH_P_MPLS_MC); + fdret = FLOW_DISSECT_RET_PROTO_AGAIN; + } else if (ppp_proto_is_valid(ppp_proto)) { + fdret = FLOW_DISSECT_RET_OUT_GOOD; + } else { + fdret = FLOW_DISSECT_RET_OUT_BAD; + break; + } + + if (dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_PPPOE)) { + struct flow_dissector_key_pppoe *key_pppoe; + + key_pppoe = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_PPPOE, + target_container); + key_pppoe->session_id = hdr->hdr.sid; + key_pppoe->ppp_proto = htons(ppp_proto); + key_pppoe->type = htons(ETH_P_PPP_SES); + } + break; + } + case htons(ETH_P_TIPC): { + struct tipc_basic_hdr *hdr, _hdr; + + hdr = __skb_header_pointer(skb, nhoff, sizeof(_hdr), + data, hlen, &_hdr); + if (!hdr) { + fdret = FLOW_DISSECT_RET_OUT_BAD; + break; + } + + if (dissector_uses_key(flow_dissector, + FLOW_DISSECTOR_KEY_TIPC)) { + key_addrs = skb_flow_dissector_target(flow_dissector, + FLOW_DISSECTOR_KEY_TIPC, + target_container); + key_addrs->tipckey.key = tipc_hdr_rps_key(hdr); + key_control->addr_type = FLOW_DISSECTOR_KEY_TIPC; + } + fdret = FLOW_DISSECT_RET_OUT_GOOD; + break; + } + + case htons(ETH_P_MPLS_UC): + case htons(ETH_P_MPLS_MC): + fdret = __skb_flow_dissect_mpls(skb, flow_dissector, + target_container, data, + nhoff, hlen, mpls_lse, + &mpls_el); + nhoff += sizeof(struct mpls_label); + mpls_lse++; + break; + case htons(ETH_P_FCOE): + if ((hlen - nhoff) < FCOE_HEADER_LEN) { + fdret = FLOW_DISSECT_RET_OUT_BAD; + break; + } + + nhoff += FCOE_HEADER_LEN; + fdret = FLOW_DISSECT_RET_OUT_GOOD; + break; + + case htons(ETH_P_ARP): + case htons(ETH_P_RARP): + fdret = __skb_flow_dissect_arp(skb, flow_dissector, + target_container, data, + nhoff, hlen); + break; + + case htons(ETH_P_BATMAN): + fdret = __skb_flow_dissect_batadv(skb, key_control, data, + &proto, &nhoff, hlen, flags); + break; + + case htons(ETH_P_1588): { + struct ptp_header *hdr, _hdr; + + hdr = __skb_header_pointer(skb, nhoff, sizeof(_hdr), data, + hlen, &_hdr); + if (!hdr) { + fdret = FLOW_DISSECT_RET_OUT_BAD; + break; + } + + nhoff += sizeof(struct ptp_header); + fdret = FLOW_DISSECT_RET_OUT_GOOD; + break; + } + + case htons(ETH_P_PRP): + case htons(ETH_P_HSR): { + struct hsr_tag *hdr, _hdr; + + hdr = __skb_header_pointer(skb, nhoff, sizeof(_hdr), data, hlen, + &_hdr); + if (!hdr) { + fdret = FLOW_DISSECT_RET_OUT_BAD; + break; + } + + proto = hdr->encap_proto; + nhoff += HSR_HLEN; + fdret = FLOW_DISSECT_RET_PROTO_AGAIN; + break; + } + + default: + fdret = FLOW_DISSECT_RET_OUT_BAD; + break; + } + + /* Process result of proto processing */ + switch (fdret) { + case FLOW_DISSECT_RET_OUT_GOOD: + goto out_good; + case FLOW_DISSECT_RET_PROTO_AGAIN: + if (skb_flow_dissect_allowed(&num_hdrs)) + goto proto_again; + goto out_good; + case FLOW_DISSECT_RET_CONTINUE: + case FLOW_DISSECT_RET_IPPROTO_AGAIN: + break; + case FLOW_DISSECT_RET_OUT_BAD: + default: + goto out_bad; + } + +ip_proto_again: + fdret = FLOW_DISSECT_RET_CONTINUE; + + switch (ip_proto) { + case IPPROTO_GRE: + if (flags & FLOW_DISSECTOR_F_STOP_BEFORE_ENCAP) { + fdret = FLOW_DISSECT_RET_OUT_GOOD; + break; + } + + fdret = __skb_flow_dissect_gre(skb, key_control, flow_dissector, + target_container, data, + &proto, &nhoff, &hlen, flags); + break; + + case NEXTHDR_HOP: + case NEXTHDR_ROUTING: + case NEXTHDR_DEST: { + u8 _opthdr[2], *opthdr; + + if (proto != htons(ETH_P_IPV6)) + break; + + opthdr = __skb_header_pointer(skb, nhoff, sizeof(_opthdr), + data, hlen, &_opthdr); + if (!opthdr) { + fdret = FLOW_DISSECT_RET_OUT_BAD; + break; + } + + ip_proto = opthdr[0]; + nhoff += (opthdr[1] + 1) << 3; + + fdret = FLOW_DISSECT_RET_IPPROTO_AGAIN; + break; + } + case NEXTHDR_FRAGMENT: { + struct frag_hdr _fh, *fh; + + if (proto != htons(ETH_P_IPV6)) + break; + + fh = __skb_header_pointer(skb, nhoff, sizeof(_fh), + data, hlen, &_fh); + + if (!fh) { + fdret = FLOW_DISSECT_RET_OUT_BAD; + break; + } + + key_control->flags |= FLOW_DIS_IS_FRAGMENT; + + nhoff += sizeof(_fh); + ip_proto = fh->nexthdr; + + if (!(fh->frag_off & htons(IP6_OFFSET))) { + key_control->flags |= FLOW_DIS_FIRST_FRAG; + if (flags & FLOW_DISSECTOR_F_PARSE_1ST_FRAG) { + fdret = FLOW_DISSECT_RET_IPPROTO_AGAIN; + break; + } + } + + fdret = FLOW_DISSECT_RET_OUT_GOOD; + break; + } + case IPPROTO_IPIP: + if (flags & FLOW_DISSECTOR_F_STOP_BEFORE_ENCAP) { + fdret = FLOW_DISSECT_RET_OUT_GOOD; + break; + } + + proto = htons(ETH_P_IP); + + key_control->flags |= FLOW_DIS_ENCAPSULATION; + if (flags & FLOW_DISSECTOR_F_STOP_AT_ENCAP) { + fdret = FLOW_DISSECT_RET_OUT_GOOD; + break; + } + + fdret = FLOW_DISSECT_RET_PROTO_AGAIN; + break; + + case IPPROTO_IPV6: + if (flags & FLOW_DISSECTOR_F_STOP_BEFORE_ENCAP) { + fdret = FLOW_DISSECT_RET_OUT_GOOD; + break; + } + + proto = htons(ETH_P_IPV6); + + key_control->flags |= FLOW_DIS_ENCAPSULATION; + if (flags & FLOW_DISSECTOR_F_STOP_AT_ENCAP) { + fdret = FLOW_DISSECT_RET_OUT_GOOD; + break; + } + + fdret = FLOW_DISSECT_RET_PROTO_AGAIN; + break; + + + case IPPROTO_MPLS: + proto = htons(ETH_P_MPLS_UC); + fdret = FLOW_DISSECT_RET_PROTO_AGAIN; + break; + + case IPPROTO_TCP: + __skb_flow_dissect_tcp(skb, flow_dissector, target_container, + data, nhoff, hlen); + break; + + case IPPROTO_ICMP: + case IPPROTO_ICMPV6: + __skb_flow_dissect_icmp(skb, flow_dissector, target_container, + data, nhoff, hlen); + break; + case IPPROTO_L2TP: + __skb_flow_dissect_l2tpv3(skb, flow_dissector, target_container, + data, nhoff, hlen); + break; + + default: + break; + } + + if (!(key_control->flags & FLOW_DIS_IS_FRAGMENT)) + __skb_flow_dissect_ports(skb, flow_dissector, target_container, + data, nhoff, ip_proto, hlen); + + /* Process result of IP proto processing */ + switch (fdret) { + case FLOW_DISSECT_RET_PROTO_AGAIN: + if (skb_flow_dissect_allowed(&num_hdrs)) + goto proto_again; + break; + case FLOW_DISSECT_RET_IPPROTO_AGAIN: + if (skb_flow_dissect_allowed(&num_hdrs)) + goto ip_proto_again; + break; + case FLOW_DISSECT_RET_OUT_GOOD: + case FLOW_DISSECT_RET_CONTINUE: + break; + case FLOW_DISSECT_RET_OUT_BAD: + default: + goto out_bad; + } + +out_good: + ret = true; + +out: + key_control->thoff = min_t(u16, nhoff, skb ? skb->len : hlen); + key_basic->n_proto = proto; + key_basic->ip_proto = ip_proto; + + return ret; + +out_bad: + ret = false; + goto out; +} +EXPORT_SYMBOL(__skb_flow_dissect); + +static siphash_aligned_key_t hashrnd; +static __always_inline void __flow_hash_secret_init(void) +{ + net_get_random_once(&hashrnd, sizeof(hashrnd)); +} + +static const void *flow_keys_hash_start(const struct flow_keys *flow) +{ + BUILD_BUG_ON(FLOW_KEYS_HASH_OFFSET % SIPHASH_ALIGNMENT); + return &flow->FLOW_KEYS_HASH_START_FIELD; +} + +static inline size_t flow_keys_hash_length(const struct flow_keys *flow) +{ + size_t diff = FLOW_KEYS_HASH_OFFSET + sizeof(flow->addrs); + + BUILD_BUG_ON((sizeof(*flow) - FLOW_KEYS_HASH_OFFSET) % sizeof(u32)); + + switch (flow->control.addr_type) { + case FLOW_DISSECTOR_KEY_IPV4_ADDRS: + diff -= sizeof(flow->addrs.v4addrs); + break; + case FLOW_DISSECTOR_KEY_IPV6_ADDRS: + diff -= sizeof(flow->addrs.v6addrs); + break; + case FLOW_DISSECTOR_KEY_TIPC: + diff -= sizeof(flow->addrs.tipckey); + break; + } + return sizeof(*flow) - diff; +} + +__be32 flow_get_u32_src(const struct flow_keys *flow) +{ + switch (flow->control.addr_type) { + case FLOW_DISSECTOR_KEY_IPV4_ADDRS: + return flow->addrs.v4addrs.src; + case FLOW_DISSECTOR_KEY_IPV6_ADDRS: + return (__force __be32)ipv6_addr_hash( + &flow->addrs.v6addrs.src); + case FLOW_DISSECTOR_KEY_TIPC: + return flow->addrs.tipckey.key; + default: + return 0; + } +} +EXPORT_SYMBOL(flow_get_u32_src); + +__be32 flow_get_u32_dst(const struct flow_keys *flow) +{ + switch (flow->control.addr_type) { + case FLOW_DISSECTOR_KEY_IPV4_ADDRS: + return flow->addrs.v4addrs.dst; + case FLOW_DISSECTOR_KEY_IPV6_ADDRS: + return (__force __be32)ipv6_addr_hash( + &flow->addrs.v6addrs.dst); + default: + return 0; + } +} +EXPORT_SYMBOL(flow_get_u32_dst); + +/* Sort the source and destination IP and the ports, + * to have consistent hash within the two directions + */ +static inline void __flow_hash_consistentify(struct flow_keys *keys) +{ + int addr_diff, i; + + switch (keys->control.addr_type) { + case FLOW_DISSECTOR_KEY_IPV4_ADDRS: + if ((__force u32)keys->addrs.v4addrs.dst < + (__force u32)keys->addrs.v4addrs.src) + swap(keys->addrs.v4addrs.src, keys->addrs.v4addrs.dst); + + if ((__force u16)keys->ports.dst < + (__force u16)keys->ports.src) { + swap(keys->ports.src, keys->ports.dst); + } + break; + case FLOW_DISSECTOR_KEY_IPV6_ADDRS: + addr_diff = memcmp(&keys->addrs.v6addrs.dst, + &keys->addrs.v6addrs.src, + sizeof(keys->addrs.v6addrs.dst)); + if (addr_diff < 0) { + for (i = 0; i < 4; i++) + swap(keys->addrs.v6addrs.src.s6_addr32[i], + keys->addrs.v6addrs.dst.s6_addr32[i]); + } + if ((__force u16)keys->ports.dst < + (__force u16)keys->ports.src) { + swap(keys->ports.src, keys->ports.dst); + } + break; + } +} + +static inline u32 __flow_hash_from_keys(struct flow_keys *keys, + const siphash_key_t *keyval) +{ + u32 hash; + + __flow_hash_consistentify(keys); + + hash = siphash(flow_keys_hash_start(keys), + flow_keys_hash_length(keys), keyval); + if (!hash) + hash = 1; + + return hash; +} + +u32 flow_hash_from_keys(struct flow_keys *keys) +{ + __flow_hash_secret_init(); + return __flow_hash_from_keys(keys, &hashrnd); +} +EXPORT_SYMBOL(flow_hash_from_keys); + +static inline u32 ___skb_get_hash(const struct sk_buff *skb, + struct flow_keys *keys, + const siphash_key_t *keyval) +{ + skb_flow_dissect_flow_keys(skb, keys, + FLOW_DISSECTOR_F_STOP_AT_FLOW_LABEL); + + return __flow_hash_from_keys(keys, keyval); +} + +struct _flow_keys_digest_data { + __be16 n_proto; + u8 ip_proto; + u8 padding; + __be32 ports; + __be32 src; + __be32 dst; +}; + +void make_flow_keys_digest(struct flow_keys_digest *digest, + const struct flow_keys *flow) +{ + struct _flow_keys_digest_data *data = + (struct _flow_keys_digest_data *)digest; + + BUILD_BUG_ON(sizeof(*data) > sizeof(*digest)); + + memset(digest, 0, sizeof(*digest)); + + data->n_proto = flow->basic.n_proto; + data->ip_proto = flow->basic.ip_proto; + data->ports = flow->ports.ports; + data->src = flow->addrs.v4addrs.src; + data->dst = flow->addrs.v4addrs.dst; +} +EXPORT_SYMBOL(make_flow_keys_digest); + +static struct flow_dissector flow_keys_dissector_symmetric __read_mostly; + +u32 __skb_get_hash_symmetric(const struct sk_buff *skb) +{ + struct flow_keys keys; + + __flow_hash_secret_init(); + + memset(&keys, 0, sizeof(keys)); + __skb_flow_dissect(NULL, skb, &flow_keys_dissector_symmetric, + &keys, NULL, 0, 0, 0, 0); + + return __flow_hash_from_keys(&keys, &hashrnd); +} +EXPORT_SYMBOL_GPL(__skb_get_hash_symmetric); + +/** + * __skb_get_hash: calculate a flow hash + * @skb: sk_buff to calculate flow hash from + * + * This function calculates a flow hash based on src/dst addresses + * and src/dst port numbers. Sets hash in skb to non-zero hash value + * on success, zero indicates no valid hash. Also, sets l4_hash in skb + * if hash is a canonical 4-tuple hash over transport ports. + */ +void __skb_get_hash(struct sk_buff *skb) +{ + struct flow_keys keys; + u32 hash; + + __flow_hash_secret_init(); + + hash = ___skb_get_hash(skb, &keys, &hashrnd); + + __skb_set_sw_hash(skb, hash, flow_keys_have_l4(&keys)); +} +EXPORT_SYMBOL(__skb_get_hash); + +__u32 skb_get_hash_perturb(const struct sk_buff *skb, + const siphash_key_t *perturb) +{ + struct flow_keys keys; + + return ___skb_get_hash(skb, &keys, perturb); +} +EXPORT_SYMBOL(skb_get_hash_perturb); + +u32 __skb_get_poff(const struct sk_buff *skb, const void *data, + const struct flow_keys_basic *keys, int hlen) +{ + u32 poff = keys->control.thoff; + + /* skip L4 headers for fragments after the first */ + if ((keys->control.flags & FLOW_DIS_IS_FRAGMENT) && + !(keys->control.flags & FLOW_DIS_FIRST_FRAG)) + return poff; + + switch (keys->basic.ip_proto) { + case IPPROTO_TCP: { + /* access doff as u8 to avoid unaligned access */ + const u8 *doff; + u8 _doff; + + doff = __skb_header_pointer(skb, poff + 12, sizeof(_doff), + data, hlen, &_doff); + if (!doff) + return poff; + + poff += max_t(u32, sizeof(struct tcphdr), (*doff & 0xF0) >> 2); + break; + } + case IPPROTO_UDP: + case IPPROTO_UDPLITE: + poff += sizeof(struct udphdr); + break; + /* For the rest, we do not really care about header + * extensions at this point for now. + */ + case IPPROTO_ICMP: + poff += sizeof(struct icmphdr); + break; + case IPPROTO_ICMPV6: + poff += sizeof(struct icmp6hdr); + break; + case IPPROTO_IGMP: + poff += sizeof(struct igmphdr); + break; + case IPPROTO_DCCP: + poff += sizeof(struct dccp_hdr); + break; + case IPPROTO_SCTP: + poff += sizeof(struct sctphdr); + break; + } + + return poff; +} + +/** + * skb_get_poff - get the offset to the payload + * @skb: sk_buff to get the payload offset from + * + * The function will get the offset to the payload as far as it could + * be dissected. The main user is currently BPF, so that we can dynamically + * truncate packets without needing to push actual payload to the user + * space and can analyze headers only, instead. + */ +u32 skb_get_poff(const struct sk_buff *skb) +{ + struct flow_keys_basic keys; + + if (!skb_flow_dissect_flow_keys_basic(NULL, skb, &keys, + NULL, 0, 0, 0, 0)) + return 0; + + return __skb_get_poff(skb, skb->data, &keys, skb_headlen(skb)); +} + +__u32 __get_hash_from_flowi6(const struct flowi6 *fl6, struct flow_keys *keys) +{ + memset(keys, 0, sizeof(*keys)); + + memcpy(&keys->addrs.v6addrs.src, &fl6->saddr, + sizeof(keys->addrs.v6addrs.src)); + memcpy(&keys->addrs.v6addrs.dst, &fl6->daddr, + sizeof(keys->addrs.v6addrs.dst)); + keys->control.addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + keys->ports.src = fl6->fl6_sport; + keys->ports.dst = fl6->fl6_dport; + keys->keyid.keyid = fl6->fl6_gre_key; + keys->tags.flow_label = (__force u32)flowi6_get_flowlabel(fl6); + keys->basic.ip_proto = fl6->flowi6_proto; + + return flow_hash_from_keys(keys); +} +EXPORT_SYMBOL(__get_hash_from_flowi6); + +static const struct flow_dissector_key flow_keys_dissector_keys[] = { + { + .key_id = FLOW_DISSECTOR_KEY_CONTROL, + .offset = offsetof(struct flow_keys, control), + }, + { + .key_id = FLOW_DISSECTOR_KEY_BASIC, + .offset = offsetof(struct flow_keys, basic), + }, + { + .key_id = FLOW_DISSECTOR_KEY_IPV4_ADDRS, + .offset = offsetof(struct flow_keys, addrs.v4addrs), + }, + { + .key_id = FLOW_DISSECTOR_KEY_IPV6_ADDRS, + .offset = offsetof(struct flow_keys, addrs.v6addrs), + }, + { + .key_id = FLOW_DISSECTOR_KEY_TIPC, + .offset = offsetof(struct flow_keys, addrs.tipckey), + }, + { + .key_id = FLOW_DISSECTOR_KEY_PORTS, + .offset = offsetof(struct flow_keys, ports), + }, + { + .key_id = FLOW_DISSECTOR_KEY_VLAN, + .offset = offsetof(struct flow_keys, vlan), + }, + { + .key_id = FLOW_DISSECTOR_KEY_FLOW_LABEL, + .offset = offsetof(struct flow_keys, tags), + }, + { + .key_id = FLOW_DISSECTOR_KEY_GRE_KEYID, + .offset = offsetof(struct flow_keys, keyid), + }, +}; + +static const struct flow_dissector_key flow_keys_dissector_symmetric_keys[] = { + { + .key_id = FLOW_DISSECTOR_KEY_CONTROL, + .offset = offsetof(struct flow_keys, control), + }, + { + .key_id = FLOW_DISSECTOR_KEY_BASIC, + .offset = offsetof(struct flow_keys, basic), + }, + { + .key_id = FLOW_DISSECTOR_KEY_IPV4_ADDRS, + .offset = offsetof(struct flow_keys, addrs.v4addrs), + }, + { + .key_id = FLOW_DISSECTOR_KEY_IPV6_ADDRS, + .offset = offsetof(struct flow_keys, addrs.v6addrs), + }, + { + .key_id = FLOW_DISSECTOR_KEY_PORTS, + .offset = offsetof(struct flow_keys, ports), + }, +}; + +static const struct flow_dissector_key flow_keys_basic_dissector_keys[] = { + { + .key_id = FLOW_DISSECTOR_KEY_CONTROL, + .offset = offsetof(struct flow_keys, control), + }, + { + .key_id = FLOW_DISSECTOR_KEY_BASIC, + .offset = offsetof(struct flow_keys, basic), + }, +}; + +struct flow_dissector flow_keys_dissector __read_mostly; +EXPORT_SYMBOL(flow_keys_dissector); + +struct flow_dissector flow_keys_basic_dissector __read_mostly; +EXPORT_SYMBOL(flow_keys_basic_dissector); + +static int __init init_default_flow_dissectors(void) +{ + skb_flow_dissector_init(&flow_keys_dissector, + flow_keys_dissector_keys, + ARRAY_SIZE(flow_keys_dissector_keys)); + skb_flow_dissector_init(&flow_keys_dissector_symmetric, + flow_keys_dissector_symmetric_keys, + ARRAY_SIZE(flow_keys_dissector_symmetric_keys)); + skb_flow_dissector_init(&flow_keys_basic_dissector, + flow_keys_basic_dissector_keys, + ARRAY_SIZE(flow_keys_basic_dissector_keys)); + return 0; +} +core_initcall(init_default_flow_dissectors); diff --git a/net/core/flow_offload.c b/net/core/flow_offload.c new file mode 100644 index 000000000..abe423fd5 --- /dev/null +++ b/net/core/flow_offload.c @@ -0,0 +1,624 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#include +#include +#include +#include +#include +#include +#include + +struct flow_rule *flow_rule_alloc(unsigned int num_actions) +{ + struct flow_rule *rule; + int i; + + rule = kzalloc(struct_size(rule, action.entries, num_actions), + GFP_KERNEL); + if (!rule) + return NULL; + + rule->action.num_entries = num_actions; + /* Pre-fill each action hw_stats with DONT_CARE. + * Caller can override this if it wants stats for a given action. + */ + for (i = 0; i < num_actions; i++) + rule->action.entries[i].hw_stats = FLOW_ACTION_HW_STATS_DONT_CARE; + + return rule; +} +EXPORT_SYMBOL(flow_rule_alloc); + +struct flow_offload_action *offload_action_alloc(unsigned int num_actions) +{ + struct flow_offload_action *fl_action; + int i; + + fl_action = kzalloc(struct_size(fl_action, action.entries, num_actions), + GFP_KERNEL); + if (!fl_action) + return NULL; + + fl_action->action.num_entries = num_actions; + /* Pre-fill each action hw_stats with DONT_CARE. + * Caller can override this if it wants stats for a given action. + */ + for (i = 0; i < num_actions; i++) + fl_action->action.entries[i].hw_stats = FLOW_ACTION_HW_STATS_DONT_CARE; + + return fl_action; +} + +#define FLOW_DISSECTOR_MATCH(__rule, __type, __out) \ + const struct flow_match *__m = &(__rule)->match; \ + struct flow_dissector *__d = (__m)->dissector; \ + \ + (__out)->key = skb_flow_dissector_target(__d, __type, (__m)->key); \ + (__out)->mask = skb_flow_dissector_target(__d, __type, (__m)->mask); \ + +void flow_rule_match_meta(const struct flow_rule *rule, + struct flow_match_meta *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_META, out); +} +EXPORT_SYMBOL(flow_rule_match_meta); + +void flow_rule_match_basic(const struct flow_rule *rule, + struct flow_match_basic *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_BASIC, out); +} +EXPORT_SYMBOL(flow_rule_match_basic); + +void flow_rule_match_control(const struct flow_rule *rule, + struct flow_match_control *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_CONTROL, out); +} +EXPORT_SYMBOL(flow_rule_match_control); + +void flow_rule_match_eth_addrs(const struct flow_rule *rule, + struct flow_match_eth_addrs *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ETH_ADDRS, out); +} +EXPORT_SYMBOL(flow_rule_match_eth_addrs); + +void flow_rule_match_vlan(const struct flow_rule *rule, + struct flow_match_vlan *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_VLAN, out); +} +EXPORT_SYMBOL(flow_rule_match_vlan); + +void flow_rule_match_cvlan(const struct flow_rule *rule, + struct flow_match_vlan *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_CVLAN, out); +} +EXPORT_SYMBOL(flow_rule_match_cvlan); + +void flow_rule_match_ipv4_addrs(const struct flow_rule *rule, + struct flow_match_ipv4_addrs *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_IPV4_ADDRS, out); +} +EXPORT_SYMBOL(flow_rule_match_ipv4_addrs); + +void flow_rule_match_ipv6_addrs(const struct flow_rule *rule, + struct flow_match_ipv6_addrs *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_IPV6_ADDRS, out); +} +EXPORT_SYMBOL(flow_rule_match_ipv6_addrs); + +void flow_rule_match_ip(const struct flow_rule *rule, + struct flow_match_ip *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_IP, out); +} +EXPORT_SYMBOL(flow_rule_match_ip); + +void flow_rule_match_ports(const struct flow_rule *rule, + struct flow_match_ports *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_PORTS, out); +} +EXPORT_SYMBOL(flow_rule_match_ports); + +void flow_rule_match_ports_range(const struct flow_rule *rule, + struct flow_match_ports_range *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_PORTS_RANGE, out); +} +EXPORT_SYMBOL(flow_rule_match_ports_range); + +void flow_rule_match_tcp(const struct flow_rule *rule, + struct flow_match_tcp *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_TCP, out); +} +EXPORT_SYMBOL(flow_rule_match_tcp); + +void flow_rule_match_icmp(const struct flow_rule *rule, + struct flow_match_icmp *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ICMP, out); +} +EXPORT_SYMBOL(flow_rule_match_icmp); + +void flow_rule_match_mpls(const struct flow_rule *rule, + struct flow_match_mpls *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_MPLS, out); +} +EXPORT_SYMBOL(flow_rule_match_mpls); + +void flow_rule_match_enc_control(const struct flow_rule *rule, + struct flow_match_control *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ENC_CONTROL, out); +} +EXPORT_SYMBOL(flow_rule_match_enc_control); + +void flow_rule_match_enc_ipv4_addrs(const struct flow_rule *rule, + struct flow_match_ipv4_addrs *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ENC_IPV4_ADDRS, out); +} +EXPORT_SYMBOL(flow_rule_match_enc_ipv4_addrs); + +void flow_rule_match_enc_ipv6_addrs(const struct flow_rule *rule, + struct flow_match_ipv6_addrs *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ENC_IPV6_ADDRS, out); +} +EXPORT_SYMBOL(flow_rule_match_enc_ipv6_addrs); + +void flow_rule_match_enc_ip(const struct flow_rule *rule, + struct flow_match_ip *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ENC_IP, out); +} +EXPORT_SYMBOL(flow_rule_match_enc_ip); + +void flow_rule_match_enc_ports(const struct flow_rule *rule, + struct flow_match_ports *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ENC_PORTS, out); +} +EXPORT_SYMBOL(flow_rule_match_enc_ports); + +void flow_rule_match_enc_keyid(const struct flow_rule *rule, + struct flow_match_enc_keyid *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ENC_KEYID, out); +} +EXPORT_SYMBOL(flow_rule_match_enc_keyid); + +void flow_rule_match_enc_opts(const struct flow_rule *rule, + struct flow_match_enc_opts *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_ENC_OPTS, out); +} +EXPORT_SYMBOL(flow_rule_match_enc_opts); + +struct flow_action_cookie *flow_action_cookie_create(void *data, + unsigned int len, + gfp_t gfp) +{ + struct flow_action_cookie *cookie; + + cookie = kmalloc(sizeof(*cookie) + len, gfp); + if (!cookie) + return NULL; + cookie->cookie_len = len; + memcpy(cookie->cookie, data, len); + return cookie; +} +EXPORT_SYMBOL(flow_action_cookie_create); + +void flow_action_cookie_destroy(struct flow_action_cookie *cookie) +{ + kfree(cookie); +} +EXPORT_SYMBOL(flow_action_cookie_destroy); + +void flow_rule_match_ct(const struct flow_rule *rule, + struct flow_match_ct *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_CT, out); +} +EXPORT_SYMBOL(flow_rule_match_ct); + +void flow_rule_match_pppoe(const struct flow_rule *rule, + struct flow_match_pppoe *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_PPPOE, out); +} +EXPORT_SYMBOL(flow_rule_match_pppoe); + +void flow_rule_match_l2tpv3(const struct flow_rule *rule, + struct flow_match_l2tpv3 *out) +{ + FLOW_DISSECTOR_MATCH(rule, FLOW_DISSECTOR_KEY_L2TPV3, out); +} +EXPORT_SYMBOL(flow_rule_match_l2tpv3); + +struct flow_block_cb *flow_block_cb_alloc(flow_setup_cb_t *cb, + void *cb_ident, void *cb_priv, + void (*release)(void *cb_priv)) +{ + struct flow_block_cb *block_cb; + + block_cb = kzalloc(sizeof(*block_cb), GFP_KERNEL); + if (!block_cb) + return ERR_PTR(-ENOMEM); + + block_cb->cb = cb; + block_cb->cb_ident = cb_ident; + block_cb->cb_priv = cb_priv; + block_cb->release = release; + + return block_cb; +} +EXPORT_SYMBOL(flow_block_cb_alloc); + +void flow_block_cb_free(struct flow_block_cb *block_cb) +{ + if (block_cb->release) + block_cb->release(block_cb->cb_priv); + + kfree(block_cb); +} +EXPORT_SYMBOL(flow_block_cb_free); + +struct flow_block_cb *flow_block_cb_lookup(struct flow_block *block, + flow_setup_cb_t *cb, void *cb_ident) +{ + struct flow_block_cb *block_cb; + + list_for_each_entry(block_cb, &block->cb_list, list) { + if (block_cb->cb == cb && + block_cb->cb_ident == cb_ident) + return block_cb; + } + + return NULL; +} +EXPORT_SYMBOL(flow_block_cb_lookup); + +void *flow_block_cb_priv(struct flow_block_cb *block_cb) +{ + return block_cb->cb_priv; +} +EXPORT_SYMBOL(flow_block_cb_priv); + +void flow_block_cb_incref(struct flow_block_cb *block_cb) +{ + block_cb->refcnt++; +} +EXPORT_SYMBOL(flow_block_cb_incref); + +unsigned int flow_block_cb_decref(struct flow_block_cb *block_cb) +{ + return --block_cb->refcnt; +} +EXPORT_SYMBOL(flow_block_cb_decref); + +bool flow_block_cb_is_busy(flow_setup_cb_t *cb, void *cb_ident, + struct list_head *driver_block_list) +{ + struct flow_block_cb *block_cb; + + list_for_each_entry(block_cb, driver_block_list, driver_list) { + if (block_cb->cb == cb && + block_cb->cb_ident == cb_ident) + return true; + } + + return false; +} +EXPORT_SYMBOL(flow_block_cb_is_busy); + +int flow_block_cb_setup_simple(struct flow_block_offload *f, + struct list_head *driver_block_list, + flow_setup_cb_t *cb, + void *cb_ident, void *cb_priv, + bool ingress_only) +{ + struct flow_block_cb *block_cb; + + if (ingress_only && + f->binder_type != FLOW_BLOCK_BINDER_TYPE_CLSACT_INGRESS) + return -EOPNOTSUPP; + + f->driver_block_list = driver_block_list; + + switch (f->command) { + case FLOW_BLOCK_BIND: + if (flow_block_cb_is_busy(cb, cb_ident, driver_block_list)) + return -EBUSY; + + block_cb = flow_block_cb_alloc(cb, cb_ident, cb_priv, NULL); + if (IS_ERR(block_cb)) + return PTR_ERR(block_cb); + + flow_block_cb_add(block_cb, f); + list_add_tail(&block_cb->driver_list, driver_block_list); + return 0; + case FLOW_BLOCK_UNBIND: + block_cb = flow_block_cb_lookup(f->block, cb, cb_ident); + if (!block_cb) + return -ENOENT; + + flow_block_cb_remove(block_cb, f); + list_del(&block_cb->driver_list); + return 0; + default: + return -EOPNOTSUPP; + } +} +EXPORT_SYMBOL(flow_block_cb_setup_simple); + +static DEFINE_MUTEX(flow_indr_block_lock); +static LIST_HEAD(flow_block_indr_list); +static LIST_HEAD(flow_block_indr_dev_list); +static LIST_HEAD(flow_indir_dev_list); + +struct flow_indr_dev { + struct list_head list; + flow_indr_block_bind_cb_t *cb; + void *cb_priv; + refcount_t refcnt; +}; + +static struct flow_indr_dev *flow_indr_dev_alloc(flow_indr_block_bind_cb_t *cb, + void *cb_priv) +{ + struct flow_indr_dev *indr_dev; + + indr_dev = kmalloc(sizeof(*indr_dev), GFP_KERNEL); + if (!indr_dev) + return NULL; + + indr_dev->cb = cb; + indr_dev->cb_priv = cb_priv; + refcount_set(&indr_dev->refcnt, 1); + + return indr_dev; +} + +struct flow_indir_dev_info { + void *data; + struct net_device *dev; + struct Qdisc *sch; + enum tc_setup_type type; + void (*cleanup)(struct flow_block_cb *block_cb); + struct list_head list; + enum flow_block_command command; + enum flow_block_binder_type binder_type; + struct list_head *cb_list; +}; + +static void existing_qdiscs_register(flow_indr_block_bind_cb_t *cb, void *cb_priv) +{ + struct flow_block_offload bo; + struct flow_indir_dev_info *cur; + + list_for_each_entry(cur, &flow_indir_dev_list, list) { + memset(&bo, 0, sizeof(bo)); + bo.command = cur->command; + bo.binder_type = cur->binder_type; + INIT_LIST_HEAD(&bo.cb_list); + cb(cur->dev, cur->sch, cb_priv, cur->type, &bo, cur->data, cur->cleanup); + list_splice(&bo.cb_list, cur->cb_list); + } +} + +int flow_indr_dev_register(flow_indr_block_bind_cb_t *cb, void *cb_priv) +{ + struct flow_indr_dev *indr_dev; + + mutex_lock(&flow_indr_block_lock); + list_for_each_entry(indr_dev, &flow_block_indr_dev_list, list) { + if (indr_dev->cb == cb && + indr_dev->cb_priv == cb_priv) { + refcount_inc(&indr_dev->refcnt); + mutex_unlock(&flow_indr_block_lock); + return 0; + } + } + + indr_dev = flow_indr_dev_alloc(cb, cb_priv); + if (!indr_dev) { + mutex_unlock(&flow_indr_block_lock); + return -ENOMEM; + } + + list_add(&indr_dev->list, &flow_block_indr_dev_list); + existing_qdiscs_register(cb, cb_priv); + mutex_unlock(&flow_indr_block_lock); + + tcf_action_reoffload_cb(cb, cb_priv, true); + + return 0; +} +EXPORT_SYMBOL(flow_indr_dev_register); + +static void __flow_block_indr_cleanup(void (*release)(void *cb_priv), + void *cb_priv, + struct list_head *cleanup_list) +{ + struct flow_block_cb *this, *next; + + list_for_each_entry_safe(this, next, &flow_block_indr_list, indr.list) { + if (this->release == release && + this->indr.cb_priv == cb_priv) + list_move(&this->indr.list, cleanup_list); + } +} + +static void flow_block_indr_notify(struct list_head *cleanup_list) +{ + struct flow_block_cb *this, *next; + + list_for_each_entry_safe(this, next, cleanup_list, indr.list) { + list_del(&this->indr.list); + this->indr.cleanup(this); + } +} + +void flow_indr_dev_unregister(flow_indr_block_bind_cb_t *cb, void *cb_priv, + void (*release)(void *cb_priv)) +{ + struct flow_indr_dev *this, *next, *indr_dev = NULL; + LIST_HEAD(cleanup_list); + + mutex_lock(&flow_indr_block_lock); + list_for_each_entry_safe(this, next, &flow_block_indr_dev_list, list) { + if (this->cb == cb && + this->cb_priv == cb_priv && + refcount_dec_and_test(&this->refcnt)) { + indr_dev = this; + list_del(&indr_dev->list); + break; + } + } + + if (!indr_dev) { + mutex_unlock(&flow_indr_block_lock); + return; + } + + __flow_block_indr_cleanup(release, cb_priv, &cleanup_list); + mutex_unlock(&flow_indr_block_lock); + + tcf_action_reoffload_cb(cb, cb_priv, false); + flow_block_indr_notify(&cleanup_list); + kfree(indr_dev); +} +EXPORT_SYMBOL(flow_indr_dev_unregister); + +static void flow_block_indr_init(struct flow_block_cb *flow_block, + struct flow_block_offload *bo, + struct net_device *dev, struct Qdisc *sch, void *data, + void *cb_priv, + void (*cleanup)(struct flow_block_cb *block_cb)) +{ + flow_block->indr.binder_type = bo->binder_type; + flow_block->indr.data = data; + flow_block->indr.cb_priv = cb_priv; + flow_block->indr.dev = dev; + flow_block->indr.sch = sch; + flow_block->indr.cleanup = cleanup; +} + +struct flow_block_cb *flow_indr_block_cb_alloc(flow_setup_cb_t *cb, + void *cb_ident, void *cb_priv, + void (*release)(void *cb_priv), + struct flow_block_offload *bo, + struct net_device *dev, + struct Qdisc *sch, void *data, + void *indr_cb_priv, + void (*cleanup)(struct flow_block_cb *block_cb)) +{ + struct flow_block_cb *block_cb; + + block_cb = flow_block_cb_alloc(cb, cb_ident, cb_priv, release); + if (IS_ERR(block_cb)) + goto out; + + flow_block_indr_init(block_cb, bo, dev, sch, data, indr_cb_priv, cleanup); + list_add(&block_cb->indr.list, &flow_block_indr_list); + +out: + return block_cb; +} +EXPORT_SYMBOL(flow_indr_block_cb_alloc); + +static struct flow_indir_dev_info *find_indir_dev(void *data) +{ + struct flow_indir_dev_info *cur; + + list_for_each_entry(cur, &flow_indir_dev_list, list) { + if (cur->data == data) + return cur; + } + return NULL; +} + +static int indir_dev_add(void *data, struct net_device *dev, struct Qdisc *sch, + enum tc_setup_type type, void (*cleanup)(struct flow_block_cb *block_cb), + struct flow_block_offload *bo) +{ + struct flow_indir_dev_info *info; + + info = find_indir_dev(data); + if (info) + return -EEXIST; + + info = kzalloc(sizeof(*info), GFP_KERNEL); + if (!info) + return -ENOMEM; + + info->data = data; + info->dev = dev; + info->sch = sch; + info->type = type; + info->cleanup = cleanup; + info->command = bo->command; + info->binder_type = bo->binder_type; + info->cb_list = bo->cb_list_head; + + list_add(&info->list, &flow_indir_dev_list); + return 0; +} + +static int indir_dev_remove(void *data) +{ + struct flow_indir_dev_info *info; + + info = find_indir_dev(data); + if (!info) + return -ENOENT; + + list_del(&info->list); + + kfree(info); + return 0; +} + +int flow_indr_dev_setup_offload(struct net_device *dev, struct Qdisc *sch, + enum tc_setup_type type, void *data, + struct flow_block_offload *bo, + void (*cleanup)(struct flow_block_cb *block_cb)) +{ + struct flow_indr_dev *this; + u32 count = 0; + int err; + + mutex_lock(&flow_indr_block_lock); + if (bo) { + if (bo->command == FLOW_BLOCK_BIND) + indir_dev_add(data, dev, sch, type, cleanup, bo); + else if (bo->command == FLOW_BLOCK_UNBIND) + indir_dev_remove(data); + } + + list_for_each_entry(this, &flow_block_indr_dev_list, list) { + err = this->cb(dev, sch, this->cb_priv, type, bo, data, cleanup); + if (!err) + count++; + } + + mutex_unlock(&flow_indr_block_lock); + + return (bo && list_empty(&bo->cb_list)) ? -EOPNOTSUPP : count; +} +EXPORT_SYMBOL(flow_indr_dev_setup_offload); + +bool flow_indr_dev_exists(void) +{ + return !list_empty(&flow_block_indr_dev_list); +} +EXPORT_SYMBOL(flow_indr_dev_exists); diff --git a/net/core/gen_estimator.c b/net/core/gen_estimator.c new file mode 100644 index 000000000..4fcbdd71c --- /dev/null +++ b/net/core/gen_estimator.c @@ -0,0 +1,278 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/gen_estimator.c Simple rate estimator. + * + * Authors: Alexey Kuznetsov, + * Eric Dumazet + * + * Changes: + * Jamal Hadi Salim - moved it to net/core and reshulfed + * names to make it usable in general net subsystem. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* This code is NOT intended to be used for statistics collection, + * its purpose is to provide a base for statistical multiplexing + * for controlled load service. + * If you need only statistics, run a user level daemon which + * periodically reads byte counters. + */ + +struct net_rate_estimator { + struct gnet_stats_basic_sync *bstats; + spinlock_t *stats_lock; + bool running; + struct gnet_stats_basic_sync __percpu *cpu_bstats; + u8 ewma_log; + u8 intvl_log; /* period : (250ms << intvl_log) */ + + seqcount_t seq; + u64 last_packets; + u64 last_bytes; + + u64 avpps; + u64 avbps; + + unsigned long next_jiffies; + struct timer_list timer; + struct rcu_head rcu; +}; + +static void est_fetch_counters(struct net_rate_estimator *e, + struct gnet_stats_basic_sync *b) +{ + gnet_stats_basic_sync_init(b); + if (e->stats_lock) + spin_lock(e->stats_lock); + + gnet_stats_add_basic(b, e->cpu_bstats, e->bstats, e->running); + + if (e->stats_lock) + spin_unlock(e->stats_lock); + +} + +static void est_timer(struct timer_list *t) +{ + struct net_rate_estimator *est = from_timer(est, t, timer); + struct gnet_stats_basic_sync b; + u64 b_bytes, b_packets; + u64 rate, brate; + + est_fetch_counters(est, &b); + b_bytes = u64_stats_read(&b.bytes); + b_packets = u64_stats_read(&b.packets); + + brate = (b_bytes - est->last_bytes) << (10 - est->intvl_log); + brate = (brate >> est->ewma_log) - (est->avbps >> est->ewma_log); + + rate = (b_packets - est->last_packets) << (10 - est->intvl_log); + rate = (rate >> est->ewma_log) - (est->avpps >> est->ewma_log); + + write_seqcount_begin(&est->seq); + est->avbps += brate; + est->avpps += rate; + write_seqcount_end(&est->seq); + + est->last_bytes = b_bytes; + est->last_packets = b_packets; + + est->next_jiffies += ((HZ/4) << est->intvl_log); + + if (unlikely(time_after_eq(jiffies, est->next_jiffies))) { + /* Ouch... timer was delayed. */ + est->next_jiffies = jiffies + 1; + } + mod_timer(&est->timer, est->next_jiffies); +} + +/** + * gen_new_estimator - create a new rate estimator + * @bstats: basic statistics + * @cpu_bstats: bstats per cpu + * @rate_est: rate estimator statistics + * @lock: lock for statistics and control path + * @running: true if @bstats represents a running qdisc, thus @bstats' + * internal values might change during basic reads. Only used + * if @bstats_cpu is NULL + * @opt: rate estimator configuration TLV + * + * Creates a new rate estimator with &bstats as source and &rate_est + * as destination. A new timer with the interval specified in the + * configuration TLV is created. Upon each interval, the latest statistics + * will be read from &bstats and the estimated rate will be stored in + * &rate_est with the statistics lock grabbed during this period. + * + * Returns 0 on success or a negative error code. + * + */ +int gen_new_estimator(struct gnet_stats_basic_sync *bstats, + struct gnet_stats_basic_sync __percpu *cpu_bstats, + struct net_rate_estimator __rcu **rate_est, + spinlock_t *lock, + bool running, + struct nlattr *opt) +{ + struct gnet_estimator *parm = nla_data(opt); + struct net_rate_estimator *old, *est; + struct gnet_stats_basic_sync b; + int intvl_log; + + if (nla_len(opt) < sizeof(*parm)) + return -EINVAL; + + /* allowed timer periods are : + * -2 : 250ms, -1 : 500ms, 0 : 1 sec + * 1 : 2 sec, 2 : 4 sec, 3 : 8 sec + */ + if (parm->interval < -2 || parm->interval > 3) + return -EINVAL; + + if (parm->ewma_log == 0 || parm->ewma_log >= 31) + return -EINVAL; + + est = kzalloc(sizeof(*est), GFP_KERNEL); + if (!est) + return -ENOBUFS; + + seqcount_init(&est->seq); + intvl_log = parm->interval + 2; + est->bstats = bstats; + est->stats_lock = lock; + est->running = running; + est->ewma_log = parm->ewma_log; + est->intvl_log = intvl_log; + est->cpu_bstats = cpu_bstats; + + if (lock) + local_bh_disable(); + est_fetch_counters(est, &b); + if (lock) + local_bh_enable(); + est->last_bytes = u64_stats_read(&b.bytes); + est->last_packets = u64_stats_read(&b.packets); + + if (lock) + spin_lock_bh(lock); + old = rcu_dereference_protected(*rate_est, 1); + if (old) { + del_timer_sync(&old->timer); + est->avbps = old->avbps; + est->avpps = old->avpps; + } + + est->next_jiffies = jiffies + ((HZ/4) << intvl_log); + timer_setup(&est->timer, est_timer, 0); + mod_timer(&est->timer, est->next_jiffies); + + rcu_assign_pointer(*rate_est, est); + if (lock) + spin_unlock_bh(lock); + if (old) + kfree_rcu(old, rcu); + return 0; +} +EXPORT_SYMBOL(gen_new_estimator); + +/** + * gen_kill_estimator - remove a rate estimator + * @rate_est: rate estimator + * + * Removes the rate estimator. + * + */ +void gen_kill_estimator(struct net_rate_estimator __rcu **rate_est) +{ + struct net_rate_estimator *est; + + est = xchg((__force struct net_rate_estimator **)rate_est, NULL); + if (est) { + del_timer_sync(&est->timer); + kfree_rcu(est, rcu); + } +} +EXPORT_SYMBOL(gen_kill_estimator); + +/** + * gen_replace_estimator - replace rate estimator configuration + * @bstats: basic statistics + * @cpu_bstats: bstats per cpu + * @rate_est: rate estimator statistics + * @lock: lock for statistics and control path + * @running: true if @bstats represents a running qdisc, thus @bstats' + * internal values might change during basic reads. Only used + * if @cpu_bstats is NULL + * @opt: rate estimator configuration TLV + * + * Replaces the configuration of a rate estimator by calling + * gen_kill_estimator() and gen_new_estimator(). + * + * Returns 0 on success or a negative error code. + */ +int gen_replace_estimator(struct gnet_stats_basic_sync *bstats, + struct gnet_stats_basic_sync __percpu *cpu_bstats, + struct net_rate_estimator __rcu **rate_est, + spinlock_t *lock, + bool running, struct nlattr *opt) +{ + return gen_new_estimator(bstats, cpu_bstats, rate_est, + lock, running, opt); +} +EXPORT_SYMBOL(gen_replace_estimator); + +/** + * gen_estimator_active - test if estimator is currently in use + * @rate_est: rate estimator + * + * Returns true if estimator is active, and false if not. + */ +bool gen_estimator_active(struct net_rate_estimator __rcu **rate_est) +{ + return !!rcu_access_pointer(*rate_est); +} +EXPORT_SYMBOL(gen_estimator_active); + +bool gen_estimator_read(struct net_rate_estimator __rcu **rate_est, + struct gnet_stats_rate_est64 *sample) +{ + struct net_rate_estimator *est; + unsigned seq; + + rcu_read_lock(); + est = rcu_dereference(*rate_est); + if (!est) { + rcu_read_unlock(); + return false; + } + + do { + seq = read_seqcount_begin(&est->seq); + sample->bps = est->avbps >> 8; + sample->pps = est->avpps >> 8; + } while (read_seqcount_retry(&est->seq, seq)); + + rcu_read_unlock(); + return true; +} +EXPORT_SYMBOL(gen_estimator_read); diff --git a/net/core/gen_stats.c b/net/core/gen_stats.c new file mode 100644 index 000000000..c8d137ef5 --- /dev/null +++ b/net/core/gen_stats.c @@ -0,0 +1,485 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/core/gen_stats.c + * + * Authors: Thomas Graf + * Jamal Hadi Salim + * Alexey Kuznetsov, + * + * See Documentation/networking/gen_stats.rst + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static inline int +gnet_stats_copy(struct gnet_dump *d, int type, void *buf, int size, int padattr) +{ + if (nla_put_64bit(d->skb, type, size, buf, padattr)) + goto nla_put_failure; + return 0; + +nla_put_failure: + if (d->lock) + spin_unlock_bh(d->lock); + kfree(d->xstats); + d->xstats = NULL; + d->xstats_len = 0; + return -1; +} + +/** + * gnet_stats_start_copy_compat - start dumping procedure in compatibility mode + * @skb: socket buffer to put statistics TLVs into + * @type: TLV type for top level statistic TLV + * @tc_stats_type: TLV type for backward compatibility struct tc_stats TLV + * @xstats_type: TLV type for backward compatibility xstats TLV + * @lock: statistics lock + * @d: dumping handle + * @padattr: padding attribute + * + * Initializes the dumping handle, grabs the statistic lock and appends + * an empty TLV header to the socket buffer for use a container for all + * other statistic TLVS. + * + * The dumping handle is marked to be in backward compatibility mode telling + * all gnet_stats_copy_XXX() functions to fill a local copy of struct tc_stats. + * + * Returns 0 on success or -1 if the room in the socket buffer was not sufficient. + */ +int +gnet_stats_start_copy_compat(struct sk_buff *skb, int type, int tc_stats_type, + int xstats_type, spinlock_t *lock, + struct gnet_dump *d, int padattr) + __acquires(lock) +{ + memset(d, 0, sizeof(*d)); + + if (type) + d->tail = (struct nlattr *)skb_tail_pointer(skb); + d->skb = skb; + d->compat_tc_stats = tc_stats_type; + d->compat_xstats = xstats_type; + d->padattr = padattr; + if (lock) { + d->lock = lock; + spin_lock_bh(lock); + } + if (d->tail) { + int ret = gnet_stats_copy(d, type, NULL, 0, padattr); + + /* The initial attribute added in gnet_stats_copy() may be + * preceded by a padding attribute, in which case d->tail will + * end up pointing at the padding instead of the real attribute. + * Fix this so gnet_stats_finish_copy() adjusts the length of + * the right attribute. + */ + if (ret == 0 && d->tail->nla_type == padattr) + d->tail = (struct nlattr *)((char *)d->tail + + NLA_ALIGN(d->tail->nla_len)); + return ret; + } + + return 0; +} +EXPORT_SYMBOL(gnet_stats_start_copy_compat); + +/** + * gnet_stats_start_copy - start dumping procedure in compatibility mode + * @skb: socket buffer to put statistics TLVs into + * @type: TLV type for top level statistic TLV + * @lock: statistics lock + * @d: dumping handle + * @padattr: padding attribute + * + * Initializes the dumping handle, grabs the statistic lock and appends + * an empty TLV header to the socket buffer for use a container for all + * other statistic TLVS. + * + * Returns 0 on success or -1 if the room in the socket buffer was not sufficient. + */ +int +gnet_stats_start_copy(struct sk_buff *skb, int type, spinlock_t *lock, + struct gnet_dump *d, int padattr) +{ + return gnet_stats_start_copy_compat(skb, type, 0, 0, lock, d, padattr); +} +EXPORT_SYMBOL(gnet_stats_start_copy); + +/* Must not be inlined, due to u64_stats seqcount_t lockdep key */ +void gnet_stats_basic_sync_init(struct gnet_stats_basic_sync *b) +{ + u64_stats_set(&b->bytes, 0); + u64_stats_set(&b->packets, 0); + u64_stats_init(&b->syncp); +} +EXPORT_SYMBOL(gnet_stats_basic_sync_init); + +static void gnet_stats_add_basic_cpu(struct gnet_stats_basic_sync *bstats, + struct gnet_stats_basic_sync __percpu *cpu) +{ + u64 t_bytes = 0, t_packets = 0; + int i; + + for_each_possible_cpu(i) { + struct gnet_stats_basic_sync *bcpu = per_cpu_ptr(cpu, i); + unsigned int start; + u64 bytes, packets; + + do { + start = u64_stats_fetch_begin_irq(&bcpu->syncp); + bytes = u64_stats_read(&bcpu->bytes); + packets = u64_stats_read(&bcpu->packets); + } while (u64_stats_fetch_retry_irq(&bcpu->syncp, start)); + + t_bytes += bytes; + t_packets += packets; + } + _bstats_update(bstats, t_bytes, t_packets); +} + +void gnet_stats_add_basic(struct gnet_stats_basic_sync *bstats, + struct gnet_stats_basic_sync __percpu *cpu, + struct gnet_stats_basic_sync *b, bool running) +{ + unsigned int start; + u64 bytes = 0; + u64 packets = 0; + + WARN_ON_ONCE((cpu || running) && in_hardirq()); + + if (cpu) { + gnet_stats_add_basic_cpu(bstats, cpu); + return; + } + do { + if (running) + start = u64_stats_fetch_begin_irq(&b->syncp); + bytes = u64_stats_read(&b->bytes); + packets = u64_stats_read(&b->packets); + } while (running && u64_stats_fetch_retry_irq(&b->syncp, start)); + + _bstats_update(bstats, bytes, packets); +} +EXPORT_SYMBOL(gnet_stats_add_basic); + +static void gnet_stats_read_basic(u64 *ret_bytes, u64 *ret_packets, + struct gnet_stats_basic_sync __percpu *cpu, + struct gnet_stats_basic_sync *b, bool running) +{ + unsigned int start; + + if (cpu) { + u64 t_bytes = 0, t_packets = 0; + int i; + + for_each_possible_cpu(i) { + struct gnet_stats_basic_sync *bcpu = per_cpu_ptr(cpu, i); + unsigned int start; + u64 bytes, packets; + + do { + start = u64_stats_fetch_begin_irq(&bcpu->syncp); + bytes = u64_stats_read(&bcpu->bytes); + packets = u64_stats_read(&bcpu->packets); + } while (u64_stats_fetch_retry_irq(&bcpu->syncp, start)); + + t_bytes += bytes; + t_packets += packets; + } + *ret_bytes = t_bytes; + *ret_packets = t_packets; + return; + } + do { + if (running) + start = u64_stats_fetch_begin_irq(&b->syncp); + *ret_bytes = u64_stats_read(&b->bytes); + *ret_packets = u64_stats_read(&b->packets); + } while (running && u64_stats_fetch_retry_irq(&b->syncp, start)); +} + +static int +___gnet_stats_copy_basic(struct gnet_dump *d, + struct gnet_stats_basic_sync __percpu *cpu, + struct gnet_stats_basic_sync *b, + int type, bool running) +{ + u64 bstats_bytes, bstats_packets; + + gnet_stats_read_basic(&bstats_bytes, &bstats_packets, cpu, b, running); + + if (d->compat_tc_stats && type == TCA_STATS_BASIC) { + d->tc_stats.bytes = bstats_bytes; + d->tc_stats.packets = bstats_packets; + } + + if (d->tail) { + struct gnet_stats_basic sb; + int res; + + memset(&sb, 0, sizeof(sb)); + sb.bytes = bstats_bytes; + sb.packets = bstats_packets; + res = gnet_stats_copy(d, type, &sb, sizeof(sb), TCA_STATS_PAD); + if (res < 0 || sb.packets == bstats_packets) + return res; + /* emit 64bit stats only if needed */ + return gnet_stats_copy(d, TCA_STATS_PKT64, &bstats_packets, + sizeof(bstats_packets), TCA_STATS_PAD); + } + return 0; +} + +/** + * gnet_stats_copy_basic - copy basic statistics into statistic TLV + * @d: dumping handle + * @cpu: copy statistic per cpu + * @b: basic statistics + * @running: true if @b represents a running qdisc, thus @b's + * internal values might change during basic reads. + * Only used if @cpu is NULL + * + * Context: task; must not be run from IRQ or BH contexts + * + * Appends the basic statistics to the top level TLV created by + * gnet_stats_start_copy(). + * + * Returns 0 on success or -1 with the statistic lock released + * if the room in the socket buffer was not sufficient. + */ +int +gnet_stats_copy_basic(struct gnet_dump *d, + struct gnet_stats_basic_sync __percpu *cpu, + struct gnet_stats_basic_sync *b, + bool running) +{ + return ___gnet_stats_copy_basic(d, cpu, b, TCA_STATS_BASIC, running); +} +EXPORT_SYMBOL(gnet_stats_copy_basic); + +/** + * gnet_stats_copy_basic_hw - copy basic hw statistics into statistic TLV + * @d: dumping handle + * @cpu: copy statistic per cpu + * @b: basic statistics + * @running: true if @b represents a running qdisc, thus @b's + * internal values might change during basic reads. + * Only used if @cpu is NULL + * + * Context: task; must not be run from IRQ or BH contexts + * + * Appends the basic statistics to the top level TLV created by + * gnet_stats_start_copy(). + * + * Returns 0 on success or -1 with the statistic lock released + * if the room in the socket buffer was not sufficient. + */ +int +gnet_stats_copy_basic_hw(struct gnet_dump *d, + struct gnet_stats_basic_sync __percpu *cpu, + struct gnet_stats_basic_sync *b, + bool running) +{ + return ___gnet_stats_copy_basic(d, cpu, b, TCA_STATS_BASIC_HW, running); +} +EXPORT_SYMBOL(gnet_stats_copy_basic_hw); + +/** + * gnet_stats_copy_rate_est - copy rate estimator statistics into statistics TLV + * @d: dumping handle + * @rate_est: rate estimator + * + * Appends the rate estimator statistics to the top level TLV created by + * gnet_stats_start_copy(). + * + * Returns 0 on success or -1 with the statistic lock released + * if the room in the socket buffer was not sufficient. + */ +int +gnet_stats_copy_rate_est(struct gnet_dump *d, + struct net_rate_estimator __rcu **rate_est) +{ + struct gnet_stats_rate_est64 sample; + struct gnet_stats_rate_est est; + int res; + + if (!gen_estimator_read(rate_est, &sample)) + return 0; + est.bps = min_t(u64, UINT_MAX, sample.bps); + /* we have some time before reaching 2^32 packets per second */ + est.pps = sample.pps; + + if (d->compat_tc_stats) { + d->tc_stats.bps = est.bps; + d->tc_stats.pps = est.pps; + } + + if (d->tail) { + res = gnet_stats_copy(d, TCA_STATS_RATE_EST, &est, sizeof(est), + TCA_STATS_PAD); + if (res < 0 || est.bps == sample.bps) + return res; + /* emit 64bit stats only if needed */ + return gnet_stats_copy(d, TCA_STATS_RATE_EST64, &sample, + sizeof(sample), TCA_STATS_PAD); + } + + return 0; +} +EXPORT_SYMBOL(gnet_stats_copy_rate_est); + +static void gnet_stats_add_queue_cpu(struct gnet_stats_queue *qstats, + const struct gnet_stats_queue __percpu *q) +{ + int i; + + for_each_possible_cpu(i) { + const struct gnet_stats_queue *qcpu = per_cpu_ptr(q, i); + + qstats->qlen += qcpu->qlen; + qstats->backlog += qcpu->backlog; + qstats->drops += qcpu->drops; + qstats->requeues += qcpu->requeues; + qstats->overlimits += qcpu->overlimits; + } +} + +void gnet_stats_add_queue(struct gnet_stats_queue *qstats, + const struct gnet_stats_queue __percpu *cpu, + const struct gnet_stats_queue *q) +{ + if (cpu) { + gnet_stats_add_queue_cpu(qstats, cpu); + } else { + qstats->qlen += q->qlen; + qstats->backlog += q->backlog; + qstats->drops += q->drops; + qstats->requeues += q->requeues; + qstats->overlimits += q->overlimits; + } +} +EXPORT_SYMBOL(gnet_stats_add_queue); + +/** + * gnet_stats_copy_queue - copy queue statistics into statistics TLV + * @d: dumping handle + * @cpu_q: per cpu queue statistics + * @q: queue statistics + * @qlen: queue length statistics + * + * Appends the queue statistics to the top level TLV created by + * gnet_stats_start_copy(). Using per cpu queue statistics if + * they are available. + * + * Returns 0 on success or -1 with the statistic lock released + * if the room in the socket buffer was not sufficient. + */ +int +gnet_stats_copy_queue(struct gnet_dump *d, + struct gnet_stats_queue __percpu *cpu_q, + struct gnet_stats_queue *q, __u32 qlen) +{ + struct gnet_stats_queue qstats = {0}; + + gnet_stats_add_queue(&qstats, cpu_q, q); + qstats.qlen = qlen; + + if (d->compat_tc_stats) { + d->tc_stats.drops = qstats.drops; + d->tc_stats.qlen = qstats.qlen; + d->tc_stats.backlog = qstats.backlog; + d->tc_stats.overlimits = qstats.overlimits; + } + + if (d->tail) + return gnet_stats_copy(d, TCA_STATS_QUEUE, + &qstats, sizeof(qstats), + TCA_STATS_PAD); + + return 0; +} +EXPORT_SYMBOL(gnet_stats_copy_queue); + +/** + * gnet_stats_copy_app - copy application specific statistics into statistics TLV + * @d: dumping handle + * @st: application specific statistics data + * @len: length of data + * + * Appends the application specific statistics to the top level TLV created by + * gnet_stats_start_copy() and remembers the data for XSTATS if the dumping + * handle is in backward compatibility mode. + * + * Returns 0 on success or -1 with the statistic lock released + * if the room in the socket buffer was not sufficient. + */ +int +gnet_stats_copy_app(struct gnet_dump *d, void *st, int len) +{ + if (d->compat_xstats) { + d->xstats = kmemdup(st, len, GFP_ATOMIC); + if (!d->xstats) + goto err_out; + d->xstats_len = len; + } + + if (d->tail) + return gnet_stats_copy(d, TCA_STATS_APP, st, len, + TCA_STATS_PAD); + + return 0; + +err_out: + if (d->lock) + spin_unlock_bh(d->lock); + d->xstats_len = 0; + return -1; +} +EXPORT_SYMBOL(gnet_stats_copy_app); + +/** + * gnet_stats_finish_copy - finish dumping procedure + * @d: dumping handle + * + * Corrects the length of the top level TLV to include all TLVs added + * by gnet_stats_copy_XXX() calls. Adds the backward compatibility TLVs + * if gnet_stats_start_copy_compat() was used and releases the statistics + * lock. + * + * Returns 0 on success or -1 with the statistic lock released + * if the room in the socket buffer was not sufficient. + */ +int +gnet_stats_finish_copy(struct gnet_dump *d) +{ + if (d->tail) + d->tail->nla_len = skb_tail_pointer(d->skb) - (u8 *)d->tail; + + if (d->compat_tc_stats) + if (gnet_stats_copy(d, d->compat_tc_stats, &d->tc_stats, + sizeof(d->tc_stats), d->padattr) < 0) + return -1; + + if (d->compat_xstats && d->xstats) { + if (gnet_stats_copy(d, d->compat_xstats, d->xstats, + d->xstats_len, d->padattr) < 0) + return -1; + } + + if (d->lock) + spin_unlock_bh(d->lock); + kfree(d->xstats); + d->xstats = NULL; + d->xstats_len = 0; + return 0; +} +EXPORT_SYMBOL(gnet_stats_finish_copy); diff --git a/net/core/gro.c b/net/core/gro.c new file mode 100644 index 000000000..352f966cb --- /dev/null +++ b/net/core/gro.c @@ -0,0 +1,815 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +#include +#include +#include +#include + +#define MAX_GRO_SKBS 8 + +/* This should be increased if a protocol with a bigger head is added. */ +#define GRO_MAX_HEAD (MAX_HEADER + 128) + +static DEFINE_SPINLOCK(offload_lock); +static struct list_head offload_base __read_mostly = LIST_HEAD_INIT(offload_base); +/* Maximum number of GRO_NORMAL skbs to batch up for list-RX */ +int gro_normal_batch __read_mostly = 8; + +/** + * dev_add_offload - register offload handlers + * @po: protocol offload declaration + * + * Add protocol offload handlers to the networking stack. The passed + * &proto_offload is linked into kernel lists and may not be freed until + * it has been removed from the kernel lists. + * + * This call does not sleep therefore it can not + * guarantee all CPU's that are in middle of receiving packets + * will see the new offload handlers (until the next received packet). + */ +void dev_add_offload(struct packet_offload *po) +{ + struct packet_offload *elem; + + spin_lock(&offload_lock); + list_for_each_entry(elem, &offload_base, list) { + if (po->priority < elem->priority) + break; + } + list_add_rcu(&po->list, elem->list.prev); + spin_unlock(&offload_lock); +} +EXPORT_SYMBOL(dev_add_offload); + +/** + * __dev_remove_offload - remove offload handler + * @po: packet offload declaration + * + * Remove a protocol offload handler that was previously added to the + * kernel offload handlers by dev_add_offload(). The passed &offload_type + * is removed from the kernel lists and can be freed or reused once this + * function returns. + * + * The packet type might still be in use by receivers + * and must not be freed until after all the CPU's have gone + * through a quiescent state. + */ +static void __dev_remove_offload(struct packet_offload *po) +{ + struct list_head *head = &offload_base; + struct packet_offload *po1; + + spin_lock(&offload_lock); + + list_for_each_entry(po1, head, list) { + if (po == po1) { + list_del_rcu(&po->list); + goto out; + } + } + + pr_warn("dev_remove_offload: %p not found\n", po); +out: + spin_unlock(&offload_lock); +} + +/** + * dev_remove_offload - remove packet offload handler + * @po: packet offload declaration + * + * Remove a packet offload handler that was previously added to the kernel + * offload handlers by dev_add_offload(). The passed &offload_type is + * removed from the kernel lists and can be freed or reused once this + * function returns. + * + * This call sleeps to guarantee that no CPU is looking at the packet + * type after return. + */ +void dev_remove_offload(struct packet_offload *po) +{ + __dev_remove_offload(po); + + synchronize_net(); +} +EXPORT_SYMBOL(dev_remove_offload); + +/** + * skb_eth_gso_segment - segmentation handler for ethernet protocols. + * @skb: buffer to segment + * @features: features for the output path (see dev->features) + * @type: Ethernet Protocol ID + */ +struct sk_buff *skb_eth_gso_segment(struct sk_buff *skb, + netdev_features_t features, __be16 type) +{ + struct sk_buff *segs = ERR_PTR(-EPROTONOSUPPORT); + struct packet_offload *ptype; + + rcu_read_lock(); + list_for_each_entry_rcu(ptype, &offload_base, list) { + if (ptype->type == type && ptype->callbacks.gso_segment) { + segs = ptype->callbacks.gso_segment(skb, features); + break; + } + } + rcu_read_unlock(); + + return segs; +} +EXPORT_SYMBOL(skb_eth_gso_segment); + +/** + * skb_mac_gso_segment - mac layer segmentation handler. + * @skb: buffer to segment + * @features: features for the output path (see dev->features) + */ +struct sk_buff *skb_mac_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + struct sk_buff *segs = ERR_PTR(-EPROTONOSUPPORT); + struct packet_offload *ptype; + int vlan_depth = skb->mac_len; + __be16 type = skb_network_protocol(skb, &vlan_depth); + + if (unlikely(!type)) + return ERR_PTR(-EINVAL); + + __skb_pull(skb, vlan_depth); + + rcu_read_lock(); + list_for_each_entry_rcu(ptype, &offload_base, list) { + if (ptype->type == type && ptype->callbacks.gso_segment) { + segs = ptype->callbacks.gso_segment(skb, features); + break; + } + } + rcu_read_unlock(); + + __skb_push(skb, skb->data - skb_mac_header(skb)); + + return segs; +} +EXPORT_SYMBOL(skb_mac_gso_segment); + +int skb_gro_receive(struct sk_buff *p, struct sk_buff *skb) +{ + struct skb_shared_info *pinfo, *skbinfo = skb_shinfo(skb); + unsigned int offset = skb_gro_offset(skb); + unsigned int headlen = skb_headlen(skb); + unsigned int len = skb_gro_len(skb); + unsigned int delta_truesize; + unsigned int gro_max_size; + unsigned int new_truesize; + struct sk_buff *lp; + int segs; + + /* Do not splice page pool based packets w/ non-page pool + * packets. This can result in reference count issues as page + * pool pages will not decrement the reference count and will + * instead be immediately returned to the pool or have frag + * count decremented. + */ + if (p->pp_recycle != skb->pp_recycle) + return -ETOOMANYREFS; + + /* pairs with WRITE_ONCE() in netif_set_gro_max_size() */ + gro_max_size = READ_ONCE(p->dev->gro_max_size); + + if (unlikely(p->len + len >= gro_max_size || NAPI_GRO_CB(skb)->flush)) + return -E2BIG; + + if (unlikely(p->len + len >= GRO_LEGACY_MAX_SIZE)) { + if (p->protocol != htons(ETH_P_IPV6) || + skb_headroom(p) < sizeof(struct hop_jumbo_hdr) || + ipv6_hdr(p)->nexthdr != IPPROTO_TCP || + p->encapsulation) + return -E2BIG; + } + + segs = NAPI_GRO_CB(skb)->count; + lp = NAPI_GRO_CB(p)->last; + pinfo = skb_shinfo(lp); + + if (headlen <= offset) { + skb_frag_t *frag; + skb_frag_t *frag2; + int i = skbinfo->nr_frags; + int nr_frags = pinfo->nr_frags + i; + + if (nr_frags > MAX_SKB_FRAGS) + goto merge; + + offset -= headlen; + pinfo->nr_frags = nr_frags; + skbinfo->nr_frags = 0; + + frag = pinfo->frags + nr_frags; + frag2 = skbinfo->frags + i; + do { + *--frag = *--frag2; + } while (--i); + + skb_frag_off_add(frag, offset); + skb_frag_size_sub(frag, offset); + + /* all fragments truesize : remove (head size + sk_buff) */ + new_truesize = SKB_TRUESIZE(skb_end_offset(skb)); + delta_truesize = skb->truesize - new_truesize; + + skb->truesize = new_truesize; + skb->len -= skb->data_len; + skb->data_len = 0; + + NAPI_GRO_CB(skb)->free = NAPI_GRO_FREE; + goto done; + } else if (skb->head_frag) { + int nr_frags = pinfo->nr_frags; + skb_frag_t *frag = pinfo->frags + nr_frags; + struct page *page = virt_to_head_page(skb->head); + unsigned int first_size = headlen - offset; + unsigned int first_offset; + + if (nr_frags + 1 + skbinfo->nr_frags > MAX_SKB_FRAGS) + goto merge; + + first_offset = skb->data - + (unsigned char *)page_address(page) + + offset; + + pinfo->nr_frags = nr_frags + 1 + skbinfo->nr_frags; + + __skb_frag_set_page(frag, page); + skb_frag_off_set(frag, first_offset); + skb_frag_size_set(frag, first_size); + + memcpy(frag + 1, skbinfo->frags, sizeof(*frag) * skbinfo->nr_frags); + /* We dont need to clear skbinfo->nr_frags here */ + + new_truesize = SKB_DATA_ALIGN(sizeof(struct sk_buff)); + delta_truesize = skb->truesize - new_truesize; + skb->truesize = new_truesize; + NAPI_GRO_CB(skb)->free = NAPI_GRO_FREE_STOLEN_HEAD; + goto done; + } + +merge: + /* sk owenrship - if any - completely transferred to the aggregated packet */ + skb->destructor = NULL; + delta_truesize = skb->truesize; + if (offset > headlen) { + unsigned int eat = offset - headlen; + + skb_frag_off_add(&skbinfo->frags[0], eat); + skb_frag_size_sub(&skbinfo->frags[0], eat); + skb->data_len -= eat; + skb->len -= eat; + offset = headlen; + } + + __skb_pull(skb, offset); + + if (NAPI_GRO_CB(p)->last == p) + skb_shinfo(p)->frag_list = skb; + else + NAPI_GRO_CB(p)->last->next = skb; + NAPI_GRO_CB(p)->last = skb; + __skb_header_release(skb); + lp = p; + +done: + NAPI_GRO_CB(p)->count += segs; + p->data_len += len; + p->truesize += delta_truesize; + p->len += len; + if (lp != p) { + lp->data_len += len; + lp->truesize += delta_truesize; + lp->len += len; + } + NAPI_GRO_CB(skb)->same_flow = 1; + return 0; +} + + +static void napi_gro_complete(struct napi_struct *napi, struct sk_buff *skb) +{ + struct packet_offload *ptype; + __be16 type = skb->protocol; + struct list_head *head = &offload_base; + int err = -ENOENT; + + BUILD_BUG_ON(sizeof(struct napi_gro_cb) > sizeof(skb->cb)); + + if (NAPI_GRO_CB(skb)->count == 1) { + skb_shinfo(skb)->gso_size = 0; + goto out; + } + + rcu_read_lock(); + list_for_each_entry_rcu(ptype, head, list) { + if (ptype->type != type || !ptype->callbacks.gro_complete) + continue; + + err = INDIRECT_CALL_INET(ptype->callbacks.gro_complete, + ipv6_gro_complete, inet_gro_complete, + skb, 0); + break; + } + rcu_read_unlock(); + + if (err) { + WARN_ON(&ptype->list == head); + kfree_skb(skb); + return; + } + +out: + gro_normal_one(napi, skb, NAPI_GRO_CB(skb)->count); +} + +static void __napi_gro_flush_chain(struct napi_struct *napi, u32 index, + bool flush_old) +{ + struct list_head *head = &napi->gro_hash[index].list; + struct sk_buff *skb, *p; + + list_for_each_entry_safe_reverse(skb, p, head, list) { + if (flush_old && NAPI_GRO_CB(skb)->age == jiffies) + return; + skb_list_del_init(skb); + napi_gro_complete(napi, skb); + napi->gro_hash[index].count--; + } + + if (!napi->gro_hash[index].count) + __clear_bit(index, &napi->gro_bitmask); +} + +/* napi->gro_hash[].list contains packets ordered by age. + * youngest packets at the head of it. + * Complete skbs in reverse order to reduce latencies. + */ +void napi_gro_flush(struct napi_struct *napi, bool flush_old) +{ + unsigned long bitmask = napi->gro_bitmask; + unsigned int i, base = ~0U; + + while ((i = ffs(bitmask)) != 0) { + bitmask >>= i; + base += i; + __napi_gro_flush_chain(napi, base, flush_old); + } +} +EXPORT_SYMBOL(napi_gro_flush); + +static void gro_list_prepare(const struct list_head *head, + const struct sk_buff *skb) +{ + unsigned int maclen = skb->dev->hard_header_len; + u32 hash = skb_get_hash_raw(skb); + struct sk_buff *p; + + list_for_each_entry(p, head, list) { + unsigned long diffs; + + NAPI_GRO_CB(p)->flush = 0; + + if (hash != skb_get_hash_raw(p)) { + NAPI_GRO_CB(p)->same_flow = 0; + continue; + } + + diffs = (unsigned long)p->dev ^ (unsigned long)skb->dev; + diffs |= skb_vlan_tag_present(p) ^ skb_vlan_tag_present(skb); + if (skb_vlan_tag_present(p)) + diffs |= skb_vlan_tag_get(p) ^ skb_vlan_tag_get(skb); + diffs |= skb_metadata_differs(p, skb); + if (maclen == ETH_HLEN) + diffs |= compare_ether_header(skb_mac_header(p), + skb_mac_header(skb)); + else if (!diffs) + diffs = memcmp(skb_mac_header(p), + skb_mac_header(skb), + maclen); + + /* in most common scenarions 'slow_gro' is 0 + * otherwise we are already on some slower paths + * either skip all the infrequent tests altogether or + * avoid trying too hard to skip each of them individually + */ + if (!diffs && unlikely(skb->slow_gro | p->slow_gro)) { +#if IS_ENABLED(CONFIG_SKB_EXTENSIONS) && IS_ENABLED(CONFIG_NET_TC_SKB_EXT) + struct tc_skb_ext *skb_ext; + struct tc_skb_ext *p_ext; +#endif + + diffs |= p->sk != skb->sk; + diffs |= skb_metadata_dst_cmp(p, skb); + diffs |= skb_get_nfct(p) ^ skb_get_nfct(skb); + +#if IS_ENABLED(CONFIG_SKB_EXTENSIONS) && IS_ENABLED(CONFIG_NET_TC_SKB_EXT) + skb_ext = skb_ext_find(skb, TC_SKB_EXT); + p_ext = skb_ext_find(p, TC_SKB_EXT); + + diffs |= (!!p_ext) ^ (!!skb_ext); + if (!diffs && unlikely(skb_ext)) + diffs |= p_ext->chain ^ skb_ext->chain; +#endif + } + + NAPI_GRO_CB(p)->same_flow = !diffs; + } +} + +static inline void skb_gro_reset_offset(struct sk_buff *skb, u32 nhoff) +{ + const struct skb_shared_info *pinfo = skb_shinfo(skb); + const skb_frag_t *frag0 = &pinfo->frags[0]; + + NAPI_GRO_CB(skb)->data_offset = 0; + NAPI_GRO_CB(skb)->frag0 = NULL; + NAPI_GRO_CB(skb)->frag0_len = 0; + + if (!skb_headlen(skb) && pinfo->nr_frags && + !PageHighMem(skb_frag_page(frag0)) && + (!NET_IP_ALIGN || !((skb_frag_off(frag0) + nhoff) & 3))) { + NAPI_GRO_CB(skb)->frag0 = skb_frag_address(frag0); + NAPI_GRO_CB(skb)->frag0_len = min_t(unsigned int, + skb_frag_size(frag0), + skb->end - skb->tail); + } +} + +static void gro_pull_from_frag0(struct sk_buff *skb, int grow) +{ + struct skb_shared_info *pinfo = skb_shinfo(skb); + + BUG_ON(skb->end - skb->tail < grow); + + memcpy(skb_tail_pointer(skb), NAPI_GRO_CB(skb)->frag0, grow); + + skb->data_len -= grow; + skb->tail += grow; + + skb_frag_off_add(&pinfo->frags[0], grow); + skb_frag_size_sub(&pinfo->frags[0], grow); + + if (unlikely(!skb_frag_size(&pinfo->frags[0]))) { + skb_frag_unref(skb, 0); + memmove(pinfo->frags, pinfo->frags + 1, + --pinfo->nr_frags * sizeof(pinfo->frags[0])); + } +} + +static void gro_flush_oldest(struct napi_struct *napi, struct list_head *head) +{ + struct sk_buff *oldest; + + oldest = list_last_entry(head, struct sk_buff, list); + + /* We are called with head length >= MAX_GRO_SKBS, so this is + * impossible. + */ + if (WARN_ON_ONCE(!oldest)) + return; + + /* Do not adjust napi->gro_hash[].count, caller is adding a new + * SKB to the chain. + */ + skb_list_del_init(oldest); + napi_gro_complete(napi, oldest); +} + +static enum gro_result dev_gro_receive(struct napi_struct *napi, struct sk_buff *skb) +{ + u32 bucket = skb_get_hash_raw(skb) & (GRO_HASH_BUCKETS - 1); + struct gro_list *gro_list = &napi->gro_hash[bucket]; + struct list_head *head = &offload_base; + struct packet_offload *ptype; + __be16 type = skb->protocol; + struct sk_buff *pp = NULL; + enum gro_result ret; + int same_flow; + int grow; + + if (netif_elide_gro(skb->dev)) + goto normal; + + gro_list_prepare(&gro_list->list, skb); + + rcu_read_lock(); + list_for_each_entry_rcu(ptype, head, list) { + if (ptype->type == type && ptype->callbacks.gro_receive) + goto found_ptype; + } + rcu_read_unlock(); + goto normal; + +found_ptype: + skb_set_network_header(skb, skb_gro_offset(skb)); + skb_reset_mac_len(skb); + BUILD_BUG_ON(sizeof_field(struct napi_gro_cb, zeroed) != sizeof(u32)); + BUILD_BUG_ON(!IS_ALIGNED(offsetof(struct napi_gro_cb, zeroed), + sizeof(u32))); /* Avoid slow unaligned acc */ + *(u32 *)&NAPI_GRO_CB(skb)->zeroed = 0; + NAPI_GRO_CB(skb)->flush = skb_has_frag_list(skb); + NAPI_GRO_CB(skb)->is_atomic = 1; + NAPI_GRO_CB(skb)->count = 1; + if (unlikely(skb_is_gso(skb))) { + NAPI_GRO_CB(skb)->count = skb_shinfo(skb)->gso_segs; + /* Only support TCP and non DODGY users. */ + if (!skb_is_gso_tcp(skb) || + (skb_shinfo(skb)->gso_type & SKB_GSO_DODGY)) + NAPI_GRO_CB(skb)->flush = 1; + } + + /* Setup for GRO checksum validation */ + switch (skb->ip_summed) { + case CHECKSUM_COMPLETE: + NAPI_GRO_CB(skb)->csum = skb->csum; + NAPI_GRO_CB(skb)->csum_valid = 1; + break; + case CHECKSUM_UNNECESSARY: + NAPI_GRO_CB(skb)->csum_cnt = skb->csum_level + 1; + break; + } + + pp = INDIRECT_CALL_INET(ptype->callbacks.gro_receive, + ipv6_gro_receive, inet_gro_receive, + &gro_list->list, skb); + + rcu_read_unlock(); + + if (PTR_ERR(pp) == -EINPROGRESS) { + ret = GRO_CONSUMED; + goto ok; + } + + same_flow = NAPI_GRO_CB(skb)->same_flow; + ret = NAPI_GRO_CB(skb)->free ? GRO_MERGED_FREE : GRO_MERGED; + + if (pp) { + skb_list_del_init(pp); + napi_gro_complete(napi, pp); + gro_list->count--; + } + + if (same_flow) + goto ok; + + if (NAPI_GRO_CB(skb)->flush) + goto normal; + + if (unlikely(gro_list->count >= MAX_GRO_SKBS)) + gro_flush_oldest(napi, &gro_list->list); + else + gro_list->count++; + + NAPI_GRO_CB(skb)->age = jiffies; + NAPI_GRO_CB(skb)->last = skb; + if (!skb_is_gso(skb)) + skb_shinfo(skb)->gso_size = skb_gro_len(skb); + list_add(&skb->list, &gro_list->list); + ret = GRO_HELD; + +pull: + grow = skb_gro_offset(skb) - skb_headlen(skb); + if (grow > 0) + gro_pull_from_frag0(skb, grow); +ok: + if (gro_list->count) { + if (!test_bit(bucket, &napi->gro_bitmask)) + __set_bit(bucket, &napi->gro_bitmask); + } else if (test_bit(bucket, &napi->gro_bitmask)) { + __clear_bit(bucket, &napi->gro_bitmask); + } + + return ret; + +normal: + ret = GRO_NORMAL; + goto pull; +} + +struct packet_offload *gro_find_receive_by_type(__be16 type) +{ + struct list_head *offload_head = &offload_base; + struct packet_offload *ptype; + + list_for_each_entry_rcu(ptype, offload_head, list) { + if (ptype->type != type || !ptype->callbacks.gro_receive) + continue; + return ptype; + } + return NULL; +} +EXPORT_SYMBOL(gro_find_receive_by_type); + +struct packet_offload *gro_find_complete_by_type(__be16 type) +{ + struct list_head *offload_head = &offload_base; + struct packet_offload *ptype; + + list_for_each_entry_rcu(ptype, offload_head, list) { + if (ptype->type != type || !ptype->callbacks.gro_complete) + continue; + return ptype; + } + return NULL; +} +EXPORT_SYMBOL(gro_find_complete_by_type); + +static gro_result_t napi_skb_finish(struct napi_struct *napi, + struct sk_buff *skb, + gro_result_t ret) +{ + switch (ret) { + case GRO_NORMAL: + gro_normal_one(napi, skb, 1); + break; + + case GRO_MERGED_FREE: + if (NAPI_GRO_CB(skb)->free == NAPI_GRO_FREE_STOLEN_HEAD) + napi_skb_free_stolen_head(skb); + else if (skb->fclone != SKB_FCLONE_UNAVAILABLE) + __kfree_skb(skb); + else + __kfree_skb_defer(skb); + break; + + case GRO_HELD: + case GRO_MERGED: + case GRO_CONSUMED: + break; + } + + return ret; +} + +gro_result_t napi_gro_receive(struct napi_struct *napi, struct sk_buff *skb) +{ + gro_result_t ret; + + skb_mark_napi_id(skb, napi); + trace_napi_gro_receive_entry(skb); + + skb_gro_reset_offset(skb, 0); + + ret = napi_skb_finish(napi, skb, dev_gro_receive(napi, skb)); + trace_napi_gro_receive_exit(ret); + + return ret; +} +EXPORT_SYMBOL(napi_gro_receive); + +static void napi_reuse_skb(struct napi_struct *napi, struct sk_buff *skb) +{ + if (unlikely(skb->pfmemalloc)) { + consume_skb(skb); + return; + } + __skb_pull(skb, skb_headlen(skb)); + /* restore the reserve we had after netdev_alloc_skb_ip_align() */ + skb_reserve(skb, NET_SKB_PAD + NET_IP_ALIGN - skb_headroom(skb)); + __vlan_hwaccel_clear_tag(skb); + skb->dev = napi->dev; + skb->skb_iif = 0; + + /* eth_type_trans() assumes pkt_type is PACKET_HOST */ + skb->pkt_type = PACKET_HOST; + + skb->encapsulation = 0; + skb_shinfo(skb)->gso_type = 0; + skb_shinfo(skb)->gso_size = 0; + if (unlikely(skb->slow_gro)) { + skb_orphan(skb); + skb_ext_reset(skb); + nf_reset_ct(skb); + skb->slow_gro = 0; + } + + napi->skb = skb; +} + +struct sk_buff *napi_get_frags(struct napi_struct *napi) +{ + struct sk_buff *skb = napi->skb; + + if (!skb) { + skb = napi_alloc_skb(napi, GRO_MAX_HEAD); + if (skb) { + napi->skb = skb; + skb_mark_napi_id(skb, napi); + } + } + return skb; +} +EXPORT_SYMBOL(napi_get_frags); + +static gro_result_t napi_frags_finish(struct napi_struct *napi, + struct sk_buff *skb, + gro_result_t ret) +{ + switch (ret) { + case GRO_NORMAL: + case GRO_HELD: + __skb_push(skb, ETH_HLEN); + skb->protocol = eth_type_trans(skb, skb->dev); + if (ret == GRO_NORMAL) + gro_normal_one(napi, skb, 1); + break; + + case GRO_MERGED_FREE: + if (NAPI_GRO_CB(skb)->free == NAPI_GRO_FREE_STOLEN_HEAD) + napi_skb_free_stolen_head(skb); + else + napi_reuse_skb(napi, skb); + break; + + case GRO_MERGED: + case GRO_CONSUMED: + break; + } + + return ret; +} + +/* Upper GRO stack assumes network header starts at gro_offset=0 + * Drivers could call both napi_gro_frags() and napi_gro_receive() + * We copy ethernet header into skb->data to have a common layout. + */ +static struct sk_buff *napi_frags_skb(struct napi_struct *napi) +{ + struct sk_buff *skb = napi->skb; + const struct ethhdr *eth; + unsigned int hlen = sizeof(*eth); + + napi->skb = NULL; + + skb_reset_mac_header(skb); + skb_gro_reset_offset(skb, hlen); + + if (unlikely(skb_gro_header_hard(skb, hlen))) { + eth = skb_gro_header_slow(skb, hlen, 0); + if (unlikely(!eth)) { + net_warn_ratelimited("%s: dropping impossible skb from %s\n", + __func__, napi->dev->name); + napi_reuse_skb(napi, skb); + return NULL; + } + } else { + eth = (const struct ethhdr *)skb->data; + gro_pull_from_frag0(skb, hlen); + NAPI_GRO_CB(skb)->frag0 += hlen; + NAPI_GRO_CB(skb)->frag0_len -= hlen; + } + __skb_pull(skb, hlen); + + /* + * This works because the only protocols we care about don't require + * special handling. + * We'll fix it up properly in napi_frags_finish() + */ + skb->protocol = eth->h_proto; + + return skb; +} + +gro_result_t napi_gro_frags(struct napi_struct *napi) +{ + gro_result_t ret; + struct sk_buff *skb = napi_frags_skb(napi); + + trace_napi_gro_frags_entry(skb); + + ret = napi_frags_finish(napi, skb, dev_gro_receive(napi, skb)); + trace_napi_gro_frags_exit(ret); + + return ret; +} +EXPORT_SYMBOL(napi_gro_frags); + +/* Compute the checksum from gro_offset and return the folded value + * after adding in any pseudo checksum. + */ +__sum16 __skb_gro_checksum_complete(struct sk_buff *skb) +{ + __wsum wsum; + __sum16 sum; + + wsum = skb_checksum(skb, skb_gro_offset(skb), skb_gro_len(skb), 0); + + /* NAPI_GRO_CB(skb)->csum holds pseudo checksum */ + sum = csum_fold(csum_add(NAPI_GRO_CB(skb)->csum, wsum)); + /* See comments in __skb_checksum_complete(). */ + if (likely(!sum)) { + if (unlikely(skb->ip_summed == CHECKSUM_COMPLETE) && + !skb->csum_complete_sw) + netdev_rx_csum_fault(skb->dev, skb); + } + + NAPI_GRO_CB(skb)->csum = wsum; + NAPI_GRO_CB(skb)->csum_valid = 1; + + return sum; +} +EXPORT_SYMBOL(__skb_gro_checksum_complete); diff --git a/net/core/gro_cells.c b/net/core/gro_cells.c new file mode 100644 index 000000000..ed5ec5de4 --- /dev/null +++ b/net/core/gro_cells.c @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include + +struct gro_cell { + struct sk_buff_head napi_skbs; + struct napi_struct napi; +}; + +int gro_cells_receive(struct gro_cells *gcells, struct sk_buff *skb) +{ + struct net_device *dev = skb->dev; + struct gro_cell *cell; + int res; + + rcu_read_lock(); + if (unlikely(!(dev->flags & IFF_UP))) + goto drop; + + if (!gcells->cells || skb_cloned(skb) || netif_elide_gro(dev)) { + res = netif_rx(skb); + goto unlock; + } + + cell = this_cpu_ptr(gcells->cells); + + if (skb_queue_len(&cell->napi_skbs) > READ_ONCE(netdev_max_backlog)) { +drop: + dev_core_stats_rx_dropped_inc(dev); + kfree_skb(skb); + res = NET_RX_DROP; + goto unlock; + } + + __skb_queue_tail(&cell->napi_skbs, skb); + if (skb_queue_len(&cell->napi_skbs) == 1) + napi_schedule(&cell->napi); + + res = NET_RX_SUCCESS; + +unlock: + rcu_read_unlock(); + return res; +} +EXPORT_SYMBOL(gro_cells_receive); + +/* called under BH context */ +static int gro_cell_poll(struct napi_struct *napi, int budget) +{ + struct gro_cell *cell = container_of(napi, struct gro_cell, napi); + struct sk_buff *skb; + int work_done = 0; + + while (work_done < budget) { + skb = __skb_dequeue(&cell->napi_skbs); + if (!skb) + break; + napi_gro_receive(napi, skb); + work_done++; + } + + if (work_done < budget) + napi_complete_done(napi, work_done); + return work_done; +} + +int gro_cells_init(struct gro_cells *gcells, struct net_device *dev) +{ + int i; + + gcells->cells = alloc_percpu(struct gro_cell); + if (!gcells->cells) + return -ENOMEM; + + for_each_possible_cpu(i) { + struct gro_cell *cell = per_cpu_ptr(gcells->cells, i); + + __skb_queue_head_init(&cell->napi_skbs); + + set_bit(NAPI_STATE_NO_BUSY_POLL, &cell->napi.state); + + netif_napi_add(dev, &cell->napi, gro_cell_poll); + napi_enable(&cell->napi); + } + return 0; +} +EXPORT_SYMBOL(gro_cells_init); + +struct percpu_free_defer { + struct rcu_head rcu; + void __percpu *ptr; +}; + +static void percpu_free_defer_callback(struct rcu_head *head) +{ + struct percpu_free_defer *defer; + + defer = container_of(head, struct percpu_free_defer, rcu); + free_percpu(defer->ptr); + kfree(defer); +} + +void gro_cells_destroy(struct gro_cells *gcells) +{ + struct percpu_free_defer *defer; + int i; + + if (!gcells->cells) + return; + for_each_possible_cpu(i) { + struct gro_cell *cell = per_cpu_ptr(gcells->cells, i); + + napi_disable(&cell->napi); + __netif_napi_del(&cell->napi); + __skb_queue_purge(&cell->napi_skbs); + } + /* We need to observe an rcu grace period before freeing ->cells, + * because netpoll could access dev->napi_list under rcu protection. + * Try hard using call_rcu() instead of synchronize_rcu(), + * because we might be called from cleanup_net(), and we + * definitely do not want to block this critical task. + */ + defer = kmalloc(sizeof(*defer), GFP_KERNEL | __GFP_NOWARN); + if (likely(defer)) { + defer->ptr = gcells->cells; + call_rcu(&defer->rcu, percpu_free_defer_callback); + } else { + /* We do not hold RTNL at this point, synchronize_net() + * would not be able to expedite this sync. + */ + synchronize_rcu_expedited(); + free_percpu(gcells->cells); + } + gcells->cells = NULL; +} +EXPORT_SYMBOL(gro_cells_destroy); diff --git a/net/core/hwbm.c b/net/core/hwbm.c new file mode 100644 index 000000000..ac1a66df9 --- /dev/null +++ b/net/core/hwbm.c @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* Support for hardware buffer manager. + * + * Copyright (C) 2016 Marvell + * + * Gregory CLEMENT + */ +#include +#include +#include +#include + +void hwbm_buf_free(struct hwbm_pool *bm_pool, void *buf) +{ + if (likely(bm_pool->frag_size <= PAGE_SIZE)) + skb_free_frag(buf); + else + kfree(buf); +} +EXPORT_SYMBOL_GPL(hwbm_buf_free); + +/* Refill processing for HW buffer management */ +int hwbm_pool_refill(struct hwbm_pool *bm_pool, gfp_t gfp) +{ + int frag_size = bm_pool->frag_size; + void *buf; + + if (likely(frag_size <= PAGE_SIZE)) + buf = netdev_alloc_frag(frag_size); + else + buf = kmalloc(frag_size, gfp); + + if (!buf) + return -ENOMEM; + + if (bm_pool->construct) + if (bm_pool->construct(bm_pool, buf)) { + hwbm_buf_free(bm_pool, buf); + return -ENOMEM; + } + + return 0; +} +EXPORT_SYMBOL_GPL(hwbm_pool_refill); + +int hwbm_pool_add(struct hwbm_pool *bm_pool, unsigned int buf_num) +{ + int err, i; + + mutex_lock(&bm_pool->buf_lock); + if (bm_pool->buf_num == bm_pool->size) { + pr_warn("pool already filled\n"); + mutex_unlock(&bm_pool->buf_lock); + return bm_pool->buf_num; + } + + if (buf_num + bm_pool->buf_num > bm_pool->size) { + pr_warn("cannot allocate %d buffers for pool\n", + buf_num); + mutex_unlock(&bm_pool->buf_lock); + return 0; + } + + if ((buf_num + bm_pool->buf_num) < bm_pool->buf_num) { + pr_warn("Adding %d buffers to the %d current buffers will overflow\n", + buf_num, bm_pool->buf_num); + mutex_unlock(&bm_pool->buf_lock); + return 0; + } + + for (i = 0; i < buf_num; i++) { + err = hwbm_pool_refill(bm_pool, GFP_KERNEL); + if (err < 0) + break; + } + + /* Update BM driver with number of buffers added to pool */ + bm_pool->buf_num += i; + + pr_debug("hwpm pool: %d of %d buffers added\n", i, buf_num); + mutex_unlock(&bm_pool->buf_lock); + + return i; +} +EXPORT_SYMBOL_GPL(hwbm_pool_add); diff --git a/net/core/link_watch.c b/net/core/link_watch.c new file mode 100644 index 000000000..aa6cb1f90 --- /dev/null +++ b/net/core/link_watch.c @@ -0,0 +1,280 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Linux network device link state notification + * + * Author: + * Stefan Rompf + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dev.h" + +enum lw_bits { + LW_URGENT = 0, +}; + +static unsigned long linkwatch_flags; +static unsigned long linkwatch_nextevent; + +static void linkwatch_event(struct work_struct *dummy); +static DECLARE_DELAYED_WORK(linkwatch_work, linkwatch_event); + +static LIST_HEAD(lweventlist); +static DEFINE_SPINLOCK(lweventlist_lock); + +static unsigned char default_operstate(const struct net_device *dev) +{ + if (netif_testing(dev)) + return IF_OPER_TESTING; + + if (!netif_carrier_ok(dev)) + return (dev->ifindex != dev_get_iflink(dev) ? + IF_OPER_LOWERLAYERDOWN : IF_OPER_DOWN); + + if (netif_dormant(dev)) + return IF_OPER_DORMANT; + + return IF_OPER_UP; +} + + +static void rfc2863_policy(struct net_device *dev) +{ + unsigned char operstate = default_operstate(dev); + + if (operstate == dev->operstate) + return; + + write_lock(&dev_base_lock); + + switch(dev->link_mode) { + case IF_LINK_MODE_TESTING: + if (operstate == IF_OPER_UP) + operstate = IF_OPER_TESTING; + break; + + case IF_LINK_MODE_DORMANT: + if (operstate == IF_OPER_UP) + operstate = IF_OPER_DORMANT; + break; + case IF_LINK_MODE_DEFAULT: + default: + break; + } + + dev->operstate = operstate; + + write_unlock(&dev_base_lock); +} + + +void linkwatch_init_dev(struct net_device *dev) +{ + /* Handle pre-registration link state changes */ + if (!netif_carrier_ok(dev) || netif_dormant(dev) || + netif_testing(dev)) + rfc2863_policy(dev); +} + + +static bool linkwatch_urgent_event(struct net_device *dev) +{ + if (!netif_running(dev)) + return false; + + if (dev->ifindex != dev_get_iflink(dev)) + return true; + + if (netif_is_lag_port(dev) || netif_is_lag_master(dev)) + return true; + + return netif_carrier_ok(dev) && qdisc_tx_changing(dev); +} + + +static void linkwatch_add_event(struct net_device *dev) +{ + unsigned long flags; + + spin_lock_irqsave(&lweventlist_lock, flags); + if (list_empty(&dev->link_watch_list)) { + list_add_tail(&dev->link_watch_list, &lweventlist); + netdev_hold(dev, &dev->linkwatch_dev_tracker, GFP_ATOMIC); + } + spin_unlock_irqrestore(&lweventlist_lock, flags); +} + + +static void linkwatch_schedule_work(int urgent) +{ + unsigned long delay = linkwatch_nextevent - jiffies; + + if (test_bit(LW_URGENT, &linkwatch_flags)) + return; + + /* Minimise down-time: drop delay for up event. */ + if (urgent) { + if (test_and_set_bit(LW_URGENT, &linkwatch_flags)) + return; + delay = 0; + } + + /* If we wrap around we'll delay it by at most HZ. */ + if (delay > HZ) + delay = 0; + + /* + * If urgent, schedule immediate execution; otherwise, don't + * override the existing timer. + */ + if (test_bit(LW_URGENT, &linkwatch_flags)) + mod_delayed_work(system_wq, &linkwatch_work, 0); + else + schedule_delayed_work(&linkwatch_work, delay); +} + + +static void linkwatch_do_dev(struct net_device *dev) +{ + /* + * Make sure the above read is complete since it can be + * rewritten as soon as we clear the bit below. + */ + smp_mb__before_atomic(); + + /* We are about to handle this device, + * so new events can be accepted + */ + clear_bit(__LINK_STATE_LINKWATCH_PENDING, &dev->state); + + rfc2863_policy(dev); + if (dev->flags & IFF_UP) { + if (netif_carrier_ok(dev)) + dev_activate(dev); + else + dev_deactivate(dev); + + netdev_state_change(dev); + } + /* Note: our callers are responsible for calling netdev_tracker_free(). + * This is the reason we use __dev_put() instead of dev_put(). + */ + __dev_put(dev); +} + +static void __linkwatch_run_queue(int urgent_only) +{ +#define MAX_DO_DEV_PER_LOOP 100 + + int do_dev = MAX_DO_DEV_PER_LOOP; + struct net_device *dev; + LIST_HEAD(wrk); + + /* Give urgent case more budget */ + if (urgent_only) + do_dev += MAX_DO_DEV_PER_LOOP; + + /* + * Limit the number of linkwatch events to one + * per second so that a runaway driver does not + * cause a storm of messages on the netlink + * socket. This limit does not apply to up events + * while the device qdisc is down. + */ + if (!urgent_only) + linkwatch_nextevent = jiffies + HZ; + /* Limit wrap-around effect on delay. */ + else if (time_after(linkwatch_nextevent, jiffies + HZ)) + linkwatch_nextevent = jiffies; + + clear_bit(LW_URGENT, &linkwatch_flags); + + spin_lock_irq(&lweventlist_lock); + list_splice_init(&lweventlist, &wrk); + + while (!list_empty(&wrk) && do_dev > 0) { + + dev = list_first_entry(&wrk, struct net_device, link_watch_list); + list_del_init(&dev->link_watch_list); + + if (!netif_device_present(dev) || + (urgent_only && !linkwatch_urgent_event(dev))) { + list_add_tail(&dev->link_watch_list, &lweventlist); + continue; + } + /* We must free netdev tracker under + * the spinlock protection. + */ + netdev_tracker_free(dev, &dev->linkwatch_dev_tracker); + spin_unlock_irq(&lweventlist_lock); + linkwatch_do_dev(dev); + do_dev--; + spin_lock_irq(&lweventlist_lock); + } + + /* Add the remaining work back to lweventlist */ + list_splice_init(&wrk, &lweventlist); + + if (!list_empty(&lweventlist)) + linkwatch_schedule_work(0); + spin_unlock_irq(&lweventlist_lock); +} + +void linkwatch_forget_dev(struct net_device *dev) +{ + unsigned long flags; + int clean = 0; + + spin_lock_irqsave(&lweventlist_lock, flags); + if (!list_empty(&dev->link_watch_list)) { + list_del_init(&dev->link_watch_list); + clean = 1; + /* We must release netdev tracker under + * the spinlock protection. + */ + netdev_tracker_free(dev, &dev->linkwatch_dev_tracker); + } + spin_unlock_irqrestore(&lweventlist_lock, flags); + if (clean) + linkwatch_do_dev(dev); +} + + +/* Must be called with the rtnl semaphore held */ +void linkwatch_run_queue(void) +{ + __linkwatch_run_queue(0); +} + + +static void linkwatch_event(struct work_struct *dummy) +{ + rtnl_lock(); + __linkwatch_run_queue(time_after(linkwatch_nextevent, jiffies)); + rtnl_unlock(); +} + + +void linkwatch_fire_event(struct net_device *dev) +{ + bool urgent = linkwatch_urgent_event(dev); + + if (!test_and_set_bit(__LINK_STATE_LINKWATCH_PENDING, &dev->state)) { + linkwatch_add_event(dev); + } else if (!urgent) + return; + + linkwatch_schedule_work(urgent); +} +EXPORT_SYMBOL(linkwatch_fire_event); diff --git a/net/core/lwt_bpf.c b/net/core/lwt_bpf.c new file mode 100644 index 000000000..4a0797f0a --- /dev/null +++ b/net/core/lwt_bpf.c @@ -0,0 +1,657 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (c) 2016 Thomas Graf + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct bpf_lwt_prog { + struct bpf_prog *prog; + char *name; +}; + +struct bpf_lwt { + struct bpf_lwt_prog in; + struct bpf_lwt_prog out; + struct bpf_lwt_prog xmit; + int family; +}; + +#define MAX_PROG_NAME 256 + +static inline struct bpf_lwt *bpf_lwt_lwtunnel(struct lwtunnel_state *lwt) +{ + return (struct bpf_lwt *)lwt->data; +} + +#define NO_REDIRECT false +#define CAN_REDIRECT true + +static int run_lwt_bpf(struct sk_buff *skb, struct bpf_lwt_prog *lwt, + struct dst_entry *dst, bool can_redirect) +{ + int ret; + + /* Migration disable and BH disable are needed to protect per-cpu + * redirect_info between BPF prog and skb_do_redirect(). + */ + migrate_disable(); + local_bh_disable(); + bpf_compute_data_pointers(skb); + ret = bpf_prog_run_save_cb(lwt->prog, skb); + + switch (ret) { + case BPF_OK: + case BPF_LWT_REROUTE: + break; + + case BPF_REDIRECT: + if (unlikely(!can_redirect)) { + pr_warn_once("Illegal redirect return code in prog %s\n", + lwt->name ? : ""); + ret = BPF_OK; + } else { + skb_reset_mac_header(skb); + skb_do_redirect(skb); + ret = BPF_REDIRECT; + } + break; + + case BPF_DROP: + kfree_skb(skb); + ret = -EPERM; + break; + + default: + pr_warn_once("bpf-lwt: Illegal return value %u, expect packet loss\n", ret); + kfree_skb(skb); + ret = -EINVAL; + break; + } + + local_bh_enable(); + migrate_enable(); + + return ret; +} + +static int bpf_lwt_input_reroute(struct sk_buff *skb) +{ + int err = -EINVAL; + + if (skb->protocol == htons(ETH_P_IP)) { + struct net_device *dev = skb_dst(skb)->dev; + struct iphdr *iph = ip_hdr(skb); + + dev_hold(dev); + skb_dst_drop(skb); + err = ip_route_input_noref(skb, iph->daddr, iph->saddr, + iph->tos, dev); + dev_put(dev); + } else if (skb->protocol == htons(ETH_P_IPV6)) { + skb_dst_drop(skb); + err = ipv6_stub->ipv6_route_input(skb); + } else { + err = -EAFNOSUPPORT; + } + + if (err) + goto err; + return dst_input(skb); + +err: + kfree_skb(skb); + return err; +} + +static int bpf_input(struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + struct bpf_lwt *bpf; + int ret; + + bpf = bpf_lwt_lwtunnel(dst->lwtstate); + if (bpf->in.prog) { + ret = run_lwt_bpf(skb, &bpf->in, dst, NO_REDIRECT); + if (ret < 0) + return ret; + if (ret == BPF_LWT_REROUTE) + return bpf_lwt_input_reroute(skb); + } + + if (unlikely(!dst->lwtstate->orig_input)) { + kfree_skb(skb); + return -EINVAL; + } + + return dst->lwtstate->orig_input(skb); +} + +static int bpf_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + struct bpf_lwt *bpf; + int ret; + + bpf = bpf_lwt_lwtunnel(dst->lwtstate); + if (bpf->out.prog) { + ret = run_lwt_bpf(skb, &bpf->out, dst, NO_REDIRECT); + if (ret < 0) + return ret; + } + + if (unlikely(!dst->lwtstate->orig_output)) { + pr_warn_once("orig_output not set on dst for prog %s\n", + bpf->out.name); + kfree_skb(skb); + return -EINVAL; + } + + return dst->lwtstate->orig_output(net, sk, skb); +} + +static int xmit_check_hhlen(struct sk_buff *skb, int hh_len) +{ + if (skb_headroom(skb) < hh_len) { + int nhead = HH_DATA_ALIGN(hh_len - skb_headroom(skb)); + + if (pskb_expand_head(skb, nhead, 0, GFP_ATOMIC)) + return -ENOMEM; + } + + return 0; +} + +static int bpf_lwt_xmit_reroute(struct sk_buff *skb) +{ + struct net_device *l3mdev = l3mdev_master_dev_rcu(skb_dst(skb)->dev); + int oif = l3mdev ? l3mdev->ifindex : 0; + struct dst_entry *dst = NULL; + int err = -EAFNOSUPPORT; + struct sock *sk; + struct net *net; + bool ipv4; + + if (skb->protocol == htons(ETH_P_IP)) + ipv4 = true; + else if (skb->protocol == htons(ETH_P_IPV6)) + ipv4 = false; + else + goto err; + + sk = sk_to_full_sk(skb->sk); + if (sk) { + if (sk->sk_bound_dev_if) + oif = sk->sk_bound_dev_if; + net = sock_net(sk); + } else { + net = dev_net(skb_dst(skb)->dev); + } + + if (ipv4) { + struct iphdr *iph = ip_hdr(skb); + struct flowi4 fl4 = {}; + struct rtable *rt; + + fl4.flowi4_oif = oif; + fl4.flowi4_mark = skb->mark; + fl4.flowi4_uid = sock_net_uid(net, sk); + fl4.flowi4_tos = RT_TOS(iph->tos); + fl4.flowi4_flags = FLOWI_FLAG_ANYSRC; + fl4.flowi4_proto = iph->protocol; + fl4.daddr = iph->daddr; + fl4.saddr = iph->saddr; + + rt = ip_route_output_key(net, &fl4); + if (IS_ERR(rt)) { + err = PTR_ERR(rt); + goto err; + } + dst = &rt->dst; + } else { + struct ipv6hdr *iph6 = ipv6_hdr(skb); + struct flowi6 fl6 = {}; + + fl6.flowi6_oif = oif; + fl6.flowi6_mark = skb->mark; + fl6.flowi6_uid = sock_net_uid(net, sk); + fl6.flowlabel = ip6_flowinfo(iph6); + fl6.flowi6_proto = iph6->nexthdr; + fl6.daddr = iph6->daddr; + fl6.saddr = iph6->saddr; + + dst = ipv6_stub->ipv6_dst_lookup_flow(net, skb->sk, &fl6, NULL); + if (IS_ERR(dst)) { + err = PTR_ERR(dst); + goto err; + } + } + if (unlikely(dst->error)) { + err = dst->error; + dst_release(dst); + goto err; + } + + /* Although skb header was reserved in bpf_lwt_push_ip_encap(), it + * was done for the previous dst, so we are doing it here again, in + * case the new dst needs much more space. The call below is a noop + * if there is enough header space in skb. + */ + err = skb_cow_head(skb, LL_RESERVED_SPACE(dst->dev)); + if (unlikely(err)) + goto err; + + skb_dst_drop(skb); + skb_dst_set(skb, dst); + + err = dst_output(dev_net(skb_dst(skb)->dev), skb->sk, skb); + if (unlikely(err)) + return net_xmit_errno(err); + + /* ip[6]_finish_output2 understand LWTUNNEL_XMIT_DONE */ + return LWTUNNEL_XMIT_DONE; + +err: + kfree_skb(skb); + return err; +} + +static int bpf_xmit(struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + struct bpf_lwt *bpf; + + bpf = bpf_lwt_lwtunnel(dst->lwtstate); + if (bpf->xmit.prog) { + int hh_len = dst->dev->hard_header_len; + __be16 proto = skb->protocol; + int ret; + + ret = run_lwt_bpf(skb, &bpf->xmit, dst, CAN_REDIRECT); + switch (ret) { + case BPF_OK: + /* If the header changed, e.g. via bpf_lwt_push_encap, + * BPF_LWT_REROUTE below should have been used if the + * protocol was also changed. + */ + if (skb->protocol != proto) { + kfree_skb(skb); + return -EINVAL; + } + /* If the header was expanded, headroom might be too + * small for L2 header to come, expand as needed. + */ + ret = xmit_check_hhlen(skb, hh_len); + if (unlikely(ret)) + return ret; + + return LWTUNNEL_XMIT_CONTINUE; + case BPF_REDIRECT: + return LWTUNNEL_XMIT_DONE; + case BPF_LWT_REROUTE: + return bpf_lwt_xmit_reroute(skb); + default: + return ret; + } + } + + return LWTUNNEL_XMIT_CONTINUE; +} + +static void bpf_lwt_prog_destroy(struct bpf_lwt_prog *prog) +{ + if (prog->prog) + bpf_prog_put(prog->prog); + + kfree(prog->name); +} + +static void bpf_destroy_state(struct lwtunnel_state *lwt) +{ + struct bpf_lwt *bpf = bpf_lwt_lwtunnel(lwt); + + bpf_lwt_prog_destroy(&bpf->in); + bpf_lwt_prog_destroy(&bpf->out); + bpf_lwt_prog_destroy(&bpf->xmit); +} + +static const struct nla_policy bpf_prog_policy[LWT_BPF_PROG_MAX + 1] = { + [LWT_BPF_PROG_FD] = { .type = NLA_U32, }, + [LWT_BPF_PROG_NAME] = { .type = NLA_NUL_STRING, + .len = MAX_PROG_NAME }, +}; + +static int bpf_parse_prog(struct nlattr *attr, struct bpf_lwt_prog *prog, + enum bpf_prog_type type) +{ + struct nlattr *tb[LWT_BPF_PROG_MAX + 1]; + struct bpf_prog *p; + int ret; + u32 fd; + + ret = nla_parse_nested_deprecated(tb, LWT_BPF_PROG_MAX, attr, + bpf_prog_policy, NULL); + if (ret < 0) + return ret; + + if (!tb[LWT_BPF_PROG_FD] || !tb[LWT_BPF_PROG_NAME]) + return -EINVAL; + + prog->name = nla_memdup(tb[LWT_BPF_PROG_NAME], GFP_ATOMIC); + if (!prog->name) + return -ENOMEM; + + fd = nla_get_u32(tb[LWT_BPF_PROG_FD]); + p = bpf_prog_get_type(fd, type); + if (IS_ERR(p)) + return PTR_ERR(p); + + prog->prog = p; + + return 0; +} + +static const struct nla_policy bpf_nl_policy[LWT_BPF_MAX + 1] = { + [LWT_BPF_IN] = { .type = NLA_NESTED, }, + [LWT_BPF_OUT] = { .type = NLA_NESTED, }, + [LWT_BPF_XMIT] = { .type = NLA_NESTED, }, + [LWT_BPF_XMIT_HEADROOM] = { .type = NLA_U32 }, +}; + +static int bpf_build_state(struct net *net, struct nlattr *nla, + unsigned int family, const void *cfg, + struct lwtunnel_state **ts, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[LWT_BPF_MAX + 1]; + struct lwtunnel_state *newts; + struct bpf_lwt *bpf; + int ret; + + if (family != AF_INET && family != AF_INET6) + return -EAFNOSUPPORT; + + ret = nla_parse_nested_deprecated(tb, LWT_BPF_MAX, nla, bpf_nl_policy, + extack); + if (ret < 0) + return ret; + + if (!tb[LWT_BPF_IN] && !tb[LWT_BPF_OUT] && !tb[LWT_BPF_XMIT]) + return -EINVAL; + + newts = lwtunnel_state_alloc(sizeof(*bpf)); + if (!newts) + return -ENOMEM; + + newts->type = LWTUNNEL_ENCAP_BPF; + bpf = bpf_lwt_lwtunnel(newts); + + if (tb[LWT_BPF_IN]) { + newts->flags |= LWTUNNEL_STATE_INPUT_REDIRECT; + ret = bpf_parse_prog(tb[LWT_BPF_IN], &bpf->in, + BPF_PROG_TYPE_LWT_IN); + if (ret < 0) + goto errout; + } + + if (tb[LWT_BPF_OUT]) { + newts->flags |= LWTUNNEL_STATE_OUTPUT_REDIRECT; + ret = bpf_parse_prog(tb[LWT_BPF_OUT], &bpf->out, + BPF_PROG_TYPE_LWT_OUT); + if (ret < 0) + goto errout; + } + + if (tb[LWT_BPF_XMIT]) { + newts->flags |= LWTUNNEL_STATE_XMIT_REDIRECT; + ret = bpf_parse_prog(tb[LWT_BPF_XMIT], &bpf->xmit, + BPF_PROG_TYPE_LWT_XMIT); + if (ret < 0) + goto errout; + } + + if (tb[LWT_BPF_XMIT_HEADROOM]) { + u32 headroom = nla_get_u32(tb[LWT_BPF_XMIT_HEADROOM]); + + if (headroom > LWT_BPF_MAX_HEADROOM) { + ret = -ERANGE; + goto errout; + } + + newts->headroom = headroom; + } + + bpf->family = family; + *ts = newts; + + return 0; + +errout: + bpf_destroy_state(newts); + kfree(newts); + return ret; +} + +static int bpf_fill_lwt_prog(struct sk_buff *skb, int attr, + struct bpf_lwt_prog *prog) +{ + struct nlattr *nest; + + if (!prog->prog) + return 0; + + nest = nla_nest_start_noflag(skb, attr); + if (!nest) + return -EMSGSIZE; + + if (prog->name && + nla_put_string(skb, LWT_BPF_PROG_NAME, prog->name)) + return -EMSGSIZE; + + return nla_nest_end(skb, nest); +} + +static int bpf_fill_encap_info(struct sk_buff *skb, struct lwtunnel_state *lwt) +{ + struct bpf_lwt *bpf = bpf_lwt_lwtunnel(lwt); + + if (bpf_fill_lwt_prog(skb, LWT_BPF_IN, &bpf->in) < 0 || + bpf_fill_lwt_prog(skb, LWT_BPF_OUT, &bpf->out) < 0 || + bpf_fill_lwt_prog(skb, LWT_BPF_XMIT, &bpf->xmit) < 0) + return -EMSGSIZE; + + return 0; +} + +static int bpf_encap_nlsize(struct lwtunnel_state *lwtstate) +{ + int nest_len = nla_total_size(sizeof(struct nlattr)) + + nla_total_size(MAX_PROG_NAME) + /* LWT_BPF_PROG_NAME */ + 0; + + return nest_len + /* LWT_BPF_IN */ + nest_len + /* LWT_BPF_OUT */ + nest_len + /* LWT_BPF_XMIT */ + 0; +} + +static int bpf_lwt_prog_cmp(struct bpf_lwt_prog *a, struct bpf_lwt_prog *b) +{ + /* FIXME: + * The LWT state is currently rebuilt for delete requests which + * results in a new bpf_prog instance. Comparing names for now. + */ + if (!a->name && !b->name) + return 0; + + if (!a->name || !b->name) + return 1; + + return strcmp(a->name, b->name); +} + +static int bpf_encap_cmp(struct lwtunnel_state *a, struct lwtunnel_state *b) +{ + struct bpf_lwt *a_bpf = bpf_lwt_lwtunnel(a); + struct bpf_lwt *b_bpf = bpf_lwt_lwtunnel(b); + + return bpf_lwt_prog_cmp(&a_bpf->in, &b_bpf->in) || + bpf_lwt_prog_cmp(&a_bpf->out, &b_bpf->out) || + bpf_lwt_prog_cmp(&a_bpf->xmit, &b_bpf->xmit); +} + +static const struct lwtunnel_encap_ops bpf_encap_ops = { + .build_state = bpf_build_state, + .destroy_state = bpf_destroy_state, + .input = bpf_input, + .output = bpf_output, + .xmit = bpf_xmit, + .fill_encap = bpf_fill_encap_info, + .get_encap_size = bpf_encap_nlsize, + .cmp_encap = bpf_encap_cmp, + .owner = THIS_MODULE, +}; + +static int handle_gso_type(struct sk_buff *skb, unsigned int gso_type, + int encap_len) +{ + struct skb_shared_info *shinfo = skb_shinfo(skb); + + gso_type |= SKB_GSO_DODGY; + shinfo->gso_type |= gso_type; + skb_decrease_gso_size(shinfo, encap_len); + shinfo->gso_segs = 0; + return 0; +} + +static int handle_gso_encap(struct sk_buff *skb, bool ipv4, int encap_len) +{ + int next_hdr_offset; + void *next_hdr; + __u8 protocol; + + /* SCTP and UDP_L4 gso need more nuanced handling than what + * handle_gso_type() does above: skb_decrease_gso_size() is not enough. + * So at the moment only TCP GSO packets are let through. + */ + if (!(skb_shinfo(skb)->gso_type & (SKB_GSO_TCPV4 | SKB_GSO_TCPV6))) + return -ENOTSUPP; + + if (ipv4) { + protocol = ip_hdr(skb)->protocol; + next_hdr_offset = sizeof(struct iphdr); + next_hdr = skb_network_header(skb) + next_hdr_offset; + } else { + protocol = ipv6_hdr(skb)->nexthdr; + next_hdr_offset = sizeof(struct ipv6hdr); + next_hdr = skb_network_header(skb) + next_hdr_offset; + } + + switch (protocol) { + case IPPROTO_GRE: + next_hdr_offset += sizeof(struct gre_base_hdr); + if (next_hdr_offset > encap_len) + return -EINVAL; + + if (((struct gre_base_hdr *)next_hdr)->flags & GRE_CSUM) + return handle_gso_type(skb, SKB_GSO_GRE_CSUM, + encap_len); + return handle_gso_type(skb, SKB_GSO_GRE, encap_len); + + case IPPROTO_UDP: + next_hdr_offset += sizeof(struct udphdr); + if (next_hdr_offset > encap_len) + return -EINVAL; + + if (((struct udphdr *)next_hdr)->check) + return handle_gso_type(skb, SKB_GSO_UDP_TUNNEL_CSUM, + encap_len); + return handle_gso_type(skb, SKB_GSO_UDP_TUNNEL, encap_len); + + case IPPROTO_IP: + case IPPROTO_IPV6: + if (ipv4) + return handle_gso_type(skb, SKB_GSO_IPXIP4, encap_len); + else + return handle_gso_type(skb, SKB_GSO_IPXIP6, encap_len); + + default: + return -EPROTONOSUPPORT; + } +} + +int bpf_lwt_push_ip_encap(struct sk_buff *skb, void *hdr, u32 len, bool ingress) +{ + struct iphdr *iph; + bool ipv4; + int err; + + if (unlikely(len < sizeof(struct iphdr) || len > LWT_BPF_MAX_HEADROOM)) + return -EINVAL; + + /* validate protocol and length */ + iph = (struct iphdr *)hdr; + if (iph->version == 4) { + ipv4 = true; + if (unlikely(len < iph->ihl * 4)) + return -EINVAL; + } else if (iph->version == 6) { + ipv4 = false; + if (unlikely(len < sizeof(struct ipv6hdr))) + return -EINVAL; + } else { + return -EINVAL; + } + + if (ingress) + err = skb_cow_head(skb, len + skb->mac_len); + else + err = skb_cow_head(skb, + len + LL_RESERVED_SPACE(skb_dst(skb)->dev)); + if (unlikely(err)) + return err; + + /* push the encap headers and fix pointers */ + skb_reset_inner_headers(skb); + skb_reset_inner_mac_header(skb); /* mac header is not yet set */ + skb_set_inner_protocol(skb, skb->protocol); + skb->encapsulation = 1; + skb_push(skb, len); + if (ingress) + skb_postpush_rcsum(skb, iph, len); + skb_reset_network_header(skb); + memcpy(skb_network_header(skb), hdr, len); + bpf_compute_data_pointers(skb); + skb_clear_hash(skb); + + if (ipv4) { + skb->protocol = htons(ETH_P_IP); + iph = ip_hdr(skb); + + if (!iph->check) + iph->check = ip_fast_csum((unsigned char *)iph, + iph->ihl); + } else { + skb->protocol = htons(ETH_P_IPV6); + } + + if (skb_is_gso(skb)) + return handle_gso_encap(skb, ipv4, len); + + return 0; +} + +static int __init bpf_lwt_init(void) +{ + return lwtunnel_encap_add_ops(&bpf_encap_ops, LWTUNNEL_ENCAP_BPF); +} + +subsys_initcall(bpf_lwt_init) diff --git a/net/core/lwtunnel.c b/net/core/lwtunnel.c new file mode 100644 index 000000000..711cd3b43 --- /dev/null +++ b/net/core/lwtunnel.c @@ -0,0 +1,427 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * lwtunnel Infrastructure for light weight tunnels like mpls + * + * Authors: Roopa Prabhu, + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +DEFINE_STATIC_KEY_FALSE(nf_hooks_lwtunnel_enabled); +EXPORT_SYMBOL_GPL(nf_hooks_lwtunnel_enabled); + +#ifdef CONFIG_MODULES + +static const char *lwtunnel_encap_str(enum lwtunnel_encap_types encap_type) +{ + /* Only lwt encaps implemented without using an interface for + * the encap need to return a string here. + */ + switch (encap_type) { + case LWTUNNEL_ENCAP_MPLS: + return "MPLS"; + case LWTUNNEL_ENCAP_ILA: + return "ILA"; + case LWTUNNEL_ENCAP_SEG6: + return "SEG6"; + case LWTUNNEL_ENCAP_BPF: + return "BPF"; + case LWTUNNEL_ENCAP_SEG6_LOCAL: + return "SEG6LOCAL"; + case LWTUNNEL_ENCAP_RPL: + return "RPL"; + case LWTUNNEL_ENCAP_IOAM6: + return "IOAM6"; + case LWTUNNEL_ENCAP_XFRM: + /* module autoload not supported for encap type */ + return NULL; + case LWTUNNEL_ENCAP_IP6: + case LWTUNNEL_ENCAP_IP: + case LWTUNNEL_ENCAP_NONE: + case __LWTUNNEL_ENCAP_MAX: + /* should not have got here */ + WARN_ON(1); + break; + } + return NULL; +} + +#endif /* CONFIG_MODULES */ + +struct lwtunnel_state *lwtunnel_state_alloc(int encap_len) +{ + struct lwtunnel_state *lws; + + lws = kzalloc(sizeof(*lws) + encap_len, GFP_ATOMIC); + + return lws; +} +EXPORT_SYMBOL_GPL(lwtunnel_state_alloc); + +static const struct lwtunnel_encap_ops __rcu * + lwtun_encaps[LWTUNNEL_ENCAP_MAX + 1] __read_mostly; + +int lwtunnel_encap_add_ops(const struct lwtunnel_encap_ops *ops, + unsigned int num) +{ + if (num > LWTUNNEL_ENCAP_MAX) + return -ERANGE; + + return !cmpxchg((const struct lwtunnel_encap_ops **) + &lwtun_encaps[num], + NULL, ops) ? 0 : -1; +} +EXPORT_SYMBOL_GPL(lwtunnel_encap_add_ops); + +int lwtunnel_encap_del_ops(const struct lwtunnel_encap_ops *ops, + unsigned int encap_type) +{ + int ret; + + if (encap_type == LWTUNNEL_ENCAP_NONE || + encap_type > LWTUNNEL_ENCAP_MAX) + return -ERANGE; + + ret = (cmpxchg((const struct lwtunnel_encap_ops **) + &lwtun_encaps[encap_type], + ops, NULL) == ops) ? 0 : -1; + + synchronize_net(); + + return ret; +} +EXPORT_SYMBOL_GPL(lwtunnel_encap_del_ops); + +int lwtunnel_build_state(struct net *net, u16 encap_type, + struct nlattr *encap, unsigned int family, + const void *cfg, struct lwtunnel_state **lws, + struct netlink_ext_ack *extack) +{ + const struct lwtunnel_encap_ops *ops; + bool found = false; + int ret = -EINVAL; + + if (encap_type == LWTUNNEL_ENCAP_NONE || + encap_type > LWTUNNEL_ENCAP_MAX) { + NL_SET_ERR_MSG_ATTR(extack, encap, + "Unknown LWT encapsulation type"); + return ret; + } + + ret = -EOPNOTSUPP; + rcu_read_lock(); + ops = rcu_dereference(lwtun_encaps[encap_type]); + if (likely(ops && ops->build_state && try_module_get(ops->owner))) + found = true; + rcu_read_unlock(); + + if (found) { + ret = ops->build_state(net, encap, family, cfg, lws, extack); + if (ret) + module_put(ops->owner); + } else { + /* don't rely on -EOPNOTSUPP to detect match as build_state + * handlers could return it + */ + NL_SET_ERR_MSG_ATTR(extack, encap, + "LWT encapsulation type not supported"); + } + + return ret; +} +EXPORT_SYMBOL_GPL(lwtunnel_build_state); + +int lwtunnel_valid_encap_type(u16 encap_type, struct netlink_ext_ack *extack) +{ + const struct lwtunnel_encap_ops *ops; + int ret = -EINVAL; + + if (encap_type == LWTUNNEL_ENCAP_NONE || + encap_type > LWTUNNEL_ENCAP_MAX) { + NL_SET_ERR_MSG(extack, "Unknown lwt encapsulation type"); + return ret; + } + + rcu_read_lock(); + ops = rcu_dereference(lwtun_encaps[encap_type]); + rcu_read_unlock(); +#ifdef CONFIG_MODULES + if (!ops) { + const char *encap_type_str = lwtunnel_encap_str(encap_type); + + if (encap_type_str) { + __rtnl_unlock(); + request_module("rtnl-lwt-%s", encap_type_str); + rtnl_lock(); + + rcu_read_lock(); + ops = rcu_dereference(lwtun_encaps[encap_type]); + rcu_read_unlock(); + } + } +#endif + ret = ops ? 0 : -EOPNOTSUPP; + if (ret < 0) + NL_SET_ERR_MSG(extack, "lwt encapsulation type not supported"); + + return ret; +} +EXPORT_SYMBOL_GPL(lwtunnel_valid_encap_type); + +int lwtunnel_valid_encap_type_attr(struct nlattr *attr, int remaining, + struct netlink_ext_ack *extack) +{ + struct rtnexthop *rtnh = (struct rtnexthop *)attr; + struct nlattr *nla_entype; + struct nlattr *attrs; + u16 encap_type; + int attrlen; + + while (rtnh_ok(rtnh, remaining)) { + attrlen = rtnh_attrlen(rtnh); + if (attrlen > 0) { + attrs = rtnh_attrs(rtnh); + nla_entype = nla_find(attrs, attrlen, RTA_ENCAP_TYPE); + + if (nla_entype) { + if (nla_len(nla_entype) < sizeof(u16)) { + NL_SET_ERR_MSG(extack, "Invalid RTA_ENCAP_TYPE"); + return -EINVAL; + } + encap_type = nla_get_u16(nla_entype); + + if (lwtunnel_valid_encap_type(encap_type, + extack) != 0) + return -EOPNOTSUPP; + } + } + rtnh = rtnh_next(rtnh, &remaining); + } + + return 0; +} +EXPORT_SYMBOL_GPL(lwtunnel_valid_encap_type_attr); + +void lwtstate_free(struct lwtunnel_state *lws) +{ + const struct lwtunnel_encap_ops *ops = lwtun_encaps[lws->type]; + + if (ops->destroy_state) { + ops->destroy_state(lws); + kfree_rcu(lws, rcu); + } else { + kfree(lws); + } + module_put(ops->owner); +} +EXPORT_SYMBOL_GPL(lwtstate_free); + +int lwtunnel_fill_encap(struct sk_buff *skb, struct lwtunnel_state *lwtstate, + int encap_attr, int encap_type_attr) +{ + const struct lwtunnel_encap_ops *ops; + struct nlattr *nest; + int ret; + + if (!lwtstate) + return 0; + + if (lwtstate->type == LWTUNNEL_ENCAP_NONE || + lwtstate->type > LWTUNNEL_ENCAP_MAX) + return 0; + + nest = nla_nest_start_noflag(skb, encap_attr); + if (!nest) + return -EMSGSIZE; + + ret = -EOPNOTSUPP; + rcu_read_lock(); + ops = rcu_dereference(lwtun_encaps[lwtstate->type]); + if (likely(ops && ops->fill_encap)) + ret = ops->fill_encap(skb, lwtstate); + rcu_read_unlock(); + + if (ret) + goto nla_put_failure; + nla_nest_end(skb, nest); + ret = nla_put_u16(skb, encap_type_attr, lwtstate->type); + if (ret) + goto nla_put_failure; + + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nest); + + return (ret == -EOPNOTSUPP ? 0 : ret); +} +EXPORT_SYMBOL_GPL(lwtunnel_fill_encap); + +int lwtunnel_get_encap_size(struct lwtunnel_state *lwtstate) +{ + const struct lwtunnel_encap_ops *ops; + int ret = 0; + + if (!lwtstate) + return 0; + + if (lwtstate->type == LWTUNNEL_ENCAP_NONE || + lwtstate->type > LWTUNNEL_ENCAP_MAX) + return 0; + + rcu_read_lock(); + ops = rcu_dereference(lwtun_encaps[lwtstate->type]); + if (likely(ops && ops->get_encap_size)) + ret = nla_total_size(ops->get_encap_size(lwtstate)); + rcu_read_unlock(); + + return ret; +} +EXPORT_SYMBOL_GPL(lwtunnel_get_encap_size); + +int lwtunnel_cmp_encap(struct lwtunnel_state *a, struct lwtunnel_state *b) +{ + const struct lwtunnel_encap_ops *ops; + int ret = 0; + + if (!a && !b) + return 0; + + if (!a || !b) + return 1; + + if (a->type != b->type) + return 1; + + if (a->type == LWTUNNEL_ENCAP_NONE || + a->type > LWTUNNEL_ENCAP_MAX) + return 0; + + rcu_read_lock(); + ops = rcu_dereference(lwtun_encaps[a->type]); + if (likely(ops && ops->cmp_encap)) + ret = ops->cmp_encap(a, b); + rcu_read_unlock(); + + return ret; +} +EXPORT_SYMBOL_GPL(lwtunnel_cmp_encap); + +int lwtunnel_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + const struct lwtunnel_encap_ops *ops; + struct lwtunnel_state *lwtstate; + int ret = -EINVAL; + + if (!dst) + goto drop; + lwtstate = dst->lwtstate; + + if (lwtstate->type == LWTUNNEL_ENCAP_NONE || + lwtstate->type > LWTUNNEL_ENCAP_MAX) + return 0; + + ret = -EOPNOTSUPP; + rcu_read_lock(); + ops = rcu_dereference(lwtun_encaps[lwtstate->type]); + if (likely(ops && ops->output)) + ret = ops->output(net, sk, skb); + rcu_read_unlock(); + + if (ret == -EOPNOTSUPP) + goto drop; + + return ret; + +drop: + kfree_skb(skb); + + return ret; +} +EXPORT_SYMBOL_GPL(lwtunnel_output); + +int lwtunnel_xmit(struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + const struct lwtunnel_encap_ops *ops; + struct lwtunnel_state *lwtstate; + int ret = -EINVAL; + + if (!dst) + goto drop; + + lwtstate = dst->lwtstate; + + if (lwtstate->type == LWTUNNEL_ENCAP_NONE || + lwtstate->type > LWTUNNEL_ENCAP_MAX) + return 0; + + ret = -EOPNOTSUPP; + rcu_read_lock(); + ops = rcu_dereference(lwtun_encaps[lwtstate->type]); + if (likely(ops && ops->xmit)) + ret = ops->xmit(skb); + rcu_read_unlock(); + + if (ret == -EOPNOTSUPP) + goto drop; + + return ret; + +drop: + kfree_skb(skb); + + return ret; +} +EXPORT_SYMBOL_GPL(lwtunnel_xmit); + +int lwtunnel_input(struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + const struct lwtunnel_encap_ops *ops; + struct lwtunnel_state *lwtstate; + int ret = -EINVAL; + + if (!dst) + goto drop; + lwtstate = dst->lwtstate; + + if (lwtstate->type == LWTUNNEL_ENCAP_NONE || + lwtstate->type > LWTUNNEL_ENCAP_MAX) + return 0; + + ret = -EOPNOTSUPP; + rcu_read_lock(); + ops = rcu_dereference(lwtun_encaps[lwtstate->type]); + if (likely(ops && ops->input)) + ret = ops->input(skb); + rcu_read_unlock(); + + if (ret == -EOPNOTSUPP) + goto drop; + + return ret; + +drop: + kfree_skb(skb); + + return ret; +} +EXPORT_SYMBOL_GPL(lwtunnel_input); diff --git a/net/core/neighbour.c b/net/core/neighbour.c new file mode 100644 index 000000000..c842f150c --- /dev/null +++ b/net/core/neighbour.c @@ -0,0 +1,3887 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Generic address resolution entity + * + * Authors: + * Pedro Roque + * Alexey Kuznetsov + * + * Fixes: + * Vitaly E. Lavrov releasing NULL neighbor in neigh_add. + * Harald Welte Add neighbour cache statistics like rtstat + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_SYSCTL +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#define NEIGH_DEBUG 1 +#define neigh_dbg(level, fmt, ...) \ +do { \ + if (level <= NEIGH_DEBUG) \ + pr_debug(fmt, ##__VA_ARGS__); \ +} while (0) + +#define PNEIGH_HASHMASK 0xF + +static void neigh_timer_handler(struct timer_list *t); +static void __neigh_notify(struct neighbour *n, int type, int flags, + u32 pid); +static void neigh_update_notify(struct neighbour *neigh, u32 nlmsg_pid); +static int pneigh_ifdown_and_unlock(struct neigh_table *tbl, + struct net_device *dev); + +#ifdef CONFIG_PROC_FS +static const struct seq_operations neigh_stat_seq_ops; +#endif + +/* + Neighbour hash table buckets are protected with rwlock tbl->lock. + + - All the scans/updates to hash buckets MUST be made under this lock. + - NOTHING clever should be made under this lock: no callbacks + to protocol backends, no attempts to send something to network. + It will result in deadlocks, if backend/driver wants to use neighbour + cache. + - If the entry requires some non-trivial actions, increase + its reference count and release table lock. + + Neighbour entries are protected: + - with reference count. + - with rwlock neigh->lock + + Reference count prevents destruction. + + neigh->lock mainly serializes ll address data and its validity state. + However, the same lock is used to protect another entry fields: + - timer + - resolution queue + + Again, nothing clever shall be made under neigh->lock, + the most complicated procedure, which we allow is dev->hard_header. + It is supposed, that dev->hard_header is simplistic and does + not make callbacks to neighbour tables. + */ + +static int neigh_blackhole(struct neighbour *neigh, struct sk_buff *skb) +{ + kfree_skb(skb); + return -ENETDOWN; +} + +static void neigh_cleanup_and_release(struct neighbour *neigh) +{ + trace_neigh_cleanup_and_release(neigh, 0); + __neigh_notify(neigh, RTM_DELNEIGH, 0, 0); + call_netevent_notifiers(NETEVENT_NEIGH_UPDATE, neigh); + neigh_release(neigh); +} + +/* + * It is random distribution in the interval (1/2)*base...(3/2)*base. + * It corresponds to default IPv6 settings and is not overridable, + * because it is really reasonable choice. + */ + +unsigned long neigh_rand_reach_time(unsigned long base) +{ + return base ? prandom_u32_max(base) + (base >> 1) : 0; +} +EXPORT_SYMBOL(neigh_rand_reach_time); + +static void neigh_mark_dead(struct neighbour *n) +{ + n->dead = 1; + if (!list_empty(&n->gc_list)) { + list_del_init(&n->gc_list); + atomic_dec(&n->tbl->gc_entries); + } + if (!list_empty(&n->managed_list)) + list_del_init(&n->managed_list); +} + +static void neigh_update_gc_list(struct neighbour *n) +{ + bool on_gc_list, exempt_from_gc; + + write_lock_bh(&n->tbl->lock); + write_lock(&n->lock); + if (n->dead) + goto out; + + /* remove from the gc list if new state is permanent or if neighbor + * is externally learned; otherwise entry should be on the gc list + */ + exempt_from_gc = n->nud_state & NUD_PERMANENT || + n->flags & NTF_EXT_LEARNED; + on_gc_list = !list_empty(&n->gc_list); + + if (exempt_from_gc && on_gc_list) { + list_del_init(&n->gc_list); + atomic_dec(&n->tbl->gc_entries); + } else if (!exempt_from_gc && !on_gc_list) { + /* add entries to the tail; cleaning removes from the front */ + list_add_tail(&n->gc_list, &n->tbl->gc_list); + atomic_inc(&n->tbl->gc_entries); + } +out: + write_unlock(&n->lock); + write_unlock_bh(&n->tbl->lock); +} + +static void neigh_update_managed_list(struct neighbour *n) +{ + bool on_managed_list, add_to_managed; + + write_lock_bh(&n->tbl->lock); + write_lock(&n->lock); + if (n->dead) + goto out; + + add_to_managed = n->flags & NTF_MANAGED; + on_managed_list = !list_empty(&n->managed_list); + + if (!add_to_managed && on_managed_list) + list_del_init(&n->managed_list); + else if (add_to_managed && !on_managed_list) + list_add_tail(&n->managed_list, &n->tbl->managed_list); +out: + write_unlock(&n->lock); + write_unlock_bh(&n->tbl->lock); +} + +static void neigh_update_flags(struct neighbour *neigh, u32 flags, int *notify, + bool *gc_update, bool *managed_update) +{ + u32 ndm_flags, old_flags = neigh->flags; + + if (!(flags & NEIGH_UPDATE_F_ADMIN)) + return; + + ndm_flags = (flags & NEIGH_UPDATE_F_EXT_LEARNED) ? NTF_EXT_LEARNED : 0; + ndm_flags |= (flags & NEIGH_UPDATE_F_MANAGED) ? NTF_MANAGED : 0; + + if ((old_flags ^ ndm_flags) & NTF_EXT_LEARNED) { + if (ndm_flags & NTF_EXT_LEARNED) + neigh->flags |= NTF_EXT_LEARNED; + else + neigh->flags &= ~NTF_EXT_LEARNED; + *notify = 1; + *gc_update = true; + } + if ((old_flags ^ ndm_flags) & NTF_MANAGED) { + if (ndm_flags & NTF_MANAGED) + neigh->flags |= NTF_MANAGED; + else + neigh->flags &= ~NTF_MANAGED; + *notify = 1; + *managed_update = true; + } +} + +static bool neigh_del(struct neighbour *n, struct neighbour __rcu **np, + struct neigh_table *tbl) +{ + bool retval = false; + + write_lock(&n->lock); + if (refcount_read(&n->refcnt) == 1) { + struct neighbour *neigh; + + neigh = rcu_dereference_protected(n->next, + lockdep_is_held(&tbl->lock)); + rcu_assign_pointer(*np, neigh); + neigh_mark_dead(n); + retval = true; + } + write_unlock(&n->lock); + if (retval) + neigh_cleanup_and_release(n); + return retval; +} + +bool neigh_remove_one(struct neighbour *ndel, struct neigh_table *tbl) +{ + struct neigh_hash_table *nht; + void *pkey = ndel->primary_key; + u32 hash_val; + struct neighbour *n; + struct neighbour __rcu **np; + + nht = rcu_dereference_protected(tbl->nht, + lockdep_is_held(&tbl->lock)); + hash_val = tbl->hash(pkey, ndel->dev, nht->hash_rnd); + hash_val = hash_val >> (32 - nht->hash_shift); + + np = &nht->hash_buckets[hash_val]; + while ((n = rcu_dereference_protected(*np, + lockdep_is_held(&tbl->lock)))) { + if (n == ndel) + return neigh_del(n, np, tbl); + np = &n->next; + } + return false; +} + +static int neigh_forced_gc(struct neigh_table *tbl) +{ + int max_clean = atomic_read(&tbl->gc_entries) - + READ_ONCE(tbl->gc_thresh2); + u64 tmax = ktime_get_ns() + NSEC_PER_MSEC; + unsigned long tref = jiffies - 5 * HZ; + struct neighbour *n, *tmp; + int shrunk = 0; + int loop = 0; + + NEIGH_CACHE_STAT_INC(tbl, forced_gc_runs); + + write_lock_bh(&tbl->lock); + + list_for_each_entry_safe(n, tmp, &tbl->gc_list, gc_list) { + if (refcount_read(&n->refcnt) == 1) { + bool remove = false; + + write_lock(&n->lock); + if ((n->nud_state == NUD_FAILED) || + (n->nud_state == NUD_NOARP) || + (tbl->is_multicast && + tbl->is_multicast(n->primary_key)) || + !time_in_range(n->updated, tref, jiffies)) + remove = true; + write_unlock(&n->lock); + + if (remove && neigh_remove_one(n, tbl)) + shrunk++; + if (shrunk >= max_clean) + break; + if (++loop == 16) { + if (ktime_get_ns() > tmax) + goto unlock; + loop = 0; + } + } + } + + WRITE_ONCE(tbl->last_flush, jiffies); +unlock: + write_unlock_bh(&tbl->lock); + + return shrunk; +} + +static void neigh_add_timer(struct neighbour *n, unsigned long when) +{ + /* Use safe distance from the jiffies - LONG_MAX point while timer + * is running in DELAY/PROBE state but still show to user space + * large times in the past. + */ + unsigned long mint = jiffies - (LONG_MAX - 86400 * HZ); + + neigh_hold(n); + if (!time_in_range(n->confirmed, mint, jiffies)) + n->confirmed = mint; + if (time_before(n->used, n->confirmed)) + n->used = n->confirmed; + if (unlikely(mod_timer(&n->timer, when))) { + printk("NEIGH: BUG, double timer add, state is %x\n", + n->nud_state); + dump_stack(); + } +} + +static int neigh_del_timer(struct neighbour *n) +{ + if ((n->nud_state & NUD_IN_TIMER) && + del_timer(&n->timer)) { + neigh_release(n); + return 1; + } + return 0; +} + +static struct neigh_parms *neigh_get_dev_parms_rcu(struct net_device *dev, + int family) +{ + switch (family) { + case AF_INET: + return __in_dev_arp_parms_get_rcu(dev); + case AF_INET6: + return __in6_dev_nd_parms_get_rcu(dev); + } + return NULL; +} + +static void neigh_parms_qlen_dec(struct net_device *dev, int family) +{ + struct neigh_parms *p; + + rcu_read_lock(); + p = neigh_get_dev_parms_rcu(dev, family); + if (p) + p->qlen--; + rcu_read_unlock(); +} + +static void pneigh_queue_purge(struct sk_buff_head *list, struct net *net, + int family) +{ + struct sk_buff_head tmp; + unsigned long flags; + struct sk_buff *skb; + + skb_queue_head_init(&tmp); + spin_lock_irqsave(&list->lock, flags); + skb = skb_peek(list); + while (skb != NULL) { + struct sk_buff *skb_next = skb_peek_next(skb, list); + struct net_device *dev = skb->dev; + + if (net == NULL || net_eq(dev_net(dev), net)) { + neigh_parms_qlen_dec(dev, family); + __skb_unlink(skb, list); + __skb_queue_tail(&tmp, skb); + } + skb = skb_next; + } + spin_unlock_irqrestore(&list->lock, flags); + + while ((skb = __skb_dequeue(&tmp))) { + dev_put(skb->dev); + kfree_skb(skb); + } +} + +static void neigh_flush_dev(struct neigh_table *tbl, struct net_device *dev, + bool skip_perm) +{ + int i; + struct neigh_hash_table *nht; + + nht = rcu_dereference_protected(tbl->nht, + lockdep_is_held(&tbl->lock)); + + for (i = 0; i < (1 << nht->hash_shift); i++) { + struct neighbour *n; + struct neighbour __rcu **np = &nht->hash_buckets[i]; + + while ((n = rcu_dereference_protected(*np, + lockdep_is_held(&tbl->lock))) != NULL) { + if (dev && n->dev != dev) { + np = &n->next; + continue; + } + if (skip_perm && n->nud_state & NUD_PERMANENT) { + np = &n->next; + continue; + } + rcu_assign_pointer(*np, + rcu_dereference_protected(n->next, + lockdep_is_held(&tbl->lock))); + write_lock(&n->lock); + neigh_del_timer(n); + neigh_mark_dead(n); + if (refcount_read(&n->refcnt) != 1) { + /* The most unpleasant situation. + We must destroy neighbour entry, + but someone still uses it. + + The destroy will be delayed until + the last user releases us, but + we must kill timers etc. and move + it to safe state. + */ + __skb_queue_purge(&n->arp_queue); + n->arp_queue_len_bytes = 0; + WRITE_ONCE(n->output, neigh_blackhole); + if (n->nud_state & NUD_VALID) + n->nud_state = NUD_NOARP; + else + n->nud_state = NUD_NONE; + neigh_dbg(2, "neigh %p is stray\n", n); + } + write_unlock(&n->lock); + neigh_cleanup_and_release(n); + } + } +} + +void neigh_changeaddr(struct neigh_table *tbl, struct net_device *dev) +{ + write_lock_bh(&tbl->lock); + neigh_flush_dev(tbl, dev, false); + write_unlock_bh(&tbl->lock); +} +EXPORT_SYMBOL(neigh_changeaddr); + +static int __neigh_ifdown(struct neigh_table *tbl, struct net_device *dev, + bool skip_perm) +{ + write_lock_bh(&tbl->lock); + neigh_flush_dev(tbl, dev, skip_perm); + pneigh_ifdown_and_unlock(tbl, dev); + pneigh_queue_purge(&tbl->proxy_queue, dev ? dev_net(dev) : NULL, + tbl->family); + if (skb_queue_empty_lockless(&tbl->proxy_queue)) + del_timer_sync(&tbl->proxy_timer); + return 0; +} + +int neigh_carrier_down(struct neigh_table *tbl, struct net_device *dev) +{ + __neigh_ifdown(tbl, dev, true); + return 0; +} +EXPORT_SYMBOL(neigh_carrier_down); + +int neigh_ifdown(struct neigh_table *tbl, struct net_device *dev) +{ + __neigh_ifdown(tbl, dev, false); + return 0; +} +EXPORT_SYMBOL(neigh_ifdown); + +static struct neighbour *neigh_alloc(struct neigh_table *tbl, + struct net_device *dev, + u32 flags, bool exempt_from_gc) +{ + struct neighbour *n = NULL; + unsigned long now = jiffies; + int entries, gc_thresh3; + + if (exempt_from_gc) + goto do_alloc; + + entries = atomic_inc_return(&tbl->gc_entries) - 1; + gc_thresh3 = READ_ONCE(tbl->gc_thresh3); + if (entries >= gc_thresh3 || + (entries >= READ_ONCE(tbl->gc_thresh2) && + time_after(now, READ_ONCE(tbl->last_flush) + 5 * HZ))) { + if (!neigh_forced_gc(tbl) && entries >= gc_thresh3) { + net_info_ratelimited("%s: neighbor table overflow!\n", + tbl->id); + NEIGH_CACHE_STAT_INC(tbl, table_fulls); + goto out_entries; + } + } + +do_alloc: + n = kzalloc(tbl->entry_size + dev->neigh_priv_len, GFP_ATOMIC); + if (!n) + goto out_entries; + + __skb_queue_head_init(&n->arp_queue); + rwlock_init(&n->lock); + seqlock_init(&n->ha_lock); + n->updated = n->used = now; + n->nud_state = NUD_NONE; + n->output = neigh_blackhole; + n->flags = flags; + seqlock_init(&n->hh.hh_lock); + n->parms = neigh_parms_clone(&tbl->parms); + timer_setup(&n->timer, neigh_timer_handler, 0); + + NEIGH_CACHE_STAT_INC(tbl, allocs); + n->tbl = tbl; + refcount_set(&n->refcnt, 1); + n->dead = 1; + INIT_LIST_HEAD(&n->gc_list); + INIT_LIST_HEAD(&n->managed_list); + + atomic_inc(&tbl->entries); +out: + return n; + +out_entries: + if (!exempt_from_gc) + atomic_dec(&tbl->gc_entries); + goto out; +} + +static void neigh_get_hash_rnd(u32 *x) +{ + *x = get_random_u32() | 1; +} + +static struct neigh_hash_table *neigh_hash_alloc(unsigned int shift) +{ + size_t size = (1 << shift) * sizeof(struct neighbour *); + struct neigh_hash_table *ret; + struct neighbour __rcu **buckets; + int i; + + ret = kmalloc(sizeof(*ret), GFP_ATOMIC); + if (!ret) + return NULL; + if (size <= PAGE_SIZE) { + buckets = kzalloc(size, GFP_ATOMIC); + } else { + buckets = (struct neighbour __rcu **) + __get_free_pages(GFP_ATOMIC | __GFP_ZERO, + get_order(size)); + kmemleak_alloc(buckets, size, 1, GFP_ATOMIC); + } + if (!buckets) { + kfree(ret); + return NULL; + } + ret->hash_buckets = buckets; + ret->hash_shift = shift; + for (i = 0; i < NEIGH_NUM_HASH_RND; i++) + neigh_get_hash_rnd(&ret->hash_rnd[i]); + return ret; +} + +static void neigh_hash_free_rcu(struct rcu_head *head) +{ + struct neigh_hash_table *nht = container_of(head, + struct neigh_hash_table, + rcu); + size_t size = (1 << nht->hash_shift) * sizeof(struct neighbour *); + struct neighbour __rcu **buckets = nht->hash_buckets; + + if (size <= PAGE_SIZE) { + kfree(buckets); + } else { + kmemleak_free(buckets); + free_pages((unsigned long)buckets, get_order(size)); + } + kfree(nht); +} + +static struct neigh_hash_table *neigh_hash_grow(struct neigh_table *tbl, + unsigned long new_shift) +{ + unsigned int i, hash; + struct neigh_hash_table *new_nht, *old_nht; + + NEIGH_CACHE_STAT_INC(tbl, hash_grows); + + old_nht = rcu_dereference_protected(tbl->nht, + lockdep_is_held(&tbl->lock)); + new_nht = neigh_hash_alloc(new_shift); + if (!new_nht) + return old_nht; + + for (i = 0; i < (1 << old_nht->hash_shift); i++) { + struct neighbour *n, *next; + + for (n = rcu_dereference_protected(old_nht->hash_buckets[i], + lockdep_is_held(&tbl->lock)); + n != NULL; + n = next) { + hash = tbl->hash(n->primary_key, n->dev, + new_nht->hash_rnd); + + hash >>= (32 - new_nht->hash_shift); + next = rcu_dereference_protected(n->next, + lockdep_is_held(&tbl->lock)); + + rcu_assign_pointer(n->next, + rcu_dereference_protected( + new_nht->hash_buckets[hash], + lockdep_is_held(&tbl->lock))); + rcu_assign_pointer(new_nht->hash_buckets[hash], n); + } + } + + rcu_assign_pointer(tbl->nht, new_nht); + call_rcu(&old_nht->rcu, neigh_hash_free_rcu); + return new_nht; +} + +struct neighbour *neigh_lookup(struct neigh_table *tbl, const void *pkey, + struct net_device *dev) +{ + struct neighbour *n; + + NEIGH_CACHE_STAT_INC(tbl, lookups); + + rcu_read_lock(); + n = __neigh_lookup_noref(tbl, pkey, dev); + if (n) { + if (!refcount_inc_not_zero(&n->refcnt)) + n = NULL; + NEIGH_CACHE_STAT_INC(tbl, hits); + } + + rcu_read_unlock(); + return n; +} +EXPORT_SYMBOL(neigh_lookup); + +static struct neighbour * +___neigh_create(struct neigh_table *tbl, const void *pkey, + struct net_device *dev, u32 flags, + bool exempt_from_gc, bool want_ref) +{ + u32 hash_val, key_len = tbl->key_len; + struct neighbour *n1, *rc, *n; + struct neigh_hash_table *nht; + int error; + + n = neigh_alloc(tbl, dev, flags, exempt_from_gc); + trace_neigh_create(tbl, dev, pkey, n, exempt_from_gc); + if (!n) { + rc = ERR_PTR(-ENOBUFS); + goto out; + } + + memcpy(n->primary_key, pkey, key_len); + n->dev = dev; + netdev_hold(dev, &n->dev_tracker, GFP_ATOMIC); + + /* Protocol specific setup. */ + if (tbl->constructor && (error = tbl->constructor(n)) < 0) { + rc = ERR_PTR(error); + goto out_neigh_release; + } + + if (dev->netdev_ops->ndo_neigh_construct) { + error = dev->netdev_ops->ndo_neigh_construct(dev, n); + if (error < 0) { + rc = ERR_PTR(error); + goto out_neigh_release; + } + } + + /* Device specific setup. */ + if (n->parms->neigh_setup && + (error = n->parms->neigh_setup(n)) < 0) { + rc = ERR_PTR(error); + goto out_neigh_release; + } + + n->confirmed = jiffies - (NEIGH_VAR(n->parms, BASE_REACHABLE_TIME) << 1); + + write_lock_bh(&tbl->lock); + nht = rcu_dereference_protected(tbl->nht, + lockdep_is_held(&tbl->lock)); + + if (atomic_read(&tbl->entries) > (1 << nht->hash_shift)) + nht = neigh_hash_grow(tbl, nht->hash_shift + 1); + + hash_val = tbl->hash(n->primary_key, dev, nht->hash_rnd) >> (32 - nht->hash_shift); + + if (n->parms->dead) { + rc = ERR_PTR(-EINVAL); + goto out_tbl_unlock; + } + + for (n1 = rcu_dereference_protected(nht->hash_buckets[hash_val], + lockdep_is_held(&tbl->lock)); + n1 != NULL; + n1 = rcu_dereference_protected(n1->next, + lockdep_is_held(&tbl->lock))) { + if (dev == n1->dev && !memcmp(n1->primary_key, n->primary_key, key_len)) { + if (want_ref) + neigh_hold(n1); + rc = n1; + goto out_tbl_unlock; + } + } + + n->dead = 0; + if (!exempt_from_gc) + list_add_tail(&n->gc_list, &n->tbl->gc_list); + if (n->flags & NTF_MANAGED) + list_add_tail(&n->managed_list, &n->tbl->managed_list); + if (want_ref) + neigh_hold(n); + rcu_assign_pointer(n->next, + rcu_dereference_protected(nht->hash_buckets[hash_val], + lockdep_is_held(&tbl->lock))); + rcu_assign_pointer(nht->hash_buckets[hash_val], n); + write_unlock_bh(&tbl->lock); + neigh_dbg(2, "neigh %p is created\n", n); + rc = n; +out: + return rc; +out_tbl_unlock: + write_unlock_bh(&tbl->lock); +out_neigh_release: + if (!exempt_from_gc) + atomic_dec(&tbl->gc_entries); + neigh_release(n); + goto out; +} + +struct neighbour *__neigh_create(struct neigh_table *tbl, const void *pkey, + struct net_device *dev, bool want_ref) +{ + return ___neigh_create(tbl, pkey, dev, 0, false, want_ref); +} +EXPORT_SYMBOL(__neigh_create); + +static u32 pneigh_hash(const void *pkey, unsigned int key_len) +{ + u32 hash_val = *(u32 *)(pkey + key_len - 4); + hash_val ^= (hash_val >> 16); + hash_val ^= hash_val >> 8; + hash_val ^= hash_val >> 4; + hash_val &= PNEIGH_HASHMASK; + return hash_val; +} + +static struct pneigh_entry *__pneigh_lookup_1(struct pneigh_entry *n, + struct net *net, + const void *pkey, + unsigned int key_len, + struct net_device *dev) +{ + while (n) { + if (!memcmp(n->key, pkey, key_len) && + net_eq(pneigh_net(n), net) && + (n->dev == dev || !n->dev)) + return n; + n = n->next; + } + return NULL; +} + +struct pneigh_entry *__pneigh_lookup(struct neigh_table *tbl, + struct net *net, const void *pkey, struct net_device *dev) +{ + unsigned int key_len = tbl->key_len; + u32 hash_val = pneigh_hash(pkey, key_len); + + return __pneigh_lookup_1(tbl->phash_buckets[hash_val], + net, pkey, key_len, dev); +} +EXPORT_SYMBOL_GPL(__pneigh_lookup); + +struct pneigh_entry * pneigh_lookup(struct neigh_table *tbl, + struct net *net, const void *pkey, + struct net_device *dev, int creat) +{ + struct pneigh_entry *n; + unsigned int key_len = tbl->key_len; + u32 hash_val = pneigh_hash(pkey, key_len); + + read_lock_bh(&tbl->lock); + n = __pneigh_lookup_1(tbl->phash_buckets[hash_val], + net, pkey, key_len, dev); + read_unlock_bh(&tbl->lock); + + if (n || !creat) + goto out; + + ASSERT_RTNL(); + + n = kzalloc(sizeof(*n) + key_len, GFP_KERNEL); + if (!n) + goto out; + + write_pnet(&n->net, net); + memcpy(n->key, pkey, key_len); + n->dev = dev; + netdev_hold(dev, &n->dev_tracker, GFP_KERNEL); + + if (tbl->pconstructor && tbl->pconstructor(n)) { + netdev_put(dev, &n->dev_tracker); + kfree(n); + n = NULL; + goto out; + } + + write_lock_bh(&tbl->lock); + n->next = tbl->phash_buckets[hash_val]; + tbl->phash_buckets[hash_val] = n; + write_unlock_bh(&tbl->lock); +out: + return n; +} +EXPORT_SYMBOL(pneigh_lookup); + + +int pneigh_delete(struct neigh_table *tbl, struct net *net, const void *pkey, + struct net_device *dev) +{ + struct pneigh_entry *n, **np; + unsigned int key_len = tbl->key_len; + u32 hash_val = pneigh_hash(pkey, key_len); + + write_lock_bh(&tbl->lock); + for (np = &tbl->phash_buckets[hash_val]; (n = *np) != NULL; + np = &n->next) { + if (!memcmp(n->key, pkey, key_len) && n->dev == dev && + net_eq(pneigh_net(n), net)) { + *np = n->next; + write_unlock_bh(&tbl->lock); + if (tbl->pdestructor) + tbl->pdestructor(n); + netdev_put(n->dev, &n->dev_tracker); + kfree(n); + return 0; + } + } + write_unlock_bh(&tbl->lock); + return -ENOENT; +} + +static int pneigh_ifdown_and_unlock(struct neigh_table *tbl, + struct net_device *dev) +{ + struct pneigh_entry *n, **np, *freelist = NULL; + u32 h; + + for (h = 0; h <= PNEIGH_HASHMASK; h++) { + np = &tbl->phash_buckets[h]; + while ((n = *np) != NULL) { + if (!dev || n->dev == dev) { + *np = n->next; + n->next = freelist; + freelist = n; + continue; + } + np = &n->next; + } + } + write_unlock_bh(&tbl->lock); + while ((n = freelist)) { + freelist = n->next; + n->next = NULL; + if (tbl->pdestructor) + tbl->pdestructor(n); + netdev_put(n->dev, &n->dev_tracker); + kfree(n); + } + return -ENOENT; +} + +static void neigh_parms_destroy(struct neigh_parms *parms); + +static inline void neigh_parms_put(struct neigh_parms *parms) +{ + if (refcount_dec_and_test(&parms->refcnt)) + neigh_parms_destroy(parms); +} + +/* + * neighbour must already be out of the table; + * + */ +void neigh_destroy(struct neighbour *neigh) +{ + struct net_device *dev = neigh->dev; + + NEIGH_CACHE_STAT_INC(neigh->tbl, destroys); + + if (!neigh->dead) { + pr_warn("Destroying alive neighbour %p\n", neigh); + dump_stack(); + return; + } + + if (neigh_del_timer(neigh)) + pr_warn("Impossible event\n"); + + write_lock_bh(&neigh->lock); + __skb_queue_purge(&neigh->arp_queue); + write_unlock_bh(&neigh->lock); + neigh->arp_queue_len_bytes = 0; + + if (dev->netdev_ops->ndo_neigh_destroy) + dev->netdev_ops->ndo_neigh_destroy(dev, neigh); + + netdev_put(dev, &neigh->dev_tracker); + neigh_parms_put(neigh->parms); + + neigh_dbg(2, "neigh %p is destroyed\n", neigh); + + atomic_dec(&neigh->tbl->entries); + kfree_rcu(neigh, rcu); +} +EXPORT_SYMBOL(neigh_destroy); + +/* Neighbour state is suspicious; + disable fast path. + + Called with write_locked neigh. + */ +static void neigh_suspect(struct neighbour *neigh) +{ + neigh_dbg(2, "neigh %p is suspected\n", neigh); + + WRITE_ONCE(neigh->output, neigh->ops->output); +} + +/* Neighbour state is OK; + enable fast path. + + Called with write_locked neigh. + */ +static void neigh_connect(struct neighbour *neigh) +{ + neigh_dbg(2, "neigh %p is connected\n", neigh); + + WRITE_ONCE(neigh->output, neigh->ops->connected_output); +} + +static void neigh_periodic_work(struct work_struct *work) +{ + struct neigh_table *tbl = container_of(work, struct neigh_table, gc_work.work); + struct neighbour *n; + struct neighbour __rcu **np; + unsigned int i; + struct neigh_hash_table *nht; + + NEIGH_CACHE_STAT_INC(tbl, periodic_gc_runs); + + write_lock_bh(&tbl->lock); + nht = rcu_dereference_protected(tbl->nht, + lockdep_is_held(&tbl->lock)); + + /* + * periodically recompute ReachableTime from random function + */ + + if (time_after(jiffies, tbl->last_rand + 300 * HZ)) { + struct neigh_parms *p; + + WRITE_ONCE(tbl->last_rand, jiffies); + list_for_each_entry(p, &tbl->parms_list, list) + p->reachable_time = + neigh_rand_reach_time(NEIGH_VAR(p, BASE_REACHABLE_TIME)); + } + + if (atomic_read(&tbl->entries) < READ_ONCE(tbl->gc_thresh1)) + goto out; + + for (i = 0 ; i < (1 << nht->hash_shift); i++) { + np = &nht->hash_buckets[i]; + + while ((n = rcu_dereference_protected(*np, + lockdep_is_held(&tbl->lock))) != NULL) { + unsigned int state; + + write_lock(&n->lock); + + state = n->nud_state; + if ((state & (NUD_PERMANENT | NUD_IN_TIMER)) || + (n->flags & NTF_EXT_LEARNED)) { + write_unlock(&n->lock); + goto next_elt; + } + + if (time_before(n->used, n->confirmed) && + time_is_before_eq_jiffies(n->confirmed)) + n->used = n->confirmed; + + if (refcount_read(&n->refcnt) == 1 && + (state == NUD_FAILED || + !time_in_range_open(jiffies, n->used, + n->used + NEIGH_VAR(n->parms, GC_STALETIME)))) { + rcu_assign_pointer(*np, + rcu_dereference_protected(n->next, + lockdep_is_held(&tbl->lock))); + neigh_mark_dead(n); + write_unlock(&n->lock); + neigh_cleanup_and_release(n); + continue; + } + write_unlock(&n->lock); + +next_elt: + np = &n->next; + } + /* + * It's fine to release lock here, even if hash table + * grows while we are preempted. + */ + write_unlock_bh(&tbl->lock); + cond_resched(); + write_lock_bh(&tbl->lock); + nht = rcu_dereference_protected(tbl->nht, + lockdep_is_held(&tbl->lock)); + } +out: + /* Cycle through all hash buckets every BASE_REACHABLE_TIME/2 ticks. + * ARP entry timeouts range from 1/2 BASE_REACHABLE_TIME to 3/2 + * BASE_REACHABLE_TIME. + */ + queue_delayed_work(system_power_efficient_wq, &tbl->gc_work, + NEIGH_VAR(&tbl->parms, BASE_REACHABLE_TIME) >> 1); + write_unlock_bh(&tbl->lock); +} + +static __inline__ int neigh_max_probes(struct neighbour *n) +{ + struct neigh_parms *p = n->parms; + return NEIGH_VAR(p, UCAST_PROBES) + NEIGH_VAR(p, APP_PROBES) + + (n->nud_state & NUD_PROBE ? NEIGH_VAR(p, MCAST_REPROBES) : + NEIGH_VAR(p, MCAST_PROBES)); +} + +static void neigh_invalidate(struct neighbour *neigh) + __releases(neigh->lock) + __acquires(neigh->lock) +{ + struct sk_buff *skb; + + NEIGH_CACHE_STAT_INC(neigh->tbl, res_failed); + neigh_dbg(2, "neigh %p is failed\n", neigh); + neigh->updated = jiffies; + + /* It is very thin place. report_unreachable is very complicated + routine. Particularly, it can hit the same neighbour entry! + + So that, we try to be accurate and avoid dead loop. --ANK + */ + while (neigh->nud_state == NUD_FAILED && + (skb = __skb_dequeue(&neigh->arp_queue)) != NULL) { + write_unlock(&neigh->lock); + neigh->ops->error_report(neigh, skb); + write_lock(&neigh->lock); + } + __skb_queue_purge(&neigh->arp_queue); + neigh->arp_queue_len_bytes = 0; +} + +static void neigh_probe(struct neighbour *neigh) + __releases(neigh->lock) +{ + struct sk_buff *skb = skb_peek_tail(&neigh->arp_queue); + /* keep skb alive even if arp_queue overflows */ + if (skb) + skb = skb_clone(skb, GFP_ATOMIC); + write_unlock(&neigh->lock); + if (neigh->ops->solicit) + neigh->ops->solicit(neigh, skb); + atomic_inc(&neigh->probes); + consume_skb(skb); +} + +/* Called when a timer expires for a neighbour entry. */ + +static void neigh_timer_handler(struct timer_list *t) +{ + unsigned long now, next; + struct neighbour *neigh = from_timer(neigh, t, timer); + unsigned int state; + int notify = 0; + + write_lock(&neigh->lock); + + state = neigh->nud_state; + now = jiffies; + next = now + HZ; + + if (!(state & NUD_IN_TIMER)) + goto out; + + if (state & NUD_REACHABLE) { + if (time_before_eq(now, + neigh->confirmed + neigh->parms->reachable_time)) { + neigh_dbg(2, "neigh %p is still alive\n", neigh); + next = neigh->confirmed + neigh->parms->reachable_time; + } else if (time_before_eq(now, + neigh->used + + NEIGH_VAR(neigh->parms, DELAY_PROBE_TIME))) { + neigh_dbg(2, "neigh %p is delayed\n", neigh); + WRITE_ONCE(neigh->nud_state, NUD_DELAY); + neigh->updated = jiffies; + neigh_suspect(neigh); + next = now + NEIGH_VAR(neigh->parms, DELAY_PROBE_TIME); + } else { + neigh_dbg(2, "neigh %p is suspected\n", neigh); + WRITE_ONCE(neigh->nud_state, NUD_STALE); + neigh->updated = jiffies; + neigh_suspect(neigh); + notify = 1; + } + } else if (state & NUD_DELAY) { + if (time_before_eq(now, + neigh->confirmed + + NEIGH_VAR(neigh->parms, DELAY_PROBE_TIME))) { + neigh_dbg(2, "neigh %p is now reachable\n", neigh); + WRITE_ONCE(neigh->nud_state, NUD_REACHABLE); + neigh->updated = jiffies; + neigh_connect(neigh); + notify = 1; + next = neigh->confirmed + neigh->parms->reachable_time; + } else { + neigh_dbg(2, "neigh %p is probed\n", neigh); + WRITE_ONCE(neigh->nud_state, NUD_PROBE); + neigh->updated = jiffies; + atomic_set(&neigh->probes, 0); + notify = 1; + next = now + max(NEIGH_VAR(neigh->parms, RETRANS_TIME), + HZ/100); + } + } else { + /* NUD_PROBE|NUD_INCOMPLETE */ + next = now + max(NEIGH_VAR(neigh->parms, RETRANS_TIME), HZ/100); + } + + if ((neigh->nud_state & (NUD_INCOMPLETE | NUD_PROBE)) && + atomic_read(&neigh->probes) >= neigh_max_probes(neigh)) { + WRITE_ONCE(neigh->nud_state, NUD_FAILED); + notify = 1; + neigh_invalidate(neigh); + goto out; + } + + if (neigh->nud_state & NUD_IN_TIMER) { + if (time_before(next, jiffies + HZ/100)) + next = jiffies + HZ/100; + if (!mod_timer(&neigh->timer, next)) + neigh_hold(neigh); + } + if (neigh->nud_state & (NUD_INCOMPLETE | NUD_PROBE)) { + neigh_probe(neigh); + } else { +out: + write_unlock(&neigh->lock); + } + + if (notify) + neigh_update_notify(neigh, 0); + + trace_neigh_timer_handler(neigh, 0); + + neigh_release(neigh); +} + +int __neigh_event_send(struct neighbour *neigh, struct sk_buff *skb, + const bool immediate_ok) +{ + int rc; + bool immediate_probe = false; + + write_lock_bh(&neigh->lock); + + rc = 0; + if (neigh->nud_state & (NUD_CONNECTED | NUD_DELAY | NUD_PROBE)) + goto out_unlock_bh; + if (neigh->dead) + goto out_dead; + + if (!(neigh->nud_state & (NUD_STALE | NUD_INCOMPLETE))) { + if (NEIGH_VAR(neigh->parms, MCAST_PROBES) + + NEIGH_VAR(neigh->parms, APP_PROBES)) { + unsigned long next, now = jiffies; + + atomic_set(&neigh->probes, + NEIGH_VAR(neigh->parms, UCAST_PROBES)); + neigh_del_timer(neigh); + WRITE_ONCE(neigh->nud_state, NUD_INCOMPLETE); + neigh->updated = now; + if (!immediate_ok) { + next = now + 1; + } else { + immediate_probe = true; + next = now + max(NEIGH_VAR(neigh->parms, + RETRANS_TIME), + HZ / 100); + } + neigh_add_timer(neigh, next); + } else { + WRITE_ONCE(neigh->nud_state, NUD_FAILED); + neigh->updated = jiffies; + write_unlock_bh(&neigh->lock); + + kfree_skb_reason(skb, SKB_DROP_REASON_NEIGH_FAILED); + return 1; + } + } else if (neigh->nud_state & NUD_STALE) { + neigh_dbg(2, "neigh %p is delayed\n", neigh); + neigh_del_timer(neigh); + WRITE_ONCE(neigh->nud_state, NUD_DELAY); + neigh->updated = jiffies; + neigh_add_timer(neigh, jiffies + + NEIGH_VAR(neigh->parms, DELAY_PROBE_TIME)); + } + + if (neigh->nud_state == NUD_INCOMPLETE) { + if (skb) { + while (neigh->arp_queue_len_bytes + skb->truesize > + NEIGH_VAR(neigh->parms, QUEUE_LEN_BYTES)) { + struct sk_buff *buff; + + buff = __skb_dequeue(&neigh->arp_queue); + if (!buff) + break; + neigh->arp_queue_len_bytes -= buff->truesize; + kfree_skb_reason(buff, SKB_DROP_REASON_NEIGH_QUEUEFULL); + NEIGH_CACHE_STAT_INC(neigh->tbl, unres_discards); + } + skb_dst_force(skb); + __skb_queue_tail(&neigh->arp_queue, skb); + neigh->arp_queue_len_bytes += skb->truesize; + } + rc = 1; + } +out_unlock_bh: + if (immediate_probe) + neigh_probe(neigh); + else + write_unlock(&neigh->lock); + local_bh_enable(); + trace_neigh_event_send_done(neigh, rc); + return rc; + +out_dead: + if (neigh->nud_state & NUD_STALE) + goto out_unlock_bh; + write_unlock_bh(&neigh->lock); + kfree_skb_reason(skb, SKB_DROP_REASON_NEIGH_DEAD); + trace_neigh_event_send_dead(neigh, 1); + return 1; +} +EXPORT_SYMBOL(__neigh_event_send); + +static void neigh_update_hhs(struct neighbour *neigh) +{ + struct hh_cache *hh; + void (*update)(struct hh_cache*, const struct net_device*, const unsigned char *) + = NULL; + + if (neigh->dev->header_ops) + update = neigh->dev->header_ops->cache_update; + + if (update) { + hh = &neigh->hh; + if (READ_ONCE(hh->hh_len)) { + write_seqlock_bh(&hh->hh_lock); + update(hh, neigh->dev, neigh->ha); + write_sequnlock_bh(&hh->hh_lock); + } + } +} + +/* Generic update routine. + -- lladdr is new lladdr or NULL, if it is not supplied. + -- new is new state. + -- flags + NEIGH_UPDATE_F_OVERRIDE allows to override existing lladdr, + if it is different. + NEIGH_UPDATE_F_WEAK_OVERRIDE will suspect existing "connected" + lladdr instead of overriding it + if it is different. + NEIGH_UPDATE_F_ADMIN means that the change is administrative. + NEIGH_UPDATE_F_USE means that the entry is user triggered. + NEIGH_UPDATE_F_MANAGED means that the entry will be auto-refreshed. + NEIGH_UPDATE_F_OVERRIDE_ISROUTER allows to override existing + NTF_ROUTER flag. + NEIGH_UPDATE_F_ISROUTER indicates if the neighbour is known as + a router. + + Caller MUST hold reference count on the entry. + */ +static int __neigh_update(struct neighbour *neigh, const u8 *lladdr, + u8 new, u32 flags, u32 nlmsg_pid, + struct netlink_ext_ack *extack) +{ + bool gc_update = false, managed_update = false; + int update_isrouter = 0; + struct net_device *dev; + int err, notify = 0; + u8 old; + + trace_neigh_update(neigh, lladdr, new, flags, nlmsg_pid); + + write_lock_bh(&neigh->lock); + + dev = neigh->dev; + old = neigh->nud_state; + err = -EPERM; + + if (neigh->dead) { + NL_SET_ERR_MSG(extack, "Neighbor entry is now dead"); + new = old; + goto out; + } + if (!(flags & NEIGH_UPDATE_F_ADMIN) && + (old & (NUD_NOARP | NUD_PERMANENT))) + goto out; + + neigh_update_flags(neigh, flags, ¬ify, &gc_update, &managed_update); + if (flags & (NEIGH_UPDATE_F_USE | NEIGH_UPDATE_F_MANAGED)) { + new = old & ~NUD_PERMANENT; + WRITE_ONCE(neigh->nud_state, new); + err = 0; + goto out; + } + + if (!(new & NUD_VALID)) { + neigh_del_timer(neigh); + if (old & NUD_CONNECTED) + neigh_suspect(neigh); + WRITE_ONCE(neigh->nud_state, new); + err = 0; + notify = old & NUD_VALID; + if ((old & (NUD_INCOMPLETE | NUD_PROBE)) && + (new & NUD_FAILED)) { + neigh_invalidate(neigh); + notify = 1; + } + goto out; + } + + /* Compare new lladdr with cached one */ + if (!dev->addr_len) { + /* First case: device needs no address. */ + lladdr = neigh->ha; + } else if (lladdr) { + /* The second case: if something is already cached + and a new address is proposed: + - compare new & old + - if they are different, check override flag + */ + if ((old & NUD_VALID) && + !memcmp(lladdr, neigh->ha, dev->addr_len)) + lladdr = neigh->ha; + } else { + /* No address is supplied; if we know something, + use it, otherwise discard the request. + */ + err = -EINVAL; + if (!(old & NUD_VALID)) { + NL_SET_ERR_MSG(extack, "No link layer address given"); + goto out; + } + lladdr = neigh->ha; + } + + /* Update confirmed timestamp for neighbour entry after we + * received ARP packet even if it doesn't change IP to MAC binding. + */ + if (new & NUD_CONNECTED) + neigh->confirmed = jiffies; + + /* If entry was valid and address is not changed, + do not change entry state, if new one is STALE. + */ + err = 0; + update_isrouter = flags & NEIGH_UPDATE_F_OVERRIDE_ISROUTER; + if (old & NUD_VALID) { + if (lladdr != neigh->ha && !(flags & NEIGH_UPDATE_F_OVERRIDE)) { + update_isrouter = 0; + if ((flags & NEIGH_UPDATE_F_WEAK_OVERRIDE) && + (old & NUD_CONNECTED)) { + lladdr = neigh->ha; + new = NUD_STALE; + } else + goto out; + } else { + if (lladdr == neigh->ha && new == NUD_STALE && + !(flags & NEIGH_UPDATE_F_ADMIN)) + new = old; + } + } + + /* Update timestamp only once we know we will make a change to the + * neighbour entry. Otherwise we risk to move the locktime window with + * noop updates and ignore relevant ARP updates. + */ + if (new != old || lladdr != neigh->ha) + neigh->updated = jiffies; + + if (new != old) { + neigh_del_timer(neigh); + if (new & NUD_PROBE) + atomic_set(&neigh->probes, 0); + if (new & NUD_IN_TIMER) + neigh_add_timer(neigh, (jiffies + + ((new & NUD_REACHABLE) ? + neigh->parms->reachable_time : + 0))); + WRITE_ONCE(neigh->nud_state, new); + notify = 1; + } + + if (lladdr != neigh->ha) { + write_seqlock(&neigh->ha_lock); + memcpy(&neigh->ha, lladdr, dev->addr_len); + write_sequnlock(&neigh->ha_lock); + neigh_update_hhs(neigh); + if (!(new & NUD_CONNECTED)) + neigh->confirmed = jiffies - + (NEIGH_VAR(neigh->parms, BASE_REACHABLE_TIME) << 1); + notify = 1; + } + if (new == old) + goto out; + if (new & NUD_CONNECTED) + neigh_connect(neigh); + else + neigh_suspect(neigh); + if (!(old & NUD_VALID)) { + struct sk_buff *skb; + + /* Again: avoid dead loop if something went wrong */ + + while (neigh->nud_state & NUD_VALID && + (skb = __skb_dequeue(&neigh->arp_queue)) != NULL) { + struct dst_entry *dst = skb_dst(skb); + struct neighbour *n2, *n1 = neigh; + write_unlock_bh(&neigh->lock); + + rcu_read_lock(); + + /* Why not just use 'neigh' as-is? The problem is that + * things such as shaper, eql, and sch_teql can end up + * using alternative, different, neigh objects to output + * the packet in the output path. So what we need to do + * here is re-lookup the top-level neigh in the path so + * we can reinject the packet there. + */ + n2 = NULL; + if (dst && dst->obsolete != DST_OBSOLETE_DEAD) { + n2 = dst_neigh_lookup_skb(dst, skb); + if (n2) + n1 = n2; + } + READ_ONCE(n1->output)(n1, skb); + if (n2) + neigh_release(n2); + rcu_read_unlock(); + + write_lock_bh(&neigh->lock); + } + __skb_queue_purge(&neigh->arp_queue); + neigh->arp_queue_len_bytes = 0; + } +out: + if (update_isrouter) + neigh_update_is_router(neigh, flags, ¬ify); + write_unlock_bh(&neigh->lock); + if (((new ^ old) & NUD_PERMANENT) || gc_update) + neigh_update_gc_list(neigh); + if (managed_update) + neigh_update_managed_list(neigh); + if (notify) + neigh_update_notify(neigh, nlmsg_pid); + trace_neigh_update_done(neigh, err); + return err; +} + +int neigh_update(struct neighbour *neigh, const u8 *lladdr, u8 new, + u32 flags, u32 nlmsg_pid) +{ + return __neigh_update(neigh, lladdr, new, flags, nlmsg_pid, NULL); +} +EXPORT_SYMBOL(neigh_update); + +/* Update the neigh to listen temporarily for probe responses, even if it is + * in a NUD_FAILED state. The caller has to hold neigh->lock for writing. + */ +void __neigh_set_probe_once(struct neighbour *neigh) +{ + if (neigh->dead) + return; + neigh->updated = jiffies; + if (!(neigh->nud_state & NUD_FAILED)) + return; + WRITE_ONCE(neigh->nud_state, NUD_INCOMPLETE); + atomic_set(&neigh->probes, neigh_max_probes(neigh)); + neigh_add_timer(neigh, + jiffies + max(NEIGH_VAR(neigh->parms, RETRANS_TIME), + HZ/100)); +} +EXPORT_SYMBOL(__neigh_set_probe_once); + +struct neighbour *neigh_event_ns(struct neigh_table *tbl, + u8 *lladdr, void *saddr, + struct net_device *dev) +{ + struct neighbour *neigh = __neigh_lookup(tbl, saddr, dev, + lladdr || !dev->addr_len); + if (neigh) + neigh_update(neigh, lladdr, NUD_STALE, + NEIGH_UPDATE_F_OVERRIDE, 0); + return neigh; +} +EXPORT_SYMBOL(neigh_event_ns); + +/* called with read_lock_bh(&n->lock); */ +static void neigh_hh_init(struct neighbour *n) +{ + struct net_device *dev = n->dev; + __be16 prot = n->tbl->protocol; + struct hh_cache *hh = &n->hh; + + write_lock_bh(&n->lock); + + /* Only one thread can come in here and initialize the + * hh_cache entry. + */ + if (!hh->hh_len) + dev->header_ops->cache(n, hh, prot); + + write_unlock_bh(&n->lock); +} + +/* Slow and careful. */ + +int neigh_resolve_output(struct neighbour *neigh, struct sk_buff *skb) +{ + int rc = 0; + + if (!neigh_event_send(neigh, skb)) { + int err; + struct net_device *dev = neigh->dev; + unsigned int seq; + + if (dev->header_ops->cache && !READ_ONCE(neigh->hh.hh_len)) + neigh_hh_init(neigh); + + do { + __skb_pull(skb, skb_network_offset(skb)); + seq = read_seqbegin(&neigh->ha_lock); + err = dev_hard_header(skb, dev, ntohs(skb->protocol), + neigh->ha, NULL, skb->len); + } while (read_seqretry(&neigh->ha_lock, seq)); + + if (err >= 0) + rc = dev_queue_xmit(skb); + else + goto out_kfree_skb; + } +out: + return rc; +out_kfree_skb: + rc = -EINVAL; + kfree_skb(skb); + goto out; +} +EXPORT_SYMBOL(neigh_resolve_output); + +/* As fast as possible without hh cache */ + +int neigh_connected_output(struct neighbour *neigh, struct sk_buff *skb) +{ + struct net_device *dev = neigh->dev; + unsigned int seq; + int err; + + do { + __skb_pull(skb, skb_network_offset(skb)); + seq = read_seqbegin(&neigh->ha_lock); + err = dev_hard_header(skb, dev, ntohs(skb->protocol), + neigh->ha, NULL, skb->len); + } while (read_seqretry(&neigh->ha_lock, seq)); + + if (err >= 0) + err = dev_queue_xmit(skb); + else { + err = -EINVAL; + kfree_skb(skb); + } + return err; +} +EXPORT_SYMBOL(neigh_connected_output); + +int neigh_direct_output(struct neighbour *neigh, struct sk_buff *skb) +{ + return dev_queue_xmit(skb); +} +EXPORT_SYMBOL(neigh_direct_output); + +static void neigh_managed_work(struct work_struct *work) +{ + struct neigh_table *tbl = container_of(work, struct neigh_table, + managed_work.work); + struct neighbour *neigh; + + write_lock_bh(&tbl->lock); + list_for_each_entry(neigh, &tbl->managed_list, managed_list) + neigh_event_send_probe(neigh, NULL, false); + queue_delayed_work(system_power_efficient_wq, &tbl->managed_work, + NEIGH_VAR(&tbl->parms, INTERVAL_PROBE_TIME_MS)); + write_unlock_bh(&tbl->lock); +} + +static void neigh_proxy_process(struct timer_list *t) +{ + struct neigh_table *tbl = from_timer(tbl, t, proxy_timer); + long sched_next = 0; + unsigned long now = jiffies; + struct sk_buff *skb, *n; + + spin_lock(&tbl->proxy_queue.lock); + + skb_queue_walk_safe(&tbl->proxy_queue, skb, n) { + long tdif = NEIGH_CB(skb)->sched_next - now; + + if (tdif <= 0) { + struct net_device *dev = skb->dev; + + neigh_parms_qlen_dec(dev, tbl->family); + __skb_unlink(skb, &tbl->proxy_queue); + + if (tbl->proxy_redo && netif_running(dev)) { + rcu_read_lock(); + tbl->proxy_redo(skb); + rcu_read_unlock(); + } else { + kfree_skb(skb); + } + + dev_put(dev); + } else if (!sched_next || tdif < sched_next) + sched_next = tdif; + } + del_timer(&tbl->proxy_timer); + if (sched_next) + mod_timer(&tbl->proxy_timer, jiffies + sched_next); + spin_unlock(&tbl->proxy_queue.lock); +} + +void pneigh_enqueue(struct neigh_table *tbl, struct neigh_parms *p, + struct sk_buff *skb) +{ + unsigned long sched_next = jiffies + + prandom_u32_max(NEIGH_VAR(p, PROXY_DELAY)); + + if (p->qlen > NEIGH_VAR(p, PROXY_QLEN)) { + kfree_skb(skb); + return; + } + + NEIGH_CB(skb)->sched_next = sched_next; + NEIGH_CB(skb)->flags |= LOCALLY_ENQUEUED; + + spin_lock(&tbl->proxy_queue.lock); + if (del_timer(&tbl->proxy_timer)) { + if (time_before(tbl->proxy_timer.expires, sched_next)) + sched_next = tbl->proxy_timer.expires; + } + skb_dst_drop(skb); + dev_hold(skb->dev); + __skb_queue_tail(&tbl->proxy_queue, skb); + p->qlen++; + mod_timer(&tbl->proxy_timer, sched_next); + spin_unlock(&tbl->proxy_queue.lock); +} +EXPORT_SYMBOL(pneigh_enqueue); + +static inline struct neigh_parms *lookup_neigh_parms(struct neigh_table *tbl, + struct net *net, int ifindex) +{ + struct neigh_parms *p; + + list_for_each_entry(p, &tbl->parms_list, list) { + if ((p->dev && p->dev->ifindex == ifindex && net_eq(neigh_parms_net(p), net)) || + (!p->dev && !ifindex && net_eq(net, &init_net))) + return p; + } + + return NULL; +} + +struct neigh_parms *neigh_parms_alloc(struct net_device *dev, + struct neigh_table *tbl) +{ + struct neigh_parms *p; + struct net *net = dev_net(dev); + const struct net_device_ops *ops = dev->netdev_ops; + + p = kmemdup(&tbl->parms, sizeof(*p), GFP_KERNEL); + if (p) { + p->tbl = tbl; + refcount_set(&p->refcnt, 1); + p->reachable_time = + neigh_rand_reach_time(NEIGH_VAR(p, BASE_REACHABLE_TIME)); + p->qlen = 0; + netdev_hold(dev, &p->dev_tracker, GFP_KERNEL); + p->dev = dev; + write_pnet(&p->net, net); + p->sysctl_table = NULL; + + if (ops->ndo_neigh_setup && ops->ndo_neigh_setup(dev, p)) { + netdev_put(dev, &p->dev_tracker); + kfree(p); + return NULL; + } + + write_lock_bh(&tbl->lock); + list_add(&p->list, &tbl->parms.list); + write_unlock_bh(&tbl->lock); + + neigh_parms_data_state_cleanall(p); + } + return p; +} +EXPORT_SYMBOL(neigh_parms_alloc); + +static void neigh_rcu_free_parms(struct rcu_head *head) +{ + struct neigh_parms *parms = + container_of(head, struct neigh_parms, rcu_head); + + neigh_parms_put(parms); +} + +void neigh_parms_release(struct neigh_table *tbl, struct neigh_parms *parms) +{ + if (!parms || parms == &tbl->parms) + return; + write_lock_bh(&tbl->lock); + list_del(&parms->list); + parms->dead = 1; + write_unlock_bh(&tbl->lock); + netdev_put(parms->dev, &parms->dev_tracker); + call_rcu(&parms->rcu_head, neigh_rcu_free_parms); +} +EXPORT_SYMBOL(neigh_parms_release); + +static void neigh_parms_destroy(struct neigh_parms *parms) +{ + kfree(parms); +} + +static struct lock_class_key neigh_table_proxy_queue_class; + +static struct neigh_table *neigh_tables[NEIGH_NR_TABLES] __read_mostly; + +void neigh_table_init(int index, struct neigh_table *tbl) +{ + unsigned long now = jiffies; + unsigned long phsize; + + INIT_LIST_HEAD(&tbl->parms_list); + INIT_LIST_HEAD(&tbl->gc_list); + INIT_LIST_HEAD(&tbl->managed_list); + + list_add(&tbl->parms.list, &tbl->parms_list); + write_pnet(&tbl->parms.net, &init_net); + refcount_set(&tbl->parms.refcnt, 1); + tbl->parms.reachable_time = + neigh_rand_reach_time(NEIGH_VAR(&tbl->parms, BASE_REACHABLE_TIME)); + tbl->parms.qlen = 0; + + tbl->stats = alloc_percpu(struct neigh_statistics); + if (!tbl->stats) + panic("cannot create neighbour cache statistics"); + +#ifdef CONFIG_PROC_FS + if (!proc_create_seq_data(tbl->id, 0, init_net.proc_net_stat, + &neigh_stat_seq_ops, tbl)) + panic("cannot create neighbour proc dir entry"); +#endif + + RCU_INIT_POINTER(tbl->nht, neigh_hash_alloc(3)); + + phsize = (PNEIGH_HASHMASK + 1) * sizeof(struct pneigh_entry *); + tbl->phash_buckets = kzalloc(phsize, GFP_KERNEL); + + if (!tbl->nht || !tbl->phash_buckets) + panic("cannot allocate neighbour cache hashes"); + + if (!tbl->entry_size) + tbl->entry_size = ALIGN(offsetof(struct neighbour, primary_key) + + tbl->key_len, NEIGH_PRIV_ALIGN); + else + WARN_ON(tbl->entry_size % NEIGH_PRIV_ALIGN); + + rwlock_init(&tbl->lock); + + INIT_DEFERRABLE_WORK(&tbl->gc_work, neigh_periodic_work); + queue_delayed_work(system_power_efficient_wq, &tbl->gc_work, + tbl->parms.reachable_time); + INIT_DEFERRABLE_WORK(&tbl->managed_work, neigh_managed_work); + queue_delayed_work(system_power_efficient_wq, &tbl->managed_work, 0); + + timer_setup(&tbl->proxy_timer, neigh_proxy_process, 0); + skb_queue_head_init_class(&tbl->proxy_queue, + &neigh_table_proxy_queue_class); + + tbl->last_flush = now; + tbl->last_rand = now + tbl->parms.reachable_time * 20; + + neigh_tables[index] = tbl; +} +EXPORT_SYMBOL(neigh_table_init); + +int neigh_table_clear(int index, struct neigh_table *tbl) +{ + neigh_tables[index] = NULL; + /* It is not clean... Fix it to unload IPv6 module safely */ + cancel_delayed_work_sync(&tbl->managed_work); + cancel_delayed_work_sync(&tbl->gc_work); + del_timer_sync(&tbl->proxy_timer); + pneigh_queue_purge(&tbl->proxy_queue, NULL, tbl->family); + neigh_ifdown(tbl, NULL); + if (atomic_read(&tbl->entries)) + pr_crit("neighbour leakage\n"); + + call_rcu(&rcu_dereference_protected(tbl->nht, 1)->rcu, + neigh_hash_free_rcu); + tbl->nht = NULL; + + kfree(tbl->phash_buckets); + tbl->phash_buckets = NULL; + + remove_proc_entry(tbl->id, init_net.proc_net_stat); + + free_percpu(tbl->stats); + tbl->stats = NULL; + + return 0; +} +EXPORT_SYMBOL(neigh_table_clear); + +static struct neigh_table *neigh_find_table(int family) +{ + struct neigh_table *tbl = NULL; + + switch (family) { + case AF_INET: + tbl = neigh_tables[NEIGH_ARP_TABLE]; + break; + case AF_INET6: + tbl = neigh_tables[NEIGH_ND_TABLE]; + break; + } + + return tbl; +} + +const struct nla_policy nda_policy[NDA_MAX+1] = { + [NDA_UNSPEC] = { .strict_start_type = NDA_NH_ID }, + [NDA_DST] = { .type = NLA_BINARY, .len = MAX_ADDR_LEN }, + [NDA_LLADDR] = { .type = NLA_BINARY, .len = MAX_ADDR_LEN }, + [NDA_CACHEINFO] = { .len = sizeof(struct nda_cacheinfo) }, + [NDA_PROBES] = { .type = NLA_U32 }, + [NDA_VLAN] = { .type = NLA_U16 }, + [NDA_PORT] = { .type = NLA_U16 }, + [NDA_VNI] = { .type = NLA_U32 }, + [NDA_IFINDEX] = { .type = NLA_U32 }, + [NDA_MASTER] = { .type = NLA_U32 }, + [NDA_PROTOCOL] = { .type = NLA_U8 }, + [NDA_NH_ID] = { .type = NLA_U32 }, + [NDA_FLAGS_EXT] = NLA_POLICY_MASK(NLA_U32, NTF_EXT_MASK), + [NDA_FDB_EXT_ATTRS] = { .type = NLA_NESTED }, +}; + +static int neigh_delete(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct ndmsg *ndm; + struct nlattr *dst_attr; + struct neigh_table *tbl; + struct neighbour *neigh; + struct net_device *dev = NULL; + int err = -EINVAL; + + ASSERT_RTNL(); + if (nlmsg_len(nlh) < sizeof(*ndm)) + goto out; + + dst_attr = nlmsg_find_attr(nlh, sizeof(*ndm), NDA_DST); + if (!dst_attr) { + NL_SET_ERR_MSG(extack, "Network address not specified"); + goto out; + } + + ndm = nlmsg_data(nlh); + if (ndm->ndm_ifindex) { + dev = __dev_get_by_index(net, ndm->ndm_ifindex); + if (dev == NULL) { + err = -ENODEV; + goto out; + } + } + + tbl = neigh_find_table(ndm->ndm_family); + if (tbl == NULL) + return -EAFNOSUPPORT; + + if (nla_len(dst_attr) < (int)tbl->key_len) { + NL_SET_ERR_MSG(extack, "Invalid network address"); + goto out; + } + + if (ndm->ndm_flags & NTF_PROXY) { + err = pneigh_delete(tbl, net, nla_data(dst_attr), dev); + goto out; + } + + if (dev == NULL) + goto out; + + neigh = neigh_lookup(tbl, nla_data(dst_attr), dev); + if (neigh == NULL) { + err = -ENOENT; + goto out; + } + + err = __neigh_update(neigh, NULL, NUD_FAILED, + NEIGH_UPDATE_F_OVERRIDE | NEIGH_UPDATE_F_ADMIN, + NETLINK_CB(skb).portid, extack); + write_lock_bh(&tbl->lock); + neigh_release(neigh); + neigh_remove_one(neigh, tbl); + write_unlock_bh(&tbl->lock); + +out: + return err; +} + +static int neigh_add(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + int flags = NEIGH_UPDATE_F_ADMIN | NEIGH_UPDATE_F_OVERRIDE | + NEIGH_UPDATE_F_OVERRIDE_ISROUTER; + struct net *net = sock_net(skb->sk); + struct ndmsg *ndm; + struct nlattr *tb[NDA_MAX+1]; + struct neigh_table *tbl; + struct net_device *dev = NULL; + struct neighbour *neigh; + void *dst, *lladdr; + u8 protocol = 0; + u32 ndm_flags; + int err; + + ASSERT_RTNL(); + err = nlmsg_parse_deprecated(nlh, sizeof(*ndm), tb, NDA_MAX, + nda_policy, extack); + if (err < 0) + goto out; + + err = -EINVAL; + if (!tb[NDA_DST]) { + NL_SET_ERR_MSG(extack, "Network address not specified"); + goto out; + } + + ndm = nlmsg_data(nlh); + ndm_flags = ndm->ndm_flags; + if (tb[NDA_FLAGS_EXT]) { + u32 ext = nla_get_u32(tb[NDA_FLAGS_EXT]); + + BUILD_BUG_ON(sizeof(neigh->flags) * BITS_PER_BYTE < + (sizeof(ndm->ndm_flags) * BITS_PER_BYTE + + hweight32(NTF_EXT_MASK))); + ndm_flags |= (ext << NTF_EXT_SHIFT); + } + if (ndm->ndm_ifindex) { + dev = __dev_get_by_index(net, ndm->ndm_ifindex); + if (dev == NULL) { + err = -ENODEV; + goto out; + } + + if (tb[NDA_LLADDR] && nla_len(tb[NDA_LLADDR]) < dev->addr_len) { + NL_SET_ERR_MSG(extack, "Invalid link address"); + goto out; + } + } + + tbl = neigh_find_table(ndm->ndm_family); + if (tbl == NULL) + return -EAFNOSUPPORT; + + if (nla_len(tb[NDA_DST]) < (int)tbl->key_len) { + NL_SET_ERR_MSG(extack, "Invalid network address"); + goto out; + } + + dst = nla_data(tb[NDA_DST]); + lladdr = tb[NDA_LLADDR] ? nla_data(tb[NDA_LLADDR]) : NULL; + + if (tb[NDA_PROTOCOL]) + protocol = nla_get_u8(tb[NDA_PROTOCOL]); + if (ndm_flags & NTF_PROXY) { + struct pneigh_entry *pn; + + if (ndm_flags & NTF_MANAGED) { + NL_SET_ERR_MSG(extack, "Invalid NTF_* flag combination"); + goto out; + } + + err = -ENOBUFS; + pn = pneigh_lookup(tbl, net, dst, dev, 1); + if (pn) { + pn->flags = ndm_flags; + if (protocol) + pn->protocol = protocol; + err = 0; + } + goto out; + } + + if (!dev) { + NL_SET_ERR_MSG(extack, "Device not specified"); + goto out; + } + + if (tbl->allow_add && !tbl->allow_add(dev, extack)) { + err = -EINVAL; + goto out; + } + + neigh = neigh_lookup(tbl, dst, dev); + if (neigh == NULL) { + bool ndm_permanent = ndm->ndm_state & NUD_PERMANENT; + bool exempt_from_gc = ndm_permanent || + ndm_flags & NTF_EXT_LEARNED; + + if (!(nlh->nlmsg_flags & NLM_F_CREATE)) { + err = -ENOENT; + goto out; + } + if (ndm_permanent && (ndm_flags & NTF_MANAGED)) { + NL_SET_ERR_MSG(extack, "Invalid NTF_* flag for permanent entry"); + err = -EINVAL; + goto out; + } + + neigh = ___neigh_create(tbl, dst, dev, + ndm_flags & + (NTF_EXT_LEARNED | NTF_MANAGED), + exempt_from_gc, true); + if (IS_ERR(neigh)) { + err = PTR_ERR(neigh); + goto out; + } + } else { + if (nlh->nlmsg_flags & NLM_F_EXCL) { + err = -EEXIST; + neigh_release(neigh); + goto out; + } + + if (!(nlh->nlmsg_flags & NLM_F_REPLACE)) + flags &= ~(NEIGH_UPDATE_F_OVERRIDE | + NEIGH_UPDATE_F_OVERRIDE_ISROUTER); + } + + if (protocol) + neigh->protocol = protocol; + if (ndm_flags & NTF_EXT_LEARNED) + flags |= NEIGH_UPDATE_F_EXT_LEARNED; + if (ndm_flags & NTF_ROUTER) + flags |= NEIGH_UPDATE_F_ISROUTER; + if (ndm_flags & NTF_MANAGED) + flags |= NEIGH_UPDATE_F_MANAGED; + if (ndm_flags & NTF_USE) + flags |= NEIGH_UPDATE_F_USE; + + err = __neigh_update(neigh, lladdr, ndm->ndm_state, flags, + NETLINK_CB(skb).portid, extack); + if (!err && ndm_flags & (NTF_USE | NTF_MANAGED)) { + neigh_event_send(neigh, NULL); + err = 0; + } + neigh_release(neigh); +out: + return err; +} + +static int neightbl_fill_parms(struct sk_buff *skb, struct neigh_parms *parms) +{ + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, NDTA_PARMS); + if (nest == NULL) + return -ENOBUFS; + + if ((parms->dev && + nla_put_u32(skb, NDTPA_IFINDEX, parms->dev->ifindex)) || + nla_put_u32(skb, NDTPA_REFCNT, refcount_read(&parms->refcnt)) || + nla_put_u32(skb, NDTPA_QUEUE_LENBYTES, + NEIGH_VAR(parms, QUEUE_LEN_BYTES)) || + /* approximative value for deprecated QUEUE_LEN (in packets) */ + nla_put_u32(skb, NDTPA_QUEUE_LEN, + NEIGH_VAR(parms, QUEUE_LEN_BYTES) / SKB_TRUESIZE(ETH_FRAME_LEN)) || + nla_put_u32(skb, NDTPA_PROXY_QLEN, NEIGH_VAR(parms, PROXY_QLEN)) || + nla_put_u32(skb, NDTPA_APP_PROBES, NEIGH_VAR(parms, APP_PROBES)) || + nla_put_u32(skb, NDTPA_UCAST_PROBES, + NEIGH_VAR(parms, UCAST_PROBES)) || + nla_put_u32(skb, NDTPA_MCAST_PROBES, + NEIGH_VAR(parms, MCAST_PROBES)) || + nla_put_u32(skb, NDTPA_MCAST_REPROBES, + NEIGH_VAR(parms, MCAST_REPROBES)) || + nla_put_msecs(skb, NDTPA_REACHABLE_TIME, parms->reachable_time, + NDTPA_PAD) || + nla_put_msecs(skb, NDTPA_BASE_REACHABLE_TIME, + NEIGH_VAR(parms, BASE_REACHABLE_TIME), NDTPA_PAD) || + nla_put_msecs(skb, NDTPA_GC_STALETIME, + NEIGH_VAR(parms, GC_STALETIME), NDTPA_PAD) || + nla_put_msecs(skb, NDTPA_DELAY_PROBE_TIME, + NEIGH_VAR(parms, DELAY_PROBE_TIME), NDTPA_PAD) || + nla_put_msecs(skb, NDTPA_RETRANS_TIME, + NEIGH_VAR(parms, RETRANS_TIME), NDTPA_PAD) || + nla_put_msecs(skb, NDTPA_ANYCAST_DELAY, + NEIGH_VAR(parms, ANYCAST_DELAY), NDTPA_PAD) || + nla_put_msecs(skb, NDTPA_PROXY_DELAY, + NEIGH_VAR(parms, PROXY_DELAY), NDTPA_PAD) || + nla_put_msecs(skb, NDTPA_LOCKTIME, + NEIGH_VAR(parms, LOCKTIME), NDTPA_PAD) || + nla_put_msecs(skb, NDTPA_INTERVAL_PROBE_TIME_MS, + NEIGH_VAR(parms, INTERVAL_PROBE_TIME_MS), NDTPA_PAD)) + goto nla_put_failure; + return nla_nest_end(skb, nest); + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int neightbl_fill_info(struct sk_buff *skb, struct neigh_table *tbl, + u32 pid, u32 seq, int type, int flags) +{ + struct nlmsghdr *nlh; + struct ndtmsg *ndtmsg; + + nlh = nlmsg_put(skb, pid, seq, type, sizeof(*ndtmsg), flags); + if (nlh == NULL) + return -EMSGSIZE; + + ndtmsg = nlmsg_data(nlh); + + read_lock_bh(&tbl->lock); + ndtmsg->ndtm_family = tbl->family; + ndtmsg->ndtm_pad1 = 0; + ndtmsg->ndtm_pad2 = 0; + + if (nla_put_string(skb, NDTA_NAME, tbl->id) || + nla_put_msecs(skb, NDTA_GC_INTERVAL, READ_ONCE(tbl->gc_interval), + NDTA_PAD) || + nla_put_u32(skb, NDTA_THRESH1, READ_ONCE(tbl->gc_thresh1)) || + nla_put_u32(skb, NDTA_THRESH2, READ_ONCE(tbl->gc_thresh2)) || + nla_put_u32(skb, NDTA_THRESH3, READ_ONCE(tbl->gc_thresh3))) + goto nla_put_failure; + { + unsigned long now = jiffies; + long flush_delta = now - READ_ONCE(tbl->last_flush); + long rand_delta = now - READ_ONCE(tbl->last_rand); + struct neigh_hash_table *nht; + struct ndt_config ndc = { + .ndtc_key_len = tbl->key_len, + .ndtc_entry_size = tbl->entry_size, + .ndtc_entries = atomic_read(&tbl->entries), + .ndtc_last_flush = jiffies_to_msecs(flush_delta), + .ndtc_last_rand = jiffies_to_msecs(rand_delta), + .ndtc_proxy_qlen = READ_ONCE(tbl->proxy_queue.qlen), + }; + + rcu_read_lock(); + nht = rcu_dereference(tbl->nht); + ndc.ndtc_hash_rnd = nht->hash_rnd[0]; + ndc.ndtc_hash_mask = ((1 << nht->hash_shift) - 1); + rcu_read_unlock(); + + if (nla_put(skb, NDTA_CONFIG, sizeof(ndc), &ndc)) + goto nla_put_failure; + } + + { + int cpu; + struct ndt_stats ndst; + + memset(&ndst, 0, sizeof(ndst)); + + for_each_possible_cpu(cpu) { + struct neigh_statistics *st; + + st = per_cpu_ptr(tbl->stats, cpu); + ndst.ndts_allocs += READ_ONCE(st->allocs); + ndst.ndts_destroys += READ_ONCE(st->destroys); + ndst.ndts_hash_grows += READ_ONCE(st->hash_grows); + ndst.ndts_res_failed += READ_ONCE(st->res_failed); + ndst.ndts_lookups += READ_ONCE(st->lookups); + ndst.ndts_hits += READ_ONCE(st->hits); + ndst.ndts_rcv_probes_mcast += READ_ONCE(st->rcv_probes_mcast); + ndst.ndts_rcv_probes_ucast += READ_ONCE(st->rcv_probes_ucast); + ndst.ndts_periodic_gc_runs += READ_ONCE(st->periodic_gc_runs); + ndst.ndts_forced_gc_runs += READ_ONCE(st->forced_gc_runs); + ndst.ndts_table_fulls += READ_ONCE(st->table_fulls); + } + + if (nla_put_64bit(skb, NDTA_STATS, sizeof(ndst), &ndst, + NDTA_PAD)) + goto nla_put_failure; + } + + BUG_ON(tbl->parms.dev); + if (neightbl_fill_parms(skb, &tbl->parms) < 0) + goto nla_put_failure; + + read_unlock_bh(&tbl->lock); + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + read_unlock_bh(&tbl->lock); + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int neightbl_fill_param_info(struct sk_buff *skb, + struct neigh_table *tbl, + struct neigh_parms *parms, + u32 pid, u32 seq, int type, + unsigned int flags) +{ + struct ndtmsg *ndtmsg; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(skb, pid, seq, type, sizeof(*ndtmsg), flags); + if (nlh == NULL) + return -EMSGSIZE; + + ndtmsg = nlmsg_data(nlh); + + read_lock_bh(&tbl->lock); + ndtmsg->ndtm_family = tbl->family; + ndtmsg->ndtm_pad1 = 0; + ndtmsg->ndtm_pad2 = 0; + + if (nla_put_string(skb, NDTA_NAME, tbl->id) < 0 || + neightbl_fill_parms(skb, parms) < 0) + goto errout; + + read_unlock_bh(&tbl->lock); + nlmsg_end(skb, nlh); + return 0; +errout: + read_unlock_bh(&tbl->lock); + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static const struct nla_policy nl_neightbl_policy[NDTA_MAX+1] = { + [NDTA_NAME] = { .type = NLA_STRING }, + [NDTA_THRESH1] = { .type = NLA_U32 }, + [NDTA_THRESH2] = { .type = NLA_U32 }, + [NDTA_THRESH3] = { .type = NLA_U32 }, + [NDTA_GC_INTERVAL] = { .type = NLA_U64 }, + [NDTA_PARMS] = { .type = NLA_NESTED }, +}; + +static const struct nla_policy nl_ntbl_parm_policy[NDTPA_MAX+1] = { + [NDTPA_IFINDEX] = { .type = NLA_U32 }, + [NDTPA_QUEUE_LEN] = { .type = NLA_U32 }, + [NDTPA_PROXY_QLEN] = { .type = NLA_U32 }, + [NDTPA_APP_PROBES] = { .type = NLA_U32 }, + [NDTPA_UCAST_PROBES] = { .type = NLA_U32 }, + [NDTPA_MCAST_PROBES] = { .type = NLA_U32 }, + [NDTPA_MCAST_REPROBES] = { .type = NLA_U32 }, + [NDTPA_BASE_REACHABLE_TIME] = { .type = NLA_U64 }, + [NDTPA_GC_STALETIME] = { .type = NLA_U64 }, + [NDTPA_DELAY_PROBE_TIME] = { .type = NLA_U64 }, + [NDTPA_RETRANS_TIME] = { .type = NLA_U64 }, + [NDTPA_ANYCAST_DELAY] = { .type = NLA_U64 }, + [NDTPA_PROXY_DELAY] = { .type = NLA_U64 }, + [NDTPA_LOCKTIME] = { .type = NLA_U64 }, + [NDTPA_INTERVAL_PROBE_TIME_MS] = { .type = NLA_U64, .min = 1 }, +}; + +static int neightbl_set(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct neigh_table *tbl; + struct ndtmsg *ndtmsg; + struct nlattr *tb[NDTA_MAX+1]; + bool found = false; + int err, tidx; + + err = nlmsg_parse_deprecated(nlh, sizeof(*ndtmsg), tb, NDTA_MAX, + nl_neightbl_policy, extack); + if (err < 0) + goto errout; + + if (tb[NDTA_NAME] == NULL) { + err = -EINVAL; + goto errout; + } + + ndtmsg = nlmsg_data(nlh); + + for (tidx = 0; tidx < NEIGH_NR_TABLES; tidx++) { + tbl = neigh_tables[tidx]; + if (!tbl) + continue; + if (ndtmsg->ndtm_family && tbl->family != ndtmsg->ndtm_family) + continue; + if (nla_strcmp(tb[NDTA_NAME], tbl->id) == 0) { + found = true; + break; + } + } + + if (!found) + return -ENOENT; + + /* + * We acquire tbl->lock to be nice to the periodic timers and + * make sure they always see a consistent set of values. + */ + write_lock_bh(&tbl->lock); + + if (tb[NDTA_PARMS]) { + struct nlattr *tbp[NDTPA_MAX+1]; + struct neigh_parms *p; + int i, ifindex = 0; + + err = nla_parse_nested_deprecated(tbp, NDTPA_MAX, + tb[NDTA_PARMS], + nl_ntbl_parm_policy, extack); + if (err < 0) + goto errout_tbl_lock; + + if (tbp[NDTPA_IFINDEX]) + ifindex = nla_get_u32(tbp[NDTPA_IFINDEX]); + + p = lookup_neigh_parms(tbl, net, ifindex); + if (p == NULL) { + err = -ENOENT; + goto errout_tbl_lock; + } + + for (i = 1; i <= NDTPA_MAX; i++) { + if (tbp[i] == NULL) + continue; + + switch (i) { + case NDTPA_QUEUE_LEN: + NEIGH_VAR_SET(p, QUEUE_LEN_BYTES, + nla_get_u32(tbp[i]) * + SKB_TRUESIZE(ETH_FRAME_LEN)); + break; + case NDTPA_QUEUE_LENBYTES: + NEIGH_VAR_SET(p, QUEUE_LEN_BYTES, + nla_get_u32(tbp[i])); + break; + case NDTPA_PROXY_QLEN: + NEIGH_VAR_SET(p, PROXY_QLEN, + nla_get_u32(tbp[i])); + break; + case NDTPA_APP_PROBES: + NEIGH_VAR_SET(p, APP_PROBES, + nla_get_u32(tbp[i])); + break; + case NDTPA_UCAST_PROBES: + NEIGH_VAR_SET(p, UCAST_PROBES, + nla_get_u32(tbp[i])); + break; + case NDTPA_MCAST_PROBES: + NEIGH_VAR_SET(p, MCAST_PROBES, + nla_get_u32(tbp[i])); + break; + case NDTPA_MCAST_REPROBES: + NEIGH_VAR_SET(p, MCAST_REPROBES, + nla_get_u32(tbp[i])); + break; + case NDTPA_BASE_REACHABLE_TIME: + NEIGH_VAR_SET(p, BASE_REACHABLE_TIME, + nla_get_msecs(tbp[i])); + /* update reachable_time as well, otherwise, the change will + * only be effective after the next time neigh_periodic_work + * decides to recompute it (can be multiple minutes) + */ + p->reachable_time = + neigh_rand_reach_time(NEIGH_VAR(p, BASE_REACHABLE_TIME)); + break; + case NDTPA_GC_STALETIME: + NEIGH_VAR_SET(p, GC_STALETIME, + nla_get_msecs(tbp[i])); + break; + case NDTPA_DELAY_PROBE_TIME: + NEIGH_VAR_SET(p, DELAY_PROBE_TIME, + nla_get_msecs(tbp[i])); + call_netevent_notifiers(NETEVENT_DELAY_PROBE_TIME_UPDATE, p); + break; + case NDTPA_INTERVAL_PROBE_TIME_MS: + NEIGH_VAR_SET(p, INTERVAL_PROBE_TIME_MS, + nla_get_msecs(tbp[i])); + break; + case NDTPA_RETRANS_TIME: + NEIGH_VAR_SET(p, RETRANS_TIME, + nla_get_msecs(tbp[i])); + break; + case NDTPA_ANYCAST_DELAY: + NEIGH_VAR_SET(p, ANYCAST_DELAY, + nla_get_msecs(tbp[i])); + break; + case NDTPA_PROXY_DELAY: + NEIGH_VAR_SET(p, PROXY_DELAY, + nla_get_msecs(tbp[i])); + break; + case NDTPA_LOCKTIME: + NEIGH_VAR_SET(p, LOCKTIME, + nla_get_msecs(tbp[i])); + break; + } + } + } + + err = -ENOENT; + if ((tb[NDTA_THRESH1] || tb[NDTA_THRESH2] || + tb[NDTA_THRESH3] || tb[NDTA_GC_INTERVAL]) && + !net_eq(net, &init_net)) + goto errout_tbl_lock; + + if (tb[NDTA_THRESH1]) + WRITE_ONCE(tbl->gc_thresh1, nla_get_u32(tb[NDTA_THRESH1])); + + if (tb[NDTA_THRESH2]) + WRITE_ONCE(tbl->gc_thresh2, nla_get_u32(tb[NDTA_THRESH2])); + + if (tb[NDTA_THRESH3]) + WRITE_ONCE(tbl->gc_thresh3, nla_get_u32(tb[NDTA_THRESH3])); + + if (tb[NDTA_GC_INTERVAL]) + WRITE_ONCE(tbl->gc_interval, nla_get_msecs(tb[NDTA_GC_INTERVAL])); + + err = 0; + +errout_tbl_lock: + write_unlock_bh(&tbl->lock); +errout: + return err; +} + +static int neightbl_valid_dump_info(const struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct ndtmsg *ndtm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ndtm))) { + NL_SET_ERR_MSG(extack, "Invalid header for neighbor table dump request"); + return -EINVAL; + } + + ndtm = nlmsg_data(nlh); + if (ndtm->ndtm_pad1 || ndtm->ndtm_pad2) { + NL_SET_ERR_MSG(extack, "Invalid values in header for neighbor table dump request"); + return -EINVAL; + } + + if (nlmsg_attrlen(nlh, sizeof(*ndtm))) { + NL_SET_ERR_MSG(extack, "Invalid data after header in neighbor table dump request"); + return -EINVAL; + } + + return 0; +} + +static int neightbl_dump_info(struct sk_buff *skb, struct netlink_callback *cb) +{ + const struct nlmsghdr *nlh = cb->nlh; + struct net *net = sock_net(skb->sk); + int family, tidx, nidx = 0; + int tbl_skip = cb->args[0]; + int neigh_skip = cb->args[1]; + struct neigh_table *tbl; + + if (cb->strict_check) { + int err = neightbl_valid_dump_info(nlh, cb->extack); + + if (err < 0) + return err; + } + + family = ((struct rtgenmsg *)nlmsg_data(nlh))->rtgen_family; + + for (tidx = 0; tidx < NEIGH_NR_TABLES; tidx++) { + struct neigh_parms *p; + + tbl = neigh_tables[tidx]; + if (!tbl) + continue; + + if (tidx < tbl_skip || (family && tbl->family != family)) + continue; + + if (neightbl_fill_info(skb, tbl, NETLINK_CB(cb->skb).portid, + nlh->nlmsg_seq, RTM_NEWNEIGHTBL, + NLM_F_MULTI) < 0) + break; + + nidx = 0; + p = list_next_entry(&tbl->parms, list); + list_for_each_entry_from(p, &tbl->parms_list, list) { + if (!net_eq(neigh_parms_net(p), net)) + continue; + + if (nidx < neigh_skip) + goto next; + + if (neightbl_fill_param_info(skb, tbl, p, + NETLINK_CB(cb->skb).portid, + nlh->nlmsg_seq, + RTM_NEWNEIGHTBL, + NLM_F_MULTI) < 0) + goto out; + next: + nidx++; + } + + neigh_skip = 0; + } +out: + cb->args[0] = tidx; + cb->args[1] = nidx; + + return skb->len; +} + +static int neigh_fill_info(struct sk_buff *skb, struct neighbour *neigh, + u32 pid, u32 seq, int type, unsigned int flags) +{ + u32 neigh_flags, neigh_flags_ext; + unsigned long now = jiffies; + struct nda_cacheinfo ci; + struct nlmsghdr *nlh; + struct ndmsg *ndm; + + nlh = nlmsg_put(skb, pid, seq, type, sizeof(*ndm), flags); + if (nlh == NULL) + return -EMSGSIZE; + + neigh_flags_ext = neigh->flags >> NTF_EXT_SHIFT; + neigh_flags = neigh->flags & NTF_OLD_MASK; + + ndm = nlmsg_data(nlh); + ndm->ndm_family = neigh->ops->family; + ndm->ndm_pad1 = 0; + ndm->ndm_pad2 = 0; + ndm->ndm_flags = neigh_flags; + ndm->ndm_type = neigh->type; + ndm->ndm_ifindex = neigh->dev->ifindex; + + if (nla_put(skb, NDA_DST, neigh->tbl->key_len, neigh->primary_key)) + goto nla_put_failure; + + read_lock_bh(&neigh->lock); + ndm->ndm_state = neigh->nud_state; + if (neigh->nud_state & NUD_VALID) { + char haddr[MAX_ADDR_LEN]; + + neigh_ha_snapshot(haddr, neigh, neigh->dev); + if (nla_put(skb, NDA_LLADDR, neigh->dev->addr_len, haddr) < 0) { + read_unlock_bh(&neigh->lock); + goto nla_put_failure; + } + } + + ci.ndm_used = jiffies_to_clock_t(now - neigh->used); + ci.ndm_confirmed = jiffies_to_clock_t(now - neigh->confirmed); + ci.ndm_updated = jiffies_to_clock_t(now - neigh->updated); + ci.ndm_refcnt = refcount_read(&neigh->refcnt) - 1; + read_unlock_bh(&neigh->lock); + + if (nla_put_u32(skb, NDA_PROBES, atomic_read(&neigh->probes)) || + nla_put(skb, NDA_CACHEINFO, sizeof(ci), &ci)) + goto nla_put_failure; + + if (neigh->protocol && nla_put_u8(skb, NDA_PROTOCOL, neigh->protocol)) + goto nla_put_failure; + if (neigh_flags_ext && nla_put_u32(skb, NDA_FLAGS_EXT, neigh_flags_ext)) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int pneigh_fill_info(struct sk_buff *skb, struct pneigh_entry *pn, + u32 pid, u32 seq, int type, unsigned int flags, + struct neigh_table *tbl) +{ + u32 neigh_flags, neigh_flags_ext; + struct nlmsghdr *nlh; + struct ndmsg *ndm; + + nlh = nlmsg_put(skb, pid, seq, type, sizeof(*ndm), flags); + if (nlh == NULL) + return -EMSGSIZE; + + neigh_flags_ext = pn->flags >> NTF_EXT_SHIFT; + neigh_flags = pn->flags & NTF_OLD_MASK; + + ndm = nlmsg_data(nlh); + ndm->ndm_family = tbl->family; + ndm->ndm_pad1 = 0; + ndm->ndm_pad2 = 0; + ndm->ndm_flags = neigh_flags | NTF_PROXY; + ndm->ndm_type = RTN_UNICAST; + ndm->ndm_ifindex = pn->dev ? pn->dev->ifindex : 0; + ndm->ndm_state = NUD_NONE; + + if (nla_put(skb, NDA_DST, tbl->key_len, pn->key)) + goto nla_put_failure; + + if (pn->protocol && nla_put_u8(skb, NDA_PROTOCOL, pn->protocol)) + goto nla_put_failure; + if (neigh_flags_ext && nla_put_u32(skb, NDA_FLAGS_EXT, neigh_flags_ext)) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static void neigh_update_notify(struct neighbour *neigh, u32 nlmsg_pid) +{ + call_netevent_notifiers(NETEVENT_NEIGH_UPDATE, neigh); + __neigh_notify(neigh, RTM_NEWNEIGH, 0, nlmsg_pid); +} + +static bool neigh_master_filtered(struct net_device *dev, int master_idx) +{ + struct net_device *master; + + if (!master_idx) + return false; + + master = dev ? netdev_master_upper_dev_get(dev) : NULL; + + /* 0 is already used to denote NDA_MASTER wasn't passed, therefore need another + * invalid value for ifindex to denote "no master". + */ + if (master_idx == -1) + return !!master; + + if (!master || master->ifindex != master_idx) + return true; + + return false; +} + +static bool neigh_ifindex_filtered(struct net_device *dev, int filter_idx) +{ + if (filter_idx && (!dev || dev->ifindex != filter_idx)) + return true; + + return false; +} + +struct neigh_dump_filter { + int master_idx; + int dev_idx; +}; + +static int neigh_dump_table(struct neigh_table *tbl, struct sk_buff *skb, + struct netlink_callback *cb, + struct neigh_dump_filter *filter) +{ + struct net *net = sock_net(skb->sk); + struct neighbour *n; + int rc, h, s_h = cb->args[1]; + int idx, s_idx = idx = cb->args[2]; + struct neigh_hash_table *nht; + unsigned int flags = NLM_F_MULTI; + + if (filter->dev_idx || filter->master_idx) + flags |= NLM_F_DUMP_FILTERED; + + rcu_read_lock(); + nht = rcu_dereference(tbl->nht); + + for (h = s_h; h < (1 << nht->hash_shift); h++) { + if (h > s_h) + s_idx = 0; + for (n = rcu_dereference(nht->hash_buckets[h]), idx = 0; + n != NULL; + n = rcu_dereference(n->next)) { + if (idx < s_idx || !net_eq(dev_net(n->dev), net)) + goto next; + if (neigh_ifindex_filtered(n->dev, filter->dev_idx) || + neigh_master_filtered(n->dev, filter->master_idx)) + goto next; + if (neigh_fill_info(skb, n, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + RTM_NEWNEIGH, + flags) < 0) { + rc = -1; + goto out; + } +next: + idx++; + } + } + rc = skb->len; +out: + rcu_read_unlock(); + cb->args[1] = h; + cb->args[2] = idx; + return rc; +} + +static int pneigh_dump_table(struct neigh_table *tbl, struct sk_buff *skb, + struct netlink_callback *cb, + struct neigh_dump_filter *filter) +{ + struct pneigh_entry *n; + struct net *net = sock_net(skb->sk); + int rc, h, s_h = cb->args[3]; + int idx, s_idx = idx = cb->args[4]; + unsigned int flags = NLM_F_MULTI; + + if (filter->dev_idx || filter->master_idx) + flags |= NLM_F_DUMP_FILTERED; + + read_lock_bh(&tbl->lock); + + for (h = s_h; h <= PNEIGH_HASHMASK; h++) { + if (h > s_h) + s_idx = 0; + for (n = tbl->phash_buckets[h], idx = 0; n; n = n->next) { + if (idx < s_idx || pneigh_net(n) != net) + goto next; + if (neigh_ifindex_filtered(n->dev, filter->dev_idx) || + neigh_master_filtered(n->dev, filter->master_idx)) + goto next; + if (pneigh_fill_info(skb, n, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + RTM_NEWNEIGH, flags, tbl) < 0) { + read_unlock_bh(&tbl->lock); + rc = -1; + goto out; + } + next: + idx++; + } + } + + read_unlock_bh(&tbl->lock); + rc = skb->len; +out: + cb->args[3] = h; + cb->args[4] = idx; + return rc; + +} + +static int neigh_valid_dump_req(const struct nlmsghdr *nlh, + bool strict_check, + struct neigh_dump_filter *filter, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[NDA_MAX + 1]; + int err, i; + + if (strict_check) { + struct ndmsg *ndm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ndm))) { + NL_SET_ERR_MSG(extack, "Invalid header for neighbor dump request"); + return -EINVAL; + } + + ndm = nlmsg_data(nlh); + if (ndm->ndm_pad1 || ndm->ndm_pad2 || ndm->ndm_ifindex || + ndm->ndm_state || ndm->ndm_type) { + NL_SET_ERR_MSG(extack, "Invalid values in header for neighbor dump request"); + return -EINVAL; + } + + if (ndm->ndm_flags & ~NTF_PROXY) { + NL_SET_ERR_MSG(extack, "Invalid flags in header for neighbor dump request"); + return -EINVAL; + } + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(struct ndmsg), + tb, NDA_MAX, nda_policy, + extack); + } else { + err = nlmsg_parse_deprecated(nlh, sizeof(struct ndmsg), tb, + NDA_MAX, nda_policy, extack); + } + if (err < 0) + return err; + + for (i = 0; i <= NDA_MAX; ++i) { + if (!tb[i]) + continue; + + /* all new attributes should require strict_check */ + switch (i) { + case NDA_IFINDEX: + filter->dev_idx = nla_get_u32(tb[i]); + break; + case NDA_MASTER: + filter->master_idx = nla_get_u32(tb[i]); + break; + default: + if (strict_check) { + NL_SET_ERR_MSG(extack, "Unsupported attribute in neighbor dump request"); + return -EINVAL; + } + } + } + + return 0; +} + +static int neigh_dump_info(struct sk_buff *skb, struct netlink_callback *cb) +{ + const struct nlmsghdr *nlh = cb->nlh; + struct neigh_dump_filter filter = {}; + struct neigh_table *tbl; + int t, family, s_t; + int proxy = 0; + int err; + + family = ((struct rtgenmsg *)nlmsg_data(nlh))->rtgen_family; + + /* check for full ndmsg structure presence, family member is + * the same for both structures + */ + if (nlmsg_len(nlh) >= sizeof(struct ndmsg) && + ((struct ndmsg *)nlmsg_data(nlh))->ndm_flags == NTF_PROXY) + proxy = 1; + + err = neigh_valid_dump_req(nlh, cb->strict_check, &filter, cb->extack); + if (err < 0 && cb->strict_check) + return err; + + s_t = cb->args[0]; + + for (t = 0; t < NEIGH_NR_TABLES; t++) { + tbl = neigh_tables[t]; + + if (!tbl) + continue; + if (t < s_t || (family && tbl->family != family)) + continue; + if (t > s_t) + memset(&cb->args[1], 0, sizeof(cb->args) - + sizeof(cb->args[0])); + if (proxy) + err = pneigh_dump_table(tbl, skb, cb, &filter); + else + err = neigh_dump_table(tbl, skb, cb, &filter); + if (err < 0) + break; + } + + cb->args[0] = t; + return skb->len; +} + +static int neigh_valid_get_req(const struct nlmsghdr *nlh, + struct neigh_table **tbl, + void **dst, int *dev_idx, u8 *ndm_flags, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[NDA_MAX + 1]; + struct ndmsg *ndm; + int err, i; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ndm))) { + NL_SET_ERR_MSG(extack, "Invalid header for neighbor get request"); + return -EINVAL; + } + + ndm = nlmsg_data(nlh); + if (ndm->ndm_pad1 || ndm->ndm_pad2 || ndm->ndm_state || + ndm->ndm_type) { + NL_SET_ERR_MSG(extack, "Invalid values in header for neighbor get request"); + return -EINVAL; + } + + if (ndm->ndm_flags & ~NTF_PROXY) { + NL_SET_ERR_MSG(extack, "Invalid flags in header for neighbor get request"); + return -EINVAL; + } + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(struct ndmsg), tb, + NDA_MAX, nda_policy, extack); + if (err < 0) + return err; + + *ndm_flags = ndm->ndm_flags; + *dev_idx = ndm->ndm_ifindex; + *tbl = neigh_find_table(ndm->ndm_family); + if (*tbl == NULL) { + NL_SET_ERR_MSG(extack, "Unsupported family in header for neighbor get request"); + return -EAFNOSUPPORT; + } + + for (i = 0; i <= NDA_MAX; ++i) { + if (!tb[i]) + continue; + + switch (i) { + case NDA_DST: + if (nla_len(tb[i]) != (int)(*tbl)->key_len) { + NL_SET_ERR_MSG(extack, "Invalid network address in neighbor get request"); + return -EINVAL; + } + *dst = nla_data(tb[i]); + break; + default: + NL_SET_ERR_MSG(extack, "Unsupported attribute in neighbor get request"); + return -EINVAL; + } + } + + return 0; +} + +static inline size_t neigh_nlmsg_size(void) +{ + return NLMSG_ALIGN(sizeof(struct ndmsg)) + + nla_total_size(MAX_ADDR_LEN) /* NDA_DST */ + + nla_total_size(MAX_ADDR_LEN) /* NDA_LLADDR */ + + nla_total_size(sizeof(struct nda_cacheinfo)) + + nla_total_size(4) /* NDA_PROBES */ + + nla_total_size(4) /* NDA_FLAGS_EXT */ + + nla_total_size(1); /* NDA_PROTOCOL */ +} + +static int neigh_get_reply(struct net *net, struct neighbour *neigh, + u32 pid, u32 seq) +{ + struct sk_buff *skb; + int err = 0; + + skb = nlmsg_new(neigh_nlmsg_size(), GFP_KERNEL); + if (!skb) + return -ENOBUFS; + + err = neigh_fill_info(skb, neigh, pid, seq, RTM_NEWNEIGH, 0); + if (err) { + kfree_skb(skb); + goto errout; + } + + err = rtnl_unicast(skb, net, pid); +errout: + return err; +} + +static inline size_t pneigh_nlmsg_size(void) +{ + return NLMSG_ALIGN(sizeof(struct ndmsg)) + + nla_total_size(MAX_ADDR_LEN) /* NDA_DST */ + + nla_total_size(4) /* NDA_FLAGS_EXT */ + + nla_total_size(1); /* NDA_PROTOCOL */ +} + +static int pneigh_get_reply(struct net *net, struct pneigh_entry *neigh, + u32 pid, u32 seq, struct neigh_table *tbl) +{ + struct sk_buff *skb; + int err = 0; + + skb = nlmsg_new(pneigh_nlmsg_size(), GFP_KERNEL); + if (!skb) + return -ENOBUFS; + + err = pneigh_fill_info(skb, neigh, pid, seq, RTM_NEWNEIGH, 0, tbl); + if (err) { + kfree_skb(skb); + goto errout; + } + + err = rtnl_unicast(skb, net, pid); +errout: + return err; +} + +static int neigh_get(struct sk_buff *in_skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(in_skb->sk); + struct net_device *dev = NULL; + struct neigh_table *tbl = NULL; + struct neighbour *neigh; + void *dst = NULL; + u8 ndm_flags = 0; + int dev_idx = 0; + int err; + + err = neigh_valid_get_req(nlh, &tbl, &dst, &dev_idx, &ndm_flags, + extack); + if (err < 0) + return err; + + if (dev_idx) { + dev = __dev_get_by_index(net, dev_idx); + if (!dev) { + NL_SET_ERR_MSG(extack, "Unknown device ifindex"); + return -ENODEV; + } + } + + if (!dst) { + NL_SET_ERR_MSG(extack, "Network address not specified"); + return -EINVAL; + } + + if (ndm_flags & NTF_PROXY) { + struct pneigh_entry *pn; + + pn = pneigh_lookup(tbl, net, dst, dev, 0); + if (!pn) { + NL_SET_ERR_MSG(extack, "Proxy neighbour entry not found"); + return -ENOENT; + } + return pneigh_get_reply(net, pn, NETLINK_CB(in_skb).portid, + nlh->nlmsg_seq, tbl); + } + + if (!dev) { + NL_SET_ERR_MSG(extack, "No device specified"); + return -EINVAL; + } + + neigh = neigh_lookup(tbl, dst, dev); + if (!neigh) { + NL_SET_ERR_MSG(extack, "Neighbour entry not found"); + return -ENOENT; + } + + err = neigh_get_reply(net, neigh, NETLINK_CB(in_skb).portid, + nlh->nlmsg_seq); + + neigh_release(neigh); + + return err; +} + +void neigh_for_each(struct neigh_table *tbl, void (*cb)(struct neighbour *, void *), void *cookie) +{ + int chain; + struct neigh_hash_table *nht; + + rcu_read_lock(); + nht = rcu_dereference(tbl->nht); + + read_lock_bh(&tbl->lock); /* avoid resizes */ + for (chain = 0; chain < (1 << nht->hash_shift); chain++) { + struct neighbour *n; + + for (n = rcu_dereference(nht->hash_buckets[chain]); + n != NULL; + n = rcu_dereference(n->next)) + cb(n, cookie); + } + read_unlock_bh(&tbl->lock); + rcu_read_unlock(); +} +EXPORT_SYMBOL(neigh_for_each); + +/* The tbl->lock must be held as a writer and BH disabled. */ +void __neigh_for_each_release(struct neigh_table *tbl, + int (*cb)(struct neighbour *)) +{ + int chain; + struct neigh_hash_table *nht; + + nht = rcu_dereference_protected(tbl->nht, + lockdep_is_held(&tbl->lock)); + for (chain = 0; chain < (1 << nht->hash_shift); chain++) { + struct neighbour *n; + struct neighbour __rcu **np; + + np = &nht->hash_buckets[chain]; + while ((n = rcu_dereference_protected(*np, + lockdep_is_held(&tbl->lock))) != NULL) { + int release; + + write_lock(&n->lock); + release = cb(n); + if (release) { + rcu_assign_pointer(*np, + rcu_dereference_protected(n->next, + lockdep_is_held(&tbl->lock))); + neigh_mark_dead(n); + } else + np = &n->next; + write_unlock(&n->lock); + if (release) + neigh_cleanup_and_release(n); + } + } +} +EXPORT_SYMBOL(__neigh_for_each_release); + +int neigh_xmit(int index, struct net_device *dev, + const void *addr, struct sk_buff *skb) +{ + int err = -EAFNOSUPPORT; + if (likely(index < NEIGH_NR_TABLES)) { + struct neigh_table *tbl; + struct neighbour *neigh; + + tbl = neigh_tables[index]; + if (!tbl) + goto out; + rcu_read_lock(); + if (index == NEIGH_ARP_TABLE) { + u32 key = *((u32 *)addr); + + neigh = __ipv4_neigh_lookup_noref(dev, key); + } else { + neigh = __neigh_lookup_noref(tbl, addr, dev); + } + if (!neigh) + neigh = __neigh_create(tbl, addr, dev, false); + err = PTR_ERR(neigh); + if (IS_ERR(neigh)) { + rcu_read_unlock(); + goto out_kfree_skb; + } + err = READ_ONCE(neigh->output)(neigh, skb); + rcu_read_unlock(); + } + else if (index == NEIGH_LINK_TABLE) { + err = dev_hard_header(skb, dev, ntohs(skb->protocol), + addr, NULL, skb->len); + if (err < 0) + goto out_kfree_skb; + err = dev_queue_xmit(skb); + } +out: + return err; +out_kfree_skb: + kfree_skb(skb); + goto out; +} +EXPORT_SYMBOL(neigh_xmit); + +#ifdef CONFIG_PROC_FS + +static struct neighbour *neigh_get_first(struct seq_file *seq) +{ + struct neigh_seq_state *state = seq->private; + struct net *net = seq_file_net(seq); + struct neigh_hash_table *nht = state->nht; + struct neighbour *n = NULL; + int bucket; + + state->flags &= ~NEIGH_SEQ_IS_PNEIGH; + for (bucket = 0; bucket < (1 << nht->hash_shift); bucket++) { + n = rcu_dereference(nht->hash_buckets[bucket]); + + while (n) { + if (!net_eq(dev_net(n->dev), net)) + goto next; + if (state->neigh_sub_iter) { + loff_t fakep = 0; + void *v; + + v = state->neigh_sub_iter(state, n, &fakep); + if (!v) + goto next; + } + if (!(state->flags & NEIGH_SEQ_SKIP_NOARP)) + break; + if (READ_ONCE(n->nud_state) & ~NUD_NOARP) + break; +next: + n = rcu_dereference(n->next); + } + + if (n) + break; + } + state->bucket = bucket; + + return n; +} + +static struct neighbour *neigh_get_next(struct seq_file *seq, + struct neighbour *n, + loff_t *pos) +{ + struct neigh_seq_state *state = seq->private; + struct net *net = seq_file_net(seq); + struct neigh_hash_table *nht = state->nht; + + if (state->neigh_sub_iter) { + void *v = state->neigh_sub_iter(state, n, pos); + if (v) + return n; + } + n = rcu_dereference(n->next); + + while (1) { + while (n) { + if (!net_eq(dev_net(n->dev), net)) + goto next; + if (state->neigh_sub_iter) { + void *v = state->neigh_sub_iter(state, n, pos); + if (v) + return n; + goto next; + } + if (!(state->flags & NEIGH_SEQ_SKIP_NOARP)) + break; + + if (READ_ONCE(n->nud_state) & ~NUD_NOARP) + break; +next: + n = rcu_dereference(n->next); + } + + if (n) + break; + + if (++state->bucket >= (1 << nht->hash_shift)) + break; + + n = rcu_dereference(nht->hash_buckets[state->bucket]); + } + + if (n && pos) + --(*pos); + return n; +} + +static struct neighbour *neigh_get_idx(struct seq_file *seq, loff_t *pos) +{ + struct neighbour *n = neigh_get_first(seq); + + if (n) { + --(*pos); + while (*pos) { + n = neigh_get_next(seq, n, pos); + if (!n) + break; + } + } + return *pos ? NULL : n; +} + +static struct pneigh_entry *pneigh_get_first(struct seq_file *seq) +{ + struct neigh_seq_state *state = seq->private; + struct net *net = seq_file_net(seq); + struct neigh_table *tbl = state->tbl; + struct pneigh_entry *pn = NULL; + int bucket; + + state->flags |= NEIGH_SEQ_IS_PNEIGH; + for (bucket = 0; bucket <= PNEIGH_HASHMASK; bucket++) { + pn = tbl->phash_buckets[bucket]; + while (pn && !net_eq(pneigh_net(pn), net)) + pn = pn->next; + if (pn) + break; + } + state->bucket = bucket; + + return pn; +} + +static struct pneigh_entry *pneigh_get_next(struct seq_file *seq, + struct pneigh_entry *pn, + loff_t *pos) +{ + struct neigh_seq_state *state = seq->private; + struct net *net = seq_file_net(seq); + struct neigh_table *tbl = state->tbl; + + do { + pn = pn->next; + } while (pn && !net_eq(pneigh_net(pn), net)); + + while (!pn) { + if (++state->bucket > PNEIGH_HASHMASK) + break; + pn = tbl->phash_buckets[state->bucket]; + while (pn && !net_eq(pneigh_net(pn), net)) + pn = pn->next; + if (pn) + break; + } + + if (pn && pos) + --(*pos); + + return pn; +} + +static struct pneigh_entry *pneigh_get_idx(struct seq_file *seq, loff_t *pos) +{ + struct pneigh_entry *pn = pneigh_get_first(seq); + + if (pn) { + --(*pos); + while (*pos) { + pn = pneigh_get_next(seq, pn, pos); + if (!pn) + break; + } + } + return *pos ? NULL : pn; +} + +static void *neigh_get_idx_any(struct seq_file *seq, loff_t *pos) +{ + struct neigh_seq_state *state = seq->private; + void *rc; + loff_t idxpos = *pos; + + rc = neigh_get_idx(seq, &idxpos); + if (!rc && !(state->flags & NEIGH_SEQ_NEIGH_ONLY)) + rc = pneigh_get_idx(seq, &idxpos); + + return rc; +} + +void *neigh_seq_start(struct seq_file *seq, loff_t *pos, struct neigh_table *tbl, unsigned int neigh_seq_flags) + __acquires(tbl->lock) + __acquires(rcu) +{ + struct neigh_seq_state *state = seq->private; + + state->tbl = tbl; + state->bucket = 0; + state->flags = (neigh_seq_flags & ~NEIGH_SEQ_IS_PNEIGH); + + rcu_read_lock(); + state->nht = rcu_dereference(tbl->nht); + read_lock_bh(&tbl->lock); + + return *pos ? neigh_get_idx_any(seq, pos) : SEQ_START_TOKEN; +} +EXPORT_SYMBOL(neigh_seq_start); + +void *neigh_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct neigh_seq_state *state; + void *rc; + + if (v == SEQ_START_TOKEN) { + rc = neigh_get_first(seq); + goto out; + } + + state = seq->private; + if (!(state->flags & NEIGH_SEQ_IS_PNEIGH)) { + rc = neigh_get_next(seq, v, NULL); + if (rc) + goto out; + if (!(state->flags & NEIGH_SEQ_NEIGH_ONLY)) + rc = pneigh_get_first(seq); + } else { + BUG_ON(state->flags & NEIGH_SEQ_NEIGH_ONLY); + rc = pneigh_get_next(seq, v, NULL); + } +out: + ++(*pos); + return rc; +} +EXPORT_SYMBOL(neigh_seq_next); + +void neigh_seq_stop(struct seq_file *seq, void *v) + __releases(tbl->lock) + __releases(rcu) +{ + struct neigh_seq_state *state = seq->private; + struct neigh_table *tbl = state->tbl; + + read_unlock_bh(&tbl->lock); + rcu_read_unlock(); +} +EXPORT_SYMBOL(neigh_seq_stop); + +/* statistics via seq_file */ + +static void *neigh_stat_seq_start(struct seq_file *seq, loff_t *pos) +{ + struct neigh_table *tbl = pde_data(file_inode(seq->file)); + int cpu; + + if (*pos == 0) + return SEQ_START_TOKEN; + + for (cpu = *pos-1; cpu < nr_cpu_ids; ++cpu) { + if (!cpu_possible(cpu)) + continue; + *pos = cpu+1; + return per_cpu_ptr(tbl->stats, cpu); + } + return NULL; +} + +static void *neigh_stat_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct neigh_table *tbl = pde_data(file_inode(seq->file)); + int cpu; + + for (cpu = *pos; cpu < nr_cpu_ids; ++cpu) { + if (!cpu_possible(cpu)) + continue; + *pos = cpu+1; + return per_cpu_ptr(tbl->stats, cpu); + } + (*pos)++; + return NULL; +} + +static void neigh_stat_seq_stop(struct seq_file *seq, void *v) +{ + +} + +static int neigh_stat_seq_show(struct seq_file *seq, void *v) +{ + struct neigh_table *tbl = pde_data(file_inode(seq->file)); + struct neigh_statistics *st = v; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, "entries allocs destroys hash_grows lookups hits res_failed rcv_probes_mcast rcv_probes_ucast periodic_gc_runs forced_gc_runs unresolved_discards table_fulls\n"); + return 0; + } + + seq_printf(seq, "%08x %08lx %08lx %08lx %08lx %08lx %08lx " + "%08lx %08lx %08lx " + "%08lx %08lx %08lx\n", + atomic_read(&tbl->entries), + + st->allocs, + st->destroys, + st->hash_grows, + + st->lookups, + st->hits, + + st->res_failed, + + st->rcv_probes_mcast, + st->rcv_probes_ucast, + + st->periodic_gc_runs, + st->forced_gc_runs, + st->unres_discards, + st->table_fulls + ); + + return 0; +} + +static const struct seq_operations neigh_stat_seq_ops = { + .start = neigh_stat_seq_start, + .next = neigh_stat_seq_next, + .stop = neigh_stat_seq_stop, + .show = neigh_stat_seq_show, +}; +#endif /* CONFIG_PROC_FS */ + +static void __neigh_notify(struct neighbour *n, int type, int flags, + u32 pid) +{ + struct net *net = dev_net(n->dev); + struct sk_buff *skb; + int err = -ENOBUFS; + + skb = nlmsg_new(neigh_nlmsg_size(), GFP_ATOMIC); + if (skb == NULL) + goto errout; + + err = neigh_fill_info(skb, n, pid, 0, type, flags); + if (err < 0) { + /* -EMSGSIZE implies BUG in neigh_nlmsg_size() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + rtnl_notify(skb, net, 0, RTNLGRP_NEIGH, NULL, GFP_ATOMIC); + return; +errout: + if (err < 0) + rtnl_set_sk_err(net, RTNLGRP_NEIGH, err); +} + +void neigh_app_ns(struct neighbour *n) +{ + __neigh_notify(n, RTM_GETNEIGH, NLM_F_REQUEST, 0); +} +EXPORT_SYMBOL(neigh_app_ns); + +#ifdef CONFIG_SYSCTL +static int unres_qlen_max = INT_MAX / SKB_TRUESIZE(ETH_FRAME_LEN); + +static int proc_unres_qlen(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + int size, ret; + struct ctl_table tmp = *ctl; + + tmp.extra1 = SYSCTL_ZERO; + tmp.extra2 = &unres_qlen_max; + tmp.data = &size; + + size = *(int *)ctl->data / SKB_TRUESIZE(ETH_FRAME_LEN); + ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos); + + if (write && !ret) + *(int *)ctl->data = size * SKB_TRUESIZE(ETH_FRAME_LEN); + return ret; +} + +static void neigh_copy_dflt_parms(struct net *net, struct neigh_parms *p, + int index) +{ + struct net_device *dev; + int family = neigh_parms_family(p); + + rcu_read_lock(); + for_each_netdev_rcu(net, dev) { + struct neigh_parms *dst_p = + neigh_get_dev_parms_rcu(dev, family); + + if (dst_p && !test_bit(index, dst_p->data_state)) + dst_p->data[index] = p->data[index]; + } + rcu_read_unlock(); +} + +static void neigh_proc_update(struct ctl_table *ctl, int write) +{ + struct net_device *dev = ctl->extra1; + struct neigh_parms *p = ctl->extra2; + struct net *net = neigh_parms_net(p); + int index = (int *) ctl->data - p->data; + + if (!write) + return; + + set_bit(index, p->data_state); + if (index == NEIGH_VAR_DELAY_PROBE_TIME) + call_netevent_notifiers(NETEVENT_DELAY_PROBE_TIME_UPDATE, p); + if (!dev) /* NULL dev means this is default value */ + neigh_copy_dflt_parms(net, p, index); +} + +static int neigh_proc_dointvec_zero_intmax(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, + loff_t *ppos) +{ + struct ctl_table tmp = *ctl; + int ret; + + tmp.extra1 = SYSCTL_ZERO; + tmp.extra2 = SYSCTL_INT_MAX; + + ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos); + neigh_proc_update(ctl, write); + return ret; +} + +static int neigh_proc_dointvec_ms_jiffies_positive(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct ctl_table tmp = *ctl; + int ret; + + int min = msecs_to_jiffies(1); + + tmp.extra1 = &min; + tmp.extra2 = NULL; + + ret = proc_dointvec_ms_jiffies_minmax(&tmp, write, buffer, lenp, ppos); + neigh_proc_update(ctl, write); + return ret; +} + +int neigh_proc_dointvec(struct ctl_table *ctl, int write, void *buffer, + size_t *lenp, loff_t *ppos) +{ + int ret = proc_dointvec(ctl, write, buffer, lenp, ppos); + + neigh_proc_update(ctl, write); + return ret; +} +EXPORT_SYMBOL(neigh_proc_dointvec); + +int neigh_proc_dointvec_jiffies(struct ctl_table *ctl, int write, void *buffer, + size_t *lenp, loff_t *ppos) +{ + int ret = proc_dointvec_jiffies(ctl, write, buffer, lenp, ppos); + + neigh_proc_update(ctl, write); + return ret; +} +EXPORT_SYMBOL(neigh_proc_dointvec_jiffies); + +static int neigh_proc_dointvec_userhz_jiffies(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, + loff_t *ppos) +{ + int ret = proc_dointvec_userhz_jiffies(ctl, write, buffer, lenp, ppos); + + neigh_proc_update(ctl, write); + return ret; +} + +int neigh_proc_dointvec_ms_jiffies(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + int ret = proc_dointvec_ms_jiffies(ctl, write, buffer, lenp, ppos); + + neigh_proc_update(ctl, write); + return ret; +} +EXPORT_SYMBOL(neigh_proc_dointvec_ms_jiffies); + +static int neigh_proc_dointvec_unres_qlen(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, + loff_t *ppos) +{ + int ret = proc_unres_qlen(ctl, write, buffer, lenp, ppos); + + neigh_proc_update(ctl, write); + return ret; +} + +static int neigh_proc_base_reachable_time(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, + loff_t *ppos) +{ + struct neigh_parms *p = ctl->extra2; + int ret; + + if (strcmp(ctl->procname, "base_reachable_time") == 0) + ret = neigh_proc_dointvec_jiffies(ctl, write, buffer, lenp, ppos); + else if (strcmp(ctl->procname, "base_reachable_time_ms") == 0) + ret = neigh_proc_dointvec_ms_jiffies(ctl, write, buffer, lenp, ppos); + else + ret = -1; + + if (write && ret == 0) { + /* update reachable_time as well, otherwise, the change will + * only be effective after the next time neigh_periodic_work + * decides to recompute it + */ + p->reachable_time = + neigh_rand_reach_time(NEIGH_VAR(p, BASE_REACHABLE_TIME)); + } + return ret; +} + +#define NEIGH_PARMS_DATA_OFFSET(index) \ + (&((struct neigh_parms *) 0)->data[index]) + +#define NEIGH_SYSCTL_ENTRY(attr, data_attr, name, mval, proc) \ + [NEIGH_VAR_ ## attr] = { \ + .procname = name, \ + .data = NEIGH_PARMS_DATA_OFFSET(NEIGH_VAR_ ## data_attr), \ + .maxlen = sizeof(int), \ + .mode = mval, \ + .proc_handler = proc, \ + } + +#define NEIGH_SYSCTL_ZERO_INTMAX_ENTRY(attr, name) \ + NEIGH_SYSCTL_ENTRY(attr, attr, name, 0644, neigh_proc_dointvec_zero_intmax) + +#define NEIGH_SYSCTL_JIFFIES_ENTRY(attr, name) \ + NEIGH_SYSCTL_ENTRY(attr, attr, name, 0644, neigh_proc_dointvec_jiffies) + +#define NEIGH_SYSCTL_USERHZ_JIFFIES_ENTRY(attr, name) \ + NEIGH_SYSCTL_ENTRY(attr, attr, name, 0644, neigh_proc_dointvec_userhz_jiffies) + +#define NEIGH_SYSCTL_MS_JIFFIES_POSITIVE_ENTRY(attr, name) \ + NEIGH_SYSCTL_ENTRY(attr, attr, name, 0644, neigh_proc_dointvec_ms_jiffies_positive) + +#define NEIGH_SYSCTL_MS_JIFFIES_REUSED_ENTRY(attr, data_attr, name) \ + NEIGH_SYSCTL_ENTRY(attr, data_attr, name, 0644, neigh_proc_dointvec_ms_jiffies) + +#define NEIGH_SYSCTL_UNRES_QLEN_REUSED_ENTRY(attr, data_attr, name) \ + NEIGH_SYSCTL_ENTRY(attr, data_attr, name, 0644, neigh_proc_dointvec_unres_qlen) + +static struct neigh_sysctl_table { + struct ctl_table_header *sysctl_header; + struct ctl_table neigh_vars[NEIGH_VAR_MAX + 1]; +} neigh_sysctl_template __read_mostly = { + .neigh_vars = { + NEIGH_SYSCTL_ZERO_INTMAX_ENTRY(MCAST_PROBES, "mcast_solicit"), + NEIGH_SYSCTL_ZERO_INTMAX_ENTRY(UCAST_PROBES, "ucast_solicit"), + NEIGH_SYSCTL_ZERO_INTMAX_ENTRY(APP_PROBES, "app_solicit"), + NEIGH_SYSCTL_ZERO_INTMAX_ENTRY(MCAST_REPROBES, "mcast_resolicit"), + NEIGH_SYSCTL_USERHZ_JIFFIES_ENTRY(RETRANS_TIME, "retrans_time"), + NEIGH_SYSCTL_JIFFIES_ENTRY(BASE_REACHABLE_TIME, "base_reachable_time"), + NEIGH_SYSCTL_JIFFIES_ENTRY(DELAY_PROBE_TIME, "delay_first_probe_time"), + NEIGH_SYSCTL_MS_JIFFIES_POSITIVE_ENTRY(INTERVAL_PROBE_TIME_MS, + "interval_probe_time_ms"), + NEIGH_SYSCTL_JIFFIES_ENTRY(GC_STALETIME, "gc_stale_time"), + NEIGH_SYSCTL_ZERO_INTMAX_ENTRY(QUEUE_LEN_BYTES, "unres_qlen_bytes"), + NEIGH_SYSCTL_ZERO_INTMAX_ENTRY(PROXY_QLEN, "proxy_qlen"), + NEIGH_SYSCTL_USERHZ_JIFFIES_ENTRY(ANYCAST_DELAY, "anycast_delay"), + NEIGH_SYSCTL_USERHZ_JIFFIES_ENTRY(PROXY_DELAY, "proxy_delay"), + NEIGH_SYSCTL_USERHZ_JIFFIES_ENTRY(LOCKTIME, "locktime"), + NEIGH_SYSCTL_UNRES_QLEN_REUSED_ENTRY(QUEUE_LEN, QUEUE_LEN_BYTES, "unres_qlen"), + NEIGH_SYSCTL_MS_JIFFIES_REUSED_ENTRY(RETRANS_TIME_MS, RETRANS_TIME, "retrans_time_ms"), + NEIGH_SYSCTL_MS_JIFFIES_REUSED_ENTRY(BASE_REACHABLE_TIME_MS, BASE_REACHABLE_TIME, "base_reachable_time_ms"), + [NEIGH_VAR_GC_INTERVAL] = { + .procname = "gc_interval", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NEIGH_VAR_GC_THRESH1] = { + .procname = "gc_thresh1", + .maxlen = sizeof(int), + .mode = 0644, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_INT_MAX, + .proc_handler = proc_dointvec_minmax, + }, + [NEIGH_VAR_GC_THRESH2] = { + .procname = "gc_thresh2", + .maxlen = sizeof(int), + .mode = 0644, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_INT_MAX, + .proc_handler = proc_dointvec_minmax, + }, + [NEIGH_VAR_GC_THRESH3] = { + .procname = "gc_thresh3", + .maxlen = sizeof(int), + .mode = 0644, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_INT_MAX, + .proc_handler = proc_dointvec_minmax, + }, + {}, + }, +}; + +int neigh_sysctl_register(struct net_device *dev, struct neigh_parms *p, + proc_handler *handler) +{ + int i; + struct neigh_sysctl_table *t; + const char *dev_name_source; + char neigh_path[ sizeof("net//neigh/") + IFNAMSIZ + IFNAMSIZ ]; + char *p_name; + + t = kmemdup(&neigh_sysctl_template, sizeof(*t), GFP_KERNEL_ACCOUNT); + if (!t) + goto err; + + for (i = 0; i < NEIGH_VAR_GC_INTERVAL; i++) { + t->neigh_vars[i].data += (long) p; + t->neigh_vars[i].extra1 = dev; + t->neigh_vars[i].extra2 = p; + } + + if (dev) { + dev_name_source = dev->name; + /* Terminate the table early */ + memset(&t->neigh_vars[NEIGH_VAR_GC_INTERVAL], 0, + sizeof(t->neigh_vars[NEIGH_VAR_GC_INTERVAL])); + } else { + struct neigh_table *tbl = p->tbl; + dev_name_source = "default"; + t->neigh_vars[NEIGH_VAR_GC_INTERVAL].data = &tbl->gc_interval; + t->neigh_vars[NEIGH_VAR_GC_THRESH1].data = &tbl->gc_thresh1; + t->neigh_vars[NEIGH_VAR_GC_THRESH2].data = &tbl->gc_thresh2; + t->neigh_vars[NEIGH_VAR_GC_THRESH3].data = &tbl->gc_thresh3; + } + + if (handler) { + /* RetransTime */ + t->neigh_vars[NEIGH_VAR_RETRANS_TIME].proc_handler = handler; + /* ReachableTime */ + t->neigh_vars[NEIGH_VAR_BASE_REACHABLE_TIME].proc_handler = handler; + /* RetransTime (in milliseconds)*/ + t->neigh_vars[NEIGH_VAR_RETRANS_TIME_MS].proc_handler = handler; + /* ReachableTime (in milliseconds) */ + t->neigh_vars[NEIGH_VAR_BASE_REACHABLE_TIME_MS].proc_handler = handler; + } else { + /* Those handlers will update p->reachable_time after + * base_reachable_time(_ms) is set to ensure the new timer starts being + * applied after the next neighbour update instead of waiting for + * neigh_periodic_work to update its value (can be multiple minutes) + * So any handler that replaces them should do this as well + */ + /* ReachableTime */ + t->neigh_vars[NEIGH_VAR_BASE_REACHABLE_TIME].proc_handler = + neigh_proc_base_reachable_time; + /* ReachableTime (in milliseconds) */ + t->neigh_vars[NEIGH_VAR_BASE_REACHABLE_TIME_MS].proc_handler = + neigh_proc_base_reachable_time; + } + + switch (neigh_parms_family(p)) { + case AF_INET: + p_name = "ipv4"; + break; + case AF_INET6: + p_name = "ipv6"; + break; + default: + BUG(); + } + + snprintf(neigh_path, sizeof(neigh_path), "net/%s/neigh/%s", + p_name, dev_name_source); + t->sysctl_header = + register_net_sysctl(neigh_parms_net(p), neigh_path, t->neigh_vars); + if (!t->sysctl_header) + goto free; + + p->sysctl_table = t; + return 0; + +free: + kfree(t); +err: + return -ENOBUFS; +} +EXPORT_SYMBOL(neigh_sysctl_register); + +void neigh_sysctl_unregister(struct neigh_parms *p) +{ + if (p->sysctl_table) { + struct neigh_sysctl_table *t = p->sysctl_table; + p->sysctl_table = NULL; + unregister_net_sysctl_table(t->sysctl_header); + kfree(t); + } +} +EXPORT_SYMBOL(neigh_sysctl_unregister); + +#endif /* CONFIG_SYSCTL */ + +static int __init neigh_init(void) +{ + rtnl_register(PF_UNSPEC, RTM_NEWNEIGH, neigh_add, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_DELNEIGH, neigh_delete, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_GETNEIGH, neigh_get, neigh_dump_info, 0); + + rtnl_register(PF_UNSPEC, RTM_GETNEIGHTBL, NULL, neightbl_dump_info, + 0); + rtnl_register(PF_UNSPEC, RTM_SETNEIGHTBL, neightbl_set, NULL, 0); + + return 0; +} + +subsys_initcall(neigh_init); diff --git a/net/core/net-procfs.c b/net/core/net-procfs.c new file mode 100644 index 000000000..1ec23bf8b --- /dev/null +++ b/net/core/net-procfs.c @@ -0,0 +1,407 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include + +#include "dev.h" + +#define BUCKET_SPACE (32 - NETDEV_HASHBITS - 1) + +#define get_bucket(x) ((x) >> BUCKET_SPACE) +#define get_offset(x) ((x) & ((1 << BUCKET_SPACE) - 1)) +#define set_bucket_offset(b, o) ((b) << BUCKET_SPACE | (o)) + +static inline struct net_device *dev_from_same_bucket(struct seq_file *seq, loff_t *pos) +{ + struct net *net = seq_file_net(seq); + struct net_device *dev; + struct hlist_head *h; + unsigned int count = 0, offset = get_offset(*pos); + + h = &net->dev_index_head[get_bucket(*pos)]; + hlist_for_each_entry_rcu(dev, h, index_hlist) { + if (++count == offset) + return dev; + } + + return NULL; +} + +static inline struct net_device *dev_from_bucket(struct seq_file *seq, loff_t *pos) +{ + struct net_device *dev; + unsigned int bucket; + + do { + dev = dev_from_same_bucket(seq, pos); + if (dev) + return dev; + + bucket = get_bucket(*pos) + 1; + *pos = set_bucket_offset(bucket, 1); + } while (bucket < NETDEV_HASHENTRIES); + + return NULL; +} + +/* + * This is invoked by the /proc filesystem handler to display a device + * in detail. + */ +static void *dev_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(RCU) +{ + rcu_read_lock(); + if (!*pos) + return SEQ_START_TOKEN; + + if (get_bucket(*pos) >= NETDEV_HASHENTRIES) + return NULL; + + return dev_from_bucket(seq, pos); +} + +static void *dev_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + ++*pos; + return dev_from_bucket(seq, pos); +} + +static void dev_seq_stop(struct seq_file *seq, void *v) + __releases(RCU) +{ + rcu_read_unlock(); +} + +static void dev_seq_printf_stats(struct seq_file *seq, struct net_device *dev) +{ + struct rtnl_link_stats64 temp; + const struct rtnl_link_stats64 *stats = dev_get_stats(dev, &temp); + + seq_printf(seq, "%6s: %7llu %7llu %4llu %4llu %4llu %5llu %10llu %9llu " + "%8llu %7llu %4llu %4llu %4llu %5llu %7llu %10llu\n", + dev->name, stats->rx_bytes, stats->rx_packets, + stats->rx_errors, + stats->rx_dropped + stats->rx_missed_errors, + stats->rx_fifo_errors, + stats->rx_length_errors + stats->rx_over_errors + + stats->rx_crc_errors + stats->rx_frame_errors, + stats->rx_compressed, stats->multicast, + stats->tx_bytes, stats->tx_packets, + stats->tx_errors, stats->tx_dropped, + stats->tx_fifo_errors, stats->collisions, + stats->tx_carrier_errors + + stats->tx_aborted_errors + + stats->tx_window_errors + + stats->tx_heartbeat_errors, + stats->tx_compressed); +} + +/* + * Called from the PROCfs module. This now uses the new arbitrary sized + * /proc/net interface to create /proc/net/dev + */ +static int dev_seq_show(struct seq_file *seq, void *v) +{ + if (v == SEQ_START_TOKEN) + seq_puts(seq, "Inter-| Receive " + " | Transmit\n" + " face |bytes packets errs drop fifo frame " + "compressed multicast|bytes packets errs " + "drop fifo colls carrier compressed\n"); + else + dev_seq_printf_stats(seq, v); + return 0; +} + +static u32 softnet_backlog_len(struct softnet_data *sd) +{ + return skb_queue_len_lockless(&sd->input_pkt_queue) + + skb_queue_len_lockless(&sd->process_queue); +} + +static struct softnet_data *softnet_get_online(loff_t *pos) +{ + struct softnet_data *sd = NULL; + + while (*pos < nr_cpu_ids) + if (cpu_online(*pos)) { + sd = &per_cpu(softnet_data, *pos); + break; + } else + ++*pos; + return sd; +} + +static void *softnet_seq_start(struct seq_file *seq, loff_t *pos) +{ + return softnet_get_online(pos); +} + +static void *softnet_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + ++*pos; + return softnet_get_online(pos); +} + +static void softnet_seq_stop(struct seq_file *seq, void *v) +{ +} + +static int softnet_seq_show(struct seq_file *seq, void *v) +{ + struct softnet_data *sd = v; + unsigned int flow_limit_count = 0; + +#ifdef CONFIG_NET_FLOW_LIMIT + struct sd_flow_limit *fl; + + rcu_read_lock(); + fl = rcu_dereference(sd->flow_limit); + if (fl) + flow_limit_count = fl->count; + rcu_read_unlock(); +#endif + + /* the index is the CPU id owing this sd. Since offline CPUs are not + * displayed, it would be othrwise not trivial for the user-space + * mapping the data a specific CPU + */ + seq_printf(seq, + "%08x %08x %08x %08x %08x %08x %08x %08x %08x %08x %08x %08x %08x\n", + sd->processed, sd->dropped, sd->time_squeeze, 0, + 0, 0, 0, 0, /* was fastroute */ + 0, /* was cpu_collision */ + sd->received_rps, flow_limit_count, + softnet_backlog_len(sd), (int)seq->index); + return 0; +} + +static const struct seq_operations dev_seq_ops = { + .start = dev_seq_start, + .next = dev_seq_next, + .stop = dev_seq_stop, + .show = dev_seq_show, +}; + +static const struct seq_operations softnet_seq_ops = { + .start = softnet_seq_start, + .next = softnet_seq_next, + .stop = softnet_seq_stop, + .show = softnet_seq_show, +}; + +static void *ptype_get_idx(struct seq_file *seq, loff_t pos) +{ + struct list_head *ptype_list = NULL; + struct packet_type *pt = NULL; + struct net_device *dev; + loff_t i = 0; + int t; + + for_each_netdev_rcu(seq_file_net(seq), dev) { + ptype_list = &dev->ptype_all; + list_for_each_entry_rcu(pt, ptype_list, list) { + if (i == pos) + return pt; + ++i; + } + } + + list_for_each_entry_rcu(pt, &ptype_all, list) { + if (i == pos) + return pt; + ++i; + } + + for (t = 0; t < PTYPE_HASH_SIZE; t++) { + list_for_each_entry_rcu(pt, &ptype_base[t], list) { + if (i == pos) + return pt; + ++i; + } + } + return NULL; +} + +static void *ptype_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(RCU) +{ + rcu_read_lock(); + return *pos ? ptype_get_idx(seq, *pos - 1) : SEQ_START_TOKEN; +} + +static void *ptype_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct net_device *dev; + struct packet_type *pt; + struct list_head *nxt; + int hash; + + ++*pos; + if (v == SEQ_START_TOKEN) + return ptype_get_idx(seq, 0); + + pt = v; + nxt = pt->list.next; + if (pt->dev) { + if (nxt != &pt->dev->ptype_all) + goto found; + + dev = pt->dev; + for_each_netdev_continue_rcu(seq_file_net(seq), dev) { + if (!list_empty(&dev->ptype_all)) { + nxt = dev->ptype_all.next; + goto found; + } + } + + nxt = ptype_all.next; + goto ptype_all; + } + + if (pt->type == htons(ETH_P_ALL)) { +ptype_all: + if (nxt != &ptype_all) + goto found; + hash = 0; + nxt = ptype_base[0].next; + } else + hash = ntohs(pt->type) & PTYPE_HASH_MASK; + + while (nxt == &ptype_base[hash]) { + if (++hash >= PTYPE_HASH_SIZE) + return NULL; + nxt = ptype_base[hash].next; + } +found: + return list_entry(nxt, struct packet_type, list); +} + +static void ptype_seq_stop(struct seq_file *seq, void *v) + __releases(RCU) +{ + rcu_read_unlock(); +} + +static int ptype_seq_show(struct seq_file *seq, void *v) +{ + struct packet_type *pt = v; + + if (v == SEQ_START_TOKEN) + seq_puts(seq, "Type Device Function\n"); + else if ((!pt->af_packet_net || net_eq(pt->af_packet_net, seq_file_net(seq))) && + (!pt->dev || net_eq(dev_net(pt->dev), seq_file_net(seq)))) { + if (pt->type == htons(ETH_P_ALL)) + seq_puts(seq, "ALL "); + else + seq_printf(seq, "%04x", ntohs(pt->type)); + + seq_printf(seq, " %-8s %ps\n", + pt->dev ? pt->dev->name : "", pt->func); + } + + return 0; +} + +static const struct seq_operations ptype_seq_ops = { + .start = ptype_seq_start, + .next = ptype_seq_next, + .stop = ptype_seq_stop, + .show = ptype_seq_show, +}; + +static int __net_init dev_proc_net_init(struct net *net) +{ + int rc = -ENOMEM; + + if (!proc_create_net("dev", 0444, net->proc_net, &dev_seq_ops, + sizeof(struct seq_net_private))) + goto out; + if (!proc_create_seq("softnet_stat", 0444, net->proc_net, + &softnet_seq_ops)) + goto out_dev; + if (!proc_create_net("ptype", 0444, net->proc_net, &ptype_seq_ops, + sizeof(struct seq_net_private))) + goto out_softnet; + + if (wext_proc_init(net)) + goto out_ptype; + rc = 0; +out: + return rc; +out_ptype: + remove_proc_entry("ptype", net->proc_net); +out_softnet: + remove_proc_entry("softnet_stat", net->proc_net); +out_dev: + remove_proc_entry("dev", net->proc_net); + goto out; +} + +static void __net_exit dev_proc_net_exit(struct net *net) +{ + wext_proc_exit(net); + + remove_proc_entry("ptype", net->proc_net); + remove_proc_entry("softnet_stat", net->proc_net); + remove_proc_entry("dev", net->proc_net); +} + +static struct pernet_operations __net_initdata dev_proc_ops = { + .init = dev_proc_net_init, + .exit = dev_proc_net_exit, +}; + +static int dev_mc_seq_show(struct seq_file *seq, void *v) +{ + struct netdev_hw_addr *ha; + struct net_device *dev = v; + + if (v == SEQ_START_TOKEN) + return 0; + + netif_addr_lock_bh(dev); + netdev_for_each_mc_addr(ha, dev) { + seq_printf(seq, "%-4d %-15s %-5d %-5d %*phN\n", + dev->ifindex, dev->name, + ha->refcount, ha->global_use, + (int)dev->addr_len, ha->addr); + } + netif_addr_unlock_bh(dev); + return 0; +} + +static const struct seq_operations dev_mc_seq_ops = { + .start = dev_seq_start, + .next = dev_seq_next, + .stop = dev_seq_stop, + .show = dev_mc_seq_show, +}; + +static int __net_init dev_mc_net_init(struct net *net) +{ + if (!proc_create_net("dev_mcast", 0, net->proc_net, &dev_mc_seq_ops, + sizeof(struct seq_net_private))) + return -ENOMEM; + return 0; +} + +static void __net_exit dev_mc_net_exit(struct net *net) +{ + remove_proc_entry("dev_mcast", net->proc_net); +} + +static struct pernet_operations __net_initdata dev_mc_net_ops = { + .init = dev_mc_net_init, + .exit = dev_mc_net_exit, +}; + +int __init dev_proc_init(void) +{ + int ret = register_pernet_subsys(&dev_proc_ops); + if (!ret) + return register_pernet_subsys(&dev_mc_net_ops); + return ret; +} diff --git a/net/core/net-sysfs.c b/net/core/net-sysfs.c new file mode 100644 index 000000000..8409d4140 --- /dev/null +++ b/net/core/net-sysfs.c @@ -0,0 +1,2079 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net-sysfs.c - network device class and attributes + * + * Copyright (c) 2003 Stephen Hemminger + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dev.h" +#include "net-sysfs.h" + +#ifdef CONFIG_SYSFS +static const char fmt_hex[] = "%#x\n"; +static const char fmt_dec[] = "%d\n"; +static const char fmt_ulong[] = "%lu\n"; +static const char fmt_u64[] = "%llu\n"; + +/* Caller holds RTNL or dev_base_lock */ +static inline int dev_isalive(const struct net_device *dev) +{ + return dev->reg_state <= NETREG_REGISTERED; +} + +/* use same locking rules as GIF* ioctl's */ +static ssize_t netdev_show(const struct device *dev, + struct device_attribute *attr, char *buf, + ssize_t (*format)(const struct net_device *, char *)) +{ + struct net_device *ndev = to_net_dev(dev); + ssize_t ret = -EINVAL; + + read_lock(&dev_base_lock); + if (dev_isalive(ndev)) + ret = (*format)(ndev, buf); + read_unlock(&dev_base_lock); + + return ret; +} + +/* generate a show function for simple field */ +#define NETDEVICE_SHOW(field, format_string) \ +static ssize_t format_##field(const struct net_device *dev, char *buf) \ +{ \ + return sysfs_emit(buf, format_string, dev->field); \ +} \ +static ssize_t field##_show(struct device *dev, \ + struct device_attribute *attr, char *buf) \ +{ \ + return netdev_show(dev, attr, buf, format_##field); \ +} \ + +#define NETDEVICE_SHOW_RO(field, format_string) \ +NETDEVICE_SHOW(field, format_string); \ +static DEVICE_ATTR_RO(field) + +#define NETDEVICE_SHOW_RW(field, format_string) \ +NETDEVICE_SHOW(field, format_string); \ +static DEVICE_ATTR_RW(field) + +/* use same locking and permission rules as SIF* ioctl's */ +static ssize_t netdev_store(struct device *dev, struct device_attribute *attr, + const char *buf, size_t len, + int (*set)(struct net_device *, unsigned long)) +{ + struct net_device *netdev = to_net_dev(dev); + struct net *net = dev_net(netdev); + unsigned long new; + int ret; + + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + ret = kstrtoul(buf, 0, &new); + if (ret) + goto err; + + if (!rtnl_trylock()) + return restart_syscall(); + + if (dev_isalive(netdev)) { + ret = (*set)(netdev, new); + if (ret == 0) + ret = len; + } + rtnl_unlock(); + err: + return ret; +} + +NETDEVICE_SHOW_RO(dev_id, fmt_hex); +NETDEVICE_SHOW_RO(dev_port, fmt_dec); +NETDEVICE_SHOW_RO(addr_assign_type, fmt_dec); +NETDEVICE_SHOW_RO(addr_len, fmt_dec); +NETDEVICE_SHOW_RO(ifindex, fmt_dec); +NETDEVICE_SHOW_RO(type, fmt_dec); +NETDEVICE_SHOW_RO(link_mode, fmt_dec); + +static ssize_t iflink_show(struct device *dev, struct device_attribute *attr, + char *buf) +{ + struct net_device *ndev = to_net_dev(dev); + + return sysfs_emit(buf, fmt_dec, dev_get_iflink(ndev)); +} +static DEVICE_ATTR_RO(iflink); + +static ssize_t format_name_assign_type(const struct net_device *dev, char *buf) +{ + return sysfs_emit(buf, fmt_dec, dev->name_assign_type); +} + +static ssize_t name_assign_type_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + struct net_device *ndev = to_net_dev(dev); + ssize_t ret = -EINVAL; + + if (ndev->name_assign_type != NET_NAME_UNKNOWN) + ret = netdev_show(dev, attr, buf, format_name_assign_type); + + return ret; +} +static DEVICE_ATTR_RO(name_assign_type); + +/* use same locking rules as GIFHWADDR ioctl's */ +static ssize_t address_show(struct device *dev, struct device_attribute *attr, + char *buf) +{ + struct net_device *ndev = to_net_dev(dev); + ssize_t ret = -EINVAL; + + read_lock(&dev_base_lock); + if (dev_isalive(ndev)) + ret = sysfs_format_mac(buf, ndev->dev_addr, ndev->addr_len); + read_unlock(&dev_base_lock); + return ret; +} +static DEVICE_ATTR_RO(address); + +static ssize_t broadcast_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + struct net_device *ndev = to_net_dev(dev); + + if (dev_isalive(ndev)) + return sysfs_format_mac(buf, ndev->broadcast, ndev->addr_len); + return -EINVAL; +} +static DEVICE_ATTR_RO(broadcast); + +static int change_carrier(struct net_device *dev, unsigned long new_carrier) +{ + if (!netif_running(dev)) + return -EINVAL; + return dev_change_carrier(dev, (bool)new_carrier); +} + +static ssize_t carrier_store(struct device *dev, struct device_attribute *attr, + const char *buf, size_t len) +{ + struct net_device *netdev = to_net_dev(dev); + + /* The check is also done in change_carrier; this helps returning early + * without hitting the trylock/restart in netdev_store. + */ + if (!netdev->netdev_ops->ndo_change_carrier) + return -EOPNOTSUPP; + + return netdev_store(dev, attr, buf, len, change_carrier); +} + +static ssize_t carrier_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + struct net_device *netdev = to_net_dev(dev); + + if (netif_running(netdev)) + return sysfs_emit(buf, fmt_dec, !!netif_carrier_ok(netdev)); + + return -EINVAL; +} +static DEVICE_ATTR_RW(carrier); + +static ssize_t speed_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + struct net_device *netdev = to_net_dev(dev); + int ret = -EINVAL; + + /* The check is also done in __ethtool_get_link_ksettings; this helps + * returning early without hitting the trylock/restart below. + */ + if (!netdev->ethtool_ops->get_link_ksettings) + return ret; + + if (!rtnl_trylock()) + return restart_syscall(); + + if (netif_running(netdev) && netif_device_present(netdev)) { + struct ethtool_link_ksettings cmd; + + if (!__ethtool_get_link_ksettings(netdev, &cmd)) + ret = sysfs_emit(buf, fmt_dec, cmd.base.speed); + } + rtnl_unlock(); + return ret; +} +static DEVICE_ATTR_RO(speed); + +static ssize_t duplex_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + struct net_device *netdev = to_net_dev(dev); + int ret = -EINVAL; + + /* The check is also done in __ethtool_get_link_ksettings; this helps + * returning early without hitting the trylock/restart below. + */ + if (!netdev->ethtool_ops->get_link_ksettings) + return ret; + + if (!rtnl_trylock()) + return restart_syscall(); + + if (netif_running(netdev)) { + struct ethtool_link_ksettings cmd; + + if (!__ethtool_get_link_ksettings(netdev, &cmd)) { + const char *duplex; + + switch (cmd.base.duplex) { + case DUPLEX_HALF: + duplex = "half"; + break; + case DUPLEX_FULL: + duplex = "full"; + break; + default: + duplex = "unknown"; + break; + } + ret = sysfs_emit(buf, "%s\n", duplex); + } + } + rtnl_unlock(); + return ret; +} +static DEVICE_ATTR_RO(duplex); + +static ssize_t testing_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + struct net_device *netdev = to_net_dev(dev); + + if (netif_running(netdev)) + return sysfs_emit(buf, fmt_dec, !!netif_testing(netdev)); + + return -EINVAL; +} +static DEVICE_ATTR_RO(testing); + +static ssize_t dormant_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + struct net_device *netdev = to_net_dev(dev); + + if (netif_running(netdev)) + return sysfs_emit(buf, fmt_dec, !!netif_dormant(netdev)); + + return -EINVAL; +} +static DEVICE_ATTR_RO(dormant); + +static const char *const operstates[] = { + "unknown", + "notpresent", /* currently unused */ + "down", + "lowerlayerdown", + "testing", + "dormant", + "up" +}; + +static ssize_t operstate_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + const struct net_device *netdev = to_net_dev(dev); + unsigned char operstate; + + read_lock(&dev_base_lock); + operstate = netdev->operstate; + if (!netif_running(netdev)) + operstate = IF_OPER_DOWN; + read_unlock(&dev_base_lock); + + if (operstate >= ARRAY_SIZE(operstates)) + return -EINVAL; /* should not happen */ + + return sysfs_emit(buf, "%s\n", operstates[operstate]); +} +static DEVICE_ATTR_RO(operstate); + +static ssize_t carrier_changes_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + struct net_device *netdev = to_net_dev(dev); + + return sysfs_emit(buf, fmt_dec, + atomic_read(&netdev->carrier_up_count) + + atomic_read(&netdev->carrier_down_count)); +} +static DEVICE_ATTR_RO(carrier_changes); + +static ssize_t carrier_up_count_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + struct net_device *netdev = to_net_dev(dev); + + return sysfs_emit(buf, fmt_dec, atomic_read(&netdev->carrier_up_count)); +} +static DEVICE_ATTR_RO(carrier_up_count); + +static ssize_t carrier_down_count_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + struct net_device *netdev = to_net_dev(dev); + + return sysfs_emit(buf, fmt_dec, atomic_read(&netdev->carrier_down_count)); +} +static DEVICE_ATTR_RO(carrier_down_count); + +/* read-write attributes */ + +static int change_mtu(struct net_device *dev, unsigned long new_mtu) +{ + return dev_set_mtu(dev, (int)new_mtu); +} + +static ssize_t mtu_store(struct device *dev, struct device_attribute *attr, + const char *buf, size_t len) +{ + return netdev_store(dev, attr, buf, len, change_mtu); +} +NETDEVICE_SHOW_RW(mtu, fmt_dec); + +static int change_flags(struct net_device *dev, unsigned long new_flags) +{ + return dev_change_flags(dev, (unsigned int)new_flags, NULL); +} + +static ssize_t flags_store(struct device *dev, struct device_attribute *attr, + const char *buf, size_t len) +{ + return netdev_store(dev, attr, buf, len, change_flags); +} +NETDEVICE_SHOW_RW(flags, fmt_hex); + +static ssize_t tx_queue_len_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t len) +{ + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + return netdev_store(dev, attr, buf, len, dev_change_tx_queue_len); +} +NETDEVICE_SHOW_RW(tx_queue_len, fmt_dec); + +static int change_gro_flush_timeout(struct net_device *dev, unsigned long val) +{ + WRITE_ONCE(dev->gro_flush_timeout, val); + return 0; +} + +static ssize_t gro_flush_timeout_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t len) +{ + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + return netdev_store(dev, attr, buf, len, change_gro_flush_timeout); +} +NETDEVICE_SHOW_RW(gro_flush_timeout, fmt_ulong); + +static int change_napi_defer_hard_irqs(struct net_device *dev, unsigned long val) +{ + WRITE_ONCE(dev->napi_defer_hard_irqs, val); + return 0; +} + +static ssize_t napi_defer_hard_irqs_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t len) +{ + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + return netdev_store(dev, attr, buf, len, change_napi_defer_hard_irqs); +} +NETDEVICE_SHOW_RW(napi_defer_hard_irqs, fmt_dec); + +static ssize_t ifalias_store(struct device *dev, struct device_attribute *attr, + const char *buf, size_t len) +{ + struct net_device *netdev = to_net_dev(dev); + struct net *net = dev_net(netdev); + size_t count = len; + ssize_t ret = 0; + + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + /* ignore trailing newline */ + if (len > 0 && buf[len - 1] == '\n') + --count; + + if (!rtnl_trylock()) + return restart_syscall(); + + if (dev_isalive(netdev)) { + ret = dev_set_alias(netdev, buf, count); + if (ret < 0) + goto err; + ret = len; + netdev_state_change(netdev); + } +err: + rtnl_unlock(); + + return ret; +} + +static ssize_t ifalias_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + const struct net_device *netdev = to_net_dev(dev); + char tmp[IFALIASZ]; + ssize_t ret = 0; + + ret = dev_get_alias(netdev, tmp, sizeof(tmp)); + if (ret > 0) + ret = sysfs_emit(buf, "%s\n", tmp); + return ret; +} +static DEVICE_ATTR_RW(ifalias); + +static int change_group(struct net_device *dev, unsigned long new_group) +{ + dev_set_group(dev, (int)new_group); + return 0; +} + +static ssize_t group_store(struct device *dev, struct device_attribute *attr, + const char *buf, size_t len) +{ + return netdev_store(dev, attr, buf, len, change_group); +} +NETDEVICE_SHOW(group, fmt_dec); +static DEVICE_ATTR(netdev_group, 0644, group_show, group_store); + +static int change_proto_down(struct net_device *dev, unsigned long proto_down) +{ + return dev_change_proto_down(dev, (bool)proto_down); +} + +static ssize_t proto_down_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return netdev_store(dev, attr, buf, len, change_proto_down); +} +NETDEVICE_SHOW_RW(proto_down, fmt_dec); + +static ssize_t phys_port_id_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + struct net_device *netdev = to_net_dev(dev); + ssize_t ret = -EINVAL; + + /* The check is also done in dev_get_phys_port_id; this helps returning + * early without hitting the trylock/restart below. + */ + if (!netdev->netdev_ops->ndo_get_phys_port_id) + return -EOPNOTSUPP; + + if (!rtnl_trylock()) + return restart_syscall(); + + if (dev_isalive(netdev)) { + struct netdev_phys_item_id ppid; + + ret = dev_get_phys_port_id(netdev, &ppid); + if (!ret) + ret = sysfs_emit(buf, "%*phN\n", ppid.id_len, ppid.id); + } + rtnl_unlock(); + + return ret; +} +static DEVICE_ATTR_RO(phys_port_id); + +static ssize_t phys_port_name_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + struct net_device *netdev = to_net_dev(dev); + ssize_t ret = -EINVAL; + + /* The checks are also done in dev_get_phys_port_name; this helps + * returning early without hitting the trylock/restart below. + */ + if (!netdev->netdev_ops->ndo_get_phys_port_name && + !netdev->netdev_ops->ndo_get_devlink_port) + return -EOPNOTSUPP; + + if (!rtnl_trylock()) + return restart_syscall(); + + if (dev_isalive(netdev)) { + char name[IFNAMSIZ]; + + ret = dev_get_phys_port_name(netdev, name, sizeof(name)); + if (!ret) + ret = sysfs_emit(buf, "%s\n", name); + } + rtnl_unlock(); + + return ret; +} +static DEVICE_ATTR_RO(phys_port_name); + +static ssize_t phys_switch_id_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + struct net_device *netdev = to_net_dev(dev); + ssize_t ret = -EINVAL; + + /* The checks are also done in dev_get_phys_port_name; this helps + * returning early without hitting the trylock/restart below. This works + * because recurse is false when calling dev_get_port_parent_id. + */ + if (!netdev->netdev_ops->ndo_get_port_parent_id && + !netdev->netdev_ops->ndo_get_devlink_port) + return -EOPNOTSUPP; + + if (!rtnl_trylock()) + return restart_syscall(); + + if (dev_isalive(netdev)) { + struct netdev_phys_item_id ppid = { }; + + ret = dev_get_port_parent_id(netdev, &ppid, false); + if (!ret) + ret = sysfs_emit(buf, "%*phN\n", ppid.id_len, ppid.id); + } + rtnl_unlock(); + + return ret; +} +static DEVICE_ATTR_RO(phys_switch_id); + +static ssize_t threaded_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + struct net_device *netdev = to_net_dev(dev); + ssize_t ret = -EINVAL; + + if (!rtnl_trylock()) + return restart_syscall(); + + if (dev_isalive(netdev)) + ret = sysfs_emit(buf, fmt_dec, netdev->threaded); + + rtnl_unlock(); + return ret; +} + +static int modify_napi_threaded(struct net_device *dev, unsigned long val) +{ + int ret; + + if (list_empty(&dev->napi_list)) + return -EOPNOTSUPP; + + if (val != 0 && val != 1) + return -EOPNOTSUPP; + + ret = dev_set_threaded(dev, val); + + return ret; +} + +static ssize_t threaded_store(struct device *dev, + struct device_attribute *attr, + const char *buf, size_t len) +{ + return netdev_store(dev, attr, buf, len, modify_napi_threaded); +} +static DEVICE_ATTR_RW(threaded); + +static struct attribute *net_class_attrs[] __ro_after_init = { + &dev_attr_netdev_group.attr, + &dev_attr_type.attr, + &dev_attr_dev_id.attr, + &dev_attr_dev_port.attr, + &dev_attr_iflink.attr, + &dev_attr_ifindex.attr, + &dev_attr_name_assign_type.attr, + &dev_attr_addr_assign_type.attr, + &dev_attr_addr_len.attr, + &dev_attr_link_mode.attr, + &dev_attr_address.attr, + &dev_attr_broadcast.attr, + &dev_attr_speed.attr, + &dev_attr_duplex.attr, + &dev_attr_dormant.attr, + &dev_attr_testing.attr, + &dev_attr_operstate.attr, + &dev_attr_carrier_changes.attr, + &dev_attr_ifalias.attr, + &dev_attr_carrier.attr, + &dev_attr_mtu.attr, + &dev_attr_flags.attr, + &dev_attr_tx_queue_len.attr, + &dev_attr_gro_flush_timeout.attr, + &dev_attr_napi_defer_hard_irqs.attr, + &dev_attr_phys_port_id.attr, + &dev_attr_phys_port_name.attr, + &dev_attr_phys_switch_id.attr, + &dev_attr_proto_down.attr, + &dev_attr_carrier_up_count.attr, + &dev_attr_carrier_down_count.attr, + &dev_attr_threaded.attr, + NULL, +}; +ATTRIBUTE_GROUPS(net_class); + +/* Show a given an attribute in the statistics group */ +static ssize_t netstat_show(const struct device *d, + struct device_attribute *attr, char *buf, + unsigned long offset) +{ + struct net_device *dev = to_net_dev(d); + ssize_t ret = -EINVAL; + + WARN_ON(offset > sizeof(struct rtnl_link_stats64) || + offset % sizeof(u64) != 0); + + read_lock(&dev_base_lock); + if (dev_isalive(dev)) { + struct rtnl_link_stats64 temp; + const struct rtnl_link_stats64 *stats = dev_get_stats(dev, &temp); + + ret = sysfs_emit(buf, fmt_u64, *(u64 *)(((u8 *)stats) + offset)); + } + read_unlock(&dev_base_lock); + return ret; +} + +/* generate a read-only statistics attribute */ +#define NETSTAT_ENTRY(name) \ +static ssize_t name##_show(struct device *d, \ + struct device_attribute *attr, char *buf) \ +{ \ + return netstat_show(d, attr, buf, \ + offsetof(struct rtnl_link_stats64, name)); \ +} \ +static DEVICE_ATTR_RO(name) + +NETSTAT_ENTRY(rx_packets); +NETSTAT_ENTRY(tx_packets); +NETSTAT_ENTRY(rx_bytes); +NETSTAT_ENTRY(tx_bytes); +NETSTAT_ENTRY(rx_errors); +NETSTAT_ENTRY(tx_errors); +NETSTAT_ENTRY(rx_dropped); +NETSTAT_ENTRY(tx_dropped); +NETSTAT_ENTRY(multicast); +NETSTAT_ENTRY(collisions); +NETSTAT_ENTRY(rx_length_errors); +NETSTAT_ENTRY(rx_over_errors); +NETSTAT_ENTRY(rx_crc_errors); +NETSTAT_ENTRY(rx_frame_errors); +NETSTAT_ENTRY(rx_fifo_errors); +NETSTAT_ENTRY(rx_missed_errors); +NETSTAT_ENTRY(tx_aborted_errors); +NETSTAT_ENTRY(tx_carrier_errors); +NETSTAT_ENTRY(tx_fifo_errors); +NETSTAT_ENTRY(tx_heartbeat_errors); +NETSTAT_ENTRY(tx_window_errors); +NETSTAT_ENTRY(rx_compressed); +NETSTAT_ENTRY(tx_compressed); +NETSTAT_ENTRY(rx_nohandler); + +static struct attribute *netstat_attrs[] __ro_after_init = { + &dev_attr_rx_packets.attr, + &dev_attr_tx_packets.attr, + &dev_attr_rx_bytes.attr, + &dev_attr_tx_bytes.attr, + &dev_attr_rx_errors.attr, + &dev_attr_tx_errors.attr, + &dev_attr_rx_dropped.attr, + &dev_attr_tx_dropped.attr, + &dev_attr_multicast.attr, + &dev_attr_collisions.attr, + &dev_attr_rx_length_errors.attr, + &dev_attr_rx_over_errors.attr, + &dev_attr_rx_crc_errors.attr, + &dev_attr_rx_frame_errors.attr, + &dev_attr_rx_fifo_errors.attr, + &dev_attr_rx_missed_errors.attr, + &dev_attr_tx_aborted_errors.attr, + &dev_attr_tx_carrier_errors.attr, + &dev_attr_tx_fifo_errors.attr, + &dev_attr_tx_heartbeat_errors.attr, + &dev_attr_tx_window_errors.attr, + &dev_attr_rx_compressed.attr, + &dev_attr_tx_compressed.attr, + &dev_attr_rx_nohandler.attr, + NULL +}; + +static const struct attribute_group netstat_group = { + .name = "statistics", + .attrs = netstat_attrs, +}; + +static struct attribute *wireless_attrs[] = { + NULL +}; + +static const struct attribute_group wireless_group = { + .name = "wireless", + .attrs = wireless_attrs, +}; + +static bool wireless_group_needed(struct net_device *ndev) +{ +#if IS_ENABLED(CONFIG_CFG80211) + if (ndev->ieee80211_ptr) + return true; +#endif +#if IS_ENABLED(CONFIG_WIRELESS_EXT) + if (ndev->wireless_handlers) + return true; +#endif + return false; +} + +#else /* CONFIG_SYSFS */ +#define net_class_groups NULL +#endif /* CONFIG_SYSFS */ + +#ifdef CONFIG_SYSFS +#define to_rx_queue_attr(_attr) \ + container_of(_attr, struct rx_queue_attribute, attr) + +#define to_rx_queue(obj) container_of(obj, struct netdev_rx_queue, kobj) + +static ssize_t rx_queue_attr_show(struct kobject *kobj, struct attribute *attr, + char *buf) +{ + const struct rx_queue_attribute *attribute = to_rx_queue_attr(attr); + struct netdev_rx_queue *queue = to_rx_queue(kobj); + + if (!attribute->show) + return -EIO; + + return attribute->show(queue, buf); +} + +static ssize_t rx_queue_attr_store(struct kobject *kobj, struct attribute *attr, + const char *buf, size_t count) +{ + const struct rx_queue_attribute *attribute = to_rx_queue_attr(attr); + struct netdev_rx_queue *queue = to_rx_queue(kobj); + + if (!attribute->store) + return -EIO; + + return attribute->store(queue, buf, count); +} + +static const struct sysfs_ops rx_queue_sysfs_ops = { + .show = rx_queue_attr_show, + .store = rx_queue_attr_store, +}; + +#ifdef CONFIG_RPS +static ssize_t show_rps_map(struct netdev_rx_queue *queue, char *buf) +{ + struct rps_map *map; + cpumask_var_t mask; + int i, len; + + if (!zalloc_cpumask_var(&mask, GFP_KERNEL)) + return -ENOMEM; + + rcu_read_lock(); + map = rcu_dereference(queue->rps_map); + if (map) + for (i = 0; i < map->len; i++) + cpumask_set_cpu(map->cpus[i], mask); + + len = sysfs_emit(buf, "%*pb\n", cpumask_pr_args(mask)); + rcu_read_unlock(); + free_cpumask_var(mask); + + return len < PAGE_SIZE ? len : -EINVAL; +} + +static ssize_t store_rps_map(struct netdev_rx_queue *queue, + const char *buf, size_t len) +{ + struct rps_map *old_map, *map; + cpumask_var_t mask; + int err, cpu, i; + static DEFINE_MUTEX(rps_map_mutex); + + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + if (!alloc_cpumask_var(&mask, GFP_KERNEL)) + return -ENOMEM; + + err = bitmap_parse(buf, len, cpumask_bits(mask), nr_cpumask_bits); + if (err) { + free_cpumask_var(mask); + return err; + } + + if (!cpumask_empty(mask)) { + cpumask_and(mask, mask, housekeeping_cpumask(HK_TYPE_DOMAIN)); + cpumask_and(mask, mask, housekeeping_cpumask(HK_TYPE_WQ)); + if (cpumask_empty(mask)) { + free_cpumask_var(mask); + return -EINVAL; + } + } + + map = kzalloc(max_t(unsigned int, + RPS_MAP_SIZE(cpumask_weight(mask)), L1_CACHE_BYTES), + GFP_KERNEL); + if (!map) { + free_cpumask_var(mask); + return -ENOMEM; + } + + i = 0; + for_each_cpu_and(cpu, mask, cpu_online_mask) + map->cpus[i++] = cpu; + + if (i) { + map->len = i; + } else { + kfree(map); + map = NULL; + } + + mutex_lock(&rps_map_mutex); + old_map = rcu_dereference_protected(queue->rps_map, + mutex_is_locked(&rps_map_mutex)); + rcu_assign_pointer(queue->rps_map, map); + + if (map) + static_branch_inc(&rps_needed); + if (old_map) + static_branch_dec(&rps_needed); + + mutex_unlock(&rps_map_mutex); + + if (old_map) + kfree_rcu(old_map, rcu); + + free_cpumask_var(mask); + return len; +} + +static ssize_t show_rps_dev_flow_table_cnt(struct netdev_rx_queue *queue, + char *buf) +{ + struct rps_dev_flow_table *flow_table; + unsigned long val = 0; + + rcu_read_lock(); + flow_table = rcu_dereference(queue->rps_flow_table); + if (flow_table) + val = (unsigned long)flow_table->mask + 1; + rcu_read_unlock(); + + return sysfs_emit(buf, "%lu\n", val); +} + +static void rps_dev_flow_table_release(struct rcu_head *rcu) +{ + struct rps_dev_flow_table *table = container_of(rcu, + struct rps_dev_flow_table, rcu); + vfree(table); +} + +static ssize_t store_rps_dev_flow_table_cnt(struct netdev_rx_queue *queue, + const char *buf, size_t len) +{ + unsigned long mask, count; + struct rps_dev_flow_table *table, *old_table; + static DEFINE_SPINLOCK(rps_dev_flow_lock); + int rc; + + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + rc = kstrtoul(buf, 0, &count); + if (rc < 0) + return rc; + + if (count) { + mask = count - 1; + /* mask = roundup_pow_of_two(count) - 1; + * without overflows... + */ + while ((mask | (mask >> 1)) != mask) + mask |= (mask >> 1); + /* On 64 bit arches, must check mask fits in table->mask (u32), + * and on 32bit arches, must check + * RPS_DEV_FLOW_TABLE_SIZE(mask + 1) doesn't overflow. + */ +#if BITS_PER_LONG > 32 + if (mask > (unsigned long)(u32)mask) + return -EINVAL; +#else + if (mask > (ULONG_MAX - RPS_DEV_FLOW_TABLE_SIZE(1)) + / sizeof(struct rps_dev_flow)) { + /* Enforce a limit to prevent overflow */ + return -EINVAL; + } +#endif + table = vmalloc(RPS_DEV_FLOW_TABLE_SIZE(mask + 1)); + if (!table) + return -ENOMEM; + + table->mask = mask; + for (count = 0; count <= mask; count++) + table->flows[count].cpu = RPS_NO_CPU; + } else { + table = NULL; + } + + spin_lock(&rps_dev_flow_lock); + old_table = rcu_dereference_protected(queue->rps_flow_table, + lockdep_is_held(&rps_dev_flow_lock)); + rcu_assign_pointer(queue->rps_flow_table, table); + spin_unlock(&rps_dev_flow_lock); + + if (old_table) + call_rcu(&old_table->rcu, rps_dev_flow_table_release); + + return len; +} + +static struct rx_queue_attribute rps_cpus_attribute __ro_after_init + = __ATTR(rps_cpus, 0644, show_rps_map, store_rps_map); + +static struct rx_queue_attribute rps_dev_flow_table_cnt_attribute __ro_after_init + = __ATTR(rps_flow_cnt, 0644, + show_rps_dev_flow_table_cnt, store_rps_dev_flow_table_cnt); +#endif /* CONFIG_RPS */ + +static struct attribute *rx_queue_default_attrs[] __ro_after_init = { +#ifdef CONFIG_RPS + &rps_cpus_attribute.attr, + &rps_dev_flow_table_cnt_attribute.attr, +#endif + NULL +}; +ATTRIBUTE_GROUPS(rx_queue_default); + +static void rx_queue_release(struct kobject *kobj) +{ + struct netdev_rx_queue *queue = to_rx_queue(kobj); +#ifdef CONFIG_RPS + struct rps_map *map; + struct rps_dev_flow_table *flow_table; + + map = rcu_dereference_protected(queue->rps_map, 1); + if (map) { + RCU_INIT_POINTER(queue->rps_map, NULL); + kfree_rcu(map, rcu); + } + + flow_table = rcu_dereference_protected(queue->rps_flow_table, 1); + if (flow_table) { + RCU_INIT_POINTER(queue->rps_flow_table, NULL); + call_rcu(&flow_table->rcu, rps_dev_flow_table_release); + } +#endif + + memset(kobj, 0, sizeof(*kobj)); + netdev_put(queue->dev, &queue->dev_tracker); +} + +static const void *rx_queue_namespace(struct kobject *kobj) +{ + struct netdev_rx_queue *queue = to_rx_queue(kobj); + struct device *dev = &queue->dev->dev; + const void *ns = NULL; + + if (dev->class && dev->class->ns_type) + ns = dev->class->namespace(dev); + + return ns; +} + +static void rx_queue_get_ownership(struct kobject *kobj, + kuid_t *uid, kgid_t *gid) +{ + const struct net *net = rx_queue_namespace(kobj); + + net_ns_get_ownership(net, uid, gid); +} + +static struct kobj_type rx_queue_ktype __ro_after_init = { + .sysfs_ops = &rx_queue_sysfs_ops, + .release = rx_queue_release, + .default_groups = rx_queue_default_groups, + .namespace = rx_queue_namespace, + .get_ownership = rx_queue_get_ownership, +}; + +static int rx_queue_add_kobject(struct net_device *dev, int index) +{ + struct netdev_rx_queue *queue = dev->_rx + index; + struct kobject *kobj = &queue->kobj; + int error = 0; + + /* Kobject_put later will trigger rx_queue_release call which + * decreases dev refcount: Take that reference here + */ + netdev_hold(queue->dev, &queue->dev_tracker, GFP_KERNEL); + + kobj->kset = dev->queues_kset; + error = kobject_init_and_add(kobj, &rx_queue_ktype, NULL, + "rx-%u", index); + if (error) + goto err; + + if (dev->sysfs_rx_queue_group) { + error = sysfs_create_group(kobj, dev->sysfs_rx_queue_group); + if (error) + goto err; + } + + kobject_uevent(kobj, KOBJ_ADD); + + return error; + +err: + kobject_put(kobj); + return error; +} + +static int rx_queue_change_owner(struct net_device *dev, int index, kuid_t kuid, + kgid_t kgid) +{ + struct netdev_rx_queue *queue = dev->_rx + index; + struct kobject *kobj = &queue->kobj; + int error; + + error = sysfs_change_owner(kobj, kuid, kgid); + if (error) + return error; + + if (dev->sysfs_rx_queue_group) + error = sysfs_group_change_owner( + kobj, dev->sysfs_rx_queue_group, kuid, kgid); + + return error; +} +#endif /* CONFIG_SYSFS */ + +int +net_rx_queue_update_kobjects(struct net_device *dev, int old_num, int new_num) +{ +#ifdef CONFIG_SYSFS + int i; + int error = 0; + +#ifndef CONFIG_RPS + if (!dev->sysfs_rx_queue_group) + return 0; +#endif + for (i = old_num; i < new_num; i++) { + error = rx_queue_add_kobject(dev, i); + if (error) { + new_num = old_num; + break; + } + } + + while (--i >= new_num) { + struct kobject *kobj = &dev->_rx[i].kobj; + + if (!refcount_read(&dev_net(dev)->ns.count)) + kobj->uevent_suppress = 1; + if (dev->sysfs_rx_queue_group) + sysfs_remove_group(kobj, dev->sysfs_rx_queue_group); + kobject_put(kobj); + } + + return error; +#else + return 0; +#endif +} + +static int net_rx_queue_change_owner(struct net_device *dev, int num, + kuid_t kuid, kgid_t kgid) +{ +#ifdef CONFIG_SYSFS + int error = 0; + int i; + +#ifndef CONFIG_RPS + if (!dev->sysfs_rx_queue_group) + return 0; +#endif + for (i = 0; i < num; i++) { + error = rx_queue_change_owner(dev, i, kuid, kgid); + if (error) + break; + } + + return error; +#else + return 0; +#endif +} + +#ifdef CONFIG_SYSFS +/* + * netdev_queue sysfs structures and functions. + */ +struct netdev_queue_attribute { + struct attribute attr; + ssize_t (*show)(struct netdev_queue *queue, char *buf); + ssize_t (*store)(struct netdev_queue *queue, + const char *buf, size_t len); +}; +#define to_netdev_queue_attr(_attr) \ + container_of(_attr, struct netdev_queue_attribute, attr) + +#define to_netdev_queue(obj) container_of(obj, struct netdev_queue, kobj) + +static ssize_t netdev_queue_attr_show(struct kobject *kobj, + struct attribute *attr, char *buf) +{ + const struct netdev_queue_attribute *attribute + = to_netdev_queue_attr(attr); + struct netdev_queue *queue = to_netdev_queue(kobj); + + if (!attribute->show) + return -EIO; + + return attribute->show(queue, buf); +} + +static ssize_t netdev_queue_attr_store(struct kobject *kobj, + struct attribute *attr, + const char *buf, size_t count) +{ + const struct netdev_queue_attribute *attribute + = to_netdev_queue_attr(attr); + struct netdev_queue *queue = to_netdev_queue(kobj); + + if (!attribute->store) + return -EIO; + + return attribute->store(queue, buf, count); +} + +static const struct sysfs_ops netdev_queue_sysfs_ops = { + .show = netdev_queue_attr_show, + .store = netdev_queue_attr_store, +}; + +static ssize_t tx_timeout_show(struct netdev_queue *queue, char *buf) +{ + unsigned long trans_timeout = atomic_long_read(&queue->trans_timeout); + + return sysfs_emit(buf, fmt_ulong, trans_timeout); +} + +static unsigned int get_netdev_queue_index(struct netdev_queue *queue) +{ + struct net_device *dev = queue->dev; + unsigned int i; + + i = queue - dev->_tx; + BUG_ON(i >= dev->num_tx_queues); + + return i; +} + +static ssize_t traffic_class_show(struct netdev_queue *queue, + char *buf) +{ + struct net_device *dev = queue->dev; + int num_tc, tc; + int index; + + if (!netif_is_multiqueue(dev)) + return -ENOENT; + + if (!rtnl_trylock()) + return restart_syscall(); + + index = get_netdev_queue_index(queue); + + /* If queue belongs to subordinate dev use its TC mapping */ + dev = netdev_get_tx_queue(dev, index)->sb_dev ? : dev; + + num_tc = dev->num_tc; + tc = netdev_txq_to_tc(dev, index); + + rtnl_unlock(); + + if (tc < 0) + return -EINVAL; + + /* We can report the traffic class one of two ways: + * Subordinate device traffic classes are reported with the traffic + * class first, and then the subordinate class so for example TC0 on + * subordinate device 2 will be reported as "0-2". If the queue + * belongs to the root device it will be reported with just the + * traffic class, so just "0" for TC 0 for example. + */ + return num_tc < 0 ? sysfs_emit(buf, "%d%d\n", tc, num_tc) : + sysfs_emit(buf, "%d\n", tc); +} + +#ifdef CONFIG_XPS +static ssize_t tx_maxrate_show(struct netdev_queue *queue, + char *buf) +{ + return sysfs_emit(buf, "%lu\n", queue->tx_maxrate); +} + +static ssize_t tx_maxrate_store(struct netdev_queue *queue, + const char *buf, size_t len) +{ + struct net_device *dev = queue->dev; + int err, index = get_netdev_queue_index(queue); + u32 rate = 0; + + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + /* The check is also done later; this helps returning early without + * hitting the trylock/restart below. + */ + if (!dev->netdev_ops->ndo_set_tx_maxrate) + return -EOPNOTSUPP; + + err = kstrtou32(buf, 10, &rate); + if (err < 0) + return err; + + if (!rtnl_trylock()) + return restart_syscall(); + + err = -EOPNOTSUPP; + if (dev->netdev_ops->ndo_set_tx_maxrate) + err = dev->netdev_ops->ndo_set_tx_maxrate(dev, index, rate); + + rtnl_unlock(); + if (!err) { + queue->tx_maxrate = rate; + return len; + } + return err; +} + +static struct netdev_queue_attribute queue_tx_maxrate __ro_after_init + = __ATTR_RW(tx_maxrate); +#endif + +static struct netdev_queue_attribute queue_trans_timeout __ro_after_init + = __ATTR_RO(tx_timeout); + +static struct netdev_queue_attribute queue_traffic_class __ro_after_init + = __ATTR_RO(traffic_class); + +#ifdef CONFIG_BQL +/* + * Byte queue limits sysfs structures and functions. + */ +static ssize_t bql_show(char *buf, unsigned int value) +{ + return sysfs_emit(buf, "%u\n", value); +} + +static ssize_t bql_set(const char *buf, const size_t count, + unsigned int *pvalue) +{ + unsigned int value; + int err; + + if (!strcmp(buf, "max") || !strcmp(buf, "max\n")) { + value = DQL_MAX_LIMIT; + } else { + err = kstrtouint(buf, 10, &value); + if (err < 0) + return err; + if (value > DQL_MAX_LIMIT) + return -EINVAL; + } + + *pvalue = value; + + return count; +} + +static ssize_t bql_show_hold_time(struct netdev_queue *queue, + char *buf) +{ + struct dql *dql = &queue->dql; + + return sysfs_emit(buf, "%u\n", jiffies_to_msecs(dql->slack_hold_time)); +} + +static ssize_t bql_set_hold_time(struct netdev_queue *queue, + const char *buf, size_t len) +{ + struct dql *dql = &queue->dql; + unsigned int value; + int err; + + err = kstrtouint(buf, 10, &value); + if (err < 0) + return err; + + dql->slack_hold_time = msecs_to_jiffies(value); + + return len; +} + +static struct netdev_queue_attribute bql_hold_time_attribute __ro_after_init + = __ATTR(hold_time, 0644, + bql_show_hold_time, bql_set_hold_time); + +static ssize_t bql_show_inflight(struct netdev_queue *queue, + char *buf) +{ + struct dql *dql = &queue->dql; + + return sysfs_emit(buf, "%u\n", dql->num_queued - dql->num_completed); +} + +static struct netdev_queue_attribute bql_inflight_attribute __ro_after_init = + __ATTR(inflight, 0444, bql_show_inflight, NULL); + +#define BQL_ATTR(NAME, FIELD) \ +static ssize_t bql_show_ ## NAME(struct netdev_queue *queue, \ + char *buf) \ +{ \ + return bql_show(buf, queue->dql.FIELD); \ +} \ + \ +static ssize_t bql_set_ ## NAME(struct netdev_queue *queue, \ + const char *buf, size_t len) \ +{ \ + return bql_set(buf, len, &queue->dql.FIELD); \ +} \ + \ +static struct netdev_queue_attribute bql_ ## NAME ## _attribute __ro_after_init \ + = __ATTR(NAME, 0644, \ + bql_show_ ## NAME, bql_set_ ## NAME) + +BQL_ATTR(limit, limit); +BQL_ATTR(limit_max, max_limit); +BQL_ATTR(limit_min, min_limit); + +static struct attribute *dql_attrs[] __ro_after_init = { + &bql_limit_attribute.attr, + &bql_limit_max_attribute.attr, + &bql_limit_min_attribute.attr, + &bql_hold_time_attribute.attr, + &bql_inflight_attribute.attr, + NULL +}; + +static const struct attribute_group dql_group = { + .name = "byte_queue_limits", + .attrs = dql_attrs, +}; +#endif /* CONFIG_BQL */ + +#ifdef CONFIG_XPS +static ssize_t xps_queue_show(struct net_device *dev, unsigned int index, + int tc, char *buf, enum xps_map_type type) +{ + struct xps_dev_maps *dev_maps; + unsigned long *mask; + unsigned int nr_ids; + int j, len; + + rcu_read_lock(); + dev_maps = rcu_dereference(dev->xps_maps[type]); + + /* Default to nr_cpu_ids/dev->num_rx_queues and do not just return 0 + * when dev_maps hasn't been allocated yet, to be backward compatible. + */ + nr_ids = dev_maps ? dev_maps->nr_ids : + (type == XPS_CPUS ? nr_cpu_ids : dev->num_rx_queues); + + mask = bitmap_zalloc(nr_ids, GFP_NOWAIT); + if (!mask) { + rcu_read_unlock(); + return -ENOMEM; + } + + if (!dev_maps || tc >= dev_maps->num_tc) + goto out_no_maps; + + for (j = 0; j < nr_ids; j++) { + int i, tci = j * dev_maps->num_tc + tc; + struct xps_map *map; + + map = rcu_dereference(dev_maps->attr_map[tci]); + if (!map) + continue; + + for (i = map->len; i--;) { + if (map->queues[i] == index) { + __set_bit(j, mask); + break; + } + } + } +out_no_maps: + rcu_read_unlock(); + + len = bitmap_print_to_pagebuf(false, buf, mask, nr_ids); + bitmap_free(mask); + + return len < PAGE_SIZE ? len : -EINVAL; +} + +static ssize_t xps_cpus_show(struct netdev_queue *queue, char *buf) +{ + struct net_device *dev = queue->dev; + unsigned int index; + int len, tc; + + if (!netif_is_multiqueue(dev)) + return -ENOENT; + + index = get_netdev_queue_index(queue); + + if (!rtnl_trylock()) + return restart_syscall(); + + /* If queue belongs to subordinate dev use its map */ + dev = netdev_get_tx_queue(dev, index)->sb_dev ? : dev; + + tc = netdev_txq_to_tc(dev, index); + if (tc < 0) { + rtnl_unlock(); + return -EINVAL; + } + + /* Make sure the subordinate device can't be freed */ + get_device(&dev->dev); + rtnl_unlock(); + + len = xps_queue_show(dev, index, tc, buf, XPS_CPUS); + + put_device(&dev->dev); + return len; +} + +static ssize_t xps_cpus_store(struct netdev_queue *queue, + const char *buf, size_t len) +{ + struct net_device *dev = queue->dev; + unsigned int index; + cpumask_var_t mask; + int err; + + if (!netif_is_multiqueue(dev)) + return -ENOENT; + + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + if (!alloc_cpumask_var(&mask, GFP_KERNEL)) + return -ENOMEM; + + index = get_netdev_queue_index(queue); + + err = bitmap_parse(buf, len, cpumask_bits(mask), nr_cpumask_bits); + if (err) { + free_cpumask_var(mask); + return err; + } + + if (!rtnl_trylock()) { + free_cpumask_var(mask); + return restart_syscall(); + } + + err = netif_set_xps_queue(dev, mask, index); + rtnl_unlock(); + + free_cpumask_var(mask); + + return err ? : len; +} + +static struct netdev_queue_attribute xps_cpus_attribute __ro_after_init + = __ATTR_RW(xps_cpus); + +static ssize_t xps_rxqs_show(struct netdev_queue *queue, char *buf) +{ + struct net_device *dev = queue->dev; + unsigned int index; + int tc; + + index = get_netdev_queue_index(queue); + + if (!rtnl_trylock()) + return restart_syscall(); + + tc = netdev_txq_to_tc(dev, index); + rtnl_unlock(); + if (tc < 0) + return -EINVAL; + + return xps_queue_show(dev, index, tc, buf, XPS_RXQS); +} + +static ssize_t xps_rxqs_store(struct netdev_queue *queue, const char *buf, + size_t len) +{ + struct net_device *dev = queue->dev; + struct net *net = dev_net(dev); + unsigned long *mask; + unsigned int index; + int err; + + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + mask = bitmap_zalloc(dev->num_rx_queues, GFP_KERNEL); + if (!mask) + return -ENOMEM; + + index = get_netdev_queue_index(queue); + + err = bitmap_parse(buf, len, mask, dev->num_rx_queues); + if (err) { + bitmap_free(mask); + return err; + } + + if (!rtnl_trylock()) { + bitmap_free(mask); + return restart_syscall(); + } + + cpus_read_lock(); + err = __netif_set_xps_queue(dev, mask, index, XPS_RXQS); + cpus_read_unlock(); + + rtnl_unlock(); + + bitmap_free(mask); + return err ? : len; +} + +static struct netdev_queue_attribute xps_rxqs_attribute __ro_after_init + = __ATTR_RW(xps_rxqs); +#endif /* CONFIG_XPS */ + +static struct attribute *netdev_queue_default_attrs[] __ro_after_init = { + &queue_trans_timeout.attr, + &queue_traffic_class.attr, +#ifdef CONFIG_XPS + &xps_cpus_attribute.attr, + &xps_rxqs_attribute.attr, + &queue_tx_maxrate.attr, +#endif + NULL +}; +ATTRIBUTE_GROUPS(netdev_queue_default); + +static void netdev_queue_release(struct kobject *kobj) +{ + struct netdev_queue *queue = to_netdev_queue(kobj); + + memset(kobj, 0, sizeof(*kobj)); + netdev_put(queue->dev, &queue->dev_tracker); +} + +static const void *netdev_queue_namespace(struct kobject *kobj) +{ + struct netdev_queue *queue = to_netdev_queue(kobj); + struct device *dev = &queue->dev->dev; + const void *ns = NULL; + + if (dev->class && dev->class->ns_type) + ns = dev->class->namespace(dev); + + return ns; +} + +static void netdev_queue_get_ownership(struct kobject *kobj, + kuid_t *uid, kgid_t *gid) +{ + const struct net *net = netdev_queue_namespace(kobj); + + net_ns_get_ownership(net, uid, gid); +} + +static struct kobj_type netdev_queue_ktype __ro_after_init = { + .sysfs_ops = &netdev_queue_sysfs_ops, + .release = netdev_queue_release, + .default_groups = netdev_queue_default_groups, + .namespace = netdev_queue_namespace, + .get_ownership = netdev_queue_get_ownership, +}; + +static int netdev_queue_add_kobject(struct net_device *dev, int index) +{ + struct netdev_queue *queue = dev->_tx + index; + struct kobject *kobj = &queue->kobj; + int error = 0; + + /* Kobject_put later will trigger netdev_queue_release call + * which decreases dev refcount: Take that reference here + */ + netdev_hold(queue->dev, &queue->dev_tracker, GFP_KERNEL); + + kobj->kset = dev->queues_kset; + error = kobject_init_and_add(kobj, &netdev_queue_ktype, NULL, + "tx-%u", index); + if (error) + goto err; + +#ifdef CONFIG_BQL + error = sysfs_create_group(kobj, &dql_group); + if (error) + goto err; +#endif + + kobject_uevent(kobj, KOBJ_ADD); + return 0; + +err: + kobject_put(kobj); + return error; +} + +static int tx_queue_change_owner(struct net_device *ndev, int index, + kuid_t kuid, kgid_t kgid) +{ + struct netdev_queue *queue = ndev->_tx + index; + struct kobject *kobj = &queue->kobj; + int error; + + error = sysfs_change_owner(kobj, kuid, kgid); + if (error) + return error; + +#ifdef CONFIG_BQL + error = sysfs_group_change_owner(kobj, &dql_group, kuid, kgid); +#endif + return error; +} +#endif /* CONFIG_SYSFS */ + +int +netdev_queue_update_kobjects(struct net_device *dev, int old_num, int new_num) +{ +#ifdef CONFIG_SYSFS + int i; + int error = 0; + + /* Tx queue kobjects are allowed to be updated when a device is being + * unregistered, but solely to remove queues from qdiscs. Any path + * adding queues should be fixed. + */ + WARN(dev->reg_state == NETREG_UNREGISTERING && new_num > old_num, + "New queues can't be registered after device unregistration."); + + for (i = old_num; i < new_num; i++) { + error = netdev_queue_add_kobject(dev, i); + if (error) { + new_num = old_num; + break; + } + } + + while (--i >= new_num) { + struct netdev_queue *queue = dev->_tx + i; + + if (!refcount_read(&dev_net(dev)->ns.count)) + queue->kobj.uevent_suppress = 1; +#ifdef CONFIG_BQL + sysfs_remove_group(&queue->kobj, &dql_group); +#endif + kobject_put(&queue->kobj); + } + + return error; +#else + return 0; +#endif /* CONFIG_SYSFS */ +} + +static int net_tx_queue_change_owner(struct net_device *dev, int num, + kuid_t kuid, kgid_t kgid) +{ +#ifdef CONFIG_SYSFS + int error = 0; + int i; + + for (i = 0; i < num; i++) { + error = tx_queue_change_owner(dev, i, kuid, kgid); + if (error) + break; + } + + return error; +#else + return 0; +#endif /* CONFIG_SYSFS */ +} + +static int register_queue_kobjects(struct net_device *dev) +{ + int error = 0, txq = 0, rxq = 0, real_rx = 0, real_tx = 0; + +#ifdef CONFIG_SYSFS + dev->queues_kset = kset_create_and_add("queues", + NULL, &dev->dev.kobj); + if (!dev->queues_kset) + return -ENOMEM; + real_rx = dev->real_num_rx_queues; +#endif + real_tx = dev->real_num_tx_queues; + + error = net_rx_queue_update_kobjects(dev, 0, real_rx); + if (error) + goto error; + rxq = real_rx; + + error = netdev_queue_update_kobjects(dev, 0, real_tx); + if (error) + goto error; + txq = real_tx; + + return 0; + +error: + netdev_queue_update_kobjects(dev, txq, 0); + net_rx_queue_update_kobjects(dev, rxq, 0); +#ifdef CONFIG_SYSFS + kset_unregister(dev->queues_kset); +#endif + return error; +} + +static int queue_change_owner(struct net_device *ndev, kuid_t kuid, kgid_t kgid) +{ + int error = 0, real_rx = 0, real_tx = 0; + +#ifdef CONFIG_SYSFS + if (ndev->queues_kset) { + error = sysfs_change_owner(&ndev->queues_kset->kobj, kuid, kgid); + if (error) + return error; + } + real_rx = ndev->real_num_rx_queues; +#endif + real_tx = ndev->real_num_tx_queues; + + error = net_rx_queue_change_owner(ndev, real_rx, kuid, kgid); + if (error) + return error; + + error = net_tx_queue_change_owner(ndev, real_tx, kuid, kgid); + if (error) + return error; + + return 0; +} + +static void remove_queue_kobjects(struct net_device *dev) +{ + int real_rx = 0, real_tx = 0; + +#ifdef CONFIG_SYSFS + real_rx = dev->real_num_rx_queues; +#endif + real_tx = dev->real_num_tx_queues; + + net_rx_queue_update_kobjects(dev, real_rx, 0); + netdev_queue_update_kobjects(dev, real_tx, 0); + + dev->real_num_rx_queues = 0; + dev->real_num_tx_queues = 0; +#ifdef CONFIG_SYSFS + kset_unregister(dev->queues_kset); +#endif +} + +static bool net_current_may_mount(void) +{ + struct net *net = current->nsproxy->net_ns; + + return ns_capable(net->user_ns, CAP_SYS_ADMIN); +} + +static void *net_grab_current_ns(void) +{ + struct net *ns = current->nsproxy->net_ns; +#ifdef CONFIG_NET_NS + if (ns) + refcount_inc(&ns->passive); +#endif + return ns; +} + +static const void *net_initial_ns(void) +{ + return &init_net; +} + +static const void *net_netlink_ns(struct sock *sk) +{ + return sock_net(sk); +} + +const struct kobj_ns_type_operations net_ns_type_operations = { + .type = KOBJ_NS_TYPE_NET, + .current_may_mount = net_current_may_mount, + .grab_current_ns = net_grab_current_ns, + .netlink_ns = net_netlink_ns, + .initial_ns = net_initial_ns, + .drop_ns = net_drop_ns, +}; +EXPORT_SYMBOL_GPL(net_ns_type_operations); + +static int netdev_uevent(struct device *d, struct kobj_uevent_env *env) +{ + struct net_device *dev = to_net_dev(d); + int retval; + + /* pass interface to uevent. */ + retval = add_uevent_var(env, "INTERFACE=%s", dev->name); + if (retval) + goto exit; + + /* pass ifindex to uevent. + * ifindex is useful as it won't change (interface name may change) + * and is what RtNetlink uses natively. + */ + retval = add_uevent_var(env, "IFINDEX=%d", dev->ifindex); + +exit: + return retval; +} + +/* + * netdev_release -- destroy and free a dead device. + * Called when last reference to device kobject is gone. + */ +static void netdev_release(struct device *d) +{ + struct net_device *dev = to_net_dev(d); + + BUG_ON(dev->reg_state != NETREG_RELEASED); + + /* no need to wait for rcu grace period: + * device is dead and about to be freed. + */ + kfree(rcu_access_pointer(dev->ifalias)); + netdev_freemem(dev); +} + +static const void *net_namespace(struct device *d) +{ + struct net_device *dev = to_net_dev(d); + + return dev_net(dev); +} + +static void net_get_ownership(struct device *d, kuid_t *uid, kgid_t *gid) +{ + struct net_device *dev = to_net_dev(d); + const struct net *net = dev_net(dev); + + net_ns_get_ownership(net, uid, gid); +} + +static struct class net_class __ro_after_init = { + .name = "net", + .dev_release = netdev_release, + .dev_groups = net_class_groups, + .dev_uevent = netdev_uevent, + .ns_type = &net_ns_type_operations, + .namespace = net_namespace, + .get_ownership = net_get_ownership, +}; + +#ifdef CONFIG_OF +static int of_dev_node_match(struct device *dev, const void *data) +{ + for (; dev; dev = dev->parent) { + if (dev->of_node == data) + return 1; + } + + return 0; +} + +/* + * of_find_net_device_by_node - lookup the net device for the device node + * @np: OF device node + * + * Looks up the net_device structure corresponding with the device node. + * If successful, returns a pointer to the net_device with the embedded + * struct device refcount incremented by one, or NULL on failure. The + * refcount must be dropped when done with the net_device. + */ +struct net_device *of_find_net_device_by_node(struct device_node *np) +{ + struct device *dev; + + dev = class_find_device(&net_class, NULL, np, of_dev_node_match); + if (!dev) + return NULL; + + return to_net_dev(dev); +} +EXPORT_SYMBOL(of_find_net_device_by_node); +#endif + +/* Delete sysfs entries but hold kobject reference until after all + * netdev references are gone. + */ +void netdev_unregister_kobject(struct net_device *ndev) +{ + struct device *dev = &ndev->dev; + + if (!refcount_read(&dev_net(ndev)->ns.count)) + dev_set_uevent_suppress(dev, 1); + + kobject_get(&dev->kobj); + + remove_queue_kobjects(ndev); + + pm_runtime_set_memalloc_noio(dev, false); + + device_del(dev); +} + +/* Create sysfs entries for network device. */ +int netdev_register_kobject(struct net_device *ndev) +{ + struct device *dev = &ndev->dev; + const struct attribute_group **groups = ndev->sysfs_groups; + int error = 0; + + device_initialize(dev); + dev->class = &net_class; + dev->platform_data = ndev; + dev->groups = groups; + + dev_set_name(dev, "%s", ndev->name); + +#ifdef CONFIG_SYSFS + /* Allow for a device specific group */ + if (*groups) + groups++; + + *groups++ = &netstat_group; + + if (wireless_group_needed(ndev)) + *groups++ = &wireless_group; +#endif /* CONFIG_SYSFS */ + + error = device_add(dev); + if (error) + return error; + + error = register_queue_kobjects(ndev); + if (error) { + device_del(dev); + return error; + } + + pm_runtime_set_memalloc_noio(dev, true); + + return error; +} + +/* Change owner for sysfs entries when moving network devices across network + * namespaces owned by different user namespaces. + */ +int netdev_change_owner(struct net_device *ndev, const struct net *net_old, + const struct net *net_new) +{ + kuid_t old_uid = GLOBAL_ROOT_UID, new_uid = GLOBAL_ROOT_UID; + kgid_t old_gid = GLOBAL_ROOT_GID, new_gid = GLOBAL_ROOT_GID; + struct device *dev = &ndev->dev; + int error; + + net_ns_get_ownership(net_old, &old_uid, &old_gid); + net_ns_get_ownership(net_new, &new_uid, &new_gid); + + /* The network namespace was changed but the owning user namespace is + * identical so there's no need to change the owner of sysfs entries. + */ + if (uid_eq(old_uid, new_uid) && gid_eq(old_gid, new_gid)) + return 0; + + error = device_change_owner(dev, new_uid, new_gid); + if (error) + return error; + + error = queue_change_owner(ndev, new_uid, new_gid); + if (error) + return error; + + return 0; +} + +int netdev_class_create_file_ns(const struct class_attribute *class_attr, + const void *ns) +{ + return class_create_file_ns(&net_class, class_attr, ns); +} +EXPORT_SYMBOL(netdev_class_create_file_ns); + +void netdev_class_remove_file_ns(const struct class_attribute *class_attr, + const void *ns) +{ + class_remove_file_ns(&net_class, class_attr, ns); +} +EXPORT_SYMBOL(netdev_class_remove_file_ns); + +int __init netdev_kobject_init(void) +{ + kobj_ns_type_register(&net_ns_type_operations); + return class_register(&net_class); +} diff --git a/net/core/net-sysfs.h b/net/core/net-sysfs.h new file mode 100644 index 000000000..8a5b04c26 --- /dev/null +++ b/net/core/net-sysfs.h @@ -0,0 +1,14 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __NET_SYSFS_H__ +#define __NET_SYSFS_H__ + +int __init netdev_kobject_init(void); +int netdev_register_kobject(struct net_device *); +void netdev_unregister_kobject(struct net_device *); +int net_rx_queue_update_kobjects(struct net_device *, int old_num, int new_num); +int netdev_queue_update_kobjects(struct net_device *net, + int old_num, int new_num); +int netdev_change_owner(struct net_device *, const struct net *net_old, + const struct net *net_new); + +#endif diff --git a/net/core/net-traces.c b/net/core/net-traces.c new file mode 100644 index 000000000..c40cd8dd7 --- /dev/null +++ b/net/core/net-traces.c @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * consolidates trace point definitions + * + * Copyright (C) 2009 Neil Horman + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#define CREATE_TRACE_POINTS +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_BRIDGE) +#include +EXPORT_TRACEPOINT_SYMBOL_GPL(br_fdb_add); +EXPORT_TRACEPOINT_SYMBOL_GPL(br_fdb_external_learn_add); +EXPORT_TRACEPOINT_SYMBOL_GPL(fdb_delete); +EXPORT_TRACEPOINT_SYMBOL_GPL(br_fdb_update); +#endif + +#if IS_ENABLED(CONFIG_PAGE_POOL) +#include +#endif + +#include +EXPORT_TRACEPOINT_SYMBOL_GPL(neigh_update); +EXPORT_TRACEPOINT_SYMBOL_GPL(neigh_update_done); +EXPORT_TRACEPOINT_SYMBOL_GPL(neigh_timer_handler); +EXPORT_TRACEPOINT_SYMBOL_GPL(neigh_event_send_done); +EXPORT_TRACEPOINT_SYMBOL_GPL(neigh_event_send_dead); +EXPORT_TRACEPOINT_SYMBOL_GPL(neigh_cleanup_and_release); + +EXPORT_TRACEPOINT_SYMBOL_GPL(kfree_skb); + +EXPORT_TRACEPOINT_SYMBOL_GPL(napi_poll); + +EXPORT_TRACEPOINT_SYMBOL_GPL(tcp_send_reset); +EXPORT_TRACEPOINT_SYMBOL_GPL(tcp_bad_csum); diff --git a/net/core/net_namespace.c b/net/core/net_namespace.c new file mode 100644 index 000000000..4c1707d0e --- /dev/null +++ b/net/core/net_namespace.c @@ -0,0 +1,1388 @@ +// SPDX-License-Identifier: GPL-2.0-only +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +/* + * Our network namespace constructor/destructor lists + */ + +static LIST_HEAD(pernet_list); +static struct list_head *first_device = &pernet_list; + +LIST_HEAD(net_namespace_list); +EXPORT_SYMBOL_GPL(net_namespace_list); + +/* Protects net_namespace_list. Nests iside rtnl_lock() */ +DECLARE_RWSEM(net_rwsem); +EXPORT_SYMBOL_GPL(net_rwsem); + +#ifdef CONFIG_KEYS +static struct key_tag init_net_key_domain = { .usage = REFCOUNT_INIT(1) }; +#endif + +struct net init_net; +EXPORT_SYMBOL(init_net); + +static bool init_net_initialized; +/* + * pernet_ops_rwsem: protects: pernet_list, net_generic_ids, + * init_net_initialized and first_device pointer. + * This is internal net namespace object. Please, don't use it + * outside. + */ +DECLARE_RWSEM(pernet_ops_rwsem); +EXPORT_SYMBOL_GPL(pernet_ops_rwsem); + +#define MIN_PERNET_OPS_ID \ + ((sizeof(struct net_generic) + sizeof(void *) - 1) / sizeof(void *)) + +#define INITIAL_NET_GEN_PTRS 13 /* +1 for len +2 for rcu_head */ + +static unsigned int max_gen_ptrs = INITIAL_NET_GEN_PTRS; + +DEFINE_COOKIE(net_cookie); + +static struct net_generic *net_alloc_generic(void) +{ + struct net_generic *ng; + unsigned int generic_size = offsetof(struct net_generic, ptr[max_gen_ptrs]); + + ng = kzalloc(generic_size, GFP_KERNEL); + if (ng) + ng->s.len = max_gen_ptrs; + + return ng; +} + +static int net_assign_generic(struct net *net, unsigned int id, void *data) +{ + struct net_generic *ng, *old_ng; + + BUG_ON(id < MIN_PERNET_OPS_ID); + + old_ng = rcu_dereference_protected(net->gen, + lockdep_is_held(&pernet_ops_rwsem)); + if (old_ng->s.len > id) { + old_ng->ptr[id] = data; + return 0; + } + + ng = net_alloc_generic(); + if (!ng) + return -ENOMEM; + + /* + * Some synchronisation notes: + * + * The net_generic explores the net->gen array inside rcu + * read section. Besides once set the net->gen->ptr[x] + * pointer never changes (see rules in netns/generic.h). + * + * That said, we simply duplicate this array and schedule + * the old copy for kfree after a grace period. + */ + + memcpy(&ng->ptr[MIN_PERNET_OPS_ID], &old_ng->ptr[MIN_PERNET_OPS_ID], + (old_ng->s.len - MIN_PERNET_OPS_ID) * sizeof(void *)); + ng->ptr[id] = data; + + rcu_assign_pointer(net->gen, ng); + kfree_rcu(old_ng, s.rcu); + return 0; +} + +static int ops_init(const struct pernet_operations *ops, struct net *net) +{ + struct net_generic *ng; + int err = -ENOMEM; + void *data = NULL; + + if (ops->id && ops->size) { + data = kzalloc(ops->size, GFP_KERNEL); + if (!data) + goto out; + + err = net_assign_generic(net, *ops->id, data); + if (err) + goto cleanup; + } + err = 0; + if (ops->init) + err = ops->init(net); + if (!err) + return 0; + + if (ops->id && ops->size) { + ng = rcu_dereference_protected(net->gen, + lockdep_is_held(&pernet_ops_rwsem)); + ng->ptr[*ops->id] = NULL; + } + +cleanup: + kfree(data); + +out: + return err; +} + +static void ops_pre_exit_list(const struct pernet_operations *ops, + struct list_head *net_exit_list) +{ + struct net *net; + + if (ops->pre_exit) { + list_for_each_entry(net, net_exit_list, exit_list) + ops->pre_exit(net); + } +} + +static void ops_exit_list(const struct pernet_operations *ops, + struct list_head *net_exit_list) +{ + struct net *net; + if (ops->exit) { + list_for_each_entry(net, net_exit_list, exit_list) { + ops->exit(net); + cond_resched(); + } + } + if (ops->exit_batch) + ops->exit_batch(net_exit_list); +} + +static void ops_free_list(const struct pernet_operations *ops, + struct list_head *net_exit_list) +{ + struct net *net; + if (ops->size && ops->id) { + list_for_each_entry(net, net_exit_list, exit_list) + kfree(net_generic(net, *ops->id)); + } +} + +/* should be called with nsid_lock held */ +static int alloc_netid(struct net *net, struct net *peer, int reqid) +{ + int min = 0, max = 0; + + if (reqid >= 0) { + min = reqid; + max = reqid + 1; + } + + return idr_alloc(&net->netns_ids, peer, min, max, GFP_ATOMIC); +} + +/* This function is used by idr_for_each(). If net is equal to peer, the + * function returns the id so that idr_for_each() stops. Because we cannot + * returns the id 0 (idr_for_each() will not stop), we return the magic value + * NET_ID_ZERO (-1) for it. + */ +#define NET_ID_ZERO -1 +static int net_eq_idr(int id, void *net, void *peer) +{ + if (net_eq(net, peer)) + return id ? : NET_ID_ZERO; + return 0; +} + +/* Must be called from RCU-critical section or with nsid_lock held */ +static int __peernet2id(const struct net *net, struct net *peer) +{ + int id = idr_for_each(&net->netns_ids, net_eq_idr, peer); + + /* Magic value for id 0. */ + if (id == NET_ID_ZERO) + return 0; + if (id > 0) + return id; + + return NETNSA_NSID_NOT_ASSIGNED; +} + +static void rtnl_net_notifyid(struct net *net, int cmd, int id, u32 portid, + struct nlmsghdr *nlh, gfp_t gfp); +/* This function returns the id of a peer netns. If no id is assigned, one will + * be allocated and returned. + */ +int peernet2id_alloc(struct net *net, struct net *peer, gfp_t gfp) +{ + int id; + + if (refcount_read(&net->ns.count) == 0) + return NETNSA_NSID_NOT_ASSIGNED; + + spin_lock_bh(&net->nsid_lock); + id = __peernet2id(net, peer); + if (id >= 0) { + spin_unlock_bh(&net->nsid_lock); + return id; + } + + /* When peer is obtained from RCU lists, we may race with + * its cleanup. Check whether it's alive, and this guarantees + * we never hash a peer back to net->netns_ids, after it has + * just been idr_remove()'d from there in cleanup_net(). + */ + if (!maybe_get_net(peer)) { + spin_unlock_bh(&net->nsid_lock); + return NETNSA_NSID_NOT_ASSIGNED; + } + + id = alloc_netid(net, peer, -1); + spin_unlock_bh(&net->nsid_lock); + + put_net(peer); + if (id < 0) + return NETNSA_NSID_NOT_ASSIGNED; + + rtnl_net_notifyid(net, RTM_NEWNSID, id, 0, NULL, gfp); + + return id; +} +EXPORT_SYMBOL_GPL(peernet2id_alloc); + +/* This function returns, if assigned, the id of a peer netns. */ +int peernet2id(const struct net *net, struct net *peer) +{ + int id; + + rcu_read_lock(); + id = __peernet2id(net, peer); + rcu_read_unlock(); + + return id; +} +EXPORT_SYMBOL(peernet2id); + +/* This function returns true is the peer netns has an id assigned into the + * current netns. + */ +bool peernet_has_id(const struct net *net, struct net *peer) +{ + return peernet2id(net, peer) >= 0; +} + +struct net *get_net_ns_by_id(const struct net *net, int id) +{ + struct net *peer; + + if (id < 0) + return NULL; + + rcu_read_lock(); + peer = idr_find(&net->netns_ids, id); + if (peer) + peer = maybe_get_net(peer); + rcu_read_unlock(); + + return peer; +} +EXPORT_SYMBOL_GPL(get_net_ns_by_id); + +/* + * setup_net runs the initializers for the network namespace object. + */ +static __net_init int setup_net(struct net *net, struct user_namespace *user_ns) +{ + /* Must be called with pernet_ops_rwsem held */ + const struct pernet_operations *ops, *saved_ops; + int error = 0; + LIST_HEAD(net_exit_list); + + refcount_set(&net->ns.count, 1); + ref_tracker_dir_init(&net->refcnt_tracker, 128); + + refcount_set(&net->passive, 1); + get_random_bytes(&net->hash_mix, sizeof(u32)); + preempt_disable(); + net->net_cookie = gen_cookie_next(&net_cookie); + preempt_enable(); + net->dev_base_seq = 1; + net->user_ns = user_ns; + idr_init(&net->netns_ids); + spin_lock_init(&net->nsid_lock); + mutex_init(&net->ipv4.ra_mutex); + + list_for_each_entry(ops, &pernet_list, list) { + error = ops_init(ops, net); + if (error < 0) + goto out_undo; + } + down_write(&net_rwsem); + list_add_tail_rcu(&net->list, &net_namespace_list); + up_write(&net_rwsem); +out: + return error; + +out_undo: + /* Walk through the list backwards calling the exit functions + * for the pernet modules whose init functions did not fail. + */ + list_add(&net->exit_list, &net_exit_list); + saved_ops = ops; + list_for_each_entry_continue_reverse(ops, &pernet_list, list) + ops_pre_exit_list(ops, &net_exit_list); + + synchronize_rcu(); + + ops = saved_ops; + list_for_each_entry_continue_reverse(ops, &pernet_list, list) + ops_exit_list(ops, &net_exit_list); + + ops = saved_ops; + list_for_each_entry_continue_reverse(ops, &pernet_list, list) + ops_free_list(ops, &net_exit_list); + + rcu_barrier(); + goto out; +} + +static int __net_init net_defaults_init_net(struct net *net) +{ + net->core.sysctl_somaxconn = SOMAXCONN; + net->core.sysctl_txrehash = SOCK_TXREHASH_ENABLED; + + return 0; +} + +static struct pernet_operations net_defaults_ops = { + .init = net_defaults_init_net, +}; + +static __init int net_defaults_init(void) +{ + if (register_pernet_subsys(&net_defaults_ops)) + panic("Cannot initialize net default settings"); + + return 0; +} + +core_initcall(net_defaults_init); + +#ifdef CONFIG_NET_NS +static struct ucounts *inc_net_namespaces(struct user_namespace *ns) +{ + return inc_ucount(ns, current_euid(), UCOUNT_NET_NAMESPACES); +} + +static void dec_net_namespaces(struct ucounts *ucounts) +{ + dec_ucount(ucounts, UCOUNT_NET_NAMESPACES); +} + +static struct kmem_cache *net_cachep __ro_after_init; +static struct workqueue_struct *netns_wq; + +static struct net *net_alloc(void) +{ + struct net *net = NULL; + struct net_generic *ng; + + ng = net_alloc_generic(); + if (!ng) + goto out; + + net = kmem_cache_zalloc(net_cachep, GFP_KERNEL); + if (!net) + goto out_free; + +#ifdef CONFIG_KEYS + net->key_domain = kzalloc(sizeof(struct key_tag), GFP_KERNEL); + if (!net->key_domain) + goto out_free_2; + refcount_set(&net->key_domain->usage, 1); +#endif + + rcu_assign_pointer(net->gen, ng); +out: + return net; + +#ifdef CONFIG_KEYS +out_free_2: + kmem_cache_free(net_cachep, net); + net = NULL; +#endif +out_free: + kfree(ng); + goto out; +} + +static void net_free(struct net *net) +{ + if (refcount_dec_and_test(&net->passive)) { + kfree(rcu_access_pointer(net->gen)); + kmem_cache_free(net_cachep, net); + } +} + +void net_drop_ns(void *p) +{ + struct net *net = (struct net *)p; + + if (net) + net_free(net); +} + +struct net *copy_net_ns(unsigned long flags, + struct user_namespace *user_ns, struct net *old_net) +{ + struct ucounts *ucounts; + struct net *net; + int rv; + + if (!(flags & CLONE_NEWNET)) + return get_net(old_net); + + ucounts = inc_net_namespaces(user_ns); + if (!ucounts) + return ERR_PTR(-ENOSPC); + + net = net_alloc(); + if (!net) { + rv = -ENOMEM; + goto dec_ucounts; + } + refcount_set(&net->passive, 1); + net->ucounts = ucounts; + get_user_ns(user_ns); + + rv = down_read_killable(&pernet_ops_rwsem); + if (rv < 0) + goto put_userns; + + rv = setup_net(net, user_ns); + + up_read(&pernet_ops_rwsem); + + if (rv < 0) { +put_userns: +#ifdef CONFIG_KEYS + key_remove_domain(net->key_domain); +#endif + put_user_ns(user_ns); + net_free(net); +dec_ucounts: + dec_net_namespaces(ucounts); + return ERR_PTR(rv); + } + return net; +} + +/** + * net_ns_get_ownership - get sysfs ownership data for @net + * @net: network namespace in question (can be NULL) + * @uid: kernel user ID for sysfs objects + * @gid: kernel group ID for sysfs objects + * + * Returns the uid/gid pair of root in the user namespace associated with the + * given network namespace. + */ +void net_ns_get_ownership(const struct net *net, kuid_t *uid, kgid_t *gid) +{ + if (net) { + kuid_t ns_root_uid = make_kuid(net->user_ns, 0); + kgid_t ns_root_gid = make_kgid(net->user_ns, 0); + + if (uid_valid(ns_root_uid)) + *uid = ns_root_uid; + + if (gid_valid(ns_root_gid)) + *gid = ns_root_gid; + } else { + *uid = GLOBAL_ROOT_UID; + *gid = GLOBAL_ROOT_GID; + } +} +EXPORT_SYMBOL_GPL(net_ns_get_ownership); + +static void unhash_nsid(struct net *net, struct net *last) +{ + struct net *tmp; + /* This function is only called from cleanup_net() work, + * and this work is the only process, that may delete + * a net from net_namespace_list. So, when the below + * is executing, the list may only grow. Thus, we do not + * use for_each_net_rcu() or net_rwsem. + */ + for_each_net(tmp) { + int id; + + spin_lock_bh(&tmp->nsid_lock); + id = __peernet2id(tmp, net); + if (id >= 0) + idr_remove(&tmp->netns_ids, id); + spin_unlock_bh(&tmp->nsid_lock); + if (id >= 0) + rtnl_net_notifyid(tmp, RTM_DELNSID, id, 0, NULL, + GFP_KERNEL); + if (tmp == last) + break; + } + spin_lock_bh(&net->nsid_lock); + idr_destroy(&net->netns_ids); + spin_unlock_bh(&net->nsid_lock); +} + +static LLIST_HEAD(cleanup_list); + +static void cleanup_net(struct work_struct *work) +{ + const struct pernet_operations *ops; + struct net *net, *tmp, *last; + struct llist_node *net_kill_list; + LIST_HEAD(net_exit_list); + + /* Atomically snapshot the list of namespaces to cleanup */ + net_kill_list = llist_del_all(&cleanup_list); + + down_read(&pernet_ops_rwsem); + + /* Don't let anyone else find us. */ + down_write(&net_rwsem); + llist_for_each_entry(net, net_kill_list, cleanup_list) + list_del_rcu(&net->list); + /* Cache last net. After we unlock rtnl, no one new net + * added to net_namespace_list can assign nsid pointer + * to a net from net_kill_list (see peernet2id_alloc()). + * So, we skip them in unhash_nsid(). + * + * Note, that unhash_nsid() does not delete nsid links + * between net_kill_list's nets, as they've already + * deleted from net_namespace_list. But, this would be + * useless anyway, as netns_ids are destroyed there. + */ + last = list_last_entry(&net_namespace_list, struct net, list); + up_write(&net_rwsem); + + llist_for_each_entry(net, net_kill_list, cleanup_list) { + unhash_nsid(net, last); + list_add_tail(&net->exit_list, &net_exit_list); + } + + /* Run all of the network namespace pre_exit methods */ + list_for_each_entry_reverse(ops, &pernet_list, list) + ops_pre_exit_list(ops, &net_exit_list); + + /* + * Another CPU might be rcu-iterating the list, wait for it. + * This needs to be before calling the exit() notifiers, so + * the rcu_barrier() below isn't sufficient alone. + * Also the pre_exit() and exit() methods need this barrier. + */ + synchronize_rcu(); + + /* Run all of the network namespace exit methods */ + list_for_each_entry_reverse(ops, &pernet_list, list) + ops_exit_list(ops, &net_exit_list); + + /* Free the net generic variables */ + list_for_each_entry_reverse(ops, &pernet_list, list) + ops_free_list(ops, &net_exit_list); + + up_read(&pernet_ops_rwsem); + + /* Ensure there are no outstanding rcu callbacks using this + * network namespace. + */ + rcu_barrier(); + + /* Finally it is safe to free my network namespace structure */ + list_for_each_entry_safe(net, tmp, &net_exit_list, exit_list) { + list_del_init(&net->exit_list); + dec_net_namespaces(net->ucounts); +#ifdef CONFIG_KEYS + key_remove_domain(net->key_domain); +#endif + put_user_ns(net->user_ns); + net_free(net); + } +} + +/** + * net_ns_barrier - wait until concurrent net_cleanup_work is done + * + * cleanup_net runs from work queue and will first remove namespaces + * from the global list, then run net exit functions. + * + * Call this in module exit path to make sure that all netns + * ->exit ops have been invoked before the function is removed. + */ +void net_ns_barrier(void) +{ + down_write(&pernet_ops_rwsem); + up_write(&pernet_ops_rwsem); +} +EXPORT_SYMBOL(net_ns_barrier); + +static DECLARE_WORK(net_cleanup_work, cleanup_net); + +void __put_net(struct net *net) +{ + ref_tracker_dir_exit(&net->refcnt_tracker); + /* Cleanup the network namespace in process context */ + if (llist_add(&net->cleanup_list, &cleanup_list)) + queue_work(netns_wq, &net_cleanup_work); +} +EXPORT_SYMBOL_GPL(__put_net); + +/** + * get_net_ns - increment the refcount of the network namespace + * @ns: common namespace (net) + * + * Returns the net's common namespace. + */ +struct ns_common *get_net_ns(struct ns_common *ns) +{ + return &get_net(container_of(ns, struct net, ns))->ns; +} +EXPORT_SYMBOL_GPL(get_net_ns); + +struct net *get_net_ns_by_fd(int fd) +{ + struct file *file; + struct ns_common *ns; + struct net *net; + + file = proc_ns_fget(fd); + if (IS_ERR(file)) + return ERR_CAST(file); + + ns = get_proc_ns(file_inode(file)); + if (ns->ops == &netns_operations) + net = get_net(container_of(ns, struct net, ns)); + else + net = ERR_PTR(-EINVAL); + + fput(file); + return net; +} +EXPORT_SYMBOL_GPL(get_net_ns_by_fd); +#endif + +struct net *get_net_ns_by_pid(pid_t pid) +{ + struct task_struct *tsk; + struct net *net; + + /* Lookup the network namespace */ + net = ERR_PTR(-ESRCH); + rcu_read_lock(); + tsk = find_task_by_vpid(pid); + if (tsk) { + struct nsproxy *nsproxy; + task_lock(tsk); + nsproxy = tsk->nsproxy; + if (nsproxy) + net = get_net(nsproxy->net_ns); + task_unlock(tsk); + } + rcu_read_unlock(); + return net; +} +EXPORT_SYMBOL_GPL(get_net_ns_by_pid); + +static __net_init int net_ns_net_init(struct net *net) +{ +#ifdef CONFIG_NET_NS + net->ns.ops = &netns_operations; +#endif + return ns_alloc_inum(&net->ns); +} + +static __net_exit void net_ns_net_exit(struct net *net) +{ + ns_free_inum(&net->ns); +} + +static struct pernet_operations __net_initdata net_ns_ops = { + .init = net_ns_net_init, + .exit = net_ns_net_exit, +}; + +static const struct nla_policy rtnl_net_policy[NETNSA_MAX + 1] = { + [NETNSA_NONE] = { .type = NLA_UNSPEC }, + [NETNSA_NSID] = { .type = NLA_S32 }, + [NETNSA_PID] = { .type = NLA_U32 }, + [NETNSA_FD] = { .type = NLA_U32 }, + [NETNSA_TARGET_NSID] = { .type = NLA_S32 }, +}; + +static int rtnl_net_newid(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *tb[NETNSA_MAX + 1]; + struct nlattr *nla; + struct net *peer; + int nsid, err; + + err = nlmsg_parse_deprecated(nlh, sizeof(struct rtgenmsg), tb, + NETNSA_MAX, rtnl_net_policy, extack); + if (err < 0) + return err; + if (!tb[NETNSA_NSID]) { + NL_SET_ERR_MSG(extack, "nsid is missing"); + return -EINVAL; + } + nsid = nla_get_s32(tb[NETNSA_NSID]); + + if (tb[NETNSA_PID]) { + peer = get_net_ns_by_pid(nla_get_u32(tb[NETNSA_PID])); + nla = tb[NETNSA_PID]; + } else if (tb[NETNSA_FD]) { + peer = get_net_ns_by_fd(nla_get_u32(tb[NETNSA_FD])); + nla = tb[NETNSA_FD]; + } else { + NL_SET_ERR_MSG(extack, "Peer netns reference is missing"); + return -EINVAL; + } + if (IS_ERR(peer)) { + NL_SET_BAD_ATTR(extack, nla); + NL_SET_ERR_MSG(extack, "Peer netns reference is invalid"); + return PTR_ERR(peer); + } + + spin_lock_bh(&net->nsid_lock); + if (__peernet2id(net, peer) >= 0) { + spin_unlock_bh(&net->nsid_lock); + err = -EEXIST; + NL_SET_BAD_ATTR(extack, nla); + NL_SET_ERR_MSG(extack, + "Peer netns already has a nsid assigned"); + goto out; + } + + err = alloc_netid(net, peer, nsid); + spin_unlock_bh(&net->nsid_lock); + if (err >= 0) { + rtnl_net_notifyid(net, RTM_NEWNSID, err, NETLINK_CB(skb).portid, + nlh, GFP_KERNEL); + err = 0; + } else if (err == -ENOSPC && nsid >= 0) { + err = -EEXIST; + NL_SET_BAD_ATTR(extack, tb[NETNSA_NSID]); + NL_SET_ERR_MSG(extack, "The specified nsid is already used"); + } +out: + put_net(peer); + return err; +} + +static int rtnl_net_get_size(void) +{ + return NLMSG_ALIGN(sizeof(struct rtgenmsg)) + + nla_total_size(sizeof(s32)) /* NETNSA_NSID */ + + nla_total_size(sizeof(s32)) /* NETNSA_CURRENT_NSID */ + ; +} + +struct net_fill_args { + u32 portid; + u32 seq; + int flags; + int cmd; + int nsid; + bool add_ref; + int ref_nsid; +}; + +static int rtnl_net_fill(struct sk_buff *skb, struct net_fill_args *args) +{ + struct nlmsghdr *nlh; + struct rtgenmsg *rth; + + nlh = nlmsg_put(skb, args->portid, args->seq, args->cmd, sizeof(*rth), + args->flags); + if (!nlh) + return -EMSGSIZE; + + rth = nlmsg_data(nlh); + rth->rtgen_family = AF_UNSPEC; + + if (nla_put_s32(skb, NETNSA_NSID, args->nsid)) + goto nla_put_failure; + + if (args->add_ref && + nla_put_s32(skb, NETNSA_CURRENT_NSID, args->ref_nsid)) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int rtnl_net_valid_getid_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + int i, err; + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse_deprecated(nlh, sizeof(struct rtgenmsg), + tb, NETNSA_MAX, rtnl_net_policy, + extack); + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(struct rtgenmsg), tb, + NETNSA_MAX, rtnl_net_policy, + extack); + if (err) + return err; + + for (i = 0; i <= NETNSA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case NETNSA_PID: + case NETNSA_FD: + case NETNSA_NSID: + case NETNSA_TARGET_NSID: + break; + default: + NL_SET_ERR_MSG(extack, "Unsupported attribute in peer netns getid request"); + return -EINVAL; + } + } + + return 0; +} + +static int rtnl_net_getid(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *tb[NETNSA_MAX + 1]; + struct net_fill_args fillargs = { + .portid = NETLINK_CB(skb).portid, + .seq = nlh->nlmsg_seq, + .cmd = RTM_NEWNSID, + }; + struct net *peer, *target = net; + struct nlattr *nla; + struct sk_buff *msg; + int err; + + err = rtnl_net_valid_getid_req(skb, nlh, tb, extack); + if (err < 0) + return err; + if (tb[NETNSA_PID]) { + peer = get_net_ns_by_pid(nla_get_u32(tb[NETNSA_PID])); + nla = tb[NETNSA_PID]; + } else if (tb[NETNSA_FD]) { + peer = get_net_ns_by_fd(nla_get_u32(tb[NETNSA_FD])); + nla = tb[NETNSA_FD]; + } else if (tb[NETNSA_NSID]) { + peer = get_net_ns_by_id(net, nla_get_s32(tb[NETNSA_NSID])); + if (!peer) + peer = ERR_PTR(-ENOENT); + nla = tb[NETNSA_NSID]; + } else { + NL_SET_ERR_MSG(extack, "Peer netns reference is missing"); + return -EINVAL; + } + + if (IS_ERR(peer)) { + NL_SET_BAD_ATTR(extack, nla); + NL_SET_ERR_MSG(extack, "Peer netns reference is invalid"); + return PTR_ERR(peer); + } + + if (tb[NETNSA_TARGET_NSID]) { + int id = nla_get_s32(tb[NETNSA_TARGET_NSID]); + + target = rtnl_get_net_ns_capable(NETLINK_CB(skb).sk, id); + if (IS_ERR(target)) { + NL_SET_BAD_ATTR(extack, tb[NETNSA_TARGET_NSID]); + NL_SET_ERR_MSG(extack, + "Target netns reference is invalid"); + err = PTR_ERR(target); + goto out; + } + fillargs.add_ref = true; + fillargs.ref_nsid = peernet2id(net, peer); + } + + msg = nlmsg_new(rtnl_net_get_size(), GFP_KERNEL); + if (!msg) { + err = -ENOMEM; + goto out; + } + + fillargs.nsid = peernet2id(target, peer); + err = rtnl_net_fill(msg, &fillargs); + if (err < 0) + goto err_out; + + err = rtnl_unicast(msg, net, NETLINK_CB(skb).portid); + goto out; + +err_out: + nlmsg_free(msg); +out: + if (fillargs.add_ref) + put_net(target); + put_net(peer); + return err; +} + +struct rtnl_net_dump_cb { + struct net *tgt_net; + struct net *ref_net; + struct sk_buff *skb; + struct net_fill_args fillargs; + int idx; + int s_idx; +}; + +/* Runs in RCU-critical section. */ +static int rtnl_net_dumpid_one(int id, void *peer, void *data) +{ + struct rtnl_net_dump_cb *net_cb = (struct rtnl_net_dump_cb *)data; + int ret; + + if (net_cb->idx < net_cb->s_idx) + goto cont; + + net_cb->fillargs.nsid = id; + if (net_cb->fillargs.add_ref) + net_cb->fillargs.ref_nsid = __peernet2id(net_cb->ref_net, peer); + ret = rtnl_net_fill(net_cb->skb, &net_cb->fillargs); + if (ret < 0) + return ret; + +cont: + net_cb->idx++; + return 0; +} + +static int rtnl_valid_dump_net_req(const struct nlmsghdr *nlh, struct sock *sk, + struct rtnl_net_dump_cb *net_cb, + struct netlink_callback *cb) +{ + struct netlink_ext_ack *extack = cb->extack; + struct nlattr *tb[NETNSA_MAX + 1]; + int err, i; + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(struct rtgenmsg), tb, + NETNSA_MAX, rtnl_net_policy, + extack); + if (err < 0) + return err; + + for (i = 0; i <= NETNSA_MAX; i++) { + if (!tb[i]) + continue; + + if (i == NETNSA_TARGET_NSID) { + struct net *net; + + net = rtnl_get_net_ns_capable(sk, nla_get_s32(tb[i])); + if (IS_ERR(net)) { + NL_SET_BAD_ATTR(extack, tb[i]); + NL_SET_ERR_MSG(extack, + "Invalid target network namespace id"); + return PTR_ERR(net); + } + net_cb->fillargs.add_ref = true; + net_cb->ref_net = net_cb->tgt_net; + net_cb->tgt_net = net; + } else { + NL_SET_BAD_ATTR(extack, tb[i]); + NL_SET_ERR_MSG(extack, + "Unsupported attribute in dump request"); + return -EINVAL; + } + } + + return 0; +} + +static int rtnl_net_dumpid(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct rtnl_net_dump_cb net_cb = { + .tgt_net = sock_net(skb->sk), + .skb = skb, + .fillargs = { + .portid = NETLINK_CB(cb->skb).portid, + .seq = cb->nlh->nlmsg_seq, + .flags = NLM_F_MULTI, + .cmd = RTM_NEWNSID, + }, + .idx = 0, + .s_idx = cb->args[0], + }; + int err = 0; + + if (cb->strict_check) { + err = rtnl_valid_dump_net_req(cb->nlh, skb->sk, &net_cb, cb); + if (err < 0) + goto end; + } + + rcu_read_lock(); + idr_for_each(&net_cb.tgt_net->netns_ids, rtnl_net_dumpid_one, &net_cb); + rcu_read_unlock(); + + cb->args[0] = net_cb.idx; +end: + if (net_cb.fillargs.add_ref) + put_net(net_cb.tgt_net); + return err < 0 ? err : skb->len; +} + +static void rtnl_net_notifyid(struct net *net, int cmd, int id, u32 portid, + struct nlmsghdr *nlh, gfp_t gfp) +{ + struct net_fill_args fillargs = { + .portid = portid, + .seq = nlh ? nlh->nlmsg_seq : 0, + .cmd = cmd, + .nsid = id, + }; + struct sk_buff *msg; + int err = -ENOMEM; + + msg = nlmsg_new(rtnl_net_get_size(), gfp); + if (!msg) + goto out; + + err = rtnl_net_fill(msg, &fillargs); + if (err < 0) + goto err_out; + + rtnl_notify(msg, net, portid, RTNLGRP_NSID, nlh, gfp); + return; + +err_out: + nlmsg_free(msg); +out: + rtnl_set_sk_err(net, RTNLGRP_NSID, err); +} + +void __init net_ns_init(void) +{ + struct net_generic *ng; + +#ifdef CONFIG_NET_NS + net_cachep = kmem_cache_create("net_namespace", sizeof(struct net), + SMP_CACHE_BYTES, + SLAB_PANIC|SLAB_ACCOUNT, NULL); + + /* Create workqueue for cleanup */ + netns_wq = create_singlethread_workqueue("netns"); + if (!netns_wq) + panic("Could not create netns workq"); +#endif + + ng = net_alloc_generic(); + if (!ng) + panic("Could not allocate generic netns"); + + rcu_assign_pointer(init_net.gen, ng); + +#ifdef CONFIG_KEYS + init_net.key_domain = &init_net_key_domain; +#endif + down_write(&pernet_ops_rwsem); + if (setup_net(&init_net, &init_user_ns)) + panic("Could not setup the initial network namespace"); + + init_net_initialized = true; + up_write(&pernet_ops_rwsem); + + if (register_pernet_subsys(&net_ns_ops)) + panic("Could not register network namespace subsystems"); + + rtnl_register(PF_UNSPEC, RTM_NEWNSID, rtnl_net_newid, NULL, + RTNL_FLAG_DOIT_UNLOCKED); + rtnl_register(PF_UNSPEC, RTM_GETNSID, rtnl_net_getid, rtnl_net_dumpid, + RTNL_FLAG_DOIT_UNLOCKED); +} + +static void free_exit_list(struct pernet_operations *ops, struct list_head *net_exit_list) +{ + ops_pre_exit_list(ops, net_exit_list); + synchronize_rcu(); + ops_exit_list(ops, net_exit_list); + ops_free_list(ops, net_exit_list); +} + +#ifdef CONFIG_NET_NS +static int __register_pernet_operations(struct list_head *list, + struct pernet_operations *ops) +{ + struct net *net; + int error; + LIST_HEAD(net_exit_list); + + list_add_tail(&ops->list, list); + if (ops->init || (ops->id && ops->size)) { + /* We held write locked pernet_ops_rwsem, and parallel + * setup_net() and cleanup_net() are not possible. + */ + for_each_net(net) { + error = ops_init(ops, net); + if (error) + goto out_undo; + list_add_tail(&net->exit_list, &net_exit_list); + } + } + return 0; + +out_undo: + /* If I have an error cleanup all namespaces I initialized */ + list_del(&ops->list); + free_exit_list(ops, &net_exit_list); + return error; +} + +static void __unregister_pernet_operations(struct pernet_operations *ops) +{ + struct net *net; + LIST_HEAD(net_exit_list); + + list_del(&ops->list); + /* See comment in __register_pernet_operations() */ + for_each_net(net) + list_add_tail(&net->exit_list, &net_exit_list); + + free_exit_list(ops, &net_exit_list); +} + +#else + +static int __register_pernet_operations(struct list_head *list, + struct pernet_operations *ops) +{ + if (!init_net_initialized) { + list_add_tail(&ops->list, list); + return 0; + } + + return ops_init(ops, &init_net); +} + +static void __unregister_pernet_operations(struct pernet_operations *ops) +{ + if (!init_net_initialized) { + list_del(&ops->list); + } else { + LIST_HEAD(net_exit_list); + list_add(&init_net.exit_list, &net_exit_list); + free_exit_list(ops, &net_exit_list); + } +} + +#endif /* CONFIG_NET_NS */ + +static DEFINE_IDA(net_generic_ids); + +static int register_pernet_operations(struct list_head *list, + struct pernet_operations *ops) +{ + int error; + + if (ops->id) { + error = ida_alloc_min(&net_generic_ids, MIN_PERNET_OPS_ID, + GFP_KERNEL); + if (error < 0) + return error; + *ops->id = error; + max_gen_ptrs = max(max_gen_ptrs, *ops->id + 1); + } + error = __register_pernet_operations(list, ops); + if (error) { + rcu_barrier(); + if (ops->id) + ida_free(&net_generic_ids, *ops->id); + } + + return error; +} + +static void unregister_pernet_operations(struct pernet_operations *ops) +{ + __unregister_pernet_operations(ops); + rcu_barrier(); + if (ops->id) + ida_free(&net_generic_ids, *ops->id); +} + +/** + * register_pernet_subsys - register a network namespace subsystem + * @ops: pernet operations structure for the subsystem + * + * Register a subsystem which has init and exit functions + * that are called when network namespaces are created and + * destroyed respectively. + * + * When registered all network namespace init functions are + * called for every existing network namespace. Allowing kernel + * modules to have a race free view of the set of network namespaces. + * + * When a new network namespace is created all of the init + * methods are called in the order in which they were registered. + * + * When a network namespace is destroyed all of the exit methods + * are called in the reverse of the order with which they were + * registered. + */ +int register_pernet_subsys(struct pernet_operations *ops) +{ + int error; + down_write(&pernet_ops_rwsem); + error = register_pernet_operations(first_device, ops); + up_write(&pernet_ops_rwsem); + return error; +} +EXPORT_SYMBOL_GPL(register_pernet_subsys); + +/** + * unregister_pernet_subsys - unregister a network namespace subsystem + * @ops: pernet operations structure to manipulate + * + * Remove the pernet operations structure from the list to be + * used when network namespaces are created or destroyed. In + * addition run the exit method for all existing network + * namespaces. + */ +void unregister_pernet_subsys(struct pernet_operations *ops) +{ + down_write(&pernet_ops_rwsem); + unregister_pernet_operations(ops); + up_write(&pernet_ops_rwsem); +} +EXPORT_SYMBOL_GPL(unregister_pernet_subsys); + +/** + * register_pernet_device - register a network namespace device + * @ops: pernet operations structure for the subsystem + * + * Register a device which has init and exit functions + * that are called when network namespaces are created and + * destroyed respectively. + * + * When registered all network namespace init functions are + * called for every existing network namespace. Allowing kernel + * modules to have a race free view of the set of network namespaces. + * + * When a new network namespace is created all of the init + * methods are called in the order in which they were registered. + * + * When a network namespace is destroyed all of the exit methods + * are called in the reverse of the order with which they were + * registered. + */ +int register_pernet_device(struct pernet_operations *ops) +{ + int error; + down_write(&pernet_ops_rwsem); + error = register_pernet_operations(&pernet_list, ops); + if (!error && (first_device == &pernet_list)) + first_device = &ops->list; + up_write(&pernet_ops_rwsem); + return error; +} +EXPORT_SYMBOL_GPL(register_pernet_device); + +/** + * unregister_pernet_device - unregister a network namespace netdevice + * @ops: pernet operations structure to manipulate + * + * Remove the pernet operations structure from the list to be + * used when network namespaces are created or destroyed. In + * addition run the exit method for all existing network + * namespaces. + */ +void unregister_pernet_device(struct pernet_operations *ops) +{ + down_write(&pernet_ops_rwsem); + if (&ops->list == first_device) + first_device = first_device->next; + unregister_pernet_operations(ops); + up_write(&pernet_ops_rwsem); +} +EXPORT_SYMBOL_GPL(unregister_pernet_device); + +#ifdef CONFIG_NET_NS +static struct ns_common *netns_get(struct task_struct *task) +{ + struct net *net = NULL; + struct nsproxy *nsproxy; + + task_lock(task); + nsproxy = task->nsproxy; + if (nsproxy) + net = get_net(nsproxy->net_ns); + task_unlock(task); + + return net ? &net->ns : NULL; +} + +static inline struct net *to_net_ns(struct ns_common *ns) +{ + return container_of(ns, struct net, ns); +} + +static void netns_put(struct ns_common *ns) +{ + put_net(to_net_ns(ns)); +} + +static int netns_install(struct nsset *nsset, struct ns_common *ns) +{ + struct nsproxy *nsproxy = nsset->nsproxy; + struct net *net = to_net_ns(ns); + + if (!ns_capable(net->user_ns, CAP_SYS_ADMIN) || + !ns_capable(nsset->cred->user_ns, CAP_SYS_ADMIN)) + return -EPERM; + + put_net(nsproxy->net_ns); + nsproxy->net_ns = get_net(net); + return 0; +} + +static struct user_namespace *netns_owner(struct ns_common *ns) +{ + return to_net_ns(ns)->user_ns; +} + +const struct proc_ns_operations netns_operations = { + .name = "net", + .type = CLONE_NEWNET, + .get = netns_get, + .put = netns_put, + .install = netns_install, + .owner = netns_owner, +}; +#endif diff --git a/net/core/netclassid_cgroup.c b/net/core/netclassid_cgroup.c new file mode 100644 index 000000000..d6a70aeaa --- /dev/null +++ b/net/core/netclassid_cgroup.c @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/core/netclassid_cgroup.c Classid Cgroupfs Handling + * + * Authors: Thomas Graf + */ + +#include +#include +#include +#include + +#include +#include + +static inline struct cgroup_cls_state *css_cls_state(struct cgroup_subsys_state *css) +{ + return css ? container_of(css, struct cgroup_cls_state, css) : NULL; +} + +struct cgroup_cls_state *task_cls_state(struct task_struct *p) +{ + return css_cls_state(task_css_check(p, net_cls_cgrp_id, + rcu_read_lock_bh_held())); +} +EXPORT_SYMBOL_GPL(task_cls_state); + +static struct cgroup_subsys_state * +cgrp_css_alloc(struct cgroup_subsys_state *parent_css) +{ + struct cgroup_cls_state *cs; + + cs = kzalloc(sizeof(*cs), GFP_KERNEL); + if (!cs) + return ERR_PTR(-ENOMEM); + + return &cs->css; +} + +static int cgrp_css_online(struct cgroup_subsys_state *css) +{ + struct cgroup_cls_state *cs = css_cls_state(css); + struct cgroup_cls_state *parent = css_cls_state(css->parent); + + if (parent) + cs->classid = parent->classid; + + return 0; +} + +static void cgrp_css_free(struct cgroup_subsys_state *css) +{ + kfree(css_cls_state(css)); +} + +/* + * To avoid freezing of sockets creation for tasks with big number of threads + * and opened sockets lets release file_lock every 1000 iterated descriptors. + * New sockets will already have been created with new classid. + */ + +struct update_classid_context { + u32 classid; + unsigned int batch; +}; + +#define UPDATE_CLASSID_BATCH 1000 + +static int update_classid_sock(const void *v, struct file *file, unsigned int n) +{ + struct update_classid_context *ctx = (void *)v; + struct socket *sock = sock_from_file(file); + + if (sock) + sock_cgroup_set_classid(&sock->sk->sk_cgrp_data, ctx->classid); + if (--ctx->batch == 0) { + ctx->batch = UPDATE_CLASSID_BATCH; + return n + 1; + } + return 0; +} + +static void update_classid_task(struct task_struct *p, u32 classid) +{ + struct update_classid_context ctx = { + .classid = classid, + .batch = UPDATE_CLASSID_BATCH + }; + unsigned int fd = 0; + + do { + task_lock(p); + fd = iterate_fd(p->files, fd, update_classid_sock, &ctx); + task_unlock(p); + cond_resched(); + } while (fd); +} + +static void cgrp_attach(struct cgroup_taskset *tset) +{ + struct cgroup_subsys_state *css; + struct task_struct *p; + + cgroup_taskset_for_each(p, css, tset) { + update_classid_task(p, css_cls_state(css)->classid); + } +} + +static u64 read_classid(struct cgroup_subsys_state *css, struct cftype *cft) +{ + return css_cls_state(css)->classid; +} + +static int write_classid(struct cgroup_subsys_state *css, struct cftype *cft, + u64 value) +{ + struct cgroup_cls_state *cs = css_cls_state(css); + struct css_task_iter it; + struct task_struct *p; + + cs->classid = (u32)value; + + css_task_iter_start(css, 0, &it); + while ((p = css_task_iter_next(&it))) + update_classid_task(p, cs->classid); + css_task_iter_end(&it); + + return 0; +} + +static struct cftype ss_files[] = { + { + .name = "classid", + .read_u64 = read_classid, + .write_u64 = write_classid, + }, + { } /* terminate */ +}; + +struct cgroup_subsys net_cls_cgrp_subsys = { + .css_alloc = cgrp_css_alloc, + .css_online = cgrp_css_online, + .css_free = cgrp_css_free, + .attach = cgrp_attach, + .legacy_cftypes = ss_files, +}; diff --git a/net/core/netevent.c b/net/core/netevent.c new file mode 100644 index 000000000..5bb615e96 --- /dev/null +++ b/net/core/netevent.c @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Network event notifiers + * + * Authors: + * Tom Tucker + * Steve Wise + * + * Fixes: + */ + +#include +#include +#include +#include + +static ATOMIC_NOTIFIER_HEAD(netevent_notif_chain); + +/** + * register_netevent_notifier - register a netevent notifier block + * @nb: notifier + * + * Register a notifier to be called when a netevent occurs. + * The notifier passed is linked into the kernel structures and must + * not be reused until it has been unregistered. A negative errno code + * is returned on a failure. + */ +int register_netevent_notifier(struct notifier_block *nb) +{ + return atomic_notifier_chain_register(&netevent_notif_chain, nb); +} +EXPORT_SYMBOL_GPL(register_netevent_notifier); + +/** + * unregister_netevent_notifier - unregister a netevent notifier block + * @nb: notifier + * + * Unregister a notifier previously registered by + * register_neigh_notifier(). The notifier is unlinked into the + * kernel structures and may then be reused. A negative errno code + * is returned on a failure. + */ + +int unregister_netevent_notifier(struct notifier_block *nb) +{ + return atomic_notifier_chain_unregister(&netevent_notif_chain, nb); +} +EXPORT_SYMBOL_GPL(unregister_netevent_notifier); + +/** + * call_netevent_notifiers - call all netevent notifier blocks + * @val: value passed unmodified to notifier function + * @v: pointer passed unmodified to notifier function + * + * Call all neighbour notifier blocks. Parameters and return value + * are as for notifier_call_chain(). + */ + +int call_netevent_notifiers(unsigned long val, void *v) +{ + return atomic_notifier_call_chain(&netevent_notif_chain, val, v); +} +EXPORT_SYMBOL_GPL(call_netevent_notifiers); diff --git a/net/core/netpoll.c b/net/core/netpoll.c new file mode 100644 index 000000000..4ac8d0ad9 --- /dev/null +++ b/net/core/netpoll.c @@ -0,0 +1,878 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Common framework for low-level network console, dump, and debugger code + * + * Sep 8 2003 Matt Mackall + * + * based on the netconsole code from: + * + * Copyright (C) 2001 Ingo Molnar + * Copyright (C) 2002 Red Hat, Inc. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * We maintain a small pool of fully-sized skbs, to make sure the + * message gets out even in extreme OOM situations. + */ + +#define MAX_UDP_CHUNK 1460 +#define MAX_SKBS 32 + +static struct sk_buff_head skb_pool; + +DEFINE_STATIC_SRCU(netpoll_srcu); + +#define USEC_PER_POLL 50 + +#define MAX_SKB_SIZE \ + (sizeof(struct ethhdr) + \ + sizeof(struct iphdr) + \ + sizeof(struct udphdr) + \ + MAX_UDP_CHUNK) + +static void zap_completion_queue(void); + +static unsigned int carrier_timeout = 4; +module_param(carrier_timeout, uint, 0644); + +#define np_info(np, fmt, ...) \ + pr_info("%s: " fmt, np->name, ##__VA_ARGS__) +#define np_err(np, fmt, ...) \ + pr_err("%s: " fmt, np->name, ##__VA_ARGS__) +#define np_notice(np, fmt, ...) \ + pr_notice("%s: " fmt, np->name, ##__VA_ARGS__) + +static netdev_tx_t netpoll_start_xmit(struct sk_buff *skb, + struct net_device *dev, + struct netdev_queue *txq) +{ + netdev_tx_t status = NETDEV_TX_OK; + netdev_features_t features; + + features = netif_skb_features(skb); + + if (skb_vlan_tag_present(skb) && + !vlan_hw_offload_capable(features, skb->vlan_proto)) { + skb = __vlan_hwaccel_push_inside(skb); + if (unlikely(!skb)) { + /* This is actually a packet drop, but we + * don't want the code that calls this + * function to try and operate on a NULL skb. + */ + goto out; + } + } + + status = netdev_start_xmit(skb, dev, txq, false); + +out: + return status; +} + +static void queue_process(struct work_struct *work) +{ + struct netpoll_info *npinfo = + container_of(work, struct netpoll_info, tx_work.work); + struct sk_buff *skb; + unsigned long flags; + + while ((skb = skb_dequeue(&npinfo->txq))) { + struct net_device *dev = skb->dev; + struct netdev_queue *txq; + unsigned int q_index; + + if (!netif_device_present(dev) || !netif_running(dev)) { + kfree_skb(skb); + continue; + } + + local_irq_save(flags); + /* check if skb->queue_mapping is still valid */ + q_index = skb_get_queue_mapping(skb); + if (unlikely(q_index >= dev->real_num_tx_queues)) { + q_index = q_index % dev->real_num_tx_queues; + skb_set_queue_mapping(skb, q_index); + } + txq = netdev_get_tx_queue(dev, q_index); + HARD_TX_LOCK(dev, txq, smp_processor_id()); + if (netif_xmit_frozen_or_stopped(txq) || + !dev_xmit_complete(netpoll_start_xmit(skb, dev, txq))) { + skb_queue_head(&npinfo->txq, skb); + HARD_TX_UNLOCK(dev, txq); + local_irq_restore(flags); + + schedule_delayed_work(&npinfo->tx_work, HZ/10); + return; + } + HARD_TX_UNLOCK(dev, txq); + local_irq_restore(flags); + } +} + +static int netif_local_xmit_active(struct net_device *dev) +{ + int i; + + for (i = 0; i < dev->num_tx_queues; i++) { + struct netdev_queue *txq = netdev_get_tx_queue(dev, i); + + if (READ_ONCE(txq->xmit_lock_owner) == smp_processor_id()) + return 1; + } + + return 0; +} + +static void poll_one_napi(struct napi_struct *napi) +{ + int work; + + /* If we set this bit but see that it has already been set, + * that indicates that napi has been disabled and we need + * to abort this operation + */ + if (test_and_set_bit(NAPI_STATE_NPSVC, &napi->state)) + return; + + /* We explicilty pass the polling call a budget of 0 to + * indicate that we are clearing the Tx path only. + */ + work = napi->poll(napi, 0); + WARN_ONCE(work, "%pS exceeded budget in poll\n", napi->poll); + trace_napi_poll(napi, work, 0); + + clear_bit(NAPI_STATE_NPSVC, &napi->state); +} + +static void poll_napi(struct net_device *dev) +{ + struct napi_struct *napi; + int cpu = smp_processor_id(); + + list_for_each_entry_rcu(napi, &dev->napi_list, dev_list) { + if (cmpxchg(&napi->poll_owner, -1, cpu) == -1) { + poll_one_napi(napi); + smp_store_release(&napi->poll_owner, -1); + } + } +} + +void netpoll_poll_dev(struct net_device *dev) +{ + struct netpoll_info *ni = rcu_dereference_bh(dev->npinfo); + const struct net_device_ops *ops; + + /* Don't do any rx activity if the dev_lock mutex is held + * the dev_open/close paths use this to block netpoll activity + * while changing device state + */ + if (!ni || down_trylock(&ni->dev_lock)) + return; + + /* Some drivers will take the same locks in poll and xmit, + * we can't poll if local CPU is already in xmit. + */ + if (!netif_running(dev) || netif_local_xmit_active(dev)) { + up(&ni->dev_lock); + return; + } + + ops = dev->netdev_ops; + if (ops->ndo_poll_controller) + ops->ndo_poll_controller(dev); + + poll_napi(dev); + + up(&ni->dev_lock); + + zap_completion_queue(); +} +EXPORT_SYMBOL(netpoll_poll_dev); + +void netpoll_poll_disable(struct net_device *dev) +{ + struct netpoll_info *ni; + int idx; + might_sleep(); + idx = srcu_read_lock(&netpoll_srcu); + ni = srcu_dereference(dev->npinfo, &netpoll_srcu); + if (ni) + down(&ni->dev_lock); + srcu_read_unlock(&netpoll_srcu, idx); +} +EXPORT_SYMBOL(netpoll_poll_disable); + +void netpoll_poll_enable(struct net_device *dev) +{ + struct netpoll_info *ni; + rcu_read_lock(); + ni = rcu_dereference(dev->npinfo); + if (ni) + up(&ni->dev_lock); + rcu_read_unlock(); +} +EXPORT_SYMBOL(netpoll_poll_enable); + +static void refill_skbs(void) +{ + struct sk_buff *skb; + unsigned long flags; + + spin_lock_irqsave(&skb_pool.lock, flags); + while (skb_pool.qlen < MAX_SKBS) { + skb = alloc_skb(MAX_SKB_SIZE, GFP_ATOMIC); + if (!skb) + break; + + __skb_queue_tail(&skb_pool, skb); + } + spin_unlock_irqrestore(&skb_pool.lock, flags); +} + +static void zap_completion_queue(void) +{ + unsigned long flags; + struct softnet_data *sd = &get_cpu_var(softnet_data); + + if (sd->completion_queue) { + struct sk_buff *clist; + + local_irq_save(flags); + clist = sd->completion_queue; + sd->completion_queue = NULL; + local_irq_restore(flags); + + while (clist != NULL) { + struct sk_buff *skb = clist; + clist = clist->next; + if (!skb_irq_freeable(skb)) { + refcount_set(&skb->users, 1); + dev_kfree_skb_any(skb); /* put this one back */ + } else { + __kfree_skb(skb); + } + } + } + + put_cpu_var(softnet_data); +} + +static struct sk_buff *find_skb(struct netpoll *np, int len, int reserve) +{ + int count = 0; + struct sk_buff *skb; + + zap_completion_queue(); + refill_skbs(); +repeat: + + skb = alloc_skb(len, GFP_ATOMIC); + if (!skb) + skb = skb_dequeue(&skb_pool); + + if (!skb) { + if (++count < 10) { + netpoll_poll_dev(np->dev); + goto repeat; + } + return NULL; + } + + refcount_set(&skb->users, 1); + skb_reserve(skb, reserve); + return skb; +} + +static int netpoll_owner_active(struct net_device *dev) +{ + struct napi_struct *napi; + + list_for_each_entry_rcu(napi, &dev->napi_list, dev_list) { + if (napi->poll_owner == smp_processor_id()) + return 1; + } + return 0; +} + +/* call with IRQ disabled */ +static netdev_tx_t __netpoll_send_skb(struct netpoll *np, struct sk_buff *skb) +{ + netdev_tx_t status = NETDEV_TX_BUSY; + struct net_device *dev; + unsigned long tries; + /* It is up to the caller to keep npinfo alive. */ + struct netpoll_info *npinfo; + + lockdep_assert_irqs_disabled(); + + dev = np->dev; + npinfo = rcu_dereference_bh(dev->npinfo); + + if (!npinfo || !netif_running(dev) || !netif_device_present(dev)) { + dev_kfree_skb_irq(skb); + return NET_XMIT_DROP; + } + + /* don't get messages out of order, and no recursion */ + if (skb_queue_len(&npinfo->txq) == 0 && !netpoll_owner_active(dev)) { + struct netdev_queue *txq; + + txq = netdev_core_pick_tx(dev, skb, NULL); + + /* try until next clock tick */ + for (tries = jiffies_to_usecs(1)/USEC_PER_POLL; + tries > 0; --tries) { + if (HARD_TX_TRYLOCK(dev, txq)) { + if (!netif_xmit_stopped(txq)) + status = netpoll_start_xmit(skb, dev, txq); + + HARD_TX_UNLOCK(dev, txq); + + if (dev_xmit_complete(status)) + break; + + } + + /* tickle device maybe there is some cleanup */ + netpoll_poll_dev(np->dev); + + udelay(USEC_PER_POLL); + } + + WARN_ONCE(!irqs_disabled(), + "netpoll_send_skb_on_dev(): %s enabled interrupts in poll (%pS)\n", + dev->name, dev->netdev_ops->ndo_start_xmit); + + } + + if (!dev_xmit_complete(status)) { + skb_queue_tail(&npinfo->txq, skb); + schedule_delayed_work(&npinfo->tx_work,0); + } + return NETDEV_TX_OK; +} + +netdev_tx_t netpoll_send_skb(struct netpoll *np, struct sk_buff *skb) +{ + unsigned long flags; + netdev_tx_t ret; + + if (unlikely(!np)) { + dev_kfree_skb_irq(skb); + ret = NET_XMIT_DROP; + } else { + local_irq_save(flags); + ret = __netpoll_send_skb(np, skb); + local_irq_restore(flags); + } + return ret; +} +EXPORT_SYMBOL(netpoll_send_skb); + +void netpoll_send_udp(struct netpoll *np, const char *msg, int len) +{ + int total_len, ip_len, udp_len; + struct sk_buff *skb; + struct udphdr *udph; + struct iphdr *iph; + struct ethhdr *eth; + static atomic_t ip_ident; + struct ipv6hdr *ip6h; + + if (!IS_ENABLED(CONFIG_PREEMPT_RT)) + WARN_ON_ONCE(!irqs_disabled()); + + udp_len = len + sizeof(*udph); + if (np->ipv6) + ip_len = udp_len + sizeof(*ip6h); + else + ip_len = udp_len + sizeof(*iph); + + total_len = ip_len + LL_RESERVED_SPACE(np->dev); + + skb = find_skb(np, total_len + np->dev->needed_tailroom, + total_len - len); + if (!skb) + return; + + skb_copy_to_linear_data(skb, msg, len); + skb_put(skb, len); + + skb_push(skb, sizeof(*udph)); + skb_reset_transport_header(skb); + udph = udp_hdr(skb); + udph->source = htons(np->local_port); + udph->dest = htons(np->remote_port); + udph->len = htons(udp_len); + + if (np->ipv6) { + udph->check = 0; + udph->check = csum_ipv6_magic(&np->local_ip.in6, + &np->remote_ip.in6, + udp_len, IPPROTO_UDP, + csum_partial(udph, udp_len, 0)); + if (udph->check == 0) + udph->check = CSUM_MANGLED_0; + + skb_push(skb, sizeof(*ip6h)); + skb_reset_network_header(skb); + ip6h = ipv6_hdr(skb); + + /* ip6h->version = 6; ip6h->priority = 0; */ + *(unsigned char *)ip6h = 0x60; + ip6h->flow_lbl[0] = 0; + ip6h->flow_lbl[1] = 0; + ip6h->flow_lbl[2] = 0; + + ip6h->payload_len = htons(sizeof(struct udphdr) + len); + ip6h->nexthdr = IPPROTO_UDP; + ip6h->hop_limit = 32; + ip6h->saddr = np->local_ip.in6; + ip6h->daddr = np->remote_ip.in6; + + eth = skb_push(skb, ETH_HLEN); + skb_reset_mac_header(skb); + skb->protocol = eth->h_proto = htons(ETH_P_IPV6); + } else { + udph->check = 0; + udph->check = csum_tcpudp_magic(np->local_ip.ip, + np->remote_ip.ip, + udp_len, IPPROTO_UDP, + csum_partial(udph, udp_len, 0)); + if (udph->check == 0) + udph->check = CSUM_MANGLED_0; + + skb_push(skb, sizeof(*iph)); + skb_reset_network_header(skb); + iph = ip_hdr(skb); + + /* iph->version = 4; iph->ihl = 5; */ + *(unsigned char *)iph = 0x45; + iph->tos = 0; + put_unaligned(htons(ip_len), &(iph->tot_len)); + iph->id = htons(atomic_inc_return(&ip_ident)); + iph->frag_off = 0; + iph->ttl = 64; + iph->protocol = IPPROTO_UDP; + iph->check = 0; + put_unaligned(np->local_ip.ip, &(iph->saddr)); + put_unaligned(np->remote_ip.ip, &(iph->daddr)); + iph->check = ip_fast_csum((unsigned char *)iph, iph->ihl); + + eth = skb_push(skb, ETH_HLEN); + skb_reset_mac_header(skb); + skb->protocol = eth->h_proto = htons(ETH_P_IP); + } + + ether_addr_copy(eth->h_source, np->dev->dev_addr); + ether_addr_copy(eth->h_dest, np->remote_mac); + + skb->dev = np->dev; + + netpoll_send_skb(np, skb); +} +EXPORT_SYMBOL(netpoll_send_udp); + +void netpoll_print_options(struct netpoll *np) +{ + np_info(np, "local port %d\n", np->local_port); + if (np->ipv6) + np_info(np, "local IPv6 address %pI6c\n", &np->local_ip.in6); + else + np_info(np, "local IPv4 address %pI4\n", &np->local_ip.ip); + np_info(np, "interface '%s'\n", np->dev_name); + np_info(np, "remote port %d\n", np->remote_port); + if (np->ipv6) + np_info(np, "remote IPv6 address %pI6c\n", &np->remote_ip.in6); + else + np_info(np, "remote IPv4 address %pI4\n", &np->remote_ip.ip); + np_info(np, "remote ethernet address %pM\n", np->remote_mac); +} +EXPORT_SYMBOL(netpoll_print_options); + +static int netpoll_parse_ip_addr(const char *str, union inet_addr *addr) +{ + const char *end; + + if (!strchr(str, ':') && + in4_pton(str, -1, (void *)addr, -1, &end) > 0) { + if (!*end) + return 0; + } + if (in6_pton(str, -1, addr->in6.s6_addr, -1, &end) > 0) { +#if IS_ENABLED(CONFIG_IPV6) + if (!*end) + return 1; +#else + return -1; +#endif + } + return -1; +} + +int netpoll_parse_options(struct netpoll *np, char *opt) +{ + char *cur=opt, *delim; + int ipv6; + bool ipversion_set = false; + + if (*cur != '@') { + if ((delim = strchr(cur, '@')) == NULL) + goto parse_failed; + *delim = 0; + if (kstrtou16(cur, 10, &np->local_port)) + goto parse_failed; + cur = delim; + } + cur++; + + if (*cur != '/') { + ipversion_set = true; + if ((delim = strchr(cur, '/')) == NULL) + goto parse_failed; + *delim = 0; + ipv6 = netpoll_parse_ip_addr(cur, &np->local_ip); + if (ipv6 < 0) + goto parse_failed; + else + np->ipv6 = (bool)ipv6; + cur = delim; + } + cur++; + + if (*cur != ',') { + /* parse out dev name */ + if ((delim = strchr(cur, ',')) == NULL) + goto parse_failed; + *delim = 0; + strscpy(np->dev_name, cur, sizeof(np->dev_name)); + cur = delim; + } + cur++; + + if (*cur != '@') { + /* dst port */ + if ((delim = strchr(cur, '@')) == NULL) + goto parse_failed; + *delim = 0; + if (*cur == ' ' || *cur == '\t') + np_info(np, "warning: whitespace is not allowed\n"); + if (kstrtou16(cur, 10, &np->remote_port)) + goto parse_failed; + cur = delim; + } + cur++; + + /* dst ip */ + if ((delim = strchr(cur, '/')) == NULL) + goto parse_failed; + *delim = 0; + ipv6 = netpoll_parse_ip_addr(cur, &np->remote_ip); + if (ipv6 < 0) + goto parse_failed; + else if (ipversion_set && np->ipv6 != (bool)ipv6) + goto parse_failed; + else + np->ipv6 = (bool)ipv6; + cur = delim + 1; + + if (*cur != 0) { + /* MAC address */ + if (!mac_pton(cur, np->remote_mac)) + goto parse_failed; + } + + netpoll_print_options(np); + + return 0; + + parse_failed: + np_info(np, "couldn't parse config at '%s'!\n", cur); + return -1; +} +EXPORT_SYMBOL(netpoll_parse_options); + +int __netpoll_setup(struct netpoll *np, struct net_device *ndev) +{ + struct netpoll_info *npinfo; + const struct net_device_ops *ops; + int err; + + np->dev = ndev; + strscpy(np->dev_name, ndev->name, IFNAMSIZ); + + if (ndev->priv_flags & IFF_DISABLE_NETPOLL) { + np_err(np, "%s doesn't support polling, aborting\n", + np->dev_name); + err = -ENOTSUPP; + goto out; + } + + if (!ndev->npinfo) { + npinfo = kmalloc(sizeof(*npinfo), GFP_KERNEL); + if (!npinfo) { + err = -ENOMEM; + goto out; + } + + sema_init(&npinfo->dev_lock, 1); + skb_queue_head_init(&npinfo->txq); + INIT_DELAYED_WORK(&npinfo->tx_work, queue_process); + + refcount_set(&npinfo->refcnt, 1); + + ops = np->dev->netdev_ops; + if (ops->ndo_netpoll_setup) { + err = ops->ndo_netpoll_setup(ndev, npinfo); + if (err) + goto free_npinfo; + } + } else { + npinfo = rtnl_dereference(ndev->npinfo); + refcount_inc(&npinfo->refcnt); + } + + npinfo->netpoll = np; + + /* last thing to do is link it to the net device structure */ + rcu_assign_pointer(ndev->npinfo, npinfo); + + return 0; + +free_npinfo: + kfree(npinfo); +out: + return err; +} +EXPORT_SYMBOL_GPL(__netpoll_setup); + +int netpoll_setup(struct netpoll *np) +{ + struct net_device *ndev = NULL; + struct in_device *in_dev; + int err; + + rtnl_lock(); + if (np->dev_name[0]) { + struct net *net = current->nsproxy->net_ns; + ndev = __dev_get_by_name(net, np->dev_name); + } + if (!ndev) { + np_err(np, "%s doesn't exist, aborting\n", np->dev_name); + err = -ENODEV; + goto unlock; + } + dev_hold(ndev); + + if (netdev_master_upper_dev_get(ndev)) { + np_err(np, "%s is a slave device, aborting\n", np->dev_name); + err = -EBUSY; + goto put; + } + + if (!netif_running(ndev)) { + unsigned long atmost, atleast; + + np_info(np, "device %s not up yet, forcing it\n", np->dev_name); + + err = dev_open(ndev, NULL); + + if (err) { + np_err(np, "failed to open %s\n", ndev->name); + goto put; + } + + rtnl_unlock(); + atleast = jiffies + HZ/10; + atmost = jiffies + carrier_timeout * HZ; + while (!netif_carrier_ok(ndev)) { + if (time_after(jiffies, atmost)) { + np_notice(np, "timeout waiting for carrier\n"); + break; + } + msleep(1); + } + + /* If carrier appears to come up instantly, we don't + * trust it and pause so that we don't pump all our + * queued console messages into the bitbucket. + */ + + if (time_before(jiffies, atleast)) { + np_notice(np, "carrier detect appears untrustworthy, waiting 4 seconds\n"); + msleep(4000); + } + rtnl_lock(); + } + + if (!np->local_ip.ip) { + if (!np->ipv6) { + const struct in_ifaddr *ifa; + + in_dev = __in_dev_get_rtnl(ndev); + if (!in_dev) + goto put_noaddr; + + ifa = rtnl_dereference(in_dev->ifa_list); + if (!ifa) { +put_noaddr: + np_err(np, "no IP address for %s, aborting\n", + np->dev_name); + err = -EDESTADDRREQ; + goto put; + } + + np->local_ip.ip = ifa->ifa_local; + np_info(np, "local IP %pI4\n", &np->local_ip.ip); + } else { +#if IS_ENABLED(CONFIG_IPV6) + struct inet6_dev *idev; + + err = -EDESTADDRREQ; + idev = __in6_dev_get(ndev); + if (idev) { + struct inet6_ifaddr *ifp; + + read_lock_bh(&idev->lock); + list_for_each_entry(ifp, &idev->addr_list, if_list) { + if (!!(ipv6_addr_type(&ifp->addr) & IPV6_ADDR_LINKLOCAL) != + !!(ipv6_addr_type(&np->remote_ip.in6) & IPV6_ADDR_LINKLOCAL)) + continue; + np->local_ip.in6 = ifp->addr; + err = 0; + break; + } + read_unlock_bh(&idev->lock); + } + if (err) { + np_err(np, "no IPv6 address for %s, aborting\n", + np->dev_name); + goto put; + } else + np_info(np, "local IPv6 %pI6c\n", &np->local_ip.in6); +#else + np_err(np, "IPv6 is not supported %s, aborting\n", + np->dev_name); + err = -EINVAL; + goto put; +#endif + } + } + + /* fill up the skb queue */ + refill_skbs(); + + err = __netpoll_setup(np, ndev); + if (err) + goto put; + netdev_tracker_alloc(ndev, &np->dev_tracker, GFP_KERNEL); + rtnl_unlock(); + return 0; + +put: + dev_put(ndev); +unlock: + rtnl_unlock(); + return err; +} +EXPORT_SYMBOL(netpoll_setup); + +static int __init netpoll_init(void) +{ + skb_queue_head_init(&skb_pool); + return 0; +} +core_initcall(netpoll_init); + +static void rcu_cleanup_netpoll_info(struct rcu_head *rcu_head) +{ + struct netpoll_info *npinfo = + container_of(rcu_head, struct netpoll_info, rcu); + + skb_queue_purge(&npinfo->txq); + + /* we can't call cancel_delayed_work_sync here, as we are in softirq */ + cancel_delayed_work(&npinfo->tx_work); + + /* clean after last, unfinished work */ + __skb_queue_purge(&npinfo->txq); + /* now cancel it again */ + cancel_delayed_work(&npinfo->tx_work); + kfree(npinfo); +} + +void __netpoll_cleanup(struct netpoll *np) +{ + struct netpoll_info *npinfo; + + npinfo = rtnl_dereference(np->dev->npinfo); + if (!npinfo) + return; + + synchronize_srcu(&netpoll_srcu); + + if (refcount_dec_and_test(&npinfo->refcnt)) { + const struct net_device_ops *ops; + + ops = np->dev->netdev_ops; + if (ops->ndo_netpoll_cleanup) + ops->ndo_netpoll_cleanup(np->dev); + + RCU_INIT_POINTER(np->dev->npinfo, NULL); + call_rcu(&npinfo->rcu, rcu_cleanup_netpoll_info); + } else + RCU_INIT_POINTER(np->dev->npinfo, NULL); +} +EXPORT_SYMBOL_GPL(__netpoll_cleanup); + +void __netpoll_free(struct netpoll *np) +{ + ASSERT_RTNL(); + + /* Wait for transmitting packets to finish before freeing. */ + synchronize_rcu(); + __netpoll_cleanup(np); + kfree(np); +} +EXPORT_SYMBOL_GPL(__netpoll_free); + +void netpoll_cleanup(struct netpoll *np) +{ + rtnl_lock(); + if (!np->dev) + goto out; + __netpoll_cleanup(np); + netdev_put(np->dev, &np->dev_tracker); + np->dev = NULL; +out: + rtnl_unlock(); +} +EXPORT_SYMBOL(netpoll_cleanup); diff --git a/net/core/netprio_cgroup.c b/net/core/netprio_cgroup.c new file mode 100644 index 000000000..8456dfbe2 --- /dev/null +++ b/net/core/netprio_cgroup.c @@ -0,0 +1,295 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/core/netprio_cgroup.c Priority Control Group + * + * Authors: Neil Horman + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include + +/* + * netprio allocates per-net_device priomap array which is indexed by + * css->id. Limiting css ID to 16bits doesn't lose anything. + */ +#define NETPRIO_ID_MAX USHRT_MAX + +#define PRIOMAP_MIN_SZ 128 + +/* + * Extend @dev->priomap so that it's large enough to accommodate + * @target_idx. @dev->priomap.priomap_len > @target_idx after successful + * return. Must be called under rtnl lock. + */ +static int extend_netdev_table(struct net_device *dev, u32 target_idx) +{ + struct netprio_map *old, *new; + size_t new_sz, new_len; + + /* is the existing priomap large enough? */ + old = rtnl_dereference(dev->priomap); + if (old && old->priomap_len > target_idx) + return 0; + + /* + * Determine the new size. Let's keep it power-of-two. We start + * from PRIOMAP_MIN_SZ and double it until it's large enough to + * accommodate @target_idx. + */ + new_sz = PRIOMAP_MIN_SZ; + while (true) { + new_len = (new_sz - offsetof(struct netprio_map, priomap)) / + sizeof(new->priomap[0]); + if (new_len > target_idx) + break; + new_sz *= 2; + /* overflowed? */ + if (WARN_ON(new_sz < PRIOMAP_MIN_SZ)) + return -ENOSPC; + } + + /* allocate & copy */ + new = kzalloc(new_sz, GFP_KERNEL); + if (!new) + return -ENOMEM; + + if (old) + memcpy(new->priomap, old->priomap, + old->priomap_len * sizeof(old->priomap[0])); + + new->priomap_len = new_len; + + /* install the new priomap */ + rcu_assign_pointer(dev->priomap, new); + if (old) + kfree_rcu(old, rcu); + return 0; +} + +/** + * netprio_prio - return the effective netprio of a cgroup-net_device pair + * @css: css part of the target pair + * @dev: net_device part of the target pair + * + * Should be called under RCU read or rtnl lock. + */ +static u32 netprio_prio(struct cgroup_subsys_state *css, struct net_device *dev) +{ + struct netprio_map *map = rcu_dereference_rtnl(dev->priomap); + int id = css->id; + + if (map && id < map->priomap_len) + return map->priomap[id]; + return 0; +} + +/** + * netprio_set_prio - set netprio on a cgroup-net_device pair + * @css: css part of the target pair + * @dev: net_device part of the target pair + * @prio: prio to set + * + * Set netprio to @prio on @css-@dev pair. Should be called under rtnl + * lock and may fail under memory pressure for non-zero @prio. + */ +static int netprio_set_prio(struct cgroup_subsys_state *css, + struct net_device *dev, u32 prio) +{ + struct netprio_map *map; + int id = css->id; + int ret; + + /* avoid extending priomap for zero writes */ + map = rtnl_dereference(dev->priomap); + if (!prio && (!map || map->priomap_len <= id)) + return 0; + + ret = extend_netdev_table(dev, id); + if (ret) + return ret; + + map = rtnl_dereference(dev->priomap); + map->priomap[id] = prio; + return 0; +} + +static struct cgroup_subsys_state * +cgrp_css_alloc(struct cgroup_subsys_state *parent_css) +{ + struct cgroup_subsys_state *css; + + css = kzalloc(sizeof(*css), GFP_KERNEL); + if (!css) + return ERR_PTR(-ENOMEM); + + return css; +} + +static int cgrp_css_online(struct cgroup_subsys_state *css) +{ + struct cgroup_subsys_state *parent_css = css->parent; + struct net_device *dev; + int ret = 0; + + if (css->id > NETPRIO_ID_MAX) + return -ENOSPC; + + if (!parent_css) + return 0; + + rtnl_lock(); + /* + * Inherit prios from the parent. As all prios are set during + * onlining, there is no need to clear them on offline. + */ + for_each_netdev(&init_net, dev) { + u32 prio = netprio_prio(parent_css, dev); + + ret = netprio_set_prio(css, dev, prio); + if (ret) + break; + } + rtnl_unlock(); + return ret; +} + +static void cgrp_css_free(struct cgroup_subsys_state *css) +{ + kfree(css); +} + +static u64 read_prioidx(struct cgroup_subsys_state *css, struct cftype *cft) +{ + return css->id; +} + +static int read_priomap(struct seq_file *sf, void *v) +{ + struct net_device *dev; + + rcu_read_lock(); + for_each_netdev_rcu(&init_net, dev) + seq_printf(sf, "%s %u\n", dev->name, + netprio_prio(seq_css(sf), dev)); + rcu_read_unlock(); + return 0; +} + +static ssize_t write_priomap(struct kernfs_open_file *of, + char *buf, size_t nbytes, loff_t off) +{ + char devname[IFNAMSIZ + 1]; + struct net_device *dev; + u32 prio; + int ret; + + if (sscanf(buf, "%"__stringify(IFNAMSIZ)"s %u", devname, &prio) != 2) + return -EINVAL; + + dev = dev_get_by_name(&init_net, devname); + if (!dev) + return -ENODEV; + + rtnl_lock(); + + ret = netprio_set_prio(of_css(of), dev, prio); + + rtnl_unlock(); + dev_put(dev); + return ret ?: nbytes; +} + +static int update_netprio(const void *v, struct file *file, unsigned n) +{ + struct socket *sock = sock_from_file(file); + + if (sock) + sock_cgroup_set_prioidx(&sock->sk->sk_cgrp_data, + (unsigned long)v); + return 0; +} + +static void net_prio_attach(struct cgroup_taskset *tset) +{ + struct task_struct *p; + struct cgroup_subsys_state *css; + + cgroup_taskset_for_each(p, css, tset) { + void *v = (void *)(unsigned long)css->id; + + task_lock(p); + iterate_fd(p->files, 0, update_netprio, v); + task_unlock(p); + } +} + +static struct cftype ss_files[] = { + { + .name = "prioidx", + .read_u64 = read_prioidx, + }, + { + .name = "ifpriomap", + .seq_show = read_priomap, + .write = write_priomap, + }, + { } /* terminate */ +}; + +struct cgroup_subsys net_prio_cgrp_subsys = { + .css_alloc = cgrp_css_alloc, + .css_online = cgrp_css_online, + .css_free = cgrp_css_free, + .attach = net_prio_attach, + .legacy_cftypes = ss_files, +}; + +static int netprio_device_event(struct notifier_block *unused, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct netprio_map *old; + + /* + * Note this is called with rtnl_lock held so we have update side + * protection on our rcu assignments + */ + + switch (event) { + case NETDEV_UNREGISTER: + old = rtnl_dereference(dev->priomap); + RCU_INIT_POINTER(dev->priomap, NULL); + if (old) + kfree_rcu(old, rcu); + break; + } + return NOTIFY_DONE; +} + +static struct notifier_block netprio_device_notifier = { + .notifier_call = netprio_device_event +}; + +static int __init init_cgroup_netprio(void) +{ + register_netdevice_notifier(&netprio_device_notifier); + return 0; +} +subsys_initcall(init_cgroup_netprio); diff --git a/net/core/of_net.c b/net/core/of_net.c new file mode 100644 index 000000000..f1a9bf757 --- /dev/null +++ b/net/core/of_net.c @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * OF helpers for network devices. + * + * Initially copied out of arch/powerpc/kernel/prom_parse.c + */ +#include +#include +#include +#include +#include +#include +#include +#include + +/** + * of_get_phy_mode - Get phy mode for given device_node + * @np: Pointer to the given device_node + * @interface: Pointer to the result + * + * The function gets phy interface string from property 'phy-mode' or + * 'phy-connection-type'. The index in phy_modes table is set in + * interface and 0 returned. In case of error interface is set to + * PHY_INTERFACE_MODE_NA and an errno is returned, e.g. -ENODEV. + */ +int of_get_phy_mode(struct device_node *np, phy_interface_t *interface) +{ + const char *pm; + int err, i; + + *interface = PHY_INTERFACE_MODE_NA; + + err = of_property_read_string(np, "phy-mode", &pm); + if (err < 0) + err = of_property_read_string(np, "phy-connection-type", &pm); + if (err < 0) + return err; + + for (i = 0; i < PHY_INTERFACE_MODE_MAX; i++) + if (!strcasecmp(pm, phy_modes(i))) { + *interface = i; + return 0; + } + + return -ENODEV; +} +EXPORT_SYMBOL_GPL(of_get_phy_mode); + +static int of_get_mac_addr(struct device_node *np, const char *name, u8 *addr) +{ + struct property *pp = of_find_property(np, name, NULL); + + if (pp && pp->length == ETH_ALEN && is_valid_ether_addr(pp->value)) { + memcpy(addr, pp->value, ETH_ALEN); + return 0; + } + return -ENODEV; +} + +static int of_get_mac_addr_nvmem(struct device_node *np, u8 *addr) +{ + struct platform_device *pdev = of_find_device_by_node(np); + struct nvmem_cell *cell; + const void *mac; + size_t len; + int ret; + + /* Try lookup by device first, there might be a nvmem_cell_lookup + * associated with a given device. + */ + if (pdev) { + ret = nvmem_get_mac_address(&pdev->dev, addr); + put_device(&pdev->dev); + return ret; + } + + cell = of_nvmem_cell_get(np, "mac-address"); + if (IS_ERR(cell)) + return PTR_ERR(cell); + + mac = nvmem_cell_read(cell, &len); + nvmem_cell_put(cell); + + if (IS_ERR(mac)) + return PTR_ERR(mac); + + if (len != ETH_ALEN || !is_valid_ether_addr(mac)) { + kfree(mac); + return -EINVAL; + } + + memcpy(addr, mac, ETH_ALEN); + kfree(mac); + + return 0; +} + +/** + * of_get_mac_address() + * @np: Caller's Device Node + * @addr: Pointer to a six-byte array for the result + * + * Search the device tree for the best MAC address to use. 'mac-address' is + * checked first, because that is supposed to contain to "most recent" MAC + * address. If that isn't set, then 'local-mac-address' is checked next, + * because that is the default address. If that isn't set, then the obsolete + * 'address' is checked, just in case we're using an old device tree. If any + * of the above isn't set, then try to get MAC address from nvmem cell named + * 'mac-address'. + * + * Note that the 'address' property is supposed to contain a virtual address of + * the register set, but some DTS files have redefined that property to be the + * MAC address. + * + * All-zero MAC addresses are rejected, because those could be properties that + * exist in the device tree, but were not set by U-Boot. For example, the + * DTS could define 'mac-address' and 'local-mac-address', with zero MAC + * addresses. Some older U-Boots only initialized 'local-mac-address'. In + * this case, the real MAC is in 'local-mac-address', and 'mac-address' exists + * but is all zeros. + * + * Return: 0 on success and errno in case of error. +*/ +int of_get_mac_address(struct device_node *np, u8 *addr) +{ + int ret; + + if (!np) + return -ENODEV; + + ret = of_get_mac_addr(np, "mac-address", addr); + if (!ret) + return 0; + + ret = of_get_mac_addr(np, "local-mac-address", addr); + if (!ret) + return 0; + + ret = of_get_mac_addr(np, "address", addr); + if (!ret) + return 0; + + return of_get_mac_addr_nvmem(np, addr); +} +EXPORT_SYMBOL(of_get_mac_address); + +/** + * of_get_ethdev_address() + * @np: Caller's Device Node + * @dev: Pointer to netdevice which address will be updated + * + * Search the device tree for the best MAC address to use. + * If found set @dev->dev_addr to that address. + * + * See documentation of of_get_mac_address() for more information on how + * the best address is determined. + * + * Return: 0 on success and errno in case of error. + */ +int of_get_ethdev_address(struct device_node *np, struct net_device *dev) +{ + u8 addr[ETH_ALEN]; + int ret; + + ret = of_get_mac_address(np, addr); + if (!ret) + eth_hw_addr_set(dev, addr); + return ret; +} +EXPORT_SYMBOL(of_get_ethdev_address); diff --git a/net/core/page_pool.c b/net/core/page_pool.c new file mode 100644 index 000000000..caf6d950d --- /dev/null +++ b/net/core/page_pool.c @@ -0,0 +1,932 @@ +/* SPDX-License-Identifier: GPL-2.0 + * + * page_pool.c + * Author: Jesper Dangaard Brouer + * Copyright (C) 2016 Red Hat, Inc. + */ + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include /* for put_page() */ +#include +#include + +#include + +#define DEFER_TIME (msecs_to_jiffies(1000)) +#define DEFER_WARN_INTERVAL (60 * HZ) + +#define BIAS_MAX LONG_MAX + +#ifdef CONFIG_PAGE_POOL_STATS +/* alloc_stat_inc is intended to be used in softirq context */ +#define alloc_stat_inc(pool, __stat) (pool->alloc_stats.__stat++) +/* recycle_stat_inc is safe to use when preemption is possible. */ +#define recycle_stat_inc(pool, __stat) \ + do { \ + struct page_pool_recycle_stats __percpu *s = pool->recycle_stats; \ + this_cpu_inc(s->__stat); \ + } while (0) + +#define recycle_stat_add(pool, __stat, val) \ + do { \ + struct page_pool_recycle_stats __percpu *s = pool->recycle_stats; \ + this_cpu_add(s->__stat, val); \ + } while (0) + +static const char pp_stats[][ETH_GSTRING_LEN] = { + "rx_pp_alloc_fast", + "rx_pp_alloc_slow", + "rx_pp_alloc_slow_ho", + "rx_pp_alloc_empty", + "rx_pp_alloc_refill", + "rx_pp_alloc_waive", + "rx_pp_recycle_cached", + "rx_pp_recycle_cache_full", + "rx_pp_recycle_ring", + "rx_pp_recycle_ring_full", + "rx_pp_recycle_released_ref", +}; + +bool page_pool_get_stats(struct page_pool *pool, + struct page_pool_stats *stats) +{ + int cpu = 0; + + if (!stats) + return false; + + /* The caller is responsible to initialize stats. */ + stats->alloc_stats.fast += pool->alloc_stats.fast; + stats->alloc_stats.slow += pool->alloc_stats.slow; + stats->alloc_stats.slow_high_order += pool->alloc_stats.slow_high_order; + stats->alloc_stats.empty += pool->alloc_stats.empty; + stats->alloc_stats.refill += pool->alloc_stats.refill; + stats->alloc_stats.waive += pool->alloc_stats.waive; + + for_each_possible_cpu(cpu) { + const struct page_pool_recycle_stats *pcpu = + per_cpu_ptr(pool->recycle_stats, cpu); + + stats->recycle_stats.cached += pcpu->cached; + stats->recycle_stats.cache_full += pcpu->cache_full; + stats->recycle_stats.ring += pcpu->ring; + stats->recycle_stats.ring_full += pcpu->ring_full; + stats->recycle_stats.released_refcnt += pcpu->released_refcnt; + } + + return true; +} +EXPORT_SYMBOL(page_pool_get_stats); + +u8 *page_pool_ethtool_stats_get_strings(u8 *data) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(pp_stats); i++) { + memcpy(data, pp_stats[i], ETH_GSTRING_LEN); + data += ETH_GSTRING_LEN; + } + + return data; +} +EXPORT_SYMBOL(page_pool_ethtool_stats_get_strings); + +int page_pool_ethtool_stats_get_count(void) +{ + return ARRAY_SIZE(pp_stats); +} +EXPORT_SYMBOL(page_pool_ethtool_stats_get_count); + +u64 *page_pool_ethtool_stats_get(u64 *data, void *stats) +{ + struct page_pool_stats *pool_stats = stats; + + *data++ = pool_stats->alloc_stats.fast; + *data++ = pool_stats->alloc_stats.slow; + *data++ = pool_stats->alloc_stats.slow_high_order; + *data++ = pool_stats->alloc_stats.empty; + *data++ = pool_stats->alloc_stats.refill; + *data++ = pool_stats->alloc_stats.waive; + *data++ = pool_stats->recycle_stats.cached; + *data++ = pool_stats->recycle_stats.cache_full; + *data++ = pool_stats->recycle_stats.ring; + *data++ = pool_stats->recycle_stats.ring_full; + *data++ = pool_stats->recycle_stats.released_refcnt; + + return data; +} +EXPORT_SYMBOL(page_pool_ethtool_stats_get); + +#else +#define alloc_stat_inc(pool, __stat) +#define recycle_stat_inc(pool, __stat) +#define recycle_stat_add(pool, __stat, val) +#endif + +static bool page_pool_producer_lock(struct page_pool *pool) + __acquires(&pool->ring.producer_lock) +{ + bool in_softirq = in_softirq(); + + if (in_softirq) + spin_lock(&pool->ring.producer_lock); + else + spin_lock_bh(&pool->ring.producer_lock); + + return in_softirq; +} + +static void page_pool_producer_unlock(struct page_pool *pool, + bool in_softirq) + __releases(&pool->ring.producer_lock) +{ + if (in_softirq) + spin_unlock(&pool->ring.producer_lock); + else + spin_unlock_bh(&pool->ring.producer_lock); +} + +static int page_pool_init(struct page_pool *pool, + const struct page_pool_params *params) +{ + unsigned int ring_qsize = 1024; /* Default */ + + memcpy(&pool->p, params, sizeof(pool->p)); + + /* Validate only known flags were used */ + if (pool->p.flags & ~(PP_FLAG_ALL)) + return -EINVAL; + + if (pool->p.pool_size) + ring_qsize = pool->p.pool_size; + + /* Sanity limit mem that can be pinned down */ + if (ring_qsize > 32768) + return -E2BIG; + + /* DMA direction is either DMA_FROM_DEVICE or DMA_BIDIRECTIONAL. + * DMA_BIDIRECTIONAL is for allowing page used for DMA sending, + * which is the XDP_TX use-case. + */ + if (pool->p.flags & PP_FLAG_DMA_MAP) { + if ((pool->p.dma_dir != DMA_FROM_DEVICE) && + (pool->p.dma_dir != DMA_BIDIRECTIONAL)) + return -EINVAL; + } + + if (pool->p.flags & PP_FLAG_DMA_SYNC_DEV) { + /* In order to request DMA-sync-for-device the page + * needs to be mapped + */ + if (!(pool->p.flags & PP_FLAG_DMA_MAP)) + return -EINVAL; + + if (!pool->p.max_len) + return -EINVAL; + + /* pool->p.offset has to be set according to the address + * offset used by the DMA engine to start copying rx data + */ + } + + if (PAGE_POOL_DMA_USE_PP_FRAG_COUNT && + pool->p.flags & PP_FLAG_PAGE_FRAG) + return -EINVAL; + +#ifdef CONFIG_PAGE_POOL_STATS + pool->recycle_stats = alloc_percpu(struct page_pool_recycle_stats); + if (!pool->recycle_stats) + return -ENOMEM; +#endif + + if (ptr_ring_init(&pool->ring, ring_qsize, GFP_KERNEL) < 0) { +#ifdef CONFIG_PAGE_POOL_STATS + free_percpu(pool->recycle_stats); +#endif + return -ENOMEM; + } + + atomic_set(&pool->pages_state_release_cnt, 0); + + /* Driver calling page_pool_create() also call page_pool_destroy() */ + refcount_set(&pool->user_cnt, 1); + + if (pool->p.flags & PP_FLAG_DMA_MAP) + get_device(pool->p.dev); + + return 0; +} + +struct page_pool *page_pool_create(const struct page_pool_params *params) +{ + struct page_pool *pool; + int err; + + pool = kzalloc_node(sizeof(*pool), GFP_KERNEL, params->nid); + if (!pool) + return ERR_PTR(-ENOMEM); + + err = page_pool_init(pool, params); + if (err < 0) { + pr_warn("%s() gave up with errno %d\n", __func__, err); + kfree(pool); + return ERR_PTR(err); + } + + return pool; +} +EXPORT_SYMBOL(page_pool_create); + +static void page_pool_return_page(struct page_pool *pool, struct page *page); + +noinline +static struct page *page_pool_refill_alloc_cache(struct page_pool *pool) +{ + struct ptr_ring *r = &pool->ring; + struct page *page; + int pref_nid; /* preferred NUMA node */ + + /* Quicker fallback, avoid locks when ring is empty */ + if (__ptr_ring_empty(r)) { + alloc_stat_inc(pool, empty); + return NULL; + } + + /* Softirq guarantee CPU and thus NUMA node is stable. This, + * assumes CPU refilling driver RX-ring will also run RX-NAPI. + */ +#ifdef CONFIG_NUMA + pref_nid = (pool->p.nid == NUMA_NO_NODE) ? numa_mem_id() : pool->p.nid; +#else + /* Ignore pool->p.nid setting if !CONFIG_NUMA, helps compiler */ + pref_nid = numa_mem_id(); /* will be zero like page_to_nid() */ +#endif + + /* Refill alloc array, but only if NUMA match */ + do { + page = __ptr_ring_consume(r); + if (unlikely(!page)) + break; + + if (likely(page_to_nid(page) == pref_nid)) { + pool->alloc.cache[pool->alloc.count++] = page; + } else { + /* NUMA mismatch; + * (1) release 1 page to page-allocator and + * (2) break out to fallthrough to alloc_pages_node. + * This limit stress on page buddy alloactor. + */ + page_pool_return_page(pool, page); + alloc_stat_inc(pool, waive); + page = NULL; + break; + } + } while (pool->alloc.count < PP_ALLOC_CACHE_REFILL); + + /* Return last page */ + if (likely(pool->alloc.count > 0)) { + page = pool->alloc.cache[--pool->alloc.count]; + alloc_stat_inc(pool, refill); + } + + return page; +} + +/* fast path */ +static struct page *__page_pool_get_cached(struct page_pool *pool) +{ + struct page *page; + + /* Caller MUST guarantee safe non-concurrent access, e.g. softirq */ + if (likely(pool->alloc.count)) { + /* Fast-path */ + page = pool->alloc.cache[--pool->alloc.count]; + alloc_stat_inc(pool, fast); + } else { + page = page_pool_refill_alloc_cache(pool); + } + + return page; +} + +static void page_pool_dma_sync_for_device(struct page_pool *pool, + struct page *page, + unsigned int dma_sync_size) +{ + dma_addr_t dma_addr = page_pool_get_dma_addr(page); + + dma_sync_size = min(dma_sync_size, pool->p.max_len); + dma_sync_single_range_for_device(pool->p.dev, dma_addr, + pool->p.offset, dma_sync_size, + pool->p.dma_dir); +} + +static bool page_pool_dma_map(struct page_pool *pool, struct page *page) +{ + dma_addr_t dma; + + /* Setup DMA mapping: use 'struct page' area for storing DMA-addr + * since dma_addr_t can be either 32 or 64 bits and does not always fit + * into page private data (i.e 32bit cpu with 64bit DMA caps) + * This mapping is kept for lifetime of page, until leaving pool. + */ + dma = dma_map_page_attrs(pool->p.dev, page, 0, + (PAGE_SIZE << pool->p.order), + pool->p.dma_dir, DMA_ATTR_SKIP_CPU_SYNC); + if (dma_mapping_error(pool->p.dev, dma)) + return false; + + page_pool_set_dma_addr(page, dma); + + if (pool->p.flags & PP_FLAG_DMA_SYNC_DEV) + page_pool_dma_sync_for_device(pool, page, pool->p.max_len); + + return true; +} + +static void page_pool_set_pp_info(struct page_pool *pool, + struct page *page) +{ + page->pp = pool; + page->pp_magic |= PP_SIGNATURE; + if (pool->p.init_callback) + pool->p.init_callback(page, pool->p.init_arg); +} + +static void page_pool_clear_pp_info(struct page *page) +{ + page->pp_magic = 0; + page->pp = NULL; +} + +static struct page *__page_pool_alloc_page_order(struct page_pool *pool, + gfp_t gfp) +{ + struct page *page; + + gfp |= __GFP_COMP; + page = alloc_pages_node(pool->p.nid, gfp, pool->p.order); + if (unlikely(!page)) + return NULL; + + if ((pool->p.flags & PP_FLAG_DMA_MAP) && + unlikely(!page_pool_dma_map(pool, page))) { + put_page(page); + return NULL; + } + + alloc_stat_inc(pool, slow_high_order); + page_pool_set_pp_info(pool, page); + + /* Track how many pages are held 'in-flight' */ + pool->pages_state_hold_cnt++; + trace_page_pool_state_hold(pool, page, pool->pages_state_hold_cnt); + return page; +} + +/* slow path */ +noinline +static struct page *__page_pool_alloc_pages_slow(struct page_pool *pool, + gfp_t gfp) +{ + const int bulk = PP_ALLOC_CACHE_REFILL; + unsigned int pp_flags = pool->p.flags; + unsigned int pp_order = pool->p.order; + struct page *page; + int i, nr_pages; + + /* Don't support bulk alloc for high-order pages */ + if (unlikely(pp_order)) + return __page_pool_alloc_page_order(pool, gfp); + + /* Unnecessary as alloc cache is empty, but guarantees zero count */ + if (unlikely(pool->alloc.count > 0)) + return pool->alloc.cache[--pool->alloc.count]; + + /* Mark empty alloc.cache slots "empty" for alloc_pages_bulk_array */ + memset(&pool->alloc.cache, 0, sizeof(void *) * bulk); + + nr_pages = alloc_pages_bulk_array_node(gfp, pool->p.nid, bulk, + pool->alloc.cache); + if (unlikely(!nr_pages)) + return NULL; + + /* Pages have been filled into alloc.cache array, but count is zero and + * page element have not been (possibly) DMA mapped. + */ + for (i = 0; i < nr_pages; i++) { + page = pool->alloc.cache[i]; + if ((pp_flags & PP_FLAG_DMA_MAP) && + unlikely(!page_pool_dma_map(pool, page))) { + put_page(page); + continue; + } + + page_pool_set_pp_info(pool, page); + pool->alloc.cache[pool->alloc.count++] = page; + /* Track how many pages are held 'in-flight' */ + pool->pages_state_hold_cnt++; + trace_page_pool_state_hold(pool, page, + pool->pages_state_hold_cnt); + } + + /* Return last page */ + if (likely(pool->alloc.count > 0)) { + page = pool->alloc.cache[--pool->alloc.count]; + alloc_stat_inc(pool, slow); + } else { + page = NULL; + } + + /* When page just alloc'ed is should/must have refcnt 1. */ + return page; +} + +/* For using page_pool replace: alloc_pages() API calls, but provide + * synchronization guarantee for allocation side. + */ +struct page *page_pool_alloc_pages(struct page_pool *pool, gfp_t gfp) +{ + struct page *page; + + /* Fast-path: Get a page from cache */ + page = __page_pool_get_cached(pool); + if (page) + return page; + + /* Slow-path: cache empty, do real allocation */ + page = __page_pool_alloc_pages_slow(pool, gfp); + return page; +} +EXPORT_SYMBOL(page_pool_alloc_pages); + +/* Calculate distance between two u32 values, valid if distance is below 2^(31) + * https://en.wikipedia.org/wiki/Serial_number_arithmetic#General_Solution + */ +#define _distance(a, b) (s32)((a) - (b)) + +static s32 page_pool_inflight(struct page_pool *pool) +{ + u32 release_cnt = atomic_read(&pool->pages_state_release_cnt); + u32 hold_cnt = READ_ONCE(pool->pages_state_hold_cnt); + s32 inflight; + + inflight = _distance(hold_cnt, release_cnt); + + trace_page_pool_release(pool, inflight, hold_cnt, release_cnt); + WARN(inflight < 0, "Negative(%d) inflight packet-pages", inflight); + + return inflight; +} + +/* Disconnects a page (from a page_pool). API users can have a need + * to disconnect a page (from a page_pool), to allow it to be used as + * a regular page (that will eventually be returned to the normal + * page-allocator via put_page). + */ +void page_pool_release_page(struct page_pool *pool, struct page *page) +{ + dma_addr_t dma; + int count; + + if (!(pool->p.flags & PP_FLAG_DMA_MAP)) + /* Always account for inflight pages, even if we didn't + * map them + */ + goto skip_dma_unmap; + + dma = page_pool_get_dma_addr(page); + + /* When page is unmapped, it cannot be returned to our pool */ + dma_unmap_page_attrs(pool->p.dev, dma, + PAGE_SIZE << pool->p.order, pool->p.dma_dir, + DMA_ATTR_SKIP_CPU_SYNC); + page_pool_set_dma_addr(page, 0); +skip_dma_unmap: + page_pool_clear_pp_info(page); + + /* This may be the last page returned, releasing the pool, so + * it is not safe to reference pool afterwards. + */ + count = atomic_inc_return_relaxed(&pool->pages_state_release_cnt); + trace_page_pool_state_release(pool, page, count); +} +EXPORT_SYMBOL(page_pool_release_page); + +/* Return a page to the page allocator, cleaning up our state */ +static void page_pool_return_page(struct page_pool *pool, struct page *page) +{ + page_pool_release_page(pool, page); + + put_page(page); + /* An optimization would be to call __free_pages(page, pool->p.order) + * knowing page is not part of page-cache (thus avoiding a + * __page_cache_release() call). + */ +} + +static bool page_pool_recycle_in_ring(struct page_pool *pool, struct page *page) +{ + int ret; + /* BH protection not needed if current is softirq */ + if (in_softirq()) + ret = ptr_ring_produce(&pool->ring, page); + else + ret = ptr_ring_produce_bh(&pool->ring, page); + + if (!ret) { + recycle_stat_inc(pool, ring); + return true; + } + + return false; +} + +/* Only allow direct recycling in special circumstances, into the + * alloc side cache. E.g. during RX-NAPI processing for XDP_DROP use-case. + * + * Caller must provide appropriate safe context. + */ +static bool page_pool_recycle_in_cache(struct page *page, + struct page_pool *pool) +{ + if (unlikely(pool->alloc.count == PP_ALLOC_CACHE_SIZE)) { + recycle_stat_inc(pool, cache_full); + return false; + } + + /* Caller MUST have verified/know (page_ref_count(page) == 1) */ + pool->alloc.cache[pool->alloc.count++] = page; + recycle_stat_inc(pool, cached); + return true; +} + +/* If the page refcnt == 1, this will try to recycle the page. + * if PP_FLAG_DMA_SYNC_DEV is set, we'll try to sync the DMA area for + * the configured size min(dma_sync_size, pool->max_len). + * If the page refcnt != 1, then the page will be returned to memory + * subsystem. + */ +static __always_inline struct page * +__page_pool_put_page(struct page_pool *pool, struct page *page, + unsigned int dma_sync_size, bool allow_direct) +{ + /* This allocator is optimized for the XDP mode that uses + * one-frame-per-page, but have fallbacks that act like the + * regular page allocator APIs. + * + * refcnt == 1 means page_pool owns page, and can recycle it. + * + * page is NOT reusable when allocated when system is under + * some pressure. (page_is_pfmemalloc) + */ + if (likely(page_ref_count(page) == 1 && !page_is_pfmemalloc(page))) { + /* Read barrier done in page_ref_count / READ_ONCE */ + + if (pool->p.flags & PP_FLAG_DMA_SYNC_DEV) + page_pool_dma_sync_for_device(pool, page, + dma_sync_size); + + if (allow_direct && in_softirq() && + page_pool_recycle_in_cache(page, pool)) + return NULL; + + /* Page found as candidate for recycling */ + return page; + } + /* Fallback/non-XDP mode: API user have elevated refcnt. + * + * Many drivers split up the page into fragments, and some + * want to keep doing this to save memory and do refcnt based + * recycling. Support this use case too, to ease drivers + * switching between XDP/non-XDP. + * + * In-case page_pool maintains the DMA mapping, API user must + * call page_pool_put_page once. In this elevated refcnt + * case, the DMA is unmapped/released, as driver is likely + * doing refcnt based recycle tricks, meaning another process + * will be invoking put_page. + */ + recycle_stat_inc(pool, released_refcnt); + /* Do not replace this with page_pool_return_page() */ + page_pool_release_page(pool, page); + put_page(page); + + return NULL; +} + +void page_pool_put_defragged_page(struct page_pool *pool, struct page *page, + unsigned int dma_sync_size, bool allow_direct) +{ + page = __page_pool_put_page(pool, page, dma_sync_size, allow_direct); + if (page && !page_pool_recycle_in_ring(pool, page)) { + /* Cache full, fallback to free pages */ + recycle_stat_inc(pool, ring_full); + page_pool_return_page(pool, page); + } +} +EXPORT_SYMBOL(page_pool_put_defragged_page); + +/* Caller must not use data area after call, as this function overwrites it */ +void page_pool_put_page_bulk(struct page_pool *pool, void **data, + int count) +{ + int i, bulk_len = 0; + bool in_softirq; + + for (i = 0; i < count; i++) { + struct page *page = virt_to_head_page(data[i]); + + /* It is not the last user for the page frag case */ + if (!page_pool_is_last_frag(pool, page)) + continue; + + page = __page_pool_put_page(pool, page, -1, false); + /* Approved for bulk recycling in ptr_ring cache */ + if (page) + data[bulk_len++] = page; + } + + if (unlikely(!bulk_len)) + return; + + /* Bulk producer into ptr_ring page_pool cache */ + in_softirq = page_pool_producer_lock(pool); + for (i = 0; i < bulk_len; i++) { + if (__ptr_ring_produce(&pool->ring, data[i])) { + /* ring full */ + recycle_stat_inc(pool, ring_full); + break; + } + } + recycle_stat_add(pool, ring, i); + page_pool_producer_unlock(pool, in_softirq); + + /* Hopefully all pages was return into ptr_ring */ + if (likely(i == bulk_len)) + return; + + /* ptr_ring cache full, free remaining pages outside producer lock + * since put_page() with refcnt == 1 can be an expensive operation + */ + for (; i < bulk_len; i++) + page_pool_return_page(pool, data[i]); +} +EXPORT_SYMBOL(page_pool_put_page_bulk); + +static struct page *page_pool_drain_frag(struct page_pool *pool, + struct page *page) +{ + long drain_count = BIAS_MAX - pool->frag_users; + + /* Some user is still using the page frag */ + if (likely(page_pool_defrag_page(page, drain_count))) + return NULL; + + if (page_ref_count(page) == 1 && !page_is_pfmemalloc(page)) { + if (pool->p.flags & PP_FLAG_DMA_SYNC_DEV) + page_pool_dma_sync_for_device(pool, page, -1); + + return page; + } + + page_pool_return_page(pool, page); + return NULL; +} + +static void page_pool_free_frag(struct page_pool *pool) +{ + long drain_count = BIAS_MAX - pool->frag_users; + struct page *page = pool->frag_page; + + pool->frag_page = NULL; + + if (!page || page_pool_defrag_page(page, drain_count)) + return; + + page_pool_return_page(pool, page); +} + +struct page *page_pool_alloc_frag(struct page_pool *pool, + unsigned int *offset, + unsigned int size, gfp_t gfp) +{ + unsigned int max_size = PAGE_SIZE << pool->p.order; + struct page *page = pool->frag_page; + + if (WARN_ON(!(pool->p.flags & PP_FLAG_PAGE_FRAG) || + size > max_size)) + return NULL; + + size = ALIGN(size, dma_get_cache_alignment()); + *offset = pool->frag_offset; + + if (page && *offset + size > max_size) { + page = page_pool_drain_frag(pool, page); + if (page) { + alloc_stat_inc(pool, fast); + goto frag_reset; + } + } + + if (!page) { + page = page_pool_alloc_pages(pool, gfp); + if (unlikely(!page)) { + pool->frag_page = NULL; + return NULL; + } + + pool->frag_page = page; + +frag_reset: + pool->frag_users = 1; + *offset = 0; + pool->frag_offset = size; + page_pool_fragment_page(page, BIAS_MAX); + return page; + } + + pool->frag_users++; + pool->frag_offset = *offset + size; + alloc_stat_inc(pool, fast); + return page; +} +EXPORT_SYMBOL(page_pool_alloc_frag); + +static void page_pool_empty_ring(struct page_pool *pool) +{ + struct page *page; + + /* Empty recycle ring */ + while ((page = ptr_ring_consume_bh(&pool->ring))) { + /* Verify the refcnt invariant of cached pages */ + if (!(page_ref_count(page) == 1)) + pr_crit("%s() page_pool refcnt %d violation\n", + __func__, page_ref_count(page)); + + page_pool_return_page(pool, page); + } +} + +static void page_pool_free(struct page_pool *pool) +{ + if (pool->disconnect) + pool->disconnect(pool); + + ptr_ring_cleanup(&pool->ring, NULL); + + if (pool->p.flags & PP_FLAG_DMA_MAP) + put_device(pool->p.dev); + +#ifdef CONFIG_PAGE_POOL_STATS + free_percpu(pool->recycle_stats); +#endif + kfree(pool); +} + +static void page_pool_empty_alloc_cache_once(struct page_pool *pool) +{ + struct page *page; + + if (pool->destroy_cnt) + return; + + /* Empty alloc cache, assume caller made sure this is + * no-longer in use, and page_pool_alloc_pages() cannot be + * call concurrently. + */ + while (pool->alloc.count) { + page = pool->alloc.cache[--pool->alloc.count]; + page_pool_return_page(pool, page); + } +} + +static void page_pool_scrub(struct page_pool *pool) +{ + page_pool_empty_alloc_cache_once(pool); + pool->destroy_cnt++; + + /* No more consumers should exist, but producers could still + * be in-flight. + */ + page_pool_empty_ring(pool); +} + +static int page_pool_release(struct page_pool *pool) +{ + int inflight; + + page_pool_scrub(pool); + inflight = page_pool_inflight(pool); + if (!inflight) + page_pool_free(pool); + + return inflight; +} + +static void page_pool_release_retry(struct work_struct *wq) +{ + struct delayed_work *dwq = to_delayed_work(wq); + struct page_pool *pool = container_of(dwq, typeof(*pool), release_dw); + int inflight; + + inflight = page_pool_release(pool); + if (!inflight) + return; + + /* Periodic warning */ + if (time_after_eq(jiffies, pool->defer_warn)) { + int sec = (s32)((u32)jiffies - (u32)pool->defer_start) / HZ; + + pr_warn("%s() stalled pool shutdown %d inflight %d sec\n", + __func__, inflight, sec); + pool->defer_warn = jiffies + DEFER_WARN_INTERVAL; + } + + /* Still not ready to be disconnected, retry later */ + schedule_delayed_work(&pool->release_dw, DEFER_TIME); +} + +void page_pool_use_xdp_mem(struct page_pool *pool, void (*disconnect)(void *), + struct xdp_mem_info *mem) +{ + refcount_inc(&pool->user_cnt); + pool->disconnect = disconnect; + pool->xdp_mem_id = mem->id; +} + +void page_pool_destroy(struct page_pool *pool) +{ + if (!pool) + return; + + if (!page_pool_put(pool)) + return; + + page_pool_free_frag(pool); + + if (!page_pool_release(pool)) + return; + + pool->defer_start = jiffies; + pool->defer_warn = jiffies + DEFER_WARN_INTERVAL; + + INIT_DELAYED_WORK(&pool->release_dw, page_pool_release_retry); + schedule_delayed_work(&pool->release_dw, DEFER_TIME); +} +EXPORT_SYMBOL(page_pool_destroy); + +/* Caller must provide appropriate safe context, e.g. NAPI. */ +void page_pool_update_nid(struct page_pool *pool, int new_nid) +{ + struct page *page; + + trace_page_pool_update_nid(pool, new_nid); + pool->p.nid = new_nid; + + /* Flush pool alloc cache, as refill will check NUMA node */ + while (pool->alloc.count) { + page = pool->alloc.cache[--pool->alloc.count]; + page_pool_return_page(pool, page); + } +} +EXPORT_SYMBOL(page_pool_update_nid); + +bool page_pool_return_skb_page(struct page *page) +{ + struct page_pool *pp; + + page = compound_head(page); + + /* page->pp_magic is OR'ed with PP_SIGNATURE after the allocation + * in order to preserve any existing bits, such as bit 0 for the + * head page of compound page and bit 1 for pfmemalloc page, so + * mask those bits for freeing side when doing below checking, + * and page_is_pfmemalloc() is checked in __page_pool_put_page() + * to avoid recycling the pfmemalloc page. + */ + if (unlikely((page->pp_magic & ~0x3UL) != PP_SIGNATURE)) + return false; + + pp = page->pp; + + /* Driver set this to memory recycling info. Reset it on recycle. + * This will *not* work for NIC using a split-page memory model. + * The page will be returned to the pool here regardless of the + * 'flipped' fragment being in use or not. + */ + page_pool_put_full_page(pp, page, false); + + return true; +} +EXPORT_SYMBOL(page_pool_return_skb_page); diff --git a/net/core/pktgen.c b/net/core/pktgen.c new file mode 100644 index 000000000..471d4effa --- /dev/null +++ b/net/core/pktgen.c @@ -0,0 +1,4039 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Authors: + * Copyright 2001, 2002 by Robert Olsson + * Uppsala University and + * Swedish University of Agricultural Sciences + * + * Alexey Kuznetsov + * Ben Greear + * Jens Låås + * + * A tool for loading the network with preconfigurated packets. + * The tool is implemented as a linux module. Parameters are output + * device, delay (to hard_xmit), number of packets, and whether + * to use multiple SKBs or just the same one. + * pktgen uses the installed interface's output routine. + * + * Additional hacking by: + * + * Jens.Laas@data.slu.se + * Improved by ANK. 010120. + * Improved by ANK even more. 010212. + * MAC address typo fixed. 010417 --ro + * Integrated. 020301 --DaveM + * Added multiskb option 020301 --DaveM + * Scaling of results. 020417--sigurdur@linpro.no + * Significant re-work of the module: + * * Convert to threaded model to more efficiently be able to transmit + * and receive on multiple interfaces at once. + * * Converted many counters to __u64 to allow longer runs. + * * Allow configuration of ranges, like min/max IP address, MACs, + * and UDP-ports, for both source and destination, and can + * set to use a random distribution or sequentially walk the range. + * * Can now change most values after starting. + * * Place 12-byte packet in UDP payload with magic number, + * sequence number, and timestamp. + * * Add receiver code that detects dropped pkts, re-ordered pkts, and + * latencies (with micro-second) precision. + * * Add IOCTL interface to easily get counters & configuration. + * --Ben Greear + * + * Renamed multiskb to clone_skb and cleaned up sending core for two distinct + * skb modes. A clone_skb=0 mode for Ben "ranges" work and a clone_skb != 0 + * as a "fastpath" with a configurable number of clones after alloc's. + * clone_skb=0 means all packets are allocated this also means ranges time + * stamps etc can be used. clone_skb=100 means 1 malloc is followed by 100 + * clones. + * + * Also moved to /proc/net/pktgen/ + * --ro + * + * Sept 10: Fixed threading/locking. Lots of bone-headed and more clever + * mistakes. Also merged in DaveM's patch in the -pre6 patch. + * --Ben Greear + * + * Integrated to 2.5.x 021029 --Lucio Maciel (luciomaciel@zipmail.com.br) + * + * 021124 Finished major redesign and rewrite for new functionality. + * See Documentation/networking/pktgen.rst for how to use this. + * + * The new operation: + * For each CPU one thread/process is created at start. This process checks + * for running devices in the if_list and sends packets until count is 0 it + * also the thread checks the thread->control which is used for inter-process + * communication. controlling process "posts" operations to the threads this + * way. + * The if_list is RCU protected, and the if_lock remains to protect updating + * of if_list, from "add_device" as it invoked from userspace (via proc write). + * + * By design there should only be *one* "controlling" process. In practice + * multiple write accesses gives unpredictable result. Understood by "write" + * to /proc gives result code thats should be read be the "writer". + * For practical use this should be no problem. + * + * Note when adding devices to a specific CPU there good idea to also assign + * /proc/irq/XX/smp_affinity so TX-interrupts gets bound to the same CPU. + * --ro + * + * Fix refcount off by one if first packet fails, potential null deref, + * memleak 030710- KJP + * + * First "ranges" functionality for ipv6 030726 --ro + * + * Included flow support. 030802 ANK. + * + * Fixed unaligned access on IA-64 Grant Grundler + * + * Remove if fix from added Harald Welte 040419 + * ia64 compilation fix from Aron Griffis 040604 + * + * New xmit() return, do_div and misc clean up by Stephen Hemminger + * 040923 + * + * Randy Dunlap fixed u64 printk compiler warning + * + * Remove FCS from BW calculation. Lennert Buytenhek + * New time handling. Lennert Buytenhek 041213 + * + * Corrections from Nikolai Malykh (nmalykh@bilim.com) + * Removed unused flags F_SET_SRCMAC & F_SET_SRCIP 041230 + * + * interruptible_sleep_on_timeout() replaced Nishanth Aravamudan + * 050103 + * + * MPLS support by Steven Whitehouse + * + * 802.1Q/Q-in-Q support by Francesco Fondelli (FF) + * + * Fixed src_mac command to set source mac of packet to value specified in + * command by Adit Ranadive + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_XFRM +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include /* do_div */ + +#define VERSION "2.75" +#define IP_NAME_SZ 32 +#define MAX_MPLS_LABELS 16 /* This is the max label stack depth */ +#define MPLS_STACK_BOTTOM htonl(0x00000100) +/* Max number of internet mix entries that can be specified in imix_weights. */ +#define MAX_IMIX_ENTRIES 20 +#define IMIX_PRECISION 100 /* Precision of IMIX distribution */ + +#define func_enter() pr_debug("entering %s\n", __func__); + +#define PKT_FLAGS \ + pf(IPV6) /* Interface in IPV6 Mode */ \ + pf(IPSRC_RND) /* IP-Src Random */ \ + pf(IPDST_RND) /* IP-Dst Random */ \ + pf(TXSIZE_RND) /* Transmit size is random */ \ + pf(UDPSRC_RND) /* UDP-Src Random */ \ + pf(UDPDST_RND) /* UDP-Dst Random */ \ + pf(UDPCSUM) /* Include UDP checksum */ \ + pf(NO_TIMESTAMP) /* Don't timestamp packets (default TS) */ \ + pf(MPLS_RND) /* Random MPLS labels */ \ + pf(QUEUE_MAP_RND) /* queue map Random */ \ + pf(QUEUE_MAP_CPU) /* queue map mirrors smp_processor_id() */ \ + pf(FLOW_SEQ) /* Sequential flows */ \ + pf(IPSEC) /* ipsec on for flows */ \ + pf(MACSRC_RND) /* MAC-Src Random */ \ + pf(MACDST_RND) /* MAC-Dst Random */ \ + pf(VID_RND) /* Random VLAN ID */ \ + pf(SVID_RND) /* Random SVLAN ID */ \ + pf(NODE) /* Node memory alloc*/ \ + +#define pf(flag) flag##_SHIFT, +enum pkt_flags { + PKT_FLAGS +}; +#undef pf + +/* Device flag bits */ +#define pf(flag) static const __u32 F_##flag = (1<if_lock)); +#define if_unlock(t) mutex_unlock(&(t->if_lock)); + +/* Used to help with determining the pkts on receive */ +#define PKTGEN_MAGIC 0xbe9be955 +#define PG_PROC_DIR "pktgen" +#define PGCTRL "pgctrl" + +#define MAX_CFLOWS 65536 + +#define VLAN_TAG_SIZE(x) ((x)->vlan_id == 0xffff ? 0 : 4) +#define SVLAN_TAG_SIZE(x) ((x)->svlan_id == 0xffff ? 0 : 4) + +struct imix_pkt { + u64 size; + u64 weight; + u64 count_so_far; +}; + +struct flow_state { + __be32 cur_daddr; + int count; +#ifdef CONFIG_XFRM + struct xfrm_state *x; +#endif + __u32 flags; +}; + +/* flow flag bits */ +#define F_INIT (1<<0) /* flow has been initialized */ + +struct pktgen_dev { + /* + * Try to keep frequent/infrequent used vars. separated. + */ + struct proc_dir_entry *entry; /* proc file */ + struct pktgen_thread *pg_thread;/* the owner */ + struct list_head list; /* chaining in the thread's run-queue */ + struct rcu_head rcu; /* freed by RCU */ + + int running; /* if false, the test will stop */ + + /* If min != max, then we will either do a linear iteration, or + * we will do a random selection from within the range. + */ + __u32 flags; + int xmit_mode; + int min_pkt_size; + int max_pkt_size; + int pkt_overhead; /* overhead for MPLS, VLANs, IPSEC etc */ + int nfrags; + int removal_mark; /* non-zero => the device is marked for + * removal by worker thread */ + + struct page *page; + u64 delay; /* nano-seconds */ + + __u64 count; /* Default No packets to send */ + __u64 sofar; /* How many pkts we've sent so far */ + __u64 tx_bytes; /* How many bytes we've transmitted */ + __u64 errors; /* Errors when trying to transmit, */ + + /* runtime counters relating to clone_skb */ + + __u32 clone_count; + int last_ok; /* Was last skb sent? + * Or a failed transmit of some sort? + * This will keep sequence numbers in order + */ + ktime_t next_tx; + ktime_t started_at; + ktime_t stopped_at; + u64 idle_acc; /* nano-seconds */ + + __u32 seq_num; + + int clone_skb; /* + * Use multiple SKBs during packet gen. + * If this number is greater than 1, then + * that many copies of the same packet will be + * sent before a new packet is allocated. + * If you want to send 1024 identical packets + * before creating a new packet, + * set clone_skb to 1024. + */ + + char dst_min[IP_NAME_SZ]; /* IP, ie 1.2.3.4 */ + char dst_max[IP_NAME_SZ]; /* IP, ie 1.2.3.4 */ + char src_min[IP_NAME_SZ]; /* IP, ie 1.2.3.4 */ + char src_max[IP_NAME_SZ]; /* IP, ie 1.2.3.4 */ + + struct in6_addr in6_saddr; + struct in6_addr in6_daddr; + struct in6_addr cur_in6_daddr; + struct in6_addr cur_in6_saddr; + /* For ranges */ + struct in6_addr min_in6_daddr; + struct in6_addr max_in6_daddr; + struct in6_addr min_in6_saddr; + struct in6_addr max_in6_saddr; + + /* If we're doing ranges, random or incremental, then this + * defines the min/max for those ranges. + */ + __be32 saddr_min; /* inclusive, source IP address */ + __be32 saddr_max; /* exclusive, source IP address */ + __be32 daddr_min; /* inclusive, dest IP address */ + __be32 daddr_max; /* exclusive, dest IP address */ + + __u16 udp_src_min; /* inclusive, source UDP port */ + __u16 udp_src_max; /* exclusive, source UDP port */ + __u16 udp_dst_min; /* inclusive, dest UDP port */ + __u16 udp_dst_max; /* exclusive, dest UDP port */ + + /* DSCP + ECN */ + __u8 tos; /* six MSB of (former) IPv4 TOS + are for dscp codepoint */ + __u8 traffic_class; /* ditto for the (former) Traffic Class in IPv6 + (see RFC 3260, sec. 4) */ + + /* IMIX */ + unsigned int n_imix_entries; + struct imix_pkt imix_entries[MAX_IMIX_ENTRIES]; + /* Maps 0-IMIX_PRECISION range to imix_entry based on probability*/ + __u8 imix_distribution[IMIX_PRECISION]; + + /* MPLS */ + unsigned int nr_labels; /* Depth of stack, 0 = no MPLS */ + __be32 labels[MAX_MPLS_LABELS]; + + /* VLAN/SVLAN (802.1Q/Q-in-Q) */ + __u8 vlan_p; + __u8 vlan_cfi; + __u16 vlan_id; /* 0xffff means no vlan tag */ + + __u8 svlan_p; + __u8 svlan_cfi; + __u16 svlan_id; /* 0xffff means no svlan tag */ + + __u32 src_mac_count; /* How many MACs to iterate through */ + __u32 dst_mac_count; /* How many MACs to iterate through */ + + unsigned char dst_mac[ETH_ALEN]; + unsigned char src_mac[ETH_ALEN]; + + __u32 cur_dst_mac_offset; + __u32 cur_src_mac_offset; + __be32 cur_saddr; + __be32 cur_daddr; + __u16 ip_id; + __u16 cur_udp_dst; + __u16 cur_udp_src; + __u16 cur_queue_map; + __u32 cur_pkt_size; + __u32 last_pkt_size; + + __u8 hh[14]; + /* = { + 0x00, 0x80, 0xC8, 0x79, 0xB3, 0xCB, + + We fill in SRC address later + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x08, 0x00 + }; + */ + __u16 pad; /* pad out the hh struct to an even 16 bytes */ + + struct sk_buff *skb; /* skb we are to transmit next, used for when we + * are transmitting the same one multiple times + */ + struct net_device *odev; /* The out-going device. + * Note that the device should have it's + * pg_info pointer pointing back to this + * device. + * Set when the user specifies the out-going + * device name (not when the inject is + * started as it used to do.) + */ + netdevice_tracker dev_tracker; + char odevname[32]; + struct flow_state *flows; + unsigned int cflows; /* Concurrent flows (config) */ + unsigned int lflow; /* Flow length (config) */ + unsigned int nflows; /* accumulated flows (stats) */ + unsigned int curfl; /* current sequenced flow (state)*/ + + u16 queue_map_min; + u16 queue_map_max; + __u32 skb_priority; /* skb priority field */ + unsigned int burst; /* number of duplicated packets to burst */ + int node; /* Memory node */ + +#ifdef CONFIG_XFRM + __u8 ipsmode; /* IPSEC mode (config) */ + __u8 ipsproto; /* IPSEC type (config) */ + __u32 spi; + struct xfrm_dst xdst; + struct dst_ops dstops; +#endif + char result[512]; +}; + +struct pktgen_hdr { + __be32 pgh_magic; + __be32 seq_num; + __be32 tv_sec; + __be32 tv_usec; +}; + + +static unsigned int pg_net_id __read_mostly; + +struct pktgen_net { + struct net *net; + struct proc_dir_entry *proc_dir; + struct list_head pktgen_threads; + bool pktgen_exiting; +}; + +struct pktgen_thread { + struct mutex if_lock; /* for list of devices */ + struct list_head if_list; /* All device here */ + struct list_head th_list; + struct task_struct *tsk; + char result[512]; + + /* Field for thread to receive "posted" events terminate, + stop ifs etc. */ + + u32 control; + int cpu; + + wait_queue_head_t queue; + struct completion start_done; + struct pktgen_net *net; +}; + +#define REMOVE 1 +#define FIND 0 + +static const char version[] = + "Packet Generator for packet performance testing. " + "Version: " VERSION "\n"; + +static int pktgen_remove_device(struct pktgen_thread *t, struct pktgen_dev *i); +static int pktgen_add_device(struct pktgen_thread *t, const char *ifname); +static struct pktgen_dev *pktgen_find_dev(struct pktgen_thread *t, + const char *ifname, bool exact); +static int pktgen_device_event(struct notifier_block *, unsigned long, void *); +static void pktgen_run_all_threads(struct pktgen_net *pn); +static void pktgen_reset_all_threads(struct pktgen_net *pn); +static void pktgen_stop_all_threads(struct pktgen_net *pn); + +static void pktgen_stop(struct pktgen_thread *t); +static void pktgen_clear_counters(struct pktgen_dev *pkt_dev); +static void fill_imix_distribution(struct pktgen_dev *pkt_dev); + +/* Module parameters, defaults. */ +static int pg_count_d __read_mostly = 1000; +static int pg_delay_d __read_mostly; +static int pg_clone_skb_d __read_mostly; +static int debug __read_mostly; + +static DEFINE_MUTEX(pktgen_thread_lock); + +static struct notifier_block pktgen_notifier_block = { + .notifier_call = pktgen_device_event, +}; + +/* + * /proc handling functions + * + */ + +static int pgctrl_show(struct seq_file *seq, void *v) +{ + seq_puts(seq, version); + return 0; +} + +static ssize_t pgctrl_write(struct file *file, const char __user *buf, + size_t count, loff_t *ppos) +{ + char data[128]; + struct pktgen_net *pn = net_generic(current->nsproxy->net_ns, pg_net_id); + + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + if (count == 0) + return -EINVAL; + + if (count > sizeof(data)) + count = sizeof(data); + + if (copy_from_user(data, buf, count)) + return -EFAULT; + + data[count - 1] = 0; /* Strip trailing '\n' and terminate string */ + + if (!strcmp(data, "stop")) + pktgen_stop_all_threads(pn); + else if (!strcmp(data, "start")) + pktgen_run_all_threads(pn); + else if (!strcmp(data, "reset")) + pktgen_reset_all_threads(pn); + else + return -EINVAL; + + return count; +} + +static int pgctrl_open(struct inode *inode, struct file *file) +{ + return single_open(file, pgctrl_show, pde_data(inode)); +} + +static const struct proc_ops pktgen_proc_ops = { + .proc_open = pgctrl_open, + .proc_read = seq_read, + .proc_lseek = seq_lseek, + .proc_write = pgctrl_write, + .proc_release = single_release, +}; + +static int pktgen_if_show(struct seq_file *seq, void *v) +{ + const struct pktgen_dev *pkt_dev = seq->private; + ktime_t stopped; + unsigned int i; + u64 idle; + + seq_printf(seq, + "Params: count %llu min_pkt_size: %u max_pkt_size: %u\n", + (unsigned long long)pkt_dev->count, pkt_dev->min_pkt_size, + pkt_dev->max_pkt_size); + + if (pkt_dev->n_imix_entries > 0) { + seq_puts(seq, " imix_weights: "); + for (i = 0; i < pkt_dev->n_imix_entries; i++) { + seq_printf(seq, "%llu,%llu ", + pkt_dev->imix_entries[i].size, + pkt_dev->imix_entries[i].weight); + } + seq_puts(seq, "\n"); + } + + seq_printf(seq, + " frags: %d delay: %llu clone_skb: %d ifname: %s\n", + pkt_dev->nfrags, (unsigned long long) pkt_dev->delay, + pkt_dev->clone_skb, pkt_dev->odevname); + + seq_printf(seq, " flows: %u flowlen: %u\n", pkt_dev->cflows, + pkt_dev->lflow); + + seq_printf(seq, + " queue_map_min: %u queue_map_max: %u\n", + pkt_dev->queue_map_min, + pkt_dev->queue_map_max); + + if (pkt_dev->skb_priority) + seq_printf(seq, " skb_priority: %u\n", + pkt_dev->skb_priority); + + if (pkt_dev->flags & F_IPV6) { + seq_printf(seq, + " saddr: %pI6c min_saddr: %pI6c max_saddr: %pI6c\n" + " daddr: %pI6c min_daddr: %pI6c max_daddr: %pI6c\n", + &pkt_dev->in6_saddr, + &pkt_dev->min_in6_saddr, &pkt_dev->max_in6_saddr, + &pkt_dev->in6_daddr, + &pkt_dev->min_in6_daddr, &pkt_dev->max_in6_daddr); + } else { + seq_printf(seq, + " dst_min: %s dst_max: %s\n", + pkt_dev->dst_min, pkt_dev->dst_max); + seq_printf(seq, + " src_min: %s src_max: %s\n", + pkt_dev->src_min, pkt_dev->src_max); + } + + seq_puts(seq, " src_mac: "); + + seq_printf(seq, "%pM ", + is_zero_ether_addr(pkt_dev->src_mac) ? + pkt_dev->odev->dev_addr : pkt_dev->src_mac); + + seq_puts(seq, "dst_mac: "); + seq_printf(seq, "%pM\n", pkt_dev->dst_mac); + + seq_printf(seq, + " udp_src_min: %d udp_src_max: %d" + " udp_dst_min: %d udp_dst_max: %d\n", + pkt_dev->udp_src_min, pkt_dev->udp_src_max, + pkt_dev->udp_dst_min, pkt_dev->udp_dst_max); + + seq_printf(seq, + " src_mac_count: %d dst_mac_count: %d\n", + pkt_dev->src_mac_count, pkt_dev->dst_mac_count); + + if (pkt_dev->nr_labels) { + seq_puts(seq, " mpls: "); + for (i = 0; i < pkt_dev->nr_labels; i++) + seq_printf(seq, "%08x%s", ntohl(pkt_dev->labels[i]), + i == pkt_dev->nr_labels-1 ? "\n" : ", "); + } + + if (pkt_dev->vlan_id != 0xffff) + seq_printf(seq, " vlan_id: %u vlan_p: %u vlan_cfi: %u\n", + pkt_dev->vlan_id, pkt_dev->vlan_p, + pkt_dev->vlan_cfi); + + if (pkt_dev->svlan_id != 0xffff) + seq_printf(seq, " svlan_id: %u vlan_p: %u vlan_cfi: %u\n", + pkt_dev->svlan_id, pkt_dev->svlan_p, + pkt_dev->svlan_cfi); + + if (pkt_dev->tos) + seq_printf(seq, " tos: 0x%02x\n", pkt_dev->tos); + + if (pkt_dev->traffic_class) + seq_printf(seq, " traffic_class: 0x%02x\n", pkt_dev->traffic_class); + + if (pkt_dev->burst > 1) + seq_printf(seq, " burst: %d\n", pkt_dev->burst); + + if (pkt_dev->node >= 0) + seq_printf(seq, " node: %d\n", pkt_dev->node); + + if (pkt_dev->xmit_mode == M_NETIF_RECEIVE) + seq_puts(seq, " xmit_mode: netif_receive\n"); + else if (pkt_dev->xmit_mode == M_QUEUE_XMIT) + seq_puts(seq, " xmit_mode: xmit_queue\n"); + + seq_puts(seq, " Flags: "); + + for (i = 0; i < NR_PKT_FLAGS; i++) { + if (i == FLOW_SEQ_SHIFT) + if (!pkt_dev->cflows) + continue; + + if (pkt_dev->flags & (1 << i)) { + seq_printf(seq, "%s ", pkt_flag_names[i]); +#ifdef CONFIG_XFRM + if (i == IPSEC_SHIFT && pkt_dev->spi) + seq_printf(seq, "spi:%u ", pkt_dev->spi); +#endif + } else if (i == FLOW_SEQ_SHIFT) { + seq_puts(seq, "FLOW_RND "); + } + } + + seq_puts(seq, "\n"); + + /* not really stopped, more like last-running-at */ + stopped = pkt_dev->running ? ktime_get() : pkt_dev->stopped_at; + idle = pkt_dev->idle_acc; + do_div(idle, NSEC_PER_USEC); + + seq_printf(seq, + "Current:\n pkts-sofar: %llu errors: %llu\n", + (unsigned long long)pkt_dev->sofar, + (unsigned long long)pkt_dev->errors); + + if (pkt_dev->n_imix_entries > 0) { + int i; + + seq_puts(seq, " imix_size_counts: "); + for (i = 0; i < pkt_dev->n_imix_entries; i++) { + seq_printf(seq, "%llu,%llu ", + pkt_dev->imix_entries[i].size, + pkt_dev->imix_entries[i].count_so_far); + } + seq_puts(seq, "\n"); + } + + seq_printf(seq, + " started: %lluus stopped: %lluus idle: %lluus\n", + (unsigned long long) ktime_to_us(pkt_dev->started_at), + (unsigned long long) ktime_to_us(stopped), + (unsigned long long) idle); + + seq_printf(seq, + " seq_num: %d cur_dst_mac_offset: %d cur_src_mac_offset: %d\n", + pkt_dev->seq_num, pkt_dev->cur_dst_mac_offset, + pkt_dev->cur_src_mac_offset); + + if (pkt_dev->flags & F_IPV6) { + seq_printf(seq, " cur_saddr: %pI6c cur_daddr: %pI6c\n", + &pkt_dev->cur_in6_saddr, + &pkt_dev->cur_in6_daddr); + } else + seq_printf(seq, " cur_saddr: %pI4 cur_daddr: %pI4\n", + &pkt_dev->cur_saddr, &pkt_dev->cur_daddr); + + seq_printf(seq, " cur_udp_dst: %d cur_udp_src: %d\n", + pkt_dev->cur_udp_dst, pkt_dev->cur_udp_src); + + seq_printf(seq, " cur_queue_map: %u\n", pkt_dev->cur_queue_map); + + seq_printf(seq, " flows: %u\n", pkt_dev->nflows); + + if (pkt_dev->result[0]) + seq_printf(seq, "Result: %s\n", pkt_dev->result); + else + seq_puts(seq, "Result: Idle\n"); + + return 0; +} + + +static int hex32_arg(const char __user *user_buffer, unsigned long maxlen, + __u32 *num) +{ + int i = 0; + *num = 0; + + for (; i < maxlen; i++) { + int value; + char c; + *num <<= 4; + if (get_user(c, &user_buffer[i])) + return -EFAULT; + value = hex_to_bin(c); + if (value >= 0) + *num |= value; + else + break; + } + return i; +} + +static int count_trail_chars(const char __user * user_buffer, + unsigned int maxlen) +{ + int i; + + for (i = 0; i < maxlen; i++) { + char c; + if (get_user(c, &user_buffer[i])) + return -EFAULT; + switch (c) { + case '\"': + case '\n': + case '\r': + case '\t': + case ' ': + case '=': + break; + default: + goto done; + } + } +done: + return i; +} + +static long num_arg(const char __user *user_buffer, unsigned long maxlen, + unsigned long *num) +{ + int i; + *num = 0; + + for (i = 0; i < maxlen; i++) { + char c; + if (get_user(c, &user_buffer[i])) + return -EFAULT; + if ((c >= '0') && (c <= '9')) { + *num *= 10; + *num += c - '0'; + } else + break; + } + return i; +} + +static int strn_len(const char __user * user_buffer, unsigned int maxlen) +{ + int i; + + for (i = 0; i < maxlen; i++) { + char c; + if (get_user(c, &user_buffer[i])) + return -EFAULT; + switch (c) { + case '\"': + case '\n': + case '\r': + case '\t': + case ' ': + goto done_str; + default: + break; + } + } +done_str: + return i; +} + +/* Parses imix entries from user buffer. + * The user buffer should consist of imix entries separated by spaces + * where each entry consists of size and weight delimited by commas. + * "size1,weight_1 size2,weight_2 ... size_n,weight_n" for example. + */ +static ssize_t get_imix_entries(const char __user *buffer, + struct pktgen_dev *pkt_dev) +{ + const int max_digits = 10; + int i = 0; + long len; + char c; + + pkt_dev->n_imix_entries = 0; + + do { + unsigned long weight; + unsigned long size; + + len = num_arg(&buffer[i], max_digits, &size); + if (len < 0) + return len; + i += len; + if (get_user(c, &buffer[i])) + return -EFAULT; + /* Check for comma between size_i and weight_i */ + if (c != ',') + return -EINVAL; + i++; + + if (size < 14 + 20 + 8) + size = 14 + 20 + 8; + + len = num_arg(&buffer[i], max_digits, &weight); + if (len < 0) + return len; + if (weight <= 0) + return -EINVAL; + + pkt_dev->imix_entries[pkt_dev->n_imix_entries].size = size; + pkt_dev->imix_entries[pkt_dev->n_imix_entries].weight = weight; + + i += len; + if (get_user(c, &buffer[i])) + return -EFAULT; + + i++; + pkt_dev->n_imix_entries++; + + if (pkt_dev->n_imix_entries > MAX_IMIX_ENTRIES) + return -E2BIG; + } while (c == ' '); + + return i; +} + +static ssize_t get_labels(const char __user *buffer, struct pktgen_dev *pkt_dev) +{ + unsigned int n = 0; + char c; + ssize_t i = 0; + int len; + + pkt_dev->nr_labels = 0; + do { + __u32 tmp; + len = hex32_arg(&buffer[i], 8, &tmp); + if (len <= 0) + return len; + pkt_dev->labels[n] = htonl(tmp); + if (pkt_dev->labels[n] & MPLS_STACK_BOTTOM) + pkt_dev->flags |= F_MPLS_RND; + i += len; + if (get_user(c, &buffer[i])) + return -EFAULT; + i++; + n++; + if (n >= MAX_MPLS_LABELS) + return -E2BIG; + } while (c == ','); + + pkt_dev->nr_labels = n; + return i; +} + +static __u32 pktgen_read_flag(const char *f, bool *disable) +{ + __u32 i; + + if (f[0] == '!') { + *disable = true; + f++; + } + + for (i = 0; i < NR_PKT_FLAGS; i++) { + if (!IS_ENABLED(CONFIG_XFRM) && i == IPSEC_SHIFT) + continue; + + /* allow only disabling ipv6 flag */ + if (!*disable && i == IPV6_SHIFT) + continue; + + if (strcmp(f, pkt_flag_names[i]) == 0) + return 1 << i; + } + + if (strcmp(f, "FLOW_RND") == 0) { + *disable = !*disable; + return F_FLOW_SEQ; + } + + return 0; +} + +static ssize_t pktgen_if_write(struct file *file, + const char __user * user_buffer, size_t count, + loff_t * offset) +{ + struct seq_file *seq = file->private_data; + struct pktgen_dev *pkt_dev = seq->private; + int i, max, len; + char name[16], valstr[32]; + unsigned long value = 0; + char *pg_result = NULL; + int tmp = 0; + char buf[128]; + + pg_result = &(pkt_dev->result[0]); + + if (count < 1) { + pr_warn("wrong command format\n"); + return -EINVAL; + } + + max = count; + tmp = count_trail_chars(user_buffer, max); + if (tmp < 0) { + pr_warn("illegal format\n"); + return tmp; + } + i = tmp; + + /* Read variable name */ + + len = strn_len(&user_buffer[i], sizeof(name) - 1); + if (len < 0) + return len; + + memset(name, 0, sizeof(name)); + if (copy_from_user(name, &user_buffer[i], len)) + return -EFAULT; + i += len; + + max = count - i; + len = count_trail_chars(&user_buffer[i], max); + if (len < 0) + return len; + + i += len; + + if (debug) { + size_t copy = min_t(size_t, count + 1, 1024); + char *tp = strndup_user(user_buffer, copy); + + if (IS_ERR(tp)) + return PTR_ERR(tp); + + pr_debug("%s,%zu buffer -:%s:-\n", name, count, tp); + kfree(tp); + } + + if (!strcmp(name, "min_pkt_size")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + if (value < 14 + 20 + 8) + value = 14 + 20 + 8; + if (value != pkt_dev->min_pkt_size) { + pkt_dev->min_pkt_size = value; + pkt_dev->cur_pkt_size = value; + } + sprintf(pg_result, "OK: min_pkt_size=%d", + pkt_dev->min_pkt_size); + return count; + } + + if (!strcmp(name, "max_pkt_size")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + if (value < 14 + 20 + 8) + value = 14 + 20 + 8; + if (value != pkt_dev->max_pkt_size) { + pkt_dev->max_pkt_size = value; + pkt_dev->cur_pkt_size = value; + } + sprintf(pg_result, "OK: max_pkt_size=%d", + pkt_dev->max_pkt_size); + return count; + } + + /* Shortcut for min = max */ + + if (!strcmp(name, "pkt_size")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + if (value < 14 + 20 + 8) + value = 14 + 20 + 8; + if (value != pkt_dev->min_pkt_size) { + pkt_dev->min_pkt_size = value; + pkt_dev->max_pkt_size = value; + pkt_dev->cur_pkt_size = value; + } + sprintf(pg_result, "OK: pkt_size=%d", pkt_dev->min_pkt_size); + return count; + } + + if (!strcmp(name, "imix_weights")) { + if (pkt_dev->clone_skb > 0) + return -EINVAL; + + len = get_imix_entries(&user_buffer[i], pkt_dev); + if (len < 0) + return len; + + fill_imix_distribution(pkt_dev); + + i += len; + return count; + } + + if (!strcmp(name, "debug")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + debug = value; + sprintf(pg_result, "OK: debug=%u", debug); + return count; + } + + if (!strcmp(name, "frags")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + pkt_dev->nfrags = value; + sprintf(pg_result, "OK: frags=%d", pkt_dev->nfrags); + return count; + } + if (!strcmp(name, "delay")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + if (value == 0x7FFFFFFF) + pkt_dev->delay = ULLONG_MAX; + else + pkt_dev->delay = (u64)value; + + sprintf(pg_result, "OK: delay=%llu", + (unsigned long long) pkt_dev->delay); + return count; + } + if (!strcmp(name, "rate")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + if (!value) + return len; + pkt_dev->delay = pkt_dev->min_pkt_size*8*NSEC_PER_USEC/value; + if (debug) + pr_info("Delay set at: %llu ns\n", pkt_dev->delay); + + sprintf(pg_result, "OK: rate=%lu", value); + return count; + } + if (!strcmp(name, "ratep")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + if (!value) + return len; + pkt_dev->delay = NSEC_PER_SEC/value; + if (debug) + pr_info("Delay set at: %llu ns\n", pkt_dev->delay); + + sprintf(pg_result, "OK: rate=%lu", value); + return count; + } + if (!strcmp(name, "udp_src_min")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + if (value != pkt_dev->udp_src_min) { + pkt_dev->udp_src_min = value; + pkt_dev->cur_udp_src = value; + } + sprintf(pg_result, "OK: udp_src_min=%u", pkt_dev->udp_src_min); + return count; + } + if (!strcmp(name, "udp_dst_min")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + if (value != pkt_dev->udp_dst_min) { + pkt_dev->udp_dst_min = value; + pkt_dev->cur_udp_dst = value; + } + sprintf(pg_result, "OK: udp_dst_min=%u", pkt_dev->udp_dst_min); + return count; + } + if (!strcmp(name, "udp_src_max")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + if (value != pkt_dev->udp_src_max) { + pkt_dev->udp_src_max = value; + pkt_dev->cur_udp_src = value; + } + sprintf(pg_result, "OK: udp_src_max=%u", pkt_dev->udp_src_max); + return count; + } + if (!strcmp(name, "udp_dst_max")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + if (value != pkt_dev->udp_dst_max) { + pkt_dev->udp_dst_max = value; + pkt_dev->cur_udp_dst = value; + } + sprintf(pg_result, "OK: udp_dst_max=%u", pkt_dev->udp_dst_max); + return count; + } + if (!strcmp(name, "clone_skb")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + /* clone_skb is not supported for netif_receive xmit_mode and + * IMIX mode. + */ + if ((value > 0) && + ((pkt_dev->xmit_mode == M_NETIF_RECEIVE) || + !(pkt_dev->odev->priv_flags & IFF_TX_SKB_SHARING))) + return -ENOTSUPP; + if (value > 0 && pkt_dev->n_imix_entries > 0) + return -EINVAL; + + i += len; + pkt_dev->clone_skb = value; + + sprintf(pg_result, "OK: clone_skb=%d", pkt_dev->clone_skb); + return count; + } + if (!strcmp(name, "count")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + pkt_dev->count = value; + sprintf(pg_result, "OK: count=%llu", + (unsigned long long)pkt_dev->count); + return count; + } + if (!strcmp(name, "src_mac_count")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + if (pkt_dev->src_mac_count != value) { + pkt_dev->src_mac_count = value; + pkt_dev->cur_src_mac_offset = 0; + } + sprintf(pg_result, "OK: src_mac_count=%d", + pkt_dev->src_mac_count); + return count; + } + if (!strcmp(name, "dst_mac_count")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + if (pkt_dev->dst_mac_count != value) { + pkt_dev->dst_mac_count = value; + pkt_dev->cur_dst_mac_offset = 0; + } + sprintf(pg_result, "OK: dst_mac_count=%d", + pkt_dev->dst_mac_count); + return count; + } + if (!strcmp(name, "burst")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + if ((value > 1) && + ((pkt_dev->xmit_mode == M_QUEUE_XMIT) || + ((pkt_dev->xmit_mode == M_START_XMIT) && + (!(pkt_dev->odev->priv_flags & IFF_TX_SKB_SHARING))))) + return -ENOTSUPP; + pkt_dev->burst = value < 1 ? 1 : value; + sprintf(pg_result, "OK: burst=%u", pkt_dev->burst); + return count; + } + if (!strcmp(name, "node")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + + if (node_possible(value)) { + pkt_dev->node = value; + sprintf(pg_result, "OK: node=%d", pkt_dev->node); + if (pkt_dev->page) { + put_page(pkt_dev->page); + pkt_dev->page = NULL; + } + } + else + sprintf(pg_result, "ERROR: node not possible"); + return count; + } + if (!strcmp(name, "xmit_mode")) { + char f[32]; + + memset(f, 0, 32); + len = strn_len(&user_buffer[i], sizeof(f) - 1); + if (len < 0) + return len; + + if (copy_from_user(f, &user_buffer[i], len)) + return -EFAULT; + i += len; + + if (strcmp(f, "start_xmit") == 0) { + pkt_dev->xmit_mode = M_START_XMIT; + } else if (strcmp(f, "netif_receive") == 0) { + /* clone_skb set earlier, not supported in this mode */ + if (pkt_dev->clone_skb > 0) + return -ENOTSUPP; + + pkt_dev->xmit_mode = M_NETIF_RECEIVE; + + /* make sure new packet is allocated every time + * pktgen_xmit() is called + */ + pkt_dev->last_ok = 1; + } else if (strcmp(f, "queue_xmit") == 0) { + pkt_dev->xmit_mode = M_QUEUE_XMIT; + pkt_dev->last_ok = 1; + } else { + sprintf(pg_result, + "xmit_mode -:%s:- unknown\nAvailable modes: %s", + f, "start_xmit, netif_receive\n"); + return count; + } + sprintf(pg_result, "OK: xmit_mode=%s", f); + return count; + } + if (!strcmp(name, "flag")) { + __u32 flag; + char f[32]; + bool disable = false; + + memset(f, 0, 32); + len = strn_len(&user_buffer[i], sizeof(f) - 1); + if (len < 0) + return len; + + if (copy_from_user(f, &user_buffer[i], len)) + return -EFAULT; + i += len; + + flag = pktgen_read_flag(f, &disable); + + if (flag) { + if (disable) + pkt_dev->flags &= ~flag; + else + pkt_dev->flags |= flag; + } else { + sprintf(pg_result, + "Flag -:%s:- unknown\nAvailable flags, (prepend ! to un-set flag):\n%s", + f, + "IPSRC_RND, IPDST_RND, UDPSRC_RND, UDPDST_RND, " + "MACSRC_RND, MACDST_RND, TXSIZE_RND, IPV6, " + "MPLS_RND, VID_RND, SVID_RND, FLOW_SEQ, " + "QUEUE_MAP_RND, QUEUE_MAP_CPU, UDPCSUM, " + "NO_TIMESTAMP, " +#ifdef CONFIG_XFRM + "IPSEC, " +#endif + "NODE_ALLOC\n"); + return count; + } + sprintf(pg_result, "OK: flags=0x%x", pkt_dev->flags); + return count; + } + if (!strcmp(name, "dst_min") || !strcmp(name, "dst")) { + len = strn_len(&user_buffer[i], sizeof(pkt_dev->dst_min) - 1); + if (len < 0) + return len; + + if (copy_from_user(buf, &user_buffer[i], len)) + return -EFAULT; + buf[len] = 0; + if (strcmp(buf, pkt_dev->dst_min) != 0) { + memset(pkt_dev->dst_min, 0, sizeof(pkt_dev->dst_min)); + strcpy(pkt_dev->dst_min, buf); + pkt_dev->daddr_min = in_aton(pkt_dev->dst_min); + pkt_dev->cur_daddr = pkt_dev->daddr_min; + } + if (debug) + pr_debug("dst_min set to: %s\n", pkt_dev->dst_min); + i += len; + sprintf(pg_result, "OK: dst_min=%s", pkt_dev->dst_min); + return count; + } + if (!strcmp(name, "dst_max")) { + len = strn_len(&user_buffer[i], sizeof(pkt_dev->dst_max) - 1); + if (len < 0) + return len; + + if (copy_from_user(buf, &user_buffer[i], len)) + return -EFAULT; + buf[len] = 0; + if (strcmp(buf, pkt_dev->dst_max) != 0) { + memset(pkt_dev->dst_max, 0, sizeof(pkt_dev->dst_max)); + strcpy(pkt_dev->dst_max, buf); + pkt_dev->daddr_max = in_aton(pkt_dev->dst_max); + pkt_dev->cur_daddr = pkt_dev->daddr_max; + } + if (debug) + pr_debug("dst_max set to: %s\n", pkt_dev->dst_max); + i += len; + sprintf(pg_result, "OK: dst_max=%s", pkt_dev->dst_max); + return count; + } + if (!strcmp(name, "dst6")) { + len = strn_len(&user_buffer[i], sizeof(buf) - 1); + if (len < 0) + return len; + + pkt_dev->flags |= F_IPV6; + + if (copy_from_user(buf, &user_buffer[i], len)) + return -EFAULT; + buf[len] = 0; + + in6_pton(buf, -1, pkt_dev->in6_daddr.s6_addr, -1, NULL); + snprintf(buf, sizeof(buf), "%pI6c", &pkt_dev->in6_daddr); + + pkt_dev->cur_in6_daddr = pkt_dev->in6_daddr; + + if (debug) + pr_debug("dst6 set to: %s\n", buf); + + i += len; + sprintf(pg_result, "OK: dst6=%s", buf); + return count; + } + if (!strcmp(name, "dst6_min")) { + len = strn_len(&user_buffer[i], sizeof(buf) - 1); + if (len < 0) + return len; + + pkt_dev->flags |= F_IPV6; + + if (copy_from_user(buf, &user_buffer[i], len)) + return -EFAULT; + buf[len] = 0; + + in6_pton(buf, -1, pkt_dev->min_in6_daddr.s6_addr, -1, NULL); + snprintf(buf, sizeof(buf), "%pI6c", &pkt_dev->min_in6_daddr); + + pkt_dev->cur_in6_daddr = pkt_dev->min_in6_daddr; + if (debug) + pr_debug("dst6_min set to: %s\n", buf); + + i += len; + sprintf(pg_result, "OK: dst6_min=%s", buf); + return count; + } + if (!strcmp(name, "dst6_max")) { + len = strn_len(&user_buffer[i], sizeof(buf) - 1); + if (len < 0) + return len; + + pkt_dev->flags |= F_IPV6; + + if (copy_from_user(buf, &user_buffer[i], len)) + return -EFAULT; + buf[len] = 0; + + in6_pton(buf, -1, pkt_dev->max_in6_daddr.s6_addr, -1, NULL); + snprintf(buf, sizeof(buf), "%pI6c", &pkt_dev->max_in6_daddr); + + if (debug) + pr_debug("dst6_max set to: %s\n", buf); + + i += len; + sprintf(pg_result, "OK: dst6_max=%s", buf); + return count; + } + if (!strcmp(name, "src6")) { + len = strn_len(&user_buffer[i], sizeof(buf) - 1); + if (len < 0) + return len; + + pkt_dev->flags |= F_IPV6; + + if (copy_from_user(buf, &user_buffer[i], len)) + return -EFAULT; + buf[len] = 0; + + in6_pton(buf, -1, pkt_dev->in6_saddr.s6_addr, -1, NULL); + snprintf(buf, sizeof(buf), "%pI6c", &pkt_dev->in6_saddr); + + pkt_dev->cur_in6_saddr = pkt_dev->in6_saddr; + + if (debug) + pr_debug("src6 set to: %s\n", buf); + + i += len; + sprintf(pg_result, "OK: src6=%s", buf); + return count; + } + if (!strcmp(name, "src_min")) { + len = strn_len(&user_buffer[i], sizeof(pkt_dev->src_min) - 1); + if (len < 0) + return len; + + if (copy_from_user(buf, &user_buffer[i], len)) + return -EFAULT; + buf[len] = 0; + if (strcmp(buf, pkt_dev->src_min) != 0) { + memset(pkt_dev->src_min, 0, sizeof(pkt_dev->src_min)); + strcpy(pkt_dev->src_min, buf); + pkt_dev->saddr_min = in_aton(pkt_dev->src_min); + pkt_dev->cur_saddr = pkt_dev->saddr_min; + } + if (debug) + pr_debug("src_min set to: %s\n", pkt_dev->src_min); + i += len; + sprintf(pg_result, "OK: src_min=%s", pkt_dev->src_min); + return count; + } + if (!strcmp(name, "src_max")) { + len = strn_len(&user_buffer[i], sizeof(pkt_dev->src_max) - 1); + if (len < 0) + return len; + + if (copy_from_user(buf, &user_buffer[i], len)) + return -EFAULT; + buf[len] = 0; + if (strcmp(buf, pkt_dev->src_max) != 0) { + memset(pkt_dev->src_max, 0, sizeof(pkt_dev->src_max)); + strcpy(pkt_dev->src_max, buf); + pkt_dev->saddr_max = in_aton(pkt_dev->src_max); + pkt_dev->cur_saddr = pkt_dev->saddr_max; + } + if (debug) + pr_debug("src_max set to: %s\n", pkt_dev->src_max); + i += len; + sprintf(pg_result, "OK: src_max=%s", pkt_dev->src_max); + return count; + } + if (!strcmp(name, "dst_mac")) { + len = strn_len(&user_buffer[i], sizeof(valstr) - 1); + if (len < 0) + return len; + + memset(valstr, 0, sizeof(valstr)); + if (copy_from_user(valstr, &user_buffer[i], len)) + return -EFAULT; + + if (!mac_pton(valstr, pkt_dev->dst_mac)) + return -EINVAL; + /* Set up Dest MAC */ + ether_addr_copy(&pkt_dev->hh[0], pkt_dev->dst_mac); + + sprintf(pg_result, "OK: dstmac %pM", pkt_dev->dst_mac); + return count; + } + if (!strcmp(name, "src_mac")) { + len = strn_len(&user_buffer[i], sizeof(valstr) - 1); + if (len < 0) + return len; + + memset(valstr, 0, sizeof(valstr)); + if (copy_from_user(valstr, &user_buffer[i], len)) + return -EFAULT; + + if (!mac_pton(valstr, pkt_dev->src_mac)) + return -EINVAL; + /* Set up Src MAC */ + ether_addr_copy(&pkt_dev->hh[6], pkt_dev->src_mac); + + sprintf(pg_result, "OK: srcmac %pM", pkt_dev->src_mac); + return count; + } + + if (!strcmp(name, "clear_counters")) { + pktgen_clear_counters(pkt_dev); + sprintf(pg_result, "OK: Clearing counters.\n"); + return count; + } + + if (!strcmp(name, "flows")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + if (value > MAX_CFLOWS) + value = MAX_CFLOWS; + + pkt_dev->cflows = value; + sprintf(pg_result, "OK: flows=%u", pkt_dev->cflows); + return count; + } +#ifdef CONFIG_XFRM + if (!strcmp(name, "spi")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + pkt_dev->spi = value; + sprintf(pg_result, "OK: spi=%u", pkt_dev->spi); + return count; + } +#endif + if (!strcmp(name, "flowlen")) { + len = num_arg(&user_buffer[i], 10, &value); + if (len < 0) + return len; + + i += len; + pkt_dev->lflow = value; + sprintf(pg_result, "OK: flowlen=%u", pkt_dev->lflow); + return count; + } + + if (!strcmp(name, "queue_map_min")) { + len = num_arg(&user_buffer[i], 5, &value); + if (len < 0) + return len; + + i += len; + pkt_dev->queue_map_min = value; + sprintf(pg_result, "OK: queue_map_min=%u", pkt_dev->queue_map_min); + return count; + } + + if (!strcmp(name, "queue_map_max")) { + len = num_arg(&user_buffer[i], 5, &value); + if (len < 0) + return len; + + i += len; + pkt_dev->queue_map_max = value; + sprintf(pg_result, "OK: queue_map_max=%u", pkt_dev->queue_map_max); + return count; + } + + if (!strcmp(name, "mpls")) { + unsigned int n, cnt; + + len = get_labels(&user_buffer[i], pkt_dev); + if (len < 0) + return len; + i += len; + cnt = sprintf(pg_result, "OK: mpls="); + for (n = 0; n < pkt_dev->nr_labels; n++) + cnt += sprintf(pg_result + cnt, + "%08x%s", ntohl(pkt_dev->labels[n]), + n == pkt_dev->nr_labels-1 ? "" : ","); + + if (pkt_dev->nr_labels && pkt_dev->vlan_id != 0xffff) { + pkt_dev->vlan_id = 0xffff; /* turn off VLAN/SVLAN */ + pkt_dev->svlan_id = 0xffff; + + if (debug) + pr_debug("VLAN/SVLAN auto turned off\n"); + } + return count; + } + + if (!strcmp(name, "vlan_id")) { + len = num_arg(&user_buffer[i], 4, &value); + if (len < 0) + return len; + + i += len; + if (value <= 4095) { + pkt_dev->vlan_id = value; /* turn on VLAN */ + + if (debug) + pr_debug("VLAN turned on\n"); + + if (debug && pkt_dev->nr_labels) + pr_debug("MPLS auto turned off\n"); + + pkt_dev->nr_labels = 0; /* turn off MPLS */ + sprintf(pg_result, "OK: vlan_id=%u", pkt_dev->vlan_id); + } else { + pkt_dev->vlan_id = 0xffff; /* turn off VLAN/SVLAN */ + pkt_dev->svlan_id = 0xffff; + + if (debug) + pr_debug("VLAN/SVLAN turned off\n"); + } + return count; + } + + if (!strcmp(name, "vlan_p")) { + len = num_arg(&user_buffer[i], 1, &value); + if (len < 0) + return len; + + i += len; + if ((value <= 7) && (pkt_dev->vlan_id != 0xffff)) { + pkt_dev->vlan_p = value; + sprintf(pg_result, "OK: vlan_p=%u", pkt_dev->vlan_p); + } else { + sprintf(pg_result, "ERROR: vlan_p must be 0-7"); + } + return count; + } + + if (!strcmp(name, "vlan_cfi")) { + len = num_arg(&user_buffer[i], 1, &value); + if (len < 0) + return len; + + i += len; + if ((value <= 1) && (pkt_dev->vlan_id != 0xffff)) { + pkt_dev->vlan_cfi = value; + sprintf(pg_result, "OK: vlan_cfi=%u", pkt_dev->vlan_cfi); + } else { + sprintf(pg_result, "ERROR: vlan_cfi must be 0-1"); + } + return count; + } + + if (!strcmp(name, "svlan_id")) { + len = num_arg(&user_buffer[i], 4, &value); + if (len < 0) + return len; + + i += len; + if ((value <= 4095) && ((pkt_dev->vlan_id != 0xffff))) { + pkt_dev->svlan_id = value; /* turn on SVLAN */ + + if (debug) + pr_debug("SVLAN turned on\n"); + + if (debug && pkt_dev->nr_labels) + pr_debug("MPLS auto turned off\n"); + + pkt_dev->nr_labels = 0; /* turn off MPLS */ + sprintf(pg_result, "OK: svlan_id=%u", pkt_dev->svlan_id); + } else { + pkt_dev->vlan_id = 0xffff; /* turn off VLAN/SVLAN */ + pkt_dev->svlan_id = 0xffff; + + if (debug) + pr_debug("VLAN/SVLAN turned off\n"); + } + return count; + } + + if (!strcmp(name, "svlan_p")) { + len = num_arg(&user_buffer[i], 1, &value); + if (len < 0) + return len; + + i += len; + if ((value <= 7) && (pkt_dev->svlan_id != 0xffff)) { + pkt_dev->svlan_p = value; + sprintf(pg_result, "OK: svlan_p=%u", pkt_dev->svlan_p); + } else { + sprintf(pg_result, "ERROR: svlan_p must be 0-7"); + } + return count; + } + + if (!strcmp(name, "svlan_cfi")) { + len = num_arg(&user_buffer[i], 1, &value); + if (len < 0) + return len; + + i += len; + if ((value <= 1) && (pkt_dev->svlan_id != 0xffff)) { + pkt_dev->svlan_cfi = value; + sprintf(pg_result, "OK: svlan_cfi=%u", pkt_dev->svlan_cfi); + } else { + sprintf(pg_result, "ERROR: svlan_cfi must be 0-1"); + } + return count; + } + + if (!strcmp(name, "tos")) { + __u32 tmp_value = 0; + len = hex32_arg(&user_buffer[i], 2, &tmp_value); + if (len < 0) + return len; + + i += len; + if (len == 2) { + pkt_dev->tos = tmp_value; + sprintf(pg_result, "OK: tos=0x%02x", pkt_dev->tos); + } else { + sprintf(pg_result, "ERROR: tos must be 00-ff"); + } + return count; + } + + if (!strcmp(name, "traffic_class")) { + __u32 tmp_value = 0; + len = hex32_arg(&user_buffer[i], 2, &tmp_value); + if (len < 0) + return len; + + i += len; + if (len == 2) { + pkt_dev->traffic_class = tmp_value; + sprintf(pg_result, "OK: traffic_class=0x%02x", pkt_dev->traffic_class); + } else { + sprintf(pg_result, "ERROR: traffic_class must be 00-ff"); + } + return count; + } + + if (!strcmp(name, "skb_priority")) { + len = num_arg(&user_buffer[i], 9, &value); + if (len < 0) + return len; + + i += len; + pkt_dev->skb_priority = value; + sprintf(pg_result, "OK: skb_priority=%i", + pkt_dev->skb_priority); + return count; + } + + sprintf(pkt_dev->result, "No such parameter \"%s\"", name); + return -EINVAL; +} + +static int pktgen_if_open(struct inode *inode, struct file *file) +{ + return single_open(file, pktgen_if_show, pde_data(inode)); +} + +static const struct proc_ops pktgen_if_proc_ops = { + .proc_open = pktgen_if_open, + .proc_read = seq_read, + .proc_lseek = seq_lseek, + .proc_write = pktgen_if_write, + .proc_release = single_release, +}; + +static int pktgen_thread_show(struct seq_file *seq, void *v) +{ + struct pktgen_thread *t = seq->private; + const struct pktgen_dev *pkt_dev; + + BUG_ON(!t); + + seq_puts(seq, "Running: "); + + rcu_read_lock(); + list_for_each_entry_rcu(pkt_dev, &t->if_list, list) + if (pkt_dev->running) + seq_printf(seq, "%s ", pkt_dev->odevname); + + seq_puts(seq, "\nStopped: "); + + list_for_each_entry_rcu(pkt_dev, &t->if_list, list) + if (!pkt_dev->running) + seq_printf(seq, "%s ", pkt_dev->odevname); + + if (t->result[0]) + seq_printf(seq, "\nResult: %s\n", t->result); + else + seq_puts(seq, "\nResult: NA\n"); + + rcu_read_unlock(); + + return 0; +} + +static ssize_t pktgen_thread_write(struct file *file, + const char __user * user_buffer, + size_t count, loff_t * offset) +{ + struct seq_file *seq = file->private_data; + struct pktgen_thread *t = seq->private; + int i, max, len, ret; + char name[40]; + char *pg_result; + + if (count < 1) { + // sprintf(pg_result, "Wrong command format"); + return -EINVAL; + } + + max = count; + len = count_trail_chars(user_buffer, max); + if (len < 0) + return len; + + i = len; + + /* Read variable name */ + + len = strn_len(&user_buffer[i], sizeof(name) - 1); + if (len < 0) + return len; + + memset(name, 0, sizeof(name)); + if (copy_from_user(name, &user_buffer[i], len)) + return -EFAULT; + i += len; + + max = count - i; + len = count_trail_chars(&user_buffer[i], max); + if (len < 0) + return len; + + i += len; + + if (debug) + pr_debug("t=%s, count=%lu\n", name, (unsigned long)count); + + if (!t) { + pr_err("ERROR: No thread\n"); + ret = -EINVAL; + goto out; + } + + pg_result = &(t->result[0]); + + if (!strcmp(name, "add_device")) { + char f[32]; + memset(f, 0, 32); + len = strn_len(&user_buffer[i], sizeof(f) - 1); + if (len < 0) { + ret = len; + goto out; + } + if (copy_from_user(f, &user_buffer[i], len)) + return -EFAULT; + i += len; + mutex_lock(&pktgen_thread_lock); + ret = pktgen_add_device(t, f); + mutex_unlock(&pktgen_thread_lock); + if (!ret) { + ret = count; + sprintf(pg_result, "OK: add_device=%s", f); + } else + sprintf(pg_result, "ERROR: can not add device %s", f); + goto out; + } + + if (!strcmp(name, "rem_device_all")) { + mutex_lock(&pktgen_thread_lock); + t->control |= T_REMDEVALL; + mutex_unlock(&pktgen_thread_lock); + schedule_timeout_interruptible(msecs_to_jiffies(125)); /* Propagate thread->control */ + ret = count; + sprintf(pg_result, "OK: rem_device_all"); + goto out; + } + + if (!strcmp(name, "max_before_softirq")) { + sprintf(pg_result, "OK: Note! max_before_softirq is obsoleted -- Do not use"); + ret = count; + goto out; + } + + ret = -EINVAL; +out: + return ret; +} + +static int pktgen_thread_open(struct inode *inode, struct file *file) +{ + return single_open(file, pktgen_thread_show, pde_data(inode)); +} + +static const struct proc_ops pktgen_thread_proc_ops = { + .proc_open = pktgen_thread_open, + .proc_read = seq_read, + .proc_lseek = seq_lseek, + .proc_write = pktgen_thread_write, + .proc_release = single_release, +}; + +/* Think find or remove for NN */ +static struct pktgen_dev *__pktgen_NN_threads(const struct pktgen_net *pn, + const char *ifname, int remove) +{ + struct pktgen_thread *t; + struct pktgen_dev *pkt_dev = NULL; + bool exact = (remove == FIND); + + list_for_each_entry(t, &pn->pktgen_threads, th_list) { + pkt_dev = pktgen_find_dev(t, ifname, exact); + if (pkt_dev) { + if (remove) { + pkt_dev->removal_mark = 1; + t->control |= T_REMDEV; + } + break; + } + } + return pkt_dev; +} + +/* + * mark a device for removal + */ +static void pktgen_mark_device(const struct pktgen_net *pn, const char *ifname) +{ + struct pktgen_dev *pkt_dev = NULL; + const int max_tries = 10, msec_per_try = 125; + int i = 0; + + mutex_lock(&pktgen_thread_lock); + pr_debug("%s: marking %s for removal\n", __func__, ifname); + + while (1) { + + pkt_dev = __pktgen_NN_threads(pn, ifname, REMOVE); + if (pkt_dev == NULL) + break; /* success */ + + mutex_unlock(&pktgen_thread_lock); + pr_debug("%s: waiting for %s to disappear....\n", + __func__, ifname); + schedule_timeout_interruptible(msecs_to_jiffies(msec_per_try)); + mutex_lock(&pktgen_thread_lock); + + if (++i >= max_tries) { + pr_err("%s: timed out after waiting %d msec for device %s to be removed\n", + __func__, msec_per_try * i, ifname); + break; + } + + } + + mutex_unlock(&pktgen_thread_lock); +} + +static void pktgen_change_name(const struct pktgen_net *pn, struct net_device *dev) +{ + struct pktgen_thread *t; + + mutex_lock(&pktgen_thread_lock); + + list_for_each_entry(t, &pn->pktgen_threads, th_list) { + struct pktgen_dev *pkt_dev; + + if_lock(t); + list_for_each_entry(pkt_dev, &t->if_list, list) { + if (pkt_dev->odev != dev) + continue; + + proc_remove(pkt_dev->entry); + + pkt_dev->entry = proc_create_data(dev->name, 0600, + pn->proc_dir, + &pktgen_if_proc_ops, + pkt_dev); + if (!pkt_dev->entry) + pr_err("can't move proc entry for '%s'\n", + dev->name); + break; + } + if_unlock(t); + } + mutex_unlock(&pktgen_thread_lock); +} + +static int pktgen_device_event(struct notifier_block *unused, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct pktgen_net *pn = net_generic(dev_net(dev), pg_net_id); + + if (pn->pktgen_exiting) + return NOTIFY_DONE; + + /* It is OK that we do not hold the group lock right now, + * as we run under the RTNL lock. + */ + + switch (event) { + case NETDEV_CHANGENAME: + pktgen_change_name(pn, dev); + break; + + case NETDEV_UNREGISTER: + pktgen_mark_device(pn, dev->name); + break; + } + + return NOTIFY_DONE; +} + +static struct net_device *pktgen_dev_get_by_name(const struct pktgen_net *pn, + struct pktgen_dev *pkt_dev, + const char *ifname) +{ + char b[IFNAMSIZ+5]; + int i; + + for (i = 0; ifname[i] != '@'; i++) { + if (i == IFNAMSIZ) + break; + + b[i] = ifname[i]; + } + b[i] = 0; + + return dev_get_by_name(pn->net, b); +} + + +/* Associate pktgen_dev with a device. */ + +static int pktgen_setup_dev(const struct pktgen_net *pn, + struct pktgen_dev *pkt_dev, const char *ifname) +{ + struct net_device *odev; + int err; + + /* Clean old setups */ + if (pkt_dev->odev) { + netdev_put(pkt_dev->odev, &pkt_dev->dev_tracker); + pkt_dev->odev = NULL; + } + + odev = pktgen_dev_get_by_name(pn, pkt_dev, ifname); + if (!odev) { + pr_err("no such netdevice: \"%s\"\n", ifname); + return -ENODEV; + } + + if (odev->type != ARPHRD_ETHER && odev->type != ARPHRD_LOOPBACK) { + pr_err("not an ethernet or loopback device: \"%s\"\n", ifname); + err = -EINVAL; + } else if (!netif_running(odev)) { + pr_err("device is down: \"%s\"\n", ifname); + err = -ENETDOWN; + } else { + pkt_dev->odev = odev; + netdev_tracker_alloc(odev, &pkt_dev->dev_tracker, GFP_KERNEL); + return 0; + } + + dev_put(odev); + return err; +} + +/* Read pkt_dev from the interface and set up internal pktgen_dev + * structure to have the right information to create/send packets + */ +static void pktgen_setup_inject(struct pktgen_dev *pkt_dev) +{ + int ntxq; + + if (!pkt_dev->odev) { + pr_err("ERROR: pkt_dev->odev == NULL in setup_inject\n"); + sprintf(pkt_dev->result, + "ERROR: pkt_dev->odev == NULL in setup_inject.\n"); + return; + } + + /* make sure that we don't pick a non-existing transmit queue */ + ntxq = pkt_dev->odev->real_num_tx_queues; + + if (ntxq <= pkt_dev->queue_map_min) { + pr_warn("WARNING: Requested queue_map_min (zero-based) (%d) exceeds valid range [0 - %d] for (%d) queues on %s, resetting\n", + pkt_dev->queue_map_min, (ntxq ?: 1) - 1, ntxq, + pkt_dev->odevname); + pkt_dev->queue_map_min = (ntxq ?: 1) - 1; + } + if (pkt_dev->queue_map_max >= ntxq) { + pr_warn("WARNING: Requested queue_map_max (zero-based) (%d) exceeds valid range [0 - %d] for (%d) queues on %s, resetting\n", + pkt_dev->queue_map_max, (ntxq ?: 1) - 1, ntxq, + pkt_dev->odevname); + pkt_dev->queue_map_max = (ntxq ?: 1) - 1; + } + + /* Default to the interface's mac if not explicitly set. */ + + if (is_zero_ether_addr(pkt_dev->src_mac)) + ether_addr_copy(&(pkt_dev->hh[6]), pkt_dev->odev->dev_addr); + + /* Set up Dest MAC */ + ether_addr_copy(&(pkt_dev->hh[0]), pkt_dev->dst_mac); + + if (pkt_dev->flags & F_IPV6) { + int i, set = 0, err = 1; + struct inet6_dev *idev; + + if (pkt_dev->min_pkt_size == 0) { + pkt_dev->min_pkt_size = 14 + sizeof(struct ipv6hdr) + + sizeof(struct udphdr) + + sizeof(struct pktgen_hdr) + + pkt_dev->pkt_overhead; + } + + for (i = 0; i < sizeof(struct in6_addr); i++) + if (pkt_dev->cur_in6_saddr.s6_addr[i]) { + set = 1; + break; + } + + if (!set) { + + /* + * Use linklevel address if unconfigured. + * + * use ipv6_get_lladdr if/when it's get exported + */ + + rcu_read_lock(); + idev = __in6_dev_get(pkt_dev->odev); + if (idev) { + struct inet6_ifaddr *ifp; + + read_lock_bh(&idev->lock); + list_for_each_entry(ifp, &idev->addr_list, if_list) { + if ((ifp->scope & IFA_LINK) && + !(ifp->flags & IFA_F_TENTATIVE)) { + pkt_dev->cur_in6_saddr = ifp->addr; + err = 0; + break; + } + } + read_unlock_bh(&idev->lock); + } + rcu_read_unlock(); + if (err) + pr_err("ERROR: IPv6 link address not available\n"); + } + } else { + if (pkt_dev->min_pkt_size == 0) { + pkt_dev->min_pkt_size = 14 + sizeof(struct iphdr) + + sizeof(struct udphdr) + + sizeof(struct pktgen_hdr) + + pkt_dev->pkt_overhead; + } + + pkt_dev->saddr_min = 0; + pkt_dev->saddr_max = 0; + if (strlen(pkt_dev->src_min) == 0) { + + struct in_device *in_dev; + + rcu_read_lock(); + in_dev = __in_dev_get_rcu(pkt_dev->odev); + if (in_dev) { + const struct in_ifaddr *ifa; + + ifa = rcu_dereference(in_dev->ifa_list); + if (ifa) { + pkt_dev->saddr_min = ifa->ifa_address; + pkt_dev->saddr_max = pkt_dev->saddr_min; + } + } + rcu_read_unlock(); + } else { + pkt_dev->saddr_min = in_aton(pkt_dev->src_min); + pkt_dev->saddr_max = in_aton(pkt_dev->src_max); + } + + pkt_dev->daddr_min = in_aton(pkt_dev->dst_min); + pkt_dev->daddr_max = in_aton(pkt_dev->dst_max); + } + /* Initialize current values. */ + pkt_dev->cur_pkt_size = pkt_dev->min_pkt_size; + if (pkt_dev->min_pkt_size > pkt_dev->max_pkt_size) + pkt_dev->max_pkt_size = pkt_dev->min_pkt_size; + + pkt_dev->cur_dst_mac_offset = 0; + pkt_dev->cur_src_mac_offset = 0; + pkt_dev->cur_saddr = pkt_dev->saddr_min; + pkt_dev->cur_daddr = pkt_dev->daddr_min; + pkt_dev->cur_udp_dst = pkt_dev->udp_dst_min; + pkt_dev->cur_udp_src = pkt_dev->udp_src_min; + pkt_dev->nflows = 0; +} + + +static void spin(struct pktgen_dev *pkt_dev, ktime_t spin_until) +{ + ktime_t start_time, end_time; + s64 remaining; + struct hrtimer_sleeper t; + + hrtimer_init_sleeper_on_stack(&t, CLOCK_MONOTONIC, HRTIMER_MODE_ABS); + hrtimer_set_expires(&t.timer, spin_until); + + remaining = ktime_to_ns(hrtimer_expires_remaining(&t.timer)); + if (remaining <= 0) + goto out; + + start_time = ktime_get(); + if (remaining < 100000) { + /* for small delays (<100us), just loop until limit is reached */ + do { + end_time = ktime_get(); + } while (ktime_compare(end_time, spin_until) < 0); + } else { + do { + set_current_state(TASK_INTERRUPTIBLE); + hrtimer_sleeper_start_expires(&t, HRTIMER_MODE_ABS); + + if (likely(t.task)) + schedule(); + + hrtimer_cancel(&t.timer); + } while (t.task && pkt_dev->running && !signal_pending(current)); + __set_current_state(TASK_RUNNING); + end_time = ktime_get(); + } + + pkt_dev->idle_acc += ktime_to_ns(ktime_sub(end_time, start_time)); +out: + pkt_dev->next_tx = ktime_add_ns(spin_until, pkt_dev->delay); + destroy_hrtimer_on_stack(&t.timer); +} + +static inline void set_pkt_overhead(struct pktgen_dev *pkt_dev) +{ + pkt_dev->pkt_overhead = 0; + pkt_dev->pkt_overhead += pkt_dev->nr_labels*sizeof(u32); + pkt_dev->pkt_overhead += VLAN_TAG_SIZE(pkt_dev); + pkt_dev->pkt_overhead += SVLAN_TAG_SIZE(pkt_dev); +} + +static inline int f_seen(const struct pktgen_dev *pkt_dev, int flow) +{ + return !!(pkt_dev->flows[flow].flags & F_INIT); +} + +static inline int f_pick(struct pktgen_dev *pkt_dev) +{ + int flow = pkt_dev->curfl; + + if (pkt_dev->flags & F_FLOW_SEQ) { + if (pkt_dev->flows[flow].count >= pkt_dev->lflow) { + /* reset time */ + pkt_dev->flows[flow].count = 0; + pkt_dev->flows[flow].flags = 0; + pkt_dev->curfl += 1; + if (pkt_dev->curfl >= pkt_dev->cflows) + pkt_dev->curfl = 0; /*reset */ + } + } else { + flow = prandom_u32_max(pkt_dev->cflows); + pkt_dev->curfl = flow; + + if (pkt_dev->flows[flow].count > pkt_dev->lflow) { + pkt_dev->flows[flow].count = 0; + pkt_dev->flows[flow].flags = 0; + } + } + + return pkt_dev->curfl; +} + + +#ifdef CONFIG_XFRM +/* If there was already an IPSEC SA, we keep it as is, else + * we go look for it ... +*/ +#define DUMMY_MARK 0 +static void get_ipsec_sa(struct pktgen_dev *pkt_dev, int flow) +{ + struct xfrm_state *x = pkt_dev->flows[flow].x; + struct pktgen_net *pn = net_generic(dev_net(pkt_dev->odev), pg_net_id); + if (!x) { + + if (pkt_dev->spi) { + /* We need as quick as possible to find the right SA + * Searching with minimum criteria to archieve this. + */ + x = xfrm_state_lookup_byspi(pn->net, htonl(pkt_dev->spi), AF_INET); + } else { + /* slow path: we dont already have xfrm_state */ + x = xfrm_stateonly_find(pn->net, DUMMY_MARK, 0, + (xfrm_address_t *)&pkt_dev->cur_daddr, + (xfrm_address_t *)&pkt_dev->cur_saddr, + AF_INET, + pkt_dev->ipsmode, + pkt_dev->ipsproto, 0); + } + if (x) { + pkt_dev->flows[flow].x = x; + set_pkt_overhead(pkt_dev); + pkt_dev->pkt_overhead += x->props.header_len; + } + + } +} +#endif +static void set_cur_queue_map(struct pktgen_dev *pkt_dev) +{ + + if (pkt_dev->flags & F_QUEUE_MAP_CPU) + pkt_dev->cur_queue_map = smp_processor_id(); + + else if (pkt_dev->queue_map_min <= pkt_dev->queue_map_max) { + __u16 t; + if (pkt_dev->flags & F_QUEUE_MAP_RND) { + t = prandom_u32_max(pkt_dev->queue_map_max - + pkt_dev->queue_map_min + 1) + + pkt_dev->queue_map_min; + } else { + t = pkt_dev->cur_queue_map + 1; + if (t > pkt_dev->queue_map_max) + t = pkt_dev->queue_map_min; + } + pkt_dev->cur_queue_map = t; + } + pkt_dev->cur_queue_map = pkt_dev->cur_queue_map % pkt_dev->odev->real_num_tx_queues; +} + +/* Increment/randomize headers according to flags and current values + * for IP src/dest, UDP src/dst port, MAC-Addr src/dst + */ +static void mod_cur_headers(struct pktgen_dev *pkt_dev) +{ + __u32 imn; + __u32 imx; + int flow = 0; + + if (pkt_dev->cflows) + flow = f_pick(pkt_dev); + + /* Deal with source MAC */ + if (pkt_dev->src_mac_count > 1) { + __u32 mc; + __u32 tmp; + + if (pkt_dev->flags & F_MACSRC_RND) + mc = prandom_u32_max(pkt_dev->src_mac_count); + else { + mc = pkt_dev->cur_src_mac_offset++; + if (pkt_dev->cur_src_mac_offset >= + pkt_dev->src_mac_count) + pkt_dev->cur_src_mac_offset = 0; + } + + tmp = pkt_dev->src_mac[5] + (mc & 0xFF); + pkt_dev->hh[11] = tmp; + tmp = (pkt_dev->src_mac[4] + ((mc >> 8) & 0xFF) + (tmp >> 8)); + pkt_dev->hh[10] = tmp; + tmp = (pkt_dev->src_mac[3] + ((mc >> 16) & 0xFF) + (tmp >> 8)); + pkt_dev->hh[9] = tmp; + tmp = (pkt_dev->src_mac[2] + ((mc >> 24) & 0xFF) + (tmp >> 8)); + pkt_dev->hh[8] = tmp; + tmp = (pkt_dev->src_mac[1] + (tmp >> 8)); + pkt_dev->hh[7] = tmp; + } + + /* Deal with Destination MAC */ + if (pkt_dev->dst_mac_count > 1) { + __u32 mc; + __u32 tmp; + + if (pkt_dev->flags & F_MACDST_RND) + mc = prandom_u32_max(pkt_dev->dst_mac_count); + + else { + mc = pkt_dev->cur_dst_mac_offset++; + if (pkt_dev->cur_dst_mac_offset >= + pkt_dev->dst_mac_count) { + pkt_dev->cur_dst_mac_offset = 0; + } + } + + tmp = pkt_dev->dst_mac[5] + (mc & 0xFF); + pkt_dev->hh[5] = tmp; + tmp = (pkt_dev->dst_mac[4] + ((mc >> 8) & 0xFF) + (tmp >> 8)); + pkt_dev->hh[4] = tmp; + tmp = (pkt_dev->dst_mac[3] + ((mc >> 16) & 0xFF) + (tmp >> 8)); + pkt_dev->hh[3] = tmp; + tmp = (pkt_dev->dst_mac[2] + ((mc >> 24) & 0xFF) + (tmp >> 8)); + pkt_dev->hh[2] = tmp; + tmp = (pkt_dev->dst_mac[1] + (tmp >> 8)); + pkt_dev->hh[1] = tmp; + } + + if (pkt_dev->flags & F_MPLS_RND) { + unsigned int i; + for (i = 0; i < pkt_dev->nr_labels; i++) + if (pkt_dev->labels[i] & MPLS_STACK_BOTTOM) + pkt_dev->labels[i] = MPLS_STACK_BOTTOM | + ((__force __be32)get_random_u32() & + htonl(0x000fffff)); + } + + if ((pkt_dev->flags & F_VID_RND) && (pkt_dev->vlan_id != 0xffff)) { + pkt_dev->vlan_id = prandom_u32_max(4096); + } + + if ((pkt_dev->flags & F_SVID_RND) && (pkt_dev->svlan_id != 0xffff)) { + pkt_dev->svlan_id = prandom_u32_max(4096); + } + + if (pkt_dev->udp_src_min < pkt_dev->udp_src_max) { + if (pkt_dev->flags & F_UDPSRC_RND) + pkt_dev->cur_udp_src = prandom_u32_max( + pkt_dev->udp_src_max - pkt_dev->udp_src_min) + + pkt_dev->udp_src_min; + + else { + pkt_dev->cur_udp_src++; + if (pkt_dev->cur_udp_src >= pkt_dev->udp_src_max) + pkt_dev->cur_udp_src = pkt_dev->udp_src_min; + } + } + + if (pkt_dev->udp_dst_min < pkt_dev->udp_dst_max) { + if (pkt_dev->flags & F_UDPDST_RND) { + pkt_dev->cur_udp_dst = prandom_u32_max( + pkt_dev->udp_dst_max - pkt_dev->udp_dst_min) + + pkt_dev->udp_dst_min; + } else { + pkt_dev->cur_udp_dst++; + if (pkt_dev->cur_udp_dst >= pkt_dev->udp_dst_max) + pkt_dev->cur_udp_dst = pkt_dev->udp_dst_min; + } + } + + if (!(pkt_dev->flags & F_IPV6)) { + + imn = ntohl(pkt_dev->saddr_min); + imx = ntohl(pkt_dev->saddr_max); + if (imn < imx) { + __u32 t; + if (pkt_dev->flags & F_IPSRC_RND) + t = prandom_u32_max(imx - imn) + imn; + else { + t = ntohl(pkt_dev->cur_saddr); + t++; + if (t > imx) + t = imn; + + } + pkt_dev->cur_saddr = htonl(t); + } + + if (pkt_dev->cflows && f_seen(pkt_dev, flow)) { + pkt_dev->cur_daddr = pkt_dev->flows[flow].cur_daddr; + } else { + imn = ntohl(pkt_dev->daddr_min); + imx = ntohl(pkt_dev->daddr_max); + if (imn < imx) { + __u32 t; + __be32 s; + if (pkt_dev->flags & F_IPDST_RND) { + + do { + t = prandom_u32_max(imx - imn) + + imn; + s = htonl(t); + } while (ipv4_is_loopback(s) || + ipv4_is_multicast(s) || + ipv4_is_lbcast(s) || + ipv4_is_zeronet(s) || + ipv4_is_local_multicast(s)); + pkt_dev->cur_daddr = s; + } else { + t = ntohl(pkt_dev->cur_daddr); + t++; + if (t > imx) { + t = imn; + } + pkt_dev->cur_daddr = htonl(t); + } + } + if (pkt_dev->cflows) { + pkt_dev->flows[flow].flags |= F_INIT; + pkt_dev->flows[flow].cur_daddr = + pkt_dev->cur_daddr; +#ifdef CONFIG_XFRM + if (pkt_dev->flags & F_IPSEC) + get_ipsec_sa(pkt_dev, flow); +#endif + pkt_dev->nflows++; + } + } + } else { /* IPV6 * */ + + if (!ipv6_addr_any(&pkt_dev->min_in6_daddr)) { + int i; + + /* Only random destinations yet */ + + for (i = 0; i < 4; i++) { + pkt_dev->cur_in6_daddr.s6_addr32[i] = + (((__force __be32)get_random_u32() | + pkt_dev->min_in6_daddr.s6_addr32[i]) & + pkt_dev->max_in6_daddr.s6_addr32[i]); + } + } + } + + if (pkt_dev->min_pkt_size < pkt_dev->max_pkt_size) { + __u32 t; + if (pkt_dev->flags & F_TXSIZE_RND) { + t = prandom_u32_max(pkt_dev->max_pkt_size - + pkt_dev->min_pkt_size) + + pkt_dev->min_pkt_size; + } else { + t = pkt_dev->cur_pkt_size + 1; + if (t > pkt_dev->max_pkt_size) + t = pkt_dev->min_pkt_size; + } + pkt_dev->cur_pkt_size = t; + } else if (pkt_dev->n_imix_entries > 0) { + struct imix_pkt *entry; + __u32 t = prandom_u32_max(IMIX_PRECISION); + __u8 entry_index = pkt_dev->imix_distribution[t]; + + entry = &pkt_dev->imix_entries[entry_index]; + entry->count_so_far++; + pkt_dev->cur_pkt_size = entry->size; + } + + set_cur_queue_map(pkt_dev); + + pkt_dev->flows[flow].count++; +} + +static void fill_imix_distribution(struct pktgen_dev *pkt_dev) +{ + int cumulative_probabilites[MAX_IMIX_ENTRIES]; + int j = 0; + __u64 cumulative_prob = 0; + __u64 total_weight = 0; + int i = 0; + + for (i = 0; i < pkt_dev->n_imix_entries; i++) + total_weight += pkt_dev->imix_entries[i].weight; + + /* Fill cumulative_probabilites with sum of normalized probabilities */ + for (i = 0; i < pkt_dev->n_imix_entries - 1; i++) { + cumulative_prob += div64_u64(pkt_dev->imix_entries[i].weight * + IMIX_PRECISION, + total_weight); + cumulative_probabilites[i] = cumulative_prob; + } + cumulative_probabilites[pkt_dev->n_imix_entries - 1] = 100; + + for (i = 0; i < IMIX_PRECISION; i++) { + if (i == cumulative_probabilites[j]) + j++; + pkt_dev->imix_distribution[i] = j; + } +} + +#ifdef CONFIG_XFRM +static u32 pktgen_dst_metrics[RTAX_MAX + 1] = { + + [RTAX_HOPLIMIT] = 0x5, /* Set a static hoplimit */ +}; + +static int pktgen_output_ipsec(struct sk_buff *skb, struct pktgen_dev *pkt_dev) +{ + struct xfrm_state *x = pkt_dev->flows[pkt_dev->curfl].x; + int err = 0; + struct net *net = dev_net(pkt_dev->odev); + + if (!x) + return 0; + /* XXX: we dont support tunnel mode for now until + * we resolve the dst issue */ + if ((x->props.mode != XFRM_MODE_TRANSPORT) && (pkt_dev->spi == 0)) + return 0; + + /* But when user specify an valid SPI, transformation + * supports both transport/tunnel mode + ESP/AH type. + */ + if ((x->props.mode == XFRM_MODE_TUNNEL) && (pkt_dev->spi != 0)) + skb->_skb_refdst = (unsigned long)&pkt_dev->xdst.u.dst | SKB_DST_NOREF; + + rcu_read_lock_bh(); + err = pktgen_xfrm_outer_mode_output(x, skb); + rcu_read_unlock_bh(); + if (err) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTSTATEMODEERROR); + goto error; + } + err = x->type->output(x, skb); + if (err) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTSTATEPROTOERROR); + goto error; + } + spin_lock_bh(&x->lock); + x->curlft.bytes += skb->len; + x->curlft.packets++; + spin_unlock_bh(&x->lock); +error: + return err; +} + +static void free_SAs(struct pktgen_dev *pkt_dev) +{ + if (pkt_dev->cflows) { + /* let go of the SAs if we have them */ + int i; + for (i = 0; i < pkt_dev->cflows; i++) { + struct xfrm_state *x = pkt_dev->flows[i].x; + if (x) { + xfrm_state_put(x); + pkt_dev->flows[i].x = NULL; + } + } + } +} + +static int process_ipsec(struct pktgen_dev *pkt_dev, + struct sk_buff *skb, __be16 protocol) +{ + if (pkt_dev->flags & F_IPSEC) { + struct xfrm_state *x = pkt_dev->flows[pkt_dev->curfl].x; + int nhead = 0; + if (x) { + struct ethhdr *eth; + struct iphdr *iph; + int ret; + + nhead = x->props.header_len - skb_headroom(skb); + if (nhead > 0) { + ret = pskb_expand_head(skb, nhead, 0, GFP_ATOMIC); + if (ret < 0) { + pr_err("Error expanding ipsec packet %d\n", + ret); + goto err; + } + } + + /* ipsec is not expecting ll header */ + skb_pull(skb, ETH_HLEN); + ret = pktgen_output_ipsec(skb, pkt_dev); + if (ret) { + pr_err("Error creating ipsec packet %d\n", ret); + goto err; + } + /* restore ll */ + eth = skb_push(skb, ETH_HLEN); + memcpy(eth, pkt_dev->hh, 2 * ETH_ALEN); + eth->h_proto = protocol; + + /* Update IPv4 header len as well as checksum value */ + iph = ip_hdr(skb); + iph->tot_len = htons(skb->len - ETH_HLEN); + ip_send_check(iph); + } + } + return 1; +err: + kfree_skb(skb); + return 0; +} +#endif + +static void mpls_push(__be32 *mpls, struct pktgen_dev *pkt_dev) +{ + unsigned int i; + for (i = 0; i < pkt_dev->nr_labels; i++) + *mpls++ = pkt_dev->labels[i] & ~MPLS_STACK_BOTTOM; + + mpls--; + *mpls |= MPLS_STACK_BOTTOM; +} + +static inline __be16 build_tci(unsigned int id, unsigned int cfi, + unsigned int prio) +{ + return htons(id | (cfi << 12) | (prio << 13)); +} + +static void pktgen_finalize_skb(struct pktgen_dev *pkt_dev, struct sk_buff *skb, + int datalen) +{ + struct timespec64 timestamp; + struct pktgen_hdr *pgh; + + pgh = skb_put(skb, sizeof(*pgh)); + datalen -= sizeof(*pgh); + + if (pkt_dev->nfrags <= 0) { + skb_put_zero(skb, datalen); + } else { + int frags = pkt_dev->nfrags; + int i, len; + int frag_len; + + + if (frags > MAX_SKB_FRAGS) + frags = MAX_SKB_FRAGS; + len = datalen - frags * PAGE_SIZE; + if (len > 0) { + skb_put_zero(skb, len); + datalen = frags * PAGE_SIZE; + } + + i = 0; + frag_len = (datalen/frags) < PAGE_SIZE ? + (datalen/frags) : PAGE_SIZE; + while (datalen > 0) { + if (unlikely(!pkt_dev->page)) { + int node = numa_node_id(); + + if (pkt_dev->node >= 0 && (pkt_dev->flags & F_NODE)) + node = pkt_dev->node; + pkt_dev->page = alloc_pages_node(node, GFP_KERNEL | __GFP_ZERO, 0); + if (!pkt_dev->page) + break; + } + get_page(pkt_dev->page); + skb_frag_set_page(skb, i, pkt_dev->page); + skb_frag_off_set(&skb_shinfo(skb)->frags[i], 0); + /*last fragment, fill rest of data*/ + if (i == (frags - 1)) + skb_frag_size_set(&skb_shinfo(skb)->frags[i], + (datalen < PAGE_SIZE ? datalen : PAGE_SIZE)); + else + skb_frag_size_set(&skb_shinfo(skb)->frags[i], frag_len); + datalen -= skb_frag_size(&skb_shinfo(skb)->frags[i]); + skb->len += skb_frag_size(&skb_shinfo(skb)->frags[i]); + skb->data_len += skb_frag_size(&skb_shinfo(skb)->frags[i]); + i++; + skb_shinfo(skb)->nr_frags = i; + } + } + + /* Stamp the time, and sequence number, + * convert them to network byte order + */ + pgh->pgh_magic = htonl(PKTGEN_MAGIC); + pgh->seq_num = htonl(pkt_dev->seq_num); + + if (pkt_dev->flags & F_NO_TIMESTAMP) { + pgh->tv_sec = 0; + pgh->tv_usec = 0; + } else { + /* + * pgh->tv_sec wraps in y2106 when interpreted as unsigned + * as done by wireshark, or y2038 when interpreted as signed. + * This is probably harmless, but if anyone wants to improve + * it, we could introduce a variant that puts 64-bit nanoseconds + * into the respective header bytes. + * This would also be slightly faster to read. + */ + ktime_get_real_ts64(×tamp); + pgh->tv_sec = htonl(timestamp.tv_sec); + pgh->tv_usec = htonl(timestamp.tv_nsec / NSEC_PER_USEC); + } +} + +static struct sk_buff *pktgen_alloc_skb(struct net_device *dev, + struct pktgen_dev *pkt_dev) +{ + unsigned int extralen = LL_RESERVED_SPACE(dev); + struct sk_buff *skb = NULL; + unsigned int size; + + size = pkt_dev->cur_pkt_size + 64 + extralen + pkt_dev->pkt_overhead; + if (pkt_dev->flags & F_NODE) { + int node = pkt_dev->node >= 0 ? pkt_dev->node : numa_node_id(); + + skb = __alloc_skb(NET_SKB_PAD + size, GFP_NOWAIT, 0, node); + if (likely(skb)) { + skb_reserve(skb, NET_SKB_PAD); + skb->dev = dev; + } + } else { + skb = __netdev_alloc_skb(dev, size, GFP_NOWAIT); + } + + /* the caller pre-fetches from skb->data and reserves for the mac hdr */ + if (likely(skb)) + skb_reserve(skb, extralen - 16); + + return skb; +} + +static struct sk_buff *fill_packet_ipv4(struct net_device *odev, + struct pktgen_dev *pkt_dev) +{ + struct sk_buff *skb = NULL; + __u8 *eth; + struct udphdr *udph; + int datalen, iplen; + struct iphdr *iph; + __be16 protocol = htons(ETH_P_IP); + __be32 *mpls; + __be16 *vlan_tci = NULL; /* Encapsulates priority and VLAN ID */ + __be16 *vlan_encapsulated_proto = NULL; /* packet type ID field (or len) for VLAN tag */ + __be16 *svlan_tci = NULL; /* Encapsulates priority and SVLAN ID */ + __be16 *svlan_encapsulated_proto = NULL; /* packet type ID field (or len) for SVLAN tag */ + u16 queue_map; + + if (pkt_dev->nr_labels) + protocol = htons(ETH_P_MPLS_UC); + + if (pkt_dev->vlan_id != 0xffff) + protocol = htons(ETH_P_8021Q); + + /* Update any of the values, used when we're incrementing various + * fields. + */ + mod_cur_headers(pkt_dev); + queue_map = pkt_dev->cur_queue_map; + + skb = pktgen_alloc_skb(odev, pkt_dev); + if (!skb) { + sprintf(pkt_dev->result, "No memory"); + return NULL; + } + + prefetchw(skb->data); + skb_reserve(skb, 16); + + /* Reserve for ethernet and IP header */ + eth = skb_push(skb, 14); + mpls = skb_put(skb, pkt_dev->nr_labels * sizeof(__u32)); + if (pkt_dev->nr_labels) + mpls_push(mpls, pkt_dev); + + if (pkt_dev->vlan_id != 0xffff) { + if (pkt_dev->svlan_id != 0xffff) { + svlan_tci = skb_put(skb, sizeof(__be16)); + *svlan_tci = build_tci(pkt_dev->svlan_id, + pkt_dev->svlan_cfi, + pkt_dev->svlan_p); + svlan_encapsulated_proto = skb_put(skb, + sizeof(__be16)); + *svlan_encapsulated_proto = htons(ETH_P_8021Q); + } + vlan_tci = skb_put(skb, sizeof(__be16)); + *vlan_tci = build_tci(pkt_dev->vlan_id, + pkt_dev->vlan_cfi, + pkt_dev->vlan_p); + vlan_encapsulated_proto = skb_put(skb, sizeof(__be16)); + *vlan_encapsulated_proto = htons(ETH_P_IP); + } + + skb_reset_mac_header(skb); + skb_set_network_header(skb, skb->len); + iph = skb_put(skb, sizeof(struct iphdr)); + + skb_set_transport_header(skb, skb->len); + udph = skb_put(skb, sizeof(struct udphdr)); + skb_set_queue_mapping(skb, queue_map); + skb->priority = pkt_dev->skb_priority; + + memcpy(eth, pkt_dev->hh, 12); + *(__be16 *) & eth[12] = protocol; + + /* Eth + IPh + UDPh + mpls */ + datalen = pkt_dev->cur_pkt_size - 14 - 20 - 8 - + pkt_dev->pkt_overhead; + if (datalen < 0 || datalen < sizeof(struct pktgen_hdr)) + datalen = sizeof(struct pktgen_hdr); + + udph->source = htons(pkt_dev->cur_udp_src); + udph->dest = htons(pkt_dev->cur_udp_dst); + udph->len = htons(datalen + 8); /* DATA + udphdr */ + udph->check = 0; + + iph->ihl = 5; + iph->version = 4; + iph->ttl = 32; + iph->tos = pkt_dev->tos; + iph->protocol = IPPROTO_UDP; /* UDP */ + iph->saddr = pkt_dev->cur_saddr; + iph->daddr = pkt_dev->cur_daddr; + iph->id = htons(pkt_dev->ip_id); + pkt_dev->ip_id++; + iph->frag_off = 0; + iplen = 20 + 8 + datalen; + iph->tot_len = htons(iplen); + ip_send_check(iph); + skb->protocol = protocol; + skb->dev = odev; + skb->pkt_type = PACKET_HOST; + + pktgen_finalize_skb(pkt_dev, skb, datalen); + + if (!(pkt_dev->flags & F_UDPCSUM)) { + skb->ip_summed = CHECKSUM_NONE; + } else if (odev->features & (NETIF_F_HW_CSUM | NETIF_F_IP_CSUM)) { + skb->ip_summed = CHECKSUM_PARTIAL; + skb->csum = 0; + udp4_hwcsum(skb, iph->saddr, iph->daddr); + } else { + __wsum csum = skb_checksum(skb, skb_transport_offset(skb), datalen + 8, 0); + + /* add protocol-dependent pseudo-header */ + udph->check = csum_tcpudp_magic(iph->saddr, iph->daddr, + datalen + 8, IPPROTO_UDP, csum); + + if (udph->check == 0) + udph->check = CSUM_MANGLED_0; + } + +#ifdef CONFIG_XFRM + if (!process_ipsec(pkt_dev, skb, protocol)) + return NULL; +#endif + + return skb; +} + +static struct sk_buff *fill_packet_ipv6(struct net_device *odev, + struct pktgen_dev *pkt_dev) +{ + struct sk_buff *skb = NULL; + __u8 *eth; + struct udphdr *udph; + int datalen, udplen; + struct ipv6hdr *iph; + __be16 protocol = htons(ETH_P_IPV6); + __be32 *mpls; + __be16 *vlan_tci = NULL; /* Encapsulates priority and VLAN ID */ + __be16 *vlan_encapsulated_proto = NULL; /* packet type ID field (or len) for VLAN tag */ + __be16 *svlan_tci = NULL; /* Encapsulates priority and SVLAN ID */ + __be16 *svlan_encapsulated_proto = NULL; /* packet type ID field (or len) for SVLAN tag */ + u16 queue_map; + + if (pkt_dev->nr_labels) + protocol = htons(ETH_P_MPLS_UC); + + if (pkt_dev->vlan_id != 0xffff) + protocol = htons(ETH_P_8021Q); + + /* Update any of the values, used when we're incrementing various + * fields. + */ + mod_cur_headers(pkt_dev); + queue_map = pkt_dev->cur_queue_map; + + skb = pktgen_alloc_skb(odev, pkt_dev); + if (!skb) { + sprintf(pkt_dev->result, "No memory"); + return NULL; + } + + prefetchw(skb->data); + skb_reserve(skb, 16); + + /* Reserve for ethernet and IP header */ + eth = skb_push(skb, 14); + mpls = skb_put(skb, pkt_dev->nr_labels * sizeof(__u32)); + if (pkt_dev->nr_labels) + mpls_push(mpls, pkt_dev); + + if (pkt_dev->vlan_id != 0xffff) { + if (pkt_dev->svlan_id != 0xffff) { + svlan_tci = skb_put(skb, sizeof(__be16)); + *svlan_tci = build_tci(pkt_dev->svlan_id, + pkt_dev->svlan_cfi, + pkt_dev->svlan_p); + svlan_encapsulated_proto = skb_put(skb, + sizeof(__be16)); + *svlan_encapsulated_proto = htons(ETH_P_8021Q); + } + vlan_tci = skb_put(skb, sizeof(__be16)); + *vlan_tci = build_tci(pkt_dev->vlan_id, + pkt_dev->vlan_cfi, + pkt_dev->vlan_p); + vlan_encapsulated_proto = skb_put(skb, sizeof(__be16)); + *vlan_encapsulated_proto = htons(ETH_P_IPV6); + } + + skb_reset_mac_header(skb); + skb_set_network_header(skb, skb->len); + iph = skb_put(skb, sizeof(struct ipv6hdr)); + + skb_set_transport_header(skb, skb->len); + udph = skb_put(skb, sizeof(struct udphdr)); + skb_set_queue_mapping(skb, queue_map); + skb->priority = pkt_dev->skb_priority; + + memcpy(eth, pkt_dev->hh, 12); + *(__be16 *) ð[12] = protocol; + + /* Eth + IPh + UDPh + mpls */ + datalen = pkt_dev->cur_pkt_size - 14 - + sizeof(struct ipv6hdr) - sizeof(struct udphdr) - + pkt_dev->pkt_overhead; + + if (datalen < 0 || datalen < sizeof(struct pktgen_hdr)) { + datalen = sizeof(struct pktgen_hdr); + net_info_ratelimited("increased datalen to %d\n", datalen); + } + + udplen = datalen + sizeof(struct udphdr); + udph->source = htons(pkt_dev->cur_udp_src); + udph->dest = htons(pkt_dev->cur_udp_dst); + udph->len = htons(udplen); + udph->check = 0; + + *(__be32 *) iph = htonl(0x60000000); /* Version + flow */ + + if (pkt_dev->traffic_class) { + /* Version + traffic class + flow (0) */ + *(__be32 *)iph |= htonl(0x60000000 | (pkt_dev->traffic_class << 20)); + } + + iph->hop_limit = 32; + + iph->payload_len = htons(udplen); + iph->nexthdr = IPPROTO_UDP; + + iph->daddr = pkt_dev->cur_in6_daddr; + iph->saddr = pkt_dev->cur_in6_saddr; + + skb->protocol = protocol; + skb->dev = odev; + skb->pkt_type = PACKET_HOST; + + pktgen_finalize_skb(pkt_dev, skb, datalen); + + if (!(pkt_dev->flags & F_UDPCSUM)) { + skb->ip_summed = CHECKSUM_NONE; + } else if (odev->features & (NETIF_F_HW_CSUM | NETIF_F_IPV6_CSUM)) { + skb->ip_summed = CHECKSUM_PARTIAL; + skb->csum_start = skb_transport_header(skb) - skb->head; + skb->csum_offset = offsetof(struct udphdr, check); + udph->check = ~csum_ipv6_magic(&iph->saddr, &iph->daddr, udplen, IPPROTO_UDP, 0); + } else { + __wsum csum = skb_checksum(skb, skb_transport_offset(skb), udplen, 0); + + /* add protocol-dependent pseudo-header */ + udph->check = csum_ipv6_magic(&iph->saddr, &iph->daddr, udplen, IPPROTO_UDP, csum); + + if (udph->check == 0) + udph->check = CSUM_MANGLED_0; + } + + return skb; +} + +static struct sk_buff *fill_packet(struct net_device *odev, + struct pktgen_dev *pkt_dev) +{ + if (pkt_dev->flags & F_IPV6) + return fill_packet_ipv6(odev, pkt_dev); + else + return fill_packet_ipv4(odev, pkt_dev); +} + +static void pktgen_clear_counters(struct pktgen_dev *pkt_dev) +{ + pkt_dev->seq_num = 1; + pkt_dev->idle_acc = 0; + pkt_dev->sofar = 0; + pkt_dev->tx_bytes = 0; + pkt_dev->errors = 0; +} + +/* Set up structure for sending pkts, clear counters */ + +static void pktgen_run(struct pktgen_thread *t) +{ + struct pktgen_dev *pkt_dev; + int started = 0; + + func_enter(); + + rcu_read_lock(); + list_for_each_entry_rcu(pkt_dev, &t->if_list, list) { + + /* + * setup odev and create initial packet. + */ + pktgen_setup_inject(pkt_dev); + + if (pkt_dev->odev) { + pktgen_clear_counters(pkt_dev); + pkt_dev->skb = NULL; + pkt_dev->started_at = pkt_dev->next_tx = ktime_get(); + + set_pkt_overhead(pkt_dev); + + strcpy(pkt_dev->result, "Starting"); + pkt_dev->running = 1; /* Cranke yeself! */ + started++; + } else + strcpy(pkt_dev->result, "Error starting"); + } + rcu_read_unlock(); + if (started) + t->control &= ~(T_STOP); +} + +static void pktgen_handle_all_threads(struct pktgen_net *pn, u32 flags) +{ + struct pktgen_thread *t; + + mutex_lock(&pktgen_thread_lock); + + list_for_each_entry(t, &pn->pktgen_threads, th_list) + t->control |= (flags); + + mutex_unlock(&pktgen_thread_lock); +} + +static void pktgen_stop_all_threads(struct pktgen_net *pn) +{ + func_enter(); + + pktgen_handle_all_threads(pn, T_STOP); +} + +static int thread_is_running(const struct pktgen_thread *t) +{ + const struct pktgen_dev *pkt_dev; + + rcu_read_lock(); + list_for_each_entry_rcu(pkt_dev, &t->if_list, list) + if (pkt_dev->running) { + rcu_read_unlock(); + return 1; + } + rcu_read_unlock(); + return 0; +} + +static int pktgen_wait_thread_run(struct pktgen_thread *t) +{ + while (thread_is_running(t)) { + + /* note: 't' will still be around even after the unlock/lock + * cycle because pktgen_thread threads are only cleared at + * net exit + */ + mutex_unlock(&pktgen_thread_lock); + msleep_interruptible(100); + mutex_lock(&pktgen_thread_lock); + + if (signal_pending(current)) + goto signal; + } + return 1; +signal: + return 0; +} + +static int pktgen_wait_all_threads_run(struct pktgen_net *pn) +{ + struct pktgen_thread *t; + int sig = 1; + + /* prevent from racing with rmmod */ + if (!try_module_get(THIS_MODULE)) + return sig; + + mutex_lock(&pktgen_thread_lock); + + list_for_each_entry(t, &pn->pktgen_threads, th_list) { + sig = pktgen_wait_thread_run(t); + if (sig == 0) + break; + } + + if (sig == 0) + list_for_each_entry(t, &pn->pktgen_threads, th_list) + t->control |= (T_STOP); + + mutex_unlock(&pktgen_thread_lock); + module_put(THIS_MODULE); + return sig; +} + +static void pktgen_run_all_threads(struct pktgen_net *pn) +{ + func_enter(); + + pktgen_handle_all_threads(pn, T_RUN); + + /* Propagate thread->control */ + schedule_timeout_interruptible(msecs_to_jiffies(125)); + + pktgen_wait_all_threads_run(pn); +} + +static void pktgen_reset_all_threads(struct pktgen_net *pn) +{ + func_enter(); + + pktgen_handle_all_threads(pn, T_REMDEVALL); + + /* Propagate thread->control */ + schedule_timeout_interruptible(msecs_to_jiffies(125)); + + pktgen_wait_all_threads_run(pn); +} + +static void show_results(struct pktgen_dev *pkt_dev, int nr_frags) +{ + __u64 bps, mbps, pps; + char *p = pkt_dev->result; + ktime_t elapsed = ktime_sub(pkt_dev->stopped_at, + pkt_dev->started_at); + ktime_t idle = ns_to_ktime(pkt_dev->idle_acc); + + p += sprintf(p, "OK: %llu(c%llu+d%llu) usec, %llu (%dbyte,%dfrags)\n", + (unsigned long long)ktime_to_us(elapsed), + (unsigned long long)ktime_to_us(ktime_sub(elapsed, idle)), + (unsigned long long)ktime_to_us(idle), + (unsigned long long)pkt_dev->sofar, + pkt_dev->cur_pkt_size, nr_frags); + + pps = div64_u64(pkt_dev->sofar * NSEC_PER_SEC, + ktime_to_ns(elapsed)); + + if (pkt_dev->n_imix_entries > 0) { + int i; + struct imix_pkt *entry; + + bps = 0; + for (i = 0; i < pkt_dev->n_imix_entries; i++) { + entry = &pkt_dev->imix_entries[i]; + bps += entry->size * entry->count_so_far; + } + bps = div64_u64(bps * 8 * NSEC_PER_SEC, ktime_to_ns(elapsed)); + } else { + bps = pps * 8 * pkt_dev->cur_pkt_size; + } + + mbps = bps; + do_div(mbps, 1000000); + p += sprintf(p, " %llupps %lluMb/sec (%llubps) errors: %llu", + (unsigned long long)pps, + (unsigned long long)mbps, + (unsigned long long)bps, + (unsigned long long)pkt_dev->errors); +} + +/* Set stopped-at timer, remove from running list, do counters & statistics */ +static int pktgen_stop_device(struct pktgen_dev *pkt_dev) +{ + int nr_frags = pkt_dev->skb ? skb_shinfo(pkt_dev->skb)->nr_frags : -1; + + if (!pkt_dev->running) { + pr_warn("interface: %s is already stopped\n", + pkt_dev->odevname); + return -EINVAL; + } + + pkt_dev->running = 0; + kfree_skb(pkt_dev->skb); + pkt_dev->skb = NULL; + pkt_dev->stopped_at = ktime_get(); + + show_results(pkt_dev, nr_frags); + + return 0; +} + +static struct pktgen_dev *next_to_run(struct pktgen_thread *t) +{ + struct pktgen_dev *pkt_dev, *best = NULL; + + rcu_read_lock(); + list_for_each_entry_rcu(pkt_dev, &t->if_list, list) { + if (!pkt_dev->running) + continue; + if (best == NULL) + best = pkt_dev; + else if (ktime_compare(pkt_dev->next_tx, best->next_tx) < 0) + best = pkt_dev; + } + rcu_read_unlock(); + + return best; +} + +static void pktgen_stop(struct pktgen_thread *t) +{ + struct pktgen_dev *pkt_dev; + + func_enter(); + + rcu_read_lock(); + + list_for_each_entry_rcu(pkt_dev, &t->if_list, list) { + pktgen_stop_device(pkt_dev); + } + + rcu_read_unlock(); +} + +/* + * one of our devices needs to be removed - find it + * and remove it + */ +static void pktgen_rem_one_if(struct pktgen_thread *t) +{ + struct list_head *q, *n; + struct pktgen_dev *cur; + + func_enter(); + + list_for_each_safe(q, n, &t->if_list) { + cur = list_entry(q, struct pktgen_dev, list); + + if (!cur->removal_mark) + continue; + + kfree_skb(cur->skb); + cur->skb = NULL; + + pktgen_remove_device(t, cur); + + break; + } +} + +static void pktgen_rem_all_ifs(struct pktgen_thread *t) +{ + struct list_head *q, *n; + struct pktgen_dev *cur; + + func_enter(); + + /* Remove all devices, free mem */ + + list_for_each_safe(q, n, &t->if_list) { + cur = list_entry(q, struct pktgen_dev, list); + + kfree_skb(cur->skb); + cur->skb = NULL; + + pktgen_remove_device(t, cur); + } +} + +static void pktgen_rem_thread(struct pktgen_thread *t) +{ + /* Remove from the thread list */ + remove_proc_entry(t->tsk->comm, t->net->proc_dir); +} + +static void pktgen_resched(struct pktgen_dev *pkt_dev) +{ + ktime_t idle_start = ktime_get(); + schedule(); + pkt_dev->idle_acc += ktime_to_ns(ktime_sub(ktime_get(), idle_start)); +} + +static void pktgen_wait_for_skb(struct pktgen_dev *pkt_dev) +{ + ktime_t idle_start = ktime_get(); + + while (refcount_read(&(pkt_dev->skb->users)) != 1) { + if (signal_pending(current)) + break; + + if (need_resched()) + pktgen_resched(pkt_dev); + else + cpu_relax(); + } + pkt_dev->idle_acc += ktime_to_ns(ktime_sub(ktime_get(), idle_start)); +} + +static void pktgen_xmit(struct pktgen_dev *pkt_dev) +{ + unsigned int burst = READ_ONCE(pkt_dev->burst); + struct net_device *odev = pkt_dev->odev; + struct netdev_queue *txq; + struct sk_buff *skb; + int ret; + + /* If device is offline, then don't send */ + if (unlikely(!netif_running(odev) || !netif_carrier_ok(odev))) { + pktgen_stop_device(pkt_dev); + return; + } + + /* This is max DELAY, this has special meaning of + * "never transmit" + */ + if (unlikely(pkt_dev->delay == ULLONG_MAX)) { + pkt_dev->next_tx = ktime_add_ns(ktime_get(), ULONG_MAX); + return; + } + + /* If no skb or clone count exhausted then get new one */ + if (!pkt_dev->skb || (pkt_dev->last_ok && + ++pkt_dev->clone_count >= pkt_dev->clone_skb)) { + /* build a new pkt */ + kfree_skb(pkt_dev->skb); + + pkt_dev->skb = fill_packet(odev, pkt_dev); + if (pkt_dev->skb == NULL) { + pr_err("ERROR: couldn't allocate skb in fill_packet\n"); + schedule(); + pkt_dev->clone_count--; /* back out increment, OOM */ + return; + } + pkt_dev->last_pkt_size = pkt_dev->skb->len; + pkt_dev->clone_count = 0; /* reset counter */ + } + + if (pkt_dev->delay && pkt_dev->last_ok) + spin(pkt_dev, pkt_dev->next_tx); + + if (pkt_dev->xmit_mode == M_NETIF_RECEIVE) { + skb = pkt_dev->skb; + skb->protocol = eth_type_trans(skb, skb->dev); + refcount_add(burst, &skb->users); + local_bh_disable(); + do { + ret = netif_receive_skb(skb); + if (ret == NET_RX_DROP) + pkt_dev->errors++; + pkt_dev->sofar++; + pkt_dev->seq_num++; + if (refcount_read(&skb->users) != burst) { + /* skb was queued by rps/rfs or taps, + * so cannot reuse this skb + */ + WARN_ON(refcount_sub_and_test(burst - 1, &skb->users)); + /* get out of the loop and wait + * until skb is consumed + */ + break; + } + /* skb was 'freed' by stack, so clean few + * bits and reuse it + */ + skb_reset_redirect(skb); + } while (--burst > 0); + goto out; /* Skips xmit_mode M_START_XMIT */ + } else if (pkt_dev->xmit_mode == M_QUEUE_XMIT) { + local_bh_disable(); + refcount_inc(&pkt_dev->skb->users); + + ret = dev_queue_xmit(pkt_dev->skb); + switch (ret) { + case NET_XMIT_SUCCESS: + pkt_dev->sofar++; + pkt_dev->seq_num++; + pkt_dev->tx_bytes += pkt_dev->last_pkt_size; + break; + case NET_XMIT_DROP: + case NET_XMIT_CN: + /* These are all valid return codes for a qdisc but + * indicate packets are being dropped or will likely + * be dropped soon. + */ + case NETDEV_TX_BUSY: + /* qdisc may call dev_hard_start_xmit directly in cases + * where no queues exist e.g. loopback device, virtual + * devices, etc. In this case we need to handle + * NETDEV_TX_ codes. + */ + default: + pkt_dev->errors++; + net_info_ratelimited("%s xmit error: %d\n", + pkt_dev->odevname, ret); + break; + } + goto out; + } + + txq = skb_get_tx_queue(odev, pkt_dev->skb); + + local_bh_disable(); + + HARD_TX_LOCK(odev, txq, smp_processor_id()); + + if (unlikely(netif_xmit_frozen_or_drv_stopped(txq))) { + pkt_dev->last_ok = 0; + goto unlock; + } + refcount_add(burst, &pkt_dev->skb->users); + +xmit_more: + ret = netdev_start_xmit(pkt_dev->skb, odev, txq, --burst > 0); + + switch (ret) { + case NETDEV_TX_OK: + pkt_dev->last_ok = 1; + pkt_dev->sofar++; + pkt_dev->seq_num++; + pkt_dev->tx_bytes += pkt_dev->last_pkt_size; + if (burst > 0 && !netif_xmit_frozen_or_drv_stopped(txq)) + goto xmit_more; + break; + case NET_XMIT_DROP: + case NET_XMIT_CN: + /* skb has been consumed */ + pkt_dev->errors++; + break; + default: /* Drivers are not supposed to return other values! */ + net_info_ratelimited("%s xmit error: %d\n", + pkt_dev->odevname, ret); + pkt_dev->errors++; + fallthrough; + case NETDEV_TX_BUSY: + /* Retry it next time */ + refcount_dec(&(pkt_dev->skb->users)); + pkt_dev->last_ok = 0; + } + if (unlikely(burst)) + WARN_ON(refcount_sub_and_test(burst, &pkt_dev->skb->users)); +unlock: + HARD_TX_UNLOCK(odev, txq); + +out: + local_bh_enable(); + + /* If pkt_dev->count is zero, then run forever */ + if ((pkt_dev->count != 0) && (pkt_dev->sofar >= pkt_dev->count)) { + pktgen_wait_for_skb(pkt_dev); + + /* Done with this */ + pktgen_stop_device(pkt_dev); + } +} + +/* + * Main loop of the thread goes here + */ + +static int pktgen_thread_worker(void *arg) +{ + struct pktgen_thread *t = arg; + struct pktgen_dev *pkt_dev = NULL; + int cpu = t->cpu; + + WARN_ON(smp_processor_id() != cpu); + + init_waitqueue_head(&t->queue); + complete(&t->start_done); + + pr_debug("starting pktgen/%d: pid=%d\n", cpu, task_pid_nr(current)); + + set_freezable(); + + while (!kthread_should_stop()) { + pkt_dev = next_to_run(t); + + if (unlikely(!pkt_dev && t->control == 0)) { + if (t->net->pktgen_exiting) + break; + wait_event_interruptible_timeout(t->queue, + t->control != 0, + HZ/10); + try_to_freeze(); + continue; + } + + if (likely(pkt_dev)) { + pktgen_xmit(pkt_dev); + + if (need_resched()) + pktgen_resched(pkt_dev); + else + cpu_relax(); + } + + if (t->control & T_STOP) { + pktgen_stop(t); + t->control &= ~(T_STOP); + } + + if (t->control & T_RUN) { + pktgen_run(t); + t->control &= ~(T_RUN); + } + + if (t->control & T_REMDEVALL) { + pktgen_rem_all_ifs(t); + t->control &= ~(T_REMDEVALL); + } + + if (t->control & T_REMDEV) { + pktgen_rem_one_if(t); + t->control &= ~(T_REMDEV); + } + + try_to_freeze(); + } + + pr_debug("%s stopping all device\n", t->tsk->comm); + pktgen_stop(t); + + pr_debug("%s removing all device\n", t->tsk->comm); + pktgen_rem_all_ifs(t); + + pr_debug("%s removing thread\n", t->tsk->comm); + pktgen_rem_thread(t); + + return 0; +} + +static struct pktgen_dev *pktgen_find_dev(struct pktgen_thread *t, + const char *ifname, bool exact) +{ + struct pktgen_dev *p, *pkt_dev = NULL; + size_t len = strlen(ifname); + + rcu_read_lock(); + list_for_each_entry_rcu(p, &t->if_list, list) + if (strncmp(p->odevname, ifname, len) == 0) { + if (p->odevname[len]) { + if (exact || p->odevname[len] != '@') + continue; + } + pkt_dev = p; + break; + } + + rcu_read_unlock(); + pr_debug("find_dev(%s) returning %p\n", ifname, pkt_dev); + return pkt_dev; +} + +/* + * Adds a dev at front of if_list. + */ + +static int add_dev_to_thread(struct pktgen_thread *t, + struct pktgen_dev *pkt_dev) +{ + int rv = 0; + + /* This function cannot be called concurrently, as its called + * under pktgen_thread_lock mutex, but it can run from + * userspace on another CPU than the kthread. The if_lock() + * is used here to sync with concurrent instances of + * _rem_dev_from_if_list() invoked via kthread, which is also + * updating the if_list */ + if_lock(t); + + if (pkt_dev->pg_thread) { + pr_err("ERROR: already assigned to a thread\n"); + rv = -EBUSY; + goto out; + } + + pkt_dev->running = 0; + pkt_dev->pg_thread = t; + list_add_rcu(&pkt_dev->list, &t->if_list); + +out: + if_unlock(t); + return rv; +} + +/* Called under thread lock */ + +static int pktgen_add_device(struct pktgen_thread *t, const char *ifname) +{ + struct pktgen_dev *pkt_dev; + int err; + int node = cpu_to_node(t->cpu); + + /* We don't allow a device to be on several threads */ + + pkt_dev = __pktgen_NN_threads(t->net, ifname, FIND); + if (pkt_dev) { + pr_err("ERROR: interface already used\n"); + return -EBUSY; + } + + pkt_dev = kzalloc_node(sizeof(struct pktgen_dev), GFP_KERNEL, node); + if (!pkt_dev) + return -ENOMEM; + + strcpy(pkt_dev->odevname, ifname); + pkt_dev->flows = vzalloc_node(array_size(MAX_CFLOWS, + sizeof(struct flow_state)), + node); + if (pkt_dev->flows == NULL) { + kfree(pkt_dev); + return -ENOMEM; + } + + pkt_dev->removal_mark = 0; + pkt_dev->nfrags = 0; + pkt_dev->delay = pg_delay_d; + pkt_dev->count = pg_count_d; + pkt_dev->sofar = 0; + pkt_dev->udp_src_min = 9; /* sink port */ + pkt_dev->udp_src_max = 9; + pkt_dev->udp_dst_min = 9; + pkt_dev->udp_dst_max = 9; + pkt_dev->vlan_p = 0; + pkt_dev->vlan_cfi = 0; + pkt_dev->vlan_id = 0xffff; + pkt_dev->svlan_p = 0; + pkt_dev->svlan_cfi = 0; + pkt_dev->svlan_id = 0xffff; + pkt_dev->burst = 1; + pkt_dev->node = NUMA_NO_NODE; + + err = pktgen_setup_dev(t->net, pkt_dev, ifname); + if (err) + goto out1; + if (pkt_dev->odev->priv_flags & IFF_TX_SKB_SHARING) + pkt_dev->clone_skb = pg_clone_skb_d; + + pkt_dev->entry = proc_create_data(ifname, 0600, t->net->proc_dir, + &pktgen_if_proc_ops, pkt_dev); + if (!pkt_dev->entry) { + pr_err("cannot create %s/%s procfs entry\n", + PG_PROC_DIR, ifname); + err = -EINVAL; + goto out2; + } +#ifdef CONFIG_XFRM + pkt_dev->ipsmode = XFRM_MODE_TRANSPORT; + pkt_dev->ipsproto = IPPROTO_ESP; + + /* xfrm tunnel mode needs additional dst to extract outter + * ip header protocol/ttl/id field, here creat a phony one. + * instead of looking for a valid rt, which definitely hurting + * performance under such circumstance. + */ + pkt_dev->dstops.family = AF_INET; + pkt_dev->xdst.u.dst.dev = pkt_dev->odev; + dst_init_metrics(&pkt_dev->xdst.u.dst, pktgen_dst_metrics, false); + pkt_dev->xdst.child = &pkt_dev->xdst.u.dst; + pkt_dev->xdst.u.dst.ops = &pkt_dev->dstops; +#endif + + return add_dev_to_thread(t, pkt_dev); +out2: + netdev_put(pkt_dev->odev, &pkt_dev->dev_tracker); +out1: +#ifdef CONFIG_XFRM + free_SAs(pkt_dev); +#endif + vfree(pkt_dev->flows); + kfree(pkt_dev); + return err; +} + +static int __net_init pktgen_create_thread(int cpu, struct pktgen_net *pn) +{ + struct pktgen_thread *t; + struct proc_dir_entry *pe; + struct task_struct *p; + + t = kzalloc_node(sizeof(struct pktgen_thread), GFP_KERNEL, + cpu_to_node(cpu)); + if (!t) { + pr_err("ERROR: out of memory, can't create new thread\n"); + return -ENOMEM; + } + + mutex_init(&t->if_lock); + t->cpu = cpu; + + INIT_LIST_HEAD(&t->if_list); + + list_add_tail(&t->th_list, &pn->pktgen_threads); + init_completion(&t->start_done); + + p = kthread_create_on_node(pktgen_thread_worker, + t, + cpu_to_node(cpu), + "kpktgend_%d", cpu); + if (IS_ERR(p)) { + pr_err("kthread_create_on_node() failed for cpu %d\n", t->cpu); + list_del(&t->th_list); + kfree(t); + return PTR_ERR(p); + } + kthread_bind(p, cpu); + t->tsk = p; + + pe = proc_create_data(t->tsk->comm, 0600, pn->proc_dir, + &pktgen_thread_proc_ops, t); + if (!pe) { + pr_err("cannot create %s/%s procfs entry\n", + PG_PROC_DIR, t->tsk->comm); + kthread_stop(p); + list_del(&t->th_list); + kfree(t); + return -EINVAL; + } + + t->net = pn; + get_task_struct(p); + wake_up_process(p); + wait_for_completion(&t->start_done); + + return 0; +} + +/* + * Removes a device from the thread if_list. + */ +static void _rem_dev_from_if_list(struct pktgen_thread *t, + struct pktgen_dev *pkt_dev) +{ + struct list_head *q, *n; + struct pktgen_dev *p; + + if_lock(t); + list_for_each_safe(q, n, &t->if_list) { + p = list_entry(q, struct pktgen_dev, list); + if (p == pkt_dev) + list_del_rcu(&p->list); + } + if_unlock(t); +} + +static int pktgen_remove_device(struct pktgen_thread *t, + struct pktgen_dev *pkt_dev) +{ + pr_debug("remove_device pkt_dev=%p\n", pkt_dev); + + if (pkt_dev->running) { + pr_warn("WARNING: trying to remove a running interface, stopping it now\n"); + pktgen_stop_device(pkt_dev); + } + + /* Dis-associate from the interface */ + + if (pkt_dev->odev) { + netdev_put(pkt_dev->odev, &pkt_dev->dev_tracker); + pkt_dev->odev = NULL; + } + + /* Remove proc before if_list entry, because add_device uses + * list to determine if interface already exist, avoid race + * with proc_create_data() */ + proc_remove(pkt_dev->entry); + + /* And update the thread if_list */ + _rem_dev_from_if_list(t, pkt_dev); + +#ifdef CONFIG_XFRM + free_SAs(pkt_dev); +#endif + vfree(pkt_dev->flows); + if (pkt_dev->page) + put_page(pkt_dev->page); + kfree_rcu(pkt_dev, rcu); + return 0; +} + +static int __net_init pg_net_init(struct net *net) +{ + struct pktgen_net *pn = net_generic(net, pg_net_id); + struct proc_dir_entry *pe; + int cpu, ret = 0; + + pn->net = net; + INIT_LIST_HEAD(&pn->pktgen_threads); + pn->pktgen_exiting = false; + pn->proc_dir = proc_mkdir(PG_PROC_DIR, pn->net->proc_net); + if (!pn->proc_dir) { + pr_warn("cannot create /proc/net/%s\n", PG_PROC_DIR); + return -ENODEV; + } + pe = proc_create(PGCTRL, 0600, pn->proc_dir, &pktgen_proc_ops); + if (pe == NULL) { + pr_err("cannot create %s procfs entry\n", PGCTRL); + ret = -EINVAL; + goto remove; + } + + for_each_online_cpu(cpu) { + int err; + + err = pktgen_create_thread(cpu, pn); + if (err) + pr_warn("Cannot create thread for cpu %d (%d)\n", + cpu, err); + } + + if (list_empty(&pn->pktgen_threads)) { + pr_err("Initialization failed for all threads\n"); + ret = -ENODEV; + goto remove_entry; + } + + return 0; + +remove_entry: + remove_proc_entry(PGCTRL, pn->proc_dir); +remove: + remove_proc_entry(PG_PROC_DIR, pn->net->proc_net); + return ret; +} + +static void __net_exit pg_net_exit(struct net *net) +{ + struct pktgen_net *pn = net_generic(net, pg_net_id); + struct pktgen_thread *t; + struct list_head *q, *n; + LIST_HEAD(list); + + /* Stop all interfaces & threads */ + pn->pktgen_exiting = true; + + mutex_lock(&pktgen_thread_lock); + list_splice_init(&pn->pktgen_threads, &list); + mutex_unlock(&pktgen_thread_lock); + + list_for_each_safe(q, n, &list) { + t = list_entry(q, struct pktgen_thread, th_list); + list_del(&t->th_list); + kthread_stop(t->tsk); + put_task_struct(t->tsk); + kfree(t); + } + + remove_proc_entry(PGCTRL, pn->proc_dir); + remove_proc_entry(PG_PROC_DIR, pn->net->proc_net); +} + +static struct pernet_operations pg_net_ops = { + .init = pg_net_init, + .exit = pg_net_exit, + .id = &pg_net_id, + .size = sizeof(struct pktgen_net), +}; + +static int __init pg_init(void) +{ + int ret = 0; + + pr_info("%s", version); + ret = register_pernet_subsys(&pg_net_ops); + if (ret) + return ret; + ret = register_netdevice_notifier(&pktgen_notifier_block); + if (ret) + unregister_pernet_subsys(&pg_net_ops); + + return ret; +} + +static void __exit pg_cleanup(void) +{ + unregister_netdevice_notifier(&pktgen_notifier_block); + unregister_pernet_subsys(&pg_net_ops); + /* Don't need rcu_barrier() due to use of kfree_rcu() */ +} + +module_init(pg_init); +module_exit(pg_cleanup); + +MODULE_AUTHOR("Robert Olsson "); +MODULE_DESCRIPTION("Packet Generator tool"); +MODULE_LICENSE("GPL"); +MODULE_VERSION(VERSION); +module_param(pg_count_d, int, 0); +MODULE_PARM_DESC(pg_count_d, "Default number of packets to inject"); +module_param(pg_delay_d, int, 0); +MODULE_PARM_DESC(pg_delay_d, "Default delay between packets (nanoseconds)"); +module_param(pg_clone_skb_d, int, 0); +MODULE_PARM_DESC(pg_clone_skb_d, "Default number of copies of the same packet"); +module_param(debug, int, 0); +MODULE_PARM_DESC(debug, "Enable debugging of pktgen module"); diff --git a/net/core/ptp_classifier.c b/net/core/ptp_classifier.c new file mode 100644 index 000000000..598041b04 --- /dev/null +++ b/net/core/ptp_classifier.c @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* PTP classifier + */ + +/* The below program is the bpf_asm (tools/net/) representation of + * the opcode array in the ptp_filter structure. + * + * For convenience, this can easily be altered and reviewed with + * bpf_asm and bpf_dbg, e.g. `./bpf_asm -c prog` where prog is a + * simple file containing the below program: + * + * ldh [12] ; load ethertype + * + * ; PTP over UDP over IPv4 over Ethernet + * test_ipv4: + * jneq #0x800, test_ipv6 ; ETH_P_IP ? + * ldb [23] ; load proto + * jneq #17, drop_ipv4 ; IPPROTO_UDP ? + * ldh [20] ; load frag offset field + * jset #0x1fff, drop_ipv4 ; don't allow fragments + * ldxb 4*([14]&0xf) ; load IP header len + * ldh [x + 16] ; load UDP dst port + * jneq #319, drop_ipv4 ; is port PTP_EV_PORT ? + * ldh [x + 22] ; load payload + * and #0xf ; mask PTP_CLASS_VMASK + * or #0x10 ; PTP_CLASS_IPV4 + * ret a ; return PTP class + * drop_ipv4: ret #0x0 ; PTP_CLASS_NONE + * + * ; PTP over UDP over IPv6 over Ethernet + * test_ipv6: + * jneq #0x86dd, test_8021q ; ETH_P_IPV6 ? + * ldb [20] ; load proto + * jneq #17, drop_ipv6 ; IPPROTO_UDP ? + * ldh [56] ; load UDP dst port + * jneq #319, drop_ipv6 ; is port PTP_EV_PORT ? + * ldh [62] ; load payload + * and #0xf ; mask PTP_CLASS_VMASK + * or #0x20 ; PTP_CLASS_IPV6 + * ret a ; return PTP class + * drop_ipv6: ret #0x0 ; PTP_CLASS_NONE + * + * ; PTP over 802.1Q over Ethernet + * test_8021q: + * jneq #0x8100, test_ieee1588 ; ETH_P_8021Q ? + * ldh [16] ; load inner type + * jneq #0x88f7, test_8021q_ipv4 ; ETH_P_1588 ? + * ldb [18] ; load payload + * and #0x8 ; as we don't have ports here, test + * jneq #0x0, drop_ieee1588 ; for PTP_GEN_BIT and drop these + * ldh [18] ; reload payload + * and #0xf ; mask PTP_CLASS_VMASK + * or #0xc0 ; PTP_CLASS_VLAN|PTP_CLASS_L2 + * ret a ; return PTP class + * + * ; PTP over UDP over IPv4 over 802.1Q over Ethernet + * test_8021q_ipv4: + * jneq #0x800, test_8021q_ipv6 ; ETH_P_IP ? + * ldb [27] ; load proto + * jneq #17, drop_8021q_ipv4 ; IPPROTO_UDP ? + * ldh [24] ; load frag offset field + * jset #0x1fff, drop_8021q_ipv4; don't allow fragments + * ldxb 4*([18]&0xf) ; load IP header len + * ldh [x + 20] ; load UDP dst port + * jneq #319, drop_8021q_ipv4 ; is port PTP_EV_PORT ? + * ldh [x + 26] ; load payload + * and #0xf ; mask PTP_CLASS_VMASK + * or #0x90 ; PTP_CLASS_VLAN|PTP_CLASS_IPV4 + * ret a ; return PTP class + * drop_8021q_ipv4: ret #0x0 ; PTP_CLASS_NONE + * + * ; PTP over UDP over IPv6 over 802.1Q over Ethernet + * test_8021q_ipv6: + * jneq #0x86dd, drop_8021q_ipv6 ; ETH_P_IPV6 ? + * ldb [24] ; load proto + * jneq #17, drop_8021q_ipv6 ; IPPROTO_UDP ? + * ldh [60] ; load UDP dst port + * jneq #319, drop_8021q_ipv6 ; is port PTP_EV_PORT ? + * ldh [66] ; load payload + * and #0xf ; mask PTP_CLASS_VMASK + * or #0xa0 ; PTP_CLASS_VLAN|PTP_CLASS_IPV6 + * ret a ; return PTP class + * drop_8021q_ipv6: ret #0x0 ; PTP_CLASS_NONE + * + * ; PTP over Ethernet + * test_ieee1588: + * jneq #0x88f7, drop_ieee1588 ; ETH_P_1588 ? + * ldb [14] ; load payload + * and #0x8 ; as we don't have ports here, test + * jneq #0x0, drop_ieee1588 ; for PTP_GEN_BIT and drop these + * ldh [14] ; reload payload + * and #0xf ; mask PTP_CLASS_VMASK + * or #0x40 ; PTP_CLASS_L2 + * ret a ; return PTP class + * drop_ieee1588: ret #0x0 ; PTP_CLASS_NONE + */ + +#include +#include +#include + +static struct bpf_prog *ptp_insns __read_mostly; + +unsigned int ptp_classify_raw(const struct sk_buff *skb) +{ + return bpf_prog_run(ptp_insns, skb); +} +EXPORT_SYMBOL_GPL(ptp_classify_raw); + +struct ptp_header *ptp_parse_header(struct sk_buff *skb, unsigned int type) +{ + u8 *ptr = skb_mac_header(skb); + + if (type & PTP_CLASS_VLAN) + ptr += VLAN_HLEN; + + switch (type & PTP_CLASS_PMASK) { + case PTP_CLASS_IPV4: + ptr += IPV4_HLEN(ptr) + UDP_HLEN; + break; + case PTP_CLASS_IPV6: + ptr += IP6_HLEN + UDP_HLEN; + break; + case PTP_CLASS_L2: + break; + default: + return NULL; + } + + ptr += ETH_HLEN; + + /* Ensure that the entire header is present in this packet. */ + if (ptr + sizeof(struct ptp_header) > skb->data + skb->len) + return NULL; + + return (struct ptp_header *)ptr; +} +EXPORT_SYMBOL_GPL(ptp_parse_header); + +bool ptp_msg_is_sync(struct sk_buff *skb, unsigned int type) +{ + struct ptp_header *hdr; + + hdr = ptp_parse_header(skb, type); + if (!hdr) + return false; + + return ptp_get_msgtype(hdr, type) == PTP_MSGTYPE_SYNC; +} +EXPORT_SYMBOL_GPL(ptp_msg_is_sync); + +void __init ptp_classifier_init(void) +{ + static struct sock_filter ptp_filter[] __initdata = { + { 0x28, 0, 0, 0x0000000c }, + { 0x15, 0, 12, 0x00000800 }, + { 0x30, 0, 0, 0x00000017 }, + { 0x15, 0, 9, 0x00000011 }, + { 0x28, 0, 0, 0x00000014 }, + { 0x45, 7, 0, 0x00001fff }, + { 0xb1, 0, 0, 0x0000000e }, + { 0x48, 0, 0, 0x00000010 }, + { 0x15, 0, 4, 0x0000013f }, + { 0x48, 0, 0, 0x00000016 }, + { 0x54, 0, 0, 0x0000000f }, + { 0x44, 0, 0, 0x00000010 }, + { 0x16, 0, 0, 0x00000000 }, + { 0x06, 0, 0, 0x00000000 }, + { 0x15, 0, 9, 0x000086dd }, + { 0x30, 0, 0, 0x00000014 }, + { 0x15, 0, 6, 0x00000011 }, + { 0x28, 0, 0, 0x00000038 }, + { 0x15, 0, 4, 0x0000013f }, + { 0x28, 0, 0, 0x0000003e }, + { 0x54, 0, 0, 0x0000000f }, + { 0x44, 0, 0, 0x00000020 }, + { 0x16, 0, 0, 0x00000000 }, + { 0x06, 0, 0, 0x00000000 }, + { 0x15, 0, 32, 0x00008100 }, + { 0x28, 0, 0, 0x00000010 }, + { 0x15, 0, 7, 0x000088f7 }, + { 0x30, 0, 0, 0x00000012 }, + { 0x54, 0, 0, 0x00000008 }, + { 0x15, 0, 35, 0x00000000 }, + { 0x28, 0, 0, 0x00000012 }, + { 0x54, 0, 0, 0x0000000f }, + { 0x44, 0, 0, 0x000000c0 }, + { 0x16, 0, 0, 0x00000000 }, + { 0x15, 0, 12, 0x00000800 }, + { 0x30, 0, 0, 0x0000001b }, + { 0x15, 0, 9, 0x00000011 }, + { 0x28, 0, 0, 0x00000018 }, + { 0x45, 7, 0, 0x00001fff }, + { 0xb1, 0, 0, 0x00000012 }, + { 0x48, 0, 0, 0x00000014 }, + { 0x15, 0, 4, 0x0000013f }, + { 0x48, 0, 0, 0x0000001a }, + { 0x54, 0, 0, 0x0000000f }, + { 0x44, 0, 0, 0x00000090 }, + { 0x16, 0, 0, 0x00000000 }, + { 0x06, 0, 0, 0x00000000 }, + { 0x15, 0, 8, 0x000086dd }, + { 0x30, 0, 0, 0x00000018 }, + { 0x15, 0, 6, 0x00000011 }, + { 0x28, 0, 0, 0x0000003c }, + { 0x15, 0, 4, 0x0000013f }, + { 0x28, 0, 0, 0x00000042 }, + { 0x54, 0, 0, 0x0000000f }, + { 0x44, 0, 0, 0x000000a0 }, + { 0x16, 0, 0, 0x00000000 }, + { 0x06, 0, 0, 0x00000000 }, + { 0x15, 0, 7, 0x000088f7 }, + { 0x30, 0, 0, 0x0000000e }, + { 0x54, 0, 0, 0x00000008 }, + { 0x15, 0, 4, 0x00000000 }, + { 0x28, 0, 0, 0x0000000e }, + { 0x54, 0, 0, 0x0000000f }, + { 0x44, 0, 0, 0x00000040 }, + { 0x16, 0, 0, 0x00000000 }, + { 0x06, 0, 0, 0x00000000 }, + }; + struct sock_fprog_kern ptp_prog; + + ptp_prog.len = ARRAY_SIZE(ptp_filter); + ptp_prog.filter = ptp_filter; + + BUG_ON(bpf_prog_create(&ptp_insns, &ptp_prog)); +} diff --git a/net/core/request_sock.c b/net/core/request_sock.c new file mode 100644 index 000000000..63de5c635 --- /dev/null +++ b/net/core/request_sock.c @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NET Generic infrastructure for Network protocols. + * + * Authors: Arnaldo Carvalho de Melo + * + * From code originally in include/net/tcp.h + */ + +#include +#include +#include +#include +#include +#include + +#include + +/* + * Maximum number of SYN_RECV sockets in queue per LISTEN socket. + * One SYN_RECV socket costs about 80bytes on a 32bit machine. + * It would be better to replace it with a global counter for all sockets + * but then some measure against one socket starving all other sockets + * would be needed. + * + * The minimum value of it is 128. Experiments with real servers show that + * it is absolutely not enough even at 100conn/sec. 256 cures most + * of problems. + * This value is adjusted to 128 for low memory machines, + * and it will increase in proportion to the memory of machine. + * Note : Dont forget somaxconn that may limit backlog too. + */ + +void reqsk_queue_alloc(struct request_sock_queue *queue) +{ + queue->fastopenq.rskq_rst_head = NULL; + queue->fastopenq.rskq_rst_tail = NULL; + queue->fastopenq.qlen = 0; + + queue->rskq_accept_head = NULL; +} + +/* + * This function is called to set a Fast Open socket's "fastopen_rsk" field + * to NULL when a TFO socket no longer needs to access the request_sock. + * This happens only after 3WHS has been either completed or aborted (e.g., + * RST is received). + * + * Before TFO, a child socket is created only after 3WHS is completed, + * hence it never needs to access the request_sock. things get a lot more + * complex with TFO. A child socket, accepted or not, has to access its + * request_sock for 3WHS processing, e.g., to retransmit SYN-ACK pkts, + * until 3WHS is either completed or aborted. Afterwards the req will stay + * until either the child socket is accepted, or in the rare case when the + * listener is closed before the child is accepted. + * + * In short, a request socket is only freed after BOTH 3WHS has completed + * (or aborted) and the child socket has been accepted (or listener closed). + * When a child socket is accepted, its corresponding req->sk is set to + * NULL since it's no longer needed. More importantly, "req->sk == NULL" + * will be used by the code below to determine if a child socket has been + * accepted or not, and the check is protected by the fastopenq->lock + * described below. + * + * Note that fastopen_rsk is only accessed from the child socket's context + * with its socket lock held. But a request_sock (req) can be accessed by + * both its child socket through fastopen_rsk, and a listener socket through + * icsk_accept_queue.rskq_accept_head. To protect the access a simple spin + * lock per listener "icsk->icsk_accept_queue.fastopenq->lock" is created. + * only in the rare case when both the listener and the child locks are held, + * e.g., in inet_csk_listen_stop() do we not need to acquire the lock. + * The lock also protects other fields such as fastopenq->qlen, which is + * decremented by this function when fastopen_rsk is no longer needed. + * + * Note that another solution was to simply use the existing socket lock + * from the listener. But first socket lock is difficult to use. It is not + * a simple spin lock - one must consider sock_owned_by_user() and arrange + * to use sk_add_backlog() stuff. But what really makes it infeasible is the + * locking hierarchy violation. E.g., inet_csk_listen_stop() may try to + * acquire a child's lock while holding listener's socket lock. A corner + * case might also exist in tcp_v4_hnd_req() that will trigger this locking + * order. + * + * This function also sets "treq->tfo_listener" to false. + * treq->tfo_listener is used by the listener so it is protected by the + * fastopenq->lock in this function. + */ +void reqsk_fastopen_remove(struct sock *sk, struct request_sock *req, + bool reset) +{ + struct sock *lsk = req->rsk_listener; + struct fastopen_queue *fastopenq; + + fastopenq = &inet_csk(lsk)->icsk_accept_queue.fastopenq; + + RCU_INIT_POINTER(tcp_sk(sk)->fastopen_rsk, NULL); + spin_lock_bh(&fastopenq->lock); + fastopenq->qlen--; + tcp_rsk(req)->tfo_listener = false; + if (req->sk) /* the child socket hasn't been accepted yet */ + goto out; + + if (!reset || lsk->sk_state != TCP_LISTEN) { + /* If the listener has been closed don't bother with the + * special RST handling below. + */ + spin_unlock_bh(&fastopenq->lock); + reqsk_put(req); + return; + } + /* Wait for 60secs before removing a req that has triggered RST. + * This is a simple defense against TFO spoofing attack - by + * counting the req against fastopen.max_qlen, and disabling + * TFO when the qlen exceeds max_qlen. + * + * For more details see CoNext'11 "TCP Fast Open" paper. + */ + req->rsk_timer.expires = jiffies + 60*HZ; + if (fastopenq->rskq_rst_head == NULL) + fastopenq->rskq_rst_head = req; + else + fastopenq->rskq_rst_tail->dl_next = req; + + req->dl_next = NULL; + fastopenq->rskq_rst_tail = req; + fastopenq->qlen++; +out: + spin_unlock_bh(&fastopenq->lock); +} diff --git a/net/core/rtnetlink.c b/net/core/rtnetlink.c new file mode 100644 index 000000000..7cf1e42d7 --- /dev/null +++ b/net/core/rtnetlink.c @@ -0,0 +1,6248 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * Routing netlink socket interface: protocol independent part. + * + * Authors: Alexey Kuznetsov, + * + * Fixes: + * Vitaly E. Lavrov RTA_OK arithmetic was wrong. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dev.h" + +#define RTNL_MAX_TYPE 50 +#define RTNL_SLAVE_MAX_TYPE 40 + +struct rtnl_link { + rtnl_doit_func doit; + rtnl_dumpit_func dumpit; + struct module *owner; + unsigned int flags; + struct rcu_head rcu; +}; + +static DEFINE_MUTEX(rtnl_mutex); + +void rtnl_lock(void) +{ + mutex_lock(&rtnl_mutex); +} +EXPORT_SYMBOL(rtnl_lock); + +int rtnl_lock_killable(void) +{ + return mutex_lock_killable(&rtnl_mutex); +} +EXPORT_SYMBOL(rtnl_lock_killable); + +static struct sk_buff *defer_kfree_skb_list; +void rtnl_kfree_skbs(struct sk_buff *head, struct sk_buff *tail) +{ + if (head && tail) { + tail->next = defer_kfree_skb_list; + defer_kfree_skb_list = head; + } +} +EXPORT_SYMBOL(rtnl_kfree_skbs); + +void __rtnl_unlock(void) +{ + struct sk_buff *head = defer_kfree_skb_list; + + defer_kfree_skb_list = NULL; + + /* Ensure that we didn't actually add any TODO item when __rtnl_unlock() + * is used. In some places, e.g. in cfg80211, we have code that will do + * something like + * rtnl_lock() + * wiphy_lock() + * ... + * rtnl_unlock() + * + * and because netdev_run_todo() acquires the RTNL for items on the list + * we could cause a situation such as this: + * Thread 1 Thread 2 + * rtnl_lock() + * unregister_netdevice() + * __rtnl_unlock() + * rtnl_lock() + * wiphy_lock() + * rtnl_unlock() + * netdev_run_todo() + * __rtnl_unlock() + * + * // list not empty now + * // because of thread 2 + * rtnl_lock() + * while (!list_empty(...)) + * rtnl_lock() + * wiphy_lock() + * **** DEADLOCK **** + * + * However, usage of __rtnl_unlock() is rare, and so we can ensure that + * it's not used in cases where something is added to do the list. + */ + WARN_ON(!list_empty(&net_todo_list)); + + mutex_unlock(&rtnl_mutex); + + while (head) { + struct sk_buff *next = head->next; + + kfree_skb(head); + cond_resched(); + head = next; + } +} + +void rtnl_unlock(void) +{ + /* This fellow will unlock it for us. */ + netdev_run_todo(); +} +EXPORT_SYMBOL(rtnl_unlock); + +int rtnl_trylock(void) +{ + return mutex_trylock(&rtnl_mutex); +} +EXPORT_SYMBOL(rtnl_trylock); + +int rtnl_is_locked(void) +{ + return mutex_is_locked(&rtnl_mutex); +} +EXPORT_SYMBOL(rtnl_is_locked); + +bool refcount_dec_and_rtnl_lock(refcount_t *r) +{ + return refcount_dec_and_mutex_lock(r, &rtnl_mutex); +} +EXPORT_SYMBOL(refcount_dec_and_rtnl_lock); + +#ifdef CONFIG_PROVE_LOCKING +bool lockdep_rtnl_is_held(void) +{ + return lockdep_is_held(&rtnl_mutex); +} +EXPORT_SYMBOL(lockdep_rtnl_is_held); +#endif /* #ifdef CONFIG_PROVE_LOCKING */ + +static struct rtnl_link __rcu *__rcu *rtnl_msg_handlers[RTNL_FAMILY_MAX + 1]; + +static inline int rtm_msgindex(int msgtype) +{ + int msgindex = msgtype - RTM_BASE; + + /* + * msgindex < 0 implies someone tried to register a netlink + * control code. msgindex >= RTM_NR_MSGTYPES may indicate that + * the message type has not been added to linux/rtnetlink.h + */ + BUG_ON(msgindex < 0 || msgindex >= RTM_NR_MSGTYPES); + + return msgindex; +} + +static struct rtnl_link *rtnl_get_link(int protocol, int msgtype) +{ + struct rtnl_link __rcu **tab; + + if (protocol >= ARRAY_SIZE(rtnl_msg_handlers)) + protocol = PF_UNSPEC; + + tab = rcu_dereference_rtnl(rtnl_msg_handlers[protocol]); + if (!tab) + tab = rcu_dereference_rtnl(rtnl_msg_handlers[PF_UNSPEC]); + + return rcu_dereference_rtnl(tab[msgtype]); +} + +static int rtnl_register_internal(struct module *owner, + int protocol, int msgtype, + rtnl_doit_func doit, rtnl_dumpit_func dumpit, + unsigned int flags) +{ + struct rtnl_link *link, *old; + struct rtnl_link __rcu **tab; + int msgindex; + int ret = -ENOBUFS; + + BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX); + msgindex = rtm_msgindex(msgtype); + + rtnl_lock(); + tab = rtnl_dereference(rtnl_msg_handlers[protocol]); + if (tab == NULL) { + tab = kcalloc(RTM_NR_MSGTYPES, sizeof(void *), GFP_KERNEL); + if (!tab) + goto unlock; + + /* ensures we see the 0 stores */ + rcu_assign_pointer(rtnl_msg_handlers[protocol], tab); + } + + old = rtnl_dereference(tab[msgindex]); + if (old) { + link = kmemdup(old, sizeof(*old), GFP_KERNEL); + if (!link) + goto unlock; + } else { + link = kzalloc(sizeof(*link), GFP_KERNEL); + if (!link) + goto unlock; + } + + WARN_ON(link->owner && link->owner != owner); + link->owner = owner; + + WARN_ON(doit && link->doit && link->doit != doit); + if (doit) + link->doit = doit; + WARN_ON(dumpit && link->dumpit && link->dumpit != dumpit); + if (dumpit) + link->dumpit = dumpit; + + WARN_ON(rtnl_msgtype_kind(msgtype) != RTNL_KIND_DEL && + (flags & RTNL_FLAG_BULK_DEL_SUPPORTED)); + link->flags |= flags; + + /* publish protocol:msgtype */ + rcu_assign_pointer(tab[msgindex], link); + ret = 0; + if (old) + kfree_rcu(old, rcu); +unlock: + rtnl_unlock(); + return ret; +} + +/** + * rtnl_register_module - Register a rtnetlink message type + * + * @owner: module registering the hook (THIS_MODULE) + * @protocol: Protocol family or PF_UNSPEC + * @msgtype: rtnetlink message type + * @doit: Function pointer called for each request message + * @dumpit: Function pointer called for each dump request (NLM_F_DUMP) message + * @flags: rtnl_link_flags to modify behaviour of doit/dumpit functions + * + * Like rtnl_register, but for use by removable modules. + */ +int rtnl_register_module(struct module *owner, + int protocol, int msgtype, + rtnl_doit_func doit, rtnl_dumpit_func dumpit, + unsigned int flags) +{ + return rtnl_register_internal(owner, protocol, msgtype, + doit, dumpit, flags); +} +EXPORT_SYMBOL_GPL(rtnl_register_module); + +/** + * rtnl_register - Register a rtnetlink message type + * @protocol: Protocol family or PF_UNSPEC + * @msgtype: rtnetlink message type + * @doit: Function pointer called for each request message + * @dumpit: Function pointer called for each dump request (NLM_F_DUMP) message + * @flags: rtnl_link_flags to modify behaviour of doit/dumpit functions + * + * Registers the specified function pointers (at least one of them has + * to be non-NULL) to be called whenever a request message for the + * specified protocol family and message type is received. + * + * The special protocol family PF_UNSPEC may be used to define fallback + * function pointers for the case when no entry for the specific protocol + * family exists. + */ +void rtnl_register(int protocol, int msgtype, + rtnl_doit_func doit, rtnl_dumpit_func dumpit, + unsigned int flags) +{ + int err; + + err = rtnl_register_internal(NULL, protocol, msgtype, doit, dumpit, + flags); + if (err) + pr_err("Unable to register rtnetlink message handler, " + "protocol = %d, message type = %d\n", protocol, msgtype); +} + +/** + * rtnl_unregister - Unregister a rtnetlink message type + * @protocol: Protocol family or PF_UNSPEC + * @msgtype: rtnetlink message type + * + * Returns 0 on success or a negative error code. + */ +int rtnl_unregister(int protocol, int msgtype) +{ + struct rtnl_link __rcu **tab; + struct rtnl_link *link; + int msgindex; + + BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX); + msgindex = rtm_msgindex(msgtype); + + rtnl_lock(); + tab = rtnl_dereference(rtnl_msg_handlers[protocol]); + if (!tab) { + rtnl_unlock(); + return -ENOENT; + } + + link = rtnl_dereference(tab[msgindex]); + RCU_INIT_POINTER(tab[msgindex], NULL); + rtnl_unlock(); + + kfree_rcu(link, rcu); + + return 0; +} +EXPORT_SYMBOL_GPL(rtnl_unregister); + +/** + * rtnl_unregister_all - Unregister all rtnetlink message type of a protocol + * @protocol : Protocol family or PF_UNSPEC + * + * Identical to calling rtnl_unregster() for all registered message types + * of a certain protocol family. + */ +void rtnl_unregister_all(int protocol) +{ + struct rtnl_link __rcu **tab; + struct rtnl_link *link; + int msgindex; + + BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX); + + rtnl_lock(); + tab = rtnl_dereference(rtnl_msg_handlers[protocol]); + if (!tab) { + rtnl_unlock(); + return; + } + RCU_INIT_POINTER(rtnl_msg_handlers[protocol], NULL); + for (msgindex = 0; msgindex < RTM_NR_MSGTYPES; msgindex++) { + link = rtnl_dereference(tab[msgindex]); + if (!link) + continue; + + RCU_INIT_POINTER(tab[msgindex], NULL); + kfree_rcu(link, rcu); + } + rtnl_unlock(); + + synchronize_net(); + + kfree(tab); +} +EXPORT_SYMBOL_GPL(rtnl_unregister_all); + +static LIST_HEAD(link_ops); + +static const struct rtnl_link_ops *rtnl_link_ops_get(const char *kind) +{ + const struct rtnl_link_ops *ops; + + list_for_each_entry(ops, &link_ops, list) { + if (!strcmp(ops->kind, kind)) + return ops; + } + return NULL; +} + +/** + * __rtnl_link_register - Register rtnl_link_ops with rtnetlink. + * @ops: struct rtnl_link_ops * to register + * + * The caller must hold the rtnl_mutex. This function should be used + * by drivers that create devices during module initialization. It + * must be called before registering the devices. + * + * Returns 0 on success or a negative error code. + */ +int __rtnl_link_register(struct rtnl_link_ops *ops) +{ + if (rtnl_link_ops_get(ops->kind)) + return -EEXIST; + + /* The check for alloc/setup is here because if ops + * does not have that filled up, it is not possible + * to use the ops for creating device. So do not + * fill up dellink as well. That disables rtnl_dellink. + */ + if ((ops->alloc || ops->setup) && !ops->dellink) + ops->dellink = unregister_netdevice_queue; + + list_add_tail(&ops->list, &link_ops); + return 0; +} +EXPORT_SYMBOL_GPL(__rtnl_link_register); + +/** + * rtnl_link_register - Register rtnl_link_ops with rtnetlink. + * @ops: struct rtnl_link_ops * to register + * + * Returns 0 on success or a negative error code. + */ +int rtnl_link_register(struct rtnl_link_ops *ops) +{ + int err; + + /* Sanity-check max sizes to avoid stack buffer overflow. */ + if (WARN_ON(ops->maxtype > RTNL_MAX_TYPE || + ops->slave_maxtype > RTNL_SLAVE_MAX_TYPE)) + return -EINVAL; + + rtnl_lock(); + err = __rtnl_link_register(ops); + rtnl_unlock(); + return err; +} +EXPORT_SYMBOL_GPL(rtnl_link_register); + +static void __rtnl_kill_links(struct net *net, struct rtnl_link_ops *ops) +{ + struct net_device *dev; + LIST_HEAD(list_kill); + + for_each_netdev(net, dev) { + if (dev->rtnl_link_ops == ops) + ops->dellink(dev, &list_kill); + } + unregister_netdevice_many(&list_kill); +} + +/** + * __rtnl_link_unregister - Unregister rtnl_link_ops from rtnetlink. + * @ops: struct rtnl_link_ops * to unregister + * + * The caller must hold the rtnl_mutex and guarantee net_namespace_list + * integrity (hold pernet_ops_rwsem for writing to close the race + * with setup_net() and cleanup_net()). + */ +void __rtnl_link_unregister(struct rtnl_link_ops *ops) +{ + struct net *net; + + for_each_net(net) { + __rtnl_kill_links(net, ops); + } + list_del(&ops->list); +} +EXPORT_SYMBOL_GPL(__rtnl_link_unregister); + +/* Return with the rtnl_lock held when there are no network + * devices unregistering in any network namespace. + */ +static void rtnl_lock_unregistering_all(void) +{ + struct net *net; + bool unregistering; + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + add_wait_queue(&netdev_unregistering_wq, &wait); + for (;;) { + unregistering = false; + rtnl_lock(); + /* We held write locked pernet_ops_rwsem, and parallel + * setup_net() and cleanup_net() are not possible. + */ + for_each_net(net) { + if (atomic_read(&net->dev_unreg_count) > 0) { + unregistering = true; + break; + } + } + if (!unregistering) + break; + __rtnl_unlock(); + + wait_woken(&wait, TASK_UNINTERRUPTIBLE, MAX_SCHEDULE_TIMEOUT); + } + remove_wait_queue(&netdev_unregistering_wq, &wait); +} + +/** + * rtnl_link_unregister - Unregister rtnl_link_ops from rtnetlink. + * @ops: struct rtnl_link_ops * to unregister + */ +void rtnl_link_unregister(struct rtnl_link_ops *ops) +{ + /* Close the race with setup_net() and cleanup_net() */ + down_write(&pernet_ops_rwsem); + rtnl_lock_unregistering_all(); + __rtnl_link_unregister(ops); + rtnl_unlock(); + up_write(&pernet_ops_rwsem); +} +EXPORT_SYMBOL_GPL(rtnl_link_unregister); + +static size_t rtnl_link_get_slave_info_data_size(const struct net_device *dev) +{ + struct net_device *master_dev; + const struct rtnl_link_ops *ops; + size_t size = 0; + + rcu_read_lock(); + + master_dev = netdev_master_upper_dev_get_rcu((struct net_device *)dev); + if (!master_dev) + goto out; + + ops = master_dev->rtnl_link_ops; + if (!ops || !ops->get_slave_size) + goto out; + /* IFLA_INFO_SLAVE_DATA + nested data */ + size = nla_total_size(sizeof(struct nlattr)) + + ops->get_slave_size(master_dev, dev); + +out: + rcu_read_unlock(); + return size; +} + +static size_t rtnl_link_get_size(const struct net_device *dev) +{ + const struct rtnl_link_ops *ops = dev->rtnl_link_ops; + size_t size; + + if (!ops) + return 0; + + size = nla_total_size(sizeof(struct nlattr)) + /* IFLA_LINKINFO */ + nla_total_size(strlen(ops->kind) + 1); /* IFLA_INFO_KIND */ + + if (ops->get_size) + /* IFLA_INFO_DATA + nested data */ + size += nla_total_size(sizeof(struct nlattr)) + + ops->get_size(dev); + + if (ops->get_xstats_size) + /* IFLA_INFO_XSTATS */ + size += nla_total_size(ops->get_xstats_size(dev)); + + size += rtnl_link_get_slave_info_data_size(dev); + + return size; +} + +static LIST_HEAD(rtnl_af_ops); + +static const struct rtnl_af_ops *rtnl_af_lookup(const int family) +{ + const struct rtnl_af_ops *ops; + + ASSERT_RTNL(); + + list_for_each_entry(ops, &rtnl_af_ops, list) { + if (ops->family == family) + return ops; + } + + return NULL; +} + +/** + * rtnl_af_register - Register rtnl_af_ops with rtnetlink. + * @ops: struct rtnl_af_ops * to register + * + * Returns 0 on success or a negative error code. + */ +void rtnl_af_register(struct rtnl_af_ops *ops) +{ + rtnl_lock(); + list_add_tail_rcu(&ops->list, &rtnl_af_ops); + rtnl_unlock(); +} +EXPORT_SYMBOL_GPL(rtnl_af_register); + +/** + * rtnl_af_unregister - Unregister rtnl_af_ops from rtnetlink. + * @ops: struct rtnl_af_ops * to unregister + */ +void rtnl_af_unregister(struct rtnl_af_ops *ops) +{ + rtnl_lock(); + list_del_rcu(&ops->list); + rtnl_unlock(); + + synchronize_rcu(); +} +EXPORT_SYMBOL_GPL(rtnl_af_unregister); + +static size_t rtnl_link_get_af_size(const struct net_device *dev, + u32 ext_filter_mask) +{ + struct rtnl_af_ops *af_ops; + size_t size; + + /* IFLA_AF_SPEC */ + size = nla_total_size(sizeof(struct nlattr)); + + rcu_read_lock(); + list_for_each_entry_rcu(af_ops, &rtnl_af_ops, list) { + if (af_ops->get_link_af_size) { + /* AF_* + nested data */ + size += nla_total_size(sizeof(struct nlattr)) + + af_ops->get_link_af_size(dev, ext_filter_mask); + } + } + rcu_read_unlock(); + + return size; +} + +static bool rtnl_have_link_slave_info(const struct net_device *dev) +{ + struct net_device *master_dev; + bool ret = false; + + rcu_read_lock(); + + master_dev = netdev_master_upper_dev_get_rcu((struct net_device *)dev); + if (master_dev && master_dev->rtnl_link_ops) + ret = true; + rcu_read_unlock(); + return ret; +} + +static int rtnl_link_slave_info_fill(struct sk_buff *skb, + const struct net_device *dev) +{ + struct net_device *master_dev; + const struct rtnl_link_ops *ops; + struct nlattr *slave_data; + int err; + + master_dev = netdev_master_upper_dev_get((struct net_device *) dev); + if (!master_dev) + return 0; + ops = master_dev->rtnl_link_ops; + if (!ops) + return 0; + if (nla_put_string(skb, IFLA_INFO_SLAVE_KIND, ops->kind) < 0) + return -EMSGSIZE; + if (ops->fill_slave_info) { + slave_data = nla_nest_start_noflag(skb, IFLA_INFO_SLAVE_DATA); + if (!slave_data) + return -EMSGSIZE; + err = ops->fill_slave_info(skb, master_dev, dev); + if (err < 0) + goto err_cancel_slave_data; + nla_nest_end(skb, slave_data); + } + return 0; + +err_cancel_slave_data: + nla_nest_cancel(skb, slave_data); + return err; +} + +static int rtnl_link_info_fill(struct sk_buff *skb, + const struct net_device *dev) +{ + const struct rtnl_link_ops *ops = dev->rtnl_link_ops; + struct nlattr *data; + int err; + + if (!ops) + return 0; + if (nla_put_string(skb, IFLA_INFO_KIND, ops->kind) < 0) + return -EMSGSIZE; + if (ops->fill_xstats) { + err = ops->fill_xstats(skb, dev); + if (err < 0) + return err; + } + if (ops->fill_info) { + data = nla_nest_start_noflag(skb, IFLA_INFO_DATA); + if (data == NULL) + return -EMSGSIZE; + err = ops->fill_info(skb, dev); + if (err < 0) + goto err_cancel_data; + nla_nest_end(skb, data); + } + return 0; + +err_cancel_data: + nla_nest_cancel(skb, data); + return err; +} + +static int rtnl_link_fill(struct sk_buff *skb, const struct net_device *dev) +{ + struct nlattr *linkinfo; + int err = -EMSGSIZE; + + linkinfo = nla_nest_start_noflag(skb, IFLA_LINKINFO); + if (linkinfo == NULL) + goto out; + + err = rtnl_link_info_fill(skb, dev); + if (err < 0) + goto err_cancel_link; + + err = rtnl_link_slave_info_fill(skb, dev); + if (err < 0) + goto err_cancel_link; + + nla_nest_end(skb, linkinfo); + return 0; + +err_cancel_link: + nla_nest_cancel(skb, linkinfo); +out: + return err; +} + +int rtnetlink_send(struct sk_buff *skb, struct net *net, u32 pid, unsigned int group, int echo) +{ + struct sock *rtnl = net->rtnl; + + return nlmsg_notify(rtnl, skb, pid, group, echo, GFP_KERNEL); +} + +int rtnl_unicast(struct sk_buff *skb, struct net *net, u32 pid) +{ + struct sock *rtnl = net->rtnl; + + return nlmsg_unicast(rtnl, skb, pid); +} +EXPORT_SYMBOL(rtnl_unicast); + +void rtnl_notify(struct sk_buff *skb, struct net *net, u32 pid, u32 group, + struct nlmsghdr *nlh, gfp_t flags) +{ + struct sock *rtnl = net->rtnl; + + nlmsg_notify(rtnl, skb, pid, group, nlmsg_report(nlh), flags); +} +EXPORT_SYMBOL(rtnl_notify); + +void rtnl_set_sk_err(struct net *net, u32 group, int error) +{ + struct sock *rtnl = net->rtnl; + + netlink_set_err(rtnl, 0, group, error); +} +EXPORT_SYMBOL(rtnl_set_sk_err); + +int rtnetlink_put_metrics(struct sk_buff *skb, u32 *metrics) +{ + struct nlattr *mx; + int i, valid = 0; + + /* nothing is dumped for dst_default_metrics, so just skip the loop */ + if (metrics == dst_default_metrics.metrics) + return 0; + + mx = nla_nest_start_noflag(skb, RTA_METRICS); + if (mx == NULL) + return -ENOBUFS; + + for (i = 0; i < RTAX_MAX; i++) { + if (metrics[i]) { + if (i == RTAX_CC_ALGO - 1) { + char tmp[TCP_CA_NAME_MAX], *name; + + name = tcp_ca_get_name_by_key(metrics[i], tmp); + if (!name) + continue; + if (nla_put_string(skb, i + 1, name)) + goto nla_put_failure; + } else if (i == RTAX_FEATURES - 1) { + u32 user_features = metrics[i] & RTAX_FEATURE_MASK; + + if (!user_features) + continue; + BUILD_BUG_ON(RTAX_FEATURE_MASK & DST_FEATURE_MASK); + if (nla_put_u32(skb, i + 1, user_features)) + goto nla_put_failure; + } else { + if (nla_put_u32(skb, i + 1, metrics[i])) + goto nla_put_failure; + } + valid++; + } + } + + if (!valid) { + nla_nest_cancel(skb, mx); + return 0; + } + + return nla_nest_end(skb, mx); + +nla_put_failure: + nla_nest_cancel(skb, mx); + return -EMSGSIZE; +} +EXPORT_SYMBOL(rtnetlink_put_metrics); + +int rtnl_put_cacheinfo(struct sk_buff *skb, struct dst_entry *dst, u32 id, + long expires, u32 error) +{ + struct rta_cacheinfo ci = { + .rta_error = error, + .rta_id = id, + }; + + if (dst) { + ci.rta_lastuse = jiffies_delta_to_clock_t(jiffies - dst->lastuse); + ci.rta_used = dst->__use; + ci.rta_clntref = atomic_read(&dst->__refcnt); + } + if (expires) { + unsigned long clock; + + clock = jiffies_to_clock_t(abs(expires)); + clock = min_t(unsigned long, clock, INT_MAX); + ci.rta_expires = (expires > 0) ? clock : -clock; + } + return nla_put(skb, RTA_CACHEINFO, sizeof(ci), &ci); +} +EXPORT_SYMBOL_GPL(rtnl_put_cacheinfo); + +static void set_operstate(struct net_device *dev, unsigned char transition) +{ + unsigned char operstate = dev->operstate; + + switch (transition) { + case IF_OPER_UP: + if ((operstate == IF_OPER_DORMANT || + operstate == IF_OPER_TESTING || + operstate == IF_OPER_UNKNOWN) && + !netif_dormant(dev) && !netif_testing(dev)) + operstate = IF_OPER_UP; + break; + + case IF_OPER_TESTING: + if (netif_oper_up(dev)) + operstate = IF_OPER_TESTING; + break; + + case IF_OPER_DORMANT: + if (netif_oper_up(dev)) + operstate = IF_OPER_DORMANT; + break; + } + + if (dev->operstate != operstate) { + write_lock(&dev_base_lock); + dev->operstate = operstate; + write_unlock(&dev_base_lock); + netdev_state_change(dev); + } +} + +static unsigned int rtnl_dev_get_flags(const struct net_device *dev) +{ + return (dev->flags & ~(IFF_PROMISC | IFF_ALLMULTI)) | + (dev->gflags & (IFF_PROMISC | IFF_ALLMULTI)); +} + +static unsigned int rtnl_dev_combine_flags(const struct net_device *dev, + const struct ifinfomsg *ifm) +{ + unsigned int flags = ifm->ifi_flags; + + /* bugwards compatibility: ifi_change == 0 is treated as ~0 */ + if (ifm->ifi_change) + flags = (flags & ifm->ifi_change) | + (rtnl_dev_get_flags(dev) & ~ifm->ifi_change); + + return flags; +} + +static void copy_rtnl_link_stats(struct rtnl_link_stats *a, + const struct rtnl_link_stats64 *b) +{ + a->rx_packets = b->rx_packets; + a->tx_packets = b->tx_packets; + a->rx_bytes = b->rx_bytes; + a->tx_bytes = b->tx_bytes; + a->rx_errors = b->rx_errors; + a->tx_errors = b->tx_errors; + a->rx_dropped = b->rx_dropped; + a->tx_dropped = b->tx_dropped; + + a->multicast = b->multicast; + a->collisions = b->collisions; + + a->rx_length_errors = b->rx_length_errors; + a->rx_over_errors = b->rx_over_errors; + a->rx_crc_errors = b->rx_crc_errors; + a->rx_frame_errors = b->rx_frame_errors; + a->rx_fifo_errors = b->rx_fifo_errors; + a->rx_missed_errors = b->rx_missed_errors; + + a->tx_aborted_errors = b->tx_aborted_errors; + a->tx_carrier_errors = b->tx_carrier_errors; + a->tx_fifo_errors = b->tx_fifo_errors; + a->tx_heartbeat_errors = b->tx_heartbeat_errors; + a->tx_window_errors = b->tx_window_errors; + + a->rx_compressed = b->rx_compressed; + a->tx_compressed = b->tx_compressed; + + a->rx_nohandler = b->rx_nohandler; +} + +/* All VF info */ +static inline int rtnl_vfinfo_size(const struct net_device *dev, + u32 ext_filter_mask) +{ + if (dev->dev.parent && (ext_filter_mask & RTEXT_FILTER_VF)) { + int num_vfs = dev_num_vf(dev->dev.parent); + size_t size = nla_total_size(0); + size += num_vfs * + (nla_total_size(0) + + nla_total_size(sizeof(struct ifla_vf_mac)) + + nla_total_size(sizeof(struct ifla_vf_broadcast)) + + nla_total_size(sizeof(struct ifla_vf_vlan)) + + nla_total_size(0) + /* nest IFLA_VF_VLAN_LIST */ + nla_total_size(MAX_VLAN_LIST_LEN * + sizeof(struct ifla_vf_vlan_info)) + + nla_total_size(sizeof(struct ifla_vf_spoofchk)) + + nla_total_size(sizeof(struct ifla_vf_tx_rate)) + + nla_total_size(sizeof(struct ifla_vf_rate)) + + nla_total_size(sizeof(struct ifla_vf_link_state)) + + nla_total_size(sizeof(struct ifla_vf_rss_query_en)) + + nla_total_size(sizeof(struct ifla_vf_trust))); + if (~ext_filter_mask & RTEXT_FILTER_SKIP_STATS) { + size += num_vfs * + (nla_total_size(0) + /* nest IFLA_VF_STATS */ + /* IFLA_VF_STATS_RX_PACKETS */ + nla_total_size_64bit(sizeof(__u64)) + + /* IFLA_VF_STATS_TX_PACKETS */ + nla_total_size_64bit(sizeof(__u64)) + + /* IFLA_VF_STATS_RX_BYTES */ + nla_total_size_64bit(sizeof(__u64)) + + /* IFLA_VF_STATS_TX_BYTES */ + nla_total_size_64bit(sizeof(__u64)) + + /* IFLA_VF_STATS_BROADCAST */ + nla_total_size_64bit(sizeof(__u64)) + + /* IFLA_VF_STATS_MULTICAST */ + nla_total_size_64bit(sizeof(__u64)) + + /* IFLA_VF_STATS_RX_DROPPED */ + nla_total_size_64bit(sizeof(__u64)) + + /* IFLA_VF_STATS_TX_DROPPED */ + nla_total_size_64bit(sizeof(__u64))); + } + return size; + } else + return 0; +} + +static size_t rtnl_port_size(const struct net_device *dev, + u32 ext_filter_mask) +{ + size_t port_size = nla_total_size(4) /* PORT_VF */ + + nla_total_size(PORT_PROFILE_MAX) /* PORT_PROFILE */ + + nla_total_size(PORT_UUID_MAX) /* PORT_INSTANCE_UUID */ + + nla_total_size(PORT_UUID_MAX) /* PORT_HOST_UUID */ + + nla_total_size(1) /* PROT_VDP_REQUEST */ + + nla_total_size(2); /* PORT_VDP_RESPONSE */ + size_t vf_ports_size = nla_total_size(sizeof(struct nlattr)); + size_t vf_port_size = nla_total_size(sizeof(struct nlattr)) + + port_size; + size_t port_self_size = nla_total_size(sizeof(struct nlattr)) + + port_size; + + if (!dev->netdev_ops->ndo_get_vf_port || !dev->dev.parent || + !(ext_filter_mask & RTEXT_FILTER_VF)) + return 0; + if (dev_num_vf(dev->dev.parent)) + return port_self_size + vf_ports_size + + vf_port_size * dev_num_vf(dev->dev.parent); + else + return port_self_size; +} + +static size_t rtnl_xdp_size(void) +{ + size_t xdp_size = nla_total_size(0) + /* nest IFLA_XDP */ + nla_total_size(1) + /* XDP_ATTACHED */ + nla_total_size(4) + /* XDP_PROG_ID (or 1st mode) */ + nla_total_size(4); /* XDP__PROG_ID */ + + return xdp_size; +} + +static size_t rtnl_prop_list_size(const struct net_device *dev) +{ + struct netdev_name_node *name_node; + size_t size; + + if (list_empty(&dev->name_node->list)) + return 0; + size = nla_total_size(0); + list_for_each_entry(name_node, &dev->name_node->list, list) + size += nla_total_size(ALTIFNAMSIZ); + return size; +} + +static size_t rtnl_proto_down_size(const struct net_device *dev) +{ + size_t size = nla_total_size(1); + + if (dev->proto_down_reason) + size += nla_total_size(0) + nla_total_size(4); + + return size; +} + +static noinline size_t if_nlmsg_size(const struct net_device *dev, + u32 ext_filter_mask) +{ + return NLMSG_ALIGN(sizeof(struct ifinfomsg)) + + nla_total_size(IFNAMSIZ) /* IFLA_IFNAME */ + + nla_total_size(IFALIASZ) /* IFLA_IFALIAS */ + + nla_total_size(IFNAMSIZ) /* IFLA_QDISC */ + + nla_total_size_64bit(sizeof(struct rtnl_link_ifmap)) + + nla_total_size(sizeof(struct rtnl_link_stats)) + + nla_total_size_64bit(sizeof(struct rtnl_link_stats64)) + + nla_total_size(MAX_ADDR_LEN) /* IFLA_ADDRESS */ + + nla_total_size(MAX_ADDR_LEN) /* IFLA_BROADCAST */ + + nla_total_size(4) /* IFLA_TXQLEN */ + + nla_total_size(4) /* IFLA_WEIGHT */ + + nla_total_size(4) /* IFLA_MTU */ + + nla_total_size(4) /* IFLA_LINK */ + + nla_total_size(4) /* IFLA_MASTER */ + + nla_total_size(1) /* IFLA_CARRIER */ + + nla_total_size(4) /* IFLA_PROMISCUITY */ + + nla_total_size(4) /* IFLA_ALLMULTI */ + + nla_total_size(4) /* IFLA_NUM_TX_QUEUES */ + + nla_total_size(4) /* IFLA_NUM_RX_QUEUES */ + + nla_total_size(4) /* IFLA_GSO_MAX_SEGS */ + + nla_total_size(4) /* IFLA_GSO_MAX_SIZE */ + + nla_total_size(4) /* IFLA_GRO_MAX_SIZE */ + + nla_total_size(4) /* IFLA_TSO_MAX_SIZE */ + + nla_total_size(4) /* IFLA_TSO_MAX_SEGS */ + + nla_total_size(1) /* IFLA_OPERSTATE */ + + nla_total_size(1) /* IFLA_LINKMODE */ + + nla_total_size(4) /* IFLA_CARRIER_CHANGES */ + + nla_total_size(4) /* IFLA_LINK_NETNSID */ + + nla_total_size(4) /* IFLA_GROUP */ + + nla_total_size(ext_filter_mask + & RTEXT_FILTER_VF ? 4 : 0) /* IFLA_NUM_VF */ + + rtnl_vfinfo_size(dev, ext_filter_mask) /* IFLA_VFINFO_LIST */ + + rtnl_port_size(dev, ext_filter_mask) /* IFLA_VF_PORTS + IFLA_PORT_SELF */ + + rtnl_link_get_size(dev) /* IFLA_LINKINFO */ + + rtnl_link_get_af_size(dev, ext_filter_mask) /* IFLA_AF_SPEC */ + + nla_total_size(MAX_PHYS_ITEM_ID_LEN) /* IFLA_PHYS_PORT_ID */ + + nla_total_size(MAX_PHYS_ITEM_ID_LEN) /* IFLA_PHYS_SWITCH_ID */ + + nla_total_size(IFNAMSIZ) /* IFLA_PHYS_PORT_NAME */ + + rtnl_xdp_size() /* IFLA_XDP */ + + nla_total_size(4) /* IFLA_EVENT */ + + nla_total_size(4) /* IFLA_NEW_NETNSID */ + + nla_total_size(4) /* IFLA_NEW_IFINDEX */ + + rtnl_proto_down_size(dev) /* proto down */ + + nla_total_size(4) /* IFLA_TARGET_NETNSID */ + + nla_total_size(4) /* IFLA_CARRIER_UP_COUNT */ + + nla_total_size(4) /* IFLA_CARRIER_DOWN_COUNT */ + + nla_total_size(4) /* IFLA_MIN_MTU */ + + nla_total_size(4) /* IFLA_MAX_MTU */ + + rtnl_prop_list_size(dev) + + nla_total_size(MAX_ADDR_LEN) /* IFLA_PERM_ADDRESS */ + + 0; +} + +static int rtnl_vf_ports_fill(struct sk_buff *skb, struct net_device *dev) +{ + struct nlattr *vf_ports; + struct nlattr *vf_port; + int vf; + int err; + + vf_ports = nla_nest_start_noflag(skb, IFLA_VF_PORTS); + if (!vf_ports) + return -EMSGSIZE; + + for (vf = 0; vf < dev_num_vf(dev->dev.parent); vf++) { + vf_port = nla_nest_start_noflag(skb, IFLA_VF_PORT); + if (!vf_port) + goto nla_put_failure; + if (nla_put_u32(skb, IFLA_PORT_VF, vf)) + goto nla_put_failure; + err = dev->netdev_ops->ndo_get_vf_port(dev, vf, skb); + if (err == -EMSGSIZE) + goto nla_put_failure; + if (err) { + nla_nest_cancel(skb, vf_port); + continue; + } + nla_nest_end(skb, vf_port); + } + + nla_nest_end(skb, vf_ports); + + return 0; + +nla_put_failure: + nla_nest_cancel(skb, vf_ports); + return -EMSGSIZE; +} + +static int rtnl_port_self_fill(struct sk_buff *skb, struct net_device *dev) +{ + struct nlattr *port_self; + int err; + + port_self = nla_nest_start_noflag(skb, IFLA_PORT_SELF); + if (!port_self) + return -EMSGSIZE; + + err = dev->netdev_ops->ndo_get_vf_port(dev, PORT_SELF_VF, skb); + if (err) { + nla_nest_cancel(skb, port_self); + return (err == -EMSGSIZE) ? err : 0; + } + + nla_nest_end(skb, port_self); + + return 0; +} + +static int rtnl_port_fill(struct sk_buff *skb, struct net_device *dev, + u32 ext_filter_mask) +{ + int err; + + if (!dev->netdev_ops->ndo_get_vf_port || !dev->dev.parent || + !(ext_filter_mask & RTEXT_FILTER_VF)) + return 0; + + err = rtnl_port_self_fill(skb, dev); + if (err) + return err; + + if (dev_num_vf(dev->dev.parent)) { + err = rtnl_vf_ports_fill(skb, dev); + if (err) + return err; + } + + return 0; +} + +static int rtnl_phys_port_id_fill(struct sk_buff *skb, struct net_device *dev) +{ + int err; + struct netdev_phys_item_id ppid; + + err = dev_get_phys_port_id(dev, &ppid); + if (err) { + if (err == -EOPNOTSUPP) + return 0; + return err; + } + + if (nla_put(skb, IFLA_PHYS_PORT_ID, ppid.id_len, ppid.id)) + return -EMSGSIZE; + + return 0; +} + +static int rtnl_phys_port_name_fill(struct sk_buff *skb, struct net_device *dev) +{ + char name[IFNAMSIZ]; + int err; + + err = dev_get_phys_port_name(dev, name, sizeof(name)); + if (err) { + if (err == -EOPNOTSUPP) + return 0; + return err; + } + + if (nla_put_string(skb, IFLA_PHYS_PORT_NAME, name)) + return -EMSGSIZE; + + return 0; +} + +static int rtnl_phys_switch_id_fill(struct sk_buff *skb, struct net_device *dev) +{ + struct netdev_phys_item_id ppid = { }; + int err; + + err = dev_get_port_parent_id(dev, &ppid, false); + if (err) { + if (err == -EOPNOTSUPP) + return 0; + return err; + } + + if (nla_put(skb, IFLA_PHYS_SWITCH_ID, ppid.id_len, ppid.id)) + return -EMSGSIZE; + + return 0; +} + +static noinline_for_stack int rtnl_fill_stats(struct sk_buff *skb, + struct net_device *dev) +{ + struct rtnl_link_stats64 *sp; + struct nlattr *attr; + + attr = nla_reserve_64bit(skb, IFLA_STATS64, + sizeof(struct rtnl_link_stats64), IFLA_PAD); + if (!attr) + return -EMSGSIZE; + + sp = nla_data(attr); + dev_get_stats(dev, sp); + + attr = nla_reserve(skb, IFLA_STATS, + sizeof(struct rtnl_link_stats)); + if (!attr) + return -EMSGSIZE; + + copy_rtnl_link_stats(nla_data(attr), sp); + + return 0; +} + +static noinline_for_stack int rtnl_fill_vfinfo(struct sk_buff *skb, + struct net_device *dev, + int vfs_num, + struct nlattr *vfinfo, + u32 ext_filter_mask) +{ + struct ifla_vf_rss_query_en vf_rss_query_en; + struct nlattr *vf, *vfstats, *vfvlanlist; + struct ifla_vf_link_state vf_linkstate; + struct ifla_vf_vlan_info vf_vlan_info; + struct ifla_vf_spoofchk vf_spoofchk; + struct ifla_vf_tx_rate vf_tx_rate; + struct ifla_vf_stats vf_stats; + struct ifla_vf_trust vf_trust; + struct ifla_vf_vlan vf_vlan; + struct ifla_vf_rate vf_rate; + struct ifla_vf_mac vf_mac; + struct ifla_vf_broadcast vf_broadcast; + struct ifla_vf_info ivi; + struct ifla_vf_guid node_guid; + struct ifla_vf_guid port_guid; + + memset(&ivi, 0, sizeof(ivi)); + + /* Not all SR-IOV capable drivers support the + * spoofcheck and "RSS query enable" query. Preset to + * -1 so the user space tool can detect that the driver + * didn't report anything. + */ + ivi.spoofchk = -1; + ivi.rss_query_en = -1; + ivi.trusted = -1; + /* The default value for VF link state is "auto" + * IFLA_VF_LINK_STATE_AUTO which equals zero + */ + ivi.linkstate = 0; + /* VLAN Protocol by default is 802.1Q */ + ivi.vlan_proto = htons(ETH_P_8021Q); + if (dev->netdev_ops->ndo_get_vf_config(dev, vfs_num, &ivi)) + return 0; + + memset(&vf_vlan_info, 0, sizeof(vf_vlan_info)); + memset(&node_guid, 0, sizeof(node_guid)); + memset(&port_guid, 0, sizeof(port_guid)); + + vf_mac.vf = + vf_vlan.vf = + vf_vlan_info.vf = + vf_rate.vf = + vf_tx_rate.vf = + vf_spoofchk.vf = + vf_linkstate.vf = + vf_rss_query_en.vf = + vf_trust.vf = + node_guid.vf = + port_guid.vf = ivi.vf; + + memcpy(vf_mac.mac, ivi.mac, sizeof(ivi.mac)); + memcpy(vf_broadcast.broadcast, dev->broadcast, dev->addr_len); + vf_vlan.vlan = ivi.vlan; + vf_vlan.qos = ivi.qos; + vf_vlan_info.vlan = ivi.vlan; + vf_vlan_info.qos = ivi.qos; + vf_vlan_info.vlan_proto = ivi.vlan_proto; + vf_tx_rate.rate = ivi.max_tx_rate; + vf_rate.min_tx_rate = ivi.min_tx_rate; + vf_rate.max_tx_rate = ivi.max_tx_rate; + vf_spoofchk.setting = ivi.spoofchk; + vf_linkstate.link_state = ivi.linkstate; + vf_rss_query_en.setting = ivi.rss_query_en; + vf_trust.setting = ivi.trusted; + vf = nla_nest_start_noflag(skb, IFLA_VF_INFO); + if (!vf) + goto nla_put_vfinfo_failure; + if (nla_put(skb, IFLA_VF_MAC, sizeof(vf_mac), &vf_mac) || + nla_put(skb, IFLA_VF_BROADCAST, sizeof(vf_broadcast), &vf_broadcast) || + nla_put(skb, IFLA_VF_VLAN, sizeof(vf_vlan), &vf_vlan) || + nla_put(skb, IFLA_VF_RATE, sizeof(vf_rate), + &vf_rate) || + nla_put(skb, IFLA_VF_TX_RATE, sizeof(vf_tx_rate), + &vf_tx_rate) || + nla_put(skb, IFLA_VF_SPOOFCHK, sizeof(vf_spoofchk), + &vf_spoofchk) || + nla_put(skb, IFLA_VF_LINK_STATE, sizeof(vf_linkstate), + &vf_linkstate) || + nla_put(skb, IFLA_VF_RSS_QUERY_EN, + sizeof(vf_rss_query_en), + &vf_rss_query_en) || + nla_put(skb, IFLA_VF_TRUST, + sizeof(vf_trust), &vf_trust)) + goto nla_put_vf_failure; + + if (dev->netdev_ops->ndo_get_vf_guid && + !dev->netdev_ops->ndo_get_vf_guid(dev, vfs_num, &node_guid, + &port_guid)) { + if (nla_put(skb, IFLA_VF_IB_NODE_GUID, sizeof(node_guid), + &node_guid) || + nla_put(skb, IFLA_VF_IB_PORT_GUID, sizeof(port_guid), + &port_guid)) + goto nla_put_vf_failure; + } + vfvlanlist = nla_nest_start_noflag(skb, IFLA_VF_VLAN_LIST); + if (!vfvlanlist) + goto nla_put_vf_failure; + if (nla_put(skb, IFLA_VF_VLAN_INFO, sizeof(vf_vlan_info), + &vf_vlan_info)) { + nla_nest_cancel(skb, vfvlanlist); + goto nla_put_vf_failure; + } + nla_nest_end(skb, vfvlanlist); + if (~ext_filter_mask & RTEXT_FILTER_SKIP_STATS) { + memset(&vf_stats, 0, sizeof(vf_stats)); + if (dev->netdev_ops->ndo_get_vf_stats) + dev->netdev_ops->ndo_get_vf_stats(dev, vfs_num, + &vf_stats); + vfstats = nla_nest_start_noflag(skb, IFLA_VF_STATS); + if (!vfstats) + goto nla_put_vf_failure; + if (nla_put_u64_64bit(skb, IFLA_VF_STATS_RX_PACKETS, + vf_stats.rx_packets, IFLA_VF_STATS_PAD) || + nla_put_u64_64bit(skb, IFLA_VF_STATS_TX_PACKETS, + vf_stats.tx_packets, IFLA_VF_STATS_PAD) || + nla_put_u64_64bit(skb, IFLA_VF_STATS_RX_BYTES, + vf_stats.rx_bytes, IFLA_VF_STATS_PAD) || + nla_put_u64_64bit(skb, IFLA_VF_STATS_TX_BYTES, + vf_stats.tx_bytes, IFLA_VF_STATS_PAD) || + nla_put_u64_64bit(skb, IFLA_VF_STATS_BROADCAST, + vf_stats.broadcast, IFLA_VF_STATS_PAD) || + nla_put_u64_64bit(skb, IFLA_VF_STATS_MULTICAST, + vf_stats.multicast, IFLA_VF_STATS_PAD) || + nla_put_u64_64bit(skb, IFLA_VF_STATS_RX_DROPPED, + vf_stats.rx_dropped, IFLA_VF_STATS_PAD) || + nla_put_u64_64bit(skb, IFLA_VF_STATS_TX_DROPPED, + vf_stats.tx_dropped, IFLA_VF_STATS_PAD)) { + nla_nest_cancel(skb, vfstats); + goto nla_put_vf_failure; + } + nla_nest_end(skb, vfstats); + } + nla_nest_end(skb, vf); + return 0; + +nla_put_vf_failure: + nla_nest_cancel(skb, vf); +nla_put_vfinfo_failure: + nla_nest_cancel(skb, vfinfo); + return -EMSGSIZE; +} + +static noinline_for_stack int rtnl_fill_vf(struct sk_buff *skb, + struct net_device *dev, + u32 ext_filter_mask) +{ + struct nlattr *vfinfo; + int i, num_vfs; + + if (!dev->dev.parent || ((ext_filter_mask & RTEXT_FILTER_VF) == 0)) + return 0; + + num_vfs = dev_num_vf(dev->dev.parent); + if (nla_put_u32(skb, IFLA_NUM_VF, num_vfs)) + return -EMSGSIZE; + + if (!dev->netdev_ops->ndo_get_vf_config) + return 0; + + vfinfo = nla_nest_start_noflag(skb, IFLA_VFINFO_LIST); + if (!vfinfo) + return -EMSGSIZE; + + for (i = 0; i < num_vfs; i++) { + if (rtnl_fill_vfinfo(skb, dev, i, vfinfo, ext_filter_mask)) + return -EMSGSIZE; + } + + nla_nest_end(skb, vfinfo); + return 0; +} + +static int rtnl_fill_link_ifmap(struct sk_buff *skb, struct net_device *dev) +{ + struct rtnl_link_ifmap map; + + memset(&map, 0, sizeof(map)); + map.mem_start = dev->mem_start; + map.mem_end = dev->mem_end; + map.base_addr = dev->base_addr; + map.irq = dev->irq; + map.dma = dev->dma; + map.port = dev->if_port; + + if (nla_put_64bit(skb, IFLA_MAP, sizeof(map), &map, IFLA_PAD)) + return -EMSGSIZE; + + return 0; +} + +static u32 rtnl_xdp_prog_skb(struct net_device *dev) +{ + const struct bpf_prog *generic_xdp_prog; + + ASSERT_RTNL(); + + generic_xdp_prog = rtnl_dereference(dev->xdp_prog); + if (!generic_xdp_prog) + return 0; + return generic_xdp_prog->aux->id; +} + +static u32 rtnl_xdp_prog_drv(struct net_device *dev) +{ + return dev_xdp_prog_id(dev, XDP_MODE_DRV); +} + +static u32 rtnl_xdp_prog_hw(struct net_device *dev) +{ + return dev_xdp_prog_id(dev, XDP_MODE_HW); +} + +static int rtnl_xdp_report_one(struct sk_buff *skb, struct net_device *dev, + u32 *prog_id, u8 *mode, u8 tgt_mode, u32 attr, + u32 (*get_prog_id)(struct net_device *dev)) +{ + u32 curr_id; + int err; + + curr_id = get_prog_id(dev); + if (!curr_id) + return 0; + + *prog_id = curr_id; + err = nla_put_u32(skb, attr, curr_id); + if (err) + return err; + + if (*mode != XDP_ATTACHED_NONE) + *mode = XDP_ATTACHED_MULTI; + else + *mode = tgt_mode; + + return 0; +} + +static int rtnl_xdp_fill(struct sk_buff *skb, struct net_device *dev) +{ + struct nlattr *xdp; + u32 prog_id; + int err; + u8 mode; + + xdp = nla_nest_start_noflag(skb, IFLA_XDP); + if (!xdp) + return -EMSGSIZE; + + prog_id = 0; + mode = XDP_ATTACHED_NONE; + err = rtnl_xdp_report_one(skb, dev, &prog_id, &mode, XDP_ATTACHED_SKB, + IFLA_XDP_SKB_PROG_ID, rtnl_xdp_prog_skb); + if (err) + goto err_cancel; + err = rtnl_xdp_report_one(skb, dev, &prog_id, &mode, XDP_ATTACHED_DRV, + IFLA_XDP_DRV_PROG_ID, rtnl_xdp_prog_drv); + if (err) + goto err_cancel; + err = rtnl_xdp_report_one(skb, dev, &prog_id, &mode, XDP_ATTACHED_HW, + IFLA_XDP_HW_PROG_ID, rtnl_xdp_prog_hw); + if (err) + goto err_cancel; + + err = nla_put_u8(skb, IFLA_XDP_ATTACHED, mode); + if (err) + goto err_cancel; + + if (prog_id && mode != XDP_ATTACHED_MULTI) { + err = nla_put_u32(skb, IFLA_XDP_PROG_ID, prog_id); + if (err) + goto err_cancel; + } + + nla_nest_end(skb, xdp); + return 0; + +err_cancel: + nla_nest_cancel(skb, xdp); + return err; +} + +static u32 rtnl_get_event(unsigned long event) +{ + u32 rtnl_event_type = IFLA_EVENT_NONE; + + switch (event) { + case NETDEV_REBOOT: + rtnl_event_type = IFLA_EVENT_REBOOT; + break; + case NETDEV_FEAT_CHANGE: + rtnl_event_type = IFLA_EVENT_FEATURES; + break; + case NETDEV_BONDING_FAILOVER: + rtnl_event_type = IFLA_EVENT_BONDING_FAILOVER; + break; + case NETDEV_NOTIFY_PEERS: + rtnl_event_type = IFLA_EVENT_NOTIFY_PEERS; + break; + case NETDEV_RESEND_IGMP: + rtnl_event_type = IFLA_EVENT_IGMP_RESEND; + break; + case NETDEV_CHANGEINFODATA: + rtnl_event_type = IFLA_EVENT_BONDING_OPTIONS; + break; + default: + break; + } + + return rtnl_event_type; +} + +static int put_master_ifindex(struct sk_buff *skb, struct net_device *dev) +{ + const struct net_device *upper_dev; + int ret = 0; + + rcu_read_lock(); + + upper_dev = netdev_master_upper_dev_get_rcu(dev); + if (upper_dev) + ret = nla_put_u32(skb, IFLA_MASTER, upper_dev->ifindex); + + rcu_read_unlock(); + return ret; +} + +static int nla_put_iflink(struct sk_buff *skb, const struct net_device *dev, + bool force) +{ + int ifindex = dev_get_iflink(dev); + + if (force || dev->ifindex != ifindex) + return nla_put_u32(skb, IFLA_LINK, ifindex); + + return 0; +} + +static noinline_for_stack int nla_put_ifalias(struct sk_buff *skb, + struct net_device *dev) +{ + char buf[IFALIASZ]; + int ret; + + ret = dev_get_alias(dev, buf, sizeof(buf)); + return ret > 0 ? nla_put_string(skb, IFLA_IFALIAS, buf) : 0; +} + +static int rtnl_fill_link_netnsid(struct sk_buff *skb, + const struct net_device *dev, + struct net *src_net, gfp_t gfp) +{ + bool put_iflink = false; + + if (dev->rtnl_link_ops && dev->rtnl_link_ops->get_link_net) { + struct net *link_net = dev->rtnl_link_ops->get_link_net(dev); + + if (!net_eq(dev_net(dev), link_net)) { + int id = peernet2id_alloc(src_net, link_net, gfp); + + if (nla_put_s32(skb, IFLA_LINK_NETNSID, id)) + return -EMSGSIZE; + + put_iflink = true; + } + } + + return nla_put_iflink(skb, dev, put_iflink); +} + +static int rtnl_fill_link_af(struct sk_buff *skb, + const struct net_device *dev, + u32 ext_filter_mask) +{ + const struct rtnl_af_ops *af_ops; + struct nlattr *af_spec; + + af_spec = nla_nest_start_noflag(skb, IFLA_AF_SPEC); + if (!af_spec) + return -EMSGSIZE; + + list_for_each_entry_rcu(af_ops, &rtnl_af_ops, list) { + struct nlattr *af; + int err; + + if (!af_ops->fill_link_af) + continue; + + af = nla_nest_start_noflag(skb, af_ops->family); + if (!af) + return -EMSGSIZE; + + err = af_ops->fill_link_af(skb, dev, ext_filter_mask); + /* + * Caller may return ENODATA to indicate that there + * was no data to be dumped. This is not an error, it + * means we should trim the attribute header and + * continue. + */ + if (err == -ENODATA) + nla_nest_cancel(skb, af); + else if (err < 0) + return -EMSGSIZE; + + nla_nest_end(skb, af); + } + + nla_nest_end(skb, af_spec); + return 0; +} + +static int rtnl_fill_alt_ifnames(struct sk_buff *skb, + const struct net_device *dev) +{ + struct netdev_name_node *name_node; + int count = 0; + + list_for_each_entry(name_node, &dev->name_node->list, list) { + if (nla_put_string(skb, IFLA_ALT_IFNAME, name_node->name)) + return -EMSGSIZE; + count++; + } + return count; +} + +static int rtnl_fill_prop_list(struct sk_buff *skb, + const struct net_device *dev) +{ + struct nlattr *prop_list; + int ret; + + prop_list = nla_nest_start(skb, IFLA_PROP_LIST); + if (!prop_list) + return -EMSGSIZE; + + ret = rtnl_fill_alt_ifnames(skb, dev); + if (ret <= 0) + goto nest_cancel; + + nla_nest_end(skb, prop_list); + return 0; + +nest_cancel: + nla_nest_cancel(skb, prop_list); + return ret; +} + +static int rtnl_fill_proto_down(struct sk_buff *skb, + const struct net_device *dev) +{ + struct nlattr *pr; + u32 preason; + + if (nla_put_u8(skb, IFLA_PROTO_DOWN, dev->proto_down)) + goto nla_put_failure; + + preason = dev->proto_down_reason; + if (!preason) + return 0; + + pr = nla_nest_start(skb, IFLA_PROTO_DOWN_REASON); + if (!pr) + return -EMSGSIZE; + + if (nla_put_u32(skb, IFLA_PROTO_DOWN_REASON_VALUE, preason)) { + nla_nest_cancel(skb, pr); + goto nla_put_failure; + } + + nla_nest_end(skb, pr); + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static int rtnl_fill_ifinfo(struct sk_buff *skb, + struct net_device *dev, struct net *src_net, + int type, u32 pid, u32 seq, u32 change, + unsigned int flags, u32 ext_filter_mask, + u32 event, int *new_nsid, int new_ifindex, + int tgt_netnsid, gfp_t gfp) +{ + struct ifinfomsg *ifm; + struct nlmsghdr *nlh; + struct Qdisc *qdisc; + + ASSERT_RTNL(); + nlh = nlmsg_put(skb, pid, seq, type, sizeof(*ifm), flags); + if (nlh == NULL) + return -EMSGSIZE; + + ifm = nlmsg_data(nlh); + ifm->ifi_family = AF_UNSPEC; + ifm->__ifi_pad = 0; + ifm->ifi_type = dev->type; + ifm->ifi_index = dev->ifindex; + ifm->ifi_flags = dev_get_flags(dev); + ifm->ifi_change = change; + + if (tgt_netnsid >= 0 && nla_put_s32(skb, IFLA_TARGET_NETNSID, tgt_netnsid)) + goto nla_put_failure; + + qdisc = rtnl_dereference(dev->qdisc); + if (nla_put_string(skb, IFLA_IFNAME, dev->name) || + nla_put_u32(skb, IFLA_TXQLEN, dev->tx_queue_len) || + nla_put_u8(skb, IFLA_OPERSTATE, + netif_running(dev) ? dev->operstate : IF_OPER_DOWN) || + nla_put_u8(skb, IFLA_LINKMODE, dev->link_mode) || + nla_put_u32(skb, IFLA_MTU, dev->mtu) || + nla_put_u32(skb, IFLA_MIN_MTU, dev->min_mtu) || + nla_put_u32(skb, IFLA_MAX_MTU, dev->max_mtu) || + nla_put_u32(skb, IFLA_GROUP, dev->group) || + nla_put_u32(skb, IFLA_PROMISCUITY, dev->promiscuity) || + nla_put_u32(skb, IFLA_ALLMULTI, dev->allmulti) || + nla_put_u32(skb, IFLA_NUM_TX_QUEUES, dev->num_tx_queues) || + nla_put_u32(skb, IFLA_GSO_MAX_SEGS, dev->gso_max_segs) || + nla_put_u32(skb, IFLA_GSO_MAX_SIZE, dev->gso_max_size) || + nla_put_u32(skb, IFLA_GRO_MAX_SIZE, dev->gro_max_size) || + nla_put_u32(skb, IFLA_TSO_MAX_SIZE, dev->tso_max_size) || + nla_put_u32(skb, IFLA_TSO_MAX_SEGS, dev->tso_max_segs) || +#ifdef CONFIG_RPS + nla_put_u32(skb, IFLA_NUM_RX_QUEUES, dev->num_rx_queues) || +#endif + put_master_ifindex(skb, dev) || + nla_put_u8(skb, IFLA_CARRIER, netif_carrier_ok(dev)) || + (qdisc && + nla_put_string(skb, IFLA_QDISC, qdisc->ops->id)) || + nla_put_ifalias(skb, dev) || + nla_put_u32(skb, IFLA_CARRIER_CHANGES, + atomic_read(&dev->carrier_up_count) + + atomic_read(&dev->carrier_down_count)) || + nla_put_u32(skb, IFLA_CARRIER_UP_COUNT, + atomic_read(&dev->carrier_up_count)) || + nla_put_u32(skb, IFLA_CARRIER_DOWN_COUNT, + atomic_read(&dev->carrier_down_count))) + goto nla_put_failure; + + if (rtnl_fill_proto_down(skb, dev)) + goto nla_put_failure; + + if (event != IFLA_EVENT_NONE) { + if (nla_put_u32(skb, IFLA_EVENT, event)) + goto nla_put_failure; + } + + if (rtnl_fill_link_ifmap(skb, dev)) + goto nla_put_failure; + + if (dev->addr_len) { + if (nla_put(skb, IFLA_ADDRESS, dev->addr_len, dev->dev_addr) || + nla_put(skb, IFLA_BROADCAST, dev->addr_len, dev->broadcast)) + goto nla_put_failure; + } + + if (rtnl_phys_port_id_fill(skb, dev)) + goto nla_put_failure; + + if (rtnl_phys_port_name_fill(skb, dev)) + goto nla_put_failure; + + if (rtnl_phys_switch_id_fill(skb, dev)) + goto nla_put_failure; + + if (rtnl_fill_stats(skb, dev)) + goto nla_put_failure; + + if (rtnl_fill_vf(skb, dev, ext_filter_mask)) + goto nla_put_failure; + + if (rtnl_port_fill(skb, dev, ext_filter_mask)) + goto nla_put_failure; + + if (rtnl_xdp_fill(skb, dev)) + goto nla_put_failure; + + if (dev->rtnl_link_ops || rtnl_have_link_slave_info(dev)) { + if (rtnl_link_fill(skb, dev) < 0) + goto nla_put_failure; + } + + if (rtnl_fill_link_netnsid(skb, dev, src_net, gfp)) + goto nla_put_failure; + + if (new_nsid && + nla_put_s32(skb, IFLA_NEW_NETNSID, *new_nsid) < 0) + goto nla_put_failure; + if (new_ifindex && + nla_put_s32(skb, IFLA_NEW_IFINDEX, new_ifindex) < 0) + goto nla_put_failure; + + if (memchr_inv(dev->perm_addr, '\0', dev->addr_len) && + nla_put(skb, IFLA_PERM_ADDRESS, dev->addr_len, dev->perm_addr)) + goto nla_put_failure; + + rcu_read_lock(); + if (rtnl_fill_link_af(skb, dev, ext_filter_mask)) + goto nla_put_failure_rcu; + rcu_read_unlock(); + + if (rtnl_fill_prop_list(skb, dev)) + goto nla_put_failure; + + if (dev->dev.parent && + nla_put_string(skb, IFLA_PARENT_DEV_NAME, + dev_name(dev->dev.parent))) + goto nla_put_failure; + + if (dev->dev.parent && dev->dev.parent->bus && + nla_put_string(skb, IFLA_PARENT_DEV_BUS_NAME, + dev->dev.parent->bus->name)) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure_rcu: + rcu_read_unlock(); +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static const struct nla_policy ifla_policy[IFLA_MAX+1] = { + [IFLA_IFNAME] = { .type = NLA_STRING, .len = IFNAMSIZ-1 }, + [IFLA_ADDRESS] = { .type = NLA_BINARY, .len = MAX_ADDR_LEN }, + [IFLA_BROADCAST] = { .type = NLA_BINARY, .len = MAX_ADDR_LEN }, + [IFLA_MAP] = { .len = sizeof(struct rtnl_link_ifmap) }, + [IFLA_MTU] = { .type = NLA_U32 }, + [IFLA_LINK] = { .type = NLA_U32 }, + [IFLA_MASTER] = { .type = NLA_U32 }, + [IFLA_CARRIER] = { .type = NLA_U8 }, + [IFLA_TXQLEN] = { .type = NLA_U32 }, + [IFLA_WEIGHT] = { .type = NLA_U32 }, + [IFLA_OPERSTATE] = { .type = NLA_U8 }, + [IFLA_LINKMODE] = { .type = NLA_U8 }, + [IFLA_LINKINFO] = { .type = NLA_NESTED }, + [IFLA_NET_NS_PID] = { .type = NLA_U32 }, + [IFLA_NET_NS_FD] = { .type = NLA_U32 }, + /* IFLA_IFALIAS is a string, but policy is set to NLA_BINARY to + * allow 0-length string (needed to remove an alias). + */ + [IFLA_IFALIAS] = { .type = NLA_BINARY, .len = IFALIASZ - 1 }, + [IFLA_VFINFO_LIST] = {. type = NLA_NESTED }, + [IFLA_VF_PORTS] = { .type = NLA_NESTED }, + [IFLA_PORT_SELF] = { .type = NLA_NESTED }, + [IFLA_AF_SPEC] = { .type = NLA_NESTED }, + [IFLA_EXT_MASK] = { .type = NLA_U32 }, + [IFLA_PROMISCUITY] = { .type = NLA_U32 }, + [IFLA_NUM_TX_QUEUES] = { .type = NLA_U32 }, + [IFLA_NUM_RX_QUEUES] = { .type = NLA_U32 }, + [IFLA_GSO_MAX_SEGS] = { .type = NLA_U32 }, + [IFLA_GSO_MAX_SIZE] = { .type = NLA_U32 }, + [IFLA_PHYS_PORT_ID] = { .type = NLA_BINARY, .len = MAX_PHYS_ITEM_ID_LEN }, + [IFLA_CARRIER_CHANGES] = { .type = NLA_U32 }, /* ignored */ + [IFLA_PHYS_SWITCH_ID] = { .type = NLA_BINARY, .len = MAX_PHYS_ITEM_ID_LEN }, + [IFLA_LINK_NETNSID] = { .type = NLA_S32 }, + [IFLA_PROTO_DOWN] = { .type = NLA_U8 }, + [IFLA_XDP] = { .type = NLA_NESTED }, + [IFLA_EVENT] = { .type = NLA_U32 }, + [IFLA_GROUP] = { .type = NLA_U32 }, + [IFLA_TARGET_NETNSID] = { .type = NLA_S32 }, + [IFLA_CARRIER_UP_COUNT] = { .type = NLA_U32 }, + [IFLA_CARRIER_DOWN_COUNT] = { .type = NLA_U32 }, + [IFLA_MIN_MTU] = { .type = NLA_U32 }, + [IFLA_MAX_MTU] = { .type = NLA_U32 }, + [IFLA_PROP_LIST] = { .type = NLA_NESTED }, + [IFLA_ALT_IFNAME] = { .type = NLA_STRING, + .len = ALTIFNAMSIZ - 1 }, + [IFLA_PERM_ADDRESS] = { .type = NLA_REJECT }, + [IFLA_PROTO_DOWN_REASON] = { .type = NLA_NESTED }, + [IFLA_NEW_IFINDEX] = NLA_POLICY_MIN(NLA_S32, 1), + [IFLA_PARENT_DEV_NAME] = { .type = NLA_NUL_STRING }, + [IFLA_GRO_MAX_SIZE] = { .type = NLA_U32 }, + [IFLA_TSO_MAX_SIZE] = { .type = NLA_REJECT }, + [IFLA_TSO_MAX_SEGS] = { .type = NLA_REJECT }, + [IFLA_ALLMULTI] = { .type = NLA_REJECT }, +}; + +static const struct nla_policy ifla_info_policy[IFLA_INFO_MAX+1] = { + [IFLA_INFO_KIND] = { .type = NLA_STRING }, + [IFLA_INFO_DATA] = { .type = NLA_NESTED }, + [IFLA_INFO_SLAVE_KIND] = { .type = NLA_STRING }, + [IFLA_INFO_SLAVE_DATA] = { .type = NLA_NESTED }, +}; + +static const struct nla_policy ifla_vf_policy[IFLA_VF_MAX+1] = { + [IFLA_VF_MAC] = { .len = sizeof(struct ifla_vf_mac) }, + [IFLA_VF_BROADCAST] = { .type = NLA_REJECT }, + [IFLA_VF_VLAN] = { .len = sizeof(struct ifla_vf_vlan) }, + [IFLA_VF_VLAN_LIST] = { .type = NLA_NESTED }, + [IFLA_VF_TX_RATE] = { .len = sizeof(struct ifla_vf_tx_rate) }, + [IFLA_VF_SPOOFCHK] = { .len = sizeof(struct ifla_vf_spoofchk) }, + [IFLA_VF_RATE] = { .len = sizeof(struct ifla_vf_rate) }, + [IFLA_VF_LINK_STATE] = { .len = sizeof(struct ifla_vf_link_state) }, + [IFLA_VF_RSS_QUERY_EN] = { .len = sizeof(struct ifla_vf_rss_query_en) }, + [IFLA_VF_STATS] = { .type = NLA_NESTED }, + [IFLA_VF_TRUST] = { .len = sizeof(struct ifla_vf_trust) }, + [IFLA_VF_IB_NODE_GUID] = { .len = sizeof(struct ifla_vf_guid) }, + [IFLA_VF_IB_PORT_GUID] = { .len = sizeof(struct ifla_vf_guid) }, +}; + +static const struct nla_policy ifla_port_policy[IFLA_PORT_MAX+1] = { + [IFLA_PORT_VF] = { .type = NLA_U32 }, + [IFLA_PORT_PROFILE] = { .type = NLA_STRING, + .len = PORT_PROFILE_MAX }, + [IFLA_PORT_INSTANCE_UUID] = { .type = NLA_BINARY, + .len = PORT_UUID_MAX }, + [IFLA_PORT_HOST_UUID] = { .type = NLA_STRING, + .len = PORT_UUID_MAX }, + [IFLA_PORT_REQUEST] = { .type = NLA_U8, }, + [IFLA_PORT_RESPONSE] = { .type = NLA_U16, }, + + /* Unused, but we need to keep it here since user space could + * fill it. It's also broken with regard to NLA_BINARY use in + * combination with structs. + */ + [IFLA_PORT_VSI_TYPE] = { .type = NLA_BINARY, + .len = sizeof(struct ifla_port_vsi) }, +}; + +static const struct nla_policy ifla_xdp_policy[IFLA_XDP_MAX + 1] = { + [IFLA_XDP_UNSPEC] = { .strict_start_type = IFLA_XDP_EXPECTED_FD }, + [IFLA_XDP_FD] = { .type = NLA_S32 }, + [IFLA_XDP_EXPECTED_FD] = { .type = NLA_S32 }, + [IFLA_XDP_ATTACHED] = { .type = NLA_U8 }, + [IFLA_XDP_FLAGS] = { .type = NLA_U32 }, + [IFLA_XDP_PROG_ID] = { .type = NLA_U32 }, +}; + +static const struct rtnl_link_ops *linkinfo_to_kind_ops(const struct nlattr *nla) +{ + const struct rtnl_link_ops *ops = NULL; + struct nlattr *linfo[IFLA_INFO_MAX + 1]; + + if (nla_parse_nested_deprecated(linfo, IFLA_INFO_MAX, nla, ifla_info_policy, NULL) < 0) + return NULL; + + if (linfo[IFLA_INFO_KIND]) { + char kind[MODULE_NAME_LEN]; + + nla_strscpy(kind, linfo[IFLA_INFO_KIND], sizeof(kind)); + ops = rtnl_link_ops_get(kind); + } + + return ops; +} + +static bool link_master_filtered(struct net_device *dev, int master_idx) +{ + struct net_device *master; + + if (!master_idx) + return false; + + master = netdev_master_upper_dev_get(dev); + + /* 0 is already used to denote IFLA_MASTER wasn't passed, therefore need + * another invalid value for ifindex to denote "no master". + */ + if (master_idx == -1) + return !!master; + + if (!master || master->ifindex != master_idx) + return true; + + return false; +} + +static bool link_kind_filtered(const struct net_device *dev, + const struct rtnl_link_ops *kind_ops) +{ + if (kind_ops && dev->rtnl_link_ops != kind_ops) + return true; + + return false; +} + +static bool link_dump_filtered(struct net_device *dev, + int master_idx, + const struct rtnl_link_ops *kind_ops) +{ + if (link_master_filtered(dev, master_idx) || + link_kind_filtered(dev, kind_ops)) + return true; + + return false; +} + +/** + * rtnl_get_net_ns_capable - Get netns if sufficiently privileged. + * @sk: netlink socket + * @netnsid: network namespace identifier + * + * Returns the network namespace identified by netnsid on success or an error + * pointer on failure. + */ +struct net *rtnl_get_net_ns_capable(struct sock *sk, int netnsid) +{ + struct net *net; + + net = get_net_ns_by_id(sock_net(sk), netnsid); + if (!net) + return ERR_PTR(-EINVAL); + + /* For now, the caller is required to have CAP_NET_ADMIN in + * the user namespace owning the target net ns. + */ + if (!sk_ns_capable(sk, net->user_ns, CAP_NET_ADMIN)) { + put_net(net); + return ERR_PTR(-EACCES); + } + return net; +} +EXPORT_SYMBOL_GPL(rtnl_get_net_ns_capable); + +static int rtnl_valid_dump_ifinfo_req(const struct nlmsghdr *nlh, + bool strict_check, struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + int hdrlen; + + if (strict_check) { + struct ifinfomsg *ifm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + NL_SET_ERR_MSG(extack, "Invalid header for link dump"); + return -EINVAL; + } + + ifm = nlmsg_data(nlh); + if (ifm->__ifi_pad || ifm->ifi_type || ifm->ifi_flags || + ifm->ifi_change) { + NL_SET_ERR_MSG(extack, "Invalid values in header for link dump request"); + return -EINVAL; + } + if (ifm->ifi_index) { + NL_SET_ERR_MSG(extack, "Filter by device index not supported for link dumps"); + return -EINVAL; + } + + return nlmsg_parse_deprecated_strict(nlh, sizeof(*ifm), tb, + IFLA_MAX, ifla_policy, + extack); + } + + /* A hack to preserve kernel<->userspace interface. + * The correct header is ifinfomsg. It is consistent with rtnl_getlink. + * However, before Linux v3.9 the code here assumed rtgenmsg and that's + * what iproute2 < v3.9.0 used. + * We can detect the old iproute2. Even including the IFLA_EXT_MASK + * attribute, its netlink message is shorter than struct ifinfomsg. + */ + hdrlen = nlmsg_len(nlh) < sizeof(struct ifinfomsg) ? + sizeof(struct rtgenmsg) : sizeof(struct ifinfomsg); + + return nlmsg_parse_deprecated(nlh, hdrlen, tb, IFLA_MAX, ifla_policy, + extack); +} + +static int rtnl_dump_ifinfo(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct netlink_ext_ack *extack = cb->extack; + const struct nlmsghdr *nlh = cb->nlh; + struct net *net = sock_net(skb->sk); + struct net *tgt_net = net; + int h, s_h; + int idx = 0, s_idx; + struct net_device *dev; + struct hlist_head *head; + struct nlattr *tb[IFLA_MAX+1]; + u32 ext_filter_mask = 0; + const struct rtnl_link_ops *kind_ops = NULL; + unsigned int flags = NLM_F_MULTI; + int master_idx = 0; + int netnsid = -1; + int err, i; + + s_h = cb->args[0]; + s_idx = cb->args[1]; + + err = rtnl_valid_dump_ifinfo_req(nlh, cb->strict_check, tb, extack); + if (err < 0) { + if (cb->strict_check) + return err; + + goto walk_entries; + } + + for (i = 0; i <= IFLA_MAX; ++i) { + if (!tb[i]) + continue; + + /* new attributes should only be added with strict checking */ + switch (i) { + case IFLA_TARGET_NETNSID: + netnsid = nla_get_s32(tb[i]); + tgt_net = rtnl_get_net_ns_capable(skb->sk, netnsid); + if (IS_ERR(tgt_net)) { + NL_SET_ERR_MSG(extack, "Invalid target network namespace id"); + return PTR_ERR(tgt_net); + } + break; + case IFLA_EXT_MASK: + ext_filter_mask = nla_get_u32(tb[i]); + break; + case IFLA_MASTER: + master_idx = nla_get_u32(tb[i]); + break; + case IFLA_LINKINFO: + kind_ops = linkinfo_to_kind_ops(tb[i]); + break; + default: + if (cb->strict_check) { + NL_SET_ERR_MSG(extack, "Unsupported attribute in link dump request"); + return -EINVAL; + } + } + } + + if (master_idx || kind_ops) + flags |= NLM_F_DUMP_FILTERED; + +walk_entries: + for (h = s_h; h < NETDEV_HASHENTRIES; h++, s_idx = 0) { + idx = 0; + head = &tgt_net->dev_index_head[h]; + hlist_for_each_entry(dev, head, index_hlist) { + if (link_dump_filtered(dev, master_idx, kind_ops)) + goto cont; + if (idx < s_idx) + goto cont; + err = rtnl_fill_ifinfo(skb, dev, net, + RTM_NEWLINK, + NETLINK_CB(cb->skb).portid, + nlh->nlmsg_seq, 0, flags, + ext_filter_mask, 0, NULL, 0, + netnsid, GFP_KERNEL); + + if (err < 0) { + if (likely(skb->len)) + goto out; + + goto out_err; + } +cont: + idx++; + } + } +out: + err = skb->len; +out_err: + cb->args[1] = idx; + cb->args[0] = h; + cb->seq = tgt_net->dev_base_seq; + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); + if (netnsid >= 0) + put_net(tgt_net); + + return err; +} + +int rtnl_nla_parse_ifinfomsg(struct nlattr **tb, const struct nlattr *nla_peer, + struct netlink_ext_ack *exterr) +{ + const struct ifinfomsg *ifmp; + const struct nlattr *attrs; + size_t len; + + ifmp = nla_data(nla_peer); + attrs = nla_data(nla_peer) + sizeof(struct ifinfomsg); + len = nla_len(nla_peer) - sizeof(struct ifinfomsg); + + if (ifmp->ifi_index < 0) { + NL_SET_ERR_MSG_ATTR(exterr, nla_peer, + "ifindex can't be negative"); + return -EINVAL; + } + + return nla_parse_deprecated(tb, IFLA_MAX, attrs, len, ifla_policy, + exterr); +} +EXPORT_SYMBOL(rtnl_nla_parse_ifinfomsg); + +struct net *rtnl_link_get_net(struct net *src_net, struct nlattr *tb[]) +{ + struct net *net; + /* Examine the link attributes and figure out which + * network namespace we are talking about. + */ + if (tb[IFLA_NET_NS_PID]) + net = get_net_ns_by_pid(nla_get_u32(tb[IFLA_NET_NS_PID])); + else if (tb[IFLA_NET_NS_FD]) + net = get_net_ns_by_fd(nla_get_u32(tb[IFLA_NET_NS_FD])); + else + net = get_net(src_net); + return net; +} +EXPORT_SYMBOL(rtnl_link_get_net); + +/* Figure out which network namespace we are talking about by + * examining the link attributes in the following order: + * + * 1. IFLA_NET_NS_PID + * 2. IFLA_NET_NS_FD + * 3. IFLA_TARGET_NETNSID + */ +static struct net *rtnl_link_get_net_by_nlattr(struct net *src_net, + struct nlattr *tb[]) +{ + struct net *net; + + if (tb[IFLA_NET_NS_PID] || tb[IFLA_NET_NS_FD]) + return rtnl_link_get_net(src_net, tb); + + if (!tb[IFLA_TARGET_NETNSID]) + return get_net(src_net); + + net = get_net_ns_by_id(src_net, nla_get_u32(tb[IFLA_TARGET_NETNSID])); + if (!net) + return ERR_PTR(-EINVAL); + + return net; +} + +static struct net *rtnl_link_get_net_capable(const struct sk_buff *skb, + struct net *src_net, + struct nlattr *tb[], int cap) +{ + struct net *net; + + net = rtnl_link_get_net_by_nlattr(src_net, tb); + if (IS_ERR(net)) + return net; + + if (!netlink_ns_capable(skb, net->user_ns, cap)) { + put_net(net); + return ERR_PTR(-EPERM); + } + + return net; +} + +/* Verify that rtnetlink requests do not pass additional properties + * potentially referring to different network namespaces. + */ +static int rtnl_ensure_unique_netns(struct nlattr *tb[], + struct netlink_ext_ack *extack, + bool netns_id_only) +{ + + if (netns_id_only) { + if (!tb[IFLA_NET_NS_PID] && !tb[IFLA_NET_NS_FD]) + return 0; + + NL_SET_ERR_MSG(extack, "specified netns attribute not supported"); + return -EOPNOTSUPP; + } + + if (tb[IFLA_TARGET_NETNSID] && (tb[IFLA_NET_NS_PID] || tb[IFLA_NET_NS_FD])) + goto invalid_attr; + + if (tb[IFLA_NET_NS_PID] && (tb[IFLA_TARGET_NETNSID] || tb[IFLA_NET_NS_FD])) + goto invalid_attr; + + if (tb[IFLA_NET_NS_FD] && (tb[IFLA_TARGET_NETNSID] || tb[IFLA_NET_NS_PID])) + goto invalid_attr; + + return 0; + +invalid_attr: + NL_SET_ERR_MSG(extack, "multiple netns identifying attributes specified"); + return -EINVAL; +} + +static int rtnl_set_vf_rate(struct net_device *dev, int vf, int min_tx_rate, + int max_tx_rate) +{ + const struct net_device_ops *ops = dev->netdev_ops; + + if (!ops->ndo_set_vf_rate) + return -EOPNOTSUPP; + if (max_tx_rate && max_tx_rate < min_tx_rate) + return -EINVAL; + + return ops->ndo_set_vf_rate(dev, vf, min_tx_rate, max_tx_rate); +} + +static int validate_linkmsg(struct net_device *dev, struct nlattr *tb[], + struct netlink_ext_ack *extack) +{ + if (dev) { + if (tb[IFLA_ADDRESS] && + nla_len(tb[IFLA_ADDRESS]) < dev->addr_len) + return -EINVAL; + + if (tb[IFLA_BROADCAST] && + nla_len(tb[IFLA_BROADCAST]) < dev->addr_len) + return -EINVAL; + } + + if (tb[IFLA_AF_SPEC]) { + struct nlattr *af; + int rem, err; + + nla_for_each_nested(af, tb[IFLA_AF_SPEC], rem) { + const struct rtnl_af_ops *af_ops; + + af_ops = rtnl_af_lookup(nla_type(af)); + if (!af_ops) + return -EAFNOSUPPORT; + + if (!af_ops->set_link_af) + return -EOPNOTSUPP; + + if (af_ops->validate_link_af) { + err = af_ops->validate_link_af(dev, af, extack); + if (err < 0) + return err; + } + } + } + + return 0; +} + +static int handle_infiniband_guid(struct net_device *dev, struct ifla_vf_guid *ivt, + int guid_type) +{ + const struct net_device_ops *ops = dev->netdev_ops; + + return ops->ndo_set_vf_guid(dev, ivt->vf, ivt->guid, guid_type); +} + +static int handle_vf_guid(struct net_device *dev, struct ifla_vf_guid *ivt, int guid_type) +{ + if (dev->type != ARPHRD_INFINIBAND) + return -EOPNOTSUPP; + + return handle_infiniband_guid(dev, ivt, guid_type); +} + +static int do_setvfinfo(struct net_device *dev, struct nlattr **tb) +{ + const struct net_device_ops *ops = dev->netdev_ops; + int err = -EINVAL; + + if (tb[IFLA_VF_MAC]) { + struct ifla_vf_mac *ivm = nla_data(tb[IFLA_VF_MAC]); + + if (ivm->vf >= INT_MAX) + return -EINVAL; + err = -EOPNOTSUPP; + if (ops->ndo_set_vf_mac) + err = ops->ndo_set_vf_mac(dev, ivm->vf, + ivm->mac); + if (err < 0) + return err; + } + + if (tb[IFLA_VF_VLAN]) { + struct ifla_vf_vlan *ivv = nla_data(tb[IFLA_VF_VLAN]); + + if (ivv->vf >= INT_MAX) + return -EINVAL; + err = -EOPNOTSUPP; + if (ops->ndo_set_vf_vlan) + err = ops->ndo_set_vf_vlan(dev, ivv->vf, ivv->vlan, + ivv->qos, + htons(ETH_P_8021Q)); + if (err < 0) + return err; + } + + if (tb[IFLA_VF_VLAN_LIST]) { + struct ifla_vf_vlan_info *ivvl[MAX_VLAN_LIST_LEN]; + struct nlattr *attr; + int rem, len = 0; + + err = -EOPNOTSUPP; + if (!ops->ndo_set_vf_vlan) + return err; + + nla_for_each_nested(attr, tb[IFLA_VF_VLAN_LIST], rem) { + if (nla_type(attr) != IFLA_VF_VLAN_INFO || + nla_len(attr) < NLA_HDRLEN) { + return -EINVAL; + } + if (len >= MAX_VLAN_LIST_LEN) + return -EOPNOTSUPP; + ivvl[len] = nla_data(attr); + + len++; + } + if (len == 0) + return -EINVAL; + + if (ivvl[0]->vf >= INT_MAX) + return -EINVAL; + err = ops->ndo_set_vf_vlan(dev, ivvl[0]->vf, ivvl[0]->vlan, + ivvl[0]->qos, ivvl[0]->vlan_proto); + if (err < 0) + return err; + } + + if (tb[IFLA_VF_TX_RATE]) { + struct ifla_vf_tx_rate *ivt = nla_data(tb[IFLA_VF_TX_RATE]); + struct ifla_vf_info ivf; + + if (ivt->vf >= INT_MAX) + return -EINVAL; + err = -EOPNOTSUPP; + if (ops->ndo_get_vf_config) + err = ops->ndo_get_vf_config(dev, ivt->vf, &ivf); + if (err < 0) + return err; + + err = rtnl_set_vf_rate(dev, ivt->vf, + ivf.min_tx_rate, ivt->rate); + if (err < 0) + return err; + } + + if (tb[IFLA_VF_RATE]) { + struct ifla_vf_rate *ivt = nla_data(tb[IFLA_VF_RATE]); + + if (ivt->vf >= INT_MAX) + return -EINVAL; + + err = rtnl_set_vf_rate(dev, ivt->vf, + ivt->min_tx_rate, ivt->max_tx_rate); + if (err < 0) + return err; + } + + if (tb[IFLA_VF_SPOOFCHK]) { + struct ifla_vf_spoofchk *ivs = nla_data(tb[IFLA_VF_SPOOFCHK]); + + if (ivs->vf >= INT_MAX) + return -EINVAL; + err = -EOPNOTSUPP; + if (ops->ndo_set_vf_spoofchk) + err = ops->ndo_set_vf_spoofchk(dev, ivs->vf, + ivs->setting); + if (err < 0) + return err; + } + + if (tb[IFLA_VF_LINK_STATE]) { + struct ifla_vf_link_state *ivl = nla_data(tb[IFLA_VF_LINK_STATE]); + + if (ivl->vf >= INT_MAX) + return -EINVAL; + err = -EOPNOTSUPP; + if (ops->ndo_set_vf_link_state) + err = ops->ndo_set_vf_link_state(dev, ivl->vf, + ivl->link_state); + if (err < 0) + return err; + } + + if (tb[IFLA_VF_RSS_QUERY_EN]) { + struct ifla_vf_rss_query_en *ivrssq_en; + + err = -EOPNOTSUPP; + ivrssq_en = nla_data(tb[IFLA_VF_RSS_QUERY_EN]); + if (ivrssq_en->vf >= INT_MAX) + return -EINVAL; + if (ops->ndo_set_vf_rss_query_en) + err = ops->ndo_set_vf_rss_query_en(dev, ivrssq_en->vf, + ivrssq_en->setting); + if (err < 0) + return err; + } + + if (tb[IFLA_VF_TRUST]) { + struct ifla_vf_trust *ivt = nla_data(tb[IFLA_VF_TRUST]); + + if (ivt->vf >= INT_MAX) + return -EINVAL; + err = -EOPNOTSUPP; + if (ops->ndo_set_vf_trust) + err = ops->ndo_set_vf_trust(dev, ivt->vf, ivt->setting); + if (err < 0) + return err; + } + + if (tb[IFLA_VF_IB_NODE_GUID]) { + struct ifla_vf_guid *ivt = nla_data(tb[IFLA_VF_IB_NODE_GUID]); + + if (ivt->vf >= INT_MAX) + return -EINVAL; + if (!ops->ndo_set_vf_guid) + return -EOPNOTSUPP; + return handle_vf_guid(dev, ivt, IFLA_VF_IB_NODE_GUID); + } + + if (tb[IFLA_VF_IB_PORT_GUID]) { + struct ifla_vf_guid *ivt = nla_data(tb[IFLA_VF_IB_PORT_GUID]); + + if (ivt->vf >= INT_MAX) + return -EINVAL; + if (!ops->ndo_set_vf_guid) + return -EOPNOTSUPP; + + return handle_vf_guid(dev, ivt, IFLA_VF_IB_PORT_GUID); + } + + return err; +} + +static int do_set_master(struct net_device *dev, int ifindex, + struct netlink_ext_ack *extack) +{ + struct net_device *upper_dev = netdev_master_upper_dev_get(dev); + const struct net_device_ops *ops; + int err; + + if (upper_dev) { + if (upper_dev->ifindex == ifindex) + return 0; + ops = upper_dev->netdev_ops; + if (ops->ndo_del_slave) { + err = ops->ndo_del_slave(upper_dev, dev); + if (err) + return err; + } else { + return -EOPNOTSUPP; + } + } + + if (ifindex) { + upper_dev = __dev_get_by_index(dev_net(dev), ifindex); + if (!upper_dev) + return -EINVAL; + ops = upper_dev->netdev_ops; + if (ops->ndo_add_slave) { + err = ops->ndo_add_slave(upper_dev, dev, extack); + if (err) + return err; + } else { + return -EOPNOTSUPP; + } + } + return 0; +} + +static const struct nla_policy ifla_proto_down_reason_policy[IFLA_PROTO_DOWN_REASON_VALUE + 1] = { + [IFLA_PROTO_DOWN_REASON_MASK] = { .type = NLA_U32 }, + [IFLA_PROTO_DOWN_REASON_VALUE] = { .type = NLA_U32 }, +}; + +static int do_set_proto_down(struct net_device *dev, + struct nlattr *nl_proto_down, + struct nlattr *nl_proto_down_reason, + struct netlink_ext_ack *extack) +{ + struct nlattr *pdreason[IFLA_PROTO_DOWN_REASON_MAX + 1]; + unsigned long mask = 0; + u32 value; + bool proto_down; + int err; + + if (!(dev->priv_flags & IFF_CHANGE_PROTO_DOWN)) { + NL_SET_ERR_MSG(extack, "Protodown not supported by device"); + return -EOPNOTSUPP; + } + + if (nl_proto_down_reason) { + err = nla_parse_nested_deprecated(pdreason, + IFLA_PROTO_DOWN_REASON_MAX, + nl_proto_down_reason, + ifla_proto_down_reason_policy, + NULL); + if (err < 0) + return err; + + if (!pdreason[IFLA_PROTO_DOWN_REASON_VALUE]) { + NL_SET_ERR_MSG(extack, "Invalid protodown reason value"); + return -EINVAL; + } + + value = nla_get_u32(pdreason[IFLA_PROTO_DOWN_REASON_VALUE]); + + if (pdreason[IFLA_PROTO_DOWN_REASON_MASK]) + mask = nla_get_u32(pdreason[IFLA_PROTO_DOWN_REASON_MASK]); + + dev_change_proto_down_reason(dev, mask, value); + } + + if (nl_proto_down) { + proto_down = nla_get_u8(nl_proto_down); + + /* Don't turn off protodown if there are active reasons */ + if (!proto_down && dev->proto_down_reason) { + NL_SET_ERR_MSG(extack, "Cannot clear protodown, active reasons"); + return -EBUSY; + } + err = dev_change_proto_down(dev, + proto_down); + if (err) + return err; + } + + return 0; +} + +#define DO_SETLINK_MODIFIED 0x01 +/* notify flag means notify + modified. */ +#define DO_SETLINK_NOTIFY 0x03 +static int do_setlink(const struct sk_buff *skb, + struct net_device *dev, struct ifinfomsg *ifm, + struct netlink_ext_ack *extack, + struct nlattr **tb, int status) +{ + const struct net_device_ops *ops = dev->netdev_ops; + char ifname[IFNAMSIZ]; + int err; + + err = validate_linkmsg(dev, tb, extack); + if (err < 0) + return err; + + if (tb[IFLA_IFNAME]) + nla_strscpy(ifname, tb[IFLA_IFNAME], IFNAMSIZ); + else + ifname[0] = '\0'; + + if (tb[IFLA_NET_NS_PID] || tb[IFLA_NET_NS_FD] || tb[IFLA_TARGET_NETNSID]) { + const char *pat = ifname[0] ? ifname : NULL; + struct net *net; + int new_ifindex; + + net = rtnl_link_get_net_capable(skb, dev_net(dev), + tb, CAP_NET_ADMIN); + if (IS_ERR(net)) { + err = PTR_ERR(net); + goto errout; + } + + if (tb[IFLA_NEW_IFINDEX]) + new_ifindex = nla_get_s32(tb[IFLA_NEW_IFINDEX]); + else + new_ifindex = 0; + + err = __dev_change_net_namespace(dev, net, pat, new_ifindex); + put_net(net); + if (err) + goto errout; + status |= DO_SETLINK_MODIFIED; + } + + if (tb[IFLA_MAP]) { + struct rtnl_link_ifmap *u_map; + struct ifmap k_map; + + if (!ops->ndo_set_config) { + err = -EOPNOTSUPP; + goto errout; + } + + if (!netif_device_present(dev)) { + err = -ENODEV; + goto errout; + } + + u_map = nla_data(tb[IFLA_MAP]); + k_map.mem_start = (unsigned long) u_map->mem_start; + k_map.mem_end = (unsigned long) u_map->mem_end; + k_map.base_addr = (unsigned short) u_map->base_addr; + k_map.irq = (unsigned char) u_map->irq; + k_map.dma = (unsigned char) u_map->dma; + k_map.port = (unsigned char) u_map->port; + + err = ops->ndo_set_config(dev, &k_map); + if (err < 0) + goto errout; + + status |= DO_SETLINK_NOTIFY; + } + + if (tb[IFLA_ADDRESS]) { + struct sockaddr *sa; + int len; + + len = sizeof(sa_family_t) + max_t(size_t, dev->addr_len, + sizeof(*sa)); + sa = kmalloc(len, GFP_KERNEL); + if (!sa) { + err = -ENOMEM; + goto errout; + } + sa->sa_family = dev->type; + memcpy(sa->sa_data, nla_data(tb[IFLA_ADDRESS]), + dev->addr_len); + err = dev_set_mac_address_user(dev, sa, extack); + kfree(sa); + if (err) + goto errout; + status |= DO_SETLINK_MODIFIED; + } + + if (tb[IFLA_MTU]) { + err = dev_set_mtu_ext(dev, nla_get_u32(tb[IFLA_MTU]), extack); + if (err < 0) + goto errout; + status |= DO_SETLINK_MODIFIED; + } + + if (tb[IFLA_GROUP]) { + dev_set_group(dev, nla_get_u32(tb[IFLA_GROUP])); + status |= DO_SETLINK_NOTIFY; + } + + /* + * Interface selected by interface index but interface + * name provided implies that a name change has been + * requested. + */ + if (ifm->ifi_index > 0 && ifname[0]) { + err = dev_change_name(dev, ifname); + if (err < 0) + goto errout; + status |= DO_SETLINK_MODIFIED; + } + + if (tb[IFLA_IFALIAS]) { + err = dev_set_alias(dev, nla_data(tb[IFLA_IFALIAS]), + nla_len(tb[IFLA_IFALIAS])); + if (err < 0) + goto errout; + status |= DO_SETLINK_NOTIFY; + } + + if (tb[IFLA_BROADCAST]) { + nla_memcpy(dev->broadcast, tb[IFLA_BROADCAST], dev->addr_len); + call_netdevice_notifiers(NETDEV_CHANGEADDR, dev); + } + + if (ifm->ifi_flags || ifm->ifi_change) { + err = dev_change_flags(dev, rtnl_dev_combine_flags(dev, ifm), + extack); + if (err < 0) + goto errout; + } + + if (tb[IFLA_MASTER]) { + err = do_set_master(dev, nla_get_u32(tb[IFLA_MASTER]), extack); + if (err) + goto errout; + status |= DO_SETLINK_MODIFIED; + } + + if (tb[IFLA_CARRIER]) { + err = dev_change_carrier(dev, nla_get_u8(tb[IFLA_CARRIER])); + if (err) + goto errout; + status |= DO_SETLINK_MODIFIED; + } + + if (tb[IFLA_TXQLEN]) { + unsigned int value = nla_get_u32(tb[IFLA_TXQLEN]); + + err = dev_change_tx_queue_len(dev, value); + if (err) + goto errout; + status |= DO_SETLINK_MODIFIED; + } + + if (tb[IFLA_GSO_MAX_SIZE]) { + u32 max_size = nla_get_u32(tb[IFLA_GSO_MAX_SIZE]); + + if (max_size > dev->tso_max_size) { + err = -EINVAL; + goto errout; + } + + if (dev->gso_max_size ^ max_size) { + netif_set_gso_max_size(dev, max_size); + status |= DO_SETLINK_MODIFIED; + } + } + + if (tb[IFLA_GSO_MAX_SEGS]) { + u32 max_segs = nla_get_u32(tb[IFLA_GSO_MAX_SEGS]); + + if (max_segs > GSO_MAX_SEGS || max_segs > dev->tso_max_segs) { + err = -EINVAL; + goto errout; + } + + if (dev->gso_max_segs ^ max_segs) { + netif_set_gso_max_segs(dev, max_segs); + status |= DO_SETLINK_MODIFIED; + } + } + + if (tb[IFLA_GRO_MAX_SIZE]) { + u32 gro_max_size = nla_get_u32(tb[IFLA_GRO_MAX_SIZE]); + + if (dev->gro_max_size ^ gro_max_size) { + netif_set_gro_max_size(dev, gro_max_size); + status |= DO_SETLINK_MODIFIED; + } + } + + if (tb[IFLA_OPERSTATE]) + set_operstate(dev, nla_get_u8(tb[IFLA_OPERSTATE])); + + if (tb[IFLA_LINKMODE]) { + unsigned char value = nla_get_u8(tb[IFLA_LINKMODE]); + + write_lock(&dev_base_lock); + if (dev->link_mode ^ value) + status |= DO_SETLINK_NOTIFY; + dev->link_mode = value; + write_unlock(&dev_base_lock); + } + + if (tb[IFLA_VFINFO_LIST]) { + struct nlattr *vfinfo[IFLA_VF_MAX + 1]; + struct nlattr *attr; + int rem; + + nla_for_each_nested(attr, tb[IFLA_VFINFO_LIST], rem) { + if (nla_type(attr) != IFLA_VF_INFO || + nla_len(attr) < NLA_HDRLEN) { + err = -EINVAL; + goto errout; + } + err = nla_parse_nested_deprecated(vfinfo, IFLA_VF_MAX, + attr, + ifla_vf_policy, + NULL); + if (err < 0) + goto errout; + err = do_setvfinfo(dev, vfinfo); + if (err < 0) + goto errout; + status |= DO_SETLINK_NOTIFY; + } + } + err = 0; + + if (tb[IFLA_VF_PORTS]) { + struct nlattr *port[IFLA_PORT_MAX+1]; + struct nlattr *attr; + int vf; + int rem; + + err = -EOPNOTSUPP; + if (!ops->ndo_set_vf_port) + goto errout; + + nla_for_each_nested(attr, tb[IFLA_VF_PORTS], rem) { + if (nla_type(attr) != IFLA_VF_PORT || + nla_len(attr) < NLA_HDRLEN) { + err = -EINVAL; + goto errout; + } + err = nla_parse_nested_deprecated(port, IFLA_PORT_MAX, + attr, + ifla_port_policy, + NULL); + if (err < 0) + goto errout; + if (!port[IFLA_PORT_VF]) { + err = -EOPNOTSUPP; + goto errout; + } + vf = nla_get_u32(port[IFLA_PORT_VF]); + err = ops->ndo_set_vf_port(dev, vf, port); + if (err < 0) + goto errout; + status |= DO_SETLINK_NOTIFY; + } + } + err = 0; + + if (tb[IFLA_PORT_SELF]) { + struct nlattr *port[IFLA_PORT_MAX+1]; + + err = nla_parse_nested_deprecated(port, IFLA_PORT_MAX, + tb[IFLA_PORT_SELF], + ifla_port_policy, NULL); + if (err < 0) + goto errout; + + err = -EOPNOTSUPP; + if (ops->ndo_set_vf_port) + err = ops->ndo_set_vf_port(dev, PORT_SELF_VF, port); + if (err < 0) + goto errout; + status |= DO_SETLINK_NOTIFY; + } + + if (tb[IFLA_AF_SPEC]) { + struct nlattr *af; + int rem; + + nla_for_each_nested(af, tb[IFLA_AF_SPEC], rem) { + const struct rtnl_af_ops *af_ops; + + BUG_ON(!(af_ops = rtnl_af_lookup(nla_type(af)))); + + err = af_ops->set_link_af(dev, af, extack); + if (err < 0) + goto errout; + + status |= DO_SETLINK_NOTIFY; + } + } + err = 0; + + if (tb[IFLA_PROTO_DOWN] || tb[IFLA_PROTO_DOWN_REASON]) { + err = do_set_proto_down(dev, tb[IFLA_PROTO_DOWN], + tb[IFLA_PROTO_DOWN_REASON], extack); + if (err) + goto errout; + status |= DO_SETLINK_NOTIFY; + } + + if (tb[IFLA_XDP]) { + struct nlattr *xdp[IFLA_XDP_MAX + 1]; + u32 xdp_flags = 0; + + err = nla_parse_nested_deprecated(xdp, IFLA_XDP_MAX, + tb[IFLA_XDP], + ifla_xdp_policy, NULL); + if (err < 0) + goto errout; + + if (xdp[IFLA_XDP_ATTACHED] || xdp[IFLA_XDP_PROG_ID]) { + err = -EINVAL; + goto errout; + } + + if (xdp[IFLA_XDP_FLAGS]) { + xdp_flags = nla_get_u32(xdp[IFLA_XDP_FLAGS]); + if (xdp_flags & ~XDP_FLAGS_MASK) { + err = -EINVAL; + goto errout; + } + if (hweight32(xdp_flags & XDP_FLAGS_MODES) > 1) { + err = -EINVAL; + goto errout; + } + } + + if (xdp[IFLA_XDP_FD]) { + int expected_fd = -1; + + if (xdp_flags & XDP_FLAGS_REPLACE) { + if (!xdp[IFLA_XDP_EXPECTED_FD]) { + err = -EINVAL; + goto errout; + } + expected_fd = + nla_get_s32(xdp[IFLA_XDP_EXPECTED_FD]); + } + + err = dev_change_xdp_fd(dev, extack, + nla_get_s32(xdp[IFLA_XDP_FD]), + expected_fd, + xdp_flags); + if (err) + goto errout; + status |= DO_SETLINK_NOTIFY; + } + } + +errout: + if (status & DO_SETLINK_MODIFIED) { + if ((status & DO_SETLINK_NOTIFY) == DO_SETLINK_NOTIFY) + netdev_state_change(dev); + + if (err < 0) + net_warn_ratelimited("A link change request failed with some changes committed already. Interface %s may have been left with an inconsistent configuration, please check.\n", + dev->name); + } + + return err; +} + +static struct net_device *rtnl_dev_get(struct net *net, + struct nlattr *tb[]) +{ + char ifname[ALTIFNAMSIZ]; + + if (tb[IFLA_IFNAME]) + nla_strscpy(ifname, tb[IFLA_IFNAME], IFNAMSIZ); + else if (tb[IFLA_ALT_IFNAME]) + nla_strscpy(ifname, tb[IFLA_ALT_IFNAME], ALTIFNAMSIZ); + else + return NULL; + + return __dev_get_by_name(net, ifname); +} + +static int rtnl_setlink(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct ifinfomsg *ifm; + struct net_device *dev; + int err; + struct nlattr *tb[IFLA_MAX+1]; + + err = nlmsg_parse_deprecated(nlh, sizeof(*ifm), tb, IFLA_MAX, + ifla_policy, extack); + if (err < 0) + goto errout; + + err = rtnl_ensure_unique_netns(tb, extack, false); + if (err < 0) + goto errout; + + err = -EINVAL; + ifm = nlmsg_data(nlh); + if (ifm->ifi_index > 0) + dev = __dev_get_by_index(net, ifm->ifi_index); + else if (tb[IFLA_IFNAME] || tb[IFLA_ALT_IFNAME]) + dev = rtnl_dev_get(net, tb); + else + goto errout; + + if (dev == NULL) { + err = -ENODEV; + goto errout; + } + + err = do_setlink(skb, dev, ifm, extack, tb, 0); +errout: + return err; +} + +static int rtnl_group_dellink(const struct net *net, int group) +{ + struct net_device *dev, *aux; + LIST_HEAD(list_kill); + bool found = false; + + if (!group) + return -EPERM; + + for_each_netdev(net, dev) { + if (dev->group == group) { + const struct rtnl_link_ops *ops; + + found = true; + ops = dev->rtnl_link_ops; + if (!ops || !ops->dellink) + return -EOPNOTSUPP; + } + } + + if (!found) + return -ENODEV; + + for_each_netdev_safe(net, dev, aux) { + if (dev->group == group) { + const struct rtnl_link_ops *ops; + + ops = dev->rtnl_link_ops; + ops->dellink(dev, &list_kill); + } + } + unregister_netdevice_many(&list_kill); + + return 0; +} + +int rtnl_delete_link(struct net_device *dev) +{ + const struct rtnl_link_ops *ops; + LIST_HEAD(list_kill); + + ops = dev->rtnl_link_ops; + if (!ops || !ops->dellink) + return -EOPNOTSUPP; + + ops->dellink(dev, &list_kill); + unregister_netdevice_many(&list_kill); + + return 0; +} +EXPORT_SYMBOL_GPL(rtnl_delete_link); + +static int rtnl_dellink(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct net *tgt_net = net; + struct net_device *dev = NULL; + struct ifinfomsg *ifm; + struct nlattr *tb[IFLA_MAX+1]; + int err; + int netnsid = -1; + + err = nlmsg_parse_deprecated(nlh, sizeof(*ifm), tb, IFLA_MAX, + ifla_policy, extack); + if (err < 0) + return err; + + err = rtnl_ensure_unique_netns(tb, extack, true); + if (err < 0) + return err; + + if (tb[IFLA_TARGET_NETNSID]) { + netnsid = nla_get_s32(tb[IFLA_TARGET_NETNSID]); + tgt_net = rtnl_get_net_ns_capable(NETLINK_CB(skb).sk, netnsid); + if (IS_ERR(tgt_net)) + return PTR_ERR(tgt_net); + } + + err = -EINVAL; + ifm = nlmsg_data(nlh); + if (ifm->ifi_index > 0) + dev = __dev_get_by_index(tgt_net, ifm->ifi_index); + else if (tb[IFLA_IFNAME] || tb[IFLA_ALT_IFNAME]) + dev = rtnl_dev_get(net, tb); + else if (tb[IFLA_GROUP]) + err = rtnl_group_dellink(tgt_net, nla_get_u32(tb[IFLA_GROUP])); + else + goto out; + + if (!dev) { + if (tb[IFLA_IFNAME] || tb[IFLA_ALT_IFNAME] || ifm->ifi_index > 0) + err = -ENODEV; + + goto out; + } + + err = rtnl_delete_link(dev); + +out: + if (netnsid >= 0) + put_net(tgt_net); + + return err; +} + +int rtnl_configure_link(struct net_device *dev, const struct ifinfomsg *ifm) +{ + unsigned int old_flags; + int err; + + old_flags = dev->flags; + if (ifm && (ifm->ifi_flags || ifm->ifi_change)) { + err = __dev_change_flags(dev, rtnl_dev_combine_flags(dev, ifm), + NULL); + if (err < 0) + return err; + } + + if (dev->rtnl_link_state == RTNL_LINK_INITIALIZED) { + __dev_notify_flags(dev, old_flags, (old_flags ^ dev->flags)); + } else { + dev->rtnl_link_state = RTNL_LINK_INITIALIZED; + __dev_notify_flags(dev, old_flags, ~0U); + } + return 0; +} +EXPORT_SYMBOL(rtnl_configure_link); + +struct net_device *rtnl_create_link(struct net *net, const char *ifname, + unsigned char name_assign_type, + const struct rtnl_link_ops *ops, + struct nlattr *tb[], + struct netlink_ext_ack *extack) +{ + struct net_device *dev; + unsigned int num_tx_queues = 1; + unsigned int num_rx_queues = 1; + int err; + + if (tb[IFLA_NUM_TX_QUEUES]) + num_tx_queues = nla_get_u32(tb[IFLA_NUM_TX_QUEUES]); + else if (ops->get_num_tx_queues) + num_tx_queues = ops->get_num_tx_queues(); + + if (tb[IFLA_NUM_RX_QUEUES]) + num_rx_queues = nla_get_u32(tb[IFLA_NUM_RX_QUEUES]); + else if (ops->get_num_rx_queues) + num_rx_queues = ops->get_num_rx_queues(); + + if (num_tx_queues < 1 || num_tx_queues > 4096) { + NL_SET_ERR_MSG(extack, "Invalid number of transmit queues"); + return ERR_PTR(-EINVAL); + } + + if (num_rx_queues < 1 || num_rx_queues > 4096) { + NL_SET_ERR_MSG(extack, "Invalid number of receive queues"); + return ERR_PTR(-EINVAL); + } + + if (ops->alloc) { + dev = ops->alloc(tb, ifname, name_assign_type, + num_tx_queues, num_rx_queues); + if (IS_ERR(dev)) + return dev; + } else { + dev = alloc_netdev_mqs(ops->priv_size, ifname, + name_assign_type, ops->setup, + num_tx_queues, num_rx_queues); + } + + if (!dev) + return ERR_PTR(-ENOMEM); + + err = validate_linkmsg(dev, tb, extack); + if (err < 0) { + free_netdev(dev); + return ERR_PTR(err); + } + + dev_net_set(dev, net); + dev->rtnl_link_ops = ops; + dev->rtnl_link_state = RTNL_LINK_INITIALIZING; + + if (tb[IFLA_MTU]) { + u32 mtu = nla_get_u32(tb[IFLA_MTU]); + + err = dev_validate_mtu(dev, mtu, extack); + if (err) { + free_netdev(dev); + return ERR_PTR(err); + } + dev->mtu = mtu; + } + if (tb[IFLA_ADDRESS]) { + __dev_addr_set(dev, nla_data(tb[IFLA_ADDRESS]), + nla_len(tb[IFLA_ADDRESS])); + dev->addr_assign_type = NET_ADDR_SET; + } + if (tb[IFLA_BROADCAST]) + memcpy(dev->broadcast, nla_data(tb[IFLA_BROADCAST]), + nla_len(tb[IFLA_BROADCAST])); + if (tb[IFLA_TXQLEN]) + dev->tx_queue_len = nla_get_u32(tb[IFLA_TXQLEN]); + if (tb[IFLA_OPERSTATE]) + set_operstate(dev, nla_get_u8(tb[IFLA_OPERSTATE])); + if (tb[IFLA_LINKMODE]) + dev->link_mode = nla_get_u8(tb[IFLA_LINKMODE]); + if (tb[IFLA_GROUP]) + dev_set_group(dev, nla_get_u32(tb[IFLA_GROUP])); + if (tb[IFLA_GSO_MAX_SIZE]) + netif_set_gso_max_size(dev, nla_get_u32(tb[IFLA_GSO_MAX_SIZE])); + if (tb[IFLA_GSO_MAX_SEGS]) + netif_set_gso_max_segs(dev, nla_get_u32(tb[IFLA_GSO_MAX_SEGS])); + if (tb[IFLA_GRO_MAX_SIZE]) + netif_set_gro_max_size(dev, nla_get_u32(tb[IFLA_GRO_MAX_SIZE])); + + return dev; +} +EXPORT_SYMBOL(rtnl_create_link); + +static int rtnl_group_changelink(const struct sk_buff *skb, + struct net *net, int group, + struct ifinfomsg *ifm, + struct netlink_ext_ack *extack, + struct nlattr **tb) +{ + struct net_device *dev, *aux; + int err; + + for_each_netdev_safe(net, dev, aux) { + if (dev->group == group) { + err = do_setlink(skb, dev, ifm, extack, tb, 0); + if (err < 0) + return err; + } + } + + return 0; +} + +static int rtnl_newlink_create(struct sk_buff *skb, struct ifinfomsg *ifm, + const struct rtnl_link_ops *ops, + struct nlattr **tb, struct nlattr **data, + struct netlink_ext_ack *extack) +{ + unsigned char name_assign_type = NET_NAME_USER; + struct net *net = sock_net(skb->sk); + struct net *dest_net, *link_net; + struct net_device *dev; + char ifname[IFNAMSIZ]; + int err; + + if (!ops->alloc && !ops->setup) + return -EOPNOTSUPP; + + if (tb[IFLA_IFNAME]) { + nla_strscpy(ifname, tb[IFLA_IFNAME], IFNAMSIZ); + } else { + snprintf(ifname, IFNAMSIZ, "%s%%d", ops->kind); + name_assign_type = NET_NAME_ENUM; + } + + dest_net = rtnl_link_get_net_capable(skb, net, tb, CAP_NET_ADMIN); + if (IS_ERR(dest_net)) + return PTR_ERR(dest_net); + + if (tb[IFLA_LINK_NETNSID]) { + int id = nla_get_s32(tb[IFLA_LINK_NETNSID]); + + link_net = get_net_ns_by_id(dest_net, id); + if (!link_net) { + NL_SET_ERR_MSG(extack, "Unknown network namespace id"); + err = -EINVAL; + goto out; + } + err = -EPERM; + if (!netlink_ns_capable(skb, link_net->user_ns, CAP_NET_ADMIN)) + goto out; + } else { + link_net = NULL; + } + + dev = rtnl_create_link(link_net ? : dest_net, ifname, + name_assign_type, ops, tb, extack); + if (IS_ERR(dev)) { + err = PTR_ERR(dev); + goto out; + } + + dev->ifindex = ifm->ifi_index; + + if (ops->newlink) + err = ops->newlink(link_net ? : net, dev, tb, data, extack); + else + err = register_netdevice(dev); + if (err < 0) { + free_netdev(dev); + goto out; + } + + err = rtnl_configure_link(dev, ifm); + if (err < 0) + goto out_unregister; + if (link_net) { + err = dev_change_net_namespace(dev, dest_net, ifname); + if (err < 0) + goto out_unregister; + } + if (tb[IFLA_MASTER]) { + err = do_set_master(dev, nla_get_u32(tb[IFLA_MASTER]), extack); + if (err) + goto out_unregister; + } +out: + if (link_net) + put_net(link_net); + put_net(dest_net); + return err; +out_unregister: + if (ops->newlink) { + LIST_HEAD(list_kill); + + ops->dellink(dev, &list_kill); + unregister_netdevice_many(&list_kill); + } else { + unregister_netdevice(dev); + } + goto out; +} + +struct rtnl_newlink_tbs { + struct nlattr *tb[IFLA_MAX + 1]; + struct nlattr *attr[RTNL_MAX_TYPE + 1]; + struct nlattr *slave_attr[RTNL_SLAVE_MAX_TYPE + 1]; +}; + +static int __rtnl_newlink(struct sk_buff *skb, struct nlmsghdr *nlh, + struct rtnl_newlink_tbs *tbs, + struct netlink_ext_ack *extack) +{ + struct nlattr *linkinfo[IFLA_INFO_MAX + 1]; + struct nlattr ** const tb = tbs->tb; + const struct rtnl_link_ops *m_ops; + struct net_device *master_dev; + struct net *net = sock_net(skb->sk); + const struct rtnl_link_ops *ops; + struct nlattr **slave_data; + char kind[MODULE_NAME_LEN]; + struct net_device *dev; + struct ifinfomsg *ifm; + struct nlattr **data; + bool link_specified; + int err; + +#ifdef CONFIG_MODULES +replay: +#endif + err = nlmsg_parse_deprecated(nlh, sizeof(*ifm), tb, IFLA_MAX, + ifla_policy, extack); + if (err < 0) + return err; + + err = rtnl_ensure_unique_netns(tb, extack, false); + if (err < 0) + return err; + + ifm = nlmsg_data(nlh); + if (ifm->ifi_index > 0) { + link_specified = true; + dev = __dev_get_by_index(net, ifm->ifi_index); + } else if (ifm->ifi_index < 0) { + NL_SET_ERR_MSG(extack, "ifindex can't be negative"); + return -EINVAL; + } else if (tb[IFLA_IFNAME] || tb[IFLA_ALT_IFNAME]) { + link_specified = true; + dev = rtnl_dev_get(net, tb); + } else { + link_specified = false; + dev = NULL; + } + + master_dev = NULL; + m_ops = NULL; + if (dev) { + master_dev = netdev_master_upper_dev_get(dev); + if (master_dev) + m_ops = master_dev->rtnl_link_ops; + } + + err = validate_linkmsg(dev, tb, extack); + if (err < 0) + return err; + + if (tb[IFLA_LINKINFO]) { + err = nla_parse_nested_deprecated(linkinfo, IFLA_INFO_MAX, + tb[IFLA_LINKINFO], + ifla_info_policy, NULL); + if (err < 0) + return err; + } else + memset(linkinfo, 0, sizeof(linkinfo)); + + if (linkinfo[IFLA_INFO_KIND]) { + nla_strscpy(kind, linkinfo[IFLA_INFO_KIND], sizeof(kind)); + ops = rtnl_link_ops_get(kind); + } else { + kind[0] = '\0'; + ops = NULL; + } + + data = NULL; + if (ops) { + if (ops->maxtype > RTNL_MAX_TYPE) + return -EINVAL; + + if (ops->maxtype && linkinfo[IFLA_INFO_DATA]) { + err = nla_parse_nested_deprecated(tbs->attr, ops->maxtype, + linkinfo[IFLA_INFO_DATA], + ops->policy, extack); + if (err < 0) + return err; + data = tbs->attr; + } + if (ops->validate) { + err = ops->validate(tb, data, extack); + if (err < 0) + return err; + } + } + + slave_data = NULL; + if (m_ops) { + if (m_ops->slave_maxtype > RTNL_SLAVE_MAX_TYPE) + return -EINVAL; + + if (m_ops->slave_maxtype && + linkinfo[IFLA_INFO_SLAVE_DATA]) { + err = nla_parse_nested_deprecated(tbs->slave_attr, + m_ops->slave_maxtype, + linkinfo[IFLA_INFO_SLAVE_DATA], + m_ops->slave_policy, + extack); + if (err < 0) + return err; + slave_data = tbs->slave_attr; + } + } + + if (dev) { + int status = 0; + + if (nlh->nlmsg_flags & NLM_F_EXCL) + return -EEXIST; + if (nlh->nlmsg_flags & NLM_F_REPLACE) + return -EOPNOTSUPP; + + if (linkinfo[IFLA_INFO_DATA]) { + if (!ops || ops != dev->rtnl_link_ops || + !ops->changelink) + return -EOPNOTSUPP; + + err = ops->changelink(dev, tb, data, extack); + if (err < 0) + return err; + status |= DO_SETLINK_NOTIFY; + } + + if (linkinfo[IFLA_INFO_SLAVE_DATA]) { + if (!m_ops || !m_ops->slave_changelink) + return -EOPNOTSUPP; + + err = m_ops->slave_changelink(master_dev, dev, tb, + slave_data, extack); + if (err < 0) + return err; + status |= DO_SETLINK_NOTIFY; + } + + return do_setlink(skb, dev, ifm, extack, tb, status); + } + + if (!(nlh->nlmsg_flags & NLM_F_CREATE)) { + /* No dev found and NLM_F_CREATE not set. Requested dev does not exist, + * or it's for a group + */ + if (link_specified) + return -ENODEV; + if (tb[IFLA_GROUP]) + return rtnl_group_changelink(skb, net, + nla_get_u32(tb[IFLA_GROUP]), + ifm, extack, tb); + return -ENODEV; + } + + if (tb[IFLA_MAP] || tb[IFLA_PROTINFO]) + return -EOPNOTSUPP; + + if (!ops) { +#ifdef CONFIG_MODULES + if (kind[0]) { + __rtnl_unlock(); + request_module("rtnl-link-%s", kind); + rtnl_lock(); + ops = rtnl_link_ops_get(kind); + if (ops) + goto replay; + } +#endif + NL_SET_ERR_MSG(extack, "Unknown device type"); + return -EOPNOTSUPP; + } + + return rtnl_newlink_create(skb, ifm, ops, tb, data, extack); +} + +static int rtnl_newlink(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct rtnl_newlink_tbs *tbs; + int ret; + + tbs = kmalloc(sizeof(*tbs), GFP_KERNEL); + if (!tbs) + return -ENOMEM; + + ret = __rtnl_newlink(skb, nlh, tbs, extack); + kfree(tbs); + return ret; +} + +static int rtnl_valid_getlink_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct ifinfomsg *ifm; + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + NL_SET_ERR_MSG(extack, "Invalid header for get link"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse_deprecated(nlh, sizeof(*ifm), tb, IFLA_MAX, + ifla_policy, extack); + + ifm = nlmsg_data(nlh); + if (ifm->__ifi_pad || ifm->ifi_type || ifm->ifi_flags || + ifm->ifi_change) { + NL_SET_ERR_MSG(extack, "Invalid values in header for get link request"); + return -EINVAL; + } + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(*ifm), tb, IFLA_MAX, + ifla_policy, extack); + if (err) + return err; + + for (i = 0; i <= IFLA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case IFLA_IFNAME: + case IFLA_ALT_IFNAME: + case IFLA_EXT_MASK: + case IFLA_TARGET_NETNSID: + break; + default: + NL_SET_ERR_MSG(extack, "Unsupported attribute in get link request"); + return -EINVAL; + } + } + + return 0; +} + +static int rtnl_getlink(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct net *tgt_net = net; + struct ifinfomsg *ifm; + struct nlattr *tb[IFLA_MAX+1]; + struct net_device *dev = NULL; + struct sk_buff *nskb; + int netnsid = -1; + int err; + u32 ext_filter_mask = 0; + + err = rtnl_valid_getlink_req(skb, nlh, tb, extack); + if (err < 0) + return err; + + err = rtnl_ensure_unique_netns(tb, extack, true); + if (err < 0) + return err; + + if (tb[IFLA_TARGET_NETNSID]) { + netnsid = nla_get_s32(tb[IFLA_TARGET_NETNSID]); + tgt_net = rtnl_get_net_ns_capable(NETLINK_CB(skb).sk, netnsid); + if (IS_ERR(tgt_net)) + return PTR_ERR(tgt_net); + } + + if (tb[IFLA_EXT_MASK]) + ext_filter_mask = nla_get_u32(tb[IFLA_EXT_MASK]); + + err = -EINVAL; + ifm = nlmsg_data(nlh); + if (ifm->ifi_index > 0) + dev = __dev_get_by_index(tgt_net, ifm->ifi_index); + else if (tb[IFLA_IFNAME] || tb[IFLA_ALT_IFNAME]) + dev = rtnl_dev_get(tgt_net, tb); + else + goto out; + + err = -ENODEV; + if (dev == NULL) + goto out; + + err = -ENOBUFS; + nskb = nlmsg_new(if_nlmsg_size(dev, ext_filter_mask), GFP_KERNEL); + if (nskb == NULL) + goto out; + + err = rtnl_fill_ifinfo(nskb, dev, net, + RTM_NEWLINK, NETLINK_CB(skb).portid, + nlh->nlmsg_seq, 0, 0, ext_filter_mask, + 0, NULL, 0, netnsid, GFP_KERNEL); + if (err < 0) { + /* -EMSGSIZE implies BUG in if_nlmsg_size */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(nskb); + } else + err = rtnl_unicast(nskb, net, NETLINK_CB(skb).portid); +out: + if (netnsid >= 0) + put_net(tgt_net); + + return err; +} + +static int rtnl_alt_ifname(int cmd, struct net_device *dev, struct nlattr *attr, + bool *changed, struct netlink_ext_ack *extack) +{ + char *alt_ifname; + size_t size; + int err; + + err = nla_validate(attr, attr->nla_len, IFLA_MAX, ifla_policy, extack); + if (err) + return err; + + if (cmd == RTM_NEWLINKPROP) { + size = rtnl_prop_list_size(dev); + size += nla_total_size(ALTIFNAMSIZ); + if (size >= U16_MAX) { + NL_SET_ERR_MSG(extack, + "effective property list too long"); + return -EINVAL; + } + } + + alt_ifname = nla_strdup(attr, GFP_KERNEL_ACCOUNT); + if (!alt_ifname) + return -ENOMEM; + + if (cmd == RTM_NEWLINKPROP) { + err = netdev_name_node_alt_create(dev, alt_ifname); + if (!err) + alt_ifname = NULL; + } else if (cmd == RTM_DELLINKPROP) { + err = netdev_name_node_alt_destroy(dev, alt_ifname); + } else { + WARN_ON_ONCE(1); + err = -EINVAL; + } + + kfree(alt_ifname); + if (!err) + *changed = true; + return err; +} + +static int rtnl_linkprop(int cmd, struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *tb[IFLA_MAX + 1]; + struct net_device *dev; + struct ifinfomsg *ifm; + bool changed = false; + struct nlattr *attr; + int err, rem; + + err = nlmsg_parse(nlh, sizeof(*ifm), tb, IFLA_MAX, ifla_policy, extack); + if (err) + return err; + + err = rtnl_ensure_unique_netns(tb, extack, true); + if (err) + return err; + + ifm = nlmsg_data(nlh); + if (ifm->ifi_index > 0) + dev = __dev_get_by_index(net, ifm->ifi_index); + else if (tb[IFLA_IFNAME] || tb[IFLA_ALT_IFNAME]) + dev = rtnl_dev_get(net, tb); + else + return -EINVAL; + + if (!dev) + return -ENODEV; + + if (!tb[IFLA_PROP_LIST]) + return 0; + + nla_for_each_nested(attr, tb[IFLA_PROP_LIST], rem) { + switch (nla_type(attr)) { + case IFLA_ALT_IFNAME: + err = rtnl_alt_ifname(cmd, dev, attr, &changed, extack); + if (err) + return err; + break; + } + } + + if (changed) + netdev_state_change(dev); + return 0; +} + +static int rtnl_newlinkprop(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + return rtnl_linkprop(RTM_NEWLINKPROP, skb, nlh, extack); +} + +static int rtnl_dellinkprop(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + return rtnl_linkprop(RTM_DELLINKPROP, skb, nlh, extack); +} + +static u32 rtnl_calcit(struct sk_buff *skb, struct nlmsghdr *nlh) +{ + struct net *net = sock_net(skb->sk); + size_t min_ifinfo_dump_size = 0; + struct nlattr *tb[IFLA_MAX+1]; + u32 ext_filter_mask = 0; + struct net_device *dev; + int hdrlen; + + /* Same kernel<->userspace interface hack as in rtnl_dump_ifinfo. */ + hdrlen = nlmsg_len(nlh) < sizeof(struct ifinfomsg) ? + sizeof(struct rtgenmsg) : sizeof(struct ifinfomsg); + + if (nlmsg_parse_deprecated(nlh, hdrlen, tb, IFLA_MAX, ifla_policy, NULL) >= 0) { + if (tb[IFLA_EXT_MASK]) + ext_filter_mask = nla_get_u32(tb[IFLA_EXT_MASK]); + } + + if (!ext_filter_mask) + return NLMSG_GOODSIZE; + /* + * traverse the list of net devices and compute the minimum + * buffer size based upon the filter mask. + */ + rcu_read_lock(); + for_each_netdev_rcu(net, dev) { + min_ifinfo_dump_size = max(min_ifinfo_dump_size, + if_nlmsg_size(dev, ext_filter_mask)); + } + rcu_read_unlock(); + + return nlmsg_total_size(min_ifinfo_dump_size); +} + +static int rtnl_dump_all(struct sk_buff *skb, struct netlink_callback *cb) +{ + int idx; + int s_idx = cb->family; + int type = cb->nlh->nlmsg_type - RTM_BASE; + int ret = 0; + + if (s_idx == 0) + s_idx = 1; + + for (idx = 1; idx <= RTNL_FAMILY_MAX; idx++) { + struct rtnl_link __rcu **tab; + struct rtnl_link *link; + rtnl_dumpit_func dumpit; + + if (idx < s_idx || idx == PF_PACKET) + continue; + + if (type < 0 || type >= RTM_NR_MSGTYPES) + continue; + + tab = rcu_dereference_rtnl(rtnl_msg_handlers[idx]); + if (!tab) + continue; + + link = rcu_dereference_rtnl(tab[type]); + if (!link) + continue; + + dumpit = link->dumpit; + if (!dumpit) + continue; + + if (idx > s_idx) { + memset(&cb->args[0], 0, sizeof(cb->args)); + cb->prev_seq = 0; + cb->seq = 0; + } + ret = dumpit(skb, cb); + if (ret) + break; + } + cb->family = idx; + + return skb->len ? : ret; +} + +struct sk_buff *rtmsg_ifinfo_build_skb(int type, struct net_device *dev, + unsigned int change, + u32 event, gfp_t flags, int *new_nsid, + int new_ifindex) +{ + struct net *net = dev_net(dev); + struct sk_buff *skb; + int err = -ENOBUFS; + + skb = nlmsg_new(if_nlmsg_size(dev, 0), flags); + if (skb == NULL) + goto errout; + + err = rtnl_fill_ifinfo(skb, dev, dev_net(dev), + type, 0, 0, change, 0, 0, event, + new_nsid, new_ifindex, -1, flags); + if (err < 0) { + /* -EMSGSIZE implies BUG in if_nlmsg_size() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + return skb; +errout: + if (err < 0) + rtnl_set_sk_err(net, RTNLGRP_LINK, err); + return NULL; +} + +void rtmsg_ifinfo_send(struct sk_buff *skb, struct net_device *dev, gfp_t flags) +{ + struct net *net = dev_net(dev); + + rtnl_notify(skb, net, 0, RTNLGRP_LINK, NULL, flags); +} + +static void rtmsg_ifinfo_event(int type, struct net_device *dev, + unsigned int change, u32 event, + gfp_t flags, int *new_nsid, int new_ifindex) +{ + struct sk_buff *skb; + + if (dev->reg_state != NETREG_REGISTERED) + return; + + skb = rtmsg_ifinfo_build_skb(type, dev, change, event, flags, new_nsid, + new_ifindex); + if (skb) + rtmsg_ifinfo_send(skb, dev, flags); +} + +void rtmsg_ifinfo(int type, struct net_device *dev, unsigned int change, + gfp_t flags) +{ + rtmsg_ifinfo_event(type, dev, change, rtnl_get_event(0), flags, + NULL, 0); +} + +void rtmsg_ifinfo_newnet(int type, struct net_device *dev, unsigned int change, + gfp_t flags, int *new_nsid, int new_ifindex) +{ + rtmsg_ifinfo_event(type, dev, change, rtnl_get_event(0), flags, + new_nsid, new_ifindex); +} + +static int nlmsg_populate_fdb_fill(struct sk_buff *skb, + struct net_device *dev, + u8 *addr, u16 vid, u32 pid, u32 seq, + int type, unsigned int flags, + int nlflags, u16 ndm_state) +{ + struct nlmsghdr *nlh; + struct ndmsg *ndm; + + nlh = nlmsg_put(skb, pid, seq, type, sizeof(*ndm), nlflags); + if (!nlh) + return -EMSGSIZE; + + ndm = nlmsg_data(nlh); + ndm->ndm_family = AF_BRIDGE; + ndm->ndm_pad1 = 0; + ndm->ndm_pad2 = 0; + ndm->ndm_flags = flags; + ndm->ndm_type = 0; + ndm->ndm_ifindex = dev->ifindex; + ndm->ndm_state = ndm_state; + + if (nla_put(skb, NDA_LLADDR, dev->addr_len, addr)) + goto nla_put_failure; + if (vid) + if (nla_put(skb, NDA_VLAN, sizeof(u16), &vid)) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static inline size_t rtnl_fdb_nlmsg_size(const struct net_device *dev) +{ + return NLMSG_ALIGN(sizeof(struct ndmsg)) + + nla_total_size(dev->addr_len) + /* NDA_LLADDR */ + nla_total_size(sizeof(u16)) + /* NDA_VLAN */ + 0; +} + +static void rtnl_fdb_notify(struct net_device *dev, u8 *addr, u16 vid, int type, + u16 ndm_state) +{ + struct net *net = dev_net(dev); + struct sk_buff *skb; + int err = -ENOBUFS; + + skb = nlmsg_new(rtnl_fdb_nlmsg_size(dev), GFP_ATOMIC); + if (!skb) + goto errout; + + err = nlmsg_populate_fdb_fill(skb, dev, addr, vid, + 0, 0, type, NTF_SELF, 0, ndm_state); + if (err < 0) { + kfree_skb(skb); + goto errout; + } + + rtnl_notify(skb, net, 0, RTNLGRP_NEIGH, NULL, GFP_ATOMIC); + return; +errout: + rtnl_set_sk_err(net, RTNLGRP_NEIGH, err); +} + +/* + * ndo_dflt_fdb_add - default netdevice operation to add an FDB entry + */ +int ndo_dflt_fdb_add(struct ndmsg *ndm, + struct nlattr *tb[], + struct net_device *dev, + const unsigned char *addr, u16 vid, + u16 flags) +{ + int err = -EINVAL; + + /* If aging addresses are supported device will need to + * implement its own handler for this. + */ + if (ndm->ndm_state && !(ndm->ndm_state & NUD_PERMANENT)) { + netdev_info(dev, "default FDB implementation only supports local addresses\n"); + return err; + } + + if (vid) { + netdev_info(dev, "vlans aren't supported yet for dev_uc|mc_add()\n"); + return err; + } + + if (is_unicast_ether_addr(addr) || is_link_local_ether_addr(addr)) + err = dev_uc_add_excl(dev, addr); + else if (is_multicast_ether_addr(addr)) + err = dev_mc_add_excl(dev, addr); + + /* Only return duplicate errors if NLM_F_EXCL is set */ + if (err == -EEXIST && !(flags & NLM_F_EXCL)) + err = 0; + + return err; +} +EXPORT_SYMBOL(ndo_dflt_fdb_add); + +static int fdb_vid_parse(struct nlattr *vlan_attr, u16 *p_vid, + struct netlink_ext_ack *extack) +{ + u16 vid = 0; + + if (vlan_attr) { + if (nla_len(vlan_attr) != sizeof(u16)) { + NL_SET_ERR_MSG(extack, "invalid vlan attribute size"); + return -EINVAL; + } + + vid = nla_get_u16(vlan_attr); + + if (!vid || vid >= VLAN_VID_MASK) { + NL_SET_ERR_MSG(extack, "invalid vlan id"); + return -EINVAL; + } + } + *p_vid = vid; + return 0; +} + +static int rtnl_fdb_add(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct ndmsg *ndm; + struct nlattr *tb[NDA_MAX+1]; + struct net_device *dev; + u8 *addr; + u16 vid; + int err; + + err = nlmsg_parse_deprecated(nlh, sizeof(*ndm), tb, NDA_MAX, NULL, + extack); + if (err < 0) + return err; + + ndm = nlmsg_data(nlh); + if (ndm->ndm_ifindex == 0) { + NL_SET_ERR_MSG(extack, "invalid ifindex"); + return -EINVAL; + } + + dev = __dev_get_by_index(net, ndm->ndm_ifindex); + if (dev == NULL) { + NL_SET_ERR_MSG(extack, "unknown ifindex"); + return -ENODEV; + } + + if (!tb[NDA_LLADDR] || nla_len(tb[NDA_LLADDR]) != ETH_ALEN) { + NL_SET_ERR_MSG(extack, "invalid address"); + return -EINVAL; + } + + if (dev->type != ARPHRD_ETHER) { + NL_SET_ERR_MSG(extack, "FDB add only supported for Ethernet devices"); + return -EINVAL; + } + + addr = nla_data(tb[NDA_LLADDR]); + + err = fdb_vid_parse(tb[NDA_VLAN], &vid, extack); + if (err) + return err; + + err = -EOPNOTSUPP; + + /* Support fdb on master device the net/bridge default case */ + if ((!ndm->ndm_flags || ndm->ndm_flags & NTF_MASTER) && + netif_is_bridge_port(dev)) { + struct net_device *br_dev = netdev_master_upper_dev_get(dev); + const struct net_device_ops *ops = br_dev->netdev_ops; + + err = ops->ndo_fdb_add(ndm, tb, dev, addr, vid, + nlh->nlmsg_flags, extack); + if (err) + goto out; + else + ndm->ndm_flags &= ~NTF_MASTER; + } + + /* Embedded bridge, macvlan, and any other device support */ + if ((ndm->ndm_flags & NTF_SELF)) { + if (dev->netdev_ops->ndo_fdb_add) + err = dev->netdev_ops->ndo_fdb_add(ndm, tb, dev, addr, + vid, + nlh->nlmsg_flags, + extack); + else + err = ndo_dflt_fdb_add(ndm, tb, dev, addr, vid, + nlh->nlmsg_flags); + + if (!err) { + rtnl_fdb_notify(dev, addr, vid, RTM_NEWNEIGH, + ndm->ndm_state); + ndm->ndm_flags &= ~NTF_SELF; + } + } +out: + return err; +} + +/* + * ndo_dflt_fdb_del - default netdevice operation to delete an FDB entry + */ +int ndo_dflt_fdb_del(struct ndmsg *ndm, + struct nlattr *tb[], + struct net_device *dev, + const unsigned char *addr, u16 vid) +{ + int err = -EINVAL; + + /* If aging addresses are supported device will need to + * implement its own handler for this. + */ + if (!(ndm->ndm_state & NUD_PERMANENT)) { + netdev_info(dev, "default FDB implementation only supports local addresses\n"); + return err; + } + + if (is_unicast_ether_addr(addr) || is_link_local_ether_addr(addr)) + err = dev_uc_del(dev, addr); + else if (is_multicast_ether_addr(addr)) + err = dev_mc_del(dev, addr); + + return err; +} +EXPORT_SYMBOL(ndo_dflt_fdb_del); + +static const struct nla_policy fdb_del_bulk_policy[NDA_MAX + 1] = { + [NDA_VLAN] = { .type = NLA_U16 }, + [NDA_IFINDEX] = NLA_POLICY_MIN(NLA_S32, 1), + [NDA_NDM_STATE_MASK] = { .type = NLA_U16 }, + [NDA_NDM_FLAGS_MASK] = { .type = NLA_U8 }, +}; + +static int rtnl_fdb_del(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + bool del_bulk = !!(nlh->nlmsg_flags & NLM_F_BULK); + struct net *net = sock_net(skb->sk); + const struct net_device_ops *ops; + struct ndmsg *ndm; + struct nlattr *tb[NDA_MAX+1]; + struct net_device *dev; + __u8 *addr = NULL; + int err; + u16 vid; + + if (!netlink_capable(skb, CAP_NET_ADMIN)) + return -EPERM; + + if (!del_bulk) { + err = nlmsg_parse_deprecated(nlh, sizeof(*ndm), tb, NDA_MAX, + NULL, extack); + } else { + err = nlmsg_parse(nlh, sizeof(*ndm), tb, NDA_MAX, + fdb_del_bulk_policy, extack); + } + if (err < 0) + return err; + + ndm = nlmsg_data(nlh); + if (ndm->ndm_ifindex == 0) { + NL_SET_ERR_MSG(extack, "invalid ifindex"); + return -EINVAL; + } + + dev = __dev_get_by_index(net, ndm->ndm_ifindex); + if (dev == NULL) { + NL_SET_ERR_MSG(extack, "unknown ifindex"); + return -ENODEV; + } + + if (!del_bulk) { + if (!tb[NDA_LLADDR] || nla_len(tb[NDA_LLADDR]) != ETH_ALEN) { + NL_SET_ERR_MSG(extack, "invalid address"); + return -EINVAL; + } + addr = nla_data(tb[NDA_LLADDR]); + } + + if (dev->type != ARPHRD_ETHER) { + NL_SET_ERR_MSG(extack, "FDB delete only supported for Ethernet devices"); + return -EINVAL; + } + + err = fdb_vid_parse(tb[NDA_VLAN], &vid, extack); + if (err) + return err; + + err = -EOPNOTSUPP; + + /* Support fdb on master device the net/bridge default case */ + if ((!ndm->ndm_flags || ndm->ndm_flags & NTF_MASTER) && + netif_is_bridge_port(dev)) { + struct net_device *br_dev = netdev_master_upper_dev_get(dev); + + ops = br_dev->netdev_ops; + if (!del_bulk) { + if (ops->ndo_fdb_del) + err = ops->ndo_fdb_del(ndm, tb, dev, addr, vid, extack); + } else { + if (ops->ndo_fdb_del_bulk) + err = ops->ndo_fdb_del_bulk(ndm, tb, dev, vid, + extack); + } + + if (err) + goto out; + else + ndm->ndm_flags &= ~NTF_MASTER; + } + + /* Embedded bridge, macvlan, and any other device support */ + if (ndm->ndm_flags & NTF_SELF) { + ops = dev->netdev_ops; + if (!del_bulk) { + if (ops->ndo_fdb_del) + err = ops->ndo_fdb_del(ndm, tb, dev, addr, vid, extack); + else + err = ndo_dflt_fdb_del(ndm, tb, dev, addr, vid); + } else { + /* in case err was cleared by NTF_MASTER call */ + err = -EOPNOTSUPP; + if (ops->ndo_fdb_del_bulk) + err = ops->ndo_fdb_del_bulk(ndm, tb, dev, vid, + extack); + } + + if (!err) { + if (!del_bulk) + rtnl_fdb_notify(dev, addr, vid, RTM_DELNEIGH, + ndm->ndm_state); + ndm->ndm_flags &= ~NTF_SELF; + } + } +out: + return err; +} + +static int nlmsg_populate_fdb(struct sk_buff *skb, + struct netlink_callback *cb, + struct net_device *dev, + int *idx, + struct netdev_hw_addr_list *list) +{ + struct netdev_hw_addr *ha; + int err; + u32 portid, seq; + + portid = NETLINK_CB(cb->skb).portid; + seq = cb->nlh->nlmsg_seq; + + list_for_each_entry(ha, &list->list, list) { + if (*idx < cb->args[2]) + goto skip; + + err = nlmsg_populate_fdb_fill(skb, dev, ha->addr, 0, + portid, seq, + RTM_NEWNEIGH, NTF_SELF, + NLM_F_MULTI, NUD_PERMANENT); + if (err < 0) + return err; +skip: + *idx += 1; + } + return 0; +} + +/** + * ndo_dflt_fdb_dump - default netdevice operation to dump an FDB table. + * @skb: socket buffer to store message in + * @cb: netlink callback + * @dev: netdevice + * @filter_dev: ignored + * @idx: the number of FDB table entries dumped is added to *@idx + * + * Default netdevice operation to dump the existing unicast address list. + * Returns number of addresses from list put in skb. + */ +int ndo_dflt_fdb_dump(struct sk_buff *skb, + struct netlink_callback *cb, + struct net_device *dev, + struct net_device *filter_dev, + int *idx) +{ + int err; + + if (dev->type != ARPHRD_ETHER) + return -EINVAL; + + netif_addr_lock_bh(dev); + err = nlmsg_populate_fdb(skb, cb, dev, idx, &dev->uc); + if (err) + goto out; + err = nlmsg_populate_fdb(skb, cb, dev, idx, &dev->mc); +out: + netif_addr_unlock_bh(dev); + return err; +} +EXPORT_SYMBOL(ndo_dflt_fdb_dump); + +static int valid_fdb_dump_strict(const struct nlmsghdr *nlh, + int *br_idx, int *brport_idx, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[NDA_MAX + 1]; + struct ndmsg *ndm; + int err, i; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ndm))) { + NL_SET_ERR_MSG(extack, "Invalid header for fdb dump request"); + return -EINVAL; + } + + ndm = nlmsg_data(nlh); + if (ndm->ndm_pad1 || ndm->ndm_pad2 || ndm->ndm_state || + ndm->ndm_flags || ndm->ndm_type) { + NL_SET_ERR_MSG(extack, "Invalid values in header for fdb dump request"); + return -EINVAL; + } + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(struct ndmsg), tb, + NDA_MAX, NULL, extack); + if (err < 0) + return err; + + *brport_idx = ndm->ndm_ifindex; + for (i = 0; i <= NDA_MAX; ++i) { + if (!tb[i]) + continue; + + switch (i) { + case NDA_IFINDEX: + if (nla_len(tb[i]) != sizeof(u32)) { + NL_SET_ERR_MSG(extack, "Invalid IFINDEX attribute in fdb dump request"); + return -EINVAL; + } + *brport_idx = nla_get_u32(tb[NDA_IFINDEX]); + break; + case NDA_MASTER: + if (nla_len(tb[i]) != sizeof(u32)) { + NL_SET_ERR_MSG(extack, "Invalid MASTER attribute in fdb dump request"); + return -EINVAL; + } + *br_idx = nla_get_u32(tb[NDA_MASTER]); + break; + default: + NL_SET_ERR_MSG(extack, "Unsupported attribute in fdb dump request"); + return -EINVAL; + } + } + + return 0; +} + +static int valid_fdb_dump_legacy(const struct nlmsghdr *nlh, + int *br_idx, int *brport_idx, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_MAX+1]; + int err; + + /* A hack to preserve kernel<->userspace interface. + * Before Linux v4.12 this code accepted ndmsg since iproute2 v3.3.0. + * However, ndmsg is shorter than ifinfomsg thus nlmsg_parse() bails. + * So, check for ndmsg with an optional u32 attribute (not used here). + * Fortunately these sizes don't conflict with the size of ifinfomsg + * with an optional attribute. + */ + if (nlmsg_len(nlh) != sizeof(struct ndmsg) && + (nlmsg_len(nlh) != sizeof(struct ndmsg) + + nla_attr_size(sizeof(u32)))) { + struct ifinfomsg *ifm; + + err = nlmsg_parse_deprecated(nlh, sizeof(struct ifinfomsg), + tb, IFLA_MAX, ifla_policy, + extack); + if (err < 0) { + return -EINVAL; + } else if (err == 0) { + if (tb[IFLA_MASTER]) + *br_idx = nla_get_u32(tb[IFLA_MASTER]); + } + + ifm = nlmsg_data(nlh); + *brport_idx = ifm->ifi_index; + } + return 0; +} + +static int rtnl_fdb_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net_device *dev; + struct net_device *br_dev = NULL; + const struct net_device_ops *ops = NULL; + const struct net_device_ops *cops = NULL; + struct net *net = sock_net(skb->sk); + struct hlist_head *head; + int brport_idx = 0; + int br_idx = 0; + int h, s_h; + int idx = 0, s_idx; + int err = 0; + int fidx = 0; + + if (cb->strict_check) + err = valid_fdb_dump_strict(cb->nlh, &br_idx, &brport_idx, + cb->extack); + else + err = valid_fdb_dump_legacy(cb->nlh, &br_idx, &brport_idx, + cb->extack); + if (err < 0) + return err; + + if (br_idx) { + br_dev = __dev_get_by_index(net, br_idx); + if (!br_dev) + return -ENODEV; + + ops = br_dev->netdev_ops; + } + + s_h = cb->args[0]; + s_idx = cb->args[1]; + + for (h = s_h; h < NETDEV_HASHENTRIES; h++, s_idx = 0) { + idx = 0; + head = &net->dev_index_head[h]; + hlist_for_each_entry(dev, head, index_hlist) { + + if (brport_idx && (dev->ifindex != brport_idx)) + continue; + + if (!br_idx) { /* user did not specify a specific bridge */ + if (netif_is_bridge_port(dev)) { + br_dev = netdev_master_upper_dev_get(dev); + cops = br_dev->netdev_ops; + } + } else { + if (dev != br_dev && + !netif_is_bridge_port(dev)) + continue; + + if (br_dev != netdev_master_upper_dev_get(dev) && + !netif_is_bridge_master(dev)) + continue; + cops = ops; + } + + if (idx < s_idx) + goto cont; + + if (netif_is_bridge_port(dev)) { + if (cops && cops->ndo_fdb_dump) { + err = cops->ndo_fdb_dump(skb, cb, + br_dev, dev, + &fidx); + if (err == -EMSGSIZE) + goto out; + } + } + + if (dev->netdev_ops->ndo_fdb_dump) + err = dev->netdev_ops->ndo_fdb_dump(skb, cb, + dev, NULL, + &fidx); + else + err = ndo_dflt_fdb_dump(skb, cb, dev, NULL, + &fidx); + if (err == -EMSGSIZE) + goto out; + + cops = NULL; + + /* reset fdb offset to 0 for rest of the interfaces */ + cb->args[2] = 0; + fidx = 0; +cont: + idx++; + } + } + +out: + cb->args[0] = h; + cb->args[1] = idx; + cb->args[2] = fidx; + + return skb->len; +} + +static int valid_fdb_get_strict(const struct nlmsghdr *nlh, + struct nlattr **tb, u8 *ndm_flags, + int *br_idx, int *brport_idx, u8 **addr, + u16 *vid, struct netlink_ext_ack *extack) +{ + struct ndmsg *ndm; + int err, i; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ndm))) { + NL_SET_ERR_MSG(extack, "Invalid header for fdb get request"); + return -EINVAL; + } + + ndm = nlmsg_data(nlh); + if (ndm->ndm_pad1 || ndm->ndm_pad2 || ndm->ndm_state || + ndm->ndm_type) { + NL_SET_ERR_MSG(extack, "Invalid values in header for fdb get request"); + return -EINVAL; + } + + if (ndm->ndm_flags & ~(NTF_MASTER | NTF_SELF)) { + NL_SET_ERR_MSG(extack, "Invalid flags in header for fdb get request"); + return -EINVAL; + } + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(struct ndmsg), tb, + NDA_MAX, nda_policy, extack); + if (err < 0) + return err; + + *ndm_flags = ndm->ndm_flags; + *brport_idx = ndm->ndm_ifindex; + for (i = 0; i <= NDA_MAX; ++i) { + if (!tb[i]) + continue; + + switch (i) { + case NDA_MASTER: + *br_idx = nla_get_u32(tb[i]); + break; + case NDA_LLADDR: + if (nla_len(tb[i]) != ETH_ALEN) { + NL_SET_ERR_MSG(extack, "Invalid address in fdb get request"); + return -EINVAL; + } + *addr = nla_data(tb[i]); + break; + case NDA_VLAN: + err = fdb_vid_parse(tb[i], vid, extack); + if (err) + return err; + break; + case NDA_VNI: + break; + default: + NL_SET_ERR_MSG(extack, "Unsupported attribute in fdb get request"); + return -EINVAL; + } + } + + return 0; +} + +static int rtnl_fdb_get(struct sk_buff *in_skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net_device *dev = NULL, *br_dev = NULL; + const struct net_device_ops *ops = NULL; + struct net *net = sock_net(in_skb->sk); + struct nlattr *tb[NDA_MAX + 1]; + struct sk_buff *skb; + int brport_idx = 0; + u8 ndm_flags = 0; + int br_idx = 0; + u8 *addr = NULL; + u16 vid = 0; + int err; + + err = valid_fdb_get_strict(nlh, tb, &ndm_flags, &br_idx, + &brport_idx, &addr, &vid, extack); + if (err < 0) + return err; + + if (!addr) { + NL_SET_ERR_MSG(extack, "Missing lookup address for fdb get request"); + return -EINVAL; + } + + if (brport_idx) { + dev = __dev_get_by_index(net, brport_idx); + if (!dev) { + NL_SET_ERR_MSG(extack, "Unknown device ifindex"); + return -ENODEV; + } + } + + if (br_idx) { + if (dev) { + NL_SET_ERR_MSG(extack, "Master and device are mutually exclusive"); + return -EINVAL; + } + + br_dev = __dev_get_by_index(net, br_idx); + if (!br_dev) { + NL_SET_ERR_MSG(extack, "Invalid master ifindex"); + return -EINVAL; + } + ops = br_dev->netdev_ops; + } + + if (dev) { + if (!ndm_flags || (ndm_flags & NTF_MASTER)) { + if (!netif_is_bridge_port(dev)) { + NL_SET_ERR_MSG(extack, "Device is not a bridge port"); + return -EINVAL; + } + br_dev = netdev_master_upper_dev_get(dev); + if (!br_dev) { + NL_SET_ERR_MSG(extack, "Master of device not found"); + return -EINVAL; + } + ops = br_dev->netdev_ops; + } else { + if (!(ndm_flags & NTF_SELF)) { + NL_SET_ERR_MSG(extack, "Missing NTF_SELF"); + return -EINVAL; + } + ops = dev->netdev_ops; + } + } + + if (!br_dev && !dev) { + NL_SET_ERR_MSG(extack, "No device specified"); + return -ENODEV; + } + + if (!ops || !ops->ndo_fdb_get) { + NL_SET_ERR_MSG(extack, "Fdb get operation not supported by device"); + return -EOPNOTSUPP; + } + + skb = nlmsg_new(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) + return -ENOBUFS; + + if (br_dev) + dev = br_dev; + err = ops->ndo_fdb_get(skb, tb, dev, addr, vid, + NETLINK_CB(in_skb).portid, + nlh->nlmsg_seq, extack); + if (err) + goto out; + + return rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid); +out: + kfree_skb(skb); + return err; +} + +static int brport_nla_put_flag(struct sk_buff *skb, u32 flags, u32 mask, + unsigned int attrnum, unsigned int flag) +{ + if (mask & flag) + return nla_put_u8(skb, attrnum, !!(flags & flag)); + return 0; +} + +int ndo_dflt_bridge_getlink(struct sk_buff *skb, u32 pid, u32 seq, + struct net_device *dev, u16 mode, + u32 flags, u32 mask, int nlflags, + u32 filter_mask, + int (*vlan_fill)(struct sk_buff *skb, + struct net_device *dev, + u32 filter_mask)) +{ + struct nlmsghdr *nlh; + struct ifinfomsg *ifm; + struct nlattr *br_afspec; + struct nlattr *protinfo; + u8 operstate = netif_running(dev) ? dev->operstate : IF_OPER_DOWN; + struct net_device *br_dev = netdev_master_upper_dev_get(dev); + int err = 0; + + nlh = nlmsg_put(skb, pid, seq, RTM_NEWLINK, sizeof(*ifm), nlflags); + if (nlh == NULL) + return -EMSGSIZE; + + ifm = nlmsg_data(nlh); + ifm->ifi_family = AF_BRIDGE; + ifm->__ifi_pad = 0; + ifm->ifi_type = dev->type; + ifm->ifi_index = dev->ifindex; + ifm->ifi_flags = dev_get_flags(dev); + ifm->ifi_change = 0; + + + if (nla_put_string(skb, IFLA_IFNAME, dev->name) || + nla_put_u32(skb, IFLA_MTU, dev->mtu) || + nla_put_u8(skb, IFLA_OPERSTATE, operstate) || + (br_dev && + nla_put_u32(skb, IFLA_MASTER, br_dev->ifindex)) || + (dev->addr_len && + nla_put(skb, IFLA_ADDRESS, dev->addr_len, dev->dev_addr)) || + (dev->ifindex != dev_get_iflink(dev) && + nla_put_u32(skb, IFLA_LINK, dev_get_iflink(dev)))) + goto nla_put_failure; + + br_afspec = nla_nest_start_noflag(skb, IFLA_AF_SPEC); + if (!br_afspec) + goto nla_put_failure; + + if (nla_put_u16(skb, IFLA_BRIDGE_FLAGS, BRIDGE_FLAGS_SELF)) { + nla_nest_cancel(skb, br_afspec); + goto nla_put_failure; + } + + if (mode != BRIDGE_MODE_UNDEF) { + if (nla_put_u16(skb, IFLA_BRIDGE_MODE, mode)) { + nla_nest_cancel(skb, br_afspec); + goto nla_put_failure; + } + } + if (vlan_fill) { + err = vlan_fill(skb, dev, filter_mask); + if (err) { + nla_nest_cancel(skb, br_afspec); + goto nla_put_failure; + } + } + nla_nest_end(skb, br_afspec); + + protinfo = nla_nest_start(skb, IFLA_PROTINFO); + if (!protinfo) + goto nla_put_failure; + + if (brport_nla_put_flag(skb, flags, mask, + IFLA_BRPORT_MODE, BR_HAIRPIN_MODE) || + brport_nla_put_flag(skb, flags, mask, + IFLA_BRPORT_GUARD, BR_BPDU_GUARD) || + brport_nla_put_flag(skb, flags, mask, + IFLA_BRPORT_FAST_LEAVE, + BR_MULTICAST_FAST_LEAVE) || + brport_nla_put_flag(skb, flags, mask, + IFLA_BRPORT_PROTECT, BR_ROOT_BLOCK) || + brport_nla_put_flag(skb, flags, mask, + IFLA_BRPORT_LEARNING, BR_LEARNING) || + brport_nla_put_flag(skb, flags, mask, + IFLA_BRPORT_LEARNING_SYNC, BR_LEARNING_SYNC) || + brport_nla_put_flag(skb, flags, mask, + IFLA_BRPORT_UNICAST_FLOOD, BR_FLOOD) || + brport_nla_put_flag(skb, flags, mask, + IFLA_BRPORT_PROXYARP, BR_PROXYARP) || + brport_nla_put_flag(skb, flags, mask, + IFLA_BRPORT_MCAST_FLOOD, BR_MCAST_FLOOD) || + brport_nla_put_flag(skb, flags, mask, + IFLA_BRPORT_BCAST_FLOOD, BR_BCAST_FLOOD)) { + nla_nest_cancel(skb, protinfo); + goto nla_put_failure; + } + + nla_nest_end(skb, protinfo); + + nlmsg_end(skb, nlh); + return 0; +nla_put_failure: + nlmsg_cancel(skb, nlh); + return err ? err : -EMSGSIZE; +} +EXPORT_SYMBOL_GPL(ndo_dflt_bridge_getlink); + +static int valid_bridge_getlink_req(const struct nlmsghdr *nlh, + bool strict_check, u32 *filter_mask, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_MAX+1]; + int err, i; + + if (strict_check) { + struct ifinfomsg *ifm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + NL_SET_ERR_MSG(extack, "Invalid header for bridge link dump"); + return -EINVAL; + } + + ifm = nlmsg_data(nlh); + if (ifm->__ifi_pad || ifm->ifi_type || ifm->ifi_flags || + ifm->ifi_change || ifm->ifi_index) { + NL_SET_ERR_MSG(extack, "Invalid values in header for bridge link dump request"); + return -EINVAL; + } + + err = nlmsg_parse_deprecated_strict(nlh, + sizeof(struct ifinfomsg), + tb, IFLA_MAX, ifla_policy, + extack); + } else { + err = nlmsg_parse_deprecated(nlh, sizeof(struct ifinfomsg), + tb, IFLA_MAX, ifla_policy, + extack); + } + if (err < 0) + return err; + + /* new attributes should only be added with strict checking */ + for (i = 0; i <= IFLA_MAX; ++i) { + if (!tb[i]) + continue; + + switch (i) { + case IFLA_EXT_MASK: + *filter_mask = nla_get_u32(tb[i]); + break; + default: + if (strict_check) { + NL_SET_ERR_MSG(extack, "Unsupported attribute in bridge link dump request"); + return -EINVAL; + } + } + } + + return 0; +} + +static int rtnl_bridge_getlink(struct sk_buff *skb, struct netlink_callback *cb) +{ + const struct nlmsghdr *nlh = cb->nlh; + struct net *net = sock_net(skb->sk); + struct net_device *dev; + int idx = 0; + u32 portid = NETLINK_CB(cb->skb).portid; + u32 seq = nlh->nlmsg_seq; + u32 filter_mask = 0; + int err; + + err = valid_bridge_getlink_req(nlh, cb->strict_check, &filter_mask, + cb->extack); + if (err < 0 && cb->strict_check) + return err; + + rcu_read_lock(); + for_each_netdev_rcu(net, dev) { + const struct net_device_ops *ops = dev->netdev_ops; + struct net_device *br_dev = netdev_master_upper_dev_get(dev); + + if (br_dev && br_dev->netdev_ops->ndo_bridge_getlink) { + if (idx >= cb->args[0]) { + err = br_dev->netdev_ops->ndo_bridge_getlink( + skb, portid, seq, dev, + filter_mask, NLM_F_MULTI); + if (err < 0 && err != -EOPNOTSUPP) { + if (likely(skb->len)) + break; + + goto out_err; + } + } + idx++; + } + + if (ops->ndo_bridge_getlink) { + if (idx >= cb->args[0]) { + err = ops->ndo_bridge_getlink(skb, portid, + seq, dev, + filter_mask, + NLM_F_MULTI); + if (err < 0 && err != -EOPNOTSUPP) { + if (likely(skb->len)) + break; + + goto out_err; + } + } + idx++; + } + } + err = skb->len; +out_err: + rcu_read_unlock(); + cb->args[0] = idx; + + return err; +} + +static inline size_t bridge_nlmsg_size(void) +{ + return NLMSG_ALIGN(sizeof(struct ifinfomsg)) + + nla_total_size(IFNAMSIZ) /* IFLA_IFNAME */ + + nla_total_size(MAX_ADDR_LEN) /* IFLA_ADDRESS */ + + nla_total_size(sizeof(u32)) /* IFLA_MASTER */ + + nla_total_size(sizeof(u32)) /* IFLA_MTU */ + + nla_total_size(sizeof(u32)) /* IFLA_LINK */ + + nla_total_size(sizeof(u32)) /* IFLA_OPERSTATE */ + + nla_total_size(sizeof(u8)) /* IFLA_PROTINFO */ + + nla_total_size(sizeof(struct nlattr)) /* IFLA_AF_SPEC */ + + nla_total_size(sizeof(u16)) /* IFLA_BRIDGE_FLAGS */ + + nla_total_size(sizeof(u16)); /* IFLA_BRIDGE_MODE */ +} + +static int rtnl_bridge_notify(struct net_device *dev) +{ + struct net *net = dev_net(dev); + struct sk_buff *skb; + int err = -EOPNOTSUPP; + + if (!dev->netdev_ops->ndo_bridge_getlink) + return 0; + + skb = nlmsg_new(bridge_nlmsg_size(), GFP_ATOMIC); + if (!skb) { + err = -ENOMEM; + goto errout; + } + + err = dev->netdev_ops->ndo_bridge_getlink(skb, 0, 0, dev, 0, 0); + if (err < 0) + goto errout; + + /* Notification info is only filled for bridge ports, not the bridge + * device itself. Therefore, a zero notification length is valid and + * should not result in an error. + */ + if (!skb->len) + goto errout; + + rtnl_notify(skb, net, 0, RTNLGRP_LINK, NULL, GFP_ATOMIC); + return 0; +errout: + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + if (err) + rtnl_set_sk_err(net, RTNLGRP_LINK, err); + return err; +} + +static int rtnl_bridge_setlink(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct ifinfomsg *ifm; + struct net_device *dev; + struct nlattr *br_spec, *attr = NULL; + int rem, err = -EOPNOTSUPP; + u16 flags = 0; + bool have_flags = false; + + if (nlmsg_len(nlh) < sizeof(*ifm)) + return -EINVAL; + + ifm = nlmsg_data(nlh); + if (ifm->ifi_family != AF_BRIDGE) + return -EPFNOSUPPORT; + + dev = __dev_get_by_index(net, ifm->ifi_index); + if (!dev) { + NL_SET_ERR_MSG(extack, "unknown ifindex"); + return -ENODEV; + } + + br_spec = nlmsg_find_attr(nlh, sizeof(struct ifinfomsg), IFLA_AF_SPEC); + if (br_spec) { + nla_for_each_nested(attr, br_spec, rem) { + if (nla_type(attr) == IFLA_BRIDGE_FLAGS && !have_flags) { + if (nla_len(attr) < sizeof(flags)) + return -EINVAL; + + have_flags = true; + flags = nla_get_u16(attr); + } + + if (nla_type(attr) == IFLA_BRIDGE_MODE) { + if (nla_len(attr) < sizeof(u16)) + return -EINVAL; + } + } + } + + if (!flags || (flags & BRIDGE_FLAGS_MASTER)) { + struct net_device *br_dev = netdev_master_upper_dev_get(dev); + + if (!br_dev || !br_dev->netdev_ops->ndo_bridge_setlink) { + err = -EOPNOTSUPP; + goto out; + } + + err = br_dev->netdev_ops->ndo_bridge_setlink(dev, nlh, flags, + extack); + if (err) + goto out; + + flags &= ~BRIDGE_FLAGS_MASTER; + } + + if ((flags & BRIDGE_FLAGS_SELF)) { + if (!dev->netdev_ops->ndo_bridge_setlink) + err = -EOPNOTSUPP; + else + err = dev->netdev_ops->ndo_bridge_setlink(dev, nlh, + flags, + extack); + if (!err) { + flags &= ~BRIDGE_FLAGS_SELF; + + /* Generate event to notify upper layer of bridge + * change + */ + err = rtnl_bridge_notify(dev); + } + } + + if (have_flags) + memcpy(nla_data(attr), &flags, sizeof(flags)); +out: + return err; +} + +static int rtnl_bridge_dellink(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct ifinfomsg *ifm; + struct net_device *dev; + struct nlattr *br_spec, *attr = NULL; + int rem, err = -EOPNOTSUPP; + u16 flags = 0; + bool have_flags = false; + + if (nlmsg_len(nlh) < sizeof(*ifm)) + return -EINVAL; + + ifm = nlmsg_data(nlh); + if (ifm->ifi_family != AF_BRIDGE) + return -EPFNOSUPPORT; + + dev = __dev_get_by_index(net, ifm->ifi_index); + if (!dev) { + NL_SET_ERR_MSG(extack, "unknown ifindex"); + return -ENODEV; + } + + br_spec = nlmsg_find_attr(nlh, sizeof(struct ifinfomsg), IFLA_AF_SPEC); + if (br_spec) { + nla_for_each_nested(attr, br_spec, rem) { + if (nla_type(attr) == IFLA_BRIDGE_FLAGS) { + if (nla_len(attr) < sizeof(flags)) + return -EINVAL; + + have_flags = true; + flags = nla_get_u16(attr); + break; + } + } + } + + if (!flags || (flags & BRIDGE_FLAGS_MASTER)) { + struct net_device *br_dev = netdev_master_upper_dev_get(dev); + + if (!br_dev || !br_dev->netdev_ops->ndo_bridge_dellink) { + err = -EOPNOTSUPP; + goto out; + } + + err = br_dev->netdev_ops->ndo_bridge_dellink(dev, nlh, flags); + if (err) + goto out; + + flags &= ~BRIDGE_FLAGS_MASTER; + } + + if ((flags & BRIDGE_FLAGS_SELF)) { + if (!dev->netdev_ops->ndo_bridge_dellink) + err = -EOPNOTSUPP; + else + err = dev->netdev_ops->ndo_bridge_dellink(dev, nlh, + flags); + + if (!err) { + flags &= ~BRIDGE_FLAGS_SELF; + + /* Generate event to notify upper layer of bridge + * change + */ + err = rtnl_bridge_notify(dev); + } + } + + if (have_flags) + memcpy(nla_data(attr), &flags, sizeof(flags)); +out: + return err; +} + +static bool stats_attr_valid(unsigned int mask, int attrid, int idxattr) +{ + return (mask & IFLA_STATS_FILTER_BIT(attrid)) && + (!idxattr || idxattr == attrid); +} + +static bool +rtnl_offload_xstats_have_ndo(const struct net_device *dev, int attr_id) +{ + return dev->netdev_ops && + dev->netdev_ops->ndo_has_offload_stats && + dev->netdev_ops->ndo_get_offload_stats && + dev->netdev_ops->ndo_has_offload_stats(dev, attr_id); +} + +static unsigned int +rtnl_offload_xstats_get_size_ndo(const struct net_device *dev, int attr_id) +{ + return rtnl_offload_xstats_have_ndo(dev, attr_id) ? + sizeof(struct rtnl_link_stats64) : 0; +} + +static int +rtnl_offload_xstats_fill_ndo(struct net_device *dev, int attr_id, + struct sk_buff *skb) +{ + unsigned int size = rtnl_offload_xstats_get_size_ndo(dev, attr_id); + struct nlattr *attr = NULL; + void *attr_data; + int err; + + if (!size) + return -ENODATA; + + attr = nla_reserve_64bit(skb, attr_id, size, + IFLA_OFFLOAD_XSTATS_UNSPEC); + if (!attr) + return -EMSGSIZE; + + attr_data = nla_data(attr); + memset(attr_data, 0, size); + + err = dev->netdev_ops->ndo_get_offload_stats(attr_id, dev, attr_data); + if (err) + return err; + + return 0; +} + +static unsigned int +rtnl_offload_xstats_get_size_stats(const struct net_device *dev, + enum netdev_offload_xstats_type type) +{ + bool enabled = netdev_offload_xstats_enabled(dev, type); + + return enabled ? sizeof(struct rtnl_hw_stats64) : 0; +} + +struct rtnl_offload_xstats_request_used { + bool request; + bool used; +}; + +static int +rtnl_offload_xstats_get_stats(struct net_device *dev, + enum netdev_offload_xstats_type type, + struct rtnl_offload_xstats_request_used *ru, + struct rtnl_hw_stats64 *stats, + struct netlink_ext_ack *extack) +{ + bool request; + bool used; + int err; + + request = netdev_offload_xstats_enabled(dev, type); + if (!request) { + used = false; + goto out; + } + + err = netdev_offload_xstats_get(dev, type, stats, &used, extack); + if (err) + return err; + +out: + if (ru) { + ru->request = request; + ru->used = used; + } + return 0; +} + +static int +rtnl_offload_xstats_fill_hw_s_info_one(struct sk_buff *skb, int attr_id, + struct rtnl_offload_xstats_request_used *ru) +{ + struct nlattr *nest; + + nest = nla_nest_start(skb, attr_id); + if (!nest) + return -EMSGSIZE; + + if (nla_put_u8(skb, IFLA_OFFLOAD_XSTATS_HW_S_INFO_REQUEST, ru->request)) + goto nla_put_failure; + + if (nla_put_u8(skb, IFLA_OFFLOAD_XSTATS_HW_S_INFO_USED, ru->used)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int +rtnl_offload_xstats_fill_hw_s_info(struct sk_buff *skb, struct net_device *dev, + struct netlink_ext_ack *extack) +{ + enum netdev_offload_xstats_type t_l3 = NETDEV_OFFLOAD_XSTATS_TYPE_L3; + struct rtnl_offload_xstats_request_used ru_l3; + struct nlattr *nest; + int err; + + err = rtnl_offload_xstats_get_stats(dev, t_l3, &ru_l3, NULL, extack); + if (err) + return err; + + nest = nla_nest_start(skb, IFLA_OFFLOAD_XSTATS_HW_S_INFO); + if (!nest) + return -EMSGSIZE; + + if (rtnl_offload_xstats_fill_hw_s_info_one(skb, + IFLA_OFFLOAD_XSTATS_L3_STATS, + &ru_l3)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int rtnl_offload_xstats_fill(struct sk_buff *skb, struct net_device *dev, + int *prividx, u32 off_filter_mask, + struct netlink_ext_ack *extack) +{ + enum netdev_offload_xstats_type t_l3 = NETDEV_OFFLOAD_XSTATS_TYPE_L3; + int attr_id_hw_s_info = IFLA_OFFLOAD_XSTATS_HW_S_INFO; + int attr_id_l3_stats = IFLA_OFFLOAD_XSTATS_L3_STATS; + int attr_id_cpu_hit = IFLA_OFFLOAD_XSTATS_CPU_HIT; + bool have_data = false; + int err; + + if (*prividx <= attr_id_cpu_hit && + (off_filter_mask & + IFLA_STATS_FILTER_BIT(attr_id_cpu_hit))) { + err = rtnl_offload_xstats_fill_ndo(dev, attr_id_cpu_hit, skb); + if (!err) { + have_data = true; + } else if (err != -ENODATA) { + *prividx = attr_id_cpu_hit; + return err; + } + } + + if (*prividx <= attr_id_hw_s_info && + (off_filter_mask & IFLA_STATS_FILTER_BIT(attr_id_hw_s_info))) { + *prividx = attr_id_hw_s_info; + + err = rtnl_offload_xstats_fill_hw_s_info(skb, dev, extack); + if (err) + return err; + + have_data = true; + *prividx = 0; + } + + if (*prividx <= attr_id_l3_stats && + (off_filter_mask & IFLA_STATS_FILTER_BIT(attr_id_l3_stats))) { + unsigned int size_l3; + struct nlattr *attr; + + *prividx = attr_id_l3_stats; + + size_l3 = rtnl_offload_xstats_get_size_stats(dev, t_l3); + if (!size_l3) + goto skip_l3_stats; + attr = nla_reserve_64bit(skb, attr_id_l3_stats, size_l3, + IFLA_OFFLOAD_XSTATS_UNSPEC); + if (!attr) + return -EMSGSIZE; + + err = rtnl_offload_xstats_get_stats(dev, t_l3, NULL, + nla_data(attr), extack); + if (err) + return err; + + have_data = true; +skip_l3_stats: + *prividx = 0; + } + + if (!have_data) + return -ENODATA; + + *prividx = 0; + return 0; +} + +static unsigned int +rtnl_offload_xstats_get_size_hw_s_info_one(const struct net_device *dev, + enum netdev_offload_xstats_type type) +{ + return nla_total_size(0) + + /* IFLA_OFFLOAD_XSTATS_HW_S_INFO_REQUEST */ + nla_total_size(sizeof(u8)) + + /* IFLA_OFFLOAD_XSTATS_HW_S_INFO_USED */ + nla_total_size(sizeof(u8)) + + 0; +} + +static unsigned int +rtnl_offload_xstats_get_size_hw_s_info(const struct net_device *dev) +{ + enum netdev_offload_xstats_type t_l3 = NETDEV_OFFLOAD_XSTATS_TYPE_L3; + + return nla_total_size(0) + + /* IFLA_OFFLOAD_XSTATS_L3_STATS */ + rtnl_offload_xstats_get_size_hw_s_info_one(dev, t_l3) + + 0; +} + +static int rtnl_offload_xstats_get_size(const struct net_device *dev, + u32 off_filter_mask) +{ + enum netdev_offload_xstats_type t_l3 = NETDEV_OFFLOAD_XSTATS_TYPE_L3; + int attr_id_cpu_hit = IFLA_OFFLOAD_XSTATS_CPU_HIT; + int nla_size = 0; + int size; + + if (off_filter_mask & + IFLA_STATS_FILTER_BIT(attr_id_cpu_hit)) { + size = rtnl_offload_xstats_get_size_ndo(dev, attr_id_cpu_hit); + nla_size += nla_total_size_64bit(size); + } + + if (off_filter_mask & + IFLA_STATS_FILTER_BIT(IFLA_OFFLOAD_XSTATS_HW_S_INFO)) + nla_size += rtnl_offload_xstats_get_size_hw_s_info(dev); + + if (off_filter_mask & + IFLA_STATS_FILTER_BIT(IFLA_OFFLOAD_XSTATS_L3_STATS)) { + size = rtnl_offload_xstats_get_size_stats(dev, t_l3); + nla_size += nla_total_size_64bit(size); + } + + if (nla_size != 0) + nla_size += nla_total_size(0); + + return nla_size; +} + +struct rtnl_stats_dump_filters { + /* mask[0] filters outer attributes. Then individual nests have their + * filtering mask at the index of the nested attribute. + */ + u32 mask[IFLA_STATS_MAX + 1]; +}; + +static int rtnl_fill_statsinfo(struct sk_buff *skb, struct net_device *dev, + int type, u32 pid, u32 seq, u32 change, + unsigned int flags, + const struct rtnl_stats_dump_filters *filters, + int *idxattr, int *prividx, + struct netlink_ext_ack *extack) +{ + unsigned int filter_mask = filters->mask[0]; + struct if_stats_msg *ifsm; + struct nlmsghdr *nlh; + struct nlattr *attr; + int s_prividx = *prividx; + int err; + + ASSERT_RTNL(); + + nlh = nlmsg_put(skb, pid, seq, type, sizeof(*ifsm), flags); + if (!nlh) + return -EMSGSIZE; + + ifsm = nlmsg_data(nlh); + ifsm->family = PF_UNSPEC; + ifsm->pad1 = 0; + ifsm->pad2 = 0; + ifsm->ifindex = dev->ifindex; + ifsm->filter_mask = filter_mask; + + if (stats_attr_valid(filter_mask, IFLA_STATS_LINK_64, *idxattr)) { + struct rtnl_link_stats64 *sp; + + attr = nla_reserve_64bit(skb, IFLA_STATS_LINK_64, + sizeof(struct rtnl_link_stats64), + IFLA_STATS_UNSPEC); + if (!attr) { + err = -EMSGSIZE; + goto nla_put_failure; + } + + sp = nla_data(attr); + dev_get_stats(dev, sp); + } + + if (stats_attr_valid(filter_mask, IFLA_STATS_LINK_XSTATS, *idxattr)) { + const struct rtnl_link_ops *ops = dev->rtnl_link_ops; + + if (ops && ops->fill_linkxstats) { + *idxattr = IFLA_STATS_LINK_XSTATS; + attr = nla_nest_start_noflag(skb, + IFLA_STATS_LINK_XSTATS); + if (!attr) { + err = -EMSGSIZE; + goto nla_put_failure; + } + + err = ops->fill_linkxstats(skb, dev, prividx, *idxattr); + nla_nest_end(skb, attr); + if (err) + goto nla_put_failure; + *idxattr = 0; + } + } + + if (stats_attr_valid(filter_mask, IFLA_STATS_LINK_XSTATS_SLAVE, + *idxattr)) { + const struct rtnl_link_ops *ops = NULL; + const struct net_device *master; + + master = netdev_master_upper_dev_get(dev); + if (master) + ops = master->rtnl_link_ops; + if (ops && ops->fill_linkxstats) { + *idxattr = IFLA_STATS_LINK_XSTATS_SLAVE; + attr = nla_nest_start_noflag(skb, + IFLA_STATS_LINK_XSTATS_SLAVE); + if (!attr) { + err = -EMSGSIZE; + goto nla_put_failure; + } + + err = ops->fill_linkxstats(skb, dev, prividx, *idxattr); + nla_nest_end(skb, attr); + if (err) + goto nla_put_failure; + *idxattr = 0; + } + } + + if (stats_attr_valid(filter_mask, IFLA_STATS_LINK_OFFLOAD_XSTATS, + *idxattr)) { + u32 off_filter_mask; + + off_filter_mask = filters->mask[IFLA_STATS_LINK_OFFLOAD_XSTATS]; + *idxattr = IFLA_STATS_LINK_OFFLOAD_XSTATS; + attr = nla_nest_start_noflag(skb, + IFLA_STATS_LINK_OFFLOAD_XSTATS); + if (!attr) { + err = -EMSGSIZE; + goto nla_put_failure; + } + + err = rtnl_offload_xstats_fill(skb, dev, prividx, + off_filter_mask, extack); + if (err == -ENODATA) + nla_nest_cancel(skb, attr); + else + nla_nest_end(skb, attr); + + if (err && err != -ENODATA) + goto nla_put_failure; + *idxattr = 0; + } + + if (stats_attr_valid(filter_mask, IFLA_STATS_AF_SPEC, *idxattr)) { + struct rtnl_af_ops *af_ops; + + *idxattr = IFLA_STATS_AF_SPEC; + attr = nla_nest_start_noflag(skb, IFLA_STATS_AF_SPEC); + if (!attr) { + err = -EMSGSIZE; + goto nla_put_failure; + } + + rcu_read_lock(); + list_for_each_entry_rcu(af_ops, &rtnl_af_ops, list) { + if (af_ops->fill_stats_af) { + struct nlattr *af; + + af = nla_nest_start_noflag(skb, + af_ops->family); + if (!af) { + rcu_read_unlock(); + err = -EMSGSIZE; + goto nla_put_failure; + } + err = af_ops->fill_stats_af(skb, dev); + + if (err == -ENODATA) { + nla_nest_cancel(skb, af); + } else if (err < 0) { + rcu_read_unlock(); + goto nla_put_failure; + } + + nla_nest_end(skb, af); + } + } + rcu_read_unlock(); + + nla_nest_end(skb, attr); + + *idxattr = 0; + } + + nlmsg_end(skb, nlh); + + return 0; + +nla_put_failure: + /* not a multi message or no progress mean a real error */ + if (!(flags & NLM_F_MULTI) || s_prividx == *prividx) + nlmsg_cancel(skb, nlh); + else + nlmsg_end(skb, nlh); + + return err; +} + +static size_t if_nlmsg_stats_size(const struct net_device *dev, + const struct rtnl_stats_dump_filters *filters) +{ + size_t size = NLMSG_ALIGN(sizeof(struct if_stats_msg)); + unsigned int filter_mask = filters->mask[0]; + + if (stats_attr_valid(filter_mask, IFLA_STATS_LINK_64, 0)) + size += nla_total_size_64bit(sizeof(struct rtnl_link_stats64)); + + if (stats_attr_valid(filter_mask, IFLA_STATS_LINK_XSTATS, 0)) { + const struct rtnl_link_ops *ops = dev->rtnl_link_ops; + int attr = IFLA_STATS_LINK_XSTATS; + + if (ops && ops->get_linkxstats_size) { + size += nla_total_size(ops->get_linkxstats_size(dev, + attr)); + /* for IFLA_STATS_LINK_XSTATS */ + size += nla_total_size(0); + } + } + + if (stats_attr_valid(filter_mask, IFLA_STATS_LINK_XSTATS_SLAVE, 0)) { + struct net_device *_dev = (struct net_device *)dev; + const struct rtnl_link_ops *ops = NULL; + const struct net_device *master; + + /* netdev_master_upper_dev_get can't take const */ + master = netdev_master_upper_dev_get(_dev); + if (master) + ops = master->rtnl_link_ops; + if (ops && ops->get_linkxstats_size) { + int attr = IFLA_STATS_LINK_XSTATS_SLAVE; + + size += nla_total_size(ops->get_linkxstats_size(dev, + attr)); + /* for IFLA_STATS_LINK_XSTATS_SLAVE */ + size += nla_total_size(0); + } + } + + if (stats_attr_valid(filter_mask, IFLA_STATS_LINK_OFFLOAD_XSTATS, 0)) { + u32 off_filter_mask; + + off_filter_mask = filters->mask[IFLA_STATS_LINK_OFFLOAD_XSTATS]; + size += rtnl_offload_xstats_get_size(dev, off_filter_mask); + } + + if (stats_attr_valid(filter_mask, IFLA_STATS_AF_SPEC, 0)) { + struct rtnl_af_ops *af_ops; + + /* for IFLA_STATS_AF_SPEC */ + size += nla_total_size(0); + + rcu_read_lock(); + list_for_each_entry_rcu(af_ops, &rtnl_af_ops, list) { + if (af_ops->get_stats_af_size) { + size += nla_total_size( + af_ops->get_stats_af_size(dev)); + + /* for AF_* */ + size += nla_total_size(0); + } + } + rcu_read_unlock(); + } + + return size; +} + +#define RTNL_STATS_OFFLOAD_XSTATS_VALID ((1 << __IFLA_OFFLOAD_XSTATS_MAX) - 1) + +static const struct nla_policy +rtnl_stats_get_policy_filters[IFLA_STATS_MAX + 1] = { + [IFLA_STATS_LINK_OFFLOAD_XSTATS] = + NLA_POLICY_MASK(NLA_U32, RTNL_STATS_OFFLOAD_XSTATS_VALID), +}; + +static const struct nla_policy +rtnl_stats_get_policy[IFLA_STATS_GETSET_MAX + 1] = { + [IFLA_STATS_GET_FILTERS] = + NLA_POLICY_NESTED(rtnl_stats_get_policy_filters), +}; + +static const struct nla_policy +ifla_stats_set_policy[IFLA_STATS_GETSET_MAX + 1] = { + [IFLA_STATS_SET_OFFLOAD_XSTATS_L3_STATS] = NLA_POLICY_MAX(NLA_U8, 1), +}; + +static int rtnl_stats_get_parse_filters(struct nlattr *ifla_filters, + struct rtnl_stats_dump_filters *filters, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_STATS_MAX + 1]; + int err; + int at; + + err = nla_parse_nested(tb, IFLA_STATS_MAX, ifla_filters, + rtnl_stats_get_policy_filters, extack); + if (err < 0) + return err; + + for (at = 1; at <= IFLA_STATS_MAX; at++) { + if (tb[at]) { + if (!(filters->mask[0] & IFLA_STATS_FILTER_BIT(at))) { + NL_SET_ERR_MSG(extack, "Filtered attribute not enabled in filter_mask"); + return -EINVAL; + } + filters->mask[at] = nla_get_u32(tb[at]); + } + } + + return 0; +} + +static int rtnl_stats_get_parse(const struct nlmsghdr *nlh, + u32 filter_mask, + struct rtnl_stats_dump_filters *filters, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_STATS_GETSET_MAX + 1]; + int err; + int i; + + filters->mask[0] = filter_mask; + for (i = 1; i < ARRAY_SIZE(filters->mask); i++) + filters->mask[i] = -1U; + + err = nlmsg_parse(nlh, sizeof(struct if_stats_msg), tb, + IFLA_STATS_GETSET_MAX, rtnl_stats_get_policy, extack); + if (err < 0) + return err; + + if (tb[IFLA_STATS_GET_FILTERS]) { + err = rtnl_stats_get_parse_filters(tb[IFLA_STATS_GET_FILTERS], + filters, extack); + if (err) + return err; + } + + return 0; +} + +static int rtnl_valid_stats_req(const struct nlmsghdr *nlh, bool strict_check, + bool is_dump, struct netlink_ext_ack *extack) +{ + struct if_stats_msg *ifsm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifsm))) { + NL_SET_ERR_MSG(extack, "Invalid header for stats dump"); + return -EINVAL; + } + + if (!strict_check) + return 0; + + ifsm = nlmsg_data(nlh); + + /* only requests using strict checks can pass data to influence + * the dump. The legacy exception is filter_mask. + */ + if (ifsm->pad1 || ifsm->pad2 || (is_dump && ifsm->ifindex)) { + NL_SET_ERR_MSG(extack, "Invalid values in header for stats dump request"); + return -EINVAL; + } + if (ifsm->filter_mask >= IFLA_STATS_FILTER_BIT(IFLA_STATS_MAX + 1)) { + NL_SET_ERR_MSG(extack, "Invalid stats requested through filter mask"); + return -EINVAL; + } + + return 0; +} + +static int rtnl_stats_get(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct rtnl_stats_dump_filters filters; + struct net *net = sock_net(skb->sk); + struct net_device *dev = NULL; + int idxattr = 0, prividx = 0; + struct if_stats_msg *ifsm; + struct sk_buff *nskb; + int err; + + err = rtnl_valid_stats_req(nlh, netlink_strict_get_check(skb), + false, extack); + if (err) + return err; + + ifsm = nlmsg_data(nlh); + if (ifsm->ifindex > 0) + dev = __dev_get_by_index(net, ifsm->ifindex); + else + return -EINVAL; + + if (!dev) + return -ENODEV; + + if (!ifsm->filter_mask) { + NL_SET_ERR_MSG(extack, "Filter mask must be set for stats get"); + return -EINVAL; + } + + err = rtnl_stats_get_parse(nlh, ifsm->filter_mask, &filters, extack); + if (err) + return err; + + nskb = nlmsg_new(if_nlmsg_stats_size(dev, &filters), GFP_KERNEL); + if (!nskb) + return -ENOBUFS; + + err = rtnl_fill_statsinfo(nskb, dev, RTM_NEWSTATS, + NETLINK_CB(skb).portid, nlh->nlmsg_seq, 0, + 0, &filters, &idxattr, &prividx, extack); + if (err < 0) { + /* -EMSGSIZE implies BUG in if_nlmsg_stats_size */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(nskb); + } else { + err = rtnl_unicast(nskb, net, NETLINK_CB(skb).portid); + } + + return err; +} + +static int rtnl_stats_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct netlink_ext_ack *extack = cb->extack; + int h, s_h, err, s_idx, s_idxattr, s_prividx; + struct rtnl_stats_dump_filters filters; + struct net *net = sock_net(skb->sk); + unsigned int flags = NLM_F_MULTI; + struct if_stats_msg *ifsm; + struct hlist_head *head; + struct net_device *dev; + int idx = 0; + + s_h = cb->args[0]; + s_idx = cb->args[1]; + s_idxattr = cb->args[2]; + s_prividx = cb->args[3]; + + cb->seq = net->dev_base_seq; + + err = rtnl_valid_stats_req(cb->nlh, cb->strict_check, true, extack); + if (err) + return err; + + ifsm = nlmsg_data(cb->nlh); + if (!ifsm->filter_mask) { + NL_SET_ERR_MSG(extack, "Filter mask must be set for stats dump"); + return -EINVAL; + } + + err = rtnl_stats_get_parse(cb->nlh, ifsm->filter_mask, &filters, + extack); + if (err) + return err; + + for (h = s_h; h < NETDEV_HASHENTRIES; h++, s_idx = 0) { + idx = 0; + head = &net->dev_index_head[h]; + hlist_for_each_entry(dev, head, index_hlist) { + if (idx < s_idx) + goto cont; + err = rtnl_fill_statsinfo(skb, dev, RTM_NEWSTATS, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, 0, + flags, &filters, + &s_idxattr, &s_prividx, + extack); + /* If we ran out of room on the first message, + * we're in trouble + */ + WARN_ON((err == -EMSGSIZE) && (skb->len == 0)); + + if (err < 0) + goto out; + s_prividx = 0; + s_idxattr = 0; + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); +cont: + idx++; + } + } +out: + cb->args[3] = s_prividx; + cb->args[2] = s_idxattr; + cb->args[1] = idx; + cb->args[0] = h; + + return skb->len; +} + +void rtnl_offload_xstats_notify(struct net_device *dev) +{ + struct rtnl_stats_dump_filters response_filters = {}; + struct net *net = dev_net(dev); + int idxattr = 0, prividx = 0; + struct sk_buff *skb; + int err = -ENOBUFS; + + ASSERT_RTNL(); + + response_filters.mask[0] |= + IFLA_STATS_FILTER_BIT(IFLA_STATS_LINK_OFFLOAD_XSTATS); + response_filters.mask[IFLA_STATS_LINK_OFFLOAD_XSTATS] |= + IFLA_STATS_FILTER_BIT(IFLA_OFFLOAD_XSTATS_HW_S_INFO); + + skb = nlmsg_new(if_nlmsg_stats_size(dev, &response_filters), + GFP_KERNEL); + if (!skb) + goto errout; + + err = rtnl_fill_statsinfo(skb, dev, RTM_NEWSTATS, 0, 0, 0, 0, + &response_filters, &idxattr, &prividx, NULL); + if (err < 0) { + kfree_skb(skb); + goto errout; + } + + rtnl_notify(skb, net, 0, RTNLGRP_STATS, NULL, GFP_KERNEL); + return; + +errout: + rtnl_set_sk_err(net, RTNLGRP_STATS, err); +} +EXPORT_SYMBOL(rtnl_offload_xstats_notify); + +static int rtnl_stats_set(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + enum netdev_offload_xstats_type t_l3 = NETDEV_OFFLOAD_XSTATS_TYPE_L3; + struct rtnl_stats_dump_filters response_filters = {}; + struct nlattr *tb[IFLA_STATS_GETSET_MAX + 1]; + struct net *net = sock_net(skb->sk); + struct net_device *dev = NULL; + struct if_stats_msg *ifsm; + bool notify = false; + int err; + + err = rtnl_valid_stats_req(nlh, netlink_strict_get_check(skb), + false, extack); + if (err) + return err; + + ifsm = nlmsg_data(nlh); + if (ifsm->family != AF_UNSPEC) { + NL_SET_ERR_MSG(extack, "Address family should be AF_UNSPEC"); + return -EINVAL; + } + + if (ifsm->ifindex > 0) + dev = __dev_get_by_index(net, ifsm->ifindex); + else + return -EINVAL; + + if (!dev) + return -ENODEV; + + if (ifsm->filter_mask) { + NL_SET_ERR_MSG(extack, "Filter mask must be 0 for stats set"); + return -EINVAL; + } + + err = nlmsg_parse(nlh, sizeof(*ifsm), tb, IFLA_STATS_GETSET_MAX, + ifla_stats_set_policy, extack); + if (err < 0) + return err; + + if (tb[IFLA_STATS_SET_OFFLOAD_XSTATS_L3_STATS]) { + u8 req = nla_get_u8(tb[IFLA_STATS_SET_OFFLOAD_XSTATS_L3_STATS]); + + if (req) + err = netdev_offload_xstats_enable(dev, t_l3, extack); + else + err = netdev_offload_xstats_disable(dev, t_l3); + + if (!err) + notify = true; + else if (err != -EALREADY) + return err; + + response_filters.mask[0] |= + IFLA_STATS_FILTER_BIT(IFLA_STATS_LINK_OFFLOAD_XSTATS); + response_filters.mask[IFLA_STATS_LINK_OFFLOAD_XSTATS] |= + IFLA_STATS_FILTER_BIT(IFLA_OFFLOAD_XSTATS_HW_S_INFO); + } + + if (notify) + rtnl_offload_xstats_notify(dev); + + return 0; +} + +/* Process one rtnetlink message. */ + +static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct rtnl_link *link; + enum rtnl_kinds kind; + struct module *owner; + int err = -EOPNOTSUPP; + rtnl_doit_func doit; + unsigned int flags; + int family; + int type; + + type = nlh->nlmsg_type; + if (type > RTM_MAX) + return -EOPNOTSUPP; + + type -= RTM_BASE; + + /* All the messages must have at least 1 byte length */ + if (nlmsg_len(nlh) < sizeof(struct rtgenmsg)) + return 0; + + family = ((struct rtgenmsg *)nlmsg_data(nlh))->rtgen_family; + kind = rtnl_msgtype_kind(type); + + if (kind != RTNL_KIND_GET && !netlink_net_capable(skb, CAP_NET_ADMIN)) + return -EPERM; + + rcu_read_lock(); + if (kind == RTNL_KIND_GET && (nlh->nlmsg_flags & NLM_F_DUMP)) { + struct sock *rtnl; + rtnl_dumpit_func dumpit; + u32 min_dump_alloc = 0; + + link = rtnl_get_link(family, type); + if (!link || !link->dumpit) { + family = PF_UNSPEC; + link = rtnl_get_link(family, type); + if (!link || !link->dumpit) + goto err_unlock; + } + owner = link->owner; + dumpit = link->dumpit; + + if (type == RTM_GETLINK - RTM_BASE) + min_dump_alloc = rtnl_calcit(skb, nlh); + + err = 0; + /* need to do this before rcu_read_unlock() */ + if (!try_module_get(owner)) + err = -EPROTONOSUPPORT; + + rcu_read_unlock(); + + rtnl = net->rtnl; + if (err == 0) { + struct netlink_dump_control c = { + .dump = dumpit, + .min_dump_alloc = min_dump_alloc, + .module = owner, + }; + err = netlink_dump_start(rtnl, skb, nlh, &c); + /* netlink_dump_start() will keep a reference on + * module if dump is still in progress. + */ + module_put(owner); + } + return err; + } + + link = rtnl_get_link(family, type); + if (!link || !link->doit) { + family = PF_UNSPEC; + link = rtnl_get_link(PF_UNSPEC, type); + if (!link || !link->doit) + goto out_unlock; + } + + owner = link->owner; + if (!try_module_get(owner)) { + err = -EPROTONOSUPPORT; + goto out_unlock; + } + + flags = link->flags; + if (kind == RTNL_KIND_DEL && (nlh->nlmsg_flags & NLM_F_BULK) && + !(flags & RTNL_FLAG_BULK_DEL_SUPPORTED)) { + NL_SET_ERR_MSG(extack, "Bulk delete is not supported"); + module_put(owner); + goto err_unlock; + } + + if (flags & RTNL_FLAG_DOIT_UNLOCKED) { + doit = link->doit; + rcu_read_unlock(); + if (doit) + err = doit(skb, nlh, extack); + module_put(owner); + return err; + } + rcu_read_unlock(); + + rtnl_lock(); + link = rtnl_get_link(family, type); + if (link && link->doit) + err = link->doit(skb, nlh, extack); + rtnl_unlock(); + + module_put(owner); + + return err; + +out_unlock: + rcu_read_unlock(); + return err; + +err_unlock: + rcu_read_unlock(); + return -EOPNOTSUPP; +} + +static void rtnetlink_rcv(struct sk_buff *skb) +{ + netlink_rcv_skb(skb, &rtnetlink_rcv_msg); +} + +static int rtnetlink_bind(struct net *net, int group) +{ + switch (group) { + case RTNLGRP_IPV4_MROUTE_R: + case RTNLGRP_IPV6_MROUTE_R: + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + break; + } + return 0; +} + +static int rtnetlink_event(struct notifier_block *this, unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + switch (event) { + case NETDEV_REBOOT: + case NETDEV_CHANGEMTU: + case NETDEV_CHANGEADDR: + case NETDEV_CHANGENAME: + case NETDEV_FEAT_CHANGE: + case NETDEV_BONDING_FAILOVER: + case NETDEV_POST_TYPE_CHANGE: + case NETDEV_NOTIFY_PEERS: + case NETDEV_CHANGEUPPER: + case NETDEV_RESEND_IGMP: + case NETDEV_CHANGEINFODATA: + case NETDEV_CHANGELOWERSTATE: + case NETDEV_CHANGE_TX_QUEUE_LEN: + rtmsg_ifinfo_event(RTM_NEWLINK, dev, 0, rtnl_get_event(event), + GFP_KERNEL, NULL, 0); + break; + default: + break; + } + return NOTIFY_DONE; +} + +static struct notifier_block rtnetlink_dev_notifier = { + .notifier_call = rtnetlink_event, +}; + + +static int __net_init rtnetlink_net_init(struct net *net) +{ + struct sock *sk; + struct netlink_kernel_cfg cfg = { + .groups = RTNLGRP_MAX, + .input = rtnetlink_rcv, + .cb_mutex = &rtnl_mutex, + .flags = NL_CFG_F_NONROOT_RECV, + .bind = rtnetlink_bind, + }; + + sk = netlink_kernel_create(net, NETLINK_ROUTE, &cfg); + if (!sk) + return -ENOMEM; + net->rtnl = sk; + return 0; +} + +static void __net_exit rtnetlink_net_exit(struct net *net) +{ + netlink_kernel_release(net->rtnl); + net->rtnl = NULL; +} + +static struct pernet_operations rtnetlink_net_ops = { + .init = rtnetlink_net_init, + .exit = rtnetlink_net_exit, +}; + +void __init rtnetlink_init(void) +{ + if (register_pernet_subsys(&rtnetlink_net_ops)) + panic("rtnetlink_init: cannot initialize rtnetlink\n"); + + register_netdevice_notifier(&rtnetlink_dev_notifier); + + rtnl_register(PF_UNSPEC, RTM_GETLINK, rtnl_getlink, + rtnl_dump_ifinfo, 0); + rtnl_register(PF_UNSPEC, RTM_SETLINK, rtnl_setlink, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_NEWLINK, rtnl_newlink, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_DELLINK, rtnl_dellink, NULL, 0); + + rtnl_register(PF_UNSPEC, RTM_GETADDR, NULL, rtnl_dump_all, 0); + rtnl_register(PF_UNSPEC, RTM_GETROUTE, NULL, rtnl_dump_all, 0); + rtnl_register(PF_UNSPEC, RTM_GETNETCONF, NULL, rtnl_dump_all, 0); + + rtnl_register(PF_UNSPEC, RTM_NEWLINKPROP, rtnl_newlinkprop, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_DELLINKPROP, rtnl_dellinkprop, NULL, 0); + + rtnl_register(PF_BRIDGE, RTM_NEWNEIGH, rtnl_fdb_add, NULL, 0); + rtnl_register(PF_BRIDGE, RTM_DELNEIGH, rtnl_fdb_del, NULL, + RTNL_FLAG_BULK_DEL_SUPPORTED); + rtnl_register(PF_BRIDGE, RTM_GETNEIGH, rtnl_fdb_get, rtnl_fdb_dump, 0); + + rtnl_register(PF_BRIDGE, RTM_GETLINK, NULL, rtnl_bridge_getlink, 0); + rtnl_register(PF_BRIDGE, RTM_DELLINK, rtnl_bridge_dellink, NULL, 0); + rtnl_register(PF_BRIDGE, RTM_SETLINK, rtnl_bridge_setlink, NULL, 0); + + rtnl_register(PF_UNSPEC, RTM_GETSTATS, rtnl_stats_get, rtnl_stats_dump, + 0); + rtnl_register(PF_UNSPEC, RTM_SETSTATS, rtnl_stats_set, NULL, 0); +} diff --git a/net/core/scm.c b/net/core/scm.c new file mode 100644 index 000000000..e762a4b8a --- /dev/null +++ b/net/core/scm.c @@ -0,0 +1,375 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* scm.c - Socket level control messages processing. + * + * Author: Alexey Kuznetsov, + * Alignment and value checking mods by Craig Metz + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + + +/* + * Only allow a user to send credentials, that they could set with + * setu(g)id. + */ + +static __inline__ int scm_check_creds(struct ucred *creds) +{ + const struct cred *cred = current_cred(); + kuid_t uid = make_kuid(cred->user_ns, creds->uid); + kgid_t gid = make_kgid(cred->user_ns, creds->gid); + + if (!uid_valid(uid) || !gid_valid(gid)) + return -EINVAL; + + if ((creds->pid == task_tgid_vnr(current) || + ns_capable(task_active_pid_ns(current)->user_ns, CAP_SYS_ADMIN)) && + ((uid_eq(uid, cred->uid) || uid_eq(uid, cred->euid) || + uid_eq(uid, cred->suid)) || ns_capable(cred->user_ns, CAP_SETUID)) && + ((gid_eq(gid, cred->gid) || gid_eq(gid, cred->egid) || + gid_eq(gid, cred->sgid)) || ns_capable(cred->user_ns, CAP_SETGID))) { + return 0; + } + return -EPERM; +} + +static int scm_fp_copy(struct cmsghdr *cmsg, struct scm_fp_list **fplp) +{ + int *fdp = (int*)CMSG_DATA(cmsg); + struct scm_fp_list *fpl = *fplp; + struct file **fpp; + int i, num; + + num = (cmsg->cmsg_len - sizeof(struct cmsghdr))/sizeof(int); + + if (num <= 0) + return 0; + + if (num > SCM_MAX_FD) + return -EINVAL; + + if (!fpl) + { + fpl = kmalloc(sizeof(struct scm_fp_list), GFP_KERNEL_ACCOUNT); + if (!fpl) + return -ENOMEM; + *fplp = fpl; + fpl->count = 0; + fpl->max = SCM_MAX_FD; + fpl->user = NULL; + } + fpp = &fpl->fp[fpl->count]; + + if (fpl->count + num > fpl->max) + return -EINVAL; + + /* + * Verify the descriptors and increment the usage count. + */ + + for (i=0; i< num; i++) + { + int fd = fdp[i]; + struct file *file; + + if (fd < 0 || !(file = fget_raw(fd))) + return -EBADF; + /* don't allow io_uring files */ + if (io_uring_get_socket(file)) { + fput(file); + return -EINVAL; + } + *fpp++ = file; + fpl->count++; + } + + if (!fpl->user) + fpl->user = get_uid(current_user()); + + return num; +} + +void __scm_destroy(struct scm_cookie *scm) +{ + struct scm_fp_list *fpl = scm->fp; + int i; + + if (fpl) { + scm->fp = NULL; + for (i=fpl->count-1; i>=0; i--) + fput(fpl->fp[i]); + free_uid(fpl->user); + kfree(fpl); + } +} +EXPORT_SYMBOL(__scm_destroy); + +int __scm_send(struct socket *sock, struct msghdr *msg, struct scm_cookie *p) +{ + struct cmsghdr *cmsg; + int err; + + for_each_cmsghdr(cmsg, msg) { + err = -EINVAL; + + /* Verify that cmsg_len is at least sizeof(struct cmsghdr) */ + /* The first check was omitted in <= 2.2.5. The reasoning was + that parser checks cmsg_len in any case, so that + additional check would be work duplication. + But if cmsg_level is not SOL_SOCKET, we do not check + for too short ancillary data object at all! Oops. + OK, let's add it... + */ + if (!CMSG_OK(msg, cmsg)) + goto error; + + if (cmsg->cmsg_level != SOL_SOCKET) + continue; + + switch (cmsg->cmsg_type) + { + case SCM_RIGHTS: + if (!sock->ops || sock->ops->family != PF_UNIX) + goto error; + err=scm_fp_copy(cmsg, &p->fp); + if (err<0) + goto error; + break; + case SCM_CREDENTIALS: + { + struct ucred creds; + kuid_t uid; + kgid_t gid; + if (cmsg->cmsg_len != CMSG_LEN(sizeof(struct ucred))) + goto error; + memcpy(&creds, CMSG_DATA(cmsg), sizeof(struct ucred)); + err = scm_check_creds(&creds); + if (err) + goto error; + + p->creds.pid = creds.pid; + if (!p->pid || pid_vnr(p->pid) != creds.pid) { + struct pid *pid; + err = -ESRCH; + pid = find_get_pid(creds.pid); + if (!pid) + goto error; + put_pid(p->pid); + p->pid = pid; + } + + err = -EINVAL; + uid = make_kuid(current_user_ns(), creds.uid); + gid = make_kgid(current_user_ns(), creds.gid); + if (!uid_valid(uid) || !gid_valid(gid)) + goto error; + + p->creds.uid = uid; + p->creds.gid = gid; + break; + } + default: + goto error; + } + } + + if (p->fp && !p->fp->count) + { + kfree(p->fp); + p->fp = NULL; + } + return 0; + +error: + scm_destroy(p); + return err; +} +EXPORT_SYMBOL(__scm_send); + +int put_cmsg(struct msghdr * msg, int level, int type, int len, void *data) +{ + int cmlen = CMSG_LEN(len); + + if (msg->msg_flags & MSG_CMSG_COMPAT) + return put_cmsg_compat(msg, level, type, len, data); + + if (!msg->msg_control || msg->msg_controllen < sizeof(struct cmsghdr)) { + msg->msg_flags |= MSG_CTRUNC; + return 0; /* XXX: return error? check spec. */ + } + if (msg->msg_controllen < cmlen) { + msg->msg_flags |= MSG_CTRUNC; + cmlen = msg->msg_controllen; + } + + if (msg->msg_control_is_user) { + struct cmsghdr __user *cm = msg->msg_control_user; + + check_object_size(data, cmlen - sizeof(*cm), true); + + if (!user_write_access_begin(cm, cmlen)) + goto efault; + + unsafe_put_user(cmlen, &cm->cmsg_len, efault_end); + unsafe_put_user(level, &cm->cmsg_level, efault_end); + unsafe_put_user(type, &cm->cmsg_type, efault_end); + unsafe_copy_to_user(CMSG_USER_DATA(cm), data, + cmlen - sizeof(*cm), efault_end); + user_write_access_end(); + } else { + struct cmsghdr *cm = msg->msg_control; + + cm->cmsg_level = level; + cm->cmsg_type = type; + cm->cmsg_len = cmlen; + memcpy(CMSG_DATA(cm), data, cmlen - sizeof(*cm)); + } + + cmlen = min(CMSG_SPACE(len), msg->msg_controllen); + msg->msg_control += cmlen; + msg->msg_controllen -= cmlen; + return 0; + +efault_end: + user_write_access_end(); +efault: + return -EFAULT; +} +EXPORT_SYMBOL(put_cmsg); + +void put_cmsg_scm_timestamping64(struct msghdr *msg, struct scm_timestamping_internal *tss_internal) +{ + struct scm_timestamping64 tss; + int i; + + for (i = 0; i < ARRAY_SIZE(tss.ts); i++) { + tss.ts[i].tv_sec = tss_internal->ts[i].tv_sec; + tss.ts[i].tv_nsec = tss_internal->ts[i].tv_nsec; + } + + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMPING_NEW, sizeof(tss), &tss); +} +EXPORT_SYMBOL(put_cmsg_scm_timestamping64); + +void put_cmsg_scm_timestamping(struct msghdr *msg, struct scm_timestamping_internal *tss_internal) +{ + struct scm_timestamping tss; + int i; + + for (i = 0; i < ARRAY_SIZE(tss.ts); i++) { + tss.ts[i].tv_sec = tss_internal->ts[i].tv_sec; + tss.ts[i].tv_nsec = tss_internal->ts[i].tv_nsec; + } + + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMPING_OLD, sizeof(tss), &tss); +} +EXPORT_SYMBOL(put_cmsg_scm_timestamping); + +static int scm_max_fds(struct msghdr *msg) +{ + if (msg->msg_controllen <= sizeof(struct cmsghdr)) + return 0; + return (msg->msg_controllen - sizeof(struct cmsghdr)) / sizeof(int); +} + +void scm_detach_fds(struct msghdr *msg, struct scm_cookie *scm) +{ + struct cmsghdr __user *cm = + (__force struct cmsghdr __user *)msg->msg_control; + unsigned int o_flags = (msg->msg_flags & MSG_CMSG_CLOEXEC) ? O_CLOEXEC : 0; + int fdmax = min_t(int, scm_max_fds(msg), scm->fp->count); + int __user *cmsg_data = CMSG_USER_DATA(cm); + int err = 0, i; + + /* no use for FD passing from kernel space callers */ + if (WARN_ON_ONCE(!msg->msg_control_is_user)) + return; + + if (msg->msg_flags & MSG_CMSG_COMPAT) { + scm_detach_fds_compat(msg, scm); + return; + } + + for (i = 0; i < fdmax; i++) { + err = receive_fd_user(scm->fp->fp[i], cmsg_data + i, o_flags); + if (err < 0) + break; + } + + if (i > 0) { + int cmlen = CMSG_LEN(i * sizeof(int)); + + err = put_user(SOL_SOCKET, &cm->cmsg_level); + if (!err) + err = put_user(SCM_RIGHTS, &cm->cmsg_type); + if (!err) + err = put_user(cmlen, &cm->cmsg_len); + if (!err) { + cmlen = CMSG_SPACE(i * sizeof(int)); + if (msg->msg_controllen < cmlen) + cmlen = msg->msg_controllen; + msg->msg_control += cmlen; + msg->msg_controllen -= cmlen; + } + } + + if (i < scm->fp->count || (scm->fp->count && fdmax <= 0)) + msg->msg_flags |= MSG_CTRUNC; + + /* + * All of the files that fit in the message have had their usage counts + * incremented, so we just free the list. + */ + __scm_destroy(scm); +} +EXPORT_SYMBOL(scm_detach_fds); + +struct scm_fp_list *scm_fp_dup(struct scm_fp_list *fpl) +{ + struct scm_fp_list *new_fpl; + int i; + + if (!fpl) + return NULL; + + new_fpl = kmemdup(fpl, offsetof(struct scm_fp_list, fp[fpl->count]), + GFP_KERNEL_ACCOUNT); + if (new_fpl) { + for (i = 0; i < fpl->count; i++) + get_file(fpl->fp[i]); + new_fpl->max = new_fpl->count; + new_fpl->user = get_uid(fpl->user); + } + return new_fpl; +} +EXPORT_SYMBOL(scm_fp_dup); diff --git a/net/core/secure_seq.c b/net/core/secure_seq.c new file mode 100644 index 000000000..b0ff6153b --- /dev/null +++ b/net/core/secure_seq.c @@ -0,0 +1,200 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2016 Jason A. Donenfeld . All Rights Reserved. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_IPV6) || IS_ENABLED(CONFIG_INET) +#include +#include + +static siphash_aligned_key_t net_secret; +static siphash_aligned_key_t ts_secret; + +#define EPHEMERAL_PORT_SHUFFLE_PERIOD (10 * HZ) + +static __always_inline void net_secret_init(void) +{ + net_get_random_once(&net_secret, sizeof(net_secret)); +} + +static __always_inline void ts_secret_init(void) +{ + net_get_random_once(&ts_secret, sizeof(ts_secret)); +} +#endif + +#ifdef CONFIG_INET +static u32 seq_scale(u32 seq) +{ + /* + * As close as possible to RFC 793, which + * suggests using a 250 kHz clock. + * Further reading shows this assumes 2 Mb/s networks. + * For 10 Mb/s Ethernet, a 1 MHz clock is appropriate. + * For 10 Gb/s Ethernet, a 1 GHz clock should be ok, but + * we also need to limit the resolution so that the u32 seq + * overlaps less than one time per MSL (2 minutes). + * Choosing a clock of 64 ns period is OK. (period of 274 s) + */ + return seq + (ktime_get_real_ns() >> 6); +} +#endif + +#if IS_ENABLED(CONFIG_IPV6) +u32 secure_tcpv6_ts_off(const struct net *net, + const __be32 *saddr, const __be32 *daddr) +{ + const struct { + struct in6_addr saddr; + struct in6_addr daddr; + } __aligned(SIPHASH_ALIGNMENT) combined = { + .saddr = *(struct in6_addr *)saddr, + .daddr = *(struct in6_addr *)daddr, + }; + + if (READ_ONCE(net->ipv4.sysctl_tcp_timestamps) != 1) + return 0; + + ts_secret_init(); + return siphash(&combined, offsetofend(typeof(combined), daddr), + &ts_secret); +} +EXPORT_SYMBOL(secure_tcpv6_ts_off); + +u32 secure_tcpv6_seq(const __be32 *saddr, const __be32 *daddr, + __be16 sport, __be16 dport) +{ + const struct { + struct in6_addr saddr; + struct in6_addr daddr; + __be16 sport; + __be16 dport; + } __aligned(SIPHASH_ALIGNMENT) combined = { + .saddr = *(struct in6_addr *)saddr, + .daddr = *(struct in6_addr *)daddr, + .sport = sport, + .dport = dport + }; + u32 hash; + + net_secret_init(); + hash = siphash(&combined, offsetofend(typeof(combined), dport), + &net_secret); + return seq_scale(hash); +} +EXPORT_SYMBOL(secure_tcpv6_seq); + +u64 secure_ipv6_port_ephemeral(const __be32 *saddr, const __be32 *daddr, + __be16 dport) +{ + const struct { + struct in6_addr saddr; + struct in6_addr daddr; + unsigned int timeseed; + __be16 dport; + } __aligned(SIPHASH_ALIGNMENT) combined = { + .saddr = *(struct in6_addr *)saddr, + .daddr = *(struct in6_addr *)daddr, + .timeseed = jiffies / EPHEMERAL_PORT_SHUFFLE_PERIOD, + .dport = dport, + }; + net_secret_init(); + return siphash(&combined, offsetofend(typeof(combined), dport), + &net_secret); +} +EXPORT_SYMBOL(secure_ipv6_port_ephemeral); +#endif + +#ifdef CONFIG_INET +u32 secure_tcp_ts_off(const struct net *net, __be32 saddr, __be32 daddr) +{ + if (READ_ONCE(net->ipv4.sysctl_tcp_timestamps) != 1) + return 0; + + ts_secret_init(); + return siphash_2u32((__force u32)saddr, (__force u32)daddr, + &ts_secret); +} + +/* secure_tcp_seq_and_tsoff(a, b, 0, d) == secure_ipv4_port_ephemeral(a, b, d), + * but fortunately, `sport' cannot be 0 in any circumstances. If this changes, + * it would be easy enough to have the former function use siphash_4u32, passing + * the arguments as separate u32. + */ +u32 secure_tcp_seq(__be32 saddr, __be32 daddr, + __be16 sport, __be16 dport) +{ + u32 hash; + + net_secret_init(); + hash = siphash_3u32((__force u32)saddr, (__force u32)daddr, + (__force u32)sport << 16 | (__force u32)dport, + &net_secret); + return seq_scale(hash); +} +EXPORT_SYMBOL_GPL(secure_tcp_seq); + +u64 secure_ipv4_port_ephemeral(__be32 saddr, __be32 daddr, __be16 dport) +{ + net_secret_init(); + return siphash_4u32((__force u32)saddr, (__force u32)daddr, + (__force u16)dport, + jiffies / EPHEMERAL_PORT_SHUFFLE_PERIOD, + &net_secret); +} +EXPORT_SYMBOL_GPL(secure_ipv4_port_ephemeral); +#endif + +#if IS_ENABLED(CONFIG_IP_DCCP) +u64 secure_dccp_sequence_number(__be32 saddr, __be32 daddr, + __be16 sport, __be16 dport) +{ + u64 seq; + net_secret_init(); + seq = siphash_3u32((__force u32)saddr, (__force u32)daddr, + (__force u32)sport << 16 | (__force u32)dport, + &net_secret); + seq += ktime_get_real_ns(); + seq &= (1ull << 48) - 1; + return seq; +} +EXPORT_SYMBOL(secure_dccp_sequence_number); + +#if IS_ENABLED(CONFIG_IPV6) +u64 secure_dccpv6_sequence_number(__be32 *saddr, __be32 *daddr, + __be16 sport, __be16 dport) +{ + const struct { + struct in6_addr saddr; + struct in6_addr daddr; + __be16 sport; + __be16 dport; + } __aligned(SIPHASH_ALIGNMENT) combined = { + .saddr = *(struct in6_addr *)saddr, + .daddr = *(struct in6_addr *)daddr, + .sport = sport, + .dport = dport + }; + u64 seq; + net_secret_init(); + seq = siphash(&combined, offsetofend(typeof(combined), dport), + &net_secret); + seq += ktime_get_real_ns(); + seq &= (1ull << 48) - 1; + return seq; +} +EXPORT_SYMBOL(secure_dccpv6_sequence_number); +#endif +#endif diff --git a/net/core/selftests.c b/net/core/selftests.c new file mode 100644 index 000000000..acb1ee97b --- /dev/null +++ b/net/core/selftests.c @@ -0,0 +1,412 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2019 Synopsys, Inc. and/or its affiliates. + * stmmac Selftests Support + * + * Author: Jose Abreu + * + * Ported from stmmac by: + * Copyright (C) 2021 Oleksij Rempel + */ + +#include +#include +#include +#include + +struct net_packet_attrs { + const unsigned char *src; + const unsigned char *dst; + u32 ip_src; + u32 ip_dst; + bool tcp; + u16 sport; + u16 dport; + int timeout; + int size; + int max_size; + u8 id; + u16 queue_mapping; +}; + +struct net_test_priv { + struct net_packet_attrs *packet; + struct packet_type pt; + struct completion comp; + int double_vlan; + int vlan_id; + int ok; +}; + +struct netsfhdr { + __be32 version; + __be64 magic; + u8 id; +} __packed; + +static u8 net_test_next_id; + +#define NET_TEST_PKT_SIZE (sizeof(struct ethhdr) + sizeof(struct iphdr) + \ + sizeof(struct netsfhdr)) +#define NET_TEST_PKT_MAGIC 0xdeadcafecafedeadULL +#define NET_LB_TIMEOUT msecs_to_jiffies(200) + +static struct sk_buff *net_test_get_skb(struct net_device *ndev, + struct net_packet_attrs *attr) +{ + struct sk_buff *skb = NULL; + struct udphdr *uhdr = NULL; + struct tcphdr *thdr = NULL; + struct netsfhdr *shdr; + struct ethhdr *ehdr; + struct iphdr *ihdr; + int iplen, size; + + size = attr->size + NET_TEST_PKT_SIZE; + + if (attr->tcp) + size += sizeof(struct tcphdr); + else + size += sizeof(struct udphdr); + + if (attr->max_size && attr->max_size > size) + size = attr->max_size; + + skb = netdev_alloc_skb(ndev, size); + if (!skb) + return NULL; + + prefetchw(skb->data); + + ehdr = skb_push(skb, ETH_HLEN); + skb_reset_mac_header(skb); + + skb_set_network_header(skb, skb->len); + ihdr = skb_put(skb, sizeof(*ihdr)); + + skb_set_transport_header(skb, skb->len); + if (attr->tcp) + thdr = skb_put(skb, sizeof(*thdr)); + else + uhdr = skb_put(skb, sizeof(*uhdr)); + + eth_zero_addr(ehdr->h_dest); + + if (attr->src) + ether_addr_copy(ehdr->h_source, attr->src); + if (attr->dst) + ether_addr_copy(ehdr->h_dest, attr->dst); + + ehdr->h_proto = htons(ETH_P_IP); + + if (attr->tcp) { + thdr->source = htons(attr->sport); + thdr->dest = htons(attr->dport); + thdr->doff = sizeof(struct tcphdr) / 4; + thdr->check = 0; + } else { + uhdr->source = htons(attr->sport); + uhdr->dest = htons(attr->dport); + uhdr->len = htons(sizeof(*shdr) + sizeof(*uhdr) + attr->size); + if (attr->max_size) + uhdr->len = htons(attr->max_size - + (sizeof(*ihdr) + sizeof(*ehdr))); + uhdr->check = 0; + } + + ihdr->ihl = 5; + ihdr->ttl = 32; + ihdr->version = 4; + if (attr->tcp) + ihdr->protocol = IPPROTO_TCP; + else + ihdr->protocol = IPPROTO_UDP; + iplen = sizeof(*ihdr) + sizeof(*shdr) + attr->size; + if (attr->tcp) + iplen += sizeof(*thdr); + else + iplen += sizeof(*uhdr); + + if (attr->max_size) + iplen = attr->max_size - sizeof(*ehdr); + + ihdr->tot_len = htons(iplen); + ihdr->frag_off = 0; + ihdr->saddr = htonl(attr->ip_src); + ihdr->daddr = htonl(attr->ip_dst); + ihdr->tos = 0; + ihdr->id = 0; + ip_send_check(ihdr); + + shdr = skb_put(skb, sizeof(*shdr)); + shdr->version = 0; + shdr->magic = cpu_to_be64(NET_TEST_PKT_MAGIC); + attr->id = net_test_next_id; + shdr->id = net_test_next_id++; + + if (attr->size) + skb_put(skb, attr->size); + if (attr->max_size && attr->max_size > skb->len) + skb_put(skb, attr->max_size - skb->len); + + skb->csum = 0; + skb->ip_summed = CHECKSUM_PARTIAL; + if (attr->tcp) { + thdr->check = ~tcp_v4_check(skb->len, ihdr->saddr, + ihdr->daddr, 0); + skb->csum_start = skb_transport_header(skb) - skb->head; + skb->csum_offset = offsetof(struct tcphdr, check); + } else { + udp4_hwcsum(skb, ihdr->saddr, ihdr->daddr); + } + + skb->protocol = htons(ETH_P_IP); + skb->pkt_type = PACKET_HOST; + skb->dev = ndev; + + return skb; +} + +static int net_test_loopback_validate(struct sk_buff *skb, + struct net_device *ndev, + struct packet_type *pt, + struct net_device *orig_ndev) +{ + struct net_test_priv *tpriv = pt->af_packet_priv; + const unsigned char *src = tpriv->packet->src; + const unsigned char *dst = tpriv->packet->dst; + struct netsfhdr *shdr; + struct ethhdr *ehdr; + struct udphdr *uhdr; + struct tcphdr *thdr; + struct iphdr *ihdr; + + skb = skb_unshare(skb, GFP_ATOMIC); + if (!skb) + goto out; + + if (skb_linearize(skb)) + goto out; + if (skb_headlen(skb) < (NET_TEST_PKT_SIZE - ETH_HLEN)) + goto out; + + ehdr = (struct ethhdr *)skb_mac_header(skb); + if (dst) { + if (!ether_addr_equal_unaligned(ehdr->h_dest, dst)) + goto out; + } + + if (src) { + if (!ether_addr_equal_unaligned(ehdr->h_source, src)) + goto out; + } + + ihdr = ip_hdr(skb); + if (tpriv->double_vlan) + ihdr = (struct iphdr *)(skb_network_header(skb) + 4); + + if (tpriv->packet->tcp) { + if (ihdr->protocol != IPPROTO_TCP) + goto out; + + thdr = (struct tcphdr *)((u8 *)ihdr + 4 * ihdr->ihl); + if (thdr->dest != htons(tpriv->packet->dport)) + goto out; + + shdr = (struct netsfhdr *)((u8 *)thdr + sizeof(*thdr)); + } else { + if (ihdr->protocol != IPPROTO_UDP) + goto out; + + uhdr = (struct udphdr *)((u8 *)ihdr + 4 * ihdr->ihl); + if (uhdr->dest != htons(tpriv->packet->dport)) + goto out; + + shdr = (struct netsfhdr *)((u8 *)uhdr + sizeof(*uhdr)); + } + + if (shdr->magic != cpu_to_be64(NET_TEST_PKT_MAGIC)) + goto out; + if (tpriv->packet->id != shdr->id) + goto out; + + tpriv->ok = true; + complete(&tpriv->comp); +out: + kfree_skb(skb); + return 0; +} + +static int __net_test_loopback(struct net_device *ndev, + struct net_packet_attrs *attr) +{ + struct net_test_priv *tpriv; + struct sk_buff *skb = NULL; + int ret = 0; + + tpriv = kzalloc(sizeof(*tpriv), GFP_KERNEL); + if (!tpriv) + return -ENOMEM; + + tpriv->ok = false; + init_completion(&tpriv->comp); + + tpriv->pt.type = htons(ETH_P_IP); + tpriv->pt.func = net_test_loopback_validate; + tpriv->pt.dev = ndev; + tpriv->pt.af_packet_priv = tpriv; + tpriv->packet = attr; + dev_add_pack(&tpriv->pt); + + skb = net_test_get_skb(ndev, attr); + if (!skb) { + ret = -ENOMEM; + goto cleanup; + } + + ret = dev_direct_xmit(skb, attr->queue_mapping); + if (ret < 0) { + goto cleanup; + } else if (ret > 0) { + ret = -ENETUNREACH; + goto cleanup; + } + + if (!attr->timeout) + attr->timeout = NET_LB_TIMEOUT; + + wait_for_completion_timeout(&tpriv->comp, attr->timeout); + ret = tpriv->ok ? 0 : -ETIMEDOUT; + +cleanup: + dev_remove_pack(&tpriv->pt); + kfree(tpriv); + return ret; +} + +static int net_test_netif_carrier(struct net_device *ndev) +{ + return netif_carrier_ok(ndev) ? 0 : -ENOLINK; +} + +static int net_test_phy_phydev(struct net_device *ndev) +{ + return ndev->phydev ? 0 : -EOPNOTSUPP; +} + +static int net_test_phy_loopback_enable(struct net_device *ndev) +{ + if (!ndev->phydev) + return -EOPNOTSUPP; + + return phy_loopback(ndev->phydev, true); +} + +static int net_test_phy_loopback_disable(struct net_device *ndev) +{ + if (!ndev->phydev) + return -EOPNOTSUPP; + + return phy_loopback(ndev->phydev, false); +} + +static int net_test_phy_loopback_udp(struct net_device *ndev) +{ + struct net_packet_attrs attr = { }; + + attr.dst = ndev->dev_addr; + return __net_test_loopback(ndev, &attr); +} + +static int net_test_phy_loopback_udp_mtu(struct net_device *ndev) +{ + struct net_packet_attrs attr = { }; + + attr.dst = ndev->dev_addr; + attr.max_size = ndev->mtu; + return __net_test_loopback(ndev, &attr); +} + +static int net_test_phy_loopback_tcp(struct net_device *ndev) +{ + struct net_packet_attrs attr = { }; + + attr.dst = ndev->dev_addr; + attr.tcp = true; + return __net_test_loopback(ndev, &attr); +} + +static const struct net_test { + char name[ETH_GSTRING_LEN]; + int (*fn)(struct net_device *ndev); +} net_selftests[] = { + { + .name = "Carrier ", + .fn = net_test_netif_carrier, + }, { + .name = "PHY dev is present ", + .fn = net_test_phy_phydev, + }, { + /* This test should be done before all PHY loopback test */ + .name = "PHY internal loopback, enable ", + .fn = net_test_phy_loopback_enable, + }, { + .name = "PHY internal loopback, UDP ", + .fn = net_test_phy_loopback_udp, + }, { + .name = "PHY internal loopback, MTU ", + .fn = net_test_phy_loopback_udp_mtu, + }, { + .name = "PHY internal loopback, TCP ", + .fn = net_test_phy_loopback_tcp, + }, { + /* This test should be done after all PHY loopback test */ + .name = "PHY internal loopback, disable", + .fn = net_test_phy_loopback_disable, + }, +}; + +void net_selftest(struct net_device *ndev, struct ethtool_test *etest, u64 *buf) +{ + int count = net_selftest_get_count(); + int i; + + memset(buf, 0, sizeof(*buf) * count); + net_test_next_id = 0; + + if (etest->flags != ETH_TEST_FL_OFFLINE) { + netdev_err(ndev, "Only offline tests are supported\n"); + etest->flags |= ETH_TEST_FL_FAILED; + return; + } + + + for (i = 0; i < count; i++) { + buf[i] = net_selftests[i].fn(ndev); + if (buf[i] && (buf[i] != -EOPNOTSUPP)) + etest->flags |= ETH_TEST_FL_FAILED; + } +} +EXPORT_SYMBOL_GPL(net_selftest); + +int net_selftest_get_count(void) +{ + return ARRAY_SIZE(net_selftests); +} +EXPORT_SYMBOL_GPL(net_selftest_get_count); + +void net_selftest_get_strings(u8 *data) +{ + u8 *p = data; + int i; + + for (i = 0; i < net_selftest_get_count(); i++) { + snprintf(p, ETH_GSTRING_LEN, "%2d. %s", i + 1, + net_selftests[i].name); + p += ETH_GSTRING_LEN; + } +} +EXPORT_SYMBOL_GPL(net_selftest_get_strings); + +MODULE_LICENSE("GPL v2"); +MODULE_AUTHOR("Oleksij Rempel "); diff --git a/net/core/skbuff.c b/net/core/skbuff.c new file mode 100644 index 000000000..8a819d0a7 --- /dev/null +++ b/net/core/skbuff.c @@ -0,0 +1,6687 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Routines having to do with the 'struct sk_buff' memory handlers. + * + * Authors: Alan Cox + * Florian La Roche + * + * Fixes: + * Alan Cox : Fixed the worst of the load + * balancer bugs. + * Dave Platt : Interrupt stacking fix. + * Richard Kooijman : Timestamp fixes. + * Alan Cox : Changed buffer format. + * Alan Cox : destructor hook for AF_UNIX etc. + * Linus Torvalds : Better skb_clone. + * Alan Cox : Added skb_copy. + * Alan Cox : Added all the changed routines Linus + * only put in the headers + * Ray VanTassle : Fixed --skb->lock in free + * Alan Cox : skb_copy copy arp field + * Andi Kleen : slabified it. + * Robert Olsson : Removed skb_head_pool + * + * NOTE: + * The __skb_ routines should be called with interrupts + * disabled, or you better be *real* sure that the operation is atomic + * with respect to whatever list is being frobbed (e.g. via lock_sock() + * or via disabling bottom half handlers, etc). + */ + +/* + * The functions in this file will not compile correctly with gcc 2.4.x + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_NET_CLS_ACT +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "dev.h" +#include "sock_destructor.h" + +struct kmem_cache *skbuff_head_cache __ro_after_init; +static struct kmem_cache *skbuff_fclone_cache __ro_after_init; +#ifdef CONFIG_SKB_EXTENSIONS +static struct kmem_cache *skbuff_ext_cache __ro_after_init; +#endif +int sysctl_max_skb_frags __read_mostly = MAX_SKB_FRAGS; +EXPORT_SYMBOL(sysctl_max_skb_frags); + +#undef FN +#define FN(reason) [SKB_DROP_REASON_##reason] = #reason, +const char * const drop_reasons[] = { + DEFINE_DROP_REASON(FN, FN) +}; +EXPORT_SYMBOL(drop_reasons); + +/** + * skb_panic - private function for out-of-line support + * @skb: buffer + * @sz: size + * @addr: address + * @msg: skb_over_panic or skb_under_panic + * + * Out-of-line support for skb_put() and skb_push(). + * Called via the wrapper skb_over_panic() or skb_under_panic(). + * Keep out of line to prevent kernel bloat. + * __builtin_return_address is not used because it is not always reliable. + */ +static void skb_panic(struct sk_buff *skb, unsigned int sz, void *addr, + const char msg[]) +{ + pr_emerg("%s: text:%px len:%d put:%d head:%px data:%px tail:%#lx end:%#lx dev:%s\n", + msg, addr, skb->len, sz, skb->head, skb->data, + (unsigned long)skb->tail, (unsigned long)skb->end, + skb->dev ? skb->dev->name : ""); + BUG(); +} + +static void skb_over_panic(struct sk_buff *skb, unsigned int sz, void *addr) +{ + skb_panic(skb, sz, addr, __func__); +} + +static void skb_under_panic(struct sk_buff *skb, unsigned int sz, void *addr) +{ + skb_panic(skb, sz, addr, __func__); +} + +#define NAPI_SKB_CACHE_SIZE 64 +#define NAPI_SKB_CACHE_BULK 16 +#define NAPI_SKB_CACHE_HALF (NAPI_SKB_CACHE_SIZE / 2) + +#if PAGE_SIZE == SZ_4K + +#define NAPI_HAS_SMALL_PAGE_FRAG 1 +#define NAPI_SMALL_PAGE_PFMEMALLOC(nc) ((nc).pfmemalloc) + +/* specialized page frag allocator using a single order 0 page + * and slicing it into 1K sized fragment. Constrained to systems + * with a very limited amount of 1K fragments fitting a single + * page - to avoid excessive truesize underestimation + */ + +struct page_frag_1k { + void *va; + u16 offset; + bool pfmemalloc; +}; + +static void *page_frag_alloc_1k(struct page_frag_1k *nc, gfp_t gfp) +{ + struct page *page; + int offset; + + offset = nc->offset - SZ_1K; + if (likely(offset >= 0)) + goto use_frag; + + page = alloc_pages_node(NUMA_NO_NODE, gfp, 0); + if (!page) + return NULL; + + nc->va = page_address(page); + nc->pfmemalloc = page_is_pfmemalloc(page); + offset = PAGE_SIZE - SZ_1K; + page_ref_add(page, offset / SZ_1K); + +use_frag: + nc->offset = offset; + return nc->va + offset; +} +#else + +/* the small page is actually unused in this build; add dummy helpers + * to please the compiler and avoid later preprocessor's conditionals + */ +#define NAPI_HAS_SMALL_PAGE_FRAG 0 +#define NAPI_SMALL_PAGE_PFMEMALLOC(nc) false + +struct page_frag_1k { +}; + +static void *page_frag_alloc_1k(struct page_frag_1k *nc, gfp_t gfp_mask) +{ + return NULL; +} + +#endif + +struct napi_alloc_cache { + struct page_frag_cache page; + struct page_frag_1k page_small; + unsigned int skb_count; + void *skb_cache[NAPI_SKB_CACHE_SIZE]; +}; + +static DEFINE_PER_CPU(struct page_frag_cache, netdev_alloc_cache); +static DEFINE_PER_CPU(struct napi_alloc_cache, napi_alloc_cache); + +/* Double check that napi_get_frags() allocates skbs with + * skb->head being backed by slab, not a page fragment. + * This is to make sure bug fixed in 3226b158e67c + * ("net: avoid 32 x truesize under-estimation for tiny skbs") + * does not accidentally come back. + */ +void napi_get_frags_check(struct napi_struct *napi) +{ + struct sk_buff *skb; + + local_bh_disable(); + skb = napi_get_frags(napi); + WARN_ON_ONCE(!NAPI_HAS_SMALL_PAGE_FRAG && skb && skb->head_frag); + napi_free_frags(napi); + local_bh_enable(); +} + +void *__napi_alloc_frag_align(unsigned int fragsz, unsigned int align_mask) +{ + struct napi_alloc_cache *nc = this_cpu_ptr(&napi_alloc_cache); + + fragsz = SKB_DATA_ALIGN(fragsz); + + return page_frag_alloc_align(&nc->page, fragsz, GFP_ATOMIC, align_mask); +} +EXPORT_SYMBOL(__napi_alloc_frag_align); + +void *__netdev_alloc_frag_align(unsigned int fragsz, unsigned int align_mask) +{ + void *data; + + fragsz = SKB_DATA_ALIGN(fragsz); + if (in_hardirq() || irqs_disabled()) { + struct page_frag_cache *nc = this_cpu_ptr(&netdev_alloc_cache); + + data = page_frag_alloc_align(nc, fragsz, GFP_ATOMIC, align_mask); + } else { + struct napi_alloc_cache *nc; + + local_bh_disable(); + nc = this_cpu_ptr(&napi_alloc_cache); + data = page_frag_alloc_align(&nc->page, fragsz, GFP_ATOMIC, align_mask); + local_bh_enable(); + } + return data; +} +EXPORT_SYMBOL(__netdev_alloc_frag_align); + +static struct sk_buff *napi_skb_cache_get(void) +{ + struct napi_alloc_cache *nc = this_cpu_ptr(&napi_alloc_cache); + struct sk_buff *skb; + + if (unlikely(!nc->skb_count)) { + nc->skb_count = kmem_cache_alloc_bulk(skbuff_head_cache, + GFP_ATOMIC, + NAPI_SKB_CACHE_BULK, + nc->skb_cache); + if (unlikely(!nc->skb_count)) + return NULL; + } + + skb = nc->skb_cache[--nc->skb_count]; + kasan_unpoison_object_data(skbuff_head_cache, skb); + + return skb; +} + +/* Caller must provide SKB that is memset cleared */ +static void __build_skb_around(struct sk_buff *skb, void *data, + unsigned int frag_size) +{ + struct skb_shared_info *shinfo; + unsigned int size = frag_size ? : ksize(data); + + size -= SKB_DATA_ALIGN(sizeof(struct skb_shared_info)); + + /* Assumes caller memset cleared SKB */ + skb->truesize = SKB_TRUESIZE(size); + refcount_set(&skb->users, 1); + skb->head = data; + skb->data = data; + skb_reset_tail_pointer(skb); + skb_set_end_offset(skb, size); + skb->mac_header = (typeof(skb->mac_header))~0U; + skb->transport_header = (typeof(skb->transport_header))~0U; + skb->alloc_cpu = raw_smp_processor_id(); + /* make sure we initialize shinfo sequentially */ + shinfo = skb_shinfo(skb); + memset(shinfo, 0, offsetof(struct skb_shared_info, dataref)); + atomic_set(&shinfo->dataref, 1); + + skb_set_kcov_handle(skb, kcov_common_handle()); +} + +/** + * __build_skb - build a network buffer + * @data: data buffer provided by caller + * @frag_size: size of data, or 0 if head was kmalloced + * + * Allocate a new &sk_buff. Caller provides space holding head and + * skb_shared_info. @data must have been allocated by kmalloc() only if + * @frag_size is 0, otherwise data should come from the page allocator + * or vmalloc() + * The return is the new skb buffer. + * On a failure the return is %NULL, and @data is not freed. + * Notes : + * Before IO, driver allocates only data buffer where NIC put incoming frame + * Driver should add room at head (NET_SKB_PAD) and + * MUST add room at tail (SKB_DATA_ALIGN(skb_shared_info)) + * After IO, driver calls build_skb(), to allocate sk_buff and populate it + * before giving packet to stack. + * RX rings only contains data buffers, not full skbs. + */ +struct sk_buff *__build_skb(void *data, unsigned int frag_size) +{ + struct sk_buff *skb; + + skb = kmem_cache_alloc(skbuff_head_cache, GFP_ATOMIC); + if (unlikely(!skb)) + return NULL; + + memset(skb, 0, offsetof(struct sk_buff, tail)); + __build_skb_around(skb, data, frag_size); + + return skb; +} + +/* build_skb() is wrapper over __build_skb(), that specifically + * takes care of skb->head and skb->pfmemalloc + * This means that if @frag_size is not zero, then @data must be backed + * by a page fragment, not kmalloc() or vmalloc() + */ +struct sk_buff *build_skb(void *data, unsigned int frag_size) +{ + struct sk_buff *skb = __build_skb(data, frag_size); + + if (skb && frag_size) { + skb->head_frag = 1; + if (page_is_pfmemalloc(virt_to_head_page(data))) + skb->pfmemalloc = 1; + } + return skb; +} +EXPORT_SYMBOL(build_skb); + +/** + * build_skb_around - build a network buffer around provided skb + * @skb: sk_buff provide by caller, must be memset cleared + * @data: data buffer provided by caller + * @frag_size: size of data, or 0 if head was kmalloced + */ +struct sk_buff *build_skb_around(struct sk_buff *skb, + void *data, unsigned int frag_size) +{ + if (unlikely(!skb)) + return NULL; + + __build_skb_around(skb, data, frag_size); + + if (frag_size) { + skb->head_frag = 1; + if (page_is_pfmemalloc(virt_to_head_page(data))) + skb->pfmemalloc = 1; + } + return skb; +} +EXPORT_SYMBOL(build_skb_around); + +/** + * __napi_build_skb - build a network buffer + * @data: data buffer provided by caller + * @frag_size: size of data, or 0 if head was kmalloced + * + * Version of __build_skb() that uses NAPI percpu caches to obtain + * skbuff_head instead of inplace allocation. + * + * Returns a new &sk_buff on success, %NULL on allocation failure. + */ +static struct sk_buff *__napi_build_skb(void *data, unsigned int frag_size) +{ + struct sk_buff *skb; + + skb = napi_skb_cache_get(); + if (unlikely(!skb)) + return NULL; + + memset(skb, 0, offsetof(struct sk_buff, tail)); + __build_skb_around(skb, data, frag_size); + + return skb; +} + +/** + * napi_build_skb - build a network buffer + * @data: data buffer provided by caller + * @frag_size: size of data, or 0 if head was kmalloced + * + * Version of __napi_build_skb() that takes care of skb->head_frag + * and skb->pfmemalloc when the data is a page or page fragment. + * + * Returns a new &sk_buff on success, %NULL on allocation failure. + */ +struct sk_buff *napi_build_skb(void *data, unsigned int frag_size) +{ + struct sk_buff *skb = __napi_build_skb(data, frag_size); + + if (likely(skb) && frag_size) { + skb->head_frag = 1; + skb_propagate_pfmemalloc(virt_to_head_page(data), skb); + } + + return skb; +} +EXPORT_SYMBOL(napi_build_skb); + +/* + * kmalloc_reserve is a wrapper around kmalloc_node_track_caller that tells + * the caller if emergency pfmemalloc reserves are being used. If it is and + * the socket is later found to be SOCK_MEMALLOC then PFMEMALLOC reserves + * may be used. Otherwise, the packet data may be discarded until enough + * memory is free + */ +static void *kmalloc_reserve(unsigned int *size, gfp_t flags, int node, + bool *pfmemalloc) +{ + bool ret_pfmemalloc = false; + size_t obj_size; + void *obj; + + obj_size = SKB_HEAD_ALIGN(*size); + + obj_size = kmalloc_size_roundup(obj_size); + /* The following cast might truncate high-order bits of obj_size, this + * is harmless because kmalloc(obj_size >= 2^32) will fail anyway. + */ + *size = (unsigned int)obj_size; + + /* + * Try a regular allocation, when that fails and we're not entitled + * to the reserves, fail. + */ + obj = kmalloc_node_track_caller(obj_size, + flags | __GFP_NOMEMALLOC | __GFP_NOWARN, + node); + if (obj || !(gfp_pfmemalloc_allowed(flags))) + goto out; + + /* Try again but now we are using pfmemalloc reserves */ + ret_pfmemalloc = true; + obj = kmalloc_node_track_caller(obj_size, flags, node); + +out: + if (pfmemalloc) + *pfmemalloc = ret_pfmemalloc; + + return obj; +} + +/* Allocate a new skbuff. We do this ourselves so we can fill in a few + * 'private' fields and also do memory statistics to find all the + * [BEEP] leaks. + * + */ + +/** + * __alloc_skb - allocate a network buffer + * @size: size to allocate + * @gfp_mask: allocation mask + * @flags: If SKB_ALLOC_FCLONE is set, allocate from fclone cache + * instead of head cache and allocate a cloned (child) skb. + * If SKB_ALLOC_RX is set, __GFP_MEMALLOC will be used for + * allocations in case the data is required for writeback + * @node: numa node to allocate memory on + * + * Allocate a new &sk_buff. The returned buffer has no headroom and a + * tail room of at least size bytes. The object has a reference count + * of one. The return is the buffer. On a failure the return is %NULL. + * + * Buffers may only be allocated from interrupts using a @gfp_mask of + * %GFP_ATOMIC. + */ +struct sk_buff *__alloc_skb(unsigned int size, gfp_t gfp_mask, + int flags, int node) +{ + struct kmem_cache *cache; + struct sk_buff *skb; + bool pfmemalloc; + u8 *data; + + cache = (flags & SKB_ALLOC_FCLONE) + ? skbuff_fclone_cache : skbuff_head_cache; + + if (sk_memalloc_socks() && (flags & SKB_ALLOC_RX)) + gfp_mask |= __GFP_MEMALLOC; + + /* Get the HEAD */ + if ((flags & (SKB_ALLOC_FCLONE | SKB_ALLOC_NAPI)) == SKB_ALLOC_NAPI && + likely(node == NUMA_NO_NODE || node == numa_mem_id())) + skb = napi_skb_cache_get(); + else + skb = kmem_cache_alloc_node(cache, gfp_mask & ~GFP_DMA, node); + if (unlikely(!skb)) + return NULL; + prefetchw(skb); + + /* We do our best to align skb_shared_info on a separate cache + * line. It usually works because kmalloc(X > SMP_CACHE_BYTES) gives + * aligned memory blocks, unless SLUB/SLAB debug is enabled. + * Both skb->head and skb_shared_info are cache line aligned. + */ + data = kmalloc_reserve(&size, gfp_mask, node, &pfmemalloc); + if (unlikely(!data)) + goto nodata; + /* kmalloc_size_roundup() might give us more room than requested. + * Put skb_shared_info exactly at the end of allocated zone, + * to allow max possible filling before reallocation. + */ + prefetchw(data + SKB_WITH_OVERHEAD(size)); + + /* + * Only clear those fields we need to clear, not those that we will + * actually initialise below. Hence, don't put any more fields after + * the tail pointer in struct sk_buff! + */ + memset(skb, 0, offsetof(struct sk_buff, tail)); + __build_skb_around(skb, data, size); + skb->pfmemalloc = pfmemalloc; + + if (flags & SKB_ALLOC_FCLONE) { + struct sk_buff_fclones *fclones; + + fclones = container_of(skb, struct sk_buff_fclones, skb1); + + skb->fclone = SKB_FCLONE_ORIG; + refcount_set(&fclones->fclone_ref, 1); + } + + return skb; + +nodata: + kmem_cache_free(cache, skb); + return NULL; +} +EXPORT_SYMBOL(__alloc_skb); + +/** + * __netdev_alloc_skb - allocate an skbuff for rx on a specific device + * @dev: network device to receive on + * @len: length to allocate + * @gfp_mask: get_free_pages mask, passed to alloc_skb + * + * Allocate a new &sk_buff and assign it a usage count of one. The + * buffer has NET_SKB_PAD headroom built in. Users should allocate + * the headroom they think they need without accounting for the + * built in space. The built in space is used for optimisations. + * + * %NULL is returned if there is no free memory. + */ +struct sk_buff *__netdev_alloc_skb(struct net_device *dev, unsigned int len, + gfp_t gfp_mask) +{ + struct page_frag_cache *nc; + struct sk_buff *skb; + bool pfmemalloc; + void *data; + + len += NET_SKB_PAD; + + /* If requested length is either too small or too big, + * we use kmalloc() for skb->head allocation. + */ + if (len <= SKB_WITH_OVERHEAD(1024) || + len > SKB_WITH_OVERHEAD(PAGE_SIZE) || + (gfp_mask & (__GFP_DIRECT_RECLAIM | GFP_DMA))) { + skb = __alloc_skb(len, gfp_mask, SKB_ALLOC_RX, NUMA_NO_NODE); + if (!skb) + goto skb_fail; + goto skb_success; + } + + len = SKB_HEAD_ALIGN(len); + + if (sk_memalloc_socks()) + gfp_mask |= __GFP_MEMALLOC; + + if (in_hardirq() || irqs_disabled()) { + nc = this_cpu_ptr(&netdev_alloc_cache); + data = page_frag_alloc(nc, len, gfp_mask); + pfmemalloc = nc->pfmemalloc; + } else { + local_bh_disable(); + nc = this_cpu_ptr(&napi_alloc_cache.page); + data = page_frag_alloc(nc, len, gfp_mask); + pfmemalloc = nc->pfmemalloc; + local_bh_enable(); + } + + if (unlikely(!data)) + return NULL; + + skb = __build_skb(data, len); + if (unlikely(!skb)) { + skb_free_frag(data); + return NULL; + } + + if (pfmemalloc) + skb->pfmemalloc = 1; + skb->head_frag = 1; + +skb_success: + skb_reserve(skb, NET_SKB_PAD); + skb->dev = dev; + +skb_fail: + return skb; +} +EXPORT_SYMBOL(__netdev_alloc_skb); + +/** + * __napi_alloc_skb - allocate skbuff for rx in a specific NAPI instance + * @napi: napi instance this buffer was allocated for + * @len: length to allocate + * @gfp_mask: get_free_pages mask, passed to alloc_skb and alloc_pages + * + * Allocate a new sk_buff for use in NAPI receive. This buffer will + * attempt to allocate the head from a special reserved region used + * only for NAPI Rx allocation. By doing this we can save several + * CPU cycles by avoiding having to disable and re-enable IRQs. + * + * %NULL is returned if there is no free memory. + */ +struct sk_buff *__napi_alloc_skb(struct napi_struct *napi, unsigned int len, + gfp_t gfp_mask) +{ + struct napi_alloc_cache *nc; + struct sk_buff *skb; + bool pfmemalloc; + void *data; + + DEBUG_NET_WARN_ON_ONCE(!in_softirq()); + len += NET_SKB_PAD + NET_IP_ALIGN; + + /* If requested length is either too small or too big, + * we use kmalloc() for skb->head allocation. + * When the small frag allocator is available, prefer it over kmalloc + * for small fragments + */ + if ((!NAPI_HAS_SMALL_PAGE_FRAG && len <= SKB_WITH_OVERHEAD(1024)) || + len > SKB_WITH_OVERHEAD(PAGE_SIZE) || + (gfp_mask & (__GFP_DIRECT_RECLAIM | GFP_DMA))) { + skb = __alloc_skb(len, gfp_mask, SKB_ALLOC_RX | SKB_ALLOC_NAPI, + NUMA_NO_NODE); + if (!skb) + goto skb_fail; + goto skb_success; + } + + nc = this_cpu_ptr(&napi_alloc_cache); + + if (sk_memalloc_socks()) + gfp_mask |= __GFP_MEMALLOC; + + if (NAPI_HAS_SMALL_PAGE_FRAG && len <= SKB_WITH_OVERHEAD(1024)) { + /* we are artificially inflating the allocation size, but + * that is not as bad as it may look like, as: + * - 'len' less than GRO_MAX_HEAD makes little sense + * - On most systems, larger 'len' values lead to fragment + * size above 512 bytes + * - kmalloc would use the kmalloc-1k slab for such values + * - Builds with smaller GRO_MAX_HEAD will very likely do + * little networking, as that implies no WiFi and no + * tunnels support, and 32 bits arches. + */ + len = SZ_1K; + + data = page_frag_alloc_1k(&nc->page_small, gfp_mask); + pfmemalloc = NAPI_SMALL_PAGE_PFMEMALLOC(nc->page_small); + } else { + len = SKB_HEAD_ALIGN(len); + + data = page_frag_alloc(&nc->page, len, gfp_mask); + pfmemalloc = nc->page.pfmemalloc; + } + + if (unlikely(!data)) + return NULL; + + skb = __napi_build_skb(data, len); + if (unlikely(!skb)) { + skb_free_frag(data); + return NULL; + } + + if (pfmemalloc) + skb->pfmemalloc = 1; + skb->head_frag = 1; + +skb_success: + skb_reserve(skb, NET_SKB_PAD + NET_IP_ALIGN); + skb->dev = napi->dev; + +skb_fail: + return skb; +} +EXPORT_SYMBOL(__napi_alloc_skb); + +void skb_add_rx_frag(struct sk_buff *skb, int i, struct page *page, int off, + int size, unsigned int truesize) +{ + skb_fill_page_desc(skb, i, page, off, size); + skb->len += size; + skb->data_len += size; + skb->truesize += truesize; +} +EXPORT_SYMBOL(skb_add_rx_frag); + +void skb_coalesce_rx_frag(struct sk_buff *skb, int i, int size, + unsigned int truesize) +{ + skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; + + skb_frag_size_add(frag, size); + skb->len += size; + skb->data_len += size; + skb->truesize += truesize; +} +EXPORT_SYMBOL(skb_coalesce_rx_frag); + +static void skb_drop_list(struct sk_buff **listp) +{ + kfree_skb_list(*listp); + *listp = NULL; +} + +static inline void skb_drop_fraglist(struct sk_buff *skb) +{ + skb_drop_list(&skb_shinfo(skb)->frag_list); +} + +static void skb_clone_fraglist(struct sk_buff *skb) +{ + struct sk_buff *list; + + skb_walk_frags(skb, list) + skb_get(list); +} + +static void skb_free_head(struct sk_buff *skb) +{ + unsigned char *head = skb->head; + + if (skb->head_frag) { + if (skb_pp_recycle(skb, head)) + return; + skb_free_frag(head); + } else { + kfree(head); + } +} + +static void skb_release_data(struct sk_buff *skb) +{ + struct skb_shared_info *shinfo = skb_shinfo(skb); + int i; + + if (skb->cloned && + atomic_sub_return(skb->nohdr ? (1 << SKB_DATAREF_SHIFT) + 1 : 1, + &shinfo->dataref)) + goto exit; + + if (skb_zcopy(skb)) { + bool skip_unref = shinfo->flags & SKBFL_MANAGED_FRAG_REFS; + + skb_zcopy_clear(skb, true); + if (skip_unref) + goto free_head; + } + + for (i = 0; i < shinfo->nr_frags; i++) + __skb_frag_unref(&shinfo->frags[i], skb->pp_recycle); + +free_head: + if (shinfo->frag_list) + kfree_skb_list(shinfo->frag_list); + + skb_free_head(skb); +exit: + /* When we clone an SKB we copy the reycling bit. The pp_recycle + * bit is only set on the head though, so in order to avoid races + * while trying to recycle fragments on __skb_frag_unref() we need + * to make one SKB responsible for triggering the recycle path. + * So disable the recycling bit if an SKB is cloned and we have + * additional references to the fragmented part of the SKB. + * Eventually the last SKB will have the recycling bit set and it's + * dataref set to 0, which will trigger the recycling + */ + skb->pp_recycle = 0; +} + +/* + * Free an skbuff by memory without cleaning the state. + */ +static void kfree_skbmem(struct sk_buff *skb) +{ + struct sk_buff_fclones *fclones; + + switch (skb->fclone) { + case SKB_FCLONE_UNAVAILABLE: + kmem_cache_free(skbuff_head_cache, skb); + return; + + case SKB_FCLONE_ORIG: + fclones = container_of(skb, struct sk_buff_fclones, skb1); + + /* We usually free the clone (TX completion) before original skb + * This test would have no chance to be true for the clone, + * while here, branch prediction will be good. + */ + if (refcount_read(&fclones->fclone_ref) == 1) + goto fastpath; + break; + + default: /* SKB_FCLONE_CLONE */ + fclones = container_of(skb, struct sk_buff_fclones, skb2); + break; + } + if (!refcount_dec_and_test(&fclones->fclone_ref)) + return; +fastpath: + kmem_cache_free(skbuff_fclone_cache, fclones); +} + +void skb_release_head_state(struct sk_buff *skb) +{ + skb_dst_drop(skb); + if (skb->destructor) { + DEBUG_NET_WARN_ON_ONCE(in_hardirq()); + skb->destructor(skb); + } +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + nf_conntrack_put(skb_nfct(skb)); +#endif + skb_ext_put(skb); +} + +/* Free everything but the sk_buff shell. */ +static void skb_release_all(struct sk_buff *skb) +{ + skb_release_head_state(skb); + if (likely(skb->head)) + skb_release_data(skb); +} + +/** + * __kfree_skb - private function + * @skb: buffer + * + * Free an sk_buff. Release anything attached to the buffer. + * Clean the state. This is an internal helper function. Users should + * always call kfree_skb + */ + +void __kfree_skb(struct sk_buff *skb) +{ + skb_release_all(skb); + kfree_skbmem(skb); +} +EXPORT_SYMBOL(__kfree_skb); + +/** + * kfree_skb_reason - free an sk_buff with special reason + * @skb: buffer to free + * @reason: reason why this skb is dropped + * + * Drop a reference to the buffer and free it if the usage count has + * hit zero. Meanwhile, pass the drop reason to 'kfree_skb' + * tracepoint. + */ +void __fix_address +kfree_skb_reason(struct sk_buff *skb, enum skb_drop_reason reason) +{ + if (unlikely(!skb_unref(skb))) + return; + + DEBUG_NET_WARN_ON_ONCE(reason <= 0 || reason >= SKB_DROP_REASON_MAX); + + trace_kfree_skb(skb, __builtin_return_address(0), reason); + __kfree_skb(skb); +} +EXPORT_SYMBOL(kfree_skb_reason); + +void kfree_skb_list_reason(struct sk_buff *segs, + enum skb_drop_reason reason) +{ + while (segs) { + struct sk_buff *next = segs->next; + + kfree_skb_reason(segs, reason); + segs = next; + } +} +EXPORT_SYMBOL(kfree_skb_list_reason); + +/* Dump skb information and contents. + * + * Must only be called from net_ratelimit()-ed paths. + * + * Dumps whole packets if full_pkt, only headers otherwise. + */ +void skb_dump(const char *level, const struct sk_buff *skb, bool full_pkt) +{ + struct skb_shared_info *sh = skb_shinfo(skb); + struct net_device *dev = skb->dev; + struct sock *sk = skb->sk; + struct sk_buff *list_skb; + bool has_mac, has_trans; + int headroom, tailroom; + int i, len, seg_len; + + if (full_pkt) + len = skb->len; + else + len = min_t(int, skb->len, MAX_HEADER + 128); + + headroom = skb_headroom(skb); + tailroom = skb_tailroom(skb); + + has_mac = skb_mac_header_was_set(skb); + has_trans = skb_transport_header_was_set(skb); + + printk("%sskb len=%u headroom=%u headlen=%u tailroom=%u\n" + "mac=(%d,%d) net=(%d,%d) trans=%d\n" + "shinfo(txflags=%u nr_frags=%u gso(size=%hu type=%u segs=%hu))\n" + "csum(0x%x ip_summed=%u complete_sw=%u valid=%u level=%u)\n" + "hash(0x%x sw=%u l4=%u) proto=0x%04x pkttype=%u iif=%d\n", + level, skb->len, headroom, skb_headlen(skb), tailroom, + has_mac ? skb->mac_header : -1, + has_mac ? skb_mac_header_len(skb) : -1, + skb->network_header, + has_trans ? skb_network_header_len(skb) : -1, + has_trans ? skb->transport_header : -1, + sh->tx_flags, sh->nr_frags, + sh->gso_size, sh->gso_type, sh->gso_segs, + skb->csum, skb->ip_summed, skb->csum_complete_sw, + skb->csum_valid, skb->csum_level, + skb->hash, skb->sw_hash, skb->l4_hash, + ntohs(skb->protocol), skb->pkt_type, skb->skb_iif); + + if (dev) + printk("%sdev name=%s feat=%pNF\n", + level, dev->name, &dev->features); + if (sk) + printk("%ssk family=%hu type=%u proto=%u\n", + level, sk->sk_family, sk->sk_type, sk->sk_protocol); + + if (full_pkt && headroom) + print_hex_dump(level, "skb headroom: ", DUMP_PREFIX_OFFSET, + 16, 1, skb->head, headroom, false); + + seg_len = min_t(int, skb_headlen(skb), len); + if (seg_len) + print_hex_dump(level, "skb linear: ", DUMP_PREFIX_OFFSET, + 16, 1, skb->data, seg_len, false); + len -= seg_len; + + if (full_pkt && tailroom) + print_hex_dump(level, "skb tailroom: ", DUMP_PREFIX_OFFSET, + 16, 1, skb_tail_pointer(skb), tailroom, false); + + for (i = 0; len && i < skb_shinfo(skb)->nr_frags; i++) { + skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; + u32 p_off, p_len, copied; + struct page *p; + u8 *vaddr; + + skb_frag_foreach_page(frag, skb_frag_off(frag), + skb_frag_size(frag), p, p_off, p_len, + copied) { + seg_len = min_t(int, p_len, len); + vaddr = kmap_atomic(p); + print_hex_dump(level, "skb frag: ", + DUMP_PREFIX_OFFSET, + 16, 1, vaddr + p_off, seg_len, false); + kunmap_atomic(vaddr); + len -= seg_len; + if (!len) + break; + } + } + + if (full_pkt && skb_has_frag_list(skb)) { + printk("skb fraglist:\n"); + skb_walk_frags(skb, list_skb) + skb_dump(level, list_skb, true); + } +} +EXPORT_SYMBOL(skb_dump); + +/** + * skb_tx_error - report an sk_buff xmit error + * @skb: buffer that triggered an error + * + * Report xmit error if a device callback is tracking this skb. + * skb must be freed afterwards. + */ +void skb_tx_error(struct sk_buff *skb) +{ + if (skb) { + skb_zcopy_downgrade_managed(skb); + skb_zcopy_clear(skb, true); + } +} +EXPORT_SYMBOL(skb_tx_error); + +#ifdef CONFIG_TRACEPOINTS +/** + * consume_skb - free an skbuff + * @skb: buffer to free + * + * Drop a ref to the buffer and free it if the usage count has hit zero + * Functions identically to kfree_skb, but kfree_skb assumes that the frame + * is being dropped after a failure and notes that + */ +void consume_skb(struct sk_buff *skb) +{ + if (!skb_unref(skb)) + return; + + trace_consume_skb(skb); + __kfree_skb(skb); +} +EXPORT_SYMBOL(consume_skb); +#endif + +/** + * __consume_stateless_skb - free an skbuff, assuming it is stateless + * @skb: buffer to free + * + * Alike consume_skb(), but this variant assumes that this is the last + * skb reference and all the head states have been already dropped + */ +void __consume_stateless_skb(struct sk_buff *skb) +{ + trace_consume_skb(skb); + skb_release_data(skb); + kfree_skbmem(skb); +} + +static void napi_skb_cache_put(struct sk_buff *skb) +{ + struct napi_alloc_cache *nc = this_cpu_ptr(&napi_alloc_cache); + u32 i; + + kasan_poison_object_data(skbuff_head_cache, skb); + nc->skb_cache[nc->skb_count++] = skb; + + if (unlikely(nc->skb_count == NAPI_SKB_CACHE_SIZE)) { + for (i = NAPI_SKB_CACHE_HALF; i < NAPI_SKB_CACHE_SIZE; i++) + kasan_unpoison_object_data(skbuff_head_cache, + nc->skb_cache[i]); + + kmem_cache_free_bulk(skbuff_head_cache, NAPI_SKB_CACHE_HALF, + nc->skb_cache + NAPI_SKB_CACHE_HALF); + nc->skb_count = NAPI_SKB_CACHE_HALF; + } +} + +void __kfree_skb_defer(struct sk_buff *skb) +{ + skb_release_all(skb); + napi_skb_cache_put(skb); +} + +void napi_skb_free_stolen_head(struct sk_buff *skb) +{ + if (unlikely(skb->slow_gro)) { + nf_reset_ct(skb); + skb_dst_drop(skb); + skb_ext_put(skb); + skb_orphan(skb); + skb->slow_gro = 0; + } + napi_skb_cache_put(skb); +} + +void napi_consume_skb(struct sk_buff *skb, int budget) +{ + /* Zero budget indicate non-NAPI context called us, like netpoll */ + if (unlikely(!budget)) { + dev_consume_skb_any(skb); + return; + } + + DEBUG_NET_WARN_ON_ONCE(!in_softirq()); + + if (!skb_unref(skb)) + return; + + /* if reaching here SKB is ready to free */ + trace_consume_skb(skb); + + /* if SKB is a clone, don't handle this case */ + if (skb->fclone != SKB_FCLONE_UNAVAILABLE) { + __kfree_skb(skb); + return; + } + + skb_release_all(skb); + napi_skb_cache_put(skb); +} +EXPORT_SYMBOL(napi_consume_skb); + +/* Make sure a field is contained by headers group */ +#define CHECK_SKB_FIELD(field) \ + BUILD_BUG_ON(offsetof(struct sk_buff, field) != \ + offsetof(struct sk_buff, headers.field)); \ + +static void __copy_skb_header(struct sk_buff *new, const struct sk_buff *old) +{ + new->tstamp = old->tstamp; + /* We do not copy old->sk */ + new->dev = old->dev; + memcpy(new->cb, old->cb, sizeof(old->cb)); + skb_dst_copy(new, old); + __skb_ext_copy(new, old); + __nf_copy(new, old, false); + + /* Note : this field could be in the headers group. + * It is not yet because we do not want to have a 16 bit hole + */ + new->queue_mapping = old->queue_mapping; + + memcpy(&new->headers, &old->headers, sizeof(new->headers)); + CHECK_SKB_FIELD(protocol); + CHECK_SKB_FIELD(csum); + CHECK_SKB_FIELD(hash); + CHECK_SKB_FIELD(priority); + CHECK_SKB_FIELD(skb_iif); + CHECK_SKB_FIELD(vlan_proto); + CHECK_SKB_FIELD(vlan_tci); + CHECK_SKB_FIELD(transport_header); + CHECK_SKB_FIELD(network_header); + CHECK_SKB_FIELD(mac_header); + CHECK_SKB_FIELD(inner_protocol); + CHECK_SKB_FIELD(inner_transport_header); + CHECK_SKB_FIELD(inner_network_header); + CHECK_SKB_FIELD(inner_mac_header); + CHECK_SKB_FIELD(mark); +#ifdef CONFIG_NETWORK_SECMARK + CHECK_SKB_FIELD(secmark); +#endif +#ifdef CONFIG_NET_RX_BUSY_POLL + CHECK_SKB_FIELD(napi_id); +#endif + CHECK_SKB_FIELD(alloc_cpu); +#ifdef CONFIG_XPS + CHECK_SKB_FIELD(sender_cpu); +#endif +#ifdef CONFIG_NET_SCHED + CHECK_SKB_FIELD(tc_index); +#endif + +} + +/* + * You should not add any new code to this function. Add it to + * __copy_skb_header above instead. + */ +static struct sk_buff *__skb_clone(struct sk_buff *n, struct sk_buff *skb) +{ +#define C(x) n->x = skb->x + + n->next = n->prev = NULL; + n->sk = NULL; + __copy_skb_header(n, skb); + + C(len); + C(data_len); + C(mac_len); + n->hdr_len = skb->nohdr ? skb_headroom(skb) : skb->hdr_len; + n->cloned = 1; + n->nohdr = 0; + n->peeked = 0; + C(pfmemalloc); + C(pp_recycle); + n->destructor = NULL; + C(tail); + C(end); + C(head); + C(head_frag); + C(data); + C(truesize); + refcount_set(&n->users, 1); + + atomic_inc(&(skb_shinfo(skb)->dataref)); + skb->cloned = 1; + + return n; +#undef C +} + +/** + * alloc_skb_for_msg() - allocate sk_buff to wrap frag list forming a msg + * @first: first sk_buff of the msg + */ +struct sk_buff *alloc_skb_for_msg(struct sk_buff *first) +{ + struct sk_buff *n; + + n = alloc_skb(0, GFP_ATOMIC); + if (!n) + return NULL; + + n->len = first->len; + n->data_len = first->len; + n->truesize = first->truesize; + + skb_shinfo(n)->frag_list = first; + + __copy_skb_header(n, first); + n->destructor = NULL; + + return n; +} +EXPORT_SYMBOL_GPL(alloc_skb_for_msg); + +/** + * skb_morph - morph one skb into another + * @dst: the skb to receive the contents + * @src: the skb to supply the contents + * + * This is identical to skb_clone except that the target skb is + * supplied by the user. + * + * The target skb is returned upon exit. + */ +struct sk_buff *skb_morph(struct sk_buff *dst, struct sk_buff *src) +{ + skb_release_all(dst); + return __skb_clone(dst, src); +} +EXPORT_SYMBOL_GPL(skb_morph); + +int mm_account_pinned_pages(struct mmpin *mmp, size_t size) +{ + unsigned long max_pg, num_pg, new_pg, old_pg; + struct user_struct *user; + + if (capable(CAP_IPC_LOCK) || !size) + return 0; + + num_pg = (size >> PAGE_SHIFT) + 2; /* worst case */ + max_pg = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT; + user = mmp->user ? : current_user(); + + do { + old_pg = atomic_long_read(&user->locked_vm); + new_pg = old_pg + num_pg; + if (new_pg > max_pg) + return -ENOBUFS; + } while (atomic_long_cmpxchg(&user->locked_vm, old_pg, new_pg) != + old_pg); + + if (!mmp->user) { + mmp->user = get_uid(user); + mmp->num_pg = num_pg; + } else { + mmp->num_pg += num_pg; + } + + return 0; +} +EXPORT_SYMBOL_GPL(mm_account_pinned_pages); + +void mm_unaccount_pinned_pages(struct mmpin *mmp) +{ + if (mmp->user) { + atomic_long_sub(mmp->num_pg, &mmp->user->locked_vm); + free_uid(mmp->user); + } +} +EXPORT_SYMBOL_GPL(mm_unaccount_pinned_pages); + +static struct ubuf_info *msg_zerocopy_alloc(struct sock *sk, size_t size) +{ + struct ubuf_info_msgzc *uarg; + struct sk_buff *skb; + + WARN_ON_ONCE(!in_task()); + + skb = sock_omalloc(sk, 0, GFP_KERNEL); + if (!skb) + return NULL; + + BUILD_BUG_ON(sizeof(*uarg) > sizeof(skb->cb)); + uarg = (void *)skb->cb; + uarg->mmp.user = NULL; + + if (mm_account_pinned_pages(&uarg->mmp, size)) { + kfree_skb(skb); + return NULL; + } + + uarg->ubuf.callback = msg_zerocopy_callback; + uarg->id = ((u32)atomic_inc_return(&sk->sk_zckey)) - 1; + uarg->len = 1; + uarg->bytelen = size; + uarg->zerocopy = 1; + uarg->ubuf.flags = SKBFL_ZEROCOPY_FRAG | SKBFL_DONT_ORPHAN; + refcount_set(&uarg->ubuf.refcnt, 1); + sock_hold(sk); + + return &uarg->ubuf; +} + +static inline struct sk_buff *skb_from_uarg(struct ubuf_info_msgzc *uarg) +{ + return container_of((void *)uarg, struct sk_buff, cb); +} + +struct ubuf_info *msg_zerocopy_realloc(struct sock *sk, size_t size, + struct ubuf_info *uarg) +{ + if (uarg) { + struct ubuf_info_msgzc *uarg_zc; + const u32 byte_limit = 1 << 19; /* limit to a few TSO */ + u32 bytelen, next; + + /* there might be non MSG_ZEROCOPY users */ + if (uarg->callback != msg_zerocopy_callback) + return NULL; + + /* realloc only when socket is locked (TCP, UDP cork), + * so uarg->len and sk_zckey access is serialized + */ + if (!sock_owned_by_user(sk)) { + WARN_ON_ONCE(1); + return NULL; + } + + uarg_zc = uarg_to_msgzc(uarg); + bytelen = uarg_zc->bytelen + size; + if (uarg_zc->len == USHRT_MAX - 1 || bytelen > byte_limit) { + /* TCP can create new skb to attach new uarg */ + if (sk->sk_type == SOCK_STREAM) + goto new_alloc; + return NULL; + } + + next = (u32)atomic_read(&sk->sk_zckey); + if ((u32)(uarg_zc->id + uarg_zc->len) == next) { + if (mm_account_pinned_pages(&uarg_zc->mmp, size)) + return NULL; + uarg_zc->len++; + uarg_zc->bytelen = bytelen; + atomic_set(&sk->sk_zckey, ++next); + + /* no extra ref when appending to datagram (MSG_MORE) */ + if (sk->sk_type == SOCK_STREAM) + net_zcopy_get(uarg); + + return uarg; + } + } + +new_alloc: + return msg_zerocopy_alloc(sk, size); +} +EXPORT_SYMBOL_GPL(msg_zerocopy_realloc); + +static bool skb_zerocopy_notify_extend(struct sk_buff *skb, u32 lo, u16 len) +{ + struct sock_exterr_skb *serr = SKB_EXT_ERR(skb); + u32 old_lo, old_hi; + u64 sum_len; + + old_lo = serr->ee.ee_info; + old_hi = serr->ee.ee_data; + sum_len = old_hi - old_lo + 1ULL + len; + + if (sum_len >= (1ULL << 32)) + return false; + + if (lo != old_hi + 1) + return false; + + serr->ee.ee_data += len; + return true; +} + +static void __msg_zerocopy_callback(struct ubuf_info_msgzc *uarg) +{ + struct sk_buff *tail, *skb = skb_from_uarg(uarg); + struct sock_exterr_skb *serr; + struct sock *sk = skb->sk; + struct sk_buff_head *q; + unsigned long flags; + bool is_zerocopy; + u32 lo, hi; + u16 len; + + mm_unaccount_pinned_pages(&uarg->mmp); + + /* if !len, there was only 1 call, and it was aborted + * so do not queue a completion notification + */ + if (!uarg->len || sock_flag(sk, SOCK_DEAD)) + goto release; + + len = uarg->len; + lo = uarg->id; + hi = uarg->id + len - 1; + is_zerocopy = uarg->zerocopy; + + serr = SKB_EXT_ERR(skb); + memset(serr, 0, sizeof(*serr)); + serr->ee.ee_errno = 0; + serr->ee.ee_origin = SO_EE_ORIGIN_ZEROCOPY; + serr->ee.ee_data = hi; + serr->ee.ee_info = lo; + if (!is_zerocopy) + serr->ee.ee_code |= SO_EE_CODE_ZEROCOPY_COPIED; + + q = &sk->sk_error_queue; + spin_lock_irqsave(&q->lock, flags); + tail = skb_peek_tail(q); + if (!tail || SKB_EXT_ERR(tail)->ee.ee_origin != SO_EE_ORIGIN_ZEROCOPY || + !skb_zerocopy_notify_extend(tail, lo, len)) { + __skb_queue_tail(q, skb); + skb = NULL; + } + spin_unlock_irqrestore(&q->lock, flags); + + sk_error_report(sk); + +release: + consume_skb(skb); + sock_put(sk); +} + +void msg_zerocopy_callback(struct sk_buff *skb, struct ubuf_info *uarg, + bool success) +{ + struct ubuf_info_msgzc *uarg_zc = uarg_to_msgzc(uarg); + + uarg_zc->zerocopy = uarg_zc->zerocopy & success; + + if (refcount_dec_and_test(&uarg->refcnt)) + __msg_zerocopy_callback(uarg_zc); +} +EXPORT_SYMBOL_GPL(msg_zerocopy_callback); + +void msg_zerocopy_put_abort(struct ubuf_info *uarg, bool have_uref) +{ + struct sock *sk = skb_from_uarg(uarg_to_msgzc(uarg))->sk; + + atomic_dec(&sk->sk_zckey); + uarg_to_msgzc(uarg)->len--; + + if (have_uref) + msg_zerocopy_callback(NULL, uarg, true); +} +EXPORT_SYMBOL_GPL(msg_zerocopy_put_abort); + +int skb_zerocopy_iter_stream(struct sock *sk, struct sk_buff *skb, + struct msghdr *msg, int len, + struct ubuf_info *uarg) +{ + struct ubuf_info *orig_uarg = skb_zcopy(skb); + int err, orig_len = skb->len; + + /* An skb can only point to one uarg. This edge case happens when + * TCP appends to an skb, but zerocopy_realloc triggered a new alloc. + */ + if (orig_uarg && uarg != orig_uarg) + return -EEXIST; + + err = __zerocopy_sg_from_iter(msg, sk, skb, &msg->msg_iter, len); + if (err == -EFAULT || (err == -EMSGSIZE && skb->len == orig_len)) { + struct sock *save_sk = skb->sk; + + /* Streams do not free skb on error. Reset to prev state. */ + iov_iter_revert(&msg->msg_iter, skb->len - orig_len); + skb->sk = sk; + ___pskb_trim(skb, orig_len); + skb->sk = save_sk; + return err; + } + + skb_zcopy_set(skb, uarg, NULL); + return skb->len - orig_len; +} +EXPORT_SYMBOL_GPL(skb_zerocopy_iter_stream); + +void __skb_zcopy_downgrade_managed(struct sk_buff *skb) +{ + int i; + + skb_shinfo(skb)->flags &= ~SKBFL_MANAGED_FRAG_REFS; + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) + skb_frag_ref(skb, i); +} +EXPORT_SYMBOL_GPL(__skb_zcopy_downgrade_managed); + +static int skb_zerocopy_clone(struct sk_buff *nskb, struct sk_buff *orig, + gfp_t gfp_mask) +{ + if (skb_zcopy(orig)) { + if (skb_zcopy(nskb)) { + /* !gfp_mask callers are verified to !skb_zcopy(nskb) */ + if (!gfp_mask) { + WARN_ON_ONCE(1); + return -ENOMEM; + } + if (skb_uarg(nskb) == skb_uarg(orig)) + return 0; + if (skb_copy_ubufs(nskb, GFP_ATOMIC)) + return -EIO; + } + skb_zcopy_set(nskb, skb_uarg(orig), NULL); + } + return 0; +} + +/** + * skb_copy_ubufs - copy userspace skb frags buffers to kernel + * @skb: the skb to modify + * @gfp_mask: allocation priority + * + * This must be called on skb with SKBFL_ZEROCOPY_ENABLE. + * It will copy all frags into kernel and drop the reference + * to userspace pages. + * + * If this function is called from an interrupt gfp_mask() must be + * %GFP_ATOMIC. + * + * Returns 0 on success or a negative error code on failure + * to allocate kernel memory to copy to. + */ +int skb_copy_ubufs(struct sk_buff *skb, gfp_t gfp_mask) +{ + int num_frags = skb_shinfo(skb)->nr_frags; + struct page *page, *head = NULL; + int i, order, psize, new_frags; + u32 d_off; + + if (skb_shared(skb) || skb_unclone(skb, gfp_mask)) + return -EINVAL; + + if (!num_frags) + goto release; + + /* We might have to allocate high order pages, so compute what minimum + * page order is needed. + */ + order = 0; + while ((PAGE_SIZE << order) * MAX_SKB_FRAGS < __skb_pagelen(skb)) + order++; + psize = (PAGE_SIZE << order); + + new_frags = (__skb_pagelen(skb) + psize - 1) >> (PAGE_SHIFT + order); + for (i = 0; i < new_frags; i++) { + page = alloc_pages(gfp_mask | __GFP_COMP, order); + if (!page) { + while (head) { + struct page *next = (struct page *)page_private(head); + put_page(head); + head = next; + } + return -ENOMEM; + } + set_page_private(page, (unsigned long)head); + head = page; + } + + page = head; + d_off = 0; + for (i = 0; i < num_frags; i++) { + skb_frag_t *f = &skb_shinfo(skb)->frags[i]; + u32 p_off, p_len, copied; + struct page *p; + u8 *vaddr; + + skb_frag_foreach_page(f, skb_frag_off(f), skb_frag_size(f), + p, p_off, p_len, copied) { + u32 copy, done = 0; + vaddr = kmap_atomic(p); + + while (done < p_len) { + if (d_off == psize) { + d_off = 0; + page = (struct page *)page_private(page); + } + copy = min_t(u32, psize - d_off, p_len - done); + memcpy(page_address(page) + d_off, + vaddr + p_off + done, copy); + done += copy; + d_off += copy; + } + kunmap_atomic(vaddr); + } + } + + /* skb frags release userspace buffers */ + for (i = 0; i < num_frags; i++) + skb_frag_unref(skb, i); + + /* skb frags point to kernel buffers */ + for (i = 0; i < new_frags - 1; i++) { + __skb_fill_page_desc(skb, i, head, 0, psize); + head = (struct page *)page_private(head); + } + __skb_fill_page_desc(skb, new_frags - 1, head, 0, d_off); + skb_shinfo(skb)->nr_frags = new_frags; + +release: + skb_zcopy_clear(skb, false); + return 0; +} +EXPORT_SYMBOL_GPL(skb_copy_ubufs); + +/** + * skb_clone - duplicate an sk_buff + * @skb: buffer to clone + * @gfp_mask: allocation priority + * + * Duplicate an &sk_buff. The new one is not owned by a socket. Both + * copies share the same packet data but not structure. The new + * buffer has a reference count of 1. If the allocation fails the + * function returns %NULL otherwise the new buffer is returned. + * + * If this function is called from an interrupt gfp_mask() must be + * %GFP_ATOMIC. + */ + +struct sk_buff *skb_clone(struct sk_buff *skb, gfp_t gfp_mask) +{ + struct sk_buff_fclones *fclones = container_of(skb, + struct sk_buff_fclones, + skb1); + struct sk_buff *n; + + if (skb_orphan_frags(skb, gfp_mask)) + return NULL; + + if (skb->fclone == SKB_FCLONE_ORIG && + refcount_read(&fclones->fclone_ref) == 1) { + n = &fclones->skb2; + refcount_set(&fclones->fclone_ref, 2); + n->fclone = SKB_FCLONE_CLONE; + } else { + if (skb_pfmemalloc(skb)) + gfp_mask |= __GFP_MEMALLOC; + + n = kmem_cache_alloc(skbuff_head_cache, gfp_mask); + if (!n) + return NULL; + + n->fclone = SKB_FCLONE_UNAVAILABLE; + } + + return __skb_clone(n, skb); +} +EXPORT_SYMBOL(skb_clone); + +void skb_headers_offset_update(struct sk_buff *skb, int off) +{ + /* Only adjust this if it actually is csum_start rather than csum */ + if (skb->ip_summed == CHECKSUM_PARTIAL) + skb->csum_start += off; + /* {transport,network,mac}_header and tail are relative to skb->head */ + skb->transport_header += off; + skb->network_header += off; + if (skb_mac_header_was_set(skb)) + skb->mac_header += off; + skb->inner_transport_header += off; + skb->inner_network_header += off; + skb->inner_mac_header += off; +} +EXPORT_SYMBOL(skb_headers_offset_update); + +void skb_copy_header(struct sk_buff *new, const struct sk_buff *old) +{ + __copy_skb_header(new, old); + + skb_shinfo(new)->gso_size = skb_shinfo(old)->gso_size; + skb_shinfo(new)->gso_segs = skb_shinfo(old)->gso_segs; + skb_shinfo(new)->gso_type = skb_shinfo(old)->gso_type; +} +EXPORT_SYMBOL(skb_copy_header); + +static inline int skb_alloc_rx_flag(const struct sk_buff *skb) +{ + if (skb_pfmemalloc(skb)) + return SKB_ALLOC_RX; + return 0; +} + +/** + * skb_copy - create private copy of an sk_buff + * @skb: buffer to copy + * @gfp_mask: allocation priority + * + * Make a copy of both an &sk_buff and its data. This is used when the + * caller wishes to modify the data and needs a private copy of the + * data to alter. Returns %NULL on failure or the pointer to the buffer + * on success. The returned buffer has a reference count of 1. + * + * As by-product this function converts non-linear &sk_buff to linear + * one, so that &sk_buff becomes completely private and caller is allowed + * to modify all the data of returned buffer. This means that this + * function is not recommended for use in circumstances when only + * header is going to be modified. Use pskb_copy() instead. + */ + +struct sk_buff *skb_copy(const struct sk_buff *skb, gfp_t gfp_mask) +{ + int headerlen = skb_headroom(skb); + unsigned int size = skb_end_offset(skb) + skb->data_len; + struct sk_buff *n = __alloc_skb(size, gfp_mask, + skb_alloc_rx_flag(skb), NUMA_NO_NODE); + + if (!n) + return NULL; + + /* Set the data pointer */ + skb_reserve(n, headerlen); + /* Set the tail pointer and length */ + skb_put(n, skb->len); + + BUG_ON(skb_copy_bits(skb, -headerlen, n->head, headerlen + skb->len)); + + skb_copy_header(n, skb); + return n; +} +EXPORT_SYMBOL(skb_copy); + +/** + * __pskb_copy_fclone - create copy of an sk_buff with private head. + * @skb: buffer to copy + * @headroom: headroom of new skb + * @gfp_mask: allocation priority + * @fclone: if true allocate the copy of the skb from the fclone + * cache instead of the head cache; it is recommended to set this + * to true for the cases where the copy will likely be cloned + * + * Make a copy of both an &sk_buff and part of its data, located + * in header. Fragmented data remain shared. This is used when + * the caller wishes to modify only header of &sk_buff and needs + * private copy of the header to alter. Returns %NULL on failure + * or the pointer to the buffer on success. + * The returned buffer has a reference count of 1. + */ + +struct sk_buff *__pskb_copy_fclone(struct sk_buff *skb, int headroom, + gfp_t gfp_mask, bool fclone) +{ + unsigned int size = skb_headlen(skb) + headroom; + int flags = skb_alloc_rx_flag(skb) | (fclone ? SKB_ALLOC_FCLONE : 0); + struct sk_buff *n = __alloc_skb(size, gfp_mask, flags, NUMA_NO_NODE); + + if (!n) + goto out; + + /* Set the data pointer */ + skb_reserve(n, headroom); + /* Set the tail pointer and length */ + skb_put(n, skb_headlen(skb)); + /* Copy the bytes */ + skb_copy_from_linear_data(skb, n->data, n->len); + + n->truesize += skb->data_len; + n->data_len = skb->data_len; + n->len = skb->len; + + if (skb_shinfo(skb)->nr_frags) { + int i; + + if (skb_orphan_frags(skb, gfp_mask) || + skb_zerocopy_clone(n, skb, gfp_mask)) { + kfree_skb(n); + n = NULL; + goto out; + } + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + skb_shinfo(n)->frags[i] = skb_shinfo(skb)->frags[i]; + skb_frag_ref(skb, i); + } + skb_shinfo(n)->nr_frags = i; + } + + if (skb_has_frag_list(skb)) { + skb_shinfo(n)->frag_list = skb_shinfo(skb)->frag_list; + skb_clone_fraglist(n); + } + + skb_copy_header(n, skb); +out: + return n; +} +EXPORT_SYMBOL(__pskb_copy_fclone); + +/** + * pskb_expand_head - reallocate header of &sk_buff + * @skb: buffer to reallocate + * @nhead: room to add at head + * @ntail: room to add at tail + * @gfp_mask: allocation priority + * + * Expands (or creates identical copy, if @nhead and @ntail are zero) + * header of @skb. &sk_buff itself is not changed. &sk_buff MUST have + * reference count of 1. Returns zero in the case of success or error, + * if expansion failed. In the last case, &sk_buff is not changed. + * + * All the pointers pointing into skb header may change and must be + * reloaded after call to this function. + */ + +int pskb_expand_head(struct sk_buff *skb, int nhead, int ntail, + gfp_t gfp_mask) +{ + unsigned int osize = skb_end_offset(skb); + unsigned int size = osize + nhead + ntail; + long off; + u8 *data; + int i; + + BUG_ON(nhead < 0); + + BUG_ON(skb_shared(skb)); + + skb_zcopy_downgrade_managed(skb); + + if (skb_pfmemalloc(skb)) + gfp_mask |= __GFP_MEMALLOC; + + data = kmalloc_reserve(&size, gfp_mask, NUMA_NO_NODE, NULL); + if (!data) + goto nodata; + size = SKB_WITH_OVERHEAD(size); + + /* Copy only real data... and, alas, header. This should be + * optimized for the cases when header is void. + */ + memcpy(data + nhead, skb->head, skb_tail_pointer(skb) - skb->head); + + memcpy((struct skb_shared_info *)(data + size), + skb_shinfo(skb), + offsetof(struct skb_shared_info, frags[skb_shinfo(skb)->nr_frags])); + + /* + * if shinfo is shared we must drop the old head gracefully, but if it + * is not we can just drop the old head and let the existing refcount + * be since all we did is relocate the values + */ + if (skb_cloned(skb)) { + if (skb_orphan_frags(skb, gfp_mask)) + goto nofrags; + if (skb_zcopy(skb)) + refcount_inc(&skb_uarg(skb)->refcnt); + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) + skb_frag_ref(skb, i); + + if (skb_has_frag_list(skb)) + skb_clone_fraglist(skb); + + skb_release_data(skb); + } else { + skb_free_head(skb); + } + off = (data + nhead) - skb->head; + + skb->head = data; + skb->head_frag = 0; + skb->data += off; + + skb_set_end_offset(skb, size); +#ifdef NET_SKBUFF_DATA_USES_OFFSET + off = nhead; +#endif + skb->tail += off; + skb_headers_offset_update(skb, nhead); + skb->cloned = 0; + skb->hdr_len = 0; + skb->nohdr = 0; + atomic_set(&skb_shinfo(skb)->dataref, 1); + + skb_metadata_clear(skb); + + /* It is not generally safe to change skb->truesize. + * For the moment, we really care of rx path, or + * when skb is orphaned (not attached to a socket). + */ + if (!skb->sk || skb->destructor == sock_edemux) + skb->truesize += size - osize; + + return 0; + +nofrags: + kfree(data); +nodata: + return -ENOMEM; +} +EXPORT_SYMBOL(pskb_expand_head); + +/* Make private copy of skb with writable head and some headroom */ + +struct sk_buff *skb_realloc_headroom(struct sk_buff *skb, unsigned int headroom) +{ + struct sk_buff *skb2; + int delta = headroom - skb_headroom(skb); + + if (delta <= 0) + skb2 = pskb_copy(skb, GFP_ATOMIC); + else { + skb2 = skb_clone(skb, GFP_ATOMIC); + if (skb2 && pskb_expand_head(skb2, SKB_DATA_ALIGN(delta), 0, + GFP_ATOMIC)) { + kfree_skb(skb2); + skb2 = NULL; + } + } + return skb2; +} +EXPORT_SYMBOL(skb_realloc_headroom); + +int __skb_unclone_keeptruesize(struct sk_buff *skb, gfp_t pri) +{ + unsigned int saved_end_offset, saved_truesize; + struct skb_shared_info *shinfo; + int res; + + saved_end_offset = skb_end_offset(skb); + saved_truesize = skb->truesize; + + res = pskb_expand_head(skb, 0, 0, pri); + if (res) + return res; + + skb->truesize = saved_truesize; + + if (likely(skb_end_offset(skb) == saved_end_offset)) + return 0; + + shinfo = skb_shinfo(skb); + + /* We are about to change back skb->end, + * we need to move skb_shinfo() to its new location. + */ + memmove(skb->head + saved_end_offset, + shinfo, + offsetof(struct skb_shared_info, frags[shinfo->nr_frags])); + + skb_set_end_offset(skb, saved_end_offset); + + return 0; +} + +/** + * skb_expand_head - reallocate header of &sk_buff + * @skb: buffer to reallocate + * @headroom: needed headroom + * + * Unlike skb_realloc_headroom, this one does not allocate a new skb + * if possible; copies skb->sk to new skb as needed + * and frees original skb in case of failures. + * + * It expect increased headroom and generates warning otherwise. + */ + +struct sk_buff *skb_expand_head(struct sk_buff *skb, unsigned int headroom) +{ + int delta = headroom - skb_headroom(skb); + int osize = skb_end_offset(skb); + struct sock *sk = skb->sk; + + if (WARN_ONCE(delta <= 0, + "%s is expecting an increase in the headroom", __func__)) + return skb; + + delta = SKB_DATA_ALIGN(delta); + /* pskb_expand_head() might crash, if skb is shared. */ + if (skb_shared(skb) || !is_skb_wmem(skb)) { + struct sk_buff *nskb = skb_clone(skb, GFP_ATOMIC); + + if (unlikely(!nskb)) + goto fail; + + if (sk) + skb_set_owner_w(nskb, sk); + consume_skb(skb); + skb = nskb; + } + if (pskb_expand_head(skb, delta, 0, GFP_ATOMIC)) + goto fail; + + if (sk && is_skb_wmem(skb)) { + delta = skb_end_offset(skb) - osize; + refcount_add(delta, &sk->sk_wmem_alloc); + skb->truesize += delta; + } + return skb; + +fail: + kfree_skb(skb); + return NULL; +} +EXPORT_SYMBOL(skb_expand_head); + +/** + * skb_copy_expand - copy and expand sk_buff + * @skb: buffer to copy + * @newheadroom: new free bytes at head + * @newtailroom: new free bytes at tail + * @gfp_mask: allocation priority + * + * Make a copy of both an &sk_buff and its data and while doing so + * allocate additional space. + * + * This is used when the caller wishes to modify the data and needs a + * private copy of the data to alter as well as more space for new fields. + * Returns %NULL on failure or the pointer to the buffer + * on success. The returned buffer has a reference count of 1. + * + * You must pass %GFP_ATOMIC as the allocation priority if this function + * is called from an interrupt. + */ +struct sk_buff *skb_copy_expand(const struct sk_buff *skb, + int newheadroom, int newtailroom, + gfp_t gfp_mask) +{ + /* + * Allocate the copy buffer + */ + struct sk_buff *n = __alloc_skb(newheadroom + skb->len + newtailroom, + gfp_mask, skb_alloc_rx_flag(skb), + NUMA_NO_NODE); + int oldheadroom = skb_headroom(skb); + int head_copy_len, head_copy_off; + + if (!n) + return NULL; + + skb_reserve(n, newheadroom); + + /* Set the tail pointer and length */ + skb_put(n, skb->len); + + head_copy_len = oldheadroom; + head_copy_off = 0; + if (newheadroom <= head_copy_len) + head_copy_len = newheadroom; + else + head_copy_off = newheadroom - head_copy_len; + + /* Copy the linear header and data. */ + BUG_ON(skb_copy_bits(skb, -head_copy_len, n->head + head_copy_off, + skb->len + head_copy_len)); + + skb_copy_header(n, skb); + + skb_headers_offset_update(n, newheadroom - oldheadroom); + + return n; +} +EXPORT_SYMBOL(skb_copy_expand); + +/** + * __skb_pad - zero pad the tail of an skb + * @skb: buffer to pad + * @pad: space to pad + * @free_on_error: free buffer on error + * + * Ensure that a buffer is followed by a padding area that is zero + * filled. Used by network drivers which may DMA or transfer data + * beyond the buffer end onto the wire. + * + * May return error in out of memory cases. The skb is freed on error + * if @free_on_error is true. + */ + +int __skb_pad(struct sk_buff *skb, int pad, bool free_on_error) +{ + int err; + int ntail; + + /* If the skbuff is non linear tailroom is always zero.. */ + if (!skb_cloned(skb) && skb_tailroom(skb) >= pad) { + memset(skb->data+skb->len, 0, pad); + return 0; + } + + ntail = skb->data_len + pad - (skb->end - skb->tail); + if (likely(skb_cloned(skb) || ntail > 0)) { + err = pskb_expand_head(skb, 0, ntail, GFP_ATOMIC); + if (unlikely(err)) + goto free_skb; + } + + /* FIXME: The use of this function with non-linear skb's really needs + * to be audited. + */ + err = skb_linearize(skb); + if (unlikely(err)) + goto free_skb; + + memset(skb->data + skb->len, 0, pad); + return 0; + +free_skb: + if (free_on_error) + kfree_skb(skb); + return err; +} +EXPORT_SYMBOL(__skb_pad); + +/** + * pskb_put - add data to the tail of a potentially fragmented buffer + * @skb: start of the buffer to use + * @tail: tail fragment of the buffer to use + * @len: amount of data to add + * + * This function extends the used data area of the potentially + * fragmented buffer. @tail must be the last fragment of @skb -- or + * @skb itself. If this would exceed the total buffer size the kernel + * will panic. A pointer to the first byte of the extra data is + * returned. + */ + +void *pskb_put(struct sk_buff *skb, struct sk_buff *tail, int len) +{ + if (tail != skb) { + skb->data_len += len; + skb->len += len; + } + return skb_put(tail, len); +} +EXPORT_SYMBOL_GPL(pskb_put); + +/** + * skb_put - add data to a buffer + * @skb: buffer to use + * @len: amount of data to add + * + * This function extends the used data area of the buffer. If this would + * exceed the total buffer size the kernel will panic. A pointer to the + * first byte of the extra data is returned. + */ +void *skb_put(struct sk_buff *skb, unsigned int len) +{ + void *tmp = skb_tail_pointer(skb); + SKB_LINEAR_ASSERT(skb); + skb->tail += len; + skb->len += len; + if (unlikely(skb->tail > skb->end)) + skb_over_panic(skb, len, __builtin_return_address(0)); + return tmp; +} +EXPORT_SYMBOL(skb_put); + +/** + * skb_push - add data to the start of a buffer + * @skb: buffer to use + * @len: amount of data to add + * + * This function extends the used data area of the buffer at the buffer + * start. If this would exceed the total buffer headroom the kernel will + * panic. A pointer to the first byte of the extra data is returned. + */ +void *skb_push(struct sk_buff *skb, unsigned int len) +{ + skb->data -= len; + skb->len += len; + if (unlikely(skb->data < skb->head)) + skb_under_panic(skb, len, __builtin_return_address(0)); + return skb->data; +} +EXPORT_SYMBOL(skb_push); + +/** + * skb_pull - remove data from the start of a buffer + * @skb: buffer to use + * @len: amount of data to remove + * + * This function removes data from the start of a buffer, returning + * the memory to the headroom. A pointer to the next data in the buffer + * is returned. Once the data has been pulled future pushes will overwrite + * the old data. + */ +void *skb_pull(struct sk_buff *skb, unsigned int len) +{ + return skb_pull_inline(skb, len); +} +EXPORT_SYMBOL(skb_pull); + +/** + * skb_pull_data - remove data from the start of a buffer returning its + * original position. + * @skb: buffer to use + * @len: amount of data to remove + * + * This function removes data from the start of a buffer, returning + * the memory to the headroom. A pointer to the original data in the buffer + * is returned after checking if there is enough data to pull. Once the + * data has been pulled future pushes will overwrite the old data. + */ +void *skb_pull_data(struct sk_buff *skb, size_t len) +{ + void *data = skb->data; + + if (skb->len < len) + return NULL; + + skb_pull(skb, len); + + return data; +} +EXPORT_SYMBOL(skb_pull_data); + +/** + * skb_trim - remove end from a buffer + * @skb: buffer to alter + * @len: new length + * + * Cut the length of a buffer down by removing data from the tail. If + * the buffer is already under the length specified it is not modified. + * The skb must be linear. + */ +void skb_trim(struct sk_buff *skb, unsigned int len) +{ + if (skb->len > len) + __skb_trim(skb, len); +} +EXPORT_SYMBOL(skb_trim); + +/* Trims skb to length len. It can change skb pointers. + */ + +int ___pskb_trim(struct sk_buff *skb, unsigned int len) +{ + struct sk_buff **fragp; + struct sk_buff *frag; + int offset = skb_headlen(skb); + int nfrags = skb_shinfo(skb)->nr_frags; + int i; + int err; + + if (skb_cloned(skb) && + unlikely((err = pskb_expand_head(skb, 0, 0, GFP_ATOMIC)))) + return err; + + i = 0; + if (offset >= len) + goto drop_pages; + + for (; i < nfrags; i++) { + int end = offset + skb_frag_size(&skb_shinfo(skb)->frags[i]); + + if (end < len) { + offset = end; + continue; + } + + skb_frag_size_set(&skb_shinfo(skb)->frags[i++], len - offset); + +drop_pages: + skb_shinfo(skb)->nr_frags = i; + + for (; i < nfrags; i++) + skb_frag_unref(skb, i); + + if (skb_has_frag_list(skb)) + skb_drop_fraglist(skb); + goto done; + } + + for (fragp = &skb_shinfo(skb)->frag_list; (frag = *fragp); + fragp = &frag->next) { + int end = offset + frag->len; + + if (skb_shared(frag)) { + struct sk_buff *nfrag; + + nfrag = skb_clone(frag, GFP_ATOMIC); + if (unlikely(!nfrag)) + return -ENOMEM; + + nfrag->next = frag->next; + consume_skb(frag); + frag = nfrag; + *fragp = frag; + } + + if (end < len) { + offset = end; + continue; + } + + if (end > len && + unlikely((err = pskb_trim(frag, len - offset)))) + return err; + + if (frag->next) + skb_drop_list(&frag->next); + break; + } + +done: + if (len > skb_headlen(skb)) { + skb->data_len -= skb->len - len; + skb->len = len; + } else { + skb->len = len; + skb->data_len = 0; + skb_set_tail_pointer(skb, len); + } + + if (!skb->sk || skb->destructor == sock_edemux) + skb_condense(skb); + return 0; +} +EXPORT_SYMBOL(___pskb_trim); + +/* Note : use pskb_trim_rcsum() instead of calling this directly + */ +int pskb_trim_rcsum_slow(struct sk_buff *skb, unsigned int len) +{ + if (skb->ip_summed == CHECKSUM_COMPLETE) { + int delta = skb->len - len; + + skb->csum = csum_block_sub(skb->csum, + skb_checksum(skb, len, delta, 0), + len); + } else if (skb->ip_summed == CHECKSUM_PARTIAL) { + int hdlen = (len > skb_headlen(skb)) ? skb_headlen(skb) : len; + int offset = skb_checksum_start_offset(skb) + skb->csum_offset; + + if (offset + sizeof(__sum16) > hdlen) + return -EINVAL; + } + return __pskb_trim(skb, len); +} +EXPORT_SYMBOL(pskb_trim_rcsum_slow); + +/** + * __pskb_pull_tail - advance tail of skb header + * @skb: buffer to reallocate + * @delta: number of bytes to advance tail + * + * The function makes a sense only on a fragmented &sk_buff, + * it expands header moving its tail forward and copying necessary + * data from fragmented part. + * + * &sk_buff MUST have reference count of 1. + * + * Returns %NULL (and &sk_buff does not change) if pull failed + * or value of new tail of skb in the case of success. + * + * All the pointers pointing into skb header may change and must be + * reloaded after call to this function. + */ + +/* Moves tail of skb head forward, copying data from fragmented part, + * when it is necessary. + * 1. It may fail due to malloc failure. + * 2. It may change skb pointers. + * + * It is pretty complicated. Luckily, it is called only in exceptional cases. + */ +void *__pskb_pull_tail(struct sk_buff *skb, int delta) +{ + /* If skb has not enough free space at tail, get new one + * plus 128 bytes for future expansions. If we have enough + * room at tail, reallocate without expansion only if skb is cloned. + */ + int i, k, eat = (skb->tail + delta) - skb->end; + + if (eat > 0 || skb_cloned(skb)) { + if (pskb_expand_head(skb, 0, eat > 0 ? eat + 128 : 0, + GFP_ATOMIC)) + return NULL; + } + + BUG_ON(skb_copy_bits(skb, skb_headlen(skb), + skb_tail_pointer(skb), delta)); + + /* Optimization: no fragments, no reasons to preestimate + * size of pulled pages. Superb. + */ + if (!skb_has_frag_list(skb)) + goto pull_pages; + + /* Estimate size of pulled pages. */ + eat = delta; + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + int size = skb_frag_size(&skb_shinfo(skb)->frags[i]); + + if (size >= eat) + goto pull_pages; + eat -= size; + } + + /* If we need update frag list, we are in troubles. + * Certainly, it is possible to add an offset to skb data, + * but taking into account that pulling is expected to + * be very rare operation, it is worth to fight against + * further bloating skb head and crucify ourselves here instead. + * Pure masohism, indeed. 8)8) + */ + if (eat) { + struct sk_buff *list = skb_shinfo(skb)->frag_list; + struct sk_buff *clone = NULL; + struct sk_buff *insp = NULL; + + do { + if (list->len <= eat) { + /* Eaten as whole. */ + eat -= list->len; + list = list->next; + insp = list; + } else { + /* Eaten partially. */ + if (skb_is_gso(skb) && !list->head_frag && + skb_headlen(list)) + skb_shinfo(skb)->gso_type |= SKB_GSO_DODGY; + + if (skb_shared(list)) { + /* Sucks! We need to fork list. :-( */ + clone = skb_clone(list, GFP_ATOMIC); + if (!clone) + return NULL; + insp = list->next; + list = clone; + } else { + /* This may be pulled without + * problems. */ + insp = list; + } + if (!pskb_pull(list, eat)) { + kfree_skb(clone); + return NULL; + } + break; + } + } while (eat); + + /* Free pulled out fragments. */ + while ((list = skb_shinfo(skb)->frag_list) != insp) { + skb_shinfo(skb)->frag_list = list->next; + consume_skb(list); + } + /* And insert new clone at head. */ + if (clone) { + clone->next = list; + skb_shinfo(skb)->frag_list = clone; + } + } + /* Success! Now we may commit changes to skb data. */ + +pull_pages: + eat = delta; + k = 0; + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + int size = skb_frag_size(&skb_shinfo(skb)->frags[i]); + + if (size <= eat) { + skb_frag_unref(skb, i); + eat -= size; + } else { + skb_frag_t *frag = &skb_shinfo(skb)->frags[k]; + + *frag = skb_shinfo(skb)->frags[i]; + if (eat) { + skb_frag_off_add(frag, eat); + skb_frag_size_sub(frag, eat); + if (!i) + goto end; + eat = 0; + } + k++; + } + } + skb_shinfo(skb)->nr_frags = k; + +end: + skb->tail += delta; + skb->data_len -= delta; + + if (!skb->data_len) + skb_zcopy_clear(skb, false); + + return skb_tail_pointer(skb); +} +EXPORT_SYMBOL(__pskb_pull_tail); + +/** + * skb_copy_bits - copy bits from skb to kernel buffer + * @skb: source skb + * @offset: offset in source + * @to: destination buffer + * @len: number of bytes to copy + * + * Copy the specified number of bytes from the source skb to the + * destination buffer. + * + * CAUTION ! : + * If its prototype is ever changed, + * check arch/{*}/net/{*}.S files, + * since it is called from BPF assembly code. + */ +int skb_copy_bits(const struct sk_buff *skb, int offset, void *to, int len) +{ + int start = skb_headlen(skb); + struct sk_buff *frag_iter; + int i, copy; + + if (offset > (int)skb->len - len) + goto fault; + + /* Copy header. */ + if ((copy = start - offset) > 0) { + if (copy > len) + copy = len; + skb_copy_from_linear_data_offset(skb, offset, to, copy); + if ((len -= copy) == 0) + return 0; + offset += copy; + to += copy; + } + + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + int end; + skb_frag_t *f = &skb_shinfo(skb)->frags[i]; + + WARN_ON(start > offset + len); + + end = start + skb_frag_size(f); + if ((copy = end - offset) > 0) { + u32 p_off, p_len, copied; + struct page *p; + u8 *vaddr; + + if (copy > len) + copy = len; + + skb_frag_foreach_page(f, + skb_frag_off(f) + offset - start, + copy, p, p_off, p_len, copied) { + vaddr = kmap_atomic(p); + memcpy(to + copied, vaddr + p_off, p_len); + kunmap_atomic(vaddr); + } + + if ((len -= copy) == 0) + return 0; + offset += copy; + to += copy; + } + start = end; + } + + skb_walk_frags(skb, frag_iter) { + int end; + + WARN_ON(start > offset + len); + + end = start + frag_iter->len; + if ((copy = end - offset) > 0) { + if (copy > len) + copy = len; + if (skb_copy_bits(frag_iter, offset - start, to, copy)) + goto fault; + if ((len -= copy) == 0) + return 0; + offset += copy; + to += copy; + } + start = end; + } + + if (!len) + return 0; + +fault: + return -EFAULT; +} +EXPORT_SYMBOL(skb_copy_bits); + +/* + * Callback from splice_to_pipe(), if we need to release some pages + * at the end of the spd in case we error'ed out in filling the pipe. + */ +static void sock_spd_release(struct splice_pipe_desc *spd, unsigned int i) +{ + put_page(spd->pages[i]); +} + +static struct page *linear_to_page(struct page *page, unsigned int *len, + unsigned int *offset, + struct sock *sk) +{ + struct page_frag *pfrag = sk_page_frag(sk); + + if (!sk_page_frag_refill(sk, pfrag)) + return NULL; + + *len = min_t(unsigned int, *len, pfrag->size - pfrag->offset); + + memcpy(page_address(pfrag->page) + pfrag->offset, + page_address(page) + *offset, *len); + *offset = pfrag->offset; + pfrag->offset += *len; + + return pfrag->page; +} + +static bool spd_can_coalesce(const struct splice_pipe_desc *spd, + struct page *page, + unsigned int offset) +{ + return spd->nr_pages && + spd->pages[spd->nr_pages - 1] == page && + (spd->partial[spd->nr_pages - 1].offset + + spd->partial[spd->nr_pages - 1].len == offset); +} + +/* + * Fill page/offset/length into spd, if it can hold more pages. + */ +static bool spd_fill_page(struct splice_pipe_desc *spd, + struct pipe_inode_info *pipe, struct page *page, + unsigned int *len, unsigned int offset, + bool linear, + struct sock *sk) +{ + if (unlikely(spd->nr_pages == MAX_SKB_FRAGS)) + return true; + + if (linear) { + page = linear_to_page(page, len, &offset, sk); + if (!page) + return true; + } + if (spd_can_coalesce(spd, page, offset)) { + spd->partial[spd->nr_pages - 1].len += *len; + return false; + } + get_page(page); + spd->pages[spd->nr_pages] = page; + spd->partial[spd->nr_pages].len = *len; + spd->partial[spd->nr_pages].offset = offset; + spd->nr_pages++; + + return false; +} + +static bool __splice_segment(struct page *page, unsigned int poff, + unsigned int plen, unsigned int *off, + unsigned int *len, + struct splice_pipe_desc *spd, bool linear, + struct sock *sk, + struct pipe_inode_info *pipe) +{ + if (!*len) + return true; + + /* skip this segment if already processed */ + if (*off >= plen) { + *off -= plen; + return false; + } + + /* ignore any bits we already processed */ + poff += *off; + plen -= *off; + *off = 0; + + do { + unsigned int flen = min(*len, plen); + + if (spd_fill_page(spd, pipe, page, &flen, poff, + linear, sk)) + return true; + poff += flen; + plen -= flen; + *len -= flen; + } while (*len && plen); + + return false; +} + +/* + * Map linear and fragment data from the skb to spd. It reports true if the + * pipe is full or if we already spliced the requested length. + */ +static bool __skb_splice_bits(struct sk_buff *skb, struct pipe_inode_info *pipe, + unsigned int *offset, unsigned int *len, + struct splice_pipe_desc *spd, struct sock *sk) +{ + int seg; + struct sk_buff *iter; + + /* map the linear part : + * If skb->head_frag is set, this 'linear' part is backed by a + * fragment, and if the head is not shared with any clones then + * we can avoid a copy since we own the head portion of this page. + */ + if (__splice_segment(virt_to_page(skb->data), + (unsigned long) skb->data & (PAGE_SIZE - 1), + skb_headlen(skb), + offset, len, spd, + skb_head_is_locked(skb), + sk, pipe)) + return true; + + /* + * then map the fragments + */ + for (seg = 0; seg < skb_shinfo(skb)->nr_frags; seg++) { + const skb_frag_t *f = &skb_shinfo(skb)->frags[seg]; + + if (__splice_segment(skb_frag_page(f), + skb_frag_off(f), skb_frag_size(f), + offset, len, spd, false, sk, pipe)) + return true; + } + + skb_walk_frags(skb, iter) { + if (*offset >= iter->len) { + *offset -= iter->len; + continue; + } + /* __skb_splice_bits() only fails if the output has no room + * left, so no point in going over the frag_list for the error + * case. + */ + if (__skb_splice_bits(iter, pipe, offset, len, spd, sk)) + return true; + } + + return false; +} + +/* + * Map data from the skb to a pipe. Should handle both the linear part, + * the fragments, and the frag list. + */ +int skb_splice_bits(struct sk_buff *skb, struct sock *sk, unsigned int offset, + struct pipe_inode_info *pipe, unsigned int tlen, + unsigned int flags) +{ + struct partial_page partial[MAX_SKB_FRAGS]; + struct page *pages[MAX_SKB_FRAGS]; + struct splice_pipe_desc spd = { + .pages = pages, + .partial = partial, + .nr_pages_max = MAX_SKB_FRAGS, + .ops = &nosteal_pipe_buf_ops, + .spd_release = sock_spd_release, + }; + int ret = 0; + + __skb_splice_bits(skb, pipe, &offset, &tlen, &spd, sk); + + if (spd.nr_pages) + ret = splice_to_pipe(pipe, &spd); + + return ret; +} +EXPORT_SYMBOL_GPL(skb_splice_bits); + +static int sendmsg_unlocked(struct sock *sk, struct msghdr *msg, + struct kvec *vec, size_t num, size_t size) +{ + struct socket *sock = sk->sk_socket; + + if (!sock) + return -EINVAL; + return kernel_sendmsg(sock, msg, vec, num, size); +} + +static int sendpage_unlocked(struct sock *sk, struct page *page, int offset, + size_t size, int flags) +{ + struct socket *sock = sk->sk_socket; + + if (!sock) + return -EINVAL; + return kernel_sendpage(sock, page, offset, size, flags); +} + +typedef int (*sendmsg_func)(struct sock *sk, struct msghdr *msg, + struct kvec *vec, size_t num, size_t size); +typedef int (*sendpage_func)(struct sock *sk, struct page *page, int offset, + size_t size, int flags); +static int __skb_send_sock(struct sock *sk, struct sk_buff *skb, int offset, + int len, sendmsg_func sendmsg, sendpage_func sendpage) +{ + unsigned int orig_len = len; + struct sk_buff *head = skb; + unsigned short fragidx; + int slen, ret; + +do_frag_list: + + /* Deal with head data */ + while (offset < skb_headlen(skb) && len) { + struct kvec kv; + struct msghdr msg; + + slen = min_t(int, len, skb_headlen(skb) - offset); + kv.iov_base = skb->data + offset; + kv.iov_len = slen; + memset(&msg, 0, sizeof(msg)); + msg.msg_flags = MSG_DONTWAIT; + + ret = INDIRECT_CALL_2(sendmsg, kernel_sendmsg_locked, + sendmsg_unlocked, sk, &msg, &kv, 1, slen); + if (ret <= 0) + goto error; + + offset += ret; + len -= ret; + } + + /* All the data was skb head? */ + if (!len) + goto out; + + /* Make offset relative to start of frags */ + offset -= skb_headlen(skb); + + /* Find where we are in frag list */ + for (fragidx = 0; fragidx < skb_shinfo(skb)->nr_frags; fragidx++) { + skb_frag_t *frag = &skb_shinfo(skb)->frags[fragidx]; + + if (offset < skb_frag_size(frag)) + break; + + offset -= skb_frag_size(frag); + } + + for (; len && fragidx < skb_shinfo(skb)->nr_frags; fragidx++) { + skb_frag_t *frag = &skb_shinfo(skb)->frags[fragidx]; + + slen = min_t(size_t, len, skb_frag_size(frag) - offset); + + while (slen) { + ret = INDIRECT_CALL_2(sendpage, kernel_sendpage_locked, + sendpage_unlocked, sk, + skb_frag_page(frag), + skb_frag_off(frag) + offset, + slen, MSG_DONTWAIT); + if (ret <= 0) + goto error; + + len -= ret; + offset += ret; + slen -= ret; + } + + offset = 0; + } + + if (len) { + /* Process any frag lists */ + + if (skb == head) { + if (skb_has_frag_list(skb)) { + skb = skb_shinfo(skb)->frag_list; + goto do_frag_list; + } + } else if (skb->next) { + skb = skb->next; + goto do_frag_list; + } + } + +out: + return orig_len - len; + +error: + return orig_len == len ? ret : orig_len - len; +} + +/* Send skb data on a socket. Socket must be locked. */ +int skb_send_sock_locked(struct sock *sk, struct sk_buff *skb, int offset, + int len) +{ + return __skb_send_sock(sk, skb, offset, len, kernel_sendmsg_locked, + kernel_sendpage_locked); +} +EXPORT_SYMBOL_GPL(skb_send_sock_locked); + +/* Send skb data on a socket. Socket must be unlocked. */ +int skb_send_sock(struct sock *sk, struct sk_buff *skb, int offset, int len) +{ + return __skb_send_sock(sk, skb, offset, len, sendmsg_unlocked, + sendpage_unlocked); +} + +/** + * skb_store_bits - store bits from kernel buffer to skb + * @skb: destination buffer + * @offset: offset in destination + * @from: source buffer + * @len: number of bytes to copy + * + * Copy the specified number of bytes from the source buffer to the + * destination skb. This function handles all the messy bits of + * traversing fragment lists and such. + */ + +int skb_store_bits(struct sk_buff *skb, int offset, const void *from, int len) +{ + int start = skb_headlen(skb); + struct sk_buff *frag_iter; + int i, copy; + + if (offset > (int)skb->len - len) + goto fault; + + if ((copy = start - offset) > 0) { + if (copy > len) + copy = len; + skb_copy_to_linear_data_offset(skb, offset, from, copy); + if ((len -= copy) == 0) + return 0; + offset += copy; + from += copy; + } + + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; + int end; + + WARN_ON(start > offset + len); + + end = start + skb_frag_size(frag); + if ((copy = end - offset) > 0) { + u32 p_off, p_len, copied; + struct page *p; + u8 *vaddr; + + if (copy > len) + copy = len; + + skb_frag_foreach_page(frag, + skb_frag_off(frag) + offset - start, + copy, p, p_off, p_len, copied) { + vaddr = kmap_atomic(p); + memcpy(vaddr + p_off, from + copied, p_len); + kunmap_atomic(vaddr); + } + + if ((len -= copy) == 0) + return 0; + offset += copy; + from += copy; + } + start = end; + } + + skb_walk_frags(skb, frag_iter) { + int end; + + WARN_ON(start > offset + len); + + end = start + frag_iter->len; + if ((copy = end - offset) > 0) { + if (copy > len) + copy = len; + if (skb_store_bits(frag_iter, offset - start, + from, copy)) + goto fault; + if ((len -= copy) == 0) + return 0; + offset += copy; + from += copy; + } + start = end; + } + if (!len) + return 0; + +fault: + return -EFAULT; +} +EXPORT_SYMBOL(skb_store_bits); + +/* Checksum skb data. */ +__wsum __skb_checksum(const struct sk_buff *skb, int offset, int len, + __wsum csum, const struct skb_checksum_ops *ops) +{ + int start = skb_headlen(skb); + int i, copy = start - offset; + struct sk_buff *frag_iter; + int pos = 0; + + /* Checksum header. */ + if (copy > 0) { + if (copy > len) + copy = len; + csum = INDIRECT_CALL_1(ops->update, csum_partial_ext, + skb->data + offset, copy, csum); + if ((len -= copy) == 0) + return csum; + offset += copy; + pos = copy; + } + + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + int end; + skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; + + WARN_ON(start > offset + len); + + end = start + skb_frag_size(frag); + if ((copy = end - offset) > 0) { + u32 p_off, p_len, copied; + struct page *p; + __wsum csum2; + u8 *vaddr; + + if (copy > len) + copy = len; + + skb_frag_foreach_page(frag, + skb_frag_off(frag) + offset - start, + copy, p, p_off, p_len, copied) { + vaddr = kmap_atomic(p); + csum2 = INDIRECT_CALL_1(ops->update, + csum_partial_ext, + vaddr + p_off, p_len, 0); + kunmap_atomic(vaddr); + csum = INDIRECT_CALL_1(ops->combine, + csum_block_add_ext, csum, + csum2, pos, p_len); + pos += p_len; + } + + if (!(len -= copy)) + return csum; + offset += copy; + } + start = end; + } + + skb_walk_frags(skb, frag_iter) { + int end; + + WARN_ON(start > offset + len); + + end = start + frag_iter->len; + if ((copy = end - offset) > 0) { + __wsum csum2; + if (copy > len) + copy = len; + csum2 = __skb_checksum(frag_iter, offset - start, + copy, 0, ops); + csum = INDIRECT_CALL_1(ops->combine, csum_block_add_ext, + csum, csum2, pos, copy); + if ((len -= copy) == 0) + return csum; + offset += copy; + pos += copy; + } + start = end; + } + BUG_ON(len); + + return csum; +} +EXPORT_SYMBOL(__skb_checksum); + +__wsum skb_checksum(const struct sk_buff *skb, int offset, + int len, __wsum csum) +{ + const struct skb_checksum_ops ops = { + .update = csum_partial_ext, + .combine = csum_block_add_ext, + }; + + return __skb_checksum(skb, offset, len, csum, &ops); +} +EXPORT_SYMBOL(skb_checksum); + +/* Both of above in one bottle. */ + +__wsum skb_copy_and_csum_bits(const struct sk_buff *skb, int offset, + u8 *to, int len) +{ + int start = skb_headlen(skb); + int i, copy = start - offset; + struct sk_buff *frag_iter; + int pos = 0; + __wsum csum = 0; + + /* Copy header. */ + if (copy > 0) { + if (copy > len) + copy = len; + csum = csum_partial_copy_nocheck(skb->data + offset, to, + copy); + if ((len -= copy) == 0) + return csum; + offset += copy; + to += copy; + pos = copy; + } + + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + int end; + + WARN_ON(start > offset + len); + + end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]); + if ((copy = end - offset) > 0) { + skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; + u32 p_off, p_len, copied; + struct page *p; + __wsum csum2; + u8 *vaddr; + + if (copy > len) + copy = len; + + skb_frag_foreach_page(frag, + skb_frag_off(frag) + offset - start, + copy, p, p_off, p_len, copied) { + vaddr = kmap_atomic(p); + csum2 = csum_partial_copy_nocheck(vaddr + p_off, + to + copied, + p_len); + kunmap_atomic(vaddr); + csum = csum_block_add(csum, csum2, pos); + pos += p_len; + } + + if (!(len -= copy)) + return csum; + offset += copy; + to += copy; + } + start = end; + } + + skb_walk_frags(skb, frag_iter) { + __wsum csum2; + int end; + + WARN_ON(start > offset + len); + + end = start + frag_iter->len; + if ((copy = end - offset) > 0) { + if (copy > len) + copy = len; + csum2 = skb_copy_and_csum_bits(frag_iter, + offset - start, + to, copy); + csum = csum_block_add(csum, csum2, pos); + if ((len -= copy) == 0) + return csum; + offset += copy; + to += copy; + pos += copy; + } + start = end; + } + BUG_ON(len); + return csum; +} +EXPORT_SYMBOL(skb_copy_and_csum_bits); + +__sum16 __skb_checksum_complete_head(struct sk_buff *skb, int len) +{ + __sum16 sum; + + sum = csum_fold(skb_checksum(skb, 0, len, skb->csum)); + /* See comments in __skb_checksum_complete(). */ + if (likely(!sum)) { + if (unlikely(skb->ip_summed == CHECKSUM_COMPLETE) && + !skb->csum_complete_sw) + netdev_rx_csum_fault(skb->dev, skb); + } + if (!skb_shared(skb)) + skb->csum_valid = !sum; + return sum; +} +EXPORT_SYMBOL(__skb_checksum_complete_head); + +/* This function assumes skb->csum already holds pseudo header's checksum, + * which has been changed from the hardware checksum, for example, by + * __skb_checksum_validate_complete(). And, the original skb->csum must + * have been validated unsuccessfully for CHECKSUM_COMPLETE case. + * + * It returns non-zero if the recomputed checksum is still invalid, otherwise + * zero. The new checksum is stored back into skb->csum unless the skb is + * shared. + */ +__sum16 __skb_checksum_complete(struct sk_buff *skb) +{ + __wsum csum; + __sum16 sum; + + csum = skb_checksum(skb, 0, skb->len, 0); + + sum = csum_fold(csum_add(skb->csum, csum)); + /* This check is inverted, because we already knew the hardware + * checksum is invalid before calling this function. So, if the + * re-computed checksum is valid instead, then we have a mismatch + * between the original skb->csum and skb_checksum(). This means either + * the original hardware checksum is incorrect or we screw up skb->csum + * when moving skb->data around. + */ + if (likely(!sum)) { + if (unlikely(skb->ip_summed == CHECKSUM_COMPLETE) && + !skb->csum_complete_sw) + netdev_rx_csum_fault(skb->dev, skb); + } + + if (!skb_shared(skb)) { + /* Save full packet checksum */ + skb->csum = csum; + skb->ip_summed = CHECKSUM_COMPLETE; + skb->csum_complete_sw = 1; + skb->csum_valid = !sum; + } + + return sum; +} +EXPORT_SYMBOL(__skb_checksum_complete); + +static __wsum warn_crc32c_csum_update(const void *buff, int len, __wsum sum) +{ + net_warn_ratelimited( + "%s: attempt to compute crc32c without libcrc32c.ko\n", + __func__); + return 0; +} + +static __wsum warn_crc32c_csum_combine(__wsum csum, __wsum csum2, + int offset, int len) +{ + net_warn_ratelimited( + "%s: attempt to compute crc32c without libcrc32c.ko\n", + __func__); + return 0; +} + +static const struct skb_checksum_ops default_crc32c_ops = { + .update = warn_crc32c_csum_update, + .combine = warn_crc32c_csum_combine, +}; + +const struct skb_checksum_ops *crc32c_csum_stub __read_mostly = + &default_crc32c_ops; +EXPORT_SYMBOL(crc32c_csum_stub); + + /** + * skb_zerocopy_headlen - Calculate headroom needed for skb_zerocopy() + * @from: source buffer + * + * Calculates the amount of linear headroom needed in the 'to' skb passed + * into skb_zerocopy(). + */ +unsigned int +skb_zerocopy_headlen(const struct sk_buff *from) +{ + unsigned int hlen = 0; + + if (!from->head_frag || + skb_headlen(from) < L1_CACHE_BYTES || + skb_shinfo(from)->nr_frags >= MAX_SKB_FRAGS) { + hlen = skb_headlen(from); + if (!hlen) + hlen = from->len; + } + + if (skb_has_frag_list(from)) + hlen = from->len; + + return hlen; +} +EXPORT_SYMBOL_GPL(skb_zerocopy_headlen); + +/** + * skb_zerocopy - Zero copy skb to skb + * @to: destination buffer + * @from: source buffer + * @len: number of bytes to copy from source buffer + * @hlen: size of linear headroom in destination buffer + * + * Copies up to `len` bytes from `from` to `to` by creating references + * to the frags in the source buffer. + * + * The `hlen` as calculated by skb_zerocopy_headlen() specifies the + * headroom in the `to` buffer. + * + * Return value: + * 0: everything is OK + * -ENOMEM: couldn't orphan frags of @from due to lack of memory + * -EFAULT: skb_copy_bits() found some problem with skb geometry + */ +int +skb_zerocopy(struct sk_buff *to, struct sk_buff *from, int len, int hlen) +{ + int i, j = 0; + int plen = 0; /* length of skb->head fragment */ + int ret; + struct page *page; + unsigned int offset; + + BUG_ON(!from->head_frag && !hlen); + + /* dont bother with small payloads */ + if (len <= skb_tailroom(to)) + return skb_copy_bits(from, 0, skb_put(to, len), len); + + if (hlen) { + ret = skb_copy_bits(from, 0, skb_put(to, hlen), hlen); + if (unlikely(ret)) + return ret; + len -= hlen; + } else { + plen = min_t(int, skb_headlen(from), len); + if (plen) { + page = virt_to_head_page(from->head); + offset = from->data - (unsigned char *)page_address(page); + __skb_fill_page_desc(to, 0, page, offset, plen); + get_page(page); + j = 1; + len -= plen; + } + } + + skb_len_add(to, len + plen); + + if (unlikely(skb_orphan_frags(from, GFP_ATOMIC))) { + skb_tx_error(from); + return -ENOMEM; + } + skb_zerocopy_clone(to, from, GFP_ATOMIC); + + for (i = 0; i < skb_shinfo(from)->nr_frags; i++) { + int size; + + if (!len) + break; + skb_shinfo(to)->frags[j] = skb_shinfo(from)->frags[i]; + size = min_t(int, skb_frag_size(&skb_shinfo(to)->frags[j]), + len); + skb_frag_size_set(&skb_shinfo(to)->frags[j], size); + len -= size; + skb_frag_ref(to, j); + j++; + } + skb_shinfo(to)->nr_frags = j; + + return 0; +} +EXPORT_SYMBOL_GPL(skb_zerocopy); + +void skb_copy_and_csum_dev(const struct sk_buff *skb, u8 *to) +{ + __wsum csum; + long csstart; + + if (skb->ip_summed == CHECKSUM_PARTIAL) + csstart = skb_checksum_start_offset(skb); + else + csstart = skb_headlen(skb); + + BUG_ON(csstart > skb_headlen(skb)); + + skb_copy_from_linear_data(skb, to, csstart); + + csum = 0; + if (csstart != skb->len) + csum = skb_copy_and_csum_bits(skb, csstart, to + csstart, + skb->len - csstart); + + if (skb->ip_summed == CHECKSUM_PARTIAL) { + long csstuff = csstart + skb->csum_offset; + + *((__sum16 *)(to + csstuff)) = csum_fold(csum); + } +} +EXPORT_SYMBOL(skb_copy_and_csum_dev); + +/** + * skb_dequeue - remove from the head of the queue + * @list: list to dequeue from + * + * Remove the head of the list. The list lock is taken so the function + * may be used safely with other locking list functions. The head item is + * returned or %NULL if the list is empty. + */ + +struct sk_buff *skb_dequeue(struct sk_buff_head *list) +{ + unsigned long flags; + struct sk_buff *result; + + spin_lock_irqsave(&list->lock, flags); + result = __skb_dequeue(list); + spin_unlock_irqrestore(&list->lock, flags); + return result; +} +EXPORT_SYMBOL(skb_dequeue); + +/** + * skb_dequeue_tail - remove from the tail of the queue + * @list: list to dequeue from + * + * Remove the tail of the list. The list lock is taken so the function + * may be used safely with other locking list functions. The tail item is + * returned or %NULL if the list is empty. + */ +struct sk_buff *skb_dequeue_tail(struct sk_buff_head *list) +{ + unsigned long flags; + struct sk_buff *result; + + spin_lock_irqsave(&list->lock, flags); + result = __skb_dequeue_tail(list); + spin_unlock_irqrestore(&list->lock, flags); + return result; +} +EXPORT_SYMBOL(skb_dequeue_tail); + +/** + * skb_queue_purge - empty a list + * @list: list to empty + * + * Delete all buffers on an &sk_buff list. Each buffer is removed from + * the list and one reference dropped. This function takes the list + * lock and is atomic with respect to other list locking functions. + */ +void skb_queue_purge(struct sk_buff_head *list) +{ + struct sk_buff *skb; + while ((skb = skb_dequeue(list)) != NULL) + kfree_skb(skb); +} +EXPORT_SYMBOL(skb_queue_purge); + +/** + * skb_rbtree_purge - empty a skb rbtree + * @root: root of the rbtree to empty + * Return value: the sum of truesizes of all purged skbs. + * + * Delete all buffers on an &sk_buff rbtree. Each buffer is removed from + * the list and one reference dropped. This function does not take + * any lock. Synchronization should be handled by the caller (e.g., TCP + * out-of-order queue is protected by the socket lock). + */ +unsigned int skb_rbtree_purge(struct rb_root *root) +{ + struct rb_node *p = rb_first(root); + unsigned int sum = 0; + + while (p) { + struct sk_buff *skb = rb_entry(p, struct sk_buff, rbnode); + + p = rb_next(p); + rb_erase(&skb->rbnode, root); + sum += skb->truesize; + kfree_skb(skb); + } + return sum; +} + +/** + * skb_queue_head - queue a buffer at the list head + * @list: list to use + * @newsk: buffer to queue + * + * Queue a buffer at the start of the list. This function takes the + * list lock and can be used safely with other locking &sk_buff functions + * safely. + * + * A buffer cannot be placed on two lists at the same time. + */ +void skb_queue_head(struct sk_buff_head *list, struct sk_buff *newsk) +{ + unsigned long flags; + + spin_lock_irqsave(&list->lock, flags); + __skb_queue_head(list, newsk); + spin_unlock_irqrestore(&list->lock, flags); +} +EXPORT_SYMBOL(skb_queue_head); + +/** + * skb_queue_tail - queue a buffer at the list tail + * @list: list to use + * @newsk: buffer to queue + * + * Queue a buffer at the tail of the list. This function takes the + * list lock and can be used safely with other locking &sk_buff functions + * safely. + * + * A buffer cannot be placed on two lists at the same time. + */ +void skb_queue_tail(struct sk_buff_head *list, struct sk_buff *newsk) +{ + unsigned long flags; + + spin_lock_irqsave(&list->lock, flags); + __skb_queue_tail(list, newsk); + spin_unlock_irqrestore(&list->lock, flags); +} +EXPORT_SYMBOL(skb_queue_tail); + +/** + * skb_unlink - remove a buffer from a list + * @skb: buffer to remove + * @list: list to use + * + * Remove a packet from a list. The list locks are taken and this + * function is atomic with respect to other list locked calls + * + * You must know what list the SKB is on. + */ +void skb_unlink(struct sk_buff *skb, struct sk_buff_head *list) +{ + unsigned long flags; + + spin_lock_irqsave(&list->lock, flags); + __skb_unlink(skb, list); + spin_unlock_irqrestore(&list->lock, flags); +} +EXPORT_SYMBOL(skb_unlink); + +/** + * skb_append - append a buffer + * @old: buffer to insert after + * @newsk: buffer to insert + * @list: list to use + * + * Place a packet after a given packet in a list. The list locks are taken + * and this function is atomic with respect to other list locked calls. + * A buffer cannot be placed on two lists at the same time. + */ +void skb_append(struct sk_buff *old, struct sk_buff *newsk, struct sk_buff_head *list) +{ + unsigned long flags; + + spin_lock_irqsave(&list->lock, flags); + __skb_queue_after(list, old, newsk); + spin_unlock_irqrestore(&list->lock, flags); +} +EXPORT_SYMBOL(skb_append); + +static inline void skb_split_inside_header(struct sk_buff *skb, + struct sk_buff* skb1, + const u32 len, const int pos) +{ + int i; + + skb_copy_from_linear_data_offset(skb, len, skb_put(skb1, pos - len), + pos - len); + /* And move data appendix as is. */ + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) + skb_shinfo(skb1)->frags[i] = skb_shinfo(skb)->frags[i]; + + skb_shinfo(skb1)->nr_frags = skb_shinfo(skb)->nr_frags; + skb_shinfo(skb)->nr_frags = 0; + skb1->data_len = skb->data_len; + skb1->len += skb1->data_len; + skb->data_len = 0; + skb->len = len; + skb_set_tail_pointer(skb, len); +} + +static inline void skb_split_no_header(struct sk_buff *skb, + struct sk_buff* skb1, + const u32 len, int pos) +{ + int i, k = 0; + const int nfrags = skb_shinfo(skb)->nr_frags; + + skb_shinfo(skb)->nr_frags = 0; + skb1->len = skb1->data_len = skb->len - len; + skb->len = len; + skb->data_len = len - pos; + + for (i = 0; i < nfrags; i++) { + int size = skb_frag_size(&skb_shinfo(skb)->frags[i]); + + if (pos + size > len) { + skb_shinfo(skb1)->frags[k] = skb_shinfo(skb)->frags[i]; + + if (pos < len) { + /* Split frag. + * We have two variants in this case: + * 1. Move all the frag to the second + * part, if it is possible. F.e. + * this approach is mandatory for TUX, + * where splitting is expensive. + * 2. Split is accurately. We make this. + */ + skb_frag_ref(skb, i); + skb_frag_off_add(&skb_shinfo(skb1)->frags[0], len - pos); + skb_frag_size_sub(&skb_shinfo(skb1)->frags[0], len - pos); + skb_frag_size_set(&skb_shinfo(skb)->frags[i], len - pos); + skb_shinfo(skb)->nr_frags++; + } + k++; + } else + skb_shinfo(skb)->nr_frags++; + pos += size; + } + skb_shinfo(skb1)->nr_frags = k; +} + +/** + * skb_split - Split fragmented skb to two parts at length len. + * @skb: the buffer to split + * @skb1: the buffer to receive the second part + * @len: new length for skb + */ +void skb_split(struct sk_buff *skb, struct sk_buff *skb1, const u32 len) +{ + int pos = skb_headlen(skb); + const int zc_flags = SKBFL_SHARED_FRAG | SKBFL_PURE_ZEROCOPY; + + skb_zcopy_downgrade_managed(skb); + + skb_shinfo(skb1)->flags |= skb_shinfo(skb)->flags & zc_flags; + skb_zerocopy_clone(skb1, skb, 0); + if (len < pos) /* Split line is inside header. */ + skb_split_inside_header(skb, skb1, len, pos); + else /* Second chunk has no header, nothing to copy. */ + skb_split_no_header(skb, skb1, len, pos); +} +EXPORT_SYMBOL(skb_split); + +/* Shifting from/to a cloned skb is a no-go. + * + * Caller cannot keep skb_shinfo related pointers past calling here! + */ +static int skb_prepare_for_shift(struct sk_buff *skb) +{ + return skb_unclone_keeptruesize(skb, GFP_ATOMIC); +} + +/** + * skb_shift - Shifts paged data partially from skb to another + * @tgt: buffer into which tail data gets added + * @skb: buffer from which the paged data comes from + * @shiftlen: shift up to this many bytes + * + * Attempts to shift up to shiftlen worth of bytes, which may be less than + * the length of the skb, from skb to tgt. Returns number bytes shifted. + * It's up to caller to free skb if everything was shifted. + * + * If @tgt runs out of frags, the whole operation is aborted. + * + * Skb cannot include anything else but paged data while tgt is allowed + * to have non-paged data as well. + * + * TODO: full sized shift could be optimized but that would need + * specialized skb free'er to handle frags without up-to-date nr_frags. + */ +int skb_shift(struct sk_buff *tgt, struct sk_buff *skb, int shiftlen) +{ + int from, to, merge, todo; + skb_frag_t *fragfrom, *fragto; + + BUG_ON(shiftlen > skb->len); + + if (skb_headlen(skb)) + return 0; + if (skb_zcopy(tgt) || skb_zcopy(skb)) + return 0; + + todo = shiftlen; + from = 0; + to = skb_shinfo(tgt)->nr_frags; + fragfrom = &skb_shinfo(skb)->frags[from]; + + /* Actual merge is delayed until the point when we know we can + * commit all, so that we don't have to undo partial changes + */ + if (!to || + !skb_can_coalesce(tgt, to, skb_frag_page(fragfrom), + skb_frag_off(fragfrom))) { + merge = -1; + } else { + merge = to - 1; + + todo -= skb_frag_size(fragfrom); + if (todo < 0) { + if (skb_prepare_for_shift(skb) || + skb_prepare_for_shift(tgt)) + return 0; + + /* All previous frag pointers might be stale! */ + fragfrom = &skb_shinfo(skb)->frags[from]; + fragto = &skb_shinfo(tgt)->frags[merge]; + + skb_frag_size_add(fragto, shiftlen); + skb_frag_size_sub(fragfrom, shiftlen); + skb_frag_off_add(fragfrom, shiftlen); + + goto onlymerged; + } + + from++; + } + + /* Skip full, not-fitting skb to avoid expensive operations */ + if ((shiftlen == skb->len) && + (skb_shinfo(skb)->nr_frags - from) > (MAX_SKB_FRAGS - to)) + return 0; + + if (skb_prepare_for_shift(skb) || skb_prepare_for_shift(tgt)) + return 0; + + while ((todo > 0) && (from < skb_shinfo(skb)->nr_frags)) { + if (to == MAX_SKB_FRAGS) + return 0; + + fragfrom = &skb_shinfo(skb)->frags[from]; + fragto = &skb_shinfo(tgt)->frags[to]; + + if (todo >= skb_frag_size(fragfrom)) { + *fragto = *fragfrom; + todo -= skb_frag_size(fragfrom); + from++; + to++; + + } else { + __skb_frag_ref(fragfrom); + skb_frag_page_copy(fragto, fragfrom); + skb_frag_off_copy(fragto, fragfrom); + skb_frag_size_set(fragto, todo); + + skb_frag_off_add(fragfrom, todo); + skb_frag_size_sub(fragfrom, todo); + todo = 0; + + to++; + break; + } + } + + /* Ready to "commit" this state change to tgt */ + skb_shinfo(tgt)->nr_frags = to; + + if (merge >= 0) { + fragfrom = &skb_shinfo(skb)->frags[0]; + fragto = &skb_shinfo(tgt)->frags[merge]; + + skb_frag_size_add(fragto, skb_frag_size(fragfrom)); + __skb_frag_unref(fragfrom, skb->pp_recycle); + } + + /* Reposition in the original skb */ + to = 0; + while (from < skb_shinfo(skb)->nr_frags) + skb_shinfo(skb)->frags[to++] = skb_shinfo(skb)->frags[from++]; + skb_shinfo(skb)->nr_frags = to; + + BUG_ON(todo > 0 && !skb_shinfo(skb)->nr_frags); + +onlymerged: + /* Most likely the tgt won't ever need its checksum anymore, skb on + * the other hand might need it if it needs to be resent + */ + tgt->ip_summed = CHECKSUM_PARTIAL; + skb->ip_summed = CHECKSUM_PARTIAL; + + skb_len_add(skb, -shiftlen); + skb_len_add(tgt, shiftlen); + + return shiftlen; +} + +/** + * skb_prepare_seq_read - Prepare a sequential read of skb data + * @skb: the buffer to read + * @from: lower offset of data to be read + * @to: upper offset of data to be read + * @st: state variable + * + * Initializes the specified state variable. Must be called before + * invoking skb_seq_read() for the first time. + */ +void skb_prepare_seq_read(struct sk_buff *skb, unsigned int from, + unsigned int to, struct skb_seq_state *st) +{ + st->lower_offset = from; + st->upper_offset = to; + st->root_skb = st->cur_skb = skb; + st->frag_idx = st->stepped_offset = 0; + st->frag_data = NULL; + st->frag_off = 0; +} +EXPORT_SYMBOL(skb_prepare_seq_read); + +/** + * skb_seq_read - Sequentially read skb data + * @consumed: number of bytes consumed by the caller so far + * @data: destination pointer for data to be returned + * @st: state variable + * + * Reads a block of skb data at @consumed relative to the + * lower offset specified to skb_prepare_seq_read(). Assigns + * the head of the data block to @data and returns the length + * of the block or 0 if the end of the skb data or the upper + * offset has been reached. + * + * The caller is not required to consume all of the data + * returned, i.e. @consumed is typically set to the number + * of bytes already consumed and the next call to + * skb_seq_read() will return the remaining part of the block. + * + * Note 1: The size of each block of data returned can be arbitrary, + * this limitation is the cost for zerocopy sequential + * reads of potentially non linear data. + * + * Note 2: Fragment lists within fragments are not implemented + * at the moment, state->root_skb could be replaced with + * a stack for this purpose. + */ +unsigned int skb_seq_read(unsigned int consumed, const u8 **data, + struct skb_seq_state *st) +{ + unsigned int block_limit, abs_offset = consumed + st->lower_offset; + skb_frag_t *frag; + + if (unlikely(abs_offset >= st->upper_offset)) { + if (st->frag_data) { + kunmap_atomic(st->frag_data); + st->frag_data = NULL; + } + return 0; + } + +next_skb: + block_limit = skb_headlen(st->cur_skb) + st->stepped_offset; + + if (abs_offset < block_limit && !st->frag_data) { + *data = st->cur_skb->data + (abs_offset - st->stepped_offset); + return block_limit - abs_offset; + } + + if (st->frag_idx == 0 && !st->frag_data) + st->stepped_offset += skb_headlen(st->cur_skb); + + while (st->frag_idx < skb_shinfo(st->cur_skb)->nr_frags) { + unsigned int pg_idx, pg_off, pg_sz; + + frag = &skb_shinfo(st->cur_skb)->frags[st->frag_idx]; + + pg_idx = 0; + pg_off = skb_frag_off(frag); + pg_sz = skb_frag_size(frag); + + if (skb_frag_must_loop(skb_frag_page(frag))) { + pg_idx = (pg_off + st->frag_off) >> PAGE_SHIFT; + pg_off = offset_in_page(pg_off + st->frag_off); + pg_sz = min_t(unsigned int, pg_sz - st->frag_off, + PAGE_SIZE - pg_off); + } + + block_limit = pg_sz + st->stepped_offset; + if (abs_offset < block_limit) { + if (!st->frag_data) + st->frag_data = kmap_atomic(skb_frag_page(frag) + pg_idx); + + *data = (u8 *)st->frag_data + pg_off + + (abs_offset - st->stepped_offset); + + return block_limit - abs_offset; + } + + if (st->frag_data) { + kunmap_atomic(st->frag_data); + st->frag_data = NULL; + } + + st->stepped_offset += pg_sz; + st->frag_off += pg_sz; + if (st->frag_off == skb_frag_size(frag)) { + st->frag_off = 0; + st->frag_idx++; + } + } + + if (st->frag_data) { + kunmap_atomic(st->frag_data); + st->frag_data = NULL; + } + + if (st->root_skb == st->cur_skb && skb_has_frag_list(st->root_skb)) { + st->cur_skb = skb_shinfo(st->root_skb)->frag_list; + st->frag_idx = 0; + goto next_skb; + } else if (st->cur_skb->next) { + st->cur_skb = st->cur_skb->next; + st->frag_idx = 0; + goto next_skb; + } + + return 0; +} +EXPORT_SYMBOL(skb_seq_read); + +/** + * skb_abort_seq_read - Abort a sequential read of skb data + * @st: state variable + * + * Must be called if skb_seq_read() was not called until it + * returned 0. + */ +void skb_abort_seq_read(struct skb_seq_state *st) +{ + if (st->frag_data) + kunmap_atomic(st->frag_data); +} +EXPORT_SYMBOL(skb_abort_seq_read); + +#define TS_SKB_CB(state) ((struct skb_seq_state *) &((state)->cb)) + +static unsigned int skb_ts_get_next_block(unsigned int offset, const u8 **text, + struct ts_config *conf, + struct ts_state *state) +{ + return skb_seq_read(offset, text, TS_SKB_CB(state)); +} + +static void skb_ts_finish(struct ts_config *conf, struct ts_state *state) +{ + skb_abort_seq_read(TS_SKB_CB(state)); +} + +/** + * skb_find_text - Find a text pattern in skb data + * @skb: the buffer to look in + * @from: search offset + * @to: search limit + * @config: textsearch configuration + * + * Finds a pattern in the skb data according to the specified + * textsearch configuration. Use textsearch_next() to retrieve + * subsequent occurrences of the pattern. Returns the offset + * to the first occurrence or UINT_MAX if no match was found. + */ +unsigned int skb_find_text(struct sk_buff *skb, unsigned int from, + unsigned int to, struct ts_config *config) +{ + unsigned int patlen = config->ops->get_pattern_len(config); + struct ts_state state; + unsigned int ret; + + BUILD_BUG_ON(sizeof(struct skb_seq_state) > sizeof(state.cb)); + + config->get_next_block = skb_ts_get_next_block; + config->finish = skb_ts_finish; + + skb_prepare_seq_read(skb, from, to, TS_SKB_CB(&state)); + + ret = textsearch_find(config, &state); + return (ret + patlen <= to - from ? ret : UINT_MAX); +} +EXPORT_SYMBOL(skb_find_text); + +int skb_append_pagefrags(struct sk_buff *skb, struct page *page, + int offset, size_t size) +{ + int i = skb_shinfo(skb)->nr_frags; + + if (skb_can_coalesce(skb, i, page, offset)) { + skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], size); + } else if (i < MAX_SKB_FRAGS) { + skb_zcopy_downgrade_managed(skb); + get_page(page); + skb_fill_page_desc_noacc(skb, i, page, offset, size); + } else { + return -EMSGSIZE; + } + + return 0; +} +EXPORT_SYMBOL_GPL(skb_append_pagefrags); + +/** + * skb_pull_rcsum - pull skb and update receive checksum + * @skb: buffer to update + * @len: length of data pulled + * + * This function performs an skb_pull on the packet and updates + * the CHECKSUM_COMPLETE checksum. It should be used on + * receive path processing instead of skb_pull unless you know + * that the checksum difference is zero (e.g., a valid IP header) + * or you are setting ip_summed to CHECKSUM_NONE. + */ +void *skb_pull_rcsum(struct sk_buff *skb, unsigned int len) +{ + unsigned char *data = skb->data; + + BUG_ON(len > skb->len); + __skb_pull(skb, len); + skb_postpull_rcsum(skb, data, len); + return skb->data; +} +EXPORT_SYMBOL_GPL(skb_pull_rcsum); + +static inline skb_frag_t skb_head_frag_to_page_desc(struct sk_buff *frag_skb) +{ + skb_frag_t head_frag; + struct page *page; + + page = virt_to_head_page(frag_skb->head); + __skb_frag_set_page(&head_frag, page); + skb_frag_off_set(&head_frag, frag_skb->data - + (unsigned char *)page_address(page)); + skb_frag_size_set(&head_frag, skb_headlen(frag_skb)); + return head_frag; +} + +struct sk_buff *skb_segment_list(struct sk_buff *skb, + netdev_features_t features, + unsigned int offset) +{ + struct sk_buff *list_skb = skb_shinfo(skb)->frag_list; + unsigned int tnl_hlen = skb_tnl_header_len(skb); + unsigned int delta_truesize = 0; + unsigned int delta_len = 0; + struct sk_buff *tail = NULL; + struct sk_buff *nskb, *tmp; + int len_diff, err; + + skb_push(skb, -skb_network_offset(skb) + offset); + + /* Ensure the head is writeable before touching the shared info */ + err = skb_unclone(skb, GFP_ATOMIC); + if (err) + goto err_linearize; + + skb_shinfo(skb)->frag_list = NULL; + + while (list_skb) { + nskb = list_skb; + list_skb = list_skb->next; + + err = 0; + delta_truesize += nskb->truesize; + if (skb_shared(nskb)) { + tmp = skb_clone(nskb, GFP_ATOMIC); + if (tmp) { + consume_skb(nskb); + nskb = tmp; + err = skb_unclone(nskb, GFP_ATOMIC); + } else { + err = -ENOMEM; + } + } + + if (!tail) + skb->next = nskb; + else + tail->next = nskb; + + if (unlikely(err)) { + nskb->next = list_skb; + goto err_linearize; + } + + tail = nskb; + + delta_len += nskb->len; + + skb_push(nskb, -skb_network_offset(nskb) + offset); + + skb_release_head_state(nskb); + len_diff = skb_network_header_len(nskb) - skb_network_header_len(skb); + __copy_skb_header(nskb, skb); + + skb_headers_offset_update(nskb, skb_headroom(nskb) - skb_headroom(skb)); + nskb->transport_header += len_diff; + skb_copy_from_linear_data_offset(skb, -tnl_hlen, + nskb->data - tnl_hlen, + offset + tnl_hlen); + + if (skb_needs_linearize(nskb, features) && + __skb_linearize(nskb)) + goto err_linearize; + } + + skb->truesize = skb->truesize - delta_truesize; + skb->data_len = skb->data_len - delta_len; + skb->len = skb->len - delta_len; + + skb_gso_reset(skb); + + skb->prev = tail; + + if (skb_needs_linearize(skb, features) && + __skb_linearize(skb)) + goto err_linearize; + + skb_get(skb); + + return skb; + +err_linearize: + kfree_skb_list(skb->next); + skb->next = NULL; + return ERR_PTR(-ENOMEM); +} +EXPORT_SYMBOL_GPL(skb_segment_list); + +/** + * skb_segment - Perform protocol segmentation on skb. + * @head_skb: buffer to segment + * @features: features for the output path (see dev->features) + * + * This function performs segmentation on the given skb. It returns + * a pointer to the first in a list of new skbs for the segments. + * In case of error it returns ERR_PTR(err). + */ +struct sk_buff *skb_segment(struct sk_buff *head_skb, + netdev_features_t features) +{ + struct sk_buff *segs = NULL; + struct sk_buff *tail = NULL; + struct sk_buff *list_skb = skb_shinfo(head_skb)->frag_list; + unsigned int mss = skb_shinfo(head_skb)->gso_size; + unsigned int doffset = head_skb->data - skb_mac_header(head_skb); + unsigned int offset = doffset; + unsigned int tnl_hlen = skb_tnl_header_len(head_skb); + unsigned int partial_segs = 0; + unsigned int headroom; + unsigned int len = head_skb->len; + struct sk_buff *frag_skb; + skb_frag_t *frag; + __be16 proto; + bool csum, sg; + int err = -ENOMEM; + int i = 0; + int nfrags, pos; + + if ((skb_shinfo(head_skb)->gso_type & SKB_GSO_DODGY) && + mss != GSO_BY_FRAGS && mss != skb_headlen(head_skb)) { + struct sk_buff *check_skb; + + for (check_skb = list_skb; check_skb; check_skb = check_skb->next) { + if (skb_headlen(check_skb) && !check_skb->head_frag) { + /* gso_size is untrusted, and we have a frag_list with + * a linear non head_frag item. + * + * If head_skb's headlen does not fit requested gso_size, + * it means that the frag_list members do NOT terminate + * on exact gso_size boundaries. Hence we cannot perform + * skb_frag_t page sharing. Therefore we must fallback to + * copying the frag_list skbs; we do so by disabling SG. + */ + features &= ~NETIF_F_SG; + break; + } + } + } + + __skb_push(head_skb, doffset); + proto = skb_network_protocol(head_skb, NULL); + if (unlikely(!proto)) + return ERR_PTR(-EINVAL); + + sg = !!(features & NETIF_F_SG); + csum = !!can_checksum_protocol(features, proto); + + if (sg && csum && (mss != GSO_BY_FRAGS)) { + if (!(features & NETIF_F_GSO_PARTIAL)) { + struct sk_buff *iter; + unsigned int frag_len; + + if (!list_skb || + !net_gso_ok(features, skb_shinfo(head_skb)->gso_type)) + goto normal; + + /* If we get here then all the required + * GSO features except frag_list are supported. + * Try to split the SKB to multiple GSO SKBs + * with no frag_list. + * Currently we can do that only when the buffers don't + * have a linear part and all the buffers except + * the last are of the same length. + */ + frag_len = list_skb->len; + skb_walk_frags(head_skb, iter) { + if (frag_len != iter->len && iter->next) + goto normal; + if (skb_headlen(iter) && !iter->head_frag) + goto normal; + + len -= iter->len; + } + + if (len != frag_len) + goto normal; + } + + /* GSO partial only requires that we trim off any excess that + * doesn't fit into an MSS sized block, so take care of that + * now. + */ + partial_segs = len / mss; + if (partial_segs > 1) + mss *= partial_segs; + else + partial_segs = 0; + } + +normal: + headroom = skb_headroom(head_skb); + pos = skb_headlen(head_skb); + + if (skb_orphan_frags(head_skb, GFP_ATOMIC)) + return ERR_PTR(-ENOMEM); + + nfrags = skb_shinfo(head_skb)->nr_frags; + frag = skb_shinfo(head_skb)->frags; + frag_skb = head_skb; + + do { + struct sk_buff *nskb; + skb_frag_t *nskb_frag; + int hsize; + int size; + + if (unlikely(mss == GSO_BY_FRAGS)) { + len = list_skb->len; + } else { + len = head_skb->len - offset; + if (len > mss) + len = mss; + } + + hsize = skb_headlen(head_skb) - offset; + + if (hsize <= 0 && i >= nfrags && skb_headlen(list_skb) && + (skb_headlen(list_skb) == len || sg)) { + BUG_ON(skb_headlen(list_skb) > len); + + nskb = skb_clone(list_skb, GFP_ATOMIC); + if (unlikely(!nskb)) + goto err; + + i = 0; + nfrags = skb_shinfo(list_skb)->nr_frags; + frag = skb_shinfo(list_skb)->frags; + frag_skb = list_skb; + pos += skb_headlen(list_skb); + + while (pos < offset + len) { + BUG_ON(i >= nfrags); + + size = skb_frag_size(frag); + if (pos + size > offset + len) + break; + + i++; + pos += size; + frag++; + } + + list_skb = list_skb->next; + + if (unlikely(pskb_trim(nskb, len))) { + kfree_skb(nskb); + goto err; + } + + hsize = skb_end_offset(nskb); + if (skb_cow_head(nskb, doffset + headroom)) { + kfree_skb(nskb); + goto err; + } + + nskb->truesize += skb_end_offset(nskb) - hsize; + skb_release_head_state(nskb); + __skb_push(nskb, doffset); + } else { + if (hsize < 0) + hsize = 0; + if (hsize > len || !sg) + hsize = len; + + nskb = __alloc_skb(hsize + doffset + headroom, + GFP_ATOMIC, skb_alloc_rx_flag(head_skb), + NUMA_NO_NODE); + + if (unlikely(!nskb)) + goto err; + + skb_reserve(nskb, headroom); + __skb_put(nskb, doffset); + } + + if (segs) + tail->next = nskb; + else + segs = nskb; + tail = nskb; + + __copy_skb_header(nskb, head_skb); + + skb_headers_offset_update(nskb, skb_headroom(nskb) - headroom); + skb_reset_mac_len(nskb); + + skb_copy_from_linear_data_offset(head_skb, -tnl_hlen, + nskb->data - tnl_hlen, + doffset + tnl_hlen); + + if (nskb->len == len + doffset) + goto perform_csum_check; + + if (!sg) { + if (!csum) { + if (!nskb->remcsum_offload) + nskb->ip_summed = CHECKSUM_NONE; + SKB_GSO_CB(nskb)->csum = + skb_copy_and_csum_bits(head_skb, offset, + skb_put(nskb, + len), + len); + SKB_GSO_CB(nskb)->csum_start = + skb_headroom(nskb) + doffset; + } else { + if (skb_copy_bits(head_skb, offset, skb_put(nskb, len), len)) + goto err; + } + continue; + } + + nskb_frag = skb_shinfo(nskb)->frags; + + skb_copy_from_linear_data_offset(head_skb, offset, + skb_put(nskb, hsize), hsize); + + skb_shinfo(nskb)->flags |= skb_shinfo(head_skb)->flags & + SKBFL_SHARED_FRAG; + + if (skb_zerocopy_clone(nskb, frag_skb, GFP_ATOMIC)) + goto err; + + while (pos < offset + len) { + if (i >= nfrags) { + if (skb_orphan_frags(list_skb, GFP_ATOMIC) || + skb_zerocopy_clone(nskb, list_skb, + GFP_ATOMIC)) + goto err; + + i = 0; + nfrags = skb_shinfo(list_skb)->nr_frags; + frag = skb_shinfo(list_skb)->frags; + frag_skb = list_skb; + if (!skb_headlen(list_skb)) { + BUG_ON(!nfrags); + } else { + BUG_ON(!list_skb->head_frag); + + /* to make room for head_frag. */ + i--; + frag--; + } + + list_skb = list_skb->next; + } + + if (unlikely(skb_shinfo(nskb)->nr_frags >= + MAX_SKB_FRAGS)) { + net_warn_ratelimited( + "skb_segment: too many frags: %u %u\n", + pos, mss); + err = -EINVAL; + goto err; + } + + *nskb_frag = (i < 0) ? skb_head_frag_to_page_desc(frag_skb) : *frag; + __skb_frag_ref(nskb_frag); + size = skb_frag_size(nskb_frag); + + if (pos < offset) { + skb_frag_off_add(nskb_frag, offset - pos); + skb_frag_size_sub(nskb_frag, offset - pos); + } + + skb_shinfo(nskb)->nr_frags++; + + if (pos + size <= offset + len) { + i++; + frag++; + pos += size; + } else { + skb_frag_size_sub(nskb_frag, pos + size - (offset + len)); + goto skip_fraglist; + } + + nskb_frag++; + } + +skip_fraglist: + nskb->data_len = len - hsize; + nskb->len += nskb->data_len; + nskb->truesize += nskb->data_len; + +perform_csum_check: + if (!csum) { + if (skb_has_shared_frag(nskb) && + __skb_linearize(nskb)) + goto err; + + if (!nskb->remcsum_offload) + nskb->ip_summed = CHECKSUM_NONE; + SKB_GSO_CB(nskb)->csum = + skb_checksum(nskb, doffset, + nskb->len - doffset, 0); + SKB_GSO_CB(nskb)->csum_start = + skb_headroom(nskb) + doffset; + } + } while ((offset += len) < head_skb->len); + + /* Some callers want to get the end of the list. + * Put it in segs->prev to avoid walking the list. + * (see validate_xmit_skb_list() for example) + */ + segs->prev = tail; + + if (partial_segs) { + struct sk_buff *iter; + int type = skb_shinfo(head_skb)->gso_type; + unsigned short gso_size = skb_shinfo(head_skb)->gso_size; + + /* Update type to add partial and then remove dodgy if set */ + type |= (features & NETIF_F_GSO_PARTIAL) / NETIF_F_GSO_PARTIAL * SKB_GSO_PARTIAL; + type &= ~SKB_GSO_DODGY; + + /* Update GSO info and prepare to start updating headers on + * our way back down the stack of protocols. + */ + for (iter = segs; iter; iter = iter->next) { + skb_shinfo(iter)->gso_size = gso_size; + skb_shinfo(iter)->gso_segs = partial_segs; + skb_shinfo(iter)->gso_type = type; + SKB_GSO_CB(iter)->data_offset = skb_headroom(iter) + doffset; + } + + if (tail->len - doffset <= gso_size) + skb_shinfo(tail)->gso_size = 0; + else if (tail != segs) + skb_shinfo(tail)->gso_segs = DIV_ROUND_UP(tail->len - doffset, gso_size); + } + + /* Following permits correct backpressure, for protocols + * using skb_set_owner_w(). + * Idea is to tranfert ownership from head_skb to last segment. + */ + if (head_skb->destructor == sock_wfree) { + swap(tail->truesize, head_skb->truesize); + swap(tail->destructor, head_skb->destructor); + swap(tail->sk, head_skb->sk); + } + return segs; + +err: + kfree_skb_list(segs); + return ERR_PTR(err); +} +EXPORT_SYMBOL_GPL(skb_segment); + +#ifdef CONFIG_SKB_EXTENSIONS +#define SKB_EXT_ALIGN_VALUE 8 +#define SKB_EXT_CHUNKSIZEOF(x) (ALIGN((sizeof(x)), SKB_EXT_ALIGN_VALUE) / SKB_EXT_ALIGN_VALUE) + +static const u8 skb_ext_type_len[] = { +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + [SKB_EXT_BRIDGE_NF] = SKB_EXT_CHUNKSIZEOF(struct nf_bridge_info), +#endif +#ifdef CONFIG_XFRM + [SKB_EXT_SEC_PATH] = SKB_EXT_CHUNKSIZEOF(struct sec_path), +#endif +#if IS_ENABLED(CONFIG_NET_TC_SKB_EXT) + [TC_SKB_EXT] = SKB_EXT_CHUNKSIZEOF(struct tc_skb_ext), +#endif +#if IS_ENABLED(CONFIG_MPTCP) + [SKB_EXT_MPTCP] = SKB_EXT_CHUNKSIZEOF(struct mptcp_ext), +#endif +#if IS_ENABLED(CONFIG_MCTP_FLOWS) + [SKB_EXT_MCTP] = SKB_EXT_CHUNKSIZEOF(struct mctp_flow), +#endif +}; + +static __always_inline unsigned int skb_ext_total_length(void) +{ + return SKB_EXT_CHUNKSIZEOF(struct skb_ext) + +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + skb_ext_type_len[SKB_EXT_BRIDGE_NF] + +#endif +#ifdef CONFIG_XFRM + skb_ext_type_len[SKB_EXT_SEC_PATH] + +#endif +#if IS_ENABLED(CONFIG_NET_TC_SKB_EXT) + skb_ext_type_len[TC_SKB_EXT] + +#endif +#if IS_ENABLED(CONFIG_MPTCP) + skb_ext_type_len[SKB_EXT_MPTCP] + +#endif +#if IS_ENABLED(CONFIG_MCTP_FLOWS) + skb_ext_type_len[SKB_EXT_MCTP] + +#endif + 0; +} + +static void skb_extensions_init(void) +{ + BUILD_BUG_ON(SKB_EXT_NUM >= 8); + BUILD_BUG_ON(skb_ext_total_length() > 255); + + skbuff_ext_cache = kmem_cache_create("skbuff_ext_cache", + SKB_EXT_ALIGN_VALUE * skb_ext_total_length(), + 0, + SLAB_HWCACHE_ALIGN|SLAB_PANIC, + NULL); +} +#else +static void skb_extensions_init(void) {} +#endif + +void __init skb_init(void) +{ + skbuff_head_cache = kmem_cache_create_usercopy("skbuff_head_cache", + sizeof(struct sk_buff), + 0, + SLAB_HWCACHE_ALIGN|SLAB_PANIC, + offsetof(struct sk_buff, cb), + sizeof_field(struct sk_buff, cb), + NULL); + skbuff_fclone_cache = kmem_cache_create("skbuff_fclone_cache", + sizeof(struct sk_buff_fclones), + 0, + SLAB_HWCACHE_ALIGN|SLAB_PANIC, + NULL); + skb_extensions_init(); +} + +static int +__skb_to_sgvec(struct sk_buff *skb, struct scatterlist *sg, int offset, int len, + unsigned int recursion_level) +{ + int start = skb_headlen(skb); + int i, copy = start - offset; + struct sk_buff *frag_iter; + int elt = 0; + + if (unlikely(recursion_level >= 24)) + return -EMSGSIZE; + + if (copy > 0) { + if (copy > len) + copy = len; + sg_set_buf(sg, skb->data + offset, copy); + elt++; + if ((len -= copy) == 0) + return elt; + offset += copy; + } + + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + int end; + + WARN_ON(start > offset + len); + + end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]); + if ((copy = end - offset) > 0) { + skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; + if (unlikely(elt && sg_is_last(&sg[elt - 1]))) + return -EMSGSIZE; + + if (copy > len) + copy = len; + sg_set_page(&sg[elt], skb_frag_page(frag), copy, + skb_frag_off(frag) + offset - start); + elt++; + if (!(len -= copy)) + return elt; + offset += copy; + } + start = end; + } + + skb_walk_frags(skb, frag_iter) { + int end, ret; + + WARN_ON(start > offset + len); + + end = start + frag_iter->len; + if ((copy = end - offset) > 0) { + if (unlikely(elt && sg_is_last(&sg[elt - 1]))) + return -EMSGSIZE; + + if (copy > len) + copy = len; + ret = __skb_to_sgvec(frag_iter, sg+elt, offset - start, + copy, recursion_level + 1); + if (unlikely(ret < 0)) + return ret; + elt += ret; + if ((len -= copy) == 0) + return elt; + offset += copy; + } + start = end; + } + BUG_ON(len); + return elt; +} + +/** + * skb_to_sgvec - Fill a scatter-gather list from a socket buffer + * @skb: Socket buffer containing the buffers to be mapped + * @sg: The scatter-gather list to map into + * @offset: The offset into the buffer's contents to start mapping + * @len: Length of buffer space to be mapped + * + * Fill the specified scatter-gather list with mappings/pointers into a + * region of the buffer space attached to a socket buffer. Returns either + * the number of scatterlist items used, or -EMSGSIZE if the contents + * could not fit. + */ +int skb_to_sgvec(struct sk_buff *skb, struct scatterlist *sg, int offset, int len) +{ + int nsg = __skb_to_sgvec(skb, sg, offset, len, 0); + + if (nsg <= 0) + return nsg; + + sg_mark_end(&sg[nsg - 1]); + + return nsg; +} +EXPORT_SYMBOL_GPL(skb_to_sgvec); + +/* As compared with skb_to_sgvec, skb_to_sgvec_nomark only map skb to given + * sglist without mark the sg which contain last skb data as the end. + * So the caller can mannipulate sg list as will when padding new data after + * the first call without calling sg_unmark_end to expend sg list. + * + * Scenario to use skb_to_sgvec_nomark: + * 1. sg_init_table + * 2. skb_to_sgvec_nomark(payload1) + * 3. skb_to_sgvec_nomark(payload2) + * + * This is equivalent to: + * 1. sg_init_table + * 2. skb_to_sgvec(payload1) + * 3. sg_unmark_end + * 4. skb_to_sgvec(payload2) + * + * When mapping mutilple payload conditionally, skb_to_sgvec_nomark + * is more preferable. + */ +int skb_to_sgvec_nomark(struct sk_buff *skb, struct scatterlist *sg, + int offset, int len) +{ + return __skb_to_sgvec(skb, sg, offset, len, 0); +} +EXPORT_SYMBOL_GPL(skb_to_sgvec_nomark); + + + +/** + * skb_cow_data - Check that a socket buffer's data buffers are writable + * @skb: The socket buffer to check. + * @tailbits: Amount of trailing space to be added + * @trailer: Returned pointer to the skb where the @tailbits space begins + * + * Make sure that the data buffers attached to a socket buffer are + * writable. If they are not, private copies are made of the data buffers + * and the socket buffer is set to use these instead. + * + * If @tailbits is given, make sure that there is space to write @tailbits + * bytes of data beyond current end of socket buffer. @trailer will be + * set to point to the skb in which this space begins. + * + * The number of scatterlist elements required to completely map the + * COW'd and extended socket buffer will be returned. + */ +int skb_cow_data(struct sk_buff *skb, int tailbits, struct sk_buff **trailer) +{ + int copyflag; + int elt; + struct sk_buff *skb1, **skb_p; + + /* If skb is cloned or its head is paged, reallocate + * head pulling out all the pages (pages are considered not writable + * at the moment even if they are anonymous). + */ + if ((skb_cloned(skb) || skb_shinfo(skb)->nr_frags) && + !__pskb_pull_tail(skb, __skb_pagelen(skb))) + return -ENOMEM; + + /* Easy case. Most of packets will go this way. */ + if (!skb_has_frag_list(skb)) { + /* A little of trouble, not enough of space for trailer. + * This should not happen, when stack is tuned to generate + * good frames. OK, on miss we reallocate and reserve even more + * space, 128 bytes is fair. */ + + if (skb_tailroom(skb) < tailbits && + pskb_expand_head(skb, 0, tailbits-skb_tailroom(skb)+128, GFP_ATOMIC)) + return -ENOMEM; + + /* Voila! */ + *trailer = skb; + return 1; + } + + /* Misery. We are in troubles, going to mincer fragments... */ + + elt = 1; + skb_p = &skb_shinfo(skb)->frag_list; + copyflag = 0; + + while ((skb1 = *skb_p) != NULL) { + int ntail = 0; + + /* The fragment is partially pulled by someone, + * this can happen on input. Copy it and everything + * after it. */ + + if (skb_shared(skb1)) + copyflag = 1; + + /* If the skb is the last, worry about trailer. */ + + if (skb1->next == NULL && tailbits) { + if (skb_shinfo(skb1)->nr_frags || + skb_has_frag_list(skb1) || + skb_tailroom(skb1) < tailbits) + ntail = tailbits + 128; + } + + if (copyflag || + skb_cloned(skb1) || + ntail || + skb_shinfo(skb1)->nr_frags || + skb_has_frag_list(skb1)) { + struct sk_buff *skb2; + + /* Fuck, we are miserable poor guys... */ + if (ntail == 0) + skb2 = skb_copy(skb1, GFP_ATOMIC); + else + skb2 = skb_copy_expand(skb1, + skb_headroom(skb1), + ntail, + GFP_ATOMIC); + if (unlikely(skb2 == NULL)) + return -ENOMEM; + + if (skb1->sk) + skb_set_owner_w(skb2, skb1->sk); + + /* Looking around. Are we still alive? + * OK, link new skb, drop old one */ + + skb2->next = skb1->next; + *skb_p = skb2; + kfree_skb(skb1); + skb1 = skb2; + } + elt++; + *trailer = skb1; + skb_p = &skb1->next; + } + + return elt; +} +EXPORT_SYMBOL_GPL(skb_cow_data); + +static void sock_rmem_free(struct sk_buff *skb) +{ + struct sock *sk = skb->sk; + + atomic_sub(skb->truesize, &sk->sk_rmem_alloc); +} + +static void skb_set_err_queue(struct sk_buff *skb) +{ + /* pkt_type of skbs received on local sockets is never PACKET_OUTGOING. + * So, it is safe to (mis)use it to mark skbs on the error queue. + */ + skb->pkt_type = PACKET_OUTGOING; + BUILD_BUG_ON(PACKET_OUTGOING == 0); +} + +/* + * Note: We dont mem charge error packets (no sk_forward_alloc changes) + */ +int sock_queue_err_skb(struct sock *sk, struct sk_buff *skb) +{ + if (atomic_read(&sk->sk_rmem_alloc) + skb->truesize >= + (unsigned int)READ_ONCE(sk->sk_rcvbuf)) + return -ENOMEM; + + skb_orphan(skb); + skb->sk = sk; + skb->destructor = sock_rmem_free; + atomic_add(skb->truesize, &sk->sk_rmem_alloc); + skb_set_err_queue(skb); + + /* before exiting rcu section, make sure dst is refcounted */ + skb_dst_force(skb); + + skb_queue_tail(&sk->sk_error_queue, skb); + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + return 0; +} +EXPORT_SYMBOL(sock_queue_err_skb); + +static bool is_icmp_err_skb(const struct sk_buff *skb) +{ + return skb && (SKB_EXT_ERR(skb)->ee.ee_origin == SO_EE_ORIGIN_ICMP || + SKB_EXT_ERR(skb)->ee.ee_origin == SO_EE_ORIGIN_ICMP6); +} + +struct sk_buff *sock_dequeue_err_skb(struct sock *sk) +{ + struct sk_buff_head *q = &sk->sk_error_queue; + struct sk_buff *skb, *skb_next = NULL; + bool icmp_next = false; + unsigned long flags; + + spin_lock_irqsave(&q->lock, flags); + skb = __skb_dequeue(q); + if (skb && (skb_next = skb_peek(q))) { + icmp_next = is_icmp_err_skb(skb_next); + if (icmp_next) + sk->sk_err = SKB_EXT_ERR(skb_next)->ee.ee_errno; + } + spin_unlock_irqrestore(&q->lock, flags); + + if (is_icmp_err_skb(skb) && !icmp_next) + sk->sk_err = 0; + + if (skb_next) + sk_error_report(sk); + + return skb; +} +EXPORT_SYMBOL(sock_dequeue_err_skb); + +/** + * skb_clone_sk - create clone of skb, and take reference to socket + * @skb: the skb to clone + * + * This function creates a clone of a buffer that holds a reference on + * sk_refcnt. Buffers created via this function are meant to be + * returned using sock_queue_err_skb, or free via kfree_skb. + * + * When passing buffers allocated with this function to sock_queue_err_skb + * it is necessary to wrap the call with sock_hold/sock_put in order to + * prevent the socket from being released prior to being enqueued on + * the sk_error_queue. + */ +struct sk_buff *skb_clone_sk(struct sk_buff *skb) +{ + struct sock *sk = skb->sk; + struct sk_buff *clone; + + if (!sk || !refcount_inc_not_zero(&sk->sk_refcnt)) + return NULL; + + clone = skb_clone(skb, GFP_ATOMIC); + if (!clone) { + sock_put(sk); + return NULL; + } + + clone->sk = sk; + clone->destructor = sock_efree; + + return clone; +} +EXPORT_SYMBOL(skb_clone_sk); + +static void __skb_complete_tx_timestamp(struct sk_buff *skb, + struct sock *sk, + int tstype, + bool opt_stats) +{ + struct sock_exterr_skb *serr; + int err; + + BUILD_BUG_ON(sizeof(struct sock_exterr_skb) > sizeof(skb->cb)); + + serr = SKB_EXT_ERR(skb); + memset(serr, 0, sizeof(*serr)); + serr->ee.ee_errno = ENOMSG; + serr->ee.ee_origin = SO_EE_ORIGIN_TIMESTAMPING; + serr->ee.ee_info = tstype; + serr->opt_stats = opt_stats; + serr->header.h4.iif = skb->dev ? skb->dev->ifindex : 0; + if (READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_OPT_ID) { + serr->ee.ee_data = skb_shinfo(skb)->tskey; + if (sk_is_tcp(sk)) + serr->ee.ee_data -= atomic_read(&sk->sk_tskey); + } + + err = sock_queue_err_skb(sk, skb); + + if (err) + kfree_skb(skb); +} + +static bool skb_may_tx_timestamp(struct sock *sk, bool tsonly) +{ + bool ret; + + if (likely(READ_ONCE(sysctl_tstamp_allow_data) || tsonly)) + return true; + + read_lock_bh(&sk->sk_callback_lock); + ret = sk->sk_socket && sk->sk_socket->file && + file_ns_capable(sk->sk_socket->file, &init_user_ns, CAP_NET_RAW); + read_unlock_bh(&sk->sk_callback_lock); + return ret; +} + +void skb_complete_tx_timestamp(struct sk_buff *skb, + struct skb_shared_hwtstamps *hwtstamps) +{ + struct sock *sk = skb->sk; + + if (!skb_may_tx_timestamp(sk, false)) + goto err; + + /* Take a reference to prevent skb_orphan() from freeing the socket, + * but only if the socket refcount is not zero. + */ + if (likely(refcount_inc_not_zero(&sk->sk_refcnt))) { + *skb_hwtstamps(skb) = *hwtstamps; + __skb_complete_tx_timestamp(skb, sk, SCM_TSTAMP_SND, false); + sock_put(sk); + return; + } + +err: + kfree_skb(skb); +} +EXPORT_SYMBOL_GPL(skb_complete_tx_timestamp); + +void __skb_tstamp_tx(struct sk_buff *orig_skb, + const struct sk_buff *ack_skb, + struct skb_shared_hwtstamps *hwtstamps, + struct sock *sk, int tstype) +{ + struct sk_buff *skb; + bool tsonly, opt_stats = false; + u32 tsflags; + + if (!sk) + return; + + tsflags = READ_ONCE(sk->sk_tsflags); + if (!hwtstamps && !(tsflags & SOF_TIMESTAMPING_OPT_TX_SWHW) && + skb_shinfo(orig_skb)->tx_flags & SKBTX_IN_PROGRESS) + return; + + tsonly = tsflags & SOF_TIMESTAMPING_OPT_TSONLY; + if (!skb_may_tx_timestamp(sk, tsonly)) + return; + + if (tsonly) { +#ifdef CONFIG_INET + if ((tsflags & SOF_TIMESTAMPING_OPT_STATS) && + sk_is_tcp(sk)) { + skb = tcp_get_timestamping_opt_stats(sk, orig_skb, + ack_skb); + opt_stats = true; + } else +#endif + skb = alloc_skb(0, GFP_ATOMIC); + } else { + skb = skb_clone(orig_skb, GFP_ATOMIC); + + if (skb_orphan_frags_rx(skb, GFP_ATOMIC)) { + kfree_skb(skb); + return; + } + } + if (!skb) + return; + + if (tsonly) { + skb_shinfo(skb)->tx_flags |= skb_shinfo(orig_skb)->tx_flags & + SKBTX_ANY_TSTAMP; + skb_shinfo(skb)->tskey = skb_shinfo(orig_skb)->tskey; + } + + if (hwtstamps) + *skb_hwtstamps(skb) = *hwtstamps; + else + __net_timestamp(skb); + + __skb_complete_tx_timestamp(skb, sk, tstype, opt_stats); +} +EXPORT_SYMBOL_GPL(__skb_tstamp_tx); + +void skb_tstamp_tx(struct sk_buff *orig_skb, + struct skb_shared_hwtstamps *hwtstamps) +{ + return __skb_tstamp_tx(orig_skb, NULL, hwtstamps, orig_skb->sk, + SCM_TSTAMP_SND); +} +EXPORT_SYMBOL_GPL(skb_tstamp_tx); + +void skb_complete_wifi_ack(struct sk_buff *skb, bool acked) +{ + struct sock *sk = skb->sk; + struct sock_exterr_skb *serr; + int err = 1; + + skb->wifi_acked_valid = 1; + skb->wifi_acked = acked; + + serr = SKB_EXT_ERR(skb); + memset(serr, 0, sizeof(*serr)); + serr->ee.ee_errno = ENOMSG; + serr->ee.ee_origin = SO_EE_ORIGIN_TXSTATUS; + + /* Take a reference to prevent skb_orphan() from freeing the socket, + * but only if the socket refcount is not zero. + */ + if (likely(refcount_inc_not_zero(&sk->sk_refcnt))) { + err = sock_queue_err_skb(sk, skb); + sock_put(sk); + } + if (err) + kfree_skb(skb); +} +EXPORT_SYMBOL_GPL(skb_complete_wifi_ack); + +/** + * skb_partial_csum_set - set up and verify partial csum values for packet + * @skb: the skb to set + * @start: the number of bytes after skb->data to start checksumming. + * @off: the offset from start to place the checksum. + * + * For untrusted partially-checksummed packets, we need to make sure the values + * for skb->csum_start and skb->csum_offset are valid so we don't oops. + * + * This function checks and sets those values and skb->ip_summed: if this + * returns false you should drop the packet. + */ +bool skb_partial_csum_set(struct sk_buff *skb, u16 start, u16 off) +{ + u32 csum_end = (u32)start + (u32)off + sizeof(__sum16); + u32 csum_start = skb_headroom(skb) + (u32)start; + + if (unlikely(csum_start >= U16_MAX || csum_end > skb_headlen(skb))) { + net_warn_ratelimited("bad partial csum: csum=%u/%u headroom=%u headlen=%u\n", + start, off, skb_headroom(skb), skb_headlen(skb)); + return false; + } + skb->ip_summed = CHECKSUM_PARTIAL; + skb->csum_start = csum_start; + skb->csum_offset = off; + skb->transport_header = csum_start; + return true; +} +EXPORT_SYMBOL_GPL(skb_partial_csum_set); + +static int skb_maybe_pull_tail(struct sk_buff *skb, unsigned int len, + unsigned int max) +{ + if (skb_headlen(skb) >= len) + return 0; + + /* If we need to pullup then pullup to the max, so we + * won't need to do it again. + */ + if (max > skb->len) + max = skb->len; + + if (__pskb_pull_tail(skb, max - skb_headlen(skb)) == NULL) + return -ENOMEM; + + if (skb_headlen(skb) < len) + return -EPROTO; + + return 0; +} + +#define MAX_TCP_HDR_LEN (15 * 4) + +static __sum16 *skb_checksum_setup_ip(struct sk_buff *skb, + typeof(IPPROTO_IP) proto, + unsigned int off) +{ + int err; + + switch (proto) { + case IPPROTO_TCP: + err = skb_maybe_pull_tail(skb, off + sizeof(struct tcphdr), + off + MAX_TCP_HDR_LEN); + if (!err && !skb_partial_csum_set(skb, off, + offsetof(struct tcphdr, + check))) + err = -EPROTO; + return err ? ERR_PTR(err) : &tcp_hdr(skb)->check; + + case IPPROTO_UDP: + err = skb_maybe_pull_tail(skb, off + sizeof(struct udphdr), + off + sizeof(struct udphdr)); + if (!err && !skb_partial_csum_set(skb, off, + offsetof(struct udphdr, + check))) + err = -EPROTO; + return err ? ERR_PTR(err) : &udp_hdr(skb)->check; + } + + return ERR_PTR(-EPROTO); +} + +/* This value should be large enough to cover a tagged ethernet header plus + * maximally sized IP and TCP or UDP headers. + */ +#define MAX_IP_HDR_LEN 128 + +static int skb_checksum_setup_ipv4(struct sk_buff *skb, bool recalculate) +{ + unsigned int off; + bool fragment; + __sum16 *csum; + int err; + + fragment = false; + + err = skb_maybe_pull_tail(skb, + sizeof(struct iphdr), + MAX_IP_HDR_LEN); + if (err < 0) + goto out; + + if (ip_is_fragment(ip_hdr(skb))) + fragment = true; + + off = ip_hdrlen(skb); + + err = -EPROTO; + + if (fragment) + goto out; + + csum = skb_checksum_setup_ip(skb, ip_hdr(skb)->protocol, off); + if (IS_ERR(csum)) + return PTR_ERR(csum); + + if (recalculate) + *csum = ~csum_tcpudp_magic(ip_hdr(skb)->saddr, + ip_hdr(skb)->daddr, + skb->len - off, + ip_hdr(skb)->protocol, 0); + err = 0; + +out: + return err; +} + +/* This value should be large enough to cover a tagged ethernet header plus + * an IPv6 header, all options, and a maximal TCP or UDP header. + */ +#define MAX_IPV6_HDR_LEN 256 + +#define OPT_HDR(type, skb, off) \ + (type *)(skb_network_header(skb) + (off)) + +static int skb_checksum_setup_ipv6(struct sk_buff *skb, bool recalculate) +{ + int err; + u8 nexthdr; + unsigned int off; + unsigned int len; + bool fragment; + bool done; + __sum16 *csum; + + fragment = false; + done = false; + + off = sizeof(struct ipv6hdr); + + err = skb_maybe_pull_tail(skb, off, MAX_IPV6_HDR_LEN); + if (err < 0) + goto out; + + nexthdr = ipv6_hdr(skb)->nexthdr; + + len = sizeof(struct ipv6hdr) + ntohs(ipv6_hdr(skb)->payload_len); + while (off <= len && !done) { + switch (nexthdr) { + case IPPROTO_DSTOPTS: + case IPPROTO_HOPOPTS: + case IPPROTO_ROUTING: { + struct ipv6_opt_hdr *hp; + + err = skb_maybe_pull_tail(skb, + off + + sizeof(struct ipv6_opt_hdr), + MAX_IPV6_HDR_LEN); + if (err < 0) + goto out; + + hp = OPT_HDR(struct ipv6_opt_hdr, skb, off); + nexthdr = hp->nexthdr; + off += ipv6_optlen(hp); + break; + } + case IPPROTO_AH: { + struct ip_auth_hdr *hp; + + err = skb_maybe_pull_tail(skb, + off + + sizeof(struct ip_auth_hdr), + MAX_IPV6_HDR_LEN); + if (err < 0) + goto out; + + hp = OPT_HDR(struct ip_auth_hdr, skb, off); + nexthdr = hp->nexthdr; + off += ipv6_authlen(hp); + break; + } + case IPPROTO_FRAGMENT: { + struct frag_hdr *hp; + + err = skb_maybe_pull_tail(skb, + off + + sizeof(struct frag_hdr), + MAX_IPV6_HDR_LEN); + if (err < 0) + goto out; + + hp = OPT_HDR(struct frag_hdr, skb, off); + + if (hp->frag_off & htons(IP6_OFFSET | IP6_MF)) + fragment = true; + + nexthdr = hp->nexthdr; + off += sizeof(struct frag_hdr); + break; + } + default: + done = true; + break; + } + } + + err = -EPROTO; + + if (!done || fragment) + goto out; + + csum = skb_checksum_setup_ip(skb, nexthdr, off); + if (IS_ERR(csum)) + return PTR_ERR(csum); + + if (recalculate) + *csum = ~csum_ipv6_magic(&ipv6_hdr(skb)->saddr, + &ipv6_hdr(skb)->daddr, + skb->len - off, nexthdr, 0); + err = 0; + +out: + return err; +} + +/** + * skb_checksum_setup - set up partial checksum offset + * @skb: the skb to set up + * @recalculate: if true the pseudo-header checksum will be recalculated + */ +int skb_checksum_setup(struct sk_buff *skb, bool recalculate) +{ + int err; + + switch (skb->protocol) { + case htons(ETH_P_IP): + err = skb_checksum_setup_ipv4(skb, recalculate); + break; + + case htons(ETH_P_IPV6): + err = skb_checksum_setup_ipv6(skb, recalculate); + break; + + default: + err = -EPROTO; + break; + } + + return err; +} +EXPORT_SYMBOL(skb_checksum_setup); + +/** + * skb_checksum_maybe_trim - maybe trims the given skb + * @skb: the skb to check + * @transport_len: the data length beyond the network header + * + * Checks whether the given skb has data beyond the given transport length. + * If so, returns a cloned skb trimmed to this transport length. + * Otherwise returns the provided skb. Returns NULL in error cases + * (e.g. transport_len exceeds skb length or out-of-memory). + * + * Caller needs to set the skb transport header and free any returned skb if it + * differs from the provided skb. + */ +static struct sk_buff *skb_checksum_maybe_trim(struct sk_buff *skb, + unsigned int transport_len) +{ + struct sk_buff *skb_chk; + unsigned int len = skb_transport_offset(skb) + transport_len; + int ret; + + if (skb->len < len) + return NULL; + else if (skb->len == len) + return skb; + + skb_chk = skb_clone(skb, GFP_ATOMIC); + if (!skb_chk) + return NULL; + + ret = pskb_trim_rcsum(skb_chk, len); + if (ret) { + kfree_skb(skb_chk); + return NULL; + } + + return skb_chk; +} + +/** + * skb_checksum_trimmed - validate checksum of an skb + * @skb: the skb to check + * @transport_len: the data length beyond the network header + * @skb_chkf: checksum function to use + * + * Applies the given checksum function skb_chkf to the provided skb. + * Returns a checked and maybe trimmed skb. Returns NULL on error. + * + * If the skb has data beyond the given transport length, then a + * trimmed & cloned skb is checked and returned. + * + * Caller needs to set the skb transport header and free any returned skb if it + * differs from the provided skb. + */ +struct sk_buff *skb_checksum_trimmed(struct sk_buff *skb, + unsigned int transport_len, + __sum16(*skb_chkf)(struct sk_buff *skb)) +{ + struct sk_buff *skb_chk; + unsigned int offset = skb_transport_offset(skb); + __sum16 ret; + + skb_chk = skb_checksum_maybe_trim(skb, transport_len); + if (!skb_chk) + goto err; + + if (!pskb_may_pull(skb_chk, offset)) + goto err; + + skb_pull_rcsum(skb_chk, offset); + ret = skb_chkf(skb_chk); + skb_push_rcsum(skb_chk, offset); + + if (ret) + goto err; + + return skb_chk; + +err: + if (skb_chk && skb_chk != skb) + kfree_skb(skb_chk); + + return NULL; + +} +EXPORT_SYMBOL(skb_checksum_trimmed); + +void __skb_warn_lro_forwarding(const struct sk_buff *skb) +{ + net_warn_ratelimited("%s: received packets cannot be forwarded while LRO is enabled\n", + skb->dev->name); +} +EXPORT_SYMBOL(__skb_warn_lro_forwarding); + +void kfree_skb_partial(struct sk_buff *skb, bool head_stolen) +{ + if (head_stolen) { + skb_release_head_state(skb); + kmem_cache_free(skbuff_head_cache, skb); + } else { + __kfree_skb(skb); + } +} +EXPORT_SYMBOL(kfree_skb_partial); + +/** + * skb_try_coalesce - try to merge skb to prior one + * @to: prior buffer + * @from: buffer to add + * @fragstolen: pointer to boolean + * @delta_truesize: how much more was allocated than was requested + */ +bool skb_try_coalesce(struct sk_buff *to, struct sk_buff *from, + bool *fragstolen, int *delta_truesize) +{ + struct skb_shared_info *to_shinfo, *from_shinfo; + int i, delta, len = from->len; + + *fragstolen = false; + + if (skb_cloned(to)) + return false; + + /* In general, avoid mixing page_pool and non-page_pool allocated + * pages within the same SKB. Additionally avoid dealing with clones + * with page_pool pages, in case the SKB is using page_pool fragment + * references (PP_FLAG_PAGE_FRAG). Since we only take full page + * references for cloned SKBs at the moment that would result in + * inconsistent reference counts. + * In theory we could take full references if @from is cloned and + * !@to->pp_recycle but its tricky (due to potential race with + * the clone disappearing) and rare, so not worth dealing with. + */ + if (to->pp_recycle != from->pp_recycle || + (from->pp_recycle && skb_cloned(from))) + return false; + + if (len <= skb_tailroom(to)) { + if (len) + BUG_ON(skb_copy_bits(from, 0, skb_put(to, len), len)); + *delta_truesize = 0; + return true; + } + + to_shinfo = skb_shinfo(to); + from_shinfo = skb_shinfo(from); + if (to_shinfo->frag_list || from_shinfo->frag_list) + return false; + if (skb_zcopy(to) || skb_zcopy(from)) + return false; + + if (skb_headlen(from) != 0) { + struct page *page; + unsigned int offset; + + if (to_shinfo->nr_frags + + from_shinfo->nr_frags >= MAX_SKB_FRAGS) + return false; + + if (skb_head_is_locked(from)) + return false; + + delta = from->truesize - SKB_DATA_ALIGN(sizeof(struct sk_buff)); + + page = virt_to_head_page(from->head); + offset = from->data - (unsigned char *)page_address(page); + + skb_fill_page_desc(to, to_shinfo->nr_frags, + page, offset, skb_headlen(from)); + *fragstolen = true; + } else { + if (to_shinfo->nr_frags + + from_shinfo->nr_frags > MAX_SKB_FRAGS) + return false; + + delta = from->truesize - SKB_TRUESIZE(skb_end_offset(from)); + } + + WARN_ON_ONCE(delta < len); + + memcpy(to_shinfo->frags + to_shinfo->nr_frags, + from_shinfo->frags, + from_shinfo->nr_frags * sizeof(skb_frag_t)); + to_shinfo->nr_frags += from_shinfo->nr_frags; + + if (!skb_cloned(from)) + from_shinfo->nr_frags = 0; + + /* if the skb is not cloned this does nothing + * since we set nr_frags to 0. + */ + for (i = 0; i < from_shinfo->nr_frags; i++) + __skb_frag_ref(&from_shinfo->frags[i]); + + to->truesize += delta; + to->len += len; + to->data_len += len; + + *delta_truesize = delta; + return true; +} +EXPORT_SYMBOL(skb_try_coalesce); + +/** + * skb_scrub_packet - scrub an skb + * + * @skb: buffer to clean + * @xnet: packet is crossing netns + * + * skb_scrub_packet can be used after encapsulating or decapsulting a packet + * into/from a tunnel. Some information have to be cleared during these + * operations. + * skb_scrub_packet can also be used to clean a skb before injecting it in + * another namespace (@xnet == true). We have to clear all information in the + * skb that could impact namespace isolation. + */ +void skb_scrub_packet(struct sk_buff *skb, bool xnet) +{ + skb->pkt_type = PACKET_HOST; + skb->skb_iif = 0; + skb->ignore_df = 0; + skb_dst_drop(skb); + skb_ext_reset(skb); + nf_reset_ct(skb); + nf_reset_trace(skb); + +#ifdef CONFIG_NET_SWITCHDEV + skb->offload_fwd_mark = 0; + skb->offload_l3_fwd_mark = 0; +#endif + + if (!xnet) + return; + + ipvs_reset(skb); + skb->mark = 0; + skb_clear_tstamp(skb); +} +EXPORT_SYMBOL_GPL(skb_scrub_packet); + +/** + * skb_gso_transport_seglen - Return length of individual segments of a gso packet + * + * @skb: GSO skb + * + * skb_gso_transport_seglen is used to determine the real size of the + * individual segments, including Layer4 headers (TCP/UDP). + * + * The MAC/L2 or network (IP, IPv6) headers are not accounted for. + */ +static unsigned int skb_gso_transport_seglen(const struct sk_buff *skb) +{ + const struct skb_shared_info *shinfo = skb_shinfo(skb); + unsigned int thlen = 0; + + if (skb->encapsulation) { + thlen = skb_inner_transport_header(skb) - + skb_transport_header(skb); + + if (likely(shinfo->gso_type & (SKB_GSO_TCPV4 | SKB_GSO_TCPV6))) + thlen += inner_tcp_hdrlen(skb); + } else if (likely(shinfo->gso_type & (SKB_GSO_TCPV4 | SKB_GSO_TCPV6))) { + thlen = tcp_hdrlen(skb); + } else if (unlikely(skb_is_gso_sctp(skb))) { + thlen = sizeof(struct sctphdr); + } else if (shinfo->gso_type & SKB_GSO_UDP_L4) { + thlen = sizeof(struct udphdr); + } + /* UFO sets gso_size to the size of the fragmentation + * payload, i.e. the size of the L4 (UDP) header is already + * accounted for. + */ + return thlen + shinfo->gso_size; +} + +/** + * skb_gso_network_seglen - Return length of individual segments of a gso packet + * + * @skb: GSO skb + * + * skb_gso_network_seglen is used to determine the real size of the + * individual segments, including Layer3 (IP, IPv6) and L4 headers (TCP/UDP). + * + * The MAC/L2 header is not accounted for. + */ +static unsigned int skb_gso_network_seglen(const struct sk_buff *skb) +{ + unsigned int hdr_len = skb_transport_header(skb) - + skb_network_header(skb); + + return hdr_len + skb_gso_transport_seglen(skb); +} + +/** + * skb_gso_mac_seglen - Return length of individual segments of a gso packet + * + * @skb: GSO skb + * + * skb_gso_mac_seglen is used to determine the real size of the + * individual segments, including MAC/L2, Layer3 (IP, IPv6) and L4 + * headers (TCP/UDP). + */ +static unsigned int skb_gso_mac_seglen(const struct sk_buff *skb) +{ + unsigned int hdr_len = skb_transport_header(skb) - skb_mac_header(skb); + + return hdr_len + skb_gso_transport_seglen(skb); +} + +/** + * skb_gso_size_check - check the skb size, considering GSO_BY_FRAGS + * + * There are a couple of instances where we have a GSO skb, and we + * want to determine what size it would be after it is segmented. + * + * We might want to check: + * - L3+L4+payload size (e.g. IP forwarding) + * - L2+L3+L4+payload size (e.g. sanity check before passing to driver) + * + * This is a helper to do that correctly considering GSO_BY_FRAGS. + * + * @skb: GSO skb + * + * @seg_len: The segmented length (from skb_gso_*_seglen). In the + * GSO_BY_FRAGS case this will be [header sizes + GSO_BY_FRAGS]. + * + * @max_len: The maximum permissible length. + * + * Returns true if the segmented length <= max length. + */ +static inline bool skb_gso_size_check(const struct sk_buff *skb, + unsigned int seg_len, + unsigned int max_len) { + const struct skb_shared_info *shinfo = skb_shinfo(skb); + const struct sk_buff *iter; + + if (shinfo->gso_size != GSO_BY_FRAGS) + return seg_len <= max_len; + + /* Undo this so we can re-use header sizes */ + seg_len -= GSO_BY_FRAGS; + + skb_walk_frags(skb, iter) { + if (seg_len + skb_headlen(iter) > max_len) + return false; + } + + return true; +} + +/** + * skb_gso_validate_network_len - Will a split GSO skb fit into a given MTU? + * + * @skb: GSO skb + * @mtu: MTU to validate against + * + * skb_gso_validate_network_len validates if a given skb will fit a + * wanted MTU once split. It considers L3 headers, L4 headers, and the + * payload. + */ +bool skb_gso_validate_network_len(const struct sk_buff *skb, unsigned int mtu) +{ + return skb_gso_size_check(skb, skb_gso_network_seglen(skb), mtu); +} +EXPORT_SYMBOL_GPL(skb_gso_validate_network_len); + +/** + * skb_gso_validate_mac_len - Will a split GSO skb fit in a given length? + * + * @skb: GSO skb + * @len: length to validate against + * + * skb_gso_validate_mac_len validates if a given skb will fit a wanted + * length once split, including L2, L3 and L4 headers and the payload. + */ +bool skb_gso_validate_mac_len(const struct sk_buff *skb, unsigned int len) +{ + return skb_gso_size_check(skb, skb_gso_mac_seglen(skb), len); +} +EXPORT_SYMBOL_GPL(skb_gso_validate_mac_len); + +static struct sk_buff *skb_reorder_vlan_header(struct sk_buff *skb) +{ + int mac_len, meta_len; + void *meta; + + if (skb_cow(skb, skb_headroom(skb)) < 0) { + kfree_skb(skb); + return NULL; + } + + mac_len = skb->data - skb_mac_header(skb); + if (likely(mac_len > VLAN_HLEN + ETH_TLEN)) { + memmove(skb_mac_header(skb) + VLAN_HLEN, skb_mac_header(skb), + mac_len - VLAN_HLEN - ETH_TLEN); + } + + meta_len = skb_metadata_len(skb); + if (meta_len) { + meta = skb_metadata_end(skb) - meta_len; + memmove(meta + VLAN_HLEN, meta, meta_len); + } + + skb->mac_header += VLAN_HLEN; + return skb; +} + +struct sk_buff *skb_vlan_untag(struct sk_buff *skb) +{ + struct vlan_hdr *vhdr; + u16 vlan_tci; + + if (unlikely(skb_vlan_tag_present(skb))) { + /* vlan_tci is already set-up so leave this for another time */ + return skb; + } + + skb = skb_share_check(skb, GFP_ATOMIC); + if (unlikely(!skb)) + goto err_free; + /* We may access the two bytes after vlan_hdr in vlan_set_encap_proto(). */ + if (unlikely(!pskb_may_pull(skb, VLAN_HLEN + sizeof(unsigned short)))) + goto err_free; + + vhdr = (struct vlan_hdr *)skb->data; + vlan_tci = ntohs(vhdr->h_vlan_TCI); + __vlan_hwaccel_put_tag(skb, skb->protocol, vlan_tci); + + skb_pull_rcsum(skb, VLAN_HLEN); + vlan_set_encap_proto(skb, vhdr); + + skb = skb_reorder_vlan_header(skb); + if (unlikely(!skb)) + goto err_free; + + skb_reset_network_header(skb); + if (!skb_transport_header_was_set(skb)) + skb_reset_transport_header(skb); + skb_reset_mac_len(skb); + + return skb; + +err_free: + kfree_skb(skb); + return NULL; +} +EXPORT_SYMBOL(skb_vlan_untag); + +int skb_ensure_writable(struct sk_buff *skb, unsigned int write_len) +{ + if (!pskb_may_pull(skb, write_len)) + return -ENOMEM; + + if (!skb_cloned(skb) || skb_clone_writable(skb, write_len)) + return 0; + + return pskb_expand_head(skb, 0, 0, GFP_ATOMIC); +} +EXPORT_SYMBOL(skb_ensure_writable); + +/* remove VLAN header from packet and update csum accordingly. + * expects a non skb_vlan_tag_present skb with a vlan tag payload + */ +int __skb_vlan_pop(struct sk_buff *skb, u16 *vlan_tci) +{ + struct vlan_hdr *vhdr; + int offset = skb->data - skb_mac_header(skb); + int err; + + if (WARN_ONCE(offset, + "__skb_vlan_pop got skb with skb->data not at mac header (offset %d)\n", + offset)) { + return -EINVAL; + } + + err = skb_ensure_writable(skb, VLAN_ETH_HLEN); + if (unlikely(err)) + return err; + + skb_postpull_rcsum(skb, skb->data + (2 * ETH_ALEN), VLAN_HLEN); + + vhdr = (struct vlan_hdr *)(skb->data + ETH_HLEN); + *vlan_tci = ntohs(vhdr->h_vlan_TCI); + + memmove(skb->data + VLAN_HLEN, skb->data, 2 * ETH_ALEN); + __skb_pull(skb, VLAN_HLEN); + + vlan_set_encap_proto(skb, vhdr); + skb->mac_header += VLAN_HLEN; + + if (skb_network_offset(skb) < ETH_HLEN) + skb_set_network_header(skb, ETH_HLEN); + + skb_reset_mac_len(skb); + + return err; +} +EXPORT_SYMBOL(__skb_vlan_pop); + +/* Pop a vlan tag either from hwaccel or from payload. + * Expects skb->data at mac header. + */ +int skb_vlan_pop(struct sk_buff *skb) +{ + u16 vlan_tci; + __be16 vlan_proto; + int err; + + if (likely(skb_vlan_tag_present(skb))) { + __vlan_hwaccel_clear_tag(skb); + } else { + if (unlikely(!eth_type_vlan(skb->protocol))) + return 0; + + err = __skb_vlan_pop(skb, &vlan_tci); + if (err) + return err; + } + /* move next vlan tag to hw accel tag */ + if (likely(!eth_type_vlan(skb->protocol))) + return 0; + + vlan_proto = skb->protocol; + err = __skb_vlan_pop(skb, &vlan_tci); + if (unlikely(err)) + return err; + + __vlan_hwaccel_put_tag(skb, vlan_proto, vlan_tci); + return 0; +} +EXPORT_SYMBOL(skb_vlan_pop); + +/* Push a vlan tag either into hwaccel or into payload (if hwaccel tag present). + * Expects skb->data at mac header. + */ +int skb_vlan_push(struct sk_buff *skb, __be16 vlan_proto, u16 vlan_tci) +{ + if (skb_vlan_tag_present(skb)) { + int offset = skb->data - skb_mac_header(skb); + int err; + + if (WARN_ONCE(offset, + "skb_vlan_push got skb with skb->data not at mac header (offset %d)\n", + offset)) { + return -EINVAL; + } + + err = __vlan_insert_tag(skb, skb->vlan_proto, + skb_vlan_tag_get(skb)); + if (err) + return err; + + skb->protocol = skb->vlan_proto; + skb->mac_len += VLAN_HLEN; + + skb_postpush_rcsum(skb, skb->data + (2 * ETH_ALEN), VLAN_HLEN); + } + __vlan_hwaccel_put_tag(skb, vlan_proto, vlan_tci); + return 0; +} +EXPORT_SYMBOL(skb_vlan_push); + +/** + * skb_eth_pop() - Drop the Ethernet header at the head of a packet + * + * @skb: Socket buffer to modify + * + * Drop the Ethernet header of @skb. + * + * Expects that skb->data points to the mac header and that no VLAN tags are + * present. + * + * Returns 0 on success, -errno otherwise. + */ +int skb_eth_pop(struct sk_buff *skb) +{ + if (!pskb_may_pull(skb, ETH_HLEN) || skb_vlan_tagged(skb) || + skb_network_offset(skb) < ETH_HLEN) + return -EPROTO; + + skb_pull_rcsum(skb, ETH_HLEN); + skb_reset_mac_header(skb); + skb_reset_mac_len(skb); + + return 0; +} +EXPORT_SYMBOL(skb_eth_pop); + +/** + * skb_eth_push() - Add a new Ethernet header at the head of a packet + * + * @skb: Socket buffer to modify + * @dst: Destination MAC address of the new header + * @src: Source MAC address of the new header + * + * Prepend @skb with a new Ethernet header. + * + * Expects that skb->data points to the mac header, which must be empty. + * + * Returns 0 on success, -errno otherwise. + */ +int skb_eth_push(struct sk_buff *skb, const unsigned char *dst, + const unsigned char *src) +{ + struct ethhdr *eth; + int err; + + if (skb_network_offset(skb) || skb_vlan_tag_present(skb)) + return -EPROTO; + + err = skb_cow_head(skb, sizeof(*eth)); + if (err < 0) + return err; + + skb_push(skb, sizeof(*eth)); + skb_reset_mac_header(skb); + skb_reset_mac_len(skb); + + eth = eth_hdr(skb); + ether_addr_copy(eth->h_dest, dst); + ether_addr_copy(eth->h_source, src); + eth->h_proto = skb->protocol; + + skb_postpush_rcsum(skb, eth, sizeof(*eth)); + + return 0; +} +EXPORT_SYMBOL(skb_eth_push); + +/* Update the ethertype of hdr and the skb csum value if required. */ +static void skb_mod_eth_type(struct sk_buff *skb, struct ethhdr *hdr, + __be16 ethertype) +{ + if (skb->ip_summed == CHECKSUM_COMPLETE) { + __be16 diff[] = { ~hdr->h_proto, ethertype }; + + skb->csum = csum_partial((char *)diff, sizeof(diff), skb->csum); + } + + hdr->h_proto = ethertype; +} + +/** + * skb_mpls_push() - push a new MPLS header after mac_len bytes from start of + * the packet + * + * @skb: buffer + * @mpls_lse: MPLS label stack entry to push + * @mpls_proto: ethertype of the new MPLS header (expects 0x8847 or 0x8848) + * @mac_len: length of the MAC header + * @ethernet: flag to indicate if the resulting packet after skb_mpls_push is + * ethernet + * + * Expects skb->data at mac header. + * + * Returns 0 on success, -errno otherwise. + */ +int skb_mpls_push(struct sk_buff *skb, __be32 mpls_lse, __be16 mpls_proto, + int mac_len, bool ethernet) +{ + struct mpls_shim_hdr *lse; + int err; + + if (unlikely(!eth_p_mpls(mpls_proto))) + return -EINVAL; + + /* Networking stack does not allow simultaneous Tunnel and MPLS GSO. */ + if (skb->encapsulation) + return -EINVAL; + + err = skb_cow_head(skb, MPLS_HLEN); + if (unlikely(err)) + return err; + + if (!skb->inner_protocol) { + skb_set_inner_network_header(skb, skb_network_offset(skb)); + skb_set_inner_protocol(skb, skb->protocol); + } + + skb_push(skb, MPLS_HLEN); + memmove(skb_mac_header(skb) - MPLS_HLEN, skb_mac_header(skb), + mac_len); + skb_reset_mac_header(skb); + skb_set_network_header(skb, mac_len); + skb_reset_mac_len(skb); + + lse = mpls_hdr(skb); + lse->label_stack_entry = mpls_lse; + skb_postpush_rcsum(skb, lse, MPLS_HLEN); + + if (ethernet && mac_len >= ETH_HLEN) + skb_mod_eth_type(skb, eth_hdr(skb), mpls_proto); + skb->protocol = mpls_proto; + + return 0; +} +EXPORT_SYMBOL_GPL(skb_mpls_push); + +/** + * skb_mpls_pop() - pop the outermost MPLS header + * + * @skb: buffer + * @next_proto: ethertype of header after popped MPLS header + * @mac_len: length of the MAC header + * @ethernet: flag to indicate if the packet is ethernet + * + * Expects skb->data at mac header. + * + * Returns 0 on success, -errno otherwise. + */ +int skb_mpls_pop(struct sk_buff *skb, __be16 next_proto, int mac_len, + bool ethernet) +{ + int err; + + if (unlikely(!eth_p_mpls(skb->protocol))) + return 0; + + err = skb_ensure_writable(skb, mac_len + MPLS_HLEN); + if (unlikely(err)) + return err; + + skb_postpull_rcsum(skb, mpls_hdr(skb), MPLS_HLEN); + memmove(skb_mac_header(skb) + MPLS_HLEN, skb_mac_header(skb), + mac_len); + + __skb_pull(skb, MPLS_HLEN); + skb_reset_mac_header(skb); + skb_set_network_header(skb, mac_len); + + if (ethernet && mac_len >= ETH_HLEN) { + struct ethhdr *hdr; + + /* use mpls_hdr() to get ethertype to account for VLANs. */ + hdr = (struct ethhdr *)((void *)mpls_hdr(skb) - ETH_HLEN); + skb_mod_eth_type(skb, hdr, next_proto); + } + skb->protocol = next_proto; + + return 0; +} +EXPORT_SYMBOL_GPL(skb_mpls_pop); + +/** + * skb_mpls_update_lse() - modify outermost MPLS header and update csum + * + * @skb: buffer + * @mpls_lse: new MPLS label stack entry to update to + * + * Expects skb->data at mac header. + * + * Returns 0 on success, -errno otherwise. + */ +int skb_mpls_update_lse(struct sk_buff *skb, __be32 mpls_lse) +{ + int err; + + if (unlikely(!eth_p_mpls(skb->protocol))) + return -EINVAL; + + err = skb_ensure_writable(skb, skb->mac_len + MPLS_HLEN); + if (unlikely(err)) + return err; + + if (skb->ip_summed == CHECKSUM_COMPLETE) { + __be32 diff[] = { ~mpls_hdr(skb)->label_stack_entry, mpls_lse }; + + skb->csum = csum_partial((char *)diff, sizeof(diff), skb->csum); + } + + mpls_hdr(skb)->label_stack_entry = mpls_lse; + + return 0; +} +EXPORT_SYMBOL_GPL(skb_mpls_update_lse); + +/** + * skb_mpls_dec_ttl() - decrement the TTL of the outermost MPLS header + * + * @skb: buffer + * + * Expects skb->data at mac header. + * + * Returns 0 on success, -errno otherwise. + */ +int skb_mpls_dec_ttl(struct sk_buff *skb) +{ + u32 lse; + u8 ttl; + + if (unlikely(!eth_p_mpls(skb->protocol))) + return -EINVAL; + + if (!pskb_may_pull(skb, skb_network_offset(skb) + MPLS_HLEN)) + return -ENOMEM; + + lse = be32_to_cpu(mpls_hdr(skb)->label_stack_entry); + ttl = (lse & MPLS_LS_TTL_MASK) >> MPLS_LS_TTL_SHIFT; + if (!--ttl) + return -EINVAL; + + lse &= ~MPLS_LS_TTL_MASK; + lse |= ttl << MPLS_LS_TTL_SHIFT; + + return skb_mpls_update_lse(skb, cpu_to_be32(lse)); +} +EXPORT_SYMBOL_GPL(skb_mpls_dec_ttl); + +/** + * alloc_skb_with_frags - allocate skb with page frags + * + * @header_len: size of linear part + * @data_len: needed length in frags + * @max_page_order: max page order desired. + * @errcode: pointer to error code if any + * @gfp_mask: allocation mask + * + * This can be used to allocate a paged skb, given a maximal order for frags. + */ +struct sk_buff *alloc_skb_with_frags(unsigned long header_len, + unsigned long data_len, + int max_page_order, + int *errcode, + gfp_t gfp_mask) +{ + int npages = (data_len + (PAGE_SIZE - 1)) >> PAGE_SHIFT; + unsigned long chunk; + struct sk_buff *skb; + struct page *page; + int i; + + *errcode = -EMSGSIZE; + /* Note this test could be relaxed, if we succeed to allocate + * high order pages... + */ + if (npages > MAX_SKB_FRAGS) + return NULL; + + *errcode = -ENOBUFS; + skb = alloc_skb(header_len, gfp_mask); + if (!skb) + return NULL; + + skb->truesize += npages << PAGE_SHIFT; + + for (i = 0; npages > 0; i++) { + int order = max_page_order; + + while (order) { + if (npages >= 1 << order) { + page = alloc_pages((gfp_mask & ~__GFP_DIRECT_RECLAIM) | + __GFP_COMP | + __GFP_NOWARN, + order); + if (page) + goto fill_page; + /* Do not retry other high order allocations */ + order = 1; + max_page_order = 0; + } + order--; + } + page = alloc_page(gfp_mask); + if (!page) + goto failure; +fill_page: + chunk = min_t(unsigned long, data_len, + PAGE_SIZE << order); + skb_fill_page_desc(skb, i, page, 0, chunk); + data_len -= chunk; + npages -= 1 << order; + } + return skb; + +failure: + kfree_skb(skb); + return NULL; +} +EXPORT_SYMBOL(alloc_skb_with_frags); + +/* carve out the first off bytes from skb when off < headlen */ +static int pskb_carve_inside_header(struct sk_buff *skb, const u32 off, + const int headlen, gfp_t gfp_mask) +{ + int i; + unsigned int size = skb_end_offset(skb); + int new_hlen = headlen - off; + u8 *data; + + if (skb_pfmemalloc(skb)) + gfp_mask |= __GFP_MEMALLOC; + + data = kmalloc_reserve(&size, gfp_mask, NUMA_NO_NODE, NULL); + if (!data) + return -ENOMEM; + size = SKB_WITH_OVERHEAD(size); + + /* Copy real data, and all frags */ + skb_copy_from_linear_data_offset(skb, off, data, new_hlen); + skb->len -= off; + + memcpy((struct skb_shared_info *)(data + size), + skb_shinfo(skb), + offsetof(struct skb_shared_info, + frags[skb_shinfo(skb)->nr_frags])); + if (skb_cloned(skb)) { + /* drop the old head gracefully */ + if (skb_orphan_frags(skb, gfp_mask)) { + kfree(data); + return -ENOMEM; + } + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) + skb_frag_ref(skb, i); + if (skb_has_frag_list(skb)) + skb_clone_fraglist(skb); + skb_release_data(skb); + } else { + /* we can reuse existing recount- all we did was + * relocate values + */ + skb_free_head(skb); + } + + skb->head = data; + skb->data = data; + skb->head_frag = 0; + skb_set_end_offset(skb, size); + skb_set_tail_pointer(skb, skb_headlen(skb)); + skb_headers_offset_update(skb, 0); + skb->cloned = 0; + skb->hdr_len = 0; + skb->nohdr = 0; + atomic_set(&skb_shinfo(skb)->dataref, 1); + + return 0; +} + +static int pskb_carve(struct sk_buff *skb, const u32 off, gfp_t gfp); + +/* carve out the first eat bytes from skb's frag_list. May recurse into + * pskb_carve() + */ +static int pskb_carve_frag_list(struct sk_buff *skb, + struct skb_shared_info *shinfo, int eat, + gfp_t gfp_mask) +{ + struct sk_buff *list = shinfo->frag_list; + struct sk_buff *clone = NULL; + struct sk_buff *insp = NULL; + + do { + if (!list) { + pr_err("Not enough bytes to eat. Want %d\n", eat); + return -EFAULT; + } + if (list->len <= eat) { + /* Eaten as whole. */ + eat -= list->len; + list = list->next; + insp = list; + } else { + /* Eaten partially. */ + if (skb_shared(list)) { + clone = skb_clone(list, gfp_mask); + if (!clone) + return -ENOMEM; + insp = list->next; + list = clone; + } else { + /* This may be pulled without problems. */ + insp = list; + } + if (pskb_carve(list, eat, gfp_mask) < 0) { + kfree_skb(clone); + return -ENOMEM; + } + break; + } + } while (eat); + + /* Free pulled out fragments. */ + while ((list = shinfo->frag_list) != insp) { + shinfo->frag_list = list->next; + consume_skb(list); + } + /* And insert new clone at head. */ + if (clone) { + clone->next = list; + shinfo->frag_list = clone; + } + return 0; +} + +/* carve off first len bytes from skb. Split line (off) is in the + * non-linear part of skb + */ +static int pskb_carve_inside_nonlinear(struct sk_buff *skb, const u32 off, + int pos, gfp_t gfp_mask) +{ + int i, k = 0; + unsigned int size = skb_end_offset(skb); + u8 *data; + const int nfrags = skb_shinfo(skb)->nr_frags; + struct skb_shared_info *shinfo; + + if (skb_pfmemalloc(skb)) + gfp_mask |= __GFP_MEMALLOC; + + data = kmalloc_reserve(&size, gfp_mask, NUMA_NO_NODE, NULL); + if (!data) + return -ENOMEM; + size = SKB_WITH_OVERHEAD(size); + + memcpy((struct skb_shared_info *)(data + size), + skb_shinfo(skb), offsetof(struct skb_shared_info, frags[0])); + if (skb_orphan_frags(skb, gfp_mask)) { + kfree(data); + return -ENOMEM; + } + shinfo = (struct skb_shared_info *)(data + size); + for (i = 0; i < nfrags; i++) { + int fsize = skb_frag_size(&skb_shinfo(skb)->frags[i]); + + if (pos + fsize > off) { + shinfo->frags[k] = skb_shinfo(skb)->frags[i]; + + if (pos < off) { + /* Split frag. + * We have two variants in this case: + * 1. Move all the frag to the second + * part, if it is possible. F.e. + * this approach is mandatory for TUX, + * where splitting is expensive. + * 2. Split is accurately. We make this. + */ + skb_frag_off_add(&shinfo->frags[0], off - pos); + skb_frag_size_sub(&shinfo->frags[0], off - pos); + } + skb_frag_ref(skb, i); + k++; + } + pos += fsize; + } + shinfo->nr_frags = k; + if (skb_has_frag_list(skb)) + skb_clone_fraglist(skb); + + /* split line is in frag list */ + if (k == 0 && pskb_carve_frag_list(skb, shinfo, off - pos, gfp_mask)) { + /* skb_frag_unref() is not needed here as shinfo->nr_frags = 0. */ + if (skb_has_frag_list(skb)) + kfree_skb_list(skb_shinfo(skb)->frag_list); + kfree(data); + return -ENOMEM; + } + skb_release_data(skb); + + skb->head = data; + skb->head_frag = 0; + skb->data = data; + skb_set_end_offset(skb, size); + skb_reset_tail_pointer(skb); + skb_headers_offset_update(skb, 0); + skb->cloned = 0; + skb->hdr_len = 0; + skb->nohdr = 0; + skb->len -= off; + skb->data_len = skb->len; + atomic_set(&skb_shinfo(skb)->dataref, 1); + return 0; +} + +/* remove len bytes from the beginning of the skb */ +static int pskb_carve(struct sk_buff *skb, const u32 len, gfp_t gfp) +{ + int headlen = skb_headlen(skb); + + if (len < headlen) + return pskb_carve_inside_header(skb, len, headlen, gfp); + else + return pskb_carve_inside_nonlinear(skb, len, headlen, gfp); +} + +/* Extract to_copy bytes starting at off from skb, and return this in + * a new skb + */ +struct sk_buff *pskb_extract(struct sk_buff *skb, int off, + int to_copy, gfp_t gfp) +{ + struct sk_buff *clone = skb_clone(skb, gfp); + + if (!clone) + return NULL; + + if (pskb_carve(clone, off, gfp) < 0 || + pskb_trim(clone, to_copy)) { + kfree_skb(clone); + return NULL; + } + return clone; +} +EXPORT_SYMBOL(pskb_extract); + +/** + * skb_condense - try to get rid of fragments/frag_list if possible + * @skb: buffer + * + * Can be used to save memory before skb is added to a busy queue. + * If packet has bytes in frags and enough tail room in skb->head, + * pull all of them, so that we can free the frags right now and adjust + * truesize. + * Notes: + * We do not reallocate skb->head thus can not fail. + * Caller must re-evaluate skb->truesize if needed. + */ +void skb_condense(struct sk_buff *skb) +{ + if (skb->data_len) { + if (skb->data_len > skb->end - skb->tail || + skb_cloned(skb)) + return; + + /* Nice, we can free page frag(s) right now */ + __pskb_pull_tail(skb, skb->data_len); + } + /* At this point, skb->truesize might be over estimated, + * because skb had a fragment, and fragments do not tell + * their truesize. + * When we pulled its content into skb->head, fragment + * was freed, but __pskb_pull_tail() could not possibly + * adjust skb->truesize, not knowing the frag truesize. + */ + skb->truesize = SKB_TRUESIZE(skb_end_offset(skb)); +} + +#ifdef CONFIG_SKB_EXTENSIONS +static void *skb_ext_get_ptr(struct skb_ext *ext, enum skb_ext_id id) +{ + return (void *)ext + (ext->offset[id] * SKB_EXT_ALIGN_VALUE); +} + +/** + * __skb_ext_alloc - allocate a new skb extensions storage + * + * @flags: See kmalloc(). + * + * Returns the newly allocated pointer. The pointer can later attached to a + * skb via __skb_ext_set(). + * Note: caller must handle the skb_ext as an opaque data. + */ +struct skb_ext *__skb_ext_alloc(gfp_t flags) +{ + struct skb_ext *new = kmem_cache_alloc(skbuff_ext_cache, flags); + + if (new) { + memset(new->offset, 0, sizeof(new->offset)); + refcount_set(&new->refcnt, 1); + } + + return new; +} + +static struct skb_ext *skb_ext_maybe_cow(struct skb_ext *old, + unsigned int old_active) +{ + struct skb_ext *new; + + if (refcount_read(&old->refcnt) == 1) + return old; + + new = kmem_cache_alloc(skbuff_ext_cache, GFP_ATOMIC); + if (!new) + return NULL; + + memcpy(new, old, old->chunks * SKB_EXT_ALIGN_VALUE); + refcount_set(&new->refcnt, 1); + +#ifdef CONFIG_XFRM + if (old_active & (1 << SKB_EXT_SEC_PATH)) { + struct sec_path *sp = skb_ext_get_ptr(old, SKB_EXT_SEC_PATH); + unsigned int i; + + for (i = 0; i < sp->len; i++) + xfrm_state_hold(sp->xvec[i]); + } +#endif + __skb_ext_put(old); + return new; +} + +/** + * __skb_ext_set - attach the specified extension storage to this skb + * @skb: buffer + * @id: extension id + * @ext: extension storage previously allocated via __skb_ext_alloc() + * + * Existing extensions, if any, are cleared. + * + * Returns the pointer to the extension. + */ +void *__skb_ext_set(struct sk_buff *skb, enum skb_ext_id id, + struct skb_ext *ext) +{ + unsigned int newlen, newoff = SKB_EXT_CHUNKSIZEOF(*ext); + + skb_ext_put(skb); + newlen = newoff + skb_ext_type_len[id]; + ext->chunks = newlen; + ext->offset[id] = newoff; + skb->extensions = ext; + skb->active_extensions = 1 << id; + return skb_ext_get_ptr(ext, id); +} + +/** + * skb_ext_add - allocate space for given extension, COW if needed + * @skb: buffer + * @id: extension to allocate space for + * + * Allocates enough space for the given extension. + * If the extension is already present, a pointer to that extension + * is returned. + * + * If the skb was cloned, COW applies and the returned memory can be + * modified without changing the extension space of clones buffers. + * + * Returns pointer to the extension or NULL on allocation failure. + */ +void *skb_ext_add(struct sk_buff *skb, enum skb_ext_id id) +{ + struct skb_ext *new, *old = NULL; + unsigned int newlen, newoff; + + if (skb->active_extensions) { + old = skb->extensions; + + new = skb_ext_maybe_cow(old, skb->active_extensions); + if (!new) + return NULL; + + if (__skb_ext_exist(new, id)) + goto set_active; + + newoff = new->chunks; + } else { + newoff = SKB_EXT_CHUNKSIZEOF(*new); + + new = __skb_ext_alloc(GFP_ATOMIC); + if (!new) + return NULL; + } + + newlen = newoff + skb_ext_type_len[id]; + new->chunks = newlen; + new->offset[id] = newoff; +set_active: + skb->slow_gro = 1; + skb->extensions = new; + skb->active_extensions |= 1 << id; + return skb_ext_get_ptr(new, id); +} +EXPORT_SYMBOL(skb_ext_add); + +#ifdef CONFIG_XFRM +static void skb_ext_put_sp(struct sec_path *sp) +{ + unsigned int i; + + for (i = 0; i < sp->len; i++) + xfrm_state_put(sp->xvec[i]); +} +#endif + +#ifdef CONFIG_MCTP_FLOWS +static void skb_ext_put_mctp(struct mctp_flow *flow) +{ + if (flow->key) + mctp_key_unref(flow->key); +} +#endif + +void __skb_ext_del(struct sk_buff *skb, enum skb_ext_id id) +{ + struct skb_ext *ext = skb->extensions; + + skb->active_extensions &= ~(1 << id); + if (skb->active_extensions == 0) { + skb->extensions = NULL; + __skb_ext_put(ext); +#ifdef CONFIG_XFRM + } else if (id == SKB_EXT_SEC_PATH && + refcount_read(&ext->refcnt) == 1) { + struct sec_path *sp = skb_ext_get_ptr(ext, SKB_EXT_SEC_PATH); + + skb_ext_put_sp(sp); + sp->len = 0; +#endif + } +} +EXPORT_SYMBOL(__skb_ext_del); + +void __skb_ext_put(struct skb_ext *ext) +{ + /* If this is last clone, nothing can increment + * it after check passes. Avoids one atomic op. + */ + if (refcount_read(&ext->refcnt) == 1) + goto free_now; + + if (!refcount_dec_and_test(&ext->refcnt)) + return; +free_now: +#ifdef CONFIG_XFRM + if (__skb_ext_exist(ext, SKB_EXT_SEC_PATH)) + skb_ext_put_sp(skb_ext_get_ptr(ext, SKB_EXT_SEC_PATH)); +#endif +#ifdef CONFIG_MCTP_FLOWS + if (__skb_ext_exist(ext, SKB_EXT_MCTP)) + skb_ext_put_mctp(skb_ext_get_ptr(ext, SKB_EXT_MCTP)); +#endif + + kmem_cache_free(skbuff_ext_cache, ext); +} +EXPORT_SYMBOL(__skb_ext_put); +#endif /* CONFIG_SKB_EXTENSIONS */ + +/** + * skb_attempt_defer_free - queue skb for remote freeing + * @skb: buffer + * + * Put @skb in a per-cpu list, using the cpu which + * allocated the skb/pages to reduce false sharing + * and memory zone spinlock contention. + */ +void skb_attempt_defer_free(struct sk_buff *skb) +{ + int cpu = skb->alloc_cpu; + struct softnet_data *sd; + unsigned long flags; + unsigned int defer_max; + bool kick; + + if (WARN_ON_ONCE(cpu >= nr_cpu_ids) || + !cpu_online(cpu) || + cpu == raw_smp_processor_id()) { +nodefer: __kfree_skb(skb); + return; + } + + sd = &per_cpu(softnet_data, cpu); + defer_max = READ_ONCE(sysctl_skb_defer_max); + if (READ_ONCE(sd->defer_count) >= defer_max) + goto nodefer; + + spin_lock_irqsave(&sd->defer_lock, flags); + /* Send an IPI every time queue reaches half capacity. */ + kick = sd->defer_count == (defer_max >> 1); + /* Paired with the READ_ONCE() few lines above */ + WRITE_ONCE(sd->defer_count, sd->defer_count + 1); + + skb->next = sd->defer_list; + /* Paired with READ_ONCE() in skb_defer_free_flush() */ + WRITE_ONCE(sd->defer_list, skb); + spin_unlock_irqrestore(&sd->defer_lock, flags); + + /* Make sure to trigger NET_RX_SOFTIRQ on the remote CPU + * if we are unlucky enough (this seems very unlikely). + */ + if (unlikely(kick) && !cmpxchg(&sd->defer_ipi_scheduled, 0, 1)) + smp_call_function_single_async(cpu, &sd->defer_csd); +} diff --git a/net/core/skmsg.c b/net/core/skmsg.c new file mode 100644 index 000000000..3818035ea --- /dev/null +++ b/net/core/skmsg.c @@ -0,0 +1,1246 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ + +#include +#include +#include + +#include +#include +#include + +static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce) +{ + if (msg->sg.end > msg->sg.start && + elem_first_coalesce < msg->sg.end) + return true; + + if (msg->sg.end < msg->sg.start && + (elem_first_coalesce > msg->sg.start || + elem_first_coalesce < msg->sg.end)) + return true; + + return false; +} + +int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len, + int elem_first_coalesce) +{ + struct page_frag *pfrag = sk_page_frag(sk); + u32 osize = msg->sg.size; + int ret = 0; + + len -= msg->sg.size; + while (len > 0) { + struct scatterlist *sge; + u32 orig_offset; + int use, i; + + if (!sk_page_frag_refill(sk, pfrag)) { + ret = -ENOMEM; + goto msg_trim; + } + + orig_offset = pfrag->offset; + use = min_t(int, len, pfrag->size - orig_offset); + if (!sk_wmem_schedule(sk, use)) { + ret = -ENOMEM; + goto msg_trim; + } + + i = msg->sg.end; + sk_msg_iter_var_prev(i); + sge = &msg->sg.data[i]; + + if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) && + sg_page(sge) == pfrag->page && + sge->offset + sge->length == orig_offset) { + sge->length += use; + } else { + if (sk_msg_full(msg)) { + ret = -ENOSPC; + break; + } + + sge = &msg->sg.data[msg->sg.end]; + sg_unmark_end(sge); + sg_set_page(sge, pfrag->page, use, orig_offset); + get_page(pfrag->page); + sk_msg_iter_next(msg, end); + } + + sk_mem_charge(sk, use); + msg->sg.size += use; + pfrag->offset += use; + len -= use; + } + + return ret; + +msg_trim: + sk_msg_trim(sk, msg, osize); + return ret; +} +EXPORT_SYMBOL_GPL(sk_msg_alloc); + +int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src, + u32 off, u32 len) +{ + int i = src->sg.start; + struct scatterlist *sge = sk_msg_elem(src, i); + struct scatterlist *sgd = NULL; + u32 sge_len, sge_off; + + while (off) { + if (sge->length > off) + break; + off -= sge->length; + sk_msg_iter_var_next(i); + if (i == src->sg.end && off) + return -ENOSPC; + sge = sk_msg_elem(src, i); + } + + while (len) { + sge_len = sge->length - off; + if (sge_len > len) + sge_len = len; + + if (dst->sg.end) + sgd = sk_msg_elem(dst, dst->sg.end - 1); + + if (sgd && + (sg_page(sge) == sg_page(sgd)) && + (sg_virt(sge) + off == sg_virt(sgd) + sgd->length)) { + sgd->length += sge_len; + dst->sg.size += sge_len; + } else if (!sk_msg_full(dst)) { + sge_off = sge->offset + off; + sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off); + } else { + return -ENOSPC; + } + + off = 0; + len -= sge_len; + sk_mem_charge(sk, sge_len); + sk_msg_iter_var_next(i); + if (i == src->sg.end && len) + return -ENOSPC; + sge = sk_msg_elem(src, i); + } + + return 0; +} +EXPORT_SYMBOL_GPL(sk_msg_clone); + +void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes) +{ + int i = msg->sg.start; + + do { + struct scatterlist *sge = sk_msg_elem(msg, i); + + if (bytes < sge->length) { + sge->length -= bytes; + sge->offset += bytes; + sk_mem_uncharge(sk, bytes); + break; + } + + sk_mem_uncharge(sk, sge->length); + bytes -= sge->length; + sge->length = 0; + sge->offset = 0; + sk_msg_iter_var_next(i); + } while (bytes && i != msg->sg.end); + msg->sg.start = i; +} +EXPORT_SYMBOL_GPL(sk_msg_return_zero); + +void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes) +{ + int i = msg->sg.start; + + do { + struct scatterlist *sge = &msg->sg.data[i]; + int uncharge = (bytes < sge->length) ? bytes : sge->length; + + sk_mem_uncharge(sk, uncharge); + bytes -= uncharge; + sk_msg_iter_var_next(i); + } while (i != msg->sg.end); +} +EXPORT_SYMBOL_GPL(sk_msg_return); + +static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i, + bool charge) +{ + struct scatterlist *sge = sk_msg_elem(msg, i); + u32 len = sge->length; + + /* When the skb owns the memory we free it from consume_skb path. */ + if (!msg->skb) { + if (charge) + sk_mem_uncharge(sk, len); + put_page(sg_page(sge)); + } + memset(sge, 0, sizeof(*sge)); + return len; +} + +static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i, + bool charge) +{ + struct scatterlist *sge = sk_msg_elem(msg, i); + int freed = 0; + + while (msg->sg.size) { + msg->sg.size -= sge->length; + freed += sk_msg_free_elem(sk, msg, i, charge); + sk_msg_iter_var_next(i); + sk_msg_check_to_free(msg, i, msg->sg.size); + sge = sk_msg_elem(msg, i); + } + consume_skb(msg->skb); + sk_msg_init(msg); + return freed; +} + +int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg) +{ + return __sk_msg_free(sk, msg, msg->sg.start, false); +} +EXPORT_SYMBOL_GPL(sk_msg_free_nocharge); + +int sk_msg_free(struct sock *sk, struct sk_msg *msg) +{ + return __sk_msg_free(sk, msg, msg->sg.start, true); +} +EXPORT_SYMBOL_GPL(sk_msg_free); + +static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, + u32 bytes, bool charge) +{ + struct scatterlist *sge; + u32 i = msg->sg.start; + + while (bytes) { + sge = sk_msg_elem(msg, i); + if (!sge->length) + break; + if (bytes < sge->length) { + if (charge) + sk_mem_uncharge(sk, bytes); + sge->length -= bytes; + sge->offset += bytes; + msg->sg.size -= bytes; + break; + } + + msg->sg.size -= sge->length; + bytes -= sge->length; + sk_msg_free_elem(sk, msg, i, charge); + sk_msg_iter_var_next(i); + sk_msg_check_to_free(msg, i, bytes); + } + msg->sg.start = i; +} + +void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes) +{ + __sk_msg_free_partial(sk, msg, bytes, true); +} +EXPORT_SYMBOL_GPL(sk_msg_free_partial); + +void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg, + u32 bytes) +{ + __sk_msg_free_partial(sk, msg, bytes, false); +} + +void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len) +{ + int trim = msg->sg.size - len; + u32 i = msg->sg.end; + + if (trim <= 0) { + WARN_ON(trim < 0); + return; + } + + sk_msg_iter_var_prev(i); + msg->sg.size = len; + while (msg->sg.data[i].length && + trim >= msg->sg.data[i].length) { + trim -= msg->sg.data[i].length; + sk_msg_free_elem(sk, msg, i, true); + sk_msg_iter_var_prev(i); + if (!trim) + goto out; + } + + msg->sg.data[i].length -= trim; + sk_mem_uncharge(sk, trim); + /* Adjust copybreak if it falls into the trimmed part of last buf */ + if (msg->sg.curr == i && msg->sg.copybreak > msg->sg.data[i].length) + msg->sg.copybreak = msg->sg.data[i].length; +out: + sk_msg_iter_var_next(i); + msg->sg.end = i; + + /* If we trim data a full sg elem before curr pointer update + * copybreak and current so that any future copy operations + * start at new copy location. + * However trimed data that has not yet been used in a copy op + * does not require an update. + */ + if (!msg->sg.size) { + msg->sg.curr = msg->sg.start; + msg->sg.copybreak = 0; + } else if (sk_msg_iter_dist(msg->sg.start, msg->sg.curr) >= + sk_msg_iter_dist(msg->sg.start, msg->sg.end)) { + sk_msg_iter_var_prev(i); + msg->sg.curr = i; + msg->sg.copybreak = msg->sg.data[i].length; + } +} +EXPORT_SYMBOL_GPL(sk_msg_trim); + +int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from, + struct sk_msg *msg, u32 bytes) +{ + int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg); + const int to_max_pages = MAX_MSG_FRAGS; + struct page *pages[MAX_MSG_FRAGS]; + ssize_t orig, copied, use, offset; + + orig = msg->sg.size; + while (bytes > 0) { + i = 0; + maxpages = to_max_pages - num_elems; + if (maxpages == 0) { + ret = -EFAULT; + goto out; + } + + copied = iov_iter_get_pages2(from, pages, bytes, maxpages, + &offset); + if (copied <= 0) { + ret = -EFAULT; + goto out; + } + + bytes -= copied; + msg->sg.size += copied; + + while (copied) { + use = min_t(int, copied, PAGE_SIZE - offset); + sg_set_page(&msg->sg.data[msg->sg.end], + pages[i], use, offset); + sg_unmark_end(&msg->sg.data[msg->sg.end]); + sk_mem_charge(sk, use); + + offset = 0; + copied -= use; + sk_msg_iter_next(msg, end); + num_elems++; + i++; + } + /* When zerocopy is mixed with sk_msg_*copy* operations we + * may have a copybreak set in this case clear and prefer + * zerocopy remainder when possible. + */ + msg->sg.copybreak = 0; + msg->sg.curr = msg->sg.end; + } +out: + /* Revert iov_iter updates, msg will need to use 'trim' later if it + * also needs to be cleared. + */ + if (ret) + iov_iter_revert(from, msg->sg.size - orig); + return ret; +} +EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter); + +int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, + struct sk_msg *msg, u32 bytes) +{ + int ret = -ENOSPC, i = msg->sg.curr; + struct scatterlist *sge; + u32 copy, buf_size; + void *to; + + do { + sge = sk_msg_elem(msg, i); + /* This is possible if a trim operation shrunk the buffer */ + if (msg->sg.copybreak >= sge->length) { + msg->sg.copybreak = 0; + sk_msg_iter_var_next(i); + if (i == msg->sg.end) + break; + sge = sk_msg_elem(msg, i); + } + + buf_size = sge->length - msg->sg.copybreak; + copy = (buf_size > bytes) ? bytes : buf_size; + to = sg_virt(sge) + msg->sg.copybreak; + msg->sg.copybreak += copy; + if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY) + ret = copy_from_iter_nocache(to, copy, from); + else + ret = copy_from_iter(to, copy, from); + if (ret != copy) { + ret = -EFAULT; + goto out; + } + bytes -= copy; + if (!bytes) + break; + msg->sg.copybreak = 0; + sk_msg_iter_var_next(i); + } while (i != msg->sg.end); +out: + msg->sg.curr = i; + return ret; +} +EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter); + +/* Receive sk_msg from psock->ingress_msg to @msg. */ +int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, + int len, int flags) +{ + struct iov_iter *iter = &msg->msg_iter; + int peek = flags & MSG_PEEK; + struct sk_msg *msg_rx; + int i, copied = 0; + + msg_rx = sk_psock_peek_msg(psock); + while (copied != len) { + struct scatterlist *sge; + + if (unlikely(!msg_rx)) + break; + + i = msg_rx->sg.start; + do { + struct page *page; + int copy; + + sge = sk_msg_elem(msg_rx, i); + copy = sge->length; + page = sg_page(sge); + if (copied + copy > len) + copy = len - copied; + copy = copy_page_to_iter(page, sge->offset, copy, iter); + if (!copy) { + copied = copied ? copied : -EFAULT; + goto out; + } + + copied += copy; + if (likely(!peek)) { + sge->offset += copy; + sge->length -= copy; + if (!msg_rx->skb) + sk_mem_uncharge(sk, copy); + msg_rx->sg.size -= copy; + + if (!sge->length) { + sk_msg_iter_var_next(i); + if (!msg_rx->skb) + put_page(page); + } + } else { + /* Lets not optimize peek case if copy_page_to_iter + * didn't copy the entire length lets just break. + */ + if (copy != sge->length) + goto out; + sk_msg_iter_var_next(i); + } + + if (copied == len) + break; + } while ((i != msg_rx->sg.end) && !sg_is_last(sge)); + + if (unlikely(peek)) { + msg_rx = sk_psock_next_msg(psock, msg_rx); + if (!msg_rx) + break; + continue; + } + + msg_rx->sg.start = i; + if (!sge->length && (i == msg_rx->sg.end || sg_is_last(sge))) { + msg_rx = sk_psock_dequeue_msg(psock); + kfree_sk_msg(msg_rx); + } + msg_rx = sk_psock_peek_msg(psock); + } +out: + return copied; +} +EXPORT_SYMBOL_GPL(sk_msg_recvmsg); + +bool sk_msg_is_readable(struct sock *sk) +{ + struct sk_psock *psock; + bool empty = true; + + rcu_read_lock(); + psock = sk_psock(sk); + if (likely(psock)) + empty = list_empty(&psock->ingress_msg); + rcu_read_unlock(); + return !empty; +} +EXPORT_SYMBOL_GPL(sk_msg_is_readable); + +static struct sk_msg *alloc_sk_msg(gfp_t gfp) +{ + struct sk_msg *msg; + + msg = kzalloc(sizeof(*msg), gfp | __GFP_NOWARN); + if (unlikely(!msg)) + return NULL; + sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS); + return msg; +} + +static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk, + struct sk_buff *skb) +{ + if (atomic_read(&sk->sk_rmem_alloc) > sk->sk_rcvbuf) + return NULL; + + if (!sk_rmem_schedule(sk, skb, skb->truesize)) + return NULL; + + return alloc_sk_msg(GFP_KERNEL); +} + +static int sk_psock_skb_ingress_enqueue(struct sk_buff *skb, + u32 off, u32 len, + struct sk_psock *psock, + struct sock *sk, + struct sk_msg *msg) +{ + int num_sge, copied; + + num_sge = skb_to_sgvec(skb, msg->sg.data, off, len); + if (num_sge < 0) { + /* skb linearize may fail with ENOMEM, but lets simply try again + * later if this happens. Under memory pressure we don't want to + * drop the skb. We need to linearize the skb so that the mapping + * in skb_to_sgvec can not error. + */ + if (skb_linearize(skb)) + return -EAGAIN; + + num_sge = skb_to_sgvec(skb, msg->sg.data, off, len); + if (unlikely(num_sge < 0)) + return num_sge; + } + + copied = len; + msg->sg.start = 0; + msg->sg.size = copied; + msg->sg.end = num_sge; + msg->skb = skb; + + sk_psock_queue_msg(psock, msg); + sk_psock_data_ready(sk, psock); + return copied; +} + +static int sk_psock_skb_ingress_self(struct sk_psock *psock, struct sk_buff *skb, + u32 off, u32 len); + +static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb, + u32 off, u32 len) +{ + struct sock *sk = psock->sk; + struct sk_msg *msg; + int err; + + /* If we are receiving on the same sock skb->sk is already assigned, + * skip memory accounting and owner transition seeing it already set + * correctly. + */ + if (unlikely(skb->sk == sk)) + return sk_psock_skb_ingress_self(psock, skb, off, len); + msg = sk_psock_create_ingress_msg(sk, skb); + if (!msg) + return -EAGAIN; + + /* This will transition ownership of the data from the socket where + * the BPF program was run initiating the redirect to the socket + * we will eventually receive this data on. The data will be released + * from skb_consume found in __tcp_bpf_recvmsg() after its been copied + * into user buffers. + */ + skb_set_owner_r(skb, sk); + err = sk_psock_skb_ingress_enqueue(skb, off, len, psock, sk, msg); + if (err < 0) + kfree(msg); + return err; +} + +/* Puts an skb on the ingress queue of the socket already assigned to the + * skb. In this case we do not need to check memory limits or skb_set_owner_r + * because the skb is already accounted for here. + */ +static int sk_psock_skb_ingress_self(struct sk_psock *psock, struct sk_buff *skb, + u32 off, u32 len) +{ + struct sk_msg *msg = alloc_sk_msg(GFP_ATOMIC); + struct sock *sk = psock->sk; + int err; + + if (unlikely(!msg)) + return -EAGAIN; + skb_set_owner_r(skb, sk); + err = sk_psock_skb_ingress_enqueue(skb, off, len, psock, sk, msg); + if (err < 0) + kfree(msg); + return err; +} + +static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb, + u32 off, u32 len, bool ingress) +{ + int err = 0; + + if (!ingress) { + if (!sock_writeable(psock->sk)) + return -EAGAIN; + return skb_send_sock(psock->sk, skb, off, len); + } + skb_get(skb); + err = sk_psock_skb_ingress(psock, skb, off, len); + if (err < 0) + kfree_skb(skb); + return err; +} + +static void sk_psock_skb_state(struct sk_psock *psock, + struct sk_psock_work_state *state, + int len, int off) +{ + spin_lock_bh(&psock->ingress_lock); + if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) { + state->len = len; + state->off = off; + } + spin_unlock_bh(&psock->ingress_lock); +} + +static void sk_psock_backlog(struct work_struct *work) +{ + struct delayed_work *dwork = to_delayed_work(work); + struct sk_psock *psock = container_of(dwork, struct sk_psock, work); + struct sk_psock_work_state *state = &psock->work_state; + struct sk_buff *skb = NULL; + u32 len = 0, off = 0; + bool ingress; + int ret; + + mutex_lock(&psock->work_mutex); + if (unlikely(state->len)) { + len = state->len; + off = state->off; + } + + while ((skb = skb_peek(&psock->ingress_skb))) { + len = skb->len; + off = 0; + if (skb_bpf_strparser(skb)) { + struct strp_msg *stm = strp_msg(skb); + + off = stm->offset; + len = stm->full_len; + } + ingress = skb_bpf_ingress(skb); + skb_bpf_redirect_clear(skb); + do { + ret = -EIO; + if (!sock_flag(psock->sk, SOCK_DEAD)) + ret = sk_psock_handle_skb(psock, skb, off, + len, ingress); + if (ret <= 0) { + if (ret == -EAGAIN) { + sk_psock_skb_state(psock, state, len, off); + + /* Delay slightly to prioritize any + * other work that might be here. + */ + if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) + schedule_delayed_work(&psock->work, 1); + goto end; + } + /* Hard errors break pipe and stop xmit. */ + sk_psock_report_error(psock, ret ? -ret : EPIPE); + sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); + goto end; + } + off += ret; + len -= ret; + } while (len); + + skb = skb_dequeue(&psock->ingress_skb); + kfree_skb(skb); + } +end: + mutex_unlock(&psock->work_mutex); +} + +struct sk_psock *sk_psock_init(struct sock *sk, int node) +{ + struct sk_psock *psock; + struct proto *prot; + + write_lock_bh(&sk->sk_callback_lock); + + if (sk_is_inet(sk) && inet_csk_has_ulp(sk)) { + psock = ERR_PTR(-EINVAL); + goto out; + } + + if (sk->sk_user_data) { + psock = ERR_PTR(-EBUSY); + goto out; + } + + psock = kzalloc_node(sizeof(*psock), GFP_ATOMIC | __GFP_NOWARN, node); + if (!psock) { + psock = ERR_PTR(-ENOMEM); + goto out; + } + + prot = READ_ONCE(sk->sk_prot); + psock->sk = sk; + psock->eval = __SK_NONE; + psock->sk_proto = prot; + psock->saved_unhash = prot->unhash; + psock->saved_destroy = prot->destroy; + psock->saved_close = prot->close; + psock->saved_write_space = sk->sk_write_space; + + INIT_LIST_HEAD(&psock->link); + spin_lock_init(&psock->link_lock); + + INIT_DELAYED_WORK(&psock->work, sk_psock_backlog); + mutex_init(&psock->work_mutex); + INIT_LIST_HEAD(&psock->ingress_msg); + spin_lock_init(&psock->ingress_lock); + skb_queue_head_init(&psock->ingress_skb); + + sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED); + refcount_set(&psock->refcnt, 1); + + __rcu_assign_sk_user_data_with_flags(sk, psock, + SK_USER_DATA_NOCOPY | + SK_USER_DATA_PSOCK); + sock_hold(sk); + +out: + write_unlock_bh(&sk->sk_callback_lock); + return psock; +} +EXPORT_SYMBOL_GPL(sk_psock_init); + +struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock) +{ + struct sk_psock_link *link; + + spin_lock_bh(&psock->link_lock); + link = list_first_entry_or_null(&psock->link, struct sk_psock_link, + list); + if (link) + list_del(&link->list); + spin_unlock_bh(&psock->link_lock); + return link; +} + +static void __sk_psock_purge_ingress_msg(struct sk_psock *psock) +{ + struct sk_msg *msg, *tmp; + + list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) { + list_del(&msg->list); + sk_msg_free(psock->sk, msg); + kfree(msg); + } +} + +static void __sk_psock_zap_ingress(struct sk_psock *psock) +{ + struct sk_buff *skb; + + while ((skb = skb_dequeue(&psock->ingress_skb)) != NULL) { + skb_bpf_redirect_clear(skb); + sock_drop(psock->sk, skb); + } + __sk_psock_purge_ingress_msg(psock); +} + +static void sk_psock_link_destroy(struct sk_psock *psock) +{ + struct sk_psock_link *link, *tmp; + + list_for_each_entry_safe(link, tmp, &psock->link, list) { + list_del(&link->list); + sk_psock_free_link(link); + } +} + +void sk_psock_stop(struct sk_psock *psock) +{ + spin_lock_bh(&psock->ingress_lock); + sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); + sk_psock_cork_free(psock); + spin_unlock_bh(&psock->ingress_lock); +} + +static void sk_psock_done_strp(struct sk_psock *psock); + +static void sk_psock_destroy(struct work_struct *work) +{ + struct sk_psock *psock = container_of(to_rcu_work(work), + struct sk_psock, rwork); + /* No sk_callback_lock since already detached. */ + + sk_psock_done_strp(psock); + + cancel_delayed_work_sync(&psock->work); + __sk_psock_zap_ingress(psock); + mutex_destroy(&psock->work_mutex); + + psock_progs_drop(&psock->progs); + + sk_psock_link_destroy(psock); + sk_psock_cork_free(psock); + + if (psock->sk_redir) + sock_put(psock->sk_redir); + if (psock->sk_pair) + sock_put(psock->sk_pair); + sock_put(psock->sk); + kfree(psock); +} + +void sk_psock_drop(struct sock *sk, struct sk_psock *psock) +{ + write_lock_bh(&sk->sk_callback_lock); + sk_psock_restore_proto(sk, psock); + rcu_assign_sk_user_data(sk, NULL); + if (psock->progs.stream_parser) + sk_psock_stop_strp(sk, psock); + else if (psock->progs.stream_verdict || psock->progs.skb_verdict) + sk_psock_stop_verdict(sk, psock); + write_unlock_bh(&sk->sk_callback_lock); + + sk_psock_stop(psock); + + INIT_RCU_WORK(&psock->rwork, sk_psock_destroy); + queue_rcu_work(system_wq, &psock->rwork); +} +EXPORT_SYMBOL_GPL(sk_psock_drop); + +static int sk_psock_map_verd(int verdict, bool redir) +{ + switch (verdict) { + case SK_PASS: + return redir ? __SK_REDIRECT : __SK_PASS; + case SK_DROP: + default: + break; + } + + return __SK_DROP; +} + +int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock, + struct sk_msg *msg) +{ + struct bpf_prog *prog; + int ret; + + rcu_read_lock(); + prog = READ_ONCE(psock->progs.msg_parser); + if (unlikely(!prog)) { + ret = __SK_PASS; + goto out; + } + + sk_msg_compute_data_pointers(msg); + msg->sk = sk; + ret = bpf_prog_run_pin_on_cpu(prog, msg); + ret = sk_psock_map_verd(ret, msg->sk_redir); + psock->apply_bytes = msg->apply_bytes; + if (ret == __SK_REDIRECT) { + if (psock->sk_redir) { + sock_put(psock->sk_redir); + psock->sk_redir = NULL; + } + if (!msg->sk_redir) { + ret = __SK_DROP; + goto out; + } + psock->redir_ingress = sk_msg_to_ingress(msg); + psock->sk_redir = msg->sk_redir; + sock_hold(psock->sk_redir); + } +out: + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(sk_psock_msg_verdict); + +static int sk_psock_skb_redirect(struct sk_psock *from, struct sk_buff *skb) +{ + struct sk_psock *psock_other; + struct sock *sk_other; + + sk_other = skb_bpf_redirect_fetch(skb); + /* This error is a buggy BPF program, it returned a redirect + * return code, but then didn't set a redirect interface. + */ + if (unlikely(!sk_other)) { + skb_bpf_redirect_clear(skb); + sock_drop(from->sk, skb); + return -EIO; + } + psock_other = sk_psock(sk_other); + /* This error indicates the socket is being torn down or had another + * error that caused the pipe to break. We can't send a packet on + * a socket that is in this state so we drop the skb. + */ + if (!psock_other || sock_flag(sk_other, SOCK_DEAD)) { + skb_bpf_redirect_clear(skb); + sock_drop(from->sk, skb); + return -EIO; + } + spin_lock_bh(&psock_other->ingress_lock); + if (!sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) { + spin_unlock_bh(&psock_other->ingress_lock); + skb_bpf_redirect_clear(skb); + sock_drop(from->sk, skb); + return -EIO; + } + + skb_queue_tail(&psock_other->ingress_skb, skb); + schedule_delayed_work(&psock_other->work, 0); + spin_unlock_bh(&psock_other->ingress_lock); + return 0; +} + +static void sk_psock_tls_verdict_apply(struct sk_buff *skb, + struct sk_psock *from, int verdict) +{ + switch (verdict) { + case __SK_REDIRECT: + sk_psock_skb_redirect(from, skb); + break; + case __SK_PASS: + case __SK_DROP: + default: + break; + } +} + +int sk_psock_tls_strp_read(struct sk_psock *psock, struct sk_buff *skb) +{ + struct bpf_prog *prog; + int ret = __SK_PASS; + + rcu_read_lock(); + prog = READ_ONCE(psock->progs.stream_verdict); + if (likely(prog)) { + skb->sk = psock->sk; + skb_dst_drop(skb); + skb_bpf_redirect_clear(skb); + ret = bpf_prog_run_pin_on_cpu(prog, skb); + ret = sk_psock_map_verd(ret, skb_bpf_redirect_fetch(skb)); + skb->sk = NULL; + } + sk_psock_tls_verdict_apply(skb, psock, ret); + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(sk_psock_tls_strp_read); + +static int sk_psock_verdict_apply(struct sk_psock *psock, struct sk_buff *skb, + int verdict) +{ + struct sock *sk_other; + int err = 0; + u32 len, off; + + switch (verdict) { + case __SK_PASS: + err = -EIO; + sk_other = psock->sk; + if (sock_flag(sk_other, SOCK_DEAD) || + !sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) + goto out_free; + + skb_bpf_set_ingress(skb); + + /* If the queue is empty then we can submit directly + * into the msg queue. If its not empty we have to + * queue work otherwise we may get OOO data. Otherwise, + * if sk_psock_skb_ingress errors will be handled by + * retrying later from workqueue. + */ + if (skb_queue_empty(&psock->ingress_skb)) { + len = skb->len; + off = 0; + if (skb_bpf_strparser(skb)) { + struct strp_msg *stm = strp_msg(skb); + + off = stm->offset; + len = stm->full_len; + } + err = sk_psock_skb_ingress_self(psock, skb, off, len); + } + if (err < 0) { + spin_lock_bh(&psock->ingress_lock); + if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) { + skb_queue_tail(&psock->ingress_skb, skb); + schedule_delayed_work(&psock->work, 0); + err = 0; + } + spin_unlock_bh(&psock->ingress_lock); + if (err < 0) + goto out_free; + } + break; + case __SK_REDIRECT: + tcp_eat_skb(psock->sk, skb); + err = sk_psock_skb_redirect(psock, skb); + break; + case __SK_DROP: + default: +out_free: + skb_bpf_redirect_clear(skb); + tcp_eat_skb(psock->sk, skb); + sock_drop(psock->sk, skb); + } + + return err; +} + +static void sk_psock_write_space(struct sock *sk) +{ + struct sk_psock *psock; + void (*write_space)(struct sock *sk) = NULL; + + rcu_read_lock(); + psock = sk_psock(sk); + if (likely(psock)) { + if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) + schedule_delayed_work(&psock->work, 0); + write_space = psock->saved_write_space; + } + rcu_read_unlock(); + if (write_space) + write_space(sk); +} + +#if IS_ENABLED(CONFIG_BPF_STREAM_PARSER) +static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb) +{ + struct sk_psock *psock; + struct bpf_prog *prog; + int ret = __SK_DROP; + struct sock *sk; + + rcu_read_lock(); + sk = strp->sk; + psock = sk_psock(sk); + if (unlikely(!psock)) { + sock_drop(sk, skb); + goto out; + } + prog = READ_ONCE(psock->progs.stream_verdict); + if (likely(prog)) { + skb->sk = sk; + skb_dst_drop(skb); + skb_bpf_redirect_clear(skb); + ret = bpf_prog_run_pin_on_cpu(prog, skb); + skb_bpf_set_strparser(skb); + ret = sk_psock_map_verd(ret, skb_bpf_redirect_fetch(skb)); + skb->sk = NULL; + } + sk_psock_verdict_apply(psock, skb, ret); +out: + rcu_read_unlock(); +} + +static int sk_psock_strp_read_done(struct strparser *strp, int err) +{ + return err; +} + +static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb) +{ + struct sk_psock *psock = container_of(strp, struct sk_psock, strp); + struct bpf_prog *prog; + int ret = skb->len; + + rcu_read_lock(); + prog = READ_ONCE(psock->progs.stream_parser); + if (likely(prog)) { + skb->sk = psock->sk; + ret = bpf_prog_run_pin_on_cpu(prog, skb); + skb->sk = NULL; + } + rcu_read_unlock(); + return ret; +} + +/* Called with socket lock held. */ +static void sk_psock_strp_data_ready(struct sock *sk) +{ + struct sk_psock *psock; + + rcu_read_lock(); + psock = sk_psock(sk); + if (likely(psock)) { + if (tls_sw_has_ctx_rx(sk)) { + psock->saved_data_ready(sk); + } else { + write_lock_bh(&sk->sk_callback_lock); + strp_data_ready(&psock->strp); + write_unlock_bh(&sk->sk_callback_lock); + } + } + rcu_read_unlock(); +} + +int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock) +{ + int ret; + + static const struct strp_callbacks cb = { + .rcv_msg = sk_psock_strp_read, + .read_sock_done = sk_psock_strp_read_done, + .parse_msg = sk_psock_strp_parse, + }; + + ret = strp_init(&psock->strp, sk, &cb); + if (!ret) + sk_psock_set_state(psock, SK_PSOCK_RX_STRP_ENABLED); + + return ret; +} + +void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock) +{ + if (psock->saved_data_ready) + return; + + psock->saved_data_ready = sk->sk_data_ready; + sk->sk_data_ready = sk_psock_strp_data_ready; + sk->sk_write_space = sk_psock_write_space; +} + +void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock) +{ + psock_set_prog(&psock->progs.stream_parser, NULL); + + if (!psock->saved_data_ready) + return; + + sk->sk_data_ready = psock->saved_data_ready; + psock->saved_data_ready = NULL; + strp_stop(&psock->strp); +} + +static void sk_psock_done_strp(struct sk_psock *psock) +{ + /* Parser has been stopped */ + if (sk_psock_test_state(psock, SK_PSOCK_RX_STRP_ENABLED)) + strp_done(&psock->strp); +} +#else +static void sk_psock_done_strp(struct sk_psock *psock) +{ +} +#endif /* CONFIG_BPF_STREAM_PARSER */ + +static int sk_psock_verdict_recv(struct sock *sk, struct sk_buff *skb) +{ + struct sk_psock *psock; + struct bpf_prog *prog; + int ret = __SK_DROP; + int len = skb->len; + + rcu_read_lock(); + psock = sk_psock(sk); + if (unlikely(!psock)) { + len = 0; + tcp_eat_skb(sk, skb); + sock_drop(sk, skb); + goto out; + } + prog = READ_ONCE(psock->progs.stream_verdict); + if (!prog) + prog = READ_ONCE(psock->progs.skb_verdict); + if (likely(prog)) { + skb_dst_drop(skb); + skb_bpf_redirect_clear(skb); + ret = bpf_prog_run_pin_on_cpu(prog, skb); + ret = sk_psock_map_verd(ret, skb_bpf_redirect_fetch(skb)); + } + ret = sk_psock_verdict_apply(psock, skb, ret); + if (ret < 0) + len = ret; +out: + rcu_read_unlock(); + return len; +} + +static void sk_psock_verdict_data_ready(struct sock *sk) +{ + struct socket *sock = sk->sk_socket; + int copied; + + if (unlikely(!sock || !sock->ops || !sock->ops->read_skb)) + return; + copied = sock->ops->read_skb(sk, sk_psock_verdict_recv); + if (copied >= 0) { + struct sk_psock *psock; + + rcu_read_lock(); + psock = sk_psock(sk); + if (psock) + psock->saved_data_ready(sk); + rcu_read_unlock(); + } +} + +void sk_psock_start_verdict(struct sock *sk, struct sk_psock *psock) +{ + if (psock->saved_data_ready) + return; + + psock->saved_data_ready = sk->sk_data_ready; + sk->sk_data_ready = sk_psock_verdict_data_ready; + sk->sk_write_space = sk_psock_write_space; +} + +void sk_psock_stop_verdict(struct sock *sk, struct sk_psock *psock) +{ + psock_set_prog(&psock->progs.stream_verdict, NULL); + psock_set_prog(&psock->progs.skb_verdict, NULL); + + if (!psock->saved_data_ready) + return; + + sk->sk_data_ready = psock->saved_data_ready; + psock->saved_data_ready = NULL; +} diff --git a/net/core/sock.c b/net/core/sock.c new file mode 100644 index 000000000..c8803b95e --- /dev/null +++ b/net/core/sock.c @@ -0,0 +1,4131 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * Generic socket support routines. Memory allocators, socket lock/release + * handler for protocols to use and generic option handler. + * + * Authors: Ross Biro + * Fred N. van Kempen, + * Florian La Roche, + * Alan Cox, + * + * Fixes: + * Alan Cox : Numerous verify_area() problems + * Alan Cox : Connecting on a connecting socket + * now returns an error for tcp. + * Alan Cox : sock->protocol is set correctly. + * and is not sometimes left as 0. + * Alan Cox : connect handles icmp errors on a + * connect properly. Unfortunately there + * is a restart syscall nasty there. I + * can't match BSD without hacking the C + * library. Ideas urgently sought! + * Alan Cox : Disallow bind() to addresses that are + * not ours - especially broadcast ones!! + * Alan Cox : Socket 1024 _IS_ ok for users. (fencepost) + * Alan Cox : sock_wfree/sock_rfree don't destroy sockets, + * instead they leave that for the DESTROY timer. + * Alan Cox : Clean up error flag in accept + * Alan Cox : TCP ack handling is buggy, the DESTROY timer + * was buggy. Put a remove_sock() in the handler + * for memory when we hit 0. Also altered the timer + * code. The ACK stuff can wait and needs major + * TCP layer surgery. + * Alan Cox : Fixed TCP ack bug, removed remove sock + * and fixed timer/inet_bh race. + * Alan Cox : Added zapped flag for TCP + * Alan Cox : Move kfree_skb into skbuff.c and tidied up surplus code + * Alan Cox : for new sk_buff allocations wmalloc/rmalloc now call alloc_skb + * Alan Cox : kfree_s calls now are kfree_skbmem so we can track skb resources + * Alan Cox : Supports socket option broadcast now as does udp. Packet and raw need fixing. + * Alan Cox : Added RCVBUF,SNDBUF size setting. It suddenly occurred to me how easy it was so... + * Rick Sladkey : Relaxed UDP rules for matching packets. + * C.E.Hawkins : IFF_PROMISC/SIOCGHWADDR support + * Pauline Middelink : identd support + * Alan Cox : Fixed connect() taking signals I think. + * Alan Cox : SO_LINGER supported + * Alan Cox : Error reporting fixes + * Anonymous : inet_create tidied up (sk->reuse setting) + * Alan Cox : inet sockets don't set sk->type! + * Alan Cox : Split socket option code + * Alan Cox : Callbacks + * Alan Cox : Nagle flag for Charles & Johannes stuff + * Alex : Removed restriction on inet fioctl + * Alan Cox : Splitting INET from NET core + * Alan Cox : Fixed bogus SO_TYPE handling in getsockopt() + * Adam Caldwell : Missing return in SO_DONTROUTE/SO_DEBUG code + * Alan Cox : Split IP from generic code + * Alan Cox : New kfree_skbmem() + * Alan Cox : Make SO_DEBUG superuser only. + * Alan Cox : Allow anyone to clear SO_DEBUG + * (compatibility fix) + * Alan Cox : Added optimistic memory grabbing for AF_UNIX throughput. + * Alan Cox : Allocator for a socket is settable. + * Alan Cox : SO_ERROR includes soft errors. + * Alan Cox : Allow NULL arguments on some SO_ opts + * Alan Cox : Generic socket allocation to make hooks + * easier (suggested by Craig Metz). + * Michael Pall : SO_ERROR returns positive errno again + * Steve Whitehouse: Added default destructor to free + * protocol private data. + * Steve Whitehouse: Added various other default routines + * common to several socket families. + * Chris Evans : Call suser() check last on F_SETOWN + * Jay Schulist : Added SO_ATTACH_FILTER and SO_DETACH_FILTER. + * Andi Kleen : Add sock_kmalloc()/sock_kfree_s() + * Andi Kleen : Fix write_space callback + * Chris Evans : Security fixes - signedness again + * Arnaldo C. Melo : cleanups, use skb_queue_purge + * + * To Fix: + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include +#include + +#include + +#include "dev.h" + +static DEFINE_MUTEX(proto_list_mutex); +static LIST_HEAD(proto_list); + +static void sock_def_write_space_wfree(struct sock *sk); +static void sock_def_write_space(struct sock *sk); + +/** + * sk_ns_capable - General socket capability test + * @sk: Socket to use a capability on or through + * @user_ns: The user namespace of the capability to use + * @cap: The capability to use + * + * Test to see if the opener of the socket had when the socket was + * created and the current process has the capability @cap in the user + * namespace @user_ns. + */ +bool sk_ns_capable(const struct sock *sk, + struct user_namespace *user_ns, int cap) +{ + return file_ns_capable(sk->sk_socket->file, user_ns, cap) && + ns_capable(user_ns, cap); +} +EXPORT_SYMBOL(sk_ns_capable); + +/** + * sk_capable - Socket global capability test + * @sk: Socket to use a capability on or through + * @cap: The global capability to use + * + * Test to see if the opener of the socket had when the socket was + * created and the current process has the capability @cap in all user + * namespaces. + */ +bool sk_capable(const struct sock *sk, int cap) +{ + return sk_ns_capable(sk, &init_user_ns, cap); +} +EXPORT_SYMBOL(sk_capable); + +/** + * sk_net_capable - Network namespace socket capability test + * @sk: Socket to use a capability on or through + * @cap: The capability to use + * + * Test to see if the opener of the socket had when the socket was created + * and the current process has the capability @cap over the network namespace + * the socket is a member of. + */ +bool sk_net_capable(const struct sock *sk, int cap) +{ + return sk_ns_capable(sk, sock_net(sk)->user_ns, cap); +} +EXPORT_SYMBOL(sk_net_capable); + +/* + * Each address family might have different locking rules, so we have + * one slock key per address family and separate keys for internal and + * userspace sockets. + */ +static struct lock_class_key af_family_keys[AF_MAX]; +static struct lock_class_key af_family_kern_keys[AF_MAX]; +static struct lock_class_key af_family_slock_keys[AF_MAX]; +static struct lock_class_key af_family_kern_slock_keys[AF_MAX]; + +/* + * Make lock validator output more readable. (we pre-construct these + * strings build-time, so that runtime initialization of socket + * locks is fast): + */ + +#define _sock_locks(x) \ + x "AF_UNSPEC", x "AF_UNIX" , x "AF_INET" , \ + x "AF_AX25" , x "AF_IPX" , x "AF_APPLETALK", \ + x "AF_NETROM", x "AF_BRIDGE" , x "AF_ATMPVC" , \ + x "AF_X25" , x "AF_INET6" , x "AF_ROSE" , \ + x "AF_DECnet", x "AF_NETBEUI" , x "AF_SECURITY" , \ + x "AF_KEY" , x "AF_NETLINK" , x "AF_PACKET" , \ + x "AF_ASH" , x "AF_ECONET" , x "AF_ATMSVC" , \ + x "AF_RDS" , x "AF_SNA" , x "AF_IRDA" , \ + x "AF_PPPOX" , x "AF_WANPIPE" , x "AF_LLC" , \ + x "27" , x "28" , x "AF_CAN" , \ + x "AF_TIPC" , x "AF_BLUETOOTH", x "IUCV" , \ + x "AF_RXRPC" , x "AF_ISDN" , x "AF_PHONET" , \ + x "AF_IEEE802154", x "AF_CAIF" , x "AF_ALG" , \ + x "AF_NFC" , x "AF_VSOCK" , x "AF_KCM" , \ + x "AF_QIPCRTR", x "AF_SMC" , x "AF_XDP" , \ + x "AF_MCTP" , \ + x "AF_MAX" + +static const char *const af_family_key_strings[AF_MAX+1] = { + _sock_locks("sk_lock-") +}; +static const char *const af_family_slock_key_strings[AF_MAX+1] = { + _sock_locks("slock-") +}; +static const char *const af_family_clock_key_strings[AF_MAX+1] = { + _sock_locks("clock-") +}; + +static const char *const af_family_kern_key_strings[AF_MAX+1] = { + _sock_locks("k-sk_lock-") +}; +static const char *const af_family_kern_slock_key_strings[AF_MAX+1] = { + _sock_locks("k-slock-") +}; +static const char *const af_family_kern_clock_key_strings[AF_MAX+1] = { + _sock_locks("k-clock-") +}; +static const char *const af_family_rlock_key_strings[AF_MAX+1] = { + _sock_locks("rlock-") +}; +static const char *const af_family_wlock_key_strings[AF_MAX+1] = { + _sock_locks("wlock-") +}; +static const char *const af_family_elock_key_strings[AF_MAX+1] = { + _sock_locks("elock-") +}; + +/* + * sk_callback_lock and sk queues locking rules are per-address-family, + * so split the lock classes by using a per-AF key: + */ +static struct lock_class_key af_callback_keys[AF_MAX]; +static struct lock_class_key af_rlock_keys[AF_MAX]; +static struct lock_class_key af_wlock_keys[AF_MAX]; +static struct lock_class_key af_elock_keys[AF_MAX]; +static struct lock_class_key af_kern_callback_keys[AF_MAX]; + +/* Run time adjustable parameters. */ +__u32 sysctl_wmem_max __read_mostly = SK_WMEM_MAX; +EXPORT_SYMBOL(sysctl_wmem_max); +__u32 sysctl_rmem_max __read_mostly = SK_RMEM_MAX; +EXPORT_SYMBOL(sysctl_rmem_max); +__u32 sysctl_wmem_default __read_mostly = SK_WMEM_MAX; +__u32 sysctl_rmem_default __read_mostly = SK_RMEM_MAX; + +/* Maximal space eaten by iovec or ancillary data plus some space */ +int sysctl_optmem_max __read_mostly = sizeof(unsigned long)*(2*UIO_MAXIOV+512); +EXPORT_SYMBOL(sysctl_optmem_max); + +int sysctl_tstamp_allow_data __read_mostly = 1; + +DEFINE_STATIC_KEY_FALSE(memalloc_socks_key); +EXPORT_SYMBOL_GPL(memalloc_socks_key); + +/** + * sk_set_memalloc - sets %SOCK_MEMALLOC + * @sk: socket to set it on + * + * Set %SOCK_MEMALLOC on a socket for access to emergency reserves. + * It's the responsibility of the admin to adjust min_free_kbytes + * to meet the requirements + */ +void sk_set_memalloc(struct sock *sk) +{ + sock_set_flag(sk, SOCK_MEMALLOC); + sk->sk_allocation |= __GFP_MEMALLOC; + static_branch_inc(&memalloc_socks_key); +} +EXPORT_SYMBOL_GPL(sk_set_memalloc); + +void sk_clear_memalloc(struct sock *sk) +{ + sock_reset_flag(sk, SOCK_MEMALLOC); + sk->sk_allocation &= ~__GFP_MEMALLOC; + static_branch_dec(&memalloc_socks_key); + + /* + * SOCK_MEMALLOC is allowed to ignore rmem limits to ensure forward + * progress of swapping. SOCK_MEMALLOC may be cleared while + * it has rmem allocations due to the last swapfile being deactivated + * but there is a risk that the socket is unusable due to exceeding + * the rmem limits. Reclaim the reserves and obey rmem limits again. + */ + sk_mem_reclaim(sk); +} +EXPORT_SYMBOL_GPL(sk_clear_memalloc); + +int __sk_backlog_rcv(struct sock *sk, struct sk_buff *skb) +{ + int ret; + unsigned int noreclaim_flag; + + /* these should have been dropped before queueing */ + BUG_ON(!sock_flag(sk, SOCK_MEMALLOC)); + + noreclaim_flag = memalloc_noreclaim_save(); + ret = INDIRECT_CALL_INET(sk->sk_backlog_rcv, + tcp_v6_do_rcv, + tcp_v4_do_rcv, + sk, skb); + memalloc_noreclaim_restore(noreclaim_flag); + + return ret; +} +EXPORT_SYMBOL(__sk_backlog_rcv); + +void sk_error_report(struct sock *sk) +{ + sk->sk_error_report(sk); + + switch (sk->sk_family) { + case AF_INET: + fallthrough; + case AF_INET6: + trace_inet_sk_error_report(sk); + break; + default: + break; + } +} +EXPORT_SYMBOL(sk_error_report); + +int sock_get_timeout(long timeo, void *optval, bool old_timeval) +{ + struct __kernel_sock_timeval tv; + + if (timeo == MAX_SCHEDULE_TIMEOUT) { + tv.tv_sec = 0; + tv.tv_usec = 0; + } else { + tv.tv_sec = timeo / HZ; + tv.tv_usec = ((timeo % HZ) * USEC_PER_SEC) / HZ; + } + + if (old_timeval && in_compat_syscall() && !COMPAT_USE_64BIT_TIME) { + struct old_timeval32 tv32 = { tv.tv_sec, tv.tv_usec }; + *(struct old_timeval32 *)optval = tv32; + return sizeof(tv32); + } + + if (old_timeval) { + struct __kernel_old_timeval old_tv; + old_tv.tv_sec = tv.tv_sec; + old_tv.tv_usec = tv.tv_usec; + *(struct __kernel_old_timeval *)optval = old_tv; + return sizeof(old_tv); + } + + *(struct __kernel_sock_timeval *)optval = tv; + return sizeof(tv); +} +EXPORT_SYMBOL(sock_get_timeout); + +int sock_copy_user_timeval(struct __kernel_sock_timeval *tv, + sockptr_t optval, int optlen, bool old_timeval) +{ + if (old_timeval && in_compat_syscall() && !COMPAT_USE_64BIT_TIME) { + struct old_timeval32 tv32; + + if (optlen < sizeof(tv32)) + return -EINVAL; + + if (copy_from_sockptr(&tv32, optval, sizeof(tv32))) + return -EFAULT; + tv->tv_sec = tv32.tv_sec; + tv->tv_usec = tv32.tv_usec; + } else if (old_timeval) { + struct __kernel_old_timeval old_tv; + + if (optlen < sizeof(old_tv)) + return -EINVAL; + if (copy_from_sockptr(&old_tv, optval, sizeof(old_tv))) + return -EFAULT; + tv->tv_sec = old_tv.tv_sec; + tv->tv_usec = old_tv.tv_usec; + } else { + if (optlen < sizeof(*tv)) + return -EINVAL; + if (copy_from_sockptr(tv, optval, sizeof(*tv))) + return -EFAULT; + } + + return 0; +} +EXPORT_SYMBOL(sock_copy_user_timeval); + +static int sock_set_timeout(long *timeo_p, sockptr_t optval, int optlen, + bool old_timeval) +{ + struct __kernel_sock_timeval tv; + int err = sock_copy_user_timeval(&tv, optval, optlen, old_timeval); + long val; + + if (err) + return err; + + if (tv.tv_usec < 0 || tv.tv_usec >= USEC_PER_SEC) + return -EDOM; + + if (tv.tv_sec < 0) { + static int warned __read_mostly; + + WRITE_ONCE(*timeo_p, 0); + if (warned < 10 && net_ratelimit()) { + warned++; + pr_info("%s: `%s' (pid %d) tries to set negative timeout\n", + __func__, current->comm, task_pid_nr(current)); + } + return 0; + } + val = MAX_SCHEDULE_TIMEOUT; + if ((tv.tv_sec || tv.tv_usec) && + (tv.tv_sec < (MAX_SCHEDULE_TIMEOUT / HZ - 1))) + val = tv.tv_sec * HZ + DIV_ROUND_UP((unsigned long)tv.tv_usec, + USEC_PER_SEC / HZ); + WRITE_ONCE(*timeo_p, val); + return 0; +} + +static bool sock_needs_netstamp(const struct sock *sk) +{ + switch (sk->sk_family) { + case AF_UNSPEC: + case AF_UNIX: + return false; + default: + return true; + } +} + +static void sock_disable_timestamp(struct sock *sk, unsigned long flags) +{ + if (sk->sk_flags & flags) { + sk->sk_flags &= ~flags; + if (sock_needs_netstamp(sk) && + !(sk->sk_flags & SK_FLAGS_TIMESTAMP)) + net_disable_timestamp(); + } +} + + +int __sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) +{ + unsigned long flags; + struct sk_buff_head *list = &sk->sk_receive_queue; + + if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf) { + atomic_inc(&sk->sk_drops); + trace_sock_rcvqueue_full(sk, skb); + return -ENOMEM; + } + + if (!sk_rmem_schedule(sk, skb, skb->truesize)) { + atomic_inc(&sk->sk_drops); + return -ENOBUFS; + } + + skb->dev = NULL; + skb_set_owner_r(skb, sk); + + /* we escape from rcu protected region, make sure we dont leak + * a norefcounted dst + */ + skb_dst_force(skb); + + spin_lock_irqsave(&list->lock, flags); + sock_skb_set_dropcount(sk, skb); + __skb_queue_tail(list, skb); + spin_unlock_irqrestore(&list->lock, flags); + + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_data_ready(sk); + return 0; +} +EXPORT_SYMBOL(__sock_queue_rcv_skb); + +int sock_queue_rcv_skb_reason(struct sock *sk, struct sk_buff *skb, + enum skb_drop_reason *reason) +{ + enum skb_drop_reason drop_reason; + int err; + + err = sk_filter(sk, skb); + if (err) { + drop_reason = SKB_DROP_REASON_SOCKET_FILTER; + goto out; + } + err = __sock_queue_rcv_skb(sk, skb); + switch (err) { + case -ENOMEM: + drop_reason = SKB_DROP_REASON_SOCKET_RCVBUFF; + break; + case -ENOBUFS: + drop_reason = SKB_DROP_REASON_PROTO_MEM; + break; + default: + drop_reason = SKB_NOT_DROPPED_YET; + break; + } +out: + if (reason) + *reason = drop_reason; + return err; +} +EXPORT_SYMBOL(sock_queue_rcv_skb_reason); + +int __sk_receive_skb(struct sock *sk, struct sk_buff *skb, + const int nested, unsigned int trim_cap, bool refcounted) +{ + int rc = NET_RX_SUCCESS; + + if (sk_filter_trim_cap(sk, skb, trim_cap)) + goto discard_and_relse; + + skb->dev = NULL; + + if (sk_rcvqueues_full(sk, sk->sk_rcvbuf)) { + atomic_inc(&sk->sk_drops); + goto discard_and_relse; + } + if (nested) + bh_lock_sock_nested(sk); + else + bh_lock_sock(sk); + if (!sock_owned_by_user(sk)) { + /* + * trylock + unlock semantics: + */ + mutex_acquire(&sk->sk_lock.dep_map, 0, 1, _RET_IP_); + + rc = sk_backlog_rcv(sk, skb); + + mutex_release(&sk->sk_lock.dep_map, _RET_IP_); + } else if (sk_add_backlog(sk, skb, READ_ONCE(sk->sk_rcvbuf))) { + bh_unlock_sock(sk); + atomic_inc(&sk->sk_drops); + goto discard_and_relse; + } + + bh_unlock_sock(sk); +out: + if (refcounted) + sock_put(sk); + return rc; +discard_and_relse: + kfree_skb(skb); + goto out; +} +EXPORT_SYMBOL(__sk_receive_skb); + +INDIRECT_CALLABLE_DECLARE(struct dst_entry *ip6_dst_check(struct dst_entry *, + u32)); +INDIRECT_CALLABLE_DECLARE(struct dst_entry *ipv4_dst_check(struct dst_entry *, + u32)); +struct dst_entry *__sk_dst_check(struct sock *sk, u32 cookie) +{ + struct dst_entry *dst = __sk_dst_get(sk); + + if (dst && dst->obsolete && + INDIRECT_CALL_INET(dst->ops->check, ip6_dst_check, ipv4_dst_check, + dst, cookie) == NULL) { + sk_tx_queue_clear(sk); + WRITE_ONCE(sk->sk_dst_pending_confirm, 0); + RCU_INIT_POINTER(sk->sk_dst_cache, NULL); + dst_release(dst); + return NULL; + } + + return dst; +} +EXPORT_SYMBOL(__sk_dst_check); + +struct dst_entry *sk_dst_check(struct sock *sk, u32 cookie) +{ + struct dst_entry *dst = sk_dst_get(sk); + + if (dst && dst->obsolete && + INDIRECT_CALL_INET(dst->ops->check, ip6_dst_check, ipv4_dst_check, + dst, cookie) == NULL) { + sk_dst_reset(sk); + dst_release(dst); + return NULL; + } + + return dst; +} +EXPORT_SYMBOL(sk_dst_check); + +static int sock_bindtoindex_locked(struct sock *sk, int ifindex) +{ + int ret = -ENOPROTOOPT; +#ifdef CONFIG_NETDEVICES + struct net *net = sock_net(sk); + + /* Sorry... */ + ret = -EPERM; + if (sk->sk_bound_dev_if && !ns_capable(net->user_ns, CAP_NET_RAW)) + goto out; + + ret = -EINVAL; + if (ifindex < 0) + goto out; + + /* Paired with all READ_ONCE() done locklessly. */ + WRITE_ONCE(sk->sk_bound_dev_if, ifindex); + + if (sk->sk_prot->rehash) + sk->sk_prot->rehash(sk); + sk_dst_reset(sk); + + ret = 0; + +out: +#endif + + return ret; +} + +int sock_bindtoindex(struct sock *sk, int ifindex, bool lock_sk) +{ + int ret; + + if (lock_sk) + lock_sock(sk); + ret = sock_bindtoindex_locked(sk, ifindex); + if (lock_sk) + release_sock(sk); + + return ret; +} +EXPORT_SYMBOL(sock_bindtoindex); + +static int sock_setbindtodevice(struct sock *sk, sockptr_t optval, int optlen) +{ + int ret = -ENOPROTOOPT; +#ifdef CONFIG_NETDEVICES + struct net *net = sock_net(sk); + char devname[IFNAMSIZ]; + int index; + + ret = -EINVAL; + if (optlen < 0) + goto out; + + /* Bind this socket to a particular device like "eth0", + * as specified in the passed interface name. If the + * name is "" or the option length is zero the socket + * is not bound. + */ + if (optlen > IFNAMSIZ - 1) + optlen = IFNAMSIZ - 1; + memset(devname, 0, sizeof(devname)); + + ret = -EFAULT; + if (copy_from_sockptr(devname, optval, optlen)) + goto out; + + index = 0; + if (devname[0] != '\0') { + struct net_device *dev; + + rcu_read_lock(); + dev = dev_get_by_name_rcu(net, devname); + if (dev) + index = dev->ifindex; + rcu_read_unlock(); + ret = -ENODEV; + if (!dev) + goto out; + } + + sockopt_lock_sock(sk); + ret = sock_bindtoindex_locked(sk, index); + sockopt_release_sock(sk); +out: +#endif + + return ret; +} + +static int sock_getbindtodevice(struct sock *sk, sockptr_t optval, + sockptr_t optlen, int len) +{ + int ret = -ENOPROTOOPT; +#ifdef CONFIG_NETDEVICES + int bound_dev_if = READ_ONCE(sk->sk_bound_dev_if); + struct net *net = sock_net(sk); + char devname[IFNAMSIZ]; + + if (bound_dev_if == 0) { + len = 0; + goto zero; + } + + ret = -EINVAL; + if (len < IFNAMSIZ) + goto out; + + ret = netdev_get_name(net, devname, bound_dev_if); + if (ret) + goto out; + + len = strlen(devname) + 1; + + ret = -EFAULT; + if (copy_to_sockptr(optval, devname, len)) + goto out; + +zero: + ret = -EFAULT; + if (copy_to_sockptr(optlen, &len, sizeof(int))) + goto out; + + ret = 0; + +out: +#endif + + return ret; +} + +bool sk_mc_loop(struct sock *sk) +{ + if (dev_recursion_level()) + return false; + if (!sk) + return true; + /* IPV6_ADDRFORM can change sk->sk_family under us. */ + switch (READ_ONCE(sk->sk_family)) { + case AF_INET: + return inet_sk(sk)->mc_loop; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + return inet6_sk(sk)->mc_loop; +#endif + } + WARN_ON_ONCE(1); + return true; +} +EXPORT_SYMBOL(sk_mc_loop); + +void sock_set_reuseaddr(struct sock *sk) +{ + lock_sock(sk); + sk->sk_reuse = SK_CAN_REUSE; + release_sock(sk); +} +EXPORT_SYMBOL(sock_set_reuseaddr); + +void sock_set_reuseport(struct sock *sk) +{ + lock_sock(sk); + sk->sk_reuseport = true; + release_sock(sk); +} +EXPORT_SYMBOL(sock_set_reuseport); + +void sock_no_linger(struct sock *sk) +{ + lock_sock(sk); + WRITE_ONCE(sk->sk_lingertime, 0); + sock_set_flag(sk, SOCK_LINGER); + release_sock(sk); +} +EXPORT_SYMBOL(sock_no_linger); + +void sock_set_priority(struct sock *sk, u32 priority) +{ + lock_sock(sk); + WRITE_ONCE(sk->sk_priority, priority); + release_sock(sk); +} +EXPORT_SYMBOL(sock_set_priority); + +void sock_set_sndtimeo(struct sock *sk, s64 secs) +{ + lock_sock(sk); + if (secs && secs < MAX_SCHEDULE_TIMEOUT / HZ - 1) + WRITE_ONCE(sk->sk_sndtimeo, secs * HZ); + else + WRITE_ONCE(sk->sk_sndtimeo, MAX_SCHEDULE_TIMEOUT); + release_sock(sk); +} +EXPORT_SYMBOL(sock_set_sndtimeo); + +static void __sock_set_timestamps(struct sock *sk, bool val, bool new, bool ns) +{ + if (val) { + sock_valbool_flag(sk, SOCK_TSTAMP_NEW, new); + sock_valbool_flag(sk, SOCK_RCVTSTAMPNS, ns); + sock_set_flag(sk, SOCK_RCVTSTAMP); + sock_enable_timestamp(sk, SOCK_TIMESTAMP); + } else { + sock_reset_flag(sk, SOCK_RCVTSTAMP); + sock_reset_flag(sk, SOCK_RCVTSTAMPNS); + } +} + +void sock_enable_timestamps(struct sock *sk) +{ + lock_sock(sk); + __sock_set_timestamps(sk, true, false, true); + release_sock(sk); +} +EXPORT_SYMBOL(sock_enable_timestamps); + +void sock_set_timestamp(struct sock *sk, int optname, bool valbool) +{ + switch (optname) { + case SO_TIMESTAMP_OLD: + __sock_set_timestamps(sk, valbool, false, false); + break; + case SO_TIMESTAMP_NEW: + __sock_set_timestamps(sk, valbool, true, false); + break; + case SO_TIMESTAMPNS_OLD: + __sock_set_timestamps(sk, valbool, false, true); + break; + case SO_TIMESTAMPNS_NEW: + __sock_set_timestamps(sk, valbool, true, true); + break; + } +} + +static int sock_timestamping_bind_phc(struct sock *sk, int phc_index) +{ + struct net *net = sock_net(sk); + struct net_device *dev = NULL; + bool match = false; + int *vclock_index; + int i, num; + + if (sk->sk_bound_dev_if) + dev = dev_get_by_index(net, sk->sk_bound_dev_if); + + if (!dev) { + pr_err("%s: sock not bind to device\n", __func__); + return -EOPNOTSUPP; + } + + num = ethtool_get_phc_vclocks(dev, &vclock_index); + dev_put(dev); + + for (i = 0; i < num; i++) { + if (*(vclock_index + i) == phc_index) { + match = true; + break; + } + } + + if (num > 0) + kfree(vclock_index); + + if (!match) + return -EINVAL; + + WRITE_ONCE(sk->sk_bind_phc, phc_index); + + return 0; +} + +int sock_set_timestamping(struct sock *sk, int optname, + struct so_timestamping timestamping) +{ + int val = timestamping.flags; + int ret; + + if (val & ~SOF_TIMESTAMPING_MASK) + return -EINVAL; + + if (val & SOF_TIMESTAMPING_OPT_ID && + !(sk->sk_tsflags & SOF_TIMESTAMPING_OPT_ID)) { + if (sk_is_tcp(sk)) { + if ((1 << sk->sk_state) & + (TCPF_CLOSE | TCPF_LISTEN)) + return -EINVAL; + atomic_set(&sk->sk_tskey, tcp_sk(sk)->snd_una); + } else { + atomic_set(&sk->sk_tskey, 0); + } + } + + if (val & SOF_TIMESTAMPING_OPT_STATS && + !(val & SOF_TIMESTAMPING_OPT_TSONLY)) + return -EINVAL; + + if (val & SOF_TIMESTAMPING_BIND_PHC) { + ret = sock_timestamping_bind_phc(sk, timestamping.bind_phc); + if (ret) + return ret; + } + + WRITE_ONCE(sk->sk_tsflags, val); + sock_valbool_flag(sk, SOCK_TSTAMP_NEW, optname == SO_TIMESTAMPING_NEW); + + if (val & SOF_TIMESTAMPING_RX_SOFTWARE) + sock_enable_timestamp(sk, + SOCK_TIMESTAMPING_RX_SOFTWARE); + else + sock_disable_timestamp(sk, + (1UL << SOCK_TIMESTAMPING_RX_SOFTWARE)); + return 0; +} + +void sock_set_keepalive(struct sock *sk) +{ + lock_sock(sk); + if (sk->sk_prot->keepalive) + sk->sk_prot->keepalive(sk, true); + sock_valbool_flag(sk, SOCK_KEEPOPEN, true); + release_sock(sk); +} +EXPORT_SYMBOL(sock_set_keepalive); + +static void __sock_set_rcvbuf(struct sock *sk, int val) +{ + /* Ensure val * 2 fits into an int, to prevent max_t() from treating it + * as a negative value. + */ + val = min_t(int, val, INT_MAX / 2); + sk->sk_userlocks |= SOCK_RCVBUF_LOCK; + + /* We double it on the way in to account for "struct sk_buff" etc. + * overhead. Applications assume that the SO_RCVBUF setting they make + * will allow that much actual data to be received on that socket. + * + * Applications are unaware that "struct sk_buff" and other overheads + * allocate from the receive buffer during socket buffer allocation. + * + * And after considering the possible alternatives, returning the value + * we actually used in getsockopt is the most desirable behavior. + */ + WRITE_ONCE(sk->sk_rcvbuf, max_t(int, val * 2, SOCK_MIN_RCVBUF)); +} + +void sock_set_rcvbuf(struct sock *sk, int val) +{ + lock_sock(sk); + __sock_set_rcvbuf(sk, val); + release_sock(sk); +} +EXPORT_SYMBOL(sock_set_rcvbuf); + +static void __sock_set_mark(struct sock *sk, u32 val) +{ + if (val != sk->sk_mark) { + WRITE_ONCE(sk->sk_mark, val); + sk_dst_reset(sk); + } +} + +void sock_set_mark(struct sock *sk, u32 val) +{ + lock_sock(sk); + __sock_set_mark(sk, val); + release_sock(sk); +} +EXPORT_SYMBOL(sock_set_mark); + +static void sock_release_reserved_memory(struct sock *sk, int bytes) +{ + /* Round down bytes to multiple of pages */ + bytes = round_down(bytes, PAGE_SIZE); + + WARN_ON(bytes > sk->sk_reserved_mem); + WRITE_ONCE(sk->sk_reserved_mem, sk->sk_reserved_mem - bytes); + sk_mem_reclaim(sk); +} + +static int sock_reserve_memory(struct sock *sk, int bytes) +{ + long allocated; + bool charged; + int pages; + + if (!mem_cgroup_sockets_enabled || !sk->sk_memcg || !sk_has_account(sk)) + return -EOPNOTSUPP; + + if (!bytes) + return 0; + + pages = sk_mem_pages(bytes); + + /* pre-charge to memcg */ + charged = mem_cgroup_charge_skmem(sk->sk_memcg, pages, + GFP_KERNEL | __GFP_RETRY_MAYFAIL); + if (!charged) + return -ENOMEM; + + /* pre-charge to forward_alloc */ + sk_memory_allocated_add(sk, pages); + allocated = sk_memory_allocated(sk); + /* If the system goes into memory pressure with this + * precharge, give up and return error. + */ + if (allocated > sk_prot_mem_limits(sk, 1)) { + sk_memory_allocated_sub(sk, pages); + mem_cgroup_uncharge_skmem(sk->sk_memcg, pages); + return -ENOMEM; + } + sk_forward_alloc_add(sk, pages << PAGE_SHIFT); + + WRITE_ONCE(sk->sk_reserved_mem, + sk->sk_reserved_mem + (pages << PAGE_SHIFT)); + + return 0; +} + +void sockopt_lock_sock(struct sock *sk) +{ + /* When current->bpf_ctx is set, the setsockopt is called from + * a bpf prog. bpf has ensured the sk lock has been + * acquired before calling setsockopt(). + */ + if (has_current_bpf_ctx()) + return; + + lock_sock(sk); +} +EXPORT_SYMBOL(sockopt_lock_sock); + +void sockopt_release_sock(struct sock *sk) +{ + if (has_current_bpf_ctx()) + return; + + release_sock(sk); +} +EXPORT_SYMBOL(sockopt_release_sock); + +bool sockopt_ns_capable(struct user_namespace *ns, int cap) +{ + return has_current_bpf_ctx() || ns_capable(ns, cap); +} +EXPORT_SYMBOL(sockopt_ns_capable); + +bool sockopt_capable(int cap) +{ + return has_current_bpf_ctx() || capable(cap); +} +EXPORT_SYMBOL(sockopt_capable); + +/* + * This is meant for all protocols to use and covers goings on + * at the socket level. Everything here is generic. + */ + +int sk_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct so_timestamping timestamping; + struct socket *sock = sk->sk_socket; + struct sock_txtime sk_txtime; + int val; + int valbool; + struct linger ling; + int ret = 0; + + /* + * Options without arguments + */ + + if (optname == SO_BINDTODEVICE) + return sock_setbindtodevice(sk, optval, optlen); + + if (optlen < sizeof(int)) + return -EINVAL; + + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + + valbool = val ? 1 : 0; + + sockopt_lock_sock(sk); + + switch (optname) { + case SO_DEBUG: + if (val && !sockopt_capable(CAP_NET_ADMIN)) + ret = -EACCES; + else + sock_valbool_flag(sk, SOCK_DBG, valbool); + break; + case SO_REUSEADDR: + sk->sk_reuse = (valbool ? SK_CAN_REUSE : SK_NO_REUSE); + break; + case SO_REUSEPORT: + sk->sk_reuseport = valbool; + break; + case SO_TYPE: + case SO_PROTOCOL: + case SO_DOMAIN: + case SO_ERROR: + ret = -ENOPROTOOPT; + break; + case SO_DONTROUTE: + sock_valbool_flag(sk, SOCK_LOCALROUTE, valbool); + sk_dst_reset(sk); + break; + case SO_BROADCAST: + sock_valbool_flag(sk, SOCK_BROADCAST, valbool); + break; + case SO_SNDBUF: + /* Don't error on this BSD doesn't and if you think + * about it this is right. Otherwise apps have to + * play 'guess the biggest size' games. RCVBUF/SNDBUF + * are treated in BSD as hints + */ + val = min_t(u32, val, READ_ONCE(sysctl_wmem_max)); +set_sndbuf: + /* Ensure val * 2 fits into an int, to prevent max_t() + * from treating it as a negative value. + */ + val = min_t(int, val, INT_MAX / 2); + sk->sk_userlocks |= SOCK_SNDBUF_LOCK; + WRITE_ONCE(sk->sk_sndbuf, + max_t(int, val * 2, SOCK_MIN_SNDBUF)); + /* Wake up sending tasks if we upped the value. */ + sk->sk_write_space(sk); + break; + + case SO_SNDBUFFORCE: + if (!sockopt_capable(CAP_NET_ADMIN)) { + ret = -EPERM; + break; + } + + /* No negative values (to prevent underflow, as val will be + * multiplied by 2). + */ + if (val < 0) + val = 0; + goto set_sndbuf; + + case SO_RCVBUF: + /* Don't error on this BSD doesn't and if you think + * about it this is right. Otherwise apps have to + * play 'guess the biggest size' games. RCVBUF/SNDBUF + * are treated in BSD as hints + */ + __sock_set_rcvbuf(sk, min_t(u32, val, READ_ONCE(sysctl_rmem_max))); + break; + + case SO_RCVBUFFORCE: + if (!sockopt_capable(CAP_NET_ADMIN)) { + ret = -EPERM; + break; + } + + /* No negative values (to prevent underflow, as val will be + * multiplied by 2). + */ + __sock_set_rcvbuf(sk, max(val, 0)); + break; + + case SO_KEEPALIVE: + if (sk->sk_prot->keepalive) + sk->sk_prot->keepalive(sk, valbool); + sock_valbool_flag(sk, SOCK_KEEPOPEN, valbool); + break; + + case SO_OOBINLINE: + sock_valbool_flag(sk, SOCK_URGINLINE, valbool); + break; + + case SO_NO_CHECK: + sk->sk_no_check_tx = valbool; + break; + + case SO_PRIORITY: + if ((val >= 0 && val <= 6) || + sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) || + sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) + WRITE_ONCE(sk->sk_priority, val); + else + ret = -EPERM; + break; + + case SO_LINGER: + if (optlen < sizeof(ling)) { + ret = -EINVAL; /* 1003.1g */ + break; + } + if (copy_from_sockptr(&ling, optval, sizeof(ling))) { + ret = -EFAULT; + break; + } + if (!ling.l_onoff) { + sock_reset_flag(sk, SOCK_LINGER); + } else { + unsigned long t_sec = ling.l_linger; + + if (t_sec >= MAX_SCHEDULE_TIMEOUT / HZ) + WRITE_ONCE(sk->sk_lingertime, MAX_SCHEDULE_TIMEOUT); + else + WRITE_ONCE(sk->sk_lingertime, t_sec * HZ); + sock_set_flag(sk, SOCK_LINGER); + } + break; + + case SO_BSDCOMPAT: + break; + + case SO_PASSCRED: + if (valbool) + set_bit(SOCK_PASSCRED, &sock->flags); + else + clear_bit(SOCK_PASSCRED, &sock->flags); + break; + + case SO_TIMESTAMP_OLD: + case SO_TIMESTAMP_NEW: + case SO_TIMESTAMPNS_OLD: + case SO_TIMESTAMPNS_NEW: + sock_set_timestamp(sk, optname, valbool); + break; + + case SO_TIMESTAMPING_NEW: + case SO_TIMESTAMPING_OLD: + if (optlen == sizeof(timestamping)) { + if (copy_from_sockptr(×tamping, optval, + sizeof(timestamping))) { + ret = -EFAULT; + break; + } + } else { + memset(×tamping, 0, sizeof(timestamping)); + timestamping.flags = val; + } + ret = sock_set_timestamping(sk, optname, timestamping); + break; + + case SO_RCVLOWAT: + if (val < 0) + val = INT_MAX; + if (sock && sock->ops->set_rcvlowat) + ret = sock->ops->set_rcvlowat(sk, val); + else + WRITE_ONCE(sk->sk_rcvlowat, val ? : 1); + break; + + case SO_RCVTIMEO_OLD: + case SO_RCVTIMEO_NEW: + ret = sock_set_timeout(&sk->sk_rcvtimeo, optval, + optlen, optname == SO_RCVTIMEO_OLD); + break; + + case SO_SNDTIMEO_OLD: + case SO_SNDTIMEO_NEW: + ret = sock_set_timeout(&sk->sk_sndtimeo, optval, + optlen, optname == SO_SNDTIMEO_OLD); + break; + + case SO_ATTACH_FILTER: { + struct sock_fprog fprog; + + ret = copy_bpf_fprog_from_user(&fprog, optval, optlen); + if (!ret) + ret = sk_attach_filter(&fprog, sk); + break; + } + case SO_ATTACH_BPF: + ret = -EINVAL; + if (optlen == sizeof(u32)) { + u32 ufd; + + ret = -EFAULT; + if (copy_from_sockptr(&ufd, optval, sizeof(ufd))) + break; + + ret = sk_attach_bpf(ufd, sk); + } + break; + + case SO_ATTACH_REUSEPORT_CBPF: { + struct sock_fprog fprog; + + ret = copy_bpf_fprog_from_user(&fprog, optval, optlen); + if (!ret) + ret = sk_reuseport_attach_filter(&fprog, sk); + break; + } + case SO_ATTACH_REUSEPORT_EBPF: + ret = -EINVAL; + if (optlen == sizeof(u32)) { + u32 ufd; + + ret = -EFAULT; + if (copy_from_sockptr(&ufd, optval, sizeof(ufd))) + break; + + ret = sk_reuseport_attach_bpf(ufd, sk); + } + break; + + case SO_DETACH_REUSEPORT_BPF: + ret = reuseport_detach_prog(sk); + break; + + case SO_DETACH_FILTER: + ret = sk_detach_filter(sk); + break; + + case SO_LOCK_FILTER: + if (sock_flag(sk, SOCK_FILTER_LOCKED) && !valbool) + ret = -EPERM; + else + sock_valbool_flag(sk, SOCK_FILTER_LOCKED, valbool); + break; + + case SO_PASSSEC: + if (valbool) + set_bit(SOCK_PASSSEC, &sock->flags); + else + clear_bit(SOCK_PASSSEC, &sock->flags); + break; + case SO_MARK: + if (!sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) && + !sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) { + ret = -EPERM; + break; + } + + __sock_set_mark(sk, val); + break; + case SO_RCVMARK: + sock_valbool_flag(sk, SOCK_RCVMARK, valbool); + break; + + case SO_RXQ_OVFL: + sock_valbool_flag(sk, SOCK_RXQ_OVFL, valbool); + break; + + case SO_WIFI_STATUS: + sock_valbool_flag(sk, SOCK_WIFI_STATUS, valbool); + break; + + case SO_PEEK_OFF: + if (sock->ops->set_peek_off) + ret = sock->ops->set_peek_off(sk, val); + else + ret = -EOPNOTSUPP; + break; + + case SO_NOFCS: + sock_valbool_flag(sk, SOCK_NOFCS, valbool); + break; + + case SO_SELECT_ERR_QUEUE: + sock_valbool_flag(sk, SOCK_SELECT_ERR_QUEUE, valbool); + break; + +#ifdef CONFIG_NET_RX_BUSY_POLL + case SO_BUSY_POLL: + /* allow unprivileged users to decrease the value */ + if ((val > sk->sk_ll_usec) && !sockopt_capable(CAP_NET_ADMIN)) + ret = -EPERM; + else { + if (val < 0) + ret = -EINVAL; + else + WRITE_ONCE(sk->sk_ll_usec, val); + } + break; + case SO_PREFER_BUSY_POLL: + if (valbool && !sockopt_capable(CAP_NET_ADMIN)) + ret = -EPERM; + else + WRITE_ONCE(sk->sk_prefer_busy_poll, valbool); + break; + case SO_BUSY_POLL_BUDGET: + if (val > READ_ONCE(sk->sk_busy_poll_budget) && !sockopt_capable(CAP_NET_ADMIN)) { + ret = -EPERM; + } else { + if (val < 0 || val > U16_MAX) + ret = -EINVAL; + else + WRITE_ONCE(sk->sk_busy_poll_budget, val); + } + break; +#endif + + case SO_MAX_PACING_RATE: + { + unsigned long ulval = (val == ~0U) ? ~0UL : (unsigned int)val; + + if (sizeof(ulval) != sizeof(val) && + optlen >= sizeof(ulval) && + copy_from_sockptr(&ulval, optval, sizeof(ulval))) { + ret = -EFAULT; + break; + } + if (ulval != ~0UL) + cmpxchg(&sk->sk_pacing_status, + SK_PACING_NONE, + SK_PACING_NEEDED); + /* Pairs with READ_ONCE() from sk_getsockopt() */ + WRITE_ONCE(sk->sk_max_pacing_rate, ulval); + sk->sk_pacing_rate = min(sk->sk_pacing_rate, ulval); + break; + } + case SO_INCOMING_CPU: + reuseport_update_incoming_cpu(sk, val); + break; + + case SO_CNX_ADVICE: + if (val == 1) + dst_negative_advice(sk); + break; + + case SO_ZEROCOPY: + if (sk->sk_family == PF_INET || sk->sk_family == PF_INET6) { + if (!(sk_is_tcp(sk) || + (sk->sk_type == SOCK_DGRAM && + sk->sk_protocol == IPPROTO_UDP))) + ret = -EOPNOTSUPP; + } else if (sk->sk_family != PF_RDS) { + ret = -EOPNOTSUPP; + } + if (!ret) { + if (val < 0 || val > 1) + ret = -EINVAL; + else + sock_valbool_flag(sk, SOCK_ZEROCOPY, valbool); + } + break; + + case SO_TXTIME: + if (optlen != sizeof(struct sock_txtime)) { + ret = -EINVAL; + break; + } else if (copy_from_sockptr(&sk_txtime, optval, + sizeof(struct sock_txtime))) { + ret = -EFAULT; + break; + } else if (sk_txtime.flags & ~SOF_TXTIME_FLAGS_MASK) { + ret = -EINVAL; + break; + } + /* CLOCK_MONOTONIC is only used by sch_fq, and this packet + * scheduler has enough safe guards. + */ + if (sk_txtime.clockid != CLOCK_MONOTONIC && + !sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) { + ret = -EPERM; + break; + } + sock_valbool_flag(sk, SOCK_TXTIME, true); + sk->sk_clockid = sk_txtime.clockid; + sk->sk_txtime_deadline_mode = + !!(sk_txtime.flags & SOF_TXTIME_DEADLINE_MODE); + sk->sk_txtime_report_errors = + !!(sk_txtime.flags & SOF_TXTIME_REPORT_ERRORS); + break; + + case SO_BINDTOIFINDEX: + ret = sock_bindtoindex_locked(sk, val); + break; + + case SO_BUF_LOCK: + if (val & ~SOCK_BUF_LOCK_MASK) { + ret = -EINVAL; + break; + } + sk->sk_userlocks = val | (sk->sk_userlocks & + ~SOCK_BUF_LOCK_MASK); + break; + + case SO_RESERVE_MEM: + { + int delta; + + if (val < 0) { + ret = -EINVAL; + break; + } + + delta = val - sk->sk_reserved_mem; + if (delta < 0) + sock_release_reserved_memory(sk, -delta); + else + ret = sock_reserve_memory(sk, delta); + break; + } + + case SO_TXREHASH: + if (val < -1 || val > 1) { + ret = -EINVAL; + break; + } + if ((u8)val == SOCK_TXREHASH_DEFAULT) + val = READ_ONCE(sock_net(sk)->core.sysctl_txrehash); + /* Paired with READ_ONCE() in tcp_rtx_synack() + * and sk_getsockopt(). + */ + WRITE_ONCE(sk->sk_txrehash, (u8)val); + break; + + default: + ret = -ENOPROTOOPT; + break; + } + sockopt_release_sock(sk); + return ret; +} + +int sock_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + return sk_setsockopt(sock->sk, level, optname, + optval, optlen); +} +EXPORT_SYMBOL(sock_setsockopt); + +static const struct cred *sk_get_peer_cred(struct sock *sk) +{ + const struct cred *cred; + + spin_lock(&sk->sk_peer_lock); + cred = get_cred(sk->sk_peer_cred); + spin_unlock(&sk->sk_peer_lock); + + return cred; +} + +static void cred_to_ucred(struct pid *pid, const struct cred *cred, + struct ucred *ucred) +{ + ucred->pid = pid_vnr(pid); + ucred->uid = ucred->gid = -1; + if (cred) { + struct user_namespace *current_ns = current_user_ns(); + + ucred->uid = from_kuid_munged(current_ns, cred->euid); + ucred->gid = from_kgid_munged(current_ns, cred->egid); + } +} + +static int groups_to_user(sockptr_t dst, const struct group_info *src) +{ + struct user_namespace *user_ns = current_user_ns(); + int i; + + for (i = 0; i < src->ngroups; i++) { + gid_t gid = from_kgid_munged(user_ns, src->gid[i]); + + if (copy_to_sockptr_offset(dst, i * sizeof(gid), &gid, sizeof(gid))) + return -EFAULT; + } + + return 0; +} + +int sk_getsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, sockptr_t optlen) +{ + struct socket *sock = sk->sk_socket; + + union { + int val; + u64 val64; + unsigned long ulval; + struct linger ling; + struct old_timeval32 tm32; + struct __kernel_old_timeval tm; + struct __kernel_sock_timeval stm; + struct sock_txtime txtime; + struct so_timestamping timestamping; + } v; + + int lv = sizeof(int); + int len; + + if (copy_from_sockptr(&len, optlen, sizeof(int))) + return -EFAULT; + if (len < 0) + return -EINVAL; + + memset(&v, 0, sizeof(v)); + + switch (optname) { + case SO_DEBUG: + v.val = sock_flag(sk, SOCK_DBG); + break; + + case SO_DONTROUTE: + v.val = sock_flag(sk, SOCK_LOCALROUTE); + break; + + case SO_BROADCAST: + v.val = sock_flag(sk, SOCK_BROADCAST); + break; + + case SO_SNDBUF: + v.val = READ_ONCE(sk->sk_sndbuf); + break; + + case SO_RCVBUF: + v.val = READ_ONCE(sk->sk_rcvbuf); + break; + + case SO_REUSEADDR: + v.val = sk->sk_reuse; + break; + + case SO_REUSEPORT: + v.val = sk->sk_reuseport; + break; + + case SO_KEEPALIVE: + v.val = sock_flag(sk, SOCK_KEEPOPEN); + break; + + case SO_TYPE: + v.val = sk->sk_type; + break; + + case SO_PROTOCOL: + v.val = sk->sk_protocol; + break; + + case SO_DOMAIN: + v.val = sk->sk_family; + break; + + case SO_ERROR: + v.val = -sock_error(sk); + if (v.val == 0) + v.val = xchg(&sk->sk_err_soft, 0); + break; + + case SO_OOBINLINE: + v.val = sock_flag(sk, SOCK_URGINLINE); + break; + + case SO_NO_CHECK: + v.val = sk->sk_no_check_tx; + break; + + case SO_PRIORITY: + v.val = READ_ONCE(sk->sk_priority); + break; + + case SO_LINGER: + lv = sizeof(v.ling); + v.ling.l_onoff = sock_flag(sk, SOCK_LINGER); + v.ling.l_linger = READ_ONCE(sk->sk_lingertime) / HZ; + break; + + case SO_BSDCOMPAT: + break; + + case SO_TIMESTAMP_OLD: + v.val = sock_flag(sk, SOCK_RCVTSTAMP) && + !sock_flag(sk, SOCK_TSTAMP_NEW) && + !sock_flag(sk, SOCK_RCVTSTAMPNS); + break; + + case SO_TIMESTAMPNS_OLD: + v.val = sock_flag(sk, SOCK_RCVTSTAMPNS) && !sock_flag(sk, SOCK_TSTAMP_NEW); + break; + + case SO_TIMESTAMP_NEW: + v.val = sock_flag(sk, SOCK_RCVTSTAMP) && sock_flag(sk, SOCK_TSTAMP_NEW); + break; + + case SO_TIMESTAMPNS_NEW: + v.val = sock_flag(sk, SOCK_RCVTSTAMPNS) && sock_flag(sk, SOCK_TSTAMP_NEW); + break; + + case SO_TIMESTAMPING_OLD: + case SO_TIMESTAMPING_NEW: + lv = sizeof(v.timestamping); + /* For the later-added case SO_TIMESTAMPING_NEW: Be strict about only + * returning the flags when they were set through the same option. + * Don't change the beviour for the old case SO_TIMESTAMPING_OLD. + */ + if (optname == SO_TIMESTAMPING_OLD || sock_flag(sk, SOCK_TSTAMP_NEW)) { + v.timestamping.flags = READ_ONCE(sk->sk_tsflags); + v.timestamping.bind_phc = READ_ONCE(sk->sk_bind_phc); + } + break; + + case SO_RCVTIMEO_OLD: + case SO_RCVTIMEO_NEW: + lv = sock_get_timeout(READ_ONCE(sk->sk_rcvtimeo), &v, + SO_RCVTIMEO_OLD == optname); + break; + + case SO_SNDTIMEO_OLD: + case SO_SNDTIMEO_NEW: + lv = sock_get_timeout(READ_ONCE(sk->sk_sndtimeo), &v, + SO_SNDTIMEO_OLD == optname); + break; + + case SO_RCVLOWAT: + v.val = READ_ONCE(sk->sk_rcvlowat); + break; + + case SO_SNDLOWAT: + v.val = 1; + break; + + case SO_PASSCRED: + v.val = !!test_bit(SOCK_PASSCRED, &sock->flags); + break; + + case SO_PEERCRED: + { + struct ucred peercred; + if (len > sizeof(peercred)) + len = sizeof(peercred); + + spin_lock(&sk->sk_peer_lock); + cred_to_ucred(sk->sk_peer_pid, sk->sk_peer_cred, &peercred); + spin_unlock(&sk->sk_peer_lock); + + if (copy_to_sockptr(optval, &peercred, len)) + return -EFAULT; + goto lenout; + } + + case SO_PEERGROUPS: + { + const struct cred *cred; + int ret, n; + + cred = sk_get_peer_cred(sk); + if (!cred) + return -ENODATA; + + n = cred->group_info->ngroups; + if (len < n * sizeof(gid_t)) { + len = n * sizeof(gid_t); + put_cred(cred); + return copy_to_sockptr(optlen, &len, sizeof(int)) ? -EFAULT : -ERANGE; + } + len = n * sizeof(gid_t); + + ret = groups_to_user(optval, cred->group_info); + put_cred(cred); + if (ret) + return ret; + goto lenout; + } + + case SO_PEERNAME: + { + struct sockaddr_storage address; + + lv = sock->ops->getname(sock, (struct sockaddr *)&address, 2); + if (lv < 0) + return -ENOTCONN; + if (lv < len) + return -EINVAL; + if (copy_to_sockptr(optval, &address, len)) + return -EFAULT; + goto lenout; + } + + /* Dubious BSD thing... Probably nobody even uses it, but + * the UNIX standard wants it for whatever reason... -DaveM + */ + case SO_ACCEPTCONN: + v.val = sk->sk_state == TCP_LISTEN; + break; + + case SO_PASSSEC: + v.val = !!test_bit(SOCK_PASSSEC, &sock->flags); + break; + + case SO_PEERSEC: + return security_socket_getpeersec_stream(sock, optval.user, optlen.user, len); + + case SO_MARK: + v.val = READ_ONCE(sk->sk_mark); + break; + + case SO_RCVMARK: + v.val = sock_flag(sk, SOCK_RCVMARK); + break; + + case SO_RXQ_OVFL: + v.val = sock_flag(sk, SOCK_RXQ_OVFL); + break; + + case SO_WIFI_STATUS: + v.val = sock_flag(sk, SOCK_WIFI_STATUS); + break; + + case SO_PEEK_OFF: + if (!sock->ops->set_peek_off) + return -EOPNOTSUPP; + + v.val = READ_ONCE(sk->sk_peek_off); + break; + case SO_NOFCS: + v.val = sock_flag(sk, SOCK_NOFCS); + break; + + case SO_BINDTODEVICE: + return sock_getbindtodevice(sk, optval, optlen, len); + + case SO_GET_FILTER: + len = sk_get_filter(sk, optval, len); + if (len < 0) + return len; + + goto lenout; + + case SO_LOCK_FILTER: + v.val = sock_flag(sk, SOCK_FILTER_LOCKED); + break; + + case SO_BPF_EXTENSIONS: + v.val = bpf_tell_extensions(); + break; + + case SO_SELECT_ERR_QUEUE: + v.val = sock_flag(sk, SOCK_SELECT_ERR_QUEUE); + break; + +#ifdef CONFIG_NET_RX_BUSY_POLL + case SO_BUSY_POLL: + v.val = READ_ONCE(sk->sk_ll_usec); + break; + case SO_PREFER_BUSY_POLL: + v.val = READ_ONCE(sk->sk_prefer_busy_poll); + break; +#endif + + case SO_MAX_PACING_RATE: + /* The READ_ONCE() pair with the WRITE_ONCE() in sk_setsockopt() */ + if (sizeof(v.ulval) != sizeof(v.val) && len >= sizeof(v.ulval)) { + lv = sizeof(v.ulval); + v.ulval = READ_ONCE(sk->sk_max_pacing_rate); + } else { + /* 32bit version */ + v.val = min_t(unsigned long, ~0U, + READ_ONCE(sk->sk_max_pacing_rate)); + } + break; + + case SO_INCOMING_CPU: + v.val = READ_ONCE(sk->sk_incoming_cpu); + break; + + case SO_MEMINFO: + { + u32 meminfo[SK_MEMINFO_VARS]; + + sk_get_meminfo(sk, meminfo); + + len = min_t(unsigned int, len, sizeof(meminfo)); + if (copy_to_sockptr(optval, &meminfo, len)) + return -EFAULT; + + goto lenout; + } + +#ifdef CONFIG_NET_RX_BUSY_POLL + case SO_INCOMING_NAPI_ID: + v.val = READ_ONCE(sk->sk_napi_id); + + /* aggregate non-NAPI IDs down to 0 */ + if (v.val < MIN_NAPI_ID) + v.val = 0; + + break; +#endif + + case SO_COOKIE: + lv = sizeof(u64); + if (len < lv) + return -EINVAL; + v.val64 = sock_gen_cookie(sk); + break; + + case SO_ZEROCOPY: + v.val = sock_flag(sk, SOCK_ZEROCOPY); + break; + + case SO_TXTIME: + lv = sizeof(v.txtime); + v.txtime.clockid = sk->sk_clockid; + v.txtime.flags |= sk->sk_txtime_deadline_mode ? + SOF_TXTIME_DEADLINE_MODE : 0; + v.txtime.flags |= sk->sk_txtime_report_errors ? + SOF_TXTIME_REPORT_ERRORS : 0; + break; + + case SO_BINDTOIFINDEX: + v.val = READ_ONCE(sk->sk_bound_dev_if); + break; + + case SO_NETNS_COOKIE: + lv = sizeof(u64); + if (len != lv) + return -EINVAL; + v.val64 = sock_net(sk)->net_cookie; + break; + + case SO_BUF_LOCK: + v.val = sk->sk_userlocks & SOCK_BUF_LOCK_MASK; + break; + + case SO_RESERVE_MEM: + v.val = READ_ONCE(sk->sk_reserved_mem); + break; + + case SO_TXREHASH: + /* Paired with WRITE_ONCE() in sk_setsockopt() */ + v.val = READ_ONCE(sk->sk_txrehash); + break; + + default: + /* We implement the SO_SNDLOWAT etc to not be settable + * (1003.1g 7). + */ + return -ENOPROTOOPT; + } + + if (len > lv) + len = lv; + if (copy_to_sockptr(optval, &v, len)) + return -EFAULT; +lenout: + if (copy_to_sockptr(optlen, &len, sizeof(int))) + return -EFAULT; + return 0; +} + +int sock_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + return sk_getsockopt(sock->sk, level, optname, + USER_SOCKPTR(optval), + USER_SOCKPTR(optlen)); +} + +/* + * Initialize an sk_lock. + * + * (We also register the sk_lock with the lock validator.) + */ +static inline void sock_lock_init(struct sock *sk) +{ + if (sk->sk_kern_sock) + sock_lock_init_class_and_name( + sk, + af_family_kern_slock_key_strings[sk->sk_family], + af_family_kern_slock_keys + sk->sk_family, + af_family_kern_key_strings[sk->sk_family], + af_family_kern_keys + sk->sk_family); + else + sock_lock_init_class_and_name( + sk, + af_family_slock_key_strings[sk->sk_family], + af_family_slock_keys + sk->sk_family, + af_family_key_strings[sk->sk_family], + af_family_keys + sk->sk_family); +} + +/* + * Copy all fields from osk to nsk but nsk->sk_refcnt must not change yet, + * even temporarly, because of RCU lookups. sk_node should also be left as is. + * We must not copy fields between sk_dontcopy_begin and sk_dontcopy_end + */ +static void sock_copy(struct sock *nsk, const struct sock *osk) +{ + const struct proto *prot = READ_ONCE(osk->sk_prot); +#ifdef CONFIG_SECURITY_NETWORK + void *sptr = nsk->sk_security; +#endif + + /* If we move sk_tx_queue_mapping out of the private section, + * we must check if sk_tx_queue_clear() is called after + * sock_copy() in sk_clone_lock(). + */ + BUILD_BUG_ON(offsetof(struct sock, sk_tx_queue_mapping) < + offsetof(struct sock, sk_dontcopy_begin) || + offsetof(struct sock, sk_tx_queue_mapping) >= + offsetof(struct sock, sk_dontcopy_end)); + + memcpy(nsk, osk, offsetof(struct sock, sk_dontcopy_begin)); + + memcpy(&nsk->sk_dontcopy_end, &osk->sk_dontcopy_end, + prot->obj_size - offsetof(struct sock, sk_dontcopy_end)); + +#ifdef CONFIG_SECURITY_NETWORK + nsk->sk_security = sptr; + security_sk_clone(osk, nsk); +#endif +} + +static struct sock *sk_prot_alloc(struct proto *prot, gfp_t priority, + int family) +{ + struct sock *sk; + struct kmem_cache *slab; + + slab = prot->slab; + if (slab != NULL) { + sk = kmem_cache_alloc(slab, priority & ~__GFP_ZERO); + if (!sk) + return sk; + if (want_init_on_alloc(priority)) + sk_prot_clear_nulls(sk, prot->obj_size); + } else + sk = kmalloc(prot->obj_size, priority); + + if (sk != NULL) { + if (security_sk_alloc(sk, family, priority)) + goto out_free; + + if (!try_module_get(prot->owner)) + goto out_free_sec; + } + + return sk; + +out_free_sec: + security_sk_free(sk); +out_free: + if (slab != NULL) + kmem_cache_free(slab, sk); + else + kfree(sk); + return NULL; +} + +static void sk_prot_free(struct proto *prot, struct sock *sk) +{ + struct kmem_cache *slab; + struct module *owner; + + owner = prot->owner; + slab = prot->slab; + + cgroup_sk_free(&sk->sk_cgrp_data); + mem_cgroup_sk_free(sk); + security_sk_free(sk); + if (slab != NULL) + kmem_cache_free(slab, sk); + else + kfree(sk); + module_put(owner); +} + +/** + * sk_alloc - All socket objects are allocated here + * @net: the applicable net namespace + * @family: protocol family + * @priority: for allocation (%GFP_KERNEL, %GFP_ATOMIC, etc) + * @prot: struct proto associated with this new sock instance + * @kern: is this to be a kernel socket? + */ +struct sock *sk_alloc(struct net *net, int family, gfp_t priority, + struct proto *prot, int kern) +{ + struct sock *sk; + + sk = sk_prot_alloc(prot, priority | __GFP_ZERO, family); + if (sk) { + sk->sk_family = family; + /* + * See comment in struct sock definition to understand + * why we need sk_prot_creator -acme + */ + sk->sk_prot = sk->sk_prot_creator = prot; + sk->sk_kern_sock = kern; + sock_lock_init(sk); + sk->sk_net_refcnt = kern ? 0 : 1; + if (likely(sk->sk_net_refcnt)) { + get_net_track(net, &sk->ns_tracker, priority); + sock_inuse_add(net, 1); + } + + sock_net_set(sk, net); + refcount_set(&sk->sk_wmem_alloc, 1); + + mem_cgroup_sk_alloc(sk); + cgroup_sk_alloc(&sk->sk_cgrp_data); + sock_update_classid(&sk->sk_cgrp_data); + sock_update_netprioidx(&sk->sk_cgrp_data); + sk_tx_queue_clear(sk); + } + + return sk; +} +EXPORT_SYMBOL(sk_alloc); + +/* Sockets having SOCK_RCU_FREE will call this function after one RCU + * grace period. This is the case for UDP sockets and TCP listeners. + */ +static void __sk_destruct(struct rcu_head *head) +{ + struct sock *sk = container_of(head, struct sock, sk_rcu); + struct sk_filter *filter; + + if (sk->sk_destruct) + sk->sk_destruct(sk); + + filter = rcu_dereference_check(sk->sk_filter, + refcount_read(&sk->sk_wmem_alloc) == 0); + if (filter) { + sk_filter_uncharge(sk, filter); + RCU_INIT_POINTER(sk->sk_filter, NULL); + } + + sock_disable_timestamp(sk, SK_FLAGS_TIMESTAMP); + +#ifdef CONFIG_BPF_SYSCALL + bpf_sk_storage_free(sk); +#endif + + if (atomic_read(&sk->sk_omem_alloc)) + pr_debug("%s: optmem leakage (%d bytes) detected\n", + __func__, atomic_read(&sk->sk_omem_alloc)); + + if (sk->sk_frag.page) { + put_page(sk->sk_frag.page); + sk->sk_frag.page = NULL; + } + + /* We do not need to acquire sk->sk_peer_lock, we are the last user. */ + put_cred(sk->sk_peer_cred); + put_pid(sk->sk_peer_pid); + + if (likely(sk->sk_net_refcnt)) + put_net_track(sock_net(sk), &sk->ns_tracker); + sk_prot_free(sk->sk_prot_creator, sk); +} + +void sk_destruct(struct sock *sk) +{ + bool use_call_rcu = sock_flag(sk, SOCK_RCU_FREE); + + if (rcu_access_pointer(sk->sk_reuseport_cb)) { + reuseport_detach_sock(sk); + use_call_rcu = true; + } + + if (use_call_rcu) + call_rcu(&sk->sk_rcu, __sk_destruct); + else + __sk_destruct(&sk->sk_rcu); +} + +static void __sk_free(struct sock *sk) +{ + if (likely(sk->sk_net_refcnt)) + sock_inuse_add(sock_net(sk), -1); + + if (unlikely(sk->sk_net_refcnt && sock_diag_has_destroy_listeners(sk))) + sock_diag_broadcast_destroy(sk); + else + sk_destruct(sk); +} + +void sk_free(struct sock *sk) +{ + /* + * We subtract one from sk_wmem_alloc and can know if + * some packets are still in some tx queue. + * If not null, sock_wfree() will call __sk_free(sk) later + */ + if (refcount_dec_and_test(&sk->sk_wmem_alloc)) + __sk_free(sk); +} +EXPORT_SYMBOL(sk_free); + +static void sk_init_common(struct sock *sk) +{ + skb_queue_head_init(&sk->sk_receive_queue); + skb_queue_head_init(&sk->sk_write_queue); + skb_queue_head_init(&sk->sk_error_queue); + + rwlock_init(&sk->sk_callback_lock); + lockdep_set_class_and_name(&sk->sk_receive_queue.lock, + af_rlock_keys + sk->sk_family, + af_family_rlock_key_strings[sk->sk_family]); + lockdep_set_class_and_name(&sk->sk_write_queue.lock, + af_wlock_keys + sk->sk_family, + af_family_wlock_key_strings[sk->sk_family]); + lockdep_set_class_and_name(&sk->sk_error_queue.lock, + af_elock_keys + sk->sk_family, + af_family_elock_key_strings[sk->sk_family]); + lockdep_set_class_and_name(&sk->sk_callback_lock, + af_callback_keys + sk->sk_family, + af_family_clock_key_strings[sk->sk_family]); +} + +/** + * sk_clone_lock - clone a socket, and lock its clone + * @sk: the socket to clone + * @priority: for allocation (%GFP_KERNEL, %GFP_ATOMIC, etc) + * + * Caller must unlock socket even in error path (bh_unlock_sock(newsk)) + */ +struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority) +{ + struct proto *prot = READ_ONCE(sk->sk_prot); + struct sk_filter *filter; + bool is_charged = true; + struct sock *newsk; + + newsk = sk_prot_alloc(prot, priority, sk->sk_family); + if (!newsk) + goto out; + + sock_copy(newsk, sk); + + newsk->sk_prot_creator = prot; + + /* SANITY */ + if (likely(newsk->sk_net_refcnt)) { + get_net_track(sock_net(newsk), &newsk->ns_tracker, priority); + sock_inuse_add(sock_net(newsk), 1); + } + sk_node_init(&newsk->sk_node); + sock_lock_init(newsk); + bh_lock_sock(newsk); + newsk->sk_backlog.head = newsk->sk_backlog.tail = NULL; + newsk->sk_backlog.len = 0; + + atomic_set(&newsk->sk_rmem_alloc, 0); + + /* sk_wmem_alloc set to one (see sk_free() and sock_wfree()) */ + refcount_set(&newsk->sk_wmem_alloc, 1); + + atomic_set(&newsk->sk_omem_alloc, 0); + sk_init_common(newsk); + + newsk->sk_dst_cache = NULL; + newsk->sk_dst_pending_confirm = 0; + newsk->sk_wmem_queued = 0; + newsk->sk_forward_alloc = 0; + newsk->sk_reserved_mem = 0; + atomic_set(&newsk->sk_drops, 0); + newsk->sk_send_head = NULL; + newsk->sk_userlocks = sk->sk_userlocks & ~SOCK_BINDPORT_LOCK; + atomic_set(&newsk->sk_zckey, 0); + + sock_reset_flag(newsk, SOCK_DONE); + + /* sk->sk_memcg will be populated at accept() time */ + newsk->sk_memcg = NULL; + + cgroup_sk_clone(&newsk->sk_cgrp_data); + + rcu_read_lock(); + filter = rcu_dereference(sk->sk_filter); + if (filter != NULL) + /* though it's an empty new sock, the charging may fail + * if sysctl_optmem_max was changed between creation of + * original socket and cloning + */ + is_charged = sk_filter_charge(newsk, filter); + RCU_INIT_POINTER(newsk->sk_filter, filter); + rcu_read_unlock(); + + if (unlikely(!is_charged || xfrm_sk_clone_policy(newsk, sk))) { + /* We need to make sure that we don't uncharge the new + * socket if we couldn't charge it in the first place + * as otherwise we uncharge the parent's filter. + */ + if (!is_charged) + RCU_INIT_POINTER(newsk->sk_filter, NULL); + sk_free_unlock_clone(newsk); + newsk = NULL; + goto out; + } + RCU_INIT_POINTER(newsk->sk_reuseport_cb, NULL); + + if (bpf_sk_storage_clone(sk, newsk)) { + sk_free_unlock_clone(newsk); + newsk = NULL; + goto out; + } + + /* Clear sk_user_data if parent had the pointer tagged + * as not suitable for copying when cloning. + */ + if (sk_user_data_is_nocopy(newsk)) + newsk->sk_user_data = NULL; + + newsk->sk_err = 0; + newsk->sk_err_soft = 0; + newsk->sk_priority = 0; + newsk->sk_incoming_cpu = raw_smp_processor_id(); + + /* Before updating sk_refcnt, we must commit prior changes to memory + * (Documentation/RCU/rculist_nulls.rst for details) + */ + smp_wmb(); + refcount_set(&newsk->sk_refcnt, 2); + + /* Increment the counter in the same struct proto as the master + * sock (sk_refcnt_debug_inc uses newsk->sk_prot->socks, that + * is the same as sk->sk_prot->socks, as this field was copied + * with memcpy). + * + * This _changes_ the previous behaviour, where + * tcp_create_openreq_child always was incrementing the + * equivalent to tcp_prot->socks (inet_sock_nr), so this have + * to be taken into account in all callers. -acme + */ + sk_refcnt_debug_inc(newsk); + sk_set_socket(newsk, NULL); + sk_tx_queue_clear(newsk); + RCU_INIT_POINTER(newsk->sk_wq, NULL); + + if (newsk->sk_prot->sockets_allocated) + sk_sockets_allocated_inc(newsk); + + if (sock_needs_netstamp(sk) && newsk->sk_flags & SK_FLAGS_TIMESTAMP) + net_enable_timestamp(); +out: + return newsk; +} +EXPORT_SYMBOL_GPL(sk_clone_lock); + +void sk_free_unlock_clone(struct sock *sk) +{ + /* It is still raw copy of parent, so invalidate + * destructor and make plain sk_free() */ + sk->sk_destruct = NULL; + bh_unlock_sock(sk); + sk_free(sk); +} +EXPORT_SYMBOL_GPL(sk_free_unlock_clone); + +static void sk_trim_gso_size(struct sock *sk) +{ + if (sk->sk_gso_max_size <= GSO_LEGACY_MAX_SIZE) + return; +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6 && + sk_is_tcp(sk) && + !ipv6_addr_v4mapped(&sk->sk_v6_rcv_saddr)) + return; +#endif + sk->sk_gso_max_size = GSO_LEGACY_MAX_SIZE; +} + +void sk_setup_caps(struct sock *sk, struct dst_entry *dst) +{ + u32 max_segs = 1; + + sk->sk_route_caps = dst->dev->features; + if (sk_is_tcp(sk)) + sk->sk_route_caps |= NETIF_F_GSO; + if (sk->sk_route_caps & NETIF_F_GSO) + sk->sk_route_caps |= NETIF_F_GSO_SOFTWARE; + if (unlikely(sk->sk_gso_disabled)) + sk->sk_route_caps &= ~NETIF_F_GSO_MASK; + if (sk_can_gso(sk)) { + if (dst->header_len && !xfrm_dst_offload_ok(dst)) { + sk->sk_route_caps &= ~NETIF_F_GSO_MASK; + } else { + sk->sk_route_caps |= NETIF_F_SG | NETIF_F_HW_CSUM; + /* pairs with the WRITE_ONCE() in netif_set_gso_max_size() */ + sk->sk_gso_max_size = READ_ONCE(dst->dev->gso_max_size); + sk_trim_gso_size(sk); + sk->sk_gso_max_size -= (MAX_TCP_HEADER + 1); + /* pairs with the WRITE_ONCE() in netif_set_gso_max_segs() */ + max_segs = max_t(u32, READ_ONCE(dst->dev->gso_max_segs), 1); + } + } + sk->sk_gso_max_segs = max_segs; + sk_dst_set(sk, dst); +} +EXPORT_SYMBOL_GPL(sk_setup_caps); + +/* + * Simple resource managers for sockets. + */ + + +/* + * Write buffer destructor automatically called from kfree_skb. + */ +void sock_wfree(struct sk_buff *skb) +{ + struct sock *sk = skb->sk; + unsigned int len = skb->truesize; + bool free; + + if (!sock_flag(sk, SOCK_USE_WRITE_QUEUE)) { + if (sock_flag(sk, SOCK_RCU_FREE) && + sk->sk_write_space == sock_def_write_space) { + rcu_read_lock(); + free = refcount_sub_and_test(len, &sk->sk_wmem_alloc); + sock_def_write_space_wfree(sk); + rcu_read_unlock(); + if (unlikely(free)) + __sk_free(sk); + return; + } + + /* + * Keep a reference on sk_wmem_alloc, this will be released + * after sk_write_space() call + */ + WARN_ON(refcount_sub_and_test(len - 1, &sk->sk_wmem_alloc)); + sk->sk_write_space(sk); + len = 1; + } + /* + * if sk_wmem_alloc reaches 0, we must finish what sk_free() + * could not do because of in-flight packets + */ + if (refcount_sub_and_test(len, &sk->sk_wmem_alloc)) + __sk_free(sk); +} +EXPORT_SYMBOL(sock_wfree); + +/* This variant of sock_wfree() is used by TCP, + * since it sets SOCK_USE_WRITE_QUEUE. + */ +void __sock_wfree(struct sk_buff *skb) +{ + struct sock *sk = skb->sk; + + if (refcount_sub_and_test(skb->truesize, &sk->sk_wmem_alloc)) + __sk_free(sk); +} + +void skb_set_owner_w(struct sk_buff *skb, struct sock *sk) +{ + skb_orphan(skb); + skb->sk = sk; +#ifdef CONFIG_INET + if (unlikely(!sk_fullsock(sk))) { + skb->destructor = sock_edemux; + sock_hold(sk); + return; + } +#endif + skb->destructor = sock_wfree; + skb_set_hash_from_sk(skb, sk); + /* + * We used to take a refcount on sk, but following operation + * is enough to guarantee sk_free() wont free this sock until + * all in-flight packets are completed + */ + refcount_add(skb->truesize, &sk->sk_wmem_alloc); +} +EXPORT_SYMBOL(skb_set_owner_w); + +static bool can_skb_orphan_partial(const struct sk_buff *skb) +{ +#ifdef CONFIG_TLS_DEVICE + /* Drivers depend on in-order delivery for crypto offload, + * partial orphan breaks out-of-order-OK logic. + */ + if (skb->decrypted) + return false; +#endif + return (skb->destructor == sock_wfree || + (IS_ENABLED(CONFIG_INET) && skb->destructor == tcp_wfree)); +} + +/* This helper is used by netem, as it can hold packets in its + * delay queue. We want to allow the owner socket to send more + * packets, as if they were already TX completed by a typical driver. + * But we also want to keep skb->sk set because some packet schedulers + * rely on it (sch_fq for example). + */ +void skb_orphan_partial(struct sk_buff *skb) +{ + if (skb_is_tcp_pure_ack(skb)) + return; + + if (can_skb_orphan_partial(skb) && skb_set_owner_sk_safe(skb, skb->sk)) + return; + + skb_orphan(skb); +} +EXPORT_SYMBOL(skb_orphan_partial); + +/* + * Read buffer destructor automatically called from kfree_skb. + */ +void sock_rfree(struct sk_buff *skb) +{ + struct sock *sk = skb->sk; + unsigned int len = skb->truesize; + + atomic_sub(len, &sk->sk_rmem_alloc); + sk_mem_uncharge(sk, len); +} +EXPORT_SYMBOL(sock_rfree); + +/* + * Buffer destructor for skbs that are not used directly in read or write + * path, e.g. for error handler skbs. Automatically called from kfree_skb. + */ +void sock_efree(struct sk_buff *skb) +{ + sock_put(skb->sk); +} +EXPORT_SYMBOL(sock_efree); + +/* Buffer destructor for prefetch/receive path where reference count may + * not be held, e.g. for listen sockets. + */ +#ifdef CONFIG_INET +void sock_pfree(struct sk_buff *skb) +{ + if (sk_is_refcounted(skb->sk)) + sock_gen_put(skb->sk); +} +EXPORT_SYMBOL(sock_pfree); +#endif /* CONFIG_INET */ + +kuid_t sock_i_uid(struct sock *sk) +{ + kuid_t uid; + + read_lock_bh(&sk->sk_callback_lock); + uid = sk->sk_socket ? SOCK_INODE(sk->sk_socket)->i_uid : GLOBAL_ROOT_UID; + read_unlock_bh(&sk->sk_callback_lock); + return uid; +} +EXPORT_SYMBOL(sock_i_uid); + +unsigned long __sock_i_ino(struct sock *sk) +{ + unsigned long ino; + + read_lock(&sk->sk_callback_lock); + ino = sk->sk_socket ? SOCK_INODE(sk->sk_socket)->i_ino : 0; + read_unlock(&sk->sk_callback_lock); + return ino; +} +EXPORT_SYMBOL(__sock_i_ino); + +unsigned long sock_i_ino(struct sock *sk) +{ + unsigned long ino; + + local_bh_disable(); + ino = __sock_i_ino(sk); + local_bh_enable(); + return ino; +} +EXPORT_SYMBOL(sock_i_ino); + +/* + * Allocate a skb from the socket's send buffer. + */ +struct sk_buff *sock_wmalloc(struct sock *sk, unsigned long size, int force, + gfp_t priority) +{ + if (force || + refcount_read(&sk->sk_wmem_alloc) < READ_ONCE(sk->sk_sndbuf)) { + struct sk_buff *skb = alloc_skb(size, priority); + + if (skb) { + skb_set_owner_w(skb, sk); + return skb; + } + } + return NULL; +} +EXPORT_SYMBOL(sock_wmalloc); + +static void sock_ofree(struct sk_buff *skb) +{ + struct sock *sk = skb->sk; + + atomic_sub(skb->truesize, &sk->sk_omem_alloc); +} + +struct sk_buff *sock_omalloc(struct sock *sk, unsigned long size, + gfp_t priority) +{ + struct sk_buff *skb; + + /* small safe race: SKB_TRUESIZE may differ from final skb->truesize */ + if (atomic_read(&sk->sk_omem_alloc) + SKB_TRUESIZE(size) > + READ_ONCE(sysctl_optmem_max)) + return NULL; + + skb = alloc_skb(size, priority); + if (!skb) + return NULL; + + atomic_add(skb->truesize, &sk->sk_omem_alloc); + skb->sk = sk; + skb->destructor = sock_ofree; + return skb; +} + +/* + * Allocate a memory block from the socket's option memory buffer. + */ +void *sock_kmalloc(struct sock *sk, int size, gfp_t priority) +{ + int optmem_max = READ_ONCE(sysctl_optmem_max); + + if ((unsigned int)size <= optmem_max && + atomic_read(&sk->sk_omem_alloc) + size < optmem_max) { + void *mem; + /* First do the add, to avoid the race if kmalloc + * might sleep. + */ + atomic_add(size, &sk->sk_omem_alloc); + mem = kmalloc(size, priority); + if (mem) + return mem; + atomic_sub(size, &sk->sk_omem_alloc); + } + return NULL; +} +EXPORT_SYMBOL(sock_kmalloc); + +/* Free an option memory block. Note, we actually want the inline + * here as this allows gcc to detect the nullify and fold away the + * condition entirely. + */ +static inline void __sock_kfree_s(struct sock *sk, void *mem, int size, + const bool nullify) +{ + if (WARN_ON_ONCE(!mem)) + return; + if (nullify) + kfree_sensitive(mem); + else + kfree(mem); + atomic_sub(size, &sk->sk_omem_alloc); +} + +void sock_kfree_s(struct sock *sk, void *mem, int size) +{ + __sock_kfree_s(sk, mem, size, false); +} +EXPORT_SYMBOL(sock_kfree_s); + +void sock_kzfree_s(struct sock *sk, void *mem, int size) +{ + __sock_kfree_s(sk, mem, size, true); +} +EXPORT_SYMBOL(sock_kzfree_s); + +/* It is almost wait_for_tcp_memory minus release_sock/lock_sock. + I think, these locks should be removed for datagram sockets. + */ +static long sock_wait_for_wmem(struct sock *sk, long timeo) +{ + DEFINE_WAIT(wait); + + sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); + for (;;) { + if (!timeo) + break; + if (signal_pending(current)) + break; + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + if (refcount_read(&sk->sk_wmem_alloc) < READ_ONCE(sk->sk_sndbuf)) + break; + if (READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN) + break; + if (READ_ONCE(sk->sk_err)) + break; + timeo = schedule_timeout(timeo); + } + finish_wait(sk_sleep(sk), &wait); + return timeo; +} + + +/* + * Generic send/receive buffer handlers + */ + +struct sk_buff *sock_alloc_send_pskb(struct sock *sk, unsigned long header_len, + unsigned long data_len, int noblock, + int *errcode, int max_page_order) +{ + struct sk_buff *skb; + long timeo; + int err; + + timeo = sock_sndtimeo(sk, noblock); + for (;;) { + err = sock_error(sk); + if (err != 0) + goto failure; + + err = -EPIPE; + if (READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN) + goto failure; + + if (sk_wmem_alloc_get(sk) < READ_ONCE(sk->sk_sndbuf)) + break; + + sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk); + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + err = -EAGAIN; + if (!timeo) + goto failure; + if (signal_pending(current)) + goto interrupted; + timeo = sock_wait_for_wmem(sk, timeo); + } + skb = alloc_skb_with_frags(header_len, data_len, max_page_order, + errcode, sk->sk_allocation); + if (skb) + skb_set_owner_w(skb, sk); + return skb; + +interrupted: + err = sock_intr_errno(timeo); +failure: + *errcode = err; + return NULL; +} +EXPORT_SYMBOL(sock_alloc_send_pskb); + +int __sock_cmsg_send(struct sock *sk, struct msghdr *msg, struct cmsghdr *cmsg, + struct sockcm_cookie *sockc) +{ + u32 tsflags; + + switch (cmsg->cmsg_type) { + case SO_MARK: + if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) && + !ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + if (cmsg->cmsg_len != CMSG_LEN(sizeof(u32))) + return -EINVAL; + sockc->mark = *(u32 *)CMSG_DATA(cmsg); + break; + case SO_TIMESTAMPING_OLD: + case SO_TIMESTAMPING_NEW: + if (cmsg->cmsg_len != CMSG_LEN(sizeof(u32))) + return -EINVAL; + + tsflags = *(u32 *)CMSG_DATA(cmsg); + if (tsflags & ~SOF_TIMESTAMPING_TX_RECORD_MASK) + return -EINVAL; + + sockc->tsflags &= ~SOF_TIMESTAMPING_TX_RECORD_MASK; + sockc->tsflags |= tsflags; + break; + case SCM_TXTIME: + if (!sock_flag(sk, SOCK_TXTIME)) + return -EINVAL; + if (cmsg->cmsg_len != CMSG_LEN(sizeof(u64))) + return -EINVAL; + sockc->transmit_time = get_unaligned((u64 *)CMSG_DATA(cmsg)); + break; + /* SCM_RIGHTS and SCM_CREDENTIALS are semantically in SOL_UNIX. */ + case SCM_RIGHTS: + case SCM_CREDENTIALS: + break; + default: + return -EINVAL; + } + return 0; +} +EXPORT_SYMBOL(__sock_cmsg_send); + +int sock_cmsg_send(struct sock *sk, struct msghdr *msg, + struct sockcm_cookie *sockc) +{ + struct cmsghdr *cmsg; + int ret; + + for_each_cmsghdr(cmsg, msg) { + if (!CMSG_OK(msg, cmsg)) + return -EINVAL; + if (cmsg->cmsg_level != SOL_SOCKET) + continue; + ret = __sock_cmsg_send(sk, msg, cmsg, sockc); + if (ret) + return ret; + } + return 0; +} +EXPORT_SYMBOL(sock_cmsg_send); + +static void sk_enter_memory_pressure(struct sock *sk) +{ + if (!sk->sk_prot->enter_memory_pressure) + return; + + sk->sk_prot->enter_memory_pressure(sk); +} + +static void sk_leave_memory_pressure(struct sock *sk) +{ + if (sk->sk_prot->leave_memory_pressure) { + INDIRECT_CALL_INET_1(sk->sk_prot->leave_memory_pressure, + tcp_leave_memory_pressure, sk); + } else { + unsigned long *memory_pressure = sk->sk_prot->memory_pressure; + + if (memory_pressure && READ_ONCE(*memory_pressure)) + WRITE_ONCE(*memory_pressure, 0); + } +} + +DEFINE_STATIC_KEY_FALSE(net_high_order_alloc_disable_key); + +/** + * skb_page_frag_refill - check that a page_frag contains enough room + * @sz: minimum size of the fragment we want to get + * @pfrag: pointer to page_frag + * @gfp: priority for memory allocation + * + * Note: While this allocator tries to use high order pages, there is + * no guarantee that allocations succeed. Therefore, @sz MUST be + * less or equal than PAGE_SIZE. + */ +bool skb_page_frag_refill(unsigned int sz, struct page_frag *pfrag, gfp_t gfp) +{ + if (pfrag->page) { + if (page_ref_count(pfrag->page) == 1) { + pfrag->offset = 0; + return true; + } + if (pfrag->offset + sz <= pfrag->size) + return true; + put_page(pfrag->page); + } + + pfrag->offset = 0; + if (SKB_FRAG_PAGE_ORDER && + !static_branch_unlikely(&net_high_order_alloc_disable_key)) { + /* Avoid direct reclaim but allow kswapd to wake */ + pfrag->page = alloc_pages((gfp & ~__GFP_DIRECT_RECLAIM) | + __GFP_COMP | __GFP_NOWARN | + __GFP_NORETRY, + SKB_FRAG_PAGE_ORDER); + if (likely(pfrag->page)) { + pfrag->size = PAGE_SIZE << SKB_FRAG_PAGE_ORDER; + return true; + } + } + pfrag->page = alloc_page(gfp); + if (likely(pfrag->page)) { + pfrag->size = PAGE_SIZE; + return true; + } + return false; +} +EXPORT_SYMBOL(skb_page_frag_refill); + +bool sk_page_frag_refill(struct sock *sk, struct page_frag *pfrag) +{ + if (likely(skb_page_frag_refill(32U, pfrag, sk->sk_allocation))) + return true; + + sk_enter_memory_pressure(sk); + sk_stream_moderate_sndbuf(sk); + return false; +} +EXPORT_SYMBOL(sk_page_frag_refill); + +void __lock_sock(struct sock *sk) + __releases(&sk->sk_lock.slock) + __acquires(&sk->sk_lock.slock) +{ + DEFINE_WAIT(wait); + + for (;;) { + prepare_to_wait_exclusive(&sk->sk_lock.wq, &wait, + TASK_UNINTERRUPTIBLE); + spin_unlock_bh(&sk->sk_lock.slock); + schedule(); + spin_lock_bh(&sk->sk_lock.slock); + if (!sock_owned_by_user(sk)) + break; + } + finish_wait(&sk->sk_lock.wq, &wait); +} + +void __release_sock(struct sock *sk) + __releases(&sk->sk_lock.slock) + __acquires(&sk->sk_lock.slock) +{ + struct sk_buff *skb, *next; + + while ((skb = sk->sk_backlog.head) != NULL) { + sk->sk_backlog.head = sk->sk_backlog.tail = NULL; + + spin_unlock_bh(&sk->sk_lock.slock); + + do { + next = skb->next; + prefetch(next); + DEBUG_NET_WARN_ON_ONCE(skb_dst_is_noref(skb)); + skb_mark_not_on_list(skb); + sk_backlog_rcv(sk, skb); + + cond_resched(); + + skb = next; + } while (skb != NULL); + + spin_lock_bh(&sk->sk_lock.slock); + } + + /* + * Doing the zeroing here guarantee we can not loop forever + * while a wild producer attempts to flood us. + */ + sk->sk_backlog.len = 0; +} + +void __sk_flush_backlog(struct sock *sk) +{ + spin_lock_bh(&sk->sk_lock.slock); + __release_sock(sk); + spin_unlock_bh(&sk->sk_lock.slock); +} +EXPORT_SYMBOL_GPL(__sk_flush_backlog); + +/** + * sk_wait_data - wait for data to arrive at sk_receive_queue + * @sk: sock to wait on + * @timeo: for how long + * @skb: last skb seen on sk_receive_queue + * + * Now socket state including sk->sk_err is changed only under lock, + * hence we may omit checks after joining wait queue. + * We check receive queue before schedule() only as optimization; + * it is very likely that release_sock() added new data. + */ +int sk_wait_data(struct sock *sk, long *timeo, const struct sk_buff *skb) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + int rc; + + add_wait_queue(sk_sleep(sk), &wait); + sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); + rc = sk_wait_event(sk, timeo, skb_peek_tail(&sk->sk_receive_queue) != skb, &wait); + sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); + remove_wait_queue(sk_sleep(sk), &wait); + return rc; +} +EXPORT_SYMBOL(sk_wait_data); + +/** + * __sk_mem_raise_allocated - increase memory_allocated + * @sk: socket + * @size: memory size to allocate + * @amt: pages to allocate + * @kind: allocation type + * + * Similar to __sk_mem_schedule(), but does not update sk_forward_alloc + */ +int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind) +{ + bool memcg_charge = mem_cgroup_sockets_enabled && sk->sk_memcg; + struct proto *prot = sk->sk_prot; + bool charged = true; + long allocated; + + sk_memory_allocated_add(sk, amt); + allocated = sk_memory_allocated(sk); + if (memcg_charge && + !(charged = mem_cgroup_charge_skmem(sk->sk_memcg, amt, + gfp_memcg_charge()))) + goto suppress_allocation; + + /* Under limit. */ + if (allocated <= sk_prot_mem_limits(sk, 0)) { + sk_leave_memory_pressure(sk); + return 1; + } + + /* Under pressure. */ + if (allocated > sk_prot_mem_limits(sk, 1)) + sk_enter_memory_pressure(sk); + + /* Over hard limit. */ + if (allocated > sk_prot_mem_limits(sk, 2)) + goto suppress_allocation; + + /* guarantee minimum buffer size under pressure */ + if (kind == SK_MEM_RECV) { + if (atomic_read(&sk->sk_rmem_alloc) < sk_get_rmem0(sk, prot)) + return 1; + + } else { /* SK_MEM_SEND */ + int wmem0 = sk_get_wmem0(sk, prot); + + if (sk->sk_type == SOCK_STREAM) { + if (sk->sk_wmem_queued < wmem0) + return 1; + } else if (refcount_read(&sk->sk_wmem_alloc) < wmem0) { + return 1; + } + } + + if (sk_has_memory_pressure(sk)) { + u64 alloc; + + if (!sk_under_memory_pressure(sk)) + return 1; + alloc = sk_sockets_allocated_read_positive(sk); + if (sk_prot_mem_limits(sk, 2) > alloc * + sk_mem_pages(sk->sk_wmem_queued + + atomic_read(&sk->sk_rmem_alloc) + + sk->sk_forward_alloc)) + return 1; + } + +suppress_allocation: + + if (kind == SK_MEM_SEND && sk->sk_type == SOCK_STREAM) { + sk_stream_moderate_sndbuf(sk); + + /* Fail only if socket is _under_ its sndbuf. + * In this case we cannot block, so that we have to fail. + */ + if (sk->sk_wmem_queued + size >= sk->sk_sndbuf) { + /* Force charge with __GFP_NOFAIL */ + if (memcg_charge && !charged) { + mem_cgroup_charge_skmem(sk->sk_memcg, amt, + gfp_memcg_charge() | __GFP_NOFAIL); + } + return 1; + } + } + + if (kind == SK_MEM_SEND || (kind == SK_MEM_RECV && charged)) + trace_sock_exceed_buf_limit(sk, prot, allocated, kind); + + sk_memory_allocated_sub(sk, amt); + + if (memcg_charge && charged) + mem_cgroup_uncharge_skmem(sk->sk_memcg, amt); + + return 0; +} + +/** + * __sk_mem_schedule - increase sk_forward_alloc and memory_allocated + * @sk: socket + * @size: memory size to allocate + * @kind: allocation type + * + * If kind is SK_MEM_SEND, it means wmem allocation. Otherwise it means + * rmem allocation. This function assumes that protocols which have + * memory_pressure use sk_wmem_queued as write buffer accounting. + */ +int __sk_mem_schedule(struct sock *sk, int size, int kind) +{ + int ret, amt = sk_mem_pages(size); + + sk_forward_alloc_add(sk, amt << PAGE_SHIFT); + ret = __sk_mem_raise_allocated(sk, size, amt, kind); + if (!ret) + sk_forward_alloc_add(sk, -(amt << PAGE_SHIFT)); + return ret; +} +EXPORT_SYMBOL(__sk_mem_schedule); + +/** + * __sk_mem_reduce_allocated - reclaim memory_allocated + * @sk: socket + * @amount: number of quanta + * + * Similar to __sk_mem_reclaim(), but does not update sk_forward_alloc + */ +void __sk_mem_reduce_allocated(struct sock *sk, int amount) +{ + sk_memory_allocated_sub(sk, amount); + + if (mem_cgroup_sockets_enabled && sk->sk_memcg) + mem_cgroup_uncharge_skmem(sk->sk_memcg, amount); + + if (sk_under_global_memory_pressure(sk) && + (sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0))) + sk_leave_memory_pressure(sk); +} + +/** + * __sk_mem_reclaim - reclaim sk_forward_alloc and memory_allocated + * @sk: socket + * @amount: number of bytes (rounded down to a PAGE_SIZE multiple) + */ +void __sk_mem_reclaim(struct sock *sk, int amount) +{ + amount >>= PAGE_SHIFT; + sk_forward_alloc_add(sk, -(amount << PAGE_SHIFT)); + __sk_mem_reduce_allocated(sk, amount); +} +EXPORT_SYMBOL(__sk_mem_reclaim); + +int sk_set_peek_off(struct sock *sk, int val) +{ + WRITE_ONCE(sk->sk_peek_off, val); + return 0; +} +EXPORT_SYMBOL_GPL(sk_set_peek_off); + +/* + * Set of default routines for initialising struct proto_ops when + * the protocol does not support a particular function. In certain + * cases where it makes no sense for a protocol to have a "do nothing" + * function, some default processing is provided. + */ + +int sock_no_bind(struct socket *sock, struct sockaddr *saddr, int len) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL(sock_no_bind); + +int sock_no_connect(struct socket *sock, struct sockaddr *saddr, + int len, int flags) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL(sock_no_connect); + +int sock_no_socketpair(struct socket *sock1, struct socket *sock2) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL(sock_no_socketpair); + +int sock_no_accept(struct socket *sock, struct socket *newsock, int flags, + bool kern) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL(sock_no_accept); + +int sock_no_getname(struct socket *sock, struct sockaddr *saddr, + int peer) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL(sock_no_getname); + +int sock_no_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL(sock_no_ioctl); + +int sock_no_listen(struct socket *sock, int backlog) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL(sock_no_listen); + +int sock_no_shutdown(struct socket *sock, int how) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL(sock_no_shutdown); + +int sock_no_sendmsg(struct socket *sock, struct msghdr *m, size_t len) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL(sock_no_sendmsg); + +int sock_no_sendmsg_locked(struct sock *sk, struct msghdr *m, size_t len) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL(sock_no_sendmsg_locked); + +int sock_no_recvmsg(struct socket *sock, struct msghdr *m, size_t len, + int flags) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL(sock_no_recvmsg); + +int sock_no_mmap(struct file *file, struct socket *sock, struct vm_area_struct *vma) +{ + /* Mirror missing mmap method error code */ + return -ENODEV; +} +EXPORT_SYMBOL(sock_no_mmap); + +/* + * When a file is received (via SCM_RIGHTS, etc), we must bump the + * various sock-based usage counts. + */ +void __receive_sock(struct file *file) +{ + struct socket *sock; + + sock = sock_from_file(file); + if (sock) { + sock_update_netprioidx(&sock->sk->sk_cgrp_data); + sock_update_classid(&sock->sk->sk_cgrp_data); + } +} + +ssize_t sock_no_sendpage(struct socket *sock, struct page *page, int offset, size_t size, int flags) +{ + ssize_t res; + struct msghdr msg = {.msg_flags = flags}; + struct kvec iov; + char *kaddr = kmap(page); + iov.iov_base = kaddr + offset; + iov.iov_len = size; + res = kernel_sendmsg(sock, &msg, &iov, 1, size); + kunmap(page); + return res; +} +EXPORT_SYMBOL(sock_no_sendpage); + +ssize_t sock_no_sendpage_locked(struct sock *sk, struct page *page, + int offset, size_t size, int flags) +{ + ssize_t res; + struct msghdr msg = {.msg_flags = flags}; + struct kvec iov; + char *kaddr = kmap(page); + + iov.iov_base = kaddr + offset; + iov.iov_len = size; + res = kernel_sendmsg_locked(sk, &msg, &iov, 1, size); + kunmap(page); + return res; +} +EXPORT_SYMBOL(sock_no_sendpage_locked); + +/* + * Default Socket Callbacks + */ + +static void sock_def_wakeup(struct sock *sk) +{ + struct socket_wq *wq; + + rcu_read_lock(); + wq = rcu_dereference(sk->sk_wq); + if (skwq_has_sleeper(wq)) + wake_up_interruptible_all(&wq->wait); + rcu_read_unlock(); +} + +static void sock_def_error_report(struct sock *sk) +{ + struct socket_wq *wq; + + rcu_read_lock(); + wq = rcu_dereference(sk->sk_wq); + if (skwq_has_sleeper(wq)) + wake_up_interruptible_poll(&wq->wait, EPOLLERR); + sk_wake_async(sk, SOCK_WAKE_IO, POLL_ERR); + rcu_read_unlock(); +} + +void sock_def_readable(struct sock *sk) +{ + struct socket_wq *wq; + + rcu_read_lock(); + wq = rcu_dereference(sk->sk_wq); + if (skwq_has_sleeper(wq)) + wake_up_interruptible_sync_poll(&wq->wait, EPOLLIN | EPOLLPRI | + EPOLLRDNORM | EPOLLRDBAND); + sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN); + rcu_read_unlock(); +} + +static void sock_def_write_space(struct sock *sk) +{ + struct socket_wq *wq; + + rcu_read_lock(); + + /* Do not wake up a writer until he can make "significant" + * progress. --DaveM + */ + if (sock_writeable(sk)) { + wq = rcu_dereference(sk->sk_wq); + if (skwq_has_sleeper(wq)) + wake_up_interruptible_sync_poll(&wq->wait, EPOLLOUT | + EPOLLWRNORM | EPOLLWRBAND); + + /* Should agree with poll, otherwise some programs break */ + sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT); + } + + rcu_read_unlock(); +} + +/* An optimised version of sock_def_write_space(), should only be called + * for SOCK_RCU_FREE sockets under RCU read section and after putting + * ->sk_wmem_alloc. + */ +static void sock_def_write_space_wfree(struct sock *sk) +{ + /* Do not wake up a writer until he can make "significant" + * progress. --DaveM + */ + if (sock_writeable(sk)) { + struct socket_wq *wq = rcu_dereference(sk->sk_wq); + + /* rely on refcount_sub from sock_wfree() */ + smp_mb__after_atomic(); + if (wq && waitqueue_active(&wq->wait)) + wake_up_interruptible_sync_poll(&wq->wait, EPOLLOUT | + EPOLLWRNORM | EPOLLWRBAND); + + /* Should agree with poll, otherwise some programs break */ + sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT); + } +} + +static void sock_def_destruct(struct sock *sk) +{ +} + +void sk_send_sigurg(struct sock *sk) +{ + if (sk->sk_socket && sk->sk_socket->file) + if (send_sigurg(&sk->sk_socket->file->f_owner)) + sk_wake_async(sk, SOCK_WAKE_URG, POLL_PRI); +} +EXPORT_SYMBOL(sk_send_sigurg); + +void sk_reset_timer(struct sock *sk, struct timer_list* timer, + unsigned long expires) +{ + if (!mod_timer(timer, expires)) + sock_hold(sk); +} +EXPORT_SYMBOL(sk_reset_timer); + +void sk_stop_timer(struct sock *sk, struct timer_list* timer) +{ + if (del_timer(timer)) + __sock_put(sk); +} +EXPORT_SYMBOL(sk_stop_timer); + +void sk_stop_timer_sync(struct sock *sk, struct timer_list *timer) +{ + if (del_timer_sync(timer)) + __sock_put(sk); +} +EXPORT_SYMBOL(sk_stop_timer_sync); + +void sock_init_data_uid(struct socket *sock, struct sock *sk, kuid_t uid) +{ + sk_init_common(sk); + sk->sk_send_head = NULL; + + timer_setup(&sk->sk_timer, NULL, 0); + + sk->sk_allocation = GFP_KERNEL; + sk->sk_rcvbuf = READ_ONCE(sysctl_rmem_default); + sk->sk_sndbuf = READ_ONCE(sysctl_wmem_default); + sk->sk_state = TCP_CLOSE; + sk_set_socket(sk, sock); + + sock_set_flag(sk, SOCK_ZAPPED); + + if (sock) { + sk->sk_type = sock->type; + RCU_INIT_POINTER(sk->sk_wq, &sock->wq); + sock->sk = sk; + } else { + RCU_INIT_POINTER(sk->sk_wq, NULL); + } + sk->sk_uid = uid; + + rwlock_init(&sk->sk_callback_lock); + if (sk->sk_kern_sock) + lockdep_set_class_and_name( + &sk->sk_callback_lock, + af_kern_callback_keys + sk->sk_family, + af_family_kern_clock_key_strings[sk->sk_family]); + else + lockdep_set_class_and_name( + &sk->sk_callback_lock, + af_callback_keys + sk->sk_family, + af_family_clock_key_strings[sk->sk_family]); + + sk->sk_state_change = sock_def_wakeup; + sk->sk_data_ready = sock_def_readable; + sk->sk_write_space = sock_def_write_space; + sk->sk_error_report = sock_def_error_report; + sk->sk_destruct = sock_def_destruct; + + sk->sk_frag.page = NULL; + sk->sk_frag.offset = 0; + sk->sk_peek_off = -1; + + sk->sk_peer_pid = NULL; + sk->sk_peer_cred = NULL; + spin_lock_init(&sk->sk_peer_lock); + + sk->sk_write_pending = 0; + sk->sk_rcvlowat = 1; + sk->sk_rcvtimeo = MAX_SCHEDULE_TIMEOUT; + sk->sk_sndtimeo = MAX_SCHEDULE_TIMEOUT; + + sk->sk_stamp = SK_DEFAULT_STAMP; +#if BITS_PER_LONG==32 + seqlock_init(&sk->sk_stamp_seq); +#endif + atomic_set(&sk->sk_zckey, 0); + +#ifdef CONFIG_NET_RX_BUSY_POLL + sk->sk_napi_id = 0; + sk->sk_ll_usec = READ_ONCE(sysctl_net_busy_read); +#endif + + sk->sk_max_pacing_rate = ~0UL; + sk->sk_pacing_rate = ~0UL; + WRITE_ONCE(sk->sk_pacing_shift, 10); + sk->sk_incoming_cpu = -1; + + sk_rx_queue_clear(sk); + /* + * Before updating sk_refcnt, we must commit prior changes to memory + * (Documentation/RCU/rculist_nulls.rst for details) + */ + smp_wmb(); + refcount_set(&sk->sk_refcnt, 1); + atomic_set(&sk->sk_drops, 0); +} +EXPORT_SYMBOL(sock_init_data_uid); + +void sock_init_data(struct socket *sock, struct sock *sk) +{ + kuid_t uid = sock ? + SOCK_INODE(sock)->i_uid : + make_kuid(sock_net(sk)->user_ns, 0); + + sock_init_data_uid(sock, sk, uid); +} +EXPORT_SYMBOL(sock_init_data); + +void lock_sock_nested(struct sock *sk, int subclass) +{ + /* The sk_lock has mutex_lock() semantics here. */ + mutex_acquire(&sk->sk_lock.dep_map, subclass, 0, _RET_IP_); + + might_sleep(); + spin_lock_bh(&sk->sk_lock.slock); + if (sock_owned_by_user_nocheck(sk)) + __lock_sock(sk); + sk->sk_lock.owned = 1; + spin_unlock_bh(&sk->sk_lock.slock); +} +EXPORT_SYMBOL(lock_sock_nested); + +void release_sock(struct sock *sk) +{ + spin_lock_bh(&sk->sk_lock.slock); + if (sk->sk_backlog.tail) + __release_sock(sk); + + /* Warning : release_cb() might need to release sk ownership, + * ie call sock_release_ownership(sk) before us. + */ + if (sk->sk_prot->release_cb) + sk->sk_prot->release_cb(sk); + + sock_release_ownership(sk); + if (waitqueue_active(&sk->sk_lock.wq)) + wake_up(&sk->sk_lock.wq); + spin_unlock_bh(&sk->sk_lock.slock); +} +EXPORT_SYMBOL(release_sock); + +bool __lock_sock_fast(struct sock *sk) __acquires(&sk->sk_lock.slock) +{ + might_sleep(); + spin_lock_bh(&sk->sk_lock.slock); + + if (!sock_owned_by_user_nocheck(sk)) { + /* + * Fast path return with bottom halves disabled and + * sock::sk_lock.slock held. + * + * The 'mutex' is not contended and holding + * sock::sk_lock.slock prevents all other lockers to + * proceed so the corresponding unlock_sock_fast() can + * avoid the slow path of release_sock() completely and + * just release slock. + * + * From a semantical POV this is equivalent to 'acquiring' + * the 'mutex', hence the corresponding lockdep + * mutex_release() has to happen in the fast path of + * unlock_sock_fast(). + */ + return false; + } + + __lock_sock(sk); + sk->sk_lock.owned = 1; + __acquire(&sk->sk_lock.slock); + spin_unlock_bh(&sk->sk_lock.slock); + return true; +} +EXPORT_SYMBOL(__lock_sock_fast); + +int sock_gettstamp(struct socket *sock, void __user *userstamp, + bool timeval, bool time32) +{ + struct sock *sk = sock->sk; + struct timespec64 ts; + + sock_enable_timestamp(sk, SOCK_TIMESTAMP); + ts = ktime_to_timespec64(sock_read_timestamp(sk)); + if (ts.tv_sec == -1) + return -ENOENT; + if (ts.tv_sec == 0) { + ktime_t kt = ktime_get_real(); + sock_write_timestamp(sk, kt); + ts = ktime_to_timespec64(kt); + } + + if (timeval) + ts.tv_nsec /= 1000; + +#ifdef CONFIG_COMPAT_32BIT_TIME + if (time32) + return put_old_timespec32(&ts, userstamp); +#endif +#ifdef CONFIG_SPARC64 + /* beware of padding in sparc64 timeval */ + if (timeval && !in_compat_syscall()) { + struct __kernel_old_timeval __user tv = { + .tv_sec = ts.tv_sec, + .tv_usec = ts.tv_nsec, + }; + if (copy_to_user(userstamp, &tv, sizeof(tv))) + return -EFAULT; + return 0; + } +#endif + return put_timespec64(&ts, userstamp); +} +EXPORT_SYMBOL(sock_gettstamp); + +void sock_enable_timestamp(struct sock *sk, enum sock_flags flag) +{ + if (!sock_flag(sk, flag)) { + unsigned long previous_flags = sk->sk_flags; + + sock_set_flag(sk, flag); + /* + * we just set one of the two flags which require net + * time stamping, but time stamping might have been on + * already because of the other one + */ + if (sock_needs_netstamp(sk) && + !(previous_flags & SK_FLAGS_TIMESTAMP)) + net_enable_timestamp(); + } +} + +int sock_recv_errqueue(struct sock *sk, struct msghdr *msg, int len, + int level, int type) +{ + struct sock_exterr_skb *serr; + struct sk_buff *skb; + int copied, err; + + err = -EAGAIN; + skb = sock_dequeue_err_skb(sk); + if (skb == NULL) + goto out; + + copied = skb->len; + if (copied > len) { + msg->msg_flags |= MSG_TRUNC; + copied = len; + } + err = skb_copy_datagram_msg(skb, 0, msg, copied); + if (err) + goto out_free_skb; + + sock_recv_timestamp(msg, sk, skb); + + serr = SKB_EXT_ERR(skb); + put_cmsg(msg, level, type, sizeof(serr->ee), &serr->ee); + + msg->msg_flags |= MSG_ERRQUEUE; + err = copied; + +out_free_skb: + kfree_skb(skb); +out: + return err; +} +EXPORT_SYMBOL(sock_recv_errqueue); + +/* + * Get a socket option on an socket. + * + * FIX: POSIX 1003.1g is very ambiguous here. It states that + * asynchronous errors should be reported by getsockopt. We assume + * this means if you specify SO_ERROR (otherwise whats the point of it). + */ +int sock_common_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + + /* IPV6_ADDRFORM can change sk->sk_prot under us. */ + return READ_ONCE(sk->sk_prot)->getsockopt(sk, level, optname, optval, optlen); +} +EXPORT_SYMBOL(sock_common_getsockopt); + +int sock_common_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, + int flags) +{ + struct sock *sk = sock->sk; + int addr_len = 0; + int err; + + err = sk->sk_prot->recvmsg(sk, msg, size, flags, &addr_len); + if (err >= 0) + msg->msg_namelen = addr_len; + return err; +} +EXPORT_SYMBOL(sock_common_recvmsg); + +/* + * Set socket options on an inet socket. + */ +int sock_common_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + + /* IPV6_ADDRFORM can change sk->sk_prot under us. */ + return READ_ONCE(sk->sk_prot)->setsockopt(sk, level, optname, optval, optlen); +} +EXPORT_SYMBOL(sock_common_setsockopt); + +void sk_common_release(struct sock *sk) +{ + if (sk->sk_prot->destroy) + sk->sk_prot->destroy(sk); + + /* + * Observation: when sk_common_release is called, processes have + * no access to socket. But net still has. + * Step one, detach it from networking: + * + * A. Remove from hash tables. + */ + + sk->sk_prot->unhash(sk); + + /* + * In this point socket cannot receive new packets, but it is possible + * that some packets are in flight because some CPU runs receiver and + * did hash table lookup before we unhashed socket. They will achieve + * receive queue and will be purged by socket destructor. + * + * Also we still have packets pending on receive queue and probably, + * our own packets waiting in device queues. sock_destroy will drain + * receive queue, but transmitted packets will delay socket destruction + * until the last reference will be released. + */ + + sock_orphan(sk); + + xfrm_sk_free_policy(sk); + + sk_refcnt_debug_release(sk); + + sock_put(sk); +} +EXPORT_SYMBOL(sk_common_release); + +void sk_get_meminfo(const struct sock *sk, u32 *mem) +{ + memset(mem, 0, sizeof(*mem) * SK_MEMINFO_VARS); + + mem[SK_MEMINFO_RMEM_ALLOC] = sk_rmem_alloc_get(sk); + mem[SK_MEMINFO_RCVBUF] = READ_ONCE(sk->sk_rcvbuf); + mem[SK_MEMINFO_WMEM_ALLOC] = sk_wmem_alloc_get(sk); + mem[SK_MEMINFO_SNDBUF] = READ_ONCE(sk->sk_sndbuf); + mem[SK_MEMINFO_FWD_ALLOC] = sk_forward_alloc_get(sk); + mem[SK_MEMINFO_WMEM_QUEUED] = READ_ONCE(sk->sk_wmem_queued); + mem[SK_MEMINFO_OPTMEM] = atomic_read(&sk->sk_omem_alloc); + mem[SK_MEMINFO_BACKLOG] = READ_ONCE(sk->sk_backlog.len); + mem[SK_MEMINFO_DROPS] = atomic_read(&sk->sk_drops); +} + +#ifdef CONFIG_PROC_FS +static DECLARE_BITMAP(proto_inuse_idx, PROTO_INUSE_NR); + +int sock_prot_inuse_get(struct net *net, struct proto *prot) +{ + int cpu, idx = prot->inuse_idx; + int res = 0; + + for_each_possible_cpu(cpu) + res += per_cpu_ptr(net->core.prot_inuse, cpu)->val[idx]; + + return res >= 0 ? res : 0; +} +EXPORT_SYMBOL_GPL(sock_prot_inuse_get); + +int sock_inuse_get(struct net *net) +{ + int cpu, res = 0; + + for_each_possible_cpu(cpu) + res += per_cpu_ptr(net->core.prot_inuse, cpu)->all; + + return res; +} + +EXPORT_SYMBOL_GPL(sock_inuse_get); + +static int __net_init sock_inuse_init_net(struct net *net) +{ + net->core.prot_inuse = alloc_percpu(struct prot_inuse); + if (net->core.prot_inuse == NULL) + return -ENOMEM; + return 0; +} + +static void __net_exit sock_inuse_exit_net(struct net *net) +{ + free_percpu(net->core.prot_inuse); +} + +static struct pernet_operations net_inuse_ops = { + .init = sock_inuse_init_net, + .exit = sock_inuse_exit_net, +}; + +static __init int net_inuse_init(void) +{ + if (register_pernet_subsys(&net_inuse_ops)) + panic("Cannot initialize net inuse counters"); + + return 0; +} + +core_initcall(net_inuse_init); + +static int assign_proto_idx(struct proto *prot) +{ + prot->inuse_idx = find_first_zero_bit(proto_inuse_idx, PROTO_INUSE_NR); + + if (unlikely(prot->inuse_idx == PROTO_INUSE_NR - 1)) { + pr_err("PROTO_INUSE_NR exhausted\n"); + return -ENOSPC; + } + + set_bit(prot->inuse_idx, proto_inuse_idx); + return 0; +} + +static void release_proto_idx(struct proto *prot) +{ + if (prot->inuse_idx != PROTO_INUSE_NR - 1) + clear_bit(prot->inuse_idx, proto_inuse_idx); +} +#else +static inline int assign_proto_idx(struct proto *prot) +{ + return 0; +} + +static inline void release_proto_idx(struct proto *prot) +{ +} + +#endif + +static void tw_prot_cleanup(struct timewait_sock_ops *twsk_prot) +{ + if (!twsk_prot) + return; + kfree(twsk_prot->twsk_slab_name); + twsk_prot->twsk_slab_name = NULL; + kmem_cache_destroy(twsk_prot->twsk_slab); + twsk_prot->twsk_slab = NULL; +} + +static int tw_prot_init(const struct proto *prot) +{ + struct timewait_sock_ops *twsk_prot = prot->twsk_prot; + + if (!twsk_prot) + return 0; + + twsk_prot->twsk_slab_name = kasprintf(GFP_KERNEL, "tw_sock_%s", + prot->name); + if (!twsk_prot->twsk_slab_name) + return -ENOMEM; + + twsk_prot->twsk_slab = + kmem_cache_create(twsk_prot->twsk_slab_name, + twsk_prot->twsk_obj_size, 0, + SLAB_ACCOUNT | prot->slab_flags, + NULL); + if (!twsk_prot->twsk_slab) { + pr_crit("%s: Can't create timewait sock SLAB cache!\n", + prot->name); + return -ENOMEM; + } + + return 0; +} + +static void req_prot_cleanup(struct request_sock_ops *rsk_prot) +{ + if (!rsk_prot) + return; + kfree(rsk_prot->slab_name); + rsk_prot->slab_name = NULL; + kmem_cache_destroy(rsk_prot->slab); + rsk_prot->slab = NULL; +} + +static int req_prot_init(const struct proto *prot) +{ + struct request_sock_ops *rsk_prot = prot->rsk_prot; + + if (!rsk_prot) + return 0; + + rsk_prot->slab_name = kasprintf(GFP_KERNEL, "request_sock_%s", + prot->name); + if (!rsk_prot->slab_name) + return -ENOMEM; + + rsk_prot->slab = kmem_cache_create(rsk_prot->slab_name, + rsk_prot->obj_size, 0, + SLAB_ACCOUNT | prot->slab_flags, + NULL); + + if (!rsk_prot->slab) { + pr_crit("%s: Can't create request sock SLAB cache!\n", + prot->name); + return -ENOMEM; + } + return 0; +} + +int proto_register(struct proto *prot, int alloc_slab) +{ + int ret = -ENOBUFS; + + if (prot->memory_allocated && !prot->sysctl_mem) { + pr_err("%s: missing sysctl_mem\n", prot->name); + return -EINVAL; + } + if (prot->memory_allocated && !prot->per_cpu_fw_alloc) { + pr_err("%s: missing per_cpu_fw_alloc\n", prot->name); + return -EINVAL; + } + if (alloc_slab) { + prot->slab = kmem_cache_create_usercopy(prot->name, + prot->obj_size, 0, + SLAB_HWCACHE_ALIGN | SLAB_ACCOUNT | + prot->slab_flags, + prot->useroffset, prot->usersize, + NULL); + + if (prot->slab == NULL) { + pr_crit("%s: Can't create sock SLAB cache!\n", + prot->name); + goto out; + } + + if (req_prot_init(prot)) + goto out_free_request_sock_slab; + + if (tw_prot_init(prot)) + goto out_free_timewait_sock_slab; + } + + mutex_lock(&proto_list_mutex); + ret = assign_proto_idx(prot); + if (ret) { + mutex_unlock(&proto_list_mutex); + goto out_free_timewait_sock_slab; + } + list_add(&prot->node, &proto_list); + mutex_unlock(&proto_list_mutex); + return ret; + +out_free_timewait_sock_slab: + if (alloc_slab) + tw_prot_cleanup(prot->twsk_prot); +out_free_request_sock_slab: + if (alloc_slab) { + req_prot_cleanup(prot->rsk_prot); + + kmem_cache_destroy(prot->slab); + prot->slab = NULL; + } +out: + return ret; +} +EXPORT_SYMBOL(proto_register); + +void proto_unregister(struct proto *prot) +{ + mutex_lock(&proto_list_mutex); + release_proto_idx(prot); + list_del(&prot->node); + mutex_unlock(&proto_list_mutex); + + kmem_cache_destroy(prot->slab); + prot->slab = NULL; + + req_prot_cleanup(prot->rsk_prot); + tw_prot_cleanup(prot->twsk_prot); +} +EXPORT_SYMBOL(proto_unregister); + +int sock_load_diag_module(int family, int protocol) +{ + if (!protocol) { + if (!sock_is_registered(family)) + return -ENOENT; + + return request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK, + NETLINK_SOCK_DIAG, family); + } + +#ifdef CONFIG_INET + if (family == AF_INET && + protocol != IPPROTO_RAW && + protocol < MAX_INET_PROTOS && + !rcu_access_pointer(inet_protos[protocol])) + return -ENOENT; +#endif + + return request_module("net-pf-%d-proto-%d-type-%d-%d", PF_NETLINK, + NETLINK_SOCK_DIAG, family, protocol); +} +EXPORT_SYMBOL(sock_load_diag_module); + +#ifdef CONFIG_PROC_FS +static void *proto_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(proto_list_mutex) +{ + mutex_lock(&proto_list_mutex); + return seq_list_start_head(&proto_list, *pos); +} + +static void *proto_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + return seq_list_next(v, &proto_list, pos); +} + +static void proto_seq_stop(struct seq_file *seq, void *v) + __releases(proto_list_mutex) +{ + mutex_unlock(&proto_list_mutex); +} + +static char proto_method_implemented(const void *method) +{ + return method == NULL ? 'n' : 'y'; +} +static long sock_prot_memory_allocated(struct proto *proto) +{ + return proto->memory_allocated != NULL ? proto_memory_allocated(proto) : -1L; +} + +static const char *sock_prot_memory_pressure(struct proto *proto) +{ + return proto->memory_pressure != NULL ? + proto_memory_pressure(proto) ? "yes" : "no" : "NI"; +} + +static void proto_seq_printf(struct seq_file *seq, struct proto *proto) +{ + + seq_printf(seq, "%-9s %4u %6d %6ld %-3s %6u %-3s %-10s " + "%2c %2c %2c %2c %2c %2c %2c %2c %2c %2c %2c %2c %2c %2c %2c %2c %2c %2c %2c\n", + proto->name, + proto->obj_size, + sock_prot_inuse_get(seq_file_net(seq), proto), + sock_prot_memory_allocated(proto), + sock_prot_memory_pressure(proto), + proto->max_header, + proto->slab == NULL ? "no" : "yes", + module_name(proto->owner), + proto_method_implemented(proto->close), + proto_method_implemented(proto->connect), + proto_method_implemented(proto->disconnect), + proto_method_implemented(proto->accept), + proto_method_implemented(proto->ioctl), + proto_method_implemented(proto->init), + proto_method_implemented(proto->destroy), + proto_method_implemented(proto->shutdown), + proto_method_implemented(proto->setsockopt), + proto_method_implemented(proto->getsockopt), + proto_method_implemented(proto->sendmsg), + proto_method_implemented(proto->recvmsg), + proto_method_implemented(proto->sendpage), + proto_method_implemented(proto->bind), + proto_method_implemented(proto->backlog_rcv), + proto_method_implemented(proto->hash), + proto_method_implemented(proto->unhash), + proto_method_implemented(proto->get_port), + proto_method_implemented(proto->enter_memory_pressure)); +} + +static int proto_seq_show(struct seq_file *seq, void *v) +{ + if (v == &proto_list) + seq_printf(seq, "%-9s %-4s %-8s %-6s %-5s %-7s %-4s %-10s %s", + "protocol", + "size", + "sockets", + "memory", + "press", + "maxhdr", + "slab", + "module", + "cl co di ac io in de sh ss gs se re sp bi br ha uh gp em\n"); + else + proto_seq_printf(seq, list_entry(v, struct proto, node)); + return 0; +} + +static const struct seq_operations proto_seq_ops = { + .start = proto_seq_start, + .next = proto_seq_next, + .stop = proto_seq_stop, + .show = proto_seq_show, +}; + +static __net_init int proto_init_net(struct net *net) +{ + if (!proc_create_net("protocols", 0444, net->proc_net, &proto_seq_ops, + sizeof(struct seq_net_private))) + return -ENOMEM; + + return 0; +} + +static __net_exit void proto_exit_net(struct net *net) +{ + remove_proc_entry("protocols", net->proc_net); +} + + +static __net_initdata struct pernet_operations proto_net_ops = { + .init = proto_init_net, + .exit = proto_exit_net, +}; + +static int __init proto_init(void) +{ + return register_pernet_subsys(&proto_net_ops); +} + +subsys_initcall(proto_init); + +#endif /* PROC_FS */ + +#ifdef CONFIG_NET_RX_BUSY_POLL +bool sk_busy_loop_end(void *p, unsigned long start_time) +{ + struct sock *sk = p; + + if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) + return true; + + if (sk_is_udp(sk) && + !skb_queue_empty_lockless(&udp_sk(sk)->reader_queue)) + return true; + + return sk_busy_loop_timeout(sk, start_time); +} +EXPORT_SYMBOL(sk_busy_loop_end); +#endif /* CONFIG_NET_RX_BUSY_POLL */ + +int sock_bind_add(struct sock *sk, struct sockaddr *addr, int addr_len) +{ + if (!sk->sk_prot->bind_add) + return -EOPNOTSUPP; + return sk->sk_prot->bind_add(sk, addr, addr_len); +} +EXPORT_SYMBOL(sock_bind_add); diff --git a/net/core/sock_destructor.h b/net/core/sock_destructor.h new file mode 100644 index 000000000..2f396e6bf --- /dev/null +++ b/net/core/sock_destructor.h @@ -0,0 +1,12 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +#ifndef _NET_CORE_SOCK_DESTRUCTOR_H +#define _NET_CORE_SOCK_DESTRUCTOR_H +#include + +static inline bool is_skb_wmem(const struct sk_buff *skb) +{ + return skb->destructor == sock_wfree || + skb->destructor == __sock_wfree || + (IS_ENABLED(CONFIG_INET) && skb->destructor == tcp_wfree); +} +#endif diff --git a/net/core/sock_diag.c b/net/core/sock_diag.c new file mode 100644 index 000000000..f7cf74cdd --- /dev/null +++ b/net/core/sock_diag.c @@ -0,0 +1,340 @@ +/* License: GPL */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static const struct sock_diag_handler *sock_diag_handlers[AF_MAX]; +static int (*inet_rcv_compat)(struct sk_buff *skb, struct nlmsghdr *nlh); +static DEFINE_MUTEX(sock_diag_table_mutex); +static struct workqueue_struct *broadcast_wq; + +DEFINE_COOKIE(sock_cookie); + +u64 __sock_gen_cookie(struct sock *sk) +{ + while (1) { + u64 res = atomic64_read(&sk->sk_cookie); + + if (res) + return res; + res = gen_cookie_next(&sock_cookie); + atomic64_cmpxchg(&sk->sk_cookie, 0, res); + } +} + +int sock_diag_check_cookie(struct sock *sk, const __u32 *cookie) +{ + u64 res; + + if (cookie[0] == INET_DIAG_NOCOOKIE && cookie[1] == INET_DIAG_NOCOOKIE) + return 0; + + res = sock_gen_cookie(sk); + if ((u32)res != cookie[0] || (u32)(res >> 32) != cookie[1]) + return -ESTALE; + + return 0; +} +EXPORT_SYMBOL_GPL(sock_diag_check_cookie); + +void sock_diag_save_cookie(struct sock *sk, __u32 *cookie) +{ + u64 res = sock_gen_cookie(sk); + + cookie[0] = (u32)res; + cookie[1] = (u32)(res >> 32); +} +EXPORT_SYMBOL_GPL(sock_diag_save_cookie); + +int sock_diag_put_meminfo(struct sock *sk, struct sk_buff *skb, int attrtype) +{ + u32 mem[SK_MEMINFO_VARS]; + + sk_get_meminfo(sk, mem); + + return nla_put(skb, attrtype, sizeof(mem), &mem); +} +EXPORT_SYMBOL_GPL(sock_diag_put_meminfo); + +int sock_diag_put_filterinfo(bool may_report_filterinfo, struct sock *sk, + struct sk_buff *skb, int attrtype) +{ + struct sock_fprog_kern *fprog; + struct sk_filter *filter; + struct nlattr *attr; + unsigned int flen; + int err = 0; + + if (!may_report_filterinfo) { + nla_reserve(skb, attrtype, 0); + return 0; + } + + rcu_read_lock(); + filter = rcu_dereference(sk->sk_filter); + if (!filter) + goto out; + + fprog = filter->prog->orig_prog; + if (!fprog) + goto out; + + flen = bpf_classic_proglen(fprog); + + attr = nla_reserve(skb, attrtype, flen); + if (attr == NULL) { + err = -EMSGSIZE; + goto out; + } + + memcpy(nla_data(attr), fprog->filter, flen); +out: + rcu_read_unlock(); + return err; +} +EXPORT_SYMBOL(sock_diag_put_filterinfo); + +struct broadcast_sk { + struct sock *sk; + struct work_struct work; +}; + +static size_t sock_diag_nlmsg_size(void) +{ + return NLMSG_ALIGN(sizeof(struct inet_diag_msg) + + nla_total_size(sizeof(u8)) /* INET_DIAG_PROTOCOL */ + + nla_total_size_64bit(sizeof(struct tcp_info))); /* INET_DIAG_INFO */ +} + +static void sock_diag_broadcast_destroy_work(struct work_struct *work) +{ + struct broadcast_sk *bsk = + container_of(work, struct broadcast_sk, work); + struct sock *sk = bsk->sk; + const struct sock_diag_handler *hndl; + struct sk_buff *skb; + const enum sknetlink_groups group = sock_diag_destroy_group(sk); + int err = -1; + + WARN_ON(group == SKNLGRP_NONE); + + skb = nlmsg_new(sock_diag_nlmsg_size(), GFP_KERNEL); + if (!skb) + goto out; + + mutex_lock(&sock_diag_table_mutex); + hndl = sock_diag_handlers[sk->sk_family]; + if (hndl && hndl->get_info) + err = hndl->get_info(skb, sk); + mutex_unlock(&sock_diag_table_mutex); + + if (!err) + nlmsg_multicast(sock_net(sk)->diag_nlsk, skb, 0, group, + GFP_KERNEL); + else + kfree_skb(skb); +out: + sk_destruct(sk); + kfree(bsk); +} + +void sock_diag_broadcast_destroy(struct sock *sk) +{ + /* Note, this function is often called from an interrupt context. */ + struct broadcast_sk *bsk = + kmalloc(sizeof(struct broadcast_sk), GFP_ATOMIC); + if (!bsk) + return sk_destruct(sk); + bsk->sk = sk; + INIT_WORK(&bsk->work, sock_diag_broadcast_destroy_work); + queue_work(broadcast_wq, &bsk->work); +} + +void sock_diag_register_inet_compat(int (*fn)(struct sk_buff *skb, struct nlmsghdr *nlh)) +{ + mutex_lock(&sock_diag_table_mutex); + inet_rcv_compat = fn; + mutex_unlock(&sock_diag_table_mutex); +} +EXPORT_SYMBOL_GPL(sock_diag_register_inet_compat); + +void sock_diag_unregister_inet_compat(int (*fn)(struct sk_buff *skb, struct nlmsghdr *nlh)) +{ + mutex_lock(&sock_diag_table_mutex); + inet_rcv_compat = NULL; + mutex_unlock(&sock_diag_table_mutex); +} +EXPORT_SYMBOL_GPL(sock_diag_unregister_inet_compat); + +int sock_diag_register(const struct sock_diag_handler *hndl) +{ + int err = 0; + + if (hndl->family >= AF_MAX) + return -EINVAL; + + mutex_lock(&sock_diag_table_mutex); + if (sock_diag_handlers[hndl->family]) + err = -EBUSY; + else + sock_diag_handlers[hndl->family] = hndl; + mutex_unlock(&sock_diag_table_mutex); + + return err; +} +EXPORT_SYMBOL_GPL(sock_diag_register); + +void sock_diag_unregister(const struct sock_diag_handler *hnld) +{ + int family = hnld->family; + + if (family >= AF_MAX) + return; + + mutex_lock(&sock_diag_table_mutex); + BUG_ON(sock_diag_handlers[family] != hnld); + sock_diag_handlers[family] = NULL; + mutex_unlock(&sock_diag_table_mutex); +} +EXPORT_SYMBOL_GPL(sock_diag_unregister); + +static int __sock_diag_cmd(struct sk_buff *skb, struct nlmsghdr *nlh) +{ + int err; + struct sock_diag_req *req = nlmsg_data(nlh); + const struct sock_diag_handler *hndl; + + if (nlmsg_len(nlh) < sizeof(*req)) + return -EINVAL; + + if (req->sdiag_family >= AF_MAX) + return -EINVAL; + req->sdiag_family = array_index_nospec(req->sdiag_family, AF_MAX); + + if (sock_diag_handlers[req->sdiag_family] == NULL) + sock_load_diag_module(req->sdiag_family, 0); + + mutex_lock(&sock_diag_table_mutex); + hndl = sock_diag_handlers[req->sdiag_family]; + if (hndl == NULL) + err = -ENOENT; + else if (nlh->nlmsg_type == SOCK_DIAG_BY_FAMILY) + err = hndl->dump(skb, nlh); + else if (nlh->nlmsg_type == SOCK_DESTROY && hndl->destroy) + err = hndl->destroy(skb, nlh); + else + err = -EOPNOTSUPP; + mutex_unlock(&sock_diag_table_mutex); + + return err; +} + +static int sock_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + int ret; + + switch (nlh->nlmsg_type) { + case TCPDIAG_GETSOCK: + case DCCPDIAG_GETSOCK: + if (inet_rcv_compat == NULL) + sock_load_diag_module(AF_INET, 0); + + mutex_lock(&sock_diag_table_mutex); + if (inet_rcv_compat != NULL) + ret = inet_rcv_compat(skb, nlh); + else + ret = -EOPNOTSUPP; + mutex_unlock(&sock_diag_table_mutex); + + return ret; + case SOCK_DIAG_BY_FAMILY: + case SOCK_DESTROY: + return __sock_diag_cmd(skb, nlh); + default: + return -EINVAL; + } +} + +static DEFINE_MUTEX(sock_diag_mutex); + +static void sock_diag_rcv(struct sk_buff *skb) +{ + mutex_lock(&sock_diag_mutex); + netlink_rcv_skb(skb, &sock_diag_rcv_msg); + mutex_unlock(&sock_diag_mutex); +} + +static int sock_diag_bind(struct net *net, int group) +{ + switch (group) { + case SKNLGRP_INET_TCP_DESTROY: + case SKNLGRP_INET_UDP_DESTROY: + if (!sock_diag_handlers[AF_INET]) + sock_load_diag_module(AF_INET, 0); + break; + case SKNLGRP_INET6_TCP_DESTROY: + case SKNLGRP_INET6_UDP_DESTROY: + if (!sock_diag_handlers[AF_INET6]) + sock_load_diag_module(AF_INET6, 0); + break; + } + return 0; +} + +int sock_diag_destroy(struct sock *sk, int err) +{ + if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + if (!sk->sk_prot->diag_destroy) + return -EOPNOTSUPP; + + return sk->sk_prot->diag_destroy(sk, err); +} +EXPORT_SYMBOL_GPL(sock_diag_destroy); + +static int __net_init diag_net_init(struct net *net) +{ + struct netlink_kernel_cfg cfg = { + .groups = SKNLGRP_MAX, + .input = sock_diag_rcv, + .bind = sock_diag_bind, + .flags = NL_CFG_F_NONROOT_RECV, + }; + + net->diag_nlsk = netlink_kernel_create(net, NETLINK_SOCK_DIAG, &cfg); + return net->diag_nlsk == NULL ? -ENOMEM : 0; +} + +static void __net_exit diag_net_exit(struct net *net) +{ + netlink_kernel_release(net->diag_nlsk); + net->diag_nlsk = NULL; +} + +static struct pernet_operations diag_net_ops = { + .init = diag_net_init, + .exit = diag_net_exit, +}; + +static int __init sock_diag_init(void) +{ + broadcast_wq = alloc_workqueue("sock_diag_events", 0, 0); + BUG_ON(!broadcast_wq); + return register_pernet_subsys(&diag_net_ops); +} +device_initcall(sock_diag_init); diff --git a/net/core/sock_map.c b/net/core/sock_map.c new file mode 100644 index 000000000..91140bc05 --- /dev/null +++ b/net/core/sock_map.c @@ -0,0 +1,1701 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct bpf_stab { + struct bpf_map map; + struct sock **sks; + struct sk_psock_progs progs; + raw_spinlock_t lock; +}; + +#define SOCK_CREATE_FLAG_MASK \ + (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) + +static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, + struct bpf_prog *old, u32 which); +static struct sk_psock_progs *sock_map_progs(struct bpf_map *map); + +static struct bpf_map *sock_map_alloc(union bpf_attr *attr) +{ + struct bpf_stab *stab; + + if (!capable(CAP_NET_ADMIN)) + return ERR_PTR(-EPERM); + if (attr->max_entries == 0 || + attr->key_size != 4 || + (attr->value_size != sizeof(u32) && + attr->value_size != sizeof(u64)) || + attr->map_flags & ~SOCK_CREATE_FLAG_MASK) + return ERR_PTR(-EINVAL); + + stab = bpf_map_area_alloc(sizeof(*stab), NUMA_NO_NODE); + if (!stab) + return ERR_PTR(-ENOMEM); + + bpf_map_init_from_attr(&stab->map, attr); + raw_spin_lock_init(&stab->lock); + + stab->sks = bpf_map_area_alloc((u64) stab->map.max_entries * + sizeof(struct sock *), + stab->map.numa_node); + if (!stab->sks) { + bpf_map_area_free(stab); + return ERR_PTR(-ENOMEM); + } + + return &stab->map; +} + +int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) +{ + u32 ufd = attr->target_fd; + struct bpf_map *map; + struct fd f; + int ret; + + if (attr->attach_flags || attr->replace_bpf_fd) + return -EINVAL; + + f = fdget(ufd); + map = __bpf_map_get(f); + if (IS_ERR(map)) + return PTR_ERR(map); + ret = sock_map_prog_update(map, prog, NULL, attr->attach_type); + fdput(f); + return ret; +} + +int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype) +{ + u32 ufd = attr->target_fd; + struct bpf_prog *prog; + struct bpf_map *map; + struct fd f; + int ret; + + if (attr->attach_flags || attr->replace_bpf_fd) + return -EINVAL; + + f = fdget(ufd); + map = __bpf_map_get(f); + if (IS_ERR(map)) + return PTR_ERR(map); + + prog = bpf_prog_get(attr->attach_bpf_fd); + if (IS_ERR(prog)) { + ret = PTR_ERR(prog); + goto put_map; + } + + if (prog->type != ptype) { + ret = -EINVAL; + goto put_prog; + } + + ret = sock_map_prog_update(map, NULL, prog, attr->attach_type); +put_prog: + bpf_prog_put(prog); +put_map: + fdput(f); + return ret; +} + +static void sock_map_sk_acquire(struct sock *sk) + __acquires(&sk->sk_lock.slock) +{ + lock_sock(sk); + rcu_read_lock(); +} + +static void sock_map_sk_release(struct sock *sk) + __releases(&sk->sk_lock.slock) +{ + rcu_read_unlock(); + release_sock(sk); +} + +static void sock_map_add_link(struct sk_psock *psock, + struct sk_psock_link *link, + struct bpf_map *map, void *link_raw) +{ + link->link_raw = link_raw; + link->map = map; + spin_lock_bh(&psock->link_lock); + list_add_tail(&link->list, &psock->link); + spin_unlock_bh(&psock->link_lock); +} + +static void sock_map_del_link(struct sock *sk, + struct sk_psock *psock, void *link_raw) +{ + bool strp_stop = false, verdict_stop = false; + struct sk_psock_link *link, *tmp; + + spin_lock_bh(&psock->link_lock); + list_for_each_entry_safe(link, tmp, &psock->link, list) { + if (link->link_raw == link_raw) { + struct bpf_map *map = link->map; + struct sk_psock_progs *progs = sock_map_progs(map); + + if (psock->saved_data_ready && progs->stream_parser) + strp_stop = true; + if (psock->saved_data_ready && progs->stream_verdict) + verdict_stop = true; + if (psock->saved_data_ready && progs->skb_verdict) + verdict_stop = true; + list_del(&link->list); + sk_psock_free_link(link); + } + } + spin_unlock_bh(&psock->link_lock); + if (strp_stop || verdict_stop) { + write_lock_bh(&sk->sk_callback_lock); + if (strp_stop) + sk_psock_stop_strp(sk, psock); + if (verdict_stop) + sk_psock_stop_verdict(sk, psock); + + if (psock->psock_update_sk_prot) + psock->psock_update_sk_prot(sk, psock, false); + write_unlock_bh(&sk->sk_callback_lock); + } +} + +static void sock_map_unref(struct sock *sk, void *link_raw) +{ + struct sk_psock *psock = sk_psock(sk); + + if (likely(psock)) { + sock_map_del_link(sk, psock, link_raw); + sk_psock_put(sk, psock); + } +} + +static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock) +{ + if (!sk->sk_prot->psock_update_sk_prot) + return -EINVAL; + psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot; + return sk->sk_prot->psock_update_sk_prot(sk, psock, false); +} + +static struct sk_psock *sock_map_psock_get_checked(struct sock *sk) +{ + struct sk_psock *psock; + + rcu_read_lock(); + psock = sk_psock(sk); + if (psock) { + if (sk->sk_prot->close != sock_map_close) { + psock = ERR_PTR(-EBUSY); + goto out; + } + + if (!refcount_inc_not_zero(&psock->refcnt)) + psock = ERR_PTR(-EBUSY); + } +out: + rcu_read_unlock(); + return psock; +} + +static int sock_map_link(struct bpf_map *map, struct sock *sk) +{ + struct sk_psock_progs *progs = sock_map_progs(map); + struct bpf_prog *stream_verdict = NULL; + struct bpf_prog *stream_parser = NULL; + struct bpf_prog *skb_verdict = NULL; + struct bpf_prog *msg_parser = NULL; + struct sk_psock *psock; + int ret; + + stream_verdict = READ_ONCE(progs->stream_verdict); + if (stream_verdict) { + stream_verdict = bpf_prog_inc_not_zero(stream_verdict); + if (IS_ERR(stream_verdict)) + return PTR_ERR(stream_verdict); + } + + stream_parser = READ_ONCE(progs->stream_parser); + if (stream_parser) { + stream_parser = bpf_prog_inc_not_zero(stream_parser); + if (IS_ERR(stream_parser)) { + ret = PTR_ERR(stream_parser); + goto out_put_stream_verdict; + } + } + + msg_parser = READ_ONCE(progs->msg_parser); + if (msg_parser) { + msg_parser = bpf_prog_inc_not_zero(msg_parser); + if (IS_ERR(msg_parser)) { + ret = PTR_ERR(msg_parser); + goto out_put_stream_parser; + } + } + + skb_verdict = READ_ONCE(progs->skb_verdict); + if (skb_verdict) { + skb_verdict = bpf_prog_inc_not_zero(skb_verdict); + if (IS_ERR(skb_verdict)) { + ret = PTR_ERR(skb_verdict); + goto out_put_msg_parser; + } + } + + psock = sock_map_psock_get_checked(sk); + if (IS_ERR(psock)) { + ret = PTR_ERR(psock); + goto out_progs; + } + + if (psock) { + if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || + (stream_parser && READ_ONCE(psock->progs.stream_parser)) || + (skb_verdict && READ_ONCE(psock->progs.skb_verdict)) || + (skb_verdict && READ_ONCE(psock->progs.stream_verdict)) || + (stream_verdict && READ_ONCE(psock->progs.skb_verdict)) || + (stream_verdict && READ_ONCE(psock->progs.stream_verdict))) { + sk_psock_put(sk, psock); + ret = -EBUSY; + goto out_progs; + } + } else { + psock = sk_psock_init(sk, map->numa_node); + if (IS_ERR(psock)) { + ret = PTR_ERR(psock); + goto out_progs; + } + } + + if (msg_parser) + psock_set_prog(&psock->progs.msg_parser, msg_parser); + if (stream_parser) + psock_set_prog(&psock->progs.stream_parser, stream_parser); + if (stream_verdict) + psock_set_prog(&psock->progs.stream_verdict, stream_verdict); + if (skb_verdict) + psock_set_prog(&psock->progs.skb_verdict, skb_verdict); + + /* msg_* and stream_* programs references tracked in psock after this + * point. Reference dec and cleanup will occur through psock destructor + */ + ret = sock_map_init_proto(sk, psock); + if (ret < 0) { + sk_psock_put(sk, psock); + goto out; + } + + write_lock_bh(&sk->sk_callback_lock); + if (stream_parser && stream_verdict && !psock->saved_data_ready) { + ret = sk_psock_init_strp(sk, psock); + if (ret) { + write_unlock_bh(&sk->sk_callback_lock); + sk_psock_put(sk, psock); + goto out; + } + sk_psock_start_strp(sk, psock); + } else if (!stream_parser && stream_verdict && !psock->saved_data_ready) { + sk_psock_start_verdict(sk,psock); + } else if (!stream_verdict && skb_verdict && !psock->saved_data_ready) { + sk_psock_start_verdict(sk, psock); + } + write_unlock_bh(&sk->sk_callback_lock); + return 0; +out_progs: + if (skb_verdict) + bpf_prog_put(skb_verdict); +out_put_msg_parser: + if (msg_parser) + bpf_prog_put(msg_parser); +out_put_stream_parser: + if (stream_parser) + bpf_prog_put(stream_parser); +out_put_stream_verdict: + if (stream_verdict) + bpf_prog_put(stream_verdict); +out: + return ret; +} + +static void sock_map_free(struct bpf_map *map) +{ + struct bpf_stab *stab = container_of(map, struct bpf_stab, map); + int i; + + /* After the sync no updates or deletes will be in-flight so it + * is safe to walk map and remove entries without risking a race + * in EEXIST update case. + */ + synchronize_rcu(); + for (i = 0; i < stab->map.max_entries; i++) { + struct sock **psk = &stab->sks[i]; + struct sock *sk; + + sk = xchg(psk, NULL); + if (sk) { + sock_hold(sk); + lock_sock(sk); + rcu_read_lock(); + sock_map_unref(sk, psk); + rcu_read_unlock(); + release_sock(sk); + sock_put(sk); + } + } + + /* wait for psock readers accessing its map link */ + synchronize_rcu(); + + bpf_map_area_free(stab->sks); + bpf_map_area_free(stab); +} + +static void sock_map_release_progs(struct bpf_map *map) +{ + psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs); +} + +static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) +{ + struct bpf_stab *stab = container_of(map, struct bpf_stab, map); + + WARN_ON_ONCE(!rcu_read_lock_held()); + + if (unlikely(key >= map->max_entries)) + return NULL; + return READ_ONCE(stab->sks[key]); +} + +static void *sock_map_lookup(struct bpf_map *map, void *key) +{ + struct sock *sk; + + sk = __sock_map_lookup_elem(map, *(u32 *)key); + if (!sk) + return NULL; + if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt)) + return NULL; + return sk; +} + +static void *sock_map_lookup_sys(struct bpf_map *map, void *key) +{ + struct sock *sk; + + if (map->value_size != sizeof(u64)) + return ERR_PTR(-ENOSPC); + + sk = __sock_map_lookup_elem(map, *(u32 *)key); + if (!sk) + return ERR_PTR(-ENOENT); + + __sock_gen_cookie(sk); + return &sk->sk_cookie; +} + +static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, + struct sock **psk) +{ + struct sock *sk; + int err = 0; + + raw_spin_lock_bh(&stab->lock); + sk = *psk; + if (!sk_test || sk_test == sk) + sk = xchg(psk, NULL); + + if (likely(sk)) + sock_map_unref(sk, psk); + else + err = -EINVAL; + + raw_spin_unlock_bh(&stab->lock); + return err; +} + +static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, + void *link_raw) +{ + struct bpf_stab *stab = container_of(map, struct bpf_stab, map); + + __sock_map_delete(stab, sk, link_raw); +} + +static int sock_map_delete_elem(struct bpf_map *map, void *key) +{ + struct bpf_stab *stab = container_of(map, struct bpf_stab, map); + u32 i = *(u32 *)key; + struct sock **psk; + + if (unlikely(i >= map->max_entries)) + return -EINVAL; + + psk = &stab->sks[i]; + return __sock_map_delete(stab, NULL, psk); +} + +static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) +{ + struct bpf_stab *stab = container_of(map, struct bpf_stab, map); + u32 i = key ? *(u32 *)key : U32_MAX; + u32 *key_next = next; + + if (i == stab->map.max_entries - 1) + return -ENOENT; + if (i >= stab->map.max_entries) + *key_next = 0; + else + *key_next = i + 1; + return 0; +} + +static int sock_map_update_common(struct bpf_map *map, u32 idx, + struct sock *sk, u64 flags) +{ + struct bpf_stab *stab = container_of(map, struct bpf_stab, map); + struct sk_psock_link *link; + struct sk_psock *psock; + struct sock *osk; + int ret; + + WARN_ON_ONCE(!rcu_read_lock_held()); + if (unlikely(flags > BPF_EXIST)) + return -EINVAL; + if (unlikely(idx >= map->max_entries)) + return -E2BIG; + + link = sk_psock_init_link(); + if (!link) + return -ENOMEM; + + ret = sock_map_link(map, sk); + if (ret < 0) + goto out_free; + + psock = sk_psock(sk); + WARN_ON_ONCE(!psock); + + raw_spin_lock_bh(&stab->lock); + osk = stab->sks[idx]; + if (osk && flags == BPF_NOEXIST) { + ret = -EEXIST; + goto out_unlock; + } else if (!osk && flags == BPF_EXIST) { + ret = -ENOENT; + goto out_unlock; + } + + sock_map_add_link(psock, link, map, &stab->sks[idx]); + stab->sks[idx] = sk; + if (osk) + sock_map_unref(osk, &stab->sks[idx]); + raw_spin_unlock_bh(&stab->lock); + return 0; +out_unlock: + raw_spin_unlock_bh(&stab->lock); + if (psock) + sk_psock_put(sk, psock); +out_free: + sk_psock_free_link(link); + return ret; +} + +static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) +{ + return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || + ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB || + ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB; +} + +static bool sock_map_redirect_allowed(const struct sock *sk) +{ + if (sk_is_tcp(sk)) + return sk->sk_state != TCP_LISTEN; + else + return sk->sk_state == TCP_ESTABLISHED; +} + +static bool sock_map_sk_is_suitable(const struct sock *sk) +{ + return !!sk->sk_prot->psock_update_sk_prot; +} + +static bool sock_map_sk_state_allowed(const struct sock *sk) +{ + if (sk_is_tcp(sk)) + return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN); + if (sk_is_stream_unix(sk)) + return (1 << sk->sk_state) & TCPF_ESTABLISHED; + return true; +} + +static int sock_hash_update_common(struct bpf_map *map, void *key, + struct sock *sk, u64 flags); + +int sock_map_update_elem_sys(struct bpf_map *map, void *key, void *value, + u64 flags) +{ + struct socket *sock; + struct sock *sk; + int ret; + u64 ufd; + + if (map->value_size == sizeof(u64)) + ufd = *(u64 *)value; + else + ufd = *(u32 *)value; + if (ufd > S32_MAX) + return -EINVAL; + + sock = sockfd_lookup(ufd, &ret); + if (!sock) + return ret; + sk = sock->sk; + if (!sk) { + ret = -EINVAL; + goto out; + } + if (!sock_map_sk_is_suitable(sk)) { + ret = -EOPNOTSUPP; + goto out; + } + + sock_map_sk_acquire(sk); + if (!sock_map_sk_state_allowed(sk)) + ret = -EOPNOTSUPP; + else if (map->map_type == BPF_MAP_TYPE_SOCKMAP) + ret = sock_map_update_common(map, *(u32 *)key, sk, flags); + else + ret = sock_hash_update_common(map, key, sk, flags); + sock_map_sk_release(sk); +out: + sockfd_put(sock); + return ret; +} + +static int sock_map_update_elem(struct bpf_map *map, void *key, + void *value, u64 flags) +{ + struct sock *sk = (struct sock *)value; + int ret; + + if (unlikely(!sk || !sk_fullsock(sk))) + return -EINVAL; + + if (!sock_map_sk_is_suitable(sk)) + return -EOPNOTSUPP; + + local_bh_disable(); + bh_lock_sock(sk); + if (!sock_map_sk_state_allowed(sk)) + ret = -EOPNOTSUPP; + else if (map->map_type == BPF_MAP_TYPE_SOCKMAP) + ret = sock_map_update_common(map, *(u32 *)key, sk, flags); + else + ret = sock_hash_update_common(map, key, sk, flags); + bh_unlock_sock(sk); + local_bh_enable(); + return ret; +} + +BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops, + struct bpf_map *, map, void *, key, u64, flags) +{ + WARN_ON_ONCE(!rcu_read_lock_held()); + + if (likely(sock_map_sk_is_suitable(sops->sk) && + sock_map_op_okay(sops))) + return sock_map_update_common(map, *(u32 *)key, sops->sk, + flags); + return -EOPNOTSUPP; +} + +const struct bpf_func_proto bpf_sock_map_update_proto = { + .func = bpf_sock_map_update, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_PTR_TO_MAP_KEY, + .arg4_type = ARG_ANYTHING, +}; + +BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, + struct bpf_map *, map, u32, key, u64, flags) +{ + struct sock *sk; + + if (unlikely(flags & ~(BPF_F_INGRESS))) + return SK_DROP; + + sk = __sock_map_lookup_elem(map, key); + if (unlikely(!sk || !sock_map_redirect_allowed(sk))) + return SK_DROP; + + skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS); + return SK_PASS; +} + +const struct bpf_func_proto bpf_sk_redirect_map_proto = { + .func = bpf_sk_redirect_map, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_ANYTHING, +}; + +BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg, + struct bpf_map *, map, u32, key, u64, flags) +{ + struct sock *sk; + + if (unlikely(flags & ~(BPF_F_INGRESS))) + return SK_DROP; + + sk = __sock_map_lookup_elem(map, key); + if (unlikely(!sk || !sock_map_redirect_allowed(sk))) + return SK_DROP; + if (!(flags & BPF_F_INGRESS) && !sk_is_tcp(sk)) + return SK_DROP; + + msg->flags = flags; + msg->sk_redir = sk; + return SK_PASS; +} + +const struct bpf_func_proto bpf_msg_redirect_map_proto = { + .func = bpf_msg_redirect_map, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_ANYTHING, + .arg4_type = ARG_ANYTHING, +}; + +struct sock_map_seq_info { + struct bpf_map *map; + struct sock *sk; + u32 index; +}; + +struct bpf_iter__sockmap { + __bpf_md_ptr(struct bpf_iter_meta *, meta); + __bpf_md_ptr(struct bpf_map *, map); + __bpf_md_ptr(void *, key); + __bpf_md_ptr(struct sock *, sk); +}; + +DEFINE_BPF_ITER_FUNC(sockmap, struct bpf_iter_meta *meta, + struct bpf_map *map, void *key, + struct sock *sk) + +static void *sock_map_seq_lookup_elem(struct sock_map_seq_info *info) +{ + if (unlikely(info->index >= info->map->max_entries)) + return NULL; + + info->sk = __sock_map_lookup_elem(info->map, info->index); + + /* can't return sk directly, since that might be NULL */ + return info; +} + +static void *sock_map_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(rcu) +{ + struct sock_map_seq_info *info = seq->private; + + if (*pos == 0) + ++*pos; + + /* pairs with sock_map_seq_stop */ + rcu_read_lock(); + return sock_map_seq_lookup_elem(info); +} + +static void *sock_map_seq_next(struct seq_file *seq, void *v, loff_t *pos) + __must_hold(rcu) +{ + struct sock_map_seq_info *info = seq->private; + + ++*pos; + ++info->index; + + return sock_map_seq_lookup_elem(info); +} + +static int sock_map_seq_show(struct seq_file *seq, void *v) + __must_hold(rcu) +{ + struct sock_map_seq_info *info = seq->private; + struct bpf_iter__sockmap ctx = {}; + struct bpf_iter_meta meta; + struct bpf_prog *prog; + + meta.seq = seq; + prog = bpf_iter_get_info(&meta, !v); + if (!prog) + return 0; + + ctx.meta = &meta; + ctx.map = info->map; + if (v) { + ctx.key = &info->index; + ctx.sk = info->sk; + } + + return bpf_iter_run_prog(prog, &ctx); +} + +static void sock_map_seq_stop(struct seq_file *seq, void *v) + __releases(rcu) +{ + if (!v) + (void)sock_map_seq_show(seq, NULL); + + /* pairs with sock_map_seq_start */ + rcu_read_unlock(); +} + +static const struct seq_operations sock_map_seq_ops = { + .start = sock_map_seq_start, + .next = sock_map_seq_next, + .stop = sock_map_seq_stop, + .show = sock_map_seq_show, +}; + +static int sock_map_init_seq_private(void *priv_data, + struct bpf_iter_aux_info *aux) +{ + struct sock_map_seq_info *info = priv_data; + + bpf_map_inc_with_uref(aux->map); + info->map = aux->map; + return 0; +} + +static void sock_map_fini_seq_private(void *priv_data) +{ + struct sock_map_seq_info *info = priv_data; + + bpf_map_put_with_uref(info->map); +} + +static const struct bpf_iter_seq_info sock_map_iter_seq_info = { + .seq_ops = &sock_map_seq_ops, + .init_seq_private = sock_map_init_seq_private, + .fini_seq_private = sock_map_fini_seq_private, + .seq_priv_size = sizeof(struct sock_map_seq_info), +}; + +BTF_ID_LIST_SINGLE(sock_map_btf_ids, struct, bpf_stab) +const struct bpf_map_ops sock_map_ops = { + .map_meta_equal = bpf_map_meta_equal, + .map_alloc = sock_map_alloc, + .map_free = sock_map_free, + .map_get_next_key = sock_map_get_next_key, + .map_lookup_elem_sys_only = sock_map_lookup_sys, + .map_update_elem = sock_map_update_elem, + .map_delete_elem = sock_map_delete_elem, + .map_lookup_elem = sock_map_lookup, + .map_release_uref = sock_map_release_progs, + .map_check_btf = map_check_no_btf, + .map_btf_id = &sock_map_btf_ids[0], + .iter_seq_info = &sock_map_iter_seq_info, +}; + +struct bpf_shtab_elem { + struct rcu_head rcu; + u32 hash; + struct sock *sk; + struct hlist_node node; + u8 key[]; +}; + +struct bpf_shtab_bucket { + struct hlist_head head; + raw_spinlock_t lock; +}; + +struct bpf_shtab { + struct bpf_map map; + struct bpf_shtab_bucket *buckets; + u32 buckets_num; + u32 elem_size; + struct sk_psock_progs progs; + atomic_t count; +}; + +static inline u32 sock_hash_bucket_hash(const void *key, u32 len) +{ + return jhash(key, len, 0); +} + +static struct bpf_shtab_bucket *sock_hash_select_bucket(struct bpf_shtab *htab, + u32 hash) +{ + return &htab->buckets[hash & (htab->buckets_num - 1)]; +} + +static struct bpf_shtab_elem * +sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, + u32 key_size) +{ + struct bpf_shtab_elem *elem; + + hlist_for_each_entry_rcu(elem, head, node) { + if (elem->hash == hash && + !memcmp(&elem->key, key, key_size)) + return elem; + } + + return NULL; +} + +static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) +{ + struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); + u32 key_size = map->key_size, hash; + struct bpf_shtab_bucket *bucket; + struct bpf_shtab_elem *elem; + + WARN_ON_ONCE(!rcu_read_lock_held()); + + hash = sock_hash_bucket_hash(key, key_size); + bucket = sock_hash_select_bucket(htab, hash); + elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); + + return elem ? elem->sk : NULL; +} + +static void sock_hash_free_elem(struct bpf_shtab *htab, + struct bpf_shtab_elem *elem) +{ + atomic_dec(&htab->count); + kfree_rcu(elem, rcu); +} + +static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, + void *link_raw) +{ + struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); + struct bpf_shtab_elem *elem_probe, *elem = link_raw; + struct bpf_shtab_bucket *bucket; + + WARN_ON_ONCE(!rcu_read_lock_held()); + bucket = sock_hash_select_bucket(htab, elem->hash); + + /* elem may be deleted in parallel from the map, but access here + * is okay since it's going away only after RCU grace period. + * However, we need to check whether it's still present. + */ + raw_spin_lock_bh(&bucket->lock); + elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, + elem->key, map->key_size); + if (elem_probe && elem_probe == elem) { + hlist_del_rcu(&elem->node); + sock_map_unref(elem->sk, elem); + sock_hash_free_elem(htab, elem); + } + raw_spin_unlock_bh(&bucket->lock); +} + +static int sock_hash_delete_elem(struct bpf_map *map, void *key) +{ + struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); + u32 hash, key_size = map->key_size; + struct bpf_shtab_bucket *bucket; + struct bpf_shtab_elem *elem; + int ret = -ENOENT; + + hash = sock_hash_bucket_hash(key, key_size); + bucket = sock_hash_select_bucket(htab, hash); + + raw_spin_lock_bh(&bucket->lock); + elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); + if (elem) { + hlist_del_rcu(&elem->node); + sock_map_unref(elem->sk, elem); + sock_hash_free_elem(htab, elem); + ret = 0; + } + raw_spin_unlock_bh(&bucket->lock); + return ret; +} + +static struct bpf_shtab_elem *sock_hash_alloc_elem(struct bpf_shtab *htab, + void *key, u32 key_size, + u32 hash, struct sock *sk, + struct bpf_shtab_elem *old) +{ + struct bpf_shtab_elem *new; + + if (atomic_inc_return(&htab->count) > htab->map.max_entries) { + if (!old) { + atomic_dec(&htab->count); + return ERR_PTR(-E2BIG); + } + } + + new = bpf_map_kmalloc_node(&htab->map, htab->elem_size, + GFP_ATOMIC | __GFP_NOWARN, + htab->map.numa_node); + if (!new) { + atomic_dec(&htab->count); + return ERR_PTR(-ENOMEM); + } + memcpy(new->key, key, key_size); + new->sk = sk; + new->hash = hash; + return new; +} + +static int sock_hash_update_common(struct bpf_map *map, void *key, + struct sock *sk, u64 flags) +{ + struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); + u32 key_size = map->key_size, hash; + struct bpf_shtab_elem *elem, *elem_new; + struct bpf_shtab_bucket *bucket; + struct sk_psock_link *link; + struct sk_psock *psock; + int ret; + + WARN_ON_ONCE(!rcu_read_lock_held()); + if (unlikely(flags > BPF_EXIST)) + return -EINVAL; + + link = sk_psock_init_link(); + if (!link) + return -ENOMEM; + + ret = sock_map_link(map, sk); + if (ret < 0) + goto out_free; + + psock = sk_psock(sk); + WARN_ON_ONCE(!psock); + + hash = sock_hash_bucket_hash(key, key_size); + bucket = sock_hash_select_bucket(htab, hash); + + raw_spin_lock_bh(&bucket->lock); + elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); + if (elem && flags == BPF_NOEXIST) { + ret = -EEXIST; + goto out_unlock; + } else if (!elem && flags == BPF_EXIST) { + ret = -ENOENT; + goto out_unlock; + } + + elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); + if (IS_ERR(elem_new)) { + ret = PTR_ERR(elem_new); + goto out_unlock; + } + + sock_map_add_link(psock, link, map, elem_new); + /* Add new element to the head of the list, so that + * concurrent search will find it before old elem. + */ + hlist_add_head_rcu(&elem_new->node, &bucket->head); + if (elem) { + hlist_del_rcu(&elem->node); + sock_map_unref(elem->sk, elem); + sock_hash_free_elem(htab, elem); + } + raw_spin_unlock_bh(&bucket->lock); + return 0; +out_unlock: + raw_spin_unlock_bh(&bucket->lock); + sk_psock_put(sk, psock); +out_free: + sk_psock_free_link(link); + return ret; +} + +static int sock_hash_get_next_key(struct bpf_map *map, void *key, + void *key_next) +{ + struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); + struct bpf_shtab_elem *elem, *elem_next; + u32 hash, key_size = map->key_size; + struct hlist_head *head; + int i = 0; + + if (!key) + goto find_first_elem; + hash = sock_hash_bucket_hash(key, key_size); + head = &sock_hash_select_bucket(htab, hash)->head; + elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); + if (!elem) + goto find_first_elem; + + elem_next = hlist_entry_safe(rcu_dereference(hlist_next_rcu(&elem->node)), + struct bpf_shtab_elem, node); + if (elem_next) { + memcpy(key_next, elem_next->key, key_size); + return 0; + } + + i = hash & (htab->buckets_num - 1); + i++; +find_first_elem: + for (; i < htab->buckets_num; i++) { + head = &sock_hash_select_bucket(htab, i)->head; + elem_next = hlist_entry_safe(rcu_dereference(hlist_first_rcu(head)), + struct bpf_shtab_elem, node); + if (elem_next) { + memcpy(key_next, elem_next->key, key_size); + return 0; + } + } + + return -ENOENT; +} + +static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) +{ + struct bpf_shtab *htab; + int i, err; + + if (!capable(CAP_NET_ADMIN)) + return ERR_PTR(-EPERM); + if (attr->max_entries == 0 || + attr->key_size == 0 || + (attr->value_size != sizeof(u32) && + attr->value_size != sizeof(u64)) || + attr->map_flags & ~SOCK_CREATE_FLAG_MASK) + return ERR_PTR(-EINVAL); + if (attr->key_size > MAX_BPF_STACK) + return ERR_PTR(-E2BIG); + + htab = bpf_map_area_alloc(sizeof(*htab), NUMA_NO_NODE); + if (!htab) + return ERR_PTR(-ENOMEM); + + bpf_map_init_from_attr(&htab->map, attr); + + htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); + htab->elem_size = sizeof(struct bpf_shtab_elem) + + round_up(htab->map.key_size, 8); + if (htab->buckets_num == 0 || + htab->buckets_num > U32_MAX / sizeof(struct bpf_shtab_bucket)) { + err = -EINVAL; + goto free_htab; + } + + htab->buckets = bpf_map_area_alloc(htab->buckets_num * + sizeof(struct bpf_shtab_bucket), + htab->map.numa_node); + if (!htab->buckets) { + err = -ENOMEM; + goto free_htab; + } + + for (i = 0; i < htab->buckets_num; i++) { + INIT_HLIST_HEAD(&htab->buckets[i].head); + raw_spin_lock_init(&htab->buckets[i].lock); + } + + return &htab->map; +free_htab: + bpf_map_area_free(htab); + return ERR_PTR(err); +} + +static void sock_hash_free(struct bpf_map *map) +{ + struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); + struct bpf_shtab_bucket *bucket; + struct hlist_head unlink_list; + struct bpf_shtab_elem *elem; + struct hlist_node *node; + int i; + + /* After the sync no updates or deletes will be in-flight so it + * is safe to walk map and remove entries without risking a race + * in EEXIST update case. + */ + synchronize_rcu(); + for (i = 0; i < htab->buckets_num; i++) { + bucket = sock_hash_select_bucket(htab, i); + + /* We are racing with sock_hash_delete_from_link to + * enter the spin-lock critical section. Every socket on + * the list is still linked to sockhash. Since link + * exists, psock exists and holds a ref to socket. That + * lets us to grab a socket ref too. + */ + raw_spin_lock_bh(&bucket->lock); + hlist_for_each_entry(elem, &bucket->head, node) + sock_hold(elem->sk); + hlist_move_list(&bucket->head, &unlink_list); + raw_spin_unlock_bh(&bucket->lock); + + /* Process removed entries out of atomic context to + * block for socket lock before deleting the psock's + * link to sockhash. + */ + hlist_for_each_entry_safe(elem, node, &unlink_list, node) { + hlist_del(&elem->node); + lock_sock(elem->sk); + rcu_read_lock(); + sock_map_unref(elem->sk, elem); + rcu_read_unlock(); + release_sock(elem->sk); + sock_put(elem->sk); + sock_hash_free_elem(htab, elem); + } + } + + /* wait for psock readers accessing its map link */ + synchronize_rcu(); + + bpf_map_area_free(htab->buckets); + bpf_map_area_free(htab); +} + +static void *sock_hash_lookup_sys(struct bpf_map *map, void *key) +{ + struct sock *sk; + + if (map->value_size != sizeof(u64)) + return ERR_PTR(-ENOSPC); + + sk = __sock_hash_lookup_elem(map, key); + if (!sk) + return ERR_PTR(-ENOENT); + + __sock_gen_cookie(sk); + return &sk->sk_cookie; +} + +static void *sock_hash_lookup(struct bpf_map *map, void *key) +{ + struct sock *sk; + + sk = __sock_hash_lookup_elem(map, key); + if (!sk) + return NULL; + if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt)) + return NULL; + return sk; +} + +static void sock_hash_release_progs(struct bpf_map *map) +{ + psock_progs_drop(&container_of(map, struct bpf_shtab, map)->progs); +} + +BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, + struct bpf_map *, map, void *, key, u64, flags) +{ + WARN_ON_ONCE(!rcu_read_lock_held()); + + if (likely(sock_map_sk_is_suitable(sops->sk) && + sock_map_op_okay(sops))) + return sock_hash_update_common(map, key, sops->sk, flags); + return -EOPNOTSUPP; +} + +const struct bpf_func_proto bpf_sock_hash_update_proto = { + .func = bpf_sock_hash_update, + .gpl_only = false, + .pkt_access = true, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_PTR_TO_MAP_KEY, + .arg4_type = ARG_ANYTHING, +}; + +BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, + struct bpf_map *, map, void *, key, u64, flags) +{ + struct sock *sk; + + if (unlikely(flags & ~(BPF_F_INGRESS))) + return SK_DROP; + + sk = __sock_hash_lookup_elem(map, key); + if (unlikely(!sk || !sock_map_redirect_allowed(sk))) + return SK_DROP; + + skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS); + return SK_PASS; +} + +const struct bpf_func_proto bpf_sk_redirect_hash_proto = { + .func = bpf_sk_redirect_hash, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_PTR_TO_MAP_KEY, + .arg4_type = ARG_ANYTHING, +}; + +BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg, + struct bpf_map *, map, void *, key, u64, flags) +{ + struct sock *sk; + + if (unlikely(flags & ~(BPF_F_INGRESS))) + return SK_DROP; + + sk = __sock_hash_lookup_elem(map, key); + if (unlikely(!sk || !sock_map_redirect_allowed(sk))) + return SK_DROP; + if (!(flags & BPF_F_INGRESS) && !sk_is_tcp(sk)) + return SK_DROP; + + msg->flags = flags; + msg->sk_redir = sk; + return SK_PASS; +} + +const struct bpf_func_proto bpf_msg_redirect_hash_proto = { + .func = bpf_msg_redirect_hash, + .gpl_only = false, + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_CTX, + .arg2_type = ARG_CONST_MAP_PTR, + .arg3_type = ARG_PTR_TO_MAP_KEY, + .arg4_type = ARG_ANYTHING, +}; + +struct sock_hash_seq_info { + struct bpf_map *map; + struct bpf_shtab *htab; + u32 bucket_id; +}; + +static void *sock_hash_seq_find_next(struct sock_hash_seq_info *info, + struct bpf_shtab_elem *prev_elem) +{ + const struct bpf_shtab *htab = info->htab; + struct bpf_shtab_bucket *bucket; + struct bpf_shtab_elem *elem; + struct hlist_node *node; + + /* try to find next elem in the same bucket */ + if (prev_elem) { + node = rcu_dereference(hlist_next_rcu(&prev_elem->node)); + elem = hlist_entry_safe(node, struct bpf_shtab_elem, node); + if (elem) + return elem; + + /* no more elements, continue in the next bucket */ + info->bucket_id++; + } + + for (; info->bucket_id < htab->buckets_num; info->bucket_id++) { + bucket = &htab->buckets[info->bucket_id]; + node = rcu_dereference(hlist_first_rcu(&bucket->head)); + elem = hlist_entry_safe(node, struct bpf_shtab_elem, node); + if (elem) + return elem; + } + + return NULL; +} + +static void *sock_hash_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(rcu) +{ + struct sock_hash_seq_info *info = seq->private; + + if (*pos == 0) + ++*pos; + + /* pairs with sock_hash_seq_stop */ + rcu_read_lock(); + return sock_hash_seq_find_next(info, NULL); +} + +static void *sock_hash_seq_next(struct seq_file *seq, void *v, loff_t *pos) + __must_hold(rcu) +{ + struct sock_hash_seq_info *info = seq->private; + + ++*pos; + return sock_hash_seq_find_next(info, v); +} + +static int sock_hash_seq_show(struct seq_file *seq, void *v) + __must_hold(rcu) +{ + struct sock_hash_seq_info *info = seq->private; + struct bpf_iter__sockmap ctx = {}; + struct bpf_shtab_elem *elem = v; + struct bpf_iter_meta meta; + struct bpf_prog *prog; + + meta.seq = seq; + prog = bpf_iter_get_info(&meta, !elem); + if (!prog) + return 0; + + ctx.meta = &meta; + ctx.map = info->map; + if (elem) { + ctx.key = elem->key; + ctx.sk = elem->sk; + } + + return bpf_iter_run_prog(prog, &ctx); +} + +static void sock_hash_seq_stop(struct seq_file *seq, void *v) + __releases(rcu) +{ + if (!v) + (void)sock_hash_seq_show(seq, NULL); + + /* pairs with sock_hash_seq_start */ + rcu_read_unlock(); +} + +static const struct seq_operations sock_hash_seq_ops = { + .start = sock_hash_seq_start, + .next = sock_hash_seq_next, + .stop = sock_hash_seq_stop, + .show = sock_hash_seq_show, +}; + +static int sock_hash_init_seq_private(void *priv_data, + struct bpf_iter_aux_info *aux) +{ + struct sock_hash_seq_info *info = priv_data; + + bpf_map_inc_with_uref(aux->map); + info->map = aux->map; + info->htab = container_of(aux->map, struct bpf_shtab, map); + return 0; +} + +static void sock_hash_fini_seq_private(void *priv_data) +{ + struct sock_hash_seq_info *info = priv_data; + + bpf_map_put_with_uref(info->map); +} + +static const struct bpf_iter_seq_info sock_hash_iter_seq_info = { + .seq_ops = &sock_hash_seq_ops, + .init_seq_private = sock_hash_init_seq_private, + .fini_seq_private = sock_hash_fini_seq_private, + .seq_priv_size = sizeof(struct sock_hash_seq_info), +}; + +BTF_ID_LIST_SINGLE(sock_hash_map_btf_ids, struct, bpf_shtab) +const struct bpf_map_ops sock_hash_ops = { + .map_meta_equal = bpf_map_meta_equal, + .map_alloc = sock_hash_alloc, + .map_free = sock_hash_free, + .map_get_next_key = sock_hash_get_next_key, + .map_update_elem = sock_map_update_elem, + .map_delete_elem = sock_hash_delete_elem, + .map_lookup_elem = sock_hash_lookup, + .map_lookup_elem_sys_only = sock_hash_lookup_sys, + .map_release_uref = sock_hash_release_progs, + .map_check_btf = map_check_no_btf, + .map_btf_id = &sock_hash_map_btf_ids[0], + .iter_seq_info = &sock_hash_iter_seq_info, +}; + +static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) +{ + switch (map->map_type) { + case BPF_MAP_TYPE_SOCKMAP: + return &container_of(map, struct bpf_stab, map)->progs; + case BPF_MAP_TYPE_SOCKHASH: + return &container_of(map, struct bpf_shtab, map)->progs; + default: + break; + } + + return NULL; +} + +static int sock_map_prog_lookup(struct bpf_map *map, struct bpf_prog ***pprog, + u32 which) +{ + struct sk_psock_progs *progs = sock_map_progs(map); + + if (!progs) + return -EOPNOTSUPP; + + switch (which) { + case BPF_SK_MSG_VERDICT: + *pprog = &progs->msg_parser; + break; +#if IS_ENABLED(CONFIG_BPF_STREAM_PARSER) + case BPF_SK_SKB_STREAM_PARSER: + *pprog = &progs->stream_parser; + break; +#endif + case BPF_SK_SKB_STREAM_VERDICT: + if (progs->skb_verdict) + return -EBUSY; + *pprog = &progs->stream_verdict; + break; + case BPF_SK_SKB_VERDICT: + if (progs->stream_verdict) + return -EBUSY; + *pprog = &progs->skb_verdict; + break; + default: + return -EOPNOTSUPP; + } + + return 0; +} + +static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, + struct bpf_prog *old, u32 which) +{ + struct bpf_prog **pprog; + int ret; + + ret = sock_map_prog_lookup(map, &pprog, which); + if (ret) + return ret; + + if (old) + return psock_replace_prog(pprog, prog, old); + + psock_set_prog(pprog, prog); + return 0; +} + +int sock_map_bpf_prog_query(const union bpf_attr *attr, + union bpf_attr __user *uattr) +{ + __u32 __user *prog_ids = u64_to_user_ptr(attr->query.prog_ids); + u32 prog_cnt = 0, flags = 0, ufd = attr->target_fd; + struct bpf_prog **pprog; + struct bpf_prog *prog; + struct bpf_map *map; + struct fd f; + u32 id = 0; + int ret; + + if (attr->query.query_flags) + return -EINVAL; + + f = fdget(ufd); + map = __bpf_map_get(f); + if (IS_ERR(map)) + return PTR_ERR(map); + + rcu_read_lock(); + + ret = sock_map_prog_lookup(map, &pprog, attr->query.attach_type); + if (ret) + goto end; + + prog = *pprog; + prog_cnt = !prog ? 0 : 1; + + if (!attr->query.prog_cnt || !prog_ids || !prog_cnt) + goto end; + + /* we do not hold the refcnt, the bpf prog may be released + * asynchronously and the id would be set to 0. + */ + id = data_race(prog->aux->id); + if (id == 0) + prog_cnt = 0; + +end: + rcu_read_unlock(); + + if (copy_to_user(&uattr->query.attach_flags, &flags, sizeof(flags)) || + (id != 0 && copy_to_user(prog_ids, &id, sizeof(u32))) || + copy_to_user(&uattr->query.prog_cnt, &prog_cnt, sizeof(prog_cnt))) + ret = -EFAULT; + + fdput(f); + return ret; +} + +static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link) +{ + switch (link->map->map_type) { + case BPF_MAP_TYPE_SOCKMAP: + return sock_map_delete_from_link(link->map, sk, + link->link_raw); + case BPF_MAP_TYPE_SOCKHASH: + return sock_hash_delete_from_link(link->map, sk, + link->link_raw); + default: + break; + } +} + +static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock) +{ + struct sk_psock_link *link; + + while ((link = sk_psock_link_pop(psock))) { + sock_map_unlink(sk, link); + sk_psock_free_link(link); + } +} + +void sock_map_unhash(struct sock *sk) +{ + void (*saved_unhash)(struct sock *sk); + struct sk_psock *psock; + + rcu_read_lock(); + psock = sk_psock(sk); + if (unlikely(!psock)) { + rcu_read_unlock(); + saved_unhash = READ_ONCE(sk->sk_prot)->unhash; + } else { + saved_unhash = psock->saved_unhash; + sock_map_remove_links(sk, psock); + rcu_read_unlock(); + } + if (WARN_ON_ONCE(saved_unhash == sock_map_unhash)) + return; + if (saved_unhash) + saved_unhash(sk); +} +EXPORT_SYMBOL_GPL(sock_map_unhash); + +void sock_map_destroy(struct sock *sk) +{ + void (*saved_destroy)(struct sock *sk); + struct sk_psock *psock; + + rcu_read_lock(); + psock = sk_psock_get(sk); + if (unlikely(!psock)) { + rcu_read_unlock(); + saved_destroy = READ_ONCE(sk->sk_prot)->destroy; + } else { + saved_destroy = psock->saved_destroy; + sock_map_remove_links(sk, psock); + rcu_read_unlock(); + sk_psock_stop(psock); + sk_psock_put(sk, psock); + } + if (WARN_ON_ONCE(saved_destroy == sock_map_destroy)) + return; + if (saved_destroy) + saved_destroy(sk); +} +EXPORT_SYMBOL_GPL(sock_map_destroy); + +void sock_map_close(struct sock *sk, long timeout) +{ + void (*saved_close)(struct sock *sk, long timeout); + struct sk_psock *psock; + + lock_sock(sk); + rcu_read_lock(); + psock = sk_psock_get(sk); + if (unlikely(!psock)) { + rcu_read_unlock(); + release_sock(sk); + saved_close = READ_ONCE(sk->sk_prot)->close; + } else { + saved_close = psock->saved_close; + sock_map_remove_links(sk, psock); + rcu_read_unlock(); + sk_psock_stop(psock); + release_sock(sk); + cancel_delayed_work_sync(&psock->work); + sk_psock_put(sk, psock); + } + + /* Make sure we do not recurse. This is a bug. + * Leak the socket instead of crashing on a stack overflow. + */ + if (WARN_ON_ONCE(saved_close == sock_map_close)) + return; + saved_close(sk, timeout); +} +EXPORT_SYMBOL_GPL(sock_map_close); + +static int sock_map_iter_attach_target(struct bpf_prog *prog, + union bpf_iter_link_info *linfo, + struct bpf_iter_aux_info *aux) +{ + struct bpf_map *map; + int err = -EINVAL; + + if (!linfo->map.map_fd) + return -EBADF; + + map = bpf_map_get_with_uref(linfo->map.map_fd); + if (IS_ERR(map)) + return PTR_ERR(map); + + if (map->map_type != BPF_MAP_TYPE_SOCKMAP && + map->map_type != BPF_MAP_TYPE_SOCKHASH) + goto put_map; + + if (prog->aux->max_rdonly_access > map->key_size) { + err = -EACCES; + goto put_map; + } + + aux->map = map; + return 0; + +put_map: + bpf_map_put_with_uref(map); + return err; +} + +static void sock_map_iter_detach_target(struct bpf_iter_aux_info *aux) +{ + bpf_map_put_with_uref(aux->map); +} + +static struct bpf_iter_reg sock_map_iter_reg = { + .target = "sockmap", + .attach_target = sock_map_iter_attach_target, + .detach_target = sock_map_iter_detach_target, + .show_fdinfo = bpf_iter_map_show_fdinfo, + .fill_link_info = bpf_iter_map_fill_link_info, + .ctx_arg_info_size = 2, + .ctx_arg_info = { + { offsetof(struct bpf_iter__sockmap, key), + PTR_TO_BUF | PTR_MAYBE_NULL | MEM_RDONLY }, + { offsetof(struct bpf_iter__sockmap, sk), + PTR_TO_BTF_ID_OR_NULL }, + }, +}; + +static int __init bpf_sockmap_iter_init(void) +{ + sock_map_iter_reg.ctx_arg_info[1].btf_id = + btf_sock_ids[BTF_SOCK_TYPE_SOCK]; + return bpf_iter_reg_target(&sock_map_iter_reg); +} +late_initcall(bpf_sockmap_iter_init); diff --git a/net/core/sock_reuseport.c b/net/core/sock_reuseport.c new file mode 100644 index 000000000..5a165286e --- /dev/null +++ b/net/core/sock_reuseport.c @@ -0,0 +1,749 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * To speed up listener socket lookup, create an array to store all sockets + * listening on the same port. This allows a decision to be made after finding + * the first socket. An optional BPF program can also be configured for + * selecting the socket index from the array of available sockets. + */ + +#include +#include +#include +#include +#include +#include + +#define INIT_SOCKS 128 + +DEFINE_SPINLOCK(reuseport_lock); + +static DEFINE_IDA(reuseport_ida); +static int reuseport_resurrect(struct sock *sk, struct sock_reuseport *old_reuse, + struct sock_reuseport *reuse, bool bind_inany); + +void reuseport_has_conns_set(struct sock *sk) +{ + struct sock_reuseport *reuse; + + if (!rcu_access_pointer(sk->sk_reuseport_cb)) + return; + + spin_lock_bh(&reuseport_lock); + reuse = rcu_dereference_protected(sk->sk_reuseport_cb, + lockdep_is_held(&reuseport_lock)); + if (likely(reuse)) + reuse->has_conns = 1; + spin_unlock_bh(&reuseport_lock); +} +EXPORT_SYMBOL(reuseport_has_conns_set); + +static void __reuseport_get_incoming_cpu(struct sock_reuseport *reuse) +{ + /* Paired with READ_ONCE() in reuseport_select_sock_by_hash(). */ + WRITE_ONCE(reuse->incoming_cpu, reuse->incoming_cpu + 1); +} + +static void __reuseport_put_incoming_cpu(struct sock_reuseport *reuse) +{ + /* Paired with READ_ONCE() in reuseport_select_sock_by_hash(). */ + WRITE_ONCE(reuse->incoming_cpu, reuse->incoming_cpu - 1); +} + +static void reuseport_get_incoming_cpu(struct sock *sk, struct sock_reuseport *reuse) +{ + if (sk->sk_incoming_cpu >= 0) + __reuseport_get_incoming_cpu(reuse); +} + +static void reuseport_put_incoming_cpu(struct sock *sk, struct sock_reuseport *reuse) +{ + if (sk->sk_incoming_cpu >= 0) + __reuseport_put_incoming_cpu(reuse); +} + +void reuseport_update_incoming_cpu(struct sock *sk, int val) +{ + struct sock_reuseport *reuse; + int old_sk_incoming_cpu; + + if (unlikely(!rcu_access_pointer(sk->sk_reuseport_cb))) { + /* Paired with REAE_ONCE() in sk_incoming_cpu_update() + * and compute_score(). + */ + WRITE_ONCE(sk->sk_incoming_cpu, val); + return; + } + + spin_lock_bh(&reuseport_lock); + + /* This must be done under reuseport_lock to avoid a race with + * reuseport_grow(), which accesses sk->sk_incoming_cpu without + * lock_sock() when detaching a shutdown()ed sk. + * + * Paired with READ_ONCE() in reuseport_select_sock_by_hash(). + */ + old_sk_incoming_cpu = sk->sk_incoming_cpu; + WRITE_ONCE(sk->sk_incoming_cpu, val); + + reuse = rcu_dereference_protected(sk->sk_reuseport_cb, + lockdep_is_held(&reuseport_lock)); + + /* reuseport_grow() has detached a closed sk. */ + if (!reuse) + goto out; + + if (old_sk_incoming_cpu < 0 && val >= 0) + __reuseport_get_incoming_cpu(reuse); + else if (old_sk_incoming_cpu >= 0 && val < 0) + __reuseport_put_incoming_cpu(reuse); + +out: + spin_unlock_bh(&reuseport_lock); +} + +static int reuseport_sock_index(struct sock *sk, + const struct sock_reuseport *reuse, + bool closed) +{ + int left, right; + + if (!closed) { + left = 0; + right = reuse->num_socks; + } else { + left = reuse->max_socks - reuse->num_closed_socks; + right = reuse->max_socks; + } + + for (; left < right; left++) + if (reuse->socks[left] == sk) + return left; + return -1; +} + +static void __reuseport_add_sock(struct sock *sk, + struct sock_reuseport *reuse) +{ + reuse->socks[reuse->num_socks] = sk; + /* paired with smp_rmb() in reuseport_(select|migrate)_sock() */ + smp_wmb(); + reuse->num_socks++; + reuseport_get_incoming_cpu(sk, reuse); +} + +static bool __reuseport_detach_sock(struct sock *sk, + struct sock_reuseport *reuse) +{ + int i = reuseport_sock_index(sk, reuse, false); + + if (i == -1) + return false; + + reuse->socks[i] = reuse->socks[reuse->num_socks - 1]; + reuse->num_socks--; + reuseport_put_incoming_cpu(sk, reuse); + + return true; +} + +static void __reuseport_add_closed_sock(struct sock *sk, + struct sock_reuseport *reuse) +{ + reuse->socks[reuse->max_socks - reuse->num_closed_socks - 1] = sk; + /* paired with READ_ONCE() in inet_csk_bind_conflict() */ + WRITE_ONCE(reuse->num_closed_socks, reuse->num_closed_socks + 1); + reuseport_get_incoming_cpu(sk, reuse); +} + +static bool __reuseport_detach_closed_sock(struct sock *sk, + struct sock_reuseport *reuse) +{ + int i = reuseport_sock_index(sk, reuse, true); + + if (i == -1) + return false; + + reuse->socks[i] = reuse->socks[reuse->max_socks - reuse->num_closed_socks]; + /* paired with READ_ONCE() in inet_csk_bind_conflict() */ + WRITE_ONCE(reuse->num_closed_socks, reuse->num_closed_socks - 1); + reuseport_put_incoming_cpu(sk, reuse); + + return true; +} + +static struct sock_reuseport *__reuseport_alloc(unsigned int max_socks) +{ + unsigned int size = sizeof(struct sock_reuseport) + + sizeof(struct sock *) * max_socks; + struct sock_reuseport *reuse = kzalloc(size, GFP_ATOMIC); + + if (!reuse) + return NULL; + + reuse->max_socks = max_socks; + + RCU_INIT_POINTER(reuse->prog, NULL); + return reuse; +} + +int reuseport_alloc(struct sock *sk, bool bind_inany) +{ + struct sock_reuseport *reuse; + int id, ret = 0; + + /* bh lock used since this function call may precede hlist lock in + * soft irq of receive path or setsockopt from process context + */ + spin_lock_bh(&reuseport_lock); + + /* Allocation attempts can occur concurrently via the setsockopt path + * and the bind/hash path. Nothing to do when we lose the race. + */ + reuse = rcu_dereference_protected(sk->sk_reuseport_cb, + lockdep_is_held(&reuseport_lock)); + if (reuse) { + if (reuse->num_closed_socks) { + /* sk was shutdown()ed before */ + ret = reuseport_resurrect(sk, reuse, NULL, bind_inany); + goto out; + } + + /* Only set reuse->bind_inany if the bind_inany is true. + * Otherwise, it will overwrite the reuse->bind_inany + * which was set by the bind/hash path. + */ + if (bind_inany) + reuse->bind_inany = bind_inany; + goto out; + } + + reuse = __reuseport_alloc(INIT_SOCKS); + if (!reuse) { + ret = -ENOMEM; + goto out; + } + + id = ida_alloc(&reuseport_ida, GFP_ATOMIC); + if (id < 0) { + kfree(reuse); + ret = id; + goto out; + } + + reuse->reuseport_id = id; + reuse->bind_inany = bind_inany; + reuse->socks[0] = sk; + reuse->num_socks = 1; + reuseport_get_incoming_cpu(sk, reuse); + rcu_assign_pointer(sk->sk_reuseport_cb, reuse); + +out: + spin_unlock_bh(&reuseport_lock); + + return ret; +} +EXPORT_SYMBOL(reuseport_alloc); + +static struct sock_reuseport *reuseport_grow(struct sock_reuseport *reuse) +{ + struct sock_reuseport *more_reuse; + u32 more_socks_size, i; + + more_socks_size = reuse->max_socks * 2U; + if (more_socks_size > U16_MAX) { + if (reuse->num_closed_socks) { + /* Make room by removing a closed sk. + * The child has already been migrated. + * Only reqsk left at this point. + */ + struct sock *sk; + + sk = reuse->socks[reuse->max_socks - reuse->num_closed_socks]; + RCU_INIT_POINTER(sk->sk_reuseport_cb, NULL); + __reuseport_detach_closed_sock(sk, reuse); + + return reuse; + } + + return NULL; + } + + more_reuse = __reuseport_alloc(more_socks_size); + if (!more_reuse) + return NULL; + + more_reuse->num_socks = reuse->num_socks; + more_reuse->num_closed_socks = reuse->num_closed_socks; + more_reuse->prog = reuse->prog; + more_reuse->reuseport_id = reuse->reuseport_id; + more_reuse->bind_inany = reuse->bind_inany; + more_reuse->has_conns = reuse->has_conns; + more_reuse->incoming_cpu = reuse->incoming_cpu; + + memcpy(more_reuse->socks, reuse->socks, + reuse->num_socks * sizeof(struct sock *)); + memcpy(more_reuse->socks + + (more_reuse->max_socks - more_reuse->num_closed_socks), + reuse->socks + (reuse->max_socks - reuse->num_closed_socks), + reuse->num_closed_socks * sizeof(struct sock *)); + more_reuse->synq_overflow_ts = READ_ONCE(reuse->synq_overflow_ts); + + for (i = 0; i < reuse->max_socks; ++i) + rcu_assign_pointer(reuse->socks[i]->sk_reuseport_cb, + more_reuse); + + /* Note: we use kfree_rcu here instead of reuseport_free_rcu so + * that reuse and more_reuse can temporarily share a reference + * to prog. + */ + kfree_rcu(reuse, rcu); + return more_reuse; +} + +static void reuseport_free_rcu(struct rcu_head *head) +{ + struct sock_reuseport *reuse; + + reuse = container_of(head, struct sock_reuseport, rcu); + sk_reuseport_prog_free(rcu_dereference_protected(reuse->prog, 1)); + ida_free(&reuseport_ida, reuse->reuseport_id); + kfree(reuse); +} + +/** + * reuseport_add_sock - Add a socket to the reuseport group of another. + * @sk: New socket to add to the group. + * @sk2: Socket belonging to the existing reuseport group. + * @bind_inany: Whether or not the group is bound to a local INANY address. + * + * May return ENOMEM and not add socket to group under memory pressure. + */ +int reuseport_add_sock(struct sock *sk, struct sock *sk2, bool bind_inany) +{ + struct sock_reuseport *old_reuse, *reuse; + + if (!rcu_access_pointer(sk2->sk_reuseport_cb)) { + int err = reuseport_alloc(sk2, bind_inany); + + if (err) + return err; + } + + spin_lock_bh(&reuseport_lock); + reuse = rcu_dereference_protected(sk2->sk_reuseport_cb, + lockdep_is_held(&reuseport_lock)); + old_reuse = rcu_dereference_protected(sk->sk_reuseport_cb, + lockdep_is_held(&reuseport_lock)); + if (old_reuse && old_reuse->num_closed_socks) { + /* sk was shutdown()ed before */ + int err = reuseport_resurrect(sk, old_reuse, reuse, reuse->bind_inany); + + spin_unlock_bh(&reuseport_lock); + return err; + } + + if (old_reuse && old_reuse->num_socks != 1) { + spin_unlock_bh(&reuseport_lock); + return -EBUSY; + } + + if (reuse->num_socks + reuse->num_closed_socks == reuse->max_socks) { + reuse = reuseport_grow(reuse); + if (!reuse) { + spin_unlock_bh(&reuseport_lock); + return -ENOMEM; + } + } + + __reuseport_add_sock(sk, reuse); + rcu_assign_pointer(sk->sk_reuseport_cb, reuse); + + spin_unlock_bh(&reuseport_lock); + + if (old_reuse) + call_rcu(&old_reuse->rcu, reuseport_free_rcu); + return 0; +} +EXPORT_SYMBOL(reuseport_add_sock); + +static int reuseport_resurrect(struct sock *sk, struct sock_reuseport *old_reuse, + struct sock_reuseport *reuse, bool bind_inany) +{ + if (old_reuse == reuse) { + /* If sk was in the same reuseport group, just pop sk out of + * the closed section and push sk into the listening section. + */ + __reuseport_detach_closed_sock(sk, old_reuse); + __reuseport_add_sock(sk, old_reuse); + return 0; + } + + if (!reuse) { + /* In bind()/listen() path, we cannot carry over the eBPF prog + * for the shutdown()ed socket. In setsockopt() path, we should + * not change the eBPF prog of listening sockets by attaching a + * prog to the shutdown()ed socket. Thus, we will allocate a new + * reuseport group and detach sk from the old group. + */ + int id; + + reuse = __reuseport_alloc(INIT_SOCKS); + if (!reuse) + return -ENOMEM; + + id = ida_alloc(&reuseport_ida, GFP_ATOMIC); + if (id < 0) { + kfree(reuse); + return id; + } + + reuse->reuseport_id = id; + reuse->bind_inany = bind_inany; + } else { + /* Move sk from the old group to the new one if + * - all the other listeners in the old group were close()d or + * shutdown()ed, and then sk2 has listen()ed on the same port + * OR + * - sk listen()ed without bind() (or with autobind), was + * shutdown()ed, and then listen()s on another port which + * sk2 listen()s on. + */ + if (reuse->num_socks + reuse->num_closed_socks == reuse->max_socks) { + reuse = reuseport_grow(reuse); + if (!reuse) + return -ENOMEM; + } + } + + __reuseport_detach_closed_sock(sk, old_reuse); + __reuseport_add_sock(sk, reuse); + rcu_assign_pointer(sk->sk_reuseport_cb, reuse); + + if (old_reuse->num_socks + old_reuse->num_closed_socks == 0) + call_rcu(&old_reuse->rcu, reuseport_free_rcu); + + return 0; +} + +void reuseport_detach_sock(struct sock *sk) +{ + struct sock_reuseport *reuse; + + spin_lock_bh(&reuseport_lock); + reuse = rcu_dereference_protected(sk->sk_reuseport_cb, + lockdep_is_held(&reuseport_lock)); + + /* reuseport_grow() has detached a closed sk */ + if (!reuse) + goto out; + + /* Notify the bpf side. The sk may be added to a sockarray + * map. If so, sockarray logic will remove it from the map. + * + * Other bpf map types that work with reuseport, like sockmap, + * don't need an explicit callback from here. They override sk + * unhash/close ops to remove the sk from the map before we + * get to this point. + */ + bpf_sk_reuseport_detach(sk); + + rcu_assign_pointer(sk->sk_reuseport_cb, NULL); + + if (!__reuseport_detach_closed_sock(sk, reuse)) + __reuseport_detach_sock(sk, reuse); + + if (reuse->num_socks + reuse->num_closed_socks == 0) + call_rcu(&reuse->rcu, reuseport_free_rcu); + +out: + spin_unlock_bh(&reuseport_lock); +} +EXPORT_SYMBOL(reuseport_detach_sock); + +void reuseport_stop_listen_sock(struct sock *sk) +{ + if (sk->sk_protocol == IPPROTO_TCP) { + struct sock_reuseport *reuse; + struct bpf_prog *prog; + + spin_lock_bh(&reuseport_lock); + + reuse = rcu_dereference_protected(sk->sk_reuseport_cb, + lockdep_is_held(&reuseport_lock)); + prog = rcu_dereference_protected(reuse->prog, + lockdep_is_held(&reuseport_lock)); + + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_migrate_req) || + (prog && prog->expected_attach_type == BPF_SK_REUSEPORT_SELECT_OR_MIGRATE)) { + /* Migration capable, move sk from the listening section + * to the closed section. + */ + bpf_sk_reuseport_detach(sk); + + __reuseport_detach_sock(sk, reuse); + __reuseport_add_closed_sock(sk, reuse); + + spin_unlock_bh(&reuseport_lock); + return; + } + + spin_unlock_bh(&reuseport_lock); + } + + /* Not capable to do migration, detach immediately */ + reuseport_detach_sock(sk); +} +EXPORT_SYMBOL(reuseport_stop_listen_sock); + +static struct sock *run_bpf_filter(struct sock_reuseport *reuse, u16 socks, + struct bpf_prog *prog, struct sk_buff *skb, + int hdr_len) +{ + struct sk_buff *nskb = NULL; + u32 index; + + if (skb_shared(skb)) { + nskb = skb_clone(skb, GFP_ATOMIC); + if (!nskb) + return NULL; + skb = nskb; + } + + /* temporarily advance data past protocol header */ + if (!pskb_pull(skb, hdr_len)) { + kfree_skb(nskb); + return NULL; + } + index = bpf_prog_run_save_cb(prog, skb); + __skb_push(skb, hdr_len); + + consume_skb(nskb); + + if (index >= socks) + return NULL; + + return reuse->socks[index]; +} + +static struct sock *reuseport_select_sock_by_hash(struct sock_reuseport *reuse, + u32 hash, u16 num_socks) +{ + struct sock *first_valid_sk = NULL; + int i, j; + + i = j = reciprocal_scale(hash, num_socks); + do { + struct sock *sk = reuse->socks[i]; + + if (sk->sk_state != TCP_ESTABLISHED) { + /* Paired with WRITE_ONCE() in __reuseport_(get|put)_incoming_cpu(). */ + if (!READ_ONCE(reuse->incoming_cpu)) + return sk; + + /* Paired with WRITE_ONCE() in reuseport_update_incoming_cpu(). */ + if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id()) + return sk; + + if (!first_valid_sk) + first_valid_sk = sk; + } + + i++; + if (i >= num_socks) + i = 0; + } while (i != j); + + return first_valid_sk; +} + +/** + * reuseport_select_sock - Select a socket from an SO_REUSEPORT group. + * @sk: First socket in the group. + * @hash: When no BPF filter is available, use this hash to select. + * @skb: skb to run through BPF filter. + * @hdr_len: BPF filter expects skb data pointer at payload data. If + * the skb does not yet point at the payload, this parameter represents + * how far the pointer needs to advance to reach the payload. + * Returns a socket that should receive the packet (or NULL on error). + */ +struct sock *reuseport_select_sock(struct sock *sk, + u32 hash, + struct sk_buff *skb, + int hdr_len) +{ + struct sock_reuseport *reuse; + struct bpf_prog *prog; + struct sock *sk2 = NULL; + u16 socks; + + rcu_read_lock(); + reuse = rcu_dereference(sk->sk_reuseport_cb); + + /* if memory allocation failed or add call is not yet complete */ + if (!reuse) + goto out; + + prog = rcu_dereference(reuse->prog); + socks = READ_ONCE(reuse->num_socks); + if (likely(socks)) { + /* paired with smp_wmb() in __reuseport_add_sock() */ + smp_rmb(); + + if (!prog || !skb) + goto select_by_hash; + + if (prog->type == BPF_PROG_TYPE_SK_REUSEPORT) + sk2 = bpf_run_sk_reuseport(reuse, sk, prog, skb, NULL, hash); + else + sk2 = run_bpf_filter(reuse, socks, prog, skb, hdr_len); + +select_by_hash: + /* no bpf or invalid bpf result: fall back to hash usage */ + if (!sk2) + sk2 = reuseport_select_sock_by_hash(reuse, hash, socks); + } + +out: + rcu_read_unlock(); + return sk2; +} +EXPORT_SYMBOL(reuseport_select_sock); + +/** + * reuseport_migrate_sock - Select a socket from an SO_REUSEPORT group. + * @sk: close()ed or shutdown()ed socket in the group. + * @migrating_sk: ESTABLISHED/SYN_RECV full socket in the accept queue or + * NEW_SYN_RECV request socket during 3WHS. + * @skb: skb to run through BPF filter. + * Returns a socket (with sk_refcnt +1) that should accept the child socket + * (or NULL on error). + */ +struct sock *reuseport_migrate_sock(struct sock *sk, + struct sock *migrating_sk, + struct sk_buff *skb) +{ + struct sock_reuseport *reuse; + struct sock *nsk = NULL; + bool allocated = false; + struct bpf_prog *prog; + u16 socks; + u32 hash; + + rcu_read_lock(); + + reuse = rcu_dereference(sk->sk_reuseport_cb); + if (!reuse) + goto out; + + socks = READ_ONCE(reuse->num_socks); + if (unlikely(!socks)) + goto failure; + + /* paired with smp_wmb() in __reuseport_add_sock() */ + smp_rmb(); + + hash = migrating_sk->sk_hash; + prog = rcu_dereference(reuse->prog); + if (!prog || prog->expected_attach_type != BPF_SK_REUSEPORT_SELECT_OR_MIGRATE) { + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_migrate_req)) + goto select_by_hash; + goto failure; + } + + if (!skb) { + skb = alloc_skb(0, GFP_ATOMIC); + if (!skb) + goto failure; + allocated = true; + } + + nsk = bpf_run_sk_reuseport(reuse, sk, prog, skb, migrating_sk, hash); + + if (allocated) + kfree_skb(skb); + +select_by_hash: + if (!nsk) + nsk = reuseport_select_sock_by_hash(reuse, hash, socks); + + if (IS_ERR_OR_NULL(nsk) || unlikely(!refcount_inc_not_zero(&nsk->sk_refcnt))) { + nsk = NULL; + goto failure; + } + +out: + rcu_read_unlock(); + return nsk; + +failure: + __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMIGRATEREQFAILURE); + goto out; +} +EXPORT_SYMBOL(reuseport_migrate_sock); + +int reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog) +{ + struct sock_reuseport *reuse; + struct bpf_prog *old_prog; + + if (sk_unhashed(sk)) { + int err; + + if (!sk->sk_reuseport) + return -EINVAL; + + err = reuseport_alloc(sk, false); + if (err) + return err; + } else if (!rcu_access_pointer(sk->sk_reuseport_cb)) { + /* The socket wasn't bound with SO_REUSEPORT */ + return -EINVAL; + } + + spin_lock_bh(&reuseport_lock); + reuse = rcu_dereference_protected(sk->sk_reuseport_cb, + lockdep_is_held(&reuseport_lock)); + old_prog = rcu_dereference_protected(reuse->prog, + lockdep_is_held(&reuseport_lock)); + rcu_assign_pointer(reuse->prog, prog); + spin_unlock_bh(&reuseport_lock); + + sk_reuseport_prog_free(old_prog); + return 0; +} +EXPORT_SYMBOL(reuseport_attach_prog); + +int reuseport_detach_prog(struct sock *sk) +{ + struct sock_reuseport *reuse; + struct bpf_prog *old_prog; + + old_prog = NULL; + spin_lock_bh(&reuseport_lock); + reuse = rcu_dereference_protected(sk->sk_reuseport_cb, + lockdep_is_held(&reuseport_lock)); + + /* reuse must be checked after acquiring the reuseport_lock + * because reuseport_grow() can detach a closed sk. + */ + if (!reuse) { + spin_unlock_bh(&reuseport_lock); + return sk->sk_reuseport ? -ENOENT : -EINVAL; + } + + if (sk_unhashed(sk) && reuse->num_closed_socks) { + spin_unlock_bh(&reuseport_lock); + return -ENOENT; + } + + old_prog = rcu_replace_pointer(reuse->prog, old_prog, + lockdep_is_held(&reuseport_lock)); + spin_unlock_bh(&reuseport_lock); + + if (!old_prog) + return -ENOENT; + + sk_reuseport_prog_free(old_prog); + return 0; +} +EXPORT_SYMBOL(reuseport_detach_prog); diff --git a/net/core/stream.c b/net/core/stream.c new file mode 100644 index 000000000..30e7deff4 --- /dev/null +++ b/net/core/stream.c @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * SUCS NET3: + * + * Generic stream handling routines. These are generic for most + * protocols. Even IP. Tonight 8-). + * This is used because TCP, LLC (others too) layer all have mostly + * identical sendmsg() and recvmsg() code. + * So we (will) share it here. + * + * Authors: Arnaldo Carvalho de Melo + * (from old tcp.c code) + * Alan Cox (Borrowed comments 8-)) + */ + +#include +#include +#include +#include +#include +#include +#include + +/** + * sk_stream_write_space - stream socket write_space callback. + * @sk: socket + * + * FIXME: write proper description + */ +void sk_stream_write_space(struct sock *sk) +{ + struct socket *sock = sk->sk_socket; + struct socket_wq *wq; + + if (__sk_stream_is_writeable(sk, 1) && sock) { + clear_bit(SOCK_NOSPACE, &sock->flags); + + rcu_read_lock(); + wq = rcu_dereference(sk->sk_wq); + if (skwq_has_sleeper(wq)) + wake_up_interruptible_poll(&wq->wait, EPOLLOUT | + EPOLLWRNORM | EPOLLWRBAND); + if (wq && wq->fasync_list && !(sk->sk_shutdown & SEND_SHUTDOWN)) + sock_wake_async(wq, SOCK_WAKE_SPACE, POLL_OUT); + rcu_read_unlock(); + } +} + +/** + * sk_stream_wait_connect - Wait for a socket to get into the connected state + * @sk: sock to wait on + * @timeo_p: for how long to wait + * + * Must be called with the socket locked. + */ +int sk_stream_wait_connect(struct sock *sk, long *timeo_p) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + struct task_struct *tsk = current; + int done; + + do { + int err = sock_error(sk); + if (err) + return err; + if ((1 << sk->sk_state) & ~(TCPF_SYN_SENT | TCPF_SYN_RECV)) + return -EPIPE; + if (!*timeo_p) + return -EAGAIN; + if (signal_pending(tsk)) + return sock_intr_errno(*timeo_p); + + add_wait_queue(sk_sleep(sk), &wait); + sk->sk_write_pending++; + done = sk_wait_event(sk, timeo_p, + !READ_ONCE(sk->sk_err) && + !((1 << READ_ONCE(sk->sk_state)) & + ~(TCPF_ESTABLISHED | TCPF_CLOSE_WAIT)), &wait); + remove_wait_queue(sk_sleep(sk), &wait); + sk->sk_write_pending--; + } while (!done); + return done < 0 ? done : 0; +} +EXPORT_SYMBOL(sk_stream_wait_connect); + +/** + * sk_stream_closing - Return 1 if we still have things to send in our buffers. + * @sk: socket to verify + */ +static int sk_stream_closing(const struct sock *sk) +{ + return (1 << READ_ONCE(sk->sk_state)) & + (TCPF_FIN_WAIT1 | TCPF_CLOSING | TCPF_LAST_ACK); +} + +void sk_stream_wait_close(struct sock *sk, long timeout) +{ + if (timeout) { + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + add_wait_queue(sk_sleep(sk), &wait); + + do { + if (sk_wait_event(sk, &timeout, !sk_stream_closing(sk), &wait)) + break; + } while (!signal_pending(current) && timeout); + + remove_wait_queue(sk_sleep(sk), &wait); + } +} +EXPORT_SYMBOL(sk_stream_wait_close); + +/** + * sk_stream_wait_memory - Wait for more memory for a socket + * @sk: socket to wait for memory + * @timeo_p: for how long + */ +int sk_stream_wait_memory(struct sock *sk, long *timeo_p) +{ + int ret, err = 0; + long vm_wait = 0; + long current_timeo = *timeo_p; + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + if (sk_stream_memory_free(sk)) + current_timeo = vm_wait = prandom_u32_max(HZ / 5) + 2; + + add_wait_queue(sk_sleep(sk), &wait); + + while (1) { + sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk); + + if (sk->sk_err || (sk->sk_shutdown & SEND_SHUTDOWN)) + goto do_error; + if (!*timeo_p) + goto do_eagain; + if (signal_pending(current)) + goto do_interrupted; + sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); + if (sk_stream_memory_free(sk) && !vm_wait) + break; + + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + sk->sk_write_pending++; + ret = sk_wait_event(sk, ¤t_timeo, READ_ONCE(sk->sk_err) || + (READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN) || + (sk_stream_memory_free(sk) && !vm_wait), + &wait); + sk->sk_write_pending--; + if (ret < 0) + goto do_error; + + if (vm_wait) { + vm_wait -= current_timeo; + current_timeo = *timeo_p; + if (current_timeo != MAX_SCHEDULE_TIMEOUT && + (current_timeo -= vm_wait) < 0) + current_timeo = 0; + vm_wait = 0; + } + *timeo_p = current_timeo; + } +out: + if (!sock_flag(sk, SOCK_DEAD)) + remove_wait_queue(sk_sleep(sk), &wait); + return err; + +do_error: + err = -EPIPE; + goto out; +do_eagain: + /* Make sure that whenever EAGAIN is returned, EPOLLOUT event can + * be generated later. + * When TCP receives ACK packets that make room, tcp_check_space() + * only calls tcp_new_space() if SOCK_NOSPACE is set. + */ + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + err = -EAGAIN; + goto out; +do_interrupted: + err = sock_intr_errno(*timeo_p); + goto out; +} +EXPORT_SYMBOL(sk_stream_wait_memory); + +int sk_stream_error(struct sock *sk, int flags, int err) +{ + if (err == -EPIPE) + err = sock_error(sk) ? : -EPIPE; + if (err == -EPIPE && !(flags & MSG_NOSIGNAL)) + send_sig(SIGPIPE, current, 0); + return err; +} +EXPORT_SYMBOL(sk_stream_error); + +void sk_stream_kill_queues(struct sock *sk) +{ + /* First the read buffer. */ + __skb_queue_purge(&sk->sk_receive_queue); + + /* Next, the error queue. + * We need to use queue lock, because other threads might + * add packets to the queue without socket lock being held. + */ + skb_queue_purge(&sk->sk_error_queue); + + /* Next, the write queue. */ + WARN_ON_ONCE(!skb_queue_empty(&sk->sk_write_queue)); + + /* Account for returned memory. */ + sk_mem_reclaim_final(sk); + + WARN_ON_ONCE(sk->sk_wmem_queued); + + /* It is _impossible_ for the backlog to contain anything + * when we get here. All user references to this socket + * have gone away, only the net layer knows can touch it. + */ +} +EXPORT_SYMBOL(sk_stream_kill_queues); diff --git a/net/core/sysctl_net_core.c b/net/core/sysctl_net_core.c new file mode 100644 index 000000000..5b1ce656b --- /dev/null +++ b/net/core/sysctl_net_core.c @@ -0,0 +1,687 @@ +// SPDX-License-Identifier: GPL-2.0 +/* -*- linux-c -*- + * sysctl_net_core.c: sysctl interface to net core subsystem. + * + * Begun April 1, 1996, Mike Shaver. + * Added /proc/sys/net/core directory entry (empty =) ). [MS] + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "dev.h" + +static int int_3600 = 3600; +static int min_sndbuf = SOCK_MIN_SNDBUF; +static int min_rcvbuf = SOCK_MIN_RCVBUF; +static int max_skb_frags = MAX_SKB_FRAGS; + +static int net_msg_warn; /* Unused, but still a sysctl */ + +int sysctl_fb_tunnels_only_for_init_net __read_mostly = 0; +EXPORT_SYMBOL(sysctl_fb_tunnels_only_for_init_net); + +/* 0 - Keep current behavior: + * IPv4: inherit all current settings from init_net + * IPv6: reset all settings to default + * 1 - Both inherit all current settings from init_net + * 2 - Both reset all settings to default + * 3 - Both inherit all settings from current netns + */ +int sysctl_devconf_inherit_init_net __read_mostly; +EXPORT_SYMBOL(sysctl_devconf_inherit_init_net); + +#ifdef CONFIG_RPS +static int rps_sock_flow_sysctl(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + unsigned int orig_size, size; + int ret, i; + struct ctl_table tmp = { + .data = &size, + .maxlen = sizeof(size), + .mode = table->mode + }; + struct rps_sock_flow_table *orig_sock_table, *sock_table; + static DEFINE_MUTEX(sock_flow_mutex); + + mutex_lock(&sock_flow_mutex); + + orig_sock_table = rcu_dereference_protected(rps_sock_flow_table, + lockdep_is_held(&sock_flow_mutex)); + size = orig_size = orig_sock_table ? orig_sock_table->mask + 1 : 0; + + ret = proc_dointvec(&tmp, write, buffer, lenp, ppos); + + if (write) { + if (size) { + if (size > 1<<29) { + /* Enforce limit to prevent overflow */ + mutex_unlock(&sock_flow_mutex); + return -EINVAL; + } + size = roundup_pow_of_two(size); + if (size != orig_size) { + sock_table = + vmalloc(RPS_SOCK_FLOW_TABLE_SIZE(size)); + if (!sock_table) { + mutex_unlock(&sock_flow_mutex); + return -ENOMEM; + } + rps_cpu_mask = roundup_pow_of_two(nr_cpu_ids) - 1; + sock_table->mask = size - 1; + } else + sock_table = orig_sock_table; + + for (i = 0; i < size; i++) + sock_table->ents[i] = RPS_NO_CPU; + } else + sock_table = NULL; + + if (sock_table != orig_sock_table) { + rcu_assign_pointer(rps_sock_flow_table, sock_table); + if (sock_table) { + static_branch_inc(&rps_needed); + static_branch_inc(&rfs_needed); + } + if (orig_sock_table) { + static_branch_dec(&rps_needed); + static_branch_dec(&rfs_needed); + kvfree_rcu(orig_sock_table); + } + } + } + + mutex_unlock(&sock_flow_mutex); + + return ret; +} +#endif /* CONFIG_RPS */ + +#ifdef CONFIG_NET_FLOW_LIMIT +static DEFINE_MUTEX(flow_limit_update_mutex); + +static int flow_limit_cpu_sysctl(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct sd_flow_limit *cur; + struct softnet_data *sd; + cpumask_var_t mask; + int i, len, ret = 0; + + if (!alloc_cpumask_var(&mask, GFP_KERNEL)) + return -ENOMEM; + + if (write) { + ret = cpumask_parse(buffer, mask); + if (ret) + goto done; + + mutex_lock(&flow_limit_update_mutex); + len = sizeof(*cur) + netdev_flow_limit_table_len; + for_each_possible_cpu(i) { + sd = &per_cpu(softnet_data, i); + cur = rcu_dereference_protected(sd->flow_limit, + lockdep_is_held(&flow_limit_update_mutex)); + if (cur && !cpumask_test_cpu(i, mask)) { + RCU_INIT_POINTER(sd->flow_limit, NULL); + kfree_rcu(cur); + } else if (!cur && cpumask_test_cpu(i, mask)) { + cur = kzalloc_node(len, GFP_KERNEL, + cpu_to_node(i)); + if (!cur) { + /* not unwinding previous changes */ + ret = -ENOMEM; + goto write_unlock; + } + cur->num_buckets = netdev_flow_limit_table_len; + rcu_assign_pointer(sd->flow_limit, cur); + } + } +write_unlock: + mutex_unlock(&flow_limit_update_mutex); + } else { + char kbuf[128]; + + if (*ppos || !*lenp) { + *lenp = 0; + goto done; + } + + cpumask_clear(mask); + rcu_read_lock(); + for_each_possible_cpu(i) { + sd = &per_cpu(softnet_data, i); + if (rcu_dereference(sd->flow_limit)) + cpumask_set_cpu(i, mask); + } + rcu_read_unlock(); + + len = min(sizeof(kbuf) - 1, *lenp); + len = scnprintf(kbuf, len, "%*pb", cpumask_pr_args(mask)); + if (!len) { + *lenp = 0; + goto done; + } + if (len < *lenp) + kbuf[len++] = '\n'; + memcpy(buffer, kbuf, len); + *lenp = len; + *ppos += len; + } + +done: + free_cpumask_var(mask); + return ret; +} + +static int flow_limit_table_len_sysctl(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + unsigned int old, *ptr; + int ret; + + mutex_lock(&flow_limit_update_mutex); + + ptr = table->data; + old = *ptr; + ret = proc_dointvec(table, write, buffer, lenp, ppos); + if (!ret && write && !is_power_of_2(*ptr)) { + *ptr = old; + ret = -EINVAL; + } + + mutex_unlock(&flow_limit_update_mutex); + return ret; +} +#endif /* CONFIG_NET_FLOW_LIMIT */ + +#ifdef CONFIG_NET_SCHED +static int set_default_qdisc(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + char id[IFNAMSIZ]; + struct ctl_table tbl = { + .data = id, + .maxlen = IFNAMSIZ, + }; + int ret; + + qdisc_get_default(id, IFNAMSIZ); + + ret = proc_dostring(&tbl, write, buffer, lenp, ppos); + if (write && ret == 0) + ret = qdisc_set_default(id); + return ret; +} +#endif + +static int proc_do_dev_weight(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + static DEFINE_MUTEX(dev_weight_mutex); + int ret, weight; + + mutex_lock(&dev_weight_mutex); + ret = proc_dointvec(table, write, buffer, lenp, ppos); + if (!ret && write) { + weight = READ_ONCE(weight_p); + WRITE_ONCE(dev_rx_weight, weight * dev_weight_rx_bias); + WRITE_ONCE(dev_tx_weight, weight * dev_weight_tx_bias); + } + mutex_unlock(&dev_weight_mutex); + + return ret; +} + +static int proc_do_rss_key(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct ctl_table fake_table; + char buf[NETDEV_RSS_KEY_LEN * 3]; + + snprintf(buf, sizeof(buf), "%*phC", NETDEV_RSS_KEY_LEN, netdev_rss_key); + fake_table.data = buf; + fake_table.maxlen = sizeof(buf); + return proc_dostring(&fake_table, write, buffer, lenp, ppos); +} + +#ifdef CONFIG_BPF_JIT +static int proc_dointvec_minmax_bpf_enable(struct ctl_table *table, int write, + void *buffer, size_t *lenp, + loff_t *ppos) +{ + int ret, jit_enable = *(int *)table->data; + int min = *(int *)table->extra1; + int max = *(int *)table->extra2; + struct ctl_table tmp = *table; + + if (write && !capable(CAP_SYS_ADMIN)) + return -EPERM; + + tmp.data = &jit_enable; + ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos); + if (write && !ret) { + if (jit_enable < 2 || + (jit_enable == 2 && bpf_dump_raw_ok(current_cred()))) { + *(int *)table->data = jit_enable; + if (jit_enable == 2) + pr_warn("bpf_jit_enable = 2 was set! NEVER use this in production, only for JIT debugging!\n"); + } else { + ret = -EPERM; + } + } + + if (write && ret && min == max) + pr_info_once("CONFIG_BPF_JIT_ALWAYS_ON is enabled, bpf_jit_enable is permanently set to 1.\n"); + + return ret; +} + +# ifdef CONFIG_HAVE_EBPF_JIT +static int +proc_dointvec_minmax_bpf_restricted(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + if (!capable(CAP_SYS_ADMIN)) + return -EPERM; + + return proc_dointvec_minmax(table, write, buffer, lenp, ppos); +} +# endif /* CONFIG_HAVE_EBPF_JIT */ + +static int +proc_dolongvec_minmax_bpf_restricted(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + if (!capable(CAP_SYS_ADMIN)) + return -EPERM; + + return proc_doulongvec_minmax(table, write, buffer, lenp, ppos); +} +#endif + +static struct ctl_table net_core_table[] = { + { + .procname = "wmem_max", + .data = &sysctl_wmem_max, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_sndbuf, + }, + { + .procname = "rmem_max", + .data = &sysctl_rmem_max, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_rcvbuf, + }, + { + .procname = "wmem_default", + .data = &sysctl_wmem_default, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_sndbuf, + }, + { + .procname = "rmem_default", + .data = &sysctl_rmem_default, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_rcvbuf, + }, + { + .procname = "dev_weight", + .data = &weight_p, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_do_dev_weight, + }, + { + .procname = "dev_weight_rx_bias", + .data = &dev_weight_rx_bias, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_do_dev_weight, + }, + { + .procname = "dev_weight_tx_bias", + .data = &dev_weight_tx_bias, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_do_dev_weight, + }, + { + .procname = "netdev_max_backlog", + .data = &netdev_max_backlog, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "netdev_rss_key", + .data = &netdev_rss_key, + .maxlen = sizeof(int), + .mode = 0444, + .proc_handler = proc_do_rss_key, + }, +#ifdef CONFIG_BPF_JIT + { + .procname = "bpf_jit_enable", + .data = &bpf_jit_enable, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax_bpf_enable, +# ifdef CONFIG_BPF_JIT_ALWAYS_ON + .extra1 = SYSCTL_ONE, + .extra2 = SYSCTL_ONE, +# else + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_TWO, +# endif + }, +# ifdef CONFIG_HAVE_EBPF_JIT + { + .procname = "bpf_jit_harden", + .data = &bpf_jit_harden, + .maxlen = sizeof(int), + .mode = 0600, + .proc_handler = proc_dointvec_minmax_bpf_restricted, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_TWO, + }, + { + .procname = "bpf_jit_kallsyms", + .data = &bpf_jit_kallsyms, + .maxlen = sizeof(int), + .mode = 0600, + .proc_handler = proc_dointvec_minmax_bpf_restricted, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, +# endif + { + .procname = "bpf_jit_limit", + .data = &bpf_jit_limit, + .maxlen = sizeof(long), + .mode = 0600, + .proc_handler = proc_dolongvec_minmax_bpf_restricted, + .extra1 = SYSCTL_LONG_ONE, + .extra2 = &bpf_jit_limit_max, + }, +#endif + { + .procname = "netdev_tstamp_prequeue", + .data = &netdev_tstamp_prequeue, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "message_cost", + .data = &net_ratelimit_state.interval, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "message_burst", + .data = &net_ratelimit_state.burst, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "optmem_max", + .data = &sysctl_optmem_max, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "tstamp_allow_data", + .data = &sysctl_tstamp_allow_data, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE + }, +#ifdef CONFIG_RPS + { + .procname = "rps_sock_flow_entries", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = rps_sock_flow_sysctl + }, +#endif +#ifdef CONFIG_NET_FLOW_LIMIT + { + .procname = "flow_limit_cpu_bitmap", + .mode = 0644, + .proc_handler = flow_limit_cpu_sysctl + }, + { + .procname = "flow_limit_table_len", + .data = &netdev_flow_limit_table_len, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = flow_limit_table_len_sysctl + }, +#endif /* CONFIG_NET_FLOW_LIMIT */ +#ifdef CONFIG_NET_RX_BUSY_POLL + { + .procname = "busy_poll", + .data = &sysctl_net_busy_poll, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + }, + { + .procname = "busy_read", + .data = &sysctl_net_busy_read, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + }, +#endif +#ifdef CONFIG_NET_SCHED + { + .procname = "default_qdisc", + .mode = 0644, + .maxlen = IFNAMSIZ, + .proc_handler = set_default_qdisc + }, +#endif + { + .procname = "netdev_budget", + .data = &netdev_budget, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "warnings", + .data = &net_msg_warn, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "max_skb_frags", + .data = &sysctl_max_skb_frags, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + .extra2 = &max_skb_frags, + }, + { + .procname = "netdev_budget_usecs", + .data = &netdev_budget_usecs, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + }, + { + .procname = "fb_tunnels_only_for_init_net", + .data = &sysctl_fb_tunnels_only_for_init_net, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_TWO, + }, + { + .procname = "devconf_inherit_init_net", + .data = &sysctl_devconf_inherit_init_net, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_THREE, + }, + { + .procname = "high_order_alloc_disable", + .data = &net_high_order_alloc_disable_key.key, + .maxlen = sizeof(net_high_order_alloc_disable_key), + .mode = 0644, + .proc_handler = proc_do_static_key, + }, + { + .procname = "gro_normal_batch", + .data = &gro_normal_batch, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + }, + { + .procname = "netdev_unregister_timeout_secs", + .data = &netdev_unregister_timeout_secs, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + .extra2 = &int_3600, + }, + { + .procname = "skb_defer_max", + .data = &sysctl_skb_defer_max, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + }, + { } +}; + +static struct ctl_table netns_core_table[] = { + { + .procname = "somaxconn", + .data = &init_net.core.sysctl_somaxconn, + .maxlen = sizeof(int), + .mode = 0644, + .extra1 = SYSCTL_ZERO, + .proc_handler = proc_dointvec_minmax + }, + { + .procname = "txrehash", + .data = &init_net.core.sysctl_txrehash, + .maxlen = sizeof(u8), + .mode = 0644, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + .proc_handler = proc_dou8vec_minmax, + }, + { } +}; + +static int __init fb_tunnels_only_for_init_net_sysctl_setup(char *str) +{ + /* fallback tunnels for initns only */ + if (!strncmp(str, "initns", 6)) + sysctl_fb_tunnels_only_for_init_net = 1; + /* no fallback tunnels anywhere */ + else if (!strncmp(str, "none", 4)) + sysctl_fb_tunnels_only_for_init_net = 2; + + return 1; +} +__setup("fb_tunnels=", fb_tunnels_only_for_init_net_sysctl_setup); + +static __net_init int sysctl_core_net_init(struct net *net) +{ + struct ctl_table *tbl, *tmp; + + tbl = netns_core_table; + if (!net_eq(net, &init_net)) { + tbl = kmemdup(tbl, sizeof(netns_core_table), GFP_KERNEL); + if (tbl == NULL) + goto err_dup; + + for (tmp = tbl; tmp->procname; tmp++) + tmp->data += (char *)net - (char *)&init_net; + + /* Don't export any sysctls to unprivileged users */ + if (net->user_ns != &init_user_ns) { + tbl[0].procname = NULL; + } + } + + net->core.sysctl_hdr = register_net_sysctl(net, "net/core", tbl); + if (net->core.sysctl_hdr == NULL) + goto err_reg; + + return 0; + +err_reg: + if (tbl != netns_core_table) + kfree(tbl); +err_dup: + return -ENOMEM; +} + +static __net_exit void sysctl_core_net_exit(struct net *net) +{ + struct ctl_table *tbl; + + tbl = net->core.sysctl_hdr->ctl_table_arg; + unregister_net_sysctl_table(net->core.sysctl_hdr); + BUG_ON(tbl == netns_core_table); + kfree(tbl); +} + +static __net_initdata struct pernet_operations sysctl_core_ops = { + .init = sysctl_core_net_init, + .exit = sysctl_core_net_exit, +}; + +static __init int sysctl_core_init(void) +{ + register_net_sysctl(&init_net, "net/core", net_core_table); + return register_pernet_subsys(&sysctl_core_ops); +} + +fs_initcall(sysctl_core_init); diff --git a/net/core/timestamping.c b/net/core/timestamping.c new file mode 100644 index 000000000..04840697f --- /dev/null +++ b/net/core/timestamping.c @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * PTP 1588 clock support - support for timestamping in PHY devices + * + * Copyright (C) 2010 OMICRON electronics GmbH + */ +#include +#include +#include +#include +#include + +static unsigned int classify(const struct sk_buff *skb) +{ + if (likely(skb->dev && skb->dev->phydev && + skb->dev->phydev->mii_ts)) + return ptp_classify_raw(skb); + else + return PTP_CLASS_NONE; +} + +void skb_clone_tx_timestamp(struct sk_buff *skb) +{ + struct mii_timestamper *mii_ts; + struct sk_buff *clone; + unsigned int type; + + if (!skb->sk) + return; + + type = classify(skb); + if (type == PTP_CLASS_NONE) + return; + + mii_ts = skb->dev->phydev->mii_ts; + if (likely(mii_ts->txtstamp)) { + clone = skb_clone_sk(skb); + if (!clone) + return; + mii_ts->txtstamp(mii_ts, clone, type); + } +} +EXPORT_SYMBOL_GPL(skb_clone_tx_timestamp); + +bool skb_defer_rx_timestamp(struct sk_buff *skb) +{ + struct mii_timestamper *mii_ts; + unsigned int type; + + if (!skb->dev || !skb->dev->phydev || !skb->dev->phydev->mii_ts) + return false; + + if (skb_headroom(skb) < ETH_HLEN) + return false; + + __skb_push(skb, ETH_HLEN); + + type = ptp_classify_raw(skb); + + __skb_pull(skb, ETH_HLEN); + + if (type == PTP_CLASS_NONE) + return false; + + mii_ts = skb->dev->phydev->mii_ts; + if (likely(mii_ts->rxtstamp)) + return mii_ts->rxtstamp(mii_ts, skb, type); + + return false; +} +EXPORT_SYMBOL_GPL(skb_defer_rx_timestamp); diff --git a/net/core/tso.c b/net/core/tso.c new file mode 100644 index 000000000..4148f6d48 --- /dev/null +++ b/net/core/tso.c @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include + +/* Calculate expected number of TX descriptors */ +int tso_count_descs(const struct sk_buff *skb) +{ + /* The Marvell Way */ + return skb_shinfo(skb)->gso_segs * 2 + skb_shinfo(skb)->nr_frags; +} +EXPORT_SYMBOL(tso_count_descs); + +void tso_build_hdr(const struct sk_buff *skb, char *hdr, struct tso_t *tso, + int size, bool is_last) +{ + int hdr_len = skb_transport_offset(skb) + tso->tlen; + int mac_hdr_len = skb_network_offset(skb); + + memcpy(hdr, skb->data, hdr_len); + if (!tso->ipv6) { + struct iphdr *iph = (void *)(hdr + mac_hdr_len); + + iph->id = htons(tso->ip_id); + iph->tot_len = htons(size + hdr_len - mac_hdr_len); + tso->ip_id++; + } else { + struct ipv6hdr *iph = (void *)(hdr + mac_hdr_len); + + iph->payload_len = htons(size + tso->tlen); + } + hdr += skb_transport_offset(skb); + if (tso->tlen != sizeof(struct udphdr)) { + struct tcphdr *tcph = (struct tcphdr *)hdr; + + put_unaligned_be32(tso->tcp_seq, &tcph->seq); + + if (!is_last) { + /* Clear all special flags for not last packet */ + tcph->psh = 0; + tcph->fin = 0; + tcph->rst = 0; + } + } else { + struct udphdr *uh = (struct udphdr *)hdr; + + uh->len = htons(sizeof(*uh) + size); + } +} +EXPORT_SYMBOL(tso_build_hdr); + +void tso_build_data(const struct sk_buff *skb, struct tso_t *tso, int size) +{ + tso->tcp_seq += size; /* not worth avoiding this operation for UDP */ + tso->size -= size; + tso->data += size; + + if ((tso->size == 0) && + (tso->next_frag_idx < skb_shinfo(skb)->nr_frags)) { + skb_frag_t *frag = &skb_shinfo(skb)->frags[tso->next_frag_idx]; + + /* Move to next segment */ + tso->size = skb_frag_size(frag); + tso->data = skb_frag_address(frag); + tso->next_frag_idx++; + } +} +EXPORT_SYMBOL(tso_build_data); + +int tso_start(struct sk_buff *skb, struct tso_t *tso) +{ + int tlen = skb_is_gso_tcp(skb) ? tcp_hdrlen(skb) : sizeof(struct udphdr); + int hdr_len = skb_transport_offset(skb) + tlen; + + tso->tlen = tlen; + tso->ip_id = ntohs(ip_hdr(skb)->id); + tso->tcp_seq = (tlen != sizeof(struct udphdr)) ? ntohl(tcp_hdr(skb)->seq) : 0; + tso->next_frag_idx = 0; + tso->ipv6 = vlan_get_protocol(skb) == htons(ETH_P_IPV6); + + /* Build first data */ + tso->size = skb_headlen(skb) - hdr_len; + tso->data = skb->data + hdr_len; + if ((tso->size == 0) && + (tso->next_frag_idx < skb_shinfo(skb)->nr_frags)) { + skb_frag_t *frag = &skb_shinfo(skb)->frags[tso->next_frag_idx]; + + /* Move to next segment */ + tso->size = skb_frag_size(frag); + tso->data = skb_frag_address(frag); + tso->next_frag_idx++; + } + return hdr_len; +} +EXPORT_SYMBOL(tso_start); diff --git a/net/core/utils.c b/net/core/utils.c new file mode 100644 index 000000000..938495bc1 --- /dev/null +++ b/net/core/utils.c @@ -0,0 +1,486 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Generic address resultion entity + * + * Authors: + * net_random Alan Cox + * net_ratelimit Andi Kleen + * in{4,6}_pton YOSHIFUJI Hideaki, Copyright (C)2006 USAGI/WIDE Project + * + * Created by Alexey Kuznetsov + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +DEFINE_RATELIMIT_STATE(net_ratelimit_state, 5 * HZ, 10); +/* + * All net warning printk()s should be guarded by this function. + */ +int net_ratelimit(void) +{ + return __ratelimit(&net_ratelimit_state); +} +EXPORT_SYMBOL(net_ratelimit); + +/* + * Convert an ASCII string to binary IP. + * This is outside of net/ipv4/ because various code that uses IP addresses + * is otherwise not dependent on the TCP/IP stack. + */ + +__be32 in_aton(const char *str) +{ + unsigned int l; + unsigned int val; + int i; + + l = 0; + for (i = 0; i < 4; i++) { + l <<= 8; + if (*str != '\0') { + val = 0; + while (*str != '\0' && *str != '.' && *str != '\n') { + val *= 10; + val += *str - '0'; + str++; + } + l |= val; + if (*str != '\0') + str++; + } + } + return htonl(l); +} +EXPORT_SYMBOL(in_aton); + +#define IN6PTON_XDIGIT 0x00010000 +#define IN6PTON_DIGIT 0x00020000 +#define IN6PTON_COLON_MASK 0x00700000 +#define IN6PTON_COLON_1 0x00100000 /* single : requested */ +#define IN6PTON_COLON_2 0x00200000 /* second : requested */ +#define IN6PTON_COLON_1_2 0x00400000 /* :: requested */ +#define IN6PTON_DOT 0x00800000 /* . */ +#define IN6PTON_DELIM 0x10000000 +#define IN6PTON_NULL 0x20000000 /* first/tail */ +#define IN6PTON_UNKNOWN 0x40000000 + +static inline int xdigit2bin(char c, int delim) +{ + int val; + + if (c == delim || c == '\0') + return IN6PTON_DELIM; + if (c == ':') + return IN6PTON_COLON_MASK; + if (c == '.') + return IN6PTON_DOT; + + val = hex_to_bin(c); + if (val >= 0) + return val | IN6PTON_XDIGIT | (val < 10 ? IN6PTON_DIGIT : 0); + + if (delim == -1) + return IN6PTON_DELIM; + return IN6PTON_UNKNOWN; +} + +/** + * in4_pton - convert an IPv4 address from literal to binary representation + * @src: the start of the IPv4 address string + * @srclen: the length of the string, -1 means strlen(src) + * @dst: the binary (u8[4] array) representation of the IPv4 address + * @delim: the delimiter of the IPv4 address in @src, -1 means no delimiter + * @end: A pointer to the end of the parsed string will be placed here + * + * Return one on success, return zero when any error occurs + * and @end will point to the end of the parsed string. + * + */ +int in4_pton(const char *src, int srclen, + u8 *dst, + int delim, const char **end) +{ + const char *s; + u8 *d; + u8 dbuf[4]; + int ret = 0; + int i; + int w = 0; + + if (srclen < 0) + srclen = strlen(src); + s = src; + d = dbuf; + i = 0; + while (1) { + int c; + c = xdigit2bin(srclen > 0 ? *s : '\0', delim); + if (!(c & (IN6PTON_DIGIT | IN6PTON_DOT | IN6PTON_DELIM | IN6PTON_COLON_MASK))) { + goto out; + } + if (c & (IN6PTON_DOT | IN6PTON_DELIM | IN6PTON_COLON_MASK)) { + if (w == 0) + goto out; + *d++ = w & 0xff; + w = 0; + i++; + if (c & (IN6PTON_DELIM | IN6PTON_COLON_MASK)) { + if (i != 4) + goto out; + break; + } + goto cont; + } + w = (w * 10) + c; + if ((w & 0xffff) > 255) { + goto out; + } +cont: + if (i >= 4) + goto out; + s++; + srclen--; + } + ret = 1; + memcpy(dst, dbuf, sizeof(dbuf)); +out: + if (end) + *end = s; + return ret; +} +EXPORT_SYMBOL(in4_pton); + +/** + * in6_pton - convert an IPv6 address from literal to binary representation + * @src: the start of the IPv6 address string + * @srclen: the length of the string, -1 means strlen(src) + * @dst: the binary (u8[16] array) representation of the IPv6 address + * @delim: the delimiter of the IPv6 address in @src, -1 means no delimiter + * @end: A pointer to the end of the parsed string will be placed here + * + * Return one on success, return zero when any error occurs + * and @end will point to the end of the parsed string. + * + */ +int in6_pton(const char *src, int srclen, + u8 *dst, + int delim, const char **end) +{ + const char *s, *tok = NULL; + u8 *d, *dc = NULL; + u8 dbuf[16]; + int ret = 0; + int i; + int state = IN6PTON_COLON_1_2 | IN6PTON_XDIGIT | IN6PTON_NULL; + int w = 0; + + memset(dbuf, 0, sizeof(dbuf)); + + s = src; + d = dbuf; + if (srclen < 0) + srclen = strlen(src); + + while (1) { + int c; + + c = xdigit2bin(srclen > 0 ? *s : '\0', delim); + if (!(c & state)) + goto out; + if (c & (IN6PTON_DELIM | IN6PTON_COLON_MASK)) { + /* process one 16-bit word */ + if (!(state & IN6PTON_NULL)) { + *d++ = (w >> 8) & 0xff; + *d++ = w & 0xff; + } + w = 0; + if (c & IN6PTON_DELIM) { + /* We've processed last word */ + break; + } + /* + * COLON_1 => XDIGIT + * COLON_2 => XDIGIT|DELIM + * COLON_1_2 => COLON_2 + */ + switch (state & IN6PTON_COLON_MASK) { + case IN6PTON_COLON_2: + dc = d; + state = IN6PTON_XDIGIT | IN6PTON_DELIM; + if (dc - dbuf >= sizeof(dbuf)) + state |= IN6PTON_NULL; + break; + case IN6PTON_COLON_1|IN6PTON_COLON_1_2: + state = IN6PTON_XDIGIT | IN6PTON_COLON_2; + break; + case IN6PTON_COLON_1: + state = IN6PTON_XDIGIT; + break; + case IN6PTON_COLON_1_2: + state = IN6PTON_COLON_2; + break; + default: + state = 0; + } + tok = s + 1; + goto cont; + } + + if (c & IN6PTON_DOT) { + ret = in4_pton(tok ? tok : s, srclen + (int)(s - tok), d, delim, &s); + if (ret > 0) { + d += 4; + break; + } + goto out; + } + + w = (w << 4) | (0xff & c); + state = IN6PTON_COLON_1 | IN6PTON_DELIM; + if (!(w & 0xf000)) { + state |= IN6PTON_XDIGIT; + } + if (!dc && d + 2 < dbuf + sizeof(dbuf)) { + state |= IN6PTON_COLON_1_2; + state &= ~IN6PTON_DELIM; + } + if (d + 2 >= dbuf + sizeof(dbuf)) { + state &= ~(IN6PTON_COLON_1|IN6PTON_COLON_1_2); + } +cont: + if ((dc && d + 4 < dbuf + sizeof(dbuf)) || + d + 4 == dbuf + sizeof(dbuf)) { + state |= IN6PTON_DOT; + } + if (d >= dbuf + sizeof(dbuf)) { + state &= ~(IN6PTON_XDIGIT|IN6PTON_COLON_MASK); + } + s++; + srclen--; + } + + i = 15; d--; + + if (dc) { + while (d >= dc) + dst[i--] = *d--; + while (i >= dc - dbuf) + dst[i--] = 0; + while (i >= 0) + dst[i--] = *d--; + } else + memcpy(dst, dbuf, sizeof(dbuf)); + + ret = 1; +out: + if (end) + *end = s; + return ret; +} +EXPORT_SYMBOL(in6_pton); + +static int inet4_pton(const char *src, u16 port_num, + struct sockaddr_storage *addr) +{ + struct sockaddr_in *addr4 = (struct sockaddr_in *)addr; + int srclen = strlen(src); + + if (srclen > INET_ADDRSTRLEN) + return -EINVAL; + + if (in4_pton(src, srclen, (u8 *)&addr4->sin_addr.s_addr, + '\n', NULL) == 0) + return -EINVAL; + + addr4->sin_family = AF_INET; + addr4->sin_port = htons(port_num); + + return 0; +} + +static int inet6_pton(struct net *net, const char *src, u16 port_num, + struct sockaddr_storage *addr) +{ + struct sockaddr_in6 *addr6 = (struct sockaddr_in6 *)addr; + const char *scope_delim; + int srclen = strlen(src); + + if (srclen > INET6_ADDRSTRLEN) + return -EINVAL; + + if (in6_pton(src, srclen, (u8 *)&addr6->sin6_addr.s6_addr, + '%', &scope_delim) == 0) + return -EINVAL; + + if (ipv6_addr_type(&addr6->sin6_addr) & IPV6_ADDR_LINKLOCAL && + src + srclen != scope_delim && *scope_delim == '%') { + struct net_device *dev; + char scope_id[16]; + size_t scope_len = min_t(size_t, sizeof(scope_id) - 1, + src + srclen - scope_delim - 1); + + memcpy(scope_id, scope_delim + 1, scope_len); + scope_id[scope_len] = '\0'; + + dev = dev_get_by_name(net, scope_id); + if (dev) { + addr6->sin6_scope_id = dev->ifindex; + dev_put(dev); + } else if (kstrtouint(scope_id, 0, &addr6->sin6_scope_id)) { + return -EINVAL; + } + } + + addr6->sin6_family = AF_INET6; + addr6->sin6_port = htons(port_num); + + return 0; +} + +/** + * inet_pton_with_scope - convert an IPv4/IPv6 and port to socket address + * @net: net namespace (used for scope handling) + * @af: address family, AF_INET, AF_INET6 or AF_UNSPEC for either + * @src: the start of the address string + * @port: the start of the port string (or NULL for none) + * @addr: output socket address + * + * Return zero on success, return errno when any error occurs. + */ +int inet_pton_with_scope(struct net *net, __kernel_sa_family_t af, + const char *src, const char *port, struct sockaddr_storage *addr) +{ + u16 port_num; + int ret = -EINVAL; + + if (port) { + if (kstrtou16(port, 0, &port_num)) + return -EINVAL; + } else { + port_num = 0; + } + + switch (af) { + case AF_INET: + ret = inet4_pton(src, port_num, addr); + break; + case AF_INET6: + ret = inet6_pton(net, src, port_num, addr); + break; + case AF_UNSPEC: + ret = inet4_pton(src, port_num, addr); + if (ret) + ret = inet6_pton(net, src, port_num, addr); + break; + default: + pr_err("unexpected address family %d\n", af); + } + + return ret; +} +EXPORT_SYMBOL(inet_pton_with_scope); + +bool inet_addr_is_any(struct sockaddr *addr) +{ + if (addr->sa_family == AF_INET6) { + struct sockaddr_in6 *in6 = (struct sockaddr_in6 *)addr; + const struct sockaddr_in6 in6_any = + { .sin6_addr = IN6ADDR_ANY_INIT }; + + if (!memcmp(in6->sin6_addr.s6_addr, + in6_any.sin6_addr.s6_addr, 16)) + return true; + } else if (addr->sa_family == AF_INET) { + struct sockaddr_in *in = (struct sockaddr_in *)addr; + + if (in->sin_addr.s_addr == htonl(INADDR_ANY)) + return true; + } else { + pr_warn("unexpected address family %u\n", addr->sa_family); + } + + return false; +} +EXPORT_SYMBOL(inet_addr_is_any); + +void inet_proto_csum_replace4(__sum16 *sum, struct sk_buff *skb, + __be32 from, __be32 to, bool pseudohdr) +{ + if (skb->ip_summed != CHECKSUM_PARTIAL) { + csum_replace4(sum, from, to); + if (skb->ip_summed == CHECKSUM_COMPLETE && pseudohdr) + skb->csum = ~csum_add(csum_sub(~(skb->csum), + (__force __wsum)from), + (__force __wsum)to); + } else if (pseudohdr) + *sum = ~csum_fold(csum_add(csum_sub(csum_unfold(*sum), + (__force __wsum)from), + (__force __wsum)to)); +} +EXPORT_SYMBOL(inet_proto_csum_replace4); + +/** + * inet_proto_csum_replace16 - update layer 4 header checksum field + * @sum: Layer 4 header checksum field + * @skb: sk_buff for the packet + * @from: old IPv6 address + * @to: new IPv6 address + * @pseudohdr: True if layer 4 header checksum includes pseudoheader + * + * Update layer 4 header as per the update in IPv6 src/dst address. + * + * There is no need to update skb->csum in this function, because update in two + * fields a.) IPv6 src/dst address and b.) L4 header checksum cancels each other + * for skb->csum calculation. Whereas inet_proto_csum_replace4 function needs to + * update skb->csum, because update in 3 fields a.) IPv4 src/dst address, + * b.) IPv4 Header checksum and c.) L4 header checksum results in same diff as + * L4 Header checksum for skb->csum calculation. + */ +void inet_proto_csum_replace16(__sum16 *sum, struct sk_buff *skb, + const __be32 *from, const __be32 *to, + bool pseudohdr) +{ + __be32 diff[] = { + ~from[0], ~from[1], ~from[2], ~from[3], + to[0], to[1], to[2], to[3], + }; + if (skb->ip_summed != CHECKSUM_PARTIAL) { + *sum = csum_fold(csum_partial(diff, sizeof(diff), + ~csum_unfold(*sum))); + } else if (pseudohdr) + *sum = ~csum_fold(csum_partial(diff, sizeof(diff), + csum_unfold(*sum))); +} +EXPORT_SYMBOL(inet_proto_csum_replace16); + +void inet_proto_csum_replace_by_diff(__sum16 *sum, struct sk_buff *skb, + __wsum diff, bool pseudohdr) +{ + if (skb->ip_summed != CHECKSUM_PARTIAL) { + csum_replace_by_diff(sum, diff); + if (skb->ip_summed == CHECKSUM_COMPLETE && pseudohdr) + skb->csum = ~csum_sub(diff, skb->csum); + } else if (pseudohdr) { + *sum = ~csum_fold(csum_add(diff, csum_unfold(*sum))); + } +} +EXPORT_SYMBOL(inet_proto_csum_replace_by_diff); diff --git a/net/core/xdp.c b/net/core/xdp.c new file mode 100644 index 000000000..844c9d99d --- /dev/null +++ b/net/core/xdp.c @@ -0,0 +1,711 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* net/core/xdp.c + * + * Copyright (c) 2017 Jesper Dangaard Brouer, Red Hat Inc. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include /* struct xdp_mem_allocator */ +#include +#include + +#define REG_STATE_NEW 0x0 +#define REG_STATE_REGISTERED 0x1 +#define REG_STATE_UNREGISTERED 0x2 +#define REG_STATE_UNUSED 0x3 + +static DEFINE_IDA(mem_id_pool); +static DEFINE_MUTEX(mem_id_lock); +#define MEM_ID_MAX 0xFFFE +#define MEM_ID_MIN 1 +static int mem_id_next = MEM_ID_MIN; + +static bool mem_id_init; /* false */ +static struct rhashtable *mem_id_ht; + +static u32 xdp_mem_id_hashfn(const void *data, u32 len, u32 seed) +{ + const u32 *k = data; + const u32 key = *k; + + BUILD_BUG_ON(sizeof_field(struct xdp_mem_allocator, mem.id) + != sizeof(u32)); + + /* Use cyclic increasing ID as direct hash key */ + return key; +} + +static int xdp_mem_id_cmp(struct rhashtable_compare_arg *arg, + const void *ptr) +{ + const struct xdp_mem_allocator *xa = ptr; + u32 mem_id = *(u32 *)arg->key; + + return xa->mem.id != mem_id; +} + +static const struct rhashtable_params mem_id_rht_params = { + .nelem_hint = 64, + .head_offset = offsetof(struct xdp_mem_allocator, node), + .key_offset = offsetof(struct xdp_mem_allocator, mem.id), + .key_len = sizeof_field(struct xdp_mem_allocator, mem.id), + .max_size = MEM_ID_MAX, + .min_size = 8, + .automatic_shrinking = true, + .hashfn = xdp_mem_id_hashfn, + .obj_cmpfn = xdp_mem_id_cmp, +}; + +static void __xdp_mem_allocator_rcu_free(struct rcu_head *rcu) +{ + struct xdp_mem_allocator *xa; + + xa = container_of(rcu, struct xdp_mem_allocator, rcu); + + /* Allow this ID to be reused */ + ida_simple_remove(&mem_id_pool, xa->mem.id); + + kfree(xa); +} + +static void mem_xa_remove(struct xdp_mem_allocator *xa) +{ + trace_mem_disconnect(xa); + + if (!rhashtable_remove_fast(mem_id_ht, &xa->node, mem_id_rht_params)) + call_rcu(&xa->rcu, __xdp_mem_allocator_rcu_free); +} + +static void mem_allocator_disconnect(void *allocator) +{ + struct xdp_mem_allocator *xa; + struct rhashtable_iter iter; + + mutex_lock(&mem_id_lock); + + rhashtable_walk_enter(mem_id_ht, &iter); + do { + rhashtable_walk_start(&iter); + + while ((xa = rhashtable_walk_next(&iter)) && !IS_ERR(xa)) { + if (xa->allocator == allocator) + mem_xa_remove(xa); + } + + rhashtable_walk_stop(&iter); + + } while (xa == ERR_PTR(-EAGAIN)); + rhashtable_walk_exit(&iter); + + mutex_unlock(&mem_id_lock); +} + +void xdp_unreg_mem_model(struct xdp_mem_info *mem) +{ + struct xdp_mem_allocator *xa; + int type = mem->type; + int id = mem->id; + + /* Reset mem info to defaults */ + mem->id = 0; + mem->type = 0; + + if (id == 0) + return; + + if (type == MEM_TYPE_PAGE_POOL) { + rcu_read_lock(); + xa = rhashtable_lookup(mem_id_ht, &id, mem_id_rht_params); + page_pool_destroy(xa->page_pool); + rcu_read_unlock(); + } +} +EXPORT_SYMBOL_GPL(xdp_unreg_mem_model); + +void xdp_rxq_info_unreg_mem_model(struct xdp_rxq_info *xdp_rxq) +{ + if (xdp_rxq->reg_state != REG_STATE_REGISTERED) { + WARN(1, "Missing register, driver bug"); + return; + } + + xdp_unreg_mem_model(&xdp_rxq->mem); +} +EXPORT_SYMBOL_GPL(xdp_rxq_info_unreg_mem_model); + +void xdp_rxq_info_unreg(struct xdp_rxq_info *xdp_rxq) +{ + /* Simplify driver cleanup code paths, allow unreg "unused" */ + if (xdp_rxq->reg_state == REG_STATE_UNUSED) + return; + + xdp_rxq_info_unreg_mem_model(xdp_rxq); + + xdp_rxq->reg_state = REG_STATE_UNREGISTERED; + xdp_rxq->dev = NULL; +} +EXPORT_SYMBOL_GPL(xdp_rxq_info_unreg); + +static void xdp_rxq_info_init(struct xdp_rxq_info *xdp_rxq) +{ + memset(xdp_rxq, 0, sizeof(*xdp_rxq)); +} + +/* Returns 0 on success, negative on failure */ +int __xdp_rxq_info_reg(struct xdp_rxq_info *xdp_rxq, + struct net_device *dev, u32 queue_index, + unsigned int napi_id, u32 frag_size) +{ + if (!dev) { + WARN(1, "Missing net_device from driver"); + return -ENODEV; + } + + if (xdp_rxq->reg_state == REG_STATE_UNUSED) { + WARN(1, "Driver promised not to register this"); + return -EINVAL; + } + + if (xdp_rxq->reg_state == REG_STATE_REGISTERED) { + WARN(1, "Missing unregister, handled but fix driver"); + xdp_rxq_info_unreg(xdp_rxq); + } + + /* State either UNREGISTERED or NEW */ + xdp_rxq_info_init(xdp_rxq); + xdp_rxq->dev = dev; + xdp_rxq->queue_index = queue_index; + xdp_rxq->napi_id = napi_id; + xdp_rxq->frag_size = frag_size; + + xdp_rxq->reg_state = REG_STATE_REGISTERED; + return 0; +} +EXPORT_SYMBOL_GPL(__xdp_rxq_info_reg); + +void xdp_rxq_info_unused(struct xdp_rxq_info *xdp_rxq) +{ + xdp_rxq->reg_state = REG_STATE_UNUSED; +} +EXPORT_SYMBOL_GPL(xdp_rxq_info_unused); + +bool xdp_rxq_info_is_reg(struct xdp_rxq_info *xdp_rxq) +{ + return (xdp_rxq->reg_state == REG_STATE_REGISTERED); +} +EXPORT_SYMBOL_GPL(xdp_rxq_info_is_reg); + +static int __mem_id_init_hash_table(void) +{ + struct rhashtable *rht; + int ret; + + if (unlikely(mem_id_init)) + return 0; + + rht = kzalloc(sizeof(*rht), GFP_KERNEL); + if (!rht) + return -ENOMEM; + + ret = rhashtable_init(rht, &mem_id_rht_params); + if (ret < 0) { + kfree(rht); + return ret; + } + mem_id_ht = rht; + smp_mb(); /* mutex lock should provide enough pairing */ + mem_id_init = true; + + return 0; +} + +/* Allocate a cyclic ID that maps to allocator pointer. + * See: https://www.kernel.org/doc/html/latest/core-api/idr.html + * + * Caller must lock mem_id_lock. + */ +static int __mem_id_cyclic_get(gfp_t gfp) +{ + int retries = 1; + int id; + +again: + id = ida_simple_get(&mem_id_pool, mem_id_next, MEM_ID_MAX, gfp); + if (id < 0) { + if (id == -ENOSPC) { + /* Cyclic allocator, reset next id */ + if (retries--) { + mem_id_next = MEM_ID_MIN; + goto again; + } + } + return id; /* errno */ + } + mem_id_next = id + 1; + + return id; +} + +static bool __is_supported_mem_type(enum xdp_mem_type type) +{ + if (type == MEM_TYPE_PAGE_POOL) + return is_page_pool_compiled_in(); + + if (type >= MEM_TYPE_MAX) + return false; + + return true; +} + +static struct xdp_mem_allocator *__xdp_reg_mem_model(struct xdp_mem_info *mem, + enum xdp_mem_type type, + void *allocator) +{ + struct xdp_mem_allocator *xdp_alloc; + gfp_t gfp = GFP_KERNEL; + int id, errno, ret; + void *ptr; + + if (!__is_supported_mem_type(type)) + return ERR_PTR(-EOPNOTSUPP); + + mem->type = type; + + if (!allocator) { + if (type == MEM_TYPE_PAGE_POOL) + return ERR_PTR(-EINVAL); /* Setup time check page_pool req */ + return NULL; + } + + /* Delay init of rhashtable to save memory if feature isn't used */ + if (!mem_id_init) { + mutex_lock(&mem_id_lock); + ret = __mem_id_init_hash_table(); + mutex_unlock(&mem_id_lock); + if (ret < 0) { + WARN_ON(1); + return ERR_PTR(ret); + } + } + + xdp_alloc = kzalloc(sizeof(*xdp_alloc), gfp); + if (!xdp_alloc) + return ERR_PTR(-ENOMEM); + + mutex_lock(&mem_id_lock); + id = __mem_id_cyclic_get(gfp); + if (id < 0) { + errno = id; + goto err; + } + mem->id = id; + xdp_alloc->mem = *mem; + xdp_alloc->allocator = allocator; + + /* Insert allocator into ID lookup table */ + ptr = rhashtable_insert_slow(mem_id_ht, &id, &xdp_alloc->node); + if (IS_ERR(ptr)) { + ida_simple_remove(&mem_id_pool, mem->id); + mem->id = 0; + errno = PTR_ERR(ptr); + goto err; + } + + if (type == MEM_TYPE_PAGE_POOL) + page_pool_use_xdp_mem(allocator, mem_allocator_disconnect, mem); + + mutex_unlock(&mem_id_lock); + + return xdp_alloc; +err: + mutex_unlock(&mem_id_lock); + kfree(xdp_alloc); + return ERR_PTR(errno); +} + +int xdp_reg_mem_model(struct xdp_mem_info *mem, + enum xdp_mem_type type, void *allocator) +{ + struct xdp_mem_allocator *xdp_alloc; + + xdp_alloc = __xdp_reg_mem_model(mem, type, allocator); + if (IS_ERR(xdp_alloc)) + return PTR_ERR(xdp_alloc); + return 0; +} +EXPORT_SYMBOL_GPL(xdp_reg_mem_model); + +int xdp_rxq_info_reg_mem_model(struct xdp_rxq_info *xdp_rxq, + enum xdp_mem_type type, void *allocator) +{ + struct xdp_mem_allocator *xdp_alloc; + + if (xdp_rxq->reg_state != REG_STATE_REGISTERED) { + WARN(1, "Missing register, driver bug"); + return -EFAULT; + } + + xdp_alloc = __xdp_reg_mem_model(&xdp_rxq->mem, type, allocator); + if (IS_ERR(xdp_alloc)) + return PTR_ERR(xdp_alloc); + + if (trace_mem_connect_enabled() && xdp_alloc) + trace_mem_connect(xdp_alloc, xdp_rxq); + return 0; +} + +EXPORT_SYMBOL_GPL(xdp_rxq_info_reg_mem_model); + +/* XDP RX runs under NAPI protection, and in different delivery error + * scenarios (e.g. queue full), it is possible to return the xdp_frame + * while still leveraging this protection. The @napi_direct boolean + * is used for those calls sites. Thus, allowing for faster recycling + * of xdp_frames/pages in those cases. + */ +void __xdp_return(void *data, struct xdp_mem_info *mem, bool napi_direct, + struct xdp_buff *xdp) +{ + struct page *page; + + switch (mem->type) { + case MEM_TYPE_PAGE_POOL: + page = virt_to_head_page(data); + if (napi_direct && xdp_return_frame_no_direct()) + napi_direct = false; + /* No need to check ((page->pp_magic & ~0x3UL) == PP_SIGNATURE) + * as mem->type knows this a page_pool page + */ + page_pool_put_full_page(page->pp, page, napi_direct); + break; + case MEM_TYPE_PAGE_SHARED: + page_frag_free(data); + break; + case MEM_TYPE_PAGE_ORDER0: + page = virt_to_page(data); /* Assumes order0 page*/ + put_page(page); + break; + case MEM_TYPE_XSK_BUFF_POOL: + /* NB! Only valid from an xdp_buff! */ + xsk_buff_free(xdp); + break; + default: + /* Not possible, checked in xdp_rxq_info_reg_mem_model() */ + WARN(1, "Incorrect XDP memory type (%d) usage", mem->type); + break; + } +} + +void xdp_return_frame(struct xdp_frame *xdpf) +{ + struct skb_shared_info *sinfo; + int i; + + if (likely(!xdp_frame_has_frags(xdpf))) + goto out; + + sinfo = xdp_get_shared_info_from_frame(xdpf); + for (i = 0; i < sinfo->nr_frags; i++) { + struct page *page = skb_frag_page(&sinfo->frags[i]); + + __xdp_return(page_address(page), &xdpf->mem, false, NULL); + } +out: + __xdp_return(xdpf->data, &xdpf->mem, false, NULL); +} +EXPORT_SYMBOL_GPL(xdp_return_frame); + +void xdp_return_frame_rx_napi(struct xdp_frame *xdpf) +{ + struct skb_shared_info *sinfo; + int i; + + if (likely(!xdp_frame_has_frags(xdpf))) + goto out; + + sinfo = xdp_get_shared_info_from_frame(xdpf); + for (i = 0; i < sinfo->nr_frags; i++) { + struct page *page = skb_frag_page(&sinfo->frags[i]); + + __xdp_return(page_address(page), &xdpf->mem, true, NULL); + } +out: + __xdp_return(xdpf->data, &xdpf->mem, true, NULL); +} +EXPORT_SYMBOL_GPL(xdp_return_frame_rx_napi); + +/* XDP bulk APIs introduce a defer/flush mechanism to return + * pages belonging to the same xdp_mem_allocator object + * (identified via the mem.id field) in bulk to optimize + * I-cache and D-cache. + * The bulk queue size is set to 16 to be aligned to how + * XDP_REDIRECT bulking works. The bulk is flushed when + * it is full or when mem.id changes. + * xdp_frame_bulk is usually stored/allocated on the function + * call-stack to avoid locking penalties. + */ +void xdp_flush_frame_bulk(struct xdp_frame_bulk *bq) +{ + struct xdp_mem_allocator *xa = bq->xa; + + if (unlikely(!xa || !bq->count)) + return; + + page_pool_put_page_bulk(xa->page_pool, bq->q, bq->count); + /* bq->xa is not cleared to save lookup, if mem.id same in next bulk */ + bq->count = 0; +} +EXPORT_SYMBOL_GPL(xdp_flush_frame_bulk); + +/* Must be called with rcu_read_lock held */ +void xdp_return_frame_bulk(struct xdp_frame *xdpf, + struct xdp_frame_bulk *bq) +{ + struct xdp_mem_info *mem = &xdpf->mem; + struct xdp_mem_allocator *xa; + + if (mem->type != MEM_TYPE_PAGE_POOL) { + xdp_return_frame(xdpf); + return; + } + + xa = bq->xa; + if (unlikely(!xa)) { + xa = rhashtable_lookup(mem_id_ht, &mem->id, mem_id_rht_params); + bq->count = 0; + bq->xa = xa; + } + + if (bq->count == XDP_BULK_QUEUE_SIZE) + xdp_flush_frame_bulk(bq); + + if (unlikely(mem->id != xa->mem.id)) { + xdp_flush_frame_bulk(bq); + bq->xa = rhashtable_lookup(mem_id_ht, &mem->id, mem_id_rht_params); + } + + if (unlikely(xdp_frame_has_frags(xdpf))) { + struct skb_shared_info *sinfo; + int i; + + sinfo = xdp_get_shared_info_from_frame(xdpf); + for (i = 0; i < sinfo->nr_frags; i++) { + skb_frag_t *frag = &sinfo->frags[i]; + + bq->q[bq->count++] = skb_frag_address(frag); + if (bq->count == XDP_BULK_QUEUE_SIZE) + xdp_flush_frame_bulk(bq); + } + } + bq->q[bq->count++] = xdpf->data; +} +EXPORT_SYMBOL_GPL(xdp_return_frame_bulk); + +void xdp_return_buff(struct xdp_buff *xdp) +{ + struct skb_shared_info *sinfo; + int i; + + if (likely(!xdp_buff_has_frags(xdp))) + goto out; + + sinfo = xdp_get_shared_info_from_buff(xdp); + for (i = 0; i < sinfo->nr_frags; i++) { + struct page *page = skb_frag_page(&sinfo->frags[i]); + + __xdp_return(page_address(page), &xdp->rxq->mem, true, xdp); + } +out: + __xdp_return(xdp->data, &xdp->rxq->mem, true, xdp); +} +EXPORT_SYMBOL_GPL(xdp_return_buff); + +/* Only called for MEM_TYPE_PAGE_POOL see xdp.h */ +void __xdp_release_frame(void *data, struct xdp_mem_info *mem) +{ + struct xdp_mem_allocator *xa; + struct page *page; + + rcu_read_lock(); + xa = rhashtable_lookup(mem_id_ht, &mem->id, mem_id_rht_params); + page = virt_to_head_page(data); + if (xa) + page_pool_release_page(xa->page_pool, page); + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(__xdp_release_frame); + +void xdp_attachment_setup(struct xdp_attachment_info *info, + struct netdev_bpf *bpf) +{ + if (info->prog) + bpf_prog_put(info->prog); + info->prog = bpf->prog; + info->flags = bpf->flags; +} +EXPORT_SYMBOL_GPL(xdp_attachment_setup); + +struct xdp_frame *xdp_convert_zc_to_xdp_frame(struct xdp_buff *xdp) +{ + unsigned int metasize, totsize; + void *addr, *data_to_copy; + struct xdp_frame *xdpf; + struct page *page; + + /* Clone into a MEM_TYPE_PAGE_ORDER0 xdp_frame. */ + metasize = xdp_data_meta_unsupported(xdp) ? 0 : + xdp->data - xdp->data_meta; + totsize = xdp->data_end - xdp->data + metasize; + + if (sizeof(*xdpf) + totsize > PAGE_SIZE) + return NULL; + + page = dev_alloc_page(); + if (!page) + return NULL; + + addr = page_to_virt(page); + xdpf = addr; + memset(xdpf, 0, sizeof(*xdpf)); + + addr += sizeof(*xdpf); + data_to_copy = metasize ? xdp->data_meta : xdp->data; + memcpy(addr, data_to_copy, totsize); + + xdpf->data = addr + metasize; + xdpf->len = totsize - metasize; + xdpf->headroom = 0; + xdpf->metasize = metasize; + xdpf->frame_sz = PAGE_SIZE; + xdpf->mem.type = MEM_TYPE_PAGE_ORDER0; + + xsk_buff_free(xdp); + return xdpf; +} +EXPORT_SYMBOL_GPL(xdp_convert_zc_to_xdp_frame); + +/* Used by XDP_WARN macro, to avoid inlining WARN() in fast-path */ +void xdp_warn(const char *msg, const char *func, const int line) +{ + WARN(1, "XDP_WARN: %s(line:%d): %s\n", func, line, msg); +}; +EXPORT_SYMBOL_GPL(xdp_warn); + +int xdp_alloc_skb_bulk(void **skbs, int n_skb, gfp_t gfp) +{ + n_skb = kmem_cache_alloc_bulk(skbuff_head_cache, gfp, + n_skb, skbs); + if (unlikely(!n_skb)) + return -ENOMEM; + + return 0; +} +EXPORT_SYMBOL_GPL(xdp_alloc_skb_bulk); + +struct sk_buff *__xdp_build_skb_from_frame(struct xdp_frame *xdpf, + struct sk_buff *skb, + struct net_device *dev) +{ + struct skb_shared_info *sinfo = xdp_get_shared_info_from_frame(xdpf); + unsigned int headroom, frame_size; + void *hard_start; + u8 nr_frags; + + /* xdp frags frame */ + if (unlikely(xdp_frame_has_frags(xdpf))) + nr_frags = sinfo->nr_frags; + + /* Part of headroom was reserved to xdpf */ + headroom = sizeof(*xdpf) + xdpf->headroom; + + /* Memory size backing xdp_frame data already have reserved + * room for build_skb to place skb_shared_info in tailroom. + */ + frame_size = xdpf->frame_sz; + + hard_start = xdpf->data - headroom; + skb = build_skb_around(skb, hard_start, frame_size); + if (unlikely(!skb)) + return NULL; + + skb_reserve(skb, headroom); + __skb_put(skb, xdpf->len); + if (xdpf->metasize) + skb_metadata_set(skb, xdpf->metasize); + + if (unlikely(xdp_frame_has_frags(xdpf))) + xdp_update_skb_shared_info(skb, nr_frags, + sinfo->xdp_frags_size, + nr_frags * xdpf->frame_sz, + xdp_frame_is_frag_pfmemalloc(xdpf)); + + /* Essential SKB info: protocol and skb->dev */ + skb->protocol = eth_type_trans(skb, dev); + + /* Optional SKB info, currently missing: + * - HW checksum info (skb->ip_summed) + * - HW RX hash (skb_set_hash) + * - RX ring dev queue index (skb_record_rx_queue) + */ + + /* Until page_pool get SKB return path, release DMA here */ + xdp_release_frame(xdpf); + + /* Allow SKB to reuse area used by xdp_frame */ + xdp_scrub_frame(xdpf); + + return skb; +} +EXPORT_SYMBOL_GPL(__xdp_build_skb_from_frame); + +struct sk_buff *xdp_build_skb_from_frame(struct xdp_frame *xdpf, + struct net_device *dev) +{ + struct sk_buff *skb; + + skb = kmem_cache_alloc(skbuff_head_cache, GFP_ATOMIC); + if (unlikely(!skb)) + return NULL; + + memset(skb, 0, offsetof(struct sk_buff, tail)); + + return __xdp_build_skb_from_frame(xdpf, skb, dev); +} +EXPORT_SYMBOL_GPL(xdp_build_skb_from_frame); + +struct xdp_frame *xdpf_clone(struct xdp_frame *xdpf) +{ + unsigned int headroom, totalsize; + struct xdp_frame *nxdpf; + struct page *page; + void *addr; + + headroom = xdpf->headroom + sizeof(*xdpf); + totalsize = headroom + xdpf->len; + + if (unlikely(totalsize > PAGE_SIZE)) + return NULL; + page = dev_alloc_page(); + if (!page) + return NULL; + addr = page_to_virt(page); + + memcpy(addr, xdpf, totalsize); + + nxdpf = addr; + nxdpf->data = addr + headroom; + nxdpf->frame_sz = PAGE_SIZE; + nxdpf->mem.type = MEM_TYPE_PAGE_ORDER0; + nxdpf->mem.id = 0; + + return nxdpf; +} diff --git a/net/dcb/Kconfig b/net/dcb/Kconfig new file mode 100644 index 000000000..efee8b9fe --- /dev/null +++ b/net/dcb/Kconfig @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: GPL-2.0-only +config DCB + bool "Data Center Bridging support" + default n + help + This enables support for configuring Data Center Bridging (DCB) + features on DCB capable Ethernet adapters via rtnetlink. Say 'Y' + if you have a DCB capable Ethernet adapter which supports this + interface and you are connected to a DCB capable switch. + + DCB is a collection of Ethernet enhancements which allow DCB capable + NICs and switches to support network traffic with differing + requirements (highly reliable, no drops vs. best effort vs. low + latency) to co-exist on Ethernet. + + DCB features include: + Enhanced Transmission Selection (aka Priority Grouping) - provides a + framework for assigning bandwidth guarantees to traffic classes. + Priority-based Flow Control (PFC) - a MAC control pause frame which + works at the granularity of the 802.1p priority instead of the + link (802.3x). + + If unsure, say N. diff --git a/net/dcb/Makefile b/net/dcb/Makefile new file mode 100644 index 000000000..2c0fa16ee --- /dev/null +++ b/net/dcb/Makefile @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: GPL-2.0-only +obj-y += dcbnl.o dcbevent.o diff --git a/net/dcb/dcbevent.c b/net/dcb/dcbevent.c new file mode 100644 index 000000000..8620564c2 --- /dev/null +++ b/net/dcb/dcbevent.c @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2010, Intel Corporation. + * + * Author: John Fastabend + */ + +#include +#include +#include +#include + +static ATOMIC_NOTIFIER_HEAD(dcbevent_notif_chain); + +int register_dcbevent_notifier(struct notifier_block *nb) +{ + return atomic_notifier_chain_register(&dcbevent_notif_chain, nb); +} +EXPORT_SYMBOL(register_dcbevent_notifier); + +int unregister_dcbevent_notifier(struct notifier_block *nb) +{ + return atomic_notifier_chain_unregister(&dcbevent_notif_chain, nb); +} +EXPORT_SYMBOL(unregister_dcbevent_notifier); + +int call_dcbevent_notifiers(unsigned long val, void *v) +{ + return atomic_notifier_call_chain(&dcbevent_notif_chain, val, v); +} diff --git a/net/dcb/dcbnl.c b/net/dcb/dcbnl.c new file mode 100644 index 000000000..d2981e89d --- /dev/null +++ b/net/dcb/dcbnl.c @@ -0,0 +1,2127 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008-2011, Intel Corporation. + * + * Description: Data Center Bridging netlink interface + * Author: Lucy Liu + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* Data Center Bridging (DCB) is a collection of Ethernet enhancements + * intended to allow network traffic with differing requirements + * (highly reliable, no drops vs. best effort vs. low latency) to operate + * and co-exist on Ethernet. Current DCB features are: + * + * Enhanced Transmission Selection (aka Priority Grouping [PG]) - provides a + * framework for assigning bandwidth guarantees to traffic classes. + * + * Priority-based Flow Control (PFC) - provides a flow control mechanism which + * can work independently for each 802.1p priority. + * + * Congestion Notification - provides a mechanism for end-to-end congestion + * control for protocols which do not have built-in congestion management. + * + * More information about the emerging standards for these Ethernet features + * can be found at: http://www.ieee802.org/1/pages/dcbridges.html + * + * This file implements an rtnetlink interface to allow configuration of DCB + * features for capable devices. + */ + +/**************** DCB attribute policies *************************************/ + +/* DCB netlink attributes policy */ +static const struct nla_policy dcbnl_rtnl_policy[DCB_ATTR_MAX + 1] = { + [DCB_ATTR_IFNAME] = {.type = NLA_NUL_STRING, .len = IFNAMSIZ - 1}, + [DCB_ATTR_STATE] = {.type = NLA_U8}, + [DCB_ATTR_PFC_CFG] = {.type = NLA_NESTED}, + [DCB_ATTR_PG_CFG] = {.type = NLA_NESTED}, + [DCB_ATTR_SET_ALL] = {.type = NLA_U8}, + [DCB_ATTR_PERM_HWADDR] = {.type = NLA_FLAG}, + [DCB_ATTR_CAP] = {.type = NLA_NESTED}, + [DCB_ATTR_PFC_STATE] = {.type = NLA_U8}, + [DCB_ATTR_BCN] = {.type = NLA_NESTED}, + [DCB_ATTR_APP] = {.type = NLA_NESTED}, + [DCB_ATTR_IEEE] = {.type = NLA_NESTED}, + [DCB_ATTR_DCBX] = {.type = NLA_U8}, + [DCB_ATTR_FEATCFG] = {.type = NLA_NESTED}, +}; + +/* DCB priority flow control to User Priority nested attributes */ +static const struct nla_policy dcbnl_pfc_up_nest[DCB_PFC_UP_ATTR_MAX + 1] = { + [DCB_PFC_UP_ATTR_0] = {.type = NLA_U8}, + [DCB_PFC_UP_ATTR_1] = {.type = NLA_U8}, + [DCB_PFC_UP_ATTR_2] = {.type = NLA_U8}, + [DCB_PFC_UP_ATTR_3] = {.type = NLA_U8}, + [DCB_PFC_UP_ATTR_4] = {.type = NLA_U8}, + [DCB_PFC_UP_ATTR_5] = {.type = NLA_U8}, + [DCB_PFC_UP_ATTR_6] = {.type = NLA_U8}, + [DCB_PFC_UP_ATTR_7] = {.type = NLA_U8}, + [DCB_PFC_UP_ATTR_ALL] = {.type = NLA_FLAG}, +}; + +/* DCB priority grouping nested attributes */ +static const struct nla_policy dcbnl_pg_nest[DCB_PG_ATTR_MAX + 1] = { + [DCB_PG_ATTR_TC_0] = {.type = NLA_NESTED}, + [DCB_PG_ATTR_TC_1] = {.type = NLA_NESTED}, + [DCB_PG_ATTR_TC_2] = {.type = NLA_NESTED}, + [DCB_PG_ATTR_TC_3] = {.type = NLA_NESTED}, + [DCB_PG_ATTR_TC_4] = {.type = NLA_NESTED}, + [DCB_PG_ATTR_TC_5] = {.type = NLA_NESTED}, + [DCB_PG_ATTR_TC_6] = {.type = NLA_NESTED}, + [DCB_PG_ATTR_TC_7] = {.type = NLA_NESTED}, + [DCB_PG_ATTR_TC_ALL] = {.type = NLA_NESTED}, + [DCB_PG_ATTR_BW_ID_0] = {.type = NLA_U8}, + [DCB_PG_ATTR_BW_ID_1] = {.type = NLA_U8}, + [DCB_PG_ATTR_BW_ID_2] = {.type = NLA_U8}, + [DCB_PG_ATTR_BW_ID_3] = {.type = NLA_U8}, + [DCB_PG_ATTR_BW_ID_4] = {.type = NLA_U8}, + [DCB_PG_ATTR_BW_ID_5] = {.type = NLA_U8}, + [DCB_PG_ATTR_BW_ID_6] = {.type = NLA_U8}, + [DCB_PG_ATTR_BW_ID_7] = {.type = NLA_U8}, + [DCB_PG_ATTR_BW_ID_ALL] = {.type = NLA_FLAG}, +}; + +/* DCB traffic class nested attributes. */ +static const struct nla_policy dcbnl_tc_param_nest[DCB_TC_ATTR_PARAM_MAX + 1] = { + [DCB_TC_ATTR_PARAM_PGID] = {.type = NLA_U8}, + [DCB_TC_ATTR_PARAM_UP_MAPPING] = {.type = NLA_U8}, + [DCB_TC_ATTR_PARAM_STRICT_PRIO] = {.type = NLA_U8}, + [DCB_TC_ATTR_PARAM_BW_PCT] = {.type = NLA_U8}, + [DCB_TC_ATTR_PARAM_ALL] = {.type = NLA_FLAG}, +}; + +/* DCB capabilities nested attributes. */ +static const struct nla_policy dcbnl_cap_nest[DCB_CAP_ATTR_MAX + 1] = { + [DCB_CAP_ATTR_ALL] = {.type = NLA_FLAG}, + [DCB_CAP_ATTR_PG] = {.type = NLA_U8}, + [DCB_CAP_ATTR_PFC] = {.type = NLA_U8}, + [DCB_CAP_ATTR_UP2TC] = {.type = NLA_U8}, + [DCB_CAP_ATTR_PG_TCS] = {.type = NLA_U8}, + [DCB_CAP_ATTR_PFC_TCS] = {.type = NLA_U8}, + [DCB_CAP_ATTR_GSP] = {.type = NLA_U8}, + [DCB_CAP_ATTR_BCN] = {.type = NLA_U8}, + [DCB_CAP_ATTR_DCBX] = {.type = NLA_U8}, +}; + +/* DCB capabilities nested attributes. */ +static const struct nla_policy dcbnl_numtcs_nest[DCB_NUMTCS_ATTR_MAX + 1] = { + [DCB_NUMTCS_ATTR_ALL] = {.type = NLA_FLAG}, + [DCB_NUMTCS_ATTR_PG] = {.type = NLA_U8}, + [DCB_NUMTCS_ATTR_PFC] = {.type = NLA_U8}, +}; + +/* DCB BCN nested attributes. */ +static const struct nla_policy dcbnl_bcn_nest[DCB_BCN_ATTR_MAX + 1] = { + [DCB_BCN_ATTR_RP_0] = {.type = NLA_U8}, + [DCB_BCN_ATTR_RP_1] = {.type = NLA_U8}, + [DCB_BCN_ATTR_RP_2] = {.type = NLA_U8}, + [DCB_BCN_ATTR_RP_3] = {.type = NLA_U8}, + [DCB_BCN_ATTR_RP_4] = {.type = NLA_U8}, + [DCB_BCN_ATTR_RP_5] = {.type = NLA_U8}, + [DCB_BCN_ATTR_RP_6] = {.type = NLA_U8}, + [DCB_BCN_ATTR_RP_7] = {.type = NLA_U8}, + [DCB_BCN_ATTR_RP_ALL] = {.type = NLA_FLAG}, + [DCB_BCN_ATTR_BCNA_0] = {.type = NLA_U32}, + [DCB_BCN_ATTR_BCNA_1] = {.type = NLA_U32}, + [DCB_BCN_ATTR_ALPHA] = {.type = NLA_U32}, + [DCB_BCN_ATTR_BETA] = {.type = NLA_U32}, + [DCB_BCN_ATTR_GD] = {.type = NLA_U32}, + [DCB_BCN_ATTR_GI] = {.type = NLA_U32}, + [DCB_BCN_ATTR_TMAX] = {.type = NLA_U32}, + [DCB_BCN_ATTR_TD] = {.type = NLA_U32}, + [DCB_BCN_ATTR_RMIN] = {.type = NLA_U32}, + [DCB_BCN_ATTR_W] = {.type = NLA_U32}, + [DCB_BCN_ATTR_RD] = {.type = NLA_U32}, + [DCB_BCN_ATTR_RU] = {.type = NLA_U32}, + [DCB_BCN_ATTR_WRTT] = {.type = NLA_U32}, + [DCB_BCN_ATTR_RI] = {.type = NLA_U32}, + [DCB_BCN_ATTR_C] = {.type = NLA_U32}, + [DCB_BCN_ATTR_ALL] = {.type = NLA_FLAG}, +}; + +/* DCB APP nested attributes. */ +static const struct nla_policy dcbnl_app_nest[DCB_APP_ATTR_MAX + 1] = { + [DCB_APP_ATTR_IDTYPE] = {.type = NLA_U8}, + [DCB_APP_ATTR_ID] = {.type = NLA_U16}, + [DCB_APP_ATTR_PRIORITY] = {.type = NLA_U8}, +}; + +/* IEEE 802.1Qaz nested attributes. */ +static const struct nla_policy dcbnl_ieee_policy[DCB_ATTR_IEEE_MAX + 1] = { + [DCB_ATTR_IEEE_ETS] = {.len = sizeof(struct ieee_ets)}, + [DCB_ATTR_IEEE_PFC] = {.len = sizeof(struct ieee_pfc)}, + [DCB_ATTR_IEEE_APP_TABLE] = {.type = NLA_NESTED}, + [DCB_ATTR_IEEE_MAXRATE] = {.len = sizeof(struct ieee_maxrate)}, + [DCB_ATTR_IEEE_QCN] = {.len = sizeof(struct ieee_qcn)}, + [DCB_ATTR_IEEE_QCN_STATS] = {.len = sizeof(struct ieee_qcn_stats)}, + [DCB_ATTR_DCB_BUFFER] = {.len = sizeof(struct dcbnl_buffer)}, +}; + +/* DCB number of traffic classes nested attributes. */ +static const struct nla_policy dcbnl_featcfg_nest[DCB_FEATCFG_ATTR_MAX + 1] = { + [DCB_FEATCFG_ATTR_ALL] = {.type = NLA_FLAG}, + [DCB_FEATCFG_ATTR_PG] = {.type = NLA_U8}, + [DCB_FEATCFG_ATTR_PFC] = {.type = NLA_U8}, + [DCB_FEATCFG_ATTR_APP] = {.type = NLA_U8}, +}; + +static LIST_HEAD(dcb_app_list); +static DEFINE_SPINLOCK(dcb_lock); + +static struct sk_buff *dcbnl_newmsg(int type, u8 cmd, u32 port, u32 seq, + u32 flags, struct nlmsghdr **nlhp) +{ + struct sk_buff *skb; + struct dcbmsg *dcb; + struct nlmsghdr *nlh; + + skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb) + return NULL; + + nlh = nlmsg_put(skb, port, seq, type, sizeof(*dcb), flags); + BUG_ON(!nlh); + + dcb = nlmsg_data(nlh); + dcb->dcb_family = AF_UNSPEC; + dcb->cmd = cmd; + dcb->dcb_pad = 0; + + if (nlhp) + *nlhp = nlh; + + return skb; +} + +static int dcbnl_getstate(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + /* if (!tb[DCB_ATTR_STATE] || !netdev->dcbnl_ops->getstate) */ + if (!netdev->dcbnl_ops->getstate) + return -EOPNOTSUPP; + + return nla_put_u8(skb, DCB_ATTR_STATE, + netdev->dcbnl_ops->getstate(netdev)); +} + +static int dcbnl_getpfccfg(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + struct nlattr *data[DCB_PFC_UP_ATTR_MAX + 1], *nest; + u8 value; + int ret; + int i; + int getall = 0; + + if (!tb[DCB_ATTR_PFC_CFG]) + return -EINVAL; + + if (!netdev->dcbnl_ops->getpfccfg) + return -EOPNOTSUPP; + + ret = nla_parse_nested_deprecated(data, DCB_PFC_UP_ATTR_MAX, + tb[DCB_ATTR_PFC_CFG], + dcbnl_pfc_up_nest, NULL); + if (ret) + return ret; + + nest = nla_nest_start_noflag(skb, DCB_ATTR_PFC_CFG); + if (!nest) + return -EMSGSIZE; + + if (data[DCB_PFC_UP_ATTR_ALL]) + getall = 1; + + for (i = DCB_PFC_UP_ATTR_0; i <= DCB_PFC_UP_ATTR_7; i++) { + if (!getall && !data[i]) + continue; + + netdev->dcbnl_ops->getpfccfg(netdev, i - DCB_PFC_UP_ATTR_0, + &value); + ret = nla_put_u8(skb, i, value); + if (ret) { + nla_nest_cancel(skb, nest); + return ret; + } + } + nla_nest_end(skb, nest); + + return 0; +} + +static int dcbnl_getperm_hwaddr(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + u8 perm_addr[MAX_ADDR_LEN]; + + if (!netdev->dcbnl_ops->getpermhwaddr) + return -EOPNOTSUPP; + + memset(perm_addr, 0, sizeof(perm_addr)); + netdev->dcbnl_ops->getpermhwaddr(netdev, perm_addr); + + return nla_put(skb, DCB_ATTR_PERM_HWADDR, sizeof(perm_addr), perm_addr); +} + +static int dcbnl_getcap(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + struct nlattr *data[DCB_CAP_ATTR_MAX + 1], *nest; + u8 value; + int ret; + int i; + int getall = 0; + + if (!tb[DCB_ATTR_CAP]) + return -EINVAL; + + if (!netdev->dcbnl_ops->getcap) + return -EOPNOTSUPP; + + ret = nla_parse_nested_deprecated(data, DCB_CAP_ATTR_MAX, + tb[DCB_ATTR_CAP], dcbnl_cap_nest, + NULL); + if (ret) + return ret; + + nest = nla_nest_start_noflag(skb, DCB_ATTR_CAP); + if (!nest) + return -EMSGSIZE; + + if (data[DCB_CAP_ATTR_ALL]) + getall = 1; + + for (i = DCB_CAP_ATTR_ALL+1; i <= DCB_CAP_ATTR_MAX; i++) { + if (!getall && !data[i]) + continue; + + if (!netdev->dcbnl_ops->getcap(netdev, i, &value)) { + ret = nla_put_u8(skb, i, value); + if (ret) { + nla_nest_cancel(skb, nest); + return ret; + } + } + } + nla_nest_end(skb, nest); + + return 0; +} + +static int dcbnl_getnumtcs(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + struct nlattr *data[DCB_NUMTCS_ATTR_MAX + 1], *nest; + u8 value; + int ret; + int i; + int getall = 0; + + if (!tb[DCB_ATTR_NUMTCS]) + return -EINVAL; + + if (!netdev->dcbnl_ops->getnumtcs) + return -EOPNOTSUPP; + + ret = nla_parse_nested_deprecated(data, DCB_NUMTCS_ATTR_MAX, + tb[DCB_ATTR_NUMTCS], + dcbnl_numtcs_nest, NULL); + if (ret) + return ret; + + nest = nla_nest_start_noflag(skb, DCB_ATTR_NUMTCS); + if (!nest) + return -EMSGSIZE; + + if (data[DCB_NUMTCS_ATTR_ALL]) + getall = 1; + + for (i = DCB_NUMTCS_ATTR_ALL+1; i <= DCB_NUMTCS_ATTR_MAX; i++) { + if (!getall && !data[i]) + continue; + + ret = netdev->dcbnl_ops->getnumtcs(netdev, i, &value); + if (!ret) { + ret = nla_put_u8(skb, i, value); + if (ret) { + nla_nest_cancel(skb, nest); + return ret; + } + } else + return -EINVAL; + } + nla_nest_end(skb, nest); + + return 0; +} + +static int dcbnl_setnumtcs(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + struct nlattr *data[DCB_NUMTCS_ATTR_MAX + 1]; + int ret; + u8 value; + int i; + + if (!tb[DCB_ATTR_NUMTCS]) + return -EINVAL; + + if (!netdev->dcbnl_ops->setnumtcs) + return -EOPNOTSUPP; + + ret = nla_parse_nested_deprecated(data, DCB_NUMTCS_ATTR_MAX, + tb[DCB_ATTR_NUMTCS], + dcbnl_numtcs_nest, NULL); + if (ret) + return ret; + + for (i = DCB_NUMTCS_ATTR_ALL+1; i <= DCB_NUMTCS_ATTR_MAX; i++) { + if (data[i] == NULL) + continue; + + value = nla_get_u8(data[i]); + + ret = netdev->dcbnl_ops->setnumtcs(netdev, i, value); + if (ret) + break; + } + + return nla_put_u8(skb, DCB_ATTR_NUMTCS, !!ret); +} + +static int dcbnl_getpfcstate(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + if (!netdev->dcbnl_ops->getpfcstate) + return -EOPNOTSUPP; + + return nla_put_u8(skb, DCB_ATTR_PFC_STATE, + netdev->dcbnl_ops->getpfcstate(netdev)); +} + +static int dcbnl_setpfcstate(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + u8 value; + + if (!tb[DCB_ATTR_PFC_STATE]) + return -EINVAL; + + if (!netdev->dcbnl_ops->setpfcstate) + return -EOPNOTSUPP; + + value = nla_get_u8(tb[DCB_ATTR_PFC_STATE]); + + netdev->dcbnl_ops->setpfcstate(netdev, value); + + return nla_put_u8(skb, DCB_ATTR_PFC_STATE, 0); +} + +static int dcbnl_getapp(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + struct nlattr *app_nest; + struct nlattr *app_tb[DCB_APP_ATTR_MAX + 1]; + u16 id; + u8 up, idtype; + int ret; + + if (!tb[DCB_ATTR_APP]) + return -EINVAL; + + ret = nla_parse_nested_deprecated(app_tb, DCB_APP_ATTR_MAX, + tb[DCB_ATTR_APP], dcbnl_app_nest, + NULL); + if (ret) + return ret; + + /* all must be non-null */ + if ((!app_tb[DCB_APP_ATTR_IDTYPE]) || + (!app_tb[DCB_APP_ATTR_ID])) + return -EINVAL; + + /* either by eth type or by socket number */ + idtype = nla_get_u8(app_tb[DCB_APP_ATTR_IDTYPE]); + if ((idtype != DCB_APP_IDTYPE_ETHTYPE) && + (idtype != DCB_APP_IDTYPE_PORTNUM)) + return -EINVAL; + + id = nla_get_u16(app_tb[DCB_APP_ATTR_ID]); + + if (netdev->dcbnl_ops->getapp) { + ret = netdev->dcbnl_ops->getapp(netdev, idtype, id); + if (ret < 0) + return ret; + else + up = ret; + } else { + struct dcb_app app = { + .selector = idtype, + .protocol = id, + }; + up = dcb_getapp(netdev, &app); + } + + app_nest = nla_nest_start_noflag(skb, DCB_ATTR_APP); + if (!app_nest) + return -EMSGSIZE; + + ret = nla_put_u8(skb, DCB_APP_ATTR_IDTYPE, idtype); + if (ret) + goto out_cancel; + + ret = nla_put_u16(skb, DCB_APP_ATTR_ID, id); + if (ret) + goto out_cancel; + + ret = nla_put_u8(skb, DCB_APP_ATTR_PRIORITY, up); + if (ret) + goto out_cancel; + + nla_nest_end(skb, app_nest); + + return 0; + +out_cancel: + nla_nest_cancel(skb, app_nest); + return ret; +} + +static int dcbnl_setapp(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + int ret; + u16 id; + u8 up, idtype; + struct nlattr *app_tb[DCB_APP_ATTR_MAX + 1]; + + if (!tb[DCB_ATTR_APP]) + return -EINVAL; + + ret = nla_parse_nested_deprecated(app_tb, DCB_APP_ATTR_MAX, + tb[DCB_ATTR_APP], dcbnl_app_nest, + NULL); + if (ret) + return ret; + + /* all must be non-null */ + if ((!app_tb[DCB_APP_ATTR_IDTYPE]) || + (!app_tb[DCB_APP_ATTR_ID]) || + (!app_tb[DCB_APP_ATTR_PRIORITY])) + return -EINVAL; + + /* either by eth type or by socket number */ + idtype = nla_get_u8(app_tb[DCB_APP_ATTR_IDTYPE]); + if ((idtype != DCB_APP_IDTYPE_ETHTYPE) && + (idtype != DCB_APP_IDTYPE_PORTNUM)) + return -EINVAL; + + id = nla_get_u16(app_tb[DCB_APP_ATTR_ID]); + up = nla_get_u8(app_tb[DCB_APP_ATTR_PRIORITY]); + + if (netdev->dcbnl_ops->setapp) { + ret = netdev->dcbnl_ops->setapp(netdev, idtype, id, up); + if (ret < 0) + return ret; + } else { + struct dcb_app app; + app.selector = idtype; + app.protocol = id; + app.priority = up; + ret = dcb_setapp(netdev, &app); + } + + ret = nla_put_u8(skb, DCB_ATTR_APP, ret); + dcbnl_cee_notify(netdev, RTM_SETDCB, DCB_CMD_SAPP, seq, 0); + + return ret; +} + +static int __dcbnl_pg_getcfg(struct net_device *netdev, struct nlmsghdr *nlh, + struct nlattr **tb, struct sk_buff *skb, int dir) +{ + struct nlattr *pg_nest, *param_nest, *data; + struct nlattr *pg_tb[DCB_PG_ATTR_MAX + 1]; + struct nlattr *param_tb[DCB_TC_ATTR_PARAM_MAX + 1]; + u8 prio, pgid, tc_pct, up_map; + int ret; + int getall = 0; + int i; + + if (!tb[DCB_ATTR_PG_CFG]) + return -EINVAL; + + if (!netdev->dcbnl_ops->getpgtccfgtx || + !netdev->dcbnl_ops->getpgtccfgrx || + !netdev->dcbnl_ops->getpgbwgcfgtx || + !netdev->dcbnl_ops->getpgbwgcfgrx) + return -EOPNOTSUPP; + + ret = nla_parse_nested_deprecated(pg_tb, DCB_PG_ATTR_MAX, + tb[DCB_ATTR_PG_CFG], dcbnl_pg_nest, + NULL); + if (ret) + return ret; + + pg_nest = nla_nest_start_noflag(skb, DCB_ATTR_PG_CFG); + if (!pg_nest) + return -EMSGSIZE; + + if (pg_tb[DCB_PG_ATTR_TC_ALL]) + getall = 1; + + for (i = DCB_PG_ATTR_TC_0; i <= DCB_PG_ATTR_TC_7; i++) { + if (!getall && !pg_tb[i]) + continue; + + if (pg_tb[DCB_PG_ATTR_TC_ALL]) + data = pg_tb[DCB_PG_ATTR_TC_ALL]; + else + data = pg_tb[i]; + ret = nla_parse_nested_deprecated(param_tb, + DCB_TC_ATTR_PARAM_MAX, data, + dcbnl_tc_param_nest, NULL); + if (ret) + goto err_pg; + + param_nest = nla_nest_start_noflag(skb, i); + if (!param_nest) + goto err_pg; + + pgid = DCB_ATTR_VALUE_UNDEFINED; + prio = DCB_ATTR_VALUE_UNDEFINED; + tc_pct = DCB_ATTR_VALUE_UNDEFINED; + up_map = DCB_ATTR_VALUE_UNDEFINED; + + if (dir) { + /* Rx */ + netdev->dcbnl_ops->getpgtccfgrx(netdev, + i - DCB_PG_ATTR_TC_0, &prio, + &pgid, &tc_pct, &up_map); + } else { + /* Tx */ + netdev->dcbnl_ops->getpgtccfgtx(netdev, + i - DCB_PG_ATTR_TC_0, &prio, + &pgid, &tc_pct, &up_map); + } + + if (param_tb[DCB_TC_ATTR_PARAM_PGID] || + param_tb[DCB_TC_ATTR_PARAM_ALL]) { + ret = nla_put_u8(skb, + DCB_TC_ATTR_PARAM_PGID, pgid); + if (ret) + goto err_param; + } + if (param_tb[DCB_TC_ATTR_PARAM_UP_MAPPING] || + param_tb[DCB_TC_ATTR_PARAM_ALL]) { + ret = nla_put_u8(skb, + DCB_TC_ATTR_PARAM_UP_MAPPING, up_map); + if (ret) + goto err_param; + } + if (param_tb[DCB_TC_ATTR_PARAM_STRICT_PRIO] || + param_tb[DCB_TC_ATTR_PARAM_ALL]) { + ret = nla_put_u8(skb, + DCB_TC_ATTR_PARAM_STRICT_PRIO, prio); + if (ret) + goto err_param; + } + if (param_tb[DCB_TC_ATTR_PARAM_BW_PCT] || + param_tb[DCB_TC_ATTR_PARAM_ALL]) { + ret = nla_put_u8(skb, DCB_TC_ATTR_PARAM_BW_PCT, + tc_pct); + if (ret) + goto err_param; + } + nla_nest_end(skb, param_nest); + } + + if (pg_tb[DCB_PG_ATTR_BW_ID_ALL]) + getall = 1; + else + getall = 0; + + for (i = DCB_PG_ATTR_BW_ID_0; i <= DCB_PG_ATTR_BW_ID_7; i++) { + if (!getall && !pg_tb[i]) + continue; + + tc_pct = DCB_ATTR_VALUE_UNDEFINED; + + if (dir) { + /* Rx */ + netdev->dcbnl_ops->getpgbwgcfgrx(netdev, + i - DCB_PG_ATTR_BW_ID_0, &tc_pct); + } else { + /* Tx */ + netdev->dcbnl_ops->getpgbwgcfgtx(netdev, + i - DCB_PG_ATTR_BW_ID_0, &tc_pct); + } + ret = nla_put_u8(skb, i, tc_pct); + if (ret) + goto err_pg; + } + + nla_nest_end(skb, pg_nest); + + return 0; + +err_param: + nla_nest_cancel(skb, param_nest); +err_pg: + nla_nest_cancel(skb, pg_nest); + + return -EMSGSIZE; +} + +static int dcbnl_pgtx_getcfg(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + return __dcbnl_pg_getcfg(netdev, nlh, tb, skb, 0); +} + +static int dcbnl_pgrx_getcfg(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + return __dcbnl_pg_getcfg(netdev, nlh, tb, skb, 1); +} + +static int dcbnl_setstate(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + u8 value; + + if (!tb[DCB_ATTR_STATE]) + return -EINVAL; + + if (!netdev->dcbnl_ops->setstate) + return -EOPNOTSUPP; + + value = nla_get_u8(tb[DCB_ATTR_STATE]); + + return nla_put_u8(skb, DCB_ATTR_STATE, + netdev->dcbnl_ops->setstate(netdev, value)); +} + +static int dcbnl_setpfccfg(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + struct nlattr *data[DCB_PFC_UP_ATTR_MAX + 1]; + int i; + int ret; + u8 value; + + if (!tb[DCB_ATTR_PFC_CFG]) + return -EINVAL; + + if (!netdev->dcbnl_ops->setpfccfg) + return -EOPNOTSUPP; + + ret = nla_parse_nested_deprecated(data, DCB_PFC_UP_ATTR_MAX, + tb[DCB_ATTR_PFC_CFG], + dcbnl_pfc_up_nest, NULL); + if (ret) + return ret; + + for (i = DCB_PFC_UP_ATTR_0; i <= DCB_PFC_UP_ATTR_7; i++) { + if (data[i] == NULL) + continue; + value = nla_get_u8(data[i]); + netdev->dcbnl_ops->setpfccfg(netdev, + data[i]->nla_type - DCB_PFC_UP_ATTR_0, value); + } + + return nla_put_u8(skb, DCB_ATTR_PFC_CFG, 0); +} + +static int dcbnl_setall(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + int ret; + + if (!tb[DCB_ATTR_SET_ALL]) + return -EINVAL; + + if (!netdev->dcbnl_ops->setall) + return -EOPNOTSUPP; + + ret = nla_put_u8(skb, DCB_ATTR_SET_ALL, + netdev->dcbnl_ops->setall(netdev)); + dcbnl_cee_notify(netdev, RTM_SETDCB, DCB_CMD_SET_ALL, seq, 0); + + return ret; +} + +static int __dcbnl_pg_setcfg(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb, + int dir) +{ + struct nlattr *pg_tb[DCB_PG_ATTR_MAX + 1]; + struct nlattr *param_tb[DCB_TC_ATTR_PARAM_MAX + 1]; + int ret; + int i; + u8 pgid; + u8 up_map; + u8 prio; + u8 tc_pct; + + if (!tb[DCB_ATTR_PG_CFG]) + return -EINVAL; + + if (!netdev->dcbnl_ops->setpgtccfgtx || + !netdev->dcbnl_ops->setpgtccfgrx || + !netdev->dcbnl_ops->setpgbwgcfgtx || + !netdev->dcbnl_ops->setpgbwgcfgrx) + return -EOPNOTSUPP; + + ret = nla_parse_nested_deprecated(pg_tb, DCB_PG_ATTR_MAX, + tb[DCB_ATTR_PG_CFG], dcbnl_pg_nest, + NULL); + if (ret) + return ret; + + for (i = DCB_PG_ATTR_TC_0; i <= DCB_PG_ATTR_TC_7; i++) { + if (!pg_tb[i]) + continue; + + ret = nla_parse_nested_deprecated(param_tb, + DCB_TC_ATTR_PARAM_MAX, + pg_tb[i], + dcbnl_tc_param_nest, NULL); + if (ret) + return ret; + + pgid = DCB_ATTR_VALUE_UNDEFINED; + prio = DCB_ATTR_VALUE_UNDEFINED; + tc_pct = DCB_ATTR_VALUE_UNDEFINED; + up_map = DCB_ATTR_VALUE_UNDEFINED; + + if (param_tb[DCB_TC_ATTR_PARAM_STRICT_PRIO]) + prio = + nla_get_u8(param_tb[DCB_TC_ATTR_PARAM_STRICT_PRIO]); + + if (param_tb[DCB_TC_ATTR_PARAM_PGID]) + pgid = nla_get_u8(param_tb[DCB_TC_ATTR_PARAM_PGID]); + + if (param_tb[DCB_TC_ATTR_PARAM_BW_PCT]) + tc_pct = nla_get_u8(param_tb[DCB_TC_ATTR_PARAM_BW_PCT]); + + if (param_tb[DCB_TC_ATTR_PARAM_UP_MAPPING]) + up_map = + nla_get_u8(param_tb[DCB_TC_ATTR_PARAM_UP_MAPPING]); + + /* dir: Tx = 0, Rx = 1 */ + if (dir) { + /* Rx */ + netdev->dcbnl_ops->setpgtccfgrx(netdev, + i - DCB_PG_ATTR_TC_0, + prio, pgid, tc_pct, up_map); + } else { + /* Tx */ + netdev->dcbnl_ops->setpgtccfgtx(netdev, + i - DCB_PG_ATTR_TC_0, + prio, pgid, tc_pct, up_map); + } + } + + for (i = DCB_PG_ATTR_BW_ID_0; i <= DCB_PG_ATTR_BW_ID_7; i++) { + if (!pg_tb[i]) + continue; + + tc_pct = nla_get_u8(pg_tb[i]); + + /* dir: Tx = 0, Rx = 1 */ + if (dir) { + /* Rx */ + netdev->dcbnl_ops->setpgbwgcfgrx(netdev, + i - DCB_PG_ATTR_BW_ID_0, tc_pct); + } else { + /* Tx */ + netdev->dcbnl_ops->setpgbwgcfgtx(netdev, + i - DCB_PG_ATTR_BW_ID_0, tc_pct); + } + } + + return nla_put_u8(skb, DCB_ATTR_PG_CFG, 0); +} + +static int dcbnl_pgtx_setcfg(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + return __dcbnl_pg_setcfg(netdev, nlh, seq, tb, skb, 0); +} + +static int dcbnl_pgrx_setcfg(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + return __dcbnl_pg_setcfg(netdev, nlh, seq, tb, skb, 1); +} + +static int dcbnl_bcn_getcfg(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + struct nlattr *bcn_nest; + struct nlattr *bcn_tb[DCB_BCN_ATTR_MAX + 1]; + u8 value_byte; + u32 value_integer; + int ret; + bool getall = false; + int i; + + if (!tb[DCB_ATTR_BCN]) + return -EINVAL; + + if (!netdev->dcbnl_ops->getbcnrp || + !netdev->dcbnl_ops->getbcncfg) + return -EOPNOTSUPP; + + ret = nla_parse_nested_deprecated(bcn_tb, DCB_BCN_ATTR_MAX, + tb[DCB_ATTR_BCN], dcbnl_bcn_nest, + NULL); + if (ret) + return ret; + + bcn_nest = nla_nest_start_noflag(skb, DCB_ATTR_BCN); + if (!bcn_nest) + return -EMSGSIZE; + + if (bcn_tb[DCB_BCN_ATTR_ALL]) + getall = true; + + for (i = DCB_BCN_ATTR_RP_0; i <= DCB_BCN_ATTR_RP_7; i++) { + if (!getall && !bcn_tb[i]) + continue; + + netdev->dcbnl_ops->getbcnrp(netdev, i - DCB_BCN_ATTR_RP_0, + &value_byte); + ret = nla_put_u8(skb, i, value_byte); + if (ret) + goto err_bcn; + } + + for (i = DCB_BCN_ATTR_BCNA_0; i <= DCB_BCN_ATTR_RI; i++) { + if (!getall && !bcn_tb[i]) + continue; + + netdev->dcbnl_ops->getbcncfg(netdev, i, + &value_integer); + ret = nla_put_u32(skb, i, value_integer); + if (ret) + goto err_bcn; + } + + nla_nest_end(skb, bcn_nest); + + return 0; + +err_bcn: + nla_nest_cancel(skb, bcn_nest); + return ret; +} + +static int dcbnl_bcn_setcfg(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + struct nlattr *data[DCB_BCN_ATTR_MAX + 1]; + int i; + int ret; + u8 value_byte; + u32 value_int; + + if (!tb[DCB_ATTR_BCN]) + return -EINVAL; + + if (!netdev->dcbnl_ops->setbcncfg || + !netdev->dcbnl_ops->setbcnrp) + return -EOPNOTSUPP; + + ret = nla_parse_nested_deprecated(data, DCB_BCN_ATTR_MAX, + tb[DCB_ATTR_BCN], dcbnl_bcn_nest, + NULL); + if (ret) + return ret; + + for (i = DCB_BCN_ATTR_RP_0; i <= DCB_BCN_ATTR_RP_7; i++) { + if (data[i] == NULL) + continue; + value_byte = nla_get_u8(data[i]); + netdev->dcbnl_ops->setbcnrp(netdev, + data[i]->nla_type - DCB_BCN_ATTR_RP_0, value_byte); + } + + for (i = DCB_BCN_ATTR_BCNA_0; i <= DCB_BCN_ATTR_RI; i++) { + if (data[i] == NULL) + continue; + value_int = nla_get_u32(data[i]); + netdev->dcbnl_ops->setbcncfg(netdev, + i, value_int); + } + + return nla_put_u8(skb, DCB_ATTR_BCN, 0); +} + +static int dcbnl_build_peer_app(struct net_device *netdev, struct sk_buff* skb, + int app_nested_type, int app_info_type, + int app_entry_type) +{ + struct dcb_peer_app_info info; + struct dcb_app *table = NULL; + const struct dcbnl_rtnl_ops *ops = netdev->dcbnl_ops; + u16 app_count; + int err; + + + /** + * retrieve the peer app configuration form the driver. If the driver + * handlers fail exit without doing anything + */ + err = ops->peer_getappinfo(netdev, &info, &app_count); + if (!err && app_count) { + table = kmalloc_array(app_count, sizeof(struct dcb_app), + GFP_KERNEL); + if (!table) + return -ENOMEM; + + err = ops->peer_getapptable(netdev, table); + } + + if (!err) { + u16 i; + struct nlattr *app; + + /** + * build the message, from here on the only possible failure + * is due to the skb size + */ + err = -EMSGSIZE; + + app = nla_nest_start_noflag(skb, app_nested_type); + if (!app) + goto nla_put_failure; + + if (app_info_type && + nla_put(skb, app_info_type, sizeof(info), &info)) + goto nla_put_failure; + + for (i = 0; i < app_count; i++) { + if (nla_put(skb, app_entry_type, sizeof(struct dcb_app), + &table[i])) + goto nla_put_failure; + } + nla_nest_end(skb, app); + } + err = 0; + +nla_put_failure: + kfree(table); + return err; +} + +/* Handle IEEE 802.1Qaz/802.1Qau/802.1Qbb GET commands. */ +static int dcbnl_ieee_fill(struct sk_buff *skb, struct net_device *netdev) +{ + struct nlattr *ieee, *app; + struct dcb_app_type *itr; + const struct dcbnl_rtnl_ops *ops = netdev->dcbnl_ops; + int dcbx; + int err; + + if (nla_put_string(skb, DCB_ATTR_IFNAME, netdev->name)) + return -EMSGSIZE; + + ieee = nla_nest_start_noflag(skb, DCB_ATTR_IEEE); + if (!ieee) + return -EMSGSIZE; + + if (ops->ieee_getets) { + struct ieee_ets ets; + memset(&ets, 0, sizeof(ets)); + err = ops->ieee_getets(netdev, &ets); + if (!err && + nla_put(skb, DCB_ATTR_IEEE_ETS, sizeof(ets), &ets)) + return -EMSGSIZE; + } + + if (ops->ieee_getmaxrate) { + struct ieee_maxrate maxrate; + memset(&maxrate, 0, sizeof(maxrate)); + err = ops->ieee_getmaxrate(netdev, &maxrate); + if (!err) { + err = nla_put(skb, DCB_ATTR_IEEE_MAXRATE, + sizeof(maxrate), &maxrate); + if (err) + return -EMSGSIZE; + } + } + + if (ops->ieee_getqcn) { + struct ieee_qcn qcn; + + memset(&qcn, 0, sizeof(qcn)); + err = ops->ieee_getqcn(netdev, &qcn); + if (!err) { + err = nla_put(skb, DCB_ATTR_IEEE_QCN, + sizeof(qcn), &qcn); + if (err) + return -EMSGSIZE; + } + } + + if (ops->ieee_getqcnstats) { + struct ieee_qcn_stats qcn_stats; + + memset(&qcn_stats, 0, sizeof(qcn_stats)); + err = ops->ieee_getqcnstats(netdev, &qcn_stats); + if (!err) { + err = nla_put(skb, DCB_ATTR_IEEE_QCN_STATS, + sizeof(qcn_stats), &qcn_stats); + if (err) + return -EMSGSIZE; + } + } + + if (ops->ieee_getpfc) { + struct ieee_pfc pfc; + memset(&pfc, 0, sizeof(pfc)); + err = ops->ieee_getpfc(netdev, &pfc); + if (!err && + nla_put(skb, DCB_ATTR_IEEE_PFC, sizeof(pfc), &pfc)) + return -EMSGSIZE; + } + + if (ops->dcbnl_getbuffer) { + struct dcbnl_buffer buffer; + + memset(&buffer, 0, sizeof(buffer)); + err = ops->dcbnl_getbuffer(netdev, &buffer); + if (!err && + nla_put(skb, DCB_ATTR_DCB_BUFFER, sizeof(buffer), &buffer)) + return -EMSGSIZE; + } + + app = nla_nest_start_noflag(skb, DCB_ATTR_IEEE_APP_TABLE); + if (!app) + return -EMSGSIZE; + + spin_lock_bh(&dcb_lock); + list_for_each_entry(itr, &dcb_app_list, list) { + if (itr->ifindex == netdev->ifindex) { + err = nla_put(skb, DCB_ATTR_IEEE_APP, sizeof(itr->app), + &itr->app); + if (err) { + spin_unlock_bh(&dcb_lock); + return -EMSGSIZE; + } + } + } + + if (netdev->dcbnl_ops->getdcbx) + dcbx = netdev->dcbnl_ops->getdcbx(netdev); + else + dcbx = -EOPNOTSUPP; + + spin_unlock_bh(&dcb_lock); + nla_nest_end(skb, app); + + /* get peer info if available */ + if (ops->ieee_peer_getets) { + struct ieee_ets ets; + memset(&ets, 0, sizeof(ets)); + err = ops->ieee_peer_getets(netdev, &ets); + if (!err && + nla_put(skb, DCB_ATTR_IEEE_PEER_ETS, sizeof(ets), &ets)) + return -EMSGSIZE; + } + + if (ops->ieee_peer_getpfc) { + struct ieee_pfc pfc; + memset(&pfc, 0, sizeof(pfc)); + err = ops->ieee_peer_getpfc(netdev, &pfc); + if (!err && + nla_put(skb, DCB_ATTR_IEEE_PEER_PFC, sizeof(pfc), &pfc)) + return -EMSGSIZE; + } + + if (ops->peer_getappinfo && ops->peer_getapptable) { + err = dcbnl_build_peer_app(netdev, skb, + DCB_ATTR_IEEE_PEER_APP, + DCB_ATTR_IEEE_APP_UNSPEC, + DCB_ATTR_IEEE_APP); + if (err) + return -EMSGSIZE; + } + + nla_nest_end(skb, ieee); + if (dcbx >= 0) { + err = nla_put_u8(skb, DCB_ATTR_DCBX, dcbx); + if (err) + return -EMSGSIZE; + } + + return 0; +} + +static int dcbnl_cee_pg_fill(struct sk_buff *skb, struct net_device *dev, + int dir) +{ + u8 pgid, up_map, prio, tc_pct; + const struct dcbnl_rtnl_ops *ops = dev->dcbnl_ops; + int i = dir ? DCB_ATTR_CEE_TX_PG : DCB_ATTR_CEE_RX_PG; + struct nlattr *pg = nla_nest_start_noflag(skb, i); + + if (!pg) + return -EMSGSIZE; + + for (i = DCB_PG_ATTR_TC_0; i <= DCB_PG_ATTR_TC_7; i++) { + struct nlattr *tc_nest = nla_nest_start_noflag(skb, i); + + if (!tc_nest) + return -EMSGSIZE; + + pgid = DCB_ATTR_VALUE_UNDEFINED; + prio = DCB_ATTR_VALUE_UNDEFINED; + tc_pct = DCB_ATTR_VALUE_UNDEFINED; + up_map = DCB_ATTR_VALUE_UNDEFINED; + + if (!dir) + ops->getpgtccfgrx(dev, i - DCB_PG_ATTR_TC_0, + &prio, &pgid, &tc_pct, &up_map); + else + ops->getpgtccfgtx(dev, i - DCB_PG_ATTR_TC_0, + &prio, &pgid, &tc_pct, &up_map); + + if (nla_put_u8(skb, DCB_TC_ATTR_PARAM_PGID, pgid) || + nla_put_u8(skb, DCB_TC_ATTR_PARAM_UP_MAPPING, up_map) || + nla_put_u8(skb, DCB_TC_ATTR_PARAM_STRICT_PRIO, prio) || + nla_put_u8(skb, DCB_TC_ATTR_PARAM_BW_PCT, tc_pct)) + return -EMSGSIZE; + nla_nest_end(skb, tc_nest); + } + + for (i = DCB_PG_ATTR_BW_ID_0; i <= DCB_PG_ATTR_BW_ID_7; i++) { + tc_pct = DCB_ATTR_VALUE_UNDEFINED; + + if (!dir) + ops->getpgbwgcfgrx(dev, i - DCB_PG_ATTR_BW_ID_0, + &tc_pct); + else + ops->getpgbwgcfgtx(dev, i - DCB_PG_ATTR_BW_ID_0, + &tc_pct); + if (nla_put_u8(skb, i, tc_pct)) + return -EMSGSIZE; + } + nla_nest_end(skb, pg); + return 0; +} + +static int dcbnl_cee_fill(struct sk_buff *skb, struct net_device *netdev) +{ + struct nlattr *cee, *app; + struct dcb_app_type *itr; + const struct dcbnl_rtnl_ops *ops = netdev->dcbnl_ops; + int dcbx, i, err = -EMSGSIZE; + u8 value; + + if (nla_put_string(skb, DCB_ATTR_IFNAME, netdev->name)) + goto nla_put_failure; + cee = nla_nest_start_noflag(skb, DCB_ATTR_CEE); + if (!cee) + goto nla_put_failure; + + /* local pg */ + if (ops->getpgtccfgtx && ops->getpgbwgcfgtx) { + err = dcbnl_cee_pg_fill(skb, netdev, 1); + if (err) + goto nla_put_failure; + } + + if (ops->getpgtccfgrx && ops->getpgbwgcfgrx) { + err = dcbnl_cee_pg_fill(skb, netdev, 0); + if (err) + goto nla_put_failure; + } + + /* local pfc */ + if (ops->getpfccfg) { + struct nlattr *pfc_nest = nla_nest_start_noflag(skb, + DCB_ATTR_CEE_PFC); + + if (!pfc_nest) + goto nla_put_failure; + + for (i = DCB_PFC_UP_ATTR_0; i <= DCB_PFC_UP_ATTR_7; i++) { + ops->getpfccfg(netdev, i - DCB_PFC_UP_ATTR_0, &value); + if (nla_put_u8(skb, i, value)) + goto nla_put_failure; + } + nla_nest_end(skb, pfc_nest); + } + + /* local app */ + spin_lock_bh(&dcb_lock); + app = nla_nest_start_noflag(skb, DCB_ATTR_CEE_APP_TABLE); + if (!app) + goto dcb_unlock; + + list_for_each_entry(itr, &dcb_app_list, list) { + if (itr->ifindex == netdev->ifindex) { + struct nlattr *app_nest = nla_nest_start_noflag(skb, + DCB_ATTR_APP); + if (!app_nest) + goto dcb_unlock; + + err = nla_put_u8(skb, DCB_APP_ATTR_IDTYPE, + itr->app.selector); + if (err) + goto dcb_unlock; + + err = nla_put_u16(skb, DCB_APP_ATTR_ID, + itr->app.protocol); + if (err) + goto dcb_unlock; + + err = nla_put_u8(skb, DCB_APP_ATTR_PRIORITY, + itr->app.priority); + if (err) + goto dcb_unlock; + + nla_nest_end(skb, app_nest); + } + } + nla_nest_end(skb, app); + + if (netdev->dcbnl_ops->getdcbx) + dcbx = netdev->dcbnl_ops->getdcbx(netdev); + else + dcbx = -EOPNOTSUPP; + + spin_unlock_bh(&dcb_lock); + + /* features flags */ + if (ops->getfeatcfg) { + struct nlattr *feat = nla_nest_start_noflag(skb, + DCB_ATTR_CEE_FEAT); + if (!feat) + goto nla_put_failure; + + for (i = DCB_FEATCFG_ATTR_ALL + 1; i <= DCB_FEATCFG_ATTR_MAX; + i++) + if (!ops->getfeatcfg(netdev, i, &value) && + nla_put_u8(skb, i, value)) + goto nla_put_failure; + + nla_nest_end(skb, feat); + } + + /* peer info if available */ + if (ops->cee_peer_getpg) { + struct cee_pg pg; + memset(&pg, 0, sizeof(pg)); + err = ops->cee_peer_getpg(netdev, &pg); + if (!err && + nla_put(skb, DCB_ATTR_CEE_PEER_PG, sizeof(pg), &pg)) + goto nla_put_failure; + } + + if (ops->cee_peer_getpfc) { + struct cee_pfc pfc; + memset(&pfc, 0, sizeof(pfc)); + err = ops->cee_peer_getpfc(netdev, &pfc); + if (!err && + nla_put(skb, DCB_ATTR_CEE_PEER_PFC, sizeof(pfc), &pfc)) + goto nla_put_failure; + } + + if (ops->peer_getappinfo && ops->peer_getapptable) { + err = dcbnl_build_peer_app(netdev, skb, + DCB_ATTR_CEE_PEER_APP_TABLE, + DCB_ATTR_CEE_PEER_APP_INFO, + DCB_ATTR_CEE_PEER_APP); + if (err) + goto nla_put_failure; + } + nla_nest_end(skb, cee); + + /* DCBX state */ + if (dcbx >= 0) { + err = nla_put_u8(skb, DCB_ATTR_DCBX, dcbx); + if (err) + goto nla_put_failure; + } + return 0; + +dcb_unlock: + spin_unlock_bh(&dcb_lock); +nla_put_failure: + err = -EMSGSIZE; + return err; +} + +static int dcbnl_notify(struct net_device *dev, int event, int cmd, + u32 seq, u32 portid, int dcbx_ver) +{ + struct net *net = dev_net(dev); + struct sk_buff *skb; + struct nlmsghdr *nlh; + const struct dcbnl_rtnl_ops *ops = dev->dcbnl_ops; + int err; + + if (!ops) + return -EOPNOTSUPP; + + skb = dcbnl_newmsg(event, cmd, portid, seq, 0, &nlh); + if (!skb) + return -ENOMEM; + + if (dcbx_ver == DCB_CAP_DCBX_VER_IEEE) + err = dcbnl_ieee_fill(skb, dev); + else + err = dcbnl_cee_fill(skb, dev); + + if (err < 0) { + /* Report error to broadcast listeners */ + nlmsg_free(skb); + rtnl_set_sk_err(net, RTNLGRP_DCB, err); + } else { + /* End nlmsg and notify broadcast listeners */ + nlmsg_end(skb, nlh); + rtnl_notify(skb, net, 0, RTNLGRP_DCB, NULL, GFP_KERNEL); + } + + return err; +} + +int dcbnl_ieee_notify(struct net_device *dev, int event, int cmd, + u32 seq, u32 portid) +{ + return dcbnl_notify(dev, event, cmd, seq, portid, DCB_CAP_DCBX_VER_IEEE); +} +EXPORT_SYMBOL(dcbnl_ieee_notify); + +int dcbnl_cee_notify(struct net_device *dev, int event, int cmd, + u32 seq, u32 portid) +{ + return dcbnl_notify(dev, event, cmd, seq, portid, DCB_CAP_DCBX_VER_CEE); +} +EXPORT_SYMBOL(dcbnl_cee_notify); + +/* Handle IEEE 802.1Qaz/802.1Qau/802.1Qbb SET commands. + * If any requested operation can not be completed + * the entire msg is aborted and error value is returned. + * No attempt is made to reconcile the case where only part of the + * cmd can be completed. + */ +static int dcbnl_ieee_set(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + const struct dcbnl_rtnl_ops *ops = netdev->dcbnl_ops; + struct nlattr *ieee[DCB_ATTR_IEEE_MAX + 1]; + int prio; + int err; + + if (!ops) + return -EOPNOTSUPP; + + if (!tb[DCB_ATTR_IEEE]) + return -EINVAL; + + err = nla_parse_nested_deprecated(ieee, DCB_ATTR_IEEE_MAX, + tb[DCB_ATTR_IEEE], + dcbnl_ieee_policy, NULL); + if (err) + return err; + + if (ieee[DCB_ATTR_IEEE_ETS] && ops->ieee_setets) { + struct ieee_ets *ets = nla_data(ieee[DCB_ATTR_IEEE_ETS]); + err = ops->ieee_setets(netdev, ets); + if (err) + goto err; + } + + if (ieee[DCB_ATTR_IEEE_MAXRATE] && ops->ieee_setmaxrate) { + struct ieee_maxrate *maxrate = + nla_data(ieee[DCB_ATTR_IEEE_MAXRATE]); + err = ops->ieee_setmaxrate(netdev, maxrate); + if (err) + goto err; + } + + if (ieee[DCB_ATTR_IEEE_QCN] && ops->ieee_setqcn) { + struct ieee_qcn *qcn = + nla_data(ieee[DCB_ATTR_IEEE_QCN]); + + err = ops->ieee_setqcn(netdev, qcn); + if (err) + goto err; + } + + if (ieee[DCB_ATTR_IEEE_PFC] && ops->ieee_setpfc) { + struct ieee_pfc *pfc = nla_data(ieee[DCB_ATTR_IEEE_PFC]); + err = ops->ieee_setpfc(netdev, pfc); + if (err) + goto err; + } + + if (ieee[DCB_ATTR_DCB_BUFFER] && ops->dcbnl_setbuffer) { + struct dcbnl_buffer *buffer = + nla_data(ieee[DCB_ATTR_DCB_BUFFER]); + + for (prio = 0; prio < ARRAY_SIZE(buffer->prio2buffer); prio++) { + if (buffer->prio2buffer[prio] >= DCBX_MAX_BUFFERS) { + err = -EINVAL; + goto err; + } + } + + err = ops->dcbnl_setbuffer(netdev, buffer); + if (err) + goto err; + } + + if (ieee[DCB_ATTR_IEEE_APP_TABLE]) { + struct nlattr *attr; + int rem; + + nla_for_each_nested(attr, ieee[DCB_ATTR_IEEE_APP_TABLE], rem) { + struct dcb_app *app_data; + + if (nla_type(attr) != DCB_ATTR_IEEE_APP) + continue; + + if (nla_len(attr) < sizeof(struct dcb_app)) { + err = -ERANGE; + goto err; + } + + app_data = nla_data(attr); + if (ops->ieee_setapp) + err = ops->ieee_setapp(netdev, app_data); + else + err = dcb_ieee_setapp(netdev, app_data); + if (err) + goto err; + } + } + +err: + err = nla_put_u8(skb, DCB_ATTR_IEEE, err); + dcbnl_ieee_notify(netdev, RTM_SETDCB, DCB_CMD_IEEE_SET, seq, 0); + return err; +} + +static int dcbnl_ieee_get(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + const struct dcbnl_rtnl_ops *ops = netdev->dcbnl_ops; + + if (!ops) + return -EOPNOTSUPP; + + return dcbnl_ieee_fill(skb, netdev); +} + +static int dcbnl_ieee_del(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + const struct dcbnl_rtnl_ops *ops = netdev->dcbnl_ops; + struct nlattr *ieee[DCB_ATTR_IEEE_MAX + 1]; + int err; + + if (!ops) + return -EOPNOTSUPP; + + if (!tb[DCB_ATTR_IEEE]) + return -EINVAL; + + err = nla_parse_nested_deprecated(ieee, DCB_ATTR_IEEE_MAX, + tb[DCB_ATTR_IEEE], + dcbnl_ieee_policy, NULL); + if (err) + return err; + + if (ieee[DCB_ATTR_IEEE_APP_TABLE]) { + struct nlattr *attr; + int rem; + + nla_for_each_nested(attr, ieee[DCB_ATTR_IEEE_APP_TABLE], rem) { + struct dcb_app *app_data; + + if (nla_type(attr) != DCB_ATTR_IEEE_APP) + continue; + app_data = nla_data(attr); + if (ops->ieee_delapp) + err = ops->ieee_delapp(netdev, app_data); + else + err = dcb_ieee_delapp(netdev, app_data); + if (err) + goto err; + } + } + +err: + err = nla_put_u8(skb, DCB_ATTR_IEEE, err); + dcbnl_ieee_notify(netdev, RTM_SETDCB, DCB_CMD_IEEE_DEL, seq, 0); + return err; +} + + +/* DCBX configuration */ +static int dcbnl_getdcbx(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + if (!netdev->dcbnl_ops->getdcbx) + return -EOPNOTSUPP; + + return nla_put_u8(skb, DCB_ATTR_DCBX, + netdev->dcbnl_ops->getdcbx(netdev)); +} + +static int dcbnl_setdcbx(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + u8 value; + + if (!netdev->dcbnl_ops->setdcbx) + return -EOPNOTSUPP; + + if (!tb[DCB_ATTR_DCBX]) + return -EINVAL; + + value = nla_get_u8(tb[DCB_ATTR_DCBX]); + + return nla_put_u8(skb, DCB_ATTR_DCBX, + netdev->dcbnl_ops->setdcbx(netdev, value)); +} + +static int dcbnl_getfeatcfg(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + struct nlattr *data[DCB_FEATCFG_ATTR_MAX + 1], *nest; + u8 value; + int ret, i; + int getall = 0; + + if (!netdev->dcbnl_ops->getfeatcfg) + return -EOPNOTSUPP; + + if (!tb[DCB_ATTR_FEATCFG]) + return -EINVAL; + + ret = nla_parse_nested_deprecated(data, DCB_FEATCFG_ATTR_MAX, + tb[DCB_ATTR_FEATCFG], + dcbnl_featcfg_nest, NULL); + if (ret) + return ret; + + nest = nla_nest_start_noflag(skb, DCB_ATTR_FEATCFG); + if (!nest) + return -EMSGSIZE; + + if (data[DCB_FEATCFG_ATTR_ALL]) + getall = 1; + + for (i = DCB_FEATCFG_ATTR_ALL+1; i <= DCB_FEATCFG_ATTR_MAX; i++) { + if (!getall && !data[i]) + continue; + + ret = netdev->dcbnl_ops->getfeatcfg(netdev, i, &value); + if (!ret) + ret = nla_put_u8(skb, i, value); + + if (ret) { + nla_nest_cancel(skb, nest); + goto nla_put_failure; + } + } + nla_nest_end(skb, nest); + +nla_put_failure: + return ret; +} + +static int dcbnl_setfeatcfg(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + struct nlattr *data[DCB_FEATCFG_ATTR_MAX + 1]; + int ret, i; + u8 value; + + if (!netdev->dcbnl_ops->setfeatcfg) + return -ENOTSUPP; + + if (!tb[DCB_ATTR_FEATCFG]) + return -EINVAL; + + ret = nla_parse_nested_deprecated(data, DCB_FEATCFG_ATTR_MAX, + tb[DCB_ATTR_FEATCFG], + dcbnl_featcfg_nest, NULL); + + if (ret) + goto err; + + for (i = DCB_FEATCFG_ATTR_ALL+1; i <= DCB_FEATCFG_ATTR_MAX; i++) { + if (data[i] == NULL) + continue; + + value = nla_get_u8(data[i]); + + ret = netdev->dcbnl_ops->setfeatcfg(netdev, i, value); + + if (ret) + goto err; + } +err: + ret = nla_put_u8(skb, DCB_ATTR_FEATCFG, ret); + + return ret; +} + +/* Handle CEE DCBX GET commands. */ +static int dcbnl_cee_get(struct net_device *netdev, struct nlmsghdr *nlh, + u32 seq, struct nlattr **tb, struct sk_buff *skb) +{ + const struct dcbnl_rtnl_ops *ops = netdev->dcbnl_ops; + + if (!ops) + return -EOPNOTSUPP; + + return dcbnl_cee_fill(skb, netdev); +} + +struct reply_func { + /* reply netlink message type */ + int type; + + /* function to fill message contents */ + int (*cb)(struct net_device *, struct nlmsghdr *, u32, + struct nlattr **, struct sk_buff *); +}; + +static const struct reply_func reply_funcs[DCB_CMD_MAX+1] = { + [DCB_CMD_GSTATE] = { RTM_GETDCB, dcbnl_getstate }, + [DCB_CMD_SSTATE] = { RTM_SETDCB, dcbnl_setstate }, + [DCB_CMD_PFC_GCFG] = { RTM_GETDCB, dcbnl_getpfccfg }, + [DCB_CMD_PFC_SCFG] = { RTM_SETDCB, dcbnl_setpfccfg }, + [DCB_CMD_GPERM_HWADDR] = { RTM_GETDCB, dcbnl_getperm_hwaddr }, + [DCB_CMD_GCAP] = { RTM_GETDCB, dcbnl_getcap }, + [DCB_CMD_GNUMTCS] = { RTM_GETDCB, dcbnl_getnumtcs }, + [DCB_CMD_SNUMTCS] = { RTM_SETDCB, dcbnl_setnumtcs }, + [DCB_CMD_PFC_GSTATE] = { RTM_GETDCB, dcbnl_getpfcstate }, + [DCB_CMD_PFC_SSTATE] = { RTM_SETDCB, dcbnl_setpfcstate }, + [DCB_CMD_GAPP] = { RTM_GETDCB, dcbnl_getapp }, + [DCB_CMD_SAPP] = { RTM_SETDCB, dcbnl_setapp }, + [DCB_CMD_PGTX_GCFG] = { RTM_GETDCB, dcbnl_pgtx_getcfg }, + [DCB_CMD_PGTX_SCFG] = { RTM_SETDCB, dcbnl_pgtx_setcfg }, + [DCB_CMD_PGRX_GCFG] = { RTM_GETDCB, dcbnl_pgrx_getcfg }, + [DCB_CMD_PGRX_SCFG] = { RTM_SETDCB, dcbnl_pgrx_setcfg }, + [DCB_CMD_SET_ALL] = { RTM_SETDCB, dcbnl_setall }, + [DCB_CMD_BCN_GCFG] = { RTM_GETDCB, dcbnl_bcn_getcfg }, + [DCB_CMD_BCN_SCFG] = { RTM_SETDCB, dcbnl_bcn_setcfg }, + [DCB_CMD_IEEE_GET] = { RTM_GETDCB, dcbnl_ieee_get }, + [DCB_CMD_IEEE_SET] = { RTM_SETDCB, dcbnl_ieee_set }, + [DCB_CMD_IEEE_DEL] = { RTM_SETDCB, dcbnl_ieee_del }, + [DCB_CMD_GDCBX] = { RTM_GETDCB, dcbnl_getdcbx }, + [DCB_CMD_SDCBX] = { RTM_SETDCB, dcbnl_setdcbx }, + [DCB_CMD_GFEATCFG] = { RTM_GETDCB, dcbnl_getfeatcfg }, + [DCB_CMD_SFEATCFG] = { RTM_SETDCB, dcbnl_setfeatcfg }, + [DCB_CMD_CEE_GET] = { RTM_GETDCB, dcbnl_cee_get }, +}; + +static int dcb_doit(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct net_device *netdev; + struct dcbmsg *dcb = nlmsg_data(nlh); + struct nlattr *tb[DCB_ATTR_MAX + 1]; + u32 portid = NETLINK_CB(skb).portid; + int ret = -EINVAL; + struct sk_buff *reply_skb; + struct nlmsghdr *reply_nlh = NULL; + const struct reply_func *fn; + + if ((nlh->nlmsg_type == RTM_SETDCB) && !netlink_capable(skb, CAP_NET_ADMIN)) + return -EPERM; + + ret = nlmsg_parse_deprecated(nlh, sizeof(*dcb), tb, DCB_ATTR_MAX, + dcbnl_rtnl_policy, extack); + if (ret < 0) + return ret; + + if (dcb->cmd > DCB_CMD_MAX) + return -EINVAL; + + /* check if a reply function has been defined for the command */ + fn = &reply_funcs[dcb->cmd]; + if (!fn->cb) + return -EOPNOTSUPP; + if (fn->type == RTM_SETDCB && !netlink_capable(skb, CAP_NET_ADMIN)) + return -EPERM; + + if (!tb[DCB_ATTR_IFNAME]) + return -EINVAL; + + netdev = __dev_get_by_name(net, nla_data(tb[DCB_ATTR_IFNAME])); + if (!netdev) + return -ENODEV; + + if (!netdev->dcbnl_ops) + return -EOPNOTSUPP; + + reply_skb = dcbnl_newmsg(fn->type, dcb->cmd, portid, nlh->nlmsg_seq, + nlh->nlmsg_flags, &reply_nlh); + if (!reply_skb) + return -ENOMEM; + + ret = fn->cb(netdev, nlh, nlh->nlmsg_seq, tb, reply_skb); + if (ret < 0) { + nlmsg_free(reply_skb); + goto out; + } + + nlmsg_end(reply_skb, reply_nlh); + + ret = rtnl_unicast(reply_skb, net, portid); +out: + return ret; +} + +static struct dcb_app_type *dcb_app_lookup(const struct dcb_app *app, + int ifindex, int prio) +{ + struct dcb_app_type *itr; + + list_for_each_entry(itr, &dcb_app_list, list) { + if (itr->app.selector == app->selector && + itr->app.protocol == app->protocol && + itr->ifindex == ifindex && + ((prio == -1) || itr->app.priority == prio)) + return itr; + } + + return NULL; +} + +static int dcb_app_add(const struct dcb_app *app, int ifindex) +{ + struct dcb_app_type *entry; + + entry = kmalloc(sizeof(*entry), GFP_ATOMIC); + if (!entry) + return -ENOMEM; + + memcpy(&entry->app, app, sizeof(*app)); + entry->ifindex = ifindex; + list_add(&entry->list, &dcb_app_list); + + return 0; +} + +/** + * dcb_getapp - retrieve the DCBX application user priority + * @dev: network interface + * @app: application to get user priority of + * + * On success returns a non-zero 802.1p user priority bitmap + * otherwise returns 0 as the invalid user priority bitmap to + * indicate an error. + */ +u8 dcb_getapp(struct net_device *dev, struct dcb_app *app) +{ + struct dcb_app_type *itr; + u8 prio = 0; + + spin_lock_bh(&dcb_lock); + itr = dcb_app_lookup(app, dev->ifindex, -1); + if (itr) + prio = itr->app.priority; + spin_unlock_bh(&dcb_lock); + + return prio; +} +EXPORT_SYMBOL(dcb_getapp); + +/** + * dcb_setapp - add CEE dcb application data to app list + * @dev: network interface + * @new: application data to add + * + * Priority 0 is an invalid priority in CEE spec. This routine + * removes applications from the app list if the priority is + * set to zero. Priority is expected to be 8-bit 802.1p user priority bitmap + */ +int dcb_setapp(struct net_device *dev, struct dcb_app *new) +{ + struct dcb_app_type *itr; + struct dcb_app_type event; + int err = 0; + + event.ifindex = dev->ifindex; + memcpy(&event.app, new, sizeof(event.app)); + if (dev->dcbnl_ops->getdcbx) + event.dcbx = dev->dcbnl_ops->getdcbx(dev); + + spin_lock_bh(&dcb_lock); + /* Search for existing match and replace */ + itr = dcb_app_lookup(new, dev->ifindex, -1); + if (itr) { + if (new->priority) + itr->app.priority = new->priority; + else { + list_del(&itr->list); + kfree(itr); + } + goto out; + } + /* App type does not exist add new application type */ + if (new->priority) + err = dcb_app_add(new, dev->ifindex); +out: + spin_unlock_bh(&dcb_lock); + if (!err) + call_dcbevent_notifiers(DCB_APP_EVENT, &event); + return err; +} +EXPORT_SYMBOL(dcb_setapp); + +/** + * dcb_ieee_getapp_mask - retrieve the IEEE DCB application priority + * @dev: network interface + * @app: where to store the retrieve application data + * + * Helper routine which on success returns a non-zero 802.1Qaz user + * priority bitmap otherwise returns 0 to indicate the dcb_app was + * not found in APP list. + */ +u8 dcb_ieee_getapp_mask(struct net_device *dev, struct dcb_app *app) +{ + struct dcb_app_type *itr; + u8 prio = 0; + + spin_lock_bh(&dcb_lock); + itr = dcb_app_lookup(app, dev->ifindex, -1); + if (itr) + prio |= 1 << itr->app.priority; + spin_unlock_bh(&dcb_lock); + + return prio; +} +EXPORT_SYMBOL(dcb_ieee_getapp_mask); + +/** + * dcb_ieee_setapp - add IEEE dcb application data to app list + * @dev: network interface + * @new: application data to add + * + * This adds Application data to the list. Multiple application + * entries may exists for the same selector and protocol as long + * as the priorities are different. Priority is expected to be a + * 3-bit unsigned integer + */ +int dcb_ieee_setapp(struct net_device *dev, struct dcb_app *new) +{ + struct dcb_app_type event; + int err = 0; + + event.ifindex = dev->ifindex; + memcpy(&event.app, new, sizeof(event.app)); + if (dev->dcbnl_ops->getdcbx) + event.dcbx = dev->dcbnl_ops->getdcbx(dev); + + spin_lock_bh(&dcb_lock); + /* Search for existing match and abort if found */ + if (dcb_app_lookup(new, dev->ifindex, new->priority)) { + err = -EEXIST; + goto out; + } + + err = dcb_app_add(new, dev->ifindex); +out: + spin_unlock_bh(&dcb_lock); + if (!err) + call_dcbevent_notifiers(DCB_APP_EVENT, &event); + return err; +} +EXPORT_SYMBOL(dcb_ieee_setapp); + +/** + * dcb_ieee_delapp - delete IEEE dcb application data from list + * @dev: network interface + * @del: application data to delete + * + * This removes a matching APP data from the APP list + */ +int dcb_ieee_delapp(struct net_device *dev, struct dcb_app *del) +{ + struct dcb_app_type *itr; + struct dcb_app_type event; + int err = -ENOENT; + + event.ifindex = dev->ifindex; + memcpy(&event.app, del, sizeof(event.app)); + if (dev->dcbnl_ops->getdcbx) + event.dcbx = dev->dcbnl_ops->getdcbx(dev); + + spin_lock_bh(&dcb_lock); + /* Search for existing match and remove it. */ + if ((itr = dcb_app_lookup(del, dev->ifindex, del->priority))) { + list_del(&itr->list); + kfree(itr); + err = 0; + } + + spin_unlock_bh(&dcb_lock); + if (!err) + call_dcbevent_notifiers(DCB_APP_EVENT, &event); + return err; +} +EXPORT_SYMBOL(dcb_ieee_delapp); + +/* + * dcb_ieee_getapp_prio_dscp_mask_map - For a given device, find mapping from + * priorities to the DSCP values assigned to that priority. Initialize p_map + * such that each map element holds a bit mask of DSCP values configured for + * that priority by APP entries. + */ +void dcb_ieee_getapp_prio_dscp_mask_map(const struct net_device *dev, + struct dcb_ieee_app_prio_map *p_map) +{ + int ifindex = dev->ifindex; + struct dcb_app_type *itr; + u8 prio; + + memset(p_map->map, 0, sizeof(p_map->map)); + + spin_lock_bh(&dcb_lock); + list_for_each_entry(itr, &dcb_app_list, list) { + if (itr->ifindex == ifindex && + itr->app.selector == IEEE_8021QAZ_APP_SEL_DSCP && + itr->app.protocol < 64 && + itr->app.priority < IEEE_8021QAZ_MAX_TCS) { + prio = itr->app.priority; + p_map->map[prio] |= 1ULL << itr->app.protocol; + } + } + spin_unlock_bh(&dcb_lock); +} +EXPORT_SYMBOL(dcb_ieee_getapp_prio_dscp_mask_map); + +/* + * dcb_ieee_getapp_dscp_prio_mask_map - For a given device, find mapping from + * DSCP values to the priorities assigned to that DSCP value. Initialize p_map + * such that each map element holds a bit mask of priorities configured for a + * given DSCP value by APP entries. + */ +void +dcb_ieee_getapp_dscp_prio_mask_map(const struct net_device *dev, + struct dcb_ieee_app_dscp_map *p_map) +{ + int ifindex = dev->ifindex; + struct dcb_app_type *itr; + + memset(p_map->map, 0, sizeof(p_map->map)); + + spin_lock_bh(&dcb_lock); + list_for_each_entry(itr, &dcb_app_list, list) { + if (itr->ifindex == ifindex && + itr->app.selector == IEEE_8021QAZ_APP_SEL_DSCP && + itr->app.protocol < 64 && + itr->app.priority < IEEE_8021QAZ_MAX_TCS) + p_map->map[itr->app.protocol] |= 1 << itr->app.priority; + } + spin_unlock_bh(&dcb_lock); +} +EXPORT_SYMBOL(dcb_ieee_getapp_dscp_prio_mask_map); + +/* + * Per 802.1Q-2014, the selector value of 1 is used for matching on Ethernet + * type, with valid PID values >= 1536. A special meaning is then assigned to + * protocol value of 0: "default priority. For use when priority is not + * otherwise specified". + * + * dcb_ieee_getapp_default_prio_mask - For a given device, find all APP entries + * of the form {$PRIO, ETHERTYPE, 0} and construct a bit mask of all default + * priorities set by these entries. + */ +u8 dcb_ieee_getapp_default_prio_mask(const struct net_device *dev) +{ + int ifindex = dev->ifindex; + struct dcb_app_type *itr; + u8 mask = 0; + + spin_lock_bh(&dcb_lock); + list_for_each_entry(itr, &dcb_app_list, list) { + if (itr->ifindex == ifindex && + itr->app.selector == IEEE_8021QAZ_APP_SEL_ETHERTYPE && + itr->app.protocol == 0 && + itr->app.priority < IEEE_8021QAZ_MAX_TCS) + mask |= 1 << itr->app.priority; + } + spin_unlock_bh(&dcb_lock); + + return mask; +} +EXPORT_SYMBOL(dcb_ieee_getapp_default_prio_mask); + +static void dcbnl_flush_dev(struct net_device *dev) +{ + struct dcb_app_type *itr, *tmp; + + spin_lock_bh(&dcb_lock); + + list_for_each_entry_safe(itr, tmp, &dcb_app_list, list) { + if (itr->ifindex == dev->ifindex) { + list_del(&itr->list); + kfree(itr); + } + } + + spin_unlock_bh(&dcb_lock); +} + +static int dcbnl_netdevice_event(struct notifier_block *nb, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + switch (event) { + case NETDEV_UNREGISTER: + if (!dev->dcbnl_ops) + return NOTIFY_DONE; + + dcbnl_flush_dev(dev); + + return NOTIFY_OK; + default: + return NOTIFY_DONE; + } +} + +static struct notifier_block dcbnl_nb __read_mostly = { + .notifier_call = dcbnl_netdevice_event, +}; + +static int __init dcbnl_init(void) +{ + int err; + + err = register_netdevice_notifier(&dcbnl_nb); + if (err) + return err; + + rtnl_register(PF_UNSPEC, RTM_GETDCB, dcb_doit, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_SETDCB, dcb_doit, NULL, 0); + + return 0; +} +device_initcall(dcbnl_init); diff --git a/net/dccp/Kconfig b/net/dccp/Kconfig new file mode 100644 index 000000000..0c7d2f66b --- /dev/null +++ b/net/dccp/Kconfig @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: GPL-2.0-only +menuconfig IP_DCCP + tristate "The DCCP Protocol" + depends on INET + help + Datagram Congestion Control Protocol (RFC 4340) + + From https://www.ietf.org/rfc/rfc4340.txt: + + The Datagram Congestion Control Protocol (DCCP) is a transport + protocol that implements bidirectional, unicast connections of + congestion-controlled, unreliable datagrams. It should be suitable + for use by applications such as streaming media, Internet telephony, + and on-line games. + + To compile this protocol support as a module, choose M here: the + module will be called dccp. + + If in doubt, say N. + +if IP_DCCP + +config INET_DCCP_DIAG + depends on INET_DIAG + def_tristate y if (IP_DCCP = y && INET_DIAG = y) + def_tristate m + +source "net/dccp/ccids/Kconfig" + +menu "DCCP Kernel Hacking" + depends on DEBUG_KERNEL=y + +config IP_DCCP_DEBUG + bool "DCCP debug messages" + help + Only use this if you're hacking DCCP. + + When compiling DCCP as a module, this debugging output can be toggled + by setting the parameter dccp_debug of the `dccp' module to 0 or 1. + + Just say N. + + +endmenu + +endif # IP_DDCP diff --git a/net/dccp/Makefile b/net/dccp/Makefile new file mode 100644 index 000000000..5b4ff37bc --- /dev/null +++ b/net/dccp/Makefile @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: GPL-2.0 +obj-$(CONFIG_IP_DCCP) += dccp.o dccp_ipv4.o + +dccp-y := ccid.o feat.o input.o minisocks.o options.o output.o proto.o timer.o \ + qpolicy.o +# +# CCID algorithms to be used by dccp.ko +# +# CCID-2 is default (RFC 4340, p. 77) and has Ack Vectors as dependency +dccp-y += ccids/ccid2.o ackvec.o +dccp-$(CONFIG_IP_DCCP_CCID3) += ccids/ccid3.o +dccp-$(CONFIG_IP_DCCP_TFRC_LIB) += ccids/lib/tfrc.o \ + ccids/lib/tfrc_equation.o \ + ccids/lib/packet_history.o \ + ccids/lib/loss_interval.o + +dccp_ipv4-y := ipv4.o + +# build dccp_ipv6 as module whenever either IPv6 or DCCP is a module +obj-$(subst y,$(CONFIG_IP_DCCP),$(CONFIG_IPV6)) += dccp_ipv6.o +dccp_ipv6-y := ipv6.o + +obj-$(CONFIG_INET_DCCP_DIAG) += dccp_diag.o + +dccp-$(CONFIG_SYSCTL) += sysctl.o + +dccp_diag-y := diag.o + +# build with local directory for trace.h +CFLAGS_proto.o := -I$(src) diff --git a/net/dccp/ackvec.c b/net/dccp/ackvec.c new file mode 100644 index 000000000..c4bbac997 --- /dev/null +++ b/net/dccp/ackvec.c @@ -0,0 +1,407 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/dccp/ackvec.c + * + * An implementation of Ack Vectors for the DCCP protocol + * Copyright (c) 2007 University of Aberdeen, Scotland, UK + * Copyright (c) 2005 Arnaldo Carvalho de Melo + */ +#include "dccp.h" +#include +#include +#include + +static struct kmem_cache *dccp_ackvec_slab; +static struct kmem_cache *dccp_ackvec_record_slab; + +struct dccp_ackvec *dccp_ackvec_alloc(const gfp_t priority) +{ + struct dccp_ackvec *av = kmem_cache_zalloc(dccp_ackvec_slab, priority); + + if (av != NULL) { + av->av_buf_head = av->av_buf_tail = DCCPAV_MAX_ACKVEC_LEN - 1; + INIT_LIST_HEAD(&av->av_records); + } + return av; +} + +static void dccp_ackvec_purge_records(struct dccp_ackvec *av) +{ + struct dccp_ackvec_record *cur, *next; + + list_for_each_entry_safe(cur, next, &av->av_records, avr_node) + kmem_cache_free(dccp_ackvec_record_slab, cur); + INIT_LIST_HEAD(&av->av_records); +} + +void dccp_ackvec_free(struct dccp_ackvec *av) +{ + if (likely(av != NULL)) { + dccp_ackvec_purge_records(av); + kmem_cache_free(dccp_ackvec_slab, av); + } +} + +/** + * dccp_ackvec_update_records - Record information about sent Ack Vectors + * @av: Ack Vector records to update + * @seqno: Sequence number of the packet carrying the Ack Vector just sent + * @nonce_sum: The sum of all buffer nonces contained in the Ack Vector + */ +int dccp_ackvec_update_records(struct dccp_ackvec *av, u64 seqno, u8 nonce_sum) +{ + struct dccp_ackvec_record *avr; + + avr = kmem_cache_alloc(dccp_ackvec_record_slab, GFP_ATOMIC); + if (avr == NULL) + return -ENOBUFS; + + avr->avr_ack_seqno = seqno; + avr->avr_ack_ptr = av->av_buf_head; + avr->avr_ack_ackno = av->av_buf_ackno; + avr->avr_ack_nonce = nonce_sum; + avr->avr_ack_runlen = dccp_ackvec_runlen(av->av_buf + av->av_buf_head); + /* + * When the buffer overflows, we keep no more than one record. This is + * the simplest way of disambiguating sender-Acks dating from before the + * overflow from sender-Acks which refer to after the overflow; a simple + * solution is preferable here since we are handling an exception. + */ + if (av->av_overflow) + dccp_ackvec_purge_records(av); + /* + * Since GSS is incremented for each packet, the list is automatically + * arranged in descending order of @ack_seqno. + */ + list_add(&avr->avr_node, &av->av_records); + + dccp_pr_debug("Added Vector, ack_seqno=%llu, ack_ackno=%llu (rl=%u)\n", + (unsigned long long)avr->avr_ack_seqno, + (unsigned long long)avr->avr_ack_ackno, + avr->avr_ack_runlen); + return 0; +} + +static struct dccp_ackvec_record *dccp_ackvec_lookup(struct list_head *av_list, + const u64 ackno) +{ + struct dccp_ackvec_record *avr; + /* + * Exploit that records are inserted in descending order of sequence + * number, start with the oldest record first. If @ackno is `before' + * the earliest ack_ackno, the packet is too old to be considered. + */ + list_for_each_entry_reverse(avr, av_list, avr_node) { + if (avr->avr_ack_seqno == ackno) + return avr; + if (before48(ackno, avr->avr_ack_seqno)) + break; + } + return NULL; +} + +/* + * Buffer index and length computation using modulo-buffersize arithmetic. + * Note that, as pointers move from right to left, head is `before' tail. + */ +static inline u16 __ackvec_idx_add(const u16 a, const u16 b) +{ + return (a + b) % DCCPAV_MAX_ACKVEC_LEN; +} + +static inline u16 __ackvec_idx_sub(const u16 a, const u16 b) +{ + return __ackvec_idx_add(a, DCCPAV_MAX_ACKVEC_LEN - b); +} + +u16 dccp_ackvec_buflen(const struct dccp_ackvec *av) +{ + if (unlikely(av->av_overflow)) + return DCCPAV_MAX_ACKVEC_LEN; + return __ackvec_idx_sub(av->av_buf_tail, av->av_buf_head); +} + +/** + * dccp_ackvec_update_old - Update previous state as per RFC 4340, 11.4.1 + * @av: non-empty buffer to update + * @distance: negative or zero distance of @seqno from buf_ackno downward + * @seqno: the (old) sequence number whose record is to be updated + * @state: state in which packet carrying @seqno was received + */ +static void dccp_ackvec_update_old(struct dccp_ackvec *av, s64 distance, + u64 seqno, enum dccp_ackvec_states state) +{ + u16 ptr = av->av_buf_head; + + BUG_ON(distance > 0); + if (unlikely(dccp_ackvec_is_empty(av))) + return; + + do { + u8 runlen = dccp_ackvec_runlen(av->av_buf + ptr); + + if (distance + runlen >= 0) { + /* + * Only update the state if packet has not been received + * yet. This is OK as per the second table in RFC 4340, + * 11.4.1; i.e. here we are using the following table: + * RECEIVED + * 0 1 3 + * S +---+---+---+ + * T 0 | 0 | 0 | 0 | + * O +---+---+---+ + * R 1 | 1 | 1 | 1 | + * E +---+---+---+ + * D 3 | 0 | 1 | 3 | + * +---+---+---+ + * The "Not Received" state was set by reserve_seats(). + */ + if (av->av_buf[ptr] == DCCPAV_NOT_RECEIVED) + av->av_buf[ptr] = state; + else + dccp_pr_debug("Not changing %llu state to %u\n", + (unsigned long long)seqno, state); + break; + } + + distance += runlen + 1; + ptr = __ackvec_idx_add(ptr, 1); + + } while (ptr != av->av_buf_tail); +} + +/* Mark @num entries after buf_head as "Not yet received". */ +static void dccp_ackvec_reserve_seats(struct dccp_ackvec *av, u16 num) +{ + u16 start = __ackvec_idx_add(av->av_buf_head, 1), + len = DCCPAV_MAX_ACKVEC_LEN - start; + + /* check for buffer wrap-around */ + if (num > len) { + memset(av->av_buf + start, DCCPAV_NOT_RECEIVED, len); + start = 0; + num -= len; + } + if (num) + memset(av->av_buf + start, DCCPAV_NOT_RECEIVED, num); +} + +/** + * dccp_ackvec_add_new - Record one or more new entries in Ack Vector buffer + * @av: container of buffer to update (can be empty or non-empty) + * @num_packets: number of packets to register (must be >= 1) + * @seqno: sequence number of the first packet in @num_packets + * @state: state in which packet carrying @seqno was received + */ +static void dccp_ackvec_add_new(struct dccp_ackvec *av, u32 num_packets, + u64 seqno, enum dccp_ackvec_states state) +{ + u32 num_cells = num_packets; + + if (num_packets > DCCPAV_BURST_THRESH) { + u32 lost_packets = num_packets - 1; + + DCCP_WARN("Warning: large burst loss (%u)\n", lost_packets); + /* + * We received 1 packet and have a loss of size "num_packets-1" + * which we squeeze into num_cells-1 rather than reserving an + * entire byte for each lost packet. + * The reason is that the vector grows in O(burst_length); when + * it grows too large there will no room left for the payload. + * This is a trade-off: if a few packets out of the burst show + * up later, their state will not be changed; it is simply too + * costly to reshuffle/reallocate/copy the buffer each time. + * Should such problems persist, we will need to switch to a + * different underlying data structure. + */ + for (num_packets = num_cells = 1; lost_packets; ++num_cells) { + u8 len = min_t(u32, lost_packets, DCCPAV_MAX_RUNLEN); + + av->av_buf_head = __ackvec_idx_sub(av->av_buf_head, 1); + av->av_buf[av->av_buf_head] = DCCPAV_NOT_RECEIVED | len; + + lost_packets -= len; + } + } + + if (num_cells + dccp_ackvec_buflen(av) >= DCCPAV_MAX_ACKVEC_LEN) { + DCCP_CRIT("Ack Vector buffer overflow: dropping old entries"); + av->av_overflow = true; + } + + av->av_buf_head = __ackvec_idx_sub(av->av_buf_head, num_packets); + if (av->av_overflow) + av->av_buf_tail = av->av_buf_head; + + av->av_buf[av->av_buf_head] = state; + av->av_buf_ackno = seqno; + + if (num_packets > 1) + dccp_ackvec_reserve_seats(av, num_packets - 1); +} + +/** + * dccp_ackvec_input - Register incoming packet in the buffer + * @av: Ack Vector to register packet to + * @skb: Packet to register + */ +void dccp_ackvec_input(struct dccp_ackvec *av, struct sk_buff *skb) +{ + u64 seqno = DCCP_SKB_CB(skb)->dccpd_seq; + enum dccp_ackvec_states state = DCCPAV_RECEIVED; + + if (dccp_ackvec_is_empty(av)) { + dccp_ackvec_add_new(av, 1, seqno, state); + av->av_tail_ackno = seqno; + + } else { + s64 num_packets = dccp_delta_seqno(av->av_buf_ackno, seqno); + u8 *current_head = av->av_buf + av->av_buf_head; + + if (num_packets == 1 && + dccp_ackvec_state(current_head) == state && + dccp_ackvec_runlen(current_head) < DCCPAV_MAX_RUNLEN) { + + *current_head += 1; + av->av_buf_ackno = seqno; + + } else if (num_packets > 0) { + dccp_ackvec_add_new(av, num_packets, seqno, state); + } else { + dccp_ackvec_update_old(av, num_packets, seqno, state); + } + } +} + +/** + * dccp_ackvec_clear_state - Perform house-keeping / garbage-collection + * @av: Ack Vector record to clean + * @ackno: last Ack Vector which has been acknowledged + * + * This routine is called when the peer acknowledges the receipt of Ack Vectors + * up to and including @ackno. While based on section A.3 of RFC 4340, here + * are additional precautions to prevent corrupted buffer state. In particular, + * we use tail_ackno to identify outdated records; it always marks the earliest + * packet of group (2) in 11.4.2. + */ +void dccp_ackvec_clear_state(struct dccp_ackvec *av, const u64 ackno) +{ + struct dccp_ackvec_record *avr, *next; + u8 runlen_now, eff_runlen; + s64 delta; + + avr = dccp_ackvec_lookup(&av->av_records, ackno); + if (avr == NULL) + return; + /* + * Deal with outdated acknowledgments: this arises when e.g. there are + * several old records and the acks from the peer come in slowly. In + * that case we may still have records that pre-date tail_ackno. + */ + delta = dccp_delta_seqno(av->av_tail_ackno, avr->avr_ack_ackno); + if (delta < 0) + goto free_records; + /* + * Deal with overlapping Ack Vectors: don't subtract more than the + * number of packets between tail_ackno and ack_ackno. + */ + eff_runlen = delta < avr->avr_ack_runlen ? delta : avr->avr_ack_runlen; + + runlen_now = dccp_ackvec_runlen(av->av_buf + avr->avr_ack_ptr); + /* + * The run length of Ack Vector cells does not decrease over time. If + * the run length is the same as at the time the Ack Vector was sent, we + * free the ack_ptr cell. That cell can however not be freed if the run + * length has increased: in this case we need to move the tail pointer + * backwards (towards higher indices), to its next-oldest neighbour. + */ + if (runlen_now > eff_runlen) { + + av->av_buf[avr->avr_ack_ptr] -= eff_runlen + 1; + av->av_buf_tail = __ackvec_idx_add(avr->avr_ack_ptr, 1); + + /* This move may not have cleared the overflow flag. */ + if (av->av_overflow) + av->av_overflow = (av->av_buf_head == av->av_buf_tail); + } else { + av->av_buf_tail = avr->avr_ack_ptr; + /* + * We have made sure that avr points to a valid cell within the + * buffer. This cell is either older than head, or equals head + * (empty buffer): in both cases we no longer have any overflow. + */ + av->av_overflow = 0; + } + + /* + * The peer has acknowledged up to and including ack_ackno. Hence the + * first packet in group (2) of 11.4.2 is the successor of ack_ackno. + */ + av->av_tail_ackno = ADD48(avr->avr_ack_ackno, 1); + +free_records: + list_for_each_entry_safe_from(avr, next, &av->av_records, avr_node) { + list_del(&avr->avr_node); + kmem_cache_free(dccp_ackvec_record_slab, avr); + } +} + +/* + * Routines to keep track of Ack Vectors received in an skb + */ +int dccp_ackvec_parsed_add(struct list_head *head, u8 *vec, u8 len, u8 nonce) +{ + struct dccp_ackvec_parsed *new = kmalloc(sizeof(*new), GFP_ATOMIC); + + if (new == NULL) + return -ENOBUFS; + new->vec = vec; + new->len = len; + new->nonce = nonce; + + list_add_tail(&new->node, head); + return 0; +} +EXPORT_SYMBOL_GPL(dccp_ackvec_parsed_add); + +void dccp_ackvec_parsed_cleanup(struct list_head *parsed_chunks) +{ + struct dccp_ackvec_parsed *cur, *next; + + list_for_each_entry_safe(cur, next, parsed_chunks, node) + kfree(cur); + INIT_LIST_HEAD(parsed_chunks); +} +EXPORT_SYMBOL_GPL(dccp_ackvec_parsed_cleanup); + +int __init dccp_ackvec_init(void) +{ + dccp_ackvec_slab = kmem_cache_create("dccp_ackvec", + sizeof(struct dccp_ackvec), 0, + SLAB_HWCACHE_ALIGN, NULL); + if (dccp_ackvec_slab == NULL) + goto out_err; + + dccp_ackvec_record_slab = kmem_cache_create("dccp_ackvec_record", + sizeof(struct dccp_ackvec_record), + 0, SLAB_HWCACHE_ALIGN, NULL); + if (dccp_ackvec_record_slab == NULL) + goto out_destroy_slab; + + return 0; + +out_destroy_slab: + kmem_cache_destroy(dccp_ackvec_slab); + dccp_ackvec_slab = NULL; +out_err: + DCCP_CRIT("Unable to create Ack Vector slab cache"); + return -ENOBUFS; +} + +void dccp_ackvec_exit(void) +{ + kmem_cache_destroy(dccp_ackvec_slab); + dccp_ackvec_slab = NULL; + kmem_cache_destroy(dccp_ackvec_record_slab); + dccp_ackvec_record_slab = NULL; +} diff --git a/net/dccp/ackvec.h b/net/dccp/ackvec.h new file mode 100644 index 000000000..d2c4220fb --- /dev/null +++ b/net/dccp/ackvec.h @@ -0,0 +1,136 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +#ifndef _ACKVEC_H +#define _ACKVEC_H +/* + * net/dccp/ackvec.h + * + * An implementation of Ack Vectors for the DCCP protocol + * Copyright (c) 2007 University of Aberdeen, Scotland, UK + * Copyright (c) 2005 Arnaldo Carvalho de Melo + */ + +#include +#include +#include +#include + +/* + * Ack Vector buffer space is static, in multiples of %DCCP_SINGLE_OPT_MAXLEN, + * the maximum size of a single Ack Vector. Setting %DCCPAV_NUM_ACKVECS to 1 + * will be sufficient for most cases of low Ack Ratios, using a value of 2 gives + * more headroom if Ack Ratio is higher or when the sender acknowledges slowly. + * The maximum value is bounded by the u16 types for indices and functions. + */ +#define DCCPAV_NUM_ACKVECS 2 +#define DCCPAV_MAX_ACKVEC_LEN (DCCP_SINGLE_OPT_MAXLEN * DCCPAV_NUM_ACKVECS) + +/* Estimated minimum average Ack Vector length - used for updating MPS */ +#define DCCPAV_MIN_OPTLEN 16 + +/* Threshold for coping with large bursts of losses */ +#define DCCPAV_BURST_THRESH (DCCPAV_MAX_ACKVEC_LEN / 8) + +enum dccp_ackvec_states { + DCCPAV_RECEIVED = 0x00, + DCCPAV_ECN_MARKED = 0x40, + DCCPAV_RESERVED = 0x80, + DCCPAV_NOT_RECEIVED = 0xC0 +}; +#define DCCPAV_MAX_RUNLEN 0x3F + +static inline u8 dccp_ackvec_runlen(const u8 *cell) +{ + return *cell & DCCPAV_MAX_RUNLEN; +} + +static inline u8 dccp_ackvec_state(const u8 *cell) +{ + return *cell & ~DCCPAV_MAX_RUNLEN; +} + +/** + * struct dccp_ackvec - Ack Vector main data structure + * + * This implements a fixed-size circular buffer within an array and is largely + * based on Appendix A of RFC 4340. + * + * @av_buf: circular buffer storage area + * @av_buf_head: head index; begin of live portion in @av_buf + * @av_buf_tail: tail index; first index _after_ the live portion in @av_buf + * @av_buf_ackno: highest seqno of acknowledgeable packet recorded in @av_buf + * @av_tail_ackno: lowest seqno of acknowledgeable packet recorded in @av_buf + * @av_buf_nonce: ECN nonce sums, each covering subsequent segments of up to + * %DCCP_SINGLE_OPT_MAXLEN cells in the live portion of @av_buf + * @av_overflow: if 1 then buf_head == buf_tail indicates buffer wraparound + * @av_records: list of %dccp_ackvec_record (Ack Vectors sent previously) + */ +struct dccp_ackvec { + u8 av_buf[DCCPAV_MAX_ACKVEC_LEN]; + u16 av_buf_head; + u16 av_buf_tail; + u64 av_buf_ackno:48; + u64 av_tail_ackno:48; + bool av_buf_nonce[DCCPAV_NUM_ACKVECS]; + u8 av_overflow:1; + struct list_head av_records; +}; + +/** + * struct dccp_ackvec_record - Records information about sent Ack Vectors + * + * These list entries define the additional information which the HC-Receiver + * keeps about recently-sent Ack Vectors; again refer to RFC 4340, Appendix A. + * + * @avr_node: the list node in @av_records + * @avr_ack_seqno: sequence number of the packet the Ack Vector was sent on + * @avr_ack_ackno: the Ack number that this record/Ack Vector refers to + * @avr_ack_ptr: pointer into @av_buf where this record starts + * @avr_ack_runlen: run length of @avr_ack_ptr at the time of sending + * @avr_ack_nonce: the sum of @av_buf_nonce's at the time this record was sent + * + * The list as a whole is sorted in descending order by @avr_ack_seqno. + */ +struct dccp_ackvec_record { + struct list_head avr_node; + u64 avr_ack_seqno:48; + u64 avr_ack_ackno:48; + u16 avr_ack_ptr; + u8 avr_ack_runlen; + u8 avr_ack_nonce:1; +}; + +int dccp_ackvec_init(void); +void dccp_ackvec_exit(void); + +struct dccp_ackvec *dccp_ackvec_alloc(const gfp_t priority); +void dccp_ackvec_free(struct dccp_ackvec *av); + +void dccp_ackvec_input(struct dccp_ackvec *av, struct sk_buff *skb); +int dccp_ackvec_update_records(struct dccp_ackvec *av, u64 seq, u8 sum); +void dccp_ackvec_clear_state(struct dccp_ackvec *av, const u64 ackno); +u16 dccp_ackvec_buflen(const struct dccp_ackvec *av); + +static inline bool dccp_ackvec_is_empty(const struct dccp_ackvec *av) +{ + return av->av_overflow == 0 && av->av_buf_head == av->av_buf_tail; +} + +/** + * struct dccp_ackvec_parsed - Record offsets of Ack Vectors in skb + * @vec: start of vector (offset into skb) + * @len: length of @vec + * @nonce: whether @vec had an ECN nonce of 0 or 1 + * @node: FIFO - arranged in descending order of ack_ackno + * + * This structure is used by CCIDs to access Ack Vectors in a received skb. + */ +struct dccp_ackvec_parsed { + u8 *vec, + len, + nonce:1; + struct list_head node; +}; + +int dccp_ackvec_parsed_add(struct list_head *head, u8 *vec, u8 len, u8 nonce); +void dccp_ackvec_parsed_cleanup(struct list_head *parsed_chunks); +#endif /* _ACKVEC_H */ diff --git a/net/dccp/ccid.c b/net/dccp/ccid.c new file mode 100644 index 000000000..6beac5d34 --- /dev/null +++ b/net/dccp/ccid.c @@ -0,0 +1,219 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/dccp/ccid.c + * + * An implementation of the DCCP protocol + * Arnaldo Carvalho de Melo + * + * CCID infrastructure + */ + +#include + +#include "ccid.h" +#include "ccids/lib/tfrc.h" + +static struct ccid_operations *ccids[] = { + &ccid2_ops, +#ifdef CONFIG_IP_DCCP_CCID3 + &ccid3_ops, +#endif +}; + +static struct ccid_operations *ccid_by_number(const u8 id) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(ccids); i++) + if (ccids[i]->ccid_id == id) + return ccids[i]; + return NULL; +} + +/* check that up to @array_len members in @ccid_array are supported */ +bool ccid_support_check(u8 const *ccid_array, u8 array_len) +{ + while (array_len > 0) + if (ccid_by_number(ccid_array[--array_len]) == NULL) + return false; + return true; +} + +/** + * ccid_get_builtin_ccids - Populate a list of built-in CCIDs + * @ccid_array: pointer to copy into + * @array_len: value to return length into + * + * This function allocates memory - caller must see that it is freed after use. + */ +int ccid_get_builtin_ccids(u8 **ccid_array, u8 *array_len) +{ + *ccid_array = kmalloc(ARRAY_SIZE(ccids), gfp_any()); + if (*ccid_array == NULL) + return -ENOBUFS; + + for (*array_len = 0; *array_len < ARRAY_SIZE(ccids); *array_len += 1) + (*ccid_array)[*array_len] = ccids[*array_len]->ccid_id; + return 0; +} + +int ccid_getsockopt_builtin_ccids(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + u8 *ccid_array, array_len; + int err = 0; + + if (ccid_get_builtin_ccids(&ccid_array, &array_len)) + return -ENOBUFS; + + if (put_user(array_len, optlen)) + err = -EFAULT; + else if (len > 0 && copy_to_user(optval, ccid_array, + len > array_len ? array_len : len)) + err = -EFAULT; + + kfree(ccid_array); + return err; +} + +static __printf(3, 4) struct kmem_cache *ccid_kmem_cache_create(int obj_size, char *slab_name_fmt, const char *fmt,...) +{ + struct kmem_cache *slab; + va_list args; + + va_start(args, fmt); + vsnprintf(slab_name_fmt, CCID_SLAB_NAME_LENGTH, fmt, args); + va_end(args); + + slab = kmem_cache_create(slab_name_fmt, sizeof(struct ccid) + obj_size, 0, + SLAB_HWCACHE_ALIGN, NULL); + return slab; +} + +static void ccid_kmem_cache_destroy(struct kmem_cache *slab) +{ + kmem_cache_destroy(slab); +} + +static int __init ccid_activate(struct ccid_operations *ccid_ops) +{ + int err = -ENOBUFS; + + ccid_ops->ccid_hc_rx_slab = + ccid_kmem_cache_create(ccid_ops->ccid_hc_rx_obj_size, + ccid_ops->ccid_hc_rx_slab_name, + "ccid%u_hc_rx_sock", + ccid_ops->ccid_id); + if (ccid_ops->ccid_hc_rx_slab == NULL) + goto out; + + ccid_ops->ccid_hc_tx_slab = + ccid_kmem_cache_create(ccid_ops->ccid_hc_tx_obj_size, + ccid_ops->ccid_hc_tx_slab_name, + "ccid%u_hc_tx_sock", + ccid_ops->ccid_id); + if (ccid_ops->ccid_hc_tx_slab == NULL) + goto out_free_rx_slab; + + pr_info("DCCP: Activated CCID %d (%s)\n", + ccid_ops->ccid_id, ccid_ops->ccid_name); + err = 0; +out: + return err; +out_free_rx_slab: + ccid_kmem_cache_destroy(ccid_ops->ccid_hc_rx_slab); + ccid_ops->ccid_hc_rx_slab = NULL; + goto out; +} + +static void ccid_deactivate(struct ccid_operations *ccid_ops) +{ + ccid_kmem_cache_destroy(ccid_ops->ccid_hc_tx_slab); + ccid_ops->ccid_hc_tx_slab = NULL; + ccid_kmem_cache_destroy(ccid_ops->ccid_hc_rx_slab); + ccid_ops->ccid_hc_rx_slab = NULL; + + pr_info("DCCP: Deactivated CCID %d (%s)\n", + ccid_ops->ccid_id, ccid_ops->ccid_name); +} + +struct ccid *ccid_new(const u8 id, struct sock *sk, bool rx) +{ + struct ccid_operations *ccid_ops = ccid_by_number(id); + struct ccid *ccid = NULL; + + if (ccid_ops == NULL) + goto out; + + ccid = kmem_cache_alloc(rx ? ccid_ops->ccid_hc_rx_slab : + ccid_ops->ccid_hc_tx_slab, gfp_any()); + if (ccid == NULL) + goto out; + ccid->ccid_ops = ccid_ops; + if (rx) { + memset(ccid + 1, 0, ccid_ops->ccid_hc_rx_obj_size); + if (ccid->ccid_ops->ccid_hc_rx_init != NULL && + ccid->ccid_ops->ccid_hc_rx_init(ccid, sk) != 0) + goto out_free_ccid; + } else { + memset(ccid + 1, 0, ccid_ops->ccid_hc_tx_obj_size); + if (ccid->ccid_ops->ccid_hc_tx_init != NULL && + ccid->ccid_ops->ccid_hc_tx_init(ccid, sk) != 0) + goto out_free_ccid; + } +out: + return ccid; +out_free_ccid: + kmem_cache_free(rx ? ccid_ops->ccid_hc_rx_slab : + ccid_ops->ccid_hc_tx_slab, ccid); + ccid = NULL; + goto out; +} + +void ccid_hc_rx_delete(struct ccid *ccid, struct sock *sk) +{ + if (ccid != NULL) { + if (ccid->ccid_ops->ccid_hc_rx_exit != NULL) + ccid->ccid_ops->ccid_hc_rx_exit(sk); + kmem_cache_free(ccid->ccid_ops->ccid_hc_rx_slab, ccid); + } +} + +void ccid_hc_tx_delete(struct ccid *ccid, struct sock *sk) +{ + if (ccid != NULL) { + if (ccid->ccid_ops->ccid_hc_tx_exit != NULL) + ccid->ccid_ops->ccid_hc_tx_exit(sk); + kmem_cache_free(ccid->ccid_ops->ccid_hc_tx_slab, ccid); + } +} + +int __init ccid_initialize_builtins(void) +{ + int i, err = tfrc_lib_init(); + + if (err) + return err; + + for (i = 0; i < ARRAY_SIZE(ccids); i++) { + err = ccid_activate(ccids[i]); + if (err) + goto unwind_registrations; + } + return 0; + +unwind_registrations: + while(--i >= 0) + ccid_deactivate(ccids[i]); + tfrc_lib_exit(); + return err; +} + +void ccid_cleanup_builtins(void) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(ccids); i++) + ccid_deactivate(ccids[i]); + tfrc_lib_exit(); +} diff --git a/net/dccp/ccid.h b/net/dccp/ccid.h new file mode 100644 index 000000000..105f3734d --- /dev/null +++ b/net/dccp/ccid.h @@ -0,0 +1,262 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +#ifndef _CCID_H +#define _CCID_H +/* + * net/dccp/ccid.h + * + * An implementation of the DCCP protocol + * Arnaldo Carvalho de Melo + * + * CCID infrastructure + */ + +#include +#include +#include +#include +#include + +/* maximum value for a CCID (RFC 4340, 19.5) */ +#define CCID_MAX 255 +#define CCID_SLAB_NAME_LENGTH 32 + +struct tcp_info; + +/** + * struct ccid_operations - Interface to Congestion-Control Infrastructure + * + * @ccid_id: numerical CCID ID (up to %CCID_MAX, cf. table 5 in RFC 4340, 10.) + * @ccid_ccmps: the CCMPS including network/transport headers (0 when disabled) + * @ccid_name: alphabetical identifier string for @ccid_id + * @ccid_hc_{r,t}x_slab: memory pool for the receiver/sender half-connection + * @ccid_hc_{r,t}x_obj_size: size of the receiver/sender half-connection socket + * + * @ccid_hc_{r,t}x_init: CCID-specific initialisation routine (before startup) + * @ccid_hc_{r,t}x_exit: CCID-specific cleanup routine (before destruction) + * @ccid_hc_rx_packet_recv: implements the HC-receiver side + * @ccid_hc_{r,t}x_parse_options: parsing routine for CCID/HC-specific options + * @ccid_hc_{r,t}x_insert_options: insert routine for CCID/HC-specific options + * @ccid_hc_tx_packet_recv: implements feedback processing for the HC-sender + * @ccid_hc_tx_send_packet: implements the sending part of the HC-sender + * @ccid_hc_tx_packet_sent: does accounting for packets in flight by HC-sender + * @ccid_hc_{r,t}x_get_info: INET_DIAG information for HC-receiver/sender + * @ccid_hc_{r,t}x_getsockopt: socket options specific to HC-receiver/sender + */ +struct ccid_operations { + unsigned char ccid_id; + __u32 ccid_ccmps; + const char *ccid_name; + struct kmem_cache *ccid_hc_rx_slab, + *ccid_hc_tx_slab; + char ccid_hc_rx_slab_name[CCID_SLAB_NAME_LENGTH]; + char ccid_hc_tx_slab_name[CCID_SLAB_NAME_LENGTH]; + __u32 ccid_hc_rx_obj_size, + ccid_hc_tx_obj_size; + /* Interface Routines */ + int (*ccid_hc_rx_init)(struct ccid *ccid, struct sock *sk); + int (*ccid_hc_tx_init)(struct ccid *ccid, struct sock *sk); + void (*ccid_hc_rx_exit)(struct sock *sk); + void (*ccid_hc_tx_exit)(struct sock *sk); + void (*ccid_hc_rx_packet_recv)(struct sock *sk, + struct sk_buff *skb); + int (*ccid_hc_rx_parse_options)(struct sock *sk, u8 pkt, + u8 opt, u8 *val, u8 len); + int (*ccid_hc_rx_insert_options)(struct sock *sk, + struct sk_buff *skb); + void (*ccid_hc_tx_packet_recv)(struct sock *sk, + struct sk_buff *skb); + int (*ccid_hc_tx_parse_options)(struct sock *sk, u8 pkt, + u8 opt, u8 *val, u8 len); + int (*ccid_hc_tx_send_packet)(struct sock *sk, + struct sk_buff *skb); + void (*ccid_hc_tx_packet_sent)(struct sock *sk, + unsigned int len); + void (*ccid_hc_rx_get_info)(struct sock *sk, + struct tcp_info *info); + void (*ccid_hc_tx_get_info)(struct sock *sk, + struct tcp_info *info); + int (*ccid_hc_rx_getsockopt)(struct sock *sk, + const int optname, int len, + u32 __user *optval, + int __user *optlen); + int (*ccid_hc_tx_getsockopt)(struct sock *sk, + const int optname, int len, + u32 __user *optval, + int __user *optlen); +}; + +extern struct ccid_operations ccid2_ops; +#ifdef CONFIG_IP_DCCP_CCID3 +extern struct ccid_operations ccid3_ops; +#endif + +int ccid_initialize_builtins(void); +void ccid_cleanup_builtins(void); + +struct ccid { + struct ccid_operations *ccid_ops; + char ccid_priv[]; +}; + +static inline void *ccid_priv(const struct ccid *ccid) +{ + return (void *)ccid->ccid_priv; +} + +bool ccid_support_check(u8 const *ccid_array, u8 array_len); +int ccid_get_builtin_ccids(u8 **ccid_array, u8 *array_len); +int ccid_getsockopt_builtin_ccids(struct sock *sk, int len, + char __user *, int __user *); + +struct ccid *ccid_new(const u8 id, struct sock *sk, bool rx); + +static inline int ccid_get_current_rx_ccid(struct dccp_sock *dp) +{ + struct ccid *ccid = dp->dccps_hc_rx_ccid; + + if (ccid == NULL || ccid->ccid_ops == NULL) + return -1; + return ccid->ccid_ops->ccid_id; +} + +static inline int ccid_get_current_tx_ccid(struct dccp_sock *dp) +{ + struct ccid *ccid = dp->dccps_hc_tx_ccid; + + if (ccid == NULL || ccid->ccid_ops == NULL) + return -1; + return ccid->ccid_ops->ccid_id; +} + +void ccid_hc_rx_delete(struct ccid *ccid, struct sock *sk); +void ccid_hc_tx_delete(struct ccid *ccid, struct sock *sk); + +/* + * Congestion control of queued data packets via CCID decision. + * + * The TX CCID performs its congestion-control by indicating whether and when a + * queued packet may be sent, using the return code of ccid_hc_tx_send_packet(). + * The following modes are supported via the symbolic constants below: + * - timer-based pacing (CCID returns a delay value in milliseconds); + * - autonomous dequeueing (CCID internally schedules dccps_xmitlet). + */ + +enum ccid_dequeueing_decision { + CCID_PACKET_SEND_AT_ONCE = 0x00000, /* "green light": no delay */ + CCID_PACKET_DELAY_MAX = 0x0FFFF, /* maximum delay in msecs */ + CCID_PACKET_DELAY = 0x10000, /* CCID msec-delay mode */ + CCID_PACKET_WILL_DEQUEUE_LATER = 0x20000, /* CCID autonomous mode */ + CCID_PACKET_ERR = 0xF0000, /* error condition */ +}; + +static inline int ccid_packet_dequeue_eval(const int return_code) +{ + if (return_code < 0) + return CCID_PACKET_ERR; + if (return_code == 0) + return CCID_PACKET_SEND_AT_ONCE; + if (return_code <= CCID_PACKET_DELAY_MAX) + return CCID_PACKET_DELAY; + return return_code; +} + +static inline int ccid_hc_tx_send_packet(struct ccid *ccid, struct sock *sk, + struct sk_buff *skb) +{ + if (ccid->ccid_ops->ccid_hc_tx_send_packet != NULL) + return ccid->ccid_ops->ccid_hc_tx_send_packet(sk, skb); + return CCID_PACKET_SEND_AT_ONCE; +} + +static inline void ccid_hc_tx_packet_sent(struct ccid *ccid, struct sock *sk, + unsigned int len) +{ + if (ccid->ccid_ops->ccid_hc_tx_packet_sent != NULL) + ccid->ccid_ops->ccid_hc_tx_packet_sent(sk, len); +} + +static inline void ccid_hc_rx_packet_recv(struct ccid *ccid, struct sock *sk, + struct sk_buff *skb) +{ + if (ccid->ccid_ops->ccid_hc_rx_packet_recv != NULL) + ccid->ccid_ops->ccid_hc_rx_packet_recv(sk, skb); +} + +static inline void ccid_hc_tx_packet_recv(struct ccid *ccid, struct sock *sk, + struct sk_buff *skb) +{ + if (ccid->ccid_ops->ccid_hc_tx_packet_recv != NULL) + ccid->ccid_ops->ccid_hc_tx_packet_recv(sk, skb); +} + +/** + * ccid_hc_tx_parse_options - Parse CCID-specific options sent by the receiver + * @pkt: type of packet that @opt appears on (RFC 4340, 5.1) + * @opt: the CCID-specific option type (RFC 4340, 5.8 and 10.3) + * @val: value of @opt + * @len: length of @val in bytes + */ +static inline int ccid_hc_tx_parse_options(struct ccid *ccid, struct sock *sk, + u8 pkt, u8 opt, u8 *val, u8 len) +{ + if (!ccid || !ccid->ccid_ops->ccid_hc_tx_parse_options) + return 0; + return ccid->ccid_ops->ccid_hc_tx_parse_options(sk, pkt, opt, val, len); +} + +/** + * ccid_hc_rx_parse_options - Parse CCID-specific options sent by the sender + * Arguments are analogous to ccid_hc_tx_parse_options() + */ +static inline int ccid_hc_rx_parse_options(struct ccid *ccid, struct sock *sk, + u8 pkt, u8 opt, u8 *val, u8 len) +{ + if (!ccid || !ccid->ccid_ops->ccid_hc_rx_parse_options) + return 0; + return ccid->ccid_ops->ccid_hc_rx_parse_options(sk, pkt, opt, val, len); +} + +static inline int ccid_hc_rx_insert_options(struct ccid *ccid, struct sock *sk, + struct sk_buff *skb) +{ + if (ccid->ccid_ops->ccid_hc_rx_insert_options != NULL) + return ccid->ccid_ops->ccid_hc_rx_insert_options(sk, skb); + return 0; +} + +static inline void ccid_hc_rx_get_info(struct ccid *ccid, struct sock *sk, + struct tcp_info *info) +{ + if (ccid->ccid_ops->ccid_hc_rx_get_info != NULL) + ccid->ccid_ops->ccid_hc_rx_get_info(sk, info); +} + +static inline void ccid_hc_tx_get_info(struct ccid *ccid, struct sock *sk, + struct tcp_info *info) +{ + if (ccid->ccid_ops->ccid_hc_tx_get_info != NULL) + ccid->ccid_ops->ccid_hc_tx_get_info(sk, info); +} + +static inline int ccid_hc_rx_getsockopt(struct ccid *ccid, struct sock *sk, + const int optname, int len, + u32 __user *optval, int __user *optlen) +{ + int rc = -ENOPROTOOPT; + if (ccid != NULL && ccid->ccid_ops->ccid_hc_rx_getsockopt != NULL) + rc = ccid->ccid_ops->ccid_hc_rx_getsockopt(sk, optname, len, + optval, optlen); + return rc; +} + +static inline int ccid_hc_tx_getsockopt(struct ccid *ccid, struct sock *sk, + const int optname, int len, + u32 __user *optval, int __user *optlen) +{ + int rc = -ENOPROTOOPT; + if (ccid != NULL && ccid->ccid_ops->ccid_hc_tx_getsockopt != NULL) + rc = ccid->ccid_ops->ccid_hc_tx_getsockopt(sk, optname, len, + optval, optlen); + return rc; +} +#endif /* _CCID_H */ diff --git a/net/dccp/ccids/Kconfig b/net/dccp/ccids/Kconfig new file mode 100644 index 000000000..a3eeb84d1 --- /dev/null +++ b/net/dccp/ccids/Kconfig @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: GPL-2.0-only +menu "DCCP CCIDs Configuration" + +config IP_DCCP_CCID2_DEBUG + bool "CCID-2 debugging messages" + help + Enable CCID-2 specific debugging messages. + + The debugging output can additionally be toggled by setting the + ccid2_debug parameter to 0 or 1. + + If in doubt, say N. + +config IP_DCCP_CCID3 + bool "CCID-3 (TCP-Friendly)" + def_bool y if (IP_DCCP = y || IP_DCCP = m) + help + CCID-3 denotes TCP-Friendly Rate Control (TFRC), an equation-based + rate-controlled congestion control mechanism. TFRC is designed to + be reasonably fair when competing for bandwidth with TCP-like flows, + where a flow is "reasonably fair" if its sending rate is generally + within a factor of two of the sending rate of a TCP flow under the + same conditions. However, TFRC has a much lower variation of + throughput over time compared with TCP, which makes CCID-3 more + suitable than CCID-2 for applications such streaming media where a + relatively smooth sending rate is of importance. + + CCID-3 is further described in RFC 4342, + https://www.ietf.org/rfc/rfc4342.txt + + The TFRC congestion control algorithms were initially described in + RFC 5348. + + This text was extracted from RFC 4340 (sec. 10.2), + https://www.ietf.org/rfc/rfc4340.txt + + If in doubt, say N. + +config IP_DCCP_CCID3_DEBUG + bool "CCID-3 debugging messages" + depends on IP_DCCP_CCID3 + help + Enable CCID-3 specific debugging messages. + + The debugging output can additionally be toggled by setting the + ccid3_debug parameter to 0 or 1. + + If in doubt, say N. + +config IP_DCCP_TFRC_LIB + def_bool y if IP_DCCP_CCID3 + +config IP_DCCP_TFRC_DEBUG + def_bool y if IP_DCCP_CCID3_DEBUG +endmenu diff --git a/net/dccp/ccids/ccid2.c b/net/dccp/ccids/ccid2.c new file mode 100644 index 000000000..4d9823d6d --- /dev/null +++ b/net/dccp/ccids/ccid2.c @@ -0,0 +1,793 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (c) 2005, 2006 Andrea Bittau + * + * Changes to meet Linux coding standards, and DCCP infrastructure fixes. + * + * Copyright (c) 2006 Arnaldo Carvalho de Melo + */ + +/* + * This implementation should follow RFC 4341 + */ +#include +#include "../feat.h" +#include "ccid2.h" + + +#ifdef CONFIG_IP_DCCP_CCID2_DEBUG +static bool ccid2_debug; +#define ccid2_pr_debug(format, a...) DCCP_PR_DEBUG(ccid2_debug, format, ##a) +#else +#define ccid2_pr_debug(format, a...) +#endif + +static int ccid2_hc_tx_alloc_seq(struct ccid2_hc_tx_sock *hc) +{ + struct ccid2_seq *seqp; + int i; + + /* check if we have space to preserve the pointer to the buffer */ + if (hc->tx_seqbufc >= (sizeof(hc->tx_seqbuf) / + sizeof(struct ccid2_seq *))) + return -ENOMEM; + + /* allocate buffer and initialize linked list */ + seqp = kmalloc_array(CCID2_SEQBUF_LEN, sizeof(struct ccid2_seq), + gfp_any()); + if (seqp == NULL) + return -ENOMEM; + + for (i = 0; i < (CCID2_SEQBUF_LEN - 1); i++) { + seqp[i].ccid2s_next = &seqp[i + 1]; + seqp[i + 1].ccid2s_prev = &seqp[i]; + } + seqp[CCID2_SEQBUF_LEN - 1].ccid2s_next = seqp; + seqp->ccid2s_prev = &seqp[CCID2_SEQBUF_LEN - 1]; + + /* This is the first allocation. Initiate the head and tail. */ + if (hc->tx_seqbufc == 0) + hc->tx_seqh = hc->tx_seqt = seqp; + else { + /* link the existing list with the one we just created */ + hc->tx_seqh->ccid2s_next = seqp; + seqp->ccid2s_prev = hc->tx_seqh; + + hc->tx_seqt->ccid2s_prev = &seqp[CCID2_SEQBUF_LEN - 1]; + seqp[CCID2_SEQBUF_LEN - 1].ccid2s_next = hc->tx_seqt; + } + + /* store the original pointer to the buffer so we can free it */ + hc->tx_seqbuf[hc->tx_seqbufc] = seqp; + hc->tx_seqbufc++; + + return 0; +} + +static int ccid2_hc_tx_send_packet(struct sock *sk, struct sk_buff *skb) +{ + if (ccid2_cwnd_network_limited(ccid2_hc_tx_sk(sk))) + return CCID_PACKET_WILL_DEQUEUE_LATER; + return CCID_PACKET_SEND_AT_ONCE; +} + +static void ccid2_change_l_ack_ratio(struct sock *sk, u32 val) +{ + u32 max_ratio = DIV_ROUND_UP(ccid2_hc_tx_sk(sk)->tx_cwnd, 2); + + /* + * Ensure that Ack Ratio does not exceed ceil(cwnd/2), which is (2) from + * RFC 4341, 6.1.2. We ignore the statement that Ack Ratio 2 is always + * acceptable since this causes starvation/deadlock whenever cwnd < 2. + * The same problem arises when Ack Ratio is 0 (ie. Ack Ratio disabled). + */ + if (val == 0 || val > max_ratio) { + DCCP_WARN("Limiting Ack Ratio (%u) to %u\n", val, max_ratio); + val = max_ratio; + } + dccp_feat_signal_nn_change(sk, DCCPF_ACK_RATIO, + min_t(u32, val, DCCPF_ACK_RATIO_MAX)); +} + +static void ccid2_check_l_ack_ratio(struct sock *sk) +{ + struct ccid2_hc_tx_sock *hc = ccid2_hc_tx_sk(sk); + + /* + * After a loss, idle period, application limited period, or RTO we + * need to check that the ack ratio is still less than the congestion + * window. Otherwise, we will send an entire congestion window of + * packets and got no response because we haven't sent ack ratio + * packets yet. + * If the ack ratio does need to be reduced, we reduce it to half of + * the congestion window (or 1 if that's zero) instead of to the + * congestion window. This prevents problems if one ack is lost. + */ + if (dccp_feat_nn_get(sk, DCCPF_ACK_RATIO) > hc->tx_cwnd) + ccid2_change_l_ack_ratio(sk, hc->tx_cwnd/2 ? : 1U); +} + +static void ccid2_change_l_seq_window(struct sock *sk, u64 val) +{ + dccp_feat_signal_nn_change(sk, DCCPF_SEQUENCE_WINDOW, + clamp_val(val, DCCPF_SEQ_WMIN, + DCCPF_SEQ_WMAX)); +} + +static void dccp_tasklet_schedule(struct sock *sk) +{ + struct tasklet_struct *t = &dccp_sk(sk)->dccps_xmitlet; + + if (!test_and_set_bit(TASKLET_STATE_SCHED, &t->state)) { + sock_hold(sk); + __tasklet_schedule(t); + } +} + +static void ccid2_hc_tx_rto_expire(struct timer_list *t) +{ + struct ccid2_hc_tx_sock *hc = from_timer(hc, t, tx_rtotimer); + struct sock *sk = hc->sk; + const bool sender_was_blocked = ccid2_cwnd_network_limited(hc); + + bh_lock_sock(sk); + if (sock_owned_by_user(sk)) { + sk_reset_timer(sk, &hc->tx_rtotimer, jiffies + HZ / 5); + goto out; + } + + ccid2_pr_debug("RTO_EXPIRE\n"); + + if (sk->sk_state == DCCP_CLOSED) + goto out; + + /* back-off timer */ + hc->tx_rto <<= 1; + if (hc->tx_rto > DCCP_RTO_MAX) + hc->tx_rto = DCCP_RTO_MAX; + + /* adjust pipe, cwnd etc */ + hc->tx_ssthresh = hc->tx_cwnd / 2; + if (hc->tx_ssthresh < 2) + hc->tx_ssthresh = 2; + hc->tx_cwnd = 1; + hc->tx_pipe = 0; + + /* clear state about stuff we sent */ + hc->tx_seqt = hc->tx_seqh; + hc->tx_packets_acked = 0; + + /* clear ack ratio state. */ + hc->tx_rpseq = 0; + hc->tx_rpdupack = -1; + ccid2_change_l_ack_ratio(sk, 1); + + /* if we were blocked before, we may now send cwnd=1 packet */ + if (sender_was_blocked) + dccp_tasklet_schedule(sk); + /* restart backed-off timer */ + sk_reset_timer(sk, &hc->tx_rtotimer, jiffies + hc->tx_rto); +out: + bh_unlock_sock(sk); + sock_put(sk); +} + +/* + * Congestion window validation (RFC 2861). + */ +static bool ccid2_do_cwv = true; +module_param(ccid2_do_cwv, bool, 0644); +MODULE_PARM_DESC(ccid2_do_cwv, "Perform RFC2861 Congestion Window Validation"); + +/** + * ccid2_update_used_window - Track how much of cwnd is actually used + * @hc: socket to update window + * @new_wnd: new window values to add into the filter + * + * This is done in addition to CWV. The sender needs to have an idea of how many + * packets may be in flight, to set the local Sequence Window value accordingly + * (RFC 4340, 7.5.2). The CWV mechanism is exploited to keep track of the + * maximum-used window. We use an EWMA low-pass filter to filter out noise. + */ +static void ccid2_update_used_window(struct ccid2_hc_tx_sock *hc, u32 new_wnd) +{ + hc->tx_expected_wnd = (3 * hc->tx_expected_wnd + new_wnd) / 4; +} + +/* This borrows the code of tcp_cwnd_application_limited() */ +static void ccid2_cwnd_application_limited(struct sock *sk, const u32 now) +{ + struct ccid2_hc_tx_sock *hc = ccid2_hc_tx_sk(sk); + /* don't reduce cwnd below the initial window (IW) */ + u32 init_win = rfc3390_bytes_to_packets(dccp_sk(sk)->dccps_mss_cache), + win_used = max(hc->tx_cwnd_used, init_win); + + if (win_used < hc->tx_cwnd) { + hc->tx_ssthresh = max(hc->tx_ssthresh, + (hc->tx_cwnd >> 1) + (hc->tx_cwnd >> 2)); + hc->tx_cwnd = (hc->tx_cwnd + win_used) >> 1; + } + hc->tx_cwnd_used = 0; + hc->tx_cwnd_stamp = now; + + ccid2_check_l_ack_ratio(sk); +} + +/* This borrows the code of tcp_cwnd_restart() */ +static void ccid2_cwnd_restart(struct sock *sk, const u32 now) +{ + struct ccid2_hc_tx_sock *hc = ccid2_hc_tx_sk(sk); + u32 cwnd = hc->tx_cwnd, restart_cwnd, + iwnd = rfc3390_bytes_to_packets(dccp_sk(sk)->dccps_mss_cache); + s32 delta = now - hc->tx_lsndtime; + + hc->tx_ssthresh = max(hc->tx_ssthresh, (cwnd >> 1) + (cwnd >> 2)); + + /* don't reduce cwnd below the initial window (IW) */ + restart_cwnd = min(cwnd, iwnd); + + while ((delta -= hc->tx_rto) >= 0 && cwnd > restart_cwnd) + cwnd >>= 1; + hc->tx_cwnd = max(cwnd, restart_cwnd); + hc->tx_cwnd_stamp = now; + hc->tx_cwnd_used = 0; + + ccid2_check_l_ack_ratio(sk); +} + +static void ccid2_hc_tx_packet_sent(struct sock *sk, unsigned int len) +{ + struct dccp_sock *dp = dccp_sk(sk); + struct ccid2_hc_tx_sock *hc = ccid2_hc_tx_sk(sk); + const u32 now = ccid2_jiffies32; + struct ccid2_seq *next; + + /* slow-start after idle periods (RFC 2581, RFC 2861) */ + if (ccid2_do_cwv && !hc->tx_pipe && + (s32)(now - hc->tx_lsndtime) >= hc->tx_rto) + ccid2_cwnd_restart(sk, now); + + hc->tx_lsndtime = now; + hc->tx_pipe += 1; + + /* see whether cwnd was fully used (RFC 2861), update expected window */ + if (ccid2_cwnd_network_limited(hc)) { + ccid2_update_used_window(hc, hc->tx_cwnd); + hc->tx_cwnd_used = 0; + hc->tx_cwnd_stamp = now; + } else { + if (hc->tx_pipe > hc->tx_cwnd_used) + hc->tx_cwnd_used = hc->tx_pipe; + + ccid2_update_used_window(hc, hc->tx_cwnd_used); + + if (ccid2_do_cwv && (s32)(now - hc->tx_cwnd_stamp) >= hc->tx_rto) + ccid2_cwnd_application_limited(sk, now); + } + + hc->tx_seqh->ccid2s_seq = dp->dccps_gss; + hc->tx_seqh->ccid2s_acked = 0; + hc->tx_seqh->ccid2s_sent = now; + + next = hc->tx_seqh->ccid2s_next; + /* check if we need to alloc more space */ + if (next == hc->tx_seqt) { + if (ccid2_hc_tx_alloc_seq(hc)) { + DCCP_CRIT("packet history - out of memory!"); + /* FIXME: find a more graceful way to bail out */ + return; + } + next = hc->tx_seqh->ccid2s_next; + BUG_ON(next == hc->tx_seqt); + } + hc->tx_seqh = next; + + ccid2_pr_debug("cwnd=%d pipe=%d\n", hc->tx_cwnd, hc->tx_pipe); + + /* + * FIXME: The code below is broken and the variables have been removed + * from the socket struct. The `ackloss' variable was always set to 0, + * and with arsent there are several problems: + * (i) it doesn't just count the number of Acks, but all sent packets; + * (ii) it is expressed in # of packets, not # of windows, so the + * comparison below uses the wrong formula: Appendix A of RFC 4341 + * comes up with the number K = cwnd / (R^2 - R) of consecutive windows + * of data with no lost or marked Ack packets. If arsent were the # of + * consecutive Acks received without loss, then Ack Ratio needs to be + * decreased by 1 when + * arsent >= K * cwnd / R = cwnd^2 / (R^3 - R^2) + * where cwnd / R is the number of Acks received per window of data + * (cf. RFC 4341, App. A). The problems are that + * - arsent counts other packets as well; + * - the comparison uses a formula different from RFC 4341; + * - computing a cubic/quadratic equation each time is too complicated. + * Hence a different algorithm is needed. + */ +#if 0 + /* Ack Ratio. Need to maintain a concept of how many windows we sent */ + hc->tx_arsent++; + /* We had an ack loss in this window... */ + if (hc->tx_ackloss) { + if (hc->tx_arsent >= hc->tx_cwnd) { + hc->tx_arsent = 0; + hc->tx_ackloss = 0; + } + } else { + /* No acks lost up to now... */ + /* decrease ack ratio if enough packets were sent */ + if (dp->dccps_l_ack_ratio > 1) { + /* XXX don't calculate denominator each time */ + int denom = dp->dccps_l_ack_ratio * dp->dccps_l_ack_ratio - + dp->dccps_l_ack_ratio; + + denom = hc->tx_cwnd * hc->tx_cwnd / denom; + + if (hc->tx_arsent >= denom) { + ccid2_change_l_ack_ratio(sk, dp->dccps_l_ack_ratio - 1); + hc->tx_arsent = 0; + } + } else { + /* we can't increase ack ratio further [1] */ + hc->tx_arsent = 0; /* or maybe set it to cwnd*/ + } + } +#endif + + sk_reset_timer(sk, &hc->tx_rtotimer, jiffies + hc->tx_rto); + +#ifdef CONFIG_IP_DCCP_CCID2_DEBUG + do { + struct ccid2_seq *seqp = hc->tx_seqt; + + while (seqp != hc->tx_seqh) { + ccid2_pr_debug("out seq=%llu acked=%d time=%u\n", + (unsigned long long)seqp->ccid2s_seq, + seqp->ccid2s_acked, seqp->ccid2s_sent); + seqp = seqp->ccid2s_next; + } + } while (0); + ccid2_pr_debug("=========\n"); +#endif +} + +/** + * ccid2_rtt_estimator - Sample RTT and compute RTO using RFC2988 algorithm + * @sk: socket to perform estimator on + * + * This code is almost identical with TCP's tcp_rtt_estimator(), since + * - it has a higher sampling frequency (recommended by RFC 1323), + * - the RTO does not collapse into RTT due to RTTVAR going towards zero, + * - it is simple (cf. more complex proposals such as Eifel timer or research + * which suggests that the gain should be set according to window size), + * - in tests it was found to work well with CCID2 [gerrit]. + */ +static void ccid2_rtt_estimator(struct sock *sk, const long mrtt) +{ + struct ccid2_hc_tx_sock *hc = ccid2_hc_tx_sk(sk); + long m = mrtt ? : 1; + + if (hc->tx_srtt == 0) { + /* First measurement m */ + hc->tx_srtt = m << 3; + hc->tx_mdev = m << 1; + + hc->tx_mdev_max = max(hc->tx_mdev, tcp_rto_min(sk)); + hc->tx_rttvar = hc->tx_mdev_max; + + hc->tx_rtt_seq = dccp_sk(sk)->dccps_gss; + } else { + /* Update scaled SRTT as SRTT += 1/8 * (m - SRTT) */ + m -= (hc->tx_srtt >> 3); + hc->tx_srtt += m; + + /* Similarly, update scaled mdev with regard to |m| */ + if (m < 0) { + m = -m; + m -= (hc->tx_mdev >> 2); + /* + * This neutralises RTO increase when RTT < SRTT - mdev + * (see P. Sarolahti, A. Kuznetsov,"Congestion Control + * in Linux TCP", USENIX 2002, pp. 49-62). + */ + if (m > 0) + m >>= 3; + } else { + m -= (hc->tx_mdev >> 2); + } + hc->tx_mdev += m; + + if (hc->tx_mdev > hc->tx_mdev_max) { + hc->tx_mdev_max = hc->tx_mdev; + if (hc->tx_mdev_max > hc->tx_rttvar) + hc->tx_rttvar = hc->tx_mdev_max; + } + + /* + * Decay RTTVAR at most once per flight, exploiting that + * 1) pipe <= cwnd <= Sequence_Window = W (RFC 4340, 7.5.2) + * 2) AWL = GSS-W+1 <= GAR <= GSS (RFC 4340, 7.5.1) + * GAR is a useful bound for FlightSize = pipe. + * AWL is probably too low here, as it over-estimates pipe. + */ + if (after48(dccp_sk(sk)->dccps_gar, hc->tx_rtt_seq)) { + if (hc->tx_mdev_max < hc->tx_rttvar) + hc->tx_rttvar -= (hc->tx_rttvar - + hc->tx_mdev_max) >> 2; + hc->tx_rtt_seq = dccp_sk(sk)->dccps_gss; + hc->tx_mdev_max = tcp_rto_min(sk); + } + } + + /* + * Set RTO from SRTT and RTTVAR + * As in TCP, 4 * RTTVAR >= TCP_RTO_MIN, giving a minimum RTO of 200 ms. + * This agrees with RFC 4341, 5: + * "Because DCCP does not retransmit data, DCCP does not require + * TCP's recommended minimum timeout of one second". + */ + hc->tx_rto = (hc->tx_srtt >> 3) + hc->tx_rttvar; + + if (hc->tx_rto > DCCP_RTO_MAX) + hc->tx_rto = DCCP_RTO_MAX; +} + +static void ccid2_new_ack(struct sock *sk, struct ccid2_seq *seqp, + unsigned int *maxincr) +{ + struct ccid2_hc_tx_sock *hc = ccid2_hc_tx_sk(sk); + struct dccp_sock *dp = dccp_sk(sk); + int r_seq_used = hc->tx_cwnd / dp->dccps_l_ack_ratio; + + if (hc->tx_cwnd < dp->dccps_l_seq_win && + r_seq_used < dp->dccps_r_seq_win) { + if (hc->tx_cwnd < hc->tx_ssthresh) { + if (*maxincr > 0 && ++hc->tx_packets_acked >= 2) { + hc->tx_cwnd += 1; + *maxincr -= 1; + hc->tx_packets_acked = 0; + } + } else if (++hc->tx_packets_acked >= hc->tx_cwnd) { + hc->tx_cwnd += 1; + hc->tx_packets_acked = 0; + } + } + + /* + * Adjust the local sequence window and the ack ratio to allow about + * 5 times the number of packets in the network (RFC 4340 7.5.2) + */ + if (r_seq_used * CCID2_WIN_CHANGE_FACTOR >= dp->dccps_r_seq_win) + ccid2_change_l_ack_ratio(sk, dp->dccps_l_ack_ratio * 2); + else if (r_seq_used * CCID2_WIN_CHANGE_FACTOR < dp->dccps_r_seq_win/2) + ccid2_change_l_ack_ratio(sk, dp->dccps_l_ack_ratio / 2 ? : 1U); + + if (hc->tx_cwnd * CCID2_WIN_CHANGE_FACTOR >= dp->dccps_l_seq_win) + ccid2_change_l_seq_window(sk, dp->dccps_l_seq_win * 2); + else if (hc->tx_cwnd * CCID2_WIN_CHANGE_FACTOR < dp->dccps_l_seq_win/2) + ccid2_change_l_seq_window(sk, dp->dccps_l_seq_win / 2); + + /* + * FIXME: RTT is sampled several times per acknowledgment (for each + * entry in the Ack Vector), instead of once per Ack (as in TCP SACK). + * This causes the RTT to be over-estimated, since the older entries + * in the Ack Vector have earlier sending times. + * The cleanest solution is to not use the ccid2s_sent field at all + * and instead use DCCP timestamps: requires changes in other places. + */ + ccid2_rtt_estimator(sk, ccid2_jiffies32 - seqp->ccid2s_sent); +} + +static void ccid2_congestion_event(struct sock *sk, struct ccid2_seq *seqp) +{ + struct ccid2_hc_tx_sock *hc = ccid2_hc_tx_sk(sk); + + if ((s32)(seqp->ccid2s_sent - hc->tx_last_cong) < 0) { + ccid2_pr_debug("Multiple losses in an RTT---treating as one\n"); + return; + } + + hc->tx_last_cong = ccid2_jiffies32; + + hc->tx_cwnd = hc->tx_cwnd / 2 ? : 1U; + hc->tx_ssthresh = max(hc->tx_cwnd, 2U); + + ccid2_check_l_ack_ratio(sk); +} + +static int ccid2_hc_tx_parse_options(struct sock *sk, u8 packet_type, + u8 option, u8 *optval, u8 optlen) +{ + struct ccid2_hc_tx_sock *hc = ccid2_hc_tx_sk(sk); + + switch (option) { + case DCCPO_ACK_VECTOR_0: + case DCCPO_ACK_VECTOR_1: + return dccp_ackvec_parsed_add(&hc->tx_av_chunks, optval, optlen, + option - DCCPO_ACK_VECTOR_0); + } + return 0; +} + +static void ccid2_hc_tx_packet_recv(struct sock *sk, struct sk_buff *skb) +{ + struct dccp_sock *dp = dccp_sk(sk); + struct ccid2_hc_tx_sock *hc = ccid2_hc_tx_sk(sk); + const bool sender_was_blocked = ccid2_cwnd_network_limited(hc); + struct dccp_ackvec_parsed *avp; + u64 ackno, seqno; + struct ccid2_seq *seqp; + int done = 0; + unsigned int maxincr = 0; + + /* check reverse path congestion */ + seqno = DCCP_SKB_CB(skb)->dccpd_seq; + + /* XXX this whole "algorithm" is broken. Need to fix it to keep track + * of the seqnos of the dupacks so that rpseq and rpdupack are correct + * -sorbo. + */ + /* need to bootstrap */ + if (hc->tx_rpdupack == -1) { + hc->tx_rpdupack = 0; + hc->tx_rpseq = seqno; + } else { + /* check if packet is consecutive */ + if (dccp_delta_seqno(hc->tx_rpseq, seqno) == 1) + hc->tx_rpseq = seqno; + /* it's a later packet */ + else if (after48(seqno, hc->tx_rpseq)) { + hc->tx_rpdupack++; + + /* check if we got enough dupacks */ + if (hc->tx_rpdupack >= NUMDUPACK) { + hc->tx_rpdupack = -1; /* XXX lame */ + hc->tx_rpseq = 0; +#ifdef __CCID2_COPES_GRACEFULLY_WITH_ACK_CONGESTION_CONTROL__ + /* + * FIXME: Ack Congestion Control is broken; in + * the current state instabilities occurred with + * Ack Ratios greater than 1; causing hang-ups + * and long RTO timeouts. This needs to be fixed + * before opening up dynamic changes. -- gerrit + */ + ccid2_change_l_ack_ratio(sk, 2 * dp->dccps_l_ack_ratio); +#endif + } + } + } + + /* check forward path congestion */ + if (dccp_packet_without_ack(skb)) + return; + + /* still didn't send out new data packets */ + if (hc->tx_seqh == hc->tx_seqt) + goto done; + + ackno = DCCP_SKB_CB(skb)->dccpd_ack_seq; + if (after48(ackno, hc->tx_high_ack)) + hc->tx_high_ack = ackno; + + seqp = hc->tx_seqt; + while (before48(seqp->ccid2s_seq, ackno)) { + seqp = seqp->ccid2s_next; + if (seqp == hc->tx_seqh) { + seqp = hc->tx_seqh->ccid2s_prev; + break; + } + } + + /* + * In slow-start, cwnd can increase up to a maximum of Ack Ratio/2 + * packets per acknowledgement. Rounding up avoids that cwnd is not + * advanced when Ack Ratio is 1 and gives a slight edge otherwise. + */ + if (hc->tx_cwnd < hc->tx_ssthresh) + maxincr = DIV_ROUND_UP(dp->dccps_l_ack_ratio, 2); + + /* go through all ack vectors */ + list_for_each_entry(avp, &hc->tx_av_chunks, node) { + /* go through this ack vector */ + for (; avp->len--; avp->vec++) { + u64 ackno_end_rl = SUB48(ackno, + dccp_ackvec_runlen(avp->vec)); + + ccid2_pr_debug("ackvec %llu |%u,%u|\n", + (unsigned long long)ackno, + dccp_ackvec_state(avp->vec) >> 6, + dccp_ackvec_runlen(avp->vec)); + /* if the seqno we are analyzing is larger than the + * current ackno, then move towards the tail of our + * seqnos. + */ + while (after48(seqp->ccid2s_seq, ackno)) { + if (seqp == hc->tx_seqt) { + done = 1; + break; + } + seqp = seqp->ccid2s_prev; + } + if (done) + break; + + /* check all seqnos in the range of the vector + * run length + */ + while (between48(seqp->ccid2s_seq,ackno_end_rl,ackno)) { + const u8 state = dccp_ackvec_state(avp->vec); + + /* new packet received or marked */ + if (state != DCCPAV_NOT_RECEIVED && + !seqp->ccid2s_acked) { + if (state == DCCPAV_ECN_MARKED) + ccid2_congestion_event(sk, + seqp); + else + ccid2_new_ack(sk, seqp, + &maxincr); + + seqp->ccid2s_acked = 1; + ccid2_pr_debug("Got ack for %llu\n", + (unsigned long long)seqp->ccid2s_seq); + hc->tx_pipe--; + } + if (seqp == hc->tx_seqt) { + done = 1; + break; + } + seqp = seqp->ccid2s_prev; + } + if (done) + break; + + ackno = SUB48(ackno_end_rl, 1); + } + if (done) + break; + } + + /* The state about what is acked should be correct now + * Check for NUMDUPACK + */ + seqp = hc->tx_seqt; + while (before48(seqp->ccid2s_seq, hc->tx_high_ack)) { + seqp = seqp->ccid2s_next; + if (seqp == hc->tx_seqh) { + seqp = hc->tx_seqh->ccid2s_prev; + break; + } + } + done = 0; + while (1) { + if (seqp->ccid2s_acked) { + done++; + if (done == NUMDUPACK) + break; + } + if (seqp == hc->tx_seqt) + break; + seqp = seqp->ccid2s_prev; + } + + /* If there are at least 3 acknowledgements, anything unacknowledged + * below the last sequence number is considered lost + */ + if (done == NUMDUPACK) { + struct ccid2_seq *last_acked = seqp; + + /* check for lost packets */ + while (1) { + if (!seqp->ccid2s_acked) { + ccid2_pr_debug("Packet lost: %llu\n", + (unsigned long long)seqp->ccid2s_seq); + /* XXX need to traverse from tail -> head in + * order to detect multiple congestion events in + * one ack vector. + */ + ccid2_congestion_event(sk, seqp); + hc->tx_pipe--; + } + if (seqp == hc->tx_seqt) + break; + seqp = seqp->ccid2s_prev; + } + + hc->tx_seqt = last_acked; + } + + /* trim acked packets in tail */ + while (hc->tx_seqt != hc->tx_seqh) { + if (!hc->tx_seqt->ccid2s_acked) + break; + + hc->tx_seqt = hc->tx_seqt->ccid2s_next; + } + + /* restart RTO timer if not all outstanding data has been acked */ + if (hc->tx_pipe == 0) + sk_stop_timer(sk, &hc->tx_rtotimer); + else + sk_reset_timer(sk, &hc->tx_rtotimer, jiffies + hc->tx_rto); +done: + /* check if incoming Acks allow pending packets to be sent */ + if (sender_was_blocked && !ccid2_cwnd_network_limited(hc)) + dccp_tasklet_schedule(sk); + dccp_ackvec_parsed_cleanup(&hc->tx_av_chunks); +} + +static int ccid2_hc_tx_init(struct ccid *ccid, struct sock *sk) +{ + struct ccid2_hc_tx_sock *hc = ccid_priv(ccid); + struct dccp_sock *dp = dccp_sk(sk); + u32 max_ratio; + + /* RFC 4341, 5: initialise ssthresh to arbitrarily high (max) value */ + hc->tx_ssthresh = ~0U; + + /* Use larger initial windows (RFC 4341, section 5). */ + hc->tx_cwnd = rfc3390_bytes_to_packets(dp->dccps_mss_cache); + hc->tx_expected_wnd = hc->tx_cwnd; + + /* Make sure that Ack Ratio is enabled and within bounds. */ + max_ratio = DIV_ROUND_UP(hc->tx_cwnd, 2); + if (dp->dccps_l_ack_ratio == 0 || dp->dccps_l_ack_ratio > max_ratio) + dp->dccps_l_ack_ratio = max_ratio; + + /* XXX init ~ to window size... */ + if (ccid2_hc_tx_alloc_seq(hc)) + return -ENOMEM; + + hc->tx_rto = DCCP_TIMEOUT_INIT; + hc->tx_rpdupack = -1; + hc->tx_last_cong = hc->tx_lsndtime = hc->tx_cwnd_stamp = ccid2_jiffies32; + hc->tx_cwnd_used = 0; + hc->sk = sk; + timer_setup(&hc->tx_rtotimer, ccid2_hc_tx_rto_expire, 0); + INIT_LIST_HEAD(&hc->tx_av_chunks); + return 0; +} + +static void ccid2_hc_tx_exit(struct sock *sk) +{ + struct ccid2_hc_tx_sock *hc = ccid2_hc_tx_sk(sk); + int i; + + sk_stop_timer(sk, &hc->tx_rtotimer); + + for (i = 0; i < hc->tx_seqbufc; i++) + kfree(hc->tx_seqbuf[i]); + hc->tx_seqbufc = 0; + dccp_ackvec_parsed_cleanup(&hc->tx_av_chunks); +} + +static void ccid2_hc_rx_packet_recv(struct sock *sk, struct sk_buff *skb) +{ + struct ccid2_hc_rx_sock *hc = ccid2_hc_rx_sk(sk); + + if (!dccp_data_packet(skb)) + return; + + if (++hc->rx_num_data_pkts >= dccp_sk(sk)->dccps_r_ack_ratio) { + dccp_send_ack(sk); + hc->rx_num_data_pkts = 0; + } +} + +struct ccid_operations ccid2_ops = { + .ccid_id = DCCPC_CCID2, + .ccid_name = "TCP-like", + .ccid_hc_tx_obj_size = sizeof(struct ccid2_hc_tx_sock), + .ccid_hc_tx_init = ccid2_hc_tx_init, + .ccid_hc_tx_exit = ccid2_hc_tx_exit, + .ccid_hc_tx_send_packet = ccid2_hc_tx_send_packet, + .ccid_hc_tx_packet_sent = ccid2_hc_tx_packet_sent, + .ccid_hc_tx_parse_options = ccid2_hc_tx_parse_options, + .ccid_hc_tx_packet_recv = ccid2_hc_tx_packet_recv, + .ccid_hc_rx_obj_size = sizeof(struct ccid2_hc_rx_sock), + .ccid_hc_rx_packet_recv = ccid2_hc_rx_packet_recv, +}; + +#ifdef CONFIG_IP_DCCP_CCID2_DEBUG +module_param(ccid2_debug, bool, 0644); +MODULE_PARM_DESC(ccid2_debug, "Enable CCID-2 debug messages"); +#endif diff --git a/net/dccp/ccids/ccid2.h b/net/dccp/ccids/ccid2.h new file mode 100644 index 000000000..330c7b4ec --- /dev/null +++ b/net/dccp/ccids/ccid2.h @@ -0,0 +1,121 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * Copyright (c) 2005 Andrea Bittau + */ +#ifndef _DCCP_CCID2_H_ +#define _DCCP_CCID2_H_ + +#include +#include +#include "../ccid.h" +#include "../dccp.h" + +/* + * CCID-2 timestamping faces the same issues as TCP timestamping. + * Hence we reuse/share as much of the code as possible. + */ +#define ccid2_jiffies32 ((u32)jiffies) + +/* NUMDUPACK parameter from RFC 4341, p. 6 */ +#define NUMDUPACK 3 + +struct ccid2_seq { + u64 ccid2s_seq; + u32 ccid2s_sent; + int ccid2s_acked; + struct ccid2_seq *ccid2s_prev; + struct ccid2_seq *ccid2s_next; +}; + +#define CCID2_SEQBUF_LEN 1024 +#define CCID2_SEQBUF_MAX 128 + +/* + * Multiple of congestion window to keep the sequence window at + * (RFC 4340 7.5.2) + */ +#define CCID2_WIN_CHANGE_FACTOR 5 + +/** + * struct ccid2_hc_tx_sock - CCID2 TX half connection + * @tx_{cwnd,ssthresh,pipe}: as per RFC 4341, section 5 + * @tx_packets_acked: Ack counter for deriving cwnd growth (RFC 3465) + * @tx_srtt: smoothed RTT estimate, scaled by 2^3 + * @tx_mdev: smoothed RTT variation, scaled by 2^2 + * @tx_mdev_max: maximum of @mdev during one flight + * @tx_rttvar: moving average/maximum of @mdev_max + * @tx_rto: RTO value deriving from SRTT and RTTVAR (RFC 2988) + * @tx_rtt_seq: to decay RTTVAR at most once per flight + * @tx_cwnd_used: actually used cwnd, W_used of RFC 2861 + * @tx_expected_wnd: moving average of @tx_cwnd_used + * @tx_cwnd_stamp: to track idle periods in CWV + * @tx_lsndtime: last time (in jiffies) a data packet was sent + * @tx_rpseq: last consecutive seqno + * @tx_rpdupack: dupacks since rpseq + * @tx_av_chunks: list of Ack Vectors received on current skb + */ +struct ccid2_hc_tx_sock { + u32 tx_cwnd; + u32 tx_ssthresh; + u32 tx_pipe; + u32 tx_packets_acked; + struct ccid2_seq *tx_seqbuf[CCID2_SEQBUF_MAX]; + int tx_seqbufc; + struct ccid2_seq *tx_seqh; + struct ccid2_seq *tx_seqt; + + /* RTT measurement: variables/principles are the same as in TCP */ + u32 tx_srtt, + tx_mdev, + tx_mdev_max, + tx_rttvar, + tx_rto; + u64 tx_rtt_seq:48; + struct timer_list tx_rtotimer; + struct sock *sk; + + /* Congestion Window validation (optional, RFC 2861) */ + u32 tx_cwnd_used, + tx_expected_wnd, + tx_cwnd_stamp, + tx_lsndtime; + + u64 tx_rpseq; + int tx_rpdupack; + u32 tx_last_cong; + u64 tx_high_ack; + struct list_head tx_av_chunks; +}; + +static inline bool ccid2_cwnd_network_limited(struct ccid2_hc_tx_sock *hc) +{ + return hc->tx_pipe >= hc->tx_cwnd; +} + +/* + * Convert RFC 3390 larger initial window into an equivalent number of packets. + * This is based on the numbers specified in RFC 5681, 3.1. + */ +static inline u32 rfc3390_bytes_to_packets(const u32 smss) +{ + return smss <= 1095 ? 4 : (smss > 2190 ? 2 : 3); +} + +/** + * struct ccid2_hc_rx_sock - Receiving end of CCID-2 half-connection + * @rx_num_data_pkts: number of data packets received since last feedback + */ +struct ccid2_hc_rx_sock { + u32 rx_num_data_pkts; +}; + +static inline struct ccid2_hc_tx_sock *ccid2_hc_tx_sk(const struct sock *sk) +{ + return ccid_priv(dccp_sk(sk)->dccps_hc_tx_ccid); +} + +static inline struct ccid2_hc_rx_sock *ccid2_hc_rx_sk(const struct sock *sk) +{ + return ccid_priv(dccp_sk(sk)->dccps_hc_rx_ccid); +} +#endif /* _DCCP_CCID2_H_ */ diff --git a/net/dccp/ccids/ccid3.c b/net/dccp/ccids/ccid3.c new file mode 100644 index 000000000..ca8670f78 --- /dev/null +++ b/net/dccp/ccids/ccid3.c @@ -0,0 +1,866 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (c) 2007 The University of Aberdeen, Scotland, UK + * Copyright (c) 2005-7 The University of Waikato, Hamilton, New Zealand. + * Copyright (c) 2005-7 Ian McDonald + * + * An implementation of the DCCP protocol + * + * This code has been developed by the University of Waikato WAND + * research group. For further information please see https://www.wand.net.nz/ + * + * This code also uses code from Lulea University, rereleased as GPL by its + * authors: + * Copyright (c) 2003 Nils-Erik Mattsson, Joacim Haggmark, Magnus Erixzon + * + * Changes to meet Linux coding standards, to make it meet latest ccid3 draft + * and to make it work as a loadable module in the DCCP stack written by + * Arnaldo Carvalho de Melo . + * + * Copyright (c) 2005 Arnaldo Carvalho de Melo + */ +#include "../dccp.h" +#include "ccid3.h" + +#include + +#ifdef CONFIG_IP_DCCP_CCID3_DEBUG +static bool ccid3_debug; +#define ccid3_pr_debug(format, a...) DCCP_PR_DEBUG(ccid3_debug, format, ##a) +#else +#define ccid3_pr_debug(format, a...) +#endif + +/* + * Transmitter Half-Connection Routines + */ +#ifdef CONFIG_IP_DCCP_CCID3_DEBUG +static const char *ccid3_tx_state_name(enum ccid3_hc_tx_states state) +{ + static const char *const ccid3_state_names[] = { + [TFRC_SSTATE_NO_SENT] = "NO_SENT", + [TFRC_SSTATE_NO_FBACK] = "NO_FBACK", + [TFRC_SSTATE_FBACK] = "FBACK", + }; + + return ccid3_state_names[state]; +} +#endif + +static void ccid3_hc_tx_set_state(struct sock *sk, + enum ccid3_hc_tx_states state) +{ + struct ccid3_hc_tx_sock *hc = ccid3_hc_tx_sk(sk); + enum ccid3_hc_tx_states oldstate = hc->tx_state; + + ccid3_pr_debug("%s(%p) %-8.8s -> %s\n", + dccp_role(sk), sk, ccid3_tx_state_name(oldstate), + ccid3_tx_state_name(state)); + WARN_ON(state == oldstate); + hc->tx_state = state; +} + +/* + * Compute the initial sending rate X_init in the manner of RFC 3390: + * + * X_init = min(4 * s, max(2 * s, 4380 bytes)) / RTT + * + * Note that RFC 3390 uses MSS, RFC 4342 refers to RFC 3390, and rfc3448bis + * (rev-02) clarifies the use of RFC 3390 with regard to the above formula. + * For consistency with other parts of the code, X_init is scaled by 2^6. + */ +static inline u64 rfc3390_initial_rate(struct sock *sk) +{ + const struct ccid3_hc_tx_sock *hc = ccid3_hc_tx_sk(sk); + const __u32 w_init = clamp_t(__u32, 4380U, 2 * hc->tx_s, 4 * hc->tx_s); + + return scaled_div(w_init << 6, hc->tx_rtt); +} + +/** + * ccid3_update_send_interval - Calculate new t_ipi = s / X_inst + * @hc: socket to have the send interval updated + * + * This respects the granularity of X_inst (64 * bytes/second). + */ +static void ccid3_update_send_interval(struct ccid3_hc_tx_sock *hc) +{ + hc->tx_t_ipi = scaled_div32(((u64)hc->tx_s) << 6, hc->tx_x); + + DCCP_BUG_ON(hc->tx_t_ipi == 0); + ccid3_pr_debug("t_ipi=%u, s=%u, X=%u\n", hc->tx_t_ipi, + hc->tx_s, (unsigned int)(hc->tx_x >> 6)); +} + +static u32 ccid3_hc_tx_idle_rtt(struct ccid3_hc_tx_sock *hc, ktime_t now) +{ + u32 delta = ktime_us_delta(now, hc->tx_t_last_win_count); + + return delta / hc->tx_rtt; +} + +/** + * ccid3_hc_tx_update_x - Update allowed sending rate X + * @sk: socket to be updated + * @stamp: most recent time if available - can be left NULL. + * + * This function tracks draft rfc3448bis, check there for latest details. + * + * Note: X and X_recv are both stored in units of 64 * bytes/second, to support + * fine-grained resolution of sending rates. This requires scaling by 2^6 + * throughout the code. Only X_calc is unscaled (in bytes/second). + * + */ +static void ccid3_hc_tx_update_x(struct sock *sk, ktime_t *stamp) +{ + struct ccid3_hc_tx_sock *hc = ccid3_hc_tx_sk(sk); + __u64 min_rate = 2 * hc->tx_x_recv; + const __u64 old_x = hc->tx_x; + ktime_t now = stamp ? *stamp : ktime_get_real(); + + /* + * Handle IDLE periods: do not reduce below RFC3390 initial sending rate + * when idling [RFC 4342, 5.1]. Definition of idling is from rfc3448bis: + * a sender is idle if it has not sent anything over a 2-RTT-period. + * For consistency with X and X_recv, min_rate is also scaled by 2^6. + */ + if (ccid3_hc_tx_idle_rtt(hc, now) >= 2) { + min_rate = rfc3390_initial_rate(sk); + min_rate = max(min_rate, 2 * hc->tx_x_recv); + } + + if (hc->tx_p > 0) { + + hc->tx_x = min(((__u64)hc->tx_x_calc) << 6, min_rate); + hc->tx_x = max(hc->tx_x, (((__u64)hc->tx_s) << 6) / TFRC_T_MBI); + + } else if (ktime_us_delta(now, hc->tx_t_ld) - (s64)hc->tx_rtt >= 0) { + + hc->tx_x = min(2 * hc->tx_x, min_rate); + hc->tx_x = max(hc->tx_x, + scaled_div(((__u64)hc->tx_s) << 6, hc->tx_rtt)); + hc->tx_t_ld = now; + } + + if (hc->tx_x != old_x) { + ccid3_pr_debug("X_prev=%u, X_now=%u, X_calc=%u, " + "X_recv=%u\n", (unsigned int)(old_x >> 6), + (unsigned int)(hc->tx_x >> 6), hc->tx_x_calc, + (unsigned int)(hc->tx_x_recv >> 6)); + + ccid3_update_send_interval(hc); + } +} + +/** + * ccid3_hc_tx_update_s - Track the mean packet size `s' + * @hc: socket to be updated + * @len: DCCP packet payload size in bytes + * + * cf. RFC 4342, 5.3 and RFC 3448, 4.1 + */ +static inline void ccid3_hc_tx_update_s(struct ccid3_hc_tx_sock *hc, int len) +{ + const u16 old_s = hc->tx_s; + + hc->tx_s = tfrc_ewma(hc->tx_s, len, 9); + + if (hc->tx_s != old_s) + ccid3_update_send_interval(hc); +} + +/* + * Update Window Counter using the algorithm from [RFC 4342, 8.1]. + * As elsewhere, RTT > 0 is assumed by using dccp_sample_rtt(). + */ +static inline void ccid3_hc_tx_update_win_count(struct ccid3_hc_tx_sock *hc, + ktime_t now) +{ + u32 delta = ktime_us_delta(now, hc->tx_t_last_win_count), + quarter_rtts = (4 * delta) / hc->tx_rtt; + + if (quarter_rtts > 0) { + hc->tx_t_last_win_count = now; + hc->tx_last_win_count += min(quarter_rtts, 5U); + hc->tx_last_win_count &= 0xF; /* mod 16 */ + } +} + +static void ccid3_hc_tx_no_feedback_timer(struct timer_list *t) +{ + struct ccid3_hc_tx_sock *hc = from_timer(hc, t, tx_no_feedback_timer); + struct sock *sk = hc->sk; + unsigned long t_nfb = USEC_PER_SEC / 5; + + bh_lock_sock(sk); + if (sock_owned_by_user(sk)) { + /* Try again later. */ + /* XXX: set some sensible MIB */ + goto restart_timer; + } + + ccid3_pr_debug("%s(%p, state=%s) - entry\n", dccp_role(sk), sk, + ccid3_tx_state_name(hc->tx_state)); + + /* Ignore and do not restart after leaving the established state */ + if ((1 << sk->sk_state) & ~(DCCPF_OPEN | DCCPF_PARTOPEN)) + goto out; + + /* Reset feedback state to "no feedback received" */ + if (hc->tx_state == TFRC_SSTATE_FBACK) + ccid3_hc_tx_set_state(sk, TFRC_SSTATE_NO_FBACK); + + /* + * Determine new allowed sending rate X as per draft rfc3448bis-00, 4.4 + * RTO is 0 if and only if no feedback has been received yet. + */ + if (hc->tx_t_rto == 0 || hc->tx_p == 0) { + + /* halve send rate directly */ + hc->tx_x = max(hc->tx_x / 2, + (((__u64)hc->tx_s) << 6) / TFRC_T_MBI); + ccid3_update_send_interval(hc); + } else { + /* + * Modify the cached value of X_recv + * + * If (X_calc > 2 * X_recv) + * X_recv = max(X_recv / 2, s / (2 * t_mbi)); + * Else + * X_recv = X_calc / 4; + * + * Note that X_recv is scaled by 2^6 while X_calc is not + */ + if (hc->tx_x_calc > (hc->tx_x_recv >> 5)) + hc->tx_x_recv = + max(hc->tx_x_recv / 2, + (((__u64)hc->tx_s) << 6) / (2*TFRC_T_MBI)); + else { + hc->tx_x_recv = hc->tx_x_calc; + hc->tx_x_recv <<= 4; + } + ccid3_hc_tx_update_x(sk, NULL); + } + ccid3_pr_debug("Reduced X to %llu/64 bytes/sec\n", + (unsigned long long)hc->tx_x); + + /* + * Set new timeout for the nofeedback timer. + * See comments in packet_recv() regarding the value of t_RTO. + */ + if (unlikely(hc->tx_t_rto == 0)) /* no feedback received yet */ + t_nfb = TFRC_INITIAL_TIMEOUT; + else + t_nfb = max(hc->tx_t_rto, 2 * hc->tx_t_ipi); + +restart_timer: + sk_reset_timer(sk, &hc->tx_no_feedback_timer, + jiffies + usecs_to_jiffies(t_nfb)); +out: + bh_unlock_sock(sk); + sock_put(sk); +} + +/** + * ccid3_hc_tx_send_packet - Delay-based dequeueing of TX packets + * @sk: socket to send packet from + * @skb: next packet candidate to send on @sk + * + * This function uses the convention of ccid_packet_dequeue_eval() and + * returns a millisecond-delay value between 0 and t_mbi = 64000 msec. + */ +static int ccid3_hc_tx_send_packet(struct sock *sk, struct sk_buff *skb) +{ + struct dccp_sock *dp = dccp_sk(sk); + struct ccid3_hc_tx_sock *hc = ccid3_hc_tx_sk(sk); + ktime_t now = ktime_get_real(); + s64 delay; + + /* + * This function is called only for Data and DataAck packets. Sending + * zero-sized Data(Ack)s is theoretically possible, but for congestion + * control this case is pathological - ignore it. + */ + if (unlikely(skb->len == 0)) + return -EBADMSG; + + if (hc->tx_state == TFRC_SSTATE_NO_SENT) { + sk_reset_timer(sk, &hc->tx_no_feedback_timer, (jiffies + + usecs_to_jiffies(TFRC_INITIAL_TIMEOUT))); + hc->tx_last_win_count = 0; + hc->tx_t_last_win_count = now; + + /* Set t_0 for initial packet */ + hc->tx_t_nom = now; + + hc->tx_s = skb->len; + + /* + * Use initial RTT sample when available: recommended by erratum + * to RFC 4342. This implements the initialisation procedure of + * draft rfc3448bis, section 4.2. Remember, X is scaled by 2^6. + */ + if (dp->dccps_syn_rtt) { + ccid3_pr_debug("SYN RTT = %uus\n", dp->dccps_syn_rtt); + hc->tx_rtt = dp->dccps_syn_rtt; + hc->tx_x = rfc3390_initial_rate(sk); + hc->tx_t_ld = now; + } else { + /* + * Sender does not have RTT sample: + * - set fallback RTT (RFC 4340, 3.4) since a RTT value + * is needed in several parts (e.g. window counter); + * - set sending rate X_pps = 1pps as per RFC 3448, 4.2. + */ + hc->tx_rtt = DCCP_FALLBACK_RTT; + hc->tx_x = hc->tx_s; + hc->tx_x <<= 6; + } + ccid3_update_send_interval(hc); + + ccid3_hc_tx_set_state(sk, TFRC_SSTATE_NO_FBACK); + + } else { + delay = ktime_us_delta(hc->tx_t_nom, now); + ccid3_pr_debug("delay=%ld\n", (long)delay); + /* + * Scheduling of packet transmissions (RFC 5348, 8.3) + * + * if (t_now > t_nom - delta) + * // send the packet now + * else + * // send the packet in (t_nom - t_now) milliseconds. + */ + if (delay >= TFRC_T_DELTA) + return (u32)delay / USEC_PER_MSEC; + + ccid3_hc_tx_update_win_count(hc, now); + } + + /* prepare to send now (add options etc.) */ + dp->dccps_hc_tx_insert_options = 1; + DCCP_SKB_CB(skb)->dccpd_ccval = hc->tx_last_win_count; + + /* set the nominal send time for the next following packet */ + hc->tx_t_nom = ktime_add_us(hc->tx_t_nom, hc->tx_t_ipi); + return CCID_PACKET_SEND_AT_ONCE; +} + +static void ccid3_hc_tx_packet_sent(struct sock *sk, unsigned int len) +{ + struct ccid3_hc_tx_sock *hc = ccid3_hc_tx_sk(sk); + + ccid3_hc_tx_update_s(hc, len); + + if (tfrc_tx_hist_add(&hc->tx_hist, dccp_sk(sk)->dccps_gss)) + DCCP_CRIT("packet history - out of memory!"); +} + +static void ccid3_hc_tx_packet_recv(struct sock *sk, struct sk_buff *skb) +{ + struct ccid3_hc_tx_sock *hc = ccid3_hc_tx_sk(sk); + struct tfrc_tx_hist_entry *acked; + ktime_t now; + unsigned long t_nfb; + u32 r_sample; + + /* we are only interested in ACKs */ + if (!(DCCP_SKB_CB(skb)->dccpd_type == DCCP_PKT_ACK || + DCCP_SKB_CB(skb)->dccpd_type == DCCP_PKT_DATAACK)) + return; + /* + * Locate the acknowledged packet in the TX history. + * + * Returning "entry not found" here can for instance happen when + * - the host has not sent out anything (e.g. a passive server), + * - the Ack is outdated (packet with higher Ack number was received), + * - it is a bogus Ack (for a packet not sent on this connection). + */ + acked = tfrc_tx_hist_find_entry(hc->tx_hist, dccp_hdr_ack_seq(skb)); + if (acked == NULL) + return; + /* For the sake of RTT sampling, ignore/remove all older entries */ + tfrc_tx_hist_purge(&acked->next); + + /* Update the moving average for the RTT estimate (RFC 3448, 4.3) */ + now = ktime_get_real(); + r_sample = dccp_sample_rtt(sk, ktime_us_delta(now, acked->stamp)); + hc->tx_rtt = tfrc_ewma(hc->tx_rtt, r_sample, 9); + + /* + * Update allowed sending rate X as per draft rfc3448bis-00, 4.2/3 + */ + if (hc->tx_state == TFRC_SSTATE_NO_FBACK) { + ccid3_hc_tx_set_state(sk, TFRC_SSTATE_FBACK); + + if (hc->tx_t_rto == 0) { + /* + * Initial feedback packet: Larger Initial Windows (4.2) + */ + hc->tx_x = rfc3390_initial_rate(sk); + hc->tx_t_ld = now; + + ccid3_update_send_interval(hc); + + goto done_computing_x; + } else if (hc->tx_p == 0) { + /* + * First feedback after nofeedback timer expiry (4.3) + */ + goto done_computing_x; + } + } + + /* Update sending rate (step 4 of [RFC 3448, 4.3]) */ + if (hc->tx_p > 0) + hc->tx_x_calc = tfrc_calc_x(hc->tx_s, hc->tx_rtt, hc->tx_p); + ccid3_hc_tx_update_x(sk, &now); + +done_computing_x: + ccid3_pr_debug("%s(%p), RTT=%uus (sample=%uus), s=%u, " + "p=%u, X_calc=%u, X_recv=%u, X=%u\n", + dccp_role(sk), sk, hc->tx_rtt, r_sample, + hc->tx_s, hc->tx_p, hc->tx_x_calc, + (unsigned int)(hc->tx_x_recv >> 6), + (unsigned int)(hc->tx_x >> 6)); + + /* unschedule no feedback timer */ + sk_stop_timer(sk, &hc->tx_no_feedback_timer); + + /* + * As we have calculated new ipi, delta, t_nom it is possible + * that we now can send a packet, so wake up dccp_wait_for_ccid + */ + sk->sk_write_space(sk); + + /* + * Update timeout interval for the nofeedback timer. In order to control + * rate halving on networks with very low RTTs (<= 1 ms), use per-route + * tunable RTAX_RTO_MIN value as the lower bound. + */ + hc->tx_t_rto = max_t(u32, 4 * hc->tx_rtt, + USEC_PER_SEC/HZ * tcp_rto_min(sk)); + /* + * Schedule no feedback timer to expire in + * max(t_RTO, 2 * s/X) = max(t_RTO, 2 * t_ipi) + */ + t_nfb = max(hc->tx_t_rto, 2 * hc->tx_t_ipi); + + ccid3_pr_debug("%s(%p), Scheduled no feedback timer to " + "expire in %lu jiffies (%luus)\n", + dccp_role(sk), sk, usecs_to_jiffies(t_nfb), t_nfb); + + sk_reset_timer(sk, &hc->tx_no_feedback_timer, + jiffies + usecs_to_jiffies(t_nfb)); +} + +static int ccid3_hc_tx_parse_options(struct sock *sk, u8 packet_type, + u8 option, u8 *optval, u8 optlen) +{ + struct ccid3_hc_tx_sock *hc = ccid3_hc_tx_sk(sk); + __be32 opt_val; + + switch (option) { + case TFRC_OPT_RECEIVE_RATE: + case TFRC_OPT_LOSS_EVENT_RATE: + /* Must be ignored on Data packets, cf. RFC 4342 8.3 and 8.5 */ + if (packet_type == DCCP_PKT_DATA) + break; + if (unlikely(optlen != 4)) { + DCCP_WARN("%s(%p), invalid len %d for %u\n", + dccp_role(sk), sk, optlen, option); + return -EINVAL; + } + opt_val = ntohl(get_unaligned((__be32 *)optval)); + + if (option == TFRC_OPT_RECEIVE_RATE) { + /* Receive Rate is kept in units of 64 bytes/second */ + hc->tx_x_recv = opt_val; + hc->tx_x_recv <<= 6; + + ccid3_pr_debug("%s(%p), RECEIVE_RATE=%u\n", + dccp_role(sk), sk, opt_val); + } else { + /* Update the fixpoint Loss Event Rate fraction */ + hc->tx_p = tfrc_invert_loss_event_rate(opt_val); + + ccid3_pr_debug("%s(%p), LOSS_EVENT_RATE=%u\n", + dccp_role(sk), sk, opt_val); + } + } + return 0; +} + +static int ccid3_hc_tx_init(struct ccid *ccid, struct sock *sk) +{ + struct ccid3_hc_tx_sock *hc = ccid_priv(ccid); + + hc->tx_state = TFRC_SSTATE_NO_SENT; + hc->tx_hist = NULL; + hc->sk = sk; + timer_setup(&hc->tx_no_feedback_timer, + ccid3_hc_tx_no_feedback_timer, 0); + return 0; +} + +static void ccid3_hc_tx_exit(struct sock *sk) +{ + struct ccid3_hc_tx_sock *hc = ccid3_hc_tx_sk(sk); + + sk_stop_timer(sk, &hc->tx_no_feedback_timer); + tfrc_tx_hist_purge(&hc->tx_hist); +} + +static void ccid3_hc_tx_get_info(struct sock *sk, struct tcp_info *info) +{ + info->tcpi_rto = ccid3_hc_tx_sk(sk)->tx_t_rto; + info->tcpi_rtt = ccid3_hc_tx_sk(sk)->tx_rtt; +} + +static int ccid3_hc_tx_getsockopt(struct sock *sk, const int optname, int len, + u32 __user *optval, int __user *optlen) +{ + const struct ccid3_hc_tx_sock *hc = ccid3_hc_tx_sk(sk); + struct tfrc_tx_info tfrc; + const void *val; + + switch (optname) { + case DCCP_SOCKOPT_CCID_TX_INFO: + if (len < sizeof(tfrc)) + return -EINVAL; + memset(&tfrc, 0, sizeof(tfrc)); + tfrc.tfrctx_x = hc->tx_x; + tfrc.tfrctx_x_recv = hc->tx_x_recv; + tfrc.tfrctx_x_calc = hc->tx_x_calc; + tfrc.tfrctx_rtt = hc->tx_rtt; + tfrc.tfrctx_p = hc->tx_p; + tfrc.tfrctx_rto = hc->tx_t_rto; + tfrc.tfrctx_ipi = hc->tx_t_ipi; + len = sizeof(tfrc); + val = &tfrc; + break; + default: + return -ENOPROTOOPT; + } + + if (put_user(len, optlen) || copy_to_user(optval, val, len)) + return -EFAULT; + + return 0; +} + +/* + * Receiver Half-Connection Routines + */ + +/* CCID3 feedback types */ +enum ccid3_fback_type { + CCID3_FBACK_NONE = 0, + CCID3_FBACK_INITIAL, + CCID3_FBACK_PERIODIC, + CCID3_FBACK_PARAM_CHANGE +}; + +#ifdef CONFIG_IP_DCCP_CCID3_DEBUG +static const char *ccid3_rx_state_name(enum ccid3_hc_rx_states state) +{ + static const char *const ccid3_rx_state_names[] = { + [TFRC_RSTATE_NO_DATA] = "NO_DATA", + [TFRC_RSTATE_DATA] = "DATA", + }; + + return ccid3_rx_state_names[state]; +} +#endif + +static void ccid3_hc_rx_set_state(struct sock *sk, + enum ccid3_hc_rx_states state) +{ + struct ccid3_hc_rx_sock *hc = ccid3_hc_rx_sk(sk); + enum ccid3_hc_rx_states oldstate = hc->rx_state; + + ccid3_pr_debug("%s(%p) %-8.8s -> %s\n", + dccp_role(sk), sk, ccid3_rx_state_name(oldstate), + ccid3_rx_state_name(state)); + WARN_ON(state == oldstate); + hc->rx_state = state; +} + +static void ccid3_hc_rx_send_feedback(struct sock *sk, + const struct sk_buff *skb, + enum ccid3_fback_type fbtype) +{ + struct ccid3_hc_rx_sock *hc = ccid3_hc_rx_sk(sk); + struct dccp_sock *dp = dccp_sk(sk); + ktime_t now = ktime_get(); + s64 delta = 0; + + switch (fbtype) { + case CCID3_FBACK_INITIAL: + hc->rx_x_recv = 0; + hc->rx_pinv = ~0U; /* see RFC 4342, 8.5 */ + break; + case CCID3_FBACK_PARAM_CHANGE: + /* + * When parameters change (new loss or p > p_prev), we do not + * have a reliable estimate for R_m of [RFC 3448, 6.2] and so + * need to reuse the previous value of X_recv. However, when + * X_recv was 0 (due to early loss), this would kill X down to + * s/t_mbi (i.e. one packet in 64 seconds). + * To avoid such drastic reduction, we approximate X_recv as + * the number of bytes since last feedback. + * This is a safe fallback, since X is bounded above by X_calc. + */ + if (hc->rx_x_recv > 0) + break; + fallthrough; + case CCID3_FBACK_PERIODIC: + delta = ktime_us_delta(now, hc->rx_tstamp_last_feedback); + if (delta <= 0) + delta = 1; + hc->rx_x_recv = scaled_div32(hc->rx_bytes_recv, delta); + break; + default: + return; + } + + ccid3_pr_debug("Interval %lldusec, X_recv=%u, 1/p=%u\n", delta, + hc->rx_x_recv, hc->rx_pinv); + + hc->rx_tstamp_last_feedback = now; + hc->rx_last_counter = dccp_hdr(skb)->dccph_ccval; + hc->rx_bytes_recv = 0; + + dp->dccps_hc_rx_insert_options = 1; + dccp_send_ack(sk); +} + +static int ccid3_hc_rx_insert_options(struct sock *sk, struct sk_buff *skb) +{ + const struct ccid3_hc_rx_sock *hc = ccid3_hc_rx_sk(sk); + __be32 x_recv, pinv; + + if (!(sk->sk_state == DCCP_OPEN || sk->sk_state == DCCP_PARTOPEN)) + return 0; + + if (dccp_packet_without_ack(skb)) + return 0; + + x_recv = htonl(hc->rx_x_recv); + pinv = htonl(hc->rx_pinv); + + if (dccp_insert_option(skb, TFRC_OPT_LOSS_EVENT_RATE, + &pinv, sizeof(pinv)) || + dccp_insert_option(skb, TFRC_OPT_RECEIVE_RATE, + &x_recv, sizeof(x_recv))) + return -1; + + return 0; +} + +/** + * ccid3_first_li - Implements [RFC 5348, 6.3.1] + * @sk: socket to calculate loss interval for + * + * Determine the length of the first loss interval via inverse lookup. + * Assume that X_recv can be computed by the throughput equation + * s + * X_recv = -------- + * R * fval + * Find some p such that f(p) = fval; return 1/p (scaled). + */ +static u32 ccid3_first_li(struct sock *sk) +{ + struct ccid3_hc_rx_sock *hc = ccid3_hc_rx_sk(sk); + u32 x_recv, p; + s64 delta; + u64 fval; + + if (hc->rx_rtt == 0) { + DCCP_WARN("No RTT estimate available, using fallback RTT\n"); + hc->rx_rtt = DCCP_FALLBACK_RTT; + } + + delta = ktime_us_delta(ktime_get(), hc->rx_tstamp_last_feedback); + if (delta <= 0) + delta = 1; + x_recv = scaled_div32(hc->rx_bytes_recv, delta); + if (x_recv == 0) { /* would also trigger divide-by-zero */ + DCCP_WARN("X_recv==0\n"); + if (hc->rx_x_recv == 0) { + DCCP_BUG("stored value of X_recv is zero"); + return ~0U; + } + x_recv = hc->rx_x_recv; + } + + fval = scaled_div(hc->rx_s, hc->rx_rtt); + fval = scaled_div32(fval, x_recv); + p = tfrc_calc_x_reverse_lookup(fval); + + ccid3_pr_debug("%s(%p), receive rate=%u bytes/s, implied " + "loss rate=%u\n", dccp_role(sk), sk, x_recv, p); + + return p == 0 ? ~0U : scaled_div(1, p); +} + +static void ccid3_hc_rx_packet_recv(struct sock *sk, struct sk_buff *skb) +{ + struct ccid3_hc_rx_sock *hc = ccid3_hc_rx_sk(sk); + enum ccid3_fback_type do_feedback = CCID3_FBACK_NONE; + const u64 ndp = dccp_sk(sk)->dccps_options_received.dccpor_ndp; + const bool is_data_packet = dccp_data_packet(skb); + + if (unlikely(hc->rx_state == TFRC_RSTATE_NO_DATA)) { + if (is_data_packet) { + const u32 payload = skb->len - dccp_hdr(skb)->dccph_doff * 4; + do_feedback = CCID3_FBACK_INITIAL; + ccid3_hc_rx_set_state(sk, TFRC_RSTATE_DATA); + hc->rx_s = payload; + /* + * Not necessary to update rx_bytes_recv here, + * since X_recv = 0 for the first feedback packet (cf. + * RFC 3448, 6.3) -- gerrit + */ + } + goto update_records; + } + + if (tfrc_rx_hist_duplicate(&hc->rx_hist, skb)) + return; /* done receiving */ + + if (is_data_packet) { + const u32 payload = skb->len - dccp_hdr(skb)->dccph_doff * 4; + /* + * Update moving-average of s and the sum of received payload bytes + */ + hc->rx_s = tfrc_ewma(hc->rx_s, payload, 9); + hc->rx_bytes_recv += payload; + } + + /* + * Perform loss detection and handle pending losses + */ + if (tfrc_rx_handle_loss(&hc->rx_hist, &hc->rx_li_hist, + skb, ndp, ccid3_first_li, sk)) { + do_feedback = CCID3_FBACK_PARAM_CHANGE; + goto done_receiving; + } + + if (tfrc_rx_hist_loss_pending(&hc->rx_hist)) + return; /* done receiving */ + + /* + * Handle data packets: RTT sampling and monitoring p + */ + if (unlikely(!is_data_packet)) + goto update_records; + + if (!tfrc_lh_is_initialised(&hc->rx_li_hist)) { + const u32 sample = tfrc_rx_hist_sample_rtt(&hc->rx_hist, skb); + /* + * Empty loss history: no loss so far, hence p stays 0. + * Sample RTT values, since an RTT estimate is required for the + * computation of p when the first loss occurs; RFC 3448, 6.3.1. + */ + if (sample != 0) + hc->rx_rtt = tfrc_ewma(hc->rx_rtt, sample, 9); + + } else if (tfrc_lh_update_i_mean(&hc->rx_li_hist, skb)) { + /* + * Step (3) of [RFC 3448, 6.1]: Recompute I_mean and, if I_mean + * has decreased (resp. p has increased), send feedback now. + */ + do_feedback = CCID3_FBACK_PARAM_CHANGE; + } + + /* + * Check if the periodic once-per-RTT feedback is due; RFC 4342, 10.3 + */ + if (SUB16(dccp_hdr(skb)->dccph_ccval, hc->rx_last_counter) > 3) + do_feedback = CCID3_FBACK_PERIODIC; + +update_records: + tfrc_rx_hist_add_packet(&hc->rx_hist, skb, ndp); + +done_receiving: + if (do_feedback) + ccid3_hc_rx_send_feedback(sk, skb, do_feedback); +} + +static int ccid3_hc_rx_init(struct ccid *ccid, struct sock *sk) +{ + struct ccid3_hc_rx_sock *hc = ccid_priv(ccid); + + hc->rx_state = TFRC_RSTATE_NO_DATA; + tfrc_lh_init(&hc->rx_li_hist); + return tfrc_rx_hist_alloc(&hc->rx_hist); +} + +static void ccid3_hc_rx_exit(struct sock *sk) +{ + struct ccid3_hc_rx_sock *hc = ccid3_hc_rx_sk(sk); + + tfrc_rx_hist_purge(&hc->rx_hist); + tfrc_lh_cleanup(&hc->rx_li_hist); +} + +static void ccid3_hc_rx_get_info(struct sock *sk, struct tcp_info *info) +{ + info->tcpi_ca_state = ccid3_hc_rx_sk(sk)->rx_state; + info->tcpi_options |= TCPI_OPT_TIMESTAMPS; + info->tcpi_rcv_rtt = ccid3_hc_rx_sk(sk)->rx_rtt; +} + +static int ccid3_hc_rx_getsockopt(struct sock *sk, const int optname, int len, + u32 __user *optval, int __user *optlen) +{ + const struct ccid3_hc_rx_sock *hc = ccid3_hc_rx_sk(sk); + struct tfrc_rx_info rx_info; + const void *val; + + switch (optname) { + case DCCP_SOCKOPT_CCID_RX_INFO: + if (len < sizeof(rx_info)) + return -EINVAL; + rx_info.tfrcrx_x_recv = hc->rx_x_recv; + rx_info.tfrcrx_rtt = hc->rx_rtt; + rx_info.tfrcrx_p = tfrc_invert_loss_event_rate(hc->rx_pinv); + len = sizeof(rx_info); + val = &rx_info; + break; + default: + return -ENOPROTOOPT; + } + + if (put_user(len, optlen) || copy_to_user(optval, val, len)) + return -EFAULT; + + return 0; +} + +struct ccid_operations ccid3_ops = { + .ccid_id = DCCPC_CCID3, + .ccid_name = "TCP-Friendly Rate Control", + .ccid_hc_tx_obj_size = sizeof(struct ccid3_hc_tx_sock), + .ccid_hc_tx_init = ccid3_hc_tx_init, + .ccid_hc_tx_exit = ccid3_hc_tx_exit, + .ccid_hc_tx_send_packet = ccid3_hc_tx_send_packet, + .ccid_hc_tx_packet_sent = ccid3_hc_tx_packet_sent, + .ccid_hc_tx_packet_recv = ccid3_hc_tx_packet_recv, + .ccid_hc_tx_parse_options = ccid3_hc_tx_parse_options, + .ccid_hc_rx_obj_size = sizeof(struct ccid3_hc_rx_sock), + .ccid_hc_rx_init = ccid3_hc_rx_init, + .ccid_hc_rx_exit = ccid3_hc_rx_exit, + .ccid_hc_rx_insert_options = ccid3_hc_rx_insert_options, + .ccid_hc_rx_packet_recv = ccid3_hc_rx_packet_recv, + .ccid_hc_rx_get_info = ccid3_hc_rx_get_info, + .ccid_hc_tx_get_info = ccid3_hc_tx_get_info, + .ccid_hc_rx_getsockopt = ccid3_hc_rx_getsockopt, + .ccid_hc_tx_getsockopt = ccid3_hc_tx_getsockopt, +}; + +#ifdef CONFIG_IP_DCCP_CCID3_DEBUG +module_param(ccid3_debug, bool, 0644); +MODULE_PARM_DESC(ccid3_debug, "Enable CCID-3 debug messages"); +#endif diff --git a/net/dccp/ccids/ccid3.h b/net/dccp/ccids/ccid3.h new file mode 100644 index 000000000..02e0fc9f6 --- /dev/null +++ b/net/dccp/ccids/ccid3.h @@ -0,0 +1,148 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * Copyright (c) 2005-7 The University of Waikato, Hamilton, New Zealand. + * Copyright (c) 2007 The University of Aberdeen, Scotland, UK + * + * An implementation of the DCCP protocol + * + * This code has been developed by the University of Waikato WAND + * research group. For further information please see https://www.wand.net.nz/ + * or e-mail Ian McDonald - ian.mcdonald@jandi.co.nz + * + * This code also uses code from Lulea University, rereleased as GPL by its + * authors: + * Copyright (c) 2003 Nils-Erik Mattsson, Joacim Haggmark, Magnus Erixzon + * + * Changes to meet Linux coding standards, to make it meet latest ccid3 draft + * and to make it work as a loadable module in the DCCP stack written by + * Arnaldo Carvalho de Melo . + * + * Copyright (c) 2005 Arnaldo Carvalho de Melo + */ +#ifndef _DCCP_CCID3_H_ +#define _DCCP_CCID3_H_ + +#include +#include +#include +#include +#include "lib/tfrc.h" +#include "../ccid.h" + +/* Two seconds as per RFC 5348, 4.2 */ +#define TFRC_INITIAL_TIMEOUT (2 * USEC_PER_SEC) + +/* Parameter t_mbi from [RFC 3448, 4.3]: backoff interval in seconds */ +#define TFRC_T_MBI 64 + +/* + * The t_delta parameter (RFC 5348, 8.3): delays of less than %USEC_PER_MSEC are + * rounded down to 0, since sk_reset_timer() here uses millisecond granularity. + * Hence we can use a constant t_delta = %USEC_PER_MSEC when HZ >= 500. A coarse + * resolution of HZ < 500 means that the error is below one timer tick (t_gran) + * when using the constant t_delta = t_gran / 2 = %USEC_PER_SEC / (2 * HZ). + */ +#if (HZ >= 500) +# define TFRC_T_DELTA USEC_PER_MSEC +#else +# define TFRC_T_DELTA (USEC_PER_SEC / (2 * HZ)) +#endif + +enum ccid3_options { + TFRC_OPT_LOSS_EVENT_RATE = 192, + TFRC_OPT_LOSS_INTERVALS = 193, + TFRC_OPT_RECEIVE_RATE = 194, +}; + +/* TFRC sender states */ +enum ccid3_hc_tx_states { + TFRC_SSTATE_NO_SENT = 1, + TFRC_SSTATE_NO_FBACK, + TFRC_SSTATE_FBACK, +}; + +/** + * struct ccid3_hc_tx_sock - CCID3 sender half-connection socket + * @tx_x: Current sending rate in 64 * bytes per second + * @tx_x_recv: Receive rate in 64 * bytes per second + * @tx_x_calc: Calculated rate in bytes per second + * @tx_rtt: Estimate of current round trip time in usecs + * @tx_p: Current loss event rate (0-1) scaled by 1000000 + * @tx_s: Packet size in bytes + * @tx_t_rto: Nofeedback Timer setting in usecs + * @tx_t_ipi: Interpacket (send) interval (RFC 3448, 4.6) in usecs + * @tx_state: Sender state, one of %ccid3_hc_tx_states + * @tx_last_win_count: Last window counter sent + * @tx_t_last_win_count: Timestamp of earliest packet + * with last_win_count value sent + * @tx_no_feedback_timer: Handle to no feedback timer + * @tx_t_ld: Time last doubled during slow start + * @tx_t_nom: Nominal send time of next packet + * @tx_hist: Packet history + */ +struct ccid3_hc_tx_sock { + u64 tx_x; + u64 tx_x_recv; + u32 tx_x_calc; + u32 tx_rtt; + u32 tx_p; + u32 tx_t_rto; + u32 tx_t_ipi; + u16 tx_s; + enum ccid3_hc_tx_states tx_state:8; + u8 tx_last_win_count; + ktime_t tx_t_last_win_count; + struct timer_list tx_no_feedback_timer; + struct sock *sk; + ktime_t tx_t_ld; + ktime_t tx_t_nom; + struct tfrc_tx_hist_entry *tx_hist; +}; + +static inline struct ccid3_hc_tx_sock *ccid3_hc_tx_sk(const struct sock *sk) +{ + struct ccid3_hc_tx_sock *hctx = ccid_priv(dccp_sk(sk)->dccps_hc_tx_ccid); + BUG_ON(hctx == NULL); + return hctx; +} + +/* TFRC receiver states */ +enum ccid3_hc_rx_states { + TFRC_RSTATE_NO_DATA = 1, + TFRC_RSTATE_DATA, +}; + +/** + * struct ccid3_hc_rx_sock - CCID3 receiver half-connection socket + * @rx_last_counter: Tracks window counter (RFC 4342, 8.1) + * @rx_state: Receiver state, one of %ccid3_hc_rx_states + * @rx_bytes_recv: Total sum of DCCP payload bytes + * @rx_x_recv: Receiver estimate of send rate (RFC 3448, sec. 4.3) + * @rx_rtt: Receiver estimate of RTT + * @rx_tstamp_last_feedback: Time at which last feedback was sent + * @rx_hist: Packet history (loss detection + RTT sampling) + * @rx_li_hist: Loss Interval database + * @rx_s: Received packet size in bytes + * @rx_pinv: Inverse of Loss Event Rate (RFC 4342, sec. 8.5) + */ +struct ccid3_hc_rx_sock { + u8 rx_last_counter:4; + enum ccid3_hc_rx_states rx_state:8; + u32 rx_bytes_recv; + u32 rx_x_recv; + u32 rx_rtt; + ktime_t rx_tstamp_last_feedback; + struct tfrc_rx_hist rx_hist; + struct tfrc_loss_hist rx_li_hist; + u16 rx_s; +#define rx_pinv rx_li_hist.i_mean +}; + +static inline struct ccid3_hc_rx_sock *ccid3_hc_rx_sk(const struct sock *sk) +{ + struct ccid3_hc_rx_sock *hcrx = ccid_priv(dccp_sk(sk)->dccps_hc_rx_ccid); + BUG_ON(hcrx == NULL); + return hcrx; +} + +#endif /* _DCCP_CCID3_H_ */ diff --git a/net/dccp/ccids/lib/loss_interval.c b/net/dccp/ccids/lib/loss_interval.c new file mode 100644 index 000000000..da9531984 --- /dev/null +++ b/net/dccp/ccids/lib/loss_interval.c @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (c) 2007 The University of Aberdeen, Scotland, UK + * Copyright (c) 2005-7 The University of Waikato, Hamilton, New Zealand. + * Copyright (c) 2005-7 Ian McDonald + * Copyright (c) 2005 Arnaldo Carvalho de Melo + */ +#include +#include "tfrc.h" + +static struct kmem_cache *tfrc_lh_slab __read_mostly; +/* Loss Interval weights from [RFC 3448, 5.4], scaled by 10 */ +static const int tfrc_lh_weights[NINTERVAL] = { 10, 10, 10, 10, 8, 6, 4, 2 }; + +/* implements LIFO semantics on the array */ +static inline u8 LIH_INDEX(const u8 ctr) +{ + return LIH_SIZE - 1 - (ctr % LIH_SIZE); +} + +/* the `counter' index always points at the next entry to be populated */ +static inline struct tfrc_loss_interval *tfrc_lh_peek(struct tfrc_loss_hist *lh) +{ + return lh->counter ? lh->ring[LIH_INDEX(lh->counter - 1)] : NULL; +} + +/* given i with 0 <= i <= k, return I_i as per the rfc3448bis notation */ +static inline u32 tfrc_lh_get_interval(struct tfrc_loss_hist *lh, const u8 i) +{ + BUG_ON(i >= lh->counter); + return lh->ring[LIH_INDEX(lh->counter - i - 1)]->li_length; +} + +/* + * On-demand allocation and de-allocation of entries + */ +static struct tfrc_loss_interval *tfrc_lh_demand_next(struct tfrc_loss_hist *lh) +{ + if (lh->ring[LIH_INDEX(lh->counter)] == NULL) + lh->ring[LIH_INDEX(lh->counter)] = kmem_cache_alloc(tfrc_lh_slab, + GFP_ATOMIC); + return lh->ring[LIH_INDEX(lh->counter)]; +} + +void tfrc_lh_cleanup(struct tfrc_loss_hist *lh) +{ + if (!tfrc_lh_is_initialised(lh)) + return; + + for (lh->counter = 0; lh->counter < LIH_SIZE; lh->counter++) + if (lh->ring[LIH_INDEX(lh->counter)] != NULL) { + kmem_cache_free(tfrc_lh_slab, + lh->ring[LIH_INDEX(lh->counter)]); + lh->ring[LIH_INDEX(lh->counter)] = NULL; + } +} + +static void tfrc_lh_calc_i_mean(struct tfrc_loss_hist *lh) +{ + u32 i_i, i_tot0 = 0, i_tot1 = 0, w_tot = 0; + int i, k = tfrc_lh_length(lh) - 1; /* k is as in rfc3448bis, 5.4 */ + + if (k <= 0) + return; + + for (i = 0; i <= k; i++) { + i_i = tfrc_lh_get_interval(lh, i); + + if (i < k) { + i_tot0 += i_i * tfrc_lh_weights[i]; + w_tot += tfrc_lh_weights[i]; + } + if (i > 0) + i_tot1 += i_i * tfrc_lh_weights[i-1]; + } + + lh->i_mean = max(i_tot0, i_tot1) / w_tot; +} + +/** + * tfrc_lh_update_i_mean - Update the `open' loss interval I_0 + * @lh: histogram to update + * @skb: received socket triggering loss interval update + * + * For recomputing p: returns `true' if p > p_prev <=> 1/p < 1/p_prev + */ +u8 tfrc_lh_update_i_mean(struct tfrc_loss_hist *lh, struct sk_buff *skb) +{ + struct tfrc_loss_interval *cur = tfrc_lh_peek(lh); + u32 old_i_mean = lh->i_mean; + s64 len; + + if (cur == NULL) /* not initialised */ + return 0; + + len = dccp_delta_seqno(cur->li_seqno, DCCP_SKB_CB(skb)->dccpd_seq) + 1; + + if (len - (s64)cur->li_length <= 0) /* duplicate or reordered */ + return 0; + + if (SUB16(dccp_hdr(skb)->dccph_ccval, cur->li_ccval) > 4) + /* + * Implements RFC 4342, 10.2: + * If a packet S (skb) exists whose seqno comes `after' the one + * starting the current loss interval (cur) and if the modulo-16 + * distance from C(cur) to C(S) is greater than 4, consider all + * subsequent packets as belonging to a new loss interval. This + * test is necessary since CCVal may wrap between intervals. + */ + cur->li_is_closed = 1; + + if (tfrc_lh_length(lh) == 1) /* due to RFC 3448, 6.3.1 */ + return 0; + + cur->li_length = len; + tfrc_lh_calc_i_mean(lh); + + return lh->i_mean < old_i_mean; +} + +/* Determine if `new_loss' does begin a new loss interval [RFC 4342, 10.2] */ +static inline u8 tfrc_lh_is_new_loss(struct tfrc_loss_interval *cur, + struct tfrc_rx_hist_entry *new_loss) +{ + return dccp_delta_seqno(cur->li_seqno, new_loss->tfrchrx_seqno) > 0 && + (cur->li_is_closed || SUB16(new_loss->tfrchrx_ccval, cur->li_ccval) > 4); +} + +/** + * tfrc_lh_interval_add - Insert new record into the Loss Interval database + * @lh: Loss Interval database + * @rh: Receive history containing a fresh loss event + * @calc_first_li: Caller-dependent routine to compute length of first interval + * @sk: Used by @calc_first_li in caller-specific way (subtyping) + * + * Updates I_mean and returns 1 if a new interval has in fact been added to @lh. + */ +int tfrc_lh_interval_add(struct tfrc_loss_hist *lh, struct tfrc_rx_hist *rh, + u32 (*calc_first_li)(struct sock *), struct sock *sk) +{ + struct tfrc_loss_interval *cur = tfrc_lh_peek(lh), *new; + + if (cur != NULL && !tfrc_lh_is_new_loss(cur, tfrc_rx_hist_loss_prev(rh))) + return 0; + + new = tfrc_lh_demand_next(lh); + if (unlikely(new == NULL)) { + DCCP_CRIT("Cannot allocate/add loss record."); + return 0; + } + + new->li_seqno = tfrc_rx_hist_loss_prev(rh)->tfrchrx_seqno; + new->li_ccval = tfrc_rx_hist_loss_prev(rh)->tfrchrx_ccval; + new->li_is_closed = 0; + + if (++lh->counter == 1) + lh->i_mean = new->li_length = (*calc_first_li)(sk); + else { + cur->li_length = dccp_delta_seqno(cur->li_seqno, new->li_seqno); + new->li_length = dccp_delta_seqno(new->li_seqno, + tfrc_rx_hist_last_rcv(rh)->tfrchrx_seqno) + 1; + if (lh->counter > (2*LIH_SIZE)) + lh->counter -= LIH_SIZE; + + tfrc_lh_calc_i_mean(lh); + } + return 1; +} + +int __init tfrc_li_init(void) +{ + tfrc_lh_slab = kmem_cache_create("tfrc_li_hist", + sizeof(struct tfrc_loss_interval), 0, + SLAB_HWCACHE_ALIGN, NULL); + return tfrc_lh_slab == NULL ? -ENOBUFS : 0; +} + +void tfrc_li_exit(void) +{ + if (tfrc_lh_slab != NULL) { + kmem_cache_destroy(tfrc_lh_slab); + tfrc_lh_slab = NULL; + } +} diff --git a/net/dccp/ccids/lib/loss_interval.h b/net/dccp/ccids/lib/loss_interval.h new file mode 100644 index 000000000..c3d95f85e --- /dev/null +++ b/net/dccp/ccids/lib/loss_interval.h @@ -0,0 +1,69 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +#ifndef _DCCP_LI_HIST_ +#define _DCCP_LI_HIST_ +/* + * Copyright (c) 2007 The University of Aberdeen, Scotland, UK + * Copyright (c) 2005-7 The University of Waikato, Hamilton, New Zealand. + * Copyright (c) 2005-7 Ian McDonald + * Copyright (c) 2005 Arnaldo Carvalho de Melo + */ +#include +#include +#include + +/* + * Number of loss intervals (RFC 4342, 8.6.1). The history size is one more than + * NINTERVAL, since the `open' interval I_0 is always stored as the first entry. + */ +#define NINTERVAL 8 +#define LIH_SIZE (NINTERVAL + 1) + +/** + * tfrc_loss_interval - Loss history record for TFRC-based protocols + * @li_seqno: Highest received seqno before the start of loss + * @li_ccval: The CCVal belonging to @li_seqno + * @li_is_closed: Whether @li_seqno is older than 1 RTT + * @li_length: Loss interval sequence length + */ +struct tfrc_loss_interval { + u64 li_seqno:48, + li_ccval:4, + li_is_closed:1; + u32 li_length; +}; + +/** + * tfrc_loss_hist - Loss record database + * @ring: Circular queue managed in LIFO manner + * @counter: Current count of entries (can be more than %LIH_SIZE) + * @i_mean: Current Average Loss Interval [RFC 3448, 5.4] + */ +struct tfrc_loss_hist { + struct tfrc_loss_interval *ring[LIH_SIZE]; + u8 counter; + u32 i_mean; +}; + +static inline void tfrc_lh_init(struct tfrc_loss_hist *lh) +{ + memset(lh, 0, sizeof(struct tfrc_loss_hist)); +} + +static inline u8 tfrc_lh_is_initialised(struct tfrc_loss_hist *lh) +{ + return lh->counter > 0; +} + +static inline u8 tfrc_lh_length(struct tfrc_loss_hist *lh) +{ + return min(lh->counter, (u8)LIH_SIZE); +} + +struct tfrc_rx_hist; + +int tfrc_lh_interval_add(struct tfrc_loss_hist *, struct tfrc_rx_hist *, + u32 (*first_li)(struct sock *), struct sock *); +u8 tfrc_lh_update_i_mean(struct tfrc_loss_hist *lh, struct sk_buff *); +void tfrc_lh_cleanup(struct tfrc_loss_hist *lh); + +#endif /* _DCCP_LI_HIST_ */ diff --git a/net/dccp/ccids/lib/packet_history.c b/net/dccp/ccids/lib/packet_history.c new file mode 100644 index 000000000..0cdda3c66 --- /dev/null +++ b/net/dccp/ccids/lib/packet_history.c @@ -0,0 +1,439 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (c) 2007 The University of Aberdeen, Scotland, UK + * Copyright (c) 2005-7 The University of Waikato, Hamilton, New Zealand. + * + * An implementation of the DCCP protocol + * + * This code has been developed by the University of Waikato WAND + * research group. For further information please see https://www.wand.net.nz/ + * or e-mail Ian McDonald - ian.mcdonald@jandi.co.nz + * + * This code also uses code from Lulea University, rereleased as GPL by its + * authors: + * Copyright (c) 2003 Nils-Erik Mattsson, Joacim Haggmark, Magnus Erixzon + * + * Changes to meet Linux coding standards, to make it meet latest ccid3 draft + * and to make it work as a loadable module in the DCCP stack written by + * Arnaldo Carvalho de Melo . + * + * Copyright (c) 2005 Arnaldo Carvalho de Melo + */ + +#include +#include +#include "packet_history.h" +#include "../../dccp.h" + +/* + * Transmitter History Routines + */ +static struct kmem_cache *tfrc_tx_hist_slab; + +int __init tfrc_tx_packet_history_init(void) +{ + tfrc_tx_hist_slab = kmem_cache_create("tfrc_tx_hist", + sizeof(struct tfrc_tx_hist_entry), + 0, SLAB_HWCACHE_ALIGN, NULL); + return tfrc_tx_hist_slab == NULL ? -ENOBUFS : 0; +} + +void tfrc_tx_packet_history_exit(void) +{ + if (tfrc_tx_hist_slab != NULL) { + kmem_cache_destroy(tfrc_tx_hist_slab); + tfrc_tx_hist_slab = NULL; + } +} + +int tfrc_tx_hist_add(struct tfrc_tx_hist_entry **headp, u64 seqno) +{ + struct tfrc_tx_hist_entry *entry = kmem_cache_alloc(tfrc_tx_hist_slab, gfp_any()); + + if (entry == NULL) + return -ENOBUFS; + entry->seqno = seqno; + entry->stamp = ktime_get_real(); + entry->next = *headp; + *headp = entry; + return 0; +} + +void tfrc_tx_hist_purge(struct tfrc_tx_hist_entry **headp) +{ + struct tfrc_tx_hist_entry *head = *headp; + + while (head != NULL) { + struct tfrc_tx_hist_entry *next = head->next; + + kmem_cache_free(tfrc_tx_hist_slab, head); + head = next; + } + + *headp = NULL; +} + +/* + * Receiver History Routines + */ +static struct kmem_cache *tfrc_rx_hist_slab; + +int __init tfrc_rx_packet_history_init(void) +{ + tfrc_rx_hist_slab = kmem_cache_create("tfrc_rxh_cache", + sizeof(struct tfrc_rx_hist_entry), + 0, SLAB_HWCACHE_ALIGN, NULL); + return tfrc_rx_hist_slab == NULL ? -ENOBUFS : 0; +} + +void tfrc_rx_packet_history_exit(void) +{ + if (tfrc_rx_hist_slab != NULL) { + kmem_cache_destroy(tfrc_rx_hist_slab); + tfrc_rx_hist_slab = NULL; + } +} + +static inline void tfrc_rx_hist_entry_from_skb(struct tfrc_rx_hist_entry *entry, + const struct sk_buff *skb, + const u64 ndp) +{ + const struct dccp_hdr *dh = dccp_hdr(skb); + + entry->tfrchrx_seqno = DCCP_SKB_CB(skb)->dccpd_seq; + entry->tfrchrx_ccval = dh->dccph_ccval; + entry->tfrchrx_type = dh->dccph_type; + entry->tfrchrx_ndp = ndp; + entry->tfrchrx_tstamp = ktime_get_real(); +} + +void tfrc_rx_hist_add_packet(struct tfrc_rx_hist *h, + const struct sk_buff *skb, + const u64 ndp) +{ + struct tfrc_rx_hist_entry *entry = tfrc_rx_hist_last_rcv(h); + + tfrc_rx_hist_entry_from_skb(entry, skb, ndp); +} + +/* has the packet contained in skb been seen before? */ +int tfrc_rx_hist_duplicate(struct tfrc_rx_hist *h, struct sk_buff *skb) +{ + const u64 seq = DCCP_SKB_CB(skb)->dccpd_seq; + int i; + + if (dccp_delta_seqno(tfrc_rx_hist_loss_prev(h)->tfrchrx_seqno, seq) <= 0) + return 1; + + for (i = 1; i <= h->loss_count; i++) + if (tfrc_rx_hist_entry(h, i)->tfrchrx_seqno == seq) + return 1; + + return 0; +} + +static void tfrc_rx_hist_swap(struct tfrc_rx_hist *h, const u8 a, const u8 b) +{ + const u8 idx_a = tfrc_rx_hist_index(h, a), + idx_b = tfrc_rx_hist_index(h, b); + + swap(h->ring[idx_a], h->ring[idx_b]); +} + +/* + * Private helper functions for loss detection. + * + * In the descriptions, `Si' refers to the sequence number of entry number i, + * whose NDP count is `Ni' (lower case is used for variables). + * Note: All __xxx_loss functions expect that a test against duplicates has been + * performed already: the seqno of the skb must not be less than the seqno + * of loss_prev; and it must not equal that of any valid history entry. + */ +static void __do_track_loss(struct tfrc_rx_hist *h, struct sk_buff *skb, u64 n1) +{ + u64 s0 = tfrc_rx_hist_loss_prev(h)->tfrchrx_seqno, + s1 = DCCP_SKB_CB(skb)->dccpd_seq; + + if (!dccp_loss_free(s0, s1, n1)) { /* gap between S0 and S1 */ + h->loss_count = 1; + tfrc_rx_hist_entry_from_skb(tfrc_rx_hist_entry(h, 1), skb, n1); + } +} + +static void __one_after_loss(struct tfrc_rx_hist *h, struct sk_buff *skb, u32 n2) +{ + u64 s0 = tfrc_rx_hist_loss_prev(h)->tfrchrx_seqno, + s1 = tfrc_rx_hist_entry(h, 1)->tfrchrx_seqno, + s2 = DCCP_SKB_CB(skb)->dccpd_seq; + + if (likely(dccp_delta_seqno(s1, s2) > 0)) { /* S1 < S2 */ + h->loss_count = 2; + tfrc_rx_hist_entry_from_skb(tfrc_rx_hist_entry(h, 2), skb, n2); + return; + } + + /* S0 < S2 < S1 */ + + if (dccp_loss_free(s0, s2, n2)) { + u64 n1 = tfrc_rx_hist_entry(h, 1)->tfrchrx_ndp; + + if (dccp_loss_free(s2, s1, n1)) { + /* hole is filled: S0, S2, and S1 are consecutive */ + h->loss_count = 0; + h->loss_start = tfrc_rx_hist_index(h, 1); + } else + /* gap between S2 and S1: just update loss_prev */ + tfrc_rx_hist_entry_from_skb(tfrc_rx_hist_loss_prev(h), skb, n2); + + } else { /* gap between S0 and S2 */ + /* + * Reorder history to insert S2 between S0 and S1 + */ + tfrc_rx_hist_swap(h, 0, 3); + h->loss_start = tfrc_rx_hist_index(h, 3); + tfrc_rx_hist_entry_from_skb(tfrc_rx_hist_entry(h, 1), skb, n2); + h->loss_count = 2; + } +} + +/* return 1 if a new loss event has been identified */ +static int __two_after_loss(struct tfrc_rx_hist *h, struct sk_buff *skb, u32 n3) +{ + u64 s0 = tfrc_rx_hist_loss_prev(h)->tfrchrx_seqno, + s1 = tfrc_rx_hist_entry(h, 1)->tfrchrx_seqno, + s2 = tfrc_rx_hist_entry(h, 2)->tfrchrx_seqno, + s3 = DCCP_SKB_CB(skb)->dccpd_seq; + + if (likely(dccp_delta_seqno(s2, s3) > 0)) { /* S2 < S3 */ + h->loss_count = 3; + tfrc_rx_hist_entry_from_skb(tfrc_rx_hist_entry(h, 3), skb, n3); + return 1; + } + + /* S3 < S2 */ + + if (dccp_delta_seqno(s1, s3) > 0) { /* S1 < S3 < S2 */ + /* + * Reorder history to insert S3 between S1 and S2 + */ + tfrc_rx_hist_swap(h, 2, 3); + tfrc_rx_hist_entry_from_skb(tfrc_rx_hist_entry(h, 2), skb, n3); + h->loss_count = 3; + return 1; + } + + /* S0 < S3 < S1 */ + + if (dccp_loss_free(s0, s3, n3)) { + u64 n1 = tfrc_rx_hist_entry(h, 1)->tfrchrx_ndp; + + if (dccp_loss_free(s3, s1, n1)) { + /* hole between S0 and S1 filled by S3 */ + u64 n2 = tfrc_rx_hist_entry(h, 2)->tfrchrx_ndp; + + if (dccp_loss_free(s1, s2, n2)) { + /* entire hole filled by S0, S3, S1, S2 */ + h->loss_start = tfrc_rx_hist_index(h, 2); + h->loss_count = 0; + } else { + /* gap remains between S1 and S2 */ + h->loss_start = tfrc_rx_hist_index(h, 1); + h->loss_count = 1; + } + + } else /* gap exists between S3 and S1, loss_count stays at 2 */ + tfrc_rx_hist_entry_from_skb(tfrc_rx_hist_loss_prev(h), skb, n3); + + return 0; + } + + /* + * The remaining case: S0 < S3 < S1 < S2; gap between S0 and S3 + * Reorder history to insert S3 between S0 and S1. + */ + tfrc_rx_hist_swap(h, 0, 3); + h->loss_start = tfrc_rx_hist_index(h, 3); + tfrc_rx_hist_entry_from_skb(tfrc_rx_hist_entry(h, 1), skb, n3); + h->loss_count = 3; + + return 1; +} + +/* recycle RX history records to continue loss detection if necessary */ +static void __three_after_loss(struct tfrc_rx_hist *h) +{ + /* + * At this stage we know already that there is a gap between S0 and S1 + * (since S0 was the highest sequence number received before detecting + * the loss). To recycle the loss record, it is thus only necessary to + * check for other possible gaps between S1/S2 and between S2/S3. + */ + u64 s1 = tfrc_rx_hist_entry(h, 1)->tfrchrx_seqno, + s2 = tfrc_rx_hist_entry(h, 2)->tfrchrx_seqno, + s3 = tfrc_rx_hist_entry(h, 3)->tfrchrx_seqno; + u64 n2 = tfrc_rx_hist_entry(h, 2)->tfrchrx_ndp, + n3 = tfrc_rx_hist_entry(h, 3)->tfrchrx_ndp; + + if (dccp_loss_free(s1, s2, n2)) { + + if (dccp_loss_free(s2, s3, n3)) { + /* no gap between S2 and S3: entire hole is filled */ + h->loss_start = tfrc_rx_hist_index(h, 3); + h->loss_count = 0; + } else { + /* gap between S2 and S3 */ + h->loss_start = tfrc_rx_hist_index(h, 2); + h->loss_count = 1; + } + + } else { /* gap between S1 and S2 */ + h->loss_start = tfrc_rx_hist_index(h, 1); + h->loss_count = 2; + } +} + +/** + * tfrc_rx_handle_loss - Loss detection and further processing + * @h: The non-empty RX history object + * @lh: Loss Intervals database to update + * @skb: Currently received packet + * @ndp: The NDP count belonging to @skb + * @calc_first_li: Caller-dependent computation of first loss interval in @lh + * @sk: Used by @calc_first_li (see tfrc_lh_interval_add) + * + * Chooses action according to pending loss, updates LI database when a new + * loss was detected, and does required post-processing. Returns 1 when caller + * should send feedback, 0 otherwise. + * Since it also takes care of reordering during loss detection and updates the + * records accordingly, the caller should not perform any more RX history + * operations when loss_count is greater than 0 after calling this function. + */ +int tfrc_rx_handle_loss(struct tfrc_rx_hist *h, + struct tfrc_loss_hist *lh, + struct sk_buff *skb, const u64 ndp, + u32 (*calc_first_li)(struct sock *), struct sock *sk) +{ + int is_new_loss = 0; + + if (h->loss_count == 0) { + __do_track_loss(h, skb, ndp); + } else if (h->loss_count == 1) { + __one_after_loss(h, skb, ndp); + } else if (h->loss_count != 2) { + DCCP_BUG("invalid loss_count %d", h->loss_count); + } else if (__two_after_loss(h, skb, ndp)) { + /* + * Update Loss Interval database and recycle RX records + */ + is_new_loss = tfrc_lh_interval_add(lh, h, calc_first_li, sk); + __three_after_loss(h); + } + return is_new_loss; +} + +int tfrc_rx_hist_alloc(struct tfrc_rx_hist *h) +{ + int i; + + for (i = 0; i <= TFRC_NDUPACK; i++) { + h->ring[i] = kmem_cache_alloc(tfrc_rx_hist_slab, GFP_ATOMIC); + if (h->ring[i] == NULL) + goto out_free; + } + + h->loss_count = h->loss_start = 0; + return 0; + +out_free: + while (i-- != 0) { + kmem_cache_free(tfrc_rx_hist_slab, h->ring[i]); + h->ring[i] = NULL; + } + return -ENOBUFS; +} + +void tfrc_rx_hist_purge(struct tfrc_rx_hist *h) +{ + int i; + + for (i = 0; i <= TFRC_NDUPACK; ++i) + if (h->ring[i] != NULL) { + kmem_cache_free(tfrc_rx_hist_slab, h->ring[i]); + h->ring[i] = NULL; + } +} + +/** + * tfrc_rx_hist_rtt_last_s - reference entry to compute RTT samples against + * @h: The non-empty RX history object + */ +static inline struct tfrc_rx_hist_entry * + tfrc_rx_hist_rtt_last_s(const struct tfrc_rx_hist *h) +{ + return h->ring[0]; +} + +/** + * tfrc_rx_hist_rtt_prev_s - previously suitable (wrt rtt_last_s) RTT-sampling entry + * @h: The non-empty RX history object + */ +static inline struct tfrc_rx_hist_entry * + tfrc_rx_hist_rtt_prev_s(const struct tfrc_rx_hist *h) +{ + return h->ring[h->rtt_sample_prev]; +} + +/** + * tfrc_rx_hist_sample_rtt - Sample RTT from timestamp / CCVal + * @h: receive histogram + * @skb: packet containing timestamp. + * + * Based on ideas presented in RFC 4342, 8.1. Returns 0 if it was not able + * to compute a sample with given data - calling function should check this. + */ +u32 tfrc_rx_hist_sample_rtt(struct tfrc_rx_hist *h, const struct sk_buff *skb) +{ + u32 sample = 0, + delta_v = SUB16(dccp_hdr(skb)->dccph_ccval, + tfrc_rx_hist_rtt_last_s(h)->tfrchrx_ccval); + + if (delta_v < 1 || delta_v > 4) { /* unsuitable CCVal delta */ + if (h->rtt_sample_prev == 2) { /* previous candidate stored */ + sample = SUB16(tfrc_rx_hist_rtt_prev_s(h)->tfrchrx_ccval, + tfrc_rx_hist_rtt_last_s(h)->tfrchrx_ccval); + if (sample) + sample = 4 / sample * + ktime_us_delta(tfrc_rx_hist_rtt_prev_s(h)->tfrchrx_tstamp, + tfrc_rx_hist_rtt_last_s(h)->tfrchrx_tstamp); + else /* + * FIXME: This condition is in principle not + * possible but occurs when CCID is used for + * two-way data traffic. I have tried to trace + * it, but the cause does not seem to be here. + */ + DCCP_BUG("please report to dccp@vger.kernel.org" + " => prev = %u, last = %u", + tfrc_rx_hist_rtt_prev_s(h)->tfrchrx_ccval, + tfrc_rx_hist_rtt_last_s(h)->tfrchrx_ccval); + } else if (delta_v < 1) { + h->rtt_sample_prev = 1; + goto keep_ref_for_next_time; + } + + } else if (delta_v == 4) /* optimal match */ + sample = ktime_to_us(net_timedelta(tfrc_rx_hist_rtt_last_s(h)->tfrchrx_tstamp)); + else { /* suboptimal match */ + h->rtt_sample_prev = 2; + goto keep_ref_for_next_time; + } + + if (unlikely(sample > DCCP_SANE_RTT_MAX)) { + DCCP_WARN("RTT sample %u too large, using max\n", sample); + sample = DCCP_SANE_RTT_MAX; + } + + h->rtt_sample_prev = 0; /* use current entry as next reference */ +keep_ref_for_next_time: + + return sample; +} diff --git a/net/dccp/ccids/lib/packet_history.h b/net/dccp/ccids/lib/packet_history.h new file mode 100644 index 000000000..159cc9326 --- /dev/null +++ b/net/dccp/ccids/lib/packet_history.h @@ -0,0 +1,142 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * Packet RX/TX history data structures and routines for TFRC-based protocols. + * + * Copyright (c) 2007 The University of Aberdeen, Scotland, UK + * Copyright (c) 2005-6 The University of Waikato, Hamilton, New Zealand. + * + * This code has been developed by the University of Waikato WAND + * research group. For further information please see https://www.wand.net.nz/ + * or e-mail Ian McDonald - ian.mcdonald@jandi.co.nz + * + * This code also uses code from Lulea University, rereleased as GPL by its + * authors: + * Copyright (c) 2003 Nils-Erik Mattsson, Joacim Haggmark, Magnus Erixzon + * + * Changes to meet Linux coding standards, to make it meet latest ccid3 draft + * and to make it work as a loadable module in the DCCP stack written by + * Arnaldo Carvalho de Melo . + * + * Copyright (c) 2005 Arnaldo Carvalho de Melo + */ + +#ifndef _DCCP_PKT_HIST_ +#define _DCCP_PKT_HIST_ + +#include +#include +#include "tfrc.h" + +/** + * tfrc_tx_hist_entry - Simple singly-linked TX history list + * @next: next oldest entry (LIFO order) + * @seqno: sequence number of this entry + * @stamp: send time of packet with sequence number @seqno + */ +struct tfrc_tx_hist_entry { + struct tfrc_tx_hist_entry *next; + u64 seqno; + ktime_t stamp; +}; + +static inline struct tfrc_tx_hist_entry * + tfrc_tx_hist_find_entry(struct tfrc_tx_hist_entry *head, u64 seqno) +{ + while (head != NULL && head->seqno != seqno) + head = head->next; + return head; +} + +int tfrc_tx_hist_add(struct tfrc_tx_hist_entry **headp, u64 seqno); +void tfrc_tx_hist_purge(struct tfrc_tx_hist_entry **headp); + +/* Subtraction a-b modulo-16, respects circular wrap-around */ +#define SUB16(a, b) (((a) + 16 - (b)) & 0xF) + +/* Number of packets to wait after a missing packet (RFC 4342, 6.1) */ +#define TFRC_NDUPACK 3 + +/** + * tfrc_rx_hist_entry - Store information about a single received packet + * @tfrchrx_seqno: DCCP packet sequence number + * @tfrchrx_ccval: window counter value of packet (RFC 4342, 8.1) + * @tfrchrx_ndp: the NDP count (if any) of the packet + * @tfrchrx_tstamp: actual receive time of packet + */ +struct tfrc_rx_hist_entry { + u64 tfrchrx_seqno:48, + tfrchrx_ccval:4, + tfrchrx_type:4; + u64 tfrchrx_ndp:48; + ktime_t tfrchrx_tstamp; +}; + +/** + * tfrc_rx_hist - RX history structure for TFRC-based protocols + * @ring: Packet history for RTT sampling and loss detection + * @loss_count: Number of entries in circular history + * @loss_start: Movable index (for loss detection) + * @rtt_sample_prev: Used during RTT sampling, points to candidate entry + */ +struct tfrc_rx_hist { + struct tfrc_rx_hist_entry *ring[TFRC_NDUPACK + 1]; + u8 loss_count:2, + loss_start:2; +#define rtt_sample_prev loss_start +}; + +/** + * tfrc_rx_hist_index - index to reach n-th entry after loss_start + */ +static inline u8 tfrc_rx_hist_index(const struct tfrc_rx_hist *h, const u8 n) +{ + return (h->loss_start + n) & TFRC_NDUPACK; +} + +/** + * tfrc_rx_hist_last_rcv - entry with highest-received-seqno so far + */ +static inline struct tfrc_rx_hist_entry * + tfrc_rx_hist_last_rcv(const struct tfrc_rx_hist *h) +{ + return h->ring[tfrc_rx_hist_index(h, h->loss_count)]; +} + +/** + * tfrc_rx_hist_entry - return the n-th history entry after loss_start + */ +static inline struct tfrc_rx_hist_entry * + tfrc_rx_hist_entry(const struct tfrc_rx_hist *h, const u8 n) +{ + return h->ring[tfrc_rx_hist_index(h, n)]; +} + +/** + * tfrc_rx_hist_loss_prev - entry with highest-received-seqno before loss was detected + */ +static inline struct tfrc_rx_hist_entry * + tfrc_rx_hist_loss_prev(const struct tfrc_rx_hist *h) +{ + return h->ring[h->loss_start]; +} + +/* indicate whether previously a packet was detected missing */ +static inline bool tfrc_rx_hist_loss_pending(const struct tfrc_rx_hist *h) +{ + return h->loss_count > 0; +} + +void tfrc_rx_hist_add_packet(struct tfrc_rx_hist *h, const struct sk_buff *skb, + const u64 ndp); + +int tfrc_rx_hist_duplicate(struct tfrc_rx_hist *h, struct sk_buff *skb); + +struct tfrc_loss_hist; +int tfrc_rx_handle_loss(struct tfrc_rx_hist *h, struct tfrc_loss_hist *lh, + struct sk_buff *skb, const u64 ndp, + u32 (*first_li)(struct sock *sk), struct sock *sk); +u32 tfrc_rx_hist_sample_rtt(struct tfrc_rx_hist *h, const struct sk_buff *skb); +int tfrc_rx_hist_alloc(struct tfrc_rx_hist *h); +void tfrc_rx_hist_purge(struct tfrc_rx_hist *h); + +#endif /* _DCCP_PKT_HIST_ */ diff --git a/net/dccp/ccids/lib/tfrc.c b/net/dccp/ccids/lib/tfrc.c new file mode 100644 index 000000000..d7f265e1f --- /dev/null +++ b/net/dccp/ccids/lib/tfrc.c @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * TFRC library initialisation + * + * Copyright (c) 2007 The University of Aberdeen, Scotland, UK + * Copyright (c) 2007 Arnaldo Carvalho de Melo + */ +#include +#include "tfrc.h" + +#ifdef CONFIG_IP_DCCP_TFRC_DEBUG +bool tfrc_debug; +module_param(tfrc_debug, bool, 0644); +MODULE_PARM_DESC(tfrc_debug, "Enable TFRC debug messages"); +#endif + +int __init tfrc_lib_init(void) +{ + int rc = tfrc_li_init(); + + if (rc) + goto out; + + rc = tfrc_tx_packet_history_init(); + if (rc) + goto out_free_loss_intervals; + + rc = tfrc_rx_packet_history_init(); + if (rc) + goto out_free_tx_history; + return 0; + +out_free_tx_history: + tfrc_tx_packet_history_exit(); +out_free_loss_intervals: + tfrc_li_exit(); +out: + return rc; +} + +void tfrc_lib_exit(void) +{ + tfrc_rx_packet_history_exit(); + tfrc_tx_packet_history_exit(); + tfrc_li_exit(); +} diff --git a/net/dccp/ccids/lib/tfrc.h b/net/dccp/ccids/lib/tfrc.h new file mode 100644 index 000000000..0a63e8750 --- /dev/null +++ b/net/dccp/ccids/lib/tfrc.h @@ -0,0 +1,73 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +#ifndef _TFRC_H_ +#define _TFRC_H_ +/* + * Copyright (c) 2007 The University of Aberdeen, Scotland, UK + * Copyright (c) 2005-6 The University of Waikato, Hamilton, New Zealand. + * Copyright (c) 2005-6 Ian McDonald + * Copyright (c) 2005 Arnaldo Carvalho de Melo + * Copyright (c) 2003 Nils-Erik Mattsson, Joacim Haggmark, Magnus Erixzon + */ +#include +#include +#include "../../dccp.h" + +/* internal includes that this library exports: */ +#include "loss_interval.h" +#include "packet_history.h" + +#ifdef CONFIG_IP_DCCP_TFRC_DEBUG +extern bool tfrc_debug; +#define tfrc_pr_debug(format, a...) DCCP_PR_DEBUG(tfrc_debug, format, ##a) +#else +#define tfrc_pr_debug(format, a...) +#endif + +/* integer-arithmetic divisions of type (a * 1000000)/b */ +static inline u64 scaled_div(u64 a, u64 b) +{ + BUG_ON(b == 0); + return div64_u64(a * 1000000, b); +} + +static inline u32 scaled_div32(u64 a, u64 b) +{ + u64 result = scaled_div(a, b); + + if (result > UINT_MAX) { + DCCP_CRIT("Overflow: %llu/%llu > UINT_MAX", + (unsigned long long)a, (unsigned long long)b); + return UINT_MAX; + } + return result; +} + +/** + * tfrc_ewma - Exponentially weighted moving average + * @weight: Weight to be used as damping factor, in units of 1/10 + */ +static inline u32 tfrc_ewma(const u32 avg, const u32 newval, const u8 weight) +{ + return avg ? (weight * avg + (10 - weight) * newval) / 10 : newval; +} + +u32 tfrc_calc_x(u16 s, u32 R, u32 p); +u32 tfrc_calc_x_reverse_lookup(u32 fvalue); +u32 tfrc_invert_loss_event_rate(u32 loss_event_rate); + +int tfrc_tx_packet_history_init(void); +void tfrc_tx_packet_history_exit(void); +int tfrc_rx_packet_history_init(void); +void tfrc_rx_packet_history_exit(void); + +int tfrc_li_init(void); +void tfrc_li_exit(void); + +#ifdef CONFIG_IP_DCCP_TFRC_LIB +int tfrc_lib_init(void); +void tfrc_lib_exit(void); +#else +#define tfrc_lib_init() (0) +#define tfrc_lib_exit() +#endif +#endif /* _TFRC_H_ */ diff --git a/net/dccp/ccids/lib/tfrc_equation.c b/net/dccp/ccids/lib/tfrc_equation.c new file mode 100644 index 000000000..92a8c6bea --- /dev/null +++ b/net/dccp/ccids/lib/tfrc_equation.c @@ -0,0 +1,702 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (c) 2005 The University of Waikato, Hamilton, New Zealand. + * Copyright (c) 2005 Ian McDonald + * Copyright (c) 2005 Arnaldo Carvalho de Melo + * Copyright (c) 2003 Nils-Erik Mattsson, Joacim Haggmark, Magnus Erixzon + */ + +#include +#include "../../dccp.h" +#include "tfrc.h" + +#define TFRC_CALC_X_ARRSIZE 500 +#define TFRC_CALC_X_SPLIT 50000 /* 0.05 * 1000000, details below */ +#define TFRC_SMALLEST_P (TFRC_CALC_X_SPLIT/TFRC_CALC_X_ARRSIZE) + +/* + TFRC TCP Reno Throughput Equation Lookup Table for f(p) + + The following two-column lookup table implements a part of the TCP throughput + equation from [RFC 3448, sec. 3.1]: + + s + X_calc = -------------------------------------------------------------- + R * sqrt(2*b*p/3) + (3 * t_RTO * sqrt(3*b*p/8) * (p + 32*p^3)) + + Where: + X is the transmit rate in bytes/second + s is the packet size in bytes + R is the round trip time in seconds + p is the loss event rate, between 0 and 1.0, of the number of loss + events as a fraction of the number of packets transmitted + t_RTO is the TCP retransmission timeout value in seconds + b is the number of packets acknowledged by a single TCP ACK + + We can assume that b = 1 and t_RTO is 4 * R. The equation now becomes: + + s + X_calc = ------------------------------------------------------- + R * sqrt(p*2/3) + (12 * R * sqrt(p*3/8) * (p + 32*p^3)) + + which we can break down into: + + s + X_calc = --------- + R * f(p) + + where f(p) is given for 0 < p <= 1 by: + + f(p) = sqrt(2*p/3) + 12 * sqrt(3*p/8) * (p + 32*p^3) + + Since this is kernel code, floating-point arithmetic is avoided in favour of + integer arithmetic. This means that nearly all fractional parameters are + scaled by 1000000: + * the parameters p and R + * the return result f(p) + The lookup table therefore actually tabulates the following function g(q): + + g(q) = 1000000 * f(q/1000000) + + Hence, when p <= 1, q must be less than or equal to 1000000. To achieve finer + granularity for the practically more relevant case of small values of p (up to + 5%), the second column is used; the first one ranges up to 100%. This split + corresponds to the value of q = TFRC_CALC_X_SPLIT. At the same time this also + determines the smallest resolution possible with this lookup table: + + TFRC_SMALLEST_P = TFRC_CALC_X_SPLIT / TFRC_CALC_X_ARRSIZE + + The entire table is generated by: + for(i=0; i < TFRC_CALC_X_ARRSIZE; i++) { + lookup[i][0] = g((i+1) * 1000000/TFRC_CALC_X_ARRSIZE); + lookup[i][1] = g((i+1) * TFRC_CALC_X_SPLIT/TFRC_CALC_X_ARRSIZE); + } + + With the given configuration, we have, with M = TFRC_CALC_X_ARRSIZE-1, + lookup[0][0] = g(1000000/(M+1)) = 1000000 * f(0.2%) + lookup[M][0] = g(1000000) = 1000000 * f(100%) + lookup[0][1] = g(TFRC_SMALLEST_P) = 1000000 * f(0.01%) + lookup[M][1] = g(TFRC_CALC_X_SPLIT) = 1000000 * f(5%) + + In summary, the two columns represent f(p) for the following ranges: + * The first column is for 0.002 <= p <= 1.0 + * The second column is for 0.0001 <= p <= 0.05 + Where the columns overlap, the second (finer-grained) is given preference, + i.e. the first column is used only for p >= 0.05. + */ +static const u32 tfrc_calc_x_lookup[TFRC_CALC_X_ARRSIZE][2] = { + { 37172, 8172 }, + { 53499, 11567 }, + { 66664, 14180 }, + { 78298, 16388 }, + { 89021, 18339 }, + { 99147, 20108 }, + { 108858, 21738 }, + { 118273, 23260 }, + { 127474, 24693 }, + { 136520, 26052 }, + { 145456, 27348 }, + { 154316, 28589 }, + { 163130, 29783 }, + { 171919, 30935 }, + { 180704, 32049 }, + { 189502, 33130 }, + { 198328, 34180 }, + { 207194, 35202 }, + { 216114, 36198 }, + { 225097, 37172 }, + { 234153, 38123 }, + { 243294, 39055 }, + { 252527, 39968 }, + { 261861, 40864 }, + { 271305, 41743 }, + { 280866, 42607 }, + { 290553, 43457 }, + { 300372, 44293 }, + { 310333, 45117 }, + { 320441, 45929 }, + { 330705, 46729 }, + { 341131, 47518 }, + { 351728, 48297 }, + { 362501, 49066 }, + { 373460, 49826 }, + { 384609, 50577 }, + { 395958, 51320 }, + { 407513, 52054 }, + { 419281, 52780 }, + { 431270, 53499 }, + { 443487, 54211 }, + { 455940, 54916 }, + { 468635, 55614 }, + { 481581, 56306 }, + { 494785, 56991 }, + { 508254, 57671 }, + { 521996, 58345 }, + { 536019, 59014 }, + { 550331, 59677 }, + { 564939, 60335 }, + { 579851, 60988 }, + { 595075, 61636 }, + { 610619, 62279 }, + { 626491, 62918 }, + { 642700, 63553 }, + { 659253, 64183 }, + { 676158, 64809 }, + { 693424, 65431 }, + { 711060, 66050 }, + { 729073, 66664 }, + { 747472, 67275 }, + { 766266, 67882 }, + { 785464, 68486 }, + { 805073, 69087 }, + { 825103, 69684 }, + { 845562, 70278 }, + { 866460, 70868 }, + { 887805, 71456 }, + { 909606, 72041 }, + { 931873, 72623 }, + { 954614, 73202 }, + { 977839, 73778 }, + { 1001557, 74352 }, + { 1025777, 74923 }, + { 1050508, 75492 }, + { 1075761, 76058 }, + { 1101544, 76621 }, + { 1127867, 77183 }, + { 1154739, 77741 }, + { 1182172, 78298 }, + { 1210173, 78852 }, + { 1238753, 79405 }, + { 1267922, 79955 }, + { 1297689, 80503 }, + { 1328066, 81049 }, + { 1359060, 81593 }, + { 1390684, 82135 }, + { 1422947, 82675 }, + { 1455859, 83213 }, + { 1489430, 83750 }, + { 1523671, 84284 }, + { 1558593, 84817 }, + { 1594205, 85348 }, + { 1630518, 85878 }, + { 1667543, 86406 }, + { 1705290, 86932 }, + { 1743770, 87457 }, + { 1782994, 87980 }, + { 1822973, 88501 }, + { 1863717, 89021 }, + { 1905237, 89540 }, + { 1947545, 90057 }, + { 1990650, 90573 }, + { 2034566, 91087 }, + { 2079301, 91600 }, + { 2124869, 92111 }, + { 2171279, 92622 }, + { 2218543, 93131 }, + { 2266673, 93639 }, + { 2315680, 94145 }, + { 2365575, 94650 }, + { 2416371, 95154 }, + { 2468077, 95657 }, + { 2520707, 96159 }, + { 2574271, 96660 }, + { 2628782, 97159 }, + { 2684250, 97658 }, + { 2740689, 98155 }, + { 2798110, 98651 }, + { 2856524, 99147 }, + { 2915944, 99641 }, + { 2976382, 100134 }, + { 3037850, 100626 }, + { 3100360, 101117 }, + { 3163924, 101608 }, + { 3228554, 102097 }, + { 3294263, 102586 }, + { 3361063, 103073 }, + { 3428966, 103560 }, + { 3497984, 104045 }, + { 3568131, 104530 }, + { 3639419, 105014 }, + { 3711860, 105498 }, + { 3785467, 105980 }, + { 3860253, 106462 }, + { 3936229, 106942 }, + { 4013410, 107422 }, + { 4091808, 107902 }, + { 4171435, 108380 }, + { 4252306, 108858 }, + { 4334431, 109335 }, + { 4417825, 109811 }, + { 4502501, 110287 }, + { 4588472, 110762 }, + { 4675750, 111236 }, + { 4764349, 111709 }, + { 4854283, 112182 }, + { 4945564, 112654 }, + { 5038206, 113126 }, + { 5132223, 113597 }, + { 5227627, 114067 }, + { 5324432, 114537 }, + { 5422652, 115006 }, + { 5522299, 115474 }, + { 5623389, 115942 }, + { 5725934, 116409 }, + { 5829948, 116876 }, + { 5935446, 117342 }, + { 6042439, 117808 }, + { 6150943, 118273 }, + { 6260972, 118738 }, + { 6372538, 119202 }, + { 6485657, 119665 }, + { 6600342, 120128 }, + { 6716607, 120591 }, + { 6834467, 121053 }, + { 6953935, 121514 }, + { 7075025, 121976 }, + { 7197752, 122436 }, + { 7322131, 122896 }, + { 7448175, 123356 }, + { 7575898, 123815 }, + { 7705316, 124274 }, + { 7836442, 124733 }, + { 7969291, 125191 }, + { 8103877, 125648 }, + { 8240216, 126105 }, + { 8378321, 126562 }, + { 8518208, 127018 }, + { 8659890, 127474 }, + { 8803384, 127930 }, + { 8948702, 128385 }, + { 9095861, 128840 }, + { 9244875, 129294 }, + { 9395760, 129748 }, + { 9548529, 130202 }, + { 9703198, 130655 }, + { 9859782, 131108 }, + { 10018296, 131561 }, + { 10178755, 132014 }, + { 10341174, 132466 }, + { 10505569, 132917 }, + { 10671954, 133369 }, + { 10840345, 133820 }, + { 11010757, 134271 }, + { 11183206, 134721 }, + { 11357706, 135171 }, + { 11534274, 135621 }, + { 11712924, 136071 }, + { 11893673, 136520 }, + { 12076536, 136969 }, + { 12261527, 137418 }, + { 12448664, 137867 }, + { 12637961, 138315 }, + { 12829435, 138763 }, + { 13023101, 139211 }, + { 13218974, 139658 }, + { 13417071, 140106 }, + { 13617407, 140553 }, + { 13819999, 140999 }, + { 14024862, 141446 }, + { 14232012, 141892 }, + { 14441465, 142339 }, + { 14653238, 142785 }, + { 14867346, 143230 }, + { 15083805, 143676 }, + { 15302632, 144121 }, + { 15523842, 144566 }, + { 15747453, 145011 }, + { 15973479, 145456 }, + { 16201939, 145900 }, + { 16432847, 146345 }, + { 16666221, 146789 }, + { 16902076, 147233 }, + { 17140429, 147677 }, + { 17381297, 148121 }, + { 17624696, 148564 }, + { 17870643, 149007 }, + { 18119154, 149451 }, + { 18370247, 149894 }, + { 18623936, 150336 }, + { 18880241, 150779 }, + { 19139176, 151222 }, + { 19400759, 151664 }, + { 19665007, 152107 }, + { 19931936, 152549 }, + { 20201564, 152991 }, + { 20473907, 153433 }, + { 20748982, 153875 }, + { 21026807, 154316 }, + { 21307399, 154758 }, + { 21590773, 155199 }, + { 21876949, 155641 }, + { 22165941, 156082 }, + { 22457769, 156523 }, + { 22752449, 156964 }, + { 23049999, 157405 }, + { 23350435, 157846 }, + { 23653774, 158287 }, + { 23960036, 158727 }, + { 24269236, 159168 }, + { 24581392, 159608 }, + { 24896521, 160049 }, + { 25214642, 160489 }, + { 25535772, 160929 }, + { 25859927, 161370 }, + { 26187127, 161810 }, + { 26517388, 162250 }, + { 26850728, 162690 }, + { 27187165, 163130 }, + { 27526716, 163569 }, + { 27869400, 164009 }, + { 28215234, 164449 }, + { 28564236, 164889 }, + { 28916423, 165328 }, + { 29271815, 165768 }, + { 29630428, 166208 }, + { 29992281, 166647 }, + { 30357392, 167087 }, + { 30725779, 167526 }, + { 31097459, 167965 }, + { 31472452, 168405 }, + { 31850774, 168844 }, + { 32232445, 169283 }, + { 32617482, 169723 }, + { 33005904, 170162 }, + { 33397730, 170601 }, + { 33792976, 171041 }, + { 34191663, 171480 }, + { 34593807, 171919 }, + { 34999428, 172358 }, + { 35408544, 172797 }, + { 35821174, 173237 }, + { 36237335, 173676 }, + { 36657047, 174115 }, + { 37080329, 174554 }, + { 37507197, 174993 }, + { 37937673, 175433 }, + { 38371773, 175872 }, + { 38809517, 176311 }, + { 39250924, 176750 }, + { 39696012, 177190 }, + { 40144800, 177629 }, + { 40597308, 178068 }, + { 41053553, 178507 }, + { 41513554, 178947 }, + { 41977332, 179386 }, + { 42444904, 179825 }, + { 42916290, 180265 }, + { 43391509, 180704 }, + { 43870579, 181144 }, + { 44353520, 181583 }, + { 44840352, 182023 }, + { 45331092, 182462 }, + { 45825761, 182902 }, + { 46324378, 183342 }, + { 46826961, 183781 }, + { 47333531, 184221 }, + { 47844106, 184661 }, + { 48358706, 185101 }, + { 48877350, 185541 }, + { 49400058, 185981 }, + { 49926849, 186421 }, + { 50457743, 186861 }, + { 50992759, 187301 }, + { 51531916, 187741 }, + { 52075235, 188181 }, + { 52622735, 188622 }, + { 53174435, 189062 }, + { 53730355, 189502 }, + { 54290515, 189943 }, + { 54854935, 190383 }, + { 55423634, 190824 }, + { 55996633, 191265 }, + { 56573950, 191706 }, + { 57155606, 192146 }, + { 57741621, 192587 }, + { 58332014, 193028 }, + { 58926806, 193470 }, + { 59526017, 193911 }, + { 60129666, 194352 }, + { 60737774, 194793 }, + { 61350361, 195235 }, + { 61967446, 195677 }, + { 62589050, 196118 }, + { 63215194, 196560 }, + { 63845897, 197002 }, + { 64481179, 197444 }, + { 65121061, 197886 }, + { 65765563, 198328 }, + { 66414705, 198770 }, + { 67068508, 199213 }, + { 67726992, 199655 }, + { 68390177, 200098 }, + { 69058085, 200540 }, + { 69730735, 200983 }, + { 70408147, 201426 }, + { 71090343, 201869 }, + { 71777343, 202312 }, + { 72469168, 202755 }, + { 73165837, 203199 }, + { 73867373, 203642 }, + { 74573795, 204086 }, + { 75285124, 204529 }, + { 76001380, 204973 }, + { 76722586, 205417 }, + { 77448761, 205861 }, + { 78179926, 206306 }, + { 78916102, 206750 }, + { 79657310, 207194 }, + { 80403571, 207639 }, + { 81154906, 208084 }, + { 81911335, 208529 }, + { 82672880, 208974 }, + { 83439562, 209419 }, + { 84211402, 209864 }, + { 84988421, 210309 }, + { 85770640, 210755 }, + { 86558080, 211201 }, + { 87350762, 211647 }, + { 88148708, 212093 }, + { 88951938, 212539 }, + { 89760475, 212985 }, + { 90574339, 213432 }, + { 91393551, 213878 }, + { 92218133, 214325 }, + { 93048107, 214772 }, + { 93883493, 215219 }, + { 94724314, 215666 }, + { 95570590, 216114 }, + { 96422343, 216561 }, + { 97279594, 217009 }, + { 98142366, 217457 }, + { 99010679, 217905 }, + { 99884556, 218353 }, + { 100764018, 218801 }, + { 101649086, 219250 }, + { 102539782, 219698 }, + { 103436128, 220147 }, + { 104338146, 220596 }, + { 105245857, 221046 }, + { 106159284, 221495 }, + { 107078448, 221945 }, + { 108003370, 222394 }, + { 108934074, 222844 }, + { 109870580, 223294 }, + { 110812910, 223745 }, + { 111761087, 224195 }, + { 112715133, 224646 }, + { 113675069, 225097 }, + { 114640918, 225548 }, + { 115612702, 225999 }, + { 116590442, 226450 }, + { 117574162, 226902 }, + { 118563882, 227353 }, + { 119559626, 227805 }, + { 120561415, 228258 }, + { 121569272, 228710 }, + { 122583219, 229162 }, + { 123603278, 229615 }, + { 124629471, 230068 }, + { 125661822, 230521 }, + { 126700352, 230974 }, + { 127745083, 231428 }, + { 128796039, 231882 }, + { 129853241, 232336 }, + { 130916713, 232790 }, + { 131986475, 233244 }, + { 133062553, 233699 }, + { 134144966, 234153 }, + { 135233739, 234608 }, + { 136328894, 235064 }, + { 137430453, 235519 }, + { 138538440, 235975 }, + { 139652876, 236430 }, + { 140773786, 236886 }, + { 141901190, 237343 }, + { 143035113, 237799 }, + { 144175576, 238256 }, + { 145322604, 238713 }, + { 146476218, 239170 }, + { 147636442, 239627 }, + { 148803298, 240085 }, + { 149976809, 240542 }, + { 151156999, 241000 }, + { 152343890, 241459 }, + { 153537506, 241917 }, + { 154737869, 242376 }, + { 155945002, 242835 }, + { 157158929, 243294 }, + { 158379673, 243753 }, + { 159607257, 244213 }, + { 160841704, 244673 }, + { 162083037, 245133 }, + { 163331279, 245593 }, + { 164586455, 246054 }, + { 165848586, 246514 }, + { 167117696, 246975 }, + { 168393810, 247437 }, + { 169676949, 247898 }, + { 170967138, 248360 }, + { 172264399, 248822 }, + { 173568757, 249284 }, + { 174880235, 249747 }, + { 176198856, 250209 }, + { 177524643, 250672 }, + { 178857621, 251136 }, + { 180197813, 251599 }, + { 181545242, 252063 }, + { 182899933, 252527 }, + { 184261908, 252991 }, + { 185631191, 253456 }, + { 187007807, 253920 }, + { 188391778, 254385 }, + { 189783129, 254851 }, + { 191181884, 255316 }, + { 192588065, 255782 }, + { 194001698, 256248 }, + { 195422805, 256714 }, + { 196851411, 257181 }, + { 198287540, 257648 }, + { 199731215, 258115 }, + { 201182461, 258582 }, + { 202641302, 259050 }, + { 204107760, 259518 }, + { 205581862, 259986 }, + { 207063630, 260454 }, + { 208553088, 260923 }, + { 210050262, 261392 }, + { 211555174, 261861 }, + { 213067849, 262331 }, + { 214588312, 262800 }, + { 216116586, 263270 }, + { 217652696, 263741 }, + { 219196666, 264211 }, + { 220748520, 264682 }, + { 222308282, 265153 }, + { 223875978, 265625 }, + { 225451630, 266097 }, + { 227035265, 266569 }, + { 228626905, 267041 }, + { 230226576, 267514 }, + { 231834302, 267986 }, + { 233450107, 268460 }, + { 235074016, 268933 }, + { 236706054, 269407 }, + { 238346244, 269881 }, + { 239994613, 270355 }, + { 241651183, 270830 }, + { 243315981, 271305 } +}; + +/* return largest index i such that fval <= lookup[i][small] */ +static inline u32 tfrc_binsearch(u32 fval, u8 small) +{ + u32 try, low = 0, high = TFRC_CALC_X_ARRSIZE - 1; + + while (low < high) { + try = (low + high) / 2; + if (fval <= tfrc_calc_x_lookup[try][small]) + high = try; + else + low = try + 1; + } + return high; +} + +/** + * tfrc_calc_x - Calculate the send rate as per section 3.1 of RFC3448 + * @s: packet size in bytes + * @R: RTT scaled by 1000000 (i.e., microseconds) + * @p: loss ratio estimate scaled by 1000000 + * + * Returns X_calc in bytes per second (not scaled). + */ +u32 tfrc_calc_x(u16 s, u32 R, u32 p) +{ + u16 index; + u32 f; + u64 result; + + /* check against invalid parameters and divide-by-zero */ + BUG_ON(p > 1000000); /* p must not exceed 100% */ + BUG_ON(p == 0); /* f(0) = 0, divide by zero */ + if (R == 0) { /* possible divide by zero */ + DCCP_CRIT("WARNING: RTT is 0, returning maximum X_calc."); + return ~0U; + } + + if (p <= TFRC_CALC_X_SPLIT) { /* 0.0000 < p <= 0.05 */ + if (p < TFRC_SMALLEST_P) { /* 0.0000 < p < 0.0001 */ + DCCP_WARN("Value of p (%d) below resolution. " + "Substituting %d\n", p, TFRC_SMALLEST_P); + index = 0; + } else /* 0.0001 <= p <= 0.05 */ + index = p/TFRC_SMALLEST_P - 1; + + f = tfrc_calc_x_lookup[index][1]; + + } else { /* 0.05 < p <= 1.00 */ + index = p/(1000000/TFRC_CALC_X_ARRSIZE) - 1; + + f = tfrc_calc_x_lookup[index][0]; + } + + /* + * Compute X = s/(R*f(p)) in bytes per second. + * Since f(p) and R are both scaled by 1000000, we need to multiply by + * 1000000^2. To avoid overflow, the result is computed in two stages. + * This works under almost all reasonable operational conditions, for a + * wide range of parameters. Yet, should some strange combination of + * parameters result in overflow, the use of scaled_div32 will catch + * this and return UINT_MAX - which is a logically adequate consequence. + */ + result = scaled_div(s, R); + return scaled_div32(result, f); +} + +/** + * tfrc_calc_x_reverse_lookup - try to find p given f(p) + * @fvalue: function value to match, scaled by 1000000 + * + * Returns closest match for p, also scaled by 1000000 + */ +u32 tfrc_calc_x_reverse_lookup(u32 fvalue) +{ + int index; + + if (fvalue == 0) /* f(p) = 0 whenever p = 0 */ + return 0; + + /* Error cases. */ + if (fvalue < tfrc_calc_x_lookup[0][1]) { + DCCP_WARN("fvalue %u smaller than resolution\n", fvalue); + return TFRC_SMALLEST_P; + } + if (fvalue > tfrc_calc_x_lookup[TFRC_CALC_X_ARRSIZE - 1][0]) { + DCCP_WARN("fvalue %u exceeds bounds!\n", fvalue); + return 1000000; + } + + if (fvalue <= tfrc_calc_x_lookup[TFRC_CALC_X_ARRSIZE - 1][1]) { + index = tfrc_binsearch(fvalue, 1); + return (index + 1) * TFRC_CALC_X_SPLIT / TFRC_CALC_X_ARRSIZE; + } + + /* else ... it must be in the coarse-grained column */ + index = tfrc_binsearch(fvalue, 0); + return (index + 1) * 1000000 / TFRC_CALC_X_ARRSIZE; +} + +/** + * tfrc_invert_loss_event_rate - Compute p so that 10^6 corresponds to 100% + * @loss_event_rate: loss event rate to invert + * When @loss_event_rate is large, there is a chance that p is truncated to 0. + * To avoid re-entering slow-start in that case, we set p = TFRC_SMALLEST_P > 0. + */ +u32 tfrc_invert_loss_event_rate(u32 loss_event_rate) +{ + if (loss_event_rate == UINT_MAX) /* see RFC 4342, 8.5 */ + return 0; + if (unlikely(loss_event_rate == 0)) /* map 1/0 into 100% */ + return 1000000; + return max_t(u32, scaled_div(1, loss_event_rate), TFRC_SMALLEST_P); +} diff --git a/net/dccp/dccp.h b/net/dccp/dccp.h new file mode 100644 index 000000000..9ddc3a9e8 --- /dev/null +++ b/net/dccp/dccp.h @@ -0,0 +1,483 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +#ifndef _DCCP_H +#define _DCCP_H +/* + * net/dccp/dccp.h + * + * An implementation of the DCCP protocol + * Copyright (c) 2005 Arnaldo Carvalho de Melo + * Copyright (c) 2005-6 Ian McDonald + */ + +#include +#include +#include +#include +#include +#include "ackvec.h" + +/* + * DCCP - specific warning and debugging macros. + */ +#define DCCP_WARN(fmt, ...) \ + net_warn_ratelimited("%s: " fmt, __func__, ##__VA_ARGS__) +#define DCCP_CRIT(fmt, a...) printk(KERN_CRIT fmt " at %s:%d/%s()\n", ##a, \ + __FILE__, __LINE__, __func__) +#define DCCP_BUG(a...) do { DCCP_CRIT("BUG: " a); dump_stack(); } while(0) +#define DCCP_BUG_ON(cond) do { if (unlikely((cond) != 0)) \ + DCCP_BUG("\"%s\" holds (exception!)", \ + __stringify(cond)); \ + } while (0) + +#define DCCP_PRINTK(enable, fmt, args...) do { if (enable) \ + printk(fmt, ##args); \ + } while(0) +#define DCCP_PR_DEBUG(enable, fmt, a...) DCCP_PRINTK(enable, KERN_DEBUG \ + "%s: " fmt, __func__, ##a) + +#ifdef CONFIG_IP_DCCP_DEBUG +extern bool dccp_debug; +#define dccp_pr_debug(format, a...) DCCP_PR_DEBUG(dccp_debug, format, ##a) +#define dccp_pr_debug_cat(format, a...) DCCP_PRINTK(dccp_debug, format, ##a) +#define dccp_debug(fmt, a...) dccp_pr_debug_cat(KERN_DEBUG fmt, ##a) +#else +#define dccp_pr_debug(format, a...) do {} while (0) +#define dccp_pr_debug_cat(format, a...) do {} while (0) +#define dccp_debug(format, a...) do {} while (0) +#endif + +extern struct inet_hashinfo dccp_hashinfo; + +DECLARE_PER_CPU(unsigned int, dccp_orphan_count); + +void dccp_time_wait(struct sock *sk, int state, int timeo); + +/* + * Set safe upper bounds for header and option length. Since Data Offset is 8 + * bits (RFC 4340, sec. 5.1), the total header length can never be more than + * 4 * 255 = 1020 bytes. The largest possible header length is 28 bytes (X=1): + * - DCCP-Response with ACK Subheader and 4 bytes of Service code OR + * - DCCP-Reset with ACK Subheader and 4 bytes of Reset Code fields + * Hence a safe upper bound for the maximum option length is 1020-28 = 992 + */ +#define MAX_DCCP_SPECIFIC_HEADER (255 * sizeof(uint32_t)) +#define DCCP_MAX_PACKET_HDR 28 +#define DCCP_MAX_OPT_LEN (MAX_DCCP_SPECIFIC_HEADER - DCCP_MAX_PACKET_HDR) +#define MAX_DCCP_HEADER (MAX_DCCP_SPECIFIC_HEADER + MAX_HEADER) + +/* Upper bound for initial feature-negotiation overhead (padded to 32 bits) */ +#define DCCP_FEATNEG_OVERHEAD (32 * sizeof(uint32_t)) + +#define DCCP_TIMEWAIT_LEN (60 * HZ) /* how long to wait to destroy TIME-WAIT + * state, about 60 seconds */ + +/* RFC 1122, 4.2.3.1 initial RTO value */ +#define DCCP_TIMEOUT_INIT ((unsigned int)(3 * HZ)) + +/* + * The maximum back-off value for retransmissions. This is needed for + * - retransmitting client-Requests (sec. 8.1.1), + * - retransmitting Close/CloseReq when closing (sec. 8.3), + * - feature-negotiation retransmission (sec. 6.6.3), + * - Acks in client-PARTOPEN state (sec. 8.1.5). + */ +#define DCCP_RTO_MAX ((unsigned int)(64 * HZ)) + +/* + * RTT sampling: sanity bounds and fallback RTT value from RFC 4340, section 3.4 + */ +#define DCCP_SANE_RTT_MIN 100 +#define DCCP_FALLBACK_RTT (USEC_PER_SEC / 5) +#define DCCP_SANE_RTT_MAX (3 * USEC_PER_SEC) + +/* sysctl variables for DCCP */ +extern int sysctl_dccp_request_retries; +extern int sysctl_dccp_retries1; +extern int sysctl_dccp_retries2; +extern int sysctl_dccp_tx_qlen; +extern int sysctl_dccp_sync_ratelimit; + +/* + * 48-bit sequence number arithmetic (signed and unsigned) + */ +#define INT48_MIN 0x800000000000LL /* 2^47 */ +#define UINT48_MAX 0xFFFFFFFFFFFFLL /* 2^48 - 1 */ +#define COMPLEMENT48(x) (0x1000000000000LL - (x)) /* 2^48 - x */ +#define TO_SIGNED48(x) (((x) < INT48_MIN)? (x) : -COMPLEMENT48( (x))) +#define TO_UNSIGNED48(x) (((x) >= 0)? (x) : COMPLEMENT48(-(x))) +#define ADD48(a, b) (((a) + (b)) & UINT48_MAX) +#define SUB48(a, b) ADD48((a), COMPLEMENT48(b)) + +static inline void dccp_inc_seqno(u64 *seqno) +{ + *seqno = ADD48(*seqno, 1); +} + +/* signed mod-2^48 distance: pos. if seqno1 < seqno2, neg. if seqno1 > seqno2 */ +static inline s64 dccp_delta_seqno(const u64 seqno1, const u64 seqno2) +{ + u64 delta = SUB48(seqno2, seqno1); + + return TO_SIGNED48(delta); +} + +/* is seq1 < seq2 ? */ +static inline int before48(const u64 seq1, const u64 seq2) +{ + return (s64)((seq2 << 16) - (seq1 << 16)) > 0; +} + +/* is seq1 > seq2 ? */ +#define after48(seq1, seq2) before48(seq2, seq1) + +/* is seq2 <= seq1 <= seq3 ? */ +static inline int between48(const u64 seq1, const u64 seq2, const u64 seq3) +{ + return (seq3 << 16) - (seq2 << 16) >= (seq1 << 16) - (seq2 << 16); +} + +/** + * dccp_loss_count - Approximate the number of lost data packets in a burst loss + * @s1: last known sequence number before the loss ('hole') + * @s2: first sequence number seen after the 'hole' + * @ndp: NDP count on packet with sequence number @s2 + */ +static inline u64 dccp_loss_count(const u64 s1, const u64 s2, const u64 ndp) +{ + s64 delta = dccp_delta_seqno(s1, s2); + + WARN_ON(delta < 0); + delta -= ndp + 1; + + return delta > 0 ? delta : 0; +} + +/** + * dccp_loss_free - Evaluate condition for data loss from RFC 4340, 7.7.1 + */ +static inline bool dccp_loss_free(const u64 s1, const u64 s2, const u64 ndp) +{ + return dccp_loss_count(s1, s2, ndp) == 0; +} + +enum { + DCCP_MIB_NUM = 0, + DCCP_MIB_ACTIVEOPENS, /* ActiveOpens */ + DCCP_MIB_ESTABRESETS, /* EstabResets */ + DCCP_MIB_CURRESTAB, /* CurrEstab */ + DCCP_MIB_OUTSEGS, /* OutSegs */ + DCCP_MIB_OUTRSTS, + DCCP_MIB_ABORTONTIMEOUT, + DCCP_MIB_TIMEOUTS, + DCCP_MIB_ABORTFAILED, + DCCP_MIB_PASSIVEOPENS, + DCCP_MIB_ATTEMPTFAILS, + DCCP_MIB_OUTDATAGRAMS, + DCCP_MIB_INERRS, + DCCP_MIB_OPTMANDATORYERROR, + DCCP_MIB_INVALIDOPT, + __DCCP_MIB_MAX +}; + +#define DCCP_MIB_MAX __DCCP_MIB_MAX +struct dccp_mib { + unsigned long mibs[DCCP_MIB_MAX]; +}; + +DECLARE_SNMP_STAT(struct dccp_mib, dccp_statistics); +#define DCCP_INC_STATS(field) SNMP_INC_STATS(dccp_statistics, field) +#define __DCCP_INC_STATS(field) __SNMP_INC_STATS(dccp_statistics, field) +#define DCCP_DEC_STATS(field) SNMP_DEC_STATS(dccp_statistics, field) + +/* + * Checksumming routines + */ +static inline unsigned int dccp_csum_coverage(const struct sk_buff *skb) +{ + const struct dccp_hdr* dh = dccp_hdr(skb); + + if (dh->dccph_cscov == 0) + return skb->len; + return (dh->dccph_doff + dh->dccph_cscov - 1) * sizeof(u32); +} + +static inline void dccp_csum_outgoing(struct sk_buff *skb) +{ + unsigned int cov = dccp_csum_coverage(skb); + + if (cov >= skb->len) + dccp_hdr(skb)->dccph_cscov = 0; + + skb->csum = skb_checksum(skb, 0, (cov > skb->len)? skb->len : cov, 0); +} + +void dccp_v4_send_check(struct sock *sk, struct sk_buff *skb); + +int dccp_retransmit_skb(struct sock *sk); + +void dccp_send_ack(struct sock *sk); +void dccp_reqsk_send_ack(const struct sock *sk, struct sk_buff *skb, + struct request_sock *rsk); + +void dccp_send_sync(struct sock *sk, const u64 seq, + const enum dccp_pkt_type pkt_type); + +/* + * TX Packet Dequeueing Interface + */ +void dccp_qpolicy_push(struct sock *sk, struct sk_buff *skb); +bool dccp_qpolicy_full(struct sock *sk); +void dccp_qpolicy_drop(struct sock *sk, struct sk_buff *skb); +struct sk_buff *dccp_qpolicy_top(struct sock *sk); +struct sk_buff *dccp_qpolicy_pop(struct sock *sk); +bool dccp_qpolicy_param_ok(struct sock *sk, __be32 param); + +/* + * TX Packet Output and TX Timers + */ +void dccp_write_xmit(struct sock *sk); +void dccp_write_space(struct sock *sk); +void dccp_flush_write_queue(struct sock *sk, long *time_budget); + +void dccp_init_xmit_timers(struct sock *sk); +static inline void dccp_clear_xmit_timers(struct sock *sk) +{ + inet_csk_clear_xmit_timers(sk); +} + +unsigned int dccp_sync_mss(struct sock *sk, u32 pmtu); + +const char *dccp_packet_name(const int type); + +void dccp_set_state(struct sock *sk, const int state); +void dccp_done(struct sock *sk); + +int dccp_reqsk_init(struct request_sock *rq, struct dccp_sock const *dp, + struct sk_buff const *skb); + +int dccp_v4_conn_request(struct sock *sk, struct sk_buff *skb); + +struct sock *dccp_create_openreq_child(const struct sock *sk, + const struct request_sock *req, + const struct sk_buff *skb); + +int dccp_v4_do_rcv(struct sock *sk, struct sk_buff *skb); + +struct sock *dccp_v4_request_recv_sock(const struct sock *sk, struct sk_buff *skb, + struct request_sock *req, + struct dst_entry *dst, + struct request_sock *req_unhash, + bool *own_req); +struct sock *dccp_check_req(struct sock *sk, struct sk_buff *skb, + struct request_sock *req); + +int dccp_child_process(struct sock *parent, struct sock *child, + struct sk_buff *skb); +int dccp_rcv_state_process(struct sock *sk, struct sk_buff *skb, + struct dccp_hdr *dh, unsigned int len); +int dccp_rcv_established(struct sock *sk, struct sk_buff *skb, + const struct dccp_hdr *dh, const unsigned int len); + +void dccp_destruct_common(struct sock *sk); +int dccp_init_sock(struct sock *sk, const __u8 ctl_sock_initialized); +void dccp_destroy_sock(struct sock *sk); + +void dccp_close(struct sock *sk, long timeout); +struct sk_buff *dccp_make_response(const struct sock *sk, struct dst_entry *dst, + struct request_sock *req); + +int dccp_connect(struct sock *sk); +int dccp_disconnect(struct sock *sk, int flags); +int dccp_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen); +int dccp_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen); +int dccp_ioctl(struct sock *sk, int cmd, unsigned long arg); +int dccp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); +int dccp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, + int *addr_len); +void dccp_shutdown(struct sock *sk, int how); +int inet_dccp_listen(struct socket *sock, int backlog); +__poll_t dccp_poll(struct file *file, struct socket *sock, + poll_table *wait); +int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len); +void dccp_req_err(struct sock *sk, u64 seq); + +struct sk_buff *dccp_ctl_make_reset(struct sock *sk, struct sk_buff *skb); +int dccp_send_reset(struct sock *sk, enum dccp_reset_codes code); +void dccp_send_close(struct sock *sk, const int active); +int dccp_invalid_packet(struct sk_buff *skb); +u32 dccp_sample_rtt(struct sock *sk, long delta); + +static inline bool dccp_bad_service_code(const struct sock *sk, + const __be32 service) +{ + const struct dccp_sock *dp = dccp_sk(sk); + + if (dp->dccps_service == service) + return false; + return !dccp_list_has_service(dp->dccps_service_list, service); +} + +/** + * dccp_skb_cb - DCCP per-packet control information + * @dccpd_type: one of %dccp_pkt_type (or unknown) + * @dccpd_ccval: CCVal field (5.1), see e.g. RFC 4342, 8.1 + * @dccpd_reset_code: one of %dccp_reset_codes + * @dccpd_reset_data: Data1..3 fields (depend on @dccpd_reset_code) + * @dccpd_opt_len: total length of all options (5.8) in the packet + * @dccpd_seq: sequence number + * @dccpd_ack_seq: acknowledgment number subheader field value + * + * This is used for transmission as well as for reception. + */ +struct dccp_skb_cb { + union { + struct inet_skb_parm h4; +#if IS_ENABLED(CONFIG_IPV6) + struct inet6_skb_parm h6; +#endif + } header; + __u8 dccpd_type:4; + __u8 dccpd_ccval:4; + __u8 dccpd_reset_code, + dccpd_reset_data[3]; + __u16 dccpd_opt_len; + __u64 dccpd_seq; + __u64 dccpd_ack_seq; +}; + +#define DCCP_SKB_CB(__skb) ((struct dccp_skb_cb *)&((__skb)->cb[0])) + +/* RFC 4340, sec. 7.7 */ +static inline int dccp_non_data_packet(const struct sk_buff *skb) +{ + const __u8 type = DCCP_SKB_CB(skb)->dccpd_type; + + return type == DCCP_PKT_ACK || + type == DCCP_PKT_CLOSE || + type == DCCP_PKT_CLOSEREQ || + type == DCCP_PKT_RESET || + type == DCCP_PKT_SYNC || + type == DCCP_PKT_SYNCACK; +} + +/* RFC 4340, sec. 7.7 */ +static inline int dccp_data_packet(const struct sk_buff *skb) +{ + const __u8 type = DCCP_SKB_CB(skb)->dccpd_type; + + return type == DCCP_PKT_DATA || + type == DCCP_PKT_DATAACK || + type == DCCP_PKT_REQUEST || + type == DCCP_PKT_RESPONSE; +} + +static inline int dccp_packet_without_ack(const struct sk_buff *skb) +{ + const __u8 type = DCCP_SKB_CB(skb)->dccpd_type; + + return type == DCCP_PKT_DATA || type == DCCP_PKT_REQUEST; +} + +#define DCCP_PKT_WITHOUT_ACK_SEQ (UINT48_MAX << 2) + +static inline void dccp_hdr_set_seq(struct dccp_hdr *dh, const u64 gss) +{ + struct dccp_hdr_ext *dhx = (struct dccp_hdr_ext *)((void *)dh + + sizeof(*dh)); + dh->dccph_seq2 = 0; + dh->dccph_seq = htons((gss >> 32) & 0xfffff); + dhx->dccph_seq_low = htonl(gss & 0xffffffff); +} + +static inline void dccp_hdr_set_ack(struct dccp_hdr_ack_bits *dhack, + const u64 gsr) +{ + dhack->dccph_reserved1 = 0; + dhack->dccph_ack_nr_high = htons(gsr >> 32); + dhack->dccph_ack_nr_low = htonl(gsr & 0xffffffff); +} + +static inline void dccp_update_gsr(struct sock *sk, u64 seq) +{ + struct dccp_sock *dp = dccp_sk(sk); + + if (after48(seq, dp->dccps_gsr)) + dp->dccps_gsr = seq; + /* Sequence validity window depends on remote Sequence Window (7.5.1) */ + dp->dccps_swl = SUB48(ADD48(dp->dccps_gsr, 1), dp->dccps_r_seq_win / 4); + /* + * Adjust SWL so that it is not below ISR. In contrast to RFC 4340, + * 7.5.1 we perform this check beyond the initial handshake: W/W' are + * always > 32, so for the first W/W' packets in the lifetime of a + * connection we always have to adjust SWL. + * A second reason why we are doing this is that the window depends on + * the feature-remote value of Sequence Window: nothing stops the peer + * from updating this value while we are busy adjusting SWL for the + * first W packets (we would have to count from scratch again then). + * Therefore it is safer to always make sure that the Sequence Window + * is not artificially extended by a peer who grows SWL downwards by + * continually updating the feature-remote Sequence-Window. + * If sequence numbers wrap it is bad luck. But that will take a while + * (48 bit), and this measure prevents Sequence-number attacks. + */ + if (before48(dp->dccps_swl, dp->dccps_isr)) + dp->dccps_swl = dp->dccps_isr; + dp->dccps_swh = ADD48(dp->dccps_gsr, (3 * dp->dccps_r_seq_win) / 4); +} + +static inline void dccp_update_gss(struct sock *sk, u64 seq) +{ + struct dccp_sock *dp = dccp_sk(sk); + + dp->dccps_gss = seq; + /* Ack validity window depends on local Sequence Window value (7.5.1) */ + dp->dccps_awl = SUB48(ADD48(dp->dccps_gss, 1), dp->dccps_l_seq_win); + /* Adjust AWL so that it is not below ISS - see comment above for SWL */ + if (before48(dp->dccps_awl, dp->dccps_iss)) + dp->dccps_awl = dp->dccps_iss; + dp->dccps_awh = dp->dccps_gss; +} + +static inline int dccp_ackvec_pending(const struct sock *sk) +{ + return dccp_sk(sk)->dccps_hc_rx_ackvec != NULL && + !dccp_ackvec_is_empty(dccp_sk(sk)->dccps_hc_rx_ackvec); +} + +static inline int dccp_ack_pending(const struct sock *sk) +{ + return dccp_ackvec_pending(sk) || inet_csk_ack_scheduled(sk); +} + +int dccp_feat_signal_nn_change(struct sock *sk, u8 feat, u64 nn_val); +int dccp_feat_finalise_settings(struct dccp_sock *dp); +int dccp_feat_server_ccid_dependencies(struct dccp_request_sock *dreq); +int dccp_feat_insert_opts(struct dccp_sock*, struct dccp_request_sock*, + struct sk_buff *skb); +int dccp_feat_activate_values(struct sock *sk, struct list_head *fn); +void dccp_feat_list_purge(struct list_head *fn_list); + +int dccp_insert_options(struct sock *sk, struct sk_buff *skb); +int dccp_insert_options_rsk(struct dccp_request_sock *, struct sk_buff *); +u32 dccp_timestamp(void); +void dccp_timestamping_init(void); +int dccp_insert_option(struct sk_buff *skb, unsigned char option, + const void *value, unsigned char len); + +#ifdef CONFIG_SYSCTL +int dccp_sysctl_init(void); +void dccp_sysctl_exit(void); +#else +static inline int dccp_sysctl_init(void) +{ + return 0; +} + +static inline void dccp_sysctl_exit(void) +{ +} +#endif + +#endif /* _DCCP_H */ diff --git a/net/dccp/diag.c b/net/dccp/diag.c new file mode 100644 index 000000000..8a82c5a2c --- /dev/null +++ b/net/dccp/diag.c @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/dccp/diag.c + * + * An implementation of the DCCP protocol + * Arnaldo Carvalho de Melo + */ + + +#include +#include + +#include "ccid.h" +#include "dccp.h" + +static void dccp_get_info(struct sock *sk, struct tcp_info *info) +{ + struct dccp_sock *dp = dccp_sk(sk); + const struct inet_connection_sock *icsk = inet_csk(sk); + + memset(info, 0, sizeof(*info)); + + info->tcpi_state = sk->sk_state; + info->tcpi_retransmits = icsk->icsk_retransmits; + info->tcpi_probes = icsk->icsk_probes_out; + info->tcpi_backoff = icsk->icsk_backoff; + info->tcpi_pmtu = icsk->icsk_pmtu_cookie; + + if (dp->dccps_hc_rx_ackvec != NULL) + info->tcpi_options |= TCPI_OPT_SACK; + + if (dp->dccps_hc_rx_ccid != NULL) + ccid_hc_rx_get_info(dp->dccps_hc_rx_ccid, sk, info); + + if (dp->dccps_hc_tx_ccid != NULL) + ccid_hc_tx_get_info(dp->dccps_hc_tx_ccid, sk, info); +} + +static void dccp_diag_get_info(struct sock *sk, struct inet_diag_msg *r, + void *_info) +{ + r->idiag_rqueue = r->idiag_wqueue = 0; + + if (_info != NULL) + dccp_get_info(sk, _info); +} + +static void dccp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, + const struct inet_diag_req_v2 *r) +{ + inet_diag_dump_icsk(&dccp_hashinfo, skb, cb, r); +} + +static int dccp_diag_dump_one(struct netlink_callback *cb, + const struct inet_diag_req_v2 *req) +{ + return inet_diag_dump_one_icsk(&dccp_hashinfo, cb, req); +} + +static const struct inet_diag_handler dccp_diag_handler = { + .dump = dccp_diag_dump, + .dump_one = dccp_diag_dump_one, + .idiag_get_info = dccp_diag_get_info, + .idiag_type = IPPROTO_DCCP, + .idiag_info_size = sizeof(struct tcp_info), +}; + +static int __init dccp_diag_init(void) +{ + return inet_diag_register(&dccp_diag_handler); +} + +static void __exit dccp_diag_fini(void) +{ + inet_diag_unregister(&dccp_diag_handler); +} + +module_init(dccp_diag_init); +module_exit(dccp_diag_fini); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Arnaldo Carvalho de Melo "); +MODULE_DESCRIPTION("DCCP inet_diag handler"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2-33 /* AF_INET - IPPROTO_DCCP */); diff --git a/net/dccp/feat.c b/net/dccp/feat.c new file mode 100644 index 000000000..54086bb05 --- /dev/null +++ b/net/dccp/feat.c @@ -0,0 +1,1577 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/dccp/feat.c + * + * Feature negotiation for the DCCP protocol (RFC 4340, section 6) + * + * Copyright (c) 2008 Gerrit Renker + * Rewrote from scratch, some bits from earlier code by + * Copyright (c) 2005 Andrea Bittau + * + * ASSUMPTIONS + * ----------- + * o Feature negotiation is coordinated with connection setup (as in TCP), wild + * changes of parameters of an established connection are not supported. + * o Changing non-negotiable (NN) values is supported in state OPEN/PARTOPEN. + * o All currently known SP features have 1-byte quantities. If in the future + * extensions of RFCs 4340..42 define features with item lengths larger than + * one byte, a feature-specific extension of the code will be required. + */ +#include +#include +#include "ccid.h" +#include "feat.h" + +/* feature-specific sysctls - initialised to the defaults from RFC 4340, 6.4 */ +unsigned long sysctl_dccp_sequence_window __read_mostly = 100; +int sysctl_dccp_rx_ccid __read_mostly = 2, + sysctl_dccp_tx_ccid __read_mostly = 2; + +/* + * Feature activation handlers. + * + * These all use an u64 argument, to provide enough room for NN/SP features. At + * this stage the negotiated values have been checked to be within their range. + */ +static int dccp_hdlr_ccid(struct sock *sk, u64 ccid, bool rx) +{ + struct dccp_sock *dp = dccp_sk(sk); + struct ccid *new_ccid = ccid_new(ccid, sk, rx); + + if (new_ccid == NULL) + return -ENOMEM; + + if (rx) { + ccid_hc_rx_delete(dp->dccps_hc_rx_ccid, sk); + dp->dccps_hc_rx_ccid = new_ccid; + } else { + ccid_hc_tx_delete(dp->dccps_hc_tx_ccid, sk); + dp->dccps_hc_tx_ccid = new_ccid; + } + return 0; +} + +static int dccp_hdlr_seq_win(struct sock *sk, u64 seq_win, bool rx) +{ + struct dccp_sock *dp = dccp_sk(sk); + + if (rx) { + dp->dccps_r_seq_win = seq_win; + /* propagate changes to update SWL/SWH */ + dccp_update_gsr(sk, dp->dccps_gsr); + } else { + dp->dccps_l_seq_win = seq_win; + /* propagate changes to update AWL */ + dccp_update_gss(sk, dp->dccps_gss); + } + return 0; +} + +static int dccp_hdlr_ack_ratio(struct sock *sk, u64 ratio, bool rx) +{ + if (rx) + dccp_sk(sk)->dccps_r_ack_ratio = ratio; + else + dccp_sk(sk)->dccps_l_ack_ratio = ratio; + return 0; +} + +static int dccp_hdlr_ackvec(struct sock *sk, u64 enable, bool rx) +{ + struct dccp_sock *dp = dccp_sk(sk); + + if (rx) { + if (enable && dp->dccps_hc_rx_ackvec == NULL) { + dp->dccps_hc_rx_ackvec = dccp_ackvec_alloc(gfp_any()); + if (dp->dccps_hc_rx_ackvec == NULL) + return -ENOMEM; + } else if (!enable) { + dccp_ackvec_free(dp->dccps_hc_rx_ackvec); + dp->dccps_hc_rx_ackvec = NULL; + } + } + return 0; +} + +static int dccp_hdlr_ndp(struct sock *sk, u64 enable, bool rx) +{ + if (!rx) + dccp_sk(sk)->dccps_send_ndp_count = (enable > 0); + return 0; +} + +/* + * Minimum Checksum Coverage is located at the RX side (9.2.1). This means that + * `rx' holds when the sending peer informs about his partial coverage via a + * ChangeR() option. In the other case, we are the sender and the receiver + * announces its coverage via ChangeL() options. The policy here is to honour + * such communication by enabling the corresponding partial coverage - but only + * if it has not been set manually before; the warning here means that all + * packets will be dropped. + */ +static int dccp_hdlr_min_cscov(struct sock *sk, u64 cscov, bool rx) +{ + struct dccp_sock *dp = dccp_sk(sk); + + if (rx) + dp->dccps_pcrlen = cscov; + else { + if (dp->dccps_pcslen == 0) + dp->dccps_pcslen = cscov; + else if (cscov > dp->dccps_pcslen) + DCCP_WARN("CsCov %u too small, peer requires >= %u\n", + dp->dccps_pcslen, (u8)cscov); + } + return 0; +} + +static const struct { + u8 feat_num; /* DCCPF_xxx */ + enum dccp_feat_type rxtx; /* RX or TX */ + enum dccp_feat_type reconciliation; /* SP or NN */ + u8 default_value; /* as in 6.4 */ + int (*activation_hdlr)(struct sock *sk, u64 val, bool rx); +/* + * Lookup table for location and type of features (from RFC 4340/4342) + * +--------------------------+----+-----+----+----+---------+-----------+ + * | Feature | Location | Reconc. | Initial | Section | + * | | RX | TX | SP | NN | Value | Reference | + * +--------------------------+----+-----+----+----+---------+-----------+ + * | DCCPF_CCID | | X | X | | 2 | 10 | + * | DCCPF_SHORT_SEQNOS | | X | X | | 0 | 7.6.1 | + * | DCCPF_SEQUENCE_WINDOW | | X | | X | 100 | 7.5.2 | + * | DCCPF_ECN_INCAPABLE | X | | X | | 0 | 12.1 | + * | DCCPF_ACK_RATIO | | X | | X | 2 | 11.3 | + * | DCCPF_SEND_ACK_VECTOR | X | | X | | 0 | 11.5 | + * | DCCPF_SEND_NDP_COUNT | | X | X | | 0 | 7.7.2 | + * | DCCPF_MIN_CSUM_COVER | X | | X | | 0 | 9.2.1 | + * | DCCPF_DATA_CHECKSUM | X | | X | | 0 | 9.3.1 | + * | DCCPF_SEND_LEV_RATE | X | | X | | 0 | 4342/8.4 | + * +--------------------------+----+-----+----+----+---------+-----------+ + */ +} dccp_feat_table[] = { + { DCCPF_CCID, FEAT_AT_TX, FEAT_SP, 2, dccp_hdlr_ccid }, + { DCCPF_SHORT_SEQNOS, FEAT_AT_TX, FEAT_SP, 0, NULL }, + { DCCPF_SEQUENCE_WINDOW, FEAT_AT_TX, FEAT_NN, 100, dccp_hdlr_seq_win }, + { DCCPF_ECN_INCAPABLE, FEAT_AT_RX, FEAT_SP, 0, NULL }, + { DCCPF_ACK_RATIO, FEAT_AT_TX, FEAT_NN, 2, dccp_hdlr_ack_ratio}, + { DCCPF_SEND_ACK_VECTOR, FEAT_AT_RX, FEAT_SP, 0, dccp_hdlr_ackvec }, + { DCCPF_SEND_NDP_COUNT, FEAT_AT_TX, FEAT_SP, 0, dccp_hdlr_ndp }, + { DCCPF_MIN_CSUM_COVER, FEAT_AT_RX, FEAT_SP, 0, dccp_hdlr_min_cscov}, + { DCCPF_DATA_CHECKSUM, FEAT_AT_RX, FEAT_SP, 0, NULL }, + { DCCPF_SEND_LEV_RATE, FEAT_AT_RX, FEAT_SP, 0, NULL }, +}; +#define DCCP_FEAT_SUPPORTED_MAX ARRAY_SIZE(dccp_feat_table) + +/** + * dccp_feat_index - Hash function to map feature number into array position + * @feat_num: feature to hash, one of %dccp_feature_numbers + * + * Returns consecutive array index or -1 if the feature is not understood. + */ +static int dccp_feat_index(u8 feat_num) +{ + /* The first 9 entries are occupied by the types from RFC 4340, 6.4 */ + if (feat_num > DCCPF_RESERVED && feat_num <= DCCPF_DATA_CHECKSUM) + return feat_num - 1; + + /* + * Other features: add cases for new feature types here after adding + * them to the above table. + */ + switch (feat_num) { + case DCCPF_SEND_LEV_RATE: + return DCCP_FEAT_SUPPORTED_MAX - 1; + } + return -1; +} + +static u8 dccp_feat_type(u8 feat_num) +{ + int idx = dccp_feat_index(feat_num); + + if (idx < 0) + return FEAT_UNKNOWN; + return dccp_feat_table[idx].reconciliation; +} + +static int dccp_feat_default_value(u8 feat_num) +{ + int idx = dccp_feat_index(feat_num); + /* + * There are no default values for unknown features, so encountering a + * negative index here indicates a serious problem somewhere else. + */ + DCCP_BUG_ON(idx < 0); + + return idx < 0 ? 0 : dccp_feat_table[idx].default_value; +} + +/* + * Debugging and verbose-printing section + */ +static const char *dccp_feat_fname(const u8 feat) +{ + static const char *const feature_names[] = { + [DCCPF_RESERVED] = "Reserved", + [DCCPF_CCID] = "CCID", + [DCCPF_SHORT_SEQNOS] = "Allow Short Seqnos", + [DCCPF_SEQUENCE_WINDOW] = "Sequence Window", + [DCCPF_ECN_INCAPABLE] = "ECN Incapable", + [DCCPF_ACK_RATIO] = "Ack Ratio", + [DCCPF_SEND_ACK_VECTOR] = "Send ACK Vector", + [DCCPF_SEND_NDP_COUNT] = "Send NDP Count", + [DCCPF_MIN_CSUM_COVER] = "Min. Csum Coverage", + [DCCPF_DATA_CHECKSUM] = "Send Data Checksum", + }; + if (feat > DCCPF_DATA_CHECKSUM && feat < DCCPF_MIN_CCID_SPECIFIC) + return feature_names[DCCPF_RESERVED]; + + if (feat == DCCPF_SEND_LEV_RATE) + return "Send Loss Event Rate"; + if (feat >= DCCPF_MIN_CCID_SPECIFIC) + return "CCID-specific"; + + return feature_names[feat]; +} + +static const char *const dccp_feat_sname[] = { + "DEFAULT", "INITIALISING", "CHANGING", "UNSTABLE", "STABLE", +}; + +#ifdef CONFIG_IP_DCCP_DEBUG +static const char *dccp_feat_oname(const u8 opt) +{ + switch (opt) { + case DCCPO_CHANGE_L: return "Change_L"; + case DCCPO_CONFIRM_L: return "Confirm_L"; + case DCCPO_CHANGE_R: return "Change_R"; + case DCCPO_CONFIRM_R: return "Confirm_R"; + } + return NULL; +} + +static void dccp_feat_printval(u8 feat_num, dccp_feat_val const *val) +{ + u8 i, type = dccp_feat_type(feat_num); + + if (val == NULL || (type == FEAT_SP && val->sp.vec == NULL)) + dccp_pr_debug_cat("(NULL)"); + else if (type == FEAT_SP) + for (i = 0; i < val->sp.len; i++) + dccp_pr_debug_cat("%s%u", i ? " " : "", val->sp.vec[i]); + else if (type == FEAT_NN) + dccp_pr_debug_cat("%llu", (unsigned long long)val->nn); + else + dccp_pr_debug_cat("unknown type %u", type); +} + +static void dccp_feat_printvals(u8 feat_num, u8 *list, u8 len) +{ + u8 type = dccp_feat_type(feat_num); + dccp_feat_val fval = { .sp.vec = list, .sp.len = len }; + + if (type == FEAT_NN) + fval.nn = dccp_decode_value_var(list, len); + dccp_feat_printval(feat_num, &fval); +} + +static void dccp_feat_print_entry(struct dccp_feat_entry const *entry) +{ + dccp_debug(" * %s %s = ", entry->is_local ? "local" : "remote", + dccp_feat_fname(entry->feat_num)); + dccp_feat_printval(entry->feat_num, &entry->val); + dccp_pr_debug_cat(", state=%s %s\n", dccp_feat_sname[entry->state], + entry->needs_confirm ? "(Confirm pending)" : ""); +} + +#define dccp_feat_print_opt(opt, feat, val, len, mandatory) do { \ + dccp_pr_debug("%s(%s, ", dccp_feat_oname(opt), dccp_feat_fname(feat));\ + dccp_feat_printvals(feat, val, len); \ + dccp_pr_debug_cat(") %s\n", mandatory ? "!" : ""); } while (0) + +#define dccp_feat_print_fnlist(fn_list) { \ + const struct dccp_feat_entry *___entry; \ + \ + dccp_pr_debug("List Dump:\n"); \ + list_for_each_entry(___entry, fn_list, node) \ + dccp_feat_print_entry(___entry); \ +} +#else /* ! CONFIG_IP_DCCP_DEBUG */ +#define dccp_feat_print_opt(opt, feat, val, len, mandatory) +#define dccp_feat_print_fnlist(fn_list) +#endif + +static int __dccp_feat_activate(struct sock *sk, const int idx, + const bool is_local, dccp_feat_val const *fval) +{ + bool rx; + u64 val; + + if (idx < 0 || idx >= DCCP_FEAT_SUPPORTED_MAX) + return -1; + if (dccp_feat_table[idx].activation_hdlr == NULL) + return 0; + + if (fval == NULL) { + val = dccp_feat_table[idx].default_value; + } else if (dccp_feat_table[idx].reconciliation == FEAT_SP) { + if (fval->sp.vec == NULL) { + /* + * This can happen when an empty Confirm is sent + * for an SP (i.e. known) feature. In this case + * we would be using the default anyway. + */ + DCCP_CRIT("Feature #%d undefined: using default", idx); + val = dccp_feat_table[idx].default_value; + } else { + val = fval->sp.vec[0]; + } + } else { + val = fval->nn; + } + + /* Location is RX if this is a local-RX or remote-TX feature */ + rx = (is_local == (dccp_feat_table[idx].rxtx == FEAT_AT_RX)); + + dccp_debug(" -> activating %s %s, %sval=%llu\n", rx ? "RX" : "TX", + dccp_feat_fname(dccp_feat_table[idx].feat_num), + fval ? "" : "default ", (unsigned long long)val); + + return dccp_feat_table[idx].activation_hdlr(sk, val, rx); +} + +/** + * dccp_feat_activate - Activate feature value on socket + * @sk: fully connected DCCP socket (after handshake is complete) + * @feat_num: feature to activate, one of %dccp_feature_numbers + * @local: whether local (1) or remote (0) @feat_num is meant + * @fval: the value (SP or NN) to activate, or NULL to use the default value + * + * For general use this function is preferable over __dccp_feat_activate(). + */ +static int dccp_feat_activate(struct sock *sk, u8 feat_num, bool local, + dccp_feat_val const *fval) +{ + return __dccp_feat_activate(sk, dccp_feat_index(feat_num), local, fval); +} + +/* Test for "Req'd" feature (RFC 4340, 6.4) */ +static inline int dccp_feat_must_be_understood(u8 feat_num) +{ + return feat_num == DCCPF_CCID || feat_num == DCCPF_SHORT_SEQNOS || + feat_num == DCCPF_SEQUENCE_WINDOW; +} + +/* copy constructor, fval must not already contain allocated memory */ +static int dccp_feat_clone_sp_val(dccp_feat_val *fval, u8 const *val, u8 len) +{ + fval->sp.len = len; + if (fval->sp.len > 0) { + fval->sp.vec = kmemdup(val, len, gfp_any()); + if (fval->sp.vec == NULL) { + fval->sp.len = 0; + return -ENOMEM; + } + } + return 0; +} + +static void dccp_feat_val_destructor(u8 feat_num, dccp_feat_val *val) +{ + if (unlikely(val == NULL)) + return; + if (dccp_feat_type(feat_num) == FEAT_SP) + kfree(val->sp.vec); + memset(val, 0, sizeof(*val)); +} + +static struct dccp_feat_entry * + dccp_feat_clone_entry(struct dccp_feat_entry const *original) +{ + struct dccp_feat_entry *new; + u8 type = dccp_feat_type(original->feat_num); + + if (type == FEAT_UNKNOWN) + return NULL; + + new = kmemdup(original, sizeof(struct dccp_feat_entry), gfp_any()); + if (new == NULL) + return NULL; + + if (type == FEAT_SP && dccp_feat_clone_sp_val(&new->val, + original->val.sp.vec, + original->val.sp.len)) { + kfree(new); + return NULL; + } + return new; +} + +static void dccp_feat_entry_destructor(struct dccp_feat_entry *entry) +{ + if (entry != NULL) { + dccp_feat_val_destructor(entry->feat_num, &entry->val); + kfree(entry); + } +} + +/* + * List management functions + * + * Feature negotiation lists rely on and maintain the following invariants: + * - each feat_num in the list is known, i.e. we know its type and default value + * - each feat_num/is_local combination is unique (old entries are overwritten) + * - SP values are always freshly allocated + * - list is sorted in increasing order of feature number (faster lookup) + */ +static struct dccp_feat_entry *dccp_feat_list_lookup(struct list_head *fn_list, + u8 feat_num, bool is_local) +{ + struct dccp_feat_entry *entry; + + list_for_each_entry(entry, fn_list, node) { + if (entry->feat_num == feat_num && entry->is_local == is_local) + return entry; + else if (entry->feat_num > feat_num) + break; + } + return NULL; +} + +/** + * dccp_feat_entry_new - Central list update routine (called by all others) + * @head: list to add to + * @feat: feature number + * @local: whether the local (1) or remote feature with number @feat is meant + * + * This is the only constructor and serves to ensure the above invariants. + */ +static struct dccp_feat_entry * + dccp_feat_entry_new(struct list_head *head, u8 feat, bool local) +{ + struct dccp_feat_entry *entry; + + list_for_each_entry(entry, head, node) + if (entry->feat_num == feat && entry->is_local == local) { + dccp_feat_val_destructor(entry->feat_num, &entry->val); + return entry; + } else if (entry->feat_num > feat) { + head = &entry->node; + break; + } + + entry = kmalloc(sizeof(*entry), gfp_any()); + if (entry != NULL) { + entry->feat_num = feat; + entry->is_local = local; + list_add_tail(&entry->node, head); + } + return entry; +} + +/** + * dccp_feat_push_change - Add/overwrite a Change option in the list + * @fn_list: feature-negotiation list to update + * @feat: one of %dccp_feature_numbers + * @local: whether local (1) or remote (0) @feat_num is meant + * @mandatory: whether to use Mandatory feature negotiation options + * @fval: pointer to NN/SP value to be inserted (will be copied) + */ +static int dccp_feat_push_change(struct list_head *fn_list, u8 feat, u8 local, + u8 mandatory, dccp_feat_val *fval) +{ + struct dccp_feat_entry *new = dccp_feat_entry_new(fn_list, feat, local); + + if (new == NULL) + return -ENOMEM; + + new->feat_num = feat; + new->is_local = local; + new->state = FEAT_INITIALISING; + new->needs_confirm = false; + new->empty_confirm = false; + new->val = *fval; + new->needs_mandatory = mandatory; + + return 0; +} + +/** + * dccp_feat_push_confirm - Add a Confirm entry to the FN list + * @fn_list: feature-negotiation list to add to + * @feat: one of %dccp_feature_numbers + * @local: whether local (1) or remote (0) @feat_num is being confirmed + * @fval: pointer to NN/SP value to be inserted or NULL + * + * Returns 0 on success, a Reset code for further processing otherwise. + */ +static int dccp_feat_push_confirm(struct list_head *fn_list, u8 feat, u8 local, + dccp_feat_val *fval) +{ + struct dccp_feat_entry *new = dccp_feat_entry_new(fn_list, feat, local); + + if (new == NULL) + return DCCP_RESET_CODE_TOO_BUSY; + + new->feat_num = feat; + new->is_local = local; + new->state = FEAT_STABLE; /* transition in 6.6.2 */ + new->needs_confirm = true; + new->empty_confirm = (fval == NULL); + new->val.nn = 0; /* zeroes the whole structure */ + if (!new->empty_confirm) + new->val = *fval; + new->needs_mandatory = false; + + return 0; +} + +static int dccp_push_empty_confirm(struct list_head *fn_list, u8 feat, u8 local) +{ + return dccp_feat_push_confirm(fn_list, feat, local, NULL); +} + +static inline void dccp_feat_list_pop(struct dccp_feat_entry *entry) +{ + list_del(&entry->node); + dccp_feat_entry_destructor(entry); +} + +void dccp_feat_list_purge(struct list_head *fn_list) +{ + struct dccp_feat_entry *entry, *next; + + list_for_each_entry_safe(entry, next, fn_list, node) + dccp_feat_entry_destructor(entry); + INIT_LIST_HEAD(fn_list); +} +EXPORT_SYMBOL_GPL(dccp_feat_list_purge); + +/* generate @to as full clone of @from - @to must not contain any nodes */ +int dccp_feat_clone_list(struct list_head const *from, struct list_head *to) +{ + struct dccp_feat_entry *entry, *new; + + INIT_LIST_HEAD(to); + list_for_each_entry(entry, from, node) { + new = dccp_feat_clone_entry(entry); + if (new == NULL) + goto cloning_failed; + list_add_tail(&new->node, to); + } + return 0; + +cloning_failed: + dccp_feat_list_purge(to); + return -ENOMEM; +} + +/** + * dccp_feat_valid_nn_length - Enforce length constraints on NN options + * @feat_num: feature to return length of, one of %dccp_feature_numbers + * + * Length is between 0 and %DCCP_OPTVAL_MAXLEN. Used for outgoing packets only, + * incoming options are accepted as long as their values are valid. + */ +static u8 dccp_feat_valid_nn_length(u8 feat_num) +{ + if (feat_num == DCCPF_ACK_RATIO) /* RFC 4340, 11.3 and 6.6.8 */ + return 2; + if (feat_num == DCCPF_SEQUENCE_WINDOW) /* RFC 4340, 7.5.2 and 6.5 */ + return 6; + return 0; +} + +static u8 dccp_feat_is_valid_nn_val(u8 feat_num, u64 val) +{ + switch (feat_num) { + case DCCPF_ACK_RATIO: + return val <= DCCPF_ACK_RATIO_MAX; + case DCCPF_SEQUENCE_WINDOW: + return val >= DCCPF_SEQ_WMIN && val <= DCCPF_SEQ_WMAX; + } + return 0; /* feature unknown - so we can't tell */ +} + +/* check that SP values are within the ranges defined in RFC 4340 */ +static u8 dccp_feat_is_valid_sp_val(u8 feat_num, u8 val) +{ + switch (feat_num) { + case DCCPF_CCID: + return val == DCCPC_CCID2 || val == DCCPC_CCID3; + /* Type-check Boolean feature values: */ + case DCCPF_SHORT_SEQNOS: + case DCCPF_ECN_INCAPABLE: + case DCCPF_SEND_ACK_VECTOR: + case DCCPF_SEND_NDP_COUNT: + case DCCPF_DATA_CHECKSUM: + case DCCPF_SEND_LEV_RATE: + return val < 2; + case DCCPF_MIN_CSUM_COVER: + return val < 16; + } + return 0; /* feature unknown */ +} + +static u8 dccp_feat_sp_list_ok(u8 feat_num, u8 const *sp_list, u8 sp_len) +{ + if (sp_list == NULL || sp_len < 1) + return 0; + while (sp_len--) + if (!dccp_feat_is_valid_sp_val(feat_num, *sp_list++)) + return 0; + return 1; +} + +/** + * dccp_feat_insert_opts - Generate FN options from current list state + * @skb: next sk_buff to be sent to the peer + * @dp: for client during handshake and general negotiation + * @dreq: used by the server only (all Changes/Confirms in LISTEN/RESPOND) + */ +int dccp_feat_insert_opts(struct dccp_sock *dp, struct dccp_request_sock *dreq, + struct sk_buff *skb) +{ + struct list_head *fn = dreq ? &dreq->dreq_featneg : &dp->dccps_featneg; + struct dccp_feat_entry *pos, *next; + u8 opt, type, len, *ptr, nn_in_nbo[DCCP_OPTVAL_MAXLEN]; + bool rpt; + + /* put entries into @skb in the order they appear in the list */ + list_for_each_entry_safe_reverse(pos, next, fn, node) { + opt = dccp_feat_genopt(pos); + type = dccp_feat_type(pos->feat_num); + rpt = false; + + if (pos->empty_confirm) { + len = 0; + ptr = NULL; + } else { + if (type == FEAT_SP) { + len = pos->val.sp.len; + ptr = pos->val.sp.vec; + rpt = pos->needs_confirm; + } else if (type == FEAT_NN) { + len = dccp_feat_valid_nn_length(pos->feat_num); + ptr = nn_in_nbo; + dccp_encode_value_var(pos->val.nn, ptr, len); + } else { + DCCP_BUG("unknown feature %u", pos->feat_num); + return -1; + } + } + dccp_feat_print_opt(opt, pos->feat_num, ptr, len, 0); + + if (dccp_insert_fn_opt(skb, opt, pos->feat_num, ptr, len, rpt)) + return -1; + if (pos->needs_mandatory && dccp_insert_option_mandatory(skb)) + return -1; + + if (skb->sk->sk_state == DCCP_OPEN && + (opt == DCCPO_CONFIRM_R || opt == DCCPO_CONFIRM_L)) { + /* + * Confirms don't get retransmitted (6.6.3) once the + * connection is in state OPEN + */ + dccp_feat_list_pop(pos); + } else { + /* + * Enter CHANGING after transmitting the Change + * option (6.6.2). + */ + if (pos->state == FEAT_INITIALISING) + pos->state = FEAT_CHANGING; + } + } + return 0; +} + +/** + * __feat_register_nn - Register new NN value on socket + * @fn: feature-negotiation list to register with + * @feat: an NN feature from %dccp_feature_numbers + * @mandatory: use Mandatory option if 1 + * @nn_val: value to register (restricted to 4 bytes) + * + * Note that NN features are local by definition (RFC 4340, 6.3.2). + */ +static int __feat_register_nn(struct list_head *fn, u8 feat, + u8 mandatory, u64 nn_val) +{ + dccp_feat_val fval = { .nn = nn_val }; + + if (dccp_feat_type(feat) != FEAT_NN || + !dccp_feat_is_valid_nn_val(feat, nn_val)) + return -EINVAL; + + /* Don't bother with default values, they will be activated anyway. */ + if (nn_val - (u64)dccp_feat_default_value(feat) == 0) + return 0; + + return dccp_feat_push_change(fn, feat, 1, mandatory, &fval); +} + +/** + * __feat_register_sp - Register new SP value/list on socket + * @fn: feature-negotiation list to register with + * @feat: an SP feature from %dccp_feature_numbers + * @is_local: whether the local (1) or the remote (0) @feat is meant + * @mandatory: use Mandatory option if 1 + * @sp_val: SP value followed by optional preference list + * @sp_len: length of @sp_val in bytes + */ +static int __feat_register_sp(struct list_head *fn, u8 feat, u8 is_local, + u8 mandatory, u8 const *sp_val, u8 sp_len) +{ + dccp_feat_val fval; + + if (dccp_feat_type(feat) != FEAT_SP || + !dccp_feat_sp_list_ok(feat, sp_val, sp_len)) + return -EINVAL; + + /* Avoid negotiating alien CCIDs by only advertising supported ones */ + if (feat == DCCPF_CCID && !ccid_support_check(sp_val, sp_len)) + return -EOPNOTSUPP; + + if (dccp_feat_clone_sp_val(&fval, sp_val, sp_len)) + return -ENOMEM; + + if (dccp_feat_push_change(fn, feat, is_local, mandatory, &fval)) { + kfree(fval.sp.vec); + return -ENOMEM; + } + + return 0; +} + +/** + * dccp_feat_register_sp - Register requests to change SP feature values + * @sk: client or listening socket + * @feat: one of %dccp_feature_numbers + * @is_local: whether the local (1) or remote (0) @feat is meant + * @list: array of preferred values, in descending order of preference + * @len: length of @list in bytes + */ +int dccp_feat_register_sp(struct sock *sk, u8 feat, u8 is_local, + u8 const *list, u8 len) +{ /* any changes must be registered before establishing the connection */ + if (sk->sk_state != DCCP_CLOSED) + return -EISCONN; + if (dccp_feat_type(feat) != FEAT_SP) + return -EINVAL; + return __feat_register_sp(&dccp_sk(sk)->dccps_featneg, feat, is_local, + 0, list, len); +} + +/** + * dccp_feat_nn_get - Query current/pending value of NN feature + * @sk: DCCP socket of an established connection + * @feat: NN feature number from %dccp_feature_numbers + * + * For a known NN feature, returns value currently being negotiated, or + * current (confirmed) value if no negotiation is going on. + */ +u64 dccp_feat_nn_get(struct sock *sk, u8 feat) +{ + if (dccp_feat_type(feat) == FEAT_NN) { + struct dccp_sock *dp = dccp_sk(sk); + struct dccp_feat_entry *entry; + + entry = dccp_feat_list_lookup(&dp->dccps_featneg, feat, 1); + if (entry != NULL) + return entry->val.nn; + + switch (feat) { + case DCCPF_ACK_RATIO: + return dp->dccps_l_ack_ratio; + case DCCPF_SEQUENCE_WINDOW: + return dp->dccps_l_seq_win; + } + } + DCCP_BUG("attempt to look up unsupported feature %u", feat); + return 0; +} +EXPORT_SYMBOL_GPL(dccp_feat_nn_get); + +/** + * dccp_feat_signal_nn_change - Update NN values for an established connection + * @sk: DCCP socket of an established connection + * @feat: NN feature number from %dccp_feature_numbers + * @nn_val: the new value to use + * + * This function is used to communicate NN updates out-of-band. + */ +int dccp_feat_signal_nn_change(struct sock *sk, u8 feat, u64 nn_val) +{ + struct list_head *fn = &dccp_sk(sk)->dccps_featneg; + dccp_feat_val fval = { .nn = nn_val }; + struct dccp_feat_entry *entry; + + if (sk->sk_state != DCCP_OPEN && sk->sk_state != DCCP_PARTOPEN) + return 0; + + if (dccp_feat_type(feat) != FEAT_NN || + !dccp_feat_is_valid_nn_val(feat, nn_val)) + return -EINVAL; + + if (nn_val == dccp_feat_nn_get(sk, feat)) + return 0; /* already set or negotiation under way */ + + entry = dccp_feat_list_lookup(fn, feat, 1); + if (entry != NULL) { + dccp_pr_debug("Clobbering existing NN entry %llu -> %llu\n", + (unsigned long long)entry->val.nn, + (unsigned long long)nn_val); + dccp_feat_list_pop(entry); + } + + inet_csk_schedule_ack(sk); + return dccp_feat_push_change(fn, feat, 1, 0, &fval); +} +EXPORT_SYMBOL_GPL(dccp_feat_signal_nn_change); + +/* + * Tracking features whose value depend on the choice of CCID + * + * This is designed with an extension in mind so that a list walk could be done + * before activating any features. However, the existing framework was found to + * work satisfactorily up until now, the automatic verification is left open. + * When adding new CCIDs, add a corresponding dependency table here. + */ +static const struct ccid_dependency *dccp_feat_ccid_deps(u8 ccid, bool is_local) +{ + static const struct ccid_dependency ccid2_dependencies[2][2] = { + /* + * CCID2 mandates Ack Vectors (RFC 4341, 4.): as CCID is a TX + * feature and Send Ack Vector is an RX feature, `is_local' + * needs to be reversed. + */ + { /* Dependencies of the receiver-side (remote) CCID2 */ + { + .dependent_feat = DCCPF_SEND_ACK_VECTOR, + .is_local = true, + .is_mandatory = true, + .val = 1 + }, + { 0, 0, 0, 0 } + }, + { /* Dependencies of the sender-side (local) CCID2 */ + { + .dependent_feat = DCCPF_SEND_ACK_VECTOR, + .is_local = false, + .is_mandatory = true, + .val = 1 + }, + { 0, 0, 0, 0 } + } + }; + static const struct ccid_dependency ccid3_dependencies[2][5] = { + { /* + * Dependencies of the receiver-side CCID3 + */ + { /* locally disable Ack Vectors */ + .dependent_feat = DCCPF_SEND_ACK_VECTOR, + .is_local = true, + .is_mandatory = false, + .val = 0 + }, + { /* see below why Send Loss Event Rate is on */ + .dependent_feat = DCCPF_SEND_LEV_RATE, + .is_local = true, + .is_mandatory = true, + .val = 1 + }, + { /* NDP Count is needed as per RFC 4342, 6.1.1 */ + .dependent_feat = DCCPF_SEND_NDP_COUNT, + .is_local = false, + .is_mandatory = true, + .val = 1 + }, + { 0, 0, 0, 0 }, + }, + { /* + * CCID3 at the TX side: we request that the HC-receiver + * will not send Ack Vectors (they will be ignored, so + * Mandatory is not set); we enable Send Loss Event Rate + * (Mandatory since the implementation does not support + * the Loss Intervals option of RFC 4342, 8.6). + * The last two options are for peer's information only. + */ + { + .dependent_feat = DCCPF_SEND_ACK_VECTOR, + .is_local = false, + .is_mandatory = false, + .val = 0 + }, + { + .dependent_feat = DCCPF_SEND_LEV_RATE, + .is_local = false, + .is_mandatory = true, + .val = 1 + }, + { /* this CCID does not support Ack Ratio */ + .dependent_feat = DCCPF_ACK_RATIO, + .is_local = true, + .is_mandatory = false, + .val = 0 + }, + { /* tell receiver we are sending NDP counts */ + .dependent_feat = DCCPF_SEND_NDP_COUNT, + .is_local = true, + .is_mandatory = false, + .val = 1 + }, + { 0, 0, 0, 0 } + } + }; + switch (ccid) { + case DCCPC_CCID2: + return ccid2_dependencies[is_local]; + case DCCPC_CCID3: + return ccid3_dependencies[is_local]; + default: + return NULL; + } +} + +/** + * dccp_feat_propagate_ccid - Resolve dependencies of features on choice of CCID + * @fn: feature-negotiation list to update + * @id: CCID number to track + * @is_local: whether TX CCID (1) or RX CCID (0) is meant + * + * This function needs to be called after registering all other features. + */ +static int dccp_feat_propagate_ccid(struct list_head *fn, u8 id, bool is_local) +{ + const struct ccid_dependency *table = dccp_feat_ccid_deps(id, is_local); + int i, rc = (table == NULL); + + for (i = 0; rc == 0 && table[i].dependent_feat != DCCPF_RESERVED; i++) + if (dccp_feat_type(table[i].dependent_feat) == FEAT_SP) + rc = __feat_register_sp(fn, table[i].dependent_feat, + table[i].is_local, + table[i].is_mandatory, + &table[i].val, 1); + else + rc = __feat_register_nn(fn, table[i].dependent_feat, + table[i].is_mandatory, + table[i].val); + return rc; +} + +/** + * dccp_feat_finalise_settings - Finalise settings before starting negotiation + * @dp: client or listening socket (settings will be inherited) + * + * This is called after all registrations (socket initialisation, sysctls, and + * sockopt calls), and before sending the first packet containing Change options + * (ie. client-Request or server-Response), to ensure internal consistency. + */ +int dccp_feat_finalise_settings(struct dccp_sock *dp) +{ + struct list_head *fn = &dp->dccps_featneg; + struct dccp_feat_entry *entry; + int i = 2, ccids[2] = { -1, -1 }; + + /* + * Propagating CCIDs: + * 1) not useful to propagate CCID settings if this host advertises more + * than one CCID: the choice of CCID may still change - if this is + * the client, or if this is the server and the client sends + * singleton CCID values. + * 2) since is that propagate_ccid changes the list, we defer changing + * the sorted list until after the traversal. + */ + list_for_each_entry(entry, fn, node) + if (entry->feat_num == DCCPF_CCID && entry->val.sp.len == 1) + ccids[entry->is_local] = entry->val.sp.vec[0]; + while (i--) + if (ccids[i] > 0 && dccp_feat_propagate_ccid(fn, ccids[i], i)) + return -1; + dccp_feat_print_fnlist(fn); + return 0; +} + +/** + * dccp_feat_server_ccid_dependencies - Resolve CCID-dependent features + * @dreq: server socket to resolve + * + * It is the server which resolves the dependencies once the CCID has been + * fully negotiated. If no CCID has been negotiated, it uses the default CCID. + */ +int dccp_feat_server_ccid_dependencies(struct dccp_request_sock *dreq) +{ + struct list_head *fn = &dreq->dreq_featneg; + struct dccp_feat_entry *entry; + u8 is_local, ccid; + + for (is_local = 0; is_local <= 1; is_local++) { + entry = dccp_feat_list_lookup(fn, DCCPF_CCID, is_local); + + if (entry != NULL && !entry->empty_confirm) + ccid = entry->val.sp.vec[0]; + else + ccid = dccp_feat_default_value(DCCPF_CCID); + + if (dccp_feat_propagate_ccid(fn, ccid, is_local)) + return -1; + } + return 0; +} + +/* Select the first entry in @servlist that also occurs in @clilist (6.3.1) */ +static int dccp_feat_preflist_match(u8 *servlist, u8 slen, u8 *clilist, u8 clen) +{ + u8 c, s; + + for (s = 0; s < slen; s++) + for (c = 0; c < clen; c++) + if (servlist[s] == clilist[c]) + return servlist[s]; + return -1; +} + +/** + * dccp_feat_prefer - Move preferred entry to the start of array + * @preferred_value: entry to move to start of array + * @array: array of preferred entries + * @array_len: size of the array + * + * Reorder the @array_len elements in @array so that @preferred_value comes + * first. Returns >0 to indicate that @preferred_value does occur in @array. + */ +static u8 dccp_feat_prefer(u8 preferred_value, u8 *array, u8 array_len) +{ + u8 i, does_occur = 0; + + if (array != NULL) { + for (i = 0; i < array_len; i++) + if (array[i] == preferred_value) { + array[i] = array[0]; + does_occur++; + } + if (does_occur) + array[0] = preferred_value; + } + return does_occur; +} + +/** + * dccp_feat_reconcile - Reconcile SP preference lists + * @fv: SP list to reconcile into + * @arr: received SP preference list + * @len: length of @arr in bytes + * @is_server: whether this side is the server (and @fv is the server's list) + * @reorder: whether to reorder the list in @fv after reconciling with @arr + * When successful, > 0 is returned and the reconciled list is in @fval. + * A value of 0 means that negotiation failed (no shared entry). + */ +static int dccp_feat_reconcile(dccp_feat_val *fv, u8 *arr, u8 len, + bool is_server, bool reorder) +{ + int rc; + + if (!fv->sp.vec || !arr) { + DCCP_CRIT("NULL feature value or array"); + return 0; + } + + if (is_server) + rc = dccp_feat_preflist_match(fv->sp.vec, fv->sp.len, arr, len); + else + rc = dccp_feat_preflist_match(arr, len, fv->sp.vec, fv->sp.len); + + if (!reorder) + return rc; + if (rc < 0) + return 0; + + /* + * Reorder list: used for activating features and in dccp_insert_fn_opt. + */ + return dccp_feat_prefer(rc, fv->sp.vec, fv->sp.len); +} + +/** + * dccp_feat_change_recv - Process incoming ChangeL/R options + * @fn: feature-negotiation list to update + * @is_mandatory: whether the Change was preceded by a Mandatory option + * @opt: %DCCPO_CHANGE_L or %DCCPO_CHANGE_R + * @feat: one of %dccp_feature_numbers + * @val: NN value or SP value/preference list + * @len: length of @val in bytes + * @server: whether this node is the server (1) or the client (0) + */ +static u8 dccp_feat_change_recv(struct list_head *fn, u8 is_mandatory, u8 opt, + u8 feat, u8 *val, u8 len, const bool server) +{ + u8 defval, type = dccp_feat_type(feat); + const bool local = (opt == DCCPO_CHANGE_R); + struct dccp_feat_entry *entry; + dccp_feat_val fval; + + if (len == 0 || type == FEAT_UNKNOWN) /* 6.1 and 6.6.8 */ + goto unknown_feature_or_value; + + dccp_feat_print_opt(opt, feat, val, len, is_mandatory); + + /* + * Negotiation of NN features: Change R is invalid, so there is no + * simultaneous negotiation; hence we do not look up in the list. + */ + if (type == FEAT_NN) { + if (local || len > sizeof(fval.nn)) + goto unknown_feature_or_value; + + /* 6.3.2: "The feature remote MUST accept any valid value..." */ + fval.nn = dccp_decode_value_var(val, len); + if (!dccp_feat_is_valid_nn_val(feat, fval.nn)) + goto unknown_feature_or_value; + + return dccp_feat_push_confirm(fn, feat, local, &fval); + } + + /* + * Unidirectional/simultaneous negotiation of SP features (6.3.1) + */ + entry = dccp_feat_list_lookup(fn, feat, local); + if (entry == NULL) { + /* + * No particular preferences have been registered. We deal with + * this situation by assuming that all valid values are equally + * acceptable, and apply the following checks: + * - if the peer's list is a singleton, we accept a valid value; + * - if we are the server, we first try to see if the peer (the + * client) advertises the default value. If yes, we use it, + * otherwise we accept the preferred value; + * - else if we are the client, we use the first list element. + */ + if (dccp_feat_clone_sp_val(&fval, val, 1)) + return DCCP_RESET_CODE_TOO_BUSY; + + if (len > 1 && server) { + defval = dccp_feat_default_value(feat); + if (dccp_feat_preflist_match(&defval, 1, val, len) > -1) + fval.sp.vec[0] = defval; + } else if (!dccp_feat_is_valid_sp_val(feat, fval.sp.vec[0])) { + kfree(fval.sp.vec); + goto unknown_feature_or_value; + } + + /* Treat unsupported CCIDs like invalid values */ + if (feat == DCCPF_CCID && !ccid_support_check(fval.sp.vec, 1)) { + kfree(fval.sp.vec); + goto not_valid_or_not_known; + } + + return dccp_feat_push_confirm(fn, feat, local, &fval); + + } else if (entry->state == FEAT_UNSTABLE) { /* 6.6.2 */ + return 0; + } + + if (dccp_feat_reconcile(&entry->val, val, len, server, true)) { + entry->empty_confirm = false; + } else if (is_mandatory) { + return DCCP_RESET_CODE_MANDATORY_ERROR; + } else if (entry->state == FEAT_INITIALISING) { + /* + * Failed simultaneous negotiation (server only): try to `save' + * the connection by checking whether entry contains the default + * value for @feat. If yes, send an empty Confirm to signal that + * the received Change was not understood - which implies using + * the default value. + * If this also fails, we use Reset as the last resort. + */ + WARN_ON(!server); + defval = dccp_feat_default_value(feat); + if (!dccp_feat_reconcile(&entry->val, &defval, 1, server, true)) + return DCCP_RESET_CODE_OPTION_ERROR; + entry->empty_confirm = true; + } + entry->needs_confirm = true; + entry->needs_mandatory = false; + entry->state = FEAT_STABLE; + return 0; + +unknown_feature_or_value: + if (!is_mandatory) + return dccp_push_empty_confirm(fn, feat, local); + +not_valid_or_not_known: + return is_mandatory ? DCCP_RESET_CODE_MANDATORY_ERROR + : DCCP_RESET_CODE_OPTION_ERROR; +} + +/** + * dccp_feat_confirm_recv - Process received Confirm options + * @fn: feature-negotiation list to update + * @is_mandatory: whether @opt was preceded by a Mandatory option + * @opt: %DCCPO_CONFIRM_L or %DCCPO_CONFIRM_R + * @feat: one of %dccp_feature_numbers + * @val: NN value or SP value/preference list + * @len: length of @val in bytes + * @server: whether this node is server (1) or client (0) + */ +static u8 dccp_feat_confirm_recv(struct list_head *fn, u8 is_mandatory, u8 opt, + u8 feat, u8 *val, u8 len, const bool server) +{ + u8 *plist, plen, type = dccp_feat_type(feat); + const bool local = (opt == DCCPO_CONFIRM_R); + struct dccp_feat_entry *entry = dccp_feat_list_lookup(fn, feat, local); + + dccp_feat_print_opt(opt, feat, val, len, is_mandatory); + + if (entry == NULL) { /* nothing queued: ignore or handle error */ + if (is_mandatory && type == FEAT_UNKNOWN) + return DCCP_RESET_CODE_MANDATORY_ERROR; + + if (!local && type == FEAT_NN) /* 6.3.2 */ + goto confirmation_failed; + return 0; + } + + if (entry->state != FEAT_CHANGING) /* 6.6.2 */ + return 0; + + if (len == 0) { + if (dccp_feat_must_be_understood(feat)) /* 6.6.7 */ + goto confirmation_failed; + /* + * Empty Confirm during connection setup: this means reverting + * to the `old' value, which in this case is the default. Since + * we handle default values automatically when no other values + * have been set, we revert to the old value by removing this + * entry from the list. + */ + dccp_feat_list_pop(entry); + return 0; + } + + if (type == FEAT_NN) { + if (len > sizeof(entry->val.nn)) + goto confirmation_failed; + + if (entry->val.nn == dccp_decode_value_var(val, len)) + goto confirmation_succeeded; + + DCCP_WARN("Bogus Confirm for non-existing value\n"); + goto confirmation_failed; + } + + /* + * Parsing SP Confirms: the first element of @val is the preferred + * SP value which the peer confirms, the remainder depends on @len. + * Note that only the confirmed value need to be a valid SP value. + */ + if (!dccp_feat_is_valid_sp_val(feat, *val)) + goto confirmation_failed; + + if (len == 1) { /* peer didn't supply a preference list */ + plist = val; + plen = len; + } else { /* preferred value + preference list */ + plist = val + 1; + plen = len - 1; + } + + /* Check whether the peer got the reconciliation right (6.6.8) */ + if (dccp_feat_reconcile(&entry->val, plist, plen, server, 0) != *val) { + DCCP_WARN("Confirm selected the wrong value %u\n", *val); + return DCCP_RESET_CODE_OPTION_ERROR; + } + entry->val.sp.vec[0] = *val; + +confirmation_succeeded: + entry->state = FEAT_STABLE; + return 0; + +confirmation_failed: + DCCP_WARN("Confirmation failed\n"); + return is_mandatory ? DCCP_RESET_CODE_MANDATORY_ERROR + : DCCP_RESET_CODE_OPTION_ERROR; +} + +/** + * dccp_feat_handle_nn_established - Fast-path reception of NN options + * @sk: socket of an established DCCP connection + * @mandatory: whether @opt was preceded by a Mandatory option + * @opt: %DCCPO_CHANGE_L | %DCCPO_CONFIRM_R (NN only) + * @feat: NN number, one of %dccp_feature_numbers + * @val: NN value + * @len: length of @val in bytes + * + * This function combines the functionality of change_recv/confirm_recv, with + * the following differences (reset codes are the same): + * - cleanup after receiving the Confirm; + * - values are directly activated after successful parsing; + * - deliberately restricted to NN features. + * The restriction to NN features is essential since SP features can have non- + * predictable outcomes (depending on the remote configuration), and are inter- + * dependent (CCIDs for instance cause further dependencies). + */ +static u8 dccp_feat_handle_nn_established(struct sock *sk, u8 mandatory, u8 opt, + u8 feat, u8 *val, u8 len) +{ + struct list_head *fn = &dccp_sk(sk)->dccps_featneg; + const bool local = (opt == DCCPO_CONFIRM_R); + struct dccp_feat_entry *entry; + u8 type = dccp_feat_type(feat); + dccp_feat_val fval; + + dccp_feat_print_opt(opt, feat, val, len, mandatory); + + /* Ignore non-mandatory unknown and non-NN features */ + if (type == FEAT_UNKNOWN) { + if (local && !mandatory) + return 0; + goto fast_path_unknown; + } else if (type != FEAT_NN) { + return 0; + } + + /* + * We don't accept empty Confirms, since in fast-path feature + * negotiation the values are enabled immediately after sending + * the Change option. + * Empty Changes on the other hand are invalid (RFC 4340, 6.1). + */ + if (len == 0 || len > sizeof(fval.nn)) + goto fast_path_unknown; + + if (opt == DCCPO_CHANGE_L) { + fval.nn = dccp_decode_value_var(val, len); + if (!dccp_feat_is_valid_nn_val(feat, fval.nn)) + goto fast_path_unknown; + + if (dccp_feat_push_confirm(fn, feat, local, &fval) || + dccp_feat_activate(sk, feat, local, &fval)) + return DCCP_RESET_CODE_TOO_BUSY; + + /* set the `Ack Pending' flag to piggyback a Confirm */ + inet_csk_schedule_ack(sk); + + } else if (opt == DCCPO_CONFIRM_R) { + entry = dccp_feat_list_lookup(fn, feat, local); + if (entry == NULL || entry->state != FEAT_CHANGING) + return 0; + + fval.nn = dccp_decode_value_var(val, len); + /* + * Just ignore a value that doesn't match our current value. + * If the option changes twice within two RTTs, then at least + * one CONFIRM will be received for the old value after a + * new CHANGE was sent. + */ + if (fval.nn != entry->val.nn) + return 0; + + /* Only activate after receiving the Confirm option (6.6.1). */ + dccp_feat_activate(sk, feat, local, &fval); + + /* It has been confirmed - so remove the entry */ + dccp_feat_list_pop(entry); + + } else { + DCCP_WARN("Received illegal option %u\n", opt); + goto fast_path_failed; + } + return 0; + +fast_path_unknown: + if (!mandatory) + return dccp_push_empty_confirm(fn, feat, local); + +fast_path_failed: + return mandatory ? DCCP_RESET_CODE_MANDATORY_ERROR + : DCCP_RESET_CODE_OPTION_ERROR; +} + +/** + * dccp_feat_parse_options - Process Feature-Negotiation Options + * @sk: for general use and used by the client during connection setup + * @dreq: used by the server during connection setup + * @mandatory: whether @opt was preceded by a Mandatory option + * @opt: %DCCPO_CHANGE_L | %DCCPO_CHANGE_R | %DCCPO_CONFIRM_L | %DCCPO_CONFIRM_R + * @feat: one of %dccp_feature_numbers + * @val: value contents of @opt + * @len: length of @val in bytes + * + * Returns 0 on success, a Reset code for ending the connection otherwise. + */ +int dccp_feat_parse_options(struct sock *sk, struct dccp_request_sock *dreq, + u8 mandatory, u8 opt, u8 feat, u8 *val, u8 len) +{ + struct dccp_sock *dp = dccp_sk(sk); + struct list_head *fn = dreq ? &dreq->dreq_featneg : &dp->dccps_featneg; + bool server = false; + + switch (sk->sk_state) { + /* + * Negotiation during connection setup + */ + case DCCP_LISTEN: + server = true; + fallthrough; + case DCCP_REQUESTING: + switch (opt) { + case DCCPO_CHANGE_L: + case DCCPO_CHANGE_R: + return dccp_feat_change_recv(fn, mandatory, opt, feat, + val, len, server); + case DCCPO_CONFIRM_R: + case DCCPO_CONFIRM_L: + return dccp_feat_confirm_recv(fn, mandatory, opt, feat, + val, len, server); + } + break; + /* + * Support for exchanging NN options on an established connection. + */ + case DCCP_OPEN: + case DCCP_PARTOPEN: + return dccp_feat_handle_nn_established(sk, mandatory, opt, feat, + val, len); + } + return 0; /* ignore FN options in all other states */ +} + +/** + * dccp_feat_init - Seed feature negotiation with host-specific defaults + * @sk: Socket to initialize. + * + * This initialises global defaults, depending on the value of the sysctls. + * These can later be overridden by registering changes via setsockopt calls. + * The last link in the chain is finalise_settings, to make sure that between + * here and the start of actual feature negotiation no inconsistencies enter. + * + * All features not appearing below use either defaults or are otherwise + * later adjusted through dccp_feat_finalise_settings(). + */ +int dccp_feat_init(struct sock *sk) +{ + struct list_head *fn = &dccp_sk(sk)->dccps_featneg; + u8 on = 1, off = 0; + int rc; + struct { + u8 *val; + u8 len; + } tx, rx; + + /* Non-negotiable (NN) features */ + rc = __feat_register_nn(fn, DCCPF_SEQUENCE_WINDOW, 0, + sysctl_dccp_sequence_window); + if (rc) + return rc; + + /* Server-priority (SP) features */ + + /* Advertise that short seqnos are not supported (7.6.1) */ + rc = __feat_register_sp(fn, DCCPF_SHORT_SEQNOS, true, true, &off, 1); + if (rc) + return rc; + + /* RFC 4340 12.1: "If a DCCP is not ECN capable, ..." */ + rc = __feat_register_sp(fn, DCCPF_ECN_INCAPABLE, true, true, &on, 1); + if (rc) + return rc; + + /* + * We advertise the available list of CCIDs and reorder according to + * preferences, to avoid failure resulting from negotiating different + * singleton values (which always leads to failure). + * These settings can still (later) be overridden via sockopts. + */ + if (ccid_get_builtin_ccids(&tx.val, &tx.len)) + return -ENOBUFS; + if (ccid_get_builtin_ccids(&rx.val, &rx.len)) { + kfree(tx.val); + return -ENOBUFS; + } + + if (!dccp_feat_prefer(sysctl_dccp_tx_ccid, tx.val, tx.len) || + !dccp_feat_prefer(sysctl_dccp_rx_ccid, rx.val, rx.len)) + goto free_ccid_lists; + + rc = __feat_register_sp(fn, DCCPF_CCID, true, false, tx.val, tx.len); + if (rc) + goto free_ccid_lists; + + rc = __feat_register_sp(fn, DCCPF_CCID, false, false, rx.val, rx.len); + +free_ccid_lists: + kfree(tx.val); + kfree(rx.val); + return rc; +} + +int dccp_feat_activate_values(struct sock *sk, struct list_head *fn_list) +{ + struct dccp_sock *dp = dccp_sk(sk); + struct dccp_feat_entry *cur, *next; + int idx; + dccp_feat_val *fvals[DCCP_FEAT_SUPPORTED_MAX][2] = { + [0 ... DCCP_FEAT_SUPPORTED_MAX-1] = { NULL, NULL } + }; + + list_for_each_entry(cur, fn_list, node) { + /* + * An empty Confirm means that either an unknown feature type + * or an invalid value was present. In the first case there is + * nothing to activate, in the other the default value is used. + */ + if (cur->empty_confirm) + continue; + + idx = dccp_feat_index(cur->feat_num); + if (idx < 0) { + DCCP_BUG("Unknown feature %u", cur->feat_num); + goto activation_failed; + } + if (cur->state != FEAT_STABLE) { + DCCP_CRIT("Negotiation of %s %s failed in state %s", + cur->is_local ? "local" : "remote", + dccp_feat_fname(cur->feat_num), + dccp_feat_sname[cur->state]); + goto activation_failed; + } + fvals[idx][cur->is_local] = &cur->val; + } + + /* + * Activate in decreasing order of index, so that the CCIDs are always + * activated as the last feature. This avoids the case where a CCID + * relies on the initialisation of one or more features that it depends + * on (e.g. Send NDP Count, Send Ack Vector, and Ack Ratio features). + */ + for (idx = DCCP_FEAT_SUPPORTED_MAX; --idx >= 0;) + if (__dccp_feat_activate(sk, idx, 0, fvals[idx][0]) || + __dccp_feat_activate(sk, idx, 1, fvals[idx][1])) { + DCCP_CRIT("Could not activate %d", idx); + goto activation_failed; + } + + /* Clean up Change options which have been confirmed already */ + list_for_each_entry_safe(cur, next, fn_list, node) + if (!cur->needs_confirm) + dccp_feat_list_pop(cur); + + dccp_pr_debug("Activation OK\n"); + return 0; + +activation_failed: + /* + * We clean up everything that may have been allocated, since + * it is difficult to track at which stage negotiation failed. + * This is ok, since all allocation functions below are robust + * against NULL arguments. + */ + ccid_hc_rx_delete(dp->dccps_hc_rx_ccid, sk); + ccid_hc_tx_delete(dp->dccps_hc_tx_ccid, sk); + dp->dccps_hc_rx_ccid = dp->dccps_hc_tx_ccid = NULL; + dccp_ackvec_free(dp->dccps_hc_rx_ackvec); + dp->dccps_hc_rx_ackvec = NULL; + return -1; +} diff --git a/net/dccp/feat.h b/net/dccp/feat.h new file mode 100644 index 000000000..d76c9be5b --- /dev/null +++ b/net/dccp/feat.h @@ -0,0 +1,134 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +#ifndef _DCCP_FEAT_H +#define _DCCP_FEAT_H +/* + * net/dccp/feat.h + * + * Feature negotiation for the DCCP protocol (RFC 4340, section 6) + * Copyright (c) 2008 Gerrit Renker + * Copyright (c) 2005 Andrea Bittau + */ +#include +#include "dccp.h" + +/* + * Known limit values + */ +/* Ack Ratio takes 2-byte integer values (11.3) */ +#define DCCPF_ACK_RATIO_MAX 0xFFFF +/* Wmin=32 and Wmax=2^46-1 from 7.5.2 */ +#define DCCPF_SEQ_WMIN 32 +#define DCCPF_SEQ_WMAX 0x3FFFFFFFFFFFull +/* Maximum number of SP values that fit in a single (Confirm) option */ +#define DCCP_FEAT_MAX_SP_VALS (DCCP_SINGLE_OPT_MAXLEN - 2) + +enum dccp_feat_type { + FEAT_AT_RX = 1, /* located at RX side of half-connection */ + FEAT_AT_TX = 2, /* located at TX side of half-connection */ + FEAT_SP = 4, /* server-priority reconciliation (6.3.1) */ + FEAT_NN = 8, /* non-negotiable reconciliation (6.3.2) */ + FEAT_UNKNOWN = 0xFF /* not understood or invalid feature */ +}; + +enum dccp_feat_state { + FEAT_DEFAULT = 0, /* using default values from 6.4 */ + FEAT_INITIALISING, /* feature is being initialised */ + FEAT_CHANGING, /* Change sent but not confirmed yet */ + FEAT_UNSTABLE, /* local modification in state CHANGING */ + FEAT_STABLE /* both ends (think they) agree */ +}; + +/** + * dccp_feat_val - Container for SP or NN feature values + * @nn: single NN value + * @sp.vec: single SP value plus optional preference list + * @sp.len: length of @sp.vec in bytes + */ +typedef union { + u64 nn; + struct { + u8 *vec; + u8 len; + } sp; +} dccp_feat_val; + +/** + * struct feat_entry - Data structure to perform feature negotiation + * @val: feature's current value (SP features may have preference list) + * @state: feature's current state + * @feat_num: one of %dccp_feature_numbers + * @needs_mandatory: whether Mandatory options should be sent + * @needs_confirm: whether to send a Confirm instead of a Change + * @empty_confirm: whether to send an empty Confirm (depends on @needs_confirm) + * @is_local: feature location (1) or feature-remote (0) + * @node: list pointers, entries arranged in FIFO order + */ +struct dccp_feat_entry { + dccp_feat_val val; + enum dccp_feat_state state:8; + u8 feat_num; + + bool needs_mandatory, + needs_confirm, + empty_confirm, + is_local; + + struct list_head node; +}; + +static inline u8 dccp_feat_genopt(struct dccp_feat_entry *entry) +{ + if (entry->needs_confirm) + return entry->is_local ? DCCPO_CONFIRM_L : DCCPO_CONFIRM_R; + return entry->is_local ? DCCPO_CHANGE_L : DCCPO_CHANGE_R; +} + +/** + * struct ccid_dependency - Track changes resulting from choosing a CCID + * @dependent_feat: one of %dccp_feature_numbers + * @is_local: local (1) or remote (0) @dependent_feat + * @is_mandatory: whether presence of @dependent_feat is mission-critical or not + * @val: corresponding default value for @dependent_feat (u8 is sufficient here) + */ +struct ccid_dependency { + u8 dependent_feat; + bool is_local:1, + is_mandatory:1; + u8 val; +}; + +/* + * Sysctls to seed defaults for feature negotiation + */ +extern unsigned long sysctl_dccp_sequence_window; +extern int sysctl_dccp_rx_ccid; +extern int sysctl_dccp_tx_ccid; + +int dccp_feat_init(struct sock *sk); +void dccp_feat_initialise_sysctls(void); +int dccp_feat_register_sp(struct sock *sk, u8 feat, u8 is_local, + u8 const *list, u8 len); +int dccp_feat_parse_options(struct sock *, struct dccp_request_sock *, + u8 mand, u8 opt, u8 feat, u8 *val, u8 len); +int dccp_feat_clone_list(struct list_head const *, struct list_head *); + +/* + * Encoding variable-length options and their maximum length. + * + * This affects NN options (SP options are all u8) and other variable-length + * options (see table 3 in RFC 4340). The limit is currently given the Sequence + * Window NN value (sec. 7.5.2) and the NDP count (sec. 7.7) option, all other + * options consume less than 6 bytes (timestamps are 4 bytes). + * When updating this constant (e.g. due to new internet drafts / RFCs), make + * sure that you also update all code which refers to it. + */ +#define DCCP_OPTVAL_MAXLEN 6 + +void dccp_encode_value_var(const u64 value, u8 *to, const u8 len); +u64 dccp_decode_value_var(const u8 *bf, const u8 len); +u64 dccp_feat_nn_get(struct sock *sk, u8 feat); + +int dccp_insert_option_mandatory(struct sk_buff *skb); +int dccp_insert_fn_opt(struct sk_buff *skb, u8 type, u8 feat, u8 *val, u8 len, + bool repeat_first); +#endif /* _DCCP_FEAT_H */ diff --git a/net/dccp/input.c b/net/dccp/input.c new file mode 100644 index 000000000..2cbb757a8 --- /dev/null +++ b/net/dccp/input.c @@ -0,0 +1,739 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/dccp/input.c + * + * An implementation of the DCCP protocol + * Arnaldo Carvalho de Melo + */ + +#include +#include +#include + +#include + +#include "ackvec.h" +#include "ccid.h" +#include "dccp.h" + +/* rate-limit for syncs in reply to sequence-invalid packets; RFC 4340, 7.5.4 */ +int sysctl_dccp_sync_ratelimit __read_mostly = HZ / 8; + +static void dccp_enqueue_skb(struct sock *sk, struct sk_buff *skb) +{ + __skb_pull(skb, dccp_hdr(skb)->dccph_doff * 4); + __skb_queue_tail(&sk->sk_receive_queue, skb); + skb_set_owner_r(skb, sk); + sk->sk_data_ready(sk); +} + +static void dccp_fin(struct sock *sk, struct sk_buff *skb) +{ + /* + * On receiving Close/CloseReq, both RD/WR shutdown are performed. + * RFC 4340, 8.3 says that we MAY send further Data/DataAcks after + * receiving the closing segment, but there is no guarantee that such + * data will be processed at all. + */ + sk->sk_shutdown = SHUTDOWN_MASK; + sock_set_flag(sk, SOCK_DONE); + dccp_enqueue_skb(sk, skb); +} + +static int dccp_rcv_close(struct sock *sk, struct sk_buff *skb) +{ + int queued = 0; + + switch (sk->sk_state) { + /* + * We ignore Close when received in one of the following states: + * - CLOSED (may be a late or duplicate packet) + * - PASSIVE_CLOSEREQ (the peer has sent a CloseReq earlier) + * - RESPOND (already handled by dccp_check_req) + */ + case DCCP_CLOSING: + /* + * Simultaneous-close: receiving a Close after sending one. This + * can happen if both client and server perform active-close and + * will result in an endless ping-pong of crossing and retrans- + * mitted Close packets, which only terminates when one of the + * nodes times out (min. 64 seconds). Quicker convergence can be + * achieved when one of the nodes acts as tie-breaker. + * This is ok as both ends are done with data transfer and each + * end is just waiting for the other to acknowledge termination. + */ + if (dccp_sk(sk)->dccps_role != DCCP_ROLE_CLIENT) + break; + fallthrough; + case DCCP_REQUESTING: + case DCCP_ACTIVE_CLOSEREQ: + dccp_send_reset(sk, DCCP_RESET_CODE_CLOSED); + dccp_done(sk); + break; + case DCCP_OPEN: + case DCCP_PARTOPEN: + /* Give waiting application a chance to read pending data */ + queued = 1; + dccp_fin(sk, skb); + dccp_set_state(sk, DCCP_PASSIVE_CLOSE); + fallthrough; + case DCCP_PASSIVE_CLOSE: + /* + * Retransmitted Close: we have already enqueued the first one. + */ + sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_HUP); + } + return queued; +} + +static int dccp_rcv_closereq(struct sock *sk, struct sk_buff *skb) +{ + int queued = 0; + + /* + * Step 7: Check for unexpected packet types + * If (S.is_server and P.type == CloseReq) + * Send Sync packet acknowledging P.seqno + * Drop packet and return + */ + if (dccp_sk(sk)->dccps_role != DCCP_ROLE_CLIENT) { + dccp_send_sync(sk, DCCP_SKB_CB(skb)->dccpd_seq, DCCP_PKT_SYNC); + return queued; + } + + /* Step 13: process relevant Client states < CLOSEREQ */ + switch (sk->sk_state) { + case DCCP_REQUESTING: + dccp_send_close(sk, 0); + dccp_set_state(sk, DCCP_CLOSING); + break; + case DCCP_OPEN: + case DCCP_PARTOPEN: + /* Give waiting application a chance to read pending data */ + queued = 1; + dccp_fin(sk, skb); + dccp_set_state(sk, DCCP_PASSIVE_CLOSEREQ); + fallthrough; + case DCCP_PASSIVE_CLOSEREQ: + sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_HUP); + } + return queued; +} + +static u16 dccp_reset_code_convert(const u8 code) +{ + static const u16 error_code[] = { + [DCCP_RESET_CODE_CLOSED] = 0, /* normal termination */ + [DCCP_RESET_CODE_UNSPECIFIED] = 0, /* nothing known */ + [DCCP_RESET_CODE_ABORTED] = ECONNRESET, + + [DCCP_RESET_CODE_NO_CONNECTION] = ECONNREFUSED, + [DCCP_RESET_CODE_CONNECTION_REFUSED] = ECONNREFUSED, + [DCCP_RESET_CODE_TOO_BUSY] = EUSERS, + [DCCP_RESET_CODE_AGGRESSION_PENALTY] = EDQUOT, + + [DCCP_RESET_CODE_PACKET_ERROR] = ENOMSG, + [DCCP_RESET_CODE_BAD_INIT_COOKIE] = EBADR, + [DCCP_RESET_CODE_BAD_SERVICE_CODE] = EBADRQC, + [DCCP_RESET_CODE_OPTION_ERROR] = EILSEQ, + [DCCP_RESET_CODE_MANDATORY_ERROR] = EOPNOTSUPP, + }; + + return code >= DCCP_MAX_RESET_CODES ? 0 : error_code[code]; +} + +static void dccp_rcv_reset(struct sock *sk, struct sk_buff *skb) +{ + u16 err = dccp_reset_code_convert(dccp_hdr_reset(skb)->dccph_reset_code); + + sk->sk_err = err; + + /* Queue the equivalent of TCP fin so that dccp_recvmsg exits the loop */ + dccp_fin(sk, skb); + + if (err && !sock_flag(sk, SOCK_DEAD)) + sk_wake_async(sk, SOCK_WAKE_IO, POLL_ERR); + dccp_time_wait(sk, DCCP_TIME_WAIT, 0); +} + +static void dccp_handle_ackvec_processing(struct sock *sk, struct sk_buff *skb) +{ + struct dccp_ackvec *av = dccp_sk(sk)->dccps_hc_rx_ackvec; + + if (av == NULL) + return; + if (DCCP_SKB_CB(skb)->dccpd_ack_seq != DCCP_PKT_WITHOUT_ACK_SEQ) + dccp_ackvec_clear_state(av, DCCP_SKB_CB(skb)->dccpd_ack_seq); + dccp_ackvec_input(av, skb); +} + +static void dccp_deliver_input_to_ccids(struct sock *sk, struct sk_buff *skb) +{ + const struct dccp_sock *dp = dccp_sk(sk); + + /* Don't deliver to RX CCID when node has shut down read end. */ + if (!(sk->sk_shutdown & RCV_SHUTDOWN)) + ccid_hc_rx_packet_recv(dp->dccps_hc_rx_ccid, sk, skb); + /* + * Until the TX queue has been drained, we can not honour SHUT_WR, since + * we need received feedback as input to adjust congestion control. + */ + if (sk->sk_write_queue.qlen > 0 || !(sk->sk_shutdown & SEND_SHUTDOWN)) + ccid_hc_tx_packet_recv(dp->dccps_hc_tx_ccid, sk, skb); +} + +static int dccp_check_seqno(struct sock *sk, struct sk_buff *skb) +{ + const struct dccp_hdr *dh = dccp_hdr(skb); + struct dccp_sock *dp = dccp_sk(sk); + u64 lswl, lawl, seqno = DCCP_SKB_CB(skb)->dccpd_seq, + ackno = DCCP_SKB_CB(skb)->dccpd_ack_seq; + + /* + * Step 5: Prepare sequence numbers for Sync + * If P.type == Sync or P.type == SyncAck, + * If S.AWL <= P.ackno <= S.AWH and P.seqno >= S.SWL, + * / * P is valid, so update sequence number variables + * accordingly. After this update, P will pass the tests + * in Step 6. A SyncAck is generated if necessary in + * Step 15 * / + * Update S.GSR, S.SWL, S.SWH + * Otherwise, + * Drop packet and return + */ + if (dh->dccph_type == DCCP_PKT_SYNC || + dh->dccph_type == DCCP_PKT_SYNCACK) { + if (between48(ackno, dp->dccps_awl, dp->dccps_awh) && + dccp_delta_seqno(dp->dccps_swl, seqno) >= 0) + dccp_update_gsr(sk, seqno); + else + return -1; + } + + /* + * Step 6: Check sequence numbers + * Let LSWL = S.SWL and LAWL = S.AWL + * If P.type == CloseReq or P.type == Close or P.type == Reset, + * LSWL := S.GSR + 1, LAWL := S.GAR + * If LSWL <= P.seqno <= S.SWH + * and (P.ackno does not exist or LAWL <= P.ackno <= S.AWH), + * Update S.GSR, S.SWL, S.SWH + * If P.type != Sync, + * Update S.GAR + */ + lswl = dp->dccps_swl; + lawl = dp->dccps_awl; + + if (dh->dccph_type == DCCP_PKT_CLOSEREQ || + dh->dccph_type == DCCP_PKT_CLOSE || + dh->dccph_type == DCCP_PKT_RESET) { + lswl = ADD48(dp->dccps_gsr, 1); + lawl = dp->dccps_gar; + } + + if (between48(seqno, lswl, dp->dccps_swh) && + (ackno == DCCP_PKT_WITHOUT_ACK_SEQ || + between48(ackno, lawl, dp->dccps_awh))) { + dccp_update_gsr(sk, seqno); + + if (dh->dccph_type != DCCP_PKT_SYNC && + ackno != DCCP_PKT_WITHOUT_ACK_SEQ && + after48(ackno, dp->dccps_gar)) + dp->dccps_gar = ackno; + } else { + unsigned long now = jiffies; + /* + * Step 6: Check sequence numbers + * Otherwise, + * If P.type == Reset, + * Send Sync packet acknowledging S.GSR + * Otherwise, + * Send Sync packet acknowledging P.seqno + * Drop packet and return + * + * These Syncs are rate-limited as per RFC 4340, 7.5.4: + * at most 1 / (dccp_sync_rate_limit * HZ) Syncs per second. + */ + if (time_before(now, (dp->dccps_rate_last + + sysctl_dccp_sync_ratelimit))) + return -1; + + DCCP_WARN("Step 6 failed for %s packet, " + "(LSWL(%llu) <= P.seqno(%llu) <= S.SWH(%llu)) and " + "(P.ackno %s or LAWL(%llu) <= P.ackno(%llu) <= S.AWH(%llu), " + "sending SYNC...\n", dccp_packet_name(dh->dccph_type), + (unsigned long long) lswl, (unsigned long long) seqno, + (unsigned long long) dp->dccps_swh, + (ackno == DCCP_PKT_WITHOUT_ACK_SEQ) ? "doesn't exist" + : "exists", + (unsigned long long) lawl, (unsigned long long) ackno, + (unsigned long long) dp->dccps_awh); + + dp->dccps_rate_last = now; + + if (dh->dccph_type == DCCP_PKT_RESET) + seqno = dp->dccps_gsr; + dccp_send_sync(sk, seqno, DCCP_PKT_SYNC); + return -1; + } + + return 0; +} + +static int __dccp_rcv_established(struct sock *sk, struct sk_buff *skb, + const struct dccp_hdr *dh, const unsigned int len) +{ + struct dccp_sock *dp = dccp_sk(sk); + + switch (dccp_hdr(skb)->dccph_type) { + case DCCP_PKT_DATAACK: + case DCCP_PKT_DATA: + /* + * FIXME: schedule DATA_DROPPED (RFC 4340, 11.7.2) if and when + * - sk_shutdown == RCV_SHUTDOWN, use Code 1, "Not Listening" + * - sk_receive_queue is full, use Code 2, "Receive Buffer" + */ + dccp_enqueue_skb(sk, skb); + return 0; + case DCCP_PKT_ACK: + goto discard; + case DCCP_PKT_RESET: + /* + * Step 9: Process Reset + * If P.type == Reset, + * Tear down connection + * S.state := TIMEWAIT + * Set TIMEWAIT timer + * Drop packet and return + */ + dccp_rcv_reset(sk, skb); + return 0; + case DCCP_PKT_CLOSEREQ: + if (dccp_rcv_closereq(sk, skb)) + return 0; + goto discard; + case DCCP_PKT_CLOSE: + if (dccp_rcv_close(sk, skb)) + return 0; + goto discard; + case DCCP_PKT_REQUEST: + /* Step 7 + * or (S.is_server and P.type == Response) + * or (S.is_client and P.type == Request) + * or (S.state >= OPEN and P.type == Request + * and P.seqno >= S.OSR) + * or (S.state >= OPEN and P.type == Response + * and P.seqno >= S.OSR) + * or (S.state == RESPOND and P.type == Data), + * Send Sync packet acknowledging P.seqno + * Drop packet and return + */ + if (dp->dccps_role != DCCP_ROLE_LISTEN) + goto send_sync; + goto check_seq; + case DCCP_PKT_RESPONSE: + if (dp->dccps_role != DCCP_ROLE_CLIENT) + goto send_sync; +check_seq: + if (dccp_delta_seqno(dp->dccps_osr, + DCCP_SKB_CB(skb)->dccpd_seq) >= 0) { +send_sync: + dccp_send_sync(sk, DCCP_SKB_CB(skb)->dccpd_seq, + DCCP_PKT_SYNC); + } + break; + case DCCP_PKT_SYNC: + dccp_send_sync(sk, DCCP_SKB_CB(skb)->dccpd_seq, + DCCP_PKT_SYNCACK); + /* + * From RFC 4340, sec. 5.7 + * + * As with DCCP-Ack packets, DCCP-Sync and DCCP-SyncAck packets + * MAY have non-zero-length application data areas, whose + * contents receivers MUST ignore. + */ + goto discard; + } + + DCCP_INC_STATS(DCCP_MIB_INERRS); +discard: + __kfree_skb(skb); + return 0; +} + +int dccp_rcv_established(struct sock *sk, struct sk_buff *skb, + const struct dccp_hdr *dh, const unsigned int len) +{ + if (dccp_check_seqno(sk, skb)) + goto discard; + + if (dccp_parse_options(sk, NULL, skb)) + return 1; + + dccp_handle_ackvec_processing(sk, skb); + dccp_deliver_input_to_ccids(sk, skb); + + return __dccp_rcv_established(sk, skb, dh, len); +discard: + __kfree_skb(skb); + return 0; +} + +EXPORT_SYMBOL_GPL(dccp_rcv_established); + +static int dccp_rcv_request_sent_state_process(struct sock *sk, + struct sk_buff *skb, + const struct dccp_hdr *dh, + const unsigned int len) +{ + /* + * Step 4: Prepare sequence numbers in REQUEST + * If S.state == REQUEST, + * If (P.type == Response or P.type == Reset) + * and S.AWL <= P.ackno <= S.AWH, + * / * Set sequence number variables corresponding to the + * other endpoint, so P will pass the tests in Step 6 * / + * Set S.GSR, S.ISR, S.SWL, S.SWH + * / * Response processing continues in Step 10; Reset + * processing continues in Step 9 * / + */ + if (dh->dccph_type == DCCP_PKT_RESPONSE) { + const struct inet_connection_sock *icsk = inet_csk(sk); + struct dccp_sock *dp = dccp_sk(sk); + long tstamp = dccp_timestamp(); + + if (!between48(DCCP_SKB_CB(skb)->dccpd_ack_seq, + dp->dccps_awl, dp->dccps_awh)) { + dccp_pr_debug("invalid ackno: S.AWL=%llu, " + "P.ackno=%llu, S.AWH=%llu\n", + (unsigned long long)dp->dccps_awl, + (unsigned long long)DCCP_SKB_CB(skb)->dccpd_ack_seq, + (unsigned long long)dp->dccps_awh); + goto out_invalid_packet; + } + + /* + * If option processing (Step 8) failed, return 1 here so that + * dccp_v4_do_rcv() sends a Reset. The Reset code depends on + * the option type and is set in dccp_parse_options(). + */ + if (dccp_parse_options(sk, NULL, skb)) + return 1; + + /* Obtain usec RTT sample from SYN exchange (used by TFRC). */ + if (likely(dp->dccps_options_received.dccpor_timestamp_echo)) + dp->dccps_syn_rtt = dccp_sample_rtt(sk, 10 * (tstamp - + dp->dccps_options_received.dccpor_timestamp_echo)); + + /* Stop the REQUEST timer */ + inet_csk_clear_xmit_timer(sk, ICSK_TIME_RETRANS); + WARN_ON(sk->sk_send_head == NULL); + kfree_skb(sk->sk_send_head); + sk->sk_send_head = NULL; + + /* + * Set ISR, GSR from packet. ISS was set in dccp_v{4,6}_connect + * and GSS in dccp_transmit_skb(). Setting AWL/AWH and SWL/SWH + * is done as part of activating the feature values below, since + * these settings depend on the local/remote Sequence Window + * features, which were undefined or not confirmed until now. + */ + dp->dccps_gsr = dp->dccps_isr = DCCP_SKB_CB(skb)->dccpd_seq; + + dccp_sync_mss(sk, icsk->icsk_pmtu_cookie); + + /* + * Step 10: Process REQUEST state (second part) + * If S.state == REQUEST, + * / * If we get here, P is a valid Response from the + * server (see Step 4), and we should move to + * PARTOPEN state. PARTOPEN means send an Ack, + * don't send Data packets, retransmit Acks + * periodically, and always include any Init Cookie + * from the Response * / + * S.state := PARTOPEN + * Set PARTOPEN timer + * Continue with S.state == PARTOPEN + * / * Step 12 will send the Ack completing the + * three-way handshake * / + */ + dccp_set_state(sk, DCCP_PARTOPEN); + + /* + * If feature negotiation was successful, activate features now; + * an activation failure means that this host could not activate + * one ore more features (e.g. insufficient memory), which would + * leave at least one feature in an undefined state. + */ + if (dccp_feat_activate_values(sk, &dp->dccps_featneg)) + goto unable_to_proceed; + + /* Make sure socket is routed, for correct metrics. */ + icsk->icsk_af_ops->rebuild_header(sk); + + if (!sock_flag(sk, SOCK_DEAD)) { + sk->sk_state_change(sk); + sk_wake_async(sk, SOCK_WAKE_IO, POLL_OUT); + } + + if (sk->sk_write_pending || inet_csk_in_pingpong_mode(sk) || + icsk->icsk_accept_queue.rskq_defer_accept) { + /* Save one ACK. Data will be ready after + * several ticks, if write_pending is set. + * + * It may be deleted, but with this feature tcpdumps + * look so _wonderfully_ clever, that I was not able + * to stand against the temptation 8) --ANK + */ + /* + * OK, in DCCP we can as well do a similar trick, its + * even in the draft, but there is no need for us to + * schedule an ack here, as dccp_sendmsg does this for + * us, also stated in the draft. -acme + */ + __kfree_skb(skb); + return 0; + } + dccp_send_ack(sk); + return -1; + } + +out_invalid_packet: + /* dccp_v4_do_rcv will send a reset */ + DCCP_SKB_CB(skb)->dccpd_reset_code = DCCP_RESET_CODE_PACKET_ERROR; + return 1; + +unable_to_proceed: + DCCP_SKB_CB(skb)->dccpd_reset_code = DCCP_RESET_CODE_ABORTED; + /* + * We mark this socket as no longer usable, so that the loop in + * dccp_sendmsg() terminates and the application gets notified. + */ + dccp_set_state(sk, DCCP_CLOSED); + sk->sk_err = ECOMM; + return 1; +} + +static int dccp_rcv_respond_partopen_state_process(struct sock *sk, + struct sk_buff *skb, + const struct dccp_hdr *dh, + const unsigned int len) +{ + struct dccp_sock *dp = dccp_sk(sk); + u32 sample = dp->dccps_options_received.dccpor_timestamp_echo; + int queued = 0; + + switch (dh->dccph_type) { + case DCCP_PKT_RESET: + inet_csk_clear_xmit_timer(sk, ICSK_TIME_DACK); + break; + case DCCP_PKT_DATA: + if (sk->sk_state == DCCP_RESPOND) + break; + fallthrough; + case DCCP_PKT_DATAACK: + case DCCP_PKT_ACK: + /* + * FIXME: we should be resetting the PARTOPEN (DELACK) timer + * here but only if we haven't used the DELACK timer for + * something else, like sending a delayed ack for a TIMESTAMP + * echo, etc, for now were not clearing it, sending an extra + * ACK when there is nothing else to do in DELACK is not a big + * deal after all. + */ + + /* Stop the PARTOPEN timer */ + if (sk->sk_state == DCCP_PARTOPEN) + inet_csk_clear_xmit_timer(sk, ICSK_TIME_DACK); + + /* Obtain usec RTT sample from SYN exchange (used by TFRC). */ + if (likely(sample)) { + long delta = dccp_timestamp() - sample; + + dp->dccps_syn_rtt = dccp_sample_rtt(sk, 10 * delta); + } + + dp->dccps_osr = DCCP_SKB_CB(skb)->dccpd_seq; + dccp_set_state(sk, DCCP_OPEN); + + if (dh->dccph_type == DCCP_PKT_DATAACK || + dh->dccph_type == DCCP_PKT_DATA) { + __dccp_rcv_established(sk, skb, dh, len); + queued = 1; /* packet was queued + (by __dccp_rcv_established) */ + } + break; + } + + return queued; +} + +int dccp_rcv_state_process(struct sock *sk, struct sk_buff *skb, + struct dccp_hdr *dh, unsigned int len) +{ + struct dccp_sock *dp = dccp_sk(sk); + struct dccp_skb_cb *dcb = DCCP_SKB_CB(skb); + const int old_state = sk->sk_state; + bool acceptable; + int queued = 0; + + /* + * Step 3: Process LISTEN state + * + * If S.state == LISTEN, + * If P.type == Request or P contains a valid Init Cookie option, + * (* Must scan the packet's options to check for Init + * Cookies. Only Init Cookies are processed here, + * however; other options are processed in Step 8. This + * scan need only be performed if the endpoint uses Init + * Cookies *) + * (* Generate a new socket and switch to that socket *) + * Set S := new socket for this port pair + * S.state = RESPOND + * Choose S.ISS (initial seqno) or set from Init Cookies + * Initialize S.GAR := S.ISS + * Set S.ISR, S.GSR, S.SWL, S.SWH from packet or Init + * Cookies Continue with S.state == RESPOND + * (* A Response packet will be generated in Step 11 *) + * Otherwise, + * Generate Reset(No Connection) unless P.type == Reset + * Drop packet and return + */ + if (sk->sk_state == DCCP_LISTEN) { + if (dh->dccph_type == DCCP_PKT_REQUEST) { + /* It is possible that we process SYN packets from backlog, + * so we need to make sure to disable BH and RCU right there. + */ + rcu_read_lock(); + local_bh_disable(); + acceptable = inet_csk(sk)->icsk_af_ops->conn_request(sk, skb) >= 0; + local_bh_enable(); + rcu_read_unlock(); + if (!acceptable) + return 1; + consume_skb(skb); + return 0; + } + if (dh->dccph_type == DCCP_PKT_RESET) + goto discard; + + /* Caller (dccp_v4_do_rcv) will send Reset */ + dcb->dccpd_reset_code = DCCP_RESET_CODE_NO_CONNECTION; + return 1; + } else if (sk->sk_state == DCCP_CLOSED) { + dcb->dccpd_reset_code = DCCP_RESET_CODE_NO_CONNECTION; + return 1; + } + + /* Step 6: Check sequence numbers (omitted in LISTEN/REQUEST state) */ + if (sk->sk_state != DCCP_REQUESTING && dccp_check_seqno(sk, skb)) + goto discard; + + /* + * Step 7: Check for unexpected packet types + * If (S.is_server and P.type == Response) + * or (S.is_client and P.type == Request) + * or (S.state == RESPOND and P.type == Data), + * Send Sync packet acknowledging P.seqno + * Drop packet and return + */ + if ((dp->dccps_role != DCCP_ROLE_CLIENT && + dh->dccph_type == DCCP_PKT_RESPONSE) || + (dp->dccps_role == DCCP_ROLE_CLIENT && + dh->dccph_type == DCCP_PKT_REQUEST) || + (sk->sk_state == DCCP_RESPOND && dh->dccph_type == DCCP_PKT_DATA)) { + dccp_send_sync(sk, dcb->dccpd_seq, DCCP_PKT_SYNC); + goto discard; + } + + /* Step 8: Process options */ + if (dccp_parse_options(sk, NULL, skb)) + return 1; + + /* + * Step 9: Process Reset + * If P.type == Reset, + * Tear down connection + * S.state := TIMEWAIT + * Set TIMEWAIT timer + * Drop packet and return + */ + if (dh->dccph_type == DCCP_PKT_RESET) { + dccp_rcv_reset(sk, skb); + return 0; + } else if (dh->dccph_type == DCCP_PKT_CLOSEREQ) { /* Step 13 */ + if (dccp_rcv_closereq(sk, skb)) + return 0; + goto discard; + } else if (dh->dccph_type == DCCP_PKT_CLOSE) { /* Step 14 */ + if (dccp_rcv_close(sk, skb)) + return 0; + goto discard; + } + + switch (sk->sk_state) { + case DCCP_REQUESTING: + queued = dccp_rcv_request_sent_state_process(sk, skb, dh, len); + if (queued >= 0) + return queued; + + __kfree_skb(skb); + return 0; + + case DCCP_PARTOPEN: + /* Step 8: if using Ack Vectors, mark packet acknowledgeable */ + dccp_handle_ackvec_processing(sk, skb); + dccp_deliver_input_to_ccids(sk, skb); + fallthrough; + case DCCP_RESPOND: + queued = dccp_rcv_respond_partopen_state_process(sk, skb, + dh, len); + break; + } + + if (dh->dccph_type == DCCP_PKT_ACK || + dh->dccph_type == DCCP_PKT_DATAACK) { + switch (old_state) { + case DCCP_PARTOPEN: + sk->sk_state_change(sk); + sk_wake_async(sk, SOCK_WAKE_IO, POLL_OUT); + break; + } + } else if (unlikely(dh->dccph_type == DCCP_PKT_SYNC)) { + dccp_send_sync(sk, dcb->dccpd_seq, DCCP_PKT_SYNCACK); + goto discard; + } + + if (!queued) { +discard: + __kfree_skb(skb); + } + return 0; +} + +EXPORT_SYMBOL_GPL(dccp_rcv_state_process); + +/** + * dccp_sample_rtt - Validate and finalise computation of RTT sample + * @sk: socket structure + * @delta: number of microseconds between packet and acknowledgment + * + * The routine is kept generic to work in different contexts. It should be + * called immediately when the ACK used for the RTT sample arrives. + */ +u32 dccp_sample_rtt(struct sock *sk, long delta) +{ + /* dccpor_elapsed_time is either zeroed out or set and > 0 */ + delta -= dccp_sk(sk)->dccps_options_received.dccpor_elapsed_time * 10; + + if (unlikely(delta <= 0)) { + DCCP_WARN("unusable RTT sample %ld, using min\n", delta); + return DCCP_SANE_RTT_MIN; + } + if (unlikely(delta > DCCP_SANE_RTT_MAX)) { + DCCP_WARN("RTT sample %ld too large, using max\n", delta); + return DCCP_SANE_RTT_MAX; + } + + return delta; +} diff --git a/net/dccp/ipv4.c b/net/dccp/ipv4.c new file mode 100644 index 000000000..9fe6d9679 --- /dev/null +++ b/net/dccp/ipv4.c @@ -0,0 +1,1100 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/dccp/ipv4.c + * + * An implementation of the DCCP protocol + * Arnaldo Carvalho de Melo + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ackvec.h" +#include "ccid.h" +#include "dccp.h" +#include "feat.h" + +struct dccp_v4_pernet { + struct sock *v4_ctl_sk; +}; + +static unsigned int dccp_v4_pernet_id __read_mostly; + +/* + * The per-net v4_ctl_sk socket is used for responding to + * the Out-of-the-blue (OOTB) packets. A control sock will be created + * for this socket at the initialization time. + */ + +int dccp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) +{ + const struct sockaddr_in *usin = (struct sockaddr_in *)uaddr; + struct inet_sock *inet = inet_sk(sk); + struct dccp_sock *dp = dccp_sk(sk); + __be16 orig_sport, orig_dport; + __be32 daddr, nexthop; + struct flowi4 *fl4; + struct rtable *rt; + int err; + struct ip_options_rcu *inet_opt; + + dp->dccps_role = DCCP_ROLE_CLIENT; + + if (addr_len < sizeof(struct sockaddr_in)) + return -EINVAL; + + if (usin->sin_family != AF_INET) + return -EAFNOSUPPORT; + + nexthop = daddr = usin->sin_addr.s_addr; + + inet_opt = rcu_dereference_protected(inet->inet_opt, + lockdep_sock_is_held(sk)); + if (inet_opt != NULL && inet_opt->opt.srr) { + if (daddr == 0) + return -EINVAL; + nexthop = inet_opt->opt.faddr; + } + + orig_sport = inet->inet_sport; + orig_dport = usin->sin_port; + fl4 = &inet->cork.fl.u.ip4; + rt = ip_route_connect(fl4, nexthop, inet->inet_saddr, + sk->sk_bound_dev_if, IPPROTO_DCCP, orig_sport, + orig_dport, sk); + if (IS_ERR(rt)) + return PTR_ERR(rt); + + if (rt->rt_flags & (RTCF_MULTICAST | RTCF_BROADCAST)) { + ip_rt_put(rt); + return -ENETUNREACH; + } + + if (inet_opt == NULL || !inet_opt->opt.srr) + daddr = fl4->daddr; + + if (inet->inet_saddr == 0) { + err = inet_bhash2_update_saddr(sk, &fl4->saddr, AF_INET); + if (err) { + ip_rt_put(rt); + return err; + } + } else { + sk_rcv_saddr_set(sk, inet->inet_saddr); + } + + inet->inet_dport = usin->sin_port; + sk_daddr_set(sk, daddr); + + inet_csk(sk)->icsk_ext_hdr_len = 0; + if (inet_opt) + inet_csk(sk)->icsk_ext_hdr_len = inet_opt->opt.optlen; + /* + * Socket identity is still unknown (sport may be zero). + * However we set state to DCCP_REQUESTING and not releasing socket + * lock select source port, enter ourselves into the hash tables and + * complete initialization after this. + */ + dccp_set_state(sk, DCCP_REQUESTING); + err = inet_hash_connect(&dccp_death_row, sk); + if (err != 0) + goto failure; + + rt = ip_route_newports(fl4, rt, orig_sport, orig_dport, + inet->inet_sport, inet->inet_dport, sk); + if (IS_ERR(rt)) { + err = PTR_ERR(rt); + rt = NULL; + goto failure; + } + /* OK, now commit destination to socket. */ + sk_setup_caps(sk, &rt->dst); + + dp->dccps_iss = secure_dccp_sequence_number(inet->inet_saddr, + inet->inet_daddr, + inet->inet_sport, + inet->inet_dport); + atomic_set(&inet->inet_id, get_random_u16()); + + err = dccp_connect(sk); + rt = NULL; + if (err != 0) + goto failure; +out: + return err; +failure: + /* + * This unhashes the socket and releases the local port, if necessary. + */ + dccp_set_state(sk, DCCP_CLOSED); + inet_bhash2_reset_saddr(sk); + ip_rt_put(rt); + sk->sk_route_caps = 0; + inet->inet_dport = 0; + goto out; +} +EXPORT_SYMBOL_GPL(dccp_v4_connect); + +/* + * This routine does path mtu discovery as defined in RFC1191. + */ +static inline void dccp_do_pmtu_discovery(struct sock *sk, + const struct iphdr *iph, + u32 mtu) +{ + struct dst_entry *dst; + const struct inet_sock *inet = inet_sk(sk); + const struct dccp_sock *dp = dccp_sk(sk); + + /* We are not interested in DCCP_LISTEN and request_socks (RESPONSEs + * send out by Linux are always < 576bytes so they should go through + * unfragmented). + */ + if (sk->sk_state == DCCP_LISTEN) + return; + + dst = inet_csk_update_pmtu(sk, mtu); + if (!dst) + return; + + /* Something is about to be wrong... Remember soft error + * for the case, if this connection will not able to recover. + */ + if (mtu < dst_mtu(dst) && ip_dont_fragment(sk, dst)) + sk->sk_err_soft = EMSGSIZE; + + mtu = dst_mtu(dst); + + if (inet->pmtudisc != IP_PMTUDISC_DONT && + ip_sk_accept_pmtu(sk) && + inet_csk(sk)->icsk_pmtu_cookie > mtu) { + dccp_sync_mss(sk, mtu); + + /* + * From RFC 4340, sec. 14.1: + * + * DCCP-Sync packets are the best choice for upward + * probing, since DCCP-Sync probes do not risk application + * data loss. + */ + dccp_send_sync(sk, dp->dccps_gsr, DCCP_PKT_SYNC); + } /* else let the usual retransmit timer handle it */ +} + +static void dccp_do_redirect(struct sk_buff *skb, struct sock *sk) +{ + struct dst_entry *dst = __sk_dst_check(sk, 0); + + if (dst) + dst->ops->redirect(dst, sk, skb); +} + +void dccp_req_err(struct sock *sk, u64 seq) + { + struct request_sock *req = inet_reqsk(sk); + struct net *net = sock_net(sk); + + /* + * ICMPs are not backlogged, hence we cannot get an established + * socket here. + */ + if (!between48(seq, dccp_rsk(req)->dreq_iss, dccp_rsk(req)->dreq_gss)) { + __NET_INC_STATS(net, LINUX_MIB_OUTOFWINDOWICMPS); + } else { + /* + * Still in RESPOND, just remove it silently. + * There is no good way to pass the error to the newly + * created socket, and POSIX does not want network + * errors returned from accept(). + */ + inet_csk_reqsk_queue_drop(req->rsk_listener, req); + } + reqsk_put(req); +} +EXPORT_SYMBOL(dccp_req_err); + +/* + * This routine is called by the ICMP module when it gets some sort of error + * condition. If err < 0 then the socket should be closed and the error + * returned to the user. If err > 0 it's just the icmp type << 8 | icmp code. + * After adjustment header points to the first 8 bytes of the tcp header. We + * need to find the appropriate port. + * + * The locking strategy used here is very "optimistic". When someone else + * accesses the socket the ICMP is just dropped and for some paths there is no + * check at all. A more general error queue to queue errors for later handling + * is probably better. + */ +static int dccp_v4_err(struct sk_buff *skb, u32 info) +{ + const struct iphdr *iph = (struct iphdr *)skb->data; + const u8 offset = iph->ihl << 2; + const struct dccp_hdr *dh; + struct dccp_sock *dp; + struct inet_sock *inet; + const int type = icmp_hdr(skb)->type; + const int code = icmp_hdr(skb)->code; + struct sock *sk; + __u64 seq; + int err; + struct net *net = dev_net(skb->dev); + + if (!pskb_may_pull(skb, offset + sizeof(*dh))) + return -EINVAL; + dh = (struct dccp_hdr *)(skb->data + offset); + if (!pskb_may_pull(skb, offset + __dccp_basic_hdr_len(dh))) + return -EINVAL; + iph = (struct iphdr *)skb->data; + dh = (struct dccp_hdr *)(skb->data + offset); + + sk = __inet_lookup_established(net, &dccp_hashinfo, + iph->daddr, dh->dccph_dport, + iph->saddr, ntohs(dh->dccph_sport), + inet_iif(skb), 0); + if (!sk) { + __ICMP_INC_STATS(net, ICMP_MIB_INERRORS); + return -ENOENT; + } + + if (sk->sk_state == DCCP_TIME_WAIT) { + inet_twsk_put(inet_twsk(sk)); + return 0; + } + seq = dccp_hdr_seq(dh); + if (sk->sk_state == DCCP_NEW_SYN_RECV) { + dccp_req_err(sk, seq); + return 0; + } + + bh_lock_sock(sk); + /* If too many ICMPs get dropped on busy + * servers this needs to be solved differently. + */ + if (sock_owned_by_user(sk)) + __NET_INC_STATS(net, LINUX_MIB_LOCKDROPPEDICMPS); + + if (sk->sk_state == DCCP_CLOSED) + goto out; + + dp = dccp_sk(sk); + if ((1 << sk->sk_state) & ~(DCCPF_REQUESTING | DCCPF_LISTEN) && + !between48(seq, dp->dccps_awl, dp->dccps_awh)) { + __NET_INC_STATS(net, LINUX_MIB_OUTOFWINDOWICMPS); + goto out; + } + + switch (type) { + case ICMP_REDIRECT: + if (!sock_owned_by_user(sk)) + dccp_do_redirect(skb, sk); + goto out; + case ICMP_SOURCE_QUENCH: + /* Just silently ignore these. */ + goto out; + case ICMP_PARAMETERPROB: + err = EPROTO; + break; + case ICMP_DEST_UNREACH: + if (code > NR_ICMP_UNREACH) + goto out; + + if (code == ICMP_FRAG_NEEDED) { /* PMTU discovery (RFC1191) */ + if (!sock_owned_by_user(sk)) + dccp_do_pmtu_discovery(sk, iph, info); + goto out; + } + + err = icmp_err_convert[code].errno; + break; + case ICMP_TIME_EXCEEDED: + err = EHOSTUNREACH; + break; + default: + goto out; + } + + switch (sk->sk_state) { + case DCCP_REQUESTING: + case DCCP_RESPOND: + if (!sock_owned_by_user(sk)) { + __DCCP_INC_STATS(DCCP_MIB_ATTEMPTFAILS); + sk->sk_err = err; + + sk_error_report(sk); + + dccp_done(sk); + } else + sk->sk_err_soft = err; + goto out; + } + + /* If we've already connected we will keep trying + * until we time out, or the user gives up. + * + * rfc1122 4.2.3.9 allows to consider as hard errors + * only PROTO_UNREACH and PORT_UNREACH (well, FRAG_FAILED too, + * but it is obsoleted by pmtu discovery). + * + * Note, that in modern internet, where routing is unreliable + * and in each dark corner broken firewalls sit, sending random + * errors ordered by their masters even this two messages finally lose + * their original sense (even Linux sends invalid PORT_UNREACHs) + * + * Now we are in compliance with RFCs. + * --ANK (980905) + */ + + inet = inet_sk(sk); + if (!sock_owned_by_user(sk) && inet->recverr) { + sk->sk_err = err; + sk_error_report(sk); + } else /* Only an error on timeout */ + sk->sk_err_soft = err; +out: + bh_unlock_sock(sk); + sock_put(sk); + return 0; +} + +static inline __sum16 dccp_v4_csum_finish(struct sk_buff *skb, + __be32 src, __be32 dst) +{ + return csum_tcpudp_magic(src, dst, skb->len, IPPROTO_DCCP, skb->csum); +} + +void dccp_v4_send_check(struct sock *sk, struct sk_buff *skb) +{ + const struct inet_sock *inet = inet_sk(sk); + struct dccp_hdr *dh = dccp_hdr(skb); + + dccp_csum_outgoing(skb); + dh->dccph_checksum = dccp_v4_csum_finish(skb, + inet->inet_saddr, + inet->inet_daddr); +} +EXPORT_SYMBOL_GPL(dccp_v4_send_check); + +static inline u64 dccp_v4_init_sequence(const struct sk_buff *skb) +{ + return secure_dccp_sequence_number(ip_hdr(skb)->daddr, + ip_hdr(skb)->saddr, + dccp_hdr(skb)->dccph_dport, + dccp_hdr(skb)->dccph_sport); +} + +/* + * The three way handshake has completed - we got a valid ACK or DATAACK - + * now create the new socket. + * + * This is the equivalent of TCP's tcp_v4_syn_recv_sock + */ +struct sock *dccp_v4_request_recv_sock(const struct sock *sk, + struct sk_buff *skb, + struct request_sock *req, + struct dst_entry *dst, + struct request_sock *req_unhash, + bool *own_req) +{ + struct inet_request_sock *ireq; + struct inet_sock *newinet; + struct sock *newsk; + + if (sk_acceptq_is_full(sk)) + goto exit_overflow; + + newsk = dccp_create_openreq_child(sk, req, skb); + if (newsk == NULL) + goto exit_nonewsk; + + newinet = inet_sk(newsk); + ireq = inet_rsk(req); + sk_daddr_set(newsk, ireq->ir_rmt_addr); + sk_rcv_saddr_set(newsk, ireq->ir_loc_addr); + newinet->inet_saddr = ireq->ir_loc_addr; + RCU_INIT_POINTER(newinet->inet_opt, rcu_dereference(ireq->ireq_opt)); + newinet->mc_index = inet_iif(skb); + newinet->mc_ttl = ip_hdr(skb)->ttl; + atomic_set(&newinet->inet_id, get_random_u16()); + + if (dst == NULL && (dst = inet_csk_route_child_sock(sk, newsk, req)) == NULL) + goto put_and_exit; + + sk_setup_caps(newsk, dst); + + dccp_sync_mss(newsk, dst_mtu(dst)); + + if (__inet_inherit_port(sk, newsk) < 0) + goto put_and_exit; + *own_req = inet_ehash_nolisten(newsk, req_to_sk(req_unhash), NULL); + if (*own_req) + ireq->ireq_opt = NULL; + else + newinet->inet_opt = NULL; + return newsk; + +exit_overflow: + __NET_INC_STATS(sock_net(sk), LINUX_MIB_LISTENOVERFLOWS); +exit_nonewsk: + dst_release(dst); +exit: + __NET_INC_STATS(sock_net(sk), LINUX_MIB_LISTENDROPS); + return NULL; +put_and_exit: + newinet->inet_opt = NULL; + inet_csk_prepare_forced_close(newsk); + dccp_done(newsk); + goto exit; +} +EXPORT_SYMBOL_GPL(dccp_v4_request_recv_sock); + +static struct dst_entry* dccp_v4_route_skb(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + struct rtable *rt; + const struct iphdr *iph = ip_hdr(skb); + struct flowi4 fl4 = { + .flowi4_oif = inet_iif(skb), + .daddr = iph->saddr, + .saddr = iph->daddr, + .flowi4_tos = RT_CONN_FLAGS(sk), + .flowi4_proto = sk->sk_protocol, + .fl4_sport = dccp_hdr(skb)->dccph_dport, + .fl4_dport = dccp_hdr(skb)->dccph_sport, + }; + + security_skb_classify_flow(skb, flowi4_to_flowi_common(&fl4)); + rt = ip_route_output_flow(net, &fl4, sk); + if (IS_ERR(rt)) { + IP_INC_STATS(net, IPSTATS_MIB_OUTNOROUTES); + return NULL; + } + + return &rt->dst; +} + +static int dccp_v4_send_response(const struct sock *sk, struct request_sock *req) +{ + int err = -1; + struct sk_buff *skb; + struct dst_entry *dst; + struct flowi4 fl4; + + dst = inet_csk_route_req(sk, &fl4, req); + if (dst == NULL) + goto out; + + skb = dccp_make_response(sk, dst, req); + if (skb != NULL) { + const struct inet_request_sock *ireq = inet_rsk(req); + struct dccp_hdr *dh = dccp_hdr(skb); + + dh->dccph_checksum = dccp_v4_csum_finish(skb, ireq->ir_loc_addr, + ireq->ir_rmt_addr); + rcu_read_lock(); + err = ip_build_and_send_pkt(skb, sk, ireq->ir_loc_addr, + ireq->ir_rmt_addr, + rcu_dereference(ireq->ireq_opt), + inet_sk(sk)->tos); + rcu_read_unlock(); + err = net_xmit_eval(err); + } + +out: + dst_release(dst); + return err; +} + +static void dccp_v4_ctl_send_reset(const struct sock *sk, struct sk_buff *rxskb) +{ + int err; + const struct iphdr *rxiph; + struct sk_buff *skb; + struct dst_entry *dst; + struct net *net = dev_net(skb_dst(rxskb)->dev); + struct dccp_v4_pernet *pn; + struct sock *ctl_sk; + + /* Never send a reset in response to a reset. */ + if (dccp_hdr(rxskb)->dccph_type == DCCP_PKT_RESET) + return; + + if (skb_rtable(rxskb)->rt_type != RTN_LOCAL) + return; + + pn = net_generic(net, dccp_v4_pernet_id); + ctl_sk = pn->v4_ctl_sk; + dst = dccp_v4_route_skb(net, ctl_sk, rxskb); + if (dst == NULL) + return; + + skb = dccp_ctl_make_reset(ctl_sk, rxskb); + if (skb == NULL) + goto out; + + rxiph = ip_hdr(rxskb); + dccp_hdr(skb)->dccph_checksum = dccp_v4_csum_finish(skb, rxiph->saddr, + rxiph->daddr); + skb_dst_set(skb, dst_clone(dst)); + + local_bh_disable(); + bh_lock_sock(ctl_sk); + err = ip_build_and_send_pkt(skb, ctl_sk, + rxiph->daddr, rxiph->saddr, NULL, + inet_sk(ctl_sk)->tos); + bh_unlock_sock(ctl_sk); + + if (net_xmit_eval(err) == 0) { + __DCCP_INC_STATS(DCCP_MIB_OUTSEGS); + __DCCP_INC_STATS(DCCP_MIB_OUTRSTS); + } + local_bh_enable(); +out: + dst_release(dst); +} + +static void dccp_v4_reqsk_destructor(struct request_sock *req) +{ + dccp_feat_list_purge(&dccp_rsk(req)->dreq_featneg); + kfree(rcu_dereference_protected(inet_rsk(req)->ireq_opt, 1)); +} + +void dccp_syn_ack_timeout(const struct request_sock *req) +{ +} +EXPORT_SYMBOL(dccp_syn_ack_timeout); + +static struct request_sock_ops dccp_request_sock_ops __read_mostly = { + .family = PF_INET, + .obj_size = sizeof(struct dccp_request_sock), + .rtx_syn_ack = dccp_v4_send_response, + .send_ack = dccp_reqsk_send_ack, + .destructor = dccp_v4_reqsk_destructor, + .send_reset = dccp_v4_ctl_send_reset, + .syn_ack_timeout = dccp_syn_ack_timeout, +}; + +int dccp_v4_conn_request(struct sock *sk, struct sk_buff *skb) +{ + struct inet_request_sock *ireq; + struct request_sock *req; + struct dccp_request_sock *dreq; + const __be32 service = dccp_hdr_request(skb)->dccph_req_service; + struct dccp_skb_cb *dcb = DCCP_SKB_CB(skb); + + /* Never answer to DCCP_PKT_REQUESTs send to broadcast or multicast */ + if (skb_rtable(skb)->rt_flags & (RTCF_BROADCAST | RTCF_MULTICAST)) + return 0; /* discard, don't send a reset here */ + + if (dccp_bad_service_code(sk, service)) { + dcb->dccpd_reset_code = DCCP_RESET_CODE_BAD_SERVICE_CODE; + goto drop; + } + /* + * TW buckets are converted to open requests without + * limitations, they conserve resources and peer is + * evidently real one. + */ + dcb->dccpd_reset_code = DCCP_RESET_CODE_TOO_BUSY; + if (inet_csk_reqsk_queue_is_full(sk)) + goto drop; + + if (sk_acceptq_is_full(sk)) + goto drop; + + req = inet_reqsk_alloc(&dccp_request_sock_ops, sk, true); + if (req == NULL) + goto drop; + + if (dccp_reqsk_init(req, dccp_sk(sk), skb)) + goto drop_and_free; + + dreq = dccp_rsk(req); + if (dccp_parse_options(sk, dreq, skb)) + goto drop_and_free; + + ireq = inet_rsk(req); + sk_rcv_saddr_set(req_to_sk(req), ip_hdr(skb)->daddr); + sk_daddr_set(req_to_sk(req), ip_hdr(skb)->saddr); + ireq->ir_mark = inet_request_mark(sk, skb); + ireq->ireq_family = AF_INET; + ireq->ir_iif = READ_ONCE(sk->sk_bound_dev_if); + + if (security_inet_conn_request(sk, skb, req)) + goto drop_and_free; + + /* + * Step 3: Process LISTEN state + * + * Set S.ISR, S.GSR, S.SWL, S.SWH from packet or Init Cookie + * + * Setting S.SWL/S.SWH to is deferred to dccp_create_openreq_child(). + */ + dreq->dreq_isr = dcb->dccpd_seq; + dreq->dreq_gsr = dreq->dreq_isr; + dreq->dreq_iss = dccp_v4_init_sequence(skb); + dreq->dreq_gss = dreq->dreq_iss; + dreq->dreq_service = service; + + if (dccp_v4_send_response(sk, req)) + goto drop_and_free; + + inet_csk_reqsk_queue_hash_add(sk, req, DCCP_TIMEOUT_INIT); + reqsk_put(req); + return 0; + +drop_and_free: + reqsk_free(req); +drop: + __DCCP_INC_STATS(DCCP_MIB_ATTEMPTFAILS); + return -1; +} +EXPORT_SYMBOL_GPL(dccp_v4_conn_request); + +int dccp_v4_do_rcv(struct sock *sk, struct sk_buff *skb) +{ + struct dccp_hdr *dh = dccp_hdr(skb); + + if (sk->sk_state == DCCP_OPEN) { /* Fast path */ + if (dccp_rcv_established(sk, skb, dh, skb->len)) + goto reset; + return 0; + } + + /* + * Step 3: Process LISTEN state + * If P.type == Request or P contains a valid Init Cookie option, + * (* Must scan the packet's options to check for Init + * Cookies. Only Init Cookies are processed here, + * however; other options are processed in Step 8. This + * scan need only be performed if the endpoint uses Init + * Cookies *) + * (* Generate a new socket and switch to that socket *) + * Set S := new socket for this port pair + * S.state = RESPOND + * Choose S.ISS (initial seqno) or set from Init Cookies + * Initialize S.GAR := S.ISS + * Set S.ISR, S.GSR, S.SWL, S.SWH from packet or Init Cookies + * Continue with S.state == RESPOND + * (* A Response packet will be generated in Step 11 *) + * Otherwise, + * Generate Reset(No Connection) unless P.type == Reset + * Drop packet and return + * + * NOTE: the check for the packet types is done in + * dccp_rcv_state_process + */ + + if (dccp_rcv_state_process(sk, skb, dh, skb->len)) + goto reset; + return 0; + +reset: + dccp_v4_ctl_send_reset(sk, skb); + kfree_skb(skb); + return 0; +} +EXPORT_SYMBOL_GPL(dccp_v4_do_rcv); + +/** + * dccp_invalid_packet - check for malformed packets + * @skb: Packet to validate + * + * Implements RFC 4340, 8.5: Step 1: Check header basics + * Packets that fail these checks are ignored and do not receive Resets. + */ +int dccp_invalid_packet(struct sk_buff *skb) +{ + const struct dccp_hdr *dh; + unsigned int cscov; + u8 dccph_doff; + + if (skb->pkt_type != PACKET_HOST) + return 1; + + /* If the packet is shorter than 12 bytes, drop packet and return */ + if (!pskb_may_pull(skb, sizeof(struct dccp_hdr))) { + DCCP_WARN("pskb_may_pull failed\n"); + return 1; + } + + dh = dccp_hdr(skb); + + /* If P.type is not understood, drop packet and return */ + if (dh->dccph_type >= DCCP_PKT_INVALID) { + DCCP_WARN("invalid packet type\n"); + return 1; + } + + /* + * If P.Data Offset is too small for packet type, drop packet and return + */ + dccph_doff = dh->dccph_doff; + if (dccph_doff < dccp_hdr_len(skb) / sizeof(u32)) { + DCCP_WARN("P.Data Offset(%u) too small\n", dccph_doff); + return 1; + } + /* + * If P.Data Offset is too large for packet, drop packet and return + */ + if (!pskb_may_pull(skb, dccph_doff * sizeof(u32))) { + DCCP_WARN("P.Data Offset(%u) too large\n", dccph_doff); + return 1; + } + dh = dccp_hdr(skb); + /* + * If P.type is not Data, Ack, or DataAck and P.X == 0 (the packet + * has short sequence numbers), drop packet and return + */ + if ((dh->dccph_type < DCCP_PKT_DATA || + dh->dccph_type > DCCP_PKT_DATAACK) && dh->dccph_x == 0) { + DCCP_WARN("P.type (%s) not Data || [Data]Ack, while P.X == 0\n", + dccp_packet_name(dh->dccph_type)); + return 1; + } + + /* + * If P.CsCov is too large for the packet size, drop packet and return. + * This must come _before_ checksumming (not as RFC 4340 suggests). + */ + cscov = dccp_csum_coverage(skb); + if (cscov > skb->len) { + DCCP_WARN("P.CsCov %u exceeds packet length %d\n", + dh->dccph_cscov, skb->len); + return 1; + } + + /* If header checksum is incorrect, drop packet and return. + * (This step is completed in the AF-dependent functions.) */ + skb->csum = skb_checksum(skb, 0, cscov, 0); + + return 0; +} +EXPORT_SYMBOL_GPL(dccp_invalid_packet); + +/* this is called when real data arrives */ +static int dccp_v4_rcv(struct sk_buff *skb) +{ + const struct dccp_hdr *dh; + const struct iphdr *iph; + bool refcounted; + struct sock *sk; + int min_cov; + + /* Step 1: Check header basics */ + + if (dccp_invalid_packet(skb)) + goto discard_it; + + iph = ip_hdr(skb); + /* Step 1: If header checksum is incorrect, drop packet and return */ + if (dccp_v4_csum_finish(skb, iph->saddr, iph->daddr)) { + DCCP_WARN("dropped packet with invalid checksum\n"); + goto discard_it; + } + + dh = dccp_hdr(skb); + + DCCP_SKB_CB(skb)->dccpd_seq = dccp_hdr_seq(dh); + DCCP_SKB_CB(skb)->dccpd_type = dh->dccph_type; + + dccp_pr_debug("%8.8s src=%pI4@%-5d dst=%pI4@%-5d seq=%llu", + dccp_packet_name(dh->dccph_type), + &iph->saddr, ntohs(dh->dccph_sport), + &iph->daddr, ntohs(dh->dccph_dport), + (unsigned long long) DCCP_SKB_CB(skb)->dccpd_seq); + + if (dccp_packet_without_ack(skb)) { + DCCP_SKB_CB(skb)->dccpd_ack_seq = DCCP_PKT_WITHOUT_ACK_SEQ; + dccp_pr_debug_cat("\n"); + } else { + DCCP_SKB_CB(skb)->dccpd_ack_seq = dccp_hdr_ack_seq(skb); + dccp_pr_debug_cat(", ack=%llu\n", (unsigned long long) + DCCP_SKB_CB(skb)->dccpd_ack_seq); + } + +lookup: + sk = __inet_lookup_skb(&dccp_hashinfo, skb, __dccp_hdr_len(dh), + dh->dccph_sport, dh->dccph_dport, 0, &refcounted); + if (!sk) { + dccp_pr_debug("failed to look up flow ID in table and " + "get corresponding socket\n"); + goto no_dccp_socket; + } + + /* + * Step 2: + * ... or S.state == TIMEWAIT, + * Generate Reset(No Connection) unless P.type == Reset + * Drop packet and return + */ + if (sk->sk_state == DCCP_TIME_WAIT) { + dccp_pr_debug("sk->sk_state == DCCP_TIME_WAIT: do_time_wait\n"); + inet_twsk_put(inet_twsk(sk)); + goto no_dccp_socket; + } + + if (sk->sk_state == DCCP_NEW_SYN_RECV) { + struct request_sock *req = inet_reqsk(sk); + struct sock *nsk; + + sk = req->rsk_listener; + if (unlikely(sk->sk_state != DCCP_LISTEN)) { + inet_csk_reqsk_queue_drop_and_put(sk, req); + goto lookup; + } + sock_hold(sk); + refcounted = true; + nsk = dccp_check_req(sk, skb, req); + if (!nsk) { + reqsk_put(req); + goto discard_and_relse; + } + if (nsk == sk) { + reqsk_put(req); + } else if (dccp_child_process(sk, nsk, skb)) { + dccp_v4_ctl_send_reset(sk, skb); + goto discard_and_relse; + } else { + sock_put(sk); + return 0; + } + } + /* + * RFC 4340, sec. 9.2.1: Minimum Checksum Coverage + * o if MinCsCov = 0, only packets with CsCov = 0 are accepted + * o if MinCsCov > 0, also accept packets with CsCov >= MinCsCov + */ + min_cov = dccp_sk(sk)->dccps_pcrlen; + if (dh->dccph_cscov && (min_cov == 0 || dh->dccph_cscov < min_cov)) { + dccp_pr_debug("Packet CsCov %d does not satisfy MinCsCov %d\n", + dh->dccph_cscov, min_cov); + /* FIXME: "Such packets SHOULD be reported using Data Dropped + * options (Section 11.7) with Drop Code 0, Protocol + * Constraints." */ + goto discard_and_relse; + } + + if (!xfrm4_policy_check(sk, XFRM_POLICY_IN, skb)) + goto discard_and_relse; + nf_reset_ct(skb); + + return __sk_receive_skb(sk, skb, 1, dh->dccph_doff * 4, refcounted); + +no_dccp_socket: + if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) + goto discard_it; + /* + * Step 2: + * If no socket ... + * Generate Reset(No Connection) unless P.type == Reset + * Drop packet and return + */ + if (dh->dccph_type != DCCP_PKT_RESET) { + DCCP_SKB_CB(skb)->dccpd_reset_code = + DCCP_RESET_CODE_NO_CONNECTION; + dccp_v4_ctl_send_reset(sk, skb); + } + +discard_it: + kfree_skb(skb); + return 0; + +discard_and_relse: + if (refcounted) + sock_put(sk); + goto discard_it; +} + +static const struct inet_connection_sock_af_ops dccp_ipv4_af_ops = { + .queue_xmit = ip_queue_xmit, + .send_check = dccp_v4_send_check, + .rebuild_header = inet_sk_rebuild_header, + .conn_request = dccp_v4_conn_request, + .syn_recv_sock = dccp_v4_request_recv_sock, + .net_header_len = sizeof(struct iphdr), + .setsockopt = ip_setsockopt, + .getsockopt = ip_getsockopt, + .addr2sockaddr = inet_csk_addr2sockaddr, + .sockaddr_len = sizeof(struct sockaddr_in), +}; + +static int dccp_v4_init_sock(struct sock *sk) +{ + static __u8 dccp_v4_ctl_sock_initialized; + int err = dccp_init_sock(sk, dccp_v4_ctl_sock_initialized); + + if (err == 0) { + if (unlikely(!dccp_v4_ctl_sock_initialized)) + dccp_v4_ctl_sock_initialized = 1; + inet_csk(sk)->icsk_af_ops = &dccp_ipv4_af_ops; + } + + return err; +} + +static struct timewait_sock_ops dccp_timewait_sock_ops = { + .twsk_obj_size = sizeof(struct inet_timewait_sock), +}; + +static struct proto dccp_v4_prot = { + .name = "DCCP", + .owner = THIS_MODULE, + .close = dccp_close, + .connect = dccp_v4_connect, + .disconnect = dccp_disconnect, + .ioctl = dccp_ioctl, + .init = dccp_v4_init_sock, + .setsockopt = dccp_setsockopt, + .getsockopt = dccp_getsockopt, + .sendmsg = dccp_sendmsg, + .recvmsg = dccp_recvmsg, + .backlog_rcv = dccp_v4_do_rcv, + .hash = inet_hash, + .unhash = inet_unhash, + .accept = inet_csk_accept, + .get_port = inet_csk_get_port, + .shutdown = dccp_shutdown, + .destroy = dccp_destroy_sock, + .orphan_count = &dccp_orphan_count, + .max_header = MAX_DCCP_HEADER, + .obj_size = sizeof(struct dccp_sock), + .slab_flags = SLAB_TYPESAFE_BY_RCU, + .rsk_prot = &dccp_request_sock_ops, + .twsk_prot = &dccp_timewait_sock_ops, + .h.hashinfo = &dccp_hashinfo, +}; + +static const struct net_protocol dccp_v4_protocol = { + .handler = dccp_v4_rcv, + .err_handler = dccp_v4_err, + .no_policy = 1, + .icmp_strict_tag_validation = 1, +}; + +static const struct proto_ops inet_dccp_ops = { + .family = PF_INET, + .owner = THIS_MODULE, + .release = inet_release, + .bind = inet_bind, + .connect = inet_stream_connect, + .socketpair = sock_no_socketpair, + .accept = inet_accept, + .getname = inet_getname, + /* FIXME: work on tcp_poll to rename it to inet_csk_poll */ + .poll = dccp_poll, + .ioctl = inet_ioctl, + .gettstamp = sock_gettstamp, + /* FIXME: work on inet_listen to rename it to sock_common_listen */ + .listen = inet_dccp_listen, + .shutdown = inet_shutdown, + .setsockopt = sock_common_setsockopt, + .getsockopt = sock_common_getsockopt, + .sendmsg = inet_sendmsg, + .recvmsg = sock_common_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static struct inet_protosw dccp_v4_protosw = { + .type = SOCK_DCCP, + .protocol = IPPROTO_DCCP, + .prot = &dccp_v4_prot, + .ops = &inet_dccp_ops, + .flags = INET_PROTOSW_ICSK, +}; + +static int __net_init dccp_v4_init_net(struct net *net) +{ + struct dccp_v4_pernet *pn = net_generic(net, dccp_v4_pernet_id); + + if (dccp_hashinfo.bhash == NULL) + return -ESOCKTNOSUPPORT; + + return inet_ctl_sock_create(&pn->v4_ctl_sk, PF_INET, + SOCK_DCCP, IPPROTO_DCCP, net); +} + +static void __net_exit dccp_v4_exit_net(struct net *net) +{ + struct dccp_v4_pernet *pn = net_generic(net, dccp_v4_pernet_id); + + inet_ctl_sock_destroy(pn->v4_ctl_sk); +} + +static void __net_exit dccp_v4_exit_batch(struct list_head *net_exit_list) +{ + inet_twsk_purge(&dccp_hashinfo, AF_INET); +} + +static struct pernet_operations dccp_v4_ops = { + .init = dccp_v4_init_net, + .exit = dccp_v4_exit_net, + .exit_batch = dccp_v4_exit_batch, + .id = &dccp_v4_pernet_id, + .size = sizeof(struct dccp_v4_pernet), +}; + +static int __init dccp_v4_init(void) +{ + int err = proto_register(&dccp_v4_prot, 1); + + if (err) + goto out; + + inet_register_protosw(&dccp_v4_protosw); + + err = register_pernet_subsys(&dccp_v4_ops); + if (err) + goto out_destroy_ctl_sock; + + err = inet_add_protocol(&dccp_v4_protocol, IPPROTO_DCCP); + if (err) + goto out_proto_unregister; + +out: + return err; +out_proto_unregister: + unregister_pernet_subsys(&dccp_v4_ops); +out_destroy_ctl_sock: + inet_unregister_protosw(&dccp_v4_protosw); + proto_unregister(&dccp_v4_prot); + goto out; +} + +static void __exit dccp_v4_exit(void) +{ + inet_del_protocol(&dccp_v4_protocol, IPPROTO_DCCP); + unregister_pernet_subsys(&dccp_v4_ops); + inet_unregister_protosw(&dccp_v4_protosw); + proto_unregister(&dccp_v4_prot); +} + +module_init(dccp_v4_init); +module_exit(dccp_v4_exit); + +/* + * __stringify doesn't likes enums, so use SOCK_DCCP (6) and IPPROTO_DCCP (33) + * values directly, Also cover the case where the protocol is not specified, + * i.e. net-pf-PF_INET-proto-0-type-SOCK_DCCP + */ +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_INET, 33, 6); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_INET, 0, 6); +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Arnaldo Carvalho de Melo "); +MODULE_DESCRIPTION("DCCP - Datagram Congestion Controlled Protocol"); diff --git a/net/dccp/ipv6.c b/net/dccp/ipv6.c new file mode 100644 index 000000000..e0b0bf75a --- /dev/null +++ b/net/dccp/ipv6.c @@ -0,0 +1,1181 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * DCCP over IPv6 + * Linux INET6 implementation + * + * Based on net/dccp6/ipv6.c + * + * Arnaldo Carvalho de Melo + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dccp.h" +#include "ipv6.h" +#include "feat.h" + +struct dccp_v6_pernet { + struct sock *v6_ctl_sk; +}; + +static unsigned int dccp_v6_pernet_id __read_mostly; + +/* The per-net v6_ctl_sk is used for sending RSTs and ACKs */ + +static const struct inet_connection_sock_af_ops dccp_ipv6_mapped; +static const struct inet_connection_sock_af_ops dccp_ipv6_af_ops; + +/* add pseudo-header to DCCP checksum stored in skb->csum */ +static inline __sum16 dccp_v6_csum_finish(struct sk_buff *skb, + const struct in6_addr *saddr, + const struct in6_addr *daddr) +{ + return csum_ipv6_magic(saddr, daddr, skb->len, IPPROTO_DCCP, skb->csum); +} + +static inline void dccp_v6_send_check(struct sock *sk, struct sk_buff *skb) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct dccp_hdr *dh = dccp_hdr(skb); + + dccp_csum_outgoing(skb); + dh->dccph_checksum = dccp_v6_csum_finish(skb, &np->saddr, &sk->sk_v6_daddr); +} + +static inline __u64 dccp_v6_init_sequence(struct sk_buff *skb) +{ + return secure_dccpv6_sequence_number(ipv6_hdr(skb)->daddr.s6_addr32, + ipv6_hdr(skb)->saddr.s6_addr32, + dccp_hdr(skb)->dccph_dport, + dccp_hdr(skb)->dccph_sport ); + +} + +static int dccp_v6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + const struct ipv6hdr *hdr; + const struct dccp_hdr *dh; + struct dccp_sock *dp; + struct ipv6_pinfo *np; + struct sock *sk; + int err; + __u64 seq; + struct net *net = dev_net(skb->dev); + + if (!pskb_may_pull(skb, offset + sizeof(*dh))) + return -EINVAL; + dh = (struct dccp_hdr *)(skb->data + offset); + if (!pskb_may_pull(skb, offset + __dccp_basic_hdr_len(dh))) + return -EINVAL; + hdr = (const struct ipv6hdr *)skb->data; + dh = (struct dccp_hdr *)(skb->data + offset); + + sk = __inet6_lookup_established(net, &dccp_hashinfo, + &hdr->daddr, dh->dccph_dport, + &hdr->saddr, ntohs(dh->dccph_sport), + inet6_iif(skb), 0); + + if (!sk) { + __ICMP6_INC_STATS(net, __in6_dev_get(skb->dev), + ICMP6_MIB_INERRORS); + return -ENOENT; + } + + if (sk->sk_state == DCCP_TIME_WAIT) { + inet_twsk_put(inet_twsk(sk)); + return 0; + } + seq = dccp_hdr_seq(dh); + if (sk->sk_state == DCCP_NEW_SYN_RECV) { + dccp_req_err(sk, seq); + return 0; + } + + bh_lock_sock(sk); + if (sock_owned_by_user(sk)) + __NET_INC_STATS(net, LINUX_MIB_LOCKDROPPEDICMPS); + + if (sk->sk_state == DCCP_CLOSED) + goto out; + + dp = dccp_sk(sk); + if ((1 << sk->sk_state) & ~(DCCPF_REQUESTING | DCCPF_LISTEN) && + !between48(seq, dp->dccps_awl, dp->dccps_awh)) { + __NET_INC_STATS(net, LINUX_MIB_OUTOFWINDOWICMPS); + goto out; + } + + np = inet6_sk(sk); + + if (type == NDISC_REDIRECT) { + if (!sock_owned_by_user(sk)) { + struct dst_entry *dst = __sk_dst_check(sk, np->dst_cookie); + + if (dst) + dst->ops->redirect(dst, sk, skb); + } + goto out; + } + + if (type == ICMPV6_PKT_TOOBIG) { + struct dst_entry *dst = NULL; + + if (!ip6_sk_accept_pmtu(sk)) + goto out; + + if (sock_owned_by_user(sk)) + goto out; + if ((1 << sk->sk_state) & (DCCPF_LISTEN | DCCPF_CLOSED)) + goto out; + + dst = inet6_csk_update_pmtu(sk, ntohl(info)); + if (!dst) + goto out; + + if (inet_csk(sk)->icsk_pmtu_cookie > dst_mtu(dst)) + dccp_sync_mss(sk, dst_mtu(dst)); + goto out; + } + + icmpv6_err_convert(type, code, &err); + + /* Might be for an request_sock */ + switch (sk->sk_state) { + case DCCP_REQUESTING: + case DCCP_RESPOND: /* Cannot happen. + It can, it SYNs are crossed. --ANK */ + if (!sock_owned_by_user(sk)) { + __DCCP_INC_STATS(DCCP_MIB_ATTEMPTFAILS); + sk->sk_err = err; + /* + * Wake people up to see the error + * (see connect in sock.c) + */ + sk_error_report(sk); + dccp_done(sk); + } else + sk->sk_err_soft = err; + goto out; + } + + if (!sock_owned_by_user(sk) && np->recverr) { + sk->sk_err = err; + sk_error_report(sk); + } else + sk->sk_err_soft = err; + +out: + bh_unlock_sock(sk); + sock_put(sk); + return 0; +} + + +static int dccp_v6_send_response(const struct sock *sk, struct request_sock *req) +{ + struct inet_request_sock *ireq = inet_rsk(req); + struct ipv6_pinfo *np = inet6_sk(sk); + struct sk_buff *skb; + struct in6_addr *final_p, final; + struct flowi6 fl6; + int err = -1; + struct dst_entry *dst; + + memset(&fl6, 0, sizeof(fl6)); + fl6.flowi6_proto = IPPROTO_DCCP; + fl6.daddr = ireq->ir_v6_rmt_addr; + fl6.saddr = ireq->ir_v6_loc_addr; + fl6.flowlabel = 0; + fl6.flowi6_oif = ireq->ir_iif; + fl6.fl6_dport = ireq->ir_rmt_port; + fl6.fl6_sport = htons(ireq->ir_num); + security_req_classify_flow(req, flowi6_to_flowi_common(&fl6)); + + + rcu_read_lock(); + final_p = fl6_update_dst(&fl6, rcu_dereference(np->opt), &final); + rcu_read_unlock(); + + dst = ip6_dst_lookup_flow(sock_net(sk), sk, &fl6, final_p); + if (IS_ERR(dst)) { + err = PTR_ERR(dst); + dst = NULL; + goto done; + } + + skb = dccp_make_response(sk, dst, req); + if (skb != NULL) { + struct dccp_hdr *dh = dccp_hdr(skb); + struct ipv6_txoptions *opt; + + dh->dccph_checksum = dccp_v6_csum_finish(skb, + &ireq->ir_v6_loc_addr, + &ireq->ir_v6_rmt_addr); + fl6.daddr = ireq->ir_v6_rmt_addr; + rcu_read_lock(); + opt = ireq->ipv6_opt; + if (!opt) + opt = rcu_dereference(np->opt); + err = ip6_xmit(sk, skb, &fl6, READ_ONCE(sk->sk_mark), opt, + np->tclass, sk->sk_priority); + rcu_read_unlock(); + err = net_xmit_eval(err); + } + +done: + dst_release(dst); + return err; +} + +static void dccp_v6_reqsk_destructor(struct request_sock *req) +{ + dccp_feat_list_purge(&dccp_rsk(req)->dreq_featneg); + kfree(inet_rsk(req)->ipv6_opt); + kfree_skb(inet_rsk(req)->pktopts); +} + +static void dccp_v6_ctl_send_reset(const struct sock *sk, struct sk_buff *rxskb) +{ + const struct ipv6hdr *rxip6h; + struct sk_buff *skb; + struct flowi6 fl6; + struct net *net = dev_net(skb_dst(rxskb)->dev); + struct dccp_v6_pernet *pn; + struct sock *ctl_sk; + struct dst_entry *dst; + + if (dccp_hdr(rxskb)->dccph_type == DCCP_PKT_RESET) + return; + + if (!ipv6_unicast_destination(rxskb)) + return; + + pn = net_generic(net, dccp_v6_pernet_id); + ctl_sk = pn->v6_ctl_sk; + skb = dccp_ctl_make_reset(ctl_sk, rxskb); + if (skb == NULL) + return; + + rxip6h = ipv6_hdr(rxskb); + dccp_hdr(skb)->dccph_checksum = dccp_v6_csum_finish(skb, &rxip6h->saddr, + &rxip6h->daddr); + + memset(&fl6, 0, sizeof(fl6)); + fl6.daddr = rxip6h->saddr; + fl6.saddr = rxip6h->daddr; + + fl6.flowi6_proto = IPPROTO_DCCP; + fl6.flowi6_oif = inet6_iif(rxskb); + fl6.fl6_dport = dccp_hdr(skb)->dccph_dport; + fl6.fl6_sport = dccp_hdr(skb)->dccph_sport; + security_skb_classify_flow(rxskb, flowi6_to_flowi_common(&fl6)); + + /* sk = NULL, but it is safe for now. RST socket required. */ + dst = ip6_dst_lookup_flow(sock_net(ctl_sk), ctl_sk, &fl6, NULL); + if (!IS_ERR(dst)) { + skb_dst_set(skb, dst); + ip6_xmit(ctl_sk, skb, &fl6, 0, NULL, 0, 0); + DCCP_INC_STATS(DCCP_MIB_OUTSEGS); + DCCP_INC_STATS(DCCP_MIB_OUTRSTS); + return; + } + + kfree_skb(skb); +} + +static struct request_sock_ops dccp6_request_sock_ops = { + .family = AF_INET6, + .obj_size = sizeof(struct dccp6_request_sock), + .rtx_syn_ack = dccp_v6_send_response, + .send_ack = dccp_reqsk_send_ack, + .destructor = dccp_v6_reqsk_destructor, + .send_reset = dccp_v6_ctl_send_reset, + .syn_ack_timeout = dccp_syn_ack_timeout, +}; + +static int dccp_v6_conn_request(struct sock *sk, struct sk_buff *skb) +{ + struct request_sock *req; + struct dccp_request_sock *dreq; + struct inet_request_sock *ireq; + struct ipv6_pinfo *np = inet6_sk(sk); + const __be32 service = dccp_hdr_request(skb)->dccph_req_service; + struct dccp_skb_cb *dcb = DCCP_SKB_CB(skb); + + if (skb->protocol == htons(ETH_P_IP)) + return dccp_v4_conn_request(sk, skb); + + if (!ipv6_unicast_destination(skb)) + return 0; /* discard, don't send a reset here */ + + if (ipv6_addr_v4mapped(&ipv6_hdr(skb)->saddr)) { + __IP6_INC_STATS(sock_net(sk), NULL, IPSTATS_MIB_INHDRERRORS); + return 0; + } + + if (dccp_bad_service_code(sk, service)) { + dcb->dccpd_reset_code = DCCP_RESET_CODE_BAD_SERVICE_CODE; + goto drop; + } + /* + * There are no SYN attacks on IPv6, yet... + */ + dcb->dccpd_reset_code = DCCP_RESET_CODE_TOO_BUSY; + if (inet_csk_reqsk_queue_is_full(sk)) + goto drop; + + if (sk_acceptq_is_full(sk)) + goto drop; + + req = inet_reqsk_alloc(&dccp6_request_sock_ops, sk, true); + if (req == NULL) + goto drop; + + if (dccp_reqsk_init(req, dccp_sk(sk), skb)) + goto drop_and_free; + + dreq = dccp_rsk(req); + if (dccp_parse_options(sk, dreq, skb)) + goto drop_and_free; + + ireq = inet_rsk(req); + ireq->ir_v6_rmt_addr = ipv6_hdr(skb)->saddr; + ireq->ir_v6_loc_addr = ipv6_hdr(skb)->daddr; + ireq->ireq_family = AF_INET6; + ireq->ir_mark = inet_request_mark(sk, skb); + + if (security_inet_conn_request(sk, skb, req)) + goto drop_and_free; + + if (ipv6_opt_accepted(sk, skb, IP6CB(skb)) || + np->rxopt.bits.rxinfo || np->rxopt.bits.rxoinfo || + np->rxopt.bits.rxhlim || np->rxopt.bits.rxohlim) { + refcount_inc(&skb->users); + ireq->pktopts = skb; + } + ireq->ir_iif = READ_ONCE(sk->sk_bound_dev_if); + + /* So that link locals have meaning */ + if (!ireq->ir_iif && + ipv6_addr_type(&ireq->ir_v6_rmt_addr) & IPV6_ADDR_LINKLOCAL) + ireq->ir_iif = inet6_iif(skb); + + /* + * Step 3: Process LISTEN state + * + * Set S.ISR, S.GSR, S.SWL, S.SWH from packet or Init Cookie + * + * Setting S.SWL/S.SWH to is deferred to dccp_create_openreq_child(). + */ + dreq->dreq_isr = dcb->dccpd_seq; + dreq->dreq_gsr = dreq->dreq_isr; + dreq->dreq_iss = dccp_v6_init_sequence(skb); + dreq->dreq_gss = dreq->dreq_iss; + dreq->dreq_service = service; + + if (dccp_v6_send_response(sk, req)) + goto drop_and_free; + + inet_csk_reqsk_queue_hash_add(sk, req, DCCP_TIMEOUT_INIT); + reqsk_put(req); + return 0; + +drop_and_free: + reqsk_free(req); +drop: + __DCCP_INC_STATS(DCCP_MIB_ATTEMPTFAILS); + return -1; +} + +static struct sock *dccp_v6_request_recv_sock(const struct sock *sk, + struct sk_buff *skb, + struct request_sock *req, + struct dst_entry *dst, + struct request_sock *req_unhash, + bool *own_req) +{ + struct inet_request_sock *ireq = inet_rsk(req); + struct ipv6_pinfo *newnp; + const struct ipv6_pinfo *np = inet6_sk(sk); + struct ipv6_txoptions *opt; + struct inet_sock *newinet; + struct dccp6_sock *newdp6; + struct sock *newsk; + + if (skb->protocol == htons(ETH_P_IP)) { + /* + * v6 mapped + */ + newsk = dccp_v4_request_recv_sock(sk, skb, req, dst, + req_unhash, own_req); + if (newsk == NULL) + return NULL; + + newdp6 = (struct dccp6_sock *)newsk; + newinet = inet_sk(newsk); + newinet->pinet6 = &newdp6->inet6; + newnp = inet6_sk(newsk); + + memcpy(newnp, np, sizeof(struct ipv6_pinfo)); + + newnp->saddr = newsk->sk_v6_rcv_saddr; + + inet_csk(newsk)->icsk_af_ops = &dccp_ipv6_mapped; + newsk->sk_backlog_rcv = dccp_v4_do_rcv; + newnp->pktoptions = NULL; + newnp->opt = NULL; + newnp->ipv6_mc_list = NULL; + newnp->ipv6_ac_list = NULL; + newnp->ipv6_fl_list = NULL; + newnp->mcast_oif = inet_iif(skb); + newnp->mcast_hops = ip_hdr(skb)->ttl; + + /* + * No need to charge this sock to the relevant IPv6 refcnt debug socks count + * here, dccp_create_openreq_child now does this for us, see the comment in + * that function for the gory details. -acme + */ + + /* It is tricky place. Until this moment IPv4 tcp + worked with IPv6 icsk.icsk_af_ops. + Sync it now. + */ + dccp_sync_mss(newsk, inet_csk(newsk)->icsk_pmtu_cookie); + + return newsk; + } + + + if (sk_acceptq_is_full(sk)) + goto out_overflow; + + if (!dst) { + struct flowi6 fl6; + + dst = inet6_csk_route_req(sk, &fl6, req, IPPROTO_DCCP); + if (!dst) + goto out; + } + + newsk = dccp_create_openreq_child(sk, req, skb); + if (newsk == NULL) + goto out_nonewsk; + + /* + * No need to charge this sock to the relevant IPv6 refcnt debug socks + * count here, dccp_create_openreq_child now does this for us, see the + * comment in that function for the gory details. -acme + */ + + ip6_dst_store(newsk, dst, NULL, NULL); + newsk->sk_route_caps = dst->dev->features & ~(NETIF_F_IP_CSUM | + NETIF_F_TSO); + newdp6 = (struct dccp6_sock *)newsk; + newinet = inet_sk(newsk); + newinet->pinet6 = &newdp6->inet6; + newnp = inet6_sk(newsk); + + memcpy(newnp, np, sizeof(struct ipv6_pinfo)); + + newsk->sk_v6_daddr = ireq->ir_v6_rmt_addr; + newnp->saddr = ireq->ir_v6_loc_addr; + newsk->sk_v6_rcv_saddr = ireq->ir_v6_loc_addr; + newsk->sk_bound_dev_if = ireq->ir_iif; + + /* Now IPv6 options... + + First: no IPv4 options. + */ + newinet->inet_opt = NULL; + + /* Clone RX bits */ + newnp->rxopt.all = np->rxopt.all; + + newnp->ipv6_mc_list = NULL; + newnp->ipv6_ac_list = NULL; + newnp->ipv6_fl_list = NULL; + newnp->pktoptions = NULL; + newnp->opt = NULL; + newnp->mcast_oif = inet6_iif(skb); + newnp->mcast_hops = ipv6_hdr(skb)->hop_limit; + + /* + * Clone native IPv6 options from listening socket (if any) + * + * Yes, keeping reference count would be much more clever, but we make + * one more one thing there: reattach optmem to newsk. + */ + opt = ireq->ipv6_opt; + if (!opt) + opt = rcu_dereference(np->opt); + if (opt) { + opt = ipv6_dup_options(newsk, opt); + RCU_INIT_POINTER(newnp->opt, opt); + } + inet_csk(newsk)->icsk_ext_hdr_len = 0; + if (opt) + inet_csk(newsk)->icsk_ext_hdr_len = opt->opt_nflen + + opt->opt_flen; + + dccp_sync_mss(newsk, dst_mtu(dst)); + + newinet->inet_daddr = newinet->inet_saddr = LOOPBACK4_IPV6; + newinet->inet_rcv_saddr = LOOPBACK4_IPV6; + + if (__inet_inherit_port(sk, newsk) < 0) { + inet_csk_prepare_forced_close(newsk); + dccp_done(newsk); + goto out; + } + *own_req = inet_ehash_nolisten(newsk, req_to_sk(req_unhash), NULL); + /* Clone pktoptions received with SYN, if we own the req */ + if (*own_req && ireq->pktopts) { + newnp->pktoptions = skb_clone_and_charge_r(ireq->pktopts, newsk); + consume_skb(ireq->pktopts); + ireq->pktopts = NULL; + } + + return newsk; + +out_overflow: + __NET_INC_STATS(sock_net(sk), LINUX_MIB_LISTENOVERFLOWS); +out_nonewsk: + dst_release(dst); +out: + __NET_INC_STATS(sock_net(sk), LINUX_MIB_LISTENDROPS); + return NULL; +} + +/* The socket must have it's spinlock held when we get + * here. + * + * We have a potential double-lock case here, so even when + * doing backlog processing we use the BH locking scheme. + * This is because we cannot sleep with the original spinlock + * held. + */ +static int dccp_v6_do_rcv(struct sock *sk, struct sk_buff *skb) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct sk_buff *opt_skb = NULL; + + /* Imagine: socket is IPv6. IPv4 packet arrives, + goes to IPv4 receive handler and backlogged. + From backlog it always goes here. Kerboom... + Fortunately, dccp_rcv_established and rcv_established + handle them correctly, but it is not case with + dccp_v6_hnd_req and dccp_v6_ctl_send_reset(). --ANK + */ + + if (skb->protocol == htons(ETH_P_IP)) + return dccp_v4_do_rcv(sk, skb); + + if (sk_filter(sk, skb)) + goto discard; + + /* + * socket locking is here for SMP purposes as backlog rcv is currently + * called with bh processing disabled. + */ + + /* Do Stevens' IPV6_PKTOPTIONS. + + Yes, guys, it is the only place in our code, where we + may make it not affecting IPv4. + The rest of code is protocol independent, + and I do not like idea to uglify IPv4. + + Actually, all the idea behind IPV6_PKTOPTIONS + looks not very well thought. For now we latch + options, received in the last packet, enqueued + by tcp. Feel free to propose better solution. + --ANK (980728) + */ + if (np->rxopt.all) + opt_skb = skb_clone_and_charge_r(skb, sk); + + if (sk->sk_state == DCCP_OPEN) { /* Fast path */ + if (dccp_rcv_established(sk, skb, dccp_hdr(skb), skb->len)) + goto reset; + if (opt_skb) + goto ipv6_pktoptions; + return 0; + } + + /* + * Step 3: Process LISTEN state + * If S.state == LISTEN, + * If P.type == Request or P contains a valid Init Cookie option, + * (* Must scan the packet's options to check for Init + * Cookies. Only Init Cookies are processed here, + * however; other options are processed in Step 8. This + * scan need only be performed if the endpoint uses Init + * Cookies *) + * (* Generate a new socket and switch to that socket *) + * Set S := new socket for this port pair + * S.state = RESPOND + * Choose S.ISS (initial seqno) or set from Init Cookies + * Initialize S.GAR := S.ISS + * Set S.ISR, S.GSR, S.SWL, S.SWH from packet or Init Cookies + * Continue with S.state == RESPOND + * (* A Response packet will be generated in Step 11 *) + * Otherwise, + * Generate Reset(No Connection) unless P.type == Reset + * Drop packet and return + * + * NOTE: the check for the packet types is done in + * dccp_rcv_state_process + */ + + if (dccp_rcv_state_process(sk, skb, dccp_hdr(skb), skb->len)) + goto reset; + if (opt_skb) + goto ipv6_pktoptions; + return 0; + +reset: + dccp_v6_ctl_send_reset(sk, skb); +discard: + if (opt_skb != NULL) + __kfree_skb(opt_skb); + kfree_skb(skb); + return 0; + +/* Handling IPV6_PKTOPTIONS skb the similar + * way it's done for net/ipv6/tcp_ipv6.c + */ +ipv6_pktoptions: + if (!((1 << sk->sk_state) & (DCCPF_CLOSED | DCCPF_LISTEN))) { + if (np->rxopt.bits.rxinfo || np->rxopt.bits.rxoinfo) + np->mcast_oif = inet6_iif(opt_skb); + if (np->rxopt.bits.rxhlim || np->rxopt.bits.rxohlim) + np->mcast_hops = ipv6_hdr(opt_skb)->hop_limit; + if (np->rxopt.bits.rxflow || np->rxopt.bits.rxtclass) + np->rcv_flowinfo = ip6_flowinfo(ipv6_hdr(opt_skb)); + if (np->repflow) + np->flow_label = ip6_flowlabel(ipv6_hdr(opt_skb)); + if (ipv6_opt_accepted(sk, opt_skb, + &DCCP_SKB_CB(opt_skb)->header.h6)) { + memmove(IP6CB(opt_skb), + &DCCP_SKB_CB(opt_skb)->header.h6, + sizeof(struct inet6_skb_parm)); + opt_skb = xchg(&np->pktoptions, opt_skb); + } else { + __kfree_skb(opt_skb); + opt_skb = xchg(&np->pktoptions, NULL); + } + } + + kfree_skb(opt_skb); + return 0; +} + +static int dccp_v6_rcv(struct sk_buff *skb) +{ + const struct dccp_hdr *dh; + bool refcounted; + struct sock *sk; + int min_cov; + + /* Step 1: Check header basics */ + + if (dccp_invalid_packet(skb)) + goto discard_it; + + /* Step 1: If header checksum is incorrect, drop packet and return. */ + if (dccp_v6_csum_finish(skb, &ipv6_hdr(skb)->saddr, + &ipv6_hdr(skb)->daddr)) { + DCCP_WARN("dropped packet with invalid checksum\n"); + goto discard_it; + } + + dh = dccp_hdr(skb); + + DCCP_SKB_CB(skb)->dccpd_seq = dccp_hdr_seq(dh); + DCCP_SKB_CB(skb)->dccpd_type = dh->dccph_type; + + if (dccp_packet_without_ack(skb)) + DCCP_SKB_CB(skb)->dccpd_ack_seq = DCCP_PKT_WITHOUT_ACK_SEQ; + else + DCCP_SKB_CB(skb)->dccpd_ack_seq = dccp_hdr_ack_seq(skb); + +lookup: + sk = __inet6_lookup_skb(&dccp_hashinfo, skb, __dccp_hdr_len(dh), + dh->dccph_sport, dh->dccph_dport, + inet6_iif(skb), 0, &refcounted); + if (!sk) { + dccp_pr_debug("failed to look up flow ID in table and " + "get corresponding socket\n"); + goto no_dccp_socket; + } + + /* + * Step 2: + * ... or S.state == TIMEWAIT, + * Generate Reset(No Connection) unless P.type == Reset + * Drop packet and return + */ + if (sk->sk_state == DCCP_TIME_WAIT) { + dccp_pr_debug("sk->sk_state == DCCP_TIME_WAIT: do_time_wait\n"); + inet_twsk_put(inet_twsk(sk)); + goto no_dccp_socket; + } + + if (sk->sk_state == DCCP_NEW_SYN_RECV) { + struct request_sock *req = inet_reqsk(sk); + struct sock *nsk; + + sk = req->rsk_listener; + if (unlikely(sk->sk_state != DCCP_LISTEN)) { + inet_csk_reqsk_queue_drop_and_put(sk, req); + goto lookup; + } + sock_hold(sk); + refcounted = true; + nsk = dccp_check_req(sk, skb, req); + if (!nsk) { + reqsk_put(req); + goto discard_and_relse; + } + if (nsk == sk) { + reqsk_put(req); + } else if (dccp_child_process(sk, nsk, skb)) { + dccp_v6_ctl_send_reset(sk, skb); + goto discard_and_relse; + } else { + sock_put(sk); + return 0; + } + } + /* + * RFC 4340, sec. 9.2.1: Minimum Checksum Coverage + * o if MinCsCov = 0, only packets with CsCov = 0 are accepted + * o if MinCsCov > 0, also accept packets with CsCov >= MinCsCov + */ + min_cov = dccp_sk(sk)->dccps_pcrlen; + if (dh->dccph_cscov && (min_cov == 0 || dh->dccph_cscov < min_cov)) { + dccp_pr_debug("Packet CsCov %d does not satisfy MinCsCov %d\n", + dh->dccph_cscov, min_cov); + /* FIXME: send Data Dropped option (see also dccp_v4_rcv) */ + goto discard_and_relse; + } + + if (!xfrm6_policy_check(sk, XFRM_POLICY_IN, skb)) + goto discard_and_relse; + nf_reset_ct(skb); + + return __sk_receive_skb(sk, skb, 1, dh->dccph_doff * 4, + refcounted) ? -1 : 0; + +no_dccp_socket: + if (!xfrm6_policy_check(NULL, XFRM_POLICY_IN, skb)) + goto discard_it; + /* + * Step 2: + * If no socket ... + * Generate Reset(No Connection) unless P.type == Reset + * Drop packet and return + */ + if (dh->dccph_type != DCCP_PKT_RESET) { + DCCP_SKB_CB(skb)->dccpd_reset_code = + DCCP_RESET_CODE_NO_CONNECTION; + dccp_v6_ctl_send_reset(sk, skb); + } + +discard_it: + kfree_skb(skb); + return 0; + +discard_and_relse: + if (refcounted) + sock_put(sk); + goto discard_it; +} + +static int dccp_v6_connect(struct sock *sk, struct sockaddr *uaddr, + int addr_len) +{ + struct sockaddr_in6 *usin = (struct sockaddr_in6 *)uaddr; + struct inet_connection_sock *icsk = inet_csk(sk); + struct inet_sock *inet = inet_sk(sk); + struct ipv6_pinfo *np = inet6_sk(sk); + struct dccp_sock *dp = dccp_sk(sk); + struct in6_addr *saddr = NULL, *final_p, final; + struct ipv6_txoptions *opt; + struct flowi6 fl6; + struct dst_entry *dst; + int addr_type; + int err; + + dp->dccps_role = DCCP_ROLE_CLIENT; + + if (addr_len < SIN6_LEN_RFC2133) + return -EINVAL; + + if (usin->sin6_family != AF_INET6) + return -EAFNOSUPPORT; + + memset(&fl6, 0, sizeof(fl6)); + + if (np->sndflow) { + fl6.flowlabel = usin->sin6_flowinfo & IPV6_FLOWINFO_MASK; + IP6_ECN_flow_init(fl6.flowlabel); + if (fl6.flowlabel & IPV6_FLOWLABEL_MASK) { + struct ip6_flowlabel *flowlabel; + flowlabel = fl6_sock_lookup(sk, fl6.flowlabel); + if (IS_ERR(flowlabel)) + return -EINVAL; + fl6_sock_release(flowlabel); + } + } + /* + * connect() to INADDR_ANY means loopback (BSD'ism). + */ + if (ipv6_addr_any(&usin->sin6_addr)) + usin->sin6_addr.s6_addr[15] = 1; + + addr_type = ipv6_addr_type(&usin->sin6_addr); + + if (addr_type & IPV6_ADDR_MULTICAST) + return -ENETUNREACH; + + if (addr_type & IPV6_ADDR_LINKLOCAL) { + if (addr_len >= sizeof(struct sockaddr_in6) && + usin->sin6_scope_id) { + /* If interface is set while binding, indices + * must coincide. + */ + if (sk->sk_bound_dev_if && + sk->sk_bound_dev_if != usin->sin6_scope_id) + return -EINVAL; + + sk->sk_bound_dev_if = usin->sin6_scope_id; + } + + /* Connect to link-local address requires an interface */ + if (!sk->sk_bound_dev_if) + return -EINVAL; + } + + sk->sk_v6_daddr = usin->sin6_addr; + np->flow_label = fl6.flowlabel; + + /* + * DCCP over IPv4 + */ + if (addr_type == IPV6_ADDR_MAPPED) { + u32 exthdrlen = icsk->icsk_ext_hdr_len; + struct sockaddr_in sin; + + SOCK_DEBUG(sk, "connect: ipv4 mapped\n"); + + if (ipv6_only_sock(sk)) + return -ENETUNREACH; + + sin.sin_family = AF_INET; + sin.sin_port = usin->sin6_port; + sin.sin_addr.s_addr = usin->sin6_addr.s6_addr32[3]; + + icsk->icsk_af_ops = &dccp_ipv6_mapped; + sk->sk_backlog_rcv = dccp_v4_do_rcv; + + err = dccp_v4_connect(sk, (struct sockaddr *)&sin, sizeof(sin)); + if (err) { + icsk->icsk_ext_hdr_len = exthdrlen; + icsk->icsk_af_ops = &dccp_ipv6_af_ops; + sk->sk_backlog_rcv = dccp_v6_do_rcv; + goto failure; + } + np->saddr = sk->sk_v6_rcv_saddr; + return err; + } + + if (!ipv6_addr_any(&sk->sk_v6_rcv_saddr)) + saddr = &sk->sk_v6_rcv_saddr; + + fl6.flowi6_proto = IPPROTO_DCCP; + fl6.daddr = sk->sk_v6_daddr; + fl6.saddr = saddr ? *saddr : np->saddr; + fl6.flowi6_oif = sk->sk_bound_dev_if; + fl6.fl6_dport = usin->sin6_port; + fl6.fl6_sport = inet->inet_sport; + security_sk_classify_flow(sk, flowi6_to_flowi_common(&fl6)); + + opt = rcu_dereference_protected(np->opt, lockdep_sock_is_held(sk)); + final_p = fl6_update_dst(&fl6, opt, &final); + + dst = ip6_dst_lookup_flow(sock_net(sk), sk, &fl6, final_p); + if (IS_ERR(dst)) { + err = PTR_ERR(dst); + goto failure; + } + + if (saddr == NULL) { + saddr = &fl6.saddr; + + err = inet_bhash2_update_saddr(sk, saddr, AF_INET6); + if (err) + goto failure; + } + + /* set the source address */ + np->saddr = *saddr; + inet->inet_rcv_saddr = LOOPBACK4_IPV6; + + ip6_dst_store(sk, dst, NULL, NULL); + + icsk->icsk_ext_hdr_len = 0; + if (opt) + icsk->icsk_ext_hdr_len = opt->opt_flen + opt->opt_nflen; + + inet->inet_dport = usin->sin6_port; + + dccp_set_state(sk, DCCP_REQUESTING); + err = inet6_hash_connect(&dccp_death_row, sk); + if (err) + goto late_failure; + + dp->dccps_iss = secure_dccpv6_sequence_number(np->saddr.s6_addr32, + sk->sk_v6_daddr.s6_addr32, + inet->inet_sport, + inet->inet_dport); + err = dccp_connect(sk); + if (err) + goto late_failure; + + return 0; + +late_failure: + dccp_set_state(sk, DCCP_CLOSED); + inet_bhash2_reset_saddr(sk); + __sk_dst_reset(sk); +failure: + inet->inet_dport = 0; + sk->sk_route_caps = 0; + return err; +} + +static const struct inet_connection_sock_af_ops dccp_ipv6_af_ops = { + .queue_xmit = inet6_csk_xmit, + .send_check = dccp_v6_send_check, + .rebuild_header = inet6_sk_rebuild_header, + .conn_request = dccp_v6_conn_request, + .syn_recv_sock = dccp_v6_request_recv_sock, + .net_header_len = sizeof(struct ipv6hdr), + .setsockopt = ipv6_setsockopt, + .getsockopt = ipv6_getsockopt, + .addr2sockaddr = inet6_csk_addr2sockaddr, + .sockaddr_len = sizeof(struct sockaddr_in6), +}; + +/* + * DCCP over IPv4 via INET6 API + */ +static const struct inet_connection_sock_af_ops dccp_ipv6_mapped = { + .queue_xmit = ip_queue_xmit, + .send_check = dccp_v4_send_check, + .rebuild_header = inet_sk_rebuild_header, + .conn_request = dccp_v6_conn_request, + .syn_recv_sock = dccp_v6_request_recv_sock, + .net_header_len = sizeof(struct iphdr), + .setsockopt = ipv6_setsockopt, + .getsockopt = ipv6_getsockopt, + .addr2sockaddr = inet6_csk_addr2sockaddr, + .sockaddr_len = sizeof(struct sockaddr_in6), +}; + +static void dccp_v6_sk_destruct(struct sock *sk) +{ + dccp_destruct_common(sk); + inet6_sock_destruct(sk); +} + +/* NOTE: A lot of things set to zero explicitly by call to + * sk_alloc() so need not be done here. + */ +static int dccp_v6_init_sock(struct sock *sk) +{ + static __u8 dccp_v6_ctl_sock_initialized; + int err = dccp_init_sock(sk, dccp_v6_ctl_sock_initialized); + + if (err == 0) { + if (unlikely(!dccp_v6_ctl_sock_initialized)) + dccp_v6_ctl_sock_initialized = 1; + inet_csk(sk)->icsk_af_ops = &dccp_ipv6_af_ops; + sk->sk_destruct = dccp_v6_sk_destruct; + } + + return err; +} + +static struct timewait_sock_ops dccp6_timewait_sock_ops = { + .twsk_obj_size = sizeof(struct dccp6_timewait_sock), +}; + +static struct proto dccp_v6_prot = { + .name = "DCCPv6", + .owner = THIS_MODULE, + .close = dccp_close, + .connect = dccp_v6_connect, + .disconnect = dccp_disconnect, + .ioctl = dccp_ioctl, + .init = dccp_v6_init_sock, + .setsockopt = dccp_setsockopt, + .getsockopt = dccp_getsockopt, + .sendmsg = dccp_sendmsg, + .recvmsg = dccp_recvmsg, + .backlog_rcv = dccp_v6_do_rcv, + .hash = inet6_hash, + .unhash = inet_unhash, + .accept = inet_csk_accept, + .get_port = inet_csk_get_port, + .shutdown = dccp_shutdown, + .destroy = dccp_destroy_sock, + .orphan_count = &dccp_orphan_count, + .max_header = MAX_DCCP_HEADER, + .obj_size = sizeof(struct dccp6_sock), + .slab_flags = SLAB_TYPESAFE_BY_RCU, + .rsk_prot = &dccp6_request_sock_ops, + .twsk_prot = &dccp6_timewait_sock_ops, + .h.hashinfo = &dccp_hashinfo, +}; + +static const struct inet6_protocol dccp_v6_protocol = { + .handler = dccp_v6_rcv, + .err_handler = dccp_v6_err, + .flags = INET6_PROTO_NOPOLICY | INET6_PROTO_FINAL, +}; + +static const struct proto_ops inet6_dccp_ops = { + .family = PF_INET6, + .owner = THIS_MODULE, + .release = inet6_release, + .bind = inet6_bind, + .connect = inet_stream_connect, + .socketpair = sock_no_socketpair, + .accept = inet_accept, + .getname = inet6_getname, + .poll = dccp_poll, + .ioctl = inet6_ioctl, + .gettstamp = sock_gettstamp, + .listen = inet_dccp_listen, + .shutdown = inet_shutdown, + .setsockopt = sock_common_setsockopt, + .getsockopt = sock_common_getsockopt, + .sendmsg = inet_sendmsg, + .recvmsg = sock_common_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +#ifdef CONFIG_COMPAT + .compat_ioctl = inet6_compat_ioctl, +#endif +}; + +static struct inet_protosw dccp_v6_protosw = { + .type = SOCK_DCCP, + .protocol = IPPROTO_DCCP, + .prot = &dccp_v6_prot, + .ops = &inet6_dccp_ops, + .flags = INET_PROTOSW_ICSK, +}; + +static int __net_init dccp_v6_init_net(struct net *net) +{ + struct dccp_v6_pernet *pn = net_generic(net, dccp_v6_pernet_id); + + if (dccp_hashinfo.bhash == NULL) + return -ESOCKTNOSUPPORT; + + return inet_ctl_sock_create(&pn->v6_ctl_sk, PF_INET6, + SOCK_DCCP, IPPROTO_DCCP, net); +} + +static void __net_exit dccp_v6_exit_net(struct net *net) +{ + struct dccp_v6_pernet *pn = net_generic(net, dccp_v6_pernet_id); + + inet_ctl_sock_destroy(pn->v6_ctl_sk); +} + +static void __net_exit dccp_v6_exit_batch(struct list_head *net_exit_list) +{ + inet_twsk_purge(&dccp_hashinfo, AF_INET6); +} + +static struct pernet_operations dccp_v6_ops = { + .init = dccp_v6_init_net, + .exit = dccp_v6_exit_net, + .exit_batch = dccp_v6_exit_batch, + .id = &dccp_v6_pernet_id, + .size = sizeof(struct dccp_v6_pernet), +}; + +static int __init dccp_v6_init(void) +{ + int err = proto_register(&dccp_v6_prot, 1); + + if (err) + goto out; + + inet6_register_protosw(&dccp_v6_protosw); + + err = register_pernet_subsys(&dccp_v6_ops); + if (err) + goto out_destroy_ctl_sock; + + err = inet6_add_protocol(&dccp_v6_protocol, IPPROTO_DCCP); + if (err) + goto out_unregister_proto; + +out: + return err; +out_unregister_proto: + unregister_pernet_subsys(&dccp_v6_ops); +out_destroy_ctl_sock: + inet6_unregister_protosw(&dccp_v6_protosw); + proto_unregister(&dccp_v6_prot); + goto out; +} + +static void __exit dccp_v6_exit(void) +{ + inet6_del_protocol(&dccp_v6_protocol, IPPROTO_DCCP); + unregister_pernet_subsys(&dccp_v6_ops); + inet6_unregister_protosw(&dccp_v6_protosw); + proto_unregister(&dccp_v6_prot); +} + +module_init(dccp_v6_init); +module_exit(dccp_v6_exit); + +/* + * __stringify doesn't likes enums, so use SOCK_DCCP (6) and IPPROTO_DCCP (33) + * values directly, Also cover the case where the protocol is not specified, + * i.e. net-pf-PF_INET6-proto-0-type-SOCK_DCCP + */ +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_INET6, 33, 6); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_INET6, 0, 6); +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Arnaldo Carvalho de Melo "); +MODULE_DESCRIPTION("DCCPv6 - Datagram Congestion Controlled Protocol"); diff --git a/net/dccp/ipv6.h b/net/dccp/ipv6.h new file mode 100644 index 000000000..7e4c2a3b3 --- /dev/null +++ b/net/dccp/ipv6.h @@ -0,0 +1,31 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +#ifndef _DCCP_IPV6_H +#define _DCCP_IPV6_H +/* + * net/dccp/ipv6.h + * + * An implementation of the DCCP protocol + * Copyright (c) 2005 Arnaldo Carvalho de Melo + */ + +#include +#include + +struct dccp6_sock { + struct dccp_sock dccp; + /* + * ipv6_pinfo has to be the last member of dccp6_sock, + * see inet6_sk_generic. + */ + struct ipv6_pinfo inet6; +}; + +struct dccp6_request_sock { + struct dccp_request_sock dccp; +}; + +struct dccp6_timewait_sock { + struct inet_timewait_sock inet; +}; + +#endif /* _DCCP_IPV6_H */ diff --git a/net/dccp/minisocks.c b/net/dccp/minisocks.c new file mode 100644 index 000000000..64d805b27 --- /dev/null +++ b/net/dccp/minisocks.c @@ -0,0 +1,272 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/dccp/minisocks.c + * + * An implementation of the DCCP protocol + * Arnaldo Carvalho de Melo + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "ackvec.h" +#include "ccid.h" +#include "dccp.h" +#include "feat.h" + +struct inet_timewait_death_row dccp_death_row = { + .tw_refcount = REFCOUNT_INIT(1), + .sysctl_max_tw_buckets = NR_FILE * 2, + .hashinfo = &dccp_hashinfo, +}; + +EXPORT_SYMBOL_GPL(dccp_death_row); + +void dccp_time_wait(struct sock *sk, int state, int timeo) +{ + struct inet_timewait_sock *tw; + + tw = inet_twsk_alloc(sk, &dccp_death_row, state); + + if (tw != NULL) { + const struct inet_connection_sock *icsk = inet_csk(sk); + const int rto = (icsk->icsk_rto << 2) - (icsk->icsk_rto >> 1); +#if IS_ENABLED(CONFIG_IPV6) + if (tw->tw_family == PF_INET6) { + tw->tw_v6_daddr = sk->sk_v6_daddr; + tw->tw_v6_rcv_saddr = sk->sk_v6_rcv_saddr; + tw->tw_ipv6only = sk->sk_ipv6only; + } +#endif + + /* Get the TIME_WAIT timeout firing. */ + if (timeo < rto) + timeo = rto; + + if (state == DCCP_TIME_WAIT) + timeo = DCCP_TIMEWAIT_LEN; + + /* tw_timer is pinned, so we need to make sure BH are disabled + * in following section, otherwise timer handler could run before + * we complete the initialization. + */ + local_bh_disable(); + inet_twsk_schedule(tw, timeo); + /* Linkage updates. + * Note that access to tw after this point is illegal. + */ + inet_twsk_hashdance(tw, sk, &dccp_hashinfo); + local_bh_enable(); + } else { + /* Sorry, if we're out of memory, just CLOSE this + * socket up. We've got bigger problems than + * non-graceful socket closings. + */ + DCCP_WARN("time wait bucket table overflow\n"); + } + + dccp_done(sk); +} + +struct sock *dccp_create_openreq_child(const struct sock *sk, + const struct request_sock *req, + const struct sk_buff *skb) +{ + /* + * Step 3: Process LISTEN state + * + * (* Generate a new socket and switch to that socket *) + * Set S := new socket for this port pair + */ + struct sock *newsk = inet_csk_clone_lock(sk, req, GFP_ATOMIC); + + if (newsk != NULL) { + struct dccp_request_sock *dreq = dccp_rsk(req); + struct inet_connection_sock *newicsk = inet_csk(newsk); + struct dccp_sock *newdp = dccp_sk(newsk); + + newdp->dccps_role = DCCP_ROLE_SERVER; + newdp->dccps_hc_rx_ackvec = NULL; + newdp->dccps_service_list = NULL; + newdp->dccps_hc_rx_ccid = NULL; + newdp->dccps_hc_tx_ccid = NULL; + newdp->dccps_service = dreq->dreq_service; + newdp->dccps_timestamp_echo = dreq->dreq_timestamp_echo; + newdp->dccps_timestamp_time = dreq->dreq_timestamp_time; + newicsk->icsk_rto = DCCP_TIMEOUT_INIT; + + INIT_LIST_HEAD(&newdp->dccps_featneg); + /* + * Step 3: Process LISTEN state + * + * Choose S.ISS (initial seqno) or set from Init Cookies + * Initialize S.GAR := S.ISS + * Set S.ISR, S.GSR from packet (or Init Cookies) + * + * Setting AWL/AWH and SWL/SWH happens as part of the feature + * activation below, as these windows all depend on the local + * and remote Sequence Window feature values (7.5.2). + */ + newdp->dccps_iss = dreq->dreq_iss; + newdp->dccps_gss = dreq->dreq_gss; + newdp->dccps_gar = newdp->dccps_iss; + newdp->dccps_isr = dreq->dreq_isr; + newdp->dccps_gsr = dreq->dreq_gsr; + + /* + * Activate features: initialise CCIDs, sequence windows etc. + */ + if (dccp_feat_activate_values(newsk, &dreq->dreq_featneg)) { + sk_free_unlock_clone(newsk); + return NULL; + } + dccp_init_xmit_timers(newsk); + + __DCCP_INC_STATS(DCCP_MIB_PASSIVEOPENS); + } + return newsk; +} + +EXPORT_SYMBOL_GPL(dccp_create_openreq_child); + +/* + * Process an incoming packet for RESPOND sockets represented + * as an request_sock. + */ +struct sock *dccp_check_req(struct sock *sk, struct sk_buff *skb, + struct request_sock *req) +{ + struct sock *child = NULL; + struct dccp_request_sock *dreq = dccp_rsk(req); + bool own_req; + + /* TCP/DCCP listeners became lockless. + * DCCP stores complex state in its request_sock, so we need + * a protection for them, now this code runs without being protected + * by the parent (listener) lock. + */ + spin_lock_bh(&dreq->dreq_lock); + + /* Check for retransmitted REQUEST */ + if (dccp_hdr(skb)->dccph_type == DCCP_PKT_REQUEST) { + + if (after48(DCCP_SKB_CB(skb)->dccpd_seq, dreq->dreq_gsr)) { + dccp_pr_debug("Retransmitted REQUEST\n"); + dreq->dreq_gsr = DCCP_SKB_CB(skb)->dccpd_seq; + /* + * Send another RESPONSE packet + * To protect against Request floods, increment retrans + * counter (backoff, monitored by dccp_response_timer). + */ + inet_rtx_syn_ack(sk, req); + } + /* Network Duplicate, discard packet */ + goto out; + } + + DCCP_SKB_CB(skb)->dccpd_reset_code = DCCP_RESET_CODE_PACKET_ERROR; + + if (dccp_hdr(skb)->dccph_type != DCCP_PKT_ACK && + dccp_hdr(skb)->dccph_type != DCCP_PKT_DATAACK) + goto drop; + + /* Invalid ACK */ + if (!between48(DCCP_SKB_CB(skb)->dccpd_ack_seq, + dreq->dreq_iss, dreq->dreq_gss)) { + dccp_pr_debug("Invalid ACK number: ack_seq=%llu, " + "dreq_iss=%llu, dreq_gss=%llu\n", + (unsigned long long) + DCCP_SKB_CB(skb)->dccpd_ack_seq, + (unsigned long long) dreq->dreq_iss, + (unsigned long long) dreq->dreq_gss); + goto drop; + } + + if (dccp_parse_options(sk, dreq, skb)) + goto drop; + + child = inet_csk(sk)->icsk_af_ops->syn_recv_sock(sk, skb, req, NULL, + req, &own_req); + if (child) { + child = inet_csk_complete_hashdance(sk, child, req, own_req); + goto out; + } + + DCCP_SKB_CB(skb)->dccpd_reset_code = DCCP_RESET_CODE_TOO_BUSY; +drop: + if (dccp_hdr(skb)->dccph_type != DCCP_PKT_RESET) + req->rsk_ops->send_reset(sk, skb); + + inet_csk_reqsk_queue_drop(sk, req); +out: + spin_unlock_bh(&dreq->dreq_lock); + return child; +} + +EXPORT_SYMBOL_GPL(dccp_check_req); + +/* + * Queue segment on the new socket if the new socket is active, + * otherwise we just shortcircuit this and continue with + * the new socket. + */ +int dccp_child_process(struct sock *parent, struct sock *child, + struct sk_buff *skb) + __releases(child) +{ + int ret = 0; + const int state = child->sk_state; + + if (!sock_owned_by_user(child)) { + ret = dccp_rcv_state_process(child, skb, dccp_hdr(skb), + skb->len); + + /* Wakeup parent, send SIGIO */ + if (state == DCCP_RESPOND && child->sk_state != state) + parent->sk_data_ready(parent); + } else { + /* Alas, it is possible again, because we do lookup + * in main socket hash table and lock on listening + * socket does not protect us more. + */ + __sk_add_backlog(child, skb); + } + + bh_unlock_sock(child); + sock_put(child); + return ret; +} + +EXPORT_SYMBOL_GPL(dccp_child_process); + +void dccp_reqsk_send_ack(const struct sock *sk, struct sk_buff *skb, + struct request_sock *rsk) +{ + DCCP_BUG("DCCP-ACK packets are never sent in LISTEN/RESPOND state"); +} + +EXPORT_SYMBOL_GPL(dccp_reqsk_send_ack); + +int dccp_reqsk_init(struct request_sock *req, + struct dccp_sock const *dp, struct sk_buff const *skb) +{ + struct dccp_request_sock *dreq = dccp_rsk(req); + + spin_lock_init(&dreq->dreq_lock); + inet_rsk(req)->ir_rmt_port = dccp_hdr(skb)->dccph_sport; + inet_rsk(req)->ir_num = ntohs(dccp_hdr(skb)->dccph_dport); + inet_rsk(req)->acked = 0; + dreq->dreq_timestamp_echo = 0; + + /* inherit feature negotiation options from listening socket */ + return dccp_feat_clone_list(&dp->dccps_featneg, &dreq->dreq_featneg); +} + +EXPORT_SYMBOL_GPL(dccp_reqsk_init); diff --git a/net/dccp/options.c b/net/dccp/options.c new file mode 100644 index 000000000..d24cad050 --- /dev/null +++ b/net/dccp/options.c @@ -0,0 +1,609 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/dccp/options.c + * + * An implementation of the DCCP protocol + * Copyright (c) 2005 Aristeu Sergio Rozanski Filho + * Copyright (c) 2005 Arnaldo Carvalho de Melo + * Copyright (c) 2005 Ian McDonald + */ +#include +#include +#include +#include +#include +#include + +#include "ackvec.h" +#include "ccid.h" +#include "dccp.h" +#include "feat.h" + +u64 dccp_decode_value_var(const u8 *bf, const u8 len) +{ + u64 value = 0; + + if (len >= DCCP_OPTVAL_MAXLEN) + value += ((u64)*bf++) << 40; + if (len > 4) + value += ((u64)*bf++) << 32; + if (len > 3) + value += ((u64)*bf++) << 24; + if (len > 2) + value += ((u64)*bf++) << 16; + if (len > 1) + value += ((u64)*bf++) << 8; + if (len > 0) + value += *bf; + + return value; +} + +/** + * dccp_parse_options - Parse DCCP options present in @skb + * @sk: client|server|listening dccp socket (when @dreq != NULL) + * @dreq: request socket to use during connection setup, or NULL + * @skb: frame to parse + */ +int dccp_parse_options(struct sock *sk, struct dccp_request_sock *dreq, + struct sk_buff *skb) +{ + struct dccp_sock *dp = dccp_sk(sk); + const struct dccp_hdr *dh = dccp_hdr(skb); + const u8 pkt_type = DCCP_SKB_CB(skb)->dccpd_type; + unsigned char *options = (unsigned char *)dh + dccp_hdr_len(skb); + unsigned char *opt_ptr = options; + const unsigned char *opt_end = (unsigned char *)dh + + (dh->dccph_doff * 4); + struct dccp_options_received *opt_recv = &dp->dccps_options_received; + unsigned char opt, len; + unsigned char *value; + u32 elapsed_time; + __be32 opt_val; + int rc; + int mandatory = 0; + + memset(opt_recv, 0, sizeof(*opt_recv)); + + opt = len = 0; + while (opt_ptr != opt_end) { + opt = *opt_ptr++; + len = 0; + value = NULL; + + /* Check if this isn't a single byte option */ + if (opt > DCCPO_MAX_RESERVED) { + if (opt_ptr == opt_end) + goto out_nonsensical_length; + + len = *opt_ptr++; + if (len < 2) + goto out_nonsensical_length; + /* + * Remove the type and len fields, leaving + * just the value size + */ + len -= 2; + value = opt_ptr; + opt_ptr += len; + + if (opt_ptr > opt_end) + goto out_nonsensical_length; + } + + /* + * CCID-specific options are ignored during connection setup, as + * negotiation may still be in progress (see RFC 4340, 10.3). + * The same applies to Ack Vectors, as these depend on the CCID. + */ + if (dreq != NULL && (opt >= DCCPO_MIN_RX_CCID_SPECIFIC || + opt == DCCPO_ACK_VECTOR_0 || opt == DCCPO_ACK_VECTOR_1)) + goto ignore_option; + + switch (opt) { + case DCCPO_PADDING: + break; + case DCCPO_MANDATORY: + if (mandatory) + goto out_invalid_option; + if (pkt_type != DCCP_PKT_DATA) + mandatory = 1; + break; + case DCCPO_NDP_COUNT: + if (len > 6) + goto out_invalid_option; + + opt_recv->dccpor_ndp = dccp_decode_value_var(value, len); + dccp_pr_debug("%s opt: NDP count=%llu\n", dccp_role(sk), + (unsigned long long)opt_recv->dccpor_ndp); + break; + case DCCPO_CHANGE_L ... DCCPO_CONFIRM_R: + if (pkt_type == DCCP_PKT_DATA) /* RFC 4340, 6 */ + break; + if (len == 0) + goto out_invalid_option; + rc = dccp_feat_parse_options(sk, dreq, mandatory, opt, + *value, value + 1, len - 1); + if (rc) + goto out_featneg_failed; + break; + case DCCPO_TIMESTAMP: + if (len != 4) + goto out_invalid_option; + /* + * RFC 4340 13.1: "The precise time corresponding to + * Timestamp Value zero is not specified". We use + * zero to indicate absence of a meaningful timestamp. + */ + opt_val = get_unaligned((__be32 *)value); + if (unlikely(opt_val == 0)) { + DCCP_WARN("Timestamp with zero value\n"); + break; + } + + if (dreq != NULL) { + dreq->dreq_timestamp_echo = ntohl(opt_val); + dreq->dreq_timestamp_time = dccp_timestamp(); + } else { + opt_recv->dccpor_timestamp = + dp->dccps_timestamp_echo = ntohl(opt_val); + dp->dccps_timestamp_time = dccp_timestamp(); + } + dccp_pr_debug("%s rx opt: TIMESTAMP=%u, ackno=%llu\n", + dccp_role(sk), ntohl(opt_val), + (unsigned long long) + DCCP_SKB_CB(skb)->dccpd_ack_seq); + /* schedule an Ack in case this sender is quiescent */ + inet_csk_schedule_ack(sk); + break; + case DCCPO_TIMESTAMP_ECHO: + if (len != 4 && len != 6 && len != 8) + goto out_invalid_option; + + opt_val = get_unaligned((__be32 *)value); + opt_recv->dccpor_timestamp_echo = ntohl(opt_val); + + dccp_pr_debug("%s rx opt: TIMESTAMP_ECHO=%u, len=%d, " + "ackno=%llu", dccp_role(sk), + opt_recv->dccpor_timestamp_echo, + len + 2, + (unsigned long long) + DCCP_SKB_CB(skb)->dccpd_ack_seq); + + value += 4; + + if (len == 4) { /* no elapsed time included */ + dccp_pr_debug_cat("\n"); + break; + } + + if (len == 6) { /* 2-byte elapsed time */ + __be16 opt_val2 = get_unaligned((__be16 *)value); + elapsed_time = ntohs(opt_val2); + } else { /* 4-byte elapsed time */ + opt_val = get_unaligned((__be32 *)value); + elapsed_time = ntohl(opt_val); + } + + dccp_pr_debug_cat(", ELAPSED_TIME=%u\n", elapsed_time); + + /* Give precedence to the biggest ELAPSED_TIME */ + if (elapsed_time > opt_recv->dccpor_elapsed_time) + opt_recv->dccpor_elapsed_time = elapsed_time; + break; + case DCCPO_ELAPSED_TIME: + if (dccp_packet_without_ack(skb)) /* RFC 4340, 13.2 */ + break; + + if (len == 2) { + __be16 opt_val2 = get_unaligned((__be16 *)value); + elapsed_time = ntohs(opt_val2); + } else if (len == 4) { + opt_val = get_unaligned((__be32 *)value); + elapsed_time = ntohl(opt_val); + } else { + goto out_invalid_option; + } + + if (elapsed_time > opt_recv->dccpor_elapsed_time) + opt_recv->dccpor_elapsed_time = elapsed_time; + + dccp_pr_debug("%s rx opt: ELAPSED_TIME=%d\n", + dccp_role(sk), elapsed_time); + break; + case DCCPO_MIN_RX_CCID_SPECIFIC ... DCCPO_MAX_RX_CCID_SPECIFIC: + if (ccid_hc_rx_parse_options(dp->dccps_hc_rx_ccid, sk, + pkt_type, opt, value, len)) + goto out_invalid_option; + break; + case DCCPO_ACK_VECTOR_0: + case DCCPO_ACK_VECTOR_1: + if (dccp_packet_without_ack(skb)) /* RFC 4340, 11.4 */ + break; + /* + * Ack vectors are processed by the TX CCID if it is + * interested. The RX CCID need not parse Ack Vectors, + * since it is only interested in clearing old state. + */ + fallthrough; + case DCCPO_MIN_TX_CCID_SPECIFIC ... DCCPO_MAX_TX_CCID_SPECIFIC: + if (ccid_hc_tx_parse_options(dp->dccps_hc_tx_ccid, sk, + pkt_type, opt, value, len)) + goto out_invalid_option; + break; + default: + DCCP_CRIT("DCCP(%p): option %d(len=%d) not " + "implemented, ignoring", sk, opt, len); + break; + } +ignore_option: + if (opt != DCCPO_MANDATORY) + mandatory = 0; + } + + /* mandatory was the last byte in option list -> reset connection */ + if (mandatory) + goto out_invalid_option; + +out_nonsensical_length: + /* RFC 4340, 5.8: ignore option and all remaining option space */ + return 0; + +out_invalid_option: + DCCP_INC_STATS(DCCP_MIB_INVALIDOPT); + rc = DCCP_RESET_CODE_OPTION_ERROR; +out_featneg_failed: + DCCP_WARN("DCCP(%p): Option %d (len=%d) error=%u\n", sk, opt, len, rc); + DCCP_SKB_CB(skb)->dccpd_reset_code = rc; + DCCP_SKB_CB(skb)->dccpd_reset_data[0] = opt; + DCCP_SKB_CB(skb)->dccpd_reset_data[1] = len > 0 ? value[0] : 0; + DCCP_SKB_CB(skb)->dccpd_reset_data[2] = len > 1 ? value[1] : 0; + return -1; +} + +EXPORT_SYMBOL_GPL(dccp_parse_options); + +void dccp_encode_value_var(const u64 value, u8 *to, const u8 len) +{ + if (len >= DCCP_OPTVAL_MAXLEN) + *to++ = (value & 0xFF0000000000ull) >> 40; + if (len > 4) + *to++ = (value & 0xFF00000000ull) >> 32; + if (len > 3) + *to++ = (value & 0xFF000000) >> 24; + if (len > 2) + *to++ = (value & 0xFF0000) >> 16; + if (len > 1) + *to++ = (value & 0xFF00) >> 8; + if (len > 0) + *to++ = (value & 0xFF); +} + +static inline u8 dccp_ndp_len(const u64 ndp) +{ + if (likely(ndp <= 0xFF)) + return 1; + return likely(ndp <= USHRT_MAX) ? 2 : (ndp <= UINT_MAX ? 4 : 6); +} + +int dccp_insert_option(struct sk_buff *skb, const unsigned char option, + const void *value, const unsigned char len) +{ + unsigned char *to; + + if (DCCP_SKB_CB(skb)->dccpd_opt_len + len + 2 > DCCP_MAX_OPT_LEN) + return -1; + + DCCP_SKB_CB(skb)->dccpd_opt_len += len + 2; + + to = skb_push(skb, len + 2); + *to++ = option; + *to++ = len + 2; + + memcpy(to, value, len); + return 0; +} + +EXPORT_SYMBOL_GPL(dccp_insert_option); + +static int dccp_insert_option_ndp(struct sock *sk, struct sk_buff *skb) +{ + struct dccp_sock *dp = dccp_sk(sk); + u64 ndp = dp->dccps_ndp_count; + + if (dccp_non_data_packet(skb)) + ++dp->dccps_ndp_count; + else + dp->dccps_ndp_count = 0; + + if (ndp > 0) { + unsigned char *ptr; + const int ndp_len = dccp_ndp_len(ndp); + const int len = ndp_len + 2; + + if (DCCP_SKB_CB(skb)->dccpd_opt_len + len > DCCP_MAX_OPT_LEN) + return -1; + + DCCP_SKB_CB(skb)->dccpd_opt_len += len; + + ptr = skb_push(skb, len); + *ptr++ = DCCPO_NDP_COUNT; + *ptr++ = len; + dccp_encode_value_var(ndp, ptr, ndp_len); + } + + return 0; +} + +static inline int dccp_elapsed_time_len(const u32 elapsed_time) +{ + return elapsed_time == 0 ? 0 : elapsed_time <= 0xFFFF ? 2 : 4; +} + +static int dccp_insert_option_timestamp(struct sk_buff *skb) +{ + __be32 now = htonl(dccp_timestamp()); + /* yes this will overflow but that is the point as we want a + * 10 usec 32 bit timer which mean it wraps every 11.9 hours */ + + return dccp_insert_option(skb, DCCPO_TIMESTAMP, &now, sizeof(now)); +} + +static int dccp_insert_option_timestamp_echo(struct dccp_sock *dp, + struct dccp_request_sock *dreq, + struct sk_buff *skb) +{ + __be32 tstamp_echo; + unsigned char *to; + u32 elapsed_time, elapsed_time_len, len; + + if (dreq != NULL) { + elapsed_time = dccp_timestamp() - dreq->dreq_timestamp_time; + tstamp_echo = htonl(dreq->dreq_timestamp_echo); + dreq->dreq_timestamp_echo = 0; + } else { + elapsed_time = dccp_timestamp() - dp->dccps_timestamp_time; + tstamp_echo = htonl(dp->dccps_timestamp_echo); + dp->dccps_timestamp_echo = 0; + } + + elapsed_time_len = dccp_elapsed_time_len(elapsed_time); + len = 6 + elapsed_time_len; + + if (DCCP_SKB_CB(skb)->dccpd_opt_len + len > DCCP_MAX_OPT_LEN) + return -1; + + DCCP_SKB_CB(skb)->dccpd_opt_len += len; + + to = skb_push(skb, len); + *to++ = DCCPO_TIMESTAMP_ECHO; + *to++ = len; + + memcpy(to, &tstamp_echo, 4); + to += 4; + + if (elapsed_time_len == 2) { + const __be16 var16 = htons((u16)elapsed_time); + memcpy(to, &var16, 2); + } else if (elapsed_time_len == 4) { + const __be32 var32 = htonl(elapsed_time); + memcpy(to, &var32, 4); + } + + return 0; +} + +static int dccp_insert_option_ackvec(struct sock *sk, struct sk_buff *skb) +{ + struct dccp_sock *dp = dccp_sk(sk); + struct dccp_ackvec *av = dp->dccps_hc_rx_ackvec; + struct dccp_skb_cb *dcb = DCCP_SKB_CB(skb); + const u16 buflen = dccp_ackvec_buflen(av); + /* Figure out how many options do we need to represent the ackvec */ + const u8 nr_opts = DIV_ROUND_UP(buflen, DCCP_SINGLE_OPT_MAXLEN); + u16 len = buflen + 2 * nr_opts; + u8 i, nonce = 0; + const unsigned char *tail, *from; + unsigned char *to; + + if (dcb->dccpd_opt_len + len > DCCP_MAX_OPT_LEN) { + DCCP_WARN("Lacking space for %u bytes on %s packet\n", len, + dccp_packet_name(dcb->dccpd_type)); + return -1; + } + /* + * Since Ack Vectors are variable-length, we can not always predict + * their size. To catch exception cases where the space is running out + * on the skb, a separate Sync is scheduled to carry the Ack Vector. + */ + if (len > DCCPAV_MIN_OPTLEN && + len + dcb->dccpd_opt_len + skb->len > dp->dccps_mss_cache) { + DCCP_WARN("No space left for Ack Vector (%u) on skb (%u+%u), " + "MPS=%u ==> reduce payload size?\n", len, skb->len, + dcb->dccpd_opt_len, dp->dccps_mss_cache); + dp->dccps_sync_scheduled = 1; + return 0; + } + dcb->dccpd_opt_len += len; + + to = skb_push(skb, len); + len = buflen; + from = av->av_buf + av->av_buf_head; + tail = av->av_buf + DCCPAV_MAX_ACKVEC_LEN; + + for (i = 0; i < nr_opts; ++i) { + int copylen = len; + + if (len > DCCP_SINGLE_OPT_MAXLEN) + copylen = DCCP_SINGLE_OPT_MAXLEN; + + /* + * RFC 4340, 12.2: Encode the Nonce Echo for this Ack Vector via + * its type; ack_nonce is the sum of all individual buf_nonce's. + */ + nonce ^= av->av_buf_nonce[i]; + + *to++ = DCCPO_ACK_VECTOR_0 + av->av_buf_nonce[i]; + *to++ = copylen + 2; + + /* Check if buf_head wraps */ + if (from + copylen > tail) { + const u16 tailsize = tail - from; + + memcpy(to, from, tailsize); + to += tailsize; + len -= tailsize; + copylen -= tailsize; + from = av->av_buf; + } + + memcpy(to, from, copylen); + from += copylen; + to += copylen; + len -= copylen; + } + /* + * Each sent Ack Vector is recorded in the list, as per A.2 of RFC 4340. + */ + if (dccp_ackvec_update_records(av, dcb->dccpd_seq, nonce)) + return -ENOBUFS; + return 0; +} + +/** + * dccp_insert_option_mandatory - Mandatory option (5.8.2) + * @skb: frame into which to insert option + * + * Note that since we are using skb_push, this function needs to be called + * _after_ inserting the option it is supposed to influence (stack order). + */ +int dccp_insert_option_mandatory(struct sk_buff *skb) +{ + if (DCCP_SKB_CB(skb)->dccpd_opt_len >= DCCP_MAX_OPT_LEN) + return -1; + + DCCP_SKB_CB(skb)->dccpd_opt_len++; + *(u8 *)skb_push(skb, 1) = DCCPO_MANDATORY; + return 0; +} + +/** + * dccp_insert_fn_opt - Insert single Feature-Negotiation option into @skb + * @skb: frame to insert feature negotiation option into + * @type: %DCCPO_CHANGE_L, %DCCPO_CHANGE_R, %DCCPO_CONFIRM_L, %DCCPO_CONFIRM_R + * @feat: one out of %dccp_feature_numbers + * @val: NN value or SP array (preferred element first) to copy + * @len: true length of @val in bytes (excluding first element repetition) + * @repeat_first: whether to copy the first element of @val twice + * + * The last argument is used to construct Confirm options, where the preferred + * value and the preference list appear separately (RFC 4340, 6.3.1). Preference + * lists are kept such that the preferred entry is always first, so we only need + * to copy twice, and avoid the overhead of cloning into a bigger array. + */ +int dccp_insert_fn_opt(struct sk_buff *skb, u8 type, u8 feat, + u8 *val, u8 len, bool repeat_first) +{ + u8 tot_len, *to; + + /* take the `Feature' field and possible repetition into account */ + if (len > (DCCP_SINGLE_OPT_MAXLEN - 2)) { + DCCP_WARN("length %u for feature %u too large\n", len, feat); + return -1; + } + + if (unlikely(val == NULL || len == 0)) + len = repeat_first = false; + tot_len = 3 + repeat_first + len; + + if (DCCP_SKB_CB(skb)->dccpd_opt_len + tot_len > DCCP_MAX_OPT_LEN) { + DCCP_WARN("packet too small for feature %d option!\n", feat); + return -1; + } + DCCP_SKB_CB(skb)->dccpd_opt_len += tot_len; + + to = skb_push(skb, tot_len); + *to++ = type; + *to++ = tot_len; + *to++ = feat; + + if (repeat_first) + *to++ = *val; + if (len) + memcpy(to, val, len); + return 0; +} + +/* The length of all options needs to be a multiple of 4 (5.8) */ +static void dccp_insert_option_padding(struct sk_buff *skb) +{ + int padding = DCCP_SKB_CB(skb)->dccpd_opt_len % 4; + + if (padding != 0) { + padding = 4 - padding; + memset(skb_push(skb, padding), 0, padding); + DCCP_SKB_CB(skb)->dccpd_opt_len += padding; + } +} + +int dccp_insert_options(struct sock *sk, struct sk_buff *skb) +{ + struct dccp_sock *dp = dccp_sk(sk); + + DCCP_SKB_CB(skb)->dccpd_opt_len = 0; + + if (dp->dccps_send_ndp_count && dccp_insert_option_ndp(sk, skb)) + return -1; + + if (DCCP_SKB_CB(skb)->dccpd_type != DCCP_PKT_DATA) { + + /* Feature Negotiation */ + if (dccp_feat_insert_opts(dp, NULL, skb)) + return -1; + + if (DCCP_SKB_CB(skb)->dccpd_type == DCCP_PKT_REQUEST) { + /* + * Obtain RTT sample from Request/Response exchange. + * This is currently used for TFRC initialisation. + */ + if (dccp_insert_option_timestamp(skb)) + return -1; + + } else if (dccp_ackvec_pending(sk) && + dccp_insert_option_ackvec(sk, skb)) { + return -1; + } + } + + if (dp->dccps_hc_rx_insert_options) { + if (ccid_hc_rx_insert_options(dp->dccps_hc_rx_ccid, sk, skb)) + return -1; + dp->dccps_hc_rx_insert_options = 0; + } + + if (dp->dccps_timestamp_echo != 0 && + dccp_insert_option_timestamp_echo(dp, NULL, skb)) + return -1; + + dccp_insert_option_padding(skb); + return 0; +} + +int dccp_insert_options_rsk(struct dccp_request_sock *dreq, struct sk_buff *skb) +{ + DCCP_SKB_CB(skb)->dccpd_opt_len = 0; + + if (dccp_feat_insert_opts(NULL, dreq, skb)) + return -1; + + /* Obtain RTT sample from Response/Ack exchange (used by TFRC). */ + if (dccp_insert_option_timestamp(skb)) + return -1; + + if (dreq->dreq_timestamp_echo != 0 && + dccp_insert_option_timestamp_echo(NULL, dreq, skb)) + return -1; + + dccp_insert_option_padding(skb); + return 0; +} diff --git a/net/dccp/output.c b/net/dccp/output.c new file mode 100644 index 000000000..fd2eb148d --- /dev/null +++ b/net/dccp/output.c @@ -0,0 +1,709 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/dccp/output.c + * + * An implementation of the DCCP protocol + * Arnaldo Carvalho de Melo + */ + +#include +#include +#include +#include +#include + +#include +#include + +#include "ackvec.h" +#include "ccid.h" +#include "dccp.h" + +static inline void dccp_event_ack_sent(struct sock *sk) +{ + inet_csk_clear_xmit_timer(sk, ICSK_TIME_DACK); +} + +/* enqueue @skb on sk_send_head for retransmission, return clone to send now */ +static struct sk_buff *dccp_skb_entail(struct sock *sk, struct sk_buff *skb) +{ + skb_set_owner_w(skb, sk); + WARN_ON(sk->sk_send_head); + sk->sk_send_head = skb; + return skb_clone(sk->sk_send_head, gfp_any()); +} + +/* + * All SKB's seen here are completely headerless. It is our + * job to build the DCCP header, and pass the packet down to + * IP so it can do the same plus pass the packet off to the + * device. + */ +static int dccp_transmit_skb(struct sock *sk, struct sk_buff *skb) +{ + if (likely(skb != NULL)) { + struct inet_sock *inet = inet_sk(sk); + const struct inet_connection_sock *icsk = inet_csk(sk); + struct dccp_sock *dp = dccp_sk(sk); + struct dccp_skb_cb *dcb = DCCP_SKB_CB(skb); + struct dccp_hdr *dh; + /* XXX For now we're using only 48 bits sequence numbers */ + const u32 dccp_header_size = sizeof(*dh) + + sizeof(struct dccp_hdr_ext) + + dccp_packet_hdr_len(dcb->dccpd_type); + int err, set_ack = 1; + u64 ackno = dp->dccps_gsr; + /* + * Increment GSS here already in case the option code needs it. + * Update GSS for real only if option processing below succeeds. + */ + dcb->dccpd_seq = ADD48(dp->dccps_gss, 1); + + switch (dcb->dccpd_type) { + case DCCP_PKT_DATA: + set_ack = 0; + fallthrough; + case DCCP_PKT_DATAACK: + case DCCP_PKT_RESET: + break; + + case DCCP_PKT_REQUEST: + set_ack = 0; + /* Use ISS on the first (non-retransmitted) Request. */ + if (icsk->icsk_retransmits == 0) + dcb->dccpd_seq = dp->dccps_iss; + fallthrough; + + case DCCP_PKT_SYNC: + case DCCP_PKT_SYNCACK: + ackno = dcb->dccpd_ack_seq; + fallthrough; + default: + /* + * Set owner/destructor: some skbs are allocated via + * alloc_skb (e.g. when retransmission may happen). + * Only Data, DataAck, and Reset packets should come + * through here with skb->sk set. + */ + WARN_ON(skb->sk); + skb_set_owner_w(skb, sk); + break; + } + + if (dccp_insert_options(sk, skb)) { + kfree_skb(skb); + return -EPROTO; + } + + + /* Build DCCP header and checksum it. */ + dh = dccp_zeroed_hdr(skb, dccp_header_size); + dh->dccph_type = dcb->dccpd_type; + dh->dccph_sport = inet->inet_sport; + dh->dccph_dport = inet->inet_dport; + dh->dccph_doff = (dccp_header_size + dcb->dccpd_opt_len) / 4; + dh->dccph_ccval = dcb->dccpd_ccval; + dh->dccph_cscov = dp->dccps_pcslen; + /* XXX For now we're using only 48 bits sequence numbers */ + dh->dccph_x = 1; + + dccp_update_gss(sk, dcb->dccpd_seq); + dccp_hdr_set_seq(dh, dp->dccps_gss); + if (set_ack) + dccp_hdr_set_ack(dccp_hdr_ack_bits(skb), ackno); + + switch (dcb->dccpd_type) { + case DCCP_PKT_REQUEST: + dccp_hdr_request(skb)->dccph_req_service = + dp->dccps_service; + /* + * Limit Ack window to ISS <= P.ackno <= GSS, so that + * only Responses to Requests we sent are considered. + */ + dp->dccps_awl = dp->dccps_iss; + break; + case DCCP_PKT_RESET: + dccp_hdr_reset(skb)->dccph_reset_code = + dcb->dccpd_reset_code; + break; + } + + icsk->icsk_af_ops->send_check(sk, skb); + + if (set_ack) + dccp_event_ack_sent(sk); + + DCCP_INC_STATS(DCCP_MIB_OUTSEGS); + + err = icsk->icsk_af_ops->queue_xmit(sk, skb, &inet->cork.fl); + return net_xmit_eval(err); + } + return -ENOBUFS; +} + +/** + * dccp_determine_ccmps - Find out about CCID-specific packet-size limits + * @dp: socket to find packet size limits of + * + * We only consider the HC-sender CCID for setting the CCMPS (RFC 4340, 14.), + * since the RX CCID is restricted to feedback packets (Acks), which are small + * in comparison with the data traffic. A value of 0 means "no current CCMPS". + */ +static u32 dccp_determine_ccmps(const struct dccp_sock *dp) +{ + const struct ccid *tx_ccid = dp->dccps_hc_tx_ccid; + + if (tx_ccid == NULL || tx_ccid->ccid_ops == NULL) + return 0; + return tx_ccid->ccid_ops->ccid_ccmps; +} + +unsigned int dccp_sync_mss(struct sock *sk, u32 pmtu) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct dccp_sock *dp = dccp_sk(sk); + u32 ccmps = dccp_determine_ccmps(dp); + u32 cur_mps = ccmps ? min(pmtu, ccmps) : pmtu; + + /* Account for header lengths and IPv4/v6 option overhead */ + cur_mps -= (icsk->icsk_af_ops->net_header_len + icsk->icsk_ext_hdr_len + + sizeof(struct dccp_hdr) + sizeof(struct dccp_hdr_ext)); + + /* + * Leave enough headroom for common DCCP header options. + * This only considers options which may appear on DCCP-Data packets, as + * per table 3 in RFC 4340, 5.8. When running out of space for other + * options (eg. Ack Vector which can take up to 255 bytes), it is better + * to schedule a separate Ack. Thus we leave headroom for the following: + * - 1 byte for Slow Receiver (11.6) + * - 6 bytes for Timestamp (13.1) + * - 10 bytes for Timestamp Echo (13.3) + * - 8 bytes for NDP count (7.7, when activated) + * - 6 bytes for Data Checksum (9.3) + * - %DCCPAV_MIN_OPTLEN bytes for Ack Vector size (11.4, when enabled) + */ + cur_mps -= roundup(1 + 6 + 10 + dp->dccps_send_ndp_count * 8 + 6 + + (dp->dccps_hc_rx_ackvec ? DCCPAV_MIN_OPTLEN : 0), 4); + + /* And store cached results */ + icsk->icsk_pmtu_cookie = pmtu; + WRITE_ONCE(dp->dccps_mss_cache, cur_mps); + + return cur_mps; +} + +EXPORT_SYMBOL_GPL(dccp_sync_mss); + +void dccp_write_space(struct sock *sk) +{ + struct socket_wq *wq; + + rcu_read_lock(); + wq = rcu_dereference(sk->sk_wq); + if (skwq_has_sleeper(wq)) + wake_up_interruptible(&wq->wait); + /* Should agree with poll, otherwise some programs break */ + if (sock_writeable(sk)) + sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT); + + rcu_read_unlock(); +} + +/** + * dccp_wait_for_ccid - Await CCID send permission + * @sk: socket to wait for + * @delay: timeout in jiffies + * + * This is used by CCIDs which need to delay the send time in process context. + */ +static int dccp_wait_for_ccid(struct sock *sk, unsigned long delay) +{ + DEFINE_WAIT(wait); + long remaining; + + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + sk->sk_write_pending++; + release_sock(sk); + + remaining = schedule_timeout(delay); + + lock_sock(sk); + sk->sk_write_pending--; + finish_wait(sk_sleep(sk), &wait); + + if (signal_pending(current) || sk->sk_err) + return -1; + return remaining; +} + +/** + * dccp_xmit_packet - Send data packet under control of CCID + * @sk: socket to send data packet on + * + * Transmits next-queued payload and informs CCID to account for the packet. + */ +static void dccp_xmit_packet(struct sock *sk) +{ + int err, len; + struct dccp_sock *dp = dccp_sk(sk); + struct sk_buff *skb = dccp_qpolicy_pop(sk); + + if (unlikely(skb == NULL)) + return; + len = skb->len; + + if (sk->sk_state == DCCP_PARTOPEN) { + const u32 cur_mps = dp->dccps_mss_cache - DCCP_FEATNEG_OVERHEAD; + /* + * See 8.1.5 - Handshake Completion. + * + * For robustness we resend Confirm options until the client has + * entered OPEN. During the initial feature negotiation, the MPS + * is smaller than usual, reduced by the Change/Confirm options. + */ + if (!list_empty(&dp->dccps_featneg) && len > cur_mps) { + DCCP_WARN("Payload too large (%d) for featneg.\n", len); + dccp_send_ack(sk); + dccp_feat_list_purge(&dp->dccps_featneg); + } + + inet_csk_schedule_ack(sk); + inet_csk_reset_xmit_timer(sk, ICSK_TIME_DACK, + inet_csk(sk)->icsk_rto, + DCCP_RTO_MAX); + DCCP_SKB_CB(skb)->dccpd_type = DCCP_PKT_DATAACK; + } else if (dccp_ack_pending(sk)) { + DCCP_SKB_CB(skb)->dccpd_type = DCCP_PKT_DATAACK; + } else { + DCCP_SKB_CB(skb)->dccpd_type = DCCP_PKT_DATA; + } + + err = dccp_transmit_skb(sk, skb); + if (err) + dccp_pr_debug("transmit_skb() returned err=%d\n", err); + /* + * Register this one as sent even if an error occurred. To the remote + * end a local packet drop is indistinguishable from network loss, i.e. + * any local drop will eventually be reported via receiver feedback. + */ + ccid_hc_tx_packet_sent(dp->dccps_hc_tx_ccid, sk, len); + + /* + * If the CCID needs to transfer additional header options out-of-band + * (e.g. Ack Vectors or feature-negotiation options), it activates this + * flag to schedule a Sync. The Sync will automatically incorporate all + * currently pending header options, thus clearing the backlog. + */ + if (dp->dccps_sync_scheduled) + dccp_send_sync(sk, dp->dccps_gsr, DCCP_PKT_SYNC); +} + +/** + * dccp_flush_write_queue - Drain queue at end of connection + * @sk: socket to be drained + * @time_budget: time allowed to drain the queue + * + * Since dccp_sendmsg queues packets without waiting for them to be sent, it may + * happen that the TX queue is not empty at the end of a connection. We give the + * HC-sender CCID a grace period of up to @time_budget jiffies. If this function + * returns with a non-empty write queue, it will be purged later. + */ +void dccp_flush_write_queue(struct sock *sk, long *time_budget) +{ + struct dccp_sock *dp = dccp_sk(sk); + struct sk_buff *skb; + long delay, rc; + + while (*time_budget > 0 && (skb = skb_peek(&sk->sk_write_queue))) { + rc = ccid_hc_tx_send_packet(dp->dccps_hc_tx_ccid, sk, skb); + + switch (ccid_packet_dequeue_eval(rc)) { + case CCID_PACKET_WILL_DEQUEUE_LATER: + /* + * If the CCID determines when to send, the next sending + * time is unknown or the CCID may not even send again + * (e.g. remote host crashes or lost Ack packets). + */ + DCCP_WARN("CCID did not manage to send all packets\n"); + return; + case CCID_PACKET_DELAY: + delay = msecs_to_jiffies(rc); + if (delay > *time_budget) + return; + rc = dccp_wait_for_ccid(sk, delay); + if (rc < 0) + return; + *time_budget -= (delay - rc); + /* check again if we can send now */ + break; + case CCID_PACKET_SEND_AT_ONCE: + dccp_xmit_packet(sk); + break; + case CCID_PACKET_ERR: + skb_dequeue(&sk->sk_write_queue); + kfree_skb(skb); + dccp_pr_debug("packet discarded due to err=%ld\n", rc); + } + } +} + +void dccp_write_xmit(struct sock *sk) +{ + struct dccp_sock *dp = dccp_sk(sk); + struct sk_buff *skb; + + while ((skb = dccp_qpolicy_top(sk))) { + int rc = ccid_hc_tx_send_packet(dp->dccps_hc_tx_ccid, sk, skb); + + switch (ccid_packet_dequeue_eval(rc)) { + case CCID_PACKET_WILL_DEQUEUE_LATER: + return; + case CCID_PACKET_DELAY: + sk_reset_timer(sk, &dp->dccps_xmit_timer, + jiffies + msecs_to_jiffies(rc)); + return; + case CCID_PACKET_SEND_AT_ONCE: + dccp_xmit_packet(sk); + break; + case CCID_PACKET_ERR: + dccp_qpolicy_drop(sk, skb); + dccp_pr_debug("packet discarded due to err=%d\n", rc); + } + } +} + +/** + * dccp_retransmit_skb - Retransmit Request, Close, or CloseReq packets + * @sk: socket to perform retransmit on + * + * There are only four retransmittable packet types in DCCP: + * - Request in client-REQUEST state (sec. 8.1.1), + * - CloseReq in server-CLOSEREQ state (sec. 8.3), + * - Close in node-CLOSING state (sec. 8.3), + * - Acks in client-PARTOPEN state (sec. 8.1.5, handled by dccp_delack_timer()). + * This function expects sk->sk_send_head to contain the original skb. + */ +int dccp_retransmit_skb(struct sock *sk) +{ + WARN_ON(sk->sk_send_head == NULL); + + if (inet_csk(sk)->icsk_af_ops->rebuild_header(sk) != 0) + return -EHOSTUNREACH; /* Routing failure or similar. */ + + /* this count is used to distinguish original and retransmitted skb */ + inet_csk(sk)->icsk_retransmits++; + + return dccp_transmit_skb(sk, skb_clone(sk->sk_send_head, GFP_ATOMIC)); +} + +struct sk_buff *dccp_make_response(const struct sock *sk, struct dst_entry *dst, + struct request_sock *req) +{ + struct dccp_hdr *dh; + struct dccp_request_sock *dreq; + const u32 dccp_header_size = sizeof(struct dccp_hdr) + + sizeof(struct dccp_hdr_ext) + + sizeof(struct dccp_hdr_response); + struct sk_buff *skb; + + /* sk is marked const to clearly express we dont hold socket lock. + * sock_wmalloc() will atomically change sk->sk_wmem_alloc, + * it is safe to promote sk to non const. + */ + skb = sock_wmalloc((struct sock *)sk, MAX_DCCP_HEADER, 1, + GFP_ATOMIC); + if (!skb) + return NULL; + + skb_reserve(skb, MAX_DCCP_HEADER); + + skb_dst_set(skb, dst_clone(dst)); + + dreq = dccp_rsk(req); + if (inet_rsk(req)->acked) /* increase GSS upon retransmission */ + dccp_inc_seqno(&dreq->dreq_gss); + DCCP_SKB_CB(skb)->dccpd_type = DCCP_PKT_RESPONSE; + DCCP_SKB_CB(skb)->dccpd_seq = dreq->dreq_gss; + + /* Resolve feature dependencies resulting from choice of CCID */ + if (dccp_feat_server_ccid_dependencies(dreq)) + goto response_failed; + + if (dccp_insert_options_rsk(dreq, skb)) + goto response_failed; + + /* Build and checksum header */ + dh = dccp_zeroed_hdr(skb, dccp_header_size); + + dh->dccph_sport = htons(inet_rsk(req)->ir_num); + dh->dccph_dport = inet_rsk(req)->ir_rmt_port; + dh->dccph_doff = (dccp_header_size + + DCCP_SKB_CB(skb)->dccpd_opt_len) / 4; + dh->dccph_type = DCCP_PKT_RESPONSE; + dh->dccph_x = 1; + dccp_hdr_set_seq(dh, dreq->dreq_gss); + dccp_hdr_set_ack(dccp_hdr_ack_bits(skb), dreq->dreq_gsr); + dccp_hdr_response(skb)->dccph_resp_service = dreq->dreq_service; + + dccp_csum_outgoing(skb); + + /* We use `acked' to remember that a Response was already sent. */ + inet_rsk(req)->acked = 1; + DCCP_INC_STATS(DCCP_MIB_OUTSEGS); + return skb; +response_failed: + kfree_skb(skb); + return NULL; +} + +EXPORT_SYMBOL_GPL(dccp_make_response); + +/* answer offending packet in @rcv_skb with Reset from control socket @ctl */ +struct sk_buff *dccp_ctl_make_reset(struct sock *sk, struct sk_buff *rcv_skb) +{ + struct dccp_hdr *rxdh = dccp_hdr(rcv_skb), *dh; + struct dccp_skb_cb *dcb = DCCP_SKB_CB(rcv_skb); + const u32 dccp_hdr_reset_len = sizeof(struct dccp_hdr) + + sizeof(struct dccp_hdr_ext) + + sizeof(struct dccp_hdr_reset); + struct dccp_hdr_reset *dhr; + struct sk_buff *skb; + + skb = alloc_skb(sk->sk_prot->max_header, GFP_ATOMIC); + if (skb == NULL) + return NULL; + + skb_reserve(skb, sk->sk_prot->max_header); + + /* Swap the send and the receive. */ + dh = dccp_zeroed_hdr(skb, dccp_hdr_reset_len); + dh->dccph_type = DCCP_PKT_RESET; + dh->dccph_sport = rxdh->dccph_dport; + dh->dccph_dport = rxdh->dccph_sport; + dh->dccph_doff = dccp_hdr_reset_len / 4; + dh->dccph_x = 1; + + dhr = dccp_hdr_reset(skb); + dhr->dccph_reset_code = dcb->dccpd_reset_code; + + switch (dcb->dccpd_reset_code) { + case DCCP_RESET_CODE_PACKET_ERROR: + dhr->dccph_reset_data[0] = rxdh->dccph_type; + break; + case DCCP_RESET_CODE_OPTION_ERROR: + case DCCP_RESET_CODE_MANDATORY_ERROR: + memcpy(dhr->dccph_reset_data, dcb->dccpd_reset_data, 3); + break; + } + /* + * From RFC 4340, 8.3.1: + * If P.ackno exists, set R.seqno := P.ackno + 1. + * Else set R.seqno := 0. + */ + if (dcb->dccpd_ack_seq != DCCP_PKT_WITHOUT_ACK_SEQ) + dccp_hdr_set_seq(dh, ADD48(dcb->dccpd_ack_seq, 1)); + dccp_hdr_set_ack(dccp_hdr_ack_bits(skb), dcb->dccpd_seq); + + dccp_csum_outgoing(skb); + return skb; +} + +EXPORT_SYMBOL_GPL(dccp_ctl_make_reset); + +/* send Reset on established socket, to close or abort the connection */ +int dccp_send_reset(struct sock *sk, enum dccp_reset_codes code) +{ + struct sk_buff *skb; + /* + * FIXME: what if rebuild_header fails? + * Should we be doing a rebuild_header here? + */ + int err = inet_csk(sk)->icsk_af_ops->rebuild_header(sk); + + if (err != 0) + return err; + + skb = sock_wmalloc(sk, sk->sk_prot->max_header, 1, GFP_ATOMIC); + if (skb == NULL) + return -ENOBUFS; + + /* Reserve space for headers and prepare control bits. */ + skb_reserve(skb, sk->sk_prot->max_header); + DCCP_SKB_CB(skb)->dccpd_type = DCCP_PKT_RESET; + DCCP_SKB_CB(skb)->dccpd_reset_code = code; + + return dccp_transmit_skb(sk, skb); +} + +/* + * Do all connect socket setups that can be done AF independent. + */ +int dccp_connect(struct sock *sk) +{ + struct sk_buff *skb; + struct dccp_sock *dp = dccp_sk(sk); + struct dst_entry *dst = __sk_dst_get(sk); + struct inet_connection_sock *icsk = inet_csk(sk); + + sk->sk_err = 0; + sock_reset_flag(sk, SOCK_DONE); + + dccp_sync_mss(sk, dst_mtu(dst)); + + /* do not connect if feature negotiation setup fails */ + if (dccp_feat_finalise_settings(dccp_sk(sk))) + return -EPROTO; + + /* Initialise GAR as per 8.5; AWL/AWH are set in dccp_transmit_skb() */ + dp->dccps_gar = dp->dccps_iss; + + skb = alloc_skb(sk->sk_prot->max_header, sk->sk_allocation); + if (unlikely(skb == NULL)) + return -ENOBUFS; + + /* Reserve space for headers. */ + skb_reserve(skb, sk->sk_prot->max_header); + + DCCP_SKB_CB(skb)->dccpd_type = DCCP_PKT_REQUEST; + + dccp_transmit_skb(sk, dccp_skb_entail(sk, skb)); + DCCP_INC_STATS(DCCP_MIB_ACTIVEOPENS); + + /* Timer for repeating the REQUEST until an answer. */ + icsk->icsk_retransmits = 0; + inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, + icsk->icsk_rto, DCCP_RTO_MAX); + return 0; +} + +EXPORT_SYMBOL_GPL(dccp_connect); + +void dccp_send_ack(struct sock *sk) +{ + /* If we have been reset, we may not send again. */ + if (sk->sk_state != DCCP_CLOSED) { + struct sk_buff *skb = alloc_skb(sk->sk_prot->max_header, + GFP_ATOMIC); + + if (skb == NULL) { + inet_csk_schedule_ack(sk); + inet_csk(sk)->icsk_ack.ato = TCP_ATO_MIN; + inet_csk_reset_xmit_timer(sk, ICSK_TIME_DACK, + TCP_DELACK_MAX, + DCCP_RTO_MAX); + return; + } + + /* Reserve space for headers */ + skb_reserve(skb, sk->sk_prot->max_header); + DCCP_SKB_CB(skb)->dccpd_type = DCCP_PKT_ACK; + dccp_transmit_skb(sk, skb); + } +} + +EXPORT_SYMBOL_GPL(dccp_send_ack); + +#if 0 +/* FIXME: Is this still necessary (11.3) - currently nowhere used by DCCP. */ +void dccp_send_delayed_ack(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + /* + * FIXME: tune this timer. elapsed time fixes the skew, so no problem + * with using 2s, and active senders also piggyback the ACK into a + * DATAACK packet, so this is really for quiescent senders. + */ + unsigned long timeout = jiffies + 2 * HZ; + + /* Use new timeout only if there wasn't a older one earlier. */ + if (icsk->icsk_ack.pending & ICSK_ACK_TIMER) { + /* If delack timer was blocked or is about to expire, + * send ACK now. + * + * FIXME: check the "about to expire" part + */ + if (icsk->icsk_ack.blocked) { + dccp_send_ack(sk); + return; + } + + if (!time_before(timeout, icsk->icsk_ack.timeout)) + timeout = icsk->icsk_ack.timeout; + } + icsk->icsk_ack.pending |= ICSK_ACK_SCHED | ICSK_ACK_TIMER; + icsk->icsk_ack.timeout = timeout; + sk_reset_timer(sk, &icsk->icsk_delack_timer, timeout); +} +#endif + +void dccp_send_sync(struct sock *sk, const u64 ackno, + const enum dccp_pkt_type pkt_type) +{ + /* + * We are not putting this on the write queue, so + * dccp_transmit_skb() will set the ownership to this + * sock. + */ + struct sk_buff *skb = alloc_skb(sk->sk_prot->max_header, GFP_ATOMIC); + + if (skb == NULL) { + /* FIXME: how to make sure the sync is sent? */ + DCCP_CRIT("could not send %s", dccp_packet_name(pkt_type)); + return; + } + + /* Reserve space for headers and prepare control bits. */ + skb_reserve(skb, sk->sk_prot->max_header); + DCCP_SKB_CB(skb)->dccpd_type = pkt_type; + DCCP_SKB_CB(skb)->dccpd_ack_seq = ackno; + + /* + * Clear the flag in case the Sync was scheduled for out-of-band data, + * such as carrying a long Ack Vector. + */ + dccp_sk(sk)->dccps_sync_scheduled = 0; + + dccp_transmit_skb(sk, skb); +} + +EXPORT_SYMBOL_GPL(dccp_send_sync); + +/* + * Send a DCCP_PKT_CLOSE/CLOSEREQ. The caller locks the socket for us. This + * cannot be allowed to fail queueing a DCCP_PKT_CLOSE/CLOSEREQ frame under + * any circumstances. + */ +void dccp_send_close(struct sock *sk, const int active) +{ + struct dccp_sock *dp = dccp_sk(sk); + struct sk_buff *skb; + const gfp_t prio = active ? GFP_KERNEL : GFP_ATOMIC; + + skb = alloc_skb(sk->sk_prot->max_header, prio); + if (skb == NULL) + return; + + /* Reserve space for headers and prepare control bits. */ + skb_reserve(skb, sk->sk_prot->max_header); + if (dp->dccps_role == DCCP_ROLE_SERVER && !dp->dccps_server_timewait) + DCCP_SKB_CB(skb)->dccpd_type = DCCP_PKT_CLOSEREQ; + else + DCCP_SKB_CB(skb)->dccpd_type = DCCP_PKT_CLOSE; + + if (active) { + skb = dccp_skb_entail(sk, skb); + /* + * Retransmission timer for active-close: RFC 4340, 8.3 requires + * to retransmit the Close/CloseReq until the CLOSING/CLOSEREQ + * state can be left. The initial timeout is 2 RTTs. + * Since RTT measurement is done by the CCIDs, there is no easy + * way to get an RTT sample. The fallback RTT from RFC 4340, 3.4 + * is too low (200ms); we use a high value to avoid unnecessary + * retransmissions when the link RTT is > 0.2 seconds. + * FIXME: Let main module sample RTTs and use that instead. + */ + inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, + DCCP_TIMEOUT_INIT, DCCP_RTO_MAX); + } + dccp_transmit_skb(sk, skb); +} diff --git a/net/dccp/proto.c b/net/dccp/proto.c new file mode 100644 index 000000000..c522c76a9 --- /dev/null +++ b/net/dccp/proto.c @@ -0,0 +1,1290 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/dccp/proto.c + * + * An implementation of the DCCP protocol + * Arnaldo Carvalho de Melo + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "ccid.h" +#include "dccp.h" +#include "feat.h" + +#define CREATE_TRACE_POINTS +#include "trace.h" + +DEFINE_SNMP_STAT(struct dccp_mib, dccp_statistics) __read_mostly; + +EXPORT_SYMBOL_GPL(dccp_statistics); + +DEFINE_PER_CPU(unsigned int, dccp_orphan_count); +EXPORT_PER_CPU_SYMBOL_GPL(dccp_orphan_count); + +struct inet_hashinfo dccp_hashinfo; +EXPORT_SYMBOL_GPL(dccp_hashinfo); + +/* the maximum queue length for tx in packets. 0 is no limit */ +int sysctl_dccp_tx_qlen __read_mostly = 5; + +#ifdef CONFIG_IP_DCCP_DEBUG +static const char *dccp_state_name(const int state) +{ + static const char *const dccp_state_names[] = { + [DCCP_OPEN] = "OPEN", + [DCCP_REQUESTING] = "REQUESTING", + [DCCP_PARTOPEN] = "PARTOPEN", + [DCCP_LISTEN] = "LISTEN", + [DCCP_RESPOND] = "RESPOND", + [DCCP_CLOSING] = "CLOSING", + [DCCP_ACTIVE_CLOSEREQ] = "CLOSEREQ", + [DCCP_PASSIVE_CLOSE] = "PASSIVE_CLOSE", + [DCCP_PASSIVE_CLOSEREQ] = "PASSIVE_CLOSEREQ", + [DCCP_TIME_WAIT] = "TIME_WAIT", + [DCCP_CLOSED] = "CLOSED", + }; + + if (state >= DCCP_MAX_STATES) + return "INVALID STATE!"; + else + return dccp_state_names[state]; +} +#endif + +void dccp_set_state(struct sock *sk, const int state) +{ + const int oldstate = sk->sk_state; + + dccp_pr_debug("%s(%p) %s --> %s\n", dccp_role(sk), sk, + dccp_state_name(oldstate), dccp_state_name(state)); + WARN_ON(state == oldstate); + + switch (state) { + case DCCP_OPEN: + if (oldstate != DCCP_OPEN) + DCCP_INC_STATS(DCCP_MIB_CURRESTAB); + /* Client retransmits all Confirm options until entering OPEN */ + if (oldstate == DCCP_PARTOPEN) + dccp_feat_list_purge(&dccp_sk(sk)->dccps_featneg); + break; + + case DCCP_CLOSED: + if (oldstate == DCCP_OPEN || oldstate == DCCP_ACTIVE_CLOSEREQ || + oldstate == DCCP_CLOSING) + DCCP_INC_STATS(DCCP_MIB_ESTABRESETS); + + sk->sk_prot->unhash(sk); + if (inet_csk(sk)->icsk_bind_hash != NULL && + !(sk->sk_userlocks & SOCK_BINDPORT_LOCK)) + inet_put_port(sk); + fallthrough; + default: + if (oldstate == DCCP_OPEN) + DCCP_DEC_STATS(DCCP_MIB_CURRESTAB); + } + + /* Change state AFTER socket is unhashed to avoid closed + * socket sitting in hash tables. + */ + inet_sk_set_state(sk, state); +} + +EXPORT_SYMBOL_GPL(dccp_set_state); + +static void dccp_finish_passive_close(struct sock *sk) +{ + switch (sk->sk_state) { + case DCCP_PASSIVE_CLOSE: + /* Node (client or server) has received Close packet. */ + dccp_send_reset(sk, DCCP_RESET_CODE_CLOSED); + dccp_set_state(sk, DCCP_CLOSED); + break; + case DCCP_PASSIVE_CLOSEREQ: + /* + * Client received CloseReq. We set the `active' flag so that + * dccp_send_close() retransmits the Close as per RFC 4340, 8.3. + */ + dccp_send_close(sk, 1); + dccp_set_state(sk, DCCP_CLOSING); + } +} + +void dccp_done(struct sock *sk) +{ + dccp_set_state(sk, DCCP_CLOSED); + dccp_clear_xmit_timers(sk); + + sk->sk_shutdown = SHUTDOWN_MASK; + + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_state_change(sk); + else + inet_csk_destroy_sock(sk); +} + +EXPORT_SYMBOL_GPL(dccp_done); + +const char *dccp_packet_name(const int type) +{ + static const char *const dccp_packet_names[] = { + [DCCP_PKT_REQUEST] = "REQUEST", + [DCCP_PKT_RESPONSE] = "RESPONSE", + [DCCP_PKT_DATA] = "DATA", + [DCCP_PKT_ACK] = "ACK", + [DCCP_PKT_DATAACK] = "DATAACK", + [DCCP_PKT_CLOSEREQ] = "CLOSEREQ", + [DCCP_PKT_CLOSE] = "CLOSE", + [DCCP_PKT_RESET] = "RESET", + [DCCP_PKT_SYNC] = "SYNC", + [DCCP_PKT_SYNCACK] = "SYNCACK", + }; + + if (type >= DCCP_NR_PKT_TYPES) + return "INVALID"; + else + return dccp_packet_names[type]; +} + +EXPORT_SYMBOL_GPL(dccp_packet_name); + +void dccp_destruct_common(struct sock *sk) +{ + struct dccp_sock *dp = dccp_sk(sk); + + ccid_hc_tx_delete(dp->dccps_hc_tx_ccid, sk); + dp->dccps_hc_tx_ccid = NULL; +} +EXPORT_SYMBOL_GPL(dccp_destruct_common); + +static void dccp_sk_destruct(struct sock *sk) +{ + dccp_destruct_common(sk); + inet_sock_destruct(sk); +} + +int dccp_init_sock(struct sock *sk, const __u8 ctl_sock_initialized) +{ + struct dccp_sock *dp = dccp_sk(sk); + struct inet_connection_sock *icsk = inet_csk(sk); + + icsk->icsk_rto = DCCP_TIMEOUT_INIT; + icsk->icsk_syn_retries = sysctl_dccp_request_retries; + sk->sk_state = DCCP_CLOSED; + sk->sk_write_space = dccp_write_space; + sk->sk_destruct = dccp_sk_destruct; + icsk->icsk_sync_mss = dccp_sync_mss; + dp->dccps_mss_cache = 536; + dp->dccps_rate_last = jiffies; + dp->dccps_role = DCCP_ROLE_UNDEFINED; + dp->dccps_service = DCCP_SERVICE_CODE_IS_ABSENT; + dp->dccps_tx_qlen = sysctl_dccp_tx_qlen; + + dccp_init_xmit_timers(sk); + + INIT_LIST_HEAD(&dp->dccps_featneg); + /* control socket doesn't need feat nego */ + if (likely(ctl_sock_initialized)) + return dccp_feat_init(sk); + return 0; +} + +EXPORT_SYMBOL_GPL(dccp_init_sock); + +void dccp_destroy_sock(struct sock *sk) +{ + struct dccp_sock *dp = dccp_sk(sk); + + __skb_queue_purge(&sk->sk_write_queue); + if (sk->sk_send_head != NULL) { + kfree_skb(sk->sk_send_head); + sk->sk_send_head = NULL; + } + + /* Clean up a referenced DCCP bind bucket. */ + if (inet_csk(sk)->icsk_bind_hash != NULL) + inet_put_port(sk); + + kfree(dp->dccps_service_list); + dp->dccps_service_list = NULL; + + if (dp->dccps_hc_rx_ackvec != NULL) { + dccp_ackvec_free(dp->dccps_hc_rx_ackvec); + dp->dccps_hc_rx_ackvec = NULL; + } + ccid_hc_rx_delete(dp->dccps_hc_rx_ccid, sk); + dp->dccps_hc_rx_ccid = NULL; + + /* clean up feature negotiation state */ + dccp_feat_list_purge(&dp->dccps_featneg); +} + +EXPORT_SYMBOL_GPL(dccp_destroy_sock); + +static inline int dccp_need_reset(int state) +{ + return state != DCCP_CLOSED && state != DCCP_LISTEN && + state != DCCP_REQUESTING; +} + +int dccp_disconnect(struct sock *sk, int flags) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct inet_sock *inet = inet_sk(sk); + struct dccp_sock *dp = dccp_sk(sk); + const int old_state = sk->sk_state; + + if (old_state != DCCP_CLOSED) + dccp_set_state(sk, DCCP_CLOSED); + + /* + * This corresponds to the ABORT function of RFC793, sec. 3.8 + * TCP uses a RST segment, DCCP a Reset packet with Code 2, "Aborted". + */ + if (old_state == DCCP_LISTEN) { + inet_csk_listen_stop(sk); + } else if (dccp_need_reset(old_state)) { + dccp_send_reset(sk, DCCP_RESET_CODE_ABORTED); + sk->sk_err = ECONNRESET; + } else if (old_state == DCCP_REQUESTING) + sk->sk_err = ECONNRESET; + + dccp_clear_xmit_timers(sk); + ccid_hc_rx_delete(dp->dccps_hc_rx_ccid, sk); + dp->dccps_hc_rx_ccid = NULL; + + __skb_queue_purge(&sk->sk_receive_queue); + __skb_queue_purge(&sk->sk_write_queue); + if (sk->sk_send_head != NULL) { + __kfree_skb(sk->sk_send_head); + sk->sk_send_head = NULL; + } + + inet->inet_dport = 0; + + inet_bhash2_reset_saddr(sk); + + sk->sk_shutdown = 0; + sock_reset_flag(sk, SOCK_DONE); + + icsk->icsk_backoff = 0; + inet_csk_delack_init(sk); + __sk_dst_reset(sk); + + WARN_ON(inet->inet_num && !icsk->icsk_bind_hash); + + sk_error_report(sk); + return 0; +} + +EXPORT_SYMBOL_GPL(dccp_disconnect); + +/* + * Wait for a DCCP event. + * + * Note that we don't need to lock the socket, as the upper poll layers + * take care of normal races (between the test and the event) and we don't + * go look at any of the socket buffers directly. + */ +__poll_t dccp_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + struct sock *sk = sock->sk; + __poll_t mask; + u8 shutdown; + int state; + + sock_poll_wait(file, sock, wait); + + state = inet_sk_state_load(sk); + if (state == DCCP_LISTEN) + return inet_csk_listen_poll(sk); + + /* Socket is not locked. We are protected from async events + by poll logic and correct handling of state changes + made by another threads is impossible in any case. + */ + + mask = 0; + if (READ_ONCE(sk->sk_err)) + mask = EPOLLERR; + shutdown = READ_ONCE(sk->sk_shutdown); + + if (shutdown == SHUTDOWN_MASK || state == DCCP_CLOSED) + mask |= EPOLLHUP; + if (shutdown & RCV_SHUTDOWN) + mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP; + + /* Connected? */ + if ((1 << state) & ~(DCCPF_REQUESTING | DCCPF_RESPOND)) { + if (atomic_read(&sk->sk_rmem_alloc) > 0) + mask |= EPOLLIN | EPOLLRDNORM; + + if (!(shutdown & SEND_SHUTDOWN)) { + if (sk_stream_is_writeable(sk)) { + mask |= EPOLLOUT | EPOLLWRNORM; + } else { /* send SIGIO later */ + sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk); + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + + /* Race breaker. If space is freed after + * wspace test but before the flags are set, + * IO signal will be lost. + */ + if (sk_stream_is_writeable(sk)) + mask |= EPOLLOUT | EPOLLWRNORM; + } + } + } + return mask; +} +EXPORT_SYMBOL_GPL(dccp_poll); + +int dccp_ioctl(struct sock *sk, int cmd, unsigned long arg) +{ + int rc = -ENOTCONN; + + lock_sock(sk); + + if (sk->sk_state == DCCP_LISTEN) + goto out; + + switch (cmd) { + case SIOCOUTQ: { + int amount = sk_wmem_alloc_get(sk); + /* Using sk_wmem_alloc here because sk_wmem_queued is not used by DCCP and + * always 0, comparably to UDP. + */ + + rc = put_user(amount, (int __user *)arg); + } + break; + case SIOCINQ: { + struct sk_buff *skb; + unsigned long amount = 0; + + skb = skb_peek(&sk->sk_receive_queue); + if (skb != NULL) { + /* + * We will only return the amount of this packet since + * that is all that will be read. + */ + amount = skb->len; + } + rc = put_user(amount, (int __user *)arg); + } + break; + default: + rc = -ENOIOCTLCMD; + break; + } +out: + release_sock(sk); + return rc; +} + +EXPORT_SYMBOL_GPL(dccp_ioctl); + +static int dccp_setsockopt_service(struct sock *sk, const __be32 service, + sockptr_t optval, unsigned int optlen) +{ + struct dccp_sock *dp = dccp_sk(sk); + struct dccp_service_list *sl = NULL; + + if (service == DCCP_SERVICE_INVALID_VALUE || + optlen > DCCP_SERVICE_LIST_MAX_LEN * sizeof(u32)) + return -EINVAL; + + if (optlen > sizeof(service)) { + sl = kmalloc(optlen, GFP_KERNEL); + if (sl == NULL) + return -ENOMEM; + + sl->dccpsl_nr = optlen / sizeof(u32) - 1; + if (copy_from_sockptr_offset(sl->dccpsl_list, optval, + sizeof(service), optlen - sizeof(service)) || + dccp_list_has_service(sl, DCCP_SERVICE_INVALID_VALUE)) { + kfree(sl); + return -EFAULT; + } + } + + lock_sock(sk); + dp->dccps_service = service; + + kfree(dp->dccps_service_list); + + dp->dccps_service_list = sl; + release_sock(sk); + return 0; +} + +static int dccp_setsockopt_cscov(struct sock *sk, int cscov, bool rx) +{ + u8 *list, len; + int i, rc; + + if (cscov < 0 || cscov > 15) + return -EINVAL; + /* + * Populate a list of permissible values, in the range cscov...15. This + * is necessary since feature negotiation of single values only works if + * both sides incidentally choose the same value. Since the list starts + * lowest-value first, negotiation will pick the smallest shared value. + */ + if (cscov == 0) + return 0; + len = 16 - cscov; + + list = kmalloc(len, GFP_KERNEL); + if (list == NULL) + return -ENOBUFS; + + for (i = 0; i < len; i++) + list[i] = cscov++; + + rc = dccp_feat_register_sp(sk, DCCPF_MIN_CSUM_COVER, rx, list, len); + + if (rc == 0) { + if (rx) + dccp_sk(sk)->dccps_pcrlen = cscov; + else + dccp_sk(sk)->dccps_pcslen = cscov; + } + kfree(list); + return rc; +} + +static int dccp_setsockopt_ccid(struct sock *sk, int type, + sockptr_t optval, unsigned int optlen) +{ + u8 *val; + int rc = 0; + + if (optlen < 1 || optlen > DCCP_FEAT_MAX_SP_VALS) + return -EINVAL; + + val = memdup_sockptr(optval, optlen); + if (IS_ERR(val)) + return PTR_ERR(val); + + lock_sock(sk); + if (type == DCCP_SOCKOPT_TX_CCID || type == DCCP_SOCKOPT_CCID) + rc = dccp_feat_register_sp(sk, DCCPF_CCID, 1, val, optlen); + + if (!rc && (type == DCCP_SOCKOPT_RX_CCID || type == DCCP_SOCKOPT_CCID)) + rc = dccp_feat_register_sp(sk, DCCPF_CCID, 0, val, optlen); + release_sock(sk); + + kfree(val); + return rc; +} + +static int do_dccp_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct dccp_sock *dp = dccp_sk(sk); + int val, err = 0; + + switch (optname) { + case DCCP_SOCKOPT_PACKET_SIZE: + DCCP_WARN("sockopt(PACKET_SIZE) is deprecated: fix your app\n"); + return 0; + case DCCP_SOCKOPT_CHANGE_L: + case DCCP_SOCKOPT_CHANGE_R: + DCCP_WARN("sockopt(CHANGE_L/R) is deprecated: fix your app\n"); + return 0; + case DCCP_SOCKOPT_CCID: + case DCCP_SOCKOPT_RX_CCID: + case DCCP_SOCKOPT_TX_CCID: + return dccp_setsockopt_ccid(sk, optname, optval, optlen); + } + + if (optlen < (int)sizeof(int)) + return -EINVAL; + + if (copy_from_sockptr(&val, optval, sizeof(int))) + return -EFAULT; + + if (optname == DCCP_SOCKOPT_SERVICE) + return dccp_setsockopt_service(sk, val, optval, optlen); + + lock_sock(sk); + switch (optname) { + case DCCP_SOCKOPT_SERVER_TIMEWAIT: + if (dp->dccps_role != DCCP_ROLE_SERVER) + err = -EOPNOTSUPP; + else + dp->dccps_server_timewait = (val != 0); + break; + case DCCP_SOCKOPT_SEND_CSCOV: + err = dccp_setsockopt_cscov(sk, val, false); + break; + case DCCP_SOCKOPT_RECV_CSCOV: + err = dccp_setsockopt_cscov(sk, val, true); + break; + case DCCP_SOCKOPT_QPOLICY_ID: + if (sk->sk_state != DCCP_CLOSED) + err = -EISCONN; + else if (val < 0 || val >= DCCPQ_POLICY_MAX) + err = -EINVAL; + else + dp->dccps_qpolicy = val; + break; + case DCCP_SOCKOPT_QPOLICY_TXQLEN: + if (val < 0) + err = -EINVAL; + else + dp->dccps_tx_qlen = val; + break; + default: + err = -ENOPROTOOPT; + break; + } + release_sock(sk); + + return err; +} + +int dccp_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, + unsigned int optlen) +{ + if (level != SOL_DCCP) + return inet_csk(sk)->icsk_af_ops->setsockopt(sk, level, + optname, optval, + optlen); + return do_dccp_setsockopt(sk, level, optname, optval, optlen); +} + +EXPORT_SYMBOL_GPL(dccp_setsockopt); + +static int dccp_getsockopt_service(struct sock *sk, int len, + __be32 __user *optval, + int __user *optlen) +{ + const struct dccp_sock *dp = dccp_sk(sk); + const struct dccp_service_list *sl; + int err = -ENOENT, slen = 0, total_len = sizeof(u32); + + lock_sock(sk); + if ((sl = dp->dccps_service_list) != NULL) { + slen = sl->dccpsl_nr * sizeof(u32); + total_len += slen; + } + + err = -EINVAL; + if (total_len > len) + goto out; + + err = 0; + if (put_user(total_len, optlen) || + put_user(dp->dccps_service, optval) || + (sl != NULL && copy_to_user(optval + 1, sl->dccpsl_list, slen))) + err = -EFAULT; +out: + release_sock(sk); + return err; +} + +static int do_dccp_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct dccp_sock *dp; + int val, len; + + if (get_user(len, optlen)) + return -EFAULT; + + if (len < (int)sizeof(int)) + return -EINVAL; + + dp = dccp_sk(sk); + + switch (optname) { + case DCCP_SOCKOPT_PACKET_SIZE: + DCCP_WARN("sockopt(PACKET_SIZE) is deprecated: fix your app\n"); + return 0; + case DCCP_SOCKOPT_SERVICE: + return dccp_getsockopt_service(sk, len, + (__be32 __user *)optval, optlen); + case DCCP_SOCKOPT_GET_CUR_MPS: + val = READ_ONCE(dp->dccps_mss_cache); + break; + case DCCP_SOCKOPT_AVAILABLE_CCIDS: + return ccid_getsockopt_builtin_ccids(sk, len, optval, optlen); + case DCCP_SOCKOPT_TX_CCID: + val = ccid_get_current_tx_ccid(dp); + if (val < 0) + return -ENOPROTOOPT; + break; + case DCCP_SOCKOPT_RX_CCID: + val = ccid_get_current_rx_ccid(dp); + if (val < 0) + return -ENOPROTOOPT; + break; + case DCCP_SOCKOPT_SERVER_TIMEWAIT: + val = dp->dccps_server_timewait; + break; + case DCCP_SOCKOPT_SEND_CSCOV: + val = dp->dccps_pcslen; + break; + case DCCP_SOCKOPT_RECV_CSCOV: + val = dp->dccps_pcrlen; + break; + case DCCP_SOCKOPT_QPOLICY_ID: + val = dp->dccps_qpolicy; + break; + case DCCP_SOCKOPT_QPOLICY_TXQLEN: + val = dp->dccps_tx_qlen; + break; + case 128 ... 191: + return ccid_hc_rx_getsockopt(dp->dccps_hc_rx_ccid, sk, optname, + len, (u32 __user *)optval, optlen); + case 192 ... 255: + return ccid_hc_tx_getsockopt(dp->dccps_hc_tx_ccid, sk, optname, + len, (u32 __user *)optval, optlen); + default: + return -ENOPROTOOPT; + } + + len = sizeof(val); + if (put_user(len, optlen) || copy_to_user(optval, &val, len)) + return -EFAULT; + + return 0; +} + +int dccp_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen) +{ + if (level != SOL_DCCP) + return inet_csk(sk)->icsk_af_ops->getsockopt(sk, level, + optname, optval, + optlen); + return do_dccp_getsockopt(sk, level, optname, optval, optlen); +} + +EXPORT_SYMBOL_GPL(dccp_getsockopt); + +static int dccp_msghdr_parse(struct msghdr *msg, struct sk_buff *skb) +{ + struct cmsghdr *cmsg; + + /* + * Assign an (opaque) qpolicy priority value to skb->priority. + * + * We are overloading this skb field for use with the qpolicy subystem. + * The skb->priority is normally used for the SO_PRIORITY option, which + * is initialised from sk_priority. Since the assignment of sk_priority + * to skb->priority happens later (on layer 3), we overload this field + * for use with queueing priorities as long as the skb is on layer 4. + * The default priority value (if nothing is set) is 0. + */ + skb->priority = 0; + + for_each_cmsghdr(cmsg, msg) { + if (!CMSG_OK(msg, cmsg)) + return -EINVAL; + + if (cmsg->cmsg_level != SOL_DCCP) + continue; + + if (cmsg->cmsg_type <= DCCP_SCM_QPOLICY_MAX && + !dccp_qpolicy_param_ok(skb->sk, cmsg->cmsg_type)) + return -EINVAL; + + switch (cmsg->cmsg_type) { + case DCCP_SCM_PRIORITY: + if (cmsg->cmsg_len != CMSG_LEN(sizeof(__u32))) + return -EINVAL; + skb->priority = *(__u32 *)CMSG_DATA(cmsg); + break; + default: + return -EINVAL; + } + } + return 0; +} + +int dccp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) +{ + const struct dccp_sock *dp = dccp_sk(sk); + const int flags = msg->msg_flags; + const int noblock = flags & MSG_DONTWAIT; + struct sk_buff *skb; + int rc, size; + long timeo; + + trace_dccp_probe(sk, len); + + if (len > READ_ONCE(dp->dccps_mss_cache)) + return -EMSGSIZE; + + lock_sock(sk); + + timeo = sock_sndtimeo(sk, noblock); + + /* + * We have to use sk_stream_wait_connect here to set sk_write_pending, + * so that the trick in dccp_rcv_request_sent_state_process. + */ + /* Wait for a connection to finish. */ + if ((1 << sk->sk_state) & ~(DCCPF_OPEN | DCCPF_PARTOPEN)) + if ((rc = sk_stream_wait_connect(sk, &timeo)) != 0) + goto out_release; + + size = sk->sk_prot->max_header + len; + release_sock(sk); + skb = sock_alloc_send_skb(sk, size, noblock, &rc); + lock_sock(sk); + if (skb == NULL) + goto out_release; + + if (dccp_qpolicy_full(sk)) { + rc = -EAGAIN; + goto out_discard; + } + + if (sk->sk_state == DCCP_CLOSED) { + rc = -ENOTCONN; + goto out_discard; + } + + /* We need to check dccps_mss_cache after socket is locked. */ + if (len > dp->dccps_mss_cache) { + rc = -EMSGSIZE; + goto out_discard; + } + + skb_reserve(skb, sk->sk_prot->max_header); + rc = memcpy_from_msg(skb_put(skb, len), msg, len); + if (rc != 0) + goto out_discard; + + rc = dccp_msghdr_parse(msg, skb); + if (rc != 0) + goto out_discard; + + dccp_qpolicy_push(sk, skb); + /* + * The xmit_timer is set if the TX CCID is rate-based and will expire + * when congestion control permits to release further packets into the + * network. Window-based CCIDs do not use this timer. + */ + if (!timer_pending(&dp->dccps_xmit_timer)) + dccp_write_xmit(sk); +out_release: + release_sock(sk); + return rc ? : len; +out_discard: + kfree_skb(skb); + goto out_release; +} + +EXPORT_SYMBOL_GPL(dccp_sendmsg); + +int dccp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, + int *addr_len) +{ + const struct dccp_hdr *dh; + long timeo; + + lock_sock(sk); + + if (sk->sk_state == DCCP_LISTEN) { + len = -ENOTCONN; + goto out; + } + + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + + do { + struct sk_buff *skb = skb_peek(&sk->sk_receive_queue); + + if (skb == NULL) + goto verify_sock_status; + + dh = dccp_hdr(skb); + + switch (dh->dccph_type) { + case DCCP_PKT_DATA: + case DCCP_PKT_DATAACK: + goto found_ok_skb; + + case DCCP_PKT_CLOSE: + case DCCP_PKT_CLOSEREQ: + if (!(flags & MSG_PEEK)) + dccp_finish_passive_close(sk); + fallthrough; + case DCCP_PKT_RESET: + dccp_pr_debug("found fin (%s) ok!\n", + dccp_packet_name(dh->dccph_type)); + len = 0; + goto found_fin_ok; + default: + dccp_pr_debug("packet_type=%s\n", + dccp_packet_name(dh->dccph_type)); + sk_eat_skb(sk, skb); + } +verify_sock_status: + if (sock_flag(sk, SOCK_DONE)) { + len = 0; + break; + } + + if (sk->sk_err) { + len = sock_error(sk); + break; + } + + if (sk->sk_shutdown & RCV_SHUTDOWN) { + len = 0; + break; + } + + if (sk->sk_state == DCCP_CLOSED) { + if (!sock_flag(sk, SOCK_DONE)) { + /* This occurs when user tries to read + * from never connected socket. + */ + len = -ENOTCONN; + break; + } + len = 0; + break; + } + + if (!timeo) { + len = -EAGAIN; + break; + } + + if (signal_pending(current)) { + len = sock_intr_errno(timeo); + break; + } + + sk_wait_data(sk, &timeo, NULL); + continue; + found_ok_skb: + if (len > skb->len) + len = skb->len; + else if (len < skb->len) + msg->msg_flags |= MSG_TRUNC; + + if (skb_copy_datagram_msg(skb, 0, msg, len)) { + /* Exception. Bailout! */ + len = -EFAULT; + break; + } + if (flags & MSG_TRUNC) + len = skb->len; + found_fin_ok: + if (!(flags & MSG_PEEK)) + sk_eat_skb(sk, skb); + break; + } while (1); +out: + release_sock(sk); + return len; +} + +EXPORT_SYMBOL_GPL(dccp_recvmsg); + +int inet_dccp_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + unsigned char old_state; + int err; + + lock_sock(sk); + + err = -EINVAL; + if (sock->state != SS_UNCONNECTED || sock->type != SOCK_DCCP) + goto out; + + old_state = sk->sk_state; + if (!((1 << old_state) & (DCCPF_CLOSED | DCCPF_LISTEN))) + goto out; + + WRITE_ONCE(sk->sk_max_ack_backlog, backlog); + /* Really, if the socket is already in listen state + * we can only allow the backlog to be adjusted. + */ + if (old_state != DCCP_LISTEN) { + struct dccp_sock *dp = dccp_sk(sk); + + dp->dccps_role = DCCP_ROLE_LISTEN; + + /* do not start to listen if feature negotiation setup fails */ + if (dccp_feat_finalise_settings(dp)) { + err = -EPROTO; + goto out; + } + + err = inet_csk_listen_start(sk); + if (err) + goto out; + } + err = 0; + +out: + release_sock(sk); + return err; +} + +EXPORT_SYMBOL_GPL(inet_dccp_listen); + +static void dccp_terminate_connection(struct sock *sk) +{ + u8 next_state = DCCP_CLOSED; + + switch (sk->sk_state) { + case DCCP_PASSIVE_CLOSE: + case DCCP_PASSIVE_CLOSEREQ: + dccp_finish_passive_close(sk); + break; + case DCCP_PARTOPEN: + dccp_pr_debug("Stop PARTOPEN timer (%p)\n", sk); + inet_csk_clear_xmit_timer(sk, ICSK_TIME_DACK); + fallthrough; + case DCCP_OPEN: + dccp_send_close(sk, 1); + + if (dccp_sk(sk)->dccps_role == DCCP_ROLE_SERVER && + !dccp_sk(sk)->dccps_server_timewait) + next_state = DCCP_ACTIVE_CLOSEREQ; + else + next_state = DCCP_CLOSING; + fallthrough; + default: + dccp_set_state(sk, next_state); + } +} + +void dccp_close(struct sock *sk, long timeout) +{ + struct dccp_sock *dp = dccp_sk(sk); + struct sk_buff *skb; + u32 data_was_unread = 0; + int state; + + lock_sock(sk); + + sk->sk_shutdown = SHUTDOWN_MASK; + + if (sk->sk_state == DCCP_LISTEN) { + dccp_set_state(sk, DCCP_CLOSED); + + /* Special case. */ + inet_csk_listen_stop(sk); + + goto adjudge_to_death; + } + + sk_stop_timer(sk, &dp->dccps_xmit_timer); + + /* + * We need to flush the recv. buffs. We do this only on the + * descriptor close, not protocol-sourced closes, because the + *reader process may not have drained the data yet! + */ + while ((skb = __skb_dequeue(&sk->sk_receive_queue)) != NULL) { + data_was_unread += skb->len; + __kfree_skb(skb); + } + + /* If socket has been already reset kill it. */ + if (sk->sk_state == DCCP_CLOSED) + goto adjudge_to_death; + + if (data_was_unread) { + /* Unread data was tossed, send an appropriate Reset Code */ + DCCP_WARN("ABORT with %u bytes unread\n", data_was_unread); + dccp_send_reset(sk, DCCP_RESET_CODE_ABORTED); + dccp_set_state(sk, DCCP_CLOSED); + } else if (sock_flag(sk, SOCK_LINGER) && !sk->sk_lingertime) { + /* Check zero linger _after_ checking for unread data. */ + sk->sk_prot->disconnect(sk, 0); + } else if (sk->sk_state != DCCP_CLOSED) { + /* + * Normal connection termination. May need to wait if there are + * still packets in the TX queue that are delayed by the CCID. + */ + dccp_flush_write_queue(sk, &timeout); + dccp_terminate_connection(sk); + } + + /* + * Flush write queue. This may be necessary in several cases: + * - we have been closed by the peer but still have application data; + * - abortive termination (unread data or zero linger time), + * - normal termination but queue could not be flushed within time limit + */ + __skb_queue_purge(&sk->sk_write_queue); + + sk_stream_wait_close(sk, timeout); + +adjudge_to_death: + state = sk->sk_state; + sock_hold(sk); + sock_orphan(sk); + + /* + * It is the last release_sock in its life. It will remove backlog. + */ + release_sock(sk); + /* + * Now socket is owned by kernel and we acquire BH lock + * to finish close. No need to check for user refs. + */ + local_bh_disable(); + bh_lock_sock(sk); + WARN_ON(sock_owned_by_user(sk)); + + this_cpu_inc(dccp_orphan_count); + + /* Have we already been destroyed by a softirq or backlog? */ + if (state != DCCP_CLOSED && sk->sk_state == DCCP_CLOSED) + goto out; + + if (sk->sk_state == DCCP_CLOSED) + inet_csk_destroy_sock(sk); + + /* Otherwise, socket is reprieved until protocol close. */ + +out: + bh_unlock_sock(sk); + local_bh_enable(); + sock_put(sk); +} + +EXPORT_SYMBOL_GPL(dccp_close); + +void dccp_shutdown(struct sock *sk, int how) +{ + dccp_pr_debug("called shutdown(%x)\n", how); +} + +EXPORT_SYMBOL_GPL(dccp_shutdown); + +static inline int __init dccp_mib_init(void) +{ + dccp_statistics = alloc_percpu(struct dccp_mib); + if (!dccp_statistics) + return -ENOMEM; + return 0; +} + +static inline void dccp_mib_exit(void) +{ + free_percpu(dccp_statistics); +} + +static int thash_entries; +module_param(thash_entries, int, 0444); +MODULE_PARM_DESC(thash_entries, "Number of ehash buckets"); + +#ifdef CONFIG_IP_DCCP_DEBUG +bool dccp_debug; +module_param(dccp_debug, bool, 0644); +MODULE_PARM_DESC(dccp_debug, "Enable debug messages"); + +EXPORT_SYMBOL_GPL(dccp_debug); +#endif + +static int __init dccp_init(void) +{ + unsigned long goal; + unsigned long nr_pages = totalram_pages(); + int ehash_order, bhash_order, i; + int rc; + + BUILD_BUG_ON(sizeof(struct dccp_skb_cb) > + sizeof_field(struct sk_buff, cb)); + rc = inet_hashinfo2_init_mod(&dccp_hashinfo); + if (rc) + goto out_fail; + rc = -ENOBUFS; + dccp_hashinfo.bind_bucket_cachep = + kmem_cache_create("dccp_bind_bucket", + sizeof(struct inet_bind_bucket), 0, + SLAB_HWCACHE_ALIGN | SLAB_ACCOUNT, NULL); + if (!dccp_hashinfo.bind_bucket_cachep) + goto out_free_hashinfo2; + dccp_hashinfo.bind2_bucket_cachep = + kmem_cache_create("dccp_bind2_bucket", + sizeof(struct inet_bind2_bucket), 0, + SLAB_HWCACHE_ALIGN | SLAB_ACCOUNT, NULL); + if (!dccp_hashinfo.bind2_bucket_cachep) + goto out_free_bind_bucket_cachep; + + /* + * Size and allocate the main established and bind bucket + * hash tables. + * + * The methodology is similar to that of the buffer cache. + */ + if (nr_pages >= (128 * 1024)) + goal = nr_pages >> (21 - PAGE_SHIFT); + else + goal = nr_pages >> (23 - PAGE_SHIFT); + + if (thash_entries) + goal = (thash_entries * + sizeof(struct inet_ehash_bucket)) >> PAGE_SHIFT; + for (ehash_order = 0; (1UL << ehash_order) < goal; ehash_order++) + ; + do { + unsigned long hash_size = (1UL << ehash_order) * PAGE_SIZE / + sizeof(struct inet_ehash_bucket); + + while (hash_size & (hash_size - 1)) + hash_size--; + dccp_hashinfo.ehash_mask = hash_size - 1; + dccp_hashinfo.ehash = (struct inet_ehash_bucket *) + __get_free_pages(GFP_ATOMIC|__GFP_NOWARN, ehash_order); + } while (!dccp_hashinfo.ehash && --ehash_order > 0); + + if (!dccp_hashinfo.ehash) { + DCCP_CRIT("Failed to allocate DCCP established hash table"); + goto out_free_bind2_bucket_cachep; + } + + for (i = 0; i <= dccp_hashinfo.ehash_mask; i++) + INIT_HLIST_NULLS_HEAD(&dccp_hashinfo.ehash[i].chain, i); + + if (inet_ehash_locks_alloc(&dccp_hashinfo)) + goto out_free_dccp_ehash; + + bhash_order = ehash_order; + + do { + dccp_hashinfo.bhash_size = (1UL << bhash_order) * PAGE_SIZE / + sizeof(struct inet_bind_hashbucket); + if ((dccp_hashinfo.bhash_size > (64 * 1024)) && + bhash_order > 0) + continue; + dccp_hashinfo.bhash = (struct inet_bind_hashbucket *) + __get_free_pages(GFP_ATOMIC|__GFP_NOWARN, bhash_order); + } while (!dccp_hashinfo.bhash && --bhash_order >= 0); + + if (!dccp_hashinfo.bhash) { + DCCP_CRIT("Failed to allocate DCCP bind hash table"); + goto out_free_dccp_locks; + } + + dccp_hashinfo.bhash2 = (struct inet_bind_hashbucket *) + __get_free_pages(GFP_ATOMIC | __GFP_NOWARN, bhash_order); + + if (!dccp_hashinfo.bhash2) { + DCCP_CRIT("Failed to allocate DCCP bind2 hash table"); + goto out_free_dccp_bhash; + } + + for (i = 0; i < dccp_hashinfo.bhash_size; i++) { + spin_lock_init(&dccp_hashinfo.bhash[i].lock); + INIT_HLIST_HEAD(&dccp_hashinfo.bhash[i].chain); + spin_lock_init(&dccp_hashinfo.bhash2[i].lock); + INIT_HLIST_HEAD(&dccp_hashinfo.bhash2[i].chain); + } + + dccp_hashinfo.pernet = false; + + rc = dccp_mib_init(); + if (rc) + goto out_free_dccp_bhash2; + + rc = dccp_ackvec_init(); + if (rc) + goto out_free_dccp_mib; + + rc = dccp_sysctl_init(); + if (rc) + goto out_ackvec_exit; + + rc = ccid_initialize_builtins(); + if (rc) + goto out_sysctl_exit; + + dccp_timestamping_init(); + + return 0; + +out_sysctl_exit: + dccp_sysctl_exit(); +out_ackvec_exit: + dccp_ackvec_exit(); +out_free_dccp_mib: + dccp_mib_exit(); +out_free_dccp_bhash2: + free_pages((unsigned long)dccp_hashinfo.bhash2, bhash_order); +out_free_dccp_bhash: + free_pages((unsigned long)dccp_hashinfo.bhash, bhash_order); +out_free_dccp_locks: + inet_ehash_locks_free(&dccp_hashinfo); +out_free_dccp_ehash: + free_pages((unsigned long)dccp_hashinfo.ehash, ehash_order); +out_free_bind2_bucket_cachep: + kmem_cache_destroy(dccp_hashinfo.bind2_bucket_cachep); +out_free_bind_bucket_cachep: + kmem_cache_destroy(dccp_hashinfo.bind_bucket_cachep); +out_free_hashinfo2: + inet_hashinfo2_free_mod(&dccp_hashinfo); +out_fail: + dccp_hashinfo.bhash = NULL; + dccp_hashinfo.bhash2 = NULL; + dccp_hashinfo.ehash = NULL; + dccp_hashinfo.bind_bucket_cachep = NULL; + dccp_hashinfo.bind2_bucket_cachep = NULL; + return rc; +} + +static void __exit dccp_fini(void) +{ + int bhash_order = get_order(dccp_hashinfo.bhash_size * + sizeof(struct inet_bind_hashbucket)); + + ccid_cleanup_builtins(); + dccp_mib_exit(); + free_pages((unsigned long)dccp_hashinfo.bhash, bhash_order); + free_pages((unsigned long)dccp_hashinfo.bhash2, bhash_order); + free_pages((unsigned long)dccp_hashinfo.ehash, + get_order((dccp_hashinfo.ehash_mask + 1) * + sizeof(struct inet_ehash_bucket))); + inet_ehash_locks_free(&dccp_hashinfo); + kmem_cache_destroy(dccp_hashinfo.bind_bucket_cachep); + dccp_ackvec_exit(); + dccp_sysctl_exit(); + inet_hashinfo2_free_mod(&dccp_hashinfo); +} + +module_init(dccp_init); +module_exit(dccp_fini); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Arnaldo Carvalho de Melo "); +MODULE_DESCRIPTION("DCCP - Datagram Congestion Controlled Protocol"); diff --git a/net/dccp/qpolicy.c b/net/dccp/qpolicy.c new file mode 100644 index 000000000..5ba204ec0 --- /dev/null +++ b/net/dccp/qpolicy.c @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/dccp/qpolicy.c + * + * Policy-based packet dequeueing interface for DCCP. + * + * Copyright (c) 2008 Tomasz Grobelny + */ +#include "dccp.h" + +/* + * Simple Dequeueing Policy: + * If tx_qlen is different from 0, enqueue up to tx_qlen elements. + */ +static void qpolicy_simple_push(struct sock *sk, struct sk_buff *skb) +{ + skb_queue_tail(&sk->sk_write_queue, skb); +} + +static bool qpolicy_simple_full(struct sock *sk) +{ + return dccp_sk(sk)->dccps_tx_qlen && + sk->sk_write_queue.qlen >= dccp_sk(sk)->dccps_tx_qlen; +} + +static struct sk_buff *qpolicy_simple_top(struct sock *sk) +{ + return skb_peek(&sk->sk_write_queue); +} + +/* + * Priority-based Dequeueing Policy: + * If tx_qlen is different from 0 and the queue has reached its upper bound + * of tx_qlen elements, replace older packets lowest-priority-first. + */ +static struct sk_buff *qpolicy_prio_best_skb(struct sock *sk) +{ + struct sk_buff *skb, *best = NULL; + + skb_queue_walk(&sk->sk_write_queue, skb) + if (best == NULL || skb->priority > best->priority) + best = skb; + return best; +} + +static struct sk_buff *qpolicy_prio_worst_skb(struct sock *sk) +{ + struct sk_buff *skb, *worst = NULL; + + skb_queue_walk(&sk->sk_write_queue, skb) + if (worst == NULL || skb->priority < worst->priority) + worst = skb; + return worst; +} + +static bool qpolicy_prio_full(struct sock *sk) +{ + if (qpolicy_simple_full(sk)) + dccp_qpolicy_drop(sk, qpolicy_prio_worst_skb(sk)); + return false; +} + +/** + * struct dccp_qpolicy_operations - TX Packet Dequeueing Interface + * @push: add a new @skb to the write queue + * @full: indicates that no more packets will be admitted + * @top: peeks at whatever the queueing policy defines as its `top' + * @params: parameter passed to policy operation + */ +struct dccp_qpolicy_operations { + void (*push) (struct sock *sk, struct sk_buff *skb); + bool (*full) (struct sock *sk); + struct sk_buff* (*top) (struct sock *sk); + __be32 params; +}; + +static struct dccp_qpolicy_operations qpol_table[DCCPQ_POLICY_MAX] = { + [DCCPQ_POLICY_SIMPLE] = { + .push = qpolicy_simple_push, + .full = qpolicy_simple_full, + .top = qpolicy_simple_top, + .params = 0, + }, + [DCCPQ_POLICY_PRIO] = { + .push = qpolicy_simple_push, + .full = qpolicy_prio_full, + .top = qpolicy_prio_best_skb, + .params = DCCP_SCM_PRIORITY, + }, +}; + +/* + * Externally visible interface + */ +void dccp_qpolicy_push(struct sock *sk, struct sk_buff *skb) +{ + qpol_table[dccp_sk(sk)->dccps_qpolicy].push(sk, skb); +} + +bool dccp_qpolicy_full(struct sock *sk) +{ + return qpol_table[dccp_sk(sk)->dccps_qpolicy].full(sk); +} + +void dccp_qpolicy_drop(struct sock *sk, struct sk_buff *skb) +{ + if (skb != NULL) { + skb_unlink(skb, &sk->sk_write_queue); + kfree_skb(skb); + } +} + +struct sk_buff *dccp_qpolicy_top(struct sock *sk) +{ + return qpol_table[dccp_sk(sk)->dccps_qpolicy].top(sk); +} + +struct sk_buff *dccp_qpolicy_pop(struct sock *sk) +{ + struct sk_buff *skb = dccp_qpolicy_top(sk); + + if (skb != NULL) { + /* Clear any skb fields that we used internally */ + skb->priority = 0; + skb_unlink(skb, &sk->sk_write_queue); + } + return skb; +} + +bool dccp_qpolicy_param_ok(struct sock *sk, __be32 param) +{ + /* check if exactly one bit is set */ + if (!param || (param & (param - 1))) + return false; + return (qpol_table[dccp_sk(sk)->dccps_qpolicy].params & param) == param; +} diff --git a/net/dccp/sysctl.c b/net/dccp/sysctl.c new file mode 100644 index 000000000..ee8d4f5af --- /dev/null +++ b/net/dccp/sysctl.c @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/dccp/sysctl.c + * + * An implementation of the DCCP protocol + * Arnaldo Carvalho de Melo + */ + +#include +#include +#include "dccp.h" +#include "feat.h" + +#ifndef CONFIG_SYSCTL +#error This file should not be compiled without CONFIG_SYSCTL defined +#endif + +/* Boundary values */ +static int u8_max = 0xFF; +static unsigned long seqw_min = DCCPF_SEQ_WMIN, + seqw_max = 0xFFFFFFFF; /* maximum on 32 bit */ + +static struct ctl_table dccp_default_table[] = { + { + .procname = "seq_window", + .data = &sysctl_dccp_sequence_window, + .maxlen = sizeof(sysctl_dccp_sequence_window), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + .extra1 = &seqw_min, /* RFC 4340, 7.5.2 */ + .extra2 = &seqw_max, + }, + { + .procname = "rx_ccid", + .data = &sysctl_dccp_rx_ccid, + .maxlen = sizeof(sysctl_dccp_rx_ccid), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &u8_max, /* RFC 4340, 10. */ + }, + { + .procname = "tx_ccid", + .data = &sysctl_dccp_tx_ccid, + .maxlen = sizeof(sysctl_dccp_tx_ccid), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &u8_max, /* RFC 4340, 10. */ + }, + { + .procname = "request_retries", + .data = &sysctl_dccp_request_retries, + .maxlen = sizeof(sysctl_dccp_request_retries), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + .extra2 = &u8_max, + }, + { + .procname = "retries1", + .data = &sysctl_dccp_retries1, + .maxlen = sizeof(sysctl_dccp_retries1), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &u8_max, + }, + { + .procname = "retries2", + .data = &sysctl_dccp_retries2, + .maxlen = sizeof(sysctl_dccp_retries2), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &u8_max, + }, + { + .procname = "tx_qlen", + .data = &sysctl_dccp_tx_qlen, + .maxlen = sizeof(sysctl_dccp_tx_qlen), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + }, + { + .procname = "sync_ratelimit", + .data = &sysctl_dccp_sync_ratelimit, + .maxlen = sizeof(sysctl_dccp_sync_ratelimit), + .mode = 0644, + .proc_handler = proc_dointvec_ms_jiffies, + }, + + { } +}; + +static struct ctl_table_header *dccp_table_header; + +int __init dccp_sysctl_init(void) +{ + dccp_table_header = register_net_sysctl(&init_net, "net/dccp/default", + dccp_default_table); + + return dccp_table_header != NULL ? 0 : -ENOMEM; +} + +void dccp_sysctl_exit(void) +{ + if (dccp_table_header != NULL) { + unregister_net_sysctl_table(dccp_table_header); + dccp_table_header = NULL; + } +} diff --git a/net/dccp/timer.c b/net/dccp/timer.c new file mode 100644 index 000000000..27a3b37ac --- /dev/null +++ b/net/dccp/timer.c @@ -0,0 +1,272 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/dccp/timer.c + * + * An implementation of the DCCP protocol + * Arnaldo Carvalho de Melo + */ + +#include +#include +#include + +#include "dccp.h" + +/* sysctl variables governing numbers of retransmission attempts */ +int sysctl_dccp_request_retries __read_mostly = TCP_SYN_RETRIES; +int sysctl_dccp_retries1 __read_mostly = TCP_RETR1; +int sysctl_dccp_retries2 __read_mostly = TCP_RETR2; + +static void dccp_write_err(struct sock *sk) +{ + sk->sk_err = sk->sk_err_soft ? : ETIMEDOUT; + sk_error_report(sk); + + dccp_send_reset(sk, DCCP_RESET_CODE_ABORTED); + dccp_done(sk); + __DCCP_INC_STATS(DCCP_MIB_ABORTONTIMEOUT); +} + +/* A write timeout has occurred. Process the after effects. */ +static int dccp_write_timeout(struct sock *sk) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + int retry_until; + + if (sk->sk_state == DCCP_REQUESTING || sk->sk_state == DCCP_PARTOPEN) { + if (icsk->icsk_retransmits != 0) + dst_negative_advice(sk); + retry_until = icsk->icsk_syn_retries ? + : sysctl_dccp_request_retries; + } else { + if (icsk->icsk_retransmits >= sysctl_dccp_retries1) { + /* NOTE. draft-ietf-tcpimpl-pmtud-01.txt requires pmtu + black hole detection. :-( + + It is place to make it. It is not made. I do not want + to make it. It is disguisting. It does not work in any + case. Let me to cite the same draft, which requires for + us to implement this: + + "The one security concern raised by this memo is that ICMP black holes + are often caused by over-zealous security administrators who block + all ICMP messages. It is vitally important that those who design and + deploy security systems understand the impact of strict filtering on + upper-layer protocols. The safest web site in the world is worthless + if most TCP implementations cannot transfer data from it. It would + be far nicer to have all of the black holes fixed rather than fixing + all of the TCP implementations." + + Golden words :-). + */ + + dst_negative_advice(sk); + } + + retry_until = sysctl_dccp_retries2; + /* + * FIXME: see tcp_write_timout and tcp_out_of_resources + */ + } + + if (icsk->icsk_retransmits >= retry_until) { + /* Has it gone just too far? */ + dccp_write_err(sk); + return 1; + } + return 0; +} + +/* + * The DCCP retransmit timer. + */ +static void dccp_retransmit_timer(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + /* + * More than 4MSL (8 minutes) has passed, a RESET(aborted) was + * sent, no need to retransmit, this sock is dead. + */ + if (dccp_write_timeout(sk)) + return; + + /* + * We want to know the number of packets retransmitted, not the + * total number of retransmissions of clones of original packets. + */ + if (icsk->icsk_retransmits == 0) + __DCCP_INC_STATS(DCCP_MIB_TIMEOUTS); + + if (dccp_retransmit_skb(sk) != 0) { + /* + * Retransmission failed because of local congestion, + * do not backoff. + */ + if (--icsk->icsk_retransmits == 0) + icsk->icsk_retransmits = 1; + inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, + min(icsk->icsk_rto, + TCP_RESOURCE_PROBE_INTERVAL), + DCCP_RTO_MAX); + return; + } + + icsk->icsk_backoff++; + + icsk->icsk_rto = min(icsk->icsk_rto << 1, DCCP_RTO_MAX); + inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, icsk->icsk_rto, + DCCP_RTO_MAX); + if (icsk->icsk_retransmits > sysctl_dccp_retries1) + __sk_dst_reset(sk); +} + +static void dccp_write_timer(struct timer_list *t) +{ + struct inet_connection_sock *icsk = + from_timer(icsk, t, icsk_retransmit_timer); + struct sock *sk = &icsk->icsk_inet.sk; + int event = 0; + + bh_lock_sock(sk); + if (sock_owned_by_user(sk)) { + /* Try again later */ + sk_reset_timer(sk, &icsk->icsk_retransmit_timer, + jiffies + (HZ / 20)); + goto out; + } + + if (sk->sk_state == DCCP_CLOSED || !icsk->icsk_pending) + goto out; + + if (time_after(icsk->icsk_timeout, jiffies)) { + sk_reset_timer(sk, &icsk->icsk_retransmit_timer, + icsk->icsk_timeout); + goto out; + } + + event = icsk->icsk_pending; + icsk->icsk_pending = 0; + + switch (event) { + case ICSK_TIME_RETRANS: + dccp_retransmit_timer(sk); + break; + } +out: + bh_unlock_sock(sk); + sock_put(sk); +} + +static void dccp_keepalive_timer(struct timer_list *t) +{ + struct sock *sk = from_timer(sk, t, sk_timer); + + pr_err("dccp should not use a keepalive timer !\n"); + sock_put(sk); +} + +/* This is the same as tcp_delack_timer, sans prequeue & mem_reclaim stuff */ +static void dccp_delack_timer(struct timer_list *t) +{ + struct inet_connection_sock *icsk = + from_timer(icsk, t, icsk_delack_timer); + struct sock *sk = &icsk->icsk_inet.sk; + + bh_lock_sock(sk); + if (sock_owned_by_user(sk)) { + /* Try again later. */ + __NET_INC_STATS(sock_net(sk), LINUX_MIB_DELAYEDACKLOCKED); + sk_reset_timer(sk, &icsk->icsk_delack_timer, + jiffies + TCP_DELACK_MIN); + goto out; + } + + if (sk->sk_state == DCCP_CLOSED || + !(icsk->icsk_ack.pending & ICSK_ACK_TIMER)) + goto out; + if (time_after(icsk->icsk_ack.timeout, jiffies)) { + sk_reset_timer(sk, &icsk->icsk_delack_timer, + icsk->icsk_ack.timeout); + goto out; + } + + icsk->icsk_ack.pending &= ~ICSK_ACK_TIMER; + + if (inet_csk_ack_scheduled(sk)) { + if (!inet_csk_in_pingpong_mode(sk)) { + /* Delayed ACK missed: inflate ATO. */ + icsk->icsk_ack.ato = min(icsk->icsk_ack.ato << 1, + icsk->icsk_rto); + } else { + /* Delayed ACK missed: leave pingpong mode and + * deflate ATO. + */ + inet_csk_exit_pingpong_mode(sk); + icsk->icsk_ack.ato = TCP_ATO_MIN; + } + dccp_send_ack(sk); + __NET_INC_STATS(sock_net(sk), LINUX_MIB_DELAYEDACKS); + } +out: + bh_unlock_sock(sk); + sock_put(sk); +} + +/** + * dccp_write_xmitlet - Workhorse for CCID packet dequeueing interface + * @t: pointer to the tasklet associated with this handler + * + * See the comments above %ccid_dequeueing_decision for supported modes. + */ +static void dccp_write_xmitlet(struct tasklet_struct *t) +{ + struct dccp_sock *dp = from_tasklet(dp, t, dccps_xmitlet); + struct sock *sk = &dp->dccps_inet_connection.icsk_inet.sk; + + bh_lock_sock(sk); + if (sock_owned_by_user(sk)) + sk_reset_timer(sk, &dccp_sk(sk)->dccps_xmit_timer, jiffies + 1); + else + dccp_write_xmit(sk); + bh_unlock_sock(sk); + sock_put(sk); +} + +static void dccp_write_xmit_timer(struct timer_list *t) +{ + struct dccp_sock *dp = from_timer(dp, t, dccps_xmit_timer); + + dccp_write_xmitlet(&dp->dccps_xmitlet); +} + +void dccp_init_xmit_timers(struct sock *sk) +{ + struct dccp_sock *dp = dccp_sk(sk); + + tasklet_setup(&dp->dccps_xmitlet, dccp_write_xmitlet); + timer_setup(&dp->dccps_xmit_timer, dccp_write_xmit_timer, 0); + inet_csk_init_xmit_timers(sk, &dccp_write_timer, &dccp_delack_timer, + &dccp_keepalive_timer); +} + +static ktime_t dccp_timestamp_seed; +/** + * dccp_timestamp - 10s of microseconds time source + * Returns the number of 10s of microseconds since loading DCCP. This is native + * DCCP time difference format (RFC 4340, sec. 13). + * Please note: This will wrap around about circa every 11.9 hours. + */ +u32 dccp_timestamp(void) +{ + u64 delta = (u64)ktime_us_delta(ktime_get_real(), dccp_timestamp_seed); + + do_div(delta, 10); + return delta; +} +EXPORT_SYMBOL_GPL(dccp_timestamp); + +void __init dccp_timestamping_init(void) +{ + dccp_timestamp_seed = ktime_get_real(); +} diff --git a/net/dccp/trace.h b/net/dccp/trace.h new file mode 100644 index 000000000..5a43b3508 --- /dev/null +++ b/net/dccp/trace.h @@ -0,0 +1,82 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#undef TRACE_SYSTEM +#define TRACE_SYSTEM dccp + +#if !defined(_TRACE_DCCP_H) || defined(TRACE_HEADER_MULTI_READ) +#define _TRACE_DCCP_H + +#include +#include "dccp.h" +#include "ccids/ccid3.h" +#include +#include + +TRACE_EVENT(dccp_probe, + + TP_PROTO(struct sock *sk, size_t size), + + TP_ARGS(sk, size), + + TP_STRUCT__entry( + /* sockaddr_in6 is always bigger than sockaddr_in */ + __array(__u8, saddr, sizeof(struct sockaddr_in6)) + __array(__u8, daddr, sizeof(struct sockaddr_in6)) + __field(__u16, sport) + __field(__u16, dport) + __field(__u16, size) + __field(__u16, tx_s) + __field(__u32, tx_rtt) + __field(__u32, tx_p) + __field(__u32, tx_x_calc) + __field(__u64, tx_x_recv) + __field(__u64, tx_x) + __field(__u32, tx_t_ipi) + ), + + TP_fast_assign( + const struct inet_sock *inet = inet_sk(sk); + struct ccid3_hc_tx_sock *hc = NULL; + + if (ccid_get_current_tx_ccid(dccp_sk(sk)) == DCCPC_CCID3) + hc = ccid3_hc_tx_sk(sk); + + memset(__entry->saddr, 0, sizeof(struct sockaddr_in6)); + memset(__entry->daddr, 0, sizeof(struct sockaddr_in6)); + + TP_STORE_ADDR_PORTS(__entry, inet, sk); + + /* For filtering use */ + __entry->sport = ntohs(inet->inet_sport); + __entry->dport = ntohs(inet->inet_dport); + + __entry->size = size; + if (hc) { + __entry->tx_s = hc->tx_s; + __entry->tx_rtt = hc->tx_rtt; + __entry->tx_p = hc->tx_p; + __entry->tx_x_calc = hc->tx_x_calc; + __entry->tx_x_recv = hc->tx_x_recv >> 6; + __entry->tx_x = hc->tx_x >> 6; + __entry->tx_t_ipi = hc->tx_t_ipi; + } else { + __entry->tx_s = 0; + memset_startat(__entry, 0, tx_rtt); + } + ), + + TP_printk("src=%pISpc dest=%pISpc size=%d tx_s=%d tx_rtt=%d " + "tx_p=%d tx_x_calc=%u tx_x_recv=%llu tx_x=%llu tx_t_ipi=%d", + __entry->saddr, __entry->daddr, __entry->size, + __entry->tx_s, __entry->tx_rtt, __entry->tx_p, + __entry->tx_x_calc, __entry->tx_x_recv, __entry->tx_x, + __entry->tx_t_ipi) +); + +#endif /* _TRACE_TCP_H */ + +/* This part must be outside protection */ +#undef TRACE_INCLUDE_PATH +#define TRACE_INCLUDE_PATH . +#undef TRACE_INCLUDE_FILE +#define TRACE_INCLUDE_FILE trace +#include diff --git a/net/devlink/Makefile b/net/devlink/Makefile new file mode 100644 index 000000000..3a60959f7 --- /dev/null +++ b/net/devlink/Makefile @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: GPL-2.0 + +obj-y := leftover.o diff --git a/net/devlink/leftover.c b/net/devlink/leftover.c new file mode 100644 index 000000000..032c7af06 --- /dev/null +++ b/net/devlink/leftover.c @@ -0,0 +1,12550 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/core/devlink.c - Network physical/parent device Netlink interface + * + * Heavily inspired by net/wireless/ + * Copyright (c) 2016 Mellanox Technologies. All rights reserved. + * Copyright (c) 2016 Jiri Pirko + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#define CREATE_TRACE_POINTS +#include + +#define DEVLINK_RELOAD_STATS_ARRAY_SIZE \ + (__DEVLINK_RELOAD_LIMIT_MAX * __DEVLINK_RELOAD_ACTION_MAX) + +struct devlink_dev_stats { + u32 reload_stats[DEVLINK_RELOAD_STATS_ARRAY_SIZE]; + u32 remote_reload_stats[DEVLINK_RELOAD_STATS_ARRAY_SIZE]; +}; + +struct devlink { + u32 index; + struct list_head port_list; + struct list_head rate_list; + struct list_head sb_list; + struct list_head dpipe_table_list; + struct list_head resource_list; + struct list_head param_list; + struct list_head region_list; + struct list_head reporter_list; + struct mutex reporters_lock; /* protects reporter_list */ + struct devlink_dpipe_headers *dpipe_headers; + struct list_head trap_list; + struct list_head trap_group_list; + struct list_head trap_policer_list; + struct list_head linecard_list; + struct mutex linecards_lock; /* protects linecard_list */ + const struct devlink_ops *ops; + u64 features; + struct xarray snapshot_ids; + struct devlink_dev_stats stats; + struct device *dev; + possible_net_t _net; + /* Serializes access to devlink instance specific objects such as + * port, sb, dpipe, resource, params, region, traps and more. + */ + struct mutex lock; + struct lock_class_key lock_key; + u8 reload_failed:1; + refcount_t refcount; + struct completion comp; + struct rcu_head rcu; + char priv[] __aligned(NETDEV_ALIGN); +}; + +struct devlink_linecard_ops; +struct devlink_linecard_type; + +struct devlink_linecard { + struct list_head list; + struct devlink *devlink; + unsigned int index; + refcount_t refcount; + const struct devlink_linecard_ops *ops; + void *priv; + enum devlink_linecard_state state; + struct mutex state_lock; /* Protects state */ + const char *type; + struct devlink_linecard_type *types; + unsigned int types_count; + struct devlink *nested_devlink; +}; + +/** + * struct devlink_resource - devlink resource + * @name: name of the resource + * @id: id, per devlink instance + * @size: size of the resource + * @size_new: updated size of the resource, reload is needed + * @size_valid: valid in case the total size of the resource is valid + * including its children + * @parent: parent resource + * @size_params: size parameters + * @list: parent list + * @resource_list: list of child resources + * @occ_get: occupancy getter callback + * @occ_get_priv: occupancy getter callback priv + */ +struct devlink_resource { + const char *name; + u64 id; + u64 size; + u64 size_new; + bool size_valid; + struct devlink_resource *parent; + struct devlink_resource_size_params size_params; + struct list_head list; + struct list_head resource_list; + devlink_resource_occ_get_t *occ_get; + void *occ_get_priv; +}; + +void *devlink_priv(struct devlink *devlink) +{ + return &devlink->priv; +} +EXPORT_SYMBOL_GPL(devlink_priv); + +struct devlink *priv_to_devlink(void *priv) +{ + return container_of(priv, struct devlink, priv); +} +EXPORT_SYMBOL_GPL(priv_to_devlink); + +struct device *devlink_to_dev(const struct devlink *devlink) +{ + return devlink->dev; +} +EXPORT_SYMBOL_GPL(devlink_to_dev); + +static struct devlink_dpipe_field devlink_dpipe_fields_ethernet[] = { + { + .name = "destination mac", + .id = DEVLINK_DPIPE_FIELD_ETHERNET_DST_MAC, + .bitwidth = 48, + }, +}; + +struct devlink_dpipe_header devlink_dpipe_header_ethernet = { + .name = "ethernet", + .id = DEVLINK_DPIPE_HEADER_ETHERNET, + .fields = devlink_dpipe_fields_ethernet, + .fields_count = ARRAY_SIZE(devlink_dpipe_fields_ethernet), + .global = true, +}; +EXPORT_SYMBOL_GPL(devlink_dpipe_header_ethernet); + +static struct devlink_dpipe_field devlink_dpipe_fields_ipv4[] = { + { + .name = "destination ip", + .id = DEVLINK_DPIPE_FIELD_IPV4_DST_IP, + .bitwidth = 32, + }, +}; + +struct devlink_dpipe_header devlink_dpipe_header_ipv4 = { + .name = "ipv4", + .id = DEVLINK_DPIPE_HEADER_IPV4, + .fields = devlink_dpipe_fields_ipv4, + .fields_count = ARRAY_SIZE(devlink_dpipe_fields_ipv4), + .global = true, +}; +EXPORT_SYMBOL_GPL(devlink_dpipe_header_ipv4); + +static struct devlink_dpipe_field devlink_dpipe_fields_ipv6[] = { + { + .name = "destination ip", + .id = DEVLINK_DPIPE_FIELD_IPV6_DST_IP, + .bitwidth = 128, + }, +}; + +struct devlink_dpipe_header devlink_dpipe_header_ipv6 = { + .name = "ipv6", + .id = DEVLINK_DPIPE_HEADER_IPV6, + .fields = devlink_dpipe_fields_ipv6, + .fields_count = ARRAY_SIZE(devlink_dpipe_fields_ipv6), + .global = true, +}; +EXPORT_SYMBOL_GPL(devlink_dpipe_header_ipv6); + +EXPORT_TRACEPOINT_SYMBOL_GPL(devlink_hwmsg); +EXPORT_TRACEPOINT_SYMBOL_GPL(devlink_hwerr); +EXPORT_TRACEPOINT_SYMBOL_GPL(devlink_trap_report); + +static const struct nla_policy devlink_function_nl_policy[DEVLINK_PORT_FUNCTION_ATTR_MAX + 1] = { + [DEVLINK_PORT_FUNCTION_ATTR_HW_ADDR] = { .type = NLA_BINARY }, + [DEVLINK_PORT_FN_ATTR_STATE] = + NLA_POLICY_RANGE(NLA_U8, DEVLINK_PORT_FN_STATE_INACTIVE, + DEVLINK_PORT_FN_STATE_ACTIVE), +}; + +static const struct nla_policy devlink_selftest_nl_policy[DEVLINK_ATTR_SELFTEST_ID_MAX + 1] = { + [DEVLINK_ATTR_SELFTEST_ID_FLASH] = { .type = NLA_FLAG }, +}; + +static DEFINE_XARRAY_FLAGS(devlinks, XA_FLAGS_ALLOC); +#define DEVLINK_REGISTERED XA_MARK_1 +#define DEVLINK_UNREGISTERING XA_MARK_2 + +/* devlink instances are open to the access from the user space after + * devlink_register() call. Such logical barrier allows us to have certain + * expectations related to locking. + * + * Before *_register() - we are in initialization stage and no parallel + * access possible to the devlink instance. All drivers perform that phase + * by implicitly holding device_lock. + * + * After *_register() - users and driver can access devlink instance at + * the same time. + */ +#define ASSERT_DEVLINK_REGISTERED(d) \ + WARN_ON_ONCE(!xa_get_mark(&devlinks, (d)->index, DEVLINK_REGISTERED)) +#define ASSERT_DEVLINK_NOT_REGISTERED(d) \ + WARN_ON_ONCE(xa_get_mark(&devlinks, (d)->index, DEVLINK_REGISTERED)) + +struct net *devlink_net(const struct devlink *devlink) +{ + return read_pnet(&devlink->_net); +} +EXPORT_SYMBOL_GPL(devlink_net); + +static void __devlink_put_rcu(struct rcu_head *head) +{ + struct devlink *devlink = container_of(head, struct devlink, rcu); + + complete(&devlink->comp); +} + +void devlink_put(struct devlink *devlink) +{ + if (refcount_dec_and_test(&devlink->refcount)) + /* Make sure unregister operation that may await the completion + * is unblocked only after all users are after the end of + * RCU grace period. + */ + call_rcu(&devlink->rcu, __devlink_put_rcu); +} + +struct devlink *__must_check devlink_try_get(struct devlink *devlink) +{ + if (refcount_inc_not_zero(&devlink->refcount)) + return devlink; + return NULL; +} + +void devl_assert_locked(struct devlink *devlink) +{ + lockdep_assert_held(&devlink->lock); +} +EXPORT_SYMBOL_GPL(devl_assert_locked); + +#ifdef CONFIG_LOCKDEP +/* For use in conjunction with LOCKDEP only e.g. rcu_dereference_protected() */ +bool devl_lock_is_held(struct devlink *devlink) +{ + return lockdep_is_held(&devlink->lock); +} +EXPORT_SYMBOL_GPL(devl_lock_is_held); +#endif + +void devl_lock(struct devlink *devlink) +{ + mutex_lock(&devlink->lock); +} +EXPORT_SYMBOL_GPL(devl_lock); + +int devl_trylock(struct devlink *devlink) +{ + return mutex_trylock(&devlink->lock); +} +EXPORT_SYMBOL_GPL(devl_trylock); + +void devl_unlock(struct devlink *devlink) +{ + mutex_unlock(&devlink->lock); +} +EXPORT_SYMBOL_GPL(devl_unlock); + +static struct devlink * +devlinks_xa_find_get(struct net *net, unsigned long *indexp, xa_mark_t filter, + void * (*xa_find_fn)(struct xarray *, unsigned long *, + unsigned long, xa_mark_t)) +{ + struct devlink *devlink; + + rcu_read_lock(); +retry: + devlink = xa_find_fn(&devlinks, indexp, ULONG_MAX, DEVLINK_REGISTERED); + if (!devlink) + goto unlock; + + /* In case devlink_unregister() was already called and "unregistering" + * mark was set, do not allow to get a devlink reference here. + * This prevents live-lock of devlink_unregister() wait for completion. + */ + if (xa_get_mark(&devlinks, *indexp, DEVLINK_UNREGISTERING)) + goto retry; + + /* For a possible retry, the xa_find_after() should be always used */ + xa_find_fn = xa_find_after; + if (!devlink_try_get(devlink)) + goto retry; + if (!net_eq(devlink_net(devlink), net)) { + devlink_put(devlink); + goto retry; + } +unlock: + rcu_read_unlock(); + return devlink; +} + +static struct devlink *devlinks_xa_find_get_first(struct net *net, + unsigned long *indexp, + xa_mark_t filter) +{ + return devlinks_xa_find_get(net, indexp, filter, xa_find); +} + +static struct devlink *devlinks_xa_find_get_next(struct net *net, + unsigned long *indexp, + xa_mark_t filter) +{ + return devlinks_xa_find_get(net, indexp, filter, xa_find_after); +} + +/* Iterate over devlink pointers which were possible to get reference to. + * devlink_put() needs to be called for each iterated devlink pointer + * in loop body in order to release the reference. + */ +#define devlinks_xa_for_each_get(net, index, devlink, filter) \ + for (index = 0, \ + devlink = devlinks_xa_find_get_first(net, &index, filter); \ + devlink; devlink = devlinks_xa_find_get_next(net, &index, filter)) + +#define devlinks_xa_for_each_registered_get(net, index, devlink) \ + devlinks_xa_for_each_get(net, index, devlink, DEVLINK_REGISTERED) + +static struct devlink *devlink_get_from_attrs(struct net *net, + struct nlattr **attrs) +{ + struct devlink *devlink; + unsigned long index; + char *busname; + char *devname; + + if (!attrs[DEVLINK_ATTR_BUS_NAME] || !attrs[DEVLINK_ATTR_DEV_NAME]) + return ERR_PTR(-EINVAL); + + busname = nla_data(attrs[DEVLINK_ATTR_BUS_NAME]); + devname = nla_data(attrs[DEVLINK_ATTR_DEV_NAME]); + + devlinks_xa_for_each_registered_get(net, index, devlink) { + if (strcmp(devlink->dev->bus->name, busname) == 0 && + strcmp(dev_name(devlink->dev), devname) == 0) + return devlink; + devlink_put(devlink); + } + + return ERR_PTR(-ENODEV); +} + +#define ASSERT_DEVLINK_PORT_REGISTERED(devlink_port) \ + WARN_ON_ONCE(!(devlink_port)->registered) +#define ASSERT_DEVLINK_PORT_NOT_REGISTERED(devlink_port) \ + WARN_ON_ONCE((devlink_port)->registered) +#define ASSERT_DEVLINK_PORT_INITIALIZED(devlink_port) \ + WARN_ON_ONCE(!(devlink_port)->initialized) + +static struct devlink_port *devlink_port_get_by_index(struct devlink *devlink, + unsigned int port_index) +{ + struct devlink_port *devlink_port; + + list_for_each_entry(devlink_port, &devlink->port_list, list) { + if (devlink_port->index == port_index) + return devlink_port; + } + return NULL; +} + +static bool devlink_port_index_exists(struct devlink *devlink, + unsigned int port_index) +{ + return devlink_port_get_by_index(devlink, port_index); +} + +static struct devlink_port *devlink_port_get_from_attrs(struct devlink *devlink, + struct nlattr **attrs) +{ + if (attrs[DEVLINK_ATTR_PORT_INDEX]) { + u32 port_index = nla_get_u32(attrs[DEVLINK_ATTR_PORT_INDEX]); + struct devlink_port *devlink_port; + + devlink_port = devlink_port_get_by_index(devlink, port_index); + if (!devlink_port) + return ERR_PTR(-ENODEV); + return devlink_port; + } + return ERR_PTR(-EINVAL); +} + +static struct devlink_port *devlink_port_get_from_info(struct devlink *devlink, + struct genl_info *info) +{ + return devlink_port_get_from_attrs(devlink, info->attrs); +} + +static inline bool +devlink_rate_is_leaf(struct devlink_rate *devlink_rate) +{ + return devlink_rate->type == DEVLINK_RATE_TYPE_LEAF; +} + +static inline bool +devlink_rate_is_node(struct devlink_rate *devlink_rate) +{ + return devlink_rate->type == DEVLINK_RATE_TYPE_NODE; +} + +static struct devlink_rate * +devlink_rate_leaf_get_from_info(struct devlink *devlink, struct genl_info *info) +{ + struct devlink_rate *devlink_rate; + struct devlink_port *devlink_port; + + devlink_port = devlink_port_get_from_attrs(devlink, info->attrs); + if (IS_ERR(devlink_port)) + return ERR_CAST(devlink_port); + devlink_rate = devlink_port->devlink_rate; + return devlink_rate ?: ERR_PTR(-ENODEV); +} + +static struct devlink_rate * +devlink_rate_node_get_by_name(struct devlink *devlink, const char *node_name) +{ + static struct devlink_rate *devlink_rate; + + list_for_each_entry(devlink_rate, &devlink->rate_list, list) { + if (devlink_rate_is_node(devlink_rate) && + !strcmp(node_name, devlink_rate->name)) + return devlink_rate; + } + return ERR_PTR(-ENODEV); +} + +static struct devlink_rate * +devlink_rate_node_get_from_attrs(struct devlink *devlink, struct nlattr **attrs) +{ + const char *rate_node_name; + size_t len; + + if (!attrs[DEVLINK_ATTR_RATE_NODE_NAME]) + return ERR_PTR(-EINVAL); + rate_node_name = nla_data(attrs[DEVLINK_ATTR_RATE_NODE_NAME]); + len = strlen(rate_node_name); + /* Name cannot be empty or decimal number */ + if (!len || strspn(rate_node_name, "0123456789") == len) + return ERR_PTR(-EINVAL); + + return devlink_rate_node_get_by_name(devlink, rate_node_name); +} + +static struct devlink_rate * +devlink_rate_node_get_from_info(struct devlink *devlink, struct genl_info *info) +{ + return devlink_rate_node_get_from_attrs(devlink, info->attrs); +} + +static struct devlink_rate * +devlink_rate_get_from_info(struct devlink *devlink, struct genl_info *info) +{ + struct nlattr **attrs = info->attrs; + + if (attrs[DEVLINK_ATTR_PORT_INDEX]) + return devlink_rate_leaf_get_from_info(devlink, info); + else if (attrs[DEVLINK_ATTR_RATE_NODE_NAME]) + return devlink_rate_node_get_from_info(devlink, info); + else + return ERR_PTR(-EINVAL); +} + +static struct devlink_linecard * +devlink_linecard_get_by_index(struct devlink *devlink, + unsigned int linecard_index) +{ + struct devlink_linecard *devlink_linecard; + + list_for_each_entry(devlink_linecard, &devlink->linecard_list, list) { + if (devlink_linecard->index == linecard_index) + return devlink_linecard; + } + return NULL; +} + +static bool devlink_linecard_index_exists(struct devlink *devlink, + unsigned int linecard_index) +{ + return devlink_linecard_get_by_index(devlink, linecard_index); +} + +static struct devlink_linecard * +devlink_linecard_get_from_attrs(struct devlink *devlink, struct nlattr **attrs) +{ + if (attrs[DEVLINK_ATTR_LINECARD_INDEX]) { + u32 linecard_index = nla_get_u32(attrs[DEVLINK_ATTR_LINECARD_INDEX]); + struct devlink_linecard *linecard; + + mutex_lock(&devlink->linecards_lock); + linecard = devlink_linecard_get_by_index(devlink, linecard_index); + if (linecard) + refcount_inc(&linecard->refcount); + mutex_unlock(&devlink->linecards_lock); + if (!linecard) + return ERR_PTR(-ENODEV); + return linecard; + } + return ERR_PTR(-EINVAL); +} + +static struct devlink_linecard * +devlink_linecard_get_from_info(struct devlink *devlink, struct genl_info *info) +{ + return devlink_linecard_get_from_attrs(devlink, info->attrs); +} + +static void devlink_linecard_put(struct devlink_linecard *linecard) +{ + if (refcount_dec_and_test(&linecard->refcount)) { + mutex_destroy(&linecard->state_lock); + kfree(linecard); + } +} + +struct devlink_sb { + struct list_head list; + unsigned int index; + u32 size; + u16 ingress_pools_count; + u16 egress_pools_count; + u16 ingress_tc_count; + u16 egress_tc_count; +}; + +static u16 devlink_sb_pool_count(struct devlink_sb *devlink_sb) +{ + return devlink_sb->ingress_pools_count + devlink_sb->egress_pools_count; +} + +static struct devlink_sb *devlink_sb_get_by_index(struct devlink *devlink, + unsigned int sb_index) +{ + struct devlink_sb *devlink_sb; + + list_for_each_entry(devlink_sb, &devlink->sb_list, list) { + if (devlink_sb->index == sb_index) + return devlink_sb; + } + return NULL; +} + +static bool devlink_sb_index_exists(struct devlink *devlink, + unsigned int sb_index) +{ + return devlink_sb_get_by_index(devlink, sb_index); +} + +static struct devlink_sb *devlink_sb_get_from_attrs(struct devlink *devlink, + struct nlattr **attrs) +{ + if (attrs[DEVLINK_ATTR_SB_INDEX]) { + u32 sb_index = nla_get_u32(attrs[DEVLINK_ATTR_SB_INDEX]); + struct devlink_sb *devlink_sb; + + devlink_sb = devlink_sb_get_by_index(devlink, sb_index); + if (!devlink_sb) + return ERR_PTR(-ENODEV); + return devlink_sb; + } + return ERR_PTR(-EINVAL); +} + +static struct devlink_sb *devlink_sb_get_from_info(struct devlink *devlink, + struct genl_info *info) +{ + return devlink_sb_get_from_attrs(devlink, info->attrs); +} + +static int devlink_sb_pool_index_get_from_attrs(struct devlink_sb *devlink_sb, + struct nlattr **attrs, + u16 *p_pool_index) +{ + u16 val; + + if (!attrs[DEVLINK_ATTR_SB_POOL_INDEX]) + return -EINVAL; + + val = nla_get_u16(attrs[DEVLINK_ATTR_SB_POOL_INDEX]); + if (val >= devlink_sb_pool_count(devlink_sb)) + return -EINVAL; + *p_pool_index = val; + return 0; +} + +static int devlink_sb_pool_index_get_from_info(struct devlink_sb *devlink_sb, + struct genl_info *info, + u16 *p_pool_index) +{ + return devlink_sb_pool_index_get_from_attrs(devlink_sb, info->attrs, + p_pool_index); +} + +static int +devlink_sb_pool_type_get_from_attrs(struct nlattr **attrs, + enum devlink_sb_pool_type *p_pool_type) +{ + u8 val; + + if (!attrs[DEVLINK_ATTR_SB_POOL_TYPE]) + return -EINVAL; + + val = nla_get_u8(attrs[DEVLINK_ATTR_SB_POOL_TYPE]); + if (val != DEVLINK_SB_POOL_TYPE_INGRESS && + val != DEVLINK_SB_POOL_TYPE_EGRESS) + return -EINVAL; + *p_pool_type = val; + return 0; +} + +static int +devlink_sb_pool_type_get_from_info(struct genl_info *info, + enum devlink_sb_pool_type *p_pool_type) +{ + return devlink_sb_pool_type_get_from_attrs(info->attrs, p_pool_type); +} + +static int +devlink_sb_th_type_get_from_attrs(struct nlattr **attrs, + enum devlink_sb_threshold_type *p_th_type) +{ + u8 val; + + if (!attrs[DEVLINK_ATTR_SB_POOL_THRESHOLD_TYPE]) + return -EINVAL; + + val = nla_get_u8(attrs[DEVLINK_ATTR_SB_POOL_THRESHOLD_TYPE]); + if (val != DEVLINK_SB_THRESHOLD_TYPE_STATIC && + val != DEVLINK_SB_THRESHOLD_TYPE_DYNAMIC) + return -EINVAL; + *p_th_type = val; + return 0; +} + +static int +devlink_sb_th_type_get_from_info(struct genl_info *info, + enum devlink_sb_threshold_type *p_th_type) +{ + return devlink_sb_th_type_get_from_attrs(info->attrs, p_th_type); +} + +static int +devlink_sb_tc_index_get_from_attrs(struct devlink_sb *devlink_sb, + struct nlattr **attrs, + enum devlink_sb_pool_type pool_type, + u16 *p_tc_index) +{ + u16 val; + + if (!attrs[DEVLINK_ATTR_SB_TC_INDEX]) + return -EINVAL; + + val = nla_get_u16(attrs[DEVLINK_ATTR_SB_TC_INDEX]); + if (pool_type == DEVLINK_SB_POOL_TYPE_INGRESS && + val >= devlink_sb->ingress_tc_count) + return -EINVAL; + if (pool_type == DEVLINK_SB_POOL_TYPE_EGRESS && + val >= devlink_sb->egress_tc_count) + return -EINVAL; + *p_tc_index = val; + return 0; +} + +static int +devlink_sb_tc_index_get_from_info(struct devlink_sb *devlink_sb, + struct genl_info *info, + enum devlink_sb_pool_type pool_type, + u16 *p_tc_index) +{ + return devlink_sb_tc_index_get_from_attrs(devlink_sb, info->attrs, + pool_type, p_tc_index); +} + +struct devlink_region { + struct devlink *devlink; + struct devlink_port *port; + struct list_head list; + union { + const struct devlink_region_ops *ops; + const struct devlink_port_region_ops *port_ops; + }; + struct mutex snapshot_lock; /* protects snapshot_list, + * max_snapshots and cur_snapshots + * consistency. + */ + struct list_head snapshot_list; + u32 max_snapshots; + u32 cur_snapshots; + u64 size; +}; + +struct devlink_snapshot { + struct list_head list; + struct devlink_region *region; + u8 *data; + u32 id; +}; + +static struct devlink_region * +devlink_region_get_by_name(struct devlink *devlink, const char *region_name) +{ + struct devlink_region *region; + + list_for_each_entry(region, &devlink->region_list, list) + if (!strcmp(region->ops->name, region_name)) + return region; + + return NULL; +} + +static struct devlink_region * +devlink_port_region_get_by_name(struct devlink_port *port, + const char *region_name) +{ + struct devlink_region *region; + + list_for_each_entry(region, &port->region_list, list) + if (!strcmp(region->ops->name, region_name)) + return region; + + return NULL; +} + +static struct devlink_snapshot * +devlink_region_snapshot_get_by_id(struct devlink_region *region, u32 id) +{ + struct devlink_snapshot *snapshot; + + list_for_each_entry(snapshot, ®ion->snapshot_list, list) + if (snapshot->id == id) + return snapshot; + + return NULL; +} + +#define DEVLINK_NL_FLAG_NEED_PORT BIT(0) +#define DEVLINK_NL_FLAG_NEED_DEVLINK_OR_PORT BIT(1) +#define DEVLINK_NL_FLAG_NEED_RATE BIT(2) +#define DEVLINK_NL_FLAG_NEED_RATE_NODE BIT(3) +#define DEVLINK_NL_FLAG_NEED_LINECARD BIT(4) + +static int devlink_nl_pre_doit(const struct genl_ops *ops, + struct sk_buff *skb, struct genl_info *info) +{ + struct devlink_linecard *linecard; + struct devlink_port *devlink_port; + struct devlink *devlink; + int err; + + devlink = devlink_get_from_attrs(genl_info_net(info), info->attrs); + if (IS_ERR(devlink)) + return PTR_ERR(devlink); + devl_lock(devlink); + info->user_ptr[0] = devlink; + if (ops->internal_flags & DEVLINK_NL_FLAG_NEED_PORT) { + devlink_port = devlink_port_get_from_info(devlink, info); + if (IS_ERR(devlink_port)) { + err = PTR_ERR(devlink_port); + goto unlock; + } + info->user_ptr[1] = devlink_port; + } else if (ops->internal_flags & DEVLINK_NL_FLAG_NEED_DEVLINK_OR_PORT) { + devlink_port = devlink_port_get_from_info(devlink, info); + if (!IS_ERR(devlink_port)) + info->user_ptr[1] = devlink_port; + } else if (ops->internal_flags & DEVLINK_NL_FLAG_NEED_RATE) { + struct devlink_rate *devlink_rate; + + devlink_rate = devlink_rate_get_from_info(devlink, info); + if (IS_ERR(devlink_rate)) { + err = PTR_ERR(devlink_rate); + goto unlock; + } + info->user_ptr[1] = devlink_rate; + } else if (ops->internal_flags & DEVLINK_NL_FLAG_NEED_RATE_NODE) { + struct devlink_rate *rate_node; + + rate_node = devlink_rate_node_get_from_info(devlink, info); + if (IS_ERR(rate_node)) { + err = PTR_ERR(rate_node); + goto unlock; + } + info->user_ptr[1] = rate_node; + } else if (ops->internal_flags & DEVLINK_NL_FLAG_NEED_LINECARD) { + linecard = devlink_linecard_get_from_info(devlink, info); + if (IS_ERR(linecard)) { + err = PTR_ERR(linecard); + goto unlock; + } + info->user_ptr[1] = linecard; + } + return 0; + +unlock: + devl_unlock(devlink); + devlink_put(devlink); + return err; +} + +static void devlink_nl_post_doit(const struct genl_ops *ops, + struct sk_buff *skb, struct genl_info *info) +{ + struct devlink_linecard *linecard; + struct devlink *devlink; + + devlink = info->user_ptr[0]; + if (ops->internal_flags & DEVLINK_NL_FLAG_NEED_LINECARD) { + linecard = info->user_ptr[1]; + devlink_linecard_put(linecard); + } + devl_unlock(devlink); + devlink_put(devlink); +} + +static struct genl_family devlink_nl_family; + +enum devlink_multicast_groups { + DEVLINK_MCGRP_CONFIG, +}; + +static const struct genl_multicast_group devlink_nl_mcgrps[] = { + [DEVLINK_MCGRP_CONFIG] = { .name = DEVLINK_GENL_MCGRP_CONFIG_NAME }, +}; + +static int devlink_nl_put_handle(struct sk_buff *msg, struct devlink *devlink) +{ + if (nla_put_string(msg, DEVLINK_ATTR_BUS_NAME, devlink->dev->bus->name)) + return -EMSGSIZE; + if (nla_put_string(msg, DEVLINK_ATTR_DEV_NAME, dev_name(devlink->dev))) + return -EMSGSIZE; + return 0; +} + +static int devlink_nl_put_nested_handle(struct sk_buff *msg, struct devlink *devlink) +{ + struct nlattr *nested_attr; + + nested_attr = nla_nest_start(msg, DEVLINK_ATTR_NESTED_DEVLINK); + if (!nested_attr) + return -EMSGSIZE; + if (devlink_nl_put_handle(msg, devlink)) + goto nla_put_failure; + + nla_nest_end(msg, nested_attr); + return 0; + +nla_put_failure: + nla_nest_cancel(msg, nested_attr); + return -EMSGSIZE; +} + +struct devlink_reload_combination { + enum devlink_reload_action action; + enum devlink_reload_limit limit; +}; + +static const struct devlink_reload_combination devlink_reload_invalid_combinations[] = { + { + /* can't reinitialize driver with no down time */ + .action = DEVLINK_RELOAD_ACTION_DRIVER_REINIT, + .limit = DEVLINK_RELOAD_LIMIT_NO_RESET, + }, +}; + +static bool +devlink_reload_combination_is_invalid(enum devlink_reload_action action, + enum devlink_reload_limit limit) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(devlink_reload_invalid_combinations); i++) + if (devlink_reload_invalid_combinations[i].action == action && + devlink_reload_invalid_combinations[i].limit == limit) + return true; + return false; +} + +static bool +devlink_reload_action_is_supported(struct devlink *devlink, enum devlink_reload_action action) +{ + return test_bit(action, &devlink->ops->reload_actions); +} + +static bool +devlink_reload_limit_is_supported(struct devlink *devlink, enum devlink_reload_limit limit) +{ + return test_bit(limit, &devlink->ops->reload_limits); +} + +static int devlink_reload_stat_put(struct sk_buff *msg, + enum devlink_reload_limit limit, u32 value) +{ + struct nlattr *reload_stats_entry; + + reload_stats_entry = nla_nest_start(msg, DEVLINK_ATTR_RELOAD_STATS_ENTRY); + if (!reload_stats_entry) + return -EMSGSIZE; + + if (nla_put_u8(msg, DEVLINK_ATTR_RELOAD_STATS_LIMIT, limit) || + nla_put_u32(msg, DEVLINK_ATTR_RELOAD_STATS_VALUE, value)) + goto nla_put_failure; + nla_nest_end(msg, reload_stats_entry); + return 0; + +nla_put_failure: + nla_nest_cancel(msg, reload_stats_entry); + return -EMSGSIZE; +} + +static int devlink_reload_stats_put(struct sk_buff *msg, struct devlink *devlink, bool is_remote) +{ + struct nlattr *reload_stats_attr, *act_info, *act_stats; + int i, j, stat_idx; + u32 value; + + if (!is_remote) + reload_stats_attr = nla_nest_start(msg, DEVLINK_ATTR_RELOAD_STATS); + else + reload_stats_attr = nla_nest_start(msg, DEVLINK_ATTR_REMOTE_RELOAD_STATS); + + if (!reload_stats_attr) + return -EMSGSIZE; + + for (i = 0; i <= DEVLINK_RELOAD_ACTION_MAX; i++) { + if ((!is_remote && + !devlink_reload_action_is_supported(devlink, i)) || + i == DEVLINK_RELOAD_ACTION_UNSPEC) + continue; + act_info = nla_nest_start(msg, DEVLINK_ATTR_RELOAD_ACTION_INFO); + if (!act_info) + goto nla_put_failure; + + if (nla_put_u8(msg, DEVLINK_ATTR_RELOAD_ACTION, i)) + goto action_info_nest_cancel; + act_stats = nla_nest_start(msg, DEVLINK_ATTR_RELOAD_ACTION_STATS); + if (!act_stats) + goto action_info_nest_cancel; + + for (j = 0; j <= DEVLINK_RELOAD_LIMIT_MAX; j++) { + /* Remote stats are shown even if not locally supported. + * Stats of actions with unspecified limit are shown + * though drivers don't need to register unspecified + * limit. + */ + if ((!is_remote && j != DEVLINK_RELOAD_LIMIT_UNSPEC && + !devlink_reload_limit_is_supported(devlink, j)) || + devlink_reload_combination_is_invalid(i, j)) + continue; + + stat_idx = j * __DEVLINK_RELOAD_ACTION_MAX + i; + if (!is_remote) + value = devlink->stats.reload_stats[stat_idx]; + else + value = devlink->stats.remote_reload_stats[stat_idx]; + if (devlink_reload_stat_put(msg, j, value)) + goto action_stats_nest_cancel; + } + nla_nest_end(msg, act_stats); + nla_nest_end(msg, act_info); + } + nla_nest_end(msg, reload_stats_attr); + return 0; + +action_stats_nest_cancel: + nla_nest_cancel(msg, act_stats); +action_info_nest_cancel: + nla_nest_cancel(msg, act_info); +nla_put_failure: + nla_nest_cancel(msg, reload_stats_attr); + return -EMSGSIZE; +} + +static int devlink_nl_fill(struct sk_buff *msg, struct devlink *devlink, + enum devlink_command cmd, u32 portid, + u32 seq, int flags) +{ + struct nlattr *dev_stats; + void *hdr; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + if (devlink_nl_put_handle(msg, devlink)) + goto nla_put_failure; + if (nla_put_u8(msg, DEVLINK_ATTR_RELOAD_FAILED, devlink->reload_failed)) + goto nla_put_failure; + + dev_stats = nla_nest_start(msg, DEVLINK_ATTR_DEV_STATS); + if (!dev_stats) + goto nla_put_failure; + + if (devlink_reload_stats_put(msg, devlink, false)) + goto dev_stats_nest_cancel; + if (devlink_reload_stats_put(msg, devlink, true)) + goto dev_stats_nest_cancel; + + nla_nest_end(msg, dev_stats); + genlmsg_end(msg, hdr); + return 0; + +dev_stats_nest_cancel: + nla_nest_cancel(msg, dev_stats); +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static void devlink_notify(struct devlink *devlink, enum devlink_command cmd) +{ + struct sk_buff *msg; + int err; + + WARN_ON(cmd != DEVLINK_CMD_NEW && cmd != DEVLINK_CMD_DEL); + WARN_ON(!xa_get_mark(&devlinks, devlink->index, DEVLINK_REGISTERED)); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + err = devlink_nl_fill(msg, devlink, cmd, 0, 0, 0); + if (err) { + nlmsg_free(msg); + return; + } + + genlmsg_multicast_netns(&devlink_nl_family, devlink_net(devlink), + msg, 0, DEVLINK_MCGRP_CONFIG, GFP_KERNEL); +} + +static int devlink_nl_port_attrs_put(struct sk_buff *msg, + struct devlink_port *devlink_port) +{ + struct devlink_port_attrs *attrs = &devlink_port->attrs; + + if (!devlink_port->attrs_set) + return 0; + if (attrs->lanes) { + if (nla_put_u32(msg, DEVLINK_ATTR_PORT_LANES, attrs->lanes)) + return -EMSGSIZE; + } + if (nla_put_u8(msg, DEVLINK_ATTR_PORT_SPLITTABLE, attrs->splittable)) + return -EMSGSIZE; + if (nla_put_u16(msg, DEVLINK_ATTR_PORT_FLAVOUR, attrs->flavour)) + return -EMSGSIZE; + switch (devlink_port->attrs.flavour) { + case DEVLINK_PORT_FLAVOUR_PCI_PF: + if (nla_put_u32(msg, DEVLINK_ATTR_PORT_CONTROLLER_NUMBER, + attrs->pci_pf.controller) || + nla_put_u16(msg, DEVLINK_ATTR_PORT_PCI_PF_NUMBER, attrs->pci_pf.pf)) + return -EMSGSIZE; + if (nla_put_u8(msg, DEVLINK_ATTR_PORT_EXTERNAL, attrs->pci_pf.external)) + return -EMSGSIZE; + break; + case DEVLINK_PORT_FLAVOUR_PCI_VF: + if (nla_put_u32(msg, DEVLINK_ATTR_PORT_CONTROLLER_NUMBER, + attrs->pci_vf.controller) || + nla_put_u16(msg, DEVLINK_ATTR_PORT_PCI_PF_NUMBER, attrs->pci_vf.pf) || + nla_put_u16(msg, DEVLINK_ATTR_PORT_PCI_VF_NUMBER, attrs->pci_vf.vf)) + return -EMSGSIZE; + if (nla_put_u8(msg, DEVLINK_ATTR_PORT_EXTERNAL, attrs->pci_vf.external)) + return -EMSGSIZE; + break; + case DEVLINK_PORT_FLAVOUR_PCI_SF: + if (nla_put_u32(msg, DEVLINK_ATTR_PORT_CONTROLLER_NUMBER, + attrs->pci_sf.controller) || + nla_put_u16(msg, DEVLINK_ATTR_PORT_PCI_PF_NUMBER, + attrs->pci_sf.pf) || + nla_put_u32(msg, DEVLINK_ATTR_PORT_PCI_SF_NUMBER, + attrs->pci_sf.sf)) + return -EMSGSIZE; + break; + case DEVLINK_PORT_FLAVOUR_PHYSICAL: + case DEVLINK_PORT_FLAVOUR_CPU: + case DEVLINK_PORT_FLAVOUR_DSA: + if (nla_put_u32(msg, DEVLINK_ATTR_PORT_NUMBER, + attrs->phys.port_number)) + return -EMSGSIZE; + if (!attrs->split) + return 0; + if (nla_put_u32(msg, DEVLINK_ATTR_PORT_SPLIT_GROUP, + attrs->phys.port_number)) + return -EMSGSIZE; + if (nla_put_u32(msg, DEVLINK_ATTR_PORT_SPLIT_SUBPORT_NUMBER, + attrs->phys.split_subport_number)) + return -EMSGSIZE; + break; + default: + break; + } + return 0; +} + +static int devlink_port_fn_hw_addr_fill(const struct devlink_ops *ops, + struct devlink_port *port, + struct sk_buff *msg, + struct netlink_ext_ack *extack, + bool *msg_updated) +{ + u8 hw_addr[MAX_ADDR_LEN]; + int hw_addr_len; + int err; + + if (!ops->port_function_hw_addr_get) + return 0; + + err = ops->port_function_hw_addr_get(port, hw_addr, &hw_addr_len, + extack); + if (err) { + if (err == -EOPNOTSUPP) + return 0; + return err; + } + err = nla_put(msg, DEVLINK_PORT_FUNCTION_ATTR_HW_ADDR, hw_addr_len, hw_addr); + if (err) + return err; + *msg_updated = true; + return 0; +} + +static int devlink_nl_rate_fill(struct sk_buff *msg, + struct devlink_rate *devlink_rate, + enum devlink_command cmd, u32 portid, u32 seq, + int flags, struct netlink_ext_ack *extack) +{ + struct devlink *devlink = devlink_rate->devlink; + void *hdr; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + if (devlink_nl_put_handle(msg, devlink)) + goto nla_put_failure; + + if (nla_put_u16(msg, DEVLINK_ATTR_RATE_TYPE, devlink_rate->type)) + goto nla_put_failure; + + if (devlink_rate_is_leaf(devlink_rate)) { + if (nla_put_u32(msg, DEVLINK_ATTR_PORT_INDEX, + devlink_rate->devlink_port->index)) + goto nla_put_failure; + } else if (devlink_rate_is_node(devlink_rate)) { + if (nla_put_string(msg, DEVLINK_ATTR_RATE_NODE_NAME, + devlink_rate->name)) + goto nla_put_failure; + } + + if (nla_put_u64_64bit(msg, DEVLINK_ATTR_RATE_TX_SHARE, + devlink_rate->tx_share, DEVLINK_ATTR_PAD)) + goto nla_put_failure; + + if (nla_put_u64_64bit(msg, DEVLINK_ATTR_RATE_TX_MAX, + devlink_rate->tx_max, DEVLINK_ATTR_PAD)) + goto nla_put_failure; + + if (devlink_rate->parent) + if (nla_put_string(msg, DEVLINK_ATTR_RATE_PARENT_NODE_NAME, + devlink_rate->parent->name)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static bool +devlink_port_fn_state_valid(enum devlink_port_fn_state state) +{ + return state == DEVLINK_PORT_FN_STATE_INACTIVE || + state == DEVLINK_PORT_FN_STATE_ACTIVE; +} + +static bool +devlink_port_fn_opstate_valid(enum devlink_port_fn_opstate opstate) +{ + return opstate == DEVLINK_PORT_FN_OPSTATE_DETACHED || + opstate == DEVLINK_PORT_FN_OPSTATE_ATTACHED; +} + +static int devlink_port_fn_state_fill(const struct devlink_ops *ops, + struct devlink_port *port, + struct sk_buff *msg, + struct netlink_ext_ack *extack, + bool *msg_updated) +{ + enum devlink_port_fn_opstate opstate; + enum devlink_port_fn_state state; + int err; + + if (!ops->port_fn_state_get) + return 0; + + err = ops->port_fn_state_get(port, &state, &opstate, extack); + if (err) { + if (err == -EOPNOTSUPP) + return 0; + return err; + } + if (!devlink_port_fn_state_valid(state)) { + WARN_ON_ONCE(1); + NL_SET_ERR_MSG_MOD(extack, "Invalid state read from driver"); + return -EINVAL; + } + if (!devlink_port_fn_opstate_valid(opstate)) { + WARN_ON_ONCE(1); + NL_SET_ERR_MSG_MOD(extack, + "Invalid operational state read from driver"); + return -EINVAL; + } + if (nla_put_u8(msg, DEVLINK_PORT_FN_ATTR_STATE, state) || + nla_put_u8(msg, DEVLINK_PORT_FN_ATTR_OPSTATE, opstate)) + return -EMSGSIZE; + *msg_updated = true; + return 0; +} + +static int +devlink_nl_port_function_attrs_put(struct sk_buff *msg, struct devlink_port *port, + struct netlink_ext_ack *extack) +{ + const struct devlink_ops *ops; + struct nlattr *function_attr; + bool msg_updated = false; + int err; + + function_attr = nla_nest_start_noflag(msg, DEVLINK_ATTR_PORT_FUNCTION); + if (!function_attr) + return -EMSGSIZE; + + ops = port->devlink->ops; + err = devlink_port_fn_hw_addr_fill(ops, port, msg, extack, + &msg_updated); + if (err) + goto out; + err = devlink_port_fn_state_fill(ops, port, msg, extack, &msg_updated); +out: + if (err || !msg_updated) + nla_nest_cancel(msg, function_attr); + else + nla_nest_end(msg, function_attr); + return err; +} + +static int devlink_nl_port_fill(struct sk_buff *msg, + struct devlink_port *devlink_port, + enum devlink_command cmd, u32 portid, u32 seq, + int flags, struct netlink_ext_ack *extack) +{ + struct devlink *devlink = devlink_port->devlink; + void *hdr; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + if (devlink_nl_put_handle(msg, devlink)) + goto nla_put_failure; + if (nla_put_u32(msg, DEVLINK_ATTR_PORT_INDEX, devlink_port->index)) + goto nla_put_failure; + + /* Hold rtnl lock while accessing port's netdev attributes. */ + rtnl_lock(); + spin_lock_bh(&devlink_port->type_lock); + if (nla_put_u16(msg, DEVLINK_ATTR_PORT_TYPE, devlink_port->type)) + goto nla_put_failure_type_locked; + if (devlink_port->desired_type != DEVLINK_PORT_TYPE_NOTSET && + nla_put_u16(msg, DEVLINK_ATTR_PORT_DESIRED_TYPE, + devlink_port->desired_type)) + goto nla_put_failure_type_locked; + if (devlink_port->type == DEVLINK_PORT_TYPE_ETH) { + struct net *net = devlink_net(devlink_port->devlink); + struct net_device *netdev = devlink_port->type_dev; + + if (netdev && net_eq(net, dev_net(netdev)) && + (nla_put_u32(msg, DEVLINK_ATTR_PORT_NETDEV_IFINDEX, + netdev->ifindex) || + nla_put_string(msg, DEVLINK_ATTR_PORT_NETDEV_NAME, + netdev->name))) + goto nla_put_failure_type_locked; + } + if (devlink_port->type == DEVLINK_PORT_TYPE_IB) { + struct ib_device *ibdev = devlink_port->type_dev; + + if (ibdev && + nla_put_string(msg, DEVLINK_ATTR_PORT_IBDEV_NAME, + ibdev->name)) + goto nla_put_failure_type_locked; + } + spin_unlock_bh(&devlink_port->type_lock); + rtnl_unlock(); + if (devlink_nl_port_attrs_put(msg, devlink_port)) + goto nla_put_failure; + if (devlink_nl_port_function_attrs_put(msg, devlink_port, extack)) + goto nla_put_failure; + if (devlink_port->linecard && + nla_put_u32(msg, DEVLINK_ATTR_LINECARD_INDEX, + devlink_port->linecard->index)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure_type_locked: + spin_unlock_bh(&devlink_port->type_lock); + rtnl_unlock(); +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static void devlink_port_notify(struct devlink_port *devlink_port, + enum devlink_command cmd) +{ + struct devlink *devlink = devlink_port->devlink; + struct sk_buff *msg; + int err; + + WARN_ON(cmd != DEVLINK_CMD_PORT_NEW && cmd != DEVLINK_CMD_PORT_DEL); + + if (!xa_get_mark(&devlinks, devlink->index, DEVLINK_REGISTERED)) + return; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + err = devlink_nl_port_fill(msg, devlink_port, cmd, 0, 0, 0, NULL); + if (err) { + nlmsg_free(msg); + return; + } + + genlmsg_multicast_netns(&devlink_nl_family, devlink_net(devlink), msg, + 0, DEVLINK_MCGRP_CONFIG, GFP_KERNEL); +} + +static void devlink_rate_notify(struct devlink_rate *devlink_rate, + enum devlink_command cmd) +{ + struct devlink *devlink = devlink_rate->devlink; + struct sk_buff *msg; + int err; + + WARN_ON(cmd != DEVLINK_CMD_RATE_NEW && cmd != DEVLINK_CMD_RATE_DEL); + + if (!xa_get_mark(&devlinks, devlink->index, DEVLINK_REGISTERED)) + return; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + err = devlink_nl_rate_fill(msg, devlink_rate, cmd, 0, 0, 0, NULL); + if (err) { + nlmsg_free(msg); + return; + } + + genlmsg_multicast_netns(&devlink_nl_family, devlink_net(devlink), msg, + 0, DEVLINK_MCGRP_CONFIG, GFP_KERNEL); +} + +static int devlink_nl_cmd_rate_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct devlink_rate *devlink_rate; + struct devlink *devlink; + int start = cb->args[0]; + unsigned long index; + int idx = 0; + int err = 0; + + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + devl_lock(devlink); + list_for_each_entry(devlink_rate, &devlink->rate_list, list) { + enum devlink_command cmd = DEVLINK_CMD_RATE_NEW; + u32 id = NETLINK_CB(cb->skb).portid; + + if (idx < start) { + idx++; + continue; + } + err = devlink_nl_rate_fill(msg, devlink_rate, cmd, id, + cb->nlh->nlmsg_seq, + NLM_F_MULTI, NULL); + if (err) { + devl_unlock(devlink); + devlink_put(devlink); + goto out; + } + idx++; + } + devl_unlock(devlink); + devlink_put(devlink); + } +out: + if (err != -EMSGSIZE) + return err; + + cb->args[0] = idx; + return msg->len; +} + +static int devlink_nl_cmd_rate_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink_rate *devlink_rate = info->user_ptr[1]; + struct sk_buff *msg; + int err; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_rate_fill(msg, devlink_rate, DEVLINK_CMD_RATE_NEW, + info->snd_portid, info->snd_seq, 0, + info->extack); + if (err) { + nlmsg_free(msg); + return err; + } + + return genlmsg_reply(msg, info); +} + +static bool +devlink_rate_is_parent_node(struct devlink_rate *devlink_rate, + struct devlink_rate *parent) +{ + while (parent) { + if (parent == devlink_rate) + return true; + parent = parent->parent; + } + return false; +} + +static int devlink_nl_cmd_get_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct sk_buff *msg; + int err; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_fill(msg, devlink, DEVLINK_CMD_NEW, + info->snd_portid, info->snd_seq, 0); + if (err) { + nlmsg_free(msg); + return err; + } + + return genlmsg_reply(msg, info); +} + +static int devlink_nl_cmd_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct devlink *devlink; + int start = cb->args[0]; + unsigned long index; + int idx = 0; + int err; + + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + if (idx < start) { + idx++; + devlink_put(devlink); + continue; + } + + devl_lock(devlink); + err = devlink_nl_fill(msg, devlink, DEVLINK_CMD_NEW, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI); + devl_unlock(devlink); + devlink_put(devlink); + + if (err) + goto out; + idx++; + } +out: + cb->args[0] = idx; + return msg->len; +} + +static int devlink_nl_cmd_port_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink_port *devlink_port = info->user_ptr[1]; + struct sk_buff *msg; + int err; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_port_fill(msg, devlink_port, DEVLINK_CMD_PORT_NEW, + info->snd_portid, info->snd_seq, 0, + info->extack); + if (err) { + nlmsg_free(msg); + return err; + } + + return genlmsg_reply(msg, info); +} + +static int devlink_nl_cmd_port_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct devlink *devlink; + struct devlink_port *devlink_port; + int start = cb->args[0]; + unsigned long index; + int idx = 0; + int err; + + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + devl_lock(devlink); + list_for_each_entry(devlink_port, &devlink->port_list, list) { + if (idx < start) { + idx++; + continue; + } + err = devlink_nl_port_fill(msg, devlink_port, + DEVLINK_CMD_NEW, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI, cb->extack); + if (err) { + devl_unlock(devlink); + devlink_put(devlink); + goto out; + } + idx++; + } + devl_unlock(devlink); + devlink_put(devlink); + } +out: + cb->args[0] = idx; + return msg->len; +} + +static int devlink_port_type_set(struct devlink_port *devlink_port, + enum devlink_port_type port_type) + +{ + int err; + + if (!devlink_port->devlink->ops->port_type_set) + return -EOPNOTSUPP; + + if (port_type == devlink_port->type) + return 0; + + err = devlink_port->devlink->ops->port_type_set(devlink_port, + port_type); + if (err) + return err; + + devlink_port->desired_type = port_type; + devlink_port_notify(devlink_port, DEVLINK_CMD_PORT_NEW); + return 0; +} + +static int devlink_port_function_hw_addr_set(struct devlink_port *port, + const struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + const struct devlink_ops *ops = port->devlink->ops; + const u8 *hw_addr; + int hw_addr_len; + + hw_addr = nla_data(attr); + hw_addr_len = nla_len(attr); + if (hw_addr_len > MAX_ADDR_LEN) { + NL_SET_ERR_MSG_MOD(extack, "Port function hardware address too long"); + return -EINVAL; + } + if (port->type == DEVLINK_PORT_TYPE_ETH) { + if (hw_addr_len != ETH_ALEN) { + NL_SET_ERR_MSG_MOD(extack, "Address must be 6 bytes for Ethernet device"); + return -EINVAL; + } + if (!is_unicast_ether_addr(hw_addr)) { + NL_SET_ERR_MSG_MOD(extack, "Non-unicast hardware address unsupported"); + return -EINVAL; + } + } + + if (!ops->port_function_hw_addr_set) { + NL_SET_ERR_MSG_MOD(extack, "Port doesn't support function attributes"); + return -EOPNOTSUPP; + } + + return ops->port_function_hw_addr_set(port, hw_addr, hw_addr_len, + extack); +} + +static int devlink_port_fn_state_set(struct devlink_port *port, + const struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + enum devlink_port_fn_state state; + const struct devlink_ops *ops; + + state = nla_get_u8(attr); + ops = port->devlink->ops; + if (!ops->port_fn_state_set) { + NL_SET_ERR_MSG_MOD(extack, + "Function does not support state setting"); + return -EOPNOTSUPP; + } + return ops->port_fn_state_set(port, state, extack); +} + +static int devlink_port_function_set(struct devlink_port *port, + const struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[DEVLINK_PORT_FUNCTION_ATTR_MAX + 1]; + int err; + + err = nla_parse_nested(tb, DEVLINK_PORT_FUNCTION_ATTR_MAX, attr, + devlink_function_nl_policy, extack); + if (err < 0) { + NL_SET_ERR_MSG_MOD(extack, "Fail to parse port function attributes"); + return err; + } + + attr = tb[DEVLINK_PORT_FUNCTION_ATTR_HW_ADDR]; + if (attr) { + err = devlink_port_function_hw_addr_set(port, attr, extack); + if (err) + return err; + } + /* Keep this as the last function attribute set, so that when + * multiple port function attributes are set along with state, + * Those can be applied first before activating the state. + */ + attr = tb[DEVLINK_PORT_FN_ATTR_STATE]; + if (attr) + err = devlink_port_fn_state_set(port, attr, extack); + + if (!err) + devlink_port_notify(port, DEVLINK_CMD_PORT_NEW); + return err; +} + +static int devlink_nl_cmd_port_set_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink_port *devlink_port = info->user_ptr[1]; + int err; + + if (info->attrs[DEVLINK_ATTR_PORT_TYPE]) { + enum devlink_port_type port_type; + + port_type = nla_get_u16(info->attrs[DEVLINK_ATTR_PORT_TYPE]); + err = devlink_port_type_set(devlink_port, port_type); + if (err) + return err; + } + + if (info->attrs[DEVLINK_ATTR_PORT_FUNCTION]) { + struct nlattr *attr = info->attrs[DEVLINK_ATTR_PORT_FUNCTION]; + struct netlink_ext_ack *extack = info->extack; + + err = devlink_port_function_set(devlink_port, attr, extack); + if (err) + return err; + } + + return 0; +} + +static int devlink_nl_cmd_port_split_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink_port *devlink_port = info->user_ptr[1]; + struct devlink *devlink = info->user_ptr[0]; + u32 count; + + if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_PORT_SPLIT_COUNT)) + return -EINVAL; + if (!devlink->ops->port_split) + return -EOPNOTSUPP; + + count = nla_get_u32(info->attrs[DEVLINK_ATTR_PORT_SPLIT_COUNT]); + + if (!devlink_port->attrs.splittable) { + /* Split ports cannot be split. */ + if (devlink_port->attrs.split) + NL_SET_ERR_MSG_MOD(info->extack, "Port cannot be split further"); + else + NL_SET_ERR_MSG_MOD(info->extack, "Port cannot be split"); + return -EINVAL; + } + + if (count < 2 || !is_power_of_2(count) || count > devlink_port->attrs.lanes) { + NL_SET_ERR_MSG_MOD(info->extack, "Invalid split count"); + return -EINVAL; + } + + return devlink->ops->port_split(devlink, devlink_port, count, + info->extack); +} + +static int devlink_nl_cmd_port_unsplit_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink_port *devlink_port = info->user_ptr[1]; + struct devlink *devlink = info->user_ptr[0]; + + if (!devlink->ops->port_unsplit) + return -EOPNOTSUPP; + return devlink->ops->port_unsplit(devlink, devlink_port, info->extack); +} + +static int devlink_port_new_notify(struct devlink *devlink, + unsigned int port_index, + struct genl_info *info) +{ + struct devlink_port *devlink_port; + struct sk_buff *msg; + int err; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + lockdep_assert_held(&devlink->lock); + devlink_port = devlink_port_get_by_index(devlink, port_index); + if (!devlink_port) { + err = -ENODEV; + goto out; + } + + err = devlink_nl_port_fill(msg, devlink_port, DEVLINK_CMD_NEW, + info->snd_portid, info->snd_seq, 0, NULL); + if (err) + goto out; + + return genlmsg_reply(msg, info); + +out: + nlmsg_free(msg); + return err; +} + +static int devlink_nl_cmd_port_new_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct netlink_ext_ack *extack = info->extack; + struct devlink_port_new_attrs new_attrs = {}; + struct devlink *devlink = info->user_ptr[0]; + unsigned int new_port_index; + int err; + + if (!devlink->ops->port_new || !devlink->ops->port_del) + return -EOPNOTSUPP; + + if (!info->attrs[DEVLINK_ATTR_PORT_FLAVOUR] || + !info->attrs[DEVLINK_ATTR_PORT_PCI_PF_NUMBER]) { + NL_SET_ERR_MSG_MOD(extack, "Port flavour or PCI PF are not specified"); + return -EINVAL; + } + new_attrs.flavour = nla_get_u16(info->attrs[DEVLINK_ATTR_PORT_FLAVOUR]); + new_attrs.pfnum = + nla_get_u16(info->attrs[DEVLINK_ATTR_PORT_PCI_PF_NUMBER]); + + if (info->attrs[DEVLINK_ATTR_PORT_INDEX]) { + /* Port index of the new port being created by driver. */ + new_attrs.port_index = + nla_get_u32(info->attrs[DEVLINK_ATTR_PORT_INDEX]); + new_attrs.port_index_valid = true; + } + if (info->attrs[DEVLINK_ATTR_PORT_CONTROLLER_NUMBER]) { + new_attrs.controller = + nla_get_u16(info->attrs[DEVLINK_ATTR_PORT_CONTROLLER_NUMBER]); + new_attrs.controller_valid = true; + } + if (new_attrs.flavour == DEVLINK_PORT_FLAVOUR_PCI_SF && + info->attrs[DEVLINK_ATTR_PORT_PCI_SF_NUMBER]) { + new_attrs.sfnum = nla_get_u32(info->attrs[DEVLINK_ATTR_PORT_PCI_SF_NUMBER]); + new_attrs.sfnum_valid = true; + } + + err = devlink->ops->port_new(devlink, &new_attrs, extack, + &new_port_index); + if (err) + return err; + + err = devlink_port_new_notify(devlink, new_port_index, info); + if (err && err != -ENODEV) { + /* Fail to send the response; destroy newly created port. */ + devlink->ops->port_del(devlink, new_port_index, extack); + } + return err; +} + +static int devlink_nl_cmd_port_del_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct netlink_ext_ack *extack = info->extack; + struct devlink *devlink = info->user_ptr[0]; + unsigned int port_index; + + if (!devlink->ops->port_del) + return -EOPNOTSUPP; + + if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_PORT_INDEX)) { + NL_SET_ERR_MSG_MOD(extack, "Port index is not specified"); + return -EINVAL; + } + port_index = nla_get_u32(info->attrs[DEVLINK_ATTR_PORT_INDEX]); + + return devlink->ops->port_del(devlink, port_index, extack); +} + +static int +devlink_nl_rate_parent_node_set(struct devlink_rate *devlink_rate, + struct genl_info *info, + struct nlattr *nla_parent) +{ + struct devlink *devlink = devlink_rate->devlink; + const char *parent_name = nla_data(nla_parent); + const struct devlink_ops *ops = devlink->ops; + size_t len = strlen(parent_name); + struct devlink_rate *parent; + int err = -EOPNOTSUPP; + + parent = devlink_rate->parent; + if (parent && len) { + NL_SET_ERR_MSG_MOD(info->extack, "Rate object already has parent."); + return -EBUSY; + } else if (parent && !len) { + if (devlink_rate_is_leaf(devlink_rate)) + err = ops->rate_leaf_parent_set(devlink_rate, NULL, + devlink_rate->priv, NULL, + info->extack); + else if (devlink_rate_is_node(devlink_rate)) + err = ops->rate_node_parent_set(devlink_rate, NULL, + devlink_rate->priv, NULL, + info->extack); + if (err) + return err; + + refcount_dec(&parent->refcnt); + devlink_rate->parent = NULL; + } else if (!parent && len) { + parent = devlink_rate_node_get_by_name(devlink, parent_name); + if (IS_ERR(parent)) + return -ENODEV; + + if (parent == devlink_rate) { + NL_SET_ERR_MSG_MOD(info->extack, "Parent to self is not allowed"); + return -EINVAL; + } + + if (devlink_rate_is_node(devlink_rate) && + devlink_rate_is_parent_node(devlink_rate, parent->parent)) { + NL_SET_ERR_MSG_MOD(info->extack, "Node is already a parent of parent node."); + return -EEXIST; + } + + if (devlink_rate_is_leaf(devlink_rate)) + err = ops->rate_leaf_parent_set(devlink_rate, parent, + devlink_rate->priv, parent->priv, + info->extack); + else if (devlink_rate_is_node(devlink_rate)) + err = ops->rate_node_parent_set(devlink_rate, parent, + devlink_rate->priv, parent->priv, + info->extack); + if (err) + return err; + + refcount_inc(&parent->refcnt); + devlink_rate->parent = parent; + } + + return 0; +} + +static int devlink_nl_rate_set(struct devlink_rate *devlink_rate, + const struct devlink_ops *ops, + struct genl_info *info) +{ + struct nlattr *nla_parent, **attrs = info->attrs; + int err = -EOPNOTSUPP; + u64 rate; + + if (attrs[DEVLINK_ATTR_RATE_TX_SHARE]) { + rate = nla_get_u64(attrs[DEVLINK_ATTR_RATE_TX_SHARE]); + if (devlink_rate_is_leaf(devlink_rate)) + err = ops->rate_leaf_tx_share_set(devlink_rate, devlink_rate->priv, + rate, info->extack); + else if (devlink_rate_is_node(devlink_rate)) + err = ops->rate_node_tx_share_set(devlink_rate, devlink_rate->priv, + rate, info->extack); + if (err) + return err; + devlink_rate->tx_share = rate; + } + + if (attrs[DEVLINK_ATTR_RATE_TX_MAX]) { + rate = nla_get_u64(attrs[DEVLINK_ATTR_RATE_TX_MAX]); + if (devlink_rate_is_leaf(devlink_rate)) + err = ops->rate_leaf_tx_max_set(devlink_rate, devlink_rate->priv, + rate, info->extack); + else if (devlink_rate_is_node(devlink_rate)) + err = ops->rate_node_tx_max_set(devlink_rate, devlink_rate->priv, + rate, info->extack); + if (err) + return err; + devlink_rate->tx_max = rate; + } + + nla_parent = attrs[DEVLINK_ATTR_RATE_PARENT_NODE_NAME]; + if (nla_parent) { + err = devlink_nl_rate_parent_node_set(devlink_rate, info, + nla_parent); + if (err) + return err; + } + + return 0; +} + +static bool devlink_rate_set_ops_supported(const struct devlink_ops *ops, + struct genl_info *info, + enum devlink_rate_type type) +{ + struct nlattr **attrs = info->attrs; + + if (type == DEVLINK_RATE_TYPE_LEAF) { + if (attrs[DEVLINK_ATTR_RATE_TX_SHARE] && !ops->rate_leaf_tx_share_set) { + NL_SET_ERR_MSG_MOD(info->extack, "TX share set isn't supported for the leafs"); + return false; + } + if (attrs[DEVLINK_ATTR_RATE_TX_MAX] && !ops->rate_leaf_tx_max_set) { + NL_SET_ERR_MSG_MOD(info->extack, "TX max set isn't supported for the leafs"); + return false; + } + if (attrs[DEVLINK_ATTR_RATE_PARENT_NODE_NAME] && + !ops->rate_leaf_parent_set) { + NL_SET_ERR_MSG_MOD(info->extack, "Parent set isn't supported for the leafs"); + return false; + } + } else if (type == DEVLINK_RATE_TYPE_NODE) { + if (attrs[DEVLINK_ATTR_RATE_TX_SHARE] && !ops->rate_node_tx_share_set) { + NL_SET_ERR_MSG_MOD(info->extack, "TX share set isn't supported for the nodes"); + return false; + } + if (attrs[DEVLINK_ATTR_RATE_TX_MAX] && !ops->rate_node_tx_max_set) { + NL_SET_ERR_MSG_MOD(info->extack, "TX max set isn't supported for the nodes"); + return false; + } + if (attrs[DEVLINK_ATTR_RATE_PARENT_NODE_NAME] && + !ops->rate_node_parent_set) { + NL_SET_ERR_MSG_MOD(info->extack, "Parent set isn't supported for the nodes"); + return false; + } + } else { + WARN(1, "Unknown type of rate object"); + return false; + } + + return true; +} + +static int devlink_nl_cmd_rate_set_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink_rate *devlink_rate = info->user_ptr[1]; + struct devlink *devlink = devlink_rate->devlink; + const struct devlink_ops *ops = devlink->ops; + int err; + + if (!ops || !devlink_rate_set_ops_supported(ops, info, devlink_rate->type)) + return -EOPNOTSUPP; + + err = devlink_nl_rate_set(devlink_rate, ops, info); + + if (!err) + devlink_rate_notify(devlink_rate, DEVLINK_CMD_RATE_NEW); + return err; +} + +static int devlink_nl_cmd_rate_new_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_rate *rate_node; + const struct devlink_ops *ops; + int err; + + ops = devlink->ops; + if (!ops || !ops->rate_node_new || !ops->rate_node_del) { + NL_SET_ERR_MSG_MOD(info->extack, "Rate nodes aren't supported"); + return -EOPNOTSUPP; + } + + if (!devlink_rate_set_ops_supported(ops, info, DEVLINK_RATE_TYPE_NODE)) + return -EOPNOTSUPP; + + rate_node = devlink_rate_node_get_from_attrs(devlink, info->attrs); + if (!IS_ERR(rate_node)) + return -EEXIST; + else if (rate_node == ERR_PTR(-EINVAL)) + return -EINVAL; + + rate_node = kzalloc(sizeof(*rate_node), GFP_KERNEL); + if (!rate_node) + return -ENOMEM; + + rate_node->devlink = devlink; + rate_node->type = DEVLINK_RATE_TYPE_NODE; + rate_node->name = nla_strdup(info->attrs[DEVLINK_ATTR_RATE_NODE_NAME], GFP_KERNEL); + if (!rate_node->name) { + err = -ENOMEM; + goto err_strdup; + } + + err = ops->rate_node_new(rate_node, &rate_node->priv, info->extack); + if (err) + goto err_node_new; + + err = devlink_nl_rate_set(rate_node, ops, info); + if (err) + goto err_rate_set; + + refcount_set(&rate_node->refcnt, 1); + list_add(&rate_node->list, &devlink->rate_list); + devlink_rate_notify(rate_node, DEVLINK_CMD_RATE_NEW); + return 0; + +err_rate_set: + ops->rate_node_del(rate_node, rate_node->priv, info->extack); +err_node_new: + kfree(rate_node->name); +err_strdup: + kfree(rate_node); + return err; +} + +static int devlink_nl_cmd_rate_del_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink_rate *rate_node = info->user_ptr[1]; + struct devlink *devlink = rate_node->devlink; + const struct devlink_ops *ops = devlink->ops; + int err; + + if (refcount_read(&rate_node->refcnt) > 1) { + NL_SET_ERR_MSG_MOD(info->extack, "Node has children. Cannot delete node."); + return -EBUSY; + } + + devlink_rate_notify(rate_node, DEVLINK_CMD_RATE_DEL); + err = ops->rate_node_del(rate_node, rate_node->priv, info->extack); + if (rate_node->parent) + refcount_dec(&rate_node->parent->refcnt); + list_del(&rate_node->list); + kfree(rate_node->name); + kfree(rate_node); + return err; +} + +struct devlink_linecard_type { + const char *type; + const void *priv; +}; + +static int devlink_nl_linecard_fill(struct sk_buff *msg, + struct devlink *devlink, + struct devlink_linecard *linecard, + enum devlink_command cmd, u32 portid, + u32 seq, int flags, + struct netlink_ext_ack *extack) +{ + struct devlink_linecard_type *linecard_type; + struct nlattr *attr; + void *hdr; + int i; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + if (devlink_nl_put_handle(msg, devlink)) + goto nla_put_failure; + if (nla_put_u32(msg, DEVLINK_ATTR_LINECARD_INDEX, linecard->index)) + goto nla_put_failure; + if (nla_put_u8(msg, DEVLINK_ATTR_LINECARD_STATE, linecard->state)) + goto nla_put_failure; + if (linecard->type && + nla_put_string(msg, DEVLINK_ATTR_LINECARD_TYPE, linecard->type)) + goto nla_put_failure; + + if (linecard->types_count) { + attr = nla_nest_start(msg, + DEVLINK_ATTR_LINECARD_SUPPORTED_TYPES); + if (!attr) + goto nla_put_failure; + for (i = 0; i < linecard->types_count; i++) { + linecard_type = &linecard->types[i]; + if (nla_put_string(msg, DEVLINK_ATTR_LINECARD_TYPE, + linecard_type->type)) { + nla_nest_cancel(msg, attr); + goto nla_put_failure; + } + } + nla_nest_end(msg, attr); + } + + if (linecard->nested_devlink && + devlink_nl_put_nested_handle(msg, linecard->nested_devlink)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static void devlink_linecard_notify(struct devlink_linecard *linecard, + enum devlink_command cmd) +{ + struct devlink *devlink = linecard->devlink; + struct sk_buff *msg; + int err; + + WARN_ON(cmd != DEVLINK_CMD_LINECARD_NEW && + cmd != DEVLINK_CMD_LINECARD_DEL); + + if (!xa_get_mark(&devlinks, devlink->index, DEVLINK_REGISTERED)) + return; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + err = devlink_nl_linecard_fill(msg, devlink, linecard, cmd, 0, 0, 0, + NULL); + if (err) { + nlmsg_free(msg); + return; + } + + genlmsg_multicast_netns(&devlink_nl_family, devlink_net(devlink), + msg, 0, DEVLINK_MCGRP_CONFIG, GFP_KERNEL); +} + +static int devlink_nl_cmd_linecard_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink_linecard *linecard = info->user_ptr[1]; + struct devlink *devlink = linecard->devlink; + struct sk_buff *msg; + int err; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + mutex_lock(&linecard->state_lock); + err = devlink_nl_linecard_fill(msg, devlink, linecard, + DEVLINK_CMD_LINECARD_NEW, + info->snd_portid, info->snd_seq, 0, + info->extack); + mutex_unlock(&linecard->state_lock); + if (err) { + nlmsg_free(msg); + return err; + } + + return genlmsg_reply(msg, info); +} + +static int devlink_nl_cmd_linecard_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct devlink_linecard *linecard; + struct devlink *devlink; + int start = cb->args[0]; + unsigned long index; + int idx = 0; + int err; + + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + mutex_lock(&devlink->linecards_lock); + list_for_each_entry(linecard, &devlink->linecard_list, list) { + if (idx < start) { + idx++; + continue; + } + mutex_lock(&linecard->state_lock); + err = devlink_nl_linecard_fill(msg, devlink, linecard, + DEVLINK_CMD_LINECARD_NEW, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI, + cb->extack); + mutex_unlock(&linecard->state_lock); + if (err) { + mutex_unlock(&devlink->linecards_lock); + devlink_put(devlink); + goto out; + } + idx++; + } + mutex_unlock(&devlink->linecards_lock); + devlink_put(devlink); + } +out: + cb->args[0] = idx; + return msg->len; +} + +static struct devlink_linecard_type * +devlink_linecard_type_lookup(struct devlink_linecard *linecard, + const char *type) +{ + struct devlink_linecard_type *linecard_type; + int i; + + for (i = 0; i < linecard->types_count; i++) { + linecard_type = &linecard->types[i]; + if (!strcmp(type, linecard_type->type)) + return linecard_type; + } + return NULL; +} + +static int devlink_linecard_type_set(struct devlink_linecard *linecard, + const char *type, + struct netlink_ext_ack *extack) +{ + const struct devlink_linecard_ops *ops = linecard->ops; + struct devlink_linecard_type *linecard_type; + int err; + + mutex_lock(&linecard->state_lock); + if (linecard->state == DEVLINK_LINECARD_STATE_PROVISIONING) { + NL_SET_ERR_MSG_MOD(extack, "Line card is currently being provisioned"); + err = -EBUSY; + goto out; + } + if (linecard->state == DEVLINK_LINECARD_STATE_UNPROVISIONING) { + NL_SET_ERR_MSG_MOD(extack, "Line card is currently being unprovisioned"); + err = -EBUSY; + goto out; + } + + linecard_type = devlink_linecard_type_lookup(linecard, type); + if (!linecard_type) { + NL_SET_ERR_MSG_MOD(extack, "Unsupported line card type provided"); + err = -EINVAL; + goto out; + } + + if (linecard->state != DEVLINK_LINECARD_STATE_UNPROVISIONED && + linecard->state != DEVLINK_LINECARD_STATE_PROVISIONING_FAILED) { + NL_SET_ERR_MSG_MOD(extack, "Line card already provisioned"); + err = -EBUSY; + /* Check if the line card is provisioned in the same + * way the user asks. In case it is, make the operation + * to return success. + */ + if (ops->same_provision && + ops->same_provision(linecard, linecard->priv, + linecard_type->type, + linecard_type->priv)) + err = 0; + goto out; + } + + linecard->state = DEVLINK_LINECARD_STATE_PROVISIONING; + linecard->type = linecard_type->type; + devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_NEW); + mutex_unlock(&linecard->state_lock); + err = ops->provision(linecard, linecard->priv, linecard_type->type, + linecard_type->priv, extack); + if (err) { + /* Provisioning failed. Assume the linecard is unprovisioned + * for future operations. + */ + mutex_lock(&linecard->state_lock); + linecard->state = DEVLINK_LINECARD_STATE_UNPROVISIONED; + linecard->type = NULL; + devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_NEW); + mutex_unlock(&linecard->state_lock); + } + return err; + +out: + mutex_unlock(&linecard->state_lock); + return err; +} + +static int devlink_linecard_type_unset(struct devlink_linecard *linecard, + struct netlink_ext_ack *extack) +{ + int err; + + mutex_lock(&linecard->state_lock); + if (linecard->state == DEVLINK_LINECARD_STATE_PROVISIONING) { + NL_SET_ERR_MSG_MOD(extack, "Line card is currently being provisioned"); + err = -EBUSY; + goto out; + } + if (linecard->state == DEVLINK_LINECARD_STATE_UNPROVISIONING) { + NL_SET_ERR_MSG_MOD(extack, "Line card is currently being unprovisioned"); + err = -EBUSY; + goto out; + } + if (linecard->state == DEVLINK_LINECARD_STATE_PROVISIONING_FAILED) { + linecard->state = DEVLINK_LINECARD_STATE_UNPROVISIONED; + linecard->type = NULL; + devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_NEW); + err = 0; + goto out; + } + + if (linecard->state == DEVLINK_LINECARD_STATE_UNPROVISIONED) { + NL_SET_ERR_MSG_MOD(extack, "Line card is not provisioned"); + err = 0; + goto out; + } + linecard->state = DEVLINK_LINECARD_STATE_UNPROVISIONING; + devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_NEW); + mutex_unlock(&linecard->state_lock); + err = linecard->ops->unprovision(linecard, linecard->priv, + extack); + if (err) { + /* Unprovisioning failed. Assume the linecard is unprovisioned + * for future operations. + */ + mutex_lock(&linecard->state_lock); + linecard->state = DEVLINK_LINECARD_STATE_UNPROVISIONED; + linecard->type = NULL; + devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_NEW); + mutex_unlock(&linecard->state_lock); + } + return err; + +out: + mutex_unlock(&linecard->state_lock); + return err; +} + +static int devlink_nl_cmd_linecard_set_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink_linecard *linecard = info->user_ptr[1]; + struct netlink_ext_ack *extack = info->extack; + int err; + + if (info->attrs[DEVLINK_ATTR_LINECARD_TYPE]) { + const char *type; + + type = nla_data(info->attrs[DEVLINK_ATTR_LINECARD_TYPE]); + if (strcmp(type, "")) { + err = devlink_linecard_type_set(linecard, type, extack); + if (err) + return err; + } else { + err = devlink_linecard_type_unset(linecard, extack); + if (err) + return err; + } + } + + return 0; +} + +static int devlink_nl_sb_fill(struct sk_buff *msg, struct devlink *devlink, + struct devlink_sb *devlink_sb, + enum devlink_command cmd, u32 portid, + u32 seq, int flags) +{ + void *hdr; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + if (devlink_nl_put_handle(msg, devlink)) + goto nla_put_failure; + if (nla_put_u32(msg, DEVLINK_ATTR_SB_INDEX, devlink_sb->index)) + goto nla_put_failure; + if (nla_put_u32(msg, DEVLINK_ATTR_SB_SIZE, devlink_sb->size)) + goto nla_put_failure; + if (nla_put_u16(msg, DEVLINK_ATTR_SB_INGRESS_POOL_COUNT, + devlink_sb->ingress_pools_count)) + goto nla_put_failure; + if (nla_put_u16(msg, DEVLINK_ATTR_SB_EGRESS_POOL_COUNT, + devlink_sb->egress_pools_count)) + goto nla_put_failure; + if (nla_put_u16(msg, DEVLINK_ATTR_SB_INGRESS_TC_COUNT, + devlink_sb->ingress_tc_count)) + goto nla_put_failure; + if (nla_put_u16(msg, DEVLINK_ATTR_SB_EGRESS_TC_COUNT, + devlink_sb->egress_tc_count)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int devlink_nl_cmd_sb_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_sb *devlink_sb; + struct sk_buff *msg; + int err; + + devlink_sb = devlink_sb_get_from_info(devlink, info); + if (IS_ERR(devlink_sb)) + return PTR_ERR(devlink_sb); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_sb_fill(msg, devlink, devlink_sb, + DEVLINK_CMD_SB_NEW, + info->snd_portid, info->snd_seq, 0); + if (err) { + nlmsg_free(msg); + return err; + } + + return genlmsg_reply(msg, info); +} + +static int devlink_nl_cmd_sb_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct devlink *devlink; + struct devlink_sb *devlink_sb; + int start = cb->args[0]; + unsigned long index; + int idx = 0; + int err; + + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + devl_lock(devlink); + list_for_each_entry(devlink_sb, &devlink->sb_list, list) { + if (idx < start) { + idx++; + continue; + } + err = devlink_nl_sb_fill(msg, devlink, devlink_sb, + DEVLINK_CMD_SB_NEW, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI); + if (err) { + devl_unlock(devlink); + devlink_put(devlink); + goto out; + } + idx++; + } + devl_unlock(devlink); + devlink_put(devlink); + } +out: + cb->args[0] = idx; + return msg->len; +} + +static int devlink_nl_sb_pool_fill(struct sk_buff *msg, struct devlink *devlink, + struct devlink_sb *devlink_sb, + u16 pool_index, enum devlink_command cmd, + u32 portid, u32 seq, int flags) +{ + struct devlink_sb_pool_info pool_info; + void *hdr; + int err; + + err = devlink->ops->sb_pool_get(devlink, devlink_sb->index, + pool_index, &pool_info); + if (err) + return err; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + if (devlink_nl_put_handle(msg, devlink)) + goto nla_put_failure; + if (nla_put_u32(msg, DEVLINK_ATTR_SB_INDEX, devlink_sb->index)) + goto nla_put_failure; + if (nla_put_u16(msg, DEVLINK_ATTR_SB_POOL_INDEX, pool_index)) + goto nla_put_failure; + if (nla_put_u8(msg, DEVLINK_ATTR_SB_POOL_TYPE, pool_info.pool_type)) + goto nla_put_failure; + if (nla_put_u32(msg, DEVLINK_ATTR_SB_POOL_SIZE, pool_info.size)) + goto nla_put_failure; + if (nla_put_u8(msg, DEVLINK_ATTR_SB_POOL_THRESHOLD_TYPE, + pool_info.threshold_type)) + goto nla_put_failure; + if (nla_put_u32(msg, DEVLINK_ATTR_SB_POOL_CELL_SIZE, + pool_info.cell_size)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int devlink_nl_cmd_sb_pool_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_sb *devlink_sb; + struct sk_buff *msg; + u16 pool_index; + int err; + + devlink_sb = devlink_sb_get_from_info(devlink, info); + if (IS_ERR(devlink_sb)) + return PTR_ERR(devlink_sb); + + err = devlink_sb_pool_index_get_from_info(devlink_sb, info, + &pool_index); + if (err) + return err; + + if (!devlink->ops->sb_pool_get) + return -EOPNOTSUPP; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_sb_pool_fill(msg, devlink, devlink_sb, pool_index, + DEVLINK_CMD_SB_POOL_NEW, + info->snd_portid, info->snd_seq, 0); + if (err) { + nlmsg_free(msg); + return err; + } + + return genlmsg_reply(msg, info); +} + +static int __sb_pool_get_dumpit(struct sk_buff *msg, int start, int *p_idx, + struct devlink *devlink, + struct devlink_sb *devlink_sb, + u32 portid, u32 seq) +{ + u16 pool_count = devlink_sb_pool_count(devlink_sb); + u16 pool_index; + int err; + + for (pool_index = 0; pool_index < pool_count; pool_index++) { + if (*p_idx < start) { + (*p_idx)++; + continue; + } + err = devlink_nl_sb_pool_fill(msg, devlink, + devlink_sb, + pool_index, + DEVLINK_CMD_SB_POOL_NEW, + portid, seq, NLM_F_MULTI); + if (err) + return err; + (*p_idx)++; + } + return 0; +} + +static int devlink_nl_cmd_sb_pool_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct devlink *devlink; + struct devlink_sb *devlink_sb; + int start = cb->args[0]; + unsigned long index; + int idx = 0; + int err = 0; + + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + if (!devlink->ops->sb_pool_get) + goto retry; + + devl_lock(devlink); + list_for_each_entry(devlink_sb, &devlink->sb_list, list) { + err = __sb_pool_get_dumpit(msg, start, &idx, devlink, + devlink_sb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq); + if (err == -EOPNOTSUPP) { + err = 0; + } else if (err) { + devl_unlock(devlink); + devlink_put(devlink); + goto out; + } + } + devl_unlock(devlink); +retry: + devlink_put(devlink); + } +out: + if (err != -EMSGSIZE) + return err; + + cb->args[0] = idx; + return msg->len; +} + +static int devlink_sb_pool_set(struct devlink *devlink, unsigned int sb_index, + u16 pool_index, u32 size, + enum devlink_sb_threshold_type threshold_type, + struct netlink_ext_ack *extack) + +{ + const struct devlink_ops *ops = devlink->ops; + + if (ops->sb_pool_set) + return ops->sb_pool_set(devlink, sb_index, pool_index, + size, threshold_type, extack); + return -EOPNOTSUPP; +} + +static int devlink_nl_cmd_sb_pool_set_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + enum devlink_sb_threshold_type threshold_type; + struct devlink_sb *devlink_sb; + u16 pool_index; + u32 size; + int err; + + devlink_sb = devlink_sb_get_from_info(devlink, info); + if (IS_ERR(devlink_sb)) + return PTR_ERR(devlink_sb); + + err = devlink_sb_pool_index_get_from_info(devlink_sb, info, + &pool_index); + if (err) + return err; + + err = devlink_sb_th_type_get_from_info(info, &threshold_type); + if (err) + return err; + + if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_SB_POOL_SIZE)) + return -EINVAL; + + size = nla_get_u32(info->attrs[DEVLINK_ATTR_SB_POOL_SIZE]); + return devlink_sb_pool_set(devlink, devlink_sb->index, + pool_index, size, threshold_type, + info->extack); +} + +static int devlink_nl_sb_port_pool_fill(struct sk_buff *msg, + struct devlink *devlink, + struct devlink_port *devlink_port, + struct devlink_sb *devlink_sb, + u16 pool_index, + enum devlink_command cmd, + u32 portid, u32 seq, int flags) +{ + const struct devlink_ops *ops = devlink->ops; + u32 threshold; + void *hdr; + int err; + + err = ops->sb_port_pool_get(devlink_port, devlink_sb->index, + pool_index, &threshold); + if (err) + return err; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + if (devlink_nl_put_handle(msg, devlink)) + goto nla_put_failure; + if (nla_put_u32(msg, DEVLINK_ATTR_PORT_INDEX, devlink_port->index)) + goto nla_put_failure; + if (nla_put_u32(msg, DEVLINK_ATTR_SB_INDEX, devlink_sb->index)) + goto nla_put_failure; + if (nla_put_u16(msg, DEVLINK_ATTR_SB_POOL_INDEX, pool_index)) + goto nla_put_failure; + if (nla_put_u32(msg, DEVLINK_ATTR_SB_THRESHOLD, threshold)) + goto nla_put_failure; + + if (ops->sb_occ_port_pool_get) { + u32 cur; + u32 max; + + err = ops->sb_occ_port_pool_get(devlink_port, devlink_sb->index, + pool_index, &cur, &max); + if (err && err != -EOPNOTSUPP) + goto sb_occ_get_failure; + if (!err) { + if (nla_put_u32(msg, DEVLINK_ATTR_SB_OCC_CUR, cur)) + goto nla_put_failure; + if (nla_put_u32(msg, DEVLINK_ATTR_SB_OCC_MAX, max)) + goto nla_put_failure; + } + } + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + err = -EMSGSIZE; +sb_occ_get_failure: + genlmsg_cancel(msg, hdr); + return err; +} + +static int devlink_nl_cmd_sb_port_pool_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink_port *devlink_port = info->user_ptr[1]; + struct devlink *devlink = devlink_port->devlink; + struct devlink_sb *devlink_sb; + struct sk_buff *msg; + u16 pool_index; + int err; + + devlink_sb = devlink_sb_get_from_info(devlink, info); + if (IS_ERR(devlink_sb)) + return PTR_ERR(devlink_sb); + + err = devlink_sb_pool_index_get_from_info(devlink_sb, info, + &pool_index); + if (err) + return err; + + if (!devlink->ops->sb_port_pool_get) + return -EOPNOTSUPP; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_sb_port_pool_fill(msg, devlink, devlink_port, + devlink_sb, pool_index, + DEVLINK_CMD_SB_PORT_POOL_NEW, + info->snd_portid, info->snd_seq, 0); + if (err) { + nlmsg_free(msg); + return err; + } + + return genlmsg_reply(msg, info); +} + +static int __sb_port_pool_get_dumpit(struct sk_buff *msg, int start, int *p_idx, + struct devlink *devlink, + struct devlink_sb *devlink_sb, + u32 portid, u32 seq) +{ + struct devlink_port *devlink_port; + u16 pool_count = devlink_sb_pool_count(devlink_sb); + u16 pool_index; + int err; + + list_for_each_entry(devlink_port, &devlink->port_list, list) { + for (pool_index = 0; pool_index < pool_count; pool_index++) { + if (*p_idx < start) { + (*p_idx)++; + continue; + } + err = devlink_nl_sb_port_pool_fill(msg, devlink, + devlink_port, + devlink_sb, + pool_index, + DEVLINK_CMD_SB_PORT_POOL_NEW, + portid, seq, + NLM_F_MULTI); + if (err) + return err; + (*p_idx)++; + } + } + return 0; +} + +static int devlink_nl_cmd_sb_port_pool_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct devlink *devlink; + struct devlink_sb *devlink_sb; + int start = cb->args[0]; + unsigned long index; + int idx = 0; + int err = 0; + + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + if (!devlink->ops->sb_port_pool_get) + goto retry; + + devl_lock(devlink); + list_for_each_entry(devlink_sb, &devlink->sb_list, list) { + err = __sb_port_pool_get_dumpit(msg, start, &idx, + devlink, devlink_sb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq); + if (err == -EOPNOTSUPP) { + err = 0; + } else if (err) { + devl_unlock(devlink); + devlink_put(devlink); + goto out; + } + } + devl_unlock(devlink); +retry: + devlink_put(devlink); + } +out: + if (err != -EMSGSIZE) + return err; + + cb->args[0] = idx; + return msg->len; +} + +static int devlink_sb_port_pool_set(struct devlink_port *devlink_port, + unsigned int sb_index, u16 pool_index, + u32 threshold, + struct netlink_ext_ack *extack) + +{ + const struct devlink_ops *ops = devlink_port->devlink->ops; + + if (ops->sb_port_pool_set) + return ops->sb_port_pool_set(devlink_port, sb_index, + pool_index, threshold, extack); + return -EOPNOTSUPP; +} + +static int devlink_nl_cmd_sb_port_pool_set_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink_port *devlink_port = info->user_ptr[1]; + struct devlink *devlink = info->user_ptr[0]; + struct devlink_sb *devlink_sb; + u16 pool_index; + u32 threshold; + int err; + + devlink_sb = devlink_sb_get_from_info(devlink, info); + if (IS_ERR(devlink_sb)) + return PTR_ERR(devlink_sb); + + err = devlink_sb_pool_index_get_from_info(devlink_sb, info, + &pool_index); + if (err) + return err; + + if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_SB_THRESHOLD)) + return -EINVAL; + + threshold = nla_get_u32(info->attrs[DEVLINK_ATTR_SB_THRESHOLD]); + return devlink_sb_port_pool_set(devlink_port, devlink_sb->index, + pool_index, threshold, info->extack); +} + +static int +devlink_nl_sb_tc_pool_bind_fill(struct sk_buff *msg, struct devlink *devlink, + struct devlink_port *devlink_port, + struct devlink_sb *devlink_sb, u16 tc_index, + enum devlink_sb_pool_type pool_type, + enum devlink_command cmd, + u32 portid, u32 seq, int flags) +{ + const struct devlink_ops *ops = devlink->ops; + u16 pool_index; + u32 threshold; + void *hdr; + int err; + + err = ops->sb_tc_pool_bind_get(devlink_port, devlink_sb->index, + tc_index, pool_type, + &pool_index, &threshold); + if (err) + return err; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + if (devlink_nl_put_handle(msg, devlink)) + goto nla_put_failure; + if (nla_put_u32(msg, DEVLINK_ATTR_PORT_INDEX, devlink_port->index)) + goto nla_put_failure; + if (nla_put_u32(msg, DEVLINK_ATTR_SB_INDEX, devlink_sb->index)) + goto nla_put_failure; + if (nla_put_u16(msg, DEVLINK_ATTR_SB_TC_INDEX, tc_index)) + goto nla_put_failure; + if (nla_put_u8(msg, DEVLINK_ATTR_SB_POOL_TYPE, pool_type)) + goto nla_put_failure; + if (nla_put_u16(msg, DEVLINK_ATTR_SB_POOL_INDEX, pool_index)) + goto nla_put_failure; + if (nla_put_u32(msg, DEVLINK_ATTR_SB_THRESHOLD, threshold)) + goto nla_put_failure; + + if (ops->sb_occ_tc_port_bind_get) { + u32 cur; + u32 max; + + err = ops->sb_occ_tc_port_bind_get(devlink_port, + devlink_sb->index, + tc_index, pool_type, + &cur, &max); + if (err && err != -EOPNOTSUPP) + return err; + if (!err) { + if (nla_put_u32(msg, DEVLINK_ATTR_SB_OCC_CUR, cur)) + goto nla_put_failure; + if (nla_put_u32(msg, DEVLINK_ATTR_SB_OCC_MAX, max)) + goto nla_put_failure; + } + } + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int devlink_nl_cmd_sb_tc_pool_bind_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink_port *devlink_port = info->user_ptr[1]; + struct devlink *devlink = devlink_port->devlink; + struct devlink_sb *devlink_sb; + struct sk_buff *msg; + enum devlink_sb_pool_type pool_type; + u16 tc_index; + int err; + + devlink_sb = devlink_sb_get_from_info(devlink, info); + if (IS_ERR(devlink_sb)) + return PTR_ERR(devlink_sb); + + err = devlink_sb_pool_type_get_from_info(info, &pool_type); + if (err) + return err; + + err = devlink_sb_tc_index_get_from_info(devlink_sb, info, + pool_type, &tc_index); + if (err) + return err; + + if (!devlink->ops->sb_tc_pool_bind_get) + return -EOPNOTSUPP; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_sb_tc_pool_bind_fill(msg, devlink, devlink_port, + devlink_sb, tc_index, pool_type, + DEVLINK_CMD_SB_TC_POOL_BIND_NEW, + info->snd_portid, + info->snd_seq, 0); + if (err) { + nlmsg_free(msg); + return err; + } + + return genlmsg_reply(msg, info); +} + +static int __sb_tc_pool_bind_get_dumpit(struct sk_buff *msg, + int start, int *p_idx, + struct devlink *devlink, + struct devlink_sb *devlink_sb, + u32 portid, u32 seq) +{ + struct devlink_port *devlink_port; + u16 tc_index; + int err; + + list_for_each_entry(devlink_port, &devlink->port_list, list) { + for (tc_index = 0; + tc_index < devlink_sb->ingress_tc_count; tc_index++) { + if (*p_idx < start) { + (*p_idx)++; + continue; + } + err = devlink_nl_sb_tc_pool_bind_fill(msg, devlink, + devlink_port, + devlink_sb, + tc_index, + DEVLINK_SB_POOL_TYPE_INGRESS, + DEVLINK_CMD_SB_TC_POOL_BIND_NEW, + portid, seq, + NLM_F_MULTI); + if (err) + return err; + (*p_idx)++; + } + for (tc_index = 0; + tc_index < devlink_sb->egress_tc_count; tc_index++) { + if (*p_idx < start) { + (*p_idx)++; + continue; + } + err = devlink_nl_sb_tc_pool_bind_fill(msg, devlink, + devlink_port, + devlink_sb, + tc_index, + DEVLINK_SB_POOL_TYPE_EGRESS, + DEVLINK_CMD_SB_TC_POOL_BIND_NEW, + portid, seq, + NLM_F_MULTI); + if (err) + return err; + (*p_idx)++; + } + } + return 0; +} + +static int +devlink_nl_cmd_sb_tc_pool_bind_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct devlink *devlink; + struct devlink_sb *devlink_sb; + int start = cb->args[0]; + unsigned long index; + int idx = 0; + int err = 0; + + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + if (!devlink->ops->sb_tc_pool_bind_get) + goto retry; + + devl_lock(devlink); + list_for_each_entry(devlink_sb, &devlink->sb_list, list) { + err = __sb_tc_pool_bind_get_dumpit(msg, start, &idx, + devlink, + devlink_sb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq); + if (err == -EOPNOTSUPP) { + err = 0; + } else if (err) { + devl_unlock(devlink); + devlink_put(devlink); + goto out; + } + } + devl_unlock(devlink); +retry: + devlink_put(devlink); + } +out: + if (err != -EMSGSIZE) + return err; + + cb->args[0] = idx; + return msg->len; +} + +static int devlink_sb_tc_pool_bind_set(struct devlink_port *devlink_port, + unsigned int sb_index, u16 tc_index, + enum devlink_sb_pool_type pool_type, + u16 pool_index, u32 threshold, + struct netlink_ext_ack *extack) + +{ + const struct devlink_ops *ops = devlink_port->devlink->ops; + + if (ops->sb_tc_pool_bind_set) + return ops->sb_tc_pool_bind_set(devlink_port, sb_index, + tc_index, pool_type, + pool_index, threshold, extack); + return -EOPNOTSUPP; +} + +static int devlink_nl_cmd_sb_tc_pool_bind_set_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink_port *devlink_port = info->user_ptr[1]; + struct devlink *devlink = info->user_ptr[0]; + enum devlink_sb_pool_type pool_type; + struct devlink_sb *devlink_sb; + u16 tc_index; + u16 pool_index; + u32 threshold; + int err; + + devlink_sb = devlink_sb_get_from_info(devlink, info); + if (IS_ERR(devlink_sb)) + return PTR_ERR(devlink_sb); + + err = devlink_sb_pool_type_get_from_info(info, &pool_type); + if (err) + return err; + + err = devlink_sb_tc_index_get_from_info(devlink_sb, info, + pool_type, &tc_index); + if (err) + return err; + + err = devlink_sb_pool_index_get_from_info(devlink_sb, info, + &pool_index); + if (err) + return err; + + if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_SB_THRESHOLD)) + return -EINVAL; + + threshold = nla_get_u32(info->attrs[DEVLINK_ATTR_SB_THRESHOLD]); + return devlink_sb_tc_pool_bind_set(devlink_port, devlink_sb->index, + tc_index, pool_type, + pool_index, threshold, info->extack); +} + +static int devlink_nl_cmd_sb_occ_snapshot_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + const struct devlink_ops *ops = devlink->ops; + struct devlink_sb *devlink_sb; + + devlink_sb = devlink_sb_get_from_info(devlink, info); + if (IS_ERR(devlink_sb)) + return PTR_ERR(devlink_sb); + + if (ops->sb_occ_snapshot) + return ops->sb_occ_snapshot(devlink, devlink_sb->index); + return -EOPNOTSUPP; +} + +static int devlink_nl_cmd_sb_occ_max_clear_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + const struct devlink_ops *ops = devlink->ops; + struct devlink_sb *devlink_sb; + + devlink_sb = devlink_sb_get_from_info(devlink, info); + if (IS_ERR(devlink_sb)) + return PTR_ERR(devlink_sb); + + if (ops->sb_occ_max_clear) + return ops->sb_occ_max_clear(devlink, devlink_sb->index); + return -EOPNOTSUPP; +} + +static int devlink_nl_eswitch_fill(struct sk_buff *msg, struct devlink *devlink, + enum devlink_command cmd, u32 portid, + u32 seq, int flags) +{ + const struct devlink_ops *ops = devlink->ops; + enum devlink_eswitch_encap_mode encap_mode; + u8 inline_mode; + void *hdr; + int err = 0; + u16 mode; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + err = devlink_nl_put_handle(msg, devlink); + if (err) + goto nla_put_failure; + + if (ops->eswitch_mode_get) { + err = ops->eswitch_mode_get(devlink, &mode); + if (err) + goto nla_put_failure; + err = nla_put_u16(msg, DEVLINK_ATTR_ESWITCH_MODE, mode); + if (err) + goto nla_put_failure; + } + + if (ops->eswitch_inline_mode_get) { + err = ops->eswitch_inline_mode_get(devlink, &inline_mode); + if (err) + goto nla_put_failure; + err = nla_put_u8(msg, DEVLINK_ATTR_ESWITCH_INLINE_MODE, + inline_mode); + if (err) + goto nla_put_failure; + } + + if (ops->eswitch_encap_mode_get) { + err = ops->eswitch_encap_mode_get(devlink, &encap_mode); + if (err) + goto nla_put_failure; + err = nla_put_u8(msg, DEVLINK_ATTR_ESWITCH_ENCAP_MODE, encap_mode); + if (err) + goto nla_put_failure; + } + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return err; +} + +static int devlink_nl_cmd_eswitch_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct sk_buff *msg; + int err; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_eswitch_fill(msg, devlink, DEVLINK_CMD_ESWITCH_GET, + info->snd_portid, info->snd_seq, 0); + + if (err) { + nlmsg_free(msg); + return err; + } + + return genlmsg_reply(msg, info); +} + +static int devlink_rate_nodes_check(struct devlink *devlink, u16 mode, + struct netlink_ext_ack *extack) +{ + struct devlink_rate *devlink_rate; + + list_for_each_entry(devlink_rate, &devlink->rate_list, list) + if (devlink_rate_is_node(devlink_rate)) { + NL_SET_ERR_MSG_MOD(extack, "Rate node(s) exists."); + return -EBUSY; + } + return 0; +} + +static int devlink_nl_cmd_eswitch_set_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + const struct devlink_ops *ops = devlink->ops; + enum devlink_eswitch_encap_mode encap_mode; + u8 inline_mode; + int err = 0; + u16 mode; + + if (info->attrs[DEVLINK_ATTR_ESWITCH_MODE]) { + if (!ops->eswitch_mode_set) + return -EOPNOTSUPP; + mode = nla_get_u16(info->attrs[DEVLINK_ATTR_ESWITCH_MODE]); + err = devlink_rate_nodes_check(devlink, mode, info->extack); + if (err) + return err; + err = ops->eswitch_mode_set(devlink, mode, info->extack); + if (err) + return err; + } + + if (info->attrs[DEVLINK_ATTR_ESWITCH_INLINE_MODE]) { + if (!ops->eswitch_inline_mode_set) + return -EOPNOTSUPP; + inline_mode = nla_get_u8( + info->attrs[DEVLINK_ATTR_ESWITCH_INLINE_MODE]); + err = ops->eswitch_inline_mode_set(devlink, inline_mode, + info->extack); + if (err) + return err; + } + + if (info->attrs[DEVLINK_ATTR_ESWITCH_ENCAP_MODE]) { + if (!ops->eswitch_encap_mode_set) + return -EOPNOTSUPP; + encap_mode = nla_get_u8(info->attrs[DEVLINK_ATTR_ESWITCH_ENCAP_MODE]); + err = ops->eswitch_encap_mode_set(devlink, encap_mode, + info->extack); + if (err) + return err; + } + + return 0; +} + +int devlink_dpipe_match_put(struct sk_buff *skb, + struct devlink_dpipe_match *match) +{ + struct devlink_dpipe_header *header = match->header; + struct devlink_dpipe_field *field = &header->fields[match->field_id]; + struct nlattr *match_attr; + + match_attr = nla_nest_start_noflag(skb, DEVLINK_ATTR_DPIPE_MATCH); + if (!match_attr) + return -EMSGSIZE; + + if (nla_put_u32(skb, DEVLINK_ATTR_DPIPE_MATCH_TYPE, match->type) || + nla_put_u32(skb, DEVLINK_ATTR_DPIPE_HEADER_INDEX, match->header_index) || + nla_put_u32(skb, DEVLINK_ATTR_DPIPE_HEADER_ID, header->id) || + nla_put_u32(skb, DEVLINK_ATTR_DPIPE_FIELD_ID, field->id) || + nla_put_u8(skb, DEVLINK_ATTR_DPIPE_HEADER_GLOBAL, header->global)) + goto nla_put_failure; + + nla_nest_end(skb, match_attr); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, match_attr); + return -EMSGSIZE; +} +EXPORT_SYMBOL_GPL(devlink_dpipe_match_put); + +static int devlink_dpipe_matches_put(struct devlink_dpipe_table *table, + struct sk_buff *skb) +{ + struct nlattr *matches_attr; + + matches_attr = nla_nest_start_noflag(skb, + DEVLINK_ATTR_DPIPE_TABLE_MATCHES); + if (!matches_attr) + return -EMSGSIZE; + + if (table->table_ops->matches_dump(table->priv, skb)) + goto nla_put_failure; + + nla_nest_end(skb, matches_attr); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, matches_attr); + return -EMSGSIZE; +} + +int devlink_dpipe_action_put(struct sk_buff *skb, + struct devlink_dpipe_action *action) +{ + struct devlink_dpipe_header *header = action->header; + struct devlink_dpipe_field *field = &header->fields[action->field_id]; + struct nlattr *action_attr; + + action_attr = nla_nest_start_noflag(skb, DEVLINK_ATTR_DPIPE_ACTION); + if (!action_attr) + return -EMSGSIZE; + + if (nla_put_u32(skb, DEVLINK_ATTR_DPIPE_ACTION_TYPE, action->type) || + nla_put_u32(skb, DEVLINK_ATTR_DPIPE_HEADER_INDEX, action->header_index) || + nla_put_u32(skb, DEVLINK_ATTR_DPIPE_HEADER_ID, header->id) || + nla_put_u32(skb, DEVLINK_ATTR_DPIPE_FIELD_ID, field->id) || + nla_put_u8(skb, DEVLINK_ATTR_DPIPE_HEADER_GLOBAL, header->global)) + goto nla_put_failure; + + nla_nest_end(skb, action_attr); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, action_attr); + return -EMSGSIZE; +} +EXPORT_SYMBOL_GPL(devlink_dpipe_action_put); + +static int devlink_dpipe_actions_put(struct devlink_dpipe_table *table, + struct sk_buff *skb) +{ + struct nlattr *actions_attr; + + actions_attr = nla_nest_start_noflag(skb, + DEVLINK_ATTR_DPIPE_TABLE_ACTIONS); + if (!actions_attr) + return -EMSGSIZE; + + if (table->table_ops->actions_dump(table->priv, skb)) + goto nla_put_failure; + + nla_nest_end(skb, actions_attr); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, actions_attr); + return -EMSGSIZE; +} + +static int devlink_dpipe_table_put(struct sk_buff *skb, + struct devlink_dpipe_table *table) +{ + struct nlattr *table_attr; + u64 table_size; + + table_size = table->table_ops->size_get(table->priv); + table_attr = nla_nest_start_noflag(skb, DEVLINK_ATTR_DPIPE_TABLE); + if (!table_attr) + return -EMSGSIZE; + + if (nla_put_string(skb, DEVLINK_ATTR_DPIPE_TABLE_NAME, table->name) || + nla_put_u64_64bit(skb, DEVLINK_ATTR_DPIPE_TABLE_SIZE, table_size, + DEVLINK_ATTR_PAD)) + goto nla_put_failure; + if (nla_put_u8(skb, DEVLINK_ATTR_DPIPE_TABLE_COUNTERS_ENABLED, + table->counters_enabled)) + goto nla_put_failure; + + if (table->resource_valid) { + if (nla_put_u64_64bit(skb, DEVLINK_ATTR_DPIPE_TABLE_RESOURCE_ID, + table->resource_id, DEVLINK_ATTR_PAD) || + nla_put_u64_64bit(skb, DEVLINK_ATTR_DPIPE_TABLE_RESOURCE_UNITS, + table->resource_units, DEVLINK_ATTR_PAD)) + goto nla_put_failure; + } + if (devlink_dpipe_matches_put(table, skb)) + goto nla_put_failure; + + if (devlink_dpipe_actions_put(table, skb)) + goto nla_put_failure; + + nla_nest_end(skb, table_attr); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, table_attr); + return -EMSGSIZE; +} + +static int devlink_dpipe_send_and_alloc_skb(struct sk_buff **pskb, + struct genl_info *info) +{ + int err; + + if (*pskb) { + err = genlmsg_reply(*pskb, info); + if (err) + return err; + } + *pskb = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!*pskb) + return -ENOMEM; + return 0; +} + +static int devlink_dpipe_tables_fill(struct genl_info *info, + enum devlink_command cmd, int flags, + struct list_head *dpipe_tables, + const char *table_name) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_dpipe_table *table; + struct nlattr *tables_attr; + struct sk_buff *skb = NULL; + struct nlmsghdr *nlh; + bool incomplete; + void *hdr; + int i; + int err; + + table = list_first_entry(dpipe_tables, + struct devlink_dpipe_table, list); +start_again: + err = devlink_dpipe_send_and_alloc_skb(&skb, info); + if (err) + return err; + + hdr = genlmsg_put(skb, info->snd_portid, info->snd_seq, + &devlink_nl_family, NLM_F_MULTI, cmd); + if (!hdr) { + nlmsg_free(skb); + return -EMSGSIZE; + } + + if (devlink_nl_put_handle(skb, devlink)) + goto nla_put_failure; + tables_attr = nla_nest_start_noflag(skb, DEVLINK_ATTR_DPIPE_TABLES); + if (!tables_attr) + goto nla_put_failure; + + i = 0; + incomplete = false; + list_for_each_entry_from(table, dpipe_tables, list) { + if (!table_name) { + err = devlink_dpipe_table_put(skb, table); + if (err) { + if (!i) + goto err_table_put; + incomplete = true; + break; + } + } else { + if (!strcmp(table->name, table_name)) { + err = devlink_dpipe_table_put(skb, table); + if (err) + break; + } + } + i++; + } + + nla_nest_end(skb, tables_attr); + genlmsg_end(skb, hdr); + if (incomplete) + goto start_again; + +send_done: + nlh = nlmsg_put(skb, info->snd_portid, info->snd_seq, + NLMSG_DONE, 0, flags | NLM_F_MULTI); + if (!nlh) { + err = devlink_dpipe_send_and_alloc_skb(&skb, info); + if (err) + return err; + goto send_done; + } + + return genlmsg_reply(skb, info); + +nla_put_failure: + err = -EMSGSIZE; +err_table_put: + nlmsg_free(skb); + return err; +} + +static int devlink_nl_cmd_dpipe_table_get(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + const char *table_name = NULL; + + if (info->attrs[DEVLINK_ATTR_DPIPE_TABLE_NAME]) + table_name = nla_data(info->attrs[DEVLINK_ATTR_DPIPE_TABLE_NAME]); + + return devlink_dpipe_tables_fill(info, DEVLINK_CMD_DPIPE_TABLE_GET, 0, + &devlink->dpipe_table_list, + table_name); +} + +static int devlink_dpipe_value_put(struct sk_buff *skb, + struct devlink_dpipe_value *value) +{ + if (nla_put(skb, DEVLINK_ATTR_DPIPE_VALUE, + value->value_size, value->value)) + return -EMSGSIZE; + if (value->mask) + if (nla_put(skb, DEVLINK_ATTR_DPIPE_VALUE_MASK, + value->value_size, value->mask)) + return -EMSGSIZE; + if (value->mapping_valid) + if (nla_put_u32(skb, DEVLINK_ATTR_DPIPE_VALUE_MAPPING, + value->mapping_value)) + return -EMSGSIZE; + return 0; +} + +static int devlink_dpipe_action_value_put(struct sk_buff *skb, + struct devlink_dpipe_value *value) +{ + if (!value->action) + return -EINVAL; + if (devlink_dpipe_action_put(skb, value->action)) + return -EMSGSIZE; + if (devlink_dpipe_value_put(skb, value)) + return -EMSGSIZE; + return 0; +} + +static int devlink_dpipe_action_values_put(struct sk_buff *skb, + struct devlink_dpipe_value *values, + unsigned int values_count) +{ + struct nlattr *action_attr; + int i; + int err; + + for (i = 0; i < values_count; i++) { + action_attr = nla_nest_start_noflag(skb, + DEVLINK_ATTR_DPIPE_ACTION_VALUE); + if (!action_attr) + return -EMSGSIZE; + err = devlink_dpipe_action_value_put(skb, &values[i]); + if (err) + goto err_action_value_put; + nla_nest_end(skb, action_attr); + } + return 0; + +err_action_value_put: + nla_nest_cancel(skb, action_attr); + return err; +} + +static int devlink_dpipe_match_value_put(struct sk_buff *skb, + struct devlink_dpipe_value *value) +{ + if (!value->match) + return -EINVAL; + if (devlink_dpipe_match_put(skb, value->match)) + return -EMSGSIZE; + if (devlink_dpipe_value_put(skb, value)) + return -EMSGSIZE; + return 0; +} + +static int devlink_dpipe_match_values_put(struct sk_buff *skb, + struct devlink_dpipe_value *values, + unsigned int values_count) +{ + struct nlattr *match_attr; + int i; + int err; + + for (i = 0; i < values_count; i++) { + match_attr = nla_nest_start_noflag(skb, + DEVLINK_ATTR_DPIPE_MATCH_VALUE); + if (!match_attr) + return -EMSGSIZE; + err = devlink_dpipe_match_value_put(skb, &values[i]); + if (err) + goto err_match_value_put; + nla_nest_end(skb, match_attr); + } + return 0; + +err_match_value_put: + nla_nest_cancel(skb, match_attr); + return err; +} + +static int devlink_dpipe_entry_put(struct sk_buff *skb, + struct devlink_dpipe_entry *entry) +{ + struct nlattr *entry_attr, *matches_attr, *actions_attr; + int err; + + entry_attr = nla_nest_start_noflag(skb, DEVLINK_ATTR_DPIPE_ENTRY); + if (!entry_attr) + return -EMSGSIZE; + + if (nla_put_u64_64bit(skb, DEVLINK_ATTR_DPIPE_ENTRY_INDEX, entry->index, + DEVLINK_ATTR_PAD)) + goto nla_put_failure; + if (entry->counter_valid) + if (nla_put_u64_64bit(skb, DEVLINK_ATTR_DPIPE_ENTRY_COUNTER, + entry->counter, DEVLINK_ATTR_PAD)) + goto nla_put_failure; + + matches_attr = nla_nest_start_noflag(skb, + DEVLINK_ATTR_DPIPE_ENTRY_MATCH_VALUES); + if (!matches_attr) + goto nla_put_failure; + + err = devlink_dpipe_match_values_put(skb, entry->match_values, + entry->match_values_count); + if (err) { + nla_nest_cancel(skb, matches_attr); + goto err_match_values_put; + } + nla_nest_end(skb, matches_attr); + + actions_attr = nla_nest_start_noflag(skb, + DEVLINK_ATTR_DPIPE_ENTRY_ACTION_VALUES); + if (!actions_attr) + goto nla_put_failure; + + err = devlink_dpipe_action_values_put(skb, entry->action_values, + entry->action_values_count); + if (err) { + nla_nest_cancel(skb, actions_attr); + goto err_action_values_put; + } + nla_nest_end(skb, actions_attr); + + nla_nest_end(skb, entry_attr); + return 0; + +nla_put_failure: + err = -EMSGSIZE; +err_match_values_put: +err_action_values_put: + nla_nest_cancel(skb, entry_attr); + return err; +} + +static struct devlink_dpipe_table * +devlink_dpipe_table_find(struct list_head *dpipe_tables, + const char *table_name, struct devlink *devlink) +{ + struct devlink_dpipe_table *table; + list_for_each_entry_rcu(table, dpipe_tables, list, + lockdep_is_held(&devlink->lock)) { + if (!strcmp(table->name, table_name)) + return table; + } + return NULL; +} + +int devlink_dpipe_entry_ctx_prepare(struct devlink_dpipe_dump_ctx *dump_ctx) +{ + struct devlink *devlink; + int err; + + err = devlink_dpipe_send_and_alloc_skb(&dump_ctx->skb, + dump_ctx->info); + if (err) + return err; + + dump_ctx->hdr = genlmsg_put(dump_ctx->skb, + dump_ctx->info->snd_portid, + dump_ctx->info->snd_seq, + &devlink_nl_family, NLM_F_MULTI, + dump_ctx->cmd); + if (!dump_ctx->hdr) + goto nla_put_failure; + + devlink = dump_ctx->info->user_ptr[0]; + if (devlink_nl_put_handle(dump_ctx->skb, devlink)) + goto nla_put_failure; + dump_ctx->nest = nla_nest_start_noflag(dump_ctx->skb, + DEVLINK_ATTR_DPIPE_ENTRIES); + if (!dump_ctx->nest) + goto nla_put_failure; + return 0; + +nla_put_failure: + nlmsg_free(dump_ctx->skb); + return -EMSGSIZE; +} +EXPORT_SYMBOL_GPL(devlink_dpipe_entry_ctx_prepare); + +int devlink_dpipe_entry_ctx_append(struct devlink_dpipe_dump_ctx *dump_ctx, + struct devlink_dpipe_entry *entry) +{ + return devlink_dpipe_entry_put(dump_ctx->skb, entry); +} +EXPORT_SYMBOL_GPL(devlink_dpipe_entry_ctx_append); + +int devlink_dpipe_entry_ctx_close(struct devlink_dpipe_dump_ctx *dump_ctx) +{ + nla_nest_end(dump_ctx->skb, dump_ctx->nest); + genlmsg_end(dump_ctx->skb, dump_ctx->hdr); + return 0; +} +EXPORT_SYMBOL_GPL(devlink_dpipe_entry_ctx_close); + +void devlink_dpipe_entry_clear(struct devlink_dpipe_entry *entry) + +{ + unsigned int value_count, value_index; + struct devlink_dpipe_value *value; + + value = entry->action_values; + value_count = entry->action_values_count; + for (value_index = 0; value_index < value_count; value_index++) { + kfree(value[value_index].value); + kfree(value[value_index].mask); + } + + value = entry->match_values; + value_count = entry->match_values_count; + for (value_index = 0; value_index < value_count; value_index++) { + kfree(value[value_index].value); + kfree(value[value_index].mask); + } +} +EXPORT_SYMBOL_GPL(devlink_dpipe_entry_clear); + +static int devlink_dpipe_entries_fill(struct genl_info *info, + enum devlink_command cmd, int flags, + struct devlink_dpipe_table *table) +{ + struct devlink_dpipe_dump_ctx dump_ctx; + struct nlmsghdr *nlh; + int err; + + dump_ctx.skb = NULL; + dump_ctx.cmd = cmd; + dump_ctx.info = info; + + err = table->table_ops->entries_dump(table->priv, + table->counters_enabled, + &dump_ctx); + if (err) + return err; + +send_done: + nlh = nlmsg_put(dump_ctx.skb, info->snd_portid, info->snd_seq, + NLMSG_DONE, 0, flags | NLM_F_MULTI); + if (!nlh) { + err = devlink_dpipe_send_and_alloc_skb(&dump_ctx.skb, info); + if (err) + return err; + goto send_done; + } + return genlmsg_reply(dump_ctx.skb, info); +} + +static int devlink_nl_cmd_dpipe_entries_get(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_dpipe_table *table; + const char *table_name; + + if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_DPIPE_TABLE_NAME)) + return -EINVAL; + + table_name = nla_data(info->attrs[DEVLINK_ATTR_DPIPE_TABLE_NAME]); + table = devlink_dpipe_table_find(&devlink->dpipe_table_list, + table_name, devlink); + if (!table) + return -EINVAL; + + if (!table->table_ops->entries_dump) + return -EINVAL; + + return devlink_dpipe_entries_fill(info, DEVLINK_CMD_DPIPE_ENTRIES_GET, + 0, table); +} + +static int devlink_dpipe_fields_put(struct sk_buff *skb, + const struct devlink_dpipe_header *header) +{ + struct devlink_dpipe_field *field; + struct nlattr *field_attr; + int i; + + for (i = 0; i < header->fields_count; i++) { + field = &header->fields[i]; + field_attr = nla_nest_start_noflag(skb, + DEVLINK_ATTR_DPIPE_FIELD); + if (!field_attr) + return -EMSGSIZE; + if (nla_put_string(skb, DEVLINK_ATTR_DPIPE_FIELD_NAME, field->name) || + nla_put_u32(skb, DEVLINK_ATTR_DPIPE_FIELD_ID, field->id) || + nla_put_u32(skb, DEVLINK_ATTR_DPIPE_FIELD_BITWIDTH, field->bitwidth) || + nla_put_u32(skb, DEVLINK_ATTR_DPIPE_FIELD_MAPPING_TYPE, field->mapping_type)) + goto nla_put_failure; + nla_nest_end(skb, field_attr); + } + return 0; + +nla_put_failure: + nla_nest_cancel(skb, field_attr); + return -EMSGSIZE; +} + +static int devlink_dpipe_header_put(struct sk_buff *skb, + struct devlink_dpipe_header *header) +{ + struct nlattr *fields_attr, *header_attr; + int err; + + header_attr = nla_nest_start_noflag(skb, DEVLINK_ATTR_DPIPE_HEADER); + if (!header_attr) + return -EMSGSIZE; + + if (nla_put_string(skb, DEVLINK_ATTR_DPIPE_HEADER_NAME, header->name) || + nla_put_u32(skb, DEVLINK_ATTR_DPIPE_HEADER_ID, header->id) || + nla_put_u8(skb, DEVLINK_ATTR_DPIPE_HEADER_GLOBAL, header->global)) + goto nla_put_failure; + + fields_attr = nla_nest_start_noflag(skb, + DEVLINK_ATTR_DPIPE_HEADER_FIELDS); + if (!fields_attr) + goto nla_put_failure; + + err = devlink_dpipe_fields_put(skb, header); + if (err) { + nla_nest_cancel(skb, fields_attr); + goto nla_put_failure; + } + nla_nest_end(skb, fields_attr); + nla_nest_end(skb, header_attr); + return 0; + +nla_put_failure: + err = -EMSGSIZE; + nla_nest_cancel(skb, header_attr); + return err; +} + +static int devlink_dpipe_headers_fill(struct genl_info *info, + enum devlink_command cmd, int flags, + struct devlink_dpipe_headers * + dpipe_headers) +{ + struct devlink *devlink = info->user_ptr[0]; + struct nlattr *headers_attr; + struct sk_buff *skb = NULL; + struct nlmsghdr *nlh; + void *hdr; + int i, j; + int err; + + i = 0; +start_again: + err = devlink_dpipe_send_and_alloc_skb(&skb, info); + if (err) + return err; + + hdr = genlmsg_put(skb, info->snd_portid, info->snd_seq, + &devlink_nl_family, NLM_F_MULTI, cmd); + if (!hdr) { + nlmsg_free(skb); + return -EMSGSIZE; + } + + if (devlink_nl_put_handle(skb, devlink)) + goto nla_put_failure; + headers_attr = nla_nest_start_noflag(skb, DEVLINK_ATTR_DPIPE_HEADERS); + if (!headers_attr) + goto nla_put_failure; + + j = 0; + for (; i < dpipe_headers->headers_count; i++) { + err = devlink_dpipe_header_put(skb, dpipe_headers->headers[i]); + if (err) { + if (!j) + goto err_table_put; + break; + } + j++; + } + nla_nest_end(skb, headers_attr); + genlmsg_end(skb, hdr); + if (i != dpipe_headers->headers_count) + goto start_again; + +send_done: + nlh = nlmsg_put(skb, info->snd_portid, info->snd_seq, + NLMSG_DONE, 0, flags | NLM_F_MULTI); + if (!nlh) { + err = devlink_dpipe_send_and_alloc_skb(&skb, info); + if (err) + return err; + goto send_done; + } + return genlmsg_reply(skb, info); + +nla_put_failure: + err = -EMSGSIZE; +err_table_put: + nlmsg_free(skb); + return err; +} + +static int devlink_nl_cmd_dpipe_headers_get(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + + if (!devlink->dpipe_headers) + return -EOPNOTSUPP; + return devlink_dpipe_headers_fill(info, DEVLINK_CMD_DPIPE_HEADERS_GET, + 0, devlink->dpipe_headers); +} + +static int devlink_dpipe_table_counters_set(struct devlink *devlink, + const char *table_name, + bool enable) +{ + struct devlink_dpipe_table *table; + + table = devlink_dpipe_table_find(&devlink->dpipe_table_list, + table_name, devlink); + if (!table) + return -EINVAL; + + if (table->counter_control_extern) + return -EOPNOTSUPP; + + if (!(table->counters_enabled ^ enable)) + return 0; + + table->counters_enabled = enable; + if (table->table_ops->counters_set_update) + table->table_ops->counters_set_update(table->priv, enable); + return 0; +} + +static int devlink_nl_cmd_dpipe_table_counters_set(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + const char *table_name; + bool counters_enable; + + if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_DPIPE_TABLE_NAME) || + GENL_REQ_ATTR_CHECK(info, + DEVLINK_ATTR_DPIPE_TABLE_COUNTERS_ENABLED)) + return -EINVAL; + + table_name = nla_data(info->attrs[DEVLINK_ATTR_DPIPE_TABLE_NAME]); + counters_enable = !!nla_get_u8(info->attrs[DEVLINK_ATTR_DPIPE_TABLE_COUNTERS_ENABLED]); + + return devlink_dpipe_table_counters_set(devlink, table_name, + counters_enable); +} + +static struct devlink_resource * +devlink_resource_find(struct devlink *devlink, + struct devlink_resource *resource, u64 resource_id) +{ + struct list_head *resource_list; + + if (resource) + resource_list = &resource->resource_list; + else + resource_list = &devlink->resource_list; + + list_for_each_entry(resource, resource_list, list) { + struct devlink_resource *child_resource; + + if (resource->id == resource_id) + return resource; + + child_resource = devlink_resource_find(devlink, resource, + resource_id); + if (child_resource) + return child_resource; + } + return NULL; +} + +static void +devlink_resource_validate_children(struct devlink_resource *resource) +{ + struct devlink_resource *child_resource; + bool size_valid = true; + u64 parts_size = 0; + + if (list_empty(&resource->resource_list)) + goto out; + + list_for_each_entry(child_resource, &resource->resource_list, list) + parts_size += child_resource->size_new; + + if (parts_size > resource->size_new) + size_valid = false; +out: + resource->size_valid = size_valid; +} + +static int +devlink_resource_validate_size(struct devlink_resource *resource, u64 size, + struct netlink_ext_ack *extack) +{ + u64 reminder; + int err = 0; + + if (size > resource->size_params.size_max) { + NL_SET_ERR_MSG_MOD(extack, "Size larger than maximum"); + err = -EINVAL; + } + + if (size < resource->size_params.size_min) { + NL_SET_ERR_MSG_MOD(extack, "Size smaller than minimum"); + err = -EINVAL; + } + + div64_u64_rem(size, resource->size_params.size_granularity, &reminder); + if (reminder) { + NL_SET_ERR_MSG_MOD(extack, "Wrong granularity"); + err = -EINVAL; + } + + return err; +} + +static int devlink_nl_cmd_resource_set(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_resource *resource; + u64 resource_id; + u64 size; + int err; + + if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_RESOURCE_ID) || + GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_RESOURCE_SIZE)) + return -EINVAL; + resource_id = nla_get_u64(info->attrs[DEVLINK_ATTR_RESOURCE_ID]); + + resource = devlink_resource_find(devlink, NULL, resource_id); + if (!resource) + return -EINVAL; + + size = nla_get_u64(info->attrs[DEVLINK_ATTR_RESOURCE_SIZE]); + err = devlink_resource_validate_size(resource, size, info->extack); + if (err) + return err; + + resource->size_new = size; + devlink_resource_validate_children(resource); + if (resource->parent) + devlink_resource_validate_children(resource->parent); + return 0; +} + +static int +devlink_resource_size_params_put(struct devlink_resource *resource, + struct sk_buff *skb) +{ + struct devlink_resource_size_params *size_params; + + size_params = &resource->size_params; + if (nla_put_u64_64bit(skb, DEVLINK_ATTR_RESOURCE_SIZE_GRAN, + size_params->size_granularity, DEVLINK_ATTR_PAD) || + nla_put_u64_64bit(skb, DEVLINK_ATTR_RESOURCE_SIZE_MAX, + size_params->size_max, DEVLINK_ATTR_PAD) || + nla_put_u64_64bit(skb, DEVLINK_ATTR_RESOURCE_SIZE_MIN, + size_params->size_min, DEVLINK_ATTR_PAD) || + nla_put_u8(skb, DEVLINK_ATTR_RESOURCE_UNIT, size_params->unit)) + return -EMSGSIZE; + return 0; +} + +static int devlink_resource_occ_put(struct devlink_resource *resource, + struct sk_buff *skb) +{ + if (!resource->occ_get) + return 0; + return nla_put_u64_64bit(skb, DEVLINK_ATTR_RESOURCE_OCC, + resource->occ_get(resource->occ_get_priv), + DEVLINK_ATTR_PAD); +} + +static int devlink_resource_put(struct devlink *devlink, struct sk_buff *skb, + struct devlink_resource *resource) +{ + struct devlink_resource *child_resource; + struct nlattr *child_resource_attr; + struct nlattr *resource_attr; + + resource_attr = nla_nest_start_noflag(skb, DEVLINK_ATTR_RESOURCE); + if (!resource_attr) + return -EMSGSIZE; + + if (nla_put_string(skb, DEVLINK_ATTR_RESOURCE_NAME, resource->name) || + nla_put_u64_64bit(skb, DEVLINK_ATTR_RESOURCE_SIZE, resource->size, + DEVLINK_ATTR_PAD) || + nla_put_u64_64bit(skb, DEVLINK_ATTR_RESOURCE_ID, resource->id, + DEVLINK_ATTR_PAD)) + goto nla_put_failure; + if (resource->size != resource->size_new) + nla_put_u64_64bit(skb, DEVLINK_ATTR_RESOURCE_SIZE_NEW, + resource->size_new, DEVLINK_ATTR_PAD); + if (devlink_resource_occ_put(resource, skb)) + goto nla_put_failure; + if (devlink_resource_size_params_put(resource, skb)) + goto nla_put_failure; + if (list_empty(&resource->resource_list)) + goto out; + + if (nla_put_u8(skb, DEVLINK_ATTR_RESOURCE_SIZE_VALID, + resource->size_valid)) + goto nla_put_failure; + + child_resource_attr = nla_nest_start_noflag(skb, + DEVLINK_ATTR_RESOURCE_LIST); + if (!child_resource_attr) + goto nla_put_failure; + + list_for_each_entry(child_resource, &resource->resource_list, list) { + if (devlink_resource_put(devlink, skb, child_resource)) + goto resource_put_failure; + } + + nla_nest_end(skb, child_resource_attr); +out: + nla_nest_end(skb, resource_attr); + return 0; + +resource_put_failure: + nla_nest_cancel(skb, child_resource_attr); +nla_put_failure: + nla_nest_cancel(skb, resource_attr); + return -EMSGSIZE; +} + +static int devlink_resource_fill(struct genl_info *info, + enum devlink_command cmd, int flags) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_resource *resource; + struct nlattr *resources_attr; + struct sk_buff *skb = NULL; + struct nlmsghdr *nlh; + bool incomplete; + void *hdr; + int i; + int err; + + resource = list_first_entry(&devlink->resource_list, + struct devlink_resource, list); +start_again: + err = devlink_dpipe_send_and_alloc_skb(&skb, info); + if (err) + return err; + + hdr = genlmsg_put(skb, info->snd_portid, info->snd_seq, + &devlink_nl_family, NLM_F_MULTI, cmd); + if (!hdr) { + nlmsg_free(skb); + return -EMSGSIZE; + } + + if (devlink_nl_put_handle(skb, devlink)) + goto nla_put_failure; + + resources_attr = nla_nest_start_noflag(skb, + DEVLINK_ATTR_RESOURCE_LIST); + if (!resources_attr) + goto nla_put_failure; + + incomplete = false; + i = 0; + list_for_each_entry_from(resource, &devlink->resource_list, list) { + err = devlink_resource_put(devlink, skb, resource); + if (err) { + if (!i) + goto err_resource_put; + incomplete = true; + break; + } + i++; + } + nla_nest_end(skb, resources_attr); + genlmsg_end(skb, hdr); + if (incomplete) + goto start_again; +send_done: + nlh = nlmsg_put(skb, info->snd_portid, info->snd_seq, + NLMSG_DONE, 0, flags | NLM_F_MULTI); + if (!nlh) { + err = devlink_dpipe_send_and_alloc_skb(&skb, info); + if (err) + return err; + goto send_done; + } + return genlmsg_reply(skb, info); + +nla_put_failure: + err = -EMSGSIZE; +err_resource_put: + nlmsg_free(skb); + return err; +} + +static int devlink_nl_cmd_resource_dump(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + + if (list_empty(&devlink->resource_list)) + return -EOPNOTSUPP; + + return devlink_resource_fill(info, DEVLINK_CMD_RESOURCE_DUMP, 0); +} + +static int +devlink_resources_validate(struct devlink *devlink, + struct devlink_resource *resource, + struct genl_info *info) +{ + struct list_head *resource_list; + int err = 0; + + if (resource) + resource_list = &resource->resource_list; + else + resource_list = &devlink->resource_list; + + list_for_each_entry(resource, resource_list, list) { + if (!resource->size_valid) + return -EINVAL; + err = devlink_resources_validate(devlink, resource, info); + if (err) + return err; + } + return err; +} + +static struct net *devlink_netns_get(struct sk_buff *skb, + struct genl_info *info) +{ + struct nlattr *netns_pid_attr = info->attrs[DEVLINK_ATTR_NETNS_PID]; + struct nlattr *netns_fd_attr = info->attrs[DEVLINK_ATTR_NETNS_FD]; + struct nlattr *netns_id_attr = info->attrs[DEVLINK_ATTR_NETNS_ID]; + struct net *net; + + if (!!netns_pid_attr + !!netns_fd_attr + !!netns_id_attr > 1) { + NL_SET_ERR_MSG_MOD(info->extack, "multiple netns identifying attributes specified"); + return ERR_PTR(-EINVAL); + } + + if (netns_pid_attr) { + net = get_net_ns_by_pid(nla_get_u32(netns_pid_attr)); + } else if (netns_fd_attr) { + net = get_net_ns_by_fd(nla_get_u32(netns_fd_attr)); + } else if (netns_id_attr) { + net = get_net_ns_by_id(sock_net(skb->sk), + nla_get_u32(netns_id_attr)); + if (!net) + net = ERR_PTR(-EINVAL); + } else { + WARN_ON(1); + net = ERR_PTR(-EINVAL); + } + if (IS_ERR(net)) { + NL_SET_ERR_MSG_MOD(info->extack, "Unknown network namespace"); + return ERR_PTR(-EINVAL); + } + if (!netlink_ns_capable(skb, net->user_ns, CAP_NET_ADMIN)) { + put_net(net); + return ERR_PTR(-EPERM); + } + return net; +} + +static void devlink_param_notify(struct devlink *devlink, + unsigned int port_index, + struct devlink_param_item *param_item, + enum devlink_command cmd); + +static void devlink_ns_change_notify(struct devlink *devlink, + struct net *dest_net, struct net *curr_net, + bool new) +{ + struct devlink_param_item *param_item; + enum devlink_command cmd; + + /* Userspace needs to be notified about devlink objects + * removed from original and entering new network namespace. + * The rest of the devlink objects are re-created during + * reload process so the notifications are generated separatelly. + */ + + if (!dest_net || net_eq(dest_net, curr_net)) + return; + + if (new) + devlink_notify(devlink, DEVLINK_CMD_NEW); + + cmd = new ? DEVLINK_CMD_PARAM_NEW : DEVLINK_CMD_PARAM_DEL; + list_for_each_entry(param_item, &devlink->param_list, list) + devlink_param_notify(devlink, 0, param_item, cmd); + + if (!new) + devlink_notify(devlink, DEVLINK_CMD_DEL); +} + +static bool devlink_reload_supported(const struct devlink_ops *ops) +{ + return ops->reload_down && ops->reload_up; +} + +static void devlink_reload_failed_set(struct devlink *devlink, + bool reload_failed) +{ + if (devlink->reload_failed == reload_failed) + return; + devlink->reload_failed = reload_failed; + devlink_notify(devlink, DEVLINK_CMD_NEW); +} + +bool devlink_is_reload_failed(const struct devlink *devlink) +{ + return devlink->reload_failed; +} +EXPORT_SYMBOL_GPL(devlink_is_reload_failed); + +static void +__devlink_reload_stats_update(struct devlink *devlink, u32 *reload_stats, + enum devlink_reload_limit limit, u32 actions_performed) +{ + unsigned long actions = actions_performed; + int stat_idx; + int action; + + for_each_set_bit(action, &actions, __DEVLINK_RELOAD_ACTION_MAX) { + stat_idx = limit * __DEVLINK_RELOAD_ACTION_MAX + action; + reload_stats[stat_idx]++; + } + devlink_notify(devlink, DEVLINK_CMD_NEW); +} + +static void +devlink_reload_stats_update(struct devlink *devlink, enum devlink_reload_limit limit, + u32 actions_performed) +{ + __devlink_reload_stats_update(devlink, devlink->stats.reload_stats, limit, + actions_performed); +} + +/** + * devlink_remote_reload_actions_performed - Update devlink on reload actions + * performed which are not a direct result of devlink reload call. + * + * This should be called by a driver after performing reload actions in case it was not + * a result of devlink reload call. For example fw_activate was performed as a result + * of devlink reload triggered fw_activate on another host. + * The motivation for this function is to keep data on reload actions performed on this + * function whether it was done due to direct devlink reload call or not. + * + * @devlink: devlink + * @limit: reload limit + * @actions_performed: bitmask of actions performed + */ +void devlink_remote_reload_actions_performed(struct devlink *devlink, + enum devlink_reload_limit limit, + u32 actions_performed) +{ + if (WARN_ON(!actions_performed || + actions_performed & BIT(DEVLINK_RELOAD_ACTION_UNSPEC) || + actions_performed >= BIT(__DEVLINK_RELOAD_ACTION_MAX) || + limit > DEVLINK_RELOAD_LIMIT_MAX)) + return; + + __devlink_reload_stats_update(devlink, devlink->stats.remote_reload_stats, limit, + actions_performed); +} +EXPORT_SYMBOL_GPL(devlink_remote_reload_actions_performed); + +static int devlink_reload(struct devlink *devlink, struct net *dest_net, + enum devlink_reload_action action, enum devlink_reload_limit limit, + u32 *actions_performed, struct netlink_ext_ack *extack) +{ + u32 remote_reload_stats[DEVLINK_RELOAD_STATS_ARRAY_SIZE]; + struct net *curr_net; + int err; + + memcpy(remote_reload_stats, devlink->stats.remote_reload_stats, + sizeof(remote_reload_stats)); + + curr_net = devlink_net(devlink); + devlink_ns_change_notify(devlink, dest_net, curr_net, false); + err = devlink->ops->reload_down(devlink, !!dest_net, action, limit, extack); + if (err) + return err; + + if (dest_net && !net_eq(dest_net, curr_net)) + write_pnet(&devlink->_net, dest_net); + + err = devlink->ops->reload_up(devlink, action, limit, actions_performed, extack); + devlink_reload_failed_set(devlink, !!err); + if (err) + return err; + + devlink_ns_change_notify(devlink, dest_net, curr_net, true); + WARN_ON(!(*actions_performed & BIT(action))); + /* Catch driver on updating the remote action within devlink reload */ + WARN_ON(memcmp(remote_reload_stats, devlink->stats.remote_reload_stats, + sizeof(remote_reload_stats))); + devlink_reload_stats_update(devlink, limit, *actions_performed); + return 0; +} + +static int +devlink_nl_reload_actions_performed_snd(struct devlink *devlink, u32 actions_performed, + enum devlink_command cmd, struct genl_info *info) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, info->snd_portid, info->snd_seq, &devlink_nl_family, 0, cmd); + if (!hdr) + goto free_msg; + + if (devlink_nl_put_handle(msg, devlink)) + goto nla_put_failure; + + if (nla_put_bitfield32(msg, DEVLINK_ATTR_RELOAD_ACTIONS_PERFORMED, actions_performed, + actions_performed)) + goto nla_put_failure; + genlmsg_end(msg, hdr); + + return genlmsg_reply(msg, info); + +nla_put_failure: + genlmsg_cancel(msg, hdr); +free_msg: + nlmsg_free(msg); + return -EMSGSIZE; +} + +static int devlink_nl_cmd_reload(struct sk_buff *skb, struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + enum devlink_reload_action action; + enum devlink_reload_limit limit; + struct net *dest_net = NULL; + u32 actions_performed; + int err; + + if (!(devlink->features & DEVLINK_F_RELOAD)) + return -EOPNOTSUPP; + + err = devlink_resources_validate(devlink, NULL, info); + if (err) { + NL_SET_ERR_MSG_MOD(info->extack, "resources size validation failed"); + return err; + } + + if (info->attrs[DEVLINK_ATTR_RELOAD_ACTION]) + action = nla_get_u8(info->attrs[DEVLINK_ATTR_RELOAD_ACTION]); + else + action = DEVLINK_RELOAD_ACTION_DRIVER_REINIT; + + if (!devlink_reload_action_is_supported(devlink, action)) { + NL_SET_ERR_MSG_MOD(info->extack, + "Requested reload action is not supported by the driver"); + return -EOPNOTSUPP; + } + + limit = DEVLINK_RELOAD_LIMIT_UNSPEC; + if (info->attrs[DEVLINK_ATTR_RELOAD_LIMITS]) { + struct nla_bitfield32 limits; + u32 limits_selected; + + limits = nla_get_bitfield32(info->attrs[DEVLINK_ATTR_RELOAD_LIMITS]); + limits_selected = limits.value & limits.selector; + if (!limits_selected) { + NL_SET_ERR_MSG_MOD(info->extack, "Invalid limit selected"); + return -EINVAL; + } + for (limit = 0 ; limit <= DEVLINK_RELOAD_LIMIT_MAX ; limit++) + if (limits_selected & BIT(limit)) + break; + /* UAPI enables multiselection, but currently it is not used */ + if (limits_selected != BIT(limit)) { + NL_SET_ERR_MSG_MOD(info->extack, + "Multiselection of limit is not supported"); + return -EOPNOTSUPP; + } + if (!devlink_reload_limit_is_supported(devlink, limit)) { + NL_SET_ERR_MSG_MOD(info->extack, + "Requested limit is not supported by the driver"); + return -EOPNOTSUPP; + } + if (devlink_reload_combination_is_invalid(action, limit)) { + NL_SET_ERR_MSG_MOD(info->extack, + "Requested limit is invalid for this action"); + return -EINVAL; + } + } + if (info->attrs[DEVLINK_ATTR_NETNS_PID] || + info->attrs[DEVLINK_ATTR_NETNS_FD] || + info->attrs[DEVLINK_ATTR_NETNS_ID]) { + dest_net = devlink_netns_get(skb, info); + if (IS_ERR(dest_net)) + return PTR_ERR(dest_net); + } + + err = devlink_reload(devlink, dest_net, action, limit, &actions_performed, info->extack); + + if (dest_net) + put_net(dest_net); + + if (err) + return err; + /* For backward compatibility generate reply only if attributes used by user */ + if (!info->attrs[DEVLINK_ATTR_RELOAD_ACTION] && !info->attrs[DEVLINK_ATTR_RELOAD_LIMITS]) + return 0; + + return devlink_nl_reload_actions_performed_snd(devlink, actions_performed, + DEVLINK_CMD_RELOAD, info); +} + +static int devlink_nl_flash_update_fill(struct sk_buff *msg, + struct devlink *devlink, + enum devlink_command cmd, + struct devlink_flash_notify *params) +{ + void *hdr; + + hdr = genlmsg_put(msg, 0, 0, &devlink_nl_family, 0, cmd); + if (!hdr) + return -EMSGSIZE; + + if (devlink_nl_put_handle(msg, devlink)) + goto nla_put_failure; + + if (cmd != DEVLINK_CMD_FLASH_UPDATE_STATUS) + goto out; + + if (params->status_msg && + nla_put_string(msg, DEVLINK_ATTR_FLASH_UPDATE_STATUS_MSG, + params->status_msg)) + goto nla_put_failure; + if (params->component && + nla_put_string(msg, DEVLINK_ATTR_FLASH_UPDATE_COMPONENT, + params->component)) + goto nla_put_failure; + if (nla_put_u64_64bit(msg, DEVLINK_ATTR_FLASH_UPDATE_STATUS_DONE, + params->done, DEVLINK_ATTR_PAD)) + goto nla_put_failure; + if (nla_put_u64_64bit(msg, DEVLINK_ATTR_FLASH_UPDATE_STATUS_TOTAL, + params->total, DEVLINK_ATTR_PAD)) + goto nla_put_failure; + if (nla_put_u64_64bit(msg, DEVLINK_ATTR_FLASH_UPDATE_STATUS_TIMEOUT, + params->timeout, DEVLINK_ATTR_PAD)) + goto nla_put_failure; + +out: + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static void __devlink_flash_update_notify(struct devlink *devlink, + enum devlink_command cmd, + struct devlink_flash_notify *params) +{ + struct sk_buff *msg; + int err; + + WARN_ON(cmd != DEVLINK_CMD_FLASH_UPDATE && + cmd != DEVLINK_CMD_FLASH_UPDATE_END && + cmd != DEVLINK_CMD_FLASH_UPDATE_STATUS); + + if (!xa_get_mark(&devlinks, devlink->index, DEVLINK_REGISTERED)) + return; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + err = devlink_nl_flash_update_fill(msg, devlink, cmd, params); + if (err) + goto out_free_msg; + + genlmsg_multicast_netns(&devlink_nl_family, devlink_net(devlink), + msg, 0, DEVLINK_MCGRP_CONFIG, GFP_KERNEL); + return; + +out_free_msg: + nlmsg_free(msg); +} + +static void devlink_flash_update_begin_notify(struct devlink *devlink) +{ + struct devlink_flash_notify params = {}; + + __devlink_flash_update_notify(devlink, + DEVLINK_CMD_FLASH_UPDATE, + ¶ms); +} + +static void devlink_flash_update_end_notify(struct devlink *devlink) +{ + struct devlink_flash_notify params = {}; + + __devlink_flash_update_notify(devlink, + DEVLINK_CMD_FLASH_UPDATE_END, + ¶ms); +} + +void devlink_flash_update_status_notify(struct devlink *devlink, + const char *status_msg, + const char *component, + unsigned long done, + unsigned long total) +{ + struct devlink_flash_notify params = { + .status_msg = status_msg, + .component = component, + .done = done, + .total = total, + }; + + __devlink_flash_update_notify(devlink, + DEVLINK_CMD_FLASH_UPDATE_STATUS, + ¶ms); +} +EXPORT_SYMBOL_GPL(devlink_flash_update_status_notify); + +void devlink_flash_update_timeout_notify(struct devlink *devlink, + const char *status_msg, + const char *component, + unsigned long timeout) +{ + struct devlink_flash_notify params = { + .status_msg = status_msg, + .component = component, + .timeout = timeout, + }; + + __devlink_flash_update_notify(devlink, + DEVLINK_CMD_FLASH_UPDATE_STATUS, + ¶ms); +} +EXPORT_SYMBOL_GPL(devlink_flash_update_timeout_notify); + +struct devlink_info_req { + struct sk_buff *msg; + void (*version_cb)(const char *version_name, + enum devlink_info_version_type version_type, + void *version_cb_priv); + void *version_cb_priv; +}; + +struct devlink_flash_component_lookup_ctx { + const char *lookup_name; + bool lookup_name_found; +}; + +static void +devlink_flash_component_lookup_cb(const char *version_name, + enum devlink_info_version_type version_type, + void *version_cb_priv) +{ + struct devlink_flash_component_lookup_ctx *lookup_ctx = version_cb_priv; + + if (version_type != DEVLINK_INFO_VERSION_TYPE_COMPONENT || + lookup_ctx->lookup_name_found) + return; + + lookup_ctx->lookup_name_found = + !strcmp(lookup_ctx->lookup_name, version_name); +} + +static int devlink_flash_component_get(struct devlink *devlink, + struct nlattr *nla_component, + const char **p_component, + struct netlink_ext_ack *extack) +{ + struct devlink_flash_component_lookup_ctx lookup_ctx = {}; + struct devlink_info_req req = {}; + const char *component; + int ret; + + if (!nla_component) + return 0; + + component = nla_data(nla_component); + + if (!devlink->ops->info_get) { + NL_SET_ERR_MSG_ATTR(extack, nla_component, + "component update is not supported by this device"); + return -EOPNOTSUPP; + } + + lookup_ctx.lookup_name = component; + req.version_cb = devlink_flash_component_lookup_cb; + req.version_cb_priv = &lookup_ctx; + + ret = devlink->ops->info_get(devlink, &req, NULL); + if (ret) + return ret; + + if (!lookup_ctx.lookup_name_found) { + NL_SET_ERR_MSG_ATTR(extack, nla_component, + "selected component is not supported by this device"); + return -EINVAL; + } + *p_component = component; + return 0; +} + +static int devlink_nl_cmd_flash_update(struct sk_buff *skb, + struct genl_info *info) +{ + struct nlattr *nla_overwrite_mask, *nla_file_name; + struct devlink_flash_update_params params = {}; + struct devlink *devlink = info->user_ptr[0]; + const char *file_name; + u32 supported_params; + int ret; + + if (!devlink->ops->flash_update) + return -EOPNOTSUPP; + + if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_FLASH_UPDATE_FILE_NAME)) + return -EINVAL; + + ret = devlink_flash_component_get(devlink, + info->attrs[DEVLINK_ATTR_FLASH_UPDATE_COMPONENT], + ¶ms.component, info->extack); + if (ret) + return ret; + + supported_params = devlink->ops->supported_flash_update_params; + + nla_overwrite_mask = info->attrs[DEVLINK_ATTR_FLASH_UPDATE_OVERWRITE_MASK]; + if (nla_overwrite_mask) { + struct nla_bitfield32 sections; + + if (!(supported_params & DEVLINK_SUPPORT_FLASH_UPDATE_OVERWRITE_MASK)) { + NL_SET_ERR_MSG_ATTR(info->extack, nla_overwrite_mask, + "overwrite settings are not supported by this device"); + return -EOPNOTSUPP; + } + sections = nla_get_bitfield32(nla_overwrite_mask); + params.overwrite_mask = sections.value & sections.selector; + } + + nla_file_name = info->attrs[DEVLINK_ATTR_FLASH_UPDATE_FILE_NAME]; + file_name = nla_data(nla_file_name); + ret = request_firmware(¶ms.fw, file_name, devlink->dev); + if (ret) { + NL_SET_ERR_MSG_ATTR(info->extack, nla_file_name, "failed to locate the requested firmware file"); + return ret; + } + + devlink_flash_update_begin_notify(devlink); + ret = devlink->ops->flash_update(devlink, ¶ms, info->extack); + devlink_flash_update_end_notify(devlink); + + release_firmware(params.fw); + + return ret; +} + +static int +devlink_nl_selftests_fill(struct sk_buff *msg, struct devlink *devlink, + u32 portid, u32 seq, int flags, + struct netlink_ext_ack *extack) +{ + struct nlattr *selftests; + void *hdr; + int err; + int i; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, + DEVLINK_CMD_SELFTESTS_GET); + if (!hdr) + return -EMSGSIZE; + + err = -EMSGSIZE; + if (devlink_nl_put_handle(msg, devlink)) + goto err_cancel_msg; + + selftests = nla_nest_start(msg, DEVLINK_ATTR_SELFTESTS); + if (!selftests) + goto err_cancel_msg; + + for (i = DEVLINK_ATTR_SELFTEST_ID_UNSPEC + 1; + i <= DEVLINK_ATTR_SELFTEST_ID_MAX; i++) { + if (devlink->ops->selftest_check(devlink, i, extack)) { + err = nla_put_flag(msg, i); + if (err) + goto err_cancel_msg; + } + } + + nla_nest_end(msg, selftests); + genlmsg_end(msg, hdr); + return 0; + +err_cancel_msg: + genlmsg_cancel(msg, hdr); + return err; +} + +static int devlink_nl_cmd_selftests_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct sk_buff *msg; + int err; + + if (!devlink->ops->selftest_check) + return -EOPNOTSUPP; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_selftests_fill(msg, devlink, info->snd_portid, + info->snd_seq, 0, info->extack); + if (err) { + nlmsg_free(msg); + return err; + } + + return genlmsg_reply(msg, info); +} + +static int devlink_nl_cmd_selftests_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct devlink *devlink; + int start = cb->args[0]; + unsigned long index; + int idx = 0; + int err = 0; + + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + if (idx < start || !devlink->ops->selftest_check) + goto inc; + + devl_lock(devlink); + err = devlink_nl_selftests_fill(msg, devlink, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + cb->extack); + devl_unlock(devlink); + if (err) { + devlink_put(devlink); + break; + } +inc: + idx++; + devlink_put(devlink); + } + + if (err != -EMSGSIZE) + return err; + + cb->args[0] = idx; + return msg->len; +} + +static int devlink_selftest_result_put(struct sk_buff *skb, unsigned int id, + enum devlink_selftest_status test_status) +{ + struct nlattr *result_attr; + + result_attr = nla_nest_start(skb, DEVLINK_ATTR_SELFTEST_RESULT); + if (!result_attr) + return -EMSGSIZE; + + if (nla_put_u32(skb, DEVLINK_ATTR_SELFTEST_RESULT_ID, id) || + nla_put_u8(skb, DEVLINK_ATTR_SELFTEST_RESULT_STATUS, + test_status)) + goto nla_put_failure; + + nla_nest_end(skb, result_attr); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, result_attr); + return -EMSGSIZE; +} + +static int devlink_nl_cmd_selftests_run(struct sk_buff *skb, + struct genl_info *info) +{ + struct nlattr *tb[DEVLINK_ATTR_SELFTEST_ID_MAX + 1]; + struct devlink *devlink = info->user_ptr[0]; + struct nlattr *attrs, *selftests; + struct sk_buff *msg; + void *hdr; + int err; + int i; + + if (!devlink->ops->selftest_run || !devlink->ops->selftest_check) + return -EOPNOTSUPP; + + if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_SELFTESTS)) + return -EINVAL; + + attrs = info->attrs[DEVLINK_ATTR_SELFTESTS]; + + err = nla_parse_nested(tb, DEVLINK_ATTR_SELFTEST_ID_MAX, attrs, + devlink_selftest_nl_policy, info->extack); + if (err < 0) + return err; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = -EMSGSIZE; + hdr = genlmsg_put(msg, info->snd_portid, info->snd_seq, + &devlink_nl_family, 0, DEVLINK_CMD_SELFTESTS_RUN); + if (!hdr) + goto free_msg; + + if (devlink_nl_put_handle(msg, devlink)) + goto genlmsg_cancel; + + selftests = nla_nest_start(msg, DEVLINK_ATTR_SELFTESTS); + if (!selftests) + goto genlmsg_cancel; + + for (i = DEVLINK_ATTR_SELFTEST_ID_UNSPEC + 1; + i <= DEVLINK_ATTR_SELFTEST_ID_MAX; i++) { + enum devlink_selftest_status test_status; + + if (nla_get_flag(tb[i])) { + if (!devlink->ops->selftest_check(devlink, i, + info->extack)) { + if (devlink_selftest_result_put(msg, i, + DEVLINK_SELFTEST_STATUS_SKIP)) + goto selftests_nest_cancel; + continue; + } + + test_status = devlink->ops->selftest_run(devlink, i, + info->extack); + if (devlink_selftest_result_put(msg, i, test_status)) + goto selftests_nest_cancel; + } + } + + nla_nest_end(msg, selftests); + genlmsg_end(msg, hdr); + return genlmsg_reply(msg, info); + +selftests_nest_cancel: + nla_nest_cancel(msg, selftests); +genlmsg_cancel: + genlmsg_cancel(msg, hdr); +free_msg: + nlmsg_free(msg); + return err; +} + +static const struct devlink_param devlink_param_generic[] = { + { + .id = DEVLINK_PARAM_GENERIC_ID_INT_ERR_RESET, + .name = DEVLINK_PARAM_GENERIC_INT_ERR_RESET_NAME, + .type = DEVLINK_PARAM_GENERIC_INT_ERR_RESET_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_MAX_MACS, + .name = DEVLINK_PARAM_GENERIC_MAX_MACS_NAME, + .type = DEVLINK_PARAM_GENERIC_MAX_MACS_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_ENABLE_SRIOV, + .name = DEVLINK_PARAM_GENERIC_ENABLE_SRIOV_NAME, + .type = DEVLINK_PARAM_GENERIC_ENABLE_SRIOV_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_REGION_SNAPSHOT, + .name = DEVLINK_PARAM_GENERIC_REGION_SNAPSHOT_NAME, + .type = DEVLINK_PARAM_GENERIC_REGION_SNAPSHOT_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_IGNORE_ARI, + .name = DEVLINK_PARAM_GENERIC_IGNORE_ARI_NAME, + .type = DEVLINK_PARAM_GENERIC_IGNORE_ARI_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_MSIX_VEC_PER_PF_MAX, + .name = DEVLINK_PARAM_GENERIC_MSIX_VEC_PER_PF_MAX_NAME, + .type = DEVLINK_PARAM_GENERIC_MSIX_VEC_PER_PF_MAX_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_MSIX_VEC_PER_PF_MIN, + .name = DEVLINK_PARAM_GENERIC_MSIX_VEC_PER_PF_MIN_NAME, + .type = DEVLINK_PARAM_GENERIC_MSIX_VEC_PER_PF_MIN_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_FW_LOAD_POLICY, + .name = DEVLINK_PARAM_GENERIC_FW_LOAD_POLICY_NAME, + .type = DEVLINK_PARAM_GENERIC_FW_LOAD_POLICY_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_RESET_DEV_ON_DRV_PROBE, + .name = DEVLINK_PARAM_GENERIC_RESET_DEV_ON_DRV_PROBE_NAME, + .type = DEVLINK_PARAM_GENERIC_RESET_DEV_ON_DRV_PROBE_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_ENABLE_ROCE, + .name = DEVLINK_PARAM_GENERIC_ENABLE_ROCE_NAME, + .type = DEVLINK_PARAM_GENERIC_ENABLE_ROCE_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_ENABLE_REMOTE_DEV_RESET, + .name = DEVLINK_PARAM_GENERIC_ENABLE_REMOTE_DEV_RESET_NAME, + .type = DEVLINK_PARAM_GENERIC_ENABLE_REMOTE_DEV_RESET_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_ENABLE_ETH, + .name = DEVLINK_PARAM_GENERIC_ENABLE_ETH_NAME, + .type = DEVLINK_PARAM_GENERIC_ENABLE_ETH_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_ENABLE_RDMA, + .name = DEVLINK_PARAM_GENERIC_ENABLE_RDMA_NAME, + .type = DEVLINK_PARAM_GENERIC_ENABLE_RDMA_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_ENABLE_VNET, + .name = DEVLINK_PARAM_GENERIC_ENABLE_VNET_NAME, + .type = DEVLINK_PARAM_GENERIC_ENABLE_VNET_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_ENABLE_IWARP, + .name = DEVLINK_PARAM_GENERIC_ENABLE_IWARP_NAME, + .type = DEVLINK_PARAM_GENERIC_ENABLE_IWARP_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_IO_EQ_SIZE, + .name = DEVLINK_PARAM_GENERIC_IO_EQ_SIZE_NAME, + .type = DEVLINK_PARAM_GENERIC_IO_EQ_SIZE_TYPE, + }, + { + .id = DEVLINK_PARAM_GENERIC_ID_EVENT_EQ_SIZE, + .name = DEVLINK_PARAM_GENERIC_EVENT_EQ_SIZE_NAME, + .type = DEVLINK_PARAM_GENERIC_EVENT_EQ_SIZE_TYPE, + }, +}; + +static int devlink_param_generic_verify(const struct devlink_param *param) +{ + /* verify it match generic parameter by id and name */ + if (param->id > DEVLINK_PARAM_GENERIC_ID_MAX) + return -EINVAL; + if (strcmp(param->name, devlink_param_generic[param->id].name)) + return -ENOENT; + + WARN_ON(param->type != devlink_param_generic[param->id].type); + + return 0; +} + +static int devlink_param_driver_verify(const struct devlink_param *param) +{ + int i; + + if (param->id <= DEVLINK_PARAM_GENERIC_ID_MAX) + return -EINVAL; + /* verify no such name in generic params */ + for (i = 0; i <= DEVLINK_PARAM_GENERIC_ID_MAX; i++) + if (!strcmp(param->name, devlink_param_generic[i].name)) + return -EEXIST; + + return 0; +} + +static struct devlink_param_item * +devlink_param_find_by_name(struct list_head *param_list, + const char *param_name) +{ + struct devlink_param_item *param_item; + + list_for_each_entry(param_item, param_list, list) + if (!strcmp(param_item->param->name, param_name)) + return param_item; + return NULL; +} + +static struct devlink_param_item * +devlink_param_find_by_id(struct list_head *param_list, u32 param_id) +{ + struct devlink_param_item *param_item; + + list_for_each_entry(param_item, param_list, list) + if (param_item->param->id == param_id) + return param_item; + return NULL; +} + +static bool +devlink_param_cmode_is_supported(const struct devlink_param *param, + enum devlink_param_cmode cmode) +{ + return test_bit(cmode, ¶m->supported_cmodes); +} + +static int devlink_param_get(struct devlink *devlink, + const struct devlink_param *param, + struct devlink_param_gset_ctx *ctx) +{ + if (!param->get) + return -EOPNOTSUPP; + return param->get(devlink, param->id, ctx); +} + +static int devlink_param_set(struct devlink *devlink, + const struct devlink_param *param, + struct devlink_param_gset_ctx *ctx) +{ + if (!param->set) + return -EOPNOTSUPP; + return param->set(devlink, param->id, ctx); +} + +static int +devlink_param_type_to_nla_type(enum devlink_param_type param_type) +{ + switch (param_type) { + case DEVLINK_PARAM_TYPE_U8: + return NLA_U8; + case DEVLINK_PARAM_TYPE_U16: + return NLA_U16; + case DEVLINK_PARAM_TYPE_U32: + return NLA_U32; + case DEVLINK_PARAM_TYPE_STRING: + return NLA_STRING; + case DEVLINK_PARAM_TYPE_BOOL: + return NLA_FLAG; + default: + return -EINVAL; + } +} + +static int +devlink_nl_param_value_fill_one(struct sk_buff *msg, + enum devlink_param_type type, + enum devlink_param_cmode cmode, + union devlink_param_value val) +{ + struct nlattr *param_value_attr; + + param_value_attr = nla_nest_start_noflag(msg, + DEVLINK_ATTR_PARAM_VALUE); + if (!param_value_attr) + goto nla_put_failure; + + if (nla_put_u8(msg, DEVLINK_ATTR_PARAM_VALUE_CMODE, cmode)) + goto value_nest_cancel; + + switch (type) { + case DEVLINK_PARAM_TYPE_U8: + if (nla_put_u8(msg, DEVLINK_ATTR_PARAM_VALUE_DATA, val.vu8)) + goto value_nest_cancel; + break; + case DEVLINK_PARAM_TYPE_U16: + if (nla_put_u16(msg, DEVLINK_ATTR_PARAM_VALUE_DATA, val.vu16)) + goto value_nest_cancel; + break; + case DEVLINK_PARAM_TYPE_U32: + if (nla_put_u32(msg, DEVLINK_ATTR_PARAM_VALUE_DATA, val.vu32)) + goto value_nest_cancel; + break; + case DEVLINK_PARAM_TYPE_STRING: + if (nla_put_string(msg, DEVLINK_ATTR_PARAM_VALUE_DATA, + val.vstr)) + goto value_nest_cancel; + break; + case DEVLINK_PARAM_TYPE_BOOL: + if (val.vbool && + nla_put_flag(msg, DEVLINK_ATTR_PARAM_VALUE_DATA)) + goto value_nest_cancel; + break; + } + + nla_nest_end(msg, param_value_attr); + return 0; + +value_nest_cancel: + nla_nest_cancel(msg, param_value_attr); +nla_put_failure: + return -EMSGSIZE; +} + +static int devlink_nl_param_fill(struct sk_buff *msg, struct devlink *devlink, + unsigned int port_index, + struct devlink_param_item *param_item, + enum devlink_command cmd, + u32 portid, u32 seq, int flags) +{ + union devlink_param_value param_value[DEVLINK_PARAM_CMODE_MAX + 1]; + bool param_value_set[DEVLINK_PARAM_CMODE_MAX + 1] = {}; + const struct devlink_param *param = param_item->param; + struct devlink_param_gset_ctx ctx; + struct nlattr *param_values_list; + struct nlattr *param_attr; + int nla_type; + void *hdr; + int err; + int i; + + /* Get value from driver part to driverinit configuration mode */ + for (i = 0; i <= DEVLINK_PARAM_CMODE_MAX; i++) { + if (!devlink_param_cmode_is_supported(param, i)) + continue; + if (i == DEVLINK_PARAM_CMODE_DRIVERINIT) { + if (!param_item->driverinit_value_valid) + return -EOPNOTSUPP; + param_value[i] = param_item->driverinit_value; + } else { + ctx.cmode = i; + err = devlink_param_get(devlink, param, &ctx); + if (err) + return err; + param_value[i] = ctx.val; + } + param_value_set[i] = true; + } + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + if (devlink_nl_put_handle(msg, devlink)) + goto genlmsg_cancel; + + if (cmd == DEVLINK_CMD_PORT_PARAM_GET || + cmd == DEVLINK_CMD_PORT_PARAM_NEW || + cmd == DEVLINK_CMD_PORT_PARAM_DEL) + if (nla_put_u32(msg, DEVLINK_ATTR_PORT_INDEX, port_index)) + goto genlmsg_cancel; + + param_attr = nla_nest_start_noflag(msg, DEVLINK_ATTR_PARAM); + if (!param_attr) + goto genlmsg_cancel; + if (nla_put_string(msg, DEVLINK_ATTR_PARAM_NAME, param->name)) + goto param_nest_cancel; + if (param->generic && nla_put_flag(msg, DEVLINK_ATTR_PARAM_GENERIC)) + goto param_nest_cancel; + + nla_type = devlink_param_type_to_nla_type(param->type); + if (nla_type < 0) + goto param_nest_cancel; + if (nla_put_u8(msg, DEVLINK_ATTR_PARAM_TYPE, nla_type)) + goto param_nest_cancel; + + param_values_list = nla_nest_start_noflag(msg, + DEVLINK_ATTR_PARAM_VALUES_LIST); + if (!param_values_list) + goto param_nest_cancel; + + for (i = 0; i <= DEVLINK_PARAM_CMODE_MAX; i++) { + if (!param_value_set[i]) + continue; + err = devlink_nl_param_value_fill_one(msg, param->type, + i, param_value[i]); + if (err) + goto values_list_nest_cancel; + } + + nla_nest_end(msg, param_values_list); + nla_nest_end(msg, param_attr); + genlmsg_end(msg, hdr); + return 0; + +values_list_nest_cancel: + nla_nest_end(msg, param_values_list); +param_nest_cancel: + nla_nest_cancel(msg, param_attr); +genlmsg_cancel: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static void devlink_param_notify(struct devlink *devlink, + unsigned int port_index, + struct devlink_param_item *param_item, + enum devlink_command cmd) +{ + struct sk_buff *msg; + int err; + + WARN_ON(cmd != DEVLINK_CMD_PARAM_NEW && cmd != DEVLINK_CMD_PARAM_DEL && + cmd != DEVLINK_CMD_PORT_PARAM_NEW && + cmd != DEVLINK_CMD_PORT_PARAM_DEL); + ASSERT_DEVLINK_REGISTERED(devlink); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + err = devlink_nl_param_fill(msg, devlink, port_index, param_item, cmd, + 0, 0, 0); + if (err) { + nlmsg_free(msg); + return; + } + + genlmsg_multicast_netns(&devlink_nl_family, devlink_net(devlink), + msg, 0, DEVLINK_MCGRP_CONFIG, GFP_KERNEL); +} + +static int devlink_nl_cmd_param_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct devlink_param_item *param_item; + struct devlink *devlink; + int start = cb->args[0]; + unsigned long index; + int idx = 0; + int err = 0; + + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + devl_lock(devlink); + list_for_each_entry(param_item, &devlink->param_list, list) { + if (idx < start) { + idx++; + continue; + } + err = devlink_nl_param_fill(msg, devlink, 0, param_item, + DEVLINK_CMD_PARAM_GET, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI); + if (err == -EOPNOTSUPP) { + err = 0; + } else if (err) { + devl_unlock(devlink); + devlink_put(devlink); + goto out; + } + idx++; + } + devl_unlock(devlink); + devlink_put(devlink); + } +out: + if (err != -EMSGSIZE) + return err; + + cb->args[0] = idx; + return msg->len; +} + +static int +devlink_param_type_get_from_info(struct genl_info *info, + enum devlink_param_type *param_type) +{ + if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_PARAM_TYPE)) + return -EINVAL; + + switch (nla_get_u8(info->attrs[DEVLINK_ATTR_PARAM_TYPE])) { + case NLA_U8: + *param_type = DEVLINK_PARAM_TYPE_U8; + break; + case NLA_U16: + *param_type = DEVLINK_PARAM_TYPE_U16; + break; + case NLA_U32: + *param_type = DEVLINK_PARAM_TYPE_U32; + break; + case NLA_STRING: + *param_type = DEVLINK_PARAM_TYPE_STRING; + break; + case NLA_FLAG: + *param_type = DEVLINK_PARAM_TYPE_BOOL; + break; + default: + return -EINVAL; + } + + return 0; +} + +static int +devlink_param_value_get_from_info(const struct devlink_param *param, + struct genl_info *info, + union devlink_param_value *value) +{ + struct nlattr *param_data; + int len; + + param_data = info->attrs[DEVLINK_ATTR_PARAM_VALUE_DATA]; + + if (param->type != DEVLINK_PARAM_TYPE_BOOL && !param_data) + return -EINVAL; + + switch (param->type) { + case DEVLINK_PARAM_TYPE_U8: + if (nla_len(param_data) != sizeof(u8)) + return -EINVAL; + value->vu8 = nla_get_u8(param_data); + break; + case DEVLINK_PARAM_TYPE_U16: + if (nla_len(param_data) != sizeof(u16)) + return -EINVAL; + value->vu16 = nla_get_u16(param_data); + break; + case DEVLINK_PARAM_TYPE_U32: + if (nla_len(param_data) != sizeof(u32)) + return -EINVAL; + value->vu32 = nla_get_u32(param_data); + break; + case DEVLINK_PARAM_TYPE_STRING: + len = strnlen(nla_data(param_data), nla_len(param_data)); + if (len == nla_len(param_data) || + len >= __DEVLINK_PARAM_MAX_STRING_VALUE) + return -EINVAL; + strcpy(value->vstr, nla_data(param_data)); + break; + case DEVLINK_PARAM_TYPE_BOOL: + if (param_data && nla_len(param_data)) + return -EINVAL; + value->vbool = nla_get_flag(param_data); + break; + } + return 0; +} + +static struct devlink_param_item * +devlink_param_get_from_info(struct list_head *param_list, + struct genl_info *info) +{ + char *param_name; + + if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_PARAM_NAME)) + return NULL; + + param_name = nla_data(info->attrs[DEVLINK_ATTR_PARAM_NAME]); + return devlink_param_find_by_name(param_list, param_name); +} + +static int devlink_nl_cmd_param_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_param_item *param_item; + struct sk_buff *msg; + int err; + + param_item = devlink_param_get_from_info(&devlink->param_list, info); + if (!param_item) + return -EINVAL; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_param_fill(msg, devlink, 0, param_item, + DEVLINK_CMD_PARAM_GET, + info->snd_portid, info->snd_seq, 0); + if (err) { + nlmsg_free(msg); + return err; + } + + return genlmsg_reply(msg, info); +} + +static int __devlink_nl_cmd_param_set_doit(struct devlink *devlink, + unsigned int port_index, + struct list_head *param_list, + struct genl_info *info, + enum devlink_command cmd) +{ + enum devlink_param_type param_type; + struct devlink_param_gset_ctx ctx; + enum devlink_param_cmode cmode; + struct devlink_param_item *param_item; + const struct devlink_param *param; + union devlink_param_value value; + int err = 0; + + param_item = devlink_param_get_from_info(param_list, info); + if (!param_item) + return -EINVAL; + param = param_item->param; + err = devlink_param_type_get_from_info(info, ¶m_type); + if (err) + return err; + if (param_type != param->type) + return -EINVAL; + err = devlink_param_value_get_from_info(param, info, &value); + if (err) + return err; + if (param->validate) { + err = param->validate(devlink, param->id, value, info->extack); + if (err) + return err; + } + + if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_PARAM_VALUE_CMODE)) + return -EINVAL; + cmode = nla_get_u8(info->attrs[DEVLINK_ATTR_PARAM_VALUE_CMODE]); + if (!devlink_param_cmode_is_supported(param, cmode)) + return -EOPNOTSUPP; + + if (cmode == DEVLINK_PARAM_CMODE_DRIVERINIT) { + if (param->type == DEVLINK_PARAM_TYPE_STRING) + strcpy(param_item->driverinit_value.vstr, value.vstr); + else + param_item->driverinit_value = value; + param_item->driverinit_value_valid = true; + } else { + if (!param->set) + return -EOPNOTSUPP; + ctx.val = value; + ctx.cmode = cmode; + err = devlink_param_set(devlink, param, &ctx); + if (err) + return err; + } + + devlink_param_notify(devlink, port_index, param_item, cmd); + return 0; +} + +static int devlink_nl_cmd_param_set_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + + return __devlink_nl_cmd_param_set_doit(devlink, 0, &devlink->param_list, + info, DEVLINK_CMD_PARAM_NEW); +} + +static int devlink_nl_cmd_port_param_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + NL_SET_ERR_MSG_MOD(cb->extack, "Port params are not supported"); + return msg->len; +} + +static int devlink_nl_cmd_port_param_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + NL_SET_ERR_MSG_MOD(info->extack, "Port params are not supported"); + return -EINVAL; +} + +static int devlink_nl_cmd_port_param_set_doit(struct sk_buff *skb, + struct genl_info *info) +{ + NL_SET_ERR_MSG_MOD(info->extack, "Port params are not supported"); + return -EINVAL; +} + +static int devlink_nl_region_snapshot_id_put(struct sk_buff *msg, + struct devlink *devlink, + struct devlink_snapshot *snapshot) +{ + struct nlattr *snap_attr; + int err; + + snap_attr = nla_nest_start_noflag(msg, DEVLINK_ATTR_REGION_SNAPSHOT); + if (!snap_attr) + return -EINVAL; + + err = nla_put_u32(msg, DEVLINK_ATTR_REGION_SNAPSHOT_ID, snapshot->id); + if (err) + goto nla_put_failure; + + nla_nest_end(msg, snap_attr); + return 0; + +nla_put_failure: + nla_nest_cancel(msg, snap_attr); + return err; +} + +static int devlink_nl_region_snapshots_id_put(struct sk_buff *msg, + struct devlink *devlink, + struct devlink_region *region) +{ + struct devlink_snapshot *snapshot; + struct nlattr *snapshots_attr; + int err; + + snapshots_attr = nla_nest_start_noflag(msg, + DEVLINK_ATTR_REGION_SNAPSHOTS); + if (!snapshots_attr) + return -EINVAL; + + list_for_each_entry(snapshot, ®ion->snapshot_list, list) { + err = devlink_nl_region_snapshot_id_put(msg, devlink, snapshot); + if (err) + goto nla_put_failure; + } + + nla_nest_end(msg, snapshots_attr); + return 0; + +nla_put_failure: + nla_nest_cancel(msg, snapshots_attr); + return err; +} + +static int devlink_nl_region_fill(struct sk_buff *msg, struct devlink *devlink, + enum devlink_command cmd, u32 portid, + u32 seq, int flags, + struct devlink_region *region) +{ + void *hdr; + int err; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + err = devlink_nl_put_handle(msg, devlink); + if (err) + goto nla_put_failure; + + if (region->port) { + err = nla_put_u32(msg, DEVLINK_ATTR_PORT_INDEX, + region->port->index); + if (err) + goto nla_put_failure; + } + + err = nla_put_string(msg, DEVLINK_ATTR_REGION_NAME, region->ops->name); + if (err) + goto nla_put_failure; + + err = nla_put_u64_64bit(msg, DEVLINK_ATTR_REGION_SIZE, + region->size, + DEVLINK_ATTR_PAD); + if (err) + goto nla_put_failure; + + err = nla_put_u32(msg, DEVLINK_ATTR_REGION_MAX_SNAPSHOTS, + region->max_snapshots); + if (err) + goto nla_put_failure; + + err = devlink_nl_region_snapshots_id_put(msg, devlink, region); + if (err) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return err; +} + +static struct sk_buff * +devlink_nl_region_notify_build(struct devlink_region *region, + struct devlink_snapshot *snapshot, + enum devlink_command cmd, u32 portid, u32 seq) +{ + struct devlink *devlink = region->devlink; + struct sk_buff *msg; + void *hdr; + int err; + + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return ERR_PTR(-ENOMEM); + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, 0, cmd); + if (!hdr) { + err = -EMSGSIZE; + goto out_free_msg; + } + + err = devlink_nl_put_handle(msg, devlink); + if (err) + goto out_cancel_msg; + + if (region->port) { + err = nla_put_u32(msg, DEVLINK_ATTR_PORT_INDEX, + region->port->index); + if (err) + goto out_cancel_msg; + } + + err = nla_put_string(msg, DEVLINK_ATTR_REGION_NAME, + region->ops->name); + if (err) + goto out_cancel_msg; + + if (snapshot) { + err = nla_put_u32(msg, DEVLINK_ATTR_REGION_SNAPSHOT_ID, + snapshot->id); + if (err) + goto out_cancel_msg; + } else { + err = nla_put_u64_64bit(msg, DEVLINK_ATTR_REGION_SIZE, + region->size, DEVLINK_ATTR_PAD); + if (err) + goto out_cancel_msg; + } + genlmsg_end(msg, hdr); + + return msg; + +out_cancel_msg: + genlmsg_cancel(msg, hdr); +out_free_msg: + nlmsg_free(msg); + return ERR_PTR(err); +} + +static void devlink_nl_region_notify(struct devlink_region *region, + struct devlink_snapshot *snapshot, + enum devlink_command cmd) +{ + struct devlink *devlink = region->devlink; + struct sk_buff *msg; + + WARN_ON(cmd != DEVLINK_CMD_REGION_NEW && cmd != DEVLINK_CMD_REGION_DEL); + if (!xa_get_mark(&devlinks, devlink->index, DEVLINK_REGISTERED)) + return; + + msg = devlink_nl_region_notify_build(region, snapshot, cmd, 0, 0); + if (IS_ERR(msg)) + return; + + genlmsg_multicast_netns(&devlink_nl_family, devlink_net(devlink), msg, + 0, DEVLINK_MCGRP_CONFIG, GFP_KERNEL); +} + +/** + * __devlink_snapshot_id_increment - Increment number of snapshots using an id + * @devlink: devlink instance + * @id: the snapshot id + * + * Track when a new snapshot begins using an id. Load the count for the + * given id from the snapshot xarray, increment it, and store it back. + * + * Called when a new snapshot is created with the given id. + * + * The id *must* have been previously allocated by + * devlink_region_snapshot_id_get(). + * + * Returns 0 on success, or an error on failure. + */ +static int __devlink_snapshot_id_increment(struct devlink *devlink, u32 id) +{ + unsigned long count; + void *p; + int err; + + xa_lock(&devlink->snapshot_ids); + p = xa_load(&devlink->snapshot_ids, id); + if (WARN_ON(!p)) { + err = -EINVAL; + goto unlock; + } + + if (WARN_ON(!xa_is_value(p))) { + err = -EINVAL; + goto unlock; + } + + count = xa_to_value(p); + count++; + + err = xa_err(__xa_store(&devlink->snapshot_ids, id, xa_mk_value(count), + GFP_ATOMIC)); +unlock: + xa_unlock(&devlink->snapshot_ids); + return err; +} + +/** + * __devlink_snapshot_id_decrement - Decrease number of snapshots using an id + * @devlink: devlink instance + * @id: the snapshot id + * + * Track when a snapshot is deleted and stops using an id. Load the count + * for the given id from the snapshot xarray, decrement it, and store it + * back. + * + * If the count reaches zero, erase this id from the xarray, freeing it + * up for future re-use by devlink_region_snapshot_id_get(). + * + * Called when a snapshot using the given id is deleted, and when the + * initial allocator of the id is finished using it. + */ +static void __devlink_snapshot_id_decrement(struct devlink *devlink, u32 id) +{ + unsigned long count; + void *p; + + xa_lock(&devlink->snapshot_ids); + p = xa_load(&devlink->snapshot_ids, id); + if (WARN_ON(!p)) + goto unlock; + + if (WARN_ON(!xa_is_value(p))) + goto unlock; + + count = xa_to_value(p); + + if (count > 1) { + count--; + __xa_store(&devlink->snapshot_ids, id, xa_mk_value(count), + GFP_ATOMIC); + } else { + /* If this was the last user, we can erase this id */ + __xa_erase(&devlink->snapshot_ids, id); + } +unlock: + xa_unlock(&devlink->snapshot_ids); +} + +/** + * __devlink_snapshot_id_insert - Insert a specific snapshot ID + * @devlink: devlink instance + * @id: the snapshot id + * + * Mark the given snapshot id as used by inserting a zero value into the + * snapshot xarray. + * + * This must be called while holding the devlink instance lock. Unlike + * devlink_snapshot_id_get, the initial reference count is zero, not one. + * It is expected that the id will immediately be used before + * releasing the devlink instance lock. + * + * Returns zero on success, or an error code if the snapshot id could not + * be inserted. + */ +static int __devlink_snapshot_id_insert(struct devlink *devlink, u32 id) +{ + int err; + + xa_lock(&devlink->snapshot_ids); + if (xa_load(&devlink->snapshot_ids, id)) { + xa_unlock(&devlink->snapshot_ids); + return -EEXIST; + } + err = xa_err(__xa_store(&devlink->snapshot_ids, id, xa_mk_value(0), + GFP_ATOMIC)); + xa_unlock(&devlink->snapshot_ids); + return err; +} + +/** + * __devlink_region_snapshot_id_get - get snapshot ID + * @devlink: devlink instance + * @id: storage to return snapshot id + * + * Allocates a new snapshot id. Returns zero on success, or a negative + * error on failure. Must be called while holding the devlink instance + * lock. + * + * Snapshot IDs are tracked using an xarray which stores the number of + * users of the snapshot id. + * + * Note that the caller of this function counts as a 'user', in order to + * avoid race conditions. The caller must release its hold on the + * snapshot by using devlink_region_snapshot_id_put. + */ +static int __devlink_region_snapshot_id_get(struct devlink *devlink, u32 *id) +{ + return xa_alloc(&devlink->snapshot_ids, id, xa_mk_value(1), + xa_limit_32b, GFP_KERNEL); +} + +/** + * __devlink_region_snapshot_create - create a new snapshot + * This will add a new snapshot of a region. The snapshot + * will be stored on the region struct and can be accessed + * from devlink. This is useful for future analyses of snapshots. + * Multiple snapshots can be created on a region. + * The @snapshot_id should be obtained using the getter function. + * + * Must be called only while holding the region snapshot lock. + * + * @region: devlink region of the snapshot + * @data: snapshot data + * @snapshot_id: snapshot id to be created + */ +static int +__devlink_region_snapshot_create(struct devlink_region *region, + u8 *data, u32 snapshot_id) +{ + struct devlink *devlink = region->devlink; + struct devlink_snapshot *snapshot; + int err; + + lockdep_assert_held(®ion->snapshot_lock); + + /* check if region can hold one more snapshot */ + if (region->cur_snapshots == region->max_snapshots) + return -ENOSPC; + + if (devlink_region_snapshot_get_by_id(region, snapshot_id)) + return -EEXIST; + + snapshot = kzalloc(sizeof(*snapshot), GFP_KERNEL); + if (!snapshot) + return -ENOMEM; + + err = __devlink_snapshot_id_increment(devlink, snapshot_id); + if (err) + goto err_snapshot_id_increment; + + snapshot->id = snapshot_id; + snapshot->region = region; + snapshot->data = data; + + list_add_tail(&snapshot->list, ®ion->snapshot_list); + + region->cur_snapshots++; + + devlink_nl_region_notify(region, snapshot, DEVLINK_CMD_REGION_NEW); + return 0; + +err_snapshot_id_increment: + kfree(snapshot); + return err; +} + +static void devlink_region_snapshot_del(struct devlink_region *region, + struct devlink_snapshot *snapshot) +{ + struct devlink *devlink = region->devlink; + + lockdep_assert_held(®ion->snapshot_lock); + + devlink_nl_region_notify(region, snapshot, DEVLINK_CMD_REGION_DEL); + region->cur_snapshots--; + list_del(&snapshot->list); + region->ops->destructor(snapshot->data); + __devlink_snapshot_id_decrement(devlink, snapshot->id); + kfree(snapshot); +} + +static int devlink_nl_cmd_region_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_port *port = NULL; + struct devlink_region *region; + const char *region_name; + struct sk_buff *msg; + unsigned int index; + int err; + + if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_REGION_NAME)) + return -EINVAL; + + if (info->attrs[DEVLINK_ATTR_PORT_INDEX]) { + index = nla_get_u32(info->attrs[DEVLINK_ATTR_PORT_INDEX]); + + port = devlink_port_get_by_index(devlink, index); + if (!port) + return -ENODEV; + } + + region_name = nla_data(info->attrs[DEVLINK_ATTR_REGION_NAME]); + if (port) + region = devlink_port_region_get_by_name(port, region_name); + else + region = devlink_region_get_by_name(devlink, region_name); + + if (!region) + return -EINVAL; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_region_fill(msg, devlink, DEVLINK_CMD_REGION_GET, + info->snd_portid, info->snd_seq, 0, + region); + if (err) { + nlmsg_free(msg); + return err; + } + + return genlmsg_reply(msg, info); +} + +static int devlink_nl_cmd_region_get_port_dumpit(struct sk_buff *msg, + struct netlink_callback *cb, + struct devlink_port *port, + int *idx, + int start) +{ + struct devlink_region *region; + int err = 0; + + list_for_each_entry(region, &port->region_list, list) { + if (*idx < start) { + (*idx)++; + continue; + } + err = devlink_nl_region_fill(msg, port->devlink, + DEVLINK_CMD_REGION_GET, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI, region); + if (err) + goto out; + (*idx)++; + } + +out: + return err; +} + +static int devlink_nl_cmd_region_get_devlink_dumpit(struct sk_buff *msg, + struct netlink_callback *cb, + struct devlink *devlink, + int *idx, + int start) +{ + struct devlink_region *region; + struct devlink_port *port; + int err = 0; + + devl_lock(devlink); + list_for_each_entry(region, &devlink->region_list, list) { + if (*idx < start) { + (*idx)++; + continue; + } + err = devlink_nl_region_fill(msg, devlink, + DEVLINK_CMD_REGION_GET, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI, region); + if (err) + goto out; + (*idx)++; + } + + list_for_each_entry(port, &devlink->port_list, list) { + err = devlink_nl_cmd_region_get_port_dumpit(msg, cb, port, idx, + start); + if (err) + goto out; + } + +out: + devl_unlock(devlink); + return err; +} + +static int devlink_nl_cmd_region_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct devlink *devlink; + int start = cb->args[0]; + unsigned long index; + int idx = 0; + int err = 0; + + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + err = devlink_nl_cmd_region_get_devlink_dumpit(msg, cb, devlink, + &idx, start); + devlink_put(devlink); + if (err) + goto out; + } +out: + cb->args[0] = idx; + return msg->len; +} + +static int devlink_nl_cmd_region_del(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_snapshot *snapshot; + struct devlink_port *port = NULL; + struct devlink_region *region; + const char *region_name; + unsigned int index; + u32 snapshot_id; + + if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_REGION_NAME) || + GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_REGION_SNAPSHOT_ID)) + return -EINVAL; + + region_name = nla_data(info->attrs[DEVLINK_ATTR_REGION_NAME]); + snapshot_id = nla_get_u32(info->attrs[DEVLINK_ATTR_REGION_SNAPSHOT_ID]); + + if (info->attrs[DEVLINK_ATTR_PORT_INDEX]) { + index = nla_get_u32(info->attrs[DEVLINK_ATTR_PORT_INDEX]); + + port = devlink_port_get_by_index(devlink, index); + if (!port) + return -ENODEV; + } + + if (port) + region = devlink_port_region_get_by_name(port, region_name); + else + region = devlink_region_get_by_name(devlink, region_name); + + if (!region) + return -EINVAL; + + mutex_lock(®ion->snapshot_lock); + snapshot = devlink_region_snapshot_get_by_id(region, snapshot_id); + if (!snapshot) { + mutex_unlock(®ion->snapshot_lock); + return -EINVAL; + } + + devlink_region_snapshot_del(region, snapshot); + mutex_unlock(®ion->snapshot_lock); + return 0; +} + +static int +devlink_nl_cmd_region_new(struct sk_buff *skb, struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_snapshot *snapshot; + struct devlink_port *port = NULL; + struct nlattr *snapshot_id_attr; + struct devlink_region *region; + const char *region_name; + unsigned int index; + u32 snapshot_id; + u8 *data; + int err; + + if (GENL_REQ_ATTR_CHECK(info, DEVLINK_ATTR_REGION_NAME)) { + NL_SET_ERR_MSG_MOD(info->extack, "No region name provided"); + return -EINVAL; + } + + region_name = nla_data(info->attrs[DEVLINK_ATTR_REGION_NAME]); + + if (info->attrs[DEVLINK_ATTR_PORT_INDEX]) { + index = nla_get_u32(info->attrs[DEVLINK_ATTR_PORT_INDEX]); + + port = devlink_port_get_by_index(devlink, index); + if (!port) + return -ENODEV; + } + + if (port) + region = devlink_port_region_get_by_name(port, region_name); + else + region = devlink_region_get_by_name(devlink, region_name); + + if (!region) { + NL_SET_ERR_MSG_MOD(info->extack, "The requested region does not exist"); + return -EINVAL; + } + + if (!region->ops->snapshot) { + NL_SET_ERR_MSG_MOD(info->extack, "The requested region does not support taking an immediate snapshot"); + return -EOPNOTSUPP; + } + + mutex_lock(®ion->snapshot_lock); + + if (region->cur_snapshots == region->max_snapshots) { + NL_SET_ERR_MSG_MOD(info->extack, "The region has reached the maximum number of stored snapshots"); + err = -ENOSPC; + goto unlock; + } + + snapshot_id_attr = info->attrs[DEVLINK_ATTR_REGION_SNAPSHOT_ID]; + if (snapshot_id_attr) { + snapshot_id = nla_get_u32(snapshot_id_attr); + + if (devlink_region_snapshot_get_by_id(region, snapshot_id)) { + NL_SET_ERR_MSG_MOD(info->extack, "The requested snapshot id is already in use"); + err = -EEXIST; + goto unlock; + } + + err = __devlink_snapshot_id_insert(devlink, snapshot_id); + if (err) + goto unlock; + } else { + err = __devlink_region_snapshot_id_get(devlink, &snapshot_id); + if (err) { + NL_SET_ERR_MSG_MOD(info->extack, "Failed to allocate a new snapshot id"); + goto unlock; + } + } + + if (port) + err = region->port_ops->snapshot(port, region->port_ops, + info->extack, &data); + else + err = region->ops->snapshot(devlink, region->ops, + info->extack, &data); + if (err) + goto err_snapshot_capture; + + err = __devlink_region_snapshot_create(region, data, snapshot_id); + if (err) + goto err_snapshot_create; + + if (!snapshot_id_attr) { + struct sk_buff *msg; + + snapshot = devlink_region_snapshot_get_by_id(region, + snapshot_id); + if (WARN_ON(!snapshot)) { + err = -EINVAL; + goto unlock; + } + + msg = devlink_nl_region_notify_build(region, snapshot, + DEVLINK_CMD_REGION_NEW, + info->snd_portid, + info->snd_seq); + err = PTR_ERR_OR_ZERO(msg); + if (err) + goto err_notify; + + err = genlmsg_reply(msg, info); + if (err) + goto err_notify; + } + + mutex_unlock(®ion->snapshot_lock); + return 0; + +err_snapshot_create: + region->ops->destructor(data); +err_snapshot_capture: + __devlink_snapshot_id_decrement(devlink, snapshot_id); + mutex_unlock(®ion->snapshot_lock); + return err; + +err_notify: + devlink_region_snapshot_del(region, snapshot); +unlock: + mutex_unlock(®ion->snapshot_lock); + return err; +} + +static int devlink_nl_cmd_region_read_chunk_fill(struct sk_buff *msg, + struct devlink *devlink, + u8 *chunk, u32 chunk_size, + u64 addr) +{ + struct nlattr *chunk_attr; + int err; + + chunk_attr = nla_nest_start_noflag(msg, DEVLINK_ATTR_REGION_CHUNK); + if (!chunk_attr) + return -EINVAL; + + err = nla_put(msg, DEVLINK_ATTR_REGION_CHUNK_DATA, chunk_size, chunk); + if (err) + goto nla_put_failure; + + err = nla_put_u64_64bit(msg, DEVLINK_ATTR_REGION_CHUNK_ADDR, addr, + DEVLINK_ATTR_PAD); + if (err) + goto nla_put_failure; + + nla_nest_end(msg, chunk_attr); + return 0; + +nla_put_failure: + nla_nest_cancel(msg, chunk_attr); + return err; +} + +#define DEVLINK_REGION_READ_CHUNK_SIZE 256 + +static int devlink_nl_region_read_snapshot_fill(struct sk_buff *skb, + struct devlink *devlink, + struct devlink_region *region, + struct nlattr **attrs, + u64 start_offset, + u64 end_offset, + u64 *new_offset) +{ + struct devlink_snapshot *snapshot; + u64 curr_offset = start_offset; + u32 snapshot_id; + int err = 0; + + *new_offset = start_offset; + + snapshot_id = nla_get_u32(attrs[DEVLINK_ATTR_REGION_SNAPSHOT_ID]); + snapshot = devlink_region_snapshot_get_by_id(region, snapshot_id); + if (!snapshot) + return -EINVAL; + + while (curr_offset < end_offset) { + u32 data_size; + u8 *data; + + if (end_offset - curr_offset < DEVLINK_REGION_READ_CHUNK_SIZE) + data_size = end_offset - curr_offset; + else + data_size = DEVLINK_REGION_READ_CHUNK_SIZE; + + data = &snapshot->data[curr_offset]; + err = devlink_nl_cmd_region_read_chunk_fill(skb, devlink, + data, data_size, + curr_offset); + if (err) + break; + + curr_offset += data_size; + } + *new_offset = curr_offset; + + return err; +} + +static int devlink_nl_cmd_region_read_dumpit(struct sk_buff *skb, + struct netlink_callback *cb) +{ + const struct genl_dumpit_info *info = genl_dumpit_info(cb); + u64 ret_offset, start_offset, end_offset = U64_MAX; + struct nlattr **attrs = info->attrs; + struct devlink_port *port = NULL; + struct devlink_region *region; + struct nlattr *chunks_attr; + const char *region_name; + struct devlink *devlink; + unsigned int index; + void *hdr; + int err; + + start_offset = *((u64 *)&cb->args[0]); + + devlink = devlink_get_from_attrs(sock_net(cb->skb->sk), attrs); + if (IS_ERR(devlink)) + return PTR_ERR(devlink); + + devl_lock(devlink); + + if (!attrs[DEVLINK_ATTR_REGION_NAME] || + !attrs[DEVLINK_ATTR_REGION_SNAPSHOT_ID]) { + err = -EINVAL; + goto out_unlock; + } + + if (info->attrs[DEVLINK_ATTR_PORT_INDEX]) { + index = nla_get_u32(info->attrs[DEVLINK_ATTR_PORT_INDEX]); + + port = devlink_port_get_by_index(devlink, index); + if (!port) { + err = -ENODEV; + goto out_unlock; + } + } + + region_name = nla_data(attrs[DEVLINK_ATTR_REGION_NAME]); + + if (port) + region = devlink_port_region_get_by_name(port, region_name); + else + region = devlink_region_get_by_name(devlink, region_name); + + if (!region) { + err = -EINVAL; + goto out_unlock; + } + + if (attrs[DEVLINK_ATTR_REGION_CHUNK_ADDR] && + attrs[DEVLINK_ATTR_REGION_CHUNK_LEN]) { + if (!start_offset) + start_offset = + nla_get_u64(attrs[DEVLINK_ATTR_REGION_CHUNK_ADDR]); + + end_offset = nla_get_u64(attrs[DEVLINK_ATTR_REGION_CHUNK_ADDR]); + end_offset += nla_get_u64(attrs[DEVLINK_ATTR_REGION_CHUNK_LEN]); + } + + if (end_offset > region->size) + end_offset = region->size; + + /* return 0 if there is no further data to read */ + if (start_offset == end_offset) { + err = 0; + goto out_unlock; + } + + hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &devlink_nl_family, NLM_F_ACK | NLM_F_MULTI, + DEVLINK_CMD_REGION_READ); + if (!hdr) { + err = -EMSGSIZE; + goto out_unlock; + } + + err = devlink_nl_put_handle(skb, devlink); + if (err) + goto nla_put_failure; + + if (region->port) { + err = nla_put_u32(skb, DEVLINK_ATTR_PORT_INDEX, + region->port->index); + if (err) + goto nla_put_failure; + } + + err = nla_put_string(skb, DEVLINK_ATTR_REGION_NAME, region_name); + if (err) + goto nla_put_failure; + + chunks_attr = nla_nest_start_noflag(skb, DEVLINK_ATTR_REGION_CHUNKS); + if (!chunks_attr) { + err = -EMSGSIZE; + goto nla_put_failure; + } + + err = devlink_nl_region_read_snapshot_fill(skb, devlink, + region, attrs, + start_offset, + end_offset, &ret_offset); + + if (err && err != -EMSGSIZE) + goto nla_put_failure; + + /* Check if there was any progress done to prevent infinite loop */ + if (ret_offset == start_offset) { + err = -EINVAL; + goto nla_put_failure; + } + + *((u64 *)&cb->args[0]) = ret_offset; + + nla_nest_end(skb, chunks_attr); + genlmsg_end(skb, hdr); + devl_unlock(devlink); + devlink_put(devlink); + return skb->len; + +nla_put_failure: + genlmsg_cancel(skb, hdr); +out_unlock: + devl_unlock(devlink); + devlink_put(devlink); + return err; +} + +int devlink_info_driver_name_put(struct devlink_info_req *req, const char *name) +{ + if (!req->msg) + return 0; + return nla_put_string(req->msg, DEVLINK_ATTR_INFO_DRIVER_NAME, name); +} +EXPORT_SYMBOL_GPL(devlink_info_driver_name_put); + +int devlink_info_serial_number_put(struct devlink_info_req *req, const char *sn) +{ + if (!req->msg) + return 0; + return nla_put_string(req->msg, DEVLINK_ATTR_INFO_SERIAL_NUMBER, sn); +} +EXPORT_SYMBOL_GPL(devlink_info_serial_number_put); + +int devlink_info_board_serial_number_put(struct devlink_info_req *req, + const char *bsn) +{ + if (!req->msg) + return 0; + return nla_put_string(req->msg, DEVLINK_ATTR_INFO_BOARD_SERIAL_NUMBER, + bsn); +} +EXPORT_SYMBOL_GPL(devlink_info_board_serial_number_put); + +static int devlink_info_version_put(struct devlink_info_req *req, int attr, + const char *version_name, + const char *version_value, + enum devlink_info_version_type version_type) +{ + struct nlattr *nest; + int err; + + if (req->version_cb) + req->version_cb(version_name, version_type, + req->version_cb_priv); + + if (!req->msg) + return 0; + + nest = nla_nest_start_noflag(req->msg, attr); + if (!nest) + return -EMSGSIZE; + + err = nla_put_string(req->msg, DEVLINK_ATTR_INFO_VERSION_NAME, + version_name); + if (err) + goto nla_put_failure; + + err = nla_put_string(req->msg, DEVLINK_ATTR_INFO_VERSION_VALUE, + version_value); + if (err) + goto nla_put_failure; + + nla_nest_end(req->msg, nest); + + return 0; + +nla_put_failure: + nla_nest_cancel(req->msg, nest); + return err; +} + +int devlink_info_version_fixed_put(struct devlink_info_req *req, + const char *version_name, + const char *version_value) +{ + return devlink_info_version_put(req, DEVLINK_ATTR_INFO_VERSION_FIXED, + version_name, version_value, + DEVLINK_INFO_VERSION_TYPE_NONE); +} +EXPORT_SYMBOL_GPL(devlink_info_version_fixed_put); + +int devlink_info_version_stored_put(struct devlink_info_req *req, + const char *version_name, + const char *version_value) +{ + return devlink_info_version_put(req, DEVLINK_ATTR_INFO_VERSION_STORED, + version_name, version_value, + DEVLINK_INFO_VERSION_TYPE_NONE); +} +EXPORT_SYMBOL_GPL(devlink_info_version_stored_put); + +int devlink_info_version_stored_put_ext(struct devlink_info_req *req, + const char *version_name, + const char *version_value, + enum devlink_info_version_type version_type) +{ + return devlink_info_version_put(req, DEVLINK_ATTR_INFO_VERSION_STORED, + version_name, version_value, + version_type); +} +EXPORT_SYMBOL_GPL(devlink_info_version_stored_put_ext); + +int devlink_info_version_running_put(struct devlink_info_req *req, + const char *version_name, + const char *version_value) +{ + return devlink_info_version_put(req, DEVLINK_ATTR_INFO_VERSION_RUNNING, + version_name, version_value, + DEVLINK_INFO_VERSION_TYPE_NONE); +} +EXPORT_SYMBOL_GPL(devlink_info_version_running_put); + +int devlink_info_version_running_put_ext(struct devlink_info_req *req, + const char *version_name, + const char *version_value, + enum devlink_info_version_type version_type) +{ + return devlink_info_version_put(req, DEVLINK_ATTR_INFO_VERSION_RUNNING, + version_name, version_value, + version_type); +} +EXPORT_SYMBOL_GPL(devlink_info_version_running_put_ext); + +static int +devlink_nl_info_fill(struct sk_buff *msg, struct devlink *devlink, + enum devlink_command cmd, u32 portid, + u32 seq, int flags, struct netlink_ext_ack *extack) +{ + struct devlink_info_req req = {}; + void *hdr; + int err; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + err = -EMSGSIZE; + if (devlink_nl_put_handle(msg, devlink)) + goto err_cancel_msg; + + req.msg = msg; + err = devlink->ops->info_get(devlink, &req, extack); + if (err) + goto err_cancel_msg; + + genlmsg_end(msg, hdr); + return 0; + +err_cancel_msg: + genlmsg_cancel(msg, hdr); + return err; +} + +static int devlink_nl_cmd_info_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct sk_buff *msg; + int err; + + if (!devlink->ops->info_get) + return -EOPNOTSUPP; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_info_fill(msg, devlink, DEVLINK_CMD_INFO_GET, + info->snd_portid, info->snd_seq, 0, + info->extack); + if (err) { + nlmsg_free(msg); + return err; + } + + return genlmsg_reply(msg, info); +} + +static int devlink_nl_cmd_info_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct devlink *devlink; + int start = cb->args[0]; + unsigned long index; + int idx = 0; + int err = 0; + + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + if (idx < start || !devlink->ops->info_get) + goto inc; + + devl_lock(devlink); + err = devlink_nl_info_fill(msg, devlink, DEVLINK_CMD_INFO_GET, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + cb->extack); + devl_unlock(devlink); + if (err == -EOPNOTSUPP) + err = 0; + else if (err) { + devlink_put(devlink); + break; + } +inc: + idx++; + devlink_put(devlink); + } + + if (err != -EMSGSIZE) + return err; + + cb->args[0] = idx; + return msg->len; +} + +struct devlink_fmsg_item { + struct list_head list; + int attrtype; + u8 nla_type; + u16 len; + int value[]; +}; + +struct devlink_fmsg { + struct list_head item_list; + bool putting_binary; /* This flag forces enclosing of binary data + * in an array brackets. It forces using + * of designated API: + * devlink_fmsg_binary_pair_nest_start() + * devlink_fmsg_binary_pair_nest_end() + */ +}; + +static struct devlink_fmsg *devlink_fmsg_alloc(void) +{ + struct devlink_fmsg *fmsg; + + fmsg = kzalloc(sizeof(*fmsg), GFP_KERNEL); + if (!fmsg) + return NULL; + + INIT_LIST_HEAD(&fmsg->item_list); + + return fmsg; +} + +static void devlink_fmsg_free(struct devlink_fmsg *fmsg) +{ + struct devlink_fmsg_item *item, *tmp; + + list_for_each_entry_safe(item, tmp, &fmsg->item_list, list) { + list_del(&item->list); + kfree(item); + } + kfree(fmsg); +} + +static int devlink_fmsg_nest_common(struct devlink_fmsg *fmsg, + int attrtype) +{ + struct devlink_fmsg_item *item; + + item = kzalloc(sizeof(*item), GFP_KERNEL); + if (!item) + return -ENOMEM; + + item->attrtype = attrtype; + list_add_tail(&item->list, &fmsg->item_list); + + return 0; +} + +int devlink_fmsg_obj_nest_start(struct devlink_fmsg *fmsg) +{ + if (fmsg->putting_binary) + return -EINVAL; + + return devlink_fmsg_nest_common(fmsg, DEVLINK_ATTR_FMSG_OBJ_NEST_START); +} +EXPORT_SYMBOL_GPL(devlink_fmsg_obj_nest_start); + +static int devlink_fmsg_nest_end(struct devlink_fmsg *fmsg) +{ + if (fmsg->putting_binary) + return -EINVAL; + + return devlink_fmsg_nest_common(fmsg, DEVLINK_ATTR_FMSG_NEST_END); +} + +int devlink_fmsg_obj_nest_end(struct devlink_fmsg *fmsg) +{ + if (fmsg->putting_binary) + return -EINVAL; + + return devlink_fmsg_nest_end(fmsg); +} +EXPORT_SYMBOL_GPL(devlink_fmsg_obj_nest_end); + +#define DEVLINK_FMSG_MAX_SIZE (GENLMSG_DEFAULT_SIZE - GENL_HDRLEN - NLA_HDRLEN) + +static int devlink_fmsg_put_name(struct devlink_fmsg *fmsg, const char *name) +{ + struct devlink_fmsg_item *item; + + if (fmsg->putting_binary) + return -EINVAL; + + if (strlen(name) + 1 > DEVLINK_FMSG_MAX_SIZE) + return -EMSGSIZE; + + item = kzalloc(sizeof(*item) + strlen(name) + 1, GFP_KERNEL); + if (!item) + return -ENOMEM; + + item->nla_type = NLA_NUL_STRING; + item->len = strlen(name) + 1; + item->attrtype = DEVLINK_ATTR_FMSG_OBJ_NAME; + memcpy(&item->value, name, item->len); + list_add_tail(&item->list, &fmsg->item_list); + + return 0; +} + +int devlink_fmsg_pair_nest_start(struct devlink_fmsg *fmsg, const char *name) +{ + int err; + + if (fmsg->putting_binary) + return -EINVAL; + + err = devlink_fmsg_nest_common(fmsg, DEVLINK_ATTR_FMSG_PAIR_NEST_START); + if (err) + return err; + + err = devlink_fmsg_put_name(fmsg, name); + if (err) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_pair_nest_start); + +int devlink_fmsg_pair_nest_end(struct devlink_fmsg *fmsg) +{ + if (fmsg->putting_binary) + return -EINVAL; + + return devlink_fmsg_nest_end(fmsg); +} +EXPORT_SYMBOL_GPL(devlink_fmsg_pair_nest_end); + +int devlink_fmsg_arr_pair_nest_start(struct devlink_fmsg *fmsg, + const char *name) +{ + int err; + + if (fmsg->putting_binary) + return -EINVAL; + + err = devlink_fmsg_pair_nest_start(fmsg, name); + if (err) + return err; + + err = devlink_fmsg_nest_common(fmsg, DEVLINK_ATTR_FMSG_ARR_NEST_START); + if (err) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_arr_pair_nest_start); + +int devlink_fmsg_arr_pair_nest_end(struct devlink_fmsg *fmsg) +{ + int err; + + if (fmsg->putting_binary) + return -EINVAL; + + err = devlink_fmsg_nest_end(fmsg); + if (err) + return err; + + err = devlink_fmsg_nest_end(fmsg); + if (err) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_arr_pair_nest_end); + +int devlink_fmsg_binary_pair_nest_start(struct devlink_fmsg *fmsg, + const char *name) +{ + int err; + + err = devlink_fmsg_arr_pair_nest_start(fmsg, name); + if (err) + return err; + + fmsg->putting_binary = true; + return err; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_binary_pair_nest_start); + +int devlink_fmsg_binary_pair_nest_end(struct devlink_fmsg *fmsg) +{ + if (!fmsg->putting_binary) + return -EINVAL; + + fmsg->putting_binary = false; + return devlink_fmsg_arr_pair_nest_end(fmsg); +} +EXPORT_SYMBOL_GPL(devlink_fmsg_binary_pair_nest_end); + +static int devlink_fmsg_put_value(struct devlink_fmsg *fmsg, + const void *value, u16 value_len, + u8 value_nla_type) +{ + struct devlink_fmsg_item *item; + + if (value_len > DEVLINK_FMSG_MAX_SIZE) + return -EMSGSIZE; + + item = kzalloc(sizeof(*item) + value_len, GFP_KERNEL); + if (!item) + return -ENOMEM; + + item->nla_type = value_nla_type; + item->len = value_len; + item->attrtype = DEVLINK_ATTR_FMSG_OBJ_VALUE_DATA; + memcpy(&item->value, value, item->len); + list_add_tail(&item->list, &fmsg->item_list); + + return 0; +} + +static int devlink_fmsg_bool_put(struct devlink_fmsg *fmsg, bool value) +{ + if (fmsg->putting_binary) + return -EINVAL; + + return devlink_fmsg_put_value(fmsg, &value, sizeof(value), NLA_FLAG); +} + +static int devlink_fmsg_u8_put(struct devlink_fmsg *fmsg, u8 value) +{ + if (fmsg->putting_binary) + return -EINVAL; + + return devlink_fmsg_put_value(fmsg, &value, sizeof(value), NLA_U8); +} + +int devlink_fmsg_u32_put(struct devlink_fmsg *fmsg, u32 value) +{ + if (fmsg->putting_binary) + return -EINVAL; + + return devlink_fmsg_put_value(fmsg, &value, sizeof(value), NLA_U32); +} +EXPORT_SYMBOL_GPL(devlink_fmsg_u32_put); + +static int devlink_fmsg_u64_put(struct devlink_fmsg *fmsg, u64 value) +{ + if (fmsg->putting_binary) + return -EINVAL; + + return devlink_fmsg_put_value(fmsg, &value, sizeof(value), NLA_U64); +} + +int devlink_fmsg_string_put(struct devlink_fmsg *fmsg, const char *value) +{ + if (fmsg->putting_binary) + return -EINVAL; + + return devlink_fmsg_put_value(fmsg, value, strlen(value) + 1, + NLA_NUL_STRING); +} +EXPORT_SYMBOL_GPL(devlink_fmsg_string_put); + +int devlink_fmsg_binary_put(struct devlink_fmsg *fmsg, const void *value, + u16 value_len) +{ + if (!fmsg->putting_binary) + return -EINVAL; + + return devlink_fmsg_put_value(fmsg, value, value_len, NLA_BINARY); +} +EXPORT_SYMBOL_GPL(devlink_fmsg_binary_put); + +int devlink_fmsg_bool_pair_put(struct devlink_fmsg *fmsg, const char *name, + bool value) +{ + int err; + + err = devlink_fmsg_pair_nest_start(fmsg, name); + if (err) + return err; + + err = devlink_fmsg_bool_put(fmsg, value); + if (err) + return err; + + err = devlink_fmsg_pair_nest_end(fmsg); + if (err) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_bool_pair_put); + +int devlink_fmsg_u8_pair_put(struct devlink_fmsg *fmsg, const char *name, + u8 value) +{ + int err; + + err = devlink_fmsg_pair_nest_start(fmsg, name); + if (err) + return err; + + err = devlink_fmsg_u8_put(fmsg, value); + if (err) + return err; + + err = devlink_fmsg_pair_nest_end(fmsg); + if (err) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_u8_pair_put); + +int devlink_fmsg_u32_pair_put(struct devlink_fmsg *fmsg, const char *name, + u32 value) +{ + int err; + + err = devlink_fmsg_pair_nest_start(fmsg, name); + if (err) + return err; + + err = devlink_fmsg_u32_put(fmsg, value); + if (err) + return err; + + err = devlink_fmsg_pair_nest_end(fmsg); + if (err) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_u32_pair_put); + +int devlink_fmsg_u64_pair_put(struct devlink_fmsg *fmsg, const char *name, + u64 value) +{ + int err; + + err = devlink_fmsg_pair_nest_start(fmsg, name); + if (err) + return err; + + err = devlink_fmsg_u64_put(fmsg, value); + if (err) + return err; + + err = devlink_fmsg_pair_nest_end(fmsg); + if (err) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_u64_pair_put); + +int devlink_fmsg_string_pair_put(struct devlink_fmsg *fmsg, const char *name, + const char *value) +{ + int err; + + err = devlink_fmsg_pair_nest_start(fmsg, name); + if (err) + return err; + + err = devlink_fmsg_string_put(fmsg, value); + if (err) + return err; + + err = devlink_fmsg_pair_nest_end(fmsg); + if (err) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_string_pair_put); + +int devlink_fmsg_binary_pair_put(struct devlink_fmsg *fmsg, const char *name, + const void *value, u32 value_len) +{ + u32 data_size; + int end_err; + u32 offset; + int err; + + err = devlink_fmsg_binary_pair_nest_start(fmsg, name); + if (err) + return err; + + for (offset = 0; offset < value_len; offset += data_size) { + data_size = value_len - offset; + if (data_size > DEVLINK_FMSG_MAX_SIZE) + data_size = DEVLINK_FMSG_MAX_SIZE; + err = devlink_fmsg_binary_put(fmsg, value + offset, data_size); + if (err) + break; + /* Exit from loop with a break (instead of + * return) to make sure putting_binary is turned off in + * devlink_fmsg_binary_pair_nest_end + */ + } + + end_err = devlink_fmsg_binary_pair_nest_end(fmsg); + if (end_err) + err = end_err; + + return err; +} +EXPORT_SYMBOL_GPL(devlink_fmsg_binary_pair_put); + +static int +devlink_fmsg_item_fill_type(struct devlink_fmsg_item *msg, struct sk_buff *skb) +{ + switch (msg->nla_type) { + case NLA_FLAG: + case NLA_U8: + case NLA_U32: + case NLA_U64: + case NLA_NUL_STRING: + case NLA_BINARY: + return nla_put_u8(skb, DEVLINK_ATTR_FMSG_OBJ_VALUE_TYPE, + msg->nla_type); + default: + return -EINVAL; + } +} + +static int +devlink_fmsg_item_fill_data(struct devlink_fmsg_item *msg, struct sk_buff *skb) +{ + int attrtype = DEVLINK_ATTR_FMSG_OBJ_VALUE_DATA; + u8 tmp; + + switch (msg->nla_type) { + case NLA_FLAG: + /* Always provide flag data, regardless of its value */ + tmp = *(bool *) msg->value; + + return nla_put_u8(skb, attrtype, tmp); + case NLA_U8: + return nla_put_u8(skb, attrtype, *(u8 *) msg->value); + case NLA_U32: + return nla_put_u32(skb, attrtype, *(u32 *) msg->value); + case NLA_U64: + return nla_put_u64_64bit(skb, attrtype, *(u64 *) msg->value, + DEVLINK_ATTR_PAD); + case NLA_NUL_STRING: + return nla_put_string(skb, attrtype, (char *) &msg->value); + case NLA_BINARY: + return nla_put(skb, attrtype, msg->len, (void *) &msg->value); + default: + return -EINVAL; + } +} + +static int +devlink_fmsg_prepare_skb(struct devlink_fmsg *fmsg, struct sk_buff *skb, + int *start) +{ + struct devlink_fmsg_item *item; + struct nlattr *fmsg_nlattr; + int i = 0; + int err; + + fmsg_nlattr = nla_nest_start_noflag(skb, DEVLINK_ATTR_FMSG); + if (!fmsg_nlattr) + return -EMSGSIZE; + + list_for_each_entry(item, &fmsg->item_list, list) { + if (i < *start) { + i++; + continue; + } + + switch (item->attrtype) { + case DEVLINK_ATTR_FMSG_OBJ_NEST_START: + case DEVLINK_ATTR_FMSG_PAIR_NEST_START: + case DEVLINK_ATTR_FMSG_ARR_NEST_START: + case DEVLINK_ATTR_FMSG_NEST_END: + err = nla_put_flag(skb, item->attrtype); + break; + case DEVLINK_ATTR_FMSG_OBJ_VALUE_DATA: + err = devlink_fmsg_item_fill_type(item, skb); + if (err) + break; + err = devlink_fmsg_item_fill_data(item, skb); + break; + case DEVLINK_ATTR_FMSG_OBJ_NAME: + err = nla_put_string(skb, item->attrtype, + (char *) &item->value); + break; + default: + err = -EINVAL; + break; + } + if (!err) + *start = ++i; + else + break; + } + + nla_nest_end(skb, fmsg_nlattr); + return err; +} + +static int devlink_fmsg_snd(struct devlink_fmsg *fmsg, + struct genl_info *info, + enum devlink_command cmd, int flags) +{ + struct nlmsghdr *nlh; + struct sk_buff *skb; + bool last = false; + int index = 0; + void *hdr; + int err; + + while (!last) { + int tmp_index = index; + + skb = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb) + return -ENOMEM; + + hdr = genlmsg_put(skb, info->snd_portid, info->snd_seq, + &devlink_nl_family, flags | NLM_F_MULTI, cmd); + if (!hdr) { + err = -EMSGSIZE; + goto nla_put_failure; + } + + err = devlink_fmsg_prepare_skb(fmsg, skb, &index); + if (!err) + last = true; + else if (err != -EMSGSIZE || tmp_index == index) + goto nla_put_failure; + + genlmsg_end(skb, hdr); + err = genlmsg_reply(skb, info); + if (err) + return err; + } + + skb = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb) + return -ENOMEM; + nlh = nlmsg_put(skb, info->snd_portid, info->snd_seq, + NLMSG_DONE, 0, flags | NLM_F_MULTI); + if (!nlh) { + err = -EMSGSIZE; + goto nla_put_failure; + } + + return genlmsg_reply(skb, info); + +nla_put_failure: + nlmsg_free(skb); + return err; +} + +static int devlink_fmsg_dumpit(struct devlink_fmsg *fmsg, struct sk_buff *skb, + struct netlink_callback *cb, + enum devlink_command cmd) +{ + int index = cb->args[0]; + int tmp_index = index; + void *hdr; + int err; + + hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &devlink_nl_family, NLM_F_ACK | NLM_F_MULTI, cmd); + if (!hdr) { + err = -EMSGSIZE; + goto nla_put_failure; + } + + err = devlink_fmsg_prepare_skb(fmsg, skb, &index); + if ((err && err != -EMSGSIZE) || tmp_index == index) + goto nla_put_failure; + + cb->args[0] = index; + genlmsg_end(skb, hdr); + return skb->len; + +nla_put_failure: + genlmsg_cancel(skb, hdr); + return err; +} + +struct devlink_health_reporter { + struct list_head list; + void *priv; + const struct devlink_health_reporter_ops *ops; + struct devlink *devlink; + struct devlink_port *devlink_port; + struct devlink_fmsg *dump_fmsg; + struct mutex dump_lock; /* lock parallel read/write from dump buffers */ + u64 graceful_period; + bool auto_recover; + bool auto_dump; + u8 health_state; + u64 dump_ts; + u64 dump_real_ts; + u64 error_count; + u64 recovery_count; + u64 last_recovery_ts; + refcount_t refcount; +}; + +void * +devlink_health_reporter_priv(struct devlink_health_reporter *reporter) +{ + return reporter->priv; +} +EXPORT_SYMBOL_GPL(devlink_health_reporter_priv); + +static struct devlink_health_reporter * +__devlink_health_reporter_find_by_name(struct list_head *reporter_list, + struct mutex *list_lock, + const char *reporter_name) +{ + struct devlink_health_reporter *reporter; + + lockdep_assert_held(list_lock); + list_for_each_entry(reporter, reporter_list, list) + if (!strcmp(reporter->ops->name, reporter_name)) + return reporter; + return NULL; +} + +static struct devlink_health_reporter * +devlink_health_reporter_find_by_name(struct devlink *devlink, + const char *reporter_name) +{ + return __devlink_health_reporter_find_by_name(&devlink->reporter_list, + &devlink->reporters_lock, + reporter_name); +} + +static struct devlink_health_reporter * +devlink_port_health_reporter_find_by_name(struct devlink_port *devlink_port, + const char *reporter_name) +{ + return __devlink_health_reporter_find_by_name(&devlink_port->reporter_list, + &devlink_port->reporters_lock, + reporter_name); +} + +static struct devlink_health_reporter * +__devlink_health_reporter_create(struct devlink *devlink, + const struct devlink_health_reporter_ops *ops, + u64 graceful_period, void *priv) +{ + struct devlink_health_reporter *reporter; + + if (WARN_ON(graceful_period && !ops->recover)) + return ERR_PTR(-EINVAL); + + reporter = kzalloc(sizeof(*reporter), GFP_KERNEL); + if (!reporter) + return ERR_PTR(-ENOMEM); + + reporter->priv = priv; + reporter->ops = ops; + reporter->devlink = devlink; + reporter->graceful_period = graceful_period; + reporter->auto_recover = !!ops->recover; + reporter->auto_dump = !!ops->dump; + mutex_init(&reporter->dump_lock); + refcount_set(&reporter->refcount, 1); + return reporter; +} + +/** + * devlink_port_health_reporter_create - create devlink health reporter for + * specified port instance + * + * @port: devlink_port which should contain the new reporter + * @ops: ops + * @graceful_period: to avoid recovery loops, in msecs + * @priv: priv + */ +struct devlink_health_reporter * +devlink_port_health_reporter_create(struct devlink_port *port, + const struct devlink_health_reporter_ops *ops, + u64 graceful_period, void *priv) +{ + struct devlink_health_reporter *reporter; + + mutex_lock(&port->reporters_lock); + if (__devlink_health_reporter_find_by_name(&port->reporter_list, + &port->reporters_lock, ops->name)) { + reporter = ERR_PTR(-EEXIST); + goto unlock; + } + + reporter = __devlink_health_reporter_create(port->devlink, ops, + graceful_period, priv); + if (IS_ERR(reporter)) + goto unlock; + + reporter->devlink_port = port; + list_add_tail(&reporter->list, &port->reporter_list); +unlock: + mutex_unlock(&port->reporters_lock); + return reporter; +} +EXPORT_SYMBOL_GPL(devlink_port_health_reporter_create); + +/** + * devlink_health_reporter_create - create devlink health reporter + * + * @devlink: devlink + * @ops: ops + * @graceful_period: to avoid recovery loops, in msecs + * @priv: priv + */ +struct devlink_health_reporter * +devlink_health_reporter_create(struct devlink *devlink, + const struct devlink_health_reporter_ops *ops, + u64 graceful_period, void *priv) +{ + struct devlink_health_reporter *reporter; + + mutex_lock(&devlink->reporters_lock); + if (devlink_health_reporter_find_by_name(devlink, ops->name)) { + reporter = ERR_PTR(-EEXIST); + goto unlock; + } + + reporter = __devlink_health_reporter_create(devlink, ops, + graceful_period, priv); + if (IS_ERR(reporter)) + goto unlock; + + list_add_tail(&reporter->list, &devlink->reporter_list); +unlock: + mutex_unlock(&devlink->reporters_lock); + return reporter; +} +EXPORT_SYMBOL_GPL(devlink_health_reporter_create); + +static void +devlink_health_reporter_free(struct devlink_health_reporter *reporter) +{ + mutex_destroy(&reporter->dump_lock); + if (reporter->dump_fmsg) + devlink_fmsg_free(reporter->dump_fmsg); + kfree(reporter); +} + +static void +devlink_health_reporter_put(struct devlink_health_reporter *reporter) +{ + if (refcount_dec_and_test(&reporter->refcount)) + devlink_health_reporter_free(reporter); +} + +static void +__devlink_health_reporter_destroy(struct devlink_health_reporter *reporter) +{ + list_del(&reporter->list); + devlink_health_reporter_put(reporter); +} + +/** + * devlink_health_reporter_destroy - destroy devlink health reporter + * + * @reporter: devlink health reporter to destroy + */ +void +devlink_health_reporter_destroy(struct devlink_health_reporter *reporter) +{ + struct mutex *lock = &reporter->devlink->reporters_lock; + + mutex_lock(lock); + __devlink_health_reporter_destroy(reporter); + mutex_unlock(lock); +} +EXPORT_SYMBOL_GPL(devlink_health_reporter_destroy); + +/** + * devlink_port_health_reporter_destroy - destroy devlink port health reporter + * + * @reporter: devlink health reporter to destroy + */ +void +devlink_port_health_reporter_destroy(struct devlink_health_reporter *reporter) +{ + struct mutex *lock = &reporter->devlink_port->reporters_lock; + + mutex_lock(lock); + __devlink_health_reporter_destroy(reporter); + mutex_unlock(lock); +} +EXPORT_SYMBOL_GPL(devlink_port_health_reporter_destroy); + +static int +devlink_nl_health_reporter_fill(struct sk_buff *msg, + struct devlink_health_reporter *reporter, + enum devlink_command cmd, u32 portid, + u32 seq, int flags) +{ + struct devlink *devlink = reporter->devlink; + struct nlattr *reporter_attr; + void *hdr; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + if (devlink_nl_put_handle(msg, devlink)) + goto genlmsg_cancel; + + if (reporter->devlink_port) { + if (nla_put_u32(msg, DEVLINK_ATTR_PORT_INDEX, reporter->devlink_port->index)) + goto genlmsg_cancel; + } + reporter_attr = nla_nest_start_noflag(msg, + DEVLINK_ATTR_HEALTH_REPORTER); + if (!reporter_attr) + goto genlmsg_cancel; + if (nla_put_string(msg, DEVLINK_ATTR_HEALTH_REPORTER_NAME, + reporter->ops->name)) + goto reporter_nest_cancel; + if (nla_put_u8(msg, DEVLINK_ATTR_HEALTH_REPORTER_STATE, + reporter->health_state)) + goto reporter_nest_cancel; + if (nla_put_u64_64bit(msg, DEVLINK_ATTR_HEALTH_REPORTER_ERR_COUNT, + reporter->error_count, DEVLINK_ATTR_PAD)) + goto reporter_nest_cancel; + if (nla_put_u64_64bit(msg, DEVLINK_ATTR_HEALTH_REPORTER_RECOVER_COUNT, + reporter->recovery_count, DEVLINK_ATTR_PAD)) + goto reporter_nest_cancel; + if (reporter->ops->recover && + nla_put_u64_64bit(msg, DEVLINK_ATTR_HEALTH_REPORTER_GRACEFUL_PERIOD, + reporter->graceful_period, + DEVLINK_ATTR_PAD)) + goto reporter_nest_cancel; + if (reporter->ops->recover && + nla_put_u8(msg, DEVLINK_ATTR_HEALTH_REPORTER_AUTO_RECOVER, + reporter->auto_recover)) + goto reporter_nest_cancel; + if (reporter->dump_fmsg && + nla_put_u64_64bit(msg, DEVLINK_ATTR_HEALTH_REPORTER_DUMP_TS, + jiffies_to_msecs(reporter->dump_ts), + DEVLINK_ATTR_PAD)) + goto reporter_nest_cancel; + if (reporter->dump_fmsg && + nla_put_u64_64bit(msg, DEVLINK_ATTR_HEALTH_REPORTER_DUMP_TS_NS, + reporter->dump_real_ts, DEVLINK_ATTR_PAD)) + goto reporter_nest_cancel; + if (reporter->ops->dump && + nla_put_u8(msg, DEVLINK_ATTR_HEALTH_REPORTER_AUTO_DUMP, + reporter->auto_dump)) + goto reporter_nest_cancel; + + nla_nest_end(msg, reporter_attr); + genlmsg_end(msg, hdr); + return 0; + +reporter_nest_cancel: + nla_nest_end(msg, reporter_attr); +genlmsg_cancel: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static void devlink_recover_notify(struct devlink_health_reporter *reporter, + enum devlink_command cmd) +{ + struct devlink *devlink = reporter->devlink; + struct sk_buff *msg; + int err; + + WARN_ON(cmd != DEVLINK_CMD_HEALTH_REPORTER_RECOVER); + WARN_ON(!xa_get_mark(&devlinks, devlink->index, DEVLINK_REGISTERED)); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + err = devlink_nl_health_reporter_fill(msg, reporter, cmd, 0, 0, 0); + if (err) { + nlmsg_free(msg); + return; + } + + genlmsg_multicast_netns(&devlink_nl_family, devlink_net(devlink), msg, + 0, DEVLINK_MCGRP_CONFIG, GFP_KERNEL); +} + +void +devlink_health_reporter_recovery_done(struct devlink_health_reporter *reporter) +{ + reporter->recovery_count++; + reporter->last_recovery_ts = jiffies; +} +EXPORT_SYMBOL_GPL(devlink_health_reporter_recovery_done); + +static int +devlink_health_reporter_recover(struct devlink_health_reporter *reporter, + void *priv_ctx, struct netlink_ext_ack *extack) +{ + int err; + + if (reporter->health_state == DEVLINK_HEALTH_REPORTER_STATE_HEALTHY) + return 0; + + if (!reporter->ops->recover) + return -EOPNOTSUPP; + + err = reporter->ops->recover(reporter, priv_ctx, extack); + if (err) + return err; + + devlink_health_reporter_recovery_done(reporter); + reporter->health_state = DEVLINK_HEALTH_REPORTER_STATE_HEALTHY; + devlink_recover_notify(reporter, DEVLINK_CMD_HEALTH_REPORTER_RECOVER); + + return 0; +} + +static void +devlink_health_dump_clear(struct devlink_health_reporter *reporter) +{ + if (!reporter->dump_fmsg) + return; + devlink_fmsg_free(reporter->dump_fmsg); + reporter->dump_fmsg = NULL; +} + +static int devlink_health_do_dump(struct devlink_health_reporter *reporter, + void *priv_ctx, + struct netlink_ext_ack *extack) +{ + int err; + + if (!reporter->ops->dump) + return 0; + + if (reporter->dump_fmsg) + return 0; + + reporter->dump_fmsg = devlink_fmsg_alloc(); + if (!reporter->dump_fmsg) { + err = -ENOMEM; + return err; + } + + err = devlink_fmsg_obj_nest_start(reporter->dump_fmsg); + if (err) + goto dump_err; + + err = reporter->ops->dump(reporter, reporter->dump_fmsg, + priv_ctx, extack); + if (err) + goto dump_err; + + err = devlink_fmsg_obj_nest_end(reporter->dump_fmsg); + if (err) + goto dump_err; + + reporter->dump_ts = jiffies; + reporter->dump_real_ts = ktime_get_real_ns(); + + return 0; + +dump_err: + devlink_health_dump_clear(reporter); + return err; +} + +int devlink_health_report(struct devlink_health_reporter *reporter, + const char *msg, void *priv_ctx) +{ + enum devlink_health_reporter_state prev_health_state; + struct devlink *devlink = reporter->devlink; + unsigned long recover_ts_threshold; + int ret; + + /* write a log message of the current error */ + WARN_ON(!msg); + trace_devlink_health_report(devlink, reporter->ops->name, msg); + reporter->error_count++; + prev_health_state = reporter->health_state; + reporter->health_state = DEVLINK_HEALTH_REPORTER_STATE_ERROR; + devlink_recover_notify(reporter, DEVLINK_CMD_HEALTH_REPORTER_RECOVER); + + /* abort if the previous error wasn't recovered */ + recover_ts_threshold = reporter->last_recovery_ts + + msecs_to_jiffies(reporter->graceful_period); + if (reporter->auto_recover && + (prev_health_state != DEVLINK_HEALTH_REPORTER_STATE_HEALTHY || + (reporter->last_recovery_ts && reporter->recovery_count && + time_is_after_jiffies(recover_ts_threshold)))) { + trace_devlink_health_recover_aborted(devlink, + reporter->ops->name, + reporter->health_state, + jiffies - + reporter->last_recovery_ts); + return -ECANCELED; + } + + reporter->health_state = DEVLINK_HEALTH_REPORTER_STATE_ERROR; + + if (reporter->auto_dump) { + mutex_lock(&reporter->dump_lock); + /* store current dump of current error, for later analysis */ + devlink_health_do_dump(reporter, priv_ctx, NULL); + mutex_unlock(&reporter->dump_lock); + } + + if (!reporter->auto_recover) + return 0; + + devl_lock(devlink); + ret = devlink_health_reporter_recover(reporter, priv_ctx, NULL); + devl_unlock(devlink); + + return ret; +} +EXPORT_SYMBOL_GPL(devlink_health_report); + +static struct devlink_health_reporter * +devlink_health_reporter_get_from_attrs(struct devlink *devlink, + struct nlattr **attrs) +{ + struct devlink_health_reporter *reporter; + struct devlink_port *devlink_port; + char *reporter_name; + + if (!attrs[DEVLINK_ATTR_HEALTH_REPORTER_NAME]) + return NULL; + + reporter_name = nla_data(attrs[DEVLINK_ATTR_HEALTH_REPORTER_NAME]); + devlink_port = devlink_port_get_from_attrs(devlink, attrs); + if (IS_ERR(devlink_port)) { + mutex_lock(&devlink->reporters_lock); + reporter = devlink_health_reporter_find_by_name(devlink, reporter_name); + if (reporter) + refcount_inc(&reporter->refcount); + mutex_unlock(&devlink->reporters_lock); + } else { + mutex_lock(&devlink_port->reporters_lock); + reporter = devlink_port_health_reporter_find_by_name(devlink_port, reporter_name); + if (reporter) + refcount_inc(&reporter->refcount); + mutex_unlock(&devlink_port->reporters_lock); + } + + return reporter; +} + +static struct devlink_health_reporter * +devlink_health_reporter_get_from_info(struct devlink *devlink, + struct genl_info *info) +{ + return devlink_health_reporter_get_from_attrs(devlink, info->attrs); +} + +static struct devlink_health_reporter * +devlink_health_reporter_get_from_cb(struct netlink_callback *cb) +{ + const struct genl_dumpit_info *info = genl_dumpit_info(cb); + struct devlink_health_reporter *reporter; + struct nlattr **attrs = info->attrs; + struct devlink *devlink; + + devlink = devlink_get_from_attrs(sock_net(cb->skb->sk), attrs); + if (IS_ERR(devlink)) + return NULL; + + reporter = devlink_health_reporter_get_from_attrs(devlink, attrs); + devlink_put(devlink); + return reporter; +} + +void +devlink_health_reporter_state_update(struct devlink_health_reporter *reporter, + enum devlink_health_reporter_state state) +{ + if (WARN_ON(state != DEVLINK_HEALTH_REPORTER_STATE_HEALTHY && + state != DEVLINK_HEALTH_REPORTER_STATE_ERROR)) + return; + + if (reporter->health_state == state) + return; + + reporter->health_state = state; + trace_devlink_health_reporter_state_update(reporter->devlink, + reporter->ops->name, state); + devlink_recover_notify(reporter, DEVLINK_CMD_HEALTH_REPORTER_RECOVER); +} +EXPORT_SYMBOL_GPL(devlink_health_reporter_state_update); + +static int devlink_nl_cmd_health_reporter_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_health_reporter *reporter; + struct sk_buff *msg; + int err; + + reporter = devlink_health_reporter_get_from_info(devlink, info); + if (!reporter) + return -EINVAL; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) { + err = -ENOMEM; + goto out; + } + + err = devlink_nl_health_reporter_fill(msg, reporter, + DEVLINK_CMD_HEALTH_REPORTER_GET, + info->snd_portid, info->snd_seq, + 0); + if (err) { + nlmsg_free(msg); + goto out; + } + + err = genlmsg_reply(msg, info); +out: + devlink_health_reporter_put(reporter); + return err; +} + +static int +devlink_nl_cmd_health_reporter_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct devlink_health_reporter *reporter; + struct devlink_port *port; + struct devlink *devlink; + int start = cb->args[0]; + unsigned long index; + int idx = 0; + int err; + + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + mutex_lock(&devlink->reporters_lock); + list_for_each_entry(reporter, &devlink->reporter_list, + list) { + if (idx < start) { + idx++; + continue; + } + err = devlink_nl_health_reporter_fill( + msg, reporter, DEVLINK_CMD_HEALTH_REPORTER_GET, + NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + NLM_F_MULTI); + if (err) { + mutex_unlock(&devlink->reporters_lock); + devlink_put(devlink); + goto out; + } + idx++; + } + mutex_unlock(&devlink->reporters_lock); + devlink_put(devlink); + } + + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + devl_lock(devlink); + list_for_each_entry(port, &devlink->port_list, list) { + mutex_lock(&port->reporters_lock); + list_for_each_entry(reporter, &port->reporter_list, list) { + if (idx < start) { + idx++; + continue; + } + err = devlink_nl_health_reporter_fill( + msg, reporter, + DEVLINK_CMD_HEALTH_REPORTER_GET, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI); + if (err) { + mutex_unlock(&port->reporters_lock); + devl_unlock(devlink); + devlink_put(devlink); + goto out; + } + idx++; + } + mutex_unlock(&port->reporters_lock); + } + devl_unlock(devlink); + devlink_put(devlink); + } +out: + cb->args[0] = idx; + return msg->len; +} + +static int +devlink_nl_cmd_health_reporter_set_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_health_reporter *reporter; + int err; + + reporter = devlink_health_reporter_get_from_info(devlink, info); + if (!reporter) + return -EINVAL; + + if (!reporter->ops->recover && + (info->attrs[DEVLINK_ATTR_HEALTH_REPORTER_GRACEFUL_PERIOD] || + info->attrs[DEVLINK_ATTR_HEALTH_REPORTER_AUTO_RECOVER])) { + err = -EOPNOTSUPP; + goto out; + } + if (!reporter->ops->dump && + info->attrs[DEVLINK_ATTR_HEALTH_REPORTER_AUTO_DUMP]) { + err = -EOPNOTSUPP; + goto out; + } + + if (info->attrs[DEVLINK_ATTR_HEALTH_REPORTER_GRACEFUL_PERIOD]) + reporter->graceful_period = + nla_get_u64(info->attrs[DEVLINK_ATTR_HEALTH_REPORTER_GRACEFUL_PERIOD]); + + if (info->attrs[DEVLINK_ATTR_HEALTH_REPORTER_AUTO_RECOVER]) + reporter->auto_recover = + nla_get_u8(info->attrs[DEVLINK_ATTR_HEALTH_REPORTER_AUTO_RECOVER]); + + if (info->attrs[DEVLINK_ATTR_HEALTH_REPORTER_AUTO_DUMP]) + reporter->auto_dump = + nla_get_u8(info->attrs[DEVLINK_ATTR_HEALTH_REPORTER_AUTO_DUMP]); + + devlink_health_reporter_put(reporter); + return 0; +out: + devlink_health_reporter_put(reporter); + return err; +} + +static int devlink_nl_cmd_health_reporter_recover_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_health_reporter *reporter; + int err; + + reporter = devlink_health_reporter_get_from_info(devlink, info); + if (!reporter) + return -EINVAL; + + err = devlink_health_reporter_recover(reporter, NULL, info->extack); + + devlink_health_reporter_put(reporter); + return err; +} + +static int devlink_nl_cmd_health_reporter_diagnose_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_health_reporter *reporter; + struct devlink_fmsg *fmsg; + int err; + + reporter = devlink_health_reporter_get_from_info(devlink, info); + if (!reporter) + return -EINVAL; + + if (!reporter->ops->diagnose) { + devlink_health_reporter_put(reporter); + return -EOPNOTSUPP; + } + + fmsg = devlink_fmsg_alloc(); + if (!fmsg) { + devlink_health_reporter_put(reporter); + return -ENOMEM; + } + + err = devlink_fmsg_obj_nest_start(fmsg); + if (err) + goto out; + + err = reporter->ops->diagnose(reporter, fmsg, info->extack); + if (err) + goto out; + + err = devlink_fmsg_obj_nest_end(fmsg); + if (err) + goto out; + + err = devlink_fmsg_snd(fmsg, info, + DEVLINK_CMD_HEALTH_REPORTER_DIAGNOSE, 0); + +out: + devlink_fmsg_free(fmsg); + devlink_health_reporter_put(reporter); + return err; +} + +static int +devlink_nl_cmd_health_reporter_dump_get_dumpit(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct devlink_health_reporter *reporter; + u64 start = cb->args[0]; + int err; + + reporter = devlink_health_reporter_get_from_cb(cb); + if (!reporter) + return -EINVAL; + + if (!reporter->ops->dump) { + err = -EOPNOTSUPP; + goto out; + } + mutex_lock(&reporter->dump_lock); + if (!start) { + err = devlink_health_do_dump(reporter, NULL, cb->extack); + if (err) + goto unlock; + cb->args[1] = reporter->dump_ts; + } + if (!reporter->dump_fmsg || cb->args[1] != reporter->dump_ts) { + NL_SET_ERR_MSG_MOD(cb->extack, "Dump trampled, please retry"); + err = -EAGAIN; + goto unlock; + } + + err = devlink_fmsg_dumpit(reporter->dump_fmsg, skb, cb, + DEVLINK_CMD_HEALTH_REPORTER_DUMP_GET); +unlock: + mutex_unlock(&reporter->dump_lock); +out: + devlink_health_reporter_put(reporter); + return err; +} + +static int +devlink_nl_cmd_health_reporter_dump_clear_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_health_reporter *reporter; + + reporter = devlink_health_reporter_get_from_info(devlink, info); + if (!reporter) + return -EINVAL; + + if (!reporter->ops->dump) { + devlink_health_reporter_put(reporter); + return -EOPNOTSUPP; + } + + mutex_lock(&reporter->dump_lock); + devlink_health_dump_clear(reporter); + mutex_unlock(&reporter->dump_lock); + devlink_health_reporter_put(reporter); + return 0; +} + +static int devlink_nl_cmd_health_reporter_test_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink *devlink = info->user_ptr[0]; + struct devlink_health_reporter *reporter; + int err; + + reporter = devlink_health_reporter_get_from_info(devlink, info); + if (!reporter) + return -EINVAL; + + if (!reporter->ops->test) { + devlink_health_reporter_put(reporter); + return -EOPNOTSUPP; + } + + err = reporter->ops->test(reporter, info->extack); + + devlink_health_reporter_put(reporter); + return err; +} + +struct devlink_stats { + u64_stats_t rx_bytes; + u64_stats_t rx_packets; + struct u64_stats_sync syncp; +}; + +/** + * struct devlink_trap_policer_item - Packet trap policer attributes. + * @policer: Immutable packet trap policer attributes. + * @rate: Rate in packets / sec. + * @burst: Burst size in packets. + * @list: trap_policer_list member. + * + * Describes packet trap policer attributes. Created by devlink during trap + * policer registration. + */ +struct devlink_trap_policer_item { + const struct devlink_trap_policer *policer; + u64 rate; + u64 burst; + struct list_head list; +}; + +/** + * struct devlink_trap_group_item - Packet trap group attributes. + * @group: Immutable packet trap group attributes. + * @policer_item: Associated policer item. Can be NULL. + * @list: trap_group_list member. + * @stats: Trap group statistics. + * + * Describes packet trap group attributes. Created by devlink during trap + * group registration. + */ +struct devlink_trap_group_item { + const struct devlink_trap_group *group; + struct devlink_trap_policer_item *policer_item; + struct list_head list; + struct devlink_stats __percpu *stats; +}; + +/** + * struct devlink_trap_item - Packet trap attributes. + * @trap: Immutable packet trap attributes. + * @group_item: Associated group item. + * @list: trap_list member. + * @action: Trap action. + * @stats: Trap statistics. + * @priv: Driver private information. + * + * Describes both mutable and immutable packet trap attributes. Created by + * devlink during trap registration and used for all trap related operations. + */ +struct devlink_trap_item { + const struct devlink_trap *trap; + struct devlink_trap_group_item *group_item; + struct list_head list; + enum devlink_trap_action action; + struct devlink_stats __percpu *stats; + void *priv; +}; + +static struct devlink_trap_policer_item * +devlink_trap_policer_item_lookup(struct devlink *devlink, u32 id) +{ + struct devlink_trap_policer_item *policer_item; + + list_for_each_entry(policer_item, &devlink->trap_policer_list, list) { + if (policer_item->policer->id == id) + return policer_item; + } + + return NULL; +} + +static struct devlink_trap_item * +devlink_trap_item_lookup(struct devlink *devlink, const char *name) +{ + struct devlink_trap_item *trap_item; + + list_for_each_entry(trap_item, &devlink->trap_list, list) { + if (!strcmp(trap_item->trap->name, name)) + return trap_item; + } + + return NULL; +} + +static struct devlink_trap_item * +devlink_trap_item_get_from_info(struct devlink *devlink, + struct genl_info *info) +{ + struct nlattr *attr; + + if (!info->attrs[DEVLINK_ATTR_TRAP_NAME]) + return NULL; + attr = info->attrs[DEVLINK_ATTR_TRAP_NAME]; + + return devlink_trap_item_lookup(devlink, nla_data(attr)); +} + +static int +devlink_trap_action_get_from_info(struct genl_info *info, + enum devlink_trap_action *p_trap_action) +{ + u8 val; + + val = nla_get_u8(info->attrs[DEVLINK_ATTR_TRAP_ACTION]); + switch (val) { + case DEVLINK_TRAP_ACTION_DROP: + case DEVLINK_TRAP_ACTION_TRAP: + case DEVLINK_TRAP_ACTION_MIRROR: + *p_trap_action = val; + break; + default: + return -EINVAL; + } + + return 0; +} + +static int devlink_trap_metadata_put(struct sk_buff *msg, + const struct devlink_trap *trap) +{ + struct nlattr *attr; + + attr = nla_nest_start(msg, DEVLINK_ATTR_TRAP_METADATA); + if (!attr) + return -EMSGSIZE; + + if ((trap->metadata_cap & DEVLINK_TRAP_METADATA_TYPE_F_IN_PORT) && + nla_put_flag(msg, DEVLINK_ATTR_TRAP_METADATA_TYPE_IN_PORT)) + goto nla_put_failure; + if ((trap->metadata_cap & DEVLINK_TRAP_METADATA_TYPE_F_FA_COOKIE) && + nla_put_flag(msg, DEVLINK_ATTR_TRAP_METADATA_TYPE_FA_COOKIE)) + goto nla_put_failure; + + nla_nest_end(msg, attr); + + return 0; + +nla_put_failure: + nla_nest_cancel(msg, attr); + return -EMSGSIZE; +} + +static void devlink_trap_stats_read(struct devlink_stats __percpu *trap_stats, + struct devlink_stats *stats) +{ + int i; + + memset(stats, 0, sizeof(*stats)); + for_each_possible_cpu(i) { + struct devlink_stats *cpu_stats; + u64 rx_packets, rx_bytes; + unsigned int start; + + cpu_stats = per_cpu_ptr(trap_stats, i); + do { + start = u64_stats_fetch_begin_irq(&cpu_stats->syncp); + rx_packets = u64_stats_read(&cpu_stats->rx_packets); + rx_bytes = u64_stats_read(&cpu_stats->rx_bytes); + } while (u64_stats_fetch_retry_irq(&cpu_stats->syncp, start)); + + u64_stats_add(&stats->rx_packets, rx_packets); + u64_stats_add(&stats->rx_bytes, rx_bytes); + } +} + +static int +devlink_trap_group_stats_put(struct sk_buff *msg, + struct devlink_stats __percpu *trap_stats) +{ + struct devlink_stats stats; + struct nlattr *attr; + + devlink_trap_stats_read(trap_stats, &stats); + + attr = nla_nest_start(msg, DEVLINK_ATTR_STATS); + if (!attr) + return -EMSGSIZE; + + if (nla_put_u64_64bit(msg, DEVLINK_ATTR_STATS_RX_PACKETS, + u64_stats_read(&stats.rx_packets), + DEVLINK_ATTR_PAD)) + goto nla_put_failure; + + if (nla_put_u64_64bit(msg, DEVLINK_ATTR_STATS_RX_BYTES, + u64_stats_read(&stats.rx_bytes), + DEVLINK_ATTR_PAD)) + goto nla_put_failure; + + nla_nest_end(msg, attr); + + return 0; + +nla_put_failure: + nla_nest_cancel(msg, attr); + return -EMSGSIZE; +} + +static int devlink_trap_stats_put(struct sk_buff *msg, struct devlink *devlink, + const struct devlink_trap_item *trap_item) +{ + struct devlink_stats stats; + struct nlattr *attr; + u64 drops = 0; + int err; + + if (devlink->ops->trap_drop_counter_get) { + err = devlink->ops->trap_drop_counter_get(devlink, + trap_item->trap, + &drops); + if (err) + return err; + } + + devlink_trap_stats_read(trap_item->stats, &stats); + + attr = nla_nest_start(msg, DEVLINK_ATTR_STATS); + if (!attr) + return -EMSGSIZE; + + if (devlink->ops->trap_drop_counter_get && + nla_put_u64_64bit(msg, DEVLINK_ATTR_STATS_RX_DROPPED, drops, + DEVLINK_ATTR_PAD)) + goto nla_put_failure; + + if (nla_put_u64_64bit(msg, DEVLINK_ATTR_STATS_RX_PACKETS, + u64_stats_read(&stats.rx_packets), + DEVLINK_ATTR_PAD)) + goto nla_put_failure; + + if (nla_put_u64_64bit(msg, DEVLINK_ATTR_STATS_RX_BYTES, + u64_stats_read(&stats.rx_bytes), + DEVLINK_ATTR_PAD)) + goto nla_put_failure; + + nla_nest_end(msg, attr); + + return 0; + +nla_put_failure: + nla_nest_cancel(msg, attr); + return -EMSGSIZE; +} + +static int devlink_nl_trap_fill(struct sk_buff *msg, struct devlink *devlink, + const struct devlink_trap_item *trap_item, + enum devlink_command cmd, u32 portid, u32 seq, + int flags) +{ + struct devlink_trap_group_item *group_item = trap_item->group_item; + void *hdr; + int err; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + if (devlink_nl_put_handle(msg, devlink)) + goto nla_put_failure; + + if (nla_put_string(msg, DEVLINK_ATTR_TRAP_GROUP_NAME, + group_item->group->name)) + goto nla_put_failure; + + if (nla_put_string(msg, DEVLINK_ATTR_TRAP_NAME, trap_item->trap->name)) + goto nla_put_failure; + + if (nla_put_u8(msg, DEVLINK_ATTR_TRAP_TYPE, trap_item->trap->type)) + goto nla_put_failure; + + if (trap_item->trap->generic && + nla_put_flag(msg, DEVLINK_ATTR_TRAP_GENERIC)) + goto nla_put_failure; + + if (nla_put_u8(msg, DEVLINK_ATTR_TRAP_ACTION, trap_item->action)) + goto nla_put_failure; + + err = devlink_trap_metadata_put(msg, trap_item->trap); + if (err) + goto nla_put_failure; + + err = devlink_trap_stats_put(msg, devlink, trap_item); + if (err) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int devlink_nl_cmd_trap_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct netlink_ext_ack *extack = info->extack; + struct devlink *devlink = info->user_ptr[0]; + struct devlink_trap_item *trap_item; + struct sk_buff *msg; + int err; + + if (list_empty(&devlink->trap_list)) + return -EOPNOTSUPP; + + trap_item = devlink_trap_item_get_from_info(devlink, info); + if (!trap_item) { + NL_SET_ERR_MSG_MOD(extack, "Device did not register this trap"); + return -ENOENT; + } + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_trap_fill(msg, devlink, trap_item, + DEVLINK_CMD_TRAP_NEW, info->snd_portid, + info->snd_seq, 0); + if (err) + goto err_trap_fill; + + return genlmsg_reply(msg, info); + +err_trap_fill: + nlmsg_free(msg); + return err; +} + +static int devlink_nl_cmd_trap_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct devlink_trap_item *trap_item; + struct devlink *devlink; + int start = cb->args[0]; + unsigned long index; + int idx = 0; + int err; + + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + devl_lock(devlink); + list_for_each_entry(trap_item, &devlink->trap_list, list) { + if (idx < start) { + idx++; + continue; + } + err = devlink_nl_trap_fill(msg, devlink, trap_item, + DEVLINK_CMD_TRAP_NEW, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI); + if (err) { + devl_unlock(devlink); + devlink_put(devlink); + goto out; + } + idx++; + } + devl_unlock(devlink); + devlink_put(devlink); + } +out: + cb->args[0] = idx; + return msg->len; +} + +static int __devlink_trap_action_set(struct devlink *devlink, + struct devlink_trap_item *trap_item, + enum devlink_trap_action trap_action, + struct netlink_ext_ack *extack) +{ + int err; + + if (trap_item->action != trap_action && + trap_item->trap->type != DEVLINK_TRAP_TYPE_DROP) { + NL_SET_ERR_MSG_MOD(extack, "Cannot change action of non-drop traps. Skipping"); + return 0; + } + + err = devlink->ops->trap_action_set(devlink, trap_item->trap, + trap_action, extack); + if (err) + return err; + + trap_item->action = trap_action; + + return 0; +} + +static int devlink_trap_action_set(struct devlink *devlink, + struct devlink_trap_item *trap_item, + struct genl_info *info) +{ + enum devlink_trap_action trap_action; + int err; + + if (!info->attrs[DEVLINK_ATTR_TRAP_ACTION]) + return 0; + + err = devlink_trap_action_get_from_info(info, &trap_action); + if (err) { + NL_SET_ERR_MSG_MOD(info->extack, "Invalid trap action"); + return -EINVAL; + } + + return __devlink_trap_action_set(devlink, trap_item, trap_action, + info->extack); +} + +static int devlink_nl_cmd_trap_set_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct netlink_ext_ack *extack = info->extack; + struct devlink *devlink = info->user_ptr[0]; + struct devlink_trap_item *trap_item; + + if (list_empty(&devlink->trap_list)) + return -EOPNOTSUPP; + + trap_item = devlink_trap_item_get_from_info(devlink, info); + if (!trap_item) { + NL_SET_ERR_MSG_MOD(extack, "Device did not register this trap"); + return -ENOENT; + } + + return devlink_trap_action_set(devlink, trap_item, info); +} + +static struct devlink_trap_group_item * +devlink_trap_group_item_lookup(struct devlink *devlink, const char *name) +{ + struct devlink_trap_group_item *group_item; + + list_for_each_entry(group_item, &devlink->trap_group_list, list) { + if (!strcmp(group_item->group->name, name)) + return group_item; + } + + return NULL; +} + +static struct devlink_trap_group_item * +devlink_trap_group_item_lookup_by_id(struct devlink *devlink, u16 id) +{ + struct devlink_trap_group_item *group_item; + + list_for_each_entry(group_item, &devlink->trap_group_list, list) { + if (group_item->group->id == id) + return group_item; + } + + return NULL; +} + +static struct devlink_trap_group_item * +devlink_trap_group_item_get_from_info(struct devlink *devlink, + struct genl_info *info) +{ + char *name; + + if (!info->attrs[DEVLINK_ATTR_TRAP_GROUP_NAME]) + return NULL; + name = nla_data(info->attrs[DEVLINK_ATTR_TRAP_GROUP_NAME]); + + return devlink_trap_group_item_lookup(devlink, name); +} + +static int +devlink_nl_trap_group_fill(struct sk_buff *msg, struct devlink *devlink, + const struct devlink_trap_group_item *group_item, + enum devlink_command cmd, u32 portid, u32 seq, + int flags) +{ + void *hdr; + int err; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + if (devlink_nl_put_handle(msg, devlink)) + goto nla_put_failure; + + if (nla_put_string(msg, DEVLINK_ATTR_TRAP_GROUP_NAME, + group_item->group->name)) + goto nla_put_failure; + + if (group_item->group->generic && + nla_put_flag(msg, DEVLINK_ATTR_TRAP_GENERIC)) + goto nla_put_failure; + + if (group_item->policer_item && + nla_put_u32(msg, DEVLINK_ATTR_TRAP_POLICER_ID, + group_item->policer_item->policer->id)) + goto nla_put_failure; + + err = devlink_trap_group_stats_put(msg, group_item->stats); + if (err) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int devlink_nl_cmd_trap_group_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct netlink_ext_ack *extack = info->extack; + struct devlink *devlink = info->user_ptr[0]; + struct devlink_trap_group_item *group_item; + struct sk_buff *msg; + int err; + + if (list_empty(&devlink->trap_group_list)) + return -EOPNOTSUPP; + + group_item = devlink_trap_group_item_get_from_info(devlink, info); + if (!group_item) { + NL_SET_ERR_MSG_MOD(extack, "Device did not register this trap group"); + return -ENOENT; + } + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_trap_group_fill(msg, devlink, group_item, + DEVLINK_CMD_TRAP_GROUP_NEW, + info->snd_portid, info->snd_seq, 0); + if (err) + goto err_trap_group_fill; + + return genlmsg_reply(msg, info); + +err_trap_group_fill: + nlmsg_free(msg); + return err; +} + +static int devlink_nl_cmd_trap_group_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + enum devlink_command cmd = DEVLINK_CMD_TRAP_GROUP_NEW; + struct devlink_trap_group_item *group_item; + u32 portid = NETLINK_CB(cb->skb).portid; + struct devlink *devlink; + int start = cb->args[0]; + unsigned long index; + int idx = 0; + int err; + + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + devl_lock(devlink); + list_for_each_entry(group_item, &devlink->trap_group_list, + list) { + if (idx < start) { + idx++; + continue; + } + err = devlink_nl_trap_group_fill(msg, devlink, + group_item, cmd, + portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI); + if (err) { + devl_unlock(devlink); + devlink_put(devlink); + goto out; + } + idx++; + } + devl_unlock(devlink); + devlink_put(devlink); + } +out: + cb->args[0] = idx; + return msg->len; +} + +static int +__devlink_trap_group_action_set(struct devlink *devlink, + struct devlink_trap_group_item *group_item, + enum devlink_trap_action trap_action, + struct netlink_ext_ack *extack) +{ + const char *group_name = group_item->group->name; + struct devlink_trap_item *trap_item; + int err; + + if (devlink->ops->trap_group_action_set) { + err = devlink->ops->trap_group_action_set(devlink, group_item->group, + trap_action, extack); + if (err) + return err; + + list_for_each_entry(trap_item, &devlink->trap_list, list) { + if (strcmp(trap_item->group_item->group->name, group_name)) + continue; + if (trap_item->action != trap_action && + trap_item->trap->type != DEVLINK_TRAP_TYPE_DROP) + continue; + trap_item->action = trap_action; + } + + return 0; + } + + list_for_each_entry(trap_item, &devlink->trap_list, list) { + if (strcmp(trap_item->group_item->group->name, group_name)) + continue; + err = __devlink_trap_action_set(devlink, trap_item, + trap_action, extack); + if (err) + return err; + } + + return 0; +} + +static int +devlink_trap_group_action_set(struct devlink *devlink, + struct devlink_trap_group_item *group_item, + struct genl_info *info, bool *p_modified) +{ + enum devlink_trap_action trap_action; + int err; + + if (!info->attrs[DEVLINK_ATTR_TRAP_ACTION]) + return 0; + + err = devlink_trap_action_get_from_info(info, &trap_action); + if (err) { + NL_SET_ERR_MSG_MOD(info->extack, "Invalid trap action"); + return -EINVAL; + } + + err = __devlink_trap_group_action_set(devlink, group_item, trap_action, + info->extack); + if (err) + return err; + + *p_modified = true; + + return 0; +} + +static int devlink_trap_group_set(struct devlink *devlink, + struct devlink_trap_group_item *group_item, + struct genl_info *info) +{ + struct devlink_trap_policer_item *policer_item; + struct netlink_ext_ack *extack = info->extack; + const struct devlink_trap_policer *policer; + struct nlattr **attrs = info->attrs; + int err; + + if (!attrs[DEVLINK_ATTR_TRAP_POLICER_ID]) + return 0; + + if (!devlink->ops->trap_group_set) + return -EOPNOTSUPP; + + policer_item = group_item->policer_item; + if (attrs[DEVLINK_ATTR_TRAP_POLICER_ID]) { + u32 policer_id; + + policer_id = nla_get_u32(attrs[DEVLINK_ATTR_TRAP_POLICER_ID]); + policer_item = devlink_trap_policer_item_lookup(devlink, + policer_id); + if (policer_id && !policer_item) { + NL_SET_ERR_MSG_MOD(extack, "Device did not register this trap policer"); + return -ENOENT; + } + } + policer = policer_item ? policer_item->policer : NULL; + + err = devlink->ops->trap_group_set(devlink, group_item->group, policer, + extack); + if (err) + return err; + + group_item->policer_item = policer_item; + + return 0; +} + +static int devlink_nl_cmd_trap_group_set_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct netlink_ext_ack *extack = info->extack; + struct devlink *devlink = info->user_ptr[0]; + struct devlink_trap_group_item *group_item; + bool modified = false; + int err; + + if (list_empty(&devlink->trap_group_list)) + return -EOPNOTSUPP; + + group_item = devlink_trap_group_item_get_from_info(devlink, info); + if (!group_item) { + NL_SET_ERR_MSG_MOD(extack, "Device did not register this trap group"); + return -ENOENT; + } + + err = devlink_trap_group_action_set(devlink, group_item, info, + &modified); + if (err) + return err; + + err = devlink_trap_group_set(devlink, group_item, info); + if (err) + goto err_trap_group_set; + + return 0; + +err_trap_group_set: + if (modified) + NL_SET_ERR_MSG_MOD(extack, "Trap group set failed, but some changes were committed already"); + return err; +} + +static struct devlink_trap_policer_item * +devlink_trap_policer_item_get_from_info(struct devlink *devlink, + struct genl_info *info) +{ + u32 id; + + if (!info->attrs[DEVLINK_ATTR_TRAP_POLICER_ID]) + return NULL; + id = nla_get_u32(info->attrs[DEVLINK_ATTR_TRAP_POLICER_ID]); + + return devlink_trap_policer_item_lookup(devlink, id); +} + +static int +devlink_trap_policer_stats_put(struct sk_buff *msg, struct devlink *devlink, + const struct devlink_trap_policer *policer) +{ + struct nlattr *attr; + u64 drops; + int err; + + if (!devlink->ops->trap_policer_counter_get) + return 0; + + err = devlink->ops->trap_policer_counter_get(devlink, policer, &drops); + if (err) + return err; + + attr = nla_nest_start(msg, DEVLINK_ATTR_STATS); + if (!attr) + return -EMSGSIZE; + + if (nla_put_u64_64bit(msg, DEVLINK_ATTR_STATS_RX_DROPPED, drops, + DEVLINK_ATTR_PAD)) + goto nla_put_failure; + + nla_nest_end(msg, attr); + + return 0; + +nla_put_failure: + nla_nest_cancel(msg, attr); + return -EMSGSIZE; +} + +static int +devlink_nl_trap_policer_fill(struct sk_buff *msg, struct devlink *devlink, + const struct devlink_trap_policer_item *policer_item, + enum devlink_command cmd, u32 portid, u32 seq, + int flags) +{ + void *hdr; + int err; + + hdr = genlmsg_put(msg, portid, seq, &devlink_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + if (devlink_nl_put_handle(msg, devlink)) + goto nla_put_failure; + + if (nla_put_u32(msg, DEVLINK_ATTR_TRAP_POLICER_ID, + policer_item->policer->id)) + goto nla_put_failure; + + if (nla_put_u64_64bit(msg, DEVLINK_ATTR_TRAP_POLICER_RATE, + policer_item->rate, DEVLINK_ATTR_PAD)) + goto nla_put_failure; + + if (nla_put_u64_64bit(msg, DEVLINK_ATTR_TRAP_POLICER_BURST, + policer_item->burst, DEVLINK_ATTR_PAD)) + goto nla_put_failure; + + err = devlink_trap_policer_stats_put(msg, devlink, + policer_item->policer); + if (err) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int devlink_nl_cmd_trap_policer_get_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink_trap_policer_item *policer_item; + struct netlink_ext_ack *extack = info->extack; + struct devlink *devlink = info->user_ptr[0]; + struct sk_buff *msg; + int err; + + if (list_empty(&devlink->trap_policer_list)) + return -EOPNOTSUPP; + + policer_item = devlink_trap_policer_item_get_from_info(devlink, info); + if (!policer_item) { + NL_SET_ERR_MSG_MOD(extack, "Device did not register this trap policer"); + return -ENOENT; + } + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + err = devlink_nl_trap_policer_fill(msg, devlink, policer_item, + DEVLINK_CMD_TRAP_POLICER_NEW, + info->snd_portid, info->snd_seq, 0); + if (err) + goto err_trap_policer_fill; + + return genlmsg_reply(msg, info); + +err_trap_policer_fill: + nlmsg_free(msg); + return err; +} + +static int devlink_nl_cmd_trap_policer_get_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + enum devlink_command cmd = DEVLINK_CMD_TRAP_POLICER_NEW; + struct devlink_trap_policer_item *policer_item; + u32 portid = NETLINK_CB(cb->skb).portid; + struct devlink *devlink; + int start = cb->args[0]; + unsigned long index; + int idx = 0; + int err; + + devlinks_xa_for_each_registered_get(sock_net(msg->sk), index, devlink) { + devl_lock(devlink); + list_for_each_entry(policer_item, &devlink->trap_policer_list, + list) { + if (idx < start) { + idx++; + continue; + } + err = devlink_nl_trap_policer_fill(msg, devlink, + policer_item, cmd, + portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI); + if (err) { + devl_unlock(devlink); + devlink_put(devlink); + goto out; + } + idx++; + } + devl_unlock(devlink); + devlink_put(devlink); + } +out: + cb->args[0] = idx; + return msg->len; +} + +static int +devlink_trap_policer_set(struct devlink *devlink, + struct devlink_trap_policer_item *policer_item, + struct genl_info *info) +{ + struct netlink_ext_ack *extack = info->extack; + struct nlattr **attrs = info->attrs; + u64 rate, burst; + int err; + + rate = policer_item->rate; + burst = policer_item->burst; + + if (attrs[DEVLINK_ATTR_TRAP_POLICER_RATE]) + rate = nla_get_u64(attrs[DEVLINK_ATTR_TRAP_POLICER_RATE]); + + if (attrs[DEVLINK_ATTR_TRAP_POLICER_BURST]) + burst = nla_get_u64(attrs[DEVLINK_ATTR_TRAP_POLICER_BURST]); + + if (rate < policer_item->policer->min_rate) { + NL_SET_ERR_MSG_MOD(extack, "Policer rate lower than limit"); + return -EINVAL; + } + + if (rate > policer_item->policer->max_rate) { + NL_SET_ERR_MSG_MOD(extack, "Policer rate higher than limit"); + return -EINVAL; + } + + if (burst < policer_item->policer->min_burst) { + NL_SET_ERR_MSG_MOD(extack, "Policer burst size lower than limit"); + return -EINVAL; + } + + if (burst > policer_item->policer->max_burst) { + NL_SET_ERR_MSG_MOD(extack, "Policer burst size higher than limit"); + return -EINVAL; + } + + err = devlink->ops->trap_policer_set(devlink, policer_item->policer, + rate, burst, info->extack); + if (err) + return err; + + policer_item->rate = rate; + policer_item->burst = burst; + + return 0; +} + +static int devlink_nl_cmd_trap_policer_set_doit(struct sk_buff *skb, + struct genl_info *info) +{ + struct devlink_trap_policer_item *policer_item; + struct netlink_ext_ack *extack = info->extack; + struct devlink *devlink = info->user_ptr[0]; + + if (list_empty(&devlink->trap_policer_list)) + return -EOPNOTSUPP; + + if (!devlink->ops->trap_policer_set) + return -EOPNOTSUPP; + + policer_item = devlink_trap_policer_item_get_from_info(devlink, info); + if (!policer_item) { + NL_SET_ERR_MSG_MOD(extack, "Device did not register this trap policer"); + return -ENOENT; + } + + return devlink_trap_policer_set(devlink, policer_item, info); +} + +static const struct nla_policy devlink_nl_policy[DEVLINK_ATTR_MAX + 1] = { + [DEVLINK_ATTR_UNSPEC] = { .strict_start_type = + DEVLINK_ATTR_TRAP_POLICER_ID }, + [DEVLINK_ATTR_BUS_NAME] = { .type = NLA_NUL_STRING }, + [DEVLINK_ATTR_DEV_NAME] = { .type = NLA_NUL_STRING }, + [DEVLINK_ATTR_PORT_INDEX] = { .type = NLA_U32 }, + [DEVLINK_ATTR_PORT_TYPE] = NLA_POLICY_RANGE(NLA_U16, DEVLINK_PORT_TYPE_AUTO, + DEVLINK_PORT_TYPE_IB), + [DEVLINK_ATTR_PORT_SPLIT_COUNT] = { .type = NLA_U32 }, + [DEVLINK_ATTR_SB_INDEX] = { .type = NLA_U32 }, + [DEVLINK_ATTR_SB_POOL_INDEX] = { .type = NLA_U16 }, + [DEVLINK_ATTR_SB_POOL_TYPE] = { .type = NLA_U8 }, + [DEVLINK_ATTR_SB_POOL_SIZE] = { .type = NLA_U32 }, + [DEVLINK_ATTR_SB_POOL_THRESHOLD_TYPE] = { .type = NLA_U8 }, + [DEVLINK_ATTR_SB_THRESHOLD] = { .type = NLA_U32 }, + [DEVLINK_ATTR_SB_TC_INDEX] = { .type = NLA_U16 }, + [DEVLINK_ATTR_ESWITCH_MODE] = NLA_POLICY_RANGE(NLA_U16, DEVLINK_ESWITCH_MODE_LEGACY, + DEVLINK_ESWITCH_MODE_SWITCHDEV), + [DEVLINK_ATTR_ESWITCH_INLINE_MODE] = { .type = NLA_U8 }, + [DEVLINK_ATTR_ESWITCH_ENCAP_MODE] = { .type = NLA_U8 }, + [DEVLINK_ATTR_DPIPE_TABLE_NAME] = { .type = NLA_NUL_STRING }, + [DEVLINK_ATTR_DPIPE_TABLE_COUNTERS_ENABLED] = { .type = NLA_U8 }, + [DEVLINK_ATTR_RESOURCE_ID] = { .type = NLA_U64}, + [DEVLINK_ATTR_RESOURCE_SIZE] = { .type = NLA_U64}, + [DEVLINK_ATTR_PARAM_NAME] = { .type = NLA_NUL_STRING }, + [DEVLINK_ATTR_PARAM_TYPE] = { .type = NLA_U8 }, + [DEVLINK_ATTR_PARAM_VALUE_CMODE] = { .type = NLA_U8 }, + [DEVLINK_ATTR_REGION_NAME] = { .type = NLA_NUL_STRING }, + [DEVLINK_ATTR_REGION_SNAPSHOT_ID] = { .type = NLA_U32 }, + [DEVLINK_ATTR_REGION_CHUNK_ADDR] = { .type = NLA_U64 }, + [DEVLINK_ATTR_REGION_CHUNK_LEN] = { .type = NLA_U64 }, + [DEVLINK_ATTR_HEALTH_REPORTER_NAME] = { .type = NLA_NUL_STRING }, + [DEVLINK_ATTR_HEALTH_REPORTER_GRACEFUL_PERIOD] = { .type = NLA_U64 }, + [DEVLINK_ATTR_HEALTH_REPORTER_AUTO_RECOVER] = { .type = NLA_U8 }, + [DEVLINK_ATTR_FLASH_UPDATE_FILE_NAME] = { .type = NLA_NUL_STRING }, + [DEVLINK_ATTR_FLASH_UPDATE_COMPONENT] = { .type = NLA_NUL_STRING }, + [DEVLINK_ATTR_FLASH_UPDATE_OVERWRITE_MASK] = + NLA_POLICY_BITFIELD32(DEVLINK_SUPPORTED_FLASH_OVERWRITE_SECTIONS), + [DEVLINK_ATTR_TRAP_NAME] = { .type = NLA_NUL_STRING }, + [DEVLINK_ATTR_TRAP_ACTION] = { .type = NLA_U8 }, + [DEVLINK_ATTR_TRAP_GROUP_NAME] = { .type = NLA_NUL_STRING }, + [DEVLINK_ATTR_NETNS_PID] = { .type = NLA_U32 }, + [DEVLINK_ATTR_NETNS_FD] = { .type = NLA_U32 }, + [DEVLINK_ATTR_NETNS_ID] = { .type = NLA_U32 }, + [DEVLINK_ATTR_HEALTH_REPORTER_AUTO_DUMP] = { .type = NLA_U8 }, + [DEVLINK_ATTR_TRAP_POLICER_ID] = { .type = NLA_U32 }, + [DEVLINK_ATTR_TRAP_POLICER_RATE] = { .type = NLA_U64 }, + [DEVLINK_ATTR_TRAP_POLICER_BURST] = { .type = NLA_U64 }, + [DEVLINK_ATTR_PORT_FUNCTION] = { .type = NLA_NESTED }, + [DEVLINK_ATTR_RELOAD_ACTION] = NLA_POLICY_RANGE(NLA_U8, DEVLINK_RELOAD_ACTION_DRIVER_REINIT, + DEVLINK_RELOAD_ACTION_MAX), + [DEVLINK_ATTR_RELOAD_LIMITS] = NLA_POLICY_BITFIELD32(DEVLINK_RELOAD_LIMITS_VALID_MASK), + [DEVLINK_ATTR_PORT_FLAVOUR] = { .type = NLA_U16 }, + [DEVLINK_ATTR_PORT_PCI_PF_NUMBER] = { .type = NLA_U16 }, + [DEVLINK_ATTR_PORT_PCI_SF_NUMBER] = { .type = NLA_U32 }, + [DEVLINK_ATTR_PORT_CONTROLLER_NUMBER] = { .type = NLA_U32 }, + [DEVLINK_ATTR_RATE_TYPE] = { .type = NLA_U16 }, + [DEVLINK_ATTR_RATE_TX_SHARE] = { .type = NLA_U64 }, + [DEVLINK_ATTR_RATE_TX_MAX] = { .type = NLA_U64 }, + [DEVLINK_ATTR_RATE_NODE_NAME] = { .type = NLA_NUL_STRING }, + [DEVLINK_ATTR_RATE_PARENT_NODE_NAME] = { .type = NLA_NUL_STRING }, + [DEVLINK_ATTR_LINECARD_INDEX] = { .type = NLA_U32 }, + [DEVLINK_ATTR_LINECARD_TYPE] = { .type = NLA_NUL_STRING }, + [DEVLINK_ATTR_SELFTESTS] = { .type = NLA_NESTED }, +}; + +static const struct genl_small_ops devlink_nl_ops[] = { + { + .cmd = DEVLINK_CMD_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_get_doit, + .dumpit = devlink_nl_cmd_get_dumpit, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_PORT_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_port_get_doit, + .dumpit = devlink_nl_cmd_port_get_dumpit, + .internal_flags = DEVLINK_NL_FLAG_NEED_PORT, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_PORT_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_port_set_doit, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_PORT, + }, + { + .cmd = DEVLINK_CMD_RATE_GET, + .doit = devlink_nl_cmd_rate_get_doit, + .dumpit = devlink_nl_cmd_rate_get_dumpit, + .internal_flags = DEVLINK_NL_FLAG_NEED_RATE, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_RATE_SET, + .doit = devlink_nl_cmd_rate_set_doit, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_RATE, + }, + { + .cmd = DEVLINK_CMD_RATE_NEW, + .doit = devlink_nl_cmd_rate_new_doit, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_RATE_DEL, + .doit = devlink_nl_cmd_rate_del_doit, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_RATE_NODE, + }, + { + .cmd = DEVLINK_CMD_PORT_SPLIT, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_port_split_doit, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_PORT, + }, + { + .cmd = DEVLINK_CMD_PORT_UNSPLIT, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_port_unsplit_doit, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_PORT, + }, + { + .cmd = DEVLINK_CMD_PORT_NEW, + .doit = devlink_nl_cmd_port_new_doit, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_PORT_DEL, + .doit = devlink_nl_cmd_port_del_doit, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_LINECARD_GET, + .doit = devlink_nl_cmd_linecard_get_doit, + .dumpit = devlink_nl_cmd_linecard_get_dumpit, + .internal_flags = DEVLINK_NL_FLAG_NEED_LINECARD, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_LINECARD_SET, + .doit = devlink_nl_cmd_linecard_set_doit, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_LINECARD, + }, + { + .cmd = DEVLINK_CMD_SB_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_sb_get_doit, + .dumpit = devlink_nl_cmd_sb_get_dumpit, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_SB_POOL_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_sb_pool_get_doit, + .dumpit = devlink_nl_cmd_sb_pool_get_dumpit, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_SB_POOL_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_sb_pool_set_doit, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_SB_PORT_POOL_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_sb_port_pool_get_doit, + .dumpit = devlink_nl_cmd_sb_port_pool_get_dumpit, + .internal_flags = DEVLINK_NL_FLAG_NEED_PORT, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_SB_PORT_POOL_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_sb_port_pool_set_doit, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_PORT, + }, + { + .cmd = DEVLINK_CMD_SB_TC_POOL_BIND_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_sb_tc_pool_bind_get_doit, + .dumpit = devlink_nl_cmd_sb_tc_pool_bind_get_dumpit, + .internal_flags = DEVLINK_NL_FLAG_NEED_PORT, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_SB_TC_POOL_BIND_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_sb_tc_pool_bind_set_doit, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_PORT, + }, + { + .cmd = DEVLINK_CMD_SB_OCC_SNAPSHOT, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_sb_occ_snapshot_doit, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_SB_OCC_MAX_CLEAR, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_sb_occ_max_clear_doit, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_ESWITCH_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_eswitch_get_doit, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_ESWITCH_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_eswitch_set_doit, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_DPIPE_TABLE_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_dpipe_table_get, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_DPIPE_ENTRIES_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_dpipe_entries_get, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_DPIPE_HEADERS_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_dpipe_headers_get, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_DPIPE_TABLE_COUNTERS_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_dpipe_table_counters_set, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_RESOURCE_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_resource_set, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_RESOURCE_DUMP, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_resource_dump, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_RELOAD, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_reload, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_PARAM_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_param_get_doit, + .dumpit = devlink_nl_cmd_param_get_dumpit, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_PARAM_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_param_set_doit, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_PORT_PARAM_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_port_param_get_doit, + .dumpit = devlink_nl_cmd_port_param_get_dumpit, + .internal_flags = DEVLINK_NL_FLAG_NEED_PORT, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_PORT_PARAM_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_port_param_set_doit, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_PORT, + }, + { + .cmd = DEVLINK_CMD_REGION_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_region_get_doit, + .dumpit = devlink_nl_cmd_region_get_dumpit, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_REGION_NEW, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_region_new, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_REGION_DEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_region_del, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_REGION_READ, + .validate = GENL_DONT_VALIDATE_STRICT | + GENL_DONT_VALIDATE_DUMP_STRICT, + .dumpit = devlink_nl_cmd_region_read_dumpit, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_INFO_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_info_get_doit, + .dumpit = devlink_nl_cmd_info_get_dumpit, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_HEALTH_REPORTER_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_health_reporter_get_doit, + .dumpit = devlink_nl_cmd_health_reporter_get_dumpit, + .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK_OR_PORT, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_HEALTH_REPORTER_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_health_reporter_set_doit, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK_OR_PORT, + }, + { + .cmd = DEVLINK_CMD_HEALTH_REPORTER_RECOVER, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_health_reporter_recover_doit, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK_OR_PORT, + }, + { + .cmd = DEVLINK_CMD_HEALTH_REPORTER_DIAGNOSE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_health_reporter_diagnose_doit, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK_OR_PORT, + }, + { + .cmd = DEVLINK_CMD_HEALTH_REPORTER_DUMP_GET, + .validate = GENL_DONT_VALIDATE_STRICT | + GENL_DONT_VALIDATE_DUMP_STRICT, + .dumpit = devlink_nl_cmd_health_reporter_dump_get_dumpit, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_HEALTH_REPORTER_DUMP_CLEAR, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_health_reporter_dump_clear_doit, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK_OR_PORT, + }, + { + .cmd = DEVLINK_CMD_HEALTH_REPORTER_TEST, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_health_reporter_test_doit, + .flags = GENL_ADMIN_PERM, + .internal_flags = DEVLINK_NL_FLAG_NEED_DEVLINK_OR_PORT, + }, + { + .cmd = DEVLINK_CMD_FLASH_UPDATE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = devlink_nl_cmd_flash_update, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_TRAP_GET, + .doit = devlink_nl_cmd_trap_get_doit, + .dumpit = devlink_nl_cmd_trap_get_dumpit, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_TRAP_SET, + .doit = devlink_nl_cmd_trap_set_doit, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_TRAP_GROUP_GET, + .doit = devlink_nl_cmd_trap_group_get_doit, + .dumpit = devlink_nl_cmd_trap_group_get_dumpit, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_TRAP_GROUP_SET, + .doit = devlink_nl_cmd_trap_group_set_doit, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_TRAP_POLICER_GET, + .doit = devlink_nl_cmd_trap_policer_get_doit, + .dumpit = devlink_nl_cmd_trap_policer_get_dumpit, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_TRAP_POLICER_SET, + .doit = devlink_nl_cmd_trap_policer_set_doit, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = DEVLINK_CMD_SELFTESTS_GET, + .doit = devlink_nl_cmd_selftests_get_doit, + .dumpit = devlink_nl_cmd_selftests_get_dumpit + /* can be retrieved by unprivileged users */ + }, + { + .cmd = DEVLINK_CMD_SELFTESTS_RUN, + .doit = devlink_nl_cmd_selftests_run, + .flags = GENL_ADMIN_PERM, + }, +}; + +static struct genl_family devlink_nl_family __ro_after_init = { + .name = DEVLINK_GENL_NAME, + .version = DEVLINK_GENL_VERSION, + .maxattr = DEVLINK_ATTR_MAX, + .policy = devlink_nl_policy, + .netnsok = true, + .parallel_ops = true, + .pre_doit = devlink_nl_pre_doit, + .post_doit = devlink_nl_post_doit, + .module = THIS_MODULE, + .small_ops = devlink_nl_ops, + .n_small_ops = ARRAY_SIZE(devlink_nl_ops), + .resv_start_op = DEVLINK_CMD_SELFTESTS_RUN + 1, + .mcgrps = devlink_nl_mcgrps, + .n_mcgrps = ARRAY_SIZE(devlink_nl_mcgrps), +}; + +static bool devlink_reload_actions_valid(const struct devlink_ops *ops) +{ + const struct devlink_reload_combination *comb; + int i; + + if (!devlink_reload_supported(ops)) { + if (WARN_ON(ops->reload_actions)) + return false; + return true; + } + + if (WARN_ON(!ops->reload_actions || + ops->reload_actions & BIT(DEVLINK_RELOAD_ACTION_UNSPEC) || + ops->reload_actions >= BIT(__DEVLINK_RELOAD_ACTION_MAX))) + return false; + + if (WARN_ON(ops->reload_limits & BIT(DEVLINK_RELOAD_LIMIT_UNSPEC) || + ops->reload_limits >= BIT(__DEVLINK_RELOAD_LIMIT_MAX))) + return false; + + for (i = 0; i < ARRAY_SIZE(devlink_reload_invalid_combinations); i++) { + comb = &devlink_reload_invalid_combinations[i]; + if (ops->reload_actions == BIT(comb->action) && + ops->reload_limits == BIT(comb->limit)) + return false; + } + return true; +} + +/** + * devlink_set_features - Set devlink supported features + * + * @devlink: devlink + * @features: devlink support features + * + * This interface allows us to set reload ops separatelly from + * the devlink_alloc. + */ +void devlink_set_features(struct devlink *devlink, u64 features) +{ + ASSERT_DEVLINK_NOT_REGISTERED(devlink); + + WARN_ON(features & DEVLINK_F_RELOAD && + !devlink_reload_supported(devlink->ops)); + devlink->features = features; +} +EXPORT_SYMBOL_GPL(devlink_set_features); + +/** + * devlink_alloc_ns - Allocate new devlink instance resources + * in specific namespace + * + * @ops: ops + * @priv_size: size of user private data + * @net: net namespace + * @dev: parent device + * + * Allocate new devlink instance resources, including devlink index + * and name. + */ +struct devlink *devlink_alloc_ns(const struct devlink_ops *ops, + size_t priv_size, struct net *net, + struct device *dev) +{ + struct devlink *devlink; + static u32 last_id; + int ret; + + WARN_ON(!ops || !dev); + if (!devlink_reload_actions_valid(ops)) + return NULL; + + devlink = kzalloc(sizeof(*devlink) + priv_size, GFP_KERNEL); + if (!devlink) + return NULL; + + ret = xa_alloc_cyclic(&devlinks, &devlink->index, devlink, xa_limit_31b, + &last_id, GFP_KERNEL); + if (ret < 0) { + kfree(devlink); + return NULL; + } + + devlink->dev = dev; + devlink->ops = ops; + xa_init_flags(&devlink->snapshot_ids, XA_FLAGS_ALLOC); + write_pnet(&devlink->_net, net); + INIT_LIST_HEAD(&devlink->port_list); + INIT_LIST_HEAD(&devlink->rate_list); + INIT_LIST_HEAD(&devlink->linecard_list); + INIT_LIST_HEAD(&devlink->sb_list); + INIT_LIST_HEAD_RCU(&devlink->dpipe_table_list); + INIT_LIST_HEAD(&devlink->resource_list); + INIT_LIST_HEAD(&devlink->param_list); + INIT_LIST_HEAD(&devlink->region_list); + INIT_LIST_HEAD(&devlink->reporter_list); + INIT_LIST_HEAD(&devlink->trap_list); + INIT_LIST_HEAD(&devlink->trap_group_list); + INIT_LIST_HEAD(&devlink->trap_policer_list); + lockdep_register_key(&devlink->lock_key); + mutex_init(&devlink->lock); + lockdep_set_class(&devlink->lock, &devlink->lock_key); + mutex_init(&devlink->reporters_lock); + mutex_init(&devlink->linecards_lock); + refcount_set(&devlink->refcount, 1); + init_completion(&devlink->comp); + + return devlink; +} +EXPORT_SYMBOL_GPL(devlink_alloc_ns); + +static void +devlink_trap_policer_notify(struct devlink *devlink, + const struct devlink_trap_policer_item *policer_item, + enum devlink_command cmd); +static void +devlink_trap_group_notify(struct devlink *devlink, + const struct devlink_trap_group_item *group_item, + enum devlink_command cmd); +static void devlink_trap_notify(struct devlink *devlink, + const struct devlink_trap_item *trap_item, + enum devlink_command cmd); + +static void devlink_notify_register(struct devlink *devlink) +{ + struct devlink_trap_policer_item *policer_item; + struct devlink_trap_group_item *group_item; + struct devlink_param_item *param_item; + struct devlink_trap_item *trap_item; + struct devlink_port *devlink_port; + struct devlink_linecard *linecard; + struct devlink_rate *rate_node; + struct devlink_region *region; + + devlink_notify(devlink, DEVLINK_CMD_NEW); + list_for_each_entry(linecard, &devlink->linecard_list, list) + devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_NEW); + + list_for_each_entry(devlink_port, &devlink->port_list, list) + devlink_port_notify(devlink_port, DEVLINK_CMD_PORT_NEW); + + list_for_each_entry(policer_item, &devlink->trap_policer_list, list) + devlink_trap_policer_notify(devlink, policer_item, + DEVLINK_CMD_TRAP_POLICER_NEW); + + list_for_each_entry(group_item, &devlink->trap_group_list, list) + devlink_trap_group_notify(devlink, group_item, + DEVLINK_CMD_TRAP_GROUP_NEW); + + list_for_each_entry(trap_item, &devlink->trap_list, list) + devlink_trap_notify(devlink, trap_item, DEVLINK_CMD_TRAP_NEW); + + list_for_each_entry(rate_node, &devlink->rate_list, list) + devlink_rate_notify(rate_node, DEVLINK_CMD_RATE_NEW); + + list_for_each_entry(region, &devlink->region_list, list) + devlink_nl_region_notify(region, NULL, DEVLINK_CMD_REGION_NEW); + + list_for_each_entry(param_item, &devlink->param_list, list) + devlink_param_notify(devlink, 0, param_item, + DEVLINK_CMD_PARAM_NEW); +} + +static void devlink_notify_unregister(struct devlink *devlink) +{ + struct devlink_trap_policer_item *policer_item; + struct devlink_trap_group_item *group_item; + struct devlink_param_item *param_item; + struct devlink_trap_item *trap_item; + struct devlink_port *devlink_port; + struct devlink_linecard *linecard; + struct devlink_rate *rate_node; + struct devlink_region *region; + + list_for_each_entry_reverse(param_item, &devlink->param_list, list) + devlink_param_notify(devlink, 0, param_item, + DEVLINK_CMD_PARAM_DEL); + + list_for_each_entry_reverse(region, &devlink->region_list, list) + devlink_nl_region_notify(region, NULL, DEVLINK_CMD_REGION_DEL); + + list_for_each_entry_reverse(rate_node, &devlink->rate_list, list) + devlink_rate_notify(rate_node, DEVLINK_CMD_RATE_DEL); + + list_for_each_entry_reverse(trap_item, &devlink->trap_list, list) + devlink_trap_notify(devlink, trap_item, DEVLINK_CMD_TRAP_DEL); + + list_for_each_entry_reverse(group_item, &devlink->trap_group_list, list) + devlink_trap_group_notify(devlink, group_item, + DEVLINK_CMD_TRAP_GROUP_DEL); + list_for_each_entry_reverse(policer_item, &devlink->trap_policer_list, + list) + devlink_trap_policer_notify(devlink, policer_item, + DEVLINK_CMD_TRAP_POLICER_DEL); + + list_for_each_entry_reverse(devlink_port, &devlink->port_list, list) + devlink_port_notify(devlink_port, DEVLINK_CMD_PORT_DEL); + list_for_each_entry_reverse(linecard, &devlink->linecard_list, list) + devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_DEL); + devlink_notify(devlink, DEVLINK_CMD_DEL); +} + +/** + * devlink_register - Register devlink instance + * + * @devlink: devlink + */ +void devlink_register(struct devlink *devlink) +{ + ASSERT_DEVLINK_NOT_REGISTERED(devlink); + /* Make sure that we are in .probe() routine */ + + xa_set_mark(&devlinks, devlink->index, DEVLINK_REGISTERED); + devlink_notify_register(devlink); +} +EXPORT_SYMBOL_GPL(devlink_register); + +/** + * devlink_unregister - Unregister devlink instance + * + * @devlink: devlink + */ +void devlink_unregister(struct devlink *devlink) +{ + ASSERT_DEVLINK_REGISTERED(devlink); + /* Make sure that we are in .remove() routine */ + + xa_set_mark(&devlinks, devlink->index, DEVLINK_UNREGISTERING); + devlink_put(devlink); + wait_for_completion(&devlink->comp); + + devlink_notify_unregister(devlink); + xa_clear_mark(&devlinks, devlink->index, DEVLINK_REGISTERED); + xa_clear_mark(&devlinks, devlink->index, DEVLINK_UNREGISTERING); +} +EXPORT_SYMBOL_GPL(devlink_unregister); + +/** + * devlink_free - Free devlink instance resources + * + * @devlink: devlink + */ +void devlink_free(struct devlink *devlink) +{ + ASSERT_DEVLINK_NOT_REGISTERED(devlink); + + mutex_destroy(&devlink->linecards_lock); + mutex_destroy(&devlink->reporters_lock); + mutex_destroy(&devlink->lock); + lockdep_unregister_key(&devlink->lock_key); + WARN_ON(!list_empty(&devlink->trap_policer_list)); + WARN_ON(!list_empty(&devlink->trap_group_list)); + WARN_ON(!list_empty(&devlink->trap_list)); + WARN_ON(!list_empty(&devlink->reporter_list)); + WARN_ON(!list_empty(&devlink->region_list)); + WARN_ON(!list_empty(&devlink->param_list)); + WARN_ON(!list_empty(&devlink->resource_list)); + WARN_ON(!list_empty(&devlink->dpipe_table_list)); + WARN_ON(!list_empty(&devlink->sb_list)); + WARN_ON(!list_empty(&devlink->rate_list)); + WARN_ON(!list_empty(&devlink->linecard_list)); + WARN_ON(!list_empty(&devlink->port_list)); + + xa_destroy(&devlink->snapshot_ids); + xa_erase(&devlinks, devlink->index); + + kfree(devlink); +} +EXPORT_SYMBOL_GPL(devlink_free); + +static void devlink_port_type_warn(struct work_struct *work) +{ + struct devlink_port *port = container_of(to_delayed_work(work), + struct devlink_port, + type_warn_dw); + dev_warn(port->devlink->dev, "Type was not set for devlink port."); +} + +static bool devlink_port_type_should_warn(struct devlink_port *devlink_port) +{ + /* Ignore CPU and DSA flavours. */ + return devlink_port->attrs.flavour != DEVLINK_PORT_FLAVOUR_CPU && + devlink_port->attrs.flavour != DEVLINK_PORT_FLAVOUR_DSA && + devlink_port->attrs.flavour != DEVLINK_PORT_FLAVOUR_UNUSED; +} + +#define DEVLINK_PORT_TYPE_WARN_TIMEOUT (HZ * 3600) + +static void devlink_port_type_warn_schedule(struct devlink_port *devlink_port) +{ + if (!devlink_port_type_should_warn(devlink_port)) + return; + /* Schedule a work to WARN in case driver does not set port + * type within timeout. + */ + schedule_delayed_work(&devlink_port->type_warn_dw, + DEVLINK_PORT_TYPE_WARN_TIMEOUT); +} + +static void devlink_port_type_warn_cancel(struct devlink_port *devlink_port) +{ + if (!devlink_port_type_should_warn(devlink_port)) + return; + cancel_delayed_work_sync(&devlink_port->type_warn_dw); +} + +/** + * devlink_port_init() - Init devlink port + * + * @devlink: devlink + * @devlink_port: devlink port + * + * Initialize essencial stuff that is needed for functions + * that may be called before devlink port registration. + * Call to this function is optional and not needed + * in case the driver does not use such functions. + */ +void devlink_port_init(struct devlink *devlink, + struct devlink_port *devlink_port) +{ + if (devlink_port->initialized) + return; + devlink_port->devlink = devlink; + INIT_LIST_HEAD(&devlink_port->region_list); + devlink_port->initialized = true; +} +EXPORT_SYMBOL_GPL(devlink_port_init); + +/** + * devlink_port_fini() - Deinitialize devlink port + * + * @devlink_port: devlink port + * + * Deinitialize essencial stuff that is in use for functions + * that may be called after devlink port unregistration. + * Call to this function is optional and not needed + * in case the driver does not use such functions. + */ +void devlink_port_fini(struct devlink_port *devlink_port) +{ + WARN_ON(!list_empty(&devlink_port->region_list)); +} +EXPORT_SYMBOL_GPL(devlink_port_fini); + +/** + * devl_port_register() - Register devlink port + * + * @devlink: devlink + * @devlink_port: devlink port + * @port_index: driver-specific numerical identifier of the port + * + * Register devlink port with provided port index. User can use + * any indexing, even hw-related one. devlink_port structure + * is convenient to be embedded inside user driver private structure. + * Note that the caller should take care of zeroing the devlink_port + * structure. + */ +int devl_port_register(struct devlink *devlink, + struct devlink_port *devlink_port, + unsigned int port_index) +{ + devl_assert_locked(devlink); + + if (devlink_port_index_exists(devlink, port_index)) + return -EEXIST; + + ASSERT_DEVLINK_PORT_NOT_REGISTERED(devlink_port); + + devlink_port_init(devlink, devlink_port); + devlink_port->registered = true; + devlink_port->index = port_index; + spin_lock_init(&devlink_port->type_lock); + INIT_LIST_HEAD(&devlink_port->reporter_list); + mutex_init(&devlink_port->reporters_lock); + list_add_tail(&devlink_port->list, &devlink->port_list); + + INIT_DELAYED_WORK(&devlink_port->type_warn_dw, &devlink_port_type_warn); + devlink_port_type_warn_schedule(devlink_port); + devlink_port_notify(devlink_port, DEVLINK_CMD_PORT_NEW); + return 0; +} +EXPORT_SYMBOL_GPL(devl_port_register); + +/** + * devlink_port_register - Register devlink port + * + * @devlink: devlink + * @devlink_port: devlink port + * @port_index: driver-specific numerical identifier of the port + * + * Register devlink port with provided port index. User can use + * any indexing, even hw-related one. devlink_port structure + * is convenient to be embedded inside user driver private structure. + * Note that the caller should take care of zeroing the devlink_port + * structure. + * + * Context: Takes and release devlink->lock . + */ +int devlink_port_register(struct devlink *devlink, + struct devlink_port *devlink_port, + unsigned int port_index) +{ + int err; + + devl_lock(devlink); + err = devl_port_register(devlink, devlink_port, port_index); + devl_unlock(devlink); + return err; +} +EXPORT_SYMBOL_GPL(devlink_port_register); + +/** + * devl_port_unregister() - Unregister devlink port + * + * @devlink_port: devlink port + */ +void devl_port_unregister(struct devlink_port *devlink_port) +{ + lockdep_assert_held(&devlink_port->devlink->lock); + + devlink_port_type_warn_cancel(devlink_port); + devlink_port_notify(devlink_port, DEVLINK_CMD_PORT_DEL); + list_del(&devlink_port->list); + WARN_ON(!list_empty(&devlink_port->reporter_list)); + mutex_destroy(&devlink_port->reporters_lock); + devlink_port->registered = false; +} +EXPORT_SYMBOL_GPL(devl_port_unregister); + +/** + * devlink_port_unregister - Unregister devlink port + * + * @devlink_port: devlink port + * + * Context: Takes and release devlink->lock . + */ +void devlink_port_unregister(struct devlink_port *devlink_port) +{ + struct devlink *devlink = devlink_port->devlink; + + devl_lock(devlink); + devl_port_unregister(devlink_port); + devl_unlock(devlink); +} +EXPORT_SYMBOL_GPL(devlink_port_unregister); + +static void __devlink_port_type_set(struct devlink_port *devlink_port, + enum devlink_port_type type, + void *type_dev) +{ + ASSERT_DEVLINK_PORT_REGISTERED(devlink_port); + + devlink_port_type_warn_cancel(devlink_port); + spin_lock_bh(&devlink_port->type_lock); + devlink_port->type = type; + devlink_port->type_dev = type_dev; + spin_unlock_bh(&devlink_port->type_lock); + devlink_port_notify(devlink_port, DEVLINK_CMD_PORT_NEW); +} + +static void devlink_port_type_netdev_checks(struct devlink_port *devlink_port, + struct net_device *netdev) +{ + const struct net_device_ops *ops = netdev->netdev_ops; + + /* If driver registers devlink port, it should set devlink port + * attributes accordingly so the compat functions are called + * and the original ops are not used. + */ + if (ops->ndo_get_phys_port_name) { + /* Some drivers use the same set of ndos for netdevs + * that have devlink_port registered and also for + * those who don't. Make sure that ndo_get_phys_port_name + * returns -EOPNOTSUPP here in case it is defined. + * Warn if not. + */ + char name[IFNAMSIZ]; + int err; + + err = ops->ndo_get_phys_port_name(netdev, name, sizeof(name)); + WARN_ON(err != -EOPNOTSUPP); + } + if (ops->ndo_get_port_parent_id) { + /* Some drivers use the same set of ndos for netdevs + * that have devlink_port registered and also for + * those who don't. Make sure that ndo_get_port_parent_id + * returns -EOPNOTSUPP here in case it is defined. + * Warn if not. + */ + struct netdev_phys_item_id ppid; + int err; + + err = ops->ndo_get_port_parent_id(netdev, &ppid); + WARN_ON(err != -EOPNOTSUPP); + } +} + +/** + * devlink_port_type_eth_set - Set port type to Ethernet + * + * @devlink_port: devlink port + * @netdev: related netdevice + */ +void devlink_port_type_eth_set(struct devlink_port *devlink_port, + struct net_device *netdev) +{ + if (netdev) + devlink_port_type_netdev_checks(devlink_port, netdev); + else + dev_warn(devlink_port->devlink->dev, + "devlink port type for port %d set to Ethernet without a software interface reference, device type not supported by the kernel?\n", + devlink_port->index); + + __devlink_port_type_set(devlink_port, DEVLINK_PORT_TYPE_ETH, netdev); +} +EXPORT_SYMBOL_GPL(devlink_port_type_eth_set); + +/** + * devlink_port_type_ib_set - Set port type to InfiniBand + * + * @devlink_port: devlink port + * @ibdev: related IB device + */ +void devlink_port_type_ib_set(struct devlink_port *devlink_port, + struct ib_device *ibdev) +{ + __devlink_port_type_set(devlink_port, DEVLINK_PORT_TYPE_IB, ibdev); +} +EXPORT_SYMBOL_GPL(devlink_port_type_ib_set); + +/** + * devlink_port_type_clear - Clear port type + * + * @devlink_port: devlink port + */ +void devlink_port_type_clear(struct devlink_port *devlink_port) +{ + __devlink_port_type_set(devlink_port, DEVLINK_PORT_TYPE_NOTSET, NULL); + devlink_port_type_warn_schedule(devlink_port); +} +EXPORT_SYMBOL_GPL(devlink_port_type_clear); + +static int __devlink_port_attrs_set(struct devlink_port *devlink_port, + enum devlink_port_flavour flavour) +{ + struct devlink_port_attrs *attrs = &devlink_port->attrs; + + devlink_port->attrs_set = true; + attrs->flavour = flavour; + if (attrs->switch_id.id_len) { + devlink_port->switch_port = true; + if (WARN_ON(attrs->switch_id.id_len > MAX_PHYS_ITEM_ID_LEN)) + attrs->switch_id.id_len = MAX_PHYS_ITEM_ID_LEN; + } else { + devlink_port->switch_port = false; + } + return 0; +} + +/** + * devlink_port_attrs_set - Set port attributes + * + * @devlink_port: devlink port + * @attrs: devlink port attrs + */ +void devlink_port_attrs_set(struct devlink_port *devlink_port, + struct devlink_port_attrs *attrs) +{ + int ret; + + ASSERT_DEVLINK_PORT_NOT_REGISTERED(devlink_port); + + devlink_port->attrs = *attrs; + ret = __devlink_port_attrs_set(devlink_port, attrs->flavour); + if (ret) + return; + WARN_ON(attrs->splittable && attrs->split); +} +EXPORT_SYMBOL_GPL(devlink_port_attrs_set); + +/** + * devlink_port_attrs_pci_pf_set - Set PCI PF port attributes + * + * @devlink_port: devlink port + * @controller: associated controller number for the devlink port instance + * @pf: associated PF for the devlink port instance + * @external: indicates if the port is for an external controller + */ +void devlink_port_attrs_pci_pf_set(struct devlink_port *devlink_port, u32 controller, + u16 pf, bool external) +{ + struct devlink_port_attrs *attrs = &devlink_port->attrs; + int ret; + + ASSERT_DEVLINK_PORT_NOT_REGISTERED(devlink_port); + + ret = __devlink_port_attrs_set(devlink_port, + DEVLINK_PORT_FLAVOUR_PCI_PF); + if (ret) + return; + attrs->pci_pf.controller = controller; + attrs->pci_pf.pf = pf; + attrs->pci_pf.external = external; +} +EXPORT_SYMBOL_GPL(devlink_port_attrs_pci_pf_set); + +/** + * devlink_port_attrs_pci_vf_set - Set PCI VF port attributes + * + * @devlink_port: devlink port + * @controller: associated controller number for the devlink port instance + * @pf: associated PF for the devlink port instance + * @vf: associated VF of a PF for the devlink port instance + * @external: indicates if the port is for an external controller + */ +void devlink_port_attrs_pci_vf_set(struct devlink_port *devlink_port, u32 controller, + u16 pf, u16 vf, bool external) +{ + struct devlink_port_attrs *attrs = &devlink_port->attrs; + int ret; + + ASSERT_DEVLINK_PORT_NOT_REGISTERED(devlink_port); + + ret = __devlink_port_attrs_set(devlink_port, + DEVLINK_PORT_FLAVOUR_PCI_VF); + if (ret) + return; + attrs->pci_vf.controller = controller; + attrs->pci_vf.pf = pf; + attrs->pci_vf.vf = vf; + attrs->pci_vf.external = external; +} +EXPORT_SYMBOL_GPL(devlink_port_attrs_pci_vf_set); + +/** + * devlink_port_attrs_pci_sf_set - Set PCI SF port attributes + * + * @devlink_port: devlink port + * @controller: associated controller number for the devlink port instance + * @pf: associated PF for the devlink port instance + * @sf: associated SF of a PF for the devlink port instance + * @external: indicates if the port is for an external controller + */ +void devlink_port_attrs_pci_sf_set(struct devlink_port *devlink_port, u32 controller, + u16 pf, u32 sf, bool external) +{ + struct devlink_port_attrs *attrs = &devlink_port->attrs; + int ret; + + ASSERT_DEVLINK_PORT_NOT_REGISTERED(devlink_port); + + ret = __devlink_port_attrs_set(devlink_port, + DEVLINK_PORT_FLAVOUR_PCI_SF); + if (ret) + return; + attrs->pci_sf.controller = controller; + attrs->pci_sf.pf = pf; + attrs->pci_sf.sf = sf; + attrs->pci_sf.external = external; +} +EXPORT_SYMBOL_GPL(devlink_port_attrs_pci_sf_set); + +/** + * devl_rate_leaf_create - create devlink rate leaf + * @devlink_port: devlink port object to create rate object on + * @priv: driver private data + * + * Create devlink rate object of type leaf on provided @devlink_port. + */ +int devl_rate_leaf_create(struct devlink_port *devlink_port, void *priv) +{ + struct devlink *devlink = devlink_port->devlink; + struct devlink_rate *devlink_rate; + + devl_assert_locked(devlink_port->devlink); + + if (WARN_ON(devlink_port->devlink_rate)) + return -EBUSY; + + devlink_rate = kzalloc(sizeof(*devlink_rate), GFP_KERNEL); + if (!devlink_rate) + return -ENOMEM; + + devlink_rate->type = DEVLINK_RATE_TYPE_LEAF; + devlink_rate->devlink = devlink; + devlink_rate->devlink_port = devlink_port; + devlink_rate->priv = priv; + list_add_tail(&devlink_rate->list, &devlink->rate_list); + devlink_port->devlink_rate = devlink_rate; + devlink_rate_notify(devlink_rate, DEVLINK_CMD_RATE_NEW); + + return 0; +} +EXPORT_SYMBOL_GPL(devl_rate_leaf_create); + +/** + * devl_rate_leaf_destroy - destroy devlink rate leaf + * + * @devlink_port: devlink port linked to the rate object + * + * Destroy the devlink rate object of type leaf on provided @devlink_port. + */ +void devl_rate_leaf_destroy(struct devlink_port *devlink_port) +{ + struct devlink_rate *devlink_rate = devlink_port->devlink_rate; + + devl_assert_locked(devlink_port->devlink); + if (!devlink_rate) + return; + + devlink_rate_notify(devlink_rate, DEVLINK_CMD_RATE_DEL); + if (devlink_rate->parent) + refcount_dec(&devlink_rate->parent->refcnt); + list_del(&devlink_rate->list); + devlink_port->devlink_rate = NULL; + kfree(devlink_rate); +} +EXPORT_SYMBOL_GPL(devl_rate_leaf_destroy); + +/** + * devl_rate_nodes_destroy - destroy all devlink rate nodes on device + * @devlink: devlink instance + * + * Unset parent for all rate objects and destroy all rate nodes + * on specified device. + */ +void devl_rate_nodes_destroy(struct devlink *devlink) +{ + static struct devlink_rate *devlink_rate, *tmp; + const struct devlink_ops *ops = devlink->ops; + + devl_assert_locked(devlink); + + list_for_each_entry(devlink_rate, &devlink->rate_list, list) { + if (!devlink_rate->parent) + continue; + + refcount_dec(&devlink_rate->parent->refcnt); + if (devlink_rate_is_leaf(devlink_rate)) + ops->rate_leaf_parent_set(devlink_rate, NULL, devlink_rate->priv, + NULL, NULL); + else if (devlink_rate_is_node(devlink_rate)) + ops->rate_node_parent_set(devlink_rate, NULL, devlink_rate->priv, + NULL, NULL); + } + list_for_each_entry_safe(devlink_rate, tmp, &devlink->rate_list, list) { + if (devlink_rate_is_node(devlink_rate)) { + ops->rate_node_del(devlink_rate, devlink_rate->priv, NULL); + list_del(&devlink_rate->list); + kfree(devlink_rate->name); + kfree(devlink_rate); + } + } +} +EXPORT_SYMBOL_GPL(devl_rate_nodes_destroy); + +/** + * devlink_port_linecard_set - Link port with a linecard + * + * @devlink_port: devlink port + * @linecard: devlink linecard + */ +void devlink_port_linecard_set(struct devlink_port *devlink_port, + struct devlink_linecard *linecard) +{ + ASSERT_DEVLINK_PORT_NOT_REGISTERED(devlink_port); + + devlink_port->linecard = linecard; +} +EXPORT_SYMBOL_GPL(devlink_port_linecard_set); + +static int __devlink_port_phys_port_name_get(struct devlink_port *devlink_port, + char *name, size_t len) +{ + struct devlink_port_attrs *attrs = &devlink_port->attrs; + int n = 0; + + if (!devlink_port->attrs_set) + return -EOPNOTSUPP; + + switch (attrs->flavour) { + case DEVLINK_PORT_FLAVOUR_PHYSICAL: + if (devlink_port->linecard) + n = snprintf(name, len, "l%u", + devlink_port->linecard->index); + if (n < len) + n += snprintf(name + n, len - n, "p%u", + attrs->phys.port_number); + if (n < len && attrs->split) + n += snprintf(name + n, len - n, "s%u", + attrs->phys.split_subport_number); + break; + case DEVLINK_PORT_FLAVOUR_CPU: + case DEVLINK_PORT_FLAVOUR_DSA: + case DEVLINK_PORT_FLAVOUR_UNUSED: + /* As CPU and DSA ports do not have a netdevice associated + * case should not ever happen. + */ + WARN_ON(1); + return -EINVAL; + case DEVLINK_PORT_FLAVOUR_PCI_PF: + if (attrs->pci_pf.external) { + n = snprintf(name, len, "c%u", attrs->pci_pf.controller); + if (n >= len) + return -EINVAL; + len -= n; + name += n; + } + n = snprintf(name, len, "pf%u", attrs->pci_pf.pf); + break; + case DEVLINK_PORT_FLAVOUR_PCI_VF: + if (attrs->pci_vf.external) { + n = snprintf(name, len, "c%u", attrs->pci_vf.controller); + if (n >= len) + return -EINVAL; + len -= n; + name += n; + } + n = snprintf(name, len, "pf%uvf%u", + attrs->pci_vf.pf, attrs->pci_vf.vf); + break; + case DEVLINK_PORT_FLAVOUR_PCI_SF: + if (attrs->pci_sf.external) { + n = snprintf(name, len, "c%u", attrs->pci_sf.controller); + if (n >= len) + return -EINVAL; + len -= n; + name += n; + } + n = snprintf(name, len, "pf%usf%u", attrs->pci_sf.pf, + attrs->pci_sf.sf); + break; + case DEVLINK_PORT_FLAVOUR_VIRTUAL: + return -EOPNOTSUPP; + } + + if (n >= len) + return -EINVAL; + + return 0; +} + +static int devlink_linecard_types_init(struct devlink_linecard *linecard) +{ + struct devlink_linecard_type *linecard_type; + unsigned int count; + int i; + + count = linecard->ops->types_count(linecard, linecard->priv); + linecard->types = kmalloc_array(count, sizeof(*linecard_type), + GFP_KERNEL); + if (!linecard->types) + return -ENOMEM; + linecard->types_count = count; + + for (i = 0; i < count; i++) { + linecard_type = &linecard->types[i]; + linecard->ops->types_get(linecard, linecard->priv, i, + &linecard_type->type, + &linecard_type->priv); + } + return 0; +} + +static void devlink_linecard_types_fini(struct devlink_linecard *linecard) +{ + kfree(linecard->types); +} + +/** + * devlink_linecard_create - Create devlink linecard + * + * @devlink: devlink + * @linecard_index: driver-specific numerical identifier of the linecard + * @ops: linecards ops + * @priv: user priv pointer + * + * Create devlink linecard instance with provided linecard index. + * Caller can use any indexing, even hw-related one. + * + * Return: Line card structure or an ERR_PTR() encoded error code. + */ +struct devlink_linecard * +devlink_linecard_create(struct devlink *devlink, unsigned int linecard_index, + const struct devlink_linecard_ops *ops, void *priv) +{ + struct devlink_linecard *linecard; + int err; + + if (WARN_ON(!ops || !ops->provision || !ops->unprovision || + !ops->types_count || !ops->types_get)) + return ERR_PTR(-EINVAL); + + mutex_lock(&devlink->linecards_lock); + if (devlink_linecard_index_exists(devlink, linecard_index)) { + mutex_unlock(&devlink->linecards_lock); + return ERR_PTR(-EEXIST); + } + + linecard = kzalloc(sizeof(*linecard), GFP_KERNEL); + if (!linecard) { + mutex_unlock(&devlink->linecards_lock); + return ERR_PTR(-ENOMEM); + } + + linecard->devlink = devlink; + linecard->index = linecard_index; + linecard->ops = ops; + linecard->priv = priv; + linecard->state = DEVLINK_LINECARD_STATE_UNPROVISIONED; + mutex_init(&linecard->state_lock); + + err = devlink_linecard_types_init(linecard); + if (err) { + mutex_destroy(&linecard->state_lock); + kfree(linecard); + mutex_unlock(&devlink->linecards_lock); + return ERR_PTR(err); + } + + list_add_tail(&linecard->list, &devlink->linecard_list); + refcount_set(&linecard->refcount, 1); + mutex_unlock(&devlink->linecards_lock); + devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_NEW); + return linecard; +} +EXPORT_SYMBOL_GPL(devlink_linecard_create); + +/** + * devlink_linecard_destroy - Destroy devlink linecard + * + * @linecard: devlink linecard + */ +void devlink_linecard_destroy(struct devlink_linecard *linecard) +{ + struct devlink *devlink = linecard->devlink; + + devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_DEL); + mutex_lock(&devlink->linecards_lock); + list_del(&linecard->list); + devlink_linecard_types_fini(linecard); + mutex_unlock(&devlink->linecards_lock); + devlink_linecard_put(linecard); +} +EXPORT_SYMBOL_GPL(devlink_linecard_destroy); + +/** + * devlink_linecard_provision_set - Set provisioning on linecard + * + * @linecard: devlink linecard + * @type: linecard type + * + * This is either called directly from the provision() op call or + * as a result of the provision() op call asynchronously. + */ +void devlink_linecard_provision_set(struct devlink_linecard *linecard, + const char *type) +{ + mutex_lock(&linecard->state_lock); + WARN_ON(linecard->type && strcmp(linecard->type, type)); + linecard->state = DEVLINK_LINECARD_STATE_PROVISIONED; + linecard->type = type; + devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_NEW); + mutex_unlock(&linecard->state_lock); +} +EXPORT_SYMBOL_GPL(devlink_linecard_provision_set); + +/** + * devlink_linecard_provision_clear - Clear provisioning on linecard + * + * @linecard: devlink linecard + * + * This is either called directly from the unprovision() op call or + * as a result of the unprovision() op call asynchronously. + */ +void devlink_linecard_provision_clear(struct devlink_linecard *linecard) +{ + mutex_lock(&linecard->state_lock); + WARN_ON(linecard->nested_devlink); + linecard->state = DEVLINK_LINECARD_STATE_UNPROVISIONED; + linecard->type = NULL; + devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_NEW); + mutex_unlock(&linecard->state_lock); +} +EXPORT_SYMBOL_GPL(devlink_linecard_provision_clear); + +/** + * devlink_linecard_provision_fail - Fail provisioning on linecard + * + * @linecard: devlink linecard + * + * This is either called directly from the provision() op call or + * as a result of the provision() op call asynchronously. + */ +void devlink_linecard_provision_fail(struct devlink_linecard *linecard) +{ + mutex_lock(&linecard->state_lock); + WARN_ON(linecard->nested_devlink); + linecard->state = DEVLINK_LINECARD_STATE_PROVISIONING_FAILED; + devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_NEW); + mutex_unlock(&linecard->state_lock); +} +EXPORT_SYMBOL_GPL(devlink_linecard_provision_fail); + +/** + * devlink_linecard_activate - Set linecard active + * + * @linecard: devlink linecard + */ +void devlink_linecard_activate(struct devlink_linecard *linecard) +{ + mutex_lock(&linecard->state_lock); + WARN_ON(linecard->state != DEVLINK_LINECARD_STATE_PROVISIONED); + linecard->state = DEVLINK_LINECARD_STATE_ACTIVE; + devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_NEW); + mutex_unlock(&linecard->state_lock); +} +EXPORT_SYMBOL_GPL(devlink_linecard_activate); + +/** + * devlink_linecard_deactivate - Set linecard inactive + * + * @linecard: devlink linecard + */ +void devlink_linecard_deactivate(struct devlink_linecard *linecard) +{ + mutex_lock(&linecard->state_lock); + switch (linecard->state) { + case DEVLINK_LINECARD_STATE_ACTIVE: + linecard->state = DEVLINK_LINECARD_STATE_PROVISIONED; + devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_NEW); + break; + case DEVLINK_LINECARD_STATE_UNPROVISIONING: + /* Line card is being deactivated as part + * of unprovisioning flow. + */ + break; + default: + WARN_ON(1); + break; + } + mutex_unlock(&linecard->state_lock); +} +EXPORT_SYMBOL_GPL(devlink_linecard_deactivate); + +/** + * devlink_linecard_nested_dl_set - Attach/detach nested devlink + * instance to linecard. + * + * @linecard: devlink linecard + * @nested_devlink: devlink instance to attach or NULL to detach + */ +void devlink_linecard_nested_dl_set(struct devlink_linecard *linecard, + struct devlink *nested_devlink) +{ + mutex_lock(&linecard->state_lock); + linecard->nested_devlink = nested_devlink; + devlink_linecard_notify(linecard, DEVLINK_CMD_LINECARD_NEW); + mutex_unlock(&linecard->state_lock); +} +EXPORT_SYMBOL_GPL(devlink_linecard_nested_dl_set); + +int devl_sb_register(struct devlink *devlink, unsigned int sb_index, + u32 size, u16 ingress_pools_count, + u16 egress_pools_count, u16 ingress_tc_count, + u16 egress_tc_count) +{ + struct devlink_sb *devlink_sb; + + lockdep_assert_held(&devlink->lock); + + if (devlink_sb_index_exists(devlink, sb_index)) + return -EEXIST; + + devlink_sb = kzalloc(sizeof(*devlink_sb), GFP_KERNEL); + if (!devlink_sb) + return -ENOMEM; + devlink_sb->index = sb_index; + devlink_sb->size = size; + devlink_sb->ingress_pools_count = ingress_pools_count; + devlink_sb->egress_pools_count = egress_pools_count; + devlink_sb->ingress_tc_count = ingress_tc_count; + devlink_sb->egress_tc_count = egress_tc_count; + list_add_tail(&devlink_sb->list, &devlink->sb_list); + return 0; +} +EXPORT_SYMBOL_GPL(devl_sb_register); + +int devlink_sb_register(struct devlink *devlink, unsigned int sb_index, + u32 size, u16 ingress_pools_count, + u16 egress_pools_count, u16 ingress_tc_count, + u16 egress_tc_count) +{ + int err; + + devl_lock(devlink); + err = devl_sb_register(devlink, sb_index, size, ingress_pools_count, + egress_pools_count, ingress_tc_count, + egress_tc_count); + devl_unlock(devlink); + return err; +} +EXPORT_SYMBOL_GPL(devlink_sb_register); + +void devl_sb_unregister(struct devlink *devlink, unsigned int sb_index) +{ + struct devlink_sb *devlink_sb; + + lockdep_assert_held(&devlink->lock); + + devlink_sb = devlink_sb_get_by_index(devlink, sb_index); + WARN_ON(!devlink_sb); + list_del(&devlink_sb->list); + kfree(devlink_sb); +} +EXPORT_SYMBOL_GPL(devl_sb_unregister); + +void devlink_sb_unregister(struct devlink *devlink, unsigned int sb_index) +{ + devl_lock(devlink); + devl_sb_unregister(devlink, sb_index); + devl_unlock(devlink); +} +EXPORT_SYMBOL_GPL(devlink_sb_unregister); + +/** + * devl_dpipe_headers_register - register dpipe headers + * + * @devlink: devlink + * @dpipe_headers: dpipe header array + * + * Register the headers supported by hardware. + */ +void devl_dpipe_headers_register(struct devlink *devlink, + struct devlink_dpipe_headers *dpipe_headers) +{ + lockdep_assert_held(&devlink->lock); + + devlink->dpipe_headers = dpipe_headers; +} +EXPORT_SYMBOL_GPL(devl_dpipe_headers_register); + +/** + * devl_dpipe_headers_unregister - unregister dpipe headers + * + * @devlink: devlink + * + * Unregister the headers supported by hardware. + */ +void devl_dpipe_headers_unregister(struct devlink *devlink) +{ + lockdep_assert_held(&devlink->lock); + + devlink->dpipe_headers = NULL; +} +EXPORT_SYMBOL_GPL(devl_dpipe_headers_unregister); + +/** + * devlink_dpipe_table_counter_enabled - check if counter allocation + * required + * @devlink: devlink + * @table_name: tables name + * + * Used by driver to check if counter allocation is required. + * After counter allocation is turned on the table entries + * are updated to include counter statistics. + * + * After that point on the driver must respect the counter + * state so that each entry added to the table is added + * with a counter. + */ +bool devlink_dpipe_table_counter_enabled(struct devlink *devlink, + const char *table_name) +{ + struct devlink_dpipe_table *table; + bool enabled; + + rcu_read_lock(); + table = devlink_dpipe_table_find(&devlink->dpipe_table_list, + table_name, devlink); + enabled = false; + if (table) + enabled = table->counters_enabled; + rcu_read_unlock(); + return enabled; +} +EXPORT_SYMBOL_GPL(devlink_dpipe_table_counter_enabled); + +/** + * devl_dpipe_table_register - register dpipe table + * + * @devlink: devlink + * @table_name: table name + * @table_ops: table ops + * @priv: priv + * @counter_control_extern: external control for counters + */ +int devl_dpipe_table_register(struct devlink *devlink, + const char *table_name, + struct devlink_dpipe_table_ops *table_ops, + void *priv, bool counter_control_extern) +{ + struct devlink_dpipe_table *table; + + lockdep_assert_held(&devlink->lock); + + if (WARN_ON(!table_ops->size_get)) + return -EINVAL; + + if (devlink_dpipe_table_find(&devlink->dpipe_table_list, table_name, + devlink)) + return -EEXIST; + + table = kzalloc(sizeof(*table), GFP_KERNEL); + if (!table) + return -ENOMEM; + + table->name = table_name; + table->table_ops = table_ops; + table->priv = priv; + table->counter_control_extern = counter_control_extern; + + list_add_tail_rcu(&table->list, &devlink->dpipe_table_list); + + return 0; +} +EXPORT_SYMBOL_GPL(devl_dpipe_table_register); + +/** + * devl_dpipe_table_unregister - unregister dpipe table + * + * @devlink: devlink + * @table_name: table name + */ +void devl_dpipe_table_unregister(struct devlink *devlink, + const char *table_name) +{ + struct devlink_dpipe_table *table; + + lockdep_assert_held(&devlink->lock); + + table = devlink_dpipe_table_find(&devlink->dpipe_table_list, + table_name, devlink); + if (!table) + return; + list_del_rcu(&table->list); + kfree_rcu(table, rcu); +} +EXPORT_SYMBOL_GPL(devl_dpipe_table_unregister); + +/** + * devl_resource_register - devlink resource register + * + * @devlink: devlink + * @resource_name: resource's name + * @resource_size: resource's size + * @resource_id: resource's id + * @parent_resource_id: resource's parent id + * @size_params: size parameters + * + * Generic resources should reuse the same names across drivers. + * Please see the generic resources list at: + * Documentation/networking/devlink/devlink-resource.rst + */ +int devl_resource_register(struct devlink *devlink, + const char *resource_name, + u64 resource_size, + u64 resource_id, + u64 parent_resource_id, + const struct devlink_resource_size_params *size_params) +{ + struct devlink_resource *resource; + struct list_head *resource_list; + bool top_hierarchy; + + lockdep_assert_held(&devlink->lock); + + top_hierarchy = parent_resource_id == DEVLINK_RESOURCE_ID_PARENT_TOP; + + resource = devlink_resource_find(devlink, NULL, resource_id); + if (resource) + return -EINVAL; + + resource = kzalloc(sizeof(*resource), GFP_KERNEL); + if (!resource) + return -ENOMEM; + + if (top_hierarchy) { + resource_list = &devlink->resource_list; + } else { + struct devlink_resource *parent_resource; + + parent_resource = devlink_resource_find(devlink, NULL, + parent_resource_id); + if (parent_resource) { + resource_list = &parent_resource->resource_list; + resource->parent = parent_resource; + } else { + kfree(resource); + return -EINVAL; + } + } + + resource->name = resource_name; + resource->size = resource_size; + resource->size_new = resource_size; + resource->id = resource_id; + resource->size_valid = true; + memcpy(&resource->size_params, size_params, + sizeof(resource->size_params)); + INIT_LIST_HEAD(&resource->resource_list); + list_add_tail(&resource->list, resource_list); + + return 0; +} +EXPORT_SYMBOL_GPL(devl_resource_register); + +/** + * devlink_resource_register - devlink resource register + * + * @devlink: devlink + * @resource_name: resource's name + * @resource_size: resource's size + * @resource_id: resource's id + * @parent_resource_id: resource's parent id + * @size_params: size parameters + * + * Generic resources should reuse the same names across drivers. + * Please see the generic resources list at: + * Documentation/networking/devlink/devlink-resource.rst + * + * Context: Takes and release devlink->lock . + */ +int devlink_resource_register(struct devlink *devlink, + const char *resource_name, + u64 resource_size, + u64 resource_id, + u64 parent_resource_id, + const struct devlink_resource_size_params *size_params) +{ + int err; + + devl_lock(devlink); + err = devl_resource_register(devlink, resource_name, resource_size, + resource_id, parent_resource_id, size_params); + devl_unlock(devlink); + return err; +} +EXPORT_SYMBOL_GPL(devlink_resource_register); + +static void devlink_resource_unregister(struct devlink *devlink, + struct devlink_resource *resource) +{ + struct devlink_resource *tmp, *child_resource; + + list_for_each_entry_safe(child_resource, tmp, &resource->resource_list, + list) { + devlink_resource_unregister(devlink, child_resource); + list_del(&child_resource->list); + kfree(child_resource); + } +} + +/** + * devl_resources_unregister - free all resources + * + * @devlink: devlink + */ +void devl_resources_unregister(struct devlink *devlink) +{ + struct devlink_resource *tmp, *child_resource; + + lockdep_assert_held(&devlink->lock); + + list_for_each_entry_safe(child_resource, tmp, &devlink->resource_list, + list) { + devlink_resource_unregister(devlink, child_resource); + list_del(&child_resource->list); + kfree(child_resource); + } +} +EXPORT_SYMBOL_GPL(devl_resources_unregister); + +/** + * devlink_resources_unregister - free all resources + * + * @devlink: devlink + * + * Context: Takes and release devlink->lock . + */ +void devlink_resources_unregister(struct devlink *devlink) +{ + devl_lock(devlink); + devl_resources_unregister(devlink); + devl_unlock(devlink); +} +EXPORT_SYMBOL_GPL(devlink_resources_unregister); + +/** + * devl_resource_size_get - get and update size + * + * @devlink: devlink + * @resource_id: the requested resource id + * @p_resource_size: ptr to update + */ +int devl_resource_size_get(struct devlink *devlink, + u64 resource_id, + u64 *p_resource_size) +{ + struct devlink_resource *resource; + + lockdep_assert_held(&devlink->lock); + + resource = devlink_resource_find(devlink, NULL, resource_id); + if (!resource) + return -EINVAL; + *p_resource_size = resource->size_new; + resource->size = resource->size_new; + return 0; +} +EXPORT_SYMBOL_GPL(devl_resource_size_get); + +/** + * devl_dpipe_table_resource_set - set the resource id + * + * @devlink: devlink + * @table_name: table name + * @resource_id: resource id + * @resource_units: number of resource's units consumed per table's entry + */ +int devl_dpipe_table_resource_set(struct devlink *devlink, + const char *table_name, u64 resource_id, + u64 resource_units) +{ + struct devlink_dpipe_table *table; + + table = devlink_dpipe_table_find(&devlink->dpipe_table_list, + table_name, devlink); + if (!table) + return -EINVAL; + + table->resource_id = resource_id; + table->resource_units = resource_units; + table->resource_valid = true; + return 0; +} +EXPORT_SYMBOL_GPL(devl_dpipe_table_resource_set); + +/** + * devl_resource_occ_get_register - register occupancy getter + * + * @devlink: devlink + * @resource_id: resource id + * @occ_get: occupancy getter callback + * @occ_get_priv: occupancy getter callback priv + */ +void devl_resource_occ_get_register(struct devlink *devlink, + u64 resource_id, + devlink_resource_occ_get_t *occ_get, + void *occ_get_priv) +{ + struct devlink_resource *resource; + + lockdep_assert_held(&devlink->lock); + + resource = devlink_resource_find(devlink, NULL, resource_id); + if (WARN_ON(!resource)) + return; + WARN_ON(resource->occ_get); + + resource->occ_get = occ_get; + resource->occ_get_priv = occ_get_priv; +} +EXPORT_SYMBOL_GPL(devl_resource_occ_get_register); + +/** + * devlink_resource_occ_get_register - register occupancy getter + * + * @devlink: devlink + * @resource_id: resource id + * @occ_get: occupancy getter callback + * @occ_get_priv: occupancy getter callback priv + * + * Context: Takes and release devlink->lock . + */ +void devlink_resource_occ_get_register(struct devlink *devlink, + u64 resource_id, + devlink_resource_occ_get_t *occ_get, + void *occ_get_priv) +{ + devl_lock(devlink); + devl_resource_occ_get_register(devlink, resource_id, + occ_get, occ_get_priv); + devl_unlock(devlink); +} +EXPORT_SYMBOL_GPL(devlink_resource_occ_get_register); + +/** + * devl_resource_occ_get_unregister - unregister occupancy getter + * + * @devlink: devlink + * @resource_id: resource id + */ +void devl_resource_occ_get_unregister(struct devlink *devlink, + u64 resource_id) +{ + struct devlink_resource *resource; + + lockdep_assert_held(&devlink->lock); + + resource = devlink_resource_find(devlink, NULL, resource_id); + if (WARN_ON(!resource)) + return; + WARN_ON(!resource->occ_get); + + resource->occ_get = NULL; + resource->occ_get_priv = NULL; +} +EXPORT_SYMBOL_GPL(devl_resource_occ_get_unregister); + +/** + * devlink_resource_occ_get_unregister - unregister occupancy getter + * + * @devlink: devlink + * @resource_id: resource id + * + * Context: Takes and release devlink->lock . + */ +void devlink_resource_occ_get_unregister(struct devlink *devlink, + u64 resource_id) +{ + devl_lock(devlink); + devl_resource_occ_get_unregister(devlink, resource_id); + devl_unlock(devlink); +} +EXPORT_SYMBOL_GPL(devlink_resource_occ_get_unregister); + +static int devlink_param_verify(const struct devlink_param *param) +{ + if (!param || !param->name || !param->supported_cmodes) + return -EINVAL; + if (param->generic) + return devlink_param_generic_verify(param); + else + return devlink_param_driver_verify(param); +} + +/** + * devlink_params_register - register configuration parameters + * + * @devlink: devlink + * @params: configuration parameters array + * @params_count: number of parameters provided + * + * Register the configuration parameters supported by the driver. + */ +int devlink_params_register(struct devlink *devlink, + const struct devlink_param *params, + size_t params_count) +{ + const struct devlink_param *param = params; + int i, err; + + ASSERT_DEVLINK_NOT_REGISTERED(devlink); + + for (i = 0; i < params_count; i++, param++) { + err = devlink_param_register(devlink, param); + if (err) + goto rollback; + } + return 0; + +rollback: + if (!i) + return err; + + for (param--; i > 0; i--, param--) + devlink_param_unregister(devlink, param); + return err; +} +EXPORT_SYMBOL_GPL(devlink_params_register); + +/** + * devlink_params_unregister - unregister configuration parameters + * @devlink: devlink + * @params: configuration parameters to unregister + * @params_count: number of parameters provided + */ +void devlink_params_unregister(struct devlink *devlink, + const struct devlink_param *params, + size_t params_count) +{ + const struct devlink_param *param = params; + int i; + + ASSERT_DEVLINK_NOT_REGISTERED(devlink); + + for (i = 0; i < params_count; i++, param++) + devlink_param_unregister(devlink, param); +} +EXPORT_SYMBOL_GPL(devlink_params_unregister); + +/** + * devlink_param_register - register one configuration parameter + * + * @devlink: devlink + * @param: one configuration parameter + * + * Register the configuration parameter supported by the driver. + * Return: returns 0 on successful registration or error code otherwise. + */ +int devlink_param_register(struct devlink *devlink, + const struct devlink_param *param) +{ + struct devlink_param_item *param_item; + + ASSERT_DEVLINK_NOT_REGISTERED(devlink); + + WARN_ON(devlink_param_verify(param)); + WARN_ON(devlink_param_find_by_name(&devlink->param_list, param->name)); + + if (param->supported_cmodes == BIT(DEVLINK_PARAM_CMODE_DRIVERINIT)) + WARN_ON(param->get || param->set); + else + WARN_ON(!param->get || !param->set); + + param_item = kzalloc(sizeof(*param_item), GFP_KERNEL); + if (!param_item) + return -ENOMEM; + + param_item->param = param; + + list_add_tail(¶m_item->list, &devlink->param_list); + return 0; +} +EXPORT_SYMBOL_GPL(devlink_param_register); + +/** + * devlink_param_unregister - unregister one configuration parameter + * @devlink: devlink + * @param: configuration parameter to unregister + */ +void devlink_param_unregister(struct devlink *devlink, + const struct devlink_param *param) +{ + struct devlink_param_item *param_item; + + ASSERT_DEVLINK_NOT_REGISTERED(devlink); + + param_item = + devlink_param_find_by_name(&devlink->param_list, param->name); + WARN_ON(!param_item); + list_del(¶m_item->list); + kfree(param_item); +} +EXPORT_SYMBOL_GPL(devlink_param_unregister); + +/** + * devlink_param_driverinit_value_get - get configuration parameter + * value for driver initializing + * + * @devlink: devlink + * @param_id: parameter ID + * @init_val: value of parameter in driverinit configuration mode + * + * This function should be used by the driver to get driverinit + * configuration for initialization after reload command. + */ +int devlink_param_driverinit_value_get(struct devlink *devlink, u32 param_id, + union devlink_param_value *init_val) +{ + struct devlink_param_item *param_item; + + if (!devlink_reload_supported(devlink->ops)) + return -EOPNOTSUPP; + + param_item = devlink_param_find_by_id(&devlink->param_list, param_id); + if (!param_item) + return -EINVAL; + + if (!param_item->driverinit_value_valid || + !devlink_param_cmode_is_supported(param_item->param, + DEVLINK_PARAM_CMODE_DRIVERINIT)) + return -EOPNOTSUPP; + + if (param_item->param->type == DEVLINK_PARAM_TYPE_STRING) + strcpy(init_val->vstr, param_item->driverinit_value.vstr); + else + *init_val = param_item->driverinit_value; + + return 0; +} +EXPORT_SYMBOL_GPL(devlink_param_driverinit_value_get); + +/** + * devlink_param_driverinit_value_set - set value of configuration + * parameter for driverinit + * configuration mode + * + * @devlink: devlink + * @param_id: parameter ID + * @init_val: value of parameter to set for driverinit configuration mode + * + * This function should be used by the driver to set driverinit + * configuration mode default value. + */ +int devlink_param_driverinit_value_set(struct devlink *devlink, u32 param_id, + union devlink_param_value init_val) +{ + struct devlink_param_item *param_item; + + ASSERT_DEVLINK_NOT_REGISTERED(devlink); + + param_item = devlink_param_find_by_id(&devlink->param_list, param_id); + if (!param_item) + return -EINVAL; + + if (!devlink_param_cmode_is_supported(param_item->param, + DEVLINK_PARAM_CMODE_DRIVERINIT)) + return -EOPNOTSUPP; + + if (param_item->param->type == DEVLINK_PARAM_TYPE_STRING) + strcpy(param_item->driverinit_value.vstr, init_val.vstr); + else + param_item->driverinit_value = init_val; + param_item->driverinit_value_valid = true; + return 0; +} +EXPORT_SYMBOL_GPL(devlink_param_driverinit_value_set); + +/** + * devlink_param_value_changed - notify devlink on a parameter's value + * change. Should be called by the driver + * right after the change. + * + * @devlink: devlink + * @param_id: parameter ID + * + * This function should be used by the driver to notify devlink on value + * change, excluding driverinit configuration mode. + * For driverinit configuration mode driver should use the function + */ +void devlink_param_value_changed(struct devlink *devlink, u32 param_id) +{ + struct devlink_param_item *param_item; + + param_item = devlink_param_find_by_id(&devlink->param_list, param_id); + WARN_ON(!param_item); + + devlink_param_notify(devlink, 0, param_item, DEVLINK_CMD_PARAM_NEW); +} +EXPORT_SYMBOL_GPL(devlink_param_value_changed); + +/** + * devl_region_create - create a new address region + * + * @devlink: devlink + * @ops: region operations and name + * @region_max_snapshots: Maximum supported number of snapshots for region + * @region_size: size of region + */ +struct devlink_region *devl_region_create(struct devlink *devlink, + const struct devlink_region_ops *ops, + u32 region_max_snapshots, + u64 region_size) +{ + struct devlink_region *region; + + devl_assert_locked(devlink); + + if (WARN_ON(!ops) || WARN_ON(!ops->destructor)) + return ERR_PTR(-EINVAL); + + if (devlink_region_get_by_name(devlink, ops->name)) + return ERR_PTR(-EEXIST); + + region = kzalloc(sizeof(*region), GFP_KERNEL); + if (!region) + return ERR_PTR(-ENOMEM); + + region->devlink = devlink; + region->max_snapshots = region_max_snapshots; + region->ops = ops; + region->size = region_size; + INIT_LIST_HEAD(®ion->snapshot_list); + mutex_init(®ion->snapshot_lock); + list_add_tail(®ion->list, &devlink->region_list); + devlink_nl_region_notify(region, NULL, DEVLINK_CMD_REGION_NEW); + + return region; +} +EXPORT_SYMBOL_GPL(devl_region_create); + +/** + * devlink_region_create - create a new address region + * + * @devlink: devlink + * @ops: region operations and name + * @region_max_snapshots: Maximum supported number of snapshots for region + * @region_size: size of region + * + * Context: Takes and release devlink->lock . + */ +struct devlink_region * +devlink_region_create(struct devlink *devlink, + const struct devlink_region_ops *ops, + u32 region_max_snapshots, u64 region_size) +{ + struct devlink_region *region; + + devl_lock(devlink); + region = devl_region_create(devlink, ops, region_max_snapshots, + region_size); + devl_unlock(devlink); + return region; +} +EXPORT_SYMBOL_GPL(devlink_region_create); + +/** + * devlink_port_region_create - create a new address region for a port + * + * @port: devlink port + * @ops: region operations and name + * @region_max_snapshots: Maximum supported number of snapshots for region + * @region_size: size of region + * + * Context: Takes and release devlink->lock . + */ +struct devlink_region * +devlink_port_region_create(struct devlink_port *port, + const struct devlink_port_region_ops *ops, + u32 region_max_snapshots, u64 region_size) +{ + struct devlink *devlink = port->devlink; + struct devlink_region *region; + int err = 0; + + ASSERT_DEVLINK_PORT_INITIALIZED(port); + + if (WARN_ON(!ops) || WARN_ON(!ops->destructor)) + return ERR_PTR(-EINVAL); + + devl_lock(devlink); + + if (devlink_port_region_get_by_name(port, ops->name)) { + err = -EEXIST; + goto unlock; + } + + region = kzalloc(sizeof(*region), GFP_KERNEL); + if (!region) { + err = -ENOMEM; + goto unlock; + } + + region->devlink = devlink; + region->port = port; + region->max_snapshots = region_max_snapshots; + region->port_ops = ops; + region->size = region_size; + INIT_LIST_HEAD(®ion->snapshot_list); + mutex_init(®ion->snapshot_lock); + list_add_tail(®ion->list, &port->region_list); + devlink_nl_region_notify(region, NULL, DEVLINK_CMD_REGION_NEW); + + devl_unlock(devlink); + return region; + +unlock: + devl_unlock(devlink); + return ERR_PTR(err); +} +EXPORT_SYMBOL_GPL(devlink_port_region_create); + +/** + * devl_region_destroy - destroy address region + * + * @region: devlink region to destroy + */ +void devl_region_destroy(struct devlink_region *region) +{ + struct devlink *devlink = region->devlink; + struct devlink_snapshot *snapshot, *ts; + + devl_assert_locked(devlink); + + /* Free all snapshots of region */ + mutex_lock(®ion->snapshot_lock); + list_for_each_entry_safe(snapshot, ts, ®ion->snapshot_list, list) + devlink_region_snapshot_del(region, snapshot); + mutex_unlock(®ion->snapshot_lock); + + list_del(®ion->list); + mutex_destroy(®ion->snapshot_lock); + + devlink_nl_region_notify(region, NULL, DEVLINK_CMD_REGION_DEL); + kfree(region); +} +EXPORT_SYMBOL_GPL(devl_region_destroy); + +/** + * devlink_region_destroy - destroy address region + * + * @region: devlink region to destroy + * + * Context: Takes and release devlink->lock . + */ +void devlink_region_destroy(struct devlink_region *region) +{ + struct devlink *devlink = region->devlink; + + devl_lock(devlink); + devl_region_destroy(region); + devl_unlock(devlink); +} +EXPORT_SYMBOL_GPL(devlink_region_destroy); + +/** + * devlink_region_snapshot_id_get - get snapshot ID + * + * This callback should be called when adding a new snapshot, + * Driver should use the same id for multiple snapshots taken + * on multiple regions at the same time/by the same trigger. + * + * The caller of this function must use devlink_region_snapshot_id_put + * when finished creating regions using this id. + * + * Returns zero on success, or a negative error code on failure. + * + * @devlink: devlink + * @id: storage to return id + */ +int devlink_region_snapshot_id_get(struct devlink *devlink, u32 *id) +{ + return __devlink_region_snapshot_id_get(devlink, id); +} +EXPORT_SYMBOL_GPL(devlink_region_snapshot_id_get); + +/** + * devlink_region_snapshot_id_put - put snapshot ID reference + * + * This should be called by a driver after finishing creating snapshots + * with an id. Doing so ensures that the ID can later be released in the + * event that all snapshots using it have been destroyed. + * + * @devlink: devlink + * @id: id to release reference on + */ +void devlink_region_snapshot_id_put(struct devlink *devlink, u32 id) +{ + __devlink_snapshot_id_decrement(devlink, id); +} +EXPORT_SYMBOL_GPL(devlink_region_snapshot_id_put); + +/** + * devlink_region_snapshot_create - create a new snapshot + * This will add a new snapshot of a region. The snapshot + * will be stored on the region struct and can be accessed + * from devlink. This is useful for future analyses of snapshots. + * Multiple snapshots can be created on a region. + * The @snapshot_id should be obtained using the getter function. + * + * @region: devlink region of the snapshot + * @data: snapshot data + * @snapshot_id: snapshot id to be created + */ +int devlink_region_snapshot_create(struct devlink_region *region, + u8 *data, u32 snapshot_id) +{ + int err; + + mutex_lock(®ion->snapshot_lock); + err = __devlink_region_snapshot_create(region, data, snapshot_id); + mutex_unlock(®ion->snapshot_lock); + return err; +} +EXPORT_SYMBOL_GPL(devlink_region_snapshot_create); + +#define DEVLINK_TRAP(_id, _type) \ + { \ + .type = DEVLINK_TRAP_TYPE_##_type, \ + .id = DEVLINK_TRAP_GENERIC_ID_##_id, \ + .name = DEVLINK_TRAP_GENERIC_NAME_##_id, \ + } + +static const struct devlink_trap devlink_trap_generic[] = { + DEVLINK_TRAP(SMAC_MC, DROP), + DEVLINK_TRAP(VLAN_TAG_MISMATCH, DROP), + DEVLINK_TRAP(INGRESS_VLAN_FILTER, DROP), + DEVLINK_TRAP(INGRESS_STP_FILTER, DROP), + DEVLINK_TRAP(EMPTY_TX_LIST, DROP), + DEVLINK_TRAP(PORT_LOOPBACK_FILTER, DROP), + DEVLINK_TRAP(BLACKHOLE_ROUTE, DROP), + DEVLINK_TRAP(TTL_ERROR, EXCEPTION), + DEVLINK_TRAP(TAIL_DROP, DROP), + DEVLINK_TRAP(NON_IP_PACKET, DROP), + DEVLINK_TRAP(UC_DIP_MC_DMAC, DROP), + DEVLINK_TRAP(DIP_LB, DROP), + DEVLINK_TRAP(SIP_MC, DROP), + DEVLINK_TRAP(SIP_LB, DROP), + DEVLINK_TRAP(CORRUPTED_IP_HDR, DROP), + DEVLINK_TRAP(IPV4_SIP_BC, DROP), + DEVLINK_TRAP(IPV6_MC_DIP_RESERVED_SCOPE, DROP), + DEVLINK_TRAP(IPV6_MC_DIP_INTERFACE_LOCAL_SCOPE, DROP), + DEVLINK_TRAP(MTU_ERROR, EXCEPTION), + DEVLINK_TRAP(UNRESOLVED_NEIGH, EXCEPTION), + DEVLINK_TRAP(RPF, EXCEPTION), + DEVLINK_TRAP(REJECT_ROUTE, EXCEPTION), + DEVLINK_TRAP(IPV4_LPM_UNICAST_MISS, EXCEPTION), + DEVLINK_TRAP(IPV6_LPM_UNICAST_MISS, EXCEPTION), + DEVLINK_TRAP(NON_ROUTABLE, DROP), + DEVLINK_TRAP(DECAP_ERROR, EXCEPTION), + DEVLINK_TRAP(OVERLAY_SMAC_MC, DROP), + DEVLINK_TRAP(INGRESS_FLOW_ACTION_DROP, DROP), + DEVLINK_TRAP(EGRESS_FLOW_ACTION_DROP, DROP), + DEVLINK_TRAP(STP, CONTROL), + DEVLINK_TRAP(LACP, CONTROL), + DEVLINK_TRAP(LLDP, CONTROL), + DEVLINK_TRAP(IGMP_QUERY, CONTROL), + DEVLINK_TRAP(IGMP_V1_REPORT, CONTROL), + DEVLINK_TRAP(IGMP_V2_REPORT, CONTROL), + DEVLINK_TRAP(IGMP_V3_REPORT, CONTROL), + DEVLINK_TRAP(IGMP_V2_LEAVE, CONTROL), + DEVLINK_TRAP(MLD_QUERY, CONTROL), + DEVLINK_TRAP(MLD_V1_REPORT, CONTROL), + DEVLINK_TRAP(MLD_V2_REPORT, CONTROL), + DEVLINK_TRAP(MLD_V1_DONE, CONTROL), + DEVLINK_TRAP(IPV4_DHCP, CONTROL), + DEVLINK_TRAP(IPV6_DHCP, CONTROL), + DEVLINK_TRAP(ARP_REQUEST, CONTROL), + DEVLINK_TRAP(ARP_RESPONSE, CONTROL), + DEVLINK_TRAP(ARP_OVERLAY, CONTROL), + DEVLINK_TRAP(IPV6_NEIGH_SOLICIT, CONTROL), + DEVLINK_TRAP(IPV6_NEIGH_ADVERT, CONTROL), + DEVLINK_TRAP(IPV4_BFD, CONTROL), + DEVLINK_TRAP(IPV6_BFD, CONTROL), + DEVLINK_TRAP(IPV4_OSPF, CONTROL), + DEVLINK_TRAP(IPV6_OSPF, CONTROL), + DEVLINK_TRAP(IPV4_BGP, CONTROL), + DEVLINK_TRAP(IPV6_BGP, CONTROL), + DEVLINK_TRAP(IPV4_VRRP, CONTROL), + DEVLINK_TRAP(IPV6_VRRP, CONTROL), + DEVLINK_TRAP(IPV4_PIM, CONTROL), + DEVLINK_TRAP(IPV6_PIM, CONTROL), + DEVLINK_TRAP(UC_LB, CONTROL), + DEVLINK_TRAP(LOCAL_ROUTE, CONTROL), + DEVLINK_TRAP(EXTERNAL_ROUTE, CONTROL), + DEVLINK_TRAP(IPV6_UC_DIP_LINK_LOCAL_SCOPE, CONTROL), + DEVLINK_TRAP(IPV6_DIP_ALL_NODES, CONTROL), + DEVLINK_TRAP(IPV6_DIP_ALL_ROUTERS, CONTROL), + DEVLINK_TRAP(IPV6_ROUTER_SOLICIT, CONTROL), + DEVLINK_TRAP(IPV6_ROUTER_ADVERT, CONTROL), + DEVLINK_TRAP(IPV6_REDIRECT, CONTROL), + DEVLINK_TRAP(IPV4_ROUTER_ALERT, CONTROL), + DEVLINK_TRAP(IPV6_ROUTER_ALERT, CONTROL), + DEVLINK_TRAP(PTP_EVENT, CONTROL), + DEVLINK_TRAP(PTP_GENERAL, CONTROL), + DEVLINK_TRAP(FLOW_ACTION_SAMPLE, CONTROL), + DEVLINK_TRAP(FLOW_ACTION_TRAP, CONTROL), + DEVLINK_TRAP(EARLY_DROP, DROP), + DEVLINK_TRAP(VXLAN_PARSING, DROP), + DEVLINK_TRAP(LLC_SNAP_PARSING, DROP), + DEVLINK_TRAP(VLAN_PARSING, DROP), + DEVLINK_TRAP(PPPOE_PPP_PARSING, DROP), + DEVLINK_TRAP(MPLS_PARSING, DROP), + DEVLINK_TRAP(ARP_PARSING, DROP), + DEVLINK_TRAP(IP_1_PARSING, DROP), + DEVLINK_TRAP(IP_N_PARSING, DROP), + DEVLINK_TRAP(GRE_PARSING, DROP), + DEVLINK_TRAP(UDP_PARSING, DROP), + DEVLINK_TRAP(TCP_PARSING, DROP), + DEVLINK_TRAP(IPSEC_PARSING, DROP), + DEVLINK_TRAP(SCTP_PARSING, DROP), + DEVLINK_TRAP(DCCP_PARSING, DROP), + DEVLINK_TRAP(GTP_PARSING, DROP), + DEVLINK_TRAP(ESP_PARSING, DROP), + DEVLINK_TRAP(BLACKHOLE_NEXTHOP, DROP), + DEVLINK_TRAP(DMAC_FILTER, DROP), +}; + +#define DEVLINK_TRAP_GROUP(_id) \ + { \ + .id = DEVLINK_TRAP_GROUP_GENERIC_ID_##_id, \ + .name = DEVLINK_TRAP_GROUP_GENERIC_NAME_##_id, \ + } + +static const struct devlink_trap_group devlink_trap_group_generic[] = { + DEVLINK_TRAP_GROUP(L2_DROPS), + DEVLINK_TRAP_GROUP(L3_DROPS), + DEVLINK_TRAP_GROUP(L3_EXCEPTIONS), + DEVLINK_TRAP_GROUP(BUFFER_DROPS), + DEVLINK_TRAP_GROUP(TUNNEL_DROPS), + DEVLINK_TRAP_GROUP(ACL_DROPS), + DEVLINK_TRAP_GROUP(STP), + DEVLINK_TRAP_GROUP(LACP), + DEVLINK_TRAP_GROUP(LLDP), + DEVLINK_TRAP_GROUP(MC_SNOOPING), + DEVLINK_TRAP_GROUP(DHCP), + DEVLINK_TRAP_GROUP(NEIGH_DISCOVERY), + DEVLINK_TRAP_GROUP(BFD), + DEVLINK_TRAP_GROUP(OSPF), + DEVLINK_TRAP_GROUP(BGP), + DEVLINK_TRAP_GROUP(VRRP), + DEVLINK_TRAP_GROUP(PIM), + DEVLINK_TRAP_GROUP(UC_LB), + DEVLINK_TRAP_GROUP(LOCAL_DELIVERY), + DEVLINK_TRAP_GROUP(EXTERNAL_DELIVERY), + DEVLINK_TRAP_GROUP(IPV6), + DEVLINK_TRAP_GROUP(PTP_EVENT), + DEVLINK_TRAP_GROUP(PTP_GENERAL), + DEVLINK_TRAP_GROUP(ACL_SAMPLE), + DEVLINK_TRAP_GROUP(ACL_TRAP), + DEVLINK_TRAP_GROUP(PARSER_ERROR_DROPS), +}; + +static int devlink_trap_generic_verify(const struct devlink_trap *trap) +{ + if (trap->id > DEVLINK_TRAP_GENERIC_ID_MAX) + return -EINVAL; + + if (strcmp(trap->name, devlink_trap_generic[trap->id].name)) + return -EINVAL; + + if (trap->type != devlink_trap_generic[trap->id].type) + return -EINVAL; + + return 0; +} + +static int devlink_trap_driver_verify(const struct devlink_trap *trap) +{ + int i; + + if (trap->id <= DEVLINK_TRAP_GENERIC_ID_MAX) + return -EINVAL; + + for (i = 0; i < ARRAY_SIZE(devlink_trap_generic); i++) { + if (!strcmp(trap->name, devlink_trap_generic[i].name)) + return -EEXIST; + } + + return 0; +} + +static int devlink_trap_verify(const struct devlink_trap *trap) +{ + if (!trap || !trap->name) + return -EINVAL; + + if (trap->generic) + return devlink_trap_generic_verify(trap); + else + return devlink_trap_driver_verify(trap); +} + +static int +devlink_trap_group_generic_verify(const struct devlink_trap_group *group) +{ + if (group->id > DEVLINK_TRAP_GROUP_GENERIC_ID_MAX) + return -EINVAL; + + if (strcmp(group->name, devlink_trap_group_generic[group->id].name)) + return -EINVAL; + + return 0; +} + +static int +devlink_trap_group_driver_verify(const struct devlink_trap_group *group) +{ + int i; + + if (group->id <= DEVLINK_TRAP_GROUP_GENERIC_ID_MAX) + return -EINVAL; + + for (i = 0; i < ARRAY_SIZE(devlink_trap_group_generic); i++) { + if (!strcmp(group->name, devlink_trap_group_generic[i].name)) + return -EEXIST; + } + + return 0; +} + +static int devlink_trap_group_verify(const struct devlink_trap_group *group) +{ + if (group->generic) + return devlink_trap_group_generic_verify(group); + else + return devlink_trap_group_driver_verify(group); +} + +static void +devlink_trap_group_notify(struct devlink *devlink, + const struct devlink_trap_group_item *group_item, + enum devlink_command cmd) +{ + struct sk_buff *msg; + int err; + + WARN_ON_ONCE(cmd != DEVLINK_CMD_TRAP_GROUP_NEW && + cmd != DEVLINK_CMD_TRAP_GROUP_DEL); + if (!xa_get_mark(&devlinks, devlink->index, DEVLINK_REGISTERED)) + return; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + err = devlink_nl_trap_group_fill(msg, devlink, group_item, cmd, 0, 0, + 0); + if (err) { + nlmsg_free(msg); + return; + } + + genlmsg_multicast_netns(&devlink_nl_family, devlink_net(devlink), + msg, 0, DEVLINK_MCGRP_CONFIG, GFP_KERNEL); +} + +static int +devlink_trap_item_group_link(struct devlink *devlink, + struct devlink_trap_item *trap_item) +{ + u16 group_id = trap_item->trap->init_group_id; + struct devlink_trap_group_item *group_item; + + group_item = devlink_trap_group_item_lookup_by_id(devlink, group_id); + if (WARN_ON_ONCE(!group_item)) + return -EINVAL; + + trap_item->group_item = group_item; + + return 0; +} + +static void devlink_trap_notify(struct devlink *devlink, + const struct devlink_trap_item *trap_item, + enum devlink_command cmd) +{ + struct sk_buff *msg; + int err; + + WARN_ON_ONCE(cmd != DEVLINK_CMD_TRAP_NEW && + cmd != DEVLINK_CMD_TRAP_DEL); + if (!xa_get_mark(&devlinks, devlink->index, DEVLINK_REGISTERED)) + return; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + err = devlink_nl_trap_fill(msg, devlink, trap_item, cmd, 0, 0, 0); + if (err) { + nlmsg_free(msg); + return; + } + + genlmsg_multicast_netns(&devlink_nl_family, devlink_net(devlink), + msg, 0, DEVLINK_MCGRP_CONFIG, GFP_KERNEL); +} + +static int +devlink_trap_register(struct devlink *devlink, + const struct devlink_trap *trap, void *priv) +{ + struct devlink_trap_item *trap_item; + int err; + + if (devlink_trap_item_lookup(devlink, trap->name)) + return -EEXIST; + + trap_item = kzalloc(sizeof(*trap_item), GFP_KERNEL); + if (!trap_item) + return -ENOMEM; + + trap_item->stats = netdev_alloc_pcpu_stats(struct devlink_stats); + if (!trap_item->stats) { + err = -ENOMEM; + goto err_stats_alloc; + } + + trap_item->trap = trap; + trap_item->action = trap->init_action; + trap_item->priv = priv; + + err = devlink_trap_item_group_link(devlink, trap_item); + if (err) + goto err_group_link; + + err = devlink->ops->trap_init(devlink, trap, trap_item); + if (err) + goto err_trap_init; + + list_add_tail(&trap_item->list, &devlink->trap_list); + devlink_trap_notify(devlink, trap_item, DEVLINK_CMD_TRAP_NEW); + + return 0; + +err_trap_init: +err_group_link: + free_percpu(trap_item->stats); +err_stats_alloc: + kfree(trap_item); + return err; +} + +static void devlink_trap_unregister(struct devlink *devlink, + const struct devlink_trap *trap) +{ + struct devlink_trap_item *trap_item; + + trap_item = devlink_trap_item_lookup(devlink, trap->name); + if (WARN_ON_ONCE(!trap_item)) + return; + + devlink_trap_notify(devlink, trap_item, DEVLINK_CMD_TRAP_DEL); + list_del(&trap_item->list); + if (devlink->ops->trap_fini) + devlink->ops->trap_fini(devlink, trap, trap_item); + free_percpu(trap_item->stats); + kfree(trap_item); +} + +static void devlink_trap_disable(struct devlink *devlink, + const struct devlink_trap *trap) +{ + struct devlink_trap_item *trap_item; + + trap_item = devlink_trap_item_lookup(devlink, trap->name); + if (WARN_ON_ONCE(!trap_item)) + return; + + devlink->ops->trap_action_set(devlink, trap, DEVLINK_TRAP_ACTION_DROP, + NULL); + trap_item->action = DEVLINK_TRAP_ACTION_DROP; +} + +/** + * devl_traps_register - Register packet traps with devlink. + * @devlink: devlink. + * @traps: Packet traps. + * @traps_count: Count of provided packet traps. + * @priv: Driver private information. + * + * Return: Non-zero value on failure. + */ +int devl_traps_register(struct devlink *devlink, + const struct devlink_trap *traps, + size_t traps_count, void *priv) +{ + int i, err; + + if (!devlink->ops->trap_init || !devlink->ops->trap_action_set) + return -EINVAL; + + devl_assert_locked(devlink); + for (i = 0; i < traps_count; i++) { + const struct devlink_trap *trap = &traps[i]; + + err = devlink_trap_verify(trap); + if (err) + goto err_trap_verify; + + err = devlink_trap_register(devlink, trap, priv); + if (err) + goto err_trap_register; + } + + return 0; + +err_trap_register: +err_trap_verify: + for (i--; i >= 0; i--) + devlink_trap_unregister(devlink, &traps[i]); + return err; +} +EXPORT_SYMBOL_GPL(devl_traps_register); + +/** + * devlink_traps_register - Register packet traps with devlink. + * @devlink: devlink. + * @traps: Packet traps. + * @traps_count: Count of provided packet traps. + * @priv: Driver private information. + * + * Context: Takes and release devlink->lock . + * + * Return: Non-zero value on failure. + */ +int devlink_traps_register(struct devlink *devlink, + const struct devlink_trap *traps, + size_t traps_count, void *priv) +{ + int err; + + devl_lock(devlink); + err = devl_traps_register(devlink, traps, traps_count, priv); + devl_unlock(devlink); + return err; +} +EXPORT_SYMBOL_GPL(devlink_traps_register); + +/** + * devl_traps_unregister - Unregister packet traps from devlink. + * @devlink: devlink. + * @traps: Packet traps. + * @traps_count: Count of provided packet traps. + */ +void devl_traps_unregister(struct devlink *devlink, + const struct devlink_trap *traps, + size_t traps_count) +{ + int i; + + devl_assert_locked(devlink); + /* Make sure we do not have any packets in-flight while unregistering + * traps by disabling all of them and waiting for a grace period. + */ + for (i = traps_count - 1; i >= 0; i--) + devlink_trap_disable(devlink, &traps[i]); + synchronize_rcu(); + for (i = traps_count - 1; i >= 0; i--) + devlink_trap_unregister(devlink, &traps[i]); +} +EXPORT_SYMBOL_GPL(devl_traps_unregister); + +/** + * devlink_traps_unregister - Unregister packet traps from devlink. + * @devlink: devlink. + * @traps: Packet traps. + * @traps_count: Count of provided packet traps. + * + * Context: Takes and release devlink->lock . + */ +void devlink_traps_unregister(struct devlink *devlink, + const struct devlink_trap *traps, + size_t traps_count) +{ + devl_lock(devlink); + devl_traps_unregister(devlink, traps, traps_count); + devl_unlock(devlink); +} +EXPORT_SYMBOL_GPL(devlink_traps_unregister); + +static void +devlink_trap_stats_update(struct devlink_stats __percpu *trap_stats, + size_t skb_len) +{ + struct devlink_stats *stats; + + stats = this_cpu_ptr(trap_stats); + u64_stats_update_begin(&stats->syncp); + u64_stats_add(&stats->rx_bytes, skb_len); + u64_stats_inc(&stats->rx_packets); + u64_stats_update_end(&stats->syncp); +} + +static void +devlink_trap_report_metadata_set(struct devlink_trap_metadata *metadata, + const struct devlink_trap_item *trap_item, + struct devlink_port *in_devlink_port, + const struct flow_action_cookie *fa_cookie) +{ + metadata->trap_name = trap_item->trap->name; + metadata->trap_group_name = trap_item->group_item->group->name; + metadata->fa_cookie = fa_cookie; + metadata->trap_type = trap_item->trap->type; + + spin_lock(&in_devlink_port->type_lock); + if (in_devlink_port->type == DEVLINK_PORT_TYPE_ETH) + metadata->input_dev = in_devlink_port->type_dev; + spin_unlock(&in_devlink_port->type_lock); +} + +/** + * devlink_trap_report - Report trapped packet to drop monitor. + * @devlink: devlink. + * @skb: Trapped packet. + * @trap_ctx: Trap context. + * @in_devlink_port: Input devlink port. + * @fa_cookie: Flow action cookie. Could be NULL. + */ +void devlink_trap_report(struct devlink *devlink, struct sk_buff *skb, + void *trap_ctx, struct devlink_port *in_devlink_port, + const struct flow_action_cookie *fa_cookie) + +{ + struct devlink_trap_item *trap_item = trap_ctx; + + devlink_trap_stats_update(trap_item->stats, skb->len); + devlink_trap_stats_update(trap_item->group_item->stats, skb->len); + + if (trace_devlink_trap_report_enabled()) { + struct devlink_trap_metadata metadata = {}; + + devlink_trap_report_metadata_set(&metadata, trap_item, + in_devlink_port, fa_cookie); + trace_devlink_trap_report(devlink, skb, &metadata); + } +} +EXPORT_SYMBOL_GPL(devlink_trap_report); + +/** + * devlink_trap_ctx_priv - Trap context to driver private information. + * @trap_ctx: Trap context. + * + * Return: Driver private information passed during registration. + */ +void *devlink_trap_ctx_priv(void *trap_ctx) +{ + struct devlink_trap_item *trap_item = trap_ctx; + + return trap_item->priv; +} +EXPORT_SYMBOL_GPL(devlink_trap_ctx_priv); + +static int +devlink_trap_group_item_policer_link(struct devlink *devlink, + struct devlink_trap_group_item *group_item) +{ + u32 policer_id = group_item->group->init_policer_id; + struct devlink_trap_policer_item *policer_item; + + if (policer_id == 0) + return 0; + + policer_item = devlink_trap_policer_item_lookup(devlink, policer_id); + if (WARN_ON_ONCE(!policer_item)) + return -EINVAL; + + group_item->policer_item = policer_item; + + return 0; +} + +static int +devlink_trap_group_register(struct devlink *devlink, + const struct devlink_trap_group *group) +{ + struct devlink_trap_group_item *group_item; + int err; + + if (devlink_trap_group_item_lookup(devlink, group->name)) + return -EEXIST; + + group_item = kzalloc(sizeof(*group_item), GFP_KERNEL); + if (!group_item) + return -ENOMEM; + + group_item->stats = netdev_alloc_pcpu_stats(struct devlink_stats); + if (!group_item->stats) { + err = -ENOMEM; + goto err_stats_alloc; + } + + group_item->group = group; + + err = devlink_trap_group_item_policer_link(devlink, group_item); + if (err) + goto err_policer_link; + + if (devlink->ops->trap_group_init) { + err = devlink->ops->trap_group_init(devlink, group); + if (err) + goto err_group_init; + } + + list_add_tail(&group_item->list, &devlink->trap_group_list); + devlink_trap_group_notify(devlink, group_item, + DEVLINK_CMD_TRAP_GROUP_NEW); + + return 0; + +err_group_init: +err_policer_link: + free_percpu(group_item->stats); +err_stats_alloc: + kfree(group_item); + return err; +} + +static void +devlink_trap_group_unregister(struct devlink *devlink, + const struct devlink_trap_group *group) +{ + struct devlink_trap_group_item *group_item; + + group_item = devlink_trap_group_item_lookup(devlink, group->name); + if (WARN_ON_ONCE(!group_item)) + return; + + devlink_trap_group_notify(devlink, group_item, + DEVLINK_CMD_TRAP_GROUP_DEL); + list_del(&group_item->list); + free_percpu(group_item->stats); + kfree(group_item); +} + +/** + * devl_trap_groups_register - Register packet trap groups with devlink. + * @devlink: devlink. + * @groups: Packet trap groups. + * @groups_count: Count of provided packet trap groups. + * + * Return: Non-zero value on failure. + */ +int devl_trap_groups_register(struct devlink *devlink, + const struct devlink_trap_group *groups, + size_t groups_count) +{ + int i, err; + + devl_assert_locked(devlink); + for (i = 0; i < groups_count; i++) { + const struct devlink_trap_group *group = &groups[i]; + + err = devlink_trap_group_verify(group); + if (err) + goto err_trap_group_verify; + + err = devlink_trap_group_register(devlink, group); + if (err) + goto err_trap_group_register; + } + + return 0; + +err_trap_group_register: +err_trap_group_verify: + for (i--; i >= 0; i--) + devlink_trap_group_unregister(devlink, &groups[i]); + return err; +} +EXPORT_SYMBOL_GPL(devl_trap_groups_register); + +/** + * devlink_trap_groups_register - Register packet trap groups with devlink. + * @devlink: devlink. + * @groups: Packet trap groups. + * @groups_count: Count of provided packet trap groups. + * + * Context: Takes and release devlink->lock . + * + * Return: Non-zero value on failure. + */ +int devlink_trap_groups_register(struct devlink *devlink, + const struct devlink_trap_group *groups, + size_t groups_count) +{ + int err; + + devl_lock(devlink); + err = devl_trap_groups_register(devlink, groups, groups_count); + devl_unlock(devlink); + return err; +} +EXPORT_SYMBOL_GPL(devlink_trap_groups_register); + +/** + * devl_trap_groups_unregister - Unregister packet trap groups from devlink. + * @devlink: devlink. + * @groups: Packet trap groups. + * @groups_count: Count of provided packet trap groups. + */ +void devl_trap_groups_unregister(struct devlink *devlink, + const struct devlink_trap_group *groups, + size_t groups_count) +{ + int i; + + devl_assert_locked(devlink); + for (i = groups_count - 1; i >= 0; i--) + devlink_trap_group_unregister(devlink, &groups[i]); +} +EXPORT_SYMBOL_GPL(devl_trap_groups_unregister); + +/** + * devlink_trap_groups_unregister - Unregister packet trap groups from devlink. + * @devlink: devlink. + * @groups: Packet trap groups. + * @groups_count: Count of provided packet trap groups. + * + * Context: Takes and release devlink->lock . + */ +void devlink_trap_groups_unregister(struct devlink *devlink, + const struct devlink_trap_group *groups, + size_t groups_count) +{ + devl_lock(devlink); + devl_trap_groups_unregister(devlink, groups, groups_count); + devl_unlock(devlink); +} +EXPORT_SYMBOL_GPL(devlink_trap_groups_unregister); + +static void +devlink_trap_policer_notify(struct devlink *devlink, + const struct devlink_trap_policer_item *policer_item, + enum devlink_command cmd) +{ + struct sk_buff *msg; + int err; + + WARN_ON_ONCE(cmd != DEVLINK_CMD_TRAP_POLICER_NEW && + cmd != DEVLINK_CMD_TRAP_POLICER_DEL); + if (!xa_get_mark(&devlinks, devlink->index, DEVLINK_REGISTERED)) + return; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + err = devlink_nl_trap_policer_fill(msg, devlink, policer_item, cmd, 0, + 0, 0); + if (err) { + nlmsg_free(msg); + return; + } + + genlmsg_multicast_netns(&devlink_nl_family, devlink_net(devlink), + msg, 0, DEVLINK_MCGRP_CONFIG, GFP_KERNEL); +} + +static int +devlink_trap_policer_register(struct devlink *devlink, + const struct devlink_trap_policer *policer) +{ + struct devlink_trap_policer_item *policer_item; + int err; + + if (devlink_trap_policer_item_lookup(devlink, policer->id)) + return -EEXIST; + + policer_item = kzalloc(sizeof(*policer_item), GFP_KERNEL); + if (!policer_item) + return -ENOMEM; + + policer_item->policer = policer; + policer_item->rate = policer->init_rate; + policer_item->burst = policer->init_burst; + + if (devlink->ops->trap_policer_init) { + err = devlink->ops->trap_policer_init(devlink, policer); + if (err) + goto err_policer_init; + } + + list_add_tail(&policer_item->list, &devlink->trap_policer_list); + devlink_trap_policer_notify(devlink, policer_item, + DEVLINK_CMD_TRAP_POLICER_NEW); + + return 0; + +err_policer_init: + kfree(policer_item); + return err; +} + +static void +devlink_trap_policer_unregister(struct devlink *devlink, + const struct devlink_trap_policer *policer) +{ + struct devlink_trap_policer_item *policer_item; + + policer_item = devlink_trap_policer_item_lookup(devlink, policer->id); + if (WARN_ON_ONCE(!policer_item)) + return; + + devlink_trap_policer_notify(devlink, policer_item, + DEVLINK_CMD_TRAP_POLICER_DEL); + list_del(&policer_item->list); + if (devlink->ops->trap_policer_fini) + devlink->ops->trap_policer_fini(devlink, policer); + kfree(policer_item); +} + +/** + * devl_trap_policers_register - Register packet trap policers with devlink. + * @devlink: devlink. + * @policers: Packet trap policers. + * @policers_count: Count of provided packet trap policers. + * + * Return: Non-zero value on failure. + */ +int +devl_trap_policers_register(struct devlink *devlink, + const struct devlink_trap_policer *policers, + size_t policers_count) +{ + int i, err; + + devl_assert_locked(devlink); + for (i = 0; i < policers_count; i++) { + const struct devlink_trap_policer *policer = &policers[i]; + + if (WARN_ON(policer->id == 0 || + policer->max_rate < policer->min_rate || + policer->max_burst < policer->min_burst)) { + err = -EINVAL; + goto err_trap_policer_verify; + } + + err = devlink_trap_policer_register(devlink, policer); + if (err) + goto err_trap_policer_register; + } + return 0; + +err_trap_policer_register: +err_trap_policer_verify: + for (i--; i >= 0; i--) + devlink_trap_policer_unregister(devlink, &policers[i]); + return err; +} +EXPORT_SYMBOL_GPL(devl_trap_policers_register); + +/** + * devl_trap_policers_unregister - Unregister packet trap policers from devlink. + * @devlink: devlink. + * @policers: Packet trap policers. + * @policers_count: Count of provided packet trap policers. + */ +void +devl_trap_policers_unregister(struct devlink *devlink, + const struct devlink_trap_policer *policers, + size_t policers_count) +{ + int i; + + devl_assert_locked(devlink); + for (i = policers_count - 1; i >= 0; i--) + devlink_trap_policer_unregister(devlink, &policers[i]); +} +EXPORT_SYMBOL_GPL(devl_trap_policers_unregister); + +static void __devlink_compat_running_version(struct devlink *devlink, + char *buf, size_t len) +{ + struct devlink_info_req req = {}; + const struct nlattr *nlattr; + struct sk_buff *msg; + int rem, err; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + req.msg = msg; + err = devlink->ops->info_get(devlink, &req, NULL); + if (err) + goto free_msg; + + nla_for_each_attr(nlattr, (void *)msg->data, msg->len, rem) { + const struct nlattr *kv; + int rem_kv; + + if (nla_type(nlattr) != DEVLINK_ATTR_INFO_VERSION_RUNNING) + continue; + + nla_for_each_nested(kv, nlattr, rem_kv) { + if (nla_type(kv) != DEVLINK_ATTR_INFO_VERSION_VALUE) + continue; + + strlcat(buf, nla_data(kv), len); + strlcat(buf, " ", len); + } + } +free_msg: + nlmsg_free(msg); +} + +static struct devlink_port *netdev_to_devlink_port(struct net_device *dev) +{ + if (!dev->netdev_ops->ndo_get_devlink_port) + return NULL; + + return dev->netdev_ops->ndo_get_devlink_port(dev); +} + +void devlink_compat_running_version(struct devlink *devlink, + char *buf, size_t len) +{ + if (!devlink->ops->info_get) + return; + + devl_lock(devlink); + __devlink_compat_running_version(devlink, buf, len); + devl_unlock(devlink); +} + +int devlink_compat_flash_update(struct devlink *devlink, const char *file_name) +{ + struct devlink_flash_update_params params = {}; + int ret; + + if (!devlink->ops->flash_update) + return -EOPNOTSUPP; + + ret = request_firmware(¶ms.fw, file_name, devlink->dev); + if (ret) + return ret; + + devl_lock(devlink); + devlink_flash_update_begin_notify(devlink); + ret = devlink->ops->flash_update(devlink, ¶ms, NULL); + devlink_flash_update_end_notify(devlink); + devl_unlock(devlink); + + release_firmware(params.fw); + + return ret; +} + +int devlink_compat_phys_port_name_get(struct net_device *dev, + char *name, size_t len) +{ + struct devlink_port *devlink_port; + + /* RTNL mutex is held here which ensures that devlink_port + * instance cannot disappear in the middle. No need to take + * any devlink lock as only permanent values are accessed. + */ + ASSERT_RTNL(); + + devlink_port = netdev_to_devlink_port(dev); + if (!devlink_port) + return -EOPNOTSUPP; + + return __devlink_port_phys_port_name_get(devlink_port, name, len); +} + +int devlink_compat_switch_id_get(struct net_device *dev, + struct netdev_phys_item_id *ppid) +{ + struct devlink_port *devlink_port; + + /* Caller must hold RTNL mutex or reference to dev, which ensures that + * devlink_port instance cannot disappear in the middle. No need to take + * any devlink lock as only permanent values are accessed. + */ + devlink_port = netdev_to_devlink_port(dev); + if (!devlink_port || !devlink_port->switch_port) + return -EOPNOTSUPP; + + memcpy(ppid, &devlink_port->attrs.switch_id, sizeof(*ppid)); + + return 0; +} + +static void __net_exit devlink_pernet_pre_exit(struct net *net) +{ + struct devlink *devlink; + u32 actions_performed; + unsigned long index; + int err; + + /* In case network namespace is getting destroyed, reload + * all devlink instances from this namespace into init_net. + */ + devlinks_xa_for_each_registered_get(net, index, devlink) { + WARN_ON(!(devlink->features & DEVLINK_F_RELOAD)); + mutex_lock(&devlink->lock); + err = devlink_reload(devlink, &init_net, + DEVLINK_RELOAD_ACTION_DRIVER_REINIT, + DEVLINK_RELOAD_LIMIT_UNSPEC, + &actions_performed, NULL); + mutex_unlock(&devlink->lock); + if (err && err != -EOPNOTSUPP) + pr_warn("Failed to reload devlink instance into init_net\n"); + devlink_put(devlink); + } +} + +static struct pernet_operations devlink_pernet_ops __net_initdata = { + .pre_exit = devlink_pernet_pre_exit, +}; + +static int __init devlink_init(void) +{ + int err; + + err = genl_register_family(&devlink_nl_family); + if (err) + goto out; + err = register_pernet_subsys(&devlink_pernet_ops); + +out: + WARN_ON(err); + return err; +} + +subsys_initcall(devlink_init); diff --git a/net/devres.c b/net/devres.c new file mode 100644 index 000000000..5ccf6ca31 --- /dev/null +++ b/net/devres.c @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * This file contains all networking devres helpers. + */ + +#include +#include +#include + +struct net_device_devres { + struct net_device *ndev; +}; + +static void devm_free_netdev(struct device *dev, void *this) +{ + struct net_device_devres *res = this; + + free_netdev(res->ndev); +} + +struct net_device *devm_alloc_etherdev_mqs(struct device *dev, int sizeof_priv, + unsigned int txqs, unsigned int rxqs) +{ + struct net_device_devres *dr; + + dr = devres_alloc(devm_free_netdev, sizeof(*dr), GFP_KERNEL); + if (!dr) + return NULL; + + dr->ndev = alloc_etherdev_mqs(sizeof_priv, txqs, rxqs); + if (!dr->ndev) { + devres_free(dr); + return NULL; + } + + devres_add(dev, dr); + + return dr->ndev; +} +EXPORT_SYMBOL(devm_alloc_etherdev_mqs); + +static void devm_unregister_netdev(struct device *dev, void *this) +{ + struct net_device_devres *res = this; + + unregister_netdev(res->ndev); +} + +static int netdev_devres_match(struct device *dev, void *this, void *match_data) +{ + struct net_device_devres *res = this; + struct net_device *ndev = match_data; + + return ndev == res->ndev; +} + +/** + * devm_register_netdev - resource managed variant of register_netdev() + * @dev: managing device for this netdev - usually the parent device + * @ndev: device to register + * + * This is a devres variant of register_netdev() for which the unregister + * function will be called automatically when the managing device is + * detached. Note: the net_device used must also be resource managed by + * the same struct device. + */ +int devm_register_netdev(struct device *dev, struct net_device *ndev) +{ + struct net_device_devres *dr; + int ret; + + /* struct net_device must itself be managed. For now a managed netdev + * can only be allocated by devm_alloc_etherdev_mqs() so the check is + * straightforward. + */ + if (WARN_ON(!devres_find(dev, devm_free_netdev, + netdev_devres_match, ndev))) + return -EINVAL; + + dr = devres_alloc(devm_unregister_netdev, sizeof(*dr), GFP_KERNEL); + if (!dr) + return -ENOMEM; + + ret = register_netdev(ndev); + if (ret) { + devres_free(dr); + return ret; + } + + dr->ndev = ndev; + devres_add(ndev->dev.parent, dr); + + return 0; +} +EXPORT_SYMBOL(devm_register_netdev); diff --git a/net/dns_resolver/Kconfig b/net/dns_resolver/Kconfig new file mode 100644 index 000000000..155b06163 --- /dev/null +++ b/net/dns_resolver/Kconfig @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Configuration for DNS Resolver +# +config DNS_RESOLVER + tristate "DNS Resolver support" + depends on KEYS + help + Saying Y here will include support for the DNS Resolver key type + which can be used to make upcalls to perform DNS lookups in + userspace. + + DNS Resolver is used to query DNS server for information. Examples + being resolving a UNC hostname element to an IP address for CIFS or + performing a DNS query for AFSDB records so that AFS can locate a + cell's volume location database servers. + + DNS Resolver is used by the CIFS and AFS modules, and would support + SMB2 later. DNS Resolver is supported by the userspace upcall + helper "/sbin/dns.resolver" via /etc/request-key.conf. + + See for further + information. + + To compile this as a module, choose M here: the module will be called + dnsresolver. + + If unsure, say N. diff --git a/net/dns_resolver/Makefile b/net/dns_resolver/Makefile new file mode 100644 index 000000000..877532d66 --- /dev/null +++ b/net/dns_resolver/Makefile @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the Linux DNS Resolver. +# + +obj-$(CONFIG_DNS_RESOLVER) += dns_resolver.o + +dns_resolver-y := dns_key.o dns_query.o diff --git a/net/dns_resolver/dns_key.c b/net/dns_resolver/dns_key.c new file mode 100644 index 000000000..26a9d8434 --- /dev/null +++ b/net/dns_resolver/dns_key.c @@ -0,0 +1,391 @@ +/* Key type used to cache DNS lookups made by the kernel + * + * See Documentation/networking/dns_resolver.rst + * + * Copyright (c) 2007 Igor Mammedov + * Author(s): Igor Mammedov (niallain@gmail.com) + * Steve French (sfrench@us.ibm.com) + * Wang Lei (wang840925@gmail.com) + * David Howells (dhowells@redhat.com) + * + * This library is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation; either version 2.1 of the License, or + * (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See + * the GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this library; if not, see . + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "internal.h" + +MODULE_DESCRIPTION("DNS Resolver"); +MODULE_AUTHOR("Wang Lei"); +MODULE_LICENSE("GPL"); + +unsigned int dns_resolver_debug; +module_param_named(debug, dns_resolver_debug, uint, 0644); +MODULE_PARM_DESC(debug, "DNS Resolver debugging mask"); + +const struct cred *dns_resolver_cache; + +#define DNS_ERRORNO_OPTION "dnserror" + +/* + * Preparse instantiation data for a dns_resolver key. + * + * For normal hostname lookups, the data must be a NUL-terminated string, with + * the NUL char accounted in datalen. + * + * If the data contains a '#' characters, then we take the clause after each + * one to be an option of the form 'key=value'. The actual data of interest is + * the string leading up to the first '#'. For instance: + * + * "ip1,ip2,...#foo=bar" + * + * For server list requests, the data must begin with a NUL char and be + * followed by a byte indicating the version of the data format. Version 1 + * looks something like (note this is packed): + * + * u8 Non-string marker (ie. 0) + * u8 Content (DNS_PAYLOAD_IS_*) + * u8 Version (e.g. 1) + * u8 Source of server list + * u8 Lookup status of server list + * u8 Number of servers + * foreach-server { + * __le16 Name length + * __le16 Priority (as per SRV record, low first) + * __le16 Weight (as per SRV record, higher first) + * __le16 Port + * u8 Source of address list + * u8 Lookup status of address list + * u8 Protocol (DNS_SERVER_PROTOCOL_*) + * u8 Number of addresses + * char[] Name (not NUL-terminated) + * foreach-address { + * u8 Family (DNS_ADDRESS_IS_*) + * union { + * u8[4] ipv4_addr + * u8[16] ipv6_addr + * } + * } + * } + * + */ +static int +dns_resolver_preparse(struct key_preparsed_payload *prep) +{ + struct user_key_payload *upayload; + unsigned long derrno; + int ret; + int datalen = prep->datalen, result_len = 0; + const char *data = prep->data, *end, *opt; + + if (datalen <= 1 || !data) + return -EINVAL; + + if (data[0] == 0) { + const struct dns_server_list_v1_header *v1; + + /* It may be a server list. */ + if (datalen < sizeof(*v1)) + return -EINVAL; + + v1 = (const struct dns_server_list_v1_header *)data; + kenter("[%u,%u],%u", v1->hdr.content, v1->hdr.version, datalen); + if (v1->hdr.content != DNS_PAYLOAD_IS_SERVER_LIST) { + pr_warn_ratelimited( + "dns_resolver: Unsupported content type (%u)\n", + v1->hdr.content); + return -EINVAL; + } + + if (v1->hdr.version != 1) { + pr_warn_ratelimited( + "dns_resolver: Unsupported server list version (%u)\n", + v1->hdr.version); + return -EINVAL; + } + + if ((v1->status != DNS_LOOKUP_GOOD && + v1->status != DNS_LOOKUP_GOOD_WITH_BAD)) { + if (prep->expiry == TIME64_MAX) + prep->expiry = ktime_get_real_seconds() + 1; + } + + result_len = datalen; + goto store_result; + } + + kenter("'%*.*s',%u", datalen, datalen, data, datalen); + + if (!data || data[datalen - 1] != '\0') + return -EINVAL; + datalen--; + + /* deal with any options embedded in the data */ + end = data + datalen; + opt = memchr(data, '#', datalen); + if (!opt) { + /* no options: the entire data is the result */ + kdebug("no options"); + result_len = datalen; + } else { + const char *next_opt; + + result_len = opt - data; + opt++; + kdebug("options: '%s'", opt); + do { + int opt_len, opt_nlen; + const char *eq; + char optval[128]; + + next_opt = memchr(opt, '#', end - opt) ?: end; + opt_len = next_opt - opt; + if (opt_len <= 0 || opt_len > sizeof(optval)) { + pr_warn_ratelimited("Invalid option length (%d) for dns_resolver key\n", + opt_len); + return -EINVAL; + } + + eq = memchr(opt, '=', opt_len); + if (eq) { + opt_nlen = eq - opt; + eq++; + memcpy(optval, eq, next_opt - eq); + optval[next_opt - eq] = '\0'; + } else { + opt_nlen = opt_len; + optval[0] = '\0'; + } + + kdebug("option '%*.*s' val '%s'", + opt_nlen, opt_nlen, opt, optval); + + /* see if it's an error number representing a DNS error + * that's to be recorded as the result in this key */ + if (opt_nlen == sizeof(DNS_ERRORNO_OPTION) - 1 && + memcmp(opt, DNS_ERRORNO_OPTION, opt_nlen) == 0) { + kdebug("dns error number option"); + + ret = kstrtoul(optval, 10, &derrno); + if (ret < 0) + goto bad_option_value; + + if (derrno < 1 || derrno > 511) + goto bad_option_value; + + kdebug("dns error no. = %lu", derrno); + prep->payload.data[dns_key_error] = ERR_PTR(-derrno); + continue; + } + + bad_option_value: + pr_warn_ratelimited("Option '%*.*s' to dns_resolver key: bad/missing value\n", + opt_nlen, opt_nlen, opt); + return -EINVAL; + } while (opt = next_opt + 1, opt < end); + } + + /* don't cache the result if we're caching an error saying there's no + * result */ + if (prep->payload.data[dns_key_error]) { + kleave(" = 0 [h_error %ld]", PTR_ERR(prep->payload.data[dns_key_error])); + return 0; + } + +store_result: + kdebug("store result"); + prep->quotalen = result_len; + + upayload = kmalloc(sizeof(*upayload) + result_len + 1, GFP_KERNEL); + if (!upayload) { + kleave(" = -ENOMEM"); + return -ENOMEM; + } + + upayload->datalen = result_len; + memcpy(upayload->data, data, result_len); + upayload->data[result_len] = '\0'; + + prep->payload.data[dns_key_data] = upayload; + kleave(" = 0"); + return 0; +} + +/* + * Clean up the preparse data + */ +static void dns_resolver_free_preparse(struct key_preparsed_payload *prep) +{ + pr_devel("==>%s()\n", __func__); + + kfree(prep->payload.data[dns_key_data]); +} + +/* + * The description is of the form "[:]" + * + * The domain name may be a simple name or an absolute domain name (which + * should end with a period). The domain name is case-independent. + */ +static bool dns_resolver_cmp(const struct key *key, + const struct key_match_data *match_data) +{ + int slen, dlen, ret = 0; + const char *src = key->description, *dsp = match_data->raw_data; + + kenter("%s,%s", src, dsp); + + if (!src || !dsp) + goto no_match; + + if (strcasecmp(src, dsp) == 0) + goto matched; + + slen = strlen(src); + dlen = strlen(dsp); + if (slen <= 0 || dlen <= 0) + goto no_match; + if (src[slen - 1] == '.') + slen--; + if (dsp[dlen - 1] == '.') + dlen--; + if (slen != dlen || strncasecmp(src, dsp, slen) != 0) + goto no_match; + +matched: + ret = 1; +no_match: + kleave(" = %d", ret); + return ret; +} + +/* + * Preparse the match criterion. + */ +static int dns_resolver_match_preparse(struct key_match_data *match_data) +{ + match_data->lookup_type = KEYRING_SEARCH_LOOKUP_ITERATE; + match_data->cmp = dns_resolver_cmp; + return 0; +} + +/* + * Describe a DNS key + */ +static void dns_resolver_describe(const struct key *key, struct seq_file *m) +{ + seq_puts(m, key->description); + if (key_is_positive(key)) { + int err = PTR_ERR(key->payload.data[dns_key_error]); + + if (err) + seq_printf(m, ": %d", err); + else + seq_printf(m, ": %u", key->datalen); + } +} + +/* + * read the DNS data + * - the key's semaphore is read-locked + */ +static long dns_resolver_read(const struct key *key, + char *buffer, size_t buflen) +{ + int err = PTR_ERR(key->payload.data[dns_key_error]); + + if (err) + return err; + + return user_read(key, buffer, buflen); +} + +struct key_type key_type_dns_resolver = { + .name = "dns_resolver", + .flags = KEY_TYPE_NET_DOMAIN | KEY_TYPE_INSTANT_REAP, + .preparse = dns_resolver_preparse, + .free_preparse = dns_resolver_free_preparse, + .instantiate = generic_key_instantiate, + .match_preparse = dns_resolver_match_preparse, + .revoke = user_revoke, + .destroy = user_destroy, + .describe = dns_resolver_describe, + .read = dns_resolver_read, +}; + +static int __init init_dns_resolver(void) +{ + struct cred *cred; + struct key *keyring; + int ret; + + /* create an override credential set with a special thread keyring in + * which DNS requests are cached + * + * this is used to prevent malicious redirections from being installed + * with add_key(). + */ + cred = prepare_kernel_cred(NULL); + if (!cred) + return -ENOMEM; + + keyring = keyring_alloc(".dns_resolver", + GLOBAL_ROOT_UID, GLOBAL_ROOT_GID, cred, + (KEY_POS_ALL & ~KEY_POS_SETATTR) | + KEY_USR_VIEW | KEY_USR_READ, + KEY_ALLOC_NOT_IN_QUOTA, NULL, NULL); + if (IS_ERR(keyring)) { + ret = PTR_ERR(keyring); + goto failed_put_cred; + } + + ret = register_key_type(&key_type_dns_resolver); + if (ret < 0) + goto failed_put_key; + + /* instruct request_key() to use this special keyring as a cache for + * the results it looks up */ + set_bit(KEY_FLAG_ROOT_CAN_CLEAR, &keyring->flags); + cred->thread_keyring = keyring; + cred->jit_keyring = KEY_REQKEY_DEFL_THREAD_KEYRING; + dns_resolver_cache = cred; + + kdebug("DNS resolver keyring: %d\n", key_serial(keyring)); + return 0; + +failed_put_key: + key_put(keyring); +failed_put_cred: + put_cred(cred); + return ret; +} + +static void __exit exit_dns_resolver(void) +{ + key_revoke(dns_resolver_cache->thread_keyring); + unregister_key_type(&key_type_dns_resolver); + put_cred(dns_resolver_cache); +} + +module_init(init_dns_resolver) +module_exit(exit_dns_resolver) +MODULE_LICENSE("GPL"); diff --git a/net/dns_resolver/dns_query.c b/net/dns_resolver/dns_query.c new file mode 100644 index 000000000..82b084cc1 --- /dev/null +++ b/net/dns_resolver/dns_query.c @@ -0,0 +1,172 @@ +/* Upcall routine, designed to work as a key type and working through + * /sbin/request-key to contact userspace when handling DNS queries. + * + * See Documentation/networking/dns_resolver.rst + * + * Copyright (c) 2007 Igor Mammedov + * Author(s): Igor Mammedov (niallain@gmail.com) + * Steve French (sfrench@us.ibm.com) + * Wang Lei (wang840925@gmail.com) + * David Howells (dhowells@redhat.com) + * + * The upcall wrapper used to make an arbitrary DNS query. + * + * This function requires the appropriate userspace tool dns.upcall to be + * installed and something like the following lines should be added to the + * /etc/request-key.conf file: + * + * create dns_resolver * * /sbin/dns.upcall %k + * + * For example to use this module to query AFSDB RR: + * + * create dns_resolver afsdb:* * /sbin/dns.afsdb %k + * + * This library is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation; either version 2.1 of the License, or + * (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See + * the GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this library; if not, see . + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "internal.h" + +/** + * dns_query - Query the DNS + * @net: The network namespace to operate in. + * @type: Query type (or NULL for straight host->IP lookup) + * @name: Name to look up + * @namelen: Length of name + * @options: Request options (or NULL if no options) + * @_result: Where to place the returned data (or NULL) + * @_expiry: Where to store the result expiry time (or NULL) + * @invalidate: Always invalidate the key after use + * + * The data will be returned in the pointer at *result, if provided, and the + * caller is responsible for freeing it. + * + * The description should be of the form "[:]", and + * the options need to be appropriate for the query type requested. If no + * query_type is given, then the query is a straight hostname to IP address + * lookup. + * + * The DNS resolution lookup is performed by upcalling to userspace by way of + * requesting a key of type dns_resolver. + * + * Returns the size of the result on success, -ve error code otherwise. + */ +int dns_query(struct net *net, + const char *type, const char *name, size_t namelen, + const char *options, char **_result, time64_t *_expiry, + bool invalidate) +{ + struct key *rkey; + struct user_key_payload *upayload; + const struct cred *saved_cred; + size_t typelen, desclen; + char *desc, *cp; + int ret, len; + + kenter("%s,%*.*s,%zu,%s", + type, (int)namelen, (int)namelen, name, namelen, options); + + if (!name || namelen == 0) + return -EINVAL; + + /* construct the query key description as "[:]" */ + typelen = 0; + desclen = 0; + if (type) { + typelen = strlen(type); + if (typelen < 1) + return -EINVAL; + desclen += typelen + 1; + } + + if (namelen < 3 || namelen > 255) + return -EINVAL; + desclen += namelen + 1; + + desc = kmalloc(desclen, GFP_KERNEL); + if (!desc) + return -ENOMEM; + + cp = desc; + if (type) { + memcpy(cp, type, typelen); + cp += typelen; + *cp++ = ':'; + } + memcpy(cp, name, namelen); + cp += namelen; + *cp = '\0'; + + if (!options) + options = ""; + kdebug("call request_key(,%s,%s)", desc, options); + + /* make the upcall, using special credentials to prevent the use of + * add_key() to preinstall malicious redirections + */ + saved_cred = override_creds(dns_resolver_cache); + rkey = request_key_net(&key_type_dns_resolver, desc, net, options); + revert_creds(saved_cred); + kfree(desc); + if (IS_ERR(rkey)) { + ret = PTR_ERR(rkey); + goto out; + } + + down_read(&rkey->sem); + set_bit(KEY_FLAG_ROOT_CAN_INVAL, &rkey->flags); + rkey->perm |= KEY_USR_VIEW; + + ret = key_validate(rkey); + if (ret < 0) + goto put; + + /* If the DNS server gave an error, return that to the caller */ + ret = PTR_ERR(rkey->payload.data[dns_key_error]); + if (ret) + goto put; + + upayload = user_key_payload_locked(rkey); + len = upayload->datalen; + + if (_result) { + ret = -ENOMEM; + *_result = kmemdup_nul(upayload->data, len, GFP_KERNEL); + if (!*_result) + goto put; + } + + if (_expiry) + *_expiry = rkey->expiry; + + ret = len; +put: + up_read(&rkey->sem); + if (invalidate) + key_invalidate(rkey); + key_put(rkey); +out: + kleave(" = %d", ret); + return ret; +} +EXPORT_SYMBOL(dns_query); diff --git a/net/dns_resolver/internal.h b/net/dns_resolver/internal.h new file mode 100644 index 000000000..0c570d40e --- /dev/null +++ b/net/dns_resolver/internal.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2010 Wang Lei + * Author(s): Wang Lei (wang840925@gmail.com). All Rights Reserved. + * + * Internal DNS Rsolver stuff + * + * This library is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation; either version 2.1 of the License, or + * (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See + * the GNU Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this library; if not, see . + */ + +#include +#include +#include + +/* + * Layout of key payload words. + */ +enum { + dns_key_data, + dns_key_error, +}; + +/* + * dns_key.c + */ +extern const struct cred *dns_resolver_cache; + +/* + * debug tracing + */ +extern unsigned int dns_resolver_debug; + +#define kdebug(FMT, ...) \ +do { \ + if (unlikely(dns_resolver_debug)) \ + printk(KERN_DEBUG "[%-6.6s] "FMT"\n", \ + current->comm, ##__VA_ARGS__); \ +} while (0) + +#define kenter(FMT, ...) kdebug("==> %s("FMT")", __func__, ##__VA_ARGS__) +#define kleave(FMT, ...) kdebug("<== %s()"FMT"", __func__, ##__VA_ARGS__) diff --git a/net/dsa/Kconfig b/net/dsa/Kconfig new file mode 100644 index 000000000..3eef72ce9 --- /dev/null +++ b/net/dsa/Kconfig @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: GPL-2.0-only + +menuconfig NET_DSA + tristate "Distributed Switch Architecture" + depends on BRIDGE || BRIDGE=n + depends on HSR || HSR=n + depends on INET && NETDEVICES + select GRO_CELLS + select NET_SWITCHDEV + select PHYLINK + select NET_DEVLINK + imply NET_SELFTESTS + help + Say Y if you want to enable support for the hardware switches supported + by the Distributed Switch Architecture. + +if NET_DSA + +# Drivers must select the appropriate tagging format(s) + +config NET_DSA_TAG_AR9331 + tristate "Tag driver for Atheros AR9331 SoC with built-in switch" + help + Say Y or M if you want to enable support for tagging frames for + the Atheros AR9331 SoC with built-in switch. + +config NET_DSA_TAG_BRCM_COMMON + tristate + default n + +config NET_DSA_TAG_BRCM + tristate "Tag driver for Broadcom switches using in-frame headers" + select NET_DSA_TAG_BRCM_COMMON + help + Say Y if you want to enable support for tagging frames for the + Broadcom switches which place the tag after the MAC source address. + +config NET_DSA_TAG_BRCM_LEGACY + tristate "Tag driver for Broadcom legacy switches using in-frame headers" + select NET_DSA_TAG_BRCM_COMMON + help + Say Y if you want to enable support for tagging frames for the + Broadcom legacy switches which place the tag after the MAC source + address. + +config NET_DSA_TAG_BRCM_PREPEND + tristate "Tag driver for Broadcom switches using prepended headers" + select NET_DSA_TAG_BRCM_COMMON + help + Say Y if you want to enable support for tagging frames for the + Broadcom switches which places the tag before the Ethernet header + (prepended). + +config NET_DSA_TAG_HELLCREEK + tristate "Tag driver for Hirschmann Hellcreek TSN switches" + help + Say Y or M if you want to enable support for tagging frames + for the Hirschmann Hellcreek TSN switches. + +config NET_DSA_TAG_GSWIP + tristate "Tag driver for Lantiq / Intel GSWIP switches" + help + Say Y or M if you want to enable support for tagging frames for the + Lantiq / Intel GSWIP switches. + +config NET_DSA_TAG_DSA_COMMON + tristate + +config NET_DSA_TAG_DSA + tristate "Tag driver for Marvell switches using DSA headers" + select NET_DSA_TAG_DSA_COMMON + help + Say Y or M if you want to enable support for tagging frames for the + Marvell switches which use DSA headers. + +config NET_DSA_TAG_EDSA + tristate "Tag driver for Marvell switches using EtherType DSA headers" + select NET_DSA_TAG_DSA_COMMON + help + Say Y or M if you want to enable support for tagging frames for the + Marvell switches which use EtherType DSA headers. + +config NET_DSA_TAG_MTK + tristate "Tag driver for Mediatek switches" + help + Say Y or M if you want to enable support for tagging frames for + Mediatek switches. + +config NET_DSA_TAG_KSZ + tristate "Tag driver for Microchip 8795/937x/9477/9893 families of switches" + help + Say Y if you want to enable support for tagging frames for the + Microchip 8795/937x/9477/9893 families of switches. + +config NET_DSA_TAG_OCELOT + tristate "Tag driver for Ocelot family of switches, using NPI port" + select PACKING + help + Say Y or M if you want to enable NPI tagging for the Ocelot switches + (VSC7511, VSC7512, VSC7513, VSC7514, VSC9953, VSC9959). In this mode, + the frames over the Ethernet CPU port are prepended with a + hardware-defined injection/extraction frame header. Flow control + (PAUSE frames) over the CPU port is not supported when operating in + this mode. + +config NET_DSA_TAG_OCELOT_8021Q + tristate "Tag driver for Ocelot family of switches, using VLAN" + help + Say Y or M if you want to enable support for tagging frames with a + custom VLAN-based header. Frames that require timestamping, such as + PTP, are not delivered over Ethernet but over register-based MMIO. + Flow control over the CPU port is functional in this mode. When using + this mode, less TCAM resources (VCAP IS1, IS2, ES0) are available for + use with tc-flower. + +config NET_DSA_TAG_QCA + tristate "Tag driver for Qualcomm Atheros QCA8K switches" + help + Say Y or M if you want to enable support for tagging frames for + the Qualcomm Atheros QCA8K switches. + +config NET_DSA_TAG_RTL4_A + tristate "Tag driver for Realtek 4 byte protocol A tags" + help + Say Y or M if you want to enable support for tagging frames for the + Realtek switches with 4 byte protocol A tags, sich as found in + the Realtek RTL8366RB. + +config NET_DSA_TAG_RTL8_4 + tristate "Tag driver for Realtek 8 byte protocol 4 tags" + help + Say Y or M if you want to enable support for tagging frames for Realtek + switches with 8 byte protocol 4 tags, such as the Realtek RTL8365MB-VC. + +config NET_DSA_TAG_RZN1_A5PSW + tristate "Tag driver for Renesas RZ/N1 A5PSW switch" + help + Say Y or M if you want to enable support for tagging frames for + Renesas RZ/N1 embedded switch that uses an 8 byte tag located after + destination MAC address. + +config NET_DSA_TAG_LAN9303 + tristate "Tag driver for SMSC/Microchip LAN9303 family of switches" + help + Say Y or M if you want to enable support for tagging frames for the + SMSC/Microchip LAN9303 family of switches. + +config NET_DSA_TAG_SJA1105 + tristate "Tag driver for NXP SJA1105 switches" + select PACKING + help + Say Y or M if you want to enable support for tagging frames with the + NXP SJA1105 switch family. Both the native tagging protocol (which + is only for link-local traffic) as well as non-native tagging (based + on a custom 802.1Q VLAN header) are available. + +config NET_DSA_TAG_TRAILER + tristate "Tag driver for switches using a trailer tag" + help + Say Y or M if you want to enable support for tagging frames at + with a trailed. e.g. Marvell 88E6060. + +config NET_DSA_TAG_XRS700X + tristate "Tag driver for XRS700x switches" + help + Say Y or M if you want to enable support for tagging frames for + Arrow SpeedChips XRS700x switches that use a single byte tag trailer. + +endif diff --git a/net/dsa/Makefile b/net/dsa/Makefile new file mode 100644 index 000000000..bf57ef3bc --- /dev/null +++ b/net/dsa/Makefile @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: GPL-2.0 +# the core +obj-$(CONFIG_NET_DSA) += dsa_core.o +dsa_core-y += \ + dsa.o \ + dsa2.o \ + master.o \ + netlink.o \ + port.o \ + slave.o \ + switch.o \ + tag_8021q.o + +# tagging formats +obj-$(CONFIG_NET_DSA_TAG_AR9331) += tag_ar9331.o +obj-$(CONFIG_NET_DSA_TAG_BRCM_COMMON) += tag_brcm.o +obj-$(CONFIG_NET_DSA_TAG_DSA_COMMON) += tag_dsa.o +obj-$(CONFIG_NET_DSA_TAG_GSWIP) += tag_gswip.o +obj-$(CONFIG_NET_DSA_TAG_HELLCREEK) += tag_hellcreek.o +obj-$(CONFIG_NET_DSA_TAG_KSZ) += tag_ksz.o +obj-$(CONFIG_NET_DSA_TAG_LAN9303) += tag_lan9303.o +obj-$(CONFIG_NET_DSA_TAG_MTK) += tag_mtk.o +obj-$(CONFIG_NET_DSA_TAG_OCELOT) += tag_ocelot.o +obj-$(CONFIG_NET_DSA_TAG_OCELOT_8021Q) += tag_ocelot_8021q.o +obj-$(CONFIG_NET_DSA_TAG_QCA) += tag_qca.o +obj-$(CONFIG_NET_DSA_TAG_RTL4_A) += tag_rtl4_a.o +obj-$(CONFIG_NET_DSA_TAG_RTL8_4) += tag_rtl8_4.o +obj-$(CONFIG_NET_DSA_TAG_RZN1_A5PSW) += tag_rzn1_a5psw.o +obj-$(CONFIG_NET_DSA_TAG_SJA1105) += tag_sja1105.o +obj-$(CONFIG_NET_DSA_TAG_TRAILER) += tag_trailer.o +obj-$(CONFIG_NET_DSA_TAG_XRS700X) += tag_xrs700x.o diff --git a/net/dsa/dsa.c b/net/dsa/dsa.c new file mode 100644 index 000000000..64b14f655 --- /dev/null +++ b/net/dsa/dsa.c @@ -0,0 +1,570 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/dsa/dsa.c - Hardware switch handling + * Copyright (c) 2008-2009 Marvell Semiconductor + * Copyright (c) 2013 Florian Fainelli + */ + +#include +#include +#include +#include +#include +#include + +#include "dsa_priv.h" + +static LIST_HEAD(dsa_tag_drivers_list); +static DEFINE_MUTEX(dsa_tag_drivers_lock); + +static struct sk_buff *dsa_slave_notag_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + /* Just return the original SKB */ + return skb; +} + +static const struct dsa_device_ops none_ops = { + .name = "none", + .proto = DSA_TAG_PROTO_NONE, + .xmit = dsa_slave_notag_xmit, + .rcv = NULL, +}; + +DSA_TAG_DRIVER(none_ops); + +static void dsa_tag_driver_register(struct dsa_tag_driver *dsa_tag_driver, + struct module *owner) +{ + dsa_tag_driver->owner = owner; + + mutex_lock(&dsa_tag_drivers_lock); + list_add_tail(&dsa_tag_driver->list, &dsa_tag_drivers_list); + mutex_unlock(&dsa_tag_drivers_lock); +} + +void dsa_tag_drivers_register(struct dsa_tag_driver *dsa_tag_driver_array[], + unsigned int count, struct module *owner) +{ + unsigned int i; + + for (i = 0; i < count; i++) + dsa_tag_driver_register(dsa_tag_driver_array[i], owner); +} + +static void dsa_tag_driver_unregister(struct dsa_tag_driver *dsa_tag_driver) +{ + mutex_lock(&dsa_tag_drivers_lock); + list_del(&dsa_tag_driver->list); + mutex_unlock(&dsa_tag_drivers_lock); +} +EXPORT_SYMBOL_GPL(dsa_tag_drivers_register); + +void dsa_tag_drivers_unregister(struct dsa_tag_driver *dsa_tag_driver_array[], + unsigned int count) +{ + unsigned int i; + + for (i = 0; i < count; i++) + dsa_tag_driver_unregister(dsa_tag_driver_array[i]); +} +EXPORT_SYMBOL_GPL(dsa_tag_drivers_unregister); + +const char *dsa_tag_protocol_to_str(const struct dsa_device_ops *ops) +{ + return ops->name; +}; + +/* Function takes a reference on the module owning the tagger, + * so dsa_tag_driver_put must be called afterwards. + */ +const struct dsa_device_ops *dsa_find_tagger_by_name(const char *buf) +{ + const struct dsa_device_ops *ops = ERR_PTR(-ENOPROTOOPT); + struct dsa_tag_driver *dsa_tag_driver; + + mutex_lock(&dsa_tag_drivers_lock); + list_for_each_entry(dsa_tag_driver, &dsa_tag_drivers_list, list) { + const struct dsa_device_ops *tmp = dsa_tag_driver->ops; + + if (!sysfs_streq(buf, tmp->name)) + continue; + + if (!try_module_get(dsa_tag_driver->owner)) + break; + + ops = tmp; + break; + } + mutex_unlock(&dsa_tag_drivers_lock); + + return ops; +} + +const struct dsa_device_ops *dsa_tag_driver_get(int tag_protocol) +{ + struct dsa_tag_driver *dsa_tag_driver; + const struct dsa_device_ops *ops; + bool found = false; + + request_module("%s%d", DSA_TAG_DRIVER_ALIAS, tag_protocol); + + mutex_lock(&dsa_tag_drivers_lock); + list_for_each_entry(dsa_tag_driver, &dsa_tag_drivers_list, list) { + ops = dsa_tag_driver->ops; + if (ops->proto == tag_protocol) { + found = true; + break; + } + } + + if (found) { + if (!try_module_get(dsa_tag_driver->owner)) + ops = ERR_PTR(-ENOPROTOOPT); + } else { + ops = ERR_PTR(-ENOPROTOOPT); + } + + mutex_unlock(&dsa_tag_drivers_lock); + + return ops; +} + +void dsa_tag_driver_put(const struct dsa_device_ops *ops) +{ + struct dsa_tag_driver *dsa_tag_driver; + + mutex_lock(&dsa_tag_drivers_lock); + list_for_each_entry(dsa_tag_driver, &dsa_tag_drivers_list, list) { + if (dsa_tag_driver->ops == ops) { + module_put(dsa_tag_driver->owner); + break; + } + } + mutex_unlock(&dsa_tag_drivers_lock); +} + +static int dev_is_class(struct device *dev, void *class) +{ + if (dev->class != NULL && !strcmp(dev->class->name, class)) + return 1; + + return 0; +} + +static struct device *dev_find_class(struct device *parent, char *class) +{ + if (dev_is_class(parent, class)) { + get_device(parent); + return parent; + } + + return device_find_child(parent, class, dev_is_class); +} + +struct net_device *dsa_dev_to_net_device(struct device *dev) +{ + struct device *d; + + d = dev_find_class(dev, "net"); + if (d != NULL) { + struct net_device *nd; + + nd = to_net_dev(d); + dev_hold(nd); + put_device(d); + + return nd; + } + + return NULL; +} +EXPORT_SYMBOL_GPL(dsa_dev_to_net_device); + +/* Determine if we should defer delivery of skb until we have a rx timestamp. + * + * Called from dsa_switch_rcv. For now, this will only work if tagging is + * enabled on the switch. Normally the MAC driver would retrieve the hardware + * timestamp when it reads the packet out of the hardware. However in a DSA + * switch, the DSA driver owning the interface to which the packet is + * delivered is never notified unless we do so here. + */ +static bool dsa_skb_defer_rx_timestamp(struct dsa_slave_priv *p, + struct sk_buff *skb) +{ + struct dsa_switch *ds = p->dp->ds; + unsigned int type; + + if (skb_headroom(skb) < ETH_HLEN) + return false; + + __skb_push(skb, ETH_HLEN); + + type = ptp_classify_raw(skb); + + __skb_pull(skb, ETH_HLEN); + + if (type == PTP_CLASS_NONE) + return false; + + if (likely(ds->ops->port_rxtstamp)) + return ds->ops->port_rxtstamp(ds, p->dp->index, skb, type); + + return false; +} + +static int dsa_switch_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *unused) +{ + struct dsa_port *cpu_dp = dev->dsa_ptr; + struct sk_buff *nskb = NULL; + struct dsa_slave_priv *p; + + if (unlikely(!cpu_dp)) { + kfree_skb(skb); + return 0; + } + + skb = skb_unshare(skb, GFP_ATOMIC); + if (!skb) + return 0; + + nskb = cpu_dp->rcv(skb, dev); + if (!nskb) { + kfree_skb(skb); + return 0; + } + + skb = nskb; + skb_push(skb, ETH_HLEN); + skb->pkt_type = PACKET_HOST; + skb->protocol = eth_type_trans(skb, skb->dev); + + if (unlikely(!dsa_slave_dev_check(skb->dev))) { + /* Packet is to be injected directly on an upper + * device, e.g. a team/bond, so skip all DSA-port + * specific actions. + */ + netif_rx(skb); + return 0; + } + + p = netdev_priv(skb->dev); + + if (unlikely(cpu_dp->ds->untag_bridge_pvid)) { + nskb = dsa_untag_bridge_pvid(skb); + if (!nskb) { + kfree_skb(skb); + return 0; + } + skb = nskb; + } + + dev_sw_netstats_rx_add(skb->dev, skb->len); + + if (dsa_skb_defer_rx_timestamp(p, skb)) + return 0; + + gro_cells_receive(&p->gcells, skb); + + return 0; +} + +#ifdef CONFIG_PM_SLEEP +static bool dsa_port_is_initialized(const struct dsa_port *dp) +{ + return dp->type == DSA_PORT_TYPE_USER && dp->slave; +} + +int dsa_switch_suspend(struct dsa_switch *ds) +{ + struct dsa_port *dp; + int ret = 0; + + /* Suspend slave network devices */ + dsa_switch_for_each_port(dp, ds) { + if (!dsa_port_is_initialized(dp)) + continue; + + ret = dsa_slave_suspend(dp->slave); + if (ret) + return ret; + } + + if (ds->ops->suspend) + ret = ds->ops->suspend(ds); + + return ret; +} +EXPORT_SYMBOL_GPL(dsa_switch_suspend); + +int dsa_switch_resume(struct dsa_switch *ds) +{ + struct dsa_port *dp; + int ret = 0; + + if (ds->ops->resume) + ret = ds->ops->resume(ds); + + if (ret) + return ret; + + /* Resume slave network devices */ + dsa_switch_for_each_port(dp, ds) { + if (!dsa_port_is_initialized(dp)) + continue; + + ret = dsa_slave_resume(dp->slave); + if (ret) + return ret; + } + + return 0; +} +EXPORT_SYMBOL_GPL(dsa_switch_resume); +#endif + +static struct packet_type dsa_pack_type __read_mostly = { + .type = cpu_to_be16(ETH_P_XDSA), + .func = dsa_switch_rcv, +}; + +static struct workqueue_struct *dsa_owq; + +bool dsa_schedule_work(struct work_struct *work) +{ + return queue_work(dsa_owq, work); +} + +void dsa_flush_workqueue(void) +{ + flush_workqueue(dsa_owq); +} +EXPORT_SYMBOL_GPL(dsa_flush_workqueue); + +int dsa_devlink_param_get(struct devlink *dl, u32 id, + struct devlink_param_gset_ctx *ctx) +{ + struct dsa_switch *ds = dsa_devlink_to_ds(dl); + + if (!ds->ops->devlink_param_get) + return -EOPNOTSUPP; + + return ds->ops->devlink_param_get(ds, id, ctx); +} +EXPORT_SYMBOL_GPL(dsa_devlink_param_get); + +int dsa_devlink_param_set(struct devlink *dl, u32 id, + struct devlink_param_gset_ctx *ctx) +{ + struct dsa_switch *ds = dsa_devlink_to_ds(dl); + + if (!ds->ops->devlink_param_set) + return -EOPNOTSUPP; + + return ds->ops->devlink_param_set(ds, id, ctx); +} +EXPORT_SYMBOL_GPL(dsa_devlink_param_set); + +int dsa_devlink_params_register(struct dsa_switch *ds, + const struct devlink_param *params, + size_t params_count) +{ + return devlink_params_register(ds->devlink, params, params_count); +} +EXPORT_SYMBOL_GPL(dsa_devlink_params_register); + +void dsa_devlink_params_unregister(struct dsa_switch *ds, + const struct devlink_param *params, + size_t params_count) +{ + devlink_params_unregister(ds->devlink, params, params_count); +} +EXPORT_SYMBOL_GPL(dsa_devlink_params_unregister); + +int dsa_devlink_resource_register(struct dsa_switch *ds, + const char *resource_name, + u64 resource_size, + u64 resource_id, + u64 parent_resource_id, + const struct devlink_resource_size_params *size_params) +{ + return devlink_resource_register(ds->devlink, resource_name, + resource_size, resource_id, + parent_resource_id, + size_params); +} +EXPORT_SYMBOL_GPL(dsa_devlink_resource_register); + +void dsa_devlink_resources_unregister(struct dsa_switch *ds) +{ + devlink_resources_unregister(ds->devlink); +} +EXPORT_SYMBOL_GPL(dsa_devlink_resources_unregister); + +void dsa_devlink_resource_occ_get_register(struct dsa_switch *ds, + u64 resource_id, + devlink_resource_occ_get_t *occ_get, + void *occ_get_priv) +{ + return devlink_resource_occ_get_register(ds->devlink, resource_id, + occ_get, occ_get_priv); +} +EXPORT_SYMBOL_GPL(dsa_devlink_resource_occ_get_register); + +void dsa_devlink_resource_occ_get_unregister(struct dsa_switch *ds, + u64 resource_id) +{ + devlink_resource_occ_get_unregister(ds->devlink, resource_id); +} +EXPORT_SYMBOL_GPL(dsa_devlink_resource_occ_get_unregister); + +struct devlink_region * +dsa_devlink_region_create(struct dsa_switch *ds, + const struct devlink_region_ops *ops, + u32 region_max_snapshots, u64 region_size) +{ + return devlink_region_create(ds->devlink, ops, region_max_snapshots, + region_size); +} +EXPORT_SYMBOL_GPL(dsa_devlink_region_create); + +struct devlink_region * +dsa_devlink_port_region_create(struct dsa_switch *ds, + int port, + const struct devlink_port_region_ops *ops, + u32 region_max_snapshots, u64 region_size) +{ + struct dsa_port *dp = dsa_to_port(ds, port); + + return devlink_port_region_create(&dp->devlink_port, ops, + region_max_snapshots, + region_size); +} +EXPORT_SYMBOL_GPL(dsa_devlink_port_region_create); + +void dsa_devlink_region_destroy(struct devlink_region *region) +{ + devlink_region_destroy(region); +} +EXPORT_SYMBOL_GPL(dsa_devlink_region_destroy); + +struct dsa_port *dsa_port_from_netdev(struct net_device *netdev) +{ + if (!netdev || !dsa_slave_dev_check(netdev)) + return ERR_PTR(-ENODEV); + + return dsa_slave_to_port(netdev); +} +EXPORT_SYMBOL_GPL(dsa_port_from_netdev); + +bool dsa_db_equal(const struct dsa_db *a, const struct dsa_db *b) +{ + if (a->type != b->type) + return false; + + switch (a->type) { + case DSA_DB_PORT: + return a->dp == b->dp; + case DSA_DB_LAG: + return a->lag.dev == b->lag.dev; + case DSA_DB_BRIDGE: + return a->bridge.num == b->bridge.num; + default: + WARN_ON(1); + return false; + } +} + +bool dsa_fdb_present_in_other_db(struct dsa_switch *ds, int port, + const unsigned char *addr, u16 vid, + struct dsa_db db) +{ + struct dsa_port *dp = dsa_to_port(ds, port); + struct dsa_mac_addr *a; + + lockdep_assert_held(&dp->addr_lists_lock); + + list_for_each_entry(a, &dp->fdbs, list) { + if (!ether_addr_equal(a->addr, addr) || a->vid != vid) + continue; + + if (a->db.type == db.type && !dsa_db_equal(&a->db, &db)) + return true; + } + + return false; +} +EXPORT_SYMBOL_GPL(dsa_fdb_present_in_other_db); + +bool dsa_mdb_present_in_other_db(struct dsa_switch *ds, int port, + const struct switchdev_obj_port_mdb *mdb, + struct dsa_db db) +{ + struct dsa_port *dp = dsa_to_port(ds, port); + struct dsa_mac_addr *a; + + lockdep_assert_held(&dp->addr_lists_lock); + + list_for_each_entry(a, &dp->mdbs, list) { + if (!ether_addr_equal(a->addr, mdb->addr) || a->vid != mdb->vid) + continue; + + if (a->db.type == db.type && !dsa_db_equal(&a->db, &db)) + return true; + } + + return false; +} +EXPORT_SYMBOL_GPL(dsa_mdb_present_in_other_db); + +static int __init dsa_init_module(void) +{ + int rc; + + dsa_owq = alloc_ordered_workqueue("dsa_ordered", + WQ_MEM_RECLAIM); + if (!dsa_owq) + return -ENOMEM; + + rc = dsa_slave_register_notifier(); + if (rc) + goto register_notifier_fail; + + dev_add_pack(&dsa_pack_type); + + dsa_tag_driver_register(&DSA_TAG_DRIVER_NAME(none_ops), + THIS_MODULE); + + rc = rtnl_link_register(&dsa_link_ops); + if (rc) + goto netlink_register_fail; + + return 0; + +netlink_register_fail: + dsa_tag_driver_unregister(&DSA_TAG_DRIVER_NAME(none_ops)); + dsa_slave_unregister_notifier(); + dev_remove_pack(&dsa_pack_type); +register_notifier_fail: + destroy_workqueue(dsa_owq); + + return rc; +} +module_init(dsa_init_module); + +static void __exit dsa_cleanup_module(void) +{ + rtnl_link_unregister(&dsa_link_ops); + dsa_tag_driver_unregister(&DSA_TAG_DRIVER_NAME(none_ops)); + + dsa_slave_unregister_notifier(); + dev_remove_pack(&dsa_pack_type); + destroy_workqueue(dsa_owq); +} +module_exit(dsa_cleanup_module); + +MODULE_AUTHOR("Lennert Buytenhek "); +MODULE_DESCRIPTION("Driver for Distributed Switch Architecture switch chips"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("platform:dsa"); diff --git a/net/dsa/dsa2.c b/net/dsa/dsa2.c new file mode 100644 index 000000000..5417f7b11 --- /dev/null +++ b/net/dsa/dsa2.c @@ -0,0 +1,1829 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/dsa/dsa2.c - Hardware switch handling, binding version 2 + * Copyright (c) 2008-2009 Marvell Semiconductor + * Copyright (c) 2013 Florian Fainelli + * Copyright (c) 2016 Andrew Lunn + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dsa_priv.h" + +static DEFINE_MUTEX(dsa2_mutex); +LIST_HEAD(dsa_tree_list); + +/* Track the bridges with forwarding offload enabled */ +static unsigned long dsa_fwd_offloading_bridges; + +/** + * dsa_tree_notify - Execute code for all switches in a DSA switch tree. + * @dst: collection of struct dsa_switch devices to notify. + * @e: event, must be of type DSA_NOTIFIER_* + * @v: event-specific value. + * + * Given a struct dsa_switch_tree, this can be used to run a function once for + * each member DSA switch. The other alternative of traversing the tree is only + * through its ports list, which does not uniquely list the switches. + */ +int dsa_tree_notify(struct dsa_switch_tree *dst, unsigned long e, void *v) +{ + struct raw_notifier_head *nh = &dst->nh; + int err; + + err = raw_notifier_call_chain(nh, e, v); + + return notifier_to_errno(err); +} + +/** + * dsa_broadcast - Notify all DSA trees in the system. + * @e: event, must be of type DSA_NOTIFIER_* + * @v: event-specific value. + * + * Can be used to notify the switching fabric of events such as cross-chip + * bridging between disjoint trees (such as islands of tagger-compatible + * switches bridged by an incompatible middle switch). + * + * WARNING: this function is not reliable during probe time, because probing + * between trees is asynchronous and not all DSA trees might have probed. + */ +int dsa_broadcast(unsigned long e, void *v) +{ + struct dsa_switch_tree *dst; + int err = 0; + + list_for_each_entry(dst, &dsa_tree_list, list) { + err = dsa_tree_notify(dst, e, v); + if (err) + break; + } + + return err; +} + +/** + * dsa_lag_map() - Map LAG structure to a linear LAG array + * @dst: Tree in which to record the mapping. + * @lag: LAG structure that is to be mapped to the tree's array. + * + * dsa_lag_id/dsa_lag_by_id can then be used to translate between the + * two spaces. The size of the mapping space is determined by the + * driver by setting ds->num_lag_ids. It is perfectly legal to leave + * it unset if it is not needed, in which case these functions become + * no-ops. + */ +void dsa_lag_map(struct dsa_switch_tree *dst, struct dsa_lag *lag) +{ + unsigned int id; + + for (id = 1; id <= dst->lags_len; id++) { + if (!dsa_lag_by_id(dst, id)) { + dst->lags[id - 1] = lag; + lag->id = id; + return; + } + } + + /* No IDs left, which is OK. Some drivers do not need it. The + * ones that do, e.g. mv88e6xxx, will discover that dsa_lag_id + * returns an error for this device when joining the LAG. The + * driver can then return -EOPNOTSUPP back to DSA, which will + * fall back to a software LAG. + */ +} + +/** + * dsa_lag_unmap() - Remove a LAG ID mapping + * @dst: Tree in which the mapping is recorded. + * @lag: LAG structure that was mapped. + * + * As there may be multiple users of the mapping, it is only removed + * if there are no other references to it. + */ +void dsa_lag_unmap(struct dsa_switch_tree *dst, struct dsa_lag *lag) +{ + unsigned int id; + + dsa_lags_foreach_id(id, dst) { + if (dsa_lag_by_id(dst, id) == lag) { + dst->lags[id - 1] = NULL; + lag->id = 0; + break; + } + } +} + +struct dsa_lag *dsa_tree_lag_find(struct dsa_switch_tree *dst, + const struct net_device *lag_dev) +{ + struct dsa_port *dp; + + list_for_each_entry(dp, &dst->ports, list) + if (dsa_port_lag_dev_get(dp) == lag_dev) + return dp->lag; + + return NULL; +} + +struct dsa_bridge *dsa_tree_bridge_find(struct dsa_switch_tree *dst, + const struct net_device *br) +{ + struct dsa_port *dp; + + list_for_each_entry(dp, &dst->ports, list) + if (dsa_port_bridge_dev_get(dp) == br) + return dp->bridge; + + return NULL; +} + +static int dsa_bridge_num_find(const struct net_device *bridge_dev) +{ + struct dsa_switch_tree *dst; + + list_for_each_entry(dst, &dsa_tree_list, list) { + struct dsa_bridge *bridge; + + bridge = dsa_tree_bridge_find(dst, bridge_dev); + if (bridge) + return bridge->num; + } + + return 0; +} + +unsigned int dsa_bridge_num_get(const struct net_device *bridge_dev, int max) +{ + unsigned int bridge_num = dsa_bridge_num_find(bridge_dev); + + /* Switches without FDB isolation support don't get unique + * bridge numbering + */ + if (!max) + return 0; + + if (!bridge_num) { + /* First port that requests FDB isolation or TX forwarding + * offload for this bridge + */ + bridge_num = find_next_zero_bit(&dsa_fwd_offloading_bridges, + DSA_MAX_NUM_OFFLOADING_BRIDGES, + 1); + if (bridge_num >= max) + return 0; + + set_bit(bridge_num, &dsa_fwd_offloading_bridges); + } + + return bridge_num; +} + +void dsa_bridge_num_put(const struct net_device *bridge_dev, + unsigned int bridge_num) +{ + /* Since we refcount bridges, we know that when we call this function + * it is no longer in use, so we can just go ahead and remove it from + * the bit mask. + */ + clear_bit(bridge_num, &dsa_fwd_offloading_bridges); +} + +struct dsa_switch *dsa_switch_find(int tree_index, int sw_index) +{ + struct dsa_switch_tree *dst; + struct dsa_port *dp; + + list_for_each_entry(dst, &dsa_tree_list, list) { + if (dst->index != tree_index) + continue; + + list_for_each_entry(dp, &dst->ports, list) { + if (dp->ds->index != sw_index) + continue; + + return dp->ds; + } + } + + return NULL; +} +EXPORT_SYMBOL_GPL(dsa_switch_find); + +static struct dsa_switch_tree *dsa_tree_find(int index) +{ + struct dsa_switch_tree *dst; + + list_for_each_entry(dst, &dsa_tree_list, list) + if (dst->index == index) + return dst; + + return NULL; +} + +static struct dsa_switch_tree *dsa_tree_alloc(int index) +{ + struct dsa_switch_tree *dst; + + dst = kzalloc(sizeof(*dst), GFP_KERNEL); + if (!dst) + return NULL; + + dst->index = index; + + INIT_LIST_HEAD(&dst->rtable); + + INIT_LIST_HEAD(&dst->ports); + + INIT_LIST_HEAD(&dst->list); + list_add_tail(&dst->list, &dsa_tree_list); + + kref_init(&dst->refcount); + + return dst; +} + +static void dsa_tree_free(struct dsa_switch_tree *dst) +{ + if (dst->tag_ops) + dsa_tag_driver_put(dst->tag_ops); + list_del(&dst->list); + kfree(dst); +} + +static struct dsa_switch_tree *dsa_tree_get(struct dsa_switch_tree *dst) +{ + if (dst) + kref_get(&dst->refcount); + + return dst; +} + +static struct dsa_switch_tree *dsa_tree_touch(int index) +{ + struct dsa_switch_tree *dst; + + dst = dsa_tree_find(index); + if (dst) + return dsa_tree_get(dst); + else + return dsa_tree_alloc(index); +} + +static void dsa_tree_release(struct kref *ref) +{ + struct dsa_switch_tree *dst; + + dst = container_of(ref, struct dsa_switch_tree, refcount); + + dsa_tree_free(dst); +} + +static void dsa_tree_put(struct dsa_switch_tree *dst) +{ + if (dst) + kref_put(&dst->refcount, dsa_tree_release); +} + +static struct dsa_port *dsa_tree_find_port_by_node(struct dsa_switch_tree *dst, + struct device_node *dn) +{ + struct dsa_port *dp; + + list_for_each_entry(dp, &dst->ports, list) + if (dp->dn == dn) + return dp; + + return NULL; +} + +static struct dsa_link *dsa_link_touch(struct dsa_port *dp, + struct dsa_port *link_dp) +{ + struct dsa_switch *ds = dp->ds; + struct dsa_switch_tree *dst; + struct dsa_link *dl; + + dst = ds->dst; + + list_for_each_entry(dl, &dst->rtable, list) + if (dl->dp == dp && dl->link_dp == link_dp) + return dl; + + dl = kzalloc(sizeof(*dl), GFP_KERNEL); + if (!dl) + return NULL; + + dl->dp = dp; + dl->link_dp = link_dp; + + INIT_LIST_HEAD(&dl->list); + list_add_tail(&dl->list, &dst->rtable); + + return dl; +} + +static bool dsa_port_setup_routing_table(struct dsa_port *dp) +{ + struct dsa_switch *ds = dp->ds; + struct dsa_switch_tree *dst = ds->dst; + struct device_node *dn = dp->dn; + struct of_phandle_iterator it; + struct dsa_port *link_dp; + struct dsa_link *dl; + int err; + + of_for_each_phandle(&it, err, dn, "link", NULL, 0) { + link_dp = dsa_tree_find_port_by_node(dst, it.node); + if (!link_dp) { + of_node_put(it.node); + return false; + } + + dl = dsa_link_touch(dp, link_dp); + if (!dl) { + of_node_put(it.node); + return false; + } + } + + return true; +} + +static bool dsa_tree_setup_routing_table(struct dsa_switch_tree *dst) +{ + bool complete = true; + struct dsa_port *dp; + + list_for_each_entry(dp, &dst->ports, list) { + if (dsa_port_is_dsa(dp)) { + complete = dsa_port_setup_routing_table(dp); + if (!complete) + break; + } + } + + return complete; +} + +static struct dsa_port *dsa_tree_find_first_cpu(struct dsa_switch_tree *dst) +{ + struct dsa_port *dp; + + list_for_each_entry(dp, &dst->ports, list) + if (dsa_port_is_cpu(dp)) + return dp; + + return NULL; +} + +struct net_device *dsa_tree_find_first_master(struct dsa_switch_tree *dst) +{ + struct device_node *ethernet; + struct net_device *master; + struct dsa_port *cpu_dp; + + cpu_dp = dsa_tree_find_first_cpu(dst); + ethernet = of_parse_phandle(cpu_dp->dn, "ethernet", 0); + master = of_find_net_device_by_node(ethernet); + of_node_put(ethernet); + + return master; +} + +/* Assign the default CPU port (the first one in the tree) to all ports of the + * fabric which don't already have one as part of their own switch. + */ +static int dsa_tree_setup_default_cpu(struct dsa_switch_tree *dst) +{ + struct dsa_port *cpu_dp, *dp; + + cpu_dp = dsa_tree_find_first_cpu(dst); + if (!cpu_dp) { + pr_err("DSA: tree %d has no CPU port\n", dst->index); + return -EINVAL; + } + + list_for_each_entry(dp, &dst->ports, list) { + if (dp->cpu_dp) + continue; + + if (dsa_port_is_user(dp) || dsa_port_is_dsa(dp)) + dp->cpu_dp = cpu_dp; + } + + return 0; +} + +/* Perform initial assignment of CPU ports to user ports and DSA links in the + * fabric, giving preference to CPU ports local to each switch. Default to + * using the first CPU port in the switch tree if the port does not have a CPU + * port local to this switch. + */ +static int dsa_tree_setup_cpu_ports(struct dsa_switch_tree *dst) +{ + struct dsa_port *cpu_dp, *dp; + + list_for_each_entry(cpu_dp, &dst->ports, list) { + if (!dsa_port_is_cpu(cpu_dp)) + continue; + + /* Prefer a local CPU port */ + dsa_switch_for_each_port(dp, cpu_dp->ds) { + /* Prefer the first local CPU port found */ + if (dp->cpu_dp) + continue; + + if (dsa_port_is_user(dp) || dsa_port_is_dsa(dp)) + dp->cpu_dp = cpu_dp; + } + } + + return dsa_tree_setup_default_cpu(dst); +} + +static void dsa_tree_teardown_cpu_ports(struct dsa_switch_tree *dst) +{ + struct dsa_port *dp; + + list_for_each_entry(dp, &dst->ports, list) + if (dsa_port_is_user(dp) || dsa_port_is_dsa(dp)) + dp->cpu_dp = NULL; +} + +static int dsa_port_devlink_setup(struct dsa_port *dp) +{ + struct devlink_port *dlp = &dp->devlink_port; + struct dsa_switch_tree *dst = dp->ds->dst; + struct devlink_port_attrs attrs = {}; + struct devlink *dl = dp->ds->devlink; + struct dsa_switch *ds = dp->ds; + const unsigned char *id; + unsigned char len; + int err; + + memset(dlp, 0, sizeof(*dlp)); + devlink_port_init(dl, dlp); + + if (ds->ops->port_setup) { + err = ds->ops->port_setup(ds, dp->index); + if (err) + return err; + } + + id = (const unsigned char *)&dst->index; + len = sizeof(dst->index); + + attrs.phys.port_number = dp->index; + memcpy(attrs.switch_id.id, id, len); + attrs.switch_id.id_len = len; + + switch (dp->type) { + case DSA_PORT_TYPE_UNUSED: + attrs.flavour = DEVLINK_PORT_FLAVOUR_UNUSED; + break; + case DSA_PORT_TYPE_CPU: + attrs.flavour = DEVLINK_PORT_FLAVOUR_CPU; + break; + case DSA_PORT_TYPE_DSA: + attrs.flavour = DEVLINK_PORT_FLAVOUR_DSA; + break; + case DSA_PORT_TYPE_USER: + attrs.flavour = DEVLINK_PORT_FLAVOUR_PHYSICAL; + break; + } + + devlink_port_attrs_set(dlp, &attrs); + err = devlink_port_register(dl, dlp, dp->index); + if (err) { + if (ds->ops->port_teardown) + ds->ops->port_teardown(ds, dp->index); + return err; + } + + return 0; +} + +static void dsa_port_devlink_teardown(struct dsa_port *dp) +{ + struct devlink_port *dlp = &dp->devlink_port; + struct dsa_switch *ds = dp->ds; + + devlink_port_unregister(dlp); + + if (ds->ops->port_teardown) + ds->ops->port_teardown(ds, dp->index); + + devlink_port_fini(dlp); +} + +static int dsa_port_setup(struct dsa_port *dp) +{ + struct devlink_port *dlp = &dp->devlink_port; + bool dsa_port_link_registered = false; + struct dsa_switch *ds = dp->ds; + bool dsa_port_enabled = false; + int err = 0; + + if (dp->setup) + return 0; + + err = dsa_port_devlink_setup(dp); + if (err) + return err; + + switch (dp->type) { + case DSA_PORT_TYPE_UNUSED: + dsa_port_disable(dp); + break; + case DSA_PORT_TYPE_CPU: + if (dp->dn) { + err = dsa_shared_port_link_register_of(dp); + if (err) + break; + dsa_port_link_registered = true; + } else { + dev_warn(ds->dev, + "skipping link registration for CPU port %d\n", + dp->index); + } + + err = dsa_port_enable(dp, NULL); + if (err) + break; + dsa_port_enabled = true; + + break; + case DSA_PORT_TYPE_DSA: + if (dp->dn) { + err = dsa_shared_port_link_register_of(dp); + if (err) + break; + dsa_port_link_registered = true; + } else { + dev_warn(ds->dev, + "skipping link registration for DSA port %d\n", + dp->index); + } + + err = dsa_port_enable(dp, NULL); + if (err) + break; + dsa_port_enabled = true; + + break; + case DSA_PORT_TYPE_USER: + of_get_mac_address(dp->dn, dp->mac); + err = dsa_slave_create(dp); + if (err) + break; + + devlink_port_type_eth_set(dlp, dp->slave); + break; + } + + if (err && dsa_port_enabled) + dsa_port_disable(dp); + if (err && dsa_port_link_registered) + dsa_shared_port_link_unregister_of(dp); + if (err) { + dsa_port_devlink_teardown(dp); + return err; + } + + dp->setup = true; + + return 0; +} + +static void dsa_port_teardown(struct dsa_port *dp) +{ + struct devlink_port *dlp = &dp->devlink_port; + + if (!dp->setup) + return; + + devlink_port_type_clear(dlp); + + switch (dp->type) { + case DSA_PORT_TYPE_UNUSED: + break; + case DSA_PORT_TYPE_CPU: + dsa_port_disable(dp); + if (dp->dn) + dsa_shared_port_link_unregister_of(dp); + break; + case DSA_PORT_TYPE_DSA: + dsa_port_disable(dp); + if (dp->dn) + dsa_shared_port_link_unregister_of(dp); + break; + case DSA_PORT_TYPE_USER: + if (dp->slave) { + dsa_slave_destroy(dp->slave); + dp->slave = NULL; + } + break; + } + + dsa_port_devlink_teardown(dp); + + dp->setup = false; +} + +static int dsa_port_setup_as_unused(struct dsa_port *dp) +{ + dp->type = DSA_PORT_TYPE_UNUSED; + return dsa_port_setup(dp); +} + +static int dsa_devlink_info_get(struct devlink *dl, + struct devlink_info_req *req, + struct netlink_ext_ack *extack) +{ + struct dsa_switch *ds = dsa_devlink_to_ds(dl); + + if (ds->ops->devlink_info_get) + return ds->ops->devlink_info_get(ds, req, extack); + + return -EOPNOTSUPP; +} + +static int dsa_devlink_sb_pool_get(struct devlink *dl, + unsigned int sb_index, u16 pool_index, + struct devlink_sb_pool_info *pool_info) +{ + struct dsa_switch *ds = dsa_devlink_to_ds(dl); + + if (!ds->ops->devlink_sb_pool_get) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_pool_get(ds, sb_index, pool_index, + pool_info); +} + +static int dsa_devlink_sb_pool_set(struct devlink *dl, unsigned int sb_index, + u16 pool_index, u32 size, + enum devlink_sb_threshold_type threshold_type, + struct netlink_ext_ack *extack) +{ + struct dsa_switch *ds = dsa_devlink_to_ds(dl); + + if (!ds->ops->devlink_sb_pool_set) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_pool_set(ds, sb_index, pool_index, size, + threshold_type, extack); +} + +static int dsa_devlink_sb_port_pool_get(struct devlink_port *dlp, + unsigned int sb_index, u16 pool_index, + u32 *p_threshold) +{ + struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); + int port = dsa_devlink_port_to_port(dlp); + + if (!ds->ops->devlink_sb_port_pool_get) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_port_pool_get(ds, port, sb_index, + pool_index, p_threshold); +} + +static int dsa_devlink_sb_port_pool_set(struct devlink_port *dlp, + unsigned int sb_index, u16 pool_index, + u32 threshold, + struct netlink_ext_ack *extack) +{ + struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); + int port = dsa_devlink_port_to_port(dlp); + + if (!ds->ops->devlink_sb_port_pool_set) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_port_pool_set(ds, port, sb_index, + pool_index, threshold, extack); +} + +static int +dsa_devlink_sb_tc_pool_bind_get(struct devlink_port *dlp, + unsigned int sb_index, u16 tc_index, + enum devlink_sb_pool_type pool_type, + u16 *p_pool_index, u32 *p_threshold) +{ + struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); + int port = dsa_devlink_port_to_port(dlp); + + if (!ds->ops->devlink_sb_tc_pool_bind_get) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_tc_pool_bind_get(ds, port, sb_index, + tc_index, pool_type, + p_pool_index, p_threshold); +} + +static int +dsa_devlink_sb_tc_pool_bind_set(struct devlink_port *dlp, + unsigned int sb_index, u16 tc_index, + enum devlink_sb_pool_type pool_type, + u16 pool_index, u32 threshold, + struct netlink_ext_ack *extack) +{ + struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); + int port = dsa_devlink_port_to_port(dlp); + + if (!ds->ops->devlink_sb_tc_pool_bind_set) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_tc_pool_bind_set(ds, port, sb_index, + tc_index, pool_type, + pool_index, threshold, + extack); +} + +static int dsa_devlink_sb_occ_snapshot(struct devlink *dl, + unsigned int sb_index) +{ + struct dsa_switch *ds = dsa_devlink_to_ds(dl); + + if (!ds->ops->devlink_sb_occ_snapshot) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_occ_snapshot(ds, sb_index); +} + +static int dsa_devlink_sb_occ_max_clear(struct devlink *dl, + unsigned int sb_index) +{ + struct dsa_switch *ds = dsa_devlink_to_ds(dl); + + if (!ds->ops->devlink_sb_occ_max_clear) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_occ_max_clear(ds, sb_index); +} + +static int dsa_devlink_sb_occ_port_pool_get(struct devlink_port *dlp, + unsigned int sb_index, + u16 pool_index, u32 *p_cur, + u32 *p_max) +{ + struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); + int port = dsa_devlink_port_to_port(dlp); + + if (!ds->ops->devlink_sb_occ_port_pool_get) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_occ_port_pool_get(ds, port, sb_index, + pool_index, p_cur, p_max); +} + +static int +dsa_devlink_sb_occ_tc_port_bind_get(struct devlink_port *dlp, + unsigned int sb_index, u16 tc_index, + enum devlink_sb_pool_type pool_type, + u32 *p_cur, u32 *p_max) +{ + struct dsa_switch *ds = dsa_devlink_port_to_ds(dlp); + int port = dsa_devlink_port_to_port(dlp); + + if (!ds->ops->devlink_sb_occ_tc_port_bind_get) + return -EOPNOTSUPP; + + return ds->ops->devlink_sb_occ_tc_port_bind_get(ds, port, + sb_index, tc_index, + pool_type, p_cur, + p_max); +} + +static const struct devlink_ops dsa_devlink_ops = { + .info_get = dsa_devlink_info_get, + .sb_pool_get = dsa_devlink_sb_pool_get, + .sb_pool_set = dsa_devlink_sb_pool_set, + .sb_port_pool_get = dsa_devlink_sb_port_pool_get, + .sb_port_pool_set = dsa_devlink_sb_port_pool_set, + .sb_tc_pool_bind_get = dsa_devlink_sb_tc_pool_bind_get, + .sb_tc_pool_bind_set = dsa_devlink_sb_tc_pool_bind_set, + .sb_occ_snapshot = dsa_devlink_sb_occ_snapshot, + .sb_occ_max_clear = dsa_devlink_sb_occ_max_clear, + .sb_occ_port_pool_get = dsa_devlink_sb_occ_port_pool_get, + .sb_occ_tc_port_bind_get = dsa_devlink_sb_occ_tc_port_bind_get, +}; + +static int dsa_switch_setup_tag_protocol(struct dsa_switch *ds) +{ + const struct dsa_device_ops *tag_ops = ds->dst->tag_ops; + struct dsa_switch_tree *dst = ds->dst; + int err; + + if (tag_ops->proto == dst->default_proto) + goto connect; + + rtnl_lock(); + err = ds->ops->change_tag_protocol(ds, tag_ops->proto); + rtnl_unlock(); + if (err) { + dev_err(ds->dev, "Unable to use tag protocol \"%s\": %pe\n", + tag_ops->name, ERR_PTR(err)); + return err; + } + +connect: + if (tag_ops->connect) { + err = tag_ops->connect(ds); + if (err) + return err; + } + + if (ds->ops->connect_tag_protocol) { + err = ds->ops->connect_tag_protocol(ds, tag_ops->proto); + if (err) { + dev_err(ds->dev, + "Unable to connect to tag protocol \"%s\": %pe\n", + tag_ops->name, ERR_PTR(err)); + goto disconnect; + } + } + + return 0; + +disconnect: + if (tag_ops->disconnect) + tag_ops->disconnect(ds); + + return err; +} + +static void dsa_switch_teardown_tag_protocol(struct dsa_switch *ds) +{ + const struct dsa_device_ops *tag_ops = ds->dst->tag_ops; + + if (tag_ops->disconnect) + tag_ops->disconnect(ds); +} + +static int dsa_switch_setup(struct dsa_switch *ds) +{ + struct dsa_devlink_priv *dl_priv; + struct device_node *dn; + int err; + + if (ds->setup) + return 0; + + /* Initialize ds->phys_mii_mask before registering the slave MDIO bus + * driver and before ops->setup() has run, since the switch drivers and + * the slave MDIO bus driver rely on these values for probing PHY + * devices or not + */ + ds->phys_mii_mask |= dsa_user_ports(ds); + + /* Add the switch to devlink before calling setup, so that setup can + * add dpipe tables + */ + ds->devlink = + devlink_alloc(&dsa_devlink_ops, sizeof(*dl_priv), ds->dev); + if (!ds->devlink) + return -ENOMEM; + dl_priv = devlink_priv(ds->devlink); + dl_priv->ds = ds; + + err = dsa_switch_register_notifier(ds); + if (err) + goto devlink_free; + + ds->configure_vlan_while_not_filtering = true; + + err = ds->ops->setup(ds); + if (err < 0) + goto unregister_notifier; + + err = dsa_switch_setup_tag_protocol(ds); + if (err) + goto teardown; + + if (!ds->slave_mii_bus && ds->ops->phy_read) { + ds->slave_mii_bus = mdiobus_alloc(); + if (!ds->slave_mii_bus) { + err = -ENOMEM; + goto teardown; + } + + dsa_slave_mii_bus_init(ds); + + dn = of_get_child_by_name(ds->dev->of_node, "mdio"); + + err = of_mdiobus_register(ds->slave_mii_bus, dn); + of_node_put(dn); + if (err < 0) + goto free_slave_mii_bus; + } + + ds->setup = true; + devlink_register(ds->devlink); + return 0; + +free_slave_mii_bus: + if (ds->slave_mii_bus && ds->ops->phy_read) + mdiobus_free(ds->slave_mii_bus); +teardown: + if (ds->ops->teardown) + ds->ops->teardown(ds); +unregister_notifier: + dsa_switch_unregister_notifier(ds); +devlink_free: + devlink_free(ds->devlink); + ds->devlink = NULL; + return err; +} + +static void dsa_switch_teardown(struct dsa_switch *ds) +{ + if (!ds->setup) + return; + + if (ds->devlink) + devlink_unregister(ds->devlink); + + if (ds->slave_mii_bus && ds->ops->phy_read) { + mdiobus_unregister(ds->slave_mii_bus); + mdiobus_free(ds->slave_mii_bus); + ds->slave_mii_bus = NULL; + } + + dsa_switch_teardown_tag_protocol(ds); + + if (ds->ops->teardown) + ds->ops->teardown(ds); + + dsa_switch_unregister_notifier(ds); + + if (ds->devlink) { + devlink_free(ds->devlink); + ds->devlink = NULL; + } + + ds->setup = false; +} + +/* First tear down the non-shared, then the shared ports. This ensures that + * all work items scheduled by our switchdev handlers for user ports have + * completed before we destroy the refcounting kept on the shared ports. + */ +static void dsa_tree_teardown_ports(struct dsa_switch_tree *dst) +{ + struct dsa_port *dp; + + list_for_each_entry(dp, &dst->ports, list) + if (dsa_port_is_user(dp) || dsa_port_is_unused(dp)) + dsa_port_teardown(dp); + + dsa_flush_workqueue(); + + list_for_each_entry(dp, &dst->ports, list) + if (dsa_port_is_dsa(dp) || dsa_port_is_cpu(dp)) + dsa_port_teardown(dp); +} + +static void dsa_tree_teardown_switches(struct dsa_switch_tree *dst) +{ + struct dsa_port *dp; + + list_for_each_entry(dp, &dst->ports, list) + dsa_switch_teardown(dp->ds); +} + +/* Bring shared ports up first, then non-shared ports */ +static int dsa_tree_setup_ports(struct dsa_switch_tree *dst) +{ + struct dsa_port *dp; + int err = 0; + + list_for_each_entry(dp, &dst->ports, list) { + if (dsa_port_is_dsa(dp) || dsa_port_is_cpu(dp)) { + err = dsa_port_setup(dp); + if (err) + goto teardown; + } + } + + list_for_each_entry(dp, &dst->ports, list) { + if (dsa_port_is_user(dp) || dsa_port_is_unused(dp)) { + err = dsa_port_setup(dp); + if (err) { + err = dsa_port_setup_as_unused(dp); + if (err) + goto teardown; + } + } + } + + return 0; + +teardown: + dsa_tree_teardown_ports(dst); + + return err; +} + +static int dsa_tree_setup_switches(struct dsa_switch_tree *dst) +{ + struct dsa_port *dp; + int err = 0; + + list_for_each_entry(dp, &dst->ports, list) { + err = dsa_switch_setup(dp->ds); + if (err) { + dsa_tree_teardown_switches(dst); + break; + } + } + + return err; +} + +static int dsa_tree_setup_master(struct dsa_switch_tree *dst) +{ + struct dsa_port *cpu_dp; + int err = 0; + + rtnl_lock(); + + dsa_tree_for_each_cpu_port(cpu_dp, dst) { + struct net_device *master = cpu_dp->master; + bool admin_up = (master->flags & IFF_UP) && + !qdisc_tx_is_noop(master); + + err = dsa_master_setup(master, cpu_dp); + if (err) + break; + + /* Replay master state event */ + dsa_tree_master_admin_state_change(dst, master, admin_up); + dsa_tree_master_oper_state_change(dst, master, + netif_oper_up(master)); + } + + rtnl_unlock(); + + return err; +} + +static void dsa_tree_teardown_master(struct dsa_switch_tree *dst) +{ + struct dsa_port *cpu_dp; + + rtnl_lock(); + + dsa_tree_for_each_cpu_port(cpu_dp, dst) { + struct net_device *master = cpu_dp->master; + + /* Synthesizing an "admin down" state is sufficient for + * the switches to get a notification if the master is + * currently up and running. + */ + dsa_tree_master_admin_state_change(dst, master, false); + + dsa_master_teardown(master); + } + + rtnl_unlock(); +} + +static int dsa_tree_setup_lags(struct dsa_switch_tree *dst) +{ + unsigned int len = 0; + struct dsa_port *dp; + + list_for_each_entry(dp, &dst->ports, list) { + if (dp->ds->num_lag_ids > len) + len = dp->ds->num_lag_ids; + } + + if (!len) + return 0; + + dst->lags = kcalloc(len, sizeof(*dst->lags), GFP_KERNEL); + if (!dst->lags) + return -ENOMEM; + + dst->lags_len = len; + return 0; +} + +static void dsa_tree_teardown_lags(struct dsa_switch_tree *dst) +{ + kfree(dst->lags); +} + +static int dsa_tree_setup(struct dsa_switch_tree *dst) +{ + bool complete; + int err; + + if (dst->setup) { + pr_err("DSA: tree %d already setup! Disjoint trees?\n", + dst->index); + return -EEXIST; + } + + complete = dsa_tree_setup_routing_table(dst); + if (!complete) + return 0; + + err = dsa_tree_setup_cpu_ports(dst); + if (err) + return err; + + err = dsa_tree_setup_switches(dst); + if (err) + goto teardown_cpu_ports; + + err = dsa_tree_setup_ports(dst); + if (err) + goto teardown_switches; + + err = dsa_tree_setup_master(dst); + if (err) + goto teardown_ports; + + err = dsa_tree_setup_lags(dst); + if (err) + goto teardown_master; + + dst->setup = true; + + pr_info("DSA: tree %d setup\n", dst->index); + + return 0; + +teardown_master: + dsa_tree_teardown_master(dst); +teardown_ports: + dsa_tree_teardown_ports(dst); +teardown_switches: + dsa_tree_teardown_switches(dst); +teardown_cpu_ports: + dsa_tree_teardown_cpu_ports(dst); + + return err; +} + +static void dsa_tree_teardown(struct dsa_switch_tree *dst) +{ + struct dsa_link *dl, *next; + + if (!dst->setup) + return; + + dsa_tree_teardown_lags(dst); + + dsa_tree_teardown_master(dst); + + dsa_tree_teardown_ports(dst); + + dsa_tree_teardown_switches(dst); + + dsa_tree_teardown_cpu_ports(dst); + + list_for_each_entry_safe(dl, next, &dst->rtable, list) { + list_del(&dl->list); + kfree(dl); + } + + pr_info("DSA: tree %d torn down\n", dst->index); + + dst->setup = false; +} + +static int dsa_tree_bind_tag_proto(struct dsa_switch_tree *dst, + const struct dsa_device_ops *tag_ops) +{ + const struct dsa_device_ops *old_tag_ops = dst->tag_ops; + struct dsa_notifier_tag_proto_info info; + int err; + + dst->tag_ops = tag_ops; + + /* Notify the switches from this tree about the connection + * to the new tagger + */ + info.tag_ops = tag_ops; + err = dsa_tree_notify(dst, DSA_NOTIFIER_TAG_PROTO_CONNECT, &info); + if (err && err != -EOPNOTSUPP) + goto out_disconnect; + + /* Notify the old tagger about the disconnection from this tree */ + info.tag_ops = old_tag_ops; + dsa_tree_notify(dst, DSA_NOTIFIER_TAG_PROTO_DISCONNECT, &info); + + return 0; + +out_disconnect: + info.tag_ops = tag_ops; + dsa_tree_notify(dst, DSA_NOTIFIER_TAG_PROTO_DISCONNECT, &info); + dst->tag_ops = old_tag_ops; + + return err; +} + +/* Since the dsa/tagging sysfs device attribute is per master, the assumption + * is that all DSA switches within a tree share the same tagger, otherwise + * they would have formed disjoint trees (different "dsa,member" values). + */ +int dsa_tree_change_tag_proto(struct dsa_switch_tree *dst, + const struct dsa_device_ops *tag_ops, + const struct dsa_device_ops *old_tag_ops) +{ + struct dsa_notifier_tag_proto_info info; + struct dsa_port *dp; + int err = -EBUSY; + + if (!rtnl_trylock()) + return restart_syscall(); + + /* At the moment we don't allow changing the tag protocol under + * traffic. The rtnl_mutex also happens to serialize concurrent + * attempts to change the tagging protocol. If we ever lift the IFF_UP + * restriction, there needs to be another mutex which serializes this. + */ + dsa_tree_for_each_user_port(dp, dst) { + if (dsa_port_to_master(dp)->flags & IFF_UP) + goto out_unlock; + + if (dp->slave->flags & IFF_UP) + goto out_unlock; + } + + /* Notify the tag protocol change */ + info.tag_ops = tag_ops; + err = dsa_tree_notify(dst, DSA_NOTIFIER_TAG_PROTO, &info); + if (err) + goto out_unwind_tagger; + + err = dsa_tree_bind_tag_proto(dst, tag_ops); + if (err) + goto out_unwind_tagger; + + rtnl_unlock(); + + return 0; + +out_unwind_tagger: + info.tag_ops = old_tag_ops; + dsa_tree_notify(dst, DSA_NOTIFIER_TAG_PROTO, &info); +out_unlock: + rtnl_unlock(); + return err; +} + +static void dsa_tree_master_state_change(struct dsa_switch_tree *dst, + struct net_device *master) +{ + struct dsa_notifier_master_state_info info; + struct dsa_port *cpu_dp = master->dsa_ptr; + + info.master = master; + info.operational = dsa_port_master_is_operational(cpu_dp); + + dsa_tree_notify(dst, DSA_NOTIFIER_MASTER_STATE_CHANGE, &info); +} + +void dsa_tree_master_admin_state_change(struct dsa_switch_tree *dst, + struct net_device *master, + bool up) +{ + struct dsa_port *cpu_dp = master->dsa_ptr; + bool notify = false; + + /* Don't keep track of admin state on LAG DSA masters, + * but rather just of physical DSA masters + */ + if (netif_is_lag_master(master)) + return; + + if ((dsa_port_master_is_operational(cpu_dp)) != + (up && cpu_dp->master_oper_up)) + notify = true; + + cpu_dp->master_admin_up = up; + + if (notify) + dsa_tree_master_state_change(dst, master); +} + +void dsa_tree_master_oper_state_change(struct dsa_switch_tree *dst, + struct net_device *master, + bool up) +{ + struct dsa_port *cpu_dp = master->dsa_ptr; + bool notify = false; + + /* Don't keep track of oper state on LAG DSA masters, + * but rather just of physical DSA masters + */ + if (netif_is_lag_master(master)) + return; + + if ((dsa_port_master_is_operational(cpu_dp)) != + (cpu_dp->master_admin_up && up)) + notify = true; + + cpu_dp->master_oper_up = up; + + if (notify) + dsa_tree_master_state_change(dst, master); +} + +static struct dsa_port *dsa_port_touch(struct dsa_switch *ds, int index) +{ + struct dsa_switch_tree *dst = ds->dst; + struct dsa_port *dp; + + dsa_switch_for_each_port(dp, ds) + if (dp->index == index) + return dp; + + dp = kzalloc(sizeof(*dp), GFP_KERNEL); + if (!dp) + return NULL; + + dp->ds = ds; + dp->index = index; + + mutex_init(&dp->addr_lists_lock); + mutex_init(&dp->vlans_lock); + INIT_LIST_HEAD(&dp->fdbs); + INIT_LIST_HEAD(&dp->mdbs); + INIT_LIST_HEAD(&dp->vlans); + INIT_LIST_HEAD(&dp->list); + list_add_tail(&dp->list, &dst->ports); + + return dp; +} + +static int dsa_port_parse_user(struct dsa_port *dp, const char *name) +{ + if (!name) + name = "eth%d"; + + dp->type = DSA_PORT_TYPE_USER; + dp->name = name; + + return 0; +} + +static int dsa_port_parse_dsa(struct dsa_port *dp) +{ + dp->type = DSA_PORT_TYPE_DSA; + + return 0; +} + +static enum dsa_tag_protocol dsa_get_tag_protocol(struct dsa_port *dp, + struct net_device *master) +{ + enum dsa_tag_protocol tag_protocol = DSA_TAG_PROTO_NONE; + struct dsa_switch *mds, *ds = dp->ds; + unsigned int mdp_upstream; + struct dsa_port *mdp; + + /* It is possible to stack DSA switches onto one another when that + * happens the switch driver may want to know if its tagging protocol + * is going to work in such a configuration. + */ + if (dsa_slave_dev_check(master)) { + mdp = dsa_slave_to_port(master); + mds = mdp->ds; + mdp_upstream = dsa_upstream_port(mds, mdp->index); + tag_protocol = mds->ops->get_tag_protocol(mds, mdp_upstream, + DSA_TAG_PROTO_NONE); + } + + /* If the master device is not itself a DSA slave in a disjoint DSA + * tree, then return immediately. + */ + return ds->ops->get_tag_protocol(ds, dp->index, tag_protocol); +} + +static int dsa_port_parse_cpu(struct dsa_port *dp, struct net_device *master, + const char *user_protocol) +{ + const struct dsa_device_ops *tag_ops = NULL; + struct dsa_switch *ds = dp->ds; + struct dsa_switch_tree *dst = ds->dst; + enum dsa_tag_protocol default_proto; + + /* Find out which protocol the switch would prefer. */ + default_proto = dsa_get_tag_protocol(dp, master); + if (dst->default_proto) { + if (dst->default_proto != default_proto) { + dev_err(ds->dev, + "A DSA switch tree can have only one tagging protocol\n"); + return -EINVAL; + } + } else { + dst->default_proto = default_proto; + } + + /* See if the user wants to override that preference. */ + if (user_protocol) { + if (!ds->ops->change_tag_protocol) { + dev_err(ds->dev, "Tag protocol cannot be modified\n"); + return -EINVAL; + } + + tag_ops = dsa_find_tagger_by_name(user_protocol); + if (IS_ERR(tag_ops)) { + dev_warn(ds->dev, + "Failed to find a tagging driver for protocol %s, using default\n", + user_protocol); + tag_ops = NULL; + } + } + + if (!tag_ops) + tag_ops = dsa_tag_driver_get(default_proto); + + if (IS_ERR(tag_ops)) { + if (PTR_ERR(tag_ops) == -ENOPROTOOPT) + return -EPROBE_DEFER; + + dev_warn(ds->dev, "No tagger for this switch\n"); + return PTR_ERR(tag_ops); + } + + if (dst->tag_ops) { + if (dst->tag_ops != tag_ops) { + dev_err(ds->dev, + "A DSA switch tree can have only one tagging protocol\n"); + + dsa_tag_driver_put(tag_ops); + return -EINVAL; + } + + /* In the case of multiple CPU ports per switch, the tagging + * protocol is still reference-counted only per switch tree. + */ + dsa_tag_driver_put(tag_ops); + } else { + dst->tag_ops = tag_ops; + } + + dp->master = master; + dp->type = DSA_PORT_TYPE_CPU; + dsa_port_set_tag_protocol(dp, dst->tag_ops); + dp->dst = dst; + + /* At this point, the tree may be configured to use a different + * tagger than the one chosen by the switch driver during + * .setup, in the case when a user selects a custom protocol + * through the DT. + * + * This is resolved by syncing the driver with the tree in + * dsa_switch_setup_tag_protocol once .setup has run and the + * driver is ready to accept calls to .change_tag_protocol. If + * the driver does not support the custom protocol at that + * point, the tree is wholly rejected, thereby ensuring that the + * tree and driver are always in agreement on the protocol to + * use. + */ + return 0; +} + +static int dsa_port_parse_of(struct dsa_port *dp, struct device_node *dn) +{ + struct device_node *ethernet = of_parse_phandle(dn, "ethernet", 0); + const char *name = of_get_property(dn, "label", NULL); + bool link = of_property_read_bool(dn, "link"); + + dp->dn = dn; + + if (ethernet) { + struct net_device *master; + const char *user_protocol; + + master = of_find_net_device_by_node(ethernet); + of_node_put(ethernet); + if (!master) + return -EPROBE_DEFER; + + user_protocol = of_get_property(dn, "dsa-tag-protocol", NULL); + return dsa_port_parse_cpu(dp, master, user_protocol); + } + + if (link) + return dsa_port_parse_dsa(dp); + + return dsa_port_parse_user(dp, name); +} + +static int dsa_switch_parse_ports_of(struct dsa_switch *ds, + struct device_node *dn) +{ + struct device_node *ports, *port; + struct dsa_port *dp; + int err = 0; + u32 reg; + + ports = of_get_child_by_name(dn, "ports"); + if (!ports) { + /* The second possibility is "ethernet-ports" */ + ports = of_get_child_by_name(dn, "ethernet-ports"); + if (!ports) { + dev_err(ds->dev, "no ports child node found\n"); + return -EINVAL; + } + } + + for_each_available_child_of_node(ports, port) { + err = of_property_read_u32(port, "reg", ®); + if (err) { + of_node_put(port); + goto out_put_node; + } + + if (reg >= ds->num_ports) { + dev_err(ds->dev, "port %pOF index %u exceeds num_ports (%u)\n", + port, reg, ds->num_ports); + of_node_put(port); + err = -EINVAL; + goto out_put_node; + } + + dp = dsa_to_port(ds, reg); + + err = dsa_port_parse_of(dp, port); + if (err) { + of_node_put(port); + goto out_put_node; + } + } + +out_put_node: + of_node_put(ports); + return err; +} + +static int dsa_switch_parse_member_of(struct dsa_switch *ds, + struct device_node *dn) +{ + u32 m[2] = { 0, 0 }; + int sz; + + /* Don't error out if this optional property isn't found */ + sz = of_property_read_variable_u32_array(dn, "dsa,member", m, 2, 2); + if (sz < 0 && sz != -EINVAL) + return sz; + + ds->index = m[1]; + + ds->dst = dsa_tree_touch(m[0]); + if (!ds->dst) + return -ENOMEM; + + if (dsa_switch_find(ds->dst->index, ds->index)) { + dev_err(ds->dev, + "A DSA switch with index %d already exists in tree %d\n", + ds->index, ds->dst->index); + return -EEXIST; + } + + if (ds->dst->last_switch < ds->index) + ds->dst->last_switch = ds->index; + + return 0; +} + +static int dsa_switch_touch_ports(struct dsa_switch *ds) +{ + struct dsa_port *dp; + int port; + + for (port = 0; port < ds->num_ports; port++) { + dp = dsa_port_touch(ds, port); + if (!dp) + return -ENOMEM; + } + + return 0; +} + +static int dsa_switch_parse_of(struct dsa_switch *ds, struct device_node *dn) +{ + int err; + + err = dsa_switch_parse_member_of(ds, dn); + if (err) + return err; + + err = dsa_switch_touch_ports(ds); + if (err) + return err; + + return dsa_switch_parse_ports_of(ds, dn); +} + +static int dsa_port_parse(struct dsa_port *dp, const char *name, + struct device *dev) +{ + if (!strcmp(name, "cpu")) { + struct net_device *master; + + master = dsa_dev_to_net_device(dev); + if (!master) + return -EPROBE_DEFER; + + dev_put(master); + + return dsa_port_parse_cpu(dp, master, NULL); + } + + if (!strcmp(name, "dsa")) + return dsa_port_parse_dsa(dp); + + return dsa_port_parse_user(dp, name); +} + +static int dsa_switch_parse_ports(struct dsa_switch *ds, + struct dsa_chip_data *cd) +{ + bool valid_name_found = false; + struct dsa_port *dp; + struct device *dev; + const char *name; + unsigned int i; + int err; + + for (i = 0; i < DSA_MAX_PORTS; i++) { + name = cd->port_names[i]; + dev = cd->netdev[i]; + dp = dsa_to_port(ds, i); + + if (!name) + continue; + + err = dsa_port_parse(dp, name, dev); + if (err) + return err; + + valid_name_found = true; + } + + if (!valid_name_found && i == DSA_MAX_PORTS) + return -EINVAL; + + return 0; +} + +static int dsa_switch_parse(struct dsa_switch *ds, struct dsa_chip_data *cd) +{ + int err; + + ds->cd = cd; + + /* We don't support interconnected switches nor multiple trees via + * platform data, so this is the unique switch of the tree. + */ + ds->index = 0; + ds->dst = dsa_tree_touch(0); + if (!ds->dst) + return -ENOMEM; + + err = dsa_switch_touch_ports(ds); + if (err) + return err; + + return dsa_switch_parse_ports(ds, cd); +} + +static void dsa_switch_release_ports(struct dsa_switch *ds) +{ + struct dsa_port *dp, *next; + + dsa_switch_for_each_port_safe(dp, next, ds) { + WARN_ON(!list_empty(&dp->fdbs)); + WARN_ON(!list_empty(&dp->mdbs)); + WARN_ON(!list_empty(&dp->vlans)); + list_del(&dp->list); + kfree(dp); + } +} + +static int dsa_switch_probe(struct dsa_switch *ds) +{ + struct dsa_switch_tree *dst; + struct dsa_chip_data *pdata; + struct device_node *np; + int err; + + if (!ds->dev) + return -ENODEV; + + pdata = ds->dev->platform_data; + np = ds->dev->of_node; + + if (!ds->num_ports) + return -EINVAL; + + if (np) { + err = dsa_switch_parse_of(ds, np); + if (err) + dsa_switch_release_ports(ds); + } else if (pdata) { + err = dsa_switch_parse(ds, pdata); + if (err) + dsa_switch_release_ports(ds); + } else { + err = -ENODEV; + } + + if (err) + return err; + + dst = ds->dst; + dsa_tree_get(dst); + err = dsa_tree_setup(dst); + if (err) { + dsa_switch_release_ports(ds); + dsa_tree_put(dst); + } + + return err; +} + +int dsa_register_switch(struct dsa_switch *ds) +{ + int err; + + mutex_lock(&dsa2_mutex); + err = dsa_switch_probe(ds); + dsa_tree_put(ds->dst); + mutex_unlock(&dsa2_mutex); + + return err; +} +EXPORT_SYMBOL_GPL(dsa_register_switch); + +static void dsa_switch_remove(struct dsa_switch *ds) +{ + struct dsa_switch_tree *dst = ds->dst; + + dsa_tree_teardown(dst); + dsa_switch_release_ports(ds); + dsa_tree_put(dst); +} + +void dsa_unregister_switch(struct dsa_switch *ds) +{ + mutex_lock(&dsa2_mutex); + dsa_switch_remove(ds); + mutex_unlock(&dsa2_mutex); +} +EXPORT_SYMBOL_GPL(dsa_unregister_switch); + +/* If the DSA master chooses to unregister its net_device on .shutdown, DSA is + * blocking that operation from completion, due to the dev_hold taken inside + * netdev_upper_dev_link. Unlink the DSA slave interfaces from being uppers of + * the DSA master, so that the system can reboot successfully. + */ +void dsa_switch_shutdown(struct dsa_switch *ds) +{ + struct net_device *master, *slave_dev; + struct dsa_port *dp; + + mutex_lock(&dsa2_mutex); + + if (!ds->setup) + goto out; + + rtnl_lock(); + + dsa_switch_for_each_user_port(dp, ds) { + master = dsa_port_to_master(dp); + slave_dev = dp->slave; + + netdev_upper_dev_unlink(master, slave_dev); + } + + /* Disconnect from further netdevice notifiers on the master, + * since netdev_uses_dsa() will now return false. + */ + dsa_switch_for_each_cpu_port(dp, ds) + dp->master->dsa_ptr = NULL; + + rtnl_unlock(); +out: + mutex_unlock(&dsa2_mutex); +} +EXPORT_SYMBOL_GPL(dsa_switch_shutdown); diff --git a/net/dsa/dsa_priv.h b/net/dsa/dsa_priv.h new file mode 100644 index 000000000..71e9707d1 --- /dev/null +++ b/net/dsa/dsa_priv.h @@ -0,0 +1,588 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * net/dsa/dsa_priv.h - Hardware switch handling + * Copyright (c) 2008-2009 Marvell Semiconductor + */ + +#ifndef __DSA_PRIV_H +#define __DSA_PRIV_H + +#include +#include +#include +#include +#include +#include +#include + +#define DSA_MAX_NUM_OFFLOADING_BRIDGES BITS_PER_LONG + +enum { + DSA_NOTIFIER_AGEING_TIME, + DSA_NOTIFIER_BRIDGE_JOIN, + DSA_NOTIFIER_BRIDGE_LEAVE, + DSA_NOTIFIER_FDB_ADD, + DSA_NOTIFIER_FDB_DEL, + DSA_NOTIFIER_HOST_FDB_ADD, + DSA_NOTIFIER_HOST_FDB_DEL, + DSA_NOTIFIER_LAG_FDB_ADD, + DSA_NOTIFIER_LAG_FDB_DEL, + DSA_NOTIFIER_LAG_CHANGE, + DSA_NOTIFIER_LAG_JOIN, + DSA_NOTIFIER_LAG_LEAVE, + DSA_NOTIFIER_MDB_ADD, + DSA_NOTIFIER_MDB_DEL, + DSA_NOTIFIER_HOST_MDB_ADD, + DSA_NOTIFIER_HOST_MDB_DEL, + DSA_NOTIFIER_VLAN_ADD, + DSA_NOTIFIER_VLAN_DEL, + DSA_NOTIFIER_HOST_VLAN_ADD, + DSA_NOTIFIER_HOST_VLAN_DEL, + DSA_NOTIFIER_MTU, + DSA_NOTIFIER_TAG_PROTO, + DSA_NOTIFIER_TAG_PROTO_CONNECT, + DSA_NOTIFIER_TAG_PROTO_DISCONNECT, + DSA_NOTIFIER_TAG_8021Q_VLAN_ADD, + DSA_NOTIFIER_TAG_8021Q_VLAN_DEL, + DSA_NOTIFIER_MASTER_STATE_CHANGE, +}; + +/* DSA_NOTIFIER_AGEING_TIME */ +struct dsa_notifier_ageing_time_info { + unsigned int ageing_time; +}; + +/* DSA_NOTIFIER_BRIDGE_* */ +struct dsa_notifier_bridge_info { + const struct dsa_port *dp; + struct dsa_bridge bridge; + bool tx_fwd_offload; + struct netlink_ext_ack *extack; +}; + +/* DSA_NOTIFIER_FDB_* */ +struct dsa_notifier_fdb_info { + const struct dsa_port *dp; + const unsigned char *addr; + u16 vid; + struct dsa_db db; +}; + +/* DSA_NOTIFIER_LAG_FDB_* */ +struct dsa_notifier_lag_fdb_info { + struct dsa_lag *lag; + const unsigned char *addr; + u16 vid; + struct dsa_db db; +}; + +/* DSA_NOTIFIER_MDB_* */ +struct dsa_notifier_mdb_info { + const struct dsa_port *dp; + const struct switchdev_obj_port_mdb *mdb; + struct dsa_db db; +}; + +/* DSA_NOTIFIER_LAG_* */ +struct dsa_notifier_lag_info { + const struct dsa_port *dp; + struct dsa_lag lag; + struct netdev_lag_upper_info *info; + struct netlink_ext_ack *extack; +}; + +/* DSA_NOTIFIER_VLAN_* */ +struct dsa_notifier_vlan_info { + const struct dsa_port *dp; + const struct switchdev_obj_port_vlan *vlan; + struct netlink_ext_ack *extack; +}; + +/* DSA_NOTIFIER_MTU */ +struct dsa_notifier_mtu_info { + const struct dsa_port *dp; + int mtu; +}; + +/* DSA_NOTIFIER_TAG_PROTO_* */ +struct dsa_notifier_tag_proto_info { + const struct dsa_device_ops *tag_ops; +}; + +/* DSA_NOTIFIER_TAG_8021Q_VLAN_* */ +struct dsa_notifier_tag_8021q_vlan_info { + const struct dsa_port *dp; + u16 vid; +}; + +/* DSA_NOTIFIER_MASTER_STATE_CHANGE */ +struct dsa_notifier_master_state_info { + const struct net_device *master; + bool operational; +}; + +struct dsa_switchdev_event_work { + struct net_device *dev; + struct net_device *orig_dev; + struct work_struct work; + unsigned long event; + /* Specific for SWITCHDEV_FDB_ADD_TO_DEVICE and + * SWITCHDEV_FDB_DEL_TO_DEVICE + */ + unsigned char addr[ETH_ALEN]; + u16 vid; + bool host_addr; +}; + +enum dsa_standalone_event { + DSA_UC_ADD, + DSA_UC_DEL, + DSA_MC_ADD, + DSA_MC_DEL, +}; + +struct dsa_standalone_event_work { + struct work_struct work; + struct net_device *dev; + enum dsa_standalone_event event; + unsigned char addr[ETH_ALEN]; + u16 vid; +}; + +struct dsa_slave_priv { + /* Copy of CPU port xmit for faster access in slave transmit hot path */ + struct sk_buff * (*xmit)(struct sk_buff *skb, + struct net_device *dev); + + struct gro_cells gcells; + + /* DSA port data, such as switch, port index, etc. */ + struct dsa_port *dp; + +#ifdef CONFIG_NET_POLL_CONTROLLER + struct netpoll *netpoll; +#endif + + /* TC context */ + struct list_head mall_tc_list; +}; + +/* dsa.c */ +const struct dsa_device_ops *dsa_tag_driver_get(int tag_protocol); +void dsa_tag_driver_put(const struct dsa_device_ops *ops); +const struct dsa_device_ops *dsa_find_tagger_by_name(const char *buf); + +bool dsa_db_equal(const struct dsa_db *a, const struct dsa_db *b); + +bool dsa_schedule_work(struct work_struct *work); +const char *dsa_tag_protocol_to_str(const struct dsa_device_ops *ops); + +static inline int dsa_tag_protocol_overhead(const struct dsa_device_ops *ops) +{ + return ops->needed_headroom + ops->needed_tailroom; +} + +/* master.c */ +int dsa_master_setup(struct net_device *dev, struct dsa_port *cpu_dp); +void dsa_master_teardown(struct net_device *dev); +int dsa_master_lag_setup(struct net_device *lag_dev, struct dsa_port *cpu_dp, + struct netdev_lag_upper_info *uinfo, + struct netlink_ext_ack *extack); +void dsa_master_lag_teardown(struct net_device *lag_dev, + struct dsa_port *cpu_dp); + +static inline struct net_device *dsa_master_find_slave(struct net_device *dev, + int device, int port) +{ + struct dsa_port *cpu_dp = dev->dsa_ptr; + struct dsa_switch_tree *dst = cpu_dp->dst; + struct dsa_port *dp; + + list_for_each_entry(dp, &dst->ports, list) + if (dp->ds->index == device && dp->index == port && + dp->type == DSA_PORT_TYPE_USER) + return dp->slave; + + return NULL; +} + +/* netlink.c */ +extern struct rtnl_link_ops dsa_link_ops __read_mostly; + +/* port.c */ +bool dsa_port_supports_hwtstamp(struct dsa_port *dp, struct ifreq *ifr); +void dsa_port_set_tag_protocol(struct dsa_port *cpu_dp, + const struct dsa_device_ops *tag_ops); +int dsa_port_set_state(struct dsa_port *dp, u8 state, bool do_fast_age); +int dsa_port_set_mst_state(struct dsa_port *dp, + const struct switchdev_mst_state *state, + struct netlink_ext_ack *extack); +int dsa_port_enable_rt(struct dsa_port *dp, struct phy_device *phy); +int dsa_port_enable(struct dsa_port *dp, struct phy_device *phy); +void dsa_port_disable_rt(struct dsa_port *dp); +void dsa_port_disable(struct dsa_port *dp); +int dsa_port_bridge_join(struct dsa_port *dp, struct net_device *br, + struct netlink_ext_ack *extack); +void dsa_port_pre_bridge_leave(struct dsa_port *dp, struct net_device *br); +void dsa_port_bridge_leave(struct dsa_port *dp, struct net_device *br); +int dsa_port_lag_change(struct dsa_port *dp, + struct netdev_lag_lower_state_info *linfo); +int dsa_port_lag_join(struct dsa_port *dp, struct net_device *lag_dev, + struct netdev_lag_upper_info *uinfo, + struct netlink_ext_ack *extack); +void dsa_port_pre_lag_leave(struct dsa_port *dp, struct net_device *lag_dev); +void dsa_port_lag_leave(struct dsa_port *dp, struct net_device *lag_dev); +int dsa_port_vlan_filtering(struct dsa_port *dp, bool vlan_filtering, + struct netlink_ext_ack *extack); +bool dsa_port_skip_vlan_configuration(struct dsa_port *dp); +int dsa_port_ageing_time(struct dsa_port *dp, clock_t ageing_clock); +int dsa_port_mst_enable(struct dsa_port *dp, bool on, + struct netlink_ext_ack *extack); +int dsa_port_vlan_msti(struct dsa_port *dp, + const struct switchdev_vlan_msti *msti); +int dsa_port_mtu_change(struct dsa_port *dp, int new_mtu); +int dsa_port_fdb_add(struct dsa_port *dp, const unsigned char *addr, + u16 vid); +int dsa_port_fdb_del(struct dsa_port *dp, const unsigned char *addr, + u16 vid); +int dsa_port_standalone_host_fdb_add(struct dsa_port *dp, + const unsigned char *addr, u16 vid); +int dsa_port_standalone_host_fdb_del(struct dsa_port *dp, + const unsigned char *addr, u16 vid); +int dsa_port_bridge_host_fdb_add(struct dsa_port *dp, const unsigned char *addr, + u16 vid); +int dsa_port_bridge_host_fdb_del(struct dsa_port *dp, const unsigned char *addr, + u16 vid); +int dsa_port_lag_fdb_add(struct dsa_port *dp, const unsigned char *addr, + u16 vid); +int dsa_port_lag_fdb_del(struct dsa_port *dp, const unsigned char *addr, + u16 vid); +int dsa_port_fdb_dump(struct dsa_port *dp, dsa_fdb_dump_cb_t *cb, void *data); +int dsa_port_mdb_add(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb); +int dsa_port_mdb_del(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb); +int dsa_port_standalone_host_mdb_add(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb); +int dsa_port_standalone_host_mdb_del(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb); +int dsa_port_bridge_host_mdb_add(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb); +int dsa_port_bridge_host_mdb_del(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb); +int dsa_port_pre_bridge_flags(const struct dsa_port *dp, + struct switchdev_brport_flags flags, + struct netlink_ext_ack *extack); +int dsa_port_bridge_flags(struct dsa_port *dp, + struct switchdev_brport_flags flags, + struct netlink_ext_ack *extack); +int dsa_port_vlan_add(struct dsa_port *dp, + const struct switchdev_obj_port_vlan *vlan, + struct netlink_ext_ack *extack); +int dsa_port_vlan_del(struct dsa_port *dp, + const struct switchdev_obj_port_vlan *vlan); +int dsa_port_host_vlan_add(struct dsa_port *dp, + const struct switchdev_obj_port_vlan *vlan, + struct netlink_ext_ack *extack); +int dsa_port_host_vlan_del(struct dsa_port *dp, + const struct switchdev_obj_port_vlan *vlan); +int dsa_port_mrp_add(const struct dsa_port *dp, + const struct switchdev_obj_mrp *mrp); +int dsa_port_mrp_del(const struct dsa_port *dp, + const struct switchdev_obj_mrp *mrp); +int dsa_port_mrp_add_ring_role(const struct dsa_port *dp, + const struct switchdev_obj_ring_role_mrp *mrp); +int dsa_port_mrp_del_ring_role(const struct dsa_port *dp, + const struct switchdev_obj_ring_role_mrp *mrp); +int dsa_port_phylink_create(struct dsa_port *dp); +void dsa_port_phylink_destroy(struct dsa_port *dp); +int dsa_shared_port_link_register_of(struct dsa_port *dp); +void dsa_shared_port_link_unregister_of(struct dsa_port *dp); +int dsa_port_hsr_join(struct dsa_port *dp, struct net_device *hsr); +void dsa_port_hsr_leave(struct dsa_port *dp, struct net_device *hsr); +int dsa_port_tag_8021q_vlan_add(struct dsa_port *dp, u16 vid, bool broadcast); +void dsa_port_tag_8021q_vlan_del(struct dsa_port *dp, u16 vid, bool broadcast); +void dsa_port_set_host_flood(struct dsa_port *dp, bool uc, bool mc); +int dsa_port_change_master(struct dsa_port *dp, struct net_device *master, + struct netlink_ext_ack *extack); + +/* slave.c */ +extern const struct dsa_device_ops notag_netdev_ops; +extern struct notifier_block dsa_slave_switchdev_notifier; +extern struct notifier_block dsa_slave_switchdev_blocking_notifier; + +void dsa_slave_mii_bus_init(struct dsa_switch *ds); +int dsa_slave_create(struct dsa_port *dp); +void dsa_slave_destroy(struct net_device *slave_dev); +int dsa_slave_suspend(struct net_device *slave_dev); +int dsa_slave_resume(struct net_device *slave_dev); +int dsa_slave_register_notifier(void); +void dsa_slave_unregister_notifier(void); +void dsa_slave_sync_ha(struct net_device *dev); +void dsa_slave_unsync_ha(struct net_device *dev); +void dsa_slave_setup_tagger(struct net_device *slave); +int dsa_slave_change_mtu(struct net_device *dev, int new_mtu); +int dsa_slave_change_master(struct net_device *dev, struct net_device *master, + struct netlink_ext_ack *extack); +int dsa_slave_manage_vlan_filtering(struct net_device *dev, + bool vlan_filtering); + +static inline struct dsa_port *dsa_slave_to_port(const struct net_device *dev) +{ + struct dsa_slave_priv *p = netdev_priv(dev); + + return p->dp; +} + +static inline struct net_device * +dsa_slave_to_master(const struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + + return dsa_port_to_master(dp); +} + +/* If under a bridge with vlan_filtering=0, make sure to send pvid-tagged + * frames as untagged, since the bridge will not untag them. + */ +static inline struct sk_buff *dsa_untag_bridge_pvid(struct sk_buff *skb) +{ + struct dsa_port *dp = dsa_slave_to_port(skb->dev); + struct net_device *br = dsa_port_bridge_dev_get(dp); + struct net_device *dev = skb->dev; + struct net_device *upper_dev; + u16 vid, pvid, proto; + int err; + + if (!br || br_vlan_enabled(br)) + return skb; + + err = br_vlan_get_proto(br, &proto); + if (err) + return skb; + + /* Move VLAN tag from data to hwaccel */ + if (!skb_vlan_tag_present(skb) && skb->protocol == htons(proto)) { + skb = skb_vlan_untag(skb); + if (!skb) + return NULL; + } + + if (!skb_vlan_tag_present(skb)) + return skb; + + vid = skb_vlan_tag_get_id(skb); + + /* We already run under an RCU read-side critical section since + * we are called from netif_receive_skb_list_internal(). + */ + err = br_vlan_get_pvid_rcu(dev, &pvid); + if (err) + return skb; + + if (vid != pvid) + return skb; + + /* The sad part about attempting to untag from DSA is that we + * don't know, unless we check, if the skb will end up in + * the bridge's data path - br_allowed_ingress() - or not. + * For example, there might be an 8021q upper for the + * default_pvid of the bridge, which will steal VLAN-tagged traffic + * from the bridge's data path. This is a configuration that DSA + * supports because vlan_filtering is 0. In that case, we should + * definitely keep the tag, to make sure it keeps working. + */ + upper_dev = __vlan_find_dev_deep_rcu(br, htons(proto), vid); + if (upper_dev) + return skb; + + __vlan_hwaccel_clear_tag(skb); + + return skb; +} + +/* For switches without hardware support for DSA tagging to be able + * to support termination through the bridge. + */ +static inline struct net_device * +dsa_find_designated_bridge_port_by_vid(struct net_device *master, u16 vid) +{ + struct dsa_port *cpu_dp = master->dsa_ptr; + struct dsa_switch_tree *dst = cpu_dp->dst; + struct bridge_vlan_info vinfo; + struct net_device *slave; + struct dsa_port *dp; + int err; + + list_for_each_entry(dp, &dst->ports, list) { + if (dp->type != DSA_PORT_TYPE_USER) + continue; + + if (!dp->bridge) + continue; + + if (dp->stp_state != BR_STATE_LEARNING && + dp->stp_state != BR_STATE_FORWARDING) + continue; + + /* Since the bridge might learn this packet, keep the CPU port + * affinity with the port that will be used for the reply on + * xmit. + */ + if (dp->cpu_dp != cpu_dp) + continue; + + slave = dp->slave; + + err = br_vlan_get_info_rcu(slave, vid, &vinfo); + if (err) + continue; + + return slave; + } + + return NULL; +} + +/* If the ingress port offloads the bridge, we mark the frame as autonomously + * forwarded by hardware, so the software bridge doesn't forward in twice, back + * to us, because we already did. However, if we're in fallback mode and we do + * software bridging, we are not offloading it, therefore the dp->bridge + * pointer is not populated, and flooding needs to be done by software (we are + * effectively operating in standalone ports mode). + */ +static inline void dsa_default_offload_fwd_mark(struct sk_buff *skb) +{ + struct dsa_port *dp = dsa_slave_to_port(skb->dev); + + skb->offload_fwd_mark = !!(dp->bridge); +} + +/* Helper for removing DSA header tags from packets in the RX path. + * Must not be called before skb_pull(len). + * skb->data + * | + * v + * | | | | | | | | | | | | | | | | | | | + * +-----------------------+-----------------------+---------------+-------+ + * | Destination MAC | Source MAC | DSA header | EType | + * +-----------------------+-----------------------+---------------+-------+ + * | | + * <----- len -----> <----- len -----> + * | + * >>>>>>> v + * >>>>>>> | | | | | | | | | | | | | | | + * >>>>>>> +-----------------------+-----------------------+-------+ + * >>>>>>> | Destination MAC | Source MAC | EType | + * +-----------------------+-----------------------+-------+ + * ^ + * | + * skb->data + */ +static inline void dsa_strip_etype_header(struct sk_buff *skb, int len) +{ + memmove(skb->data - ETH_HLEN, skb->data - ETH_HLEN - len, 2 * ETH_ALEN); +} + +/* Helper for creating space for DSA header tags in TX path packets. + * Must not be called before skb_push(len). + * + * Before: + * + * <<<<<<< | | | | | | | | | | | | | | | + * ^ <<<<<<< +-----------------------+-----------------------+-------+ + * | <<<<<<< | Destination MAC | Source MAC | EType | + * | +-----------------------+-----------------------+-------+ + * <----- len -----> + * | + * | + * skb->data + * + * After: + * + * | | | | | | | | | | | | | | | | | | | + * +-----------------------+-----------------------+---------------+-------+ + * | Destination MAC | Source MAC | DSA header | EType | + * +-----------------------+-----------------------+---------------+-------+ + * ^ | | + * | <----- len -----> + * skb->data + */ +static inline void dsa_alloc_etype_header(struct sk_buff *skb, int len) +{ + memmove(skb->data, skb->data + len, 2 * ETH_ALEN); +} + +/* On RX, eth_type_trans() on the DSA master pulls ETH_HLEN bytes starting from + * skb_mac_header(skb), which leaves skb->data pointing at the first byte after + * what the DSA master perceives as the EtherType (the beginning of the L3 + * protocol). Since DSA EtherType header taggers treat the EtherType as part of + * the DSA tag itself, and the EtherType is 2 bytes in length, the DSA header + * is located 2 bytes behind skb->data. Note that EtherType in this context + * means the first 2 bytes of the DSA header, not the encapsulated EtherType + * that will become visible after the DSA header is stripped. + */ +static inline void *dsa_etype_header_pos_rx(struct sk_buff *skb) +{ + return skb->data - 2; +} + +/* On TX, skb->data points to skb_mac_header(skb), which means that EtherType + * header taggers start exactly where the EtherType is (the EtherType is + * treated as part of the DSA header). + */ +static inline void *dsa_etype_header_pos_tx(struct sk_buff *skb) +{ + return skb->data + 2 * ETH_ALEN; +} + +/* switch.c */ +int dsa_switch_register_notifier(struct dsa_switch *ds); +void dsa_switch_unregister_notifier(struct dsa_switch *ds); + +static inline bool dsa_switch_supports_uc_filtering(struct dsa_switch *ds) +{ + return ds->ops->port_fdb_add && ds->ops->port_fdb_del && + ds->fdb_isolation && !ds->vlan_filtering_is_global && + !ds->needs_standalone_vlan_filtering; +} + +static inline bool dsa_switch_supports_mc_filtering(struct dsa_switch *ds) +{ + return ds->ops->port_mdb_add && ds->ops->port_mdb_del && + ds->fdb_isolation && !ds->vlan_filtering_is_global && + !ds->needs_standalone_vlan_filtering; +} + +/* dsa2.c */ +void dsa_lag_map(struct dsa_switch_tree *dst, struct dsa_lag *lag); +void dsa_lag_unmap(struct dsa_switch_tree *dst, struct dsa_lag *lag); +struct dsa_lag *dsa_tree_lag_find(struct dsa_switch_tree *dst, + const struct net_device *lag_dev); +struct net_device *dsa_tree_find_first_master(struct dsa_switch_tree *dst); +int dsa_tree_notify(struct dsa_switch_tree *dst, unsigned long e, void *v); +int dsa_broadcast(unsigned long e, void *v); +int dsa_tree_change_tag_proto(struct dsa_switch_tree *dst, + const struct dsa_device_ops *tag_ops, + const struct dsa_device_ops *old_tag_ops); +void dsa_tree_master_admin_state_change(struct dsa_switch_tree *dst, + struct net_device *master, + bool up); +void dsa_tree_master_oper_state_change(struct dsa_switch_tree *dst, + struct net_device *master, + bool up); +unsigned int dsa_bridge_num_get(const struct net_device *bridge_dev, int max); +void dsa_bridge_num_put(const struct net_device *bridge_dev, + unsigned int bridge_num); +struct dsa_bridge *dsa_tree_bridge_find(struct dsa_switch_tree *dst, + const struct net_device *br); + +/* tag_8021q.c */ +int dsa_switch_tag_8021q_vlan_add(struct dsa_switch *ds, + struct dsa_notifier_tag_8021q_vlan_info *info); +int dsa_switch_tag_8021q_vlan_del(struct dsa_switch *ds, + struct dsa_notifier_tag_8021q_vlan_info *info); + +extern struct list_head dsa_tree_list; + +#endif diff --git a/net/dsa/master.c b/net/dsa/master.c new file mode 100644 index 000000000..421de1665 --- /dev/null +++ b/net/dsa/master.c @@ -0,0 +1,478 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Handling of a master device, switching frames via its switch fabric CPU port + * + * Copyright (c) 2017 Savoir-faire Linux Inc. + * Vivien Didelot + */ + +#include "dsa_priv.h" + +static int dsa_master_get_regs_len(struct net_device *dev) +{ + struct dsa_port *cpu_dp = dev->dsa_ptr; + const struct ethtool_ops *ops = cpu_dp->orig_ethtool_ops; + struct dsa_switch *ds = cpu_dp->ds; + int port = cpu_dp->index; + int ret = 0; + int len; + + if (ops->get_regs_len) { + len = ops->get_regs_len(dev); + if (len < 0) + return len; + ret += len; + } + + ret += sizeof(struct ethtool_drvinfo); + ret += sizeof(struct ethtool_regs); + + if (ds->ops->get_regs_len) { + len = ds->ops->get_regs_len(ds, port); + if (len < 0) + return len; + ret += len; + } + + return ret; +} + +static void dsa_master_get_regs(struct net_device *dev, + struct ethtool_regs *regs, void *data) +{ + struct dsa_port *cpu_dp = dev->dsa_ptr; + const struct ethtool_ops *ops = cpu_dp->orig_ethtool_ops; + struct dsa_switch *ds = cpu_dp->ds; + struct ethtool_drvinfo *cpu_info; + struct ethtool_regs *cpu_regs; + int port = cpu_dp->index; + int len; + + if (ops->get_regs_len && ops->get_regs) { + len = ops->get_regs_len(dev); + if (len < 0) + return; + regs->len = len; + ops->get_regs(dev, regs, data); + data += regs->len; + } + + cpu_info = (struct ethtool_drvinfo *)data; + strscpy(cpu_info->driver, "dsa", sizeof(cpu_info->driver)); + data += sizeof(*cpu_info); + cpu_regs = (struct ethtool_regs *)data; + data += sizeof(*cpu_regs); + + if (ds->ops->get_regs_len && ds->ops->get_regs) { + len = ds->ops->get_regs_len(ds, port); + if (len < 0) + return; + cpu_regs->len = len; + ds->ops->get_regs(ds, port, cpu_regs, data); + } +} + +static void dsa_master_get_ethtool_stats(struct net_device *dev, + struct ethtool_stats *stats, + uint64_t *data) +{ + struct dsa_port *cpu_dp = dev->dsa_ptr; + const struct ethtool_ops *ops = cpu_dp->orig_ethtool_ops; + struct dsa_switch *ds = cpu_dp->ds; + int port = cpu_dp->index; + int count = 0; + + if (ops->get_sset_count && ops->get_ethtool_stats) { + count = ops->get_sset_count(dev, ETH_SS_STATS); + ops->get_ethtool_stats(dev, stats, data); + } + + if (ds->ops->get_ethtool_stats) + ds->ops->get_ethtool_stats(ds, port, data + count); +} + +static void dsa_master_get_ethtool_phy_stats(struct net_device *dev, + struct ethtool_stats *stats, + uint64_t *data) +{ + struct dsa_port *cpu_dp = dev->dsa_ptr; + const struct ethtool_ops *ops = cpu_dp->orig_ethtool_ops; + struct dsa_switch *ds = cpu_dp->ds; + int port = cpu_dp->index; + int count = 0; + + if (dev->phydev && !ops->get_ethtool_phy_stats) { + count = phy_ethtool_get_sset_count(dev->phydev); + if (count >= 0) + phy_ethtool_get_stats(dev->phydev, stats, data); + } else if (ops->get_sset_count && ops->get_ethtool_phy_stats) { + count = ops->get_sset_count(dev, ETH_SS_PHY_STATS); + ops->get_ethtool_phy_stats(dev, stats, data); + } + + if (count < 0) + count = 0; + + if (ds->ops->get_ethtool_phy_stats) + ds->ops->get_ethtool_phy_stats(ds, port, data + count); +} + +static int dsa_master_get_sset_count(struct net_device *dev, int sset) +{ + struct dsa_port *cpu_dp = dev->dsa_ptr; + const struct ethtool_ops *ops = cpu_dp->orig_ethtool_ops; + struct dsa_switch *ds = cpu_dp->ds; + int count = 0; + + if (sset == ETH_SS_PHY_STATS && dev->phydev && + !ops->get_ethtool_phy_stats) + count = phy_ethtool_get_sset_count(dev->phydev); + else if (ops->get_sset_count) + count = ops->get_sset_count(dev, sset); + + if (count < 0) + count = 0; + + if (ds->ops->get_sset_count) + count += ds->ops->get_sset_count(ds, cpu_dp->index, sset); + + return count; +} + +static void dsa_master_get_strings(struct net_device *dev, uint32_t stringset, + uint8_t *data) +{ + struct dsa_port *cpu_dp = dev->dsa_ptr; + const struct ethtool_ops *ops = cpu_dp->orig_ethtool_ops; + struct dsa_switch *ds = cpu_dp->ds; + int port = cpu_dp->index; + int len = ETH_GSTRING_LEN; + int mcount = 0, count, i; + uint8_t pfx[4]; + uint8_t *ndata; + + snprintf(pfx, sizeof(pfx), "p%.2d", port); + /* We do not want to be NULL-terminated, since this is a prefix */ + pfx[sizeof(pfx) - 1] = '_'; + + if (stringset == ETH_SS_PHY_STATS && dev->phydev && + !ops->get_ethtool_phy_stats) { + mcount = phy_ethtool_get_sset_count(dev->phydev); + if (mcount < 0) + mcount = 0; + else + phy_ethtool_get_strings(dev->phydev, data); + } else if (ops->get_sset_count && ops->get_strings) { + mcount = ops->get_sset_count(dev, stringset); + if (mcount < 0) + mcount = 0; + ops->get_strings(dev, stringset, data); + } + + if (ds->ops->get_strings) { + ndata = data + mcount * len; + /* This function copies ETH_GSTRINGS_LEN bytes, we will mangle + * the output after to prepend our CPU port prefix we + * constructed earlier + */ + ds->ops->get_strings(ds, port, stringset, ndata); + count = ds->ops->get_sset_count(ds, port, stringset); + if (count < 0) + return; + for (i = 0; i < count; i++) { + memmove(ndata + (i * len + sizeof(pfx)), + ndata + i * len, len - sizeof(pfx)); + memcpy(ndata + i * len, pfx, sizeof(pfx)); + } + } +} + +static int dsa_master_ioctl(struct net_device *dev, struct ifreq *ifr, int cmd) +{ + struct dsa_port *cpu_dp = dev->dsa_ptr; + struct dsa_switch *ds = cpu_dp->ds; + struct dsa_switch_tree *dst; + int err = -EOPNOTSUPP; + struct dsa_port *dp; + + dst = ds->dst; + + switch (cmd) { + case SIOCGHWTSTAMP: + case SIOCSHWTSTAMP: + /* Deny PTP operations on master if there is at least one + * switch in the tree that is PTP capable. + */ + list_for_each_entry(dp, &dst->ports, list) + if (dsa_port_supports_hwtstamp(dp, ifr)) + return -EBUSY; + break; + } + + if (dev->netdev_ops->ndo_eth_ioctl) + err = dev->netdev_ops->ndo_eth_ioctl(dev, ifr, cmd); + + return err; +} + +static const struct dsa_netdevice_ops dsa_netdev_ops = { + .ndo_eth_ioctl = dsa_master_ioctl, +}; + +static int dsa_master_ethtool_setup(struct net_device *dev) +{ + struct dsa_port *cpu_dp = dev->dsa_ptr; + struct dsa_switch *ds = cpu_dp->ds; + struct ethtool_ops *ops; + + if (netif_is_lag_master(dev)) + return 0; + + ops = devm_kzalloc(ds->dev, sizeof(*ops), GFP_KERNEL); + if (!ops) + return -ENOMEM; + + cpu_dp->orig_ethtool_ops = dev->ethtool_ops; + if (cpu_dp->orig_ethtool_ops) + memcpy(ops, cpu_dp->orig_ethtool_ops, sizeof(*ops)); + + ops->get_regs_len = dsa_master_get_regs_len; + ops->get_regs = dsa_master_get_regs; + ops->get_sset_count = dsa_master_get_sset_count; + ops->get_ethtool_stats = dsa_master_get_ethtool_stats; + ops->get_strings = dsa_master_get_strings; + ops->get_ethtool_phy_stats = dsa_master_get_ethtool_phy_stats; + + dev->ethtool_ops = ops; + + return 0; +} + +static void dsa_master_ethtool_teardown(struct net_device *dev) +{ + struct dsa_port *cpu_dp = dev->dsa_ptr; + + if (netif_is_lag_master(dev)) + return; + + dev->ethtool_ops = cpu_dp->orig_ethtool_ops; + cpu_dp->orig_ethtool_ops = NULL; +} + +static void dsa_netdev_ops_set(struct net_device *dev, + const struct dsa_netdevice_ops *ops) +{ + if (netif_is_lag_master(dev)) + return; + + dev->dsa_ptr->netdev_ops = ops; +} + +/* Keep the master always promiscuous if the tagging protocol requires that + * (garbles MAC DA) or if it doesn't support unicast filtering, case in which + * it would revert to promiscuous mode as soon as we call dev_uc_add() on it + * anyway. + */ +static void dsa_master_set_promiscuity(struct net_device *dev, int inc) +{ + const struct dsa_device_ops *ops = dev->dsa_ptr->tag_ops; + + if ((dev->priv_flags & IFF_UNICAST_FLT) && !ops->promisc_on_master) + return; + + ASSERT_RTNL(); + + dev_set_promiscuity(dev, inc); +} + +static ssize_t tagging_show(struct device *d, struct device_attribute *attr, + char *buf) +{ + struct net_device *dev = to_net_dev(d); + struct dsa_port *cpu_dp = dev->dsa_ptr; + + return sprintf(buf, "%s\n", + dsa_tag_protocol_to_str(cpu_dp->tag_ops)); +} + +static ssize_t tagging_store(struct device *d, struct device_attribute *attr, + const char *buf, size_t count) +{ + const struct dsa_device_ops *new_tag_ops, *old_tag_ops; + struct net_device *dev = to_net_dev(d); + struct dsa_port *cpu_dp = dev->dsa_ptr; + int err; + + old_tag_ops = cpu_dp->tag_ops; + new_tag_ops = dsa_find_tagger_by_name(buf); + /* Bad tagger name, or module is not loaded? */ + if (IS_ERR(new_tag_ops)) + return PTR_ERR(new_tag_ops); + + if (new_tag_ops == old_tag_ops) + /* Drop the temporarily held duplicate reference, since + * the DSA switch tree uses this tagger. + */ + goto out; + + err = dsa_tree_change_tag_proto(cpu_dp->ds->dst, new_tag_ops, + old_tag_ops); + if (err) { + /* On failure the old tagger is restored, so we don't need the + * driver for the new one. + */ + dsa_tag_driver_put(new_tag_ops); + return err; + } + + /* On success we no longer need the module for the old tagging protocol + */ +out: + dsa_tag_driver_put(old_tag_ops); + return count; +} +static DEVICE_ATTR_RW(tagging); + +static struct attribute *dsa_slave_attrs[] = { + &dev_attr_tagging.attr, + NULL +}; + +static const struct attribute_group dsa_group = { + .name = "dsa", + .attrs = dsa_slave_attrs, +}; + +static void dsa_master_reset_mtu(struct net_device *dev) +{ + int err; + + err = dev_set_mtu(dev, ETH_DATA_LEN); + if (err) + netdev_dbg(dev, + "Unable to reset MTU to exclude DSA overheads\n"); +} + +int dsa_master_setup(struct net_device *dev, struct dsa_port *cpu_dp) +{ + const struct dsa_device_ops *tag_ops = cpu_dp->tag_ops; + struct dsa_switch *ds = cpu_dp->ds; + struct device_link *consumer_link; + int mtu, ret; + + mtu = ETH_DATA_LEN + dsa_tag_protocol_overhead(tag_ops); + + /* The DSA master must use SET_NETDEV_DEV for this to work. */ + if (!netif_is_lag_master(dev)) { + consumer_link = device_link_add(ds->dev, dev->dev.parent, + DL_FLAG_AUTOREMOVE_CONSUMER); + if (!consumer_link) + netdev_err(dev, + "Failed to create a device link to DSA switch %s\n", + dev_name(ds->dev)); + } + + /* The switch driver may not implement ->port_change_mtu(), case in + * which dsa_slave_change_mtu() will not update the master MTU either, + * so we need to do that here. + */ + ret = dev_set_mtu(dev, mtu); + if (ret) + netdev_warn(dev, "error %d setting MTU to %d to include DSA overhead\n", + ret, mtu); + + /* If we use a tagging format that doesn't have an ethertype + * field, make sure that all packets from this point on get + * sent to the tag format's receive function. + */ + wmb(); + + dev->dsa_ptr = cpu_dp; + + dsa_master_set_promiscuity(dev, 1); + + ret = dsa_master_ethtool_setup(dev); + if (ret) + goto out_err_reset_promisc; + + dsa_netdev_ops_set(dev, &dsa_netdev_ops); + + ret = sysfs_create_group(&dev->dev.kobj, &dsa_group); + if (ret) + goto out_err_ndo_teardown; + + return ret; + +out_err_ndo_teardown: + dsa_netdev_ops_set(dev, NULL); + dsa_master_ethtool_teardown(dev); +out_err_reset_promisc: + dsa_master_set_promiscuity(dev, -1); + return ret; +} + +void dsa_master_teardown(struct net_device *dev) +{ + sysfs_remove_group(&dev->dev.kobj, &dsa_group); + dsa_netdev_ops_set(dev, NULL); + dsa_master_ethtool_teardown(dev); + dsa_master_reset_mtu(dev); + dsa_master_set_promiscuity(dev, -1); + + dev->dsa_ptr = NULL; + + /* If we used a tagging format that doesn't have an ethertype + * field, make sure that all packets from this point get sent + * without the tag and go through the regular receive path. + */ + wmb(); +} + +int dsa_master_lag_setup(struct net_device *lag_dev, struct dsa_port *cpu_dp, + struct netdev_lag_upper_info *uinfo, + struct netlink_ext_ack *extack) +{ + bool master_setup = false; + int err; + + if (!netdev_uses_dsa(lag_dev)) { + err = dsa_master_setup(lag_dev, cpu_dp); + if (err) + return err; + + master_setup = true; + } + + err = dsa_port_lag_join(cpu_dp, lag_dev, uinfo, extack); + if (err) { + if (extack && !extack->_msg) + NL_SET_ERR_MSG_MOD(extack, + "CPU port failed to join LAG"); + goto out_master_teardown; + } + + return 0; + +out_master_teardown: + if (master_setup) + dsa_master_teardown(lag_dev); + return err; +} + +/* Tear down a master if there isn't any other user port on it, + * optionally also destroying LAG information. + */ +void dsa_master_lag_teardown(struct net_device *lag_dev, + struct dsa_port *cpu_dp) +{ + struct net_device *upper; + struct list_head *iter; + + dsa_port_lag_leave(cpu_dp, lag_dev); + + netdev_for_each_upper_dev_rcu(lag_dev, upper, iter) + if (dsa_slave_dev_check(upper)) + return; + + dsa_master_teardown(lag_dev); +} diff --git a/net/dsa/netlink.c b/net/dsa/netlink.c new file mode 100644 index 000000000..ecf9ed1de --- /dev/null +++ b/net/dsa/netlink.c @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright 2022 NXP + */ +#include +#include + +#include "dsa_priv.h" + +static const struct nla_policy dsa_policy[IFLA_DSA_MAX + 1] = { + [IFLA_DSA_MASTER] = { .type = NLA_U32 }, +}; + +static int dsa_changelink(struct net_device *dev, struct nlattr *tb[], + struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + int err; + + if (!data) + return 0; + + if (data[IFLA_DSA_MASTER]) { + u32 ifindex = nla_get_u32(data[IFLA_DSA_MASTER]); + struct net_device *master; + + master = __dev_get_by_index(dev_net(dev), ifindex); + if (!master) + return -EINVAL; + + err = dsa_slave_change_master(dev, master, extack); + if (err) + return err; + } + + return 0; +} + +static size_t dsa_get_size(const struct net_device *dev) +{ + return nla_total_size(sizeof(u32)) + /* IFLA_DSA_MASTER */ + 0; +} + +static int dsa_fill_info(struct sk_buff *skb, const struct net_device *dev) +{ + struct net_device *master = dsa_slave_to_master(dev); + + if (nla_put_u32(skb, IFLA_DSA_MASTER, master->ifindex)) + return -EMSGSIZE; + + return 0; +} + +struct rtnl_link_ops dsa_link_ops __read_mostly = { + .kind = "dsa", + .priv_size = sizeof(struct dsa_port), + .maxtype = IFLA_DSA_MAX, + .policy = dsa_policy, + .changelink = dsa_changelink, + .get_size = dsa_get_size, + .fill_info = dsa_fill_info, + .netns_refund = true, +}; diff --git a/net/dsa/port.c b/net/dsa/port.c new file mode 100644 index 000000000..750fe68d9 --- /dev/null +++ b/net/dsa/port.c @@ -0,0 +1,2083 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Handling of a single switch port + * + * Copyright (c) 2017 Savoir-faire Linux Inc. + * Vivien Didelot + */ + +#include +#include +#include +#include +#include + +#include "dsa_priv.h" + +/** + * dsa_port_notify - Notify the switching fabric of changes to a port + * @dp: port on which change occurred + * @e: event, must be of type DSA_NOTIFIER_* + * @v: event-specific value. + * + * Notify all switches in the DSA tree that this port's switch belongs to, + * including this switch itself, of an event. Allows the other switches to + * reconfigure themselves for cross-chip operations. Can also be used to + * reconfigure ports without net_devices (CPU ports, DSA links) whenever + * a user port's state changes. + */ +static int dsa_port_notify(const struct dsa_port *dp, unsigned long e, void *v) +{ + return dsa_tree_notify(dp->ds->dst, e, v); +} + +static void dsa_port_notify_bridge_fdb_flush(const struct dsa_port *dp, u16 vid) +{ + struct net_device *brport_dev = dsa_port_to_bridge_port(dp); + struct switchdev_notifier_fdb_info info = { + .vid = vid, + }; + + /* When the port becomes standalone it has already left the bridge. + * Don't notify the bridge in that case. + */ + if (!brport_dev) + return; + + call_switchdev_notifiers(SWITCHDEV_FDB_FLUSH_TO_BRIDGE, + brport_dev, &info.info, NULL); +} + +static void dsa_port_fast_age(const struct dsa_port *dp) +{ + struct dsa_switch *ds = dp->ds; + + if (!ds->ops->port_fast_age) + return; + + ds->ops->port_fast_age(ds, dp->index); + + /* flush all VLANs */ + dsa_port_notify_bridge_fdb_flush(dp, 0); +} + +static int dsa_port_vlan_fast_age(const struct dsa_port *dp, u16 vid) +{ + struct dsa_switch *ds = dp->ds; + int err; + + if (!ds->ops->port_vlan_fast_age) + return -EOPNOTSUPP; + + err = ds->ops->port_vlan_fast_age(ds, dp->index, vid); + + if (!err) + dsa_port_notify_bridge_fdb_flush(dp, vid); + + return err; +} + +static int dsa_port_msti_fast_age(const struct dsa_port *dp, u16 msti) +{ + DECLARE_BITMAP(vids, VLAN_N_VID) = { 0 }; + int err, vid; + + err = br_mst_get_info(dsa_port_bridge_dev_get(dp), msti, vids); + if (err) + return err; + + for_each_set_bit(vid, vids, VLAN_N_VID) { + err = dsa_port_vlan_fast_age(dp, vid); + if (err) + return err; + } + + return 0; +} + +static bool dsa_port_can_configure_learning(struct dsa_port *dp) +{ + struct switchdev_brport_flags flags = { + .mask = BR_LEARNING, + }; + struct dsa_switch *ds = dp->ds; + int err; + + if (!ds->ops->port_bridge_flags || !ds->ops->port_pre_bridge_flags) + return false; + + err = ds->ops->port_pre_bridge_flags(ds, dp->index, flags, NULL); + return !err; +} + +bool dsa_port_supports_hwtstamp(struct dsa_port *dp, struct ifreq *ifr) +{ + struct dsa_switch *ds = dp->ds; + int err; + + if (!ds->ops->port_hwtstamp_get || !ds->ops->port_hwtstamp_set) + return false; + + /* "See through" shim implementations of the "get" method. + * This will clobber the ifreq structure, but we will either return an + * error, or the master will overwrite it with proper values. + */ + err = ds->ops->port_hwtstamp_get(ds, dp->index, ifr); + return err != -EOPNOTSUPP; +} + +int dsa_port_set_state(struct dsa_port *dp, u8 state, bool do_fast_age) +{ + struct dsa_switch *ds = dp->ds; + int port = dp->index; + + if (!ds->ops->port_stp_state_set) + return -EOPNOTSUPP; + + ds->ops->port_stp_state_set(ds, port, state); + + if (!dsa_port_can_configure_learning(dp) || + (do_fast_age && dp->learning)) { + /* Fast age FDB entries or flush appropriate forwarding database + * for the given port, if we are moving it from Learning or + * Forwarding state, to Disabled or Blocking or Listening state. + * Ports that were standalone before the STP state change don't + * need to fast age the FDB, since address learning is off in + * standalone mode. + */ + + if ((dp->stp_state == BR_STATE_LEARNING || + dp->stp_state == BR_STATE_FORWARDING) && + (state == BR_STATE_DISABLED || + state == BR_STATE_BLOCKING || + state == BR_STATE_LISTENING)) + dsa_port_fast_age(dp); + } + + dp->stp_state = state; + + return 0; +} + +static void dsa_port_set_state_now(struct dsa_port *dp, u8 state, + bool do_fast_age) +{ + struct dsa_switch *ds = dp->ds; + int err; + + err = dsa_port_set_state(dp, state, do_fast_age); + if (err && err != -EOPNOTSUPP) { + dev_err(ds->dev, "port %d failed to set STP state %u: %pe\n", + dp->index, state, ERR_PTR(err)); + } +} + +int dsa_port_set_mst_state(struct dsa_port *dp, + const struct switchdev_mst_state *state, + struct netlink_ext_ack *extack) +{ + struct dsa_switch *ds = dp->ds; + u8 prev_state; + int err; + + if (!ds->ops->port_mst_state_set) + return -EOPNOTSUPP; + + err = br_mst_get_state(dsa_port_to_bridge_port(dp), state->msti, + &prev_state); + if (err) + return err; + + err = ds->ops->port_mst_state_set(ds, dp->index, state); + if (err) + return err; + + if (!(dp->learning && + (prev_state == BR_STATE_LEARNING || + prev_state == BR_STATE_FORWARDING) && + (state->state == BR_STATE_DISABLED || + state->state == BR_STATE_BLOCKING || + state->state == BR_STATE_LISTENING))) + return 0; + + err = dsa_port_msti_fast_age(dp, state->msti); + if (err) + NL_SET_ERR_MSG_MOD(extack, + "Unable to flush associated VLANs"); + + return 0; +} + +int dsa_port_enable_rt(struct dsa_port *dp, struct phy_device *phy) +{ + struct dsa_switch *ds = dp->ds; + int port = dp->index; + int err; + + if (ds->ops->port_enable) { + err = ds->ops->port_enable(ds, port, phy); + if (err) + return err; + } + + if (!dp->bridge) + dsa_port_set_state_now(dp, BR_STATE_FORWARDING, false); + + if (dp->pl) + phylink_start(dp->pl); + + return 0; +} + +int dsa_port_enable(struct dsa_port *dp, struct phy_device *phy) +{ + int err; + + rtnl_lock(); + err = dsa_port_enable_rt(dp, phy); + rtnl_unlock(); + + return err; +} + +void dsa_port_disable_rt(struct dsa_port *dp) +{ + struct dsa_switch *ds = dp->ds; + int port = dp->index; + + if (dp->pl) + phylink_stop(dp->pl); + + if (!dp->bridge) + dsa_port_set_state_now(dp, BR_STATE_DISABLED, false); + + if (ds->ops->port_disable) + ds->ops->port_disable(ds, port); +} + +void dsa_port_disable(struct dsa_port *dp) +{ + rtnl_lock(); + dsa_port_disable_rt(dp); + rtnl_unlock(); +} + +static void dsa_port_reset_vlan_filtering(struct dsa_port *dp, + struct dsa_bridge bridge) +{ + struct netlink_ext_ack extack = {0}; + bool change_vlan_filtering = false; + struct dsa_switch *ds = dp->ds; + struct dsa_port *other_dp; + bool vlan_filtering; + int err; + + if (ds->needs_standalone_vlan_filtering && + !br_vlan_enabled(bridge.dev)) { + change_vlan_filtering = true; + vlan_filtering = true; + } else if (!ds->needs_standalone_vlan_filtering && + br_vlan_enabled(bridge.dev)) { + change_vlan_filtering = true; + vlan_filtering = false; + } + + /* If the bridge was vlan_filtering, the bridge core doesn't trigger an + * event for changing vlan_filtering setting upon slave ports leaving + * it. That is a good thing, because that lets us handle it and also + * handle the case where the switch's vlan_filtering setting is global + * (not per port). When that happens, the correct moment to trigger the + * vlan_filtering callback is only when the last port leaves the last + * VLAN-aware bridge. + */ + if (change_vlan_filtering && ds->vlan_filtering_is_global) { + dsa_switch_for_each_port(other_dp, ds) { + struct net_device *br = dsa_port_bridge_dev_get(other_dp); + + if (br && br_vlan_enabled(br)) { + change_vlan_filtering = false; + break; + } + } + } + + if (!change_vlan_filtering) + return; + + err = dsa_port_vlan_filtering(dp, vlan_filtering, &extack); + if (extack._msg) { + dev_err(ds->dev, "port %d: %s\n", dp->index, + extack._msg); + } + if (err && err != -EOPNOTSUPP) { + dev_err(ds->dev, + "port %d failed to reset VLAN filtering to %d: %pe\n", + dp->index, vlan_filtering, ERR_PTR(err)); + } +} + +static int dsa_port_inherit_brport_flags(struct dsa_port *dp, + struct netlink_ext_ack *extack) +{ + const unsigned long mask = BR_LEARNING | BR_FLOOD | BR_MCAST_FLOOD | + BR_BCAST_FLOOD | BR_PORT_LOCKED; + struct net_device *brport_dev = dsa_port_to_bridge_port(dp); + int flag, err; + + for_each_set_bit(flag, &mask, 32) { + struct switchdev_brport_flags flags = {0}; + + flags.mask = BIT(flag); + + if (br_port_flag_is_set(brport_dev, BIT(flag))) + flags.val = BIT(flag); + + err = dsa_port_bridge_flags(dp, flags, extack); + if (err && err != -EOPNOTSUPP) + return err; + } + + return 0; +} + +static void dsa_port_clear_brport_flags(struct dsa_port *dp) +{ + const unsigned long val = BR_FLOOD | BR_MCAST_FLOOD | BR_BCAST_FLOOD; + const unsigned long mask = BR_LEARNING | BR_FLOOD | BR_MCAST_FLOOD | + BR_BCAST_FLOOD | BR_PORT_LOCKED; + int flag, err; + + for_each_set_bit(flag, &mask, 32) { + struct switchdev_brport_flags flags = {0}; + + flags.mask = BIT(flag); + flags.val = val & BIT(flag); + + err = dsa_port_bridge_flags(dp, flags, NULL); + if (err && err != -EOPNOTSUPP) + dev_err(dp->ds->dev, + "failed to clear bridge port flag %lu: %pe\n", + flags.val, ERR_PTR(err)); + } +} + +static int dsa_port_switchdev_sync_attrs(struct dsa_port *dp, + struct netlink_ext_ack *extack) +{ + struct net_device *brport_dev = dsa_port_to_bridge_port(dp); + struct net_device *br = dsa_port_bridge_dev_get(dp); + int err; + + err = dsa_port_inherit_brport_flags(dp, extack); + if (err) + return err; + + err = dsa_port_set_state(dp, br_port_get_stp_state(brport_dev), false); + if (err && err != -EOPNOTSUPP) + return err; + + err = dsa_port_vlan_filtering(dp, br_vlan_enabled(br), extack); + if (err && err != -EOPNOTSUPP) + return err; + + err = dsa_port_ageing_time(dp, br_get_ageing_time(br)); + if (err && err != -EOPNOTSUPP) + return err; + + return 0; +} + +static void dsa_port_switchdev_unsync_attrs(struct dsa_port *dp, + struct dsa_bridge bridge) +{ + /* Configure the port for standalone mode (no address learning, + * flood everything). + * The bridge only emits SWITCHDEV_ATTR_ID_PORT_BRIDGE_FLAGS events + * when the user requests it through netlink or sysfs, but not + * automatically at port join or leave, so we need to handle resetting + * the brport flags ourselves. But we even prefer it that way, because + * otherwise, some setups might never get the notification they need, + * for example, when a port leaves a LAG that offloads the bridge, + * it becomes standalone, but as far as the bridge is concerned, no + * port ever left. + */ + dsa_port_clear_brport_flags(dp); + + /* Port left the bridge, put in BR_STATE_DISABLED by the bridge layer, + * so allow it to be in BR_STATE_FORWARDING to be kept functional + */ + dsa_port_set_state_now(dp, BR_STATE_FORWARDING, true); + + dsa_port_reset_vlan_filtering(dp, bridge); + + /* Ageing time may be global to the switch chip, so don't change it + * here because we have no good reason (or value) to change it to. + */ +} + +static int dsa_port_bridge_create(struct dsa_port *dp, + struct net_device *br, + struct netlink_ext_ack *extack) +{ + struct dsa_switch *ds = dp->ds; + struct dsa_bridge *bridge; + + bridge = dsa_tree_bridge_find(ds->dst, br); + if (bridge) { + refcount_inc(&bridge->refcount); + dp->bridge = bridge; + return 0; + } + + bridge = kzalloc(sizeof(*bridge), GFP_KERNEL); + if (!bridge) + return -ENOMEM; + + refcount_set(&bridge->refcount, 1); + + bridge->dev = br; + + bridge->num = dsa_bridge_num_get(br, ds->max_num_bridges); + if (ds->max_num_bridges && !bridge->num) { + NL_SET_ERR_MSG_MOD(extack, + "Range of offloadable bridges exceeded"); + kfree(bridge); + return -EOPNOTSUPP; + } + + dp->bridge = bridge; + + return 0; +} + +static void dsa_port_bridge_destroy(struct dsa_port *dp, + const struct net_device *br) +{ + struct dsa_bridge *bridge = dp->bridge; + + dp->bridge = NULL; + + if (!refcount_dec_and_test(&bridge->refcount)) + return; + + if (bridge->num) + dsa_bridge_num_put(br, bridge->num); + + kfree(bridge); +} + +static bool dsa_port_supports_mst(struct dsa_port *dp) +{ + struct dsa_switch *ds = dp->ds; + + return ds->ops->vlan_msti_set && + ds->ops->port_mst_state_set && + ds->ops->port_vlan_fast_age && + dsa_port_can_configure_learning(dp); +} + +int dsa_port_bridge_join(struct dsa_port *dp, struct net_device *br, + struct netlink_ext_ack *extack) +{ + struct dsa_notifier_bridge_info info = { + .dp = dp, + .extack = extack, + }; + struct net_device *dev = dp->slave; + struct net_device *brport_dev; + int err; + + if (br_mst_enabled(br) && !dsa_port_supports_mst(dp)) + return -EOPNOTSUPP; + + /* Here the interface is already bridged. Reflect the current + * configuration so that drivers can program their chips accordingly. + */ + err = dsa_port_bridge_create(dp, br, extack); + if (err) + return err; + + brport_dev = dsa_port_to_bridge_port(dp); + + info.bridge = *dp->bridge; + err = dsa_broadcast(DSA_NOTIFIER_BRIDGE_JOIN, &info); + if (err) + goto out_rollback; + + /* Drivers which support bridge TX forwarding should set this */ + dp->bridge->tx_fwd_offload = info.tx_fwd_offload; + + err = switchdev_bridge_port_offload(brport_dev, dev, dp, + &dsa_slave_switchdev_notifier, + &dsa_slave_switchdev_blocking_notifier, + dp->bridge->tx_fwd_offload, extack); + if (err) + goto out_rollback_unbridge; + + err = dsa_port_switchdev_sync_attrs(dp, extack); + if (err) + goto out_rollback_unoffload; + + return 0; + +out_rollback_unoffload: + switchdev_bridge_port_unoffload(brport_dev, dp, + &dsa_slave_switchdev_notifier, + &dsa_slave_switchdev_blocking_notifier); + dsa_flush_workqueue(); +out_rollback_unbridge: + dsa_broadcast(DSA_NOTIFIER_BRIDGE_LEAVE, &info); +out_rollback: + dsa_port_bridge_destroy(dp, br); + return err; +} + +void dsa_port_pre_bridge_leave(struct dsa_port *dp, struct net_device *br) +{ + struct net_device *brport_dev = dsa_port_to_bridge_port(dp); + + /* Don't try to unoffload something that is not offloaded */ + if (!brport_dev) + return; + + switchdev_bridge_port_unoffload(brport_dev, dp, + &dsa_slave_switchdev_notifier, + &dsa_slave_switchdev_blocking_notifier); + + dsa_flush_workqueue(); +} + +void dsa_port_bridge_leave(struct dsa_port *dp, struct net_device *br) +{ + struct dsa_notifier_bridge_info info = { + .dp = dp, + }; + int err; + + /* If the port could not be offloaded to begin with, then + * there is nothing to do. + */ + if (!dp->bridge) + return; + + info.bridge = *dp->bridge; + + /* Here the port is already unbridged. Reflect the current configuration + * so that drivers can program their chips accordingly. + */ + dsa_port_bridge_destroy(dp, br); + + err = dsa_broadcast(DSA_NOTIFIER_BRIDGE_LEAVE, &info); + if (err) + dev_err(dp->ds->dev, + "port %d failed to notify DSA_NOTIFIER_BRIDGE_LEAVE: %pe\n", + dp->index, ERR_PTR(err)); + + dsa_port_switchdev_unsync_attrs(dp, info.bridge); +} + +int dsa_port_lag_change(struct dsa_port *dp, + struct netdev_lag_lower_state_info *linfo) +{ + struct dsa_notifier_lag_info info = { + .dp = dp, + }; + bool tx_enabled; + + if (!dp->lag) + return 0; + + /* On statically configured aggregates (e.g. loadbalance + * without LACP) ports will always be tx_enabled, even if the + * link is down. Thus we require both link_up and tx_enabled + * in order to include it in the tx set. + */ + tx_enabled = linfo->link_up && linfo->tx_enabled; + + if (tx_enabled == dp->lag_tx_enabled) + return 0; + + dp->lag_tx_enabled = tx_enabled; + + return dsa_port_notify(dp, DSA_NOTIFIER_LAG_CHANGE, &info); +} + +static int dsa_port_lag_create(struct dsa_port *dp, + struct net_device *lag_dev) +{ + struct dsa_switch *ds = dp->ds; + struct dsa_lag *lag; + + lag = dsa_tree_lag_find(ds->dst, lag_dev); + if (lag) { + refcount_inc(&lag->refcount); + dp->lag = lag; + return 0; + } + + lag = kzalloc(sizeof(*lag), GFP_KERNEL); + if (!lag) + return -ENOMEM; + + refcount_set(&lag->refcount, 1); + mutex_init(&lag->fdb_lock); + INIT_LIST_HEAD(&lag->fdbs); + lag->dev = lag_dev; + dsa_lag_map(ds->dst, lag); + dp->lag = lag; + + return 0; +} + +static void dsa_port_lag_destroy(struct dsa_port *dp) +{ + struct dsa_lag *lag = dp->lag; + + dp->lag = NULL; + dp->lag_tx_enabled = false; + + if (!refcount_dec_and_test(&lag->refcount)) + return; + + WARN_ON(!list_empty(&lag->fdbs)); + dsa_lag_unmap(dp->ds->dst, lag); + kfree(lag); +} + +int dsa_port_lag_join(struct dsa_port *dp, struct net_device *lag_dev, + struct netdev_lag_upper_info *uinfo, + struct netlink_ext_ack *extack) +{ + struct dsa_notifier_lag_info info = { + .dp = dp, + .info = uinfo, + .extack = extack, + }; + struct net_device *bridge_dev; + int err; + + err = dsa_port_lag_create(dp, lag_dev); + if (err) + goto err_lag_create; + + info.lag = *dp->lag; + err = dsa_port_notify(dp, DSA_NOTIFIER_LAG_JOIN, &info); + if (err) + goto err_lag_join; + + bridge_dev = netdev_master_upper_dev_get(lag_dev); + if (!bridge_dev || !netif_is_bridge_master(bridge_dev)) + return 0; + + err = dsa_port_bridge_join(dp, bridge_dev, extack); + if (err) + goto err_bridge_join; + + return 0; + +err_bridge_join: + dsa_port_notify(dp, DSA_NOTIFIER_LAG_LEAVE, &info); +err_lag_join: + dsa_port_lag_destroy(dp); +err_lag_create: + return err; +} + +void dsa_port_pre_lag_leave(struct dsa_port *dp, struct net_device *lag_dev) +{ + struct net_device *br = dsa_port_bridge_dev_get(dp); + + if (br) + dsa_port_pre_bridge_leave(dp, br); +} + +void dsa_port_lag_leave(struct dsa_port *dp, struct net_device *lag_dev) +{ + struct net_device *br = dsa_port_bridge_dev_get(dp); + struct dsa_notifier_lag_info info = { + .dp = dp, + }; + int err; + + if (!dp->lag) + return; + + /* Port might have been part of a LAG that in turn was + * attached to a bridge. + */ + if (br) + dsa_port_bridge_leave(dp, br); + + info.lag = *dp->lag; + + dsa_port_lag_destroy(dp); + + err = dsa_port_notify(dp, DSA_NOTIFIER_LAG_LEAVE, &info); + if (err) + dev_err(dp->ds->dev, + "port %d failed to notify DSA_NOTIFIER_LAG_LEAVE: %pe\n", + dp->index, ERR_PTR(err)); +} + +/* Must be called under rcu_read_lock() */ +static bool dsa_port_can_apply_vlan_filtering(struct dsa_port *dp, + bool vlan_filtering, + struct netlink_ext_ack *extack) +{ + struct dsa_switch *ds = dp->ds; + struct dsa_port *other_dp; + int err; + + /* VLAN awareness was off, so the question is "can we turn it on". + * We may have had 8021q uppers, those need to go. Make sure we don't + * enter an inconsistent state: deny changing the VLAN awareness state + * as long as we have 8021q uppers. + */ + if (vlan_filtering && dsa_port_is_user(dp)) { + struct net_device *br = dsa_port_bridge_dev_get(dp); + struct net_device *upper_dev, *slave = dp->slave; + struct list_head *iter; + + netdev_for_each_upper_dev_rcu(slave, upper_dev, iter) { + struct bridge_vlan_info br_info; + u16 vid; + + if (!is_vlan_dev(upper_dev)) + continue; + + vid = vlan_dev_vlan_id(upper_dev); + + /* br_vlan_get_info() returns -EINVAL or -ENOENT if the + * device, respectively the VID is not found, returning + * 0 means success, which is a failure for us here. + */ + err = br_vlan_get_info(br, vid, &br_info); + if (err == 0) { + NL_SET_ERR_MSG_MOD(extack, + "Must first remove VLAN uppers having VIDs also present in bridge"); + return false; + } + } + } + + if (!ds->vlan_filtering_is_global) + return true; + + /* For cases where enabling/disabling VLAN awareness is global to the + * switch, we need to handle the case where multiple bridges span + * different ports of the same switch device and one of them has a + * different setting than what is being requested. + */ + dsa_switch_for_each_port(other_dp, ds) { + struct net_device *other_br = dsa_port_bridge_dev_get(other_dp); + + /* If it's the same bridge, it also has same + * vlan_filtering setting => no need to check + */ + if (!other_br || other_br == dsa_port_bridge_dev_get(dp)) + continue; + + if (br_vlan_enabled(other_br) != vlan_filtering) { + NL_SET_ERR_MSG_MOD(extack, + "VLAN filtering is a global setting"); + return false; + } + } + return true; +} + +int dsa_port_vlan_filtering(struct dsa_port *dp, bool vlan_filtering, + struct netlink_ext_ack *extack) +{ + bool old_vlan_filtering = dsa_port_is_vlan_filtering(dp); + struct dsa_switch *ds = dp->ds; + bool apply; + int err; + + if (!ds->ops->port_vlan_filtering) + return -EOPNOTSUPP; + + /* We are called from dsa_slave_switchdev_blocking_event(), + * which is not under rcu_read_lock(), unlike + * dsa_slave_switchdev_event(). + */ + rcu_read_lock(); + apply = dsa_port_can_apply_vlan_filtering(dp, vlan_filtering, extack); + rcu_read_unlock(); + if (!apply) + return -EINVAL; + + if (dsa_port_is_vlan_filtering(dp) == vlan_filtering) + return 0; + + err = ds->ops->port_vlan_filtering(ds, dp->index, vlan_filtering, + extack); + if (err) + return err; + + if (ds->vlan_filtering_is_global) { + struct dsa_port *other_dp; + + ds->vlan_filtering = vlan_filtering; + + dsa_switch_for_each_user_port(other_dp, ds) { + struct net_device *slave = other_dp->slave; + + /* We might be called in the unbind path, so not + * all slave devices might still be registered. + */ + if (!slave) + continue; + + err = dsa_slave_manage_vlan_filtering(slave, + vlan_filtering); + if (err) + goto restore; + } + } else { + dp->vlan_filtering = vlan_filtering; + + err = dsa_slave_manage_vlan_filtering(dp->slave, + vlan_filtering); + if (err) + goto restore; + } + + return 0; + +restore: + ds->ops->port_vlan_filtering(ds, dp->index, old_vlan_filtering, NULL); + + if (ds->vlan_filtering_is_global) + ds->vlan_filtering = old_vlan_filtering; + else + dp->vlan_filtering = old_vlan_filtering; + + return err; +} + +/* This enforces legacy behavior for switch drivers which assume they can't + * receive VLAN configuration when enslaved to a bridge with vlan_filtering=0 + */ +bool dsa_port_skip_vlan_configuration(struct dsa_port *dp) +{ + struct net_device *br = dsa_port_bridge_dev_get(dp); + struct dsa_switch *ds = dp->ds; + + if (!br) + return false; + + return !ds->configure_vlan_while_not_filtering && !br_vlan_enabled(br); +} + +int dsa_port_ageing_time(struct dsa_port *dp, clock_t ageing_clock) +{ + unsigned long ageing_jiffies = clock_t_to_jiffies(ageing_clock); + unsigned int ageing_time = jiffies_to_msecs(ageing_jiffies); + struct dsa_notifier_ageing_time_info info; + int err; + + info.ageing_time = ageing_time; + + err = dsa_port_notify(dp, DSA_NOTIFIER_AGEING_TIME, &info); + if (err) + return err; + + dp->ageing_time = ageing_time; + + return 0; +} + +int dsa_port_mst_enable(struct dsa_port *dp, bool on, + struct netlink_ext_ack *extack) +{ + if (on && !dsa_port_supports_mst(dp)) { + NL_SET_ERR_MSG_MOD(extack, "Hardware does not support MST"); + return -EINVAL; + } + + return 0; +} + +int dsa_port_pre_bridge_flags(const struct dsa_port *dp, + struct switchdev_brport_flags flags, + struct netlink_ext_ack *extack) +{ + struct dsa_switch *ds = dp->ds; + + if (!ds->ops->port_pre_bridge_flags) + return -EINVAL; + + return ds->ops->port_pre_bridge_flags(ds, dp->index, flags, extack); +} + +int dsa_port_bridge_flags(struct dsa_port *dp, + struct switchdev_brport_flags flags, + struct netlink_ext_ack *extack) +{ + struct dsa_switch *ds = dp->ds; + int err; + + if (!ds->ops->port_bridge_flags) + return -EOPNOTSUPP; + + err = ds->ops->port_bridge_flags(ds, dp->index, flags, extack); + if (err) + return err; + + if (flags.mask & BR_LEARNING) { + bool learning = flags.val & BR_LEARNING; + + if (learning == dp->learning) + return 0; + + if ((dp->learning && !learning) && + (dp->stp_state == BR_STATE_LEARNING || + dp->stp_state == BR_STATE_FORWARDING)) + dsa_port_fast_age(dp); + + dp->learning = learning; + } + + return 0; +} + +void dsa_port_set_host_flood(struct dsa_port *dp, bool uc, bool mc) +{ + struct dsa_switch *ds = dp->ds; + + if (ds->ops->port_set_host_flood) + ds->ops->port_set_host_flood(ds, dp->index, uc, mc); +} + +int dsa_port_vlan_msti(struct dsa_port *dp, + const struct switchdev_vlan_msti *msti) +{ + struct dsa_switch *ds = dp->ds; + + if (!ds->ops->vlan_msti_set) + return -EOPNOTSUPP; + + return ds->ops->vlan_msti_set(ds, *dp->bridge, msti); +} + +int dsa_port_mtu_change(struct dsa_port *dp, int new_mtu) +{ + struct dsa_notifier_mtu_info info = { + .dp = dp, + .mtu = new_mtu, + }; + + return dsa_port_notify(dp, DSA_NOTIFIER_MTU, &info); +} + +int dsa_port_fdb_add(struct dsa_port *dp, const unsigned char *addr, + u16 vid) +{ + struct dsa_notifier_fdb_info info = { + .dp = dp, + .addr = addr, + .vid = vid, + .db = { + .type = DSA_DB_BRIDGE, + .bridge = *dp->bridge, + }, + }; + + /* Refcounting takes bridge.num as a key, and should be global for all + * bridges in the absence of FDB isolation, and per bridge otherwise. + * Force the bridge.num to zero here in the absence of FDB isolation. + */ + if (!dp->ds->fdb_isolation) + info.db.bridge.num = 0; + + return dsa_port_notify(dp, DSA_NOTIFIER_FDB_ADD, &info); +} + +int dsa_port_fdb_del(struct dsa_port *dp, const unsigned char *addr, + u16 vid) +{ + struct dsa_notifier_fdb_info info = { + .dp = dp, + .addr = addr, + .vid = vid, + .db = { + .type = DSA_DB_BRIDGE, + .bridge = *dp->bridge, + }, + }; + + if (!dp->ds->fdb_isolation) + info.db.bridge.num = 0; + + return dsa_port_notify(dp, DSA_NOTIFIER_FDB_DEL, &info); +} + +static int dsa_port_host_fdb_add(struct dsa_port *dp, + const unsigned char *addr, u16 vid, + struct dsa_db db) +{ + struct dsa_notifier_fdb_info info = { + .dp = dp, + .addr = addr, + .vid = vid, + .db = db, + }; + + if (!dp->ds->fdb_isolation) + info.db.bridge.num = 0; + + return dsa_port_notify(dp, DSA_NOTIFIER_HOST_FDB_ADD, &info); +} + +int dsa_port_standalone_host_fdb_add(struct dsa_port *dp, + const unsigned char *addr, u16 vid) +{ + struct dsa_db db = { + .type = DSA_DB_PORT, + .dp = dp, + }; + + return dsa_port_host_fdb_add(dp, addr, vid, db); +} + +int dsa_port_bridge_host_fdb_add(struct dsa_port *dp, + const unsigned char *addr, u16 vid) +{ + struct net_device *master = dsa_port_to_master(dp); + struct dsa_db db = { + .type = DSA_DB_BRIDGE, + .bridge = *dp->bridge, + }; + int err; + + /* Avoid a call to __dev_set_promiscuity() on the master, which + * requires rtnl_lock(), since we can't guarantee that is held here, + * and we can't take it either. + */ + if (master->priv_flags & IFF_UNICAST_FLT) { + err = dev_uc_add(master, addr); + if (err) + return err; + } + + return dsa_port_host_fdb_add(dp, addr, vid, db); +} + +static int dsa_port_host_fdb_del(struct dsa_port *dp, + const unsigned char *addr, u16 vid, + struct dsa_db db) +{ + struct dsa_notifier_fdb_info info = { + .dp = dp, + .addr = addr, + .vid = vid, + .db = db, + }; + + if (!dp->ds->fdb_isolation) + info.db.bridge.num = 0; + + return dsa_port_notify(dp, DSA_NOTIFIER_HOST_FDB_DEL, &info); +} + +int dsa_port_standalone_host_fdb_del(struct dsa_port *dp, + const unsigned char *addr, u16 vid) +{ + struct dsa_db db = { + .type = DSA_DB_PORT, + .dp = dp, + }; + + return dsa_port_host_fdb_del(dp, addr, vid, db); +} + +int dsa_port_bridge_host_fdb_del(struct dsa_port *dp, + const unsigned char *addr, u16 vid) +{ + struct net_device *master = dsa_port_to_master(dp); + struct dsa_db db = { + .type = DSA_DB_BRIDGE, + .bridge = *dp->bridge, + }; + int err; + + if (master->priv_flags & IFF_UNICAST_FLT) { + err = dev_uc_del(master, addr); + if (err) + return err; + } + + return dsa_port_host_fdb_del(dp, addr, vid, db); +} + +int dsa_port_lag_fdb_add(struct dsa_port *dp, const unsigned char *addr, + u16 vid) +{ + struct dsa_notifier_lag_fdb_info info = { + .lag = dp->lag, + .addr = addr, + .vid = vid, + .db = { + .type = DSA_DB_BRIDGE, + .bridge = *dp->bridge, + }, + }; + + if (!dp->ds->fdb_isolation) + info.db.bridge.num = 0; + + return dsa_port_notify(dp, DSA_NOTIFIER_LAG_FDB_ADD, &info); +} + +int dsa_port_lag_fdb_del(struct dsa_port *dp, const unsigned char *addr, + u16 vid) +{ + struct dsa_notifier_lag_fdb_info info = { + .lag = dp->lag, + .addr = addr, + .vid = vid, + .db = { + .type = DSA_DB_BRIDGE, + .bridge = *dp->bridge, + }, + }; + + if (!dp->ds->fdb_isolation) + info.db.bridge.num = 0; + + return dsa_port_notify(dp, DSA_NOTIFIER_LAG_FDB_DEL, &info); +} + +int dsa_port_fdb_dump(struct dsa_port *dp, dsa_fdb_dump_cb_t *cb, void *data) +{ + struct dsa_switch *ds = dp->ds; + int port = dp->index; + + if (!ds->ops->port_fdb_dump) + return -EOPNOTSUPP; + + return ds->ops->port_fdb_dump(ds, port, cb, data); +} + +int dsa_port_mdb_add(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb) +{ + struct dsa_notifier_mdb_info info = { + .dp = dp, + .mdb = mdb, + .db = { + .type = DSA_DB_BRIDGE, + .bridge = *dp->bridge, + }, + }; + + if (!dp->ds->fdb_isolation) + info.db.bridge.num = 0; + + return dsa_port_notify(dp, DSA_NOTIFIER_MDB_ADD, &info); +} + +int dsa_port_mdb_del(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb) +{ + struct dsa_notifier_mdb_info info = { + .dp = dp, + .mdb = mdb, + .db = { + .type = DSA_DB_BRIDGE, + .bridge = *dp->bridge, + }, + }; + + if (!dp->ds->fdb_isolation) + info.db.bridge.num = 0; + + return dsa_port_notify(dp, DSA_NOTIFIER_MDB_DEL, &info); +} + +static int dsa_port_host_mdb_add(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb, + struct dsa_db db) +{ + struct dsa_notifier_mdb_info info = { + .dp = dp, + .mdb = mdb, + .db = db, + }; + + if (!dp->ds->fdb_isolation) + info.db.bridge.num = 0; + + return dsa_port_notify(dp, DSA_NOTIFIER_HOST_MDB_ADD, &info); +} + +int dsa_port_standalone_host_mdb_add(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb) +{ + struct dsa_db db = { + .type = DSA_DB_PORT, + .dp = dp, + }; + + return dsa_port_host_mdb_add(dp, mdb, db); +} + +int dsa_port_bridge_host_mdb_add(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb) +{ + struct net_device *master = dsa_port_to_master(dp); + struct dsa_db db = { + .type = DSA_DB_BRIDGE, + .bridge = *dp->bridge, + }; + int err; + + err = dev_mc_add(master, mdb->addr); + if (err) + return err; + + return dsa_port_host_mdb_add(dp, mdb, db); +} + +static int dsa_port_host_mdb_del(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb, + struct dsa_db db) +{ + struct dsa_notifier_mdb_info info = { + .dp = dp, + .mdb = mdb, + .db = db, + }; + + if (!dp->ds->fdb_isolation) + info.db.bridge.num = 0; + + return dsa_port_notify(dp, DSA_NOTIFIER_HOST_MDB_DEL, &info); +} + +int dsa_port_standalone_host_mdb_del(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb) +{ + struct dsa_db db = { + .type = DSA_DB_PORT, + .dp = dp, + }; + + return dsa_port_host_mdb_del(dp, mdb, db); +} + +int dsa_port_bridge_host_mdb_del(const struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb) +{ + struct net_device *master = dsa_port_to_master(dp); + struct dsa_db db = { + .type = DSA_DB_BRIDGE, + .bridge = *dp->bridge, + }; + int err; + + err = dev_mc_del(master, mdb->addr); + if (err) + return err; + + return dsa_port_host_mdb_del(dp, mdb, db); +} + +int dsa_port_vlan_add(struct dsa_port *dp, + const struct switchdev_obj_port_vlan *vlan, + struct netlink_ext_ack *extack) +{ + struct dsa_notifier_vlan_info info = { + .dp = dp, + .vlan = vlan, + .extack = extack, + }; + + return dsa_port_notify(dp, DSA_NOTIFIER_VLAN_ADD, &info); +} + +int dsa_port_vlan_del(struct dsa_port *dp, + const struct switchdev_obj_port_vlan *vlan) +{ + struct dsa_notifier_vlan_info info = { + .dp = dp, + .vlan = vlan, + }; + + return dsa_port_notify(dp, DSA_NOTIFIER_VLAN_DEL, &info); +} + +int dsa_port_host_vlan_add(struct dsa_port *dp, + const struct switchdev_obj_port_vlan *vlan, + struct netlink_ext_ack *extack) +{ + struct net_device *master = dsa_port_to_master(dp); + struct dsa_notifier_vlan_info info = { + .dp = dp, + .vlan = vlan, + .extack = extack, + }; + int err; + + err = dsa_port_notify(dp, DSA_NOTIFIER_HOST_VLAN_ADD, &info); + if (err && err != -EOPNOTSUPP) + return err; + + vlan_vid_add(master, htons(ETH_P_8021Q), vlan->vid); + + return err; +} + +int dsa_port_host_vlan_del(struct dsa_port *dp, + const struct switchdev_obj_port_vlan *vlan) +{ + struct net_device *master = dsa_port_to_master(dp); + struct dsa_notifier_vlan_info info = { + .dp = dp, + .vlan = vlan, + }; + int err; + + err = dsa_port_notify(dp, DSA_NOTIFIER_HOST_VLAN_DEL, &info); + if (err && err != -EOPNOTSUPP) + return err; + + vlan_vid_del(master, htons(ETH_P_8021Q), vlan->vid); + + return err; +} + +int dsa_port_mrp_add(const struct dsa_port *dp, + const struct switchdev_obj_mrp *mrp) +{ + struct dsa_switch *ds = dp->ds; + + if (!ds->ops->port_mrp_add) + return -EOPNOTSUPP; + + return ds->ops->port_mrp_add(ds, dp->index, mrp); +} + +int dsa_port_mrp_del(const struct dsa_port *dp, + const struct switchdev_obj_mrp *mrp) +{ + struct dsa_switch *ds = dp->ds; + + if (!ds->ops->port_mrp_del) + return -EOPNOTSUPP; + + return ds->ops->port_mrp_del(ds, dp->index, mrp); +} + +int dsa_port_mrp_add_ring_role(const struct dsa_port *dp, + const struct switchdev_obj_ring_role_mrp *mrp) +{ + struct dsa_switch *ds = dp->ds; + + if (!ds->ops->port_mrp_add_ring_role) + return -EOPNOTSUPP; + + return ds->ops->port_mrp_add_ring_role(ds, dp->index, mrp); +} + +int dsa_port_mrp_del_ring_role(const struct dsa_port *dp, + const struct switchdev_obj_ring_role_mrp *mrp) +{ + struct dsa_switch *ds = dp->ds; + + if (!ds->ops->port_mrp_del_ring_role) + return -EOPNOTSUPP; + + return ds->ops->port_mrp_del_ring_role(ds, dp->index, mrp); +} + +static int dsa_port_assign_master(struct dsa_port *dp, + struct net_device *master, + struct netlink_ext_ack *extack, + bool fail_on_err) +{ + struct dsa_switch *ds = dp->ds; + int port = dp->index, err; + + err = ds->ops->port_change_master(ds, port, master, extack); + if (err && !fail_on_err) + dev_err(ds->dev, "port %d failed to assign master %s: %pe\n", + port, master->name, ERR_PTR(err)); + + if (err && fail_on_err) + return err; + + dp->cpu_dp = master->dsa_ptr; + dp->cpu_port_in_lag = netif_is_lag_master(master); + + return 0; +} + +/* Change the dp->cpu_dp affinity for a user port. Note that both cross-chip + * notifiers and drivers have implicit assumptions about user-to-CPU-port + * mappings, so we unfortunately cannot delay the deletion of the objects + * (switchdev, standalone addresses, standalone VLANs) on the old CPU port + * until the new CPU port has been set up. So we need to completely tear down + * the old CPU port before changing it, and restore it on errors during the + * bringup of the new one. + */ +int dsa_port_change_master(struct dsa_port *dp, struct net_device *master, + struct netlink_ext_ack *extack) +{ + struct net_device *bridge_dev = dsa_port_bridge_dev_get(dp); + struct net_device *old_master = dsa_port_to_master(dp); + struct net_device *dev = dp->slave; + struct dsa_switch *ds = dp->ds; + bool vlan_filtering; + int err, tmp; + + /* Bridges may hold host FDB, MDB and VLAN objects. These need to be + * migrated, so dynamically unoffload and later reoffload the bridge + * port. + */ + if (bridge_dev) { + dsa_port_pre_bridge_leave(dp, bridge_dev); + dsa_port_bridge_leave(dp, bridge_dev); + } + + /* The port might still be VLAN filtering even if it's no longer + * under a bridge, either due to ds->vlan_filtering_is_global or + * ds->needs_standalone_vlan_filtering. In turn this means VLANs + * on the CPU port. + */ + vlan_filtering = dsa_port_is_vlan_filtering(dp); + if (vlan_filtering) { + err = dsa_slave_manage_vlan_filtering(dev, false); + if (err) { + NL_SET_ERR_MSG_MOD(extack, + "Failed to remove standalone VLANs"); + goto rewind_old_bridge; + } + } + + /* Standalone addresses, and addresses of upper interfaces like + * VLAN, LAG, HSR need to be migrated. + */ + dsa_slave_unsync_ha(dev); + + err = dsa_port_assign_master(dp, master, extack, true); + if (err) + goto rewind_old_addrs; + + dsa_slave_sync_ha(dev); + + if (vlan_filtering) { + err = dsa_slave_manage_vlan_filtering(dev, true); + if (err) { + NL_SET_ERR_MSG_MOD(extack, + "Failed to restore standalone VLANs"); + goto rewind_new_addrs; + } + } + + if (bridge_dev) { + err = dsa_port_bridge_join(dp, bridge_dev, extack); + if (err && err == -EOPNOTSUPP) { + NL_SET_ERR_MSG_MOD(extack, + "Failed to reoffload bridge"); + goto rewind_new_vlan; + } + } + + return 0; + +rewind_new_vlan: + if (vlan_filtering) + dsa_slave_manage_vlan_filtering(dev, false); + +rewind_new_addrs: + dsa_slave_unsync_ha(dev); + + dsa_port_assign_master(dp, old_master, NULL, false); + +/* Restore the objects on the old CPU port */ +rewind_old_addrs: + dsa_slave_sync_ha(dev); + + if (vlan_filtering) { + tmp = dsa_slave_manage_vlan_filtering(dev, true); + if (tmp) { + dev_err(ds->dev, + "port %d failed to restore standalone VLANs: %pe\n", + dp->index, ERR_PTR(tmp)); + } + } + +rewind_old_bridge: + if (bridge_dev) { + tmp = dsa_port_bridge_join(dp, bridge_dev, extack); + if (tmp) { + dev_err(ds->dev, + "port %d failed to rejoin bridge %s: %pe\n", + dp->index, bridge_dev->name, ERR_PTR(tmp)); + } + } + + return err; +} + +void dsa_port_set_tag_protocol(struct dsa_port *cpu_dp, + const struct dsa_device_ops *tag_ops) +{ + cpu_dp->rcv = tag_ops->rcv; + cpu_dp->tag_ops = tag_ops; +} + +static struct phy_device *dsa_port_get_phy_device(struct dsa_port *dp) +{ + struct device_node *phy_dn; + struct phy_device *phydev; + + phy_dn = of_parse_phandle(dp->dn, "phy-handle", 0); + if (!phy_dn) + return NULL; + + phydev = of_phy_find_device(phy_dn); + if (!phydev) { + of_node_put(phy_dn); + return ERR_PTR(-EPROBE_DEFER); + } + + of_node_put(phy_dn); + return phydev; +} + +static void dsa_port_phylink_validate(struct phylink_config *config, + unsigned long *supported, + struct phylink_link_state *state) +{ + struct dsa_port *dp = container_of(config, struct dsa_port, pl_config); + struct dsa_switch *ds = dp->ds; + + if (!ds->ops->phylink_validate) { + if (config->mac_capabilities) + phylink_generic_validate(config, supported, state); + return; + } + + ds->ops->phylink_validate(ds, dp->index, supported, state); +} + +static void dsa_port_phylink_mac_pcs_get_state(struct phylink_config *config, + struct phylink_link_state *state) +{ + struct dsa_port *dp = container_of(config, struct dsa_port, pl_config); + struct dsa_switch *ds = dp->ds; + int err; + + /* Only called for inband modes */ + if (!ds->ops->phylink_mac_link_state) { + state->link = 0; + return; + } + + err = ds->ops->phylink_mac_link_state(ds, dp->index, state); + if (err < 0) { + dev_err(ds->dev, "p%d: phylink_mac_link_state() failed: %d\n", + dp->index, err); + state->link = 0; + } +} + +static struct phylink_pcs * +dsa_port_phylink_mac_select_pcs(struct phylink_config *config, + phy_interface_t interface) +{ + struct dsa_port *dp = container_of(config, struct dsa_port, pl_config); + struct phylink_pcs *pcs = ERR_PTR(-EOPNOTSUPP); + struct dsa_switch *ds = dp->ds; + + if (ds->ops->phylink_mac_select_pcs) + pcs = ds->ops->phylink_mac_select_pcs(ds, dp->index, interface); + + return pcs; +} + +static void dsa_port_phylink_mac_config(struct phylink_config *config, + unsigned int mode, + const struct phylink_link_state *state) +{ + struct dsa_port *dp = container_of(config, struct dsa_port, pl_config); + struct dsa_switch *ds = dp->ds; + + if (!ds->ops->phylink_mac_config) + return; + + ds->ops->phylink_mac_config(ds, dp->index, mode, state); +} + +static void dsa_port_phylink_mac_an_restart(struct phylink_config *config) +{ + struct dsa_port *dp = container_of(config, struct dsa_port, pl_config); + struct dsa_switch *ds = dp->ds; + + if (!ds->ops->phylink_mac_an_restart) + return; + + ds->ops->phylink_mac_an_restart(ds, dp->index); +} + +static void dsa_port_phylink_mac_link_down(struct phylink_config *config, + unsigned int mode, + phy_interface_t interface) +{ + struct dsa_port *dp = container_of(config, struct dsa_port, pl_config); + struct phy_device *phydev = NULL; + struct dsa_switch *ds = dp->ds; + + if (dsa_port_is_user(dp)) + phydev = dp->slave->phydev; + + if (!ds->ops->phylink_mac_link_down) { + if (ds->ops->adjust_link && phydev) + ds->ops->adjust_link(ds, dp->index, phydev); + return; + } + + ds->ops->phylink_mac_link_down(ds, dp->index, mode, interface); +} + +static void dsa_port_phylink_mac_link_up(struct phylink_config *config, + struct phy_device *phydev, + unsigned int mode, + phy_interface_t interface, + int speed, int duplex, + bool tx_pause, bool rx_pause) +{ + struct dsa_port *dp = container_of(config, struct dsa_port, pl_config); + struct dsa_switch *ds = dp->ds; + + if (!ds->ops->phylink_mac_link_up) { + if (ds->ops->adjust_link && phydev) + ds->ops->adjust_link(ds, dp->index, phydev); + return; + } + + ds->ops->phylink_mac_link_up(ds, dp->index, mode, interface, phydev, + speed, duplex, tx_pause, rx_pause); +} + +static const struct phylink_mac_ops dsa_port_phylink_mac_ops = { + .validate = dsa_port_phylink_validate, + .mac_select_pcs = dsa_port_phylink_mac_select_pcs, + .mac_pcs_get_state = dsa_port_phylink_mac_pcs_get_state, + .mac_config = dsa_port_phylink_mac_config, + .mac_an_restart = dsa_port_phylink_mac_an_restart, + .mac_link_down = dsa_port_phylink_mac_link_down, + .mac_link_up = dsa_port_phylink_mac_link_up, +}; + +int dsa_port_phylink_create(struct dsa_port *dp) +{ + struct dsa_switch *ds = dp->ds; + phy_interface_t mode; + struct phylink *pl; + int err; + + err = of_get_phy_mode(dp->dn, &mode); + if (err) + mode = PHY_INTERFACE_MODE_NA; + + /* Presence of phylink_mac_link_state or phylink_mac_an_restart is + * an indicator of a legacy phylink driver. + */ + if (ds->ops->phylink_mac_link_state || + ds->ops->phylink_mac_an_restart) + dp->pl_config.legacy_pre_march2020 = true; + + if (ds->ops->phylink_get_caps) + ds->ops->phylink_get_caps(ds, dp->index, &dp->pl_config); + + pl = phylink_create(&dp->pl_config, of_fwnode_handle(dp->dn), + mode, &dsa_port_phylink_mac_ops); + if (IS_ERR(pl)) { + pr_err("error creating PHYLINK: %ld\n", PTR_ERR(pl)); + return PTR_ERR(pl); + } + + dp->pl = pl; + + return 0; +} + +void dsa_port_phylink_destroy(struct dsa_port *dp) +{ + phylink_destroy(dp->pl); + dp->pl = NULL; +} + +static int dsa_shared_port_setup_phy_of(struct dsa_port *dp, bool enable) +{ + struct dsa_switch *ds = dp->ds; + struct phy_device *phydev; + int port = dp->index; + int err = 0; + + phydev = dsa_port_get_phy_device(dp); + if (!phydev) + return 0; + + if (IS_ERR(phydev)) + return PTR_ERR(phydev); + + if (enable) { + err = genphy_resume(phydev); + if (err < 0) + goto err_put_dev; + + err = genphy_read_status(phydev); + if (err < 0) + goto err_put_dev; + } else { + err = genphy_suspend(phydev); + if (err < 0) + goto err_put_dev; + } + + if (ds->ops->adjust_link) + ds->ops->adjust_link(ds, port, phydev); + + dev_dbg(ds->dev, "enabled port's phy: %s", phydev_name(phydev)); + +err_put_dev: + put_device(&phydev->mdio.dev); + return err; +} + +static int dsa_shared_port_fixed_link_register_of(struct dsa_port *dp) +{ + struct device_node *dn = dp->dn; + struct dsa_switch *ds = dp->ds; + struct phy_device *phydev; + int port = dp->index; + phy_interface_t mode; + int err; + + err = of_phy_register_fixed_link(dn); + if (err) { + dev_err(ds->dev, + "failed to register the fixed PHY of port %d\n", + port); + return err; + } + + phydev = of_phy_find_device(dn); + + err = of_get_phy_mode(dn, &mode); + if (err) + mode = PHY_INTERFACE_MODE_NA; + phydev->interface = mode; + + genphy_read_status(phydev); + + if (ds->ops->adjust_link) + ds->ops->adjust_link(ds, port, phydev); + + put_device(&phydev->mdio.dev); + + return 0; +} + +static int dsa_shared_port_phylink_register(struct dsa_port *dp) +{ + struct dsa_switch *ds = dp->ds; + struct device_node *port_dn = dp->dn; + int err; + + dp->pl_config.dev = ds->dev; + dp->pl_config.type = PHYLINK_DEV; + + err = dsa_port_phylink_create(dp); + if (err) + return err; + + err = phylink_of_phy_connect(dp->pl, port_dn, 0); + if (err && err != -ENODEV) { + pr_err("could not attach to PHY: %d\n", err); + goto err_phy_connect; + } + + return 0; + +err_phy_connect: + dsa_port_phylink_destroy(dp); + return err; +} + +/* During the initial DSA driver migration to OF, port nodes were sometimes + * added to device trees with no indication of how they should operate from a + * link management perspective (phy-handle, fixed-link, etc). Additionally, the + * phy-mode may be absent. The interpretation of these port OF nodes depends on + * their type. + * + * User ports with no phy-handle or fixed-link are expected to connect to an + * internal PHY located on the ds->slave_mii_bus at an MDIO address equal to + * the port number. This description is still actively supported. + * + * Shared (CPU and DSA) ports with no phy-handle or fixed-link are expected to + * operate at the maximum speed that their phy-mode is capable of. If the + * phy-mode is absent, they are expected to operate using the phy-mode + * supported by the port that gives the highest link speed. It is unspecified + * if the port should use flow control or not, half duplex or full duplex, or + * if the phy-mode is a SERDES link, whether in-band autoneg is expected to be + * enabled or not. + * + * In the latter case of shared ports, omitting the link management description + * from the firmware node is deprecated and strongly discouraged. DSA uses + * phylink, which rejects the firmware nodes of these ports for lacking + * required properties. + * + * For switches in this table, DSA will skip enforcing validation and will + * later omit registering a phylink instance for the shared ports, if they lack + * a fixed-link, a phy-handle, or a managed = "in-band-status" property. + * It becomes the responsibility of the driver to ensure that these ports + * operate at the maximum speed (whatever this means) and will interoperate + * with the DSA master or other cascade port, since phylink methods will not be + * invoked for them. + * + * If you are considering expanding this table for newly introduced switches, + * think again. It is OK to remove switches from this table if there aren't DT + * blobs in circulation which rely on defaulting the shared ports. + */ +static const char * const dsa_switches_apply_workarounds[] = { +#if IS_ENABLED(CONFIG_NET_DSA_XRS700X) + "arrow,xrs7003e", + "arrow,xrs7003f", + "arrow,xrs7004e", + "arrow,xrs7004f", +#endif +#if IS_ENABLED(CONFIG_B53) + "brcm,bcm5325", + "brcm,bcm53115", + "brcm,bcm53125", + "brcm,bcm53128", + "brcm,bcm5365", + "brcm,bcm5389", + "brcm,bcm5395", + "brcm,bcm5397", + "brcm,bcm5398", + "brcm,bcm53010-srab", + "brcm,bcm53011-srab", + "brcm,bcm53012-srab", + "brcm,bcm53018-srab", + "brcm,bcm53019-srab", + "brcm,bcm5301x-srab", + "brcm,bcm11360-srab", + "brcm,bcm58522-srab", + "brcm,bcm58525-srab", + "brcm,bcm58535-srab", + "brcm,bcm58622-srab", + "brcm,bcm58623-srab", + "brcm,bcm58625-srab", + "brcm,bcm88312-srab", + "brcm,cygnus-srab", + "brcm,nsp-srab", + "brcm,omega-srab", + "brcm,bcm3384-switch", + "brcm,bcm6328-switch", + "brcm,bcm6368-switch", + "brcm,bcm63xx-switch", +#endif +#if IS_ENABLED(CONFIG_NET_DSA_BCM_SF2) + "brcm,bcm7445-switch-v4.0", + "brcm,bcm7278-switch-v4.0", + "brcm,bcm7278-switch-v4.8", +#endif +#if IS_ENABLED(CONFIG_NET_DSA_LANTIQ_GSWIP) + "lantiq,xrx200-gswip", + "lantiq,xrx300-gswip", + "lantiq,xrx330-gswip", +#endif +#if IS_ENABLED(CONFIG_NET_DSA_MV88E6060) + "marvell,mv88e6060", +#endif +#if IS_ENABLED(CONFIG_NET_DSA_MV88E6XXX) + "marvell,mv88e6085", + "marvell,mv88e6190", + "marvell,mv88e6250", +#endif +#if IS_ENABLED(CONFIG_NET_DSA_MICROCHIP_KSZ_COMMON) + "microchip,ksz8765", + "microchip,ksz8794", + "microchip,ksz8795", + "microchip,ksz8863", + "microchip,ksz8873", + "microchip,ksz9477", + "microchip,ksz9897", + "microchip,ksz9893", + "microchip,ksz9563", + "microchip,ksz8563", + "microchip,ksz9567", +#endif +#if IS_ENABLED(CONFIG_NET_DSA_SMSC_LAN9303_MDIO) + "smsc,lan9303-mdio", +#endif +#if IS_ENABLED(CONFIG_NET_DSA_SMSC_LAN9303_I2C) + "smsc,lan9303-i2c", +#endif + NULL, +}; + +static void dsa_shared_port_validate_of(struct dsa_port *dp, + bool *missing_phy_mode, + bool *missing_link_description) +{ + struct device_node *dn = dp->dn, *phy_np; + struct dsa_switch *ds = dp->ds; + phy_interface_t mode; + + *missing_phy_mode = false; + *missing_link_description = false; + + if (of_get_phy_mode(dn, &mode)) { + *missing_phy_mode = true; + dev_err(ds->dev, + "OF node %pOF of %s port %d lacks the required \"phy-mode\" property\n", + dn, dsa_port_is_cpu(dp) ? "CPU" : "DSA", dp->index); + } + + /* Note: of_phy_is_fixed_link() also returns true for + * managed = "in-band-status" + */ + if (of_phy_is_fixed_link(dn)) + return; + + phy_np = of_parse_phandle(dn, "phy-handle", 0); + if (phy_np) { + of_node_put(phy_np); + return; + } + + *missing_link_description = true; + + dev_err(ds->dev, + "OF node %pOF of %s port %d lacks the required \"phy-handle\", \"fixed-link\" or \"managed\" properties\n", + dn, dsa_port_is_cpu(dp) ? "CPU" : "DSA", dp->index); +} + +int dsa_shared_port_link_register_of(struct dsa_port *dp) +{ + struct dsa_switch *ds = dp->ds; + bool missing_link_description; + bool missing_phy_mode; + int port = dp->index; + + dsa_shared_port_validate_of(dp, &missing_phy_mode, + &missing_link_description); + + if ((missing_phy_mode || missing_link_description) && + !of_device_compatible_match(ds->dev->of_node, + dsa_switches_apply_workarounds)) + return -EINVAL; + + if (!ds->ops->adjust_link) { + if (missing_link_description) { + dev_warn(ds->dev, + "Skipping phylink registration for %s port %d\n", + dsa_port_is_cpu(dp) ? "CPU" : "DSA", dp->index); + } else { + if (ds->ops->phylink_mac_link_down) + ds->ops->phylink_mac_link_down(ds, port, + MLO_AN_FIXED, PHY_INTERFACE_MODE_NA); + + return dsa_shared_port_phylink_register(dp); + } + return 0; + } + + dev_warn(ds->dev, + "Using legacy PHYLIB callbacks. Please migrate to PHYLINK!\n"); + + if (of_phy_is_fixed_link(dp->dn)) + return dsa_shared_port_fixed_link_register_of(dp); + else + return dsa_shared_port_setup_phy_of(dp, true); +} + +void dsa_shared_port_link_unregister_of(struct dsa_port *dp) +{ + struct dsa_switch *ds = dp->ds; + + if (!ds->ops->adjust_link && dp->pl) { + rtnl_lock(); + phylink_disconnect_phy(dp->pl); + rtnl_unlock(); + dsa_port_phylink_destroy(dp); + return; + } + + if (of_phy_is_fixed_link(dp->dn)) + of_phy_deregister_fixed_link(dp->dn); + else + dsa_shared_port_setup_phy_of(dp, false); +} + +int dsa_port_hsr_join(struct dsa_port *dp, struct net_device *hsr) +{ + struct dsa_switch *ds = dp->ds; + int err; + + if (!ds->ops->port_hsr_join) + return -EOPNOTSUPP; + + dp->hsr_dev = hsr; + + err = ds->ops->port_hsr_join(ds, dp->index, hsr); + if (err) + dp->hsr_dev = NULL; + + return err; +} + +void dsa_port_hsr_leave(struct dsa_port *dp, struct net_device *hsr) +{ + struct dsa_switch *ds = dp->ds; + int err; + + dp->hsr_dev = NULL; + + if (ds->ops->port_hsr_leave) { + err = ds->ops->port_hsr_leave(ds, dp->index, hsr); + if (err) + dev_err(dp->ds->dev, + "port %d failed to leave HSR %s: %pe\n", + dp->index, hsr->name, ERR_PTR(err)); + } +} + +int dsa_port_tag_8021q_vlan_add(struct dsa_port *dp, u16 vid, bool broadcast) +{ + struct dsa_notifier_tag_8021q_vlan_info info = { + .dp = dp, + .vid = vid, + }; + + if (broadcast) + return dsa_broadcast(DSA_NOTIFIER_TAG_8021Q_VLAN_ADD, &info); + + return dsa_port_notify(dp, DSA_NOTIFIER_TAG_8021Q_VLAN_ADD, &info); +} + +void dsa_port_tag_8021q_vlan_del(struct dsa_port *dp, u16 vid, bool broadcast) +{ + struct dsa_notifier_tag_8021q_vlan_info info = { + .dp = dp, + .vid = vid, + }; + int err; + + if (broadcast) + err = dsa_broadcast(DSA_NOTIFIER_TAG_8021Q_VLAN_DEL, &info); + else + err = dsa_port_notify(dp, DSA_NOTIFIER_TAG_8021Q_VLAN_DEL, &info); + if (err) + dev_err(dp->ds->dev, + "port %d failed to notify tag_8021q VLAN %d deletion: %pe\n", + dp->index, vid, ERR_PTR(err)); +} diff --git a/net/dsa/slave.c b/net/dsa/slave.c new file mode 100644 index 000000000..5fe075bf4 --- /dev/null +++ b/net/dsa/slave.c @@ -0,0 +1,3491 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/dsa/slave.c - Slave device handling + * Copyright (c) 2008-2009 Marvell Semiconductor + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "dsa_priv.h" + +static void dsa_slave_standalone_event_work(struct work_struct *work) +{ + struct dsa_standalone_event_work *standalone_work = + container_of(work, struct dsa_standalone_event_work, work); + const unsigned char *addr = standalone_work->addr; + struct net_device *dev = standalone_work->dev; + struct dsa_port *dp = dsa_slave_to_port(dev); + struct switchdev_obj_port_mdb mdb; + struct dsa_switch *ds = dp->ds; + u16 vid = standalone_work->vid; + int err; + + switch (standalone_work->event) { + case DSA_UC_ADD: + err = dsa_port_standalone_host_fdb_add(dp, addr, vid); + if (err) { + dev_err(ds->dev, + "port %d failed to add %pM vid %d to fdb: %d\n", + dp->index, addr, vid, err); + break; + } + break; + + case DSA_UC_DEL: + err = dsa_port_standalone_host_fdb_del(dp, addr, vid); + if (err) { + dev_err(ds->dev, + "port %d failed to delete %pM vid %d from fdb: %d\n", + dp->index, addr, vid, err); + } + + break; + case DSA_MC_ADD: + ether_addr_copy(mdb.addr, addr); + mdb.vid = vid; + + err = dsa_port_standalone_host_mdb_add(dp, &mdb); + if (err) { + dev_err(ds->dev, + "port %d failed to add %pM vid %d to mdb: %d\n", + dp->index, addr, vid, err); + break; + } + break; + case DSA_MC_DEL: + ether_addr_copy(mdb.addr, addr); + mdb.vid = vid; + + err = dsa_port_standalone_host_mdb_del(dp, &mdb); + if (err) { + dev_err(ds->dev, + "port %d failed to delete %pM vid %d from mdb: %d\n", + dp->index, addr, vid, err); + } + + break; + } + + kfree(standalone_work); +} + +static int dsa_slave_schedule_standalone_work(struct net_device *dev, + enum dsa_standalone_event event, + const unsigned char *addr, + u16 vid) +{ + struct dsa_standalone_event_work *standalone_work; + + standalone_work = kzalloc(sizeof(*standalone_work), GFP_ATOMIC); + if (!standalone_work) + return -ENOMEM; + + INIT_WORK(&standalone_work->work, dsa_slave_standalone_event_work); + standalone_work->event = event; + standalone_work->dev = dev; + + ether_addr_copy(standalone_work->addr, addr); + standalone_work->vid = vid; + + dsa_schedule_work(&standalone_work->work); + + return 0; +} + +static int dsa_slave_sync_uc(struct net_device *dev, + const unsigned char *addr) +{ + struct net_device *master = dsa_slave_to_master(dev); + struct dsa_port *dp = dsa_slave_to_port(dev); + + dev_uc_add(master, addr); + + if (!dsa_switch_supports_uc_filtering(dp->ds)) + return 0; + + return dsa_slave_schedule_standalone_work(dev, DSA_UC_ADD, addr, 0); +} + +static int dsa_slave_unsync_uc(struct net_device *dev, + const unsigned char *addr) +{ + struct net_device *master = dsa_slave_to_master(dev); + struct dsa_port *dp = dsa_slave_to_port(dev); + + dev_uc_del(master, addr); + + if (!dsa_switch_supports_uc_filtering(dp->ds)) + return 0; + + return dsa_slave_schedule_standalone_work(dev, DSA_UC_DEL, addr, 0); +} + +static int dsa_slave_sync_mc(struct net_device *dev, + const unsigned char *addr) +{ + struct net_device *master = dsa_slave_to_master(dev); + struct dsa_port *dp = dsa_slave_to_port(dev); + + dev_mc_add(master, addr); + + if (!dsa_switch_supports_mc_filtering(dp->ds)) + return 0; + + return dsa_slave_schedule_standalone_work(dev, DSA_MC_ADD, addr, 0); +} + +static int dsa_slave_unsync_mc(struct net_device *dev, + const unsigned char *addr) +{ + struct net_device *master = dsa_slave_to_master(dev); + struct dsa_port *dp = dsa_slave_to_port(dev); + + dev_mc_del(master, addr); + + if (!dsa_switch_supports_mc_filtering(dp->ds)) + return 0; + + return dsa_slave_schedule_standalone_work(dev, DSA_MC_DEL, addr, 0); +} + +void dsa_slave_sync_ha(struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + struct netdev_hw_addr *ha; + + netif_addr_lock_bh(dev); + + netdev_for_each_synced_mc_addr(ha, dev) + dsa_slave_sync_mc(dev, ha->addr); + + netdev_for_each_synced_uc_addr(ha, dev) + dsa_slave_sync_uc(dev, ha->addr); + + netif_addr_unlock_bh(dev); + + if (dsa_switch_supports_uc_filtering(ds) || + dsa_switch_supports_mc_filtering(ds)) + dsa_flush_workqueue(); +} + +void dsa_slave_unsync_ha(struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + struct netdev_hw_addr *ha; + + netif_addr_lock_bh(dev); + + netdev_for_each_synced_uc_addr(ha, dev) + dsa_slave_unsync_uc(dev, ha->addr); + + netdev_for_each_synced_mc_addr(ha, dev) + dsa_slave_unsync_mc(dev, ha->addr); + + netif_addr_unlock_bh(dev); + + if (dsa_switch_supports_uc_filtering(ds) || + dsa_switch_supports_mc_filtering(ds)) + dsa_flush_workqueue(); +} + +/* slave mii_bus handling ***************************************************/ +static int dsa_slave_phy_read(struct mii_bus *bus, int addr, int reg) +{ + struct dsa_switch *ds = bus->priv; + + if (ds->phys_mii_mask & (1 << addr)) + return ds->ops->phy_read(ds, addr, reg); + + return 0xffff; +} + +static int dsa_slave_phy_write(struct mii_bus *bus, int addr, int reg, u16 val) +{ + struct dsa_switch *ds = bus->priv; + + if (ds->phys_mii_mask & (1 << addr)) + return ds->ops->phy_write(ds, addr, reg, val); + + return 0; +} + +void dsa_slave_mii_bus_init(struct dsa_switch *ds) +{ + ds->slave_mii_bus->priv = (void *)ds; + ds->slave_mii_bus->name = "dsa slave smi"; + ds->slave_mii_bus->read = dsa_slave_phy_read; + ds->slave_mii_bus->write = dsa_slave_phy_write; + snprintf(ds->slave_mii_bus->id, MII_BUS_ID_SIZE, "dsa-%d.%d", + ds->dst->index, ds->index); + ds->slave_mii_bus->parent = ds->dev; + ds->slave_mii_bus->phy_mask = ~ds->phys_mii_mask; +} + + +/* slave device handling ****************************************************/ +static int dsa_slave_get_iflink(const struct net_device *dev) +{ + return dsa_slave_to_master(dev)->ifindex; +} + +static int dsa_slave_open(struct net_device *dev) +{ + struct net_device *master = dsa_slave_to_master(dev); + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + int err; + + err = dev_open(master, NULL); + if (err < 0) { + netdev_err(dev, "failed to open master %s\n", master->name); + goto out; + } + + if (dsa_switch_supports_uc_filtering(ds)) { + err = dsa_port_standalone_host_fdb_add(dp, dev->dev_addr, 0); + if (err) + goto out; + } + + if (!ether_addr_equal(dev->dev_addr, master->dev_addr)) { + err = dev_uc_add(master, dev->dev_addr); + if (err < 0) + goto del_host_addr; + } + + err = dsa_port_enable_rt(dp, dev->phydev); + if (err) + goto del_unicast; + + return 0; + +del_unicast: + if (!ether_addr_equal(dev->dev_addr, master->dev_addr)) + dev_uc_del(master, dev->dev_addr); +del_host_addr: + if (dsa_switch_supports_uc_filtering(ds)) + dsa_port_standalone_host_fdb_del(dp, dev->dev_addr, 0); +out: + return err; +} + +static int dsa_slave_close(struct net_device *dev) +{ + struct net_device *master = dsa_slave_to_master(dev); + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + dsa_port_disable_rt(dp); + + if (!ether_addr_equal(dev->dev_addr, master->dev_addr)) + dev_uc_del(master, dev->dev_addr); + + if (dsa_switch_supports_uc_filtering(ds)) + dsa_port_standalone_host_fdb_del(dp, dev->dev_addr, 0); + + return 0; +} + +static void dsa_slave_manage_host_flood(struct net_device *dev) +{ + bool mc = dev->flags & (IFF_PROMISC | IFF_ALLMULTI); + struct dsa_port *dp = dsa_slave_to_port(dev); + bool uc = dev->flags & IFF_PROMISC; + + dsa_port_set_host_flood(dp, uc, mc); +} + +static void dsa_slave_change_rx_flags(struct net_device *dev, int change) +{ + struct net_device *master = dsa_slave_to_master(dev); + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + if (change & IFF_ALLMULTI) + dev_set_allmulti(master, + dev->flags & IFF_ALLMULTI ? 1 : -1); + if (change & IFF_PROMISC) + dev_set_promiscuity(master, + dev->flags & IFF_PROMISC ? 1 : -1); + + if (dsa_switch_supports_uc_filtering(ds) && + dsa_switch_supports_mc_filtering(ds)) + dsa_slave_manage_host_flood(dev); +} + +static void dsa_slave_set_rx_mode(struct net_device *dev) +{ + __dev_mc_sync(dev, dsa_slave_sync_mc, dsa_slave_unsync_mc); + __dev_uc_sync(dev, dsa_slave_sync_uc, dsa_slave_unsync_uc); +} + +static int dsa_slave_set_mac_address(struct net_device *dev, void *a) +{ + struct net_device *master = dsa_slave_to_master(dev); + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + struct sockaddr *addr = a; + int err; + + if (!is_valid_ether_addr(addr->sa_data)) + return -EADDRNOTAVAIL; + + /* If the port is down, the address isn't synced yet to hardware or + * to the DSA master, so there is nothing to change. + */ + if (!(dev->flags & IFF_UP)) + goto out_change_dev_addr; + + if (dsa_switch_supports_uc_filtering(ds)) { + err = dsa_port_standalone_host_fdb_add(dp, addr->sa_data, 0); + if (err) + return err; + } + + if (!ether_addr_equal(addr->sa_data, master->dev_addr)) { + err = dev_uc_add(master, addr->sa_data); + if (err < 0) + goto del_unicast; + } + + if (!ether_addr_equal(dev->dev_addr, master->dev_addr)) + dev_uc_del(master, dev->dev_addr); + + if (dsa_switch_supports_uc_filtering(ds)) + dsa_port_standalone_host_fdb_del(dp, dev->dev_addr, 0); + +out_change_dev_addr: + eth_hw_addr_set(dev, addr->sa_data); + + return 0; + +del_unicast: + if (dsa_switch_supports_uc_filtering(ds)) + dsa_port_standalone_host_fdb_del(dp, addr->sa_data, 0); + + return err; +} + +struct dsa_slave_dump_ctx { + struct net_device *dev; + struct sk_buff *skb; + struct netlink_callback *cb; + int idx; +}; + +static int +dsa_slave_port_fdb_do_dump(const unsigned char *addr, u16 vid, + bool is_static, void *data) +{ + struct dsa_slave_dump_ctx *dump = data; + u32 portid = NETLINK_CB(dump->cb->skb).portid; + u32 seq = dump->cb->nlh->nlmsg_seq; + struct nlmsghdr *nlh; + struct ndmsg *ndm; + + if (dump->idx < dump->cb->args[2]) + goto skip; + + nlh = nlmsg_put(dump->skb, portid, seq, RTM_NEWNEIGH, + sizeof(*ndm), NLM_F_MULTI); + if (!nlh) + return -EMSGSIZE; + + ndm = nlmsg_data(nlh); + ndm->ndm_family = AF_BRIDGE; + ndm->ndm_pad1 = 0; + ndm->ndm_pad2 = 0; + ndm->ndm_flags = NTF_SELF; + ndm->ndm_type = 0; + ndm->ndm_ifindex = dump->dev->ifindex; + ndm->ndm_state = is_static ? NUD_NOARP : NUD_REACHABLE; + + if (nla_put(dump->skb, NDA_LLADDR, ETH_ALEN, addr)) + goto nla_put_failure; + + if (vid && nla_put_u16(dump->skb, NDA_VLAN, vid)) + goto nla_put_failure; + + nlmsg_end(dump->skb, nlh); + +skip: + dump->idx++; + return 0; + +nla_put_failure: + nlmsg_cancel(dump->skb, nlh); + return -EMSGSIZE; +} + +static int +dsa_slave_fdb_dump(struct sk_buff *skb, struct netlink_callback *cb, + struct net_device *dev, struct net_device *filter_dev, + int *idx) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_slave_dump_ctx dump = { + .dev = dev, + .skb = skb, + .cb = cb, + .idx = *idx, + }; + int err; + + err = dsa_port_fdb_dump(dp, dsa_slave_port_fdb_do_dump, &dump); + *idx = dump.idx; + + return err; +} + +static int dsa_slave_ioctl(struct net_device *dev, struct ifreq *ifr, int cmd) +{ + struct dsa_slave_priv *p = netdev_priv(dev); + struct dsa_switch *ds = p->dp->ds; + int port = p->dp->index; + + /* Pass through to switch driver if it supports timestamping */ + switch (cmd) { + case SIOCGHWTSTAMP: + if (ds->ops->port_hwtstamp_get) + return ds->ops->port_hwtstamp_get(ds, port, ifr); + break; + case SIOCSHWTSTAMP: + if (ds->ops->port_hwtstamp_set) + return ds->ops->port_hwtstamp_set(ds, port, ifr); + break; + } + + return phylink_mii_ioctl(p->dp->pl, ifr, cmd); +} + +static int dsa_slave_port_attr_set(struct net_device *dev, const void *ctx, + const struct switchdev_attr *attr, + struct netlink_ext_ack *extack) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + int ret; + + if (ctx && ctx != dp) + return 0; + + switch (attr->id) { + case SWITCHDEV_ATTR_ID_PORT_STP_STATE: + if (!dsa_port_offloads_bridge_port(dp, attr->orig_dev)) + return -EOPNOTSUPP; + + ret = dsa_port_set_state(dp, attr->u.stp_state, true); + break; + case SWITCHDEV_ATTR_ID_PORT_MST_STATE: + if (!dsa_port_offloads_bridge_port(dp, attr->orig_dev)) + return -EOPNOTSUPP; + + ret = dsa_port_set_mst_state(dp, &attr->u.mst_state, extack); + break; + case SWITCHDEV_ATTR_ID_BRIDGE_VLAN_FILTERING: + if (!dsa_port_offloads_bridge_dev(dp, attr->orig_dev)) + return -EOPNOTSUPP; + + ret = dsa_port_vlan_filtering(dp, attr->u.vlan_filtering, + extack); + break; + case SWITCHDEV_ATTR_ID_BRIDGE_AGEING_TIME: + if (!dsa_port_offloads_bridge_dev(dp, attr->orig_dev)) + return -EOPNOTSUPP; + + ret = dsa_port_ageing_time(dp, attr->u.ageing_time); + break; + case SWITCHDEV_ATTR_ID_BRIDGE_MST: + if (!dsa_port_offloads_bridge_dev(dp, attr->orig_dev)) + return -EOPNOTSUPP; + + ret = dsa_port_mst_enable(dp, attr->u.mst, extack); + break; + case SWITCHDEV_ATTR_ID_PORT_PRE_BRIDGE_FLAGS: + if (!dsa_port_offloads_bridge_port(dp, attr->orig_dev)) + return -EOPNOTSUPP; + + ret = dsa_port_pre_bridge_flags(dp, attr->u.brport_flags, + extack); + break; + case SWITCHDEV_ATTR_ID_PORT_BRIDGE_FLAGS: + if (!dsa_port_offloads_bridge_port(dp, attr->orig_dev)) + return -EOPNOTSUPP; + + ret = dsa_port_bridge_flags(dp, attr->u.brport_flags, extack); + break; + case SWITCHDEV_ATTR_ID_VLAN_MSTI: + if (!dsa_port_offloads_bridge_dev(dp, attr->orig_dev)) + return -EOPNOTSUPP; + + ret = dsa_port_vlan_msti(dp, &attr->u.vlan_msti); + break; + default: + ret = -EOPNOTSUPP; + break; + } + + return ret; +} + +/* Must be called under rcu_read_lock() */ +static int +dsa_slave_vlan_check_for_8021q_uppers(struct net_device *slave, + const struct switchdev_obj_port_vlan *vlan) +{ + struct net_device *upper_dev; + struct list_head *iter; + + netdev_for_each_upper_dev_rcu(slave, upper_dev, iter) { + u16 vid; + + if (!is_vlan_dev(upper_dev)) + continue; + + vid = vlan_dev_vlan_id(upper_dev); + if (vid == vlan->vid) + return -EBUSY; + } + + return 0; +} + +static int dsa_slave_vlan_add(struct net_device *dev, + const struct switchdev_obj *obj, + struct netlink_ext_ack *extack) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct switchdev_obj_port_vlan *vlan; + int err; + + if (dsa_port_skip_vlan_configuration(dp)) { + NL_SET_ERR_MSG_MOD(extack, "skipping configuration of VLAN"); + return 0; + } + + vlan = SWITCHDEV_OBJ_PORT_VLAN(obj); + + /* Deny adding a bridge VLAN when there is already an 802.1Q upper with + * the same VID. + */ + if (br_vlan_enabled(dsa_port_bridge_dev_get(dp))) { + rcu_read_lock(); + err = dsa_slave_vlan_check_for_8021q_uppers(dev, vlan); + rcu_read_unlock(); + if (err) { + NL_SET_ERR_MSG_MOD(extack, + "Port already has a VLAN upper with this VID"); + return err; + } + } + + return dsa_port_vlan_add(dp, vlan, extack); +} + +/* Offload a VLAN installed on the bridge or on a foreign interface by + * installing it as a VLAN towards the CPU port. + */ +static int dsa_slave_host_vlan_add(struct net_device *dev, + const struct switchdev_obj *obj, + struct netlink_ext_ack *extack) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct switchdev_obj_port_vlan vlan; + + /* Do nothing if this is a software bridge */ + if (!dp->bridge) + return -EOPNOTSUPP; + + if (dsa_port_skip_vlan_configuration(dp)) { + NL_SET_ERR_MSG_MOD(extack, "skipping configuration of VLAN"); + return 0; + } + + vlan = *SWITCHDEV_OBJ_PORT_VLAN(obj); + + /* Even though drivers often handle CPU membership in special ways, + * it doesn't make sense to program a PVID, so clear this flag. + */ + vlan.flags &= ~BRIDGE_VLAN_INFO_PVID; + + return dsa_port_host_vlan_add(dp, &vlan, extack); +} + +static int dsa_slave_port_obj_add(struct net_device *dev, const void *ctx, + const struct switchdev_obj *obj, + struct netlink_ext_ack *extack) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + int err; + + if (ctx && ctx != dp) + return 0; + + switch (obj->id) { + case SWITCHDEV_OBJ_ID_PORT_MDB: + if (!dsa_port_offloads_bridge_port(dp, obj->orig_dev)) + return -EOPNOTSUPP; + + err = dsa_port_mdb_add(dp, SWITCHDEV_OBJ_PORT_MDB(obj)); + break; + case SWITCHDEV_OBJ_ID_HOST_MDB: + if (!dsa_port_offloads_bridge_dev(dp, obj->orig_dev)) + return -EOPNOTSUPP; + + err = dsa_port_bridge_host_mdb_add(dp, SWITCHDEV_OBJ_PORT_MDB(obj)); + break; + case SWITCHDEV_OBJ_ID_PORT_VLAN: + if (dsa_port_offloads_bridge_port(dp, obj->orig_dev)) + err = dsa_slave_vlan_add(dev, obj, extack); + else + err = dsa_slave_host_vlan_add(dev, obj, extack); + break; + case SWITCHDEV_OBJ_ID_MRP: + if (!dsa_port_offloads_bridge_dev(dp, obj->orig_dev)) + return -EOPNOTSUPP; + + err = dsa_port_mrp_add(dp, SWITCHDEV_OBJ_MRP(obj)); + break; + case SWITCHDEV_OBJ_ID_RING_ROLE_MRP: + if (!dsa_port_offloads_bridge_dev(dp, obj->orig_dev)) + return -EOPNOTSUPP; + + err = dsa_port_mrp_add_ring_role(dp, + SWITCHDEV_OBJ_RING_ROLE_MRP(obj)); + break; + default: + err = -EOPNOTSUPP; + break; + } + + return err; +} + +static int dsa_slave_vlan_del(struct net_device *dev, + const struct switchdev_obj *obj) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct switchdev_obj_port_vlan *vlan; + + if (dsa_port_skip_vlan_configuration(dp)) + return 0; + + vlan = SWITCHDEV_OBJ_PORT_VLAN(obj); + + return dsa_port_vlan_del(dp, vlan); +} + +static int dsa_slave_host_vlan_del(struct net_device *dev, + const struct switchdev_obj *obj) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct switchdev_obj_port_vlan *vlan; + + /* Do nothing if this is a software bridge */ + if (!dp->bridge) + return -EOPNOTSUPP; + + if (dsa_port_skip_vlan_configuration(dp)) + return 0; + + vlan = SWITCHDEV_OBJ_PORT_VLAN(obj); + + return dsa_port_host_vlan_del(dp, vlan); +} + +static int dsa_slave_port_obj_del(struct net_device *dev, const void *ctx, + const struct switchdev_obj *obj) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + int err; + + if (ctx && ctx != dp) + return 0; + + switch (obj->id) { + case SWITCHDEV_OBJ_ID_PORT_MDB: + if (!dsa_port_offloads_bridge_port(dp, obj->orig_dev)) + return -EOPNOTSUPP; + + err = dsa_port_mdb_del(dp, SWITCHDEV_OBJ_PORT_MDB(obj)); + break; + case SWITCHDEV_OBJ_ID_HOST_MDB: + if (!dsa_port_offloads_bridge_dev(dp, obj->orig_dev)) + return -EOPNOTSUPP; + + err = dsa_port_bridge_host_mdb_del(dp, SWITCHDEV_OBJ_PORT_MDB(obj)); + break; + case SWITCHDEV_OBJ_ID_PORT_VLAN: + if (dsa_port_offloads_bridge_port(dp, obj->orig_dev)) + err = dsa_slave_vlan_del(dev, obj); + else + err = dsa_slave_host_vlan_del(dev, obj); + break; + case SWITCHDEV_OBJ_ID_MRP: + if (!dsa_port_offloads_bridge_dev(dp, obj->orig_dev)) + return -EOPNOTSUPP; + + err = dsa_port_mrp_del(dp, SWITCHDEV_OBJ_MRP(obj)); + break; + case SWITCHDEV_OBJ_ID_RING_ROLE_MRP: + if (!dsa_port_offloads_bridge_dev(dp, obj->orig_dev)) + return -EOPNOTSUPP; + + err = dsa_port_mrp_del_ring_role(dp, + SWITCHDEV_OBJ_RING_ROLE_MRP(obj)); + break; + default: + err = -EOPNOTSUPP; + break; + } + + return err; +} + +static inline netdev_tx_t dsa_slave_netpoll_send_skb(struct net_device *dev, + struct sk_buff *skb) +{ +#ifdef CONFIG_NET_POLL_CONTROLLER + struct dsa_slave_priv *p = netdev_priv(dev); + + return netpoll_send_skb(p->netpoll, skb); +#else + BUG(); + return NETDEV_TX_OK; +#endif +} + +static void dsa_skb_tx_timestamp(struct dsa_slave_priv *p, + struct sk_buff *skb) +{ + struct dsa_switch *ds = p->dp->ds; + + if (!(skb_shinfo(skb)->tx_flags & SKBTX_HW_TSTAMP)) + return; + + if (!ds->ops->port_txtstamp) + return; + + ds->ops->port_txtstamp(ds, p->dp->index, skb); +} + +netdev_tx_t dsa_enqueue_skb(struct sk_buff *skb, struct net_device *dev) +{ + /* SKB for netpoll still need to be mangled with the protocol-specific + * tag to be successfully transmitted + */ + if (unlikely(netpoll_tx_running(dev))) + return dsa_slave_netpoll_send_skb(dev, skb); + + /* Queue the SKB for transmission on the parent interface, but + * do not modify its EtherType + */ + skb->dev = dsa_slave_to_master(dev); + dev_queue_xmit(skb); + + return NETDEV_TX_OK; +} +EXPORT_SYMBOL_GPL(dsa_enqueue_skb); + +static int dsa_realloc_skb(struct sk_buff *skb, struct net_device *dev) +{ + int needed_headroom = dev->needed_headroom; + int needed_tailroom = dev->needed_tailroom; + + /* For tail taggers, we need to pad short frames ourselves, to ensure + * that the tail tag does not fail at its role of being at the end of + * the packet, once the master interface pads the frame. Account for + * that pad length here, and pad later. + */ + if (unlikely(needed_tailroom && skb->len < ETH_ZLEN)) + needed_tailroom += ETH_ZLEN - skb->len; + /* skb_headroom() returns unsigned int... */ + needed_headroom = max_t(int, needed_headroom - skb_headroom(skb), 0); + needed_tailroom = max_t(int, needed_tailroom - skb_tailroom(skb), 0); + + if (likely(!needed_headroom && !needed_tailroom && !skb_cloned(skb))) + /* No reallocation needed, yay! */ + return 0; + + return pskb_expand_head(skb, needed_headroom, needed_tailroom, + GFP_ATOMIC); +} + +static netdev_tx_t dsa_slave_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct dsa_slave_priv *p = netdev_priv(dev); + struct sk_buff *nskb; + + dev_sw_netstats_tx_add(dev, 1, skb->len); + + memset(skb->cb, 0, sizeof(skb->cb)); + + /* Handle tx timestamp if any */ + dsa_skb_tx_timestamp(p, skb); + + if (dsa_realloc_skb(skb, dev)) { + dev_kfree_skb_any(skb); + return NETDEV_TX_OK; + } + + /* needed_tailroom should still be 'warm' in the cache line from + * dsa_realloc_skb(), which has also ensured that padding is safe. + */ + if (dev->needed_tailroom) + eth_skb_pad(skb); + + /* Transmit function may have to reallocate the original SKB, + * in which case it must have freed it. Only free it here on error. + */ + nskb = p->xmit(skb, dev); + if (!nskb) { + kfree_skb(skb); + return NETDEV_TX_OK; + } + + return dsa_enqueue_skb(nskb, dev); +} + +/* ethtool operations *******************************************************/ + +static void dsa_slave_get_drvinfo(struct net_device *dev, + struct ethtool_drvinfo *drvinfo) +{ + strscpy(drvinfo->driver, "dsa", sizeof(drvinfo->driver)); + strscpy(drvinfo->fw_version, "N/A", sizeof(drvinfo->fw_version)); + strscpy(drvinfo->bus_info, "platform", sizeof(drvinfo->bus_info)); +} + +static int dsa_slave_get_regs_len(struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + if (ds->ops->get_regs_len) + return ds->ops->get_regs_len(ds, dp->index); + + return -EOPNOTSUPP; +} + +static void +dsa_slave_get_regs(struct net_device *dev, struct ethtool_regs *regs, void *_p) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + if (ds->ops->get_regs) + ds->ops->get_regs(ds, dp->index, regs, _p); +} + +static int dsa_slave_nway_reset(struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + + return phylink_ethtool_nway_reset(dp->pl); +} + +static int dsa_slave_get_eeprom_len(struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + if (ds->cd && ds->cd->eeprom_len) + return ds->cd->eeprom_len; + + if (ds->ops->get_eeprom_len) + return ds->ops->get_eeprom_len(ds); + + return 0; +} + +static int dsa_slave_get_eeprom(struct net_device *dev, + struct ethtool_eeprom *eeprom, u8 *data) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + if (ds->ops->get_eeprom) + return ds->ops->get_eeprom(ds, eeprom, data); + + return -EOPNOTSUPP; +} + +static int dsa_slave_set_eeprom(struct net_device *dev, + struct ethtool_eeprom *eeprom, u8 *data) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + if (ds->ops->set_eeprom) + return ds->ops->set_eeprom(ds, eeprom, data); + + return -EOPNOTSUPP; +} + +static void dsa_slave_get_strings(struct net_device *dev, + uint32_t stringset, uint8_t *data) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + if (stringset == ETH_SS_STATS) { + int len = ETH_GSTRING_LEN; + + strncpy(data, "tx_packets", len); + strncpy(data + len, "tx_bytes", len); + strncpy(data + 2 * len, "rx_packets", len); + strncpy(data + 3 * len, "rx_bytes", len); + if (ds->ops->get_strings) + ds->ops->get_strings(ds, dp->index, stringset, + data + 4 * len); + } else if (stringset == ETH_SS_TEST) { + net_selftest_get_strings(data); + } + +} + +static void dsa_slave_get_ethtool_stats(struct net_device *dev, + struct ethtool_stats *stats, + uint64_t *data) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + struct pcpu_sw_netstats *s; + unsigned int start; + int i; + + for_each_possible_cpu(i) { + u64 tx_packets, tx_bytes, rx_packets, rx_bytes; + + s = per_cpu_ptr(dev->tstats, i); + do { + start = u64_stats_fetch_begin_irq(&s->syncp); + tx_packets = u64_stats_read(&s->tx_packets); + tx_bytes = u64_stats_read(&s->tx_bytes); + rx_packets = u64_stats_read(&s->rx_packets); + rx_bytes = u64_stats_read(&s->rx_bytes); + } while (u64_stats_fetch_retry_irq(&s->syncp, start)); + data[0] += tx_packets; + data[1] += tx_bytes; + data[2] += rx_packets; + data[3] += rx_bytes; + } + if (ds->ops->get_ethtool_stats) + ds->ops->get_ethtool_stats(ds, dp->index, data + 4); +} + +static int dsa_slave_get_sset_count(struct net_device *dev, int sset) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + if (sset == ETH_SS_STATS) { + int count = 0; + + if (ds->ops->get_sset_count) { + count = ds->ops->get_sset_count(ds, dp->index, sset); + if (count < 0) + return count; + } + + return count + 4; + } else if (sset == ETH_SS_TEST) { + return net_selftest_get_count(); + } + + return -EOPNOTSUPP; +} + +static void dsa_slave_get_eth_phy_stats(struct net_device *dev, + struct ethtool_eth_phy_stats *phy_stats) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + if (ds->ops->get_eth_phy_stats) + ds->ops->get_eth_phy_stats(ds, dp->index, phy_stats); +} + +static void dsa_slave_get_eth_mac_stats(struct net_device *dev, + struct ethtool_eth_mac_stats *mac_stats) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + if (ds->ops->get_eth_mac_stats) + ds->ops->get_eth_mac_stats(ds, dp->index, mac_stats); +} + +static void +dsa_slave_get_eth_ctrl_stats(struct net_device *dev, + struct ethtool_eth_ctrl_stats *ctrl_stats) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + if (ds->ops->get_eth_ctrl_stats) + ds->ops->get_eth_ctrl_stats(ds, dp->index, ctrl_stats); +} + +static void +dsa_slave_get_rmon_stats(struct net_device *dev, + struct ethtool_rmon_stats *rmon_stats, + const struct ethtool_rmon_hist_range **ranges) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + if (ds->ops->get_rmon_stats) + ds->ops->get_rmon_stats(ds, dp->index, rmon_stats, ranges); +} + +static void dsa_slave_net_selftest(struct net_device *ndev, + struct ethtool_test *etest, u64 *buf) +{ + struct dsa_port *dp = dsa_slave_to_port(ndev); + struct dsa_switch *ds = dp->ds; + + if (ds->ops->self_test) { + ds->ops->self_test(ds, dp->index, etest, buf); + return; + } + + net_selftest(ndev, etest, buf); +} + +static void dsa_slave_get_wol(struct net_device *dev, struct ethtool_wolinfo *w) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + phylink_ethtool_get_wol(dp->pl, w); + + if (ds->ops->get_wol) + ds->ops->get_wol(ds, dp->index, w); +} + +static int dsa_slave_set_wol(struct net_device *dev, struct ethtool_wolinfo *w) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + int ret = -EOPNOTSUPP; + + phylink_ethtool_set_wol(dp->pl, w); + + if (ds->ops->set_wol) + ret = ds->ops->set_wol(ds, dp->index, w); + + return ret; +} + +static int dsa_slave_set_eee(struct net_device *dev, struct ethtool_eee *e) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + int ret; + + /* Port's PHY and MAC both need to be EEE capable */ + if (!dev->phydev || !dp->pl) + return -ENODEV; + + if (!ds->ops->set_mac_eee) + return -EOPNOTSUPP; + + ret = ds->ops->set_mac_eee(ds, dp->index, e); + if (ret) + return ret; + + return phylink_ethtool_set_eee(dp->pl, e); +} + +static int dsa_slave_get_eee(struct net_device *dev, struct ethtool_eee *e) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + int ret; + + /* Port's PHY and MAC both need to be EEE capable */ + if (!dev->phydev || !dp->pl) + return -ENODEV; + + if (!ds->ops->get_mac_eee) + return -EOPNOTSUPP; + + ret = ds->ops->get_mac_eee(ds, dp->index, e); + if (ret) + return ret; + + return phylink_ethtool_get_eee(dp->pl, e); +} + +static int dsa_slave_get_link_ksettings(struct net_device *dev, + struct ethtool_link_ksettings *cmd) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + + return phylink_ethtool_ksettings_get(dp->pl, cmd); +} + +static int dsa_slave_set_link_ksettings(struct net_device *dev, + const struct ethtool_link_ksettings *cmd) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + + return phylink_ethtool_ksettings_set(dp->pl, cmd); +} + +static void dsa_slave_get_pause_stats(struct net_device *dev, + struct ethtool_pause_stats *pause_stats) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + if (ds->ops->get_pause_stats) + ds->ops->get_pause_stats(ds, dp->index, pause_stats); +} + +static void dsa_slave_get_pauseparam(struct net_device *dev, + struct ethtool_pauseparam *pause) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + + phylink_ethtool_get_pauseparam(dp->pl, pause); +} + +static int dsa_slave_set_pauseparam(struct net_device *dev, + struct ethtool_pauseparam *pause) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + + return phylink_ethtool_set_pauseparam(dp->pl, pause); +} + +#ifdef CONFIG_NET_POLL_CONTROLLER +static int dsa_slave_netpoll_setup(struct net_device *dev, + struct netpoll_info *ni) +{ + struct net_device *master = dsa_slave_to_master(dev); + struct dsa_slave_priv *p = netdev_priv(dev); + struct netpoll *netpoll; + int err = 0; + + netpoll = kzalloc(sizeof(*netpoll), GFP_KERNEL); + if (!netpoll) + return -ENOMEM; + + err = __netpoll_setup(netpoll, master); + if (err) { + kfree(netpoll); + goto out; + } + + p->netpoll = netpoll; +out: + return err; +} + +static void dsa_slave_netpoll_cleanup(struct net_device *dev) +{ + struct dsa_slave_priv *p = netdev_priv(dev); + struct netpoll *netpoll = p->netpoll; + + if (!netpoll) + return; + + p->netpoll = NULL; + + __netpoll_free(netpoll); +} + +static void dsa_slave_poll_controller(struct net_device *dev) +{ +} +#endif + +static struct dsa_mall_tc_entry * +dsa_slave_mall_tc_entry_find(struct net_device *dev, unsigned long cookie) +{ + struct dsa_slave_priv *p = netdev_priv(dev); + struct dsa_mall_tc_entry *mall_tc_entry; + + list_for_each_entry(mall_tc_entry, &p->mall_tc_list, list) + if (mall_tc_entry->cookie == cookie) + return mall_tc_entry; + + return NULL; +} + +static int +dsa_slave_add_cls_matchall_mirred(struct net_device *dev, + struct tc_cls_matchall_offload *cls, + bool ingress) +{ + struct netlink_ext_ack *extack = cls->common.extack; + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_slave_priv *p = netdev_priv(dev); + struct dsa_mall_mirror_tc_entry *mirror; + struct dsa_mall_tc_entry *mall_tc_entry; + struct dsa_switch *ds = dp->ds; + struct flow_action_entry *act; + struct dsa_port *to_dp; + int err; + + if (!ds->ops->port_mirror_add) + return -EOPNOTSUPP; + + if (!flow_action_basic_hw_stats_check(&cls->rule->action, + cls->common.extack)) + return -EOPNOTSUPP; + + act = &cls->rule->action.entries[0]; + + if (!act->dev) + return -EINVAL; + + if (!dsa_slave_dev_check(act->dev)) + return -EOPNOTSUPP; + + mall_tc_entry = kzalloc(sizeof(*mall_tc_entry), GFP_KERNEL); + if (!mall_tc_entry) + return -ENOMEM; + + mall_tc_entry->cookie = cls->cookie; + mall_tc_entry->type = DSA_PORT_MALL_MIRROR; + mirror = &mall_tc_entry->mirror; + + to_dp = dsa_slave_to_port(act->dev); + + mirror->to_local_port = to_dp->index; + mirror->ingress = ingress; + + err = ds->ops->port_mirror_add(ds, dp->index, mirror, ingress, extack); + if (err) { + kfree(mall_tc_entry); + return err; + } + + list_add_tail(&mall_tc_entry->list, &p->mall_tc_list); + + return err; +} + +static int +dsa_slave_add_cls_matchall_police(struct net_device *dev, + struct tc_cls_matchall_offload *cls, + bool ingress) +{ + struct netlink_ext_ack *extack = cls->common.extack; + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_slave_priv *p = netdev_priv(dev); + struct dsa_mall_policer_tc_entry *policer; + struct dsa_mall_tc_entry *mall_tc_entry; + struct dsa_switch *ds = dp->ds; + struct flow_action_entry *act; + int err; + + if (!ds->ops->port_policer_add) { + NL_SET_ERR_MSG_MOD(extack, + "Policing offload not implemented"); + return -EOPNOTSUPP; + } + + if (!ingress) { + NL_SET_ERR_MSG_MOD(extack, + "Only supported on ingress qdisc"); + return -EOPNOTSUPP; + } + + if (!flow_action_basic_hw_stats_check(&cls->rule->action, + cls->common.extack)) + return -EOPNOTSUPP; + + list_for_each_entry(mall_tc_entry, &p->mall_tc_list, list) { + if (mall_tc_entry->type == DSA_PORT_MALL_POLICER) { + NL_SET_ERR_MSG_MOD(extack, + "Only one port policer allowed"); + return -EEXIST; + } + } + + act = &cls->rule->action.entries[0]; + + mall_tc_entry = kzalloc(sizeof(*mall_tc_entry), GFP_KERNEL); + if (!mall_tc_entry) + return -ENOMEM; + + mall_tc_entry->cookie = cls->cookie; + mall_tc_entry->type = DSA_PORT_MALL_POLICER; + policer = &mall_tc_entry->policer; + policer->rate_bytes_per_sec = act->police.rate_bytes_ps; + policer->burst = act->police.burst; + + err = ds->ops->port_policer_add(ds, dp->index, policer); + if (err) { + kfree(mall_tc_entry); + return err; + } + + list_add_tail(&mall_tc_entry->list, &p->mall_tc_list); + + return err; +} + +static int dsa_slave_add_cls_matchall(struct net_device *dev, + struct tc_cls_matchall_offload *cls, + bool ingress) +{ + int err = -EOPNOTSUPP; + + if (cls->common.protocol == htons(ETH_P_ALL) && + flow_offload_has_one_action(&cls->rule->action) && + cls->rule->action.entries[0].id == FLOW_ACTION_MIRRED) + err = dsa_slave_add_cls_matchall_mirred(dev, cls, ingress); + else if (flow_offload_has_one_action(&cls->rule->action) && + cls->rule->action.entries[0].id == FLOW_ACTION_POLICE) + err = dsa_slave_add_cls_matchall_police(dev, cls, ingress); + + return err; +} + +static void dsa_slave_del_cls_matchall(struct net_device *dev, + struct tc_cls_matchall_offload *cls) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_mall_tc_entry *mall_tc_entry; + struct dsa_switch *ds = dp->ds; + + mall_tc_entry = dsa_slave_mall_tc_entry_find(dev, cls->cookie); + if (!mall_tc_entry) + return; + + list_del(&mall_tc_entry->list); + + switch (mall_tc_entry->type) { + case DSA_PORT_MALL_MIRROR: + if (ds->ops->port_mirror_del) + ds->ops->port_mirror_del(ds, dp->index, + &mall_tc_entry->mirror); + break; + case DSA_PORT_MALL_POLICER: + if (ds->ops->port_policer_del) + ds->ops->port_policer_del(ds, dp->index); + break; + default: + WARN_ON(1); + } + + kfree(mall_tc_entry); +} + +static int dsa_slave_setup_tc_cls_matchall(struct net_device *dev, + struct tc_cls_matchall_offload *cls, + bool ingress) +{ + if (cls->common.chain_index) + return -EOPNOTSUPP; + + switch (cls->command) { + case TC_CLSMATCHALL_REPLACE: + return dsa_slave_add_cls_matchall(dev, cls, ingress); + case TC_CLSMATCHALL_DESTROY: + dsa_slave_del_cls_matchall(dev, cls); + return 0; + default: + return -EOPNOTSUPP; + } +} + +static int dsa_slave_add_cls_flower(struct net_device *dev, + struct flow_cls_offload *cls, + bool ingress) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + int port = dp->index; + + if (!ds->ops->cls_flower_add) + return -EOPNOTSUPP; + + return ds->ops->cls_flower_add(ds, port, cls, ingress); +} + +static int dsa_slave_del_cls_flower(struct net_device *dev, + struct flow_cls_offload *cls, + bool ingress) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + int port = dp->index; + + if (!ds->ops->cls_flower_del) + return -EOPNOTSUPP; + + return ds->ops->cls_flower_del(ds, port, cls, ingress); +} + +static int dsa_slave_stats_cls_flower(struct net_device *dev, + struct flow_cls_offload *cls, + bool ingress) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + int port = dp->index; + + if (!ds->ops->cls_flower_stats) + return -EOPNOTSUPP; + + return ds->ops->cls_flower_stats(ds, port, cls, ingress); +} + +static int dsa_slave_setup_tc_cls_flower(struct net_device *dev, + struct flow_cls_offload *cls, + bool ingress) +{ + switch (cls->command) { + case FLOW_CLS_REPLACE: + return dsa_slave_add_cls_flower(dev, cls, ingress); + case FLOW_CLS_DESTROY: + return dsa_slave_del_cls_flower(dev, cls, ingress); + case FLOW_CLS_STATS: + return dsa_slave_stats_cls_flower(dev, cls, ingress); + default: + return -EOPNOTSUPP; + } +} + +static int dsa_slave_setup_tc_block_cb(enum tc_setup_type type, void *type_data, + void *cb_priv, bool ingress) +{ + struct net_device *dev = cb_priv; + + if (!tc_can_offload(dev)) + return -EOPNOTSUPP; + + switch (type) { + case TC_SETUP_CLSMATCHALL: + return dsa_slave_setup_tc_cls_matchall(dev, type_data, ingress); + case TC_SETUP_CLSFLOWER: + return dsa_slave_setup_tc_cls_flower(dev, type_data, ingress); + default: + return -EOPNOTSUPP; + } +} + +static int dsa_slave_setup_tc_block_cb_ig(enum tc_setup_type type, + void *type_data, void *cb_priv) +{ + return dsa_slave_setup_tc_block_cb(type, type_data, cb_priv, true); +} + +static int dsa_slave_setup_tc_block_cb_eg(enum tc_setup_type type, + void *type_data, void *cb_priv) +{ + return dsa_slave_setup_tc_block_cb(type, type_data, cb_priv, false); +} + +static LIST_HEAD(dsa_slave_block_cb_list); + +static int dsa_slave_setup_tc_block(struct net_device *dev, + struct flow_block_offload *f) +{ + struct flow_block_cb *block_cb; + flow_setup_cb_t *cb; + + if (f->binder_type == FLOW_BLOCK_BINDER_TYPE_CLSACT_INGRESS) + cb = dsa_slave_setup_tc_block_cb_ig; + else if (f->binder_type == FLOW_BLOCK_BINDER_TYPE_CLSACT_EGRESS) + cb = dsa_slave_setup_tc_block_cb_eg; + else + return -EOPNOTSUPP; + + f->driver_block_list = &dsa_slave_block_cb_list; + + switch (f->command) { + case FLOW_BLOCK_BIND: + if (flow_block_cb_is_busy(cb, dev, &dsa_slave_block_cb_list)) + return -EBUSY; + + block_cb = flow_block_cb_alloc(cb, dev, dev, NULL); + if (IS_ERR(block_cb)) + return PTR_ERR(block_cb); + + flow_block_cb_add(block_cb, f); + list_add_tail(&block_cb->driver_list, &dsa_slave_block_cb_list); + return 0; + case FLOW_BLOCK_UNBIND: + block_cb = flow_block_cb_lookup(f->block, cb, dev); + if (!block_cb) + return -ENOENT; + + flow_block_cb_remove(block_cb, f); + list_del(&block_cb->driver_list); + return 0; + default: + return -EOPNOTSUPP; + } +} + +static int dsa_slave_setup_ft_block(struct dsa_switch *ds, int port, + void *type_data) +{ + struct net_device *master = dsa_port_to_master(dsa_to_port(ds, port)); + + if (!master->netdev_ops->ndo_setup_tc) + return -EOPNOTSUPP; + + return master->netdev_ops->ndo_setup_tc(master, TC_SETUP_FT, type_data); +} + +static int dsa_slave_setup_tc(struct net_device *dev, enum tc_setup_type type, + void *type_data) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + switch (type) { + case TC_SETUP_BLOCK: + return dsa_slave_setup_tc_block(dev, type_data); + case TC_SETUP_FT: + return dsa_slave_setup_ft_block(ds, dp->index, type_data); + default: + break; + } + + if (!ds->ops->port_setup_tc) + return -EOPNOTSUPP; + + return ds->ops->port_setup_tc(ds, dp->index, type, type_data); +} + +static int dsa_slave_get_rxnfc(struct net_device *dev, + struct ethtool_rxnfc *nfc, u32 *rule_locs) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + if (!ds->ops->get_rxnfc) + return -EOPNOTSUPP; + + return ds->ops->get_rxnfc(ds, dp->index, nfc, rule_locs); +} + +static int dsa_slave_set_rxnfc(struct net_device *dev, + struct ethtool_rxnfc *nfc) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + if (!ds->ops->set_rxnfc) + return -EOPNOTSUPP; + + return ds->ops->set_rxnfc(ds, dp->index, nfc); +} + +static int dsa_slave_get_ts_info(struct net_device *dev, + struct ethtool_ts_info *ts) +{ + struct dsa_slave_priv *p = netdev_priv(dev); + struct dsa_switch *ds = p->dp->ds; + + if (!ds->ops->get_ts_info) + return -EOPNOTSUPP; + + return ds->ops->get_ts_info(ds, p->dp->index, ts); +} + +static int dsa_slave_vlan_rx_add_vid(struct net_device *dev, __be16 proto, + u16 vid) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct switchdev_obj_port_vlan vlan = { + .obj.id = SWITCHDEV_OBJ_ID_PORT_VLAN, + .vid = vid, + /* This API only allows programming tagged, non-PVID VIDs */ + .flags = 0, + }; + struct netlink_ext_ack extack = {0}; + int ret; + + /* User port... */ + ret = dsa_port_vlan_add(dp, &vlan, &extack); + if (ret) { + if (extack._msg) + netdev_err(dev, "%s\n", extack._msg); + return ret; + } + + /* And CPU port... */ + ret = dsa_port_host_vlan_add(dp, &vlan, &extack); + if (ret) { + if (extack._msg) + netdev_err(dev, "CPU port %d: %s\n", dp->cpu_dp->index, + extack._msg); + return ret; + } + + return 0; +} + +static int dsa_slave_vlan_rx_kill_vid(struct net_device *dev, __be16 proto, + u16 vid) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct switchdev_obj_port_vlan vlan = { + .vid = vid, + /* This API only allows programming tagged, non-PVID VIDs */ + .flags = 0, + }; + int err; + + err = dsa_port_vlan_del(dp, &vlan); + if (err) + return err; + + return dsa_port_host_vlan_del(dp, &vlan); +} + +static int dsa_slave_restore_vlan(struct net_device *vdev, int vid, void *arg) +{ + __be16 proto = vdev ? vlan_dev_vlan_proto(vdev) : htons(ETH_P_8021Q); + + return dsa_slave_vlan_rx_add_vid(arg, proto, vid); +} + +static int dsa_slave_clear_vlan(struct net_device *vdev, int vid, void *arg) +{ + __be16 proto = vdev ? vlan_dev_vlan_proto(vdev) : htons(ETH_P_8021Q); + + return dsa_slave_vlan_rx_kill_vid(arg, proto, vid); +} + +/* Keep the VLAN RX filtering list in sync with the hardware only if VLAN + * filtering is enabled. The baseline is that only ports that offload a + * VLAN-aware bridge are VLAN-aware, and standalone ports are VLAN-unaware, + * but there are exceptions for quirky hardware. + * + * If ds->vlan_filtering_is_global = true, then standalone ports which share + * the same switch with other ports that offload a VLAN-aware bridge are also + * inevitably VLAN-aware. + * + * To summarize, a DSA switch port offloads: + * + * - If standalone (this includes software bridge, software LAG): + * - if ds->needs_standalone_vlan_filtering = true, OR if + * (ds->vlan_filtering_is_global = true AND there are bridges spanning + * this switch chip which have vlan_filtering=1) + * - the 8021q upper VLANs + * - else (standalone VLAN filtering is not needed, VLAN filtering is not + * global, or it is, but no port is under a VLAN-aware bridge): + * - no VLAN (any 8021q upper is a software VLAN) + * + * - If under a vlan_filtering=0 bridge which it offload: + * - if ds->configure_vlan_while_not_filtering = true (default): + * - the bridge VLANs. These VLANs are committed to hardware but inactive. + * - else (deprecated): + * - no VLAN. The bridge VLANs are not restored when VLAN awareness is + * enabled, so this behavior is broken and discouraged. + * + * - If under a vlan_filtering=1 bridge which it offload: + * - the bridge VLANs + * - the 8021q upper VLANs + */ +int dsa_slave_manage_vlan_filtering(struct net_device *slave, + bool vlan_filtering) +{ + int err; + + if (vlan_filtering) { + slave->features |= NETIF_F_HW_VLAN_CTAG_FILTER; + + err = vlan_for_each(slave, dsa_slave_restore_vlan, slave); + if (err) { + vlan_for_each(slave, dsa_slave_clear_vlan, slave); + slave->features &= ~NETIF_F_HW_VLAN_CTAG_FILTER; + return err; + } + } else { + err = vlan_for_each(slave, dsa_slave_clear_vlan, slave); + if (err) + return err; + + slave->features &= ~NETIF_F_HW_VLAN_CTAG_FILTER; + } + + return 0; +} + +struct dsa_hw_port { + struct list_head list; + struct net_device *dev; + int old_mtu; +}; + +static int dsa_hw_port_list_set_mtu(struct list_head *hw_port_list, int mtu) +{ + const struct dsa_hw_port *p; + int err; + + list_for_each_entry(p, hw_port_list, list) { + if (p->dev->mtu == mtu) + continue; + + err = dev_set_mtu(p->dev, mtu); + if (err) + goto rollback; + } + + return 0; + +rollback: + list_for_each_entry_continue_reverse(p, hw_port_list, list) { + if (p->dev->mtu == p->old_mtu) + continue; + + if (dev_set_mtu(p->dev, p->old_mtu)) + netdev_err(p->dev, "Failed to restore MTU\n"); + } + + return err; +} + +static void dsa_hw_port_list_free(struct list_head *hw_port_list) +{ + struct dsa_hw_port *p, *n; + + list_for_each_entry_safe(p, n, hw_port_list, list) + kfree(p); +} + +/* Make the hardware datapath to/from @dev limited to a common MTU */ +static void dsa_bridge_mtu_normalization(struct dsa_port *dp) +{ + struct list_head hw_port_list; + struct dsa_switch_tree *dst; + int min_mtu = ETH_MAX_MTU; + struct dsa_port *other_dp; + int err; + + if (!dp->ds->mtu_enforcement_ingress) + return; + + if (!dp->bridge) + return; + + INIT_LIST_HEAD(&hw_port_list); + + /* Populate the list of ports that are part of the same bridge + * as the newly added/modified port + */ + list_for_each_entry(dst, &dsa_tree_list, list) { + list_for_each_entry(other_dp, &dst->ports, list) { + struct dsa_hw_port *hw_port; + struct net_device *slave; + + if (other_dp->type != DSA_PORT_TYPE_USER) + continue; + + if (!dsa_port_bridge_same(dp, other_dp)) + continue; + + if (!other_dp->ds->mtu_enforcement_ingress) + continue; + + slave = other_dp->slave; + + if (min_mtu > slave->mtu) + min_mtu = slave->mtu; + + hw_port = kzalloc(sizeof(*hw_port), GFP_KERNEL); + if (!hw_port) + goto out; + + hw_port->dev = slave; + hw_port->old_mtu = slave->mtu; + + list_add(&hw_port->list, &hw_port_list); + } + } + + /* Attempt to configure the entire hardware bridge to the newly added + * interface's MTU first, regardless of whether the intention of the + * user was to raise or lower it. + */ + err = dsa_hw_port_list_set_mtu(&hw_port_list, dp->slave->mtu); + if (!err) + goto out; + + /* Clearly that didn't work out so well, so just set the minimum MTU on + * all hardware bridge ports now. If this fails too, then all ports will + * still have their old MTU rolled back anyway. + */ + dsa_hw_port_list_set_mtu(&hw_port_list, min_mtu); + +out: + dsa_hw_port_list_free(&hw_port_list); +} + +int dsa_slave_change_mtu(struct net_device *dev, int new_mtu) +{ + struct net_device *master = dsa_slave_to_master(dev); + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_port *cpu_dp = dp->cpu_dp; + struct dsa_switch *ds = dp->ds; + struct dsa_port *other_dp; + int largest_mtu = 0; + int new_master_mtu; + int old_master_mtu; + int mtu_limit; + int overhead; + int cpu_mtu; + int err; + + if (!ds->ops->port_change_mtu) + return -EOPNOTSUPP; + + dsa_tree_for_each_user_port(other_dp, ds->dst) { + int slave_mtu; + + /* During probe, this function will be called for each slave + * device, while not all of them have been allocated. That's + * ok, it doesn't change what the maximum is, so ignore it. + */ + if (!other_dp->slave) + continue; + + /* Pretend that we already applied the setting, which we + * actually haven't (still haven't done all integrity checks) + */ + if (dp == other_dp) + slave_mtu = new_mtu; + else + slave_mtu = other_dp->slave->mtu; + + if (largest_mtu < slave_mtu) + largest_mtu = slave_mtu; + } + + overhead = dsa_tag_protocol_overhead(cpu_dp->tag_ops); + mtu_limit = min_t(int, master->max_mtu, dev->max_mtu + overhead); + old_master_mtu = master->mtu; + new_master_mtu = largest_mtu + overhead; + if (new_master_mtu > mtu_limit) + return -ERANGE; + + /* If the master MTU isn't over limit, there's no need to check the CPU + * MTU, since that surely isn't either. + */ + cpu_mtu = largest_mtu; + + /* Start applying stuff */ + if (new_master_mtu != old_master_mtu) { + err = dev_set_mtu(master, new_master_mtu); + if (err < 0) + goto out_master_failed; + + /* We only need to propagate the MTU of the CPU port to + * upstream switches, so emit a notifier which updates them. + */ + err = dsa_port_mtu_change(cpu_dp, cpu_mtu); + if (err) + goto out_cpu_failed; + } + + err = ds->ops->port_change_mtu(ds, dp->index, new_mtu); + if (err) + goto out_port_failed; + + dev->mtu = new_mtu; + + dsa_bridge_mtu_normalization(dp); + + return 0; + +out_port_failed: + if (new_master_mtu != old_master_mtu) + dsa_port_mtu_change(cpu_dp, old_master_mtu - overhead); +out_cpu_failed: + if (new_master_mtu != old_master_mtu) + dev_set_mtu(master, old_master_mtu); +out_master_failed: + return err; +} + +static int __maybe_unused +dsa_slave_dcbnl_set_default_prio(struct net_device *dev, struct dcb_app *app) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + unsigned long mask, new_prio; + int err, port = dp->index; + + if (!ds->ops->port_set_default_prio) + return -EOPNOTSUPP; + + err = dcb_ieee_setapp(dev, app); + if (err) + return err; + + mask = dcb_ieee_getapp_mask(dev, app); + new_prio = __fls(mask); + + err = ds->ops->port_set_default_prio(ds, port, new_prio); + if (err) { + dcb_ieee_delapp(dev, app); + return err; + } + + return 0; +} + +static int __maybe_unused +dsa_slave_dcbnl_add_dscp_prio(struct net_device *dev, struct dcb_app *app) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + unsigned long mask, new_prio; + int err, port = dp->index; + u8 dscp = app->protocol; + + if (!ds->ops->port_add_dscp_prio) + return -EOPNOTSUPP; + + if (dscp >= 64) { + netdev_err(dev, "DSCP APP entry with protocol value %u is invalid\n", + dscp); + return -EINVAL; + } + + err = dcb_ieee_setapp(dev, app); + if (err) + return err; + + mask = dcb_ieee_getapp_mask(dev, app); + new_prio = __fls(mask); + + err = ds->ops->port_add_dscp_prio(ds, port, dscp, new_prio); + if (err) { + dcb_ieee_delapp(dev, app); + return err; + } + + return 0; +} + +static int __maybe_unused dsa_slave_dcbnl_ieee_setapp(struct net_device *dev, + struct dcb_app *app) +{ + switch (app->selector) { + case IEEE_8021QAZ_APP_SEL_ETHERTYPE: + switch (app->protocol) { + case 0: + return dsa_slave_dcbnl_set_default_prio(dev, app); + default: + return -EOPNOTSUPP; + } + break; + case IEEE_8021QAZ_APP_SEL_DSCP: + return dsa_slave_dcbnl_add_dscp_prio(dev, app); + default: + return -EOPNOTSUPP; + } +} + +static int __maybe_unused +dsa_slave_dcbnl_del_default_prio(struct net_device *dev, struct dcb_app *app) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + unsigned long mask, new_prio; + int err, port = dp->index; + + if (!ds->ops->port_set_default_prio) + return -EOPNOTSUPP; + + err = dcb_ieee_delapp(dev, app); + if (err) + return err; + + mask = dcb_ieee_getapp_mask(dev, app); + new_prio = mask ? __fls(mask) : 0; + + err = ds->ops->port_set_default_prio(ds, port, new_prio); + if (err) { + dcb_ieee_setapp(dev, app); + return err; + } + + return 0; +} + +static int __maybe_unused +dsa_slave_dcbnl_del_dscp_prio(struct net_device *dev, struct dcb_app *app) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + int err, port = dp->index; + u8 dscp = app->protocol; + + if (!ds->ops->port_del_dscp_prio) + return -EOPNOTSUPP; + + err = dcb_ieee_delapp(dev, app); + if (err) + return err; + + err = ds->ops->port_del_dscp_prio(ds, port, dscp, app->priority); + if (err) { + dcb_ieee_setapp(dev, app); + return err; + } + + return 0; +} + +static int __maybe_unused dsa_slave_dcbnl_ieee_delapp(struct net_device *dev, + struct dcb_app *app) +{ + switch (app->selector) { + case IEEE_8021QAZ_APP_SEL_ETHERTYPE: + switch (app->protocol) { + case 0: + return dsa_slave_dcbnl_del_default_prio(dev, app); + default: + return -EOPNOTSUPP; + } + break; + case IEEE_8021QAZ_APP_SEL_DSCP: + return dsa_slave_dcbnl_del_dscp_prio(dev, app); + default: + return -EOPNOTSUPP; + } +} + +/* Pre-populate the DCB application priority table with the priorities + * configured during switch setup, which we read from hardware here. + */ +static int dsa_slave_dcbnl_init(struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + int port = dp->index; + int err; + + if (ds->ops->port_get_default_prio) { + int prio = ds->ops->port_get_default_prio(ds, port); + struct dcb_app app = { + .selector = IEEE_8021QAZ_APP_SEL_ETHERTYPE, + .protocol = 0, + .priority = prio, + }; + + if (prio < 0) + return prio; + + err = dcb_ieee_setapp(dev, &app); + if (err) + return err; + } + + if (ds->ops->port_get_dscp_prio) { + int protocol; + + for (protocol = 0; protocol < 64; protocol++) { + struct dcb_app app = { + .selector = IEEE_8021QAZ_APP_SEL_DSCP, + .protocol = protocol, + }; + int prio; + + prio = ds->ops->port_get_dscp_prio(ds, port, protocol); + if (prio == -EOPNOTSUPP) + continue; + if (prio < 0) + return prio; + + app.priority = prio; + + err = dcb_ieee_setapp(dev, &app); + if (err) + return err; + } + } + + return 0; +} + +static const struct ethtool_ops dsa_slave_ethtool_ops = { + .get_drvinfo = dsa_slave_get_drvinfo, + .get_regs_len = dsa_slave_get_regs_len, + .get_regs = dsa_slave_get_regs, + .nway_reset = dsa_slave_nway_reset, + .get_link = ethtool_op_get_link, + .get_eeprom_len = dsa_slave_get_eeprom_len, + .get_eeprom = dsa_slave_get_eeprom, + .set_eeprom = dsa_slave_set_eeprom, + .get_strings = dsa_slave_get_strings, + .get_ethtool_stats = dsa_slave_get_ethtool_stats, + .get_sset_count = dsa_slave_get_sset_count, + .get_eth_phy_stats = dsa_slave_get_eth_phy_stats, + .get_eth_mac_stats = dsa_slave_get_eth_mac_stats, + .get_eth_ctrl_stats = dsa_slave_get_eth_ctrl_stats, + .get_rmon_stats = dsa_slave_get_rmon_stats, + .set_wol = dsa_slave_set_wol, + .get_wol = dsa_slave_get_wol, + .set_eee = dsa_slave_set_eee, + .get_eee = dsa_slave_get_eee, + .get_link_ksettings = dsa_slave_get_link_ksettings, + .set_link_ksettings = dsa_slave_set_link_ksettings, + .get_pause_stats = dsa_slave_get_pause_stats, + .get_pauseparam = dsa_slave_get_pauseparam, + .set_pauseparam = dsa_slave_set_pauseparam, + .get_rxnfc = dsa_slave_get_rxnfc, + .set_rxnfc = dsa_slave_set_rxnfc, + .get_ts_info = dsa_slave_get_ts_info, + .self_test = dsa_slave_net_selftest, +}; + +static const struct dcbnl_rtnl_ops __maybe_unused dsa_slave_dcbnl_ops = { + .ieee_setapp = dsa_slave_dcbnl_ieee_setapp, + .ieee_delapp = dsa_slave_dcbnl_ieee_delapp, +}; + +static struct devlink_port *dsa_slave_get_devlink_port(struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + + return &dp->devlink_port; +} + +static void dsa_slave_get_stats64(struct net_device *dev, + struct rtnl_link_stats64 *s) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + + if (ds->ops->get_stats64) + ds->ops->get_stats64(ds, dp->index, s); + else + dev_get_tstats64(dev, s); +} + +static int dsa_slave_fill_forward_path(struct net_device_path_ctx *ctx, + struct net_device_path *path) +{ + struct dsa_port *dp = dsa_slave_to_port(ctx->dev); + struct net_device *master = dsa_port_to_master(dp); + struct dsa_port *cpu_dp = dp->cpu_dp; + + path->dev = ctx->dev; + path->type = DEV_PATH_DSA; + path->dsa.proto = cpu_dp->tag_ops->proto; + path->dsa.port = dp->index; + ctx->dev = master; + + return 0; +} + +static const struct net_device_ops dsa_slave_netdev_ops = { + .ndo_open = dsa_slave_open, + .ndo_stop = dsa_slave_close, + .ndo_start_xmit = dsa_slave_xmit, + .ndo_change_rx_flags = dsa_slave_change_rx_flags, + .ndo_set_rx_mode = dsa_slave_set_rx_mode, + .ndo_set_mac_address = dsa_slave_set_mac_address, + .ndo_fdb_dump = dsa_slave_fdb_dump, + .ndo_eth_ioctl = dsa_slave_ioctl, + .ndo_get_iflink = dsa_slave_get_iflink, +#ifdef CONFIG_NET_POLL_CONTROLLER + .ndo_netpoll_setup = dsa_slave_netpoll_setup, + .ndo_netpoll_cleanup = dsa_slave_netpoll_cleanup, + .ndo_poll_controller = dsa_slave_poll_controller, +#endif + .ndo_setup_tc = dsa_slave_setup_tc, + .ndo_get_stats64 = dsa_slave_get_stats64, + .ndo_vlan_rx_add_vid = dsa_slave_vlan_rx_add_vid, + .ndo_vlan_rx_kill_vid = dsa_slave_vlan_rx_kill_vid, + .ndo_get_devlink_port = dsa_slave_get_devlink_port, + .ndo_change_mtu = dsa_slave_change_mtu, + .ndo_fill_forward_path = dsa_slave_fill_forward_path, +}; + +static struct device_type dsa_type = { + .name = "dsa", +}; + +void dsa_port_phylink_mac_change(struct dsa_switch *ds, int port, bool up) +{ + const struct dsa_port *dp = dsa_to_port(ds, port); + + if (dp->pl) + phylink_mac_change(dp->pl, up); +} +EXPORT_SYMBOL_GPL(dsa_port_phylink_mac_change); + +static void dsa_slave_phylink_fixed_state(struct phylink_config *config, + struct phylink_link_state *state) +{ + struct dsa_port *dp = container_of(config, struct dsa_port, pl_config); + struct dsa_switch *ds = dp->ds; + + /* No need to check that this operation is valid, the callback would + * not be called if it was not. + */ + ds->ops->phylink_fixed_state(ds, dp->index, state); +} + +/* slave device setup *******************************************************/ +static int dsa_slave_phy_connect(struct net_device *slave_dev, int addr, + u32 flags) +{ + struct dsa_port *dp = dsa_slave_to_port(slave_dev); + struct dsa_switch *ds = dp->ds; + + slave_dev->phydev = mdiobus_get_phy(ds->slave_mii_bus, addr); + if (!slave_dev->phydev) { + netdev_err(slave_dev, "no phy at %d\n", addr); + return -ENODEV; + } + + slave_dev->phydev->dev_flags |= flags; + + return phylink_connect_phy(dp->pl, slave_dev->phydev); +} + +static int dsa_slave_phy_setup(struct net_device *slave_dev) +{ + struct dsa_port *dp = dsa_slave_to_port(slave_dev); + struct device_node *port_dn = dp->dn; + struct dsa_switch *ds = dp->ds; + u32 phy_flags = 0; + int ret; + + dp->pl_config.dev = &slave_dev->dev; + dp->pl_config.type = PHYLINK_NETDEV; + + /* The get_fixed_state callback takes precedence over polling the + * link GPIO in PHYLINK (see phylink_get_fixed_state). Only set + * this if the switch provides such a callback. + */ + if (ds->ops->phylink_fixed_state) { + dp->pl_config.get_fixed_state = dsa_slave_phylink_fixed_state; + dp->pl_config.poll_fixed_state = true; + } + + ret = dsa_port_phylink_create(dp); + if (ret) + return ret; + + if (ds->ops->get_phy_flags) + phy_flags = ds->ops->get_phy_flags(ds, dp->index); + + ret = phylink_of_phy_connect(dp->pl, port_dn, phy_flags); + if (ret == -ENODEV && ds->slave_mii_bus) { + /* We could not connect to a designated PHY or SFP, so try to + * use the switch internal MDIO bus instead + */ + ret = dsa_slave_phy_connect(slave_dev, dp->index, phy_flags); + } + if (ret) { + netdev_err(slave_dev, "failed to connect to PHY: %pe\n", + ERR_PTR(ret)); + dsa_port_phylink_destroy(dp); + } + + return ret; +} + +void dsa_slave_setup_tagger(struct net_device *slave) +{ + struct dsa_port *dp = dsa_slave_to_port(slave); + struct net_device *master = dsa_port_to_master(dp); + struct dsa_slave_priv *p = netdev_priv(slave); + const struct dsa_port *cpu_dp = dp->cpu_dp; + const struct dsa_switch *ds = dp->ds; + + slave->needed_headroom = cpu_dp->tag_ops->needed_headroom; + slave->needed_tailroom = cpu_dp->tag_ops->needed_tailroom; + /* Try to save one extra realloc later in the TX path (in the master) + * by also inheriting the master's needed headroom and tailroom. + * The 8021q driver also does this. + */ + slave->needed_headroom += master->needed_headroom; + slave->needed_tailroom += master->needed_tailroom; + + p->xmit = cpu_dp->tag_ops->xmit; + + slave->features = master->vlan_features | NETIF_F_HW_TC; + slave->hw_features |= NETIF_F_HW_TC; + slave->features |= NETIF_F_LLTX; + if (slave->needed_tailroom) + slave->features &= ~(NETIF_F_SG | NETIF_F_FRAGLIST); + if (ds->needs_standalone_vlan_filtering) + slave->features |= NETIF_F_HW_VLAN_CTAG_FILTER; +} + +int dsa_slave_suspend(struct net_device *slave_dev) +{ + struct dsa_port *dp = dsa_slave_to_port(slave_dev); + + if (!netif_running(slave_dev)) + return 0; + + netif_device_detach(slave_dev); + + rtnl_lock(); + phylink_stop(dp->pl); + rtnl_unlock(); + + return 0; +} + +int dsa_slave_resume(struct net_device *slave_dev) +{ + struct dsa_port *dp = dsa_slave_to_port(slave_dev); + + if (!netif_running(slave_dev)) + return 0; + + netif_device_attach(slave_dev); + + rtnl_lock(); + phylink_start(dp->pl); + rtnl_unlock(); + + return 0; +} + +int dsa_slave_create(struct dsa_port *port) +{ + struct net_device *master = dsa_port_to_master(port); + struct dsa_switch *ds = port->ds; + const char *name = port->name; + struct net_device *slave_dev; + struct dsa_slave_priv *p; + int ret; + + if (!ds->num_tx_queues) + ds->num_tx_queues = 1; + + slave_dev = alloc_netdev_mqs(sizeof(struct dsa_slave_priv), name, + NET_NAME_UNKNOWN, ether_setup, + ds->num_tx_queues, 1); + if (slave_dev == NULL) + return -ENOMEM; + + slave_dev->rtnl_link_ops = &dsa_link_ops; + slave_dev->ethtool_ops = &dsa_slave_ethtool_ops; +#if IS_ENABLED(CONFIG_DCB) + slave_dev->dcbnl_ops = &dsa_slave_dcbnl_ops; +#endif + if (!is_zero_ether_addr(port->mac)) + eth_hw_addr_set(slave_dev, port->mac); + else + eth_hw_addr_inherit(slave_dev, master); + slave_dev->priv_flags |= IFF_NO_QUEUE; + if (dsa_switch_supports_uc_filtering(ds)) + slave_dev->priv_flags |= IFF_UNICAST_FLT; + slave_dev->netdev_ops = &dsa_slave_netdev_ops; + if (ds->ops->port_max_mtu) + slave_dev->max_mtu = ds->ops->port_max_mtu(ds, port->index); + SET_NETDEV_DEVTYPE(slave_dev, &dsa_type); + + SET_NETDEV_DEV(slave_dev, port->ds->dev); + slave_dev->dev.of_node = port->dn; + slave_dev->vlan_features = master->vlan_features; + + p = netdev_priv(slave_dev); + slave_dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); + if (!slave_dev->tstats) { + free_netdev(slave_dev); + return -ENOMEM; + } + + ret = gro_cells_init(&p->gcells, slave_dev); + if (ret) + goto out_free; + + p->dp = port; + INIT_LIST_HEAD(&p->mall_tc_list); + port->slave = slave_dev; + dsa_slave_setup_tagger(slave_dev); + + netif_carrier_off(slave_dev); + + ret = dsa_slave_phy_setup(slave_dev); + if (ret) { + netdev_err(slave_dev, + "error %d setting up PHY for tree %d, switch %d, port %d\n", + ret, ds->dst->index, ds->index, port->index); + goto out_gcells; + } + + rtnl_lock(); + + ret = dsa_slave_change_mtu(slave_dev, ETH_DATA_LEN); + if (ret && ret != -EOPNOTSUPP) + dev_warn(ds->dev, "nonfatal error %d setting MTU to %d on port %d\n", + ret, ETH_DATA_LEN, port->index); + + ret = register_netdevice(slave_dev); + if (ret) { + netdev_err(master, "error %d registering interface %s\n", + ret, slave_dev->name); + rtnl_unlock(); + goto out_phy; + } + + if (IS_ENABLED(CONFIG_DCB)) { + ret = dsa_slave_dcbnl_init(slave_dev); + if (ret) { + netdev_err(slave_dev, + "failed to initialize DCB: %pe\n", + ERR_PTR(ret)); + rtnl_unlock(); + goto out_unregister; + } + } + + ret = netdev_upper_dev_link(master, slave_dev, NULL); + + rtnl_unlock(); + + if (ret) + goto out_unregister; + + return 0; + +out_unregister: + unregister_netdev(slave_dev); +out_phy: + rtnl_lock(); + phylink_disconnect_phy(p->dp->pl); + rtnl_unlock(); + dsa_port_phylink_destroy(p->dp); +out_gcells: + gro_cells_destroy(&p->gcells); +out_free: + free_percpu(slave_dev->tstats); + free_netdev(slave_dev); + port->slave = NULL; + return ret; +} + +void dsa_slave_destroy(struct net_device *slave_dev) +{ + struct net_device *master = dsa_slave_to_master(slave_dev); + struct dsa_port *dp = dsa_slave_to_port(slave_dev); + struct dsa_slave_priv *p = netdev_priv(slave_dev); + + netif_carrier_off(slave_dev); + rtnl_lock(); + netdev_upper_dev_unlink(master, slave_dev); + unregister_netdevice(slave_dev); + phylink_disconnect_phy(dp->pl); + rtnl_unlock(); + + dsa_port_phylink_destroy(dp); + gro_cells_destroy(&p->gcells); + free_percpu(slave_dev->tstats); + free_netdev(slave_dev); +} + +int dsa_slave_change_master(struct net_device *dev, struct net_device *master, + struct netlink_ext_ack *extack) +{ + struct net_device *old_master = dsa_slave_to_master(dev); + struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch *ds = dp->ds; + struct net_device *upper; + struct list_head *iter; + int err; + + if (master == old_master) + return 0; + + if (!ds->ops->port_change_master) { + NL_SET_ERR_MSG_MOD(extack, + "Driver does not support changing DSA master"); + return -EOPNOTSUPP; + } + + if (!netdev_uses_dsa(master)) { + NL_SET_ERR_MSG_MOD(extack, + "Interface not eligible as DSA master"); + return -EOPNOTSUPP; + } + + netdev_for_each_upper_dev_rcu(master, upper, iter) { + if (dsa_slave_dev_check(upper)) + continue; + if (netif_is_bridge_master(upper)) + continue; + NL_SET_ERR_MSG_MOD(extack, "Cannot join master with unknown uppers"); + return -EOPNOTSUPP; + } + + /* Since we allow live-changing the DSA master, plus we auto-open the + * DSA master when the user port opens => we need to ensure that the + * new DSA master is open too. + */ + if (dev->flags & IFF_UP) { + err = dev_open(master, extack); + if (err) + return err; + } + + netdev_upper_dev_unlink(old_master, dev); + + err = netdev_upper_dev_link(master, dev, extack); + if (err) + goto out_revert_old_master_unlink; + + err = dsa_port_change_master(dp, master, extack); + if (err) + goto out_revert_master_link; + + /* Update the MTU of the new CPU port through cross-chip notifiers */ + err = dsa_slave_change_mtu(dev, dev->mtu); + if (err && err != -EOPNOTSUPP) { + netdev_warn(dev, + "nonfatal error updating MTU with new master: %pe\n", + ERR_PTR(err)); + } + + /* If the port doesn't have its own MAC address and relies on the DSA + * master's one, inherit it again from the new DSA master. + */ + if (is_zero_ether_addr(dp->mac)) + eth_hw_addr_inherit(dev, master); + + return 0; + +out_revert_master_link: + netdev_upper_dev_unlink(master, dev); +out_revert_old_master_unlink: + netdev_upper_dev_link(old_master, dev, NULL); + return err; +} + +bool dsa_slave_dev_check(const struct net_device *dev) +{ + return dev->netdev_ops == &dsa_slave_netdev_ops; +} +EXPORT_SYMBOL_GPL(dsa_slave_dev_check); + +static int dsa_slave_changeupper(struct net_device *dev, + struct netdev_notifier_changeupper_info *info) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct netlink_ext_ack *extack; + int err = NOTIFY_DONE; + + if (!dsa_slave_dev_check(dev)) + return err; + + extack = netdev_notifier_info_to_extack(&info->info); + + if (netif_is_bridge_master(info->upper_dev)) { + if (info->linking) { + err = dsa_port_bridge_join(dp, info->upper_dev, extack); + if (!err) + dsa_bridge_mtu_normalization(dp); + if (err == -EOPNOTSUPP) { + if (extack && !extack->_msg) + NL_SET_ERR_MSG_MOD(extack, + "Offloading not supported"); + err = 0; + } + err = notifier_from_errno(err); + } else { + dsa_port_bridge_leave(dp, info->upper_dev); + err = NOTIFY_OK; + } + } else if (netif_is_lag_master(info->upper_dev)) { + if (info->linking) { + err = dsa_port_lag_join(dp, info->upper_dev, + info->upper_info, extack); + if (err == -EOPNOTSUPP) { + NL_SET_ERR_MSG_MOD(info->info.extack, + "Offloading not supported"); + err = 0; + } + err = notifier_from_errno(err); + } else { + dsa_port_lag_leave(dp, info->upper_dev); + err = NOTIFY_OK; + } + } else if (is_hsr_master(info->upper_dev)) { + if (info->linking) { + err = dsa_port_hsr_join(dp, info->upper_dev); + if (err == -EOPNOTSUPP) { + NL_SET_ERR_MSG_MOD(info->info.extack, + "Offloading not supported"); + err = 0; + } + err = notifier_from_errno(err); + } else { + dsa_port_hsr_leave(dp, info->upper_dev); + err = NOTIFY_OK; + } + } + + return err; +} + +static int dsa_slave_prechangeupper(struct net_device *dev, + struct netdev_notifier_changeupper_info *info) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + + if (!dsa_slave_dev_check(dev)) + return NOTIFY_DONE; + + if (netif_is_bridge_master(info->upper_dev) && !info->linking) + dsa_port_pre_bridge_leave(dp, info->upper_dev); + else if (netif_is_lag_master(info->upper_dev) && !info->linking) + dsa_port_pre_lag_leave(dp, info->upper_dev); + /* dsa_port_pre_hsr_leave is not yet necessary since hsr cannot be + * meaningfully enslaved to a bridge yet + */ + + return NOTIFY_DONE; +} + +static int +dsa_slave_lag_changeupper(struct net_device *dev, + struct netdev_notifier_changeupper_info *info) +{ + struct net_device *lower; + struct list_head *iter; + int err = NOTIFY_DONE; + struct dsa_port *dp; + + if (!netif_is_lag_master(dev)) + return err; + + netdev_for_each_lower_dev(dev, lower, iter) { + if (!dsa_slave_dev_check(lower)) + continue; + + dp = dsa_slave_to_port(lower); + if (!dp->lag) + /* Software LAG */ + continue; + + err = dsa_slave_changeupper(lower, info); + if (notifier_to_errno(err)) + break; + } + + return err; +} + +/* Same as dsa_slave_lag_changeupper() except that it calls + * dsa_slave_prechangeupper() + */ +static int +dsa_slave_lag_prechangeupper(struct net_device *dev, + struct netdev_notifier_changeupper_info *info) +{ + struct net_device *lower; + struct list_head *iter; + int err = NOTIFY_DONE; + struct dsa_port *dp; + + if (!netif_is_lag_master(dev)) + return err; + + netdev_for_each_lower_dev(dev, lower, iter) { + if (!dsa_slave_dev_check(lower)) + continue; + + dp = dsa_slave_to_port(lower); + if (!dp->lag) + /* Software LAG */ + continue; + + err = dsa_slave_prechangeupper(lower, info); + if (notifier_to_errno(err)) + break; + } + + return err; +} + +static int +dsa_prevent_bridging_8021q_upper(struct net_device *dev, + struct netdev_notifier_changeupper_info *info) +{ + struct netlink_ext_ack *ext_ack; + struct net_device *slave, *br; + struct dsa_port *dp; + + ext_ack = netdev_notifier_info_to_extack(&info->info); + + if (!is_vlan_dev(dev)) + return NOTIFY_DONE; + + slave = vlan_dev_real_dev(dev); + if (!dsa_slave_dev_check(slave)) + return NOTIFY_DONE; + + dp = dsa_slave_to_port(slave); + br = dsa_port_bridge_dev_get(dp); + if (!br) + return NOTIFY_DONE; + + /* Deny enslaving a VLAN device into a VLAN-aware bridge */ + if (br_vlan_enabled(br) && + netif_is_bridge_master(info->upper_dev) && info->linking) { + NL_SET_ERR_MSG_MOD(ext_ack, + "Cannot enslave VLAN device into VLAN aware bridge"); + return notifier_from_errno(-EINVAL); + } + + return NOTIFY_DONE; +} + +static int +dsa_slave_check_8021q_upper(struct net_device *dev, + struct netdev_notifier_changeupper_info *info) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct net_device *br = dsa_port_bridge_dev_get(dp); + struct bridge_vlan_info br_info; + struct netlink_ext_ack *extack; + int err = NOTIFY_DONE; + u16 vid; + + if (!br || !br_vlan_enabled(br)) + return NOTIFY_DONE; + + extack = netdev_notifier_info_to_extack(&info->info); + vid = vlan_dev_vlan_id(info->upper_dev); + + /* br_vlan_get_info() returns -EINVAL or -ENOENT if the + * device, respectively the VID is not found, returning + * 0 means success, which is a failure for us here. + */ + err = br_vlan_get_info(br, vid, &br_info); + if (err == 0) { + NL_SET_ERR_MSG_MOD(extack, + "This VLAN is already configured by the bridge"); + return notifier_from_errno(-EBUSY); + } + + return NOTIFY_DONE; +} + +static int +dsa_slave_prechangeupper_sanity_check(struct net_device *dev, + struct netdev_notifier_changeupper_info *info) +{ + struct dsa_switch *ds; + struct dsa_port *dp; + int err; + + if (!dsa_slave_dev_check(dev)) + return dsa_prevent_bridging_8021q_upper(dev, info); + + dp = dsa_slave_to_port(dev); + ds = dp->ds; + + if (ds->ops->port_prechangeupper) { + err = ds->ops->port_prechangeupper(ds, dp->index, info); + if (err) + return notifier_from_errno(err); + } + + if (is_vlan_dev(info->upper_dev)) + return dsa_slave_check_8021q_upper(dev, info); + + return NOTIFY_DONE; +} + +/* To be eligible as a DSA master, a LAG must have all lower interfaces be + * eligible DSA masters. Additionally, all LAG slaves must be DSA masters of + * switches in the same switch tree. + */ +static int dsa_lag_master_validate(struct net_device *lag_dev, + struct netlink_ext_ack *extack) +{ + struct net_device *lower1, *lower2; + struct list_head *iter1, *iter2; + + netdev_for_each_lower_dev(lag_dev, lower1, iter1) { + netdev_for_each_lower_dev(lag_dev, lower2, iter2) { + if (!netdev_uses_dsa(lower1) || + !netdev_uses_dsa(lower2)) { + NL_SET_ERR_MSG_MOD(extack, + "All LAG ports must be eligible as DSA masters"); + return notifier_from_errno(-EINVAL); + } + + if (lower1 == lower2) + continue; + + if (!dsa_port_tree_same(lower1->dsa_ptr, + lower2->dsa_ptr)) { + NL_SET_ERR_MSG_MOD(extack, + "LAG contains DSA masters of disjoint switch trees"); + return notifier_from_errno(-EINVAL); + } + } + } + + return NOTIFY_DONE; +} + +static int +dsa_master_prechangeupper_sanity_check(struct net_device *master, + struct netdev_notifier_changeupper_info *info) +{ + struct netlink_ext_ack *extack = netdev_notifier_info_to_extack(&info->info); + + if (!netdev_uses_dsa(master)) + return NOTIFY_DONE; + + if (!info->linking) + return NOTIFY_DONE; + + /* Allow DSA switch uppers */ + if (dsa_slave_dev_check(info->upper_dev)) + return NOTIFY_DONE; + + /* Allow bridge uppers of DSA masters, subject to further + * restrictions in dsa_bridge_prechangelower_sanity_check() + */ + if (netif_is_bridge_master(info->upper_dev)) + return NOTIFY_DONE; + + /* Allow LAG uppers, subject to further restrictions in + * dsa_lag_master_prechangelower_sanity_check() + */ + if (netif_is_lag_master(info->upper_dev)) + return dsa_lag_master_validate(info->upper_dev, extack); + + NL_SET_ERR_MSG_MOD(extack, + "DSA master cannot join unknown upper interfaces"); + return notifier_from_errno(-EBUSY); +} + +static int +dsa_lag_master_prechangelower_sanity_check(struct net_device *dev, + struct netdev_notifier_changeupper_info *info) +{ + struct netlink_ext_ack *extack = netdev_notifier_info_to_extack(&info->info); + struct net_device *lag_dev = info->upper_dev; + struct net_device *lower; + struct list_head *iter; + + if (!netdev_uses_dsa(lag_dev) || !netif_is_lag_master(lag_dev)) + return NOTIFY_DONE; + + if (!info->linking) + return NOTIFY_DONE; + + if (!netdev_uses_dsa(dev)) { + NL_SET_ERR_MSG(extack, + "Only DSA masters can join a LAG DSA master"); + return notifier_from_errno(-EINVAL); + } + + netdev_for_each_lower_dev(lag_dev, lower, iter) { + if (!dsa_port_tree_same(dev->dsa_ptr, lower->dsa_ptr)) { + NL_SET_ERR_MSG(extack, + "Interface is DSA master for a different switch tree than this LAG"); + return notifier_from_errno(-EINVAL); + } + + break; + } + + return NOTIFY_DONE; +} + +/* Don't allow bridging of DSA masters, since the bridge layer rx_handler + * prevents the DSA fake ethertype handler to be invoked, so we don't get the + * chance to strip off and parse the DSA switch tag protocol header (the bridge + * layer just returns RX_HANDLER_CONSUMED, stopping RX processing for these + * frames). + * The only case where that would not be an issue is when bridging can already + * be offloaded, such as when the DSA master is itself a DSA or plain switchdev + * port, and is bridged only with other ports from the same hardware device. + */ +static int +dsa_bridge_prechangelower_sanity_check(struct net_device *new_lower, + struct netdev_notifier_changeupper_info *info) +{ + struct net_device *br = info->upper_dev; + struct netlink_ext_ack *extack; + struct net_device *lower; + struct list_head *iter; + + if (!netif_is_bridge_master(br)) + return NOTIFY_DONE; + + if (!info->linking) + return NOTIFY_DONE; + + extack = netdev_notifier_info_to_extack(&info->info); + + netdev_for_each_lower_dev(br, lower, iter) { + if (!netdev_uses_dsa(new_lower) && !netdev_uses_dsa(lower)) + continue; + + if (!netdev_port_same_parent_id(lower, new_lower)) { + NL_SET_ERR_MSG(extack, + "Cannot do software bridging with a DSA master"); + return notifier_from_errno(-EINVAL); + } + } + + return NOTIFY_DONE; +} + +static void dsa_tree_migrate_ports_from_lag_master(struct dsa_switch_tree *dst, + struct net_device *lag_dev) +{ + struct net_device *new_master = dsa_tree_find_first_master(dst); + struct dsa_port *dp; + int err; + + dsa_tree_for_each_user_port(dp, dst) { + if (dsa_port_to_master(dp) != lag_dev) + continue; + + err = dsa_slave_change_master(dp->slave, new_master, NULL); + if (err) { + netdev_err(dp->slave, + "failed to restore master to %s: %pe\n", + new_master->name, ERR_PTR(err)); + } + } +} + +static int dsa_master_lag_join(struct net_device *master, + struct net_device *lag_dev, + struct netdev_lag_upper_info *uinfo, + struct netlink_ext_ack *extack) +{ + struct dsa_port *cpu_dp = master->dsa_ptr; + struct dsa_switch_tree *dst = cpu_dp->dst; + struct dsa_port *dp; + int err; + + err = dsa_master_lag_setup(lag_dev, cpu_dp, uinfo, extack); + if (err) + return err; + + dsa_tree_for_each_user_port(dp, dst) { + if (dsa_port_to_master(dp) != master) + continue; + + err = dsa_slave_change_master(dp->slave, lag_dev, extack); + if (err) + goto restore; + } + + return 0; + +restore: + dsa_tree_for_each_user_port_continue_reverse(dp, dst) { + if (dsa_port_to_master(dp) != lag_dev) + continue; + + err = dsa_slave_change_master(dp->slave, master, NULL); + if (err) { + netdev_err(dp->slave, + "failed to restore master to %s: %pe\n", + master->name, ERR_PTR(err)); + } + } + + dsa_master_lag_teardown(lag_dev, master->dsa_ptr); + + return err; +} + +static void dsa_master_lag_leave(struct net_device *master, + struct net_device *lag_dev) +{ + struct dsa_port *dp, *cpu_dp = lag_dev->dsa_ptr; + struct dsa_switch_tree *dst = cpu_dp->dst; + struct dsa_port *new_cpu_dp = NULL; + struct net_device *lower; + struct list_head *iter; + + netdev_for_each_lower_dev(lag_dev, lower, iter) { + if (netdev_uses_dsa(lower)) { + new_cpu_dp = lower->dsa_ptr; + break; + } + } + + if (new_cpu_dp) { + /* Update the CPU port of the user ports still under the LAG + * so that dsa_port_to_master() continues to work properly + */ + dsa_tree_for_each_user_port(dp, dst) + if (dsa_port_to_master(dp) == lag_dev) + dp->cpu_dp = new_cpu_dp; + + /* Update the index of the virtual CPU port to match the lowest + * physical CPU port + */ + lag_dev->dsa_ptr = new_cpu_dp; + wmb(); + } else { + /* If the LAG DSA master has no ports left, migrate back all + * user ports to the first physical CPU port + */ + dsa_tree_migrate_ports_from_lag_master(dst, lag_dev); + } + + /* This DSA master has left its LAG in any case, so let + * the CPU port leave the hardware LAG as well + */ + dsa_master_lag_teardown(lag_dev, master->dsa_ptr); +} + +static int dsa_master_changeupper(struct net_device *dev, + struct netdev_notifier_changeupper_info *info) +{ + struct netlink_ext_ack *extack; + int err = NOTIFY_DONE; + + if (!netdev_uses_dsa(dev)) + return err; + + extack = netdev_notifier_info_to_extack(&info->info); + + if (netif_is_lag_master(info->upper_dev)) { + if (info->linking) { + err = dsa_master_lag_join(dev, info->upper_dev, + info->upper_info, extack); + err = notifier_from_errno(err); + } else { + dsa_master_lag_leave(dev, info->upper_dev); + err = NOTIFY_OK; + } + } + + return err; +} + +static int dsa_slave_netdevice_event(struct notifier_block *nb, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + switch (event) { + case NETDEV_PRECHANGEUPPER: { + struct netdev_notifier_changeupper_info *info = ptr; + int err; + + err = dsa_slave_prechangeupper_sanity_check(dev, info); + if (notifier_to_errno(err)) + return err; + + err = dsa_master_prechangeupper_sanity_check(dev, info); + if (notifier_to_errno(err)) + return err; + + err = dsa_lag_master_prechangelower_sanity_check(dev, info); + if (notifier_to_errno(err)) + return err; + + err = dsa_bridge_prechangelower_sanity_check(dev, info); + if (notifier_to_errno(err)) + return err; + + err = dsa_slave_prechangeupper(dev, ptr); + if (notifier_to_errno(err)) + return err; + + err = dsa_slave_lag_prechangeupper(dev, ptr); + if (notifier_to_errno(err)) + return err; + + break; + } + case NETDEV_CHANGEUPPER: { + int err; + + err = dsa_slave_changeupper(dev, ptr); + if (notifier_to_errno(err)) + return err; + + err = dsa_slave_lag_changeupper(dev, ptr); + if (notifier_to_errno(err)) + return err; + + err = dsa_master_changeupper(dev, ptr); + if (notifier_to_errno(err)) + return err; + + break; + } + case NETDEV_CHANGELOWERSTATE: { + struct netdev_notifier_changelowerstate_info *info = ptr; + struct dsa_port *dp; + int err = 0; + + if (dsa_slave_dev_check(dev)) { + dp = dsa_slave_to_port(dev); + + err = dsa_port_lag_change(dp, info->lower_state_info); + } + + /* Mirror LAG port events on DSA masters that are in + * a LAG towards their respective switch CPU ports + */ + if (netdev_uses_dsa(dev)) { + dp = dev->dsa_ptr; + + err = dsa_port_lag_change(dp, info->lower_state_info); + } + + return notifier_from_errno(err); + } + case NETDEV_CHANGE: + case NETDEV_UP: { + /* Track state of master port. + * DSA driver may require the master port (and indirectly + * the tagger) to be available for some special operation. + */ + if (netdev_uses_dsa(dev)) { + struct dsa_port *cpu_dp = dev->dsa_ptr; + struct dsa_switch_tree *dst = cpu_dp->ds->dst; + + /* Track when the master port is UP */ + dsa_tree_master_oper_state_change(dst, dev, + netif_oper_up(dev)); + + /* Track when the master port is ready and can accept + * packet. + * NETDEV_UP event is not enough to flag a port as ready. + * We also have to wait for linkwatch_do_dev to dev_activate + * and emit a NETDEV_CHANGE event. + * We check if a master port is ready by checking if the dev + * have a qdisc assigned and is not noop. + */ + dsa_tree_master_admin_state_change(dst, dev, + !qdisc_tx_is_noop(dev)); + + return NOTIFY_OK; + } + + return NOTIFY_DONE; + } + case NETDEV_GOING_DOWN: { + struct dsa_port *dp, *cpu_dp; + struct dsa_switch_tree *dst; + LIST_HEAD(close_list); + + if (!netdev_uses_dsa(dev)) + return NOTIFY_DONE; + + cpu_dp = dev->dsa_ptr; + dst = cpu_dp->ds->dst; + + dsa_tree_master_admin_state_change(dst, dev, false); + + list_for_each_entry(dp, &dst->ports, list) { + if (!dsa_port_is_user(dp)) + continue; + + if (dp->cpu_dp != cpu_dp) + continue; + + list_add(&dp->slave->close_list, &close_list); + } + + dev_close_many(&close_list, true); + + return NOTIFY_OK; + } + default: + break; + } + + return NOTIFY_DONE; +} + +static void +dsa_fdb_offload_notify(struct dsa_switchdev_event_work *switchdev_work) +{ + struct switchdev_notifier_fdb_info info = {}; + + info.addr = switchdev_work->addr; + info.vid = switchdev_work->vid; + info.offloaded = true; + call_switchdev_notifiers(SWITCHDEV_FDB_OFFLOADED, + switchdev_work->orig_dev, &info.info, NULL); +} + +static void dsa_slave_switchdev_event_work(struct work_struct *work) +{ + struct dsa_switchdev_event_work *switchdev_work = + container_of(work, struct dsa_switchdev_event_work, work); + const unsigned char *addr = switchdev_work->addr; + struct net_device *dev = switchdev_work->dev; + u16 vid = switchdev_work->vid; + struct dsa_switch *ds; + struct dsa_port *dp; + int err; + + dp = dsa_slave_to_port(dev); + ds = dp->ds; + + switch (switchdev_work->event) { + case SWITCHDEV_FDB_ADD_TO_DEVICE: + if (switchdev_work->host_addr) + err = dsa_port_bridge_host_fdb_add(dp, addr, vid); + else if (dp->lag) + err = dsa_port_lag_fdb_add(dp, addr, vid); + else + err = dsa_port_fdb_add(dp, addr, vid); + if (err) { + dev_err(ds->dev, + "port %d failed to add %pM vid %d to fdb: %d\n", + dp->index, addr, vid, err); + break; + } + dsa_fdb_offload_notify(switchdev_work); + break; + + case SWITCHDEV_FDB_DEL_TO_DEVICE: + if (switchdev_work->host_addr) + err = dsa_port_bridge_host_fdb_del(dp, addr, vid); + else if (dp->lag) + err = dsa_port_lag_fdb_del(dp, addr, vid); + else + err = dsa_port_fdb_del(dp, addr, vid); + if (err) { + dev_err(ds->dev, + "port %d failed to delete %pM vid %d from fdb: %d\n", + dp->index, addr, vid, err); + } + + break; + } + + kfree(switchdev_work); +} + +static bool dsa_foreign_dev_check(const struct net_device *dev, + const struct net_device *foreign_dev) +{ + const struct dsa_port *dp = dsa_slave_to_port(dev); + struct dsa_switch_tree *dst = dp->ds->dst; + + if (netif_is_bridge_master(foreign_dev)) + return !dsa_tree_offloads_bridge_dev(dst, foreign_dev); + + if (netif_is_bridge_port(foreign_dev)) + return !dsa_tree_offloads_bridge_port(dst, foreign_dev); + + /* Everything else is foreign */ + return true; +} + +static int dsa_slave_fdb_event(struct net_device *dev, + struct net_device *orig_dev, + unsigned long event, const void *ctx, + const struct switchdev_notifier_fdb_info *fdb_info) +{ + struct dsa_switchdev_event_work *switchdev_work; + struct dsa_port *dp = dsa_slave_to_port(dev); + bool host_addr = fdb_info->is_local; + struct dsa_switch *ds = dp->ds; + + if (ctx && ctx != dp) + return 0; + + if (!dp->bridge) + return 0; + + if (switchdev_fdb_is_dynamically_learned(fdb_info)) { + if (dsa_port_offloads_bridge_port(dp, orig_dev)) + return 0; + + /* FDB entries learned by the software bridge or by foreign + * bridge ports should be installed as host addresses only if + * the driver requests assisted learning. + */ + if (!ds->assisted_learning_on_cpu_port) + return 0; + } + + /* Also treat FDB entries on foreign interfaces bridged with us as host + * addresses. + */ + if (dsa_foreign_dev_check(dev, orig_dev)) + host_addr = true; + + /* Check early that we're not doing work in vain. + * Host addresses on LAG ports still require regular FDB ops, + * since the CPU port isn't in a LAG. + */ + if (dp->lag && !host_addr) { + if (!ds->ops->lag_fdb_add || !ds->ops->lag_fdb_del) + return -EOPNOTSUPP; + } else { + if (!ds->ops->port_fdb_add || !ds->ops->port_fdb_del) + return -EOPNOTSUPP; + } + + switchdev_work = kzalloc(sizeof(*switchdev_work), GFP_ATOMIC); + if (!switchdev_work) + return -ENOMEM; + + netdev_dbg(dev, "%s FDB entry towards %s, addr %pM vid %d%s\n", + event == SWITCHDEV_FDB_ADD_TO_DEVICE ? "Adding" : "Deleting", + orig_dev->name, fdb_info->addr, fdb_info->vid, + host_addr ? " as host address" : ""); + + INIT_WORK(&switchdev_work->work, dsa_slave_switchdev_event_work); + switchdev_work->event = event; + switchdev_work->dev = dev; + switchdev_work->orig_dev = orig_dev; + + ether_addr_copy(switchdev_work->addr, fdb_info->addr); + switchdev_work->vid = fdb_info->vid; + switchdev_work->host_addr = host_addr; + + dsa_schedule_work(&switchdev_work->work); + + return 0; +} + +/* Called under rcu_read_lock() */ +static int dsa_slave_switchdev_event(struct notifier_block *unused, + unsigned long event, void *ptr) +{ + struct net_device *dev = switchdev_notifier_info_to_dev(ptr); + int err; + + switch (event) { + case SWITCHDEV_PORT_ATTR_SET: + err = switchdev_handle_port_attr_set(dev, ptr, + dsa_slave_dev_check, + dsa_slave_port_attr_set); + return notifier_from_errno(err); + case SWITCHDEV_FDB_ADD_TO_DEVICE: + case SWITCHDEV_FDB_DEL_TO_DEVICE: + err = switchdev_handle_fdb_event_to_device(dev, event, ptr, + dsa_slave_dev_check, + dsa_foreign_dev_check, + dsa_slave_fdb_event); + return notifier_from_errno(err); + default: + return NOTIFY_DONE; + } + + return NOTIFY_OK; +} + +static int dsa_slave_switchdev_blocking_event(struct notifier_block *unused, + unsigned long event, void *ptr) +{ + struct net_device *dev = switchdev_notifier_info_to_dev(ptr); + int err; + + switch (event) { + case SWITCHDEV_PORT_OBJ_ADD: + err = switchdev_handle_port_obj_add_foreign(dev, ptr, + dsa_slave_dev_check, + dsa_foreign_dev_check, + dsa_slave_port_obj_add); + return notifier_from_errno(err); + case SWITCHDEV_PORT_OBJ_DEL: + err = switchdev_handle_port_obj_del_foreign(dev, ptr, + dsa_slave_dev_check, + dsa_foreign_dev_check, + dsa_slave_port_obj_del); + return notifier_from_errno(err); + case SWITCHDEV_PORT_ATTR_SET: + err = switchdev_handle_port_attr_set(dev, ptr, + dsa_slave_dev_check, + dsa_slave_port_attr_set); + return notifier_from_errno(err); + } + + return NOTIFY_DONE; +} + +static struct notifier_block dsa_slave_nb __read_mostly = { + .notifier_call = dsa_slave_netdevice_event, +}; + +struct notifier_block dsa_slave_switchdev_notifier = { + .notifier_call = dsa_slave_switchdev_event, +}; + +struct notifier_block dsa_slave_switchdev_blocking_notifier = { + .notifier_call = dsa_slave_switchdev_blocking_event, +}; + +int dsa_slave_register_notifier(void) +{ + struct notifier_block *nb; + int err; + + err = register_netdevice_notifier(&dsa_slave_nb); + if (err) + return err; + + err = register_switchdev_notifier(&dsa_slave_switchdev_notifier); + if (err) + goto err_switchdev_nb; + + nb = &dsa_slave_switchdev_blocking_notifier; + err = register_switchdev_blocking_notifier(nb); + if (err) + goto err_switchdev_blocking_nb; + + return 0; + +err_switchdev_blocking_nb: + unregister_switchdev_notifier(&dsa_slave_switchdev_notifier); +err_switchdev_nb: + unregister_netdevice_notifier(&dsa_slave_nb); + return err; +} + +void dsa_slave_unregister_notifier(void) +{ + struct notifier_block *nb; + int err; + + nb = &dsa_slave_switchdev_blocking_notifier; + err = unregister_switchdev_blocking_notifier(nb); + if (err) + pr_err("DSA: failed to unregister switchdev blocking notifier (%d)\n", err); + + err = unregister_switchdev_notifier(&dsa_slave_switchdev_notifier); + if (err) + pr_err("DSA: failed to unregister switchdev notifier (%d)\n", err); + + err = unregister_netdevice_notifier(&dsa_slave_nb); + if (err) + pr_err("DSA: failed to unregister slave notifier (%d)\n", err); +} diff --git a/net/dsa/switch.c b/net/dsa/switch.c new file mode 100644 index 000000000..ce56acdba --- /dev/null +++ b/net/dsa/switch.c @@ -0,0 +1,1030 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Handling of a single switch chip, part of a switch fabric + * + * Copyright (c) 2017 Savoir-faire Linux Inc. + * Vivien Didelot + */ + +#include +#include +#include +#include +#include + +#include "dsa_priv.h" + +static unsigned int dsa_switch_fastest_ageing_time(struct dsa_switch *ds, + unsigned int ageing_time) +{ + struct dsa_port *dp; + + dsa_switch_for_each_port(dp, ds) + if (dp->ageing_time && dp->ageing_time < ageing_time) + ageing_time = dp->ageing_time; + + return ageing_time; +} + +static int dsa_switch_ageing_time(struct dsa_switch *ds, + struct dsa_notifier_ageing_time_info *info) +{ + unsigned int ageing_time = info->ageing_time; + + if (ds->ageing_time_min && ageing_time < ds->ageing_time_min) + return -ERANGE; + + if (ds->ageing_time_max && ageing_time > ds->ageing_time_max) + return -ERANGE; + + /* Program the fastest ageing time in case of multiple bridges */ + ageing_time = dsa_switch_fastest_ageing_time(ds, ageing_time); + + if (ds->ops->set_ageing_time) + return ds->ops->set_ageing_time(ds, ageing_time); + + return 0; +} + +static bool dsa_port_mtu_match(struct dsa_port *dp, + struct dsa_notifier_mtu_info *info) +{ + return dp == info->dp || dsa_port_is_dsa(dp) || dsa_port_is_cpu(dp); +} + +static int dsa_switch_mtu(struct dsa_switch *ds, + struct dsa_notifier_mtu_info *info) +{ + struct dsa_port *dp; + int ret; + + if (!ds->ops->port_change_mtu) + return -EOPNOTSUPP; + + dsa_switch_for_each_port(dp, ds) { + if (dsa_port_mtu_match(dp, info)) { + ret = ds->ops->port_change_mtu(ds, dp->index, + info->mtu); + if (ret) + return ret; + } + } + + return 0; +} + +static int dsa_switch_bridge_join(struct dsa_switch *ds, + struct dsa_notifier_bridge_info *info) +{ + int err; + + if (info->dp->ds == ds) { + if (!ds->ops->port_bridge_join) + return -EOPNOTSUPP; + + err = ds->ops->port_bridge_join(ds, info->dp->index, + info->bridge, + &info->tx_fwd_offload, + info->extack); + if (err) + return err; + } + + if (info->dp->ds != ds && ds->ops->crosschip_bridge_join) { + err = ds->ops->crosschip_bridge_join(ds, + info->dp->ds->dst->index, + info->dp->ds->index, + info->dp->index, + info->bridge, + info->extack); + if (err) + return err; + } + + return 0; +} + +static int dsa_switch_bridge_leave(struct dsa_switch *ds, + struct dsa_notifier_bridge_info *info) +{ + if (info->dp->ds == ds && ds->ops->port_bridge_leave) + ds->ops->port_bridge_leave(ds, info->dp->index, info->bridge); + + if (info->dp->ds != ds && ds->ops->crosschip_bridge_leave) + ds->ops->crosschip_bridge_leave(ds, info->dp->ds->dst->index, + info->dp->ds->index, + info->dp->index, + info->bridge); + + return 0; +} + +/* Matches for all upstream-facing ports (the CPU port and all upstream-facing + * DSA links) that sit between the targeted port on which the notifier was + * emitted and its dedicated CPU port. + */ +static bool dsa_port_host_address_match(struct dsa_port *dp, + const struct dsa_port *targeted_dp) +{ + struct dsa_port *cpu_dp = targeted_dp->cpu_dp; + + if (dsa_switch_is_upstream_of(dp->ds, targeted_dp->ds)) + return dp->index == dsa_towards_port(dp->ds, cpu_dp->ds->index, + cpu_dp->index); + + return false; +} + +static struct dsa_mac_addr *dsa_mac_addr_find(struct list_head *addr_list, + const unsigned char *addr, u16 vid, + struct dsa_db db) +{ + struct dsa_mac_addr *a; + + list_for_each_entry(a, addr_list, list) + if (ether_addr_equal(a->addr, addr) && a->vid == vid && + dsa_db_equal(&a->db, &db)) + return a; + + return NULL; +} + +static int dsa_port_do_mdb_add(struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb, + struct dsa_db db) +{ + struct dsa_switch *ds = dp->ds; + struct dsa_mac_addr *a; + int port = dp->index; + int err = 0; + + /* No need to bother with refcounting for user ports */ + if (!(dsa_port_is_cpu(dp) || dsa_port_is_dsa(dp))) + return ds->ops->port_mdb_add(ds, port, mdb, db); + + mutex_lock(&dp->addr_lists_lock); + + a = dsa_mac_addr_find(&dp->mdbs, mdb->addr, mdb->vid, db); + if (a) { + refcount_inc(&a->refcount); + goto out; + } + + a = kzalloc(sizeof(*a), GFP_KERNEL); + if (!a) { + err = -ENOMEM; + goto out; + } + + err = ds->ops->port_mdb_add(ds, port, mdb, db); + if (err) { + kfree(a); + goto out; + } + + ether_addr_copy(a->addr, mdb->addr); + a->vid = mdb->vid; + a->db = db; + refcount_set(&a->refcount, 1); + list_add_tail(&a->list, &dp->mdbs); + +out: + mutex_unlock(&dp->addr_lists_lock); + + return err; +} + +static int dsa_port_do_mdb_del(struct dsa_port *dp, + const struct switchdev_obj_port_mdb *mdb, + struct dsa_db db) +{ + struct dsa_switch *ds = dp->ds; + struct dsa_mac_addr *a; + int port = dp->index; + int err = 0; + + /* No need to bother with refcounting for user ports */ + if (!(dsa_port_is_cpu(dp) || dsa_port_is_dsa(dp))) + return ds->ops->port_mdb_del(ds, port, mdb, db); + + mutex_lock(&dp->addr_lists_lock); + + a = dsa_mac_addr_find(&dp->mdbs, mdb->addr, mdb->vid, db); + if (!a) { + err = -ENOENT; + goto out; + } + + if (!refcount_dec_and_test(&a->refcount)) + goto out; + + err = ds->ops->port_mdb_del(ds, port, mdb, db); + if (err) { + refcount_set(&a->refcount, 1); + goto out; + } + + list_del(&a->list); + kfree(a); + +out: + mutex_unlock(&dp->addr_lists_lock); + + return err; +} + +static int dsa_port_do_fdb_add(struct dsa_port *dp, const unsigned char *addr, + u16 vid, struct dsa_db db) +{ + struct dsa_switch *ds = dp->ds; + struct dsa_mac_addr *a; + int port = dp->index; + int err = 0; + + /* No need to bother with refcounting for user ports */ + if (!(dsa_port_is_cpu(dp) || dsa_port_is_dsa(dp))) + return ds->ops->port_fdb_add(ds, port, addr, vid, db); + + mutex_lock(&dp->addr_lists_lock); + + a = dsa_mac_addr_find(&dp->fdbs, addr, vid, db); + if (a) { + refcount_inc(&a->refcount); + goto out; + } + + a = kzalloc(sizeof(*a), GFP_KERNEL); + if (!a) { + err = -ENOMEM; + goto out; + } + + err = ds->ops->port_fdb_add(ds, port, addr, vid, db); + if (err) { + kfree(a); + goto out; + } + + ether_addr_copy(a->addr, addr); + a->vid = vid; + a->db = db; + refcount_set(&a->refcount, 1); + list_add_tail(&a->list, &dp->fdbs); + +out: + mutex_unlock(&dp->addr_lists_lock); + + return err; +} + +static int dsa_port_do_fdb_del(struct dsa_port *dp, const unsigned char *addr, + u16 vid, struct dsa_db db) +{ + struct dsa_switch *ds = dp->ds; + struct dsa_mac_addr *a; + int port = dp->index; + int err = 0; + + /* No need to bother with refcounting for user ports */ + if (!(dsa_port_is_cpu(dp) || dsa_port_is_dsa(dp))) + return ds->ops->port_fdb_del(ds, port, addr, vid, db); + + mutex_lock(&dp->addr_lists_lock); + + a = dsa_mac_addr_find(&dp->fdbs, addr, vid, db); + if (!a) { + err = -ENOENT; + goto out; + } + + if (!refcount_dec_and_test(&a->refcount)) + goto out; + + err = ds->ops->port_fdb_del(ds, port, addr, vid, db); + if (err) { + refcount_set(&a->refcount, 1); + goto out; + } + + list_del(&a->list); + kfree(a); + +out: + mutex_unlock(&dp->addr_lists_lock); + + return err; +} + +static int dsa_switch_do_lag_fdb_add(struct dsa_switch *ds, struct dsa_lag *lag, + const unsigned char *addr, u16 vid, + struct dsa_db db) +{ + struct dsa_mac_addr *a; + int err = 0; + + mutex_lock(&lag->fdb_lock); + + a = dsa_mac_addr_find(&lag->fdbs, addr, vid, db); + if (a) { + refcount_inc(&a->refcount); + goto out; + } + + a = kzalloc(sizeof(*a), GFP_KERNEL); + if (!a) { + err = -ENOMEM; + goto out; + } + + err = ds->ops->lag_fdb_add(ds, *lag, addr, vid, db); + if (err) { + kfree(a); + goto out; + } + + ether_addr_copy(a->addr, addr); + a->vid = vid; + a->db = db; + refcount_set(&a->refcount, 1); + list_add_tail(&a->list, &lag->fdbs); + +out: + mutex_unlock(&lag->fdb_lock); + + return err; +} + +static int dsa_switch_do_lag_fdb_del(struct dsa_switch *ds, struct dsa_lag *lag, + const unsigned char *addr, u16 vid, + struct dsa_db db) +{ + struct dsa_mac_addr *a; + int err = 0; + + mutex_lock(&lag->fdb_lock); + + a = dsa_mac_addr_find(&lag->fdbs, addr, vid, db); + if (!a) { + err = -ENOENT; + goto out; + } + + if (!refcount_dec_and_test(&a->refcount)) + goto out; + + err = ds->ops->lag_fdb_del(ds, *lag, addr, vid, db); + if (err) { + refcount_set(&a->refcount, 1); + goto out; + } + + list_del(&a->list); + kfree(a); + +out: + mutex_unlock(&lag->fdb_lock); + + return err; +} + +static int dsa_switch_host_fdb_add(struct dsa_switch *ds, + struct dsa_notifier_fdb_info *info) +{ + struct dsa_port *dp; + int err = 0; + + if (!ds->ops->port_fdb_add) + return -EOPNOTSUPP; + + dsa_switch_for_each_port(dp, ds) { + if (dsa_port_host_address_match(dp, info->dp)) { + if (dsa_port_is_cpu(dp) && info->dp->cpu_port_in_lag) { + err = dsa_switch_do_lag_fdb_add(ds, dp->lag, + info->addr, + info->vid, + info->db); + } else { + err = dsa_port_do_fdb_add(dp, info->addr, + info->vid, info->db); + } + if (err) + break; + } + } + + return err; +} + +static int dsa_switch_host_fdb_del(struct dsa_switch *ds, + struct dsa_notifier_fdb_info *info) +{ + struct dsa_port *dp; + int err = 0; + + if (!ds->ops->port_fdb_del) + return -EOPNOTSUPP; + + dsa_switch_for_each_port(dp, ds) { + if (dsa_port_host_address_match(dp, info->dp)) { + if (dsa_port_is_cpu(dp) && info->dp->cpu_port_in_lag) { + err = dsa_switch_do_lag_fdb_del(ds, dp->lag, + info->addr, + info->vid, + info->db); + } else { + err = dsa_port_do_fdb_del(dp, info->addr, + info->vid, info->db); + } + if (err) + break; + } + } + + return err; +} + +static int dsa_switch_fdb_add(struct dsa_switch *ds, + struct dsa_notifier_fdb_info *info) +{ + int port = dsa_towards_port(ds, info->dp->ds->index, info->dp->index); + struct dsa_port *dp = dsa_to_port(ds, port); + + if (!ds->ops->port_fdb_add) + return -EOPNOTSUPP; + + return dsa_port_do_fdb_add(dp, info->addr, info->vid, info->db); +} + +static int dsa_switch_fdb_del(struct dsa_switch *ds, + struct dsa_notifier_fdb_info *info) +{ + int port = dsa_towards_port(ds, info->dp->ds->index, info->dp->index); + struct dsa_port *dp = dsa_to_port(ds, port); + + if (!ds->ops->port_fdb_del) + return -EOPNOTSUPP; + + return dsa_port_do_fdb_del(dp, info->addr, info->vid, info->db); +} + +static int dsa_switch_lag_fdb_add(struct dsa_switch *ds, + struct dsa_notifier_lag_fdb_info *info) +{ + struct dsa_port *dp; + + if (!ds->ops->lag_fdb_add) + return -EOPNOTSUPP; + + /* Notify switch only if it has a port in this LAG */ + dsa_switch_for_each_port(dp, ds) + if (dsa_port_offloads_lag(dp, info->lag)) + return dsa_switch_do_lag_fdb_add(ds, info->lag, + info->addr, info->vid, + info->db); + + return 0; +} + +static int dsa_switch_lag_fdb_del(struct dsa_switch *ds, + struct dsa_notifier_lag_fdb_info *info) +{ + struct dsa_port *dp; + + if (!ds->ops->lag_fdb_del) + return -EOPNOTSUPP; + + /* Notify switch only if it has a port in this LAG */ + dsa_switch_for_each_port(dp, ds) + if (dsa_port_offloads_lag(dp, info->lag)) + return dsa_switch_do_lag_fdb_del(ds, info->lag, + info->addr, info->vid, + info->db); + + return 0; +} + +static int dsa_switch_lag_change(struct dsa_switch *ds, + struct dsa_notifier_lag_info *info) +{ + if (info->dp->ds == ds && ds->ops->port_lag_change) + return ds->ops->port_lag_change(ds, info->dp->index); + + if (info->dp->ds != ds && ds->ops->crosschip_lag_change) + return ds->ops->crosschip_lag_change(ds, info->dp->ds->index, + info->dp->index); + + return 0; +} + +static int dsa_switch_lag_join(struct dsa_switch *ds, + struct dsa_notifier_lag_info *info) +{ + if (info->dp->ds == ds && ds->ops->port_lag_join) + return ds->ops->port_lag_join(ds, info->dp->index, info->lag, + info->info, info->extack); + + if (info->dp->ds != ds && ds->ops->crosschip_lag_join) + return ds->ops->crosschip_lag_join(ds, info->dp->ds->index, + info->dp->index, info->lag, + info->info, info->extack); + + return -EOPNOTSUPP; +} + +static int dsa_switch_lag_leave(struct dsa_switch *ds, + struct dsa_notifier_lag_info *info) +{ + if (info->dp->ds == ds && ds->ops->port_lag_leave) + return ds->ops->port_lag_leave(ds, info->dp->index, info->lag); + + if (info->dp->ds != ds && ds->ops->crosschip_lag_leave) + return ds->ops->crosschip_lag_leave(ds, info->dp->ds->index, + info->dp->index, info->lag); + + return -EOPNOTSUPP; +} + +static int dsa_switch_mdb_add(struct dsa_switch *ds, + struct dsa_notifier_mdb_info *info) +{ + int port = dsa_towards_port(ds, info->dp->ds->index, info->dp->index); + struct dsa_port *dp = dsa_to_port(ds, port); + + if (!ds->ops->port_mdb_add) + return -EOPNOTSUPP; + + return dsa_port_do_mdb_add(dp, info->mdb, info->db); +} + +static int dsa_switch_mdb_del(struct dsa_switch *ds, + struct dsa_notifier_mdb_info *info) +{ + int port = dsa_towards_port(ds, info->dp->ds->index, info->dp->index); + struct dsa_port *dp = dsa_to_port(ds, port); + + if (!ds->ops->port_mdb_del) + return -EOPNOTSUPP; + + return dsa_port_do_mdb_del(dp, info->mdb, info->db); +} + +static int dsa_switch_host_mdb_add(struct dsa_switch *ds, + struct dsa_notifier_mdb_info *info) +{ + struct dsa_port *dp; + int err = 0; + + if (!ds->ops->port_mdb_add) + return -EOPNOTSUPP; + + dsa_switch_for_each_port(dp, ds) { + if (dsa_port_host_address_match(dp, info->dp)) { + err = dsa_port_do_mdb_add(dp, info->mdb, info->db); + if (err) + break; + } + } + + return err; +} + +static int dsa_switch_host_mdb_del(struct dsa_switch *ds, + struct dsa_notifier_mdb_info *info) +{ + struct dsa_port *dp; + int err = 0; + + if (!ds->ops->port_mdb_del) + return -EOPNOTSUPP; + + dsa_switch_for_each_port(dp, ds) { + if (dsa_port_host_address_match(dp, info->dp)) { + err = dsa_port_do_mdb_del(dp, info->mdb, info->db); + if (err) + break; + } + } + + return err; +} + +/* Port VLANs match on the targeted port and on all DSA ports */ +static bool dsa_port_vlan_match(struct dsa_port *dp, + struct dsa_notifier_vlan_info *info) +{ + return dsa_port_is_dsa(dp) || dp == info->dp; +} + +/* Host VLANs match on the targeted port's CPU port, and on all DSA ports + * (upstream and downstream) of that switch and its upstream switches. + */ +static bool dsa_port_host_vlan_match(struct dsa_port *dp, + const struct dsa_port *targeted_dp) +{ + struct dsa_port *cpu_dp = targeted_dp->cpu_dp; + + if (dsa_switch_is_upstream_of(dp->ds, targeted_dp->ds)) + return dsa_port_is_dsa(dp) || dp == cpu_dp; + + return false; +} + +static struct dsa_vlan *dsa_vlan_find(struct list_head *vlan_list, + const struct switchdev_obj_port_vlan *vlan) +{ + struct dsa_vlan *v; + + list_for_each_entry(v, vlan_list, list) + if (v->vid == vlan->vid) + return v; + + return NULL; +} + +static int dsa_port_do_vlan_add(struct dsa_port *dp, + const struct switchdev_obj_port_vlan *vlan, + struct netlink_ext_ack *extack) +{ + struct dsa_switch *ds = dp->ds; + int port = dp->index; + struct dsa_vlan *v; + int err = 0; + + /* No need to bother with refcounting for user ports. */ + if (!(dsa_port_is_cpu(dp) || dsa_port_is_dsa(dp))) + return ds->ops->port_vlan_add(ds, port, vlan, extack); + + /* No need to propagate on shared ports the existing VLANs that were + * re-notified after just the flags have changed. This would cause a + * refcount bump which we need to avoid, since it unbalances the + * additions with the deletions. + */ + if (vlan->changed) + return 0; + + mutex_lock(&dp->vlans_lock); + + v = dsa_vlan_find(&dp->vlans, vlan); + if (v) { + refcount_inc(&v->refcount); + goto out; + } + + v = kzalloc(sizeof(*v), GFP_KERNEL); + if (!v) { + err = -ENOMEM; + goto out; + } + + err = ds->ops->port_vlan_add(ds, port, vlan, extack); + if (err) { + kfree(v); + goto out; + } + + v->vid = vlan->vid; + refcount_set(&v->refcount, 1); + list_add_tail(&v->list, &dp->vlans); + +out: + mutex_unlock(&dp->vlans_lock); + + return err; +} + +static int dsa_port_do_vlan_del(struct dsa_port *dp, + const struct switchdev_obj_port_vlan *vlan) +{ + struct dsa_switch *ds = dp->ds; + int port = dp->index; + struct dsa_vlan *v; + int err = 0; + + /* No need to bother with refcounting for user ports */ + if (!(dsa_port_is_cpu(dp) || dsa_port_is_dsa(dp))) + return ds->ops->port_vlan_del(ds, port, vlan); + + mutex_lock(&dp->vlans_lock); + + v = dsa_vlan_find(&dp->vlans, vlan); + if (!v) { + err = -ENOENT; + goto out; + } + + if (!refcount_dec_and_test(&v->refcount)) + goto out; + + err = ds->ops->port_vlan_del(ds, port, vlan); + if (err) { + refcount_set(&v->refcount, 1); + goto out; + } + + list_del(&v->list); + kfree(v); + +out: + mutex_unlock(&dp->vlans_lock); + + return err; +} + +static int dsa_switch_vlan_add(struct dsa_switch *ds, + struct dsa_notifier_vlan_info *info) +{ + struct dsa_port *dp; + int err; + + if (!ds->ops->port_vlan_add) + return -EOPNOTSUPP; + + dsa_switch_for_each_port(dp, ds) { + if (dsa_port_vlan_match(dp, info)) { + err = dsa_port_do_vlan_add(dp, info->vlan, + info->extack); + if (err) + return err; + } + } + + return 0; +} + +static int dsa_switch_vlan_del(struct dsa_switch *ds, + struct dsa_notifier_vlan_info *info) +{ + struct dsa_port *dp; + int err; + + if (!ds->ops->port_vlan_del) + return -EOPNOTSUPP; + + dsa_switch_for_each_port(dp, ds) { + if (dsa_port_vlan_match(dp, info)) { + err = dsa_port_do_vlan_del(dp, info->vlan); + if (err) + return err; + } + } + + return 0; +} + +static int dsa_switch_host_vlan_add(struct dsa_switch *ds, + struct dsa_notifier_vlan_info *info) +{ + struct dsa_port *dp; + int err; + + if (!ds->ops->port_vlan_add) + return -EOPNOTSUPP; + + dsa_switch_for_each_port(dp, ds) { + if (dsa_port_host_vlan_match(dp, info->dp)) { + err = dsa_port_do_vlan_add(dp, info->vlan, + info->extack); + if (err) + return err; + } + } + + return 0; +} + +static int dsa_switch_host_vlan_del(struct dsa_switch *ds, + struct dsa_notifier_vlan_info *info) +{ + struct dsa_port *dp; + int err; + + if (!ds->ops->port_vlan_del) + return -EOPNOTSUPP; + + dsa_switch_for_each_port(dp, ds) { + if (dsa_port_host_vlan_match(dp, info->dp)) { + err = dsa_port_do_vlan_del(dp, info->vlan); + if (err) + return err; + } + } + + return 0; +} + +static int dsa_switch_change_tag_proto(struct dsa_switch *ds, + struct dsa_notifier_tag_proto_info *info) +{ + const struct dsa_device_ops *tag_ops = info->tag_ops; + struct dsa_port *dp, *cpu_dp; + int err; + + if (!ds->ops->change_tag_protocol) + return -EOPNOTSUPP; + + ASSERT_RTNL(); + + err = ds->ops->change_tag_protocol(ds, tag_ops->proto); + if (err) + return err; + + dsa_switch_for_each_cpu_port(cpu_dp, ds) + dsa_port_set_tag_protocol(cpu_dp, tag_ops); + + /* Now that changing the tag protocol can no longer fail, let's update + * the remaining bits which are "duplicated for faster access", and the + * bits that depend on the tagger, such as the MTU. + */ + dsa_switch_for_each_user_port(dp, ds) { + struct net_device *slave = dp->slave; + + dsa_slave_setup_tagger(slave); + + /* rtnl_mutex is held in dsa_tree_change_tag_proto */ + dsa_slave_change_mtu(slave, slave->mtu); + } + + return 0; +} + +/* We use the same cross-chip notifiers to inform both the tagger side, as well + * as the switch side, of connection and disconnection events. + * Since ds->tagger_data is owned by the tagger, it isn't a hard error if the + * switch side doesn't support connecting to this tagger, and therefore, the + * fact that we don't disconnect the tagger side doesn't constitute a memory + * leak: the tagger will still operate with persistent per-switch memory, just + * with the switch side unconnected to it. What does constitute a hard error is + * when the switch side supports connecting but fails. + */ +static int +dsa_switch_connect_tag_proto(struct dsa_switch *ds, + struct dsa_notifier_tag_proto_info *info) +{ + const struct dsa_device_ops *tag_ops = info->tag_ops; + int err; + + /* Notify the new tagger about the connection to this switch */ + if (tag_ops->connect) { + err = tag_ops->connect(ds); + if (err) + return err; + } + + if (!ds->ops->connect_tag_protocol) + return -EOPNOTSUPP; + + /* Notify the switch about the connection to the new tagger */ + err = ds->ops->connect_tag_protocol(ds, tag_ops->proto); + if (err) { + /* Revert the new tagger's connection to this tree */ + if (tag_ops->disconnect) + tag_ops->disconnect(ds); + return err; + } + + return 0; +} + +static int +dsa_switch_disconnect_tag_proto(struct dsa_switch *ds, + struct dsa_notifier_tag_proto_info *info) +{ + const struct dsa_device_ops *tag_ops = info->tag_ops; + + /* Notify the tagger about the disconnection from this switch */ + if (tag_ops->disconnect && ds->tagger_data) + tag_ops->disconnect(ds); + + /* No need to notify the switch, since it shouldn't have any + * resources to tear down + */ + return 0; +} + +static int +dsa_switch_master_state_change(struct dsa_switch *ds, + struct dsa_notifier_master_state_info *info) +{ + if (!ds->ops->master_state_change) + return 0; + + ds->ops->master_state_change(ds, info->master, info->operational); + + return 0; +} + +static int dsa_switch_event(struct notifier_block *nb, + unsigned long event, void *info) +{ + struct dsa_switch *ds = container_of(nb, struct dsa_switch, nb); + int err; + + switch (event) { + case DSA_NOTIFIER_AGEING_TIME: + err = dsa_switch_ageing_time(ds, info); + break; + case DSA_NOTIFIER_BRIDGE_JOIN: + err = dsa_switch_bridge_join(ds, info); + break; + case DSA_NOTIFIER_BRIDGE_LEAVE: + err = dsa_switch_bridge_leave(ds, info); + break; + case DSA_NOTIFIER_FDB_ADD: + err = dsa_switch_fdb_add(ds, info); + break; + case DSA_NOTIFIER_FDB_DEL: + err = dsa_switch_fdb_del(ds, info); + break; + case DSA_NOTIFIER_HOST_FDB_ADD: + err = dsa_switch_host_fdb_add(ds, info); + break; + case DSA_NOTIFIER_HOST_FDB_DEL: + err = dsa_switch_host_fdb_del(ds, info); + break; + case DSA_NOTIFIER_LAG_FDB_ADD: + err = dsa_switch_lag_fdb_add(ds, info); + break; + case DSA_NOTIFIER_LAG_FDB_DEL: + err = dsa_switch_lag_fdb_del(ds, info); + break; + case DSA_NOTIFIER_LAG_CHANGE: + err = dsa_switch_lag_change(ds, info); + break; + case DSA_NOTIFIER_LAG_JOIN: + err = dsa_switch_lag_join(ds, info); + break; + case DSA_NOTIFIER_LAG_LEAVE: + err = dsa_switch_lag_leave(ds, info); + break; + case DSA_NOTIFIER_MDB_ADD: + err = dsa_switch_mdb_add(ds, info); + break; + case DSA_NOTIFIER_MDB_DEL: + err = dsa_switch_mdb_del(ds, info); + break; + case DSA_NOTIFIER_HOST_MDB_ADD: + err = dsa_switch_host_mdb_add(ds, info); + break; + case DSA_NOTIFIER_HOST_MDB_DEL: + err = dsa_switch_host_mdb_del(ds, info); + break; + case DSA_NOTIFIER_VLAN_ADD: + err = dsa_switch_vlan_add(ds, info); + break; + case DSA_NOTIFIER_VLAN_DEL: + err = dsa_switch_vlan_del(ds, info); + break; + case DSA_NOTIFIER_HOST_VLAN_ADD: + err = dsa_switch_host_vlan_add(ds, info); + break; + case DSA_NOTIFIER_HOST_VLAN_DEL: + err = dsa_switch_host_vlan_del(ds, info); + break; + case DSA_NOTIFIER_MTU: + err = dsa_switch_mtu(ds, info); + break; + case DSA_NOTIFIER_TAG_PROTO: + err = dsa_switch_change_tag_proto(ds, info); + break; + case DSA_NOTIFIER_TAG_PROTO_CONNECT: + err = dsa_switch_connect_tag_proto(ds, info); + break; + case DSA_NOTIFIER_TAG_PROTO_DISCONNECT: + err = dsa_switch_disconnect_tag_proto(ds, info); + break; + case DSA_NOTIFIER_TAG_8021Q_VLAN_ADD: + err = dsa_switch_tag_8021q_vlan_add(ds, info); + break; + case DSA_NOTIFIER_TAG_8021Q_VLAN_DEL: + err = dsa_switch_tag_8021q_vlan_del(ds, info); + break; + case DSA_NOTIFIER_MASTER_STATE_CHANGE: + err = dsa_switch_master_state_change(ds, info); + break; + default: + err = -EOPNOTSUPP; + break; + } + + if (err) + dev_dbg(ds->dev, "breaking chain for DSA event %lu (%d)\n", + event, err); + + return notifier_from_errno(err); +} + +int dsa_switch_register_notifier(struct dsa_switch *ds) +{ + ds->nb.notifier_call = dsa_switch_event; + + return raw_notifier_chain_register(&ds->dst->nh, &ds->nb); +} + +void dsa_switch_unregister_notifier(struct dsa_switch *ds) +{ + int err; + + err = raw_notifier_chain_unregister(&ds->dst->nh, &ds->nb); + if (err) + dev_err(ds->dev, "failed to unregister notifier (%d)\n", err); +} diff --git a/net/dsa/tag_8021q.c b/net/dsa/tag_8021q.c new file mode 100644 index 000000000..89371b164 --- /dev/null +++ b/net/dsa/tag_8021q.c @@ -0,0 +1,507 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2019, Vladimir Oltean + * + * This module is not a complete tagger implementation. It only provides + * primitives for taggers that rely on 802.1Q VLAN tags to use. + */ +#include +#include + +#include "dsa_priv.h" + +/* Binary structure of the fake 12-bit VID field (when the TPID is + * ETH_P_DSA_8021Q): + * + * | 11 | 10 | 9 | 8 | 7 | 6 | 5 | 4 | 3 | 2 | 1 | 0 | + * +-----------+-----+-----------------+-----------+-----------------------+ + * | RSV | VBID| SWITCH_ID | VBID | PORT | + * +-----------+-----+-----------------+-----------+-----------------------+ + * + * RSV - VID[11:10]: + * Reserved. Must be set to 3 (0b11). + * + * SWITCH_ID - VID[8:6]: + * Index of switch within DSA tree. Must be between 0 and 7. + * + * VBID - { VID[9], VID[5:4] }: + * Virtual bridge ID. If between 1 and 7, packet targets the broadcast + * domain of a bridge. If transmitted as zero, packet targets a single + * port. + * + * PORT - VID[3:0]: + * Index of switch port. Must be between 0 and 15. + */ + +#define DSA_8021Q_RSV_VAL 3 +#define DSA_8021Q_RSV_SHIFT 10 +#define DSA_8021Q_RSV_MASK GENMASK(11, 10) +#define DSA_8021Q_RSV ((DSA_8021Q_RSV_VAL << DSA_8021Q_RSV_SHIFT) & \ + DSA_8021Q_RSV_MASK) + +#define DSA_8021Q_SWITCH_ID_SHIFT 6 +#define DSA_8021Q_SWITCH_ID_MASK GENMASK(8, 6) +#define DSA_8021Q_SWITCH_ID(x) (((x) << DSA_8021Q_SWITCH_ID_SHIFT) & \ + DSA_8021Q_SWITCH_ID_MASK) + +#define DSA_8021Q_VBID_HI_SHIFT 9 +#define DSA_8021Q_VBID_HI_MASK GENMASK(9, 9) +#define DSA_8021Q_VBID_LO_SHIFT 4 +#define DSA_8021Q_VBID_LO_MASK GENMASK(5, 4) +#define DSA_8021Q_VBID_HI(x) (((x) & GENMASK(2, 2)) >> 2) +#define DSA_8021Q_VBID_LO(x) ((x) & GENMASK(1, 0)) +#define DSA_8021Q_VBID(x) \ + (((DSA_8021Q_VBID_LO(x) << DSA_8021Q_VBID_LO_SHIFT) & \ + DSA_8021Q_VBID_LO_MASK) | \ + ((DSA_8021Q_VBID_HI(x) << DSA_8021Q_VBID_HI_SHIFT) & \ + DSA_8021Q_VBID_HI_MASK)) + +#define DSA_8021Q_PORT_SHIFT 0 +#define DSA_8021Q_PORT_MASK GENMASK(3, 0) +#define DSA_8021Q_PORT(x) (((x) << DSA_8021Q_PORT_SHIFT) & \ + DSA_8021Q_PORT_MASK) + +u16 dsa_tag_8021q_bridge_vid(unsigned int bridge_num) +{ + /* The VBID value of 0 is reserved for precise TX, but it is also + * reserved/invalid for the bridge_num, so all is well. + */ + return DSA_8021Q_RSV | DSA_8021Q_VBID(bridge_num); +} +EXPORT_SYMBOL_GPL(dsa_tag_8021q_bridge_vid); + +/* Returns the VID that will be installed as pvid for this switch port, sent as + * tagged egress towards the CPU port and decoded by the rcv function. + */ +u16 dsa_tag_8021q_standalone_vid(const struct dsa_port *dp) +{ + return DSA_8021Q_RSV | DSA_8021Q_SWITCH_ID(dp->ds->index) | + DSA_8021Q_PORT(dp->index); +} +EXPORT_SYMBOL_GPL(dsa_tag_8021q_standalone_vid); + +/* Returns the decoded switch ID from the RX VID. */ +int dsa_8021q_rx_switch_id(u16 vid) +{ + return (vid & DSA_8021Q_SWITCH_ID_MASK) >> DSA_8021Q_SWITCH_ID_SHIFT; +} +EXPORT_SYMBOL_GPL(dsa_8021q_rx_switch_id); + +/* Returns the decoded port ID from the RX VID. */ +int dsa_8021q_rx_source_port(u16 vid) +{ + return (vid & DSA_8021Q_PORT_MASK) >> DSA_8021Q_PORT_SHIFT; +} +EXPORT_SYMBOL_GPL(dsa_8021q_rx_source_port); + +/* Returns the decoded VBID from the RX VID. */ +static int dsa_tag_8021q_rx_vbid(u16 vid) +{ + u16 vbid_hi = (vid & DSA_8021Q_VBID_HI_MASK) >> DSA_8021Q_VBID_HI_SHIFT; + u16 vbid_lo = (vid & DSA_8021Q_VBID_LO_MASK) >> DSA_8021Q_VBID_LO_SHIFT; + + return (vbid_hi << 2) | vbid_lo; +} + +bool vid_is_dsa_8021q(u16 vid) +{ + u16 rsv = (vid & DSA_8021Q_RSV_MASK) >> DSA_8021Q_RSV_SHIFT; + + return rsv == DSA_8021Q_RSV_VAL; +} +EXPORT_SYMBOL_GPL(vid_is_dsa_8021q); + +static struct dsa_tag_8021q_vlan * +dsa_tag_8021q_vlan_find(struct dsa_8021q_context *ctx, int port, u16 vid) +{ + struct dsa_tag_8021q_vlan *v; + + list_for_each_entry(v, &ctx->vlans, list) + if (v->vid == vid && v->port == port) + return v; + + return NULL; +} + +static int dsa_port_do_tag_8021q_vlan_add(struct dsa_port *dp, u16 vid, + u16 flags) +{ + struct dsa_8021q_context *ctx = dp->ds->tag_8021q_ctx; + struct dsa_switch *ds = dp->ds; + struct dsa_tag_8021q_vlan *v; + int port = dp->index; + int err; + + /* No need to bother with refcounting for user ports */ + if (!(dsa_port_is_cpu(dp) || dsa_port_is_dsa(dp))) + return ds->ops->tag_8021q_vlan_add(ds, port, vid, flags); + + v = dsa_tag_8021q_vlan_find(ctx, port, vid); + if (v) { + refcount_inc(&v->refcount); + return 0; + } + + v = kzalloc(sizeof(*v), GFP_KERNEL); + if (!v) + return -ENOMEM; + + err = ds->ops->tag_8021q_vlan_add(ds, port, vid, flags); + if (err) { + kfree(v); + return err; + } + + v->vid = vid; + v->port = port; + refcount_set(&v->refcount, 1); + list_add_tail(&v->list, &ctx->vlans); + + return 0; +} + +static int dsa_port_do_tag_8021q_vlan_del(struct dsa_port *dp, u16 vid) +{ + struct dsa_8021q_context *ctx = dp->ds->tag_8021q_ctx; + struct dsa_switch *ds = dp->ds; + struct dsa_tag_8021q_vlan *v; + int port = dp->index; + int err; + + /* No need to bother with refcounting for user ports */ + if (!(dsa_port_is_cpu(dp) || dsa_port_is_dsa(dp))) + return ds->ops->tag_8021q_vlan_del(ds, port, vid); + + v = dsa_tag_8021q_vlan_find(ctx, port, vid); + if (!v) + return -ENOENT; + + if (!refcount_dec_and_test(&v->refcount)) + return 0; + + err = ds->ops->tag_8021q_vlan_del(ds, port, vid); + if (err) { + refcount_inc(&v->refcount); + return err; + } + + list_del(&v->list); + kfree(v); + + return 0; +} + +static bool +dsa_port_tag_8021q_vlan_match(struct dsa_port *dp, + struct dsa_notifier_tag_8021q_vlan_info *info) +{ + return dsa_port_is_dsa(dp) || dsa_port_is_cpu(dp) || dp == info->dp; +} + +int dsa_switch_tag_8021q_vlan_add(struct dsa_switch *ds, + struct dsa_notifier_tag_8021q_vlan_info *info) +{ + struct dsa_port *dp; + int err; + + /* Since we use dsa_broadcast(), there might be other switches in other + * trees which don't support tag_8021q, so don't return an error. + * Or they might even support tag_8021q but have not registered yet to + * use it (maybe they use another tagger currently). + */ + if (!ds->ops->tag_8021q_vlan_add || !ds->tag_8021q_ctx) + return 0; + + dsa_switch_for_each_port(dp, ds) { + if (dsa_port_tag_8021q_vlan_match(dp, info)) { + u16 flags = 0; + + if (dsa_port_is_user(dp)) + flags |= BRIDGE_VLAN_INFO_UNTAGGED | + BRIDGE_VLAN_INFO_PVID; + + err = dsa_port_do_tag_8021q_vlan_add(dp, info->vid, + flags); + if (err) + return err; + } + } + + return 0; +} + +int dsa_switch_tag_8021q_vlan_del(struct dsa_switch *ds, + struct dsa_notifier_tag_8021q_vlan_info *info) +{ + struct dsa_port *dp; + int err; + + if (!ds->ops->tag_8021q_vlan_del || !ds->tag_8021q_ctx) + return 0; + + dsa_switch_for_each_port(dp, ds) { + if (dsa_port_tag_8021q_vlan_match(dp, info)) { + err = dsa_port_do_tag_8021q_vlan_del(dp, info->vid); + if (err) + return err; + } + } + + return 0; +} + +/* There are 2 ways of offloading tag_8021q VLANs. + * + * One is to use a hardware TCAM to push the port's standalone VLAN into the + * frame when forwarding it to the CPU, as an egress modification rule on the + * CPU port. This is preferable because it has no side effects for the + * autonomous forwarding path, and accomplishes tag_8021q's primary goal of + * identifying the source port of each packet based on VLAN ID. + * + * The other is to commit the tag_8021q VLAN as a PVID to the VLAN table, and + * to configure the port as VLAN-unaware. This is less preferable because + * unique source port identification can only be done for standalone ports; + * under a VLAN-unaware bridge, all ports share the same tag_8021q VLAN as + * PVID, and under a VLAN-aware bridge, packets received by software will not + * have tag_8021q VLANs appended, just bridge VLANs. + * + * For tag_8021q implementations of the second type, this method is used to + * replace the standalone tag_8021q VLAN of a port with the tag_8021q VLAN to + * be used for VLAN-unaware bridging. + */ +int dsa_tag_8021q_bridge_join(struct dsa_switch *ds, int port, + struct dsa_bridge bridge) +{ + struct dsa_port *dp = dsa_to_port(ds, port); + u16 standalone_vid, bridge_vid; + int err; + + /* Delete the standalone VLAN of the port and replace it with a + * bridging VLAN + */ + standalone_vid = dsa_tag_8021q_standalone_vid(dp); + bridge_vid = dsa_tag_8021q_bridge_vid(bridge.num); + + err = dsa_port_tag_8021q_vlan_add(dp, bridge_vid, true); + if (err) + return err; + + dsa_port_tag_8021q_vlan_del(dp, standalone_vid, false); + + return 0; +} +EXPORT_SYMBOL_GPL(dsa_tag_8021q_bridge_join); + +void dsa_tag_8021q_bridge_leave(struct dsa_switch *ds, int port, + struct dsa_bridge bridge) +{ + struct dsa_port *dp = dsa_to_port(ds, port); + u16 standalone_vid, bridge_vid; + int err; + + /* Delete the bridging VLAN of the port and replace it with a + * standalone VLAN + */ + standalone_vid = dsa_tag_8021q_standalone_vid(dp); + bridge_vid = dsa_tag_8021q_bridge_vid(bridge.num); + + err = dsa_port_tag_8021q_vlan_add(dp, standalone_vid, false); + if (err) { + dev_err(ds->dev, + "Failed to delete tag_8021q standalone VLAN %d from port %d: %pe\n", + standalone_vid, port, ERR_PTR(err)); + } + + dsa_port_tag_8021q_vlan_del(dp, bridge_vid, true); +} +EXPORT_SYMBOL_GPL(dsa_tag_8021q_bridge_leave); + +/* Set up a port's standalone tag_8021q VLAN */ +static int dsa_tag_8021q_port_setup(struct dsa_switch *ds, int port) +{ + struct dsa_8021q_context *ctx = ds->tag_8021q_ctx; + struct dsa_port *dp = dsa_to_port(ds, port); + u16 vid = dsa_tag_8021q_standalone_vid(dp); + struct net_device *master; + int err; + + /* The CPU port is implicitly configured by + * configuring the front-panel ports + */ + if (!dsa_port_is_user(dp)) + return 0; + + master = dsa_port_to_master(dp); + + err = dsa_port_tag_8021q_vlan_add(dp, vid, false); + if (err) { + dev_err(ds->dev, + "Failed to apply standalone VID %d to port %d: %pe\n", + vid, port, ERR_PTR(err)); + return err; + } + + /* Add the VLAN to the master's RX filter. */ + vlan_vid_add(master, ctx->proto, vid); + + return err; +} + +static void dsa_tag_8021q_port_teardown(struct dsa_switch *ds, int port) +{ + struct dsa_8021q_context *ctx = ds->tag_8021q_ctx; + struct dsa_port *dp = dsa_to_port(ds, port); + u16 vid = dsa_tag_8021q_standalone_vid(dp); + struct net_device *master; + + /* The CPU port is implicitly configured by + * configuring the front-panel ports + */ + if (!dsa_port_is_user(dp)) + return; + + master = dsa_port_to_master(dp); + + dsa_port_tag_8021q_vlan_del(dp, vid, false); + + vlan_vid_del(master, ctx->proto, vid); +} + +static int dsa_tag_8021q_setup(struct dsa_switch *ds) +{ + int err, port; + + ASSERT_RTNL(); + + for (port = 0; port < ds->num_ports; port++) { + err = dsa_tag_8021q_port_setup(ds, port); + if (err < 0) { + dev_err(ds->dev, + "Failed to setup VLAN tagging for port %d: %pe\n", + port, ERR_PTR(err)); + return err; + } + } + + return 0; +} + +static void dsa_tag_8021q_teardown(struct dsa_switch *ds) +{ + int port; + + ASSERT_RTNL(); + + for (port = 0; port < ds->num_ports; port++) + dsa_tag_8021q_port_teardown(ds, port); +} + +int dsa_tag_8021q_register(struct dsa_switch *ds, __be16 proto) +{ + struct dsa_8021q_context *ctx; + int err; + + ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); + if (!ctx) + return -ENOMEM; + + ctx->proto = proto; + ctx->ds = ds; + + INIT_LIST_HEAD(&ctx->vlans); + + ds->tag_8021q_ctx = ctx; + + err = dsa_tag_8021q_setup(ds); + if (err) + goto err_free; + + return 0; + +err_free: + kfree(ctx); + return err; +} +EXPORT_SYMBOL_GPL(dsa_tag_8021q_register); + +void dsa_tag_8021q_unregister(struct dsa_switch *ds) +{ + struct dsa_8021q_context *ctx = ds->tag_8021q_ctx; + struct dsa_tag_8021q_vlan *v, *n; + + dsa_tag_8021q_teardown(ds); + + list_for_each_entry_safe(v, n, &ctx->vlans, list) { + list_del(&v->list); + kfree(v); + } + + ds->tag_8021q_ctx = NULL; + + kfree(ctx); +} +EXPORT_SYMBOL_GPL(dsa_tag_8021q_unregister); + +struct sk_buff *dsa_8021q_xmit(struct sk_buff *skb, struct net_device *netdev, + u16 tpid, u16 tci) +{ + /* skb->data points at skb_mac_header, which + * is fine for vlan_insert_tag. + */ + return vlan_insert_tag(skb, htons(tpid), tci); +} +EXPORT_SYMBOL_GPL(dsa_8021q_xmit); + +struct net_device *dsa_tag_8021q_find_port_by_vbid(struct net_device *master, + int vbid) +{ + struct dsa_port *cpu_dp = master->dsa_ptr; + struct dsa_switch_tree *dst = cpu_dp->dst; + struct dsa_port *dp; + + if (WARN_ON(!vbid)) + return NULL; + + dsa_tree_for_each_user_port(dp, dst) { + if (!dp->bridge) + continue; + + if (dp->stp_state != BR_STATE_LEARNING && + dp->stp_state != BR_STATE_FORWARDING) + continue; + + if (dp->cpu_dp != cpu_dp) + continue; + + if (dsa_port_bridge_num_get(dp) == vbid) + return dp->slave; + } + + return NULL; +} +EXPORT_SYMBOL_GPL(dsa_tag_8021q_find_port_by_vbid); + +void dsa_8021q_rcv(struct sk_buff *skb, int *source_port, int *switch_id, + int *vbid) +{ + u16 vid, tci; + + if (skb_vlan_tag_present(skb)) { + tci = skb_vlan_tag_get(skb); + __vlan_hwaccel_clear_tag(skb); + } else { + skb_push_rcsum(skb, ETH_HLEN); + __skb_vlan_pop(skb, &tci); + skb_pull_rcsum(skb, ETH_HLEN); + } + + vid = tci & VLAN_VID_MASK; + + *source_port = dsa_8021q_rx_source_port(vid); + *switch_id = dsa_8021q_rx_switch_id(vid); + + if (vbid) + *vbid = dsa_tag_8021q_rx_vbid(vid); + + skb->priority = (tci & VLAN_PRIO_MASK) >> VLAN_PRIO_SHIFT; +} +EXPORT_SYMBOL_GPL(dsa_8021q_rcv); diff --git a/net/dsa/tag_ar9331.c b/net/dsa/tag_ar9331.c new file mode 100644 index 000000000..8a02ac442 --- /dev/null +++ b/net/dsa/tag_ar9331.c @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2019 Pengutronix, Oleksij Rempel + */ + + +#include +#include + +#include "dsa_priv.h" + +#define AR9331_HDR_LEN 2 +#define AR9331_HDR_VERSION 1 + +#define AR9331_HDR_VERSION_MASK GENMASK(15, 14) +#define AR9331_HDR_PRIORITY_MASK GENMASK(13, 12) +#define AR9331_HDR_TYPE_MASK GENMASK(10, 8) +#define AR9331_HDR_BROADCAST BIT(7) +#define AR9331_HDR_FROM_CPU BIT(6) +/* AR9331_HDR_RESERVED - not used or may be version field. + * According to the AR8216 doc it should 0b10. On AR9331 it is 0b11 on RX path + * and should be set to 0b11 to make it work. + */ +#define AR9331_HDR_RESERVED_MASK GENMASK(5, 4) +#define AR9331_HDR_PORT_NUM_MASK GENMASK(3, 0) + +static struct sk_buff *ar9331_tag_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + __le16 *phdr; + u16 hdr; + + phdr = skb_push(skb, AR9331_HDR_LEN); + + hdr = FIELD_PREP(AR9331_HDR_VERSION_MASK, AR9331_HDR_VERSION); + hdr |= AR9331_HDR_FROM_CPU | dp->index; + /* 0b10 for AR8216 and 0b11 for AR9331 */ + hdr |= AR9331_HDR_RESERVED_MASK; + + phdr[0] = cpu_to_le16(hdr); + + return skb; +} + +static struct sk_buff *ar9331_tag_rcv(struct sk_buff *skb, + struct net_device *ndev) +{ + u8 ver, port; + u16 hdr; + + if (unlikely(!pskb_may_pull(skb, AR9331_HDR_LEN))) + return NULL; + + hdr = le16_to_cpu(*(__le16 *)skb_mac_header(skb)); + + ver = FIELD_GET(AR9331_HDR_VERSION_MASK, hdr); + if (unlikely(ver != AR9331_HDR_VERSION)) { + netdev_warn_once(ndev, "%s:%i wrong header version 0x%2x\n", + __func__, __LINE__, hdr); + return NULL; + } + + if (unlikely(hdr & AR9331_HDR_FROM_CPU)) { + netdev_warn_once(ndev, "%s:%i packet should not be from cpu 0x%2x\n", + __func__, __LINE__, hdr); + return NULL; + } + + skb_pull_rcsum(skb, AR9331_HDR_LEN); + + /* Get source port information */ + port = FIELD_GET(AR9331_HDR_PORT_NUM_MASK, hdr); + + skb->dev = dsa_master_find_slave(ndev, 0, port); + if (!skb->dev) + return NULL; + + return skb; +} + +static const struct dsa_device_ops ar9331_netdev_ops = { + .name = "ar9331", + .proto = DSA_TAG_PROTO_AR9331, + .xmit = ar9331_tag_xmit, + .rcv = ar9331_tag_rcv, + .needed_headroom = AR9331_HDR_LEN, +}; + +MODULE_LICENSE("GPL v2"); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_AR9331); +module_dsa_tag_driver(ar9331_netdev_ops); diff --git a/net/dsa/tag_brcm.c b/net/dsa/tag_brcm.c new file mode 100644 index 000000000..a65d62fb9 --- /dev/null +++ b/net/dsa/tag_brcm.c @@ -0,0 +1,334 @@ +// SPDX-License-Identifier: GPL-2.0+ +/* + * Broadcom tag support + * + * Copyright (C) 2014 Broadcom Corporation + */ + +#include +#include +#include +#include +#include + +#include "dsa_priv.h" + +/* Legacy Broadcom tag (6 bytes) */ +#define BRCM_LEG_TAG_LEN 6 + +/* Type fields */ +/* 1st byte in the tag */ +#define BRCM_LEG_TYPE_HI 0x88 +/* 2nd byte in the tag */ +#define BRCM_LEG_TYPE_LO 0x74 + +/* Tag fields */ +/* 3rd byte in the tag */ +#define BRCM_LEG_UNICAST (0 << 5) +#define BRCM_LEG_MULTICAST (1 << 5) +#define BRCM_LEG_EGRESS (2 << 5) +#define BRCM_LEG_INGRESS (3 << 5) + +/* 6th byte in the tag */ +#define BRCM_LEG_PORT_ID (0xf) + +/* Newer Broadcom tag (4 bytes) */ +#define BRCM_TAG_LEN 4 + +/* Tag is constructed and deconstructed using byte by byte access + * because the tag is placed after the MAC Source Address, which does + * not make it 4-bytes aligned, so this might cause unaligned accesses + * on most systems where this is used. + */ + +/* Ingress and egress opcodes */ +#define BRCM_OPCODE_SHIFT 5 +#define BRCM_OPCODE_MASK 0x7 + +/* Ingress fields */ +/* 1st byte in the tag */ +#define BRCM_IG_TC_SHIFT 2 +#define BRCM_IG_TC_MASK 0x7 +/* 2nd byte in the tag */ +#define BRCM_IG_TE_MASK 0x3 +#define BRCM_IG_TS_SHIFT 7 +/* 3rd byte in the tag */ +#define BRCM_IG_DSTMAP2_MASK 1 +#define BRCM_IG_DSTMAP1_MASK 0xff + +/* Egress fields */ + +/* 2nd byte in the tag */ +#define BRCM_EG_CID_MASK 0xff + +/* 3rd byte in the tag */ +#define BRCM_EG_RC_MASK 0xff +#define BRCM_EG_RC_RSVD (3 << 6) +#define BRCM_EG_RC_EXCEPTION (1 << 5) +#define BRCM_EG_RC_PROT_SNOOP (1 << 4) +#define BRCM_EG_RC_PROT_TERM (1 << 3) +#define BRCM_EG_RC_SWITCH (1 << 2) +#define BRCM_EG_RC_MAC_LEARN (1 << 1) +#define BRCM_EG_RC_MIRROR (1 << 0) +#define BRCM_EG_TC_SHIFT 5 +#define BRCM_EG_TC_MASK 0x7 +#define BRCM_EG_PID_MASK 0x1f + +#if IS_ENABLED(CONFIG_NET_DSA_TAG_BRCM) || \ + IS_ENABLED(CONFIG_NET_DSA_TAG_BRCM_PREPEND) + +static struct sk_buff *brcm_tag_xmit_ll(struct sk_buff *skb, + struct net_device *dev, + unsigned int offset) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + u16 queue = skb_get_queue_mapping(skb); + u8 *brcm_tag; + + /* The Ethernet switch we are interfaced with needs packets to be at + * least 64 bytes (including FCS) otherwise they will be discarded when + * they enter the switch port logic. When Broadcom tags are enabled, we + * need to make sure that packets are at least 68 bytes + * (including FCS and tag) because the length verification is done after + * the Broadcom tag is stripped off the ingress packet. + * + * Let dsa_slave_xmit() free the SKB + */ + if (__skb_put_padto(skb, ETH_ZLEN + BRCM_TAG_LEN, false)) + return NULL; + + skb_push(skb, BRCM_TAG_LEN); + + if (offset) + dsa_alloc_etype_header(skb, BRCM_TAG_LEN); + + brcm_tag = skb->data + offset; + + /* Set the ingress opcode, traffic class, tag enforcement is + * deprecated + */ + brcm_tag[0] = (1 << BRCM_OPCODE_SHIFT) | + ((queue & BRCM_IG_TC_MASK) << BRCM_IG_TC_SHIFT); + brcm_tag[1] = 0; + brcm_tag[2] = 0; + if (dp->index == 8) + brcm_tag[2] = BRCM_IG_DSTMAP2_MASK; + brcm_tag[3] = (1 << dp->index) & BRCM_IG_DSTMAP1_MASK; + + /* Now tell the master network device about the desired output queue + * as well + */ + skb_set_queue_mapping(skb, BRCM_TAG_SET_PORT_QUEUE(dp->index, queue)); + + return skb; +} + +/* Frames with this tag have one of these two layouts: + * ----------------------------------- + * | MAC DA | MAC SA | 4b tag | Type | DSA_TAG_PROTO_BRCM + * ----------------------------------- + * ----------------------------------- + * | 4b tag | MAC DA | MAC SA | Type | DSA_TAG_PROTO_BRCM_PREPEND + * ----------------------------------- + * In both cases, at receive time, skb->data points 2 bytes before the actual + * Ethernet type field and we have an offset of 4bytes between where skb->data + * and where the payload starts. So the same low-level receive function can be + * used. + */ +static struct sk_buff *brcm_tag_rcv_ll(struct sk_buff *skb, + struct net_device *dev, + unsigned int offset) +{ + int source_port; + u8 *brcm_tag; + + if (unlikely(!pskb_may_pull(skb, BRCM_TAG_LEN))) + return NULL; + + brcm_tag = skb->data - offset; + + /* The opcode should never be different than 0b000 */ + if (unlikely((brcm_tag[0] >> BRCM_OPCODE_SHIFT) & BRCM_OPCODE_MASK)) + return NULL; + + /* We should never see a reserved reason code without knowing how to + * handle it + */ + if (unlikely(brcm_tag[2] & BRCM_EG_RC_RSVD)) + return NULL; + + /* Locate which port this is coming from */ + source_port = brcm_tag[3] & BRCM_EG_PID_MASK; + + skb->dev = dsa_master_find_slave(dev, 0, source_port); + if (!skb->dev) + return NULL; + + /* Remove Broadcom tag and update checksum */ + skb_pull_rcsum(skb, BRCM_TAG_LEN); + + dsa_default_offload_fwd_mark(skb); + + return skb; +} +#endif + +#if IS_ENABLED(CONFIG_NET_DSA_TAG_BRCM) +static struct sk_buff *brcm_tag_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + /* Build the tag after the MAC Source Address */ + return brcm_tag_xmit_ll(skb, dev, 2 * ETH_ALEN); +} + + +static struct sk_buff *brcm_tag_rcv(struct sk_buff *skb, struct net_device *dev) +{ + struct sk_buff *nskb; + + /* skb->data points to the EtherType, the tag is right before it */ + nskb = brcm_tag_rcv_ll(skb, dev, 2); + if (!nskb) + return nskb; + + dsa_strip_etype_header(skb, BRCM_TAG_LEN); + + return nskb; +} + +static const struct dsa_device_ops brcm_netdev_ops = { + .name = "brcm", + .proto = DSA_TAG_PROTO_BRCM, + .xmit = brcm_tag_xmit, + .rcv = brcm_tag_rcv, + .needed_headroom = BRCM_TAG_LEN, +}; + +DSA_TAG_DRIVER(brcm_netdev_ops); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_BRCM); +#endif + +#if IS_ENABLED(CONFIG_NET_DSA_TAG_BRCM_LEGACY) +static struct sk_buff *brcm_leg_tag_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + u8 *brcm_tag; + + /* The Ethernet switch we are interfaced with needs packets to be at + * least 64 bytes (including FCS) otherwise they will be discarded when + * they enter the switch port logic. When Broadcom tags are enabled, we + * need to make sure that packets are at least 70 bytes + * (including FCS and tag) because the length verification is done after + * the Broadcom tag is stripped off the ingress packet. + * + * Let dsa_slave_xmit() free the SKB + */ + if (__skb_put_padto(skb, ETH_ZLEN + BRCM_LEG_TAG_LEN, false)) + return NULL; + + skb_push(skb, BRCM_LEG_TAG_LEN); + + dsa_alloc_etype_header(skb, BRCM_LEG_TAG_LEN); + + brcm_tag = skb->data + 2 * ETH_ALEN; + + /* Broadcom tag type */ + brcm_tag[0] = BRCM_LEG_TYPE_HI; + brcm_tag[1] = BRCM_LEG_TYPE_LO; + + /* Broadcom tag value */ + brcm_tag[2] = BRCM_LEG_EGRESS; + brcm_tag[3] = 0; + brcm_tag[4] = 0; + brcm_tag[5] = dp->index & BRCM_LEG_PORT_ID; + + return skb; +} + +static struct sk_buff *brcm_leg_tag_rcv(struct sk_buff *skb, + struct net_device *dev) +{ + int len = BRCM_LEG_TAG_LEN; + int source_port; + u8 *brcm_tag; + + if (unlikely(!pskb_may_pull(skb, BRCM_LEG_PORT_ID))) + return NULL; + + brcm_tag = dsa_etype_header_pos_rx(skb); + + source_port = brcm_tag[5] & BRCM_LEG_PORT_ID; + + skb->dev = dsa_master_find_slave(dev, 0, source_port); + if (!skb->dev) + return NULL; + + /* VLAN tag is added by BCM63xx internal switch */ + if (netdev_uses_dsa(skb->dev)) + len += VLAN_HLEN; + + /* Remove Broadcom tag and update checksum */ + skb_pull_rcsum(skb, len); + + dsa_default_offload_fwd_mark(skb); + + dsa_strip_etype_header(skb, len); + + return skb; +} + +static const struct dsa_device_ops brcm_legacy_netdev_ops = { + .name = "brcm-legacy", + .proto = DSA_TAG_PROTO_BRCM_LEGACY, + .xmit = brcm_leg_tag_xmit, + .rcv = brcm_leg_tag_rcv, + .needed_headroom = BRCM_LEG_TAG_LEN, +}; + +DSA_TAG_DRIVER(brcm_legacy_netdev_ops); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_BRCM_LEGACY); +#endif /* CONFIG_NET_DSA_TAG_BRCM_LEGACY */ + +#if IS_ENABLED(CONFIG_NET_DSA_TAG_BRCM_PREPEND) +static struct sk_buff *brcm_tag_xmit_prepend(struct sk_buff *skb, + struct net_device *dev) +{ + /* tag is prepended to the packet */ + return brcm_tag_xmit_ll(skb, dev, 0); +} + +static struct sk_buff *brcm_tag_rcv_prepend(struct sk_buff *skb, + struct net_device *dev) +{ + /* tag is prepended to the packet */ + return brcm_tag_rcv_ll(skb, dev, ETH_HLEN); +} + +static const struct dsa_device_ops brcm_prepend_netdev_ops = { + .name = "brcm-prepend", + .proto = DSA_TAG_PROTO_BRCM_PREPEND, + .xmit = brcm_tag_xmit_prepend, + .rcv = brcm_tag_rcv_prepend, + .needed_headroom = BRCM_TAG_LEN, +}; + +DSA_TAG_DRIVER(brcm_prepend_netdev_ops); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_BRCM_PREPEND); +#endif + +static struct dsa_tag_driver *dsa_tag_driver_array[] = { +#if IS_ENABLED(CONFIG_NET_DSA_TAG_BRCM) + &DSA_TAG_DRIVER_NAME(brcm_netdev_ops), +#endif +#if IS_ENABLED(CONFIG_NET_DSA_TAG_BRCM_LEGACY) + &DSA_TAG_DRIVER_NAME(brcm_legacy_netdev_ops), +#endif +#if IS_ENABLED(CONFIG_NET_DSA_TAG_BRCM_PREPEND) + &DSA_TAG_DRIVER_NAME(brcm_prepend_netdev_ops), +#endif +}; + +module_dsa_tag_drivers(dsa_tag_driver_array); + +MODULE_LICENSE("GPL"); diff --git a/net/dsa/tag_dsa.c b/net/dsa/tag_dsa.c new file mode 100644 index 000000000..e4b6e3f2a --- /dev/null +++ b/net/dsa/tag_dsa.c @@ -0,0 +1,406 @@ +// SPDX-License-Identifier: GPL-2.0+ +/* + * Regular and Ethertype DSA tagging + * Copyright (c) 2008-2009 Marvell Semiconductor + * + * Regular DSA + * ----------- + + * For untagged (in 802.1Q terms) packets, the switch will splice in + * the tag between the SA and the ethertype of the original + * packet. Tagged frames will instead have their outermost .1Q tag + * converted to a DSA tag. It expects the same layout when receiving + * packets from the CPU. + * + * Example: + * + * .----.----.----.--------- + * Pu: | DA | SA | ET | Payload ... + * '----'----'----'--------- + * 6 6 2 N + * .----.----.--------.-----.----.--------- + * Pt: | DA | SA | 0x8100 | TCI | ET | Payload ... + * '----'----'--------'-----'----'--------- + * 6 6 2 2 2 N + * .----.----.-----.----.--------- + * Pd: | DA | SA | DSA | ET | Payload ... + * '----'----'-----'----'--------- + * 6 6 4 2 N + * + * No matter if a packet is received untagged (Pu) or tagged (Pt), + * they will both have the same layout (Pd) when they are sent to the + * CPU. This is done by ignoring 802.3, replacing the ethertype field + * with more metadata, among which is a bit to signal if the original + * packet was tagged or not. + * + * Ethertype DSA + * ------------- + * Uses the exact same tag format as regular DSA, but also includes a + * proper ethertype field (which the mv88e6xxx driver sets to + * ETH_P_EDSA/0xdada) followed by two zero bytes: + * + * .----.----.--------.--------.-----.----.--------- + * | DA | SA | 0xdada | 0x0000 | DSA | ET | Payload ... + * '----'----'--------'--------'-----'----'--------- + * 6 6 2 2 4 2 N + */ + +#include +#include +#include +#include + +#include "dsa_priv.h" + +#define DSA_HLEN 4 + +/** + * enum dsa_cmd - DSA Command + * @DSA_CMD_TO_CPU: Set on packets that were trapped or mirrored to + * the CPU port. This is needed to implement control protocols, + * e.g. STP and LLDP, that must not allow those control packets to + * be switched according to the normal rules. + * @DSA_CMD_FROM_CPU: Used by the CPU to send a packet to a specific + * port, ignoring all the barriers that the switch normally + * enforces (VLANs, STP port states etc.). No source address + * learning takes place. "sudo send packet" + * @DSA_CMD_TO_SNIFFER: Set on the copies of packets that matched some + * user configured ingress or egress monitor criteria. These are + * forwarded by the switch tree to the user configured ingress or + * egress monitor port, which can be set to the CPU port or a + * regular port. If the destination is a regular port, the tag + * will be removed before egressing the port. If the destination + * is the CPU port, the tag will not be removed. + * @DSA_CMD_FORWARD: This tag is used on all bulk traffic passing + * through the switch tree, including the flows that are directed + * towards the CPU. Its device/port tuple encodes the original + * source port on which the packet ingressed. It can also be used + * on transmit by the CPU to defer the forwarding decision to the + * hardware, based on the current config of PVT/VTU/ATU + * etc. Source address learning takes places if enabled on the + * receiving DSA/CPU port. + */ +enum dsa_cmd { + DSA_CMD_TO_CPU = 0, + DSA_CMD_FROM_CPU = 1, + DSA_CMD_TO_SNIFFER = 2, + DSA_CMD_FORWARD = 3 +}; + +/** + * enum dsa_code - TO_CPU Code + * + * @DSA_CODE_MGMT_TRAP: DA was classified as a management + * address. Typical examples include STP BPDUs and LLDP. + * @DSA_CODE_FRAME2REG: Response to a "remote management" request. + * @DSA_CODE_IGMP_MLD_TRAP: IGMP/MLD signaling. + * @DSA_CODE_POLICY_TRAP: Frame matched some policy configuration on + * the device. Typical examples are matching on DA/SA/VID and DHCP + * snooping. + * @DSA_CODE_ARP_MIRROR: The name says it all really. + * @DSA_CODE_POLICY_MIRROR: Same as @DSA_CODE_POLICY_TRAP, but the + * particular policy was set to trigger a mirror instead of a + * trap. + * @DSA_CODE_RESERVED_6: Unused on all devices up to at least 6393X. + * @DSA_CODE_RESERVED_7: Unused on all devices up to at least 6393X. + * + * A 3-bit code is used to relay why a particular frame was sent to + * the CPU. We only use this to determine if the packet was mirrored + * or trapped, i.e. whether the packet has been forwarded by hardware + * or not. + * + * This is the superset of all possible codes. Any particular device + * may only implement a subset. + */ +enum dsa_code { + DSA_CODE_MGMT_TRAP = 0, + DSA_CODE_FRAME2REG = 1, + DSA_CODE_IGMP_MLD_TRAP = 2, + DSA_CODE_POLICY_TRAP = 3, + DSA_CODE_ARP_MIRROR = 4, + DSA_CODE_POLICY_MIRROR = 5, + DSA_CODE_RESERVED_6 = 6, + DSA_CODE_RESERVED_7 = 7 +}; + +static struct sk_buff *dsa_xmit_ll(struct sk_buff *skb, struct net_device *dev, + u8 extra) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct net_device *br_dev; + u8 tag_dev, tag_port; + enum dsa_cmd cmd; + u8 *dsa_header; + + if (skb->offload_fwd_mark) { + unsigned int bridge_num = dsa_port_bridge_num_get(dp); + struct dsa_switch_tree *dst = dp->ds->dst; + + cmd = DSA_CMD_FORWARD; + + /* When offloading forwarding for a bridge, inject FORWARD + * packets on behalf of a virtual switch device with an index + * past the physical switches. + */ + tag_dev = dst->last_switch + bridge_num; + tag_port = 0; + } else { + cmd = DSA_CMD_FROM_CPU; + tag_dev = dp->ds->index; + tag_port = dp->index; + } + + br_dev = dsa_port_bridge_dev_get(dp); + + /* If frame is already 802.1Q tagged, we can convert it to a DSA + * tag (avoiding a memmove), but only if the port is standalone + * (in which case we always send FROM_CPU) or if the port's + * bridge has VLAN filtering enabled (in which case the CPU port + * will be a member of the VLAN). + */ + if (skb->protocol == htons(ETH_P_8021Q) && + (!br_dev || br_vlan_enabled(br_dev))) { + if (extra) { + skb_push(skb, extra); + dsa_alloc_etype_header(skb, extra); + } + + /* Construct tagged DSA tag from 802.1Q tag. */ + dsa_header = dsa_etype_header_pos_tx(skb) + extra; + dsa_header[0] = (cmd << 6) | 0x20 | tag_dev; + dsa_header[1] = tag_port << 3; + + /* Move CFI field from byte 2 to byte 1. */ + if (dsa_header[2] & 0x10) { + dsa_header[1] |= 0x01; + dsa_header[2] &= ~0x10; + } + } else { + u16 vid; + + vid = br_dev ? MV88E6XXX_VID_BRIDGED : MV88E6XXX_VID_STANDALONE; + + skb_push(skb, DSA_HLEN + extra); + dsa_alloc_etype_header(skb, DSA_HLEN + extra); + + /* Construct DSA header from untagged frame. */ + dsa_header = dsa_etype_header_pos_tx(skb) + extra; + + dsa_header[0] = (cmd << 6) | tag_dev; + dsa_header[1] = tag_port << 3; + dsa_header[2] = vid >> 8; + dsa_header[3] = vid & 0xff; + } + + return skb; +} + +static struct sk_buff *dsa_rcv_ll(struct sk_buff *skb, struct net_device *dev, + u8 extra) +{ + bool trap = false, trunk = false; + int source_device, source_port; + enum dsa_code code; + enum dsa_cmd cmd; + u8 *dsa_header; + + /* The ethertype field is part of the DSA header. */ + dsa_header = dsa_etype_header_pos_rx(skb); + + cmd = dsa_header[0] >> 6; + switch (cmd) { + case DSA_CMD_FORWARD: + trunk = !!(dsa_header[1] & 4); + break; + + case DSA_CMD_TO_CPU: + code = (dsa_header[1] & 0x6) | ((dsa_header[2] >> 4) & 1); + + switch (code) { + case DSA_CODE_FRAME2REG: + /* Remote management is not implemented yet, + * drop. + */ + return NULL; + case DSA_CODE_ARP_MIRROR: + case DSA_CODE_POLICY_MIRROR: + /* Mark mirrored packets to notify any upper + * device (like a bridge) that forwarding has + * already been done by hardware. + */ + break; + case DSA_CODE_MGMT_TRAP: + case DSA_CODE_IGMP_MLD_TRAP: + case DSA_CODE_POLICY_TRAP: + /* Traps have, by definition, not been + * forwarded by hardware, so don't mark them. + */ + trap = true; + break; + default: + /* Reserved code, this could be anything. Drop + * seems like the safest option. + */ + return NULL; + } + + break; + + default: + return NULL; + } + + source_device = dsa_header[0] & 0x1f; + source_port = (dsa_header[1] >> 3) & 0x1f; + + if (trunk) { + struct dsa_port *cpu_dp = dev->dsa_ptr; + struct dsa_lag *lag; + + /* The exact source port is not available in the tag, + * so we inject the frame directly on the upper + * team/bond. + */ + lag = dsa_lag_by_id(cpu_dp->dst, source_port + 1); + skb->dev = lag ? lag->dev : NULL; + } else { + skb->dev = dsa_master_find_slave(dev, source_device, + source_port); + } + + if (!skb->dev) + return NULL; + + /* When using LAG offload, skb->dev is not a DSA slave interface, + * so we cannot call dsa_default_offload_fwd_mark and we need to + * special-case it. + */ + if (trunk) + skb->offload_fwd_mark = true; + else if (!trap) + dsa_default_offload_fwd_mark(skb); + + /* If the 'tagged' bit is set; convert the DSA tag to a 802.1Q + * tag, and delete the ethertype (extra) if applicable. If the + * 'tagged' bit is cleared; delete the DSA tag, and ethertype + * if applicable. + */ + if (dsa_header[0] & 0x20) { + u8 new_header[4]; + + /* Insert 802.1Q ethertype and copy the VLAN-related + * fields, but clear the bit that will hold CFI (since + * DSA uses that bit location for another purpose). + */ + new_header[0] = (ETH_P_8021Q >> 8) & 0xff; + new_header[1] = ETH_P_8021Q & 0xff; + new_header[2] = dsa_header[2] & ~0x10; + new_header[3] = dsa_header[3]; + + /* Move CFI bit from its place in the DSA header to + * its 802.1Q-designated place. + */ + if (dsa_header[1] & 0x01) + new_header[2] |= 0x10; + + /* Update packet checksum if skb is CHECKSUM_COMPLETE. */ + if (skb->ip_summed == CHECKSUM_COMPLETE) { + __wsum c = skb->csum; + c = csum_add(c, csum_partial(new_header + 2, 2, 0)); + c = csum_sub(c, csum_partial(dsa_header + 2, 2, 0)); + skb->csum = c; + } + + memcpy(dsa_header, new_header, DSA_HLEN); + + if (extra) + dsa_strip_etype_header(skb, extra); + } else { + skb_pull_rcsum(skb, DSA_HLEN); + dsa_strip_etype_header(skb, DSA_HLEN + extra); + } + + return skb; +} + +#if IS_ENABLED(CONFIG_NET_DSA_TAG_DSA) + +static struct sk_buff *dsa_xmit(struct sk_buff *skb, struct net_device *dev) +{ + return dsa_xmit_ll(skb, dev, 0); +} + +static struct sk_buff *dsa_rcv(struct sk_buff *skb, struct net_device *dev) +{ + if (unlikely(!pskb_may_pull(skb, DSA_HLEN))) + return NULL; + + return dsa_rcv_ll(skb, dev, 0); +} + +static const struct dsa_device_ops dsa_netdev_ops = { + .name = "dsa", + .proto = DSA_TAG_PROTO_DSA, + .xmit = dsa_xmit, + .rcv = dsa_rcv, + .needed_headroom = DSA_HLEN, +}; + +DSA_TAG_DRIVER(dsa_netdev_ops); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_DSA); +#endif /* CONFIG_NET_DSA_TAG_DSA */ + +#if IS_ENABLED(CONFIG_NET_DSA_TAG_EDSA) + +#define EDSA_HLEN 8 + +static struct sk_buff *edsa_xmit(struct sk_buff *skb, struct net_device *dev) +{ + u8 *edsa_header; + + skb = dsa_xmit_ll(skb, dev, EDSA_HLEN - DSA_HLEN); + if (!skb) + return NULL; + + edsa_header = dsa_etype_header_pos_tx(skb); + edsa_header[0] = (ETH_P_EDSA >> 8) & 0xff; + edsa_header[1] = ETH_P_EDSA & 0xff; + edsa_header[2] = 0x00; + edsa_header[3] = 0x00; + return skb; +} + +static struct sk_buff *edsa_rcv(struct sk_buff *skb, struct net_device *dev) +{ + if (unlikely(!pskb_may_pull(skb, EDSA_HLEN))) + return NULL; + + skb_pull_rcsum(skb, EDSA_HLEN - DSA_HLEN); + + return dsa_rcv_ll(skb, dev, EDSA_HLEN - DSA_HLEN); +} + +static const struct dsa_device_ops edsa_netdev_ops = { + .name = "edsa", + .proto = DSA_TAG_PROTO_EDSA, + .xmit = edsa_xmit, + .rcv = edsa_rcv, + .needed_headroom = EDSA_HLEN, +}; + +DSA_TAG_DRIVER(edsa_netdev_ops); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_EDSA); +#endif /* CONFIG_NET_DSA_TAG_EDSA */ + +static struct dsa_tag_driver *dsa_tag_drivers[] = { +#if IS_ENABLED(CONFIG_NET_DSA_TAG_DSA) + &DSA_TAG_DRIVER_NAME(dsa_netdev_ops), +#endif +#if IS_ENABLED(CONFIG_NET_DSA_TAG_EDSA) + &DSA_TAG_DRIVER_NAME(edsa_netdev_ops), +#endif +}; + +module_dsa_tag_drivers(dsa_tag_drivers); + +MODULE_LICENSE("GPL"); diff --git a/net/dsa/tag_gswip.c b/net/dsa/tag_gswip.c new file mode 100644 index 000000000..df7140984 --- /dev/null +++ b/net/dsa/tag_gswip.c @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Intel / Lantiq GSWIP V2.0 PMAC tag support + * + * Copyright (C) 2017 - 2018 Hauke Mehrtens + */ + +#include +#include +#include +#include + +#include "dsa_priv.h" + +#define GSWIP_TX_HEADER_LEN 4 + +/* special tag in TX path header */ +/* Byte 0 */ +#define GSWIP_TX_SLPID_SHIFT 0 /* source port ID */ +#define GSWIP_TX_SLPID_CPU 2 +#define GSWIP_TX_SLPID_APP1 3 +#define GSWIP_TX_SLPID_APP2 4 +#define GSWIP_TX_SLPID_APP3 5 +#define GSWIP_TX_SLPID_APP4 6 +#define GSWIP_TX_SLPID_APP5 7 + +/* Byte 1 */ +#define GSWIP_TX_CRCGEN_DIS BIT(7) +#define GSWIP_TX_DPID_SHIFT 0 /* destination group ID */ +#define GSWIP_TX_DPID_ELAN 0 +#define GSWIP_TX_DPID_EWAN 1 +#define GSWIP_TX_DPID_CPU 2 +#define GSWIP_TX_DPID_APP1 3 +#define GSWIP_TX_DPID_APP2 4 +#define GSWIP_TX_DPID_APP3 5 +#define GSWIP_TX_DPID_APP4 6 +#define GSWIP_TX_DPID_APP5 7 + +/* Byte 2 */ +#define GSWIP_TX_PORT_MAP_EN BIT(7) +#define GSWIP_TX_PORT_MAP_SEL BIT(6) +#define GSWIP_TX_LRN_DIS BIT(5) +#define GSWIP_TX_CLASS_EN BIT(4) +#define GSWIP_TX_CLASS_SHIFT 0 +#define GSWIP_TX_CLASS_MASK GENMASK(3, 0) + +/* Byte 3 */ +#define GSWIP_TX_DPID_EN BIT(0) +#define GSWIP_TX_PORT_MAP_SHIFT 1 +#define GSWIP_TX_PORT_MAP_MASK GENMASK(6, 1) + +#define GSWIP_RX_HEADER_LEN 8 + +/* special tag in RX path header */ +/* Byte 7 */ +#define GSWIP_RX_SPPID_SHIFT 4 +#define GSWIP_RX_SPPID_MASK GENMASK(6, 4) + +static struct sk_buff *gswip_tag_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + u8 *gswip_tag; + + skb_push(skb, GSWIP_TX_HEADER_LEN); + + gswip_tag = skb->data; + gswip_tag[0] = GSWIP_TX_SLPID_CPU; + gswip_tag[1] = GSWIP_TX_DPID_ELAN; + gswip_tag[2] = GSWIP_TX_PORT_MAP_EN | GSWIP_TX_PORT_MAP_SEL; + gswip_tag[3] = BIT(dp->index + GSWIP_TX_PORT_MAP_SHIFT) & GSWIP_TX_PORT_MAP_MASK; + gswip_tag[3] |= GSWIP_TX_DPID_EN; + + return skb; +} + +static struct sk_buff *gswip_tag_rcv(struct sk_buff *skb, + struct net_device *dev) +{ + int port; + u8 *gswip_tag; + + if (unlikely(!pskb_may_pull(skb, GSWIP_RX_HEADER_LEN))) + return NULL; + + gswip_tag = skb->data - ETH_HLEN; + + /* Get source port information */ + port = (gswip_tag[7] & GSWIP_RX_SPPID_MASK) >> GSWIP_RX_SPPID_SHIFT; + skb->dev = dsa_master_find_slave(dev, 0, port); + if (!skb->dev) + return NULL; + + /* remove GSWIP tag */ + skb_pull_rcsum(skb, GSWIP_RX_HEADER_LEN); + + return skb; +} + +static const struct dsa_device_ops gswip_netdev_ops = { + .name = "gswip", + .proto = DSA_TAG_PROTO_GSWIP, + .xmit = gswip_tag_xmit, + .rcv = gswip_tag_rcv, + .needed_headroom = GSWIP_RX_HEADER_LEN, +}; + +MODULE_LICENSE("GPL"); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_GSWIP); + +module_dsa_tag_driver(gswip_netdev_ops); diff --git a/net/dsa/tag_hellcreek.c b/net/dsa/tag_hellcreek.c new file mode 100644 index 000000000..53a206d11 --- /dev/null +++ b/net/dsa/tag_hellcreek.c @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: (GPL-2.0 OR MIT) +/* + * net/dsa/tag_hellcreek.c - Hirschmann Hellcreek switch tag format handling + * + * Copyright (C) 2019,2020 Linutronix GmbH + * Author Kurt Kanzenbach + * + * Based on tag_ksz.c. + */ + +#include +#include + +#include "dsa_priv.h" + +#define HELLCREEK_TAG_LEN 1 + +static struct sk_buff *hellcreek_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + u8 *tag; + + /* Calculate checksums (if required) before adding the trailer tag to + * avoid including it in calculations. That would lead to wrong + * checksums after the switch strips the tag. + */ + if (skb->ip_summed == CHECKSUM_PARTIAL && + skb_checksum_help(skb)) + return NULL; + + /* Tag encoding */ + tag = skb_put(skb, HELLCREEK_TAG_LEN); + *tag = BIT(dp->index); + + return skb; +} + +static struct sk_buff *hellcreek_rcv(struct sk_buff *skb, + struct net_device *dev) +{ + /* Tag decoding */ + u8 *tag = skb_tail_pointer(skb) - HELLCREEK_TAG_LEN; + unsigned int port = tag[0] & 0x03; + + skb->dev = dsa_master_find_slave(dev, 0, port); + if (!skb->dev) { + netdev_warn_once(dev, "Failed to get source port: %d\n", port); + return NULL; + } + + if (pskb_trim_rcsum(skb, skb->len - HELLCREEK_TAG_LEN)) + return NULL; + + dsa_default_offload_fwd_mark(skb); + + return skb; +} + +static const struct dsa_device_ops hellcreek_netdev_ops = { + .name = "hellcreek", + .proto = DSA_TAG_PROTO_HELLCREEK, + .xmit = hellcreek_xmit, + .rcv = hellcreek_rcv, + .needed_tailroom = HELLCREEK_TAG_LEN, +}; + +MODULE_LICENSE("Dual MIT/GPL"); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_HELLCREEK); + +module_dsa_tag_driver(hellcreek_netdev_ops); diff --git a/net/dsa/tag_ksz.c b/net/dsa/tag_ksz.c new file mode 100644 index 000000000..429250298 --- /dev/null +++ b/net/dsa/tag_ksz.c @@ -0,0 +1,264 @@ +// SPDX-License-Identifier: GPL-2.0+ +/* + * net/dsa/tag_ksz.c - Microchip KSZ Switch tag format handling + * Copyright (c) 2017 Microchip Technology + */ + +#include +#include +#include +#include "dsa_priv.h" + +/* Typically only one byte is used for tail tag. */ +#define KSZ_EGRESS_TAG_LEN 1 +#define KSZ_INGRESS_TAG_LEN 1 + +static struct sk_buff *ksz_common_rcv(struct sk_buff *skb, + struct net_device *dev, + unsigned int port, unsigned int len) +{ + skb->dev = dsa_master_find_slave(dev, 0, port); + if (!skb->dev) + return NULL; + + if (pskb_trim_rcsum(skb, skb->len - len)) + return NULL; + + dsa_default_offload_fwd_mark(skb); + + return skb; +} + +/* + * For Ingress (Host -> KSZ8795), 1 byte is added before FCS. + * --------------------------------------------------------------------------- + * DA(6bytes)|SA(6bytes)|....|Data(nbytes)|tag(1byte)|FCS(4bytes) + * --------------------------------------------------------------------------- + * tag : each bit represents port (eg, 0x01=port1, 0x02=port2, 0x10=port5) + * + * For Egress (KSZ8795 -> Host), 1 byte is added before FCS. + * --------------------------------------------------------------------------- + * DA(6bytes)|SA(6bytes)|....|Data(nbytes)|tag0(1byte)|FCS(4bytes) + * --------------------------------------------------------------------------- + * tag0 : zero-based value represents port + * (eg, 0x00=port1, 0x02=port3, 0x06=port7) + */ + +#define KSZ8795_TAIL_TAG_OVERRIDE BIT(6) +#define KSZ8795_TAIL_TAG_LOOKUP BIT(7) + +static struct sk_buff *ksz8795_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + u8 *tag; + u8 *addr; + + if (skb->ip_summed == CHECKSUM_PARTIAL && skb_checksum_help(skb)) + return NULL; + + /* Tag encoding */ + tag = skb_put(skb, KSZ_INGRESS_TAG_LEN); + addr = skb_mac_header(skb); + + *tag = 1 << dp->index; + if (is_link_local_ether_addr(addr)) + *tag |= KSZ8795_TAIL_TAG_OVERRIDE; + + return skb; +} + +static struct sk_buff *ksz8795_rcv(struct sk_buff *skb, struct net_device *dev) +{ + u8 *tag = skb_tail_pointer(skb) - KSZ_EGRESS_TAG_LEN; + + return ksz_common_rcv(skb, dev, tag[0] & 7, KSZ_EGRESS_TAG_LEN); +} + +static const struct dsa_device_ops ksz8795_netdev_ops = { + .name = "ksz8795", + .proto = DSA_TAG_PROTO_KSZ8795, + .xmit = ksz8795_xmit, + .rcv = ksz8795_rcv, + .needed_tailroom = KSZ_INGRESS_TAG_LEN, +}; + +DSA_TAG_DRIVER(ksz8795_netdev_ops); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_KSZ8795); + +/* + * For Ingress (Host -> KSZ9477), 2 bytes are added before FCS. + * --------------------------------------------------------------------------- + * DA(6bytes)|SA(6bytes)|....|Data(nbytes)|tag0(1byte)|tag1(1byte)|FCS(4bytes) + * --------------------------------------------------------------------------- + * tag0 : Prioritization (not used now) + * tag1 : each bit represents port (eg, 0x01=port1, 0x02=port2, 0x10=port5) + * + * For Egress (KSZ9477 -> Host), 1 byte is added before FCS. + * --------------------------------------------------------------------------- + * DA(6bytes)|SA(6bytes)|....|Data(nbytes)|tag0(1byte)|FCS(4bytes) + * --------------------------------------------------------------------------- + * tag0 : zero-based value represents port + * (eg, 0x00=port1, 0x02=port3, 0x06=port7) + */ + +#define KSZ9477_INGRESS_TAG_LEN 2 +#define KSZ9477_PTP_TAG_LEN 4 +#define KSZ9477_PTP_TAG_INDICATION 0x80 + +#define KSZ9477_TAIL_TAG_OVERRIDE BIT(9) +#define KSZ9477_TAIL_TAG_LOOKUP BIT(10) + +static struct sk_buff *ksz9477_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + __be16 *tag; + u8 *addr; + u16 val; + + if (skb->ip_summed == CHECKSUM_PARTIAL && skb_checksum_help(skb)) + return NULL; + + /* Tag encoding */ + tag = skb_put(skb, KSZ9477_INGRESS_TAG_LEN); + addr = skb_mac_header(skb); + + val = BIT(dp->index); + + if (is_link_local_ether_addr(addr)) + val |= KSZ9477_TAIL_TAG_OVERRIDE; + + *tag = cpu_to_be16(val); + + return skb; +} + +static struct sk_buff *ksz9477_rcv(struct sk_buff *skb, struct net_device *dev) +{ + /* Tag decoding */ + u8 *tag = skb_tail_pointer(skb) - KSZ_EGRESS_TAG_LEN; + unsigned int port = tag[0] & 7; + unsigned int len = KSZ_EGRESS_TAG_LEN; + + /* Extra 4-bytes PTP timestamp */ + if (tag[0] & KSZ9477_PTP_TAG_INDICATION) + len += KSZ9477_PTP_TAG_LEN; + + return ksz_common_rcv(skb, dev, port, len); +} + +static const struct dsa_device_ops ksz9477_netdev_ops = { + .name = "ksz9477", + .proto = DSA_TAG_PROTO_KSZ9477, + .xmit = ksz9477_xmit, + .rcv = ksz9477_rcv, + .needed_tailroom = KSZ9477_INGRESS_TAG_LEN, +}; + +DSA_TAG_DRIVER(ksz9477_netdev_ops); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_KSZ9477); + +#define KSZ9893_TAIL_TAG_OVERRIDE BIT(5) +#define KSZ9893_TAIL_TAG_LOOKUP BIT(6) + +static struct sk_buff *ksz9893_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + u8 *addr; + u8 *tag; + + if (skb->ip_summed == CHECKSUM_PARTIAL && skb_checksum_help(skb)) + return NULL; + + /* Tag encoding */ + tag = skb_put(skb, KSZ_INGRESS_TAG_LEN); + addr = skb_mac_header(skb); + + *tag = BIT(dp->index); + + if (is_link_local_ether_addr(addr)) + *tag |= KSZ9893_TAIL_TAG_OVERRIDE; + + return skb; +} + +static const struct dsa_device_ops ksz9893_netdev_ops = { + .name = "ksz9893", + .proto = DSA_TAG_PROTO_KSZ9893, + .xmit = ksz9893_xmit, + .rcv = ksz9477_rcv, + .needed_tailroom = KSZ_INGRESS_TAG_LEN, +}; + +DSA_TAG_DRIVER(ksz9893_netdev_ops); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_KSZ9893); + +/* For xmit, 2 bytes are added before FCS. + * --------------------------------------------------------------------------- + * DA(6bytes)|SA(6bytes)|....|Data(nbytes)|tag0(1byte)|tag1(1byte)|FCS(4bytes) + * --------------------------------------------------------------------------- + * tag0 : represents tag override, lookup and valid + * tag1 : each bit represents port (eg, 0x01=port1, 0x02=port2, 0x80=port8) + * + * For rcv, 1 byte is added before FCS. + * --------------------------------------------------------------------------- + * DA(6bytes)|SA(6bytes)|....|Data(nbytes)|tag0(1byte)|FCS(4bytes) + * --------------------------------------------------------------------------- + * tag0 : zero-based value represents port + * (eg, 0x00=port1, 0x02=port3, 0x07=port8) + */ +#define LAN937X_EGRESS_TAG_LEN 2 + +#define LAN937X_TAIL_TAG_BLOCKING_OVERRIDE BIT(11) +#define LAN937X_TAIL_TAG_LOOKUP BIT(12) +#define LAN937X_TAIL_TAG_VALID BIT(13) +#define LAN937X_TAIL_TAG_PORT_MASK 7 + +static struct sk_buff *lan937x_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + const struct ethhdr *hdr = eth_hdr(skb); + __be16 *tag; + u16 val; + + if (skb->ip_summed == CHECKSUM_PARTIAL && skb_checksum_help(skb)) + return NULL; + + tag = skb_put(skb, LAN937X_EGRESS_TAG_LEN); + + val = BIT(dp->index); + + if (is_link_local_ether_addr(hdr->h_dest)) + val |= LAN937X_TAIL_TAG_BLOCKING_OVERRIDE; + + /* Tail tag valid bit - This bit should always be set by the CPU */ + val |= LAN937X_TAIL_TAG_VALID; + + put_unaligned_be16(val, tag); + + return skb; +} + +static const struct dsa_device_ops lan937x_netdev_ops = { + .name = "lan937x", + .proto = DSA_TAG_PROTO_LAN937X, + .xmit = lan937x_xmit, + .rcv = ksz9477_rcv, + .needed_tailroom = LAN937X_EGRESS_TAG_LEN, +}; + +DSA_TAG_DRIVER(lan937x_netdev_ops); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_LAN937X); + +static struct dsa_tag_driver *dsa_tag_driver_array[] = { + &DSA_TAG_DRIVER_NAME(ksz8795_netdev_ops), + &DSA_TAG_DRIVER_NAME(ksz9477_netdev_ops), + &DSA_TAG_DRIVER_NAME(ksz9893_netdev_ops), + &DSA_TAG_DRIVER_NAME(lan937x_netdev_ops), +}; + +module_dsa_tag_drivers(dsa_tag_driver_array); + +MODULE_LICENSE("GPL"); diff --git a/net/dsa/tag_lan9303.c b/net/dsa/tag_lan9303.c new file mode 100644 index 000000000..98d7d7120 --- /dev/null +++ b/net/dsa/tag_lan9303.c @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (C) 2017 Pengutronix, Juergen Borleis + */ +#include +#include +#include +#include + +#include "dsa_priv.h" + +/* To define the outgoing port and to discover the incoming port a regular + * VLAN tag is used by the LAN9303. But its VID meaning is 'special': + * + * Dest MAC Src MAC TAG Type + * ...| 1 2 3 4 5 6 | 1 2 3 4 5 6 | 1 2 3 4 | 1 2 |... + * |<------->| + * TAG: + * |<------------->| + * | 1 2 | 3 4 | + * TPID VID + * 0x8100 + * + * VID bit 3 indicates a request for an ALR lookup. + * + * If VID bit 3 is zero, then bits 0 and 1 specify the destination port + * (0, 1, 2) or broadcast (3) or the source port (1, 2). + * + * VID bit 4 is used to specify if the STP port state should be overridden. + * Required when no forwarding between the external ports should happen. + */ + +#define LAN9303_TAG_LEN 4 +# define LAN9303_TAG_TX_USE_ALR BIT(3) +# define LAN9303_TAG_TX_STP_OVERRIDE BIT(4) +# define LAN9303_TAG_RX_IGMP BIT(3) +# define LAN9303_TAG_RX_STP BIT(4) +# define LAN9303_TAG_RX_TRAPPED_TO_CPU (LAN9303_TAG_RX_IGMP | \ + LAN9303_TAG_RX_STP) + +/* Decide whether to transmit using ALR lookup, or transmit directly to + * port using tag. ALR learning is performed only when using ALR lookup. + * If the two external ports are bridged and the frame is unicast, + * then use ALR lookup to allow ALR learning on CPU port. + * Otherwise transmit directly to port with STP state override. + * See also: lan9303_separate_ports() and lan9303.pdf 6.4.10.1 + */ +static int lan9303_xmit_use_arl(struct dsa_port *dp, u8 *dest_addr) +{ + struct lan9303 *chip = dp->ds->priv; + + return chip->is_bridged && !is_multicast_ether_addr(dest_addr); +} + +static struct sk_buff *lan9303_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + __be16 *lan9303_tag; + u16 tag; + + /* provide 'LAN9303_TAG_LEN' bytes additional space */ + skb_push(skb, LAN9303_TAG_LEN); + + /* make room between MACs and Ether-Type */ + dsa_alloc_etype_header(skb, LAN9303_TAG_LEN); + + lan9303_tag = dsa_etype_header_pos_tx(skb); + + tag = lan9303_xmit_use_arl(dp, skb->data) ? + LAN9303_TAG_TX_USE_ALR : + dp->index | LAN9303_TAG_TX_STP_OVERRIDE; + lan9303_tag[0] = htons(ETH_P_8021Q); + lan9303_tag[1] = htons(tag); + + return skb; +} + +static struct sk_buff *lan9303_rcv(struct sk_buff *skb, struct net_device *dev) +{ + u16 lan9303_tag1; + unsigned int source_port; + + if (unlikely(!pskb_may_pull(skb, LAN9303_TAG_LEN))) { + dev_warn_ratelimited(&dev->dev, + "Dropping packet, cannot pull\n"); + return NULL; + } + + if (skb_vlan_tag_present(skb)) { + lan9303_tag1 = skb_vlan_tag_get(skb); + __vlan_hwaccel_clear_tag(skb); + } else { + skb_push_rcsum(skb, ETH_HLEN); + __skb_vlan_pop(skb, &lan9303_tag1); + skb_pull_rcsum(skb, ETH_HLEN); + } + + source_port = lan9303_tag1 & 0x3; + + skb->dev = dsa_master_find_slave(dev, 0, source_port); + if (!skb->dev) { + dev_warn_ratelimited(&dev->dev, "Dropping packet due to invalid source port\n"); + return NULL; + } + + if (!(lan9303_tag1 & LAN9303_TAG_RX_TRAPPED_TO_CPU)) + dsa_default_offload_fwd_mark(skb); + + return skb; +} + +static const struct dsa_device_ops lan9303_netdev_ops = { + .name = "lan9303", + .proto = DSA_TAG_PROTO_LAN9303, + .xmit = lan9303_xmit, + .rcv = lan9303_rcv, + .needed_headroom = LAN9303_TAG_LEN, +}; + +MODULE_LICENSE("GPL"); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_LAN9303); + +module_dsa_tag_driver(lan9303_netdev_ops); diff --git a/net/dsa/tag_mtk.c b/net/dsa/tag_mtk.c new file mode 100644 index 000000000..415d8ece2 --- /dev/null +++ b/net/dsa/tag_mtk.c @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Mediatek DSA Tag support + * Copyright (C) 2017 Landen Chao + * Sean Wang + */ + +#include +#include + +#include "dsa_priv.h" + +#define MTK_HDR_LEN 4 +#define MTK_HDR_XMIT_UNTAGGED 0 +#define MTK_HDR_XMIT_TAGGED_TPID_8100 1 +#define MTK_HDR_XMIT_TAGGED_TPID_88A8 2 +#define MTK_HDR_RECV_SOURCE_PORT_MASK GENMASK(2, 0) +#define MTK_HDR_XMIT_DP_BIT_MASK GENMASK(5, 0) +#define MTK_HDR_XMIT_SA_DIS BIT(6) + +static struct sk_buff *mtk_tag_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + u8 xmit_tpid; + u8 *mtk_tag; + + /* Build the special tag after the MAC Source Address. If VLAN header + * is present, it's required that VLAN header and special tag is + * being combined. Only in this way we can allow the switch can parse + * the both special and VLAN tag at the same time and then look up VLAN + * table with VID. + */ + switch (skb->protocol) { + case htons(ETH_P_8021Q): + xmit_tpid = MTK_HDR_XMIT_TAGGED_TPID_8100; + break; + case htons(ETH_P_8021AD): + xmit_tpid = MTK_HDR_XMIT_TAGGED_TPID_88A8; + break; + default: + xmit_tpid = MTK_HDR_XMIT_UNTAGGED; + skb_push(skb, MTK_HDR_LEN); + dsa_alloc_etype_header(skb, MTK_HDR_LEN); + } + + mtk_tag = dsa_etype_header_pos_tx(skb); + + /* Mark tag attribute on special tag insertion to notify hardware + * whether that's a combined special tag with 802.1Q header. + */ + mtk_tag[0] = xmit_tpid; + mtk_tag[1] = (1 << dp->index) & MTK_HDR_XMIT_DP_BIT_MASK; + + /* Tag control information is kept for 802.1Q */ + if (xmit_tpid == MTK_HDR_XMIT_UNTAGGED) { + mtk_tag[2] = 0; + mtk_tag[3] = 0; + } + + return skb; +} + +static struct sk_buff *mtk_tag_rcv(struct sk_buff *skb, struct net_device *dev) +{ + u16 hdr; + int port; + __be16 *phdr; + + if (unlikely(!pskb_may_pull(skb, MTK_HDR_LEN))) + return NULL; + + phdr = dsa_etype_header_pos_rx(skb); + hdr = ntohs(*phdr); + + /* Remove MTK tag and recalculate checksum. */ + skb_pull_rcsum(skb, MTK_HDR_LEN); + + dsa_strip_etype_header(skb, MTK_HDR_LEN); + + /* Get source port information */ + port = (hdr & MTK_HDR_RECV_SOURCE_PORT_MASK); + + skb->dev = dsa_master_find_slave(dev, 0, port); + if (!skb->dev) + return NULL; + + dsa_default_offload_fwd_mark(skb); + + return skb; +} + +static const struct dsa_device_ops mtk_netdev_ops = { + .name = "mtk", + .proto = DSA_TAG_PROTO_MTK, + .xmit = mtk_tag_xmit, + .rcv = mtk_tag_rcv, + .needed_headroom = MTK_HDR_LEN, +}; + +MODULE_LICENSE("GPL"); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_MTK); + +module_dsa_tag_driver(mtk_netdev_ops); diff --git a/net/dsa/tag_ocelot.c b/net/dsa/tag_ocelot.c new file mode 100644 index 000000000..0d81f172b --- /dev/null +++ b/net/dsa/tag_ocelot.c @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright 2019 NXP + */ +#include +#include "dsa_priv.h" + +/* If the port is under a VLAN-aware bridge, remove the VLAN header from the + * payload and move it into the DSA tag, which will make the switch classify + * the packet to the bridge VLAN. Otherwise, leave the classified VLAN at zero, + * which is the pvid of standalone and VLAN-unaware bridge ports. + */ +static void ocelot_xmit_get_vlan_info(struct sk_buff *skb, struct dsa_port *dp, + u64 *vlan_tci, u64 *tag_type) +{ + struct net_device *br = dsa_port_bridge_dev_get(dp); + struct vlan_ethhdr *hdr; + u16 proto, tci; + + if (!br || !br_vlan_enabled(br)) { + *vlan_tci = 0; + *tag_type = IFH_TAG_TYPE_C; + return; + } + + hdr = (struct vlan_ethhdr *)skb_mac_header(skb); + br_vlan_get_proto(br, &proto); + + if (ntohs(hdr->h_vlan_proto) == proto) { + __skb_vlan_pop(skb, &tci); + *vlan_tci = tci; + } else { + rcu_read_lock(); + br_vlan_get_pvid_rcu(br, &tci); + rcu_read_unlock(); + *vlan_tci = tci; + } + + *tag_type = (proto != ETH_P_8021Q) ? IFH_TAG_TYPE_S : IFH_TAG_TYPE_C; +} + +static void ocelot_xmit_common(struct sk_buff *skb, struct net_device *netdev, + __be32 ifh_prefix, void **ifh) +{ + struct dsa_port *dp = dsa_slave_to_port(netdev); + struct dsa_switch *ds = dp->ds; + u64 vlan_tci, tag_type; + void *injection; + __be32 *prefix; + u32 rew_op = 0; + u64 qos_class; + + ocelot_xmit_get_vlan_info(skb, dp, &vlan_tci, &tag_type); + + qos_class = netdev_get_num_tc(netdev) ? + netdev_get_prio_tc_map(netdev, skb->priority) : skb->priority; + + injection = skb_push(skb, OCELOT_TAG_LEN); + prefix = skb_push(skb, OCELOT_SHORT_PREFIX_LEN); + + *prefix = ifh_prefix; + memset(injection, 0, OCELOT_TAG_LEN); + ocelot_ifh_set_bypass(injection, 1); + ocelot_ifh_set_src(injection, ds->num_ports); + ocelot_ifh_set_qos_class(injection, qos_class); + ocelot_ifh_set_vlan_tci(injection, vlan_tci); + ocelot_ifh_set_tag_type(injection, tag_type); + + rew_op = ocelot_ptp_rew_op(skb); + if (rew_op) + ocelot_ifh_set_rew_op(injection, rew_op); + + *ifh = injection; +} + +static struct sk_buff *ocelot_xmit(struct sk_buff *skb, + struct net_device *netdev) +{ + struct dsa_port *dp = dsa_slave_to_port(netdev); + void *injection; + + ocelot_xmit_common(skb, netdev, cpu_to_be32(0x8880000a), &injection); + ocelot_ifh_set_dest(injection, BIT_ULL(dp->index)); + + return skb; +} + +static struct sk_buff *seville_xmit(struct sk_buff *skb, + struct net_device *netdev) +{ + struct dsa_port *dp = dsa_slave_to_port(netdev); + void *injection; + + ocelot_xmit_common(skb, netdev, cpu_to_be32(0x88800005), &injection); + seville_ifh_set_dest(injection, BIT_ULL(dp->index)); + + return skb; +} + +static struct sk_buff *ocelot_rcv(struct sk_buff *skb, + struct net_device *netdev) +{ + u64 src_port, qos_class; + u64 vlan_tci, tag_type; + u8 *start = skb->data; + struct dsa_port *dp; + u8 *extraction; + u16 vlan_tpid; + u64 rew_val; + + /* Revert skb->data by the amount consumed by the DSA master, + * so it points to the beginning of the frame. + */ + skb_push(skb, ETH_HLEN); + /* We don't care about the short prefix, it is just for easy entrance + * into the DSA master's RX filter. Discard it now by moving it into + * the headroom. + */ + skb_pull(skb, OCELOT_SHORT_PREFIX_LEN); + /* And skb->data now points to the extraction frame header. + * Keep a pointer to it. + */ + extraction = skb->data; + /* Now the EFH is part of the headroom as well */ + skb_pull(skb, OCELOT_TAG_LEN); + /* Reset the pointer to the real MAC header */ + skb_reset_mac_header(skb); + skb_reset_mac_len(skb); + /* And move skb->data to the correct location again */ + skb_pull(skb, ETH_HLEN); + + /* Remove from inet csum the extraction header */ + skb_postpull_rcsum(skb, start, OCELOT_TOTAL_TAG_LEN); + + ocelot_xfh_get_src_port(extraction, &src_port); + ocelot_xfh_get_qos_class(extraction, &qos_class); + ocelot_xfh_get_tag_type(extraction, &tag_type); + ocelot_xfh_get_vlan_tci(extraction, &vlan_tci); + ocelot_xfh_get_rew_val(extraction, &rew_val); + + skb->dev = dsa_master_find_slave(netdev, 0, src_port); + if (!skb->dev) + /* The switch will reflect back some frames sent through + * sockets opened on the bare DSA master. These will come back + * with src_port equal to the index of the CPU port, for which + * there is no slave registered. So don't print any error + * message here (ignore and drop those frames). + */ + return NULL; + + dsa_default_offload_fwd_mark(skb); + skb->priority = qos_class; + OCELOT_SKB_CB(skb)->tstamp_lo = rew_val; + + /* Ocelot switches copy frames unmodified to the CPU. However, it is + * possible for the user to request a VLAN modification through + * VCAP_IS1_ACT_VID_REPLACE_ENA. In this case, what will happen is that + * the VLAN ID field from the Extraction Header gets updated, but the + * 802.1Q header does not (the classified VLAN only becomes visible on + * egress through the "port tag" of front-panel ports). + * So, for traffic extracted by the CPU, we want to pick up the + * classified VLAN and manually replace the existing 802.1Q header from + * the packet with it, so that the operating system is always up to + * date with the result of tc-vlan actions. + * NOTE: In VLAN-unaware mode, we don't want to do that, we want the + * frame to remain unmodified, because the classified VLAN is always + * equal to the pvid of the ingress port and should not be used for + * processing. + */ + dp = dsa_slave_to_port(skb->dev); + vlan_tpid = tag_type ? ETH_P_8021AD : ETH_P_8021Q; + + if (dsa_port_is_vlan_filtering(dp) && + eth_hdr(skb)->h_proto == htons(vlan_tpid)) { + u16 dummy_vlan_tci; + + skb_push_rcsum(skb, ETH_HLEN); + __skb_vlan_pop(skb, &dummy_vlan_tci); + skb_pull_rcsum(skb, ETH_HLEN); + __vlan_hwaccel_put_tag(skb, htons(vlan_tpid), vlan_tci); + } + + return skb; +} + +static const struct dsa_device_ops ocelot_netdev_ops = { + .name = "ocelot", + .proto = DSA_TAG_PROTO_OCELOT, + .xmit = ocelot_xmit, + .rcv = ocelot_rcv, + .needed_headroom = OCELOT_TOTAL_TAG_LEN, + .promisc_on_master = true, +}; + +DSA_TAG_DRIVER(ocelot_netdev_ops); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_OCELOT); + +static const struct dsa_device_ops seville_netdev_ops = { + .name = "seville", + .proto = DSA_TAG_PROTO_SEVILLE, + .xmit = seville_xmit, + .rcv = ocelot_rcv, + .needed_headroom = OCELOT_TOTAL_TAG_LEN, + .promisc_on_master = true, +}; + +DSA_TAG_DRIVER(seville_netdev_ops); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_SEVILLE); + +static struct dsa_tag_driver *ocelot_tag_driver_array[] = { + &DSA_TAG_DRIVER_NAME(ocelot_netdev_ops), + &DSA_TAG_DRIVER_NAME(seville_netdev_ops), +}; + +module_dsa_tag_drivers(ocelot_tag_driver_array); + +MODULE_LICENSE("GPL v2"); diff --git a/net/dsa/tag_ocelot_8021q.c b/net/dsa/tag_ocelot_8021q.c new file mode 100644 index 000000000..37ccf0040 --- /dev/null +++ b/net/dsa/tag_ocelot_8021q.c @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright 2020-2021 NXP + * + * An implementation of the software-defined tag_8021q.c tagger format, which + * also preserves full functionality under a vlan_filtering bridge. It does + * this by using the TCAM engines for: + * - pushing the RX VLAN as a second, outer tag, on egress towards the CPU port + * - redirecting towards the correct front port based on TX VLAN and popping + * that on egress + */ +#include +#include +#include "dsa_priv.h" + +struct ocelot_8021q_tagger_private { + struct ocelot_8021q_tagger_data data; /* Must be first */ + struct kthread_worker *xmit_worker; +}; + +static struct sk_buff *ocelot_defer_xmit(struct dsa_port *dp, + struct sk_buff *skb) +{ + struct ocelot_8021q_tagger_private *priv = dp->ds->tagger_data; + struct ocelot_8021q_tagger_data *data = &priv->data; + void (*xmit_work_fn)(struct kthread_work *work); + struct felix_deferred_xmit_work *xmit_work; + struct kthread_worker *xmit_worker; + + xmit_work_fn = data->xmit_work_fn; + xmit_worker = priv->xmit_worker; + + if (!xmit_work_fn || !xmit_worker) + return NULL; + + /* PTP over IP packets need UDP checksumming. We may have inherited + * NETIF_F_HW_CSUM from the DSA master, but these packets are not sent + * through the DSA master, so calculate the checksum here. + */ + if (skb->ip_summed == CHECKSUM_PARTIAL && skb_checksum_help(skb)) + return NULL; + + xmit_work = kzalloc(sizeof(*xmit_work), GFP_ATOMIC); + if (!xmit_work) + return NULL; + + /* Calls felix_port_deferred_xmit in felix.c */ + kthread_init_work(&xmit_work->work, xmit_work_fn); + /* Increase refcount so the kfree_skb in dsa_slave_xmit + * won't really free the packet. + */ + xmit_work->dp = dp; + xmit_work->skb = skb_get(skb); + + kthread_queue_work(xmit_worker, &xmit_work->work); + + return NULL; +} + +static struct sk_buff *ocelot_xmit(struct sk_buff *skb, + struct net_device *netdev) +{ + struct dsa_port *dp = dsa_slave_to_port(netdev); + u16 queue_mapping = skb_get_queue_mapping(skb); + u8 pcp = netdev_txq_to_tc(netdev, queue_mapping); + u16 tx_vid = dsa_tag_8021q_standalone_vid(dp); + struct ethhdr *hdr = eth_hdr(skb); + + if (ocelot_ptp_rew_op(skb) || is_link_local_ether_addr(hdr->h_dest)) + return ocelot_defer_xmit(dp, skb); + + return dsa_8021q_xmit(skb, netdev, ETH_P_8021Q, + ((pcp << VLAN_PRIO_SHIFT) | tx_vid)); +} + +static struct sk_buff *ocelot_rcv(struct sk_buff *skb, + struct net_device *netdev) +{ + int src_port, switch_id; + + dsa_8021q_rcv(skb, &src_port, &switch_id, NULL); + + skb->dev = dsa_master_find_slave(netdev, switch_id, src_port); + if (!skb->dev) + return NULL; + + dsa_default_offload_fwd_mark(skb); + + return skb; +} + +static void ocelot_disconnect(struct dsa_switch *ds) +{ + struct ocelot_8021q_tagger_private *priv = ds->tagger_data; + + kthread_destroy_worker(priv->xmit_worker); + kfree(priv); + ds->tagger_data = NULL; +} + +static int ocelot_connect(struct dsa_switch *ds) +{ + struct ocelot_8021q_tagger_private *priv; + int err; + + priv = kzalloc(sizeof(*priv), GFP_KERNEL); + if (!priv) + return -ENOMEM; + + priv->xmit_worker = kthread_create_worker(0, "felix_xmit"); + if (IS_ERR(priv->xmit_worker)) { + err = PTR_ERR(priv->xmit_worker); + kfree(priv); + return err; + } + + ds->tagger_data = priv; + + return 0; +} + +static const struct dsa_device_ops ocelot_8021q_netdev_ops = { + .name = "ocelot-8021q", + .proto = DSA_TAG_PROTO_OCELOT_8021Q, + .xmit = ocelot_xmit, + .rcv = ocelot_rcv, + .connect = ocelot_connect, + .disconnect = ocelot_disconnect, + .needed_headroom = VLAN_HLEN, + .promisc_on_master = true, +}; + +MODULE_LICENSE("GPL v2"); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_OCELOT_8021Q); + +module_dsa_tag_driver(ocelot_8021q_netdev_ops); diff --git a/net/dsa/tag_qca.c b/net/dsa/tag_qca.c new file mode 100644 index 000000000..57d2e00f1 --- /dev/null +++ b/net/dsa/tag_qca.c @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2015, The Linux Foundation. All rights reserved. + */ + +#include +#include +#include +#include + +#include "dsa_priv.h" + +static struct sk_buff *qca_tag_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + __be16 *phdr; + u16 hdr; + + skb_push(skb, QCA_HDR_LEN); + + dsa_alloc_etype_header(skb, QCA_HDR_LEN); + phdr = dsa_etype_header_pos_tx(skb); + + /* Set the version field, and set destination port information */ + hdr = FIELD_PREP(QCA_HDR_XMIT_VERSION, QCA_HDR_VERSION); + hdr |= QCA_HDR_XMIT_FROM_CPU; + hdr |= FIELD_PREP(QCA_HDR_XMIT_DP_BIT, BIT(dp->index)); + + *phdr = htons(hdr); + + return skb; +} + +static struct sk_buff *qca_tag_rcv(struct sk_buff *skb, struct net_device *dev) +{ + struct qca_tagger_data *tagger_data; + struct dsa_port *dp = dev->dsa_ptr; + struct dsa_switch *ds = dp->ds; + u8 ver, pk_type; + __be16 *phdr; + int port; + u16 hdr; + + BUILD_BUG_ON(sizeof(struct qca_mgmt_ethhdr) != QCA_HDR_MGMT_HEADER_LEN + QCA_HDR_LEN); + + tagger_data = ds->tagger_data; + + if (unlikely(!pskb_may_pull(skb, QCA_HDR_LEN))) + return NULL; + + phdr = dsa_etype_header_pos_rx(skb); + hdr = ntohs(*phdr); + + /* Make sure the version is correct */ + ver = FIELD_GET(QCA_HDR_RECV_VERSION, hdr); + if (unlikely(ver != QCA_HDR_VERSION)) + return NULL; + + /* Get pk type */ + pk_type = FIELD_GET(QCA_HDR_RECV_TYPE, hdr); + + /* Ethernet mgmt read/write packet */ + if (pk_type == QCA_HDR_RECV_TYPE_RW_REG_ACK) { + if (likely(tagger_data->rw_reg_ack_handler)) + tagger_data->rw_reg_ack_handler(ds, skb); + return NULL; + } + + /* Ethernet MIB counter packet */ + if (pk_type == QCA_HDR_RECV_TYPE_MIB) { + if (likely(tagger_data->mib_autocast_handler)) + tagger_data->mib_autocast_handler(ds, skb); + return NULL; + } + + /* Remove QCA tag and recalculate checksum */ + skb_pull_rcsum(skb, QCA_HDR_LEN); + dsa_strip_etype_header(skb, QCA_HDR_LEN); + + /* Get source port information */ + port = FIELD_GET(QCA_HDR_RECV_SOURCE_PORT, hdr); + + skb->dev = dsa_master_find_slave(dev, 0, port); + if (!skb->dev) + return NULL; + + return skb; +} + +static int qca_tag_connect(struct dsa_switch *ds) +{ + struct qca_tagger_data *tagger_data; + + tagger_data = kzalloc(sizeof(*tagger_data), GFP_KERNEL); + if (!tagger_data) + return -ENOMEM; + + ds->tagger_data = tagger_data; + + return 0; +} + +static void qca_tag_disconnect(struct dsa_switch *ds) +{ + kfree(ds->tagger_data); + ds->tagger_data = NULL; +} + +static const struct dsa_device_ops qca_netdev_ops = { + .name = "qca", + .proto = DSA_TAG_PROTO_QCA, + .connect = qca_tag_connect, + .disconnect = qca_tag_disconnect, + .xmit = qca_tag_xmit, + .rcv = qca_tag_rcv, + .needed_headroom = QCA_HDR_LEN, + .promisc_on_master = true, +}; + +MODULE_LICENSE("GPL"); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_QCA); + +module_dsa_tag_driver(qca_netdev_ops); diff --git a/net/dsa/tag_rtl4_a.c b/net/dsa/tag_rtl4_a.c new file mode 100644 index 000000000..6d928ee3e --- /dev/null +++ b/net/dsa/tag_rtl4_a.c @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Handler for Realtek 4 byte DSA switch tags + * Currently only supports protocol "A" found in RTL8366RB + * Copyright (c) 2020 Linus Walleij + * + * This "proprietary tag" header looks like so: + * + * ------------------------------------------------- + * | MAC DA | MAC SA | 0x8899 | 2 bytes tag | Type | + * ------------------------------------------------- + * + * The 2 bytes tag form a 16 bit big endian word. The exact + * meaning has been guessed from packet dumps from ingress + * frames. + */ + +#include +#include + +#include "dsa_priv.h" + +#define RTL4_A_HDR_LEN 4 +#define RTL4_A_ETHERTYPE 0x8899 +#define RTL4_A_PROTOCOL_SHIFT 12 +/* + * 0x1 = Realtek Remote Control protocol (RRCP) + * 0x2/0x3 seems to be used for loopback testing + * 0x9 = RTL8306 DSA protocol + * 0xa = RTL8366RB DSA protocol + */ +#define RTL4_A_PROTOCOL_RTL8366RB 0xa + +static struct sk_buff *rtl4a_tag_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + __be16 *p; + u8 *tag; + u16 out; + + /* Pad out to at least 60 bytes */ + if (unlikely(__skb_put_padto(skb, ETH_ZLEN, false))) + return NULL; + + netdev_dbg(dev, "add realtek tag to package to port %d\n", + dp->index); + skb_push(skb, RTL4_A_HDR_LEN); + + dsa_alloc_etype_header(skb, RTL4_A_HDR_LEN); + tag = dsa_etype_header_pos_tx(skb); + + /* Set Ethertype */ + p = (__be16 *)tag; + *p = htons(RTL4_A_ETHERTYPE); + + out = (RTL4_A_PROTOCOL_RTL8366RB << RTL4_A_PROTOCOL_SHIFT); + /* The lower bits indicate the port number */ + out |= BIT(dp->index); + + p = (__be16 *)(tag + 2); + *p = htons(out); + + return skb; +} + +static struct sk_buff *rtl4a_tag_rcv(struct sk_buff *skb, + struct net_device *dev) +{ + u16 protport; + __be16 *p; + u16 etype; + u8 *tag; + u8 prot; + u8 port; + + if (unlikely(!pskb_may_pull(skb, RTL4_A_HDR_LEN))) + return NULL; + + tag = dsa_etype_header_pos_rx(skb); + p = (__be16 *)tag; + etype = ntohs(*p); + if (etype != RTL4_A_ETHERTYPE) { + /* Not custom, just pass through */ + netdev_dbg(dev, "non-realtek ethertype 0x%04x\n", etype); + return skb; + } + p = (__be16 *)(tag + 2); + protport = ntohs(*p); + /* The 4 upper bits are the protocol */ + prot = (protport >> RTL4_A_PROTOCOL_SHIFT) & 0x0f; + if (prot != RTL4_A_PROTOCOL_RTL8366RB) { + netdev_err(dev, "unknown realtek protocol 0x%01x\n", prot); + return NULL; + } + port = protport & 0xff; + + skb->dev = dsa_master_find_slave(dev, 0, port); + if (!skb->dev) { + netdev_dbg(dev, "could not find slave for port %d\n", port); + return NULL; + } + + /* Remove RTL4 tag and recalculate checksum */ + skb_pull_rcsum(skb, RTL4_A_HDR_LEN); + + dsa_strip_etype_header(skb, RTL4_A_HDR_LEN); + + dsa_default_offload_fwd_mark(skb); + + return skb; +} + +static const struct dsa_device_ops rtl4a_netdev_ops = { + .name = "rtl4a", + .proto = DSA_TAG_PROTO_RTL4_A, + .xmit = rtl4a_tag_xmit, + .rcv = rtl4a_tag_rcv, + .needed_headroom = RTL4_A_HDR_LEN, +}; +module_dsa_tag_driver(rtl4a_netdev_ops); + +MODULE_LICENSE("GPL"); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_RTL4_A); diff --git a/net/dsa/tag_rtl8_4.c b/net/dsa/tag_rtl8_4.c new file mode 100644 index 000000000..a593ead7f --- /dev/null +++ b/net/dsa/tag_rtl8_4.c @@ -0,0 +1,258 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Handler for Realtek 8 byte switch tags + * + * Copyright (C) 2021 Alvin Šipraga + * + * NOTE: Currently only supports protocol "4" found in the RTL8365MB, hence + * named tag_rtl8_4. + * + * This tag has the following format: + * + * 0 7|8 15 + * |-----------------------------------+-----------------------------------|--- + * | (16-bit) | ^ + * | Realtek EtherType [0x8899] | | + * |-----------------------------------+-----------------------------------| 8 + * | (8-bit) | (8-bit) | + * | Protocol [0x04] | REASON | b + * |-----------------------------------+-----------------------------------| y + * | (1) | (1) | (2) | (1) | (3) | (1) | (1) | (1) | (5) | t + * | FID_EN | X | FID | PRI_EN | PRI | KEEP | X | LEARN_DIS | X | e + * |-----------------------------------+-----------------------------------| s + * | (1) | (15-bit) | | + * | ALLOW | TX/RX | v + * |-----------------------------------+-----------------------------------|--- + * + * With the following field descriptions: + * + * field | description + * ------------+------------- + * Realtek | 0x8899: indicates that this is a proprietary Realtek tag; + * EtherType | note that Realtek uses the same EtherType for + * | other incompatible tag formats (e.g. tag_rtl4_a.c) + * Protocol | 0x04: indicates that this tag conforms to this format + * X | reserved + * ------------+------------- + * REASON | reason for forwarding packet to CPU + * | 0: packet was forwarded or flooded to CPU + * | 80: packet was trapped to CPU + * FID_EN | 1: packet has an FID + * | 0: no FID + * FID | FID of packet (if FID_EN=1) + * PRI_EN | 1: force priority of packet + * | 0: don't force priority + * PRI | priority of packet (if PRI_EN=1) + * KEEP | preserve packet VLAN tag format + * LEARN_DIS | don't learn the source MAC address of the packet + * ALLOW | 1: treat TX/RX field as an allowance port mask, meaning the + * | packet may only be forwarded to ports specified in the + * | mask + * | 0: no allowance port mask, TX/RX field is the forwarding + * | port mask + * TX/RX | TX (switch->CPU): port number the packet was received on + * | RX (CPU->switch): forwarding port mask (if ALLOW=0) + * | allowance port mask (if ALLOW=1) + * + * The tag can be positioned before Ethertype, using tag "rtl8_4": + * + * +--------+--------+------------+------+----- + * | MAC DA | MAC SA | 8 byte tag | Type | ... + * +--------+--------+------------+------+----- + * + * The tag can also appear between the end of the payload and before the CRC, + * using tag "rtl8_4t": + * + * +--------+--------+------+-----+---------+------------+-----+ + * | MAC DA | MAC SA | TYPE | ... | payload | 8-byte tag | CRC | + * +--------+--------+------+-----+---------+------------+-----+ + * + * The added bytes after the payload will break most checksums, either in + * software or hardware. To avoid this issue, if the checksum is still pending, + * this tagger checksums the packet in software before adding the tag. + * + */ + +#include +#include +#include + +#include "dsa_priv.h" + +/* Protocols supported: + * + * 0x04 = RTL8365MB DSA protocol + */ + +#define RTL8_4_TAG_LEN 8 + +#define RTL8_4_PROTOCOL GENMASK(15, 8) +#define RTL8_4_PROTOCOL_RTL8365MB 0x04 +#define RTL8_4_REASON GENMASK(7, 0) +#define RTL8_4_REASON_FORWARD 0 +#define RTL8_4_REASON_TRAP 80 + +#define RTL8_4_LEARN_DIS BIT(5) + +#define RTL8_4_TX GENMASK(3, 0) +#define RTL8_4_RX GENMASK(10, 0) + +static void rtl8_4_write_tag(struct sk_buff *skb, struct net_device *dev, + void *tag) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + __be16 tag16[RTL8_4_TAG_LEN / 2]; + + /* Set Realtek EtherType */ + tag16[0] = htons(ETH_P_REALTEK); + + /* Set Protocol; zero REASON */ + tag16[1] = htons(FIELD_PREP(RTL8_4_PROTOCOL, RTL8_4_PROTOCOL_RTL8365MB)); + + /* Zero FID_EN, FID, PRI_EN, PRI, KEEP; set LEARN_DIS */ + tag16[2] = htons(FIELD_PREP(RTL8_4_LEARN_DIS, 1)); + + /* Zero ALLOW; set RX (CPU->switch) forwarding port mask */ + tag16[3] = htons(FIELD_PREP(RTL8_4_RX, BIT(dp->index))); + + memcpy(tag, tag16, RTL8_4_TAG_LEN); +} + +static struct sk_buff *rtl8_4_tag_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + skb_push(skb, RTL8_4_TAG_LEN); + + dsa_alloc_etype_header(skb, RTL8_4_TAG_LEN); + + rtl8_4_write_tag(skb, dev, dsa_etype_header_pos_tx(skb)); + + return skb; +} + +static struct sk_buff *rtl8_4t_tag_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + /* Calculate the checksum here if not done yet as trailing tags will + * break either software or hardware based checksum + */ + if (skb->ip_summed == CHECKSUM_PARTIAL && skb_checksum_help(skb)) + return NULL; + + rtl8_4_write_tag(skb, dev, skb_put(skb, RTL8_4_TAG_LEN)); + + return skb; +} + +static int rtl8_4_read_tag(struct sk_buff *skb, struct net_device *dev, + void *tag) +{ + __be16 tag16[RTL8_4_TAG_LEN / 2]; + u16 etype; + u8 reason; + u8 proto; + u8 port; + + memcpy(tag16, tag, RTL8_4_TAG_LEN); + + /* Parse Realtek EtherType */ + etype = ntohs(tag16[0]); + if (unlikely(etype != ETH_P_REALTEK)) { + dev_warn_ratelimited(&dev->dev, + "non-realtek ethertype 0x%04x\n", etype); + return -EPROTO; + } + + /* Parse Protocol */ + proto = FIELD_GET(RTL8_4_PROTOCOL, ntohs(tag16[1])); + if (unlikely(proto != RTL8_4_PROTOCOL_RTL8365MB)) { + dev_warn_ratelimited(&dev->dev, + "unknown realtek protocol 0x%02x\n", + proto); + return -EPROTO; + } + + /* Parse REASON */ + reason = FIELD_GET(RTL8_4_REASON, ntohs(tag16[1])); + + /* Parse TX (switch->CPU) */ + port = FIELD_GET(RTL8_4_TX, ntohs(tag16[3])); + skb->dev = dsa_master_find_slave(dev, 0, port); + if (!skb->dev) { + dev_warn_ratelimited(&dev->dev, + "could not find slave for port %d\n", + port); + return -ENOENT; + } + + if (reason != RTL8_4_REASON_TRAP) + dsa_default_offload_fwd_mark(skb); + + return 0; +} + +static struct sk_buff *rtl8_4_tag_rcv(struct sk_buff *skb, + struct net_device *dev) +{ + if (unlikely(!pskb_may_pull(skb, RTL8_4_TAG_LEN))) + return NULL; + + if (unlikely(rtl8_4_read_tag(skb, dev, dsa_etype_header_pos_rx(skb)))) + return NULL; + + /* Remove tag and recalculate checksum */ + skb_pull_rcsum(skb, RTL8_4_TAG_LEN); + + dsa_strip_etype_header(skb, RTL8_4_TAG_LEN); + + return skb; +} + +static struct sk_buff *rtl8_4t_tag_rcv(struct sk_buff *skb, + struct net_device *dev) +{ + if (skb_linearize(skb)) + return NULL; + + if (unlikely(rtl8_4_read_tag(skb, dev, skb_tail_pointer(skb) - RTL8_4_TAG_LEN))) + return NULL; + + if (pskb_trim_rcsum(skb, skb->len - RTL8_4_TAG_LEN)) + return NULL; + + return skb; +} + +/* Ethertype version */ +static const struct dsa_device_ops rtl8_4_netdev_ops = { + .name = "rtl8_4", + .proto = DSA_TAG_PROTO_RTL8_4, + .xmit = rtl8_4_tag_xmit, + .rcv = rtl8_4_tag_rcv, + .needed_headroom = RTL8_4_TAG_LEN, +}; + +DSA_TAG_DRIVER(rtl8_4_netdev_ops); + +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_RTL8_4); + +/* Tail version */ +static const struct dsa_device_ops rtl8_4t_netdev_ops = { + .name = "rtl8_4t", + .proto = DSA_TAG_PROTO_RTL8_4T, + .xmit = rtl8_4t_tag_xmit, + .rcv = rtl8_4t_tag_rcv, + .needed_tailroom = RTL8_4_TAG_LEN, +}; + +DSA_TAG_DRIVER(rtl8_4t_netdev_ops); + +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_RTL8_4T); + +static struct dsa_tag_driver *dsa_tag_drivers[] = { + &DSA_TAG_DRIVER_NAME(rtl8_4_netdev_ops), + &DSA_TAG_DRIVER_NAME(rtl8_4t_netdev_ops), +}; +module_dsa_tag_drivers(dsa_tag_drivers); + +MODULE_LICENSE("GPL"); diff --git a/net/dsa/tag_rzn1_a5psw.c b/net/dsa/tag_rzn1_a5psw.c new file mode 100644 index 000000000..e2a5ee6ae --- /dev/null +++ b/net/dsa/tag_rzn1_a5psw.c @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2022 Schneider Electric + * + * Clément Léger + */ + +#include +#include +#include +#include + +#include "dsa_priv.h" + +/* To define the outgoing port and to discover the incoming port a TAG is + * inserted after Src MAC : + * + * Dest MAC Src MAC TAG Type + * ...| 1 2 3 4 5 6 | 1 2 3 4 5 6 | 1 2 3 4 5 6 7 8 | 1 2 |... + * |<--------------->| + * + * See struct a5psw_tag for layout + */ + +#define ETH_P_DSA_A5PSW 0xE001 +#define A5PSW_TAG_LEN 8 +#define A5PSW_CTRL_DATA_FORCE_FORWARD BIT(0) +/* This is both used for xmit tag and rcv tagging */ +#define A5PSW_CTRL_DATA_PORT GENMASK(3, 0) + +struct a5psw_tag { + __be16 ctrl_tag; + __be16 ctrl_data; + __be16 ctrl_data2_hi; + __be16 ctrl_data2_lo; +}; + +static struct sk_buff *a5psw_tag_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + struct a5psw_tag *ptag; + u32 data2_val; + + BUILD_BUG_ON(sizeof(*ptag) != A5PSW_TAG_LEN); + + /* The Ethernet switch we are interfaced with needs packets to be at + * least 60 bytes otherwise they will be discarded when they enter the + * switch port logic. + */ + if (__skb_put_padto(skb, ETH_ZLEN, false)) + return NULL; + + /* provide 'A5PSW_TAG_LEN' bytes additional space */ + skb_push(skb, A5PSW_TAG_LEN); + + /* make room between MACs and Ether-Type to insert tag */ + dsa_alloc_etype_header(skb, A5PSW_TAG_LEN); + + ptag = dsa_etype_header_pos_tx(skb); + + data2_val = FIELD_PREP(A5PSW_CTRL_DATA_PORT, BIT(dp->index)); + ptag->ctrl_tag = htons(ETH_P_DSA_A5PSW); + ptag->ctrl_data = htons(A5PSW_CTRL_DATA_FORCE_FORWARD); + ptag->ctrl_data2_lo = htons(data2_val); + ptag->ctrl_data2_hi = 0; + + return skb; +} + +static struct sk_buff *a5psw_tag_rcv(struct sk_buff *skb, + struct net_device *dev) +{ + struct a5psw_tag *tag; + int port; + + if (unlikely(!pskb_may_pull(skb, A5PSW_TAG_LEN))) { + dev_warn_ratelimited(&dev->dev, + "Dropping packet, cannot pull\n"); + return NULL; + } + + tag = dsa_etype_header_pos_rx(skb); + + if (tag->ctrl_tag != htons(ETH_P_DSA_A5PSW)) { + dev_warn_ratelimited(&dev->dev, "Dropping packet due to invalid TAG marker\n"); + return NULL; + } + + port = FIELD_GET(A5PSW_CTRL_DATA_PORT, ntohs(tag->ctrl_data)); + + skb->dev = dsa_master_find_slave(dev, 0, port); + if (!skb->dev) + return NULL; + + skb_pull_rcsum(skb, A5PSW_TAG_LEN); + dsa_strip_etype_header(skb, A5PSW_TAG_LEN); + + dsa_default_offload_fwd_mark(skb); + + return skb; +} + +static const struct dsa_device_ops a5psw_netdev_ops = { + .name = "a5psw", + .proto = DSA_TAG_PROTO_RZN1_A5PSW, + .xmit = a5psw_tag_xmit, + .rcv = a5psw_tag_rcv, + .needed_headroom = A5PSW_TAG_LEN, +}; + +MODULE_LICENSE("GPL v2"); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_A5PSW); +module_dsa_tag_driver(a5psw_netdev_ops); diff --git a/net/dsa/tag_sja1105.c b/net/dsa/tag_sja1105.c new file mode 100644 index 000000000..143348962 --- /dev/null +++ b/net/dsa/tag_sja1105.c @@ -0,0 +1,804 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2019, Vladimir Oltean + */ +#include +#include +#include +#include +#include "dsa_priv.h" + +/* Is this a TX or an RX header? */ +#define SJA1110_HEADER_HOST_TO_SWITCH BIT(15) + +/* RX header */ +#define SJA1110_RX_HEADER_IS_METADATA BIT(14) +#define SJA1110_RX_HEADER_HOST_ONLY BIT(13) +#define SJA1110_RX_HEADER_HAS_TRAILER BIT(12) + +/* Trap-to-host format (no trailer present) */ +#define SJA1110_RX_HEADER_SRC_PORT(x) (((x) & GENMASK(7, 4)) >> 4) +#define SJA1110_RX_HEADER_SWITCH_ID(x) ((x) & GENMASK(3, 0)) + +/* Timestamp format (trailer present) */ +#define SJA1110_RX_HEADER_TRAILER_POS(x) ((x) & GENMASK(11, 0)) + +#define SJA1110_RX_TRAILER_SWITCH_ID(x) (((x) & GENMASK(7, 4)) >> 4) +#define SJA1110_RX_TRAILER_SRC_PORT(x) ((x) & GENMASK(3, 0)) + +/* Meta frame format (for 2-step TX timestamps) */ +#define SJA1110_RX_HEADER_N_TS(x) (((x) & GENMASK(8, 4)) >> 4) + +/* TX header */ +#define SJA1110_TX_HEADER_UPDATE_TC BIT(14) +#define SJA1110_TX_HEADER_TAKE_TS BIT(13) +#define SJA1110_TX_HEADER_TAKE_TS_CASC BIT(12) +#define SJA1110_TX_HEADER_HAS_TRAILER BIT(11) + +/* Only valid if SJA1110_TX_HEADER_HAS_TRAILER is false */ +#define SJA1110_TX_HEADER_PRIO(x) (((x) << 7) & GENMASK(10, 7)) +#define SJA1110_TX_HEADER_TSTAMP_ID(x) ((x) & GENMASK(7, 0)) + +/* Only valid if SJA1110_TX_HEADER_HAS_TRAILER is true */ +#define SJA1110_TX_HEADER_TRAILER_POS(x) ((x) & GENMASK(10, 0)) + +#define SJA1110_TX_TRAILER_TSTAMP_ID(x) (((x) << 24) & GENMASK(31, 24)) +#define SJA1110_TX_TRAILER_PRIO(x) (((x) << 21) & GENMASK(23, 21)) +#define SJA1110_TX_TRAILER_SWITCHID(x) (((x) << 12) & GENMASK(15, 12)) +#define SJA1110_TX_TRAILER_DESTPORTS(x) (((x) << 1) & GENMASK(11, 1)) + +#define SJA1110_META_TSTAMP_SIZE 10 + +#define SJA1110_HEADER_LEN 4 +#define SJA1110_RX_TRAILER_LEN 13 +#define SJA1110_TX_TRAILER_LEN 4 +#define SJA1110_MAX_PADDING_LEN 15 + +struct sja1105_tagger_private { + struct sja1105_tagger_data data; /* Must be first */ + /* Protects concurrent access to the meta state machine + * from taggers running on multiple ports on SMP systems + */ + spinlock_t meta_lock; + struct sk_buff *stampable_skb; + struct kthread_worker *xmit_worker; +}; + +static struct sja1105_tagger_private * +sja1105_tagger_private(struct dsa_switch *ds) +{ + return ds->tagger_data; +} + +/* Similar to is_link_local_ether_addr(hdr->h_dest) but also covers PTP */ +static inline bool sja1105_is_link_local(const struct sk_buff *skb) +{ + const struct ethhdr *hdr = eth_hdr(skb); + u64 dmac = ether_addr_to_u64(hdr->h_dest); + + if (ntohs(hdr->h_proto) == ETH_P_SJA1105_META) + return false; + if ((dmac & SJA1105_LINKLOCAL_FILTER_A_MASK) == + SJA1105_LINKLOCAL_FILTER_A) + return true; + if ((dmac & SJA1105_LINKLOCAL_FILTER_B_MASK) == + SJA1105_LINKLOCAL_FILTER_B) + return true; + return false; +} + +struct sja1105_meta { + u64 tstamp; + u64 dmac_byte_4; + u64 dmac_byte_3; + u64 source_port; + u64 switch_id; +}; + +static void sja1105_meta_unpack(const struct sk_buff *skb, + struct sja1105_meta *meta) +{ + u8 *buf = skb_mac_header(skb) + ETH_HLEN; + + /* UM10944.pdf section 4.2.17 AVB Parameters: + * Structure of the meta-data follow-up frame. + * It is in network byte order, so there are no quirks + * while unpacking the meta frame. + * + * Also SJA1105 E/T only populates bits 23:0 of the timestamp + * whereas P/Q/R/S does 32 bits. Since the structure is the + * same and the E/T puts zeroes in the high-order byte, use + * a unified unpacking command for both device series. + */ + packing(buf, &meta->tstamp, 31, 0, 4, UNPACK, 0); + packing(buf + 4, &meta->dmac_byte_3, 7, 0, 1, UNPACK, 0); + packing(buf + 5, &meta->dmac_byte_4, 7, 0, 1, UNPACK, 0); + packing(buf + 6, &meta->source_port, 7, 0, 1, UNPACK, 0); + packing(buf + 7, &meta->switch_id, 7, 0, 1, UNPACK, 0); +} + +static inline bool sja1105_is_meta_frame(const struct sk_buff *skb) +{ + const struct ethhdr *hdr = eth_hdr(skb); + u64 smac = ether_addr_to_u64(hdr->h_source); + u64 dmac = ether_addr_to_u64(hdr->h_dest); + + if (smac != SJA1105_META_SMAC) + return false; + if (dmac != SJA1105_META_DMAC) + return false; + if (ntohs(hdr->h_proto) != ETH_P_SJA1105_META) + return false; + return true; +} + +/* Calls sja1105_port_deferred_xmit in sja1105_main.c */ +static struct sk_buff *sja1105_defer_xmit(struct dsa_port *dp, + struct sk_buff *skb) +{ + struct sja1105_tagger_data *tagger_data = sja1105_tagger_data(dp->ds); + struct sja1105_tagger_private *priv = sja1105_tagger_private(dp->ds); + void (*xmit_work_fn)(struct kthread_work *work); + struct sja1105_deferred_xmit_work *xmit_work; + struct kthread_worker *xmit_worker; + + xmit_work_fn = tagger_data->xmit_work_fn; + xmit_worker = priv->xmit_worker; + + if (!xmit_work_fn || !xmit_worker) + return NULL; + + xmit_work = kzalloc(sizeof(*xmit_work), GFP_ATOMIC); + if (!xmit_work) + return NULL; + + kthread_init_work(&xmit_work->work, xmit_work_fn); + /* Increase refcount so the kfree_skb in dsa_slave_xmit + * won't really free the packet. + */ + xmit_work->dp = dp; + xmit_work->skb = skb_get(skb); + + kthread_queue_work(xmit_worker, &xmit_work->work); + + return NULL; +} + +/* Send VLAN tags with a TPID that blends in with whatever VLAN protocol a + * bridge spanning ports of this switch might have. + */ +static u16 sja1105_xmit_tpid(struct dsa_port *dp) +{ + struct dsa_switch *ds = dp->ds; + struct dsa_port *other_dp; + u16 proto; + + /* Since VLAN awareness is global, then if this port is VLAN-unaware, + * all ports are. Use the VLAN-unaware TPID used for tag_8021q. + */ + if (!dsa_port_is_vlan_filtering(dp)) + return ETH_P_SJA1105; + + /* Port is VLAN-aware, so there is a bridge somewhere (a single one, + * we're sure about that). It may not be on this port though, so we + * need to find it. + */ + dsa_switch_for_each_port(other_dp, ds) { + struct net_device *br = dsa_port_bridge_dev_get(other_dp); + + if (!br) + continue; + + /* Error is returned only if CONFIG_BRIDGE_VLAN_FILTERING, + * which seems pointless to handle, as our port cannot become + * VLAN-aware in that case. + */ + br_vlan_get_proto(br, &proto); + + return proto; + } + + WARN_ONCE(1, "Port is VLAN-aware but cannot find associated bridge!\n"); + + return ETH_P_SJA1105; +} + +static struct sk_buff *sja1105_imprecise_xmit(struct sk_buff *skb, + struct net_device *netdev) +{ + struct dsa_port *dp = dsa_slave_to_port(netdev); + unsigned int bridge_num = dsa_port_bridge_num_get(dp); + struct net_device *br = dsa_port_bridge_dev_get(dp); + u16 tx_vid; + + /* If the port is under a VLAN-aware bridge, just slide the + * VLAN-tagged packet into the FDB and hope for the best. + * This works because we support a single VLAN-aware bridge + * across the entire dst, and its VLANs cannot be shared with + * any standalone port. + */ + if (br_vlan_enabled(br)) + return skb; + + /* If the port is under a VLAN-unaware bridge, use an imprecise + * TX VLAN that targets the bridge's entire broadcast domain, + * instead of just the specific port. + */ + tx_vid = dsa_tag_8021q_bridge_vid(bridge_num); + + return dsa_8021q_xmit(skb, netdev, sja1105_xmit_tpid(dp), tx_vid); +} + +/* Transform untagged control packets into pvid-tagged control packets so that + * all packets sent by this tagger are VLAN-tagged and we can configure the + * switch to drop untagged packets coming from the DSA master. + */ +static struct sk_buff *sja1105_pvid_tag_control_pkt(struct dsa_port *dp, + struct sk_buff *skb, u8 pcp) +{ + __be16 xmit_tpid = htons(sja1105_xmit_tpid(dp)); + struct vlan_ethhdr *hdr; + + /* If VLAN tag is in hwaccel area, move it to the payload + * to deal with both cases uniformly and to ensure that + * the VLANs are added in the right order. + */ + if (unlikely(skb_vlan_tag_present(skb))) { + skb = __vlan_hwaccel_push_inside(skb); + if (!skb) + return NULL; + } + + hdr = (struct vlan_ethhdr *)skb_mac_header(skb); + + /* If skb is already VLAN-tagged, leave that VLAN ID in place */ + if (hdr->h_vlan_proto == xmit_tpid) + return skb; + + return vlan_insert_tag(skb, xmit_tpid, (pcp << VLAN_PRIO_SHIFT) | + SJA1105_DEFAULT_VLAN); +} + +static struct sk_buff *sja1105_xmit(struct sk_buff *skb, + struct net_device *netdev) +{ + struct dsa_port *dp = dsa_slave_to_port(netdev); + u16 queue_mapping = skb_get_queue_mapping(skb); + u8 pcp = netdev_txq_to_tc(netdev, queue_mapping); + u16 tx_vid = dsa_tag_8021q_standalone_vid(dp); + + if (skb->offload_fwd_mark) + return sja1105_imprecise_xmit(skb, netdev); + + /* Transmitting management traffic does not rely upon switch tagging, + * but instead SPI-installed management routes. Part 2 of this + * is the .port_deferred_xmit driver callback. + */ + if (unlikely(sja1105_is_link_local(skb))) { + skb = sja1105_pvid_tag_control_pkt(dp, skb, pcp); + if (!skb) + return NULL; + + return sja1105_defer_xmit(dp, skb); + } + + return dsa_8021q_xmit(skb, netdev, sja1105_xmit_tpid(dp), + ((pcp << VLAN_PRIO_SHIFT) | tx_vid)); +} + +static struct sk_buff *sja1110_xmit(struct sk_buff *skb, + struct net_device *netdev) +{ + struct sk_buff *clone = SJA1105_SKB_CB(skb)->clone; + struct dsa_port *dp = dsa_slave_to_port(netdev); + u16 queue_mapping = skb_get_queue_mapping(skb); + u8 pcp = netdev_txq_to_tc(netdev, queue_mapping); + u16 tx_vid = dsa_tag_8021q_standalone_vid(dp); + __be32 *tx_trailer; + __be16 *tx_header; + int trailer_pos; + + if (skb->offload_fwd_mark) + return sja1105_imprecise_xmit(skb, netdev); + + /* Transmitting control packets is done using in-band control + * extensions, while data packets are transmitted using + * tag_8021q TX VLANs. + */ + if (likely(!sja1105_is_link_local(skb))) + return dsa_8021q_xmit(skb, netdev, sja1105_xmit_tpid(dp), + ((pcp << VLAN_PRIO_SHIFT) | tx_vid)); + + skb = sja1105_pvid_tag_control_pkt(dp, skb, pcp); + if (!skb) + return NULL; + + skb_push(skb, SJA1110_HEADER_LEN); + + dsa_alloc_etype_header(skb, SJA1110_HEADER_LEN); + + trailer_pos = skb->len; + + tx_header = dsa_etype_header_pos_tx(skb); + tx_trailer = skb_put(skb, SJA1110_TX_TRAILER_LEN); + + tx_header[0] = htons(ETH_P_SJA1110); + tx_header[1] = htons(SJA1110_HEADER_HOST_TO_SWITCH | + SJA1110_TX_HEADER_HAS_TRAILER | + SJA1110_TX_HEADER_TRAILER_POS(trailer_pos)); + *tx_trailer = cpu_to_be32(SJA1110_TX_TRAILER_PRIO(pcp) | + SJA1110_TX_TRAILER_SWITCHID(dp->ds->index) | + SJA1110_TX_TRAILER_DESTPORTS(BIT(dp->index))); + if (clone) { + u8 ts_id = SJA1105_SKB_CB(clone)->ts_id; + + tx_header[1] |= htons(SJA1110_TX_HEADER_TAKE_TS); + *tx_trailer |= cpu_to_be32(SJA1110_TX_TRAILER_TSTAMP_ID(ts_id)); + } + + return skb; +} + +static void sja1105_transfer_meta(struct sk_buff *skb, + const struct sja1105_meta *meta) +{ + struct ethhdr *hdr = eth_hdr(skb); + + hdr->h_dest[3] = meta->dmac_byte_3; + hdr->h_dest[4] = meta->dmac_byte_4; + SJA1105_SKB_CB(skb)->tstamp = meta->tstamp; +} + +/* This is a simple state machine which follows the hardware mechanism of + * generating RX timestamps: + * + * After each timestampable skb (all traffic for which send_meta1 and + * send_meta0 is true, aka all MAC-filtered link-local traffic) a meta frame + * containing a partial timestamp is immediately generated by the switch and + * sent as a follow-up to the link-local frame on the CPU port. + * + * The meta frames have no unique identifier (such as sequence number) by which + * one may pair them to the correct timestampable frame. + * Instead, the switch has internal logic that ensures no frames are sent on + * the CPU port between a link-local timestampable frame and its corresponding + * meta follow-up. It also ensures strict ordering between ports (lower ports + * have higher priority towards the CPU port). For this reason, a per-port + * data structure is not needed/desirable. + * + * This function pairs the link-local frame with its partial timestamp from the + * meta follow-up frame. The full timestamp will be reconstructed later in a + * work queue. + */ +static struct sk_buff +*sja1105_rcv_meta_state_machine(struct sk_buff *skb, + struct sja1105_meta *meta, + bool is_link_local, + bool is_meta) +{ + /* Step 1: A timestampable frame was received. + * Buffer it until we get its meta frame. + */ + if (is_link_local) { + struct dsa_port *dp = dsa_slave_to_port(skb->dev); + struct sja1105_tagger_private *priv; + struct dsa_switch *ds = dp->ds; + + priv = sja1105_tagger_private(ds); + + spin_lock(&priv->meta_lock); + /* Was this a link-local frame instead of the meta + * that we were expecting? + */ + if (priv->stampable_skb) { + dev_err_ratelimited(ds->dev, + "Expected meta frame, is %12llx " + "in the DSA master multicast filter?\n", + SJA1105_META_DMAC); + kfree_skb(priv->stampable_skb); + } + + /* Hold a reference to avoid dsa_switch_rcv + * from freeing the skb. + */ + priv->stampable_skb = skb_get(skb); + spin_unlock(&priv->meta_lock); + + /* Tell DSA we got nothing */ + return NULL; + + /* Step 2: The meta frame arrived. + * Time to take the stampable skb out of the closet, annotate it + * with the partial timestamp, and pretend that we received it + * just now (basically masquerade the buffered frame as the meta + * frame, which serves no further purpose). + */ + } else if (is_meta) { + struct dsa_port *dp = dsa_slave_to_port(skb->dev); + struct sja1105_tagger_private *priv; + struct dsa_switch *ds = dp->ds; + struct sk_buff *stampable_skb; + + priv = sja1105_tagger_private(ds); + + spin_lock(&priv->meta_lock); + + stampable_skb = priv->stampable_skb; + priv->stampable_skb = NULL; + + /* Was this a meta frame instead of the link-local + * that we were expecting? + */ + if (!stampable_skb) { + dev_err_ratelimited(ds->dev, + "Unexpected meta frame\n"); + spin_unlock(&priv->meta_lock); + return NULL; + } + + if (stampable_skb->dev != skb->dev) { + dev_err_ratelimited(ds->dev, + "Meta frame on wrong port\n"); + spin_unlock(&priv->meta_lock); + return NULL; + } + + /* Free the meta frame and give DSA the buffered stampable_skb + * for further processing up the network stack. + */ + kfree_skb(skb); + skb = stampable_skb; + sja1105_transfer_meta(skb, meta); + + spin_unlock(&priv->meta_lock); + } + + return skb; +} + +static bool sja1105_skb_has_tag_8021q(const struct sk_buff *skb) +{ + u16 tpid = ntohs(eth_hdr(skb)->h_proto); + + return tpid == ETH_P_SJA1105 || tpid == ETH_P_8021Q || + skb_vlan_tag_present(skb); +} + +static bool sja1110_skb_has_inband_control_extension(const struct sk_buff *skb) +{ + return ntohs(eth_hdr(skb)->h_proto) == ETH_P_SJA1110; +} + +/* If the VLAN in the packet is a tag_8021q one, set @source_port and + * @switch_id and strip the header. Otherwise set @vid and keep it in the + * packet. + */ +static void sja1105_vlan_rcv(struct sk_buff *skb, int *source_port, + int *switch_id, int *vbid, u16 *vid) +{ + struct vlan_ethhdr *hdr = (struct vlan_ethhdr *)skb_mac_header(skb); + u16 vlan_tci; + + if (skb_vlan_tag_present(skb)) + vlan_tci = skb_vlan_tag_get(skb); + else + vlan_tci = ntohs(hdr->h_vlan_TCI); + + if (vid_is_dsa_8021q(vlan_tci & VLAN_VID_MASK)) + return dsa_8021q_rcv(skb, source_port, switch_id, vbid); + + /* Try our best with imprecise RX */ + *vid = vlan_tci & VLAN_VID_MASK; +} + +static struct sk_buff *sja1105_rcv(struct sk_buff *skb, + struct net_device *netdev) +{ + int source_port = -1, switch_id = -1, vbid = -1; + struct sja1105_meta meta = {0}; + struct ethhdr *hdr; + bool is_link_local; + bool is_meta; + u16 vid; + + hdr = eth_hdr(skb); + is_link_local = sja1105_is_link_local(skb); + is_meta = sja1105_is_meta_frame(skb); + + if (is_link_local) { + /* Management traffic path. Switch embeds the switch ID and + * port ID into bytes of the destination MAC, courtesy of + * the incl_srcpt options. + */ + source_port = hdr->h_dest[3]; + switch_id = hdr->h_dest[4]; + } else if (is_meta) { + sja1105_meta_unpack(skb, &meta); + source_port = meta.source_port; + switch_id = meta.switch_id; + } + + /* Normal data plane traffic and link-local frames are tagged with + * a tag_8021q VLAN which we have to strip + */ + if (sja1105_skb_has_tag_8021q(skb)) { + int tmp_source_port = -1, tmp_switch_id = -1; + + sja1105_vlan_rcv(skb, &tmp_source_port, &tmp_switch_id, &vbid, + &vid); + /* Preserve the source information from the INCL_SRCPT option, + * if available. This allows us to not overwrite a valid source + * port and switch ID with zeroes when receiving link-local + * frames from a VLAN-unaware bridged port (non-zero vbid) or a + * VLAN-aware bridged port (non-zero vid). Furthermore, the + * tag_8021q source port information is only of trust when the + * vbid is 0 (precise port). Otherwise, tmp_source_port and + * tmp_switch_id will be zeroes. + */ + if (vbid == 0 && source_port == -1) + source_port = tmp_source_port; + if (vbid == 0 && switch_id == -1) + switch_id = tmp_switch_id; + } else if (source_port == -1 && switch_id == -1) { + /* Packets with no source information have no chance of + * getting accepted, drop them straight away. + */ + return NULL; + } + + if (source_port != -1 && switch_id != -1) + skb->dev = dsa_master_find_slave(netdev, switch_id, source_port); + else if (vbid >= 1) + skb->dev = dsa_tag_8021q_find_port_by_vbid(netdev, vbid); + else + skb->dev = dsa_find_designated_bridge_port_by_vid(netdev, vid); + if (!skb->dev) { + netdev_warn(netdev, "Couldn't decode source port\n"); + return NULL; + } + + if (!is_link_local) + dsa_default_offload_fwd_mark(skb); + + return sja1105_rcv_meta_state_machine(skb, &meta, is_link_local, + is_meta); +} + +static struct sk_buff *sja1110_rcv_meta(struct sk_buff *skb, u16 rx_header) +{ + u8 *buf = dsa_etype_header_pos_rx(skb) + SJA1110_HEADER_LEN; + int switch_id = SJA1110_RX_HEADER_SWITCH_ID(rx_header); + int n_ts = SJA1110_RX_HEADER_N_TS(rx_header); + struct sja1105_tagger_data *tagger_data; + struct net_device *master = skb->dev; + struct dsa_port *cpu_dp; + struct dsa_switch *ds; + int i; + + cpu_dp = master->dsa_ptr; + ds = dsa_switch_find(cpu_dp->dst->index, switch_id); + if (!ds) { + net_err_ratelimited("%s: cannot find switch id %d\n", + master->name, switch_id); + return NULL; + } + + tagger_data = sja1105_tagger_data(ds); + if (!tagger_data->meta_tstamp_handler) + return NULL; + + for (i = 0; i <= n_ts; i++) { + u8 ts_id, source_port, dir; + u64 tstamp; + + ts_id = buf[0]; + source_port = (buf[1] & GENMASK(7, 4)) >> 4; + dir = (buf[1] & BIT(3)) >> 3; + tstamp = be64_to_cpu(*(__be64 *)(buf + 2)); + + tagger_data->meta_tstamp_handler(ds, source_port, ts_id, dir, + tstamp); + + buf += SJA1110_META_TSTAMP_SIZE; + } + + /* Discard the meta frame, we've consumed the timestamps it contained */ + return NULL; +} + +static struct sk_buff *sja1110_rcv_inband_control_extension(struct sk_buff *skb, + int *source_port, + int *switch_id, + bool *host_only) +{ + u16 rx_header; + + if (unlikely(!pskb_may_pull(skb, SJA1110_HEADER_LEN))) + return NULL; + + /* skb->data points to skb_mac_header(skb) + ETH_HLEN, which is exactly + * what we need because the caller has checked the EtherType (which is + * located 2 bytes back) and we just need a pointer to the header that + * comes afterwards. + */ + rx_header = ntohs(*(__be16 *)skb->data); + + if (rx_header & SJA1110_RX_HEADER_HOST_ONLY) + *host_only = true; + + if (rx_header & SJA1110_RX_HEADER_IS_METADATA) + return sja1110_rcv_meta(skb, rx_header); + + /* Timestamp frame, we have a trailer */ + if (rx_header & SJA1110_RX_HEADER_HAS_TRAILER) { + int start_of_padding = SJA1110_RX_HEADER_TRAILER_POS(rx_header); + u8 *rx_trailer = skb_tail_pointer(skb) - SJA1110_RX_TRAILER_LEN; + u64 *tstamp = &SJA1105_SKB_CB(skb)->tstamp; + u8 last_byte = rx_trailer[12]; + + /* The timestamp is unaligned, so we need to use packing() + * to get it + */ + packing(rx_trailer, tstamp, 63, 0, 8, UNPACK, 0); + + *source_port = SJA1110_RX_TRAILER_SRC_PORT(last_byte); + *switch_id = SJA1110_RX_TRAILER_SWITCH_ID(last_byte); + + /* skb->len counts from skb->data, while start_of_padding + * counts from the destination MAC address. Right now skb->data + * is still as set by the DSA master, so to trim away the + * padding and trailer we need to account for the fact that + * skb->data points to skb_mac_header(skb) + ETH_HLEN. + */ + if (pskb_trim_rcsum(skb, start_of_padding - ETH_HLEN)) + return NULL; + /* Trap-to-host frame, no timestamp trailer */ + } else { + *source_port = SJA1110_RX_HEADER_SRC_PORT(rx_header); + *switch_id = SJA1110_RX_HEADER_SWITCH_ID(rx_header); + } + + /* Advance skb->data past the DSA header */ + skb_pull_rcsum(skb, SJA1110_HEADER_LEN); + + dsa_strip_etype_header(skb, SJA1110_HEADER_LEN); + + /* With skb->data in its final place, update the MAC header + * so that eth_hdr() continues to works properly. + */ + skb_set_mac_header(skb, -ETH_HLEN); + + return skb; +} + +static struct sk_buff *sja1110_rcv(struct sk_buff *skb, + struct net_device *netdev) +{ + int source_port = -1, switch_id = -1, vbid = -1; + bool host_only = false; + u16 vid = 0; + + if (sja1110_skb_has_inband_control_extension(skb)) { + skb = sja1110_rcv_inband_control_extension(skb, &source_port, + &switch_id, + &host_only); + if (!skb) + return NULL; + } + + /* Packets with in-band control extensions might still have RX VLANs */ + if (likely(sja1105_skb_has_tag_8021q(skb))) + sja1105_vlan_rcv(skb, &source_port, &switch_id, &vbid, &vid); + + if (vbid >= 1) + skb->dev = dsa_tag_8021q_find_port_by_vbid(netdev, vbid); + else if (source_port == -1 || switch_id == -1) + skb->dev = dsa_find_designated_bridge_port_by_vid(netdev, vid); + else + skb->dev = dsa_master_find_slave(netdev, switch_id, source_port); + if (!skb->dev) { + netdev_warn(netdev, "Couldn't decode source port\n"); + return NULL; + } + + if (!host_only) + dsa_default_offload_fwd_mark(skb); + + return skb; +} + +static void sja1105_flow_dissect(const struct sk_buff *skb, __be16 *proto, + int *offset) +{ + /* No tag added for management frames, all ok */ + if (unlikely(sja1105_is_link_local(skb))) + return; + + dsa_tag_generic_flow_dissect(skb, proto, offset); +} + +static void sja1110_flow_dissect(const struct sk_buff *skb, __be16 *proto, + int *offset) +{ + /* Management frames have 2 DSA tags on RX, so the needed_headroom we + * declared is fine for the generic dissector adjustment procedure. + */ + if (unlikely(sja1105_is_link_local(skb))) + return dsa_tag_generic_flow_dissect(skb, proto, offset); + + /* For the rest, there is a single DSA tag, the tag_8021q one */ + *offset = VLAN_HLEN; + *proto = ((__be16 *)skb->data)[(VLAN_HLEN / 2) - 1]; +} + +static void sja1105_disconnect(struct dsa_switch *ds) +{ + struct sja1105_tagger_private *priv = ds->tagger_data; + + kthread_destroy_worker(priv->xmit_worker); + kfree(priv); + ds->tagger_data = NULL; +} + +static int sja1105_connect(struct dsa_switch *ds) +{ + struct sja1105_tagger_private *priv; + struct kthread_worker *xmit_worker; + int err; + + priv = kzalloc(sizeof(*priv), GFP_KERNEL); + if (!priv) + return -ENOMEM; + + spin_lock_init(&priv->meta_lock); + + xmit_worker = kthread_create_worker(0, "dsa%d:%d_xmit", + ds->dst->index, ds->index); + if (IS_ERR(xmit_worker)) { + err = PTR_ERR(xmit_worker); + kfree(priv); + return err; + } + + priv->xmit_worker = xmit_worker; + ds->tagger_data = priv; + + return 0; +} + +static const struct dsa_device_ops sja1105_netdev_ops = { + .name = "sja1105", + .proto = DSA_TAG_PROTO_SJA1105, + .xmit = sja1105_xmit, + .rcv = sja1105_rcv, + .connect = sja1105_connect, + .disconnect = sja1105_disconnect, + .needed_headroom = VLAN_HLEN, + .flow_dissect = sja1105_flow_dissect, + .promisc_on_master = true, +}; + +DSA_TAG_DRIVER(sja1105_netdev_ops); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_SJA1105); + +static const struct dsa_device_ops sja1110_netdev_ops = { + .name = "sja1110", + .proto = DSA_TAG_PROTO_SJA1110, + .xmit = sja1110_xmit, + .rcv = sja1110_rcv, + .connect = sja1105_connect, + .disconnect = sja1105_disconnect, + .flow_dissect = sja1110_flow_dissect, + .needed_headroom = SJA1110_HEADER_LEN + VLAN_HLEN, + .needed_tailroom = SJA1110_RX_TRAILER_LEN + SJA1110_MAX_PADDING_LEN, +}; + +DSA_TAG_DRIVER(sja1110_netdev_ops); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_SJA1110); + +static struct dsa_tag_driver *sja1105_tag_driver_array[] = { + &DSA_TAG_DRIVER_NAME(sja1105_netdev_ops), + &DSA_TAG_DRIVER_NAME(sja1110_netdev_ops), +}; + +module_dsa_tag_drivers(sja1105_tag_driver_array); + +MODULE_LICENSE("GPL v2"); diff --git a/net/dsa/tag_trailer.c b/net/dsa/tag_trailer.c new file mode 100644 index 000000000..5749ba85c --- /dev/null +++ b/net/dsa/tag_trailer.c @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: GPL-2.0+ +/* + * net/dsa/tag_trailer.c - Trailer tag format handling + * Copyright (c) 2008-2009 Marvell Semiconductor + */ + +#include +#include +#include + +#include "dsa_priv.h" + +static struct sk_buff *trailer_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct dsa_port *dp = dsa_slave_to_port(dev); + u8 *trailer; + + trailer = skb_put(skb, 4); + trailer[0] = 0x80; + trailer[1] = 1 << dp->index; + trailer[2] = 0x10; + trailer[3] = 0x00; + + return skb; +} + +static struct sk_buff *trailer_rcv(struct sk_buff *skb, struct net_device *dev) +{ + u8 *trailer; + int source_port; + + if (skb_linearize(skb)) + return NULL; + + trailer = skb_tail_pointer(skb) - 4; + if (trailer[0] != 0x80 || (trailer[1] & 0xf8) != 0x00 || + (trailer[2] & 0xef) != 0x00 || trailer[3] != 0x00) + return NULL; + + source_port = trailer[1] & 7; + + skb->dev = dsa_master_find_slave(dev, 0, source_port); + if (!skb->dev) + return NULL; + + if (pskb_trim_rcsum(skb, skb->len - 4)) + return NULL; + + return skb; +} + +static const struct dsa_device_ops trailer_netdev_ops = { + .name = "trailer", + .proto = DSA_TAG_PROTO_TRAILER, + .xmit = trailer_xmit, + .rcv = trailer_rcv, + .needed_tailroom = 4, +}; + +MODULE_LICENSE("GPL"); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_TRAILER); + +module_dsa_tag_driver(trailer_netdev_ops); diff --git a/net/dsa/tag_xrs700x.c b/net/dsa/tag_xrs700x.c new file mode 100644 index 000000000..ff442b8af --- /dev/null +++ b/net/dsa/tag_xrs700x.c @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: GPL-2.0+ +/* + * XRS700x tag format handling + * Copyright (c) 2008-2009 Marvell Semiconductor + * Copyright (c) 2020 NovaTech LLC + */ + +#include + +#include "dsa_priv.h" + +static struct sk_buff *xrs700x_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct dsa_port *partner, *dp = dsa_slave_to_port(dev); + u8 *trailer; + + trailer = skb_put(skb, 1); + trailer[0] = BIT(dp->index); + + if (dp->hsr_dev) + dsa_hsr_foreach_port(partner, dp->ds, dp->hsr_dev) + if (partner != dp) + trailer[0] |= BIT(partner->index); + + return skb; +} + +static struct sk_buff *xrs700x_rcv(struct sk_buff *skb, struct net_device *dev) +{ + int source_port; + u8 *trailer; + + trailer = skb_tail_pointer(skb) - 1; + + source_port = ffs((int)trailer[0]) - 1; + + if (source_port < 0) + return NULL; + + skb->dev = dsa_master_find_slave(dev, 0, source_port); + if (!skb->dev) + return NULL; + + if (pskb_trim_rcsum(skb, skb->len - 1)) + return NULL; + + /* Frame is forwarded by hardware, don't forward in software. */ + dsa_default_offload_fwd_mark(skb); + + return skb; +} + +static const struct dsa_device_ops xrs700x_netdev_ops = { + .name = "xrs700x", + .proto = DSA_TAG_PROTO_XRS700X, + .xmit = xrs700x_xmit, + .rcv = xrs700x_rcv, + .needed_tailroom = 1, +}; + +MODULE_LICENSE("GPL"); +MODULE_ALIAS_DSA_TAG_DRIVER(DSA_TAG_PROTO_XRS700X); + +module_dsa_tag_driver(xrs700x_netdev_ops); diff --git a/net/ethernet/Makefile b/net/ethernet/Makefile new file mode 100644 index 000000000..e03eff94e --- /dev/null +++ b/net/ethernet/Makefile @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the Linux Ethernet layer. +# + +obj-y += eth.o diff --git a/net/ethernet/eth.c b/net/ethernet/eth.c new file mode 100644 index 000000000..e02daa74e --- /dev/null +++ b/net/ethernet/eth.c @@ -0,0 +1,650 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * Ethernet-type device handling. + * + * Version: @(#)eth.c 1.0.7 05/25/93 + * + * Authors: Ross Biro + * Fred N. van Kempen, + * Mark Evans, + * Florian La Roche, + * Alan Cox, + * + * Fixes: + * Mr Linux : Arp problems + * Alan Cox : Generic queue tidyup (very tiny here) + * Alan Cox : eth_header ntohs should be htons + * Alan Cox : eth_rebuild_header missing an htons and + * minor other things. + * Tegge : Arp bug fixes. + * Florian : Removed many unnecessary functions, code cleanup + * and changes for new arp and skbuff. + * Alan Cox : Redid header building to reflect new format. + * Alan Cox : ARP only when compiled with CONFIG_INET + * Greg Page : 802.2 and SNAP stuff. + * Alan Cox : MAC layer pointers/new format. + * Paul Gortmaker : eth_copy_and_sum shouldn't csum padding. + * Alan Cox : Protect against forwarding explosions with + * older network drivers and IFF_ALLMULTI. + * Christer Weinigel : Better rebuild header message. + * Andrew Morton : 26Feb01: kill ether_setup() - use netdev_boot_setup(). + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/** + * eth_header - create the Ethernet header + * @skb: buffer to alter + * @dev: source device + * @type: Ethernet type field + * @daddr: destination address (NULL leave destination address) + * @saddr: source address (NULL use device source address) + * @len: packet length (<= skb->len) + * + * + * Set the protocol type. For a packet of type ETH_P_802_3/2 we put the length + * in here instead. + */ +int eth_header(struct sk_buff *skb, struct net_device *dev, + unsigned short type, + const void *daddr, const void *saddr, unsigned int len) +{ + struct ethhdr *eth = skb_push(skb, ETH_HLEN); + + if (type != ETH_P_802_3 && type != ETH_P_802_2) + eth->h_proto = htons(type); + else + eth->h_proto = htons(len); + + /* + * Set the source hardware address. + */ + + if (!saddr) + saddr = dev->dev_addr; + memcpy(eth->h_source, saddr, ETH_ALEN); + + if (daddr) { + memcpy(eth->h_dest, daddr, ETH_ALEN); + return ETH_HLEN; + } + + /* + * Anyway, the loopback-device should never use this function... + */ + + if (dev->flags & (IFF_LOOPBACK | IFF_NOARP)) { + eth_zero_addr(eth->h_dest); + return ETH_HLEN; + } + + return -ETH_HLEN; +} +EXPORT_SYMBOL(eth_header); + +/** + * eth_get_headlen - determine the length of header for an ethernet frame + * @dev: pointer to network device + * @data: pointer to start of frame + * @len: total length of frame + * + * Make a best effort attempt to pull the length for all of the headers for + * a given frame in a linear buffer. + */ +u32 eth_get_headlen(const struct net_device *dev, const void *data, u32 len) +{ + const unsigned int flags = FLOW_DISSECTOR_F_PARSE_1ST_FRAG; + const struct ethhdr *eth = (const struct ethhdr *)data; + struct flow_keys_basic keys; + + /* this should never happen, but better safe than sorry */ + if (unlikely(len < sizeof(*eth))) + return len; + + /* parse any remaining L2/L3 headers, check for L4 */ + if (!skb_flow_dissect_flow_keys_basic(dev_net(dev), NULL, &keys, data, + eth->h_proto, sizeof(*eth), + len, flags)) + return max_t(u32, keys.control.thoff, sizeof(*eth)); + + /* parse for any L4 headers */ + return min_t(u32, __skb_get_poff(NULL, data, &keys, len), len); +} +EXPORT_SYMBOL(eth_get_headlen); + +/** + * eth_type_trans - determine the packet's protocol ID. + * @skb: received socket data + * @dev: receiving network device + * + * The rule here is that we + * assume 802.3 if the type field is short enough to be a length. + * This is normal practice and works for any 'now in use' protocol. + */ +__be16 eth_type_trans(struct sk_buff *skb, struct net_device *dev) +{ + unsigned short _service_access_point; + const unsigned short *sap; + const struct ethhdr *eth; + + skb->dev = dev; + skb_reset_mac_header(skb); + + eth = (struct ethhdr *)skb->data; + skb_pull_inline(skb, ETH_HLEN); + + if (unlikely(!ether_addr_equal_64bits(eth->h_dest, + dev->dev_addr))) { + if (unlikely(is_multicast_ether_addr_64bits(eth->h_dest))) { + if (ether_addr_equal_64bits(eth->h_dest, dev->broadcast)) + skb->pkt_type = PACKET_BROADCAST; + else + skb->pkt_type = PACKET_MULTICAST; + } else { + skb->pkt_type = PACKET_OTHERHOST; + } + } + + /* + * Some variants of DSA tagging don't have an ethertype field + * at all, so we check here whether one of those tagging + * variants has been configured on the receiving interface, + * and if so, set skb->protocol without looking at the packet. + */ + if (unlikely(netdev_uses_dsa(dev))) + return htons(ETH_P_XDSA); + + if (likely(eth_proto_is_802_3(eth->h_proto))) + return eth->h_proto; + + /* + * This is a magic hack to spot IPX packets. Older Novell breaks + * the protocol design and runs IPX over 802.3 without an 802.2 LLC + * layer. We look for FFFF which isn't a used 802.2 SSAP/DSAP. This + * won't work for fault tolerant netware but does for the rest. + */ + sap = skb_header_pointer(skb, 0, sizeof(*sap), &_service_access_point); + if (sap && *sap == 0xFFFF) + return htons(ETH_P_802_3); + + /* + * Real 802.2 LLC + */ + return htons(ETH_P_802_2); +} +EXPORT_SYMBOL(eth_type_trans); + +/** + * eth_header_parse - extract hardware address from packet + * @skb: packet to extract header from + * @haddr: destination buffer + */ +int eth_header_parse(const struct sk_buff *skb, unsigned char *haddr) +{ + const struct ethhdr *eth = eth_hdr(skb); + memcpy(haddr, eth->h_source, ETH_ALEN); + return ETH_ALEN; +} +EXPORT_SYMBOL(eth_header_parse); + +/** + * eth_header_cache - fill cache entry from neighbour + * @neigh: source neighbour + * @hh: destination cache entry + * @type: Ethernet type field + * + * Create an Ethernet header template from the neighbour. + */ +int eth_header_cache(const struct neighbour *neigh, struct hh_cache *hh, __be16 type) +{ + struct ethhdr *eth; + const struct net_device *dev = neigh->dev; + + eth = (struct ethhdr *) + (((u8 *) hh->hh_data) + (HH_DATA_OFF(sizeof(*eth)))); + + if (type == htons(ETH_P_802_3)) + return -1; + + eth->h_proto = type; + memcpy(eth->h_source, dev->dev_addr, ETH_ALEN); + memcpy(eth->h_dest, neigh->ha, ETH_ALEN); + + /* Pairs with READ_ONCE() in neigh_resolve_output(), + * neigh_hh_output() and neigh_update_hhs(). + */ + smp_store_release(&hh->hh_len, ETH_HLEN); + + return 0; +} +EXPORT_SYMBOL(eth_header_cache); + +/** + * eth_header_cache_update - update cache entry + * @hh: destination cache entry + * @dev: network device + * @haddr: new hardware address + * + * Called by Address Resolution module to notify changes in address. + */ +void eth_header_cache_update(struct hh_cache *hh, + const struct net_device *dev, + const unsigned char *haddr) +{ + memcpy(((u8 *) hh->hh_data) + HH_DATA_OFF(sizeof(struct ethhdr)), + haddr, ETH_ALEN); +} +EXPORT_SYMBOL(eth_header_cache_update); + +/** + * eth_header_parse_protocol - extract protocol from L2 header + * @skb: packet to extract protocol from + */ +__be16 eth_header_parse_protocol(const struct sk_buff *skb) +{ + const struct ethhdr *eth = eth_hdr(skb); + + return eth->h_proto; +} +EXPORT_SYMBOL(eth_header_parse_protocol); + +/** + * eth_prepare_mac_addr_change - prepare for mac change + * @dev: network device + * @p: socket address + */ +int eth_prepare_mac_addr_change(struct net_device *dev, void *p) +{ + struct sockaddr *addr = p; + + if (!(dev->priv_flags & IFF_LIVE_ADDR_CHANGE) && netif_running(dev)) + return -EBUSY; + if (!is_valid_ether_addr(addr->sa_data)) + return -EADDRNOTAVAIL; + return 0; +} +EXPORT_SYMBOL(eth_prepare_mac_addr_change); + +/** + * eth_commit_mac_addr_change - commit mac change + * @dev: network device + * @p: socket address + */ +void eth_commit_mac_addr_change(struct net_device *dev, void *p) +{ + struct sockaddr *addr = p; + + eth_hw_addr_set(dev, addr->sa_data); +} +EXPORT_SYMBOL(eth_commit_mac_addr_change); + +/** + * eth_mac_addr - set new Ethernet hardware address + * @dev: network device + * @p: socket address + * + * Change hardware address of device. + * + * This doesn't change hardware matching, so needs to be overridden + * for most real devices. + */ +int eth_mac_addr(struct net_device *dev, void *p) +{ + int ret; + + ret = eth_prepare_mac_addr_change(dev, p); + if (ret < 0) + return ret; + eth_commit_mac_addr_change(dev, p); + return 0; +} +EXPORT_SYMBOL(eth_mac_addr); + +int eth_validate_addr(struct net_device *dev) +{ + if (!is_valid_ether_addr(dev->dev_addr)) + return -EADDRNOTAVAIL; + + return 0; +} +EXPORT_SYMBOL(eth_validate_addr); + +const struct header_ops eth_header_ops ____cacheline_aligned = { + .create = eth_header, + .parse = eth_header_parse, + .cache = eth_header_cache, + .cache_update = eth_header_cache_update, + .parse_protocol = eth_header_parse_protocol, +}; + +/** + * ether_setup - setup Ethernet network device + * @dev: network device + * + * Fill in the fields of the device structure with Ethernet-generic values. + */ +void ether_setup(struct net_device *dev) +{ + dev->header_ops = ð_header_ops; + dev->type = ARPHRD_ETHER; + dev->hard_header_len = ETH_HLEN; + dev->min_header_len = ETH_HLEN; + dev->mtu = ETH_DATA_LEN; + dev->min_mtu = ETH_MIN_MTU; + dev->max_mtu = ETH_DATA_LEN; + dev->addr_len = ETH_ALEN; + dev->tx_queue_len = DEFAULT_TX_QUEUE_LEN; + dev->flags = IFF_BROADCAST|IFF_MULTICAST; + dev->priv_flags |= IFF_TX_SKB_SHARING; + + eth_broadcast_addr(dev->broadcast); + +} +EXPORT_SYMBOL(ether_setup); + +/** + * alloc_etherdev_mqs - Allocates and sets up an Ethernet device + * @sizeof_priv: Size of additional driver-private structure to be allocated + * for this Ethernet device + * @txqs: The number of TX queues this device has. + * @rxqs: The number of RX queues this device has. + * + * Fill in the fields of the device structure with Ethernet-generic + * values. Basically does everything except registering the device. + * + * Constructs a new net device, complete with a private data area of + * size (sizeof_priv). A 32-byte (not bit) alignment is enforced for + * this private data area. + */ + +struct net_device *alloc_etherdev_mqs(int sizeof_priv, unsigned int txqs, + unsigned int rxqs) +{ + return alloc_netdev_mqs(sizeof_priv, "eth%d", NET_NAME_ENUM, + ether_setup, txqs, rxqs); +} +EXPORT_SYMBOL(alloc_etherdev_mqs); + +ssize_t sysfs_format_mac(char *buf, const unsigned char *addr, int len) +{ + return scnprintf(buf, PAGE_SIZE, "%*phC\n", len, addr); +} +EXPORT_SYMBOL(sysfs_format_mac); + +struct sk_buff *eth_gro_receive(struct list_head *head, struct sk_buff *skb) +{ + const struct packet_offload *ptype; + unsigned int hlen, off_eth; + struct sk_buff *pp = NULL; + struct ethhdr *eh, *eh2; + struct sk_buff *p; + __be16 type; + int flush = 1; + + off_eth = skb_gro_offset(skb); + hlen = off_eth + sizeof(*eh); + eh = skb_gro_header(skb, hlen, off_eth); + if (unlikely(!eh)) + goto out; + + flush = 0; + + list_for_each_entry(p, head, list) { + if (!NAPI_GRO_CB(p)->same_flow) + continue; + + eh2 = (struct ethhdr *)(p->data + off_eth); + if (compare_ether_header(eh, eh2)) { + NAPI_GRO_CB(p)->same_flow = 0; + continue; + } + } + + type = eh->h_proto; + + ptype = gro_find_receive_by_type(type); + if (ptype == NULL) { + flush = 1; + goto out; + } + + skb_gro_pull(skb, sizeof(*eh)); + skb_gro_postpull_rcsum(skb, eh, sizeof(*eh)); + + pp = indirect_call_gro_receive_inet(ptype->callbacks.gro_receive, + ipv6_gro_receive, inet_gro_receive, + head, skb); + +out: + skb_gro_flush_final(skb, pp, flush); + + return pp; +} +EXPORT_SYMBOL(eth_gro_receive); + +int eth_gro_complete(struct sk_buff *skb, int nhoff) +{ + struct ethhdr *eh = (struct ethhdr *)(skb->data + nhoff); + __be16 type = eh->h_proto; + struct packet_offload *ptype; + int err = -ENOSYS; + + if (skb->encapsulation) + skb_set_inner_mac_header(skb, nhoff); + + ptype = gro_find_complete_by_type(type); + if (ptype != NULL) + err = INDIRECT_CALL_INET(ptype->callbacks.gro_complete, + ipv6_gro_complete, inet_gro_complete, + skb, nhoff + sizeof(*eh)); + + return err; +} +EXPORT_SYMBOL(eth_gro_complete); + +static struct packet_offload eth_packet_offload __read_mostly = { + .type = cpu_to_be16(ETH_P_TEB), + .priority = 10, + .callbacks = { + .gro_receive = eth_gro_receive, + .gro_complete = eth_gro_complete, + }, +}; + +static int __init eth_offload_init(void) +{ + dev_add_offload(ð_packet_offload); + + return 0; +} + +fs_initcall(eth_offload_init); + +unsigned char * __weak arch_get_platform_mac_address(void) +{ + return NULL; +} + +int eth_platform_get_mac_address(struct device *dev, u8 *mac_addr) +{ + unsigned char *addr; + int ret; + + ret = of_get_mac_address(dev->of_node, mac_addr); + if (!ret) + return 0; + + addr = arch_get_platform_mac_address(); + if (!addr) + return -ENODEV; + + ether_addr_copy(mac_addr, addr); + + return 0; +} +EXPORT_SYMBOL(eth_platform_get_mac_address); + +/** + * platform_get_ethdev_address - Set netdev's MAC address from a given device + * @dev: Pointer to the device + * @netdev: Pointer to netdev to write the address to + * + * Wrapper around eth_platform_get_mac_address() which writes the address + * directly to netdev->dev_addr. + */ +int platform_get_ethdev_address(struct device *dev, struct net_device *netdev) +{ + u8 addr[ETH_ALEN] __aligned(2); + int ret; + + ret = eth_platform_get_mac_address(dev, addr); + if (!ret) + eth_hw_addr_set(netdev, addr); + return ret; +} +EXPORT_SYMBOL(platform_get_ethdev_address); + +/** + * nvmem_get_mac_address - Obtain the MAC address from an nvmem cell named + * 'mac-address' associated with given device. + * + * @dev: Device with which the mac-address cell is associated. + * @addrbuf: Buffer to which the MAC address will be copied on success. + * + * Returns 0 on success or a negative error number on failure. + */ +int nvmem_get_mac_address(struct device *dev, void *addrbuf) +{ + struct nvmem_cell *cell; + const void *mac; + size_t len; + + cell = nvmem_cell_get(dev, "mac-address"); + if (IS_ERR(cell)) + return PTR_ERR(cell); + + mac = nvmem_cell_read(cell, &len); + nvmem_cell_put(cell); + + if (IS_ERR(mac)) + return PTR_ERR(mac); + + if (len != ETH_ALEN || !is_valid_ether_addr(mac)) { + kfree(mac); + return -EINVAL; + } + + ether_addr_copy(addrbuf, mac); + kfree(mac); + + return 0; +} + +static int fwnode_get_mac_addr(struct fwnode_handle *fwnode, + const char *name, char *addr) +{ + int ret; + + ret = fwnode_property_read_u8_array(fwnode, name, addr, ETH_ALEN); + if (ret) + return ret; + + if (!is_valid_ether_addr(addr)) + return -EINVAL; + return 0; +} + +/** + * fwnode_get_mac_address - Get the MAC from the firmware node + * @fwnode: Pointer to the firmware node + * @addr: Address of buffer to store the MAC in + * + * Search the firmware node for the best MAC address to use. 'mac-address' is + * checked first, because that is supposed to contain to "most recent" MAC + * address. If that isn't set, then 'local-mac-address' is checked next, + * because that is the default address. If that isn't set, then the obsolete + * 'address' is checked, just in case we're using an old device tree. + * + * Note that the 'address' property is supposed to contain a virtual address of + * the register set, but some DTS files have redefined that property to be the + * MAC address. + * + * All-zero MAC addresses are rejected, because those could be properties that + * exist in the firmware tables, but were not updated by the firmware. For + * example, the DTS could define 'mac-address' and 'local-mac-address', with + * zero MAC addresses. Some older U-Boots only initialized 'local-mac-address'. + * In this case, the real MAC is in 'local-mac-address', and 'mac-address' + * exists but is all zeros. + */ +int fwnode_get_mac_address(struct fwnode_handle *fwnode, char *addr) +{ + if (!fwnode_get_mac_addr(fwnode, "mac-address", addr) || + !fwnode_get_mac_addr(fwnode, "local-mac-address", addr) || + !fwnode_get_mac_addr(fwnode, "address", addr)) + return 0; + + return -ENOENT; +} +EXPORT_SYMBOL(fwnode_get_mac_address); + +/** + * device_get_mac_address - Get the MAC for a given device + * @dev: Pointer to the device + * @addr: Address of buffer to store the MAC in + */ +int device_get_mac_address(struct device *dev, char *addr) +{ + return fwnode_get_mac_address(dev_fwnode(dev), addr); +} +EXPORT_SYMBOL(device_get_mac_address); + +/** + * device_get_ethdev_address - Set netdev's MAC address from a given device + * @dev: Pointer to the device + * @netdev: Pointer to netdev to write the address to + * + * Wrapper around device_get_mac_address() which writes the address + * directly to netdev->dev_addr. + */ +int device_get_ethdev_address(struct device *dev, struct net_device *netdev) +{ + u8 addr[ETH_ALEN]; + int ret; + + ret = device_get_mac_address(dev, addr); + if (!ret) + eth_hw_addr_set(netdev, addr); + return ret; +} +EXPORT_SYMBOL(device_get_ethdev_address); diff --git a/net/ethtool/Makefile b/net/ethtool/Makefile new file mode 100644 index 000000000..72ab09442 --- /dev/null +++ b/net/ethtool/Makefile @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: GPL-2.0-only + +obj-y += ioctl.o common.o + +obj-$(CONFIG_ETHTOOL_NETLINK) += ethtool_nl.o + +ethtool_nl-y := netlink.o bitset.o strset.o linkinfo.o linkmodes.o \ + linkstate.o debug.o wol.o features.o privflags.o rings.o \ + channels.o coalesce.o pause.o eee.o tsinfo.o cabletest.o \ + tunnels.o fec.o eeprom.o stats.o phc_vclocks.o module.o \ + pse-pd.o diff --git a/net/ethtool/bitset.c b/net/ethtool/bitset.c new file mode 100644 index 000000000..0515d6604 --- /dev/null +++ b/net/ethtool/bitset.c @@ -0,0 +1,833 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include +#include +#include "netlink.h" +#include "bitset.h" + +/* Some bitmaps are internally represented as an array of unsigned long, some + * as an array of u32 (some even as single u32 for now). To avoid the need of + * wrappers on caller side, we provide two set of functions: those with "32" + * suffix in their names expect u32 based bitmaps, those without it expect + * unsigned long bitmaps. + */ + +static u32 ethnl_lower_bits(unsigned int n) +{ + return ~(u32)0 >> (32 - n % 32); +} + +static u32 ethnl_upper_bits(unsigned int n) +{ + return ~(u32)0 << (n % 32); +} + +/** + * ethnl_bitmap32_clear() - Clear u32 based bitmap + * @dst: bitmap to clear + * @start: beginning of the interval + * @end: end of the interval + * @mod: set if bitmap was modified + * + * Clear @nbits bits of a bitmap with indices @start <= i < @end + */ +static void ethnl_bitmap32_clear(u32 *dst, unsigned int start, unsigned int end, + bool *mod) +{ + unsigned int start_word = start / 32; + unsigned int end_word = end / 32; + unsigned int i; + u32 mask; + + if (end <= start) + return; + + if (start % 32) { + mask = ethnl_upper_bits(start); + if (end_word == start_word) { + mask &= ethnl_lower_bits(end); + if (dst[start_word] & mask) { + dst[start_word] &= ~mask; + *mod = true; + } + return; + } + if (dst[start_word] & mask) { + dst[start_word] &= ~mask; + *mod = true; + } + start_word++; + } + + for (i = start_word; i < end_word; i++) { + if (dst[i]) { + dst[i] = 0; + *mod = true; + } + } + if (end % 32) { + mask = ethnl_lower_bits(end); + if (dst[end_word] & mask) { + dst[end_word] &= ~mask; + *mod = true; + } + } +} + +/** + * ethnl_bitmap32_not_zero() - Check if any bit is set in an interval + * @map: bitmap to test + * @start: beginning of the interval + * @end: end of the interval + * + * Return: true if there is non-zero bit with index @start <= i < @end, + * false if the whole interval is zero + */ +static bool ethnl_bitmap32_not_zero(const u32 *map, unsigned int start, + unsigned int end) +{ + unsigned int start_word = start / 32; + unsigned int end_word = end / 32; + u32 mask; + + if (end <= start) + return true; + + if (start % 32) { + mask = ethnl_upper_bits(start); + if (end_word == start_word) { + mask &= ethnl_lower_bits(end); + return map[start_word] & mask; + } + if (map[start_word] & mask) + return true; + start_word++; + } + + if (!memchr_inv(map + start_word, '\0', + (end_word - start_word) * sizeof(u32))) + return true; + if (end % 32 == 0) + return true; + return map[end_word] & ethnl_lower_bits(end); +} + +/** + * ethnl_bitmap32_update() - Modify u32 based bitmap according to value/mask + * pair + * @dst: bitmap to update + * @nbits: bit size of the bitmap + * @value: values to set + * @mask: mask of bits to set + * @mod: set to true if bitmap is modified, preserve if not + * + * Set bits in @dst bitmap which are set in @mask to values from @value, leave + * the rest untouched. If destination bitmap was modified, set @mod to true, + * leave as it is if not. + */ +static void ethnl_bitmap32_update(u32 *dst, unsigned int nbits, + const u32 *value, const u32 *mask, bool *mod) +{ + while (nbits > 0) { + u32 real_mask = mask ? *mask : ~(u32)0; + u32 new_value; + + if (nbits < 32) + real_mask &= ethnl_lower_bits(nbits); + new_value = (*dst & ~real_mask) | (*value & real_mask); + if (new_value != *dst) { + *dst = new_value; + *mod = true; + } + + if (nbits <= 32) + break; + dst++; + nbits -= 32; + value++; + if (mask) + mask++; + } +} + +static bool ethnl_bitmap32_test_bit(const u32 *map, unsigned int index) +{ + return map[index / 32] & (1U << (index % 32)); +} + +/** + * ethnl_bitset32_size() - Calculate size of bitset nested attribute + * @val: value bitmap (u32 based) + * @mask: mask bitmap (u32 based, optional) + * @nbits: bit length of the bitset + * @names: array of bit names (optional) + * @compact: assume compact format for output + * + * Estimate length of netlink attribute composed by a later call to + * ethnl_put_bitset32() call with the same arguments. + * + * Return: negative error code or attribute length estimate + */ +int ethnl_bitset32_size(const u32 *val, const u32 *mask, unsigned int nbits, + ethnl_string_array_t names, bool compact) +{ + unsigned int len = 0; + + /* list flag */ + if (!mask) + len += nla_total_size(sizeof(u32)); + /* size */ + len += nla_total_size(sizeof(u32)); + + if (compact) { + unsigned int nwords = DIV_ROUND_UP(nbits, 32); + + /* value, mask */ + len += (mask ? 2 : 1) * nla_total_size(nwords * sizeof(u32)); + } else { + unsigned int bits_len = 0; + unsigned int bit_len, i; + + for (i = 0; i < nbits; i++) { + const char *name = names ? names[i] : NULL; + + if (!ethnl_bitmap32_test_bit(mask ?: val, i)) + continue; + /* index */ + bit_len = nla_total_size(sizeof(u32)); + /* name */ + if (name) + bit_len += ethnl_strz_size(name); + /* value */ + if (mask && ethnl_bitmap32_test_bit(val, i)) + bit_len += nla_total_size(0); + + /* bit nest */ + bits_len += nla_total_size(bit_len); + } + /* bits nest */ + len += nla_total_size(bits_len); + } + + /* outermost nest */ + return nla_total_size(len); +} + +/** + * ethnl_put_bitset32() - Put a bitset nest into a message + * @skb: skb with the message + * @attrtype: attribute type for the bitset nest + * @val: value bitmap (u32 based) + * @mask: mask bitmap (u32 based, optional) + * @nbits: bit length of the bitset + * @names: array of bit names (optional) + * @compact: use compact format for the output + * + * Compose a nested attribute representing a bitset. If @mask is null, simple + * bitmap (bit list) is created, if @mask is provided, represent a value/mask + * pair. Bit names are only used in verbose mode and when provided by calller. + * + * Return: 0 on success, negative error value on error + */ +int ethnl_put_bitset32(struct sk_buff *skb, int attrtype, const u32 *val, + const u32 *mask, unsigned int nbits, + ethnl_string_array_t names, bool compact) +{ + struct nlattr *nest; + struct nlattr *attr; + + nest = nla_nest_start(skb, attrtype); + if (!nest) + return -EMSGSIZE; + + if (!mask && nla_put_flag(skb, ETHTOOL_A_BITSET_NOMASK)) + goto nla_put_failure; + if (nla_put_u32(skb, ETHTOOL_A_BITSET_SIZE, nbits)) + goto nla_put_failure; + if (compact) { + unsigned int nwords = DIV_ROUND_UP(nbits, 32); + unsigned int nbytes = nwords * sizeof(u32); + u32 *dst; + + attr = nla_reserve(skb, ETHTOOL_A_BITSET_VALUE, nbytes); + if (!attr) + goto nla_put_failure; + dst = nla_data(attr); + memcpy(dst, val, nbytes); + if (nbits % 32) + dst[nwords - 1] &= ethnl_lower_bits(nbits); + + if (mask) { + attr = nla_reserve(skb, ETHTOOL_A_BITSET_MASK, nbytes); + if (!attr) + goto nla_put_failure; + dst = nla_data(attr); + memcpy(dst, mask, nbytes); + if (nbits % 32) + dst[nwords - 1] &= ethnl_lower_bits(nbits); + } + } else { + struct nlattr *bits; + unsigned int i; + + bits = nla_nest_start(skb, ETHTOOL_A_BITSET_BITS); + if (!bits) + goto nla_put_failure; + for (i = 0; i < nbits; i++) { + const char *name = names ? names[i] : NULL; + + if (!ethnl_bitmap32_test_bit(mask ?: val, i)) + continue; + attr = nla_nest_start(skb, ETHTOOL_A_BITSET_BITS_BIT); + if (!attr) + goto nla_put_failure; + if (nla_put_u32(skb, ETHTOOL_A_BITSET_BIT_INDEX, i)) + goto nla_put_failure; + if (name && + ethnl_put_strz(skb, ETHTOOL_A_BITSET_BIT_NAME, name)) + goto nla_put_failure; + if (mask && ethnl_bitmap32_test_bit(val, i) && + nla_put_flag(skb, ETHTOOL_A_BITSET_BIT_VALUE)) + goto nla_put_failure; + nla_nest_end(skb, attr); + } + nla_nest_end(skb, bits); + } + + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static const struct nla_policy bitset_policy[] = { + [ETHTOOL_A_BITSET_NOMASK] = { .type = NLA_FLAG }, + [ETHTOOL_A_BITSET_SIZE] = NLA_POLICY_MAX(NLA_U32, + ETHNL_MAX_BITSET_SIZE), + [ETHTOOL_A_BITSET_BITS] = { .type = NLA_NESTED }, + [ETHTOOL_A_BITSET_VALUE] = { .type = NLA_BINARY }, + [ETHTOOL_A_BITSET_MASK] = { .type = NLA_BINARY }, +}; + +static const struct nla_policy bit_policy[] = { + [ETHTOOL_A_BITSET_BIT_INDEX] = { .type = NLA_U32 }, + [ETHTOOL_A_BITSET_BIT_NAME] = { .type = NLA_NUL_STRING }, + [ETHTOOL_A_BITSET_BIT_VALUE] = { .type = NLA_FLAG }, +}; + +/** + * ethnl_bitset_is_compact() - check if bitset attribute represents a compact + * bitset + * @bitset: nested attribute representing a bitset + * @compact: pointer for return value + * + * Return: 0 on success, negative error code on failure + */ +int ethnl_bitset_is_compact(const struct nlattr *bitset, bool *compact) +{ + struct nlattr *tb[ARRAY_SIZE(bitset_policy)]; + int ret; + + ret = nla_parse_nested(tb, ARRAY_SIZE(bitset_policy) - 1, bitset, + bitset_policy, NULL); + if (ret < 0) + return ret; + + if (tb[ETHTOOL_A_BITSET_BITS]) { + if (tb[ETHTOOL_A_BITSET_VALUE] || tb[ETHTOOL_A_BITSET_MASK]) + return -EINVAL; + *compact = false; + return 0; + } + if (!tb[ETHTOOL_A_BITSET_SIZE] || !tb[ETHTOOL_A_BITSET_VALUE]) + return -EINVAL; + + *compact = true; + return 0; +} + +/** + * ethnl_name_to_idx() - look up string index for a name + * @names: array of ETH_GSTRING_LEN sized strings + * @n_names: number of strings in the array + * @name: name to look up + * + * Return: index of the string if found, -ENOENT if not found + */ +static int ethnl_name_to_idx(ethnl_string_array_t names, unsigned int n_names, + const char *name) +{ + unsigned int i; + + if (!names) + return -ENOENT; + + for (i = 0; i < n_names; i++) { + /* names[i] may not be null terminated */ + if (!strncmp(names[i], name, ETH_GSTRING_LEN) && + strlen(name) <= ETH_GSTRING_LEN) + return i; + } + + return -ENOENT; +} + +static int ethnl_parse_bit(unsigned int *index, bool *val, unsigned int nbits, + const struct nlattr *bit_attr, bool no_mask, + ethnl_string_array_t names, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[ARRAY_SIZE(bit_policy)]; + int ret, idx; + + ret = nla_parse_nested(tb, ARRAY_SIZE(bit_policy) - 1, bit_attr, + bit_policy, extack); + if (ret < 0) + return ret; + + if (tb[ETHTOOL_A_BITSET_BIT_INDEX]) { + const char *name; + + idx = nla_get_u32(tb[ETHTOOL_A_BITSET_BIT_INDEX]); + if (idx >= nbits) { + NL_SET_ERR_MSG_ATTR(extack, + tb[ETHTOOL_A_BITSET_BIT_INDEX], + "bit index too high"); + return -EOPNOTSUPP; + } + name = names ? names[idx] : NULL; + if (tb[ETHTOOL_A_BITSET_BIT_NAME] && name && + strncmp(nla_data(tb[ETHTOOL_A_BITSET_BIT_NAME]), name, + nla_len(tb[ETHTOOL_A_BITSET_BIT_NAME]))) { + NL_SET_ERR_MSG_ATTR(extack, bit_attr, + "bit index and name mismatch"); + return -EINVAL; + } + } else if (tb[ETHTOOL_A_BITSET_BIT_NAME]) { + idx = ethnl_name_to_idx(names, nbits, + nla_data(tb[ETHTOOL_A_BITSET_BIT_NAME])); + if (idx < 0) { + NL_SET_ERR_MSG_ATTR(extack, + tb[ETHTOOL_A_BITSET_BIT_NAME], + "bit name not found"); + return -EOPNOTSUPP; + } + } else { + NL_SET_ERR_MSG_ATTR(extack, bit_attr, + "neither bit index nor name specified"); + return -EINVAL; + } + + *index = idx; + *val = no_mask || tb[ETHTOOL_A_BITSET_BIT_VALUE]; + return 0; +} + +static int +ethnl_update_bitset32_verbose(u32 *bitmap, unsigned int nbits, + const struct nlattr *attr, struct nlattr **tb, + ethnl_string_array_t names, + struct netlink_ext_ack *extack, bool *mod) +{ + struct nlattr *bit_attr; + bool no_mask; + int rem; + int ret; + + if (tb[ETHTOOL_A_BITSET_VALUE]) { + NL_SET_ERR_MSG_ATTR(extack, tb[ETHTOOL_A_BITSET_VALUE], + "value only allowed in compact bitset"); + return -EINVAL; + } + if (tb[ETHTOOL_A_BITSET_MASK]) { + NL_SET_ERR_MSG_ATTR(extack, tb[ETHTOOL_A_BITSET_MASK], + "mask only allowed in compact bitset"); + return -EINVAL; + } + + no_mask = tb[ETHTOOL_A_BITSET_NOMASK]; + if (no_mask) + ethnl_bitmap32_clear(bitmap, 0, nbits, mod); + + nla_for_each_nested(bit_attr, tb[ETHTOOL_A_BITSET_BITS], rem) { + bool old_val, new_val; + unsigned int idx; + + if (nla_type(bit_attr) != ETHTOOL_A_BITSET_BITS_BIT) { + NL_SET_ERR_MSG_ATTR(extack, bit_attr, + "only ETHTOOL_A_BITSET_BITS_BIT allowed in ETHTOOL_A_BITSET_BITS"); + return -EINVAL; + } + ret = ethnl_parse_bit(&idx, &new_val, nbits, bit_attr, no_mask, + names, extack); + if (ret < 0) + return ret; + old_val = bitmap[idx / 32] & ((u32)1 << (idx % 32)); + if (new_val != old_val) { + if (new_val) + bitmap[idx / 32] |= ((u32)1 << (idx % 32)); + else + bitmap[idx / 32] &= ~((u32)1 << (idx % 32)); + *mod = true; + } + } + + return 0; +} + +static int ethnl_compact_sanity_checks(unsigned int nbits, + const struct nlattr *nest, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + bool no_mask = tb[ETHTOOL_A_BITSET_NOMASK]; + unsigned int attr_nbits, attr_nwords; + const struct nlattr *test_attr; + + if (no_mask && tb[ETHTOOL_A_BITSET_MASK]) { + NL_SET_ERR_MSG_ATTR(extack, tb[ETHTOOL_A_BITSET_MASK], + "mask not allowed in list bitset"); + return -EINVAL; + } + if (!tb[ETHTOOL_A_BITSET_SIZE]) { + NL_SET_ERR_MSG_ATTR(extack, nest, + "missing size in compact bitset"); + return -EINVAL; + } + if (!tb[ETHTOOL_A_BITSET_VALUE]) { + NL_SET_ERR_MSG_ATTR(extack, nest, + "missing value in compact bitset"); + return -EINVAL; + } + if (!no_mask && !tb[ETHTOOL_A_BITSET_MASK]) { + NL_SET_ERR_MSG_ATTR(extack, nest, + "missing mask in compact nonlist bitset"); + return -EINVAL; + } + + attr_nbits = nla_get_u32(tb[ETHTOOL_A_BITSET_SIZE]); + attr_nwords = DIV_ROUND_UP(attr_nbits, 32); + if (nla_len(tb[ETHTOOL_A_BITSET_VALUE]) != attr_nwords * sizeof(u32)) { + NL_SET_ERR_MSG_ATTR(extack, tb[ETHTOOL_A_BITSET_VALUE], + "bitset value length does not match size"); + return -EINVAL; + } + if (tb[ETHTOOL_A_BITSET_MASK] && + nla_len(tb[ETHTOOL_A_BITSET_MASK]) != attr_nwords * sizeof(u32)) { + NL_SET_ERR_MSG_ATTR(extack, tb[ETHTOOL_A_BITSET_MASK], + "bitset mask length does not match size"); + return -EINVAL; + } + if (attr_nbits <= nbits) + return 0; + + test_attr = no_mask ? tb[ETHTOOL_A_BITSET_VALUE] : + tb[ETHTOOL_A_BITSET_MASK]; + if (ethnl_bitmap32_not_zero(nla_data(test_attr), nbits, attr_nbits)) { + NL_SET_ERR_MSG_ATTR(extack, test_attr, + "cannot modify bits past kernel bitset size"); + return -EINVAL; + } + return 0; +} + +/** + * ethnl_update_bitset32() - Apply a bitset nest to a u32 based bitmap + * @bitmap: bitmap to update + * @nbits: size of the updated bitmap in bits + * @attr: nest attribute to parse and apply + * @names: array of bit names; may be null for compact format + * @extack: extack for error reporting + * @mod: set this to true if bitmap is modified, leave as it is if not + * + * Apply bitset netsted attribute to a bitmap. If the attribute represents + * a bit list, @bitmap is set to its contents; otherwise, bits in mask are + * set to values from value. Bitmaps in the attribute may be longer than + * @nbits but the message must not request modifying any bits past @nbits. + * + * Return: negative error code on failure, 0 on success + */ +int ethnl_update_bitset32(u32 *bitmap, unsigned int nbits, + const struct nlattr *attr, ethnl_string_array_t names, + struct netlink_ext_ack *extack, bool *mod) +{ + struct nlattr *tb[ARRAY_SIZE(bitset_policy)]; + unsigned int change_bits; + bool no_mask; + int ret; + + if (!attr) + return 0; + ret = nla_parse_nested(tb, ARRAY_SIZE(bitset_policy) - 1, attr, + bitset_policy, extack); + if (ret < 0) + return ret; + + if (tb[ETHTOOL_A_BITSET_BITS]) + return ethnl_update_bitset32_verbose(bitmap, nbits, attr, tb, + names, extack, mod); + ret = ethnl_compact_sanity_checks(nbits, attr, tb, extack); + if (ret < 0) + return ret; + + no_mask = tb[ETHTOOL_A_BITSET_NOMASK]; + change_bits = min_t(unsigned int, + nla_get_u32(tb[ETHTOOL_A_BITSET_SIZE]), nbits); + ethnl_bitmap32_update(bitmap, change_bits, + nla_data(tb[ETHTOOL_A_BITSET_VALUE]), + no_mask ? NULL : + nla_data(tb[ETHTOOL_A_BITSET_MASK]), + mod); + if (no_mask && change_bits < nbits) + ethnl_bitmap32_clear(bitmap, change_bits, nbits, mod); + + return 0; +} + +/** + * ethnl_parse_bitset() - Compute effective value and mask from bitset nest + * @val: unsigned long based bitmap to put value into + * @mask: unsigned long based bitmap to put mask into + * @nbits: size of @val and @mask bitmaps + * @attr: nest attribute to parse and apply + * @names: array of bit names; may be null for compact format + * @extack: extack for error reporting + * + * Provide @nbits size long bitmaps for value and mask so that + * x = (val & mask) | (x & ~mask) would modify any @nbits sized bitmap x + * the same way ethnl_update_bitset() with the same bitset attribute would. + * + * Return: negative error code on failure, 0 on success + */ +int ethnl_parse_bitset(unsigned long *val, unsigned long *mask, + unsigned int nbits, const struct nlattr *attr, + ethnl_string_array_t names, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[ARRAY_SIZE(bitset_policy)]; + const struct nlattr *bit_attr; + bool no_mask; + int rem; + int ret; + + if (!attr) + return 0; + ret = nla_parse_nested(tb, ARRAY_SIZE(bitset_policy) - 1, attr, + bitset_policy, extack); + if (ret < 0) + return ret; + no_mask = tb[ETHTOOL_A_BITSET_NOMASK]; + + if (!tb[ETHTOOL_A_BITSET_BITS]) { + unsigned int change_bits; + + ret = ethnl_compact_sanity_checks(nbits, attr, tb, extack); + if (ret < 0) + return ret; + + change_bits = nla_get_u32(tb[ETHTOOL_A_BITSET_SIZE]); + if (change_bits > nbits) + change_bits = nbits; + bitmap_from_arr32(val, nla_data(tb[ETHTOOL_A_BITSET_VALUE]), + change_bits); + if (change_bits < nbits) + bitmap_clear(val, change_bits, nbits - change_bits); + if (no_mask) { + bitmap_fill(mask, nbits); + } else { + bitmap_from_arr32(mask, + nla_data(tb[ETHTOOL_A_BITSET_MASK]), + change_bits); + if (change_bits < nbits) + bitmap_clear(mask, change_bits, + nbits - change_bits); + } + + return 0; + } + + if (tb[ETHTOOL_A_BITSET_VALUE]) { + NL_SET_ERR_MSG_ATTR(extack, tb[ETHTOOL_A_BITSET_VALUE], + "value only allowed in compact bitset"); + return -EINVAL; + } + if (tb[ETHTOOL_A_BITSET_MASK]) { + NL_SET_ERR_MSG_ATTR(extack, tb[ETHTOOL_A_BITSET_MASK], + "mask only allowed in compact bitset"); + return -EINVAL; + } + + bitmap_zero(val, nbits); + if (no_mask) + bitmap_fill(mask, nbits); + else + bitmap_zero(mask, nbits); + + nla_for_each_nested(bit_attr, tb[ETHTOOL_A_BITSET_BITS], rem) { + unsigned int idx; + bool bit_val; + + ret = ethnl_parse_bit(&idx, &bit_val, nbits, bit_attr, no_mask, + names, extack); + if (ret < 0) + return ret; + if (bit_val) + __set_bit(idx, val); + if (!no_mask) + __set_bit(idx, mask); + } + + return 0; +} + +#if BITS_PER_LONG == 64 && defined(__BIG_ENDIAN) + +/* 64-bit big endian architectures are the only case when u32 based bitmaps + * and unsigned long based bitmaps have different memory layout so that we + * cannot simply cast the latter to the former and need actual wrappers + * converting the latter to the former. + * + * To reduce the number of slab allocations, the wrappers use fixed size local + * variables for bitmaps up to ETHNL_SMALL_BITMAP_BITS bits which is the + * majority of bitmaps used by ethtool. + */ +#define ETHNL_SMALL_BITMAP_BITS 128 +#define ETHNL_SMALL_BITMAP_WORDS DIV_ROUND_UP(ETHNL_SMALL_BITMAP_BITS, 32) + +int ethnl_bitset_size(const unsigned long *val, const unsigned long *mask, + unsigned int nbits, ethnl_string_array_t names, + bool compact) +{ + u32 small_mask32[ETHNL_SMALL_BITMAP_WORDS]; + u32 small_val32[ETHNL_SMALL_BITMAP_WORDS]; + u32 *mask32; + u32 *val32; + int ret; + + if (nbits > ETHNL_SMALL_BITMAP_BITS) { + unsigned int nwords = DIV_ROUND_UP(nbits, 32); + + val32 = kmalloc_array(2 * nwords, sizeof(u32), GFP_KERNEL); + if (!val32) + return -ENOMEM; + mask32 = val32 + nwords; + } else { + val32 = small_val32; + mask32 = small_mask32; + } + + bitmap_to_arr32(val32, val, nbits); + if (mask) + bitmap_to_arr32(mask32, mask, nbits); + else + mask32 = NULL; + ret = ethnl_bitset32_size(val32, mask32, nbits, names, compact); + + if (nbits > ETHNL_SMALL_BITMAP_BITS) + kfree(val32); + + return ret; +} + +int ethnl_put_bitset(struct sk_buff *skb, int attrtype, + const unsigned long *val, const unsigned long *mask, + unsigned int nbits, ethnl_string_array_t names, + bool compact) +{ + u32 small_mask32[ETHNL_SMALL_BITMAP_WORDS]; + u32 small_val32[ETHNL_SMALL_BITMAP_WORDS]; + u32 *mask32; + u32 *val32; + int ret; + + if (nbits > ETHNL_SMALL_BITMAP_BITS) { + unsigned int nwords = DIV_ROUND_UP(nbits, 32); + + val32 = kmalloc_array(2 * nwords, sizeof(u32), GFP_KERNEL); + if (!val32) + return -ENOMEM; + mask32 = val32 + nwords; + } else { + val32 = small_val32; + mask32 = small_mask32; + } + + bitmap_to_arr32(val32, val, nbits); + if (mask) + bitmap_to_arr32(mask32, mask, nbits); + else + mask32 = NULL; + ret = ethnl_put_bitset32(skb, attrtype, val32, mask32, nbits, names, + compact); + + if (nbits > ETHNL_SMALL_BITMAP_BITS) + kfree(val32); + + return ret; +} + +int ethnl_update_bitset(unsigned long *bitmap, unsigned int nbits, + const struct nlattr *attr, ethnl_string_array_t names, + struct netlink_ext_ack *extack, bool *mod) +{ + u32 small_bitmap32[ETHNL_SMALL_BITMAP_WORDS]; + u32 *bitmap32 = small_bitmap32; + bool u32_mod = false; + int ret; + + if (nbits > ETHNL_SMALL_BITMAP_BITS) { + unsigned int dst_words = DIV_ROUND_UP(nbits, 32); + + bitmap32 = kmalloc_array(dst_words, sizeof(u32), GFP_KERNEL); + if (!bitmap32) + return -ENOMEM; + } + + bitmap_to_arr32(bitmap32, bitmap, nbits); + ret = ethnl_update_bitset32(bitmap32, nbits, attr, names, extack, + &u32_mod); + if (u32_mod) { + bitmap_from_arr32(bitmap, bitmap32, nbits); + *mod = true; + } + + if (nbits > ETHNL_SMALL_BITMAP_BITS) + kfree(bitmap32); + + return ret; +} + +#else + +/* On little endian 64-bit and all 32-bit architectures, an unsigned long + * based bitmap can be interpreted as u32 based one using a simple cast. + */ + +int ethnl_bitset_size(const unsigned long *val, const unsigned long *mask, + unsigned int nbits, ethnl_string_array_t names, + bool compact) +{ + return ethnl_bitset32_size((const u32 *)val, (const u32 *)mask, nbits, + names, compact); +} + +int ethnl_put_bitset(struct sk_buff *skb, int attrtype, + const unsigned long *val, const unsigned long *mask, + unsigned int nbits, ethnl_string_array_t names, + bool compact) +{ + return ethnl_put_bitset32(skb, attrtype, (const u32 *)val, + (const u32 *)mask, nbits, names, compact); +} + +int ethnl_update_bitset(unsigned long *bitmap, unsigned int nbits, + const struct nlattr *attr, ethnl_string_array_t names, + struct netlink_ext_ack *extack, bool *mod) +{ + return ethnl_update_bitset32((u32 *)bitmap, nbits, attr, names, extack, + mod); +} + +#endif /* BITS_PER_LONG == 64 && defined(__BIG_ENDIAN) */ diff --git a/net/ethtool/bitset.h b/net/ethtool/bitset.h new file mode 100644 index 000000000..c2c2e0051 --- /dev/null +++ b/net/ethtool/bitset.h @@ -0,0 +1,34 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ + +#ifndef _NET_ETHTOOL_BITSET_H +#define _NET_ETHTOOL_BITSET_H + +#define ETHNL_MAX_BITSET_SIZE S16_MAX + +typedef const char (*const ethnl_string_array_t)[ETH_GSTRING_LEN]; + +int ethnl_bitset_is_compact(const struct nlattr *bitset, bool *compact); +int ethnl_bitset_size(const unsigned long *val, const unsigned long *mask, + unsigned int nbits, ethnl_string_array_t names, + bool compact); +int ethnl_bitset32_size(const u32 *val, const u32 *mask, unsigned int nbits, + ethnl_string_array_t names, bool compact); +int ethnl_put_bitset(struct sk_buff *skb, int attrtype, + const unsigned long *val, const unsigned long *mask, + unsigned int nbits, ethnl_string_array_t names, + bool compact); +int ethnl_put_bitset32(struct sk_buff *skb, int attrtype, const u32 *val, + const u32 *mask, unsigned int nbits, + ethnl_string_array_t names, bool compact); +int ethnl_update_bitset(unsigned long *bitmap, unsigned int nbits, + const struct nlattr *attr, ethnl_string_array_t names, + struct netlink_ext_ack *extack, bool *mod); +int ethnl_update_bitset32(u32 *bitmap, unsigned int nbits, + const struct nlattr *attr, ethnl_string_array_t names, + struct netlink_ext_ack *extack, bool *mod); +int ethnl_parse_bitset(unsigned long *val, unsigned long *mask, + unsigned int nbits, const struct nlattr *attr, + ethnl_string_array_t names, + struct netlink_ext_ack *extack); + +#endif /* _NET_ETHTOOL_BITSET_H */ diff --git a/net/ethtool/cabletest.c b/net/ethtool/cabletest.c new file mode 100644 index 000000000..06a151165 --- /dev/null +++ b/net/ethtool/cabletest.c @@ -0,0 +1,433 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include +#include +#include "netlink.h" +#include "common.h" + +/* 802.3 standard allows 100 meters for BaseT cables. However longer + * cables might work, depending on the quality of the cables and the + * PHY. So allow testing for up to 150 meters. + */ +#define MAX_CABLE_LENGTH_CM (150 * 100) + +const struct nla_policy ethnl_cable_test_act_policy[] = { + [ETHTOOL_A_CABLE_TEST_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), +}; + +static int ethnl_cable_test_started(struct phy_device *phydev, u8 cmd) +{ + struct sk_buff *skb; + int err = -ENOMEM; + void *ehdr; + + skb = genlmsg_new(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) + goto out; + + ehdr = ethnl_bcastmsg_put(skb, cmd); + if (!ehdr) { + err = -EMSGSIZE; + goto out; + } + + err = ethnl_fill_reply_header(skb, phydev->attached_dev, + ETHTOOL_A_CABLE_TEST_NTF_HEADER); + if (err) + goto out; + + err = nla_put_u8(skb, ETHTOOL_A_CABLE_TEST_NTF_STATUS, + ETHTOOL_A_CABLE_TEST_NTF_STATUS_STARTED); + if (err) + goto out; + + genlmsg_end(skb, ehdr); + + return ethnl_multicast(skb, phydev->attached_dev); + +out: + nlmsg_free(skb); + phydev_err(phydev, "%s: Error %pe\n", __func__, ERR_PTR(err)); + + return err; +} + +int ethnl_act_cable_test(struct sk_buff *skb, struct genl_info *info) +{ + struct ethnl_req_info req_info = {}; + const struct ethtool_phy_ops *ops; + struct nlattr **tb = info->attrs; + struct net_device *dev; + int ret; + + ret = ethnl_parse_header_dev_get(&req_info, + tb[ETHTOOL_A_CABLE_TEST_HEADER], + genl_info_net(info), info->extack, + true); + if (ret < 0) + return ret; + + dev = req_info.dev; + if (!dev->phydev) { + ret = -EOPNOTSUPP; + goto out_dev_put; + } + + rtnl_lock(); + ops = ethtool_phy_ops; + if (!ops || !ops->start_cable_test) { + ret = -EOPNOTSUPP; + goto out_rtnl; + } + + ret = ethnl_ops_begin(dev); + if (ret < 0) + goto out_rtnl; + + ret = ops->start_cable_test(dev->phydev, info->extack); + + ethnl_ops_complete(dev); + + if (!ret) + ethnl_cable_test_started(dev->phydev, + ETHTOOL_MSG_CABLE_TEST_NTF); + +out_rtnl: + rtnl_unlock(); +out_dev_put: + ethnl_parse_header_dev_put(&req_info); + return ret; +} + +int ethnl_cable_test_alloc(struct phy_device *phydev, u8 cmd) +{ + int err = -ENOMEM; + + /* One TDR sample occupies 20 bytes. For a 150 meter cable, + * with four pairs, around 12K is needed. + */ + phydev->skb = genlmsg_new(SZ_16K, GFP_KERNEL); + if (!phydev->skb) + goto out; + + phydev->ehdr = ethnl_bcastmsg_put(phydev->skb, cmd); + if (!phydev->ehdr) { + err = -EMSGSIZE; + goto out; + } + + err = ethnl_fill_reply_header(phydev->skb, phydev->attached_dev, + ETHTOOL_A_CABLE_TEST_NTF_HEADER); + if (err) + goto out; + + err = nla_put_u8(phydev->skb, ETHTOOL_A_CABLE_TEST_NTF_STATUS, + ETHTOOL_A_CABLE_TEST_NTF_STATUS_COMPLETED); + if (err) + goto out; + + phydev->nest = nla_nest_start(phydev->skb, + ETHTOOL_A_CABLE_TEST_NTF_NEST); + if (!phydev->nest) { + err = -EMSGSIZE; + goto out; + } + + return 0; + +out: + nlmsg_free(phydev->skb); + phydev->skb = NULL; + return err; +} +EXPORT_SYMBOL_GPL(ethnl_cable_test_alloc); + +void ethnl_cable_test_free(struct phy_device *phydev) +{ + nlmsg_free(phydev->skb); + phydev->skb = NULL; +} +EXPORT_SYMBOL_GPL(ethnl_cable_test_free); + +void ethnl_cable_test_finished(struct phy_device *phydev) +{ + nla_nest_end(phydev->skb, phydev->nest); + + genlmsg_end(phydev->skb, phydev->ehdr); + + ethnl_multicast(phydev->skb, phydev->attached_dev); +} +EXPORT_SYMBOL_GPL(ethnl_cable_test_finished); + +int ethnl_cable_test_result(struct phy_device *phydev, u8 pair, u8 result) +{ + struct nlattr *nest; + int ret = -EMSGSIZE; + + nest = nla_nest_start(phydev->skb, ETHTOOL_A_CABLE_NEST_RESULT); + if (!nest) + return -EMSGSIZE; + + if (nla_put_u8(phydev->skb, ETHTOOL_A_CABLE_RESULT_PAIR, pair)) + goto err; + if (nla_put_u8(phydev->skb, ETHTOOL_A_CABLE_RESULT_CODE, result)) + goto err; + + nla_nest_end(phydev->skb, nest); + return 0; + +err: + nla_nest_cancel(phydev->skb, nest); + return ret; +} +EXPORT_SYMBOL_GPL(ethnl_cable_test_result); + +int ethnl_cable_test_fault_length(struct phy_device *phydev, u8 pair, u32 cm) +{ + struct nlattr *nest; + int ret = -EMSGSIZE; + + nest = nla_nest_start(phydev->skb, + ETHTOOL_A_CABLE_NEST_FAULT_LENGTH); + if (!nest) + return -EMSGSIZE; + + if (nla_put_u8(phydev->skb, ETHTOOL_A_CABLE_FAULT_LENGTH_PAIR, pair)) + goto err; + if (nla_put_u32(phydev->skb, ETHTOOL_A_CABLE_FAULT_LENGTH_CM, cm)) + goto err; + + nla_nest_end(phydev->skb, nest); + return 0; + +err: + nla_nest_cancel(phydev->skb, nest); + return ret; +} +EXPORT_SYMBOL_GPL(ethnl_cable_test_fault_length); + +struct cable_test_tdr_req_info { + struct ethnl_req_info base; +}; + +static const struct nla_policy cable_test_tdr_act_cfg_policy[] = { + [ETHTOOL_A_CABLE_TEST_TDR_CFG_FIRST] = { .type = NLA_U32 }, + [ETHTOOL_A_CABLE_TEST_TDR_CFG_LAST] = { .type = NLA_U32 }, + [ETHTOOL_A_CABLE_TEST_TDR_CFG_STEP] = { .type = NLA_U32 }, + [ETHTOOL_A_CABLE_TEST_TDR_CFG_PAIR] = { .type = NLA_U8 }, +}; + +const struct nla_policy ethnl_cable_test_tdr_act_policy[] = { + [ETHTOOL_A_CABLE_TEST_TDR_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_CABLE_TEST_TDR_CFG] = { .type = NLA_NESTED }, +}; + +/* CABLE_TEST_TDR_ACT */ +static int ethnl_act_cable_test_tdr_cfg(const struct nlattr *nest, + struct genl_info *info, + struct phy_tdr_config *cfg) +{ + struct nlattr *tb[ARRAY_SIZE(cable_test_tdr_act_cfg_policy)]; + int ret; + + cfg->first = 100; + cfg->step = 100; + cfg->last = MAX_CABLE_LENGTH_CM; + cfg->pair = PHY_PAIR_ALL; + + if (!nest) + return 0; + + ret = nla_parse_nested(tb, + ARRAY_SIZE(cable_test_tdr_act_cfg_policy) - 1, + nest, cable_test_tdr_act_cfg_policy, + info->extack); + if (ret < 0) + return ret; + + if (tb[ETHTOOL_A_CABLE_TEST_TDR_CFG_FIRST]) + cfg->first = nla_get_u32( + tb[ETHTOOL_A_CABLE_TEST_TDR_CFG_FIRST]); + + if (tb[ETHTOOL_A_CABLE_TEST_TDR_CFG_LAST]) + cfg->last = nla_get_u32(tb[ETHTOOL_A_CABLE_TEST_TDR_CFG_LAST]); + + if (tb[ETHTOOL_A_CABLE_TEST_TDR_CFG_STEP]) + cfg->step = nla_get_u32(tb[ETHTOOL_A_CABLE_TEST_TDR_CFG_STEP]); + + if (tb[ETHTOOL_A_CABLE_TEST_TDR_CFG_PAIR]) { + cfg->pair = nla_get_u8(tb[ETHTOOL_A_CABLE_TEST_TDR_CFG_PAIR]); + if (cfg->pair > ETHTOOL_A_CABLE_PAIR_D) { + NL_SET_ERR_MSG_ATTR( + info->extack, + tb[ETHTOOL_A_CABLE_TEST_TDR_CFG_PAIR], + "invalid pair parameter"); + return -EINVAL; + } + } + + if (cfg->first > MAX_CABLE_LENGTH_CM) { + NL_SET_ERR_MSG_ATTR(info->extack, + tb[ETHTOOL_A_CABLE_TEST_TDR_CFG_FIRST], + "invalid first parameter"); + return -EINVAL; + } + + if (cfg->last > MAX_CABLE_LENGTH_CM) { + NL_SET_ERR_MSG_ATTR(info->extack, + tb[ETHTOOL_A_CABLE_TEST_TDR_CFG_LAST], + "invalid last parameter"); + return -EINVAL; + } + + if (cfg->first > cfg->last) { + NL_SET_ERR_MSG(info->extack, "invalid first/last parameter"); + return -EINVAL; + } + + if (!cfg->step) { + NL_SET_ERR_MSG_ATTR(info->extack, + tb[ETHTOOL_A_CABLE_TEST_TDR_CFG_STEP], + "invalid step parameter"); + return -EINVAL; + } + + if (cfg->step > (cfg->last - cfg->first)) { + NL_SET_ERR_MSG_ATTR(info->extack, + tb[ETHTOOL_A_CABLE_TEST_TDR_CFG_STEP], + "step parameter too big"); + return -EINVAL; + } + + return 0; +} + +int ethnl_act_cable_test_tdr(struct sk_buff *skb, struct genl_info *info) +{ + struct ethnl_req_info req_info = {}; + const struct ethtool_phy_ops *ops; + struct nlattr **tb = info->attrs; + struct phy_tdr_config cfg; + struct net_device *dev; + int ret; + + ret = ethnl_parse_header_dev_get(&req_info, + tb[ETHTOOL_A_CABLE_TEST_TDR_HEADER], + genl_info_net(info), info->extack, + true); + if (ret < 0) + return ret; + + dev = req_info.dev; + if (!dev->phydev) { + ret = -EOPNOTSUPP; + goto out_dev_put; + } + + ret = ethnl_act_cable_test_tdr_cfg(tb[ETHTOOL_A_CABLE_TEST_TDR_CFG], + info, &cfg); + if (ret) + goto out_dev_put; + + rtnl_lock(); + ops = ethtool_phy_ops; + if (!ops || !ops->start_cable_test_tdr) { + ret = -EOPNOTSUPP; + goto out_rtnl; + } + + ret = ethnl_ops_begin(dev); + if (ret < 0) + goto out_rtnl; + + ret = ops->start_cable_test_tdr(dev->phydev, info->extack, &cfg); + + ethnl_ops_complete(dev); + + if (!ret) + ethnl_cable_test_started(dev->phydev, + ETHTOOL_MSG_CABLE_TEST_TDR_NTF); + +out_rtnl: + rtnl_unlock(); +out_dev_put: + ethnl_parse_header_dev_put(&req_info); + return ret; +} + +int ethnl_cable_test_amplitude(struct phy_device *phydev, + u8 pair, s16 mV) +{ + struct nlattr *nest; + int ret = -EMSGSIZE; + + nest = nla_nest_start(phydev->skb, + ETHTOOL_A_CABLE_TDR_NEST_AMPLITUDE); + if (!nest) + return -EMSGSIZE; + + if (nla_put_u8(phydev->skb, ETHTOOL_A_CABLE_AMPLITUDE_PAIR, pair)) + goto err; + if (nla_put_u16(phydev->skb, ETHTOOL_A_CABLE_AMPLITUDE_mV, mV)) + goto err; + + nla_nest_end(phydev->skb, nest); + return 0; + +err: + nla_nest_cancel(phydev->skb, nest); + return ret; +} +EXPORT_SYMBOL_GPL(ethnl_cable_test_amplitude); + +int ethnl_cable_test_pulse(struct phy_device *phydev, u16 mV) +{ + struct nlattr *nest; + int ret = -EMSGSIZE; + + nest = nla_nest_start(phydev->skb, ETHTOOL_A_CABLE_TDR_NEST_PULSE); + if (!nest) + return -EMSGSIZE; + + if (nla_put_u16(phydev->skb, ETHTOOL_A_CABLE_PULSE_mV, mV)) + goto err; + + nla_nest_end(phydev->skb, nest); + return 0; + +err: + nla_nest_cancel(phydev->skb, nest); + return ret; +} +EXPORT_SYMBOL_GPL(ethnl_cable_test_pulse); + +int ethnl_cable_test_step(struct phy_device *phydev, u32 first, u32 last, + u32 step) +{ + struct nlattr *nest; + int ret = -EMSGSIZE; + + nest = nla_nest_start(phydev->skb, ETHTOOL_A_CABLE_TDR_NEST_STEP); + if (!nest) + return -EMSGSIZE; + + if (nla_put_u32(phydev->skb, ETHTOOL_A_CABLE_STEP_FIRST_DISTANCE, + first)) + goto err; + + if (nla_put_u32(phydev->skb, ETHTOOL_A_CABLE_STEP_LAST_DISTANCE, last)) + goto err; + + if (nla_put_u32(phydev->skb, ETHTOOL_A_CABLE_STEP_STEP_DISTANCE, step)) + goto err; + + nla_nest_end(phydev->skb, nest); + return 0; + +err: + nla_nest_cancel(phydev->skb, nest); + return ret; +} +EXPORT_SYMBOL_GPL(ethnl_cable_test_step); diff --git a/net/ethtool/channels.c b/net/ethtool/channels.c new file mode 100644 index 000000000..403158862 --- /dev/null +++ b/net/ethtool/channels.c @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include + +#include "netlink.h" +#include "common.h" + +struct channels_req_info { + struct ethnl_req_info base; +}; + +struct channels_reply_data { + struct ethnl_reply_data base; + struct ethtool_channels channels; +}; + +#define CHANNELS_REPDATA(__reply_base) \ + container_of(__reply_base, struct channels_reply_data, base) + +const struct nla_policy ethnl_channels_get_policy[] = { + [ETHTOOL_A_CHANNELS_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), +}; + +static int channels_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + struct channels_reply_data *data = CHANNELS_REPDATA(reply_base); + struct net_device *dev = reply_base->dev; + int ret; + + if (!dev->ethtool_ops->get_channels) + return -EOPNOTSUPP; + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + dev->ethtool_ops->get_channels(dev, &data->channels); + ethnl_ops_complete(dev); + + return 0; +} + +static int channels_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + return nla_total_size(sizeof(u32)) + /* _CHANNELS_RX_MAX */ + nla_total_size(sizeof(u32)) + /* _CHANNELS_TX_MAX */ + nla_total_size(sizeof(u32)) + /* _CHANNELS_OTHER_MAX */ + nla_total_size(sizeof(u32)) + /* _CHANNELS_COMBINED_MAX */ + nla_total_size(sizeof(u32)) + /* _CHANNELS_RX_COUNT */ + nla_total_size(sizeof(u32)) + /* _CHANNELS_TX_COUNT */ + nla_total_size(sizeof(u32)) + /* _CHANNELS_OTHER_COUNT */ + nla_total_size(sizeof(u32)); /* _CHANNELS_COMBINED_COUNT */ +} + +static int channels_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct channels_reply_data *data = CHANNELS_REPDATA(reply_base); + const struct ethtool_channels *channels = &data->channels; + + if ((channels->max_rx && + (nla_put_u32(skb, ETHTOOL_A_CHANNELS_RX_MAX, + channels->max_rx) || + nla_put_u32(skb, ETHTOOL_A_CHANNELS_RX_COUNT, + channels->rx_count))) || + (channels->max_tx && + (nla_put_u32(skb, ETHTOOL_A_CHANNELS_TX_MAX, + channels->max_tx) || + nla_put_u32(skb, ETHTOOL_A_CHANNELS_TX_COUNT, + channels->tx_count))) || + (channels->max_other && + (nla_put_u32(skb, ETHTOOL_A_CHANNELS_OTHER_MAX, + channels->max_other) || + nla_put_u32(skb, ETHTOOL_A_CHANNELS_OTHER_COUNT, + channels->other_count))) || + (channels->max_combined && + (nla_put_u32(skb, ETHTOOL_A_CHANNELS_COMBINED_MAX, + channels->max_combined) || + nla_put_u32(skb, ETHTOOL_A_CHANNELS_COMBINED_COUNT, + channels->combined_count)))) + return -EMSGSIZE; + + return 0; +} + +const struct ethnl_request_ops ethnl_channels_request_ops = { + .request_cmd = ETHTOOL_MSG_CHANNELS_GET, + .reply_cmd = ETHTOOL_MSG_CHANNELS_GET_REPLY, + .hdr_attr = ETHTOOL_A_CHANNELS_HEADER, + .req_info_size = sizeof(struct channels_req_info), + .reply_data_size = sizeof(struct channels_reply_data), + + .prepare_data = channels_prepare_data, + .reply_size = channels_reply_size, + .fill_reply = channels_fill_reply, +}; + +/* CHANNELS_SET */ + +const struct nla_policy ethnl_channels_set_policy[] = { + [ETHTOOL_A_CHANNELS_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_CHANNELS_RX_COUNT] = { .type = NLA_U32 }, + [ETHTOOL_A_CHANNELS_TX_COUNT] = { .type = NLA_U32 }, + [ETHTOOL_A_CHANNELS_OTHER_COUNT] = { .type = NLA_U32 }, + [ETHTOOL_A_CHANNELS_COMBINED_COUNT] = { .type = NLA_U32 }, +}; + +int ethnl_set_channels(struct sk_buff *skb, struct genl_info *info) +{ + unsigned int from_channel, old_total, i; + bool mod = false, mod_combined = false; + struct ethtool_channels channels = {}; + struct ethnl_req_info req_info = {}; + struct nlattr **tb = info->attrs; + u32 err_attr, max_rx_in_use = 0; + const struct ethtool_ops *ops; + struct net_device *dev; + int ret; + + ret = ethnl_parse_header_dev_get(&req_info, + tb[ETHTOOL_A_CHANNELS_HEADER], + genl_info_net(info), info->extack, + true); + if (ret < 0) + return ret; + dev = req_info.dev; + ops = dev->ethtool_ops; + ret = -EOPNOTSUPP; + if (!ops->get_channels || !ops->set_channels) + goto out_dev; + + rtnl_lock(); + ret = ethnl_ops_begin(dev); + if (ret < 0) + goto out_rtnl; + ops->get_channels(dev, &channels); + old_total = channels.combined_count + + max(channels.rx_count, channels.tx_count); + + ethnl_update_u32(&channels.rx_count, tb[ETHTOOL_A_CHANNELS_RX_COUNT], + &mod); + ethnl_update_u32(&channels.tx_count, tb[ETHTOOL_A_CHANNELS_TX_COUNT], + &mod); + ethnl_update_u32(&channels.other_count, + tb[ETHTOOL_A_CHANNELS_OTHER_COUNT], &mod); + ethnl_update_u32(&channels.combined_count, + tb[ETHTOOL_A_CHANNELS_COMBINED_COUNT], &mod_combined); + mod |= mod_combined; + ret = 0; + if (!mod) + goto out_ops; + + /* ensure new channel counts are within limits */ + if (channels.rx_count > channels.max_rx) + err_attr = ETHTOOL_A_CHANNELS_RX_COUNT; + else if (channels.tx_count > channels.max_tx) + err_attr = ETHTOOL_A_CHANNELS_TX_COUNT; + else if (channels.other_count > channels.max_other) + err_attr = ETHTOOL_A_CHANNELS_OTHER_COUNT; + else if (channels.combined_count > channels.max_combined) + err_attr = ETHTOOL_A_CHANNELS_COMBINED_COUNT; + else + err_attr = 0; + if (err_attr) { + ret = -EINVAL; + NL_SET_ERR_MSG_ATTR(info->extack, tb[err_attr], + "requested channel count exceeds maximum"); + goto out_ops; + } + + /* ensure there is at least one RX and one TX channel */ + if (!channels.combined_count && !channels.rx_count) + err_attr = ETHTOOL_A_CHANNELS_RX_COUNT; + else if (!channels.combined_count && !channels.tx_count) + err_attr = ETHTOOL_A_CHANNELS_TX_COUNT; + else + err_attr = 0; + if (err_attr) { + if (mod_combined) + err_attr = ETHTOOL_A_CHANNELS_COMBINED_COUNT; + ret = -EINVAL; + NL_SET_ERR_MSG_ATTR(info->extack, tb[err_attr], + "requested channel counts would result in no RX or TX channel being configured"); + goto out_ops; + } + + /* ensure the new Rx count fits within the configured Rx flow + * indirection table settings + */ + if (netif_is_rxfh_configured(dev) && + !ethtool_get_max_rxfh_channel(dev, &max_rx_in_use) && + (channels.combined_count + channels.rx_count) <= max_rx_in_use) { + ret = -EINVAL; + GENL_SET_ERR_MSG(info, "requested channel counts are too low for existing indirection table settings"); + goto out_ops; + } + + /* Disabling channels, query zero-copy AF_XDP sockets */ + from_channel = channels.combined_count + + min(channels.rx_count, channels.tx_count); + for (i = from_channel; i < old_total; i++) + if (xsk_get_pool_from_qid(dev, i)) { + ret = -EINVAL; + GENL_SET_ERR_MSG(info, "requested channel counts are too low for existing zerocopy AF_XDP sockets"); + goto out_ops; + } + + ret = dev->ethtool_ops->set_channels(dev, &channels); + if (ret < 0) + goto out_ops; + ethtool_notify(dev, ETHTOOL_MSG_CHANNELS_NTF, NULL); + +out_ops: + ethnl_ops_complete(dev); +out_rtnl: + rtnl_unlock(); +out_dev: + ethnl_parse_header_dev_put(&req_info); + return ret; +} diff --git a/net/ethtool/coalesce.c b/net/ethtool/coalesce.c new file mode 100644 index 000000000..487bdf345 --- /dev/null +++ b/net/ethtool/coalesce.c @@ -0,0 +1,341 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include "netlink.h" +#include "common.h" + +struct coalesce_req_info { + struct ethnl_req_info base; +}; + +struct coalesce_reply_data { + struct ethnl_reply_data base; + struct ethtool_coalesce coalesce; + struct kernel_ethtool_coalesce kernel_coalesce; + u32 supported_params; +}; + +#define COALESCE_REPDATA(__reply_base) \ + container_of(__reply_base, struct coalesce_reply_data, base) + +#define __SUPPORTED_OFFSET ETHTOOL_A_COALESCE_RX_USECS +static u32 attr_to_mask(unsigned int attr_type) +{ + return BIT(attr_type - __SUPPORTED_OFFSET); +} + +/* build time check that indices in ethtool_ops::supported_coalesce_params + * match corresponding attribute types with an offset + */ +#define __CHECK_SUPPORTED_OFFSET(x) \ + static_assert((ETHTOOL_ ## x) == \ + BIT((ETHTOOL_A_ ## x) - __SUPPORTED_OFFSET)) +__CHECK_SUPPORTED_OFFSET(COALESCE_RX_USECS); +__CHECK_SUPPORTED_OFFSET(COALESCE_RX_MAX_FRAMES); +__CHECK_SUPPORTED_OFFSET(COALESCE_RX_USECS_IRQ); +__CHECK_SUPPORTED_OFFSET(COALESCE_RX_MAX_FRAMES_IRQ); +__CHECK_SUPPORTED_OFFSET(COALESCE_TX_USECS); +__CHECK_SUPPORTED_OFFSET(COALESCE_TX_MAX_FRAMES); +__CHECK_SUPPORTED_OFFSET(COALESCE_TX_USECS_IRQ); +__CHECK_SUPPORTED_OFFSET(COALESCE_TX_MAX_FRAMES_IRQ); +__CHECK_SUPPORTED_OFFSET(COALESCE_STATS_BLOCK_USECS); +__CHECK_SUPPORTED_OFFSET(COALESCE_USE_ADAPTIVE_RX); +__CHECK_SUPPORTED_OFFSET(COALESCE_USE_ADAPTIVE_TX); +__CHECK_SUPPORTED_OFFSET(COALESCE_PKT_RATE_LOW); +__CHECK_SUPPORTED_OFFSET(COALESCE_RX_USECS_LOW); +__CHECK_SUPPORTED_OFFSET(COALESCE_RX_MAX_FRAMES_LOW); +__CHECK_SUPPORTED_OFFSET(COALESCE_TX_USECS_LOW); +__CHECK_SUPPORTED_OFFSET(COALESCE_TX_MAX_FRAMES_LOW); +__CHECK_SUPPORTED_OFFSET(COALESCE_PKT_RATE_HIGH); +__CHECK_SUPPORTED_OFFSET(COALESCE_RX_USECS_HIGH); +__CHECK_SUPPORTED_OFFSET(COALESCE_RX_MAX_FRAMES_HIGH); +__CHECK_SUPPORTED_OFFSET(COALESCE_TX_USECS_HIGH); +__CHECK_SUPPORTED_OFFSET(COALESCE_TX_MAX_FRAMES_HIGH); +__CHECK_SUPPORTED_OFFSET(COALESCE_RATE_SAMPLE_INTERVAL); + +const struct nla_policy ethnl_coalesce_get_policy[] = { + [ETHTOOL_A_COALESCE_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), +}; + +static int coalesce_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + struct coalesce_reply_data *data = COALESCE_REPDATA(reply_base); + struct netlink_ext_ack *extack = info ? info->extack : NULL; + struct net_device *dev = reply_base->dev; + int ret; + + if (!dev->ethtool_ops->get_coalesce) + return -EOPNOTSUPP; + data->supported_params = dev->ethtool_ops->supported_coalesce_params; + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + ret = dev->ethtool_ops->get_coalesce(dev, &data->coalesce, + &data->kernel_coalesce, extack); + ethnl_ops_complete(dev); + + return ret; +} + +static int coalesce_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + return nla_total_size(sizeof(u32)) + /* _RX_USECS */ + nla_total_size(sizeof(u32)) + /* _RX_MAX_FRAMES */ + nla_total_size(sizeof(u32)) + /* _RX_USECS_IRQ */ + nla_total_size(sizeof(u32)) + /* _RX_MAX_FRAMES_IRQ */ + nla_total_size(sizeof(u32)) + /* _TX_USECS */ + nla_total_size(sizeof(u32)) + /* _TX_MAX_FRAMES */ + nla_total_size(sizeof(u32)) + /* _TX_USECS_IRQ */ + nla_total_size(sizeof(u32)) + /* _TX_MAX_FRAMES_IRQ */ + nla_total_size(sizeof(u32)) + /* _STATS_BLOCK_USECS */ + nla_total_size(sizeof(u8)) + /* _USE_ADAPTIVE_RX */ + nla_total_size(sizeof(u8)) + /* _USE_ADAPTIVE_TX */ + nla_total_size(sizeof(u32)) + /* _PKT_RATE_LOW */ + nla_total_size(sizeof(u32)) + /* _RX_USECS_LOW */ + nla_total_size(sizeof(u32)) + /* _RX_MAX_FRAMES_LOW */ + nla_total_size(sizeof(u32)) + /* _TX_USECS_LOW */ + nla_total_size(sizeof(u32)) + /* _TX_MAX_FRAMES_LOW */ + nla_total_size(sizeof(u32)) + /* _PKT_RATE_HIGH */ + nla_total_size(sizeof(u32)) + /* _RX_USECS_HIGH */ + nla_total_size(sizeof(u32)) + /* _RX_MAX_FRAMES_HIGH */ + nla_total_size(sizeof(u32)) + /* _TX_USECS_HIGH */ + nla_total_size(sizeof(u32)) + /* _TX_MAX_FRAMES_HIGH */ + nla_total_size(sizeof(u32)) + /* _RATE_SAMPLE_INTERVAL */ + nla_total_size(sizeof(u8)) + /* _USE_CQE_MODE_TX */ + nla_total_size(sizeof(u8)); /* _USE_CQE_MODE_RX */ +} + +static bool coalesce_put_u32(struct sk_buff *skb, u16 attr_type, u32 val, + u32 supported_params) +{ + if (!val && !(supported_params & attr_to_mask(attr_type))) + return false; + return nla_put_u32(skb, attr_type, val); +} + +static bool coalesce_put_bool(struct sk_buff *skb, u16 attr_type, u32 val, + u32 supported_params) +{ + if (!val && !(supported_params & attr_to_mask(attr_type))) + return false; + return nla_put_u8(skb, attr_type, !!val); +} + +static int coalesce_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct coalesce_reply_data *data = COALESCE_REPDATA(reply_base); + const struct kernel_ethtool_coalesce *kcoal = &data->kernel_coalesce; + const struct ethtool_coalesce *coal = &data->coalesce; + u32 supported = data->supported_params; + + if (coalesce_put_u32(skb, ETHTOOL_A_COALESCE_RX_USECS, + coal->rx_coalesce_usecs, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_RX_MAX_FRAMES, + coal->rx_max_coalesced_frames, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_RX_USECS_IRQ, + coal->rx_coalesce_usecs_irq, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_RX_MAX_FRAMES_IRQ, + coal->rx_max_coalesced_frames_irq, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_TX_USECS, + coal->tx_coalesce_usecs, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_TX_MAX_FRAMES, + coal->tx_max_coalesced_frames, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_TX_USECS_IRQ, + coal->tx_coalesce_usecs_irq, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_TX_MAX_FRAMES_IRQ, + coal->tx_max_coalesced_frames_irq, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_STATS_BLOCK_USECS, + coal->stats_block_coalesce_usecs, supported) || + coalesce_put_bool(skb, ETHTOOL_A_COALESCE_USE_ADAPTIVE_RX, + coal->use_adaptive_rx_coalesce, supported) || + coalesce_put_bool(skb, ETHTOOL_A_COALESCE_USE_ADAPTIVE_TX, + coal->use_adaptive_tx_coalesce, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_PKT_RATE_LOW, + coal->pkt_rate_low, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_RX_USECS_LOW, + coal->rx_coalesce_usecs_low, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_RX_MAX_FRAMES_LOW, + coal->rx_max_coalesced_frames_low, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_TX_USECS_LOW, + coal->tx_coalesce_usecs_low, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_TX_MAX_FRAMES_LOW, + coal->tx_max_coalesced_frames_low, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_PKT_RATE_HIGH, + coal->pkt_rate_high, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_RX_USECS_HIGH, + coal->rx_coalesce_usecs_high, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_RX_MAX_FRAMES_HIGH, + coal->rx_max_coalesced_frames_high, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_TX_USECS_HIGH, + coal->tx_coalesce_usecs_high, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_TX_MAX_FRAMES_HIGH, + coal->tx_max_coalesced_frames_high, supported) || + coalesce_put_u32(skb, ETHTOOL_A_COALESCE_RATE_SAMPLE_INTERVAL, + coal->rate_sample_interval, supported) || + coalesce_put_bool(skb, ETHTOOL_A_COALESCE_USE_CQE_MODE_TX, + kcoal->use_cqe_mode_tx, supported) || + coalesce_put_bool(skb, ETHTOOL_A_COALESCE_USE_CQE_MODE_RX, + kcoal->use_cqe_mode_rx, supported)) + return -EMSGSIZE; + + return 0; +} + +const struct ethnl_request_ops ethnl_coalesce_request_ops = { + .request_cmd = ETHTOOL_MSG_COALESCE_GET, + .reply_cmd = ETHTOOL_MSG_COALESCE_GET_REPLY, + .hdr_attr = ETHTOOL_A_COALESCE_HEADER, + .req_info_size = sizeof(struct coalesce_req_info), + .reply_data_size = sizeof(struct coalesce_reply_data), + + .prepare_data = coalesce_prepare_data, + .reply_size = coalesce_reply_size, + .fill_reply = coalesce_fill_reply, +}; + +/* COALESCE_SET */ + +const struct nla_policy ethnl_coalesce_set_policy[] = { + [ETHTOOL_A_COALESCE_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_COALESCE_RX_USECS] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_RX_MAX_FRAMES] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_RX_USECS_IRQ] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_RX_MAX_FRAMES_IRQ] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_TX_USECS] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_TX_MAX_FRAMES] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_TX_USECS_IRQ] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_TX_MAX_FRAMES_IRQ] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_STATS_BLOCK_USECS] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_USE_ADAPTIVE_RX] = { .type = NLA_U8 }, + [ETHTOOL_A_COALESCE_USE_ADAPTIVE_TX] = { .type = NLA_U8 }, + [ETHTOOL_A_COALESCE_PKT_RATE_LOW] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_RX_USECS_LOW] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_RX_MAX_FRAMES_LOW] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_TX_USECS_LOW] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_TX_MAX_FRAMES_LOW] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_PKT_RATE_HIGH] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_RX_USECS_HIGH] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_RX_MAX_FRAMES_HIGH] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_TX_USECS_HIGH] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_TX_MAX_FRAMES_HIGH] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_RATE_SAMPLE_INTERVAL] = { .type = NLA_U32 }, + [ETHTOOL_A_COALESCE_USE_CQE_MODE_TX] = NLA_POLICY_MAX(NLA_U8, 1), + [ETHTOOL_A_COALESCE_USE_CQE_MODE_RX] = NLA_POLICY_MAX(NLA_U8, 1), +}; + +int ethnl_set_coalesce(struct sk_buff *skb, struct genl_info *info) +{ + struct kernel_ethtool_coalesce kernel_coalesce = {}; + struct ethtool_coalesce coalesce = {}; + struct ethnl_req_info req_info = {}; + struct nlattr **tb = info->attrs; + const struct ethtool_ops *ops; + struct net_device *dev; + u32 supported_params; + bool mod = false; + int ret; + u16 a; + + ret = ethnl_parse_header_dev_get(&req_info, + tb[ETHTOOL_A_COALESCE_HEADER], + genl_info_net(info), info->extack, + true); + if (ret < 0) + return ret; + dev = req_info.dev; + ops = dev->ethtool_ops; + ret = -EOPNOTSUPP; + if (!ops->get_coalesce || !ops->set_coalesce) + goto out_dev; + + /* make sure that only supported parameters are present */ + supported_params = ops->supported_coalesce_params; + for (a = ETHTOOL_A_COALESCE_RX_USECS; a < __ETHTOOL_A_COALESCE_CNT; a++) + if (tb[a] && !(supported_params & attr_to_mask(a))) { + ret = -EINVAL; + NL_SET_ERR_MSG_ATTR(info->extack, tb[a], + "cannot modify an unsupported parameter"); + goto out_dev; + } + + rtnl_lock(); + ret = ethnl_ops_begin(dev); + if (ret < 0) + goto out_rtnl; + ret = ops->get_coalesce(dev, &coalesce, &kernel_coalesce, + info->extack); + if (ret < 0) + goto out_ops; + + ethnl_update_u32(&coalesce.rx_coalesce_usecs, + tb[ETHTOOL_A_COALESCE_RX_USECS], &mod); + ethnl_update_u32(&coalesce.rx_max_coalesced_frames, + tb[ETHTOOL_A_COALESCE_RX_MAX_FRAMES], &mod); + ethnl_update_u32(&coalesce.rx_coalesce_usecs_irq, + tb[ETHTOOL_A_COALESCE_RX_USECS_IRQ], &mod); + ethnl_update_u32(&coalesce.rx_max_coalesced_frames_irq, + tb[ETHTOOL_A_COALESCE_RX_MAX_FRAMES_IRQ], &mod); + ethnl_update_u32(&coalesce.tx_coalesce_usecs, + tb[ETHTOOL_A_COALESCE_TX_USECS], &mod); + ethnl_update_u32(&coalesce.tx_max_coalesced_frames, + tb[ETHTOOL_A_COALESCE_TX_MAX_FRAMES], &mod); + ethnl_update_u32(&coalesce.tx_coalesce_usecs_irq, + tb[ETHTOOL_A_COALESCE_TX_USECS_IRQ], &mod); + ethnl_update_u32(&coalesce.tx_max_coalesced_frames_irq, + tb[ETHTOOL_A_COALESCE_TX_MAX_FRAMES_IRQ], &mod); + ethnl_update_u32(&coalesce.stats_block_coalesce_usecs, + tb[ETHTOOL_A_COALESCE_STATS_BLOCK_USECS], &mod); + ethnl_update_bool32(&coalesce.use_adaptive_rx_coalesce, + tb[ETHTOOL_A_COALESCE_USE_ADAPTIVE_RX], &mod); + ethnl_update_bool32(&coalesce.use_adaptive_tx_coalesce, + tb[ETHTOOL_A_COALESCE_USE_ADAPTIVE_TX], &mod); + ethnl_update_u32(&coalesce.pkt_rate_low, + tb[ETHTOOL_A_COALESCE_PKT_RATE_LOW], &mod); + ethnl_update_u32(&coalesce.rx_coalesce_usecs_low, + tb[ETHTOOL_A_COALESCE_RX_USECS_LOW], &mod); + ethnl_update_u32(&coalesce.rx_max_coalesced_frames_low, + tb[ETHTOOL_A_COALESCE_RX_MAX_FRAMES_LOW], &mod); + ethnl_update_u32(&coalesce.tx_coalesce_usecs_low, + tb[ETHTOOL_A_COALESCE_TX_USECS_LOW], &mod); + ethnl_update_u32(&coalesce.tx_max_coalesced_frames_low, + tb[ETHTOOL_A_COALESCE_TX_MAX_FRAMES_LOW], &mod); + ethnl_update_u32(&coalesce.pkt_rate_high, + tb[ETHTOOL_A_COALESCE_PKT_RATE_HIGH], &mod); + ethnl_update_u32(&coalesce.rx_coalesce_usecs_high, + tb[ETHTOOL_A_COALESCE_RX_USECS_HIGH], &mod); + ethnl_update_u32(&coalesce.rx_max_coalesced_frames_high, + tb[ETHTOOL_A_COALESCE_RX_MAX_FRAMES_HIGH], &mod); + ethnl_update_u32(&coalesce.tx_coalesce_usecs_high, + tb[ETHTOOL_A_COALESCE_TX_USECS_HIGH], &mod); + ethnl_update_u32(&coalesce.tx_max_coalesced_frames_high, + tb[ETHTOOL_A_COALESCE_TX_MAX_FRAMES_HIGH], &mod); + ethnl_update_u32(&coalesce.rate_sample_interval, + tb[ETHTOOL_A_COALESCE_RATE_SAMPLE_INTERVAL], &mod); + ethnl_update_u8(&kernel_coalesce.use_cqe_mode_tx, + tb[ETHTOOL_A_COALESCE_USE_CQE_MODE_TX], &mod); + ethnl_update_u8(&kernel_coalesce.use_cqe_mode_rx, + tb[ETHTOOL_A_COALESCE_USE_CQE_MODE_RX], &mod); + ret = 0; + if (!mod) + goto out_ops; + + ret = dev->ethtool_ops->set_coalesce(dev, &coalesce, &kernel_coalesce, + info->extack); + if (ret < 0) + goto out_ops; + ethtool_notify(dev, ETHTOOL_MSG_COALESCE_NTF, NULL); + +out_ops: + ethnl_ops_complete(dev); +out_rtnl: + rtnl_unlock(); +out_dev: + ethnl_parse_header_dev_put(&req_info); + return ret; +} diff --git a/net/ethtool/common.c b/net/ethtool/common.c new file mode 100644 index 000000000..566adf85e --- /dev/null +++ b/net/ethtool/common.c @@ -0,0 +1,599 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include +#include +#include +#include +#include + +#include "common.h" + +const char netdev_features_strings[NETDEV_FEATURE_COUNT][ETH_GSTRING_LEN] = { + [NETIF_F_SG_BIT] = "tx-scatter-gather", + [NETIF_F_IP_CSUM_BIT] = "tx-checksum-ipv4", + [NETIF_F_HW_CSUM_BIT] = "tx-checksum-ip-generic", + [NETIF_F_IPV6_CSUM_BIT] = "tx-checksum-ipv6", + [NETIF_F_HIGHDMA_BIT] = "highdma", + [NETIF_F_FRAGLIST_BIT] = "tx-scatter-gather-fraglist", + [NETIF_F_HW_VLAN_CTAG_TX_BIT] = "tx-vlan-hw-insert", + + [NETIF_F_HW_VLAN_CTAG_RX_BIT] = "rx-vlan-hw-parse", + [NETIF_F_HW_VLAN_CTAG_FILTER_BIT] = "rx-vlan-filter", + [NETIF_F_HW_VLAN_STAG_TX_BIT] = "tx-vlan-stag-hw-insert", + [NETIF_F_HW_VLAN_STAG_RX_BIT] = "rx-vlan-stag-hw-parse", + [NETIF_F_HW_VLAN_STAG_FILTER_BIT] = "rx-vlan-stag-filter", + [NETIF_F_VLAN_CHALLENGED_BIT] = "vlan-challenged", + [NETIF_F_GSO_BIT] = "tx-generic-segmentation", + [NETIF_F_LLTX_BIT] = "tx-lockless", + [NETIF_F_NETNS_LOCAL_BIT] = "netns-local", + [NETIF_F_GRO_BIT] = "rx-gro", + [NETIF_F_GRO_HW_BIT] = "rx-gro-hw", + [NETIF_F_LRO_BIT] = "rx-lro", + + [NETIF_F_TSO_BIT] = "tx-tcp-segmentation", + [NETIF_F_GSO_ROBUST_BIT] = "tx-gso-robust", + [NETIF_F_TSO_ECN_BIT] = "tx-tcp-ecn-segmentation", + [NETIF_F_TSO_MANGLEID_BIT] = "tx-tcp-mangleid-segmentation", + [NETIF_F_TSO6_BIT] = "tx-tcp6-segmentation", + [NETIF_F_FSO_BIT] = "tx-fcoe-segmentation", + [NETIF_F_GSO_GRE_BIT] = "tx-gre-segmentation", + [NETIF_F_GSO_GRE_CSUM_BIT] = "tx-gre-csum-segmentation", + [NETIF_F_GSO_IPXIP4_BIT] = "tx-ipxip4-segmentation", + [NETIF_F_GSO_IPXIP6_BIT] = "tx-ipxip6-segmentation", + [NETIF_F_GSO_UDP_TUNNEL_BIT] = "tx-udp_tnl-segmentation", + [NETIF_F_GSO_UDP_TUNNEL_CSUM_BIT] = "tx-udp_tnl-csum-segmentation", + [NETIF_F_GSO_PARTIAL_BIT] = "tx-gso-partial", + [NETIF_F_GSO_TUNNEL_REMCSUM_BIT] = "tx-tunnel-remcsum-segmentation", + [NETIF_F_GSO_SCTP_BIT] = "tx-sctp-segmentation", + [NETIF_F_GSO_ESP_BIT] = "tx-esp-segmentation", + [NETIF_F_GSO_UDP_L4_BIT] = "tx-udp-segmentation", + [NETIF_F_GSO_FRAGLIST_BIT] = "tx-gso-list", + + [NETIF_F_FCOE_CRC_BIT] = "tx-checksum-fcoe-crc", + [NETIF_F_SCTP_CRC_BIT] = "tx-checksum-sctp", + [NETIF_F_FCOE_MTU_BIT] = "fcoe-mtu", + [NETIF_F_NTUPLE_BIT] = "rx-ntuple-filter", + [NETIF_F_RXHASH_BIT] = "rx-hashing", + [NETIF_F_RXCSUM_BIT] = "rx-checksum", + [NETIF_F_NOCACHE_COPY_BIT] = "tx-nocache-copy", + [NETIF_F_LOOPBACK_BIT] = "loopback", + [NETIF_F_RXFCS_BIT] = "rx-fcs", + [NETIF_F_RXALL_BIT] = "rx-all", + [NETIF_F_HW_L2FW_DOFFLOAD_BIT] = "l2-fwd-offload", + [NETIF_F_HW_TC_BIT] = "hw-tc-offload", + [NETIF_F_HW_ESP_BIT] = "esp-hw-offload", + [NETIF_F_HW_ESP_TX_CSUM_BIT] = "esp-tx-csum-hw-offload", + [NETIF_F_RX_UDP_TUNNEL_PORT_BIT] = "rx-udp_tunnel-port-offload", + [NETIF_F_HW_TLS_RECORD_BIT] = "tls-hw-record", + [NETIF_F_HW_TLS_TX_BIT] = "tls-hw-tx-offload", + [NETIF_F_HW_TLS_RX_BIT] = "tls-hw-rx-offload", + [NETIF_F_GRO_FRAGLIST_BIT] = "rx-gro-list", + [NETIF_F_HW_MACSEC_BIT] = "macsec-hw-offload", + [NETIF_F_GRO_UDP_FWD_BIT] = "rx-udp-gro-forwarding", + [NETIF_F_HW_HSR_TAG_INS_BIT] = "hsr-tag-ins-offload", + [NETIF_F_HW_HSR_TAG_RM_BIT] = "hsr-tag-rm-offload", + [NETIF_F_HW_HSR_FWD_BIT] = "hsr-fwd-offload", + [NETIF_F_HW_HSR_DUP_BIT] = "hsr-dup-offload", +}; + +const char +rss_hash_func_strings[ETH_RSS_HASH_FUNCS_COUNT][ETH_GSTRING_LEN] = { + [ETH_RSS_HASH_TOP_BIT] = "toeplitz", + [ETH_RSS_HASH_XOR_BIT] = "xor", + [ETH_RSS_HASH_CRC32_BIT] = "crc32", +}; + +const char +tunable_strings[__ETHTOOL_TUNABLE_COUNT][ETH_GSTRING_LEN] = { + [ETHTOOL_ID_UNSPEC] = "Unspec", + [ETHTOOL_RX_COPYBREAK] = "rx-copybreak", + [ETHTOOL_TX_COPYBREAK] = "tx-copybreak", + [ETHTOOL_PFC_PREVENTION_TOUT] = "pfc-prevention-tout", + [ETHTOOL_TX_COPYBREAK_BUF_SIZE] = "tx-copybreak-buf-size", +}; + +const char +phy_tunable_strings[__ETHTOOL_PHY_TUNABLE_COUNT][ETH_GSTRING_LEN] = { + [ETHTOOL_ID_UNSPEC] = "Unspec", + [ETHTOOL_PHY_DOWNSHIFT] = "phy-downshift", + [ETHTOOL_PHY_FAST_LINK_DOWN] = "phy-fast-link-down", + [ETHTOOL_PHY_EDPD] = "phy-energy-detect-power-down", +}; + +#define __LINK_MODE_NAME(speed, type, duplex) \ + #speed "base" #type "/" #duplex +#define __DEFINE_LINK_MODE_NAME(speed, type, duplex) \ + [ETHTOOL_LINK_MODE(speed, type, duplex)] = \ + __LINK_MODE_NAME(speed, type, duplex) +#define __DEFINE_SPECIAL_MODE_NAME(_mode, _name) \ + [ETHTOOL_LINK_MODE_ ## _mode ## _BIT] = _name + +const char link_mode_names[][ETH_GSTRING_LEN] = { + __DEFINE_LINK_MODE_NAME(10, T, Half), + __DEFINE_LINK_MODE_NAME(10, T, Full), + __DEFINE_LINK_MODE_NAME(100, T, Half), + __DEFINE_LINK_MODE_NAME(100, T, Full), + __DEFINE_LINK_MODE_NAME(1000, T, Half), + __DEFINE_LINK_MODE_NAME(1000, T, Full), + __DEFINE_SPECIAL_MODE_NAME(Autoneg, "Autoneg"), + __DEFINE_SPECIAL_MODE_NAME(TP, "TP"), + __DEFINE_SPECIAL_MODE_NAME(AUI, "AUI"), + __DEFINE_SPECIAL_MODE_NAME(MII, "MII"), + __DEFINE_SPECIAL_MODE_NAME(FIBRE, "FIBRE"), + __DEFINE_SPECIAL_MODE_NAME(BNC, "BNC"), + __DEFINE_LINK_MODE_NAME(10000, T, Full), + __DEFINE_SPECIAL_MODE_NAME(Pause, "Pause"), + __DEFINE_SPECIAL_MODE_NAME(Asym_Pause, "Asym_Pause"), + __DEFINE_LINK_MODE_NAME(2500, X, Full), + __DEFINE_SPECIAL_MODE_NAME(Backplane, "Backplane"), + __DEFINE_LINK_MODE_NAME(1000, KX, Full), + __DEFINE_LINK_MODE_NAME(10000, KX4, Full), + __DEFINE_LINK_MODE_NAME(10000, KR, Full), + __DEFINE_SPECIAL_MODE_NAME(10000baseR_FEC, "10000baseR_FEC"), + __DEFINE_LINK_MODE_NAME(20000, MLD2, Full), + __DEFINE_LINK_MODE_NAME(20000, KR2, Full), + __DEFINE_LINK_MODE_NAME(40000, KR4, Full), + __DEFINE_LINK_MODE_NAME(40000, CR4, Full), + __DEFINE_LINK_MODE_NAME(40000, SR4, Full), + __DEFINE_LINK_MODE_NAME(40000, LR4, Full), + __DEFINE_LINK_MODE_NAME(56000, KR4, Full), + __DEFINE_LINK_MODE_NAME(56000, CR4, Full), + __DEFINE_LINK_MODE_NAME(56000, SR4, Full), + __DEFINE_LINK_MODE_NAME(56000, LR4, Full), + __DEFINE_LINK_MODE_NAME(25000, CR, Full), + __DEFINE_LINK_MODE_NAME(25000, KR, Full), + __DEFINE_LINK_MODE_NAME(25000, SR, Full), + __DEFINE_LINK_MODE_NAME(50000, CR2, Full), + __DEFINE_LINK_MODE_NAME(50000, KR2, Full), + __DEFINE_LINK_MODE_NAME(100000, KR4, Full), + __DEFINE_LINK_MODE_NAME(100000, SR4, Full), + __DEFINE_LINK_MODE_NAME(100000, CR4, Full), + __DEFINE_LINK_MODE_NAME(100000, LR4_ER4, Full), + __DEFINE_LINK_MODE_NAME(50000, SR2, Full), + __DEFINE_LINK_MODE_NAME(1000, X, Full), + __DEFINE_LINK_MODE_NAME(10000, CR, Full), + __DEFINE_LINK_MODE_NAME(10000, SR, Full), + __DEFINE_LINK_MODE_NAME(10000, LR, Full), + __DEFINE_LINK_MODE_NAME(10000, LRM, Full), + __DEFINE_LINK_MODE_NAME(10000, ER, Full), + __DEFINE_LINK_MODE_NAME(2500, T, Full), + __DEFINE_LINK_MODE_NAME(5000, T, Full), + __DEFINE_SPECIAL_MODE_NAME(FEC_NONE, "None"), + __DEFINE_SPECIAL_MODE_NAME(FEC_RS, "RS"), + __DEFINE_SPECIAL_MODE_NAME(FEC_BASER, "BASER"), + __DEFINE_LINK_MODE_NAME(50000, KR, Full), + __DEFINE_LINK_MODE_NAME(50000, SR, Full), + __DEFINE_LINK_MODE_NAME(50000, CR, Full), + __DEFINE_LINK_MODE_NAME(50000, LR_ER_FR, Full), + __DEFINE_LINK_MODE_NAME(50000, DR, Full), + __DEFINE_LINK_MODE_NAME(100000, KR2, Full), + __DEFINE_LINK_MODE_NAME(100000, SR2, Full), + __DEFINE_LINK_MODE_NAME(100000, CR2, Full), + __DEFINE_LINK_MODE_NAME(100000, LR2_ER2_FR2, Full), + __DEFINE_LINK_MODE_NAME(100000, DR2, Full), + __DEFINE_LINK_MODE_NAME(200000, KR4, Full), + __DEFINE_LINK_MODE_NAME(200000, SR4, Full), + __DEFINE_LINK_MODE_NAME(200000, LR4_ER4_FR4, Full), + __DEFINE_LINK_MODE_NAME(200000, DR4, Full), + __DEFINE_LINK_MODE_NAME(200000, CR4, Full), + __DEFINE_LINK_MODE_NAME(100, T1, Full), + __DEFINE_LINK_MODE_NAME(1000, T1, Full), + __DEFINE_LINK_MODE_NAME(400000, KR8, Full), + __DEFINE_LINK_MODE_NAME(400000, SR8, Full), + __DEFINE_LINK_MODE_NAME(400000, LR8_ER8_FR8, Full), + __DEFINE_LINK_MODE_NAME(400000, DR8, Full), + __DEFINE_LINK_MODE_NAME(400000, CR8, Full), + __DEFINE_SPECIAL_MODE_NAME(FEC_LLRS, "LLRS"), + __DEFINE_LINK_MODE_NAME(100000, KR, Full), + __DEFINE_LINK_MODE_NAME(100000, SR, Full), + __DEFINE_LINK_MODE_NAME(100000, LR_ER_FR, Full), + __DEFINE_LINK_MODE_NAME(100000, DR, Full), + __DEFINE_LINK_MODE_NAME(100000, CR, Full), + __DEFINE_LINK_MODE_NAME(200000, KR2, Full), + __DEFINE_LINK_MODE_NAME(200000, SR2, Full), + __DEFINE_LINK_MODE_NAME(200000, LR2_ER2_FR2, Full), + __DEFINE_LINK_MODE_NAME(200000, DR2, Full), + __DEFINE_LINK_MODE_NAME(200000, CR2, Full), + __DEFINE_LINK_MODE_NAME(400000, KR4, Full), + __DEFINE_LINK_MODE_NAME(400000, SR4, Full), + __DEFINE_LINK_MODE_NAME(400000, LR4_ER4_FR4, Full), + __DEFINE_LINK_MODE_NAME(400000, DR4, Full), + __DEFINE_LINK_MODE_NAME(400000, CR4, Full), + __DEFINE_LINK_MODE_NAME(100, FX, Half), + __DEFINE_LINK_MODE_NAME(100, FX, Full), + __DEFINE_LINK_MODE_NAME(10, T1L, Full), +}; +static_assert(ARRAY_SIZE(link_mode_names) == __ETHTOOL_LINK_MODE_MASK_NBITS); + +#define __LINK_MODE_LANES_CR 1 +#define __LINK_MODE_LANES_CR2 2 +#define __LINK_MODE_LANES_CR4 4 +#define __LINK_MODE_LANES_CR8 8 +#define __LINK_MODE_LANES_DR 1 +#define __LINK_MODE_LANES_DR2 2 +#define __LINK_MODE_LANES_DR4 4 +#define __LINK_MODE_LANES_DR8 8 +#define __LINK_MODE_LANES_KR 1 +#define __LINK_MODE_LANES_KR2 2 +#define __LINK_MODE_LANES_KR4 4 +#define __LINK_MODE_LANES_KR8 8 +#define __LINK_MODE_LANES_SR 1 +#define __LINK_MODE_LANES_SR2 2 +#define __LINK_MODE_LANES_SR4 4 +#define __LINK_MODE_LANES_SR8 8 +#define __LINK_MODE_LANES_ER 1 +#define __LINK_MODE_LANES_KX 1 +#define __LINK_MODE_LANES_KX4 4 +#define __LINK_MODE_LANES_LR 1 +#define __LINK_MODE_LANES_LR4 4 +#define __LINK_MODE_LANES_LR4_ER4 4 +#define __LINK_MODE_LANES_LR_ER_FR 1 +#define __LINK_MODE_LANES_LR2_ER2_FR2 2 +#define __LINK_MODE_LANES_LR4_ER4_FR4 4 +#define __LINK_MODE_LANES_LR8_ER8_FR8 8 +#define __LINK_MODE_LANES_LRM 1 +#define __LINK_MODE_LANES_MLD2 2 +#define __LINK_MODE_LANES_T 1 +#define __LINK_MODE_LANES_T1 1 +#define __LINK_MODE_LANES_X 1 +#define __LINK_MODE_LANES_FX 1 +#define __LINK_MODE_LANES_T1L 1 + +#define __DEFINE_LINK_MODE_PARAMS(_speed, _type, _duplex) \ + [ETHTOOL_LINK_MODE(_speed, _type, _duplex)] = { \ + .speed = SPEED_ ## _speed, \ + .lanes = __LINK_MODE_LANES_ ## _type, \ + .duplex = __DUPLEX_ ## _duplex \ + } +#define __DUPLEX_Half DUPLEX_HALF +#define __DUPLEX_Full DUPLEX_FULL +#define __DEFINE_SPECIAL_MODE_PARAMS(_mode) \ + [ETHTOOL_LINK_MODE_ ## _mode ## _BIT] = { \ + .speed = SPEED_UNKNOWN, \ + .lanes = 0, \ + .duplex = DUPLEX_UNKNOWN, \ + } + +const struct link_mode_info link_mode_params[] = { + __DEFINE_LINK_MODE_PARAMS(10, T, Half), + __DEFINE_LINK_MODE_PARAMS(10, T, Full), + __DEFINE_LINK_MODE_PARAMS(100, T, Half), + __DEFINE_LINK_MODE_PARAMS(100, T, Full), + __DEFINE_LINK_MODE_PARAMS(1000, T, Half), + __DEFINE_LINK_MODE_PARAMS(1000, T, Full), + __DEFINE_SPECIAL_MODE_PARAMS(Autoneg), + __DEFINE_SPECIAL_MODE_PARAMS(TP), + __DEFINE_SPECIAL_MODE_PARAMS(AUI), + __DEFINE_SPECIAL_MODE_PARAMS(MII), + __DEFINE_SPECIAL_MODE_PARAMS(FIBRE), + __DEFINE_SPECIAL_MODE_PARAMS(BNC), + __DEFINE_LINK_MODE_PARAMS(10000, T, Full), + __DEFINE_SPECIAL_MODE_PARAMS(Pause), + __DEFINE_SPECIAL_MODE_PARAMS(Asym_Pause), + __DEFINE_LINK_MODE_PARAMS(2500, X, Full), + __DEFINE_SPECIAL_MODE_PARAMS(Backplane), + __DEFINE_LINK_MODE_PARAMS(1000, KX, Full), + __DEFINE_LINK_MODE_PARAMS(10000, KX4, Full), + __DEFINE_LINK_MODE_PARAMS(10000, KR, Full), + [ETHTOOL_LINK_MODE_10000baseR_FEC_BIT] = { + .speed = SPEED_10000, + .lanes = 1, + .duplex = DUPLEX_FULL, + }, + __DEFINE_LINK_MODE_PARAMS(20000, MLD2, Full), + __DEFINE_LINK_MODE_PARAMS(20000, KR2, Full), + __DEFINE_LINK_MODE_PARAMS(40000, KR4, Full), + __DEFINE_LINK_MODE_PARAMS(40000, CR4, Full), + __DEFINE_LINK_MODE_PARAMS(40000, SR4, Full), + __DEFINE_LINK_MODE_PARAMS(40000, LR4, Full), + __DEFINE_LINK_MODE_PARAMS(56000, KR4, Full), + __DEFINE_LINK_MODE_PARAMS(56000, CR4, Full), + __DEFINE_LINK_MODE_PARAMS(56000, SR4, Full), + __DEFINE_LINK_MODE_PARAMS(56000, LR4, Full), + __DEFINE_LINK_MODE_PARAMS(25000, CR, Full), + __DEFINE_LINK_MODE_PARAMS(25000, KR, Full), + __DEFINE_LINK_MODE_PARAMS(25000, SR, Full), + __DEFINE_LINK_MODE_PARAMS(50000, CR2, Full), + __DEFINE_LINK_MODE_PARAMS(50000, KR2, Full), + __DEFINE_LINK_MODE_PARAMS(100000, KR4, Full), + __DEFINE_LINK_MODE_PARAMS(100000, SR4, Full), + __DEFINE_LINK_MODE_PARAMS(100000, CR4, Full), + __DEFINE_LINK_MODE_PARAMS(100000, LR4_ER4, Full), + __DEFINE_LINK_MODE_PARAMS(50000, SR2, Full), + __DEFINE_LINK_MODE_PARAMS(1000, X, Full), + __DEFINE_LINK_MODE_PARAMS(10000, CR, Full), + __DEFINE_LINK_MODE_PARAMS(10000, SR, Full), + __DEFINE_LINK_MODE_PARAMS(10000, LR, Full), + __DEFINE_LINK_MODE_PARAMS(10000, LRM, Full), + __DEFINE_LINK_MODE_PARAMS(10000, ER, Full), + __DEFINE_LINK_MODE_PARAMS(2500, T, Full), + __DEFINE_LINK_MODE_PARAMS(5000, T, Full), + __DEFINE_SPECIAL_MODE_PARAMS(FEC_NONE), + __DEFINE_SPECIAL_MODE_PARAMS(FEC_RS), + __DEFINE_SPECIAL_MODE_PARAMS(FEC_BASER), + __DEFINE_LINK_MODE_PARAMS(50000, KR, Full), + __DEFINE_LINK_MODE_PARAMS(50000, SR, Full), + __DEFINE_LINK_MODE_PARAMS(50000, CR, Full), + __DEFINE_LINK_MODE_PARAMS(50000, LR_ER_FR, Full), + __DEFINE_LINK_MODE_PARAMS(50000, DR, Full), + __DEFINE_LINK_MODE_PARAMS(100000, KR2, Full), + __DEFINE_LINK_MODE_PARAMS(100000, SR2, Full), + __DEFINE_LINK_MODE_PARAMS(100000, CR2, Full), + __DEFINE_LINK_MODE_PARAMS(100000, LR2_ER2_FR2, Full), + __DEFINE_LINK_MODE_PARAMS(100000, DR2, Full), + __DEFINE_LINK_MODE_PARAMS(200000, KR4, Full), + __DEFINE_LINK_MODE_PARAMS(200000, SR4, Full), + __DEFINE_LINK_MODE_PARAMS(200000, LR4_ER4_FR4, Full), + __DEFINE_LINK_MODE_PARAMS(200000, DR4, Full), + __DEFINE_LINK_MODE_PARAMS(200000, CR4, Full), + __DEFINE_LINK_MODE_PARAMS(100, T1, Full), + __DEFINE_LINK_MODE_PARAMS(1000, T1, Full), + __DEFINE_LINK_MODE_PARAMS(400000, KR8, Full), + __DEFINE_LINK_MODE_PARAMS(400000, SR8, Full), + __DEFINE_LINK_MODE_PARAMS(400000, LR8_ER8_FR8, Full), + __DEFINE_LINK_MODE_PARAMS(400000, DR8, Full), + __DEFINE_LINK_MODE_PARAMS(400000, CR8, Full), + __DEFINE_SPECIAL_MODE_PARAMS(FEC_LLRS), + __DEFINE_LINK_MODE_PARAMS(100000, KR, Full), + __DEFINE_LINK_MODE_PARAMS(100000, SR, Full), + __DEFINE_LINK_MODE_PARAMS(100000, LR_ER_FR, Full), + __DEFINE_LINK_MODE_PARAMS(100000, DR, Full), + __DEFINE_LINK_MODE_PARAMS(100000, CR, Full), + __DEFINE_LINK_MODE_PARAMS(200000, KR2, Full), + __DEFINE_LINK_MODE_PARAMS(200000, SR2, Full), + __DEFINE_LINK_MODE_PARAMS(200000, LR2_ER2_FR2, Full), + __DEFINE_LINK_MODE_PARAMS(200000, DR2, Full), + __DEFINE_LINK_MODE_PARAMS(200000, CR2, Full), + __DEFINE_LINK_MODE_PARAMS(400000, KR4, Full), + __DEFINE_LINK_MODE_PARAMS(400000, SR4, Full), + __DEFINE_LINK_MODE_PARAMS(400000, LR4_ER4_FR4, Full), + __DEFINE_LINK_MODE_PARAMS(400000, DR4, Full), + __DEFINE_LINK_MODE_PARAMS(400000, CR4, Full), + __DEFINE_LINK_MODE_PARAMS(100, FX, Half), + __DEFINE_LINK_MODE_PARAMS(100, FX, Full), + __DEFINE_LINK_MODE_PARAMS(10, T1L, Full), +}; +static_assert(ARRAY_SIZE(link_mode_params) == __ETHTOOL_LINK_MODE_MASK_NBITS); + +const char netif_msg_class_names[][ETH_GSTRING_LEN] = { + [NETIF_MSG_DRV_BIT] = "drv", + [NETIF_MSG_PROBE_BIT] = "probe", + [NETIF_MSG_LINK_BIT] = "link", + [NETIF_MSG_TIMER_BIT] = "timer", + [NETIF_MSG_IFDOWN_BIT] = "ifdown", + [NETIF_MSG_IFUP_BIT] = "ifup", + [NETIF_MSG_RX_ERR_BIT] = "rx_err", + [NETIF_MSG_TX_ERR_BIT] = "tx_err", + [NETIF_MSG_TX_QUEUED_BIT] = "tx_queued", + [NETIF_MSG_INTR_BIT] = "intr", + [NETIF_MSG_TX_DONE_BIT] = "tx_done", + [NETIF_MSG_RX_STATUS_BIT] = "rx_status", + [NETIF_MSG_PKTDATA_BIT] = "pktdata", + [NETIF_MSG_HW_BIT] = "hw", + [NETIF_MSG_WOL_BIT] = "wol", +}; +static_assert(ARRAY_SIZE(netif_msg_class_names) == NETIF_MSG_CLASS_COUNT); + +const char wol_mode_names[][ETH_GSTRING_LEN] = { + [const_ilog2(WAKE_PHY)] = "phy", + [const_ilog2(WAKE_UCAST)] = "ucast", + [const_ilog2(WAKE_MCAST)] = "mcast", + [const_ilog2(WAKE_BCAST)] = "bcast", + [const_ilog2(WAKE_ARP)] = "arp", + [const_ilog2(WAKE_MAGIC)] = "magic", + [const_ilog2(WAKE_MAGICSECURE)] = "magicsecure", + [const_ilog2(WAKE_FILTER)] = "filter", +}; +static_assert(ARRAY_SIZE(wol_mode_names) == WOL_MODE_COUNT); + +const char sof_timestamping_names[][ETH_GSTRING_LEN] = { + [const_ilog2(SOF_TIMESTAMPING_TX_HARDWARE)] = "hardware-transmit", + [const_ilog2(SOF_TIMESTAMPING_TX_SOFTWARE)] = "software-transmit", + [const_ilog2(SOF_TIMESTAMPING_RX_HARDWARE)] = "hardware-receive", + [const_ilog2(SOF_TIMESTAMPING_RX_SOFTWARE)] = "software-receive", + [const_ilog2(SOF_TIMESTAMPING_SOFTWARE)] = "software-system-clock", + [const_ilog2(SOF_TIMESTAMPING_SYS_HARDWARE)] = "hardware-legacy-clock", + [const_ilog2(SOF_TIMESTAMPING_RAW_HARDWARE)] = "hardware-raw-clock", + [const_ilog2(SOF_TIMESTAMPING_OPT_ID)] = "option-id", + [const_ilog2(SOF_TIMESTAMPING_TX_SCHED)] = "sched-transmit", + [const_ilog2(SOF_TIMESTAMPING_TX_ACK)] = "ack-transmit", + [const_ilog2(SOF_TIMESTAMPING_OPT_CMSG)] = "option-cmsg", + [const_ilog2(SOF_TIMESTAMPING_OPT_TSONLY)] = "option-tsonly", + [const_ilog2(SOF_TIMESTAMPING_OPT_STATS)] = "option-stats", + [const_ilog2(SOF_TIMESTAMPING_OPT_PKTINFO)] = "option-pktinfo", + [const_ilog2(SOF_TIMESTAMPING_OPT_TX_SWHW)] = "option-tx-swhw", + [const_ilog2(SOF_TIMESTAMPING_BIND_PHC)] = "bind-phc", +}; +static_assert(ARRAY_SIZE(sof_timestamping_names) == __SOF_TIMESTAMPING_CNT); + +const char ts_tx_type_names[][ETH_GSTRING_LEN] = { + [HWTSTAMP_TX_OFF] = "off", + [HWTSTAMP_TX_ON] = "on", + [HWTSTAMP_TX_ONESTEP_SYNC] = "onestep-sync", + [HWTSTAMP_TX_ONESTEP_P2P] = "onestep-p2p", +}; +static_assert(ARRAY_SIZE(ts_tx_type_names) == __HWTSTAMP_TX_CNT); + +const char ts_rx_filter_names[][ETH_GSTRING_LEN] = { + [HWTSTAMP_FILTER_NONE] = "none", + [HWTSTAMP_FILTER_ALL] = "all", + [HWTSTAMP_FILTER_SOME] = "some", + [HWTSTAMP_FILTER_PTP_V1_L4_EVENT] = "ptpv1-l4-event", + [HWTSTAMP_FILTER_PTP_V1_L4_SYNC] = "ptpv1-l4-sync", + [HWTSTAMP_FILTER_PTP_V1_L4_DELAY_REQ] = "ptpv1-l4-delay-req", + [HWTSTAMP_FILTER_PTP_V2_L4_EVENT] = "ptpv2-l4-event", + [HWTSTAMP_FILTER_PTP_V2_L4_SYNC] = "ptpv2-l4-sync", + [HWTSTAMP_FILTER_PTP_V2_L4_DELAY_REQ] = "ptpv2-l4-delay-req", + [HWTSTAMP_FILTER_PTP_V2_L2_EVENT] = "ptpv2-l2-event", + [HWTSTAMP_FILTER_PTP_V2_L2_SYNC] = "ptpv2-l2-sync", + [HWTSTAMP_FILTER_PTP_V2_L2_DELAY_REQ] = "ptpv2-l2-delay-req", + [HWTSTAMP_FILTER_PTP_V2_EVENT] = "ptpv2-event", + [HWTSTAMP_FILTER_PTP_V2_SYNC] = "ptpv2-sync", + [HWTSTAMP_FILTER_PTP_V2_DELAY_REQ] = "ptpv2-delay-req", + [HWTSTAMP_FILTER_NTP_ALL] = "ntp-all", +}; +static_assert(ARRAY_SIZE(ts_rx_filter_names) == __HWTSTAMP_FILTER_CNT); + +const char udp_tunnel_type_names[][ETH_GSTRING_LEN] = { + [ETHTOOL_UDP_TUNNEL_TYPE_VXLAN] = "vxlan", + [ETHTOOL_UDP_TUNNEL_TYPE_GENEVE] = "geneve", + [ETHTOOL_UDP_TUNNEL_TYPE_VXLAN_GPE] = "vxlan-gpe", +}; +static_assert(ARRAY_SIZE(udp_tunnel_type_names) == + __ETHTOOL_UDP_TUNNEL_TYPE_CNT); + +/* return false if legacy contained non-0 deprecated fields + * maxtxpkt/maxrxpkt. rest of ksettings always updated + */ +bool +convert_legacy_settings_to_link_ksettings( + struct ethtool_link_ksettings *link_ksettings, + const struct ethtool_cmd *legacy_settings) +{ + bool retval = true; + + memset(link_ksettings, 0, sizeof(*link_ksettings)); + + /* This is used to tell users that driver is still using these + * deprecated legacy fields, and they should not use + * %ETHTOOL_GLINKSETTINGS/%ETHTOOL_SLINKSETTINGS + */ + if (legacy_settings->maxtxpkt || + legacy_settings->maxrxpkt) + retval = false; + + ethtool_convert_legacy_u32_to_link_mode( + link_ksettings->link_modes.supported, + legacy_settings->supported); + ethtool_convert_legacy_u32_to_link_mode( + link_ksettings->link_modes.advertising, + legacy_settings->advertising); + ethtool_convert_legacy_u32_to_link_mode( + link_ksettings->link_modes.lp_advertising, + legacy_settings->lp_advertising); + link_ksettings->base.speed + = ethtool_cmd_speed(legacy_settings); + link_ksettings->base.duplex + = legacy_settings->duplex; + link_ksettings->base.port + = legacy_settings->port; + link_ksettings->base.phy_address + = legacy_settings->phy_address; + link_ksettings->base.autoneg + = legacy_settings->autoneg; + link_ksettings->base.mdio_support + = legacy_settings->mdio_support; + link_ksettings->base.eth_tp_mdix + = legacy_settings->eth_tp_mdix; + link_ksettings->base.eth_tp_mdix_ctrl + = legacy_settings->eth_tp_mdix_ctrl; + return retval; +} + +int __ethtool_get_link(struct net_device *dev) +{ + if (!dev->ethtool_ops->get_link) + return -EOPNOTSUPP; + + return netif_running(dev) && dev->ethtool_ops->get_link(dev); +} + +int ethtool_get_max_rxfh_channel(struct net_device *dev, u32 *max) +{ + u32 dev_size, current_max = 0; + u32 *indir; + int ret; + + if (!dev->ethtool_ops->get_rxfh_indir_size || + !dev->ethtool_ops->get_rxfh) + return -EOPNOTSUPP; + dev_size = dev->ethtool_ops->get_rxfh_indir_size(dev); + if (dev_size == 0) + return -EOPNOTSUPP; + + indir = kcalloc(dev_size, sizeof(indir[0]), GFP_USER); + if (!indir) + return -ENOMEM; + + ret = dev->ethtool_ops->get_rxfh(dev, indir, NULL, NULL); + if (ret) + goto out; + + while (dev_size--) + current_max = max(current_max, indir[dev_size]); + + *max = current_max; + +out: + kfree(indir); + return ret; +} + +int ethtool_check_ops(const struct ethtool_ops *ops) +{ + if (WARN_ON(ops->set_coalesce && !ops->supported_coalesce_params)) + return -EINVAL; + /* NOTE: sufficiently insane drivers may swap ethtool_ops at runtime, + * the fact that ops are checked at registration time does not + * mean the ops attached to a netdev later on are sane. + */ + return 0; +} + +int __ethtool_get_ts_info(struct net_device *dev, struct ethtool_ts_info *info) +{ + const struct ethtool_ops *ops = dev->ethtool_ops; + struct phy_device *phydev = dev->phydev; + + memset(info, 0, sizeof(*info)); + info->cmd = ETHTOOL_GET_TS_INFO; + + if (phy_has_tsinfo(phydev)) + return phy_ts_info(phydev, info); + if (ops->get_ts_info) + return ops->get_ts_info(dev, info); + + info->so_timestamping = SOF_TIMESTAMPING_RX_SOFTWARE | + SOF_TIMESTAMPING_SOFTWARE; + info->phc_index = -1; + + return 0; +} + +int ethtool_get_phc_vclocks(struct net_device *dev, int **vclock_index) +{ + struct ethtool_ts_info info = { }; + int num = 0; + + if (!__ethtool_get_ts_info(dev, &info)) + num = ptp_get_vclocks_index(info.phc_index, vclock_index); + + return num; +} +EXPORT_SYMBOL(ethtool_get_phc_vclocks); + +const struct ethtool_phy_ops *ethtool_phy_ops; + +void ethtool_set_ethtool_phy_ops(const struct ethtool_phy_ops *ops) +{ + rtnl_lock(); + ethtool_phy_ops = ops; + rtnl_unlock(); +} +EXPORT_SYMBOL_GPL(ethtool_set_ethtool_phy_ops); + +void +ethtool_params_from_link_mode(struct ethtool_link_ksettings *link_ksettings, + enum ethtool_link_mode_bit_indices link_mode) +{ + const struct link_mode_info *link_info; + + if (WARN_ON_ONCE(link_mode >= __ETHTOOL_LINK_MODE_MASK_NBITS)) + return; + + link_info = &link_mode_params[link_mode]; + link_ksettings->base.speed = link_info->speed; + link_ksettings->lanes = link_info->lanes; + link_ksettings->base.duplex = link_info->duplex; +} +EXPORT_SYMBOL_GPL(ethtool_params_from_link_mode); diff --git a/net/ethtool/common.h b/net/ethtool/common.h new file mode 100644 index 000000000..c1779657e --- /dev/null +++ b/net/ethtool/common.h @@ -0,0 +1,56 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ + +#ifndef _ETHTOOL_COMMON_H +#define _ETHTOOL_COMMON_H + +#include +#include + +#define ETHTOOL_DEV_FEATURE_WORDS DIV_ROUND_UP(NETDEV_FEATURE_COUNT, 32) + +/* compose link mode index from speed, type and duplex */ +#define ETHTOOL_LINK_MODE(speed, type, duplex) \ + ETHTOOL_LINK_MODE_ ## speed ## base ## type ## _ ## duplex ## _BIT + +#define __SOF_TIMESTAMPING_CNT (const_ilog2(SOF_TIMESTAMPING_LAST) + 1) + +struct link_mode_info { + int speed; + u8 lanes; + u8 duplex; +}; + +extern const char +netdev_features_strings[NETDEV_FEATURE_COUNT][ETH_GSTRING_LEN]; +extern const char +rss_hash_func_strings[ETH_RSS_HASH_FUNCS_COUNT][ETH_GSTRING_LEN]; +extern const char +tunable_strings[__ETHTOOL_TUNABLE_COUNT][ETH_GSTRING_LEN]; +extern const char +phy_tunable_strings[__ETHTOOL_PHY_TUNABLE_COUNT][ETH_GSTRING_LEN]; +extern const char link_mode_names[][ETH_GSTRING_LEN]; +extern const struct link_mode_info link_mode_params[]; +extern const char netif_msg_class_names[][ETH_GSTRING_LEN]; +extern const char wol_mode_names[][ETH_GSTRING_LEN]; +extern const char sof_timestamping_names[][ETH_GSTRING_LEN]; +extern const char ts_tx_type_names[][ETH_GSTRING_LEN]; +extern const char ts_rx_filter_names[][ETH_GSTRING_LEN]; +extern const char udp_tunnel_type_names[][ETH_GSTRING_LEN]; + +int __ethtool_get_link(struct net_device *dev); + +bool convert_legacy_settings_to_link_ksettings( + struct ethtool_link_ksettings *link_ksettings, + const struct ethtool_cmd *legacy_settings); +int ethtool_get_max_rxfh_channel(struct net_device *dev, u32 *max); +int __ethtool_get_ts_info(struct net_device *dev, struct ethtool_ts_info *info); + +extern const struct ethtool_phy_ops *ethtool_phy_ops; +extern const struct ethtool_pse_ops *ethtool_pse_ops; + +int ethtool_get_module_info_call(struct net_device *dev, + struct ethtool_modinfo *modinfo); +int ethtool_get_module_eeprom_call(struct net_device *dev, + struct ethtool_eeprom *ee, u8 *data); + +#endif /* _ETHTOOL_COMMON_H */ diff --git a/net/ethtool/debug.c b/net/ethtool/debug.c new file mode 100644 index 000000000..d73888c7d --- /dev/null +++ b/net/ethtool/debug.c @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include "netlink.h" +#include "common.h" +#include "bitset.h" + +struct debug_req_info { + struct ethnl_req_info base; +}; + +struct debug_reply_data { + struct ethnl_reply_data base; + u32 msg_mask; +}; + +#define DEBUG_REPDATA(__reply_base) \ + container_of(__reply_base, struct debug_reply_data, base) + +const struct nla_policy ethnl_debug_get_policy[] = { + [ETHTOOL_A_DEBUG_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), +}; + +static int debug_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + struct debug_reply_data *data = DEBUG_REPDATA(reply_base); + struct net_device *dev = reply_base->dev; + int ret; + + if (!dev->ethtool_ops->get_msglevel) + return -EOPNOTSUPP; + + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + data->msg_mask = dev->ethtool_ops->get_msglevel(dev); + ethnl_ops_complete(dev); + + return 0; +} + +static int debug_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct debug_reply_data *data = DEBUG_REPDATA(reply_base); + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + + return ethnl_bitset32_size(&data->msg_mask, NULL, NETIF_MSG_CLASS_COUNT, + netif_msg_class_names, compact); +} + +static int debug_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct debug_reply_data *data = DEBUG_REPDATA(reply_base); + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + + return ethnl_put_bitset32(skb, ETHTOOL_A_DEBUG_MSGMASK, &data->msg_mask, + NULL, NETIF_MSG_CLASS_COUNT, + netif_msg_class_names, compact); +} + +const struct ethnl_request_ops ethnl_debug_request_ops = { + .request_cmd = ETHTOOL_MSG_DEBUG_GET, + .reply_cmd = ETHTOOL_MSG_DEBUG_GET_REPLY, + .hdr_attr = ETHTOOL_A_DEBUG_HEADER, + .req_info_size = sizeof(struct debug_req_info), + .reply_data_size = sizeof(struct debug_reply_data), + + .prepare_data = debug_prepare_data, + .reply_size = debug_reply_size, + .fill_reply = debug_fill_reply, +}; + +/* DEBUG_SET */ + +const struct nla_policy ethnl_debug_set_policy[] = { + [ETHTOOL_A_DEBUG_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_DEBUG_MSGMASK] = { .type = NLA_NESTED }, +}; + +int ethnl_set_debug(struct sk_buff *skb, struct genl_info *info) +{ + struct ethnl_req_info req_info = {}; + struct nlattr **tb = info->attrs; + struct net_device *dev; + bool mod = false; + u32 msg_mask; + int ret; + + ret = ethnl_parse_header_dev_get(&req_info, + tb[ETHTOOL_A_DEBUG_HEADER], + genl_info_net(info), info->extack, + true); + if (ret < 0) + return ret; + dev = req_info.dev; + ret = -EOPNOTSUPP; + if (!dev->ethtool_ops->get_msglevel || !dev->ethtool_ops->set_msglevel) + goto out_dev; + + rtnl_lock(); + ret = ethnl_ops_begin(dev); + if (ret < 0) + goto out_rtnl; + + msg_mask = dev->ethtool_ops->get_msglevel(dev); + ret = ethnl_update_bitset32(&msg_mask, NETIF_MSG_CLASS_COUNT, + tb[ETHTOOL_A_DEBUG_MSGMASK], + netif_msg_class_names, info->extack, &mod); + if (ret < 0 || !mod) + goto out_ops; + + dev->ethtool_ops->set_msglevel(dev, msg_mask); + ethtool_notify(dev, ETHTOOL_MSG_DEBUG_NTF, NULL); + +out_ops: + ethnl_ops_complete(dev); +out_rtnl: + rtnl_unlock(); +out_dev: + ethnl_parse_header_dev_put(&req_info); + return ret; +} diff --git a/net/ethtool/eee.c b/net/ethtool/eee.c new file mode 100644 index 000000000..45c42b2d5 --- /dev/null +++ b/net/ethtool/eee.c @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include "netlink.h" +#include "common.h" +#include "bitset.h" + +#define EEE_MODES_COUNT \ + (sizeof_field(struct ethtool_eee, supported) * BITS_PER_BYTE) + +struct eee_req_info { + struct ethnl_req_info base; +}; + +struct eee_reply_data { + struct ethnl_reply_data base; + struct ethtool_eee eee; +}; + +#define EEE_REPDATA(__reply_base) \ + container_of(__reply_base, struct eee_reply_data, base) + +const struct nla_policy ethnl_eee_get_policy[] = { + [ETHTOOL_A_EEE_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), +}; + +static int eee_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + struct eee_reply_data *data = EEE_REPDATA(reply_base); + struct net_device *dev = reply_base->dev; + int ret; + + if (!dev->ethtool_ops->get_eee) + return -EOPNOTSUPP; + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + ret = dev->ethtool_ops->get_eee(dev, &data->eee); + ethnl_ops_complete(dev); + + return ret; +} + +static int eee_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + const struct eee_reply_data *data = EEE_REPDATA(reply_base); + const struct ethtool_eee *eee = &data->eee; + int len = 0; + int ret; + + BUILD_BUG_ON(sizeof(eee->advertised) * BITS_PER_BYTE != + EEE_MODES_COUNT); + BUILD_BUG_ON(sizeof(eee->lp_advertised) * BITS_PER_BYTE != + EEE_MODES_COUNT); + + /* MODES_OURS */ + ret = ethnl_bitset32_size(&eee->advertised, &eee->supported, + EEE_MODES_COUNT, link_mode_names, compact); + if (ret < 0) + return ret; + len += ret; + /* MODES_PEERS */ + ret = ethnl_bitset32_size(&eee->lp_advertised, NULL, + EEE_MODES_COUNT, link_mode_names, compact); + if (ret < 0) + return ret; + len += ret; + + len += nla_total_size(sizeof(u8)) + /* _EEE_ACTIVE */ + nla_total_size(sizeof(u8)) + /* _EEE_ENABLED */ + nla_total_size(sizeof(u8)) + /* _EEE_TX_LPI_ENABLED */ + nla_total_size(sizeof(u32)); /* _EEE_TX_LPI_TIMER */ + + return len; +} + +static int eee_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + const struct eee_reply_data *data = EEE_REPDATA(reply_base); + const struct ethtool_eee *eee = &data->eee; + int ret; + + ret = ethnl_put_bitset32(skb, ETHTOOL_A_EEE_MODES_OURS, + &eee->advertised, &eee->supported, + EEE_MODES_COUNT, link_mode_names, compact); + if (ret < 0) + return ret; + ret = ethnl_put_bitset32(skb, ETHTOOL_A_EEE_MODES_PEER, + &eee->lp_advertised, NULL, EEE_MODES_COUNT, + link_mode_names, compact); + if (ret < 0) + return ret; + + if (nla_put_u8(skb, ETHTOOL_A_EEE_ACTIVE, !!eee->eee_active) || + nla_put_u8(skb, ETHTOOL_A_EEE_ENABLED, !!eee->eee_enabled) || + nla_put_u8(skb, ETHTOOL_A_EEE_TX_LPI_ENABLED, + !!eee->tx_lpi_enabled) || + nla_put_u32(skb, ETHTOOL_A_EEE_TX_LPI_TIMER, eee->tx_lpi_timer)) + return -EMSGSIZE; + + return 0; +} + +const struct ethnl_request_ops ethnl_eee_request_ops = { + .request_cmd = ETHTOOL_MSG_EEE_GET, + .reply_cmd = ETHTOOL_MSG_EEE_GET_REPLY, + .hdr_attr = ETHTOOL_A_EEE_HEADER, + .req_info_size = sizeof(struct eee_req_info), + .reply_data_size = sizeof(struct eee_reply_data), + + .prepare_data = eee_prepare_data, + .reply_size = eee_reply_size, + .fill_reply = eee_fill_reply, +}; + +/* EEE_SET */ + +const struct nla_policy ethnl_eee_set_policy[] = { + [ETHTOOL_A_EEE_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_EEE_MODES_OURS] = { .type = NLA_NESTED }, + [ETHTOOL_A_EEE_ENABLED] = { .type = NLA_U8 }, + [ETHTOOL_A_EEE_TX_LPI_ENABLED] = { .type = NLA_U8 }, + [ETHTOOL_A_EEE_TX_LPI_TIMER] = { .type = NLA_U32 }, +}; + +int ethnl_set_eee(struct sk_buff *skb, struct genl_info *info) +{ + struct ethnl_req_info req_info = {}; + struct nlattr **tb = info->attrs; + const struct ethtool_ops *ops; + struct ethtool_eee eee = {}; + struct net_device *dev; + bool mod = false; + int ret; + + ret = ethnl_parse_header_dev_get(&req_info, + tb[ETHTOOL_A_EEE_HEADER], + genl_info_net(info), info->extack, + true); + if (ret < 0) + return ret; + dev = req_info.dev; + ops = dev->ethtool_ops; + ret = -EOPNOTSUPP; + if (!ops->get_eee || !ops->set_eee) + goto out_dev; + + rtnl_lock(); + ret = ethnl_ops_begin(dev); + if (ret < 0) + goto out_rtnl; + ret = ops->get_eee(dev, &eee); + if (ret < 0) + goto out_ops; + + ret = ethnl_update_bitset32(&eee.advertised, EEE_MODES_COUNT, + tb[ETHTOOL_A_EEE_MODES_OURS], + link_mode_names, info->extack, &mod); + if (ret < 0) + goto out_ops; + ethnl_update_bool32(&eee.eee_enabled, tb[ETHTOOL_A_EEE_ENABLED], &mod); + ethnl_update_bool32(&eee.tx_lpi_enabled, + tb[ETHTOOL_A_EEE_TX_LPI_ENABLED], &mod); + ethnl_update_u32(&eee.tx_lpi_timer, tb[ETHTOOL_A_EEE_TX_LPI_TIMER], + &mod); + ret = 0; + if (!mod) + goto out_ops; + + ret = dev->ethtool_ops->set_eee(dev, &eee); + if (ret < 0) + goto out_ops; + ethtool_notify(dev, ETHTOOL_MSG_EEE_NTF, NULL); + +out_ops: + ethnl_ops_complete(dev); +out_rtnl: + rtnl_unlock(); +out_dev: + ethnl_parse_header_dev_put(&req_info); + return ret; +} diff --git a/net/ethtool/eeprom.c b/net/ethtool/eeprom.c new file mode 100644 index 000000000..49c0a2a77 --- /dev/null +++ b/net/ethtool/eeprom.c @@ -0,0 +1,241 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include +#include +#include "netlink.h" +#include "common.h" + +struct eeprom_req_info { + struct ethnl_req_info base; + u32 offset; + u32 length; + u8 page; + u8 bank; + u8 i2c_address; +}; + +struct eeprom_reply_data { + struct ethnl_reply_data base; + u32 length; + u8 *data; +}; + +#define MODULE_EEPROM_REQINFO(__req_base) \ + container_of(__req_base, struct eeprom_req_info, base) + +#define MODULE_EEPROM_REPDATA(__reply_base) \ + container_of(__reply_base, struct eeprom_reply_data, base) + +static int fallback_set_params(struct eeprom_req_info *request, + struct ethtool_modinfo *modinfo, + struct ethtool_eeprom *eeprom) +{ + u32 offset = request->offset; + u32 length = request->length; + + if (request->page) + offset = request->page * ETH_MODULE_EEPROM_PAGE_LEN + offset; + + if (modinfo->type == ETH_MODULE_SFF_8472 && + request->i2c_address == 0x51) + offset += ETH_MODULE_EEPROM_PAGE_LEN * 2; + + if (offset >= modinfo->eeprom_len) + return -EINVAL; + + eeprom->cmd = ETHTOOL_GMODULEEEPROM; + eeprom->len = length; + eeprom->offset = offset; + + return 0; +} + +static int eeprom_fallback(struct eeprom_req_info *request, + struct eeprom_reply_data *reply, + struct genl_info *info) +{ + struct net_device *dev = reply->base.dev; + struct ethtool_modinfo modinfo = {0}; + struct ethtool_eeprom eeprom = {0}; + u8 *data; + int err; + + modinfo.cmd = ETHTOOL_GMODULEINFO; + err = ethtool_get_module_info_call(dev, &modinfo); + if (err < 0) + return err; + + err = fallback_set_params(request, &modinfo, &eeprom); + if (err < 0) + return err; + + data = kmalloc(eeprom.len, GFP_KERNEL); + if (!data) + return -ENOMEM; + err = ethtool_get_module_eeprom_call(dev, &eeprom, data); + if (err < 0) + goto err_out; + + reply->data = data; + reply->length = eeprom.len; + + return 0; + +err_out: + kfree(data); + return err; +} + +static int get_module_eeprom_by_page(struct net_device *dev, + struct ethtool_module_eeprom *page_data, + struct netlink_ext_ack *extack) +{ + const struct ethtool_ops *ops = dev->ethtool_ops; + + if (dev->sfp_bus) + return sfp_get_module_eeprom_by_page(dev->sfp_bus, page_data, extack); + + if (ops->get_module_eeprom_by_page) + return ops->get_module_eeprom_by_page(dev, page_data, extack); + + return -EOPNOTSUPP; +} + +static int eeprom_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + struct eeprom_reply_data *reply = MODULE_EEPROM_REPDATA(reply_base); + struct eeprom_req_info *request = MODULE_EEPROM_REQINFO(req_base); + struct ethtool_module_eeprom page_data = {0}; + struct net_device *dev = reply_base->dev; + int ret; + + page_data.offset = request->offset; + page_data.length = request->length; + page_data.i2c_address = request->i2c_address; + page_data.page = request->page; + page_data.bank = request->bank; + page_data.data = kmalloc(page_data.length, GFP_KERNEL); + if (!page_data.data) + return -ENOMEM; + + ret = ethnl_ops_begin(dev); + if (ret) + goto err_free; + + ret = get_module_eeprom_by_page(dev, &page_data, info ? info->extack : NULL); + if (ret < 0) + goto err_ops; + + reply->length = ret; + reply->data = page_data.data; + + ethnl_ops_complete(dev); + return 0; + +err_ops: + ethnl_ops_complete(dev); +err_free: + kfree(page_data.data); + + if (ret == -EOPNOTSUPP) + return eeprom_fallback(request, reply, info); + return ret; +} + +static int eeprom_parse_request(struct ethnl_req_info *req_info, struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct eeprom_req_info *request = MODULE_EEPROM_REQINFO(req_info); + + if (!tb[ETHTOOL_A_MODULE_EEPROM_OFFSET] || + !tb[ETHTOOL_A_MODULE_EEPROM_LENGTH] || + !tb[ETHTOOL_A_MODULE_EEPROM_PAGE] || + !tb[ETHTOOL_A_MODULE_EEPROM_I2C_ADDRESS]) + return -EINVAL; + + request->i2c_address = nla_get_u8(tb[ETHTOOL_A_MODULE_EEPROM_I2C_ADDRESS]); + request->offset = nla_get_u32(tb[ETHTOOL_A_MODULE_EEPROM_OFFSET]); + request->length = nla_get_u32(tb[ETHTOOL_A_MODULE_EEPROM_LENGTH]); + + /* The following set of conditions limit the API to only dump 1/2 + * EEPROM page without crossing low page boundary located at offset 128. + * This means user may only request dumps of length limited to 128 from + * either low 128 bytes or high 128 bytes. + * For pages higher than 0 only high 128 bytes are accessible. + */ + request->page = nla_get_u8(tb[ETHTOOL_A_MODULE_EEPROM_PAGE]); + if (request->page && request->offset < ETH_MODULE_EEPROM_PAGE_LEN) { + NL_SET_ERR_MSG_ATTR(extack, tb[ETHTOOL_A_MODULE_EEPROM_PAGE], + "reading from lower half page is allowed for page 0 only"); + return -EINVAL; + } + + if (request->offset < ETH_MODULE_EEPROM_PAGE_LEN && + request->offset + request->length > ETH_MODULE_EEPROM_PAGE_LEN) { + NL_SET_ERR_MSG_ATTR(extack, tb[ETHTOOL_A_MODULE_EEPROM_LENGTH], + "reading cross half page boundary is illegal"); + return -EINVAL; + } else if (request->offset + request->length > ETH_MODULE_EEPROM_PAGE_LEN * 2) { + NL_SET_ERR_MSG_ATTR(extack, tb[ETHTOOL_A_MODULE_EEPROM_LENGTH], + "reading cross page boundary is illegal"); + return -EINVAL; + } + + if (tb[ETHTOOL_A_MODULE_EEPROM_BANK]) + request->bank = nla_get_u8(tb[ETHTOOL_A_MODULE_EEPROM_BANK]); + + return 0; +} + +static int eeprom_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct eeprom_req_info *request = MODULE_EEPROM_REQINFO(req_base); + + return nla_total_size(sizeof(u8) * request->length); /* _EEPROM_DATA */ +} + +static int eeprom_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + struct eeprom_reply_data *reply = MODULE_EEPROM_REPDATA(reply_base); + + return nla_put(skb, ETHTOOL_A_MODULE_EEPROM_DATA, reply->length, reply->data); +} + +static void eeprom_cleanup_data(struct ethnl_reply_data *reply_base) +{ + struct eeprom_reply_data *reply = MODULE_EEPROM_REPDATA(reply_base); + + kfree(reply->data); +} + +const struct ethnl_request_ops ethnl_module_eeprom_request_ops = { + .request_cmd = ETHTOOL_MSG_MODULE_EEPROM_GET, + .reply_cmd = ETHTOOL_MSG_MODULE_EEPROM_GET_REPLY, + .hdr_attr = ETHTOOL_A_MODULE_EEPROM_HEADER, + .req_info_size = sizeof(struct eeprom_req_info), + .reply_data_size = sizeof(struct eeprom_reply_data), + + .parse_request = eeprom_parse_request, + .prepare_data = eeprom_prepare_data, + .reply_size = eeprom_reply_size, + .fill_reply = eeprom_fill_reply, + .cleanup_data = eeprom_cleanup_data, +}; + +const struct nla_policy ethnl_module_eeprom_get_policy[] = { + [ETHTOOL_A_MODULE_EEPROM_HEADER] = NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_MODULE_EEPROM_OFFSET] = + NLA_POLICY_MAX(NLA_U32, ETH_MODULE_EEPROM_PAGE_LEN * 2 - 1), + [ETHTOOL_A_MODULE_EEPROM_LENGTH] = + NLA_POLICY_RANGE(NLA_U32, 1, ETH_MODULE_EEPROM_PAGE_LEN), + [ETHTOOL_A_MODULE_EEPROM_PAGE] = { .type = NLA_U8 }, + [ETHTOOL_A_MODULE_EEPROM_BANK] = { .type = NLA_U8 }, + [ETHTOOL_A_MODULE_EEPROM_I2C_ADDRESS] = + NLA_POLICY_RANGE(NLA_U8, 0, ETH_MODULE_MAX_I2C_ADDRESS), +}; + diff --git a/net/ethtool/features.c b/net/ethtool/features.c new file mode 100644 index 000000000..090e493f5 --- /dev/null +++ b/net/ethtool/features.c @@ -0,0 +1,293 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include "netlink.h" +#include "common.h" +#include "bitset.h" + +struct features_req_info { + struct ethnl_req_info base; +}; + +struct features_reply_data { + struct ethnl_reply_data base; + u32 hw[ETHTOOL_DEV_FEATURE_WORDS]; + u32 wanted[ETHTOOL_DEV_FEATURE_WORDS]; + u32 active[ETHTOOL_DEV_FEATURE_WORDS]; + u32 nochange[ETHTOOL_DEV_FEATURE_WORDS]; + u32 all[ETHTOOL_DEV_FEATURE_WORDS]; +}; + +#define FEATURES_REPDATA(__reply_base) \ + container_of(__reply_base, struct features_reply_data, base) + +const struct nla_policy ethnl_features_get_policy[] = { + [ETHTOOL_A_FEATURES_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), +}; + +static void ethnl_features_to_bitmap32(u32 *dest, netdev_features_t src) +{ + unsigned int i; + + for (i = 0; i < ETHTOOL_DEV_FEATURE_WORDS; i++) + dest[i] = src >> (32 * i); +} + +static int features_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + struct features_reply_data *data = FEATURES_REPDATA(reply_base); + struct net_device *dev = reply_base->dev; + netdev_features_t all_features; + + ethnl_features_to_bitmap32(data->hw, dev->hw_features); + ethnl_features_to_bitmap32(data->wanted, dev->wanted_features); + ethnl_features_to_bitmap32(data->active, dev->features); + ethnl_features_to_bitmap32(data->nochange, NETIF_F_NEVER_CHANGE); + all_features = GENMASK_ULL(NETDEV_FEATURE_COUNT - 1, 0); + ethnl_features_to_bitmap32(data->all, all_features); + + return 0; +} + +static int features_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct features_reply_data *data = FEATURES_REPDATA(reply_base); + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + unsigned int len = 0; + int ret; + + ret = ethnl_bitset32_size(data->hw, data->all, NETDEV_FEATURE_COUNT, + netdev_features_strings, compact); + if (ret < 0) + return ret; + len += ret; + ret = ethnl_bitset32_size(data->wanted, NULL, NETDEV_FEATURE_COUNT, + netdev_features_strings, compact); + if (ret < 0) + return ret; + len += ret; + ret = ethnl_bitset32_size(data->active, NULL, NETDEV_FEATURE_COUNT, + netdev_features_strings, compact); + if (ret < 0) + return ret; + len += ret; + ret = ethnl_bitset32_size(data->nochange, NULL, NETDEV_FEATURE_COUNT, + netdev_features_strings, compact); + if (ret < 0) + return ret; + len += ret; + + return len; +} + +static int features_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct features_reply_data *data = FEATURES_REPDATA(reply_base); + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + int ret; + + ret = ethnl_put_bitset32(skb, ETHTOOL_A_FEATURES_HW, data->hw, + data->all, NETDEV_FEATURE_COUNT, + netdev_features_strings, compact); + if (ret < 0) + return ret; + ret = ethnl_put_bitset32(skb, ETHTOOL_A_FEATURES_WANTED, data->wanted, + NULL, NETDEV_FEATURE_COUNT, + netdev_features_strings, compact); + if (ret < 0) + return ret; + ret = ethnl_put_bitset32(skb, ETHTOOL_A_FEATURES_ACTIVE, data->active, + NULL, NETDEV_FEATURE_COUNT, + netdev_features_strings, compact); + if (ret < 0) + return ret; + return ethnl_put_bitset32(skb, ETHTOOL_A_FEATURES_NOCHANGE, + data->nochange, NULL, NETDEV_FEATURE_COUNT, + netdev_features_strings, compact); +} + +const struct ethnl_request_ops ethnl_features_request_ops = { + .request_cmd = ETHTOOL_MSG_FEATURES_GET, + .reply_cmd = ETHTOOL_MSG_FEATURES_GET_REPLY, + .hdr_attr = ETHTOOL_A_FEATURES_HEADER, + .req_info_size = sizeof(struct features_req_info), + .reply_data_size = sizeof(struct features_reply_data), + + .prepare_data = features_prepare_data, + .reply_size = features_reply_size, + .fill_reply = features_fill_reply, +}; + +/* FEATURES_SET */ + +const struct nla_policy ethnl_features_set_policy[] = { + [ETHTOOL_A_FEATURES_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_FEATURES_WANTED] = { .type = NLA_NESTED }, +}; + +static void ethnl_features_to_bitmap(unsigned long *dest, netdev_features_t val) +{ + const unsigned int words = BITS_TO_LONGS(NETDEV_FEATURE_COUNT); + unsigned int i; + + for (i = 0; i < words; i++) + dest[i] = (unsigned long)(val >> (i * BITS_PER_LONG)); +} + +static netdev_features_t ethnl_bitmap_to_features(unsigned long *src) +{ + const unsigned int nft_bits = sizeof(netdev_features_t) * BITS_PER_BYTE; + const unsigned int words = BITS_TO_LONGS(NETDEV_FEATURE_COUNT); + netdev_features_t ret = 0; + unsigned int i; + + for (i = 0; i < words; i++) + ret |= (netdev_features_t)(src[i]) << (i * BITS_PER_LONG); + ret &= ~(netdev_features_t)0 >> (nft_bits - NETDEV_FEATURE_COUNT); + return ret; +} + +static int features_send_reply(struct net_device *dev, struct genl_info *info, + const unsigned long *wanted, + const unsigned long *wanted_mask, + const unsigned long *active, + const unsigned long *active_mask, bool compact) +{ + struct sk_buff *rskb; + void *reply_payload; + int reply_len = 0; + int ret; + + reply_len = ethnl_reply_header_size(); + ret = ethnl_bitset_size(wanted, wanted_mask, NETDEV_FEATURE_COUNT, + netdev_features_strings, compact); + if (ret < 0) + goto err; + reply_len += ret; + ret = ethnl_bitset_size(active, active_mask, NETDEV_FEATURE_COUNT, + netdev_features_strings, compact); + if (ret < 0) + goto err; + reply_len += ret; + + ret = -ENOMEM; + rskb = ethnl_reply_init(reply_len, dev, ETHTOOL_MSG_FEATURES_SET_REPLY, + ETHTOOL_A_FEATURES_HEADER, info, + &reply_payload); + if (!rskb) + goto err; + + ret = ethnl_put_bitset(rskb, ETHTOOL_A_FEATURES_WANTED, wanted, + wanted_mask, NETDEV_FEATURE_COUNT, + netdev_features_strings, compact); + if (ret < 0) + goto nla_put_failure; + ret = ethnl_put_bitset(rskb, ETHTOOL_A_FEATURES_ACTIVE, active, + active_mask, NETDEV_FEATURE_COUNT, + netdev_features_strings, compact); + if (ret < 0) + goto nla_put_failure; + + genlmsg_end(rskb, reply_payload); + ret = genlmsg_reply(rskb, info); + return ret; + +nla_put_failure: + nlmsg_free(rskb); + WARN_ONCE(1, "calculated message payload length (%d) not sufficient\n", + reply_len); +err: + GENL_SET_ERR_MSG(info, "failed to send reply message"); + return ret; +} + +int ethnl_set_features(struct sk_buff *skb, struct genl_info *info) +{ + DECLARE_BITMAP(wanted_diff_mask, NETDEV_FEATURE_COUNT); + DECLARE_BITMAP(active_diff_mask, NETDEV_FEATURE_COUNT); + DECLARE_BITMAP(old_active, NETDEV_FEATURE_COUNT); + DECLARE_BITMAP(old_wanted, NETDEV_FEATURE_COUNT); + DECLARE_BITMAP(new_active, NETDEV_FEATURE_COUNT); + DECLARE_BITMAP(new_wanted, NETDEV_FEATURE_COUNT); + DECLARE_BITMAP(req_wanted, NETDEV_FEATURE_COUNT); + DECLARE_BITMAP(req_mask, NETDEV_FEATURE_COUNT); + struct ethnl_req_info req_info = {}; + struct nlattr **tb = info->attrs; + struct net_device *dev; + bool mod; + int ret; + + if (!tb[ETHTOOL_A_FEATURES_WANTED]) + return -EINVAL; + ret = ethnl_parse_header_dev_get(&req_info, + tb[ETHTOOL_A_FEATURES_HEADER], + genl_info_net(info), info->extack, + true); + if (ret < 0) + return ret; + dev = req_info.dev; + + rtnl_lock(); + ret = ethnl_ops_begin(dev); + if (ret < 0) + goto out_rtnl; + ethnl_features_to_bitmap(old_active, dev->features); + ethnl_features_to_bitmap(old_wanted, dev->wanted_features); + ret = ethnl_parse_bitset(req_wanted, req_mask, NETDEV_FEATURE_COUNT, + tb[ETHTOOL_A_FEATURES_WANTED], + netdev_features_strings, info->extack); + if (ret < 0) + goto out_ops; + if (ethnl_bitmap_to_features(req_mask) & ~NETIF_F_ETHTOOL_BITS) { + GENL_SET_ERR_MSG(info, "attempt to change non-ethtool features"); + ret = -EINVAL; + goto out_ops; + } + + /* set req_wanted bits not in req_mask from old_wanted */ + bitmap_and(req_wanted, req_wanted, req_mask, NETDEV_FEATURE_COUNT); + bitmap_andnot(new_wanted, old_wanted, req_mask, NETDEV_FEATURE_COUNT); + bitmap_or(req_wanted, new_wanted, req_wanted, NETDEV_FEATURE_COUNT); + if (!bitmap_equal(req_wanted, old_wanted, NETDEV_FEATURE_COUNT)) { + dev->wanted_features &= ~dev->hw_features; + dev->wanted_features |= ethnl_bitmap_to_features(req_wanted) & dev->hw_features; + __netdev_update_features(dev); + } + ethnl_features_to_bitmap(new_active, dev->features); + mod = !bitmap_equal(old_active, new_active, NETDEV_FEATURE_COUNT); + + ret = 0; + if (!(req_info.flags & ETHTOOL_FLAG_OMIT_REPLY)) { + bool compact = req_info.flags & ETHTOOL_FLAG_COMPACT_BITSETS; + + bitmap_xor(wanted_diff_mask, req_wanted, new_active, + NETDEV_FEATURE_COUNT); + bitmap_xor(active_diff_mask, old_active, new_active, + NETDEV_FEATURE_COUNT); + bitmap_and(wanted_diff_mask, wanted_diff_mask, req_mask, + NETDEV_FEATURE_COUNT); + bitmap_and(req_wanted, req_wanted, wanted_diff_mask, + NETDEV_FEATURE_COUNT); + bitmap_and(new_active, new_active, active_diff_mask, + NETDEV_FEATURE_COUNT); + + ret = features_send_reply(dev, info, req_wanted, + wanted_diff_mask, new_active, + active_diff_mask, compact); + } + if (mod) + netdev_features_change(dev); + +out_ops: + ethnl_ops_complete(dev); +out_rtnl: + rtnl_unlock(); + ethnl_parse_header_dev_put(&req_info); + return ret; +} diff --git a/net/ethtool/fec.c b/net/ethtool/fec.c new file mode 100644 index 000000000..9f5a134e2 --- /dev/null +++ b/net/ethtool/fec.c @@ -0,0 +1,310 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include "netlink.h" +#include "common.h" +#include "bitset.h" + +struct fec_req_info { + struct ethnl_req_info base; +}; + +struct fec_reply_data { + struct ethnl_reply_data base; + __ETHTOOL_DECLARE_LINK_MODE_MASK(fec_link_modes); + u32 active_fec; + u8 fec_auto; + struct fec_stat_grp { + u64 stats[1 + ETHTOOL_MAX_LANES]; + u8 cnt; + } corr, uncorr, corr_bits; +}; + +#define FEC_REPDATA(__reply_base) \ + container_of(__reply_base, struct fec_reply_data, base) + +#define ETHTOOL_FEC_MASK ((ETHTOOL_FEC_LLRS << 1) - 1) + +const struct nla_policy ethnl_fec_get_policy[ETHTOOL_A_FEC_HEADER + 1] = { + [ETHTOOL_A_FEC_HEADER] = NLA_POLICY_NESTED(ethnl_header_policy_stats), +}; + +static void +ethtool_fec_to_link_modes(u32 fec, unsigned long *link_modes, u8 *fec_auto) +{ + if (fec_auto) + *fec_auto = !!(fec & ETHTOOL_FEC_AUTO); + + if (fec & ETHTOOL_FEC_OFF) + __set_bit(ETHTOOL_LINK_MODE_FEC_NONE_BIT, link_modes); + if (fec & ETHTOOL_FEC_RS) + __set_bit(ETHTOOL_LINK_MODE_FEC_RS_BIT, link_modes); + if (fec & ETHTOOL_FEC_BASER) + __set_bit(ETHTOOL_LINK_MODE_FEC_BASER_BIT, link_modes); + if (fec & ETHTOOL_FEC_LLRS) + __set_bit(ETHTOOL_LINK_MODE_FEC_LLRS_BIT, link_modes); +} + +static int +ethtool_link_modes_to_fecparam(struct ethtool_fecparam *fec, + unsigned long *link_modes, u8 fec_auto) +{ + memset(fec, 0, sizeof(*fec)); + + if (fec_auto) + fec->fec |= ETHTOOL_FEC_AUTO; + + if (__test_and_clear_bit(ETHTOOL_LINK_MODE_FEC_NONE_BIT, link_modes)) + fec->fec |= ETHTOOL_FEC_OFF; + if (__test_and_clear_bit(ETHTOOL_LINK_MODE_FEC_RS_BIT, link_modes)) + fec->fec |= ETHTOOL_FEC_RS; + if (__test_and_clear_bit(ETHTOOL_LINK_MODE_FEC_BASER_BIT, link_modes)) + fec->fec |= ETHTOOL_FEC_BASER; + if (__test_and_clear_bit(ETHTOOL_LINK_MODE_FEC_LLRS_BIT, link_modes)) + fec->fec |= ETHTOOL_FEC_LLRS; + + if (!bitmap_empty(link_modes, __ETHTOOL_LINK_MODE_MASK_NBITS)) + return -EINVAL; + + return 0; +} + +static void +fec_stats_recalc(struct fec_stat_grp *grp, struct ethtool_fec_stat *stats) +{ + int i; + + if (stats->lanes[0] == ETHTOOL_STAT_NOT_SET) { + grp->stats[0] = stats->total; + grp->cnt = stats->total != ETHTOOL_STAT_NOT_SET; + return; + } + + grp->cnt = 1; + grp->stats[0] = 0; + for (i = 0; i < ETHTOOL_MAX_LANES; i++) { + if (stats->lanes[i] == ETHTOOL_STAT_NOT_SET) + break; + + grp->stats[0] += stats->lanes[i]; + grp->stats[grp->cnt++] = stats->lanes[i]; + } +} + +static int fec_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + __ETHTOOL_DECLARE_LINK_MODE_MASK(active_fec_modes) = {}; + struct fec_reply_data *data = FEC_REPDATA(reply_base); + struct net_device *dev = reply_base->dev; + struct ethtool_fecparam fec = {}; + int ret; + + if (!dev->ethtool_ops->get_fecparam) + return -EOPNOTSUPP; + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + ret = dev->ethtool_ops->get_fecparam(dev, &fec); + if (ret) + goto out_complete; + if (req_base->flags & ETHTOOL_FLAG_STATS && + dev->ethtool_ops->get_fec_stats) { + struct ethtool_fec_stats stats; + + ethtool_stats_init((u64 *)&stats, sizeof(stats) / 8); + dev->ethtool_ops->get_fec_stats(dev, &stats); + + fec_stats_recalc(&data->corr, &stats.corrected_blocks); + fec_stats_recalc(&data->uncorr, &stats.uncorrectable_blocks); + fec_stats_recalc(&data->corr_bits, &stats.corrected_bits); + } + + WARN_ON_ONCE(fec.reserved); + + ethtool_fec_to_link_modes(fec.fec, data->fec_link_modes, + &data->fec_auto); + + ethtool_fec_to_link_modes(fec.active_fec, active_fec_modes, NULL); + data->active_fec = find_first_bit(active_fec_modes, + __ETHTOOL_LINK_MODE_MASK_NBITS); + /* Don't report attr if no FEC mode set. Note that + * ethtool_fecparam_to_link_modes() ignores NONE and AUTO. + */ + if (data->active_fec == __ETHTOOL_LINK_MODE_MASK_NBITS) + data->active_fec = 0; + +out_complete: + ethnl_ops_complete(dev); + return ret; +} + +static int fec_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + const struct fec_reply_data *data = FEC_REPDATA(reply_base); + int len = 0; + int ret; + + ret = ethnl_bitset_size(data->fec_link_modes, NULL, + __ETHTOOL_LINK_MODE_MASK_NBITS, + link_mode_names, compact); + if (ret < 0) + return ret; + len += ret; + + len += nla_total_size(sizeof(u8)) + /* _FEC_AUTO */ + nla_total_size(sizeof(u32)); /* _FEC_ACTIVE */ + + if (req_base->flags & ETHTOOL_FLAG_STATS) + len += 3 * nla_total_size_64bit(sizeof(u64) * + (1 + ETHTOOL_MAX_LANES)); + + return len; +} + +static int fec_put_stats(struct sk_buff *skb, const struct fec_reply_data *data) +{ + struct nlattr *nest; + + nest = nla_nest_start(skb, ETHTOOL_A_FEC_STATS); + if (!nest) + return -EMSGSIZE; + + if (nla_put_64bit(skb, ETHTOOL_A_FEC_STAT_CORRECTED, + sizeof(u64) * data->corr.cnt, + data->corr.stats, ETHTOOL_A_FEC_STAT_PAD) || + nla_put_64bit(skb, ETHTOOL_A_FEC_STAT_UNCORR, + sizeof(u64) * data->uncorr.cnt, + data->uncorr.stats, ETHTOOL_A_FEC_STAT_PAD) || + nla_put_64bit(skb, ETHTOOL_A_FEC_STAT_CORR_BITS, + sizeof(u64) * data->corr_bits.cnt, + data->corr_bits.stats, ETHTOOL_A_FEC_STAT_PAD)) + goto err_cancel; + + nla_nest_end(skb, nest); + return 0; + +err_cancel: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int fec_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + const struct fec_reply_data *data = FEC_REPDATA(reply_base); + int ret; + + ret = ethnl_put_bitset(skb, ETHTOOL_A_FEC_MODES, + data->fec_link_modes, NULL, + __ETHTOOL_LINK_MODE_MASK_NBITS, + link_mode_names, compact); + if (ret < 0) + return ret; + + if (nla_put_u8(skb, ETHTOOL_A_FEC_AUTO, data->fec_auto) || + (data->active_fec && + nla_put_u32(skb, ETHTOOL_A_FEC_ACTIVE, data->active_fec))) + return -EMSGSIZE; + + if (req_base->flags & ETHTOOL_FLAG_STATS && fec_put_stats(skb, data)) + return -EMSGSIZE; + + return 0; +} + +const struct ethnl_request_ops ethnl_fec_request_ops = { + .request_cmd = ETHTOOL_MSG_FEC_GET, + .reply_cmd = ETHTOOL_MSG_FEC_GET_REPLY, + .hdr_attr = ETHTOOL_A_FEC_HEADER, + .req_info_size = sizeof(struct fec_req_info), + .reply_data_size = sizeof(struct fec_reply_data), + + .prepare_data = fec_prepare_data, + .reply_size = fec_reply_size, + .fill_reply = fec_fill_reply, +}; + +/* FEC_SET */ + +const struct nla_policy ethnl_fec_set_policy[ETHTOOL_A_FEC_AUTO + 1] = { + [ETHTOOL_A_FEC_HEADER] = NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_FEC_MODES] = { .type = NLA_NESTED }, + [ETHTOOL_A_FEC_AUTO] = NLA_POLICY_MAX(NLA_U8, 1), +}; + +int ethnl_set_fec(struct sk_buff *skb, struct genl_info *info) +{ + __ETHTOOL_DECLARE_LINK_MODE_MASK(fec_link_modes) = {}; + struct ethnl_req_info req_info = {}; + struct nlattr **tb = info->attrs; + struct ethtool_fecparam fec = {}; + const struct ethtool_ops *ops; + struct net_device *dev; + bool mod = false; + u8 fec_auto; + int ret; + + ret = ethnl_parse_header_dev_get(&req_info, tb[ETHTOOL_A_FEC_HEADER], + genl_info_net(info), info->extack, + true); + if (ret < 0) + return ret; + dev = req_info.dev; + ops = dev->ethtool_ops; + ret = -EOPNOTSUPP; + if (!ops->get_fecparam || !ops->set_fecparam) + goto out_dev; + + rtnl_lock(); + ret = ethnl_ops_begin(dev); + if (ret < 0) + goto out_rtnl; + ret = ops->get_fecparam(dev, &fec); + if (ret < 0) + goto out_ops; + + ethtool_fec_to_link_modes(fec.fec, fec_link_modes, &fec_auto); + + ret = ethnl_update_bitset(fec_link_modes, + __ETHTOOL_LINK_MODE_MASK_NBITS, + tb[ETHTOOL_A_FEC_MODES], + link_mode_names, info->extack, &mod); + if (ret < 0) + goto out_ops; + ethnl_update_u8(&fec_auto, tb[ETHTOOL_A_FEC_AUTO], &mod); + + ret = 0; + if (!mod) + goto out_ops; + + ret = ethtool_link_modes_to_fecparam(&fec, fec_link_modes, fec_auto); + if (ret) { + NL_SET_ERR_MSG_ATTR(info->extack, tb[ETHTOOL_A_FEC_MODES], + "invalid FEC modes requested"); + goto out_ops; + } + if (!fec.fec) { + ret = -EINVAL; + NL_SET_ERR_MSG_ATTR(info->extack, tb[ETHTOOL_A_FEC_MODES], + "no FEC modes set"); + goto out_ops; + } + + ret = dev->ethtool_ops->set_fecparam(dev, &fec); + if (ret < 0) + goto out_ops; + ethtool_notify(dev, ETHTOOL_MSG_FEC_NTF, NULL); + +out_ops: + ethnl_ops_complete(dev); +out_rtnl: + rtnl_unlock(); +out_dev: + ethnl_parse_header_dev_put(&req_info); + return ret; +} diff --git a/net/ethtool/ioctl.c b/net/ethtool/ioctl.c new file mode 100644 index 000000000..e31d1247b --- /dev/null +++ b/net/ethtool/ioctl.c @@ -0,0 +1,3341 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/core/ethtool.c - Ethtool ioctl handler + * Copyright (c) 2003 Matthew Wilcox + * + * This file is where we call all the ethtool_ops commands to get + * the information ethtool needs. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "common.h" + +/* State held across locks and calls for commands which have devlink fallback */ +struct ethtool_devlink_compat { + struct devlink *devlink; + union { + struct ethtool_flash efl; + struct ethtool_drvinfo info; + }; +}; + +static struct devlink *netdev_to_devlink_get(struct net_device *dev) +{ + struct devlink_port *devlink_port; + + if (!dev->netdev_ops->ndo_get_devlink_port) + return NULL; + + devlink_port = dev->netdev_ops->ndo_get_devlink_port(dev); + if (!devlink_port) + return NULL; + + return devlink_try_get(devlink_port->devlink); +} + +/* + * Some useful ethtool_ops methods that're device independent. + * If we find that all drivers want to do the same thing here, + * we can turn these into dev_() function calls. + */ + +u32 ethtool_op_get_link(struct net_device *dev) +{ + return netif_carrier_ok(dev) ? 1 : 0; +} +EXPORT_SYMBOL(ethtool_op_get_link); + +int ethtool_op_get_ts_info(struct net_device *dev, struct ethtool_ts_info *info) +{ + info->so_timestamping = + SOF_TIMESTAMPING_TX_SOFTWARE | + SOF_TIMESTAMPING_RX_SOFTWARE | + SOF_TIMESTAMPING_SOFTWARE; + info->phc_index = -1; + return 0; +} +EXPORT_SYMBOL(ethtool_op_get_ts_info); + +/* Handlers for each ethtool command */ + +static int ethtool_get_features(struct net_device *dev, void __user *useraddr) +{ + struct ethtool_gfeatures cmd = { + .cmd = ETHTOOL_GFEATURES, + .size = ETHTOOL_DEV_FEATURE_WORDS, + }; + struct ethtool_get_features_block features[ETHTOOL_DEV_FEATURE_WORDS]; + u32 __user *sizeaddr; + u32 copy_size; + int i; + + /* in case feature bits run out again */ + BUILD_BUG_ON(ETHTOOL_DEV_FEATURE_WORDS * sizeof(u32) > sizeof(netdev_features_t)); + + for (i = 0; i < ETHTOOL_DEV_FEATURE_WORDS; ++i) { + features[i].available = (u32)(dev->hw_features >> (32 * i)); + features[i].requested = (u32)(dev->wanted_features >> (32 * i)); + features[i].active = (u32)(dev->features >> (32 * i)); + features[i].never_changed = + (u32)(NETIF_F_NEVER_CHANGE >> (32 * i)); + } + + sizeaddr = useraddr + offsetof(struct ethtool_gfeatures, size); + if (get_user(copy_size, sizeaddr)) + return -EFAULT; + + if (copy_size > ETHTOOL_DEV_FEATURE_WORDS) + copy_size = ETHTOOL_DEV_FEATURE_WORDS; + + if (copy_to_user(useraddr, &cmd, sizeof(cmd))) + return -EFAULT; + useraddr += sizeof(cmd); + if (copy_to_user(useraddr, features, + array_size(copy_size, sizeof(*features)))) + return -EFAULT; + + return 0; +} + +static int ethtool_set_features(struct net_device *dev, void __user *useraddr) +{ + struct ethtool_sfeatures cmd; + struct ethtool_set_features_block features[ETHTOOL_DEV_FEATURE_WORDS]; + netdev_features_t wanted = 0, valid = 0; + int i, ret = 0; + + if (copy_from_user(&cmd, useraddr, sizeof(cmd))) + return -EFAULT; + useraddr += sizeof(cmd); + + if (cmd.size != ETHTOOL_DEV_FEATURE_WORDS) + return -EINVAL; + + if (copy_from_user(features, useraddr, sizeof(features))) + return -EFAULT; + + for (i = 0; i < ETHTOOL_DEV_FEATURE_WORDS; ++i) { + valid |= (netdev_features_t)features[i].valid << (32 * i); + wanted |= (netdev_features_t)features[i].requested << (32 * i); + } + + if (valid & ~NETIF_F_ETHTOOL_BITS) + return -EINVAL; + + if (valid & ~dev->hw_features) { + valid &= dev->hw_features; + ret |= ETHTOOL_F_UNSUPPORTED; + } + + dev->wanted_features &= ~valid; + dev->wanted_features |= wanted & valid; + __netdev_update_features(dev); + + if ((dev->wanted_features ^ dev->features) & valid) + ret |= ETHTOOL_F_WISH; + + return ret; +} + +static int __ethtool_get_sset_count(struct net_device *dev, int sset) +{ + const struct ethtool_phy_ops *phy_ops = ethtool_phy_ops; + const struct ethtool_ops *ops = dev->ethtool_ops; + + if (sset == ETH_SS_FEATURES) + return ARRAY_SIZE(netdev_features_strings); + + if (sset == ETH_SS_RSS_HASH_FUNCS) + return ARRAY_SIZE(rss_hash_func_strings); + + if (sset == ETH_SS_TUNABLES) + return ARRAY_SIZE(tunable_strings); + + if (sset == ETH_SS_PHY_TUNABLES) + return ARRAY_SIZE(phy_tunable_strings); + + if (sset == ETH_SS_PHY_STATS && dev->phydev && + !ops->get_ethtool_phy_stats && + phy_ops && phy_ops->get_sset_count) + return phy_ops->get_sset_count(dev->phydev); + + if (sset == ETH_SS_LINK_MODES) + return __ETHTOOL_LINK_MODE_MASK_NBITS; + + if (ops->get_sset_count && ops->get_strings) + return ops->get_sset_count(dev, sset); + else + return -EOPNOTSUPP; +} + +static void __ethtool_get_strings(struct net_device *dev, + u32 stringset, u8 *data) +{ + const struct ethtool_phy_ops *phy_ops = ethtool_phy_ops; + const struct ethtool_ops *ops = dev->ethtool_ops; + + if (stringset == ETH_SS_FEATURES) + memcpy(data, netdev_features_strings, + sizeof(netdev_features_strings)); + else if (stringset == ETH_SS_RSS_HASH_FUNCS) + memcpy(data, rss_hash_func_strings, + sizeof(rss_hash_func_strings)); + else if (stringset == ETH_SS_TUNABLES) + memcpy(data, tunable_strings, sizeof(tunable_strings)); + else if (stringset == ETH_SS_PHY_TUNABLES) + memcpy(data, phy_tunable_strings, sizeof(phy_tunable_strings)); + else if (stringset == ETH_SS_PHY_STATS && dev->phydev && + !ops->get_ethtool_phy_stats && phy_ops && + phy_ops->get_strings) + phy_ops->get_strings(dev->phydev, data); + else if (stringset == ETH_SS_LINK_MODES) + memcpy(data, link_mode_names, + __ETHTOOL_LINK_MODE_MASK_NBITS * ETH_GSTRING_LEN); + else + /* ops->get_strings is valid because checked earlier */ + ops->get_strings(dev, stringset, data); +} + +static netdev_features_t ethtool_get_feature_mask(u32 eth_cmd) +{ + /* feature masks of legacy discrete ethtool ops */ + + switch (eth_cmd) { + case ETHTOOL_GTXCSUM: + case ETHTOOL_STXCSUM: + return NETIF_F_CSUM_MASK | NETIF_F_FCOE_CRC | + NETIF_F_SCTP_CRC; + case ETHTOOL_GRXCSUM: + case ETHTOOL_SRXCSUM: + return NETIF_F_RXCSUM; + case ETHTOOL_GSG: + case ETHTOOL_SSG: + return NETIF_F_SG | NETIF_F_FRAGLIST; + case ETHTOOL_GTSO: + case ETHTOOL_STSO: + return NETIF_F_ALL_TSO; + case ETHTOOL_GGSO: + case ETHTOOL_SGSO: + return NETIF_F_GSO; + case ETHTOOL_GGRO: + case ETHTOOL_SGRO: + return NETIF_F_GRO; + default: + BUG(); + } +} + +static int ethtool_get_one_feature(struct net_device *dev, + char __user *useraddr, u32 ethcmd) +{ + netdev_features_t mask = ethtool_get_feature_mask(ethcmd); + struct ethtool_value edata = { + .cmd = ethcmd, + .data = !!(dev->features & mask), + }; + + if (copy_to_user(useraddr, &edata, sizeof(edata))) + return -EFAULT; + return 0; +} + +static int ethtool_set_one_feature(struct net_device *dev, + void __user *useraddr, u32 ethcmd) +{ + struct ethtool_value edata; + netdev_features_t mask; + + if (copy_from_user(&edata, useraddr, sizeof(edata))) + return -EFAULT; + + mask = ethtool_get_feature_mask(ethcmd); + mask &= dev->hw_features; + if (!mask) + return -EOPNOTSUPP; + + if (edata.data) + dev->wanted_features |= mask; + else + dev->wanted_features &= ~mask; + + __netdev_update_features(dev); + + return 0; +} + +#define ETH_ALL_FLAGS (ETH_FLAG_LRO | ETH_FLAG_RXVLAN | ETH_FLAG_TXVLAN | \ + ETH_FLAG_NTUPLE | ETH_FLAG_RXHASH) +#define ETH_ALL_FEATURES (NETIF_F_LRO | NETIF_F_HW_VLAN_CTAG_RX | \ + NETIF_F_HW_VLAN_CTAG_TX | NETIF_F_NTUPLE | \ + NETIF_F_RXHASH) + +static u32 __ethtool_get_flags(struct net_device *dev) +{ + u32 flags = 0; + + if (dev->features & NETIF_F_LRO) + flags |= ETH_FLAG_LRO; + if (dev->features & NETIF_F_HW_VLAN_CTAG_RX) + flags |= ETH_FLAG_RXVLAN; + if (dev->features & NETIF_F_HW_VLAN_CTAG_TX) + flags |= ETH_FLAG_TXVLAN; + if (dev->features & NETIF_F_NTUPLE) + flags |= ETH_FLAG_NTUPLE; + if (dev->features & NETIF_F_RXHASH) + flags |= ETH_FLAG_RXHASH; + + return flags; +} + +static int __ethtool_set_flags(struct net_device *dev, u32 data) +{ + netdev_features_t features = 0, changed; + + if (data & ~ETH_ALL_FLAGS) + return -EINVAL; + + if (data & ETH_FLAG_LRO) + features |= NETIF_F_LRO; + if (data & ETH_FLAG_RXVLAN) + features |= NETIF_F_HW_VLAN_CTAG_RX; + if (data & ETH_FLAG_TXVLAN) + features |= NETIF_F_HW_VLAN_CTAG_TX; + if (data & ETH_FLAG_NTUPLE) + features |= NETIF_F_NTUPLE; + if (data & ETH_FLAG_RXHASH) + features |= NETIF_F_RXHASH; + + /* allow changing only bits set in hw_features */ + changed = (features ^ dev->features) & ETH_ALL_FEATURES; + if (changed & ~dev->hw_features) + return (changed & dev->hw_features) ? -EINVAL : -EOPNOTSUPP; + + dev->wanted_features = + (dev->wanted_features & ~changed) | (features & changed); + + __netdev_update_features(dev); + + return 0; +} + +/* Given two link masks, AND them together and save the result in dst. */ +void ethtool_intersect_link_masks(struct ethtool_link_ksettings *dst, + struct ethtool_link_ksettings *src) +{ + unsigned int size = BITS_TO_LONGS(__ETHTOOL_LINK_MODE_MASK_NBITS); + unsigned int idx = 0; + + for (; idx < size; idx++) { + dst->link_modes.supported[idx] &= + src->link_modes.supported[idx]; + dst->link_modes.advertising[idx] &= + src->link_modes.advertising[idx]; + } +} +EXPORT_SYMBOL(ethtool_intersect_link_masks); + +void ethtool_convert_legacy_u32_to_link_mode(unsigned long *dst, + u32 legacy_u32) +{ + linkmode_zero(dst); + dst[0] = legacy_u32; +} +EXPORT_SYMBOL(ethtool_convert_legacy_u32_to_link_mode); + +/* return false if src had higher bits set. lower bits always updated. */ +bool ethtool_convert_link_mode_to_legacy_u32(u32 *legacy_u32, + const unsigned long *src) +{ + *legacy_u32 = src[0]; + return find_next_bit(src, __ETHTOOL_LINK_MODE_MASK_NBITS, 32) == + __ETHTOOL_LINK_MODE_MASK_NBITS; +} +EXPORT_SYMBOL(ethtool_convert_link_mode_to_legacy_u32); + +/* return false if ksettings link modes had higher bits + * set. legacy_settings always updated (best effort) + */ +static bool +convert_link_ksettings_to_legacy_settings( + struct ethtool_cmd *legacy_settings, + const struct ethtool_link_ksettings *link_ksettings) +{ + bool retval = true; + + memset(legacy_settings, 0, sizeof(*legacy_settings)); + /* this also clears the deprecated fields in legacy structure: + * __u8 transceiver; + * __u32 maxtxpkt; + * __u32 maxrxpkt; + */ + + retval &= ethtool_convert_link_mode_to_legacy_u32( + &legacy_settings->supported, + link_ksettings->link_modes.supported); + retval &= ethtool_convert_link_mode_to_legacy_u32( + &legacy_settings->advertising, + link_ksettings->link_modes.advertising); + retval &= ethtool_convert_link_mode_to_legacy_u32( + &legacy_settings->lp_advertising, + link_ksettings->link_modes.lp_advertising); + ethtool_cmd_speed_set(legacy_settings, link_ksettings->base.speed); + legacy_settings->duplex + = link_ksettings->base.duplex; + legacy_settings->port + = link_ksettings->base.port; + legacy_settings->phy_address + = link_ksettings->base.phy_address; + legacy_settings->autoneg + = link_ksettings->base.autoneg; + legacy_settings->mdio_support + = link_ksettings->base.mdio_support; + legacy_settings->eth_tp_mdix + = link_ksettings->base.eth_tp_mdix; + legacy_settings->eth_tp_mdix_ctrl + = link_ksettings->base.eth_tp_mdix_ctrl; + legacy_settings->transceiver + = link_ksettings->base.transceiver; + return retval; +} + +/* number of 32-bit words to store the user's link mode bitmaps */ +#define __ETHTOOL_LINK_MODE_MASK_NU32 \ + DIV_ROUND_UP(__ETHTOOL_LINK_MODE_MASK_NBITS, 32) + +/* layout of the struct passed from/to userland */ +struct ethtool_link_usettings { + struct ethtool_link_settings base; + struct { + __u32 supported[__ETHTOOL_LINK_MODE_MASK_NU32]; + __u32 advertising[__ETHTOOL_LINK_MODE_MASK_NU32]; + __u32 lp_advertising[__ETHTOOL_LINK_MODE_MASK_NU32]; + } link_modes; +}; + +/* Internal kernel helper to query a device ethtool_link_settings. */ +int __ethtool_get_link_ksettings(struct net_device *dev, + struct ethtool_link_ksettings *link_ksettings) +{ + ASSERT_RTNL(); + + if (!dev->ethtool_ops->get_link_ksettings) + return -EOPNOTSUPP; + + memset(link_ksettings, 0, sizeof(*link_ksettings)); + return dev->ethtool_ops->get_link_ksettings(dev, link_ksettings); +} +EXPORT_SYMBOL(__ethtool_get_link_ksettings); + +/* convert ethtool_link_usettings in user space to a kernel internal + * ethtool_link_ksettings. return 0 on success, errno on error. + */ +static int load_link_ksettings_from_user(struct ethtool_link_ksettings *to, + const void __user *from) +{ + struct ethtool_link_usettings link_usettings; + + if (copy_from_user(&link_usettings, from, sizeof(link_usettings))) + return -EFAULT; + + memcpy(&to->base, &link_usettings.base, sizeof(to->base)); + bitmap_from_arr32(to->link_modes.supported, + link_usettings.link_modes.supported, + __ETHTOOL_LINK_MODE_MASK_NBITS); + bitmap_from_arr32(to->link_modes.advertising, + link_usettings.link_modes.advertising, + __ETHTOOL_LINK_MODE_MASK_NBITS); + bitmap_from_arr32(to->link_modes.lp_advertising, + link_usettings.link_modes.lp_advertising, + __ETHTOOL_LINK_MODE_MASK_NBITS); + + return 0; +} + +/* Check if the user is trying to change anything besides speed/duplex */ +bool ethtool_virtdev_validate_cmd(const struct ethtool_link_ksettings *cmd) +{ + struct ethtool_link_settings base2 = {}; + + base2.speed = cmd->base.speed; + base2.port = PORT_OTHER; + base2.duplex = cmd->base.duplex; + base2.cmd = cmd->base.cmd; + base2.link_mode_masks_nwords = cmd->base.link_mode_masks_nwords; + + return !memcmp(&base2, &cmd->base, sizeof(base2)) && + bitmap_empty(cmd->link_modes.supported, + __ETHTOOL_LINK_MODE_MASK_NBITS) && + bitmap_empty(cmd->link_modes.lp_advertising, + __ETHTOOL_LINK_MODE_MASK_NBITS); +} + +/* convert a kernel internal ethtool_link_ksettings to + * ethtool_link_usettings in user space. return 0 on success, errno on + * error. + */ +static int +store_link_ksettings_for_user(void __user *to, + const struct ethtool_link_ksettings *from) +{ + struct ethtool_link_usettings link_usettings; + + memcpy(&link_usettings, from, sizeof(link_usettings)); + bitmap_to_arr32(link_usettings.link_modes.supported, + from->link_modes.supported, + __ETHTOOL_LINK_MODE_MASK_NBITS); + bitmap_to_arr32(link_usettings.link_modes.advertising, + from->link_modes.advertising, + __ETHTOOL_LINK_MODE_MASK_NBITS); + bitmap_to_arr32(link_usettings.link_modes.lp_advertising, + from->link_modes.lp_advertising, + __ETHTOOL_LINK_MODE_MASK_NBITS); + + if (copy_to_user(to, &link_usettings, sizeof(link_usettings))) + return -EFAULT; + + return 0; +} + +/* Query device for its ethtool_link_settings. */ +static int ethtool_get_link_ksettings(struct net_device *dev, + void __user *useraddr) +{ + int err = 0; + struct ethtool_link_ksettings link_ksettings; + + ASSERT_RTNL(); + if (!dev->ethtool_ops->get_link_ksettings) + return -EOPNOTSUPP; + + /* handle bitmap nbits handshake */ + if (copy_from_user(&link_ksettings.base, useraddr, + sizeof(link_ksettings.base))) + return -EFAULT; + + if (__ETHTOOL_LINK_MODE_MASK_NU32 + != link_ksettings.base.link_mode_masks_nwords) { + /* wrong link mode nbits requested */ + memset(&link_ksettings, 0, sizeof(link_ksettings)); + link_ksettings.base.cmd = ETHTOOL_GLINKSETTINGS; + /* send back number of words required as negative val */ + compiletime_assert(__ETHTOOL_LINK_MODE_MASK_NU32 <= S8_MAX, + "need too many bits for link modes!"); + link_ksettings.base.link_mode_masks_nwords + = -((s8)__ETHTOOL_LINK_MODE_MASK_NU32); + + /* copy the base fields back to user, not the link + * mode bitmaps + */ + if (copy_to_user(useraddr, &link_ksettings.base, + sizeof(link_ksettings.base))) + return -EFAULT; + + return 0; + } + + /* handshake successful: user/kernel agree on + * link_mode_masks_nwords + */ + + memset(&link_ksettings, 0, sizeof(link_ksettings)); + err = dev->ethtool_ops->get_link_ksettings(dev, &link_ksettings); + if (err < 0) + return err; + + /* make sure we tell the right values to user */ + link_ksettings.base.cmd = ETHTOOL_GLINKSETTINGS; + link_ksettings.base.link_mode_masks_nwords + = __ETHTOOL_LINK_MODE_MASK_NU32; + link_ksettings.base.master_slave_cfg = MASTER_SLAVE_CFG_UNSUPPORTED; + link_ksettings.base.master_slave_state = MASTER_SLAVE_STATE_UNSUPPORTED; + link_ksettings.base.rate_matching = RATE_MATCH_NONE; + + return store_link_ksettings_for_user(useraddr, &link_ksettings); +} + +/* Update device ethtool_link_settings. */ +static int ethtool_set_link_ksettings(struct net_device *dev, + void __user *useraddr) +{ + struct ethtool_link_ksettings link_ksettings = {}; + int err; + + ASSERT_RTNL(); + + if (!dev->ethtool_ops->set_link_ksettings) + return -EOPNOTSUPP; + + /* make sure nbits field has expected value */ + if (copy_from_user(&link_ksettings.base, useraddr, + sizeof(link_ksettings.base))) + return -EFAULT; + + if (__ETHTOOL_LINK_MODE_MASK_NU32 + != link_ksettings.base.link_mode_masks_nwords) + return -EINVAL; + + /* copy the whole structure, now that we know it has expected + * format + */ + err = load_link_ksettings_from_user(&link_ksettings, useraddr); + if (err) + return err; + + /* re-check nwords field, just in case */ + if (__ETHTOOL_LINK_MODE_MASK_NU32 + != link_ksettings.base.link_mode_masks_nwords) + return -EINVAL; + + if (link_ksettings.base.master_slave_cfg || + link_ksettings.base.master_slave_state) + return -EINVAL; + + err = dev->ethtool_ops->set_link_ksettings(dev, &link_ksettings); + if (err >= 0) { + ethtool_notify(dev, ETHTOOL_MSG_LINKINFO_NTF, NULL); + ethtool_notify(dev, ETHTOOL_MSG_LINKMODES_NTF, NULL); + } + return err; +} + +int ethtool_virtdev_set_link_ksettings(struct net_device *dev, + const struct ethtool_link_ksettings *cmd, + u32 *dev_speed, u8 *dev_duplex) +{ + u32 speed; + u8 duplex; + + speed = cmd->base.speed; + duplex = cmd->base.duplex; + /* don't allow custom speed and duplex */ + if (!ethtool_validate_speed(speed) || + !ethtool_validate_duplex(duplex) || + !ethtool_virtdev_validate_cmd(cmd)) + return -EINVAL; + *dev_speed = speed; + *dev_duplex = duplex; + + return 0; +} +EXPORT_SYMBOL(ethtool_virtdev_set_link_ksettings); + +/* Query device for its ethtool_cmd settings. + * + * Backward compatibility note: for compatibility with legacy ethtool, this is + * now implemented via get_link_ksettings. When driver reports higher link mode + * bits, a kernel warning is logged once (with name of 1st driver/device) to + * recommend user to upgrade ethtool, but the command is successful (only the + * lower link mode bits reported back to user). Deprecated fields from + * ethtool_cmd (transceiver/maxrxpkt/maxtxpkt) are always set to zero. + */ +static int ethtool_get_settings(struct net_device *dev, void __user *useraddr) +{ + struct ethtool_link_ksettings link_ksettings; + struct ethtool_cmd cmd; + int err; + + ASSERT_RTNL(); + if (!dev->ethtool_ops->get_link_ksettings) + return -EOPNOTSUPP; + + memset(&link_ksettings, 0, sizeof(link_ksettings)); + err = dev->ethtool_ops->get_link_ksettings(dev, &link_ksettings); + if (err < 0) + return err; + convert_link_ksettings_to_legacy_settings(&cmd, &link_ksettings); + + /* send a sensible cmd tag back to user */ + cmd.cmd = ETHTOOL_GSET; + + if (copy_to_user(useraddr, &cmd, sizeof(cmd))) + return -EFAULT; + + return 0; +} + +/* Update device link settings with given ethtool_cmd. + * + * Backward compatibility note: for compatibility with legacy ethtool, this is + * now always implemented via set_link_settings. When user's request updates + * deprecated ethtool_cmd fields (transceiver/maxrxpkt/maxtxpkt), a kernel + * warning is logged once (with name of 1st driver/device) to recommend user to + * upgrade ethtool, and the request is rejected. + */ +static int ethtool_set_settings(struct net_device *dev, void __user *useraddr) +{ + struct ethtool_link_ksettings link_ksettings; + struct ethtool_cmd cmd; + int ret; + + ASSERT_RTNL(); + + if (copy_from_user(&cmd, useraddr, sizeof(cmd))) + return -EFAULT; + if (!dev->ethtool_ops->set_link_ksettings) + return -EOPNOTSUPP; + + if (!convert_legacy_settings_to_link_ksettings(&link_ksettings, &cmd)) + return -EINVAL; + link_ksettings.base.link_mode_masks_nwords = + __ETHTOOL_LINK_MODE_MASK_NU32; + ret = dev->ethtool_ops->set_link_ksettings(dev, &link_ksettings); + if (ret >= 0) { + ethtool_notify(dev, ETHTOOL_MSG_LINKINFO_NTF, NULL); + ethtool_notify(dev, ETHTOOL_MSG_LINKMODES_NTF, NULL); + } + return ret; +} + +static int +ethtool_get_drvinfo(struct net_device *dev, struct ethtool_devlink_compat *rsp) +{ + const struct ethtool_ops *ops = dev->ethtool_ops; + + rsp->info.cmd = ETHTOOL_GDRVINFO; + strscpy(rsp->info.version, UTS_RELEASE, sizeof(rsp->info.version)); + if (ops->get_drvinfo) { + ops->get_drvinfo(dev, &rsp->info); + } else if (dev->dev.parent && dev->dev.parent->driver) { + strscpy(rsp->info.bus_info, dev_name(dev->dev.parent), + sizeof(rsp->info.bus_info)); + strscpy(rsp->info.driver, dev->dev.parent->driver->name, + sizeof(rsp->info.driver)); + } else if (dev->rtnl_link_ops) { + strscpy(rsp->info.driver, dev->rtnl_link_ops->kind, + sizeof(rsp->info.driver)); + } else { + return -EOPNOTSUPP; + } + + /* + * this method of obtaining string set info is deprecated; + * Use ETHTOOL_GSSET_INFO instead. + */ + if (ops->get_sset_count) { + int rc; + + rc = ops->get_sset_count(dev, ETH_SS_TEST); + if (rc >= 0) + rsp->info.testinfo_len = rc; + rc = ops->get_sset_count(dev, ETH_SS_STATS); + if (rc >= 0) + rsp->info.n_stats = rc; + rc = ops->get_sset_count(dev, ETH_SS_PRIV_FLAGS); + if (rc >= 0) + rsp->info.n_priv_flags = rc; + } + if (ops->get_regs_len) { + int ret = ops->get_regs_len(dev); + + if (ret > 0) + rsp->info.regdump_len = ret; + } + + if (ops->get_eeprom_len) + rsp->info.eedump_len = ops->get_eeprom_len(dev); + + if (!rsp->info.fw_version[0]) + rsp->devlink = netdev_to_devlink_get(dev); + + return 0; +} + +static noinline_for_stack int ethtool_get_sset_info(struct net_device *dev, + void __user *useraddr) +{ + struct ethtool_sset_info info; + u64 sset_mask; + int i, idx = 0, n_bits = 0, ret, rc; + u32 *info_buf = NULL; + + if (copy_from_user(&info, useraddr, sizeof(info))) + return -EFAULT; + + /* store copy of mask, because we zero struct later on */ + sset_mask = info.sset_mask; + if (!sset_mask) + return 0; + + /* calculate size of return buffer */ + n_bits = hweight64(sset_mask); + + memset(&info, 0, sizeof(info)); + info.cmd = ETHTOOL_GSSET_INFO; + + info_buf = kcalloc(n_bits, sizeof(u32), GFP_USER); + if (!info_buf) + return -ENOMEM; + + /* + * fill return buffer based on input bitmask and successful + * get_sset_count return + */ + for (i = 0; i < 64; i++) { + if (!(sset_mask & (1ULL << i))) + continue; + + rc = __ethtool_get_sset_count(dev, i); + if (rc >= 0) { + info.sset_mask |= (1ULL << i); + info_buf[idx++] = rc; + } + } + + ret = -EFAULT; + if (copy_to_user(useraddr, &info, sizeof(info))) + goto out; + + useraddr += offsetof(struct ethtool_sset_info, data); + if (copy_to_user(useraddr, info_buf, array_size(idx, sizeof(u32)))) + goto out; + + ret = 0; + +out: + kfree(info_buf); + return ret; +} + +static noinline_for_stack int +ethtool_rxnfc_copy_from_compat(struct ethtool_rxnfc *rxnfc, + const struct compat_ethtool_rxnfc __user *useraddr, + size_t size) +{ + struct compat_ethtool_rxnfc crxnfc = {}; + + /* We expect there to be holes between fs.m_ext and + * fs.ring_cookie and at the end of fs, but nowhere else. + * On non-x86, no conversion should be needed. + */ + BUILD_BUG_ON(!IS_ENABLED(CONFIG_X86_64) && + sizeof(struct compat_ethtool_rxnfc) != + sizeof(struct ethtool_rxnfc)); + BUILD_BUG_ON(offsetof(struct compat_ethtool_rxnfc, fs.m_ext) + + sizeof(useraddr->fs.m_ext) != + offsetof(struct ethtool_rxnfc, fs.m_ext) + + sizeof(rxnfc->fs.m_ext)); + BUILD_BUG_ON(offsetof(struct compat_ethtool_rxnfc, fs.location) - + offsetof(struct compat_ethtool_rxnfc, fs.ring_cookie) != + offsetof(struct ethtool_rxnfc, fs.location) - + offsetof(struct ethtool_rxnfc, fs.ring_cookie)); + + if (copy_from_user(&crxnfc, useraddr, min(size, sizeof(crxnfc)))) + return -EFAULT; + + *rxnfc = (struct ethtool_rxnfc) { + .cmd = crxnfc.cmd, + .flow_type = crxnfc.flow_type, + .data = crxnfc.data, + .fs = { + .flow_type = crxnfc.fs.flow_type, + .h_u = crxnfc.fs.h_u, + .h_ext = crxnfc.fs.h_ext, + .m_u = crxnfc.fs.m_u, + .m_ext = crxnfc.fs.m_ext, + .ring_cookie = crxnfc.fs.ring_cookie, + .location = crxnfc.fs.location, + }, + .rule_cnt = crxnfc.rule_cnt, + }; + + return 0; +} + +static int ethtool_rxnfc_copy_from_user(struct ethtool_rxnfc *rxnfc, + const void __user *useraddr, + size_t size) +{ + if (compat_need_64bit_alignment_fixup()) + return ethtool_rxnfc_copy_from_compat(rxnfc, useraddr, size); + + if (copy_from_user(rxnfc, useraddr, size)) + return -EFAULT; + + return 0; +} + +static int ethtool_rxnfc_copy_to_compat(void __user *useraddr, + const struct ethtool_rxnfc *rxnfc, + size_t size, const u32 *rule_buf) +{ + struct compat_ethtool_rxnfc crxnfc; + + memset(&crxnfc, 0, sizeof(crxnfc)); + crxnfc = (struct compat_ethtool_rxnfc) { + .cmd = rxnfc->cmd, + .flow_type = rxnfc->flow_type, + .data = rxnfc->data, + .fs = { + .flow_type = rxnfc->fs.flow_type, + .h_u = rxnfc->fs.h_u, + .h_ext = rxnfc->fs.h_ext, + .m_u = rxnfc->fs.m_u, + .m_ext = rxnfc->fs.m_ext, + .ring_cookie = rxnfc->fs.ring_cookie, + .location = rxnfc->fs.location, + }, + .rule_cnt = rxnfc->rule_cnt, + }; + + if (copy_to_user(useraddr, &crxnfc, min(size, sizeof(crxnfc)))) + return -EFAULT; + + return 0; +} + +static int ethtool_rxnfc_copy_to_user(void __user *useraddr, + const struct ethtool_rxnfc *rxnfc, + size_t size, const u32 *rule_buf) +{ + int ret; + + if (compat_need_64bit_alignment_fixup()) { + ret = ethtool_rxnfc_copy_to_compat(useraddr, rxnfc, size, + rule_buf); + useraddr += offsetof(struct compat_ethtool_rxnfc, rule_locs); + } else { + ret = copy_to_user(useraddr, rxnfc, size); + useraddr += offsetof(struct ethtool_rxnfc, rule_locs); + } + + if (ret) + return -EFAULT; + + if (rule_buf) { + if (copy_to_user(useraddr, rule_buf, + rxnfc->rule_cnt * sizeof(u32))) + return -EFAULT; + } + + return 0; +} + +static noinline_for_stack int ethtool_set_rxnfc(struct net_device *dev, + u32 cmd, void __user *useraddr) +{ + struct ethtool_rxnfc info; + size_t info_size = sizeof(info); + int rc; + + if (!dev->ethtool_ops->set_rxnfc) + return -EOPNOTSUPP; + + /* struct ethtool_rxnfc was originally defined for + * ETHTOOL_{G,S}RXFH with only the cmd, flow_type and data + * members. User-space might still be using that + * definition. */ + if (cmd == ETHTOOL_SRXFH) + info_size = (offsetof(struct ethtool_rxnfc, data) + + sizeof(info.data)); + + if (ethtool_rxnfc_copy_from_user(&info, useraddr, info_size)) + return -EFAULT; + + rc = dev->ethtool_ops->set_rxnfc(dev, &info); + if (rc) + return rc; + + if (cmd == ETHTOOL_SRXCLSRLINS && + ethtool_rxnfc_copy_to_user(useraddr, &info, info_size, NULL)) + return -EFAULT; + + return 0; +} + +static noinline_for_stack int ethtool_get_rxnfc(struct net_device *dev, + u32 cmd, void __user *useraddr) +{ + struct ethtool_rxnfc info; + size_t info_size = sizeof(info); + const struct ethtool_ops *ops = dev->ethtool_ops; + int ret; + void *rule_buf = NULL; + + if (!ops->get_rxnfc) + return -EOPNOTSUPP; + + /* struct ethtool_rxnfc was originally defined for + * ETHTOOL_{G,S}RXFH with only the cmd, flow_type and data + * members. User-space might still be using that + * definition. */ + if (cmd == ETHTOOL_GRXFH) + info_size = (offsetof(struct ethtool_rxnfc, data) + + sizeof(info.data)); + + if (ethtool_rxnfc_copy_from_user(&info, useraddr, info_size)) + return -EFAULT; + + /* If FLOW_RSS was requested then user-space must be using the + * new definition, as FLOW_RSS is newer. + */ + if (cmd == ETHTOOL_GRXFH && info.flow_type & FLOW_RSS) { + info_size = sizeof(info); + if (ethtool_rxnfc_copy_from_user(&info, useraddr, info_size)) + return -EFAULT; + /* Since malicious users may modify the original data, + * we need to check whether FLOW_RSS is still requested. + */ + if (!(info.flow_type & FLOW_RSS)) + return -EINVAL; + } + + if (info.cmd != cmd) + return -EINVAL; + + if (info.cmd == ETHTOOL_GRXCLSRLALL) { + if (info.rule_cnt > 0) { + if (info.rule_cnt <= KMALLOC_MAX_SIZE / sizeof(u32)) + rule_buf = kcalloc(info.rule_cnt, sizeof(u32), + GFP_USER); + if (!rule_buf) + return -ENOMEM; + } + } + + ret = ops->get_rxnfc(dev, &info, rule_buf); + if (ret < 0) + goto err_out; + + ret = ethtool_rxnfc_copy_to_user(useraddr, &info, info_size, rule_buf); +err_out: + kfree(rule_buf); + + return ret; +} + +static int ethtool_copy_validate_indir(u32 *indir, void __user *useraddr, + struct ethtool_rxnfc *rx_rings, + u32 size) +{ + int i; + + if (copy_from_user(indir, useraddr, array_size(size, sizeof(indir[0])))) + return -EFAULT; + + /* Validate ring indices */ + for (i = 0; i < size; i++) + if (indir[i] >= rx_rings->data) + return -EINVAL; + + return 0; +} + +u8 netdev_rss_key[NETDEV_RSS_KEY_LEN] __read_mostly; + +void netdev_rss_key_fill(void *buffer, size_t len) +{ + BUG_ON(len > sizeof(netdev_rss_key)); + net_get_random_once(netdev_rss_key, sizeof(netdev_rss_key)); + memcpy(buffer, netdev_rss_key, len); +} +EXPORT_SYMBOL(netdev_rss_key_fill); + +static noinline_for_stack int ethtool_get_rxfh_indir(struct net_device *dev, + void __user *useraddr) +{ + u32 user_size, dev_size; + u32 *indir; + int ret; + + if (!dev->ethtool_ops->get_rxfh_indir_size || + !dev->ethtool_ops->get_rxfh) + return -EOPNOTSUPP; + dev_size = dev->ethtool_ops->get_rxfh_indir_size(dev); + if (dev_size == 0) + return -EOPNOTSUPP; + + if (copy_from_user(&user_size, + useraddr + offsetof(struct ethtool_rxfh_indir, size), + sizeof(user_size))) + return -EFAULT; + + if (copy_to_user(useraddr + offsetof(struct ethtool_rxfh_indir, size), + &dev_size, sizeof(dev_size))) + return -EFAULT; + + /* If the user buffer size is 0, this is just a query for the + * device table size. Otherwise, if it's smaller than the + * device table size it's an error. + */ + if (user_size < dev_size) + return user_size == 0 ? 0 : -EINVAL; + + indir = kcalloc(dev_size, sizeof(indir[0]), GFP_USER); + if (!indir) + return -ENOMEM; + + ret = dev->ethtool_ops->get_rxfh(dev, indir, NULL, NULL); + if (ret) + goto out; + + if (copy_to_user(useraddr + + offsetof(struct ethtool_rxfh_indir, ring_index[0]), + indir, dev_size * sizeof(indir[0]))) + ret = -EFAULT; + +out: + kfree(indir); + return ret; +} + +static noinline_for_stack int ethtool_set_rxfh_indir(struct net_device *dev, + void __user *useraddr) +{ + struct ethtool_rxnfc rx_rings; + u32 user_size, dev_size, i; + u32 *indir; + const struct ethtool_ops *ops = dev->ethtool_ops; + int ret; + u32 ringidx_offset = offsetof(struct ethtool_rxfh_indir, ring_index[0]); + + if (!ops->get_rxfh_indir_size || !ops->set_rxfh || + !ops->get_rxnfc) + return -EOPNOTSUPP; + + dev_size = ops->get_rxfh_indir_size(dev); + if (dev_size == 0) + return -EOPNOTSUPP; + + if (copy_from_user(&user_size, + useraddr + offsetof(struct ethtool_rxfh_indir, size), + sizeof(user_size))) + return -EFAULT; + + if (user_size != 0 && user_size != dev_size) + return -EINVAL; + + indir = kcalloc(dev_size, sizeof(indir[0]), GFP_USER); + if (!indir) + return -ENOMEM; + + rx_rings.cmd = ETHTOOL_GRXRINGS; + ret = ops->get_rxnfc(dev, &rx_rings, NULL); + if (ret) + goto out; + + if (user_size == 0) { + for (i = 0; i < dev_size; i++) + indir[i] = ethtool_rxfh_indir_default(i, rx_rings.data); + } else { + ret = ethtool_copy_validate_indir(indir, + useraddr + ringidx_offset, + &rx_rings, + dev_size); + if (ret) + goto out; + } + + ret = ops->set_rxfh(dev, indir, NULL, ETH_RSS_HASH_NO_CHANGE); + if (ret) + goto out; + + /* indicate whether rxfh was set to default */ + if (user_size == 0) + dev->priv_flags &= ~IFF_RXFH_CONFIGURED; + else + dev->priv_flags |= IFF_RXFH_CONFIGURED; + +out: + kfree(indir); + return ret; +} + +static noinline_for_stack int ethtool_get_rxfh(struct net_device *dev, + void __user *useraddr) +{ + int ret; + const struct ethtool_ops *ops = dev->ethtool_ops; + u32 user_indir_size, user_key_size; + u32 dev_indir_size = 0, dev_key_size = 0; + struct ethtool_rxfh rxfh; + u32 total_size; + u32 indir_bytes; + u32 *indir = NULL; + u8 dev_hfunc = 0; + u8 *hkey = NULL; + u8 *rss_config; + + if (!ops->get_rxfh) + return -EOPNOTSUPP; + + if (ops->get_rxfh_indir_size) + dev_indir_size = ops->get_rxfh_indir_size(dev); + if (ops->get_rxfh_key_size) + dev_key_size = ops->get_rxfh_key_size(dev); + + if (copy_from_user(&rxfh, useraddr, sizeof(rxfh))) + return -EFAULT; + user_indir_size = rxfh.indir_size; + user_key_size = rxfh.key_size; + + /* Check that reserved fields are 0 for now */ + if (rxfh.rsvd8[0] || rxfh.rsvd8[1] || rxfh.rsvd8[2] || rxfh.rsvd32) + return -EINVAL; + /* Most drivers don't handle rss_context, check it's 0 as well */ + if (rxfh.rss_context && !ops->get_rxfh_context) + return -EOPNOTSUPP; + + rxfh.indir_size = dev_indir_size; + rxfh.key_size = dev_key_size; + if (copy_to_user(useraddr, &rxfh, sizeof(rxfh))) + return -EFAULT; + + if ((user_indir_size && (user_indir_size != dev_indir_size)) || + (user_key_size && (user_key_size != dev_key_size))) + return -EINVAL; + + indir_bytes = user_indir_size * sizeof(indir[0]); + total_size = indir_bytes + user_key_size; + rss_config = kzalloc(total_size, GFP_USER); + if (!rss_config) + return -ENOMEM; + + if (user_indir_size) + indir = (u32 *)rss_config; + + if (user_key_size) + hkey = rss_config + indir_bytes; + + if (rxfh.rss_context) + ret = dev->ethtool_ops->get_rxfh_context(dev, indir, hkey, + &dev_hfunc, + rxfh.rss_context); + else + ret = dev->ethtool_ops->get_rxfh(dev, indir, hkey, &dev_hfunc); + if (ret) + goto out; + + if (copy_to_user(useraddr + offsetof(struct ethtool_rxfh, hfunc), + &dev_hfunc, sizeof(rxfh.hfunc))) { + ret = -EFAULT; + } else if (copy_to_user(useraddr + + offsetof(struct ethtool_rxfh, rss_config[0]), + rss_config, total_size)) { + ret = -EFAULT; + } +out: + kfree(rss_config); + + return ret; +} + +static noinline_for_stack int ethtool_set_rxfh(struct net_device *dev, + void __user *useraddr) +{ + int ret; + const struct ethtool_ops *ops = dev->ethtool_ops; + struct ethtool_rxnfc rx_rings; + struct ethtool_rxfh rxfh; + u32 dev_indir_size = 0, dev_key_size = 0, i; + u32 *indir = NULL, indir_bytes = 0; + u8 *hkey = NULL; + u8 *rss_config; + u32 rss_cfg_offset = offsetof(struct ethtool_rxfh, rss_config[0]); + bool delete = false; + + if (!ops->get_rxnfc || !ops->set_rxfh) + return -EOPNOTSUPP; + + if (ops->get_rxfh_indir_size) + dev_indir_size = ops->get_rxfh_indir_size(dev); + if (ops->get_rxfh_key_size) + dev_key_size = ops->get_rxfh_key_size(dev); + + if (copy_from_user(&rxfh, useraddr, sizeof(rxfh))) + return -EFAULT; + + /* Check that reserved fields are 0 for now */ + if (rxfh.rsvd8[0] || rxfh.rsvd8[1] || rxfh.rsvd8[2] || rxfh.rsvd32) + return -EINVAL; + /* Most drivers don't handle rss_context, check it's 0 as well */ + if (rxfh.rss_context && !ops->set_rxfh_context) + return -EOPNOTSUPP; + + /* If either indir, hash key or function is valid, proceed further. + * Must request at least one change: indir size, hash key or function. + */ + if ((rxfh.indir_size && + rxfh.indir_size != ETH_RXFH_INDIR_NO_CHANGE && + rxfh.indir_size != dev_indir_size) || + (rxfh.key_size && (rxfh.key_size != dev_key_size)) || + (rxfh.indir_size == ETH_RXFH_INDIR_NO_CHANGE && + rxfh.key_size == 0 && rxfh.hfunc == ETH_RSS_HASH_NO_CHANGE)) + return -EINVAL; + + if (rxfh.indir_size != ETH_RXFH_INDIR_NO_CHANGE) + indir_bytes = dev_indir_size * sizeof(indir[0]); + + rss_config = kzalloc(indir_bytes + rxfh.key_size, GFP_USER); + if (!rss_config) + return -ENOMEM; + + rx_rings.cmd = ETHTOOL_GRXRINGS; + ret = ops->get_rxnfc(dev, &rx_rings, NULL); + if (ret) + goto out; + + /* rxfh.indir_size == 0 means reset the indir table to default (master + * context) or delete the context (other RSS contexts). + * rxfh.indir_size == ETH_RXFH_INDIR_NO_CHANGE means leave it unchanged. + */ + if (rxfh.indir_size && + rxfh.indir_size != ETH_RXFH_INDIR_NO_CHANGE) { + indir = (u32 *)rss_config; + ret = ethtool_copy_validate_indir(indir, + useraddr + rss_cfg_offset, + &rx_rings, + rxfh.indir_size); + if (ret) + goto out; + } else if (rxfh.indir_size == 0) { + if (rxfh.rss_context == 0) { + indir = (u32 *)rss_config; + for (i = 0; i < dev_indir_size; i++) + indir[i] = ethtool_rxfh_indir_default(i, rx_rings.data); + } else { + delete = true; + } + } + + if (rxfh.key_size) { + hkey = rss_config + indir_bytes; + if (copy_from_user(hkey, + useraddr + rss_cfg_offset + indir_bytes, + rxfh.key_size)) { + ret = -EFAULT; + goto out; + } + } + + if (rxfh.rss_context) + ret = ops->set_rxfh_context(dev, indir, hkey, rxfh.hfunc, + &rxfh.rss_context, delete); + else + ret = ops->set_rxfh(dev, indir, hkey, rxfh.hfunc); + if (ret) + goto out; + + if (copy_to_user(useraddr + offsetof(struct ethtool_rxfh, rss_context), + &rxfh.rss_context, sizeof(rxfh.rss_context))) + ret = -EFAULT; + + if (!rxfh.rss_context) { + /* indicate whether rxfh was set to default */ + if (rxfh.indir_size == 0) + dev->priv_flags &= ~IFF_RXFH_CONFIGURED; + else if (rxfh.indir_size != ETH_RXFH_INDIR_NO_CHANGE) + dev->priv_flags |= IFF_RXFH_CONFIGURED; + } + +out: + kfree(rss_config); + return ret; +} + +static int ethtool_get_regs(struct net_device *dev, char __user *useraddr) +{ + struct ethtool_regs regs; + const struct ethtool_ops *ops = dev->ethtool_ops; + void *regbuf; + int reglen, ret; + + if (!ops->get_regs || !ops->get_regs_len) + return -EOPNOTSUPP; + + if (copy_from_user(®s, useraddr, sizeof(regs))) + return -EFAULT; + + reglen = ops->get_regs_len(dev); + if (reglen <= 0) + return reglen; + + if (regs.len > reglen) + regs.len = reglen; + + regbuf = vzalloc(reglen); + if (!regbuf) + return -ENOMEM; + + if (regs.len < reglen) + reglen = regs.len; + + ops->get_regs(dev, ®s, regbuf); + + ret = -EFAULT; + if (copy_to_user(useraddr, ®s, sizeof(regs))) + goto out; + useraddr += offsetof(struct ethtool_regs, data); + if (copy_to_user(useraddr, regbuf, reglen)) + goto out; + ret = 0; + + out: + vfree(regbuf); + return ret; +} + +static int ethtool_reset(struct net_device *dev, char __user *useraddr) +{ + struct ethtool_value reset; + int ret; + + if (!dev->ethtool_ops->reset) + return -EOPNOTSUPP; + + if (copy_from_user(&reset, useraddr, sizeof(reset))) + return -EFAULT; + + ret = dev->ethtool_ops->reset(dev, &reset.data); + if (ret) + return ret; + + if (copy_to_user(useraddr, &reset, sizeof(reset))) + return -EFAULT; + return 0; +} + +static int ethtool_get_wol(struct net_device *dev, char __user *useraddr) +{ + struct ethtool_wolinfo wol; + + if (!dev->ethtool_ops->get_wol) + return -EOPNOTSUPP; + + memset(&wol, 0, sizeof(struct ethtool_wolinfo)); + wol.cmd = ETHTOOL_GWOL; + dev->ethtool_ops->get_wol(dev, &wol); + + if (copy_to_user(useraddr, &wol, sizeof(wol))) + return -EFAULT; + return 0; +} + +static int ethtool_set_wol(struct net_device *dev, char __user *useraddr) +{ + struct ethtool_wolinfo wol; + int ret; + + if (!dev->ethtool_ops->set_wol) + return -EOPNOTSUPP; + + if (copy_from_user(&wol, useraddr, sizeof(wol))) + return -EFAULT; + + ret = dev->ethtool_ops->set_wol(dev, &wol); + if (ret) + return ret; + + dev->wol_enabled = !!wol.wolopts; + ethtool_notify(dev, ETHTOOL_MSG_WOL_NTF, NULL); + + return 0; +} + +static int ethtool_get_eee(struct net_device *dev, char __user *useraddr) +{ + struct ethtool_eee edata; + int rc; + + if (!dev->ethtool_ops->get_eee) + return -EOPNOTSUPP; + + memset(&edata, 0, sizeof(struct ethtool_eee)); + edata.cmd = ETHTOOL_GEEE; + rc = dev->ethtool_ops->get_eee(dev, &edata); + + if (rc) + return rc; + + if (copy_to_user(useraddr, &edata, sizeof(edata))) + return -EFAULT; + + return 0; +} + +static int ethtool_set_eee(struct net_device *dev, char __user *useraddr) +{ + struct ethtool_eee edata; + int ret; + + if (!dev->ethtool_ops->set_eee) + return -EOPNOTSUPP; + + if (copy_from_user(&edata, useraddr, sizeof(edata))) + return -EFAULT; + + ret = dev->ethtool_ops->set_eee(dev, &edata); + if (!ret) + ethtool_notify(dev, ETHTOOL_MSG_EEE_NTF, NULL); + return ret; +} + +static int ethtool_nway_reset(struct net_device *dev) +{ + if (!dev->ethtool_ops->nway_reset) + return -EOPNOTSUPP; + + return dev->ethtool_ops->nway_reset(dev); +} + +static int ethtool_get_link(struct net_device *dev, char __user *useraddr) +{ + struct ethtool_value edata = { .cmd = ETHTOOL_GLINK }; + int link = __ethtool_get_link(dev); + + if (link < 0) + return link; + + edata.data = link; + if (copy_to_user(useraddr, &edata, sizeof(edata))) + return -EFAULT; + return 0; +} + +static int ethtool_get_any_eeprom(struct net_device *dev, void __user *useraddr, + int (*getter)(struct net_device *, + struct ethtool_eeprom *, u8 *), + u32 total_len) +{ + struct ethtool_eeprom eeprom; + void __user *userbuf = useraddr + sizeof(eeprom); + u32 bytes_remaining; + u8 *data; + int ret = 0; + + if (copy_from_user(&eeprom, useraddr, sizeof(eeprom))) + return -EFAULT; + + /* Check for wrap and zero */ + if (eeprom.offset + eeprom.len <= eeprom.offset) + return -EINVAL; + + /* Check for exceeding total eeprom len */ + if (eeprom.offset + eeprom.len > total_len) + return -EINVAL; + + data = kzalloc(PAGE_SIZE, GFP_USER); + if (!data) + return -ENOMEM; + + bytes_remaining = eeprom.len; + while (bytes_remaining > 0) { + eeprom.len = min(bytes_remaining, (u32)PAGE_SIZE); + + ret = getter(dev, &eeprom, data); + if (ret) + break; + if (!eeprom.len) { + ret = -EIO; + break; + } + if (copy_to_user(userbuf, data, eeprom.len)) { + ret = -EFAULT; + break; + } + userbuf += eeprom.len; + eeprom.offset += eeprom.len; + bytes_remaining -= eeprom.len; + } + + eeprom.len = userbuf - (useraddr + sizeof(eeprom)); + eeprom.offset -= eeprom.len; + if (copy_to_user(useraddr, &eeprom, sizeof(eeprom))) + ret = -EFAULT; + + kfree(data); + return ret; +} + +static int ethtool_get_eeprom(struct net_device *dev, void __user *useraddr) +{ + const struct ethtool_ops *ops = dev->ethtool_ops; + + if (!ops->get_eeprom || !ops->get_eeprom_len || + !ops->get_eeprom_len(dev)) + return -EOPNOTSUPP; + + return ethtool_get_any_eeprom(dev, useraddr, ops->get_eeprom, + ops->get_eeprom_len(dev)); +} + +static int ethtool_set_eeprom(struct net_device *dev, void __user *useraddr) +{ + struct ethtool_eeprom eeprom; + const struct ethtool_ops *ops = dev->ethtool_ops; + void __user *userbuf = useraddr + sizeof(eeprom); + u32 bytes_remaining; + u8 *data; + int ret = 0; + + if (!ops->set_eeprom || !ops->get_eeprom_len || + !ops->get_eeprom_len(dev)) + return -EOPNOTSUPP; + + if (copy_from_user(&eeprom, useraddr, sizeof(eeprom))) + return -EFAULT; + + /* Check for wrap and zero */ + if (eeprom.offset + eeprom.len <= eeprom.offset) + return -EINVAL; + + /* Check for exceeding total eeprom len */ + if (eeprom.offset + eeprom.len > ops->get_eeprom_len(dev)) + return -EINVAL; + + data = kzalloc(PAGE_SIZE, GFP_USER); + if (!data) + return -ENOMEM; + + bytes_remaining = eeprom.len; + while (bytes_remaining > 0) { + eeprom.len = min(bytes_remaining, (u32)PAGE_SIZE); + + if (copy_from_user(data, userbuf, eeprom.len)) { + ret = -EFAULT; + break; + } + ret = ops->set_eeprom(dev, &eeprom, data); + if (ret) + break; + userbuf += eeprom.len; + eeprom.offset += eeprom.len; + bytes_remaining -= eeprom.len; + } + + kfree(data); + return ret; +} + +static noinline_for_stack int ethtool_get_coalesce(struct net_device *dev, + void __user *useraddr) +{ + struct ethtool_coalesce coalesce = { .cmd = ETHTOOL_GCOALESCE }; + struct kernel_ethtool_coalesce kernel_coalesce = {}; + int ret; + + if (!dev->ethtool_ops->get_coalesce) + return -EOPNOTSUPP; + + ret = dev->ethtool_ops->get_coalesce(dev, &coalesce, &kernel_coalesce, + NULL); + if (ret) + return ret; + + if (copy_to_user(useraddr, &coalesce, sizeof(coalesce))) + return -EFAULT; + return 0; +} + +static bool +ethtool_set_coalesce_supported(struct net_device *dev, + struct ethtool_coalesce *coalesce) +{ + u32 supported_params = dev->ethtool_ops->supported_coalesce_params; + u32 nonzero_params = 0; + + if (coalesce->rx_coalesce_usecs) + nonzero_params |= ETHTOOL_COALESCE_RX_USECS; + if (coalesce->rx_max_coalesced_frames) + nonzero_params |= ETHTOOL_COALESCE_RX_MAX_FRAMES; + if (coalesce->rx_coalesce_usecs_irq) + nonzero_params |= ETHTOOL_COALESCE_RX_USECS_IRQ; + if (coalesce->rx_max_coalesced_frames_irq) + nonzero_params |= ETHTOOL_COALESCE_RX_MAX_FRAMES_IRQ; + if (coalesce->tx_coalesce_usecs) + nonzero_params |= ETHTOOL_COALESCE_TX_USECS; + if (coalesce->tx_max_coalesced_frames) + nonzero_params |= ETHTOOL_COALESCE_TX_MAX_FRAMES; + if (coalesce->tx_coalesce_usecs_irq) + nonzero_params |= ETHTOOL_COALESCE_TX_USECS_IRQ; + if (coalesce->tx_max_coalesced_frames_irq) + nonzero_params |= ETHTOOL_COALESCE_TX_MAX_FRAMES_IRQ; + if (coalesce->stats_block_coalesce_usecs) + nonzero_params |= ETHTOOL_COALESCE_STATS_BLOCK_USECS; + if (coalesce->use_adaptive_rx_coalesce) + nonzero_params |= ETHTOOL_COALESCE_USE_ADAPTIVE_RX; + if (coalesce->use_adaptive_tx_coalesce) + nonzero_params |= ETHTOOL_COALESCE_USE_ADAPTIVE_TX; + if (coalesce->pkt_rate_low) + nonzero_params |= ETHTOOL_COALESCE_PKT_RATE_LOW; + if (coalesce->rx_coalesce_usecs_low) + nonzero_params |= ETHTOOL_COALESCE_RX_USECS_LOW; + if (coalesce->rx_max_coalesced_frames_low) + nonzero_params |= ETHTOOL_COALESCE_RX_MAX_FRAMES_LOW; + if (coalesce->tx_coalesce_usecs_low) + nonzero_params |= ETHTOOL_COALESCE_TX_USECS_LOW; + if (coalesce->tx_max_coalesced_frames_low) + nonzero_params |= ETHTOOL_COALESCE_TX_MAX_FRAMES_LOW; + if (coalesce->pkt_rate_high) + nonzero_params |= ETHTOOL_COALESCE_PKT_RATE_HIGH; + if (coalesce->rx_coalesce_usecs_high) + nonzero_params |= ETHTOOL_COALESCE_RX_USECS_HIGH; + if (coalesce->rx_max_coalesced_frames_high) + nonzero_params |= ETHTOOL_COALESCE_RX_MAX_FRAMES_HIGH; + if (coalesce->tx_coalesce_usecs_high) + nonzero_params |= ETHTOOL_COALESCE_TX_USECS_HIGH; + if (coalesce->tx_max_coalesced_frames_high) + nonzero_params |= ETHTOOL_COALESCE_TX_MAX_FRAMES_HIGH; + if (coalesce->rate_sample_interval) + nonzero_params |= ETHTOOL_COALESCE_RATE_SAMPLE_INTERVAL; + + return (supported_params & nonzero_params) == nonzero_params; +} + +static noinline_for_stack int ethtool_set_coalesce(struct net_device *dev, + void __user *useraddr) +{ + struct kernel_ethtool_coalesce kernel_coalesce = {}; + struct ethtool_coalesce coalesce; + int ret; + + if (!dev->ethtool_ops->set_coalesce || !dev->ethtool_ops->get_coalesce) + return -EOPNOTSUPP; + + ret = dev->ethtool_ops->get_coalesce(dev, &coalesce, &kernel_coalesce, + NULL); + if (ret) + return ret; + + if (copy_from_user(&coalesce, useraddr, sizeof(coalesce))) + return -EFAULT; + + if (!ethtool_set_coalesce_supported(dev, &coalesce)) + return -EOPNOTSUPP; + + ret = dev->ethtool_ops->set_coalesce(dev, &coalesce, &kernel_coalesce, + NULL); + if (!ret) + ethtool_notify(dev, ETHTOOL_MSG_COALESCE_NTF, NULL); + return ret; +} + +static int ethtool_get_ringparam(struct net_device *dev, void __user *useraddr) +{ + struct ethtool_ringparam ringparam = { .cmd = ETHTOOL_GRINGPARAM }; + struct kernel_ethtool_ringparam kernel_ringparam = {}; + + if (!dev->ethtool_ops->get_ringparam) + return -EOPNOTSUPP; + + dev->ethtool_ops->get_ringparam(dev, &ringparam, + &kernel_ringparam, NULL); + + if (copy_to_user(useraddr, &ringparam, sizeof(ringparam))) + return -EFAULT; + return 0; +} + +static int ethtool_set_ringparam(struct net_device *dev, void __user *useraddr) +{ + struct ethtool_ringparam ringparam, max = { .cmd = ETHTOOL_GRINGPARAM }; + struct kernel_ethtool_ringparam kernel_ringparam; + int ret; + + if (!dev->ethtool_ops->set_ringparam || !dev->ethtool_ops->get_ringparam) + return -EOPNOTSUPP; + + if (copy_from_user(&ringparam, useraddr, sizeof(ringparam))) + return -EFAULT; + + dev->ethtool_ops->get_ringparam(dev, &max, &kernel_ringparam, NULL); + + /* ensure new ring parameters are within the maximums */ + if (ringparam.rx_pending > max.rx_max_pending || + ringparam.rx_mini_pending > max.rx_mini_max_pending || + ringparam.rx_jumbo_pending > max.rx_jumbo_max_pending || + ringparam.tx_pending > max.tx_max_pending) + return -EINVAL; + + ret = dev->ethtool_ops->set_ringparam(dev, &ringparam, + &kernel_ringparam, NULL); + if (!ret) + ethtool_notify(dev, ETHTOOL_MSG_RINGS_NTF, NULL); + return ret; +} + +static noinline_for_stack int ethtool_get_channels(struct net_device *dev, + void __user *useraddr) +{ + struct ethtool_channels channels = { .cmd = ETHTOOL_GCHANNELS }; + + if (!dev->ethtool_ops->get_channels) + return -EOPNOTSUPP; + + dev->ethtool_ops->get_channels(dev, &channels); + + if (copy_to_user(useraddr, &channels, sizeof(channels))) + return -EFAULT; + return 0; +} + +static noinline_for_stack int ethtool_set_channels(struct net_device *dev, + void __user *useraddr) +{ + struct ethtool_channels channels, curr = { .cmd = ETHTOOL_GCHANNELS }; + u16 from_channel, to_channel; + u32 max_rx_in_use = 0; + unsigned int i; + int ret; + + if (!dev->ethtool_ops->set_channels || !dev->ethtool_ops->get_channels) + return -EOPNOTSUPP; + + if (copy_from_user(&channels, useraddr, sizeof(channels))) + return -EFAULT; + + dev->ethtool_ops->get_channels(dev, &curr); + + if (channels.rx_count == curr.rx_count && + channels.tx_count == curr.tx_count && + channels.combined_count == curr.combined_count && + channels.other_count == curr.other_count) + return 0; + + /* ensure new counts are within the maximums */ + if (channels.rx_count > curr.max_rx || + channels.tx_count > curr.max_tx || + channels.combined_count > curr.max_combined || + channels.other_count > curr.max_other) + return -EINVAL; + + /* ensure there is at least one RX and one TX channel */ + if (!channels.combined_count && + (!channels.rx_count || !channels.tx_count)) + return -EINVAL; + + /* ensure the new Rx count fits within the configured Rx flow + * indirection table settings */ + if (netif_is_rxfh_configured(dev) && + !ethtool_get_max_rxfh_channel(dev, &max_rx_in_use) && + (channels.combined_count + channels.rx_count) <= max_rx_in_use) + return -EINVAL; + + /* Disabling channels, query zero-copy AF_XDP sockets */ + from_channel = channels.combined_count + + min(channels.rx_count, channels.tx_count); + to_channel = curr.combined_count + max(curr.rx_count, curr.tx_count); + for (i = from_channel; i < to_channel; i++) + if (xsk_get_pool_from_qid(dev, i)) + return -EINVAL; + + ret = dev->ethtool_ops->set_channels(dev, &channels); + if (!ret) + ethtool_notify(dev, ETHTOOL_MSG_CHANNELS_NTF, NULL); + return ret; +} + +static int ethtool_get_pauseparam(struct net_device *dev, void __user *useraddr) +{ + struct ethtool_pauseparam pauseparam = { .cmd = ETHTOOL_GPAUSEPARAM }; + + if (!dev->ethtool_ops->get_pauseparam) + return -EOPNOTSUPP; + + dev->ethtool_ops->get_pauseparam(dev, &pauseparam); + + if (copy_to_user(useraddr, &pauseparam, sizeof(pauseparam))) + return -EFAULT; + return 0; +} + +static int ethtool_set_pauseparam(struct net_device *dev, void __user *useraddr) +{ + struct ethtool_pauseparam pauseparam; + int ret; + + if (!dev->ethtool_ops->set_pauseparam) + return -EOPNOTSUPP; + + if (copy_from_user(&pauseparam, useraddr, sizeof(pauseparam))) + return -EFAULT; + + ret = dev->ethtool_ops->set_pauseparam(dev, &pauseparam); + if (!ret) + ethtool_notify(dev, ETHTOOL_MSG_PAUSE_NTF, NULL); + return ret; +} + +static int ethtool_self_test(struct net_device *dev, char __user *useraddr) +{ + struct ethtool_test test; + const struct ethtool_ops *ops = dev->ethtool_ops; + u64 *data; + int ret, test_len; + + if (!ops->self_test || !ops->get_sset_count) + return -EOPNOTSUPP; + + test_len = ops->get_sset_count(dev, ETH_SS_TEST); + if (test_len < 0) + return test_len; + WARN_ON(test_len == 0); + + if (copy_from_user(&test, useraddr, sizeof(test))) + return -EFAULT; + + test.len = test_len; + data = kcalloc(test_len, sizeof(u64), GFP_USER); + if (!data) + return -ENOMEM; + + netif_testing_on(dev); + ops->self_test(dev, &test, data); + netif_testing_off(dev); + + ret = -EFAULT; + if (copy_to_user(useraddr, &test, sizeof(test))) + goto out; + useraddr += sizeof(test); + if (copy_to_user(useraddr, data, array_size(test.len, sizeof(u64)))) + goto out; + ret = 0; + + out: + kfree(data); + return ret; +} + +static int ethtool_get_strings(struct net_device *dev, void __user *useraddr) +{ + struct ethtool_gstrings gstrings; + u8 *data; + int ret; + + if (copy_from_user(&gstrings, useraddr, sizeof(gstrings))) + return -EFAULT; + + ret = __ethtool_get_sset_count(dev, gstrings.string_set); + if (ret < 0) + return ret; + if (ret > S32_MAX / ETH_GSTRING_LEN) + return -ENOMEM; + WARN_ON_ONCE(!ret); + + gstrings.len = ret; + + if (gstrings.len) { + data = vzalloc(array_size(gstrings.len, ETH_GSTRING_LEN)); + if (!data) + return -ENOMEM; + + __ethtool_get_strings(dev, gstrings.string_set, data); + } else { + data = NULL; + } + + ret = -EFAULT; + if (copy_to_user(useraddr, &gstrings, sizeof(gstrings))) + goto out; + useraddr += sizeof(gstrings); + if (gstrings.len && + copy_to_user(useraddr, data, + array_size(gstrings.len, ETH_GSTRING_LEN))) + goto out; + ret = 0; + +out: + vfree(data); + return ret; +} + +__printf(2, 3) void ethtool_sprintf(u8 **data, const char *fmt, ...) +{ + va_list args; + + va_start(args, fmt); + vsnprintf(*data, ETH_GSTRING_LEN, fmt, args); + va_end(args); + + *data += ETH_GSTRING_LEN; +} +EXPORT_SYMBOL(ethtool_sprintf); + +static int ethtool_phys_id(struct net_device *dev, void __user *useraddr) +{ + struct ethtool_value id; + static bool busy; + const struct ethtool_ops *ops = dev->ethtool_ops; + netdevice_tracker dev_tracker; + int rc; + + if (!ops->set_phys_id) + return -EOPNOTSUPP; + + if (busy) + return -EBUSY; + + if (copy_from_user(&id, useraddr, sizeof(id))) + return -EFAULT; + + rc = ops->set_phys_id(dev, ETHTOOL_ID_ACTIVE); + if (rc < 0) + return rc; + + /* Drop the RTNL lock while waiting, but prevent reentry or + * removal of the device. + */ + busy = true; + netdev_hold(dev, &dev_tracker, GFP_KERNEL); + rtnl_unlock(); + + if (rc == 0) { + /* Driver will handle this itself */ + schedule_timeout_interruptible( + id.data ? (id.data * HZ) : MAX_SCHEDULE_TIMEOUT); + } else { + /* Driver expects to be called at twice the frequency in rc */ + int n = rc * 2, interval = HZ / n; + u64 count = mul_u32_u32(n, id.data); + u64 i = 0; + + do { + rtnl_lock(); + rc = ops->set_phys_id(dev, + (i++ & 1) ? ETHTOOL_ID_OFF : ETHTOOL_ID_ON); + rtnl_unlock(); + if (rc) + break; + schedule_timeout_interruptible(interval); + } while (!signal_pending(current) && (!id.data || i < count)); + } + + rtnl_lock(); + netdev_put(dev, &dev_tracker); + busy = false; + + (void) ops->set_phys_id(dev, ETHTOOL_ID_INACTIVE); + return rc; +} + +static int ethtool_get_stats(struct net_device *dev, void __user *useraddr) +{ + struct ethtool_stats stats; + const struct ethtool_ops *ops = dev->ethtool_ops; + u64 *data; + int ret, n_stats; + + if (!ops->get_ethtool_stats || !ops->get_sset_count) + return -EOPNOTSUPP; + + n_stats = ops->get_sset_count(dev, ETH_SS_STATS); + if (n_stats < 0) + return n_stats; + if (n_stats > S32_MAX / sizeof(u64)) + return -ENOMEM; + WARN_ON_ONCE(!n_stats); + if (copy_from_user(&stats, useraddr, sizeof(stats))) + return -EFAULT; + + stats.n_stats = n_stats; + + if (n_stats) { + data = vzalloc(array_size(n_stats, sizeof(u64))); + if (!data) + return -ENOMEM; + ops->get_ethtool_stats(dev, &stats, data); + } else { + data = NULL; + } + + ret = -EFAULT; + if (copy_to_user(useraddr, &stats, sizeof(stats))) + goto out; + useraddr += sizeof(stats); + if (n_stats && copy_to_user(useraddr, data, array_size(n_stats, sizeof(u64)))) + goto out; + ret = 0; + + out: + vfree(data); + return ret; +} + +static int ethtool_get_phy_stats(struct net_device *dev, void __user *useraddr) +{ + const struct ethtool_phy_ops *phy_ops = ethtool_phy_ops; + const struct ethtool_ops *ops = dev->ethtool_ops; + struct phy_device *phydev = dev->phydev; + struct ethtool_stats stats; + u64 *data; + int ret, n_stats; + + if (!phydev && (!ops->get_ethtool_phy_stats || !ops->get_sset_count)) + return -EOPNOTSUPP; + + if (phydev && !ops->get_ethtool_phy_stats && + phy_ops && phy_ops->get_sset_count) + n_stats = phy_ops->get_sset_count(phydev); + else + n_stats = ops->get_sset_count(dev, ETH_SS_PHY_STATS); + if (n_stats < 0) + return n_stats; + if (n_stats > S32_MAX / sizeof(u64)) + return -ENOMEM; + if (WARN_ON_ONCE(!n_stats)) + return -EOPNOTSUPP; + + if (copy_from_user(&stats, useraddr, sizeof(stats))) + return -EFAULT; + + stats.n_stats = n_stats; + + if (n_stats) { + data = vzalloc(array_size(n_stats, sizeof(u64))); + if (!data) + return -ENOMEM; + + if (phydev && !ops->get_ethtool_phy_stats && + phy_ops && phy_ops->get_stats) { + ret = phy_ops->get_stats(phydev, &stats, data); + if (ret < 0) + goto out; + } else { + ops->get_ethtool_phy_stats(dev, &stats, data); + } + } else { + data = NULL; + } + + ret = -EFAULT; + if (copy_to_user(useraddr, &stats, sizeof(stats))) + goto out; + useraddr += sizeof(stats); + if (n_stats && copy_to_user(useraddr, data, array_size(n_stats, sizeof(u64)))) + goto out; + ret = 0; + + out: + vfree(data); + return ret; +} + +static int ethtool_get_perm_addr(struct net_device *dev, void __user *useraddr) +{ + struct ethtool_perm_addr epaddr; + + if (copy_from_user(&epaddr, useraddr, sizeof(epaddr))) + return -EFAULT; + + if (epaddr.size < dev->addr_len) + return -ETOOSMALL; + epaddr.size = dev->addr_len; + + if (copy_to_user(useraddr, &epaddr, sizeof(epaddr))) + return -EFAULT; + useraddr += sizeof(epaddr); + if (copy_to_user(useraddr, dev->perm_addr, epaddr.size)) + return -EFAULT; + return 0; +} + +static int ethtool_get_value(struct net_device *dev, char __user *useraddr, + u32 cmd, u32 (*actor)(struct net_device *)) +{ + struct ethtool_value edata = { .cmd = cmd }; + + if (!actor) + return -EOPNOTSUPP; + + edata.data = actor(dev); + + if (copy_to_user(useraddr, &edata, sizeof(edata))) + return -EFAULT; + return 0; +} + +static int ethtool_set_value_void(struct net_device *dev, char __user *useraddr, + void (*actor)(struct net_device *, u32)) +{ + struct ethtool_value edata; + + if (!actor) + return -EOPNOTSUPP; + + if (copy_from_user(&edata, useraddr, sizeof(edata))) + return -EFAULT; + + actor(dev, edata.data); + return 0; +} + +static int ethtool_set_value(struct net_device *dev, char __user *useraddr, + int (*actor)(struct net_device *, u32)) +{ + struct ethtool_value edata; + + if (!actor) + return -EOPNOTSUPP; + + if (copy_from_user(&edata, useraddr, sizeof(edata))) + return -EFAULT; + + return actor(dev, edata.data); +} + +static int +ethtool_flash_device(struct net_device *dev, struct ethtool_devlink_compat *req) +{ + if (!dev->ethtool_ops->flash_device) { + req->devlink = netdev_to_devlink_get(dev); + return 0; + } + + return dev->ethtool_ops->flash_device(dev, &req->efl); +} + +static int ethtool_set_dump(struct net_device *dev, + void __user *useraddr) +{ + struct ethtool_dump dump; + + if (!dev->ethtool_ops->set_dump) + return -EOPNOTSUPP; + + if (copy_from_user(&dump, useraddr, sizeof(dump))) + return -EFAULT; + + return dev->ethtool_ops->set_dump(dev, &dump); +} + +static int ethtool_get_dump_flag(struct net_device *dev, + void __user *useraddr) +{ + int ret; + struct ethtool_dump dump; + const struct ethtool_ops *ops = dev->ethtool_ops; + + if (!ops->get_dump_flag) + return -EOPNOTSUPP; + + if (copy_from_user(&dump, useraddr, sizeof(dump))) + return -EFAULT; + + ret = ops->get_dump_flag(dev, &dump); + if (ret) + return ret; + + if (copy_to_user(useraddr, &dump, sizeof(dump))) + return -EFAULT; + return 0; +} + +static int ethtool_get_dump_data(struct net_device *dev, + void __user *useraddr) +{ + int ret; + __u32 len; + struct ethtool_dump dump, tmp; + const struct ethtool_ops *ops = dev->ethtool_ops; + void *data = NULL; + + if (!ops->get_dump_data || !ops->get_dump_flag) + return -EOPNOTSUPP; + + if (copy_from_user(&dump, useraddr, sizeof(dump))) + return -EFAULT; + + memset(&tmp, 0, sizeof(tmp)); + tmp.cmd = ETHTOOL_GET_DUMP_FLAG; + ret = ops->get_dump_flag(dev, &tmp); + if (ret) + return ret; + + len = min(tmp.len, dump.len); + if (!len) + return -EFAULT; + + /* Don't ever let the driver think there's more space available + * than it requested with .get_dump_flag(). + */ + dump.len = len; + + /* Always allocate enough space to hold the whole thing so that the + * driver does not need to check the length and bother with partial + * dumping. + */ + data = vzalloc(tmp.len); + if (!data) + return -ENOMEM; + ret = ops->get_dump_data(dev, &dump, data); + if (ret) + goto out; + + /* There are two sane possibilities: + * 1. The driver's .get_dump_data() does not touch dump.len. + * 2. Or it may set dump.len to how much it really writes, which + * should be tmp.len (or len if it can do a partial dump). + * In any case respond to userspace with the actual length of data + * it's receiving. + */ + WARN_ON(dump.len != len && dump.len != tmp.len); + dump.len = len; + + if (copy_to_user(useraddr, &dump, sizeof(dump))) { + ret = -EFAULT; + goto out; + } + useraddr += offsetof(struct ethtool_dump, data); + if (copy_to_user(useraddr, data, len)) + ret = -EFAULT; +out: + vfree(data); + return ret; +} + +static int ethtool_get_ts_info(struct net_device *dev, void __user *useraddr) +{ + struct ethtool_ts_info info; + int err; + + err = __ethtool_get_ts_info(dev, &info); + if (err) + return err; + + if (copy_to_user(useraddr, &info, sizeof(info))) + return -EFAULT; + + return 0; +} + +int ethtool_get_module_info_call(struct net_device *dev, + struct ethtool_modinfo *modinfo) +{ + const struct ethtool_ops *ops = dev->ethtool_ops; + struct phy_device *phydev = dev->phydev; + + if (dev->sfp_bus) + return sfp_get_module_info(dev->sfp_bus, modinfo); + + if (phydev && phydev->drv && phydev->drv->module_info) + return phydev->drv->module_info(phydev, modinfo); + + if (ops->get_module_info) + return ops->get_module_info(dev, modinfo); + + return -EOPNOTSUPP; +} + +static int ethtool_get_module_info(struct net_device *dev, + void __user *useraddr) +{ + int ret; + struct ethtool_modinfo modinfo; + + if (copy_from_user(&modinfo, useraddr, sizeof(modinfo))) + return -EFAULT; + + ret = ethtool_get_module_info_call(dev, &modinfo); + if (ret) + return ret; + + if (copy_to_user(useraddr, &modinfo, sizeof(modinfo))) + return -EFAULT; + + return 0; +} + +int ethtool_get_module_eeprom_call(struct net_device *dev, + struct ethtool_eeprom *ee, u8 *data) +{ + const struct ethtool_ops *ops = dev->ethtool_ops; + struct phy_device *phydev = dev->phydev; + + if (dev->sfp_bus) + return sfp_get_module_eeprom(dev->sfp_bus, ee, data); + + if (phydev && phydev->drv && phydev->drv->module_eeprom) + return phydev->drv->module_eeprom(phydev, ee, data); + + if (ops->get_module_eeprom) + return ops->get_module_eeprom(dev, ee, data); + + return -EOPNOTSUPP; +} + +static int ethtool_get_module_eeprom(struct net_device *dev, + void __user *useraddr) +{ + int ret; + struct ethtool_modinfo modinfo; + + ret = ethtool_get_module_info_call(dev, &modinfo); + if (ret) + return ret; + + return ethtool_get_any_eeprom(dev, useraddr, + ethtool_get_module_eeprom_call, + modinfo.eeprom_len); +} + +static int ethtool_tunable_valid(const struct ethtool_tunable *tuna) +{ + switch (tuna->id) { + case ETHTOOL_RX_COPYBREAK: + case ETHTOOL_TX_COPYBREAK: + case ETHTOOL_TX_COPYBREAK_BUF_SIZE: + if (tuna->len != sizeof(u32) || + tuna->type_id != ETHTOOL_TUNABLE_U32) + return -EINVAL; + break; + case ETHTOOL_PFC_PREVENTION_TOUT: + if (tuna->len != sizeof(u16) || + tuna->type_id != ETHTOOL_TUNABLE_U16) + return -EINVAL; + break; + default: + return -EINVAL; + } + + return 0; +} + +static int ethtool_get_tunable(struct net_device *dev, void __user *useraddr) +{ + int ret; + struct ethtool_tunable tuna; + const struct ethtool_ops *ops = dev->ethtool_ops; + void *data; + + if (!ops->get_tunable) + return -EOPNOTSUPP; + if (copy_from_user(&tuna, useraddr, sizeof(tuna))) + return -EFAULT; + ret = ethtool_tunable_valid(&tuna); + if (ret) + return ret; + data = kzalloc(tuna.len, GFP_USER); + if (!data) + return -ENOMEM; + ret = ops->get_tunable(dev, &tuna, data); + if (ret) + goto out; + useraddr += sizeof(tuna); + ret = -EFAULT; + if (copy_to_user(useraddr, data, tuna.len)) + goto out; + ret = 0; + +out: + kfree(data); + return ret; +} + +static int ethtool_set_tunable(struct net_device *dev, void __user *useraddr) +{ + int ret; + struct ethtool_tunable tuna; + const struct ethtool_ops *ops = dev->ethtool_ops; + void *data; + + if (!ops->set_tunable) + return -EOPNOTSUPP; + if (copy_from_user(&tuna, useraddr, sizeof(tuna))) + return -EFAULT; + ret = ethtool_tunable_valid(&tuna); + if (ret) + return ret; + useraddr += sizeof(tuna); + data = memdup_user(useraddr, tuna.len); + if (IS_ERR(data)) + return PTR_ERR(data); + ret = ops->set_tunable(dev, &tuna, data); + + kfree(data); + return ret; +} + +static noinline_for_stack int +ethtool_get_per_queue_coalesce(struct net_device *dev, + void __user *useraddr, + struct ethtool_per_queue_op *per_queue_opt) +{ + u32 bit; + int ret; + DECLARE_BITMAP(queue_mask, MAX_NUM_QUEUE); + + if (!dev->ethtool_ops->get_per_queue_coalesce) + return -EOPNOTSUPP; + + useraddr += sizeof(*per_queue_opt); + + bitmap_from_arr32(queue_mask, per_queue_opt->queue_mask, + MAX_NUM_QUEUE); + + for_each_set_bit(bit, queue_mask, MAX_NUM_QUEUE) { + struct ethtool_coalesce coalesce = { .cmd = ETHTOOL_GCOALESCE }; + + ret = dev->ethtool_ops->get_per_queue_coalesce(dev, bit, &coalesce); + if (ret != 0) + return ret; + if (copy_to_user(useraddr, &coalesce, sizeof(coalesce))) + return -EFAULT; + useraddr += sizeof(coalesce); + } + + return 0; +} + +static noinline_for_stack int +ethtool_set_per_queue_coalesce(struct net_device *dev, + void __user *useraddr, + struct ethtool_per_queue_op *per_queue_opt) +{ + u32 bit; + int i, ret = 0; + int n_queue; + struct ethtool_coalesce *backup = NULL, *tmp = NULL; + DECLARE_BITMAP(queue_mask, MAX_NUM_QUEUE); + + if ((!dev->ethtool_ops->set_per_queue_coalesce) || + (!dev->ethtool_ops->get_per_queue_coalesce)) + return -EOPNOTSUPP; + + useraddr += sizeof(*per_queue_opt); + + bitmap_from_arr32(queue_mask, per_queue_opt->queue_mask, MAX_NUM_QUEUE); + n_queue = bitmap_weight(queue_mask, MAX_NUM_QUEUE); + tmp = backup = kmalloc_array(n_queue, sizeof(*backup), GFP_KERNEL); + if (!backup) + return -ENOMEM; + + for_each_set_bit(bit, queue_mask, MAX_NUM_QUEUE) { + struct ethtool_coalesce coalesce; + + ret = dev->ethtool_ops->get_per_queue_coalesce(dev, bit, tmp); + if (ret != 0) + goto roll_back; + + tmp++; + + if (copy_from_user(&coalesce, useraddr, sizeof(coalesce))) { + ret = -EFAULT; + goto roll_back; + } + + if (!ethtool_set_coalesce_supported(dev, &coalesce)) { + ret = -EOPNOTSUPP; + goto roll_back; + } + + ret = dev->ethtool_ops->set_per_queue_coalesce(dev, bit, &coalesce); + if (ret != 0) + goto roll_back; + + useraddr += sizeof(coalesce); + } + +roll_back: + if (ret != 0) { + tmp = backup; + for_each_set_bit(i, queue_mask, bit) { + dev->ethtool_ops->set_per_queue_coalesce(dev, i, tmp); + tmp++; + } + } + kfree(backup); + + return ret; +} + +static int noinline_for_stack ethtool_set_per_queue(struct net_device *dev, + void __user *useraddr, u32 sub_cmd) +{ + struct ethtool_per_queue_op per_queue_opt; + + if (copy_from_user(&per_queue_opt, useraddr, sizeof(per_queue_opt))) + return -EFAULT; + + if (per_queue_opt.sub_command != sub_cmd) + return -EINVAL; + + switch (per_queue_opt.sub_command) { + case ETHTOOL_GCOALESCE: + return ethtool_get_per_queue_coalesce(dev, useraddr, &per_queue_opt); + case ETHTOOL_SCOALESCE: + return ethtool_set_per_queue_coalesce(dev, useraddr, &per_queue_opt); + default: + return -EOPNOTSUPP; + } +} + +static int ethtool_phy_tunable_valid(const struct ethtool_tunable *tuna) +{ + switch (tuna->id) { + case ETHTOOL_PHY_DOWNSHIFT: + case ETHTOOL_PHY_FAST_LINK_DOWN: + if (tuna->len != sizeof(u8) || + tuna->type_id != ETHTOOL_TUNABLE_U8) + return -EINVAL; + break; + case ETHTOOL_PHY_EDPD: + if (tuna->len != sizeof(u16) || + tuna->type_id != ETHTOOL_TUNABLE_U16) + return -EINVAL; + break; + default: + return -EINVAL; + } + + return 0; +} + +static int get_phy_tunable(struct net_device *dev, void __user *useraddr) +{ + struct phy_device *phydev = dev->phydev; + struct ethtool_tunable tuna; + bool phy_drv_tunable; + void *data; + int ret; + + phy_drv_tunable = phydev && phydev->drv && phydev->drv->get_tunable; + if (!phy_drv_tunable && !dev->ethtool_ops->get_phy_tunable) + return -EOPNOTSUPP; + if (copy_from_user(&tuna, useraddr, sizeof(tuna))) + return -EFAULT; + ret = ethtool_phy_tunable_valid(&tuna); + if (ret) + return ret; + data = kzalloc(tuna.len, GFP_USER); + if (!data) + return -ENOMEM; + if (phy_drv_tunable) { + mutex_lock(&phydev->lock); + ret = phydev->drv->get_tunable(phydev, &tuna, data); + mutex_unlock(&phydev->lock); + } else { + ret = dev->ethtool_ops->get_phy_tunable(dev, &tuna, data); + } + if (ret) + goto out; + useraddr += sizeof(tuna); + ret = -EFAULT; + if (copy_to_user(useraddr, data, tuna.len)) + goto out; + ret = 0; + +out: + kfree(data); + return ret; +} + +static int set_phy_tunable(struct net_device *dev, void __user *useraddr) +{ + struct phy_device *phydev = dev->phydev; + struct ethtool_tunable tuna; + bool phy_drv_tunable; + void *data; + int ret; + + phy_drv_tunable = phydev && phydev->drv && phydev->drv->get_tunable; + if (!phy_drv_tunable && !dev->ethtool_ops->set_phy_tunable) + return -EOPNOTSUPP; + if (copy_from_user(&tuna, useraddr, sizeof(tuna))) + return -EFAULT; + ret = ethtool_phy_tunable_valid(&tuna); + if (ret) + return ret; + useraddr += sizeof(tuna); + data = memdup_user(useraddr, tuna.len); + if (IS_ERR(data)) + return PTR_ERR(data); + if (phy_drv_tunable) { + mutex_lock(&phydev->lock); + ret = phydev->drv->set_tunable(phydev, &tuna, data); + mutex_unlock(&phydev->lock); + } else { + ret = dev->ethtool_ops->set_phy_tunable(dev, &tuna, data); + } + + kfree(data); + return ret; +} + +static int ethtool_get_fecparam(struct net_device *dev, void __user *useraddr) +{ + struct ethtool_fecparam fecparam = { .cmd = ETHTOOL_GFECPARAM }; + int rc; + + if (!dev->ethtool_ops->get_fecparam) + return -EOPNOTSUPP; + + rc = dev->ethtool_ops->get_fecparam(dev, &fecparam); + if (rc) + return rc; + + if (WARN_ON_ONCE(fecparam.reserved)) + fecparam.reserved = 0; + + if (copy_to_user(useraddr, &fecparam, sizeof(fecparam))) + return -EFAULT; + return 0; +} + +static int ethtool_set_fecparam(struct net_device *dev, void __user *useraddr) +{ + struct ethtool_fecparam fecparam; + + if (!dev->ethtool_ops->set_fecparam) + return -EOPNOTSUPP; + + if (copy_from_user(&fecparam, useraddr, sizeof(fecparam))) + return -EFAULT; + + if (!fecparam.fec || fecparam.fec & ETHTOOL_FEC_NONE) + return -EINVAL; + + fecparam.active_fec = 0; + fecparam.reserved = 0; + + return dev->ethtool_ops->set_fecparam(dev, &fecparam); +} + +/* The main entry point in this file. Called from net/core/dev_ioctl.c */ + +static int +__dev_ethtool(struct net *net, struct ifreq *ifr, void __user *useraddr, + u32 ethcmd, struct ethtool_devlink_compat *devlink_state) +{ + struct net_device *dev; + u32 sub_cmd; + int rc; + netdev_features_t old_features; + + dev = __dev_get_by_name(net, ifr->ifr_name); + if (!dev) + return -ENODEV; + + if (ethcmd == ETHTOOL_PERQUEUE) { + if (copy_from_user(&sub_cmd, useraddr + sizeof(ethcmd), sizeof(sub_cmd))) + return -EFAULT; + } else { + sub_cmd = ethcmd; + } + /* Allow some commands to be done by anyone */ + switch (sub_cmd) { + case ETHTOOL_GSET: + case ETHTOOL_GDRVINFO: + case ETHTOOL_GMSGLVL: + case ETHTOOL_GLINK: + case ETHTOOL_GCOALESCE: + case ETHTOOL_GRINGPARAM: + case ETHTOOL_GPAUSEPARAM: + case ETHTOOL_GRXCSUM: + case ETHTOOL_GTXCSUM: + case ETHTOOL_GSG: + case ETHTOOL_GSSET_INFO: + case ETHTOOL_GSTRINGS: + case ETHTOOL_GSTATS: + case ETHTOOL_GPHYSTATS: + case ETHTOOL_GTSO: + case ETHTOOL_GPERMADDR: + case ETHTOOL_GUFO: + case ETHTOOL_GGSO: + case ETHTOOL_GGRO: + case ETHTOOL_GFLAGS: + case ETHTOOL_GPFLAGS: + case ETHTOOL_GRXFH: + case ETHTOOL_GRXRINGS: + case ETHTOOL_GRXCLSRLCNT: + case ETHTOOL_GRXCLSRULE: + case ETHTOOL_GRXCLSRLALL: + case ETHTOOL_GRXFHINDIR: + case ETHTOOL_GRSSH: + case ETHTOOL_GFEATURES: + case ETHTOOL_GCHANNELS: + case ETHTOOL_GET_TS_INFO: + case ETHTOOL_GEEE: + case ETHTOOL_GTUNABLE: + case ETHTOOL_PHY_GTUNABLE: + case ETHTOOL_GLINKSETTINGS: + case ETHTOOL_GFECPARAM: + break; + default: + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + } + + if (dev->dev.parent) + pm_runtime_get_sync(dev->dev.parent); + + if (!netif_device_present(dev)) { + rc = -ENODEV; + goto out; + } + + if (dev->ethtool_ops->begin) { + rc = dev->ethtool_ops->begin(dev); + if (rc < 0) + goto out; + } + old_features = dev->features; + + switch (ethcmd) { + case ETHTOOL_GSET: + rc = ethtool_get_settings(dev, useraddr); + break; + case ETHTOOL_SSET: + rc = ethtool_set_settings(dev, useraddr); + break; + case ETHTOOL_GDRVINFO: + rc = ethtool_get_drvinfo(dev, devlink_state); + break; + case ETHTOOL_GREGS: + rc = ethtool_get_regs(dev, useraddr); + break; + case ETHTOOL_GWOL: + rc = ethtool_get_wol(dev, useraddr); + break; + case ETHTOOL_SWOL: + rc = ethtool_set_wol(dev, useraddr); + break; + case ETHTOOL_GMSGLVL: + rc = ethtool_get_value(dev, useraddr, ethcmd, + dev->ethtool_ops->get_msglevel); + break; + case ETHTOOL_SMSGLVL: + rc = ethtool_set_value_void(dev, useraddr, + dev->ethtool_ops->set_msglevel); + if (!rc) + ethtool_notify(dev, ETHTOOL_MSG_DEBUG_NTF, NULL); + break; + case ETHTOOL_GEEE: + rc = ethtool_get_eee(dev, useraddr); + break; + case ETHTOOL_SEEE: + rc = ethtool_set_eee(dev, useraddr); + break; + case ETHTOOL_NWAY_RST: + rc = ethtool_nway_reset(dev); + break; + case ETHTOOL_GLINK: + rc = ethtool_get_link(dev, useraddr); + break; + case ETHTOOL_GEEPROM: + rc = ethtool_get_eeprom(dev, useraddr); + break; + case ETHTOOL_SEEPROM: + rc = ethtool_set_eeprom(dev, useraddr); + break; + case ETHTOOL_GCOALESCE: + rc = ethtool_get_coalesce(dev, useraddr); + break; + case ETHTOOL_SCOALESCE: + rc = ethtool_set_coalesce(dev, useraddr); + break; + case ETHTOOL_GRINGPARAM: + rc = ethtool_get_ringparam(dev, useraddr); + break; + case ETHTOOL_SRINGPARAM: + rc = ethtool_set_ringparam(dev, useraddr); + break; + case ETHTOOL_GPAUSEPARAM: + rc = ethtool_get_pauseparam(dev, useraddr); + break; + case ETHTOOL_SPAUSEPARAM: + rc = ethtool_set_pauseparam(dev, useraddr); + break; + case ETHTOOL_TEST: + rc = ethtool_self_test(dev, useraddr); + break; + case ETHTOOL_GSTRINGS: + rc = ethtool_get_strings(dev, useraddr); + break; + case ETHTOOL_PHYS_ID: + rc = ethtool_phys_id(dev, useraddr); + break; + case ETHTOOL_GSTATS: + rc = ethtool_get_stats(dev, useraddr); + break; + case ETHTOOL_GPERMADDR: + rc = ethtool_get_perm_addr(dev, useraddr); + break; + case ETHTOOL_GFLAGS: + rc = ethtool_get_value(dev, useraddr, ethcmd, + __ethtool_get_flags); + break; + case ETHTOOL_SFLAGS: + rc = ethtool_set_value(dev, useraddr, __ethtool_set_flags); + break; + case ETHTOOL_GPFLAGS: + rc = ethtool_get_value(dev, useraddr, ethcmd, + dev->ethtool_ops->get_priv_flags); + if (!rc) + ethtool_notify(dev, ETHTOOL_MSG_PRIVFLAGS_NTF, NULL); + break; + case ETHTOOL_SPFLAGS: + rc = ethtool_set_value(dev, useraddr, + dev->ethtool_ops->set_priv_flags); + break; + case ETHTOOL_GRXFH: + case ETHTOOL_GRXRINGS: + case ETHTOOL_GRXCLSRLCNT: + case ETHTOOL_GRXCLSRULE: + case ETHTOOL_GRXCLSRLALL: + rc = ethtool_get_rxnfc(dev, ethcmd, useraddr); + break; + case ETHTOOL_SRXFH: + case ETHTOOL_SRXCLSRLDEL: + case ETHTOOL_SRXCLSRLINS: + rc = ethtool_set_rxnfc(dev, ethcmd, useraddr); + break; + case ETHTOOL_FLASHDEV: + rc = ethtool_flash_device(dev, devlink_state); + break; + case ETHTOOL_RESET: + rc = ethtool_reset(dev, useraddr); + break; + case ETHTOOL_GSSET_INFO: + rc = ethtool_get_sset_info(dev, useraddr); + break; + case ETHTOOL_GRXFHINDIR: + rc = ethtool_get_rxfh_indir(dev, useraddr); + break; + case ETHTOOL_SRXFHINDIR: + rc = ethtool_set_rxfh_indir(dev, useraddr); + break; + case ETHTOOL_GRSSH: + rc = ethtool_get_rxfh(dev, useraddr); + break; + case ETHTOOL_SRSSH: + rc = ethtool_set_rxfh(dev, useraddr); + break; + case ETHTOOL_GFEATURES: + rc = ethtool_get_features(dev, useraddr); + break; + case ETHTOOL_SFEATURES: + rc = ethtool_set_features(dev, useraddr); + break; + case ETHTOOL_GTXCSUM: + case ETHTOOL_GRXCSUM: + case ETHTOOL_GSG: + case ETHTOOL_GTSO: + case ETHTOOL_GGSO: + case ETHTOOL_GGRO: + rc = ethtool_get_one_feature(dev, useraddr, ethcmd); + break; + case ETHTOOL_STXCSUM: + case ETHTOOL_SRXCSUM: + case ETHTOOL_SSG: + case ETHTOOL_STSO: + case ETHTOOL_SGSO: + case ETHTOOL_SGRO: + rc = ethtool_set_one_feature(dev, useraddr, ethcmd); + break; + case ETHTOOL_GCHANNELS: + rc = ethtool_get_channels(dev, useraddr); + break; + case ETHTOOL_SCHANNELS: + rc = ethtool_set_channels(dev, useraddr); + break; + case ETHTOOL_SET_DUMP: + rc = ethtool_set_dump(dev, useraddr); + break; + case ETHTOOL_GET_DUMP_FLAG: + rc = ethtool_get_dump_flag(dev, useraddr); + break; + case ETHTOOL_GET_DUMP_DATA: + rc = ethtool_get_dump_data(dev, useraddr); + break; + case ETHTOOL_GET_TS_INFO: + rc = ethtool_get_ts_info(dev, useraddr); + break; + case ETHTOOL_GMODULEINFO: + rc = ethtool_get_module_info(dev, useraddr); + break; + case ETHTOOL_GMODULEEEPROM: + rc = ethtool_get_module_eeprom(dev, useraddr); + break; + case ETHTOOL_GTUNABLE: + rc = ethtool_get_tunable(dev, useraddr); + break; + case ETHTOOL_STUNABLE: + rc = ethtool_set_tunable(dev, useraddr); + break; + case ETHTOOL_GPHYSTATS: + rc = ethtool_get_phy_stats(dev, useraddr); + break; + case ETHTOOL_PERQUEUE: + rc = ethtool_set_per_queue(dev, useraddr, sub_cmd); + break; + case ETHTOOL_GLINKSETTINGS: + rc = ethtool_get_link_ksettings(dev, useraddr); + break; + case ETHTOOL_SLINKSETTINGS: + rc = ethtool_set_link_ksettings(dev, useraddr); + break; + case ETHTOOL_PHY_GTUNABLE: + rc = get_phy_tunable(dev, useraddr); + break; + case ETHTOOL_PHY_STUNABLE: + rc = set_phy_tunable(dev, useraddr); + break; + case ETHTOOL_GFECPARAM: + rc = ethtool_get_fecparam(dev, useraddr); + break; + case ETHTOOL_SFECPARAM: + rc = ethtool_set_fecparam(dev, useraddr); + break; + default: + rc = -EOPNOTSUPP; + } + + if (dev->ethtool_ops->complete) + dev->ethtool_ops->complete(dev); + + if (old_features != dev->features) + netdev_features_change(dev); +out: + if (dev->dev.parent) + pm_runtime_put(dev->dev.parent); + + return rc; +} + +int dev_ethtool(struct net *net, struct ifreq *ifr, void __user *useraddr) +{ + struct ethtool_devlink_compat *state; + u32 ethcmd; + int rc; + + if (copy_from_user(ðcmd, useraddr, sizeof(ethcmd))) + return -EFAULT; + + state = kzalloc(sizeof(*state), GFP_KERNEL); + if (!state) + return -ENOMEM; + + switch (ethcmd) { + case ETHTOOL_FLASHDEV: + if (copy_from_user(&state->efl, useraddr, sizeof(state->efl))) { + rc = -EFAULT; + goto exit_free; + } + state->efl.data[ETHTOOL_FLASH_MAX_FILENAME - 1] = 0; + break; + } + + rtnl_lock(); + rc = __dev_ethtool(net, ifr, useraddr, ethcmd, state); + rtnl_unlock(); + if (rc) + goto exit_free; + + switch (ethcmd) { + case ETHTOOL_FLASHDEV: + if (state->devlink) + rc = devlink_compat_flash_update(state->devlink, + state->efl.data); + break; + case ETHTOOL_GDRVINFO: + if (state->devlink) + devlink_compat_running_version(state->devlink, + state->info.fw_version, + sizeof(state->info.fw_version)); + if (copy_to_user(useraddr, &state->info, sizeof(state->info))) { + rc = -EFAULT; + goto exit_free; + } + break; + } + +exit_free: + if (state->devlink) + devlink_put(state->devlink); + kfree(state); + return rc; +} + +struct ethtool_rx_flow_key { + struct flow_dissector_key_basic basic; + union { + struct flow_dissector_key_ipv4_addrs ipv4; + struct flow_dissector_key_ipv6_addrs ipv6; + }; + struct flow_dissector_key_ports tp; + struct flow_dissector_key_ip ip; + struct flow_dissector_key_vlan vlan; + struct flow_dissector_key_eth_addrs eth_addrs; +} __aligned(BITS_PER_LONG / 8); /* Ensure that we can do comparisons as longs. */ + +struct ethtool_rx_flow_match { + struct flow_dissector dissector; + struct ethtool_rx_flow_key key; + struct ethtool_rx_flow_key mask; +}; + +struct ethtool_rx_flow_rule * +ethtool_rx_flow_rule_create(const struct ethtool_rx_flow_spec_input *input) +{ + const struct ethtool_rx_flow_spec *fs = input->fs; + struct ethtool_rx_flow_match *match; + struct ethtool_rx_flow_rule *flow; + struct flow_action_entry *act; + + flow = kzalloc(sizeof(struct ethtool_rx_flow_rule) + + sizeof(struct ethtool_rx_flow_match), GFP_KERNEL); + if (!flow) + return ERR_PTR(-ENOMEM); + + /* ethtool_rx supports only one single action per rule. */ + flow->rule = flow_rule_alloc(1); + if (!flow->rule) { + kfree(flow); + return ERR_PTR(-ENOMEM); + } + + match = (struct ethtool_rx_flow_match *)flow->priv; + flow->rule->match.dissector = &match->dissector; + flow->rule->match.mask = &match->mask; + flow->rule->match.key = &match->key; + + match->mask.basic.n_proto = htons(0xffff); + + switch (fs->flow_type & ~(FLOW_EXT | FLOW_MAC_EXT | FLOW_RSS)) { + case ETHER_FLOW: { + const struct ethhdr *ether_spec, *ether_m_spec; + + ether_spec = &fs->h_u.ether_spec; + ether_m_spec = &fs->m_u.ether_spec; + + if (!is_zero_ether_addr(ether_m_spec->h_source)) { + ether_addr_copy(match->key.eth_addrs.src, + ether_spec->h_source); + ether_addr_copy(match->mask.eth_addrs.src, + ether_m_spec->h_source); + } + if (!is_zero_ether_addr(ether_m_spec->h_dest)) { + ether_addr_copy(match->key.eth_addrs.dst, + ether_spec->h_dest); + ether_addr_copy(match->mask.eth_addrs.dst, + ether_m_spec->h_dest); + } + if (ether_m_spec->h_proto) { + match->key.basic.n_proto = ether_spec->h_proto; + match->mask.basic.n_proto = ether_m_spec->h_proto; + } + } + break; + case TCP_V4_FLOW: + case UDP_V4_FLOW: { + const struct ethtool_tcpip4_spec *v4_spec, *v4_m_spec; + + match->key.basic.n_proto = htons(ETH_P_IP); + + v4_spec = &fs->h_u.tcp_ip4_spec; + v4_m_spec = &fs->m_u.tcp_ip4_spec; + + if (v4_m_spec->ip4src) { + match->key.ipv4.src = v4_spec->ip4src; + match->mask.ipv4.src = v4_m_spec->ip4src; + } + if (v4_m_spec->ip4dst) { + match->key.ipv4.dst = v4_spec->ip4dst; + match->mask.ipv4.dst = v4_m_spec->ip4dst; + } + if (v4_m_spec->ip4src || + v4_m_spec->ip4dst) { + match->dissector.used_keys |= + BIT(FLOW_DISSECTOR_KEY_IPV4_ADDRS); + match->dissector.offset[FLOW_DISSECTOR_KEY_IPV4_ADDRS] = + offsetof(struct ethtool_rx_flow_key, ipv4); + } + if (v4_m_spec->psrc) { + match->key.tp.src = v4_spec->psrc; + match->mask.tp.src = v4_m_spec->psrc; + } + if (v4_m_spec->pdst) { + match->key.tp.dst = v4_spec->pdst; + match->mask.tp.dst = v4_m_spec->pdst; + } + if (v4_m_spec->psrc || + v4_m_spec->pdst) { + match->dissector.used_keys |= + BIT(FLOW_DISSECTOR_KEY_PORTS); + match->dissector.offset[FLOW_DISSECTOR_KEY_PORTS] = + offsetof(struct ethtool_rx_flow_key, tp); + } + if (v4_m_spec->tos) { + match->key.ip.tos = v4_spec->tos; + match->mask.ip.tos = v4_m_spec->tos; + match->dissector.used_keys |= + BIT(FLOW_DISSECTOR_KEY_IP); + match->dissector.offset[FLOW_DISSECTOR_KEY_IP] = + offsetof(struct ethtool_rx_flow_key, ip); + } + } + break; + case TCP_V6_FLOW: + case UDP_V6_FLOW: { + const struct ethtool_tcpip6_spec *v6_spec, *v6_m_spec; + + match->key.basic.n_proto = htons(ETH_P_IPV6); + + v6_spec = &fs->h_u.tcp_ip6_spec; + v6_m_spec = &fs->m_u.tcp_ip6_spec; + if (!ipv6_addr_any((struct in6_addr *)v6_m_spec->ip6src)) { + memcpy(&match->key.ipv6.src, v6_spec->ip6src, + sizeof(match->key.ipv6.src)); + memcpy(&match->mask.ipv6.src, v6_m_spec->ip6src, + sizeof(match->mask.ipv6.src)); + } + if (!ipv6_addr_any((struct in6_addr *)v6_m_spec->ip6dst)) { + memcpy(&match->key.ipv6.dst, v6_spec->ip6dst, + sizeof(match->key.ipv6.dst)); + memcpy(&match->mask.ipv6.dst, v6_m_spec->ip6dst, + sizeof(match->mask.ipv6.dst)); + } + if (!ipv6_addr_any((struct in6_addr *)v6_m_spec->ip6src) || + !ipv6_addr_any((struct in6_addr *)v6_m_spec->ip6dst)) { + match->dissector.used_keys |= + BIT(FLOW_DISSECTOR_KEY_IPV6_ADDRS); + match->dissector.offset[FLOW_DISSECTOR_KEY_IPV6_ADDRS] = + offsetof(struct ethtool_rx_flow_key, ipv6); + } + if (v6_m_spec->psrc) { + match->key.tp.src = v6_spec->psrc; + match->mask.tp.src = v6_m_spec->psrc; + } + if (v6_m_spec->pdst) { + match->key.tp.dst = v6_spec->pdst; + match->mask.tp.dst = v6_m_spec->pdst; + } + if (v6_m_spec->psrc || + v6_m_spec->pdst) { + match->dissector.used_keys |= + BIT(FLOW_DISSECTOR_KEY_PORTS); + match->dissector.offset[FLOW_DISSECTOR_KEY_PORTS] = + offsetof(struct ethtool_rx_flow_key, tp); + } + if (v6_m_spec->tclass) { + match->key.ip.tos = v6_spec->tclass; + match->mask.ip.tos = v6_m_spec->tclass; + match->dissector.used_keys |= + BIT(FLOW_DISSECTOR_KEY_IP); + match->dissector.offset[FLOW_DISSECTOR_KEY_IP] = + offsetof(struct ethtool_rx_flow_key, ip); + } + } + break; + default: + ethtool_rx_flow_rule_destroy(flow); + return ERR_PTR(-EINVAL); + } + + switch (fs->flow_type & ~(FLOW_EXT | FLOW_MAC_EXT | FLOW_RSS)) { + case TCP_V4_FLOW: + case TCP_V6_FLOW: + match->key.basic.ip_proto = IPPROTO_TCP; + match->mask.basic.ip_proto = 0xff; + break; + case UDP_V4_FLOW: + case UDP_V6_FLOW: + match->key.basic.ip_proto = IPPROTO_UDP; + match->mask.basic.ip_proto = 0xff; + break; + } + + match->dissector.used_keys |= BIT(FLOW_DISSECTOR_KEY_BASIC); + match->dissector.offset[FLOW_DISSECTOR_KEY_BASIC] = + offsetof(struct ethtool_rx_flow_key, basic); + + if (fs->flow_type & FLOW_EXT) { + const struct ethtool_flow_ext *ext_h_spec = &fs->h_ext; + const struct ethtool_flow_ext *ext_m_spec = &fs->m_ext; + + if (ext_m_spec->vlan_etype) { + match->key.vlan.vlan_tpid = ext_h_spec->vlan_etype; + match->mask.vlan.vlan_tpid = ext_m_spec->vlan_etype; + } + + if (ext_m_spec->vlan_tci) { + match->key.vlan.vlan_id = + ntohs(ext_h_spec->vlan_tci) & 0x0fff; + match->mask.vlan.vlan_id = + ntohs(ext_m_spec->vlan_tci) & 0x0fff; + + match->key.vlan.vlan_dei = + !!(ext_h_spec->vlan_tci & htons(0x1000)); + match->mask.vlan.vlan_dei = + !!(ext_m_spec->vlan_tci & htons(0x1000)); + + match->key.vlan.vlan_priority = + (ntohs(ext_h_spec->vlan_tci) & 0xe000) >> 13; + match->mask.vlan.vlan_priority = + (ntohs(ext_m_spec->vlan_tci) & 0xe000) >> 13; + } + + if (ext_m_spec->vlan_etype || + ext_m_spec->vlan_tci) { + match->dissector.used_keys |= + BIT(FLOW_DISSECTOR_KEY_VLAN); + match->dissector.offset[FLOW_DISSECTOR_KEY_VLAN] = + offsetof(struct ethtool_rx_flow_key, vlan); + } + } + if (fs->flow_type & FLOW_MAC_EXT) { + const struct ethtool_flow_ext *ext_h_spec = &fs->h_ext; + const struct ethtool_flow_ext *ext_m_spec = &fs->m_ext; + + memcpy(match->key.eth_addrs.dst, ext_h_spec->h_dest, + ETH_ALEN); + memcpy(match->mask.eth_addrs.dst, ext_m_spec->h_dest, + ETH_ALEN); + + match->dissector.used_keys |= + BIT(FLOW_DISSECTOR_KEY_ETH_ADDRS); + match->dissector.offset[FLOW_DISSECTOR_KEY_ETH_ADDRS] = + offsetof(struct ethtool_rx_flow_key, eth_addrs); + } + + act = &flow->rule->action.entries[0]; + switch (fs->ring_cookie) { + case RX_CLS_FLOW_DISC: + act->id = FLOW_ACTION_DROP; + break; + case RX_CLS_FLOW_WAKE: + act->id = FLOW_ACTION_WAKE; + break; + default: + act->id = FLOW_ACTION_QUEUE; + if (fs->flow_type & FLOW_RSS) + act->queue.ctx = input->rss_ctx; + + act->queue.vf = ethtool_get_flow_spec_ring_vf(fs->ring_cookie); + act->queue.index = ethtool_get_flow_spec_ring(fs->ring_cookie); + break; + } + + return flow; +} +EXPORT_SYMBOL(ethtool_rx_flow_rule_create); + +void ethtool_rx_flow_rule_destroy(struct ethtool_rx_flow_rule *flow) +{ + kfree(flow->rule); + kfree(flow); +} +EXPORT_SYMBOL(ethtool_rx_flow_rule_destroy); diff --git a/net/ethtool/linkinfo.c b/net/ethtool/linkinfo.c new file mode 100644 index 000000000..efa0f7f48 --- /dev/null +++ b/net/ethtool/linkinfo.c @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include "netlink.h" +#include "common.h" + +struct linkinfo_req_info { + struct ethnl_req_info base; +}; + +struct linkinfo_reply_data { + struct ethnl_reply_data base; + struct ethtool_link_ksettings ksettings; + struct ethtool_link_settings *lsettings; +}; + +#define LINKINFO_REPDATA(__reply_base) \ + container_of(__reply_base, struct linkinfo_reply_data, base) + +const struct nla_policy ethnl_linkinfo_get_policy[] = { + [ETHTOOL_A_LINKINFO_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), +}; + +static int linkinfo_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + struct linkinfo_reply_data *data = LINKINFO_REPDATA(reply_base); + struct net_device *dev = reply_base->dev; + int ret; + + data->lsettings = &data->ksettings.base; + + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + ret = __ethtool_get_link_ksettings(dev, &data->ksettings); + if (ret < 0 && info) + GENL_SET_ERR_MSG(info, "failed to retrieve link settings"); + ethnl_ops_complete(dev); + + return ret; +} + +static int linkinfo_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + return nla_total_size(sizeof(u8)) /* LINKINFO_PORT */ + + nla_total_size(sizeof(u8)) /* LINKINFO_PHYADDR */ + + nla_total_size(sizeof(u8)) /* LINKINFO_TP_MDIX */ + + nla_total_size(sizeof(u8)) /* LINKINFO_TP_MDIX_CTRL */ + + nla_total_size(sizeof(u8)) /* LINKINFO_TRANSCEIVER */ + + 0; +} + +static int linkinfo_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct linkinfo_reply_data *data = LINKINFO_REPDATA(reply_base); + + if (nla_put_u8(skb, ETHTOOL_A_LINKINFO_PORT, data->lsettings->port) || + nla_put_u8(skb, ETHTOOL_A_LINKINFO_PHYADDR, + data->lsettings->phy_address) || + nla_put_u8(skb, ETHTOOL_A_LINKINFO_TP_MDIX, + data->lsettings->eth_tp_mdix) || + nla_put_u8(skb, ETHTOOL_A_LINKINFO_TP_MDIX_CTRL, + data->lsettings->eth_tp_mdix_ctrl) || + nla_put_u8(skb, ETHTOOL_A_LINKINFO_TRANSCEIVER, + data->lsettings->transceiver)) + return -EMSGSIZE; + + return 0; +} + +const struct ethnl_request_ops ethnl_linkinfo_request_ops = { + .request_cmd = ETHTOOL_MSG_LINKINFO_GET, + .reply_cmd = ETHTOOL_MSG_LINKINFO_GET_REPLY, + .hdr_attr = ETHTOOL_A_LINKINFO_HEADER, + .req_info_size = sizeof(struct linkinfo_req_info), + .reply_data_size = sizeof(struct linkinfo_reply_data), + + .prepare_data = linkinfo_prepare_data, + .reply_size = linkinfo_reply_size, + .fill_reply = linkinfo_fill_reply, +}; + +/* LINKINFO_SET */ + +const struct nla_policy ethnl_linkinfo_set_policy[] = { + [ETHTOOL_A_LINKINFO_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_LINKINFO_PORT] = { .type = NLA_U8 }, + [ETHTOOL_A_LINKINFO_PHYADDR] = { .type = NLA_U8 }, + [ETHTOOL_A_LINKINFO_TP_MDIX_CTRL] = { .type = NLA_U8 }, +}; + +int ethnl_set_linkinfo(struct sk_buff *skb, struct genl_info *info) +{ + struct ethtool_link_ksettings ksettings = {}; + struct ethtool_link_settings *lsettings; + struct ethnl_req_info req_info = {}; + struct nlattr **tb = info->attrs; + struct net_device *dev; + bool mod = false; + int ret; + + ret = ethnl_parse_header_dev_get(&req_info, + tb[ETHTOOL_A_LINKINFO_HEADER], + genl_info_net(info), info->extack, + true); + if (ret < 0) + return ret; + dev = req_info.dev; + ret = -EOPNOTSUPP; + if (!dev->ethtool_ops->get_link_ksettings || + !dev->ethtool_ops->set_link_ksettings) + goto out_dev; + + rtnl_lock(); + ret = ethnl_ops_begin(dev); + if (ret < 0) + goto out_rtnl; + + ret = __ethtool_get_link_ksettings(dev, &ksettings); + if (ret < 0) { + GENL_SET_ERR_MSG(info, "failed to retrieve link settings"); + goto out_ops; + } + lsettings = &ksettings.base; + + ethnl_update_u8(&lsettings->port, tb[ETHTOOL_A_LINKINFO_PORT], &mod); + ethnl_update_u8(&lsettings->phy_address, tb[ETHTOOL_A_LINKINFO_PHYADDR], + &mod); + ethnl_update_u8(&lsettings->eth_tp_mdix_ctrl, + tb[ETHTOOL_A_LINKINFO_TP_MDIX_CTRL], &mod); + ret = 0; + if (!mod) + goto out_ops; + + ret = dev->ethtool_ops->set_link_ksettings(dev, &ksettings); + if (ret < 0) + GENL_SET_ERR_MSG(info, "link settings update failed"); + else + ethtool_notify(dev, ETHTOOL_MSG_LINKINFO_NTF, NULL); + +out_ops: + ethnl_ops_complete(dev); +out_rtnl: + rtnl_unlock(); +out_dev: + ethnl_parse_header_dev_put(&req_info); + return ret; +} diff --git a/net/ethtool/linkmodes.c b/net/ethtool/linkmodes.c new file mode 100644 index 000000000..2d91f2a8c --- /dev/null +++ b/net/ethtool/linkmodes.c @@ -0,0 +1,369 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include "netlink.h" +#include "common.h" +#include "bitset.h" + +/* LINKMODES_GET */ + +struct linkmodes_req_info { + struct ethnl_req_info base; +}; + +struct linkmodes_reply_data { + struct ethnl_reply_data base; + struct ethtool_link_ksettings ksettings; + struct ethtool_link_settings *lsettings; + bool peer_empty; +}; + +#define LINKMODES_REPDATA(__reply_base) \ + container_of(__reply_base, struct linkmodes_reply_data, base) + +const struct nla_policy ethnl_linkmodes_get_policy[] = { + [ETHTOOL_A_LINKMODES_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), +}; + +static int linkmodes_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + struct linkmodes_reply_data *data = LINKMODES_REPDATA(reply_base); + struct net_device *dev = reply_base->dev; + int ret; + + data->lsettings = &data->ksettings.base; + + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + + ret = __ethtool_get_link_ksettings(dev, &data->ksettings); + if (ret < 0 && info) { + GENL_SET_ERR_MSG(info, "failed to retrieve link settings"); + goto out; + } + + if (!dev->ethtool_ops->cap_link_lanes_supported) + data->ksettings.lanes = 0; + + data->peer_empty = + bitmap_empty(data->ksettings.link_modes.lp_advertising, + __ETHTOOL_LINK_MODE_MASK_NBITS); + +out: + ethnl_ops_complete(dev); + return ret; +} + +static int linkmodes_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct linkmodes_reply_data *data = LINKMODES_REPDATA(reply_base); + const struct ethtool_link_ksettings *ksettings = &data->ksettings; + const struct ethtool_link_settings *lsettings = &ksettings->base; + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + int len, ret; + + len = nla_total_size(sizeof(u8)) /* LINKMODES_AUTONEG */ + + nla_total_size(sizeof(u32)) /* LINKMODES_SPEED */ + + nla_total_size(sizeof(u32)) /* LINKMODES_LANES */ + + nla_total_size(sizeof(u8)) /* LINKMODES_DUPLEX */ + + nla_total_size(sizeof(u8)) /* LINKMODES_RATE_MATCHING */ + + 0; + ret = ethnl_bitset_size(ksettings->link_modes.advertising, + ksettings->link_modes.supported, + __ETHTOOL_LINK_MODE_MASK_NBITS, + link_mode_names, compact); + if (ret < 0) + return ret; + len += ret; + if (!data->peer_empty) { + ret = ethnl_bitset_size(ksettings->link_modes.lp_advertising, + NULL, __ETHTOOL_LINK_MODE_MASK_NBITS, + link_mode_names, compact); + if (ret < 0) + return ret; + len += ret; + } + + if (lsettings->master_slave_cfg != MASTER_SLAVE_CFG_UNSUPPORTED) + len += nla_total_size(sizeof(u8)); + + if (lsettings->master_slave_state != MASTER_SLAVE_STATE_UNSUPPORTED) + len += nla_total_size(sizeof(u8)); + + return len; +} + +static int linkmodes_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct linkmodes_reply_data *data = LINKMODES_REPDATA(reply_base); + const struct ethtool_link_ksettings *ksettings = &data->ksettings; + const struct ethtool_link_settings *lsettings = &ksettings->base; + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + int ret; + + if (nla_put_u8(skb, ETHTOOL_A_LINKMODES_AUTONEG, lsettings->autoneg)) + return -EMSGSIZE; + + ret = ethnl_put_bitset(skb, ETHTOOL_A_LINKMODES_OURS, + ksettings->link_modes.advertising, + ksettings->link_modes.supported, + __ETHTOOL_LINK_MODE_MASK_NBITS, link_mode_names, + compact); + if (ret < 0) + return -EMSGSIZE; + if (!data->peer_empty) { + ret = ethnl_put_bitset(skb, ETHTOOL_A_LINKMODES_PEER, + ksettings->link_modes.lp_advertising, + NULL, __ETHTOOL_LINK_MODE_MASK_NBITS, + link_mode_names, compact); + if (ret < 0) + return -EMSGSIZE; + } + + if (nla_put_u32(skb, ETHTOOL_A_LINKMODES_SPEED, lsettings->speed) || + nla_put_u8(skb, ETHTOOL_A_LINKMODES_DUPLEX, lsettings->duplex)) + return -EMSGSIZE; + + if (ksettings->lanes && + nla_put_u32(skb, ETHTOOL_A_LINKMODES_LANES, ksettings->lanes)) + return -EMSGSIZE; + + if (lsettings->master_slave_cfg != MASTER_SLAVE_CFG_UNSUPPORTED && + nla_put_u8(skb, ETHTOOL_A_LINKMODES_MASTER_SLAVE_CFG, + lsettings->master_slave_cfg)) + return -EMSGSIZE; + + if (lsettings->master_slave_state != MASTER_SLAVE_STATE_UNSUPPORTED && + nla_put_u8(skb, ETHTOOL_A_LINKMODES_MASTER_SLAVE_STATE, + lsettings->master_slave_state)) + return -EMSGSIZE; + + if (nla_put_u8(skb, ETHTOOL_A_LINKMODES_RATE_MATCHING, + lsettings->rate_matching)) + return -EMSGSIZE; + + return 0; +} + +const struct ethnl_request_ops ethnl_linkmodes_request_ops = { + .request_cmd = ETHTOOL_MSG_LINKMODES_GET, + .reply_cmd = ETHTOOL_MSG_LINKMODES_GET_REPLY, + .hdr_attr = ETHTOOL_A_LINKMODES_HEADER, + .req_info_size = sizeof(struct linkmodes_req_info), + .reply_data_size = sizeof(struct linkmodes_reply_data), + + .prepare_data = linkmodes_prepare_data, + .reply_size = linkmodes_reply_size, + .fill_reply = linkmodes_fill_reply, +}; + +/* LINKMODES_SET */ + +const struct nla_policy ethnl_linkmodes_set_policy[] = { + [ETHTOOL_A_LINKMODES_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_LINKMODES_AUTONEG] = { .type = NLA_U8 }, + [ETHTOOL_A_LINKMODES_OURS] = { .type = NLA_NESTED }, + [ETHTOOL_A_LINKMODES_SPEED] = { .type = NLA_U32 }, + [ETHTOOL_A_LINKMODES_DUPLEX] = { .type = NLA_U8 }, + [ETHTOOL_A_LINKMODES_MASTER_SLAVE_CFG] = { .type = NLA_U8 }, + [ETHTOOL_A_LINKMODES_LANES] = NLA_POLICY_RANGE(NLA_U32, 1, 8), +}; + +/* Set advertised link modes to all supported modes matching requested speed, + * lanes and duplex values. Called when autonegotiation is on, speed, lanes or + * duplex is requested but no link mode change. This is done in userspace with + * ioctl() interface, move it into kernel for netlink. + * Returns true if advertised modes bitmap was modified. + */ +static bool ethnl_auto_linkmodes(struct ethtool_link_ksettings *ksettings, + bool req_speed, bool req_lanes, bool req_duplex) +{ + unsigned long *advertising = ksettings->link_modes.advertising; + unsigned long *supported = ksettings->link_modes.supported; + DECLARE_BITMAP(old_adv, __ETHTOOL_LINK_MODE_MASK_NBITS); + unsigned int i; + + bitmap_copy(old_adv, advertising, __ETHTOOL_LINK_MODE_MASK_NBITS); + + for (i = 0; i < __ETHTOOL_LINK_MODE_MASK_NBITS; i++) { + const struct link_mode_info *info = &link_mode_params[i]; + + if (info->speed == SPEED_UNKNOWN) + continue; + if (test_bit(i, supported) && + (!req_speed || info->speed == ksettings->base.speed) && + (!req_lanes || info->lanes == ksettings->lanes) && + (!req_duplex || info->duplex == ksettings->base.duplex)) + set_bit(i, advertising); + else + clear_bit(i, advertising); + } + + return !bitmap_equal(old_adv, advertising, + __ETHTOOL_LINK_MODE_MASK_NBITS); +} + +static bool ethnl_validate_master_slave_cfg(u8 cfg) +{ + switch (cfg) { + case MASTER_SLAVE_CFG_MASTER_PREFERRED: + case MASTER_SLAVE_CFG_SLAVE_PREFERRED: + case MASTER_SLAVE_CFG_MASTER_FORCE: + case MASTER_SLAVE_CFG_SLAVE_FORCE: + return true; + } + + return false; +} + +static int ethnl_check_linkmodes(struct genl_info *info, struct nlattr **tb) +{ + const struct nlattr *master_slave_cfg, *lanes_cfg; + + master_slave_cfg = tb[ETHTOOL_A_LINKMODES_MASTER_SLAVE_CFG]; + if (master_slave_cfg && + !ethnl_validate_master_slave_cfg(nla_get_u8(master_slave_cfg))) { + NL_SET_ERR_MSG_ATTR(info->extack, master_slave_cfg, + "master/slave value is invalid"); + return -EOPNOTSUPP; + } + + lanes_cfg = tb[ETHTOOL_A_LINKMODES_LANES]; + if (lanes_cfg && !is_power_of_2(nla_get_u32(lanes_cfg))) { + NL_SET_ERR_MSG_ATTR(info->extack, lanes_cfg, + "lanes value is invalid"); + return -EINVAL; + } + + return 0; +} + +static int ethnl_update_linkmodes(struct genl_info *info, struct nlattr **tb, + struct ethtool_link_ksettings *ksettings, + bool *mod, const struct net_device *dev) +{ + struct ethtool_link_settings *lsettings = &ksettings->base; + bool req_speed, req_lanes, req_duplex; + const struct nlattr *master_slave_cfg, *lanes_cfg; + int ret; + + master_slave_cfg = tb[ETHTOOL_A_LINKMODES_MASTER_SLAVE_CFG]; + if (master_slave_cfg) { + if (lsettings->master_slave_cfg == MASTER_SLAVE_CFG_UNSUPPORTED) { + NL_SET_ERR_MSG_ATTR(info->extack, master_slave_cfg, + "master/slave configuration not supported by device"); + return -EOPNOTSUPP; + } + } + + *mod = false; + req_speed = tb[ETHTOOL_A_LINKMODES_SPEED]; + req_lanes = tb[ETHTOOL_A_LINKMODES_LANES]; + req_duplex = tb[ETHTOOL_A_LINKMODES_DUPLEX]; + + ethnl_update_u8(&lsettings->autoneg, tb[ETHTOOL_A_LINKMODES_AUTONEG], + mod); + + lanes_cfg = tb[ETHTOOL_A_LINKMODES_LANES]; + if (lanes_cfg) { + /* If autoneg is off and lanes parameter is not supported by the + * driver, return an error. + */ + if (!lsettings->autoneg && + !dev->ethtool_ops->cap_link_lanes_supported) { + NL_SET_ERR_MSG_ATTR(info->extack, lanes_cfg, + "lanes configuration not supported by device"); + return -EOPNOTSUPP; + } + } else if (!lsettings->autoneg && ksettings->lanes) { + /* If autoneg is off and lanes parameter is not passed from user but + * it was defined previously then set the lanes parameter to 0. + */ + ksettings->lanes = 0; + *mod = true; + } + + ret = ethnl_update_bitset(ksettings->link_modes.advertising, + __ETHTOOL_LINK_MODE_MASK_NBITS, + tb[ETHTOOL_A_LINKMODES_OURS], link_mode_names, + info->extack, mod); + if (ret < 0) + return ret; + ethnl_update_u32(&lsettings->speed, tb[ETHTOOL_A_LINKMODES_SPEED], + mod); + ethnl_update_u32(&ksettings->lanes, lanes_cfg, mod); + ethnl_update_u8(&lsettings->duplex, tb[ETHTOOL_A_LINKMODES_DUPLEX], + mod); + ethnl_update_u8(&lsettings->master_slave_cfg, master_slave_cfg, mod); + + if (!tb[ETHTOOL_A_LINKMODES_OURS] && lsettings->autoneg && + (req_speed || req_lanes || req_duplex) && + ethnl_auto_linkmodes(ksettings, req_speed, req_lanes, req_duplex)) + *mod = true; + + return 0; +} + +int ethnl_set_linkmodes(struct sk_buff *skb, struct genl_info *info) +{ + struct ethtool_link_ksettings ksettings = {}; + struct ethnl_req_info req_info = {}; + struct nlattr **tb = info->attrs; + struct net_device *dev; + bool mod = false; + int ret; + + ret = ethnl_check_linkmodes(info, tb); + if (ret < 0) + return ret; + + ret = ethnl_parse_header_dev_get(&req_info, + tb[ETHTOOL_A_LINKMODES_HEADER], + genl_info_net(info), info->extack, + true); + if (ret < 0) + return ret; + dev = req_info.dev; + ret = -EOPNOTSUPP; + if (!dev->ethtool_ops->get_link_ksettings || + !dev->ethtool_ops->set_link_ksettings) + goto out_dev; + + rtnl_lock(); + ret = ethnl_ops_begin(dev); + if (ret < 0) + goto out_rtnl; + + ret = __ethtool_get_link_ksettings(dev, &ksettings); + if (ret < 0) { + GENL_SET_ERR_MSG(info, "failed to retrieve link settings"); + goto out_ops; + } + + ret = ethnl_update_linkmodes(info, tb, &ksettings, &mod, dev); + if (ret < 0) + goto out_ops; + + if (mod) { + ret = dev->ethtool_ops->set_link_ksettings(dev, &ksettings); + if (ret < 0) + GENL_SET_ERR_MSG(info, "link settings update failed"); + else + ethtool_notify(dev, ETHTOOL_MSG_LINKMODES_NTF, NULL); + } + +out_ops: + ethnl_ops_complete(dev); +out_rtnl: + rtnl_unlock(); +out_dev: + ethnl_parse_header_dev_put(&req_info); + return ret; +} diff --git a/net/ethtool/linkstate.c b/net/ethtool/linkstate.c new file mode 100644 index 000000000..fb676f349 --- /dev/null +++ b/net/ethtool/linkstate.c @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include "netlink.h" +#include "common.h" +#include + +struct linkstate_req_info { + struct ethnl_req_info base; +}; + +struct linkstate_reply_data { + struct ethnl_reply_data base; + int link; + int sqi; + int sqi_max; + bool link_ext_state_provided; + struct ethtool_link_ext_state_info ethtool_link_ext_state_info; +}; + +#define LINKSTATE_REPDATA(__reply_base) \ + container_of(__reply_base, struct linkstate_reply_data, base) + +const struct nla_policy ethnl_linkstate_get_policy[] = { + [ETHTOOL_A_LINKSTATE_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), +}; + +static int linkstate_get_sqi(struct net_device *dev) +{ + struct phy_device *phydev = dev->phydev; + int ret; + + if (!phydev) + return -EOPNOTSUPP; + + mutex_lock(&phydev->lock); + if (!phydev->drv || !phydev->drv->get_sqi) + ret = -EOPNOTSUPP; + else + ret = phydev->drv->get_sqi(phydev); + mutex_unlock(&phydev->lock); + + return ret; +} + +static int linkstate_get_sqi_max(struct net_device *dev) +{ + struct phy_device *phydev = dev->phydev; + int ret; + + if (!phydev) + return -EOPNOTSUPP; + + mutex_lock(&phydev->lock); + if (!phydev->drv || !phydev->drv->get_sqi_max) + ret = -EOPNOTSUPP; + else + ret = phydev->drv->get_sqi_max(phydev); + mutex_unlock(&phydev->lock); + + return ret; +}; + +static int linkstate_get_link_ext_state(struct net_device *dev, + struct linkstate_reply_data *data) +{ + int err; + + if (!dev->ethtool_ops->get_link_ext_state) + return -EOPNOTSUPP; + + err = dev->ethtool_ops->get_link_ext_state(dev, &data->ethtool_link_ext_state_info); + if (err) + return err; + + data->link_ext_state_provided = true; + + return 0; +} + +static int linkstate_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + struct linkstate_reply_data *data = LINKSTATE_REPDATA(reply_base); + struct net_device *dev = reply_base->dev; + int ret; + + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + data->link = __ethtool_get_link(dev); + + ret = linkstate_get_sqi(dev); + if (ret < 0 && ret != -EOPNOTSUPP) + goto out; + data->sqi = ret; + + ret = linkstate_get_sqi_max(dev); + if (ret < 0 && ret != -EOPNOTSUPP) + goto out; + data->sqi_max = ret; + + if (dev->flags & IFF_UP) { + ret = linkstate_get_link_ext_state(dev, data); + if (ret < 0 && ret != -EOPNOTSUPP && ret != -ENODATA) + goto out; + } + + ret = 0; +out: + ethnl_ops_complete(dev); + return ret; +} + +static int linkstate_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + struct linkstate_reply_data *data = LINKSTATE_REPDATA(reply_base); + int len; + + len = nla_total_size(sizeof(u8)) /* LINKSTATE_LINK */ + + 0; + + if (data->sqi != -EOPNOTSUPP) + len += nla_total_size(sizeof(u32)); + + if (data->sqi_max != -EOPNOTSUPP) + len += nla_total_size(sizeof(u32)); + + if (data->link_ext_state_provided) + len += nla_total_size(sizeof(u8)); /* LINKSTATE_EXT_STATE */ + + if (data->ethtool_link_ext_state_info.__link_ext_substate) + len += nla_total_size(sizeof(u8)); /* LINKSTATE_EXT_SUBSTATE */ + + return len; +} + +static int linkstate_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + struct linkstate_reply_data *data = LINKSTATE_REPDATA(reply_base); + + if (data->link >= 0 && + nla_put_u8(skb, ETHTOOL_A_LINKSTATE_LINK, !!data->link)) + return -EMSGSIZE; + + if (data->sqi != -EOPNOTSUPP && + nla_put_u32(skb, ETHTOOL_A_LINKSTATE_SQI, data->sqi)) + return -EMSGSIZE; + + if (data->sqi_max != -EOPNOTSUPP && + nla_put_u32(skb, ETHTOOL_A_LINKSTATE_SQI_MAX, data->sqi_max)) + return -EMSGSIZE; + + if (data->link_ext_state_provided) { + if (nla_put_u8(skb, ETHTOOL_A_LINKSTATE_EXT_STATE, + data->ethtool_link_ext_state_info.link_ext_state)) + return -EMSGSIZE; + + if (data->ethtool_link_ext_state_info.__link_ext_substate && + nla_put_u8(skb, ETHTOOL_A_LINKSTATE_EXT_SUBSTATE, + data->ethtool_link_ext_state_info.__link_ext_substate)) + return -EMSGSIZE; + } + + return 0; +} + +const struct ethnl_request_ops ethnl_linkstate_request_ops = { + .request_cmd = ETHTOOL_MSG_LINKSTATE_GET, + .reply_cmd = ETHTOOL_MSG_LINKSTATE_GET_REPLY, + .hdr_attr = ETHTOOL_A_LINKSTATE_HEADER, + .req_info_size = sizeof(struct linkstate_req_info), + .reply_data_size = sizeof(struct linkstate_reply_data), + + .prepare_data = linkstate_prepare_data, + .reply_size = linkstate_reply_size, + .fill_reply = linkstate_fill_reply, +}; diff --git a/net/ethtool/module.c b/net/ethtool/module.c new file mode 100644 index 000000000..898ed436b --- /dev/null +++ b/net/ethtool/module.c @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include + +#include "netlink.h" +#include "common.h" +#include "bitset.h" + +struct module_req_info { + struct ethnl_req_info base; +}; + +struct module_reply_data { + struct ethnl_reply_data base; + struct ethtool_module_power_mode_params power; +}; + +#define MODULE_REPDATA(__reply_base) \ + container_of(__reply_base, struct module_reply_data, base) + +/* MODULE_GET */ + +const struct nla_policy ethnl_module_get_policy[ETHTOOL_A_MODULE_HEADER + 1] = { + [ETHTOOL_A_MODULE_HEADER] = NLA_POLICY_NESTED(ethnl_header_policy), +}; + +static int module_get_power_mode(struct net_device *dev, + struct module_reply_data *data, + struct netlink_ext_ack *extack) +{ + const struct ethtool_ops *ops = dev->ethtool_ops; + + if (!ops->get_module_power_mode) + return 0; + + return ops->get_module_power_mode(dev, &data->power, extack); +} + +static int module_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + struct module_reply_data *data = MODULE_REPDATA(reply_base); + struct netlink_ext_ack *extack = info ? info->extack : NULL; + struct net_device *dev = reply_base->dev; + int ret; + + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + + ret = module_get_power_mode(dev, data, extack); + if (ret < 0) + goto out_complete; + +out_complete: + ethnl_ops_complete(dev); + return ret; +} + +static int module_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + struct module_reply_data *data = MODULE_REPDATA(reply_base); + int len = 0; + + if (data->power.policy) + len += nla_total_size(sizeof(u8)); /* _MODULE_POWER_MODE_POLICY */ + + if (data->power.mode) + len += nla_total_size(sizeof(u8)); /* _MODULE_POWER_MODE */ + + return len; +} + +static int module_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct module_reply_data *data = MODULE_REPDATA(reply_base); + + if (data->power.policy && + nla_put_u8(skb, ETHTOOL_A_MODULE_POWER_MODE_POLICY, + data->power.policy)) + return -EMSGSIZE; + + if (data->power.mode && + nla_put_u8(skb, ETHTOOL_A_MODULE_POWER_MODE, data->power.mode)) + return -EMSGSIZE; + + return 0; +} + +const struct ethnl_request_ops ethnl_module_request_ops = { + .request_cmd = ETHTOOL_MSG_MODULE_GET, + .reply_cmd = ETHTOOL_MSG_MODULE_GET_REPLY, + .hdr_attr = ETHTOOL_A_MODULE_HEADER, + .req_info_size = sizeof(struct module_req_info), + .reply_data_size = sizeof(struct module_reply_data), + + .prepare_data = module_prepare_data, + .reply_size = module_reply_size, + .fill_reply = module_fill_reply, +}; + +/* MODULE_SET */ + +const struct nla_policy ethnl_module_set_policy[ETHTOOL_A_MODULE_POWER_MODE_POLICY + 1] = { + [ETHTOOL_A_MODULE_HEADER] = NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_MODULE_POWER_MODE_POLICY] = + NLA_POLICY_RANGE(NLA_U8, ETHTOOL_MODULE_POWER_MODE_POLICY_HIGH, + ETHTOOL_MODULE_POWER_MODE_POLICY_AUTO), +}; + +static int module_set_power_mode(struct net_device *dev, struct nlattr **tb, + bool *p_mod, struct netlink_ext_ack *extack) +{ + struct ethtool_module_power_mode_params power = {}; + struct ethtool_module_power_mode_params power_new; + const struct ethtool_ops *ops = dev->ethtool_ops; + int ret; + + if (!tb[ETHTOOL_A_MODULE_POWER_MODE_POLICY]) + return 0; + + if (!ops->get_module_power_mode || !ops->set_module_power_mode) { + NL_SET_ERR_MSG_ATTR(extack, + tb[ETHTOOL_A_MODULE_POWER_MODE_POLICY], + "Setting power mode policy is not supported by this device"); + return -EOPNOTSUPP; + } + + power_new.policy = nla_get_u8(tb[ETHTOOL_A_MODULE_POWER_MODE_POLICY]); + ret = ops->get_module_power_mode(dev, &power, extack); + if (ret < 0) + return ret; + + if (power_new.policy == power.policy) + return 0; + *p_mod = true; + + return ops->set_module_power_mode(dev, &power_new, extack); +} + +int ethnl_set_module(struct sk_buff *skb, struct genl_info *info) +{ + struct ethnl_req_info req_info = {}; + struct nlattr **tb = info->attrs; + struct net_device *dev; + bool mod = false; + int ret; + + ret = ethnl_parse_header_dev_get(&req_info, tb[ETHTOOL_A_MODULE_HEADER], + genl_info_net(info), info->extack, + true); + if (ret < 0) + return ret; + dev = req_info.dev; + + rtnl_lock(); + ret = ethnl_ops_begin(dev); + if (ret < 0) + goto out_rtnl; + + ret = module_set_power_mode(dev, tb, &mod, info->extack); + if (ret < 0) + goto out_ops; + + if (!mod) + goto out_ops; + + ethtool_notify(dev, ETHTOOL_MSG_MODULE_NTF, NULL); + +out_ops: + ethnl_ops_complete(dev); +out_rtnl: + rtnl_unlock(); + ethnl_parse_header_dev_put(&req_info); + return ret; +} diff --git a/net/ethtool/netlink.c b/net/ethtool/netlink.c new file mode 100644 index 000000000..fc4ccecf9 --- /dev/null +++ b/net/ethtool/netlink.c @@ -0,0 +1,1077 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include +#include +#include +#include "netlink.h" + +static struct genl_family ethtool_genl_family; + +static bool ethnl_ok __read_mostly; +static u32 ethnl_bcast_seq; + +#define ETHTOOL_FLAGS_BASIC (ETHTOOL_FLAG_COMPACT_BITSETS | \ + ETHTOOL_FLAG_OMIT_REPLY) +#define ETHTOOL_FLAGS_STATS (ETHTOOL_FLAGS_BASIC | ETHTOOL_FLAG_STATS) + +const struct nla_policy ethnl_header_policy[] = { + [ETHTOOL_A_HEADER_DEV_INDEX] = { .type = NLA_U32 }, + [ETHTOOL_A_HEADER_DEV_NAME] = { .type = NLA_NUL_STRING, + .len = ALTIFNAMSIZ - 1 }, + [ETHTOOL_A_HEADER_FLAGS] = NLA_POLICY_MASK(NLA_U32, + ETHTOOL_FLAGS_BASIC), +}; + +const struct nla_policy ethnl_header_policy_stats[] = { + [ETHTOOL_A_HEADER_DEV_INDEX] = { .type = NLA_U32 }, + [ETHTOOL_A_HEADER_DEV_NAME] = { .type = NLA_NUL_STRING, + .len = ALTIFNAMSIZ - 1 }, + [ETHTOOL_A_HEADER_FLAGS] = NLA_POLICY_MASK(NLA_U32, + ETHTOOL_FLAGS_STATS), +}; + +int ethnl_ops_begin(struct net_device *dev) +{ + int ret; + + if (!dev) + return -ENODEV; + + if (dev->dev.parent) + pm_runtime_get_sync(dev->dev.parent); + + if (!netif_device_present(dev) || + dev->reg_state == NETREG_UNREGISTERING) { + ret = -ENODEV; + goto err; + } + + if (dev->ethtool_ops->begin) { + ret = dev->ethtool_ops->begin(dev); + if (ret) + goto err; + } + + return 0; +err: + if (dev->dev.parent) + pm_runtime_put(dev->dev.parent); + + return ret; +} + +void ethnl_ops_complete(struct net_device *dev) +{ + if (dev->ethtool_ops->complete) + dev->ethtool_ops->complete(dev); + + if (dev->dev.parent) + pm_runtime_put(dev->dev.parent); +} + +/** + * ethnl_parse_header_dev_get() - parse request header + * @req_info: structure to put results into + * @header: nest attribute with request header + * @net: request netns + * @extack: netlink extack for error reporting + * @require_dev: fail if no device identified in header + * + * Parse request header in nested attribute @nest and puts results into + * the structure pointed to by @req_info. Extack from @info is used for error + * reporting. If req_info->dev is not null on return, reference to it has + * been taken. If error is returned, *req_info is null initialized and no + * reference is held. + * + * Return: 0 on success or negative error code + */ +int ethnl_parse_header_dev_get(struct ethnl_req_info *req_info, + const struct nlattr *header, struct net *net, + struct netlink_ext_ack *extack, bool require_dev) +{ + struct nlattr *tb[ARRAY_SIZE(ethnl_header_policy)]; + const struct nlattr *devname_attr; + struct net_device *dev = NULL; + u32 flags = 0; + int ret; + + if (!header) { + NL_SET_ERR_MSG(extack, "request header missing"); + return -EINVAL; + } + /* No validation here, command policy should have a nested policy set + * for the header, therefore validation should have already been done. + */ + ret = nla_parse_nested(tb, ARRAY_SIZE(ethnl_header_policy) - 1, header, + NULL, extack); + if (ret < 0) + return ret; + if (tb[ETHTOOL_A_HEADER_FLAGS]) + flags = nla_get_u32(tb[ETHTOOL_A_HEADER_FLAGS]); + + devname_attr = tb[ETHTOOL_A_HEADER_DEV_NAME]; + if (tb[ETHTOOL_A_HEADER_DEV_INDEX]) { + u32 ifindex = nla_get_u32(tb[ETHTOOL_A_HEADER_DEV_INDEX]); + + dev = dev_get_by_index(net, ifindex); + if (!dev) { + NL_SET_ERR_MSG_ATTR(extack, + tb[ETHTOOL_A_HEADER_DEV_INDEX], + "no device matches ifindex"); + return -ENODEV; + } + /* if both ifindex and ifname are passed, they must match */ + if (devname_attr && + strncmp(dev->name, nla_data(devname_attr), IFNAMSIZ)) { + dev_put(dev); + NL_SET_ERR_MSG_ATTR(extack, header, + "ifindex and name do not match"); + return -ENODEV; + } + } else if (devname_attr) { + dev = dev_get_by_name(net, nla_data(devname_attr)); + if (!dev) { + NL_SET_ERR_MSG_ATTR(extack, devname_attr, + "no device matches name"); + return -ENODEV; + } + } else if (require_dev) { + NL_SET_ERR_MSG_ATTR(extack, header, + "neither ifindex nor name specified"); + return -EINVAL; + } + + req_info->dev = dev; + if (dev) + netdev_tracker_alloc(dev, &req_info->dev_tracker, GFP_KERNEL); + req_info->flags = flags; + return 0; +} + +/** + * ethnl_fill_reply_header() - Put common header into a reply message + * @skb: skb with the message + * @dev: network device to describe in header + * @attrtype: attribute type to use for the nest + * + * Create a nested attribute with attributes describing given network device. + * + * Return: 0 on success, error value (-EMSGSIZE only) on error + */ +int ethnl_fill_reply_header(struct sk_buff *skb, struct net_device *dev, + u16 attrtype) +{ + struct nlattr *nest; + + if (!dev) + return 0; + nest = nla_nest_start(skb, attrtype); + if (!nest) + return -EMSGSIZE; + + if (nla_put_u32(skb, ETHTOOL_A_HEADER_DEV_INDEX, (u32)dev->ifindex) || + nla_put_string(skb, ETHTOOL_A_HEADER_DEV_NAME, dev->name)) + goto nla_put_failure; + /* If more attributes are put into reply header, ethnl_header_size() + * must be updated to account for them. + */ + + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +/** + * ethnl_reply_init() - Create skb for a reply and fill device identification + * @payload: payload length (without netlink and genetlink header) + * @dev: device the reply is about (may be null) + * @cmd: ETHTOOL_MSG_* message type for reply + * @hdr_attrtype: attribute type for common header + * @info: genetlink info of the received packet we respond to + * @ehdrp: place to store payload pointer returned by genlmsg_new() + * + * Return: pointer to allocated skb on success, NULL on error + */ +struct sk_buff *ethnl_reply_init(size_t payload, struct net_device *dev, u8 cmd, + u16 hdr_attrtype, struct genl_info *info, + void **ehdrp) +{ + struct sk_buff *skb; + + skb = genlmsg_new(payload, GFP_KERNEL); + if (!skb) + goto err; + *ehdrp = genlmsg_put_reply(skb, info, ðtool_genl_family, 0, cmd); + if (!*ehdrp) + goto err_free; + + if (dev) { + int ret; + + ret = ethnl_fill_reply_header(skb, dev, hdr_attrtype); + if (ret < 0) + goto err_free; + } + return skb; + +err_free: + nlmsg_free(skb); +err: + if (info) + GENL_SET_ERR_MSG(info, "failed to setup reply message"); + return NULL; +} + +void *ethnl_dump_put(struct sk_buff *skb, struct netlink_callback *cb, u8 cmd) +{ + return genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + ðtool_genl_family, 0, cmd); +} + +void *ethnl_bcastmsg_put(struct sk_buff *skb, u8 cmd) +{ + return genlmsg_put(skb, 0, ++ethnl_bcast_seq, ðtool_genl_family, 0, + cmd); +} + +int ethnl_multicast(struct sk_buff *skb, struct net_device *dev) +{ + return genlmsg_multicast_netns(ðtool_genl_family, dev_net(dev), skb, + 0, ETHNL_MCGRP_MONITOR, GFP_KERNEL); +} + +/* GET request helpers */ + +/** + * struct ethnl_dump_ctx - context structure for generic dumpit() callback + * @ops: request ops of currently processed message type + * @req_info: parsed request header of processed request + * @reply_data: data needed to compose the reply + * @pos_hash: saved iteration position - hashbucket + * @pos_idx: saved iteration position - index + * + * These parameters are kept in struct netlink_callback as context preserved + * between iterations. They are initialized by ethnl_default_start() and used + * in ethnl_default_dumpit() and ethnl_default_done(). + */ +struct ethnl_dump_ctx { + const struct ethnl_request_ops *ops; + struct ethnl_req_info *req_info; + struct ethnl_reply_data *reply_data; + int pos_hash; + int pos_idx; +}; + +static const struct ethnl_request_ops * +ethnl_default_requests[__ETHTOOL_MSG_USER_CNT] = { + [ETHTOOL_MSG_STRSET_GET] = ðnl_strset_request_ops, + [ETHTOOL_MSG_LINKINFO_GET] = ðnl_linkinfo_request_ops, + [ETHTOOL_MSG_LINKMODES_GET] = ðnl_linkmodes_request_ops, + [ETHTOOL_MSG_LINKSTATE_GET] = ðnl_linkstate_request_ops, + [ETHTOOL_MSG_DEBUG_GET] = ðnl_debug_request_ops, + [ETHTOOL_MSG_WOL_GET] = ðnl_wol_request_ops, + [ETHTOOL_MSG_FEATURES_GET] = ðnl_features_request_ops, + [ETHTOOL_MSG_PRIVFLAGS_GET] = ðnl_privflags_request_ops, + [ETHTOOL_MSG_RINGS_GET] = ðnl_rings_request_ops, + [ETHTOOL_MSG_CHANNELS_GET] = ðnl_channels_request_ops, + [ETHTOOL_MSG_COALESCE_GET] = ðnl_coalesce_request_ops, + [ETHTOOL_MSG_PAUSE_GET] = ðnl_pause_request_ops, + [ETHTOOL_MSG_EEE_GET] = ðnl_eee_request_ops, + [ETHTOOL_MSG_FEC_GET] = ðnl_fec_request_ops, + [ETHTOOL_MSG_TSINFO_GET] = ðnl_tsinfo_request_ops, + [ETHTOOL_MSG_MODULE_EEPROM_GET] = ðnl_module_eeprom_request_ops, + [ETHTOOL_MSG_STATS_GET] = ðnl_stats_request_ops, + [ETHTOOL_MSG_PHC_VCLOCKS_GET] = ðnl_phc_vclocks_request_ops, + [ETHTOOL_MSG_MODULE_GET] = ðnl_module_request_ops, + [ETHTOOL_MSG_PSE_GET] = ðnl_pse_request_ops, +}; + +static struct ethnl_dump_ctx *ethnl_dump_context(struct netlink_callback *cb) +{ + return (struct ethnl_dump_ctx *)cb->ctx; +} + +/** + * ethnl_default_parse() - Parse request message + * @req_info: pointer to structure to put data into + * @tb: parsed attributes + * @net: request netns + * @request_ops: struct request_ops for request type + * @extack: netlink extack for error reporting + * @require_dev: fail if no device identified in header + * + * Parse universal request header and call request specific ->parse_request() + * callback (if defined) to parse the rest of the message. + * + * Return: 0 on success or negative error code + */ +static int ethnl_default_parse(struct ethnl_req_info *req_info, + struct nlattr **tb, struct net *net, + const struct ethnl_request_ops *request_ops, + struct netlink_ext_ack *extack, bool require_dev) +{ + int ret; + + ret = ethnl_parse_header_dev_get(req_info, tb[request_ops->hdr_attr], + net, extack, require_dev); + if (ret < 0) + return ret; + + if (request_ops->parse_request) { + ret = request_ops->parse_request(req_info, tb, extack); + if (ret < 0) + return ret; + } + + return 0; +} + +/** + * ethnl_init_reply_data() - Initialize reply data for GET request + * @reply_data: pointer to embedded struct ethnl_reply_data + * @ops: instance of struct ethnl_request_ops describing the layout + * @dev: network device to initialize the reply for + * + * Fills the reply data part with zeros and sets the dev member. Must be called + * before calling the ->fill_reply() callback (for each iteration when handling + * dump requests). + */ +static void ethnl_init_reply_data(struct ethnl_reply_data *reply_data, + const struct ethnl_request_ops *ops, + struct net_device *dev) +{ + memset(reply_data, 0, ops->reply_data_size); + reply_data->dev = dev; +} + +/* default ->doit() handler for GET type requests */ +static int ethnl_default_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct ethnl_reply_data *reply_data = NULL; + struct ethnl_req_info *req_info = NULL; + const u8 cmd = info->genlhdr->cmd; + const struct ethnl_request_ops *ops; + int hdr_len, reply_len; + struct sk_buff *rskb; + void *reply_payload; + int ret; + + ops = ethnl_default_requests[cmd]; + if (WARN_ONCE(!ops, "cmd %u has no ethnl_request_ops\n", cmd)) + return -EOPNOTSUPP; + if (GENL_REQ_ATTR_CHECK(info, ops->hdr_attr)) + return -EINVAL; + + req_info = kzalloc(ops->req_info_size, GFP_KERNEL); + if (!req_info) + return -ENOMEM; + reply_data = kmalloc(ops->reply_data_size, GFP_KERNEL); + if (!reply_data) { + kfree(req_info); + return -ENOMEM; + } + + ret = ethnl_default_parse(req_info, info->attrs, genl_info_net(info), + ops, info->extack, !ops->allow_nodev_do); + if (ret < 0) + goto err_dev; + ethnl_init_reply_data(reply_data, ops, req_info->dev); + + rtnl_lock(); + ret = ops->prepare_data(req_info, reply_data, info); + rtnl_unlock(); + if (ret < 0) + goto err_cleanup; + ret = ops->reply_size(req_info, reply_data); + if (ret < 0) + goto err_cleanup; + reply_len = ret; + ret = -ENOMEM; + rskb = ethnl_reply_init(reply_len + ethnl_reply_header_size(), + req_info->dev, ops->reply_cmd, + ops->hdr_attr, info, &reply_payload); + if (!rskb) + goto err_cleanup; + hdr_len = rskb->len; + ret = ops->fill_reply(rskb, req_info, reply_data); + if (ret < 0) + goto err_msg; + WARN_ONCE(rskb->len - hdr_len > reply_len, + "ethnl cmd %d: calculated reply length %d, but consumed %d\n", + cmd, reply_len, rskb->len - hdr_len); + if (ops->cleanup_data) + ops->cleanup_data(reply_data); + + genlmsg_end(rskb, reply_payload); + netdev_put(req_info->dev, &req_info->dev_tracker); + kfree(reply_data); + kfree(req_info); + return genlmsg_reply(rskb, info); + +err_msg: + WARN_ONCE(ret == -EMSGSIZE, "calculated message payload length (%d) not sufficient\n", reply_len); + nlmsg_free(rskb); +err_cleanup: + if (ops->cleanup_data) + ops->cleanup_data(reply_data); +err_dev: + netdev_put(req_info->dev, &req_info->dev_tracker); + kfree(reply_data); + kfree(req_info); + return ret; +} + +static int ethnl_default_dump_one(struct sk_buff *skb, struct net_device *dev, + const struct ethnl_dump_ctx *ctx, + struct netlink_callback *cb) +{ + void *ehdr; + int ret; + + ehdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + ðtool_genl_family, NLM_F_MULTI, + ctx->ops->reply_cmd); + if (!ehdr) + return -EMSGSIZE; + + ethnl_init_reply_data(ctx->reply_data, ctx->ops, dev); + rtnl_lock(); + ret = ctx->ops->prepare_data(ctx->req_info, ctx->reply_data, NULL); + rtnl_unlock(); + if (ret < 0) + goto out; + ret = ethnl_fill_reply_header(skb, dev, ctx->ops->hdr_attr); + if (ret < 0) + goto out; + ret = ctx->ops->fill_reply(skb, ctx->req_info, ctx->reply_data); + +out: + if (ctx->ops->cleanup_data) + ctx->ops->cleanup_data(ctx->reply_data); + ctx->reply_data->dev = NULL; + if (ret < 0) + genlmsg_cancel(skb, ehdr); + else + genlmsg_end(skb, ehdr); + return ret; +} + +/* Default ->dumpit() handler for GET requests. Device iteration copied from + * rtnl_dump_ifinfo(); we have to be more careful about device hashtable + * persistence as we cannot guarantee to hold RTNL lock through the whole + * function as rtnetnlink does. + */ +static int ethnl_default_dumpit(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct ethnl_dump_ctx *ctx = ethnl_dump_context(cb); + struct net *net = sock_net(skb->sk); + int s_idx = ctx->pos_idx; + int h, idx = 0; + int ret = 0; + + rtnl_lock(); + for (h = ctx->pos_hash; h < NETDEV_HASHENTRIES; h++, s_idx = 0) { + struct hlist_head *head; + struct net_device *dev; + unsigned int seq; + + head = &net->dev_index_head[h]; + +restart_chain: + seq = net->dev_base_seq; + cb->seq = seq; + idx = 0; + hlist_for_each_entry(dev, head, index_hlist) { + if (idx < s_idx) + goto cont; + dev_hold(dev); + rtnl_unlock(); + + ret = ethnl_default_dump_one(skb, dev, ctx, cb); + dev_put(dev); + if (ret < 0) { + if (ret == -EOPNOTSUPP) + goto lock_and_cont; + if (likely(skb->len)) + ret = skb->len; + goto out; + } +lock_and_cont: + rtnl_lock(); + if (net->dev_base_seq != seq) { + s_idx = idx + 1; + goto restart_chain; + } +cont: + idx++; + } + ret = 0; + } + rtnl_unlock(); + +out: + ctx->pos_hash = h; + ctx->pos_idx = idx; + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); + + return ret; +} + +/* generic ->start() handler for GET requests */ +static int ethnl_default_start(struct netlink_callback *cb) +{ + const struct genl_dumpit_info *info = genl_dumpit_info(cb); + struct ethnl_dump_ctx *ctx = ethnl_dump_context(cb); + struct ethnl_reply_data *reply_data; + const struct ethnl_request_ops *ops; + struct ethnl_req_info *req_info; + struct genlmsghdr *ghdr; + int ret; + + BUILD_BUG_ON(sizeof(*ctx) > sizeof(cb->ctx)); + + ghdr = nlmsg_data(cb->nlh); + ops = ethnl_default_requests[ghdr->cmd]; + if (WARN_ONCE(!ops, "cmd %u has no ethnl_request_ops\n", ghdr->cmd)) + return -EOPNOTSUPP; + req_info = kzalloc(ops->req_info_size, GFP_KERNEL); + if (!req_info) + return -ENOMEM; + reply_data = kmalloc(ops->reply_data_size, GFP_KERNEL); + if (!reply_data) { + ret = -ENOMEM; + goto free_req_info; + } + + ret = ethnl_default_parse(req_info, info->attrs, sock_net(cb->skb->sk), + ops, cb->extack, false); + if (req_info->dev) { + /* We ignore device specification in dump requests but as the + * same parser as for non-dump (doit) requests is used, it + * would take reference to the device if it finds one + */ + netdev_put(req_info->dev, &req_info->dev_tracker); + req_info->dev = NULL; + } + if (ret < 0) + goto free_reply_data; + + ctx->ops = ops; + ctx->req_info = req_info; + ctx->reply_data = reply_data; + ctx->pos_hash = 0; + ctx->pos_idx = 0; + + return 0; + +free_reply_data: + kfree(reply_data); +free_req_info: + kfree(req_info); + + return ret; +} + +/* default ->done() handler for GET requests */ +static int ethnl_default_done(struct netlink_callback *cb) +{ + struct ethnl_dump_ctx *ctx = ethnl_dump_context(cb); + + kfree(ctx->reply_data); + kfree(ctx->req_info); + + return 0; +} + +static const struct ethnl_request_ops * +ethnl_default_notify_ops[ETHTOOL_MSG_KERNEL_MAX + 1] = { + [ETHTOOL_MSG_LINKINFO_NTF] = ðnl_linkinfo_request_ops, + [ETHTOOL_MSG_LINKMODES_NTF] = ðnl_linkmodes_request_ops, + [ETHTOOL_MSG_DEBUG_NTF] = ðnl_debug_request_ops, + [ETHTOOL_MSG_WOL_NTF] = ðnl_wol_request_ops, + [ETHTOOL_MSG_FEATURES_NTF] = ðnl_features_request_ops, + [ETHTOOL_MSG_PRIVFLAGS_NTF] = ðnl_privflags_request_ops, + [ETHTOOL_MSG_RINGS_NTF] = ðnl_rings_request_ops, + [ETHTOOL_MSG_CHANNELS_NTF] = ðnl_channels_request_ops, + [ETHTOOL_MSG_COALESCE_NTF] = ðnl_coalesce_request_ops, + [ETHTOOL_MSG_PAUSE_NTF] = ðnl_pause_request_ops, + [ETHTOOL_MSG_EEE_NTF] = ðnl_eee_request_ops, + [ETHTOOL_MSG_FEC_NTF] = ðnl_fec_request_ops, + [ETHTOOL_MSG_MODULE_NTF] = ðnl_module_request_ops, +}; + +/* default notification handler */ +static void ethnl_default_notify(struct net_device *dev, unsigned int cmd, + const void *data) +{ + struct ethnl_reply_data *reply_data; + const struct ethnl_request_ops *ops; + struct ethnl_req_info *req_info; + struct sk_buff *skb; + void *reply_payload; + int reply_len; + int ret; + + if (WARN_ONCE(cmd > ETHTOOL_MSG_KERNEL_MAX || + !ethnl_default_notify_ops[cmd], + "unexpected notification type %u\n", cmd)) + return; + ops = ethnl_default_notify_ops[cmd]; + req_info = kzalloc(ops->req_info_size, GFP_KERNEL); + if (!req_info) + return; + reply_data = kmalloc(ops->reply_data_size, GFP_KERNEL); + if (!reply_data) { + kfree(req_info); + return; + } + + req_info->dev = dev; + req_info->flags |= ETHTOOL_FLAG_COMPACT_BITSETS; + + ethnl_init_reply_data(reply_data, ops, dev); + ret = ops->prepare_data(req_info, reply_data, NULL); + if (ret < 0) + goto err_cleanup; + ret = ops->reply_size(req_info, reply_data); + if (ret < 0) + goto err_cleanup; + reply_len = ret + ethnl_reply_header_size(); + skb = genlmsg_new(reply_len, GFP_KERNEL); + if (!skb) + goto err_cleanup; + reply_payload = ethnl_bcastmsg_put(skb, cmd); + if (!reply_payload) + goto err_skb; + ret = ethnl_fill_reply_header(skb, dev, ops->hdr_attr); + if (ret < 0) + goto err_msg; + ret = ops->fill_reply(skb, req_info, reply_data); + if (ret < 0) + goto err_msg; + if (ops->cleanup_data) + ops->cleanup_data(reply_data); + + genlmsg_end(skb, reply_payload); + kfree(reply_data); + kfree(req_info); + ethnl_multicast(skb, dev); + return; + +err_msg: + WARN_ONCE(ret == -EMSGSIZE, + "calculated message payload length (%d) not sufficient\n", + reply_len); +err_skb: + nlmsg_free(skb); +err_cleanup: + if (ops->cleanup_data) + ops->cleanup_data(reply_data); + kfree(reply_data); + kfree(req_info); + return; +} + +/* notifications */ + +typedef void (*ethnl_notify_handler_t)(struct net_device *dev, unsigned int cmd, + const void *data); + +static const ethnl_notify_handler_t ethnl_notify_handlers[] = { + [ETHTOOL_MSG_LINKINFO_NTF] = ethnl_default_notify, + [ETHTOOL_MSG_LINKMODES_NTF] = ethnl_default_notify, + [ETHTOOL_MSG_DEBUG_NTF] = ethnl_default_notify, + [ETHTOOL_MSG_WOL_NTF] = ethnl_default_notify, + [ETHTOOL_MSG_FEATURES_NTF] = ethnl_default_notify, + [ETHTOOL_MSG_PRIVFLAGS_NTF] = ethnl_default_notify, + [ETHTOOL_MSG_RINGS_NTF] = ethnl_default_notify, + [ETHTOOL_MSG_CHANNELS_NTF] = ethnl_default_notify, + [ETHTOOL_MSG_COALESCE_NTF] = ethnl_default_notify, + [ETHTOOL_MSG_PAUSE_NTF] = ethnl_default_notify, + [ETHTOOL_MSG_EEE_NTF] = ethnl_default_notify, + [ETHTOOL_MSG_FEC_NTF] = ethnl_default_notify, + [ETHTOOL_MSG_MODULE_NTF] = ethnl_default_notify, +}; + +void ethtool_notify(struct net_device *dev, unsigned int cmd, const void *data) +{ + if (unlikely(!ethnl_ok)) + return; + ASSERT_RTNL(); + + if (likely(cmd < ARRAY_SIZE(ethnl_notify_handlers) && + ethnl_notify_handlers[cmd])) + ethnl_notify_handlers[cmd](dev, cmd, data); + else + WARN_ONCE(1, "notification %u not implemented (dev=%s)\n", + cmd, netdev_name(dev)); +} +EXPORT_SYMBOL(ethtool_notify); + +static void ethnl_notify_features(struct netdev_notifier_info *info) +{ + struct net_device *dev = netdev_notifier_info_to_dev(info); + + ethtool_notify(dev, ETHTOOL_MSG_FEATURES_NTF, NULL); +} + +static int ethnl_netdev_event(struct notifier_block *this, unsigned long event, + void *ptr) +{ + switch (event) { + case NETDEV_FEAT_CHANGE: + ethnl_notify_features(ptr); + break; + } + + return NOTIFY_DONE; +} + +static struct notifier_block ethnl_netdev_notifier = { + .notifier_call = ethnl_netdev_event, +}; + +/* genetlink setup */ + +static const struct genl_ops ethtool_genl_ops[] = { + { + .cmd = ETHTOOL_MSG_STRSET_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_strset_get_policy, + .maxattr = ARRAY_SIZE(ethnl_strset_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_LINKINFO_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_linkinfo_get_policy, + .maxattr = ARRAY_SIZE(ethnl_linkinfo_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_LINKINFO_SET, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_set_linkinfo, + .policy = ethnl_linkinfo_set_policy, + .maxattr = ARRAY_SIZE(ethnl_linkinfo_set_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_LINKMODES_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_linkmodes_get_policy, + .maxattr = ARRAY_SIZE(ethnl_linkmodes_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_LINKMODES_SET, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_set_linkmodes, + .policy = ethnl_linkmodes_set_policy, + .maxattr = ARRAY_SIZE(ethnl_linkmodes_set_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_LINKSTATE_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_linkstate_get_policy, + .maxattr = ARRAY_SIZE(ethnl_linkstate_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_DEBUG_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_debug_get_policy, + .maxattr = ARRAY_SIZE(ethnl_debug_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_DEBUG_SET, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_set_debug, + .policy = ethnl_debug_set_policy, + .maxattr = ARRAY_SIZE(ethnl_debug_set_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_WOL_GET, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_wol_get_policy, + .maxattr = ARRAY_SIZE(ethnl_wol_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_WOL_SET, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_set_wol, + .policy = ethnl_wol_set_policy, + .maxattr = ARRAY_SIZE(ethnl_wol_set_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_FEATURES_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_features_get_policy, + .maxattr = ARRAY_SIZE(ethnl_features_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_FEATURES_SET, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_set_features, + .policy = ethnl_features_set_policy, + .maxattr = ARRAY_SIZE(ethnl_features_set_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_PRIVFLAGS_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_privflags_get_policy, + .maxattr = ARRAY_SIZE(ethnl_privflags_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_PRIVFLAGS_SET, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_set_privflags, + .policy = ethnl_privflags_set_policy, + .maxattr = ARRAY_SIZE(ethnl_privflags_set_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_RINGS_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_rings_get_policy, + .maxattr = ARRAY_SIZE(ethnl_rings_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_RINGS_SET, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_set_rings, + .policy = ethnl_rings_set_policy, + .maxattr = ARRAY_SIZE(ethnl_rings_set_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_CHANNELS_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_channels_get_policy, + .maxattr = ARRAY_SIZE(ethnl_channels_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_CHANNELS_SET, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_set_channels, + .policy = ethnl_channels_set_policy, + .maxattr = ARRAY_SIZE(ethnl_channels_set_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_COALESCE_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_coalesce_get_policy, + .maxattr = ARRAY_SIZE(ethnl_coalesce_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_COALESCE_SET, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_set_coalesce, + .policy = ethnl_coalesce_set_policy, + .maxattr = ARRAY_SIZE(ethnl_coalesce_set_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_PAUSE_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_pause_get_policy, + .maxattr = ARRAY_SIZE(ethnl_pause_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_PAUSE_SET, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_set_pause, + .policy = ethnl_pause_set_policy, + .maxattr = ARRAY_SIZE(ethnl_pause_set_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_EEE_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_eee_get_policy, + .maxattr = ARRAY_SIZE(ethnl_eee_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_EEE_SET, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_set_eee, + .policy = ethnl_eee_set_policy, + .maxattr = ARRAY_SIZE(ethnl_eee_set_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_TSINFO_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_tsinfo_get_policy, + .maxattr = ARRAY_SIZE(ethnl_tsinfo_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_CABLE_TEST_ACT, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_act_cable_test, + .policy = ethnl_cable_test_act_policy, + .maxattr = ARRAY_SIZE(ethnl_cable_test_act_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_CABLE_TEST_TDR_ACT, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_act_cable_test_tdr, + .policy = ethnl_cable_test_tdr_act_policy, + .maxattr = ARRAY_SIZE(ethnl_cable_test_tdr_act_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_TUNNEL_INFO_GET, + .doit = ethnl_tunnel_info_doit, + .start = ethnl_tunnel_info_start, + .dumpit = ethnl_tunnel_info_dumpit, + .policy = ethnl_tunnel_info_get_policy, + .maxattr = ARRAY_SIZE(ethnl_tunnel_info_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_FEC_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_fec_get_policy, + .maxattr = ARRAY_SIZE(ethnl_fec_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_FEC_SET, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_set_fec, + .policy = ethnl_fec_set_policy, + .maxattr = ARRAY_SIZE(ethnl_fec_set_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_MODULE_EEPROM_GET, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_module_eeprom_get_policy, + .maxattr = ARRAY_SIZE(ethnl_module_eeprom_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_STATS_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_stats_get_policy, + .maxattr = ARRAY_SIZE(ethnl_stats_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_PHC_VCLOCKS_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_phc_vclocks_get_policy, + .maxattr = ARRAY_SIZE(ethnl_phc_vclocks_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_MODULE_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_module_get_policy, + .maxattr = ARRAY_SIZE(ethnl_module_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_MODULE_SET, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_set_module, + .policy = ethnl_module_set_policy, + .maxattr = ARRAY_SIZE(ethnl_module_set_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_PSE_GET, + .doit = ethnl_default_doit, + .start = ethnl_default_start, + .dumpit = ethnl_default_dumpit, + .done = ethnl_default_done, + .policy = ethnl_pse_get_policy, + .maxattr = ARRAY_SIZE(ethnl_pse_get_policy) - 1, + }, + { + .cmd = ETHTOOL_MSG_PSE_SET, + .flags = GENL_UNS_ADMIN_PERM, + .doit = ethnl_set_pse, + .policy = ethnl_pse_set_policy, + .maxattr = ARRAY_SIZE(ethnl_pse_set_policy) - 1, + }, +}; + +static const struct genl_multicast_group ethtool_nl_mcgrps[] = { + [ETHNL_MCGRP_MONITOR] = { .name = ETHTOOL_MCGRP_MONITOR_NAME }, +}; + +static struct genl_family ethtool_genl_family __ro_after_init = { + .name = ETHTOOL_GENL_NAME, + .version = ETHTOOL_GENL_VERSION, + .netnsok = true, + .parallel_ops = true, + .ops = ethtool_genl_ops, + .n_ops = ARRAY_SIZE(ethtool_genl_ops), + .resv_start_op = ETHTOOL_MSG_MODULE_GET + 1, + .mcgrps = ethtool_nl_mcgrps, + .n_mcgrps = ARRAY_SIZE(ethtool_nl_mcgrps), +}; + +/* module setup */ + +static int __init ethnl_init(void) +{ + int ret; + + ret = genl_register_family(ðtool_genl_family); + if (WARN(ret < 0, "ethtool: genetlink family registration failed")) + return ret; + ethnl_ok = true; + + ret = register_netdevice_notifier(ðnl_netdev_notifier); + WARN(ret < 0, "ethtool: net device notifier registration failed"); + return ret; +} + +subsys_initcall(ethnl_init); diff --git a/net/ethtool/netlink.h b/net/ethtool/netlink.h new file mode 100644 index 000000000..1bfd374f9 --- /dev/null +++ b/net/ethtool/netlink.h @@ -0,0 +1,416 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ + +#ifndef _NET_ETHTOOL_NETLINK_H +#define _NET_ETHTOOL_NETLINK_H + +#include +#include +#include +#include + +struct ethnl_req_info; + +int ethnl_parse_header_dev_get(struct ethnl_req_info *req_info, + const struct nlattr *nest, struct net *net, + struct netlink_ext_ack *extack, + bool require_dev); +int ethnl_fill_reply_header(struct sk_buff *skb, struct net_device *dev, + u16 attrtype); +struct sk_buff *ethnl_reply_init(size_t payload, struct net_device *dev, u8 cmd, + u16 hdr_attrtype, struct genl_info *info, + void **ehdrp); +void *ethnl_dump_put(struct sk_buff *skb, struct netlink_callback *cb, u8 cmd); +void *ethnl_bcastmsg_put(struct sk_buff *skb, u8 cmd); +int ethnl_multicast(struct sk_buff *skb, struct net_device *dev); + +/** + * ethnl_strz_size() - calculate attribute length for fixed size string + * @s: ETH_GSTRING_LEN sized string (may not be null terminated) + * + * Return: total length of an attribute with null terminated string from @s + */ +static inline int ethnl_strz_size(const char *s) +{ + return nla_total_size(strnlen(s, ETH_GSTRING_LEN) + 1); +} + +/** + * ethnl_put_strz() - put string attribute with fixed size string + * @skb: skb with the message + * @attrtype: attribute type + * @s: ETH_GSTRING_LEN sized string (may not be null terminated) + * + * Puts an attribute with null terminated string from @s into the message. + * + * Return: 0 on success, negative error code on failure + */ +static inline int ethnl_put_strz(struct sk_buff *skb, u16 attrtype, + const char *s) +{ + unsigned int len = strnlen(s, ETH_GSTRING_LEN); + struct nlattr *attr; + + attr = nla_reserve(skb, attrtype, len + 1); + if (!attr) + return -EMSGSIZE; + + memcpy(nla_data(attr), s, len); + ((char *)nla_data(attr))[len] = '\0'; + return 0; +} + +/** + * ethnl_update_u32() - update u32 value from NLA_U32 attribute + * @dst: value to update + * @attr: netlink attribute with new value or null + * @mod: pointer to bool for modification tracking + * + * Copy the u32 value from NLA_U32 netlink attribute @attr into variable + * pointed to by @dst; do nothing if @attr is null. Bool pointed to by @mod + * is set to true if this function changed the value of *dst, otherwise it + * is left as is. + */ +static inline void ethnl_update_u32(u32 *dst, const struct nlattr *attr, + bool *mod) +{ + u32 val; + + if (!attr) + return; + val = nla_get_u32(attr); + if (*dst == val) + return; + + *dst = val; + *mod = true; +} + +/** + * ethnl_update_u8() - update u8 value from NLA_U8 attribute + * @dst: value to update + * @attr: netlink attribute with new value or null + * @mod: pointer to bool for modification tracking + * + * Copy the u8 value from NLA_U8 netlink attribute @attr into variable + * pointed to by @dst; do nothing if @attr is null. Bool pointed to by @mod + * is set to true if this function changed the value of *dst, otherwise it + * is left as is. + */ +static inline void ethnl_update_u8(u8 *dst, const struct nlattr *attr, + bool *mod) +{ + u8 val; + + if (!attr) + return; + val = nla_get_u8(attr); + if (*dst == val) + return; + + *dst = val; + *mod = true; +} + +/** + * ethnl_update_bool32() - update u32 used as bool from NLA_U8 attribute + * @dst: value to update + * @attr: netlink attribute with new value or null + * @mod: pointer to bool for modification tracking + * + * Use the u8 value from NLA_U8 netlink attribute @attr to set u32 variable + * pointed to by @dst to 0 (if zero) or 1 (if not); do nothing if @attr is + * null. Bool pointed to by @mod is set to true if this function changed the + * logical value of *dst, otherwise it is left as is. + */ +static inline void ethnl_update_bool32(u32 *dst, const struct nlattr *attr, + bool *mod) +{ + u8 val; + + if (!attr) + return; + val = !!nla_get_u8(attr); + if (!!*dst == val) + return; + + *dst = val; + *mod = true; +} + +/** + * ethnl_update_binary() - update binary data from NLA_BINARY attribute + * @dst: value to update + * @len: destination buffer length + * @attr: netlink attribute with new value or null + * @mod: pointer to bool for modification tracking + * + * Use the u8 value from NLA_U8 netlink attribute @attr to rewrite data block + * of length @len at @dst by attribute payload; do nothing if @attr is null. + * Bool pointed to by @mod is set to true if this function changed the logical + * value of *dst, otherwise it is left as is. + */ +static inline void ethnl_update_binary(void *dst, unsigned int len, + const struct nlattr *attr, bool *mod) +{ + if (!attr) + return; + if (nla_len(attr) < len) + len = nla_len(attr); + if (!memcmp(dst, nla_data(attr), len)) + return; + + memcpy(dst, nla_data(attr), len); + *mod = true; +} + +/** + * ethnl_update_bitfield32() - update u32 value from NLA_BITFIELD32 attribute + * @dst: value to update + * @attr: netlink attribute with new value or null + * @mod: pointer to bool for modification tracking + * + * Update bits in u32 value which are set in attribute's mask to values from + * attribute's value. Do nothing if @attr is null or the value wouldn't change; + * otherwise, set bool pointed to by @mod to true. + */ +static inline void ethnl_update_bitfield32(u32 *dst, const struct nlattr *attr, + bool *mod) +{ + struct nla_bitfield32 change; + u32 newval; + + if (!attr) + return; + change = nla_get_bitfield32(attr); + newval = (*dst & ~change.selector) | (change.value & change.selector); + if (*dst == newval) + return; + + *dst = newval; + *mod = true; +} + +/** + * ethnl_reply_header_size() - total size of reply header + * + * This is an upper estimate so that we do not need to hold RTNL lock longer + * than necessary (to prevent rename between size estimate and composing the + * message). Accounts only for device ifindex and name as those are the only + * attributes ethnl_fill_reply_header() puts into the reply header. + */ +static inline unsigned int ethnl_reply_header_size(void) +{ + return nla_total_size(nla_total_size(sizeof(u32)) + + nla_total_size(IFNAMSIZ)); +} + +/* GET request handling */ + +/* Unified processing of GET requests uses two data structures: request info + * and reply data. Request info holds information parsed from client request + * and its stays constant through all request processing. Reply data holds data + * retrieved from ethtool_ops callbacks or other internal sources which is used + * to compose the reply. When processing a dump request, request info is filled + * only once (when the request message is parsed) but reply data is filled for + * each reply message. + * + * Both structures consist of part common for all request types (struct + * ethnl_req_info and struct ethnl_reply_data defined below) and optional + * parts specific for each request type. Common part always starts at offset 0. + */ + +/** + * struct ethnl_req_info - base type of request information for GET requests + * @dev: network device the request is for (may be null) + * @dev_tracker: refcount tracker for @dev reference + * @flags: request flags common for all request types + * + * This is a common base for request specific structures holding data from + * parsed userspace request. These always embed struct ethnl_req_info at + * zero offset. + */ +struct ethnl_req_info { + struct net_device *dev; + netdevice_tracker dev_tracker; + u32 flags; +}; + +static inline void ethnl_parse_header_dev_put(struct ethnl_req_info *req_info) +{ + netdev_put(req_info->dev, &req_info->dev_tracker); +} + +/** + * struct ethnl_reply_data - base type of reply data for GET requests + * @dev: device for current reply message; in single shot requests it is + * equal to ðnl_req_info.dev; in dumps it's different for each + * reply message + * + * This is a common base for request specific structures holding data for + * kernel reply message. These always embed struct ethnl_reply_data at zero + * offset. + */ +struct ethnl_reply_data { + struct net_device *dev; +}; + +int ethnl_ops_begin(struct net_device *dev); +void ethnl_ops_complete(struct net_device *dev); + +/** + * struct ethnl_request_ops - unified handling of GET requests + * @request_cmd: command id for request (GET) + * @reply_cmd: command id for reply (GET_REPLY) + * @hdr_attr: attribute type for request header + * @req_info_size: size of request info + * @reply_data_size: size of reply data + * @allow_nodev_do: allow non-dump request with no device identification + * @parse_request: + * Parse request except common header (struct ethnl_req_info). Common + * header is already filled on entry, the rest up to @repdata_offset + * is zero initialized. This callback should only modify type specific + * request info by parsed attributes from request message. + * @prepare_data: + * Retrieve and prepare data needed to compose a reply message. Calls to + * ethtool_ops handlers are limited to this callback. Common reply data + * (struct ethnl_reply_data) is filled on entry, type specific part after + * it is zero initialized. This callback should only modify the type + * specific part of reply data. Device identification from struct + * ethnl_reply_data is to be used as for dump requests, it iterates + * through network devices while dev member of struct ethnl_req_info + * points to the device from client request. + * @reply_size: + * Estimate reply message size. Returned value must be sufficient for + * message payload without common reply header. The callback may returned + * estimate higher than actual message size if exact calculation would + * not be worth the saved memory space. + * @fill_reply: + * Fill reply message payload (except for common header) from reply data. + * The callback must not generate more payload than previously called + * ->reply_size() estimated. + * @cleanup_data: + * Optional cleanup called when reply data is no longer needed. Can be + * used e.g. to free any additional data structures outside the main + * structure which were allocated by ->prepare_data(). When processing + * dump requests, ->cleanup() is called for each message. + * + * Description of variable parts of GET request handling when using the + * unified infrastructure. When used, a pointer to an instance of this + * structure is to be added to ðnl_default_requests array and generic + * handlers ethnl_default_doit(), ethnl_default_dumpit(), + * ethnl_default_start() and ethnl_default_done() used in @ethtool_genl_ops; + * ethnl_default_notify() can be used in @ethnl_notify_handlers to send + * notifications of the corresponding type. + */ +struct ethnl_request_ops { + u8 request_cmd; + u8 reply_cmd; + u16 hdr_attr; + unsigned int req_info_size; + unsigned int reply_data_size; + bool allow_nodev_do; + + int (*parse_request)(struct ethnl_req_info *req_info, + struct nlattr **tb, + struct netlink_ext_ack *extack); + int (*prepare_data)(const struct ethnl_req_info *req_info, + struct ethnl_reply_data *reply_data, + struct genl_info *info); + int (*reply_size)(const struct ethnl_req_info *req_info, + const struct ethnl_reply_data *reply_data); + int (*fill_reply)(struct sk_buff *skb, + const struct ethnl_req_info *req_info, + const struct ethnl_reply_data *reply_data); + void (*cleanup_data)(struct ethnl_reply_data *reply_data); +}; + +/* request handlers */ + +extern const struct ethnl_request_ops ethnl_strset_request_ops; +extern const struct ethnl_request_ops ethnl_linkinfo_request_ops; +extern const struct ethnl_request_ops ethnl_linkmodes_request_ops; +extern const struct ethnl_request_ops ethnl_linkstate_request_ops; +extern const struct ethnl_request_ops ethnl_debug_request_ops; +extern const struct ethnl_request_ops ethnl_wol_request_ops; +extern const struct ethnl_request_ops ethnl_features_request_ops; +extern const struct ethnl_request_ops ethnl_privflags_request_ops; +extern const struct ethnl_request_ops ethnl_rings_request_ops; +extern const struct ethnl_request_ops ethnl_channels_request_ops; +extern const struct ethnl_request_ops ethnl_coalesce_request_ops; +extern const struct ethnl_request_ops ethnl_pause_request_ops; +extern const struct ethnl_request_ops ethnl_eee_request_ops; +extern const struct ethnl_request_ops ethnl_tsinfo_request_ops; +extern const struct ethnl_request_ops ethnl_fec_request_ops; +extern const struct ethnl_request_ops ethnl_module_eeprom_request_ops; +extern const struct ethnl_request_ops ethnl_stats_request_ops; +extern const struct ethnl_request_ops ethnl_phc_vclocks_request_ops; +extern const struct ethnl_request_ops ethnl_module_request_ops; +extern const struct ethnl_request_ops ethnl_pse_request_ops; + +extern const struct nla_policy ethnl_header_policy[ETHTOOL_A_HEADER_FLAGS + 1]; +extern const struct nla_policy ethnl_header_policy_stats[ETHTOOL_A_HEADER_FLAGS + 1]; +extern const struct nla_policy ethnl_strset_get_policy[ETHTOOL_A_STRSET_COUNTS_ONLY + 1]; +extern const struct nla_policy ethnl_linkinfo_get_policy[ETHTOOL_A_LINKINFO_HEADER + 1]; +extern const struct nla_policy ethnl_linkinfo_set_policy[ETHTOOL_A_LINKINFO_TP_MDIX_CTRL + 1]; +extern const struct nla_policy ethnl_linkmodes_get_policy[ETHTOOL_A_LINKMODES_HEADER + 1]; +extern const struct nla_policy ethnl_linkmodes_set_policy[ETHTOOL_A_LINKMODES_LANES + 1]; +extern const struct nla_policy ethnl_linkstate_get_policy[ETHTOOL_A_LINKSTATE_HEADER + 1]; +extern const struct nla_policy ethnl_debug_get_policy[ETHTOOL_A_DEBUG_HEADER + 1]; +extern const struct nla_policy ethnl_debug_set_policy[ETHTOOL_A_DEBUG_MSGMASK + 1]; +extern const struct nla_policy ethnl_wol_get_policy[ETHTOOL_A_WOL_HEADER + 1]; +extern const struct nla_policy ethnl_wol_set_policy[ETHTOOL_A_WOL_SOPASS + 1]; +extern const struct nla_policy ethnl_features_get_policy[ETHTOOL_A_FEATURES_HEADER + 1]; +extern const struct nla_policy ethnl_features_set_policy[ETHTOOL_A_FEATURES_WANTED + 1]; +extern const struct nla_policy ethnl_privflags_get_policy[ETHTOOL_A_PRIVFLAGS_HEADER + 1]; +extern const struct nla_policy ethnl_privflags_set_policy[ETHTOOL_A_PRIVFLAGS_FLAGS + 1]; +extern const struct nla_policy ethnl_rings_get_policy[ETHTOOL_A_RINGS_HEADER + 1]; +extern const struct nla_policy ethnl_rings_set_policy[ETHTOOL_A_RINGS_TX_PUSH + 1]; +extern const struct nla_policy ethnl_channels_get_policy[ETHTOOL_A_CHANNELS_HEADER + 1]; +extern const struct nla_policy ethnl_channels_set_policy[ETHTOOL_A_CHANNELS_COMBINED_COUNT + 1]; +extern const struct nla_policy ethnl_coalesce_get_policy[ETHTOOL_A_COALESCE_HEADER + 1]; +extern const struct nla_policy ethnl_coalesce_set_policy[ETHTOOL_A_COALESCE_MAX + 1]; +extern const struct nla_policy ethnl_pause_get_policy[ETHTOOL_A_PAUSE_HEADER + 1]; +extern const struct nla_policy ethnl_pause_set_policy[ETHTOOL_A_PAUSE_TX + 1]; +extern const struct nla_policy ethnl_eee_get_policy[ETHTOOL_A_EEE_HEADER + 1]; +extern const struct nla_policy ethnl_eee_set_policy[ETHTOOL_A_EEE_TX_LPI_TIMER + 1]; +extern const struct nla_policy ethnl_tsinfo_get_policy[ETHTOOL_A_TSINFO_HEADER + 1]; +extern const struct nla_policy ethnl_cable_test_act_policy[ETHTOOL_A_CABLE_TEST_HEADER + 1]; +extern const struct nla_policy ethnl_cable_test_tdr_act_policy[ETHTOOL_A_CABLE_TEST_TDR_CFG + 1]; +extern const struct nla_policy ethnl_tunnel_info_get_policy[ETHTOOL_A_TUNNEL_INFO_HEADER + 1]; +extern const struct nla_policy ethnl_fec_get_policy[ETHTOOL_A_FEC_HEADER + 1]; +extern const struct nla_policy ethnl_fec_set_policy[ETHTOOL_A_FEC_AUTO + 1]; +extern const struct nla_policy ethnl_module_eeprom_get_policy[ETHTOOL_A_MODULE_EEPROM_I2C_ADDRESS + 1]; +extern const struct nla_policy ethnl_stats_get_policy[ETHTOOL_A_STATS_GROUPS + 1]; +extern const struct nla_policy ethnl_phc_vclocks_get_policy[ETHTOOL_A_PHC_VCLOCKS_HEADER + 1]; +extern const struct nla_policy ethnl_module_get_policy[ETHTOOL_A_MODULE_HEADER + 1]; +extern const struct nla_policy ethnl_module_set_policy[ETHTOOL_A_MODULE_POWER_MODE_POLICY + 1]; +extern const struct nla_policy ethnl_pse_get_policy[ETHTOOL_A_PSE_HEADER + 1]; +extern const struct nla_policy ethnl_pse_set_policy[ETHTOOL_A_PSE_MAX + 1]; + +int ethnl_set_linkinfo(struct sk_buff *skb, struct genl_info *info); +int ethnl_set_linkmodes(struct sk_buff *skb, struct genl_info *info); +int ethnl_set_debug(struct sk_buff *skb, struct genl_info *info); +int ethnl_set_wol(struct sk_buff *skb, struct genl_info *info); +int ethnl_set_features(struct sk_buff *skb, struct genl_info *info); +int ethnl_set_privflags(struct sk_buff *skb, struct genl_info *info); +int ethnl_set_rings(struct sk_buff *skb, struct genl_info *info); +int ethnl_set_channels(struct sk_buff *skb, struct genl_info *info); +int ethnl_set_coalesce(struct sk_buff *skb, struct genl_info *info); +int ethnl_set_pause(struct sk_buff *skb, struct genl_info *info); +int ethnl_set_eee(struct sk_buff *skb, struct genl_info *info); +int ethnl_act_cable_test(struct sk_buff *skb, struct genl_info *info); +int ethnl_act_cable_test_tdr(struct sk_buff *skb, struct genl_info *info); +int ethnl_tunnel_info_doit(struct sk_buff *skb, struct genl_info *info); +int ethnl_tunnel_info_start(struct netlink_callback *cb); +int ethnl_tunnel_info_dumpit(struct sk_buff *skb, struct netlink_callback *cb); +int ethnl_set_fec(struct sk_buff *skb, struct genl_info *info); +int ethnl_set_module(struct sk_buff *skb, struct genl_info *info); +int ethnl_set_pse(struct sk_buff *skb, struct genl_info *info); + +extern const char stats_std_names[__ETHTOOL_STATS_CNT][ETH_GSTRING_LEN]; +extern const char stats_eth_phy_names[__ETHTOOL_A_STATS_ETH_PHY_CNT][ETH_GSTRING_LEN]; +extern const char stats_eth_mac_names[__ETHTOOL_A_STATS_ETH_MAC_CNT][ETH_GSTRING_LEN]; +extern const char stats_eth_ctrl_names[__ETHTOOL_A_STATS_ETH_CTRL_CNT][ETH_GSTRING_LEN]; +extern const char stats_rmon_names[__ETHTOOL_A_STATS_RMON_CNT][ETH_GSTRING_LEN]; + +#endif /* _NET_ETHTOOL_NETLINK_H */ diff --git a/net/ethtool/pause.c b/net/ethtool/pause.c new file mode 100644 index 000000000..a8c113d24 --- /dev/null +++ b/net/ethtool/pause.c @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include "netlink.h" +#include "common.h" + +struct pause_req_info { + struct ethnl_req_info base; +}; + +struct pause_reply_data { + struct ethnl_reply_data base; + struct ethtool_pauseparam pauseparam; + struct ethtool_pause_stats pausestat; +}; + +#define PAUSE_REPDATA(__reply_base) \ + container_of(__reply_base, struct pause_reply_data, base) + +const struct nla_policy ethnl_pause_get_policy[] = { + [ETHTOOL_A_PAUSE_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy_stats), +}; + +static int pause_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + struct pause_reply_data *data = PAUSE_REPDATA(reply_base); + struct net_device *dev = reply_base->dev; + int ret; + + if (!dev->ethtool_ops->get_pauseparam) + return -EOPNOTSUPP; + + ethtool_stats_init((u64 *)&data->pausestat, + sizeof(data->pausestat) / 8); + + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + dev->ethtool_ops->get_pauseparam(dev, &data->pauseparam); + if (req_base->flags & ETHTOOL_FLAG_STATS && + dev->ethtool_ops->get_pause_stats) + dev->ethtool_ops->get_pause_stats(dev, &data->pausestat); + ethnl_ops_complete(dev); + + return 0; +} + +static int pause_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + int n = nla_total_size(sizeof(u8)) + /* _PAUSE_AUTONEG */ + nla_total_size(sizeof(u8)) + /* _PAUSE_RX */ + nla_total_size(sizeof(u8)); /* _PAUSE_TX */ + + if (req_base->flags & ETHTOOL_FLAG_STATS) + n += nla_total_size(0) + /* _PAUSE_STATS */ + nla_total_size_64bit(sizeof(u64)) * ETHTOOL_PAUSE_STAT_CNT; + return n; +} + +static int ethtool_put_stat(struct sk_buff *skb, u64 val, u16 attrtype, + u16 padtype) +{ + if (val == ETHTOOL_STAT_NOT_SET) + return 0; + if (nla_put_u64_64bit(skb, attrtype, val, padtype)) + return -EMSGSIZE; + + return 0; +} + +static int pause_put_stats(struct sk_buff *skb, + const struct ethtool_pause_stats *pause_stats) +{ + const u16 pad = ETHTOOL_A_PAUSE_STAT_PAD; + struct nlattr *nest; + + nest = nla_nest_start(skb, ETHTOOL_A_PAUSE_STATS); + if (!nest) + return -EMSGSIZE; + + if (ethtool_put_stat(skb, pause_stats->tx_pause_frames, + ETHTOOL_A_PAUSE_STAT_TX_FRAMES, pad) || + ethtool_put_stat(skb, pause_stats->rx_pause_frames, + ETHTOOL_A_PAUSE_STAT_RX_FRAMES, pad)) + goto err_cancel; + + nla_nest_end(skb, nest); + return 0; + +err_cancel: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int pause_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct pause_reply_data *data = PAUSE_REPDATA(reply_base); + const struct ethtool_pauseparam *pauseparam = &data->pauseparam; + + if (nla_put_u8(skb, ETHTOOL_A_PAUSE_AUTONEG, !!pauseparam->autoneg) || + nla_put_u8(skb, ETHTOOL_A_PAUSE_RX, !!pauseparam->rx_pause) || + nla_put_u8(skb, ETHTOOL_A_PAUSE_TX, !!pauseparam->tx_pause)) + return -EMSGSIZE; + + if (req_base->flags & ETHTOOL_FLAG_STATS && + pause_put_stats(skb, &data->pausestat)) + return -EMSGSIZE; + + return 0; +} + +const struct ethnl_request_ops ethnl_pause_request_ops = { + .request_cmd = ETHTOOL_MSG_PAUSE_GET, + .reply_cmd = ETHTOOL_MSG_PAUSE_GET_REPLY, + .hdr_attr = ETHTOOL_A_PAUSE_HEADER, + .req_info_size = sizeof(struct pause_req_info), + .reply_data_size = sizeof(struct pause_reply_data), + + .prepare_data = pause_prepare_data, + .reply_size = pause_reply_size, + .fill_reply = pause_fill_reply, +}; + +/* PAUSE_SET */ + +const struct nla_policy ethnl_pause_set_policy[] = { + [ETHTOOL_A_PAUSE_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_PAUSE_AUTONEG] = { .type = NLA_U8 }, + [ETHTOOL_A_PAUSE_RX] = { .type = NLA_U8 }, + [ETHTOOL_A_PAUSE_TX] = { .type = NLA_U8 }, +}; + +int ethnl_set_pause(struct sk_buff *skb, struct genl_info *info) +{ + struct ethtool_pauseparam params = {}; + struct ethnl_req_info req_info = {}; + struct nlattr **tb = info->attrs; + const struct ethtool_ops *ops; + struct net_device *dev; + bool mod = false; + int ret; + + ret = ethnl_parse_header_dev_get(&req_info, + tb[ETHTOOL_A_PAUSE_HEADER], + genl_info_net(info), info->extack, + true); + if (ret < 0) + return ret; + dev = req_info.dev; + ops = dev->ethtool_ops; + ret = -EOPNOTSUPP; + if (!ops->get_pauseparam || !ops->set_pauseparam) + goto out_dev; + + rtnl_lock(); + ret = ethnl_ops_begin(dev); + if (ret < 0) + goto out_rtnl; + ops->get_pauseparam(dev, ¶ms); + + ethnl_update_bool32(¶ms.autoneg, tb[ETHTOOL_A_PAUSE_AUTONEG], &mod); + ethnl_update_bool32(¶ms.rx_pause, tb[ETHTOOL_A_PAUSE_RX], &mod); + ethnl_update_bool32(¶ms.tx_pause, tb[ETHTOOL_A_PAUSE_TX], &mod); + ret = 0; + if (!mod) + goto out_ops; + + ret = dev->ethtool_ops->set_pauseparam(dev, ¶ms); + if (ret < 0) + goto out_ops; + ethtool_notify(dev, ETHTOOL_MSG_PAUSE_NTF, NULL); + +out_ops: + ethnl_ops_complete(dev); +out_rtnl: + rtnl_unlock(); +out_dev: + ethnl_parse_header_dev_put(&req_info); + return ret; +} diff --git a/net/ethtool/phc_vclocks.c b/net/ethtool/phc_vclocks.c new file mode 100644 index 000000000..637b2f529 --- /dev/null +++ b/net/ethtool/phc_vclocks.c @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2021 NXP + */ +#include "netlink.h" +#include "common.h" + +struct phc_vclocks_req_info { + struct ethnl_req_info base; +}; + +struct phc_vclocks_reply_data { + struct ethnl_reply_data base; + int num; + int *index; +}; + +#define PHC_VCLOCKS_REPDATA(__reply_base) \ + container_of(__reply_base, struct phc_vclocks_reply_data, base) + +const struct nla_policy ethnl_phc_vclocks_get_policy[] = { + [ETHTOOL_A_PHC_VCLOCKS_HEADER] = NLA_POLICY_NESTED(ethnl_header_policy), +}; + +static int phc_vclocks_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + struct phc_vclocks_reply_data *data = PHC_VCLOCKS_REPDATA(reply_base); + struct net_device *dev = reply_base->dev; + int ret; + + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + data->num = ethtool_get_phc_vclocks(dev, &data->index); + ethnl_ops_complete(dev); + + return ret; +} + +static int phc_vclocks_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct phc_vclocks_reply_data *data = + PHC_VCLOCKS_REPDATA(reply_base); + int len = 0; + + if (data->num > 0) { + len += nla_total_size(sizeof(u32)); + len += nla_total_size(sizeof(s32) * data->num); + } + + return len; +} + +static int phc_vclocks_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct phc_vclocks_reply_data *data = + PHC_VCLOCKS_REPDATA(reply_base); + + if (data->num <= 0) + return 0; + + if (nla_put_u32(skb, ETHTOOL_A_PHC_VCLOCKS_NUM, data->num) || + nla_put(skb, ETHTOOL_A_PHC_VCLOCKS_INDEX, + sizeof(s32) * data->num, data->index)) + return -EMSGSIZE; + + return 0; +} + +static void phc_vclocks_cleanup_data(struct ethnl_reply_data *reply_base) +{ + const struct phc_vclocks_reply_data *data = + PHC_VCLOCKS_REPDATA(reply_base); + + kfree(data->index); +} + +const struct ethnl_request_ops ethnl_phc_vclocks_request_ops = { + .request_cmd = ETHTOOL_MSG_PHC_VCLOCKS_GET, + .reply_cmd = ETHTOOL_MSG_PHC_VCLOCKS_GET_REPLY, + .hdr_attr = ETHTOOL_A_PHC_VCLOCKS_HEADER, + .req_info_size = sizeof(struct phc_vclocks_req_info), + .reply_data_size = sizeof(struct phc_vclocks_reply_data), + + .prepare_data = phc_vclocks_prepare_data, + .reply_size = phc_vclocks_reply_size, + .fill_reply = phc_vclocks_fill_reply, + .cleanup_data = phc_vclocks_cleanup_data, +}; diff --git a/net/ethtool/privflags.c b/net/ethtool/privflags.c new file mode 100644 index 000000000..4c7bfa81e --- /dev/null +++ b/net/ethtool/privflags.c @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include "netlink.h" +#include "common.h" +#include "bitset.h" + +struct privflags_req_info { + struct ethnl_req_info base; +}; + +struct privflags_reply_data { + struct ethnl_reply_data base; + const char (*priv_flag_names)[ETH_GSTRING_LEN]; + unsigned int n_priv_flags; + u32 priv_flags; +}; + +#define PRIVFLAGS_REPDATA(__reply_base) \ + container_of(__reply_base, struct privflags_reply_data, base) + +const struct nla_policy ethnl_privflags_get_policy[] = { + [ETHTOOL_A_PRIVFLAGS_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), +}; + +static int ethnl_get_priv_flags_info(struct net_device *dev, + unsigned int *count, + const char (**names)[ETH_GSTRING_LEN]) +{ + const struct ethtool_ops *ops = dev->ethtool_ops; + int nflags; + + nflags = ops->get_sset_count(dev, ETH_SS_PRIV_FLAGS); + if (nflags < 0) + return nflags; + + if (names) { + *names = kcalloc(nflags, ETH_GSTRING_LEN, GFP_KERNEL); + if (!*names) + return -ENOMEM; + ops->get_strings(dev, ETH_SS_PRIV_FLAGS, (u8 *)*names); + } + + /* We can pass more than 32 private flags to userspace via netlink but + * we cannot get more with ethtool_ops::get_priv_flags(). Note that we + * must not adjust nflags before allocating the space for flag names + * as the buffer must be large enough for all flags. + */ + if (WARN_ONCE(nflags > 32, + "device %s reports more than 32 private flags (%d)\n", + netdev_name(dev), nflags)) + nflags = 32; + *count = nflags; + + return 0; +} + +static int privflags_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + struct privflags_reply_data *data = PRIVFLAGS_REPDATA(reply_base); + struct net_device *dev = reply_base->dev; + const char (*names)[ETH_GSTRING_LEN]; + const struct ethtool_ops *ops; + unsigned int nflags; + int ret; + + ops = dev->ethtool_ops; + if (!ops->get_priv_flags || !ops->get_sset_count || !ops->get_strings) + return -EOPNOTSUPP; + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + + ret = ethnl_get_priv_flags_info(dev, &nflags, &names); + if (ret < 0) + goto out_ops; + data->priv_flags = ops->get_priv_flags(dev); + data->priv_flag_names = names; + data->n_priv_flags = nflags; + +out_ops: + ethnl_ops_complete(dev); + return ret; +} + +static int privflags_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct privflags_reply_data *data = PRIVFLAGS_REPDATA(reply_base); + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + const u32 all_flags = ~(u32)0 >> (32 - data->n_priv_flags); + + return ethnl_bitset32_size(&data->priv_flags, &all_flags, + data->n_priv_flags, + data->priv_flag_names, compact); +} + +static int privflags_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct privflags_reply_data *data = PRIVFLAGS_REPDATA(reply_base); + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + const u32 all_flags = ~(u32)0 >> (32 - data->n_priv_flags); + + return ethnl_put_bitset32(skb, ETHTOOL_A_PRIVFLAGS_FLAGS, + &data->priv_flags, &all_flags, + data->n_priv_flags, data->priv_flag_names, + compact); +} + +static void privflags_cleanup_data(struct ethnl_reply_data *reply_data) +{ + struct privflags_reply_data *data = PRIVFLAGS_REPDATA(reply_data); + + kfree(data->priv_flag_names); +} + +const struct ethnl_request_ops ethnl_privflags_request_ops = { + .request_cmd = ETHTOOL_MSG_PRIVFLAGS_GET, + .reply_cmd = ETHTOOL_MSG_PRIVFLAGS_GET_REPLY, + .hdr_attr = ETHTOOL_A_PRIVFLAGS_HEADER, + .req_info_size = sizeof(struct privflags_req_info), + .reply_data_size = sizeof(struct privflags_reply_data), + + .prepare_data = privflags_prepare_data, + .reply_size = privflags_reply_size, + .fill_reply = privflags_fill_reply, + .cleanup_data = privflags_cleanup_data, +}; + +/* PRIVFLAGS_SET */ + +const struct nla_policy ethnl_privflags_set_policy[] = { + [ETHTOOL_A_PRIVFLAGS_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_PRIVFLAGS_FLAGS] = { .type = NLA_NESTED }, +}; + +int ethnl_set_privflags(struct sk_buff *skb, struct genl_info *info) +{ + const char (*names)[ETH_GSTRING_LEN] = NULL; + struct ethnl_req_info req_info = {}; + struct nlattr **tb = info->attrs; + const struct ethtool_ops *ops; + struct net_device *dev; + unsigned int nflags; + bool mod = false; + bool compact; + u32 flags; + int ret; + + if (!tb[ETHTOOL_A_PRIVFLAGS_FLAGS]) + return -EINVAL; + ret = ethnl_bitset_is_compact(tb[ETHTOOL_A_PRIVFLAGS_FLAGS], &compact); + if (ret < 0) + return ret; + ret = ethnl_parse_header_dev_get(&req_info, + tb[ETHTOOL_A_PRIVFLAGS_HEADER], + genl_info_net(info), info->extack, + true); + if (ret < 0) + return ret; + dev = req_info.dev; + ops = dev->ethtool_ops; + ret = -EOPNOTSUPP; + if (!ops->get_priv_flags || !ops->set_priv_flags || + !ops->get_sset_count || !ops->get_strings) + goto out_dev; + + rtnl_lock(); + ret = ethnl_ops_begin(dev); + if (ret < 0) + goto out_rtnl; + ret = ethnl_get_priv_flags_info(dev, &nflags, compact ? NULL : &names); + if (ret < 0) + goto out_ops; + flags = ops->get_priv_flags(dev); + + ret = ethnl_update_bitset32(&flags, nflags, + tb[ETHTOOL_A_PRIVFLAGS_FLAGS], names, + info->extack, &mod); + if (ret < 0 || !mod) + goto out_free; + ret = ops->set_priv_flags(dev, flags); + if (ret < 0) + goto out_free; + ethtool_notify(dev, ETHTOOL_MSG_PRIVFLAGS_NTF, NULL); + +out_free: + kfree(names); +out_ops: + ethnl_ops_complete(dev); +out_rtnl: + rtnl_unlock(); +out_dev: + ethnl_parse_header_dev_put(&req_info); + return ret; +} diff --git a/net/ethtool/pse-pd.c b/net/ethtool/pse-pd.c new file mode 100644 index 000000000..e8683e485 --- /dev/null +++ b/net/ethtool/pse-pd.c @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: GPL-2.0-only +// +// ethtool interface for for Ethernet PSE (Power Sourcing Equipment) +// and PD (Powered Device) +// +// Copyright (c) 2022 Pengutronix, Oleksij Rempel +// + +#include "common.h" +#include "linux/pse-pd/pse.h" +#include "netlink.h" +#include +#include +#include + +struct pse_req_info { + struct ethnl_req_info base; +}; + +struct pse_reply_data { + struct ethnl_reply_data base; + struct pse_control_status status; +}; + +#define PSE_REPDATA(__reply_base) \ + container_of(__reply_base, struct pse_reply_data, base) + +/* PSE_GET */ + +const struct nla_policy ethnl_pse_get_policy[ETHTOOL_A_PSE_HEADER + 1] = { + [ETHTOOL_A_PSE_HEADER] = NLA_POLICY_NESTED(ethnl_header_policy), +}; + +static int pse_get_pse_attributes(struct net_device *dev, + struct netlink_ext_ack *extack, + struct pse_reply_data *data) +{ + struct phy_device *phydev = dev->phydev; + + if (!phydev) { + NL_SET_ERR_MSG(extack, "No PHY is attached"); + return -EOPNOTSUPP; + } + + if (!phydev->psec) { + NL_SET_ERR_MSG(extack, "No PSE is attached"); + return -EOPNOTSUPP; + } + + memset(&data->status, 0, sizeof(data->status)); + + return pse_ethtool_get_status(phydev->psec, extack, &data->status); +} + +static int pse_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + struct pse_reply_data *data = PSE_REPDATA(reply_base); + struct net_device *dev = reply_base->dev; + int ret; + + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + + ret = pse_get_pse_attributes(dev, info ? info->extack : NULL, data); + + ethnl_ops_complete(dev); + + return ret; +} + +static int pse_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct pse_reply_data *data = PSE_REPDATA(reply_base); + const struct pse_control_status *st = &data->status; + int len = 0; + + if (st->podl_admin_state > 0) + len += nla_total_size(sizeof(u32)); /* _PODL_PSE_ADMIN_STATE */ + if (st->podl_pw_status > 0) + len += nla_total_size(sizeof(u32)); /* _PODL_PSE_PW_D_STATUS */ + + return len; +} + +static int pse_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct pse_reply_data *data = PSE_REPDATA(reply_base); + const struct pse_control_status *st = &data->status; + + if (st->podl_admin_state > 0 && + nla_put_u32(skb, ETHTOOL_A_PODL_PSE_ADMIN_STATE, + st->podl_admin_state)) + return -EMSGSIZE; + + if (st->podl_pw_status > 0 && + nla_put_u32(skb, ETHTOOL_A_PODL_PSE_PW_D_STATUS, + st->podl_pw_status)) + return -EMSGSIZE; + + return 0; +} + +const struct ethnl_request_ops ethnl_pse_request_ops = { + .request_cmd = ETHTOOL_MSG_PSE_GET, + .reply_cmd = ETHTOOL_MSG_PSE_GET_REPLY, + .hdr_attr = ETHTOOL_A_PSE_HEADER, + .req_info_size = sizeof(struct pse_req_info), + .reply_data_size = sizeof(struct pse_reply_data), + + .prepare_data = pse_prepare_data, + .reply_size = pse_reply_size, + .fill_reply = pse_fill_reply, +}; + +/* PSE_SET */ + +const struct nla_policy ethnl_pse_set_policy[ETHTOOL_A_PSE_MAX + 1] = { + [ETHTOOL_A_PSE_HEADER] = NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_PODL_PSE_ADMIN_CONTROL] = + NLA_POLICY_RANGE(NLA_U32, ETHTOOL_PODL_PSE_ADMIN_STATE_DISABLED, + ETHTOOL_PODL_PSE_ADMIN_STATE_ENABLED), +}; + +static int pse_set_pse_config(struct net_device *dev, + struct netlink_ext_ack *extack, + struct nlattr **tb) +{ + struct phy_device *phydev = dev->phydev; + struct pse_control_config config = {}; + + /* Optional attribute. Do not return error if not set. */ + if (!tb[ETHTOOL_A_PODL_PSE_ADMIN_CONTROL]) + return 0; + + /* this values are already validated by the ethnl_pse_set_policy */ + config.admin_cotrol = nla_get_u32(tb[ETHTOOL_A_PODL_PSE_ADMIN_CONTROL]); + + if (!phydev) { + NL_SET_ERR_MSG(extack, "No PHY is attached"); + return -EOPNOTSUPP; + } + + if (!phydev->psec) { + NL_SET_ERR_MSG(extack, "No PSE is attached"); + return -EOPNOTSUPP; + } + + return pse_ethtool_set_config(phydev->psec, extack, &config); +} + +int ethnl_set_pse(struct sk_buff *skb, struct genl_info *info) +{ + struct ethnl_req_info req_info = {}; + struct nlattr **tb = info->attrs; + struct net_device *dev; + int ret; + + ret = ethnl_parse_header_dev_get(&req_info, tb[ETHTOOL_A_PSE_HEADER], + genl_info_net(info), info->extack, + true); + if (ret < 0) + return ret; + + dev = req_info.dev; + + rtnl_lock(); + ret = ethnl_ops_begin(dev); + if (ret < 0) + goto out_rtnl; + + ret = pse_set_pse_config(dev, info->extack, tb); + ethnl_ops_complete(dev); +out_rtnl: + rtnl_unlock(); + + ethnl_parse_header_dev_put(&req_info); + + return ret; +} diff --git a/net/ethtool/rings.c b/net/ethtool/rings.c new file mode 100644 index 000000000..fa3ec8d43 --- /dev/null +++ b/net/ethtool/rings.c @@ -0,0 +1,235 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include "netlink.h" +#include "common.h" + +struct rings_req_info { + struct ethnl_req_info base; +}; + +struct rings_reply_data { + struct ethnl_reply_data base; + struct ethtool_ringparam ringparam; + struct kernel_ethtool_ringparam kernel_ringparam; +}; + +#define RINGS_REPDATA(__reply_base) \ + container_of(__reply_base, struct rings_reply_data, base) + +const struct nla_policy ethnl_rings_get_policy[] = { + [ETHTOOL_A_RINGS_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), +}; + +static int rings_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + struct rings_reply_data *data = RINGS_REPDATA(reply_base); + struct netlink_ext_ack *extack = info ? info->extack : NULL; + struct net_device *dev = reply_base->dev; + int ret; + + if (!dev->ethtool_ops->get_ringparam) + return -EOPNOTSUPP; + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + dev->ethtool_ops->get_ringparam(dev, &data->ringparam, + &data->kernel_ringparam, extack); + ethnl_ops_complete(dev); + + return 0; +} + +static int rings_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + return nla_total_size(sizeof(u32)) + /* _RINGS_RX_MAX */ + nla_total_size(sizeof(u32)) + /* _RINGS_RX_MINI_MAX */ + nla_total_size(sizeof(u32)) + /* _RINGS_RX_JUMBO_MAX */ + nla_total_size(sizeof(u32)) + /* _RINGS_TX_MAX */ + nla_total_size(sizeof(u32)) + /* _RINGS_RX */ + nla_total_size(sizeof(u32)) + /* _RINGS_RX_MINI */ + nla_total_size(sizeof(u32)) + /* _RINGS_RX_JUMBO */ + nla_total_size(sizeof(u32)) + /* _RINGS_TX */ + nla_total_size(sizeof(u32)) + /* _RINGS_RX_BUF_LEN */ + nla_total_size(sizeof(u8)) + /* _RINGS_TCP_DATA_SPLIT */ + nla_total_size(sizeof(u32) + /* _RINGS_CQE_SIZE */ + nla_total_size(sizeof(u8))); /* _RINGS_TX_PUSH */ +} + +static int rings_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct rings_reply_data *data = RINGS_REPDATA(reply_base); + const struct kernel_ethtool_ringparam *kr = &data->kernel_ringparam; + const struct ethtool_ringparam *ringparam = &data->ringparam; + + WARN_ON(kr->tcp_data_split > ETHTOOL_TCP_DATA_SPLIT_ENABLED); + + if ((ringparam->rx_max_pending && + (nla_put_u32(skb, ETHTOOL_A_RINGS_RX_MAX, + ringparam->rx_max_pending) || + nla_put_u32(skb, ETHTOOL_A_RINGS_RX, + ringparam->rx_pending))) || + (ringparam->rx_mini_max_pending && + (nla_put_u32(skb, ETHTOOL_A_RINGS_RX_MINI_MAX, + ringparam->rx_mini_max_pending) || + nla_put_u32(skb, ETHTOOL_A_RINGS_RX_MINI, + ringparam->rx_mini_pending))) || + (ringparam->rx_jumbo_max_pending && + (nla_put_u32(skb, ETHTOOL_A_RINGS_RX_JUMBO_MAX, + ringparam->rx_jumbo_max_pending) || + nla_put_u32(skb, ETHTOOL_A_RINGS_RX_JUMBO, + ringparam->rx_jumbo_pending))) || + (ringparam->tx_max_pending && + (nla_put_u32(skb, ETHTOOL_A_RINGS_TX_MAX, + ringparam->tx_max_pending) || + nla_put_u32(skb, ETHTOOL_A_RINGS_TX, + ringparam->tx_pending))) || + (kr->rx_buf_len && + (nla_put_u32(skb, ETHTOOL_A_RINGS_RX_BUF_LEN, kr->rx_buf_len))) || + (kr->tcp_data_split && + (nla_put_u8(skb, ETHTOOL_A_RINGS_TCP_DATA_SPLIT, + kr->tcp_data_split))) || + (kr->cqe_size && + (nla_put_u32(skb, ETHTOOL_A_RINGS_CQE_SIZE, kr->cqe_size))) || + nla_put_u8(skb, ETHTOOL_A_RINGS_TX_PUSH, !!kr->tx_push)) + return -EMSGSIZE; + + return 0; +} + +const struct ethnl_request_ops ethnl_rings_request_ops = { + .request_cmd = ETHTOOL_MSG_RINGS_GET, + .reply_cmd = ETHTOOL_MSG_RINGS_GET_REPLY, + .hdr_attr = ETHTOOL_A_RINGS_HEADER, + .req_info_size = sizeof(struct rings_req_info), + .reply_data_size = sizeof(struct rings_reply_data), + + .prepare_data = rings_prepare_data, + .reply_size = rings_reply_size, + .fill_reply = rings_fill_reply, +}; + +/* RINGS_SET */ + +const struct nla_policy ethnl_rings_set_policy[] = { + [ETHTOOL_A_RINGS_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_RINGS_RX] = { .type = NLA_U32 }, + [ETHTOOL_A_RINGS_RX_MINI] = { .type = NLA_U32 }, + [ETHTOOL_A_RINGS_RX_JUMBO] = { .type = NLA_U32 }, + [ETHTOOL_A_RINGS_TX] = { .type = NLA_U32 }, + [ETHTOOL_A_RINGS_RX_BUF_LEN] = NLA_POLICY_MIN(NLA_U32, 1), + [ETHTOOL_A_RINGS_CQE_SIZE] = NLA_POLICY_MIN(NLA_U32, 1), + [ETHTOOL_A_RINGS_TX_PUSH] = NLA_POLICY_MAX(NLA_U8, 1), +}; + +int ethnl_set_rings(struct sk_buff *skb, struct genl_info *info) +{ + struct kernel_ethtool_ringparam kernel_ringparam = {}; + struct ethtool_ringparam ringparam = {}; + struct ethnl_req_info req_info = {}; + struct nlattr **tb = info->attrs; + const struct nlattr *err_attr; + const struct ethtool_ops *ops; + struct net_device *dev; + bool mod = false; + int ret; + + ret = ethnl_parse_header_dev_get(&req_info, + tb[ETHTOOL_A_RINGS_HEADER], + genl_info_net(info), info->extack, + true); + if (ret < 0) + return ret; + dev = req_info.dev; + ops = dev->ethtool_ops; + ret = -EOPNOTSUPP; + if (!ops->get_ringparam || !ops->set_ringparam) + goto out_dev; + + if (tb[ETHTOOL_A_RINGS_RX_BUF_LEN] && + !(ops->supported_ring_params & ETHTOOL_RING_USE_RX_BUF_LEN)) { + ret = -EOPNOTSUPP; + NL_SET_ERR_MSG_ATTR(info->extack, + tb[ETHTOOL_A_RINGS_RX_BUF_LEN], + "setting rx buf len not supported"); + goto out_dev; + } + + if (tb[ETHTOOL_A_RINGS_CQE_SIZE] && + !(ops->supported_ring_params & ETHTOOL_RING_USE_CQE_SIZE)) { + ret = -EOPNOTSUPP; + NL_SET_ERR_MSG_ATTR(info->extack, + tb[ETHTOOL_A_RINGS_CQE_SIZE], + "setting cqe size not supported"); + goto out_dev; + } + + if (tb[ETHTOOL_A_RINGS_TX_PUSH] && + !(ops->supported_ring_params & ETHTOOL_RING_USE_TX_PUSH)) { + ret = -EOPNOTSUPP; + NL_SET_ERR_MSG_ATTR(info->extack, + tb[ETHTOOL_A_RINGS_TX_PUSH], + "setting tx push not supported"); + goto out_dev; + } + + rtnl_lock(); + ret = ethnl_ops_begin(dev); + if (ret < 0) + goto out_rtnl; + ops->get_ringparam(dev, &ringparam, &kernel_ringparam, info->extack); + + ethnl_update_u32(&ringparam.rx_pending, tb[ETHTOOL_A_RINGS_RX], &mod); + ethnl_update_u32(&ringparam.rx_mini_pending, + tb[ETHTOOL_A_RINGS_RX_MINI], &mod); + ethnl_update_u32(&ringparam.rx_jumbo_pending, + tb[ETHTOOL_A_RINGS_RX_JUMBO], &mod); + ethnl_update_u32(&ringparam.tx_pending, tb[ETHTOOL_A_RINGS_TX], &mod); + ethnl_update_u32(&kernel_ringparam.rx_buf_len, + tb[ETHTOOL_A_RINGS_RX_BUF_LEN], &mod); + ethnl_update_u32(&kernel_ringparam.cqe_size, + tb[ETHTOOL_A_RINGS_CQE_SIZE], &mod); + ethnl_update_u8(&kernel_ringparam.tx_push, + tb[ETHTOOL_A_RINGS_TX_PUSH], &mod); + ret = 0; + if (!mod) + goto out_ops; + + /* ensure new ring parameters are within limits */ + if (ringparam.rx_pending > ringparam.rx_max_pending) + err_attr = tb[ETHTOOL_A_RINGS_RX]; + else if (ringparam.rx_mini_pending > ringparam.rx_mini_max_pending) + err_attr = tb[ETHTOOL_A_RINGS_RX_MINI]; + else if (ringparam.rx_jumbo_pending > ringparam.rx_jumbo_max_pending) + err_attr = tb[ETHTOOL_A_RINGS_RX_JUMBO]; + else if (ringparam.tx_pending > ringparam.tx_max_pending) + err_attr = tb[ETHTOOL_A_RINGS_TX]; + else + err_attr = NULL; + if (err_attr) { + ret = -EINVAL; + NL_SET_ERR_MSG_ATTR(info->extack, err_attr, + "requested ring size exceeds maximum"); + goto out_ops; + } + + ret = dev->ethtool_ops->set_ringparam(dev, &ringparam, + &kernel_ringparam, info->extack); + if (ret < 0) + goto out_ops; + ethtool_notify(dev, ETHTOOL_MSG_RINGS_NTF, NULL); + +out_ops: + ethnl_ops_complete(dev); +out_rtnl: + rtnl_unlock(); +out_dev: + ethnl_parse_header_dev_put(&req_info); + return ret; +} diff --git a/net/ethtool/stats.c b/net/ethtool/stats.c new file mode 100644 index 000000000..a20e0a24f --- /dev/null +++ b/net/ethtool/stats.c @@ -0,0 +1,412 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include "netlink.h" +#include "common.h" +#include "bitset.h" + +struct stats_req_info { + struct ethnl_req_info base; + DECLARE_BITMAP(stat_mask, __ETHTOOL_STATS_CNT); +}; + +#define STATS_REQINFO(__req_base) \ + container_of(__req_base, struct stats_req_info, base) + +struct stats_reply_data { + struct ethnl_reply_data base; + struct_group(stats, + struct ethtool_eth_phy_stats phy_stats; + struct ethtool_eth_mac_stats mac_stats; + struct ethtool_eth_ctrl_stats ctrl_stats; + struct ethtool_rmon_stats rmon_stats; + ); + const struct ethtool_rmon_hist_range *rmon_ranges; +}; + +#define STATS_REPDATA(__reply_base) \ + container_of(__reply_base, struct stats_reply_data, base) + +const char stats_std_names[__ETHTOOL_STATS_CNT][ETH_GSTRING_LEN] = { + [ETHTOOL_STATS_ETH_PHY] = "eth-phy", + [ETHTOOL_STATS_ETH_MAC] = "eth-mac", + [ETHTOOL_STATS_ETH_CTRL] = "eth-ctrl", + [ETHTOOL_STATS_RMON] = "rmon", +}; + +const char stats_eth_phy_names[__ETHTOOL_A_STATS_ETH_PHY_CNT][ETH_GSTRING_LEN] = { + [ETHTOOL_A_STATS_ETH_PHY_5_SYM_ERR] = "SymbolErrorDuringCarrier", +}; + +const char stats_eth_mac_names[__ETHTOOL_A_STATS_ETH_MAC_CNT][ETH_GSTRING_LEN] = { + [ETHTOOL_A_STATS_ETH_MAC_2_TX_PKT] = "FramesTransmittedOK", + [ETHTOOL_A_STATS_ETH_MAC_3_SINGLE_COL] = "SingleCollisionFrames", + [ETHTOOL_A_STATS_ETH_MAC_4_MULTI_COL] = "MultipleCollisionFrames", + [ETHTOOL_A_STATS_ETH_MAC_5_RX_PKT] = "FramesReceivedOK", + [ETHTOOL_A_STATS_ETH_MAC_6_FCS_ERR] = "FrameCheckSequenceErrors", + [ETHTOOL_A_STATS_ETH_MAC_7_ALIGN_ERR] = "AlignmentErrors", + [ETHTOOL_A_STATS_ETH_MAC_8_TX_BYTES] = "OctetsTransmittedOK", + [ETHTOOL_A_STATS_ETH_MAC_9_TX_DEFER] = "FramesWithDeferredXmissions", + [ETHTOOL_A_STATS_ETH_MAC_10_LATE_COL] = "LateCollisions", + [ETHTOOL_A_STATS_ETH_MAC_11_XS_COL] = "FramesAbortedDueToXSColls", + [ETHTOOL_A_STATS_ETH_MAC_12_TX_INT_ERR] = "FramesLostDueToIntMACXmitError", + [ETHTOOL_A_STATS_ETH_MAC_13_CS_ERR] = "CarrierSenseErrors", + [ETHTOOL_A_STATS_ETH_MAC_14_RX_BYTES] = "OctetsReceivedOK", + [ETHTOOL_A_STATS_ETH_MAC_15_RX_INT_ERR] = "FramesLostDueToIntMACRcvError", + [ETHTOOL_A_STATS_ETH_MAC_18_TX_MCAST] = "MulticastFramesXmittedOK", + [ETHTOOL_A_STATS_ETH_MAC_19_TX_BCAST] = "BroadcastFramesXmittedOK", + [ETHTOOL_A_STATS_ETH_MAC_20_XS_DEFER] = "FramesWithExcessiveDeferral", + [ETHTOOL_A_STATS_ETH_MAC_21_RX_MCAST] = "MulticastFramesReceivedOK", + [ETHTOOL_A_STATS_ETH_MAC_22_RX_BCAST] = "BroadcastFramesReceivedOK", + [ETHTOOL_A_STATS_ETH_MAC_23_IR_LEN_ERR] = "InRangeLengthErrors", + [ETHTOOL_A_STATS_ETH_MAC_24_OOR_LEN] = "OutOfRangeLengthField", + [ETHTOOL_A_STATS_ETH_MAC_25_TOO_LONG_ERR] = "FrameTooLongErrors", +}; + +const char stats_eth_ctrl_names[__ETHTOOL_A_STATS_ETH_CTRL_CNT][ETH_GSTRING_LEN] = { + [ETHTOOL_A_STATS_ETH_CTRL_3_TX] = "MACControlFramesTransmitted", + [ETHTOOL_A_STATS_ETH_CTRL_4_RX] = "MACControlFramesReceived", + [ETHTOOL_A_STATS_ETH_CTRL_5_RX_UNSUP] = "UnsupportedOpcodesReceived", +}; + +const char stats_rmon_names[__ETHTOOL_A_STATS_RMON_CNT][ETH_GSTRING_LEN] = { + [ETHTOOL_A_STATS_RMON_UNDERSIZE] = "etherStatsUndersizePkts", + [ETHTOOL_A_STATS_RMON_OVERSIZE] = "etherStatsOversizePkts", + [ETHTOOL_A_STATS_RMON_FRAG] = "etherStatsFragments", + [ETHTOOL_A_STATS_RMON_JABBER] = "etherStatsJabbers", +}; + +const struct nla_policy ethnl_stats_get_policy[ETHTOOL_A_STATS_GROUPS + 1] = { + [ETHTOOL_A_STATS_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_STATS_GROUPS] = { .type = NLA_NESTED }, +}; + +static int stats_parse_request(struct ethnl_req_info *req_base, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct stats_req_info *req_info = STATS_REQINFO(req_base); + bool mod = false; + int err; + + err = ethnl_update_bitset(req_info->stat_mask, __ETHTOOL_STATS_CNT, + tb[ETHTOOL_A_STATS_GROUPS], stats_std_names, + extack, &mod); + if (err) + return err; + + if (!mod) { + NL_SET_ERR_MSG(extack, "no stats requested"); + return -EINVAL; + } + + return 0; +} + +static int stats_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + const struct stats_req_info *req_info = STATS_REQINFO(req_base); + struct stats_reply_data *data = STATS_REPDATA(reply_base); + struct net_device *dev = reply_base->dev; + int ret; + + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + + /* Mark all stats as unset (see ETHTOOL_STAT_NOT_SET) to prevent them + * from being reported to user space in case driver did not set them. + */ + memset(&data->stats, 0xff, sizeof(data->stats)); + + if (test_bit(ETHTOOL_STATS_ETH_PHY, req_info->stat_mask) && + dev->ethtool_ops->get_eth_phy_stats) + dev->ethtool_ops->get_eth_phy_stats(dev, &data->phy_stats); + if (test_bit(ETHTOOL_STATS_ETH_MAC, req_info->stat_mask) && + dev->ethtool_ops->get_eth_mac_stats) + dev->ethtool_ops->get_eth_mac_stats(dev, &data->mac_stats); + if (test_bit(ETHTOOL_STATS_ETH_CTRL, req_info->stat_mask) && + dev->ethtool_ops->get_eth_ctrl_stats) + dev->ethtool_ops->get_eth_ctrl_stats(dev, &data->ctrl_stats); + if (test_bit(ETHTOOL_STATS_RMON, req_info->stat_mask) && + dev->ethtool_ops->get_rmon_stats) + dev->ethtool_ops->get_rmon_stats(dev, &data->rmon_stats, + &data->rmon_ranges); + + ethnl_ops_complete(dev); + return 0; +} + +static int stats_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct stats_req_info *req_info = STATS_REQINFO(req_base); + unsigned int n_grps = 0, n_stats = 0; + int len = 0; + + if (test_bit(ETHTOOL_STATS_ETH_PHY, req_info->stat_mask)) { + n_stats += sizeof(struct ethtool_eth_phy_stats) / sizeof(u64); + n_grps++; + } + if (test_bit(ETHTOOL_STATS_ETH_MAC, req_info->stat_mask)) { + n_stats += sizeof(struct ethtool_eth_mac_stats) / sizeof(u64); + n_grps++; + } + if (test_bit(ETHTOOL_STATS_ETH_CTRL, req_info->stat_mask)) { + n_stats += sizeof(struct ethtool_eth_ctrl_stats) / sizeof(u64); + n_grps++; + } + if (test_bit(ETHTOOL_STATS_RMON, req_info->stat_mask)) { + n_stats += sizeof(struct ethtool_rmon_stats) / sizeof(u64); + n_grps++; + /* Above includes the space for _A_STATS_GRP_HIST_VALs */ + + len += (nla_total_size(0) + /* _A_STATS_GRP_HIST */ + nla_total_size(4) + /* _A_STATS_GRP_HIST_BKT_LOW */ + nla_total_size(4)) * /* _A_STATS_GRP_HIST_BKT_HI */ + ETHTOOL_RMON_HIST_MAX * 2; + } + + len += n_grps * (nla_total_size(0) + /* _A_STATS_GRP */ + nla_total_size(4) + /* _A_STATS_GRP_ID */ + nla_total_size(4)); /* _A_STATS_GRP_SS_ID */ + len += n_stats * (nla_total_size(0) + /* _A_STATS_GRP_STAT */ + nla_total_size_64bit(sizeof(u64))); + + return len; +} + +static int stat_put(struct sk_buff *skb, u16 attrtype, u64 val) +{ + struct nlattr *nest; + int ret; + + if (val == ETHTOOL_STAT_NOT_SET) + return 0; + + /* We want to start stats attr types from 0, so we don't have a type + * for pad inside ETHTOOL_A_STATS_GRP_STAT. Pad things on the outside + * of ETHTOOL_A_STATS_GRP_STAT. Since we're one nest away from the + * actual attr we're 4B off - nla_need_padding_for_64bit() & co. + * can't be used. + */ +#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS + if (!IS_ALIGNED((unsigned long)skb_tail_pointer(skb), 8)) + if (!nla_reserve(skb, ETHTOOL_A_STATS_GRP_PAD, 0)) + return -EMSGSIZE; +#endif + + nest = nla_nest_start(skb, ETHTOOL_A_STATS_GRP_STAT); + if (!nest) + return -EMSGSIZE; + + ret = nla_put_u64_64bit(skb, attrtype, val, -1 /* not used */); + if (ret) { + nla_nest_cancel(skb, nest); + return ret; + } + + nla_nest_end(skb, nest); + return 0; +} + +static int stats_put_phy_stats(struct sk_buff *skb, + const struct stats_reply_data *data) +{ + if (stat_put(skb, ETHTOOL_A_STATS_ETH_PHY_5_SYM_ERR, + data->phy_stats.SymbolErrorDuringCarrier)) + return -EMSGSIZE; + return 0; +} + +static int stats_put_mac_stats(struct sk_buff *skb, + const struct stats_reply_data *data) +{ + if (stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_2_TX_PKT, + data->mac_stats.FramesTransmittedOK) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_3_SINGLE_COL, + data->mac_stats.SingleCollisionFrames) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_4_MULTI_COL, + data->mac_stats.MultipleCollisionFrames) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_5_RX_PKT, + data->mac_stats.FramesReceivedOK) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_6_FCS_ERR, + data->mac_stats.FrameCheckSequenceErrors) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_7_ALIGN_ERR, + data->mac_stats.AlignmentErrors) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_8_TX_BYTES, + data->mac_stats.OctetsTransmittedOK) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_9_TX_DEFER, + data->mac_stats.FramesWithDeferredXmissions) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_10_LATE_COL, + data->mac_stats.LateCollisions) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_11_XS_COL, + data->mac_stats.FramesAbortedDueToXSColls) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_12_TX_INT_ERR, + data->mac_stats.FramesLostDueToIntMACXmitError) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_13_CS_ERR, + data->mac_stats.CarrierSenseErrors) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_14_RX_BYTES, + data->mac_stats.OctetsReceivedOK) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_15_RX_INT_ERR, + data->mac_stats.FramesLostDueToIntMACRcvError) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_18_TX_MCAST, + data->mac_stats.MulticastFramesXmittedOK) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_19_TX_BCAST, + data->mac_stats.BroadcastFramesXmittedOK) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_20_XS_DEFER, + data->mac_stats.FramesWithExcessiveDeferral) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_21_RX_MCAST, + data->mac_stats.MulticastFramesReceivedOK) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_22_RX_BCAST, + data->mac_stats.BroadcastFramesReceivedOK) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_23_IR_LEN_ERR, + data->mac_stats.InRangeLengthErrors) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_24_OOR_LEN, + data->mac_stats.OutOfRangeLengthField) || + stat_put(skb, ETHTOOL_A_STATS_ETH_MAC_25_TOO_LONG_ERR, + data->mac_stats.FrameTooLongErrors)) + return -EMSGSIZE; + return 0; +} + +static int stats_put_ctrl_stats(struct sk_buff *skb, + const struct stats_reply_data *data) +{ + if (stat_put(skb, ETHTOOL_A_STATS_ETH_CTRL_3_TX, + data->ctrl_stats.MACControlFramesTransmitted) || + stat_put(skb, ETHTOOL_A_STATS_ETH_CTRL_4_RX, + data->ctrl_stats.MACControlFramesReceived) || + stat_put(skb, ETHTOOL_A_STATS_ETH_CTRL_5_RX_UNSUP, + data->ctrl_stats.UnsupportedOpcodesReceived)) + return -EMSGSIZE; + return 0; +} + +static int stats_put_rmon_hist(struct sk_buff *skb, u32 attr, const u64 *hist, + const struct ethtool_rmon_hist_range *ranges) +{ + struct nlattr *nest; + int i; + + if (!ranges) + return 0; + + for (i = 0; i < ETHTOOL_RMON_HIST_MAX; i++) { + if (!ranges[i].low && !ranges[i].high) + break; + if (hist[i] == ETHTOOL_STAT_NOT_SET) + continue; + + nest = nla_nest_start(skb, attr); + if (!nest) + return -EMSGSIZE; + + if (nla_put_u32(skb, ETHTOOL_A_STATS_GRP_HIST_BKT_LOW, + ranges[i].low) || + nla_put_u32(skb, ETHTOOL_A_STATS_GRP_HIST_BKT_HI, + ranges[i].high) || + nla_put_u64_64bit(skb, ETHTOOL_A_STATS_GRP_HIST_VAL, + hist[i], ETHTOOL_A_STATS_GRP_PAD)) + goto err_cancel_hist; + + nla_nest_end(skb, nest); + } + + return 0; + +err_cancel_hist: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int stats_put_rmon_stats(struct sk_buff *skb, + const struct stats_reply_data *data) +{ + if (stats_put_rmon_hist(skb, ETHTOOL_A_STATS_GRP_HIST_RX, + data->rmon_stats.hist, data->rmon_ranges) || + stats_put_rmon_hist(skb, ETHTOOL_A_STATS_GRP_HIST_TX, + data->rmon_stats.hist_tx, data->rmon_ranges)) + return -EMSGSIZE; + + if (stat_put(skb, ETHTOOL_A_STATS_RMON_UNDERSIZE, + data->rmon_stats.undersize_pkts) || + stat_put(skb, ETHTOOL_A_STATS_RMON_OVERSIZE, + data->rmon_stats.oversize_pkts) || + stat_put(skb, ETHTOOL_A_STATS_RMON_FRAG, + data->rmon_stats.fragments) || + stat_put(skb, ETHTOOL_A_STATS_RMON_JABBER, + data->rmon_stats.jabbers)) + return -EMSGSIZE; + + return 0; +} + +static int stats_put_stats(struct sk_buff *skb, + const struct stats_reply_data *data, + u32 id, u32 ss_id, + int (*cb)(struct sk_buff *skb, + const struct stats_reply_data *data)) +{ + struct nlattr *nest; + + nest = nla_nest_start(skb, ETHTOOL_A_STATS_GRP); + if (!nest) + return -EMSGSIZE; + + if (nla_put_u32(skb, ETHTOOL_A_STATS_GRP_ID, id) || + nla_put_u32(skb, ETHTOOL_A_STATS_GRP_SS_ID, ss_id)) + goto err_cancel; + + if (cb(skb, data)) + goto err_cancel; + + nla_nest_end(skb, nest); + return 0; + +err_cancel: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int stats_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct stats_req_info *req_info = STATS_REQINFO(req_base); + const struct stats_reply_data *data = STATS_REPDATA(reply_base); + int ret = 0; + + if (!ret && test_bit(ETHTOOL_STATS_ETH_PHY, req_info->stat_mask)) + ret = stats_put_stats(skb, data, ETHTOOL_STATS_ETH_PHY, + ETH_SS_STATS_ETH_PHY, + stats_put_phy_stats); + if (!ret && test_bit(ETHTOOL_STATS_ETH_MAC, req_info->stat_mask)) + ret = stats_put_stats(skb, data, ETHTOOL_STATS_ETH_MAC, + ETH_SS_STATS_ETH_MAC, + stats_put_mac_stats); + if (!ret && test_bit(ETHTOOL_STATS_ETH_CTRL, req_info->stat_mask)) + ret = stats_put_stats(skb, data, ETHTOOL_STATS_ETH_CTRL, + ETH_SS_STATS_ETH_CTRL, + stats_put_ctrl_stats); + if (!ret && test_bit(ETHTOOL_STATS_RMON, req_info->stat_mask)) + ret = stats_put_stats(skb, data, ETHTOOL_STATS_RMON, + ETH_SS_STATS_RMON, stats_put_rmon_stats); + + return ret; +} + +const struct ethnl_request_ops ethnl_stats_request_ops = { + .request_cmd = ETHTOOL_MSG_STATS_GET, + .reply_cmd = ETHTOOL_MSG_STATS_GET_REPLY, + .hdr_attr = ETHTOOL_A_STATS_HEADER, + .req_info_size = sizeof(struct stats_req_info), + .reply_data_size = sizeof(struct stats_reply_data), + + .parse_request = stats_parse_request, + .prepare_data = stats_prepare_data, + .reply_size = stats_reply_size, + .fill_reply = stats_fill_reply, +}; diff --git a/net/ethtool/strset.c b/net/ethtool/strset.c new file mode 100644 index 000000000..3f7de54d8 --- /dev/null +++ b/net/ethtool/strset.c @@ -0,0 +1,480 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include +#include +#include "netlink.h" +#include "common.h" + +struct strset_info { + bool per_dev; + bool free_strings; + unsigned int count; + const char (*strings)[ETH_GSTRING_LEN]; +}; + +static const struct strset_info info_template[] = { + [ETH_SS_TEST] = { + .per_dev = true, + }, + [ETH_SS_STATS] = { + .per_dev = true, + }, + [ETH_SS_PRIV_FLAGS] = { + .per_dev = true, + }, + [ETH_SS_FEATURES] = { + .per_dev = false, + .count = ARRAY_SIZE(netdev_features_strings), + .strings = netdev_features_strings, + }, + [ETH_SS_RSS_HASH_FUNCS] = { + .per_dev = false, + .count = ARRAY_SIZE(rss_hash_func_strings), + .strings = rss_hash_func_strings, + }, + [ETH_SS_TUNABLES] = { + .per_dev = false, + .count = ARRAY_SIZE(tunable_strings), + .strings = tunable_strings, + }, + [ETH_SS_PHY_STATS] = { + .per_dev = true, + }, + [ETH_SS_PHY_TUNABLES] = { + .per_dev = false, + .count = ARRAY_SIZE(phy_tunable_strings), + .strings = phy_tunable_strings, + }, + [ETH_SS_LINK_MODES] = { + .per_dev = false, + .count = __ETHTOOL_LINK_MODE_MASK_NBITS, + .strings = link_mode_names, + }, + [ETH_SS_MSG_CLASSES] = { + .per_dev = false, + .count = NETIF_MSG_CLASS_COUNT, + .strings = netif_msg_class_names, + }, + [ETH_SS_WOL_MODES] = { + .per_dev = false, + .count = WOL_MODE_COUNT, + .strings = wol_mode_names, + }, + [ETH_SS_SOF_TIMESTAMPING] = { + .per_dev = false, + .count = __SOF_TIMESTAMPING_CNT, + .strings = sof_timestamping_names, + }, + [ETH_SS_TS_TX_TYPES] = { + .per_dev = false, + .count = __HWTSTAMP_TX_CNT, + .strings = ts_tx_type_names, + }, + [ETH_SS_TS_RX_FILTERS] = { + .per_dev = false, + .count = __HWTSTAMP_FILTER_CNT, + .strings = ts_rx_filter_names, + }, + [ETH_SS_UDP_TUNNEL_TYPES] = { + .per_dev = false, + .count = __ETHTOOL_UDP_TUNNEL_TYPE_CNT, + .strings = udp_tunnel_type_names, + }, + [ETH_SS_STATS_STD] = { + .per_dev = false, + .count = __ETHTOOL_STATS_CNT, + .strings = stats_std_names, + }, + [ETH_SS_STATS_ETH_PHY] = { + .per_dev = false, + .count = __ETHTOOL_A_STATS_ETH_PHY_CNT, + .strings = stats_eth_phy_names, + }, + [ETH_SS_STATS_ETH_MAC] = { + .per_dev = false, + .count = __ETHTOOL_A_STATS_ETH_MAC_CNT, + .strings = stats_eth_mac_names, + }, + [ETH_SS_STATS_ETH_CTRL] = { + .per_dev = false, + .count = __ETHTOOL_A_STATS_ETH_CTRL_CNT, + .strings = stats_eth_ctrl_names, + }, + [ETH_SS_STATS_RMON] = { + .per_dev = false, + .count = __ETHTOOL_A_STATS_RMON_CNT, + .strings = stats_rmon_names, + }, +}; + +struct strset_req_info { + struct ethnl_req_info base; + u32 req_ids; + bool counts_only; +}; + +#define STRSET_REQINFO(__req_base) \ + container_of(__req_base, struct strset_req_info, base) + +struct strset_reply_data { + struct ethnl_reply_data base; + struct strset_info sets[ETH_SS_COUNT]; +}; + +#define STRSET_REPDATA(__reply_base) \ + container_of(__reply_base, struct strset_reply_data, base) + +const struct nla_policy ethnl_strset_get_policy[] = { + [ETHTOOL_A_STRSET_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_STRSET_STRINGSETS] = { .type = NLA_NESTED }, + [ETHTOOL_A_STRSET_COUNTS_ONLY] = { .type = NLA_FLAG }, +}; + +static const struct nla_policy get_stringset_policy[] = { + [ETHTOOL_A_STRINGSET_ID] = { .type = NLA_U32 }, +}; + +/** + * strset_include() - test if a string set should be included in reply + * @info: parsed client request + * @data: pointer to request data structure + * @id: id of string set to check (ETH_SS_* constants) + */ +static bool strset_include(const struct strset_req_info *info, + const struct strset_reply_data *data, u32 id) +{ + bool per_dev; + + BUILD_BUG_ON(ETH_SS_COUNT >= BITS_PER_BYTE * sizeof(info->req_ids)); + + if (info->req_ids) + return info->req_ids & (1U << id); + per_dev = data->sets[id].per_dev; + if (!per_dev && !data->sets[id].strings) + return false; + + return data->base.dev ? per_dev : !per_dev; +} + +static int strset_get_id(const struct nlattr *nest, u32 *val, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[ARRAY_SIZE(get_stringset_policy)]; + int ret; + + ret = nla_parse_nested(tb, ARRAY_SIZE(get_stringset_policy) - 1, nest, + get_stringset_policy, extack); + if (ret < 0) + return ret; + if (NL_REQ_ATTR_CHECK(extack, nest, tb, ETHTOOL_A_STRINGSET_ID)) + return -EINVAL; + + *val = nla_get_u32(tb[ETHTOOL_A_STRINGSET_ID]); + return 0; +} + +static const struct nla_policy strset_stringsets_policy[] = { + [ETHTOOL_A_STRINGSETS_STRINGSET] = { .type = NLA_NESTED }, +}; + +static int strset_parse_request(struct ethnl_req_info *req_base, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct strset_req_info *req_info = STRSET_REQINFO(req_base); + struct nlattr *nest = tb[ETHTOOL_A_STRSET_STRINGSETS]; + struct nlattr *attr; + int rem, ret; + + if (!nest) + return 0; + ret = nla_validate_nested(nest, + ARRAY_SIZE(strset_stringsets_policy) - 1, + strset_stringsets_policy, extack); + if (ret < 0) + return ret; + + req_info->counts_only = tb[ETHTOOL_A_STRSET_COUNTS_ONLY]; + nla_for_each_nested(attr, nest, rem) { + u32 id; + + if (WARN_ONCE(nla_type(attr) != ETHTOOL_A_STRINGSETS_STRINGSET, + "unexpected attrtype %u in ETHTOOL_A_STRSET_STRINGSETS\n", + nla_type(attr))) + return -EINVAL; + + ret = strset_get_id(attr, &id, extack); + if (ret < 0) + return ret; + if (id >= ETH_SS_COUNT) { + NL_SET_ERR_MSG_ATTR(extack, attr, + "unknown string set id"); + return -EOPNOTSUPP; + } + + req_info->req_ids |= (1U << id); + } + + return 0; +} + +static void strset_cleanup_data(struct ethnl_reply_data *reply_base) +{ + struct strset_reply_data *data = STRSET_REPDATA(reply_base); + unsigned int i; + + for (i = 0; i < ETH_SS_COUNT; i++) + if (data->sets[i].free_strings) { + kfree(data->sets[i].strings); + data->sets[i].strings = NULL; + data->sets[i].free_strings = false; + } +} + +static int strset_prepare_set(struct strset_info *info, struct net_device *dev, + unsigned int id, bool counts_only) +{ + const struct ethtool_phy_ops *phy_ops = ethtool_phy_ops; + const struct ethtool_ops *ops = dev->ethtool_ops; + void *strings; + int count, ret; + + if (id == ETH_SS_PHY_STATS && dev->phydev && + !ops->get_ethtool_phy_stats && phy_ops && + phy_ops->get_sset_count) + ret = phy_ops->get_sset_count(dev->phydev); + else if (ops->get_sset_count && ops->get_strings) + ret = ops->get_sset_count(dev, id); + else + ret = -EOPNOTSUPP; + if (ret <= 0) { + info->count = 0; + return 0; + } + + count = ret; + if (!counts_only) { + strings = kcalloc(count, ETH_GSTRING_LEN, GFP_KERNEL); + if (!strings) + return -ENOMEM; + if (id == ETH_SS_PHY_STATS && dev->phydev && + !ops->get_ethtool_phy_stats && phy_ops && + phy_ops->get_strings) + phy_ops->get_strings(dev->phydev, strings); + else + ops->get_strings(dev, id, strings); + info->strings = strings; + info->free_strings = true; + } + info->count = count; + + return 0; +} + +static int strset_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + const struct strset_req_info *req_info = STRSET_REQINFO(req_base); + struct strset_reply_data *data = STRSET_REPDATA(reply_base); + struct net_device *dev = reply_base->dev; + unsigned int i; + int ret; + + BUILD_BUG_ON(ARRAY_SIZE(info_template) != ETH_SS_COUNT); + memcpy(&data->sets, &info_template, sizeof(data->sets)); + + if (!dev) { + for (i = 0; i < ETH_SS_COUNT; i++) { + if ((req_info->req_ids & (1U << i)) && + data->sets[i].per_dev) { + if (info) + GENL_SET_ERR_MSG(info, "requested per device strings without dev"); + return -EINVAL; + } + } + return 0; + } + + ret = ethnl_ops_begin(dev); + if (ret < 0) + goto err_strset; + for (i = 0; i < ETH_SS_COUNT; i++) { + if (!strset_include(req_info, data, i) || + !data->sets[i].per_dev) + continue; + + ret = strset_prepare_set(&data->sets[i], dev, i, + req_info->counts_only); + if (ret < 0) + goto err_ops; + } + ethnl_ops_complete(dev); + + return 0; +err_ops: + ethnl_ops_complete(dev); +err_strset: + strset_cleanup_data(reply_base); + return ret; +} + +/* calculate size of ETHTOOL_A_STRSET_STRINGSET nest for one string set */ +static int strset_set_size(const struct strset_info *info, bool counts_only) +{ + unsigned int len = 0; + unsigned int i; + + if (info->count == 0) + return 0; + if (counts_only) + return nla_total_size(2 * nla_total_size(sizeof(u32))); + + for (i = 0; i < info->count; i++) { + const char *str = info->strings[i]; + + /* ETHTOOL_A_STRING_INDEX, ETHTOOL_A_STRING_VALUE, nest */ + len += nla_total_size(nla_total_size(sizeof(u32)) + + ethnl_strz_size(str)); + } + /* ETHTOOL_A_STRINGSET_ID, ETHTOOL_A_STRINGSET_COUNT */ + len = 2 * nla_total_size(sizeof(u32)) + nla_total_size(len); + + return nla_total_size(len); +} + +static int strset_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct strset_req_info *req_info = STRSET_REQINFO(req_base); + const struct strset_reply_data *data = STRSET_REPDATA(reply_base); + unsigned int i; + int len = 0; + int ret; + + len += nla_total_size(0); /* ETHTOOL_A_STRSET_STRINGSETS */ + + for (i = 0; i < ETH_SS_COUNT; i++) { + const struct strset_info *set_info = &data->sets[i]; + + if (!strset_include(req_info, data, i)) + continue; + + ret = strset_set_size(set_info, req_info->counts_only); + if (ret < 0) + return ret; + len += ret; + } + + return len; +} + +/* fill one string into reply */ +static int strset_fill_string(struct sk_buff *skb, + const struct strset_info *set_info, u32 idx) +{ + struct nlattr *string_attr; + const char *value; + + value = set_info->strings[idx]; + + string_attr = nla_nest_start(skb, ETHTOOL_A_STRINGS_STRING); + if (!string_attr) + return -EMSGSIZE; + if (nla_put_u32(skb, ETHTOOL_A_STRING_INDEX, idx) || + ethnl_put_strz(skb, ETHTOOL_A_STRING_VALUE, value)) + goto nla_put_failure; + nla_nest_end(skb, string_attr); + + return 0; +nla_put_failure: + nla_nest_cancel(skb, string_attr); + return -EMSGSIZE; +} + +/* fill one string set into reply */ +static int strset_fill_set(struct sk_buff *skb, + const struct strset_info *set_info, u32 id, + bool counts_only) +{ + struct nlattr *stringset_attr; + struct nlattr *strings_attr; + unsigned int i; + + if (!set_info->per_dev && !set_info->strings) + return -EOPNOTSUPP; + if (set_info->count == 0) + return 0; + stringset_attr = nla_nest_start(skb, ETHTOOL_A_STRINGSETS_STRINGSET); + if (!stringset_attr) + return -EMSGSIZE; + + if (nla_put_u32(skb, ETHTOOL_A_STRINGSET_ID, id) || + nla_put_u32(skb, ETHTOOL_A_STRINGSET_COUNT, set_info->count)) + goto nla_put_failure; + + if (!counts_only) { + strings_attr = nla_nest_start(skb, ETHTOOL_A_STRINGSET_STRINGS); + if (!strings_attr) + goto nla_put_failure; + for (i = 0; i < set_info->count; i++) { + if (strset_fill_string(skb, set_info, i) < 0) + goto nla_put_failure; + } + nla_nest_end(skb, strings_attr); + } + + nla_nest_end(skb, stringset_attr); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, stringset_attr); + return -EMSGSIZE; +} + +static int strset_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct strset_req_info *req_info = STRSET_REQINFO(req_base); + const struct strset_reply_data *data = STRSET_REPDATA(reply_base); + struct nlattr *nest; + unsigned int i; + int ret; + + nest = nla_nest_start(skb, ETHTOOL_A_STRSET_STRINGSETS); + if (!nest) + return -EMSGSIZE; + + for (i = 0; i < ETH_SS_COUNT; i++) { + if (strset_include(req_info, data, i)) { + ret = strset_fill_set(skb, &data->sets[i], i, + req_info->counts_only); + if (ret < 0) + goto nla_put_failure; + } + } + + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return ret; +} + +const struct ethnl_request_ops ethnl_strset_request_ops = { + .request_cmd = ETHTOOL_MSG_STRSET_GET, + .reply_cmd = ETHTOOL_MSG_STRSET_GET_REPLY, + .hdr_attr = ETHTOOL_A_STRSET_HEADER, + .req_info_size = sizeof(struct strset_req_info), + .reply_data_size = sizeof(struct strset_reply_data), + .allow_nodev_do = true, + + .parse_request = strset_parse_request, + .prepare_data = strset_prepare_data, + .reply_size = strset_reply_size, + .fill_reply = strset_fill_reply, + .cleanup_data = strset_cleanup_data, +}; diff --git a/net/ethtool/tsinfo.c b/net/ethtool/tsinfo.c new file mode 100644 index 000000000..63b5814bd --- /dev/null +++ b/net/ethtool/tsinfo.c @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include + +#include "netlink.h" +#include "common.h" +#include "bitset.h" + +struct tsinfo_req_info { + struct ethnl_req_info base; +}; + +struct tsinfo_reply_data { + struct ethnl_reply_data base; + struct ethtool_ts_info ts_info; +}; + +#define TSINFO_REPDATA(__reply_base) \ + container_of(__reply_base, struct tsinfo_reply_data, base) + +const struct nla_policy ethnl_tsinfo_get_policy[] = { + [ETHTOOL_A_TSINFO_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), +}; + +static int tsinfo_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + struct tsinfo_reply_data *data = TSINFO_REPDATA(reply_base); + struct net_device *dev = reply_base->dev; + int ret; + + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + ret = __ethtool_get_ts_info(dev, &data->ts_info); + ethnl_ops_complete(dev); + + return ret; +} + +static int tsinfo_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct tsinfo_reply_data *data = TSINFO_REPDATA(reply_base); + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + const struct ethtool_ts_info *ts_info = &data->ts_info; + int len = 0; + int ret; + + BUILD_BUG_ON(__SOF_TIMESTAMPING_CNT > 32); + BUILD_BUG_ON(__HWTSTAMP_TX_CNT > 32); + BUILD_BUG_ON(__HWTSTAMP_FILTER_CNT > 32); + + if (ts_info->so_timestamping) { + ret = ethnl_bitset32_size(&ts_info->so_timestamping, NULL, + __SOF_TIMESTAMPING_CNT, + sof_timestamping_names, compact); + if (ret < 0) + return ret; + len += ret; /* _TSINFO_TIMESTAMPING */ + } + if (ts_info->tx_types) { + ret = ethnl_bitset32_size(&ts_info->tx_types, NULL, + __HWTSTAMP_TX_CNT, + ts_tx_type_names, compact); + if (ret < 0) + return ret; + len += ret; /* _TSINFO_TX_TYPES */ + } + if (ts_info->rx_filters) { + ret = ethnl_bitset32_size(&ts_info->rx_filters, NULL, + __HWTSTAMP_FILTER_CNT, + ts_rx_filter_names, compact); + if (ret < 0) + return ret; + len += ret; /* _TSINFO_RX_FILTERS */ + } + if (ts_info->phc_index >= 0) + len += nla_total_size(sizeof(u32)); /* _TSINFO_PHC_INDEX */ + + return len; +} + +static int tsinfo_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + const struct tsinfo_reply_data *data = TSINFO_REPDATA(reply_base); + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + const struct ethtool_ts_info *ts_info = &data->ts_info; + int ret; + + if (ts_info->so_timestamping) { + ret = ethnl_put_bitset32(skb, ETHTOOL_A_TSINFO_TIMESTAMPING, + &ts_info->so_timestamping, NULL, + __SOF_TIMESTAMPING_CNT, + sof_timestamping_names, compact); + if (ret < 0) + return ret; + } + if (ts_info->tx_types) { + ret = ethnl_put_bitset32(skb, ETHTOOL_A_TSINFO_TX_TYPES, + &ts_info->tx_types, NULL, + __HWTSTAMP_TX_CNT, + ts_tx_type_names, compact); + if (ret < 0) + return ret; + } + if (ts_info->rx_filters) { + ret = ethnl_put_bitset32(skb, ETHTOOL_A_TSINFO_RX_FILTERS, + &ts_info->rx_filters, NULL, + __HWTSTAMP_FILTER_CNT, + ts_rx_filter_names, compact); + if (ret < 0) + return ret; + } + if (ts_info->phc_index >= 0 && + nla_put_u32(skb, ETHTOOL_A_TSINFO_PHC_INDEX, ts_info->phc_index)) + return -EMSGSIZE; + + return 0; +} + +const struct ethnl_request_ops ethnl_tsinfo_request_ops = { + .request_cmd = ETHTOOL_MSG_TSINFO_GET, + .reply_cmd = ETHTOOL_MSG_TSINFO_GET_REPLY, + .hdr_attr = ETHTOOL_A_TSINFO_HEADER, + .req_info_size = sizeof(struct tsinfo_req_info), + .reply_data_size = sizeof(struct tsinfo_reply_data), + + .prepare_data = tsinfo_prepare_data, + .reply_size = tsinfo_reply_size, + .fill_reply = tsinfo_fill_reply, +}; diff --git a/net/ethtool/tunnels.c b/net/ethtool/tunnels.c new file mode 100644 index 000000000..67fb414ca --- /dev/null +++ b/net/ethtool/tunnels.c @@ -0,0 +1,300 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include +#include +#include + +#include "bitset.h" +#include "common.h" +#include "netlink.h" + +const struct nla_policy ethnl_tunnel_info_get_policy[] = { + [ETHTOOL_A_TUNNEL_INFO_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), +}; + +static_assert(ETHTOOL_UDP_TUNNEL_TYPE_VXLAN == ilog2(UDP_TUNNEL_TYPE_VXLAN)); +static_assert(ETHTOOL_UDP_TUNNEL_TYPE_GENEVE == ilog2(UDP_TUNNEL_TYPE_GENEVE)); +static_assert(ETHTOOL_UDP_TUNNEL_TYPE_VXLAN_GPE == + ilog2(UDP_TUNNEL_TYPE_VXLAN_GPE)); + +static ssize_t ethnl_udp_table_reply_size(unsigned int types, bool compact) +{ + ssize_t size; + + size = ethnl_bitset32_size(&types, NULL, __ETHTOOL_UDP_TUNNEL_TYPE_CNT, + udp_tunnel_type_names, compact); + if (size < 0) + return size; + + return size + + nla_total_size(0) + /* _UDP_TABLE */ + nla_total_size(sizeof(u32)); /* _UDP_TABLE_SIZE */ +} + +static ssize_t +ethnl_tunnel_info_reply_size(const struct ethnl_req_info *req_base, + struct netlink_ext_ack *extack) +{ + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + const struct udp_tunnel_nic_info *info; + unsigned int i; + ssize_t ret; + size_t size; + + info = req_base->dev->udp_tunnel_nic_info; + if (!info) { + NL_SET_ERR_MSG(extack, + "device does not report tunnel offload info"); + return -EOPNOTSUPP; + } + + size = nla_total_size(0); /* _INFO_UDP_PORTS */ + + for (i = 0; i < UDP_TUNNEL_NIC_MAX_TABLES; i++) { + if (!info->tables[i].n_entries) + break; + + ret = ethnl_udp_table_reply_size(info->tables[i].tunnel_types, + compact); + if (ret < 0) + return ret; + size += ret; + + size += udp_tunnel_nic_dump_size(req_base->dev, i); + } + + if (info->flags & UDP_TUNNEL_NIC_INFO_STATIC_IANA_VXLAN) { + ret = ethnl_udp_table_reply_size(0, compact); + if (ret < 0) + return ret; + size += ret; + + size += nla_total_size(0) + /* _TABLE_ENTRY */ + nla_total_size(sizeof(__be16)) + /* _ENTRY_PORT */ + nla_total_size(sizeof(u32)); /* _ENTRY_TYPE */ + } + + return size; +} + +static int +ethnl_tunnel_info_fill_reply(const struct ethnl_req_info *req_base, + struct sk_buff *skb) +{ + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + const struct udp_tunnel_nic_info *info; + struct nlattr *ports, *table, *entry; + unsigned int i; + + info = req_base->dev->udp_tunnel_nic_info; + if (!info) + return -EOPNOTSUPP; + + ports = nla_nest_start(skb, ETHTOOL_A_TUNNEL_INFO_UDP_PORTS); + if (!ports) + return -EMSGSIZE; + + for (i = 0; i < UDP_TUNNEL_NIC_MAX_TABLES; i++) { + if (!info->tables[i].n_entries) + break; + + table = nla_nest_start(skb, ETHTOOL_A_TUNNEL_UDP_TABLE); + if (!table) + goto err_cancel_ports; + + if (nla_put_u32(skb, ETHTOOL_A_TUNNEL_UDP_TABLE_SIZE, + info->tables[i].n_entries)) + goto err_cancel_table; + + if (ethnl_put_bitset32(skb, ETHTOOL_A_TUNNEL_UDP_TABLE_TYPES, + &info->tables[i].tunnel_types, NULL, + __ETHTOOL_UDP_TUNNEL_TYPE_CNT, + udp_tunnel_type_names, compact)) + goto err_cancel_table; + + if (udp_tunnel_nic_dump_write(req_base->dev, i, skb)) + goto err_cancel_table; + + nla_nest_end(skb, table); + } + + if (info->flags & UDP_TUNNEL_NIC_INFO_STATIC_IANA_VXLAN) { + u32 zero = 0; + + table = nla_nest_start(skb, ETHTOOL_A_TUNNEL_UDP_TABLE); + if (!table) + goto err_cancel_ports; + + if (nla_put_u32(skb, ETHTOOL_A_TUNNEL_UDP_TABLE_SIZE, 1)) + goto err_cancel_table; + + if (ethnl_put_bitset32(skb, ETHTOOL_A_TUNNEL_UDP_TABLE_TYPES, + &zero, NULL, + __ETHTOOL_UDP_TUNNEL_TYPE_CNT, + udp_tunnel_type_names, compact)) + goto err_cancel_table; + + entry = nla_nest_start(skb, ETHTOOL_A_TUNNEL_UDP_TABLE_ENTRY); + if (!entry) + goto err_cancel_entry; + + if (nla_put_be16(skb, ETHTOOL_A_TUNNEL_UDP_ENTRY_PORT, + htons(IANA_VXLAN_UDP_PORT)) || + nla_put_u32(skb, ETHTOOL_A_TUNNEL_UDP_ENTRY_TYPE, + ilog2(UDP_TUNNEL_TYPE_VXLAN))) + goto err_cancel_entry; + + nla_nest_end(skb, entry); + nla_nest_end(skb, table); + } + + nla_nest_end(skb, ports); + + return 0; + +err_cancel_entry: + nla_nest_cancel(skb, entry); +err_cancel_table: + nla_nest_cancel(skb, table); +err_cancel_ports: + nla_nest_cancel(skb, ports); + return -EMSGSIZE; +} + +int ethnl_tunnel_info_doit(struct sk_buff *skb, struct genl_info *info) +{ + struct ethnl_req_info req_info = {}; + struct nlattr **tb = info->attrs; + struct sk_buff *rskb; + void *reply_payload; + int reply_len; + int ret; + + ret = ethnl_parse_header_dev_get(&req_info, + tb[ETHTOOL_A_TUNNEL_INFO_HEADER], + genl_info_net(info), info->extack, + true); + if (ret < 0) + return ret; + + rtnl_lock(); + ret = ethnl_tunnel_info_reply_size(&req_info, info->extack); + if (ret < 0) + goto err_unlock_rtnl; + reply_len = ret + ethnl_reply_header_size(); + + rskb = ethnl_reply_init(reply_len, req_info.dev, + ETHTOOL_MSG_TUNNEL_INFO_GET_REPLY, + ETHTOOL_A_TUNNEL_INFO_HEADER, + info, &reply_payload); + if (!rskb) { + ret = -ENOMEM; + goto err_unlock_rtnl; + } + + ret = ethnl_tunnel_info_fill_reply(&req_info, rskb); + if (ret) + goto err_free_msg; + rtnl_unlock(); + ethnl_parse_header_dev_put(&req_info); + genlmsg_end(rskb, reply_payload); + + return genlmsg_reply(rskb, info); + +err_free_msg: + nlmsg_free(rskb); +err_unlock_rtnl: + rtnl_unlock(); + ethnl_parse_header_dev_put(&req_info); + return ret; +} + +struct ethnl_tunnel_info_dump_ctx { + struct ethnl_req_info req_info; + int pos_hash; + int pos_idx; +}; + +int ethnl_tunnel_info_start(struct netlink_callback *cb) +{ + const struct genl_dumpit_info *info = genl_dumpit_info(cb); + struct ethnl_tunnel_info_dump_ctx *ctx = (void *)cb->ctx; + struct nlattr **tb = info->attrs; + int ret; + + BUILD_BUG_ON(sizeof(*ctx) > sizeof(cb->ctx)); + + memset(ctx, 0, sizeof(*ctx)); + + ret = ethnl_parse_header_dev_get(&ctx->req_info, + tb[ETHTOOL_A_TUNNEL_INFO_HEADER], + sock_net(cb->skb->sk), cb->extack, + false); + if (ctx->req_info.dev) { + ethnl_parse_header_dev_put(&ctx->req_info); + ctx->req_info.dev = NULL; + } + + return ret; +} + +int ethnl_tunnel_info_dumpit(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct ethnl_tunnel_info_dump_ctx *ctx = (void *)cb->ctx; + struct net *net = sock_net(skb->sk); + int s_idx = ctx->pos_idx; + int h, idx = 0; + int ret = 0; + void *ehdr; + + rtnl_lock(); + cb->seq = net->dev_base_seq; + for (h = ctx->pos_hash; h < NETDEV_HASHENTRIES; h++, s_idx = 0) { + struct hlist_head *head; + struct net_device *dev; + + head = &net->dev_index_head[h]; + idx = 0; + hlist_for_each_entry(dev, head, index_hlist) { + if (idx < s_idx) + goto cont; + + ehdr = ethnl_dump_put(skb, cb, + ETHTOOL_MSG_TUNNEL_INFO_GET_REPLY); + if (!ehdr) { + ret = -EMSGSIZE; + goto out; + } + + ret = ethnl_fill_reply_header(skb, dev, ETHTOOL_A_TUNNEL_INFO_HEADER); + if (ret < 0) { + genlmsg_cancel(skb, ehdr); + goto out; + } + + ctx->req_info.dev = dev; + ret = ethnl_tunnel_info_fill_reply(&ctx->req_info, skb); + ctx->req_info.dev = NULL; + if (ret < 0) { + genlmsg_cancel(skb, ehdr); + if (ret == -EOPNOTSUPP) + goto cont; + goto out; + } + genlmsg_end(skb, ehdr); +cont: + idx++; + } + } +out: + rtnl_unlock(); + + ctx->pos_hash = h; + ctx->pos_idx = idx; + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); + + if (ret == -EMSGSIZE && skb->len) + return skb->len; + return ret; +} diff --git a/net/ethtool/wol.c b/net/ethtool/wol.c new file mode 100644 index 000000000..88f435e76 --- /dev/null +++ b/net/ethtool/wol.c @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include "netlink.h" +#include "common.h" +#include "bitset.h" + +struct wol_req_info { + struct ethnl_req_info base; +}; + +struct wol_reply_data { + struct ethnl_reply_data base; + struct ethtool_wolinfo wol; + bool show_sopass; +}; + +#define WOL_REPDATA(__reply_base) \ + container_of(__reply_base, struct wol_reply_data, base) + +const struct nla_policy ethnl_wol_get_policy[] = { + [ETHTOOL_A_WOL_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), +}; + +static int wol_prepare_data(const struct ethnl_req_info *req_base, + struct ethnl_reply_data *reply_base, + struct genl_info *info) +{ + struct wol_reply_data *data = WOL_REPDATA(reply_base); + struct net_device *dev = reply_base->dev; + int ret; + + if (!dev->ethtool_ops->get_wol) + return -EOPNOTSUPP; + + ret = ethnl_ops_begin(dev); + if (ret < 0) + return ret; + dev->ethtool_ops->get_wol(dev, &data->wol); + ethnl_ops_complete(dev); + /* do not include password in notifications */ + data->show_sopass = info && (data->wol.supported & WAKE_MAGICSECURE); + + return 0; +} + +static int wol_reply_size(const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + const struct wol_reply_data *data = WOL_REPDATA(reply_base); + int len; + + len = ethnl_bitset32_size(&data->wol.wolopts, &data->wol.supported, + WOL_MODE_COUNT, wol_mode_names, compact); + if (len < 0) + return len; + if (data->show_sopass) + len += nla_total_size(sizeof(data->wol.sopass)); + + return len; +} + +static int wol_fill_reply(struct sk_buff *skb, + const struct ethnl_req_info *req_base, + const struct ethnl_reply_data *reply_base) +{ + bool compact = req_base->flags & ETHTOOL_FLAG_COMPACT_BITSETS; + const struct wol_reply_data *data = WOL_REPDATA(reply_base); + int ret; + + ret = ethnl_put_bitset32(skb, ETHTOOL_A_WOL_MODES, &data->wol.wolopts, + &data->wol.supported, WOL_MODE_COUNT, + wol_mode_names, compact); + if (ret < 0) + return ret; + if (data->show_sopass && + nla_put(skb, ETHTOOL_A_WOL_SOPASS, sizeof(data->wol.sopass), + data->wol.sopass)) + return -EMSGSIZE; + + return 0; +} + +const struct ethnl_request_ops ethnl_wol_request_ops = { + .request_cmd = ETHTOOL_MSG_WOL_GET, + .reply_cmd = ETHTOOL_MSG_WOL_GET_REPLY, + .hdr_attr = ETHTOOL_A_WOL_HEADER, + .req_info_size = sizeof(struct wol_req_info), + .reply_data_size = sizeof(struct wol_reply_data), + + .prepare_data = wol_prepare_data, + .reply_size = wol_reply_size, + .fill_reply = wol_fill_reply, +}; + +/* WOL_SET */ + +const struct nla_policy ethnl_wol_set_policy[] = { + [ETHTOOL_A_WOL_HEADER] = + NLA_POLICY_NESTED(ethnl_header_policy), + [ETHTOOL_A_WOL_MODES] = { .type = NLA_NESTED }, + [ETHTOOL_A_WOL_SOPASS] = { .type = NLA_BINARY, + .len = SOPASS_MAX }, +}; + +int ethnl_set_wol(struct sk_buff *skb, struct genl_info *info) +{ + struct ethtool_wolinfo wol = { .cmd = ETHTOOL_GWOL }; + struct ethnl_req_info req_info = {}; + struct nlattr **tb = info->attrs; + struct net_device *dev; + bool mod = false; + int ret; + + ret = ethnl_parse_header_dev_get(&req_info, tb[ETHTOOL_A_WOL_HEADER], + genl_info_net(info), info->extack, + true); + if (ret < 0) + return ret; + dev = req_info.dev; + ret = -EOPNOTSUPP; + if (!dev->ethtool_ops->get_wol || !dev->ethtool_ops->set_wol) + goto out_dev; + + rtnl_lock(); + ret = ethnl_ops_begin(dev); + if (ret < 0) + goto out_rtnl; + + dev->ethtool_ops->get_wol(dev, &wol); + ret = ethnl_update_bitset32(&wol.wolopts, WOL_MODE_COUNT, + tb[ETHTOOL_A_WOL_MODES], wol_mode_names, + info->extack, &mod); + if (ret < 0) + goto out_ops; + if (wol.wolopts & ~wol.supported) { + NL_SET_ERR_MSG_ATTR(info->extack, tb[ETHTOOL_A_WOL_MODES], + "cannot enable unsupported WoL mode"); + ret = -EINVAL; + goto out_ops; + } + if (tb[ETHTOOL_A_WOL_SOPASS]) { + if (!(wol.supported & WAKE_MAGICSECURE)) { + NL_SET_ERR_MSG_ATTR(info->extack, + tb[ETHTOOL_A_WOL_SOPASS], + "magicsecure not supported, cannot set password"); + ret = -EINVAL; + goto out_ops; + } + ethnl_update_binary(wol.sopass, sizeof(wol.sopass), + tb[ETHTOOL_A_WOL_SOPASS], &mod); + } + + if (!mod) + goto out_ops; + ret = dev->ethtool_ops->set_wol(dev, &wol); + if (ret) + goto out_ops; + dev->wol_enabled = !!wol.wolopts; + ethtool_notify(dev, ETHTOOL_MSG_WOL_NTF, NULL); + +out_ops: + ethnl_ops_complete(dev); +out_rtnl: + rtnl_unlock(); +out_dev: + ethnl_parse_header_dev_put(&req_info); + return ret; +} diff --git a/net/hsr/Kconfig b/net/hsr/Kconfig new file mode 100644 index 000000000..1b048c17b --- /dev/null +++ b/net/hsr/Kconfig @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# IEC 62439-3 High-availability Seamless Redundancy +# + +config HSR + tristate "High-availability Seamless Redundancy (HSR & PRP)" + help + This enables IEC 62439 defined High-availability Seamless + Redundancy (HSR) and Parallel Redundancy Protocol (PRP). + + If you say Y here, then your Linux box will be able to act as a + DANH ("Doubly attached node implementing HSR") or DANP ("Doubly + attached node implementing PRP"). For this to work, your Linux box + needs (at least) two physical Ethernet interfaces. + + For DANH, it must be connected as a node in a ring network together + with other HSR capable nodes. All Ethernet frames sent over the HSR + device will be sent in both directions on the ring (over both slave + ports), giving a redundant, instant fail-over network. Each HSR node + in the ring acts like a bridge for HSR frames, but filters frames + that have been forwarded earlier. + + For DANP, it must be connected as a node connecting to two + separate networks over the two slave interfaces. Like HSR, Ethernet + frames sent over the PRP device will be sent to both networks giving + a redundant, instant fail-over network. Unlike HSR, PRP networks + can have Singly Attached Nodes (SAN) such as PC, printer, bridges + etc and will be able to communicate with DANP nodes. + + This code is a "best effort" to comply with the HSR standard as + described in IEC 62439-3:2010 (HSRv0) and IEC 62439-3:2012 (HSRv1), + and PRP standard described in IEC 62439-4:2012 (PRP), but no + compliancy tests have been made. Use iproute2 to select the protocol + you would like to use. + + You need to perform any and all necessary tests yourself before + relying on this code in a safety critical system! + + If unsure, say N. diff --git a/net/hsr/Makefile b/net/hsr/Makefile new file mode 100644 index 000000000..75df90d3b --- /dev/null +++ b/net/hsr/Makefile @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for HSR +# + +obj-$(CONFIG_HSR) += hsr.o + +hsr-y := hsr_main.o hsr_framereg.o hsr_device.o \ + hsr_netlink.o hsr_slave.o hsr_forward.o +hsr-$(CONFIG_DEBUG_FS) += hsr_debugfs.o diff --git a/net/hsr/hsr_debugfs.c b/net/hsr/hsr_debugfs.c new file mode 100644 index 000000000..1a195efc7 --- /dev/null +++ b/net/hsr/hsr_debugfs.c @@ -0,0 +1,126 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * debugfs code for HSR & PRP + * Copyright (C) 2019 Texas Instruments Incorporated + * + * Author(s): + * Murali Karicheri + */ +#include +#include +#include +#include "hsr_main.h" +#include "hsr_framereg.h" + +static struct dentry *hsr_debugfs_root_dir; + +/* hsr_node_table_show - Formats and prints node_table entries */ +static int +hsr_node_table_show(struct seq_file *sfp, void *data) +{ + struct hsr_priv *priv = (struct hsr_priv *)sfp->private; + struct hsr_node *node; + + seq_printf(sfp, "Node Table entries for (%s) device\n", + (priv->prot_version == PRP_V1 ? "PRP" : "HSR")); + seq_puts(sfp, "MAC-Address-A, MAC-Address-B, time_in[A], "); + seq_puts(sfp, "time_in[B], Address-B port, "); + if (priv->prot_version == PRP_V1) + seq_puts(sfp, "SAN-A, SAN-B, DAN-P\n"); + else + seq_puts(sfp, "DAN-H\n"); + + rcu_read_lock(); + list_for_each_entry_rcu(node, &priv->node_db, mac_list) { + /* skip self node */ + if (hsr_addr_is_self(priv, node->macaddress_A)) + continue; + seq_printf(sfp, "%pM ", &node->macaddress_A[0]); + seq_printf(sfp, "%pM ", &node->macaddress_B[0]); + seq_printf(sfp, "%10lx, ", node->time_in[HSR_PT_SLAVE_A]); + seq_printf(sfp, "%10lx, ", node->time_in[HSR_PT_SLAVE_B]); + seq_printf(sfp, "%14x, ", node->addr_B_port); + + if (priv->prot_version == PRP_V1) + seq_printf(sfp, "%5x, %5x, %5x\n", + node->san_a, node->san_b, + (node->san_a == 0 && node->san_b == 0)); + else + seq_printf(sfp, "%5x\n", 1); + } + rcu_read_unlock(); + return 0; +} + +DEFINE_SHOW_ATTRIBUTE(hsr_node_table); + +void hsr_debugfs_rename(struct net_device *dev) +{ + struct hsr_priv *priv = netdev_priv(dev); + struct dentry *d; + + d = debugfs_rename(hsr_debugfs_root_dir, priv->node_tbl_root, + hsr_debugfs_root_dir, dev->name); + if (IS_ERR(d)) + netdev_warn(dev, "failed to rename\n"); + else + priv->node_tbl_root = d; +} + +/* hsr_debugfs_init - create hsr node_table file for dumping + * the node table + * + * Description: + * When debugfs is configured this routine sets up the node_table file per + * hsr device for dumping the node_table entries + */ +void hsr_debugfs_init(struct hsr_priv *priv, struct net_device *hsr_dev) +{ + struct dentry *de = NULL; + + de = debugfs_create_dir(hsr_dev->name, hsr_debugfs_root_dir); + if (IS_ERR(de)) { + pr_err("Cannot create hsr debugfs directory\n"); + return; + } + + priv->node_tbl_root = de; + + de = debugfs_create_file("node_table", S_IFREG | 0444, + priv->node_tbl_root, priv, + &hsr_node_table_fops); + if (IS_ERR(de)) { + pr_err("Cannot create hsr node_table file\n"); + debugfs_remove(priv->node_tbl_root); + priv->node_tbl_root = NULL; + return; + } +} + +/* hsr_debugfs_term - Tear down debugfs intrastructure + * + * Description: + * When Debugfs is configured this routine removes debugfs file system + * elements that are specific to hsr + */ +void +hsr_debugfs_term(struct hsr_priv *priv) +{ + debugfs_remove_recursive(priv->node_tbl_root); + priv->node_tbl_root = NULL; +} + +void hsr_debugfs_create_root(void) +{ + hsr_debugfs_root_dir = debugfs_create_dir("hsr", NULL); + if (IS_ERR(hsr_debugfs_root_dir)) { + pr_err("Cannot create hsr debugfs root directory\n"); + hsr_debugfs_root_dir = NULL; + } +} + +void hsr_debugfs_remove_root(void) +{ + /* debugfs_remove() internally checks NULL and ERROR */ + debugfs_remove(hsr_debugfs_root_dir); +} diff --git a/net/hsr/hsr_device.c b/net/hsr/hsr_device.c new file mode 100644 index 000000000..b1e86a726 --- /dev/null +++ b/net/hsr/hsr_device.c @@ -0,0 +1,562 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright 2011-2014 Autronica Fire and Security AS + * + * Author(s): + * 2011-2014 Arvid Brodin, arvid.brodin@alten.se + * This file contains device methods for creating, using and destroying + * virtual HSR or PRP devices. + */ + +#include +#include +#include +#include +#include +#include "hsr_device.h" +#include "hsr_slave.h" +#include "hsr_framereg.h" +#include "hsr_main.h" +#include "hsr_forward.h" + +static bool is_admin_up(struct net_device *dev) +{ + return dev && (dev->flags & IFF_UP); +} + +static bool is_slave_up(struct net_device *dev) +{ + return dev && is_admin_up(dev) && netif_oper_up(dev); +} + +static void __hsr_set_operstate(struct net_device *dev, int transition) +{ + write_lock(&dev_base_lock); + if (dev->operstate != transition) { + dev->operstate = transition; + write_unlock(&dev_base_lock); + netdev_state_change(dev); + } else { + write_unlock(&dev_base_lock); + } +} + +static void hsr_set_operstate(struct hsr_port *master, bool has_carrier) +{ + if (!is_admin_up(master->dev)) { + __hsr_set_operstate(master->dev, IF_OPER_DOWN); + return; + } + + if (has_carrier) + __hsr_set_operstate(master->dev, IF_OPER_UP); + else + __hsr_set_operstate(master->dev, IF_OPER_LOWERLAYERDOWN); +} + +static bool hsr_check_carrier(struct hsr_port *master) +{ + struct hsr_port *port; + + ASSERT_RTNL(); + + hsr_for_each_port(master->hsr, port) { + if (port->type != HSR_PT_MASTER && is_slave_up(port->dev)) { + netif_carrier_on(master->dev); + return true; + } + } + + netif_carrier_off(master->dev); + + return false; +} + +static void hsr_check_announce(struct net_device *hsr_dev, + unsigned char old_operstate) +{ + struct hsr_priv *hsr; + + hsr = netdev_priv(hsr_dev); + + if (hsr_dev->operstate == IF_OPER_UP && old_operstate != IF_OPER_UP) { + /* Went up */ + hsr->announce_count = 0; + mod_timer(&hsr->announce_timer, + jiffies + msecs_to_jiffies(HSR_ANNOUNCE_INTERVAL)); + } + + if (hsr_dev->operstate != IF_OPER_UP && old_operstate == IF_OPER_UP) + /* Went down */ + del_timer(&hsr->announce_timer); +} + +void hsr_check_carrier_and_operstate(struct hsr_priv *hsr) +{ + struct hsr_port *master; + unsigned char old_operstate; + bool has_carrier; + + master = hsr_port_get_hsr(hsr, HSR_PT_MASTER); + /* netif_stacked_transfer_operstate() cannot be used here since + * it doesn't set IF_OPER_LOWERLAYERDOWN (?) + */ + old_operstate = master->dev->operstate; + has_carrier = hsr_check_carrier(master); + hsr_set_operstate(master, has_carrier); + hsr_check_announce(master->dev, old_operstate); +} + +int hsr_get_max_mtu(struct hsr_priv *hsr) +{ + unsigned int mtu_max; + struct hsr_port *port; + + mtu_max = ETH_DATA_LEN; + hsr_for_each_port(hsr, port) + if (port->type != HSR_PT_MASTER) + mtu_max = min(port->dev->mtu, mtu_max); + + if (mtu_max < HSR_HLEN) + return 0; + return mtu_max - HSR_HLEN; +} + +static int hsr_dev_change_mtu(struct net_device *dev, int new_mtu) +{ + struct hsr_priv *hsr; + + hsr = netdev_priv(dev); + + if (new_mtu > hsr_get_max_mtu(hsr)) { + netdev_info(dev, "A HSR master's MTU cannot be greater than the smallest MTU of its slaves minus the HSR Tag length (%d octets).\n", + HSR_HLEN); + return -EINVAL; + } + + dev->mtu = new_mtu; + + return 0; +} + +static int hsr_dev_open(struct net_device *dev) +{ + struct hsr_priv *hsr; + struct hsr_port *port; + char designation; + + hsr = netdev_priv(dev); + designation = '\0'; + + hsr_for_each_port(hsr, port) { + if (port->type == HSR_PT_MASTER) + continue; + switch (port->type) { + case HSR_PT_SLAVE_A: + designation = 'A'; + break; + case HSR_PT_SLAVE_B: + designation = 'B'; + break; + default: + designation = '?'; + } + if (!is_slave_up(port->dev)) + netdev_warn(dev, "Slave %c (%s) is not up; please bring it up to get a fully working HSR network\n", + designation, port->dev->name); + } + + if (designation == '\0') + netdev_warn(dev, "No slave devices configured\n"); + + return 0; +} + +static int hsr_dev_close(struct net_device *dev) +{ + /* Nothing to do here. */ + return 0; +} + +static netdev_features_t hsr_features_recompute(struct hsr_priv *hsr, + netdev_features_t features) +{ + netdev_features_t mask; + struct hsr_port *port; + + mask = features; + + /* Mask out all features that, if supported by one device, should be + * enabled for all devices (see NETIF_F_ONE_FOR_ALL). + * + * Anything that's off in mask will not be enabled - so only things + * that were in features originally, and also is in NETIF_F_ONE_FOR_ALL, + * may become enabled. + */ + features &= ~NETIF_F_ONE_FOR_ALL; + hsr_for_each_port(hsr, port) + features = netdev_increment_features(features, + port->dev->features, + mask); + + return features; +} + +static netdev_features_t hsr_fix_features(struct net_device *dev, + netdev_features_t features) +{ + struct hsr_priv *hsr = netdev_priv(dev); + + return hsr_features_recompute(hsr, features); +} + +static netdev_tx_t hsr_dev_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct hsr_priv *hsr = netdev_priv(dev); + struct hsr_port *master; + + master = hsr_port_get_hsr(hsr, HSR_PT_MASTER); + if (master) { + skb->dev = master->dev; + skb_reset_mac_header(skb); + skb_reset_mac_len(skb); + spin_lock_bh(&hsr->seqnr_lock); + hsr_forward_skb(skb, master); + spin_unlock_bh(&hsr->seqnr_lock); + } else { + dev_core_stats_tx_dropped_inc(dev); + dev_kfree_skb_any(skb); + } + return NETDEV_TX_OK; +} + +static const struct header_ops hsr_header_ops = { + .create = eth_header, + .parse = eth_header_parse, +}; + +static struct sk_buff *hsr_init_skb(struct hsr_port *master) +{ + struct hsr_priv *hsr = master->hsr; + struct sk_buff *skb; + int hlen, tlen; + + hlen = LL_RESERVED_SPACE(master->dev); + tlen = master->dev->needed_tailroom; + /* skb size is same for PRP/HSR frames, only difference + * being, for PRP it is a trailer and for HSR it is a + * header + */ + skb = dev_alloc_skb(sizeof(struct hsr_sup_tag) + + sizeof(struct hsr_sup_payload) + hlen + tlen); + + if (!skb) + return skb; + + skb_reserve(skb, hlen); + skb->dev = master->dev; + skb->priority = TC_PRIO_CONTROL; + + if (dev_hard_header(skb, skb->dev, ETH_P_PRP, + hsr->sup_multicast_addr, + skb->dev->dev_addr, skb->len) <= 0) + goto out; + + skb_reset_mac_header(skb); + skb_reset_mac_len(skb); + skb_reset_network_header(skb); + skb_reset_transport_header(skb); + + return skb; +out: + kfree_skb(skb); + + return NULL; +} + +static void send_hsr_supervision_frame(struct hsr_port *master, + unsigned long *interval) +{ + struct hsr_priv *hsr = master->hsr; + __u8 type = HSR_TLV_LIFE_CHECK; + struct hsr_sup_payload *hsr_sp; + struct hsr_sup_tag *hsr_stag; + struct sk_buff *skb; + + *interval = msecs_to_jiffies(HSR_LIFE_CHECK_INTERVAL); + if (hsr->announce_count < 3 && hsr->prot_version == 0) { + type = HSR_TLV_ANNOUNCE; + *interval = msecs_to_jiffies(HSR_ANNOUNCE_INTERVAL); + hsr->announce_count++; + } + + skb = hsr_init_skb(master); + if (!skb) { + WARN_ONCE(1, "HSR: Could not send supervision frame\n"); + return; + } + + hsr_stag = skb_put(skb, sizeof(struct hsr_sup_tag)); + set_hsr_stag_path(hsr_stag, (hsr->prot_version ? 0x0 : 0xf)); + set_hsr_stag_HSR_ver(hsr_stag, hsr->prot_version); + + /* From HSRv1 on we have separate supervision sequence numbers. */ + spin_lock_bh(&hsr->seqnr_lock); + if (hsr->prot_version > 0) { + hsr_stag->sequence_nr = htons(hsr->sup_sequence_nr); + hsr->sup_sequence_nr++; + } else { + hsr_stag->sequence_nr = htons(hsr->sequence_nr); + hsr->sequence_nr++; + } + + hsr_stag->tlv.HSR_TLV_type = type; + /* TODO: Why 12 in HSRv0? */ + hsr_stag->tlv.HSR_TLV_length = hsr->prot_version ? + sizeof(struct hsr_sup_payload) : 12; + + /* Payload: MacAddressA */ + hsr_sp = skb_put(skb, sizeof(struct hsr_sup_payload)); + ether_addr_copy(hsr_sp->macaddress_A, master->dev->dev_addr); + + if (skb_put_padto(skb, ETH_ZLEN)) { + spin_unlock_bh(&hsr->seqnr_lock); + return; + } + + hsr_forward_skb(skb, master); + spin_unlock_bh(&hsr->seqnr_lock); + return; +} + +static void send_prp_supervision_frame(struct hsr_port *master, + unsigned long *interval) +{ + struct hsr_priv *hsr = master->hsr; + struct hsr_sup_payload *hsr_sp; + struct hsr_sup_tag *hsr_stag; + struct sk_buff *skb; + + skb = hsr_init_skb(master); + if (!skb) { + WARN_ONCE(1, "PRP: Could not send supervision frame\n"); + return; + } + + *interval = msecs_to_jiffies(HSR_LIFE_CHECK_INTERVAL); + hsr_stag = skb_put(skb, sizeof(struct hsr_sup_tag)); + set_hsr_stag_path(hsr_stag, (hsr->prot_version ? 0x0 : 0xf)); + set_hsr_stag_HSR_ver(hsr_stag, (hsr->prot_version ? 1 : 0)); + + /* From HSRv1 on we have separate supervision sequence numbers. */ + spin_lock_bh(&hsr->seqnr_lock); + hsr_stag->sequence_nr = htons(hsr->sup_sequence_nr); + hsr->sup_sequence_nr++; + hsr_stag->tlv.HSR_TLV_type = PRP_TLV_LIFE_CHECK_DD; + hsr_stag->tlv.HSR_TLV_length = sizeof(struct hsr_sup_payload); + + /* Payload: MacAddressA */ + hsr_sp = skb_put(skb, sizeof(struct hsr_sup_payload)); + ether_addr_copy(hsr_sp->macaddress_A, master->dev->dev_addr); + + if (skb_put_padto(skb, ETH_ZLEN)) { + spin_unlock_bh(&hsr->seqnr_lock); + return; + } + + hsr_forward_skb(skb, master); + spin_unlock_bh(&hsr->seqnr_lock); +} + +/* Announce (supervision frame) timer function + */ +static void hsr_announce(struct timer_list *t) +{ + struct hsr_priv *hsr; + struct hsr_port *master; + unsigned long interval; + + hsr = from_timer(hsr, t, announce_timer); + + rcu_read_lock(); + master = hsr_port_get_hsr(hsr, HSR_PT_MASTER); + hsr->proto_ops->send_sv_frame(master, &interval); + + if (is_admin_up(master->dev)) + mod_timer(&hsr->announce_timer, jiffies + interval); + + rcu_read_unlock(); +} + +void hsr_del_ports(struct hsr_priv *hsr) +{ + struct hsr_port *port; + + port = hsr_port_get_hsr(hsr, HSR_PT_SLAVE_A); + if (port) + hsr_del_port(port); + + port = hsr_port_get_hsr(hsr, HSR_PT_SLAVE_B); + if (port) + hsr_del_port(port); + + port = hsr_port_get_hsr(hsr, HSR_PT_MASTER); + if (port) + hsr_del_port(port); +} + +static const struct net_device_ops hsr_device_ops = { + .ndo_change_mtu = hsr_dev_change_mtu, + .ndo_open = hsr_dev_open, + .ndo_stop = hsr_dev_close, + .ndo_start_xmit = hsr_dev_xmit, + .ndo_fix_features = hsr_fix_features, +}; + +static struct device_type hsr_type = { + .name = "hsr", +}; + +static struct hsr_proto_ops hsr_ops = { + .send_sv_frame = send_hsr_supervision_frame, + .create_tagged_frame = hsr_create_tagged_frame, + .get_untagged_frame = hsr_get_untagged_frame, + .drop_frame = hsr_drop_frame, + .fill_frame_info = hsr_fill_frame_info, + .invalid_dan_ingress_frame = hsr_invalid_dan_ingress_frame, +}; + +static struct hsr_proto_ops prp_ops = { + .send_sv_frame = send_prp_supervision_frame, + .create_tagged_frame = prp_create_tagged_frame, + .get_untagged_frame = prp_get_untagged_frame, + .drop_frame = prp_drop_frame, + .fill_frame_info = prp_fill_frame_info, + .handle_san_frame = prp_handle_san_frame, + .update_san_info = prp_update_san_info, +}; + +void hsr_dev_setup(struct net_device *dev) +{ + eth_hw_addr_random(dev); + + ether_setup(dev); + dev->min_mtu = 0; + dev->header_ops = &hsr_header_ops; + dev->netdev_ops = &hsr_device_ops; + SET_NETDEV_DEVTYPE(dev, &hsr_type); + dev->priv_flags |= IFF_NO_QUEUE | IFF_DISABLE_NETPOLL; + + dev->needs_free_netdev = true; + + dev->hw_features = NETIF_F_SG | NETIF_F_FRAGLIST | NETIF_F_HIGHDMA | + NETIF_F_GSO_MASK | NETIF_F_HW_CSUM | + NETIF_F_HW_VLAN_CTAG_TX; + + dev->features = dev->hw_features; + + /* Prevent recursive tx locking */ + dev->features |= NETIF_F_LLTX; + /* VLAN on top of HSR needs testing and probably some work on + * hsr_header_create() etc. + */ + dev->features |= NETIF_F_VLAN_CHALLENGED; + /* Not sure about this. Taken from bridge code. netdev_features.h says + * it means "Does not change network namespaces". + */ + dev->features |= NETIF_F_NETNS_LOCAL; +} + +/* Return true if dev is a HSR master; return false otherwise. + */ +bool is_hsr_master(struct net_device *dev) +{ + return (dev->netdev_ops->ndo_start_xmit == hsr_dev_xmit); +} +EXPORT_SYMBOL(is_hsr_master); + +/* Default multicast address for HSR Supervision frames */ +static const unsigned char def_multicast_addr[ETH_ALEN] __aligned(2) = { + 0x01, 0x15, 0x4e, 0x00, 0x01, 0x00 +}; + +int hsr_dev_finalize(struct net_device *hsr_dev, struct net_device *slave[2], + unsigned char multicast_spec, u8 protocol_version, + struct netlink_ext_ack *extack) +{ + bool unregister = false; + struct hsr_priv *hsr; + int res; + + hsr = netdev_priv(hsr_dev); + INIT_LIST_HEAD(&hsr->ports); + INIT_LIST_HEAD(&hsr->node_db); + INIT_LIST_HEAD(&hsr->self_node_db); + spin_lock_init(&hsr->list_lock); + + eth_hw_addr_set(hsr_dev, slave[0]->dev_addr); + + /* initialize protocol specific functions */ + if (protocol_version == PRP_V1) { + /* For PRP, lan_id has most significant 3 bits holding + * the net_id of PRP_LAN_ID + */ + hsr->net_id = PRP_LAN_ID << 1; + hsr->proto_ops = &prp_ops; + } else { + hsr->proto_ops = &hsr_ops; + } + + /* Make sure we recognize frames from ourselves in hsr_rcv() */ + res = hsr_create_self_node(hsr, hsr_dev->dev_addr, + slave[1]->dev_addr); + if (res < 0) + return res; + + spin_lock_init(&hsr->seqnr_lock); + /* Overflow soon to find bugs easier: */ + hsr->sequence_nr = HSR_SEQNR_START; + hsr->sup_sequence_nr = HSR_SUP_SEQNR_START; + + timer_setup(&hsr->announce_timer, hsr_announce, 0); + timer_setup(&hsr->prune_timer, hsr_prune_nodes, 0); + + ether_addr_copy(hsr->sup_multicast_addr, def_multicast_addr); + hsr->sup_multicast_addr[ETH_ALEN - 1] = multicast_spec; + + hsr->prot_version = protocol_version; + + /* Make sure the 1st call to netif_carrier_on() gets through */ + netif_carrier_off(hsr_dev); + + res = hsr_add_port(hsr, hsr_dev, HSR_PT_MASTER, extack); + if (res) + goto err_add_master; + + res = register_netdevice(hsr_dev); + if (res) + goto err_unregister; + + unregister = true; + + res = hsr_add_port(hsr, slave[0], HSR_PT_SLAVE_A, extack); + if (res) + goto err_unregister; + + res = hsr_add_port(hsr, slave[1], HSR_PT_SLAVE_B, extack); + if (res) + goto err_unregister; + + hsr_debugfs_init(hsr, hsr_dev); + mod_timer(&hsr->prune_timer, jiffies + msecs_to_jiffies(PRUNE_PERIOD)); + + return 0; + +err_unregister: + hsr_del_ports(hsr); +err_add_master: + hsr_del_self_node(hsr); + + if (unregister) + unregister_netdevice(hsr_dev); + return res; +} diff --git a/net/hsr/hsr_device.h b/net/hsr/hsr_device.h new file mode 100644 index 000000000..9060c9216 --- /dev/null +++ b/net/hsr/hsr_device.h @@ -0,0 +1,23 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright 2011-2014 Autronica Fire and Security AS + * + * Author(s): + * 2011-2014 Arvid Brodin, arvid.brodin@alten.se + * + * include file for HSR and PRP. + */ + +#ifndef __HSR_DEVICE_H +#define __HSR_DEVICE_H + +#include +#include "hsr_main.h" + +void hsr_del_ports(struct hsr_priv *hsr); +void hsr_dev_setup(struct net_device *dev); +int hsr_dev_finalize(struct net_device *hsr_dev, struct net_device *slave[2], + unsigned char multicast_spec, u8 protocol_version, + struct netlink_ext_ack *extack); +void hsr_check_carrier_and_operstate(struct hsr_priv *hsr); +int hsr_get_max_mtu(struct hsr_priv *hsr); +#endif /* __HSR_DEVICE_H */ diff --git a/net/hsr/hsr_forward.c b/net/hsr/hsr_forward.c new file mode 100644 index 000000000..80cdc6f6b --- /dev/null +++ b/net/hsr/hsr_forward.c @@ -0,0 +1,638 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright 2011-2014 Autronica Fire and Security AS + * + * Author(s): + * 2011-2014 Arvid Brodin, arvid.brodin@alten.se + * + * Frame router for HSR and PRP. + */ + +#include "hsr_forward.h" +#include +#include +#include +#include +#include "hsr_main.h" +#include "hsr_framereg.h" + +struct hsr_node; + +/* The uses I can see for these HSR supervision frames are: + * 1) Use the frames that are sent after node initialization ("HSR_TLV.Type = + * 22") to reset any sequence_nr counters belonging to that node. Useful if + * the other node's counter has been reset for some reason. + * -- + * Or not - resetting the counter and bridging the frame would create a + * loop, unfortunately. + * + * 2) Use the LifeCheck frames to detect ring breaks. I.e. if no LifeCheck + * frame is received from a particular node, we know something is wrong. + * We just register these (as with normal frames) and throw them away. + * + * 3) Allow different MAC addresses for the two slave interfaces, using the + * MacAddressA field. + */ +static bool is_supervision_frame(struct hsr_priv *hsr, struct sk_buff *skb) +{ + struct ethhdr *eth_hdr; + struct hsr_sup_tag *hsr_sup_tag; + struct hsrv1_ethhdr_sp *hsr_V1_hdr; + struct hsr_sup_tlv *hsr_sup_tlv; + u16 total_length = 0; + + WARN_ON_ONCE(!skb_mac_header_was_set(skb)); + eth_hdr = (struct ethhdr *)skb_mac_header(skb); + + /* Correct addr? */ + if (!ether_addr_equal(eth_hdr->h_dest, + hsr->sup_multicast_addr)) + return false; + + /* Correct ether type?. */ + if (!(eth_hdr->h_proto == htons(ETH_P_PRP) || + eth_hdr->h_proto == htons(ETH_P_HSR))) + return false; + + /* Get the supervision header from correct location. */ + if (eth_hdr->h_proto == htons(ETH_P_HSR)) { /* Okay HSRv1. */ + total_length = sizeof(struct hsrv1_ethhdr_sp); + if (!pskb_may_pull(skb, total_length)) + return false; + + hsr_V1_hdr = (struct hsrv1_ethhdr_sp *)skb_mac_header(skb); + if (hsr_V1_hdr->hsr.encap_proto != htons(ETH_P_PRP)) + return false; + + hsr_sup_tag = &hsr_V1_hdr->hsr_sup; + } else { + total_length = sizeof(struct hsrv0_ethhdr_sp); + if (!pskb_may_pull(skb, total_length)) + return false; + + hsr_sup_tag = + &((struct hsrv0_ethhdr_sp *)skb_mac_header(skb))->hsr_sup; + } + + if (hsr_sup_tag->tlv.HSR_TLV_type != HSR_TLV_ANNOUNCE && + hsr_sup_tag->tlv.HSR_TLV_type != HSR_TLV_LIFE_CHECK && + hsr_sup_tag->tlv.HSR_TLV_type != PRP_TLV_LIFE_CHECK_DD && + hsr_sup_tag->tlv.HSR_TLV_type != PRP_TLV_LIFE_CHECK_DA) + return false; + if (hsr_sup_tag->tlv.HSR_TLV_length != 12 && + hsr_sup_tag->tlv.HSR_TLV_length != sizeof(struct hsr_sup_payload)) + return false; + + /* Get next tlv */ + total_length += sizeof(struct hsr_sup_tlv) + hsr_sup_tag->tlv.HSR_TLV_length; + if (!pskb_may_pull(skb, total_length)) + return false; + skb_pull(skb, total_length); + hsr_sup_tlv = (struct hsr_sup_tlv *)skb->data; + skb_push(skb, total_length); + + /* if this is a redbox supervision frame we need to verify + * that more data is available + */ + if (hsr_sup_tlv->HSR_TLV_type == PRP_TLV_REDBOX_MAC) { + /* tlv length must be a length of a mac address */ + if (hsr_sup_tlv->HSR_TLV_length != sizeof(struct hsr_sup_payload)) + return false; + + /* make sure another tlv follows */ + total_length += sizeof(struct hsr_sup_tlv) + hsr_sup_tlv->HSR_TLV_length; + if (!pskb_may_pull(skb, total_length)) + return false; + + /* get next tlv */ + skb_pull(skb, total_length); + hsr_sup_tlv = (struct hsr_sup_tlv *)skb->data; + skb_push(skb, total_length); + } + + /* end of tlvs must follow at the end */ + if (hsr_sup_tlv->HSR_TLV_type == HSR_TLV_EOT && + hsr_sup_tlv->HSR_TLV_length != 0) + return false; + + return true; +} + +static struct sk_buff *create_stripped_skb_hsr(struct sk_buff *skb_in, + struct hsr_frame_info *frame) +{ + struct sk_buff *skb; + int copylen; + unsigned char *dst, *src; + + skb_pull(skb_in, HSR_HLEN); + skb = __pskb_copy(skb_in, skb_headroom(skb_in) - HSR_HLEN, GFP_ATOMIC); + skb_push(skb_in, HSR_HLEN); + if (!skb) + return NULL; + + skb_reset_mac_header(skb); + + if (skb->ip_summed == CHECKSUM_PARTIAL) + skb->csum_start -= HSR_HLEN; + + copylen = 2 * ETH_ALEN; + if (frame->is_vlan) + copylen += VLAN_HLEN; + src = skb_mac_header(skb_in); + dst = skb_mac_header(skb); + memcpy(dst, src, copylen); + + skb->protocol = eth_hdr(skb)->h_proto; + return skb; +} + +struct sk_buff *hsr_get_untagged_frame(struct hsr_frame_info *frame, + struct hsr_port *port) +{ + if (!frame->skb_std) { + if (frame->skb_hsr) + frame->skb_std = + create_stripped_skb_hsr(frame->skb_hsr, frame); + else + netdev_warn_once(port->dev, + "Unexpected frame received in hsr_get_untagged_frame()\n"); + + if (!frame->skb_std) + return NULL; + } + + return skb_clone(frame->skb_std, GFP_ATOMIC); +} + +struct sk_buff *prp_get_untagged_frame(struct hsr_frame_info *frame, + struct hsr_port *port) +{ + if (!frame->skb_std) { + if (frame->skb_prp) { + /* trim the skb by len - HSR_HLEN to exclude RCT */ + skb_trim(frame->skb_prp, + frame->skb_prp->len - HSR_HLEN); + frame->skb_std = + __pskb_copy(frame->skb_prp, + skb_headroom(frame->skb_prp), + GFP_ATOMIC); + } else { + /* Unexpected */ + WARN_ONCE(1, "%s:%d: Unexpected frame received (port_src %s)\n", + __FILE__, __LINE__, port->dev->name); + return NULL; + } + } + + return skb_clone(frame->skb_std, GFP_ATOMIC); +} + +static void prp_set_lan_id(struct prp_rct *trailer, + struct hsr_port *port) +{ + int lane_id; + + if (port->type == HSR_PT_SLAVE_A) + lane_id = 0; + else + lane_id = 1; + + /* Add net_id in the upper 3 bits of lane_id */ + lane_id |= port->hsr->net_id; + set_prp_lan_id(trailer, lane_id); +} + +/* Tailroom for PRP rct should have been created before calling this */ +static struct sk_buff *prp_fill_rct(struct sk_buff *skb, + struct hsr_frame_info *frame, + struct hsr_port *port) +{ + struct prp_rct *trailer; + int min_size = ETH_ZLEN; + int lsdu_size; + + if (!skb) + return skb; + + if (frame->is_vlan) + min_size = VLAN_ETH_ZLEN; + + if (skb_put_padto(skb, min_size)) + return NULL; + + trailer = (struct prp_rct *)skb_put(skb, HSR_HLEN); + lsdu_size = skb->len - 14; + if (frame->is_vlan) + lsdu_size -= 4; + prp_set_lan_id(trailer, port); + set_prp_LSDU_size(trailer, lsdu_size); + trailer->sequence_nr = htons(frame->sequence_nr); + trailer->PRP_suffix = htons(ETH_P_PRP); + skb->protocol = eth_hdr(skb)->h_proto; + + return skb; +} + +static void hsr_set_path_id(struct hsr_ethhdr *hsr_ethhdr, + struct hsr_port *port) +{ + int path_id; + + if (port->type == HSR_PT_SLAVE_A) + path_id = 0; + else + path_id = 1; + + set_hsr_tag_path(&hsr_ethhdr->hsr_tag, path_id); +} + +static struct sk_buff *hsr_fill_tag(struct sk_buff *skb, + struct hsr_frame_info *frame, + struct hsr_port *port, u8 proto_version) +{ + struct hsr_ethhdr *hsr_ethhdr; + int lsdu_size; + + /* pad to minimum packet size which is 60 + 6 (HSR tag) */ + if (skb_put_padto(skb, ETH_ZLEN + HSR_HLEN)) + return NULL; + + lsdu_size = skb->len - 14; + if (frame->is_vlan) + lsdu_size -= 4; + + hsr_ethhdr = (struct hsr_ethhdr *)skb_mac_header(skb); + + hsr_set_path_id(hsr_ethhdr, port); + set_hsr_tag_LSDU_size(&hsr_ethhdr->hsr_tag, lsdu_size); + hsr_ethhdr->hsr_tag.sequence_nr = htons(frame->sequence_nr); + hsr_ethhdr->hsr_tag.encap_proto = hsr_ethhdr->ethhdr.h_proto; + hsr_ethhdr->ethhdr.h_proto = htons(proto_version ? + ETH_P_HSR : ETH_P_PRP); + skb->protocol = hsr_ethhdr->ethhdr.h_proto; + + return skb; +} + +/* If the original frame was an HSR tagged frame, just clone it to be sent + * unchanged. Otherwise, create a private frame especially tagged for 'port'. + */ +struct sk_buff *hsr_create_tagged_frame(struct hsr_frame_info *frame, + struct hsr_port *port) +{ + unsigned char *dst, *src; + struct sk_buff *skb; + int movelen; + + if (frame->skb_hsr) { + struct hsr_ethhdr *hsr_ethhdr = + (struct hsr_ethhdr *)skb_mac_header(frame->skb_hsr); + + /* set the lane id properly */ + hsr_set_path_id(hsr_ethhdr, port); + return skb_clone(frame->skb_hsr, GFP_ATOMIC); + } else if (port->dev->features & NETIF_F_HW_HSR_TAG_INS) { + return skb_clone(frame->skb_std, GFP_ATOMIC); + } + + /* Create the new skb with enough headroom to fit the HSR tag */ + skb = __pskb_copy(frame->skb_std, + skb_headroom(frame->skb_std) + HSR_HLEN, GFP_ATOMIC); + if (!skb) + return NULL; + skb_reset_mac_header(skb); + + if (skb->ip_summed == CHECKSUM_PARTIAL) + skb->csum_start += HSR_HLEN; + + movelen = ETH_HLEN; + if (frame->is_vlan) + movelen += VLAN_HLEN; + + src = skb_mac_header(skb); + dst = skb_push(skb, HSR_HLEN); + memmove(dst, src, movelen); + skb_reset_mac_header(skb); + + /* skb_put_padto free skb on error and hsr_fill_tag returns NULL in + * that case + */ + return hsr_fill_tag(skb, frame, port, port->hsr->prot_version); +} + +struct sk_buff *prp_create_tagged_frame(struct hsr_frame_info *frame, + struct hsr_port *port) +{ + struct sk_buff *skb; + + if (frame->skb_prp) { + struct prp_rct *trailer = skb_get_PRP_rct(frame->skb_prp); + + if (trailer) { + prp_set_lan_id(trailer, port); + } else { + WARN_ONCE(!trailer, "errored PRP skb"); + return NULL; + } + return skb_clone(frame->skb_prp, GFP_ATOMIC); + } else if (port->dev->features & NETIF_F_HW_HSR_TAG_INS) { + return skb_clone(frame->skb_std, GFP_ATOMIC); + } + + skb = skb_copy_expand(frame->skb_std, 0, + skb_tailroom(frame->skb_std) + HSR_HLEN, + GFP_ATOMIC); + return prp_fill_rct(skb, frame, port); +} + +static void hsr_deliver_master(struct sk_buff *skb, struct net_device *dev, + struct hsr_node *node_src) +{ + bool was_multicast_frame; + int res, recv_len; + + was_multicast_frame = (skb->pkt_type == PACKET_MULTICAST); + hsr_addr_subst_source(node_src, skb); + skb_pull(skb, ETH_HLEN); + recv_len = skb->len; + res = netif_rx(skb); + if (res == NET_RX_DROP) { + dev->stats.rx_dropped++; + } else { + dev->stats.rx_packets++; + dev->stats.rx_bytes += recv_len; + if (was_multicast_frame) + dev->stats.multicast++; + } +} + +static int hsr_xmit(struct sk_buff *skb, struct hsr_port *port, + struct hsr_frame_info *frame) +{ + if (frame->port_rcv->type == HSR_PT_MASTER) { + hsr_addr_subst_dest(frame->node_src, skb, port); + + /* Address substitution (IEC62439-3 pp 26, 50): replace mac + * address of outgoing frame with that of the outgoing slave's. + */ + ether_addr_copy(eth_hdr(skb)->h_source, port->dev->dev_addr); + } + return dev_queue_xmit(skb); +} + +bool prp_drop_frame(struct hsr_frame_info *frame, struct hsr_port *port) +{ + return ((frame->port_rcv->type == HSR_PT_SLAVE_A && + port->type == HSR_PT_SLAVE_B) || + (frame->port_rcv->type == HSR_PT_SLAVE_B && + port->type == HSR_PT_SLAVE_A)); +} + +bool hsr_drop_frame(struct hsr_frame_info *frame, struct hsr_port *port) +{ + if (port->dev->features & NETIF_F_HW_HSR_FWD) + return prp_drop_frame(frame, port); + + return false; +} + +/* Forward the frame through all devices except: + * - Back through the receiving device + * - If it's a HSR frame: through a device where it has passed before + * - if it's a PRP frame: through another PRP slave device (no bridge) + * - To the local HSR master only if the frame is directly addressed to it, or + * a non-supervision multicast or broadcast frame. + * + * HSR slave devices should insert a HSR tag into the frame, or forward the + * frame unchanged if it's already tagged. Interlink devices should strip HSR + * tags if they're of the non-HSR type (but only after duplicate discard). The + * master device always strips HSR tags. + */ +static void hsr_forward_do(struct hsr_frame_info *frame) +{ + struct hsr_port *port; + struct sk_buff *skb; + bool sent = false; + + hsr_for_each_port(frame->port_rcv->hsr, port) { + struct hsr_priv *hsr = port->hsr; + /* Don't send frame back the way it came */ + if (port == frame->port_rcv) + continue; + + /* Don't deliver locally unless we should */ + if (port->type == HSR_PT_MASTER && !frame->is_local_dest) + continue; + + /* Deliver frames directly addressed to us to master only */ + if (port->type != HSR_PT_MASTER && frame->is_local_exclusive) + continue; + + /* If hardware duplicate generation is enabled, only send out + * one port. + */ + if ((port->dev->features & NETIF_F_HW_HSR_DUP) && sent) + continue; + + /* Don't send frame over port where it has been sent before. + * Also fro SAN, this shouldn't be done. + */ + if (!frame->is_from_san && + hsr_register_frame_out(port, frame->node_src, + frame->sequence_nr)) + continue; + + if (frame->is_supervision && port->type == HSR_PT_MASTER) { + hsr_handle_sup_frame(frame); + continue; + } + + /* Check if frame is to be dropped. Eg. for PRP no forward + * between ports. + */ + if (hsr->proto_ops->drop_frame && + hsr->proto_ops->drop_frame(frame, port)) + continue; + + if (port->type != HSR_PT_MASTER) + skb = hsr->proto_ops->create_tagged_frame(frame, port); + else + skb = hsr->proto_ops->get_untagged_frame(frame, port); + + if (!skb) { + frame->port_rcv->dev->stats.rx_dropped++; + continue; + } + + skb->dev = port->dev; + if (port->type == HSR_PT_MASTER) { + hsr_deliver_master(skb, port->dev, frame->node_src); + } else { + if (!hsr_xmit(skb, port, frame)) + sent = true; + } + } +} + +static void check_local_dest(struct hsr_priv *hsr, struct sk_buff *skb, + struct hsr_frame_info *frame) +{ + if (hsr_addr_is_self(hsr, eth_hdr(skb)->h_dest)) { + frame->is_local_exclusive = true; + skb->pkt_type = PACKET_HOST; + } else { + frame->is_local_exclusive = false; + } + + if (skb->pkt_type == PACKET_HOST || + skb->pkt_type == PACKET_MULTICAST || + skb->pkt_type == PACKET_BROADCAST) { + frame->is_local_dest = true; + } else { + frame->is_local_dest = false; + } +} + +static void handle_std_frame(struct sk_buff *skb, + struct hsr_frame_info *frame) +{ + struct hsr_port *port = frame->port_rcv; + struct hsr_priv *hsr = port->hsr; + + frame->skb_hsr = NULL; + frame->skb_prp = NULL; + frame->skb_std = skb; + + if (port->type != HSR_PT_MASTER) { + frame->is_from_san = true; + } else { + /* Sequence nr for the master node */ + lockdep_assert_held(&hsr->seqnr_lock); + frame->sequence_nr = hsr->sequence_nr; + hsr->sequence_nr++; + } +} + +int hsr_fill_frame_info(__be16 proto, struct sk_buff *skb, + struct hsr_frame_info *frame) +{ + struct hsr_port *port = frame->port_rcv; + struct hsr_priv *hsr = port->hsr; + + /* HSRv0 supervisory frames double as a tag so treat them as tagged. */ + if ((!hsr->prot_version && proto == htons(ETH_P_PRP)) || + proto == htons(ETH_P_HSR)) { + /* Check if skb contains hsr_ethhdr */ + if (skb->mac_len < sizeof(struct hsr_ethhdr)) + return -EINVAL; + + /* HSR tagged frame :- Data or Supervision */ + frame->skb_std = NULL; + frame->skb_prp = NULL; + frame->skb_hsr = skb; + frame->sequence_nr = hsr_get_skb_sequence_nr(skb); + return 0; + } + + /* Standard frame or PRP from master port */ + handle_std_frame(skb, frame); + + return 0; +} + +int prp_fill_frame_info(__be16 proto, struct sk_buff *skb, + struct hsr_frame_info *frame) +{ + /* Supervision frame */ + struct prp_rct *rct = skb_get_PRP_rct(skb); + + if (rct && + prp_check_lsdu_size(skb, rct, frame->is_supervision)) { + frame->skb_hsr = NULL; + frame->skb_std = NULL; + frame->skb_prp = skb; + frame->sequence_nr = prp_get_skb_sequence_nr(rct); + return 0; + } + handle_std_frame(skb, frame); + + return 0; +} + +static int fill_frame_info(struct hsr_frame_info *frame, + struct sk_buff *skb, struct hsr_port *port) +{ + struct hsr_priv *hsr = port->hsr; + struct hsr_vlan_ethhdr *vlan_hdr; + struct ethhdr *ethhdr; + __be16 proto; + int ret; + + /* Check if skb contains ethhdr */ + if (skb->mac_len < sizeof(struct ethhdr)) + return -EINVAL; + + memset(frame, 0, sizeof(*frame)); + frame->is_supervision = is_supervision_frame(port->hsr, skb); + frame->node_src = hsr_get_node(port, &hsr->node_db, skb, + frame->is_supervision, + port->type); + if (!frame->node_src) + return -1; /* Unknown node and !is_supervision, or no mem */ + + ethhdr = (struct ethhdr *)skb_mac_header(skb); + frame->is_vlan = false; + proto = ethhdr->h_proto; + + if (proto == htons(ETH_P_8021Q)) + frame->is_vlan = true; + + if (frame->is_vlan) { + vlan_hdr = (struct hsr_vlan_ethhdr *)ethhdr; + proto = vlan_hdr->vlanhdr.h_vlan_encapsulated_proto; + /* FIXME: */ + netdev_warn_once(skb->dev, "VLAN not yet supported"); + return -EINVAL; + } + + frame->is_from_san = false; + frame->port_rcv = port; + ret = hsr->proto_ops->fill_frame_info(proto, skb, frame); + if (ret) + return ret; + + check_local_dest(port->hsr, skb, frame); + + return 0; +} + +/* Must be called holding rcu read lock (because of the port parameter) */ +void hsr_forward_skb(struct sk_buff *skb, struct hsr_port *port) +{ + struct hsr_frame_info frame; + + rcu_read_lock(); + if (fill_frame_info(&frame, skb, port) < 0) + goto out_drop; + + hsr_register_frame_in(frame.node_src, port, frame.sequence_nr); + hsr_forward_do(&frame); + rcu_read_unlock(); + /* Gets called for ingress frames as well as egress from master port. + * So check and increment stats for master port only here. + */ + if (port->type == HSR_PT_MASTER) { + port->dev->stats.tx_packets++; + port->dev->stats.tx_bytes += skb->len; + } + + kfree_skb(frame.skb_hsr); + kfree_skb(frame.skb_prp); + kfree_skb(frame.skb_std); + return; + +out_drop: + rcu_read_unlock(); + port->dev->stats.tx_dropped++; + kfree_skb(skb); +} diff --git a/net/hsr/hsr_forward.h b/net/hsr/hsr_forward.h new file mode 100644 index 000000000..206636750 --- /dev/null +++ b/net/hsr/hsr_forward.h @@ -0,0 +1,31 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright 2011-2014 Autronica Fire and Security AS + * + * Author(s): + * 2011-2014 Arvid Brodin, arvid.brodin@alten.se + * + * include file for HSR and PRP. + */ + +#ifndef __HSR_FORWARD_H +#define __HSR_FORWARD_H + +#include +#include "hsr_main.h" + +void hsr_forward_skb(struct sk_buff *skb, struct hsr_port *port); +struct sk_buff *prp_create_tagged_frame(struct hsr_frame_info *frame, + struct hsr_port *port); +struct sk_buff *hsr_create_tagged_frame(struct hsr_frame_info *frame, + struct hsr_port *port); +struct sk_buff *hsr_get_untagged_frame(struct hsr_frame_info *frame, + struct hsr_port *port); +struct sk_buff *prp_get_untagged_frame(struct hsr_frame_info *frame, + struct hsr_port *port); +bool prp_drop_frame(struct hsr_frame_info *frame, struct hsr_port *port); +bool hsr_drop_frame(struct hsr_frame_info *frame, struct hsr_port *port); +int prp_fill_frame_info(__be16 proto, struct sk_buff *skb, + struct hsr_frame_info *frame); +int hsr_fill_frame_info(__be16 proto, struct sk_buff *skb, + struct hsr_frame_info *frame); +#endif /* __HSR_FORWARD_H */ diff --git a/net/hsr/hsr_framereg.c b/net/hsr/hsr_framereg.c new file mode 100644 index 000000000..0b0199878 --- /dev/null +++ b/net/hsr/hsr_framereg.c @@ -0,0 +1,640 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright 2011-2014 Autronica Fire and Security AS + * + * Author(s): + * 2011-2014 Arvid Brodin, arvid.brodin@alten.se + * + * The HSR spec says never to forward the same frame twice on the same + * interface. A frame is identified by its source MAC address and its HSR + * sequence number. This code keeps track of senders and their sequence numbers + * to allow filtering of duplicate frames, and to detect HSR ring errors. + * Same code handles filtering of duplicates for PRP as well. + */ + +#include +#include +#include +#include +#include "hsr_main.h" +#include "hsr_framereg.h" +#include "hsr_netlink.h" + +/* seq_nr_after(a, b) - return true if a is after (higher in sequence than) b, + * false otherwise. + */ +static bool seq_nr_after(u16 a, u16 b) +{ + /* Remove inconsistency where + * seq_nr_after(a, b) == seq_nr_before(a, b) + */ + if ((int)b - a == 32768) + return false; + + return (((s16)(b - a)) < 0); +} + +#define seq_nr_before(a, b) seq_nr_after((b), (a)) +#define seq_nr_before_or_eq(a, b) (!seq_nr_after((a), (b))) + +bool hsr_addr_is_self(struct hsr_priv *hsr, unsigned char *addr) +{ + struct hsr_node *node; + + node = list_first_or_null_rcu(&hsr->self_node_db, struct hsr_node, + mac_list); + if (!node) { + WARN_ONCE(1, "HSR: No self node\n"); + return false; + } + + if (ether_addr_equal(addr, node->macaddress_A)) + return true; + if (ether_addr_equal(addr, node->macaddress_B)) + return true; + + return false; +} + +/* Search for mac entry. Caller must hold rcu read lock. + */ +static struct hsr_node *find_node_by_addr_A(struct list_head *node_db, + const unsigned char addr[ETH_ALEN]) +{ + struct hsr_node *node; + + list_for_each_entry_rcu(node, node_db, mac_list) { + if (ether_addr_equal(node->macaddress_A, addr)) + return node; + } + + return NULL; +} + +/* Helper for device init; the self_node_db is used in hsr_rcv() to recognize + * frames from self that's been looped over the HSR ring. + */ +int hsr_create_self_node(struct hsr_priv *hsr, + const unsigned char addr_a[ETH_ALEN], + const unsigned char addr_b[ETH_ALEN]) +{ + struct list_head *self_node_db = &hsr->self_node_db; + struct hsr_node *node, *oldnode; + + node = kmalloc(sizeof(*node), GFP_KERNEL); + if (!node) + return -ENOMEM; + + ether_addr_copy(node->macaddress_A, addr_a); + ether_addr_copy(node->macaddress_B, addr_b); + + spin_lock_bh(&hsr->list_lock); + oldnode = list_first_or_null_rcu(self_node_db, + struct hsr_node, mac_list); + if (oldnode) { + list_replace_rcu(&oldnode->mac_list, &node->mac_list); + spin_unlock_bh(&hsr->list_lock); + kfree_rcu(oldnode, rcu_head); + } else { + list_add_tail_rcu(&node->mac_list, self_node_db); + spin_unlock_bh(&hsr->list_lock); + } + + return 0; +} + +void hsr_del_self_node(struct hsr_priv *hsr) +{ + struct list_head *self_node_db = &hsr->self_node_db; + struct hsr_node *node; + + spin_lock_bh(&hsr->list_lock); + node = list_first_or_null_rcu(self_node_db, struct hsr_node, mac_list); + if (node) { + list_del_rcu(&node->mac_list); + kfree_rcu(node, rcu_head); + } + spin_unlock_bh(&hsr->list_lock); +} + +void hsr_del_nodes(struct list_head *node_db) +{ + struct hsr_node *node; + struct hsr_node *tmp; + + list_for_each_entry_safe(node, tmp, node_db, mac_list) + kfree(node); +} + +void prp_handle_san_frame(bool san, enum hsr_port_type port, + struct hsr_node *node) +{ + /* Mark if the SAN node is over LAN_A or LAN_B */ + if (port == HSR_PT_SLAVE_A) { + node->san_a = true; + return; + } + + if (port == HSR_PT_SLAVE_B) + node->san_b = true; +} + +/* Allocate an hsr_node and add it to node_db. 'addr' is the node's address_A; + * seq_out is used to initialize filtering of outgoing duplicate frames + * originating from the newly added node. + */ +static struct hsr_node *hsr_add_node(struct hsr_priv *hsr, + struct list_head *node_db, + unsigned char addr[], + u16 seq_out, bool san, + enum hsr_port_type rx_port) +{ + struct hsr_node *new_node, *node; + unsigned long now; + int i; + + new_node = kzalloc(sizeof(*new_node), GFP_ATOMIC); + if (!new_node) + return NULL; + + ether_addr_copy(new_node->macaddress_A, addr); + spin_lock_init(&new_node->seq_out_lock); + + /* We are only interested in time diffs here, so use current jiffies + * as initialization. (0 could trigger an spurious ring error warning). + */ + now = jiffies; + for (i = 0; i < HSR_PT_PORTS; i++) { + new_node->time_in[i] = now; + new_node->time_out[i] = now; + } + for (i = 0; i < HSR_PT_PORTS; i++) + new_node->seq_out[i] = seq_out; + + if (san && hsr->proto_ops->handle_san_frame) + hsr->proto_ops->handle_san_frame(san, rx_port, new_node); + + spin_lock_bh(&hsr->list_lock); + list_for_each_entry_rcu(node, node_db, mac_list, + lockdep_is_held(&hsr->list_lock)) { + if (ether_addr_equal(node->macaddress_A, addr)) + goto out; + if (ether_addr_equal(node->macaddress_B, addr)) + goto out; + } + list_add_tail_rcu(&new_node->mac_list, node_db); + spin_unlock_bh(&hsr->list_lock); + return new_node; +out: + spin_unlock_bh(&hsr->list_lock); + kfree(new_node); + return node; +} + +void prp_update_san_info(struct hsr_node *node, bool is_sup) +{ + if (!is_sup) + return; + + node->san_a = false; + node->san_b = false; +} + +/* Get the hsr_node from which 'skb' was sent. + */ +struct hsr_node *hsr_get_node(struct hsr_port *port, struct list_head *node_db, + struct sk_buff *skb, bool is_sup, + enum hsr_port_type rx_port) +{ + struct hsr_priv *hsr = port->hsr; + struct hsr_node *node; + struct ethhdr *ethhdr; + struct prp_rct *rct; + bool san = false; + u16 seq_out; + + if (!skb_mac_header_was_set(skb)) + return NULL; + + ethhdr = (struct ethhdr *)skb_mac_header(skb); + + list_for_each_entry_rcu(node, node_db, mac_list) { + if (ether_addr_equal(node->macaddress_A, ethhdr->h_source)) { + if (hsr->proto_ops->update_san_info) + hsr->proto_ops->update_san_info(node, is_sup); + return node; + } + if (ether_addr_equal(node->macaddress_B, ethhdr->h_source)) { + if (hsr->proto_ops->update_san_info) + hsr->proto_ops->update_san_info(node, is_sup); + return node; + } + } + + /* Everyone may create a node entry, connected node to a HSR/PRP + * device. + */ + if (ethhdr->h_proto == htons(ETH_P_PRP) || + ethhdr->h_proto == htons(ETH_P_HSR)) { + /* Use the existing sequence_nr from the tag as starting point + * for filtering duplicate frames. + */ + seq_out = hsr_get_skb_sequence_nr(skb) - 1; + } else { + rct = skb_get_PRP_rct(skb); + if (rct && prp_check_lsdu_size(skb, rct, is_sup)) { + seq_out = prp_get_skb_sequence_nr(rct); + } else { + if (rx_port != HSR_PT_MASTER) + san = true; + seq_out = HSR_SEQNR_START; + } + } + + return hsr_add_node(hsr, node_db, ethhdr->h_source, seq_out, + san, rx_port); +} + +/* Use the Supervision frame's info about an eventual macaddress_B for merging + * nodes that has previously had their macaddress_B registered as a separate + * node. + */ +void hsr_handle_sup_frame(struct hsr_frame_info *frame) +{ + struct hsr_node *node_curr = frame->node_src; + struct hsr_port *port_rcv = frame->port_rcv; + struct hsr_priv *hsr = port_rcv->hsr; + struct hsr_sup_payload *hsr_sp; + struct hsr_sup_tlv *hsr_sup_tlv; + struct hsr_node *node_real; + struct sk_buff *skb = NULL; + struct list_head *node_db; + struct ethhdr *ethhdr; + int i; + unsigned int pull_size = 0; + unsigned int total_pull_size = 0; + + /* Here either frame->skb_hsr or frame->skb_prp should be + * valid as supervision frame always will have protocol + * header info. + */ + if (frame->skb_hsr) + skb = frame->skb_hsr; + else if (frame->skb_prp) + skb = frame->skb_prp; + else if (frame->skb_std) + skb = frame->skb_std; + if (!skb) + return; + + /* Leave the ethernet header. */ + pull_size = sizeof(struct ethhdr); + skb_pull(skb, pull_size); + total_pull_size += pull_size; + + ethhdr = (struct ethhdr *)skb_mac_header(skb); + + /* And leave the HSR tag. */ + if (ethhdr->h_proto == htons(ETH_P_HSR)) { + pull_size = sizeof(struct hsr_tag); + skb_pull(skb, pull_size); + total_pull_size += pull_size; + } + + /* And leave the HSR sup tag. */ + pull_size = sizeof(struct hsr_sup_tag); + skb_pull(skb, pull_size); + total_pull_size += pull_size; + + /* get HSR sup payload */ + hsr_sp = (struct hsr_sup_payload *)skb->data; + + /* Merge node_curr (registered on macaddress_B) into node_real */ + node_db = &port_rcv->hsr->node_db; + node_real = find_node_by_addr_A(node_db, hsr_sp->macaddress_A); + if (!node_real) + /* No frame received from AddrA of this node yet */ + node_real = hsr_add_node(hsr, node_db, hsr_sp->macaddress_A, + HSR_SEQNR_START - 1, true, + port_rcv->type); + if (!node_real) + goto done; /* No mem */ + if (node_real == node_curr) + /* Node has already been merged */ + goto done; + + /* Leave the first HSR sup payload. */ + pull_size = sizeof(struct hsr_sup_payload); + skb_pull(skb, pull_size); + total_pull_size += pull_size; + + /* Get second supervision tlv */ + hsr_sup_tlv = (struct hsr_sup_tlv *)skb->data; + /* And check if it is a redbox mac TLV */ + if (hsr_sup_tlv->HSR_TLV_type == PRP_TLV_REDBOX_MAC) { + /* We could stop here after pushing hsr_sup_payload, + * or proceed and allow macaddress_B and for redboxes. + */ + /* Sanity check length */ + if (hsr_sup_tlv->HSR_TLV_length != 6) + goto done; + + /* Leave the second HSR sup tlv. */ + pull_size = sizeof(struct hsr_sup_tlv); + skb_pull(skb, pull_size); + total_pull_size += pull_size; + + /* Get redbox mac address. */ + hsr_sp = (struct hsr_sup_payload *)skb->data; + + /* Check if redbox mac and node mac are equal. */ + if (!ether_addr_equal(node_real->macaddress_A, hsr_sp->macaddress_A)) { + /* This is a redbox supervision frame for a VDAN! */ + goto done; + } + } + + ether_addr_copy(node_real->macaddress_B, ethhdr->h_source); + spin_lock_bh(&node_real->seq_out_lock); + for (i = 0; i < HSR_PT_PORTS; i++) { + if (!node_curr->time_in_stale[i] && + time_after(node_curr->time_in[i], node_real->time_in[i])) { + node_real->time_in[i] = node_curr->time_in[i]; + node_real->time_in_stale[i] = + node_curr->time_in_stale[i]; + } + if (seq_nr_after(node_curr->seq_out[i], node_real->seq_out[i])) + node_real->seq_out[i] = node_curr->seq_out[i]; + } + spin_unlock_bh(&node_real->seq_out_lock); + node_real->addr_B_port = port_rcv->type; + + spin_lock_bh(&hsr->list_lock); + if (!node_curr->removed) { + list_del_rcu(&node_curr->mac_list); + node_curr->removed = true; + kfree_rcu(node_curr, rcu_head); + } + spin_unlock_bh(&hsr->list_lock); + +done: + /* Push back here */ + skb_push(skb, total_pull_size); +} + +/* 'skb' is a frame meant for this host, that is to be passed to upper layers. + * + * If the frame was sent by a node's B interface, replace the source + * address with that node's "official" address (macaddress_A) so that upper + * layers recognize where it came from. + */ +void hsr_addr_subst_source(struct hsr_node *node, struct sk_buff *skb) +{ + if (!skb_mac_header_was_set(skb)) { + WARN_ONCE(1, "%s: Mac header not set\n", __func__); + return; + } + + memcpy(ð_hdr(skb)->h_source, node->macaddress_A, ETH_ALEN); +} + +/* 'skb' is a frame meant for another host. + * 'port' is the outgoing interface + * + * Substitute the target (dest) MAC address if necessary, so the it matches the + * recipient interface MAC address, regardless of whether that is the + * recipient's A or B interface. + * This is needed to keep the packets flowing through switches that learn on + * which "side" the different interfaces are. + */ +void hsr_addr_subst_dest(struct hsr_node *node_src, struct sk_buff *skb, + struct hsr_port *port) +{ + struct hsr_node *node_dst; + + if (!skb_mac_header_was_set(skb)) { + WARN_ONCE(1, "%s: Mac header not set\n", __func__); + return; + } + + if (!is_unicast_ether_addr(eth_hdr(skb)->h_dest)) + return; + + node_dst = find_node_by_addr_A(&port->hsr->node_db, + eth_hdr(skb)->h_dest); + if (!node_dst) { + if (port->hsr->prot_version != PRP_V1 && net_ratelimit()) + netdev_err(skb->dev, "%s: Unknown node\n", __func__); + return; + } + if (port->type != node_dst->addr_B_port) + return; + + if (is_valid_ether_addr(node_dst->macaddress_B)) + ether_addr_copy(eth_hdr(skb)->h_dest, node_dst->macaddress_B); +} + +void hsr_register_frame_in(struct hsr_node *node, struct hsr_port *port, + u16 sequence_nr) +{ + /* Don't register incoming frames without a valid sequence number. This + * ensures entries of restarted nodes gets pruned so that they can + * re-register and resume communications. + */ + if (!(port->dev->features & NETIF_F_HW_HSR_TAG_RM) && + seq_nr_before(sequence_nr, node->seq_out[port->type])) + return; + + node->time_in[port->type] = jiffies; + node->time_in_stale[port->type] = false; +} + +/* 'skb' is a HSR Ethernet frame (with a HSR tag inserted), with a valid + * ethhdr->h_source address and skb->mac_header set. + * + * Return: + * 1 if frame can be shown to have been sent recently on this interface, + * 0 otherwise, or + * negative error code on error + */ +int hsr_register_frame_out(struct hsr_port *port, struct hsr_node *node, + u16 sequence_nr) +{ + spin_lock_bh(&node->seq_out_lock); + if (seq_nr_before_or_eq(sequence_nr, node->seq_out[port->type]) && + time_is_after_jiffies(node->time_out[port->type] + + msecs_to_jiffies(HSR_ENTRY_FORGET_TIME))) { + spin_unlock_bh(&node->seq_out_lock); + return 1; + } + + node->time_out[port->type] = jiffies; + node->seq_out[port->type] = sequence_nr; + spin_unlock_bh(&node->seq_out_lock); + return 0; +} + +static struct hsr_port *get_late_port(struct hsr_priv *hsr, + struct hsr_node *node) +{ + if (node->time_in_stale[HSR_PT_SLAVE_A]) + return hsr_port_get_hsr(hsr, HSR_PT_SLAVE_A); + if (node->time_in_stale[HSR_PT_SLAVE_B]) + return hsr_port_get_hsr(hsr, HSR_PT_SLAVE_B); + + if (time_after(node->time_in[HSR_PT_SLAVE_B], + node->time_in[HSR_PT_SLAVE_A] + + msecs_to_jiffies(MAX_SLAVE_DIFF))) + return hsr_port_get_hsr(hsr, HSR_PT_SLAVE_A); + if (time_after(node->time_in[HSR_PT_SLAVE_A], + node->time_in[HSR_PT_SLAVE_B] + + msecs_to_jiffies(MAX_SLAVE_DIFF))) + return hsr_port_get_hsr(hsr, HSR_PT_SLAVE_B); + + return NULL; +} + +/* Remove stale sequence_nr records. Called by timer every + * HSR_LIFE_CHECK_INTERVAL (two seconds or so). + */ +void hsr_prune_nodes(struct timer_list *t) +{ + struct hsr_priv *hsr = from_timer(hsr, t, prune_timer); + struct hsr_node *node; + struct hsr_node *tmp; + struct hsr_port *port; + unsigned long timestamp; + unsigned long time_a, time_b; + + spin_lock_bh(&hsr->list_lock); + list_for_each_entry_safe(node, tmp, &hsr->node_db, mac_list) { + /* Don't prune own node. Neither time_in[HSR_PT_SLAVE_A] + * nor time_in[HSR_PT_SLAVE_B], will ever be updated for + * the master port. Thus the master node will be repeatedly + * pruned leading to packet loss. + */ + if (hsr_addr_is_self(hsr, node->macaddress_A)) + continue; + + /* Shorthand */ + time_a = node->time_in[HSR_PT_SLAVE_A]; + time_b = node->time_in[HSR_PT_SLAVE_B]; + + /* Check for timestamps old enough to risk wrap-around */ + if (time_after(jiffies, time_a + MAX_JIFFY_OFFSET / 2)) + node->time_in_stale[HSR_PT_SLAVE_A] = true; + if (time_after(jiffies, time_b + MAX_JIFFY_OFFSET / 2)) + node->time_in_stale[HSR_PT_SLAVE_B] = true; + + /* Get age of newest frame from node. + * At least one time_in is OK here; nodes get pruned long + * before both time_ins can get stale + */ + timestamp = time_a; + if (node->time_in_stale[HSR_PT_SLAVE_A] || + (!node->time_in_stale[HSR_PT_SLAVE_B] && + time_after(time_b, time_a))) + timestamp = time_b; + + /* Warn of ring error only as long as we get frames at all */ + if (time_is_after_jiffies(timestamp + + msecs_to_jiffies(1.5 * MAX_SLAVE_DIFF))) { + rcu_read_lock(); + port = get_late_port(hsr, node); + if (port) + hsr_nl_ringerror(hsr, node->macaddress_A, port); + rcu_read_unlock(); + } + + /* Prune old entries */ + if (time_is_before_jiffies(timestamp + + msecs_to_jiffies(HSR_NODE_FORGET_TIME))) { + hsr_nl_nodedown(hsr, node->macaddress_A); + if (!node->removed) { + list_del_rcu(&node->mac_list); + node->removed = true; + /* Note that we need to free this entry later: */ + kfree_rcu(node, rcu_head); + } + } + } + spin_unlock_bh(&hsr->list_lock); + + /* Restart timer */ + mod_timer(&hsr->prune_timer, + jiffies + msecs_to_jiffies(PRUNE_PERIOD)); +} + +void *hsr_get_next_node(struct hsr_priv *hsr, void *_pos, + unsigned char addr[ETH_ALEN]) +{ + struct hsr_node *node; + + if (!_pos) { + node = list_first_or_null_rcu(&hsr->node_db, + struct hsr_node, mac_list); + if (node) + ether_addr_copy(addr, node->macaddress_A); + return node; + } + + node = _pos; + list_for_each_entry_continue_rcu(node, &hsr->node_db, mac_list) { + ether_addr_copy(addr, node->macaddress_A); + return node; + } + + return NULL; +} + +int hsr_get_node_data(struct hsr_priv *hsr, + const unsigned char *addr, + unsigned char addr_b[ETH_ALEN], + unsigned int *addr_b_ifindex, + int *if1_age, + u16 *if1_seq, + int *if2_age, + u16 *if2_seq) +{ + struct hsr_node *node; + struct hsr_port *port; + unsigned long tdiff; + + node = find_node_by_addr_A(&hsr->node_db, addr); + if (!node) + return -ENOENT; + + ether_addr_copy(addr_b, node->macaddress_B); + + tdiff = jiffies - node->time_in[HSR_PT_SLAVE_A]; + if (node->time_in_stale[HSR_PT_SLAVE_A]) + *if1_age = INT_MAX; +#if HZ <= MSEC_PER_SEC + else if (tdiff > msecs_to_jiffies(INT_MAX)) + *if1_age = INT_MAX; +#endif + else + *if1_age = jiffies_to_msecs(tdiff); + + tdiff = jiffies - node->time_in[HSR_PT_SLAVE_B]; + if (node->time_in_stale[HSR_PT_SLAVE_B]) + *if2_age = INT_MAX; +#if HZ <= MSEC_PER_SEC + else if (tdiff > msecs_to_jiffies(INT_MAX)) + *if2_age = INT_MAX; +#endif + else + *if2_age = jiffies_to_msecs(tdiff); + + /* Present sequence numbers as if they were incoming on interface */ + *if1_seq = node->seq_out[HSR_PT_SLAVE_B]; + *if2_seq = node->seq_out[HSR_PT_SLAVE_A]; + + if (node->addr_B_port != HSR_PT_NONE) { + port = hsr_port_get_hsr(hsr, node->addr_B_port); + *addr_b_ifindex = port->dev->ifindex; + } else { + *addr_b_ifindex = -1; + } + + return 0; +} diff --git a/net/hsr/hsr_framereg.h b/net/hsr/hsr_framereg.h new file mode 100644 index 000000000..b23556251 --- /dev/null +++ b/net/hsr/hsr_framereg.h @@ -0,0 +1,89 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright 2011-2014 Autronica Fire and Security AS + * + * Author(s): + * 2011-2014 Arvid Brodin, arvid.brodin@alten.se + * + * include file for HSR and PRP. + */ + +#ifndef __HSR_FRAMEREG_H +#define __HSR_FRAMEREG_H + +#include "hsr_main.h" + +struct hsr_node; + +struct hsr_frame_info { + struct sk_buff *skb_std; + struct sk_buff *skb_hsr; + struct sk_buff *skb_prp; + struct hsr_port *port_rcv; + struct hsr_node *node_src; + u16 sequence_nr; + bool is_supervision; + bool is_vlan; + bool is_local_dest; + bool is_local_exclusive; + bool is_from_san; +}; + +void hsr_del_self_node(struct hsr_priv *hsr); +void hsr_del_nodes(struct list_head *node_db); +struct hsr_node *hsr_get_node(struct hsr_port *port, struct list_head *node_db, + struct sk_buff *skb, bool is_sup, + enum hsr_port_type rx_port); +void hsr_handle_sup_frame(struct hsr_frame_info *frame); +bool hsr_addr_is_self(struct hsr_priv *hsr, unsigned char *addr); + +void hsr_addr_subst_source(struct hsr_node *node, struct sk_buff *skb); +void hsr_addr_subst_dest(struct hsr_node *node_src, struct sk_buff *skb, + struct hsr_port *port); + +void hsr_register_frame_in(struct hsr_node *node, struct hsr_port *port, + u16 sequence_nr); +int hsr_register_frame_out(struct hsr_port *port, struct hsr_node *node, + u16 sequence_nr); + +void hsr_prune_nodes(struct timer_list *t); + +int hsr_create_self_node(struct hsr_priv *hsr, + const unsigned char addr_a[ETH_ALEN], + const unsigned char addr_b[ETH_ALEN]); + +void *hsr_get_next_node(struct hsr_priv *hsr, void *_pos, + unsigned char addr[ETH_ALEN]); + +int hsr_get_node_data(struct hsr_priv *hsr, + const unsigned char *addr, + unsigned char addr_b[ETH_ALEN], + unsigned int *addr_b_ifindex, + int *if1_age, + u16 *if1_seq, + int *if2_age, + u16 *if2_seq); + +void prp_handle_san_frame(bool san, enum hsr_port_type port, + struct hsr_node *node); +void prp_update_san_info(struct hsr_node *node, bool is_sup); + +struct hsr_node { + struct list_head mac_list; + /* Protect R/W access to seq_out */ + spinlock_t seq_out_lock; + unsigned char macaddress_A[ETH_ALEN]; + unsigned char macaddress_B[ETH_ALEN]; + /* Local slave through which AddrB frames are received from this node */ + enum hsr_port_type addr_B_port; + unsigned long time_in[HSR_PT_PORTS]; + bool time_in_stale[HSR_PT_PORTS]; + unsigned long time_out[HSR_PT_PORTS]; + /* if the node is a SAN */ + bool san_a; + bool san_b; + u16 seq_out[HSR_PT_PORTS]; + bool removed; + struct rcu_head rcu_head; +}; + +#endif /* __HSR_FRAMEREG_H */ diff --git a/net/hsr/hsr_main.c b/net/hsr/hsr_main.c new file mode 100644 index 000000000..b099c3150 --- /dev/null +++ b/net/hsr/hsr_main.c @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright 2011-2014 Autronica Fire and Security AS + * + * Author(s): + * 2011-2014 Arvid Brodin, arvid.brodin@alten.se + * + * Event handling for HSR and PRP devices. + */ + +#include +#include +#include +#include +#include +#include "hsr_main.h" +#include "hsr_device.h" +#include "hsr_netlink.h" +#include "hsr_framereg.h" +#include "hsr_slave.h" + +static bool hsr_slave_empty(struct hsr_priv *hsr) +{ + struct hsr_port *port; + + hsr_for_each_port(hsr, port) + if (port->type != HSR_PT_MASTER) + return false; + return true; +} + +static int hsr_netdev_notify(struct notifier_block *nb, unsigned long event, + void *ptr) +{ + struct hsr_port *port, *master; + struct net_device *dev; + struct hsr_priv *hsr; + LIST_HEAD(list_kill); + int mtu_max; + int res; + + dev = netdev_notifier_info_to_dev(ptr); + port = hsr_port_get_rtnl(dev); + if (!port) { + if (!is_hsr_master(dev)) + return NOTIFY_DONE; /* Not an HSR device */ + hsr = netdev_priv(dev); + port = hsr_port_get_hsr(hsr, HSR_PT_MASTER); + if (!port) { + /* Resend of notification concerning removed device? */ + return NOTIFY_DONE; + } + } else { + hsr = port->hsr; + } + + switch (event) { + case NETDEV_UP: /* Administrative state DOWN */ + case NETDEV_DOWN: /* Administrative state UP */ + case NETDEV_CHANGE: /* Link (carrier) state changes */ + hsr_check_carrier_and_operstate(hsr); + break; + case NETDEV_CHANGENAME: + if (is_hsr_master(dev)) + hsr_debugfs_rename(dev); + break; + case NETDEV_CHANGEADDR: + if (port->type == HSR_PT_MASTER) { + /* This should not happen since there's no + * ndo_set_mac_address() for HSR devices - i.e. not + * supported. + */ + break; + } + + master = hsr_port_get_hsr(hsr, HSR_PT_MASTER); + + if (port->type == HSR_PT_SLAVE_A) { + eth_hw_addr_set(master->dev, dev->dev_addr); + call_netdevice_notifiers(NETDEV_CHANGEADDR, + master->dev); + } + + /* Make sure we recognize frames from ourselves in hsr_rcv() */ + port = hsr_port_get_hsr(hsr, HSR_PT_SLAVE_B); + res = hsr_create_self_node(hsr, + master->dev->dev_addr, + port ? + port->dev->dev_addr : + master->dev->dev_addr); + if (res) + netdev_warn(master->dev, + "Could not update HSR node address.\n"); + break; + case NETDEV_CHANGEMTU: + if (port->type == HSR_PT_MASTER) + break; /* Handled in ndo_change_mtu() */ + mtu_max = hsr_get_max_mtu(port->hsr); + master = hsr_port_get_hsr(port->hsr, HSR_PT_MASTER); + master->dev->mtu = mtu_max; + break; + case NETDEV_UNREGISTER: + if (!is_hsr_master(dev)) { + master = hsr_port_get_hsr(port->hsr, HSR_PT_MASTER); + hsr_del_port(port); + if (hsr_slave_empty(master->hsr)) { + const struct rtnl_link_ops *ops; + + ops = master->dev->rtnl_link_ops; + ops->dellink(master->dev, &list_kill); + unregister_netdevice_many(&list_kill); + } + } + break; + case NETDEV_PRE_TYPE_CHANGE: + /* HSR works only on Ethernet devices. Refuse slave to change + * its type. + */ + return NOTIFY_BAD; + } + + return NOTIFY_DONE; +} + +struct hsr_port *hsr_port_get_hsr(struct hsr_priv *hsr, enum hsr_port_type pt) +{ + struct hsr_port *port; + + hsr_for_each_port(hsr, port) + if (port->type == pt) + return port; + return NULL; +} + +int hsr_get_version(struct net_device *dev, enum hsr_version *ver) +{ + struct hsr_priv *hsr; + + hsr = netdev_priv(dev); + *ver = hsr->prot_version; + + return 0; +} +EXPORT_SYMBOL(hsr_get_version); + +static struct notifier_block hsr_nb = { + .notifier_call = hsr_netdev_notify, /* Slave event notifications */ +}; + +static int __init hsr_init(void) +{ + int res; + + BUILD_BUG_ON(sizeof(struct hsr_tag) != HSR_HLEN); + + register_netdevice_notifier(&hsr_nb); + res = hsr_netlink_init(); + + return res; +} + +static void __exit hsr_exit(void) +{ + hsr_netlink_exit(); + hsr_debugfs_remove_root(); + unregister_netdevice_notifier(&hsr_nb); +} + +module_init(hsr_init); +module_exit(hsr_exit); +MODULE_LICENSE("GPL"); diff --git a/net/hsr/hsr_main.h b/net/hsr/hsr_main.h new file mode 100644 index 000000000..58a5a8b38 --- /dev/null +++ b/net/hsr/hsr_main.h @@ -0,0 +1,287 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright 2011-2014 Autronica Fire and Security AS + * + * Author(s): + * 2011-2014 Arvid Brodin, arvid.brodin@alten.se + * + * include file for HSR and PRP. + */ + +#ifndef __HSR_PRIVATE_H +#define __HSR_PRIVATE_H + +#include +#include +#include +#include + +/* Time constants as specified in the HSR specification (IEC-62439-3 2010) + * Table 8. + * All values in milliseconds. + */ +#define HSR_LIFE_CHECK_INTERVAL 2000 /* ms */ +#define HSR_NODE_FORGET_TIME 60000 /* ms */ +#define HSR_ANNOUNCE_INTERVAL 100 /* ms */ +#define HSR_ENTRY_FORGET_TIME 400 /* ms */ + +/* By how much may slave1 and slave2 timestamps of latest received frame from + * each node differ before we notify of communication problem? + */ +#define MAX_SLAVE_DIFF 3000 /* ms */ +#define HSR_SEQNR_START (USHRT_MAX - 1024) +#define HSR_SUP_SEQNR_START (HSR_SEQNR_START / 2) + +/* How often shall we check for broken ring and remove node entries older than + * HSR_NODE_FORGET_TIME? + */ +#define PRUNE_PERIOD 3000 /* ms */ +#define HSR_TLV_EOT 0 /* End of TLVs */ +#define HSR_TLV_ANNOUNCE 22 +#define HSR_TLV_LIFE_CHECK 23 +/* PRP V1 life check for Duplicate discard */ +#define PRP_TLV_LIFE_CHECK_DD 20 +/* PRP V1 life check for Duplicate Accept */ +#define PRP_TLV_LIFE_CHECK_DA 21 +/* PRP V1 life redundancy box MAC address */ +#define PRP_TLV_REDBOX_MAC 30 + +#define HSR_V1_SUP_LSDUSIZE 52 + +/* The helper functions below assumes that 'path' occupies the 4 most + * significant bits of the 16-bit field shared by 'path' and 'LSDU_size' (or + * equivalently, the 4 most significant bits of HSR tag byte 14). + * + * This is unclear in the IEC specification; its definition of MAC addresses + * indicates the spec is written with the least significant bit first (to the + * left). This, however, would mean that the LSDU field would be split in two + * with the path field in-between, which seems strange. I'm guessing the MAC + * address definition is in error. + */ + +static inline void set_hsr_tag_path(struct hsr_tag *ht, u16 path) +{ + ht->path_and_LSDU_size = + htons((ntohs(ht->path_and_LSDU_size) & 0x0FFF) | (path << 12)); +} + +static inline void set_hsr_tag_LSDU_size(struct hsr_tag *ht, u16 LSDU_size) +{ + ht->path_and_LSDU_size = htons((ntohs(ht->path_and_LSDU_size) & + 0xF000) | (LSDU_size & 0x0FFF)); +} + +struct hsr_ethhdr { + struct ethhdr ethhdr; + struct hsr_tag hsr_tag; +} __packed; + +struct hsr_vlan_ethhdr { + struct vlan_ethhdr vlanhdr; + struct hsr_tag hsr_tag; +} __packed; + +struct hsr_sup_tlv { + u8 HSR_TLV_type; + u8 HSR_TLV_length; +} __packed; + +/* HSR/PRP Supervision Frame data types. + * Field names as defined in the IEC:2010 standard for HSR. + */ +struct hsr_sup_tag { + __be16 path_and_HSR_ver; + __be16 sequence_nr; + struct hsr_sup_tlv tlv; +} __packed; + +struct hsr_sup_payload { + unsigned char macaddress_A[ETH_ALEN]; +} __packed; + +static inline void set_hsr_stag_path(struct hsr_sup_tag *hst, u16 path) +{ + set_hsr_tag_path((struct hsr_tag *)hst, path); +} + +static inline void set_hsr_stag_HSR_ver(struct hsr_sup_tag *hst, u16 HSR_ver) +{ + set_hsr_tag_LSDU_size((struct hsr_tag *)hst, HSR_ver); +} + +struct hsrv0_ethhdr_sp { + struct ethhdr ethhdr; + struct hsr_sup_tag hsr_sup; +} __packed; + +struct hsrv1_ethhdr_sp { + struct ethhdr ethhdr; + struct hsr_tag hsr; + struct hsr_sup_tag hsr_sup; +} __packed; + +enum hsr_port_type { + HSR_PT_NONE = 0, /* Must be 0, used by framereg */ + HSR_PT_SLAVE_A, + HSR_PT_SLAVE_B, + HSR_PT_INTERLINK, + HSR_PT_MASTER, + HSR_PT_PORTS, /* This must be the last item in the enum */ +}; + +/* PRP Redunancy Control Trailor (RCT). + * As defined in IEC-62439-4:2012, the PRP RCT is really { sequence Nr, + * Lan indentifier (LanId), LSDU_size and PRP_suffix = 0x88FB }. + * + * Field names as defined in the IEC:2012 standard for PRP. + */ +struct prp_rct { + __be16 sequence_nr; + __be16 lan_id_and_LSDU_size; + __be16 PRP_suffix; +} __packed; + +static inline u16 get_prp_LSDU_size(struct prp_rct *rct) +{ + return ntohs(rct->lan_id_and_LSDU_size) & 0x0FFF; +} + +static inline void set_prp_lan_id(struct prp_rct *rct, u16 lan_id) +{ + rct->lan_id_and_LSDU_size = htons((ntohs(rct->lan_id_and_LSDU_size) & + 0x0FFF) | (lan_id << 12)); +} +static inline void set_prp_LSDU_size(struct prp_rct *rct, u16 LSDU_size) +{ + rct->lan_id_and_LSDU_size = htons((ntohs(rct->lan_id_and_LSDU_size) & + 0xF000) | (LSDU_size & 0x0FFF)); +} + +struct hsr_port { + struct list_head port_list; + struct net_device *dev; + struct hsr_priv *hsr; + enum hsr_port_type type; +}; + +struct hsr_frame_info; +struct hsr_node; + +struct hsr_proto_ops { + /* format and send supervision frame */ + void (*send_sv_frame)(struct hsr_port *port, unsigned long *interval); + void (*handle_san_frame)(bool san, enum hsr_port_type port, + struct hsr_node *node); + bool (*drop_frame)(struct hsr_frame_info *frame, struct hsr_port *port); + struct sk_buff * (*get_untagged_frame)(struct hsr_frame_info *frame, + struct hsr_port *port); + struct sk_buff * (*create_tagged_frame)(struct hsr_frame_info *frame, + struct hsr_port *port); + int (*fill_frame_info)(__be16 proto, struct sk_buff *skb, + struct hsr_frame_info *frame); + bool (*invalid_dan_ingress_frame)(__be16 protocol); + void (*update_san_info)(struct hsr_node *node, bool is_sup); +}; + +struct hsr_priv { + struct rcu_head rcu_head; + struct list_head ports; + struct list_head node_db; /* Known HSR nodes */ + struct list_head self_node_db; /* MACs of slaves */ + struct timer_list announce_timer; /* Supervision frame dispatch */ + struct timer_list prune_timer; + int announce_count; + u16 sequence_nr; + u16 sup_sequence_nr; /* For HSRv1 separate seq_nr for supervision */ + enum hsr_version prot_version; /* Indicate if HSRv0, HSRv1 or PRPv1 */ + spinlock_t seqnr_lock; /* locking for sequence_nr */ + spinlock_t list_lock; /* locking for node list */ + struct hsr_proto_ops *proto_ops; +#define PRP_LAN_ID 0x5 /* 0x1010 for A and 0x1011 for B. Bit 0 is set + * based on SLAVE_A or SLAVE_B + */ + u8 net_id; /* for PRP, it occupies most significant 3 bits + * of lan_id + */ + unsigned char sup_multicast_addr[ETH_ALEN] __aligned(sizeof(u16)); + /* Align to u16 boundary to avoid unaligned access + * in ether_addr_equal + */ +#ifdef CONFIG_DEBUG_FS + struct dentry *node_tbl_root; +#endif +}; + +#define hsr_for_each_port(hsr, port) \ + list_for_each_entry_rcu((port), &(hsr)->ports, port_list) + +struct hsr_port *hsr_port_get_hsr(struct hsr_priv *hsr, enum hsr_port_type pt); + +/* Caller must ensure skb is a valid HSR frame */ +static inline u16 hsr_get_skb_sequence_nr(struct sk_buff *skb) +{ + struct hsr_ethhdr *hsr_ethhdr; + + hsr_ethhdr = (struct hsr_ethhdr *)skb_mac_header(skb); + return ntohs(hsr_ethhdr->hsr_tag.sequence_nr); +} + +static inline struct prp_rct *skb_get_PRP_rct(struct sk_buff *skb) +{ + unsigned char *tail = skb_tail_pointer(skb) - HSR_HLEN; + + struct prp_rct *rct = (struct prp_rct *)tail; + + if (rct->PRP_suffix == htons(ETH_P_PRP)) + return rct; + + return NULL; +} + +/* Assume caller has confirmed this skb is PRP suffixed */ +static inline u16 prp_get_skb_sequence_nr(struct prp_rct *rct) +{ + return ntohs(rct->sequence_nr); +} + +/* assume there is a valid rct */ +static inline bool prp_check_lsdu_size(struct sk_buff *skb, + struct prp_rct *rct, + bool is_sup) +{ + struct ethhdr *ethhdr; + int expected_lsdu_size; + + if (is_sup) { + expected_lsdu_size = HSR_V1_SUP_LSDUSIZE; + } else { + ethhdr = (struct ethhdr *)skb_mac_header(skb); + expected_lsdu_size = skb->len - 14; + if (ethhdr->h_proto == htons(ETH_P_8021Q)) + expected_lsdu_size -= 4; + } + + return (expected_lsdu_size == get_prp_LSDU_size(rct)); +} + +#if IS_ENABLED(CONFIG_DEBUG_FS) +void hsr_debugfs_rename(struct net_device *dev); +void hsr_debugfs_init(struct hsr_priv *priv, struct net_device *hsr_dev); +void hsr_debugfs_term(struct hsr_priv *priv); +void hsr_debugfs_create_root(void); +void hsr_debugfs_remove_root(void); +#else +static inline void hsr_debugfs_rename(struct net_device *dev) +{ +} +static inline void hsr_debugfs_init(struct hsr_priv *priv, + struct net_device *hsr_dev) +{} +static inline void hsr_debugfs_term(struct hsr_priv *priv) +{} +static inline void hsr_debugfs_create_root(void) +{} +static inline void hsr_debugfs_remove_root(void) +{} +#endif + +#endif /* __HSR_PRIVATE_H */ diff --git a/net/hsr/hsr_netlink.c b/net/hsr/hsr_netlink.c new file mode 100644 index 000000000..78fe40eb9 --- /dev/null +++ b/net/hsr/hsr_netlink.c @@ -0,0 +1,556 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright 2011-2014 Autronica Fire and Security AS + * + * Author(s): + * 2011-2014 Arvid Brodin, arvid.brodin@alten.se + * + * Routines for handling Netlink messages for HSR and PRP. + */ + +#include "hsr_netlink.h" +#include +#include +#include +#include "hsr_main.h" +#include "hsr_device.h" +#include "hsr_framereg.h" + +static const struct nla_policy hsr_policy[IFLA_HSR_MAX + 1] = { + [IFLA_HSR_SLAVE1] = { .type = NLA_U32 }, + [IFLA_HSR_SLAVE2] = { .type = NLA_U32 }, + [IFLA_HSR_MULTICAST_SPEC] = { .type = NLA_U8 }, + [IFLA_HSR_VERSION] = { .type = NLA_U8 }, + [IFLA_HSR_SUPERVISION_ADDR] = { .len = ETH_ALEN }, + [IFLA_HSR_SEQ_NR] = { .type = NLA_U16 }, + [IFLA_HSR_PROTOCOL] = { .type = NLA_U8 }, +}; + +/* Here, it seems a netdevice has already been allocated for us, and the + * hsr_dev_setup routine has been executed. Nice! + */ +static int hsr_newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + enum hsr_version proto_version; + unsigned char multicast_spec; + u8 proto = HSR_PROTOCOL_HSR; + struct net_device *link[2]; + + if (!data) { + NL_SET_ERR_MSG_MOD(extack, "No slave devices specified"); + return -EINVAL; + } + if (!data[IFLA_HSR_SLAVE1]) { + NL_SET_ERR_MSG_MOD(extack, "Slave1 device not specified"); + return -EINVAL; + } + link[0] = __dev_get_by_index(src_net, + nla_get_u32(data[IFLA_HSR_SLAVE1])); + if (!link[0]) { + NL_SET_ERR_MSG_MOD(extack, "Slave1 does not exist"); + return -EINVAL; + } + if (!data[IFLA_HSR_SLAVE2]) { + NL_SET_ERR_MSG_MOD(extack, "Slave2 device not specified"); + return -EINVAL; + } + link[1] = __dev_get_by_index(src_net, + nla_get_u32(data[IFLA_HSR_SLAVE2])); + if (!link[1]) { + NL_SET_ERR_MSG_MOD(extack, "Slave2 does not exist"); + return -EINVAL; + } + + if (link[0] == link[1]) { + NL_SET_ERR_MSG_MOD(extack, "Slave1 and Slave2 are same"); + return -EINVAL; + } + + if (!data[IFLA_HSR_MULTICAST_SPEC]) + multicast_spec = 0; + else + multicast_spec = nla_get_u8(data[IFLA_HSR_MULTICAST_SPEC]); + + if (data[IFLA_HSR_PROTOCOL]) + proto = nla_get_u8(data[IFLA_HSR_PROTOCOL]); + + if (proto >= HSR_PROTOCOL_MAX) { + NL_SET_ERR_MSG_MOD(extack, "Unsupported protocol"); + return -EINVAL; + } + + if (!data[IFLA_HSR_VERSION]) { + proto_version = HSR_V0; + } else { + if (proto == HSR_PROTOCOL_PRP) { + NL_SET_ERR_MSG_MOD(extack, "PRP version unsupported"); + return -EINVAL; + } + + proto_version = nla_get_u8(data[IFLA_HSR_VERSION]); + if (proto_version > HSR_V1) { + NL_SET_ERR_MSG_MOD(extack, + "Only HSR version 0/1 supported"); + return -EINVAL; + } + } + + if (proto == HSR_PROTOCOL_PRP) + proto_version = PRP_V1; + + return hsr_dev_finalize(dev, link, multicast_spec, proto_version, extack); +} + +static void hsr_dellink(struct net_device *dev, struct list_head *head) +{ + struct hsr_priv *hsr = netdev_priv(dev); + + del_timer_sync(&hsr->prune_timer); + del_timer_sync(&hsr->announce_timer); + + hsr_debugfs_term(hsr); + hsr_del_ports(hsr); + + hsr_del_self_node(hsr); + hsr_del_nodes(&hsr->node_db); + + unregister_netdevice_queue(dev, head); +} + +static int hsr_fill_info(struct sk_buff *skb, const struct net_device *dev) +{ + struct hsr_priv *hsr = netdev_priv(dev); + u8 proto = HSR_PROTOCOL_HSR; + struct hsr_port *port; + + port = hsr_port_get_hsr(hsr, HSR_PT_SLAVE_A); + if (port) { + if (nla_put_u32(skb, IFLA_HSR_SLAVE1, port->dev->ifindex)) + goto nla_put_failure; + } + + port = hsr_port_get_hsr(hsr, HSR_PT_SLAVE_B); + if (port) { + if (nla_put_u32(skb, IFLA_HSR_SLAVE2, port->dev->ifindex)) + goto nla_put_failure; + } + + if (nla_put(skb, IFLA_HSR_SUPERVISION_ADDR, ETH_ALEN, + hsr->sup_multicast_addr) || + nla_put_u16(skb, IFLA_HSR_SEQ_NR, hsr->sequence_nr)) + goto nla_put_failure; + if (hsr->prot_version == PRP_V1) + proto = HSR_PROTOCOL_PRP; + if (nla_put_u8(skb, IFLA_HSR_PROTOCOL, proto)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static struct rtnl_link_ops hsr_link_ops __read_mostly = { + .kind = "hsr", + .maxtype = IFLA_HSR_MAX, + .policy = hsr_policy, + .priv_size = sizeof(struct hsr_priv), + .setup = hsr_dev_setup, + .newlink = hsr_newlink, + .dellink = hsr_dellink, + .fill_info = hsr_fill_info, +}; + +/* attribute policy */ +static const struct nla_policy hsr_genl_policy[HSR_A_MAX + 1] = { + [HSR_A_NODE_ADDR] = { .len = ETH_ALEN }, + [HSR_A_NODE_ADDR_B] = { .len = ETH_ALEN }, + [HSR_A_IFINDEX] = { .type = NLA_U32 }, + [HSR_A_IF1_AGE] = { .type = NLA_U32 }, + [HSR_A_IF2_AGE] = { .type = NLA_U32 }, + [HSR_A_IF1_SEQ] = { .type = NLA_U16 }, + [HSR_A_IF2_SEQ] = { .type = NLA_U16 }, +}; + +static struct genl_family hsr_genl_family; + +static const struct genl_multicast_group hsr_mcgrps[] = { + { .name = "hsr-network", }, +}; + +/* This is called if for some node with MAC address addr, we only get frames + * over one of the slave interfaces. This would indicate an open network ring + * (i.e. a link has failed somewhere). + */ +void hsr_nl_ringerror(struct hsr_priv *hsr, unsigned char addr[ETH_ALEN], + struct hsr_port *port) +{ + struct sk_buff *skb; + void *msg_head; + struct hsr_port *master; + int res; + + skb = genlmsg_new(NLMSG_GOODSIZE, GFP_ATOMIC); + if (!skb) + goto fail; + + msg_head = genlmsg_put(skb, 0, 0, &hsr_genl_family, 0, + HSR_C_RING_ERROR); + if (!msg_head) + goto nla_put_failure; + + res = nla_put(skb, HSR_A_NODE_ADDR, ETH_ALEN, addr); + if (res < 0) + goto nla_put_failure; + + res = nla_put_u32(skb, HSR_A_IFINDEX, port->dev->ifindex); + if (res < 0) + goto nla_put_failure; + + genlmsg_end(skb, msg_head); + genlmsg_multicast(&hsr_genl_family, skb, 0, 0, GFP_ATOMIC); + + return; + +nla_put_failure: + kfree_skb(skb); + +fail: + rcu_read_lock(); + master = hsr_port_get_hsr(hsr, HSR_PT_MASTER); + netdev_warn(master->dev, "Could not send HSR ring error message\n"); + rcu_read_unlock(); +} + +/* This is called when we haven't heard from the node with MAC address addr for + * some time (just before the node is removed from the node table/list). + */ +void hsr_nl_nodedown(struct hsr_priv *hsr, unsigned char addr[ETH_ALEN]) +{ + struct sk_buff *skb; + void *msg_head; + struct hsr_port *master; + int res; + + skb = genlmsg_new(NLMSG_GOODSIZE, GFP_ATOMIC); + if (!skb) + goto fail; + + msg_head = genlmsg_put(skb, 0, 0, &hsr_genl_family, 0, HSR_C_NODE_DOWN); + if (!msg_head) + goto nla_put_failure; + + res = nla_put(skb, HSR_A_NODE_ADDR, ETH_ALEN, addr); + if (res < 0) + goto nla_put_failure; + + genlmsg_end(skb, msg_head); + genlmsg_multicast(&hsr_genl_family, skb, 0, 0, GFP_ATOMIC); + + return; + +nla_put_failure: + kfree_skb(skb); + +fail: + rcu_read_lock(); + master = hsr_port_get_hsr(hsr, HSR_PT_MASTER); + netdev_warn(master->dev, "Could not send HSR node down\n"); + rcu_read_unlock(); +} + +/* HSR_C_GET_NODE_STATUS lets userspace query the internal HSR node table + * about the status of a specific node in the network, defined by its MAC + * address. + * + * Input: hsr ifindex, node mac address + * Output: hsr ifindex, node mac address (copied from request), + * age of latest frame from node over slave 1, slave 2 [ms] + */ +static int hsr_get_node_status(struct sk_buff *skb_in, struct genl_info *info) +{ + /* For receiving */ + struct nlattr *na; + struct net_device *hsr_dev; + + /* For sending */ + struct sk_buff *skb_out; + void *msg_head; + struct hsr_priv *hsr; + struct hsr_port *port; + unsigned char hsr_node_addr_b[ETH_ALEN]; + int hsr_node_if1_age; + u16 hsr_node_if1_seq; + int hsr_node_if2_age; + u16 hsr_node_if2_seq; + int addr_b_ifindex; + int res; + + if (!info) + goto invalid; + + na = info->attrs[HSR_A_IFINDEX]; + if (!na) + goto invalid; + na = info->attrs[HSR_A_NODE_ADDR]; + if (!na) + goto invalid; + + rcu_read_lock(); + hsr_dev = dev_get_by_index_rcu(genl_info_net(info), + nla_get_u32(info->attrs[HSR_A_IFINDEX])); + if (!hsr_dev) + goto rcu_unlock; + if (!is_hsr_master(hsr_dev)) + goto rcu_unlock; + + /* Send reply */ + skb_out = genlmsg_new(NLMSG_GOODSIZE, GFP_ATOMIC); + if (!skb_out) { + res = -ENOMEM; + goto fail; + } + + msg_head = genlmsg_put(skb_out, NETLINK_CB(skb_in).portid, + info->snd_seq, &hsr_genl_family, 0, + HSR_C_SET_NODE_STATUS); + if (!msg_head) { + res = -ENOMEM; + goto nla_put_failure; + } + + res = nla_put_u32(skb_out, HSR_A_IFINDEX, hsr_dev->ifindex); + if (res < 0) + goto nla_put_failure; + + hsr = netdev_priv(hsr_dev); + res = hsr_get_node_data(hsr, + (unsigned char *) + nla_data(info->attrs[HSR_A_NODE_ADDR]), + hsr_node_addr_b, + &addr_b_ifindex, + &hsr_node_if1_age, + &hsr_node_if1_seq, + &hsr_node_if2_age, + &hsr_node_if2_seq); + if (res < 0) + goto nla_put_failure; + + res = nla_put(skb_out, HSR_A_NODE_ADDR, ETH_ALEN, + nla_data(info->attrs[HSR_A_NODE_ADDR])); + if (res < 0) + goto nla_put_failure; + + if (addr_b_ifindex > -1) { + res = nla_put(skb_out, HSR_A_NODE_ADDR_B, ETH_ALEN, + hsr_node_addr_b); + if (res < 0) + goto nla_put_failure; + + res = nla_put_u32(skb_out, HSR_A_ADDR_B_IFINDEX, + addr_b_ifindex); + if (res < 0) + goto nla_put_failure; + } + + res = nla_put_u32(skb_out, HSR_A_IF1_AGE, hsr_node_if1_age); + if (res < 0) + goto nla_put_failure; + res = nla_put_u16(skb_out, HSR_A_IF1_SEQ, hsr_node_if1_seq); + if (res < 0) + goto nla_put_failure; + port = hsr_port_get_hsr(hsr, HSR_PT_SLAVE_A); + if (port) + res = nla_put_u32(skb_out, HSR_A_IF1_IFINDEX, + port->dev->ifindex); + if (res < 0) + goto nla_put_failure; + + res = nla_put_u32(skb_out, HSR_A_IF2_AGE, hsr_node_if2_age); + if (res < 0) + goto nla_put_failure; + res = nla_put_u16(skb_out, HSR_A_IF2_SEQ, hsr_node_if2_seq); + if (res < 0) + goto nla_put_failure; + port = hsr_port_get_hsr(hsr, HSR_PT_SLAVE_B); + if (port) + res = nla_put_u32(skb_out, HSR_A_IF2_IFINDEX, + port->dev->ifindex); + if (res < 0) + goto nla_put_failure; + + rcu_read_unlock(); + + genlmsg_end(skb_out, msg_head); + genlmsg_unicast(genl_info_net(info), skb_out, info->snd_portid); + + return 0; + +rcu_unlock: + rcu_read_unlock(); +invalid: + netlink_ack(skb_in, nlmsg_hdr(skb_in), -EINVAL, NULL); + return 0; + +nla_put_failure: + kfree_skb(skb_out); + /* Fall through */ + +fail: + rcu_read_unlock(); + return res; +} + +/* Get a list of MacAddressA of all nodes known to this node (including self). + */ +static int hsr_get_node_list(struct sk_buff *skb_in, struct genl_info *info) +{ + unsigned char addr[ETH_ALEN]; + struct net_device *hsr_dev; + struct sk_buff *skb_out; + struct hsr_priv *hsr; + bool restart = false; + struct nlattr *na; + void *pos = NULL; + void *msg_head; + int res; + + if (!info) + goto invalid; + + na = info->attrs[HSR_A_IFINDEX]; + if (!na) + goto invalid; + + rcu_read_lock(); + hsr_dev = dev_get_by_index_rcu(genl_info_net(info), + nla_get_u32(info->attrs[HSR_A_IFINDEX])); + if (!hsr_dev) + goto rcu_unlock; + if (!is_hsr_master(hsr_dev)) + goto rcu_unlock; + +restart: + /* Send reply */ + skb_out = genlmsg_new(GENLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!skb_out) { + res = -ENOMEM; + goto fail; + } + + msg_head = genlmsg_put(skb_out, NETLINK_CB(skb_in).portid, + info->snd_seq, &hsr_genl_family, 0, + HSR_C_SET_NODE_LIST); + if (!msg_head) { + res = -ENOMEM; + goto nla_put_failure; + } + + if (!restart) { + res = nla_put_u32(skb_out, HSR_A_IFINDEX, hsr_dev->ifindex); + if (res < 0) + goto nla_put_failure; + } + + hsr = netdev_priv(hsr_dev); + + if (!pos) + pos = hsr_get_next_node(hsr, NULL, addr); + while (pos) { + res = nla_put(skb_out, HSR_A_NODE_ADDR, ETH_ALEN, addr); + if (res < 0) { + if (res == -EMSGSIZE) { + genlmsg_end(skb_out, msg_head); + genlmsg_unicast(genl_info_net(info), skb_out, + info->snd_portid); + restart = true; + goto restart; + } + goto nla_put_failure; + } + pos = hsr_get_next_node(hsr, pos, addr); + } + rcu_read_unlock(); + + genlmsg_end(skb_out, msg_head); + genlmsg_unicast(genl_info_net(info), skb_out, info->snd_portid); + + return 0; + +rcu_unlock: + rcu_read_unlock(); +invalid: + netlink_ack(skb_in, nlmsg_hdr(skb_in), -EINVAL, NULL); + return 0; + +nla_put_failure: + nlmsg_free(skb_out); + /* Fall through */ + +fail: + rcu_read_unlock(); + return res; +} + +static const struct genl_small_ops hsr_ops[] = { + { + .cmd = HSR_C_GET_NODE_STATUS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, + .doit = hsr_get_node_status, + .dumpit = NULL, + }, + { + .cmd = HSR_C_GET_NODE_LIST, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, + .doit = hsr_get_node_list, + .dumpit = NULL, + }, +}; + +static struct genl_family hsr_genl_family __ro_after_init = { + .hdrsize = 0, + .name = "HSR", + .version = 1, + .maxattr = HSR_A_MAX, + .policy = hsr_genl_policy, + .netnsok = true, + .module = THIS_MODULE, + .small_ops = hsr_ops, + .n_small_ops = ARRAY_SIZE(hsr_ops), + .resv_start_op = HSR_C_SET_NODE_LIST + 1, + .mcgrps = hsr_mcgrps, + .n_mcgrps = ARRAY_SIZE(hsr_mcgrps), +}; + +int __init hsr_netlink_init(void) +{ + int rc; + + rc = rtnl_link_register(&hsr_link_ops); + if (rc) + goto fail_rtnl_link_register; + + rc = genl_register_family(&hsr_genl_family); + if (rc) + goto fail_genl_register_family; + + hsr_debugfs_create_root(); + return 0; + +fail_genl_register_family: + rtnl_link_unregister(&hsr_link_ops); +fail_rtnl_link_register: + + return rc; +} + +void __exit hsr_netlink_exit(void) +{ + genl_unregister_family(&hsr_genl_family); + rtnl_link_unregister(&hsr_link_ops); +} + +MODULE_ALIAS_RTNL_LINK("hsr"); diff --git a/net/hsr/hsr_netlink.h b/net/hsr/hsr_netlink.h new file mode 100644 index 000000000..501552d97 --- /dev/null +++ b/net/hsr/hsr_netlink.h @@ -0,0 +1,29 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright 2011-2014 Autronica Fire and Security AS + * + * Author(s): + * 2011-2014 Arvid Brodin, arvid.brodin@alten.se + * + * include file for HSR and PRP. + */ + +#ifndef __HSR_NETLINK_H +#define __HSR_NETLINK_H + +#include +#include +#include + +struct hsr_priv; +struct hsr_port; + +int __init hsr_netlink_init(void); +void __exit hsr_netlink_exit(void); + +void hsr_nl_ringerror(struct hsr_priv *hsr, unsigned char addr[ETH_ALEN], + struct hsr_port *port); +void hsr_nl_nodedown(struct hsr_priv *hsr, unsigned char addr[ETH_ALEN]); +void hsr_nl_framedrop(int dropcount, int dev_idx); +void hsr_nl_linkdown(int dev_idx); + +#endif /* __HSR_NETLINK_H */ diff --git a/net/hsr/hsr_slave.c b/net/hsr/hsr_slave.c new file mode 100644 index 000000000..b70e6bbf6 --- /dev/null +++ b/net/hsr/hsr_slave.c @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright 2011-2014 Autronica Fire and Security AS + * + * Author(s): + * 2011-2014 Arvid Brodin, arvid.brodin@alten.se + * + * Frame handler other utility functions for HSR and PRP. + */ + +#include "hsr_slave.h" +#include +#include +#include +#include "hsr_main.h" +#include "hsr_device.h" +#include "hsr_forward.h" +#include "hsr_framereg.h" + +bool hsr_invalid_dan_ingress_frame(__be16 protocol) +{ + return (protocol != htons(ETH_P_PRP) && protocol != htons(ETH_P_HSR)); +} + +static rx_handler_result_t hsr_handle_frame(struct sk_buff **pskb) +{ + struct sk_buff *skb = *pskb; + struct hsr_port *port; + struct hsr_priv *hsr; + __be16 protocol; + + /* Packets from dev_loopback_xmit() do not have L2 header, bail out */ + if (unlikely(skb->pkt_type == PACKET_LOOPBACK)) + return RX_HANDLER_PASS; + + if (!skb_mac_header_was_set(skb)) { + WARN_ONCE(1, "%s: skb invalid", __func__); + return RX_HANDLER_PASS; + } + + port = hsr_port_get_rcu(skb->dev); + if (!port) + goto finish_pass; + hsr = port->hsr; + + if (hsr_addr_is_self(port->hsr, eth_hdr(skb)->h_source)) { + /* Directly kill frames sent by ourselves */ + kfree_skb(skb); + goto finish_consume; + } + + /* For HSR, only tagged frames are expected (unless the device offloads + * HSR tag removal), but for PRP there could be non tagged frames as + * well from Single attached nodes (SANs). + */ + protocol = eth_hdr(skb)->h_proto; + + if (!(port->dev->features & NETIF_F_HW_HSR_TAG_RM) && + hsr->proto_ops->invalid_dan_ingress_frame && + hsr->proto_ops->invalid_dan_ingress_frame(protocol)) + goto finish_pass; + + skb_push(skb, ETH_HLEN); + skb_reset_mac_header(skb); + if ((!hsr->prot_version && protocol == htons(ETH_P_PRP)) || + protocol == htons(ETH_P_HSR)) + skb_set_network_header(skb, ETH_HLEN + HSR_HLEN); + skb_reset_mac_len(skb); + + hsr_forward_skb(skb, port); + +finish_consume: + return RX_HANDLER_CONSUMED; + +finish_pass: + return RX_HANDLER_PASS; +} + +bool hsr_port_exists(const struct net_device *dev) +{ + return rcu_access_pointer(dev->rx_handler) == hsr_handle_frame; +} + +static int hsr_check_dev_ok(struct net_device *dev, + struct netlink_ext_ack *extack) +{ + /* Don't allow HSR on non-ethernet like devices */ + if ((dev->flags & IFF_LOOPBACK) || dev->type != ARPHRD_ETHER || + dev->addr_len != ETH_ALEN) { + NL_SET_ERR_MSG_MOD(extack, "Cannot use loopback or non-ethernet device as HSR slave."); + return -EINVAL; + } + + /* Don't allow enslaving hsr devices */ + if (is_hsr_master(dev)) { + NL_SET_ERR_MSG_MOD(extack, + "Cannot create trees of HSR devices."); + return -EINVAL; + } + + if (hsr_port_exists(dev)) { + NL_SET_ERR_MSG_MOD(extack, + "This device is already a HSR slave."); + return -EINVAL; + } + + if (is_vlan_dev(dev)) { + NL_SET_ERR_MSG_MOD(extack, "HSR on top of VLAN is not yet supported in this driver."); + return -EINVAL; + } + + if (dev->priv_flags & IFF_DONT_BRIDGE) { + NL_SET_ERR_MSG_MOD(extack, + "This device does not support bridging."); + return -EOPNOTSUPP; + } + + /* HSR over bonded devices has not been tested, but I'm not sure it + * won't work... + */ + + return 0; +} + +/* Setup device to be added to the HSR bridge. */ +static int hsr_portdev_setup(struct hsr_priv *hsr, struct net_device *dev, + struct hsr_port *port, + struct netlink_ext_ack *extack) + +{ + struct net_device *hsr_dev; + struct hsr_port *master; + int res; + + res = dev_set_promiscuity(dev, 1); + if (res) + return res; + + master = hsr_port_get_hsr(hsr, HSR_PT_MASTER); + hsr_dev = master->dev; + + res = netdev_upper_dev_link(dev, hsr_dev, extack); + if (res) + goto fail_upper_dev_link; + + res = netdev_rx_handler_register(dev, hsr_handle_frame, port); + if (res) + goto fail_rx_handler; + dev_disable_lro(dev); + + return 0; + +fail_rx_handler: + netdev_upper_dev_unlink(dev, hsr_dev); +fail_upper_dev_link: + dev_set_promiscuity(dev, -1); + return res; +} + +int hsr_add_port(struct hsr_priv *hsr, struct net_device *dev, + enum hsr_port_type type, struct netlink_ext_ack *extack) +{ + struct hsr_port *port, *master; + int res; + + if (type != HSR_PT_MASTER) { + res = hsr_check_dev_ok(dev, extack); + if (res) + return res; + } + + port = hsr_port_get_hsr(hsr, type); + if (port) + return -EBUSY; /* This port already exists */ + + port = kzalloc(sizeof(*port), GFP_KERNEL); + if (!port) + return -ENOMEM; + + port->hsr = hsr; + port->dev = dev; + port->type = type; + + if (type != HSR_PT_MASTER) { + res = hsr_portdev_setup(hsr, dev, port, extack); + if (res) + goto fail_dev_setup; + } + + list_add_tail_rcu(&port->port_list, &hsr->ports); + synchronize_rcu(); + + master = hsr_port_get_hsr(hsr, HSR_PT_MASTER); + netdev_update_features(master->dev); + dev_set_mtu(master->dev, hsr_get_max_mtu(hsr)); + + return 0; + +fail_dev_setup: + kfree(port); + return res; +} + +void hsr_del_port(struct hsr_port *port) +{ + struct hsr_priv *hsr; + struct hsr_port *master; + + hsr = port->hsr; + master = hsr_port_get_hsr(hsr, HSR_PT_MASTER); + list_del_rcu(&port->port_list); + + if (port != master) { + netdev_update_features(master->dev); + dev_set_mtu(master->dev, hsr_get_max_mtu(hsr)); + netdev_rx_handler_unregister(port->dev); + dev_set_promiscuity(port->dev, -1); + netdev_upper_dev_unlink(port->dev, master->dev); + } + + synchronize_rcu(); + + kfree(port); +} diff --git a/net/hsr/hsr_slave.h b/net/hsr/hsr_slave.h new file mode 100644 index 000000000..edc4612bb --- /dev/null +++ b/net/hsr/hsr_slave.h @@ -0,0 +1,37 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright 2011-2014 Autronica Fire and Security AS + * + * 2011-2014 Arvid Brodin, arvid.brodin@alten.se + * + * include file for HSR and PRP. + */ + +#ifndef __HSR_SLAVE_H +#define __HSR_SLAVE_H + +#include +#include +#include +#include "hsr_main.h" + +int hsr_add_port(struct hsr_priv *hsr, struct net_device *dev, + enum hsr_port_type pt, struct netlink_ext_ack *extack); +void hsr_del_port(struct hsr_port *port); +bool hsr_port_exists(const struct net_device *dev); + +static inline struct hsr_port *hsr_port_get_rtnl(const struct net_device *dev) +{ + ASSERT_RTNL(); + return hsr_port_exists(dev) ? + rtnl_dereference(dev->rx_handler_data) : NULL; +} + +static inline struct hsr_port *hsr_port_get_rcu(const struct net_device *dev) +{ + return hsr_port_exists(dev) ? + rcu_dereference(dev->rx_handler_data) : NULL; +} + +bool hsr_invalid_dan_ingress_frame(__be16 protocol); + +#endif /* __HSR_SLAVE_H */ diff --git a/net/ieee802154/6lowpan/6lowpan_i.h b/net/ieee802154/6lowpan/6lowpan_i.h new file mode 100644 index 000000000..44a7e16bf --- /dev/null +++ b/net/ieee802154/6lowpan/6lowpan_i.h @@ -0,0 +1,48 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __IEEE802154_6LOWPAN_I_H__ +#define __IEEE802154_6LOWPAN_I_H__ + +#include + +#include +#include +#include + +typedef unsigned __bitwise lowpan_rx_result; +#define RX_CONTINUE ((__force lowpan_rx_result) 0u) +#define RX_DROP_UNUSABLE ((__force lowpan_rx_result) 1u) +#define RX_DROP ((__force lowpan_rx_result) 2u) +#define RX_QUEUED ((__force lowpan_rx_result) 3u) + +#define LOWPAN_DISPATCH_FRAG1 0xc0 +#define LOWPAN_DISPATCH_FRAGN 0xe0 + +struct frag_lowpan_compare_key { + u16 tag; + u16 d_size; + struct ieee802154_addr src; + struct ieee802154_addr dst; +}; + +/* Equivalent of ipv4 struct ipq + */ +struct lowpan_frag_queue { + struct inet_frag_queue q; +}; + +int lowpan_frag_rcv(struct sk_buff *skb, const u8 frag_type); +void lowpan_net_frag_exit(void); +int lowpan_net_frag_init(void); + +void lowpan_rx_init(void); +void lowpan_rx_exit(void); + +int lowpan_header_create(struct sk_buff *skb, struct net_device *dev, + unsigned short type, const void *_daddr, + const void *_saddr, unsigned int len); +netdev_tx_t lowpan_xmit(struct sk_buff *skb, struct net_device *dev); + +int lowpan_iphc_decompress(struct sk_buff *skb); +lowpan_rx_result lowpan_rx_h_ipv6(struct sk_buff *skb); + +#endif /* __IEEE802154_6LOWPAN_I_H__ */ diff --git a/net/ieee802154/6lowpan/Kconfig b/net/ieee802154/6lowpan/Kconfig new file mode 100644 index 000000000..e808e4db2 --- /dev/null +++ b/net/ieee802154/6lowpan/Kconfig @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: GPL-2.0-only +config IEEE802154_6LOWPAN + tristate "6lowpan support over IEEE 802.15.4" + depends on 6LOWPAN + help + IPv6 compression over IEEE 802.15.4. diff --git a/net/ieee802154/6lowpan/Makefile b/net/ieee802154/6lowpan/Makefile new file mode 100644 index 000000000..f11d6376a --- /dev/null +++ b/net/ieee802154/6lowpan/Makefile @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: GPL-2.0-only +obj-$(CONFIG_IEEE802154_6LOWPAN) += ieee802154_6lowpan.o + +ieee802154_6lowpan-y := core.o rx.o reassembly.o tx.o diff --git a/net/ieee802154/6lowpan/core.c b/net/ieee802154/6lowpan/core.c new file mode 100644 index 000000000..2c087b7f1 --- /dev/null +++ b/net/ieee802154/6lowpan/core.c @@ -0,0 +1,284 @@ +/* Copyright 2011, Siemens AG + * written by Alexander Smirnov + */ + +/* Based on patches from Jon Smirl + * Copyright (c) 2011 Jon Smirl + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + */ + +/* Jon's code is based on 6lowpan implementation for Contiki which is: + * Copyright (c) 2008, Swedish Institute of Computer Science. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the Institute nor the names of its contributors + * may be used to endorse or promote products derived from this software + * without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + */ + +#include +#include +#include +#include + +#include + +#include "6lowpan_i.h" + +static int open_count; + +static const struct header_ops lowpan_header_ops = { + .create = lowpan_header_create, +}; + +static int lowpan_dev_init(struct net_device *ldev) +{ + netdev_lockdep_set_classes(ldev); + + return 0; +} + +static int lowpan_open(struct net_device *dev) +{ + if (!open_count) + lowpan_rx_init(); + open_count++; + return 0; +} + +static int lowpan_stop(struct net_device *dev) +{ + open_count--; + if (!open_count) + lowpan_rx_exit(); + return 0; +} + +static int lowpan_neigh_construct(struct net_device *dev, struct neighbour *n) +{ + struct lowpan_802154_neigh *neigh = lowpan_802154_neigh(neighbour_priv(n)); + + /* default no short_addr is available for a neighbour */ + neigh->short_addr = cpu_to_le16(IEEE802154_ADDR_SHORT_UNSPEC); + return 0; +} + +static int lowpan_get_iflink(const struct net_device *dev) +{ + return lowpan_802154_dev(dev)->wdev->ifindex; +} + +static const struct net_device_ops lowpan_netdev_ops = { + .ndo_init = lowpan_dev_init, + .ndo_start_xmit = lowpan_xmit, + .ndo_open = lowpan_open, + .ndo_stop = lowpan_stop, + .ndo_neigh_construct = lowpan_neigh_construct, + .ndo_get_iflink = lowpan_get_iflink, +}; + +static void lowpan_setup(struct net_device *ldev) +{ + memset(ldev->broadcast, 0xff, IEEE802154_ADDR_LEN); + /* We need an ipv6hdr as minimum len when calling xmit */ + ldev->hard_header_len = sizeof(struct ipv6hdr); + ldev->flags = IFF_BROADCAST | IFF_MULTICAST; + ldev->priv_flags |= IFF_NO_QUEUE; + + ldev->netdev_ops = &lowpan_netdev_ops; + ldev->header_ops = &lowpan_header_ops; + ldev->needs_free_netdev = true; + ldev->features |= NETIF_F_NETNS_LOCAL; +} + +static int lowpan_validate(struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + if (tb[IFLA_ADDRESS]) { + if (nla_len(tb[IFLA_ADDRESS]) != IEEE802154_ADDR_LEN) + return -EINVAL; + } + return 0; +} + +static int lowpan_newlink(struct net *src_net, struct net_device *ldev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct net_device *wdev; + int ret; + + ASSERT_RTNL(); + + pr_debug("adding new link\n"); + + if (!tb[IFLA_LINK]) + return -EINVAL; + /* find and hold wpan device */ + wdev = dev_get_by_index(dev_net(ldev), nla_get_u32(tb[IFLA_LINK])); + if (!wdev) + return -ENODEV; + if (wdev->type != ARPHRD_IEEE802154) { + dev_put(wdev); + return -EINVAL; + } + + if (wdev->ieee802154_ptr->lowpan_dev) { + dev_put(wdev); + return -EBUSY; + } + + lowpan_802154_dev(ldev)->wdev = wdev; + /* Set the lowpan hardware address to the wpan hardware address. */ + __dev_addr_set(ldev, wdev->dev_addr, IEEE802154_ADDR_LEN); + /* We need headroom for possible wpan_dev_hard_header call and tailroom + * for encryption/fcs handling. The lowpan interface will replace + * the IPv6 header with 6LoWPAN header. At worst case the 6LoWPAN + * header has LOWPAN_IPHC_MAX_HEADER_LEN more bytes than the IPv6 + * header. + */ + ldev->needed_headroom = LOWPAN_IPHC_MAX_HEADER_LEN + + wdev->needed_headroom; + ldev->needed_tailroom = wdev->needed_tailroom; + + ldev->neigh_priv_len = sizeof(struct lowpan_802154_neigh); + + ret = lowpan_register_netdevice(ldev, LOWPAN_LLTYPE_IEEE802154); + if (ret < 0) { + dev_put(wdev); + return ret; + } + + wdev->ieee802154_ptr->lowpan_dev = ldev; + return 0; +} + +static void lowpan_dellink(struct net_device *ldev, struct list_head *head) +{ + struct net_device *wdev = lowpan_802154_dev(ldev)->wdev; + + ASSERT_RTNL(); + + wdev->ieee802154_ptr->lowpan_dev = NULL; + lowpan_unregister_netdevice(ldev); + dev_put(wdev); +} + +static struct rtnl_link_ops lowpan_link_ops __read_mostly = { + .kind = "lowpan", + .priv_size = LOWPAN_PRIV_SIZE(sizeof(struct lowpan_802154_dev)), + .setup = lowpan_setup, + .newlink = lowpan_newlink, + .dellink = lowpan_dellink, + .validate = lowpan_validate, +}; + +static inline int __init lowpan_netlink_init(void) +{ + return rtnl_link_register(&lowpan_link_ops); +} + +static inline void lowpan_netlink_fini(void) +{ + rtnl_link_unregister(&lowpan_link_ops); +} + +static int lowpan_device_event(struct notifier_block *unused, + unsigned long event, void *ptr) +{ + struct net_device *ndev = netdev_notifier_info_to_dev(ptr); + struct wpan_dev *wpan_dev; + + if (ndev->type != ARPHRD_IEEE802154) + return NOTIFY_DONE; + wpan_dev = ndev->ieee802154_ptr; + if (!wpan_dev) + return NOTIFY_DONE; + + switch (event) { + case NETDEV_UNREGISTER: + /* Check if wpan interface is unregistered that we + * also delete possible lowpan interfaces which belongs + * to the wpan interface. + */ + if (wpan_dev->lowpan_dev) + lowpan_dellink(wpan_dev->lowpan_dev, NULL); + break; + default: + return NOTIFY_DONE; + } + + return NOTIFY_OK; +} + +static struct notifier_block lowpan_dev_notifier = { + .notifier_call = lowpan_device_event, +}; + +static int __init lowpan_init_module(void) +{ + int err = 0; + + err = lowpan_net_frag_init(); + if (err < 0) + goto out; + + err = lowpan_netlink_init(); + if (err < 0) + goto out_frag; + + err = register_netdevice_notifier(&lowpan_dev_notifier); + if (err < 0) + goto out_pack; + + return 0; + +out_pack: + lowpan_netlink_fini(); +out_frag: + lowpan_net_frag_exit(); +out: + return err; +} + +static void __exit lowpan_cleanup_module(void) +{ + lowpan_netlink_fini(); + + lowpan_net_frag_exit(); + + unregister_netdevice_notifier(&lowpan_dev_notifier); +} + +module_init(lowpan_init_module); +module_exit(lowpan_cleanup_module); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_RTNL_LINK("lowpan"); diff --git a/net/ieee802154/6lowpan/reassembly.c b/net/ieee802154/6lowpan/reassembly.c new file mode 100644 index 000000000..a91283d1e --- /dev/null +++ b/net/ieee802154/6lowpan/reassembly.c @@ -0,0 +1,551 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* 6LoWPAN fragment reassembly + * + * Authors: + * Alexander Aring + * + * Based on: net/ipv6/reassembly.c + */ + +#define pr_fmt(fmt) "6LoWPAN: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "6lowpan_i.h" + +static const char lowpan_frags_cache_name[] = "lowpan-frags"; + +static struct inet_frags lowpan_frags; + +static int lowpan_frag_reasm(struct lowpan_frag_queue *fq, struct sk_buff *skb, + struct sk_buff *prev, struct net_device *ldev); + +static void lowpan_frag_init(struct inet_frag_queue *q, const void *a) +{ + const struct frag_lowpan_compare_key *key = a; + + BUILD_BUG_ON(sizeof(*key) > sizeof(q->key)); + memcpy(&q->key, key, sizeof(*key)); +} + +static void lowpan_frag_expire(struct timer_list *t) +{ + struct inet_frag_queue *frag = from_timer(frag, t, timer); + struct frag_queue *fq; + + fq = container_of(frag, struct frag_queue, q); + + spin_lock(&fq->q.lock); + + if (fq->q.flags & INET_FRAG_COMPLETE) + goto out; + + inet_frag_kill(&fq->q); +out: + spin_unlock(&fq->q.lock); + inet_frag_put(&fq->q); +} + +static inline struct lowpan_frag_queue * +fq_find(struct net *net, const struct lowpan_802154_cb *cb, + const struct ieee802154_addr *src, + const struct ieee802154_addr *dst) +{ + struct netns_ieee802154_lowpan *ieee802154_lowpan = + net_ieee802154_lowpan(net); + struct frag_lowpan_compare_key key = {}; + struct inet_frag_queue *q; + + key.tag = cb->d_tag; + key.d_size = cb->d_size; + key.src = *src; + key.dst = *dst; + + q = inet_frag_find(ieee802154_lowpan->fqdir, &key); + if (!q) + return NULL; + + return container_of(q, struct lowpan_frag_queue, q); +} + +static int lowpan_frag_queue(struct lowpan_frag_queue *fq, + struct sk_buff *skb, u8 frag_type) +{ + struct sk_buff *prev_tail; + struct net_device *ldev; + int end, offset, err; + + /* inet_frag_queue_* functions use skb->cb; see struct ipfrag_skb_cb + * in inet_fragment.c + */ + BUILD_BUG_ON(sizeof(struct lowpan_802154_cb) > sizeof(struct inet_skb_parm)); + BUILD_BUG_ON(sizeof(struct lowpan_802154_cb) > sizeof(struct inet6_skb_parm)); + + if (fq->q.flags & INET_FRAG_COMPLETE) + goto err; + + offset = lowpan_802154_cb(skb)->d_offset << 3; + end = lowpan_802154_cb(skb)->d_size; + + /* Is this the final fragment? */ + if (offset + skb->len == end) { + /* If we already have some bits beyond end + * or have different end, the segment is corrupted. + */ + if (end < fq->q.len || + ((fq->q.flags & INET_FRAG_LAST_IN) && end != fq->q.len)) + goto err; + fq->q.flags |= INET_FRAG_LAST_IN; + fq->q.len = end; + } else { + if (end > fq->q.len) { + /* Some bits beyond end -> corruption. */ + if (fq->q.flags & INET_FRAG_LAST_IN) + goto err; + fq->q.len = end; + } + } + + ldev = skb->dev; + if (ldev) + skb->dev = NULL; + barrier(); + + prev_tail = fq->q.fragments_tail; + err = inet_frag_queue_insert(&fq->q, skb, offset, end); + if (err) + goto err; + + fq->q.stamp = skb->tstamp; + fq->q.mono_delivery_time = skb->mono_delivery_time; + if (frag_type == LOWPAN_DISPATCH_FRAG1) + fq->q.flags |= INET_FRAG_FIRST_IN; + + fq->q.meat += skb->len; + add_frag_mem_limit(fq->q.fqdir, skb->truesize); + + if (fq->q.flags == (INET_FRAG_FIRST_IN | INET_FRAG_LAST_IN) && + fq->q.meat == fq->q.len) { + int res; + unsigned long orefdst = skb->_skb_refdst; + + skb->_skb_refdst = 0UL; + res = lowpan_frag_reasm(fq, skb, prev_tail, ldev); + skb->_skb_refdst = orefdst; + return res; + } + skb_dst_drop(skb); + + return -1; +err: + kfree_skb(skb); + return -1; +} + +/* Check if this packet is complete. + * + * It is called with locked fq, and caller must check that + * queue is eligible for reassembly i.e. it is not COMPLETE, + * the last and the first frames arrived and all the bits are here. + */ +static int lowpan_frag_reasm(struct lowpan_frag_queue *fq, struct sk_buff *skb, + struct sk_buff *prev_tail, struct net_device *ldev) +{ + void *reasm_data; + + inet_frag_kill(&fq->q); + + reasm_data = inet_frag_reasm_prepare(&fq->q, skb, prev_tail); + if (!reasm_data) + goto out_oom; + inet_frag_reasm_finish(&fq->q, skb, reasm_data, false); + + skb->dev = ldev; + skb->tstamp = fq->q.stamp; + fq->q.rb_fragments = RB_ROOT; + fq->q.fragments_tail = NULL; + fq->q.last_run_head = NULL; + + return 1; +out_oom: + net_dbg_ratelimited("lowpan_frag_reasm: no memory for reassembly\n"); + return -1; +} + +static int lowpan_frag_rx_handlers_result(struct sk_buff *skb, + lowpan_rx_result res) +{ + switch (res) { + case RX_QUEUED: + return NET_RX_SUCCESS; + case RX_CONTINUE: + /* nobody cared about this packet */ + net_warn_ratelimited("%s: received unknown dispatch\n", + __func__); + + fallthrough; + default: + /* all others failure */ + return NET_RX_DROP; + } +} + +static lowpan_rx_result lowpan_frag_rx_h_iphc(struct sk_buff *skb) +{ + int ret; + + if (!lowpan_is_iphc(*skb_network_header(skb))) + return RX_CONTINUE; + + ret = lowpan_iphc_decompress(skb); + if (ret < 0) + return RX_DROP; + + return RX_QUEUED; +} + +static int lowpan_invoke_frag_rx_handlers(struct sk_buff *skb) +{ + lowpan_rx_result res; + +#define CALL_RXH(rxh) \ + do { \ + res = rxh(skb); \ + if (res != RX_CONTINUE) \ + goto rxh_next; \ + } while (0) + + /* likely at first */ + CALL_RXH(lowpan_frag_rx_h_iphc); + CALL_RXH(lowpan_rx_h_ipv6); + +rxh_next: + return lowpan_frag_rx_handlers_result(skb, res); +#undef CALL_RXH +} + +#define LOWPAN_FRAG_DGRAM_SIZE_HIGH_MASK 0x07 +#define LOWPAN_FRAG_DGRAM_SIZE_HIGH_SHIFT 8 + +static int lowpan_get_cb(struct sk_buff *skb, u8 frag_type, + struct lowpan_802154_cb *cb) +{ + bool fail; + u8 high = 0, low = 0; + __be16 d_tag = 0; + + fail = lowpan_fetch_skb(skb, &high, 1); + fail |= lowpan_fetch_skb(skb, &low, 1); + /* remove the dispatch value and use first three bits as high value + * for the datagram size + */ + cb->d_size = (high & LOWPAN_FRAG_DGRAM_SIZE_HIGH_MASK) << + LOWPAN_FRAG_DGRAM_SIZE_HIGH_SHIFT | low; + fail |= lowpan_fetch_skb(skb, &d_tag, 2); + cb->d_tag = ntohs(d_tag); + + if (frag_type == LOWPAN_DISPATCH_FRAGN) { + fail |= lowpan_fetch_skb(skb, &cb->d_offset, 1); + } else { + skb_reset_network_header(skb); + cb->d_offset = 0; + /* check if datagram_size has ipv6hdr on FRAG1 */ + fail |= cb->d_size < sizeof(struct ipv6hdr); + /* check if we can dereference the dispatch value */ + fail |= !skb->len; + } + + if (unlikely(fail)) + return -EIO; + + return 0; +} + +int lowpan_frag_rcv(struct sk_buff *skb, u8 frag_type) +{ + struct lowpan_frag_queue *fq; + struct net *net = dev_net(skb->dev); + struct lowpan_802154_cb *cb = lowpan_802154_cb(skb); + struct ieee802154_hdr hdr = {}; + int err; + + if (ieee802154_hdr_peek_addrs(skb, &hdr) < 0) + goto err; + + err = lowpan_get_cb(skb, frag_type, cb); + if (err < 0) + goto err; + + if (frag_type == LOWPAN_DISPATCH_FRAG1) { + err = lowpan_invoke_frag_rx_handlers(skb); + if (err == NET_RX_DROP) + goto err; + } + + if (cb->d_size > IPV6_MIN_MTU) { + net_warn_ratelimited("lowpan_frag_rcv: datagram size exceeds MTU\n"); + goto err; + } + + fq = fq_find(net, cb, &hdr.source, &hdr.dest); + if (fq != NULL) { + int ret; + + spin_lock(&fq->q.lock); + ret = lowpan_frag_queue(fq, skb, frag_type); + spin_unlock(&fq->q.lock); + + inet_frag_put(&fq->q); + return ret; + } + +err: + kfree_skb(skb); + return -1; +} + +#ifdef CONFIG_SYSCTL + +static struct ctl_table lowpan_frags_ns_ctl_table[] = { + { + .procname = "6lowpanfrag_high_thresh", + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + }, + { + .procname = "6lowpanfrag_low_thresh", + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + }, + { + .procname = "6lowpanfrag_time", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { } +}; + +/* secret interval has been deprecated */ +static int lowpan_frags_secret_interval_unused; +static struct ctl_table lowpan_frags_ctl_table[] = { + { + .procname = "6lowpanfrag_secret_interval", + .data = &lowpan_frags_secret_interval_unused, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { } +}; + +static int __net_init lowpan_frags_ns_sysctl_register(struct net *net) +{ + struct ctl_table *table; + struct ctl_table_header *hdr; + struct netns_ieee802154_lowpan *ieee802154_lowpan = + net_ieee802154_lowpan(net); + + table = lowpan_frags_ns_ctl_table; + if (!net_eq(net, &init_net)) { + table = kmemdup(table, sizeof(lowpan_frags_ns_ctl_table), + GFP_KERNEL); + if (table == NULL) + goto err_alloc; + + /* Don't export sysctls to unprivileged users */ + if (net->user_ns != &init_user_ns) + table[0].procname = NULL; + } + + table[0].data = &ieee802154_lowpan->fqdir->high_thresh; + table[0].extra1 = &ieee802154_lowpan->fqdir->low_thresh; + table[1].data = &ieee802154_lowpan->fqdir->low_thresh; + table[1].extra2 = &ieee802154_lowpan->fqdir->high_thresh; + table[2].data = &ieee802154_lowpan->fqdir->timeout; + + hdr = register_net_sysctl(net, "net/ieee802154/6lowpan", table); + if (hdr == NULL) + goto err_reg; + + ieee802154_lowpan->sysctl.frags_hdr = hdr; + return 0; + +err_reg: + if (!net_eq(net, &init_net)) + kfree(table); +err_alloc: + return -ENOMEM; +} + +static void __net_exit lowpan_frags_ns_sysctl_unregister(struct net *net) +{ + struct ctl_table *table; + struct netns_ieee802154_lowpan *ieee802154_lowpan = + net_ieee802154_lowpan(net); + + table = ieee802154_lowpan->sysctl.frags_hdr->ctl_table_arg; + unregister_net_sysctl_table(ieee802154_lowpan->sysctl.frags_hdr); + if (!net_eq(net, &init_net)) + kfree(table); +} + +static struct ctl_table_header *lowpan_ctl_header; + +static int __init lowpan_frags_sysctl_register(void) +{ + lowpan_ctl_header = register_net_sysctl(&init_net, + "net/ieee802154/6lowpan", + lowpan_frags_ctl_table); + return lowpan_ctl_header == NULL ? -ENOMEM : 0; +} + +static void lowpan_frags_sysctl_unregister(void) +{ + unregister_net_sysctl_table(lowpan_ctl_header); +} +#else +static inline int lowpan_frags_ns_sysctl_register(struct net *net) +{ + return 0; +} + +static inline void lowpan_frags_ns_sysctl_unregister(struct net *net) +{ +} + +static inline int __init lowpan_frags_sysctl_register(void) +{ + return 0; +} + +static inline void lowpan_frags_sysctl_unregister(void) +{ +} +#endif + +static int __net_init lowpan_frags_init_net(struct net *net) +{ + struct netns_ieee802154_lowpan *ieee802154_lowpan = + net_ieee802154_lowpan(net); + int res; + + + res = fqdir_init(&ieee802154_lowpan->fqdir, &lowpan_frags, net); + if (res < 0) + return res; + + ieee802154_lowpan->fqdir->high_thresh = IPV6_FRAG_HIGH_THRESH; + ieee802154_lowpan->fqdir->low_thresh = IPV6_FRAG_LOW_THRESH; + ieee802154_lowpan->fqdir->timeout = IPV6_FRAG_TIMEOUT; + + res = lowpan_frags_ns_sysctl_register(net); + if (res < 0) + fqdir_exit(ieee802154_lowpan->fqdir); + return res; +} + +static void __net_exit lowpan_frags_pre_exit_net(struct net *net) +{ + struct netns_ieee802154_lowpan *ieee802154_lowpan = + net_ieee802154_lowpan(net); + + fqdir_pre_exit(ieee802154_lowpan->fqdir); +} + +static void __net_exit lowpan_frags_exit_net(struct net *net) +{ + struct netns_ieee802154_lowpan *ieee802154_lowpan = + net_ieee802154_lowpan(net); + + lowpan_frags_ns_sysctl_unregister(net); + fqdir_exit(ieee802154_lowpan->fqdir); +} + +static struct pernet_operations lowpan_frags_ops = { + .init = lowpan_frags_init_net, + .pre_exit = lowpan_frags_pre_exit_net, + .exit = lowpan_frags_exit_net, +}; + +static u32 lowpan_key_hashfn(const void *data, u32 len, u32 seed) +{ + return jhash2(data, + sizeof(struct frag_lowpan_compare_key) / sizeof(u32), seed); +} + +static u32 lowpan_obj_hashfn(const void *data, u32 len, u32 seed) +{ + const struct inet_frag_queue *fq = data; + + return jhash2((const u32 *)&fq->key, + sizeof(struct frag_lowpan_compare_key) / sizeof(u32), seed); +} + +static int lowpan_obj_cmpfn(struct rhashtable_compare_arg *arg, const void *ptr) +{ + const struct frag_lowpan_compare_key *key = arg->key; + const struct inet_frag_queue *fq = ptr; + + return !!memcmp(&fq->key, key, sizeof(*key)); +} + +static const struct rhashtable_params lowpan_rhash_params = { + .head_offset = offsetof(struct inet_frag_queue, node), + .hashfn = lowpan_key_hashfn, + .obj_hashfn = lowpan_obj_hashfn, + .obj_cmpfn = lowpan_obj_cmpfn, + .automatic_shrinking = true, +}; + +int __init lowpan_net_frag_init(void) +{ + int ret; + + lowpan_frags.constructor = lowpan_frag_init; + lowpan_frags.destructor = NULL; + lowpan_frags.qsize = sizeof(struct frag_queue); + lowpan_frags.frag_expire = lowpan_frag_expire; + lowpan_frags.frags_cache_name = lowpan_frags_cache_name; + lowpan_frags.rhash_params = lowpan_rhash_params; + ret = inet_frags_init(&lowpan_frags); + if (ret) + goto out; + + ret = lowpan_frags_sysctl_register(); + if (ret) + goto err_sysctl; + + ret = register_pernet_subsys(&lowpan_frags_ops); + if (ret) + goto err_pernet; +out: + return ret; +err_pernet: + lowpan_frags_sysctl_unregister(); +err_sysctl: + inet_frags_fini(&lowpan_frags); + return ret; +} + +void lowpan_net_frag_exit(void) +{ + lowpan_frags_sysctl_unregister(); + unregister_pernet_subsys(&lowpan_frags_ops); + inet_frags_fini(&lowpan_frags); +} diff --git a/net/ieee802154/6lowpan/rx.c b/net/ieee802154/6lowpan/rx.c new file mode 100644 index 000000000..517e6493f --- /dev/null +++ b/net/ieee802154/6lowpan/rx.c @@ -0,0 +1,323 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include + +#include +#include +#include + +#include "6lowpan_i.h" + +#define LOWPAN_DISPATCH_FIRST 0xc0 +#define LOWPAN_DISPATCH_FRAG_MASK 0xf8 + +#define LOWPAN_DISPATCH_NALP 0x00 +#define LOWPAN_DISPATCH_ESC 0x40 +#define LOWPAN_DISPATCH_HC1 0x42 +#define LOWPAN_DISPATCH_DFF 0x43 +#define LOWPAN_DISPATCH_BC0 0x50 +#define LOWPAN_DISPATCH_MESH 0x80 + +static int lowpan_give_skb_to_device(struct sk_buff *skb) +{ + skb->protocol = htons(ETH_P_IPV6); + skb->dev->stats.rx_packets++; + skb->dev->stats.rx_bytes += skb->len; + + return netif_rx(skb); +} + +static int lowpan_rx_handlers_result(struct sk_buff *skb, lowpan_rx_result res) +{ + switch (res) { + case RX_CONTINUE: + /* nobody cared about this packet */ + net_warn_ratelimited("%s: received unknown dispatch\n", + __func__); + + fallthrough; + case RX_DROP_UNUSABLE: + kfree_skb(skb); + + fallthrough; + case RX_DROP: + return NET_RX_DROP; + case RX_QUEUED: + return lowpan_give_skb_to_device(skb); + default: + break; + } + + return NET_RX_DROP; +} + +static inline bool lowpan_is_frag1(u8 dispatch) +{ + return (dispatch & LOWPAN_DISPATCH_FRAG_MASK) == LOWPAN_DISPATCH_FRAG1; +} + +static inline bool lowpan_is_fragn(u8 dispatch) +{ + return (dispatch & LOWPAN_DISPATCH_FRAG_MASK) == LOWPAN_DISPATCH_FRAGN; +} + +static lowpan_rx_result lowpan_rx_h_frag(struct sk_buff *skb) +{ + int ret; + + if (!(lowpan_is_frag1(*skb_network_header(skb)) || + lowpan_is_fragn(*skb_network_header(skb)))) + return RX_CONTINUE; + + ret = lowpan_frag_rcv(skb, *skb_network_header(skb) & + LOWPAN_DISPATCH_FRAG_MASK); + if (ret == 1) + return RX_QUEUED; + + /* Packet is freed by lowpan_frag_rcv on error or put into the frag + * bucket. + */ + return RX_DROP; +} + +int lowpan_iphc_decompress(struct sk_buff *skb) +{ + struct ieee802154_hdr hdr; + + if (ieee802154_hdr_peek_addrs(skb, &hdr) < 0) + return -EINVAL; + + return lowpan_header_decompress(skb, skb->dev, &hdr.dest, &hdr.source); +} + +static lowpan_rx_result lowpan_rx_h_iphc(struct sk_buff *skb) +{ + int ret; + + if (!lowpan_is_iphc(*skb_network_header(skb))) + return RX_CONTINUE; + + /* Setting datagram_offset to zero indicates non frag handling + * while doing lowpan_header_decompress. + */ + lowpan_802154_cb(skb)->d_size = 0; + + ret = lowpan_iphc_decompress(skb); + if (ret < 0) + return RX_DROP_UNUSABLE; + + return RX_QUEUED; +} + +lowpan_rx_result lowpan_rx_h_ipv6(struct sk_buff *skb) +{ + if (!lowpan_is_ipv6(*skb_network_header(skb))) + return RX_CONTINUE; + + /* Pull off the 1-byte of 6lowpan header. */ + skb_pull(skb, 1); + return RX_QUEUED; +} + +static inline bool lowpan_is_esc(u8 dispatch) +{ + return dispatch == LOWPAN_DISPATCH_ESC; +} + +static lowpan_rx_result lowpan_rx_h_esc(struct sk_buff *skb) +{ + if (!lowpan_is_esc(*skb_network_header(skb))) + return RX_CONTINUE; + + net_warn_ratelimited("%s: %s\n", skb->dev->name, + "6LoWPAN ESC not supported\n"); + + return RX_DROP_UNUSABLE; +} + +static inline bool lowpan_is_hc1(u8 dispatch) +{ + return dispatch == LOWPAN_DISPATCH_HC1; +} + +static lowpan_rx_result lowpan_rx_h_hc1(struct sk_buff *skb) +{ + if (!lowpan_is_hc1(*skb_network_header(skb))) + return RX_CONTINUE; + + net_warn_ratelimited("%s: %s\n", skb->dev->name, + "6LoWPAN HC1 not supported\n"); + + return RX_DROP_UNUSABLE; +} + +static inline bool lowpan_is_dff(u8 dispatch) +{ + return dispatch == LOWPAN_DISPATCH_DFF; +} + +static lowpan_rx_result lowpan_rx_h_dff(struct sk_buff *skb) +{ + if (!lowpan_is_dff(*skb_network_header(skb))) + return RX_CONTINUE; + + net_warn_ratelimited("%s: %s\n", skb->dev->name, + "6LoWPAN DFF not supported\n"); + + return RX_DROP_UNUSABLE; +} + +static inline bool lowpan_is_bc0(u8 dispatch) +{ + return dispatch == LOWPAN_DISPATCH_BC0; +} + +static lowpan_rx_result lowpan_rx_h_bc0(struct sk_buff *skb) +{ + if (!lowpan_is_bc0(*skb_network_header(skb))) + return RX_CONTINUE; + + net_warn_ratelimited("%s: %s\n", skb->dev->name, + "6LoWPAN BC0 not supported\n"); + + return RX_DROP_UNUSABLE; +} + +static inline bool lowpan_is_mesh(u8 dispatch) +{ + return (dispatch & LOWPAN_DISPATCH_FIRST) == LOWPAN_DISPATCH_MESH; +} + +static lowpan_rx_result lowpan_rx_h_mesh(struct sk_buff *skb) +{ + if (!lowpan_is_mesh(*skb_network_header(skb))) + return RX_CONTINUE; + + net_warn_ratelimited("%s: %s\n", skb->dev->name, + "6LoWPAN MESH not supported\n"); + + return RX_DROP_UNUSABLE; +} + +static int lowpan_invoke_rx_handlers(struct sk_buff *skb) +{ + lowpan_rx_result res; + +#define CALL_RXH(rxh) \ + do { \ + res = rxh(skb); \ + if (res != RX_CONTINUE) \ + goto rxh_next; \ + } while (0) + + /* likely at first */ + CALL_RXH(lowpan_rx_h_iphc); + CALL_RXH(lowpan_rx_h_frag); + CALL_RXH(lowpan_rx_h_ipv6); + CALL_RXH(lowpan_rx_h_esc); + CALL_RXH(lowpan_rx_h_hc1); + CALL_RXH(lowpan_rx_h_dff); + CALL_RXH(lowpan_rx_h_bc0); + CALL_RXH(lowpan_rx_h_mesh); + +rxh_next: + return lowpan_rx_handlers_result(skb, res); +#undef CALL_RXH +} + +static inline bool lowpan_is_nalp(u8 dispatch) +{ + return (dispatch & LOWPAN_DISPATCH_FIRST) == LOWPAN_DISPATCH_NALP; +} + +/* Lookup for reserved dispatch values at: + * https://www.iana.org/assignments/_6lowpan-parameters/_6lowpan-parameters.xhtml#_6lowpan-parameters-1 + * + * Last Updated: 2015-01-22 + */ +static inline bool lowpan_is_reserved(u8 dispatch) +{ + return ((dispatch >= 0x44 && dispatch <= 0x4F) || + (dispatch >= 0x51 && dispatch <= 0x5F) || + (dispatch >= 0xc8 && dispatch <= 0xdf) || + dispatch >= 0xe8); +} + +/* lowpan_rx_h_check checks on generic 6LoWPAN requirements + * in MAC and 6LoWPAN header. + * + * Don't manipulate the skb here, it could be shared buffer. + */ +static inline bool lowpan_rx_h_check(struct sk_buff *skb) +{ + __le16 fc = ieee802154_get_fc_from_skb(skb); + + /* check on ieee802154 conform 6LoWPAN header */ + if (!ieee802154_is_data(fc) || + !ieee802154_skb_is_intra_pan_addressing(fc, skb)) + return false; + + /* check if we can dereference the dispatch */ + if (unlikely(!skb->len)) + return false; + + if (lowpan_is_nalp(*skb_network_header(skb)) || + lowpan_is_reserved(*skb_network_header(skb))) + return false; + + return true; +} + +static int lowpan_rcv(struct sk_buff *skb, struct net_device *wdev, + struct packet_type *pt, struct net_device *orig_wdev) +{ + struct net_device *ldev; + + if (wdev->type != ARPHRD_IEEE802154 || + skb->pkt_type == PACKET_OTHERHOST || + !lowpan_rx_h_check(skb)) + goto drop; + + ldev = wdev->ieee802154_ptr->lowpan_dev; + if (!ldev || !netif_running(ldev)) + goto drop; + + /* Replacing skb->dev and followed rx handlers will manipulate skb. */ + skb = skb_share_check(skb, GFP_ATOMIC); + if (!skb) + goto out; + skb->dev = ldev; + + /* When receive frag1 it's likely that we manipulate the buffer. + * When recevie iphc we manipulate the data buffer. So we need + * to unshare the buffer. + */ + if (lowpan_is_frag1(*skb_network_header(skb)) || + lowpan_is_iphc(*skb_network_header(skb))) { + skb = skb_unshare(skb, GFP_ATOMIC); + if (!skb) + goto out; + } + + return lowpan_invoke_rx_handlers(skb); + +drop: + kfree_skb(skb); +out: + return NET_RX_DROP; +} + +static struct packet_type lowpan_packet_type = { + .type = htons(ETH_P_IEEE802154), + .func = lowpan_rcv, +}; + +void lowpan_rx_init(void) +{ + dev_add_pack(&lowpan_packet_type); +} + +void lowpan_rx_exit(void) +{ + dev_remove_pack(&lowpan_packet_type); +} diff --git a/net/ieee802154/6lowpan/tx.c b/net/ieee802154/6lowpan/tx.c new file mode 100644 index 000000000..0c07662b4 --- /dev/null +++ b/net/ieee802154/6lowpan/tx.c @@ -0,0 +1,309 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include +#include +#include +#include + +#include "6lowpan_i.h" + +#define LOWPAN_FRAG1_HEAD_SIZE 0x4 +#define LOWPAN_FRAGN_HEAD_SIZE 0x5 + +struct lowpan_addr_info { + struct ieee802154_addr daddr; + struct ieee802154_addr saddr; +}; + +static inline struct +lowpan_addr_info *lowpan_skb_priv(const struct sk_buff *skb) +{ + WARN_ON_ONCE(skb_headroom(skb) < sizeof(struct lowpan_addr_info)); + return (struct lowpan_addr_info *)(skb->data - + sizeof(struct lowpan_addr_info)); +} + +/* This callback will be called from AF_PACKET and IPv6 stack, the AF_PACKET + * sockets gives an 8 byte array for addresses only! + * + * TODO I think AF_PACKET DGRAM (sending/receiving) RAW (sending) makes no + * sense here. We should disable it, the right use-case would be AF_INET6 + * RAW/DGRAM sockets. + */ +int lowpan_header_create(struct sk_buff *skb, struct net_device *ldev, + unsigned short type, const void *daddr, + const void *saddr, unsigned int len) +{ + struct wpan_dev *wpan_dev = lowpan_802154_dev(ldev)->wdev->ieee802154_ptr; + struct lowpan_addr_info *info = lowpan_skb_priv(skb); + struct lowpan_802154_neigh *llneigh = NULL; + const struct ipv6hdr *hdr = ipv6_hdr(skb); + struct neighbour *n; + + if (!daddr) + return -EINVAL; + + /* TODO: + * if this package isn't ipv6 one, where should it be routed? + */ + if (type != ETH_P_IPV6) + return 0; + + /* intra-pan communication */ + info->saddr.pan_id = wpan_dev->pan_id; + info->daddr.pan_id = info->saddr.pan_id; + + if (!memcmp(daddr, ldev->broadcast, EUI64_ADDR_LEN)) { + info->daddr.short_addr = cpu_to_le16(IEEE802154_ADDR_BROADCAST); + info->daddr.mode = IEEE802154_ADDR_SHORT; + } else { + __le16 short_addr = cpu_to_le16(IEEE802154_ADDR_SHORT_UNSPEC); + + n = neigh_lookup(&nd_tbl, &hdr->daddr, ldev); + if (n) { + llneigh = lowpan_802154_neigh(neighbour_priv(n)); + read_lock_bh(&n->lock); + short_addr = llneigh->short_addr; + read_unlock_bh(&n->lock); + } + + if (llneigh && + lowpan_802154_is_valid_src_short_addr(short_addr)) { + info->daddr.short_addr = short_addr; + info->daddr.mode = IEEE802154_ADDR_SHORT; + } else { + info->daddr.mode = IEEE802154_ADDR_LONG; + ieee802154_be64_to_le64(&info->daddr.extended_addr, + daddr); + } + + if (n) + neigh_release(n); + } + + if (!saddr) { + if (lowpan_802154_is_valid_src_short_addr(wpan_dev->short_addr)) { + info->saddr.mode = IEEE802154_ADDR_SHORT; + info->saddr.short_addr = wpan_dev->short_addr; + } else { + info->saddr.mode = IEEE802154_ADDR_LONG; + info->saddr.extended_addr = wpan_dev->extended_addr; + } + } else { + info->saddr.mode = IEEE802154_ADDR_LONG; + ieee802154_be64_to_le64(&info->saddr.extended_addr, saddr); + } + + return 0; +} + +static struct sk_buff* +lowpan_alloc_frag(struct sk_buff *skb, int size, + const struct ieee802154_hdr *master_hdr, bool frag1) +{ + struct net_device *wdev = lowpan_802154_dev(skb->dev)->wdev; + struct sk_buff *frag; + int rc; + + frag = alloc_skb(wdev->needed_headroom + wdev->needed_tailroom + size, + GFP_ATOMIC); + + if (likely(frag)) { + frag->dev = wdev; + frag->priority = skb->priority; + skb_reserve(frag, wdev->needed_headroom); + skb_reset_network_header(frag); + *mac_cb(frag) = *mac_cb(skb); + + if (frag1) { + skb_put_data(frag, skb_mac_header(skb), skb->mac_len); + } else { + rc = wpan_dev_hard_header(frag, wdev, + &master_hdr->dest, + &master_hdr->source, size); + if (rc < 0) { + kfree_skb(frag); + return ERR_PTR(rc); + } + } + } else { + frag = ERR_PTR(-ENOMEM); + } + + return frag; +} + +static int +lowpan_xmit_fragment(struct sk_buff *skb, const struct ieee802154_hdr *wpan_hdr, + u8 *frag_hdr, int frag_hdrlen, + int offset, int len, bool frag1) +{ + struct sk_buff *frag; + + raw_dump_inline(__func__, " fragment header", frag_hdr, frag_hdrlen); + + frag = lowpan_alloc_frag(skb, frag_hdrlen + len, wpan_hdr, frag1); + if (IS_ERR(frag)) + return PTR_ERR(frag); + + skb_put_data(frag, frag_hdr, frag_hdrlen); + skb_put_data(frag, skb_network_header(skb) + offset, len); + + raw_dump_table(__func__, " fragment dump", frag->data, frag->len); + + return dev_queue_xmit(frag); +} + +static int +lowpan_xmit_fragmented(struct sk_buff *skb, struct net_device *ldev, + const struct ieee802154_hdr *wpan_hdr, u16 dgram_size, + u16 dgram_offset) +{ + __be16 frag_tag; + u8 frag_hdr[5]; + int frag_cap, frag_len, payload_cap, rc; + int skb_unprocessed, skb_offset; + + frag_tag = htons(lowpan_802154_dev(ldev)->fragment_tag); + lowpan_802154_dev(ldev)->fragment_tag++; + + frag_hdr[0] = LOWPAN_DISPATCH_FRAG1 | ((dgram_size >> 8) & 0x07); + frag_hdr[1] = dgram_size & 0xff; + memcpy(frag_hdr + 2, &frag_tag, sizeof(frag_tag)); + + payload_cap = ieee802154_max_payload(wpan_hdr); + + frag_len = round_down(payload_cap - LOWPAN_FRAG1_HEAD_SIZE - + skb_network_header_len(skb), 8); + + skb_offset = skb_network_header_len(skb); + skb_unprocessed = skb->len - skb->mac_len - skb_offset; + + rc = lowpan_xmit_fragment(skb, wpan_hdr, frag_hdr, + LOWPAN_FRAG1_HEAD_SIZE, 0, + frag_len + skb_network_header_len(skb), + true); + if (rc) { + pr_debug("%s unable to send FRAG1 packet (tag: %d)", + __func__, ntohs(frag_tag)); + goto err; + } + + frag_hdr[0] &= ~LOWPAN_DISPATCH_FRAG1; + frag_hdr[0] |= LOWPAN_DISPATCH_FRAGN; + frag_cap = round_down(payload_cap - LOWPAN_FRAGN_HEAD_SIZE, 8); + + do { + dgram_offset += frag_len; + skb_offset += frag_len; + skb_unprocessed -= frag_len; + frag_len = min(frag_cap, skb_unprocessed); + + frag_hdr[4] = dgram_offset >> 3; + + rc = lowpan_xmit_fragment(skb, wpan_hdr, frag_hdr, + LOWPAN_FRAGN_HEAD_SIZE, skb_offset, + frag_len, false); + if (rc) { + pr_debug("%s unable to send a FRAGN packet. (tag: %d, offset: %d)\n", + __func__, ntohs(frag_tag), skb_offset); + goto err; + } + } while (skb_unprocessed > frag_cap); + + ldev->stats.tx_packets++; + ldev->stats.tx_bytes += dgram_size; + consume_skb(skb); + return NET_XMIT_SUCCESS; + +err: + kfree_skb(skb); + return rc; +} + +static int lowpan_header(struct sk_buff *skb, struct net_device *ldev, + u16 *dgram_size, u16 *dgram_offset) +{ + struct wpan_dev *wpan_dev = lowpan_802154_dev(ldev)->wdev->ieee802154_ptr; + struct ieee802154_mac_cb *cb = mac_cb_init(skb); + struct lowpan_addr_info info; + + memcpy(&info, lowpan_skb_priv(skb), sizeof(info)); + + *dgram_size = skb->len; + lowpan_header_compress(skb, ldev, &info.daddr, &info.saddr); + /* dgram_offset = (saved bytes after compression) + lowpan header len */ + *dgram_offset = (*dgram_size - skb->len) + skb_network_header_len(skb); + + cb->type = IEEE802154_FC_TYPE_DATA; + + if (info.daddr.mode == IEEE802154_ADDR_SHORT && + ieee802154_is_broadcast_short_addr(info.daddr.short_addr)) + cb->ackreq = false; + else + cb->ackreq = wpan_dev->ackreq; + + return wpan_dev_hard_header(skb, lowpan_802154_dev(ldev)->wdev, + &info.daddr, &info.saddr, 0); +} + +netdev_tx_t lowpan_xmit(struct sk_buff *skb, struct net_device *ldev) +{ + struct ieee802154_hdr wpan_hdr; + int max_single, ret; + u16 dgram_size, dgram_offset; + + pr_debug("package xmit\n"); + + WARN_ON_ONCE(skb->len > IPV6_MIN_MTU); + + /* We must take a copy of the skb before we modify/replace the ipv6 + * header as the header could be used elsewhere + */ + if (unlikely(skb_headroom(skb) < ldev->needed_headroom || + skb_tailroom(skb) < ldev->needed_tailroom)) { + struct sk_buff *nskb; + + nskb = skb_copy_expand(skb, ldev->needed_headroom, + ldev->needed_tailroom, GFP_ATOMIC); + if (likely(nskb)) { + consume_skb(skb); + skb = nskb; + } else { + kfree_skb(skb); + return NET_XMIT_DROP; + } + } else { + skb = skb_unshare(skb, GFP_ATOMIC); + if (!skb) + return NET_XMIT_DROP; + } + + ret = lowpan_header(skb, ldev, &dgram_size, &dgram_offset); + if (ret < 0) { + kfree_skb(skb); + return NET_XMIT_DROP; + } + + if (ieee802154_hdr_peek(skb, &wpan_hdr) < 0) { + kfree_skb(skb); + return NET_XMIT_DROP; + } + + max_single = ieee802154_max_payload(&wpan_hdr); + + if (skb_tail_pointer(skb) - skb_network_header(skb) <= max_single) { + skb->dev = lowpan_802154_dev(ldev)->wdev; + ldev->stats.tx_packets++; + ldev->stats.tx_bytes += dgram_size; + return dev_queue_xmit(skb); + } else { + netdev_tx_t rc; + + pr_debug("frame is too big, fragmentation is needed\n"); + rc = lowpan_xmit_fragmented(skb, ldev, &wpan_hdr, dgram_size, + dgram_offset); + + return rc < 0 ? NET_XMIT_DROP : rc; + } +} diff --git a/net/ieee802154/Kconfig b/net/ieee802154/Kconfig new file mode 100644 index 000000000..bcb05ba97 --- /dev/null +++ b/net/ieee802154/Kconfig @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: GPL-2.0-only +menuconfig IEEE802154 + tristate "IEEE Std 802.15.4 Low-Rate Wireless Personal Area Networks support" + help + IEEE Std 802.15.4 defines a low data rate, low power and low + complexity short range wireless personal area networks. It was + designed to organise networks of sensors, switches, etc automation + devices. Maximum allowed data rate is 250 kb/s and typical personal + operating space around 10m. + + Say Y here to compile LR-WPAN support into the kernel or say M to + compile it as modules. + +if IEEE802154 + +config IEEE802154_NL802154_EXPERIMENTAL + bool "IEEE 802.15.4 experimental netlink support" + help + Adds experimental netlink support for nl802154. + +config IEEE802154_SOCKET + tristate "IEEE 802.15.4 socket interface" + default y + help + Socket interface for IEEE 802.15.4. Contains DGRAM sockets interface + for 802.15.4 dataframes. Also RAW socket interface to build MAC + header from userspace. + +source "net/ieee802154/6lowpan/Kconfig" + +endif diff --git a/net/ieee802154/Makefile b/net/ieee802154/Makefile new file mode 100644 index 000000000..f05b7bdae --- /dev/null +++ b/net/ieee802154/Makefile @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: GPL-2.0 +obj-$(CONFIG_IEEE802154) += ieee802154.o +obj-$(CONFIG_IEEE802154_SOCKET) += ieee802154_socket.o +obj-y += 6lowpan/ + +ieee802154-y := netlink.o nl-mac.o nl-phy.o nl_policy.o core.o \ + header_ops.o sysfs.o nl802154.o trace.o +ieee802154_socket-y := socket.o + +CFLAGS_trace.o := -I$(src) diff --git a/net/ieee802154/core.c b/net/ieee802154/core.c new file mode 100644 index 000000000..de259b517 --- /dev/null +++ b/net/ieee802154/core.c @@ -0,0 +1,388 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2007, 2008, 2009 Siemens AG + */ + +#include +#include +#include +#include + +#include +#include + +#include "ieee802154.h" +#include "nl802154.h" +#include "sysfs.h" +#include "core.h" + +/* name for sysfs, %d is appended */ +#define PHY_NAME "phy" + +/* RCU-protected (and RTNL for writers) */ +LIST_HEAD(cfg802154_rdev_list); +int cfg802154_rdev_list_generation; + +struct wpan_phy *wpan_phy_find(const char *str) +{ + struct device *dev; + + if (WARN_ON(!str)) + return NULL; + + dev = class_find_device_by_name(&wpan_phy_class, str); + if (!dev) + return NULL; + + return container_of(dev, struct wpan_phy, dev); +} +EXPORT_SYMBOL(wpan_phy_find); + +struct wpan_phy_iter_data { + int (*fn)(struct wpan_phy *phy, void *data); + void *data; +}; + +static int wpan_phy_iter(struct device *dev, void *_data) +{ + struct wpan_phy_iter_data *wpid = _data; + struct wpan_phy *phy = container_of(dev, struct wpan_phy, dev); + + return wpid->fn(phy, wpid->data); +} + +int wpan_phy_for_each(int (*fn)(struct wpan_phy *phy, void *data), + void *data) +{ + struct wpan_phy_iter_data wpid = { + .fn = fn, + .data = data, + }; + + return class_for_each_device(&wpan_phy_class, NULL, + &wpid, wpan_phy_iter); +} +EXPORT_SYMBOL(wpan_phy_for_each); + +struct cfg802154_registered_device * +cfg802154_rdev_by_wpan_phy_idx(int wpan_phy_idx) +{ + struct cfg802154_registered_device *result = NULL, *rdev; + + ASSERT_RTNL(); + + list_for_each_entry(rdev, &cfg802154_rdev_list, list) { + if (rdev->wpan_phy_idx == wpan_phy_idx) { + result = rdev; + break; + } + } + + return result; +} + +struct wpan_phy *wpan_phy_idx_to_wpan_phy(int wpan_phy_idx) +{ + struct cfg802154_registered_device *rdev; + + ASSERT_RTNL(); + + rdev = cfg802154_rdev_by_wpan_phy_idx(wpan_phy_idx); + if (!rdev) + return NULL; + return &rdev->wpan_phy; +} + +struct wpan_phy * +wpan_phy_new(const struct cfg802154_ops *ops, size_t priv_size) +{ + static atomic_t wpan_phy_counter = ATOMIC_INIT(0); + struct cfg802154_registered_device *rdev; + size_t alloc_size; + + alloc_size = sizeof(*rdev) + priv_size; + rdev = kzalloc(alloc_size, GFP_KERNEL); + if (!rdev) + return NULL; + + rdev->ops = ops; + + rdev->wpan_phy_idx = atomic_inc_return(&wpan_phy_counter); + + if (unlikely(rdev->wpan_phy_idx < 0)) { + /* ugh, wrapped! */ + atomic_dec(&wpan_phy_counter); + kfree(rdev); + return NULL; + } + + /* atomic_inc_return makes it start at 1, make it start at 0 */ + rdev->wpan_phy_idx--; + + INIT_LIST_HEAD(&rdev->wpan_dev_list); + device_initialize(&rdev->wpan_phy.dev); + dev_set_name(&rdev->wpan_phy.dev, PHY_NAME "%d", rdev->wpan_phy_idx); + + rdev->wpan_phy.dev.class = &wpan_phy_class; + rdev->wpan_phy.dev.platform_data = rdev; + + wpan_phy_net_set(&rdev->wpan_phy, &init_net); + + init_waitqueue_head(&rdev->dev_wait); + + return &rdev->wpan_phy; +} +EXPORT_SYMBOL(wpan_phy_new); + +int wpan_phy_register(struct wpan_phy *phy) +{ + struct cfg802154_registered_device *rdev = wpan_phy_to_rdev(phy); + int ret; + + rtnl_lock(); + ret = device_add(&phy->dev); + if (ret) { + rtnl_unlock(); + return ret; + } + + list_add_rcu(&rdev->list, &cfg802154_rdev_list); + cfg802154_rdev_list_generation++; + + /* TODO phy registered lock */ + rtnl_unlock(); + + /* TODO nl802154 phy notify */ + + return 0; +} +EXPORT_SYMBOL(wpan_phy_register); + +void wpan_phy_unregister(struct wpan_phy *phy) +{ + struct cfg802154_registered_device *rdev = wpan_phy_to_rdev(phy); + + wait_event(rdev->dev_wait, ({ + int __count; + rtnl_lock(); + __count = rdev->opencount; + rtnl_unlock(); + __count == 0; })); + + rtnl_lock(); + /* TODO nl802154 phy notify */ + /* TODO phy registered lock */ + + WARN_ON(!list_empty(&rdev->wpan_dev_list)); + + /* First remove the hardware from everywhere, this makes + * it impossible to find from userspace. + */ + list_del_rcu(&rdev->list); + synchronize_rcu(); + + cfg802154_rdev_list_generation++; + + device_del(&phy->dev); + + rtnl_unlock(); +} +EXPORT_SYMBOL(wpan_phy_unregister); + +void wpan_phy_free(struct wpan_phy *phy) +{ + put_device(&phy->dev); +} +EXPORT_SYMBOL(wpan_phy_free); + +int cfg802154_switch_netns(struct cfg802154_registered_device *rdev, + struct net *net) +{ + struct wpan_dev *wpan_dev; + int err = 0; + + list_for_each_entry(wpan_dev, &rdev->wpan_dev_list, list) { + if (!wpan_dev->netdev) + continue; + wpan_dev->netdev->features &= ~NETIF_F_NETNS_LOCAL; + err = dev_change_net_namespace(wpan_dev->netdev, net, "wpan%d"); + if (err) + break; + wpan_dev->netdev->features |= NETIF_F_NETNS_LOCAL; + } + + if (err) { + /* failed -- clean up to old netns */ + net = wpan_phy_net(&rdev->wpan_phy); + + list_for_each_entry_continue_reverse(wpan_dev, + &rdev->wpan_dev_list, + list) { + if (!wpan_dev->netdev) + continue; + wpan_dev->netdev->features &= ~NETIF_F_NETNS_LOCAL; + err = dev_change_net_namespace(wpan_dev->netdev, net, + "wpan%d"); + WARN_ON(err); + wpan_dev->netdev->features |= NETIF_F_NETNS_LOCAL; + } + + return err; + } + + wpan_phy_net_set(&rdev->wpan_phy, net); + + err = device_rename(&rdev->wpan_phy.dev, dev_name(&rdev->wpan_phy.dev)); + WARN_ON(err); + + return 0; +} + +void cfg802154_dev_free(struct cfg802154_registered_device *rdev) +{ + kfree(rdev); +} + +static void +cfg802154_update_iface_num(struct cfg802154_registered_device *rdev, + int iftype, int num) +{ + ASSERT_RTNL(); + + rdev->num_running_ifaces += num; +} + +static int cfg802154_netdev_notifier_call(struct notifier_block *nb, + unsigned long state, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct wpan_dev *wpan_dev = dev->ieee802154_ptr; + struct cfg802154_registered_device *rdev; + + if (!wpan_dev) + return NOTIFY_DONE; + + rdev = wpan_phy_to_rdev(wpan_dev->wpan_phy); + + /* TODO WARN_ON unspec type */ + + switch (state) { + /* TODO NETDEV_DEVTYPE */ + case NETDEV_REGISTER: + dev->features |= NETIF_F_NETNS_LOCAL; + wpan_dev->identifier = ++rdev->wpan_dev_id; + list_add_rcu(&wpan_dev->list, &rdev->wpan_dev_list); + rdev->devlist_generation++; + + wpan_dev->netdev = dev; + break; + case NETDEV_DOWN: + cfg802154_update_iface_num(rdev, wpan_dev->iftype, -1); + + rdev->opencount--; + wake_up(&rdev->dev_wait); + break; + case NETDEV_UP: + cfg802154_update_iface_num(rdev, wpan_dev->iftype, 1); + + rdev->opencount++; + break; + case NETDEV_UNREGISTER: + /* It is possible to get NETDEV_UNREGISTER + * multiple times. To detect that, check + * that the interface is still on the list + * of registered interfaces, and only then + * remove and clean it up. + */ + if (!list_empty(&wpan_dev->list)) { + list_del_rcu(&wpan_dev->list); + rdev->devlist_generation++; + } + /* synchronize (so that we won't find this netdev + * from other code any more) and then clear the list + * head so that the above code can safely check for + * !list_empty() to avoid double-cleanup. + */ + synchronize_rcu(); + INIT_LIST_HEAD(&wpan_dev->list); + break; + default: + return NOTIFY_DONE; + } + + return NOTIFY_OK; +} + +static struct notifier_block cfg802154_netdev_notifier = { + .notifier_call = cfg802154_netdev_notifier_call, +}; + +static void __net_exit cfg802154_pernet_exit(struct net *net) +{ + struct cfg802154_registered_device *rdev; + + rtnl_lock(); + list_for_each_entry(rdev, &cfg802154_rdev_list, list) { + if (net_eq(wpan_phy_net(&rdev->wpan_phy), net)) + WARN_ON(cfg802154_switch_netns(rdev, &init_net)); + } + rtnl_unlock(); +} + +static struct pernet_operations cfg802154_pernet_ops = { + .exit = cfg802154_pernet_exit, +}; + +static int __init wpan_phy_class_init(void) +{ + int rc; + + rc = register_pernet_device(&cfg802154_pernet_ops); + if (rc) + goto err; + + rc = wpan_phy_sysfs_init(); + if (rc) + goto err_sysfs; + + rc = register_netdevice_notifier(&cfg802154_netdev_notifier); + if (rc) + goto err_nl; + + rc = ieee802154_nl_init(); + if (rc) + goto err_notifier; + + rc = nl802154_init(); + if (rc) + goto err_ieee802154_nl; + + return 0; + +err_ieee802154_nl: + ieee802154_nl_exit(); + +err_notifier: + unregister_netdevice_notifier(&cfg802154_netdev_notifier); +err_nl: + wpan_phy_sysfs_exit(); +err_sysfs: + unregister_pernet_device(&cfg802154_pernet_ops); +err: + return rc; +} +subsys_initcall(wpan_phy_class_init); + +static void __exit wpan_phy_class_exit(void) +{ + nl802154_exit(); + ieee802154_nl_exit(); + unregister_netdevice_notifier(&cfg802154_netdev_notifier); + wpan_phy_sysfs_exit(); + unregister_pernet_device(&cfg802154_pernet_ops); +} +module_exit(wpan_phy_class_exit); + +MODULE_LICENSE("GPL v2"); +MODULE_DESCRIPTION("IEEE 802.15.4 configuration interface"); +MODULE_AUTHOR("Dmitry Eremin-Solenikov"); diff --git a/net/ieee802154/core.h b/net/ieee802154/core.h new file mode 100644 index 000000000..1c19f575d --- /dev/null +++ b/net/ieee802154/core.h @@ -0,0 +1,50 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __IEEE802154_CORE_H +#define __IEEE802154_CORE_H + +#include + +struct cfg802154_registered_device { + const struct cfg802154_ops *ops; + struct list_head list; + + /* wpan_phy index, internal only */ + int wpan_phy_idx; + + /* also protected by devlist_mtx */ + int opencount; + wait_queue_head_t dev_wait; + + /* protected by RTNL only */ + int num_running_ifaces; + + /* associated wpan interfaces, protected by rtnl or RCU */ + struct list_head wpan_dev_list; + int devlist_generation, wpan_dev_id; + + /* must be last because of the way we do wpan_phy_priv(), + * and it should at least be aligned to NETDEV_ALIGN + */ + struct wpan_phy wpan_phy __aligned(NETDEV_ALIGN); +}; + +static inline struct cfg802154_registered_device * +wpan_phy_to_rdev(struct wpan_phy *wpan_phy) +{ + BUG_ON(!wpan_phy); + return container_of(wpan_phy, struct cfg802154_registered_device, + wpan_phy); +} + +extern struct list_head cfg802154_rdev_list; +extern int cfg802154_rdev_list_generation; + +int cfg802154_switch_netns(struct cfg802154_registered_device *rdev, + struct net *net); +/* free object */ +void cfg802154_dev_free(struct cfg802154_registered_device *rdev); +struct cfg802154_registered_device * +cfg802154_rdev_by_wpan_phy_idx(int wpan_phy_idx); +struct wpan_phy *wpan_phy_idx_to_wpan_phy(int wpan_phy_idx); + +#endif /* __IEEE802154_CORE_H */ diff --git a/net/ieee802154/header_ops.c b/net/ieee802154/header_ops.c new file mode 100644 index 000000000..af337cf62 --- /dev/null +++ b/net/ieee802154/header_ops.c @@ -0,0 +1,318 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2014 Fraunhofer ITWM + * + * Written by: + * Phoebe Buckheister + */ + +#include + +#include +#include + +static int +ieee802154_hdr_push_addr(u8 *buf, const struct ieee802154_addr *addr, + bool omit_pan) +{ + int pos = 0; + + if (addr->mode == IEEE802154_ADDR_NONE) + return 0; + + if (!omit_pan) { + memcpy(buf + pos, &addr->pan_id, 2); + pos += 2; + } + + switch (addr->mode) { + case IEEE802154_ADDR_SHORT: + memcpy(buf + pos, &addr->short_addr, 2); + pos += 2; + break; + + case IEEE802154_ADDR_LONG: + memcpy(buf + pos, &addr->extended_addr, IEEE802154_ADDR_LEN); + pos += IEEE802154_ADDR_LEN; + break; + + default: + return -EINVAL; + } + + return pos; +} + +static int +ieee802154_hdr_push_sechdr(u8 *buf, const struct ieee802154_sechdr *hdr) +{ + int pos = 5; + + memcpy(buf, hdr, 1); + memcpy(buf + 1, &hdr->frame_counter, 4); + + switch (hdr->key_id_mode) { + case IEEE802154_SCF_KEY_IMPLICIT: + return pos; + + case IEEE802154_SCF_KEY_INDEX: + break; + + case IEEE802154_SCF_KEY_SHORT_INDEX: + memcpy(buf + pos, &hdr->short_src, 4); + pos += 4; + break; + + case IEEE802154_SCF_KEY_HW_INDEX: + memcpy(buf + pos, &hdr->extended_src, IEEE802154_ADDR_LEN); + pos += IEEE802154_ADDR_LEN; + break; + } + + buf[pos++] = hdr->key_id; + + return pos; +} + +int +ieee802154_hdr_push(struct sk_buff *skb, struct ieee802154_hdr *hdr) +{ + u8 buf[IEEE802154_MAX_HEADER_LEN]; + int pos = 2; + int rc; + struct ieee802154_hdr_fc *fc = &hdr->fc; + + buf[pos++] = hdr->seq; + + fc->dest_addr_mode = hdr->dest.mode; + + rc = ieee802154_hdr_push_addr(buf + pos, &hdr->dest, false); + if (rc < 0) + return -EINVAL; + pos += rc; + + fc->source_addr_mode = hdr->source.mode; + + if (hdr->source.pan_id == hdr->dest.pan_id && + hdr->dest.mode != IEEE802154_ADDR_NONE) + fc->intra_pan = true; + + rc = ieee802154_hdr_push_addr(buf + pos, &hdr->source, fc->intra_pan); + if (rc < 0) + return -EINVAL; + pos += rc; + + if (fc->security_enabled) { + fc->version = 1; + + rc = ieee802154_hdr_push_sechdr(buf + pos, &hdr->sec); + if (rc < 0) + return -EINVAL; + + pos += rc; + } + + memcpy(buf, fc, 2); + + memcpy(skb_push(skb, pos), buf, pos); + + return pos; +} +EXPORT_SYMBOL_GPL(ieee802154_hdr_push); + +static int +ieee802154_hdr_get_addr(const u8 *buf, int mode, bool omit_pan, + struct ieee802154_addr *addr) +{ + int pos = 0; + + addr->mode = mode; + + if (mode == IEEE802154_ADDR_NONE) + return 0; + + if (!omit_pan) { + memcpy(&addr->pan_id, buf + pos, 2); + pos += 2; + } + + if (mode == IEEE802154_ADDR_SHORT) { + memcpy(&addr->short_addr, buf + pos, 2); + return pos + 2; + } else { + memcpy(&addr->extended_addr, buf + pos, IEEE802154_ADDR_LEN); + return pos + IEEE802154_ADDR_LEN; + } +} + +static int ieee802154_hdr_addr_len(int mode, bool omit_pan) +{ + int pan_len = omit_pan ? 0 : 2; + + switch (mode) { + case IEEE802154_ADDR_NONE: return 0; + case IEEE802154_ADDR_SHORT: return 2 + pan_len; + case IEEE802154_ADDR_LONG: return IEEE802154_ADDR_LEN + pan_len; + default: return -EINVAL; + } +} + +static int +ieee802154_hdr_get_sechdr(const u8 *buf, struct ieee802154_sechdr *hdr) +{ + int pos = 5; + + memcpy(hdr, buf, 1); + memcpy(&hdr->frame_counter, buf + 1, 4); + + switch (hdr->key_id_mode) { + case IEEE802154_SCF_KEY_IMPLICIT: + return pos; + + case IEEE802154_SCF_KEY_INDEX: + break; + + case IEEE802154_SCF_KEY_SHORT_INDEX: + memcpy(&hdr->short_src, buf + pos, 4); + pos += 4; + break; + + case IEEE802154_SCF_KEY_HW_INDEX: + memcpy(&hdr->extended_src, buf + pos, IEEE802154_ADDR_LEN); + pos += IEEE802154_ADDR_LEN; + break; + } + + hdr->key_id = buf[pos++]; + + return pos; +} + +static int ieee802154_sechdr_lengths[4] = { + [IEEE802154_SCF_KEY_IMPLICIT] = 5, + [IEEE802154_SCF_KEY_INDEX] = 6, + [IEEE802154_SCF_KEY_SHORT_INDEX] = 10, + [IEEE802154_SCF_KEY_HW_INDEX] = 14, +}; + +static int ieee802154_hdr_sechdr_len(u8 sc) +{ + return ieee802154_sechdr_lengths[IEEE802154_SCF_KEY_ID_MODE(sc)]; +} + +static int ieee802154_hdr_minlen(const struct ieee802154_hdr *hdr) +{ + int dlen, slen; + + dlen = ieee802154_hdr_addr_len(hdr->fc.dest_addr_mode, false); + slen = ieee802154_hdr_addr_len(hdr->fc.source_addr_mode, + hdr->fc.intra_pan); + + if (slen < 0 || dlen < 0) + return -EINVAL; + + return 3 + dlen + slen + hdr->fc.security_enabled; +} + +static int +ieee802154_hdr_get_addrs(const u8 *buf, struct ieee802154_hdr *hdr) +{ + int pos = 0; + + pos += ieee802154_hdr_get_addr(buf + pos, hdr->fc.dest_addr_mode, + false, &hdr->dest); + pos += ieee802154_hdr_get_addr(buf + pos, hdr->fc.source_addr_mode, + hdr->fc.intra_pan, &hdr->source); + + if (hdr->fc.intra_pan) + hdr->source.pan_id = hdr->dest.pan_id; + + return pos; +} + +int +ieee802154_hdr_pull(struct sk_buff *skb, struct ieee802154_hdr *hdr) +{ + int pos = 3, rc; + + if (!pskb_may_pull(skb, 3)) + return -EINVAL; + + memcpy(hdr, skb->data, 3); + + rc = ieee802154_hdr_minlen(hdr); + if (rc < 0 || !pskb_may_pull(skb, rc)) + return -EINVAL; + + pos += ieee802154_hdr_get_addrs(skb->data + pos, hdr); + + if (hdr->fc.security_enabled) { + int want = pos + ieee802154_hdr_sechdr_len(skb->data[pos]); + + if (!pskb_may_pull(skb, want)) + return -EINVAL; + + pos += ieee802154_hdr_get_sechdr(skb->data + pos, &hdr->sec); + } + + skb_pull(skb, pos); + return pos; +} +EXPORT_SYMBOL_GPL(ieee802154_hdr_pull); + +int +ieee802154_hdr_peek_addrs(const struct sk_buff *skb, struct ieee802154_hdr *hdr) +{ + const u8 *buf = skb_mac_header(skb); + int pos = 3, rc; + + if (buf + 3 > skb_tail_pointer(skb)) + return -EINVAL; + + memcpy(hdr, buf, 3); + + rc = ieee802154_hdr_minlen(hdr); + if (rc < 0 || buf + rc > skb_tail_pointer(skb)) + return -EINVAL; + + pos += ieee802154_hdr_get_addrs(buf + pos, hdr); + return pos; +} +EXPORT_SYMBOL_GPL(ieee802154_hdr_peek_addrs); + +int +ieee802154_hdr_peek(const struct sk_buff *skb, struct ieee802154_hdr *hdr) +{ + const u8 *buf = skb_mac_header(skb); + int pos; + + pos = ieee802154_hdr_peek_addrs(skb, hdr); + if (pos < 0) + return -EINVAL; + + if (hdr->fc.security_enabled) { + u8 key_id_mode = IEEE802154_SCF_KEY_ID_MODE(*(buf + pos)); + int want = pos + ieee802154_sechdr_lengths[key_id_mode]; + + if (buf + want > skb_tail_pointer(skb)) + return -EINVAL; + + pos += ieee802154_hdr_get_sechdr(buf + pos, &hdr->sec); + } + + return pos; +} +EXPORT_SYMBOL_GPL(ieee802154_hdr_peek); + +int ieee802154_max_payload(const struct ieee802154_hdr *hdr) +{ + int hlen = ieee802154_hdr_minlen(hdr); + + if (hdr->fc.security_enabled) { + hlen += ieee802154_sechdr_lengths[hdr->sec.key_id_mode] - 1; + hlen += ieee802154_sechdr_authtag_len(&hdr->sec); + } + + return IEEE802154_MTU - hlen - IEEE802154_MFR_SIZE; +} +EXPORT_SYMBOL_GPL(ieee802154_max_payload); diff --git a/net/ieee802154/ieee802154.h b/net/ieee802154/ieee802154.h new file mode 100644 index 000000000..c5d91f783 --- /dev/null +++ b/net/ieee802154/ieee802154.h @@ -0,0 +1,75 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (C) 2007, 2008, 2009 Siemens AG + */ +#ifndef IEEE_802154_LOCAL_H +#define IEEE_802154_LOCAL_H + +int __init ieee802154_nl_init(void); +void ieee802154_nl_exit(void); + +#define IEEE802154_OP(_cmd, _func) \ + { \ + .cmd = _cmd, \ + .doit = _func, \ + .dumpit = NULL, \ + .flags = GENL_ADMIN_PERM, \ + } + +#define IEEE802154_DUMP(_cmd, _func, _dump) \ + { \ + .cmd = _cmd, \ + .doit = _func, \ + .dumpit = _dump, \ + } + +struct genl_info; + +struct sk_buff *ieee802154_nl_create(int flags, u8 req); +int ieee802154_nl_mcast(struct sk_buff *msg, unsigned int group); +struct sk_buff *ieee802154_nl_new_reply(struct genl_info *info, + int flags, u8 req); +int ieee802154_nl_reply(struct sk_buff *msg, struct genl_info *info); + +extern struct genl_family nl802154_family; + +/* genetlink ops/groups */ +int ieee802154_list_phy(struct sk_buff *skb, struct genl_info *info); +int ieee802154_dump_phy(struct sk_buff *skb, struct netlink_callback *cb); +int ieee802154_add_iface(struct sk_buff *skb, struct genl_info *info); +int ieee802154_del_iface(struct sk_buff *skb, struct genl_info *info); + +enum ieee802154_mcgrp_ids { + IEEE802154_COORD_MCGRP, + IEEE802154_BEACON_MCGRP, +}; + +int ieee802154_associate_req(struct sk_buff *skb, struct genl_info *info); +int ieee802154_associate_resp(struct sk_buff *skb, struct genl_info *info); +int ieee802154_disassociate_req(struct sk_buff *skb, struct genl_info *info); +int ieee802154_scan_req(struct sk_buff *skb, struct genl_info *info); +int ieee802154_start_req(struct sk_buff *skb, struct genl_info *info); +int ieee802154_list_iface(struct sk_buff *skb, struct genl_info *info); +int ieee802154_dump_iface(struct sk_buff *skb, struct netlink_callback *cb); +int ieee802154_set_macparams(struct sk_buff *skb, struct genl_info *info); + +int ieee802154_llsec_getparams(struct sk_buff *skb, struct genl_info *info); +int ieee802154_llsec_setparams(struct sk_buff *skb, struct genl_info *info); +int ieee802154_llsec_add_key(struct sk_buff *skb, struct genl_info *info); +int ieee802154_llsec_del_key(struct sk_buff *skb, struct genl_info *info); +int ieee802154_llsec_dump_keys(struct sk_buff *skb, + struct netlink_callback *cb); +int ieee802154_llsec_add_dev(struct sk_buff *skb, struct genl_info *info); +int ieee802154_llsec_del_dev(struct sk_buff *skb, struct genl_info *info); +int ieee802154_llsec_dump_devs(struct sk_buff *skb, + struct netlink_callback *cb); +int ieee802154_llsec_add_devkey(struct sk_buff *skb, struct genl_info *info); +int ieee802154_llsec_del_devkey(struct sk_buff *skb, struct genl_info *info); +int ieee802154_llsec_dump_devkeys(struct sk_buff *skb, + struct netlink_callback *cb); +int ieee802154_llsec_add_seclevel(struct sk_buff *skb, struct genl_info *info); +int ieee802154_llsec_del_seclevel(struct sk_buff *skb, struct genl_info *info); +int ieee802154_llsec_dump_seclevels(struct sk_buff *skb, + struct netlink_callback *cb); + +#endif diff --git a/net/ieee802154/netlink.c b/net/ieee802154/netlink.c new file mode 100644 index 000000000..7d2de4ee6 --- /dev/null +++ b/net/ieee802154/netlink.c @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Netlink interface for IEEE 802.15.4 stack + * + * Copyright 2007, 2008 Siemens AG + * + * Written by: + * Sergey Lapin + * Dmitry Eremin-Solenikov + * Maxim Osipov + */ + +#include +#include +#include +#include + +#include "ieee802154.h" + +static unsigned int ieee802154_seq_num; +static DEFINE_SPINLOCK(ieee802154_seq_lock); + +/* Requests to userspace */ +struct sk_buff *ieee802154_nl_create(int flags, u8 req) +{ + void *hdr; + struct sk_buff *msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + unsigned long f; + + if (!msg) + return NULL; + + spin_lock_irqsave(&ieee802154_seq_lock, f); + hdr = genlmsg_put(msg, 0, ieee802154_seq_num++, + &nl802154_family, flags, req); + spin_unlock_irqrestore(&ieee802154_seq_lock, f); + if (!hdr) { + nlmsg_free(msg); + return NULL; + } + + return msg; +} + +int ieee802154_nl_mcast(struct sk_buff *msg, unsigned int group) +{ + struct nlmsghdr *nlh = nlmsg_hdr(msg); + void *hdr = genlmsg_data(nlmsg_data(nlh)); + + genlmsg_end(msg, hdr); + + return genlmsg_multicast(&nl802154_family, msg, 0, group, GFP_ATOMIC); +} + +struct sk_buff *ieee802154_nl_new_reply(struct genl_info *info, + int flags, u8 req) +{ + void *hdr; + struct sk_buff *msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + + if (!msg) + return NULL; + + hdr = genlmsg_put_reply(msg, info, + &nl802154_family, flags, req); + if (!hdr) { + nlmsg_free(msg); + return NULL; + } + + return msg; +} + +int ieee802154_nl_reply(struct sk_buff *msg, struct genl_info *info) +{ + struct nlmsghdr *nlh = nlmsg_hdr(msg); + void *hdr = genlmsg_data(nlmsg_data(nlh)); + + genlmsg_end(msg, hdr); + + return genlmsg_reply(msg, info); +} + +static const struct genl_small_ops ieee802154_ops[] = { + /* see nl-phy.c */ + IEEE802154_DUMP(IEEE802154_LIST_PHY, ieee802154_list_phy, + ieee802154_dump_phy), + IEEE802154_OP(IEEE802154_ADD_IFACE, ieee802154_add_iface), + IEEE802154_OP(IEEE802154_DEL_IFACE, ieee802154_del_iface), + /* see nl-mac.c */ + IEEE802154_OP(IEEE802154_ASSOCIATE_REQ, ieee802154_associate_req), + IEEE802154_OP(IEEE802154_ASSOCIATE_RESP, ieee802154_associate_resp), + IEEE802154_OP(IEEE802154_DISASSOCIATE_REQ, ieee802154_disassociate_req), + IEEE802154_OP(IEEE802154_SCAN_REQ, ieee802154_scan_req), + IEEE802154_OP(IEEE802154_START_REQ, ieee802154_start_req), + IEEE802154_DUMP(IEEE802154_LIST_IFACE, ieee802154_list_iface, + ieee802154_dump_iface), + IEEE802154_OP(IEEE802154_SET_MACPARAMS, ieee802154_set_macparams), + IEEE802154_OP(IEEE802154_LLSEC_GETPARAMS, ieee802154_llsec_getparams), + IEEE802154_OP(IEEE802154_LLSEC_SETPARAMS, ieee802154_llsec_setparams), + IEEE802154_DUMP(IEEE802154_LLSEC_LIST_KEY, NULL, + ieee802154_llsec_dump_keys), + IEEE802154_OP(IEEE802154_LLSEC_ADD_KEY, ieee802154_llsec_add_key), + IEEE802154_OP(IEEE802154_LLSEC_DEL_KEY, ieee802154_llsec_del_key), + IEEE802154_DUMP(IEEE802154_LLSEC_LIST_DEV, NULL, + ieee802154_llsec_dump_devs), + IEEE802154_OP(IEEE802154_LLSEC_ADD_DEV, ieee802154_llsec_add_dev), + IEEE802154_OP(IEEE802154_LLSEC_DEL_DEV, ieee802154_llsec_del_dev), + IEEE802154_DUMP(IEEE802154_LLSEC_LIST_DEVKEY, NULL, + ieee802154_llsec_dump_devkeys), + IEEE802154_OP(IEEE802154_LLSEC_ADD_DEVKEY, ieee802154_llsec_add_devkey), + IEEE802154_OP(IEEE802154_LLSEC_DEL_DEVKEY, ieee802154_llsec_del_devkey), + IEEE802154_DUMP(IEEE802154_LLSEC_LIST_SECLEVEL, NULL, + ieee802154_llsec_dump_seclevels), + IEEE802154_OP(IEEE802154_LLSEC_ADD_SECLEVEL, + ieee802154_llsec_add_seclevel), + IEEE802154_OP(IEEE802154_LLSEC_DEL_SECLEVEL, + ieee802154_llsec_del_seclevel), +}; + +static const struct genl_multicast_group ieee802154_mcgrps[] = { + [IEEE802154_COORD_MCGRP] = { .name = IEEE802154_MCAST_COORD_NAME, }, + [IEEE802154_BEACON_MCGRP] = { .name = IEEE802154_MCAST_BEACON_NAME, }, +}; + +struct genl_family nl802154_family __ro_after_init = { + .hdrsize = 0, + .name = IEEE802154_NL_NAME, + .version = 1, + .maxattr = IEEE802154_ATTR_MAX, + .policy = ieee802154_policy, + .module = THIS_MODULE, + .small_ops = ieee802154_ops, + .n_small_ops = ARRAY_SIZE(ieee802154_ops), + .resv_start_op = IEEE802154_LLSEC_DEL_SECLEVEL + 1, + .mcgrps = ieee802154_mcgrps, + .n_mcgrps = ARRAY_SIZE(ieee802154_mcgrps), +}; + +int __init ieee802154_nl_init(void) +{ + return genl_register_family(&nl802154_family); +} + +void ieee802154_nl_exit(void) +{ + genl_unregister_family(&nl802154_family); +} diff --git a/net/ieee802154/nl-mac.c b/net/ieee802154/nl-mac.c new file mode 100644 index 000000000..29bf97640 --- /dev/null +++ b/net/ieee802154/nl-mac.c @@ -0,0 +1,1344 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Netlink interface for IEEE 802.15.4 stack + * + * Copyright 2007, 2008 Siemens AG + * + * Written by: + * Sergey Lapin + * Dmitry Eremin-Solenikov + * Maxim Osipov + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ieee802154.h" + +static int nla_put_hwaddr(struct sk_buff *msg, int type, __le64 hwaddr, + int padattr) +{ + return nla_put_u64_64bit(msg, type, swab64((__force u64)hwaddr), + padattr); +} + +static __le64 nla_get_hwaddr(const struct nlattr *nla) +{ + return ieee802154_devaddr_from_raw(nla_data(nla)); +} + +static int nla_put_shortaddr(struct sk_buff *msg, int type, __le16 addr) +{ + return nla_put_u16(msg, type, le16_to_cpu(addr)); +} + +static __le16 nla_get_shortaddr(const struct nlattr *nla) +{ + return cpu_to_le16(nla_get_u16(nla)); +} + +static int ieee802154_nl_start_confirm(struct net_device *dev, u8 status) +{ + struct sk_buff *msg; + + pr_debug("%s\n", __func__); + + msg = ieee802154_nl_create(0, IEEE802154_START_CONF); + if (!msg) + return -ENOBUFS; + + if (nla_put_string(msg, IEEE802154_ATTR_DEV_NAME, dev->name) || + nla_put_u32(msg, IEEE802154_ATTR_DEV_INDEX, dev->ifindex) || + nla_put(msg, IEEE802154_ATTR_HW_ADDR, IEEE802154_ADDR_LEN, + dev->dev_addr) || + nla_put_u8(msg, IEEE802154_ATTR_STATUS, status)) + goto nla_put_failure; + return ieee802154_nl_mcast(msg, IEEE802154_COORD_MCGRP); + +nla_put_failure: + nlmsg_free(msg); + return -ENOBUFS; +} + +static int ieee802154_nl_fill_iface(struct sk_buff *msg, u32 portid, + u32 seq, int flags, struct net_device *dev) +{ + void *hdr; + struct wpan_phy *phy; + struct ieee802154_mlme_ops *ops; + __le16 short_addr, pan_id; + + pr_debug("%s\n", __func__); + + hdr = genlmsg_put(msg, 0, seq, &nl802154_family, flags, + IEEE802154_LIST_IFACE); + if (!hdr) + goto out; + + ops = ieee802154_mlme_ops(dev); + phy = dev->ieee802154_ptr->wpan_phy; + BUG_ON(!phy); + get_device(&phy->dev); + + rtnl_lock(); + short_addr = dev->ieee802154_ptr->short_addr; + pan_id = dev->ieee802154_ptr->pan_id; + rtnl_unlock(); + + if (nla_put_string(msg, IEEE802154_ATTR_DEV_NAME, dev->name) || + nla_put_string(msg, IEEE802154_ATTR_PHY_NAME, wpan_phy_name(phy)) || + nla_put_u32(msg, IEEE802154_ATTR_DEV_INDEX, dev->ifindex) || + nla_put(msg, IEEE802154_ATTR_HW_ADDR, IEEE802154_ADDR_LEN, + dev->dev_addr) || + nla_put_shortaddr(msg, IEEE802154_ATTR_SHORT_ADDR, short_addr) || + nla_put_shortaddr(msg, IEEE802154_ATTR_PAN_ID, pan_id)) + goto nla_put_failure; + + if (ops->get_mac_params) { + struct ieee802154_mac_params params; + + rtnl_lock(); + ops->get_mac_params(dev, ¶ms); + rtnl_unlock(); + + if (nla_put_s8(msg, IEEE802154_ATTR_TXPOWER, + params.transmit_power / 100) || + nla_put_u8(msg, IEEE802154_ATTR_LBT_ENABLED, params.lbt) || + nla_put_u8(msg, IEEE802154_ATTR_CCA_MODE, + params.cca.mode) || + nla_put_s32(msg, IEEE802154_ATTR_CCA_ED_LEVEL, + params.cca_ed_level / 100) || + nla_put_u8(msg, IEEE802154_ATTR_CSMA_RETRIES, + params.csma_retries) || + nla_put_u8(msg, IEEE802154_ATTR_CSMA_MIN_BE, + params.min_be) || + nla_put_u8(msg, IEEE802154_ATTR_CSMA_MAX_BE, + params.max_be) || + nla_put_s8(msg, IEEE802154_ATTR_FRAME_RETRIES, + params.frame_retries)) + goto nla_put_failure; + } + + wpan_phy_put(phy); + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + wpan_phy_put(phy); + genlmsg_cancel(msg, hdr); +out: + return -EMSGSIZE; +} + +/* Requests from userspace */ +static struct net_device *ieee802154_nl_get_dev(struct genl_info *info) +{ + struct net_device *dev; + + if (info->attrs[IEEE802154_ATTR_DEV_NAME]) { + char name[IFNAMSIZ + 1]; + + nla_strscpy(name, info->attrs[IEEE802154_ATTR_DEV_NAME], + sizeof(name)); + dev = dev_get_by_name(&init_net, name); + } else if (info->attrs[IEEE802154_ATTR_DEV_INDEX]) { + dev = dev_get_by_index(&init_net, + nla_get_u32(info->attrs[IEEE802154_ATTR_DEV_INDEX])); + } else { + return NULL; + } + + if (!dev) + return NULL; + + if (dev->type != ARPHRD_IEEE802154) { + dev_put(dev); + return NULL; + } + + return dev; +} + +int ieee802154_associate_req(struct sk_buff *skb, struct genl_info *info) +{ + struct net_device *dev; + struct ieee802154_addr addr; + u8 page; + int ret = -EOPNOTSUPP; + + if (!info->attrs[IEEE802154_ATTR_CHANNEL] || + !info->attrs[IEEE802154_ATTR_COORD_PAN_ID] || + (!info->attrs[IEEE802154_ATTR_COORD_HW_ADDR] && + !info->attrs[IEEE802154_ATTR_COORD_SHORT_ADDR]) || + !info->attrs[IEEE802154_ATTR_CAPABILITY]) + return -EINVAL; + + dev = ieee802154_nl_get_dev(info); + if (!dev) + return -ENODEV; + if (!ieee802154_mlme_ops(dev)->assoc_req) + goto out; + + if (info->attrs[IEEE802154_ATTR_COORD_HW_ADDR]) { + addr.mode = IEEE802154_ADDR_LONG; + addr.extended_addr = nla_get_hwaddr( + info->attrs[IEEE802154_ATTR_COORD_HW_ADDR]); + } else { + addr.mode = IEEE802154_ADDR_SHORT; + addr.short_addr = nla_get_shortaddr( + info->attrs[IEEE802154_ATTR_COORD_SHORT_ADDR]); + } + addr.pan_id = nla_get_shortaddr( + info->attrs[IEEE802154_ATTR_COORD_PAN_ID]); + + if (info->attrs[IEEE802154_ATTR_PAGE]) + page = nla_get_u8(info->attrs[IEEE802154_ATTR_PAGE]); + else + page = 0; + + ret = ieee802154_mlme_ops(dev)->assoc_req(dev, &addr, + nla_get_u8(info->attrs[IEEE802154_ATTR_CHANNEL]), + page, + nla_get_u8(info->attrs[IEEE802154_ATTR_CAPABILITY])); + +out: + dev_put(dev); + return ret; +} + +int ieee802154_associate_resp(struct sk_buff *skb, struct genl_info *info) +{ + struct net_device *dev; + struct ieee802154_addr addr; + int ret = -EOPNOTSUPP; + + if (!info->attrs[IEEE802154_ATTR_STATUS] || + !info->attrs[IEEE802154_ATTR_DEST_HW_ADDR] || + !info->attrs[IEEE802154_ATTR_DEST_SHORT_ADDR]) + return -EINVAL; + + dev = ieee802154_nl_get_dev(info); + if (!dev) + return -ENODEV; + if (!ieee802154_mlme_ops(dev)->assoc_resp) + goto out; + + addr.mode = IEEE802154_ADDR_LONG; + addr.extended_addr = nla_get_hwaddr( + info->attrs[IEEE802154_ATTR_DEST_HW_ADDR]); + rtnl_lock(); + addr.pan_id = dev->ieee802154_ptr->pan_id; + rtnl_unlock(); + + ret = ieee802154_mlme_ops(dev)->assoc_resp(dev, &addr, + nla_get_shortaddr(info->attrs[IEEE802154_ATTR_DEST_SHORT_ADDR]), + nla_get_u8(info->attrs[IEEE802154_ATTR_STATUS])); + +out: + dev_put(dev); + return ret; +} + +int ieee802154_disassociate_req(struct sk_buff *skb, struct genl_info *info) +{ + struct net_device *dev; + struct ieee802154_addr addr; + int ret = -EOPNOTSUPP; + + if ((!info->attrs[IEEE802154_ATTR_DEST_HW_ADDR] && + !info->attrs[IEEE802154_ATTR_DEST_SHORT_ADDR]) || + !info->attrs[IEEE802154_ATTR_REASON]) + return -EINVAL; + + dev = ieee802154_nl_get_dev(info); + if (!dev) + return -ENODEV; + if (!ieee802154_mlme_ops(dev)->disassoc_req) + goto out; + + if (info->attrs[IEEE802154_ATTR_DEST_HW_ADDR]) { + addr.mode = IEEE802154_ADDR_LONG; + addr.extended_addr = nla_get_hwaddr( + info->attrs[IEEE802154_ATTR_DEST_HW_ADDR]); + } else { + addr.mode = IEEE802154_ADDR_SHORT; + addr.short_addr = nla_get_shortaddr( + info->attrs[IEEE802154_ATTR_DEST_SHORT_ADDR]); + } + rtnl_lock(); + addr.pan_id = dev->ieee802154_ptr->pan_id; + rtnl_unlock(); + + ret = ieee802154_mlme_ops(dev)->disassoc_req(dev, &addr, + nla_get_u8(info->attrs[IEEE802154_ATTR_REASON])); + +out: + dev_put(dev); + return ret; +} + +/* PANid, channel, beacon_order = 15, superframe_order = 15, + * PAN_coordinator, battery_life_extension = 0, + * coord_realignment = 0, security_enable = 0 +*/ +int ieee802154_start_req(struct sk_buff *skb, struct genl_info *info) +{ + struct net_device *dev; + struct ieee802154_addr addr; + + u8 channel, bcn_ord, sf_ord; + u8 page; + int pan_coord, blx, coord_realign; + int ret = -EBUSY; + + if (!info->attrs[IEEE802154_ATTR_COORD_PAN_ID] || + !info->attrs[IEEE802154_ATTR_COORD_SHORT_ADDR] || + !info->attrs[IEEE802154_ATTR_CHANNEL] || + !info->attrs[IEEE802154_ATTR_BCN_ORD] || + !info->attrs[IEEE802154_ATTR_SF_ORD] || + !info->attrs[IEEE802154_ATTR_PAN_COORD] || + !info->attrs[IEEE802154_ATTR_BAT_EXT] || + !info->attrs[IEEE802154_ATTR_COORD_REALIGN] + ) + return -EINVAL; + + dev = ieee802154_nl_get_dev(info); + if (!dev) + return -ENODEV; + + if (netif_running(dev)) + goto out; + + if (!ieee802154_mlme_ops(dev)->start_req) { + ret = -EOPNOTSUPP; + goto out; + } + + addr.mode = IEEE802154_ADDR_SHORT; + addr.short_addr = nla_get_shortaddr( + info->attrs[IEEE802154_ATTR_COORD_SHORT_ADDR]); + addr.pan_id = nla_get_shortaddr( + info->attrs[IEEE802154_ATTR_COORD_PAN_ID]); + + channel = nla_get_u8(info->attrs[IEEE802154_ATTR_CHANNEL]); + bcn_ord = nla_get_u8(info->attrs[IEEE802154_ATTR_BCN_ORD]); + sf_ord = nla_get_u8(info->attrs[IEEE802154_ATTR_SF_ORD]); + pan_coord = nla_get_u8(info->attrs[IEEE802154_ATTR_PAN_COORD]); + blx = nla_get_u8(info->attrs[IEEE802154_ATTR_BAT_EXT]); + coord_realign = nla_get_u8(info->attrs[IEEE802154_ATTR_COORD_REALIGN]); + + if (info->attrs[IEEE802154_ATTR_PAGE]) + page = nla_get_u8(info->attrs[IEEE802154_ATTR_PAGE]); + else + page = 0; + + if (addr.short_addr == cpu_to_le16(IEEE802154_ADDR_BROADCAST)) { + ieee802154_nl_start_confirm(dev, IEEE802154_NO_SHORT_ADDRESS); + dev_put(dev); + return -EINVAL; + } + + rtnl_lock(); + ret = ieee802154_mlme_ops(dev)->start_req(dev, &addr, channel, page, + bcn_ord, sf_ord, pan_coord, blx, coord_realign); + rtnl_unlock(); + + /* FIXME: add validation for unused parameters to be sane + * for SoftMAC + */ + ieee802154_nl_start_confirm(dev, IEEE802154_SUCCESS); + +out: + dev_put(dev); + return ret; +} + +int ieee802154_scan_req(struct sk_buff *skb, struct genl_info *info) +{ + struct net_device *dev; + int ret = -EOPNOTSUPP; + u8 type; + u32 channels; + u8 duration; + u8 page; + + if (!info->attrs[IEEE802154_ATTR_SCAN_TYPE] || + !info->attrs[IEEE802154_ATTR_CHANNELS] || + !info->attrs[IEEE802154_ATTR_DURATION]) + return -EINVAL; + + dev = ieee802154_nl_get_dev(info); + if (!dev) + return -ENODEV; + if (!ieee802154_mlme_ops(dev)->scan_req) + goto out; + + type = nla_get_u8(info->attrs[IEEE802154_ATTR_SCAN_TYPE]); + channels = nla_get_u32(info->attrs[IEEE802154_ATTR_CHANNELS]); + duration = nla_get_u8(info->attrs[IEEE802154_ATTR_DURATION]); + + if (info->attrs[IEEE802154_ATTR_PAGE]) + page = nla_get_u8(info->attrs[IEEE802154_ATTR_PAGE]); + else + page = 0; + + ret = ieee802154_mlme_ops(dev)->scan_req(dev, type, channels, + page, duration); + +out: + dev_put(dev); + return ret; +} + +int ieee802154_list_iface(struct sk_buff *skb, struct genl_info *info) +{ + /* Request for interface name, index, type, IEEE address, + * PAN Id, short address + */ + struct sk_buff *msg; + struct net_device *dev = NULL; + int rc = -ENOBUFS; + + pr_debug("%s\n", __func__); + + dev = ieee802154_nl_get_dev(info); + if (!dev) + return -ENODEV; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + goto out_dev; + + rc = ieee802154_nl_fill_iface(msg, info->snd_portid, info->snd_seq, + 0, dev); + if (rc < 0) + goto out_free; + + dev_put(dev); + + return genlmsg_reply(msg, info); +out_free: + nlmsg_free(msg); +out_dev: + dev_put(dev); + return rc; +} + +int ieee802154_dump_iface(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + struct net_device *dev; + int idx; + int s_idx = cb->args[0]; + + pr_debug("%s\n", __func__); + + idx = 0; + for_each_netdev(net, dev) { + if (idx < s_idx || dev->type != ARPHRD_IEEE802154) + goto cont; + + if (ieee802154_nl_fill_iface(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI, dev) < 0) + break; +cont: + idx++; + } + cb->args[0] = idx; + + return skb->len; +} + +int ieee802154_set_macparams(struct sk_buff *skb, struct genl_info *info) +{ + struct net_device *dev = NULL; + struct ieee802154_mlme_ops *ops; + struct ieee802154_mac_params params; + struct wpan_phy *phy; + int rc = -EINVAL; + + pr_debug("%s\n", __func__); + + dev = ieee802154_nl_get_dev(info); + if (!dev) + return -ENODEV; + + ops = ieee802154_mlme_ops(dev); + + if (!ops->get_mac_params || !ops->set_mac_params) { + rc = -EOPNOTSUPP; + goto out; + } + + if (netif_running(dev)) { + rc = -EBUSY; + goto out; + } + + if (!info->attrs[IEEE802154_ATTR_LBT_ENABLED] && + !info->attrs[IEEE802154_ATTR_CCA_MODE] && + !info->attrs[IEEE802154_ATTR_CCA_ED_LEVEL] && + !info->attrs[IEEE802154_ATTR_CSMA_RETRIES] && + !info->attrs[IEEE802154_ATTR_CSMA_MIN_BE] && + !info->attrs[IEEE802154_ATTR_CSMA_MAX_BE] && + !info->attrs[IEEE802154_ATTR_FRAME_RETRIES]) + goto out; + + phy = dev->ieee802154_ptr->wpan_phy; + get_device(&phy->dev); + + rtnl_lock(); + ops->get_mac_params(dev, ¶ms); + + if (info->attrs[IEEE802154_ATTR_TXPOWER]) + params.transmit_power = nla_get_s8(info->attrs[IEEE802154_ATTR_TXPOWER]) * 100; + + if (info->attrs[IEEE802154_ATTR_LBT_ENABLED]) + params.lbt = nla_get_u8(info->attrs[IEEE802154_ATTR_LBT_ENABLED]); + + if (info->attrs[IEEE802154_ATTR_CCA_MODE]) + params.cca.mode = nla_get_u8(info->attrs[IEEE802154_ATTR_CCA_MODE]); + + if (info->attrs[IEEE802154_ATTR_CCA_ED_LEVEL]) + params.cca_ed_level = nla_get_s32(info->attrs[IEEE802154_ATTR_CCA_ED_LEVEL]) * 100; + + if (info->attrs[IEEE802154_ATTR_CSMA_RETRIES]) + params.csma_retries = nla_get_u8(info->attrs[IEEE802154_ATTR_CSMA_RETRIES]); + + if (info->attrs[IEEE802154_ATTR_CSMA_MIN_BE]) + params.min_be = nla_get_u8(info->attrs[IEEE802154_ATTR_CSMA_MIN_BE]); + + if (info->attrs[IEEE802154_ATTR_CSMA_MAX_BE]) + params.max_be = nla_get_u8(info->attrs[IEEE802154_ATTR_CSMA_MAX_BE]); + + if (info->attrs[IEEE802154_ATTR_FRAME_RETRIES]) + params.frame_retries = nla_get_s8(info->attrs[IEEE802154_ATTR_FRAME_RETRIES]); + + rc = ops->set_mac_params(dev, ¶ms); + rtnl_unlock(); + + wpan_phy_put(phy); + dev_put(dev); + + return 0; + +out: + dev_put(dev); + return rc; +} + +static int +ieee802154_llsec_parse_key_id(struct genl_info *info, + struct ieee802154_llsec_key_id *desc) +{ + memset(desc, 0, sizeof(*desc)); + + if (!info->attrs[IEEE802154_ATTR_LLSEC_KEY_MODE]) + return -EINVAL; + + desc->mode = nla_get_u8(info->attrs[IEEE802154_ATTR_LLSEC_KEY_MODE]); + + if (desc->mode == IEEE802154_SCF_KEY_IMPLICIT) { + if (!info->attrs[IEEE802154_ATTR_PAN_ID]) + return -EINVAL; + + desc->device_addr.pan_id = nla_get_shortaddr(info->attrs[IEEE802154_ATTR_PAN_ID]); + + if (info->attrs[IEEE802154_ATTR_SHORT_ADDR]) { + desc->device_addr.mode = IEEE802154_ADDR_SHORT; + desc->device_addr.short_addr = nla_get_shortaddr(info->attrs[IEEE802154_ATTR_SHORT_ADDR]); + } else { + if (!info->attrs[IEEE802154_ATTR_HW_ADDR]) + return -EINVAL; + + desc->device_addr.mode = IEEE802154_ADDR_LONG; + desc->device_addr.extended_addr = nla_get_hwaddr(info->attrs[IEEE802154_ATTR_HW_ADDR]); + } + } + + if (desc->mode != IEEE802154_SCF_KEY_IMPLICIT && + !info->attrs[IEEE802154_ATTR_LLSEC_KEY_ID]) + return -EINVAL; + + if (desc->mode == IEEE802154_SCF_KEY_SHORT_INDEX && + !info->attrs[IEEE802154_ATTR_LLSEC_KEY_SOURCE_SHORT]) + return -EINVAL; + + if (desc->mode == IEEE802154_SCF_KEY_HW_INDEX && + !info->attrs[IEEE802154_ATTR_LLSEC_KEY_SOURCE_EXTENDED]) + return -EINVAL; + + if (desc->mode != IEEE802154_SCF_KEY_IMPLICIT) + desc->id = nla_get_u8(info->attrs[IEEE802154_ATTR_LLSEC_KEY_ID]); + + switch (desc->mode) { + case IEEE802154_SCF_KEY_SHORT_INDEX: + { + u32 source = nla_get_u32(info->attrs[IEEE802154_ATTR_LLSEC_KEY_SOURCE_SHORT]); + + desc->short_source = cpu_to_le32(source); + break; + } + case IEEE802154_SCF_KEY_HW_INDEX: + desc->extended_source = nla_get_hwaddr(info->attrs[IEEE802154_ATTR_LLSEC_KEY_SOURCE_EXTENDED]); + break; + } + + return 0; +} + +static int +ieee802154_llsec_fill_key_id(struct sk_buff *msg, + const struct ieee802154_llsec_key_id *desc) +{ + if (nla_put_u8(msg, IEEE802154_ATTR_LLSEC_KEY_MODE, desc->mode)) + return -EMSGSIZE; + + if (desc->mode == IEEE802154_SCF_KEY_IMPLICIT) { + if (nla_put_shortaddr(msg, IEEE802154_ATTR_PAN_ID, + desc->device_addr.pan_id)) + return -EMSGSIZE; + + if (desc->device_addr.mode == IEEE802154_ADDR_SHORT && + nla_put_shortaddr(msg, IEEE802154_ATTR_SHORT_ADDR, + desc->device_addr.short_addr)) + return -EMSGSIZE; + + if (desc->device_addr.mode == IEEE802154_ADDR_LONG && + nla_put_hwaddr(msg, IEEE802154_ATTR_HW_ADDR, + desc->device_addr.extended_addr, + IEEE802154_ATTR_PAD)) + return -EMSGSIZE; + } + + if (desc->mode != IEEE802154_SCF_KEY_IMPLICIT && + nla_put_u8(msg, IEEE802154_ATTR_LLSEC_KEY_ID, desc->id)) + return -EMSGSIZE; + + if (desc->mode == IEEE802154_SCF_KEY_SHORT_INDEX && + nla_put_u32(msg, IEEE802154_ATTR_LLSEC_KEY_SOURCE_SHORT, + le32_to_cpu(desc->short_source))) + return -EMSGSIZE; + + if (desc->mode == IEEE802154_SCF_KEY_HW_INDEX && + nla_put_hwaddr(msg, IEEE802154_ATTR_LLSEC_KEY_SOURCE_EXTENDED, + desc->extended_source, IEEE802154_ATTR_PAD)) + return -EMSGSIZE; + + return 0; +} + +int ieee802154_llsec_getparams(struct sk_buff *skb, struct genl_info *info) +{ + struct sk_buff *msg; + struct net_device *dev = NULL; + int rc = -ENOBUFS; + struct ieee802154_mlme_ops *ops; + void *hdr; + struct ieee802154_llsec_params params; + + pr_debug("%s\n", __func__); + + dev = ieee802154_nl_get_dev(info); + if (!dev) + return -ENODEV; + + ops = ieee802154_mlme_ops(dev); + if (!ops->llsec) { + rc = -EOPNOTSUPP; + goto out_dev; + } + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + goto out_dev; + + hdr = genlmsg_put(msg, 0, info->snd_seq, &nl802154_family, 0, + IEEE802154_LLSEC_GETPARAMS); + if (!hdr) + goto out_free; + + rc = ops->llsec->get_params(dev, ¶ms); + if (rc < 0) + goto out_free; + + if (nla_put_string(msg, IEEE802154_ATTR_DEV_NAME, dev->name) || + nla_put_u32(msg, IEEE802154_ATTR_DEV_INDEX, dev->ifindex) || + nla_put_u8(msg, IEEE802154_ATTR_LLSEC_ENABLED, params.enabled) || + nla_put_u8(msg, IEEE802154_ATTR_LLSEC_SECLEVEL, params.out_level) || + nla_put_u32(msg, IEEE802154_ATTR_LLSEC_FRAME_COUNTER, + be32_to_cpu(params.frame_counter)) || + ieee802154_llsec_fill_key_id(msg, ¶ms.out_key)) { + rc = -ENOBUFS; + goto out_free; + } + + dev_put(dev); + + return ieee802154_nl_reply(msg, info); +out_free: + nlmsg_free(msg); +out_dev: + dev_put(dev); + return rc; +} + +int ieee802154_llsec_setparams(struct sk_buff *skb, struct genl_info *info) +{ + struct net_device *dev = NULL; + int rc = -EINVAL; + struct ieee802154_mlme_ops *ops; + struct ieee802154_llsec_params params; + int changed = 0; + + pr_debug("%s\n", __func__); + + dev = ieee802154_nl_get_dev(info); + if (!dev) + return -ENODEV; + + if (!info->attrs[IEEE802154_ATTR_LLSEC_ENABLED] && + !info->attrs[IEEE802154_ATTR_LLSEC_KEY_MODE] && + !info->attrs[IEEE802154_ATTR_LLSEC_SECLEVEL]) + goto out; + + ops = ieee802154_mlme_ops(dev); + if (!ops->llsec) { + rc = -EOPNOTSUPP; + goto out; + } + + if (info->attrs[IEEE802154_ATTR_LLSEC_SECLEVEL] && + nla_get_u8(info->attrs[IEEE802154_ATTR_LLSEC_SECLEVEL]) > 7) + goto out; + + if (info->attrs[IEEE802154_ATTR_LLSEC_ENABLED]) { + params.enabled = nla_get_u8(info->attrs[IEEE802154_ATTR_LLSEC_ENABLED]); + changed |= IEEE802154_LLSEC_PARAM_ENABLED; + } + + if (info->attrs[IEEE802154_ATTR_LLSEC_KEY_MODE]) { + if (ieee802154_llsec_parse_key_id(info, ¶ms.out_key)) + goto out; + + changed |= IEEE802154_LLSEC_PARAM_OUT_KEY; + } + + if (info->attrs[IEEE802154_ATTR_LLSEC_SECLEVEL]) { + params.out_level = nla_get_u8(info->attrs[IEEE802154_ATTR_LLSEC_SECLEVEL]); + changed |= IEEE802154_LLSEC_PARAM_OUT_LEVEL; + } + + if (info->attrs[IEEE802154_ATTR_LLSEC_FRAME_COUNTER]) { + u32 fc = nla_get_u32(info->attrs[IEEE802154_ATTR_LLSEC_FRAME_COUNTER]); + + params.frame_counter = cpu_to_be32(fc); + changed |= IEEE802154_LLSEC_PARAM_FRAME_COUNTER; + } + + rc = ops->llsec->set_params(dev, ¶ms, changed); + + dev_put(dev); + + return rc; +out: + dev_put(dev); + return rc; +} + +struct llsec_dump_data { + struct sk_buff *skb; + int s_idx, s_idx2; + int portid; + int nlmsg_seq; + struct net_device *dev; + struct ieee802154_mlme_ops *ops; + struct ieee802154_llsec_table *table; +}; + +static int +ieee802154_llsec_dump_table(struct sk_buff *skb, struct netlink_callback *cb, + int (*step)(struct llsec_dump_data *)) +{ + struct net *net = sock_net(skb->sk); + struct net_device *dev; + struct llsec_dump_data data; + int idx = 0; + int first_dev = cb->args[0]; + int rc; + + for_each_netdev(net, dev) { + if (idx < first_dev || dev->type != ARPHRD_IEEE802154) + goto skip; + + data.ops = ieee802154_mlme_ops(dev); + if (!data.ops->llsec) + goto skip; + + data.skb = skb; + data.s_idx = cb->args[1]; + data.s_idx2 = cb->args[2]; + data.dev = dev; + data.portid = NETLINK_CB(cb->skb).portid; + data.nlmsg_seq = cb->nlh->nlmsg_seq; + + data.ops->llsec->lock_table(dev); + data.ops->llsec->get_table(data.dev, &data.table); + rc = step(&data); + data.ops->llsec->unlock_table(dev); + + if (rc < 0) + break; + +skip: + idx++; + } + cb->args[0] = idx; + + return skb->len; +} + +static int +ieee802154_nl_llsec_change(struct sk_buff *skb, struct genl_info *info, + int (*fn)(struct net_device*, struct genl_info*)) +{ + struct net_device *dev = NULL; + int rc = -EINVAL; + + dev = ieee802154_nl_get_dev(info); + if (!dev) + return -ENODEV; + + if (!ieee802154_mlme_ops(dev)->llsec) + rc = -EOPNOTSUPP; + else + rc = fn(dev, info); + + dev_put(dev); + return rc; +} + +static int +ieee802154_llsec_parse_key(struct genl_info *info, + struct ieee802154_llsec_key *key) +{ + u8 frames; + u32 commands[256 / 32]; + + memset(key, 0, sizeof(*key)); + + if (!info->attrs[IEEE802154_ATTR_LLSEC_KEY_USAGE_FRAME_TYPES] || + !info->attrs[IEEE802154_ATTR_LLSEC_KEY_BYTES]) + return -EINVAL; + + frames = nla_get_u8(info->attrs[IEEE802154_ATTR_LLSEC_KEY_USAGE_FRAME_TYPES]); + if ((frames & BIT(IEEE802154_FC_TYPE_MAC_CMD)) && + !info->attrs[IEEE802154_ATTR_LLSEC_KEY_USAGE_COMMANDS]) + return -EINVAL; + + if (info->attrs[IEEE802154_ATTR_LLSEC_KEY_USAGE_COMMANDS]) { + nla_memcpy(commands, + info->attrs[IEEE802154_ATTR_LLSEC_KEY_USAGE_COMMANDS], + 256 / 8); + + if (commands[0] || commands[1] || commands[2] || commands[3] || + commands[4] || commands[5] || commands[6] || + commands[7] >= BIT(IEEE802154_CMD_GTS_REQ + 1)) + return -EINVAL; + + key->cmd_frame_ids = commands[7]; + } + + key->frame_types = frames; + + nla_memcpy(key->key, info->attrs[IEEE802154_ATTR_LLSEC_KEY_BYTES], + IEEE802154_LLSEC_KEY_SIZE); + + return 0; +} + +static int llsec_add_key(struct net_device *dev, struct genl_info *info) +{ + struct ieee802154_mlme_ops *ops = ieee802154_mlme_ops(dev); + struct ieee802154_llsec_key key; + struct ieee802154_llsec_key_id id; + + if (ieee802154_llsec_parse_key(info, &key) || + ieee802154_llsec_parse_key_id(info, &id)) + return -EINVAL; + + return ops->llsec->add_key(dev, &id, &key); +} + +int ieee802154_llsec_add_key(struct sk_buff *skb, struct genl_info *info) +{ + if ((info->nlhdr->nlmsg_flags & (NLM_F_CREATE | NLM_F_EXCL)) != + (NLM_F_CREATE | NLM_F_EXCL)) + return -EINVAL; + + return ieee802154_nl_llsec_change(skb, info, llsec_add_key); +} + +static int llsec_remove_key(struct net_device *dev, struct genl_info *info) +{ + struct ieee802154_mlme_ops *ops = ieee802154_mlme_ops(dev); + struct ieee802154_llsec_key_id id; + + if (ieee802154_llsec_parse_key_id(info, &id)) + return -EINVAL; + + return ops->llsec->del_key(dev, &id); +} + +int ieee802154_llsec_del_key(struct sk_buff *skb, struct genl_info *info) +{ + return ieee802154_nl_llsec_change(skb, info, llsec_remove_key); +} + +static int +ieee802154_nl_fill_key(struct sk_buff *msg, u32 portid, u32 seq, + const struct ieee802154_llsec_key_entry *key, + const struct net_device *dev) +{ + void *hdr; + u32 commands[256 / 32]; + + hdr = genlmsg_put(msg, 0, seq, &nl802154_family, NLM_F_MULTI, + IEEE802154_LLSEC_LIST_KEY); + if (!hdr) + goto out; + + if (nla_put_string(msg, IEEE802154_ATTR_DEV_NAME, dev->name) || + nla_put_u32(msg, IEEE802154_ATTR_DEV_INDEX, dev->ifindex) || + ieee802154_llsec_fill_key_id(msg, &key->id) || + nla_put_u8(msg, IEEE802154_ATTR_LLSEC_KEY_USAGE_FRAME_TYPES, + key->key->frame_types)) + goto nla_put_failure; + + if (key->key->frame_types & BIT(IEEE802154_FC_TYPE_MAC_CMD)) { + memset(commands, 0, sizeof(commands)); + commands[7] = key->key->cmd_frame_ids; + if (nla_put(msg, IEEE802154_ATTR_LLSEC_KEY_USAGE_COMMANDS, + sizeof(commands), commands)) + goto nla_put_failure; + } + + if (nla_put(msg, IEEE802154_ATTR_LLSEC_KEY_BYTES, + IEEE802154_LLSEC_KEY_SIZE, key->key->key)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); +out: + return -EMSGSIZE; +} + +static int llsec_iter_keys(struct llsec_dump_data *data) +{ + struct ieee802154_llsec_key_entry *pos; + int rc = 0, idx = 0; + + list_for_each_entry(pos, &data->table->keys, list) { + if (idx++ < data->s_idx) + continue; + + if (ieee802154_nl_fill_key(data->skb, data->portid, + data->nlmsg_seq, pos, data->dev)) { + rc = -EMSGSIZE; + break; + } + + data->s_idx++; + } + + return rc; +} + +int ieee802154_llsec_dump_keys(struct sk_buff *skb, struct netlink_callback *cb) +{ + return ieee802154_llsec_dump_table(skb, cb, llsec_iter_keys); +} + +static int +llsec_parse_dev(struct genl_info *info, + struct ieee802154_llsec_device *dev) +{ + memset(dev, 0, sizeof(*dev)); + + if (!info->attrs[IEEE802154_ATTR_LLSEC_FRAME_COUNTER] || + !info->attrs[IEEE802154_ATTR_HW_ADDR] || + !info->attrs[IEEE802154_ATTR_LLSEC_DEV_OVERRIDE] || + !info->attrs[IEEE802154_ATTR_LLSEC_DEV_KEY_MODE] || + (!!info->attrs[IEEE802154_ATTR_PAN_ID] != + !!info->attrs[IEEE802154_ATTR_SHORT_ADDR])) + return -EINVAL; + + if (info->attrs[IEEE802154_ATTR_PAN_ID]) { + dev->pan_id = nla_get_shortaddr(info->attrs[IEEE802154_ATTR_PAN_ID]); + dev->short_addr = nla_get_shortaddr(info->attrs[IEEE802154_ATTR_SHORT_ADDR]); + } else { + dev->short_addr = cpu_to_le16(IEEE802154_ADDR_UNDEF); + } + + dev->hwaddr = nla_get_hwaddr(info->attrs[IEEE802154_ATTR_HW_ADDR]); + dev->frame_counter = nla_get_u32(info->attrs[IEEE802154_ATTR_LLSEC_FRAME_COUNTER]); + dev->seclevel_exempt = !!nla_get_u8(info->attrs[IEEE802154_ATTR_LLSEC_DEV_OVERRIDE]); + dev->key_mode = nla_get_u8(info->attrs[IEEE802154_ATTR_LLSEC_DEV_KEY_MODE]); + + if (dev->key_mode >= __IEEE802154_LLSEC_DEVKEY_MAX) + return -EINVAL; + + return 0; +} + +static int llsec_add_dev(struct net_device *dev, struct genl_info *info) +{ + struct ieee802154_mlme_ops *ops = ieee802154_mlme_ops(dev); + struct ieee802154_llsec_device desc; + + if (llsec_parse_dev(info, &desc)) + return -EINVAL; + + return ops->llsec->add_dev(dev, &desc); +} + +int ieee802154_llsec_add_dev(struct sk_buff *skb, struct genl_info *info) +{ + if ((info->nlhdr->nlmsg_flags & (NLM_F_CREATE | NLM_F_EXCL)) != + (NLM_F_CREATE | NLM_F_EXCL)) + return -EINVAL; + + return ieee802154_nl_llsec_change(skb, info, llsec_add_dev); +} + +static int llsec_del_dev(struct net_device *dev, struct genl_info *info) +{ + struct ieee802154_mlme_ops *ops = ieee802154_mlme_ops(dev); + __le64 devaddr; + + if (!info->attrs[IEEE802154_ATTR_HW_ADDR]) + return -EINVAL; + + devaddr = nla_get_hwaddr(info->attrs[IEEE802154_ATTR_HW_ADDR]); + + return ops->llsec->del_dev(dev, devaddr); +} + +int ieee802154_llsec_del_dev(struct sk_buff *skb, struct genl_info *info) +{ + return ieee802154_nl_llsec_change(skb, info, llsec_del_dev); +} + +static int +ieee802154_nl_fill_dev(struct sk_buff *msg, u32 portid, u32 seq, + const struct ieee802154_llsec_device *desc, + const struct net_device *dev) +{ + void *hdr; + + hdr = genlmsg_put(msg, 0, seq, &nl802154_family, NLM_F_MULTI, + IEEE802154_LLSEC_LIST_DEV); + if (!hdr) + goto out; + + if (nla_put_string(msg, IEEE802154_ATTR_DEV_NAME, dev->name) || + nla_put_u32(msg, IEEE802154_ATTR_DEV_INDEX, dev->ifindex) || + nla_put_shortaddr(msg, IEEE802154_ATTR_PAN_ID, desc->pan_id) || + nla_put_shortaddr(msg, IEEE802154_ATTR_SHORT_ADDR, + desc->short_addr) || + nla_put_hwaddr(msg, IEEE802154_ATTR_HW_ADDR, desc->hwaddr, + IEEE802154_ATTR_PAD) || + nla_put_u32(msg, IEEE802154_ATTR_LLSEC_FRAME_COUNTER, + desc->frame_counter) || + nla_put_u8(msg, IEEE802154_ATTR_LLSEC_DEV_OVERRIDE, + desc->seclevel_exempt) || + nla_put_u8(msg, IEEE802154_ATTR_LLSEC_DEV_KEY_MODE, desc->key_mode)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); +out: + return -EMSGSIZE; +} + +static int llsec_iter_devs(struct llsec_dump_data *data) +{ + struct ieee802154_llsec_device *pos; + int rc = 0, idx = 0; + + list_for_each_entry(pos, &data->table->devices, list) { + if (idx++ < data->s_idx) + continue; + + if (ieee802154_nl_fill_dev(data->skb, data->portid, + data->nlmsg_seq, pos, data->dev)) { + rc = -EMSGSIZE; + break; + } + + data->s_idx++; + } + + return rc; +} + +int ieee802154_llsec_dump_devs(struct sk_buff *skb, struct netlink_callback *cb) +{ + return ieee802154_llsec_dump_table(skb, cb, llsec_iter_devs); +} + +static int llsec_add_devkey(struct net_device *dev, struct genl_info *info) +{ + struct ieee802154_mlme_ops *ops = ieee802154_mlme_ops(dev); + struct ieee802154_llsec_device_key key; + __le64 devaddr; + + if (!info->attrs[IEEE802154_ATTR_LLSEC_FRAME_COUNTER] || + !info->attrs[IEEE802154_ATTR_HW_ADDR] || + ieee802154_llsec_parse_key_id(info, &key.key_id)) + return -EINVAL; + + devaddr = nla_get_hwaddr(info->attrs[IEEE802154_ATTR_HW_ADDR]); + key.frame_counter = nla_get_u32(info->attrs[IEEE802154_ATTR_LLSEC_FRAME_COUNTER]); + + return ops->llsec->add_devkey(dev, devaddr, &key); +} + +int ieee802154_llsec_add_devkey(struct sk_buff *skb, struct genl_info *info) +{ + if ((info->nlhdr->nlmsg_flags & (NLM_F_CREATE | NLM_F_EXCL)) != + (NLM_F_CREATE | NLM_F_EXCL)) + return -EINVAL; + + return ieee802154_nl_llsec_change(skb, info, llsec_add_devkey); +} + +static int llsec_del_devkey(struct net_device *dev, struct genl_info *info) +{ + struct ieee802154_mlme_ops *ops = ieee802154_mlme_ops(dev); + struct ieee802154_llsec_device_key key; + __le64 devaddr; + + if (!info->attrs[IEEE802154_ATTR_HW_ADDR] || + ieee802154_llsec_parse_key_id(info, &key.key_id)) + return -EINVAL; + + devaddr = nla_get_hwaddr(info->attrs[IEEE802154_ATTR_HW_ADDR]); + + return ops->llsec->del_devkey(dev, devaddr, &key); +} + +int ieee802154_llsec_del_devkey(struct sk_buff *skb, struct genl_info *info) +{ + return ieee802154_nl_llsec_change(skb, info, llsec_del_devkey); +} + +static int +ieee802154_nl_fill_devkey(struct sk_buff *msg, u32 portid, u32 seq, + __le64 devaddr, + const struct ieee802154_llsec_device_key *devkey, + const struct net_device *dev) +{ + void *hdr; + + hdr = genlmsg_put(msg, 0, seq, &nl802154_family, NLM_F_MULTI, + IEEE802154_LLSEC_LIST_DEVKEY); + if (!hdr) + goto out; + + if (nla_put_string(msg, IEEE802154_ATTR_DEV_NAME, dev->name) || + nla_put_u32(msg, IEEE802154_ATTR_DEV_INDEX, dev->ifindex) || + nla_put_hwaddr(msg, IEEE802154_ATTR_HW_ADDR, devaddr, + IEEE802154_ATTR_PAD) || + nla_put_u32(msg, IEEE802154_ATTR_LLSEC_FRAME_COUNTER, + devkey->frame_counter) || + ieee802154_llsec_fill_key_id(msg, &devkey->key_id)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); +out: + return -EMSGSIZE; +} + +static int llsec_iter_devkeys(struct llsec_dump_data *data) +{ + struct ieee802154_llsec_device *dpos; + struct ieee802154_llsec_device_key *kpos; + int idx = 0, idx2; + + list_for_each_entry(dpos, &data->table->devices, list) { + if (idx++ < data->s_idx) + continue; + + idx2 = 0; + + list_for_each_entry(kpos, &dpos->keys, list) { + if (idx2++ < data->s_idx2) + continue; + + if (ieee802154_nl_fill_devkey(data->skb, data->portid, + data->nlmsg_seq, + dpos->hwaddr, kpos, + data->dev)) { + return -EMSGSIZE; + } + + data->s_idx2++; + } + + data->s_idx++; + } + + return 0; +} + +int ieee802154_llsec_dump_devkeys(struct sk_buff *skb, + struct netlink_callback *cb) +{ + return ieee802154_llsec_dump_table(skb, cb, llsec_iter_devkeys); +} + +static int +llsec_parse_seclevel(struct genl_info *info, + struct ieee802154_llsec_seclevel *sl) +{ + memset(sl, 0, sizeof(*sl)); + + if (!info->attrs[IEEE802154_ATTR_LLSEC_FRAME_TYPE] || + !info->attrs[IEEE802154_ATTR_LLSEC_SECLEVELS] || + !info->attrs[IEEE802154_ATTR_LLSEC_DEV_OVERRIDE]) + return -EINVAL; + + sl->frame_type = nla_get_u8(info->attrs[IEEE802154_ATTR_LLSEC_FRAME_TYPE]); + if (sl->frame_type == IEEE802154_FC_TYPE_MAC_CMD) { + if (!info->attrs[IEEE802154_ATTR_LLSEC_CMD_FRAME_ID]) + return -EINVAL; + + sl->cmd_frame_id = nla_get_u8(info->attrs[IEEE802154_ATTR_LLSEC_CMD_FRAME_ID]); + } + + sl->sec_levels = nla_get_u8(info->attrs[IEEE802154_ATTR_LLSEC_SECLEVELS]); + sl->device_override = nla_get_u8(info->attrs[IEEE802154_ATTR_LLSEC_DEV_OVERRIDE]); + + return 0; +} + +static int llsec_add_seclevel(struct net_device *dev, struct genl_info *info) +{ + struct ieee802154_mlme_ops *ops = ieee802154_mlme_ops(dev); + struct ieee802154_llsec_seclevel sl; + + if (llsec_parse_seclevel(info, &sl)) + return -EINVAL; + + return ops->llsec->add_seclevel(dev, &sl); +} + +int ieee802154_llsec_add_seclevel(struct sk_buff *skb, struct genl_info *info) +{ + if ((info->nlhdr->nlmsg_flags & (NLM_F_CREATE | NLM_F_EXCL)) != + (NLM_F_CREATE | NLM_F_EXCL)) + return -EINVAL; + + return ieee802154_nl_llsec_change(skb, info, llsec_add_seclevel); +} + +static int llsec_del_seclevel(struct net_device *dev, struct genl_info *info) +{ + struct ieee802154_mlme_ops *ops = ieee802154_mlme_ops(dev); + struct ieee802154_llsec_seclevel sl; + + if (llsec_parse_seclevel(info, &sl)) + return -EINVAL; + + return ops->llsec->del_seclevel(dev, &sl); +} + +int ieee802154_llsec_del_seclevel(struct sk_buff *skb, struct genl_info *info) +{ + return ieee802154_nl_llsec_change(skb, info, llsec_del_seclevel); +} + +static int +ieee802154_nl_fill_seclevel(struct sk_buff *msg, u32 portid, u32 seq, + const struct ieee802154_llsec_seclevel *sl, + const struct net_device *dev) +{ + void *hdr; + + hdr = genlmsg_put(msg, 0, seq, &nl802154_family, NLM_F_MULTI, + IEEE802154_LLSEC_LIST_SECLEVEL); + if (!hdr) + goto out; + + if (nla_put_string(msg, IEEE802154_ATTR_DEV_NAME, dev->name) || + nla_put_u32(msg, IEEE802154_ATTR_DEV_INDEX, dev->ifindex) || + nla_put_u8(msg, IEEE802154_ATTR_LLSEC_FRAME_TYPE, sl->frame_type) || + nla_put_u8(msg, IEEE802154_ATTR_LLSEC_SECLEVELS, sl->sec_levels) || + nla_put_u8(msg, IEEE802154_ATTR_LLSEC_DEV_OVERRIDE, + sl->device_override)) + goto nla_put_failure; + + if (sl->frame_type == IEEE802154_FC_TYPE_MAC_CMD && + nla_put_u8(msg, IEEE802154_ATTR_LLSEC_CMD_FRAME_ID, + sl->cmd_frame_id)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); +out: + return -EMSGSIZE; +} + +static int llsec_iter_seclevels(struct llsec_dump_data *data) +{ + struct ieee802154_llsec_seclevel *pos; + int rc = 0, idx = 0; + + list_for_each_entry(pos, &data->table->security_levels, list) { + if (idx++ < data->s_idx) + continue; + + if (ieee802154_nl_fill_seclevel(data->skb, data->portid, + data->nlmsg_seq, pos, + data->dev)) { + rc = -EMSGSIZE; + break; + } + + data->s_idx++; + } + + return rc; +} + +int ieee802154_llsec_dump_seclevels(struct sk_buff *skb, + struct netlink_callback *cb) +{ + return ieee802154_llsec_dump_table(skb, cb, llsec_iter_seclevels); +} diff --git a/net/ieee802154/nl-phy.c b/net/ieee802154/nl-phy.c new file mode 100644 index 000000000..359249ab7 --- /dev/null +++ b/net/ieee802154/nl-phy.c @@ -0,0 +1,346 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Netlink interface for IEEE 802.15.4 stack + * + * Copyright 2007, 2008 Siemens AG + * + * Written by: + * Sergey Lapin + * Dmitry Eremin-Solenikov + * Maxim Osipov + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include /* for rtnl_{un,}lock */ +#include + +#include "ieee802154.h" +#include "rdev-ops.h" +#include "core.h" + +static int ieee802154_nl_fill_phy(struct sk_buff *msg, u32 portid, + u32 seq, int flags, struct wpan_phy *phy) +{ + void *hdr; + int i, pages = 0; + u32 *buf = kcalloc(IEEE802154_MAX_PAGE + 1, sizeof(u32), GFP_KERNEL); + + pr_debug("%s\n", __func__); + + if (!buf) + return -EMSGSIZE; + + hdr = genlmsg_put(msg, 0, seq, &nl802154_family, flags, + IEEE802154_LIST_PHY); + if (!hdr) + goto out; + + rtnl_lock(); + if (nla_put_string(msg, IEEE802154_ATTR_PHY_NAME, wpan_phy_name(phy)) || + nla_put_u8(msg, IEEE802154_ATTR_PAGE, phy->current_page) || + nla_put_u8(msg, IEEE802154_ATTR_CHANNEL, phy->current_channel)) + goto nla_put_failure; + for (i = 0; i <= IEEE802154_MAX_PAGE; i++) { + if (phy->supported.channels[i]) + buf[pages++] = phy->supported.channels[i] | (i << 27); + } + if (pages && + nla_put(msg, IEEE802154_ATTR_CHANNEL_PAGE_LIST, + pages * sizeof(uint32_t), buf)) + goto nla_put_failure; + rtnl_unlock(); + kfree(buf); + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + rtnl_unlock(); + genlmsg_cancel(msg, hdr); +out: + kfree(buf); + return -EMSGSIZE; +} + +int ieee802154_list_phy(struct sk_buff *skb, struct genl_info *info) +{ + /* Request for interface name, index, type, IEEE address, + * PAN Id, short address + */ + struct sk_buff *msg; + struct wpan_phy *phy; + const char *name; + int rc = -ENOBUFS; + + pr_debug("%s\n", __func__); + + if (!info->attrs[IEEE802154_ATTR_PHY_NAME]) + return -EINVAL; + + name = nla_data(info->attrs[IEEE802154_ATTR_PHY_NAME]); + if (name[nla_len(info->attrs[IEEE802154_ATTR_PHY_NAME]) - 1] != '\0') + return -EINVAL; /* phy name should be null-terminated */ + + phy = wpan_phy_find(name); + if (!phy) + return -ENODEV; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + goto out_dev; + + rc = ieee802154_nl_fill_phy(msg, info->snd_portid, info->snd_seq, + 0, phy); + if (rc < 0) + goto out_free; + + wpan_phy_put(phy); + + return genlmsg_reply(msg, info); +out_free: + nlmsg_free(msg); +out_dev: + wpan_phy_put(phy); + return rc; +} + +struct dump_phy_data { + struct sk_buff *skb; + struct netlink_callback *cb; + int idx, s_idx; +}; + +static int ieee802154_dump_phy_iter(struct wpan_phy *phy, void *_data) +{ + int rc; + struct dump_phy_data *data = _data; + + pr_debug("%s\n", __func__); + + if (data->idx++ < data->s_idx) + return 0; + + rc = ieee802154_nl_fill_phy(data->skb, + NETLINK_CB(data->cb->skb).portid, + data->cb->nlh->nlmsg_seq, + NLM_F_MULTI, + phy); + + if (rc < 0) { + data->idx--; + return rc; + } + + return 0; +} + +int ieee802154_dump_phy(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct dump_phy_data data = { + .cb = cb, + .skb = skb, + .s_idx = cb->args[0], + .idx = 0, + }; + + pr_debug("%s\n", __func__); + + wpan_phy_for_each(ieee802154_dump_phy_iter, &data); + + cb->args[0] = data.idx; + + return skb->len; +} + +int ieee802154_add_iface(struct sk_buff *skb, struct genl_info *info) +{ + struct sk_buff *msg; + struct wpan_phy *phy; + const char *name; + const char *devname; + int rc = -ENOBUFS; + struct net_device *dev; + int type = __IEEE802154_DEV_INVALID; + unsigned char name_assign_type; + + pr_debug("%s\n", __func__); + + if (!info->attrs[IEEE802154_ATTR_PHY_NAME]) + return -EINVAL; + + name = nla_data(info->attrs[IEEE802154_ATTR_PHY_NAME]); + if (name[nla_len(info->attrs[IEEE802154_ATTR_PHY_NAME]) - 1] != '\0') + return -EINVAL; /* phy name should be null-terminated */ + + if (info->attrs[IEEE802154_ATTR_DEV_NAME]) { + devname = nla_data(info->attrs[IEEE802154_ATTR_DEV_NAME]); + if (devname[nla_len(info->attrs[IEEE802154_ATTR_DEV_NAME]) - 1] + != '\0') + return -EINVAL; /* phy name should be null-terminated */ + name_assign_type = NET_NAME_USER; + } else { + devname = "wpan%d"; + name_assign_type = NET_NAME_ENUM; + } + + if (strlen(devname) >= IFNAMSIZ) + return -ENAMETOOLONG; + + phy = wpan_phy_find(name); + if (!phy) + return -ENODEV; + + msg = ieee802154_nl_new_reply(info, 0, IEEE802154_ADD_IFACE); + if (!msg) + goto out_dev; + + if (info->attrs[IEEE802154_ATTR_HW_ADDR] && + nla_len(info->attrs[IEEE802154_ATTR_HW_ADDR]) != + IEEE802154_ADDR_LEN) { + rc = -EINVAL; + goto nla_put_failure; + } + + if (info->attrs[IEEE802154_ATTR_DEV_TYPE]) { + type = nla_get_u8(info->attrs[IEEE802154_ATTR_DEV_TYPE]); + if (type >= __IEEE802154_DEV_MAX) { + rc = -EINVAL; + goto nla_put_failure; + } + } + + dev = rdev_add_virtual_intf_deprecated(wpan_phy_to_rdev(phy), devname, + name_assign_type, type); + if (IS_ERR(dev)) { + rc = PTR_ERR(dev); + goto nla_put_failure; + } + dev_hold(dev); + + if (info->attrs[IEEE802154_ATTR_HW_ADDR]) { + struct sockaddr addr; + + addr.sa_family = ARPHRD_IEEE802154; + nla_memcpy(&addr.sa_data, info->attrs[IEEE802154_ATTR_HW_ADDR], + IEEE802154_ADDR_LEN); + + /* strangely enough, some callbacks (inetdev_event) from + * dev_set_mac_address require RTNL_LOCK + */ + rtnl_lock(); + rc = dev_set_mac_address(dev, &addr, NULL); + rtnl_unlock(); + if (rc) + goto dev_unregister; + } + + if (nla_put_string(msg, IEEE802154_ATTR_PHY_NAME, wpan_phy_name(phy)) || + nla_put_string(msg, IEEE802154_ATTR_DEV_NAME, dev->name)) { + rc = -EMSGSIZE; + goto nla_put_failure; + } + dev_put(dev); + + wpan_phy_put(phy); + + return ieee802154_nl_reply(msg, info); + +dev_unregister: + rtnl_lock(); /* del_iface must be called with RTNL lock */ + rdev_del_virtual_intf_deprecated(wpan_phy_to_rdev(phy), dev); + dev_put(dev); + rtnl_unlock(); +nla_put_failure: + nlmsg_free(msg); +out_dev: + wpan_phy_put(phy); + return rc; +} + +int ieee802154_del_iface(struct sk_buff *skb, struct genl_info *info) +{ + struct sk_buff *msg; + struct wpan_phy *phy; + const char *name; + int rc; + struct net_device *dev; + + pr_debug("%s\n", __func__); + + if (!info->attrs[IEEE802154_ATTR_DEV_NAME]) + return -EINVAL; + + name = nla_data(info->attrs[IEEE802154_ATTR_DEV_NAME]); + if (name[nla_len(info->attrs[IEEE802154_ATTR_DEV_NAME]) - 1] != '\0') + return -EINVAL; /* name should be null-terminated */ + + rc = -ENODEV; + dev = dev_get_by_name(genl_info_net(info), name); + if (!dev) + return rc; + if (dev->type != ARPHRD_IEEE802154) + goto out; + + phy = dev->ieee802154_ptr->wpan_phy; + BUG_ON(!phy); + get_device(&phy->dev); + + rc = -EINVAL; + /* phy name is optional, but should be checked if it's given */ + if (info->attrs[IEEE802154_ATTR_PHY_NAME]) { + struct wpan_phy *phy2; + + const char *pname = + nla_data(info->attrs[IEEE802154_ATTR_PHY_NAME]); + if (pname[nla_len(info->attrs[IEEE802154_ATTR_PHY_NAME]) - 1] + != '\0') + /* name should be null-terminated */ + goto out_dev; + + phy2 = wpan_phy_find(pname); + if (!phy2) + goto out_dev; + + if (phy != phy2) { + wpan_phy_put(phy2); + goto out_dev; + } + } + + rc = -ENOBUFS; + + msg = ieee802154_nl_new_reply(info, 0, IEEE802154_DEL_IFACE); + if (!msg) + goto out_dev; + + rtnl_lock(); + rdev_del_virtual_intf_deprecated(wpan_phy_to_rdev(phy), dev); + + /* We don't have device anymore */ + dev_put(dev); + dev = NULL; + + rtnl_unlock(); + + if (nla_put_string(msg, IEEE802154_ATTR_PHY_NAME, wpan_phy_name(phy)) || + nla_put_string(msg, IEEE802154_ATTR_DEV_NAME, name)) + goto nla_put_failure; + wpan_phy_put(phy); + + return ieee802154_nl_reply(msg, info); + +nla_put_failure: + nlmsg_free(msg); +out_dev: + wpan_phy_put(phy); +out: + dev_put(dev); + + return rc; +} diff --git a/net/ieee802154/nl802154.c b/net/ieee802154/nl802154.c new file mode 100644 index 000000000..38c4f3cb0 --- /dev/null +++ b/net/ieee802154/nl802154.c @@ -0,0 +1,2517 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * + * Authors: + * Alexander Aring + * + * Based on: net/wireless/nl80211.c + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include "nl802154.h" +#include "rdev-ops.h" +#include "core.h" + +/* the netlink family */ +static struct genl_family nl802154_fam; + +/* multicast groups */ +enum nl802154_multicast_groups { + NL802154_MCGRP_CONFIG, +}; + +static const struct genl_multicast_group nl802154_mcgrps[] = { + [NL802154_MCGRP_CONFIG] = { .name = "config", }, +}; + +/* returns ERR_PTR values */ +static struct wpan_dev * +__cfg802154_wpan_dev_from_attrs(struct net *netns, struct nlattr **attrs) +{ + struct cfg802154_registered_device *rdev; + struct wpan_dev *result = NULL; + bool have_ifidx = attrs[NL802154_ATTR_IFINDEX]; + bool have_wpan_dev_id = attrs[NL802154_ATTR_WPAN_DEV]; + u64 wpan_dev_id; + int wpan_phy_idx = -1; + int ifidx = -1; + + ASSERT_RTNL(); + + if (!have_ifidx && !have_wpan_dev_id) + return ERR_PTR(-EINVAL); + + if (have_ifidx) + ifidx = nla_get_u32(attrs[NL802154_ATTR_IFINDEX]); + if (have_wpan_dev_id) { + wpan_dev_id = nla_get_u64(attrs[NL802154_ATTR_WPAN_DEV]); + wpan_phy_idx = wpan_dev_id >> 32; + } + + list_for_each_entry(rdev, &cfg802154_rdev_list, list) { + struct wpan_dev *wpan_dev; + + if (wpan_phy_net(&rdev->wpan_phy) != netns) + continue; + + if (have_wpan_dev_id && rdev->wpan_phy_idx != wpan_phy_idx) + continue; + + list_for_each_entry(wpan_dev, &rdev->wpan_dev_list, list) { + if (have_ifidx && wpan_dev->netdev && + wpan_dev->netdev->ifindex == ifidx) { + result = wpan_dev; + break; + } + if (have_wpan_dev_id && + wpan_dev->identifier == (u32)wpan_dev_id) { + result = wpan_dev; + break; + } + } + + if (result) + break; + } + + if (result) + return result; + + return ERR_PTR(-ENODEV); +} + +static struct cfg802154_registered_device * +__cfg802154_rdev_from_attrs(struct net *netns, struct nlattr **attrs) +{ + struct cfg802154_registered_device *rdev = NULL, *tmp; + struct net_device *netdev; + + ASSERT_RTNL(); + + if (!attrs[NL802154_ATTR_WPAN_PHY] && + !attrs[NL802154_ATTR_IFINDEX] && + !attrs[NL802154_ATTR_WPAN_DEV]) + return ERR_PTR(-EINVAL); + + if (attrs[NL802154_ATTR_WPAN_PHY]) + rdev = cfg802154_rdev_by_wpan_phy_idx( + nla_get_u32(attrs[NL802154_ATTR_WPAN_PHY])); + + if (attrs[NL802154_ATTR_WPAN_DEV]) { + u64 wpan_dev_id = nla_get_u64(attrs[NL802154_ATTR_WPAN_DEV]); + struct wpan_dev *wpan_dev; + bool found = false; + + tmp = cfg802154_rdev_by_wpan_phy_idx(wpan_dev_id >> 32); + if (tmp) { + /* make sure wpan_dev exists */ + list_for_each_entry(wpan_dev, &tmp->wpan_dev_list, list) { + if (wpan_dev->identifier != (u32)wpan_dev_id) + continue; + found = true; + break; + } + + if (!found) + tmp = NULL; + + if (rdev && tmp != rdev) + return ERR_PTR(-EINVAL); + rdev = tmp; + } + } + + if (attrs[NL802154_ATTR_IFINDEX]) { + int ifindex = nla_get_u32(attrs[NL802154_ATTR_IFINDEX]); + + netdev = __dev_get_by_index(netns, ifindex); + if (netdev) { + if (netdev->ieee802154_ptr) + tmp = wpan_phy_to_rdev( + netdev->ieee802154_ptr->wpan_phy); + else + tmp = NULL; + + /* not wireless device -- return error */ + if (!tmp) + return ERR_PTR(-EINVAL); + + /* mismatch -- return error */ + if (rdev && tmp != rdev) + return ERR_PTR(-EINVAL); + + rdev = tmp; + } + } + + if (!rdev) + return ERR_PTR(-ENODEV); + + if (netns != wpan_phy_net(&rdev->wpan_phy)) + return ERR_PTR(-ENODEV); + + return rdev; +} + +/* This function returns a pointer to the driver + * that the genl_info item that is passed refers to. + * + * The result of this can be a PTR_ERR and hence must + * be checked with IS_ERR() for errors. + */ +static struct cfg802154_registered_device * +cfg802154_get_dev_from_info(struct net *netns, struct genl_info *info) +{ + return __cfg802154_rdev_from_attrs(netns, info->attrs); +} + +/* policy for the attributes */ +static const struct nla_policy nl802154_policy[NL802154_ATTR_MAX+1] = { + [NL802154_ATTR_WPAN_PHY] = { .type = NLA_U32 }, + [NL802154_ATTR_WPAN_PHY_NAME] = { .type = NLA_NUL_STRING, + .len = 20-1 }, + + [NL802154_ATTR_IFINDEX] = { .type = NLA_U32 }, + [NL802154_ATTR_IFTYPE] = { .type = NLA_U32 }, + [NL802154_ATTR_IFNAME] = { .type = NLA_NUL_STRING, .len = IFNAMSIZ-1 }, + + [NL802154_ATTR_WPAN_DEV] = { .type = NLA_U64 }, + + [NL802154_ATTR_PAGE] = { .type = NLA_U8, }, + [NL802154_ATTR_CHANNEL] = { .type = NLA_U8, }, + + [NL802154_ATTR_TX_POWER] = { .type = NLA_S32, }, + + [NL802154_ATTR_CCA_MODE] = { .type = NLA_U32, }, + [NL802154_ATTR_CCA_OPT] = { .type = NLA_U32, }, + [NL802154_ATTR_CCA_ED_LEVEL] = { .type = NLA_S32, }, + + [NL802154_ATTR_SUPPORTED_CHANNEL] = { .type = NLA_U32, }, + + [NL802154_ATTR_PAN_ID] = { .type = NLA_U16, }, + [NL802154_ATTR_EXTENDED_ADDR] = { .type = NLA_U64 }, + [NL802154_ATTR_SHORT_ADDR] = { .type = NLA_U16, }, + + [NL802154_ATTR_MIN_BE] = { .type = NLA_U8, }, + [NL802154_ATTR_MAX_BE] = { .type = NLA_U8, }, + [NL802154_ATTR_MAX_CSMA_BACKOFFS] = { .type = NLA_U8, }, + + [NL802154_ATTR_MAX_FRAME_RETRIES] = { .type = NLA_S8, }, + + [NL802154_ATTR_LBT_MODE] = { .type = NLA_U8, }, + + [NL802154_ATTR_WPAN_PHY_CAPS] = { .type = NLA_NESTED }, + + [NL802154_ATTR_SUPPORTED_COMMANDS] = { .type = NLA_NESTED }, + + [NL802154_ATTR_ACKREQ_DEFAULT] = { .type = NLA_U8 }, + + [NL802154_ATTR_PID] = { .type = NLA_U32 }, + [NL802154_ATTR_NETNS_FD] = { .type = NLA_U32 }, +#ifdef CONFIG_IEEE802154_NL802154_EXPERIMENTAL + [NL802154_ATTR_SEC_ENABLED] = { .type = NLA_U8, }, + [NL802154_ATTR_SEC_OUT_LEVEL] = { .type = NLA_U32, }, + [NL802154_ATTR_SEC_OUT_KEY_ID] = { .type = NLA_NESTED, }, + [NL802154_ATTR_SEC_FRAME_COUNTER] = { .type = NLA_U32 }, + + [NL802154_ATTR_SEC_LEVEL] = { .type = NLA_NESTED }, + [NL802154_ATTR_SEC_DEVICE] = { .type = NLA_NESTED }, + [NL802154_ATTR_SEC_DEVKEY] = { .type = NLA_NESTED }, + [NL802154_ATTR_SEC_KEY] = { .type = NLA_NESTED }, +#endif /* CONFIG_IEEE802154_NL802154_EXPERIMENTAL */ +}; + +#ifdef CONFIG_IEEE802154_NL802154_EXPERIMENTAL +static int +nl802154_prepare_wpan_dev_dump(struct sk_buff *skb, + struct netlink_callback *cb, + struct cfg802154_registered_device **rdev, + struct wpan_dev **wpan_dev) +{ + const struct genl_dumpit_info *info = genl_dumpit_info(cb); + int err; + + rtnl_lock(); + + if (!cb->args[0]) { + *wpan_dev = __cfg802154_wpan_dev_from_attrs(sock_net(skb->sk), + info->attrs); + if (IS_ERR(*wpan_dev)) { + err = PTR_ERR(*wpan_dev); + goto out_unlock; + } + *rdev = wpan_phy_to_rdev((*wpan_dev)->wpan_phy); + /* 0 is the first index - add 1 to parse only once */ + cb->args[0] = (*rdev)->wpan_phy_idx + 1; + cb->args[1] = (*wpan_dev)->identifier; + } else { + /* subtract the 1 again here */ + struct wpan_phy *wpan_phy = wpan_phy_idx_to_wpan_phy(cb->args[0] - 1); + struct wpan_dev *tmp; + + if (!wpan_phy) { + err = -ENODEV; + goto out_unlock; + } + *rdev = wpan_phy_to_rdev(wpan_phy); + *wpan_dev = NULL; + + list_for_each_entry(tmp, &(*rdev)->wpan_dev_list, list) { + if (tmp->identifier == cb->args[1]) { + *wpan_dev = tmp; + break; + } + } + + if (!*wpan_dev) { + err = -ENODEV; + goto out_unlock; + } + } + + return 0; + out_unlock: + rtnl_unlock(); + return err; +} + +static void +nl802154_finish_wpan_dev_dump(struct cfg802154_registered_device *rdev) +{ + rtnl_unlock(); +} +#endif /* CONFIG_IEEE802154_NL802154_EXPERIMENTAL */ + +/* message building helper */ +static inline void *nl802154hdr_put(struct sk_buff *skb, u32 portid, u32 seq, + int flags, u8 cmd) +{ + /* since there is no private header just add the generic one */ + return genlmsg_put(skb, portid, seq, &nl802154_fam, flags, cmd); +} + +static int +nl802154_put_flags(struct sk_buff *msg, int attr, u32 mask) +{ + struct nlattr *nl_flags = nla_nest_start_noflag(msg, attr); + int i; + + if (!nl_flags) + return -ENOBUFS; + + i = 0; + while (mask) { + if ((mask & 1) && nla_put_flag(msg, i)) + return -ENOBUFS; + + mask >>= 1; + i++; + } + + nla_nest_end(msg, nl_flags); + return 0; +} + +static int +nl802154_send_wpan_phy_channels(struct cfg802154_registered_device *rdev, + struct sk_buff *msg) +{ + struct nlattr *nl_page; + unsigned long page; + + nl_page = nla_nest_start_noflag(msg, NL802154_ATTR_CHANNELS_SUPPORTED); + if (!nl_page) + return -ENOBUFS; + + for (page = 0; page <= IEEE802154_MAX_PAGE; page++) { + if (nla_put_u32(msg, NL802154_ATTR_SUPPORTED_CHANNEL, + rdev->wpan_phy.supported.channels[page])) + return -ENOBUFS; + } + nla_nest_end(msg, nl_page); + + return 0; +} + +static int +nl802154_put_capabilities(struct sk_buff *msg, + struct cfg802154_registered_device *rdev) +{ + const struct wpan_phy_supported *caps = &rdev->wpan_phy.supported; + struct nlattr *nl_caps, *nl_channels; + int i; + + nl_caps = nla_nest_start_noflag(msg, NL802154_ATTR_WPAN_PHY_CAPS); + if (!nl_caps) + return -ENOBUFS; + + nl_channels = nla_nest_start_noflag(msg, NL802154_CAP_ATTR_CHANNELS); + if (!nl_channels) + return -ENOBUFS; + + for (i = 0; i <= IEEE802154_MAX_PAGE; i++) { + if (caps->channels[i]) { + if (nl802154_put_flags(msg, i, caps->channels[i])) + return -ENOBUFS; + } + } + + nla_nest_end(msg, nl_channels); + + if (rdev->wpan_phy.flags & WPAN_PHY_FLAG_CCA_ED_LEVEL) { + struct nlattr *nl_ed_lvls; + + nl_ed_lvls = nla_nest_start_noflag(msg, + NL802154_CAP_ATTR_CCA_ED_LEVELS); + if (!nl_ed_lvls) + return -ENOBUFS; + + for (i = 0; i < caps->cca_ed_levels_size; i++) { + if (nla_put_s32(msg, i, caps->cca_ed_levels[i])) + return -ENOBUFS; + } + + nla_nest_end(msg, nl_ed_lvls); + } + + if (rdev->wpan_phy.flags & WPAN_PHY_FLAG_TXPOWER) { + struct nlattr *nl_tx_pwrs; + + nl_tx_pwrs = nla_nest_start_noflag(msg, + NL802154_CAP_ATTR_TX_POWERS); + if (!nl_tx_pwrs) + return -ENOBUFS; + + for (i = 0; i < caps->tx_powers_size; i++) { + if (nla_put_s32(msg, i, caps->tx_powers[i])) + return -ENOBUFS; + } + + nla_nest_end(msg, nl_tx_pwrs); + } + + if (rdev->wpan_phy.flags & WPAN_PHY_FLAG_CCA_MODE) { + if (nl802154_put_flags(msg, NL802154_CAP_ATTR_CCA_MODES, + caps->cca_modes) || + nl802154_put_flags(msg, NL802154_CAP_ATTR_CCA_OPTS, + caps->cca_opts)) + return -ENOBUFS; + } + + if (nla_put_u8(msg, NL802154_CAP_ATTR_MIN_MINBE, caps->min_minbe) || + nla_put_u8(msg, NL802154_CAP_ATTR_MAX_MINBE, caps->max_minbe) || + nla_put_u8(msg, NL802154_CAP_ATTR_MIN_MAXBE, caps->min_maxbe) || + nla_put_u8(msg, NL802154_CAP_ATTR_MAX_MAXBE, caps->max_maxbe) || + nla_put_u8(msg, NL802154_CAP_ATTR_MIN_CSMA_BACKOFFS, + caps->min_csma_backoffs) || + nla_put_u8(msg, NL802154_CAP_ATTR_MAX_CSMA_BACKOFFS, + caps->max_csma_backoffs) || + nla_put_s8(msg, NL802154_CAP_ATTR_MIN_FRAME_RETRIES, + caps->min_frame_retries) || + nla_put_s8(msg, NL802154_CAP_ATTR_MAX_FRAME_RETRIES, + caps->max_frame_retries) || + nl802154_put_flags(msg, NL802154_CAP_ATTR_IFTYPES, + caps->iftypes) || + nla_put_u32(msg, NL802154_CAP_ATTR_LBT, caps->lbt)) + return -ENOBUFS; + + nla_nest_end(msg, nl_caps); + + return 0; +} + +static int nl802154_send_wpan_phy(struct cfg802154_registered_device *rdev, + enum nl802154_commands cmd, + struct sk_buff *msg, u32 portid, u32 seq, + int flags) +{ + struct nlattr *nl_cmds; + void *hdr; + int i; + + hdr = nl802154hdr_put(msg, portid, seq, flags, cmd); + if (!hdr) + return -ENOBUFS; + + if (nla_put_u32(msg, NL802154_ATTR_WPAN_PHY, rdev->wpan_phy_idx) || + nla_put_string(msg, NL802154_ATTR_WPAN_PHY_NAME, + wpan_phy_name(&rdev->wpan_phy)) || + nla_put_u32(msg, NL802154_ATTR_GENERATION, + cfg802154_rdev_list_generation)) + goto nla_put_failure; + + if (cmd != NL802154_CMD_NEW_WPAN_PHY) + goto finish; + + /* DUMP PHY PIB */ + + /* current channel settings */ + if (nla_put_u8(msg, NL802154_ATTR_PAGE, + rdev->wpan_phy.current_page) || + nla_put_u8(msg, NL802154_ATTR_CHANNEL, + rdev->wpan_phy.current_channel)) + goto nla_put_failure; + + /* TODO remove this behaviour, we still keep support it for a while + * so users can change the behaviour to the new one. + */ + if (nl802154_send_wpan_phy_channels(rdev, msg)) + goto nla_put_failure; + + /* cca mode */ + if (rdev->wpan_phy.flags & WPAN_PHY_FLAG_CCA_MODE) { + if (nla_put_u32(msg, NL802154_ATTR_CCA_MODE, + rdev->wpan_phy.cca.mode)) + goto nla_put_failure; + + if (rdev->wpan_phy.cca.mode == NL802154_CCA_ENERGY_CARRIER) { + if (nla_put_u32(msg, NL802154_ATTR_CCA_OPT, + rdev->wpan_phy.cca.opt)) + goto nla_put_failure; + } + } + + if (rdev->wpan_phy.flags & WPAN_PHY_FLAG_TXPOWER) { + if (nla_put_s32(msg, NL802154_ATTR_TX_POWER, + rdev->wpan_phy.transmit_power)) + goto nla_put_failure; + } + + if (rdev->wpan_phy.flags & WPAN_PHY_FLAG_CCA_ED_LEVEL) { + if (nla_put_s32(msg, NL802154_ATTR_CCA_ED_LEVEL, + rdev->wpan_phy.cca_ed_level)) + goto nla_put_failure; + } + + if (nl802154_put_capabilities(msg, rdev)) + goto nla_put_failure; + + nl_cmds = nla_nest_start_noflag(msg, NL802154_ATTR_SUPPORTED_COMMANDS); + if (!nl_cmds) + goto nla_put_failure; + + i = 0; +#define CMD(op, n) \ + do { \ + if (rdev->ops->op) { \ + i++; \ + if (nla_put_u32(msg, i, NL802154_CMD_ ## n)) \ + goto nla_put_failure; \ + } \ + } while (0) + + CMD(add_virtual_intf, NEW_INTERFACE); + CMD(del_virtual_intf, DEL_INTERFACE); + CMD(set_channel, SET_CHANNEL); + CMD(set_pan_id, SET_PAN_ID); + CMD(set_short_addr, SET_SHORT_ADDR); + CMD(set_backoff_exponent, SET_BACKOFF_EXPONENT); + CMD(set_max_csma_backoffs, SET_MAX_CSMA_BACKOFFS); + CMD(set_max_frame_retries, SET_MAX_FRAME_RETRIES); + CMD(set_lbt_mode, SET_LBT_MODE); + CMD(set_ackreq_default, SET_ACKREQ_DEFAULT); + + if (rdev->wpan_phy.flags & WPAN_PHY_FLAG_TXPOWER) + CMD(set_tx_power, SET_TX_POWER); + + if (rdev->wpan_phy.flags & WPAN_PHY_FLAG_CCA_ED_LEVEL) + CMD(set_cca_ed_level, SET_CCA_ED_LEVEL); + + if (rdev->wpan_phy.flags & WPAN_PHY_FLAG_CCA_MODE) + CMD(set_cca_mode, SET_CCA_MODE); + +#undef CMD + nla_nest_end(msg, nl_cmds); + +finish: + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +struct nl802154_dump_wpan_phy_state { + s64 filter_wpan_phy; + long start; + +}; + +static int nl802154_dump_wpan_phy_parse(struct sk_buff *skb, + struct netlink_callback *cb, + struct nl802154_dump_wpan_phy_state *state) +{ + const struct genl_dumpit_info *info = genl_dumpit_info(cb); + struct nlattr **tb = info->attrs; + + if (tb[NL802154_ATTR_WPAN_PHY]) + state->filter_wpan_phy = nla_get_u32(tb[NL802154_ATTR_WPAN_PHY]); + if (tb[NL802154_ATTR_WPAN_DEV]) + state->filter_wpan_phy = nla_get_u64(tb[NL802154_ATTR_WPAN_DEV]) >> 32; + if (tb[NL802154_ATTR_IFINDEX]) { + struct net_device *netdev; + struct cfg802154_registered_device *rdev; + int ifidx = nla_get_u32(tb[NL802154_ATTR_IFINDEX]); + + netdev = __dev_get_by_index(&init_net, ifidx); + if (!netdev) + return -ENODEV; + if (netdev->ieee802154_ptr) { + rdev = wpan_phy_to_rdev( + netdev->ieee802154_ptr->wpan_phy); + state->filter_wpan_phy = rdev->wpan_phy_idx; + } + } + + return 0; +} + +static int +nl802154_dump_wpan_phy(struct sk_buff *skb, struct netlink_callback *cb) +{ + int idx = 0, ret; + struct nl802154_dump_wpan_phy_state *state = (void *)cb->args[0]; + struct cfg802154_registered_device *rdev; + + rtnl_lock(); + if (!state) { + state = kzalloc(sizeof(*state), GFP_KERNEL); + if (!state) { + rtnl_unlock(); + return -ENOMEM; + } + state->filter_wpan_phy = -1; + ret = nl802154_dump_wpan_phy_parse(skb, cb, state); + if (ret) { + kfree(state); + rtnl_unlock(); + return ret; + } + cb->args[0] = (long)state; + } + + list_for_each_entry(rdev, &cfg802154_rdev_list, list) { + if (!net_eq(wpan_phy_net(&rdev->wpan_phy), sock_net(skb->sk))) + continue; + if (++idx <= state->start) + continue; + if (state->filter_wpan_phy != -1 && + state->filter_wpan_phy != rdev->wpan_phy_idx) + continue; + /* attempt to fit multiple wpan_phy data chunks into the skb */ + ret = nl802154_send_wpan_phy(rdev, + NL802154_CMD_NEW_WPAN_PHY, + skb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI); + if (ret < 0) { + if ((ret == -ENOBUFS || ret == -EMSGSIZE) && + !skb->len && cb->min_dump_alloc < 4096) { + cb->min_dump_alloc = 4096; + rtnl_unlock(); + return 1; + } + idx--; + break; + } + break; + } + rtnl_unlock(); + + state->start = idx; + + return skb->len; +} + +static int nl802154_dump_wpan_phy_done(struct netlink_callback *cb) +{ + kfree((void *)cb->args[0]); + return 0; +} + +static int nl802154_get_wpan_phy(struct sk_buff *skb, struct genl_info *info) +{ + struct sk_buff *msg; + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + if (nl802154_send_wpan_phy(rdev, NL802154_CMD_NEW_WPAN_PHY, msg, + info->snd_portid, info->snd_seq, 0) < 0) { + nlmsg_free(msg); + return -ENOBUFS; + } + + return genlmsg_reply(msg, info); +} + +static inline u64 wpan_dev_id(struct wpan_dev *wpan_dev) +{ + return (u64)wpan_dev->identifier | + ((u64)wpan_phy_to_rdev(wpan_dev->wpan_phy)->wpan_phy_idx << 32); +} + +#ifdef CONFIG_IEEE802154_NL802154_EXPERIMENTAL +#include + +static int +ieee802154_llsec_send_key_id(struct sk_buff *msg, + const struct ieee802154_llsec_key_id *desc) +{ + struct nlattr *nl_dev_addr; + + if (nla_put_u32(msg, NL802154_KEY_ID_ATTR_MODE, desc->mode)) + return -ENOBUFS; + + switch (desc->mode) { + case NL802154_KEY_ID_MODE_IMPLICIT: + nl_dev_addr = nla_nest_start_noflag(msg, + NL802154_KEY_ID_ATTR_IMPLICIT); + if (!nl_dev_addr) + return -ENOBUFS; + + if (nla_put_le16(msg, NL802154_DEV_ADDR_ATTR_PAN_ID, + desc->device_addr.pan_id) || + nla_put_u32(msg, NL802154_DEV_ADDR_ATTR_MODE, + desc->device_addr.mode)) + return -ENOBUFS; + + switch (desc->device_addr.mode) { + case NL802154_DEV_ADDR_SHORT: + if (nla_put_le16(msg, NL802154_DEV_ADDR_ATTR_SHORT, + desc->device_addr.short_addr)) + return -ENOBUFS; + break; + case NL802154_DEV_ADDR_EXTENDED: + if (nla_put_le64(msg, NL802154_DEV_ADDR_ATTR_EXTENDED, + desc->device_addr.extended_addr, + NL802154_DEV_ADDR_ATTR_PAD)) + return -ENOBUFS; + break; + default: + /* userspace should handle unknown */ + break; + } + + nla_nest_end(msg, nl_dev_addr); + break; + case NL802154_KEY_ID_MODE_INDEX: + break; + case NL802154_KEY_ID_MODE_INDEX_SHORT: + /* TODO renmae short_source? */ + if (nla_put_le32(msg, NL802154_KEY_ID_ATTR_SOURCE_SHORT, + desc->short_source)) + return -ENOBUFS; + break; + case NL802154_KEY_ID_MODE_INDEX_EXTENDED: + if (nla_put_le64(msg, NL802154_KEY_ID_ATTR_SOURCE_EXTENDED, + desc->extended_source, + NL802154_KEY_ID_ATTR_PAD)) + return -ENOBUFS; + break; + default: + /* userspace should handle unknown */ + break; + } + + /* TODO key_id to key_idx ? Check naming */ + if (desc->mode != NL802154_KEY_ID_MODE_IMPLICIT) { + if (nla_put_u8(msg, NL802154_KEY_ID_ATTR_INDEX, desc->id)) + return -ENOBUFS; + } + + return 0; +} + +static int nl802154_get_llsec_params(struct sk_buff *msg, + struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev) +{ + struct nlattr *nl_key_id; + struct ieee802154_llsec_params params; + int ret; + + ret = rdev_get_llsec_params(rdev, wpan_dev, ¶ms); + if (ret < 0) + return ret; + + if (nla_put_u8(msg, NL802154_ATTR_SEC_ENABLED, params.enabled) || + nla_put_u32(msg, NL802154_ATTR_SEC_OUT_LEVEL, params.out_level) || + nla_put_be32(msg, NL802154_ATTR_SEC_FRAME_COUNTER, + params.frame_counter)) + return -ENOBUFS; + + nl_key_id = nla_nest_start_noflag(msg, NL802154_ATTR_SEC_OUT_KEY_ID); + if (!nl_key_id) + return -ENOBUFS; + + ret = ieee802154_llsec_send_key_id(msg, ¶ms.out_key); + if (ret < 0) + return ret; + + nla_nest_end(msg, nl_key_id); + + return 0; +} +#endif /* CONFIG_IEEE802154_NL802154_EXPERIMENTAL */ + +static int +nl802154_send_iface(struct sk_buff *msg, u32 portid, u32 seq, int flags, + struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev) +{ + struct net_device *dev = wpan_dev->netdev; + void *hdr; + + hdr = nl802154hdr_put(msg, portid, seq, flags, + NL802154_CMD_NEW_INTERFACE); + if (!hdr) + return -1; + + if (dev && + (nla_put_u32(msg, NL802154_ATTR_IFINDEX, dev->ifindex) || + nla_put_string(msg, NL802154_ATTR_IFNAME, dev->name))) + goto nla_put_failure; + + if (nla_put_u32(msg, NL802154_ATTR_WPAN_PHY, rdev->wpan_phy_idx) || + nla_put_u32(msg, NL802154_ATTR_IFTYPE, wpan_dev->iftype) || + nla_put_u64_64bit(msg, NL802154_ATTR_WPAN_DEV, + wpan_dev_id(wpan_dev), NL802154_ATTR_PAD) || + nla_put_u32(msg, NL802154_ATTR_GENERATION, + rdev->devlist_generation ^ + (cfg802154_rdev_list_generation << 2))) + goto nla_put_failure; + + /* address settings */ + if (nla_put_le64(msg, NL802154_ATTR_EXTENDED_ADDR, + wpan_dev->extended_addr, + NL802154_ATTR_PAD) || + nla_put_le16(msg, NL802154_ATTR_SHORT_ADDR, + wpan_dev->short_addr) || + nla_put_le16(msg, NL802154_ATTR_PAN_ID, wpan_dev->pan_id)) + goto nla_put_failure; + + /* ARET handling */ + if (nla_put_s8(msg, NL802154_ATTR_MAX_FRAME_RETRIES, + wpan_dev->frame_retries) || + nla_put_u8(msg, NL802154_ATTR_MAX_BE, wpan_dev->max_be) || + nla_put_u8(msg, NL802154_ATTR_MAX_CSMA_BACKOFFS, + wpan_dev->csma_retries) || + nla_put_u8(msg, NL802154_ATTR_MIN_BE, wpan_dev->min_be)) + goto nla_put_failure; + + /* listen before transmit */ + if (nla_put_u8(msg, NL802154_ATTR_LBT_MODE, wpan_dev->lbt)) + goto nla_put_failure; + + /* ackreq default behaviour */ + if (nla_put_u8(msg, NL802154_ATTR_ACKREQ_DEFAULT, wpan_dev->ackreq)) + goto nla_put_failure; + +#ifdef CONFIG_IEEE802154_NL802154_EXPERIMENTAL + if (wpan_dev->iftype == NL802154_IFTYPE_MONITOR) + goto out; + + if (nl802154_get_llsec_params(msg, rdev, wpan_dev) < 0) + goto nla_put_failure; + +out: +#endif /* CONFIG_IEEE802154_NL802154_EXPERIMENTAL */ + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int +nl802154_dump_interface(struct sk_buff *skb, struct netlink_callback *cb) +{ + int wp_idx = 0; + int if_idx = 0; + int wp_start = cb->args[0]; + int if_start = cb->args[1]; + struct cfg802154_registered_device *rdev; + struct wpan_dev *wpan_dev; + + rtnl_lock(); + list_for_each_entry(rdev, &cfg802154_rdev_list, list) { + if (!net_eq(wpan_phy_net(&rdev->wpan_phy), sock_net(skb->sk))) + continue; + if (wp_idx < wp_start) { + wp_idx++; + continue; + } + if_idx = 0; + + list_for_each_entry(wpan_dev, &rdev->wpan_dev_list, list) { + if (if_idx < if_start) { + if_idx++; + continue; + } + if (nl802154_send_iface(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + rdev, wpan_dev) < 0) { + goto out; + } + if_idx++; + } + + wp_idx++; + } +out: + rtnl_unlock(); + + cb->args[0] = wp_idx; + cb->args[1] = if_idx; + + return skb->len; +} + +static int nl802154_get_interface(struct sk_buff *skb, struct genl_info *info) +{ + struct sk_buff *msg; + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct wpan_dev *wdev = info->user_ptr[1]; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + if (nl802154_send_iface(msg, info->snd_portid, info->snd_seq, 0, + rdev, wdev) < 0) { + nlmsg_free(msg); + return -ENOBUFS; + } + + return genlmsg_reply(msg, info); +} + +static int nl802154_new_interface(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + enum nl802154_iftype type = NL802154_IFTYPE_UNSPEC; + __le64 extended_addr = cpu_to_le64(0x0000000000000000ULL); + + /* TODO avoid failing a new interface + * creation due to pending removal? + */ + + if (!info->attrs[NL802154_ATTR_IFNAME]) + return -EINVAL; + + if (info->attrs[NL802154_ATTR_IFTYPE]) { + type = nla_get_u32(info->attrs[NL802154_ATTR_IFTYPE]); + if (type > NL802154_IFTYPE_MAX || + !(rdev->wpan_phy.supported.iftypes & BIT(type))) + return -EINVAL; + } + + if (info->attrs[NL802154_ATTR_EXTENDED_ADDR]) + extended_addr = nla_get_le64(info->attrs[NL802154_ATTR_EXTENDED_ADDR]); + + if (!rdev->ops->add_virtual_intf) + return -EOPNOTSUPP; + + return rdev_add_virtual_intf(rdev, + nla_data(info->attrs[NL802154_ATTR_IFNAME]), + NET_NAME_USER, type, extended_addr); +} + +static int nl802154_del_interface(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct wpan_dev *wpan_dev = info->user_ptr[1]; + + if (!rdev->ops->del_virtual_intf) + return -EOPNOTSUPP; + + /* If we remove a wpan device without a netdev then clear + * user_ptr[1] so that nl802154_post_doit won't dereference it + * to check if it needs to do dev_put(). Otherwise it crashes + * since the wpan_dev has been freed, unlike with a netdev where + * we need the dev_put() for the netdev to really be freed. + */ + if (!wpan_dev->netdev) + info->user_ptr[1] = NULL; + + return rdev_del_virtual_intf(rdev, wpan_dev); +} + +static int nl802154_set_channel(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + u8 channel, page; + + if (!info->attrs[NL802154_ATTR_PAGE] || + !info->attrs[NL802154_ATTR_CHANNEL]) + return -EINVAL; + + page = nla_get_u8(info->attrs[NL802154_ATTR_PAGE]); + channel = nla_get_u8(info->attrs[NL802154_ATTR_CHANNEL]); + + /* check 802.15.4 constraints */ + if (page > IEEE802154_MAX_PAGE || channel > IEEE802154_MAX_CHANNEL || + !(rdev->wpan_phy.supported.channels[page] & BIT(channel))) + return -EINVAL; + + return rdev_set_channel(rdev, page, channel); +} + +static int nl802154_set_cca_mode(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct wpan_phy_cca cca; + + if (!(rdev->wpan_phy.flags & WPAN_PHY_FLAG_CCA_MODE)) + return -EOPNOTSUPP; + + if (!info->attrs[NL802154_ATTR_CCA_MODE]) + return -EINVAL; + + cca.mode = nla_get_u32(info->attrs[NL802154_ATTR_CCA_MODE]); + /* checking 802.15.4 constraints */ + if (cca.mode < NL802154_CCA_ENERGY || + cca.mode > NL802154_CCA_ATTR_MAX || + !(rdev->wpan_phy.supported.cca_modes & BIT(cca.mode))) + return -EINVAL; + + if (cca.mode == NL802154_CCA_ENERGY_CARRIER) { + if (!info->attrs[NL802154_ATTR_CCA_OPT]) + return -EINVAL; + + cca.opt = nla_get_u32(info->attrs[NL802154_ATTR_CCA_OPT]); + if (cca.opt > NL802154_CCA_OPT_ATTR_MAX || + !(rdev->wpan_phy.supported.cca_opts & BIT(cca.opt))) + return -EINVAL; + } + + return rdev_set_cca_mode(rdev, &cca); +} + +static int nl802154_set_cca_ed_level(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + s32 ed_level; + int i; + + if (!(rdev->wpan_phy.flags & WPAN_PHY_FLAG_CCA_ED_LEVEL)) + return -EOPNOTSUPP; + + if (!info->attrs[NL802154_ATTR_CCA_ED_LEVEL]) + return -EINVAL; + + ed_level = nla_get_s32(info->attrs[NL802154_ATTR_CCA_ED_LEVEL]); + + for (i = 0; i < rdev->wpan_phy.supported.cca_ed_levels_size; i++) { + if (ed_level == rdev->wpan_phy.supported.cca_ed_levels[i]) + return rdev_set_cca_ed_level(rdev, ed_level); + } + + return -EINVAL; +} + +static int nl802154_set_tx_power(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + s32 power; + int i; + + if (!(rdev->wpan_phy.flags & WPAN_PHY_FLAG_TXPOWER)) + return -EOPNOTSUPP; + + if (!info->attrs[NL802154_ATTR_TX_POWER]) + return -EINVAL; + + power = nla_get_s32(info->attrs[NL802154_ATTR_TX_POWER]); + + for (i = 0; i < rdev->wpan_phy.supported.tx_powers_size; i++) { + if (power == rdev->wpan_phy.supported.tx_powers[i]) + return rdev_set_tx_power(rdev, power); + } + + return -EINVAL; +} + +static int nl802154_set_pan_id(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wpan_dev *wpan_dev = dev->ieee802154_ptr; + __le16 pan_id; + + /* conflict here while tx/rx calls */ + if (netif_running(dev)) + return -EBUSY; + + if (wpan_dev->lowpan_dev) { + if (netif_running(wpan_dev->lowpan_dev)) + return -EBUSY; + } + + /* don't change address fields on monitor */ + if (wpan_dev->iftype == NL802154_IFTYPE_MONITOR || + !info->attrs[NL802154_ATTR_PAN_ID]) + return -EINVAL; + + pan_id = nla_get_le16(info->attrs[NL802154_ATTR_PAN_ID]); + + /* TODO + * I am not sure about to check here on broadcast pan_id. + * Broadcast is a valid setting, comment from 802.15.4: + * If this value is 0xffff, the device is not associated. + * + * This could useful to simple deassociate an device. + */ + if (pan_id == cpu_to_le16(IEEE802154_PAN_ID_BROADCAST)) + return -EINVAL; + + return rdev_set_pan_id(rdev, wpan_dev, pan_id); +} + +static int nl802154_set_short_addr(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wpan_dev *wpan_dev = dev->ieee802154_ptr; + __le16 short_addr; + + /* conflict here while tx/rx calls */ + if (netif_running(dev)) + return -EBUSY; + + if (wpan_dev->lowpan_dev) { + if (netif_running(wpan_dev->lowpan_dev)) + return -EBUSY; + } + + /* don't change address fields on monitor */ + if (wpan_dev->iftype == NL802154_IFTYPE_MONITOR || + !info->attrs[NL802154_ATTR_SHORT_ADDR]) + return -EINVAL; + + short_addr = nla_get_le16(info->attrs[NL802154_ATTR_SHORT_ADDR]); + + /* TODO + * I am not sure about to check here on broadcast short_addr. + * Broadcast is a valid setting, comment from 802.15.4: + * A value of 0xfffe indicates that the device has + * associated but has not been allocated an address. A + * value of 0xffff indicates that the device does not + * have a short address. + * + * I think we should allow to set these settings but + * don't allow to allow socket communication with it. + */ + if (short_addr == cpu_to_le16(IEEE802154_ADDR_SHORT_UNSPEC) || + short_addr == cpu_to_le16(IEEE802154_ADDR_SHORT_BROADCAST)) + return -EINVAL; + + return rdev_set_short_addr(rdev, wpan_dev, short_addr); +} + +static int +nl802154_set_backoff_exponent(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wpan_dev *wpan_dev = dev->ieee802154_ptr; + u8 min_be, max_be; + + /* should be set on netif open inside phy settings */ + if (netif_running(dev)) + return -EBUSY; + + if (!info->attrs[NL802154_ATTR_MIN_BE] || + !info->attrs[NL802154_ATTR_MAX_BE]) + return -EINVAL; + + min_be = nla_get_u8(info->attrs[NL802154_ATTR_MIN_BE]); + max_be = nla_get_u8(info->attrs[NL802154_ATTR_MAX_BE]); + + /* check 802.15.4 constraints */ + if (min_be < rdev->wpan_phy.supported.min_minbe || + min_be > rdev->wpan_phy.supported.max_minbe || + max_be < rdev->wpan_phy.supported.min_maxbe || + max_be > rdev->wpan_phy.supported.max_maxbe || + min_be > max_be) + return -EINVAL; + + return rdev_set_backoff_exponent(rdev, wpan_dev, min_be, max_be); +} + +static int +nl802154_set_max_csma_backoffs(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wpan_dev *wpan_dev = dev->ieee802154_ptr; + u8 max_csma_backoffs; + + /* conflict here while other running iface settings */ + if (netif_running(dev)) + return -EBUSY; + + if (!info->attrs[NL802154_ATTR_MAX_CSMA_BACKOFFS]) + return -EINVAL; + + max_csma_backoffs = nla_get_u8( + info->attrs[NL802154_ATTR_MAX_CSMA_BACKOFFS]); + + /* check 802.15.4 constraints */ + if (max_csma_backoffs < rdev->wpan_phy.supported.min_csma_backoffs || + max_csma_backoffs > rdev->wpan_phy.supported.max_csma_backoffs) + return -EINVAL; + + return rdev_set_max_csma_backoffs(rdev, wpan_dev, max_csma_backoffs); +} + +static int +nl802154_set_max_frame_retries(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wpan_dev *wpan_dev = dev->ieee802154_ptr; + s8 max_frame_retries; + + if (netif_running(dev)) + return -EBUSY; + + if (!info->attrs[NL802154_ATTR_MAX_FRAME_RETRIES]) + return -EINVAL; + + max_frame_retries = nla_get_s8( + info->attrs[NL802154_ATTR_MAX_FRAME_RETRIES]); + + /* check 802.15.4 constraints */ + if (max_frame_retries < rdev->wpan_phy.supported.min_frame_retries || + max_frame_retries > rdev->wpan_phy.supported.max_frame_retries) + return -EINVAL; + + return rdev_set_max_frame_retries(rdev, wpan_dev, max_frame_retries); +} + +static int nl802154_set_lbt_mode(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wpan_dev *wpan_dev = dev->ieee802154_ptr; + int mode; + + if (netif_running(dev)) + return -EBUSY; + + if (!info->attrs[NL802154_ATTR_LBT_MODE]) + return -EINVAL; + + mode = nla_get_u8(info->attrs[NL802154_ATTR_LBT_MODE]); + + if (mode != 0 && mode != 1) + return -EINVAL; + + if (!wpan_phy_supported_bool(mode, rdev->wpan_phy.supported.lbt)) + return -EINVAL; + + return rdev_set_lbt_mode(rdev, wpan_dev, mode); +} + +static int +nl802154_set_ackreq_default(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wpan_dev *wpan_dev = dev->ieee802154_ptr; + int ackreq; + + if (netif_running(dev)) + return -EBUSY; + + if (!info->attrs[NL802154_ATTR_ACKREQ_DEFAULT]) + return -EINVAL; + + ackreq = nla_get_u8(info->attrs[NL802154_ATTR_ACKREQ_DEFAULT]); + + if (ackreq != 0 && ackreq != 1) + return -EINVAL; + + return rdev_set_ackreq_default(rdev, wpan_dev, ackreq); +} + +static int nl802154_wpan_phy_netns(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct net *net; + int err; + + if (info->attrs[NL802154_ATTR_PID]) { + u32 pid = nla_get_u32(info->attrs[NL802154_ATTR_PID]); + + net = get_net_ns_by_pid(pid); + } else if (info->attrs[NL802154_ATTR_NETNS_FD]) { + u32 fd = nla_get_u32(info->attrs[NL802154_ATTR_NETNS_FD]); + + net = get_net_ns_by_fd(fd); + } else { + return -EINVAL; + } + + if (IS_ERR(net)) + return PTR_ERR(net); + + err = 0; + + /* check if anything to do */ + if (!net_eq(wpan_phy_net(&rdev->wpan_phy), net)) + err = cfg802154_switch_netns(rdev, net); + + put_net(net); + return err; +} + +#ifdef CONFIG_IEEE802154_NL802154_EXPERIMENTAL +static const struct nla_policy nl802154_dev_addr_policy[NL802154_DEV_ADDR_ATTR_MAX + 1] = { + [NL802154_DEV_ADDR_ATTR_PAN_ID] = { .type = NLA_U16 }, + [NL802154_DEV_ADDR_ATTR_MODE] = { .type = NLA_U32 }, + [NL802154_DEV_ADDR_ATTR_SHORT] = { .type = NLA_U16 }, + [NL802154_DEV_ADDR_ATTR_EXTENDED] = { .type = NLA_U64 }, +}; + +static int +ieee802154_llsec_parse_dev_addr(struct nlattr *nla, + struct ieee802154_addr *addr) +{ + struct nlattr *attrs[NL802154_DEV_ADDR_ATTR_MAX + 1]; + + if (!nla || nla_parse_nested_deprecated(attrs, NL802154_DEV_ADDR_ATTR_MAX, nla, nl802154_dev_addr_policy, NULL)) + return -EINVAL; + + if (!attrs[NL802154_DEV_ADDR_ATTR_PAN_ID] || !attrs[NL802154_DEV_ADDR_ATTR_MODE]) + return -EINVAL; + + addr->pan_id = nla_get_le16(attrs[NL802154_DEV_ADDR_ATTR_PAN_ID]); + addr->mode = nla_get_u32(attrs[NL802154_DEV_ADDR_ATTR_MODE]); + switch (addr->mode) { + case NL802154_DEV_ADDR_SHORT: + if (!attrs[NL802154_DEV_ADDR_ATTR_SHORT]) + return -EINVAL; + addr->short_addr = nla_get_le16(attrs[NL802154_DEV_ADDR_ATTR_SHORT]); + break; + case NL802154_DEV_ADDR_EXTENDED: + if (!attrs[NL802154_DEV_ADDR_ATTR_EXTENDED]) + return -EINVAL; + addr->extended_addr = nla_get_le64(attrs[NL802154_DEV_ADDR_ATTR_EXTENDED]); + break; + default: + return -EINVAL; + } + + return 0; +} + +static const struct nla_policy nl802154_key_id_policy[NL802154_KEY_ID_ATTR_MAX + 1] = { + [NL802154_KEY_ID_ATTR_MODE] = { .type = NLA_U32 }, + [NL802154_KEY_ID_ATTR_INDEX] = { .type = NLA_U8 }, + [NL802154_KEY_ID_ATTR_IMPLICIT] = { .type = NLA_NESTED }, + [NL802154_KEY_ID_ATTR_SOURCE_SHORT] = { .type = NLA_U32 }, + [NL802154_KEY_ID_ATTR_SOURCE_EXTENDED] = { .type = NLA_U64 }, +}; + +static int +ieee802154_llsec_parse_key_id(struct nlattr *nla, + struct ieee802154_llsec_key_id *desc) +{ + struct nlattr *attrs[NL802154_KEY_ID_ATTR_MAX + 1]; + + if (!nla || nla_parse_nested_deprecated(attrs, NL802154_KEY_ID_ATTR_MAX, nla, nl802154_key_id_policy, NULL)) + return -EINVAL; + + if (!attrs[NL802154_KEY_ID_ATTR_MODE]) + return -EINVAL; + + desc->mode = nla_get_u32(attrs[NL802154_KEY_ID_ATTR_MODE]); + switch (desc->mode) { + case NL802154_KEY_ID_MODE_IMPLICIT: + if (!attrs[NL802154_KEY_ID_ATTR_IMPLICIT]) + return -EINVAL; + + if (ieee802154_llsec_parse_dev_addr(attrs[NL802154_KEY_ID_ATTR_IMPLICIT], + &desc->device_addr) < 0) + return -EINVAL; + break; + case NL802154_KEY_ID_MODE_INDEX: + break; + case NL802154_KEY_ID_MODE_INDEX_SHORT: + if (!attrs[NL802154_KEY_ID_ATTR_SOURCE_SHORT]) + return -EINVAL; + + desc->short_source = nla_get_le32(attrs[NL802154_KEY_ID_ATTR_SOURCE_SHORT]); + break; + case NL802154_KEY_ID_MODE_INDEX_EXTENDED: + if (!attrs[NL802154_KEY_ID_ATTR_SOURCE_EXTENDED]) + return -EINVAL; + + desc->extended_source = nla_get_le64(attrs[NL802154_KEY_ID_ATTR_SOURCE_EXTENDED]); + break; + default: + return -EINVAL; + } + + if (desc->mode != NL802154_KEY_ID_MODE_IMPLICIT) { + if (!attrs[NL802154_KEY_ID_ATTR_INDEX]) + return -EINVAL; + + /* TODO change id to idx */ + desc->id = nla_get_u8(attrs[NL802154_KEY_ID_ATTR_INDEX]); + } + + return 0; +} + +static int nl802154_set_llsec_params(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wpan_dev *wpan_dev = dev->ieee802154_ptr; + struct ieee802154_llsec_params params; + u32 changed = 0; + int ret; + + if (wpan_dev->iftype == NL802154_IFTYPE_MONITOR) + return -EOPNOTSUPP; + + if (info->attrs[NL802154_ATTR_SEC_ENABLED]) { + u8 enabled; + + enabled = nla_get_u8(info->attrs[NL802154_ATTR_SEC_ENABLED]); + if (enabled != 0 && enabled != 1) + return -EINVAL; + + params.enabled = nla_get_u8(info->attrs[NL802154_ATTR_SEC_ENABLED]); + changed |= IEEE802154_LLSEC_PARAM_ENABLED; + } + + if (info->attrs[NL802154_ATTR_SEC_OUT_KEY_ID]) { + ret = ieee802154_llsec_parse_key_id(info->attrs[NL802154_ATTR_SEC_OUT_KEY_ID], + ¶ms.out_key); + if (ret < 0) + return ret; + + changed |= IEEE802154_LLSEC_PARAM_OUT_KEY; + } + + if (info->attrs[NL802154_ATTR_SEC_OUT_LEVEL]) { + params.out_level = nla_get_u32(info->attrs[NL802154_ATTR_SEC_OUT_LEVEL]); + if (params.out_level > NL802154_SECLEVEL_MAX) + return -EINVAL; + + changed |= IEEE802154_LLSEC_PARAM_OUT_LEVEL; + } + + if (info->attrs[NL802154_ATTR_SEC_FRAME_COUNTER]) { + params.frame_counter = nla_get_be32(info->attrs[NL802154_ATTR_SEC_FRAME_COUNTER]); + changed |= IEEE802154_LLSEC_PARAM_FRAME_COUNTER; + } + + return rdev_set_llsec_params(rdev, wpan_dev, ¶ms, changed); +} + +static int nl802154_send_key(struct sk_buff *msg, u32 cmd, u32 portid, + u32 seq, int flags, + struct cfg802154_registered_device *rdev, + struct net_device *dev, + const struct ieee802154_llsec_key_entry *key) +{ + void *hdr; + u32 commands[NL802154_CMD_FRAME_NR_IDS / 32]; + struct nlattr *nl_key, *nl_key_id; + + hdr = nl802154hdr_put(msg, portid, seq, flags, cmd); + if (!hdr) + return -ENOBUFS; + + if (nla_put_u32(msg, NL802154_ATTR_IFINDEX, dev->ifindex)) + goto nla_put_failure; + + nl_key = nla_nest_start_noflag(msg, NL802154_ATTR_SEC_KEY); + if (!nl_key) + goto nla_put_failure; + + nl_key_id = nla_nest_start_noflag(msg, NL802154_KEY_ATTR_ID); + if (!nl_key_id) + goto nla_put_failure; + + if (ieee802154_llsec_send_key_id(msg, &key->id) < 0) + goto nla_put_failure; + + nla_nest_end(msg, nl_key_id); + + if (nla_put_u8(msg, NL802154_KEY_ATTR_USAGE_FRAMES, + key->key->frame_types)) + goto nla_put_failure; + + if (key->key->frame_types & BIT(NL802154_FRAME_CMD)) { + /* TODO for each nested */ + memset(commands, 0, sizeof(commands)); + commands[7] = key->key->cmd_frame_ids; + if (nla_put(msg, NL802154_KEY_ATTR_USAGE_CMDS, + sizeof(commands), commands)) + goto nla_put_failure; + } + + if (nla_put(msg, NL802154_KEY_ATTR_BYTES, NL802154_KEY_SIZE, + key->key->key)) + goto nla_put_failure; + + nla_nest_end(msg, nl_key); + genlmsg_end(msg, hdr); + + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int +nl802154_dump_llsec_key(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct cfg802154_registered_device *rdev = NULL; + struct ieee802154_llsec_key_entry *key; + struct ieee802154_llsec_table *table; + struct wpan_dev *wpan_dev; + int err; + + err = nl802154_prepare_wpan_dev_dump(skb, cb, &rdev, &wpan_dev); + if (err) + return err; + + if (wpan_dev->iftype == NL802154_IFTYPE_MONITOR) { + err = skb->len; + goto out_err; + } + + if (!wpan_dev->netdev) { + err = -EINVAL; + goto out_err; + } + + rdev_lock_llsec_table(rdev, wpan_dev); + rdev_get_llsec_table(rdev, wpan_dev, &table); + + /* TODO make it like station dump */ + if (cb->args[2]) + goto out; + + list_for_each_entry(key, &table->keys, list) { + if (nl802154_send_key(skb, NL802154_CMD_NEW_SEC_KEY, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + rdev, wpan_dev->netdev, key) < 0) { + /* TODO */ + err = -EIO; + rdev_unlock_llsec_table(rdev, wpan_dev); + goto out_err; + } + } + + cb->args[2] = 1; + +out: + rdev_unlock_llsec_table(rdev, wpan_dev); + err = skb->len; +out_err: + nl802154_finish_wpan_dev_dump(rdev); + + return err; +} + +static const struct nla_policy nl802154_key_policy[NL802154_KEY_ATTR_MAX + 1] = { + [NL802154_KEY_ATTR_ID] = { NLA_NESTED }, + /* TODO handle it as for_each_nested and NLA_FLAG? */ + [NL802154_KEY_ATTR_USAGE_FRAMES] = { NLA_U8 }, + /* TODO handle it as for_each_nested, not static array? */ + [NL802154_KEY_ATTR_USAGE_CMDS] = { .len = NL802154_CMD_FRAME_NR_IDS / 8 }, + [NL802154_KEY_ATTR_BYTES] = { .len = NL802154_KEY_SIZE }, +}; + +static int nl802154_add_llsec_key(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wpan_dev *wpan_dev = dev->ieee802154_ptr; + struct nlattr *attrs[NL802154_KEY_ATTR_MAX + 1]; + struct ieee802154_llsec_key key = { }; + struct ieee802154_llsec_key_id id = { }; + u32 commands[NL802154_CMD_FRAME_NR_IDS / 32] = { }; + + if (wpan_dev->iftype == NL802154_IFTYPE_MONITOR) + return -EOPNOTSUPP; + + if (!info->attrs[NL802154_ATTR_SEC_KEY] || + nla_parse_nested_deprecated(attrs, NL802154_KEY_ATTR_MAX, info->attrs[NL802154_ATTR_SEC_KEY], nl802154_key_policy, info->extack)) + return -EINVAL; + + if (!attrs[NL802154_KEY_ATTR_USAGE_FRAMES] || + !attrs[NL802154_KEY_ATTR_BYTES]) + return -EINVAL; + + if (ieee802154_llsec_parse_key_id(attrs[NL802154_KEY_ATTR_ID], &id) < 0) + return -ENOBUFS; + + key.frame_types = nla_get_u8(attrs[NL802154_KEY_ATTR_USAGE_FRAMES]); + if (key.frame_types > BIT(NL802154_FRAME_MAX) || + ((key.frame_types & BIT(NL802154_FRAME_CMD)) && + !attrs[NL802154_KEY_ATTR_USAGE_CMDS])) + return -EINVAL; + + if (attrs[NL802154_KEY_ATTR_USAGE_CMDS]) { + /* TODO for each nested */ + nla_memcpy(commands, attrs[NL802154_KEY_ATTR_USAGE_CMDS], + NL802154_CMD_FRAME_NR_IDS / 8); + + /* TODO understand the -EINVAL logic here? last condition */ + if (commands[0] || commands[1] || commands[2] || commands[3] || + commands[4] || commands[5] || commands[6] || + commands[7] > BIT(NL802154_CMD_FRAME_MAX)) + return -EINVAL; + + key.cmd_frame_ids = commands[7]; + } else { + key.cmd_frame_ids = 0; + } + + nla_memcpy(key.key, attrs[NL802154_KEY_ATTR_BYTES], NL802154_KEY_SIZE); + + if (ieee802154_llsec_parse_key_id(attrs[NL802154_KEY_ATTR_ID], &id) < 0) + return -ENOBUFS; + + return rdev_add_llsec_key(rdev, wpan_dev, &id, &key); +} + +static int nl802154_del_llsec_key(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wpan_dev *wpan_dev = dev->ieee802154_ptr; + struct nlattr *attrs[NL802154_KEY_ATTR_MAX + 1]; + struct ieee802154_llsec_key_id id; + + if (wpan_dev->iftype == NL802154_IFTYPE_MONITOR) + return -EOPNOTSUPP; + + if (!info->attrs[NL802154_ATTR_SEC_KEY] || + nla_parse_nested_deprecated(attrs, NL802154_KEY_ATTR_MAX, info->attrs[NL802154_ATTR_SEC_KEY], nl802154_key_policy, info->extack)) + return -EINVAL; + + if (ieee802154_llsec_parse_key_id(attrs[NL802154_KEY_ATTR_ID], &id) < 0) + return -ENOBUFS; + + return rdev_del_llsec_key(rdev, wpan_dev, &id); +} + +static int nl802154_send_device(struct sk_buff *msg, u32 cmd, u32 portid, + u32 seq, int flags, + struct cfg802154_registered_device *rdev, + struct net_device *dev, + const struct ieee802154_llsec_device *dev_desc) +{ + void *hdr; + struct nlattr *nl_device; + + hdr = nl802154hdr_put(msg, portid, seq, flags, cmd); + if (!hdr) + return -ENOBUFS; + + if (nla_put_u32(msg, NL802154_ATTR_IFINDEX, dev->ifindex)) + goto nla_put_failure; + + nl_device = nla_nest_start_noflag(msg, NL802154_ATTR_SEC_DEVICE); + if (!nl_device) + goto nla_put_failure; + + if (nla_put_u32(msg, NL802154_DEV_ATTR_FRAME_COUNTER, + dev_desc->frame_counter) || + nla_put_le16(msg, NL802154_DEV_ATTR_PAN_ID, dev_desc->pan_id) || + nla_put_le16(msg, NL802154_DEV_ATTR_SHORT_ADDR, + dev_desc->short_addr) || + nla_put_le64(msg, NL802154_DEV_ATTR_EXTENDED_ADDR, + dev_desc->hwaddr, NL802154_DEV_ATTR_PAD) || + nla_put_u8(msg, NL802154_DEV_ATTR_SECLEVEL_EXEMPT, + dev_desc->seclevel_exempt) || + nla_put_u32(msg, NL802154_DEV_ATTR_KEY_MODE, dev_desc->key_mode)) + goto nla_put_failure; + + nla_nest_end(msg, nl_device); + genlmsg_end(msg, hdr); + + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int +nl802154_dump_llsec_dev(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct cfg802154_registered_device *rdev = NULL; + struct ieee802154_llsec_device *dev; + struct ieee802154_llsec_table *table; + struct wpan_dev *wpan_dev; + int err; + + err = nl802154_prepare_wpan_dev_dump(skb, cb, &rdev, &wpan_dev); + if (err) + return err; + + if (wpan_dev->iftype == NL802154_IFTYPE_MONITOR) { + err = skb->len; + goto out_err; + } + + if (!wpan_dev->netdev) { + err = -EINVAL; + goto out_err; + } + + rdev_lock_llsec_table(rdev, wpan_dev); + rdev_get_llsec_table(rdev, wpan_dev, &table); + + /* TODO make it like station dump */ + if (cb->args[2]) + goto out; + + list_for_each_entry(dev, &table->devices, list) { + if (nl802154_send_device(skb, NL802154_CMD_NEW_SEC_LEVEL, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + rdev, wpan_dev->netdev, dev) < 0) { + /* TODO */ + err = -EIO; + rdev_unlock_llsec_table(rdev, wpan_dev); + goto out_err; + } + } + + cb->args[2] = 1; + +out: + rdev_unlock_llsec_table(rdev, wpan_dev); + err = skb->len; +out_err: + nl802154_finish_wpan_dev_dump(rdev); + + return err; +} + +static const struct nla_policy nl802154_dev_policy[NL802154_DEV_ATTR_MAX + 1] = { + [NL802154_DEV_ATTR_FRAME_COUNTER] = { NLA_U32 }, + [NL802154_DEV_ATTR_PAN_ID] = { .type = NLA_U16 }, + [NL802154_DEV_ATTR_SHORT_ADDR] = { .type = NLA_U16 }, + [NL802154_DEV_ATTR_EXTENDED_ADDR] = { .type = NLA_U64 }, + [NL802154_DEV_ATTR_SECLEVEL_EXEMPT] = { NLA_U8 }, + [NL802154_DEV_ATTR_KEY_MODE] = { NLA_U32 }, +}; + +static int +ieee802154_llsec_parse_device(struct nlattr *nla, + struct ieee802154_llsec_device *dev) +{ + struct nlattr *attrs[NL802154_DEV_ATTR_MAX + 1]; + + if (!nla || nla_parse_nested_deprecated(attrs, NL802154_DEV_ATTR_MAX, nla, nl802154_dev_policy, NULL)) + return -EINVAL; + + memset(dev, 0, sizeof(*dev)); + + if (!attrs[NL802154_DEV_ATTR_FRAME_COUNTER] || + !attrs[NL802154_DEV_ATTR_PAN_ID] || + !attrs[NL802154_DEV_ATTR_SHORT_ADDR] || + !attrs[NL802154_DEV_ATTR_EXTENDED_ADDR] || + !attrs[NL802154_DEV_ATTR_SECLEVEL_EXEMPT] || + !attrs[NL802154_DEV_ATTR_KEY_MODE]) + return -EINVAL; + + /* TODO be32 */ + dev->frame_counter = nla_get_u32(attrs[NL802154_DEV_ATTR_FRAME_COUNTER]); + dev->pan_id = nla_get_le16(attrs[NL802154_DEV_ATTR_PAN_ID]); + dev->short_addr = nla_get_le16(attrs[NL802154_DEV_ATTR_SHORT_ADDR]); + /* TODO rename hwaddr to extended_addr */ + dev->hwaddr = nla_get_le64(attrs[NL802154_DEV_ATTR_EXTENDED_ADDR]); + dev->seclevel_exempt = nla_get_u8(attrs[NL802154_DEV_ATTR_SECLEVEL_EXEMPT]); + dev->key_mode = nla_get_u32(attrs[NL802154_DEV_ATTR_KEY_MODE]); + + if (dev->key_mode > NL802154_DEVKEY_MAX || + (dev->seclevel_exempt != 0 && dev->seclevel_exempt != 1)) + return -EINVAL; + + return 0; +} + +static int nl802154_add_llsec_dev(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wpan_dev *wpan_dev = dev->ieee802154_ptr; + struct ieee802154_llsec_device dev_desc; + + if (wpan_dev->iftype == NL802154_IFTYPE_MONITOR) + return -EOPNOTSUPP; + + if (ieee802154_llsec_parse_device(info->attrs[NL802154_ATTR_SEC_DEVICE], + &dev_desc) < 0) + return -EINVAL; + + return rdev_add_device(rdev, wpan_dev, &dev_desc); +} + +static int nl802154_del_llsec_dev(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wpan_dev *wpan_dev = dev->ieee802154_ptr; + struct nlattr *attrs[NL802154_DEV_ATTR_MAX + 1]; + __le64 extended_addr; + + if (wpan_dev->iftype == NL802154_IFTYPE_MONITOR) + return -EOPNOTSUPP; + + if (!info->attrs[NL802154_ATTR_SEC_DEVICE] || + nla_parse_nested_deprecated(attrs, NL802154_DEV_ATTR_MAX, info->attrs[NL802154_ATTR_SEC_DEVICE], nl802154_dev_policy, info->extack)) + return -EINVAL; + + if (!attrs[NL802154_DEV_ATTR_EXTENDED_ADDR]) + return -EINVAL; + + extended_addr = nla_get_le64(attrs[NL802154_DEV_ATTR_EXTENDED_ADDR]); + return rdev_del_device(rdev, wpan_dev, extended_addr); +} + +static int nl802154_send_devkey(struct sk_buff *msg, u32 cmd, u32 portid, + u32 seq, int flags, + struct cfg802154_registered_device *rdev, + struct net_device *dev, __le64 extended_addr, + const struct ieee802154_llsec_device_key *devkey) +{ + void *hdr; + struct nlattr *nl_devkey, *nl_key_id; + + hdr = nl802154hdr_put(msg, portid, seq, flags, cmd); + if (!hdr) + return -ENOBUFS; + + if (nla_put_u32(msg, NL802154_ATTR_IFINDEX, dev->ifindex)) + goto nla_put_failure; + + nl_devkey = nla_nest_start_noflag(msg, NL802154_ATTR_SEC_DEVKEY); + if (!nl_devkey) + goto nla_put_failure; + + if (nla_put_le64(msg, NL802154_DEVKEY_ATTR_EXTENDED_ADDR, + extended_addr, NL802154_DEVKEY_ATTR_PAD) || + nla_put_u32(msg, NL802154_DEVKEY_ATTR_FRAME_COUNTER, + devkey->frame_counter)) + goto nla_put_failure; + + nl_key_id = nla_nest_start_noflag(msg, NL802154_DEVKEY_ATTR_ID); + if (!nl_key_id) + goto nla_put_failure; + + if (ieee802154_llsec_send_key_id(msg, &devkey->key_id) < 0) + goto nla_put_failure; + + nla_nest_end(msg, nl_key_id); + nla_nest_end(msg, nl_devkey); + genlmsg_end(msg, hdr); + + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int +nl802154_dump_llsec_devkey(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct cfg802154_registered_device *rdev = NULL; + struct ieee802154_llsec_device_key *kpos; + struct ieee802154_llsec_device *dpos; + struct ieee802154_llsec_table *table; + struct wpan_dev *wpan_dev; + int err; + + err = nl802154_prepare_wpan_dev_dump(skb, cb, &rdev, &wpan_dev); + if (err) + return err; + + if (wpan_dev->iftype == NL802154_IFTYPE_MONITOR) { + err = skb->len; + goto out_err; + } + + if (!wpan_dev->netdev) { + err = -EINVAL; + goto out_err; + } + + rdev_lock_llsec_table(rdev, wpan_dev); + rdev_get_llsec_table(rdev, wpan_dev, &table); + + /* TODO make it like station dump */ + if (cb->args[2]) + goto out; + + /* TODO look if remove devkey and do some nested attribute */ + list_for_each_entry(dpos, &table->devices, list) { + list_for_each_entry(kpos, &dpos->keys, list) { + if (nl802154_send_devkey(skb, + NL802154_CMD_NEW_SEC_LEVEL, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI, rdev, + wpan_dev->netdev, + dpos->hwaddr, + kpos) < 0) { + /* TODO */ + err = -EIO; + rdev_unlock_llsec_table(rdev, wpan_dev); + goto out_err; + } + } + } + + cb->args[2] = 1; + +out: + rdev_unlock_llsec_table(rdev, wpan_dev); + err = skb->len; +out_err: + nl802154_finish_wpan_dev_dump(rdev); + + return err; +} + +static const struct nla_policy nl802154_devkey_policy[NL802154_DEVKEY_ATTR_MAX + 1] = { + [NL802154_DEVKEY_ATTR_FRAME_COUNTER] = { NLA_U32 }, + [NL802154_DEVKEY_ATTR_EXTENDED_ADDR] = { NLA_U64 }, + [NL802154_DEVKEY_ATTR_ID] = { NLA_NESTED }, +}; + +static int nl802154_add_llsec_devkey(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wpan_dev *wpan_dev = dev->ieee802154_ptr; + struct nlattr *attrs[NL802154_DEVKEY_ATTR_MAX + 1]; + struct ieee802154_llsec_device_key key; + __le64 extended_addr; + + if (wpan_dev->iftype == NL802154_IFTYPE_MONITOR) + return -EOPNOTSUPP; + + if (!info->attrs[NL802154_ATTR_SEC_DEVKEY] || + nla_parse_nested_deprecated(attrs, NL802154_DEVKEY_ATTR_MAX, info->attrs[NL802154_ATTR_SEC_DEVKEY], nl802154_devkey_policy, info->extack) < 0) + return -EINVAL; + + if (!attrs[NL802154_DEVKEY_ATTR_FRAME_COUNTER] || + !attrs[NL802154_DEVKEY_ATTR_EXTENDED_ADDR]) + return -EINVAL; + + /* TODO change key.id ? */ + if (ieee802154_llsec_parse_key_id(attrs[NL802154_DEVKEY_ATTR_ID], + &key.key_id) < 0) + return -ENOBUFS; + + /* TODO be32 */ + key.frame_counter = nla_get_u32(attrs[NL802154_DEVKEY_ATTR_FRAME_COUNTER]); + /* TODO change naming hwaddr -> extended_addr + * check unique identifier short+pan OR extended_addr + */ + extended_addr = nla_get_le64(attrs[NL802154_DEVKEY_ATTR_EXTENDED_ADDR]); + return rdev_add_devkey(rdev, wpan_dev, extended_addr, &key); +} + +static int nl802154_del_llsec_devkey(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wpan_dev *wpan_dev = dev->ieee802154_ptr; + struct nlattr *attrs[NL802154_DEVKEY_ATTR_MAX + 1]; + struct ieee802154_llsec_device_key key; + __le64 extended_addr; + + if (wpan_dev->iftype == NL802154_IFTYPE_MONITOR) + return -EOPNOTSUPP; + + if (!info->attrs[NL802154_ATTR_SEC_DEVKEY] || + nla_parse_nested_deprecated(attrs, NL802154_DEVKEY_ATTR_MAX, info->attrs[NL802154_ATTR_SEC_DEVKEY], nl802154_devkey_policy, info->extack)) + return -EINVAL; + + if (!attrs[NL802154_DEVKEY_ATTR_EXTENDED_ADDR]) + return -EINVAL; + + /* TODO change key.id ? */ + if (ieee802154_llsec_parse_key_id(attrs[NL802154_DEVKEY_ATTR_ID], + &key.key_id) < 0) + return -ENOBUFS; + + /* TODO change naming hwaddr -> extended_addr + * check unique identifier short+pan OR extended_addr + */ + extended_addr = nla_get_le64(attrs[NL802154_DEVKEY_ATTR_EXTENDED_ADDR]); + return rdev_del_devkey(rdev, wpan_dev, extended_addr, &key); +} + +static int nl802154_send_seclevel(struct sk_buff *msg, u32 cmd, u32 portid, + u32 seq, int flags, + struct cfg802154_registered_device *rdev, + struct net_device *dev, + const struct ieee802154_llsec_seclevel *sl) +{ + void *hdr; + struct nlattr *nl_seclevel; + + hdr = nl802154hdr_put(msg, portid, seq, flags, cmd); + if (!hdr) + return -ENOBUFS; + + if (nla_put_u32(msg, NL802154_ATTR_IFINDEX, dev->ifindex)) + goto nla_put_failure; + + nl_seclevel = nla_nest_start_noflag(msg, NL802154_ATTR_SEC_LEVEL); + if (!nl_seclevel) + goto nla_put_failure; + + if (nla_put_u32(msg, NL802154_SECLEVEL_ATTR_FRAME, sl->frame_type) || + nla_put_u32(msg, NL802154_SECLEVEL_ATTR_LEVELS, sl->sec_levels) || + nla_put_u8(msg, NL802154_SECLEVEL_ATTR_DEV_OVERRIDE, + sl->device_override)) + goto nla_put_failure; + + if (sl->frame_type == NL802154_FRAME_CMD) { + if (nla_put_u32(msg, NL802154_SECLEVEL_ATTR_CMD_FRAME, + sl->cmd_frame_id)) + goto nla_put_failure; + } + + nla_nest_end(msg, nl_seclevel); + genlmsg_end(msg, hdr); + + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int +nl802154_dump_llsec_seclevel(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct cfg802154_registered_device *rdev = NULL; + struct ieee802154_llsec_seclevel *sl; + struct ieee802154_llsec_table *table; + struct wpan_dev *wpan_dev; + int err; + + err = nl802154_prepare_wpan_dev_dump(skb, cb, &rdev, &wpan_dev); + if (err) + return err; + + if (wpan_dev->iftype == NL802154_IFTYPE_MONITOR) { + err = skb->len; + goto out_err; + } + + if (!wpan_dev->netdev) { + err = -EINVAL; + goto out_err; + } + + rdev_lock_llsec_table(rdev, wpan_dev); + rdev_get_llsec_table(rdev, wpan_dev, &table); + + /* TODO make it like station dump */ + if (cb->args[2]) + goto out; + + list_for_each_entry(sl, &table->security_levels, list) { + if (nl802154_send_seclevel(skb, NL802154_CMD_NEW_SEC_LEVEL, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + rdev, wpan_dev->netdev, sl) < 0) { + /* TODO */ + err = -EIO; + rdev_unlock_llsec_table(rdev, wpan_dev); + goto out_err; + } + } + + cb->args[2] = 1; + +out: + rdev_unlock_llsec_table(rdev, wpan_dev); + err = skb->len; +out_err: + nl802154_finish_wpan_dev_dump(rdev); + + return err; +} + +static const struct nla_policy nl802154_seclevel_policy[NL802154_SECLEVEL_ATTR_MAX + 1] = { + [NL802154_SECLEVEL_ATTR_LEVELS] = { .type = NLA_U8 }, + [NL802154_SECLEVEL_ATTR_FRAME] = { .type = NLA_U32 }, + [NL802154_SECLEVEL_ATTR_CMD_FRAME] = { .type = NLA_U32 }, + [NL802154_SECLEVEL_ATTR_DEV_OVERRIDE] = { .type = NLA_U8 }, +}; + +static int +llsec_parse_seclevel(struct nlattr *nla, struct ieee802154_llsec_seclevel *sl) +{ + struct nlattr *attrs[NL802154_SECLEVEL_ATTR_MAX + 1]; + + if (!nla || nla_parse_nested_deprecated(attrs, NL802154_SECLEVEL_ATTR_MAX, nla, nl802154_seclevel_policy, NULL)) + return -EINVAL; + + memset(sl, 0, sizeof(*sl)); + + if (!attrs[NL802154_SECLEVEL_ATTR_LEVELS] || + !attrs[NL802154_SECLEVEL_ATTR_FRAME] || + !attrs[NL802154_SECLEVEL_ATTR_DEV_OVERRIDE]) + return -EINVAL; + + sl->sec_levels = nla_get_u8(attrs[NL802154_SECLEVEL_ATTR_LEVELS]); + sl->frame_type = nla_get_u32(attrs[NL802154_SECLEVEL_ATTR_FRAME]); + sl->device_override = nla_get_u8(attrs[NL802154_SECLEVEL_ATTR_DEV_OVERRIDE]); + if (sl->frame_type > NL802154_FRAME_MAX || + (sl->device_override != 0 && sl->device_override != 1)) + return -EINVAL; + + if (sl->frame_type == NL802154_FRAME_CMD) { + if (!attrs[NL802154_SECLEVEL_ATTR_CMD_FRAME]) + return -EINVAL; + + sl->cmd_frame_id = nla_get_u32(attrs[NL802154_SECLEVEL_ATTR_CMD_FRAME]); + if (sl->cmd_frame_id > NL802154_CMD_FRAME_MAX) + return -EINVAL; + } + + return 0; +} + +static int nl802154_add_llsec_seclevel(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wpan_dev *wpan_dev = dev->ieee802154_ptr; + struct ieee802154_llsec_seclevel sl; + + if (wpan_dev->iftype == NL802154_IFTYPE_MONITOR) + return -EOPNOTSUPP; + + if (llsec_parse_seclevel(info->attrs[NL802154_ATTR_SEC_LEVEL], + &sl) < 0) + return -EINVAL; + + return rdev_add_seclevel(rdev, wpan_dev, &sl); +} + +static int nl802154_del_llsec_seclevel(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg802154_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wpan_dev *wpan_dev = dev->ieee802154_ptr; + struct ieee802154_llsec_seclevel sl; + + if (wpan_dev->iftype == NL802154_IFTYPE_MONITOR) + return -EOPNOTSUPP; + + if (!info->attrs[NL802154_ATTR_SEC_LEVEL] || + llsec_parse_seclevel(info->attrs[NL802154_ATTR_SEC_LEVEL], + &sl) < 0) + return -EINVAL; + + return rdev_del_seclevel(rdev, wpan_dev, &sl); +} +#endif /* CONFIG_IEEE802154_NL802154_EXPERIMENTAL */ + +#define NL802154_FLAG_NEED_WPAN_PHY 0x01 +#define NL802154_FLAG_NEED_NETDEV 0x02 +#define NL802154_FLAG_NEED_RTNL 0x04 +#define NL802154_FLAG_CHECK_NETDEV_UP 0x08 +#define NL802154_FLAG_NEED_WPAN_DEV 0x10 + +static int nl802154_pre_doit(const struct genl_ops *ops, struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg802154_registered_device *rdev; + struct wpan_dev *wpan_dev; + struct net_device *dev; + bool rtnl = ops->internal_flags & NL802154_FLAG_NEED_RTNL; + + if (rtnl) + rtnl_lock(); + + if (ops->internal_flags & NL802154_FLAG_NEED_WPAN_PHY) { + rdev = cfg802154_get_dev_from_info(genl_info_net(info), info); + if (IS_ERR(rdev)) { + if (rtnl) + rtnl_unlock(); + return PTR_ERR(rdev); + } + info->user_ptr[0] = rdev; + } else if (ops->internal_flags & NL802154_FLAG_NEED_NETDEV || + ops->internal_flags & NL802154_FLAG_NEED_WPAN_DEV) { + ASSERT_RTNL(); + wpan_dev = __cfg802154_wpan_dev_from_attrs(genl_info_net(info), + info->attrs); + if (IS_ERR(wpan_dev)) { + if (rtnl) + rtnl_unlock(); + return PTR_ERR(wpan_dev); + } + + dev = wpan_dev->netdev; + rdev = wpan_phy_to_rdev(wpan_dev->wpan_phy); + + if (ops->internal_flags & NL802154_FLAG_NEED_NETDEV) { + if (!dev) { + if (rtnl) + rtnl_unlock(); + return -EINVAL; + } + + info->user_ptr[1] = dev; + } else { + info->user_ptr[1] = wpan_dev; + } + + if (dev) { + if (ops->internal_flags & NL802154_FLAG_CHECK_NETDEV_UP && + !netif_running(dev)) { + if (rtnl) + rtnl_unlock(); + return -ENETDOWN; + } + + dev_hold(dev); + } + + info->user_ptr[0] = rdev; + } + + return 0; +} + +static void nl802154_post_doit(const struct genl_ops *ops, struct sk_buff *skb, + struct genl_info *info) +{ + if (info->user_ptr[1]) { + if (ops->internal_flags & NL802154_FLAG_NEED_WPAN_DEV) { + struct wpan_dev *wpan_dev = info->user_ptr[1]; + + dev_put(wpan_dev->netdev); + } else { + dev_put(info->user_ptr[1]); + } + } + + if (ops->internal_flags & NL802154_FLAG_NEED_RTNL) + rtnl_unlock(); +} + +static const struct genl_ops nl802154_ops[] = { + { + .cmd = NL802154_CMD_GET_WPAN_PHY, + .validate = GENL_DONT_VALIDATE_STRICT | + GENL_DONT_VALIDATE_DUMP_STRICT, + .doit = nl802154_get_wpan_phy, + .dumpit = nl802154_dump_wpan_phy, + .done = nl802154_dump_wpan_phy_done, + /* can be retrieved by unprivileged users */ + .internal_flags = NL802154_FLAG_NEED_WPAN_PHY | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_GET_INTERFACE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_get_interface, + .dumpit = nl802154_dump_interface, + /* can be retrieved by unprivileged users */ + .internal_flags = NL802154_FLAG_NEED_WPAN_DEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_NEW_INTERFACE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_new_interface, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_WPAN_PHY | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_DEL_INTERFACE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_del_interface, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_WPAN_DEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_SET_CHANNEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_set_channel, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_WPAN_PHY | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_SET_CCA_MODE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_set_cca_mode, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_WPAN_PHY | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_SET_CCA_ED_LEVEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_set_cca_ed_level, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_WPAN_PHY | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_SET_TX_POWER, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_set_tx_power, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_WPAN_PHY | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_SET_WPAN_PHY_NETNS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_wpan_phy_netns, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_WPAN_PHY | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_SET_PAN_ID, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_set_pan_id, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_SET_SHORT_ADDR, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_set_short_addr, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_SET_BACKOFF_EXPONENT, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_set_backoff_exponent, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_SET_MAX_CSMA_BACKOFFS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_set_max_csma_backoffs, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_SET_MAX_FRAME_RETRIES, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_set_max_frame_retries, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_SET_LBT_MODE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_set_lbt_mode, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_SET_ACKREQ_DEFAULT, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_set_ackreq_default, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, +#ifdef CONFIG_IEEE802154_NL802154_EXPERIMENTAL + { + .cmd = NL802154_CMD_SET_SEC_PARAMS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_set_llsec_params, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_GET_SEC_KEY, + .validate = GENL_DONT_VALIDATE_STRICT | + GENL_DONT_VALIDATE_DUMP_STRICT, + /* TODO .doit by matching key id? */ + .dumpit = nl802154_dump_llsec_key, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_NEW_SEC_KEY, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_add_llsec_key, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_DEL_SEC_KEY, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_del_llsec_key, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + /* TODO unique identifier must short+pan OR extended_addr */ + { + .cmd = NL802154_CMD_GET_SEC_DEV, + .validate = GENL_DONT_VALIDATE_STRICT | + GENL_DONT_VALIDATE_DUMP_STRICT, + /* TODO .doit by matching extended_addr? */ + .dumpit = nl802154_dump_llsec_dev, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_NEW_SEC_DEV, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_add_llsec_dev, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_DEL_SEC_DEV, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_del_llsec_dev, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + /* TODO remove complete devkey, put it as nested? */ + { + .cmd = NL802154_CMD_GET_SEC_DEVKEY, + .validate = GENL_DONT_VALIDATE_STRICT | + GENL_DONT_VALIDATE_DUMP_STRICT, + /* TODO doit by matching ??? */ + .dumpit = nl802154_dump_llsec_devkey, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_NEW_SEC_DEVKEY, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_add_llsec_devkey, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_DEL_SEC_DEVKEY, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_del_llsec_devkey, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_GET_SEC_LEVEL, + .validate = GENL_DONT_VALIDATE_STRICT | + GENL_DONT_VALIDATE_DUMP_STRICT, + /* TODO .doit by matching frame_type? */ + .dumpit = nl802154_dump_llsec_seclevel, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_NEW_SEC_LEVEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl802154_add_llsec_seclevel, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, + { + .cmd = NL802154_CMD_DEL_SEC_LEVEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + /* TODO match frame_type only? */ + .doit = nl802154_del_llsec_seclevel, + .flags = GENL_ADMIN_PERM, + .internal_flags = NL802154_FLAG_NEED_NETDEV | + NL802154_FLAG_NEED_RTNL, + }, +#endif /* CONFIG_IEEE802154_NL802154_EXPERIMENTAL */ +}; + +static struct genl_family nl802154_fam __ro_after_init = { + .name = NL802154_GENL_NAME, /* have users key off the name instead */ + .hdrsize = 0, /* no private header */ + .version = 1, /* no particular meaning now */ + .maxattr = NL802154_ATTR_MAX, + .policy = nl802154_policy, + .netnsok = true, + .pre_doit = nl802154_pre_doit, + .post_doit = nl802154_post_doit, + .module = THIS_MODULE, + .ops = nl802154_ops, + .n_ops = ARRAY_SIZE(nl802154_ops), + .resv_start_op = NL802154_CMD_DEL_SEC_LEVEL + 1, + .mcgrps = nl802154_mcgrps, + .n_mcgrps = ARRAY_SIZE(nl802154_mcgrps), +}; + +/* initialisation/exit functions */ +int __init nl802154_init(void) +{ + return genl_register_family(&nl802154_fam); +} + +void nl802154_exit(void) +{ + genl_unregister_family(&nl802154_fam); +} diff --git a/net/ieee802154/nl802154.h b/net/ieee802154/nl802154.h new file mode 100644 index 000000000..8c4b6d089 --- /dev/null +++ b/net/ieee802154/nl802154.h @@ -0,0 +1,8 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __IEEE802154_NL802154_H +#define __IEEE802154_NL802154_H + +int nl802154_init(void); +void nl802154_exit(void); + +#endif /* __IEEE802154_NL802154_H */ diff --git a/net/ieee802154/nl_policy.c b/net/ieee802154/nl_policy.c new file mode 100644 index 000000000..0672b2f01 --- /dev/null +++ b/net/ieee802154/nl_policy.c @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * nl802154.h + * + * Copyright (C) 2007, 2008 Siemens AG + */ + +#include +#include +#include + +#define NLA_HW_ADDR NLA_U64 + +const struct nla_policy ieee802154_policy[IEEE802154_ATTR_MAX + 1] = { + [IEEE802154_ATTR_DEV_NAME] = { .type = NLA_STRING, }, + [IEEE802154_ATTR_DEV_INDEX] = { .type = NLA_U32, }, + [IEEE802154_ATTR_PHY_NAME] = { .type = NLA_STRING, }, + + [IEEE802154_ATTR_STATUS] = { .type = NLA_U8, }, + [IEEE802154_ATTR_SHORT_ADDR] = { .type = NLA_U16, }, + [IEEE802154_ATTR_HW_ADDR] = { .type = NLA_HW_ADDR, }, + [IEEE802154_ATTR_PAN_ID] = { .type = NLA_U16, }, + [IEEE802154_ATTR_CHANNEL] = { .type = NLA_U8, }, + [IEEE802154_ATTR_BCN_ORD] = { .type = NLA_U8, }, + [IEEE802154_ATTR_SF_ORD] = { .type = NLA_U8, }, + [IEEE802154_ATTR_PAN_COORD] = { .type = NLA_U8, }, + [IEEE802154_ATTR_BAT_EXT] = { .type = NLA_U8, }, + [IEEE802154_ATTR_COORD_REALIGN] = { .type = NLA_U8, }, + [IEEE802154_ATTR_PAGE] = { .type = NLA_U8, }, + [IEEE802154_ATTR_DEV_TYPE] = { .type = NLA_U8, }, + [IEEE802154_ATTR_COORD_SHORT_ADDR] = { .type = NLA_U16, }, + [IEEE802154_ATTR_COORD_HW_ADDR] = { .type = NLA_HW_ADDR, }, + [IEEE802154_ATTR_COORD_PAN_ID] = { .type = NLA_U16, }, + [IEEE802154_ATTR_SRC_SHORT_ADDR] = { .type = NLA_U16, }, + [IEEE802154_ATTR_SRC_HW_ADDR] = { .type = NLA_HW_ADDR, }, + [IEEE802154_ATTR_SRC_PAN_ID] = { .type = NLA_U16, }, + [IEEE802154_ATTR_DEST_SHORT_ADDR] = { .type = NLA_U16, }, + [IEEE802154_ATTR_DEST_HW_ADDR] = { .type = NLA_HW_ADDR, }, + [IEEE802154_ATTR_DEST_PAN_ID] = { .type = NLA_U16, }, + + [IEEE802154_ATTR_CAPABILITY] = { .type = NLA_U8, }, + [IEEE802154_ATTR_REASON] = { .type = NLA_U8, }, + [IEEE802154_ATTR_SCAN_TYPE] = { .type = NLA_U8, }, + [IEEE802154_ATTR_CHANNELS] = { .type = NLA_U32, }, + [IEEE802154_ATTR_DURATION] = { .type = NLA_U8, }, + [IEEE802154_ATTR_ED_LIST] = { .len = 27 }, + [IEEE802154_ATTR_CHANNEL_PAGE_LIST] = { .len = 32 * 4, }, + + [IEEE802154_ATTR_TXPOWER] = { .type = NLA_S8, }, + [IEEE802154_ATTR_LBT_ENABLED] = { .type = NLA_U8, }, + [IEEE802154_ATTR_CCA_MODE] = { .type = NLA_U8, }, + [IEEE802154_ATTR_CCA_ED_LEVEL] = { .type = NLA_S32, }, + [IEEE802154_ATTR_CSMA_RETRIES] = { .type = NLA_U8, }, + [IEEE802154_ATTR_CSMA_MIN_BE] = { .type = NLA_U8, }, + [IEEE802154_ATTR_CSMA_MAX_BE] = { .type = NLA_U8, }, + + [IEEE802154_ATTR_FRAME_RETRIES] = { .type = NLA_S8, }, + + [IEEE802154_ATTR_LLSEC_ENABLED] = { .type = NLA_U8, }, + [IEEE802154_ATTR_LLSEC_SECLEVEL] = { .type = NLA_U8, }, + [IEEE802154_ATTR_LLSEC_KEY_MODE] = { .type = NLA_U8, }, + [IEEE802154_ATTR_LLSEC_KEY_SOURCE_SHORT] = { .type = NLA_U32, }, + [IEEE802154_ATTR_LLSEC_KEY_SOURCE_EXTENDED] = { .type = NLA_HW_ADDR, }, + [IEEE802154_ATTR_LLSEC_KEY_ID] = { .type = NLA_U8, }, + [IEEE802154_ATTR_LLSEC_FRAME_COUNTER] = { .type = NLA_U32 }, + [IEEE802154_ATTR_LLSEC_KEY_BYTES] = { .len = 16, }, + [IEEE802154_ATTR_LLSEC_KEY_USAGE_FRAME_TYPES] = { .type = NLA_U8, }, + [IEEE802154_ATTR_LLSEC_KEY_USAGE_COMMANDS] = { .len = 258 / 8 }, + [IEEE802154_ATTR_LLSEC_FRAME_TYPE] = { .type = NLA_U8, }, + [IEEE802154_ATTR_LLSEC_CMD_FRAME_ID] = { .type = NLA_U8, }, + [IEEE802154_ATTR_LLSEC_SECLEVELS] = { .type = NLA_U8, }, + [IEEE802154_ATTR_LLSEC_DEV_OVERRIDE] = { .type = NLA_U8, }, + [IEEE802154_ATTR_LLSEC_DEV_KEY_MODE] = { .type = NLA_U8, }, +}; diff --git a/net/ieee802154/rdev-ops.h b/net/ieee802154/rdev-ops.h new file mode 100644 index 000000000..598f5af49 --- /dev/null +++ b/net/ieee802154/rdev-ops.h @@ -0,0 +1,321 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __CFG802154_RDEV_OPS +#define __CFG802154_RDEV_OPS + +#include + +#include "core.h" +#include "trace.h" + +static inline struct net_device * +rdev_add_virtual_intf_deprecated(struct cfg802154_registered_device *rdev, + const char *name, + unsigned char name_assign_type, + int type) +{ + return rdev->ops->add_virtual_intf_deprecated(&rdev->wpan_phy, name, + name_assign_type, type); +} + +static inline void +rdev_del_virtual_intf_deprecated(struct cfg802154_registered_device *rdev, + struct net_device *dev) +{ + rdev->ops->del_virtual_intf_deprecated(&rdev->wpan_phy, dev); +} + +static inline int +rdev_suspend(struct cfg802154_registered_device *rdev) +{ + int ret; + trace_802154_rdev_suspend(&rdev->wpan_phy); + ret = rdev->ops->suspend(&rdev->wpan_phy); + trace_802154_rdev_return_int(&rdev->wpan_phy, ret); + return ret; +} + +static inline int +rdev_resume(struct cfg802154_registered_device *rdev) +{ + int ret; + trace_802154_rdev_resume(&rdev->wpan_phy); + ret = rdev->ops->resume(&rdev->wpan_phy); + trace_802154_rdev_return_int(&rdev->wpan_phy, ret); + return ret; +} + +static inline int +rdev_add_virtual_intf(struct cfg802154_registered_device *rdev, char *name, + unsigned char name_assign_type, + enum nl802154_iftype type, __le64 extended_addr) +{ + int ret; + + trace_802154_rdev_add_virtual_intf(&rdev->wpan_phy, name, type, + extended_addr); + ret = rdev->ops->add_virtual_intf(&rdev->wpan_phy, name, + name_assign_type, type, + extended_addr); + trace_802154_rdev_return_int(&rdev->wpan_phy, ret); + return ret; +} + +static inline int +rdev_del_virtual_intf(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev) +{ + int ret; + + trace_802154_rdev_del_virtual_intf(&rdev->wpan_phy, wpan_dev); + ret = rdev->ops->del_virtual_intf(&rdev->wpan_phy, wpan_dev); + trace_802154_rdev_return_int(&rdev->wpan_phy, ret); + return ret; +} + +static inline int +rdev_set_channel(struct cfg802154_registered_device *rdev, u8 page, u8 channel) +{ + int ret; + + trace_802154_rdev_set_channel(&rdev->wpan_phy, page, channel); + ret = rdev->ops->set_channel(&rdev->wpan_phy, page, channel); + trace_802154_rdev_return_int(&rdev->wpan_phy, ret); + return ret; +} + +static inline int +rdev_set_cca_mode(struct cfg802154_registered_device *rdev, + const struct wpan_phy_cca *cca) +{ + int ret; + + trace_802154_rdev_set_cca_mode(&rdev->wpan_phy, cca); + ret = rdev->ops->set_cca_mode(&rdev->wpan_phy, cca); + trace_802154_rdev_return_int(&rdev->wpan_phy, ret); + return ret; +} + +static inline int +rdev_set_cca_ed_level(struct cfg802154_registered_device *rdev, s32 ed_level) +{ + int ret; + + trace_802154_rdev_set_cca_ed_level(&rdev->wpan_phy, ed_level); + ret = rdev->ops->set_cca_ed_level(&rdev->wpan_phy, ed_level); + trace_802154_rdev_return_int(&rdev->wpan_phy, ret); + return ret; +} + +static inline int +rdev_set_tx_power(struct cfg802154_registered_device *rdev, + s32 power) +{ + int ret; + + trace_802154_rdev_set_tx_power(&rdev->wpan_phy, power); + ret = rdev->ops->set_tx_power(&rdev->wpan_phy, power); + trace_802154_rdev_return_int(&rdev->wpan_phy, ret); + return ret; +} + +static inline int +rdev_set_pan_id(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, __le16 pan_id) +{ + int ret; + + trace_802154_rdev_set_pan_id(&rdev->wpan_phy, wpan_dev, pan_id); + ret = rdev->ops->set_pan_id(&rdev->wpan_phy, wpan_dev, pan_id); + trace_802154_rdev_return_int(&rdev->wpan_phy, ret); + return ret; +} + +static inline int +rdev_set_short_addr(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, __le16 short_addr) +{ + int ret; + + trace_802154_rdev_set_short_addr(&rdev->wpan_phy, wpan_dev, short_addr); + ret = rdev->ops->set_short_addr(&rdev->wpan_phy, wpan_dev, short_addr); + trace_802154_rdev_return_int(&rdev->wpan_phy, ret); + return ret; +} + +static inline int +rdev_set_backoff_exponent(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, u8 min_be, u8 max_be) +{ + int ret; + + trace_802154_rdev_set_backoff_exponent(&rdev->wpan_phy, wpan_dev, + min_be, max_be); + ret = rdev->ops->set_backoff_exponent(&rdev->wpan_phy, wpan_dev, + min_be, max_be); + trace_802154_rdev_return_int(&rdev->wpan_phy, ret); + return ret; +} + +static inline int +rdev_set_max_csma_backoffs(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, u8 max_csma_backoffs) +{ + int ret; + + trace_802154_rdev_set_csma_backoffs(&rdev->wpan_phy, wpan_dev, + max_csma_backoffs); + ret = rdev->ops->set_max_csma_backoffs(&rdev->wpan_phy, wpan_dev, + max_csma_backoffs); + trace_802154_rdev_return_int(&rdev->wpan_phy, ret); + return ret; +} + +static inline int +rdev_set_max_frame_retries(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, s8 max_frame_retries) +{ + int ret; + + trace_802154_rdev_set_max_frame_retries(&rdev->wpan_phy, wpan_dev, + max_frame_retries); + ret = rdev->ops->set_max_frame_retries(&rdev->wpan_phy, wpan_dev, + max_frame_retries); + trace_802154_rdev_return_int(&rdev->wpan_phy, ret); + return ret; +} + +static inline int +rdev_set_lbt_mode(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, bool mode) +{ + int ret; + + trace_802154_rdev_set_lbt_mode(&rdev->wpan_phy, wpan_dev, mode); + ret = rdev->ops->set_lbt_mode(&rdev->wpan_phy, wpan_dev, mode); + trace_802154_rdev_return_int(&rdev->wpan_phy, ret); + return ret; +} + +static inline int +rdev_set_ackreq_default(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, bool ackreq) +{ + int ret; + + trace_802154_rdev_set_ackreq_default(&rdev->wpan_phy, wpan_dev, + ackreq); + ret = rdev->ops->set_ackreq_default(&rdev->wpan_phy, wpan_dev, ackreq); + trace_802154_rdev_return_int(&rdev->wpan_phy, ret); + return ret; +} + +#ifdef CONFIG_IEEE802154_NL802154_EXPERIMENTAL +/* TODO this is already a nl802154, so move into ieee802154 */ +static inline void +rdev_get_llsec_table(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, + struct ieee802154_llsec_table **table) +{ + rdev->ops->get_llsec_table(&rdev->wpan_phy, wpan_dev, table); +} + +static inline void +rdev_lock_llsec_table(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev) +{ + rdev->ops->lock_llsec_table(&rdev->wpan_phy, wpan_dev); +} + +static inline void +rdev_unlock_llsec_table(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev) +{ + rdev->ops->unlock_llsec_table(&rdev->wpan_phy, wpan_dev); +} + +static inline int +rdev_get_llsec_params(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, + struct ieee802154_llsec_params *params) +{ + return rdev->ops->get_llsec_params(&rdev->wpan_phy, wpan_dev, params); +} + +static inline int +rdev_set_llsec_params(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, + const struct ieee802154_llsec_params *params, + u32 changed) +{ + return rdev->ops->set_llsec_params(&rdev->wpan_phy, wpan_dev, params, + changed); +} + +static inline int +rdev_add_llsec_key(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, + const struct ieee802154_llsec_key_id *id, + const struct ieee802154_llsec_key *key) +{ + return rdev->ops->add_llsec_key(&rdev->wpan_phy, wpan_dev, id, key); +} + +static inline int +rdev_del_llsec_key(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, + const struct ieee802154_llsec_key_id *id) +{ + return rdev->ops->del_llsec_key(&rdev->wpan_phy, wpan_dev, id); +} + +static inline int +rdev_add_seclevel(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, + const struct ieee802154_llsec_seclevel *sl) +{ + return rdev->ops->add_seclevel(&rdev->wpan_phy, wpan_dev, sl); +} + +static inline int +rdev_del_seclevel(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, + const struct ieee802154_llsec_seclevel *sl) +{ + return rdev->ops->del_seclevel(&rdev->wpan_phy, wpan_dev, sl); +} + +static inline int +rdev_add_device(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, + const struct ieee802154_llsec_device *dev_desc) +{ + return rdev->ops->add_device(&rdev->wpan_phy, wpan_dev, dev_desc); +} + +static inline int +rdev_del_device(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, __le64 extended_addr) +{ + return rdev->ops->del_device(&rdev->wpan_phy, wpan_dev, extended_addr); +} + +static inline int +rdev_add_devkey(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, __le64 extended_addr, + const struct ieee802154_llsec_device_key *devkey) +{ + return rdev->ops->add_devkey(&rdev->wpan_phy, wpan_dev, extended_addr, + devkey); +} + +static inline int +rdev_del_devkey(struct cfg802154_registered_device *rdev, + struct wpan_dev *wpan_dev, __le64 extended_addr, + const struct ieee802154_llsec_device_key *devkey) +{ + return rdev->ops->del_devkey(&rdev->wpan_phy, wpan_dev, extended_addr, + devkey); +} +#endif /* CONFIG_IEEE802154_NL802154_EXPERIMENTAL */ + +#endif /* __CFG802154_RDEV_OPS */ diff --git a/net/ieee802154/socket.c b/net/ieee802154/socket.c new file mode 100644 index 000000000..1fa2fe041 --- /dev/null +++ b/net/ieee802154/socket.c @@ -0,0 +1,1143 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * IEEE802154.4 socket interface + * + * Copyright 2007, 2008 Siemens AG + * + * Written by: + * Sergey Lapin + * Maxim Gorbachyov + */ + +#include +#include +#include +#include +#include +#include /* For TIOCOUTQ/INQ */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +/* Utility function for families */ +static struct net_device* +ieee802154_get_dev(struct net *net, const struct ieee802154_addr *addr) +{ + struct net_device *dev = NULL; + struct net_device *tmp; + __le16 pan_id, short_addr; + u8 hwaddr[IEEE802154_ADDR_LEN]; + + switch (addr->mode) { + case IEEE802154_ADDR_LONG: + ieee802154_devaddr_to_raw(hwaddr, addr->extended_addr); + rcu_read_lock(); + dev = dev_getbyhwaddr_rcu(net, ARPHRD_IEEE802154, hwaddr); + dev_hold(dev); + rcu_read_unlock(); + break; + case IEEE802154_ADDR_SHORT: + if (addr->pan_id == cpu_to_le16(IEEE802154_PANID_BROADCAST) || + addr->short_addr == cpu_to_le16(IEEE802154_ADDR_UNDEF) || + addr->short_addr == cpu_to_le16(IEEE802154_ADDR_BROADCAST)) + break; + + rtnl_lock(); + + for_each_netdev(net, tmp) { + if (tmp->type != ARPHRD_IEEE802154) + continue; + + pan_id = tmp->ieee802154_ptr->pan_id; + short_addr = tmp->ieee802154_ptr->short_addr; + if (pan_id == addr->pan_id && + short_addr == addr->short_addr) { + dev = tmp; + dev_hold(dev); + break; + } + } + + rtnl_unlock(); + break; + default: + pr_warn("Unsupported ieee802154 address type: %d\n", + addr->mode); + break; + } + + return dev; +} + +static int ieee802154_sock_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + + if (sk) { + sock->sk = NULL; + sk->sk_prot->close(sk, 0); + } + return 0; +} + +static int ieee802154_sock_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + struct sock *sk = sock->sk; + + return sk->sk_prot->sendmsg(sk, msg, len); +} + +static int ieee802154_sock_bind(struct socket *sock, struct sockaddr *uaddr, + int addr_len) +{ + struct sock *sk = sock->sk; + + if (sk->sk_prot->bind) + return sk->sk_prot->bind(sk, uaddr, addr_len); + + return sock_no_bind(sock, uaddr, addr_len); +} + +static int ieee802154_sock_connect(struct socket *sock, struct sockaddr *uaddr, + int addr_len, int flags) +{ + struct sock *sk = sock->sk; + + if (addr_len < sizeof(uaddr->sa_family)) + return -EINVAL; + + if (uaddr->sa_family == AF_UNSPEC) + return sk->sk_prot->disconnect(sk, flags); + + return sk->sk_prot->connect(sk, uaddr, addr_len); +} + +static int ieee802154_dev_ioctl(struct sock *sk, struct ifreq __user *arg, + unsigned int cmd) +{ + struct ifreq ifr; + int ret = -ENOIOCTLCMD; + struct net_device *dev; + + if (get_user_ifreq(&ifr, NULL, arg)) + return -EFAULT; + + ifr.ifr_name[IFNAMSIZ-1] = 0; + + dev_load(sock_net(sk), ifr.ifr_name); + dev = dev_get_by_name(sock_net(sk), ifr.ifr_name); + + if (!dev) + return -ENODEV; + + if (dev->type == ARPHRD_IEEE802154 && dev->netdev_ops->ndo_do_ioctl) + ret = dev->netdev_ops->ndo_do_ioctl(dev, &ifr, cmd); + + if (!ret && put_user_ifreq(&ifr, arg)) + ret = -EFAULT; + dev_put(dev); + + return ret; +} + +static int ieee802154_sock_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + struct sock *sk = sock->sk; + + switch (cmd) { + case SIOCGIFADDR: + case SIOCSIFADDR: + return ieee802154_dev_ioctl(sk, (struct ifreq __user *)arg, + cmd); + default: + if (!sk->sk_prot->ioctl) + return -ENOIOCTLCMD; + return sk->sk_prot->ioctl(sk, cmd, arg); + } +} + +/* RAW Sockets (802.15.4 created in userspace) */ +static HLIST_HEAD(raw_head); +static DEFINE_RWLOCK(raw_lock); + +static int raw_hash(struct sock *sk) +{ + write_lock_bh(&raw_lock); + sk_add_node(sk, &raw_head); + write_unlock_bh(&raw_lock); + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); + + return 0; +} + +static void raw_unhash(struct sock *sk) +{ + write_lock_bh(&raw_lock); + if (sk_del_node_init(sk)) + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); + write_unlock_bh(&raw_lock); +} + +static void raw_close(struct sock *sk, long timeout) +{ + sk_common_release(sk); +} + +static int raw_bind(struct sock *sk, struct sockaddr *_uaddr, int len) +{ + struct ieee802154_addr addr; + struct sockaddr_ieee802154 *uaddr = (struct sockaddr_ieee802154 *)_uaddr; + int err = 0; + struct net_device *dev = NULL; + + err = ieee802154_sockaddr_check_size(uaddr, len); + if (err < 0) + return err; + + uaddr = (struct sockaddr_ieee802154 *)_uaddr; + if (uaddr->family != AF_IEEE802154) + return -EINVAL; + + lock_sock(sk); + + ieee802154_addr_from_sa(&addr, &uaddr->addr); + dev = ieee802154_get_dev(sock_net(sk), &addr); + if (!dev) { + err = -ENODEV; + goto out; + } + + sk->sk_bound_dev_if = dev->ifindex; + sk_dst_reset(sk); + + dev_put(dev); +out: + release_sock(sk); + + return err; +} + +static int raw_connect(struct sock *sk, struct sockaddr *uaddr, + int addr_len) +{ + return -ENOTSUPP; +} + +static int raw_disconnect(struct sock *sk, int flags) +{ + return 0; +} + +static int raw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +{ + struct net_device *dev; + unsigned int mtu; + struct sk_buff *skb; + int hlen, tlen; + int err; + + if (msg->msg_flags & MSG_OOB) { + pr_debug("msg->msg_flags = 0x%x\n", msg->msg_flags); + return -EOPNOTSUPP; + } + + lock_sock(sk); + if (!sk->sk_bound_dev_if) + dev = dev_getfirstbyhwtype(sock_net(sk), ARPHRD_IEEE802154); + else + dev = dev_get_by_index(sock_net(sk), sk->sk_bound_dev_if); + release_sock(sk); + + if (!dev) { + pr_debug("no dev\n"); + err = -ENXIO; + goto out; + } + + mtu = IEEE802154_MTU; + pr_debug("name = %s, mtu = %u\n", dev->name, mtu); + + if (size > mtu) { + pr_debug("size = %zu, mtu = %u\n", size, mtu); + err = -EMSGSIZE; + goto out_dev; + } + if (!size) { + err = 0; + goto out_dev; + } + + hlen = LL_RESERVED_SPACE(dev); + tlen = dev->needed_tailroom; + skb = sock_alloc_send_skb(sk, hlen + tlen + size, + msg->msg_flags & MSG_DONTWAIT, &err); + if (!skb) + goto out_dev; + + skb_reserve(skb, hlen); + + skb_reset_mac_header(skb); + skb_reset_network_header(skb); + + err = memcpy_from_msg(skb_put(skb, size), msg, size); + if (err < 0) + goto out_skb; + + skb->dev = dev; + skb->protocol = htons(ETH_P_IEEE802154); + + err = dev_queue_xmit(skb); + if (err > 0) + err = net_xmit_errno(err); + + dev_put(dev); + + return err ?: size; + +out_skb: + kfree_skb(skb); +out_dev: + dev_put(dev); +out: + return err; +} + +static int raw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int flags, int *addr_len) +{ + size_t copied = 0; + int err = -EOPNOTSUPP; + struct sk_buff *skb; + + skb = skb_recv_datagram(sk, flags, &err); + if (!skb) + goto out; + + copied = skb->len; + if (len < copied) { + msg->msg_flags |= MSG_TRUNC; + copied = len; + } + + err = skb_copy_datagram_msg(skb, 0, msg, copied); + if (err) + goto done; + + sock_recv_cmsgs(msg, sk, skb); + + if (flags & MSG_TRUNC) + copied = skb->len; +done: + skb_free_datagram(sk, skb); +out: + if (err) + return err; + return copied; +} + +static int raw_rcv_skb(struct sock *sk, struct sk_buff *skb) +{ + skb = skb_share_check(skb, GFP_ATOMIC); + if (!skb) + return NET_RX_DROP; + + if (sock_queue_rcv_skb(sk, skb) < 0) { + kfree_skb(skb); + return NET_RX_DROP; + } + + return NET_RX_SUCCESS; +} + +static void ieee802154_raw_deliver(struct net_device *dev, struct sk_buff *skb) +{ + struct sock *sk; + + read_lock(&raw_lock); + sk_for_each(sk, &raw_head) { + bh_lock_sock(sk); + if (!sk->sk_bound_dev_if || + sk->sk_bound_dev_if == dev->ifindex) { + struct sk_buff *clone; + + clone = skb_clone(skb, GFP_ATOMIC); + if (clone) + raw_rcv_skb(sk, clone); + } + bh_unlock_sock(sk); + } + read_unlock(&raw_lock); +} + +static int raw_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen) +{ + return -EOPNOTSUPP; +} + +static int raw_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + return -EOPNOTSUPP; +} + +static struct proto ieee802154_raw_prot = { + .name = "IEEE-802.15.4-RAW", + .owner = THIS_MODULE, + .obj_size = sizeof(struct sock), + .close = raw_close, + .bind = raw_bind, + .sendmsg = raw_sendmsg, + .recvmsg = raw_recvmsg, + .hash = raw_hash, + .unhash = raw_unhash, + .connect = raw_connect, + .disconnect = raw_disconnect, + .getsockopt = raw_getsockopt, + .setsockopt = raw_setsockopt, +}; + +static const struct proto_ops ieee802154_raw_ops = { + .family = PF_IEEE802154, + .owner = THIS_MODULE, + .release = ieee802154_sock_release, + .bind = ieee802154_sock_bind, + .connect = ieee802154_sock_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = sock_no_getname, + .poll = datagram_poll, + .ioctl = ieee802154_sock_ioctl, + .gettstamp = sock_gettstamp, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = sock_common_setsockopt, + .getsockopt = sock_common_getsockopt, + .sendmsg = ieee802154_sock_sendmsg, + .recvmsg = sock_common_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +/* DGRAM Sockets (802.15.4 dataframes) */ +static HLIST_HEAD(dgram_head); +static DEFINE_RWLOCK(dgram_lock); + +struct dgram_sock { + struct sock sk; + + struct ieee802154_addr src_addr; + struct ieee802154_addr dst_addr; + + unsigned int bound:1; + unsigned int connected:1; + unsigned int want_ack:1; + unsigned int want_lqi:1; + unsigned int secen:1; + unsigned int secen_override:1; + unsigned int seclevel:3; + unsigned int seclevel_override:1; +}; + +static inline struct dgram_sock *dgram_sk(const struct sock *sk) +{ + return container_of(sk, struct dgram_sock, sk); +} + +static int dgram_hash(struct sock *sk) +{ + write_lock_bh(&dgram_lock); + sk_add_node(sk, &dgram_head); + write_unlock_bh(&dgram_lock); + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); + + return 0; +} + +static void dgram_unhash(struct sock *sk) +{ + write_lock_bh(&dgram_lock); + if (sk_del_node_init(sk)) + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); + write_unlock_bh(&dgram_lock); +} + +static int dgram_init(struct sock *sk) +{ + struct dgram_sock *ro = dgram_sk(sk); + + ro->want_ack = 1; + ro->want_lqi = 0; + return 0; +} + +static void dgram_close(struct sock *sk, long timeout) +{ + sk_common_release(sk); +} + +static int dgram_bind(struct sock *sk, struct sockaddr *uaddr, int len) +{ + struct sockaddr_ieee802154 *addr = (struct sockaddr_ieee802154 *)uaddr; + struct ieee802154_addr haddr; + struct dgram_sock *ro = dgram_sk(sk); + int err = -EINVAL; + struct net_device *dev; + + lock_sock(sk); + + ro->bound = 0; + + err = ieee802154_sockaddr_check_size(addr, len); + if (err < 0) + goto out; + + if (addr->family != AF_IEEE802154) { + err = -EINVAL; + goto out; + } + + ieee802154_addr_from_sa(&haddr, &addr->addr); + dev = ieee802154_get_dev(sock_net(sk), &haddr); + if (!dev) { + err = -ENODEV; + goto out; + } + + if (dev->type != ARPHRD_IEEE802154) { + err = -ENODEV; + goto out_put; + } + + ro->src_addr = haddr; + + ro->bound = 1; + err = 0; +out_put: + dev_put(dev); +out: + release_sock(sk); + + return err; +} + +static int dgram_ioctl(struct sock *sk, int cmd, unsigned long arg) +{ + switch (cmd) { + case SIOCOUTQ: + { + int amount = sk_wmem_alloc_get(sk); + + return put_user(amount, (int __user *)arg); + } + + case SIOCINQ: + { + struct sk_buff *skb; + unsigned long amount; + + amount = 0; + spin_lock_bh(&sk->sk_receive_queue.lock); + skb = skb_peek(&sk->sk_receive_queue); + if (skb) { + /* We will only return the amount + * of this packet since that is all + * that will be read. + */ + amount = skb->len - ieee802154_hdr_length(skb); + } + spin_unlock_bh(&sk->sk_receive_queue.lock); + return put_user(amount, (int __user *)arg); + } + } + + return -ENOIOCTLCMD; +} + +/* FIXME: autobind */ +static int dgram_connect(struct sock *sk, struct sockaddr *uaddr, + int len) +{ + struct sockaddr_ieee802154 *addr = (struct sockaddr_ieee802154 *)uaddr; + struct dgram_sock *ro = dgram_sk(sk); + int err = 0; + + err = ieee802154_sockaddr_check_size(addr, len); + if (err < 0) + return err; + + if (addr->family != AF_IEEE802154) + return -EINVAL; + + lock_sock(sk); + + if (!ro->bound) { + err = -ENETUNREACH; + goto out; + } + + ieee802154_addr_from_sa(&ro->dst_addr, &addr->addr); + ro->connected = 1; + +out: + release_sock(sk); + return err; +} + +static int dgram_disconnect(struct sock *sk, int flags) +{ + struct dgram_sock *ro = dgram_sk(sk); + + lock_sock(sk); + ro->connected = 0; + release_sock(sk); + + return 0; +} + +static int dgram_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +{ + struct net_device *dev; + unsigned int mtu; + struct sk_buff *skb; + struct ieee802154_mac_cb *cb; + struct dgram_sock *ro = dgram_sk(sk); + struct ieee802154_addr dst_addr; + DECLARE_SOCKADDR(struct sockaddr_ieee802154*, daddr, msg->msg_name); + int hlen, tlen; + int err; + + if (msg->msg_flags & MSG_OOB) { + pr_debug("msg->msg_flags = 0x%x\n", msg->msg_flags); + return -EOPNOTSUPP; + } + + if (msg->msg_name) { + if (ro->connected) + return -EISCONN; + if (msg->msg_namelen < IEEE802154_MIN_NAMELEN) + return -EINVAL; + err = ieee802154_sockaddr_check_size(daddr, msg->msg_namelen); + if (err < 0) + return err; + ieee802154_addr_from_sa(&dst_addr, &daddr->addr); + } else { + if (!ro->connected) + return -EDESTADDRREQ; + dst_addr = ro->dst_addr; + } + + if (!ro->bound) + dev = dev_getfirstbyhwtype(sock_net(sk), ARPHRD_IEEE802154); + else + dev = ieee802154_get_dev(sock_net(sk), &ro->src_addr); + + if (!dev) { + pr_debug("no dev\n"); + err = -ENXIO; + goto out; + } + mtu = IEEE802154_MTU; + pr_debug("name = %s, mtu = %u\n", dev->name, mtu); + + if (size > mtu) { + pr_debug("size = %zu, mtu = %u\n", size, mtu); + err = -EMSGSIZE; + goto out_dev; + } + + hlen = LL_RESERVED_SPACE(dev); + tlen = dev->needed_tailroom; + skb = sock_alloc_send_skb(sk, hlen + tlen + size, + msg->msg_flags & MSG_DONTWAIT, + &err); + if (!skb) + goto out_dev; + + skb_reserve(skb, hlen); + + skb_reset_network_header(skb); + + cb = mac_cb_init(skb); + cb->type = IEEE802154_FC_TYPE_DATA; + cb->ackreq = ro->want_ack; + cb->secen = ro->secen; + cb->secen_override = ro->secen_override; + cb->seclevel = ro->seclevel; + cb->seclevel_override = ro->seclevel_override; + + err = wpan_dev_hard_header(skb, dev, &dst_addr, + ro->bound ? &ro->src_addr : NULL, size); + if (err < 0) + goto out_skb; + + err = memcpy_from_msg(skb_put(skb, size), msg, size); + if (err < 0) + goto out_skb; + + skb->dev = dev; + skb->protocol = htons(ETH_P_IEEE802154); + + err = dev_queue_xmit(skb); + if (err > 0) + err = net_xmit_errno(err); + + dev_put(dev); + + return err ?: size; + +out_skb: + kfree_skb(skb); +out_dev: + dev_put(dev); +out: + return err; +} + +static int dgram_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int flags, int *addr_len) +{ + size_t copied = 0; + int err = -EOPNOTSUPP; + struct sk_buff *skb; + struct dgram_sock *ro = dgram_sk(sk); + DECLARE_SOCKADDR(struct sockaddr_ieee802154 *, saddr, msg->msg_name); + + skb = skb_recv_datagram(sk, flags, &err); + if (!skb) + goto out; + + copied = skb->len; + if (len < copied) { + msg->msg_flags |= MSG_TRUNC; + copied = len; + } + + /* FIXME: skip headers if necessary ?! */ + err = skb_copy_datagram_msg(skb, 0, msg, copied); + if (err) + goto done; + + sock_recv_cmsgs(msg, sk, skb); + + if (saddr) { + /* Clear the implicit padding in struct sockaddr_ieee802154 + * (16 bits between 'family' and 'addr') and in struct + * ieee802154_addr_sa (16 bits at the end of the structure). + */ + memset(saddr, 0, sizeof(*saddr)); + + saddr->family = AF_IEEE802154; + ieee802154_addr_to_sa(&saddr->addr, &mac_cb(skb)->source); + *addr_len = sizeof(*saddr); + } + + if (ro->want_lqi) { + err = put_cmsg(msg, SOL_IEEE802154, WPAN_WANTLQI, + sizeof(uint8_t), &(mac_cb(skb)->lqi)); + if (err) + goto done; + } + + if (flags & MSG_TRUNC) + copied = skb->len; +done: + skb_free_datagram(sk, skb); +out: + if (err) + return err; + return copied; +} + +static int dgram_rcv_skb(struct sock *sk, struct sk_buff *skb) +{ + skb = skb_share_check(skb, GFP_ATOMIC); + if (!skb) + return NET_RX_DROP; + + if (sock_queue_rcv_skb(sk, skb) < 0) { + kfree_skb(skb); + return NET_RX_DROP; + } + + return NET_RX_SUCCESS; +} + +static inline bool +ieee802154_match_sock(__le64 hw_addr, __le16 pan_id, __le16 short_addr, + struct dgram_sock *ro) +{ + if (!ro->bound) + return true; + + if (ro->src_addr.mode == IEEE802154_ADDR_LONG && + hw_addr == ro->src_addr.extended_addr) + return true; + + if (ro->src_addr.mode == IEEE802154_ADDR_SHORT && + pan_id == ro->src_addr.pan_id && + short_addr == ro->src_addr.short_addr) + return true; + + return false; +} + +static int ieee802154_dgram_deliver(struct net_device *dev, struct sk_buff *skb) +{ + struct sock *sk, *prev = NULL; + int ret = NET_RX_SUCCESS; + __le16 pan_id, short_addr; + __le64 hw_addr; + + /* Data frame processing */ + BUG_ON(dev->type != ARPHRD_IEEE802154); + + pan_id = dev->ieee802154_ptr->pan_id; + short_addr = dev->ieee802154_ptr->short_addr; + hw_addr = dev->ieee802154_ptr->extended_addr; + + read_lock(&dgram_lock); + sk_for_each(sk, &dgram_head) { + if (ieee802154_match_sock(hw_addr, pan_id, short_addr, + dgram_sk(sk))) { + if (prev) { + struct sk_buff *clone; + + clone = skb_clone(skb, GFP_ATOMIC); + if (clone) + dgram_rcv_skb(prev, clone); + } + + prev = sk; + } + } + + if (prev) { + dgram_rcv_skb(prev, skb); + } else { + kfree_skb(skb); + ret = NET_RX_DROP; + } + read_unlock(&dgram_lock); + + return ret; +} + +static int dgram_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct dgram_sock *ro = dgram_sk(sk); + + int val, len; + + if (level != SOL_IEEE802154) + return -EOPNOTSUPP; + + if (get_user(len, optlen)) + return -EFAULT; + + len = min_t(unsigned int, len, sizeof(int)); + + switch (optname) { + case WPAN_WANTACK: + val = ro->want_ack; + break; + case WPAN_WANTLQI: + val = ro->want_lqi; + break; + case WPAN_SECURITY: + if (!ro->secen_override) + val = WPAN_SECURITY_DEFAULT; + else if (ro->secen) + val = WPAN_SECURITY_ON; + else + val = WPAN_SECURITY_OFF; + break; + case WPAN_SECURITY_LEVEL: + if (!ro->seclevel_override) + val = WPAN_SECURITY_LEVEL_DEFAULT; + else + val = ro->seclevel; + break; + default: + return -ENOPROTOOPT; + } + + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + return 0; +} + +static int dgram_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct dgram_sock *ro = dgram_sk(sk); + struct net *net = sock_net(sk); + int val; + int err = 0; + + if (optlen < sizeof(int)) + return -EINVAL; + + if (copy_from_sockptr(&val, optval, sizeof(int))) + return -EFAULT; + + lock_sock(sk); + + switch (optname) { + case WPAN_WANTACK: + ro->want_ack = !!val; + break; + case WPAN_WANTLQI: + ro->want_lqi = !!val; + break; + case WPAN_SECURITY: + if (!ns_capable(net->user_ns, CAP_NET_ADMIN) && + !ns_capable(net->user_ns, CAP_NET_RAW)) { + err = -EPERM; + break; + } + + switch (val) { + case WPAN_SECURITY_DEFAULT: + ro->secen_override = 0; + break; + case WPAN_SECURITY_ON: + ro->secen_override = 1; + ro->secen = 1; + break; + case WPAN_SECURITY_OFF: + ro->secen_override = 1; + ro->secen = 0; + break; + default: + err = -EINVAL; + break; + } + break; + case WPAN_SECURITY_LEVEL: + if (!ns_capable(net->user_ns, CAP_NET_ADMIN) && + !ns_capable(net->user_ns, CAP_NET_RAW)) { + err = -EPERM; + break; + } + + if (val < WPAN_SECURITY_LEVEL_DEFAULT || + val > IEEE802154_SCF_SECLEVEL_ENC_MIC128) { + err = -EINVAL; + } else if (val == WPAN_SECURITY_LEVEL_DEFAULT) { + ro->seclevel_override = 0; + } else { + ro->seclevel_override = 1; + ro->seclevel = val; + } + break; + default: + err = -ENOPROTOOPT; + break; + } + + release_sock(sk); + return err; +} + +static struct proto ieee802154_dgram_prot = { + .name = "IEEE-802.15.4-MAC", + .owner = THIS_MODULE, + .obj_size = sizeof(struct dgram_sock), + .init = dgram_init, + .close = dgram_close, + .bind = dgram_bind, + .sendmsg = dgram_sendmsg, + .recvmsg = dgram_recvmsg, + .hash = dgram_hash, + .unhash = dgram_unhash, + .connect = dgram_connect, + .disconnect = dgram_disconnect, + .ioctl = dgram_ioctl, + .getsockopt = dgram_getsockopt, + .setsockopt = dgram_setsockopt, +}; + +static const struct proto_ops ieee802154_dgram_ops = { + .family = PF_IEEE802154, + .owner = THIS_MODULE, + .release = ieee802154_sock_release, + .bind = ieee802154_sock_bind, + .connect = ieee802154_sock_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = sock_no_getname, + .poll = datagram_poll, + .ioctl = ieee802154_sock_ioctl, + .gettstamp = sock_gettstamp, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = sock_common_setsockopt, + .getsockopt = sock_common_getsockopt, + .sendmsg = ieee802154_sock_sendmsg, + .recvmsg = sock_common_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static void ieee802154_sock_destruct(struct sock *sk) +{ + skb_queue_purge(&sk->sk_receive_queue); +} + +/* Create a socket. Initialise the socket, blank the addresses + * set the state. + */ +static int ieee802154_create(struct net *net, struct socket *sock, + int protocol, int kern) +{ + struct sock *sk; + int rc; + struct proto *proto; + const struct proto_ops *ops; + + if (!net_eq(net, &init_net)) + return -EAFNOSUPPORT; + + switch (sock->type) { + case SOCK_RAW: + rc = -EPERM; + if (!capable(CAP_NET_RAW)) + goto out; + proto = &ieee802154_raw_prot; + ops = &ieee802154_raw_ops; + break; + case SOCK_DGRAM: + proto = &ieee802154_dgram_prot; + ops = &ieee802154_dgram_ops; + break; + default: + rc = -ESOCKTNOSUPPORT; + goto out; + } + + rc = -ENOMEM; + sk = sk_alloc(net, PF_IEEE802154, GFP_KERNEL, proto, kern); + if (!sk) + goto out; + rc = 0; + + sock->ops = ops; + + sock_init_data(sock, sk); + sk->sk_destruct = ieee802154_sock_destruct; + sk->sk_family = PF_IEEE802154; + + /* Checksums on by default */ + sock_set_flag(sk, SOCK_ZAPPED); + + if (sk->sk_prot->hash) { + rc = sk->sk_prot->hash(sk); + if (rc) { + sk_common_release(sk); + goto out; + } + } + + if (sk->sk_prot->init) { + rc = sk->sk_prot->init(sk); + if (rc) + sk_common_release(sk); + } +out: + return rc; +} + +static const struct net_proto_family ieee802154_family_ops = { + .family = PF_IEEE802154, + .create = ieee802154_create, + .owner = THIS_MODULE, +}; + +static int ieee802154_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + if (!netif_running(dev)) + goto drop; + pr_debug("got frame, type %d, dev %p\n", dev->type, dev); +#ifdef DEBUG + print_hex_dump_bytes("ieee802154_rcv ", + DUMP_PREFIX_NONE, skb->data, skb->len); +#endif + + if (!net_eq(dev_net(dev), &init_net)) + goto drop; + + ieee802154_raw_deliver(dev, skb); + + if (dev->type != ARPHRD_IEEE802154) + goto drop; + + if (skb->pkt_type != PACKET_OTHERHOST) + return ieee802154_dgram_deliver(dev, skb); + +drop: + kfree_skb(skb); + return NET_RX_DROP; +} + +static struct packet_type ieee802154_packet_type = { + .type = htons(ETH_P_IEEE802154), + .func = ieee802154_rcv, +}; + +static int __init af_ieee802154_init(void) +{ + int rc; + + rc = proto_register(&ieee802154_raw_prot, 1); + if (rc) + goto out; + + rc = proto_register(&ieee802154_dgram_prot, 1); + if (rc) + goto err_dgram; + + /* Tell SOCKET that we are alive */ + rc = sock_register(&ieee802154_family_ops); + if (rc) + goto err_sock; + dev_add_pack(&ieee802154_packet_type); + + rc = 0; + goto out; + +err_sock: + proto_unregister(&ieee802154_dgram_prot); +err_dgram: + proto_unregister(&ieee802154_raw_prot); +out: + return rc; +} + +static void __exit af_ieee802154_remove(void) +{ + dev_remove_pack(&ieee802154_packet_type); + sock_unregister(PF_IEEE802154); + proto_unregister(&ieee802154_dgram_prot); + proto_unregister(&ieee802154_raw_prot); +} + +module_init(af_ieee802154_init); +module_exit(af_ieee802154_remove); + +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_IEEE802154); diff --git a/net/ieee802154/sysfs.c b/net/ieee802154/sysfs.c new file mode 100644 index 000000000..d29039338 --- /dev/null +++ b/net/ieee802154/sysfs.c @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * + * Authors: + * Alexander Aring + * + * Based on: net/wireless/sysfs.c + */ + +#include +#include + +#include + +#include "core.h" +#include "sysfs.h" +#include "rdev-ops.h" + +static inline struct cfg802154_registered_device * +dev_to_rdev(struct device *dev) +{ + return container_of(dev, struct cfg802154_registered_device, + wpan_phy.dev); +} + +#define SHOW_FMT(name, fmt, member) \ +static ssize_t name ## _show(struct device *dev, \ + struct device_attribute *attr, \ + char *buf) \ +{ \ + return sprintf(buf, fmt "\n", dev_to_rdev(dev)->member); \ +} \ +static DEVICE_ATTR_RO(name) + +SHOW_FMT(index, "%d", wpan_phy_idx); + +static ssize_t name_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + struct wpan_phy *wpan_phy = &dev_to_rdev(dev)->wpan_phy; + + return sprintf(buf, "%s\n", dev_name(&wpan_phy->dev)); +} +static DEVICE_ATTR_RO(name); + +static void wpan_phy_release(struct device *dev) +{ + struct cfg802154_registered_device *rdev = dev_to_rdev(dev); + + cfg802154_dev_free(rdev); +} + +static struct attribute *pmib_attrs[] = { + &dev_attr_index.attr, + &dev_attr_name.attr, + NULL, +}; +ATTRIBUTE_GROUPS(pmib); + +#ifdef CONFIG_PM_SLEEP +static int wpan_phy_suspend(struct device *dev) +{ + struct cfg802154_registered_device *rdev = dev_to_rdev(dev); + int ret = 0; + + if (rdev->ops->suspend) { + rtnl_lock(); + ret = rdev_suspend(rdev); + rtnl_unlock(); + } + + return ret; +} + +static int wpan_phy_resume(struct device *dev) +{ + struct cfg802154_registered_device *rdev = dev_to_rdev(dev); + int ret = 0; + + if (rdev->ops->resume) { + rtnl_lock(); + ret = rdev_resume(rdev); + rtnl_unlock(); + } + + return ret; +} + +static SIMPLE_DEV_PM_OPS(wpan_phy_pm_ops, wpan_phy_suspend, wpan_phy_resume); +#define WPAN_PHY_PM_OPS (&wpan_phy_pm_ops) +#else +#define WPAN_PHY_PM_OPS NULL +#endif + +struct class wpan_phy_class = { + .name = "ieee802154", + .dev_release = wpan_phy_release, + .dev_groups = pmib_groups, + .pm = WPAN_PHY_PM_OPS, +}; + +int wpan_phy_sysfs_init(void) +{ + return class_register(&wpan_phy_class); +} + +void wpan_phy_sysfs_exit(void) +{ + class_unregister(&wpan_phy_class); +} diff --git a/net/ieee802154/sysfs.h b/net/ieee802154/sysfs.h new file mode 100644 index 000000000..337545b63 --- /dev/null +++ b/net/ieee802154/sysfs.h @@ -0,0 +1,10 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __IEEE802154_SYSFS_H +#define __IEEE802154_SYSFS_H + +int wpan_phy_sysfs_init(void); +void wpan_phy_sysfs_exit(void); + +extern struct class wpan_phy_class; + +#endif /* __IEEE802154_SYSFS_H */ diff --git a/net/ieee802154/trace.c b/net/ieee802154/trace.c new file mode 100644 index 000000000..95f997fad --- /dev/null +++ b/net/ieee802154/trace.c @@ -0,0 +1,7 @@ +#include + +#ifndef __CHECKER__ +#define CREATE_TRACE_POINTS +#include "trace.h" + +#endif diff --git a/net/ieee802154/trace.h b/net/ieee802154/trace.h new file mode 100644 index 000000000..19c2e5d60 --- /dev/null +++ b/net/ieee802154/trace.h @@ -0,0 +1,319 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Based on net/wireless/trace.h */ + +#undef TRACE_SYSTEM +#define TRACE_SYSTEM cfg802154 + +#if !defined(__RDEV_CFG802154_OPS_TRACE) || defined(TRACE_HEADER_MULTI_READ) +#define __RDEV_CFG802154_OPS_TRACE + +#include + +#include + +#define MAXNAME 32 +#define WPAN_PHY_ENTRY __array(char, wpan_phy_name, MAXNAME) +#define WPAN_PHY_ASSIGN strlcpy(__entry->wpan_phy_name, \ + wpan_phy_name(wpan_phy), \ + MAXNAME) +#define WPAN_PHY_PR_FMT "%s" +#define WPAN_PHY_PR_ARG __entry->wpan_phy_name + +#define WPAN_DEV_ENTRY __field(u32, identifier) +#define WPAN_DEV_ASSIGN (__entry->identifier) = (!IS_ERR_OR_NULL(wpan_dev) \ + ? wpan_dev->identifier : 0) +#define WPAN_DEV_PR_FMT "wpan_dev(%u)" +#define WPAN_DEV_PR_ARG (__entry->identifier) + +#define WPAN_CCA_ENTRY __field(enum nl802154_cca_modes, cca_mode) \ + __field(enum nl802154_cca_opts, cca_opt) +#define WPAN_CCA_ASSIGN \ + do { \ + (__entry->cca_mode) = cca->mode; \ + (__entry->cca_opt) = cca->opt; \ + } while (0) +#define WPAN_CCA_PR_FMT "cca_mode: %d, cca_opt: %d" +#define WPAN_CCA_PR_ARG __entry->cca_mode, __entry->cca_opt + +#define BOOL_TO_STR(bo) (bo) ? "true" : "false" + +/************************************************************* + * rdev->ops traces * + *************************************************************/ + +DECLARE_EVENT_CLASS(wpan_phy_only_evt, + TP_PROTO(struct wpan_phy *wpan_phy), + TP_ARGS(wpan_phy), + TP_STRUCT__entry( + WPAN_PHY_ENTRY + ), + TP_fast_assign( + WPAN_PHY_ASSIGN; + ), + TP_printk(WPAN_PHY_PR_FMT, WPAN_PHY_PR_ARG) +); + +DEFINE_EVENT(wpan_phy_only_evt, 802154_rdev_suspend, + TP_PROTO(struct wpan_phy *wpan_phy), + TP_ARGS(wpan_phy) +); + +DEFINE_EVENT(wpan_phy_only_evt, 802154_rdev_resume, + TP_PROTO(struct wpan_phy *wpan_phy), + TP_ARGS(wpan_phy) +); + +TRACE_EVENT(802154_rdev_add_virtual_intf, + TP_PROTO(struct wpan_phy *wpan_phy, char *name, + enum nl802154_iftype type, __le64 extended_addr), + TP_ARGS(wpan_phy, name, type, extended_addr), + TP_STRUCT__entry( + WPAN_PHY_ENTRY + __string(vir_intf_name, name ? name : "") + __field(enum nl802154_iftype, type) + __field(__le64, extended_addr) + ), + TP_fast_assign( + WPAN_PHY_ASSIGN; + __assign_str(vir_intf_name, name ? name : ""); + __entry->type = type; + __entry->extended_addr = extended_addr; + ), + TP_printk(WPAN_PHY_PR_FMT ", virtual intf name: %s, type: %d, extended addr: 0x%llx", + WPAN_PHY_PR_ARG, __get_str(vir_intf_name), __entry->type, + __le64_to_cpu(__entry->extended_addr)) +); + +TRACE_EVENT(802154_rdev_del_virtual_intf, + TP_PROTO(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev), + TP_ARGS(wpan_phy, wpan_dev), + TP_STRUCT__entry( + WPAN_PHY_ENTRY + WPAN_DEV_ENTRY + ), + TP_fast_assign( + WPAN_PHY_ASSIGN; + WPAN_DEV_ASSIGN; + ), + TP_printk(WPAN_PHY_PR_FMT ", " WPAN_DEV_PR_FMT, WPAN_PHY_PR_ARG, + WPAN_DEV_PR_ARG) +); + +TRACE_EVENT(802154_rdev_set_channel, + TP_PROTO(struct wpan_phy *wpan_phy, u8 page, u8 channel), + TP_ARGS(wpan_phy, page, channel), + TP_STRUCT__entry( + WPAN_PHY_ENTRY + __field(u8, page) + __field(u8, channel) + ), + TP_fast_assign( + WPAN_PHY_ASSIGN; + __entry->page = page; + __entry->channel = channel; + ), + TP_printk(WPAN_PHY_PR_FMT ", page: %d, channel: %d", WPAN_PHY_PR_ARG, + __entry->page, __entry->channel) +); + +TRACE_EVENT(802154_rdev_set_tx_power, + TP_PROTO(struct wpan_phy *wpan_phy, s32 power), + TP_ARGS(wpan_phy, power), + TP_STRUCT__entry( + WPAN_PHY_ENTRY + __field(s32, power) + ), + TP_fast_assign( + WPAN_PHY_ASSIGN; + __entry->power = power; + ), + TP_printk(WPAN_PHY_PR_FMT ", mbm: %d", WPAN_PHY_PR_ARG, + __entry->power) +); + +TRACE_EVENT(802154_rdev_set_cca_mode, + TP_PROTO(struct wpan_phy *wpan_phy, const struct wpan_phy_cca *cca), + TP_ARGS(wpan_phy, cca), + TP_STRUCT__entry( + WPAN_PHY_ENTRY + WPAN_CCA_ENTRY + ), + TP_fast_assign( + WPAN_PHY_ASSIGN; + WPAN_CCA_ASSIGN; + ), + TP_printk(WPAN_PHY_PR_FMT ", " WPAN_CCA_PR_FMT, WPAN_PHY_PR_ARG, + WPAN_CCA_PR_ARG) +); + +TRACE_EVENT(802154_rdev_set_cca_ed_level, + TP_PROTO(struct wpan_phy *wpan_phy, s32 ed_level), + TP_ARGS(wpan_phy, ed_level), + TP_STRUCT__entry( + WPAN_PHY_ENTRY + __field(s32, ed_level) + ), + TP_fast_assign( + WPAN_PHY_ASSIGN; + __entry->ed_level = ed_level; + ), + TP_printk(WPAN_PHY_PR_FMT ", ed level: %d", WPAN_PHY_PR_ARG, + __entry->ed_level) +); + +DECLARE_EVENT_CLASS(802154_le16_template, + TP_PROTO(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + __le16 le16arg), + TP_ARGS(wpan_phy, wpan_dev, le16arg), + TP_STRUCT__entry( + WPAN_PHY_ENTRY + WPAN_DEV_ENTRY + __field(__le16, le16arg) + ), + TP_fast_assign( + WPAN_PHY_ASSIGN; + WPAN_DEV_ASSIGN; + __entry->le16arg = le16arg; + ), + TP_printk(WPAN_PHY_PR_FMT ", " WPAN_DEV_PR_FMT ", pan id: 0x%04x", + WPAN_PHY_PR_ARG, WPAN_DEV_PR_ARG, + __le16_to_cpu(__entry->le16arg)) +); + +DEFINE_EVENT(802154_le16_template, 802154_rdev_set_pan_id, + TP_PROTO(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + __le16 le16arg), + TP_ARGS(wpan_phy, wpan_dev, le16arg) +); + +DEFINE_EVENT_PRINT(802154_le16_template, 802154_rdev_set_short_addr, + TP_PROTO(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + __le16 le16arg), + TP_ARGS(wpan_phy, wpan_dev, le16arg), + TP_printk(WPAN_PHY_PR_FMT ", " WPAN_DEV_PR_FMT ", short addr: 0x%04x", + WPAN_PHY_PR_ARG, WPAN_DEV_PR_ARG, + __le16_to_cpu(__entry->le16arg)) +); + +TRACE_EVENT(802154_rdev_set_backoff_exponent, + TP_PROTO(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + u8 min_be, u8 max_be), + TP_ARGS(wpan_phy, wpan_dev, min_be, max_be), + TP_STRUCT__entry( + WPAN_PHY_ENTRY + WPAN_DEV_ENTRY + __field(u8, min_be) + __field(u8, max_be) + ), + TP_fast_assign( + WPAN_PHY_ASSIGN; + WPAN_DEV_ASSIGN; + __entry->min_be = min_be; + __entry->max_be = max_be; + ), + + TP_printk(WPAN_PHY_PR_FMT ", " WPAN_DEV_PR_FMT + ", min be: %d, max be: %d", WPAN_PHY_PR_ARG, + WPAN_DEV_PR_ARG, __entry->min_be, __entry->max_be) +); + +TRACE_EVENT(802154_rdev_set_csma_backoffs, + TP_PROTO(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + u8 max_csma_backoffs), + TP_ARGS(wpan_phy, wpan_dev, max_csma_backoffs), + TP_STRUCT__entry( + WPAN_PHY_ENTRY + WPAN_DEV_ENTRY + __field(u8, max_csma_backoffs) + ), + TP_fast_assign( + WPAN_PHY_ASSIGN; + WPAN_DEV_ASSIGN; + __entry->max_csma_backoffs = max_csma_backoffs; + ), + + TP_printk(WPAN_PHY_PR_FMT ", " WPAN_DEV_PR_FMT + ", max csma backoffs: %d", WPAN_PHY_PR_ARG, + WPAN_DEV_PR_ARG, __entry->max_csma_backoffs) +); + +TRACE_EVENT(802154_rdev_set_max_frame_retries, + TP_PROTO(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + s8 max_frame_retries), + TP_ARGS(wpan_phy, wpan_dev, max_frame_retries), + TP_STRUCT__entry( + WPAN_PHY_ENTRY + WPAN_DEV_ENTRY + __field(s8, max_frame_retries) + ), + TP_fast_assign( + WPAN_PHY_ASSIGN; + WPAN_DEV_ASSIGN; + __entry->max_frame_retries = max_frame_retries; + ), + + TP_printk(WPAN_PHY_PR_FMT ", " WPAN_DEV_PR_FMT + ", max frame retries: %d", WPAN_PHY_PR_ARG, + WPAN_DEV_PR_ARG, __entry->max_frame_retries) +); + +TRACE_EVENT(802154_rdev_set_lbt_mode, + TP_PROTO(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + bool mode), + TP_ARGS(wpan_phy, wpan_dev, mode), + TP_STRUCT__entry( + WPAN_PHY_ENTRY + WPAN_DEV_ENTRY + __field(bool, mode) + ), + TP_fast_assign( + WPAN_PHY_ASSIGN; + WPAN_DEV_ASSIGN; + __entry->mode = mode; + ), + TP_printk(WPAN_PHY_PR_FMT ", " WPAN_DEV_PR_FMT + ", lbt mode: %s", WPAN_PHY_PR_ARG, + WPAN_DEV_PR_ARG, BOOL_TO_STR(__entry->mode)) +); + +TRACE_EVENT(802154_rdev_set_ackreq_default, + TP_PROTO(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + bool ackreq), + TP_ARGS(wpan_phy, wpan_dev, ackreq), + TP_STRUCT__entry( + WPAN_PHY_ENTRY + WPAN_DEV_ENTRY + __field(bool, ackreq) + ), + TP_fast_assign( + WPAN_PHY_ASSIGN; + WPAN_DEV_ASSIGN; + __entry->ackreq = ackreq; + ), + TP_printk(WPAN_PHY_PR_FMT ", " WPAN_DEV_PR_FMT + ", ackreq default: %s", WPAN_PHY_PR_ARG, + WPAN_DEV_PR_ARG, BOOL_TO_STR(__entry->ackreq)) +); + +TRACE_EVENT(802154_rdev_return_int, + TP_PROTO(struct wpan_phy *wpan_phy, int ret), + TP_ARGS(wpan_phy, ret), + TP_STRUCT__entry( + WPAN_PHY_ENTRY + __field(int, ret) + ), + TP_fast_assign( + WPAN_PHY_ASSIGN; + __entry->ret = ret; + ), + TP_printk(WPAN_PHY_PR_FMT ", returned: %d", WPAN_PHY_PR_ARG, + __entry->ret) +); + +#endif /* !__RDEV_CFG802154_OPS_TRACE || TRACE_HEADER_MULTI_READ */ + +#undef TRACE_INCLUDE_PATH +#define TRACE_INCLUDE_PATH . +#undef TRACE_INCLUDE_FILE +#define TRACE_INCLUDE_FILE trace +#include diff --git a/net/ife/Kconfig b/net/ife/Kconfig new file mode 100644 index 000000000..de36a5b91 --- /dev/null +++ b/net/ife/Kconfig @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# IFE subsystem configuration +# + +menuconfig NET_IFE + tristate "Inter-FE based on IETF ForCES InterFE LFB" + default n + help + Say Y here to add support of IFE encapsulation protocol + For details refer to netdev01 paper: + "Distributing Linux Traffic Control Classifier-Action Subsystem" + Authors: Jamal Hadi Salim and Damascene M. Joachimpillai + + To compile this support as a module, choose M here: the module will + be called ife. diff --git a/net/ife/Makefile b/net/ife/Makefile new file mode 100644 index 000000000..1258fcb07 --- /dev/null +++ b/net/ife/Makefile @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the IFE encapsulation protocol +# + +obj-$(CONFIG_NET_IFE) += ife.o diff --git a/net/ife/ife.c b/net/ife/ife.c new file mode 100644 index 000000000..be05b690b --- /dev/null +++ b/net/ife/ife.c @@ -0,0 +1,177 @@ +/* + * net/ife/ife.c - Inter-FE protocol based on ForCES WG InterFE LFB + * Copyright (c) 2015 Jamal Hadi Salim + * Copyright (c) 2017 Yotam Gigi + * + * Refer to: draft-ietf-forces-interfelfb-03 and netdev01 paper: + * "Distributing Linux Traffic Control Classifier-Action Subsystem" + * Authors: Jamal Hadi Salim and Damascene M. Joachimpillai + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct ifeheadr { + __be16 metalen; + u8 tlv_data[]; +}; + +void *ife_encode(struct sk_buff *skb, u16 metalen) +{ + /* OUTERHDR:TOTMETALEN:{TLVHDR:Metadatum:TLVHDR..}:ORIGDATA + * where ORIGDATA = original ethernet header ... + */ + int hdrm = metalen + IFE_METAHDRLEN; + int total_push = hdrm + skb->dev->hard_header_len; + struct ifeheadr *ifehdr; + struct ethhdr *iethh; /* inner ether header */ + int skboff = 0; + int err; + + err = skb_cow_head(skb, total_push); + if (unlikely(err)) + return NULL; + + iethh = (struct ethhdr *) skb->data; + + __skb_push(skb, total_push); + memcpy(skb->data, iethh, skb->dev->hard_header_len); + skb_reset_mac_header(skb); + skboff += skb->dev->hard_header_len; + + /* total metadata length */ + ifehdr = (struct ifeheadr *) (skb->data + skboff); + metalen += IFE_METAHDRLEN; + ifehdr->metalen = htons(metalen); + + return ifehdr->tlv_data; +} +EXPORT_SYMBOL_GPL(ife_encode); + +void *ife_decode(struct sk_buff *skb, u16 *metalen) +{ + struct ifeheadr *ifehdr; + int total_pull; + u16 ifehdrln; + + if (!pskb_may_pull(skb, skb->dev->hard_header_len + IFE_METAHDRLEN)) + return NULL; + + ifehdr = (struct ifeheadr *) (skb->data + skb->dev->hard_header_len); + ifehdrln = ntohs(ifehdr->metalen); + total_pull = skb->dev->hard_header_len + ifehdrln; + + if (unlikely(ifehdrln < 2)) + return NULL; + + if (unlikely(!pskb_may_pull(skb, total_pull))) + return NULL; + + ifehdr = (struct ifeheadr *)(skb->data + skb->dev->hard_header_len); + skb_set_mac_header(skb, total_pull); + __skb_pull(skb, total_pull); + *metalen = ifehdrln - IFE_METAHDRLEN; + + return &ifehdr->tlv_data; +} +EXPORT_SYMBOL_GPL(ife_decode); + +struct meta_tlvhdr { + __be16 type; + __be16 len; +}; + +static bool __ife_tlv_meta_valid(const unsigned char *skbdata, + const unsigned char *ifehdr_end) +{ + const struct meta_tlvhdr *tlv; + u16 tlvlen; + + if (unlikely(skbdata + sizeof(*tlv) > ifehdr_end)) + return false; + + tlv = (const struct meta_tlvhdr *)skbdata; + tlvlen = ntohs(tlv->len); + + /* tlv length field is inc header, check on minimum */ + if (tlvlen < NLA_HDRLEN) + return false; + + /* overflow by NLA_ALIGN check */ + if (NLA_ALIGN(tlvlen) < tlvlen) + return false; + + if (unlikely(skbdata + NLA_ALIGN(tlvlen) > ifehdr_end)) + return false; + + return true; +} + +/* Caller takes care of presenting data in network order + */ +void *ife_tlv_meta_decode(void *skbdata, const void *ifehdr_end, u16 *attrtype, + u16 *dlen, u16 *totlen) +{ + struct meta_tlvhdr *tlv; + + if (!__ife_tlv_meta_valid(skbdata, ifehdr_end)) + return NULL; + + tlv = (struct meta_tlvhdr *)skbdata; + *dlen = ntohs(tlv->len) - NLA_HDRLEN; + *attrtype = ntohs(tlv->type); + + if (totlen) + *totlen = nla_total_size(*dlen); + + return skbdata + sizeof(struct meta_tlvhdr); +} +EXPORT_SYMBOL_GPL(ife_tlv_meta_decode); + +void *ife_tlv_meta_next(void *skbdata) +{ + struct meta_tlvhdr *tlv = (struct meta_tlvhdr *) skbdata; + u16 tlvlen = ntohs(tlv->len); + + tlvlen = NLA_ALIGN(tlvlen); + + return skbdata + tlvlen; +} +EXPORT_SYMBOL_GPL(ife_tlv_meta_next); + +/* Caller takes care of presenting data in network order + */ +int ife_tlv_meta_encode(void *skbdata, u16 attrtype, u16 dlen, const void *dval) +{ + __be32 *tlv = (__be32 *) (skbdata); + u16 totlen = nla_total_size(dlen); /*alignment + hdr */ + char *dptr = (char *) tlv + NLA_HDRLEN; + u32 htlv = attrtype << 16 | (dlen + NLA_HDRLEN); + + *tlv = htonl(htlv); + memset(dptr, 0, totlen - NLA_HDRLEN); + memcpy(dptr, dval, dlen); + + return totlen; +} +EXPORT_SYMBOL_GPL(ife_tlv_meta_encode); + +MODULE_AUTHOR("Jamal Hadi Salim "); +MODULE_AUTHOR("Yotam Gigi "); +MODULE_DESCRIPTION("Inter-FE LFB action"); +MODULE_LICENSE("GPL"); diff --git a/net/ipv4/Kconfig b/net/ipv4/Kconfig new file mode 100644 index 000000000..2dfb12230 --- /dev/null +++ b/net/ipv4/Kconfig @@ -0,0 +1,753 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# IP configuration +# +config IP_MULTICAST + bool "IP: multicasting" + help + This is code for addressing several networked computers at once, + enlarging your kernel by about 2 KB. You need multicasting if you + intend to participate in the MBONE, a high bandwidth network on top + of the Internet which carries audio and video broadcasts. More + information about the MBONE is on the WWW at + . For most people, it's safe to say N. + +config IP_ADVANCED_ROUTER + bool "IP: advanced router" + help + If you intend to run your Linux box mostly as a router, i.e. as a + computer that forwards and redistributes network packets, say Y; you + will then be presented with several options that allow more precise + control about the routing process. + + The answer to this question won't directly affect the kernel: + answering N will just cause the configurator to skip all the + questions about advanced routing. + + Note that your box can only act as a router if you enable IP + forwarding in your kernel; you can do that by saying Y to "/proc + file system support" and "Sysctl support" below and executing the + line + + echo "1" > /proc/sys/net/ipv4/ip_forward + + at boot time after the /proc file system has been mounted. + + If you turn on IP forwarding, you should consider the rp_filter, which + automatically rejects incoming packets if the routing table entry + for their source address doesn't match the network interface they're + arriving on. This has security advantages because it prevents the + so-called IP spoofing, however it can pose problems if you use + asymmetric routing (packets from you to a host take a different path + than packets from that host to you) or if you operate a non-routing + host which has several IP addresses on different interfaces. To turn + rp_filter on use: + + echo 1 > /proc/sys/net/ipv4/conf//rp_filter + or + echo 1 > /proc/sys/net/ipv4/conf/all/rp_filter + + Note that some distributions enable it in startup scripts. + For details about rp_filter strict and loose mode read + . + + If unsure, say N here. + +config IP_FIB_TRIE_STATS + bool "FIB TRIE statistics" + depends on IP_ADVANCED_ROUTER + help + Keep track of statistics on structure of FIB TRIE table. + Useful for testing and measuring TRIE performance. + +config IP_MULTIPLE_TABLES + bool "IP: policy routing" + depends on IP_ADVANCED_ROUTER + select FIB_RULES + help + Normally, a router decides what to do with a received packet based + solely on the packet's final destination address. If you say Y here, + the Linux router will also be able to take the packet's source + address into account. Furthermore, the TOS (Type-Of-Service) field + of the packet can be used for routing decisions as well. + + If you need more information, see the Linux Advanced + Routing and Traffic Control documentation at + + + If unsure, say N. + +config IP_ROUTE_MULTIPATH + bool "IP: equal cost multipath" + depends on IP_ADVANCED_ROUTER + help + Normally, the routing tables specify a single action to be taken in + a deterministic manner for a given packet. If you say Y here + however, it becomes possible to attach several actions to a packet + pattern, in effect specifying several alternative paths to travel + for those packets. The router considers all these paths to be of + equal "cost" and chooses one of them in a non-deterministic fashion + if a matching packet arrives. + +config IP_ROUTE_VERBOSE + bool "IP: verbose route monitoring" + depends on IP_ADVANCED_ROUTER + help + If you say Y here, which is recommended, then the kernel will print + verbose messages regarding the routing, for example warnings about + received packets which look strange and could be evidence of an + attack or a misconfigured system somewhere. The information is + handled by the klogd daemon which is responsible for kernel messages + ("man klogd"). + +config IP_ROUTE_CLASSID + bool + +config IP_PNP + bool "IP: kernel level autoconfiguration" + help + This enables automatic configuration of IP addresses of devices and + of the routing table during kernel boot, based on either information + supplied on the kernel command line or by BOOTP or RARP protocols. + You need to say Y only for diskless machines requiring network + access to boot (in which case you want to say Y to "Root file system + on NFS" as well), because all other machines configure the network + in their startup scripts. + +config IP_PNP_DHCP + bool "IP: DHCP support" + depends on IP_PNP + help + If you want your Linux box to mount its whole root file system (the + one containing the directory /) from some other computer over the + net via NFS and you want the IP address of your computer to be + discovered automatically at boot time using the DHCP protocol (a + special protocol designed for doing this job), say Y here. In case + the boot ROM of your network card was designed for booting Linux and + does DHCP itself, providing all necessary information on the kernel + command line, you can say N here. + + If unsure, say Y. Note that if you want to use DHCP, a DHCP server + must be operating on your network. Read + for details. + +config IP_PNP_BOOTP + bool "IP: BOOTP support" + depends on IP_PNP + help + If you want your Linux box to mount its whole root file system (the + one containing the directory /) from some other computer over the + net via NFS and you want the IP address of your computer to be + discovered automatically at boot time using the BOOTP protocol (a + special protocol designed for doing this job), say Y here. In case + the boot ROM of your network card was designed for booting Linux and + does BOOTP itself, providing all necessary information on the kernel + command line, you can say N here. If unsure, say Y. Note that if you + want to use BOOTP, a BOOTP server must be operating on your network. + Read for details. + +config IP_PNP_RARP + bool "IP: RARP support" + depends on IP_PNP + help + If you want your Linux box to mount its whole root file system (the + one containing the directory /) from some other computer over the + net via NFS and you want the IP address of your computer to be + discovered automatically at boot time using the RARP protocol (an + older protocol which is being obsoleted by BOOTP and DHCP), say Y + here. Note that if you want to use RARP, a RARP server must be + operating on your network. Read + for details. + +config NET_IPIP + tristate "IP: tunneling" + select INET_TUNNEL + select NET_IP_TUNNEL + help + Tunneling means encapsulating data of one protocol type within + another protocol and sending it over a channel that understands the + encapsulating protocol. This particular tunneling driver implements + encapsulation of IP within IP, which sounds kind of pointless, but + can be useful if you want to make your (or some other) machine + appear on a different network than it physically is, or to use + mobile-IP facilities (allowing laptops to seamlessly move between + networks without changing their IP addresses). + + Saying Y to this option will produce two modules ( = code which can + be inserted in and removed from the running kernel whenever you + want). Most people won't need this and can say N. + +config NET_IPGRE_DEMUX + tristate "IP: GRE demultiplexer" + help + This is helper module to demultiplex GRE packets on GRE version field criteria. + Required by ip_gre and pptp modules. + +config NET_IP_TUNNEL + tristate + select DST_CACHE + select GRO_CELLS + default n + +config NET_IPGRE + tristate "IP: GRE tunnels over IP" + depends on (IPV6 || IPV6=n) && NET_IPGRE_DEMUX + select NET_IP_TUNNEL + help + Tunneling means encapsulating data of one protocol type within + another protocol and sending it over a channel that understands the + encapsulating protocol. This particular tunneling driver implements + GRE (Generic Routing Encapsulation) and at this time allows + encapsulating of IPv4 or IPv6 over existing IPv4 infrastructure. + This driver is useful if the other endpoint is a Cisco router: Cisco + likes GRE much better than the other Linux tunneling driver ("IP + tunneling" above). In addition, GRE allows multicast redistribution + through the tunnel. + +config NET_IPGRE_BROADCAST + bool "IP: broadcast GRE over IP" + depends on IP_MULTICAST && NET_IPGRE + help + One application of GRE/IP is to construct a broadcast WAN (Wide Area + Network), which looks like a normal Ethernet LAN (Local Area + Network), but can be distributed all over the Internet. If you want + to do that, say Y here and to "IP multicast routing" below. + +config IP_MROUTE_COMMON + bool + depends on IP_MROUTE || IPV6_MROUTE + +config IP_MROUTE + bool "IP: multicast routing" + depends on IP_MULTICAST + select IP_MROUTE_COMMON + help + This is used if you want your machine to act as a router for IP + packets that have several destination addresses. It is needed on the + MBONE, a high bandwidth network on top of the Internet which carries + audio and video broadcasts. In order to do that, you would most + likely run the program mrouted. If you haven't heard about it, you + don't need it. + +config IP_MROUTE_MULTIPLE_TABLES + bool "IP: multicast policy routing" + depends on IP_MROUTE && IP_ADVANCED_ROUTER + select FIB_RULES + help + Normally, a multicast router runs a userspace daemon and decides + what to do with a multicast packet based on the source and + destination addresses. If you say Y here, the multicast router + will also be able to take interfaces and packet marks into + account and run multiple instances of userspace daemons + simultaneously, each one handling a single table. + + If unsure, say N. + +config IP_PIMSM_V1 + bool "IP: PIM-SM version 1 support" + depends on IP_MROUTE + help + Kernel side support for Sparse Mode PIM (Protocol Independent + Multicast) version 1. This multicast routing protocol is used widely + because Cisco supports it. You need special software to use it + (pimd-v1). Please see for more + information about PIM. + + Say Y if you want to use PIM-SM v1. Note that you can say N here if + you just want to use Dense Mode PIM. + +config IP_PIMSM_V2 + bool "IP: PIM-SM version 2 support" + depends on IP_MROUTE + help + Kernel side support for Sparse Mode PIM version 2. In order to use + this, you need an experimental routing daemon supporting it (pimd or + gated-5). This routing protocol is not used widely, so say N unless + you want to play with it. + +config SYN_COOKIES + bool "IP: TCP syncookie support" + help + Normal TCP/IP networking is open to an attack known as "SYN + flooding". This denial-of-service attack prevents legitimate remote + users from being able to connect to your computer during an ongoing + attack and requires very little work from the attacker, who can + operate from anywhere on the Internet. + + SYN cookies provide protection against this type of attack. If you + say Y here, the TCP/IP stack will use a cryptographic challenge + protocol known as "SYN cookies" to enable legitimate users to + continue to connect, even when your machine is under attack. There + is no need for the legitimate users to change their TCP/IP software; + SYN cookies work transparently to them. For technical information + about SYN cookies, check out . + + If you are SYN flooded, the source address reported by the kernel is + likely to have been forged by the attacker; it is only reported as + an aid in tracing the packets to their actual source and should not + be taken as absolute truth. + + SYN cookies may prevent correct error reporting on clients when the + server is really overloaded. If this happens frequently better turn + them off. + + If you say Y here, you can disable SYN cookies at run time by + saying Y to "/proc file system support" and + "Sysctl support" below and executing the command + + echo 0 > /proc/sys/net/ipv4/tcp_syncookies + + after the /proc file system has been mounted. + + If unsure, say N. + +config NET_IPVTI + tristate "Virtual (secure) IP: tunneling" + depends on IPV6 || IPV6=n + select INET_TUNNEL + select NET_IP_TUNNEL + select XFRM + help + Tunneling means encapsulating data of one protocol type within + another protocol and sending it over a channel that understands the + encapsulating protocol. This can be used with xfrm mode tunnel to give + the notion of a secure tunnel for IPSEC and then use routing protocol + on top. + +config NET_UDP_TUNNEL + tristate + select NET_IP_TUNNEL + default n + +config NET_FOU + tristate "IP: Foo (IP protocols) over UDP" + select NET_UDP_TUNNEL + help + Foo over UDP allows any IP protocol to be directly encapsulated + over UDP include tunnels (IPIP, GRE, SIT). By encapsulating in UDP + network mechanisms and optimizations for UDP (such as ECMP + and RSS) can be leveraged to provide better service. + +config NET_FOU_IP_TUNNELS + bool "IP: FOU encapsulation of IP tunnels" + depends on NET_IPIP || NET_IPGRE || IPV6_SIT + select NET_FOU + help + Allow configuration of FOU or GUE encapsulation for IP tunnels. + When this option is enabled IP tunnels can be configured to use + FOU or GUE encapsulation. + +config INET_AH + tristate "IP: AH transformation" + select XFRM_AH + help + Support for IPsec AH (Authentication Header). + + AH can be used with various authentication algorithms. Besides + enabling AH support itself, this option enables the generic + implementations of the algorithms that RFC 8221 lists as MUST be + implemented. If you need any other algorithms, you'll need to enable + them in the crypto API. You should also enable accelerated + implementations of any needed algorithms when available. + + If unsure, say Y. + +config INET_ESP + tristate "IP: ESP transformation" + select XFRM_ESP + help + Support for IPsec ESP (Encapsulating Security Payload). + + ESP can be used with various encryption and authentication algorithms. + Besides enabling ESP support itself, this option enables the generic + implementations of the algorithms that RFC 8221 lists as MUST be + implemented. If you need any other algorithms, you'll need to enable + them in the crypto API. You should also enable accelerated + implementations of any needed algorithms when available. + + If unsure, say Y. + +config INET_ESP_OFFLOAD + tristate "IP: ESP transformation offload" + depends on INET_ESP + select XFRM_OFFLOAD + default n + help + Support for ESP transformation offload. This makes sense + only if this system really does IPsec and want to do it + with high throughput. A typical desktop system does not + need it, even if it does IPsec. + + If unsure, say N. + +config INET_ESPINTCP + bool "IP: ESP in TCP encapsulation (RFC 8229)" + depends on XFRM && INET_ESP + select STREAM_PARSER + select NET_SOCK_MSG + select XFRM_ESPINTCP + help + Support for RFC 8229 encapsulation of ESP and IKE over + TCP/IPv4 sockets. + + If unsure, say N. + +config INET_IPCOMP + tristate "IP: IPComp transformation" + select INET_XFRM_TUNNEL + select XFRM_IPCOMP + help + Support for IP Payload Compression Protocol (IPComp) (RFC3173), + typically needed for IPsec. + + If unsure, say Y. + +config INET_TABLE_PERTURB_ORDER + int "INET: Source port perturbation table size (as power of 2)" if EXPERT + default 16 + help + Source port perturbation table size (as power of 2) for + RFC 6056 3.3.4. Algorithm 4: Double-Hash Port Selection Algorithm. + + The default is almost always what you want. + Only change this if you know what you are doing. + +config INET_XFRM_TUNNEL + tristate + select INET_TUNNEL + default n + +config INET_TUNNEL + tristate + default n + +config INET_DIAG + tristate "INET: socket monitoring interface" + default y + help + Support for INET (TCP, DCCP, etc) socket monitoring interface used by + native Linux tools such as ss. ss is included in iproute2, currently + downloadable at: + + http://www.linuxfoundation.org/collaborate/workgroups/networking/iproute2 + + If unsure, say Y. + +config INET_TCP_DIAG + depends on INET_DIAG + def_tristate INET_DIAG + +config INET_UDP_DIAG + tristate "UDP: socket monitoring interface" + depends on INET_DIAG && (IPV6 || IPV6=n) + default n + help + Support for UDP socket monitoring interface used by the ss tool. + If unsure, say Y. + +config INET_RAW_DIAG + tristate "RAW: socket monitoring interface" + depends on INET_DIAG && (IPV6 || IPV6=n) + default n + help + Support for RAW socket monitoring interface used by the ss tool. + If unsure, say Y. + +config INET_DIAG_DESTROY + bool "INET: allow privileged process to administratively close sockets" + depends on INET_DIAG + default n + help + Provides a SOCK_DESTROY operation that allows privileged processes + (e.g., a connection manager or a network administration tool such as + ss) to close sockets opened by other processes. Closing a socket in + this way interrupts any blocking read/write/connect operations on + the socket and causes future socket calls to behave as if the socket + had been disconnected. + If unsure, say N. + +menuconfig TCP_CONG_ADVANCED + bool "TCP: advanced congestion control" + help + Support for selection of various TCP congestion control + modules. + + Nearly all users can safely say no here, and a safe default + selection will be made (CUBIC with new Reno as a fallback). + + If unsure, say N. + +if TCP_CONG_ADVANCED + +config TCP_CONG_BIC + tristate "Binary Increase Congestion (BIC) control" + default m + help + BIC-TCP is a sender-side only change that ensures a linear RTT + fairness under large windows while offering both scalability and + bounded TCP-friendliness. The protocol combines two schemes + called additive increase and binary search increase. When the + congestion window is large, additive increase with a large + increment ensures linear RTT fairness as well as good + scalability. Under small congestion windows, binary search + increase provides TCP friendliness. + See http://www.csc.ncsu.edu/faculty/rhee/export/bitcp/ + +config TCP_CONG_CUBIC + tristate "CUBIC TCP" + default y + help + This is version 2.0 of BIC-TCP which uses a cubic growth function + among other techniques. + See http://www.csc.ncsu.edu/faculty/rhee/export/bitcp/cubic-paper.pdf + +config TCP_CONG_WESTWOOD + tristate "TCP Westwood+" + default m + help + TCP Westwood+ is a sender-side only modification of the TCP Reno + protocol stack that optimizes the performance of TCP congestion + control. It is based on end-to-end bandwidth estimation to set + congestion window and slow start threshold after a congestion + episode. Using this estimation, TCP Westwood+ adaptively sets a + slow start threshold and a congestion window which takes into + account the bandwidth used at the time congestion is experienced. + TCP Westwood+ significantly increases fairness wrt TCP Reno in + wired networks and throughput over wireless links. + +config TCP_CONG_HTCP + tristate "H-TCP" + default m + help + H-TCP is a send-side only modifications of the TCP Reno + protocol stack that optimizes the performance of TCP + congestion control for high speed network links. It uses a + modeswitch to change the alpha and beta parameters of TCP Reno + based on network conditions and in a way so as to be fair with + other Reno and H-TCP flows. + +config TCP_CONG_HSTCP + tristate "High Speed TCP" + default n + help + Sally Floyd's High Speed TCP (RFC 3649) congestion control. + A modification to TCP's congestion control mechanism for use + with large congestion windows. A table indicates how much to + increase the congestion window by when an ACK is received. + For more detail see https://www.icir.org/floyd/hstcp.html + +config TCP_CONG_HYBLA + tristate "TCP-Hybla congestion control algorithm" + default n + help + TCP-Hybla is a sender-side only change that eliminates penalization of + long-RTT, large-bandwidth connections, like when satellite legs are + involved, especially when sharing a common bottleneck with normal + terrestrial connections. + +config TCP_CONG_VEGAS + tristate "TCP Vegas" + default n + help + TCP Vegas is a sender-side only change to TCP that anticipates + the onset of congestion by estimating the bandwidth. TCP Vegas + adjusts the sending rate by modifying the congestion + window. TCP Vegas should provide less packet loss, but it is + not as aggressive as TCP Reno. + +config TCP_CONG_NV + tristate "TCP NV" + default n + help + TCP NV is a follow up to TCP Vegas. It has been modified to deal with + 10G networks, measurement noise introduced by LRO, GRO and interrupt + coalescence. In addition, it will decrease its cwnd multiplicatively + instead of linearly. + + Note that in general congestion avoidance (cwnd decreased when # packets + queued grows) cannot coexist with congestion control (cwnd decreased only + when there is packet loss) due to fairness issues. One scenario when they + can coexist safely is when the CA flows have RTTs << CC flows RTTs. + + For further details see http://www.brakmo.org/networking/tcp-nv/ + +config TCP_CONG_SCALABLE + tristate "Scalable TCP" + default n + help + Scalable TCP is a sender-side only change to TCP which uses a + MIMD congestion control algorithm which has some nice scaling + properties, though is known to have fairness issues. + See http://www.deneholme.net/tom/scalable/ + +config TCP_CONG_LP + tristate "TCP Low Priority" + default n + help + TCP Low Priority (TCP-LP), a distributed algorithm whose goal is + to utilize only the excess network bandwidth as compared to the + ``fair share`` of bandwidth as targeted by TCP. + See http://www-ece.rice.edu/networks/TCP-LP/ + +config TCP_CONG_VENO + tristate "TCP Veno" + default n + help + TCP Veno is a sender-side only enhancement of TCP to obtain better + throughput over wireless networks. TCP Veno makes use of state + distinguishing to circumvent the difficult judgment of the packet loss + type. TCP Veno cuts down less congestion window in response to random + loss packets. + See + +config TCP_CONG_YEAH + tristate "YeAH TCP" + select TCP_CONG_VEGAS + default n + help + YeAH-TCP is a sender-side high-speed enabled TCP congestion control + algorithm, which uses a mixed loss/delay approach to compute the + congestion window. It's design goals target high efficiency, + internal, RTT and Reno fairness, resilience to link loss while + keeping network elements load as low as possible. + + For further details look here: + http://wil.cs.caltech.edu/pfldnet2007/paper/YeAH_TCP.pdf + +config TCP_CONG_ILLINOIS + tristate "TCP Illinois" + default n + help + TCP-Illinois is a sender-side modification of TCP Reno for + high speed long delay links. It uses round-trip-time to + adjust the alpha and beta parameters to achieve a higher average + throughput and maintain fairness. + + For further details see: + http://www.ews.uiuc.edu/~shaoliu/tcpillinois/index.html + +config TCP_CONG_DCTCP + tristate "DataCenter TCP (DCTCP)" + default n + help + DCTCP leverages Explicit Congestion Notification (ECN) in the network to + provide multi-bit feedback to the end hosts. It is designed to provide: + + - High burst tolerance (incast due to partition/aggregate), + - Low latency (short flows, queries), + - High throughput (continuous data updates, large file transfers) with + commodity, shallow-buffered switches. + + All switches in the data center network running DCTCP must support + ECN marking and be configured for marking when reaching defined switch + buffer thresholds. The default ECN marking threshold heuristic for + DCTCP on switches is 20 packets (30KB) at 1Gbps, and 65 packets + (~100KB) at 10Gbps, but might need further careful tweaking. + + For further details see: + http://simula.stanford.edu/~alizade/Site/DCTCP_files/dctcp-final.pdf + +config TCP_CONG_CDG + tristate "CAIA Delay-Gradient (CDG)" + default n + help + CAIA Delay-Gradient (CDG) is a TCP congestion control that modifies + the TCP sender in order to: + + o Use the delay gradient as a congestion signal. + o Back off with an average probability that is independent of the RTT. + o Coexist with flows that use loss-based congestion control. + o Tolerate packet loss unrelated to congestion. + + For further details see: + D.A. Hayes and G. Armitage. "Revisiting TCP congestion control using + delay gradients." In Networking 2011. Preprint: http://goo.gl/No3vdg + +config TCP_CONG_BBR + tristate "BBR TCP" + default n + help + + BBR (Bottleneck Bandwidth and RTT) TCP congestion control aims to + maximize network utilization and minimize queues. It builds an explicit + model of the bottleneck delivery rate and path round-trip propagation + delay. It tolerates packet loss and delay unrelated to congestion. It + can operate over LAN, WAN, cellular, wifi, or cable modem links. It can + coexist with flows that use loss-based congestion control, and can + operate with shallow buffers, deep buffers, bufferbloat, policers, or + AQM schemes that do not provide a delay signal. It requires the fq + ("Fair Queue") pacing packet scheduler. + +choice + prompt "Default TCP congestion control" + default DEFAULT_CUBIC + help + Select the TCP congestion control that will be used by default + for all connections. + + config DEFAULT_BIC + bool "Bic" if TCP_CONG_BIC=y + + config DEFAULT_CUBIC + bool "Cubic" if TCP_CONG_CUBIC=y + + config DEFAULT_HTCP + bool "Htcp" if TCP_CONG_HTCP=y + + config DEFAULT_HYBLA + bool "Hybla" if TCP_CONG_HYBLA=y + + config DEFAULT_VEGAS + bool "Vegas" if TCP_CONG_VEGAS=y + + config DEFAULT_VENO + bool "Veno" if TCP_CONG_VENO=y + + config DEFAULT_WESTWOOD + bool "Westwood" if TCP_CONG_WESTWOOD=y + + config DEFAULT_DCTCP + bool "DCTCP" if TCP_CONG_DCTCP=y + + config DEFAULT_CDG + bool "CDG" if TCP_CONG_CDG=y + + config DEFAULT_BBR + bool "BBR" if TCP_CONG_BBR=y + + config DEFAULT_RENO + bool "Reno" +endchoice + +endif + +config TCP_CONG_CUBIC + tristate + depends on !TCP_CONG_ADVANCED + default y + +config DEFAULT_TCP_CONG + string + default "bic" if DEFAULT_BIC + default "cubic" if DEFAULT_CUBIC + default "htcp" if DEFAULT_HTCP + default "hybla" if DEFAULT_HYBLA + default "vegas" if DEFAULT_VEGAS + default "westwood" if DEFAULT_WESTWOOD + default "veno" if DEFAULT_VENO + default "reno" if DEFAULT_RENO + default "dctcp" if DEFAULT_DCTCP + default "cdg" if DEFAULT_CDG + default "bbr" if DEFAULT_BBR + default "cubic" + +config TCP_MD5SIG + bool "TCP: MD5 Signature Option support (RFC2385)" + select CRYPTO + select CRYPTO_MD5 + help + RFC2385 specifies a method of giving MD5 protection to TCP sessions. + Its main (only?) use is to protect BGP sessions between core routers + on the Internet. + + If unsure, say N. diff --git a/net/ipv4/Makefile b/net/ipv4/Makefile new file mode 100644 index 000000000..bbdd9c44f --- /dev/null +++ b/net/ipv4/Makefile @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the Linux TCP/IP (INET) layer. +# + +obj-y := route.o inetpeer.o protocol.o \ + ip_input.o ip_fragment.o ip_forward.o ip_options.o \ + ip_output.o ip_sockglue.o inet_hashtables.o \ + inet_timewait_sock.o inet_connection_sock.o \ + tcp.o tcp_input.o tcp_output.o tcp_timer.o tcp_ipv4.o \ + tcp_minisocks.o tcp_cong.o tcp_metrics.o tcp_fastopen.o \ + tcp_rate.o tcp_recovery.o tcp_ulp.o \ + tcp_offload.o datagram.o raw.o udp.o udplite.o \ + udp_offload.o arp.o icmp.o devinet.o af_inet.o igmp.o \ + fib_frontend.o fib_semantics.o fib_trie.o fib_notifier.o \ + inet_fragment.o ping.o ip_tunnel_core.o gre_offload.o \ + metrics.o netlink.o nexthop.o udp_tunnel_stub.o + +obj-$(CONFIG_BPFILTER) += bpfilter/ + +obj-$(CONFIG_NET_IP_TUNNEL) += ip_tunnel.o +obj-$(CONFIG_SYSCTL) += sysctl_net_ipv4.o +obj-$(CONFIG_PROC_FS) += proc.o +obj-$(CONFIG_IP_MULTIPLE_TABLES) += fib_rules.o +obj-$(CONFIG_IP_MROUTE) += ipmr.o +obj-$(CONFIG_IP_MROUTE_COMMON) += ipmr_base.o +obj-$(CONFIG_NET_IPIP) += ipip.o +gre-y := gre_demux.o +obj-$(CONFIG_NET_FOU) += fou.o +obj-$(CONFIG_NET_IPGRE_DEMUX) += gre.o +obj-$(CONFIG_NET_IPGRE) += ip_gre.o +udp_tunnel-y := udp_tunnel_core.o udp_tunnel_nic.o +obj-$(CONFIG_NET_UDP_TUNNEL) += udp_tunnel.o +obj-$(CONFIG_NET_IPVTI) += ip_vti.o +obj-$(CONFIG_SYN_COOKIES) += syncookies.o +obj-$(CONFIG_INET_AH) += ah4.o +obj-$(CONFIG_INET_ESP) += esp4.o +obj-$(CONFIG_INET_ESP_OFFLOAD) += esp4_offload.o +obj-$(CONFIG_INET_IPCOMP) += ipcomp.o +obj-$(CONFIG_INET_XFRM_TUNNEL) += xfrm4_tunnel.o +obj-$(CONFIG_INET_TUNNEL) += tunnel4.o +obj-$(CONFIG_IP_PNP) += ipconfig.o +obj-$(CONFIG_NETFILTER) += netfilter.o netfilter/ +obj-$(CONFIG_INET_DIAG) += inet_diag.o +obj-$(CONFIG_INET_TCP_DIAG) += tcp_diag.o +obj-$(CONFIG_INET_UDP_DIAG) += udp_diag.o +obj-$(CONFIG_INET_RAW_DIAG) += raw_diag.o +obj-$(CONFIG_TCP_CONG_BBR) += tcp_bbr.o +obj-$(CONFIG_TCP_CONG_BIC) += tcp_bic.o +obj-$(CONFIG_TCP_CONG_CDG) += tcp_cdg.o +obj-$(CONFIG_TCP_CONG_CUBIC) += tcp_cubic.o +obj-$(CONFIG_TCP_CONG_DCTCP) += tcp_dctcp.o +obj-$(CONFIG_TCP_CONG_WESTWOOD) += tcp_westwood.o +obj-$(CONFIG_TCP_CONG_HSTCP) += tcp_highspeed.o +obj-$(CONFIG_TCP_CONG_HYBLA) += tcp_hybla.o +obj-$(CONFIG_TCP_CONG_HTCP) += tcp_htcp.o +obj-$(CONFIG_TCP_CONG_VEGAS) += tcp_vegas.o +obj-$(CONFIG_TCP_CONG_NV) += tcp_nv.o +obj-$(CONFIG_TCP_CONG_VENO) += tcp_veno.o +obj-$(CONFIG_TCP_CONG_SCALABLE) += tcp_scalable.o +obj-$(CONFIG_TCP_CONG_LP) += tcp_lp.o +obj-$(CONFIG_TCP_CONG_YEAH) += tcp_yeah.o +obj-$(CONFIG_TCP_CONG_ILLINOIS) += tcp_illinois.o +obj-$(CONFIG_NET_SOCK_MSG) += tcp_bpf.o +obj-$(CONFIG_BPF_SYSCALL) += udp_bpf.o +obj-$(CONFIG_NETLABEL) += cipso_ipv4.o + +obj-$(CONFIG_XFRM) += xfrm4_policy.o xfrm4_state.o xfrm4_input.o \ + xfrm4_output.o xfrm4_protocol.o + +ifeq ($(CONFIG_BPF_JIT),y) +obj-$(CONFIG_BPF_SYSCALL) += bpf_tcp_ca.o +endif diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c new file mode 100644 index 000000000..2f646335d --- /dev/null +++ b/net/ipv4/af_inet.c @@ -0,0 +1,2125 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * PF_INET protocol family socket handler. + * + * Authors: Ross Biro + * Fred N. van Kempen, + * Florian La Roche, + * Alan Cox, + * + * Changes (see also sock.c) + * + * piggy, + * Karl Knutson : Socket protocol table + * A.N.Kuznetsov : Socket death error in accept(). + * John Richardson : Fix non blocking error in connect() + * so sockets that fail to connect + * don't return -EINPROGRESS. + * Alan Cox : Asynchronous I/O support + * Alan Cox : Keep correct socket pointer on sock + * structures + * when accept() ed + * Alan Cox : Semantics of SO_LINGER aren't state + * moved to close when you look carefully. + * With this fixed and the accept bug fixed + * some RPC stuff seems happier. + * Niibe Yutaka : 4.4BSD style write async I/O + * Alan Cox, + * Tony Gale : Fixed reuse semantics. + * Alan Cox : bind() shouldn't abort existing but dead + * sockets. Stops FTP netin:.. I hope. + * Alan Cox : bind() works correctly for RAW sockets. + * Note that FreeBSD at least was broken + * in this respect so be careful with + * compatibility tests... + * Alan Cox : routing cache support + * Alan Cox : memzero the socket structure for + * compactness. + * Matt Day : nonblock connect error handler + * Alan Cox : Allow large numbers of pending sockets + * (eg for big web sites), but only if + * specifically application requested. + * Alan Cox : New buffering throughout IP. Used + * dumbly. + * Alan Cox : New buffering now used smartly. + * Alan Cox : BSD rather than common sense + * interpretation of listen. + * Germano Caronni : Assorted small races. + * Alan Cox : sendmsg/recvmsg basic support. + * Alan Cox : Only sendmsg/recvmsg now supported. + * Alan Cox : Locked down bind (see security list). + * Alan Cox : Loosened bind a little. + * Mike McLagan : ADD/DEL DLCI Ioctls + * Willy Konynenberg : Transparent proxying support. + * David S. Miller : New socket lookup architecture. + * Some other random speedups. + * Cyrus Durgin : Cleaned up file for kmod hacks. + * Andi Kleen : Fix inet_stream_connect TCP race. + */ + +#define pr_fmt(fmt) "IPv4: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_IP_MROUTE +#include +#endif +#include +#include + +#include + +/* The inetsw table contains everything that inet_create needs to + * build a new socket. + */ +static struct list_head inetsw[SOCK_MAX]; +static DEFINE_SPINLOCK(inetsw_lock); + +/* New destruction routine */ + +void inet_sock_destruct(struct sock *sk) +{ + struct inet_sock *inet = inet_sk(sk); + + __skb_queue_purge(&sk->sk_receive_queue); + __skb_queue_purge(&sk->sk_error_queue); + + sk_mem_reclaim_final(sk); + + if (sk->sk_type == SOCK_STREAM && sk->sk_state != TCP_CLOSE) { + pr_err("Attempt to release TCP socket in state %d %p\n", + sk->sk_state, sk); + return; + } + if (!sock_flag(sk, SOCK_DEAD)) { + pr_err("Attempt to release alive inet socket %p\n", sk); + return; + } + + WARN_ON_ONCE(atomic_read(&sk->sk_rmem_alloc)); + WARN_ON_ONCE(refcount_read(&sk->sk_wmem_alloc)); + WARN_ON_ONCE(sk->sk_wmem_queued); + WARN_ON_ONCE(sk_forward_alloc_get(sk)); + + kfree(rcu_dereference_protected(inet->inet_opt, 1)); + dst_release(rcu_dereference_protected(sk->sk_dst_cache, 1)); + dst_release(rcu_dereference_protected(sk->sk_rx_dst, 1)); + sk_refcnt_debug_dec(sk); +} +EXPORT_SYMBOL(inet_sock_destruct); + +/* + * The routines beyond this point handle the behaviour of an AF_INET + * socket object. Mostly it punts to the subprotocols of IP to do + * the work. + */ + +/* + * Automatically bind an unbound socket. + */ + +static int inet_autobind(struct sock *sk) +{ + struct inet_sock *inet; + /* We may need to bind the socket. */ + lock_sock(sk); + inet = inet_sk(sk); + if (!inet->inet_num) { + if (sk->sk_prot->get_port(sk, 0)) { + release_sock(sk); + return -EAGAIN; + } + inet->inet_sport = htons(inet->inet_num); + } + release_sock(sk); + return 0; +} + +/* + * Move a socket into listening state. + */ +int inet_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + unsigned char old_state; + int err, tcp_fastopen; + + lock_sock(sk); + + err = -EINVAL; + if (sock->state != SS_UNCONNECTED || sock->type != SOCK_STREAM) + goto out; + + old_state = sk->sk_state; + if (!((1 << old_state) & (TCPF_CLOSE | TCPF_LISTEN))) + goto out; + + WRITE_ONCE(sk->sk_max_ack_backlog, backlog); + /* Really, if the socket is already in listen state + * we can only allow the backlog to be adjusted. + */ + if (old_state != TCP_LISTEN) { + /* Enable TFO w/o requiring TCP_FASTOPEN socket option. + * Note that only TCP sockets (SOCK_STREAM) will reach here. + * Also fastopen backlog may already been set via the option + * because the socket was in TCP_LISTEN state previously but + * was shutdown() rather than close(). + */ + tcp_fastopen = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_fastopen); + if ((tcp_fastopen & TFO_SERVER_WO_SOCKOPT1) && + (tcp_fastopen & TFO_SERVER_ENABLE) && + !inet_csk(sk)->icsk_accept_queue.fastopenq.max_qlen) { + fastopen_queue_tune(sk, backlog); + tcp_fastopen_init_key_once(sock_net(sk)); + } + + err = inet_csk_listen_start(sk); + if (err) + goto out; + tcp_call_bpf(sk, BPF_SOCK_OPS_TCP_LISTEN_CB, 0, NULL); + } + err = 0; + +out: + release_sock(sk); + return err; +} +EXPORT_SYMBOL(inet_listen); + +/* + * Create an inet socket. + */ + +static int inet_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + struct inet_protosw *answer; + struct inet_sock *inet; + struct proto *answer_prot; + unsigned char answer_flags; + int try_loading_module = 0; + int err; + + if (protocol < 0 || protocol >= IPPROTO_MAX) + return -EINVAL; + + sock->state = SS_UNCONNECTED; + + /* Look for the requested type/protocol pair. */ +lookup_protocol: + err = -ESOCKTNOSUPPORT; + rcu_read_lock(); + list_for_each_entry_rcu(answer, &inetsw[sock->type], list) { + + err = 0; + /* Check the non-wild match. */ + if (protocol == answer->protocol) { + if (protocol != IPPROTO_IP) + break; + } else { + /* Check for the two wild cases. */ + if (IPPROTO_IP == protocol) { + protocol = answer->protocol; + break; + } + if (IPPROTO_IP == answer->protocol) + break; + } + err = -EPROTONOSUPPORT; + } + + if (unlikely(err)) { + if (try_loading_module < 2) { + rcu_read_unlock(); + /* + * Be more specific, e.g. net-pf-2-proto-132-type-1 + * (net-pf-PF_INET-proto-IPPROTO_SCTP-type-SOCK_STREAM) + */ + if (++try_loading_module == 1) + request_module("net-pf-%d-proto-%d-type-%d", + PF_INET, protocol, sock->type); + /* + * Fall back to generic, e.g. net-pf-2-proto-132 + * (net-pf-PF_INET-proto-IPPROTO_SCTP) + */ + else + request_module("net-pf-%d-proto-%d", + PF_INET, protocol); + goto lookup_protocol; + } else + goto out_rcu_unlock; + } + + err = -EPERM; + if (sock->type == SOCK_RAW && !kern && + !ns_capable(net->user_ns, CAP_NET_RAW)) + goto out_rcu_unlock; + + sock->ops = answer->ops; + answer_prot = answer->prot; + answer_flags = answer->flags; + rcu_read_unlock(); + + WARN_ON(!answer_prot->slab); + + err = -ENOMEM; + sk = sk_alloc(net, PF_INET, GFP_KERNEL, answer_prot, kern); + if (!sk) + goto out; + + err = 0; + if (INET_PROTOSW_REUSE & answer_flags) + sk->sk_reuse = SK_CAN_REUSE; + + if (INET_PROTOSW_ICSK & answer_flags) + inet_init_csk_locks(sk); + + inet = inet_sk(sk); + inet->is_icsk = (INET_PROTOSW_ICSK & answer_flags) != 0; + + inet->nodefrag = 0; + + if (SOCK_RAW == sock->type) { + inet->inet_num = protocol; + if (IPPROTO_RAW == protocol) + inet->hdrincl = 1; + } + + if (READ_ONCE(net->ipv4.sysctl_ip_no_pmtu_disc)) + inet->pmtudisc = IP_PMTUDISC_DONT; + else + inet->pmtudisc = IP_PMTUDISC_WANT; + + atomic_set(&inet->inet_id, 0); + + sock_init_data(sock, sk); + + sk->sk_destruct = inet_sock_destruct; + sk->sk_protocol = protocol; + sk->sk_backlog_rcv = sk->sk_prot->backlog_rcv; + sk->sk_txrehash = READ_ONCE(net->core.sysctl_txrehash); + + inet->uc_ttl = -1; + inet->mc_loop = 1; + inet->mc_ttl = 1; + inet->mc_all = 1; + inet->mc_index = 0; + inet->mc_list = NULL; + inet->rcv_tos = 0; + + sk_refcnt_debug_inc(sk); + + if (inet->inet_num) { + /* It assumes that any protocol which allows + * the user to assign a number at socket + * creation time automatically + * shares. + */ + inet->inet_sport = htons(inet->inet_num); + /* Add to protocol hash chains. */ + err = sk->sk_prot->hash(sk); + if (err) { + sk_common_release(sk); + goto out; + } + } + + if (sk->sk_prot->init) { + err = sk->sk_prot->init(sk); + if (err) { + sk_common_release(sk); + goto out; + } + } + + if (!kern) { + err = BPF_CGROUP_RUN_PROG_INET_SOCK(sk); + if (err) { + sk_common_release(sk); + goto out; + } + } +out: + return err; +out_rcu_unlock: + rcu_read_unlock(); + goto out; +} + + +/* + * The peer socket should always be NULL (or else). When we call this + * function we are destroying the object and from then on nobody + * should refer to it. + */ +int inet_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + + if (sk) { + long timeout; + + if (!sk->sk_kern_sock) + BPF_CGROUP_RUN_PROG_INET_SOCK_RELEASE(sk); + + /* Applications forget to leave groups before exiting */ + ip_mc_drop_socket(sk); + + /* If linger is set, we don't return until the close + * is complete. Otherwise we return immediately. The + * actually closing is done the same either way. + * + * If the close is due to the process exiting, we never + * linger.. + */ + timeout = 0; + if (sock_flag(sk, SOCK_LINGER) && + !(current->flags & PF_EXITING)) + timeout = sk->sk_lingertime; + sk->sk_prot->close(sk, timeout); + sock->sk = NULL; + } + return 0; +} +EXPORT_SYMBOL(inet_release); + +int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +{ + struct sock *sk = sock->sk; + u32 flags = BIND_WITH_LOCK; + int err; + + /* If the socket has its own bind function then use it. (RAW) */ + if (sk->sk_prot->bind) { + return sk->sk_prot->bind(sk, uaddr, addr_len); + } + if (addr_len < sizeof(struct sockaddr_in)) + return -EINVAL; + + /* BPF prog is run before any checks are done so that if the prog + * changes context in a wrong way it will be caught. + */ + err = BPF_CGROUP_RUN_PROG_INET_BIND_LOCK(sk, uaddr, + CGROUP_INET4_BIND, &flags); + if (err) + return err; + + return __inet_bind(sk, uaddr, addr_len, flags); +} +EXPORT_SYMBOL(inet_bind); + +int __inet_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len, + u32 flags) +{ + struct sockaddr_in *addr = (struct sockaddr_in *)uaddr; + struct inet_sock *inet = inet_sk(sk); + struct net *net = sock_net(sk); + unsigned short snum; + int chk_addr_ret; + u32 tb_id = RT_TABLE_LOCAL; + int err; + + if (addr->sin_family != AF_INET) { + /* Compatibility games : accept AF_UNSPEC (mapped to AF_INET) + * only if s_addr is INADDR_ANY. + */ + err = -EAFNOSUPPORT; + if (addr->sin_family != AF_UNSPEC || + addr->sin_addr.s_addr != htonl(INADDR_ANY)) + goto out; + } + + tb_id = l3mdev_fib_table_by_index(net, sk->sk_bound_dev_if) ? : tb_id; + chk_addr_ret = inet_addr_type_table(net, addr->sin_addr.s_addr, tb_id); + + /* Not specified by any standard per-se, however it breaks too + * many applications when removed. It is unfortunate since + * allowing applications to make a non-local bind solves + * several problems with systems using dynamic addressing. + * (ie. your servers still start up even if your ISDN link + * is temporarily down) + */ + err = -EADDRNOTAVAIL; + if (!inet_addr_valid_or_nonlocal(net, inet, addr->sin_addr.s_addr, + chk_addr_ret)) + goto out; + + snum = ntohs(addr->sin_port); + err = -EACCES; + if (!(flags & BIND_NO_CAP_NET_BIND_SERVICE) && + snum && inet_port_requires_bind_service(net, snum) && + !ns_capable(net->user_ns, CAP_NET_BIND_SERVICE)) + goto out; + + /* We keep a pair of addresses. rcv_saddr is the one + * used by hash lookups, and saddr is used for transmit. + * + * In the BSD API these are the same except where it + * would be illegal to use them (multicast/broadcast) in + * which case the sending device address is used. + */ + if (flags & BIND_WITH_LOCK) + lock_sock(sk); + + /* Check these errors (active socket, double bind). */ + err = -EINVAL; + if (sk->sk_state != TCP_CLOSE || inet->inet_num) + goto out_release_sock; + + inet->inet_rcv_saddr = inet->inet_saddr = addr->sin_addr.s_addr; + if (chk_addr_ret == RTN_MULTICAST || chk_addr_ret == RTN_BROADCAST) + inet->inet_saddr = 0; /* Use device */ + + /* Make sure we are allowed to bind here. */ + if (snum || !(inet->bind_address_no_port || + (flags & BIND_FORCE_ADDRESS_NO_PORT))) { + err = sk->sk_prot->get_port(sk, snum); + if (err) { + inet->inet_saddr = inet->inet_rcv_saddr = 0; + goto out_release_sock; + } + if (!(flags & BIND_FROM_BPF)) { + err = BPF_CGROUP_RUN_PROG_INET4_POST_BIND(sk); + if (err) { + inet->inet_saddr = inet->inet_rcv_saddr = 0; + if (sk->sk_prot->put_port) + sk->sk_prot->put_port(sk); + goto out_release_sock; + } + } + } + + if (inet->inet_rcv_saddr) + sk->sk_userlocks |= SOCK_BINDADDR_LOCK; + if (snum) + sk->sk_userlocks |= SOCK_BINDPORT_LOCK; + inet->inet_sport = htons(inet->inet_num); + inet->inet_daddr = 0; + inet->inet_dport = 0; + sk_dst_reset(sk); + err = 0; +out_release_sock: + if (flags & BIND_WITH_LOCK) + release_sock(sk); +out: + return err; +} + +int inet_dgram_connect(struct socket *sock, struct sockaddr *uaddr, + int addr_len, int flags) +{ + struct sock *sk = sock->sk; + const struct proto *prot; + int err; + + if (addr_len < sizeof(uaddr->sa_family)) + return -EINVAL; + + /* IPV6_ADDRFORM can change sk->sk_prot under us. */ + prot = READ_ONCE(sk->sk_prot); + + if (uaddr->sa_family == AF_UNSPEC) + return prot->disconnect(sk, flags); + + if (BPF_CGROUP_PRE_CONNECT_ENABLED(sk)) { + err = prot->pre_connect(sk, uaddr, addr_len); + if (err) + return err; + } + + if (data_race(!inet_sk(sk)->inet_num) && inet_autobind(sk)) + return -EAGAIN; + return prot->connect(sk, uaddr, addr_len); +} +EXPORT_SYMBOL(inet_dgram_connect); + +static long inet_wait_for_connect(struct sock *sk, long timeo, int writebias) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + add_wait_queue(sk_sleep(sk), &wait); + sk->sk_write_pending += writebias; + + /* Basic assumption: if someone sets sk->sk_err, he _must_ + * change state of the socket from TCP_SYN_*. + * Connect() does not allow to get error notifications + * without closing the socket. + */ + while ((1 << sk->sk_state) & (TCPF_SYN_SENT | TCPF_SYN_RECV)) { + release_sock(sk); + timeo = wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); + lock_sock(sk); + if (signal_pending(current) || !timeo) + break; + } + remove_wait_queue(sk_sleep(sk), &wait); + sk->sk_write_pending -= writebias; + return timeo; +} + +/* + * Connect to a remote host. There is regrettably still a little + * TCP 'magic' in here. + */ +int __inet_stream_connect(struct socket *sock, struct sockaddr *uaddr, + int addr_len, int flags, int is_sendmsg) +{ + struct sock *sk = sock->sk; + int err; + long timeo; + + /* + * uaddr can be NULL and addr_len can be 0 if: + * sk is a TCP fastopen active socket and + * TCP_FASTOPEN_CONNECT sockopt is set and + * we already have a valid cookie for this socket. + * In this case, user can call write() after connect(). + * write() will invoke tcp_sendmsg_fastopen() which calls + * __inet_stream_connect(). + */ + if (uaddr) { + if (addr_len < sizeof(uaddr->sa_family)) + return -EINVAL; + + if (uaddr->sa_family == AF_UNSPEC) { + sk->sk_disconnects++; + err = sk->sk_prot->disconnect(sk, flags); + sock->state = err ? SS_DISCONNECTING : SS_UNCONNECTED; + goto out; + } + } + + switch (sock->state) { + default: + err = -EINVAL; + goto out; + case SS_CONNECTED: + err = -EISCONN; + goto out; + case SS_CONNECTING: + if (inet_sk(sk)->defer_connect) + err = is_sendmsg ? -EINPROGRESS : -EISCONN; + else + err = -EALREADY; + /* Fall out of switch with err, set for this state */ + break; + case SS_UNCONNECTED: + err = -EISCONN; + if (sk->sk_state != TCP_CLOSE) + goto out; + + if (BPF_CGROUP_PRE_CONNECT_ENABLED(sk)) { + err = sk->sk_prot->pre_connect(sk, uaddr, addr_len); + if (err) + goto out; + } + + err = sk->sk_prot->connect(sk, uaddr, addr_len); + if (err < 0) + goto out; + + sock->state = SS_CONNECTING; + + if (!err && inet_sk(sk)->defer_connect) + goto out; + + /* Just entered SS_CONNECTING state; the only + * difference is that return value in non-blocking + * case is EINPROGRESS, rather than EALREADY. + */ + err = -EINPROGRESS; + break; + } + + timeo = sock_sndtimeo(sk, flags & O_NONBLOCK); + + if ((1 << sk->sk_state) & (TCPF_SYN_SENT | TCPF_SYN_RECV)) { + int writebias = (sk->sk_protocol == IPPROTO_TCP) && + tcp_sk(sk)->fastopen_req && + tcp_sk(sk)->fastopen_req->data ? 1 : 0; + int dis = sk->sk_disconnects; + + /* Error code is set above */ + if (!timeo || !inet_wait_for_connect(sk, timeo, writebias)) + goto out; + + err = sock_intr_errno(timeo); + if (signal_pending(current)) + goto out; + + if (dis != sk->sk_disconnects) { + err = -EPIPE; + goto out; + } + } + + /* Connection was closed by RST, timeout, ICMP error + * or another process disconnected us. + */ + if (sk->sk_state == TCP_CLOSE) + goto sock_error; + + /* sk->sk_err may be not zero now, if RECVERR was ordered by user + * and error was received after socket entered established state. + * Hence, it is handled normally after connect() return successfully. + */ + + sock->state = SS_CONNECTED; + err = 0; +out: + return err; + +sock_error: + err = sock_error(sk) ? : -ECONNABORTED; + sock->state = SS_UNCONNECTED; + sk->sk_disconnects++; + if (sk->sk_prot->disconnect(sk, flags)) + sock->state = SS_DISCONNECTING; + goto out; +} +EXPORT_SYMBOL(__inet_stream_connect); + +int inet_stream_connect(struct socket *sock, struct sockaddr *uaddr, + int addr_len, int flags) +{ + int err; + + lock_sock(sock->sk); + err = __inet_stream_connect(sock, uaddr, addr_len, flags, 0); + release_sock(sock->sk); + return err; +} +EXPORT_SYMBOL(inet_stream_connect); + +/* + * Accept a pending connection. The TCP layer now gives BSD semantics. + */ + +int inet_accept(struct socket *sock, struct socket *newsock, int flags, + bool kern) +{ + struct sock *sk1 = sock->sk, *sk2; + int err = -EINVAL; + + /* IPV6_ADDRFORM can change sk->sk_prot under us. */ + sk2 = READ_ONCE(sk1->sk_prot)->accept(sk1, flags, &err, kern); + if (!sk2) + goto do_err; + + lock_sock(sk2); + + sock_rps_record_flow(sk2); + WARN_ON(!((1 << sk2->sk_state) & + (TCPF_ESTABLISHED | TCPF_SYN_RECV | + TCPF_CLOSE_WAIT | TCPF_CLOSE))); + + if (test_bit(SOCK_SUPPORT_ZC, &sock->flags)) + set_bit(SOCK_SUPPORT_ZC, &newsock->flags); + sock_graft(sk2, newsock); + + newsock->state = SS_CONNECTED; + err = 0; + release_sock(sk2); +do_err: + return err; +} +EXPORT_SYMBOL(inet_accept); + +/* + * This does both peername and sockname. + */ +int inet_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + struct sock *sk = sock->sk; + struct inet_sock *inet = inet_sk(sk); + DECLARE_SOCKADDR(struct sockaddr_in *, sin, uaddr); + + sin->sin_family = AF_INET; + lock_sock(sk); + if (peer) { + if (!inet->inet_dport || + (((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_SYN_SENT)) && + peer == 1)) { + release_sock(sk); + return -ENOTCONN; + } + sin->sin_port = inet->inet_dport; + sin->sin_addr.s_addr = inet->inet_daddr; + BPF_CGROUP_RUN_SA_PROG(sk, (struct sockaddr *)sin, + CGROUP_INET4_GETPEERNAME); + } else { + __be32 addr = inet->inet_rcv_saddr; + if (!addr) + addr = inet->inet_saddr; + sin->sin_port = inet->inet_sport; + sin->sin_addr.s_addr = addr; + BPF_CGROUP_RUN_SA_PROG(sk, (struct sockaddr *)sin, + CGROUP_INET4_GETSOCKNAME); + } + release_sock(sk); + memset(sin->sin_zero, 0, sizeof(sin->sin_zero)); + return sizeof(*sin); +} +EXPORT_SYMBOL(inet_getname); + +int inet_send_prepare(struct sock *sk) +{ + sock_rps_record_flow(sk); + + /* We may need to bind the socket. */ + if (data_race(!inet_sk(sk)->inet_num) && !sk->sk_prot->no_autobind && + inet_autobind(sk)) + return -EAGAIN; + + return 0; +} +EXPORT_SYMBOL_GPL(inet_send_prepare); + +int inet_sendmsg(struct socket *sock, struct msghdr *msg, size_t size) +{ + struct sock *sk = sock->sk; + + if (unlikely(inet_send_prepare(sk))) + return -EAGAIN; + + return INDIRECT_CALL_2(sk->sk_prot->sendmsg, tcp_sendmsg, udp_sendmsg, + sk, msg, size); +} +EXPORT_SYMBOL(inet_sendmsg); + +void inet_splice_eof(struct socket *sock) +{ + const struct proto *prot; + struct sock *sk = sock->sk; + + if (unlikely(inet_send_prepare(sk))) + return; + + /* IPV6_ADDRFORM can change sk->sk_prot under us. */ + prot = READ_ONCE(sk->sk_prot); + if (prot->splice_eof) + prot->splice_eof(sock); +} +EXPORT_SYMBOL_GPL(inet_splice_eof); + +ssize_t inet_sendpage(struct socket *sock, struct page *page, int offset, + size_t size, int flags) +{ + struct sock *sk = sock->sk; + const struct proto *prot; + + if (unlikely(inet_send_prepare(sk))) + return -EAGAIN; + + /* IPV6_ADDRFORM can change sk->sk_prot under us. */ + prot = READ_ONCE(sk->sk_prot); + if (prot->sendpage) + return prot->sendpage(sk, page, offset, size, flags); + return sock_no_sendpage(sock, page, offset, size, flags); +} +EXPORT_SYMBOL(inet_sendpage); + +INDIRECT_CALLABLE_DECLARE(int udp_recvmsg(struct sock *, struct msghdr *, + size_t, int, int *)); +int inet_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, + int flags) +{ + struct sock *sk = sock->sk; + int addr_len = 0; + int err; + + if (likely(!(flags & MSG_ERRQUEUE))) + sock_rps_record_flow(sk); + + err = INDIRECT_CALL_2(sk->sk_prot->recvmsg, tcp_recvmsg, udp_recvmsg, + sk, msg, size, flags, &addr_len); + if (err >= 0) + msg->msg_namelen = addr_len; + return err; +} +EXPORT_SYMBOL(inet_recvmsg); + +int inet_shutdown(struct socket *sock, int how) +{ + struct sock *sk = sock->sk; + int err = 0; + + /* This should really check to make sure + * the socket is a TCP socket. (WHY AC...) + */ + how++; /* maps 0->1 has the advantage of making bit 1 rcvs and + 1->2 bit 2 snds. + 2->3 */ + if ((how & ~SHUTDOWN_MASK) || !how) /* MAXINT->0 */ + return -EINVAL; + + lock_sock(sk); + if (sock->state == SS_CONNECTING) { + if ((1 << sk->sk_state) & + (TCPF_SYN_SENT | TCPF_SYN_RECV | TCPF_CLOSE)) + sock->state = SS_DISCONNECTING; + else + sock->state = SS_CONNECTED; + } + + switch (sk->sk_state) { + case TCP_CLOSE: + err = -ENOTCONN; + /* Hack to wake up other listeners, who can poll for + EPOLLHUP, even on eg. unconnected UDP sockets -- RR */ + fallthrough; + default: + WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | how); + if (sk->sk_prot->shutdown) + sk->sk_prot->shutdown(sk, how); + break; + + /* Remaining two branches are temporary solution for missing + * close() in multithreaded environment. It is _not_ a good idea, + * but we have no choice until close() is repaired at VFS level. + */ + case TCP_LISTEN: + if (!(how & RCV_SHUTDOWN)) + break; + fallthrough; + case TCP_SYN_SENT: + err = sk->sk_prot->disconnect(sk, O_NONBLOCK); + sock->state = err ? SS_DISCONNECTING : SS_UNCONNECTED; + break; + } + + /* Wake up anyone sleeping in poll. */ + sk->sk_state_change(sk); + release_sock(sk); + return err; +} +EXPORT_SYMBOL(inet_shutdown); + +/* + * ioctl() calls you can issue on an INET socket. Most of these are + * device configuration and stuff and very rarely used. Some ioctls + * pass on to the socket itself. + * + * NOTE: I like the idea of a module for the config stuff. ie ifconfig + * loads the devconfigure module does its configuring and unloads it. + * There's a good 20K of config code hanging around the kernel. + */ + +int inet_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + struct sock *sk = sock->sk; + int err = 0; + struct net *net = sock_net(sk); + void __user *p = (void __user *)arg; + struct ifreq ifr; + struct rtentry rt; + + switch (cmd) { + case SIOCADDRT: + case SIOCDELRT: + if (copy_from_user(&rt, p, sizeof(struct rtentry))) + return -EFAULT; + err = ip_rt_ioctl(net, cmd, &rt); + break; + case SIOCRTMSG: + err = -EINVAL; + break; + case SIOCDARP: + case SIOCGARP: + case SIOCSARP: + err = arp_ioctl(net, cmd, (void __user *)arg); + break; + case SIOCGIFADDR: + case SIOCGIFBRDADDR: + case SIOCGIFNETMASK: + case SIOCGIFDSTADDR: + case SIOCGIFPFLAGS: + if (get_user_ifreq(&ifr, NULL, p)) + return -EFAULT; + err = devinet_ioctl(net, cmd, &ifr); + if (!err && put_user_ifreq(&ifr, p)) + err = -EFAULT; + break; + + case SIOCSIFADDR: + case SIOCSIFBRDADDR: + case SIOCSIFNETMASK: + case SIOCSIFDSTADDR: + case SIOCSIFPFLAGS: + case SIOCSIFFLAGS: + if (get_user_ifreq(&ifr, NULL, p)) + return -EFAULT; + err = devinet_ioctl(net, cmd, &ifr); + break; + default: + if (sk->sk_prot->ioctl) + err = sk->sk_prot->ioctl(sk, cmd, arg); + else + err = -ENOIOCTLCMD; + break; + } + return err; +} +EXPORT_SYMBOL(inet_ioctl); + +#ifdef CONFIG_COMPAT +static int inet_compat_routing_ioctl(struct sock *sk, unsigned int cmd, + struct compat_rtentry __user *ur) +{ + compat_uptr_t rtdev; + struct rtentry rt; + + if (copy_from_user(&rt.rt_dst, &ur->rt_dst, + 3 * sizeof(struct sockaddr)) || + get_user(rt.rt_flags, &ur->rt_flags) || + get_user(rt.rt_metric, &ur->rt_metric) || + get_user(rt.rt_mtu, &ur->rt_mtu) || + get_user(rt.rt_window, &ur->rt_window) || + get_user(rt.rt_irtt, &ur->rt_irtt) || + get_user(rtdev, &ur->rt_dev)) + return -EFAULT; + + rt.rt_dev = compat_ptr(rtdev); + return ip_rt_ioctl(sock_net(sk), cmd, &rt); +} + +static int inet_compat_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + void __user *argp = compat_ptr(arg); + struct sock *sk = sock->sk; + + switch (cmd) { + case SIOCADDRT: + case SIOCDELRT: + return inet_compat_routing_ioctl(sk, cmd, argp); + default: + if (!sk->sk_prot->compat_ioctl) + return -ENOIOCTLCMD; + return sk->sk_prot->compat_ioctl(sk, cmd, arg); + } +} +#endif /* CONFIG_COMPAT */ + +const struct proto_ops inet_stream_ops = { + .family = PF_INET, + .owner = THIS_MODULE, + .release = inet_release, + .bind = inet_bind, + .connect = inet_stream_connect, + .socketpair = sock_no_socketpair, + .accept = inet_accept, + .getname = inet_getname, + .poll = tcp_poll, + .ioctl = inet_ioctl, + .gettstamp = sock_gettstamp, + .listen = inet_listen, + .shutdown = inet_shutdown, + .setsockopt = sock_common_setsockopt, + .getsockopt = sock_common_getsockopt, + .sendmsg = inet_sendmsg, + .recvmsg = inet_recvmsg, +#ifdef CONFIG_MMU + .mmap = tcp_mmap, +#endif + .splice_eof = inet_splice_eof, + .sendpage = inet_sendpage, + .splice_read = tcp_splice_read, + .read_sock = tcp_read_sock, + .read_skb = tcp_read_skb, + .sendmsg_locked = tcp_sendmsg_locked, + .sendpage_locked = tcp_sendpage_locked, + .peek_len = tcp_peek_len, +#ifdef CONFIG_COMPAT + .compat_ioctl = inet_compat_ioctl, +#endif + .set_rcvlowat = tcp_set_rcvlowat, +}; +EXPORT_SYMBOL(inet_stream_ops); + +const struct proto_ops inet_dgram_ops = { + .family = PF_INET, + .owner = THIS_MODULE, + .release = inet_release, + .bind = inet_bind, + .connect = inet_dgram_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = inet_getname, + .poll = udp_poll, + .ioctl = inet_ioctl, + .gettstamp = sock_gettstamp, + .listen = sock_no_listen, + .shutdown = inet_shutdown, + .setsockopt = sock_common_setsockopt, + .getsockopt = sock_common_getsockopt, + .sendmsg = inet_sendmsg, + .read_skb = udp_read_skb, + .recvmsg = inet_recvmsg, + .mmap = sock_no_mmap, + .splice_eof = inet_splice_eof, + .sendpage = inet_sendpage, + .set_peek_off = sk_set_peek_off, +#ifdef CONFIG_COMPAT + .compat_ioctl = inet_compat_ioctl, +#endif +}; +EXPORT_SYMBOL(inet_dgram_ops); + +/* + * For SOCK_RAW sockets; should be the same as inet_dgram_ops but without + * udp_poll + */ +static const struct proto_ops inet_sockraw_ops = { + .family = PF_INET, + .owner = THIS_MODULE, + .release = inet_release, + .bind = inet_bind, + .connect = inet_dgram_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = inet_getname, + .poll = datagram_poll, + .ioctl = inet_ioctl, + .gettstamp = sock_gettstamp, + .listen = sock_no_listen, + .shutdown = inet_shutdown, + .setsockopt = sock_common_setsockopt, + .getsockopt = sock_common_getsockopt, + .sendmsg = inet_sendmsg, + .recvmsg = inet_recvmsg, + .mmap = sock_no_mmap, + .splice_eof = inet_splice_eof, + .sendpage = inet_sendpage, +#ifdef CONFIG_COMPAT + .compat_ioctl = inet_compat_ioctl, +#endif +}; + +static const struct net_proto_family inet_family_ops = { + .family = PF_INET, + .create = inet_create, + .owner = THIS_MODULE, +}; + +/* Upon startup we insert all the elements in inetsw_array[] into + * the linked list inetsw. + */ +static struct inet_protosw inetsw_array[] = +{ + { + .type = SOCK_STREAM, + .protocol = IPPROTO_TCP, + .prot = &tcp_prot, + .ops = &inet_stream_ops, + .flags = INET_PROTOSW_PERMANENT | + INET_PROTOSW_ICSK, + }, + + { + .type = SOCK_DGRAM, + .protocol = IPPROTO_UDP, + .prot = &udp_prot, + .ops = &inet_dgram_ops, + .flags = INET_PROTOSW_PERMANENT, + }, + + { + .type = SOCK_DGRAM, + .protocol = IPPROTO_ICMP, + .prot = &ping_prot, + .ops = &inet_sockraw_ops, + .flags = INET_PROTOSW_REUSE, + }, + + { + .type = SOCK_RAW, + .protocol = IPPROTO_IP, /* wild card */ + .prot = &raw_prot, + .ops = &inet_sockraw_ops, + .flags = INET_PROTOSW_REUSE, + } +}; + +#define INETSW_ARRAY_LEN ARRAY_SIZE(inetsw_array) + +void inet_register_protosw(struct inet_protosw *p) +{ + struct list_head *lh; + struct inet_protosw *answer; + int protocol = p->protocol; + struct list_head *last_perm; + + spin_lock_bh(&inetsw_lock); + + if (p->type >= SOCK_MAX) + goto out_illegal; + + /* If we are trying to override a permanent protocol, bail. */ + last_perm = &inetsw[p->type]; + list_for_each(lh, &inetsw[p->type]) { + answer = list_entry(lh, struct inet_protosw, list); + /* Check only the non-wild match. */ + if ((INET_PROTOSW_PERMANENT & answer->flags) == 0) + break; + if (protocol == answer->protocol) + goto out_permanent; + last_perm = lh; + } + + /* Add the new entry after the last permanent entry if any, so that + * the new entry does not override a permanent entry when matched with + * a wild-card protocol. But it is allowed to override any existing + * non-permanent entry. This means that when we remove this entry, the + * system automatically returns to the old behavior. + */ + list_add_rcu(&p->list, last_perm); +out: + spin_unlock_bh(&inetsw_lock); + + return; + +out_permanent: + pr_err("Attempt to override permanent protocol %d\n", protocol); + goto out; + +out_illegal: + pr_err("Ignoring attempt to register invalid socket type %d\n", + p->type); + goto out; +} +EXPORT_SYMBOL(inet_register_protosw); + +void inet_unregister_protosw(struct inet_protosw *p) +{ + if (INET_PROTOSW_PERMANENT & p->flags) { + pr_err("Attempt to unregister permanent protocol %d\n", + p->protocol); + } else { + spin_lock_bh(&inetsw_lock); + list_del_rcu(&p->list); + spin_unlock_bh(&inetsw_lock); + + synchronize_net(); + } +} +EXPORT_SYMBOL(inet_unregister_protosw); + +static int inet_sk_reselect_saddr(struct sock *sk) +{ + struct inet_sock *inet = inet_sk(sk); + __be32 old_saddr = inet->inet_saddr; + __be32 daddr = inet->inet_daddr; + struct flowi4 *fl4; + struct rtable *rt; + __be32 new_saddr; + struct ip_options_rcu *inet_opt; + int err; + + inet_opt = rcu_dereference_protected(inet->inet_opt, + lockdep_sock_is_held(sk)); + if (inet_opt && inet_opt->opt.srr) + daddr = inet_opt->opt.faddr; + + /* Query new route. */ + fl4 = &inet->cork.fl.u.ip4; + rt = ip_route_connect(fl4, daddr, 0, sk->sk_bound_dev_if, + sk->sk_protocol, inet->inet_sport, + inet->inet_dport, sk); + if (IS_ERR(rt)) + return PTR_ERR(rt); + + new_saddr = fl4->saddr; + + if (new_saddr == old_saddr) { + sk_setup_caps(sk, &rt->dst); + return 0; + } + + err = inet_bhash2_update_saddr(sk, &new_saddr, AF_INET); + if (err) { + ip_rt_put(rt); + return err; + } + + sk_setup_caps(sk, &rt->dst); + + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_ip_dynaddr) > 1) { + pr_info("%s(): shifting inet->saddr from %pI4 to %pI4\n", + __func__, &old_saddr, &new_saddr); + } + + /* + * XXX The only one ugly spot where we need to + * XXX really change the sockets identity after + * XXX it has entered the hashes. -DaveM + * + * Besides that, it does not check for connection + * uniqueness. Wait for troubles. + */ + return __sk_prot_rehash(sk); +} + +int inet_sk_rebuild_header(struct sock *sk) +{ + struct inet_sock *inet = inet_sk(sk); + struct rtable *rt = (struct rtable *)__sk_dst_check(sk, 0); + __be32 daddr; + struct ip_options_rcu *inet_opt; + struct flowi4 *fl4; + int err; + + /* Route is OK, nothing to do. */ + if (rt) + return 0; + + /* Reroute. */ + rcu_read_lock(); + inet_opt = rcu_dereference(inet->inet_opt); + daddr = inet->inet_daddr; + if (inet_opt && inet_opt->opt.srr) + daddr = inet_opt->opt.faddr; + rcu_read_unlock(); + fl4 = &inet->cork.fl.u.ip4; + rt = ip_route_output_ports(sock_net(sk), fl4, sk, daddr, inet->inet_saddr, + inet->inet_dport, inet->inet_sport, + sk->sk_protocol, RT_CONN_FLAGS(sk), + sk->sk_bound_dev_if); + if (!IS_ERR(rt)) { + err = 0; + sk_setup_caps(sk, &rt->dst); + } else { + err = PTR_ERR(rt); + + /* Routing failed... */ + sk->sk_route_caps = 0; + /* + * Other protocols have to map its equivalent state to TCP_SYN_SENT. + * DCCP maps its DCCP_REQUESTING state to TCP_SYN_SENT. -acme + */ + if (!READ_ONCE(sock_net(sk)->ipv4.sysctl_ip_dynaddr) || + sk->sk_state != TCP_SYN_SENT || + (sk->sk_userlocks & SOCK_BINDADDR_LOCK) || + (err = inet_sk_reselect_saddr(sk)) != 0) + sk->sk_err_soft = -err; + } + + return err; +} +EXPORT_SYMBOL(inet_sk_rebuild_header); + +void inet_sk_set_state(struct sock *sk, int state) +{ + trace_inet_sock_set_state(sk, sk->sk_state, state); + sk->sk_state = state; +} +EXPORT_SYMBOL(inet_sk_set_state); + +void inet_sk_state_store(struct sock *sk, int newstate) +{ + trace_inet_sock_set_state(sk, sk->sk_state, newstate); + smp_store_release(&sk->sk_state, newstate); +} + +struct sk_buff *inet_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + bool udpfrag = false, fixedid = false, gso_partial, encap; + struct sk_buff *segs = ERR_PTR(-EINVAL); + const struct net_offload *ops; + unsigned int offset = 0; + struct iphdr *iph; + int proto, tot_len; + int nhoff; + int ihl; + int id; + + skb_reset_network_header(skb); + nhoff = skb_network_header(skb) - skb_mac_header(skb); + if (unlikely(!pskb_may_pull(skb, sizeof(*iph)))) + goto out; + + iph = ip_hdr(skb); + ihl = iph->ihl * 4; + if (ihl < sizeof(*iph)) + goto out; + + id = ntohs(iph->id); + proto = iph->protocol; + + /* Warning: after this point, iph might be no longer valid */ + if (unlikely(!pskb_may_pull(skb, ihl))) + goto out; + __skb_pull(skb, ihl); + + encap = SKB_GSO_CB(skb)->encap_level > 0; + if (encap) + features &= skb->dev->hw_enc_features; + SKB_GSO_CB(skb)->encap_level += ihl; + + skb_reset_transport_header(skb); + + segs = ERR_PTR(-EPROTONOSUPPORT); + + if (!skb->encapsulation || encap) { + udpfrag = !!(skb_shinfo(skb)->gso_type & SKB_GSO_UDP); + fixedid = !!(skb_shinfo(skb)->gso_type & SKB_GSO_TCP_FIXEDID); + + /* fixed ID is invalid if DF bit is not set */ + if (fixedid && !(ip_hdr(skb)->frag_off & htons(IP_DF))) + goto out; + } + + ops = rcu_dereference(inet_offloads[proto]); + if (likely(ops && ops->callbacks.gso_segment)) { + segs = ops->callbacks.gso_segment(skb, features); + if (!segs) + skb->network_header = skb_mac_header(skb) + nhoff - skb->head; + } + + if (IS_ERR_OR_NULL(segs)) + goto out; + + gso_partial = !!(skb_shinfo(segs)->gso_type & SKB_GSO_PARTIAL); + + skb = segs; + do { + iph = (struct iphdr *)(skb_mac_header(skb) + nhoff); + if (udpfrag) { + iph->frag_off = htons(offset >> 3); + if (skb->next) + iph->frag_off |= htons(IP_MF); + offset += skb->len - nhoff - ihl; + tot_len = skb->len - nhoff; + } else if (skb_is_gso(skb)) { + if (!fixedid) { + iph->id = htons(id); + id += skb_shinfo(skb)->gso_segs; + } + + if (gso_partial) + tot_len = skb_shinfo(skb)->gso_size + + SKB_GSO_CB(skb)->data_offset + + skb->head - (unsigned char *)iph; + else + tot_len = skb->len - nhoff; + } else { + if (!fixedid) + iph->id = htons(id++); + tot_len = skb->len - nhoff; + } + iph->tot_len = htons(tot_len); + ip_send_check(iph); + if (encap) + skb_reset_inner_headers(skb); + skb->network_header = (u8 *)iph - skb->head; + skb_reset_mac_len(skb); + } while ((skb = skb->next)); + +out: + return segs; +} + +static struct sk_buff *ipip_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + if (!(skb_shinfo(skb)->gso_type & SKB_GSO_IPXIP4)) + return ERR_PTR(-EINVAL); + + return inet_gso_segment(skb, features); +} + +struct sk_buff *inet_gro_receive(struct list_head *head, struct sk_buff *skb) +{ + const struct net_offload *ops; + struct sk_buff *pp = NULL; + const struct iphdr *iph; + struct sk_buff *p; + unsigned int hlen; + unsigned int off; + unsigned int id; + int flush = 1; + int proto; + + off = skb_gro_offset(skb); + hlen = off + sizeof(*iph); + iph = skb_gro_header(skb, hlen, off); + if (unlikely(!iph)) + goto out; + + proto = iph->protocol; + + ops = rcu_dereference(inet_offloads[proto]); + if (!ops || !ops->callbacks.gro_receive) + goto out; + + if (*(u8 *)iph != 0x45) + goto out; + + if (ip_is_fragment(iph)) + goto out; + + if (unlikely(ip_fast_csum((u8 *)iph, 5))) + goto out; + + id = ntohl(*(__be32 *)&iph->id); + flush = (u16)((ntohl(*(__be32 *)iph) ^ skb_gro_len(skb)) | (id & ~IP_DF)); + id >>= 16; + + list_for_each_entry(p, head, list) { + struct iphdr *iph2; + u16 flush_id; + + if (!NAPI_GRO_CB(p)->same_flow) + continue; + + iph2 = (struct iphdr *)(p->data + off); + /* The above works because, with the exception of the top + * (inner most) layer, we only aggregate pkts with the same + * hdr length so all the hdrs we'll need to verify will start + * at the same offset. + */ + if ((iph->protocol ^ iph2->protocol) | + ((__force u32)iph->saddr ^ (__force u32)iph2->saddr) | + ((__force u32)iph->daddr ^ (__force u32)iph2->daddr)) { + NAPI_GRO_CB(p)->same_flow = 0; + continue; + } + + /* All fields must match except length and checksum. */ + NAPI_GRO_CB(p)->flush |= + (iph->ttl ^ iph2->ttl) | + (iph->tos ^ iph2->tos) | + ((iph->frag_off ^ iph2->frag_off) & htons(IP_DF)); + + NAPI_GRO_CB(p)->flush |= flush; + + /* We need to store of the IP ID check to be included later + * when we can verify that this packet does in fact belong + * to a given flow. + */ + flush_id = (u16)(id - ntohs(iph2->id)); + + /* This bit of code makes it much easier for us to identify + * the cases where we are doing atomic vs non-atomic IP ID + * checks. Specifically an atomic check can return IP ID + * values 0 - 0xFFFF, while a non-atomic check can only + * return 0 or 0xFFFF. + */ + if (!NAPI_GRO_CB(p)->is_atomic || + !(iph->frag_off & htons(IP_DF))) { + flush_id ^= NAPI_GRO_CB(p)->count; + flush_id = flush_id ? 0xFFFF : 0; + } + + /* If the previous IP ID value was based on an atomic + * datagram we can overwrite the value and ignore it. + */ + if (NAPI_GRO_CB(skb)->is_atomic) + NAPI_GRO_CB(p)->flush_id = flush_id; + else + NAPI_GRO_CB(p)->flush_id |= flush_id; + } + + NAPI_GRO_CB(skb)->is_atomic = !!(iph->frag_off & htons(IP_DF)); + NAPI_GRO_CB(skb)->flush |= flush; + skb_set_network_header(skb, off); + /* The above will be needed by the transport layer if there is one + * immediately following this IP hdr. + */ + + /* Note : No need to call skb_gro_postpull_rcsum() here, + * as we already checked checksum over ipv4 header was 0 + */ + skb_gro_pull(skb, sizeof(*iph)); + skb_set_transport_header(skb, skb_gro_offset(skb)); + + pp = indirect_call_gro_receive(tcp4_gro_receive, udp4_gro_receive, + ops->callbacks.gro_receive, head, skb); + +out: + skb_gro_flush_final(skb, pp, flush); + + return pp; +} + +static struct sk_buff *ipip_gro_receive(struct list_head *head, + struct sk_buff *skb) +{ + if (NAPI_GRO_CB(skb)->encap_mark) { + NAPI_GRO_CB(skb)->flush = 1; + return NULL; + } + + NAPI_GRO_CB(skb)->encap_mark = 1; + + return inet_gro_receive(head, skb); +} + +#define SECONDS_PER_DAY 86400 + +/* inet_current_timestamp - Return IP network timestamp + * + * Return milliseconds since midnight in network byte order. + */ +__be32 inet_current_timestamp(void) +{ + u32 secs; + u32 msecs; + struct timespec64 ts; + + ktime_get_real_ts64(&ts); + + /* Get secs since midnight. */ + (void)div_u64_rem(ts.tv_sec, SECONDS_PER_DAY, &secs); + /* Convert to msecs. */ + msecs = secs * MSEC_PER_SEC; + /* Convert nsec to msec. */ + msecs += (u32)ts.tv_nsec / NSEC_PER_MSEC; + + /* Convert to network byte order. */ + return htonl(msecs); +} +EXPORT_SYMBOL(inet_current_timestamp); + +int inet_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len) +{ + if (sk->sk_family == AF_INET) + return ip_recv_error(sk, msg, len, addr_len); +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) + return pingv6_ops.ipv6_recv_error(sk, msg, len, addr_len); +#endif + return -EINVAL; +} +EXPORT_SYMBOL(inet_recv_error); + +int inet_gro_complete(struct sk_buff *skb, int nhoff) +{ + __be16 newlen = htons(skb->len - nhoff); + struct iphdr *iph = (struct iphdr *)(skb->data + nhoff); + const struct net_offload *ops; + int proto = iph->protocol; + int err = -ENOSYS; + + if (skb->encapsulation) { + skb_set_inner_protocol(skb, cpu_to_be16(ETH_P_IP)); + skb_set_inner_network_header(skb, nhoff); + } + + csum_replace2(&iph->check, iph->tot_len, newlen); + iph->tot_len = newlen; + + ops = rcu_dereference(inet_offloads[proto]); + if (WARN_ON(!ops || !ops->callbacks.gro_complete)) + goto out; + + /* Only need to add sizeof(*iph) to get to the next hdr below + * because any hdr with option will have been flushed in + * inet_gro_receive(). + */ + err = INDIRECT_CALL_2(ops->callbacks.gro_complete, + tcp4_gro_complete, udp4_gro_complete, + skb, nhoff + sizeof(*iph)); + +out: + return err; +} + +static int ipip_gro_complete(struct sk_buff *skb, int nhoff) +{ + skb->encapsulation = 1; + skb_shinfo(skb)->gso_type |= SKB_GSO_IPXIP4; + return inet_gro_complete(skb, nhoff); +} + +int inet_ctl_sock_create(struct sock **sk, unsigned short family, + unsigned short type, unsigned char protocol, + struct net *net) +{ + struct socket *sock; + int rc = sock_create_kern(net, family, type, protocol, &sock); + + if (rc == 0) { + *sk = sock->sk; + (*sk)->sk_allocation = GFP_ATOMIC; + /* + * Unhash it so that IP input processing does not even see it, + * we do not wish this socket to see incoming packets. + */ + (*sk)->sk_prot->unhash(*sk); + } + return rc; +} +EXPORT_SYMBOL_GPL(inet_ctl_sock_create); + +unsigned long snmp_fold_field(void __percpu *mib, int offt) +{ + unsigned long res = 0; + int i; + + for_each_possible_cpu(i) + res += snmp_get_cpu_field(mib, i, offt); + return res; +} +EXPORT_SYMBOL_GPL(snmp_fold_field); + +#if BITS_PER_LONG==32 + +u64 snmp_get_cpu_field64(void __percpu *mib, int cpu, int offt, + size_t syncp_offset) +{ + void *bhptr; + struct u64_stats_sync *syncp; + u64 v; + unsigned int start; + + bhptr = per_cpu_ptr(mib, cpu); + syncp = (struct u64_stats_sync *)(bhptr + syncp_offset); + do { + start = u64_stats_fetch_begin_irq(syncp); + v = *(((u64 *)bhptr) + offt); + } while (u64_stats_fetch_retry_irq(syncp, start)); + + return v; +} +EXPORT_SYMBOL_GPL(snmp_get_cpu_field64); + +u64 snmp_fold_field64(void __percpu *mib, int offt, size_t syncp_offset) +{ + u64 res = 0; + int cpu; + + for_each_possible_cpu(cpu) { + res += snmp_get_cpu_field64(mib, cpu, offt, syncp_offset); + } + return res; +} +EXPORT_SYMBOL_GPL(snmp_fold_field64); +#endif + +#ifdef CONFIG_IP_MULTICAST +static const struct net_protocol igmp_protocol = { + .handler = igmp_rcv, +}; +#endif + +static const struct net_protocol tcp_protocol = { + .handler = tcp_v4_rcv, + .err_handler = tcp_v4_err, + .no_policy = 1, + .icmp_strict_tag_validation = 1, +}; + +static const struct net_protocol udp_protocol = { + .handler = udp_rcv, + .err_handler = udp_err, + .no_policy = 1, +}; + +static const struct net_protocol icmp_protocol = { + .handler = icmp_rcv, + .err_handler = icmp_err, + .no_policy = 1, +}; + +static __net_init int ipv4_mib_init_net(struct net *net) +{ + int i; + + net->mib.tcp_statistics = alloc_percpu(struct tcp_mib); + if (!net->mib.tcp_statistics) + goto err_tcp_mib; + net->mib.ip_statistics = alloc_percpu(struct ipstats_mib); + if (!net->mib.ip_statistics) + goto err_ip_mib; + + for_each_possible_cpu(i) { + struct ipstats_mib *af_inet_stats; + af_inet_stats = per_cpu_ptr(net->mib.ip_statistics, i); + u64_stats_init(&af_inet_stats->syncp); + } + + net->mib.net_statistics = alloc_percpu(struct linux_mib); + if (!net->mib.net_statistics) + goto err_net_mib; + net->mib.udp_statistics = alloc_percpu(struct udp_mib); + if (!net->mib.udp_statistics) + goto err_udp_mib; + net->mib.udplite_statistics = alloc_percpu(struct udp_mib); + if (!net->mib.udplite_statistics) + goto err_udplite_mib; + net->mib.icmp_statistics = alloc_percpu(struct icmp_mib); + if (!net->mib.icmp_statistics) + goto err_icmp_mib; + net->mib.icmpmsg_statistics = kzalloc(sizeof(struct icmpmsg_mib), + GFP_KERNEL); + if (!net->mib.icmpmsg_statistics) + goto err_icmpmsg_mib; + + tcp_mib_init(net); + return 0; + +err_icmpmsg_mib: + free_percpu(net->mib.icmp_statistics); +err_icmp_mib: + free_percpu(net->mib.udplite_statistics); +err_udplite_mib: + free_percpu(net->mib.udp_statistics); +err_udp_mib: + free_percpu(net->mib.net_statistics); +err_net_mib: + free_percpu(net->mib.ip_statistics); +err_ip_mib: + free_percpu(net->mib.tcp_statistics); +err_tcp_mib: + return -ENOMEM; +} + +static __net_exit void ipv4_mib_exit_net(struct net *net) +{ + kfree(net->mib.icmpmsg_statistics); + free_percpu(net->mib.icmp_statistics); + free_percpu(net->mib.udplite_statistics); + free_percpu(net->mib.udp_statistics); + free_percpu(net->mib.net_statistics); + free_percpu(net->mib.ip_statistics); + free_percpu(net->mib.tcp_statistics); +#ifdef CONFIG_MPTCP + /* allocated on demand, see mptcp_init_sock() */ + free_percpu(net->mib.mptcp_statistics); +#endif +} + +static __net_initdata struct pernet_operations ipv4_mib_ops = { + .init = ipv4_mib_init_net, + .exit = ipv4_mib_exit_net, +}; + +static int __init init_ipv4_mibs(void) +{ + return register_pernet_subsys(&ipv4_mib_ops); +} + +static __net_init int inet_init_net(struct net *net) +{ + /* + * Set defaults for local port range + */ + seqlock_init(&net->ipv4.ip_local_ports.lock); + net->ipv4.ip_local_ports.range[0] = 32768; + net->ipv4.ip_local_ports.range[1] = 60999; + + seqlock_init(&net->ipv4.ping_group_range.lock); + /* + * Sane defaults - nobody may create ping sockets. + * Boot scripts should set this to distro-specific group. + */ + net->ipv4.ping_group_range.range[0] = make_kgid(&init_user_ns, 1); + net->ipv4.ping_group_range.range[1] = make_kgid(&init_user_ns, 0); + + /* Default values for sysctl-controlled parameters. + * We set them here, in case sysctl is not compiled. + */ + net->ipv4.sysctl_ip_default_ttl = IPDEFTTL; + net->ipv4.sysctl_ip_fwd_update_priority = 1; + net->ipv4.sysctl_ip_dynaddr = 0; + net->ipv4.sysctl_ip_early_demux = 1; + net->ipv4.sysctl_udp_early_demux = 1; + net->ipv4.sysctl_tcp_early_demux = 1; + net->ipv4.sysctl_nexthop_compat_mode = 1; +#ifdef CONFIG_SYSCTL + net->ipv4.sysctl_ip_prot_sock = PROT_SOCK; +#endif + + /* Some igmp sysctl, whose values are always used */ + net->ipv4.sysctl_igmp_max_memberships = 20; + net->ipv4.sysctl_igmp_max_msf = 10; + /* IGMP reports for link-local multicast groups are enabled by default */ + net->ipv4.sysctl_igmp_llm_reports = 1; + net->ipv4.sysctl_igmp_qrv = 2; + + net->ipv4.sysctl_fib_notify_on_flag_change = 0; + + return 0; +} + +static __net_initdata struct pernet_operations af_inet_ops = { + .init = inet_init_net, +}; + +static int __init init_inet_pernet_ops(void) +{ + return register_pernet_subsys(&af_inet_ops); +} + +static int ipv4_proc_init(void); + +/* + * IP protocol layer initialiser + */ + +static struct packet_offload ip_packet_offload __read_mostly = { + .type = cpu_to_be16(ETH_P_IP), + .callbacks = { + .gso_segment = inet_gso_segment, + .gro_receive = inet_gro_receive, + .gro_complete = inet_gro_complete, + }, +}; + +static const struct net_offload ipip_offload = { + .callbacks = { + .gso_segment = ipip_gso_segment, + .gro_receive = ipip_gro_receive, + .gro_complete = ipip_gro_complete, + }, +}; + +static int __init ipip_offload_init(void) +{ + return inet_add_offload(&ipip_offload, IPPROTO_IPIP); +} + +static int __init ipv4_offload_init(void) +{ + /* + * Add offloads + */ + if (udpv4_offload_init() < 0) + pr_crit("%s: Cannot add UDP protocol offload\n", __func__); + if (tcpv4_offload_init() < 0) + pr_crit("%s: Cannot add TCP protocol offload\n", __func__); + if (ipip_offload_init() < 0) + pr_crit("%s: Cannot add IPIP protocol offload\n", __func__); + + dev_add_offload(&ip_packet_offload); + return 0; +} + +fs_initcall(ipv4_offload_init); + +static struct packet_type ip_packet_type __read_mostly = { + .type = cpu_to_be16(ETH_P_IP), + .func = ip_rcv, + .list_func = ip_list_rcv, +}; + +static int __init inet_init(void) +{ + struct inet_protosw *q; + struct list_head *r; + int rc; + + sock_skb_cb_check_size(sizeof(struct inet_skb_parm)); + + raw_hashinfo_init(&raw_v4_hashinfo); + + rc = proto_register(&tcp_prot, 1); + if (rc) + goto out; + + rc = proto_register(&udp_prot, 1); + if (rc) + goto out_unregister_tcp_proto; + + rc = proto_register(&raw_prot, 1); + if (rc) + goto out_unregister_udp_proto; + + rc = proto_register(&ping_prot, 1); + if (rc) + goto out_unregister_raw_proto; + + /* + * Tell SOCKET that we are alive... + */ + + (void)sock_register(&inet_family_ops); + +#ifdef CONFIG_SYSCTL + ip_static_sysctl_init(); +#endif + + /* + * Add all the base protocols. + */ + + if (inet_add_protocol(&icmp_protocol, IPPROTO_ICMP) < 0) + pr_crit("%s: Cannot add ICMP protocol\n", __func__); + if (inet_add_protocol(&udp_protocol, IPPROTO_UDP) < 0) + pr_crit("%s: Cannot add UDP protocol\n", __func__); + if (inet_add_protocol(&tcp_protocol, IPPROTO_TCP) < 0) + pr_crit("%s: Cannot add TCP protocol\n", __func__); +#ifdef CONFIG_IP_MULTICAST + if (inet_add_protocol(&igmp_protocol, IPPROTO_IGMP) < 0) + pr_crit("%s: Cannot add IGMP protocol\n", __func__); +#endif + + /* Register the socket-side information for inet_create. */ + for (r = &inetsw[0]; r < &inetsw[SOCK_MAX]; ++r) + INIT_LIST_HEAD(r); + + for (q = inetsw_array; q < &inetsw_array[INETSW_ARRAY_LEN]; ++q) + inet_register_protosw(q); + + /* + * Set the ARP module up + */ + + arp_init(); + + /* + * Set the IP module up + */ + + ip_init(); + + /* Initialise per-cpu ipv4 mibs */ + if (init_ipv4_mibs()) + panic("%s: Cannot init ipv4 mibs\n", __func__); + + /* Setup TCP slab cache for open requests. */ + tcp_init(); + + /* Setup UDP memory threshold */ + udp_init(); + + /* Add UDP-Lite (RFC 3828) */ + udplite4_register(); + + raw_init(); + + ping_init(); + + /* + * Set the ICMP layer up + */ + + if (icmp_init() < 0) + panic("Failed to create the ICMP control socket.\n"); + + /* + * Initialise the multicast router + */ +#if defined(CONFIG_IP_MROUTE) + if (ip_mr_init()) + pr_crit("%s: Cannot init ipv4 mroute\n", __func__); +#endif + + if (init_inet_pernet_ops()) + pr_crit("%s: Cannot init ipv4 inet pernet ops\n", __func__); + + ipv4_proc_init(); + + ipfrag_init(); + + dev_add_pack(&ip_packet_type); + + ip_tunnel_core_init(); + + rc = 0; +out: + return rc; +out_unregister_raw_proto: + proto_unregister(&raw_prot); +out_unregister_udp_proto: + proto_unregister(&udp_prot); +out_unregister_tcp_proto: + proto_unregister(&tcp_prot); + goto out; +} + +fs_initcall(inet_init); + +/* ------------------------------------------------------------------------ */ + +#ifdef CONFIG_PROC_FS +static int __init ipv4_proc_init(void) +{ + int rc = 0; + + if (raw_proc_init()) + goto out_raw; + if (tcp4_proc_init()) + goto out_tcp; + if (udp4_proc_init()) + goto out_udp; + if (ping_proc_init()) + goto out_ping; + if (ip_misc_proc_init()) + goto out_misc; +out: + return rc; +out_misc: + ping_proc_exit(); +out_ping: + udp4_proc_exit(); +out_udp: + tcp4_proc_exit(); +out_tcp: + raw_proc_exit(); +out_raw: + rc = -ENOMEM; + goto out; +} + +#else /* CONFIG_PROC_FS */ +static int __init ipv4_proc_init(void) +{ + return 0; +} +#endif /* CONFIG_PROC_FS */ diff --git a/net/ipv4/ah4.c b/net/ipv4/ah4.c new file mode 100644 index 000000000..ee4e578c7 --- /dev/null +++ b/net/ipv4/ah4.c @@ -0,0 +1,604 @@ +// SPDX-License-Identifier: GPL-2.0-only +#define pr_fmt(fmt) "IPsec: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct ah_skb_cb { + struct xfrm_skb_cb xfrm; + void *tmp; +}; + +#define AH_SKB_CB(__skb) ((struct ah_skb_cb *)&((__skb)->cb[0])) + +static void *ah_alloc_tmp(struct crypto_ahash *ahash, int nfrags, + unsigned int size) +{ + unsigned int len; + + len = size + crypto_ahash_digestsize(ahash) + + (crypto_ahash_alignmask(ahash) & + ~(crypto_tfm_ctx_alignment() - 1)); + + len = ALIGN(len, crypto_tfm_ctx_alignment()); + + len += sizeof(struct ahash_request) + crypto_ahash_reqsize(ahash); + len = ALIGN(len, __alignof__(struct scatterlist)); + + len += sizeof(struct scatterlist) * nfrags; + + return kmalloc(len, GFP_ATOMIC); +} + +static inline u8 *ah_tmp_auth(void *tmp, unsigned int offset) +{ + return tmp + offset; +} + +static inline u8 *ah_tmp_icv(struct crypto_ahash *ahash, void *tmp, + unsigned int offset) +{ + return PTR_ALIGN((u8 *)tmp + offset, crypto_ahash_alignmask(ahash) + 1); +} + +static inline struct ahash_request *ah_tmp_req(struct crypto_ahash *ahash, + u8 *icv) +{ + struct ahash_request *req; + + req = (void *)PTR_ALIGN(icv + crypto_ahash_digestsize(ahash), + crypto_tfm_ctx_alignment()); + + ahash_request_set_tfm(req, ahash); + + return req; +} + +static inline struct scatterlist *ah_req_sg(struct crypto_ahash *ahash, + struct ahash_request *req) +{ + return (void *)ALIGN((unsigned long)(req + 1) + + crypto_ahash_reqsize(ahash), + __alignof__(struct scatterlist)); +} + +/* Clear mutable options and find final destination to substitute + * into IP header for icv calculation. Options are already checked + * for validity, so paranoia is not required. */ + +static int ip_clear_mutable_options(const struct iphdr *iph, __be32 *daddr) +{ + unsigned char *optptr = (unsigned char *)(iph+1); + int l = iph->ihl*4 - sizeof(struct iphdr); + int optlen; + + while (l > 0) { + switch (*optptr) { + case IPOPT_END: + return 0; + case IPOPT_NOOP: + l--; + optptr++; + continue; + } + optlen = optptr[1]; + if (optlen<2 || optlen>l) + return -EINVAL; + switch (*optptr) { + case IPOPT_SEC: + case 0x85: /* Some "Extended Security" crap. */ + case IPOPT_CIPSO: + case IPOPT_RA: + case 0x80|21: /* RFC1770 */ + break; + case IPOPT_LSRR: + case IPOPT_SSRR: + if (optlen < 6) + return -EINVAL; + memcpy(daddr, optptr+optlen-4, 4); + fallthrough; + default: + memset(optptr, 0, optlen); + } + l -= optlen; + optptr += optlen; + } + return 0; +} + +static void ah_output_done(struct crypto_async_request *base, int err) +{ + u8 *icv; + struct iphdr *iph; + struct sk_buff *skb = base->data; + struct xfrm_state *x = skb_dst(skb)->xfrm; + struct ah_data *ahp = x->data; + struct iphdr *top_iph = ip_hdr(skb); + struct ip_auth_hdr *ah = ip_auth_hdr(skb); + int ihl = ip_hdrlen(skb); + + iph = AH_SKB_CB(skb)->tmp; + icv = ah_tmp_icv(ahp->ahash, iph, ihl); + memcpy(ah->auth_data, icv, ahp->icv_trunc_len); + + top_iph->tos = iph->tos; + top_iph->ttl = iph->ttl; + top_iph->frag_off = iph->frag_off; + if (top_iph->ihl != 5) { + top_iph->daddr = iph->daddr; + memcpy(top_iph+1, iph+1, top_iph->ihl*4 - sizeof(struct iphdr)); + } + + kfree(AH_SKB_CB(skb)->tmp); + xfrm_output_resume(skb->sk, skb, err); +} + +static int ah_output(struct xfrm_state *x, struct sk_buff *skb) +{ + int err; + int nfrags; + int ihl; + u8 *icv; + struct sk_buff *trailer; + struct crypto_ahash *ahash; + struct ahash_request *req; + struct scatterlist *sg; + struct iphdr *iph, *top_iph; + struct ip_auth_hdr *ah; + struct ah_data *ahp; + int seqhi_len = 0; + __be32 *seqhi; + int sglists = 0; + struct scatterlist *seqhisg; + + ahp = x->data; + ahash = ahp->ahash; + + if ((err = skb_cow_data(skb, 0, &trailer)) < 0) + goto out; + nfrags = err; + + skb_push(skb, -skb_network_offset(skb)); + ah = ip_auth_hdr(skb); + ihl = ip_hdrlen(skb); + + if (x->props.flags & XFRM_STATE_ESN) { + sglists = 1; + seqhi_len = sizeof(*seqhi); + } + err = -ENOMEM; + iph = ah_alloc_tmp(ahash, nfrags + sglists, ihl + seqhi_len); + if (!iph) + goto out; + seqhi = (__be32 *)((char *)iph + ihl); + icv = ah_tmp_icv(ahash, seqhi, seqhi_len); + req = ah_tmp_req(ahash, icv); + sg = ah_req_sg(ahash, req); + seqhisg = sg + nfrags; + + memset(ah->auth_data, 0, ahp->icv_trunc_len); + + top_iph = ip_hdr(skb); + + iph->tos = top_iph->tos; + iph->ttl = top_iph->ttl; + iph->frag_off = top_iph->frag_off; + + if (top_iph->ihl != 5) { + iph->daddr = top_iph->daddr; + memcpy(iph+1, top_iph+1, top_iph->ihl*4 - sizeof(struct iphdr)); + err = ip_clear_mutable_options(top_iph, &top_iph->daddr); + if (err) + goto out_free; + } + + ah->nexthdr = *skb_mac_header(skb); + *skb_mac_header(skb) = IPPROTO_AH; + + top_iph->tos = 0; + top_iph->tot_len = htons(skb->len); + top_iph->frag_off = 0; + top_iph->ttl = 0; + top_iph->check = 0; + + if (x->props.flags & XFRM_STATE_ALIGN4) + ah->hdrlen = (XFRM_ALIGN4(sizeof(*ah) + ahp->icv_trunc_len) >> 2) - 2; + else + ah->hdrlen = (XFRM_ALIGN8(sizeof(*ah) + ahp->icv_trunc_len) >> 2) - 2; + + ah->reserved = 0; + ah->spi = x->id.spi; + ah->seq_no = htonl(XFRM_SKB_CB(skb)->seq.output.low); + + sg_init_table(sg, nfrags + sglists); + err = skb_to_sgvec_nomark(skb, sg, 0, skb->len); + if (unlikely(err < 0)) + goto out_free; + + if (x->props.flags & XFRM_STATE_ESN) { + /* Attach seqhi sg right after packet payload */ + *seqhi = htonl(XFRM_SKB_CB(skb)->seq.output.hi); + sg_set_buf(seqhisg, seqhi, seqhi_len); + } + ahash_request_set_crypt(req, sg, icv, skb->len + seqhi_len); + ahash_request_set_callback(req, 0, ah_output_done, skb); + + AH_SKB_CB(skb)->tmp = iph; + + err = crypto_ahash_digest(req); + if (err) { + if (err == -EINPROGRESS) + goto out; + + if (err == -ENOSPC) + err = NET_XMIT_DROP; + goto out_free; + } + + memcpy(ah->auth_data, icv, ahp->icv_trunc_len); + + top_iph->tos = iph->tos; + top_iph->ttl = iph->ttl; + top_iph->frag_off = iph->frag_off; + if (top_iph->ihl != 5) { + top_iph->daddr = iph->daddr; + memcpy(top_iph+1, iph+1, top_iph->ihl*4 - sizeof(struct iphdr)); + } + +out_free: + kfree(iph); +out: + return err; +} + +static void ah_input_done(struct crypto_async_request *base, int err) +{ + u8 *auth_data; + u8 *icv; + struct iphdr *work_iph; + struct sk_buff *skb = base->data; + struct xfrm_state *x = xfrm_input_state(skb); + struct ah_data *ahp = x->data; + struct ip_auth_hdr *ah = ip_auth_hdr(skb); + int ihl = ip_hdrlen(skb); + int ah_hlen = (ah->hdrlen + 2) << 2; + + if (err) + goto out; + + work_iph = AH_SKB_CB(skb)->tmp; + auth_data = ah_tmp_auth(work_iph, ihl); + icv = ah_tmp_icv(ahp->ahash, auth_data, ahp->icv_trunc_len); + + err = crypto_memneq(icv, auth_data, ahp->icv_trunc_len) ? -EBADMSG : 0; + if (err) + goto out; + + err = ah->nexthdr; + + skb->network_header += ah_hlen; + memcpy(skb_network_header(skb), work_iph, ihl); + __skb_pull(skb, ah_hlen + ihl); + + if (x->props.mode == XFRM_MODE_TUNNEL) + skb_reset_transport_header(skb); + else + skb_set_transport_header(skb, -ihl); +out: + kfree(AH_SKB_CB(skb)->tmp); + xfrm_input_resume(skb, err); +} + +static int ah_input(struct xfrm_state *x, struct sk_buff *skb) +{ + int ah_hlen; + int ihl; + int nexthdr; + int nfrags; + u8 *auth_data; + u8 *icv; + struct sk_buff *trailer; + struct crypto_ahash *ahash; + struct ahash_request *req; + struct scatterlist *sg; + struct iphdr *iph, *work_iph; + struct ip_auth_hdr *ah; + struct ah_data *ahp; + int err = -ENOMEM; + int seqhi_len = 0; + __be32 *seqhi; + int sglists = 0; + struct scatterlist *seqhisg; + + if (!pskb_may_pull(skb, sizeof(*ah))) + goto out; + + ah = (struct ip_auth_hdr *)skb->data; + ahp = x->data; + ahash = ahp->ahash; + + nexthdr = ah->nexthdr; + ah_hlen = (ah->hdrlen + 2) << 2; + + if (x->props.flags & XFRM_STATE_ALIGN4) { + if (ah_hlen != XFRM_ALIGN4(sizeof(*ah) + ahp->icv_full_len) && + ah_hlen != XFRM_ALIGN4(sizeof(*ah) + ahp->icv_trunc_len)) + goto out; + } else { + if (ah_hlen != XFRM_ALIGN8(sizeof(*ah) + ahp->icv_full_len) && + ah_hlen != XFRM_ALIGN8(sizeof(*ah) + ahp->icv_trunc_len)) + goto out; + } + + if (!pskb_may_pull(skb, ah_hlen)) + goto out; + + /* We are going to _remove_ AH header to keep sockets happy, + * so... Later this can change. */ + if (skb_unclone(skb, GFP_ATOMIC)) + goto out; + + skb->ip_summed = CHECKSUM_NONE; + + + if ((err = skb_cow_data(skb, 0, &trailer)) < 0) + goto out; + nfrags = err; + + ah = (struct ip_auth_hdr *)skb->data; + iph = ip_hdr(skb); + ihl = ip_hdrlen(skb); + + if (x->props.flags & XFRM_STATE_ESN) { + sglists = 1; + seqhi_len = sizeof(*seqhi); + } + + work_iph = ah_alloc_tmp(ahash, nfrags + sglists, ihl + + ahp->icv_trunc_len + seqhi_len); + if (!work_iph) { + err = -ENOMEM; + goto out; + } + + seqhi = (__be32 *)((char *)work_iph + ihl); + auth_data = ah_tmp_auth(seqhi, seqhi_len); + icv = ah_tmp_icv(ahash, auth_data, ahp->icv_trunc_len); + req = ah_tmp_req(ahash, icv); + sg = ah_req_sg(ahash, req); + seqhisg = sg + nfrags; + + memcpy(work_iph, iph, ihl); + memcpy(auth_data, ah->auth_data, ahp->icv_trunc_len); + memset(ah->auth_data, 0, ahp->icv_trunc_len); + + iph->ttl = 0; + iph->tos = 0; + iph->frag_off = 0; + iph->check = 0; + if (ihl > sizeof(*iph)) { + __be32 dummy; + err = ip_clear_mutable_options(iph, &dummy); + if (err) + goto out_free; + } + + skb_push(skb, ihl); + + sg_init_table(sg, nfrags + sglists); + err = skb_to_sgvec_nomark(skb, sg, 0, skb->len); + if (unlikely(err < 0)) + goto out_free; + + if (x->props.flags & XFRM_STATE_ESN) { + /* Attach seqhi sg right after packet payload */ + *seqhi = XFRM_SKB_CB(skb)->seq.input.hi; + sg_set_buf(seqhisg, seqhi, seqhi_len); + } + ahash_request_set_crypt(req, sg, icv, skb->len + seqhi_len); + ahash_request_set_callback(req, 0, ah_input_done, skb); + + AH_SKB_CB(skb)->tmp = work_iph; + + err = crypto_ahash_digest(req); + if (err) { + if (err == -EINPROGRESS) + goto out; + + goto out_free; + } + + err = crypto_memneq(icv, auth_data, ahp->icv_trunc_len) ? -EBADMSG : 0; + if (err) + goto out_free; + + skb->network_header += ah_hlen; + memcpy(skb_network_header(skb), work_iph, ihl); + __skb_pull(skb, ah_hlen + ihl); + if (x->props.mode == XFRM_MODE_TUNNEL) + skb_reset_transport_header(skb); + else + skb_set_transport_header(skb, -ihl); + + err = nexthdr; + +out_free: + kfree (work_iph); +out: + return err; +} + +static int ah4_err(struct sk_buff *skb, u32 info) +{ + struct net *net = dev_net(skb->dev); + const struct iphdr *iph = (const struct iphdr *)skb->data; + struct ip_auth_hdr *ah = (struct ip_auth_hdr *)(skb->data+(iph->ihl<<2)); + struct xfrm_state *x; + + switch (icmp_hdr(skb)->type) { + case ICMP_DEST_UNREACH: + if (icmp_hdr(skb)->code != ICMP_FRAG_NEEDED) + return 0; + break; + case ICMP_REDIRECT: + break; + default: + return 0; + } + + x = xfrm_state_lookup(net, skb->mark, (const xfrm_address_t *)&iph->daddr, + ah->spi, IPPROTO_AH, AF_INET); + if (!x) + return 0; + + if (icmp_hdr(skb)->type == ICMP_DEST_UNREACH) + ipv4_update_pmtu(skb, net, info, 0, IPPROTO_AH); + else + ipv4_redirect(skb, net, 0, IPPROTO_AH); + xfrm_state_put(x); + + return 0; +} + +static int ah_init_state(struct xfrm_state *x, struct netlink_ext_ack *extack) +{ + struct ah_data *ahp = NULL; + struct xfrm_algo_desc *aalg_desc; + struct crypto_ahash *ahash; + + if (!x->aalg) { + NL_SET_ERR_MSG(extack, "AH requires a state with an AUTH algorithm"); + goto error; + } + + if (x->encap) { + NL_SET_ERR_MSG(extack, "AH is not compatible with encapsulation"); + goto error; + } + + ahp = kzalloc(sizeof(*ahp), GFP_KERNEL); + if (!ahp) + return -ENOMEM; + + ahash = crypto_alloc_ahash(x->aalg->alg_name, 0, 0); + if (IS_ERR(ahash)) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); + goto error; + } + + ahp->ahash = ahash; + if (crypto_ahash_setkey(ahash, x->aalg->alg_key, + (x->aalg->alg_key_len + 7) / 8)) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); + goto error; + } + + /* + * Lookup the algorithm description maintained by xfrm_algo, + * verify crypto transform properties, and store information + * we need for AH processing. This lookup cannot fail here + * after a successful crypto_alloc_ahash(). + */ + aalg_desc = xfrm_aalg_get_byname(x->aalg->alg_name, 0); + BUG_ON(!aalg_desc); + + if (aalg_desc->uinfo.auth.icv_fullbits/8 != + crypto_ahash_digestsize(ahash)) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); + goto error; + } + + ahp->icv_full_len = aalg_desc->uinfo.auth.icv_fullbits/8; + ahp->icv_trunc_len = x->aalg->alg_trunc_len/8; + + if (x->props.flags & XFRM_STATE_ALIGN4) + x->props.header_len = XFRM_ALIGN4(sizeof(struct ip_auth_hdr) + + ahp->icv_trunc_len); + else + x->props.header_len = XFRM_ALIGN8(sizeof(struct ip_auth_hdr) + + ahp->icv_trunc_len); + if (x->props.mode == XFRM_MODE_TUNNEL) + x->props.header_len += sizeof(struct iphdr); + x->data = ahp; + + return 0; + +error: + if (ahp) { + crypto_free_ahash(ahp->ahash); + kfree(ahp); + } + return -EINVAL; +} + +static void ah_destroy(struct xfrm_state *x) +{ + struct ah_data *ahp = x->data; + + if (!ahp) + return; + + crypto_free_ahash(ahp->ahash); + kfree(ahp); +} + +static int ah4_rcv_cb(struct sk_buff *skb, int err) +{ + return 0; +} + +static const struct xfrm_type ah_type = +{ + .owner = THIS_MODULE, + .proto = IPPROTO_AH, + .flags = XFRM_TYPE_REPLAY_PROT, + .init_state = ah_init_state, + .destructor = ah_destroy, + .input = ah_input, + .output = ah_output +}; + +static struct xfrm4_protocol ah4_protocol = { + .handler = xfrm4_rcv, + .input_handler = xfrm_input, + .cb_handler = ah4_rcv_cb, + .err_handler = ah4_err, + .priority = 0, +}; + +static int __init ah4_init(void) +{ + if (xfrm_register_type(&ah_type, AF_INET) < 0) { + pr_info("%s: can't add xfrm type\n", __func__); + return -EAGAIN; + } + if (xfrm4_protocol_register(&ah4_protocol, IPPROTO_AH) < 0) { + pr_info("%s: can't add protocol\n", __func__); + xfrm_unregister_type(&ah_type, AF_INET); + return -EAGAIN; + } + return 0; +} + +static void __exit ah4_fini(void) +{ + if (xfrm4_protocol_deregister(&ah4_protocol, IPPROTO_AH) < 0) + pr_info("%s: can't remove protocol\n", __func__); + xfrm_unregister_type(&ah_type, AF_INET); +} + +module_init(ah4_init); +module_exit(ah4_fini); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_XFRM_TYPE(AF_INET, XFRM_PROTO_AH); diff --git a/net/ipv4/arp.c b/net/ipv4/arp.c new file mode 100644 index 000000000..9456f5bb3 --- /dev/null +++ b/net/ipv4/arp.c @@ -0,0 +1,1472 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* linux/net/ipv4/arp.c + * + * Copyright (C) 1994 by Florian La Roche + * + * This module implements the Address Resolution Protocol ARP (RFC 826), + * which is used to convert IP addresses (or in the future maybe other + * high-level addresses) into a low-level hardware address (like an Ethernet + * address). + * + * Fixes: + * Alan Cox : Removed the Ethernet assumptions in + * Florian's code + * Alan Cox : Fixed some small errors in the ARP + * logic + * Alan Cox : Allow >4K in /proc + * Alan Cox : Make ARP add its own protocol entry + * Ross Martin : Rewrote arp_rcv() and arp_get_info() + * Stephen Henson : Add AX25 support to arp_get_info() + * Alan Cox : Drop data when a device is downed. + * Alan Cox : Use init_timer(). + * Alan Cox : Double lock fixes. + * Martin Seine : Move the arphdr structure + * to if_arp.h for compatibility. + * with BSD based programs. + * Andrew Tridgell : Added ARP netmask code and + * re-arranged proxy handling. + * Alan Cox : Changed to use notifiers. + * Niibe Yutaka : Reply for this device or proxies only. + * Alan Cox : Don't proxy across hardware types! + * Jonathan Naylor : Added support for NET/ROM. + * Mike Shaver : RFC1122 checks. + * Jonathan Naylor : Only lookup the hardware address for + * the correct hardware type. + * Germano Caronni : Assorted subtle races. + * Craig Schlenter : Don't modify permanent entry + * during arp_rcv. + * Russ Nelson : Tidied up a few bits. + * Alexey Kuznetsov: Major changes to caching and behaviour, + * eg intelligent arp probing and + * generation + * of host down events. + * Alan Cox : Missing unlock in device events. + * Eckes : ARP ioctl control errors. + * Alexey Kuznetsov: Arp free fix. + * Manuel Rodriguez: Gratuitous ARP. + * Jonathan Layes : Added arpd support through kerneld + * message queue (960314) + * Mike Shaver : /proc/sys/net/ipv4/arp_* support + * Mike McLagan : Routing by source + * Stuart Cheshire : Metricom and grat arp fixes + * *** FOR 2.1 clean this up *** + * Lawrence V. Stefani: (08/12/96) Added FDDI support. + * Alan Cox : Took the AP1000 nasty FDDI hack and + * folded into the mainstream FDDI code. + * Ack spit, Linus how did you allow that + * one in... + * Jes Sorensen : Make FDDI work again in 2.1.x and + * clean up the APFDDI & gen. FDDI bits. + * Alexey Kuznetsov: new arp state machine; + * now it is in net/core/neighbour.c. + * Krzysztof Halasa: Added Frame Relay ARP support. + * Arnaldo C. Melo : convert /proc/net/arp to seq_file + * Shmulik Hen: Split arp_send to arp_create and + * arp_xmit so intermediate drivers like + * bonding can change the skb before + * sending (e.g. insert 8021q tag). + * Harald Welte : convert to make use of jenkins hash + * Jesper D. Brouer: Proxy ARP PVLAN RFC 3069 support. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_SYSCTL +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +/* + * Interface to generic neighbour cache. + */ +static u32 arp_hash(const void *pkey, const struct net_device *dev, __u32 *hash_rnd); +static bool arp_key_eq(const struct neighbour *n, const void *pkey); +static int arp_constructor(struct neighbour *neigh); +static void arp_solicit(struct neighbour *neigh, struct sk_buff *skb); +static void arp_error_report(struct neighbour *neigh, struct sk_buff *skb); +static void parp_redo(struct sk_buff *skb); +static int arp_is_multicast(const void *pkey); + +static const struct neigh_ops arp_generic_ops = { + .family = AF_INET, + .solicit = arp_solicit, + .error_report = arp_error_report, + .output = neigh_resolve_output, + .connected_output = neigh_connected_output, +}; + +static const struct neigh_ops arp_hh_ops = { + .family = AF_INET, + .solicit = arp_solicit, + .error_report = arp_error_report, + .output = neigh_resolve_output, + .connected_output = neigh_resolve_output, +}; + +static const struct neigh_ops arp_direct_ops = { + .family = AF_INET, + .output = neigh_direct_output, + .connected_output = neigh_direct_output, +}; + +struct neigh_table arp_tbl = { + .family = AF_INET, + .key_len = 4, + .protocol = cpu_to_be16(ETH_P_IP), + .hash = arp_hash, + .key_eq = arp_key_eq, + .constructor = arp_constructor, + .proxy_redo = parp_redo, + .is_multicast = arp_is_multicast, + .id = "arp_cache", + .parms = { + .tbl = &arp_tbl, + .reachable_time = 30 * HZ, + .data = { + [NEIGH_VAR_MCAST_PROBES] = 3, + [NEIGH_VAR_UCAST_PROBES] = 3, + [NEIGH_VAR_RETRANS_TIME] = 1 * HZ, + [NEIGH_VAR_BASE_REACHABLE_TIME] = 30 * HZ, + [NEIGH_VAR_DELAY_PROBE_TIME] = 5 * HZ, + [NEIGH_VAR_INTERVAL_PROBE_TIME_MS] = 5 * HZ, + [NEIGH_VAR_GC_STALETIME] = 60 * HZ, + [NEIGH_VAR_QUEUE_LEN_BYTES] = SK_WMEM_MAX, + [NEIGH_VAR_PROXY_QLEN] = 64, + [NEIGH_VAR_ANYCAST_DELAY] = 1 * HZ, + [NEIGH_VAR_PROXY_DELAY] = (8 * HZ) / 10, + [NEIGH_VAR_LOCKTIME] = 1 * HZ, + }, + }, + .gc_interval = 30 * HZ, + .gc_thresh1 = 128, + .gc_thresh2 = 512, + .gc_thresh3 = 1024, +}; +EXPORT_SYMBOL(arp_tbl); + +int arp_mc_map(__be32 addr, u8 *haddr, struct net_device *dev, int dir) +{ + switch (dev->type) { + case ARPHRD_ETHER: + case ARPHRD_FDDI: + case ARPHRD_IEEE802: + ip_eth_mc_map(addr, haddr); + return 0; + case ARPHRD_INFINIBAND: + ip_ib_mc_map(addr, dev->broadcast, haddr); + return 0; + case ARPHRD_IPGRE: + ip_ipgre_mc_map(addr, dev->broadcast, haddr); + return 0; + default: + if (dir) { + memcpy(haddr, dev->broadcast, dev->addr_len); + return 0; + } + } + return -EINVAL; +} + + +static u32 arp_hash(const void *pkey, + const struct net_device *dev, + __u32 *hash_rnd) +{ + return arp_hashfn(pkey, dev, hash_rnd); +} + +static bool arp_key_eq(const struct neighbour *neigh, const void *pkey) +{ + return neigh_key_eq32(neigh, pkey); +} + +static int arp_constructor(struct neighbour *neigh) +{ + __be32 addr; + struct net_device *dev = neigh->dev; + struct in_device *in_dev; + struct neigh_parms *parms; + u32 inaddr_any = INADDR_ANY; + + if (dev->flags & (IFF_LOOPBACK | IFF_POINTOPOINT)) + memcpy(neigh->primary_key, &inaddr_any, arp_tbl.key_len); + + addr = *(__be32 *)neigh->primary_key; + rcu_read_lock(); + in_dev = __in_dev_get_rcu(dev); + if (!in_dev) { + rcu_read_unlock(); + return -EINVAL; + } + + neigh->type = inet_addr_type_dev_table(dev_net(dev), dev, addr); + + parms = in_dev->arp_parms; + __neigh_parms_put(neigh->parms); + neigh->parms = neigh_parms_clone(parms); + rcu_read_unlock(); + + if (!dev->header_ops) { + neigh->nud_state = NUD_NOARP; + neigh->ops = &arp_direct_ops; + neigh->output = neigh_direct_output; + } else { + /* Good devices (checked by reading texts, but only Ethernet is + tested) + + ARPHRD_ETHER: (ethernet, apfddi) + ARPHRD_FDDI: (fddi) + ARPHRD_IEEE802: (tr) + ARPHRD_METRICOM: (strip) + ARPHRD_ARCNET: + etc. etc. etc. + + ARPHRD_IPDDP will also work, if author repairs it. + I did not it, because this driver does not work even + in old paradigm. + */ + + if (neigh->type == RTN_MULTICAST) { + neigh->nud_state = NUD_NOARP; + arp_mc_map(addr, neigh->ha, dev, 1); + } else if (dev->flags & (IFF_NOARP | IFF_LOOPBACK)) { + neigh->nud_state = NUD_NOARP; + memcpy(neigh->ha, dev->dev_addr, dev->addr_len); + } else if (neigh->type == RTN_BROADCAST || + (dev->flags & IFF_POINTOPOINT)) { + neigh->nud_state = NUD_NOARP; + memcpy(neigh->ha, dev->broadcast, dev->addr_len); + } + + if (dev->header_ops->cache) + neigh->ops = &arp_hh_ops; + else + neigh->ops = &arp_generic_ops; + + if (neigh->nud_state & NUD_VALID) + neigh->output = neigh->ops->connected_output; + else + neigh->output = neigh->ops->output; + } + return 0; +} + +static void arp_error_report(struct neighbour *neigh, struct sk_buff *skb) +{ + dst_link_failure(skb); + kfree_skb_reason(skb, SKB_DROP_REASON_NEIGH_FAILED); +} + +/* Create and send an arp packet. */ +static void arp_send_dst(int type, int ptype, __be32 dest_ip, + struct net_device *dev, __be32 src_ip, + const unsigned char *dest_hw, + const unsigned char *src_hw, + const unsigned char *target_hw, + struct dst_entry *dst) +{ + struct sk_buff *skb; + + /* arp on this interface. */ + if (dev->flags & IFF_NOARP) + return; + + skb = arp_create(type, ptype, dest_ip, dev, src_ip, + dest_hw, src_hw, target_hw); + if (!skb) + return; + + skb_dst_set(skb, dst_clone(dst)); + arp_xmit(skb); +} + +void arp_send(int type, int ptype, __be32 dest_ip, + struct net_device *dev, __be32 src_ip, + const unsigned char *dest_hw, const unsigned char *src_hw, + const unsigned char *target_hw) +{ + arp_send_dst(type, ptype, dest_ip, dev, src_ip, dest_hw, src_hw, + target_hw, NULL); +} +EXPORT_SYMBOL(arp_send); + +static void arp_solicit(struct neighbour *neigh, struct sk_buff *skb) +{ + __be32 saddr = 0; + u8 dst_ha[MAX_ADDR_LEN], *dst_hw = NULL; + struct net_device *dev = neigh->dev; + __be32 target = *(__be32 *)neigh->primary_key; + int probes = atomic_read(&neigh->probes); + struct in_device *in_dev; + struct dst_entry *dst = NULL; + + rcu_read_lock(); + in_dev = __in_dev_get_rcu(dev); + if (!in_dev) { + rcu_read_unlock(); + return; + } + switch (IN_DEV_ARP_ANNOUNCE(in_dev)) { + default: + case 0: /* By default announce any local IP */ + if (skb && inet_addr_type_dev_table(dev_net(dev), dev, + ip_hdr(skb)->saddr) == RTN_LOCAL) + saddr = ip_hdr(skb)->saddr; + break; + case 1: /* Restrict announcements of saddr in same subnet */ + if (!skb) + break; + saddr = ip_hdr(skb)->saddr; + if (inet_addr_type_dev_table(dev_net(dev), dev, + saddr) == RTN_LOCAL) { + /* saddr should be known to target */ + if (inet_addr_onlink(in_dev, target, saddr)) + break; + } + saddr = 0; + break; + case 2: /* Avoid secondary IPs, get a primary/preferred one */ + break; + } + rcu_read_unlock(); + + if (!saddr) + saddr = inet_select_addr(dev, target, RT_SCOPE_LINK); + + probes -= NEIGH_VAR(neigh->parms, UCAST_PROBES); + if (probes < 0) { + if (!(READ_ONCE(neigh->nud_state) & NUD_VALID)) + pr_debug("trying to ucast probe in NUD_INVALID\n"); + neigh_ha_snapshot(dst_ha, neigh, dev); + dst_hw = dst_ha; + } else { + probes -= NEIGH_VAR(neigh->parms, APP_PROBES); + if (probes < 0) { + neigh_app_ns(neigh); + return; + } + } + + if (skb && !(dev->priv_flags & IFF_XMIT_DST_RELEASE)) + dst = skb_dst(skb); + arp_send_dst(ARPOP_REQUEST, ETH_P_ARP, target, dev, saddr, + dst_hw, dev->dev_addr, NULL, dst); +} + +static int arp_ignore(struct in_device *in_dev, __be32 sip, __be32 tip) +{ + struct net *net = dev_net(in_dev->dev); + int scope; + + switch (IN_DEV_ARP_IGNORE(in_dev)) { + case 0: /* Reply, the tip is already validated */ + return 0; + case 1: /* Reply only if tip is configured on the incoming interface */ + sip = 0; + scope = RT_SCOPE_HOST; + break; + case 2: /* + * Reply only if tip is configured on the incoming interface + * and is in same subnet as sip + */ + scope = RT_SCOPE_HOST; + break; + case 3: /* Do not reply for scope host addresses */ + sip = 0; + scope = RT_SCOPE_LINK; + in_dev = NULL; + break; + case 4: /* Reserved */ + case 5: + case 6: + case 7: + return 0; + case 8: /* Do not reply */ + return 1; + default: + return 0; + } + return !inet_confirm_addr(net, in_dev, sip, tip, scope); +} + +static int arp_accept(struct in_device *in_dev, __be32 sip) +{ + struct net *net = dev_net(in_dev->dev); + int scope = RT_SCOPE_LINK; + + switch (IN_DEV_ARP_ACCEPT(in_dev)) { + case 0: /* Don't create new entries from garp */ + return 0; + case 1: /* Create new entries from garp */ + return 1; + case 2: /* Create a neighbor in the arp table only if sip + * is in the same subnet as an address configured + * on the interface that received the garp message + */ + return !!inet_confirm_addr(net, in_dev, sip, 0, scope); + default: + return 0; + } +} + +static int arp_filter(__be32 sip, __be32 tip, struct net_device *dev) +{ + struct rtable *rt; + int flag = 0; + /*unsigned long now; */ + struct net *net = dev_net(dev); + + rt = ip_route_output(net, sip, tip, 0, l3mdev_master_ifindex_rcu(dev)); + if (IS_ERR(rt)) + return 1; + if (rt->dst.dev != dev) { + __NET_INC_STATS(net, LINUX_MIB_ARPFILTER); + flag = 1; + } + ip_rt_put(rt); + return flag; +} + +/* + * Check if we can use proxy ARP for this path + */ +static inline int arp_fwd_proxy(struct in_device *in_dev, + struct net_device *dev, struct rtable *rt) +{ + struct in_device *out_dev; + int imi, omi = -1; + + if (rt->dst.dev == dev) + return 0; + + if (!IN_DEV_PROXY_ARP(in_dev)) + return 0; + imi = IN_DEV_MEDIUM_ID(in_dev); + if (imi == 0) + return 1; + if (imi == -1) + return 0; + + /* place to check for proxy_arp for routes */ + + out_dev = __in_dev_get_rcu(rt->dst.dev); + if (out_dev) + omi = IN_DEV_MEDIUM_ID(out_dev); + + return omi != imi && omi != -1; +} + +/* + * Check for RFC3069 proxy arp private VLAN (allow to send back to same dev) + * + * RFC3069 supports proxy arp replies back to the same interface. This + * is done to support (ethernet) switch features, like RFC 3069, where + * the individual ports are not allowed to communicate with each + * other, BUT they are allowed to talk to the upstream router. As + * described in RFC 3069, it is possible to allow these hosts to + * communicate through the upstream router, by proxy_arp'ing. + * + * RFC 3069: "VLAN Aggregation for Efficient IP Address Allocation" + * + * This technology is known by different names: + * In RFC 3069 it is called VLAN Aggregation. + * Cisco and Allied Telesyn call it Private VLAN. + * Hewlett-Packard call it Source-Port filtering or port-isolation. + * Ericsson call it MAC-Forced Forwarding (RFC Draft). + * + */ +static inline int arp_fwd_pvlan(struct in_device *in_dev, + struct net_device *dev, struct rtable *rt, + __be32 sip, __be32 tip) +{ + /* Private VLAN is only concerned about the same ethernet segment */ + if (rt->dst.dev != dev) + return 0; + + /* Don't reply on self probes (often done by windowz boxes)*/ + if (sip == tip) + return 0; + + if (IN_DEV_PROXY_ARP_PVLAN(in_dev)) + return 1; + else + return 0; +} + +/* + * Interface to link layer: send routine and receive handler. + */ + +/* + * Create an arp packet. If dest_hw is not set, we create a broadcast + * message. + */ +struct sk_buff *arp_create(int type, int ptype, __be32 dest_ip, + struct net_device *dev, __be32 src_ip, + const unsigned char *dest_hw, + const unsigned char *src_hw, + const unsigned char *target_hw) +{ + struct sk_buff *skb; + struct arphdr *arp; + unsigned char *arp_ptr; + int hlen = LL_RESERVED_SPACE(dev); + int tlen = dev->needed_tailroom; + + /* + * Allocate a buffer + */ + + skb = alloc_skb(arp_hdr_len(dev) + hlen + tlen, GFP_ATOMIC); + if (!skb) + return NULL; + + skb_reserve(skb, hlen); + skb_reset_network_header(skb); + arp = skb_put(skb, arp_hdr_len(dev)); + skb->dev = dev; + skb->protocol = htons(ETH_P_ARP); + if (!src_hw) + src_hw = dev->dev_addr; + if (!dest_hw) + dest_hw = dev->broadcast; + + /* + * Fill the device header for the ARP frame + */ + if (dev_hard_header(skb, dev, ptype, dest_hw, src_hw, skb->len) < 0) + goto out; + + /* + * Fill out the arp protocol part. + * + * The arp hardware type should match the device type, except for FDDI, + * which (according to RFC 1390) should always equal 1 (Ethernet). + */ + /* + * Exceptions everywhere. AX.25 uses the AX.25 PID value not the + * DIX code for the protocol. Make these device structure fields. + */ + switch (dev->type) { + default: + arp->ar_hrd = htons(dev->type); + arp->ar_pro = htons(ETH_P_IP); + break; + +#if IS_ENABLED(CONFIG_AX25) + case ARPHRD_AX25: + arp->ar_hrd = htons(ARPHRD_AX25); + arp->ar_pro = htons(AX25_P_IP); + break; + +#if IS_ENABLED(CONFIG_NETROM) + case ARPHRD_NETROM: + arp->ar_hrd = htons(ARPHRD_NETROM); + arp->ar_pro = htons(AX25_P_IP); + break; +#endif +#endif + +#if IS_ENABLED(CONFIG_FDDI) + case ARPHRD_FDDI: + arp->ar_hrd = htons(ARPHRD_ETHER); + arp->ar_pro = htons(ETH_P_IP); + break; +#endif + } + + arp->ar_hln = dev->addr_len; + arp->ar_pln = 4; + arp->ar_op = htons(type); + + arp_ptr = (unsigned char *)(arp + 1); + + memcpy(arp_ptr, src_hw, dev->addr_len); + arp_ptr += dev->addr_len; + memcpy(arp_ptr, &src_ip, 4); + arp_ptr += 4; + + switch (dev->type) { +#if IS_ENABLED(CONFIG_FIREWIRE_NET) + case ARPHRD_IEEE1394: + break; +#endif + default: + if (target_hw) + memcpy(arp_ptr, target_hw, dev->addr_len); + else + memset(arp_ptr, 0, dev->addr_len); + arp_ptr += dev->addr_len; + } + memcpy(arp_ptr, &dest_ip, 4); + + return skb; + +out: + kfree_skb(skb); + return NULL; +} +EXPORT_SYMBOL(arp_create); + +static int arp_xmit_finish(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + return dev_queue_xmit(skb); +} + +/* + * Send an arp packet. + */ +void arp_xmit(struct sk_buff *skb) +{ + /* Send it off, maybe filter it using firewalling first. */ + NF_HOOK(NFPROTO_ARP, NF_ARP_OUT, + dev_net(skb->dev), NULL, skb, NULL, skb->dev, + arp_xmit_finish); +} +EXPORT_SYMBOL(arp_xmit); + +static bool arp_is_garp(struct net *net, struct net_device *dev, + int *addr_type, __be16 ar_op, + __be32 sip, __be32 tip, + unsigned char *sha, unsigned char *tha) +{ + bool is_garp = tip == sip; + + /* Gratuitous ARP _replies_ also require target hwaddr to be + * the same as source. + */ + if (is_garp && ar_op == htons(ARPOP_REPLY)) + is_garp = + /* IPv4 over IEEE 1394 doesn't provide target + * hardware address field in its ARP payload. + */ + tha && + !memcmp(tha, sha, dev->addr_len); + + if (is_garp) { + *addr_type = inet_addr_type_dev_table(net, dev, sip); + if (*addr_type != RTN_UNICAST) + is_garp = false; + } + return is_garp; +} + +/* + * Process an arp request. + */ + +static int arp_process(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct net_device *dev = skb->dev; + struct in_device *in_dev = __in_dev_get_rcu(dev); + struct arphdr *arp; + unsigned char *arp_ptr; + struct rtable *rt; + unsigned char *sha; + unsigned char *tha = NULL; + __be32 sip, tip; + u16 dev_type = dev->type; + int addr_type; + struct neighbour *n; + struct dst_entry *reply_dst = NULL; + bool is_garp = false; + + /* arp_rcv below verifies the ARP header and verifies the device + * is ARP'able. + */ + + if (!in_dev) + goto out_free_skb; + + arp = arp_hdr(skb); + + switch (dev_type) { + default: + if (arp->ar_pro != htons(ETH_P_IP) || + htons(dev_type) != arp->ar_hrd) + goto out_free_skb; + break; + case ARPHRD_ETHER: + case ARPHRD_FDDI: + case ARPHRD_IEEE802: + /* + * ETHERNET, and Fibre Channel (which are IEEE 802 + * devices, according to RFC 2625) devices will accept ARP + * hardware types of either 1 (Ethernet) or 6 (IEEE 802.2). + * This is the case also of FDDI, where the RFC 1390 says that + * FDDI devices should accept ARP hardware of (1) Ethernet, + * however, to be more robust, we'll accept both 1 (Ethernet) + * or 6 (IEEE 802.2) + */ + if ((arp->ar_hrd != htons(ARPHRD_ETHER) && + arp->ar_hrd != htons(ARPHRD_IEEE802)) || + arp->ar_pro != htons(ETH_P_IP)) + goto out_free_skb; + break; + case ARPHRD_AX25: + if (arp->ar_pro != htons(AX25_P_IP) || + arp->ar_hrd != htons(ARPHRD_AX25)) + goto out_free_skb; + break; + case ARPHRD_NETROM: + if (arp->ar_pro != htons(AX25_P_IP) || + arp->ar_hrd != htons(ARPHRD_NETROM)) + goto out_free_skb; + break; + } + + /* Understand only these message types */ + + if (arp->ar_op != htons(ARPOP_REPLY) && + arp->ar_op != htons(ARPOP_REQUEST)) + goto out_free_skb; + +/* + * Extract fields + */ + arp_ptr = (unsigned char *)(arp + 1); + sha = arp_ptr; + arp_ptr += dev->addr_len; + memcpy(&sip, arp_ptr, 4); + arp_ptr += 4; + switch (dev_type) { +#if IS_ENABLED(CONFIG_FIREWIRE_NET) + case ARPHRD_IEEE1394: + break; +#endif + default: + tha = arp_ptr; + arp_ptr += dev->addr_len; + } + memcpy(&tip, arp_ptr, 4); +/* + * Check for bad requests for 127.x.x.x and requests for multicast + * addresses. If this is one such, delete it. + */ + if (ipv4_is_multicast(tip) || + (!IN_DEV_ROUTE_LOCALNET(in_dev) && ipv4_is_loopback(tip))) + goto out_free_skb; + + /* + * For some 802.11 wireless deployments (and possibly other networks), + * there will be an ARP proxy and gratuitous ARP frames are attacks + * and thus should not be accepted. + */ + if (sip == tip && IN_DEV_ORCONF(in_dev, DROP_GRATUITOUS_ARP)) + goto out_free_skb; + +/* + * Special case: We must set Frame Relay source Q.922 address + */ + if (dev_type == ARPHRD_DLCI) + sha = dev->broadcast; + +/* + * Process entry. The idea here is we want to send a reply if it is a + * request for us or if it is a request for someone else that we hold + * a proxy for. We want to add an entry to our cache if it is a reply + * to us or if it is a request for our address. + * (The assumption for this last is that if someone is requesting our + * address, they are probably intending to talk to us, so it saves time + * if we cache their address. Their address is also probably not in + * our cache, since ours is not in their cache.) + * + * Putting this another way, we only care about replies if they are to + * us, in which case we add them to the cache. For requests, we care + * about those for us and those for our proxies. We reply to both, + * and in the case of requests for us we add the requester to the arp + * cache. + */ + + if (arp->ar_op == htons(ARPOP_REQUEST) && skb_metadata_dst(skb)) + reply_dst = (struct dst_entry *) + iptunnel_metadata_reply(skb_metadata_dst(skb), + GFP_ATOMIC); + + /* Special case: IPv4 duplicate address detection packet (RFC2131) */ + if (sip == 0) { + if (arp->ar_op == htons(ARPOP_REQUEST) && + inet_addr_type_dev_table(net, dev, tip) == RTN_LOCAL && + !arp_ignore(in_dev, sip, tip)) + arp_send_dst(ARPOP_REPLY, ETH_P_ARP, sip, dev, tip, + sha, dev->dev_addr, sha, reply_dst); + goto out_consume_skb; + } + + if (arp->ar_op == htons(ARPOP_REQUEST) && + ip_route_input_noref(skb, tip, sip, 0, dev) == 0) { + + rt = skb_rtable(skb); + addr_type = rt->rt_type; + + if (addr_type == RTN_LOCAL) { + int dont_send; + + dont_send = arp_ignore(in_dev, sip, tip); + if (!dont_send && IN_DEV_ARPFILTER(in_dev)) + dont_send = arp_filter(sip, tip, dev); + if (!dont_send) { + n = neigh_event_ns(&arp_tbl, sha, &sip, dev); + if (n) { + arp_send_dst(ARPOP_REPLY, ETH_P_ARP, + sip, dev, tip, sha, + dev->dev_addr, sha, + reply_dst); + neigh_release(n); + } + } + goto out_consume_skb; + } else if (IN_DEV_FORWARD(in_dev)) { + if (addr_type == RTN_UNICAST && + (arp_fwd_proxy(in_dev, dev, rt) || + arp_fwd_pvlan(in_dev, dev, rt, sip, tip) || + (rt->dst.dev != dev && + pneigh_lookup(&arp_tbl, net, &tip, dev, 0)))) { + n = neigh_event_ns(&arp_tbl, sha, &sip, dev); + if (n) + neigh_release(n); + + if (NEIGH_CB(skb)->flags & LOCALLY_ENQUEUED || + skb->pkt_type == PACKET_HOST || + NEIGH_VAR(in_dev->arp_parms, PROXY_DELAY) == 0) { + arp_send_dst(ARPOP_REPLY, ETH_P_ARP, + sip, dev, tip, sha, + dev->dev_addr, sha, + reply_dst); + } else { + pneigh_enqueue(&arp_tbl, + in_dev->arp_parms, skb); + goto out_free_dst; + } + goto out_consume_skb; + } + } + } + + /* Update our ARP tables */ + + n = __neigh_lookup(&arp_tbl, &sip, dev, 0); + + addr_type = -1; + if (n || arp_accept(in_dev, sip)) { + is_garp = arp_is_garp(net, dev, &addr_type, arp->ar_op, + sip, tip, sha, tha); + } + + if (arp_accept(in_dev, sip)) { + /* Unsolicited ARP is not accepted by default. + It is possible, that this option should be enabled for some + devices (strip is candidate) + */ + if (!n && + (is_garp || + (arp->ar_op == htons(ARPOP_REPLY) && + (addr_type == RTN_UNICAST || + (addr_type < 0 && + /* postpone calculation to as late as possible */ + inet_addr_type_dev_table(net, dev, sip) == + RTN_UNICAST))))) + n = __neigh_lookup(&arp_tbl, &sip, dev, 1); + } + + if (n) { + int state = NUD_REACHABLE; + int override; + + /* If several different ARP replies follows back-to-back, + use the FIRST one. It is possible, if several proxy + agents are active. Taking the first reply prevents + arp trashing and chooses the fastest router. + */ + override = time_after(jiffies, + n->updated + + NEIGH_VAR(n->parms, LOCKTIME)) || + is_garp; + + /* Broadcast replies and request packets + do not assert neighbour reachability. + */ + if (arp->ar_op != htons(ARPOP_REPLY) || + skb->pkt_type != PACKET_HOST) + state = NUD_STALE; + neigh_update(n, sha, state, + override ? NEIGH_UPDATE_F_OVERRIDE : 0, 0); + neigh_release(n); + } + +out_consume_skb: + consume_skb(skb); + +out_free_dst: + dst_release(reply_dst); + return NET_RX_SUCCESS; + +out_free_skb: + kfree_skb(skb); + return NET_RX_DROP; +} + +static void parp_redo(struct sk_buff *skb) +{ + arp_process(dev_net(skb->dev), NULL, skb); +} + +static int arp_is_multicast(const void *pkey) +{ + return ipv4_is_multicast(*((__be32 *)pkey)); +} + +/* + * Receive an arp request from the device layer. + */ + +static int arp_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + const struct arphdr *arp; + + /* do not tweak dropwatch on an ARP we will ignore */ + if (dev->flags & IFF_NOARP || + skb->pkt_type == PACKET_OTHERHOST || + skb->pkt_type == PACKET_LOOPBACK) + goto consumeskb; + + skb = skb_share_check(skb, GFP_ATOMIC); + if (!skb) + goto out_of_mem; + + /* ARP header, plus 2 device addresses, plus 2 IP addresses. */ + if (!pskb_may_pull(skb, arp_hdr_len(dev))) + goto freeskb; + + arp = arp_hdr(skb); + if (arp->ar_hln != dev->addr_len || arp->ar_pln != 4) + goto freeskb; + + memset(NEIGH_CB(skb), 0, sizeof(struct neighbour_cb)); + + return NF_HOOK(NFPROTO_ARP, NF_ARP_IN, + dev_net(dev), NULL, skb, dev, NULL, + arp_process); + +consumeskb: + consume_skb(skb); + return NET_RX_SUCCESS; +freeskb: + kfree_skb(skb); +out_of_mem: + return NET_RX_DROP; +} + +/* + * User level interface (ioctl) + */ + +/* + * Set (create) an ARP cache entry. + */ + +static int arp_req_set_proxy(struct net *net, struct net_device *dev, int on) +{ + if (!dev) { + IPV4_DEVCONF_ALL(net, PROXY_ARP) = on; + return 0; + } + if (__in_dev_get_rtnl(dev)) { + IN_DEV_CONF_SET(__in_dev_get_rtnl(dev), PROXY_ARP, on); + return 0; + } + return -ENXIO; +} + +static int arp_req_set_public(struct net *net, struct arpreq *r, + struct net_device *dev) +{ + __be32 ip = ((struct sockaddr_in *)&r->arp_pa)->sin_addr.s_addr; + __be32 mask = ((struct sockaddr_in *)&r->arp_netmask)->sin_addr.s_addr; + + if (mask && mask != htonl(0xFFFFFFFF)) + return -EINVAL; + if (!dev && (r->arp_flags & ATF_COM)) { + dev = dev_getbyhwaddr_rcu(net, r->arp_ha.sa_family, + r->arp_ha.sa_data); + if (!dev) + return -ENODEV; + } + if (mask) { + if (!pneigh_lookup(&arp_tbl, net, &ip, dev, 1)) + return -ENOBUFS; + return 0; + } + + return arp_req_set_proxy(net, dev, 1); +} + +static int arp_req_set(struct net *net, struct arpreq *r, + struct net_device *dev) +{ + __be32 ip; + struct neighbour *neigh; + int err; + + if (r->arp_flags & ATF_PUBL) + return arp_req_set_public(net, r, dev); + + ip = ((struct sockaddr_in *)&r->arp_pa)->sin_addr.s_addr; + if (r->arp_flags & ATF_PERM) + r->arp_flags |= ATF_COM; + if (!dev) { + struct rtable *rt = ip_route_output(net, ip, 0, RTO_ONLINK, 0); + + if (IS_ERR(rt)) + return PTR_ERR(rt); + dev = rt->dst.dev; + ip_rt_put(rt); + if (!dev) + return -EINVAL; + } + switch (dev->type) { +#if IS_ENABLED(CONFIG_FDDI) + case ARPHRD_FDDI: + /* + * According to RFC 1390, FDDI devices should accept ARP + * hardware types of 1 (Ethernet). However, to be more + * robust, we'll accept hardware types of either 1 (Ethernet) + * or 6 (IEEE 802.2). + */ + if (r->arp_ha.sa_family != ARPHRD_FDDI && + r->arp_ha.sa_family != ARPHRD_ETHER && + r->arp_ha.sa_family != ARPHRD_IEEE802) + return -EINVAL; + break; +#endif + default: + if (r->arp_ha.sa_family != dev->type) + return -EINVAL; + break; + } + + neigh = __neigh_lookup_errno(&arp_tbl, &ip, dev); + err = PTR_ERR(neigh); + if (!IS_ERR(neigh)) { + unsigned int state = NUD_STALE; + if (r->arp_flags & ATF_PERM) + state = NUD_PERMANENT; + err = neigh_update(neigh, (r->arp_flags & ATF_COM) ? + r->arp_ha.sa_data : NULL, state, + NEIGH_UPDATE_F_OVERRIDE | + NEIGH_UPDATE_F_ADMIN, 0); + neigh_release(neigh); + } + return err; +} + +static unsigned int arp_state_to_flags(struct neighbour *neigh) +{ + if (neigh->nud_state&NUD_PERMANENT) + return ATF_PERM | ATF_COM; + else if (neigh->nud_state&NUD_VALID) + return ATF_COM; + else + return 0; +} + +/* + * Get an ARP cache entry. + */ + +static int arp_req_get(struct arpreq *r, struct net_device *dev) +{ + __be32 ip = ((struct sockaddr_in *) &r->arp_pa)->sin_addr.s_addr; + struct neighbour *neigh; + int err = -ENXIO; + + neigh = neigh_lookup(&arp_tbl, &ip, dev); + if (neigh) { + if (!(READ_ONCE(neigh->nud_state) & NUD_NOARP)) { + read_lock_bh(&neigh->lock); + memcpy(r->arp_ha.sa_data, neigh->ha, dev->addr_len); + r->arp_flags = arp_state_to_flags(neigh); + read_unlock_bh(&neigh->lock); + r->arp_ha.sa_family = dev->type; + strscpy(r->arp_dev, dev->name, sizeof(r->arp_dev)); + err = 0; + } + neigh_release(neigh); + } + return err; +} + +int arp_invalidate(struct net_device *dev, __be32 ip, bool force) +{ + struct neighbour *neigh = neigh_lookup(&arp_tbl, &ip, dev); + int err = -ENXIO; + struct neigh_table *tbl = &arp_tbl; + + if (neigh) { + if ((READ_ONCE(neigh->nud_state) & NUD_VALID) && !force) { + neigh_release(neigh); + return 0; + } + + if (READ_ONCE(neigh->nud_state) & ~NUD_NOARP) + err = neigh_update(neigh, NULL, NUD_FAILED, + NEIGH_UPDATE_F_OVERRIDE| + NEIGH_UPDATE_F_ADMIN, 0); + write_lock_bh(&tbl->lock); + neigh_release(neigh); + neigh_remove_one(neigh, tbl); + write_unlock_bh(&tbl->lock); + } + + return err; +} + +static int arp_req_delete_public(struct net *net, struct arpreq *r, + struct net_device *dev) +{ + __be32 ip = ((struct sockaddr_in *) &r->arp_pa)->sin_addr.s_addr; + __be32 mask = ((struct sockaddr_in *)&r->arp_netmask)->sin_addr.s_addr; + + if (mask == htonl(0xFFFFFFFF)) + return pneigh_delete(&arp_tbl, net, &ip, dev); + + if (mask) + return -EINVAL; + + return arp_req_set_proxy(net, dev, 0); +} + +static int arp_req_delete(struct net *net, struct arpreq *r, + struct net_device *dev) +{ + __be32 ip; + + if (r->arp_flags & ATF_PUBL) + return arp_req_delete_public(net, r, dev); + + ip = ((struct sockaddr_in *)&r->arp_pa)->sin_addr.s_addr; + if (!dev) { + struct rtable *rt = ip_route_output(net, ip, 0, RTO_ONLINK, 0); + if (IS_ERR(rt)) + return PTR_ERR(rt); + dev = rt->dst.dev; + ip_rt_put(rt); + if (!dev) + return -EINVAL; + } + return arp_invalidate(dev, ip, true); +} + +/* + * Handle an ARP layer I/O control request. + */ + +int arp_ioctl(struct net *net, unsigned int cmd, void __user *arg) +{ + int err; + struct arpreq r; + struct net_device *dev = NULL; + + switch (cmd) { + case SIOCDARP: + case SIOCSARP: + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + fallthrough; + case SIOCGARP: + err = copy_from_user(&r, arg, sizeof(struct arpreq)); + if (err) + return -EFAULT; + break; + default: + return -EINVAL; + } + + if (r.arp_pa.sa_family != AF_INET) + return -EPFNOSUPPORT; + + if (!(r.arp_flags & ATF_PUBL) && + (r.arp_flags & (ATF_NETMASK | ATF_DONTPUB))) + return -EINVAL; + if (!(r.arp_flags & ATF_NETMASK)) + ((struct sockaddr_in *)&r.arp_netmask)->sin_addr.s_addr = + htonl(0xFFFFFFFFUL); + rtnl_lock(); + if (r.arp_dev[0]) { + err = -ENODEV; + dev = __dev_get_by_name(net, r.arp_dev); + if (!dev) + goto out; + + /* Mmmm... It is wrong... ARPHRD_NETROM==0 */ + if (!r.arp_ha.sa_family) + r.arp_ha.sa_family = dev->type; + err = -EINVAL; + if ((r.arp_flags & ATF_COM) && r.arp_ha.sa_family != dev->type) + goto out; + } else if (cmd == SIOCGARP) { + err = -ENODEV; + goto out; + } + + switch (cmd) { + case SIOCDARP: + err = arp_req_delete(net, &r, dev); + break; + case SIOCSARP: + err = arp_req_set(net, &r, dev); + break; + case SIOCGARP: + err = arp_req_get(&r, dev); + break; + } +out: + rtnl_unlock(); + if (cmd == SIOCGARP && !err && copy_to_user(arg, &r, sizeof(r))) + err = -EFAULT; + return err; +} + +static int arp_netdev_event(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct netdev_notifier_change_info *change_info; + struct in_device *in_dev; + bool evict_nocarrier; + + switch (event) { + case NETDEV_CHANGEADDR: + neigh_changeaddr(&arp_tbl, dev); + rt_cache_flush(dev_net(dev)); + break; + case NETDEV_CHANGE: + change_info = ptr; + if (change_info->flags_changed & IFF_NOARP) + neigh_changeaddr(&arp_tbl, dev); + + in_dev = __in_dev_get_rtnl(dev); + if (!in_dev) + evict_nocarrier = true; + else + evict_nocarrier = IN_DEV_ARP_EVICT_NOCARRIER(in_dev); + + if (evict_nocarrier && !netif_carrier_ok(dev)) + neigh_carrier_down(&arp_tbl, dev); + break; + default: + break; + } + + return NOTIFY_DONE; +} + +static struct notifier_block arp_netdev_notifier = { + .notifier_call = arp_netdev_event, +}; + +/* Note, that it is not on notifier chain. + It is necessary, that this routine was called after route cache will be + flushed. + */ +void arp_ifdown(struct net_device *dev) +{ + neigh_ifdown(&arp_tbl, dev); +} + + +/* + * Called once on startup. + */ + +static struct packet_type arp_packet_type __read_mostly = { + .type = cpu_to_be16(ETH_P_ARP), + .func = arp_rcv, +}; + +#ifdef CONFIG_PROC_FS +#if IS_ENABLED(CONFIG_AX25) + +/* + * ax25 -> ASCII conversion + */ +static void ax2asc2(ax25_address *a, char *buf) +{ + char c, *s; + int n; + + for (n = 0, s = buf; n < 6; n++) { + c = (a->ax25_call[n] >> 1) & 0x7F; + + if (c != ' ') + *s++ = c; + } + + *s++ = '-'; + n = (a->ax25_call[6] >> 1) & 0x0F; + if (n > 9) { + *s++ = '1'; + n -= 10; + } + + *s++ = n + '0'; + *s++ = '\0'; + + if (*buf == '\0' || *buf == '-') { + buf[0] = '*'; + buf[1] = '\0'; + } +} +#endif /* CONFIG_AX25 */ + +#define HBUFFERLEN 30 + +static void arp_format_neigh_entry(struct seq_file *seq, + struct neighbour *n) +{ + char hbuffer[HBUFFERLEN]; + int k, j; + char tbuf[16]; + struct net_device *dev = n->dev; + int hatype = dev->type; + + read_lock(&n->lock); + /* Convert hardware address to XX:XX:XX:XX ... form. */ +#if IS_ENABLED(CONFIG_AX25) + if (hatype == ARPHRD_AX25 || hatype == ARPHRD_NETROM) + ax2asc2((ax25_address *)n->ha, hbuffer); + else { +#endif + for (k = 0, j = 0; k < HBUFFERLEN - 3 && j < dev->addr_len; j++) { + hbuffer[k++] = hex_asc_hi(n->ha[j]); + hbuffer[k++] = hex_asc_lo(n->ha[j]); + hbuffer[k++] = ':'; + } + if (k != 0) + --k; + hbuffer[k] = 0; +#if IS_ENABLED(CONFIG_AX25) + } +#endif + sprintf(tbuf, "%pI4", n->primary_key); + seq_printf(seq, "%-16s 0x%-10x0x%-10x%-17s * %s\n", + tbuf, hatype, arp_state_to_flags(n), hbuffer, dev->name); + read_unlock(&n->lock); +} + +static void arp_format_pneigh_entry(struct seq_file *seq, + struct pneigh_entry *n) +{ + struct net_device *dev = n->dev; + int hatype = dev ? dev->type : 0; + char tbuf[16]; + + sprintf(tbuf, "%pI4", n->key); + seq_printf(seq, "%-16s 0x%-10x0x%-10x%s * %s\n", + tbuf, hatype, ATF_PUBL | ATF_PERM, "00:00:00:00:00:00", + dev ? dev->name : "*"); +} + +static int arp_seq_show(struct seq_file *seq, void *v) +{ + if (v == SEQ_START_TOKEN) { + seq_puts(seq, "IP address HW type Flags " + "HW address Mask Device\n"); + } else { + struct neigh_seq_state *state = seq->private; + + if (state->flags & NEIGH_SEQ_IS_PNEIGH) + arp_format_pneigh_entry(seq, v); + else + arp_format_neigh_entry(seq, v); + } + + return 0; +} + +static void *arp_seq_start(struct seq_file *seq, loff_t *pos) +{ + /* Don't want to confuse "arp -a" w/ magic entries, + * so we tell the generic iterator to skip NUD_NOARP. + */ + return neigh_seq_start(seq, pos, &arp_tbl, NEIGH_SEQ_SKIP_NOARP); +} + +static const struct seq_operations arp_seq_ops = { + .start = arp_seq_start, + .next = neigh_seq_next, + .stop = neigh_seq_stop, + .show = arp_seq_show, +}; +#endif /* CONFIG_PROC_FS */ + +static int __net_init arp_net_init(struct net *net) +{ + if (!proc_create_net("arp", 0444, net->proc_net, &arp_seq_ops, + sizeof(struct neigh_seq_state))) + return -ENOMEM; + return 0; +} + +static void __net_exit arp_net_exit(struct net *net) +{ + remove_proc_entry("arp", net->proc_net); +} + +static struct pernet_operations arp_net_ops = { + .init = arp_net_init, + .exit = arp_net_exit, +}; + +void __init arp_init(void) +{ + neigh_table_init(NEIGH_ARP_TABLE, &arp_tbl); + + dev_add_pack(&arp_packet_type); + register_pernet_subsys(&arp_net_ops); +#ifdef CONFIG_SYSCTL + neigh_sysctl_register(NULL, &arp_tbl.parms, NULL); +#endif + register_netdevice_notifier(&arp_netdev_notifier); +} diff --git a/net/ipv4/bpf_tcp_ca.c b/net/ipv4/bpf_tcp_ca.c new file mode 100644 index 000000000..6da16ae6a --- /dev/null +++ b/net/ipv4/bpf_tcp_ca.c @@ -0,0 +1,281 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2019 Facebook */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* "extern" is to avoid sparse warning. It is only used in bpf_struct_ops.c. */ +extern struct bpf_struct_ops bpf_tcp_congestion_ops; + +static u32 unsupported_ops[] = { + offsetof(struct tcp_congestion_ops, get_info), +}; + +static const struct btf_type *tcp_sock_type; +static u32 tcp_sock_id, sock_id; + +static int bpf_tcp_ca_init(struct btf *btf) +{ + s32 type_id; + + type_id = btf_find_by_name_kind(btf, "sock", BTF_KIND_STRUCT); + if (type_id < 0) + return -EINVAL; + sock_id = type_id; + + type_id = btf_find_by_name_kind(btf, "tcp_sock", BTF_KIND_STRUCT); + if (type_id < 0) + return -EINVAL; + tcp_sock_id = type_id; + tcp_sock_type = btf_type_by_id(btf, tcp_sock_id); + + return 0; +} + +static bool is_unsupported(u32 member_offset) +{ + unsigned int i; + + for (i = 0; i < ARRAY_SIZE(unsupported_ops); i++) { + if (member_offset == unsupported_ops[i]) + return true; + } + + return false; +} + +extern struct btf *btf_vmlinux; + +static bool bpf_tcp_ca_is_valid_access(int off, int size, + enum bpf_access_type type, + const struct bpf_prog *prog, + struct bpf_insn_access_aux *info) +{ + if (!bpf_tracing_btf_ctx_access(off, size, type, prog, info)) + return false; + + if (info->reg_type == PTR_TO_BTF_ID && info->btf_id == sock_id) + /* promote it to tcp_sock */ + info->btf_id = tcp_sock_id; + + return true; +} + +static int bpf_tcp_ca_btf_struct_access(struct bpf_verifier_log *log, + const struct btf *btf, + const struct btf_type *t, int off, + int size, enum bpf_access_type atype, + u32 *next_btf_id, + enum bpf_type_flag *flag) +{ + size_t end; + + if (atype == BPF_READ) + return btf_struct_access(log, btf, t, off, size, atype, next_btf_id, + flag); + + if (t != tcp_sock_type) { + bpf_log(log, "only read is supported\n"); + return -EACCES; + } + + switch (off) { + case offsetof(struct sock, sk_pacing_rate): + end = offsetofend(struct sock, sk_pacing_rate); + break; + case offsetof(struct sock, sk_pacing_status): + end = offsetofend(struct sock, sk_pacing_status); + break; + case bpf_ctx_range(struct inet_connection_sock, icsk_ca_priv): + end = offsetofend(struct inet_connection_sock, icsk_ca_priv); + break; + case offsetof(struct inet_connection_sock, icsk_ack.pending): + end = offsetofend(struct inet_connection_sock, + icsk_ack.pending); + break; + case offsetof(struct tcp_sock, snd_cwnd): + end = offsetofend(struct tcp_sock, snd_cwnd); + break; + case offsetof(struct tcp_sock, snd_cwnd_cnt): + end = offsetofend(struct tcp_sock, snd_cwnd_cnt); + break; + case offsetof(struct tcp_sock, snd_ssthresh): + end = offsetofend(struct tcp_sock, snd_ssthresh); + break; + case offsetof(struct tcp_sock, ecn_flags): + end = offsetofend(struct tcp_sock, ecn_flags); + break; + default: + bpf_log(log, "no write support to tcp_sock at off %d\n", off); + return -EACCES; + } + + if (off + size > end) { + bpf_log(log, + "write access at off %d with size %d beyond the member of tcp_sock ended at %zu\n", + off, size, end); + return -EACCES; + } + + return 0; +} + +BPF_CALL_2(bpf_tcp_send_ack, struct tcp_sock *, tp, u32, rcv_nxt) +{ + /* bpf_tcp_ca prog cannot have NULL tp */ + __tcp_send_ack((struct sock *)tp, rcv_nxt); + return 0; +} + +static const struct bpf_func_proto bpf_tcp_send_ack_proto = { + .func = bpf_tcp_send_ack, + .gpl_only = false, + /* In case we want to report error later */ + .ret_type = RET_INTEGER, + .arg1_type = ARG_PTR_TO_BTF_ID, + .arg1_btf_id = &tcp_sock_id, + .arg2_type = ARG_ANYTHING, +}; + +static u32 prog_ops_moff(const struct bpf_prog *prog) +{ + const struct btf_member *m; + const struct btf_type *t; + u32 midx; + + midx = prog->expected_attach_type; + t = bpf_tcp_congestion_ops.type; + m = &btf_type_member(t)[midx]; + + return __btf_member_bit_offset(t, m) / 8; +} + +static const struct bpf_func_proto * +bpf_tcp_ca_get_func_proto(enum bpf_func_id func_id, + const struct bpf_prog *prog) +{ + switch (func_id) { + case BPF_FUNC_tcp_send_ack: + return &bpf_tcp_send_ack_proto; + case BPF_FUNC_sk_storage_get: + return &bpf_sk_storage_get_proto; + case BPF_FUNC_sk_storage_delete: + return &bpf_sk_storage_delete_proto; + case BPF_FUNC_setsockopt: + /* Does not allow release() to call setsockopt. + * release() is called when the current bpf-tcp-cc + * is retiring. It is not allowed to call + * setsockopt() to make further changes which + * may potentially allocate new resources. + */ + if (prog_ops_moff(prog) != + offsetof(struct tcp_congestion_ops, release)) + return &bpf_sk_setsockopt_proto; + return NULL; + case BPF_FUNC_getsockopt: + /* Since get/setsockopt is usually expected to + * be available together, disable getsockopt for + * release also to avoid usage surprise. + * The bpf-tcp-cc already has a more powerful way + * to read tcp_sock from the PTR_TO_BTF_ID. + */ + if (prog_ops_moff(prog) != + offsetof(struct tcp_congestion_ops, release)) + return &bpf_sk_getsockopt_proto; + return NULL; + case BPF_FUNC_ktime_get_coarse_ns: + return &bpf_ktime_get_coarse_ns_proto; + default: + return bpf_base_func_proto(func_id); + } +} + +BTF_SET8_START(bpf_tcp_ca_check_kfunc_ids) +BTF_ID_FLAGS(func, tcp_reno_ssthresh) +BTF_ID_FLAGS(func, tcp_reno_cong_avoid) +BTF_ID_FLAGS(func, tcp_reno_undo_cwnd) +BTF_ID_FLAGS(func, tcp_slow_start) +BTF_ID_FLAGS(func, tcp_cong_avoid_ai) +BTF_SET8_END(bpf_tcp_ca_check_kfunc_ids) + +static const struct btf_kfunc_id_set bpf_tcp_ca_kfunc_set = { + .owner = THIS_MODULE, + .set = &bpf_tcp_ca_check_kfunc_ids, +}; + +static const struct bpf_verifier_ops bpf_tcp_ca_verifier_ops = { + .get_func_proto = bpf_tcp_ca_get_func_proto, + .is_valid_access = bpf_tcp_ca_is_valid_access, + .btf_struct_access = bpf_tcp_ca_btf_struct_access, +}; + +static int bpf_tcp_ca_init_member(const struct btf_type *t, + const struct btf_member *member, + void *kdata, const void *udata) +{ + const struct tcp_congestion_ops *utcp_ca; + struct tcp_congestion_ops *tcp_ca; + u32 moff; + + utcp_ca = (const struct tcp_congestion_ops *)udata; + tcp_ca = (struct tcp_congestion_ops *)kdata; + + moff = __btf_member_bit_offset(t, member) / 8; + switch (moff) { + case offsetof(struct tcp_congestion_ops, flags): + if (utcp_ca->flags & ~TCP_CONG_MASK) + return -EINVAL; + tcp_ca->flags = utcp_ca->flags; + return 1; + case offsetof(struct tcp_congestion_ops, name): + if (bpf_obj_name_cpy(tcp_ca->name, utcp_ca->name, + sizeof(tcp_ca->name)) <= 0) + return -EINVAL; + if (tcp_ca_find(utcp_ca->name)) + return -EEXIST; + return 1; + } + + return 0; +} + +static int bpf_tcp_ca_check_member(const struct btf_type *t, + const struct btf_member *member) +{ + if (is_unsupported(__btf_member_bit_offset(t, member) / 8)) + return -ENOTSUPP; + return 0; +} + +static int bpf_tcp_ca_reg(void *kdata) +{ + return tcp_register_congestion_control(kdata); +} + +static void bpf_tcp_ca_unreg(void *kdata) +{ + tcp_unregister_congestion_control(kdata); +} + +struct bpf_struct_ops bpf_tcp_congestion_ops = { + .verifier_ops = &bpf_tcp_ca_verifier_ops, + .reg = bpf_tcp_ca_reg, + .unreg = bpf_tcp_ca_unreg, + .check_member = bpf_tcp_ca_check_member, + .init_member = bpf_tcp_ca_init_member, + .init = bpf_tcp_ca_init, + .name = "tcp_congestion_ops", +}; + +static int __init bpf_tcp_ca_kfunc_init(void) +{ + return register_btf_kfunc_id_set(BPF_PROG_TYPE_STRUCT_OPS, &bpf_tcp_ca_kfunc_set); +} +late_initcall(bpf_tcp_ca_kfunc_init); diff --git a/net/ipv4/bpfilter/Makefile b/net/ipv4/bpfilter/Makefile new file mode 100644 index 000000000..00af5305e --- /dev/null +++ b/net/ipv4/bpfilter/Makefile @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: GPL-2.0-only +obj-$(CONFIG_BPFILTER) += sockopt.o diff --git a/net/ipv4/bpfilter/sockopt.c b/net/ipv4/bpfilter/sockopt.c new file mode 100644 index 000000000..1b34cb9a7 --- /dev/null +++ b/net/ipv4/bpfilter/sockopt.c @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct bpfilter_umh_ops bpfilter_ops; +EXPORT_SYMBOL_GPL(bpfilter_ops); + +void bpfilter_umh_cleanup(struct umd_info *info) +{ + fput(info->pipe_to_umh); + fput(info->pipe_from_umh); + put_pid(info->tgid); + info->tgid = NULL; +} +EXPORT_SYMBOL_GPL(bpfilter_umh_cleanup); + +static int bpfilter_mbox_request(struct sock *sk, int optname, sockptr_t optval, + unsigned int optlen, bool is_set) +{ + int err; + mutex_lock(&bpfilter_ops.lock); + if (!bpfilter_ops.sockopt) { + mutex_unlock(&bpfilter_ops.lock); + request_module("bpfilter"); + mutex_lock(&bpfilter_ops.lock); + + if (!bpfilter_ops.sockopt) { + err = -ENOPROTOOPT; + goto out; + } + } + if (bpfilter_ops.info.tgid && + thread_group_exited(bpfilter_ops.info.tgid)) + bpfilter_umh_cleanup(&bpfilter_ops.info); + + if (!bpfilter_ops.info.tgid) { + err = bpfilter_ops.start(); + if (err) + goto out; + } + err = bpfilter_ops.sockopt(sk, optname, optval, optlen, is_set); +out: + mutex_unlock(&bpfilter_ops.lock); + return err; +} + +int bpfilter_ip_set_sockopt(struct sock *sk, int optname, sockptr_t optval, + unsigned int optlen) +{ + return bpfilter_mbox_request(sk, optname, optval, optlen, true); +} + +int bpfilter_ip_get_sockopt(struct sock *sk, int optname, char __user *optval, + int __user *optlen) +{ + int len; + + if (get_user(len, optlen)) + return -EFAULT; + + return bpfilter_mbox_request(sk, optname, USER_SOCKPTR(optval), len, + false); +} + +static int __init bpfilter_sockopt_init(void) +{ + mutex_init(&bpfilter_ops.lock); + bpfilter_ops.info.tgid = NULL; + bpfilter_ops.info.driver_name = "bpfilter_umh"; + + return 0; +} +device_initcall(bpfilter_sockopt_init); diff --git a/net/ipv4/cipso_ipv4.c b/net/ipv4/cipso_ipv4.c new file mode 100644 index 000000000..6cd3b6c55 --- /dev/null +++ b/net/ipv4/cipso_ipv4.c @@ -0,0 +1,2295 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * CIPSO - Commercial IP Security Option + * + * This is an implementation of the CIPSO 2.2 protocol as specified in + * draft-ietf-cipso-ipsecurity-01.txt with additional tag types as found in + * FIPS-188. While CIPSO never became a full IETF RFC standard many vendors + * have chosen to adopt the protocol and over the years it has become a + * de-facto standard for labeled networking. + * + * The CIPSO draft specification can be found in the kernel's Documentation + * directory as well as the following URL: + * https://tools.ietf.org/id/draft-ietf-cipso-ipsecurity-01.txt + * The FIPS-188 specification can be found at the following URL: + * https://www.itl.nist.gov/fipspubs/fip188.htm + * + * Author: Paul Moore + */ + +/* + * (c) Copyright Hewlett-Packard Development Company, L.P., 2006, 2008 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* List of available DOI definitions */ +/* XXX - This currently assumes a minimal number of different DOIs in use, + * if in practice there are a lot of different DOIs this list should + * probably be turned into a hash table or something similar so we + * can do quick lookups. */ +static DEFINE_SPINLOCK(cipso_v4_doi_list_lock); +static LIST_HEAD(cipso_v4_doi_list); + +/* Label mapping cache */ +int cipso_v4_cache_enabled = 1; +int cipso_v4_cache_bucketsize = 10; +#define CIPSO_V4_CACHE_BUCKETBITS 7 +#define CIPSO_V4_CACHE_BUCKETS (1 << CIPSO_V4_CACHE_BUCKETBITS) +#define CIPSO_V4_CACHE_REORDERLIMIT 10 +struct cipso_v4_map_cache_bkt { + spinlock_t lock; + u32 size; + struct list_head list; +}; + +struct cipso_v4_map_cache_entry { + u32 hash; + unsigned char *key; + size_t key_len; + + struct netlbl_lsm_cache *lsm_data; + + u32 activity; + struct list_head list; +}; + +static struct cipso_v4_map_cache_bkt *cipso_v4_cache; + +/* Restricted bitmap (tag #1) flags */ +int cipso_v4_rbm_optfmt; +int cipso_v4_rbm_strictvalid = 1; + +/* + * Protocol Constants + */ + +/* Maximum size of the CIPSO IP option, derived from the fact that the maximum + * IPv4 header size is 60 bytes and the base IPv4 header is 20 bytes long. */ +#define CIPSO_V4_OPT_LEN_MAX 40 + +/* Length of the base CIPSO option, this includes the option type (1 byte), the + * option length (1 byte), and the DOI (4 bytes). */ +#define CIPSO_V4_HDR_LEN 6 + +/* Base length of the restrictive category bitmap tag (tag #1). */ +#define CIPSO_V4_TAG_RBM_BLEN 4 + +/* Base length of the enumerated category tag (tag #2). */ +#define CIPSO_V4_TAG_ENUM_BLEN 4 + +/* Base length of the ranged categories bitmap tag (tag #5). */ +#define CIPSO_V4_TAG_RNG_BLEN 4 +/* The maximum number of category ranges permitted in the ranged category tag + * (tag #5). You may note that the IETF draft states that the maximum number + * of category ranges is 7, but if the low end of the last category range is + * zero then it is possible to fit 8 category ranges because the zero should + * be omitted. */ +#define CIPSO_V4_TAG_RNG_CAT_MAX 8 + +/* Base length of the local tag (non-standard tag). + * Tag definition (may change between kernel versions) + * + * 0 8 16 24 32 + * +----------+----------+----------+----------+ + * | 10000000 | 00000110 | 32-bit secid value | + * +----------+----------+----------+----------+ + * | in (host byte order)| + * +----------+----------+ + * + */ +#define CIPSO_V4_TAG_LOC_BLEN 6 + +/* + * Helper Functions + */ + +/** + * cipso_v4_cache_entry_free - Frees a cache entry + * @entry: the entry to free + * + * Description: + * This function frees the memory associated with a cache entry including the + * LSM cache data if there are no longer any users, i.e. reference count == 0. + * + */ +static void cipso_v4_cache_entry_free(struct cipso_v4_map_cache_entry *entry) +{ + if (entry->lsm_data) + netlbl_secattr_cache_free(entry->lsm_data); + kfree(entry->key); + kfree(entry); +} + +/** + * cipso_v4_map_cache_hash - Hashing function for the CIPSO cache + * @key: the hash key + * @key_len: the length of the key in bytes + * + * Description: + * The CIPSO tag hashing function. Returns a 32-bit hash value. + * + */ +static u32 cipso_v4_map_cache_hash(const unsigned char *key, u32 key_len) +{ + return jhash(key, key_len, 0); +} + +/* + * Label Mapping Cache Functions + */ + +/** + * cipso_v4_cache_init - Initialize the CIPSO cache + * + * Description: + * Initializes the CIPSO label mapping cache, this function should be called + * before any of the other functions defined in this file. Returns zero on + * success, negative values on error. + * + */ +static int __init cipso_v4_cache_init(void) +{ + u32 iter; + + cipso_v4_cache = kcalloc(CIPSO_V4_CACHE_BUCKETS, + sizeof(struct cipso_v4_map_cache_bkt), + GFP_KERNEL); + if (!cipso_v4_cache) + return -ENOMEM; + + for (iter = 0; iter < CIPSO_V4_CACHE_BUCKETS; iter++) { + spin_lock_init(&cipso_v4_cache[iter].lock); + cipso_v4_cache[iter].size = 0; + INIT_LIST_HEAD(&cipso_v4_cache[iter].list); + } + + return 0; +} + +/** + * cipso_v4_cache_invalidate - Invalidates the current CIPSO cache + * + * Description: + * Invalidates and frees any entries in the CIPSO cache. + * + */ +void cipso_v4_cache_invalidate(void) +{ + struct cipso_v4_map_cache_entry *entry, *tmp_entry; + u32 iter; + + for (iter = 0; iter < CIPSO_V4_CACHE_BUCKETS; iter++) { + spin_lock_bh(&cipso_v4_cache[iter].lock); + list_for_each_entry_safe(entry, + tmp_entry, + &cipso_v4_cache[iter].list, list) { + list_del(&entry->list); + cipso_v4_cache_entry_free(entry); + } + cipso_v4_cache[iter].size = 0; + spin_unlock_bh(&cipso_v4_cache[iter].lock); + } +} + +/** + * cipso_v4_cache_check - Check the CIPSO cache for a label mapping + * @key: the buffer to check + * @key_len: buffer length in bytes + * @secattr: the security attribute struct to use + * + * Description: + * This function checks the cache to see if a label mapping already exists for + * the given key. If there is a match then the cache is adjusted and the + * @secattr struct is populated with the correct LSM security attributes. The + * cache is adjusted in the following manner if the entry is not already the + * first in the cache bucket: + * + * 1. The cache entry's activity counter is incremented + * 2. The previous (higher ranking) entry's activity counter is decremented + * 3. If the difference between the two activity counters is geater than + * CIPSO_V4_CACHE_REORDERLIMIT the two entries are swapped + * + * Returns zero on success, -ENOENT for a cache miss, and other negative values + * on error. + * + */ +static int cipso_v4_cache_check(const unsigned char *key, + u32 key_len, + struct netlbl_lsm_secattr *secattr) +{ + u32 bkt; + struct cipso_v4_map_cache_entry *entry; + struct cipso_v4_map_cache_entry *prev_entry = NULL; + u32 hash; + + if (!READ_ONCE(cipso_v4_cache_enabled)) + return -ENOENT; + + hash = cipso_v4_map_cache_hash(key, key_len); + bkt = hash & (CIPSO_V4_CACHE_BUCKETS - 1); + spin_lock_bh(&cipso_v4_cache[bkt].lock); + list_for_each_entry(entry, &cipso_v4_cache[bkt].list, list) { + if (entry->hash == hash && + entry->key_len == key_len && + memcmp(entry->key, key, key_len) == 0) { + entry->activity += 1; + refcount_inc(&entry->lsm_data->refcount); + secattr->cache = entry->lsm_data; + secattr->flags |= NETLBL_SECATTR_CACHE; + secattr->type = NETLBL_NLTYPE_CIPSOV4; + if (!prev_entry) { + spin_unlock_bh(&cipso_v4_cache[bkt].lock); + return 0; + } + + if (prev_entry->activity > 0) + prev_entry->activity -= 1; + if (entry->activity > prev_entry->activity && + entry->activity - prev_entry->activity > + CIPSO_V4_CACHE_REORDERLIMIT) { + __list_del(entry->list.prev, entry->list.next); + __list_add(&entry->list, + prev_entry->list.prev, + &prev_entry->list); + } + + spin_unlock_bh(&cipso_v4_cache[bkt].lock); + return 0; + } + prev_entry = entry; + } + spin_unlock_bh(&cipso_v4_cache[bkt].lock); + + return -ENOENT; +} + +/** + * cipso_v4_cache_add - Add an entry to the CIPSO cache + * @cipso_ptr: pointer to CIPSO IP option + * @secattr: the packet's security attributes + * + * Description: + * Add a new entry into the CIPSO label mapping cache. Add the new entry to + * head of the cache bucket's list, if the cache bucket is out of room remove + * the last entry in the list first. It is important to note that there is + * currently no checking for duplicate keys. Returns zero on success, + * negative values on failure. + * + */ +int cipso_v4_cache_add(const unsigned char *cipso_ptr, + const struct netlbl_lsm_secattr *secattr) +{ + int bkt_size = READ_ONCE(cipso_v4_cache_bucketsize); + int ret_val = -EPERM; + u32 bkt; + struct cipso_v4_map_cache_entry *entry = NULL; + struct cipso_v4_map_cache_entry *old_entry = NULL; + u32 cipso_ptr_len; + + if (!READ_ONCE(cipso_v4_cache_enabled) || bkt_size <= 0) + return 0; + + cipso_ptr_len = cipso_ptr[1]; + + entry = kzalloc(sizeof(*entry), GFP_ATOMIC); + if (!entry) + return -ENOMEM; + entry->key = kmemdup(cipso_ptr, cipso_ptr_len, GFP_ATOMIC); + if (!entry->key) { + ret_val = -ENOMEM; + goto cache_add_failure; + } + entry->key_len = cipso_ptr_len; + entry->hash = cipso_v4_map_cache_hash(cipso_ptr, cipso_ptr_len); + refcount_inc(&secattr->cache->refcount); + entry->lsm_data = secattr->cache; + + bkt = entry->hash & (CIPSO_V4_CACHE_BUCKETS - 1); + spin_lock_bh(&cipso_v4_cache[bkt].lock); + if (cipso_v4_cache[bkt].size < bkt_size) { + list_add(&entry->list, &cipso_v4_cache[bkt].list); + cipso_v4_cache[bkt].size += 1; + } else { + old_entry = list_entry(cipso_v4_cache[bkt].list.prev, + struct cipso_v4_map_cache_entry, list); + list_del(&old_entry->list); + list_add(&entry->list, &cipso_v4_cache[bkt].list); + cipso_v4_cache_entry_free(old_entry); + } + spin_unlock_bh(&cipso_v4_cache[bkt].lock); + + return 0; + +cache_add_failure: + if (entry) + cipso_v4_cache_entry_free(entry); + return ret_val; +} + +/* + * DOI List Functions + */ + +/** + * cipso_v4_doi_search - Searches for a DOI definition + * @doi: the DOI to search for + * + * Description: + * Search the DOI definition list for a DOI definition with a DOI value that + * matches @doi. The caller is responsible for calling rcu_read_[un]lock(). + * Returns a pointer to the DOI definition on success and NULL on failure. + */ +static struct cipso_v4_doi *cipso_v4_doi_search(u32 doi) +{ + struct cipso_v4_doi *iter; + + list_for_each_entry_rcu(iter, &cipso_v4_doi_list, list) + if (iter->doi == doi && refcount_read(&iter->refcount)) + return iter; + return NULL; +} + +/** + * cipso_v4_doi_add - Add a new DOI to the CIPSO protocol engine + * @doi_def: the DOI structure + * @audit_info: NetLabel audit information + * + * Description: + * The caller defines a new DOI for use by the CIPSO engine and calls this + * function to add it to the list of acceptable domains. The caller must + * ensure that the mapping table specified in @doi_def->map meets all of the + * requirements of the mapping type (see cipso_ipv4.h for details). Returns + * zero on success and non-zero on failure. + * + */ +int cipso_v4_doi_add(struct cipso_v4_doi *doi_def, + struct netlbl_audit *audit_info) +{ + int ret_val = -EINVAL; + u32 iter; + u32 doi; + u32 doi_type; + struct audit_buffer *audit_buf; + + doi = doi_def->doi; + doi_type = doi_def->type; + + if (doi_def->doi == CIPSO_V4_DOI_UNKNOWN) + goto doi_add_return; + for (iter = 0; iter < CIPSO_V4_TAG_MAXCNT; iter++) { + switch (doi_def->tags[iter]) { + case CIPSO_V4_TAG_RBITMAP: + break; + case CIPSO_V4_TAG_RANGE: + case CIPSO_V4_TAG_ENUM: + if (doi_def->type != CIPSO_V4_MAP_PASS) + goto doi_add_return; + break; + case CIPSO_V4_TAG_LOCAL: + if (doi_def->type != CIPSO_V4_MAP_LOCAL) + goto doi_add_return; + break; + case CIPSO_V4_TAG_INVALID: + if (iter == 0) + goto doi_add_return; + break; + default: + goto doi_add_return; + } + } + + refcount_set(&doi_def->refcount, 1); + + spin_lock(&cipso_v4_doi_list_lock); + if (cipso_v4_doi_search(doi_def->doi)) { + spin_unlock(&cipso_v4_doi_list_lock); + ret_val = -EEXIST; + goto doi_add_return; + } + list_add_tail_rcu(&doi_def->list, &cipso_v4_doi_list); + spin_unlock(&cipso_v4_doi_list_lock); + ret_val = 0; + +doi_add_return: + audit_buf = netlbl_audit_start(AUDIT_MAC_CIPSOV4_ADD, audit_info); + if (audit_buf) { + const char *type_str; + switch (doi_type) { + case CIPSO_V4_MAP_TRANS: + type_str = "trans"; + break; + case CIPSO_V4_MAP_PASS: + type_str = "pass"; + break; + case CIPSO_V4_MAP_LOCAL: + type_str = "local"; + break; + default: + type_str = "(unknown)"; + } + audit_log_format(audit_buf, + " cipso_doi=%u cipso_type=%s res=%u", + doi, type_str, ret_val == 0 ? 1 : 0); + audit_log_end(audit_buf); + } + + return ret_val; +} + +/** + * cipso_v4_doi_free - Frees a DOI definition + * @doi_def: the DOI definition + * + * Description: + * This function frees all of the memory associated with a DOI definition. + * + */ +void cipso_v4_doi_free(struct cipso_v4_doi *doi_def) +{ + if (!doi_def) + return; + + switch (doi_def->type) { + case CIPSO_V4_MAP_TRANS: + kfree(doi_def->map.std->lvl.cipso); + kfree(doi_def->map.std->lvl.local); + kfree(doi_def->map.std->cat.cipso); + kfree(doi_def->map.std->cat.local); + kfree(doi_def->map.std); + break; + } + kfree(doi_def); +} + +/** + * cipso_v4_doi_free_rcu - Frees a DOI definition via the RCU pointer + * @entry: the entry's RCU field + * + * Description: + * This function is designed to be used as a callback to the call_rcu() + * function so that the memory allocated to the DOI definition can be released + * safely. + * + */ +static void cipso_v4_doi_free_rcu(struct rcu_head *entry) +{ + struct cipso_v4_doi *doi_def; + + doi_def = container_of(entry, struct cipso_v4_doi, rcu); + cipso_v4_doi_free(doi_def); +} + +/** + * cipso_v4_doi_remove - Remove an existing DOI from the CIPSO protocol engine + * @doi: the DOI value + * @audit_info: NetLabel audit information + * + * Description: + * Removes a DOI definition from the CIPSO engine. The NetLabel routines will + * be called to release their own LSM domain mappings as well as our own + * domain list. Returns zero on success and negative values on failure. + * + */ +int cipso_v4_doi_remove(u32 doi, struct netlbl_audit *audit_info) +{ + int ret_val; + struct cipso_v4_doi *doi_def; + struct audit_buffer *audit_buf; + + spin_lock(&cipso_v4_doi_list_lock); + doi_def = cipso_v4_doi_search(doi); + if (!doi_def) { + spin_unlock(&cipso_v4_doi_list_lock); + ret_val = -ENOENT; + goto doi_remove_return; + } + list_del_rcu(&doi_def->list); + spin_unlock(&cipso_v4_doi_list_lock); + + cipso_v4_doi_putdef(doi_def); + ret_val = 0; + +doi_remove_return: + audit_buf = netlbl_audit_start(AUDIT_MAC_CIPSOV4_DEL, audit_info); + if (audit_buf) { + audit_log_format(audit_buf, + " cipso_doi=%u res=%u", + doi, ret_val == 0 ? 1 : 0); + audit_log_end(audit_buf); + } + + return ret_val; +} + +/** + * cipso_v4_doi_getdef - Returns a reference to a valid DOI definition + * @doi: the DOI value + * + * Description: + * Searches for a valid DOI definition and if one is found it is returned to + * the caller. Otherwise NULL is returned. The caller must ensure that + * rcu_read_lock() is held while accessing the returned definition and the DOI + * definition reference count is decremented when the caller is done. + * + */ +struct cipso_v4_doi *cipso_v4_doi_getdef(u32 doi) +{ + struct cipso_v4_doi *doi_def; + + rcu_read_lock(); + doi_def = cipso_v4_doi_search(doi); + if (!doi_def) + goto doi_getdef_return; + if (!refcount_inc_not_zero(&doi_def->refcount)) + doi_def = NULL; + +doi_getdef_return: + rcu_read_unlock(); + return doi_def; +} + +/** + * cipso_v4_doi_putdef - Releases a reference for the given DOI definition + * @doi_def: the DOI definition + * + * Description: + * Releases a DOI definition reference obtained from cipso_v4_doi_getdef(). + * + */ +void cipso_v4_doi_putdef(struct cipso_v4_doi *doi_def) +{ + if (!doi_def) + return; + + if (!refcount_dec_and_test(&doi_def->refcount)) + return; + + cipso_v4_cache_invalidate(); + call_rcu(&doi_def->rcu, cipso_v4_doi_free_rcu); +} + +/** + * cipso_v4_doi_walk - Iterate through the DOI definitions + * @skip_cnt: skip past this number of DOI definitions, updated + * @callback: callback for each DOI definition + * @cb_arg: argument for the callback function + * + * Description: + * Iterate over the DOI definition list, skipping the first @skip_cnt entries. + * For each entry call @callback, if @callback returns a negative value stop + * 'walking' through the list and return. Updates the value in @skip_cnt upon + * return. Returns zero on success, negative values on failure. + * + */ +int cipso_v4_doi_walk(u32 *skip_cnt, + int (*callback) (struct cipso_v4_doi *doi_def, void *arg), + void *cb_arg) +{ + int ret_val = -ENOENT; + u32 doi_cnt = 0; + struct cipso_v4_doi *iter_doi; + + rcu_read_lock(); + list_for_each_entry_rcu(iter_doi, &cipso_v4_doi_list, list) + if (refcount_read(&iter_doi->refcount) > 0) { + if (doi_cnt++ < *skip_cnt) + continue; + ret_val = callback(iter_doi, cb_arg); + if (ret_val < 0) { + doi_cnt--; + goto doi_walk_return; + } + } + +doi_walk_return: + rcu_read_unlock(); + *skip_cnt = doi_cnt; + return ret_val; +} + +/* + * Label Mapping Functions + */ + +/** + * cipso_v4_map_lvl_valid - Checks to see if the given level is understood + * @doi_def: the DOI definition + * @level: the level to check + * + * Description: + * Checks the given level against the given DOI definition and returns a + * negative value if the level does not have a valid mapping and a zero value + * if the level is defined by the DOI. + * + */ +static int cipso_v4_map_lvl_valid(const struct cipso_v4_doi *doi_def, u8 level) +{ + switch (doi_def->type) { + case CIPSO_V4_MAP_PASS: + return 0; + case CIPSO_V4_MAP_TRANS: + if ((level < doi_def->map.std->lvl.cipso_size) && + (doi_def->map.std->lvl.cipso[level] < CIPSO_V4_INV_LVL)) + return 0; + break; + } + + return -EFAULT; +} + +/** + * cipso_v4_map_lvl_hton - Perform a level mapping from the host to the network + * @doi_def: the DOI definition + * @host_lvl: the host MLS level + * @net_lvl: the network/CIPSO MLS level + * + * Description: + * Perform a label mapping to translate a local MLS level to the correct + * CIPSO level using the given DOI definition. Returns zero on success, + * negative values otherwise. + * + */ +static int cipso_v4_map_lvl_hton(const struct cipso_v4_doi *doi_def, + u32 host_lvl, + u32 *net_lvl) +{ + switch (doi_def->type) { + case CIPSO_V4_MAP_PASS: + *net_lvl = host_lvl; + return 0; + case CIPSO_V4_MAP_TRANS: + if (host_lvl < doi_def->map.std->lvl.local_size && + doi_def->map.std->lvl.local[host_lvl] < CIPSO_V4_INV_LVL) { + *net_lvl = doi_def->map.std->lvl.local[host_lvl]; + return 0; + } + return -EPERM; + } + + return -EINVAL; +} + +/** + * cipso_v4_map_lvl_ntoh - Perform a level mapping from the network to the host + * @doi_def: the DOI definition + * @net_lvl: the network/CIPSO MLS level + * @host_lvl: the host MLS level + * + * Description: + * Perform a label mapping to translate a CIPSO level to the correct local MLS + * level using the given DOI definition. Returns zero on success, negative + * values otherwise. + * + */ +static int cipso_v4_map_lvl_ntoh(const struct cipso_v4_doi *doi_def, + u32 net_lvl, + u32 *host_lvl) +{ + struct cipso_v4_std_map_tbl *map_tbl; + + switch (doi_def->type) { + case CIPSO_V4_MAP_PASS: + *host_lvl = net_lvl; + return 0; + case CIPSO_V4_MAP_TRANS: + map_tbl = doi_def->map.std; + if (net_lvl < map_tbl->lvl.cipso_size && + map_tbl->lvl.cipso[net_lvl] < CIPSO_V4_INV_LVL) { + *host_lvl = doi_def->map.std->lvl.cipso[net_lvl]; + return 0; + } + return -EPERM; + } + + return -EINVAL; +} + +/** + * cipso_v4_map_cat_rbm_valid - Checks to see if the category bitmap is valid + * @doi_def: the DOI definition + * @bitmap: category bitmap + * @bitmap_len: bitmap length in bytes + * + * Description: + * Checks the given category bitmap against the given DOI definition and + * returns a negative value if any of the categories in the bitmap do not have + * a valid mapping and a zero value if all of the categories are valid. + * + */ +static int cipso_v4_map_cat_rbm_valid(const struct cipso_v4_doi *doi_def, + const unsigned char *bitmap, + u32 bitmap_len) +{ + int cat = -1; + u32 bitmap_len_bits = bitmap_len * 8; + u32 cipso_cat_size; + u32 *cipso_array; + + switch (doi_def->type) { + case CIPSO_V4_MAP_PASS: + return 0; + case CIPSO_V4_MAP_TRANS: + cipso_cat_size = doi_def->map.std->cat.cipso_size; + cipso_array = doi_def->map.std->cat.cipso; + for (;;) { + cat = netlbl_bitmap_walk(bitmap, + bitmap_len_bits, + cat + 1, + 1); + if (cat < 0) + break; + if (cat >= cipso_cat_size || + cipso_array[cat] >= CIPSO_V4_INV_CAT) + return -EFAULT; + } + + if (cat == -1) + return 0; + break; + } + + return -EFAULT; +} + +/** + * cipso_v4_map_cat_rbm_hton - Perform a category mapping from host to network + * @doi_def: the DOI definition + * @secattr: the security attributes + * @net_cat: the zero'd out category bitmap in network/CIPSO format + * @net_cat_len: the length of the CIPSO bitmap in bytes + * + * Description: + * Perform a label mapping to translate a local MLS category bitmap to the + * correct CIPSO bitmap using the given DOI definition. Returns the minimum + * size in bytes of the network bitmap on success, negative values otherwise. + * + */ +static int cipso_v4_map_cat_rbm_hton(const struct cipso_v4_doi *doi_def, + const struct netlbl_lsm_secattr *secattr, + unsigned char *net_cat, + u32 net_cat_len) +{ + int host_spot = -1; + u32 net_spot = CIPSO_V4_INV_CAT; + u32 net_spot_max = 0; + u32 net_clen_bits = net_cat_len * 8; + u32 host_cat_size = 0; + u32 *host_cat_array = NULL; + + if (doi_def->type == CIPSO_V4_MAP_TRANS) { + host_cat_size = doi_def->map.std->cat.local_size; + host_cat_array = doi_def->map.std->cat.local; + } + + for (;;) { + host_spot = netlbl_catmap_walk(secattr->attr.mls.cat, + host_spot + 1); + if (host_spot < 0) + break; + + switch (doi_def->type) { + case CIPSO_V4_MAP_PASS: + net_spot = host_spot; + break; + case CIPSO_V4_MAP_TRANS: + if (host_spot >= host_cat_size) + return -EPERM; + net_spot = host_cat_array[host_spot]; + if (net_spot >= CIPSO_V4_INV_CAT) + return -EPERM; + break; + } + if (net_spot >= net_clen_bits) + return -ENOSPC; + netlbl_bitmap_setbit(net_cat, net_spot, 1); + + if (net_spot > net_spot_max) + net_spot_max = net_spot; + } + + if (++net_spot_max % 8) + return net_spot_max / 8 + 1; + return net_spot_max / 8; +} + +/** + * cipso_v4_map_cat_rbm_ntoh - Perform a category mapping from network to host + * @doi_def: the DOI definition + * @net_cat: the category bitmap in network/CIPSO format + * @net_cat_len: the length of the CIPSO bitmap in bytes + * @secattr: the security attributes + * + * Description: + * Perform a label mapping to translate a CIPSO bitmap to the correct local + * MLS category bitmap using the given DOI definition. Returns zero on + * success, negative values on failure. + * + */ +static int cipso_v4_map_cat_rbm_ntoh(const struct cipso_v4_doi *doi_def, + const unsigned char *net_cat, + u32 net_cat_len, + struct netlbl_lsm_secattr *secattr) +{ + int ret_val; + int net_spot = -1; + u32 host_spot = CIPSO_V4_INV_CAT; + u32 net_clen_bits = net_cat_len * 8; + u32 net_cat_size = 0; + u32 *net_cat_array = NULL; + + if (doi_def->type == CIPSO_V4_MAP_TRANS) { + net_cat_size = doi_def->map.std->cat.cipso_size; + net_cat_array = doi_def->map.std->cat.cipso; + } + + for (;;) { + net_spot = netlbl_bitmap_walk(net_cat, + net_clen_bits, + net_spot + 1, + 1); + if (net_spot < 0) { + if (net_spot == -2) + return -EFAULT; + return 0; + } + + switch (doi_def->type) { + case CIPSO_V4_MAP_PASS: + host_spot = net_spot; + break; + case CIPSO_V4_MAP_TRANS: + if (net_spot >= net_cat_size) + return -EPERM; + host_spot = net_cat_array[net_spot]; + if (host_spot >= CIPSO_V4_INV_CAT) + return -EPERM; + break; + } + ret_val = netlbl_catmap_setbit(&secattr->attr.mls.cat, + host_spot, + GFP_ATOMIC); + if (ret_val != 0) + return ret_val; + } + + return -EINVAL; +} + +/** + * cipso_v4_map_cat_enum_valid - Checks to see if the categories are valid + * @doi_def: the DOI definition + * @enumcat: category list + * @enumcat_len: length of the category list in bytes + * + * Description: + * Checks the given categories against the given DOI definition and returns a + * negative value if any of the categories do not have a valid mapping and a + * zero value if all of the categories are valid. + * + */ +static int cipso_v4_map_cat_enum_valid(const struct cipso_v4_doi *doi_def, + const unsigned char *enumcat, + u32 enumcat_len) +{ + u16 cat; + int cat_prev = -1; + u32 iter; + + if (doi_def->type != CIPSO_V4_MAP_PASS || enumcat_len & 0x01) + return -EFAULT; + + for (iter = 0; iter < enumcat_len; iter += 2) { + cat = get_unaligned_be16(&enumcat[iter]); + if (cat <= cat_prev) + return -EFAULT; + cat_prev = cat; + } + + return 0; +} + +/** + * cipso_v4_map_cat_enum_hton - Perform a category mapping from host to network + * @doi_def: the DOI definition + * @secattr: the security attributes + * @net_cat: the zero'd out category list in network/CIPSO format + * @net_cat_len: the length of the CIPSO category list in bytes + * + * Description: + * Perform a label mapping to translate a local MLS category bitmap to the + * correct CIPSO category list using the given DOI definition. Returns the + * size in bytes of the network category bitmap on success, negative values + * otherwise. + * + */ +static int cipso_v4_map_cat_enum_hton(const struct cipso_v4_doi *doi_def, + const struct netlbl_lsm_secattr *secattr, + unsigned char *net_cat, + u32 net_cat_len) +{ + int cat = -1; + u32 cat_iter = 0; + + for (;;) { + cat = netlbl_catmap_walk(secattr->attr.mls.cat, cat + 1); + if (cat < 0) + break; + if ((cat_iter + 2) > net_cat_len) + return -ENOSPC; + + *((__be16 *)&net_cat[cat_iter]) = htons(cat); + cat_iter += 2; + } + + return cat_iter; +} + +/** + * cipso_v4_map_cat_enum_ntoh - Perform a category mapping from network to host + * @doi_def: the DOI definition + * @net_cat: the category list in network/CIPSO format + * @net_cat_len: the length of the CIPSO bitmap in bytes + * @secattr: the security attributes + * + * Description: + * Perform a label mapping to translate a CIPSO category list to the correct + * local MLS category bitmap using the given DOI definition. Returns zero on + * success, negative values on failure. + * + */ +static int cipso_v4_map_cat_enum_ntoh(const struct cipso_v4_doi *doi_def, + const unsigned char *net_cat, + u32 net_cat_len, + struct netlbl_lsm_secattr *secattr) +{ + int ret_val; + u32 iter; + + for (iter = 0; iter < net_cat_len; iter += 2) { + ret_val = netlbl_catmap_setbit(&secattr->attr.mls.cat, + get_unaligned_be16(&net_cat[iter]), + GFP_ATOMIC); + if (ret_val != 0) + return ret_val; + } + + return 0; +} + +/** + * cipso_v4_map_cat_rng_valid - Checks to see if the categories are valid + * @doi_def: the DOI definition + * @rngcat: category list + * @rngcat_len: length of the category list in bytes + * + * Description: + * Checks the given categories against the given DOI definition and returns a + * negative value if any of the categories do not have a valid mapping and a + * zero value if all of the categories are valid. + * + */ +static int cipso_v4_map_cat_rng_valid(const struct cipso_v4_doi *doi_def, + const unsigned char *rngcat, + u32 rngcat_len) +{ + u16 cat_high; + u16 cat_low; + u32 cat_prev = CIPSO_V4_MAX_REM_CATS + 1; + u32 iter; + + if (doi_def->type != CIPSO_V4_MAP_PASS || rngcat_len & 0x01) + return -EFAULT; + + for (iter = 0; iter < rngcat_len; iter += 4) { + cat_high = get_unaligned_be16(&rngcat[iter]); + if ((iter + 4) <= rngcat_len) + cat_low = get_unaligned_be16(&rngcat[iter + 2]); + else + cat_low = 0; + + if (cat_high > cat_prev) + return -EFAULT; + + cat_prev = cat_low; + } + + return 0; +} + +/** + * cipso_v4_map_cat_rng_hton - Perform a category mapping from host to network + * @doi_def: the DOI definition + * @secattr: the security attributes + * @net_cat: the zero'd out category list in network/CIPSO format + * @net_cat_len: the length of the CIPSO category list in bytes + * + * Description: + * Perform a label mapping to translate a local MLS category bitmap to the + * correct CIPSO category list using the given DOI definition. Returns the + * size in bytes of the network category bitmap on success, negative values + * otherwise. + * + */ +static int cipso_v4_map_cat_rng_hton(const struct cipso_v4_doi *doi_def, + const struct netlbl_lsm_secattr *secattr, + unsigned char *net_cat, + u32 net_cat_len) +{ + int iter = -1; + u16 array[CIPSO_V4_TAG_RNG_CAT_MAX * 2]; + u32 array_cnt = 0; + u32 cat_size = 0; + + /* make sure we don't overflow the 'array[]' variable */ + if (net_cat_len > + (CIPSO_V4_OPT_LEN_MAX - CIPSO_V4_HDR_LEN - CIPSO_V4_TAG_RNG_BLEN)) + return -ENOSPC; + + for (;;) { + iter = netlbl_catmap_walk(secattr->attr.mls.cat, iter + 1); + if (iter < 0) + break; + cat_size += (iter == 0 ? 0 : sizeof(u16)); + if (cat_size > net_cat_len) + return -ENOSPC; + array[array_cnt++] = iter; + + iter = netlbl_catmap_walkrng(secattr->attr.mls.cat, iter); + if (iter < 0) + return -EFAULT; + cat_size += sizeof(u16); + if (cat_size > net_cat_len) + return -ENOSPC; + array[array_cnt++] = iter; + } + + for (iter = 0; array_cnt > 0;) { + *((__be16 *)&net_cat[iter]) = htons(array[--array_cnt]); + iter += 2; + array_cnt--; + if (array[array_cnt] != 0) { + *((__be16 *)&net_cat[iter]) = htons(array[array_cnt]); + iter += 2; + } + } + + return cat_size; +} + +/** + * cipso_v4_map_cat_rng_ntoh - Perform a category mapping from network to host + * @doi_def: the DOI definition + * @net_cat: the category list in network/CIPSO format + * @net_cat_len: the length of the CIPSO bitmap in bytes + * @secattr: the security attributes + * + * Description: + * Perform a label mapping to translate a CIPSO category list to the correct + * local MLS category bitmap using the given DOI definition. Returns zero on + * success, negative values on failure. + * + */ +static int cipso_v4_map_cat_rng_ntoh(const struct cipso_v4_doi *doi_def, + const unsigned char *net_cat, + u32 net_cat_len, + struct netlbl_lsm_secattr *secattr) +{ + int ret_val; + u32 net_iter; + u16 cat_low; + u16 cat_high; + + for (net_iter = 0; net_iter < net_cat_len; net_iter += 4) { + cat_high = get_unaligned_be16(&net_cat[net_iter]); + if ((net_iter + 4) <= net_cat_len) + cat_low = get_unaligned_be16(&net_cat[net_iter + 2]); + else + cat_low = 0; + + ret_val = netlbl_catmap_setrng(&secattr->attr.mls.cat, + cat_low, + cat_high, + GFP_ATOMIC); + if (ret_val != 0) + return ret_val; + } + + return 0; +} + +/* + * Protocol Handling Functions + */ + +/** + * cipso_v4_gentag_hdr - Generate a CIPSO option header + * @doi_def: the DOI definition + * @len: the total tag length in bytes, not including this header + * @buf: the CIPSO option buffer + * + * Description: + * Write a CIPSO header into the beginning of @buffer. + * + */ +static void cipso_v4_gentag_hdr(const struct cipso_v4_doi *doi_def, + unsigned char *buf, + u32 len) +{ + buf[0] = IPOPT_CIPSO; + buf[1] = CIPSO_V4_HDR_LEN + len; + put_unaligned_be32(doi_def->doi, &buf[2]); +} + +/** + * cipso_v4_gentag_rbm - Generate a CIPSO restricted bitmap tag (type #1) + * @doi_def: the DOI definition + * @secattr: the security attributes + * @buffer: the option buffer + * @buffer_len: length of buffer in bytes + * + * Description: + * Generate a CIPSO option using the restricted bitmap tag, tag type #1. The + * actual buffer length may be larger than the indicated size due to + * translation between host and network category bitmaps. Returns the size of + * the tag on success, negative values on failure. + * + */ +static int cipso_v4_gentag_rbm(const struct cipso_v4_doi *doi_def, + const struct netlbl_lsm_secattr *secattr, + unsigned char *buffer, + u32 buffer_len) +{ + int ret_val; + u32 tag_len; + u32 level; + + if ((secattr->flags & NETLBL_SECATTR_MLS_LVL) == 0) + return -EPERM; + + ret_val = cipso_v4_map_lvl_hton(doi_def, + secattr->attr.mls.lvl, + &level); + if (ret_val != 0) + return ret_val; + + if (secattr->flags & NETLBL_SECATTR_MLS_CAT) { + ret_val = cipso_v4_map_cat_rbm_hton(doi_def, + secattr, + &buffer[4], + buffer_len - 4); + if (ret_val < 0) + return ret_val; + + /* This will send packets using the "optimized" format when + * possible as specified in section 3.4.2.6 of the + * CIPSO draft. */ + if (READ_ONCE(cipso_v4_rbm_optfmt) && ret_val > 0 && + ret_val <= 10) + tag_len = 14; + else + tag_len = 4 + ret_val; + } else + tag_len = 4; + + buffer[0] = CIPSO_V4_TAG_RBITMAP; + buffer[1] = tag_len; + buffer[3] = level; + + return tag_len; +} + +/** + * cipso_v4_parsetag_rbm - Parse a CIPSO restricted bitmap tag + * @doi_def: the DOI definition + * @tag: the CIPSO tag + * @secattr: the security attributes + * + * Description: + * Parse a CIPSO restricted bitmap tag (tag type #1) and return the security + * attributes in @secattr. Return zero on success, negatives values on + * failure. + * + */ +static int cipso_v4_parsetag_rbm(const struct cipso_v4_doi *doi_def, + const unsigned char *tag, + struct netlbl_lsm_secattr *secattr) +{ + int ret_val; + u8 tag_len = tag[1]; + u32 level; + + ret_val = cipso_v4_map_lvl_ntoh(doi_def, tag[3], &level); + if (ret_val != 0) + return ret_val; + secattr->attr.mls.lvl = level; + secattr->flags |= NETLBL_SECATTR_MLS_LVL; + + if (tag_len > 4) { + ret_val = cipso_v4_map_cat_rbm_ntoh(doi_def, + &tag[4], + tag_len - 4, + secattr); + if (ret_val != 0) { + netlbl_catmap_free(secattr->attr.mls.cat); + return ret_val; + } + + if (secattr->attr.mls.cat) + secattr->flags |= NETLBL_SECATTR_MLS_CAT; + } + + return 0; +} + +/** + * cipso_v4_gentag_enum - Generate a CIPSO enumerated tag (type #2) + * @doi_def: the DOI definition + * @secattr: the security attributes + * @buffer: the option buffer + * @buffer_len: length of buffer in bytes + * + * Description: + * Generate a CIPSO option using the enumerated tag, tag type #2. Returns the + * size of the tag on success, negative values on failure. + * + */ +static int cipso_v4_gentag_enum(const struct cipso_v4_doi *doi_def, + const struct netlbl_lsm_secattr *secattr, + unsigned char *buffer, + u32 buffer_len) +{ + int ret_val; + u32 tag_len; + u32 level; + + if (!(secattr->flags & NETLBL_SECATTR_MLS_LVL)) + return -EPERM; + + ret_val = cipso_v4_map_lvl_hton(doi_def, + secattr->attr.mls.lvl, + &level); + if (ret_val != 0) + return ret_val; + + if (secattr->flags & NETLBL_SECATTR_MLS_CAT) { + ret_val = cipso_v4_map_cat_enum_hton(doi_def, + secattr, + &buffer[4], + buffer_len - 4); + if (ret_val < 0) + return ret_val; + + tag_len = 4 + ret_val; + } else + tag_len = 4; + + buffer[0] = CIPSO_V4_TAG_ENUM; + buffer[1] = tag_len; + buffer[3] = level; + + return tag_len; +} + +/** + * cipso_v4_parsetag_enum - Parse a CIPSO enumerated tag + * @doi_def: the DOI definition + * @tag: the CIPSO tag + * @secattr: the security attributes + * + * Description: + * Parse a CIPSO enumerated tag (tag type #2) and return the security + * attributes in @secattr. Return zero on success, negatives values on + * failure. + * + */ +static int cipso_v4_parsetag_enum(const struct cipso_v4_doi *doi_def, + const unsigned char *tag, + struct netlbl_lsm_secattr *secattr) +{ + int ret_val; + u8 tag_len = tag[1]; + u32 level; + + ret_val = cipso_v4_map_lvl_ntoh(doi_def, tag[3], &level); + if (ret_val != 0) + return ret_val; + secattr->attr.mls.lvl = level; + secattr->flags |= NETLBL_SECATTR_MLS_LVL; + + if (tag_len > 4) { + ret_val = cipso_v4_map_cat_enum_ntoh(doi_def, + &tag[4], + tag_len - 4, + secattr); + if (ret_val != 0) { + netlbl_catmap_free(secattr->attr.mls.cat); + return ret_val; + } + + secattr->flags |= NETLBL_SECATTR_MLS_CAT; + } + + return 0; +} + +/** + * cipso_v4_gentag_rng - Generate a CIPSO ranged tag (type #5) + * @doi_def: the DOI definition + * @secattr: the security attributes + * @buffer: the option buffer + * @buffer_len: length of buffer in bytes + * + * Description: + * Generate a CIPSO option using the ranged tag, tag type #5. Returns the + * size of the tag on success, negative values on failure. + * + */ +static int cipso_v4_gentag_rng(const struct cipso_v4_doi *doi_def, + const struct netlbl_lsm_secattr *secattr, + unsigned char *buffer, + u32 buffer_len) +{ + int ret_val; + u32 tag_len; + u32 level; + + if (!(secattr->flags & NETLBL_SECATTR_MLS_LVL)) + return -EPERM; + + ret_val = cipso_v4_map_lvl_hton(doi_def, + secattr->attr.mls.lvl, + &level); + if (ret_val != 0) + return ret_val; + + if (secattr->flags & NETLBL_SECATTR_MLS_CAT) { + ret_val = cipso_v4_map_cat_rng_hton(doi_def, + secattr, + &buffer[4], + buffer_len - 4); + if (ret_val < 0) + return ret_val; + + tag_len = 4 + ret_val; + } else + tag_len = 4; + + buffer[0] = CIPSO_V4_TAG_RANGE; + buffer[1] = tag_len; + buffer[3] = level; + + return tag_len; +} + +/** + * cipso_v4_parsetag_rng - Parse a CIPSO ranged tag + * @doi_def: the DOI definition + * @tag: the CIPSO tag + * @secattr: the security attributes + * + * Description: + * Parse a CIPSO ranged tag (tag type #5) and return the security attributes + * in @secattr. Return zero on success, negatives values on failure. + * + */ +static int cipso_v4_parsetag_rng(const struct cipso_v4_doi *doi_def, + const unsigned char *tag, + struct netlbl_lsm_secattr *secattr) +{ + int ret_val; + u8 tag_len = tag[1]; + u32 level; + + ret_val = cipso_v4_map_lvl_ntoh(doi_def, tag[3], &level); + if (ret_val != 0) + return ret_val; + secattr->attr.mls.lvl = level; + secattr->flags |= NETLBL_SECATTR_MLS_LVL; + + if (tag_len > 4) { + ret_val = cipso_v4_map_cat_rng_ntoh(doi_def, + &tag[4], + tag_len - 4, + secattr); + if (ret_val != 0) { + netlbl_catmap_free(secattr->attr.mls.cat); + return ret_val; + } + + if (secattr->attr.mls.cat) + secattr->flags |= NETLBL_SECATTR_MLS_CAT; + } + + return 0; +} + +/** + * cipso_v4_gentag_loc - Generate a CIPSO local tag (non-standard) + * @doi_def: the DOI definition + * @secattr: the security attributes + * @buffer: the option buffer + * @buffer_len: length of buffer in bytes + * + * Description: + * Generate a CIPSO option using the local tag. Returns the size of the tag + * on success, negative values on failure. + * + */ +static int cipso_v4_gentag_loc(const struct cipso_v4_doi *doi_def, + const struct netlbl_lsm_secattr *secattr, + unsigned char *buffer, + u32 buffer_len) +{ + if (!(secattr->flags & NETLBL_SECATTR_SECID)) + return -EPERM; + + buffer[0] = CIPSO_V4_TAG_LOCAL; + buffer[1] = CIPSO_V4_TAG_LOC_BLEN; + *(u32 *)&buffer[2] = secattr->attr.secid; + + return CIPSO_V4_TAG_LOC_BLEN; +} + +/** + * cipso_v4_parsetag_loc - Parse a CIPSO local tag + * @doi_def: the DOI definition + * @tag: the CIPSO tag + * @secattr: the security attributes + * + * Description: + * Parse a CIPSO local tag and return the security attributes in @secattr. + * Return zero on success, negatives values on failure. + * + */ +static int cipso_v4_parsetag_loc(const struct cipso_v4_doi *doi_def, + const unsigned char *tag, + struct netlbl_lsm_secattr *secattr) +{ + secattr->attr.secid = *(u32 *)&tag[2]; + secattr->flags |= NETLBL_SECATTR_SECID; + + return 0; +} + +/** + * cipso_v4_optptr - Find the CIPSO option in the packet + * @skb: the packet + * + * Description: + * Parse the packet's IP header looking for a CIPSO option. Returns a pointer + * to the start of the CIPSO option on success, NULL if one is not found. + * + */ +unsigned char *cipso_v4_optptr(const struct sk_buff *skb) +{ + const struct iphdr *iph = ip_hdr(skb); + unsigned char *optptr = (unsigned char *)&(ip_hdr(skb)[1]); + int optlen; + int taglen; + + for (optlen = iph->ihl*4 - sizeof(struct iphdr); optlen > 1; ) { + switch (optptr[0]) { + case IPOPT_END: + return NULL; + case IPOPT_NOOP: + taglen = 1; + break; + default: + taglen = optptr[1]; + } + if (!taglen || taglen > optlen) + return NULL; + if (optptr[0] == IPOPT_CIPSO) + return optptr; + + optlen -= taglen; + optptr += taglen; + } + + return NULL; +} + +/** + * cipso_v4_validate - Validate a CIPSO option + * @skb: the packet + * @option: the start of the option, on error it is set to point to the error + * + * Description: + * This routine is called to validate a CIPSO option, it checks all of the + * fields to ensure that they are at least valid, see the draft snippet below + * for details. If the option is valid then a zero value is returned and + * the value of @option is unchanged. If the option is invalid then a + * non-zero value is returned and @option is adjusted to point to the + * offending portion of the option. From the IETF draft ... + * + * "If any field within the CIPSO options, such as the DOI identifier, is not + * recognized the IP datagram is discarded and an ICMP 'parameter problem' + * (type 12) is generated and returned. The ICMP code field is set to 'bad + * parameter' (code 0) and the pointer is set to the start of the CIPSO field + * that is unrecognized." + * + */ +int cipso_v4_validate(const struct sk_buff *skb, unsigned char **option) +{ + unsigned char *opt = *option; + unsigned char *tag; + unsigned char opt_iter; + unsigned char err_offset = 0; + u8 opt_len; + u8 tag_len; + struct cipso_v4_doi *doi_def = NULL; + u32 tag_iter; + + /* caller already checks for length values that are too large */ + opt_len = opt[1]; + if (opt_len < 8) { + err_offset = 1; + goto validate_return; + } + + rcu_read_lock(); + doi_def = cipso_v4_doi_search(get_unaligned_be32(&opt[2])); + if (!doi_def) { + err_offset = 2; + goto validate_return_locked; + } + + opt_iter = CIPSO_V4_HDR_LEN; + tag = opt + opt_iter; + while (opt_iter < opt_len) { + for (tag_iter = 0; doi_def->tags[tag_iter] != tag[0];) + if (doi_def->tags[tag_iter] == CIPSO_V4_TAG_INVALID || + ++tag_iter == CIPSO_V4_TAG_MAXCNT) { + err_offset = opt_iter; + goto validate_return_locked; + } + + if (opt_iter + 1 == opt_len) { + err_offset = opt_iter; + goto validate_return_locked; + } + tag_len = tag[1]; + if (tag_len > (opt_len - opt_iter)) { + err_offset = opt_iter + 1; + goto validate_return_locked; + } + + switch (tag[0]) { + case CIPSO_V4_TAG_RBITMAP: + if (tag_len < CIPSO_V4_TAG_RBM_BLEN) { + err_offset = opt_iter + 1; + goto validate_return_locked; + } + + /* We are already going to do all the verification + * necessary at the socket layer so from our point of + * view it is safe to turn these checks off (and less + * work), however, the CIPSO draft says we should do + * all the CIPSO validations here but it doesn't + * really specify _exactly_ what we need to validate + * ... so, just make it a sysctl tunable. */ + if (READ_ONCE(cipso_v4_rbm_strictvalid)) { + if (cipso_v4_map_lvl_valid(doi_def, + tag[3]) < 0) { + err_offset = opt_iter + 3; + goto validate_return_locked; + } + if (tag_len > CIPSO_V4_TAG_RBM_BLEN && + cipso_v4_map_cat_rbm_valid(doi_def, + &tag[4], + tag_len - 4) < 0) { + err_offset = opt_iter + 4; + goto validate_return_locked; + } + } + break; + case CIPSO_V4_TAG_ENUM: + if (tag_len < CIPSO_V4_TAG_ENUM_BLEN) { + err_offset = opt_iter + 1; + goto validate_return_locked; + } + + if (cipso_v4_map_lvl_valid(doi_def, + tag[3]) < 0) { + err_offset = opt_iter + 3; + goto validate_return_locked; + } + if (tag_len > CIPSO_V4_TAG_ENUM_BLEN && + cipso_v4_map_cat_enum_valid(doi_def, + &tag[4], + tag_len - 4) < 0) { + err_offset = opt_iter + 4; + goto validate_return_locked; + } + break; + case CIPSO_V4_TAG_RANGE: + if (tag_len < CIPSO_V4_TAG_RNG_BLEN) { + err_offset = opt_iter + 1; + goto validate_return_locked; + } + + if (cipso_v4_map_lvl_valid(doi_def, + tag[3]) < 0) { + err_offset = opt_iter + 3; + goto validate_return_locked; + } + if (tag_len > CIPSO_V4_TAG_RNG_BLEN && + cipso_v4_map_cat_rng_valid(doi_def, + &tag[4], + tag_len - 4) < 0) { + err_offset = opt_iter + 4; + goto validate_return_locked; + } + break; + case CIPSO_V4_TAG_LOCAL: + /* This is a non-standard tag that we only allow for + * local connections, so if the incoming interface is + * not the loopback device drop the packet. Further, + * there is no legitimate reason for setting this from + * userspace so reject it if skb is NULL. */ + if (!skb || !(skb->dev->flags & IFF_LOOPBACK)) { + err_offset = opt_iter; + goto validate_return_locked; + } + if (tag_len != CIPSO_V4_TAG_LOC_BLEN) { + err_offset = opt_iter + 1; + goto validate_return_locked; + } + break; + default: + err_offset = opt_iter; + goto validate_return_locked; + } + + tag += tag_len; + opt_iter += tag_len; + } + +validate_return_locked: + rcu_read_unlock(); +validate_return: + *option = opt + err_offset; + return err_offset; +} + +/** + * cipso_v4_error - Send the correct response for a bad packet + * @skb: the packet + * @error: the error code + * @gateway: CIPSO gateway flag + * + * Description: + * Based on the error code given in @error, send an ICMP error message back to + * the originating host. From the IETF draft ... + * + * "If the contents of the CIPSO [option] are valid but the security label is + * outside of the configured host or port label range, the datagram is + * discarded and an ICMP 'destination unreachable' (type 3) is generated and + * returned. The code field of the ICMP is set to 'communication with + * destination network administratively prohibited' (code 9) or to + * 'communication with destination host administratively prohibited' + * (code 10). The value of the code is dependent on whether the originator + * of the ICMP message is acting as a CIPSO host or a CIPSO gateway. The + * recipient of the ICMP message MUST be able to handle either value. The + * same procedure is performed if a CIPSO [option] can not be added to an + * IP packet because it is too large to fit in the IP options area." + * + * "If the error is triggered by receipt of an ICMP message, the message is + * discarded and no response is permitted (consistent with general ICMP + * processing rules)." + * + */ +void cipso_v4_error(struct sk_buff *skb, int error, u32 gateway) +{ + unsigned char optbuf[sizeof(struct ip_options) + 40]; + struct ip_options *opt = (struct ip_options *)optbuf; + int res; + + if (ip_hdr(skb)->protocol == IPPROTO_ICMP || error != -EACCES) + return; + + /* + * We might be called above the IP layer, + * so we can not use icmp_send and IPCB here. + */ + + memset(opt, 0, sizeof(struct ip_options)); + opt->optlen = ip_hdr(skb)->ihl*4 - sizeof(struct iphdr); + rcu_read_lock(); + res = __ip_options_compile(dev_net(skb->dev), opt, skb, NULL); + rcu_read_unlock(); + + if (res) + return; + + if (gateway) + __icmp_send(skb, ICMP_DEST_UNREACH, ICMP_NET_ANO, 0, opt); + else + __icmp_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_ANO, 0, opt); +} + +/** + * cipso_v4_genopt - Generate a CIPSO option + * @buf: the option buffer + * @buf_len: the size of opt_buf + * @doi_def: the CIPSO DOI to use + * @secattr: the security attributes + * + * Description: + * Generate a CIPSO option using the DOI definition and security attributes + * passed to the function. Returns the length of the option on success and + * negative values on failure. + * + */ +static int cipso_v4_genopt(unsigned char *buf, u32 buf_len, + const struct cipso_v4_doi *doi_def, + const struct netlbl_lsm_secattr *secattr) +{ + int ret_val; + u32 iter; + + if (buf_len <= CIPSO_V4_HDR_LEN) + return -ENOSPC; + + /* XXX - This code assumes only one tag per CIPSO option which isn't + * really a good assumption to make but since we only support the MAC + * tags right now it is a safe assumption. */ + iter = 0; + do { + memset(buf, 0, buf_len); + switch (doi_def->tags[iter]) { + case CIPSO_V4_TAG_RBITMAP: + ret_val = cipso_v4_gentag_rbm(doi_def, + secattr, + &buf[CIPSO_V4_HDR_LEN], + buf_len - CIPSO_V4_HDR_LEN); + break; + case CIPSO_V4_TAG_ENUM: + ret_val = cipso_v4_gentag_enum(doi_def, + secattr, + &buf[CIPSO_V4_HDR_LEN], + buf_len - CIPSO_V4_HDR_LEN); + break; + case CIPSO_V4_TAG_RANGE: + ret_val = cipso_v4_gentag_rng(doi_def, + secattr, + &buf[CIPSO_V4_HDR_LEN], + buf_len - CIPSO_V4_HDR_LEN); + break; + case CIPSO_V4_TAG_LOCAL: + ret_val = cipso_v4_gentag_loc(doi_def, + secattr, + &buf[CIPSO_V4_HDR_LEN], + buf_len - CIPSO_V4_HDR_LEN); + break; + default: + return -EPERM; + } + + iter++; + } while (ret_val < 0 && + iter < CIPSO_V4_TAG_MAXCNT && + doi_def->tags[iter] != CIPSO_V4_TAG_INVALID); + if (ret_val < 0) + return ret_val; + cipso_v4_gentag_hdr(doi_def, buf, ret_val); + return CIPSO_V4_HDR_LEN + ret_val; +} + +/** + * cipso_v4_sock_setattr - Add a CIPSO option to a socket + * @sk: the socket + * @doi_def: the CIPSO DOI to use + * @secattr: the specific security attributes of the socket + * + * Description: + * Set the CIPSO option on the given socket using the DOI definition and + * security attributes passed to the function. This function requires + * exclusive access to @sk, which means it either needs to be in the + * process of being created or locked. Returns zero on success and negative + * values on failure. + * + */ +int cipso_v4_sock_setattr(struct sock *sk, + const struct cipso_v4_doi *doi_def, + const struct netlbl_lsm_secattr *secattr) +{ + int ret_val = -EPERM; + unsigned char *buf = NULL; + u32 buf_len; + u32 opt_len; + struct ip_options_rcu *old, *opt = NULL; + struct inet_sock *sk_inet; + struct inet_connection_sock *sk_conn; + + /* In the case of sock_create_lite(), the sock->sk field is not + * defined yet but it is not a problem as the only users of these + * "lite" PF_INET sockets are functions which do an accept() call + * afterwards so we will label the socket as part of the accept(). */ + if (!sk) + return 0; + + /* We allocate the maximum CIPSO option size here so we are probably + * being a little wasteful, but it makes our life _much_ easier later + * on and after all we are only talking about 40 bytes. */ + buf_len = CIPSO_V4_OPT_LEN_MAX; + buf = kmalloc(buf_len, GFP_ATOMIC); + if (!buf) { + ret_val = -ENOMEM; + goto socket_setattr_failure; + } + + ret_val = cipso_v4_genopt(buf, buf_len, doi_def, secattr); + if (ret_val < 0) + goto socket_setattr_failure; + buf_len = ret_val; + + /* We can't use ip_options_get() directly because it makes a call to + * ip_options_get_alloc() which allocates memory with GFP_KERNEL and + * we won't always have CAP_NET_RAW even though we _always_ want to + * set the IPOPT_CIPSO option. */ + opt_len = (buf_len + 3) & ~3; + opt = kzalloc(sizeof(*opt) + opt_len, GFP_ATOMIC); + if (!opt) { + ret_val = -ENOMEM; + goto socket_setattr_failure; + } + memcpy(opt->opt.__data, buf, buf_len); + opt->opt.optlen = opt_len; + opt->opt.cipso = sizeof(struct iphdr); + kfree(buf); + buf = NULL; + + sk_inet = inet_sk(sk); + + old = rcu_dereference_protected(sk_inet->inet_opt, + lockdep_sock_is_held(sk)); + if (sk_inet->is_icsk) { + sk_conn = inet_csk(sk); + if (old) + sk_conn->icsk_ext_hdr_len -= old->opt.optlen; + sk_conn->icsk_ext_hdr_len += opt->opt.optlen; + sk_conn->icsk_sync_mss(sk, sk_conn->icsk_pmtu_cookie); + } + rcu_assign_pointer(sk_inet->inet_opt, opt); + if (old) + kfree_rcu(old, rcu); + + return 0; + +socket_setattr_failure: + kfree(buf); + kfree(opt); + return ret_val; +} + +/** + * cipso_v4_req_setattr - Add a CIPSO option to a connection request socket + * @req: the connection request socket + * @doi_def: the CIPSO DOI to use + * @secattr: the specific security attributes of the socket + * + * Description: + * Set the CIPSO option on the given socket using the DOI definition and + * security attributes passed to the function. Returns zero on success and + * negative values on failure. + * + */ +int cipso_v4_req_setattr(struct request_sock *req, + const struct cipso_v4_doi *doi_def, + const struct netlbl_lsm_secattr *secattr) +{ + int ret_val = -EPERM; + unsigned char *buf = NULL; + u32 buf_len; + u32 opt_len; + struct ip_options_rcu *opt = NULL; + struct inet_request_sock *req_inet; + + /* We allocate the maximum CIPSO option size here so we are probably + * being a little wasteful, but it makes our life _much_ easier later + * on and after all we are only talking about 40 bytes. */ + buf_len = CIPSO_V4_OPT_LEN_MAX; + buf = kmalloc(buf_len, GFP_ATOMIC); + if (!buf) { + ret_val = -ENOMEM; + goto req_setattr_failure; + } + + ret_val = cipso_v4_genopt(buf, buf_len, doi_def, secattr); + if (ret_val < 0) + goto req_setattr_failure; + buf_len = ret_val; + + /* We can't use ip_options_get() directly because it makes a call to + * ip_options_get_alloc() which allocates memory with GFP_KERNEL and + * we won't always have CAP_NET_RAW even though we _always_ want to + * set the IPOPT_CIPSO option. */ + opt_len = (buf_len + 3) & ~3; + opt = kzalloc(sizeof(*opt) + opt_len, GFP_ATOMIC); + if (!opt) { + ret_val = -ENOMEM; + goto req_setattr_failure; + } + memcpy(opt->opt.__data, buf, buf_len); + opt->opt.optlen = opt_len; + opt->opt.cipso = sizeof(struct iphdr); + kfree(buf); + buf = NULL; + + req_inet = inet_rsk(req); + opt = xchg((__force struct ip_options_rcu **)&req_inet->ireq_opt, opt); + if (opt) + kfree_rcu(opt, rcu); + + return 0; + +req_setattr_failure: + kfree(buf); + kfree(opt); + return ret_val; +} + +/** + * cipso_v4_delopt - Delete the CIPSO option from a set of IP options + * @opt_ptr: IP option pointer + * + * Description: + * Deletes the CIPSO IP option from a set of IP options and makes the necessary + * adjustments to the IP option structure. Returns zero on success, negative + * values on failure. + * + */ +static int cipso_v4_delopt(struct ip_options_rcu __rcu **opt_ptr) +{ + struct ip_options_rcu *opt = rcu_dereference_protected(*opt_ptr, 1); + int hdr_delta = 0; + + if (!opt || opt->opt.cipso == 0) + return 0; + if (opt->opt.srr || opt->opt.rr || opt->opt.ts || opt->opt.router_alert) { + u8 cipso_len; + u8 cipso_off; + unsigned char *cipso_ptr; + int iter; + int optlen_new; + + cipso_off = opt->opt.cipso - sizeof(struct iphdr); + cipso_ptr = &opt->opt.__data[cipso_off]; + cipso_len = cipso_ptr[1]; + + if (opt->opt.srr > opt->opt.cipso) + opt->opt.srr -= cipso_len; + if (opt->opt.rr > opt->opt.cipso) + opt->opt.rr -= cipso_len; + if (opt->opt.ts > opt->opt.cipso) + opt->opt.ts -= cipso_len; + if (opt->opt.router_alert > opt->opt.cipso) + opt->opt.router_alert -= cipso_len; + opt->opt.cipso = 0; + + memmove(cipso_ptr, cipso_ptr + cipso_len, + opt->opt.optlen - cipso_off - cipso_len); + + /* determining the new total option length is tricky because of + * the padding necessary, the only thing i can think to do at + * this point is walk the options one-by-one, skipping the + * padding at the end to determine the actual option size and + * from there we can determine the new total option length */ + iter = 0; + optlen_new = 0; + while (iter < opt->opt.optlen) + if (opt->opt.__data[iter] != IPOPT_NOP) { + iter += opt->opt.__data[iter + 1]; + optlen_new = iter; + } else + iter++; + hdr_delta = opt->opt.optlen; + opt->opt.optlen = (optlen_new + 3) & ~3; + hdr_delta -= opt->opt.optlen; + } else { + /* only the cipso option was present on the socket so we can + * remove the entire option struct */ + *opt_ptr = NULL; + hdr_delta = opt->opt.optlen; + kfree_rcu(opt, rcu); + } + + return hdr_delta; +} + +/** + * cipso_v4_sock_delattr - Delete the CIPSO option from a socket + * @sk: the socket + * + * Description: + * Removes the CIPSO option from a socket, if present. + * + */ +void cipso_v4_sock_delattr(struct sock *sk) +{ + struct inet_sock *sk_inet; + int hdr_delta; + + sk_inet = inet_sk(sk); + + hdr_delta = cipso_v4_delopt(&sk_inet->inet_opt); + if (sk_inet->is_icsk && hdr_delta > 0) { + struct inet_connection_sock *sk_conn = inet_csk(sk); + sk_conn->icsk_ext_hdr_len -= hdr_delta; + sk_conn->icsk_sync_mss(sk, sk_conn->icsk_pmtu_cookie); + } +} + +/** + * cipso_v4_req_delattr - Delete the CIPSO option from a request socket + * @req: the request socket + * + * Description: + * Removes the CIPSO option from a request socket, if present. + * + */ +void cipso_v4_req_delattr(struct request_sock *req) +{ + cipso_v4_delopt(&inet_rsk(req)->ireq_opt); +} + +/** + * cipso_v4_getattr - Helper function for the cipso_v4_*_getattr functions + * @cipso: the CIPSO v4 option + * @secattr: the security attributes + * + * Description: + * Inspect @cipso and return the security attributes in @secattr. Returns zero + * on success and negative values on failure. + * + */ +int cipso_v4_getattr(const unsigned char *cipso, + struct netlbl_lsm_secattr *secattr) +{ + int ret_val = -ENOMSG; + u32 doi; + struct cipso_v4_doi *doi_def; + + if (cipso_v4_cache_check(cipso, cipso[1], secattr) == 0) + return 0; + + doi = get_unaligned_be32(&cipso[2]); + rcu_read_lock(); + doi_def = cipso_v4_doi_search(doi); + if (!doi_def) + goto getattr_return; + /* XXX - This code assumes only one tag per CIPSO option which isn't + * really a good assumption to make but since we only support the MAC + * tags right now it is a safe assumption. */ + switch (cipso[6]) { + case CIPSO_V4_TAG_RBITMAP: + ret_val = cipso_v4_parsetag_rbm(doi_def, &cipso[6], secattr); + break; + case CIPSO_V4_TAG_ENUM: + ret_val = cipso_v4_parsetag_enum(doi_def, &cipso[6], secattr); + break; + case CIPSO_V4_TAG_RANGE: + ret_val = cipso_v4_parsetag_rng(doi_def, &cipso[6], secattr); + break; + case CIPSO_V4_TAG_LOCAL: + ret_val = cipso_v4_parsetag_loc(doi_def, &cipso[6], secattr); + break; + } + if (ret_val == 0) + secattr->type = NETLBL_NLTYPE_CIPSOV4; + +getattr_return: + rcu_read_unlock(); + return ret_val; +} + +/** + * cipso_v4_sock_getattr - Get the security attributes from a sock + * @sk: the sock + * @secattr: the security attributes + * + * Description: + * Query @sk to see if there is a CIPSO option attached to the sock and if + * there is return the CIPSO security attributes in @secattr. This function + * requires that @sk be locked, or privately held, but it does not do any + * locking itself. Returns zero on success and negative values on failure. + * + */ +int cipso_v4_sock_getattr(struct sock *sk, struct netlbl_lsm_secattr *secattr) +{ + struct ip_options_rcu *opt; + int res = -ENOMSG; + + rcu_read_lock(); + opt = rcu_dereference(inet_sk(sk)->inet_opt); + if (opt && opt->opt.cipso) + res = cipso_v4_getattr(opt->opt.__data + + opt->opt.cipso - + sizeof(struct iphdr), + secattr); + rcu_read_unlock(); + return res; +} + +/** + * cipso_v4_skbuff_setattr - Set the CIPSO option on a packet + * @skb: the packet + * @doi_def: the DOI structure + * @secattr: the security attributes + * + * Description: + * Set the CIPSO option on the given packet based on the security attributes. + * Returns a pointer to the IP header on success and NULL on failure. + * + */ +int cipso_v4_skbuff_setattr(struct sk_buff *skb, + const struct cipso_v4_doi *doi_def, + const struct netlbl_lsm_secattr *secattr) +{ + int ret_val; + struct iphdr *iph; + struct ip_options *opt = &IPCB(skb)->opt; + unsigned char buf[CIPSO_V4_OPT_LEN_MAX]; + u32 buf_len = CIPSO_V4_OPT_LEN_MAX; + u32 opt_len; + int len_delta; + + ret_val = cipso_v4_genopt(buf, buf_len, doi_def, secattr); + if (ret_val < 0) + return ret_val; + buf_len = ret_val; + opt_len = (buf_len + 3) & ~3; + + /* we overwrite any existing options to ensure that we have enough + * room for the CIPSO option, the reason is that we _need_ to guarantee + * that the security label is applied to the packet - we do the same + * thing when using the socket options and it hasn't caused a problem, + * if we need to we can always revisit this choice later */ + + len_delta = opt_len - opt->optlen; + /* if we don't ensure enough headroom we could panic on the skb_push() + * call below so make sure we have enough, we are also "mangling" the + * packet so we should probably do a copy-on-write call anyway */ + ret_val = skb_cow(skb, skb_headroom(skb) + len_delta); + if (ret_val < 0) + return ret_val; + + if (len_delta > 0) { + /* we assume that the header + opt->optlen have already been + * "pushed" in ip_options_build() or similar */ + iph = ip_hdr(skb); + skb_push(skb, len_delta); + memmove((char *)iph - len_delta, iph, iph->ihl << 2); + skb_reset_network_header(skb); + iph = ip_hdr(skb); + } else if (len_delta < 0) { + iph = ip_hdr(skb); + memset(iph + 1, IPOPT_NOP, opt->optlen); + } else + iph = ip_hdr(skb); + + if (opt->optlen > 0) + memset(opt, 0, sizeof(*opt)); + opt->optlen = opt_len; + opt->cipso = sizeof(struct iphdr); + opt->is_changed = 1; + + /* we have to do the following because we are being called from a + * netfilter hook which means the packet already has had the header + * fields populated and the checksum calculated - yes this means we + * are doing more work than needed but we do it to keep the core + * stack clean and tidy */ + memcpy(iph + 1, buf, buf_len); + if (opt_len > buf_len) + memset((char *)(iph + 1) + buf_len, 0, opt_len - buf_len); + if (len_delta != 0) { + iph->ihl = 5 + (opt_len >> 2); + iph->tot_len = htons(skb->len); + } + ip_send_check(iph); + + return 0; +} + +/** + * cipso_v4_skbuff_delattr - Delete any CIPSO options from a packet + * @skb: the packet + * + * Description: + * Removes any and all CIPSO options from the given packet. Returns zero on + * success, negative values on failure. + * + */ +int cipso_v4_skbuff_delattr(struct sk_buff *skb) +{ + int ret_val; + struct iphdr *iph; + struct ip_options *opt = &IPCB(skb)->opt; + unsigned char *cipso_ptr; + + if (opt->cipso == 0) + return 0; + + /* since we are changing the packet we should make a copy */ + ret_val = skb_cow(skb, skb_headroom(skb)); + if (ret_val < 0) + return ret_val; + + /* the easiest thing to do is just replace the cipso option with noop + * options since we don't change the size of the packet, although we + * still need to recalculate the checksum */ + + iph = ip_hdr(skb); + cipso_ptr = (unsigned char *)iph + opt->cipso; + memset(cipso_ptr, IPOPT_NOOP, cipso_ptr[1]); + opt->cipso = 0; + opt->is_changed = 1; + + ip_send_check(iph); + + return 0; +} + +/* + * Setup Functions + */ + +/** + * cipso_v4_init - Initialize the CIPSO module + * + * Description: + * Initialize the CIPSO module and prepare it for use. Returns zero on success + * and negative values on failure. + * + */ +static int __init cipso_v4_init(void) +{ + int ret_val; + + ret_val = cipso_v4_cache_init(); + if (ret_val != 0) + panic("Failed to initialize the CIPSO/IPv4 cache (%d)\n", + ret_val); + + return 0; +} + +subsys_initcall(cipso_v4_init); diff --git a/net/ipv4/datagram.c b/net/ipv4/datagram.c new file mode 100644 index 000000000..cb5dbee9e --- /dev/null +++ b/net/ipv4/datagram.c @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * common UDP/RAW code + * Linux INET implementation + * + * Authors: + * Hideaki YOSHIFUJI + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +int __ip4_datagram_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) +{ + struct inet_sock *inet = inet_sk(sk); + struct sockaddr_in *usin = (struct sockaddr_in *) uaddr; + struct flowi4 *fl4; + struct rtable *rt; + __be32 saddr; + int oif; + int err; + + + if (addr_len < sizeof(*usin)) + return -EINVAL; + + if (usin->sin_family != AF_INET) + return -EAFNOSUPPORT; + + sk_dst_reset(sk); + + oif = sk->sk_bound_dev_if; + saddr = inet->inet_saddr; + if (ipv4_is_multicast(usin->sin_addr.s_addr)) { + if (!oif || netif_index_is_l3_master(sock_net(sk), oif)) + oif = inet->mc_index; + if (!saddr) + saddr = inet->mc_addr; + } else if (!oif) { + oif = inet->uc_index; + } + fl4 = &inet->cork.fl.u.ip4; + rt = ip_route_connect(fl4, usin->sin_addr.s_addr, saddr, oif, + sk->sk_protocol, inet->inet_sport, + usin->sin_port, sk); + if (IS_ERR(rt)) { + err = PTR_ERR(rt); + if (err == -ENETUNREACH) + IP_INC_STATS(sock_net(sk), IPSTATS_MIB_OUTNOROUTES); + goto out; + } + + if ((rt->rt_flags & RTCF_BROADCAST) && !sock_flag(sk, SOCK_BROADCAST)) { + ip_rt_put(rt); + err = -EACCES; + goto out; + } + if (!inet->inet_saddr) + inet->inet_saddr = fl4->saddr; /* Update source address */ + if (!inet->inet_rcv_saddr) { + inet->inet_rcv_saddr = fl4->saddr; + if (sk->sk_prot->rehash) + sk->sk_prot->rehash(sk); + } + inet->inet_daddr = fl4->daddr; + inet->inet_dport = usin->sin_port; + reuseport_has_conns_set(sk); + sk->sk_state = TCP_ESTABLISHED; + sk_set_txhash(sk); + atomic_set(&inet->inet_id, get_random_u16()); + + sk_dst_set(sk, &rt->dst); + err = 0; +out: + return err; +} +EXPORT_SYMBOL(__ip4_datagram_connect); + +int ip4_datagram_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) +{ + int res; + + lock_sock(sk); + res = __ip4_datagram_connect(sk, uaddr, addr_len); + release_sock(sk); + return res; +} +EXPORT_SYMBOL(ip4_datagram_connect); + +/* Because UDP xmit path can manipulate sk_dst_cache without holding + * socket lock, we need to use sk_dst_set() here, + * even if we own the socket lock. + */ +void ip4_datagram_release_cb(struct sock *sk) +{ + const struct inet_sock *inet = inet_sk(sk); + const struct ip_options_rcu *inet_opt; + __be32 daddr = inet->inet_daddr; + struct dst_entry *dst; + struct flowi4 fl4; + struct rtable *rt; + + rcu_read_lock(); + + dst = __sk_dst_get(sk); + if (!dst || !dst->obsolete || dst->ops->check(dst, 0)) { + rcu_read_unlock(); + return; + } + inet_opt = rcu_dereference(inet->inet_opt); + if (inet_opt && inet_opt->opt.srr) + daddr = inet_opt->opt.faddr; + rt = ip_route_output_ports(sock_net(sk), &fl4, sk, daddr, + inet->inet_saddr, inet->inet_dport, + inet->inet_sport, sk->sk_protocol, + RT_CONN_FLAGS(sk), sk->sk_bound_dev_if); + + dst = !IS_ERR(rt) ? &rt->dst : NULL; + sk_dst_set(sk, dst); + + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(ip4_datagram_release_cb); diff --git a/net/ipv4/devinet.c b/net/ipv4/devinet.c new file mode 100644 index 000000000..35d6e74be --- /dev/null +++ b/net/ipv4/devinet.c @@ -0,0 +1,2794 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NET3 IP device support routines. + * + * Derived from the IP parts of dev.c 1.0.19 + * Authors: Ross Biro + * Fred N. van Kempen, + * Mark Evans, + * + * Additional Authors: + * Alan Cox, + * Alexey Kuznetsov, + * + * Changes: + * Alexey Kuznetsov: pa_* fields are replaced with ifaddr + * lists. + * Cyrus Durgin: updated for kmod + * Matthias Andree: in devinet_ioctl, compare label and + * address (4.4BSD alias style support), + * fall back to comparing just the label + * if no match found. + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_SYSCTL +#include +#endif +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define IPV6ONLY_FLAGS \ + (IFA_F_NODAD | IFA_F_OPTIMISTIC | IFA_F_DADFAILED | \ + IFA_F_HOMEADDRESS | IFA_F_TENTATIVE | \ + IFA_F_MANAGETEMPADDR | IFA_F_STABLE_PRIVACY) + +static struct ipv4_devconf ipv4_devconf = { + .data = { + [IPV4_DEVCONF_ACCEPT_REDIRECTS - 1] = 1, + [IPV4_DEVCONF_SEND_REDIRECTS - 1] = 1, + [IPV4_DEVCONF_SECURE_REDIRECTS - 1] = 1, + [IPV4_DEVCONF_SHARED_MEDIA - 1] = 1, + [IPV4_DEVCONF_IGMPV2_UNSOLICITED_REPORT_INTERVAL - 1] = 10000 /*ms*/, + [IPV4_DEVCONF_IGMPV3_UNSOLICITED_REPORT_INTERVAL - 1] = 1000 /*ms*/, + [IPV4_DEVCONF_ARP_EVICT_NOCARRIER - 1] = 1, + }, +}; + +static struct ipv4_devconf ipv4_devconf_dflt = { + .data = { + [IPV4_DEVCONF_ACCEPT_REDIRECTS - 1] = 1, + [IPV4_DEVCONF_SEND_REDIRECTS - 1] = 1, + [IPV4_DEVCONF_SECURE_REDIRECTS - 1] = 1, + [IPV4_DEVCONF_SHARED_MEDIA - 1] = 1, + [IPV4_DEVCONF_ACCEPT_SOURCE_ROUTE - 1] = 1, + [IPV4_DEVCONF_IGMPV2_UNSOLICITED_REPORT_INTERVAL - 1] = 10000 /*ms*/, + [IPV4_DEVCONF_IGMPV3_UNSOLICITED_REPORT_INTERVAL - 1] = 1000 /*ms*/, + [IPV4_DEVCONF_ARP_EVICT_NOCARRIER - 1] = 1, + }, +}; + +#define IPV4_DEVCONF_DFLT(net, attr) \ + IPV4_DEVCONF((*net->ipv4.devconf_dflt), attr) + +static const struct nla_policy ifa_ipv4_policy[IFA_MAX+1] = { + [IFA_LOCAL] = { .type = NLA_U32 }, + [IFA_ADDRESS] = { .type = NLA_U32 }, + [IFA_BROADCAST] = { .type = NLA_U32 }, + [IFA_LABEL] = { .type = NLA_STRING, .len = IFNAMSIZ - 1 }, + [IFA_CACHEINFO] = { .len = sizeof(struct ifa_cacheinfo) }, + [IFA_FLAGS] = { .type = NLA_U32 }, + [IFA_RT_PRIORITY] = { .type = NLA_U32 }, + [IFA_TARGET_NETNSID] = { .type = NLA_S32 }, + [IFA_PROTO] = { .type = NLA_U8 }, +}; + +struct inet_fill_args { + u32 portid; + u32 seq; + int event; + unsigned int flags; + int netnsid; + int ifindex; +}; + +#define IN4_ADDR_HSIZE_SHIFT 8 +#define IN4_ADDR_HSIZE (1U << IN4_ADDR_HSIZE_SHIFT) + +static struct hlist_head inet_addr_lst[IN4_ADDR_HSIZE]; + +static u32 inet_addr_hash(const struct net *net, __be32 addr) +{ + u32 val = (__force u32) addr ^ net_hash_mix(net); + + return hash_32(val, IN4_ADDR_HSIZE_SHIFT); +} + +static void inet_hash_insert(struct net *net, struct in_ifaddr *ifa) +{ + u32 hash = inet_addr_hash(net, ifa->ifa_local); + + ASSERT_RTNL(); + hlist_add_head_rcu(&ifa->hash, &inet_addr_lst[hash]); +} + +static void inet_hash_remove(struct in_ifaddr *ifa) +{ + ASSERT_RTNL(); + hlist_del_init_rcu(&ifa->hash); +} + +/** + * __ip_dev_find - find the first device with a given source address. + * @net: the net namespace + * @addr: the source address + * @devref: if true, take a reference on the found device + * + * If a caller uses devref=false, it should be protected by RCU, or RTNL + */ +struct net_device *__ip_dev_find(struct net *net, __be32 addr, bool devref) +{ + struct net_device *result = NULL; + struct in_ifaddr *ifa; + + rcu_read_lock(); + ifa = inet_lookup_ifaddr_rcu(net, addr); + if (!ifa) { + struct flowi4 fl4 = { .daddr = addr }; + struct fib_result res = { 0 }; + struct fib_table *local; + + /* Fallback to FIB local table so that communication + * over loopback subnets work. + */ + local = fib_get_table(net, RT_TABLE_LOCAL); + if (local && + !fib_table_lookup(local, &fl4, &res, FIB_LOOKUP_NOREF) && + res.type == RTN_LOCAL) + result = FIB_RES_DEV(res); + } else { + result = ifa->ifa_dev->dev; + } + if (result && devref) + dev_hold(result); + rcu_read_unlock(); + return result; +} +EXPORT_SYMBOL(__ip_dev_find); + +/* called under RCU lock */ +struct in_ifaddr *inet_lookup_ifaddr_rcu(struct net *net, __be32 addr) +{ + u32 hash = inet_addr_hash(net, addr); + struct in_ifaddr *ifa; + + hlist_for_each_entry_rcu(ifa, &inet_addr_lst[hash], hash) + if (ifa->ifa_local == addr && + net_eq(dev_net(ifa->ifa_dev->dev), net)) + return ifa; + + return NULL; +} + +static void rtmsg_ifa(int event, struct in_ifaddr *, struct nlmsghdr *, u32); + +static BLOCKING_NOTIFIER_HEAD(inetaddr_chain); +static BLOCKING_NOTIFIER_HEAD(inetaddr_validator_chain); +static void inet_del_ifa(struct in_device *in_dev, + struct in_ifaddr __rcu **ifap, + int destroy); +#ifdef CONFIG_SYSCTL +static int devinet_sysctl_register(struct in_device *idev); +static void devinet_sysctl_unregister(struct in_device *idev); +#else +static int devinet_sysctl_register(struct in_device *idev) +{ + return 0; +} +static void devinet_sysctl_unregister(struct in_device *idev) +{ +} +#endif + +/* Locks all the inet devices. */ + +static struct in_ifaddr *inet_alloc_ifa(void) +{ + return kzalloc(sizeof(struct in_ifaddr), GFP_KERNEL_ACCOUNT); +} + +static void inet_rcu_free_ifa(struct rcu_head *head) +{ + struct in_ifaddr *ifa = container_of(head, struct in_ifaddr, rcu_head); + if (ifa->ifa_dev) + in_dev_put(ifa->ifa_dev); + kfree(ifa); +} + +static void inet_free_ifa(struct in_ifaddr *ifa) +{ + call_rcu(&ifa->rcu_head, inet_rcu_free_ifa); +} + +void in_dev_finish_destroy(struct in_device *idev) +{ + struct net_device *dev = idev->dev; + + WARN_ON(idev->ifa_list); + WARN_ON(idev->mc_list); + kfree(rcu_dereference_protected(idev->mc_hash, 1)); +#ifdef NET_REFCNT_DEBUG + pr_debug("%s: %p=%s\n", __func__, idev, dev ? dev->name : "NIL"); +#endif + netdev_put(dev, &idev->dev_tracker); + if (!idev->dead) + pr_err("Freeing alive in_device %p\n", idev); + else + kfree(idev); +} +EXPORT_SYMBOL(in_dev_finish_destroy); + +static struct in_device *inetdev_init(struct net_device *dev) +{ + struct in_device *in_dev; + int err = -ENOMEM; + + ASSERT_RTNL(); + + in_dev = kzalloc(sizeof(*in_dev), GFP_KERNEL); + if (!in_dev) + goto out; + memcpy(&in_dev->cnf, dev_net(dev)->ipv4.devconf_dflt, + sizeof(in_dev->cnf)); + in_dev->cnf.sysctl = NULL; + in_dev->dev = dev; + in_dev->arp_parms = neigh_parms_alloc(dev, &arp_tbl); + if (!in_dev->arp_parms) + goto out_kfree; + if (IPV4_DEVCONF(in_dev->cnf, FORWARDING)) + dev_disable_lro(dev); + /* Reference in_dev->dev */ + netdev_hold(dev, &in_dev->dev_tracker, GFP_KERNEL); + /* Account for reference dev->ip_ptr (below) */ + refcount_set(&in_dev->refcnt, 1); + + err = devinet_sysctl_register(in_dev); + if (err) { + in_dev->dead = 1; + neigh_parms_release(&arp_tbl, in_dev->arp_parms); + in_dev_put(in_dev); + in_dev = NULL; + goto out; + } + ip_mc_init_dev(in_dev); + if (dev->flags & IFF_UP) + ip_mc_up(in_dev); + + /* we can receive as soon as ip_ptr is set -- do this last */ + rcu_assign_pointer(dev->ip_ptr, in_dev); +out: + return in_dev ?: ERR_PTR(err); +out_kfree: + kfree(in_dev); + in_dev = NULL; + goto out; +} + +static void in_dev_rcu_put(struct rcu_head *head) +{ + struct in_device *idev = container_of(head, struct in_device, rcu_head); + in_dev_put(idev); +} + +static void inetdev_destroy(struct in_device *in_dev) +{ + struct net_device *dev; + struct in_ifaddr *ifa; + + ASSERT_RTNL(); + + dev = in_dev->dev; + + in_dev->dead = 1; + + ip_mc_destroy_dev(in_dev); + + while ((ifa = rtnl_dereference(in_dev->ifa_list)) != NULL) { + inet_del_ifa(in_dev, &in_dev->ifa_list, 0); + inet_free_ifa(ifa); + } + + RCU_INIT_POINTER(dev->ip_ptr, NULL); + + devinet_sysctl_unregister(in_dev); + neigh_parms_release(&arp_tbl, in_dev->arp_parms); + arp_ifdown(dev); + + call_rcu(&in_dev->rcu_head, in_dev_rcu_put); +} + +int inet_addr_onlink(struct in_device *in_dev, __be32 a, __be32 b) +{ + const struct in_ifaddr *ifa; + + rcu_read_lock(); + in_dev_for_each_ifa_rcu(ifa, in_dev) { + if (inet_ifa_match(a, ifa)) { + if (!b || inet_ifa_match(b, ifa)) { + rcu_read_unlock(); + return 1; + } + } + } + rcu_read_unlock(); + return 0; +} + +static void __inet_del_ifa(struct in_device *in_dev, + struct in_ifaddr __rcu **ifap, + int destroy, struct nlmsghdr *nlh, u32 portid) +{ + struct in_ifaddr *promote = NULL; + struct in_ifaddr *ifa, *ifa1; + struct in_ifaddr __rcu **last_prim; + struct in_ifaddr *prev_prom = NULL; + int do_promote = IN_DEV_PROMOTE_SECONDARIES(in_dev); + + ASSERT_RTNL(); + + ifa1 = rtnl_dereference(*ifap); + last_prim = ifap; + if (in_dev->dead) + goto no_promotions; + + /* 1. Deleting primary ifaddr forces deletion all secondaries + * unless alias promotion is set + **/ + + if (!(ifa1->ifa_flags & IFA_F_SECONDARY)) { + struct in_ifaddr __rcu **ifap1 = &ifa1->ifa_next; + + while ((ifa = rtnl_dereference(*ifap1)) != NULL) { + if (!(ifa->ifa_flags & IFA_F_SECONDARY) && + ifa1->ifa_scope <= ifa->ifa_scope) + last_prim = &ifa->ifa_next; + + if (!(ifa->ifa_flags & IFA_F_SECONDARY) || + ifa1->ifa_mask != ifa->ifa_mask || + !inet_ifa_match(ifa1->ifa_address, ifa)) { + ifap1 = &ifa->ifa_next; + prev_prom = ifa; + continue; + } + + if (!do_promote) { + inet_hash_remove(ifa); + *ifap1 = ifa->ifa_next; + + rtmsg_ifa(RTM_DELADDR, ifa, nlh, portid); + blocking_notifier_call_chain(&inetaddr_chain, + NETDEV_DOWN, ifa); + inet_free_ifa(ifa); + } else { + promote = ifa; + break; + } + } + } + + /* On promotion all secondaries from subnet are changing + * the primary IP, we must remove all their routes silently + * and later to add them back with new prefsrc. Do this + * while all addresses are on the device list. + */ + for (ifa = promote; ifa; ifa = rtnl_dereference(ifa->ifa_next)) { + if (ifa1->ifa_mask == ifa->ifa_mask && + inet_ifa_match(ifa1->ifa_address, ifa)) + fib_del_ifaddr(ifa, ifa1); + } + +no_promotions: + /* 2. Unlink it */ + + *ifap = ifa1->ifa_next; + inet_hash_remove(ifa1); + + /* 3. Announce address deletion */ + + /* Send message first, then call notifier. + At first sight, FIB update triggered by notifier + will refer to already deleted ifaddr, that could confuse + netlink listeners. It is not true: look, gated sees + that route deleted and if it still thinks that ifaddr + is valid, it will try to restore deleted routes... Grr. + So that, this order is correct. + */ + rtmsg_ifa(RTM_DELADDR, ifa1, nlh, portid); + blocking_notifier_call_chain(&inetaddr_chain, NETDEV_DOWN, ifa1); + + if (promote) { + struct in_ifaddr *next_sec; + + next_sec = rtnl_dereference(promote->ifa_next); + if (prev_prom) { + struct in_ifaddr *last_sec; + + rcu_assign_pointer(prev_prom->ifa_next, next_sec); + + last_sec = rtnl_dereference(*last_prim); + rcu_assign_pointer(promote->ifa_next, last_sec); + rcu_assign_pointer(*last_prim, promote); + } + + promote->ifa_flags &= ~IFA_F_SECONDARY; + rtmsg_ifa(RTM_NEWADDR, promote, nlh, portid); + blocking_notifier_call_chain(&inetaddr_chain, + NETDEV_UP, promote); + for (ifa = next_sec; ifa; + ifa = rtnl_dereference(ifa->ifa_next)) { + if (ifa1->ifa_mask != ifa->ifa_mask || + !inet_ifa_match(ifa1->ifa_address, ifa)) + continue; + fib_add_ifaddr(ifa); + } + + } + if (destroy) + inet_free_ifa(ifa1); +} + +static void inet_del_ifa(struct in_device *in_dev, + struct in_ifaddr __rcu **ifap, + int destroy) +{ + __inet_del_ifa(in_dev, ifap, destroy, NULL, 0); +} + +static void check_lifetime(struct work_struct *work); + +static DECLARE_DELAYED_WORK(check_lifetime_work, check_lifetime); + +static int __inet_insert_ifa(struct in_ifaddr *ifa, struct nlmsghdr *nlh, + u32 portid, struct netlink_ext_ack *extack) +{ + struct in_ifaddr __rcu **last_primary, **ifap; + struct in_device *in_dev = ifa->ifa_dev; + struct in_validator_info ivi; + struct in_ifaddr *ifa1; + int ret; + + ASSERT_RTNL(); + + if (!ifa->ifa_local) { + inet_free_ifa(ifa); + return 0; + } + + ifa->ifa_flags &= ~IFA_F_SECONDARY; + last_primary = &in_dev->ifa_list; + + /* Don't set IPv6 only flags to IPv4 addresses */ + ifa->ifa_flags &= ~IPV6ONLY_FLAGS; + + ifap = &in_dev->ifa_list; + ifa1 = rtnl_dereference(*ifap); + + while (ifa1) { + if (!(ifa1->ifa_flags & IFA_F_SECONDARY) && + ifa->ifa_scope <= ifa1->ifa_scope) + last_primary = &ifa1->ifa_next; + if (ifa1->ifa_mask == ifa->ifa_mask && + inet_ifa_match(ifa1->ifa_address, ifa)) { + if (ifa1->ifa_local == ifa->ifa_local) { + inet_free_ifa(ifa); + return -EEXIST; + } + if (ifa1->ifa_scope != ifa->ifa_scope) { + inet_free_ifa(ifa); + return -EINVAL; + } + ifa->ifa_flags |= IFA_F_SECONDARY; + } + + ifap = &ifa1->ifa_next; + ifa1 = rtnl_dereference(*ifap); + } + + /* Allow any devices that wish to register ifaddr validtors to weigh + * in now, before changes are committed. The rntl lock is serializing + * access here, so the state should not change between a validator call + * and a final notify on commit. This isn't invoked on promotion under + * the assumption that validators are checking the address itself, and + * not the flags. + */ + ivi.ivi_addr = ifa->ifa_address; + ivi.ivi_dev = ifa->ifa_dev; + ivi.extack = extack; + ret = blocking_notifier_call_chain(&inetaddr_validator_chain, + NETDEV_UP, &ivi); + ret = notifier_to_errno(ret); + if (ret) { + inet_free_ifa(ifa); + return ret; + } + + if (!(ifa->ifa_flags & IFA_F_SECONDARY)) + ifap = last_primary; + + rcu_assign_pointer(ifa->ifa_next, *ifap); + rcu_assign_pointer(*ifap, ifa); + + inet_hash_insert(dev_net(in_dev->dev), ifa); + + cancel_delayed_work(&check_lifetime_work); + queue_delayed_work(system_power_efficient_wq, &check_lifetime_work, 0); + + /* Send message first, then call notifier. + Notifier will trigger FIB update, so that + listeners of netlink will know about new ifaddr */ + rtmsg_ifa(RTM_NEWADDR, ifa, nlh, portid); + blocking_notifier_call_chain(&inetaddr_chain, NETDEV_UP, ifa); + + return 0; +} + +static int inet_insert_ifa(struct in_ifaddr *ifa) +{ + return __inet_insert_ifa(ifa, NULL, 0, NULL); +} + +static int inet_set_ifa(struct net_device *dev, struct in_ifaddr *ifa) +{ + struct in_device *in_dev = __in_dev_get_rtnl(dev); + + ASSERT_RTNL(); + + if (!in_dev) { + inet_free_ifa(ifa); + return -ENOBUFS; + } + ipv4_devconf_setall(in_dev); + neigh_parms_data_state_setall(in_dev->arp_parms); + if (ifa->ifa_dev != in_dev) { + WARN_ON(ifa->ifa_dev); + in_dev_hold(in_dev); + ifa->ifa_dev = in_dev; + } + if (ipv4_is_loopback(ifa->ifa_local)) + ifa->ifa_scope = RT_SCOPE_HOST; + return inet_insert_ifa(ifa); +} + +/* Caller must hold RCU or RTNL : + * We dont take a reference on found in_device + */ +struct in_device *inetdev_by_index(struct net *net, int ifindex) +{ + struct net_device *dev; + struct in_device *in_dev = NULL; + + rcu_read_lock(); + dev = dev_get_by_index_rcu(net, ifindex); + if (dev) + in_dev = rcu_dereference_rtnl(dev->ip_ptr); + rcu_read_unlock(); + return in_dev; +} +EXPORT_SYMBOL(inetdev_by_index); + +/* Called only from RTNL semaphored context. No locks. */ + +struct in_ifaddr *inet_ifa_byprefix(struct in_device *in_dev, __be32 prefix, + __be32 mask) +{ + struct in_ifaddr *ifa; + + ASSERT_RTNL(); + + in_dev_for_each_ifa_rtnl(ifa, in_dev) { + if (ifa->ifa_mask == mask && inet_ifa_match(prefix, ifa)) + return ifa; + } + return NULL; +} + +static int ip_mc_autojoin_config(struct net *net, bool join, + const struct in_ifaddr *ifa) +{ +#if defined(CONFIG_IP_MULTICAST) + struct ip_mreqn mreq = { + .imr_multiaddr.s_addr = ifa->ifa_address, + .imr_ifindex = ifa->ifa_dev->dev->ifindex, + }; + struct sock *sk = net->ipv4.mc_autojoin_sk; + int ret; + + ASSERT_RTNL(); + + lock_sock(sk); + if (join) + ret = ip_mc_join_group(sk, &mreq); + else + ret = ip_mc_leave_group(sk, &mreq); + release_sock(sk); + + return ret; +#else + return -EOPNOTSUPP; +#endif +} + +static int inet_rtm_deladdr(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct in_ifaddr __rcu **ifap; + struct nlattr *tb[IFA_MAX+1]; + struct in_device *in_dev; + struct ifaddrmsg *ifm; + struct in_ifaddr *ifa; + int err; + + ASSERT_RTNL(); + + err = nlmsg_parse_deprecated(nlh, sizeof(*ifm), tb, IFA_MAX, + ifa_ipv4_policy, extack); + if (err < 0) + goto errout; + + ifm = nlmsg_data(nlh); + in_dev = inetdev_by_index(net, ifm->ifa_index); + if (!in_dev) { + err = -ENODEV; + goto errout; + } + + for (ifap = &in_dev->ifa_list; (ifa = rtnl_dereference(*ifap)) != NULL; + ifap = &ifa->ifa_next) { + if (tb[IFA_LOCAL] && + ifa->ifa_local != nla_get_in_addr(tb[IFA_LOCAL])) + continue; + + if (tb[IFA_LABEL] && nla_strcmp(tb[IFA_LABEL], ifa->ifa_label)) + continue; + + if (tb[IFA_ADDRESS] && + (ifm->ifa_prefixlen != ifa->ifa_prefixlen || + !inet_ifa_match(nla_get_in_addr(tb[IFA_ADDRESS]), ifa))) + continue; + + if (ipv4_is_multicast(ifa->ifa_address)) + ip_mc_autojoin_config(net, false, ifa); + __inet_del_ifa(in_dev, ifap, 1, nlh, NETLINK_CB(skb).portid); + return 0; + } + + err = -EADDRNOTAVAIL; +errout: + return err; +} + +#define INFINITY_LIFE_TIME 0xFFFFFFFF + +static void check_lifetime(struct work_struct *work) +{ + unsigned long now, next, next_sec, next_sched; + struct in_ifaddr *ifa; + struct hlist_node *n; + int i; + + now = jiffies; + next = round_jiffies_up(now + ADDR_CHECK_FREQUENCY); + + for (i = 0; i < IN4_ADDR_HSIZE; i++) { + bool change_needed = false; + + rcu_read_lock(); + hlist_for_each_entry_rcu(ifa, &inet_addr_lst[i], hash) { + unsigned long age; + + if (ifa->ifa_flags & IFA_F_PERMANENT) + continue; + + /* We try to batch several events at once. */ + age = (now - ifa->ifa_tstamp + + ADDRCONF_TIMER_FUZZ_MINUS) / HZ; + + if (ifa->ifa_valid_lft != INFINITY_LIFE_TIME && + age >= ifa->ifa_valid_lft) { + change_needed = true; + } else if (ifa->ifa_preferred_lft == + INFINITY_LIFE_TIME) { + continue; + } else if (age >= ifa->ifa_preferred_lft) { + if (time_before(ifa->ifa_tstamp + + ifa->ifa_valid_lft * HZ, next)) + next = ifa->ifa_tstamp + + ifa->ifa_valid_lft * HZ; + + if (!(ifa->ifa_flags & IFA_F_DEPRECATED)) + change_needed = true; + } else if (time_before(ifa->ifa_tstamp + + ifa->ifa_preferred_lft * HZ, + next)) { + next = ifa->ifa_tstamp + + ifa->ifa_preferred_lft * HZ; + } + } + rcu_read_unlock(); + if (!change_needed) + continue; + rtnl_lock(); + hlist_for_each_entry_safe(ifa, n, &inet_addr_lst[i], hash) { + unsigned long age; + + if (ifa->ifa_flags & IFA_F_PERMANENT) + continue; + + /* We try to batch several events at once. */ + age = (now - ifa->ifa_tstamp + + ADDRCONF_TIMER_FUZZ_MINUS) / HZ; + + if (ifa->ifa_valid_lft != INFINITY_LIFE_TIME && + age >= ifa->ifa_valid_lft) { + struct in_ifaddr __rcu **ifap; + struct in_ifaddr *tmp; + + ifap = &ifa->ifa_dev->ifa_list; + tmp = rtnl_dereference(*ifap); + while (tmp) { + if (tmp == ifa) { + inet_del_ifa(ifa->ifa_dev, + ifap, 1); + break; + } + ifap = &tmp->ifa_next; + tmp = rtnl_dereference(*ifap); + } + } else if (ifa->ifa_preferred_lft != + INFINITY_LIFE_TIME && + age >= ifa->ifa_preferred_lft && + !(ifa->ifa_flags & IFA_F_DEPRECATED)) { + ifa->ifa_flags |= IFA_F_DEPRECATED; + rtmsg_ifa(RTM_NEWADDR, ifa, NULL, 0); + } + } + rtnl_unlock(); + } + + next_sec = round_jiffies_up(next); + next_sched = next; + + /* If rounded timeout is accurate enough, accept it. */ + if (time_before(next_sec, next + ADDRCONF_TIMER_FUZZ)) + next_sched = next_sec; + + now = jiffies; + /* And minimum interval is ADDRCONF_TIMER_FUZZ_MAX. */ + if (time_before(next_sched, now + ADDRCONF_TIMER_FUZZ_MAX)) + next_sched = now + ADDRCONF_TIMER_FUZZ_MAX; + + queue_delayed_work(system_power_efficient_wq, &check_lifetime_work, + next_sched - now); +} + +static void set_ifa_lifetime(struct in_ifaddr *ifa, __u32 valid_lft, + __u32 prefered_lft) +{ + unsigned long timeout; + + ifa->ifa_flags &= ~(IFA_F_PERMANENT | IFA_F_DEPRECATED); + + timeout = addrconf_timeout_fixup(valid_lft, HZ); + if (addrconf_finite_timeout(timeout)) + ifa->ifa_valid_lft = timeout; + else + ifa->ifa_flags |= IFA_F_PERMANENT; + + timeout = addrconf_timeout_fixup(prefered_lft, HZ); + if (addrconf_finite_timeout(timeout)) { + if (timeout == 0) + ifa->ifa_flags |= IFA_F_DEPRECATED; + ifa->ifa_preferred_lft = timeout; + } + ifa->ifa_tstamp = jiffies; + if (!ifa->ifa_cstamp) + ifa->ifa_cstamp = ifa->ifa_tstamp; +} + +static struct in_ifaddr *rtm_to_ifaddr(struct net *net, struct nlmsghdr *nlh, + __u32 *pvalid_lft, __u32 *pprefered_lft, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFA_MAX+1]; + struct in_ifaddr *ifa; + struct ifaddrmsg *ifm; + struct net_device *dev; + struct in_device *in_dev; + int err; + + err = nlmsg_parse_deprecated(nlh, sizeof(*ifm), tb, IFA_MAX, + ifa_ipv4_policy, extack); + if (err < 0) + goto errout; + + ifm = nlmsg_data(nlh); + err = -EINVAL; + if (ifm->ifa_prefixlen > 32 || !tb[IFA_LOCAL]) + goto errout; + + dev = __dev_get_by_index(net, ifm->ifa_index); + err = -ENODEV; + if (!dev) + goto errout; + + in_dev = __in_dev_get_rtnl(dev); + err = -ENOBUFS; + if (!in_dev) + goto errout; + + ifa = inet_alloc_ifa(); + if (!ifa) + /* + * A potential indev allocation can be left alive, it stays + * assigned to its device and is destroy with it. + */ + goto errout; + + ipv4_devconf_setall(in_dev); + neigh_parms_data_state_setall(in_dev->arp_parms); + in_dev_hold(in_dev); + + if (!tb[IFA_ADDRESS]) + tb[IFA_ADDRESS] = tb[IFA_LOCAL]; + + INIT_HLIST_NODE(&ifa->hash); + ifa->ifa_prefixlen = ifm->ifa_prefixlen; + ifa->ifa_mask = inet_make_mask(ifm->ifa_prefixlen); + ifa->ifa_flags = tb[IFA_FLAGS] ? nla_get_u32(tb[IFA_FLAGS]) : + ifm->ifa_flags; + ifa->ifa_scope = ifm->ifa_scope; + ifa->ifa_dev = in_dev; + + ifa->ifa_local = nla_get_in_addr(tb[IFA_LOCAL]); + ifa->ifa_address = nla_get_in_addr(tb[IFA_ADDRESS]); + + if (tb[IFA_BROADCAST]) + ifa->ifa_broadcast = nla_get_in_addr(tb[IFA_BROADCAST]); + + if (tb[IFA_LABEL]) + nla_strscpy(ifa->ifa_label, tb[IFA_LABEL], IFNAMSIZ); + else + memcpy(ifa->ifa_label, dev->name, IFNAMSIZ); + + if (tb[IFA_RT_PRIORITY]) + ifa->ifa_rt_priority = nla_get_u32(tb[IFA_RT_PRIORITY]); + + if (tb[IFA_PROTO]) + ifa->ifa_proto = nla_get_u8(tb[IFA_PROTO]); + + if (tb[IFA_CACHEINFO]) { + struct ifa_cacheinfo *ci; + + ci = nla_data(tb[IFA_CACHEINFO]); + if (!ci->ifa_valid || ci->ifa_prefered > ci->ifa_valid) { + err = -EINVAL; + goto errout_free; + } + *pvalid_lft = ci->ifa_valid; + *pprefered_lft = ci->ifa_prefered; + } + + return ifa; + +errout_free: + inet_free_ifa(ifa); +errout: + return ERR_PTR(err); +} + +static struct in_ifaddr *find_matching_ifa(struct in_ifaddr *ifa) +{ + struct in_device *in_dev = ifa->ifa_dev; + struct in_ifaddr *ifa1; + + if (!ifa->ifa_local) + return NULL; + + in_dev_for_each_ifa_rtnl(ifa1, in_dev) { + if (ifa1->ifa_mask == ifa->ifa_mask && + inet_ifa_match(ifa1->ifa_address, ifa) && + ifa1->ifa_local == ifa->ifa_local) + return ifa1; + } + return NULL; +} + +static int inet_rtm_newaddr(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct in_ifaddr *ifa; + struct in_ifaddr *ifa_existing; + __u32 valid_lft = INFINITY_LIFE_TIME; + __u32 prefered_lft = INFINITY_LIFE_TIME; + + ASSERT_RTNL(); + + ifa = rtm_to_ifaddr(net, nlh, &valid_lft, &prefered_lft, extack); + if (IS_ERR(ifa)) + return PTR_ERR(ifa); + + ifa_existing = find_matching_ifa(ifa); + if (!ifa_existing) { + /* It would be best to check for !NLM_F_CREATE here but + * userspace already relies on not having to provide this. + */ + set_ifa_lifetime(ifa, valid_lft, prefered_lft); + if (ifa->ifa_flags & IFA_F_MCAUTOJOIN) { + int ret = ip_mc_autojoin_config(net, true, ifa); + + if (ret < 0) { + inet_free_ifa(ifa); + return ret; + } + } + return __inet_insert_ifa(ifa, nlh, NETLINK_CB(skb).portid, + extack); + } else { + u32 new_metric = ifa->ifa_rt_priority; + + inet_free_ifa(ifa); + + if (nlh->nlmsg_flags & NLM_F_EXCL || + !(nlh->nlmsg_flags & NLM_F_REPLACE)) + return -EEXIST; + ifa = ifa_existing; + + if (ifa->ifa_rt_priority != new_metric) { + fib_modify_prefix_metric(ifa, new_metric); + ifa->ifa_rt_priority = new_metric; + } + + set_ifa_lifetime(ifa, valid_lft, prefered_lft); + cancel_delayed_work(&check_lifetime_work); + queue_delayed_work(system_power_efficient_wq, + &check_lifetime_work, 0); + rtmsg_ifa(RTM_NEWADDR, ifa, nlh, NETLINK_CB(skb).portid); + } + return 0; +} + +/* + * Determine a default network mask, based on the IP address. + */ + +static int inet_abc_len(__be32 addr) +{ + int rc = -1; /* Something else, probably a multicast. */ + + if (ipv4_is_zeronet(addr) || ipv4_is_lbcast(addr)) + rc = 0; + else { + __u32 haddr = ntohl(addr); + if (IN_CLASSA(haddr)) + rc = 8; + else if (IN_CLASSB(haddr)) + rc = 16; + else if (IN_CLASSC(haddr)) + rc = 24; + else if (IN_CLASSE(haddr)) + rc = 32; + } + + return rc; +} + + +int devinet_ioctl(struct net *net, unsigned int cmd, struct ifreq *ifr) +{ + struct sockaddr_in sin_orig; + struct sockaddr_in *sin = (struct sockaddr_in *)&ifr->ifr_addr; + struct in_ifaddr __rcu **ifap = NULL; + struct in_device *in_dev; + struct in_ifaddr *ifa = NULL; + struct net_device *dev; + char *colon; + int ret = -EFAULT; + int tryaddrmatch = 0; + + ifr->ifr_name[IFNAMSIZ - 1] = 0; + + /* save original address for comparison */ + memcpy(&sin_orig, sin, sizeof(*sin)); + + colon = strchr(ifr->ifr_name, ':'); + if (colon) + *colon = 0; + + dev_load(net, ifr->ifr_name); + + switch (cmd) { + case SIOCGIFADDR: /* Get interface address */ + case SIOCGIFBRDADDR: /* Get the broadcast address */ + case SIOCGIFDSTADDR: /* Get the destination address */ + case SIOCGIFNETMASK: /* Get the netmask for the interface */ + /* Note that these ioctls will not sleep, + so that we do not impose a lock. + One day we will be forced to put shlock here (I mean SMP) + */ + tryaddrmatch = (sin_orig.sin_family == AF_INET); + memset(sin, 0, sizeof(*sin)); + sin->sin_family = AF_INET; + break; + + case SIOCSIFFLAGS: + ret = -EPERM; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + goto out; + break; + case SIOCSIFADDR: /* Set interface address (and family) */ + case SIOCSIFBRDADDR: /* Set the broadcast address */ + case SIOCSIFDSTADDR: /* Set the destination address */ + case SIOCSIFNETMASK: /* Set the netmask for the interface */ + ret = -EPERM; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + goto out; + ret = -EINVAL; + if (sin->sin_family != AF_INET) + goto out; + break; + default: + ret = -EINVAL; + goto out; + } + + rtnl_lock(); + + ret = -ENODEV; + dev = __dev_get_by_name(net, ifr->ifr_name); + if (!dev) + goto done; + + if (colon) + *colon = ':'; + + in_dev = __in_dev_get_rtnl(dev); + if (in_dev) { + if (tryaddrmatch) { + /* Matthias Andree */ + /* compare label and address (4.4BSD style) */ + /* note: we only do this for a limited set of ioctls + and only if the original address family was AF_INET. + This is checked above. */ + + for (ifap = &in_dev->ifa_list; + (ifa = rtnl_dereference(*ifap)) != NULL; + ifap = &ifa->ifa_next) { + if (!strcmp(ifr->ifr_name, ifa->ifa_label) && + sin_orig.sin_addr.s_addr == + ifa->ifa_local) { + break; /* found */ + } + } + } + /* we didn't get a match, maybe the application is + 4.3BSD-style and passed in junk so we fall back to + comparing just the label */ + if (!ifa) { + for (ifap = &in_dev->ifa_list; + (ifa = rtnl_dereference(*ifap)) != NULL; + ifap = &ifa->ifa_next) + if (!strcmp(ifr->ifr_name, ifa->ifa_label)) + break; + } + } + + ret = -EADDRNOTAVAIL; + if (!ifa && cmd != SIOCSIFADDR && cmd != SIOCSIFFLAGS) + goto done; + + switch (cmd) { + case SIOCGIFADDR: /* Get interface address */ + ret = 0; + sin->sin_addr.s_addr = ifa->ifa_local; + break; + + case SIOCGIFBRDADDR: /* Get the broadcast address */ + ret = 0; + sin->sin_addr.s_addr = ifa->ifa_broadcast; + break; + + case SIOCGIFDSTADDR: /* Get the destination address */ + ret = 0; + sin->sin_addr.s_addr = ifa->ifa_address; + break; + + case SIOCGIFNETMASK: /* Get the netmask for the interface */ + ret = 0; + sin->sin_addr.s_addr = ifa->ifa_mask; + break; + + case SIOCSIFFLAGS: + if (colon) { + ret = -EADDRNOTAVAIL; + if (!ifa) + break; + ret = 0; + if (!(ifr->ifr_flags & IFF_UP)) + inet_del_ifa(in_dev, ifap, 1); + break; + } + ret = dev_change_flags(dev, ifr->ifr_flags, NULL); + break; + + case SIOCSIFADDR: /* Set interface address (and family) */ + ret = -EINVAL; + if (inet_abc_len(sin->sin_addr.s_addr) < 0) + break; + + if (!ifa) { + ret = -ENOBUFS; + ifa = inet_alloc_ifa(); + if (!ifa) + break; + INIT_HLIST_NODE(&ifa->hash); + if (colon) + memcpy(ifa->ifa_label, ifr->ifr_name, IFNAMSIZ); + else + memcpy(ifa->ifa_label, dev->name, IFNAMSIZ); + } else { + ret = 0; + if (ifa->ifa_local == sin->sin_addr.s_addr) + break; + inet_del_ifa(in_dev, ifap, 0); + ifa->ifa_broadcast = 0; + ifa->ifa_scope = 0; + } + + ifa->ifa_address = ifa->ifa_local = sin->sin_addr.s_addr; + + if (!(dev->flags & IFF_POINTOPOINT)) { + ifa->ifa_prefixlen = inet_abc_len(ifa->ifa_address); + ifa->ifa_mask = inet_make_mask(ifa->ifa_prefixlen); + if ((dev->flags & IFF_BROADCAST) && + ifa->ifa_prefixlen < 31) + ifa->ifa_broadcast = ifa->ifa_address | + ~ifa->ifa_mask; + } else { + ifa->ifa_prefixlen = 32; + ifa->ifa_mask = inet_make_mask(32); + } + set_ifa_lifetime(ifa, INFINITY_LIFE_TIME, INFINITY_LIFE_TIME); + ret = inet_set_ifa(dev, ifa); + break; + + case SIOCSIFBRDADDR: /* Set the broadcast address */ + ret = 0; + if (ifa->ifa_broadcast != sin->sin_addr.s_addr) { + inet_del_ifa(in_dev, ifap, 0); + ifa->ifa_broadcast = sin->sin_addr.s_addr; + inet_insert_ifa(ifa); + } + break; + + case SIOCSIFDSTADDR: /* Set the destination address */ + ret = 0; + if (ifa->ifa_address == sin->sin_addr.s_addr) + break; + ret = -EINVAL; + if (inet_abc_len(sin->sin_addr.s_addr) < 0) + break; + ret = 0; + inet_del_ifa(in_dev, ifap, 0); + ifa->ifa_address = sin->sin_addr.s_addr; + inet_insert_ifa(ifa); + break; + + case SIOCSIFNETMASK: /* Set the netmask for the interface */ + + /* + * The mask we set must be legal. + */ + ret = -EINVAL; + if (bad_mask(sin->sin_addr.s_addr, 0)) + break; + ret = 0; + if (ifa->ifa_mask != sin->sin_addr.s_addr) { + __be32 old_mask = ifa->ifa_mask; + inet_del_ifa(in_dev, ifap, 0); + ifa->ifa_mask = sin->sin_addr.s_addr; + ifa->ifa_prefixlen = inet_mask_len(ifa->ifa_mask); + + /* See if current broadcast address matches + * with current netmask, then recalculate + * the broadcast address. Otherwise it's a + * funny address, so don't touch it since + * the user seems to know what (s)he's doing... + */ + if ((dev->flags & IFF_BROADCAST) && + (ifa->ifa_prefixlen < 31) && + (ifa->ifa_broadcast == + (ifa->ifa_local|~old_mask))) { + ifa->ifa_broadcast = (ifa->ifa_local | + ~sin->sin_addr.s_addr); + } + inet_insert_ifa(ifa); + } + break; + } +done: + rtnl_unlock(); +out: + return ret; +} + +int inet_gifconf(struct net_device *dev, char __user *buf, int len, int size) +{ + struct in_device *in_dev = __in_dev_get_rtnl(dev); + const struct in_ifaddr *ifa; + struct ifreq ifr; + int done = 0; + + if (WARN_ON(size > sizeof(struct ifreq))) + goto out; + + if (!in_dev) + goto out; + + in_dev_for_each_ifa_rtnl(ifa, in_dev) { + if (!buf) { + done += size; + continue; + } + if (len < size) + break; + memset(&ifr, 0, sizeof(struct ifreq)); + strcpy(ifr.ifr_name, ifa->ifa_label); + + (*(struct sockaddr_in *)&ifr.ifr_addr).sin_family = AF_INET; + (*(struct sockaddr_in *)&ifr.ifr_addr).sin_addr.s_addr = + ifa->ifa_local; + + if (copy_to_user(buf + done, &ifr, size)) { + done = -EFAULT; + break; + } + len -= size; + done += size; + } +out: + return done; +} + +static __be32 in_dev_select_addr(const struct in_device *in_dev, + int scope) +{ + const struct in_ifaddr *ifa; + + in_dev_for_each_ifa_rcu(ifa, in_dev) { + if (ifa->ifa_flags & IFA_F_SECONDARY) + continue; + if (ifa->ifa_scope != RT_SCOPE_LINK && + ifa->ifa_scope <= scope) + return ifa->ifa_local; + } + + return 0; +} + +__be32 inet_select_addr(const struct net_device *dev, __be32 dst, int scope) +{ + const struct in_ifaddr *ifa; + __be32 addr = 0; + unsigned char localnet_scope = RT_SCOPE_HOST; + struct in_device *in_dev; + struct net *net = dev_net(dev); + int master_idx; + + rcu_read_lock(); + in_dev = __in_dev_get_rcu(dev); + if (!in_dev) + goto no_in_dev; + + if (unlikely(IN_DEV_ROUTE_LOCALNET(in_dev))) + localnet_scope = RT_SCOPE_LINK; + + in_dev_for_each_ifa_rcu(ifa, in_dev) { + if (ifa->ifa_flags & IFA_F_SECONDARY) + continue; + if (min(ifa->ifa_scope, localnet_scope) > scope) + continue; + if (!dst || inet_ifa_match(dst, ifa)) { + addr = ifa->ifa_local; + break; + } + if (!addr) + addr = ifa->ifa_local; + } + + if (addr) + goto out_unlock; +no_in_dev: + master_idx = l3mdev_master_ifindex_rcu(dev); + + /* For VRFs, the VRF device takes the place of the loopback device, + * with addresses on it being preferred. Note in such cases the + * loopback device will be among the devices that fail the master_idx + * equality check in the loop below. + */ + if (master_idx && + (dev = dev_get_by_index_rcu(net, master_idx)) && + (in_dev = __in_dev_get_rcu(dev))) { + addr = in_dev_select_addr(in_dev, scope); + if (addr) + goto out_unlock; + } + + /* Not loopback addresses on loopback should be preferred + in this case. It is important that lo is the first interface + in dev_base list. + */ + for_each_netdev_rcu(net, dev) { + if (l3mdev_master_ifindex_rcu(dev) != master_idx) + continue; + + in_dev = __in_dev_get_rcu(dev); + if (!in_dev) + continue; + + addr = in_dev_select_addr(in_dev, scope); + if (addr) + goto out_unlock; + } +out_unlock: + rcu_read_unlock(); + return addr; +} +EXPORT_SYMBOL(inet_select_addr); + +static __be32 confirm_addr_indev(struct in_device *in_dev, __be32 dst, + __be32 local, int scope) +{ + unsigned char localnet_scope = RT_SCOPE_HOST; + const struct in_ifaddr *ifa; + __be32 addr = 0; + int same = 0; + + if (unlikely(IN_DEV_ROUTE_LOCALNET(in_dev))) + localnet_scope = RT_SCOPE_LINK; + + in_dev_for_each_ifa_rcu(ifa, in_dev) { + unsigned char min_scope = min(ifa->ifa_scope, localnet_scope); + + if (!addr && + (local == ifa->ifa_local || !local) && + min_scope <= scope) { + addr = ifa->ifa_local; + if (same) + break; + } + if (!same) { + same = (!local || inet_ifa_match(local, ifa)) && + (!dst || inet_ifa_match(dst, ifa)); + if (same && addr) { + if (local || !dst) + break; + /* Is the selected addr into dst subnet? */ + if (inet_ifa_match(addr, ifa)) + break; + /* No, then can we use new local src? */ + if (min_scope <= scope) { + addr = ifa->ifa_local; + break; + } + /* search for large dst subnet for addr */ + same = 0; + } + } + } + + return same ? addr : 0; +} + +/* + * Confirm that local IP address exists using wildcards: + * - net: netns to check, cannot be NULL + * - in_dev: only on this interface, NULL=any interface + * - dst: only in the same subnet as dst, 0=any dst + * - local: address, 0=autoselect the local address + * - scope: maximum allowed scope value for the local address + */ +__be32 inet_confirm_addr(struct net *net, struct in_device *in_dev, + __be32 dst, __be32 local, int scope) +{ + __be32 addr = 0; + struct net_device *dev; + + if (in_dev) + return confirm_addr_indev(in_dev, dst, local, scope); + + rcu_read_lock(); + for_each_netdev_rcu(net, dev) { + in_dev = __in_dev_get_rcu(dev); + if (in_dev) { + addr = confirm_addr_indev(in_dev, dst, local, scope); + if (addr) + break; + } + } + rcu_read_unlock(); + + return addr; +} +EXPORT_SYMBOL(inet_confirm_addr); + +/* + * Device notifier + */ + +int register_inetaddr_notifier(struct notifier_block *nb) +{ + return blocking_notifier_chain_register(&inetaddr_chain, nb); +} +EXPORT_SYMBOL(register_inetaddr_notifier); + +int unregister_inetaddr_notifier(struct notifier_block *nb) +{ + return blocking_notifier_chain_unregister(&inetaddr_chain, nb); +} +EXPORT_SYMBOL(unregister_inetaddr_notifier); + +int register_inetaddr_validator_notifier(struct notifier_block *nb) +{ + return blocking_notifier_chain_register(&inetaddr_validator_chain, nb); +} +EXPORT_SYMBOL(register_inetaddr_validator_notifier); + +int unregister_inetaddr_validator_notifier(struct notifier_block *nb) +{ + return blocking_notifier_chain_unregister(&inetaddr_validator_chain, + nb); +} +EXPORT_SYMBOL(unregister_inetaddr_validator_notifier); + +/* Rename ifa_labels for a device name change. Make some effort to preserve + * existing alias numbering and to create unique labels if possible. +*/ +static void inetdev_changename(struct net_device *dev, struct in_device *in_dev) +{ + struct in_ifaddr *ifa; + int named = 0; + + in_dev_for_each_ifa_rtnl(ifa, in_dev) { + char old[IFNAMSIZ], *dot; + + memcpy(old, ifa->ifa_label, IFNAMSIZ); + memcpy(ifa->ifa_label, dev->name, IFNAMSIZ); + if (named++ == 0) + goto skip; + dot = strchr(old, ':'); + if (!dot) { + sprintf(old, ":%d", named); + dot = old; + } + if (strlen(dot) + strlen(dev->name) < IFNAMSIZ) + strcat(ifa->ifa_label, dot); + else + strcpy(ifa->ifa_label + (IFNAMSIZ - strlen(dot) - 1), dot); +skip: + rtmsg_ifa(RTM_NEWADDR, ifa, NULL, 0); + } +} + +static void inetdev_send_gratuitous_arp(struct net_device *dev, + struct in_device *in_dev) + +{ + const struct in_ifaddr *ifa; + + in_dev_for_each_ifa_rtnl(ifa, in_dev) { + arp_send(ARPOP_REQUEST, ETH_P_ARP, + ifa->ifa_local, dev, + ifa->ifa_local, NULL, + dev->dev_addr, NULL); + } +} + +/* Called only under RTNL semaphore */ + +static int inetdev_event(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct in_device *in_dev = __in_dev_get_rtnl(dev); + + ASSERT_RTNL(); + + if (!in_dev) { + if (event == NETDEV_REGISTER) { + in_dev = inetdev_init(dev); + if (IS_ERR(in_dev)) + return notifier_from_errno(PTR_ERR(in_dev)); + if (dev->flags & IFF_LOOPBACK) { + IN_DEV_CONF_SET(in_dev, NOXFRM, 1); + IN_DEV_CONF_SET(in_dev, NOPOLICY, 1); + } + } else if (event == NETDEV_CHANGEMTU) { + /* Re-enabling IP */ + if (inetdev_valid_mtu(dev->mtu)) + in_dev = inetdev_init(dev); + } + goto out; + } + + switch (event) { + case NETDEV_REGISTER: + pr_debug("%s: bug\n", __func__); + RCU_INIT_POINTER(dev->ip_ptr, NULL); + break; + case NETDEV_UP: + if (!inetdev_valid_mtu(dev->mtu)) + break; + if (dev->flags & IFF_LOOPBACK) { + struct in_ifaddr *ifa = inet_alloc_ifa(); + + if (ifa) { + INIT_HLIST_NODE(&ifa->hash); + ifa->ifa_local = + ifa->ifa_address = htonl(INADDR_LOOPBACK); + ifa->ifa_prefixlen = 8; + ifa->ifa_mask = inet_make_mask(8); + in_dev_hold(in_dev); + ifa->ifa_dev = in_dev; + ifa->ifa_scope = RT_SCOPE_HOST; + memcpy(ifa->ifa_label, dev->name, IFNAMSIZ); + set_ifa_lifetime(ifa, INFINITY_LIFE_TIME, + INFINITY_LIFE_TIME); + ipv4_devconf_setall(in_dev); + neigh_parms_data_state_setall(in_dev->arp_parms); + inet_insert_ifa(ifa); + } + } + ip_mc_up(in_dev); + fallthrough; + case NETDEV_CHANGEADDR: + if (!IN_DEV_ARP_NOTIFY(in_dev)) + break; + fallthrough; + case NETDEV_NOTIFY_PEERS: + /* Send gratuitous ARP to notify of link change */ + inetdev_send_gratuitous_arp(dev, in_dev); + break; + case NETDEV_DOWN: + ip_mc_down(in_dev); + break; + case NETDEV_PRE_TYPE_CHANGE: + ip_mc_unmap(in_dev); + break; + case NETDEV_POST_TYPE_CHANGE: + ip_mc_remap(in_dev); + break; + case NETDEV_CHANGEMTU: + if (inetdev_valid_mtu(dev->mtu)) + break; + /* disable IP when MTU is not enough */ + fallthrough; + case NETDEV_UNREGISTER: + inetdev_destroy(in_dev); + break; + case NETDEV_CHANGENAME: + /* Do not notify about label change, this event is + * not interesting to applications using netlink. + */ + inetdev_changename(dev, in_dev); + + devinet_sysctl_unregister(in_dev); + devinet_sysctl_register(in_dev); + break; + } +out: + return NOTIFY_DONE; +} + +static struct notifier_block ip_netdev_notifier = { + .notifier_call = inetdev_event, +}; + +static size_t inet_nlmsg_size(void) +{ + return NLMSG_ALIGN(sizeof(struct ifaddrmsg)) + + nla_total_size(4) /* IFA_ADDRESS */ + + nla_total_size(4) /* IFA_LOCAL */ + + nla_total_size(4) /* IFA_BROADCAST */ + + nla_total_size(IFNAMSIZ) /* IFA_LABEL */ + + nla_total_size(4) /* IFA_FLAGS */ + + nla_total_size(1) /* IFA_PROTO */ + + nla_total_size(4) /* IFA_RT_PRIORITY */ + + nla_total_size(sizeof(struct ifa_cacheinfo)); /* IFA_CACHEINFO */ +} + +static inline u32 cstamp_delta(unsigned long cstamp) +{ + return (cstamp - INITIAL_JIFFIES) * 100UL / HZ; +} + +static int put_cacheinfo(struct sk_buff *skb, unsigned long cstamp, + unsigned long tstamp, u32 preferred, u32 valid) +{ + struct ifa_cacheinfo ci; + + ci.cstamp = cstamp_delta(cstamp); + ci.tstamp = cstamp_delta(tstamp); + ci.ifa_prefered = preferred; + ci.ifa_valid = valid; + + return nla_put(skb, IFA_CACHEINFO, sizeof(ci), &ci); +} + +static int inet_fill_ifaddr(struct sk_buff *skb, struct in_ifaddr *ifa, + struct inet_fill_args *args) +{ + struct ifaddrmsg *ifm; + struct nlmsghdr *nlh; + u32 preferred, valid; + + nlh = nlmsg_put(skb, args->portid, args->seq, args->event, sizeof(*ifm), + args->flags); + if (!nlh) + return -EMSGSIZE; + + ifm = nlmsg_data(nlh); + ifm->ifa_family = AF_INET; + ifm->ifa_prefixlen = ifa->ifa_prefixlen; + ifm->ifa_flags = ifa->ifa_flags; + ifm->ifa_scope = ifa->ifa_scope; + ifm->ifa_index = ifa->ifa_dev->dev->ifindex; + + if (args->netnsid >= 0 && + nla_put_s32(skb, IFA_TARGET_NETNSID, args->netnsid)) + goto nla_put_failure; + + if (!(ifm->ifa_flags & IFA_F_PERMANENT)) { + preferred = ifa->ifa_preferred_lft; + valid = ifa->ifa_valid_lft; + if (preferred != INFINITY_LIFE_TIME) { + long tval = (jiffies - ifa->ifa_tstamp) / HZ; + + if (preferred > tval) + preferred -= tval; + else + preferred = 0; + if (valid != INFINITY_LIFE_TIME) { + if (valid > tval) + valid -= tval; + else + valid = 0; + } + } + } else { + preferred = INFINITY_LIFE_TIME; + valid = INFINITY_LIFE_TIME; + } + if ((ifa->ifa_address && + nla_put_in_addr(skb, IFA_ADDRESS, ifa->ifa_address)) || + (ifa->ifa_local && + nla_put_in_addr(skb, IFA_LOCAL, ifa->ifa_local)) || + (ifa->ifa_broadcast && + nla_put_in_addr(skb, IFA_BROADCAST, ifa->ifa_broadcast)) || + (ifa->ifa_label[0] && + nla_put_string(skb, IFA_LABEL, ifa->ifa_label)) || + (ifa->ifa_proto && + nla_put_u8(skb, IFA_PROTO, ifa->ifa_proto)) || + nla_put_u32(skb, IFA_FLAGS, ifa->ifa_flags) || + (ifa->ifa_rt_priority && + nla_put_u32(skb, IFA_RT_PRIORITY, ifa->ifa_rt_priority)) || + put_cacheinfo(skb, ifa->ifa_cstamp, ifa->ifa_tstamp, + preferred, valid)) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int inet_valid_dump_ifaddr_req(const struct nlmsghdr *nlh, + struct inet_fill_args *fillargs, + struct net **tgt_net, struct sock *sk, + struct netlink_callback *cb) +{ + struct netlink_ext_ack *extack = cb->extack; + struct nlattr *tb[IFA_MAX+1]; + struct ifaddrmsg *ifm; + int err, i; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid header for address dump request"); + return -EINVAL; + } + + ifm = nlmsg_data(nlh); + if (ifm->ifa_prefixlen || ifm->ifa_flags || ifm->ifa_scope) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid values in header for address dump request"); + return -EINVAL; + } + + fillargs->ifindex = ifm->ifa_index; + if (fillargs->ifindex) { + cb->answer_flags |= NLM_F_DUMP_FILTERED; + fillargs->flags |= NLM_F_DUMP_FILTERED; + } + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(*ifm), tb, IFA_MAX, + ifa_ipv4_policy, extack); + if (err < 0) + return err; + + for (i = 0; i <= IFA_MAX; ++i) { + if (!tb[i]) + continue; + + if (i == IFA_TARGET_NETNSID) { + struct net *net; + + fillargs->netnsid = nla_get_s32(tb[i]); + + net = rtnl_get_net_ns_capable(sk, fillargs->netnsid); + if (IS_ERR(net)) { + fillargs->netnsid = -1; + NL_SET_ERR_MSG(extack, "ipv4: Invalid target network namespace id"); + return PTR_ERR(net); + } + *tgt_net = net; + } else { + NL_SET_ERR_MSG(extack, "ipv4: Unsupported attribute in dump request"); + return -EINVAL; + } + } + + return 0; +} + +static int in_dev_dump_addr(struct in_device *in_dev, struct sk_buff *skb, + struct netlink_callback *cb, int s_ip_idx, + struct inet_fill_args *fillargs) +{ + struct in_ifaddr *ifa; + int ip_idx = 0; + int err; + + in_dev_for_each_ifa_rtnl(ifa, in_dev) { + if (ip_idx < s_ip_idx) { + ip_idx++; + continue; + } + err = inet_fill_ifaddr(skb, ifa, fillargs); + if (err < 0) + goto done; + + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); + ip_idx++; + } + err = 0; + +done: + cb->args[2] = ip_idx; + + return err; +} + +static int inet_dump_ifaddr(struct sk_buff *skb, struct netlink_callback *cb) +{ + const struct nlmsghdr *nlh = cb->nlh; + struct inet_fill_args fillargs = { + .portid = NETLINK_CB(cb->skb).portid, + .seq = nlh->nlmsg_seq, + .event = RTM_NEWADDR, + .flags = NLM_F_MULTI, + .netnsid = -1, + }; + struct net *net = sock_net(skb->sk); + struct net *tgt_net = net; + int h, s_h; + int idx, s_idx; + int s_ip_idx; + struct net_device *dev; + struct in_device *in_dev; + struct hlist_head *head; + int err = 0; + + s_h = cb->args[0]; + s_idx = idx = cb->args[1]; + s_ip_idx = cb->args[2]; + + if (cb->strict_check) { + err = inet_valid_dump_ifaddr_req(nlh, &fillargs, &tgt_net, + skb->sk, cb); + if (err < 0) + goto put_tgt_net; + + err = 0; + if (fillargs.ifindex) { + dev = __dev_get_by_index(tgt_net, fillargs.ifindex); + if (!dev) { + err = -ENODEV; + goto put_tgt_net; + } + + in_dev = __in_dev_get_rtnl(dev); + if (in_dev) { + err = in_dev_dump_addr(in_dev, skb, cb, s_ip_idx, + &fillargs); + } + goto put_tgt_net; + } + } + + for (h = s_h; h < NETDEV_HASHENTRIES; h++, s_idx = 0) { + idx = 0; + head = &tgt_net->dev_index_head[h]; + rcu_read_lock(); + cb->seq = atomic_read(&tgt_net->ipv4.dev_addr_genid) ^ + tgt_net->dev_base_seq; + hlist_for_each_entry_rcu(dev, head, index_hlist) { + if (idx < s_idx) + goto cont; + if (h > s_h || idx > s_idx) + s_ip_idx = 0; + in_dev = __in_dev_get_rcu(dev); + if (!in_dev) + goto cont; + + err = in_dev_dump_addr(in_dev, skb, cb, s_ip_idx, + &fillargs); + if (err < 0) { + rcu_read_unlock(); + goto done; + } +cont: + idx++; + } + rcu_read_unlock(); + } + +done: + cb->args[0] = h; + cb->args[1] = idx; +put_tgt_net: + if (fillargs.netnsid >= 0) + put_net(tgt_net); + + return skb->len ? : err; +} + +static void rtmsg_ifa(int event, struct in_ifaddr *ifa, struct nlmsghdr *nlh, + u32 portid) +{ + struct inet_fill_args fillargs = { + .portid = portid, + .seq = nlh ? nlh->nlmsg_seq : 0, + .event = event, + .flags = 0, + .netnsid = -1, + }; + struct sk_buff *skb; + int err = -ENOBUFS; + struct net *net; + + net = dev_net(ifa->ifa_dev->dev); + skb = nlmsg_new(inet_nlmsg_size(), GFP_KERNEL); + if (!skb) + goto errout; + + err = inet_fill_ifaddr(skb, ifa, &fillargs); + if (err < 0) { + /* -EMSGSIZE implies BUG in inet_nlmsg_size() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + rtnl_notify(skb, net, portid, RTNLGRP_IPV4_IFADDR, nlh, GFP_KERNEL); + return; +errout: + if (err < 0) + rtnl_set_sk_err(net, RTNLGRP_IPV4_IFADDR, err); +} + +static size_t inet_get_link_af_size(const struct net_device *dev, + u32 ext_filter_mask) +{ + struct in_device *in_dev = rcu_dereference_rtnl(dev->ip_ptr); + + if (!in_dev) + return 0; + + return nla_total_size(IPV4_DEVCONF_MAX * 4); /* IFLA_INET_CONF */ +} + +static int inet_fill_link_af(struct sk_buff *skb, const struct net_device *dev, + u32 ext_filter_mask) +{ + struct in_device *in_dev = rcu_dereference_rtnl(dev->ip_ptr); + struct nlattr *nla; + int i; + + if (!in_dev) + return -ENODATA; + + nla = nla_reserve(skb, IFLA_INET_CONF, IPV4_DEVCONF_MAX * 4); + if (!nla) + return -EMSGSIZE; + + for (i = 0; i < IPV4_DEVCONF_MAX; i++) + ((u32 *) nla_data(nla))[i] = in_dev->cnf.data[i]; + + return 0; +} + +static const struct nla_policy inet_af_policy[IFLA_INET_MAX+1] = { + [IFLA_INET_CONF] = { .type = NLA_NESTED }, +}; + +static int inet_validate_link_af(const struct net_device *dev, + const struct nlattr *nla, + struct netlink_ext_ack *extack) +{ + struct nlattr *a, *tb[IFLA_INET_MAX+1]; + int err, rem; + + if (dev && !__in_dev_get_rtnl(dev)) + return -EAFNOSUPPORT; + + err = nla_parse_nested_deprecated(tb, IFLA_INET_MAX, nla, + inet_af_policy, extack); + if (err < 0) + return err; + + if (tb[IFLA_INET_CONF]) { + nla_for_each_nested(a, tb[IFLA_INET_CONF], rem) { + int cfgid = nla_type(a); + + if (nla_len(a) < 4) + return -EINVAL; + + if (cfgid <= 0 || cfgid > IPV4_DEVCONF_MAX) + return -EINVAL; + } + } + + return 0; +} + +static int inet_set_link_af(struct net_device *dev, const struct nlattr *nla, + struct netlink_ext_ack *extack) +{ + struct in_device *in_dev = __in_dev_get_rtnl(dev); + struct nlattr *a, *tb[IFLA_INET_MAX+1]; + int rem; + + if (!in_dev) + return -EAFNOSUPPORT; + + if (nla_parse_nested_deprecated(tb, IFLA_INET_MAX, nla, NULL, NULL) < 0) + return -EINVAL; + + if (tb[IFLA_INET_CONF]) { + nla_for_each_nested(a, tb[IFLA_INET_CONF], rem) + ipv4_devconf_set(in_dev, nla_type(a), nla_get_u32(a)); + } + + return 0; +} + +static int inet_netconf_msgsize_devconf(int type) +{ + int size = NLMSG_ALIGN(sizeof(struct netconfmsg)) + + nla_total_size(4); /* NETCONFA_IFINDEX */ + bool all = false; + + if (type == NETCONFA_ALL) + all = true; + + if (all || type == NETCONFA_FORWARDING) + size += nla_total_size(4); + if (all || type == NETCONFA_RP_FILTER) + size += nla_total_size(4); + if (all || type == NETCONFA_MC_FORWARDING) + size += nla_total_size(4); + if (all || type == NETCONFA_BC_FORWARDING) + size += nla_total_size(4); + if (all || type == NETCONFA_PROXY_NEIGH) + size += nla_total_size(4); + if (all || type == NETCONFA_IGNORE_ROUTES_WITH_LINKDOWN) + size += nla_total_size(4); + + return size; +} + +static int inet_netconf_fill_devconf(struct sk_buff *skb, int ifindex, + struct ipv4_devconf *devconf, u32 portid, + u32 seq, int event, unsigned int flags, + int type) +{ + struct nlmsghdr *nlh; + struct netconfmsg *ncm; + bool all = false; + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(struct netconfmsg), + flags); + if (!nlh) + return -EMSGSIZE; + + if (type == NETCONFA_ALL) + all = true; + + ncm = nlmsg_data(nlh); + ncm->ncm_family = AF_INET; + + if (nla_put_s32(skb, NETCONFA_IFINDEX, ifindex) < 0) + goto nla_put_failure; + + if (!devconf) + goto out; + + if ((all || type == NETCONFA_FORWARDING) && + nla_put_s32(skb, NETCONFA_FORWARDING, + IPV4_DEVCONF(*devconf, FORWARDING)) < 0) + goto nla_put_failure; + if ((all || type == NETCONFA_RP_FILTER) && + nla_put_s32(skb, NETCONFA_RP_FILTER, + IPV4_DEVCONF(*devconf, RP_FILTER)) < 0) + goto nla_put_failure; + if ((all || type == NETCONFA_MC_FORWARDING) && + nla_put_s32(skb, NETCONFA_MC_FORWARDING, + IPV4_DEVCONF(*devconf, MC_FORWARDING)) < 0) + goto nla_put_failure; + if ((all || type == NETCONFA_BC_FORWARDING) && + nla_put_s32(skb, NETCONFA_BC_FORWARDING, + IPV4_DEVCONF(*devconf, BC_FORWARDING)) < 0) + goto nla_put_failure; + if ((all || type == NETCONFA_PROXY_NEIGH) && + nla_put_s32(skb, NETCONFA_PROXY_NEIGH, + IPV4_DEVCONF(*devconf, PROXY_ARP)) < 0) + goto nla_put_failure; + if ((all || type == NETCONFA_IGNORE_ROUTES_WITH_LINKDOWN) && + nla_put_s32(skb, NETCONFA_IGNORE_ROUTES_WITH_LINKDOWN, + IPV4_DEVCONF(*devconf, IGNORE_ROUTES_WITH_LINKDOWN)) < 0) + goto nla_put_failure; + +out: + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +void inet_netconf_notify_devconf(struct net *net, int event, int type, + int ifindex, struct ipv4_devconf *devconf) +{ + struct sk_buff *skb; + int err = -ENOBUFS; + + skb = nlmsg_new(inet_netconf_msgsize_devconf(type), GFP_KERNEL); + if (!skb) + goto errout; + + err = inet_netconf_fill_devconf(skb, ifindex, devconf, 0, 0, + event, 0, type); + if (err < 0) { + /* -EMSGSIZE implies BUG in inet_netconf_msgsize_devconf() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + rtnl_notify(skb, net, 0, RTNLGRP_IPV4_NETCONF, NULL, GFP_KERNEL); + return; +errout: + if (err < 0) + rtnl_set_sk_err(net, RTNLGRP_IPV4_NETCONF, err); +} + +static const struct nla_policy devconf_ipv4_policy[NETCONFA_MAX+1] = { + [NETCONFA_IFINDEX] = { .len = sizeof(int) }, + [NETCONFA_FORWARDING] = { .len = sizeof(int) }, + [NETCONFA_RP_FILTER] = { .len = sizeof(int) }, + [NETCONFA_PROXY_NEIGH] = { .len = sizeof(int) }, + [NETCONFA_IGNORE_ROUTES_WITH_LINKDOWN] = { .len = sizeof(int) }, +}; + +static int inet_netconf_valid_get_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(struct netconfmsg))) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid header for netconf get request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse_deprecated(nlh, sizeof(struct netconfmsg), + tb, NETCONFA_MAX, + devconf_ipv4_policy, extack); + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(struct netconfmsg), + tb, NETCONFA_MAX, + devconf_ipv4_policy, extack); + if (err) + return err; + + for (i = 0; i <= NETCONFA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case NETCONFA_IFINDEX: + break; + default: + NL_SET_ERR_MSG(extack, "ipv4: Unsupported attribute in netconf get request"); + return -EINVAL; + } + } + + return 0; +} + +static int inet_netconf_get_devconf(struct sk_buff *in_skb, + struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(in_skb->sk); + struct nlattr *tb[NETCONFA_MAX+1]; + struct sk_buff *skb; + struct ipv4_devconf *devconf; + struct in_device *in_dev; + struct net_device *dev; + int ifindex; + int err; + + err = inet_netconf_valid_get_req(in_skb, nlh, tb, extack); + if (err) + goto errout; + + err = -EINVAL; + if (!tb[NETCONFA_IFINDEX]) + goto errout; + + ifindex = nla_get_s32(tb[NETCONFA_IFINDEX]); + switch (ifindex) { + case NETCONFA_IFINDEX_ALL: + devconf = net->ipv4.devconf_all; + break; + case NETCONFA_IFINDEX_DEFAULT: + devconf = net->ipv4.devconf_dflt; + break; + default: + dev = __dev_get_by_index(net, ifindex); + if (!dev) + goto errout; + in_dev = __in_dev_get_rtnl(dev); + if (!in_dev) + goto errout; + devconf = &in_dev->cnf; + break; + } + + err = -ENOBUFS; + skb = nlmsg_new(inet_netconf_msgsize_devconf(NETCONFA_ALL), GFP_KERNEL); + if (!skb) + goto errout; + + err = inet_netconf_fill_devconf(skb, ifindex, devconf, + NETLINK_CB(in_skb).portid, + nlh->nlmsg_seq, RTM_NEWNETCONF, 0, + NETCONFA_ALL); + if (err < 0) { + /* -EMSGSIZE implies BUG in inet_netconf_msgsize_devconf() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid); +errout: + return err; +} + +static int inet_netconf_dump_devconf(struct sk_buff *skb, + struct netlink_callback *cb) +{ + const struct nlmsghdr *nlh = cb->nlh; + struct net *net = sock_net(skb->sk); + int h, s_h; + int idx, s_idx; + struct net_device *dev; + struct in_device *in_dev; + struct hlist_head *head; + + if (cb->strict_check) { + struct netlink_ext_ack *extack = cb->extack; + struct netconfmsg *ncm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ncm))) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid header for netconf dump request"); + return -EINVAL; + } + + if (nlmsg_attrlen(nlh, sizeof(*ncm))) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid data after header in netconf dump request"); + return -EINVAL; + } + } + + s_h = cb->args[0]; + s_idx = idx = cb->args[1]; + + for (h = s_h; h < NETDEV_HASHENTRIES; h++, s_idx = 0) { + idx = 0; + head = &net->dev_index_head[h]; + rcu_read_lock(); + cb->seq = atomic_read(&net->ipv4.dev_addr_genid) ^ + net->dev_base_seq; + hlist_for_each_entry_rcu(dev, head, index_hlist) { + if (idx < s_idx) + goto cont; + in_dev = __in_dev_get_rcu(dev); + if (!in_dev) + goto cont; + + if (inet_netconf_fill_devconf(skb, dev->ifindex, + &in_dev->cnf, + NETLINK_CB(cb->skb).portid, + nlh->nlmsg_seq, + RTM_NEWNETCONF, + NLM_F_MULTI, + NETCONFA_ALL) < 0) { + rcu_read_unlock(); + goto done; + } + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); +cont: + idx++; + } + rcu_read_unlock(); + } + if (h == NETDEV_HASHENTRIES) { + if (inet_netconf_fill_devconf(skb, NETCONFA_IFINDEX_ALL, + net->ipv4.devconf_all, + NETLINK_CB(cb->skb).portid, + nlh->nlmsg_seq, + RTM_NEWNETCONF, NLM_F_MULTI, + NETCONFA_ALL) < 0) + goto done; + else + h++; + } + if (h == NETDEV_HASHENTRIES + 1) { + if (inet_netconf_fill_devconf(skb, NETCONFA_IFINDEX_DEFAULT, + net->ipv4.devconf_dflt, + NETLINK_CB(cb->skb).portid, + nlh->nlmsg_seq, + RTM_NEWNETCONF, NLM_F_MULTI, + NETCONFA_ALL) < 0) + goto done; + else + h++; + } +done: + cb->args[0] = h; + cb->args[1] = idx; + + return skb->len; +} + +#ifdef CONFIG_SYSCTL + +static void devinet_copy_dflt_conf(struct net *net, int i) +{ + struct net_device *dev; + + rcu_read_lock(); + for_each_netdev_rcu(net, dev) { + struct in_device *in_dev; + + in_dev = __in_dev_get_rcu(dev); + if (in_dev && !test_bit(i, in_dev->cnf.state)) + in_dev->cnf.data[i] = net->ipv4.devconf_dflt->data[i]; + } + rcu_read_unlock(); +} + +/* called with RTNL locked */ +static void inet_forward_change(struct net *net) +{ + struct net_device *dev; + int on = IPV4_DEVCONF_ALL(net, FORWARDING); + + IPV4_DEVCONF_ALL(net, ACCEPT_REDIRECTS) = !on; + IPV4_DEVCONF_DFLT(net, FORWARDING) = on; + inet_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_FORWARDING, + NETCONFA_IFINDEX_ALL, + net->ipv4.devconf_all); + inet_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_FORWARDING, + NETCONFA_IFINDEX_DEFAULT, + net->ipv4.devconf_dflt); + + for_each_netdev(net, dev) { + struct in_device *in_dev; + + if (on) + dev_disable_lro(dev); + + in_dev = __in_dev_get_rtnl(dev); + if (in_dev) { + IN_DEV_CONF_SET(in_dev, FORWARDING, on); + inet_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_FORWARDING, + dev->ifindex, &in_dev->cnf); + } + } +} + +static int devinet_conf_ifindex(struct net *net, struct ipv4_devconf *cnf) +{ + if (cnf == net->ipv4.devconf_dflt) + return NETCONFA_IFINDEX_DEFAULT; + else if (cnf == net->ipv4.devconf_all) + return NETCONFA_IFINDEX_ALL; + else { + struct in_device *idev + = container_of(cnf, struct in_device, cnf); + return idev->dev->ifindex; + } +} + +static int devinet_conf_proc(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + int old_value = *(int *)ctl->data; + int ret = proc_dointvec(ctl, write, buffer, lenp, ppos); + int new_value = *(int *)ctl->data; + + if (write) { + struct ipv4_devconf *cnf = ctl->extra1; + struct net *net = ctl->extra2; + int i = (int *)ctl->data - cnf->data; + int ifindex; + + set_bit(i, cnf->state); + + if (cnf == net->ipv4.devconf_dflt) + devinet_copy_dflt_conf(net, i); + if (i == IPV4_DEVCONF_ACCEPT_LOCAL - 1 || + i == IPV4_DEVCONF_ROUTE_LOCALNET - 1) + if ((new_value == 0) && (old_value != 0)) + rt_cache_flush(net); + + if (i == IPV4_DEVCONF_BC_FORWARDING - 1 && + new_value != old_value) + rt_cache_flush(net); + + if (i == IPV4_DEVCONF_RP_FILTER - 1 && + new_value != old_value) { + ifindex = devinet_conf_ifindex(net, cnf); + inet_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_RP_FILTER, + ifindex, cnf); + } + if (i == IPV4_DEVCONF_PROXY_ARP - 1 && + new_value != old_value) { + ifindex = devinet_conf_ifindex(net, cnf); + inet_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_PROXY_NEIGH, + ifindex, cnf); + } + if (i == IPV4_DEVCONF_IGNORE_ROUTES_WITH_LINKDOWN - 1 && + new_value != old_value) { + ifindex = devinet_conf_ifindex(net, cnf); + inet_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_IGNORE_ROUTES_WITH_LINKDOWN, + ifindex, cnf); + } + } + + return ret; +} + +static int devinet_sysctl_forward(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + int *valp = ctl->data; + int val = *valp; + loff_t pos = *ppos; + struct net *net = ctl->extra2; + int ret; + + if (write && !ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + ret = proc_dointvec(ctl, write, buffer, lenp, ppos); + + if (write && *valp != val) { + if (valp != &IPV4_DEVCONF_DFLT(net, FORWARDING)) { + if (!rtnl_trylock()) { + /* Restore the original values before restarting */ + *valp = val; + *ppos = pos; + return restart_syscall(); + } + if (valp == &IPV4_DEVCONF_ALL(net, FORWARDING)) { + inet_forward_change(net); + } else { + struct ipv4_devconf *cnf = ctl->extra1; + struct in_device *idev = + container_of(cnf, struct in_device, cnf); + if (*valp) + dev_disable_lro(idev->dev); + inet_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_FORWARDING, + idev->dev->ifindex, + cnf); + } + rtnl_unlock(); + rt_cache_flush(net); + } else + inet_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_FORWARDING, + NETCONFA_IFINDEX_DEFAULT, + net->ipv4.devconf_dflt); + } + + return ret; +} + +static int ipv4_doint_and_flush(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + int *valp = ctl->data; + int val = *valp; + int ret = proc_dointvec(ctl, write, buffer, lenp, ppos); + struct net *net = ctl->extra2; + + if (write && *valp != val) + rt_cache_flush(net); + + return ret; +} + +#define DEVINET_SYSCTL_ENTRY(attr, name, mval, proc) \ + { \ + .procname = name, \ + .data = ipv4_devconf.data + \ + IPV4_DEVCONF_ ## attr - 1, \ + .maxlen = sizeof(int), \ + .mode = mval, \ + .proc_handler = proc, \ + .extra1 = &ipv4_devconf, \ + } + +#define DEVINET_SYSCTL_RW_ENTRY(attr, name) \ + DEVINET_SYSCTL_ENTRY(attr, name, 0644, devinet_conf_proc) + +#define DEVINET_SYSCTL_RO_ENTRY(attr, name) \ + DEVINET_SYSCTL_ENTRY(attr, name, 0444, devinet_conf_proc) + +#define DEVINET_SYSCTL_COMPLEX_ENTRY(attr, name, proc) \ + DEVINET_SYSCTL_ENTRY(attr, name, 0644, proc) + +#define DEVINET_SYSCTL_FLUSHING_ENTRY(attr, name) \ + DEVINET_SYSCTL_COMPLEX_ENTRY(attr, name, ipv4_doint_and_flush) + +static struct devinet_sysctl_table { + struct ctl_table_header *sysctl_header; + struct ctl_table devinet_vars[__IPV4_DEVCONF_MAX]; +} devinet_sysctl = { + .devinet_vars = { + DEVINET_SYSCTL_COMPLEX_ENTRY(FORWARDING, "forwarding", + devinet_sysctl_forward), + DEVINET_SYSCTL_RO_ENTRY(MC_FORWARDING, "mc_forwarding"), + DEVINET_SYSCTL_RW_ENTRY(BC_FORWARDING, "bc_forwarding"), + + DEVINET_SYSCTL_RW_ENTRY(ACCEPT_REDIRECTS, "accept_redirects"), + DEVINET_SYSCTL_RW_ENTRY(SECURE_REDIRECTS, "secure_redirects"), + DEVINET_SYSCTL_RW_ENTRY(SHARED_MEDIA, "shared_media"), + DEVINET_SYSCTL_RW_ENTRY(RP_FILTER, "rp_filter"), + DEVINET_SYSCTL_RW_ENTRY(SEND_REDIRECTS, "send_redirects"), + DEVINET_SYSCTL_RW_ENTRY(ACCEPT_SOURCE_ROUTE, + "accept_source_route"), + DEVINET_SYSCTL_RW_ENTRY(ACCEPT_LOCAL, "accept_local"), + DEVINET_SYSCTL_RW_ENTRY(SRC_VMARK, "src_valid_mark"), + DEVINET_SYSCTL_RW_ENTRY(PROXY_ARP, "proxy_arp"), + DEVINET_SYSCTL_RW_ENTRY(MEDIUM_ID, "medium_id"), + DEVINET_SYSCTL_RW_ENTRY(BOOTP_RELAY, "bootp_relay"), + DEVINET_SYSCTL_RW_ENTRY(LOG_MARTIANS, "log_martians"), + DEVINET_SYSCTL_RW_ENTRY(TAG, "tag"), + DEVINET_SYSCTL_RW_ENTRY(ARPFILTER, "arp_filter"), + DEVINET_SYSCTL_RW_ENTRY(ARP_ANNOUNCE, "arp_announce"), + DEVINET_SYSCTL_RW_ENTRY(ARP_IGNORE, "arp_ignore"), + DEVINET_SYSCTL_RW_ENTRY(ARP_ACCEPT, "arp_accept"), + DEVINET_SYSCTL_RW_ENTRY(ARP_NOTIFY, "arp_notify"), + DEVINET_SYSCTL_RW_ENTRY(ARP_EVICT_NOCARRIER, + "arp_evict_nocarrier"), + DEVINET_SYSCTL_RW_ENTRY(PROXY_ARP_PVLAN, "proxy_arp_pvlan"), + DEVINET_SYSCTL_RW_ENTRY(FORCE_IGMP_VERSION, + "force_igmp_version"), + DEVINET_SYSCTL_RW_ENTRY(IGMPV2_UNSOLICITED_REPORT_INTERVAL, + "igmpv2_unsolicited_report_interval"), + DEVINET_SYSCTL_RW_ENTRY(IGMPV3_UNSOLICITED_REPORT_INTERVAL, + "igmpv3_unsolicited_report_interval"), + DEVINET_SYSCTL_RW_ENTRY(IGNORE_ROUTES_WITH_LINKDOWN, + "ignore_routes_with_linkdown"), + DEVINET_SYSCTL_RW_ENTRY(DROP_GRATUITOUS_ARP, + "drop_gratuitous_arp"), + + DEVINET_SYSCTL_FLUSHING_ENTRY(NOXFRM, "disable_xfrm"), + DEVINET_SYSCTL_FLUSHING_ENTRY(NOPOLICY, "disable_policy"), + DEVINET_SYSCTL_FLUSHING_ENTRY(PROMOTE_SECONDARIES, + "promote_secondaries"), + DEVINET_SYSCTL_FLUSHING_ENTRY(ROUTE_LOCALNET, + "route_localnet"), + DEVINET_SYSCTL_FLUSHING_ENTRY(DROP_UNICAST_IN_L2_MULTICAST, + "drop_unicast_in_l2_multicast"), + }, +}; + +static int __devinet_sysctl_register(struct net *net, char *dev_name, + int ifindex, struct ipv4_devconf *p) +{ + int i; + struct devinet_sysctl_table *t; + char path[sizeof("net/ipv4/conf/") + IFNAMSIZ]; + + t = kmemdup(&devinet_sysctl, sizeof(*t), GFP_KERNEL_ACCOUNT); + if (!t) + goto out; + + for (i = 0; i < ARRAY_SIZE(t->devinet_vars) - 1; i++) { + t->devinet_vars[i].data += (char *)p - (char *)&ipv4_devconf; + t->devinet_vars[i].extra1 = p; + t->devinet_vars[i].extra2 = net; + } + + snprintf(path, sizeof(path), "net/ipv4/conf/%s", dev_name); + + t->sysctl_header = register_net_sysctl(net, path, t->devinet_vars); + if (!t->sysctl_header) + goto free; + + p->sysctl = t; + + inet_netconf_notify_devconf(net, RTM_NEWNETCONF, NETCONFA_ALL, + ifindex, p); + return 0; + +free: + kfree(t); +out: + return -ENOMEM; +} + +static void __devinet_sysctl_unregister(struct net *net, + struct ipv4_devconf *cnf, int ifindex) +{ + struct devinet_sysctl_table *t = cnf->sysctl; + + if (t) { + cnf->sysctl = NULL; + unregister_net_sysctl_table(t->sysctl_header); + kfree(t); + } + + inet_netconf_notify_devconf(net, RTM_DELNETCONF, 0, ifindex, NULL); +} + +static int devinet_sysctl_register(struct in_device *idev) +{ + int err; + + if (!sysctl_dev_name_is_allowed(idev->dev->name)) + return -EINVAL; + + err = neigh_sysctl_register(idev->dev, idev->arp_parms, NULL); + if (err) + return err; + err = __devinet_sysctl_register(dev_net(idev->dev), idev->dev->name, + idev->dev->ifindex, &idev->cnf); + if (err) + neigh_sysctl_unregister(idev->arp_parms); + return err; +} + +static void devinet_sysctl_unregister(struct in_device *idev) +{ + struct net *net = dev_net(idev->dev); + + __devinet_sysctl_unregister(net, &idev->cnf, idev->dev->ifindex); + neigh_sysctl_unregister(idev->arp_parms); +} + +static struct ctl_table ctl_forward_entry[] = { + { + .procname = "ip_forward", + .data = &ipv4_devconf.data[ + IPV4_DEVCONF_FORWARDING - 1], + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = devinet_sysctl_forward, + .extra1 = &ipv4_devconf, + .extra2 = &init_net, + }, + { }, +}; +#endif + +static __net_init int devinet_init_net(struct net *net) +{ + int err; + struct ipv4_devconf *all, *dflt; +#ifdef CONFIG_SYSCTL + struct ctl_table *tbl; + struct ctl_table_header *forw_hdr; +#endif + + err = -ENOMEM; + all = kmemdup(&ipv4_devconf, sizeof(ipv4_devconf), GFP_KERNEL); + if (!all) + goto err_alloc_all; + + dflt = kmemdup(&ipv4_devconf_dflt, sizeof(ipv4_devconf_dflt), GFP_KERNEL); + if (!dflt) + goto err_alloc_dflt; + +#ifdef CONFIG_SYSCTL + tbl = kmemdup(ctl_forward_entry, sizeof(ctl_forward_entry), GFP_KERNEL); + if (!tbl) + goto err_alloc_ctl; + + tbl[0].data = &all->data[IPV4_DEVCONF_FORWARDING - 1]; + tbl[0].extra1 = all; + tbl[0].extra2 = net; +#endif + + if (!net_eq(net, &init_net)) { + switch (net_inherit_devconf()) { + case 3: + /* copy from the current netns */ + memcpy(all, current->nsproxy->net_ns->ipv4.devconf_all, + sizeof(ipv4_devconf)); + memcpy(dflt, + current->nsproxy->net_ns->ipv4.devconf_dflt, + sizeof(ipv4_devconf_dflt)); + break; + case 0: + case 1: + /* copy from init_net */ + memcpy(all, init_net.ipv4.devconf_all, + sizeof(ipv4_devconf)); + memcpy(dflt, init_net.ipv4.devconf_dflt, + sizeof(ipv4_devconf_dflt)); + break; + case 2: + /* use compiled values */ + break; + } + } + +#ifdef CONFIG_SYSCTL + err = __devinet_sysctl_register(net, "all", NETCONFA_IFINDEX_ALL, all); + if (err < 0) + goto err_reg_all; + + err = __devinet_sysctl_register(net, "default", + NETCONFA_IFINDEX_DEFAULT, dflt); + if (err < 0) + goto err_reg_dflt; + + err = -ENOMEM; + forw_hdr = register_net_sysctl(net, "net/ipv4", tbl); + if (!forw_hdr) + goto err_reg_ctl; + net->ipv4.forw_hdr = forw_hdr; +#endif + + net->ipv4.devconf_all = all; + net->ipv4.devconf_dflt = dflt; + return 0; + +#ifdef CONFIG_SYSCTL +err_reg_ctl: + __devinet_sysctl_unregister(net, dflt, NETCONFA_IFINDEX_DEFAULT); +err_reg_dflt: + __devinet_sysctl_unregister(net, all, NETCONFA_IFINDEX_ALL); +err_reg_all: + kfree(tbl); +err_alloc_ctl: +#endif + kfree(dflt); +err_alloc_dflt: + kfree(all); +err_alloc_all: + return err; +} + +static __net_exit void devinet_exit_net(struct net *net) +{ +#ifdef CONFIG_SYSCTL + struct ctl_table *tbl; + + tbl = net->ipv4.forw_hdr->ctl_table_arg; + unregister_net_sysctl_table(net->ipv4.forw_hdr); + __devinet_sysctl_unregister(net, net->ipv4.devconf_dflt, + NETCONFA_IFINDEX_DEFAULT); + __devinet_sysctl_unregister(net, net->ipv4.devconf_all, + NETCONFA_IFINDEX_ALL); + kfree(tbl); +#endif + kfree(net->ipv4.devconf_dflt); + kfree(net->ipv4.devconf_all); +} + +static __net_initdata struct pernet_operations devinet_ops = { + .init = devinet_init_net, + .exit = devinet_exit_net, +}; + +static struct rtnl_af_ops inet_af_ops __read_mostly = { + .family = AF_INET, + .fill_link_af = inet_fill_link_af, + .get_link_af_size = inet_get_link_af_size, + .validate_link_af = inet_validate_link_af, + .set_link_af = inet_set_link_af, +}; + +void __init devinet_init(void) +{ + int i; + + for (i = 0; i < IN4_ADDR_HSIZE; i++) + INIT_HLIST_HEAD(&inet_addr_lst[i]); + + register_pernet_subsys(&devinet_ops); + register_netdevice_notifier(&ip_netdev_notifier); + + queue_delayed_work(system_power_efficient_wq, &check_lifetime_work, 0); + + rtnl_af_register(&inet_af_ops); + + rtnl_register(PF_INET, RTM_NEWADDR, inet_rtm_newaddr, NULL, 0); + rtnl_register(PF_INET, RTM_DELADDR, inet_rtm_deladdr, NULL, 0); + rtnl_register(PF_INET, RTM_GETADDR, NULL, inet_dump_ifaddr, 0); + rtnl_register(PF_INET, RTM_GETNETCONF, inet_netconf_get_devconf, + inet_netconf_dump_devconf, 0); +} diff --git a/net/ipv4/esp4.c b/net/ipv4/esp4.c new file mode 100644 index 000000000..e2546961a --- /dev/null +++ b/net/ipv4/esp4.c @@ -0,0 +1,1251 @@ +// SPDX-License-Identifier: GPL-2.0-only +#define pr_fmt(fmt) "IPsec: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +struct esp_skb_cb { + struct xfrm_skb_cb xfrm; + void *tmp; +}; + +struct esp_output_extra { + __be32 seqhi; + u32 esphoff; +}; + +#define ESP_SKB_CB(__skb) ((struct esp_skb_cb *)&((__skb)->cb[0])) + +/* + * Allocate an AEAD request structure with extra space for SG and IV. + * + * For alignment considerations the IV is placed at the front, followed + * by the request and finally the SG list. + * + * TODO: Use spare space in skb for this where possible. + */ +static void *esp_alloc_tmp(struct crypto_aead *aead, int nfrags, int extralen) +{ + unsigned int len; + + len = extralen; + + len += crypto_aead_ivsize(aead); + + if (len) { + len += crypto_aead_alignmask(aead) & + ~(crypto_tfm_ctx_alignment() - 1); + len = ALIGN(len, crypto_tfm_ctx_alignment()); + } + + len += sizeof(struct aead_request) + crypto_aead_reqsize(aead); + len = ALIGN(len, __alignof__(struct scatterlist)); + + len += sizeof(struct scatterlist) * nfrags; + + return kmalloc(len, GFP_ATOMIC); +} + +static inline void *esp_tmp_extra(void *tmp) +{ + return PTR_ALIGN(tmp, __alignof__(struct esp_output_extra)); +} + +static inline u8 *esp_tmp_iv(struct crypto_aead *aead, void *tmp, int extralen) +{ + return crypto_aead_ivsize(aead) ? + PTR_ALIGN((u8 *)tmp + extralen, + crypto_aead_alignmask(aead) + 1) : tmp + extralen; +} + +static inline struct aead_request *esp_tmp_req(struct crypto_aead *aead, u8 *iv) +{ + struct aead_request *req; + + req = (void *)PTR_ALIGN(iv + crypto_aead_ivsize(aead), + crypto_tfm_ctx_alignment()); + aead_request_set_tfm(req, aead); + return req; +} + +static inline struct scatterlist *esp_req_sg(struct crypto_aead *aead, + struct aead_request *req) +{ + return (void *)ALIGN((unsigned long)(req + 1) + + crypto_aead_reqsize(aead), + __alignof__(struct scatterlist)); +} + +static void esp_ssg_unref(struct xfrm_state *x, void *tmp) +{ + struct crypto_aead *aead = x->data; + int extralen = 0; + u8 *iv; + struct aead_request *req; + struct scatterlist *sg; + + if (x->props.flags & XFRM_STATE_ESN) + extralen += sizeof(struct esp_output_extra); + + iv = esp_tmp_iv(aead, tmp, extralen); + req = esp_tmp_req(aead, iv); + + /* Unref skb_frag_pages in the src scatterlist if necessary. + * Skip the first sg which comes from skb->data. + */ + if (req->src != req->dst) + for (sg = sg_next(req->src); sg; sg = sg_next(sg)) + put_page(sg_page(sg)); +} + +#ifdef CONFIG_INET_ESPINTCP +struct esp_tcp_sk { + struct sock *sk; + struct rcu_head rcu; +}; + +static void esp_free_tcp_sk(struct rcu_head *head) +{ + struct esp_tcp_sk *esk = container_of(head, struct esp_tcp_sk, rcu); + + sock_put(esk->sk); + kfree(esk); +} + +static struct sock *esp_find_tcp_sk(struct xfrm_state *x) +{ + struct xfrm_encap_tmpl *encap = x->encap; + struct net *net = xs_net(x); + struct esp_tcp_sk *esk; + __be16 sport, dport; + struct sock *nsk; + struct sock *sk; + + sk = rcu_dereference(x->encap_sk); + if (sk && sk->sk_state == TCP_ESTABLISHED) + return sk; + + spin_lock_bh(&x->lock); + sport = encap->encap_sport; + dport = encap->encap_dport; + nsk = rcu_dereference_protected(x->encap_sk, + lockdep_is_held(&x->lock)); + if (sk && sk == nsk) { + esk = kmalloc(sizeof(*esk), GFP_ATOMIC); + if (!esk) { + spin_unlock_bh(&x->lock); + return ERR_PTR(-ENOMEM); + } + RCU_INIT_POINTER(x->encap_sk, NULL); + esk->sk = sk; + call_rcu(&esk->rcu, esp_free_tcp_sk); + } + spin_unlock_bh(&x->lock); + + sk = inet_lookup_established(net, net->ipv4.tcp_death_row.hashinfo, x->id.daddr.a4, + dport, x->props.saddr.a4, sport, 0); + if (!sk) + return ERR_PTR(-ENOENT); + + if (!tcp_is_ulp_esp(sk)) { + sock_put(sk); + return ERR_PTR(-EINVAL); + } + + spin_lock_bh(&x->lock); + nsk = rcu_dereference_protected(x->encap_sk, + lockdep_is_held(&x->lock)); + if (encap->encap_sport != sport || + encap->encap_dport != dport) { + sock_put(sk); + sk = nsk ?: ERR_PTR(-EREMCHG); + } else if (sk == nsk) { + sock_put(sk); + } else { + rcu_assign_pointer(x->encap_sk, sk); + } + spin_unlock_bh(&x->lock); + + return sk; +} + +static int esp_output_tcp_finish(struct xfrm_state *x, struct sk_buff *skb) +{ + struct sock *sk; + int err; + + rcu_read_lock(); + + sk = esp_find_tcp_sk(x); + err = PTR_ERR_OR_ZERO(sk); + if (err) + goto out; + + bh_lock_sock(sk); + if (sock_owned_by_user(sk)) + err = espintcp_queue_out(sk, skb); + else + err = espintcp_push_skb(sk, skb); + bh_unlock_sock(sk); + +out: + rcu_read_unlock(); + return err; +} + +static int esp_output_tcp_encap_cb(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + struct xfrm_state *x = dst->xfrm; + + return esp_output_tcp_finish(x, skb); +} + +static int esp_output_tail_tcp(struct xfrm_state *x, struct sk_buff *skb) +{ + int err; + + local_bh_disable(); + err = xfrm_trans_queue_net(xs_net(x), skb, esp_output_tcp_encap_cb); + local_bh_enable(); + + /* EINPROGRESS just happens to do the right thing. It + * actually means that the skb has been consumed and + * isn't coming back. + */ + return err ?: -EINPROGRESS; +} +#else +static int esp_output_tail_tcp(struct xfrm_state *x, struct sk_buff *skb) +{ + kfree_skb(skb); + + return -EOPNOTSUPP; +} +#endif + +static void esp_output_done(struct crypto_async_request *base, int err) +{ + struct sk_buff *skb = base->data; + struct xfrm_offload *xo = xfrm_offload(skb); + void *tmp; + struct xfrm_state *x; + + if (xo && (xo->flags & XFRM_DEV_RESUME)) { + struct sec_path *sp = skb_sec_path(skb); + + x = sp->xvec[sp->len - 1]; + } else { + x = skb_dst(skb)->xfrm; + } + + tmp = ESP_SKB_CB(skb)->tmp; + esp_ssg_unref(x, tmp); + kfree(tmp); + + if (xo && (xo->flags & XFRM_DEV_RESUME)) { + if (err) { + XFRM_INC_STATS(xs_net(x), LINUX_MIB_XFRMOUTSTATEPROTOERROR); + kfree_skb(skb); + return; + } + + skb_push(skb, skb->data - skb_mac_header(skb)); + secpath_reset(skb); + xfrm_dev_resume(skb); + } else { + if (!err && + x->encap && x->encap->encap_type == TCP_ENCAP_ESPINTCP) + esp_output_tail_tcp(x, skb); + else + xfrm_output_resume(skb->sk, skb, err); + } +} + +/* Move ESP header back into place. */ +static void esp_restore_header(struct sk_buff *skb, unsigned int offset) +{ + struct ip_esp_hdr *esph = (void *)(skb->data + offset); + void *tmp = ESP_SKB_CB(skb)->tmp; + __be32 *seqhi = esp_tmp_extra(tmp); + + esph->seq_no = esph->spi; + esph->spi = *seqhi; +} + +static void esp_output_restore_header(struct sk_buff *skb) +{ + void *tmp = ESP_SKB_CB(skb)->tmp; + struct esp_output_extra *extra = esp_tmp_extra(tmp); + + esp_restore_header(skb, skb_transport_offset(skb) + extra->esphoff - + sizeof(__be32)); +} + +static struct ip_esp_hdr *esp_output_set_extra(struct sk_buff *skb, + struct xfrm_state *x, + struct ip_esp_hdr *esph, + struct esp_output_extra *extra) +{ + /* For ESN we move the header forward by 4 bytes to + * accommodate the high bits. We will move it back after + * encryption. + */ + if ((x->props.flags & XFRM_STATE_ESN)) { + __u32 seqhi; + struct xfrm_offload *xo = xfrm_offload(skb); + + if (xo) + seqhi = xo->seq.hi; + else + seqhi = XFRM_SKB_CB(skb)->seq.output.hi; + + extra->esphoff = (unsigned char *)esph - + skb_transport_header(skb); + esph = (struct ip_esp_hdr *)((unsigned char *)esph - 4); + extra->seqhi = esph->spi; + esph->seq_no = htonl(seqhi); + } + + esph->spi = x->id.spi; + + return esph; +} + +static void esp_output_done_esn(struct crypto_async_request *base, int err) +{ + struct sk_buff *skb = base->data; + + esp_output_restore_header(skb); + esp_output_done(base, err); +} + +static struct ip_esp_hdr *esp_output_udp_encap(struct sk_buff *skb, + int encap_type, + struct esp_info *esp, + __be16 sport, + __be16 dport) +{ + struct udphdr *uh; + __be32 *udpdata32; + unsigned int len; + + len = skb->len + esp->tailen - skb_transport_offset(skb); + if (len + sizeof(struct iphdr) > IP_MAX_MTU) + return ERR_PTR(-EMSGSIZE); + + uh = (struct udphdr *)esp->esph; + uh->source = sport; + uh->dest = dport; + uh->len = htons(len); + uh->check = 0; + + *skb_mac_header(skb) = IPPROTO_UDP; + + if (encap_type == UDP_ENCAP_ESPINUDP_NON_IKE) { + udpdata32 = (__be32 *)(uh + 1); + udpdata32[0] = udpdata32[1] = 0; + return (struct ip_esp_hdr *)(udpdata32 + 2); + } + + return (struct ip_esp_hdr *)(uh + 1); +} + +#ifdef CONFIG_INET_ESPINTCP +static struct ip_esp_hdr *esp_output_tcp_encap(struct xfrm_state *x, + struct sk_buff *skb, + struct esp_info *esp) +{ + __be16 *lenp = (void *)esp->esph; + struct ip_esp_hdr *esph; + unsigned int len; + struct sock *sk; + + len = skb->len + esp->tailen - skb_transport_offset(skb); + if (len > IP_MAX_MTU) + return ERR_PTR(-EMSGSIZE); + + rcu_read_lock(); + sk = esp_find_tcp_sk(x); + rcu_read_unlock(); + + if (IS_ERR(sk)) + return ERR_CAST(sk); + + *lenp = htons(len); + esph = (struct ip_esp_hdr *)(lenp + 1); + + return esph; +} +#else +static struct ip_esp_hdr *esp_output_tcp_encap(struct xfrm_state *x, + struct sk_buff *skb, + struct esp_info *esp) +{ + return ERR_PTR(-EOPNOTSUPP); +} +#endif + +static int esp_output_encap(struct xfrm_state *x, struct sk_buff *skb, + struct esp_info *esp) +{ + struct xfrm_encap_tmpl *encap = x->encap; + struct ip_esp_hdr *esph; + __be16 sport, dport; + int encap_type; + + spin_lock_bh(&x->lock); + sport = encap->encap_sport; + dport = encap->encap_dport; + encap_type = encap->encap_type; + spin_unlock_bh(&x->lock); + + switch (encap_type) { + default: + case UDP_ENCAP_ESPINUDP: + case UDP_ENCAP_ESPINUDP_NON_IKE: + esph = esp_output_udp_encap(skb, encap_type, esp, sport, dport); + break; + case TCP_ENCAP_ESPINTCP: + esph = esp_output_tcp_encap(x, skb, esp); + break; + } + + if (IS_ERR(esph)) + return PTR_ERR(esph); + + esp->esph = esph; + + return 0; +} + +int esp_output_head(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *esp) +{ + u8 *tail; + int nfrags; + int esph_offset; + struct page *page; + struct sk_buff *trailer; + int tailen = esp->tailen; + + /* this is non-NULL only with TCP/UDP Encapsulation */ + if (x->encap) { + int err = esp_output_encap(x, skb, esp); + + if (err < 0) + return err; + } + + if (ALIGN(tailen, L1_CACHE_BYTES) > PAGE_SIZE || + ALIGN(skb->data_len, L1_CACHE_BYTES) > PAGE_SIZE) + goto cow; + + if (!skb_cloned(skb)) { + if (tailen <= skb_tailroom(skb)) { + nfrags = 1; + trailer = skb; + tail = skb_tail_pointer(trailer); + + goto skip_cow; + } else if ((skb_shinfo(skb)->nr_frags < MAX_SKB_FRAGS) + && !skb_has_frag_list(skb)) { + int allocsize; + struct sock *sk = skb->sk; + struct page_frag *pfrag = &x->xfrag; + + esp->inplace = false; + + allocsize = ALIGN(tailen, L1_CACHE_BYTES); + + spin_lock_bh(&x->lock); + + if (unlikely(!skb_page_frag_refill(allocsize, pfrag, GFP_ATOMIC))) { + spin_unlock_bh(&x->lock); + goto cow; + } + + page = pfrag->page; + get_page(page); + + tail = page_address(page) + pfrag->offset; + + esp_output_fill_trailer(tail, esp->tfclen, esp->plen, esp->proto); + + nfrags = skb_shinfo(skb)->nr_frags; + + __skb_fill_page_desc(skb, nfrags, page, pfrag->offset, + tailen); + skb_shinfo(skb)->nr_frags = ++nfrags; + + pfrag->offset = pfrag->offset + allocsize; + + spin_unlock_bh(&x->lock); + + nfrags++; + + skb_len_add(skb, tailen); + if (sk && sk_fullsock(sk)) + refcount_add(tailen, &sk->sk_wmem_alloc); + + goto out; + } + } + +cow: + esph_offset = (unsigned char *)esp->esph - skb_transport_header(skb); + + nfrags = skb_cow_data(skb, tailen, &trailer); + if (nfrags < 0) + goto out; + tail = skb_tail_pointer(trailer); + esp->esph = (struct ip_esp_hdr *)(skb_transport_header(skb) + esph_offset); + +skip_cow: + esp_output_fill_trailer(tail, esp->tfclen, esp->plen, esp->proto); + pskb_put(skb, trailer, tailen); + +out: + return nfrags; +} +EXPORT_SYMBOL_GPL(esp_output_head); + +int esp_output_tail(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *esp) +{ + u8 *iv; + int alen; + void *tmp; + int ivlen; + int assoclen; + int extralen; + struct page *page; + struct ip_esp_hdr *esph; + struct crypto_aead *aead; + struct aead_request *req; + struct scatterlist *sg, *dsg; + struct esp_output_extra *extra; + int err = -ENOMEM; + + assoclen = sizeof(struct ip_esp_hdr); + extralen = 0; + + if (x->props.flags & XFRM_STATE_ESN) { + extralen += sizeof(*extra); + assoclen += sizeof(__be32); + } + + aead = x->data; + alen = crypto_aead_authsize(aead); + ivlen = crypto_aead_ivsize(aead); + + tmp = esp_alloc_tmp(aead, esp->nfrags + 2, extralen); + if (!tmp) + goto error; + + extra = esp_tmp_extra(tmp); + iv = esp_tmp_iv(aead, tmp, extralen); + req = esp_tmp_req(aead, iv); + sg = esp_req_sg(aead, req); + + if (esp->inplace) + dsg = sg; + else + dsg = &sg[esp->nfrags]; + + esph = esp_output_set_extra(skb, x, esp->esph, extra); + esp->esph = esph; + + sg_init_table(sg, esp->nfrags); + err = skb_to_sgvec(skb, sg, + (unsigned char *)esph - skb->data, + assoclen + ivlen + esp->clen + alen); + if (unlikely(err < 0)) + goto error_free; + + if (!esp->inplace) { + int allocsize; + struct page_frag *pfrag = &x->xfrag; + + allocsize = ALIGN(skb->data_len, L1_CACHE_BYTES); + + spin_lock_bh(&x->lock); + if (unlikely(!skb_page_frag_refill(allocsize, pfrag, GFP_ATOMIC))) { + spin_unlock_bh(&x->lock); + goto error_free; + } + + skb_shinfo(skb)->nr_frags = 1; + + page = pfrag->page; + get_page(page); + /* replace page frags in skb with new page */ + __skb_fill_page_desc(skb, 0, page, pfrag->offset, skb->data_len); + pfrag->offset = pfrag->offset + allocsize; + spin_unlock_bh(&x->lock); + + sg_init_table(dsg, skb_shinfo(skb)->nr_frags + 1); + err = skb_to_sgvec(skb, dsg, + (unsigned char *)esph - skb->data, + assoclen + ivlen + esp->clen + alen); + if (unlikely(err < 0)) + goto error_free; + } + + if ((x->props.flags & XFRM_STATE_ESN)) + aead_request_set_callback(req, 0, esp_output_done_esn, skb); + else + aead_request_set_callback(req, 0, esp_output_done, skb); + + aead_request_set_crypt(req, sg, dsg, ivlen + esp->clen, iv); + aead_request_set_ad(req, assoclen); + + memset(iv, 0, ivlen); + memcpy(iv + ivlen - min(ivlen, 8), (u8 *)&esp->seqno + 8 - min(ivlen, 8), + min(ivlen, 8)); + + ESP_SKB_CB(skb)->tmp = tmp; + err = crypto_aead_encrypt(req); + + switch (err) { + case -EINPROGRESS: + goto error; + + case -ENOSPC: + err = NET_XMIT_DROP; + break; + + case 0: + if ((x->props.flags & XFRM_STATE_ESN)) + esp_output_restore_header(skb); + } + + if (sg != dsg) + esp_ssg_unref(x, tmp); + + if (!err && x->encap && x->encap->encap_type == TCP_ENCAP_ESPINTCP) + err = esp_output_tail_tcp(x, skb); + +error_free: + kfree(tmp); +error: + return err; +} +EXPORT_SYMBOL_GPL(esp_output_tail); + +static int esp_output(struct xfrm_state *x, struct sk_buff *skb) +{ + int alen; + int blksize; + struct ip_esp_hdr *esph; + struct crypto_aead *aead; + struct esp_info esp; + + esp.inplace = true; + + esp.proto = *skb_mac_header(skb); + *skb_mac_header(skb) = IPPROTO_ESP; + + /* skb is pure payload to encrypt */ + + aead = x->data; + alen = crypto_aead_authsize(aead); + + esp.tfclen = 0; + if (x->tfcpad) { + struct xfrm_dst *dst = (struct xfrm_dst *)skb_dst(skb); + u32 padto; + + padto = min(x->tfcpad, xfrm_state_mtu(x, dst->child_mtu_cached)); + if (skb->len < padto) + esp.tfclen = padto - skb->len; + } + blksize = ALIGN(crypto_aead_blocksize(aead), 4); + esp.clen = ALIGN(skb->len + 2 + esp.tfclen, blksize); + esp.plen = esp.clen - skb->len - esp.tfclen; + esp.tailen = esp.tfclen + esp.plen + alen; + + esp.esph = ip_esp_hdr(skb); + + esp.nfrags = esp_output_head(x, skb, &esp); + if (esp.nfrags < 0) + return esp.nfrags; + + esph = esp.esph; + esph->spi = x->id.spi; + + esph->seq_no = htonl(XFRM_SKB_CB(skb)->seq.output.low); + esp.seqno = cpu_to_be64(XFRM_SKB_CB(skb)->seq.output.low + + ((u64)XFRM_SKB_CB(skb)->seq.output.hi << 32)); + + skb_push(skb, -skb_network_offset(skb)); + + return esp_output_tail(x, skb, &esp); +} + +static inline int esp_remove_trailer(struct sk_buff *skb) +{ + struct xfrm_state *x = xfrm_input_state(skb); + struct crypto_aead *aead = x->data; + int alen, hlen, elen; + int padlen, trimlen; + __wsum csumdiff; + u8 nexthdr[2]; + int ret; + + alen = crypto_aead_authsize(aead); + hlen = sizeof(struct ip_esp_hdr) + crypto_aead_ivsize(aead); + elen = skb->len - hlen; + + if (skb_copy_bits(skb, skb->len - alen - 2, nexthdr, 2)) + BUG(); + + ret = -EINVAL; + padlen = nexthdr[0]; + if (padlen + 2 + alen >= elen) { + net_dbg_ratelimited("ipsec esp packet is garbage padlen=%d, elen=%d\n", + padlen + 2, elen - alen); + goto out; + } + + trimlen = alen + padlen + 2; + if (skb->ip_summed == CHECKSUM_COMPLETE) { + csumdiff = skb_checksum(skb, skb->len - trimlen, trimlen, 0); + skb->csum = csum_block_sub(skb->csum, csumdiff, + skb->len - trimlen); + } + ret = pskb_trim(skb, skb->len - trimlen); + if (unlikely(ret)) + return ret; + + ret = nexthdr[1]; + +out: + return ret; +} + +int esp_input_done2(struct sk_buff *skb, int err) +{ + const struct iphdr *iph; + struct xfrm_state *x = xfrm_input_state(skb); + struct xfrm_offload *xo = xfrm_offload(skb); + struct crypto_aead *aead = x->data; + int hlen = sizeof(struct ip_esp_hdr) + crypto_aead_ivsize(aead); + int ihl; + + if (!xo || !(xo->flags & CRYPTO_DONE)) + kfree(ESP_SKB_CB(skb)->tmp); + + if (unlikely(err)) + goto out; + + err = esp_remove_trailer(skb); + if (unlikely(err < 0)) + goto out; + + iph = ip_hdr(skb); + ihl = iph->ihl * 4; + + if (x->encap) { + struct xfrm_encap_tmpl *encap = x->encap; + struct tcphdr *th = (void *)(skb_network_header(skb) + ihl); + struct udphdr *uh = (void *)(skb_network_header(skb) + ihl); + __be16 source; + + switch (x->encap->encap_type) { + case TCP_ENCAP_ESPINTCP: + source = th->source; + break; + case UDP_ENCAP_ESPINUDP: + case UDP_ENCAP_ESPINUDP_NON_IKE: + source = uh->source; + break; + default: + WARN_ON_ONCE(1); + err = -EINVAL; + goto out; + } + + /* + * 1) if the NAT-T peer's IP or port changed then + * advertize the change to the keying daemon. + * This is an inbound SA, so just compare + * SRC ports. + */ + if (iph->saddr != x->props.saddr.a4 || + source != encap->encap_sport) { + xfrm_address_t ipaddr; + + ipaddr.a4 = iph->saddr; + km_new_mapping(x, &ipaddr, source); + + /* XXX: perhaps add an extra + * policy check here, to see + * if we should allow or + * reject a packet from a + * different source + * address/port. + */ + } + + /* + * 2) ignore UDP/TCP checksums in case + * of NAT-T in Transport Mode, or + * perform other post-processing fixes + * as per draft-ietf-ipsec-udp-encaps-06, + * section 3.1.2 + */ + if (x->props.mode == XFRM_MODE_TRANSPORT) + skb->ip_summed = CHECKSUM_UNNECESSARY; + } + + skb_pull_rcsum(skb, hlen); + if (x->props.mode == XFRM_MODE_TUNNEL) + skb_reset_transport_header(skb); + else + skb_set_transport_header(skb, -ihl); + + /* RFC4303: Drop dummy packets without any error */ + if (err == IPPROTO_NONE) + err = -EINVAL; + +out: + return err; +} +EXPORT_SYMBOL_GPL(esp_input_done2); + +static void esp_input_done(struct crypto_async_request *base, int err) +{ + struct sk_buff *skb = base->data; + + xfrm_input_resume(skb, esp_input_done2(skb, err)); +} + +static void esp_input_restore_header(struct sk_buff *skb) +{ + esp_restore_header(skb, 0); + __skb_pull(skb, 4); +} + +static void esp_input_set_header(struct sk_buff *skb, __be32 *seqhi) +{ + struct xfrm_state *x = xfrm_input_state(skb); + struct ip_esp_hdr *esph; + + /* For ESN we move the header forward by 4 bytes to + * accommodate the high bits. We will move it back after + * decryption. + */ + if ((x->props.flags & XFRM_STATE_ESN)) { + esph = skb_push(skb, 4); + *seqhi = esph->spi; + esph->spi = esph->seq_no; + esph->seq_no = XFRM_SKB_CB(skb)->seq.input.hi; + } +} + +static void esp_input_done_esn(struct crypto_async_request *base, int err) +{ + struct sk_buff *skb = base->data; + + esp_input_restore_header(skb); + esp_input_done(base, err); +} + +/* + * Note: detecting truncated vs. non-truncated authentication data is very + * expensive, so we only support truncated data, which is the recommended + * and common case. + */ +static int esp_input(struct xfrm_state *x, struct sk_buff *skb) +{ + struct crypto_aead *aead = x->data; + struct aead_request *req; + struct sk_buff *trailer; + int ivlen = crypto_aead_ivsize(aead); + int elen = skb->len - sizeof(struct ip_esp_hdr) - ivlen; + int nfrags; + int assoclen; + int seqhilen; + __be32 *seqhi; + void *tmp; + u8 *iv; + struct scatterlist *sg; + int err = -EINVAL; + + if (!pskb_may_pull(skb, sizeof(struct ip_esp_hdr) + ivlen)) + goto out; + + if (elen <= 0) + goto out; + + assoclen = sizeof(struct ip_esp_hdr); + seqhilen = 0; + + if (x->props.flags & XFRM_STATE_ESN) { + seqhilen += sizeof(__be32); + assoclen += seqhilen; + } + + if (!skb_cloned(skb)) { + if (!skb_is_nonlinear(skb)) { + nfrags = 1; + + goto skip_cow; + } else if (!skb_has_frag_list(skb)) { + nfrags = skb_shinfo(skb)->nr_frags; + nfrags++; + + goto skip_cow; + } + } + + err = skb_cow_data(skb, 0, &trailer); + if (err < 0) + goto out; + + nfrags = err; + +skip_cow: + err = -ENOMEM; + tmp = esp_alloc_tmp(aead, nfrags, seqhilen); + if (!tmp) + goto out; + + ESP_SKB_CB(skb)->tmp = tmp; + seqhi = esp_tmp_extra(tmp); + iv = esp_tmp_iv(aead, tmp, seqhilen); + req = esp_tmp_req(aead, iv); + sg = esp_req_sg(aead, req); + + esp_input_set_header(skb, seqhi); + + sg_init_table(sg, nfrags); + err = skb_to_sgvec(skb, sg, 0, skb->len); + if (unlikely(err < 0)) { + kfree(tmp); + goto out; + } + + skb->ip_summed = CHECKSUM_NONE; + + if ((x->props.flags & XFRM_STATE_ESN)) + aead_request_set_callback(req, 0, esp_input_done_esn, skb); + else + aead_request_set_callback(req, 0, esp_input_done, skb); + + aead_request_set_crypt(req, sg, sg, elen + ivlen, iv); + aead_request_set_ad(req, assoclen); + + err = crypto_aead_decrypt(req); + if (err == -EINPROGRESS) + goto out; + + if ((x->props.flags & XFRM_STATE_ESN)) + esp_input_restore_header(skb); + + err = esp_input_done2(skb, err); + +out: + return err; +} + +static int esp4_err(struct sk_buff *skb, u32 info) +{ + struct net *net = dev_net(skb->dev); + const struct iphdr *iph = (const struct iphdr *)skb->data; + struct ip_esp_hdr *esph = (struct ip_esp_hdr *)(skb->data+(iph->ihl<<2)); + struct xfrm_state *x; + + switch (icmp_hdr(skb)->type) { + case ICMP_DEST_UNREACH: + if (icmp_hdr(skb)->code != ICMP_FRAG_NEEDED) + return 0; + break; + case ICMP_REDIRECT: + break; + default: + return 0; + } + + x = xfrm_state_lookup(net, skb->mark, (const xfrm_address_t *)&iph->daddr, + esph->spi, IPPROTO_ESP, AF_INET); + if (!x) + return 0; + + if (icmp_hdr(skb)->type == ICMP_DEST_UNREACH) + ipv4_update_pmtu(skb, net, info, 0, IPPROTO_ESP); + else + ipv4_redirect(skb, net, 0, IPPROTO_ESP); + xfrm_state_put(x); + + return 0; +} + +static void esp_destroy(struct xfrm_state *x) +{ + struct crypto_aead *aead = x->data; + + if (!aead) + return; + + crypto_free_aead(aead); +} + +static int esp_init_aead(struct xfrm_state *x, struct netlink_ext_ack *extack) +{ + char aead_name[CRYPTO_MAX_ALG_NAME]; + struct crypto_aead *aead; + int err; + + if (snprintf(aead_name, CRYPTO_MAX_ALG_NAME, "%s(%s)", + x->geniv, x->aead->alg_name) >= CRYPTO_MAX_ALG_NAME) { + NL_SET_ERR_MSG(extack, "Algorithm name is too long"); + return -ENAMETOOLONG; + } + + aead = crypto_alloc_aead(aead_name, 0, 0); + err = PTR_ERR(aead); + if (IS_ERR(aead)) + goto error; + + x->data = aead; + + err = crypto_aead_setkey(aead, x->aead->alg_key, + (x->aead->alg_key_len + 7) / 8); + if (err) + goto error; + + err = crypto_aead_setauthsize(aead, x->aead->alg_icv_len / 8); + if (err) + goto error; + + return 0; + +error: + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); + return err; +} + +static int esp_init_authenc(struct xfrm_state *x, + struct netlink_ext_ack *extack) +{ + struct crypto_aead *aead; + struct crypto_authenc_key_param *param; + struct rtattr *rta; + char *key; + char *p; + char authenc_name[CRYPTO_MAX_ALG_NAME]; + unsigned int keylen; + int err; + + err = -ENAMETOOLONG; + + if ((x->props.flags & XFRM_STATE_ESN)) { + if (snprintf(authenc_name, CRYPTO_MAX_ALG_NAME, + "%s%sauthencesn(%s,%s)%s", + x->geniv ?: "", x->geniv ? "(" : "", + x->aalg ? x->aalg->alg_name : "digest_null", + x->ealg->alg_name, + x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME) { + NL_SET_ERR_MSG(extack, "Algorithm name is too long"); + goto error; + } + } else { + if (snprintf(authenc_name, CRYPTO_MAX_ALG_NAME, + "%s%sauthenc(%s,%s)%s", + x->geniv ?: "", x->geniv ? "(" : "", + x->aalg ? x->aalg->alg_name : "digest_null", + x->ealg->alg_name, + x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME) { + NL_SET_ERR_MSG(extack, "Algorithm name is too long"); + goto error; + } + } + + aead = crypto_alloc_aead(authenc_name, 0, 0); + err = PTR_ERR(aead); + if (IS_ERR(aead)) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); + goto error; + } + + x->data = aead; + + keylen = (x->aalg ? (x->aalg->alg_key_len + 7) / 8 : 0) + + (x->ealg->alg_key_len + 7) / 8 + RTA_SPACE(sizeof(*param)); + err = -ENOMEM; + key = kmalloc(keylen, GFP_KERNEL); + if (!key) + goto error; + + p = key; + rta = (void *)p; + rta->rta_type = CRYPTO_AUTHENC_KEYA_PARAM; + rta->rta_len = RTA_LENGTH(sizeof(*param)); + param = RTA_DATA(rta); + p += RTA_SPACE(sizeof(*param)); + + if (x->aalg) { + struct xfrm_algo_desc *aalg_desc; + + memcpy(p, x->aalg->alg_key, (x->aalg->alg_key_len + 7) / 8); + p += (x->aalg->alg_key_len + 7) / 8; + + aalg_desc = xfrm_aalg_get_byname(x->aalg->alg_name, 0); + BUG_ON(!aalg_desc); + + err = -EINVAL; + if (aalg_desc->uinfo.auth.icv_fullbits / 8 != + crypto_aead_authsize(aead)) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); + goto free_key; + } + + err = crypto_aead_setauthsize( + aead, x->aalg->alg_trunc_len / 8); + if (err) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); + goto free_key; + } + } + + param->enckeylen = cpu_to_be32((x->ealg->alg_key_len + 7) / 8); + memcpy(p, x->ealg->alg_key, (x->ealg->alg_key_len + 7) / 8); + + err = crypto_aead_setkey(aead, key, keylen); + +free_key: + kfree_sensitive(key); + +error: + return err; +} + +static int esp_init_state(struct xfrm_state *x, struct netlink_ext_ack *extack) +{ + struct crypto_aead *aead; + u32 align; + int err; + + x->data = NULL; + + if (x->aead) { + err = esp_init_aead(x, extack); + } else if (x->ealg) { + err = esp_init_authenc(x, extack); + } else { + NL_SET_ERR_MSG(extack, "ESP: AEAD or CRYPT must be provided"); + err = -EINVAL; + } + + if (err) + goto error; + + aead = x->data; + + x->props.header_len = sizeof(struct ip_esp_hdr) + + crypto_aead_ivsize(aead); + if (x->props.mode == XFRM_MODE_TUNNEL) + x->props.header_len += sizeof(struct iphdr); + else if (x->props.mode == XFRM_MODE_BEET && x->sel.family != AF_INET6) + x->props.header_len += IPV4_BEET_PHMAXLEN; + if (x->encap) { + struct xfrm_encap_tmpl *encap = x->encap; + + switch (encap->encap_type) { + default: + NL_SET_ERR_MSG(extack, "Unsupported encapsulation type for ESP"); + err = -EINVAL; + goto error; + case UDP_ENCAP_ESPINUDP: + x->props.header_len += sizeof(struct udphdr); + break; + case UDP_ENCAP_ESPINUDP_NON_IKE: + x->props.header_len += sizeof(struct udphdr) + 2 * sizeof(u32); + break; +#ifdef CONFIG_INET_ESPINTCP + case TCP_ENCAP_ESPINTCP: + /* only the length field, TCP encap is done by + * the socket + */ + x->props.header_len += 2; + break; +#endif + } + } + + align = ALIGN(crypto_aead_blocksize(aead), 4); + x->props.trailer_len = align + 1 + crypto_aead_authsize(aead); + +error: + return err; +} + +static int esp4_rcv_cb(struct sk_buff *skb, int err) +{ + return 0; +} + +static const struct xfrm_type esp_type = +{ + .owner = THIS_MODULE, + .proto = IPPROTO_ESP, + .flags = XFRM_TYPE_REPLAY_PROT, + .init_state = esp_init_state, + .destructor = esp_destroy, + .input = esp_input, + .output = esp_output, +}; + +static struct xfrm4_protocol esp4_protocol = { + .handler = xfrm4_rcv, + .input_handler = xfrm_input, + .cb_handler = esp4_rcv_cb, + .err_handler = esp4_err, + .priority = 0, +}; + +static int __init esp4_init(void) +{ + if (xfrm_register_type(&esp_type, AF_INET) < 0) { + pr_info("%s: can't add xfrm type\n", __func__); + return -EAGAIN; + } + if (xfrm4_protocol_register(&esp4_protocol, IPPROTO_ESP) < 0) { + pr_info("%s: can't add protocol\n", __func__); + xfrm_unregister_type(&esp_type, AF_INET); + return -EAGAIN; + } + return 0; +} + +static void __exit esp4_fini(void) +{ + if (xfrm4_protocol_deregister(&esp4_protocol, IPPROTO_ESP) < 0) + pr_info("%s: can't remove protocol\n", __func__); + xfrm_unregister_type(&esp_type, AF_INET); +} + +module_init(esp4_init); +module_exit(esp4_fini); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_XFRM_TYPE(AF_INET, XFRM_PROTO_ESP); diff --git a/net/ipv4/esp4_offload.c b/net/ipv4/esp4_offload.c new file mode 100644 index 000000000..ee848be59 --- /dev/null +++ b/net/ipv4/esp4_offload.c @@ -0,0 +1,385 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * IPV4 GSO/GRO offload support + * Linux INET implementation + * + * Copyright (C) 2016 secunet Security Networks AG + * Author: Steffen Klassert + * + * ESP GRO support + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static struct sk_buff *esp4_gro_receive(struct list_head *head, + struct sk_buff *skb) +{ + int offset = skb_gro_offset(skb); + struct xfrm_offload *xo; + struct xfrm_state *x; + __be32 seq; + __be32 spi; + + if (!pskb_pull(skb, offset)) + return NULL; + + if (xfrm_parse_spi(skb, IPPROTO_ESP, &spi, &seq) != 0) + goto out; + + xo = xfrm_offload(skb); + if (!xo || !(xo->flags & CRYPTO_DONE)) { + struct sec_path *sp = secpath_set(skb); + + if (!sp) + goto out; + + if (sp->len == XFRM_MAX_DEPTH) + goto out_reset; + + x = xfrm_state_lookup(dev_net(skb->dev), skb->mark, + (xfrm_address_t *)&ip_hdr(skb)->daddr, + spi, IPPROTO_ESP, AF_INET); + if (!x) + goto out_reset; + + skb->mark = xfrm_smark_get(skb->mark, x); + + sp->xvec[sp->len++] = x; + sp->olen++; + + xo = xfrm_offload(skb); + if (!xo) + goto out_reset; + } + + xo->flags |= XFRM_GRO; + + XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip4 = NULL; + XFRM_SPI_SKB_CB(skb)->family = AF_INET; + XFRM_SPI_SKB_CB(skb)->daddroff = offsetof(struct iphdr, daddr); + XFRM_SPI_SKB_CB(skb)->seq = seq; + + /* We don't need to handle errors from xfrm_input, it does all + * the error handling and frees the resources on error. */ + xfrm_input(skb, IPPROTO_ESP, spi, -2); + + return ERR_PTR(-EINPROGRESS); +out_reset: + secpath_reset(skb); +out: + skb_push(skb, offset); + NAPI_GRO_CB(skb)->same_flow = 0; + NAPI_GRO_CB(skb)->flush = 1; + + return NULL; +} + +static void esp4_gso_encap(struct xfrm_state *x, struct sk_buff *skb) +{ + struct ip_esp_hdr *esph; + struct iphdr *iph = ip_hdr(skb); + struct xfrm_offload *xo = xfrm_offload(skb); + int proto = iph->protocol; + + skb_push(skb, -skb_network_offset(skb)); + esph = ip_esp_hdr(skb); + *skb_mac_header(skb) = IPPROTO_ESP; + + esph->spi = x->id.spi; + esph->seq_no = htonl(XFRM_SKB_CB(skb)->seq.output.low); + + xo->proto = proto; +} + +static struct sk_buff *xfrm4_tunnel_gso_segment(struct xfrm_state *x, + struct sk_buff *skb, + netdev_features_t features) +{ + __be16 type = x->inner_mode.family == AF_INET6 ? htons(ETH_P_IPV6) + : htons(ETH_P_IP); + + return skb_eth_gso_segment(skb, features, type); +} + +static struct sk_buff *xfrm4_transport_gso_segment(struct xfrm_state *x, + struct sk_buff *skb, + netdev_features_t features) +{ + const struct net_offload *ops; + struct sk_buff *segs = ERR_PTR(-EINVAL); + struct xfrm_offload *xo = xfrm_offload(skb); + + skb->transport_header += x->props.header_len; + ops = rcu_dereference(inet_offloads[xo->proto]); + if (likely(ops && ops->callbacks.gso_segment)) + segs = ops->callbacks.gso_segment(skb, features); + + return segs; +} + +static struct sk_buff *xfrm4_beet_gso_segment(struct xfrm_state *x, + struct sk_buff *skb, + netdev_features_t features) +{ + struct xfrm_offload *xo = xfrm_offload(skb); + struct sk_buff *segs = ERR_PTR(-EINVAL); + const struct net_offload *ops; + u8 proto = xo->proto; + + skb->transport_header += x->props.header_len; + + if (x->sel.family != AF_INET6) { + if (proto == IPPROTO_BEETPH) { + struct ip_beet_phdr *ph = + (struct ip_beet_phdr *)skb->data; + + skb->transport_header += ph->hdrlen * 8; + proto = ph->nexthdr; + } else { + skb->transport_header -= IPV4_BEET_PHMAXLEN; + } + } else { + __be16 frag; + + skb->transport_header += + ipv6_skip_exthdr(skb, 0, &proto, &frag); + if (proto == IPPROTO_TCP) + skb_shinfo(skb)->gso_type |= SKB_GSO_TCPV4; + } + + if (proto == IPPROTO_IPV6) + skb_shinfo(skb)->gso_type |= SKB_GSO_IPXIP4; + + __skb_pull(skb, skb_transport_offset(skb)); + ops = rcu_dereference(inet_offloads[proto]); + if (likely(ops && ops->callbacks.gso_segment)) + segs = ops->callbacks.gso_segment(skb, features); + + return segs; +} + +static struct sk_buff *xfrm4_outer_mode_gso_segment(struct xfrm_state *x, + struct sk_buff *skb, + netdev_features_t features) +{ + switch (x->outer_mode.encap) { + case XFRM_MODE_TUNNEL: + return xfrm4_tunnel_gso_segment(x, skb, features); + case XFRM_MODE_TRANSPORT: + return xfrm4_transport_gso_segment(x, skb, features); + case XFRM_MODE_BEET: + return xfrm4_beet_gso_segment(x, skb, features); + } + + return ERR_PTR(-EOPNOTSUPP); +} + +static struct sk_buff *esp4_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + struct xfrm_state *x; + struct ip_esp_hdr *esph; + struct crypto_aead *aead; + netdev_features_t esp_features = features; + struct xfrm_offload *xo = xfrm_offload(skb); + struct sec_path *sp; + + if (!xo) + return ERR_PTR(-EINVAL); + + if (!(skb_shinfo(skb)->gso_type & SKB_GSO_ESP)) + return ERR_PTR(-EINVAL); + + sp = skb_sec_path(skb); + x = sp->xvec[sp->len - 1]; + aead = x->data; + esph = ip_esp_hdr(skb); + + if (esph->spi != x->id.spi) + return ERR_PTR(-EINVAL); + + if (!pskb_may_pull(skb, sizeof(*esph) + crypto_aead_ivsize(aead))) + return ERR_PTR(-EINVAL); + + __skb_pull(skb, sizeof(*esph) + crypto_aead_ivsize(aead)); + + skb->encap_hdr_csum = 1; + + if ((!(skb->dev->gso_partial_features & NETIF_F_HW_ESP) && + !(features & NETIF_F_HW_ESP)) || x->xso.dev != skb->dev) + esp_features = features & ~(NETIF_F_SG | NETIF_F_CSUM_MASK | + NETIF_F_SCTP_CRC); + else if (!(features & NETIF_F_HW_ESP_TX_CSUM) && + !(skb->dev->gso_partial_features & NETIF_F_HW_ESP_TX_CSUM)) + esp_features = features & ~(NETIF_F_CSUM_MASK | + NETIF_F_SCTP_CRC); + + xo->flags |= XFRM_GSO_SEGMENT; + + return xfrm4_outer_mode_gso_segment(x, skb, esp_features); +} + +static int esp_input_tail(struct xfrm_state *x, struct sk_buff *skb) +{ + struct crypto_aead *aead = x->data; + struct xfrm_offload *xo = xfrm_offload(skb); + + if (!pskb_may_pull(skb, sizeof(struct ip_esp_hdr) + crypto_aead_ivsize(aead))) + return -EINVAL; + + if (!(xo->flags & CRYPTO_DONE)) + skb->ip_summed = CHECKSUM_NONE; + + return esp_input_done2(skb, 0); +} + +static int esp_xmit(struct xfrm_state *x, struct sk_buff *skb, netdev_features_t features) +{ + int err; + int alen; + int blksize; + struct xfrm_offload *xo; + struct ip_esp_hdr *esph; + struct crypto_aead *aead; + struct esp_info esp; + bool hw_offload = true; + __u32 seq; + + esp.inplace = true; + + xo = xfrm_offload(skb); + + if (!xo) + return -EINVAL; + + if ((!(features & NETIF_F_HW_ESP) && + !(skb->dev->gso_partial_features & NETIF_F_HW_ESP)) || + x->xso.dev != skb->dev) { + xo->flags |= CRYPTO_FALLBACK; + hw_offload = false; + } + + esp.proto = xo->proto; + + /* skb is pure payload to encrypt */ + + aead = x->data; + alen = crypto_aead_authsize(aead); + + esp.tfclen = 0; + /* XXX: Add support for tfc padding here. */ + + blksize = ALIGN(crypto_aead_blocksize(aead), 4); + esp.clen = ALIGN(skb->len + 2 + esp.tfclen, blksize); + esp.plen = esp.clen - skb->len - esp.tfclen; + esp.tailen = esp.tfclen + esp.plen + alen; + + esp.esph = ip_esp_hdr(skb); + + + if (!hw_offload || !skb_is_gso(skb)) { + esp.nfrags = esp_output_head(x, skb, &esp); + if (esp.nfrags < 0) + return esp.nfrags; + } + + seq = xo->seq.low; + + esph = esp.esph; + esph->spi = x->id.spi; + + skb_push(skb, -skb_network_offset(skb)); + + if (xo->flags & XFRM_GSO_SEGMENT) { + esph->seq_no = htonl(seq); + + if (!skb_is_gso(skb)) + xo->seq.low++; + else + xo->seq.low += skb_shinfo(skb)->gso_segs; + } + + if (xo->seq.low < seq) + xo->seq.hi++; + + esp.seqno = cpu_to_be64(seq + ((u64)xo->seq.hi << 32)); + + ip_hdr(skb)->tot_len = htons(skb->len); + ip_send_check(ip_hdr(skb)); + + if (hw_offload) { + if (!skb_ext_add(skb, SKB_EXT_SEC_PATH)) + return -ENOMEM; + + xo = xfrm_offload(skb); + if (!xo) + return -EINVAL; + + xo->flags |= XFRM_XMIT; + return 0; + } + + err = esp_output_tail(x, skb, &esp); + if (err) + return err; + + secpath_reset(skb); + + if (skb_needs_linearize(skb, skb->dev->features) && + __skb_linearize(skb)) + return -ENOMEM; + return 0; +} + +static const struct net_offload esp4_offload = { + .callbacks = { + .gro_receive = esp4_gro_receive, + .gso_segment = esp4_gso_segment, + }, +}; + +static const struct xfrm_type_offload esp_type_offload = { + .owner = THIS_MODULE, + .proto = IPPROTO_ESP, + .input_tail = esp_input_tail, + .xmit = esp_xmit, + .encap = esp4_gso_encap, +}; + +static int __init esp4_offload_init(void) +{ + if (xfrm_register_type_offload(&esp_type_offload, AF_INET) < 0) { + pr_info("%s: can't add xfrm type offload\n", __func__); + return -EAGAIN; + } + + return inet_add_offload(&esp4_offload, IPPROTO_ESP); +} + +static void __exit esp4_offload_exit(void) +{ + xfrm_unregister_type_offload(&esp_type_offload, AF_INET); + inet_del_offload(&esp4_offload, IPPROTO_ESP); +} + +module_init(esp4_offload_init); +module_exit(esp4_offload_exit); +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Steffen Klassert "); +MODULE_ALIAS_XFRM_OFFLOAD_TYPE(AF_INET, XFRM_PROTO_ESP); +MODULE_DESCRIPTION("IPV4 GSO/GRO offload support"); diff --git a/net/ipv4/fib_frontend.c b/net/ipv4/fib_frontend.c new file mode 100644 index 000000000..390f4be7f --- /dev/null +++ b/net/ipv4/fib_frontend.c @@ -0,0 +1,1663 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * IPv4 Forwarding Information Base: FIB frontend. + * + * Authors: Alexey Kuznetsov, + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef CONFIG_IP_MULTIPLE_TABLES + +static int __net_init fib4_rules_init(struct net *net) +{ + struct fib_table *local_table, *main_table; + + main_table = fib_trie_table(RT_TABLE_MAIN, NULL); + if (!main_table) + return -ENOMEM; + + local_table = fib_trie_table(RT_TABLE_LOCAL, main_table); + if (!local_table) + goto fail; + + hlist_add_head_rcu(&local_table->tb_hlist, + &net->ipv4.fib_table_hash[TABLE_LOCAL_INDEX]); + hlist_add_head_rcu(&main_table->tb_hlist, + &net->ipv4.fib_table_hash[TABLE_MAIN_INDEX]); + return 0; + +fail: + fib_free_table(main_table); + return -ENOMEM; +} +#else + +struct fib_table *fib_new_table(struct net *net, u32 id) +{ + struct fib_table *tb, *alias = NULL; + unsigned int h; + + if (id == 0) + id = RT_TABLE_MAIN; + tb = fib_get_table(net, id); + if (tb) + return tb; + + if (id == RT_TABLE_LOCAL && !net->ipv4.fib_has_custom_rules) + alias = fib_new_table(net, RT_TABLE_MAIN); + + tb = fib_trie_table(id, alias); + if (!tb) + return NULL; + + switch (id) { + case RT_TABLE_MAIN: + rcu_assign_pointer(net->ipv4.fib_main, tb); + break; + case RT_TABLE_DEFAULT: + rcu_assign_pointer(net->ipv4.fib_default, tb); + break; + default: + break; + } + + h = id & (FIB_TABLE_HASHSZ - 1); + hlist_add_head_rcu(&tb->tb_hlist, &net->ipv4.fib_table_hash[h]); + return tb; +} +EXPORT_SYMBOL_GPL(fib_new_table); + +/* caller must hold either rtnl or rcu read lock */ +struct fib_table *fib_get_table(struct net *net, u32 id) +{ + struct fib_table *tb; + struct hlist_head *head; + unsigned int h; + + if (id == 0) + id = RT_TABLE_MAIN; + h = id & (FIB_TABLE_HASHSZ - 1); + + head = &net->ipv4.fib_table_hash[h]; + hlist_for_each_entry_rcu(tb, head, tb_hlist, + lockdep_rtnl_is_held()) { + if (tb->tb_id == id) + return tb; + } + return NULL; +} +#endif /* CONFIG_IP_MULTIPLE_TABLES */ + +static void fib_replace_table(struct net *net, struct fib_table *old, + struct fib_table *new) +{ +#ifdef CONFIG_IP_MULTIPLE_TABLES + switch (new->tb_id) { + case RT_TABLE_MAIN: + rcu_assign_pointer(net->ipv4.fib_main, new); + break; + case RT_TABLE_DEFAULT: + rcu_assign_pointer(net->ipv4.fib_default, new); + break; + default: + break; + } + +#endif + /* replace the old table in the hlist */ + hlist_replace_rcu(&old->tb_hlist, &new->tb_hlist); +} + +int fib_unmerge(struct net *net) +{ + struct fib_table *old, *new, *main_table; + + /* attempt to fetch local table if it has been allocated */ + old = fib_get_table(net, RT_TABLE_LOCAL); + if (!old) + return 0; + + new = fib_trie_unmerge(old); + if (!new) + return -ENOMEM; + + /* table is already unmerged */ + if (new == old) + return 0; + + /* replace merged table with clean table */ + fib_replace_table(net, old, new); + fib_free_table(old); + + /* attempt to fetch main table if it has been allocated */ + main_table = fib_get_table(net, RT_TABLE_MAIN); + if (!main_table) + return 0; + + /* flush local entries from main table */ + fib_table_flush_external(main_table); + + return 0; +} + +void fib_flush(struct net *net) +{ + int flushed = 0; + unsigned int h; + + for (h = 0; h < FIB_TABLE_HASHSZ; h++) { + struct hlist_head *head = &net->ipv4.fib_table_hash[h]; + struct hlist_node *tmp; + struct fib_table *tb; + + hlist_for_each_entry_safe(tb, tmp, head, tb_hlist) + flushed += fib_table_flush(net, tb, false); + } + + if (flushed) + rt_cache_flush(net); +} + +/* + * Find address type as if only "dev" was present in the system. If + * on_dev is NULL then all interfaces are taken into consideration. + */ +static inline unsigned int __inet_dev_addr_type(struct net *net, + const struct net_device *dev, + __be32 addr, u32 tb_id) +{ + struct flowi4 fl4 = { .daddr = addr }; + struct fib_result res; + unsigned int ret = RTN_BROADCAST; + struct fib_table *table; + + if (ipv4_is_zeronet(addr) || ipv4_is_lbcast(addr)) + return RTN_BROADCAST; + if (ipv4_is_multicast(addr)) + return RTN_MULTICAST; + + rcu_read_lock(); + + table = fib_get_table(net, tb_id); + if (table) { + ret = RTN_UNICAST; + if (!fib_table_lookup(table, &fl4, &res, FIB_LOOKUP_NOREF)) { + struct fib_nh_common *nhc = fib_info_nhc(res.fi, 0); + + if (!dev || dev == nhc->nhc_dev) + ret = res.type; + } + } + + rcu_read_unlock(); + return ret; +} + +unsigned int inet_addr_type_table(struct net *net, __be32 addr, u32 tb_id) +{ + return __inet_dev_addr_type(net, NULL, addr, tb_id); +} +EXPORT_SYMBOL(inet_addr_type_table); + +unsigned int inet_addr_type(struct net *net, __be32 addr) +{ + return __inet_dev_addr_type(net, NULL, addr, RT_TABLE_LOCAL); +} +EXPORT_SYMBOL(inet_addr_type); + +unsigned int inet_dev_addr_type(struct net *net, const struct net_device *dev, + __be32 addr) +{ + u32 rt_table = l3mdev_fib_table(dev) ? : RT_TABLE_LOCAL; + + return __inet_dev_addr_type(net, dev, addr, rt_table); +} +EXPORT_SYMBOL(inet_dev_addr_type); + +/* inet_addr_type with dev == NULL but using the table from a dev + * if one is associated + */ +unsigned int inet_addr_type_dev_table(struct net *net, + const struct net_device *dev, + __be32 addr) +{ + u32 rt_table = l3mdev_fib_table(dev) ? : RT_TABLE_LOCAL; + + return __inet_dev_addr_type(net, NULL, addr, rt_table); +} +EXPORT_SYMBOL(inet_addr_type_dev_table); + +__be32 fib_compute_spec_dst(struct sk_buff *skb) +{ + struct net_device *dev = skb->dev; + struct in_device *in_dev; + struct fib_result res; + struct rtable *rt; + struct net *net; + int scope; + + rt = skb_rtable(skb); + if ((rt->rt_flags & (RTCF_BROADCAST | RTCF_MULTICAST | RTCF_LOCAL)) == + RTCF_LOCAL) + return ip_hdr(skb)->daddr; + + in_dev = __in_dev_get_rcu(dev); + + net = dev_net(dev); + + scope = RT_SCOPE_UNIVERSE; + if (!ipv4_is_zeronet(ip_hdr(skb)->saddr)) { + bool vmark = in_dev && IN_DEV_SRC_VMARK(in_dev); + struct flowi4 fl4 = { + .flowi4_iif = LOOPBACK_IFINDEX, + .flowi4_l3mdev = l3mdev_master_ifindex_rcu(dev), + .daddr = ip_hdr(skb)->saddr, + .flowi4_tos = ip_hdr(skb)->tos & IPTOS_RT_MASK, + .flowi4_scope = scope, + .flowi4_mark = vmark ? skb->mark : 0, + }; + if (!fib_lookup(net, &fl4, &res, 0)) + return fib_result_prefsrc(net, &res); + } else { + scope = RT_SCOPE_LINK; + } + + return inet_select_addr(dev, ip_hdr(skb)->saddr, scope); +} + +bool fib_info_nh_uses_dev(struct fib_info *fi, const struct net_device *dev) +{ + bool dev_match = false; +#ifdef CONFIG_IP_ROUTE_MULTIPATH + if (unlikely(fi->nh)) { + dev_match = nexthop_uses_dev(fi->nh, dev); + } else { + int ret; + + for (ret = 0; ret < fib_info_num_path(fi); ret++) { + const struct fib_nh_common *nhc = fib_info_nhc(fi, ret); + + if (nhc_l3mdev_matches_dev(nhc, dev)) { + dev_match = true; + break; + } + } + } +#else + if (fib_info_nhc(fi, 0)->nhc_dev == dev) + dev_match = true; +#endif + + return dev_match; +} +EXPORT_SYMBOL_GPL(fib_info_nh_uses_dev); + +/* Given (packet source, input interface) and optional (dst, oif, tos): + * - (main) check, that source is valid i.e. not broadcast or our local + * address. + * - figure out what "logical" interface this packet arrived + * and calculate "specific destination" address. + * - check, that packet arrived from expected physical interface. + * called with rcu_read_lock() + */ +static int __fib_validate_source(struct sk_buff *skb, __be32 src, __be32 dst, + u8 tos, int oif, struct net_device *dev, + int rpf, struct in_device *idev, u32 *itag) +{ + struct net *net = dev_net(dev); + struct flow_keys flkeys; + int ret, no_addr; + struct fib_result res; + struct flowi4 fl4; + bool dev_match; + + fl4.flowi4_oif = 0; + fl4.flowi4_l3mdev = l3mdev_master_ifindex_rcu(dev); + fl4.flowi4_iif = oif ? : LOOPBACK_IFINDEX; + fl4.daddr = src; + fl4.saddr = dst; + fl4.flowi4_tos = tos; + fl4.flowi4_scope = RT_SCOPE_UNIVERSE; + fl4.flowi4_tun_key.tun_id = 0; + fl4.flowi4_flags = 0; + fl4.flowi4_uid = sock_net_uid(net, NULL); + fl4.flowi4_multipath_hash = 0; + + no_addr = idev->ifa_list == NULL; + + fl4.flowi4_mark = IN_DEV_SRC_VMARK(idev) ? skb->mark : 0; + if (!fib4_rules_early_flow_dissect(net, skb, &fl4, &flkeys)) { + fl4.flowi4_proto = 0; + fl4.fl4_sport = 0; + fl4.fl4_dport = 0; + } else { + swap(fl4.fl4_sport, fl4.fl4_dport); + } + + if (fib_lookup(net, &fl4, &res, 0)) + goto last_resort; + if (res.type != RTN_UNICAST && + (res.type != RTN_LOCAL || !IN_DEV_ACCEPT_LOCAL(idev))) + goto e_inval; + fib_combine_itag(itag, &res); + + dev_match = fib_info_nh_uses_dev(res.fi, dev); + /* This is not common, loopback packets retain skb_dst so normally they + * would not even hit this slow path. + */ + dev_match = dev_match || (res.type == RTN_LOCAL && + dev == net->loopback_dev); + if (dev_match) { + ret = FIB_RES_NHC(res)->nhc_scope >= RT_SCOPE_HOST; + return ret; + } + if (no_addr) + goto last_resort; + if (rpf == 1) + goto e_rpf; + fl4.flowi4_oif = dev->ifindex; + + ret = 0; + if (fib_lookup(net, &fl4, &res, FIB_LOOKUP_IGNORE_LINKSTATE) == 0) { + if (res.type == RTN_UNICAST) + ret = FIB_RES_NHC(res)->nhc_scope >= RT_SCOPE_HOST; + } + return ret; + +last_resort: + if (rpf) + goto e_rpf; + *itag = 0; + return 0; + +e_inval: + return -EINVAL; +e_rpf: + return -EXDEV; +} + +/* Ignore rp_filter for packets protected by IPsec. */ +int fib_validate_source(struct sk_buff *skb, __be32 src, __be32 dst, + u8 tos, int oif, struct net_device *dev, + struct in_device *idev, u32 *itag) +{ + int r = secpath_exists(skb) ? 0 : IN_DEV_RPFILTER(idev); + struct net *net = dev_net(dev); + + if (!r && !fib_num_tclassid_users(net) && + (dev->ifindex != oif || !IN_DEV_TX_REDIRECTS(idev))) { + if (IN_DEV_ACCEPT_LOCAL(idev)) + goto ok; + /* with custom local routes in place, checking local addresses + * only will be too optimistic, with custom rules, checking + * local addresses only can be too strict, e.g. due to vrf + */ + if (net->ipv4.fib_has_custom_local_routes || + fib4_has_custom_rules(net)) + goto full_check; + /* Within the same container, it is regarded as a martian source, + * and the same host but different containers are not. + */ + if (inet_lookup_ifaddr_rcu(net, src)) + return -EINVAL; + +ok: + *itag = 0; + return 0; + } + +full_check: + return __fib_validate_source(skb, src, dst, tos, oif, dev, r, idev, itag); +} + +static inline __be32 sk_extract_addr(struct sockaddr *addr) +{ + return ((struct sockaddr_in *) addr)->sin_addr.s_addr; +} + +static int put_rtax(struct nlattr *mx, int len, int type, u32 value) +{ + struct nlattr *nla; + + nla = (struct nlattr *) ((char *) mx + len); + nla->nla_type = type; + nla->nla_len = nla_attr_size(4); + *(u32 *) nla_data(nla) = value; + + return len + nla_total_size(4); +} + +static int rtentry_to_fib_config(struct net *net, int cmd, struct rtentry *rt, + struct fib_config *cfg) +{ + __be32 addr; + int plen; + + memset(cfg, 0, sizeof(*cfg)); + cfg->fc_nlinfo.nl_net = net; + + if (rt->rt_dst.sa_family != AF_INET) + return -EAFNOSUPPORT; + + /* + * Check mask for validity: + * a) it must be contiguous. + * b) destination must have all host bits clear. + * c) if application forgot to set correct family (AF_INET), + * reject request unless it is absolutely clear i.e. + * both family and mask are zero. + */ + plen = 32; + addr = sk_extract_addr(&rt->rt_dst); + if (!(rt->rt_flags & RTF_HOST)) { + __be32 mask = sk_extract_addr(&rt->rt_genmask); + + if (rt->rt_genmask.sa_family != AF_INET) { + if (mask || rt->rt_genmask.sa_family) + return -EAFNOSUPPORT; + } + + if (bad_mask(mask, addr)) + return -EINVAL; + + plen = inet_mask_len(mask); + } + + cfg->fc_dst_len = plen; + cfg->fc_dst = addr; + + if (cmd != SIOCDELRT) { + cfg->fc_nlflags = NLM_F_CREATE; + cfg->fc_protocol = RTPROT_BOOT; + } + + if (rt->rt_metric) + cfg->fc_priority = rt->rt_metric - 1; + + if (rt->rt_flags & RTF_REJECT) { + cfg->fc_scope = RT_SCOPE_HOST; + cfg->fc_type = RTN_UNREACHABLE; + return 0; + } + + cfg->fc_scope = RT_SCOPE_NOWHERE; + cfg->fc_type = RTN_UNICAST; + + if (rt->rt_dev) { + char *colon; + struct net_device *dev; + char devname[IFNAMSIZ]; + + if (copy_from_user(devname, rt->rt_dev, IFNAMSIZ-1)) + return -EFAULT; + + devname[IFNAMSIZ-1] = 0; + colon = strchr(devname, ':'); + if (colon) + *colon = 0; + dev = __dev_get_by_name(net, devname); + if (!dev) + return -ENODEV; + cfg->fc_oif = dev->ifindex; + cfg->fc_table = l3mdev_fib_table(dev); + if (colon) { + const struct in_ifaddr *ifa; + struct in_device *in_dev; + + in_dev = __in_dev_get_rtnl(dev); + if (!in_dev) + return -ENODEV; + + *colon = ':'; + + rcu_read_lock(); + in_dev_for_each_ifa_rcu(ifa, in_dev) { + if (strcmp(ifa->ifa_label, devname) == 0) + break; + } + rcu_read_unlock(); + + if (!ifa) + return -ENODEV; + cfg->fc_prefsrc = ifa->ifa_local; + } + } + + addr = sk_extract_addr(&rt->rt_gateway); + if (rt->rt_gateway.sa_family == AF_INET && addr) { + unsigned int addr_type; + + cfg->fc_gw4 = addr; + cfg->fc_gw_family = AF_INET; + addr_type = inet_addr_type_table(net, addr, cfg->fc_table); + if (rt->rt_flags & RTF_GATEWAY && + addr_type == RTN_UNICAST) + cfg->fc_scope = RT_SCOPE_UNIVERSE; + } + + if (!cfg->fc_table) + cfg->fc_table = RT_TABLE_MAIN; + + if (cmd == SIOCDELRT) + return 0; + + if (rt->rt_flags & RTF_GATEWAY && !cfg->fc_gw_family) + return -EINVAL; + + if (cfg->fc_scope == RT_SCOPE_NOWHERE) + cfg->fc_scope = RT_SCOPE_LINK; + + if (rt->rt_flags & (RTF_MTU | RTF_WINDOW | RTF_IRTT)) { + struct nlattr *mx; + int len = 0; + + mx = kcalloc(3, nla_total_size(4), GFP_KERNEL); + if (!mx) + return -ENOMEM; + + if (rt->rt_flags & RTF_MTU) + len = put_rtax(mx, len, RTAX_ADVMSS, rt->rt_mtu - 40); + + if (rt->rt_flags & RTF_WINDOW) + len = put_rtax(mx, len, RTAX_WINDOW, rt->rt_window); + + if (rt->rt_flags & RTF_IRTT) + len = put_rtax(mx, len, RTAX_RTT, rt->rt_irtt << 3); + + cfg->fc_mx = mx; + cfg->fc_mx_len = len; + } + + return 0; +} + +/* + * Handle IP routing ioctl calls. + * These are used to manipulate the routing tables + */ +int ip_rt_ioctl(struct net *net, unsigned int cmd, struct rtentry *rt) +{ + struct fib_config cfg; + int err; + + switch (cmd) { + case SIOCADDRT: /* Add a route */ + case SIOCDELRT: /* Delete a route */ + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + rtnl_lock(); + err = rtentry_to_fib_config(net, cmd, rt, &cfg); + if (err == 0) { + struct fib_table *tb; + + if (cmd == SIOCDELRT) { + tb = fib_get_table(net, cfg.fc_table); + if (tb) + err = fib_table_delete(net, tb, &cfg, + NULL); + else + err = -ESRCH; + } else { + tb = fib_new_table(net, cfg.fc_table); + if (tb) + err = fib_table_insert(net, tb, + &cfg, NULL); + else + err = -ENOBUFS; + } + + /* allocated by rtentry_to_fib_config() */ + kfree(cfg.fc_mx); + } + rtnl_unlock(); + return err; + } + return -EINVAL; +} + +const struct nla_policy rtm_ipv4_policy[RTA_MAX + 1] = { + [RTA_UNSPEC] = { .strict_start_type = RTA_DPORT + 1 }, + [RTA_DST] = { .type = NLA_U32 }, + [RTA_SRC] = { .type = NLA_U32 }, + [RTA_IIF] = { .type = NLA_U32 }, + [RTA_OIF] = { .type = NLA_U32 }, + [RTA_GATEWAY] = { .type = NLA_U32 }, + [RTA_PRIORITY] = { .type = NLA_U32 }, + [RTA_PREFSRC] = { .type = NLA_U32 }, + [RTA_METRICS] = { .type = NLA_NESTED }, + [RTA_MULTIPATH] = { .len = sizeof(struct rtnexthop) }, + [RTA_FLOW] = { .type = NLA_U32 }, + [RTA_ENCAP_TYPE] = { .type = NLA_U16 }, + [RTA_ENCAP] = { .type = NLA_NESTED }, + [RTA_UID] = { .type = NLA_U32 }, + [RTA_MARK] = { .type = NLA_U32 }, + [RTA_TABLE] = { .type = NLA_U32 }, + [RTA_IP_PROTO] = { .type = NLA_U8 }, + [RTA_SPORT] = { .type = NLA_U16 }, + [RTA_DPORT] = { .type = NLA_U16 }, + [RTA_NH_ID] = { .type = NLA_U32 }, +}; + +int fib_gw_from_via(struct fib_config *cfg, struct nlattr *nla, + struct netlink_ext_ack *extack) +{ + struct rtvia *via; + int alen; + + if (nla_len(nla) < offsetof(struct rtvia, rtvia_addr)) { + NL_SET_ERR_MSG(extack, "Invalid attribute length for RTA_VIA"); + return -EINVAL; + } + + via = nla_data(nla); + alen = nla_len(nla) - offsetof(struct rtvia, rtvia_addr); + + switch (via->rtvia_family) { + case AF_INET: + if (alen != sizeof(__be32)) { + NL_SET_ERR_MSG(extack, "Invalid IPv4 address in RTA_VIA"); + return -EINVAL; + } + cfg->fc_gw_family = AF_INET; + cfg->fc_gw4 = *((__be32 *)via->rtvia_addr); + break; + case AF_INET6: +#if IS_ENABLED(CONFIG_IPV6) + if (alen != sizeof(struct in6_addr)) { + NL_SET_ERR_MSG(extack, "Invalid IPv6 address in RTA_VIA"); + return -EINVAL; + } + cfg->fc_gw_family = AF_INET6; + cfg->fc_gw6 = *((struct in6_addr *)via->rtvia_addr); +#else + NL_SET_ERR_MSG(extack, "IPv6 support not enabled in kernel"); + return -EINVAL; +#endif + break; + default: + NL_SET_ERR_MSG(extack, "Unsupported address family in RTA_VIA"); + return -EINVAL; + } + + return 0; +} + +static int rtm_to_fib_config(struct net *net, struct sk_buff *skb, + struct nlmsghdr *nlh, struct fib_config *cfg, + struct netlink_ext_ack *extack) +{ + bool has_gw = false, has_via = false; + struct nlattr *attr; + int err, remaining; + struct rtmsg *rtm; + + err = nlmsg_validate_deprecated(nlh, sizeof(*rtm), RTA_MAX, + rtm_ipv4_policy, extack); + if (err < 0) + goto errout; + + memset(cfg, 0, sizeof(*cfg)); + + rtm = nlmsg_data(nlh); + + if (!inet_validate_dscp(rtm->rtm_tos)) { + NL_SET_ERR_MSG(extack, + "Invalid dsfield (tos): ECN bits must be 0"); + err = -EINVAL; + goto errout; + } + cfg->fc_dscp = inet_dsfield_to_dscp(rtm->rtm_tos); + + cfg->fc_dst_len = rtm->rtm_dst_len; + cfg->fc_table = rtm->rtm_table; + cfg->fc_protocol = rtm->rtm_protocol; + cfg->fc_scope = rtm->rtm_scope; + cfg->fc_type = rtm->rtm_type; + cfg->fc_flags = rtm->rtm_flags; + cfg->fc_nlflags = nlh->nlmsg_flags; + + cfg->fc_nlinfo.portid = NETLINK_CB(skb).portid; + cfg->fc_nlinfo.nlh = nlh; + cfg->fc_nlinfo.nl_net = net; + + if (cfg->fc_type > RTN_MAX) { + NL_SET_ERR_MSG(extack, "Invalid route type"); + err = -EINVAL; + goto errout; + } + + nlmsg_for_each_attr(attr, nlh, sizeof(struct rtmsg), remaining) { + switch (nla_type(attr)) { + case RTA_DST: + cfg->fc_dst = nla_get_be32(attr); + break; + case RTA_OIF: + cfg->fc_oif = nla_get_u32(attr); + break; + case RTA_GATEWAY: + has_gw = true; + cfg->fc_gw4 = nla_get_be32(attr); + if (cfg->fc_gw4) + cfg->fc_gw_family = AF_INET; + break; + case RTA_VIA: + has_via = true; + err = fib_gw_from_via(cfg, attr, extack); + if (err) + goto errout; + break; + case RTA_PRIORITY: + cfg->fc_priority = nla_get_u32(attr); + break; + case RTA_PREFSRC: + cfg->fc_prefsrc = nla_get_be32(attr); + break; + case RTA_METRICS: + cfg->fc_mx = nla_data(attr); + cfg->fc_mx_len = nla_len(attr); + break; + case RTA_MULTIPATH: + err = lwtunnel_valid_encap_type_attr(nla_data(attr), + nla_len(attr), + extack); + if (err < 0) + goto errout; + cfg->fc_mp = nla_data(attr); + cfg->fc_mp_len = nla_len(attr); + break; + case RTA_FLOW: + cfg->fc_flow = nla_get_u32(attr); + break; + case RTA_TABLE: + cfg->fc_table = nla_get_u32(attr); + break; + case RTA_ENCAP: + cfg->fc_encap = attr; + break; + case RTA_ENCAP_TYPE: + cfg->fc_encap_type = nla_get_u16(attr); + err = lwtunnel_valid_encap_type(cfg->fc_encap_type, + extack); + if (err < 0) + goto errout; + break; + case RTA_NH_ID: + cfg->fc_nh_id = nla_get_u32(attr); + break; + } + } + + if (cfg->fc_nh_id) { + if (cfg->fc_oif || cfg->fc_gw_family || + cfg->fc_encap || cfg->fc_mp) { + NL_SET_ERR_MSG(extack, + "Nexthop specification and nexthop id are mutually exclusive"); + return -EINVAL; + } + } + + if (has_gw && has_via) { + NL_SET_ERR_MSG(extack, + "Nexthop configuration can not contain both GATEWAY and VIA"); + return -EINVAL; + } + + if (!cfg->fc_table) + cfg->fc_table = RT_TABLE_MAIN; + + return 0; +errout: + return err; +} + +static int inet_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct fib_config cfg; + struct fib_table *tb; + int err; + + err = rtm_to_fib_config(net, skb, nlh, &cfg, extack); + if (err < 0) + goto errout; + + if (cfg.fc_nh_id && !nexthop_find_by_id(net, cfg.fc_nh_id)) { + NL_SET_ERR_MSG(extack, "Nexthop id does not exist"); + err = -EINVAL; + goto errout; + } + + tb = fib_get_table(net, cfg.fc_table); + if (!tb) { + NL_SET_ERR_MSG(extack, "FIB table does not exist"); + err = -ESRCH; + goto errout; + } + + err = fib_table_delete(net, tb, &cfg, extack); +errout: + return err; +} + +static int inet_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct fib_config cfg; + struct fib_table *tb; + int err; + + err = rtm_to_fib_config(net, skb, nlh, &cfg, extack); + if (err < 0) + goto errout; + + tb = fib_new_table(net, cfg.fc_table); + if (!tb) { + err = -ENOBUFS; + goto errout; + } + + err = fib_table_insert(net, tb, &cfg, extack); + if (!err && cfg.fc_type == RTN_LOCAL) + net->ipv4.fib_has_custom_local_routes = true; +errout: + return err; +} + +int ip_valid_fib_dump_req(struct net *net, const struct nlmsghdr *nlh, + struct fib_dump_filter *filter, + struct netlink_callback *cb) +{ + struct netlink_ext_ack *extack = cb->extack; + struct nlattr *tb[RTA_MAX + 1]; + struct rtmsg *rtm; + int err, i; + + ASSERT_RTNL(); + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) { + NL_SET_ERR_MSG(extack, "Invalid header for FIB dump request"); + return -EINVAL; + } + + rtm = nlmsg_data(nlh); + if (rtm->rtm_dst_len || rtm->rtm_src_len || rtm->rtm_tos || + rtm->rtm_scope) { + NL_SET_ERR_MSG(extack, "Invalid values in header for FIB dump request"); + return -EINVAL; + } + + if (rtm->rtm_flags & ~(RTM_F_CLONED | RTM_F_PREFIX)) { + NL_SET_ERR_MSG(extack, "Invalid flags for FIB dump request"); + return -EINVAL; + } + if (rtm->rtm_flags & RTM_F_CLONED) + filter->dump_routes = false; + else + filter->dump_exceptions = false; + + filter->flags = rtm->rtm_flags; + filter->protocol = rtm->rtm_protocol; + filter->rt_type = rtm->rtm_type; + filter->table_id = rtm->rtm_table; + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_ipv4_policy, extack); + if (err < 0) + return err; + + for (i = 0; i <= RTA_MAX; ++i) { + int ifindex; + + if (!tb[i]) + continue; + + switch (i) { + case RTA_TABLE: + filter->table_id = nla_get_u32(tb[i]); + break; + case RTA_OIF: + ifindex = nla_get_u32(tb[i]); + filter->dev = __dev_get_by_index(net, ifindex); + if (!filter->dev) + return -ENODEV; + break; + default: + NL_SET_ERR_MSG(extack, "Unsupported attribute in dump request"); + return -EINVAL; + } + } + + if (filter->flags || filter->protocol || filter->rt_type || + filter->table_id || filter->dev) { + filter->filter_set = 1; + cb->answer_flags = NLM_F_DUMP_FILTERED; + } + + return 0; +} +EXPORT_SYMBOL_GPL(ip_valid_fib_dump_req); + +static int inet_dump_fib(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct fib_dump_filter filter = { .dump_routes = true, + .dump_exceptions = true }; + const struct nlmsghdr *nlh = cb->nlh; + struct net *net = sock_net(skb->sk); + unsigned int h, s_h; + unsigned int e = 0, s_e; + struct fib_table *tb; + struct hlist_head *head; + int dumped = 0, err; + + if (cb->strict_check) { + err = ip_valid_fib_dump_req(net, nlh, &filter, cb); + if (err < 0) + return err; + } else if (nlmsg_len(nlh) >= sizeof(struct rtmsg)) { + struct rtmsg *rtm = nlmsg_data(nlh); + + filter.flags = rtm->rtm_flags & (RTM_F_PREFIX | RTM_F_CLONED); + } + + /* ipv4 does not use prefix flag */ + if (filter.flags & RTM_F_PREFIX) + return skb->len; + + if (filter.table_id) { + tb = fib_get_table(net, filter.table_id); + if (!tb) { + if (rtnl_msg_family(cb->nlh) != PF_INET) + return skb->len; + + NL_SET_ERR_MSG(cb->extack, "ipv4: FIB table does not exist"); + return -ENOENT; + } + + rcu_read_lock(); + err = fib_table_dump(tb, skb, cb, &filter); + rcu_read_unlock(); + return skb->len ? : err; + } + + s_h = cb->args[0]; + s_e = cb->args[1]; + + rcu_read_lock(); + + for (h = s_h; h < FIB_TABLE_HASHSZ; h++, s_e = 0) { + e = 0; + head = &net->ipv4.fib_table_hash[h]; + hlist_for_each_entry_rcu(tb, head, tb_hlist) { + if (e < s_e) + goto next; + if (dumped) + memset(&cb->args[2], 0, sizeof(cb->args) - + 2 * sizeof(cb->args[0])); + err = fib_table_dump(tb, skb, cb, &filter); + if (err < 0) { + if (likely(skb->len)) + goto out; + + goto out_err; + } + dumped = 1; +next: + e++; + } + } +out: + err = skb->len; +out_err: + rcu_read_unlock(); + + cb->args[1] = e; + cb->args[0] = h; + + return err; +} + +/* Prepare and feed intra-kernel routing request. + * Really, it should be netlink message, but :-( netlink + * can be not configured, so that we feed it directly + * to fib engine. It is legal, because all events occur + * only when netlink is already locked. + */ +static void fib_magic(int cmd, int type, __be32 dst, int dst_len, + struct in_ifaddr *ifa, u32 rt_priority) +{ + struct net *net = dev_net(ifa->ifa_dev->dev); + u32 tb_id = l3mdev_fib_table(ifa->ifa_dev->dev); + struct fib_table *tb; + struct fib_config cfg = { + .fc_protocol = RTPROT_KERNEL, + .fc_type = type, + .fc_dst = dst, + .fc_dst_len = dst_len, + .fc_priority = rt_priority, + .fc_prefsrc = ifa->ifa_local, + .fc_oif = ifa->ifa_dev->dev->ifindex, + .fc_nlflags = NLM_F_CREATE | NLM_F_APPEND, + .fc_nlinfo = { + .nl_net = net, + }, + }; + + if (!tb_id) + tb_id = (type == RTN_UNICAST) ? RT_TABLE_MAIN : RT_TABLE_LOCAL; + + tb = fib_new_table(net, tb_id); + if (!tb) + return; + + cfg.fc_table = tb->tb_id; + + if (type != RTN_LOCAL) + cfg.fc_scope = RT_SCOPE_LINK; + else + cfg.fc_scope = RT_SCOPE_HOST; + + if (cmd == RTM_NEWROUTE) + fib_table_insert(net, tb, &cfg, NULL); + else + fib_table_delete(net, tb, &cfg, NULL); +} + +void fib_add_ifaddr(struct in_ifaddr *ifa) +{ + struct in_device *in_dev = ifa->ifa_dev; + struct net_device *dev = in_dev->dev; + struct in_ifaddr *prim = ifa; + __be32 mask = ifa->ifa_mask; + __be32 addr = ifa->ifa_local; + __be32 prefix = ifa->ifa_address & mask; + + if (ifa->ifa_flags & IFA_F_SECONDARY) { + prim = inet_ifa_byprefix(in_dev, prefix, mask); + if (!prim) { + pr_warn("%s: bug: prim == NULL\n", __func__); + return; + } + } + + fib_magic(RTM_NEWROUTE, RTN_LOCAL, addr, 32, prim, 0); + + if (!(dev->flags & IFF_UP)) + return; + + /* Add broadcast address, if it is explicitly assigned. */ + if (ifa->ifa_broadcast && ifa->ifa_broadcast != htonl(0xFFFFFFFF)) { + fib_magic(RTM_NEWROUTE, RTN_BROADCAST, ifa->ifa_broadcast, 32, + prim, 0); + arp_invalidate(dev, ifa->ifa_broadcast, false); + } + + if (!ipv4_is_zeronet(prefix) && !(ifa->ifa_flags & IFA_F_SECONDARY) && + (prefix != addr || ifa->ifa_prefixlen < 32)) { + if (!(ifa->ifa_flags & IFA_F_NOPREFIXROUTE)) + fib_magic(RTM_NEWROUTE, + dev->flags & IFF_LOOPBACK ? RTN_LOCAL : RTN_UNICAST, + prefix, ifa->ifa_prefixlen, prim, + ifa->ifa_rt_priority); + + /* Add the network broadcast address, when it makes sense */ + if (ifa->ifa_prefixlen < 31) { + fib_magic(RTM_NEWROUTE, RTN_BROADCAST, prefix | ~mask, + 32, prim, 0); + arp_invalidate(dev, prefix | ~mask, false); + } + } +} + +void fib_modify_prefix_metric(struct in_ifaddr *ifa, u32 new_metric) +{ + __be32 prefix = ifa->ifa_address & ifa->ifa_mask; + struct in_device *in_dev = ifa->ifa_dev; + struct net_device *dev = in_dev->dev; + + if (!(dev->flags & IFF_UP) || + ifa->ifa_flags & (IFA_F_SECONDARY | IFA_F_NOPREFIXROUTE) || + ipv4_is_zeronet(prefix) || + (prefix == ifa->ifa_local && ifa->ifa_prefixlen == 32)) + return; + + /* add the new */ + fib_magic(RTM_NEWROUTE, + dev->flags & IFF_LOOPBACK ? RTN_LOCAL : RTN_UNICAST, + prefix, ifa->ifa_prefixlen, ifa, new_metric); + + /* delete the old */ + fib_magic(RTM_DELROUTE, + dev->flags & IFF_LOOPBACK ? RTN_LOCAL : RTN_UNICAST, + prefix, ifa->ifa_prefixlen, ifa, ifa->ifa_rt_priority); +} + +/* Delete primary or secondary address. + * Optionally, on secondary address promotion consider the addresses + * from subnet iprim as deleted, even if they are in device list. + * In this case the secondary ifa can be in device list. + */ +void fib_del_ifaddr(struct in_ifaddr *ifa, struct in_ifaddr *iprim) +{ + struct in_device *in_dev = ifa->ifa_dev; + struct net_device *dev = in_dev->dev; + struct in_ifaddr *ifa1; + struct in_ifaddr *prim = ifa, *prim1 = NULL; + __be32 brd = ifa->ifa_address | ~ifa->ifa_mask; + __be32 any = ifa->ifa_address & ifa->ifa_mask; +#define LOCAL_OK 1 +#define BRD_OK 2 +#define BRD0_OK 4 +#define BRD1_OK 8 + unsigned int ok = 0; + int subnet = 0; /* Primary network */ + int gone = 1; /* Address is missing */ + int same_prefsrc = 0; /* Another primary with same IP */ + + if (ifa->ifa_flags & IFA_F_SECONDARY) { + prim = inet_ifa_byprefix(in_dev, any, ifa->ifa_mask); + if (!prim) { + /* if the device has been deleted, we don't perform + * address promotion + */ + if (!in_dev->dead) + pr_warn("%s: bug: prim == NULL\n", __func__); + return; + } + if (iprim && iprim != prim) { + pr_warn("%s: bug: iprim != prim\n", __func__); + return; + } + } else if (!ipv4_is_zeronet(any) && + (any != ifa->ifa_local || ifa->ifa_prefixlen < 32)) { + if (!(ifa->ifa_flags & IFA_F_NOPREFIXROUTE)) + fib_magic(RTM_DELROUTE, + dev->flags & IFF_LOOPBACK ? RTN_LOCAL : RTN_UNICAST, + any, ifa->ifa_prefixlen, prim, 0); + subnet = 1; + } + + if (in_dev->dead) + goto no_promotions; + + /* Deletion is more complicated than add. + * We should take care of not to delete too much :-) + * + * Scan address list to be sure that addresses are really gone. + */ + rcu_read_lock(); + in_dev_for_each_ifa_rcu(ifa1, in_dev) { + if (ifa1 == ifa) { + /* promotion, keep the IP */ + gone = 0; + continue; + } + /* Ignore IFAs from our subnet */ + if (iprim && ifa1->ifa_mask == iprim->ifa_mask && + inet_ifa_match(ifa1->ifa_address, iprim)) + continue; + + /* Ignore ifa1 if it uses different primary IP (prefsrc) */ + if (ifa1->ifa_flags & IFA_F_SECONDARY) { + /* Another address from our subnet? */ + if (ifa1->ifa_mask == prim->ifa_mask && + inet_ifa_match(ifa1->ifa_address, prim)) + prim1 = prim; + else { + /* We reached the secondaries, so + * same_prefsrc should be determined. + */ + if (!same_prefsrc) + continue; + /* Search new prim1 if ifa1 is not + * using the current prim1 + */ + if (!prim1 || + ifa1->ifa_mask != prim1->ifa_mask || + !inet_ifa_match(ifa1->ifa_address, prim1)) + prim1 = inet_ifa_byprefix(in_dev, + ifa1->ifa_address, + ifa1->ifa_mask); + if (!prim1) + continue; + if (prim1->ifa_local != prim->ifa_local) + continue; + } + } else { + if (prim->ifa_local != ifa1->ifa_local) + continue; + prim1 = ifa1; + if (prim != prim1) + same_prefsrc = 1; + } + if (ifa->ifa_local == ifa1->ifa_local) + ok |= LOCAL_OK; + if (ifa->ifa_broadcast == ifa1->ifa_broadcast) + ok |= BRD_OK; + if (brd == ifa1->ifa_broadcast) + ok |= BRD1_OK; + if (any == ifa1->ifa_broadcast) + ok |= BRD0_OK; + /* primary has network specific broadcasts */ + if (prim1 == ifa1 && ifa1->ifa_prefixlen < 31) { + __be32 brd1 = ifa1->ifa_address | ~ifa1->ifa_mask; + __be32 any1 = ifa1->ifa_address & ifa1->ifa_mask; + + if (!ipv4_is_zeronet(any1)) { + if (ifa->ifa_broadcast == brd1 || + ifa->ifa_broadcast == any1) + ok |= BRD_OK; + if (brd == brd1 || brd == any1) + ok |= BRD1_OK; + if (any == brd1 || any == any1) + ok |= BRD0_OK; + } + } + } + rcu_read_unlock(); + +no_promotions: + if (!(ok & BRD_OK)) + fib_magic(RTM_DELROUTE, RTN_BROADCAST, ifa->ifa_broadcast, 32, + prim, 0); + if (subnet && ifa->ifa_prefixlen < 31) { + if (!(ok & BRD1_OK)) + fib_magic(RTM_DELROUTE, RTN_BROADCAST, brd, 32, + prim, 0); + if (!(ok & BRD0_OK)) + fib_magic(RTM_DELROUTE, RTN_BROADCAST, any, 32, + prim, 0); + } + if (!(ok & LOCAL_OK)) { + unsigned int addr_type; + + fib_magic(RTM_DELROUTE, RTN_LOCAL, ifa->ifa_local, 32, prim, 0); + + /* Check, that this local address finally disappeared. */ + addr_type = inet_addr_type_dev_table(dev_net(dev), dev, + ifa->ifa_local); + if (gone && addr_type != RTN_LOCAL) { + /* And the last, but not the least thing. + * We must flush stray FIB entries. + * + * First of all, we scan fib_info list searching + * for stray nexthop entries, then ignite fib_flush. + */ + if (fib_sync_down_addr(dev, ifa->ifa_local)) + fib_flush(dev_net(dev)); + } + } +#undef LOCAL_OK +#undef BRD_OK +#undef BRD0_OK +#undef BRD1_OK +} + +static void nl_fib_lookup(struct net *net, struct fib_result_nl *frn) +{ + + struct fib_result res; + struct flowi4 fl4 = { + .flowi4_mark = frn->fl_mark, + .daddr = frn->fl_addr, + .flowi4_tos = frn->fl_tos, + .flowi4_scope = frn->fl_scope, + }; + struct fib_table *tb; + + rcu_read_lock(); + + tb = fib_get_table(net, frn->tb_id_in); + + frn->err = -ENOENT; + if (tb) { + local_bh_disable(); + + frn->tb_id = tb->tb_id; + frn->err = fib_table_lookup(tb, &fl4, &res, FIB_LOOKUP_NOREF); + + if (!frn->err) { + frn->prefixlen = res.prefixlen; + frn->nh_sel = res.nh_sel; + frn->type = res.type; + frn->scope = res.scope; + } + local_bh_enable(); + } + + rcu_read_unlock(); +} + +static void nl_fib_input(struct sk_buff *skb) +{ + struct net *net; + struct fib_result_nl *frn; + struct nlmsghdr *nlh; + u32 portid; + + net = sock_net(skb->sk); + nlh = nlmsg_hdr(skb); + if (skb->len < nlmsg_total_size(sizeof(*frn)) || + skb->len < nlh->nlmsg_len || + nlmsg_len(nlh) < sizeof(*frn)) + return; + + skb = netlink_skb_clone(skb, GFP_KERNEL); + if (!skb) + return; + nlh = nlmsg_hdr(skb); + + frn = nlmsg_data(nlh); + nl_fib_lookup(net, frn); + + portid = NETLINK_CB(skb).portid; /* netlink portid */ + NETLINK_CB(skb).portid = 0; /* from kernel */ + NETLINK_CB(skb).dst_group = 0; /* unicast */ + nlmsg_unicast(net->ipv4.fibnl, skb, portid); +} + +static int __net_init nl_fib_lookup_init(struct net *net) +{ + struct sock *sk; + struct netlink_kernel_cfg cfg = { + .input = nl_fib_input, + }; + + sk = netlink_kernel_create(net, NETLINK_FIB_LOOKUP, &cfg); + if (!sk) + return -EAFNOSUPPORT; + net->ipv4.fibnl = sk; + return 0; +} + +static void nl_fib_lookup_exit(struct net *net) +{ + netlink_kernel_release(net->ipv4.fibnl); + net->ipv4.fibnl = NULL; +} + +static void fib_disable_ip(struct net_device *dev, unsigned long event, + bool force) +{ + if (fib_sync_down_dev(dev, event, force)) + fib_flush(dev_net(dev)); + else + rt_cache_flush(dev_net(dev)); + arp_ifdown(dev); +} + +static int fib_inetaddr_event(struct notifier_block *this, unsigned long event, void *ptr) +{ + struct in_ifaddr *ifa = ptr; + struct net_device *dev = ifa->ifa_dev->dev; + struct net *net = dev_net(dev); + + switch (event) { + case NETDEV_UP: + fib_add_ifaddr(ifa); +#ifdef CONFIG_IP_ROUTE_MULTIPATH + fib_sync_up(dev, RTNH_F_DEAD); +#endif + atomic_inc(&net->ipv4.dev_addr_genid); + rt_cache_flush(dev_net(dev)); + break; + case NETDEV_DOWN: + fib_del_ifaddr(ifa, NULL); + atomic_inc(&net->ipv4.dev_addr_genid); + if (!ifa->ifa_dev->ifa_list) { + /* Last address was deleted from this interface. + * Disable IP. + */ + fib_disable_ip(dev, event, true); + } else { + rt_cache_flush(dev_net(dev)); + } + break; + } + return NOTIFY_DONE; +} + +static int fib_netdev_event(struct notifier_block *this, unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct netdev_notifier_changeupper_info *upper_info = ptr; + struct netdev_notifier_info_ext *info_ext = ptr; + struct in_device *in_dev; + struct net *net = dev_net(dev); + struct in_ifaddr *ifa; + unsigned int flags; + + if (event == NETDEV_UNREGISTER) { + fib_disable_ip(dev, event, true); + rt_flush_dev(dev); + return NOTIFY_DONE; + } + + in_dev = __in_dev_get_rtnl(dev); + if (!in_dev) + return NOTIFY_DONE; + + switch (event) { + case NETDEV_UP: + in_dev_for_each_ifa_rtnl(ifa, in_dev) { + fib_add_ifaddr(ifa); + } +#ifdef CONFIG_IP_ROUTE_MULTIPATH + fib_sync_up(dev, RTNH_F_DEAD); +#endif + atomic_inc(&net->ipv4.dev_addr_genid); + rt_cache_flush(net); + break; + case NETDEV_DOWN: + fib_disable_ip(dev, event, false); + break; + case NETDEV_CHANGE: + flags = dev_get_flags(dev); + if (flags & (IFF_RUNNING | IFF_LOWER_UP)) + fib_sync_up(dev, RTNH_F_LINKDOWN); + else + fib_sync_down_dev(dev, event, false); + rt_cache_flush(net); + break; + case NETDEV_CHANGEMTU: + fib_sync_mtu(dev, info_ext->ext.mtu); + rt_cache_flush(net); + break; + case NETDEV_CHANGEUPPER: + upper_info = ptr; + /* flush all routes if dev is linked to or unlinked from + * an L3 master device (e.g., VRF) + */ + if (upper_info->upper_dev && + netif_is_l3_master(upper_info->upper_dev)) + fib_disable_ip(dev, NETDEV_DOWN, true); + break; + } + return NOTIFY_DONE; +} + +static struct notifier_block fib_inetaddr_notifier = { + .notifier_call = fib_inetaddr_event, +}; + +static struct notifier_block fib_netdev_notifier = { + .notifier_call = fib_netdev_event, +}; + +static int __net_init ip_fib_net_init(struct net *net) +{ + int err; + size_t size = sizeof(struct hlist_head) * FIB_TABLE_HASHSZ; + + err = fib4_notifier_init(net); + if (err) + return err; + +#ifdef CONFIG_IP_ROUTE_MULTIPATH + /* Default to 3-tuple */ + net->ipv4.sysctl_fib_multipath_hash_fields = + FIB_MULTIPATH_HASH_FIELD_DEFAULT_MASK; +#endif + + /* Avoid false sharing : Use at least a full cache line */ + size = max_t(size_t, size, L1_CACHE_BYTES); + + net->ipv4.fib_table_hash = kzalloc(size, GFP_KERNEL); + if (!net->ipv4.fib_table_hash) { + err = -ENOMEM; + goto err_table_hash_alloc; + } + + err = fib4_rules_init(net); + if (err < 0) + goto err_rules_init; + return 0; + +err_rules_init: + kfree(net->ipv4.fib_table_hash); +err_table_hash_alloc: + fib4_notifier_exit(net); + return err; +} + +static void ip_fib_net_exit(struct net *net) +{ + int i; + + ASSERT_RTNL(); +#ifdef CONFIG_IP_MULTIPLE_TABLES + RCU_INIT_POINTER(net->ipv4.fib_main, NULL); + RCU_INIT_POINTER(net->ipv4.fib_default, NULL); +#endif + /* Destroy the tables in reverse order to guarantee that the + * local table, ID 255, is destroyed before the main table, ID + * 254. This is necessary as the local table may contain + * references to data contained in the main table. + */ + for (i = FIB_TABLE_HASHSZ - 1; i >= 0; i--) { + struct hlist_head *head = &net->ipv4.fib_table_hash[i]; + struct hlist_node *tmp; + struct fib_table *tb; + + hlist_for_each_entry_safe(tb, tmp, head, tb_hlist) { + hlist_del(&tb->tb_hlist); + fib_table_flush(net, tb, true); + fib_free_table(tb); + } + } + +#ifdef CONFIG_IP_MULTIPLE_TABLES + fib4_rules_exit(net); +#endif + + kfree(net->ipv4.fib_table_hash); + fib4_notifier_exit(net); +} + +static int __net_init fib_net_init(struct net *net) +{ + int error; + +#ifdef CONFIG_IP_ROUTE_CLASSID + atomic_set(&net->ipv4.fib_num_tclassid_users, 0); +#endif + error = ip_fib_net_init(net); + if (error < 0) + goto out; + error = nl_fib_lookup_init(net); + if (error < 0) + goto out_nlfl; + error = fib_proc_init(net); + if (error < 0) + goto out_proc; +out: + return error; + +out_proc: + nl_fib_lookup_exit(net); +out_nlfl: + rtnl_lock(); + ip_fib_net_exit(net); + rtnl_unlock(); + goto out; +} + +static void __net_exit fib_net_exit(struct net *net) +{ + fib_proc_exit(net); + nl_fib_lookup_exit(net); +} + +static void __net_exit fib_net_exit_batch(struct list_head *net_list) +{ + struct net *net; + + rtnl_lock(); + list_for_each_entry(net, net_list, exit_list) + ip_fib_net_exit(net); + + rtnl_unlock(); +} + +static struct pernet_operations fib_net_ops = { + .init = fib_net_init, + .exit = fib_net_exit, + .exit_batch = fib_net_exit_batch, +}; + +void __init ip_fib_init(void) +{ + fib_trie_init(); + + register_pernet_subsys(&fib_net_ops); + + register_netdevice_notifier(&fib_netdev_notifier); + register_inetaddr_notifier(&fib_inetaddr_notifier); + + rtnl_register(PF_INET, RTM_NEWROUTE, inet_rtm_newroute, NULL, 0); + rtnl_register(PF_INET, RTM_DELROUTE, inet_rtm_delroute, NULL, 0); + rtnl_register(PF_INET, RTM_GETROUTE, NULL, inet_dump_fib, 0); +} diff --git a/net/ipv4/fib_lookup.h b/net/ipv4/fib_lookup.h new file mode 100644 index 000000000..f9b9e26c3 --- /dev/null +++ b/net/ipv4/fib_lookup.h @@ -0,0 +1,63 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _FIB_LOOKUP_H +#define _FIB_LOOKUP_H + +#include +#include +#include +#include +#include + +struct fib_alias { + struct hlist_node fa_list; + struct fib_info *fa_info; + dscp_t fa_dscp; + u8 fa_type; + u8 fa_state; + u8 fa_slen; + u32 tb_id; + s16 fa_default; + u8 offload; + u8 trap; + u8 offload_failed; + struct rcu_head rcu; +}; + +#define FA_S_ACCESSED 0x01 + +/* Don't write on fa_state unless needed, to keep it shared on all cpus */ +static inline void fib_alias_accessed(struct fib_alias *fa) +{ + if (!(fa->fa_state & FA_S_ACCESSED)) + fa->fa_state |= FA_S_ACCESSED; +} + +/* Exported by fib_semantics.c */ +void fib_release_info(struct fib_info *); +struct fib_info *fib_create_info(struct fib_config *cfg, + struct netlink_ext_ack *extack); +int fib_nh_match(struct net *net, struct fib_config *cfg, struct fib_info *fi, + struct netlink_ext_ack *extack); +bool fib_metrics_match(struct fib_config *cfg, struct fib_info *fi); +int fib_dump_info(struct sk_buff *skb, u32 pid, u32 seq, int event, + const struct fib_rt_info *fri, unsigned int flags); +void rtmsg_fib(int event, __be32 key, struct fib_alias *fa, int dst_len, + u32 tb_id, const struct nl_info *info, unsigned int nlm_flags); +size_t fib_nlmsg_size(struct fib_info *fi); + +static inline void fib_result_assign(struct fib_result *res, + struct fib_info *fi) +{ + /* we used to play games with refcounts, but we now use RCU */ + res->fi = fi; + res->nhc = fib_info_nhc(fi, 0); +} + +struct fib_prop { + int error; + u8 scope; +}; + +extern const struct fib_prop fib_props[RTN_MAX + 1]; + +#endif /* _FIB_LOOKUP_H */ diff --git a/net/ipv4/fib_notifier.c b/net/ipv4/fib_notifier.c new file mode 100644 index 000000000..0e23ade74 --- /dev/null +++ b/net/ipv4/fib_notifier.c @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include +#include +#include + +int call_fib4_notifier(struct notifier_block *nb, + enum fib_event_type event_type, + struct fib_notifier_info *info) +{ + info->family = AF_INET; + return call_fib_notifier(nb, event_type, info); +} + +int call_fib4_notifiers(struct net *net, enum fib_event_type event_type, + struct fib_notifier_info *info) +{ + ASSERT_RTNL(); + + info->family = AF_INET; + net->ipv4.fib_seq++; + return call_fib_notifiers(net, event_type, info); +} + +static unsigned int fib4_seq_read(struct net *net) +{ + ASSERT_RTNL(); + + return net->ipv4.fib_seq + fib4_rules_seq_read(net); +} + +static int fib4_dump(struct net *net, struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + int err; + + err = fib4_rules_dump(net, nb, extack); + if (err) + return err; + + return fib_notify(net, nb, extack); +} + +static const struct fib_notifier_ops fib4_notifier_ops_template = { + .family = AF_INET, + .fib_seq_read = fib4_seq_read, + .fib_dump = fib4_dump, + .owner = THIS_MODULE, +}; + +int __net_init fib4_notifier_init(struct net *net) +{ + struct fib_notifier_ops *ops; + + net->ipv4.fib_seq = 0; + + ops = fib_notifier_ops_register(&fib4_notifier_ops_template, net); + if (IS_ERR(ops)) + return PTR_ERR(ops); + net->ipv4.notifier_ops = ops; + + return 0; +} + +void __net_exit fib4_notifier_exit(struct net *net) +{ + fib_notifier_ops_unregister(net->ipv4.notifier_ops); +} diff --git a/net/ipv4/fib_rules.c b/net/ipv4/fib_rules.c new file mode 100644 index 000000000..513f475c6 --- /dev/null +++ b/net/ipv4/fib_rules.c @@ -0,0 +1,436 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * IPv4 Forwarding Information Base: policy rules. + * + * Authors: Alexey Kuznetsov, + * Thomas Graf + * + * Fixes: + * Rani Assaf : local_rule cannot be deleted + * Marc Boucher : routing by fwmark + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct fib4_rule { + struct fib_rule common; + u8 dst_len; + u8 src_len; + dscp_t dscp; + __be32 src; + __be32 srcmask; + __be32 dst; + __be32 dstmask; +#ifdef CONFIG_IP_ROUTE_CLASSID + u32 tclassid; +#endif +}; + +static bool fib4_rule_matchall(const struct fib_rule *rule) +{ + struct fib4_rule *r = container_of(rule, struct fib4_rule, common); + + if (r->dst_len || r->src_len || r->dscp) + return false; + return fib_rule_matchall(rule); +} + +bool fib4_rule_default(const struct fib_rule *rule) +{ + if (!fib4_rule_matchall(rule) || rule->action != FR_ACT_TO_TBL || + rule->l3mdev) + return false; + if (rule->table != RT_TABLE_LOCAL && rule->table != RT_TABLE_MAIN && + rule->table != RT_TABLE_DEFAULT) + return false; + return true; +} +EXPORT_SYMBOL_GPL(fib4_rule_default); + +int fib4_rules_dump(struct net *net, struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + return fib_rules_dump(net, nb, AF_INET, extack); +} + +unsigned int fib4_rules_seq_read(struct net *net) +{ + return fib_rules_seq_read(net, AF_INET); +} + +int __fib_lookup(struct net *net, struct flowi4 *flp, + struct fib_result *res, unsigned int flags) +{ + struct fib_lookup_arg arg = { + .result = res, + .flags = flags, + }; + int err; + + /* update flow if oif or iif point to device enslaved to l3mdev */ + l3mdev_update_flow(net, flowi4_to_flowi(flp)); + + err = fib_rules_lookup(net->ipv4.rules_ops, flowi4_to_flowi(flp), 0, &arg); +#ifdef CONFIG_IP_ROUTE_CLASSID + if (arg.rule) + res->tclassid = ((struct fib4_rule *)arg.rule)->tclassid; + else + res->tclassid = 0; +#endif + + if (err == -ESRCH) + err = -ENETUNREACH; + + return err; +} +EXPORT_SYMBOL_GPL(__fib_lookup); + +INDIRECT_CALLABLE_SCOPE int fib4_rule_action(struct fib_rule *rule, + struct flowi *flp, int flags, + struct fib_lookup_arg *arg) +{ + int err = -EAGAIN; + struct fib_table *tbl; + u32 tb_id; + + switch (rule->action) { + case FR_ACT_TO_TBL: + break; + + case FR_ACT_UNREACHABLE: + return -ENETUNREACH; + + case FR_ACT_PROHIBIT: + return -EACCES; + + case FR_ACT_BLACKHOLE: + default: + return -EINVAL; + } + + rcu_read_lock(); + + tb_id = fib_rule_get_table(rule, arg); + tbl = fib_get_table(rule->fr_net, tb_id); + if (tbl) + err = fib_table_lookup(tbl, &flp->u.ip4, + (struct fib_result *)arg->result, + arg->flags); + + rcu_read_unlock(); + return err; +} + +INDIRECT_CALLABLE_SCOPE bool fib4_rule_suppress(struct fib_rule *rule, + int flags, + struct fib_lookup_arg *arg) +{ + struct fib_result *result = arg->result; + struct net_device *dev = NULL; + + if (result->fi) { + struct fib_nh_common *nhc = fib_info_nhc(result->fi, 0); + + dev = nhc->nhc_dev; + } + + /* do not accept result if the route does + * not meet the required prefix length + */ + if (result->prefixlen <= rule->suppress_prefixlen) + goto suppress_route; + + /* do not accept result if the route uses a device + * belonging to a forbidden interface group + */ + if (rule->suppress_ifgroup != -1 && dev && dev->group == rule->suppress_ifgroup) + goto suppress_route; + + return false; + +suppress_route: + if (!(arg->flags & FIB_LOOKUP_NOREF)) + fib_info_put(result->fi); + return true; +} + +INDIRECT_CALLABLE_SCOPE int fib4_rule_match(struct fib_rule *rule, + struct flowi *fl, int flags) +{ + struct fib4_rule *r = (struct fib4_rule *) rule; + struct flowi4 *fl4 = &fl->u.ip4; + __be32 daddr = fl4->daddr; + __be32 saddr = fl4->saddr; + + if (((saddr ^ r->src) & r->srcmask) || + ((daddr ^ r->dst) & r->dstmask)) + return 0; + + if (r->dscp && r->dscp != inet_dsfield_to_dscp(fl4->flowi4_tos)) + return 0; + + if (rule->ip_proto && (rule->ip_proto != fl4->flowi4_proto)) + return 0; + + if (fib_rule_port_range_set(&rule->sport_range) && + !fib_rule_port_inrange(&rule->sport_range, fl4->fl4_sport)) + return 0; + + if (fib_rule_port_range_set(&rule->dport_range) && + !fib_rule_port_inrange(&rule->dport_range, fl4->fl4_dport)) + return 0; + + return 1; +} + +static struct fib_table *fib_empty_table(struct net *net) +{ + u32 id = 1; + + while (1) { + if (!fib_get_table(net, id)) + return fib_new_table(net, id); + + if (id++ == RT_TABLE_MAX) + break; + } + return NULL; +} + +static int fib4_rule_configure(struct fib_rule *rule, struct sk_buff *skb, + struct fib_rule_hdr *frh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + int err = -EINVAL; + struct fib4_rule *rule4 = (struct fib4_rule *) rule; + + if (!inet_validate_dscp(frh->tos)) { + NL_SET_ERR_MSG(extack, + "Invalid dsfield (tos): ECN bits must be 0"); + goto errout; + } + /* IPv4 currently doesn't handle high order DSCP bits correctly */ + if (frh->tos & ~IPTOS_TOS_MASK) { + NL_SET_ERR_MSG(extack, "Invalid tos"); + goto errout; + } + rule4->dscp = inet_dsfield_to_dscp(frh->tos); + + /* split local/main if they are not already split */ + err = fib_unmerge(net); + if (err) + goto errout; + + if (rule->table == RT_TABLE_UNSPEC && !rule->l3mdev) { + if (rule->action == FR_ACT_TO_TBL) { + struct fib_table *table; + + table = fib_empty_table(net); + if (!table) { + err = -ENOBUFS; + goto errout; + } + + rule->table = table->tb_id; + } + } + + if (frh->src_len) + rule4->src = nla_get_in_addr(tb[FRA_SRC]); + + if (frh->dst_len) + rule4->dst = nla_get_in_addr(tb[FRA_DST]); + +#ifdef CONFIG_IP_ROUTE_CLASSID + if (tb[FRA_FLOW]) { + rule4->tclassid = nla_get_u32(tb[FRA_FLOW]); + if (rule4->tclassid) + atomic_inc(&net->ipv4.fib_num_tclassid_users); + } +#endif + + if (fib_rule_requires_fldissect(rule)) + net->ipv4.fib_rules_require_fldissect++; + + rule4->src_len = frh->src_len; + rule4->srcmask = inet_make_mask(rule4->src_len); + rule4->dst_len = frh->dst_len; + rule4->dstmask = inet_make_mask(rule4->dst_len); + + net->ipv4.fib_has_custom_rules = true; + + err = 0; +errout: + return err; +} + +static int fib4_rule_delete(struct fib_rule *rule) +{ + struct net *net = rule->fr_net; + int err; + + /* split local/main if they are not already split */ + err = fib_unmerge(net); + if (err) + goto errout; + +#ifdef CONFIG_IP_ROUTE_CLASSID + if (((struct fib4_rule *)rule)->tclassid) + atomic_dec(&net->ipv4.fib_num_tclassid_users); +#endif + net->ipv4.fib_has_custom_rules = true; + + if (net->ipv4.fib_rules_require_fldissect && + fib_rule_requires_fldissect(rule)) + net->ipv4.fib_rules_require_fldissect--; +errout: + return err; +} + +static int fib4_rule_compare(struct fib_rule *rule, struct fib_rule_hdr *frh, + struct nlattr **tb) +{ + struct fib4_rule *rule4 = (struct fib4_rule *) rule; + + if (frh->src_len && (rule4->src_len != frh->src_len)) + return 0; + + if (frh->dst_len && (rule4->dst_len != frh->dst_len)) + return 0; + + if (frh->tos && inet_dscp_to_dsfield(rule4->dscp) != frh->tos) + return 0; + +#ifdef CONFIG_IP_ROUTE_CLASSID + if (tb[FRA_FLOW] && (rule4->tclassid != nla_get_u32(tb[FRA_FLOW]))) + return 0; +#endif + + if (frh->src_len && (rule4->src != nla_get_in_addr(tb[FRA_SRC]))) + return 0; + + if (frh->dst_len && (rule4->dst != nla_get_in_addr(tb[FRA_DST]))) + return 0; + + return 1; +} + +static int fib4_rule_fill(struct fib_rule *rule, struct sk_buff *skb, + struct fib_rule_hdr *frh) +{ + struct fib4_rule *rule4 = (struct fib4_rule *) rule; + + frh->dst_len = rule4->dst_len; + frh->src_len = rule4->src_len; + frh->tos = inet_dscp_to_dsfield(rule4->dscp); + + if ((rule4->dst_len && + nla_put_in_addr(skb, FRA_DST, rule4->dst)) || + (rule4->src_len && + nla_put_in_addr(skb, FRA_SRC, rule4->src))) + goto nla_put_failure; +#ifdef CONFIG_IP_ROUTE_CLASSID + if (rule4->tclassid && + nla_put_u32(skb, FRA_FLOW, rule4->tclassid)) + goto nla_put_failure; +#endif + return 0; + +nla_put_failure: + return -ENOBUFS; +} + +static size_t fib4_rule_nlmsg_payload(struct fib_rule *rule) +{ + return nla_total_size(4) /* dst */ + + nla_total_size(4) /* src */ + + nla_total_size(4); /* flow */ +} + +static void fib4_rule_flush_cache(struct fib_rules_ops *ops) +{ + rt_cache_flush(ops->fro_net); +} + +static const struct fib_rules_ops __net_initconst fib4_rules_ops_template = { + .family = AF_INET, + .rule_size = sizeof(struct fib4_rule), + .addr_size = sizeof(u32), + .action = fib4_rule_action, + .suppress = fib4_rule_suppress, + .match = fib4_rule_match, + .configure = fib4_rule_configure, + .delete = fib4_rule_delete, + .compare = fib4_rule_compare, + .fill = fib4_rule_fill, + .nlmsg_payload = fib4_rule_nlmsg_payload, + .flush_cache = fib4_rule_flush_cache, + .nlgroup = RTNLGRP_IPV4_RULE, + .owner = THIS_MODULE, +}; + +static int fib_default_rules_init(struct fib_rules_ops *ops) +{ + int err; + + err = fib_default_rule_add(ops, 0, RT_TABLE_LOCAL, 0); + if (err < 0) + return err; + err = fib_default_rule_add(ops, 0x7FFE, RT_TABLE_MAIN, 0); + if (err < 0) + return err; + err = fib_default_rule_add(ops, 0x7FFF, RT_TABLE_DEFAULT, 0); + if (err < 0) + return err; + return 0; +} + +int __net_init fib4_rules_init(struct net *net) +{ + int err; + struct fib_rules_ops *ops; + + ops = fib_rules_register(&fib4_rules_ops_template, net); + if (IS_ERR(ops)) + return PTR_ERR(ops); + + err = fib_default_rules_init(ops); + if (err < 0) + goto fail; + net->ipv4.rules_ops = ops; + net->ipv4.fib_has_custom_rules = false; + net->ipv4.fib_rules_require_fldissect = 0; + return 0; + +fail: + /* also cleans all rules already added */ + fib_rules_unregister(ops); + return err; +} + +void __net_exit fib4_rules_exit(struct net *net) +{ + fib_rules_unregister(net->ipv4.rules_ops); +} diff --git a/net/ipv4/fib_semantics.c b/net/ipv4/fib_semantics.c new file mode 100644 index 000000000..5eb1b8d30 --- /dev/null +++ b/net/ipv4/fib_semantics.c @@ -0,0 +1,2275 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * IPv4 Forwarding Information Base: semantics. + * + * Authors: Alexey Kuznetsov, + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "fib_lookup.h" + +static DEFINE_SPINLOCK(fib_info_lock); +static struct hlist_head *fib_info_hash; +static struct hlist_head *fib_info_laddrhash; +static unsigned int fib_info_hash_size; +static unsigned int fib_info_hash_bits; +static unsigned int fib_info_cnt; + +#define DEVINDEX_HASHBITS 8 +#define DEVINDEX_HASHSIZE (1U << DEVINDEX_HASHBITS) +static struct hlist_head fib_info_devhash[DEVINDEX_HASHSIZE]; + +/* for_nexthops and change_nexthops only used when nexthop object + * is not set in a fib_info. The logic within can reference fib_nh. + */ +#ifdef CONFIG_IP_ROUTE_MULTIPATH + +#define for_nexthops(fi) { \ + int nhsel; const struct fib_nh *nh; \ + for (nhsel = 0, nh = (fi)->fib_nh; \ + nhsel < fib_info_num_path((fi)); \ + nh++, nhsel++) + +#define change_nexthops(fi) { \ + int nhsel; struct fib_nh *nexthop_nh; \ + for (nhsel = 0, nexthop_nh = (struct fib_nh *)((fi)->fib_nh); \ + nhsel < fib_info_num_path((fi)); \ + nexthop_nh++, nhsel++) + +#else /* CONFIG_IP_ROUTE_MULTIPATH */ + +/* Hope, that gcc will optimize it to get rid of dummy loop */ + +#define for_nexthops(fi) { \ + int nhsel; const struct fib_nh *nh = (fi)->fib_nh; \ + for (nhsel = 0; nhsel < 1; nhsel++) + +#define change_nexthops(fi) { \ + int nhsel; \ + struct fib_nh *nexthop_nh = (struct fib_nh *)((fi)->fib_nh); \ + for (nhsel = 0; nhsel < 1; nhsel++) + +#endif /* CONFIG_IP_ROUTE_MULTIPATH */ + +#define endfor_nexthops(fi) } + + +const struct fib_prop fib_props[RTN_MAX + 1] = { + [RTN_UNSPEC] = { + .error = 0, + .scope = RT_SCOPE_NOWHERE, + }, + [RTN_UNICAST] = { + .error = 0, + .scope = RT_SCOPE_UNIVERSE, + }, + [RTN_LOCAL] = { + .error = 0, + .scope = RT_SCOPE_HOST, + }, + [RTN_BROADCAST] = { + .error = 0, + .scope = RT_SCOPE_LINK, + }, + [RTN_ANYCAST] = { + .error = 0, + .scope = RT_SCOPE_LINK, + }, + [RTN_MULTICAST] = { + .error = 0, + .scope = RT_SCOPE_UNIVERSE, + }, + [RTN_BLACKHOLE] = { + .error = -EINVAL, + .scope = RT_SCOPE_UNIVERSE, + }, + [RTN_UNREACHABLE] = { + .error = -EHOSTUNREACH, + .scope = RT_SCOPE_UNIVERSE, + }, + [RTN_PROHIBIT] = { + .error = -EACCES, + .scope = RT_SCOPE_UNIVERSE, + }, + [RTN_THROW] = { + .error = -EAGAIN, + .scope = RT_SCOPE_UNIVERSE, + }, + [RTN_NAT] = { + .error = -EINVAL, + .scope = RT_SCOPE_NOWHERE, + }, + [RTN_XRESOLVE] = { + .error = -EINVAL, + .scope = RT_SCOPE_NOWHERE, + }, +}; + +static void rt_fibinfo_free(struct rtable __rcu **rtp) +{ + struct rtable *rt = rcu_dereference_protected(*rtp, 1); + + if (!rt) + return; + + /* Not even needed : RCU_INIT_POINTER(*rtp, NULL); + * because we waited an RCU grace period before calling + * free_fib_info_rcu() + */ + + dst_dev_put(&rt->dst); + dst_release_immediate(&rt->dst); +} + +static void free_nh_exceptions(struct fib_nh_common *nhc) +{ + struct fnhe_hash_bucket *hash; + int i; + + hash = rcu_dereference_protected(nhc->nhc_exceptions, 1); + if (!hash) + return; + for (i = 0; i < FNHE_HASH_SIZE; i++) { + struct fib_nh_exception *fnhe; + + fnhe = rcu_dereference_protected(hash[i].chain, 1); + while (fnhe) { + struct fib_nh_exception *next; + + next = rcu_dereference_protected(fnhe->fnhe_next, 1); + + rt_fibinfo_free(&fnhe->fnhe_rth_input); + rt_fibinfo_free(&fnhe->fnhe_rth_output); + + kfree(fnhe); + + fnhe = next; + } + } + kfree(hash); +} + +static void rt_fibinfo_free_cpus(struct rtable __rcu * __percpu *rtp) +{ + int cpu; + + if (!rtp) + return; + + for_each_possible_cpu(cpu) { + struct rtable *rt; + + rt = rcu_dereference_protected(*per_cpu_ptr(rtp, cpu), 1); + if (rt) { + dst_dev_put(&rt->dst); + dst_release_immediate(&rt->dst); + } + } + free_percpu(rtp); +} + +void fib_nh_common_release(struct fib_nh_common *nhc) +{ + netdev_put(nhc->nhc_dev, &nhc->nhc_dev_tracker); + lwtstate_put(nhc->nhc_lwtstate); + rt_fibinfo_free_cpus(nhc->nhc_pcpu_rth_output); + rt_fibinfo_free(&nhc->nhc_rth_input); + free_nh_exceptions(nhc); +} +EXPORT_SYMBOL_GPL(fib_nh_common_release); + +void fib_nh_release(struct net *net, struct fib_nh *fib_nh) +{ +#ifdef CONFIG_IP_ROUTE_CLASSID + if (fib_nh->nh_tclassid) + atomic_dec(&net->ipv4.fib_num_tclassid_users); +#endif + fib_nh_common_release(&fib_nh->nh_common); +} + +/* Release a nexthop info record */ +static void free_fib_info_rcu(struct rcu_head *head) +{ + struct fib_info *fi = container_of(head, struct fib_info, rcu); + + if (fi->nh) { + nexthop_put(fi->nh); + } else { + change_nexthops(fi) { + fib_nh_release(fi->fib_net, nexthop_nh); + } endfor_nexthops(fi); + } + + ip_fib_metrics_put(fi->fib_metrics); + + kfree(fi); +} + +void free_fib_info(struct fib_info *fi) +{ + if (fi->fib_dead == 0) { + pr_warn("Freeing alive fib_info %p\n", fi); + return; + } + + call_rcu(&fi->rcu, free_fib_info_rcu); +} +EXPORT_SYMBOL_GPL(free_fib_info); + +void fib_release_info(struct fib_info *fi) +{ + spin_lock_bh(&fib_info_lock); + if (fi && refcount_dec_and_test(&fi->fib_treeref)) { + hlist_del(&fi->fib_hash); + + /* Paired with READ_ONCE() in fib_create_info(). */ + WRITE_ONCE(fib_info_cnt, fib_info_cnt - 1); + + if (fi->fib_prefsrc) + hlist_del(&fi->fib_lhash); + if (fi->nh) { + list_del(&fi->nh_list); + } else { + change_nexthops(fi) { + if (!nexthop_nh->fib_nh_dev) + continue; + hlist_del(&nexthop_nh->nh_hash); + } endfor_nexthops(fi) + } + /* Paired with READ_ONCE() from fib_table_lookup() */ + WRITE_ONCE(fi->fib_dead, 1); + fib_info_put(fi); + } + spin_unlock_bh(&fib_info_lock); +} + +static inline int nh_comp(struct fib_info *fi, struct fib_info *ofi) +{ + const struct fib_nh *onh; + + if (fi->nh || ofi->nh) + return nexthop_cmp(fi->nh, ofi->nh) ? 0 : -1; + + if (ofi->fib_nhs == 0) + return 0; + + for_nexthops(fi) { + onh = fib_info_nh(ofi, nhsel); + + if (nh->fib_nh_oif != onh->fib_nh_oif || + nh->fib_nh_gw_family != onh->fib_nh_gw_family || + nh->fib_nh_scope != onh->fib_nh_scope || +#ifdef CONFIG_IP_ROUTE_MULTIPATH + nh->fib_nh_weight != onh->fib_nh_weight || +#endif +#ifdef CONFIG_IP_ROUTE_CLASSID + nh->nh_tclassid != onh->nh_tclassid || +#endif + lwtunnel_cmp_encap(nh->fib_nh_lws, onh->fib_nh_lws) || + ((nh->fib_nh_flags ^ onh->fib_nh_flags) & ~RTNH_COMPARE_MASK)) + return -1; + + if (nh->fib_nh_gw_family == AF_INET && + nh->fib_nh_gw4 != onh->fib_nh_gw4) + return -1; + + if (nh->fib_nh_gw_family == AF_INET6 && + ipv6_addr_cmp(&nh->fib_nh_gw6, &onh->fib_nh_gw6)) + return -1; + } endfor_nexthops(fi); + return 0; +} + +static inline unsigned int fib_devindex_hashfn(unsigned int val) +{ + return hash_32(val, DEVINDEX_HASHBITS); +} + +static struct hlist_head * +fib_info_devhash_bucket(const struct net_device *dev) +{ + u32 val = net_hash_mix(dev_net(dev)) ^ dev->ifindex; + + return &fib_info_devhash[fib_devindex_hashfn(val)]; +} + +static unsigned int fib_info_hashfn_1(int init_val, u8 protocol, u8 scope, + u32 prefsrc, u32 priority) +{ + unsigned int val = init_val; + + val ^= (protocol << 8) | scope; + val ^= prefsrc; + val ^= priority; + + return val; +} + +static unsigned int fib_info_hashfn_result(unsigned int val) +{ + unsigned int mask = (fib_info_hash_size - 1); + + return (val ^ (val >> 7) ^ (val >> 12)) & mask; +} + +static inline unsigned int fib_info_hashfn(struct fib_info *fi) +{ + unsigned int val; + + val = fib_info_hashfn_1(fi->fib_nhs, fi->fib_protocol, + fi->fib_scope, (__force u32)fi->fib_prefsrc, + fi->fib_priority); + + if (fi->nh) { + val ^= fib_devindex_hashfn(fi->nh->id); + } else { + for_nexthops(fi) { + val ^= fib_devindex_hashfn(nh->fib_nh_oif); + } endfor_nexthops(fi) + } + + return fib_info_hashfn_result(val); +} + +/* no metrics, only nexthop id */ +static struct fib_info *fib_find_info_nh(struct net *net, + const struct fib_config *cfg) +{ + struct hlist_head *head; + struct fib_info *fi; + unsigned int hash; + + hash = fib_info_hashfn_1(fib_devindex_hashfn(cfg->fc_nh_id), + cfg->fc_protocol, cfg->fc_scope, + (__force u32)cfg->fc_prefsrc, + cfg->fc_priority); + hash = fib_info_hashfn_result(hash); + head = &fib_info_hash[hash]; + + hlist_for_each_entry(fi, head, fib_hash) { + if (!net_eq(fi->fib_net, net)) + continue; + if (!fi->nh || fi->nh->id != cfg->fc_nh_id) + continue; + if (cfg->fc_protocol == fi->fib_protocol && + cfg->fc_scope == fi->fib_scope && + cfg->fc_prefsrc == fi->fib_prefsrc && + cfg->fc_priority == fi->fib_priority && + cfg->fc_type == fi->fib_type && + cfg->fc_table == fi->fib_tb_id && + !((cfg->fc_flags ^ fi->fib_flags) & ~RTNH_COMPARE_MASK)) + return fi; + } + + return NULL; +} + +static struct fib_info *fib_find_info(struct fib_info *nfi) +{ + struct hlist_head *head; + struct fib_info *fi; + unsigned int hash; + + hash = fib_info_hashfn(nfi); + head = &fib_info_hash[hash]; + + hlist_for_each_entry(fi, head, fib_hash) { + if (!net_eq(fi->fib_net, nfi->fib_net)) + continue; + if (fi->fib_nhs != nfi->fib_nhs) + continue; + if (nfi->fib_protocol == fi->fib_protocol && + nfi->fib_scope == fi->fib_scope && + nfi->fib_prefsrc == fi->fib_prefsrc && + nfi->fib_priority == fi->fib_priority && + nfi->fib_type == fi->fib_type && + nfi->fib_tb_id == fi->fib_tb_id && + memcmp(nfi->fib_metrics, fi->fib_metrics, + sizeof(u32) * RTAX_MAX) == 0 && + !((nfi->fib_flags ^ fi->fib_flags) & ~RTNH_COMPARE_MASK) && + nh_comp(fi, nfi) == 0) + return fi; + } + + return NULL; +} + +/* Check, that the gateway is already configured. + * Used only by redirect accept routine. + */ +int ip_fib_check_default(__be32 gw, struct net_device *dev) +{ + struct hlist_head *head; + struct fib_nh *nh; + + spin_lock(&fib_info_lock); + + head = fib_info_devhash_bucket(dev); + + hlist_for_each_entry(nh, head, nh_hash) { + if (nh->fib_nh_dev == dev && + nh->fib_nh_gw4 == gw && + !(nh->fib_nh_flags & RTNH_F_DEAD)) { + spin_unlock(&fib_info_lock); + return 0; + } + } + + spin_unlock(&fib_info_lock); + + return -1; +} + +size_t fib_nlmsg_size(struct fib_info *fi) +{ + size_t payload = NLMSG_ALIGN(sizeof(struct rtmsg)) + + nla_total_size(4) /* RTA_TABLE */ + + nla_total_size(4) /* RTA_DST */ + + nla_total_size(4) /* RTA_PRIORITY */ + + nla_total_size(4) /* RTA_PREFSRC */ + + nla_total_size(TCP_CA_NAME_MAX); /* RTAX_CC_ALGO */ + unsigned int nhs = fib_info_num_path(fi); + + /* space for nested metrics */ + payload += nla_total_size((RTAX_MAX * nla_total_size(4))); + + if (fi->nh) + payload += nla_total_size(4); /* RTA_NH_ID */ + + if (nhs) { + size_t nh_encapsize = 0; + /* Also handles the special case nhs == 1 */ + + /* each nexthop is packed in an attribute */ + size_t nhsize = nla_total_size(sizeof(struct rtnexthop)); + unsigned int i; + + /* may contain flow and gateway attribute */ + nhsize += 2 * nla_total_size(4); + + /* grab encap info */ + for (i = 0; i < fib_info_num_path(fi); i++) { + struct fib_nh_common *nhc = fib_info_nhc(fi, i); + + if (nhc->nhc_lwtstate) { + /* RTA_ENCAP_TYPE */ + nh_encapsize += lwtunnel_get_encap_size( + nhc->nhc_lwtstate); + /* RTA_ENCAP */ + nh_encapsize += nla_total_size(2); + } + } + + /* all nexthops are packed in a nested attribute */ + payload += nla_total_size((nhs * nhsize) + nh_encapsize); + + } + + return payload; +} + +void rtmsg_fib(int event, __be32 key, struct fib_alias *fa, + int dst_len, u32 tb_id, const struct nl_info *info, + unsigned int nlm_flags) +{ + struct fib_rt_info fri; + struct sk_buff *skb; + u32 seq = info->nlh ? info->nlh->nlmsg_seq : 0; + int err = -ENOBUFS; + + skb = nlmsg_new(fib_nlmsg_size(fa->fa_info), GFP_KERNEL); + if (!skb) + goto errout; + + fri.fi = fa->fa_info; + fri.tb_id = tb_id; + fri.dst = key; + fri.dst_len = dst_len; + fri.dscp = fa->fa_dscp; + fri.type = fa->fa_type; + fri.offload = READ_ONCE(fa->offload); + fri.trap = READ_ONCE(fa->trap); + fri.offload_failed = READ_ONCE(fa->offload_failed); + err = fib_dump_info(skb, info->portid, seq, event, &fri, nlm_flags); + if (err < 0) { + /* -EMSGSIZE implies BUG in fib_nlmsg_size() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + rtnl_notify(skb, info->nl_net, info->portid, RTNLGRP_IPV4_ROUTE, + info->nlh, GFP_KERNEL); + return; +errout: + if (err < 0) + rtnl_set_sk_err(info->nl_net, RTNLGRP_IPV4_ROUTE, err); +} + +static int fib_detect_death(struct fib_info *fi, int order, + struct fib_info **last_resort, int *last_idx, + int dflt) +{ + const struct fib_nh_common *nhc = fib_info_nhc(fi, 0); + struct neighbour *n; + int state = NUD_NONE; + + if (likely(nhc->nhc_gw_family == AF_INET)) + n = neigh_lookup(&arp_tbl, &nhc->nhc_gw.ipv4, nhc->nhc_dev); + else if (nhc->nhc_gw_family == AF_INET6) + n = neigh_lookup(ipv6_stub->nd_tbl, &nhc->nhc_gw.ipv6, + nhc->nhc_dev); + else + n = NULL; + + if (n) { + state = READ_ONCE(n->nud_state); + neigh_release(n); + } else { + return 0; + } + if (state == NUD_REACHABLE) + return 0; + if ((state & NUD_VALID) && order != dflt) + return 0; + if ((state & NUD_VALID) || + (*last_idx < 0 && order > dflt && state != NUD_INCOMPLETE)) { + *last_resort = fi; + *last_idx = order; + } + return 1; +} + +int fib_nh_common_init(struct net *net, struct fib_nh_common *nhc, + struct nlattr *encap, u16 encap_type, + void *cfg, gfp_t gfp_flags, + struct netlink_ext_ack *extack) +{ + int err; + + nhc->nhc_pcpu_rth_output = alloc_percpu_gfp(struct rtable __rcu *, + gfp_flags); + if (!nhc->nhc_pcpu_rth_output) + return -ENOMEM; + + if (encap) { + struct lwtunnel_state *lwtstate; + + if (encap_type == LWTUNNEL_ENCAP_NONE) { + NL_SET_ERR_MSG(extack, "LWT encap type not specified"); + err = -EINVAL; + goto lwt_failure; + } + err = lwtunnel_build_state(net, encap_type, encap, + nhc->nhc_family, cfg, &lwtstate, + extack); + if (err) + goto lwt_failure; + + nhc->nhc_lwtstate = lwtstate_get(lwtstate); + } + + return 0; + +lwt_failure: + rt_fibinfo_free_cpus(nhc->nhc_pcpu_rth_output); + nhc->nhc_pcpu_rth_output = NULL; + return err; +} +EXPORT_SYMBOL_GPL(fib_nh_common_init); + +int fib_nh_init(struct net *net, struct fib_nh *nh, + struct fib_config *cfg, int nh_weight, + struct netlink_ext_ack *extack) +{ + int err; + + nh->fib_nh_family = AF_INET; + + err = fib_nh_common_init(net, &nh->nh_common, cfg->fc_encap, + cfg->fc_encap_type, cfg, GFP_KERNEL, extack); + if (err) + return err; + + nh->fib_nh_oif = cfg->fc_oif; + nh->fib_nh_gw_family = cfg->fc_gw_family; + if (cfg->fc_gw_family == AF_INET) + nh->fib_nh_gw4 = cfg->fc_gw4; + else if (cfg->fc_gw_family == AF_INET6) + nh->fib_nh_gw6 = cfg->fc_gw6; + + nh->fib_nh_flags = cfg->fc_flags; + +#ifdef CONFIG_IP_ROUTE_CLASSID + nh->nh_tclassid = cfg->fc_flow; + if (nh->nh_tclassid) + atomic_inc(&net->ipv4.fib_num_tclassid_users); +#endif +#ifdef CONFIG_IP_ROUTE_MULTIPATH + nh->fib_nh_weight = nh_weight; +#endif + return 0; +} + +#ifdef CONFIG_IP_ROUTE_MULTIPATH + +static int fib_count_nexthops(struct rtnexthop *rtnh, int remaining, + struct netlink_ext_ack *extack) +{ + int nhs = 0; + + while (rtnh_ok(rtnh, remaining)) { + nhs++; + rtnh = rtnh_next(rtnh, &remaining); + } + + /* leftover implies invalid nexthop configuration, discard it */ + if (remaining > 0) { + NL_SET_ERR_MSG(extack, + "Invalid nexthop configuration - extra data after nexthops"); + nhs = 0; + } + + return nhs; +} + +static int fib_gw_from_attr(__be32 *gw, struct nlattr *nla, + struct netlink_ext_ack *extack) +{ + if (nla_len(nla) < sizeof(*gw)) { + NL_SET_ERR_MSG(extack, "Invalid IPv4 address in RTA_GATEWAY"); + return -EINVAL; + } + + *gw = nla_get_in_addr(nla); + + return 0; +} + +/* only called when fib_nh is integrated into fib_info */ +static int fib_get_nhs(struct fib_info *fi, struct rtnexthop *rtnh, + int remaining, struct fib_config *cfg, + struct netlink_ext_ack *extack) +{ + struct net *net = fi->fib_net; + struct fib_config fib_cfg; + struct fib_nh *nh; + int ret; + + change_nexthops(fi) { + int attrlen; + + memset(&fib_cfg, 0, sizeof(fib_cfg)); + + if (!rtnh_ok(rtnh, remaining)) { + NL_SET_ERR_MSG(extack, + "Invalid nexthop configuration - extra data after nexthop"); + return -EINVAL; + } + + if (rtnh->rtnh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN)) { + NL_SET_ERR_MSG(extack, + "Invalid flags for nexthop - can not contain DEAD or LINKDOWN"); + return -EINVAL; + } + + fib_cfg.fc_flags = (cfg->fc_flags & ~0xFF) | rtnh->rtnh_flags; + fib_cfg.fc_oif = rtnh->rtnh_ifindex; + + attrlen = rtnh_attrlen(rtnh); + if (attrlen > 0) { + struct nlattr *nla, *nlav, *attrs = rtnh_attrs(rtnh); + + nla = nla_find(attrs, attrlen, RTA_GATEWAY); + nlav = nla_find(attrs, attrlen, RTA_VIA); + if (nla && nlav) { + NL_SET_ERR_MSG(extack, + "Nexthop configuration can not contain both GATEWAY and VIA"); + return -EINVAL; + } + if (nla) { + ret = fib_gw_from_attr(&fib_cfg.fc_gw4, nla, + extack); + if (ret) + goto errout; + + if (fib_cfg.fc_gw4) + fib_cfg.fc_gw_family = AF_INET; + } else if (nlav) { + ret = fib_gw_from_via(&fib_cfg, nlav, extack); + if (ret) + goto errout; + } + + nla = nla_find(attrs, attrlen, RTA_FLOW); + if (nla) { + if (nla_len(nla) < sizeof(u32)) { + NL_SET_ERR_MSG(extack, "Invalid RTA_FLOW"); + return -EINVAL; + } + fib_cfg.fc_flow = nla_get_u32(nla); + } + + fib_cfg.fc_encap = nla_find(attrs, attrlen, RTA_ENCAP); + /* RTA_ENCAP_TYPE length checked in + * lwtunnel_valid_encap_type_attr + */ + nla = nla_find(attrs, attrlen, RTA_ENCAP_TYPE); + if (nla) + fib_cfg.fc_encap_type = nla_get_u16(nla); + } + + ret = fib_nh_init(net, nexthop_nh, &fib_cfg, + rtnh->rtnh_hops + 1, extack); + if (ret) + goto errout; + + rtnh = rtnh_next(rtnh, &remaining); + } endfor_nexthops(fi); + + ret = -EINVAL; + nh = fib_info_nh(fi, 0); + if (cfg->fc_oif && nh->fib_nh_oif != cfg->fc_oif) { + NL_SET_ERR_MSG(extack, + "Nexthop device index does not match RTA_OIF"); + goto errout; + } + if (cfg->fc_gw_family) { + if (cfg->fc_gw_family != nh->fib_nh_gw_family || + (cfg->fc_gw_family == AF_INET && + nh->fib_nh_gw4 != cfg->fc_gw4) || + (cfg->fc_gw_family == AF_INET6 && + ipv6_addr_cmp(&nh->fib_nh_gw6, &cfg->fc_gw6))) { + NL_SET_ERR_MSG(extack, + "Nexthop gateway does not match RTA_GATEWAY or RTA_VIA"); + goto errout; + } + } +#ifdef CONFIG_IP_ROUTE_CLASSID + if (cfg->fc_flow && nh->nh_tclassid != cfg->fc_flow) { + NL_SET_ERR_MSG(extack, + "Nexthop class id does not match RTA_FLOW"); + goto errout; + } +#endif + ret = 0; +errout: + return ret; +} + +/* only called when fib_nh is integrated into fib_info */ +static void fib_rebalance(struct fib_info *fi) +{ + int total; + int w; + + if (fib_info_num_path(fi) < 2) + return; + + total = 0; + for_nexthops(fi) { + if (nh->fib_nh_flags & RTNH_F_DEAD) + continue; + + if (ip_ignore_linkdown(nh->fib_nh_dev) && + nh->fib_nh_flags & RTNH_F_LINKDOWN) + continue; + + total += nh->fib_nh_weight; + } endfor_nexthops(fi); + + w = 0; + change_nexthops(fi) { + int upper_bound; + + if (nexthop_nh->fib_nh_flags & RTNH_F_DEAD) { + upper_bound = -1; + } else if (ip_ignore_linkdown(nexthop_nh->fib_nh_dev) && + nexthop_nh->fib_nh_flags & RTNH_F_LINKDOWN) { + upper_bound = -1; + } else { + w += nexthop_nh->fib_nh_weight; + upper_bound = DIV_ROUND_CLOSEST_ULL((u64)w << 31, + total) - 1; + } + + atomic_set(&nexthop_nh->fib_nh_upper_bound, upper_bound); + } endfor_nexthops(fi); +} +#else /* CONFIG_IP_ROUTE_MULTIPATH */ + +static int fib_get_nhs(struct fib_info *fi, struct rtnexthop *rtnh, + int remaining, struct fib_config *cfg, + struct netlink_ext_ack *extack) +{ + NL_SET_ERR_MSG(extack, "Multipath support not enabled in kernel"); + + return -EINVAL; +} + +#define fib_rebalance(fi) do { } while (0) + +#endif /* CONFIG_IP_ROUTE_MULTIPATH */ + +static int fib_encap_match(struct net *net, u16 encap_type, + struct nlattr *encap, + const struct fib_nh *nh, + const struct fib_config *cfg, + struct netlink_ext_ack *extack) +{ + struct lwtunnel_state *lwtstate; + int ret, result = 0; + + if (encap_type == LWTUNNEL_ENCAP_NONE) + return 0; + + ret = lwtunnel_build_state(net, encap_type, encap, AF_INET, + cfg, &lwtstate, extack); + if (!ret) { + result = lwtunnel_cmp_encap(lwtstate, nh->fib_nh_lws); + lwtstate_free(lwtstate); + } + + return result; +} + +int fib_nh_match(struct net *net, struct fib_config *cfg, struct fib_info *fi, + struct netlink_ext_ack *extack) +{ +#ifdef CONFIG_IP_ROUTE_MULTIPATH + struct rtnexthop *rtnh; + int remaining; +#endif + + if (cfg->fc_priority && cfg->fc_priority != fi->fib_priority) + return 1; + + if (cfg->fc_nh_id) { + if (fi->nh && cfg->fc_nh_id == fi->nh->id) + return 0; + return 1; + } + + if (fi->nh) { + if (cfg->fc_oif || cfg->fc_gw_family || cfg->fc_mp) + return 1; + return 0; + } + + if (cfg->fc_oif || cfg->fc_gw_family) { + struct fib_nh *nh; + + nh = fib_info_nh(fi, 0); + if (cfg->fc_encap) { + if (fib_encap_match(net, cfg->fc_encap_type, + cfg->fc_encap, nh, cfg, extack)) + return 1; + } +#ifdef CONFIG_IP_ROUTE_CLASSID + if (cfg->fc_flow && + cfg->fc_flow != nh->nh_tclassid) + return 1; +#endif + if ((cfg->fc_oif && cfg->fc_oif != nh->fib_nh_oif) || + (cfg->fc_gw_family && + cfg->fc_gw_family != nh->fib_nh_gw_family)) + return 1; + + if (cfg->fc_gw_family == AF_INET && + cfg->fc_gw4 != nh->fib_nh_gw4) + return 1; + + if (cfg->fc_gw_family == AF_INET6 && + ipv6_addr_cmp(&cfg->fc_gw6, &nh->fib_nh_gw6)) + return 1; + + return 0; + } + +#ifdef CONFIG_IP_ROUTE_MULTIPATH + if (!cfg->fc_mp) + return 0; + + rtnh = cfg->fc_mp; + remaining = cfg->fc_mp_len; + + for_nexthops(fi) { + int attrlen; + + if (!rtnh_ok(rtnh, remaining)) + return -EINVAL; + + if (rtnh->rtnh_ifindex && rtnh->rtnh_ifindex != nh->fib_nh_oif) + return 1; + + attrlen = rtnh_attrlen(rtnh); + if (attrlen > 0) { + struct nlattr *nla, *nlav, *attrs = rtnh_attrs(rtnh); + int err; + + nla = nla_find(attrs, attrlen, RTA_GATEWAY); + nlav = nla_find(attrs, attrlen, RTA_VIA); + if (nla && nlav) { + NL_SET_ERR_MSG(extack, + "Nexthop configuration can not contain both GATEWAY and VIA"); + return -EINVAL; + } + + if (nla) { + __be32 gw; + + err = fib_gw_from_attr(&gw, nla, extack); + if (err) + return err; + + if (nh->fib_nh_gw_family != AF_INET || + gw != nh->fib_nh_gw4) + return 1; + } else if (nlav) { + struct fib_config cfg2; + + err = fib_gw_from_via(&cfg2, nlav, extack); + if (err) + return err; + + switch (nh->fib_nh_gw_family) { + case AF_INET: + if (cfg2.fc_gw_family != AF_INET || + cfg2.fc_gw4 != nh->fib_nh_gw4) + return 1; + break; + case AF_INET6: + if (cfg2.fc_gw_family != AF_INET6 || + ipv6_addr_cmp(&cfg2.fc_gw6, + &nh->fib_nh_gw6)) + return 1; + break; + } + } + +#ifdef CONFIG_IP_ROUTE_CLASSID + nla = nla_find(attrs, attrlen, RTA_FLOW); + if (nla) { + if (nla_len(nla) < sizeof(u32)) { + NL_SET_ERR_MSG(extack, "Invalid RTA_FLOW"); + return -EINVAL; + } + if (nla_get_u32(nla) != nh->nh_tclassid) + return 1; + } +#endif + } + + rtnh = rtnh_next(rtnh, &remaining); + } endfor_nexthops(fi); +#endif + return 0; +} + +bool fib_metrics_match(struct fib_config *cfg, struct fib_info *fi) +{ + struct nlattr *nla; + int remaining; + + if (!cfg->fc_mx) + return true; + + nla_for_each_attr(nla, cfg->fc_mx, cfg->fc_mx_len, remaining) { + int type = nla_type(nla); + u32 fi_val, val; + + if (!type) + continue; + if (type > RTAX_MAX) + return false; + + type = array_index_nospec(type, RTAX_MAX + 1); + if (type == RTAX_CC_ALGO) { + char tmp[TCP_CA_NAME_MAX]; + bool ecn_ca = false; + + nla_strscpy(tmp, nla, sizeof(tmp)); + val = tcp_ca_get_key_by_name(fi->fib_net, tmp, &ecn_ca); + } else { + if (nla_len(nla) != sizeof(u32)) + return false; + val = nla_get_u32(nla); + } + + fi_val = fi->fib_metrics->metrics[type - 1]; + if (type == RTAX_FEATURES) + fi_val &= ~DST_FEATURE_ECN_CA; + + if (fi_val != val) + return false; + } + + return true; +} + +static int fib_check_nh_v6_gw(struct net *net, struct fib_nh *nh, + u32 table, struct netlink_ext_ack *extack) +{ + struct fib6_config cfg = { + .fc_table = table, + .fc_flags = nh->fib_nh_flags | RTF_GATEWAY, + .fc_ifindex = nh->fib_nh_oif, + .fc_gateway = nh->fib_nh_gw6, + }; + struct fib6_nh fib6_nh = {}; + int err; + + err = ipv6_stub->fib6_nh_init(net, &fib6_nh, &cfg, GFP_KERNEL, extack); + if (!err) { + nh->fib_nh_dev = fib6_nh.fib_nh_dev; + netdev_hold(nh->fib_nh_dev, &nh->fib_nh_dev_tracker, + GFP_KERNEL); + nh->fib_nh_oif = nh->fib_nh_dev->ifindex; + nh->fib_nh_scope = RT_SCOPE_LINK; + + ipv6_stub->fib6_nh_release(&fib6_nh); + } + + return err; +} + +/* + * Picture + * ------- + * + * Semantics of nexthop is very messy by historical reasons. + * We have to take into account, that: + * a) gateway can be actually local interface address, + * so that gatewayed route is direct. + * b) gateway must be on-link address, possibly + * described not by an ifaddr, but also by a direct route. + * c) If both gateway and interface are specified, they should not + * contradict. + * d) If we use tunnel routes, gateway could be not on-link. + * + * Attempt to reconcile all of these (alas, self-contradictory) conditions + * results in pretty ugly and hairy code with obscure logic. + * + * I chose to generalized it instead, so that the size + * of code does not increase practically, but it becomes + * much more general. + * Every prefix is assigned a "scope" value: "host" is local address, + * "link" is direct route, + * [ ... "site" ... "interior" ... ] + * and "universe" is true gateway route with global meaning. + * + * Every prefix refers to a set of "nexthop"s (gw, oif), + * where gw must have narrower scope. This recursion stops + * when gw has LOCAL scope or if "nexthop" is declared ONLINK, + * which means that gw is forced to be on link. + * + * Code is still hairy, but now it is apparently logically + * consistent and very flexible. F.e. as by-product it allows + * to co-exists in peace independent exterior and interior + * routing processes. + * + * Normally it looks as following. + * + * {universe prefix} -> (gw, oif) [scope link] + * | + * |-> {link prefix} -> (gw, oif) [scope local] + * | + * |-> {local prefix} (terminal node) + */ +static int fib_check_nh_v4_gw(struct net *net, struct fib_nh *nh, u32 table, + u8 scope, struct netlink_ext_ack *extack) +{ + struct net_device *dev; + struct fib_result res; + int err = 0; + + if (nh->fib_nh_flags & RTNH_F_ONLINK) { + unsigned int addr_type; + + if (scope >= RT_SCOPE_LINK) { + NL_SET_ERR_MSG(extack, "Nexthop has invalid scope"); + return -EINVAL; + } + dev = __dev_get_by_index(net, nh->fib_nh_oif); + if (!dev) { + NL_SET_ERR_MSG(extack, "Nexthop device required for onlink"); + return -ENODEV; + } + if (!(dev->flags & IFF_UP)) { + NL_SET_ERR_MSG(extack, "Nexthop device is not up"); + return -ENETDOWN; + } + addr_type = inet_addr_type_dev_table(net, dev, nh->fib_nh_gw4); + if (addr_type != RTN_UNICAST) { + NL_SET_ERR_MSG(extack, "Nexthop has invalid gateway"); + return -EINVAL; + } + if (!netif_carrier_ok(dev)) + nh->fib_nh_flags |= RTNH_F_LINKDOWN; + nh->fib_nh_dev = dev; + netdev_hold(dev, &nh->fib_nh_dev_tracker, GFP_ATOMIC); + nh->fib_nh_scope = RT_SCOPE_LINK; + return 0; + } + rcu_read_lock(); + { + struct fib_table *tbl = NULL; + struct flowi4 fl4 = { + .daddr = nh->fib_nh_gw4, + .flowi4_scope = scope + 1, + .flowi4_oif = nh->fib_nh_oif, + .flowi4_iif = LOOPBACK_IFINDEX, + }; + + /* It is not necessary, but requires a bit of thinking */ + if (fl4.flowi4_scope < RT_SCOPE_LINK) + fl4.flowi4_scope = RT_SCOPE_LINK; + + if (table && table != RT_TABLE_MAIN) + tbl = fib_get_table(net, table); + + if (tbl) + err = fib_table_lookup(tbl, &fl4, &res, + FIB_LOOKUP_IGNORE_LINKSTATE | + FIB_LOOKUP_NOREF); + + /* on error or if no table given do full lookup. This + * is needed for example when nexthops are in the local + * table rather than the given table + */ + if (!tbl || err) { + err = fib_lookup(net, &fl4, &res, + FIB_LOOKUP_IGNORE_LINKSTATE); + } + + if (err) { + NL_SET_ERR_MSG(extack, "Nexthop has invalid gateway"); + goto out; + } + } + + err = -EINVAL; + if (res.type != RTN_UNICAST && res.type != RTN_LOCAL) { + NL_SET_ERR_MSG(extack, "Nexthop has invalid gateway"); + goto out; + } + nh->fib_nh_scope = res.scope; + nh->fib_nh_oif = FIB_RES_OIF(res); + nh->fib_nh_dev = dev = FIB_RES_DEV(res); + if (!dev) { + NL_SET_ERR_MSG(extack, + "No egress device for nexthop gateway"); + goto out; + } + netdev_hold(dev, &nh->fib_nh_dev_tracker, GFP_ATOMIC); + if (!netif_carrier_ok(dev)) + nh->fib_nh_flags |= RTNH_F_LINKDOWN; + err = (dev->flags & IFF_UP) ? 0 : -ENETDOWN; +out: + rcu_read_unlock(); + return err; +} + +static int fib_check_nh_nongw(struct net *net, struct fib_nh *nh, + struct netlink_ext_ack *extack) +{ + struct in_device *in_dev; + int err; + + if (nh->fib_nh_flags & (RTNH_F_PERVASIVE | RTNH_F_ONLINK)) { + NL_SET_ERR_MSG(extack, + "Invalid flags for nexthop - PERVASIVE and ONLINK can not be set"); + return -EINVAL; + } + + rcu_read_lock(); + + err = -ENODEV; + in_dev = inetdev_by_index(net, nh->fib_nh_oif); + if (!in_dev) + goto out; + err = -ENETDOWN; + if (!(in_dev->dev->flags & IFF_UP)) { + NL_SET_ERR_MSG(extack, "Device for nexthop is not up"); + goto out; + } + + nh->fib_nh_dev = in_dev->dev; + netdev_hold(nh->fib_nh_dev, &nh->fib_nh_dev_tracker, GFP_ATOMIC); + nh->fib_nh_scope = RT_SCOPE_HOST; + if (!netif_carrier_ok(nh->fib_nh_dev)) + nh->fib_nh_flags |= RTNH_F_LINKDOWN; + err = 0; +out: + rcu_read_unlock(); + return err; +} + +int fib_check_nh(struct net *net, struct fib_nh *nh, u32 table, u8 scope, + struct netlink_ext_ack *extack) +{ + int err; + + if (nh->fib_nh_gw_family == AF_INET) + err = fib_check_nh_v4_gw(net, nh, table, scope, extack); + else if (nh->fib_nh_gw_family == AF_INET6) + err = fib_check_nh_v6_gw(net, nh, table, extack); + else + err = fib_check_nh_nongw(net, nh, extack); + + return err; +} + +static struct hlist_head * +fib_info_laddrhash_bucket(const struct net *net, __be32 val) +{ + u32 slot = hash_32(net_hash_mix(net) ^ (__force u32)val, + fib_info_hash_bits); + + return &fib_info_laddrhash[slot]; +} + +static void fib_info_hash_move(struct hlist_head *new_info_hash, + struct hlist_head *new_laddrhash, + unsigned int new_size) +{ + struct hlist_head *old_info_hash, *old_laddrhash; + unsigned int old_size = fib_info_hash_size; + unsigned int i; + + spin_lock_bh(&fib_info_lock); + old_info_hash = fib_info_hash; + old_laddrhash = fib_info_laddrhash; + fib_info_hash_size = new_size; + fib_info_hash_bits = ilog2(new_size); + + for (i = 0; i < old_size; i++) { + struct hlist_head *head = &fib_info_hash[i]; + struct hlist_node *n; + struct fib_info *fi; + + hlist_for_each_entry_safe(fi, n, head, fib_hash) { + struct hlist_head *dest; + unsigned int new_hash; + + new_hash = fib_info_hashfn(fi); + dest = &new_info_hash[new_hash]; + hlist_add_head(&fi->fib_hash, dest); + } + } + fib_info_hash = new_info_hash; + + fib_info_laddrhash = new_laddrhash; + for (i = 0; i < old_size; i++) { + struct hlist_head *lhead = &old_laddrhash[i]; + struct hlist_node *n; + struct fib_info *fi; + + hlist_for_each_entry_safe(fi, n, lhead, fib_lhash) { + struct hlist_head *ldest; + + ldest = fib_info_laddrhash_bucket(fi->fib_net, + fi->fib_prefsrc); + hlist_add_head(&fi->fib_lhash, ldest); + } + } + + spin_unlock_bh(&fib_info_lock); + + kvfree(old_info_hash); + kvfree(old_laddrhash); +} + +__be32 fib_info_update_nhc_saddr(struct net *net, struct fib_nh_common *nhc, + unsigned char scope) +{ + struct fib_nh *nh; + __be32 saddr; + + if (nhc->nhc_family != AF_INET) + return inet_select_addr(nhc->nhc_dev, 0, scope); + + nh = container_of(nhc, struct fib_nh, nh_common); + saddr = inet_select_addr(nh->fib_nh_dev, nh->fib_nh_gw4, scope); + + WRITE_ONCE(nh->nh_saddr, saddr); + WRITE_ONCE(nh->nh_saddr_genid, atomic_read(&net->ipv4.dev_addr_genid)); + + return saddr; +} + +__be32 fib_result_prefsrc(struct net *net, struct fib_result *res) +{ + struct fib_nh_common *nhc = res->nhc; + + if (res->fi->fib_prefsrc) + return res->fi->fib_prefsrc; + + if (nhc->nhc_family == AF_INET) { + struct fib_nh *nh; + + nh = container_of(nhc, struct fib_nh, nh_common); + if (READ_ONCE(nh->nh_saddr_genid) == + atomic_read(&net->ipv4.dev_addr_genid)) + return READ_ONCE(nh->nh_saddr); + } + + return fib_info_update_nhc_saddr(net, nhc, res->fi->fib_scope); +} + +static bool fib_valid_prefsrc(struct fib_config *cfg, __be32 fib_prefsrc) +{ + if (cfg->fc_type != RTN_LOCAL || !cfg->fc_dst || + fib_prefsrc != cfg->fc_dst) { + u32 tb_id = cfg->fc_table; + int rc; + + if (tb_id == RT_TABLE_MAIN) + tb_id = RT_TABLE_LOCAL; + + rc = inet_addr_type_table(cfg->fc_nlinfo.nl_net, + fib_prefsrc, tb_id); + + if (rc != RTN_LOCAL && tb_id != RT_TABLE_LOCAL) { + rc = inet_addr_type_table(cfg->fc_nlinfo.nl_net, + fib_prefsrc, RT_TABLE_LOCAL); + } + + if (rc != RTN_LOCAL) + return false; + } + return true; +} + +struct fib_info *fib_create_info(struct fib_config *cfg, + struct netlink_ext_ack *extack) +{ + int err; + struct fib_info *fi = NULL; + struct nexthop *nh = NULL; + struct fib_info *ofi; + int nhs = 1; + struct net *net = cfg->fc_nlinfo.nl_net; + + if (cfg->fc_type > RTN_MAX) + goto err_inval; + + /* Fast check to catch the most weird cases */ + if (fib_props[cfg->fc_type].scope > cfg->fc_scope) { + NL_SET_ERR_MSG(extack, "Invalid scope"); + goto err_inval; + } + + if (cfg->fc_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN)) { + NL_SET_ERR_MSG(extack, + "Invalid rtm_flags - can not contain DEAD or LINKDOWN"); + goto err_inval; + } + + if (cfg->fc_nh_id) { + if (!cfg->fc_mx) { + fi = fib_find_info_nh(net, cfg); + if (fi) { + refcount_inc(&fi->fib_treeref); + return fi; + } + } + + nh = nexthop_find_by_id(net, cfg->fc_nh_id); + if (!nh) { + NL_SET_ERR_MSG(extack, "Nexthop id does not exist"); + goto err_inval; + } + nhs = 0; + } + +#ifdef CONFIG_IP_ROUTE_MULTIPATH + if (cfg->fc_mp) { + nhs = fib_count_nexthops(cfg->fc_mp, cfg->fc_mp_len, extack); + if (nhs == 0) + goto err_inval; + } +#endif + + err = -ENOBUFS; + + /* Paired with WRITE_ONCE() in fib_release_info() */ + if (READ_ONCE(fib_info_cnt) >= fib_info_hash_size) { + unsigned int new_size = fib_info_hash_size << 1; + struct hlist_head *new_info_hash; + struct hlist_head *new_laddrhash; + size_t bytes; + + if (!new_size) + new_size = 16; + bytes = (size_t)new_size * sizeof(struct hlist_head *); + new_info_hash = kvzalloc(bytes, GFP_KERNEL); + new_laddrhash = kvzalloc(bytes, GFP_KERNEL); + if (!new_info_hash || !new_laddrhash) { + kvfree(new_info_hash); + kvfree(new_laddrhash); + } else { + fib_info_hash_move(new_info_hash, new_laddrhash, new_size); + } + if (!fib_info_hash_size) + goto failure; + } + + fi = kzalloc(struct_size(fi, fib_nh, nhs), GFP_KERNEL); + if (!fi) + goto failure; + fi->fib_metrics = ip_fib_metrics_init(fi->fib_net, cfg->fc_mx, + cfg->fc_mx_len, extack); + if (IS_ERR(fi->fib_metrics)) { + err = PTR_ERR(fi->fib_metrics); + kfree(fi); + return ERR_PTR(err); + } + + fi->fib_net = net; + fi->fib_protocol = cfg->fc_protocol; + fi->fib_scope = cfg->fc_scope; + fi->fib_flags = cfg->fc_flags; + fi->fib_priority = cfg->fc_priority; + fi->fib_prefsrc = cfg->fc_prefsrc; + fi->fib_type = cfg->fc_type; + fi->fib_tb_id = cfg->fc_table; + + fi->fib_nhs = nhs; + if (nh) { + if (!nexthop_get(nh)) { + NL_SET_ERR_MSG(extack, "Nexthop has been deleted"); + err = -EINVAL; + } else { + err = 0; + fi->nh = nh; + } + } else { + change_nexthops(fi) { + nexthop_nh->nh_parent = fi; + } endfor_nexthops(fi) + + if (cfg->fc_mp) + err = fib_get_nhs(fi, cfg->fc_mp, cfg->fc_mp_len, cfg, + extack); + else + err = fib_nh_init(net, fi->fib_nh, cfg, 1, extack); + } + + if (err != 0) + goto failure; + + if (fib_props[cfg->fc_type].error) { + if (cfg->fc_gw_family || cfg->fc_oif || cfg->fc_mp) { + NL_SET_ERR_MSG(extack, + "Gateway, device and multipath can not be specified for this route type"); + goto err_inval; + } + goto link_it; + } else { + switch (cfg->fc_type) { + case RTN_UNICAST: + case RTN_LOCAL: + case RTN_BROADCAST: + case RTN_ANYCAST: + case RTN_MULTICAST: + break; + default: + NL_SET_ERR_MSG(extack, "Invalid route type"); + goto err_inval; + } + } + + if (cfg->fc_scope > RT_SCOPE_HOST) { + NL_SET_ERR_MSG(extack, "Invalid scope"); + goto err_inval; + } + + if (fi->nh) { + err = fib_check_nexthop(fi->nh, cfg->fc_scope, extack); + if (err) + goto failure; + } else if (cfg->fc_scope == RT_SCOPE_HOST) { + struct fib_nh *nh = fi->fib_nh; + + /* Local address is added. */ + if (nhs != 1) { + NL_SET_ERR_MSG(extack, + "Route with host scope can not have multiple nexthops"); + goto err_inval; + } + if (nh->fib_nh_gw_family) { + NL_SET_ERR_MSG(extack, + "Route with host scope can not have a gateway"); + goto err_inval; + } + nh->fib_nh_scope = RT_SCOPE_NOWHERE; + nh->fib_nh_dev = dev_get_by_index(net, nh->fib_nh_oif); + err = -ENODEV; + if (!nh->fib_nh_dev) + goto failure; + netdev_tracker_alloc(nh->fib_nh_dev, &nh->fib_nh_dev_tracker, + GFP_KERNEL); + } else { + int linkdown = 0; + + change_nexthops(fi) { + err = fib_check_nh(cfg->fc_nlinfo.nl_net, nexthop_nh, + cfg->fc_table, cfg->fc_scope, + extack); + if (err != 0) + goto failure; + if (nexthop_nh->fib_nh_flags & RTNH_F_LINKDOWN) + linkdown++; + } endfor_nexthops(fi) + if (linkdown == fi->fib_nhs) + fi->fib_flags |= RTNH_F_LINKDOWN; + } + + if (fi->fib_prefsrc && !fib_valid_prefsrc(cfg, fi->fib_prefsrc)) { + NL_SET_ERR_MSG(extack, "Invalid prefsrc address"); + goto err_inval; + } + + if (!fi->nh) { + change_nexthops(fi) { + fib_info_update_nhc_saddr(net, &nexthop_nh->nh_common, + fi->fib_scope); + if (nexthop_nh->fib_nh_gw_family == AF_INET6) + fi->fib_nh_is_v6 = true; + } endfor_nexthops(fi) + + fib_rebalance(fi); + } + +link_it: + ofi = fib_find_info(fi); + if (ofi) { + /* fib_table_lookup() should not see @fi yet. */ + fi->fib_dead = 1; + free_fib_info(fi); + refcount_inc(&ofi->fib_treeref); + return ofi; + } + + refcount_set(&fi->fib_treeref, 1); + refcount_set(&fi->fib_clntref, 1); + spin_lock_bh(&fib_info_lock); + fib_info_cnt++; + hlist_add_head(&fi->fib_hash, + &fib_info_hash[fib_info_hashfn(fi)]); + if (fi->fib_prefsrc) { + struct hlist_head *head; + + head = fib_info_laddrhash_bucket(net, fi->fib_prefsrc); + hlist_add_head(&fi->fib_lhash, head); + } + if (fi->nh) { + list_add(&fi->nh_list, &nh->fi_list); + } else { + change_nexthops(fi) { + struct hlist_head *head; + + if (!nexthop_nh->fib_nh_dev) + continue; + head = fib_info_devhash_bucket(nexthop_nh->fib_nh_dev); + hlist_add_head(&nexthop_nh->nh_hash, head); + } endfor_nexthops(fi) + } + spin_unlock_bh(&fib_info_lock); + return fi; + +err_inval: + err = -EINVAL; + +failure: + if (fi) { + /* fib_table_lookup() should not see @fi yet. */ + fi->fib_dead = 1; + free_fib_info(fi); + } + + return ERR_PTR(err); +} + +int fib_nexthop_info(struct sk_buff *skb, const struct fib_nh_common *nhc, + u8 rt_family, unsigned char *flags, bool skip_oif) +{ + if (nhc->nhc_flags & RTNH_F_DEAD) + *flags |= RTNH_F_DEAD; + + if (nhc->nhc_flags & RTNH_F_LINKDOWN) { + *flags |= RTNH_F_LINKDOWN; + + rcu_read_lock(); + switch (nhc->nhc_family) { + case AF_INET: + if (ip_ignore_linkdown(nhc->nhc_dev)) + *flags |= RTNH_F_DEAD; + break; + case AF_INET6: + if (ip6_ignore_linkdown(nhc->nhc_dev)) + *flags |= RTNH_F_DEAD; + break; + } + rcu_read_unlock(); + } + + switch (nhc->nhc_gw_family) { + case AF_INET: + if (nla_put_in_addr(skb, RTA_GATEWAY, nhc->nhc_gw.ipv4)) + goto nla_put_failure; + break; + case AF_INET6: + /* if gateway family does not match nexthop family + * gateway is encoded as RTA_VIA + */ + if (rt_family != nhc->nhc_gw_family) { + int alen = sizeof(struct in6_addr); + struct nlattr *nla; + struct rtvia *via; + + nla = nla_reserve(skb, RTA_VIA, alen + 2); + if (!nla) + goto nla_put_failure; + + via = nla_data(nla); + via->rtvia_family = AF_INET6; + memcpy(via->rtvia_addr, &nhc->nhc_gw.ipv6, alen); + } else if (nla_put_in6_addr(skb, RTA_GATEWAY, + &nhc->nhc_gw.ipv6) < 0) { + goto nla_put_failure; + } + break; + } + + *flags |= (nhc->nhc_flags & + (RTNH_F_ONLINK | RTNH_F_OFFLOAD | RTNH_F_TRAP)); + + if (!skip_oif && nhc->nhc_dev && + nla_put_u32(skb, RTA_OIF, nhc->nhc_dev->ifindex)) + goto nla_put_failure; + + if (nhc->nhc_lwtstate && + lwtunnel_fill_encap(skb, nhc->nhc_lwtstate, + RTA_ENCAP, RTA_ENCAP_TYPE) < 0) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} +EXPORT_SYMBOL_GPL(fib_nexthop_info); + +#if IS_ENABLED(CONFIG_IP_ROUTE_MULTIPATH) || IS_ENABLED(CONFIG_IPV6) +int fib_add_nexthop(struct sk_buff *skb, const struct fib_nh_common *nhc, + int nh_weight, u8 rt_family, u32 nh_tclassid) +{ + const struct net_device *dev = nhc->nhc_dev; + struct rtnexthop *rtnh; + unsigned char flags = 0; + + rtnh = nla_reserve_nohdr(skb, sizeof(*rtnh)); + if (!rtnh) + goto nla_put_failure; + + rtnh->rtnh_hops = nh_weight - 1; + rtnh->rtnh_ifindex = dev ? dev->ifindex : 0; + + if (fib_nexthop_info(skb, nhc, rt_family, &flags, true) < 0) + goto nla_put_failure; + + rtnh->rtnh_flags = flags; + + if (nh_tclassid && nla_put_u32(skb, RTA_FLOW, nh_tclassid)) + goto nla_put_failure; + + /* length of rtnetlink header + attributes */ + rtnh->rtnh_len = nlmsg_get_pos(skb) - (void *)rtnh; + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} +EXPORT_SYMBOL_GPL(fib_add_nexthop); +#endif + +#ifdef CONFIG_IP_ROUTE_MULTIPATH +static int fib_add_multipath(struct sk_buff *skb, struct fib_info *fi) +{ + struct nlattr *mp; + + mp = nla_nest_start_noflag(skb, RTA_MULTIPATH); + if (!mp) + goto nla_put_failure; + + if (unlikely(fi->nh)) { + if (nexthop_mpath_fill_node(skb, fi->nh, AF_INET) < 0) + goto nla_put_failure; + goto mp_end; + } + + for_nexthops(fi) { + u32 nh_tclassid = 0; +#ifdef CONFIG_IP_ROUTE_CLASSID + nh_tclassid = nh->nh_tclassid; +#endif + if (fib_add_nexthop(skb, &nh->nh_common, nh->fib_nh_weight, + AF_INET, nh_tclassid) < 0) + goto nla_put_failure; + } endfor_nexthops(fi); + +mp_end: + nla_nest_end(skb, mp); + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} +#else +static int fib_add_multipath(struct sk_buff *skb, struct fib_info *fi) +{ + return 0; +} +#endif + +int fib_dump_info(struct sk_buff *skb, u32 portid, u32 seq, int event, + const struct fib_rt_info *fri, unsigned int flags) +{ + unsigned int nhs = fib_info_num_path(fri->fi); + struct fib_info *fi = fri->fi; + u32 tb_id = fri->tb_id; + struct nlmsghdr *nlh; + struct rtmsg *rtm; + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(*rtm), flags); + if (!nlh) + return -EMSGSIZE; + + rtm = nlmsg_data(nlh); + rtm->rtm_family = AF_INET; + rtm->rtm_dst_len = fri->dst_len; + rtm->rtm_src_len = 0; + rtm->rtm_tos = inet_dscp_to_dsfield(fri->dscp); + if (tb_id < 256) + rtm->rtm_table = tb_id; + else + rtm->rtm_table = RT_TABLE_COMPAT; + if (nla_put_u32(skb, RTA_TABLE, tb_id)) + goto nla_put_failure; + rtm->rtm_type = fri->type; + rtm->rtm_flags = fi->fib_flags; + rtm->rtm_scope = fi->fib_scope; + rtm->rtm_protocol = fi->fib_protocol; + + if (rtm->rtm_dst_len && + nla_put_in_addr(skb, RTA_DST, fri->dst)) + goto nla_put_failure; + if (fi->fib_priority && + nla_put_u32(skb, RTA_PRIORITY, fi->fib_priority)) + goto nla_put_failure; + if (rtnetlink_put_metrics(skb, fi->fib_metrics->metrics) < 0) + goto nla_put_failure; + + if (fi->fib_prefsrc && + nla_put_in_addr(skb, RTA_PREFSRC, fi->fib_prefsrc)) + goto nla_put_failure; + + if (fi->nh) { + if (nla_put_u32(skb, RTA_NH_ID, fi->nh->id)) + goto nla_put_failure; + if (nexthop_is_blackhole(fi->nh)) + rtm->rtm_type = RTN_BLACKHOLE; + if (!READ_ONCE(fi->fib_net->ipv4.sysctl_nexthop_compat_mode)) + goto offload; + } + + if (nhs == 1) { + const struct fib_nh_common *nhc = fib_info_nhc(fi, 0); + unsigned char flags = 0; + + if (fib_nexthop_info(skb, nhc, AF_INET, &flags, false) < 0) + goto nla_put_failure; + + rtm->rtm_flags = flags; +#ifdef CONFIG_IP_ROUTE_CLASSID + if (nhc->nhc_family == AF_INET) { + struct fib_nh *nh; + + nh = container_of(nhc, struct fib_nh, nh_common); + if (nh->nh_tclassid && + nla_put_u32(skb, RTA_FLOW, nh->nh_tclassid)) + goto nla_put_failure; + } +#endif + } else { + if (fib_add_multipath(skb, fi) < 0) + goto nla_put_failure; + } + +offload: + if (fri->offload) + rtm->rtm_flags |= RTM_F_OFFLOAD; + if (fri->trap) + rtm->rtm_flags |= RTM_F_TRAP; + if (fri->offload_failed) + rtm->rtm_flags |= RTM_F_OFFLOAD_FAILED; + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +/* + * Update FIB if: + * - local address disappeared -> we must delete all the entries + * referring to it. + * - device went down -> we must shutdown all nexthops going via it. + */ +int fib_sync_down_addr(struct net_device *dev, __be32 local) +{ + int tb_id = l3mdev_fib_table(dev) ? : RT_TABLE_MAIN; + struct net *net = dev_net(dev); + struct hlist_head *head; + struct fib_info *fi; + int ret = 0; + + if (!fib_info_laddrhash || local == 0) + return 0; + + head = fib_info_laddrhash_bucket(net, local); + hlist_for_each_entry(fi, head, fib_lhash) { + if (!net_eq(fi->fib_net, net) || + fi->fib_tb_id != tb_id) + continue; + if (fi->fib_prefsrc == local) { + fi->fib_flags |= RTNH_F_DEAD; + fi->pfsrc_removed = true; + ret++; + } + } + return ret; +} + +static int call_fib_nh_notifiers(struct fib_nh *nh, + enum fib_event_type event_type) +{ + bool ignore_link_down = ip_ignore_linkdown(nh->fib_nh_dev); + struct fib_nh_notifier_info info = { + .fib_nh = nh, + }; + + switch (event_type) { + case FIB_EVENT_NH_ADD: + if (nh->fib_nh_flags & RTNH_F_DEAD) + break; + if (ignore_link_down && nh->fib_nh_flags & RTNH_F_LINKDOWN) + break; + return call_fib4_notifiers(dev_net(nh->fib_nh_dev), event_type, + &info.info); + case FIB_EVENT_NH_DEL: + if ((ignore_link_down && nh->fib_nh_flags & RTNH_F_LINKDOWN) || + (nh->fib_nh_flags & RTNH_F_DEAD)) + return call_fib4_notifiers(dev_net(nh->fib_nh_dev), + event_type, &info.info); + break; + default: + break; + } + + return NOTIFY_DONE; +} + +/* Update the PMTU of exceptions when: + * - the new MTU of the first hop becomes smaller than the PMTU + * - the old MTU was the same as the PMTU, and it limited discovery of + * larger MTUs on the path. With that limit raised, we can now + * discover larger MTUs + * A special case is locked exceptions, for which the PMTU is smaller + * than the minimal accepted PMTU: + * - if the new MTU is greater than the PMTU, don't make any change + * - otherwise, unlock and set PMTU + */ +void fib_nhc_update_mtu(struct fib_nh_common *nhc, u32 new, u32 orig) +{ + struct fnhe_hash_bucket *bucket; + int i; + + bucket = rcu_dereference_protected(nhc->nhc_exceptions, 1); + if (!bucket) + return; + + for (i = 0; i < FNHE_HASH_SIZE; i++) { + struct fib_nh_exception *fnhe; + + for (fnhe = rcu_dereference_protected(bucket[i].chain, 1); + fnhe; + fnhe = rcu_dereference_protected(fnhe->fnhe_next, 1)) { + if (fnhe->fnhe_mtu_locked) { + if (new <= fnhe->fnhe_pmtu) { + fnhe->fnhe_pmtu = new; + fnhe->fnhe_mtu_locked = false; + } + } else if (new < fnhe->fnhe_pmtu || + orig == fnhe->fnhe_pmtu) { + fnhe->fnhe_pmtu = new; + } + } + } +} + +void fib_sync_mtu(struct net_device *dev, u32 orig_mtu) +{ + struct hlist_head *head = fib_info_devhash_bucket(dev); + struct fib_nh *nh; + + hlist_for_each_entry(nh, head, nh_hash) { + if (nh->fib_nh_dev == dev) + fib_nhc_update_mtu(&nh->nh_common, dev->mtu, orig_mtu); + } +} + +/* Event force Flags Description + * NETDEV_CHANGE 0 LINKDOWN Carrier OFF, not for scope host + * NETDEV_DOWN 0 LINKDOWN|DEAD Link down, not for scope host + * NETDEV_DOWN 1 LINKDOWN|DEAD Last address removed + * NETDEV_UNREGISTER 1 LINKDOWN|DEAD Device removed + * + * only used when fib_nh is built into fib_info + */ +int fib_sync_down_dev(struct net_device *dev, unsigned long event, bool force) +{ + struct hlist_head *head = fib_info_devhash_bucket(dev); + struct fib_info *prev_fi = NULL; + int scope = RT_SCOPE_NOWHERE; + struct fib_nh *nh; + int ret = 0; + + if (force) + scope = -1; + + hlist_for_each_entry(nh, head, nh_hash) { + struct fib_info *fi = nh->nh_parent; + int dead; + + BUG_ON(!fi->fib_nhs); + if (nh->fib_nh_dev != dev || fi == prev_fi) + continue; + prev_fi = fi; + dead = 0; + change_nexthops(fi) { + if (nexthop_nh->fib_nh_flags & RTNH_F_DEAD) + dead++; + else if (nexthop_nh->fib_nh_dev == dev && + nexthop_nh->fib_nh_scope != scope) { + switch (event) { + case NETDEV_DOWN: + case NETDEV_UNREGISTER: + nexthop_nh->fib_nh_flags |= RTNH_F_DEAD; + fallthrough; + case NETDEV_CHANGE: + nexthop_nh->fib_nh_flags |= RTNH_F_LINKDOWN; + break; + } + call_fib_nh_notifiers(nexthop_nh, + FIB_EVENT_NH_DEL); + dead++; + } +#ifdef CONFIG_IP_ROUTE_MULTIPATH + if (event == NETDEV_UNREGISTER && + nexthop_nh->fib_nh_dev == dev) { + dead = fi->fib_nhs; + break; + } +#endif + } endfor_nexthops(fi) + if (dead == fi->fib_nhs) { + switch (event) { + case NETDEV_DOWN: + case NETDEV_UNREGISTER: + fi->fib_flags |= RTNH_F_DEAD; + fallthrough; + case NETDEV_CHANGE: + fi->fib_flags |= RTNH_F_LINKDOWN; + break; + } + ret++; + } + + fib_rebalance(fi); + } + + return ret; +} + +/* Must be invoked inside of an RCU protected region. */ +static void fib_select_default(const struct flowi4 *flp, struct fib_result *res) +{ + struct fib_info *fi = NULL, *last_resort = NULL; + struct hlist_head *fa_head = res->fa_head; + struct fib_table *tb = res->table; + u8 slen = 32 - res->prefixlen; + int order = -1, last_idx = -1; + struct fib_alias *fa, *fa1 = NULL; + u32 last_prio = res->fi->fib_priority; + dscp_t last_dscp = 0; + + hlist_for_each_entry_rcu(fa, fa_head, fa_list) { + struct fib_info *next_fi = fa->fa_info; + struct fib_nh_common *nhc; + + if (fa->fa_slen != slen) + continue; + if (fa->fa_dscp && + fa->fa_dscp != inet_dsfield_to_dscp(flp->flowi4_tos)) + continue; + if (fa->tb_id != tb->tb_id) + continue; + if (next_fi->fib_priority > last_prio && + fa->fa_dscp == last_dscp) { + if (last_dscp) + continue; + break; + } + if (next_fi->fib_flags & RTNH_F_DEAD) + continue; + last_dscp = fa->fa_dscp; + last_prio = next_fi->fib_priority; + + if (next_fi->fib_scope != res->scope || + fa->fa_type != RTN_UNICAST) + continue; + + nhc = fib_info_nhc(next_fi, 0); + if (!nhc->nhc_gw_family || nhc->nhc_scope != RT_SCOPE_LINK) + continue; + + fib_alias_accessed(fa); + + if (!fi) { + if (next_fi != res->fi) + break; + fa1 = fa; + } else if (!fib_detect_death(fi, order, &last_resort, + &last_idx, fa1->fa_default)) { + fib_result_assign(res, fi); + fa1->fa_default = order; + goto out; + } + fi = next_fi; + order++; + } + + if (order <= 0 || !fi) { + if (fa1) + fa1->fa_default = -1; + goto out; + } + + if (!fib_detect_death(fi, order, &last_resort, &last_idx, + fa1->fa_default)) { + fib_result_assign(res, fi); + fa1->fa_default = order; + goto out; + } + + if (last_idx >= 0) + fib_result_assign(res, last_resort); + fa1->fa_default = last_idx; +out: + return; +} + +/* + * Dead device goes up. We wake up dead nexthops. + * It takes sense only on multipath routes. + * + * only used when fib_nh is built into fib_info + */ +int fib_sync_up(struct net_device *dev, unsigned char nh_flags) +{ + struct fib_info *prev_fi; + struct hlist_head *head; + struct fib_nh *nh; + int ret; + + if (!(dev->flags & IFF_UP)) + return 0; + + if (nh_flags & RTNH_F_DEAD) { + unsigned int flags = dev_get_flags(dev); + + if (flags & (IFF_RUNNING | IFF_LOWER_UP)) + nh_flags |= RTNH_F_LINKDOWN; + } + + prev_fi = NULL; + head = fib_info_devhash_bucket(dev); + ret = 0; + + hlist_for_each_entry(nh, head, nh_hash) { + struct fib_info *fi = nh->nh_parent; + int alive; + + BUG_ON(!fi->fib_nhs); + if (nh->fib_nh_dev != dev || fi == prev_fi) + continue; + + prev_fi = fi; + alive = 0; + change_nexthops(fi) { + if (!(nexthop_nh->fib_nh_flags & nh_flags)) { + alive++; + continue; + } + if (!nexthop_nh->fib_nh_dev || + !(nexthop_nh->fib_nh_dev->flags & IFF_UP)) + continue; + if (nexthop_nh->fib_nh_dev != dev || + !__in_dev_get_rtnl(dev)) + continue; + alive++; + nexthop_nh->fib_nh_flags &= ~nh_flags; + call_fib_nh_notifiers(nexthop_nh, FIB_EVENT_NH_ADD); + } endfor_nexthops(fi) + + if (alive > 0) { + fi->fib_flags &= ~nh_flags; + ret++; + } + + fib_rebalance(fi); + } + + return ret; +} + +#ifdef CONFIG_IP_ROUTE_MULTIPATH +static bool fib_good_nh(const struct fib_nh *nh) +{ + int state = NUD_REACHABLE; + + if (nh->fib_nh_scope == RT_SCOPE_LINK) { + struct neighbour *n; + + rcu_read_lock(); + + if (likely(nh->fib_nh_gw_family == AF_INET)) + n = __ipv4_neigh_lookup_noref(nh->fib_nh_dev, + (__force u32)nh->fib_nh_gw4); + else if (nh->fib_nh_gw_family == AF_INET6) + n = __ipv6_neigh_lookup_noref_stub(nh->fib_nh_dev, + &nh->fib_nh_gw6); + else + n = NULL; + if (n) + state = READ_ONCE(n->nud_state); + + rcu_read_unlock(); + } + + return !!(state & NUD_VALID); +} + +void fib_select_multipath(struct fib_result *res, int hash) +{ + struct fib_info *fi = res->fi; + struct net *net = fi->fib_net; + bool first = false; + + if (unlikely(res->fi->nh)) { + nexthop_path_fib_result(res, hash); + return; + } + + change_nexthops(fi) { + if (READ_ONCE(net->ipv4.sysctl_fib_multipath_use_neigh)) { + if (!fib_good_nh(nexthop_nh)) + continue; + if (!first) { + res->nh_sel = nhsel; + res->nhc = &nexthop_nh->nh_common; + first = true; + } + } + + if (hash > atomic_read(&nexthop_nh->fib_nh_upper_bound)) + continue; + + res->nh_sel = nhsel; + res->nhc = &nexthop_nh->nh_common; + return; + } endfor_nexthops(fi); +} +#endif + +void fib_select_path(struct net *net, struct fib_result *res, + struct flowi4 *fl4, const struct sk_buff *skb) +{ + if (fl4->flowi4_oif) + goto check_saddr; + +#ifdef CONFIG_IP_ROUTE_MULTIPATH + if (fib_info_num_path(res->fi) > 1) { + int h = fib_multipath_hash(net, fl4, skb, NULL); + + fib_select_multipath(res, h); + } + else +#endif + if (!res->prefixlen && + res->table->tb_num_default > 1 && + res->type == RTN_UNICAST) + fib_select_default(fl4, res); + +check_saddr: + if (!fl4->saddr) + fl4->saddr = fib_result_prefsrc(net, res); +} diff --git a/net/ipv4/fib_trie.c b/net/ipv4/fib_trie.c new file mode 100644 index 000000000..9bdfdab90 --- /dev/null +++ b/net/ipv4/fib_trie.c @@ -0,0 +1,3067 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Robert Olsson Uppsala Universitet + * & Swedish University of Agricultural Sciences. + * + * Jens Laas Swedish University of + * Agricultural Sciences. + * + * Hans Liss Uppsala Universitet + * + * This work is based on the LPC-trie which is originally described in: + * + * An experimental study of compression methods for dynamic tries + * Stefan Nilsson and Matti Tikkanen. Algorithmica, 33(1):19-33, 2002. + * https://www.csc.kth.se/~snilsson/software/dyntrie2/ + * + * IP-address lookup using LC-tries. Stefan Nilsson and Gunnar Karlsson + * IEEE Journal on Selected Areas in Communications, 17(6):1083-1092, June 1999 + * + * Code from fib_hash has been reused which includes the following header: + * + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * IPv4 FIB: lookup engine and maintenance routines. + * + * Authors: Alexey Kuznetsov, + * + * Substantial contributions to this work comes from: + * + * David S. Miller, + * Stephen Hemminger + * Paul E. McKenney + * Patrick McHardy + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "fib_lookup.h" + +static int call_fib_entry_notifier(struct notifier_block *nb, + enum fib_event_type event_type, u32 dst, + int dst_len, struct fib_alias *fa, + struct netlink_ext_ack *extack) +{ + struct fib_entry_notifier_info info = { + .info.extack = extack, + .dst = dst, + .dst_len = dst_len, + .fi = fa->fa_info, + .dscp = fa->fa_dscp, + .type = fa->fa_type, + .tb_id = fa->tb_id, + }; + return call_fib4_notifier(nb, event_type, &info.info); +} + +static int call_fib_entry_notifiers(struct net *net, + enum fib_event_type event_type, u32 dst, + int dst_len, struct fib_alias *fa, + struct netlink_ext_ack *extack) +{ + struct fib_entry_notifier_info info = { + .info.extack = extack, + .dst = dst, + .dst_len = dst_len, + .fi = fa->fa_info, + .dscp = fa->fa_dscp, + .type = fa->fa_type, + .tb_id = fa->tb_id, + }; + return call_fib4_notifiers(net, event_type, &info.info); +} + +#define MAX_STAT_DEPTH 32 + +#define KEYLENGTH (8*sizeof(t_key)) +#define KEY_MAX ((t_key)~0) + +typedef unsigned int t_key; + +#define IS_TRIE(n) ((n)->pos >= KEYLENGTH) +#define IS_TNODE(n) ((n)->bits) +#define IS_LEAF(n) (!(n)->bits) + +struct key_vector { + t_key key; + unsigned char pos; /* 2log(KEYLENGTH) bits needed */ + unsigned char bits; /* 2log(KEYLENGTH) bits needed */ + unsigned char slen; + union { + /* This list pointer if valid if (pos | bits) == 0 (LEAF) */ + struct hlist_head leaf; + /* This array is valid if (pos | bits) > 0 (TNODE) */ + DECLARE_FLEX_ARRAY(struct key_vector __rcu *, tnode); + }; +}; + +struct tnode { + struct rcu_head rcu; + t_key empty_children; /* KEYLENGTH bits needed */ + t_key full_children; /* KEYLENGTH bits needed */ + struct key_vector __rcu *parent; + struct key_vector kv[1]; +#define tn_bits kv[0].bits +}; + +#define TNODE_SIZE(n) offsetof(struct tnode, kv[0].tnode[n]) +#define LEAF_SIZE TNODE_SIZE(1) + +#ifdef CONFIG_IP_FIB_TRIE_STATS +struct trie_use_stats { + unsigned int gets; + unsigned int backtrack; + unsigned int semantic_match_passed; + unsigned int semantic_match_miss; + unsigned int null_node_hit; + unsigned int resize_node_skipped; +}; +#endif + +struct trie_stat { + unsigned int totdepth; + unsigned int maxdepth; + unsigned int tnodes; + unsigned int leaves; + unsigned int nullpointers; + unsigned int prefixes; + unsigned int nodesizes[MAX_STAT_DEPTH]; +}; + +struct trie { + struct key_vector kv[1]; +#ifdef CONFIG_IP_FIB_TRIE_STATS + struct trie_use_stats __percpu *stats; +#endif +}; + +static struct key_vector *resize(struct trie *t, struct key_vector *tn); +static unsigned int tnode_free_size; + +/* + * synchronize_rcu after call_rcu for outstanding dirty memory; it should be + * especially useful before resizing the root node with PREEMPT_NONE configs; + * the value was obtained experimentally, aiming to avoid visible slowdown. + */ +unsigned int sysctl_fib_sync_mem = 512 * 1024; +unsigned int sysctl_fib_sync_mem_min = 64 * 1024; +unsigned int sysctl_fib_sync_mem_max = 64 * 1024 * 1024; + +static struct kmem_cache *fn_alias_kmem __ro_after_init; +static struct kmem_cache *trie_leaf_kmem __ro_after_init; + +static inline struct tnode *tn_info(struct key_vector *kv) +{ + return container_of(kv, struct tnode, kv[0]); +} + +/* caller must hold RTNL */ +#define node_parent(tn) rtnl_dereference(tn_info(tn)->parent) +#define get_child(tn, i) rtnl_dereference((tn)->tnode[i]) + +/* caller must hold RCU read lock or RTNL */ +#define node_parent_rcu(tn) rcu_dereference_rtnl(tn_info(tn)->parent) +#define get_child_rcu(tn, i) rcu_dereference_rtnl((tn)->tnode[i]) + +/* wrapper for rcu_assign_pointer */ +static inline void node_set_parent(struct key_vector *n, struct key_vector *tp) +{ + if (n) + rcu_assign_pointer(tn_info(n)->parent, tp); +} + +#define NODE_INIT_PARENT(n, p) RCU_INIT_POINTER(tn_info(n)->parent, p) + +/* This provides us with the number of children in this node, in the case of a + * leaf this will return 0 meaning none of the children are accessible. + */ +static inline unsigned long child_length(const struct key_vector *tn) +{ + return (1ul << tn->bits) & ~(1ul); +} + +#define get_cindex(key, kv) (((key) ^ (kv)->key) >> (kv)->pos) + +static inline unsigned long get_index(t_key key, struct key_vector *kv) +{ + unsigned long index = key ^ kv->key; + + if ((BITS_PER_LONG <= KEYLENGTH) && (KEYLENGTH == kv->pos)) + return 0; + + return index >> kv->pos; +} + +/* To understand this stuff, an understanding of keys and all their bits is + * necessary. Every node in the trie has a key associated with it, but not + * all of the bits in that key are significant. + * + * Consider a node 'n' and its parent 'tp'. + * + * If n is a leaf, every bit in its key is significant. Its presence is + * necessitated by path compression, since during a tree traversal (when + * searching for a leaf - unless we are doing an insertion) we will completely + * ignore all skipped bits we encounter. Thus we need to verify, at the end of + * a potentially successful search, that we have indeed been walking the + * correct key path. + * + * Note that we can never "miss" the correct key in the tree if present by + * following the wrong path. Path compression ensures that segments of the key + * that are the same for all keys with a given prefix are skipped, but the + * skipped part *is* identical for each node in the subtrie below the skipped + * bit! trie_insert() in this implementation takes care of that. + * + * if n is an internal node - a 'tnode' here, the various parts of its key + * have many different meanings. + * + * Example: + * _________________________________________________________________ + * | i | i | i | i | i | i | i | N | N | N | S | S | S | S | S | C | + * ----------------------------------------------------------------- + * 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 + * + * _________________________________________________________________ + * | C | C | C | u | u | u | u | u | u | u | u | u | u | u | u | u | + * ----------------------------------------------------------------- + * 15 14 13 12 11 10 9 8 7 6 5 4 3 2 1 0 + * + * tp->pos = 22 + * tp->bits = 3 + * n->pos = 13 + * n->bits = 4 + * + * First, let's just ignore the bits that come before the parent tp, that is + * the bits from (tp->pos + tp->bits) to 31. They are *known* but at this + * point we do not use them for anything. + * + * The bits from (tp->pos) to (tp->pos + tp->bits - 1) - "N", above - are the + * index into the parent's child array. That is, they will be used to find + * 'n' among tp's children. + * + * The bits from (n->pos + n->bits) to (tp->pos - 1) - "S" - are skipped bits + * for the node n. + * + * All the bits we have seen so far are significant to the node n. The rest + * of the bits are really not needed or indeed known in n->key. + * + * The bits from (n->pos) to (n->pos + n->bits - 1) - "C" - are the index into + * n's child array, and will of course be different for each child. + * + * The rest of the bits, from 0 to (n->pos -1) - "u" - are completely unknown + * at this point. + */ + +static const int halve_threshold = 25; +static const int inflate_threshold = 50; +static const int halve_threshold_root = 15; +static const int inflate_threshold_root = 30; + +static void __alias_free_mem(struct rcu_head *head) +{ + struct fib_alias *fa = container_of(head, struct fib_alias, rcu); + kmem_cache_free(fn_alias_kmem, fa); +} + +static inline void alias_free_mem_rcu(struct fib_alias *fa) +{ + call_rcu(&fa->rcu, __alias_free_mem); +} + +#define TNODE_VMALLOC_MAX \ + ilog2((SIZE_MAX - TNODE_SIZE(0)) / sizeof(struct key_vector *)) + +static void __node_free_rcu(struct rcu_head *head) +{ + struct tnode *n = container_of(head, struct tnode, rcu); + + if (!n->tn_bits) + kmem_cache_free(trie_leaf_kmem, n); + else + kvfree(n); +} + +#define node_free(n) call_rcu(&tn_info(n)->rcu, __node_free_rcu) + +static struct tnode *tnode_alloc(int bits) +{ + size_t size; + + /* verify bits is within bounds */ + if (bits > TNODE_VMALLOC_MAX) + return NULL; + + /* determine size and verify it is non-zero and didn't overflow */ + size = TNODE_SIZE(1ul << bits); + + if (size <= PAGE_SIZE) + return kzalloc(size, GFP_KERNEL); + else + return vzalloc(size); +} + +static inline void empty_child_inc(struct key_vector *n) +{ + tn_info(n)->empty_children++; + + if (!tn_info(n)->empty_children) + tn_info(n)->full_children++; +} + +static inline void empty_child_dec(struct key_vector *n) +{ + if (!tn_info(n)->empty_children) + tn_info(n)->full_children--; + + tn_info(n)->empty_children--; +} + +static struct key_vector *leaf_new(t_key key, struct fib_alias *fa) +{ + struct key_vector *l; + struct tnode *kv; + + kv = kmem_cache_alloc(trie_leaf_kmem, GFP_KERNEL); + if (!kv) + return NULL; + + /* initialize key vector */ + l = kv->kv; + l->key = key; + l->pos = 0; + l->bits = 0; + l->slen = fa->fa_slen; + + /* link leaf to fib alias */ + INIT_HLIST_HEAD(&l->leaf); + hlist_add_head(&fa->fa_list, &l->leaf); + + return l; +} + +static struct key_vector *tnode_new(t_key key, int pos, int bits) +{ + unsigned int shift = pos + bits; + struct key_vector *tn; + struct tnode *tnode; + + /* verify bits and pos their msb bits clear and values are valid */ + BUG_ON(!bits || (shift > KEYLENGTH)); + + tnode = tnode_alloc(bits); + if (!tnode) + return NULL; + + pr_debug("AT %p s=%zu %zu\n", tnode, TNODE_SIZE(0), + sizeof(struct key_vector *) << bits); + + if (bits == KEYLENGTH) + tnode->full_children = 1; + else + tnode->empty_children = 1ul << bits; + + tn = tnode->kv; + tn->key = (shift < KEYLENGTH) ? (key >> shift) << shift : 0; + tn->pos = pos; + tn->bits = bits; + tn->slen = pos; + + return tn; +} + +/* Check whether a tnode 'n' is "full", i.e. it is an internal node + * and no bits are skipped. See discussion in dyntree paper p. 6 + */ +static inline int tnode_full(struct key_vector *tn, struct key_vector *n) +{ + return n && ((n->pos + n->bits) == tn->pos) && IS_TNODE(n); +} + +/* Add a child at position i overwriting the old value. + * Update the value of full_children and empty_children. + */ +static void put_child(struct key_vector *tn, unsigned long i, + struct key_vector *n) +{ + struct key_vector *chi = get_child(tn, i); + int isfull, wasfull; + + BUG_ON(i >= child_length(tn)); + + /* update emptyChildren, overflow into fullChildren */ + if (!n && chi) + empty_child_inc(tn); + if (n && !chi) + empty_child_dec(tn); + + /* update fullChildren */ + wasfull = tnode_full(tn, chi); + isfull = tnode_full(tn, n); + + if (wasfull && !isfull) + tn_info(tn)->full_children--; + else if (!wasfull && isfull) + tn_info(tn)->full_children++; + + if (n && (tn->slen < n->slen)) + tn->slen = n->slen; + + rcu_assign_pointer(tn->tnode[i], n); +} + +static void update_children(struct key_vector *tn) +{ + unsigned long i; + + /* update all of the child parent pointers */ + for (i = child_length(tn); i;) { + struct key_vector *inode = get_child(tn, --i); + + if (!inode) + continue; + + /* Either update the children of a tnode that + * already belongs to us or update the child + * to point to ourselves. + */ + if (node_parent(inode) == tn) + update_children(inode); + else + node_set_parent(inode, tn); + } +} + +static inline void put_child_root(struct key_vector *tp, t_key key, + struct key_vector *n) +{ + if (IS_TRIE(tp)) + rcu_assign_pointer(tp->tnode[0], n); + else + put_child(tp, get_index(key, tp), n); +} + +static inline void tnode_free_init(struct key_vector *tn) +{ + tn_info(tn)->rcu.next = NULL; +} + +static inline void tnode_free_append(struct key_vector *tn, + struct key_vector *n) +{ + tn_info(n)->rcu.next = tn_info(tn)->rcu.next; + tn_info(tn)->rcu.next = &tn_info(n)->rcu; +} + +static void tnode_free(struct key_vector *tn) +{ + struct callback_head *head = &tn_info(tn)->rcu; + + while (head) { + head = head->next; + tnode_free_size += TNODE_SIZE(1ul << tn->bits); + node_free(tn); + + tn = container_of(head, struct tnode, rcu)->kv; + } + + if (tnode_free_size >= READ_ONCE(sysctl_fib_sync_mem)) { + tnode_free_size = 0; + synchronize_rcu(); + } +} + +static struct key_vector *replace(struct trie *t, + struct key_vector *oldtnode, + struct key_vector *tn) +{ + struct key_vector *tp = node_parent(oldtnode); + unsigned long i; + + /* setup the parent pointer out of and back into this node */ + NODE_INIT_PARENT(tn, tp); + put_child_root(tp, tn->key, tn); + + /* update all of the child parent pointers */ + update_children(tn); + + /* all pointers should be clean so we are done */ + tnode_free(oldtnode); + + /* resize children now that oldtnode is freed */ + for (i = child_length(tn); i;) { + struct key_vector *inode = get_child(tn, --i); + + /* resize child node */ + if (tnode_full(tn, inode)) + tn = resize(t, inode); + } + + return tp; +} + +static struct key_vector *inflate(struct trie *t, + struct key_vector *oldtnode) +{ + struct key_vector *tn; + unsigned long i; + t_key m; + + pr_debug("In inflate\n"); + + tn = tnode_new(oldtnode->key, oldtnode->pos - 1, oldtnode->bits + 1); + if (!tn) + goto notnode; + + /* prepare oldtnode to be freed */ + tnode_free_init(oldtnode); + + /* Assemble all of the pointers in our cluster, in this case that + * represents all of the pointers out of our allocated nodes that + * point to existing tnodes and the links between our allocated + * nodes. + */ + for (i = child_length(oldtnode), m = 1u << tn->pos; i;) { + struct key_vector *inode = get_child(oldtnode, --i); + struct key_vector *node0, *node1; + unsigned long j, k; + + /* An empty child */ + if (!inode) + continue; + + /* A leaf or an internal node with skipped bits */ + if (!tnode_full(oldtnode, inode)) { + put_child(tn, get_index(inode->key, tn), inode); + continue; + } + + /* drop the node in the old tnode free list */ + tnode_free_append(oldtnode, inode); + + /* An internal node with two children */ + if (inode->bits == 1) { + put_child(tn, 2 * i + 1, get_child(inode, 1)); + put_child(tn, 2 * i, get_child(inode, 0)); + continue; + } + + /* We will replace this node 'inode' with two new + * ones, 'node0' and 'node1', each with half of the + * original children. The two new nodes will have + * a position one bit further down the key and this + * means that the "significant" part of their keys + * (see the discussion near the top of this file) + * will differ by one bit, which will be "0" in + * node0's key and "1" in node1's key. Since we are + * moving the key position by one step, the bit that + * we are moving away from - the bit at position + * (tn->pos) - is the one that will differ between + * node0 and node1. So... we synthesize that bit in the + * two new keys. + */ + node1 = tnode_new(inode->key | m, inode->pos, inode->bits - 1); + if (!node1) + goto nomem; + node0 = tnode_new(inode->key, inode->pos, inode->bits - 1); + + tnode_free_append(tn, node1); + if (!node0) + goto nomem; + tnode_free_append(tn, node0); + + /* populate child pointers in new nodes */ + for (k = child_length(inode), j = k / 2; j;) { + put_child(node1, --j, get_child(inode, --k)); + put_child(node0, j, get_child(inode, j)); + put_child(node1, --j, get_child(inode, --k)); + put_child(node0, j, get_child(inode, j)); + } + + /* link new nodes to parent */ + NODE_INIT_PARENT(node1, tn); + NODE_INIT_PARENT(node0, tn); + + /* link parent to nodes */ + put_child(tn, 2 * i + 1, node1); + put_child(tn, 2 * i, node0); + } + + /* setup the parent pointers into and out of this node */ + return replace(t, oldtnode, tn); +nomem: + /* all pointers should be clean so we are done */ + tnode_free(tn); +notnode: + return NULL; +} + +static struct key_vector *halve(struct trie *t, + struct key_vector *oldtnode) +{ + struct key_vector *tn; + unsigned long i; + + pr_debug("In halve\n"); + + tn = tnode_new(oldtnode->key, oldtnode->pos + 1, oldtnode->bits - 1); + if (!tn) + goto notnode; + + /* prepare oldtnode to be freed */ + tnode_free_init(oldtnode); + + /* Assemble all of the pointers in our cluster, in this case that + * represents all of the pointers out of our allocated nodes that + * point to existing tnodes and the links between our allocated + * nodes. + */ + for (i = child_length(oldtnode); i;) { + struct key_vector *node1 = get_child(oldtnode, --i); + struct key_vector *node0 = get_child(oldtnode, --i); + struct key_vector *inode; + + /* At least one of the children is empty */ + if (!node1 || !node0) { + put_child(tn, i / 2, node1 ? : node0); + continue; + } + + /* Two nonempty children */ + inode = tnode_new(node0->key, oldtnode->pos, 1); + if (!inode) + goto nomem; + tnode_free_append(tn, inode); + + /* initialize pointers out of node */ + put_child(inode, 1, node1); + put_child(inode, 0, node0); + NODE_INIT_PARENT(inode, tn); + + /* link parent to node */ + put_child(tn, i / 2, inode); + } + + /* setup the parent pointers into and out of this node */ + return replace(t, oldtnode, tn); +nomem: + /* all pointers should be clean so we are done */ + tnode_free(tn); +notnode: + return NULL; +} + +static struct key_vector *collapse(struct trie *t, + struct key_vector *oldtnode) +{ + struct key_vector *n, *tp; + unsigned long i; + + /* scan the tnode looking for that one child that might still exist */ + for (n = NULL, i = child_length(oldtnode); !n && i;) + n = get_child(oldtnode, --i); + + /* compress one level */ + tp = node_parent(oldtnode); + put_child_root(tp, oldtnode->key, n); + node_set_parent(n, tp); + + /* drop dead node */ + node_free(oldtnode); + + return tp; +} + +static unsigned char update_suffix(struct key_vector *tn) +{ + unsigned char slen = tn->pos; + unsigned long stride, i; + unsigned char slen_max; + + /* only vector 0 can have a suffix length greater than or equal to + * tn->pos + tn->bits, the second highest node will have a suffix + * length at most of tn->pos + tn->bits - 1 + */ + slen_max = min_t(unsigned char, tn->pos + tn->bits - 1, tn->slen); + + /* search though the list of children looking for nodes that might + * have a suffix greater than the one we currently have. This is + * why we start with a stride of 2 since a stride of 1 would + * represent the nodes with suffix length equal to tn->pos + */ + for (i = 0, stride = 0x2ul ; i < child_length(tn); i += stride) { + struct key_vector *n = get_child(tn, i); + + if (!n || (n->slen <= slen)) + continue; + + /* update stride and slen based on new value */ + stride <<= (n->slen - slen); + slen = n->slen; + i &= ~(stride - 1); + + /* stop searching if we have hit the maximum possible value */ + if (slen >= slen_max) + break; + } + + tn->slen = slen; + + return slen; +} + +/* From "Implementing a dynamic compressed trie" by Stefan Nilsson of + * the Helsinki University of Technology and Matti Tikkanen of Nokia + * Telecommunications, page 6: + * "A node is doubled if the ratio of non-empty children to all + * children in the *doubled* node is at least 'high'." + * + * 'high' in this instance is the variable 'inflate_threshold'. It + * is expressed as a percentage, so we multiply it with + * child_length() and instead of multiplying by 2 (since the + * child array will be doubled by inflate()) and multiplying + * the left-hand side by 100 (to handle the percentage thing) we + * multiply the left-hand side by 50. + * + * The left-hand side may look a bit weird: child_length(tn) + * - tn->empty_children is of course the number of non-null children + * in the current node. tn->full_children is the number of "full" + * children, that is non-null tnodes with a skip value of 0. + * All of those will be doubled in the resulting inflated tnode, so + * we just count them one extra time here. + * + * A clearer way to write this would be: + * + * to_be_doubled = tn->full_children; + * not_to_be_doubled = child_length(tn) - tn->empty_children - + * tn->full_children; + * + * new_child_length = child_length(tn) * 2; + * + * new_fill_factor = 100 * (not_to_be_doubled + 2*to_be_doubled) / + * new_child_length; + * if (new_fill_factor >= inflate_threshold) + * + * ...and so on, tho it would mess up the while () loop. + * + * anyway, + * 100 * (not_to_be_doubled + 2*to_be_doubled) / new_child_length >= + * inflate_threshold + * + * avoid a division: + * 100 * (not_to_be_doubled + 2*to_be_doubled) >= + * inflate_threshold * new_child_length + * + * expand not_to_be_doubled and to_be_doubled, and shorten: + * 100 * (child_length(tn) - tn->empty_children + + * tn->full_children) >= inflate_threshold * new_child_length + * + * expand new_child_length: + * 100 * (child_length(tn) - tn->empty_children + + * tn->full_children) >= + * inflate_threshold * child_length(tn) * 2 + * + * shorten again: + * 50 * (tn->full_children + child_length(tn) - + * tn->empty_children) >= inflate_threshold * + * child_length(tn) + * + */ +static inline bool should_inflate(struct key_vector *tp, struct key_vector *tn) +{ + unsigned long used = child_length(tn); + unsigned long threshold = used; + + /* Keep root node larger */ + threshold *= IS_TRIE(tp) ? inflate_threshold_root : inflate_threshold; + used -= tn_info(tn)->empty_children; + used += tn_info(tn)->full_children; + + /* if bits == KEYLENGTH then pos = 0, and will fail below */ + + return (used > 1) && tn->pos && ((50 * used) >= threshold); +} + +static inline bool should_halve(struct key_vector *tp, struct key_vector *tn) +{ + unsigned long used = child_length(tn); + unsigned long threshold = used; + + /* Keep root node larger */ + threshold *= IS_TRIE(tp) ? halve_threshold_root : halve_threshold; + used -= tn_info(tn)->empty_children; + + /* if bits == KEYLENGTH then used = 100% on wrap, and will fail below */ + + return (used > 1) && (tn->bits > 1) && ((100 * used) < threshold); +} + +static inline bool should_collapse(struct key_vector *tn) +{ + unsigned long used = child_length(tn); + + used -= tn_info(tn)->empty_children; + + /* account for bits == KEYLENGTH case */ + if ((tn->bits == KEYLENGTH) && tn_info(tn)->full_children) + used -= KEY_MAX; + + /* One child or none, time to drop us from the trie */ + return used < 2; +} + +#define MAX_WORK 10 +static struct key_vector *resize(struct trie *t, struct key_vector *tn) +{ +#ifdef CONFIG_IP_FIB_TRIE_STATS + struct trie_use_stats __percpu *stats = t->stats; +#endif + struct key_vector *tp = node_parent(tn); + unsigned long cindex = get_index(tn->key, tp); + int max_work = MAX_WORK; + + pr_debug("In tnode_resize %p inflate_threshold=%d threshold=%d\n", + tn, inflate_threshold, halve_threshold); + + /* track the tnode via the pointer from the parent instead of + * doing it ourselves. This way we can let RCU fully do its + * thing without us interfering + */ + BUG_ON(tn != get_child(tp, cindex)); + + /* Double as long as the resulting node has a number of + * nonempty nodes that are above the threshold. + */ + while (should_inflate(tp, tn) && max_work) { + tp = inflate(t, tn); + if (!tp) { +#ifdef CONFIG_IP_FIB_TRIE_STATS + this_cpu_inc(stats->resize_node_skipped); +#endif + break; + } + + max_work--; + tn = get_child(tp, cindex); + } + + /* update parent in case inflate failed */ + tp = node_parent(tn); + + /* Return if at least one inflate is run */ + if (max_work != MAX_WORK) + return tp; + + /* Halve as long as the number of empty children in this + * node is above threshold. + */ + while (should_halve(tp, tn) && max_work) { + tp = halve(t, tn); + if (!tp) { +#ifdef CONFIG_IP_FIB_TRIE_STATS + this_cpu_inc(stats->resize_node_skipped); +#endif + break; + } + + max_work--; + tn = get_child(tp, cindex); + } + + /* Only one child remains */ + if (should_collapse(tn)) + return collapse(t, tn); + + /* update parent in case halve failed */ + return node_parent(tn); +} + +static void node_pull_suffix(struct key_vector *tn, unsigned char slen) +{ + unsigned char node_slen = tn->slen; + + while ((node_slen > tn->pos) && (node_slen > slen)) { + slen = update_suffix(tn); + if (node_slen == slen) + break; + + tn = node_parent(tn); + node_slen = tn->slen; + } +} + +static void node_push_suffix(struct key_vector *tn, unsigned char slen) +{ + while (tn->slen < slen) { + tn->slen = slen; + tn = node_parent(tn); + } +} + +/* rcu_read_lock needs to be hold by caller from readside */ +static struct key_vector *fib_find_node(struct trie *t, + struct key_vector **tp, u32 key) +{ + struct key_vector *pn, *n = t->kv; + unsigned long index = 0; + + do { + pn = n; + n = get_child_rcu(n, index); + + if (!n) + break; + + index = get_cindex(key, n); + + /* This bit of code is a bit tricky but it combines multiple + * checks into a single check. The prefix consists of the + * prefix plus zeros for the bits in the cindex. The index + * is the difference between the key and this value. From + * this we can actually derive several pieces of data. + * if (index >= (1ul << bits)) + * we have a mismatch in skip bits and failed + * else + * we know the value is cindex + * + * This check is safe even if bits == KEYLENGTH due to the + * fact that we can only allocate a node with 32 bits if a + * long is greater than 32 bits. + */ + if (index >= (1ul << n->bits)) { + n = NULL; + break; + } + + /* keep searching until we find a perfect match leaf or NULL */ + } while (IS_TNODE(n)); + + *tp = pn; + + return n; +} + +/* Return the first fib alias matching DSCP with + * priority less than or equal to PRIO. + * If 'find_first' is set, return the first matching + * fib alias, regardless of DSCP and priority. + */ +static struct fib_alias *fib_find_alias(struct hlist_head *fah, u8 slen, + dscp_t dscp, u32 prio, u32 tb_id, + bool find_first) +{ + struct fib_alias *fa; + + if (!fah) + return NULL; + + hlist_for_each_entry(fa, fah, fa_list) { + /* Avoid Sparse warning when using dscp_t in inequalities */ + u8 __fa_dscp = inet_dscp_to_dsfield(fa->fa_dscp); + u8 __dscp = inet_dscp_to_dsfield(dscp); + + if (fa->fa_slen < slen) + continue; + if (fa->fa_slen != slen) + break; + if (fa->tb_id > tb_id) + continue; + if (fa->tb_id != tb_id) + break; + if (find_first) + return fa; + if (__fa_dscp > __dscp) + continue; + if (fa->fa_info->fib_priority >= prio || __fa_dscp < __dscp) + return fa; + } + + return NULL; +} + +static struct fib_alias * +fib_find_matching_alias(struct net *net, const struct fib_rt_info *fri) +{ + u8 slen = KEYLENGTH - fri->dst_len; + struct key_vector *l, *tp; + struct fib_table *tb; + struct fib_alias *fa; + struct trie *t; + + tb = fib_get_table(net, fri->tb_id); + if (!tb) + return NULL; + + t = (struct trie *)tb->tb_data; + l = fib_find_node(t, &tp, be32_to_cpu(fri->dst)); + if (!l) + return NULL; + + hlist_for_each_entry_rcu(fa, &l->leaf, fa_list) { + if (fa->fa_slen == slen && fa->tb_id == fri->tb_id && + fa->fa_dscp == fri->dscp && fa->fa_info == fri->fi && + fa->fa_type == fri->type) + return fa; + } + + return NULL; +} + +void fib_alias_hw_flags_set(struct net *net, const struct fib_rt_info *fri) +{ + u8 fib_notify_on_flag_change; + struct fib_alias *fa_match; + struct sk_buff *skb; + int err; + + rcu_read_lock(); + + fa_match = fib_find_matching_alias(net, fri); + if (!fa_match) + goto out; + + /* These are paired with the WRITE_ONCE() happening in this function. + * The reason is that we are only protected by RCU at this point. + */ + if (READ_ONCE(fa_match->offload) == fri->offload && + READ_ONCE(fa_match->trap) == fri->trap && + READ_ONCE(fa_match->offload_failed) == fri->offload_failed) + goto out; + + WRITE_ONCE(fa_match->offload, fri->offload); + WRITE_ONCE(fa_match->trap, fri->trap); + + fib_notify_on_flag_change = READ_ONCE(net->ipv4.sysctl_fib_notify_on_flag_change); + + /* 2 means send notifications only if offload_failed was changed. */ + if (fib_notify_on_flag_change == 2 && + READ_ONCE(fa_match->offload_failed) == fri->offload_failed) + goto out; + + WRITE_ONCE(fa_match->offload_failed, fri->offload_failed); + + if (!fib_notify_on_flag_change) + goto out; + + skb = nlmsg_new(fib_nlmsg_size(fa_match->fa_info), GFP_ATOMIC); + if (!skb) { + err = -ENOBUFS; + goto errout; + } + + err = fib_dump_info(skb, 0, 0, RTM_NEWROUTE, fri, 0); + if (err < 0) { + /* -EMSGSIZE implies BUG in fib_nlmsg_size() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + + rtnl_notify(skb, net, 0, RTNLGRP_IPV4_ROUTE, NULL, GFP_ATOMIC); + goto out; + +errout: + rtnl_set_sk_err(net, RTNLGRP_IPV4_ROUTE, err); +out: + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(fib_alias_hw_flags_set); + +static void trie_rebalance(struct trie *t, struct key_vector *tn) +{ + while (!IS_TRIE(tn)) + tn = resize(t, tn); +} + +static int fib_insert_node(struct trie *t, struct key_vector *tp, + struct fib_alias *new, t_key key) +{ + struct key_vector *n, *l; + + l = leaf_new(key, new); + if (!l) + goto noleaf; + + /* retrieve child from parent node */ + n = get_child(tp, get_index(key, tp)); + + /* Case 2: n is a LEAF or a TNODE and the key doesn't match. + * + * Add a new tnode here + * first tnode need some special handling + * leaves us in position for handling as case 3 + */ + if (n) { + struct key_vector *tn; + + tn = tnode_new(key, __fls(key ^ n->key), 1); + if (!tn) + goto notnode; + + /* initialize routes out of node */ + NODE_INIT_PARENT(tn, tp); + put_child(tn, get_index(key, tn) ^ 1, n); + + /* start adding routes into the node */ + put_child_root(tp, key, tn); + node_set_parent(n, tn); + + /* parent now has a NULL spot where the leaf can go */ + tp = tn; + } + + /* Case 3: n is NULL, and will just insert a new leaf */ + node_push_suffix(tp, new->fa_slen); + NODE_INIT_PARENT(l, tp); + put_child_root(tp, key, l); + trie_rebalance(t, tp); + + return 0; +notnode: + node_free(l); +noleaf: + return -ENOMEM; +} + +static int fib_insert_alias(struct trie *t, struct key_vector *tp, + struct key_vector *l, struct fib_alias *new, + struct fib_alias *fa, t_key key) +{ + if (!l) + return fib_insert_node(t, tp, new, key); + + if (fa) { + hlist_add_before_rcu(&new->fa_list, &fa->fa_list); + } else { + struct fib_alias *last; + + hlist_for_each_entry(last, &l->leaf, fa_list) { + if (new->fa_slen < last->fa_slen) + break; + if ((new->fa_slen == last->fa_slen) && + (new->tb_id > last->tb_id)) + break; + fa = last; + } + + if (fa) + hlist_add_behind_rcu(&new->fa_list, &fa->fa_list); + else + hlist_add_head_rcu(&new->fa_list, &l->leaf); + } + + /* if we added to the tail node then we need to update slen */ + if (l->slen < new->fa_slen) { + l->slen = new->fa_slen; + node_push_suffix(tp, new->fa_slen); + } + + return 0; +} + +static bool fib_valid_key_len(u32 key, u8 plen, struct netlink_ext_ack *extack) +{ + if (plen > KEYLENGTH) { + NL_SET_ERR_MSG(extack, "Invalid prefix length"); + return false; + } + + if ((plen < KEYLENGTH) && (key << plen)) { + NL_SET_ERR_MSG(extack, + "Invalid prefix for given prefix length"); + return false; + } + + return true; +} + +static void fib_remove_alias(struct trie *t, struct key_vector *tp, + struct key_vector *l, struct fib_alias *old); + +/* Caller must hold RTNL. */ +int fib_table_insert(struct net *net, struct fib_table *tb, + struct fib_config *cfg, struct netlink_ext_ack *extack) +{ + struct trie *t = (struct trie *)tb->tb_data; + struct fib_alias *fa, *new_fa; + struct key_vector *l, *tp; + u16 nlflags = NLM_F_EXCL; + struct fib_info *fi; + u8 plen = cfg->fc_dst_len; + u8 slen = KEYLENGTH - plen; + dscp_t dscp; + u32 key; + int err; + + key = ntohl(cfg->fc_dst); + + if (!fib_valid_key_len(key, plen, extack)) + return -EINVAL; + + pr_debug("Insert table=%u %08x/%d\n", tb->tb_id, key, plen); + + fi = fib_create_info(cfg, extack); + if (IS_ERR(fi)) { + err = PTR_ERR(fi); + goto err; + } + + dscp = cfg->fc_dscp; + l = fib_find_node(t, &tp, key); + fa = l ? fib_find_alias(&l->leaf, slen, dscp, fi->fib_priority, + tb->tb_id, false) : NULL; + + /* Now fa, if non-NULL, points to the first fib alias + * with the same keys [prefix,dscp,priority], if such key already + * exists or to the node before which we will insert new one. + * + * If fa is NULL, we will need to allocate a new one and + * insert to the tail of the section matching the suffix length + * of the new alias. + */ + + if (fa && fa->fa_dscp == dscp && + fa->fa_info->fib_priority == fi->fib_priority) { + struct fib_alias *fa_first, *fa_match; + + err = -EEXIST; + if (cfg->fc_nlflags & NLM_F_EXCL) + goto out; + + nlflags &= ~NLM_F_EXCL; + + /* We have 2 goals: + * 1. Find exact match for type, scope, fib_info to avoid + * duplicate routes + * 2. Find next 'fa' (or head), NLM_F_APPEND inserts before it + */ + fa_match = NULL; + fa_first = fa; + hlist_for_each_entry_from(fa, fa_list) { + if ((fa->fa_slen != slen) || + (fa->tb_id != tb->tb_id) || + (fa->fa_dscp != dscp)) + break; + if (fa->fa_info->fib_priority != fi->fib_priority) + break; + if (fa->fa_type == cfg->fc_type && + fa->fa_info == fi) { + fa_match = fa; + break; + } + } + + if (cfg->fc_nlflags & NLM_F_REPLACE) { + struct fib_info *fi_drop; + u8 state; + + nlflags |= NLM_F_REPLACE; + fa = fa_first; + if (fa_match) { + if (fa == fa_match) + err = 0; + goto out; + } + err = -ENOBUFS; + new_fa = kmem_cache_alloc(fn_alias_kmem, GFP_KERNEL); + if (!new_fa) + goto out; + + fi_drop = fa->fa_info; + new_fa->fa_dscp = fa->fa_dscp; + new_fa->fa_info = fi; + new_fa->fa_type = cfg->fc_type; + state = fa->fa_state; + new_fa->fa_state = state & ~FA_S_ACCESSED; + new_fa->fa_slen = fa->fa_slen; + new_fa->tb_id = tb->tb_id; + new_fa->fa_default = -1; + new_fa->offload = 0; + new_fa->trap = 0; + new_fa->offload_failed = 0; + + hlist_replace_rcu(&fa->fa_list, &new_fa->fa_list); + + if (fib_find_alias(&l->leaf, fa->fa_slen, 0, 0, + tb->tb_id, true) == new_fa) { + enum fib_event_type fib_event; + + fib_event = FIB_EVENT_ENTRY_REPLACE; + err = call_fib_entry_notifiers(net, fib_event, + key, plen, + new_fa, extack); + if (err) { + hlist_replace_rcu(&new_fa->fa_list, + &fa->fa_list); + goto out_free_new_fa; + } + } + + rtmsg_fib(RTM_NEWROUTE, htonl(key), new_fa, plen, + tb->tb_id, &cfg->fc_nlinfo, nlflags); + + alias_free_mem_rcu(fa); + + fib_release_info(fi_drop); + if (state & FA_S_ACCESSED) + rt_cache_flush(cfg->fc_nlinfo.nl_net); + + goto succeeded; + } + /* Error if we find a perfect match which + * uses the same scope, type, and nexthop + * information. + */ + if (fa_match) + goto out; + + if (cfg->fc_nlflags & NLM_F_APPEND) + nlflags |= NLM_F_APPEND; + else + fa = fa_first; + } + err = -ENOENT; + if (!(cfg->fc_nlflags & NLM_F_CREATE)) + goto out; + + nlflags |= NLM_F_CREATE; + err = -ENOBUFS; + new_fa = kmem_cache_alloc(fn_alias_kmem, GFP_KERNEL); + if (!new_fa) + goto out; + + new_fa->fa_info = fi; + new_fa->fa_dscp = dscp; + new_fa->fa_type = cfg->fc_type; + new_fa->fa_state = 0; + new_fa->fa_slen = slen; + new_fa->tb_id = tb->tb_id; + new_fa->fa_default = -1; + new_fa->offload = 0; + new_fa->trap = 0; + new_fa->offload_failed = 0; + + /* Insert new entry to the list. */ + err = fib_insert_alias(t, tp, l, new_fa, fa, key); + if (err) + goto out_free_new_fa; + + /* The alias was already inserted, so the node must exist. */ + l = l ? l : fib_find_node(t, &tp, key); + if (WARN_ON_ONCE(!l)) { + err = -ENOENT; + goto out_free_new_fa; + } + + if (fib_find_alias(&l->leaf, new_fa->fa_slen, 0, 0, tb->tb_id, true) == + new_fa) { + enum fib_event_type fib_event; + + fib_event = FIB_EVENT_ENTRY_REPLACE; + err = call_fib_entry_notifiers(net, fib_event, key, plen, + new_fa, extack); + if (err) + goto out_remove_new_fa; + } + + if (!plen) + tb->tb_num_default++; + + rt_cache_flush(cfg->fc_nlinfo.nl_net); + rtmsg_fib(RTM_NEWROUTE, htonl(key), new_fa, plen, new_fa->tb_id, + &cfg->fc_nlinfo, nlflags); +succeeded: + return 0; + +out_remove_new_fa: + fib_remove_alias(t, tp, l, new_fa); +out_free_new_fa: + kmem_cache_free(fn_alias_kmem, new_fa); +out: + fib_release_info(fi); +err: + return err; +} + +static inline t_key prefix_mismatch(t_key key, struct key_vector *n) +{ + t_key prefix = n->key; + + return (key ^ prefix) & (prefix | -prefix); +} + +bool fib_lookup_good_nhc(const struct fib_nh_common *nhc, int fib_flags, + const struct flowi4 *flp) +{ + if (nhc->nhc_flags & RTNH_F_DEAD) + return false; + + if (ip_ignore_linkdown(nhc->nhc_dev) && + nhc->nhc_flags & RTNH_F_LINKDOWN && + !(fib_flags & FIB_LOOKUP_IGNORE_LINKSTATE)) + return false; + + if (flp->flowi4_oif && flp->flowi4_oif != nhc->nhc_oif) + return false; + + return true; +} + +/* should be called with rcu_read_lock */ +int fib_table_lookup(struct fib_table *tb, const struct flowi4 *flp, + struct fib_result *res, int fib_flags) +{ + struct trie *t = (struct trie *) tb->tb_data; +#ifdef CONFIG_IP_FIB_TRIE_STATS + struct trie_use_stats __percpu *stats = t->stats; +#endif + const t_key key = ntohl(flp->daddr); + struct key_vector *n, *pn; + struct fib_alias *fa; + unsigned long index; + t_key cindex; + + pn = t->kv; + cindex = 0; + + n = get_child_rcu(pn, cindex); + if (!n) { + trace_fib_table_lookup(tb->tb_id, flp, NULL, -EAGAIN); + return -EAGAIN; + } + +#ifdef CONFIG_IP_FIB_TRIE_STATS + this_cpu_inc(stats->gets); +#endif + + /* Step 1: Travel to the longest prefix match in the trie */ + for (;;) { + index = get_cindex(key, n); + + /* This bit of code is a bit tricky but it combines multiple + * checks into a single check. The prefix consists of the + * prefix plus zeros for the "bits" in the prefix. The index + * is the difference between the key and this value. From + * this we can actually derive several pieces of data. + * if (index >= (1ul << bits)) + * we have a mismatch in skip bits and failed + * else + * we know the value is cindex + * + * This check is safe even if bits == KEYLENGTH due to the + * fact that we can only allocate a node with 32 bits if a + * long is greater than 32 bits. + */ + if (index >= (1ul << n->bits)) + break; + + /* we have found a leaf. Prefixes have already been compared */ + if (IS_LEAF(n)) + goto found; + + /* only record pn and cindex if we are going to be chopping + * bits later. Otherwise we are just wasting cycles. + */ + if (n->slen > n->pos) { + pn = n; + cindex = index; + } + + n = get_child_rcu(n, index); + if (unlikely(!n)) + goto backtrace; + } + + /* Step 2: Sort out leaves and begin backtracing for longest prefix */ + for (;;) { + /* record the pointer where our next node pointer is stored */ + struct key_vector __rcu **cptr = n->tnode; + + /* This test verifies that none of the bits that differ + * between the key and the prefix exist in the region of + * the lsb and higher in the prefix. + */ + if (unlikely(prefix_mismatch(key, n)) || (n->slen == n->pos)) + goto backtrace; + + /* exit out and process leaf */ + if (unlikely(IS_LEAF(n))) + break; + + /* Don't bother recording parent info. Since we are in + * prefix match mode we will have to come back to wherever + * we started this traversal anyway + */ + + while ((n = rcu_dereference(*cptr)) == NULL) { +backtrace: +#ifdef CONFIG_IP_FIB_TRIE_STATS + if (!n) + this_cpu_inc(stats->null_node_hit); +#endif + /* If we are at cindex 0 there are no more bits for + * us to strip at this level so we must ascend back + * up one level to see if there are any more bits to + * be stripped there. + */ + while (!cindex) { + t_key pkey = pn->key; + + /* If we don't have a parent then there is + * nothing for us to do as we do not have any + * further nodes to parse. + */ + if (IS_TRIE(pn)) { + trace_fib_table_lookup(tb->tb_id, flp, + NULL, -EAGAIN); + return -EAGAIN; + } +#ifdef CONFIG_IP_FIB_TRIE_STATS + this_cpu_inc(stats->backtrack); +#endif + /* Get Child's index */ + pn = node_parent_rcu(pn); + cindex = get_index(pkey, pn); + } + + /* strip the least significant bit from the cindex */ + cindex &= cindex - 1; + + /* grab pointer for next child node */ + cptr = &pn->tnode[cindex]; + } + } + +found: + /* this line carries forward the xor from earlier in the function */ + index = key ^ n->key; + + /* Step 3: Process the leaf, if that fails fall back to backtracing */ + hlist_for_each_entry_rcu(fa, &n->leaf, fa_list) { + struct fib_info *fi = fa->fa_info; + struct fib_nh_common *nhc; + int nhsel, err; + + if ((BITS_PER_LONG > KEYLENGTH) || (fa->fa_slen < KEYLENGTH)) { + if (index >= (1ul << fa->fa_slen)) + continue; + } + if (fa->fa_dscp && + inet_dscp_to_dsfield(fa->fa_dscp) != flp->flowi4_tos) + continue; + /* Paired with WRITE_ONCE() in fib_release_info() */ + if (READ_ONCE(fi->fib_dead)) + continue; + if (fa->fa_info->fib_scope < flp->flowi4_scope) + continue; + fib_alias_accessed(fa); + err = fib_props[fa->fa_type].error; + if (unlikely(err < 0)) { +out_reject: +#ifdef CONFIG_IP_FIB_TRIE_STATS + this_cpu_inc(stats->semantic_match_passed); +#endif + trace_fib_table_lookup(tb->tb_id, flp, NULL, err); + return err; + } + if (fi->fib_flags & RTNH_F_DEAD) + continue; + + if (unlikely(fi->nh)) { + if (nexthop_is_blackhole(fi->nh)) { + err = fib_props[RTN_BLACKHOLE].error; + goto out_reject; + } + + nhc = nexthop_get_nhc_lookup(fi->nh, fib_flags, flp, + &nhsel); + if (nhc) + goto set_result; + goto miss; + } + + for (nhsel = 0; nhsel < fib_info_num_path(fi); nhsel++) { + nhc = fib_info_nhc(fi, nhsel); + + if (!fib_lookup_good_nhc(nhc, fib_flags, flp)) + continue; +set_result: + if (!(fib_flags & FIB_LOOKUP_NOREF)) + refcount_inc(&fi->fib_clntref); + + res->prefix = htonl(n->key); + res->prefixlen = KEYLENGTH - fa->fa_slen; + res->nh_sel = nhsel; + res->nhc = nhc; + res->type = fa->fa_type; + res->scope = fi->fib_scope; + res->fi = fi; + res->table = tb; + res->fa_head = &n->leaf; +#ifdef CONFIG_IP_FIB_TRIE_STATS + this_cpu_inc(stats->semantic_match_passed); +#endif + trace_fib_table_lookup(tb->tb_id, flp, nhc, err); + + return err; + } + } +miss: +#ifdef CONFIG_IP_FIB_TRIE_STATS + this_cpu_inc(stats->semantic_match_miss); +#endif + goto backtrace; +} +EXPORT_SYMBOL_GPL(fib_table_lookup); + +static void fib_remove_alias(struct trie *t, struct key_vector *tp, + struct key_vector *l, struct fib_alias *old) +{ + /* record the location of the previous list_info entry */ + struct hlist_node **pprev = old->fa_list.pprev; + struct fib_alias *fa = hlist_entry(pprev, typeof(*fa), fa_list.next); + + /* remove the fib_alias from the list */ + hlist_del_rcu(&old->fa_list); + + /* if we emptied the list this leaf will be freed and we can sort + * out parent suffix lengths as a part of trie_rebalance + */ + if (hlist_empty(&l->leaf)) { + if (tp->slen == l->slen) + node_pull_suffix(tp, tp->pos); + put_child_root(tp, l->key, NULL); + node_free(l); + trie_rebalance(t, tp); + return; + } + + /* only access fa if it is pointing at the last valid hlist_node */ + if (*pprev) + return; + + /* update the trie with the latest suffix length */ + l->slen = fa->fa_slen; + node_pull_suffix(tp, fa->fa_slen); +} + +static void fib_notify_alias_delete(struct net *net, u32 key, + struct hlist_head *fah, + struct fib_alias *fa_to_delete, + struct netlink_ext_ack *extack) +{ + struct fib_alias *fa_next, *fa_to_notify; + u32 tb_id = fa_to_delete->tb_id; + u8 slen = fa_to_delete->fa_slen; + enum fib_event_type fib_event; + + /* Do not notify if we do not care about the route. */ + if (fib_find_alias(fah, slen, 0, 0, tb_id, true) != fa_to_delete) + return; + + /* Determine if the route should be replaced by the next route in the + * list. + */ + fa_next = hlist_entry_safe(fa_to_delete->fa_list.next, + struct fib_alias, fa_list); + if (fa_next && fa_next->fa_slen == slen && fa_next->tb_id == tb_id) { + fib_event = FIB_EVENT_ENTRY_REPLACE; + fa_to_notify = fa_next; + } else { + fib_event = FIB_EVENT_ENTRY_DEL; + fa_to_notify = fa_to_delete; + } + call_fib_entry_notifiers(net, fib_event, key, KEYLENGTH - slen, + fa_to_notify, extack); +} + +/* Caller must hold RTNL. */ +int fib_table_delete(struct net *net, struct fib_table *tb, + struct fib_config *cfg, struct netlink_ext_ack *extack) +{ + struct trie *t = (struct trie *) tb->tb_data; + struct fib_alias *fa, *fa_to_delete; + struct key_vector *l, *tp; + u8 plen = cfg->fc_dst_len; + u8 slen = KEYLENGTH - plen; + dscp_t dscp; + u32 key; + + key = ntohl(cfg->fc_dst); + + if (!fib_valid_key_len(key, plen, extack)) + return -EINVAL; + + l = fib_find_node(t, &tp, key); + if (!l) + return -ESRCH; + + dscp = cfg->fc_dscp; + fa = fib_find_alias(&l->leaf, slen, dscp, 0, tb->tb_id, false); + if (!fa) + return -ESRCH; + + pr_debug("Deleting %08x/%d dsfield=0x%02x t=%p\n", key, plen, + inet_dscp_to_dsfield(dscp), t); + + fa_to_delete = NULL; + hlist_for_each_entry_from(fa, fa_list) { + struct fib_info *fi = fa->fa_info; + + if ((fa->fa_slen != slen) || + (fa->tb_id != tb->tb_id) || + (fa->fa_dscp != dscp)) + break; + + if ((!cfg->fc_type || fa->fa_type == cfg->fc_type) && + (cfg->fc_scope == RT_SCOPE_NOWHERE || + fa->fa_info->fib_scope == cfg->fc_scope) && + (!cfg->fc_prefsrc || + fi->fib_prefsrc == cfg->fc_prefsrc) && + (!cfg->fc_protocol || + fi->fib_protocol == cfg->fc_protocol) && + fib_nh_match(net, cfg, fi, extack) == 0 && + fib_metrics_match(cfg, fi)) { + fa_to_delete = fa; + break; + } + } + + if (!fa_to_delete) + return -ESRCH; + + fib_notify_alias_delete(net, key, &l->leaf, fa_to_delete, extack); + rtmsg_fib(RTM_DELROUTE, htonl(key), fa_to_delete, plen, tb->tb_id, + &cfg->fc_nlinfo, 0); + + if (!plen) + tb->tb_num_default--; + + fib_remove_alias(t, tp, l, fa_to_delete); + + if (fa_to_delete->fa_state & FA_S_ACCESSED) + rt_cache_flush(cfg->fc_nlinfo.nl_net); + + fib_release_info(fa_to_delete->fa_info); + alias_free_mem_rcu(fa_to_delete); + return 0; +} + +/* Scan for the next leaf starting at the provided key value */ +static struct key_vector *leaf_walk_rcu(struct key_vector **tn, t_key key) +{ + struct key_vector *pn, *n = *tn; + unsigned long cindex; + + /* this loop is meant to try and find the key in the trie */ + do { + /* record parent and next child index */ + pn = n; + cindex = (key > pn->key) ? get_index(key, pn) : 0; + + if (cindex >> pn->bits) + break; + + /* descend into the next child */ + n = get_child_rcu(pn, cindex++); + if (!n) + break; + + /* guarantee forward progress on the keys */ + if (IS_LEAF(n) && (n->key >= key)) + goto found; + } while (IS_TNODE(n)); + + /* this loop will search for the next leaf with a greater key */ + while (!IS_TRIE(pn)) { + /* if we exhausted the parent node we will need to climb */ + if (cindex >= (1ul << pn->bits)) { + t_key pkey = pn->key; + + pn = node_parent_rcu(pn); + cindex = get_index(pkey, pn) + 1; + continue; + } + + /* grab the next available node */ + n = get_child_rcu(pn, cindex++); + if (!n) + continue; + + /* no need to compare keys since we bumped the index */ + if (IS_LEAF(n)) + goto found; + + /* Rescan start scanning in new node */ + pn = n; + cindex = 0; + } + + *tn = pn; + return NULL; /* Root of trie */ +found: + /* if we are at the limit for keys just return NULL for the tnode */ + *tn = pn; + return n; +} + +static void fib_trie_free(struct fib_table *tb) +{ + struct trie *t = (struct trie *)tb->tb_data; + struct key_vector *pn = t->kv; + unsigned long cindex = 1; + struct hlist_node *tmp; + struct fib_alias *fa; + + /* walk trie in reverse order and free everything */ + for (;;) { + struct key_vector *n; + + if (!(cindex--)) { + t_key pkey = pn->key; + + if (IS_TRIE(pn)) + break; + + n = pn; + pn = node_parent(pn); + + /* drop emptied tnode */ + put_child_root(pn, n->key, NULL); + node_free(n); + + cindex = get_index(pkey, pn); + + continue; + } + + /* grab the next available node */ + n = get_child(pn, cindex); + if (!n) + continue; + + if (IS_TNODE(n)) { + /* record pn and cindex for leaf walking */ + pn = n; + cindex = 1ul << n->bits; + + continue; + } + + hlist_for_each_entry_safe(fa, tmp, &n->leaf, fa_list) { + hlist_del_rcu(&fa->fa_list); + alias_free_mem_rcu(fa); + } + + put_child_root(pn, n->key, NULL); + node_free(n); + } + +#ifdef CONFIG_IP_FIB_TRIE_STATS + free_percpu(t->stats); +#endif + kfree(tb); +} + +struct fib_table *fib_trie_unmerge(struct fib_table *oldtb) +{ + struct trie *ot = (struct trie *)oldtb->tb_data; + struct key_vector *l, *tp = ot->kv; + struct fib_table *local_tb; + struct fib_alias *fa; + struct trie *lt; + t_key key = 0; + + if (oldtb->tb_data == oldtb->__data) + return oldtb; + + local_tb = fib_trie_table(RT_TABLE_LOCAL, NULL); + if (!local_tb) + return NULL; + + lt = (struct trie *)local_tb->tb_data; + + while ((l = leaf_walk_rcu(&tp, key)) != NULL) { + struct key_vector *local_l = NULL, *local_tp; + + hlist_for_each_entry(fa, &l->leaf, fa_list) { + struct fib_alias *new_fa; + + if (local_tb->tb_id != fa->tb_id) + continue; + + /* clone fa for new local table */ + new_fa = kmem_cache_alloc(fn_alias_kmem, GFP_KERNEL); + if (!new_fa) + goto out; + + memcpy(new_fa, fa, sizeof(*fa)); + + /* insert clone into table */ + if (!local_l) + local_l = fib_find_node(lt, &local_tp, l->key); + + if (fib_insert_alias(lt, local_tp, local_l, new_fa, + NULL, l->key)) { + kmem_cache_free(fn_alias_kmem, new_fa); + goto out; + } + } + + /* stop loop if key wrapped back to 0 */ + key = l->key + 1; + if (key < l->key) + break; + } + + return local_tb; +out: + fib_trie_free(local_tb); + + return NULL; +} + +/* Caller must hold RTNL */ +void fib_table_flush_external(struct fib_table *tb) +{ + struct trie *t = (struct trie *)tb->tb_data; + struct key_vector *pn = t->kv; + unsigned long cindex = 1; + struct hlist_node *tmp; + struct fib_alias *fa; + + /* walk trie in reverse order */ + for (;;) { + unsigned char slen = 0; + struct key_vector *n; + + if (!(cindex--)) { + t_key pkey = pn->key; + + /* cannot resize the trie vector */ + if (IS_TRIE(pn)) + break; + + /* update the suffix to address pulled leaves */ + if (pn->slen > pn->pos) + update_suffix(pn); + + /* resize completed node */ + pn = resize(t, pn); + cindex = get_index(pkey, pn); + + continue; + } + + /* grab the next available node */ + n = get_child(pn, cindex); + if (!n) + continue; + + if (IS_TNODE(n)) { + /* record pn and cindex for leaf walking */ + pn = n; + cindex = 1ul << n->bits; + + continue; + } + + hlist_for_each_entry_safe(fa, tmp, &n->leaf, fa_list) { + /* if alias was cloned to local then we just + * need to remove the local copy from main + */ + if (tb->tb_id != fa->tb_id) { + hlist_del_rcu(&fa->fa_list); + alias_free_mem_rcu(fa); + continue; + } + + /* record local slen */ + slen = fa->fa_slen; + } + + /* update leaf slen */ + n->slen = slen; + + if (hlist_empty(&n->leaf)) { + put_child_root(pn, n->key, NULL); + node_free(n); + } + } +} + +/* Caller must hold RTNL. */ +int fib_table_flush(struct net *net, struct fib_table *tb, bool flush_all) +{ + struct trie *t = (struct trie *)tb->tb_data; + struct nl_info info = { .nl_net = net }; + struct key_vector *pn = t->kv; + unsigned long cindex = 1; + struct hlist_node *tmp; + struct fib_alias *fa; + int found = 0; + + /* walk trie in reverse order */ + for (;;) { + unsigned char slen = 0; + struct key_vector *n; + + if (!(cindex--)) { + t_key pkey = pn->key; + + /* cannot resize the trie vector */ + if (IS_TRIE(pn)) + break; + + /* update the suffix to address pulled leaves */ + if (pn->slen > pn->pos) + update_suffix(pn); + + /* resize completed node */ + pn = resize(t, pn); + cindex = get_index(pkey, pn); + + continue; + } + + /* grab the next available node */ + n = get_child(pn, cindex); + if (!n) + continue; + + if (IS_TNODE(n)) { + /* record pn and cindex for leaf walking */ + pn = n; + cindex = 1ul << n->bits; + + continue; + } + + hlist_for_each_entry_safe(fa, tmp, &n->leaf, fa_list) { + struct fib_info *fi = fa->fa_info; + + if (!fi || tb->tb_id != fa->tb_id || + (!(fi->fib_flags & RTNH_F_DEAD) && + !fib_props[fa->fa_type].error)) { + slen = fa->fa_slen; + continue; + } + + /* Do not flush error routes if network namespace is + * not being dismantled + */ + if (!flush_all && fib_props[fa->fa_type].error) { + slen = fa->fa_slen; + continue; + } + + fib_notify_alias_delete(net, n->key, &n->leaf, fa, + NULL); + if (fi->pfsrc_removed) + rtmsg_fib(RTM_DELROUTE, htonl(n->key), fa, + KEYLENGTH - fa->fa_slen, tb->tb_id, &info, 0); + hlist_del_rcu(&fa->fa_list); + fib_release_info(fa->fa_info); + alias_free_mem_rcu(fa); + found++; + } + + /* update leaf slen */ + n->slen = slen; + + if (hlist_empty(&n->leaf)) { + put_child_root(pn, n->key, NULL); + node_free(n); + } + } + + pr_debug("trie_flush found=%d\n", found); + return found; +} + +/* derived from fib_trie_free */ +static void __fib_info_notify_update(struct net *net, struct fib_table *tb, + struct nl_info *info) +{ + struct trie *t = (struct trie *)tb->tb_data; + struct key_vector *pn = t->kv; + unsigned long cindex = 1; + struct fib_alias *fa; + + for (;;) { + struct key_vector *n; + + if (!(cindex--)) { + t_key pkey = pn->key; + + if (IS_TRIE(pn)) + break; + + pn = node_parent(pn); + cindex = get_index(pkey, pn); + continue; + } + + /* grab the next available node */ + n = get_child(pn, cindex); + if (!n) + continue; + + if (IS_TNODE(n)) { + /* record pn and cindex for leaf walking */ + pn = n; + cindex = 1ul << n->bits; + + continue; + } + + hlist_for_each_entry(fa, &n->leaf, fa_list) { + struct fib_info *fi = fa->fa_info; + + if (!fi || !fi->nh_updated || fa->tb_id != tb->tb_id) + continue; + + rtmsg_fib(RTM_NEWROUTE, htonl(n->key), fa, + KEYLENGTH - fa->fa_slen, tb->tb_id, + info, NLM_F_REPLACE); + } + } +} + +void fib_info_notify_update(struct net *net, struct nl_info *info) +{ + unsigned int h; + + for (h = 0; h < FIB_TABLE_HASHSZ; h++) { + struct hlist_head *head = &net->ipv4.fib_table_hash[h]; + struct fib_table *tb; + + hlist_for_each_entry_rcu(tb, head, tb_hlist, + lockdep_rtnl_is_held()) + __fib_info_notify_update(net, tb, info); + } +} + +static int fib_leaf_notify(struct key_vector *l, struct fib_table *tb, + struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + struct fib_alias *fa; + int last_slen = -1; + int err; + + hlist_for_each_entry_rcu(fa, &l->leaf, fa_list) { + struct fib_info *fi = fa->fa_info; + + if (!fi) + continue; + + /* local and main table can share the same trie, + * so don't notify twice for the same entry. + */ + if (tb->tb_id != fa->tb_id) + continue; + + if (fa->fa_slen == last_slen) + continue; + + last_slen = fa->fa_slen; + err = call_fib_entry_notifier(nb, FIB_EVENT_ENTRY_REPLACE, + l->key, KEYLENGTH - fa->fa_slen, + fa, extack); + if (err) + return err; + } + return 0; +} + +static int fib_table_notify(struct fib_table *tb, struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + struct trie *t = (struct trie *)tb->tb_data; + struct key_vector *l, *tp = t->kv; + t_key key = 0; + int err; + + while ((l = leaf_walk_rcu(&tp, key)) != NULL) { + err = fib_leaf_notify(l, tb, nb, extack); + if (err) + return err; + + key = l->key + 1; + /* stop in case of wrap around */ + if (key < l->key) + break; + } + return 0; +} + +int fib_notify(struct net *net, struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + unsigned int h; + int err; + + for (h = 0; h < FIB_TABLE_HASHSZ; h++) { + struct hlist_head *head = &net->ipv4.fib_table_hash[h]; + struct fib_table *tb; + + hlist_for_each_entry_rcu(tb, head, tb_hlist) { + err = fib_table_notify(tb, nb, extack); + if (err) + return err; + } + } + return 0; +} + +static void __trie_free_rcu(struct rcu_head *head) +{ + struct fib_table *tb = container_of(head, struct fib_table, rcu); +#ifdef CONFIG_IP_FIB_TRIE_STATS + struct trie *t = (struct trie *)tb->tb_data; + + if (tb->tb_data == tb->__data) + free_percpu(t->stats); +#endif /* CONFIG_IP_FIB_TRIE_STATS */ + kfree(tb); +} + +void fib_free_table(struct fib_table *tb) +{ + call_rcu(&tb->rcu, __trie_free_rcu); +} + +static int fn_trie_dump_leaf(struct key_vector *l, struct fib_table *tb, + struct sk_buff *skb, struct netlink_callback *cb, + struct fib_dump_filter *filter) +{ + unsigned int flags = NLM_F_MULTI; + __be32 xkey = htonl(l->key); + int i, s_i, i_fa, s_fa, err; + struct fib_alias *fa; + + if (filter->filter_set || + !filter->dump_exceptions || !filter->dump_routes) + flags |= NLM_F_DUMP_FILTERED; + + s_i = cb->args[4]; + s_fa = cb->args[5]; + i = 0; + + /* rcu_read_lock is hold by caller */ + hlist_for_each_entry_rcu(fa, &l->leaf, fa_list) { + struct fib_info *fi = fa->fa_info; + + if (i < s_i) + goto next; + + i_fa = 0; + + if (tb->tb_id != fa->tb_id) + goto next; + + if (filter->filter_set) { + if (filter->rt_type && fa->fa_type != filter->rt_type) + goto next; + + if ((filter->protocol && + fi->fib_protocol != filter->protocol)) + goto next; + + if (filter->dev && + !fib_info_nh_uses_dev(fi, filter->dev)) + goto next; + } + + if (filter->dump_routes) { + if (!s_fa) { + struct fib_rt_info fri; + + fri.fi = fi; + fri.tb_id = tb->tb_id; + fri.dst = xkey; + fri.dst_len = KEYLENGTH - fa->fa_slen; + fri.dscp = fa->fa_dscp; + fri.type = fa->fa_type; + fri.offload = READ_ONCE(fa->offload); + fri.trap = READ_ONCE(fa->trap); + fri.offload_failed = READ_ONCE(fa->offload_failed); + err = fib_dump_info(skb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + RTM_NEWROUTE, &fri, flags); + if (err < 0) + goto stop; + } + + i_fa++; + } + + if (filter->dump_exceptions) { + err = fib_dump_info_fnhe(skb, cb, tb->tb_id, fi, + &i_fa, s_fa, flags); + if (err < 0) + goto stop; + } + +next: + i++; + } + + cb->args[4] = i; + return skb->len; + +stop: + cb->args[4] = i; + cb->args[5] = i_fa; + return err; +} + +/* rcu_read_lock needs to be hold by caller from readside */ +int fib_table_dump(struct fib_table *tb, struct sk_buff *skb, + struct netlink_callback *cb, struct fib_dump_filter *filter) +{ + struct trie *t = (struct trie *)tb->tb_data; + struct key_vector *l, *tp = t->kv; + /* Dump starting at last key. + * Note: 0.0.0.0/0 (ie default) is first key. + */ + int count = cb->args[2]; + t_key key = cb->args[3]; + + /* First time here, count and key are both always 0. Count > 0 + * and key == 0 means the dump has wrapped around and we are done. + */ + if (count && !key) + return skb->len; + + while ((l = leaf_walk_rcu(&tp, key)) != NULL) { + int err; + + err = fn_trie_dump_leaf(l, tb, skb, cb, filter); + if (err < 0) { + cb->args[3] = key; + cb->args[2] = count; + return err; + } + + ++count; + key = l->key + 1; + + memset(&cb->args[4], 0, + sizeof(cb->args) - 4*sizeof(cb->args[0])); + + /* stop loop if key wrapped back to 0 */ + if (key < l->key) + break; + } + + cb->args[3] = key; + cb->args[2] = count; + + return skb->len; +} + +void __init fib_trie_init(void) +{ + fn_alias_kmem = kmem_cache_create("ip_fib_alias", + sizeof(struct fib_alias), + 0, SLAB_PANIC | SLAB_ACCOUNT, NULL); + + trie_leaf_kmem = kmem_cache_create("ip_fib_trie", + LEAF_SIZE, + 0, SLAB_PANIC | SLAB_ACCOUNT, NULL); +} + +struct fib_table *fib_trie_table(u32 id, struct fib_table *alias) +{ + struct fib_table *tb; + struct trie *t; + size_t sz = sizeof(*tb); + + if (!alias) + sz += sizeof(struct trie); + + tb = kzalloc(sz, GFP_KERNEL); + if (!tb) + return NULL; + + tb->tb_id = id; + tb->tb_num_default = 0; + tb->tb_data = (alias ? alias->__data : tb->__data); + + if (alias) + return tb; + + t = (struct trie *) tb->tb_data; + t->kv[0].pos = KEYLENGTH; + t->kv[0].slen = KEYLENGTH; +#ifdef CONFIG_IP_FIB_TRIE_STATS + t->stats = alloc_percpu(struct trie_use_stats); + if (!t->stats) { + kfree(tb); + tb = NULL; + } +#endif + + return tb; +} + +#ifdef CONFIG_PROC_FS +/* Depth first Trie walk iterator */ +struct fib_trie_iter { + struct seq_net_private p; + struct fib_table *tb; + struct key_vector *tnode; + unsigned int index; + unsigned int depth; +}; + +static struct key_vector *fib_trie_get_next(struct fib_trie_iter *iter) +{ + unsigned long cindex = iter->index; + struct key_vector *pn = iter->tnode; + t_key pkey; + + pr_debug("get_next iter={node=%p index=%d depth=%d}\n", + iter->tnode, iter->index, iter->depth); + + while (!IS_TRIE(pn)) { + while (cindex < child_length(pn)) { + struct key_vector *n = get_child_rcu(pn, cindex++); + + if (!n) + continue; + + if (IS_LEAF(n)) { + iter->tnode = pn; + iter->index = cindex; + } else { + /* push down one level */ + iter->tnode = n; + iter->index = 0; + ++iter->depth; + } + + return n; + } + + /* Current node exhausted, pop back up */ + pkey = pn->key; + pn = node_parent_rcu(pn); + cindex = get_index(pkey, pn) + 1; + --iter->depth; + } + + /* record root node so further searches know we are done */ + iter->tnode = pn; + iter->index = 0; + + return NULL; +} + +static struct key_vector *fib_trie_get_first(struct fib_trie_iter *iter, + struct trie *t) +{ + struct key_vector *n, *pn; + + if (!t) + return NULL; + + pn = t->kv; + n = rcu_dereference(pn->tnode[0]); + if (!n) + return NULL; + + if (IS_TNODE(n)) { + iter->tnode = n; + iter->index = 0; + iter->depth = 1; + } else { + iter->tnode = pn; + iter->index = 0; + iter->depth = 0; + } + + return n; +} + +static void trie_collect_stats(struct trie *t, struct trie_stat *s) +{ + struct key_vector *n; + struct fib_trie_iter iter; + + memset(s, 0, sizeof(*s)); + + rcu_read_lock(); + for (n = fib_trie_get_first(&iter, t); n; n = fib_trie_get_next(&iter)) { + if (IS_LEAF(n)) { + struct fib_alias *fa; + + s->leaves++; + s->totdepth += iter.depth; + if (iter.depth > s->maxdepth) + s->maxdepth = iter.depth; + + hlist_for_each_entry_rcu(fa, &n->leaf, fa_list) + ++s->prefixes; + } else { + s->tnodes++; + if (n->bits < MAX_STAT_DEPTH) + s->nodesizes[n->bits]++; + s->nullpointers += tn_info(n)->empty_children; + } + } + rcu_read_unlock(); +} + +/* + * This outputs /proc/net/fib_triestats + */ +static void trie_show_stats(struct seq_file *seq, struct trie_stat *stat) +{ + unsigned int i, max, pointers, bytes, avdepth; + + if (stat->leaves) + avdepth = stat->totdepth*100 / stat->leaves; + else + avdepth = 0; + + seq_printf(seq, "\tAver depth: %u.%02d\n", + avdepth / 100, avdepth % 100); + seq_printf(seq, "\tMax depth: %u\n", stat->maxdepth); + + seq_printf(seq, "\tLeaves: %u\n", stat->leaves); + bytes = LEAF_SIZE * stat->leaves; + + seq_printf(seq, "\tPrefixes: %u\n", stat->prefixes); + bytes += sizeof(struct fib_alias) * stat->prefixes; + + seq_printf(seq, "\tInternal nodes: %u\n\t", stat->tnodes); + bytes += TNODE_SIZE(0) * stat->tnodes; + + max = MAX_STAT_DEPTH; + while (max > 0 && stat->nodesizes[max-1] == 0) + max--; + + pointers = 0; + for (i = 1; i < max; i++) + if (stat->nodesizes[i] != 0) { + seq_printf(seq, " %u: %u", i, stat->nodesizes[i]); + pointers += (1<nodesizes[i]; + } + seq_putc(seq, '\n'); + seq_printf(seq, "\tPointers: %u\n", pointers); + + bytes += sizeof(struct key_vector *) * pointers; + seq_printf(seq, "Null ptrs: %u\n", stat->nullpointers); + seq_printf(seq, "Total size: %u kB\n", (bytes + 1023) / 1024); +} + +#ifdef CONFIG_IP_FIB_TRIE_STATS +static void trie_show_usage(struct seq_file *seq, + const struct trie_use_stats __percpu *stats) +{ + struct trie_use_stats s = { 0 }; + int cpu; + + /* loop through all of the CPUs and gather up the stats */ + for_each_possible_cpu(cpu) { + const struct trie_use_stats *pcpu = per_cpu_ptr(stats, cpu); + + s.gets += pcpu->gets; + s.backtrack += pcpu->backtrack; + s.semantic_match_passed += pcpu->semantic_match_passed; + s.semantic_match_miss += pcpu->semantic_match_miss; + s.null_node_hit += pcpu->null_node_hit; + s.resize_node_skipped += pcpu->resize_node_skipped; + } + + seq_printf(seq, "\nCounters:\n---------\n"); + seq_printf(seq, "gets = %u\n", s.gets); + seq_printf(seq, "backtracks = %u\n", s.backtrack); + seq_printf(seq, "semantic match passed = %u\n", + s.semantic_match_passed); + seq_printf(seq, "semantic match miss = %u\n", s.semantic_match_miss); + seq_printf(seq, "null node hit= %u\n", s.null_node_hit); + seq_printf(seq, "skipped node resize = %u\n\n", s.resize_node_skipped); +} +#endif /* CONFIG_IP_FIB_TRIE_STATS */ + +static void fib_table_print(struct seq_file *seq, struct fib_table *tb) +{ + if (tb->tb_id == RT_TABLE_LOCAL) + seq_puts(seq, "Local:\n"); + else if (tb->tb_id == RT_TABLE_MAIN) + seq_puts(seq, "Main:\n"); + else + seq_printf(seq, "Id %d:\n", tb->tb_id); +} + + +static int fib_triestat_seq_show(struct seq_file *seq, void *v) +{ + struct net *net = seq->private; + unsigned int h; + + seq_printf(seq, + "Basic info: size of leaf:" + " %zd bytes, size of tnode: %zd bytes.\n", + LEAF_SIZE, TNODE_SIZE(0)); + + rcu_read_lock(); + for (h = 0; h < FIB_TABLE_HASHSZ; h++) { + struct hlist_head *head = &net->ipv4.fib_table_hash[h]; + struct fib_table *tb; + + hlist_for_each_entry_rcu(tb, head, tb_hlist) { + struct trie *t = (struct trie *) tb->tb_data; + struct trie_stat stat; + + if (!t) + continue; + + fib_table_print(seq, tb); + + trie_collect_stats(t, &stat); + trie_show_stats(seq, &stat); +#ifdef CONFIG_IP_FIB_TRIE_STATS + trie_show_usage(seq, t->stats); +#endif + } + cond_resched_rcu(); + } + rcu_read_unlock(); + + return 0; +} + +static struct key_vector *fib_trie_get_idx(struct seq_file *seq, loff_t pos) +{ + struct fib_trie_iter *iter = seq->private; + struct net *net = seq_file_net(seq); + loff_t idx = 0; + unsigned int h; + + for (h = 0; h < FIB_TABLE_HASHSZ; h++) { + struct hlist_head *head = &net->ipv4.fib_table_hash[h]; + struct fib_table *tb; + + hlist_for_each_entry_rcu(tb, head, tb_hlist) { + struct key_vector *n; + + for (n = fib_trie_get_first(iter, + (struct trie *) tb->tb_data); + n; n = fib_trie_get_next(iter)) + if (pos == idx++) { + iter->tb = tb; + return n; + } + } + } + + return NULL; +} + +static void *fib_trie_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(RCU) +{ + rcu_read_lock(); + return fib_trie_get_idx(seq, *pos); +} + +static void *fib_trie_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct fib_trie_iter *iter = seq->private; + struct net *net = seq_file_net(seq); + struct fib_table *tb = iter->tb; + struct hlist_node *tb_node; + unsigned int h; + struct key_vector *n; + + ++*pos; + /* next node in same table */ + n = fib_trie_get_next(iter); + if (n) + return n; + + /* walk rest of this hash chain */ + h = tb->tb_id & (FIB_TABLE_HASHSZ - 1); + while ((tb_node = rcu_dereference(hlist_next_rcu(&tb->tb_hlist)))) { + tb = hlist_entry(tb_node, struct fib_table, tb_hlist); + n = fib_trie_get_first(iter, (struct trie *) tb->tb_data); + if (n) + goto found; + } + + /* new hash chain */ + while (++h < FIB_TABLE_HASHSZ) { + struct hlist_head *head = &net->ipv4.fib_table_hash[h]; + hlist_for_each_entry_rcu(tb, head, tb_hlist) { + n = fib_trie_get_first(iter, (struct trie *) tb->tb_data); + if (n) + goto found; + } + } + return NULL; + +found: + iter->tb = tb; + return n; +} + +static void fib_trie_seq_stop(struct seq_file *seq, void *v) + __releases(RCU) +{ + rcu_read_unlock(); +} + +static void seq_indent(struct seq_file *seq, int n) +{ + while (n-- > 0) + seq_puts(seq, " "); +} + +static inline const char *rtn_scope(char *buf, size_t len, enum rt_scope_t s) +{ + switch (s) { + case RT_SCOPE_UNIVERSE: return "universe"; + case RT_SCOPE_SITE: return "site"; + case RT_SCOPE_LINK: return "link"; + case RT_SCOPE_HOST: return "host"; + case RT_SCOPE_NOWHERE: return "nowhere"; + default: + snprintf(buf, len, "scope=%d", s); + return buf; + } +} + +static const char *const rtn_type_names[__RTN_MAX] = { + [RTN_UNSPEC] = "UNSPEC", + [RTN_UNICAST] = "UNICAST", + [RTN_LOCAL] = "LOCAL", + [RTN_BROADCAST] = "BROADCAST", + [RTN_ANYCAST] = "ANYCAST", + [RTN_MULTICAST] = "MULTICAST", + [RTN_BLACKHOLE] = "BLACKHOLE", + [RTN_UNREACHABLE] = "UNREACHABLE", + [RTN_PROHIBIT] = "PROHIBIT", + [RTN_THROW] = "THROW", + [RTN_NAT] = "NAT", + [RTN_XRESOLVE] = "XRESOLVE", +}; + +static inline const char *rtn_type(char *buf, size_t len, unsigned int t) +{ + if (t < __RTN_MAX && rtn_type_names[t]) + return rtn_type_names[t]; + snprintf(buf, len, "type %u", t); + return buf; +} + +/* Pretty print the trie */ +static int fib_trie_seq_show(struct seq_file *seq, void *v) +{ + const struct fib_trie_iter *iter = seq->private; + struct key_vector *n = v; + + if (IS_TRIE(node_parent_rcu(n))) + fib_table_print(seq, iter->tb); + + if (IS_TNODE(n)) { + __be32 prf = htonl(n->key); + + seq_indent(seq, iter->depth-1); + seq_printf(seq, " +-- %pI4/%zu %u %u %u\n", + &prf, KEYLENGTH - n->pos - n->bits, n->bits, + tn_info(n)->full_children, + tn_info(n)->empty_children); + } else { + __be32 val = htonl(n->key); + struct fib_alias *fa; + + seq_indent(seq, iter->depth); + seq_printf(seq, " |-- %pI4\n", &val); + + hlist_for_each_entry_rcu(fa, &n->leaf, fa_list) { + char buf1[32], buf2[32]; + + seq_indent(seq, iter->depth + 1); + seq_printf(seq, " /%zu %s %s", + KEYLENGTH - fa->fa_slen, + rtn_scope(buf1, sizeof(buf1), + fa->fa_info->fib_scope), + rtn_type(buf2, sizeof(buf2), + fa->fa_type)); + if (fa->fa_dscp) + seq_printf(seq, " tos=%d", + inet_dscp_to_dsfield(fa->fa_dscp)); + seq_putc(seq, '\n'); + } + } + + return 0; +} + +static const struct seq_operations fib_trie_seq_ops = { + .start = fib_trie_seq_start, + .next = fib_trie_seq_next, + .stop = fib_trie_seq_stop, + .show = fib_trie_seq_show, +}; + +struct fib_route_iter { + struct seq_net_private p; + struct fib_table *main_tb; + struct key_vector *tnode; + loff_t pos; + t_key key; +}; + +static struct key_vector *fib_route_get_idx(struct fib_route_iter *iter, + loff_t pos) +{ + struct key_vector *l, **tp = &iter->tnode; + t_key key; + + /* use cached location of previously found key */ + if (iter->pos > 0 && pos >= iter->pos) { + key = iter->key; + } else { + iter->pos = 1; + key = 0; + } + + pos -= iter->pos; + + while ((l = leaf_walk_rcu(tp, key)) && (pos-- > 0)) { + key = l->key + 1; + iter->pos++; + l = NULL; + + /* handle unlikely case of a key wrap */ + if (!key) + break; + } + + if (l) + iter->key = l->key; /* remember it */ + else + iter->pos = 0; /* forget it */ + + return l; +} + +static void *fib_route_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(RCU) +{ + struct fib_route_iter *iter = seq->private; + struct fib_table *tb; + struct trie *t; + + rcu_read_lock(); + + tb = fib_get_table(seq_file_net(seq), RT_TABLE_MAIN); + if (!tb) + return NULL; + + iter->main_tb = tb; + t = (struct trie *)tb->tb_data; + iter->tnode = t->kv; + + if (*pos != 0) + return fib_route_get_idx(iter, *pos); + + iter->pos = 0; + iter->key = KEY_MAX; + + return SEQ_START_TOKEN; +} + +static void *fib_route_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct fib_route_iter *iter = seq->private; + struct key_vector *l = NULL; + t_key key = iter->key + 1; + + ++*pos; + + /* only allow key of 0 for start of sequence */ + if ((v == SEQ_START_TOKEN) || key) + l = leaf_walk_rcu(&iter->tnode, key); + + if (l) { + iter->key = l->key; + iter->pos++; + } else { + iter->pos = 0; + } + + return l; +} + +static void fib_route_seq_stop(struct seq_file *seq, void *v) + __releases(RCU) +{ + rcu_read_unlock(); +} + +static unsigned int fib_flag_trans(int type, __be32 mask, struct fib_info *fi) +{ + unsigned int flags = 0; + + if (type == RTN_UNREACHABLE || type == RTN_PROHIBIT) + flags = RTF_REJECT; + if (fi) { + const struct fib_nh_common *nhc = fib_info_nhc(fi, 0); + + if (nhc->nhc_gw.ipv4) + flags |= RTF_GATEWAY; + } + if (mask == htonl(0xFFFFFFFF)) + flags |= RTF_HOST; + flags |= RTF_UP; + return flags; +} + +/* + * This outputs /proc/net/route. + * The format of the file is not supposed to be changed + * and needs to be same as fib_hash output to avoid breaking + * legacy utilities + */ +static int fib_route_seq_show(struct seq_file *seq, void *v) +{ + struct fib_route_iter *iter = seq->private; + struct fib_table *tb = iter->main_tb; + struct fib_alias *fa; + struct key_vector *l = v; + __be32 prefix; + + if (v == SEQ_START_TOKEN) { + seq_printf(seq, "%-127s\n", "Iface\tDestination\tGateway " + "\tFlags\tRefCnt\tUse\tMetric\tMask\t\tMTU" + "\tWindow\tIRTT"); + return 0; + } + + prefix = htonl(l->key); + + hlist_for_each_entry_rcu(fa, &l->leaf, fa_list) { + struct fib_info *fi = fa->fa_info; + __be32 mask = inet_make_mask(KEYLENGTH - fa->fa_slen); + unsigned int flags = fib_flag_trans(fa->fa_type, mask, fi); + + if ((fa->fa_type == RTN_BROADCAST) || + (fa->fa_type == RTN_MULTICAST)) + continue; + + if (fa->tb_id != tb->tb_id) + continue; + + seq_setwidth(seq, 127); + + if (fi) { + struct fib_nh_common *nhc = fib_info_nhc(fi, 0); + __be32 gw = 0; + + if (nhc->nhc_gw_family == AF_INET) + gw = nhc->nhc_gw.ipv4; + + seq_printf(seq, + "%s\t%08X\t%08X\t%04X\t%d\t%u\t" + "%d\t%08X\t%d\t%u\t%u", + nhc->nhc_dev ? nhc->nhc_dev->name : "*", + prefix, gw, flags, 0, 0, + fi->fib_priority, + mask, + (fi->fib_advmss ? + fi->fib_advmss + 40 : 0), + fi->fib_window, + fi->fib_rtt >> 3); + } else { + seq_printf(seq, + "*\t%08X\t%08X\t%04X\t%d\t%u\t" + "%d\t%08X\t%d\t%u\t%u", + prefix, 0, flags, 0, 0, 0, + mask, 0, 0, 0); + } + seq_pad(seq, '\n'); + } + + return 0; +} + +static const struct seq_operations fib_route_seq_ops = { + .start = fib_route_seq_start, + .next = fib_route_seq_next, + .stop = fib_route_seq_stop, + .show = fib_route_seq_show, +}; + +int __net_init fib_proc_init(struct net *net) +{ + if (!proc_create_net("fib_trie", 0444, net->proc_net, &fib_trie_seq_ops, + sizeof(struct fib_trie_iter))) + goto out1; + + if (!proc_create_net_single("fib_triestat", 0444, net->proc_net, + fib_triestat_seq_show, NULL)) + goto out2; + + if (!proc_create_net("route", 0444, net->proc_net, &fib_route_seq_ops, + sizeof(struct fib_route_iter))) + goto out3; + + return 0; + +out3: + remove_proc_entry("fib_triestat", net->proc_net); +out2: + remove_proc_entry("fib_trie", net->proc_net); +out1: + return -ENOMEM; +} + +void __net_exit fib_proc_exit(struct net *net) +{ + remove_proc_entry("fib_trie", net->proc_net); + remove_proc_entry("fib_triestat", net->proc_net); + remove_proc_entry("route", net->proc_net); +} + +#endif /* CONFIG_PROC_FS */ diff --git a/net/ipv4/fou.c b/net/ipv4/fou.c new file mode 100644 index 000000000..0c3c6d0ce --- /dev/null +++ b/net/ipv4/fou.c @@ -0,0 +1,1294 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct fou { + struct socket *sock; + u8 protocol; + u8 flags; + __be16 port; + u8 family; + u16 type; + struct list_head list; + struct rcu_head rcu; +}; + +#define FOU_F_REMCSUM_NOPARTIAL BIT(0) + +struct fou_cfg { + u16 type; + u8 protocol; + u8 flags; + struct udp_port_cfg udp_config; +}; + +static unsigned int fou_net_id; + +struct fou_net { + struct list_head fou_list; + struct mutex fou_lock; +}; + +static inline struct fou *fou_from_sock(struct sock *sk) +{ + return sk->sk_user_data; +} + +static int fou_recv_pull(struct sk_buff *skb, struct fou *fou, size_t len) +{ + /* Remove 'len' bytes from the packet (UDP header and + * FOU header if present). + */ + if (fou->family == AF_INET) + ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(skb)->tot_len) - len); + else + ipv6_hdr(skb)->payload_len = + htons(ntohs(ipv6_hdr(skb)->payload_len) - len); + + __skb_pull(skb, len); + skb_postpull_rcsum(skb, udp_hdr(skb), len); + skb_reset_transport_header(skb); + return iptunnel_pull_offloads(skb); +} + +static int fou_udp_recv(struct sock *sk, struct sk_buff *skb) +{ + struct fou *fou = fou_from_sock(sk); + + if (!fou) + return 1; + + if (fou_recv_pull(skb, fou, sizeof(struct udphdr))) + goto drop; + + return -fou->protocol; + +drop: + kfree_skb(skb); + return 0; +} + +static struct guehdr *gue_remcsum(struct sk_buff *skb, struct guehdr *guehdr, + void *data, size_t hdrlen, u8 ipproto, + bool nopartial) +{ + __be16 *pd = data; + size_t start = ntohs(pd[0]); + size_t offset = ntohs(pd[1]); + size_t plen = sizeof(struct udphdr) + hdrlen + + max_t(size_t, offset + sizeof(u16), start); + + if (skb->remcsum_offload) + return guehdr; + + if (!pskb_may_pull(skb, plen)) + return NULL; + guehdr = (struct guehdr *)&udp_hdr(skb)[1]; + + skb_remcsum_process(skb, (void *)guehdr + hdrlen, + start, offset, nopartial); + + return guehdr; +} + +static int gue_control_message(struct sk_buff *skb, struct guehdr *guehdr) +{ + /* No support yet */ + kfree_skb(skb); + return 0; +} + +static int gue_udp_recv(struct sock *sk, struct sk_buff *skb) +{ + struct fou *fou = fou_from_sock(sk); + size_t len, optlen, hdrlen; + struct guehdr *guehdr; + void *data; + u16 doffset = 0; + u8 proto_ctype; + + if (!fou) + return 1; + + len = sizeof(struct udphdr) + sizeof(struct guehdr); + if (!pskb_may_pull(skb, len)) + goto drop; + + guehdr = (struct guehdr *)&udp_hdr(skb)[1]; + + switch (guehdr->version) { + case 0: /* Full GUE header present */ + break; + + case 1: { + /* Direct encapsulation of IPv4 or IPv6 */ + + int prot; + + switch (((struct iphdr *)guehdr)->version) { + case 4: + prot = IPPROTO_IPIP; + break; + case 6: + prot = IPPROTO_IPV6; + break; + default: + goto drop; + } + + if (fou_recv_pull(skb, fou, sizeof(struct udphdr))) + goto drop; + + return -prot; + } + + default: /* Undefined version */ + goto drop; + } + + optlen = guehdr->hlen << 2; + len += optlen; + + if (!pskb_may_pull(skb, len)) + goto drop; + + /* guehdr may change after pull */ + guehdr = (struct guehdr *)&udp_hdr(skb)[1]; + + if (validate_gue_flags(guehdr, optlen)) + goto drop; + + hdrlen = sizeof(struct guehdr) + optlen; + + if (fou->family == AF_INET) + ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(skb)->tot_len) - len); + else + ipv6_hdr(skb)->payload_len = + htons(ntohs(ipv6_hdr(skb)->payload_len) - len); + + /* Pull csum through the guehdr now . This can be used if + * there is a remote checksum offload. + */ + skb_postpull_rcsum(skb, udp_hdr(skb), len); + + data = &guehdr[1]; + + if (guehdr->flags & GUE_FLAG_PRIV) { + __be32 flags = *(__be32 *)(data + doffset); + + doffset += GUE_LEN_PRIV; + + if (flags & GUE_PFLAG_REMCSUM) { + guehdr = gue_remcsum(skb, guehdr, data + doffset, + hdrlen, guehdr->proto_ctype, + !!(fou->flags & + FOU_F_REMCSUM_NOPARTIAL)); + if (!guehdr) + goto drop; + + data = &guehdr[1]; + + doffset += GUE_PLEN_REMCSUM; + } + } + + if (unlikely(guehdr->control)) + return gue_control_message(skb, guehdr); + + proto_ctype = guehdr->proto_ctype; + __skb_pull(skb, sizeof(struct udphdr) + hdrlen); + skb_reset_transport_header(skb); + + if (iptunnel_pull_offloads(skb)) + goto drop; + + return -proto_ctype; + +drop: + kfree_skb(skb); + return 0; +} + +static struct sk_buff *fou_gro_receive(struct sock *sk, + struct list_head *head, + struct sk_buff *skb) +{ + const struct net_offload __rcu **offloads; + u8 proto = fou_from_sock(sk)->protocol; + const struct net_offload *ops; + struct sk_buff *pp = NULL; + + /* We can clear the encap_mark for FOU as we are essentially doing + * one of two possible things. We are either adding an L4 tunnel + * header to the outer L3 tunnel header, or we are simply + * treating the GRE tunnel header as though it is a UDP protocol + * specific header such as VXLAN or GENEVE. + */ + NAPI_GRO_CB(skb)->encap_mark = 0; + + /* Flag this frame as already having an outer encap header */ + NAPI_GRO_CB(skb)->is_fou = 1; + + offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads; + ops = rcu_dereference(offloads[proto]); + if (!ops || !ops->callbacks.gro_receive) + goto out; + + pp = call_gro_receive(ops->callbacks.gro_receive, head, skb); + +out: + return pp; +} + +static int fou_gro_complete(struct sock *sk, struct sk_buff *skb, + int nhoff) +{ + const struct net_offload __rcu **offloads; + u8 proto = fou_from_sock(sk)->protocol; + const struct net_offload *ops; + int err = -ENOSYS; + + offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads; + ops = rcu_dereference(offloads[proto]); + if (WARN_ON(!ops || !ops->callbacks.gro_complete)) + goto out; + + err = ops->callbacks.gro_complete(skb, nhoff); + + skb_set_inner_mac_header(skb, nhoff); + +out: + return err; +} + +static struct guehdr *gue_gro_remcsum(struct sk_buff *skb, unsigned int off, + struct guehdr *guehdr, void *data, + size_t hdrlen, struct gro_remcsum *grc, + bool nopartial) +{ + __be16 *pd = data; + size_t start = ntohs(pd[0]); + size_t offset = ntohs(pd[1]); + + if (skb->remcsum_offload) + return guehdr; + + if (!NAPI_GRO_CB(skb)->csum_valid) + return NULL; + + guehdr = skb_gro_remcsum_process(skb, (void *)guehdr, off, hdrlen, + start, offset, grc, nopartial); + + skb->remcsum_offload = 1; + + return guehdr; +} + +static struct sk_buff *gue_gro_receive(struct sock *sk, + struct list_head *head, + struct sk_buff *skb) +{ + const struct net_offload __rcu **offloads; + const struct net_offload *ops; + struct sk_buff *pp = NULL; + struct sk_buff *p; + struct guehdr *guehdr; + size_t len, optlen, hdrlen, off; + void *data; + u16 doffset = 0; + int flush = 1; + struct fou *fou = fou_from_sock(sk); + struct gro_remcsum grc; + u8 proto; + + skb_gro_remcsum_init(&grc); + + off = skb_gro_offset(skb); + len = off + sizeof(*guehdr); + + guehdr = skb_gro_header(skb, len, off); + if (unlikely(!guehdr)) + goto out; + + switch (guehdr->version) { + case 0: + break; + case 1: + switch (((struct iphdr *)guehdr)->version) { + case 4: + proto = IPPROTO_IPIP; + break; + case 6: + proto = IPPROTO_IPV6; + break; + default: + goto out; + } + goto next_proto; + default: + goto out; + } + + optlen = guehdr->hlen << 2; + len += optlen; + + if (skb_gro_header_hard(skb, len)) { + guehdr = skb_gro_header_slow(skb, len, off); + if (unlikely(!guehdr)) + goto out; + } + + if (unlikely(guehdr->control) || guehdr->version != 0 || + validate_gue_flags(guehdr, optlen)) + goto out; + + hdrlen = sizeof(*guehdr) + optlen; + + /* Adjust NAPI_GRO_CB(skb)->csum to account for guehdr, + * this is needed if there is a remote checkcsum offload. + */ + skb_gro_postpull_rcsum(skb, guehdr, hdrlen); + + data = &guehdr[1]; + + if (guehdr->flags & GUE_FLAG_PRIV) { + __be32 flags = *(__be32 *)(data + doffset); + + doffset += GUE_LEN_PRIV; + + if (flags & GUE_PFLAG_REMCSUM) { + guehdr = gue_gro_remcsum(skb, off, guehdr, + data + doffset, hdrlen, &grc, + !!(fou->flags & + FOU_F_REMCSUM_NOPARTIAL)); + + if (!guehdr) + goto out; + + data = &guehdr[1]; + + doffset += GUE_PLEN_REMCSUM; + } + } + + skb_gro_pull(skb, hdrlen); + + list_for_each_entry(p, head, list) { + const struct guehdr *guehdr2; + + if (!NAPI_GRO_CB(p)->same_flow) + continue; + + guehdr2 = (struct guehdr *)(p->data + off); + + /* Compare base GUE header to be equal (covers + * hlen, version, proto_ctype, and flags. + */ + if (guehdr->word != guehdr2->word) { + NAPI_GRO_CB(p)->same_flow = 0; + continue; + } + + /* Compare optional fields are the same. */ + if (guehdr->hlen && memcmp(&guehdr[1], &guehdr2[1], + guehdr->hlen << 2)) { + NAPI_GRO_CB(p)->same_flow = 0; + continue; + } + } + + proto = guehdr->proto_ctype; + +next_proto: + + /* We can clear the encap_mark for GUE as we are essentially doing + * one of two possible things. We are either adding an L4 tunnel + * header to the outer L3 tunnel header, or we are simply + * treating the GRE tunnel header as though it is a UDP protocol + * specific header such as VXLAN or GENEVE. + */ + NAPI_GRO_CB(skb)->encap_mark = 0; + + /* Flag this frame as already having an outer encap header */ + NAPI_GRO_CB(skb)->is_fou = 1; + + offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads; + ops = rcu_dereference(offloads[proto]); + if (WARN_ON_ONCE(!ops || !ops->callbacks.gro_receive)) + goto out; + + pp = call_gro_receive(ops->callbacks.gro_receive, head, skb); + flush = 0; + +out: + skb_gro_flush_final_remcsum(skb, pp, flush, &grc); + + return pp; +} + +static int gue_gro_complete(struct sock *sk, struct sk_buff *skb, int nhoff) +{ + struct guehdr *guehdr = (struct guehdr *)(skb->data + nhoff); + const struct net_offload __rcu **offloads; + const struct net_offload *ops; + unsigned int guehlen = 0; + u8 proto; + int err = -ENOENT; + + switch (guehdr->version) { + case 0: + proto = guehdr->proto_ctype; + guehlen = sizeof(*guehdr) + (guehdr->hlen << 2); + break; + case 1: + switch (((struct iphdr *)guehdr)->version) { + case 4: + proto = IPPROTO_IPIP; + break; + case 6: + proto = IPPROTO_IPV6; + break; + default: + return err; + } + break; + default: + return err; + } + + offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads; + ops = rcu_dereference(offloads[proto]); + if (WARN_ON(!ops || !ops->callbacks.gro_complete)) + goto out; + + err = ops->callbacks.gro_complete(skb, nhoff + guehlen); + + skb_set_inner_mac_header(skb, nhoff + guehlen); + +out: + return err; +} + +static bool fou_cfg_cmp(struct fou *fou, struct fou_cfg *cfg) +{ + struct sock *sk = fou->sock->sk; + struct udp_port_cfg *udp_cfg = &cfg->udp_config; + + if (fou->family != udp_cfg->family || + fou->port != udp_cfg->local_udp_port || + sk->sk_dport != udp_cfg->peer_udp_port || + sk->sk_bound_dev_if != udp_cfg->bind_ifindex) + return false; + + if (fou->family == AF_INET) { + if (sk->sk_rcv_saddr != udp_cfg->local_ip.s_addr || + sk->sk_daddr != udp_cfg->peer_ip.s_addr) + return false; + else + return true; +#if IS_ENABLED(CONFIG_IPV6) + } else { + if (ipv6_addr_cmp(&sk->sk_v6_rcv_saddr, &udp_cfg->local_ip6) || + ipv6_addr_cmp(&sk->sk_v6_daddr, &udp_cfg->peer_ip6)) + return false; + else + return true; +#endif + } + + return false; +} + +static int fou_add_to_port_list(struct net *net, struct fou *fou, + struct fou_cfg *cfg) +{ + struct fou_net *fn = net_generic(net, fou_net_id); + struct fou *fout; + + mutex_lock(&fn->fou_lock); + list_for_each_entry(fout, &fn->fou_list, list) { + if (fou_cfg_cmp(fout, cfg)) { + mutex_unlock(&fn->fou_lock); + return -EALREADY; + } + } + + list_add(&fou->list, &fn->fou_list); + mutex_unlock(&fn->fou_lock); + + return 0; +} + +static void fou_release(struct fou *fou) +{ + struct socket *sock = fou->sock; + + list_del(&fou->list); + udp_tunnel_sock_release(sock); + + kfree_rcu(fou, rcu); +} + +static int fou_create(struct net *net, struct fou_cfg *cfg, + struct socket **sockp) +{ + struct socket *sock = NULL; + struct fou *fou = NULL; + struct sock *sk; + struct udp_tunnel_sock_cfg tunnel_cfg; + int err; + + /* Open UDP socket */ + err = udp_sock_create(net, &cfg->udp_config, &sock); + if (err < 0) + goto error; + + /* Allocate FOU port structure */ + fou = kzalloc(sizeof(*fou), GFP_KERNEL); + if (!fou) { + err = -ENOMEM; + goto error; + } + + sk = sock->sk; + + fou->port = cfg->udp_config.local_udp_port; + fou->family = cfg->udp_config.family; + fou->flags = cfg->flags; + fou->type = cfg->type; + fou->sock = sock; + + memset(&tunnel_cfg, 0, sizeof(tunnel_cfg)); + tunnel_cfg.encap_type = 1; + tunnel_cfg.sk_user_data = fou; + tunnel_cfg.encap_destroy = NULL; + + /* Initial for fou type */ + switch (cfg->type) { + case FOU_ENCAP_DIRECT: + tunnel_cfg.encap_rcv = fou_udp_recv; + tunnel_cfg.gro_receive = fou_gro_receive; + tunnel_cfg.gro_complete = fou_gro_complete; + fou->protocol = cfg->protocol; + break; + case FOU_ENCAP_GUE: + tunnel_cfg.encap_rcv = gue_udp_recv; + tunnel_cfg.gro_receive = gue_gro_receive; + tunnel_cfg.gro_complete = gue_gro_complete; + break; + default: + err = -EINVAL; + goto error; + } + + setup_udp_tunnel_sock(net, sock, &tunnel_cfg); + + sk->sk_allocation = GFP_ATOMIC; + + err = fou_add_to_port_list(net, fou, cfg); + if (err) + goto error; + + if (sockp) + *sockp = sock; + + return 0; + +error: + kfree(fou); + if (sock) + udp_tunnel_sock_release(sock); + + return err; +} + +static int fou_destroy(struct net *net, struct fou_cfg *cfg) +{ + struct fou_net *fn = net_generic(net, fou_net_id); + int err = -EINVAL; + struct fou *fou; + + mutex_lock(&fn->fou_lock); + list_for_each_entry(fou, &fn->fou_list, list) { + if (fou_cfg_cmp(fou, cfg)) { + fou_release(fou); + err = 0; + break; + } + } + mutex_unlock(&fn->fou_lock); + + return err; +} + +static struct genl_family fou_nl_family; + +static const struct nla_policy fou_nl_policy[FOU_ATTR_MAX + 1] = { + [FOU_ATTR_PORT] = { .type = NLA_U16, }, + [FOU_ATTR_AF] = { .type = NLA_U8, }, + [FOU_ATTR_IPPROTO] = { .type = NLA_U8, }, + [FOU_ATTR_TYPE] = { .type = NLA_U8, }, + [FOU_ATTR_REMCSUM_NOPARTIAL] = { .type = NLA_FLAG, }, + [FOU_ATTR_LOCAL_V4] = { .type = NLA_U32, }, + [FOU_ATTR_PEER_V4] = { .type = NLA_U32, }, + [FOU_ATTR_LOCAL_V6] = { .len = sizeof(struct in6_addr), }, + [FOU_ATTR_PEER_V6] = { .len = sizeof(struct in6_addr), }, + [FOU_ATTR_PEER_PORT] = { .type = NLA_U16, }, + [FOU_ATTR_IFINDEX] = { .type = NLA_S32, }, +}; + +static int parse_nl_config(struct genl_info *info, + struct fou_cfg *cfg) +{ + bool has_local = false, has_peer = false; + struct nlattr *attr; + int ifindex; + __be16 port; + + memset(cfg, 0, sizeof(*cfg)); + + cfg->udp_config.family = AF_INET; + + if (info->attrs[FOU_ATTR_AF]) { + u8 family = nla_get_u8(info->attrs[FOU_ATTR_AF]); + + switch (family) { + case AF_INET: + break; + case AF_INET6: + cfg->udp_config.ipv6_v6only = 1; + break; + default: + return -EAFNOSUPPORT; + } + + cfg->udp_config.family = family; + } + + if (info->attrs[FOU_ATTR_PORT]) { + port = nla_get_be16(info->attrs[FOU_ATTR_PORT]); + cfg->udp_config.local_udp_port = port; + } + + if (info->attrs[FOU_ATTR_IPPROTO]) + cfg->protocol = nla_get_u8(info->attrs[FOU_ATTR_IPPROTO]); + + if (info->attrs[FOU_ATTR_TYPE]) + cfg->type = nla_get_u8(info->attrs[FOU_ATTR_TYPE]); + + if (info->attrs[FOU_ATTR_REMCSUM_NOPARTIAL]) + cfg->flags |= FOU_F_REMCSUM_NOPARTIAL; + + if (cfg->udp_config.family == AF_INET) { + if (info->attrs[FOU_ATTR_LOCAL_V4]) { + attr = info->attrs[FOU_ATTR_LOCAL_V4]; + cfg->udp_config.local_ip.s_addr = nla_get_in_addr(attr); + has_local = true; + } + + if (info->attrs[FOU_ATTR_PEER_V4]) { + attr = info->attrs[FOU_ATTR_PEER_V4]; + cfg->udp_config.peer_ip.s_addr = nla_get_in_addr(attr); + has_peer = true; + } +#if IS_ENABLED(CONFIG_IPV6) + } else { + if (info->attrs[FOU_ATTR_LOCAL_V6]) { + attr = info->attrs[FOU_ATTR_LOCAL_V6]; + cfg->udp_config.local_ip6 = nla_get_in6_addr(attr); + has_local = true; + } + + if (info->attrs[FOU_ATTR_PEER_V6]) { + attr = info->attrs[FOU_ATTR_PEER_V6]; + cfg->udp_config.peer_ip6 = nla_get_in6_addr(attr); + has_peer = true; + } +#endif + } + + if (has_peer) { + if (info->attrs[FOU_ATTR_PEER_PORT]) { + port = nla_get_be16(info->attrs[FOU_ATTR_PEER_PORT]); + cfg->udp_config.peer_udp_port = port; + } else { + return -EINVAL; + } + } + + if (info->attrs[FOU_ATTR_IFINDEX]) { + if (!has_local) + return -EINVAL; + + ifindex = nla_get_s32(info->attrs[FOU_ATTR_IFINDEX]); + + cfg->udp_config.bind_ifindex = ifindex; + } + + return 0; +} + +static int fou_nl_cmd_add_port(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = genl_info_net(info); + struct fou_cfg cfg; + int err; + + err = parse_nl_config(info, &cfg); + if (err) + return err; + + return fou_create(net, &cfg, NULL); +} + +static int fou_nl_cmd_rm_port(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = genl_info_net(info); + struct fou_cfg cfg; + int err; + + err = parse_nl_config(info, &cfg); + if (err) + return err; + + return fou_destroy(net, &cfg); +} + +static int fou_fill_info(struct fou *fou, struct sk_buff *msg) +{ + struct sock *sk = fou->sock->sk; + + if (nla_put_u8(msg, FOU_ATTR_AF, fou->sock->sk->sk_family) || + nla_put_be16(msg, FOU_ATTR_PORT, fou->port) || + nla_put_be16(msg, FOU_ATTR_PEER_PORT, sk->sk_dport) || + nla_put_u8(msg, FOU_ATTR_IPPROTO, fou->protocol) || + nla_put_u8(msg, FOU_ATTR_TYPE, fou->type) || + nla_put_s32(msg, FOU_ATTR_IFINDEX, sk->sk_bound_dev_if)) + return -1; + + if (fou->flags & FOU_F_REMCSUM_NOPARTIAL) + if (nla_put_flag(msg, FOU_ATTR_REMCSUM_NOPARTIAL)) + return -1; + + if (fou->sock->sk->sk_family == AF_INET) { + if (nla_put_in_addr(msg, FOU_ATTR_LOCAL_V4, sk->sk_rcv_saddr)) + return -1; + + if (nla_put_in_addr(msg, FOU_ATTR_PEER_V4, sk->sk_daddr)) + return -1; +#if IS_ENABLED(CONFIG_IPV6) + } else { + if (nla_put_in6_addr(msg, FOU_ATTR_LOCAL_V6, + &sk->sk_v6_rcv_saddr)) + return -1; + + if (nla_put_in6_addr(msg, FOU_ATTR_PEER_V6, &sk->sk_v6_daddr)) + return -1; +#endif + } + + return 0; +} + +static int fou_dump_info(struct fou *fou, u32 portid, u32 seq, + u32 flags, struct sk_buff *skb, u8 cmd) +{ + void *hdr; + + hdr = genlmsg_put(skb, portid, seq, &fou_nl_family, flags, cmd); + if (!hdr) + return -ENOMEM; + + if (fou_fill_info(fou, skb) < 0) + goto nla_put_failure; + + genlmsg_end(skb, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; +} + +static int fou_nl_cmd_get_port(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = genl_info_net(info); + struct fou_net *fn = net_generic(net, fou_net_id); + struct sk_buff *msg; + struct fou_cfg cfg; + struct fou *fout; + __be16 port; + u8 family; + int ret; + + ret = parse_nl_config(info, &cfg); + if (ret) + return ret; + port = cfg.udp_config.local_udp_port; + if (port == 0) + return -EINVAL; + + family = cfg.udp_config.family; + if (family != AF_INET && family != AF_INET6) + return -EINVAL; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + ret = -ESRCH; + mutex_lock(&fn->fou_lock); + list_for_each_entry(fout, &fn->fou_list, list) { + if (fou_cfg_cmp(fout, &cfg)) { + ret = fou_dump_info(fout, info->snd_portid, + info->snd_seq, 0, msg, + info->genlhdr->cmd); + break; + } + } + mutex_unlock(&fn->fou_lock); + if (ret < 0) + goto out_free; + + return genlmsg_reply(msg, info); + +out_free: + nlmsg_free(msg); + return ret; +} + +static int fou_nl_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + struct fou_net *fn = net_generic(net, fou_net_id); + struct fou *fout; + int idx = 0, ret; + + mutex_lock(&fn->fou_lock); + list_for_each_entry(fout, &fn->fou_list, list) { + if (idx++ < cb->args[0]) + continue; + ret = fou_dump_info(fout, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + skb, FOU_CMD_GET); + if (ret) + break; + } + mutex_unlock(&fn->fou_lock); + + cb->args[0] = idx; + return skb->len; +} + +static const struct genl_small_ops fou_nl_ops[] = { + { + .cmd = FOU_CMD_ADD, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = fou_nl_cmd_add_port, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = FOU_CMD_DEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = fou_nl_cmd_rm_port, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = FOU_CMD_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = fou_nl_cmd_get_port, + .dumpit = fou_nl_dump, + }, +}; + +static struct genl_family fou_nl_family __ro_after_init = { + .hdrsize = 0, + .name = FOU_GENL_NAME, + .version = FOU_GENL_VERSION, + .maxattr = FOU_ATTR_MAX, + .policy = fou_nl_policy, + .netnsok = true, + .module = THIS_MODULE, + .small_ops = fou_nl_ops, + .n_small_ops = ARRAY_SIZE(fou_nl_ops), + .resv_start_op = FOU_CMD_GET + 1, +}; + +size_t fou_encap_hlen(struct ip_tunnel_encap *e) +{ + return sizeof(struct udphdr); +} +EXPORT_SYMBOL(fou_encap_hlen); + +size_t gue_encap_hlen(struct ip_tunnel_encap *e) +{ + size_t len; + bool need_priv = false; + + len = sizeof(struct udphdr) + sizeof(struct guehdr); + + if (e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) { + len += GUE_PLEN_REMCSUM; + need_priv = true; + } + + len += need_priv ? GUE_LEN_PRIV : 0; + + return len; +} +EXPORT_SYMBOL(gue_encap_hlen); + +int __fou_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e, + u8 *protocol, __be16 *sport, int type) +{ + int err; + + err = iptunnel_handle_offloads(skb, type); + if (err) + return err; + + *sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev), + skb, 0, 0, false); + + return 0; +} +EXPORT_SYMBOL(__fou_build_header); + +int __gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e, + u8 *protocol, __be16 *sport, int type) +{ + struct guehdr *guehdr; + size_t hdrlen, optlen = 0; + void *data; + bool need_priv = false; + int err; + + if ((e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) && + skb->ip_summed == CHECKSUM_PARTIAL) { + optlen += GUE_PLEN_REMCSUM; + type |= SKB_GSO_TUNNEL_REMCSUM; + need_priv = true; + } + + optlen += need_priv ? GUE_LEN_PRIV : 0; + + err = iptunnel_handle_offloads(skb, type); + if (err) + return err; + + /* Get source port (based on flow hash) before skb_push */ + *sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev), + skb, 0, 0, false); + + hdrlen = sizeof(struct guehdr) + optlen; + + skb_push(skb, hdrlen); + + guehdr = (struct guehdr *)skb->data; + + guehdr->control = 0; + guehdr->version = 0; + guehdr->hlen = optlen >> 2; + guehdr->flags = 0; + guehdr->proto_ctype = *protocol; + + data = &guehdr[1]; + + if (need_priv) { + __be32 *flags = data; + + guehdr->flags |= GUE_FLAG_PRIV; + *flags = 0; + data += GUE_LEN_PRIV; + + if (type & SKB_GSO_TUNNEL_REMCSUM) { + u16 csum_start = skb_checksum_start_offset(skb); + __be16 *pd = data; + + if (csum_start < hdrlen) + return -EINVAL; + + csum_start -= hdrlen; + pd[0] = htons(csum_start); + pd[1] = htons(csum_start + skb->csum_offset); + + if (!skb_is_gso(skb)) { + skb->ip_summed = CHECKSUM_NONE; + skb->encapsulation = 0; + } + + *flags |= GUE_PFLAG_REMCSUM; + data += GUE_PLEN_REMCSUM; + } + + } + + return 0; +} +EXPORT_SYMBOL(__gue_build_header); + +#ifdef CONFIG_NET_FOU_IP_TUNNELS + +static void fou_build_udp(struct sk_buff *skb, struct ip_tunnel_encap *e, + struct flowi4 *fl4, u8 *protocol, __be16 sport) +{ + struct udphdr *uh; + + skb_push(skb, sizeof(struct udphdr)); + skb_reset_transport_header(skb); + + uh = udp_hdr(skb); + + uh->dest = e->dport; + uh->source = sport; + uh->len = htons(skb->len); + udp_set_csum(!(e->flags & TUNNEL_ENCAP_FLAG_CSUM), skb, + fl4->saddr, fl4->daddr, skb->len); + + *protocol = IPPROTO_UDP; +} + +static int fou_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e, + u8 *protocol, struct flowi4 *fl4) +{ + int type = e->flags & TUNNEL_ENCAP_FLAG_CSUM ? SKB_GSO_UDP_TUNNEL_CSUM : + SKB_GSO_UDP_TUNNEL; + __be16 sport; + int err; + + err = __fou_build_header(skb, e, protocol, &sport, type); + if (err) + return err; + + fou_build_udp(skb, e, fl4, protocol, sport); + + return 0; +} + +static int gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e, + u8 *protocol, struct flowi4 *fl4) +{ + int type = e->flags & TUNNEL_ENCAP_FLAG_CSUM ? SKB_GSO_UDP_TUNNEL_CSUM : + SKB_GSO_UDP_TUNNEL; + __be16 sport; + int err; + + err = __gue_build_header(skb, e, protocol, &sport, type); + if (err) + return err; + + fou_build_udp(skb, e, fl4, protocol, sport); + + return 0; +} + +static int gue_err_proto_handler(int proto, struct sk_buff *skb, u32 info) +{ + const struct net_protocol *ipprot = rcu_dereference(inet_protos[proto]); + + if (ipprot && ipprot->err_handler) { + if (!ipprot->err_handler(skb, info)) + return 0; + } + + return -ENOENT; +} + +static int gue_err(struct sk_buff *skb, u32 info) +{ + int transport_offset = skb_transport_offset(skb); + struct guehdr *guehdr; + size_t len, optlen; + int ret; + + len = sizeof(struct udphdr) + sizeof(struct guehdr); + if (!pskb_may_pull(skb, transport_offset + len)) + return -EINVAL; + + guehdr = (struct guehdr *)&udp_hdr(skb)[1]; + + switch (guehdr->version) { + case 0: /* Full GUE header present */ + break; + case 1: { + /* Direct encapsulation of IPv4 or IPv6 */ + skb_set_transport_header(skb, -(int)sizeof(struct icmphdr)); + + switch (((struct iphdr *)guehdr)->version) { + case 4: + ret = gue_err_proto_handler(IPPROTO_IPIP, skb, info); + goto out; +#if IS_ENABLED(CONFIG_IPV6) + case 6: + ret = gue_err_proto_handler(IPPROTO_IPV6, skb, info); + goto out; +#endif + default: + ret = -EOPNOTSUPP; + goto out; + } + } + default: /* Undefined version */ + return -EOPNOTSUPP; + } + + if (guehdr->control) + return -ENOENT; + + optlen = guehdr->hlen << 2; + + if (!pskb_may_pull(skb, transport_offset + len + optlen)) + return -EINVAL; + + guehdr = (struct guehdr *)&udp_hdr(skb)[1]; + if (validate_gue_flags(guehdr, optlen)) + return -EINVAL; + + /* Handling exceptions for direct UDP encapsulation in GUE would lead to + * recursion. Besides, this kind of encapsulation can't even be + * configured currently. Discard this. + */ + if (guehdr->proto_ctype == IPPROTO_UDP || + guehdr->proto_ctype == IPPROTO_UDPLITE) + return -EOPNOTSUPP; + + skb_set_transport_header(skb, -(int)sizeof(struct icmphdr)); + ret = gue_err_proto_handler(guehdr->proto_ctype, skb, info); + +out: + skb_set_transport_header(skb, transport_offset); + return ret; +} + + +static const struct ip_tunnel_encap_ops fou_iptun_ops = { + .encap_hlen = fou_encap_hlen, + .build_header = fou_build_header, + .err_handler = gue_err, +}; + +static const struct ip_tunnel_encap_ops gue_iptun_ops = { + .encap_hlen = gue_encap_hlen, + .build_header = gue_build_header, + .err_handler = gue_err, +}; + +static int ip_tunnel_encap_add_fou_ops(void) +{ + int ret; + + ret = ip_tunnel_encap_add_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU); + if (ret < 0) { + pr_err("can't add fou ops\n"); + return ret; + } + + ret = ip_tunnel_encap_add_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE); + if (ret < 0) { + pr_err("can't add gue ops\n"); + ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU); + return ret; + } + + return 0; +} + +static void ip_tunnel_encap_del_fou_ops(void) +{ + ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU); + ip_tunnel_encap_del_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE); +} + +#else + +static int ip_tunnel_encap_add_fou_ops(void) +{ + return 0; +} + +static void ip_tunnel_encap_del_fou_ops(void) +{ +} + +#endif + +static __net_init int fou_init_net(struct net *net) +{ + struct fou_net *fn = net_generic(net, fou_net_id); + + INIT_LIST_HEAD(&fn->fou_list); + mutex_init(&fn->fou_lock); + return 0; +} + +static __net_exit void fou_exit_net(struct net *net) +{ + struct fou_net *fn = net_generic(net, fou_net_id); + struct fou *fou, *next; + + /* Close all the FOU sockets */ + mutex_lock(&fn->fou_lock); + list_for_each_entry_safe(fou, next, &fn->fou_list, list) + fou_release(fou); + mutex_unlock(&fn->fou_lock); +} + +static struct pernet_operations fou_net_ops = { + .init = fou_init_net, + .exit = fou_exit_net, + .id = &fou_net_id, + .size = sizeof(struct fou_net), +}; + +static int __init fou_init(void) +{ + int ret; + + ret = register_pernet_device(&fou_net_ops); + if (ret) + goto exit; + + ret = genl_register_family(&fou_nl_family); + if (ret < 0) + goto unregister; + + ret = ip_tunnel_encap_add_fou_ops(); + if (ret == 0) + return 0; + + genl_unregister_family(&fou_nl_family); +unregister: + unregister_pernet_device(&fou_net_ops); +exit: + return ret; +} + +static void __exit fou_fini(void) +{ + ip_tunnel_encap_del_fou_ops(); + genl_unregister_family(&fou_nl_family); + unregister_pernet_device(&fou_net_ops); +} + +module_init(fou_init); +module_exit(fou_fini); +MODULE_AUTHOR("Tom Herbert "); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Foo over UDP"); diff --git a/net/ipv4/gre_demux.c b/net/ipv4/gre_demux.c new file mode 100644 index 000000000..cbb2b4bb0 --- /dev/null +++ b/net/ipv4/gre_demux.c @@ -0,0 +1,221 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * GRE over IPv4 demultiplexer driver + * + * Authors: Dmitry Kozlov (xeb@mail.ru) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +static const struct gre_protocol __rcu *gre_proto[GREPROTO_MAX] __read_mostly; + +int gre_add_protocol(const struct gre_protocol *proto, u8 version) +{ + if (version >= GREPROTO_MAX) + return -EINVAL; + + return (cmpxchg((const struct gre_protocol **)&gre_proto[version], NULL, proto) == NULL) ? + 0 : -EBUSY; +} +EXPORT_SYMBOL_GPL(gre_add_protocol); + +int gre_del_protocol(const struct gre_protocol *proto, u8 version) +{ + int ret; + + if (version >= GREPROTO_MAX) + return -EINVAL; + + ret = (cmpxchg((const struct gre_protocol **)&gre_proto[version], proto, NULL) == proto) ? + 0 : -EBUSY; + + if (ret) + return ret; + + synchronize_rcu(); + return 0; +} +EXPORT_SYMBOL_GPL(gre_del_protocol); + +/* Fills in tpi and returns header length to be pulled. + * Note that caller must use pskb_may_pull() before pulling GRE header. + */ +int gre_parse_header(struct sk_buff *skb, struct tnl_ptk_info *tpi, + bool *csum_err, __be16 proto, int nhs) +{ + const struct gre_base_hdr *greh; + __be32 *options; + int hdr_len; + + if (unlikely(!pskb_may_pull(skb, nhs + sizeof(struct gre_base_hdr)))) + return -EINVAL; + + greh = (struct gre_base_hdr *)(skb->data + nhs); + if (unlikely(greh->flags & (GRE_VERSION | GRE_ROUTING))) + return -EINVAL; + + tpi->flags = gre_flags_to_tnl_flags(greh->flags); + hdr_len = gre_calc_hlen(tpi->flags); + + if (!pskb_may_pull(skb, nhs + hdr_len)) + return -EINVAL; + + greh = (struct gre_base_hdr *)(skb->data + nhs); + tpi->proto = greh->protocol; + + options = (__be32 *)(greh + 1); + if (greh->flags & GRE_CSUM) { + if (!skb_checksum_simple_validate(skb)) { + skb_checksum_try_convert(skb, IPPROTO_GRE, + null_compute_pseudo); + } else if (csum_err) { + *csum_err = true; + return -EINVAL; + } + + options++; + } + + if (greh->flags & GRE_KEY) { + tpi->key = *options; + options++; + } else { + tpi->key = 0; + } + if (unlikely(greh->flags & GRE_SEQ)) { + tpi->seq = *options; + options++; + } else { + tpi->seq = 0; + } + /* WCCP version 1 and 2 protocol decoding. + * - Change protocol to IPv4/IPv6 + * - When dealing with WCCPv2, Skip extra 4 bytes in GRE header + */ + if (greh->flags == 0 && tpi->proto == htons(ETH_P_WCCP)) { + u8 _val, *val; + + val = skb_header_pointer(skb, nhs + hdr_len, + sizeof(_val), &_val); + if (!val) + return -EINVAL; + tpi->proto = proto; + if ((*val & 0xF0) != 0x40) + hdr_len += 4; + } + tpi->hdr_len = hdr_len; + + /* ERSPAN ver 1 and 2 protocol sets GRE key field + * to 0 and sets the configured key in the + * inner erspan header field + */ + if ((greh->protocol == htons(ETH_P_ERSPAN) && hdr_len != 4) || + greh->protocol == htons(ETH_P_ERSPAN2)) { + struct erspan_base_hdr *ershdr; + + if (!pskb_may_pull(skb, nhs + hdr_len + sizeof(*ershdr))) + return -EINVAL; + + ershdr = (struct erspan_base_hdr *)(skb->data + nhs + hdr_len); + tpi->key = cpu_to_be32(get_session_id(ershdr)); + } + + return hdr_len; +} +EXPORT_SYMBOL(gre_parse_header); + +static int gre_rcv(struct sk_buff *skb) +{ + const struct gre_protocol *proto; + u8 ver; + int ret; + + if (!pskb_may_pull(skb, 12)) + goto drop; + + ver = skb->data[1]&0x7f; + if (ver >= GREPROTO_MAX) + goto drop; + + rcu_read_lock(); + proto = rcu_dereference(gre_proto[ver]); + if (!proto || !proto->handler) + goto drop_unlock; + ret = proto->handler(skb); + rcu_read_unlock(); + return ret; + +drop_unlock: + rcu_read_unlock(); +drop: + kfree_skb(skb); + return NET_RX_DROP; +} + +static int gre_err(struct sk_buff *skb, u32 info) +{ + const struct gre_protocol *proto; + const struct iphdr *iph = (const struct iphdr *)skb->data; + u8 ver = skb->data[(iph->ihl<<2) + 1]&0x7f; + int err = 0; + + if (ver >= GREPROTO_MAX) + return -EINVAL; + + rcu_read_lock(); + proto = rcu_dereference(gre_proto[ver]); + if (proto && proto->err_handler) + proto->err_handler(skb, info); + else + err = -EPROTONOSUPPORT; + rcu_read_unlock(); + + return err; +} + +static const struct net_protocol net_gre_protocol = { + .handler = gre_rcv, + .err_handler = gre_err, +}; + +static int __init gre_init(void) +{ + pr_info("GRE over IPv4 demultiplexor driver\n"); + + if (inet_add_protocol(&net_gre_protocol, IPPROTO_GRE) < 0) { + pr_err("can't add protocol\n"); + return -EAGAIN; + } + return 0; +} + +static void __exit gre_exit(void) +{ + inet_del_protocol(&net_gre_protocol, IPPROTO_GRE); +} + +module_init(gre_init); +module_exit(gre_exit); + +MODULE_DESCRIPTION("GRE over IPv4 demultiplexer driver"); +MODULE_AUTHOR("D. Kozlov (xeb@mail.ru)"); +MODULE_LICENSE("GPL"); diff --git a/net/ipv4/gre_offload.c b/net/ipv4/gre_offload.c new file mode 100644 index 000000000..2b9cb5398 --- /dev/null +++ b/net/ipv4/gre_offload.c @@ -0,0 +1,286 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPV4 GSO/GRO offload support + * Linux INET implementation + * + * GRE GSO support + */ + +#include +#include +#include +#include +#include + +static struct sk_buff *gre_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + int tnl_hlen = skb_inner_mac_header(skb) - skb_transport_header(skb); + bool need_csum, offload_csum, gso_partial, need_ipsec; + struct sk_buff *segs = ERR_PTR(-EINVAL); + u16 mac_offset = skb->mac_header; + __be16 protocol = skb->protocol; + u16 mac_len = skb->mac_len; + int gre_offset, outer_hlen; + + if (!skb->encapsulation) + goto out; + + if (unlikely(tnl_hlen < sizeof(struct gre_base_hdr))) + goto out; + + if (unlikely(!pskb_may_pull(skb, tnl_hlen))) + goto out; + + /* setup inner skb. */ + skb->encapsulation = 0; + SKB_GSO_CB(skb)->encap_level = 0; + __skb_pull(skb, tnl_hlen); + skb_reset_mac_header(skb); + skb_set_network_header(skb, skb_inner_network_offset(skb)); + skb->mac_len = skb_inner_network_offset(skb); + skb->protocol = skb->inner_protocol; + + need_csum = !!(skb_shinfo(skb)->gso_type & SKB_GSO_GRE_CSUM); + skb->encap_hdr_csum = need_csum; + + features &= skb->dev->hw_enc_features; + if (need_csum) + features &= ~NETIF_F_SCTP_CRC; + + need_ipsec = skb_dst(skb) && dst_xfrm(skb_dst(skb)); + /* Try to offload checksum if possible */ + offload_csum = !!(need_csum && !need_ipsec && + (skb->dev->features & NETIF_F_HW_CSUM)); + + /* segment inner packet. */ + segs = skb_mac_gso_segment(skb, features); + if (IS_ERR_OR_NULL(segs)) { + skb_gso_error_unwind(skb, protocol, tnl_hlen, mac_offset, + mac_len); + goto out; + } + + gso_partial = !!(skb_shinfo(segs)->gso_type & SKB_GSO_PARTIAL); + + outer_hlen = skb_tnl_header_len(skb); + gre_offset = outer_hlen - tnl_hlen; + skb = segs; + do { + struct gre_base_hdr *greh; + __sum16 *pcsum; + + /* Set up inner headers if we are offloading inner checksum */ + if (skb->ip_summed == CHECKSUM_PARTIAL) { + skb_reset_inner_headers(skb); + skb->encapsulation = 1; + } + + skb->mac_len = mac_len; + skb->protocol = protocol; + + __skb_push(skb, outer_hlen); + skb_reset_mac_header(skb); + skb_set_network_header(skb, mac_len); + skb_set_transport_header(skb, gre_offset); + + if (!need_csum) + continue; + + greh = (struct gre_base_hdr *)skb_transport_header(skb); + pcsum = (__sum16 *)(greh + 1); + + if (gso_partial && skb_is_gso(skb)) { + unsigned int partial_adj; + + /* Adjust checksum to account for the fact that + * the partial checksum is based on actual size + * whereas headers should be based on MSS size. + */ + partial_adj = skb->len + skb_headroom(skb) - + SKB_GSO_CB(skb)->data_offset - + skb_shinfo(skb)->gso_size; + *pcsum = ~csum_fold((__force __wsum)htonl(partial_adj)); + } else { + *pcsum = 0; + } + + *(pcsum + 1) = 0; + if (skb->encapsulation || !offload_csum) { + *pcsum = gso_make_checksum(skb, 0); + } else { + skb->ip_summed = CHECKSUM_PARTIAL; + skb->csum_start = skb_transport_header(skb) - skb->head; + skb->csum_offset = sizeof(*greh); + } + } while ((skb = skb->next)); +out: + return segs; +} + +static struct sk_buff *gre_gro_receive(struct list_head *head, + struct sk_buff *skb) +{ + struct sk_buff *pp = NULL; + struct sk_buff *p; + const struct gre_base_hdr *greh; + unsigned int hlen, grehlen; + unsigned int off; + int flush = 1; + struct packet_offload *ptype; + __be16 type; + + if (NAPI_GRO_CB(skb)->encap_mark) + goto out; + + NAPI_GRO_CB(skb)->encap_mark = 1; + + off = skb_gro_offset(skb); + hlen = off + sizeof(*greh); + greh = skb_gro_header(skb, hlen, off); + if (unlikely(!greh)) + goto out; + + /* Only support version 0 and K (key), C (csum) flags. Note that + * although the support for the S (seq#) flag can be added easily + * for GRO, this is problematic for GSO hence can not be enabled + * here because a GRO pkt may end up in the forwarding path, thus + * requiring GSO support to break it up correctly. + */ + if ((greh->flags & ~(GRE_KEY|GRE_CSUM)) != 0) + goto out; + + /* We can only support GRE_CSUM if we can track the location of + * the GRE header. In the case of FOU/GUE we cannot because the + * outer UDP header displaces the GRE header leaving us in a state + * of limbo. + */ + if ((greh->flags & GRE_CSUM) && NAPI_GRO_CB(skb)->is_fou) + goto out; + + type = greh->protocol; + + ptype = gro_find_receive_by_type(type); + if (!ptype) + goto out; + + grehlen = GRE_HEADER_SECTION; + + if (greh->flags & GRE_KEY) + grehlen += GRE_HEADER_SECTION; + + if (greh->flags & GRE_CSUM) + grehlen += GRE_HEADER_SECTION; + + hlen = off + grehlen; + if (skb_gro_header_hard(skb, hlen)) { + greh = skb_gro_header_slow(skb, hlen, off); + if (unlikely(!greh)) + goto out; + } + + /* Don't bother verifying checksum if we're going to flush anyway. */ + if ((greh->flags & GRE_CSUM) && !NAPI_GRO_CB(skb)->flush) { + if (skb_gro_checksum_simple_validate(skb)) + goto out; + + skb_gro_checksum_try_convert(skb, IPPROTO_GRE, + null_compute_pseudo); + } + + list_for_each_entry(p, head, list) { + const struct gre_base_hdr *greh2; + + if (!NAPI_GRO_CB(p)->same_flow) + continue; + + /* The following checks are needed to ensure only pkts + * from the same tunnel are considered for aggregation. + * The criteria for "the same tunnel" includes: + * 1) same version (we only support version 0 here) + * 2) same protocol (we only support ETH_P_IP for now) + * 3) same set of flags + * 4) same key if the key field is present. + */ + greh2 = (struct gre_base_hdr *)(p->data + off); + + if (greh2->flags != greh->flags || + greh2->protocol != greh->protocol) { + NAPI_GRO_CB(p)->same_flow = 0; + continue; + } + if (greh->flags & GRE_KEY) { + /* compare keys */ + if (*(__be32 *)(greh2+1) != *(__be32 *)(greh+1)) { + NAPI_GRO_CB(p)->same_flow = 0; + continue; + } + } + } + + skb_gro_pull(skb, grehlen); + + /* Adjusted NAPI_GRO_CB(skb)->csum after skb_gro_pull()*/ + skb_gro_postpull_rcsum(skb, greh, grehlen); + + pp = call_gro_receive(ptype->callbacks.gro_receive, head, skb); + flush = 0; + +out: + skb_gro_flush_final(skb, pp, flush); + + return pp; +} + +static int gre_gro_complete(struct sk_buff *skb, int nhoff) +{ + struct gre_base_hdr *greh = (struct gre_base_hdr *)(skb->data + nhoff); + struct packet_offload *ptype; + unsigned int grehlen = sizeof(*greh); + int err = -ENOENT; + __be16 type; + + skb->encapsulation = 1; + skb_shinfo(skb)->gso_type = SKB_GSO_GRE; + + type = greh->protocol; + if (greh->flags & GRE_KEY) + grehlen += GRE_HEADER_SECTION; + + if (greh->flags & GRE_CSUM) + grehlen += GRE_HEADER_SECTION; + + ptype = gro_find_complete_by_type(type); + if (ptype) + err = ptype->callbacks.gro_complete(skb, nhoff + grehlen); + + skb_set_inner_mac_header(skb, nhoff + grehlen); + + return err; +} + +static const struct net_offload gre_offload = { + .callbacks = { + .gso_segment = gre_gso_segment, + .gro_receive = gre_gro_receive, + .gro_complete = gre_gro_complete, + }, +}; + +static int __init gre_offload_init(void) +{ + int err; + + err = inet_add_offload(&gre_offload, IPPROTO_GRE); +#if IS_ENABLED(CONFIG_IPV6) + if (err) + return err; + + err = inet6_add_offload(&gre_offload, IPPROTO_GRE); + if (err) + inet_del_offload(&gre_offload, IPPROTO_GRE); +#endif + + return err; +} +device_initcall(gre_offload_init); diff --git a/net/ipv4/icmp.c b/net/ipv4/icmp.c new file mode 100644 index 000000000..2b09ef707 --- /dev/null +++ b/net/ipv4/icmp.c @@ -0,0 +1,1507 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NET3: Implementation of the ICMP protocol layer. + * + * Alan Cox, + * + * Some of the function names and the icmp unreach table for this + * module were derived from [icmp.c 1.0.11 06/02/93] by + * Ross Biro, Fred N. van Kempen, Mark Evans, Alan Cox, Gerhard Koerting. + * Other than that this module is a complete rewrite. + * + * Fixes: + * Clemens Fruhwirth : introduce global icmp rate limiting + * with icmp type masking ability instead + * of broken per type icmp timeouts. + * Mike Shaver : RFC1122 checks. + * Alan Cox : Multicast ping reply as self. + * Alan Cox : Fix atomicity lockup in ip_build_xmit + * call. + * Alan Cox : Added 216,128 byte paths to the MTU + * code. + * Martin Mares : RFC1812 checks. + * Martin Mares : Can be configured to follow redirects + * if acting as a router _without_ a + * routing protocol (RFC 1812). + * Martin Mares : Echo requests may be configured to + * be ignored (RFC 1812). + * Martin Mares : Limitation of ICMP error message + * transmit rate (RFC 1812). + * Martin Mares : TOS and Precedence set correctly + * (RFC 1812). + * Martin Mares : Now copying as much data from the + * original packet as we can without + * exceeding 576 bytes (RFC 1812). + * Willy Konynenberg : Transparent proxying support. + * Keith Owens : RFC1191 correction for 4.2BSD based + * path MTU bug. + * Thomas Quinot : ICMP Dest Unreach codes up to 15 are + * valid (RFC 1812). + * Andi Kleen : Check all packet lengths properly + * and moved all kfree_skb() up to + * icmp_rcv. + * Andi Kleen : Move the rate limit bookkeeping + * into the dest entry and use a token + * bucket filter (thanks to ANK). Make + * the rates sysctl configurable. + * Yu Tianli : Fixed two ugly bugs in icmp_send + * - IP option length was accounted wrongly + * - ICMP header length was not accounted + * at all. + * Tristan Greaves : Added sysctl option to ignore bogus + * broadcast responses from broken routers. + * + * To Fix: + * + * - Should use skb_pull() instead of all the manual checking. + * This would also greatly simply some upper layer error handlers. --AK + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * Build xmit assembly blocks + */ + +struct icmp_bxm { + struct sk_buff *skb; + int offset; + int data_len; + + struct { + struct icmphdr icmph; + __be32 times[3]; + } data; + int head_len; + struct ip_options_data replyopts; +}; + +/* An array of errno for error messages from dest unreach. */ +/* RFC 1122: 3.2.2.1 States that NET_UNREACH, HOST_UNREACH and SR_FAILED MUST be considered 'transient errs'. */ + +const struct icmp_err icmp_err_convert[] = { + { + .errno = ENETUNREACH, /* ICMP_NET_UNREACH */ + .fatal = 0, + }, + { + .errno = EHOSTUNREACH, /* ICMP_HOST_UNREACH */ + .fatal = 0, + }, + { + .errno = ENOPROTOOPT /* ICMP_PROT_UNREACH */, + .fatal = 1, + }, + { + .errno = ECONNREFUSED, /* ICMP_PORT_UNREACH */ + .fatal = 1, + }, + { + .errno = EMSGSIZE, /* ICMP_FRAG_NEEDED */ + .fatal = 0, + }, + { + .errno = EOPNOTSUPP, /* ICMP_SR_FAILED */ + .fatal = 0, + }, + { + .errno = ENETUNREACH, /* ICMP_NET_UNKNOWN */ + .fatal = 1, + }, + { + .errno = EHOSTDOWN, /* ICMP_HOST_UNKNOWN */ + .fatal = 1, + }, + { + .errno = ENONET, /* ICMP_HOST_ISOLATED */ + .fatal = 1, + }, + { + .errno = ENETUNREACH, /* ICMP_NET_ANO */ + .fatal = 1, + }, + { + .errno = EHOSTUNREACH, /* ICMP_HOST_ANO */ + .fatal = 1, + }, + { + .errno = ENETUNREACH, /* ICMP_NET_UNR_TOS */ + .fatal = 0, + }, + { + .errno = EHOSTUNREACH, /* ICMP_HOST_UNR_TOS */ + .fatal = 0, + }, + { + .errno = EHOSTUNREACH, /* ICMP_PKT_FILTERED */ + .fatal = 1, + }, + { + .errno = EHOSTUNREACH, /* ICMP_PREC_VIOLATION */ + .fatal = 1, + }, + { + .errno = EHOSTUNREACH, /* ICMP_PREC_CUTOFF */ + .fatal = 1, + }, +}; +EXPORT_SYMBOL(icmp_err_convert); + +/* + * ICMP control array. This specifies what to do with each ICMP. + */ + +struct icmp_control { + enum skb_drop_reason (*handler)(struct sk_buff *skb); + short error; /* This ICMP is classed as an error message */ +}; + +static const struct icmp_control icmp_pointers[NR_ICMP_TYPES+1]; + +static DEFINE_PER_CPU(struct sock *, ipv4_icmp_sk); + +/* Called with BH disabled */ +static inline struct sock *icmp_xmit_lock(struct net *net) +{ + struct sock *sk; + + sk = this_cpu_read(ipv4_icmp_sk); + + if (unlikely(!spin_trylock(&sk->sk_lock.slock))) { + /* This can happen if the output path signals a + * dst_link_failure() for an outgoing ICMP packet. + */ + return NULL; + } + sock_net_set(sk, net); + return sk; +} + +static inline void icmp_xmit_unlock(struct sock *sk) +{ + sock_net_set(sk, &init_net); + spin_unlock(&sk->sk_lock.slock); +} + +int sysctl_icmp_msgs_per_sec __read_mostly = 1000; +int sysctl_icmp_msgs_burst __read_mostly = 50; + +static struct { + spinlock_t lock; + u32 credit; + u32 stamp; +} icmp_global = { + .lock = __SPIN_LOCK_UNLOCKED(icmp_global.lock), +}; + +/** + * icmp_global_allow - Are we allowed to send one more ICMP message ? + * + * Uses a token bucket to limit our ICMP messages to ~sysctl_icmp_msgs_per_sec. + * Returns false if we reached the limit and can not send another packet. + * Note: called with BH disabled + */ +bool icmp_global_allow(void) +{ + u32 credit, delta, incr = 0, now = (u32)jiffies; + bool rc = false; + + /* Check if token bucket is empty and cannot be refilled + * without taking the spinlock. The READ_ONCE() are paired + * with the following WRITE_ONCE() in this same function. + */ + if (!READ_ONCE(icmp_global.credit)) { + delta = min_t(u32, now - READ_ONCE(icmp_global.stamp), HZ); + if (delta < HZ / 50) + return false; + } + + spin_lock(&icmp_global.lock); + delta = min_t(u32, now - icmp_global.stamp, HZ); + if (delta >= HZ / 50) { + incr = READ_ONCE(sysctl_icmp_msgs_per_sec) * delta / HZ; + if (incr) + WRITE_ONCE(icmp_global.stamp, now); + } + credit = min_t(u32, icmp_global.credit + incr, + READ_ONCE(sysctl_icmp_msgs_burst)); + if (credit) { + /* We want to use a credit of one in average, but need to randomize + * it for security reasons. + */ + credit = max_t(int, credit - prandom_u32_max(3), 0); + rc = true; + } + WRITE_ONCE(icmp_global.credit, credit); + spin_unlock(&icmp_global.lock); + return rc; +} +EXPORT_SYMBOL(icmp_global_allow); + +static bool icmpv4_mask_allow(struct net *net, int type, int code) +{ + if (type > NR_ICMP_TYPES) + return true; + + /* Don't limit PMTU discovery. */ + if (type == ICMP_DEST_UNREACH && code == ICMP_FRAG_NEEDED) + return true; + + /* Limit if icmp type is enabled in ratemask. */ + if (!((1 << type) & READ_ONCE(net->ipv4.sysctl_icmp_ratemask))) + return true; + + return false; +} + +static bool icmpv4_global_allow(struct net *net, int type, int code) +{ + if (icmpv4_mask_allow(net, type, code)) + return true; + + if (icmp_global_allow()) + return true; + + return false; +} + +/* + * Send an ICMP frame. + */ + +static bool icmpv4_xrlim_allow(struct net *net, struct rtable *rt, + struct flowi4 *fl4, int type, int code) +{ + struct dst_entry *dst = &rt->dst; + struct inet_peer *peer; + bool rc = true; + int vif; + + if (icmpv4_mask_allow(net, type, code)) + goto out; + + /* No rate limit on loopback */ + if (dst->dev && (dst->dev->flags&IFF_LOOPBACK)) + goto out; + + vif = l3mdev_master_ifindex(dst->dev); + peer = inet_getpeer_v4(net->ipv4.peers, fl4->daddr, vif, 1); + rc = inet_peer_xrlim_allow(peer, + READ_ONCE(net->ipv4.sysctl_icmp_ratelimit)); + if (peer) + inet_putpeer(peer); +out: + return rc; +} + +/* + * Maintain the counters used in the SNMP statistics for outgoing ICMP + */ +void icmp_out_count(struct net *net, unsigned char type) +{ + ICMPMSGOUT_INC_STATS(net, type); + ICMP_INC_STATS(net, ICMP_MIB_OUTMSGS); +} + +/* + * Checksum each fragment, and on the first include the headers and final + * checksum. + */ +static int icmp_glue_bits(void *from, char *to, int offset, int len, int odd, + struct sk_buff *skb) +{ + struct icmp_bxm *icmp_param = from; + __wsum csum; + + csum = skb_copy_and_csum_bits(icmp_param->skb, + icmp_param->offset + offset, + to, len); + + skb->csum = csum_block_add(skb->csum, csum, odd); + if (icmp_pointers[icmp_param->data.icmph.type].error) + nf_ct_attach(skb, icmp_param->skb); + return 0; +} + +static void icmp_push_reply(struct sock *sk, + struct icmp_bxm *icmp_param, + struct flowi4 *fl4, + struct ipcm_cookie *ipc, struct rtable **rt) +{ + struct sk_buff *skb; + + if (ip_append_data(sk, fl4, icmp_glue_bits, icmp_param, + icmp_param->data_len+icmp_param->head_len, + icmp_param->head_len, + ipc, rt, MSG_DONTWAIT) < 0) { + __ICMP_INC_STATS(sock_net(sk), ICMP_MIB_OUTERRORS); + ip_flush_pending_frames(sk); + } else if ((skb = skb_peek(&sk->sk_write_queue)) != NULL) { + struct icmphdr *icmph = icmp_hdr(skb); + __wsum csum; + struct sk_buff *skb1; + + csum = csum_partial_copy_nocheck((void *)&icmp_param->data, + (char *)icmph, + icmp_param->head_len); + skb_queue_walk(&sk->sk_write_queue, skb1) { + csum = csum_add(csum, skb1->csum); + } + icmph->checksum = csum_fold(csum); + skb->ip_summed = CHECKSUM_NONE; + ip_push_pending_frames(sk, fl4); + } +} + +/* + * Driving logic for building and sending ICMP messages. + */ + +static void icmp_reply(struct icmp_bxm *icmp_param, struct sk_buff *skb) +{ + struct ipcm_cookie ipc; + struct rtable *rt = skb_rtable(skb); + struct net *net = dev_net(rt->dst.dev); + struct flowi4 fl4; + struct sock *sk; + struct inet_sock *inet; + __be32 daddr, saddr; + u32 mark = IP4_REPLY_MARK(net, skb->mark); + int type = icmp_param->data.icmph.type; + int code = icmp_param->data.icmph.code; + + if (ip_options_echo(net, &icmp_param->replyopts.opt.opt, skb)) + return; + + /* Needed by both icmp_global_allow and icmp_xmit_lock */ + local_bh_disable(); + + /* global icmp_msgs_per_sec */ + if (!icmpv4_global_allow(net, type, code)) + goto out_bh_enable; + + sk = icmp_xmit_lock(net); + if (!sk) + goto out_bh_enable; + inet = inet_sk(sk); + + icmp_param->data.icmph.checksum = 0; + + ipcm_init(&ipc); + inet->tos = ip_hdr(skb)->tos; + ipc.sockc.mark = mark; + daddr = ipc.addr = ip_hdr(skb)->saddr; + saddr = fib_compute_spec_dst(skb); + + if (icmp_param->replyopts.opt.opt.optlen) { + ipc.opt = &icmp_param->replyopts.opt; + if (ipc.opt->opt.srr) + daddr = icmp_param->replyopts.opt.opt.faddr; + } + memset(&fl4, 0, sizeof(fl4)); + fl4.daddr = daddr; + fl4.saddr = saddr; + fl4.flowi4_mark = mark; + fl4.flowi4_uid = sock_net_uid(net, NULL); + fl4.flowi4_tos = RT_TOS(ip_hdr(skb)->tos); + fl4.flowi4_proto = IPPROTO_ICMP; + fl4.flowi4_oif = l3mdev_master_ifindex(skb->dev); + security_skb_classify_flow(skb, flowi4_to_flowi_common(&fl4)); + rt = ip_route_output_key(net, &fl4); + if (IS_ERR(rt)) + goto out_unlock; + if (icmpv4_xrlim_allow(net, rt, &fl4, type, code)) + icmp_push_reply(sk, icmp_param, &fl4, &ipc, &rt); + ip_rt_put(rt); +out_unlock: + icmp_xmit_unlock(sk); +out_bh_enable: + local_bh_enable(); +} + +/* + * The device used for looking up which routing table to use for sending an ICMP + * error is preferably the source whenever it is set, which should ensure the + * icmp error can be sent to the source host, else lookup using the routing + * table of the destination device, else use the main routing table (index 0). + */ +static struct net_device *icmp_get_route_lookup_dev(struct sk_buff *skb) +{ + struct net_device *route_lookup_dev = NULL; + + if (skb->dev) + route_lookup_dev = skb->dev; + else if (skb_dst(skb)) + route_lookup_dev = skb_dst(skb)->dev; + return route_lookup_dev; +} + +static struct rtable *icmp_route_lookup(struct net *net, + struct flowi4 *fl4, + struct sk_buff *skb_in, + const struct iphdr *iph, + __be32 saddr, u8 tos, u32 mark, + int type, int code, + struct icmp_bxm *param) +{ + struct net_device *route_lookup_dev; + struct rtable *rt, *rt2; + struct flowi4 fl4_dec; + int err; + + memset(fl4, 0, sizeof(*fl4)); + fl4->daddr = (param->replyopts.opt.opt.srr ? + param->replyopts.opt.opt.faddr : iph->saddr); + fl4->saddr = saddr; + fl4->flowi4_mark = mark; + fl4->flowi4_uid = sock_net_uid(net, NULL); + fl4->flowi4_tos = RT_TOS(tos); + fl4->flowi4_proto = IPPROTO_ICMP; + fl4->fl4_icmp_type = type; + fl4->fl4_icmp_code = code; + route_lookup_dev = icmp_get_route_lookup_dev(skb_in); + fl4->flowi4_oif = l3mdev_master_ifindex(route_lookup_dev); + + security_skb_classify_flow(skb_in, flowi4_to_flowi_common(fl4)); + rt = ip_route_output_key_hash(net, fl4, skb_in); + if (IS_ERR(rt)) + return rt; + + /* No need to clone since we're just using its address. */ + rt2 = rt; + + rt = (struct rtable *) xfrm_lookup(net, &rt->dst, + flowi4_to_flowi(fl4), NULL, 0); + if (!IS_ERR(rt)) { + if (rt != rt2) + return rt; + } else if (PTR_ERR(rt) == -EPERM) { + rt = NULL; + } else + return rt; + + err = xfrm_decode_session_reverse(skb_in, flowi4_to_flowi(&fl4_dec), AF_INET); + if (err) + goto relookup_failed; + + if (inet_addr_type_dev_table(net, route_lookup_dev, + fl4_dec.saddr) == RTN_LOCAL) { + rt2 = __ip_route_output_key(net, &fl4_dec); + if (IS_ERR(rt2)) + err = PTR_ERR(rt2); + } else { + struct flowi4 fl4_2 = {}; + unsigned long orefdst; + + fl4_2.daddr = fl4_dec.saddr; + rt2 = ip_route_output_key(net, &fl4_2); + if (IS_ERR(rt2)) { + err = PTR_ERR(rt2); + goto relookup_failed; + } + /* Ugh! */ + orefdst = skb_in->_skb_refdst; /* save old refdst */ + skb_dst_set(skb_in, NULL); + err = ip_route_input(skb_in, fl4_dec.daddr, fl4_dec.saddr, + RT_TOS(tos), rt2->dst.dev); + + dst_release(&rt2->dst); + rt2 = skb_rtable(skb_in); + skb_in->_skb_refdst = orefdst; /* restore old refdst */ + } + + if (err) + goto relookup_failed; + + rt2 = (struct rtable *) xfrm_lookup(net, &rt2->dst, + flowi4_to_flowi(&fl4_dec), NULL, + XFRM_LOOKUP_ICMP); + if (!IS_ERR(rt2)) { + dst_release(&rt->dst); + memcpy(fl4, &fl4_dec, sizeof(*fl4)); + rt = rt2; + } else if (PTR_ERR(rt2) == -EPERM) { + if (rt) + dst_release(&rt->dst); + return rt2; + } else { + err = PTR_ERR(rt2); + goto relookup_failed; + } + return rt; + +relookup_failed: + if (rt) + return rt; + return ERR_PTR(err); +} + +/* + * Send an ICMP message in response to a situation + * + * RFC 1122: 3.2.2 MUST send at least the IP header and 8 bytes of header. + * MAY send more (we do). + * MUST NOT change this header information. + * MUST NOT reply to a multicast/broadcast IP address. + * MUST NOT reply to a multicast/broadcast MAC address. + * MUST reply to only the first fragment. + */ + +void __icmp_send(struct sk_buff *skb_in, int type, int code, __be32 info, + const struct ip_options *opt) +{ + struct iphdr *iph; + int room; + struct icmp_bxm icmp_param; + struct rtable *rt = skb_rtable(skb_in); + struct ipcm_cookie ipc; + struct flowi4 fl4; + __be32 saddr; + u8 tos; + u32 mark; + struct net *net; + struct sock *sk; + + if (!rt) + goto out; + + if (rt->dst.dev) + net = dev_net(rt->dst.dev); + else if (skb_in->dev) + net = dev_net(skb_in->dev); + else + goto out; + + /* + * Find the original header. It is expected to be valid, of course. + * Check this, icmp_send is called from the most obscure devices + * sometimes. + */ + iph = ip_hdr(skb_in); + + if ((u8 *)iph < skb_in->head || + (skb_network_header(skb_in) + sizeof(*iph)) > + skb_tail_pointer(skb_in)) + goto out; + + /* + * No replies to physical multicast/broadcast + */ + if (skb_in->pkt_type != PACKET_HOST) + goto out; + + /* + * Now check at the protocol level + */ + if (rt->rt_flags & (RTCF_BROADCAST | RTCF_MULTICAST)) + goto out; + + /* + * Only reply to fragment 0. We byte re-order the constant + * mask for efficiency. + */ + if (iph->frag_off & htons(IP_OFFSET)) + goto out; + + /* + * If we send an ICMP error to an ICMP error a mess would result.. + */ + if (icmp_pointers[type].error) { + /* + * We are an error, check if we are replying to an + * ICMP error + */ + if (iph->protocol == IPPROTO_ICMP) { + u8 _inner_type, *itp; + + itp = skb_header_pointer(skb_in, + skb_network_header(skb_in) + + (iph->ihl << 2) + + offsetof(struct icmphdr, + type) - + skb_in->data, + sizeof(_inner_type), + &_inner_type); + if (!itp) + goto out; + + /* + * Assume any unknown ICMP type is an error. This + * isn't specified by the RFC, but think about it.. + */ + if (*itp > NR_ICMP_TYPES || + icmp_pointers[*itp].error) + goto out; + } + } + + /* Needed by both icmp_global_allow and icmp_xmit_lock */ + local_bh_disable(); + + /* Check global sysctl_icmp_msgs_per_sec ratelimit, unless + * incoming dev is loopback. If outgoing dev change to not be + * loopback, then peer ratelimit still work (in icmpv4_xrlim_allow) + */ + if (!(skb_in->dev && (skb_in->dev->flags&IFF_LOOPBACK)) && + !icmpv4_global_allow(net, type, code)) + goto out_bh_enable; + + sk = icmp_xmit_lock(net); + if (!sk) + goto out_bh_enable; + + /* + * Construct source address and options. + */ + + saddr = iph->daddr; + if (!(rt->rt_flags & RTCF_LOCAL)) { + struct net_device *dev = NULL; + + rcu_read_lock(); + if (rt_is_input_route(rt) && + READ_ONCE(net->ipv4.sysctl_icmp_errors_use_inbound_ifaddr)) + dev = dev_get_by_index_rcu(net, inet_iif(skb_in)); + + if (dev) + saddr = inet_select_addr(dev, iph->saddr, + RT_SCOPE_LINK); + else + saddr = 0; + rcu_read_unlock(); + } + + tos = icmp_pointers[type].error ? (RT_TOS(iph->tos) | + IPTOS_PREC_INTERNETCONTROL) : + iph->tos; + mark = IP4_REPLY_MARK(net, skb_in->mark); + + if (__ip_options_echo(net, &icmp_param.replyopts.opt.opt, skb_in, opt)) + goto out_unlock; + + + /* + * Prepare data for ICMP header. + */ + + icmp_param.data.icmph.type = type; + icmp_param.data.icmph.code = code; + icmp_param.data.icmph.un.gateway = info; + icmp_param.data.icmph.checksum = 0; + icmp_param.skb = skb_in; + icmp_param.offset = skb_network_offset(skb_in); + inet_sk(sk)->tos = tos; + ipcm_init(&ipc); + ipc.addr = iph->saddr; + ipc.opt = &icmp_param.replyopts.opt; + ipc.sockc.mark = mark; + + rt = icmp_route_lookup(net, &fl4, skb_in, iph, saddr, tos, mark, + type, code, &icmp_param); + if (IS_ERR(rt)) + goto out_unlock; + + /* peer icmp_ratelimit */ + if (!icmpv4_xrlim_allow(net, rt, &fl4, type, code)) + goto ende; + + /* RFC says return as much as we can without exceeding 576 bytes. */ + + room = dst_mtu(&rt->dst); + if (room > 576) + room = 576; + room -= sizeof(struct iphdr) + icmp_param.replyopts.opt.opt.optlen; + room -= sizeof(struct icmphdr); + /* Guard against tiny mtu. We need to include at least one + * IP network header for this message to make any sense. + */ + if (room <= (int)sizeof(struct iphdr)) + goto ende; + + icmp_param.data_len = skb_in->len - icmp_param.offset; + if (icmp_param.data_len > room) + icmp_param.data_len = room; + icmp_param.head_len = sizeof(struct icmphdr); + + /* if we don't have a source address at this point, fall back to the + * dummy address instead of sending out a packet with a source address + * of 0.0.0.0 + */ + if (!fl4.saddr) + fl4.saddr = htonl(INADDR_DUMMY); + + icmp_push_reply(sk, &icmp_param, &fl4, &ipc, &rt); +ende: + ip_rt_put(rt); +out_unlock: + icmp_xmit_unlock(sk); +out_bh_enable: + local_bh_enable(); +out:; +} +EXPORT_SYMBOL(__icmp_send); + +#if IS_ENABLED(CONFIG_NF_NAT) +#include +void icmp_ndo_send(struct sk_buff *skb_in, int type, int code, __be32 info) +{ + struct sk_buff *cloned_skb = NULL; + struct ip_options opts = { 0 }; + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + __be32 orig_ip; + + ct = nf_ct_get(skb_in, &ctinfo); + if (!ct || !(ct->status & IPS_SRC_NAT)) { + __icmp_send(skb_in, type, code, info, &opts); + return; + } + + if (skb_shared(skb_in)) + skb_in = cloned_skb = skb_clone(skb_in, GFP_ATOMIC); + + if (unlikely(!skb_in || skb_network_header(skb_in) < skb_in->head || + (skb_network_header(skb_in) + sizeof(struct iphdr)) > + skb_tail_pointer(skb_in) || skb_ensure_writable(skb_in, + skb_network_offset(skb_in) + sizeof(struct iphdr)))) + goto out; + + orig_ip = ip_hdr(skb_in)->saddr; + ip_hdr(skb_in)->saddr = ct->tuplehash[0].tuple.src.u3.ip; + __icmp_send(skb_in, type, code, info, &opts); + ip_hdr(skb_in)->saddr = orig_ip; +out: + consume_skb(cloned_skb); +} +EXPORT_SYMBOL(icmp_ndo_send); +#endif + +static void icmp_socket_deliver(struct sk_buff *skb, u32 info) +{ + const struct iphdr *iph = (const struct iphdr *)skb->data; + const struct net_protocol *ipprot; + int protocol = iph->protocol; + + /* Checkin full IP header plus 8 bytes of protocol to + * avoid additional coding at protocol handlers. + */ + if (!pskb_may_pull(skb, iph->ihl * 4 + 8)) { + __ICMP_INC_STATS(dev_net(skb->dev), ICMP_MIB_INERRORS); + return; + } + + raw_icmp_error(skb, protocol, info); + + ipprot = rcu_dereference(inet_protos[protocol]); + if (ipprot && ipprot->err_handler) + ipprot->err_handler(skb, info); +} + +static bool icmp_tag_validation(int proto) +{ + bool ok; + + rcu_read_lock(); + ok = rcu_dereference(inet_protos[proto])->icmp_strict_tag_validation; + rcu_read_unlock(); + return ok; +} + +/* + * Handle ICMP_DEST_UNREACH, ICMP_TIME_EXCEEDED, ICMP_QUENCH, and + * ICMP_PARAMETERPROB. + */ + +static enum skb_drop_reason icmp_unreach(struct sk_buff *skb) +{ + enum skb_drop_reason reason = SKB_NOT_DROPPED_YET; + const struct iphdr *iph; + struct icmphdr *icmph; + struct net *net; + u32 info = 0; + + net = dev_net(skb_dst(skb)->dev); + + /* + * Incomplete header ? + * Only checks for the IP header, there should be an + * additional check for longer headers in upper levels. + */ + + if (!pskb_may_pull(skb, sizeof(struct iphdr))) + goto out_err; + + icmph = icmp_hdr(skb); + iph = (const struct iphdr *)skb->data; + + if (iph->ihl < 5) { /* Mangled header, drop. */ + reason = SKB_DROP_REASON_IP_INHDR; + goto out_err; + } + + switch (icmph->type) { + case ICMP_DEST_UNREACH: + switch (icmph->code & 15) { + case ICMP_NET_UNREACH: + case ICMP_HOST_UNREACH: + case ICMP_PROT_UNREACH: + case ICMP_PORT_UNREACH: + break; + case ICMP_FRAG_NEEDED: + /* for documentation of the ip_no_pmtu_disc + * values please see + * Documentation/networking/ip-sysctl.rst + */ + switch (READ_ONCE(net->ipv4.sysctl_ip_no_pmtu_disc)) { + default: + net_dbg_ratelimited("%pI4: fragmentation needed and DF set\n", + &iph->daddr); + break; + case 2: + goto out; + case 3: + if (!icmp_tag_validation(iph->protocol)) + goto out; + fallthrough; + case 0: + info = ntohs(icmph->un.frag.mtu); + } + break; + case ICMP_SR_FAILED: + net_dbg_ratelimited("%pI4: Source Route Failed\n", + &iph->daddr); + break; + default: + break; + } + if (icmph->code > NR_ICMP_UNREACH) + goto out; + break; + case ICMP_PARAMETERPROB: + info = ntohl(icmph->un.gateway) >> 24; + break; + case ICMP_TIME_EXCEEDED: + __ICMP_INC_STATS(net, ICMP_MIB_INTIMEEXCDS); + if (icmph->code == ICMP_EXC_FRAGTIME) + goto out; + break; + } + + /* + * Throw it at our lower layers + * + * RFC 1122: 3.2.2 MUST extract the protocol ID from the passed + * header. + * RFC 1122: 3.2.2.1 MUST pass ICMP unreach messages to the + * transport layer. + * RFC 1122: 3.2.2.2 MUST pass ICMP time expired messages to + * transport layer. + */ + + /* + * Check the other end isn't violating RFC 1122. Some routers send + * bogus responses to broadcast frames. If you see this message + * first check your netmask matches at both ends, if it does then + * get the other vendor to fix their kit. + */ + + if (!READ_ONCE(net->ipv4.sysctl_icmp_ignore_bogus_error_responses) && + inet_addr_type_dev_table(net, skb->dev, iph->daddr) == RTN_BROADCAST) { + net_warn_ratelimited("%pI4 sent an invalid ICMP type %u, code %u error to a broadcast: %pI4 on %s\n", + &ip_hdr(skb)->saddr, + icmph->type, icmph->code, + &iph->daddr, skb->dev->name); + goto out; + } + + icmp_socket_deliver(skb, info); + +out: + return reason; +out_err: + __ICMP_INC_STATS(net, ICMP_MIB_INERRORS); + return reason ?: SKB_DROP_REASON_NOT_SPECIFIED; +} + + +/* + * Handle ICMP_REDIRECT. + */ + +static enum skb_drop_reason icmp_redirect(struct sk_buff *skb) +{ + if (skb->len < sizeof(struct iphdr)) { + __ICMP_INC_STATS(dev_net(skb->dev), ICMP_MIB_INERRORS); + return SKB_DROP_REASON_PKT_TOO_SMALL; + } + + if (!pskb_may_pull(skb, sizeof(struct iphdr))) { + /* there aught to be a stat */ + return SKB_DROP_REASON_NOMEM; + } + + icmp_socket_deliver(skb, ntohl(icmp_hdr(skb)->un.gateway)); + return SKB_NOT_DROPPED_YET; +} + +/* + * Handle ICMP_ECHO ("ping") and ICMP_EXT_ECHO ("PROBE") requests. + * + * RFC 1122: 3.2.2.6 MUST have an echo server that answers ICMP echo + * requests. + * RFC 1122: 3.2.2.6 Data received in the ICMP_ECHO request MUST be + * included in the reply. + * RFC 1812: 4.3.3.6 SHOULD have a config option for silently ignoring + * echo requests, MUST have default=NOT. + * RFC 8335: 8 MUST have a config option to enable/disable ICMP + * Extended Echo Functionality, MUST be disabled by default + * See also WRT handling of options once they are done and working. + */ + +static enum skb_drop_reason icmp_echo(struct sk_buff *skb) +{ + struct icmp_bxm icmp_param; + struct net *net; + + net = dev_net(skb_dst(skb)->dev); + /* should there be an ICMP stat for ignored echos? */ + if (READ_ONCE(net->ipv4.sysctl_icmp_echo_ignore_all)) + return SKB_NOT_DROPPED_YET; + + icmp_param.data.icmph = *icmp_hdr(skb); + icmp_param.skb = skb; + icmp_param.offset = 0; + icmp_param.data_len = skb->len; + icmp_param.head_len = sizeof(struct icmphdr); + + if (icmp_param.data.icmph.type == ICMP_ECHO) + icmp_param.data.icmph.type = ICMP_ECHOREPLY; + else if (!icmp_build_probe(skb, &icmp_param.data.icmph)) + return SKB_NOT_DROPPED_YET; + + icmp_reply(&icmp_param, skb); + return SKB_NOT_DROPPED_YET; +} + +/* Helper for icmp_echo and icmpv6_echo_reply. + * Searches for net_device that matches PROBE interface identifier + * and builds PROBE reply message in icmphdr. + * + * Returns false if PROBE responses are disabled via sysctl + */ + +bool icmp_build_probe(struct sk_buff *skb, struct icmphdr *icmphdr) +{ + struct icmp_ext_hdr *ext_hdr, _ext_hdr; + struct icmp_ext_echo_iio *iio, _iio; + struct net *net = dev_net(skb->dev); + struct net_device *dev; + char buff[IFNAMSIZ]; + u16 ident_len; + u8 status; + + if (!READ_ONCE(net->ipv4.sysctl_icmp_echo_enable_probe)) + return false; + + /* We currently only support probing interfaces on the proxy node + * Check to ensure L-bit is set + */ + if (!(ntohs(icmphdr->un.echo.sequence) & 1)) + return false; + /* Clear status bits in reply message */ + icmphdr->un.echo.sequence &= htons(0xFF00); + if (icmphdr->type == ICMP_EXT_ECHO) + icmphdr->type = ICMP_EXT_ECHOREPLY; + else + icmphdr->type = ICMPV6_EXT_ECHO_REPLY; + ext_hdr = skb_header_pointer(skb, 0, sizeof(_ext_hdr), &_ext_hdr); + /* Size of iio is class_type dependent. + * Only check header here and assign length based on ctype in the switch statement + */ + iio = skb_header_pointer(skb, sizeof(_ext_hdr), sizeof(iio->extobj_hdr), &_iio); + if (!ext_hdr || !iio) + goto send_mal_query; + if (ntohs(iio->extobj_hdr.length) <= sizeof(iio->extobj_hdr) || + ntohs(iio->extobj_hdr.length) > sizeof(_iio)) + goto send_mal_query; + ident_len = ntohs(iio->extobj_hdr.length) - sizeof(iio->extobj_hdr); + iio = skb_header_pointer(skb, sizeof(_ext_hdr), + sizeof(iio->extobj_hdr) + ident_len, &_iio); + if (!iio) + goto send_mal_query; + + status = 0; + dev = NULL; + switch (iio->extobj_hdr.class_type) { + case ICMP_EXT_ECHO_CTYPE_NAME: + if (ident_len >= IFNAMSIZ) + goto send_mal_query; + memset(buff, 0, sizeof(buff)); + memcpy(buff, &iio->ident.name, ident_len); + dev = dev_get_by_name(net, buff); + break; + case ICMP_EXT_ECHO_CTYPE_INDEX: + if (ident_len != sizeof(iio->ident.ifindex)) + goto send_mal_query; + dev = dev_get_by_index(net, ntohl(iio->ident.ifindex)); + break; + case ICMP_EXT_ECHO_CTYPE_ADDR: + if (ident_len < sizeof(iio->ident.addr.ctype3_hdr) || + ident_len != sizeof(iio->ident.addr.ctype3_hdr) + + iio->ident.addr.ctype3_hdr.addrlen) + goto send_mal_query; + switch (ntohs(iio->ident.addr.ctype3_hdr.afi)) { + case ICMP_AFI_IP: + if (iio->ident.addr.ctype3_hdr.addrlen != sizeof(struct in_addr)) + goto send_mal_query; + dev = ip_dev_find(net, iio->ident.addr.ip_addr.ipv4_addr); + break; +#if IS_ENABLED(CONFIG_IPV6) + case ICMP_AFI_IP6: + if (iio->ident.addr.ctype3_hdr.addrlen != sizeof(struct in6_addr)) + goto send_mal_query; + dev = ipv6_stub->ipv6_dev_find(net, &iio->ident.addr.ip_addr.ipv6_addr, dev); + dev_hold(dev); + break; +#endif + default: + goto send_mal_query; + } + break; + default: + goto send_mal_query; + } + if (!dev) { + icmphdr->code = ICMP_EXT_CODE_NO_IF; + return true; + } + /* Fill bits in reply message */ + if (dev->flags & IFF_UP) + status |= ICMP_EXT_ECHOREPLY_ACTIVE; + if (__in_dev_get_rcu(dev) && __in_dev_get_rcu(dev)->ifa_list) + status |= ICMP_EXT_ECHOREPLY_IPV4; + if (!list_empty(&rcu_dereference(dev->ip6_ptr)->addr_list)) + status |= ICMP_EXT_ECHOREPLY_IPV6; + dev_put(dev); + icmphdr->un.echo.sequence |= htons(status); + return true; +send_mal_query: + icmphdr->code = ICMP_EXT_CODE_MAL_QUERY; + return true; +} +EXPORT_SYMBOL_GPL(icmp_build_probe); + +/* + * Handle ICMP Timestamp requests. + * RFC 1122: 3.2.2.8 MAY implement ICMP timestamp requests. + * SHOULD be in the kernel for minimum random latency. + * MUST be accurate to a few minutes. + * MUST be updated at least at 15Hz. + */ +static enum skb_drop_reason icmp_timestamp(struct sk_buff *skb) +{ + struct icmp_bxm icmp_param; + /* + * Too short. + */ + if (skb->len < 4) + goto out_err; + + /* + * Fill in the current time as ms since midnight UT: + */ + icmp_param.data.times[1] = inet_current_timestamp(); + icmp_param.data.times[2] = icmp_param.data.times[1]; + + BUG_ON(skb_copy_bits(skb, 0, &icmp_param.data.times[0], 4)); + + icmp_param.data.icmph = *icmp_hdr(skb); + icmp_param.data.icmph.type = ICMP_TIMESTAMPREPLY; + icmp_param.data.icmph.code = 0; + icmp_param.skb = skb; + icmp_param.offset = 0; + icmp_param.data_len = 0; + icmp_param.head_len = sizeof(struct icmphdr) + 12; + icmp_reply(&icmp_param, skb); + return SKB_NOT_DROPPED_YET; + +out_err: + __ICMP_INC_STATS(dev_net(skb_dst(skb)->dev), ICMP_MIB_INERRORS); + return SKB_DROP_REASON_PKT_TOO_SMALL; +} + +static enum skb_drop_reason icmp_discard(struct sk_buff *skb) +{ + /* pretend it was a success */ + return SKB_NOT_DROPPED_YET; +} + +/* + * Deal with incoming ICMP packets. + */ +int icmp_rcv(struct sk_buff *skb) +{ + enum skb_drop_reason reason = SKB_DROP_REASON_NOT_SPECIFIED; + struct rtable *rt = skb_rtable(skb); + struct net *net = dev_net(rt->dst.dev); + struct icmphdr *icmph; + + if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) { + struct sec_path *sp = skb_sec_path(skb); + int nh; + + if (!(sp && sp->xvec[sp->len - 1]->props.flags & + XFRM_STATE_ICMP)) { + reason = SKB_DROP_REASON_XFRM_POLICY; + goto drop; + } + + if (!pskb_may_pull(skb, sizeof(*icmph) + sizeof(struct iphdr))) + goto drop; + + nh = skb_network_offset(skb); + skb_set_network_header(skb, sizeof(*icmph)); + + if (!xfrm4_policy_check_reverse(NULL, XFRM_POLICY_IN, + skb)) { + reason = SKB_DROP_REASON_XFRM_POLICY; + goto drop; + } + + skb_set_network_header(skb, nh); + } + + __ICMP_INC_STATS(net, ICMP_MIB_INMSGS); + + if (skb_checksum_simple_validate(skb)) + goto csum_error; + + if (!pskb_pull(skb, sizeof(*icmph))) + goto error; + + icmph = icmp_hdr(skb); + + ICMPMSGIN_INC_STATS(net, icmph->type); + + /* Check for ICMP Extended Echo (PROBE) messages */ + if (icmph->type == ICMP_EXT_ECHO) { + /* We can't use icmp_pointers[].handler() because it is an array of + * size NR_ICMP_TYPES + 1 (19 elements) and PROBE has code 42. + */ + reason = icmp_echo(skb); + goto reason_check; + } + + if (icmph->type == ICMP_EXT_ECHOREPLY) { + reason = ping_rcv(skb); + goto reason_check; + } + + /* + * 18 is the highest 'known' ICMP type. Anything else is a mystery + * + * RFC 1122: 3.2.2 Unknown ICMP messages types MUST be silently + * discarded. + */ + if (icmph->type > NR_ICMP_TYPES) { + reason = SKB_DROP_REASON_UNHANDLED_PROTO; + goto error; + } + + /* + * Parse the ICMP message + */ + + if (rt->rt_flags & (RTCF_BROADCAST | RTCF_MULTICAST)) { + /* + * RFC 1122: 3.2.2.6 An ICMP_ECHO to broadcast MAY be + * silently ignored (we let user decide with a sysctl). + * RFC 1122: 3.2.2.8 An ICMP_TIMESTAMP MAY be silently + * discarded if to broadcast/multicast. + */ + if ((icmph->type == ICMP_ECHO || + icmph->type == ICMP_TIMESTAMP) && + READ_ONCE(net->ipv4.sysctl_icmp_echo_ignore_broadcasts)) { + reason = SKB_DROP_REASON_INVALID_PROTO; + goto error; + } + if (icmph->type != ICMP_ECHO && + icmph->type != ICMP_TIMESTAMP && + icmph->type != ICMP_ADDRESS && + icmph->type != ICMP_ADDRESSREPLY) { + reason = SKB_DROP_REASON_INVALID_PROTO; + goto error; + } + } + + reason = icmp_pointers[icmph->type].handler(skb); +reason_check: + if (!reason) { + consume_skb(skb); + return NET_RX_SUCCESS; + } + +drop: + kfree_skb_reason(skb, reason); + return NET_RX_DROP; +csum_error: + reason = SKB_DROP_REASON_ICMP_CSUM; + __ICMP_INC_STATS(net, ICMP_MIB_CSUMERRORS); +error: + __ICMP_INC_STATS(net, ICMP_MIB_INERRORS); + goto drop; +} + +static bool ip_icmp_error_rfc4884_validate(const struct sk_buff *skb, int off) +{ + struct icmp_extobj_hdr *objh, _objh; + struct icmp_ext_hdr *exth, _exth; + u16 olen; + + exth = skb_header_pointer(skb, off, sizeof(_exth), &_exth); + if (!exth) + return false; + if (exth->version != 2) + return true; + + if (exth->checksum && + csum_fold(skb_checksum(skb, off, skb->len - off, 0))) + return false; + + off += sizeof(_exth); + while (off < skb->len) { + objh = skb_header_pointer(skb, off, sizeof(_objh), &_objh); + if (!objh) + return false; + + olen = ntohs(objh->length); + if (olen < sizeof(_objh)) + return false; + + off += olen; + if (off > skb->len) + return false; + } + + return true; +} + +void ip_icmp_error_rfc4884(const struct sk_buff *skb, + struct sock_ee_data_rfc4884 *out, + int thlen, int off) +{ + int hlen; + + /* original datagram headers: end of icmph to payload (skb->data) */ + hlen = -skb_transport_offset(skb) - thlen; + + /* per rfc 4884: minimal datagram length of 128 bytes */ + if (off < 128 || off < hlen) + return; + + /* kernel has stripped headers: return payload offset in bytes */ + off -= hlen; + if (off + sizeof(struct icmp_ext_hdr) > skb->len) + return; + + out->len = off; + + if (!ip_icmp_error_rfc4884_validate(skb, off)) + out->flags |= SO_EE_RFC4884_FLAG_INVALID; +} +EXPORT_SYMBOL_GPL(ip_icmp_error_rfc4884); + +int icmp_err(struct sk_buff *skb, u32 info) +{ + struct iphdr *iph = (struct iphdr *)skb->data; + int offset = iph->ihl<<2; + struct icmphdr *icmph = (struct icmphdr *)(skb->data + offset); + int type = icmp_hdr(skb)->type; + int code = icmp_hdr(skb)->code; + struct net *net = dev_net(skb->dev); + + /* + * Use ping_err to handle all icmp errors except those + * triggered by ICMP_ECHOREPLY which sent from kernel. + */ + if (icmph->type != ICMP_ECHOREPLY) { + ping_err(skb, offset, info); + return 0; + } + + if (type == ICMP_DEST_UNREACH && code == ICMP_FRAG_NEEDED) + ipv4_update_pmtu(skb, net, info, 0, IPPROTO_ICMP); + else if (type == ICMP_REDIRECT) + ipv4_redirect(skb, net, 0, IPPROTO_ICMP); + + return 0; +} + +/* + * This table is the definition of how we handle ICMP. + */ +static const struct icmp_control icmp_pointers[NR_ICMP_TYPES + 1] = { + [ICMP_ECHOREPLY] = { + .handler = ping_rcv, + }, + [1] = { + .handler = icmp_discard, + .error = 1, + }, + [2] = { + .handler = icmp_discard, + .error = 1, + }, + [ICMP_DEST_UNREACH] = { + .handler = icmp_unreach, + .error = 1, + }, + [ICMP_SOURCE_QUENCH] = { + .handler = icmp_unreach, + .error = 1, + }, + [ICMP_REDIRECT] = { + .handler = icmp_redirect, + .error = 1, + }, + [6] = { + .handler = icmp_discard, + .error = 1, + }, + [7] = { + .handler = icmp_discard, + .error = 1, + }, + [ICMP_ECHO] = { + .handler = icmp_echo, + }, + [9] = { + .handler = icmp_discard, + .error = 1, + }, + [10] = { + .handler = icmp_discard, + .error = 1, + }, + [ICMP_TIME_EXCEEDED] = { + .handler = icmp_unreach, + .error = 1, + }, + [ICMP_PARAMETERPROB] = { + .handler = icmp_unreach, + .error = 1, + }, + [ICMP_TIMESTAMP] = { + .handler = icmp_timestamp, + }, + [ICMP_TIMESTAMPREPLY] = { + .handler = icmp_discard, + }, + [ICMP_INFO_REQUEST] = { + .handler = icmp_discard, + }, + [ICMP_INFO_REPLY] = { + .handler = icmp_discard, + }, + [ICMP_ADDRESS] = { + .handler = icmp_discard, + }, + [ICMP_ADDRESSREPLY] = { + .handler = icmp_discard, + }, +}; + +static int __net_init icmp_sk_init(struct net *net) +{ + /* Control parameters for ECHO replies. */ + net->ipv4.sysctl_icmp_echo_ignore_all = 0; + net->ipv4.sysctl_icmp_echo_enable_probe = 0; + net->ipv4.sysctl_icmp_echo_ignore_broadcasts = 1; + + /* Control parameter - ignore bogus broadcast responses? */ + net->ipv4.sysctl_icmp_ignore_bogus_error_responses = 1; + + /* + * Configurable global rate limit. + * + * ratelimit defines tokens/packet consumed for dst->rate_token + * bucket ratemask defines which icmp types are ratelimited by + * setting it's bit position. + * + * default: + * dest unreachable (3), source quench (4), + * time exceeded (11), parameter problem (12) + */ + + net->ipv4.sysctl_icmp_ratelimit = 1 * HZ; + net->ipv4.sysctl_icmp_ratemask = 0x1818; + net->ipv4.sysctl_icmp_errors_use_inbound_ifaddr = 0; + + return 0; +} + +static struct pernet_operations __net_initdata icmp_sk_ops = { + .init = icmp_sk_init, +}; + +int __init icmp_init(void) +{ + int err, i; + + for_each_possible_cpu(i) { + struct sock *sk; + + err = inet_ctl_sock_create(&sk, PF_INET, + SOCK_RAW, IPPROTO_ICMP, &init_net); + if (err < 0) + return err; + + per_cpu(ipv4_icmp_sk, i) = sk; + + /* Enough space for 2 64K ICMP packets, including + * sk_buff/skb_shared_info struct overhead. + */ + sk->sk_sndbuf = 2 * SKB_TRUESIZE(64 * 1024); + + /* + * Speedup sock_wfree() + */ + sock_set_flag(sk, SOCK_USE_WRITE_QUEUE); + inet_sk(sk)->pmtudisc = IP_PMTUDISC_DONT; + } + return register_pernet_subsys(&icmp_sk_ops); +} diff --git a/net/ipv4/igmp.c b/net/ipv4/igmp.c new file mode 100644 index 000000000..ac53ef7ee --- /dev/null +++ b/net/ipv4/igmp.c @@ -0,0 +1,3110 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Linux NET3: Internet Group Management Protocol [IGMP] + * + * This code implements the IGMP protocol as defined in RFC1112. There has + * been a further revision of this protocol since which is now supported. + * + * If you have trouble with this module be careful what gcc you have used, + * the older version didn't come out right using gcc 2.5.8, the newer one + * seems to fall out with gcc 2.6.2. + * + * Authors: + * Alan Cox + * + * Fixes: + * + * Alan Cox : Added lots of __inline__ to optimise + * the memory usage of all the tiny little + * functions. + * Alan Cox : Dumped the header building experiment. + * Alan Cox : Minor tweaks ready for multicast routing + * and extended IGMP protocol. + * Alan Cox : Removed a load of inline directives. Gcc 2.5.8 + * writes utterly bogus code otherwise (sigh) + * fixed IGMP loopback to behave in the manner + * desired by mrouted, fixed the fact it has been + * broken since 1.3.6 and cleaned up a few minor + * points. + * + * Chih-Jen Chang : Tried to revise IGMP to Version 2 + * Tsu-Sheng Tsao E-mail: chihjenc@scf.usc.edu and tsusheng@scf.usc.edu + * The enhancements are mainly based on Steve Deering's + * ipmulti-3.5 source code. + * Chih-Jen Chang : Added the igmp_get_mrouter_info and + * Tsu-Sheng Tsao igmp_set_mrouter_info to keep track of + * the mrouted version on that device. + * Chih-Jen Chang : Added the max_resp_time parameter to + * Tsu-Sheng Tsao igmp_heard_query(). Using this parameter + * to identify the multicast router version + * and do what the IGMP version 2 specified. + * Chih-Jen Chang : Added a timer to revert to IGMP V2 router + * Tsu-Sheng Tsao if the specified time expired. + * Alan Cox : Stop IGMP from 0.0.0.0 being accepted. + * Alan Cox : Use GFP_ATOMIC in the right places. + * Christian Daudt : igmp timer wasn't set for local group + * memberships but was being deleted, + * which caused a "del_timer() called + * from %p with timer not initialized\n" + * message (960131). + * Christian Daudt : removed del_timer from + * igmp_timer_expire function (960205). + * Christian Daudt : igmp_heard_report now only calls + * igmp_timer_expire if tm->running is + * true (960216). + * Malcolm Beattie : ttl comparison wrong in igmp_rcv made + * igmp_heard_query never trigger. Expiry + * miscalculation fixed in igmp_heard_query + * and random() made to return unsigned to + * prevent negative expiry times. + * Alexey Kuznetsov: Wrong group leaving behaviour, backport + * fix from pending 2.1.x patches. + * Alan Cox: Forget to enable FDDI support earlier. + * Alexey Kuznetsov: Fixed leaving groups on device down. + * Alexey Kuznetsov: Accordance to igmp-v2-06 draft. + * David L Stevens: IGMPv3 support, with help from + * Vinay Kulkarni + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_IP_MROUTE +#include +#endif +#ifdef CONFIG_PROC_FS +#include +#include +#endif + +#ifdef CONFIG_IP_MULTICAST +/* Parameter names and values are taken from igmp-v2-06 draft */ + +#define IGMP_QUERY_INTERVAL (125*HZ) +#define IGMP_QUERY_RESPONSE_INTERVAL (10*HZ) + +#define IGMP_INITIAL_REPORT_DELAY (1) + +/* IGMP_INITIAL_REPORT_DELAY is not from IGMP specs! + * IGMP specs require to report membership immediately after + * joining a group, but we delay the first report by a + * small interval. It seems more natural and still does not + * contradict to specs provided this delay is small enough. + */ + +#define IGMP_V1_SEEN(in_dev) \ + (IPV4_DEVCONF_ALL(dev_net(in_dev->dev), FORCE_IGMP_VERSION) == 1 || \ + IN_DEV_CONF_GET((in_dev), FORCE_IGMP_VERSION) == 1 || \ + ((in_dev)->mr_v1_seen && \ + time_before(jiffies, (in_dev)->mr_v1_seen))) +#define IGMP_V2_SEEN(in_dev) \ + (IPV4_DEVCONF_ALL(dev_net(in_dev->dev), FORCE_IGMP_VERSION) == 2 || \ + IN_DEV_CONF_GET((in_dev), FORCE_IGMP_VERSION) == 2 || \ + ((in_dev)->mr_v2_seen && \ + time_before(jiffies, (in_dev)->mr_v2_seen))) + +static int unsolicited_report_interval(struct in_device *in_dev) +{ + int interval_ms, interval_jiffies; + + if (IGMP_V1_SEEN(in_dev) || IGMP_V2_SEEN(in_dev)) + interval_ms = IN_DEV_CONF_GET( + in_dev, + IGMPV2_UNSOLICITED_REPORT_INTERVAL); + else /* v3 */ + interval_ms = IN_DEV_CONF_GET( + in_dev, + IGMPV3_UNSOLICITED_REPORT_INTERVAL); + + interval_jiffies = msecs_to_jiffies(interval_ms); + + /* _timer functions can't handle a delay of 0 jiffies so ensure + * we always return a positive value. + */ + if (interval_jiffies <= 0) + interval_jiffies = 1; + return interval_jiffies; +} + +static void igmpv3_add_delrec(struct in_device *in_dev, struct ip_mc_list *im, + gfp_t gfp); +static void igmpv3_del_delrec(struct in_device *in_dev, struct ip_mc_list *im); +static void igmpv3_clear_delrec(struct in_device *in_dev); +static int sf_setstate(struct ip_mc_list *pmc); +static void sf_markstate(struct ip_mc_list *pmc); +#endif +static void ip_mc_clear_src(struct ip_mc_list *pmc); +static int ip_mc_add_src(struct in_device *in_dev, __be32 *pmca, int sfmode, + int sfcount, __be32 *psfsrc, int delta); + +static void ip_ma_put(struct ip_mc_list *im) +{ + if (refcount_dec_and_test(&im->refcnt)) { + in_dev_put(im->interface); + kfree_rcu(im, rcu); + } +} + +#define for_each_pmc_rcu(in_dev, pmc) \ + for (pmc = rcu_dereference(in_dev->mc_list); \ + pmc != NULL; \ + pmc = rcu_dereference(pmc->next_rcu)) + +#define for_each_pmc_rtnl(in_dev, pmc) \ + for (pmc = rtnl_dereference(in_dev->mc_list); \ + pmc != NULL; \ + pmc = rtnl_dereference(pmc->next_rcu)) + +static void ip_sf_list_clear_all(struct ip_sf_list *psf) +{ + struct ip_sf_list *next; + + while (psf) { + next = psf->sf_next; + kfree(psf); + psf = next; + } +} + +#ifdef CONFIG_IP_MULTICAST + +/* + * Timer management + */ + +static void igmp_stop_timer(struct ip_mc_list *im) +{ + spin_lock_bh(&im->lock); + if (del_timer(&im->timer)) + refcount_dec(&im->refcnt); + im->tm_running = 0; + im->reporter = 0; + im->unsolicit_count = 0; + spin_unlock_bh(&im->lock); +} + +/* It must be called with locked im->lock */ +static void igmp_start_timer(struct ip_mc_list *im, int max_delay) +{ + int tv = prandom_u32_max(max_delay); + + im->tm_running = 1; + if (refcount_inc_not_zero(&im->refcnt)) { + if (mod_timer(&im->timer, jiffies + tv + 2)) + ip_ma_put(im); + } +} + +static void igmp_gq_start_timer(struct in_device *in_dev) +{ + int tv = prandom_u32_max(in_dev->mr_maxdelay); + unsigned long exp = jiffies + tv + 2; + + if (in_dev->mr_gq_running && + time_after_eq(exp, (in_dev->mr_gq_timer).expires)) + return; + + in_dev->mr_gq_running = 1; + if (!mod_timer(&in_dev->mr_gq_timer, exp)) + in_dev_hold(in_dev); +} + +static void igmp_ifc_start_timer(struct in_device *in_dev, int delay) +{ + int tv = prandom_u32_max(delay); + + if (!mod_timer(&in_dev->mr_ifc_timer, jiffies+tv+2)) + in_dev_hold(in_dev); +} + +static void igmp_mod_timer(struct ip_mc_list *im, int max_delay) +{ + spin_lock_bh(&im->lock); + im->unsolicit_count = 0; + if (del_timer(&im->timer)) { + if ((long)(im->timer.expires-jiffies) < max_delay) { + add_timer(&im->timer); + im->tm_running = 1; + spin_unlock_bh(&im->lock); + return; + } + refcount_dec(&im->refcnt); + } + igmp_start_timer(im, max_delay); + spin_unlock_bh(&im->lock); +} + + +/* + * Send an IGMP report. + */ + +#define IGMP_SIZE (sizeof(struct igmphdr)+sizeof(struct iphdr)+4) + + +static int is_in(struct ip_mc_list *pmc, struct ip_sf_list *psf, int type, + int gdeleted, int sdeleted) +{ + switch (type) { + case IGMPV3_MODE_IS_INCLUDE: + case IGMPV3_MODE_IS_EXCLUDE: + if (gdeleted || sdeleted) + return 0; + if (!(pmc->gsquery && !psf->sf_gsresp)) { + if (pmc->sfmode == MCAST_INCLUDE) + return 1; + /* don't include if this source is excluded + * in all filters + */ + if (psf->sf_count[MCAST_INCLUDE]) + return type == IGMPV3_MODE_IS_INCLUDE; + return pmc->sfcount[MCAST_EXCLUDE] == + psf->sf_count[MCAST_EXCLUDE]; + } + return 0; + case IGMPV3_CHANGE_TO_INCLUDE: + if (gdeleted || sdeleted) + return 0; + return psf->sf_count[MCAST_INCLUDE] != 0; + case IGMPV3_CHANGE_TO_EXCLUDE: + if (gdeleted || sdeleted) + return 0; + if (pmc->sfcount[MCAST_EXCLUDE] == 0 || + psf->sf_count[MCAST_INCLUDE]) + return 0; + return pmc->sfcount[MCAST_EXCLUDE] == + psf->sf_count[MCAST_EXCLUDE]; + case IGMPV3_ALLOW_NEW_SOURCES: + if (gdeleted || !psf->sf_crcount) + return 0; + return (pmc->sfmode == MCAST_INCLUDE) ^ sdeleted; + case IGMPV3_BLOCK_OLD_SOURCES: + if (pmc->sfmode == MCAST_INCLUDE) + return gdeleted || (psf->sf_crcount && sdeleted); + return psf->sf_crcount && !gdeleted && !sdeleted; + } + return 0; +} + +static int +igmp_scount(struct ip_mc_list *pmc, int type, int gdeleted, int sdeleted) +{ + struct ip_sf_list *psf; + int scount = 0; + + for (psf = pmc->sources; psf; psf = psf->sf_next) { + if (!is_in(pmc, psf, type, gdeleted, sdeleted)) + continue; + scount++; + } + return scount; +} + +/* source address selection per RFC 3376 section 4.2.13 */ +static __be32 igmpv3_get_srcaddr(struct net_device *dev, + const struct flowi4 *fl4) +{ + struct in_device *in_dev = __in_dev_get_rcu(dev); + const struct in_ifaddr *ifa; + + if (!in_dev) + return htonl(INADDR_ANY); + + in_dev_for_each_ifa_rcu(ifa, in_dev) { + if (fl4->saddr == ifa->ifa_local) + return fl4->saddr; + } + + return htonl(INADDR_ANY); +} + +static struct sk_buff *igmpv3_newpack(struct net_device *dev, unsigned int mtu) +{ + struct sk_buff *skb; + struct rtable *rt; + struct iphdr *pip; + struct igmpv3_report *pig; + struct net *net = dev_net(dev); + struct flowi4 fl4; + int hlen = LL_RESERVED_SPACE(dev); + int tlen = dev->needed_tailroom; + unsigned int size; + + size = min(mtu, IP_MAX_MTU); + while (1) { + skb = alloc_skb(size + hlen + tlen, + GFP_ATOMIC | __GFP_NOWARN); + if (skb) + break; + size >>= 1; + if (size < 256) + return NULL; + } + skb->priority = TC_PRIO_CONTROL; + + rt = ip_route_output_ports(net, &fl4, NULL, IGMPV3_ALL_MCR, 0, + 0, 0, + IPPROTO_IGMP, 0, dev->ifindex); + if (IS_ERR(rt)) { + kfree_skb(skb); + return NULL; + } + + skb_dst_set(skb, &rt->dst); + skb->dev = dev; + + skb_reserve(skb, hlen); + skb_tailroom_reserve(skb, mtu, tlen); + + skb_reset_network_header(skb); + pip = ip_hdr(skb); + skb_put(skb, sizeof(struct iphdr) + 4); + + pip->version = 4; + pip->ihl = (sizeof(struct iphdr)+4)>>2; + pip->tos = 0xc0; + pip->frag_off = htons(IP_DF); + pip->ttl = 1; + pip->daddr = fl4.daddr; + + rcu_read_lock(); + pip->saddr = igmpv3_get_srcaddr(dev, &fl4); + rcu_read_unlock(); + + pip->protocol = IPPROTO_IGMP; + pip->tot_len = 0; /* filled in later */ + ip_select_ident(net, skb, NULL); + ((u8 *)&pip[1])[0] = IPOPT_RA; + ((u8 *)&pip[1])[1] = 4; + ((u8 *)&pip[1])[2] = 0; + ((u8 *)&pip[1])[3] = 0; + + skb->transport_header = skb->network_header + sizeof(struct iphdr) + 4; + skb_put(skb, sizeof(*pig)); + pig = igmpv3_report_hdr(skb); + pig->type = IGMPV3_HOST_MEMBERSHIP_REPORT; + pig->resv1 = 0; + pig->csum = 0; + pig->resv2 = 0; + pig->ngrec = 0; + return skb; +} + +static int igmpv3_sendpack(struct sk_buff *skb) +{ + struct igmphdr *pig = igmp_hdr(skb); + const int igmplen = skb_tail_pointer(skb) - skb_transport_header(skb); + + pig->csum = ip_compute_csum(igmp_hdr(skb), igmplen); + + return ip_local_out(dev_net(skb_dst(skb)->dev), skb->sk, skb); +} + +static int grec_size(struct ip_mc_list *pmc, int type, int gdel, int sdel) +{ + return sizeof(struct igmpv3_grec) + 4*igmp_scount(pmc, type, gdel, sdel); +} + +static struct sk_buff *add_grhead(struct sk_buff *skb, struct ip_mc_list *pmc, + int type, struct igmpv3_grec **ppgr, unsigned int mtu) +{ + struct net_device *dev = pmc->interface->dev; + struct igmpv3_report *pih; + struct igmpv3_grec *pgr; + + if (!skb) { + skb = igmpv3_newpack(dev, mtu); + if (!skb) + return NULL; + } + pgr = skb_put(skb, sizeof(struct igmpv3_grec)); + pgr->grec_type = type; + pgr->grec_auxwords = 0; + pgr->grec_nsrcs = 0; + pgr->grec_mca = pmc->multiaddr; + pih = igmpv3_report_hdr(skb); + pih->ngrec = htons(ntohs(pih->ngrec)+1); + *ppgr = pgr; + return skb; +} + +#define AVAILABLE(skb) ((skb) ? skb_availroom(skb) : 0) + +static struct sk_buff *add_grec(struct sk_buff *skb, struct ip_mc_list *pmc, + int type, int gdeleted, int sdeleted) +{ + struct net_device *dev = pmc->interface->dev; + struct net *net = dev_net(dev); + struct igmpv3_report *pih; + struct igmpv3_grec *pgr = NULL; + struct ip_sf_list *psf, *psf_next, *psf_prev, **psf_list; + int scount, stotal, first, isquery, truncate; + unsigned int mtu; + + if (pmc->multiaddr == IGMP_ALL_HOSTS) + return skb; + if (ipv4_is_local_multicast(pmc->multiaddr) && + !READ_ONCE(net->ipv4.sysctl_igmp_llm_reports)) + return skb; + + mtu = READ_ONCE(dev->mtu); + if (mtu < IPV4_MIN_MTU) + return skb; + + isquery = type == IGMPV3_MODE_IS_INCLUDE || + type == IGMPV3_MODE_IS_EXCLUDE; + truncate = type == IGMPV3_MODE_IS_EXCLUDE || + type == IGMPV3_CHANGE_TO_EXCLUDE; + + stotal = scount = 0; + + psf_list = sdeleted ? &pmc->tomb : &pmc->sources; + + if (!*psf_list) + goto empty_source; + + pih = skb ? igmpv3_report_hdr(skb) : NULL; + + /* EX and TO_EX get a fresh packet, if needed */ + if (truncate) { + if (pih && pih->ngrec && + AVAILABLE(skb) < grec_size(pmc, type, gdeleted, sdeleted)) { + if (skb) + igmpv3_sendpack(skb); + skb = igmpv3_newpack(dev, mtu); + } + } + first = 1; + psf_prev = NULL; + for (psf = *psf_list; psf; psf = psf_next) { + __be32 *psrc; + + psf_next = psf->sf_next; + + if (!is_in(pmc, psf, type, gdeleted, sdeleted)) { + psf_prev = psf; + continue; + } + + /* Based on RFC3376 5.1. Should not send source-list change + * records when there is a filter mode change. + */ + if (((gdeleted && pmc->sfmode == MCAST_EXCLUDE) || + (!gdeleted && pmc->crcount)) && + (type == IGMPV3_ALLOW_NEW_SOURCES || + type == IGMPV3_BLOCK_OLD_SOURCES) && psf->sf_crcount) + goto decrease_sf_crcount; + + /* clear marks on query responses */ + if (isquery) + psf->sf_gsresp = 0; + + if (AVAILABLE(skb) < sizeof(__be32) + + first*sizeof(struct igmpv3_grec)) { + if (truncate && !first) + break; /* truncate these */ + if (pgr) + pgr->grec_nsrcs = htons(scount); + if (skb) + igmpv3_sendpack(skb); + skb = igmpv3_newpack(dev, mtu); + first = 1; + scount = 0; + } + if (first) { + skb = add_grhead(skb, pmc, type, &pgr, mtu); + first = 0; + } + if (!skb) + return NULL; + psrc = skb_put(skb, sizeof(__be32)); + *psrc = psf->sf_inaddr; + scount++; stotal++; + if ((type == IGMPV3_ALLOW_NEW_SOURCES || + type == IGMPV3_BLOCK_OLD_SOURCES) && psf->sf_crcount) { +decrease_sf_crcount: + psf->sf_crcount--; + if ((sdeleted || gdeleted) && psf->sf_crcount == 0) { + if (psf_prev) + psf_prev->sf_next = psf->sf_next; + else + *psf_list = psf->sf_next; + kfree(psf); + continue; + } + } + psf_prev = psf; + } + +empty_source: + if (!stotal) { + if (type == IGMPV3_ALLOW_NEW_SOURCES || + type == IGMPV3_BLOCK_OLD_SOURCES) + return skb; + if (pmc->crcount || isquery) { + /* make sure we have room for group header */ + if (skb && AVAILABLE(skb) < sizeof(struct igmpv3_grec)) { + igmpv3_sendpack(skb); + skb = NULL; /* add_grhead will get a new one */ + } + skb = add_grhead(skb, pmc, type, &pgr, mtu); + } + } + if (pgr) + pgr->grec_nsrcs = htons(scount); + + if (isquery) + pmc->gsquery = 0; /* clear query state on report */ + return skb; +} + +static int igmpv3_send_report(struct in_device *in_dev, struct ip_mc_list *pmc) +{ + struct sk_buff *skb = NULL; + struct net *net = dev_net(in_dev->dev); + int type; + + if (!pmc) { + rcu_read_lock(); + for_each_pmc_rcu(in_dev, pmc) { + if (pmc->multiaddr == IGMP_ALL_HOSTS) + continue; + if (ipv4_is_local_multicast(pmc->multiaddr) && + !READ_ONCE(net->ipv4.sysctl_igmp_llm_reports)) + continue; + spin_lock_bh(&pmc->lock); + if (pmc->sfcount[MCAST_EXCLUDE]) + type = IGMPV3_MODE_IS_EXCLUDE; + else + type = IGMPV3_MODE_IS_INCLUDE; + skb = add_grec(skb, pmc, type, 0, 0); + spin_unlock_bh(&pmc->lock); + } + rcu_read_unlock(); + } else { + spin_lock_bh(&pmc->lock); + if (pmc->sfcount[MCAST_EXCLUDE]) + type = IGMPV3_MODE_IS_EXCLUDE; + else + type = IGMPV3_MODE_IS_INCLUDE; + skb = add_grec(skb, pmc, type, 0, 0); + spin_unlock_bh(&pmc->lock); + } + if (!skb) + return 0; + return igmpv3_sendpack(skb); +} + +/* + * remove zero-count source records from a source filter list + */ +static void igmpv3_clear_zeros(struct ip_sf_list **ppsf) +{ + struct ip_sf_list *psf_prev, *psf_next, *psf; + + psf_prev = NULL; + for (psf = *ppsf; psf; psf = psf_next) { + psf_next = psf->sf_next; + if (psf->sf_crcount == 0) { + if (psf_prev) + psf_prev->sf_next = psf->sf_next; + else + *ppsf = psf->sf_next; + kfree(psf); + } else + psf_prev = psf; + } +} + +static void kfree_pmc(struct ip_mc_list *pmc) +{ + ip_sf_list_clear_all(pmc->sources); + ip_sf_list_clear_all(pmc->tomb); + kfree(pmc); +} + +static void igmpv3_send_cr(struct in_device *in_dev) +{ + struct ip_mc_list *pmc, *pmc_prev, *pmc_next; + struct sk_buff *skb = NULL; + int type, dtype; + + rcu_read_lock(); + spin_lock_bh(&in_dev->mc_tomb_lock); + + /* deleted MCA's */ + pmc_prev = NULL; + for (pmc = in_dev->mc_tomb; pmc; pmc = pmc_next) { + pmc_next = pmc->next; + if (pmc->sfmode == MCAST_INCLUDE) { + type = IGMPV3_BLOCK_OLD_SOURCES; + dtype = IGMPV3_BLOCK_OLD_SOURCES; + skb = add_grec(skb, pmc, type, 1, 0); + skb = add_grec(skb, pmc, dtype, 1, 1); + } + if (pmc->crcount) { + if (pmc->sfmode == MCAST_EXCLUDE) { + type = IGMPV3_CHANGE_TO_INCLUDE; + skb = add_grec(skb, pmc, type, 1, 0); + } + pmc->crcount--; + if (pmc->crcount == 0) { + igmpv3_clear_zeros(&pmc->tomb); + igmpv3_clear_zeros(&pmc->sources); + } + } + if (pmc->crcount == 0 && !pmc->tomb && !pmc->sources) { + if (pmc_prev) + pmc_prev->next = pmc_next; + else + in_dev->mc_tomb = pmc_next; + in_dev_put(pmc->interface); + kfree_pmc(pmc); + } else + pmc_prev = pmc; + } + spin_unlock_bh(&in_dev->mc_tomb_lock); + + /* change recs */ + for_each_pmc_rcu(in_dev, pmc) { + spin_lock_bh(&pmc->lock); + if (pmc->sfcount[MCAST_EXCLUDE]) { + type = IGMPV3_BLOCK_OLD_SOURCES; + dtype = IGMPV3_ALLOW_NEW_SOURCES; + } else { + type = IGMPV3_ALLOW_NEW_SOURCES; + dtype = IGMPV3_BLOCK_OLD_SOURCES; + } + skb = add_grec(skb, pmc, type, 0, 0); + skb = add_grec(skb, pmc, dtype, 0, 1); /* deleted sources */ + + /* filter mode changes */ + if (pmc->crcount) { + if (pmc->sfmode == MCAST_EXCLUDE) + type = IGMPV3_CHANGE_TO_EXCLUDE; + else + type = IGMPV3_CHANGE_TO_INCLUDE; + skb = add_grec(skb, pmc, type, 0, 0); + pmc->crcount--; + } + spin_unlock_bh(&pmc->lock); + } + rcu_read_unlock(); + + if (!skb) + return; + (void) igmpv3_sendpack(skb); +} + +static int igmp_send_report(struct in_device *in_dev, struct ip_mc_list *pmc, + int type) +{ + struct sk_buff *skb; + struct iphdr *iph; + struct igmphdr *ih; + struct rtable *rt; + struct net_device *dev = in_dev->dev; + struct net *net = dev_net(dev); + __be32 group = pmc ? pmc->multiaddr : 0; + struct flowi4 fl4; + __be32 dst; + int hlen, tlen; + + if (type == IGMPV3_HOST_MEMBERSHIP_REPORT) + return igmpv3_send_report(in_dev, pmc); + + if (ipv4_is_local_multicast(group) && + !READ_ONCE(net->ipv4.sysctl_igmp_llm_reports)) + return 0; + + if (type == IGMP_HOST_LEAVE_MESSAGE) + dst = IGMP_ALL_ROUTER; + else + dst = group; + + rt = ip_route_output_ports(net, &fl4, NULL, dst, 0, + 0, 0, + IPPROTO_IGMP, 0, dev->ifindex); + if (IS_ERR(rt)) + return -1; + + hlen = LL_RESERVED_SPACE(dev); + tlen = dev->needed_tailroom; + skb = alloc_skb(IGMP_SIZE + hlen + tlen, GFP_ATOMIC); + if (!skb) { + ip_rt_put(rt); + return -1; + } + skb->priority = TC_PRIO_CONTROL; + + skb_dst_set(skb, &rt->dst); + + skb_reserve(skb, hlen); + + skb_reset_network_header(skb); + iph = ip_hdr(skb); + skb_put(skb, sizeof(struct iphdr) + 4); + + iph->version = 4; + iph->ihl = (sizeof(struct iphdr)+4)>>2; + iph->tos = 0xc0; + iph->frag_off = htons(IP_DF); + iph->ttl = 1; + iph->daddr = dst; + iph->saddr = fl4.saddr; + iph->protocol = IPPROTO_IGMP; + ip_select_ident(net, skb, NULL); + ((u8 *)&iph[1])[0] = IPOPT_RA; + ((u8 *)&iph[1])[1] = 4; + ((u8 *)&iph[1])[2] = 0; + ((u8 *)&iph[1])[3] = 0; + + ih = skb_put(skb, sizeof(struct igmphdr)); + ih->type = type; + ih->code = 0; + ih->csum = 0; + ih->group = group; + ih->csum = ip_compute_csum((void *)ih, sizeof(struct igmphdr)); + + return ip_local_out(net, skb->sk, skb); +} + +static void igmp_gq_timer_expire(struct timer_list *t) +{ + struct in_device *in_dev = from_timer(in_dev, t, mr_gq_timer); + + in_dev->mr_gq_running = 0; + igmpv3_send_report(in_dev, NULL); + in_dev_put(in_dev); +} + +static void igmp_ifc_timer_expire(struct timer_list *t) +{ + struct in_device *in_dev = from_timer(in_dev, t, mr_ifc_timer); + u32 mr_ifc_count; + + igmpv3_send_cr(in_dev); +restart: + mr_ifc_count = READ_ONCE(in_dev->mr_ifc_count); + + if (mr_ifc_count) { + if (cmpxchg(&in_dev->mr_ifc_count, + mr_ifc_count, + mr_ifc_count - 1) != mr_ifc_count) + goto restart; + igmp_ifc_start_timer(in_dev, + unsolicited_report_interval(in_dev)); + } + in_dev_put(in_dev); +} + +static void igmp_ifc_event(struct in_device *in_dev) +{ + struct net *net = dev_net(in_dev->dev); + if (IGMP_V1_SEEN(in_dev) || IGMP_V2_SEEN(in_dev)) + return; + WRITE_ONCE(in_dev->mr_ifc_count, in_dev->mr_qrv ?: READ_ONCE(net->ipv4.sysctl_igmp_qrv)); + igmp_ifc_start_timer(in_dev, 1); +} + + +static void igmp_timer_expire(struct timer_list *t) +{ + struct ip_mc_list *im = from_timer(im, t, timer); + struct in_device *in_dev = im->interface; + + spin_lock(&im->lock); + im->tm_running = 0; + + if (im->unsolicit_count && --im->unsolicit_count) + igmp_start_timer(im, unsolicited_report_interval(in_dev)); + + im->reporter = 1; + spin_unlock(&im->lock); + + if (IGMP_V1_SEEN(in_dev)) + igmp_send_report(in_dev, im, IGMP_HOST_MEMBERSHIP_REPORT); + else if (IGMP_V2_SEEN(in_dev)) + igmp_send_report(in_dev, im, IGMPV2_HOST_MEMBERSHIP_REPORT); + else + igmp_send_report(in_dev, im, IGMPV3_HOST_MEMBERSHIP_REPORT); + + ip_ma_put(im); +} + +/* mark EXCLUDE-mode sources */ +static int igmp_xmarksources(struct ip_mc_list *pmc, int nsrcs, __be32 *srcs) +{ + struct ip_sf_list *psf; + int i, scount; + + scount = 0; + for (psf = pmc->sources; psf; psf = psf->sf_next) { + if (scount == nsrcs) + break; + for (i = 0; i < nsrcs; i++) { + /* skip inactive filters */ + if (psf->sf_count[MCAST_INCLUDE] || + pmc->sfcount[MCAST_EXCLUDE] != + psf->sf_count[MCAST_EXCLUDE]) + break; + if (srcs[i] == psf->sf_inaddr) { + scount++; + break; + } + } + } + pmc->gsquery = 0; + if (scount == nsrcs) /* all sources excluded */ + return 0; + return 1; +} + +static int igmp_marksources(struct ip_mc_list *pmc, int nsrcs, __be32 *srcs) +{ + struct ip_sf_list *psf; + int i, scount; + + if (pmc->sfmode == MCAST_EXCLUDE) + return igmp_xmarksources(pmc, nsrcs, srcs); + + /* mark INCLUDE-mode sources */ + scount = 0; + for (psf = pmc->sources; psf; psf = psf->sf_next) { + if (scount == nsrcs) + break; + for (i = 0; i < nsrcs; i++) + if (srcs[i] == psf->sf_inaddr) { + psf->sf_gsresp = 1; + scount++; + break; + } + } + if (!scount) { + pmc->gsquery = 0; + return 0; + } + pmc->gsquery = 1; + return 1; +} + +/* return true if packet was dropped */ +static bool igmp_heard_report(struct in_device *in_dev, __be32 group) +{ + struct ip_mc_list *im; + struct net *net = dev_net(in_dev->dev); + + /* Timers are only set for non-local groups */ + + if (group == IGMP_ALL_HOSTS) + return false; + if (ipv4_is_local_multicast(group) && + !READ_ONCE(net->ipv4.sysctl_igmp_llm_reports)) + return false; + + rcu_read_lock(); + for_each_pmc_rcu(in_dev, im) { + if (im->multiaddr == group) { + igmp_stop_timer(im); + break; + } + } + rcu_read_unlock(); + return false; +} + +/* return true if packet was dropped */ +static bool igmp_heard_query(struct in_device *in_dev, struct sk_buff *skb, + int len) +{ + struct igmphdr *ih = igmp_hdr(skb); + struct igmpv3_query *ih3 = igmpv3_query_hdr(skb); + struct ip_mc_list *im; + __be32 group = ih->group; + int max_delay; + int mark = 0; + struct net *net = dev_net(in_dev->dev); + + + if (len == 8) { + if (ih->code == 0) { + /* Alas, old v1 router presents here. */ + + max_delay = IGMP_QUERY_RESPONSE_INTERVAL; + in_dev->mr_v1_seen = jiffies + + (in_dev->mr_qrv * in_dev->mr_qi) + + in_dev->mr_qri; + group = 0; + } else { + /* v2 router present */ + max_delay = ih->code*(HZ/IGMP_TIMER_SCALE); + in_dev->mr_v2_seen = jiffies + + (in_dev->mr_qrv * in_dev->mr_qi) + + in_dev->mr_qri; + } + /* cancel the interface change timer */ + WRITE_ONCE(in_dev->mr_ifc_count, 0); + if (del_timer(&in_dev->mr_ifc_timer)) + __in_dev_put(in_dev); + /* clear deleted report items */ + igmpv3_clear_delrec(in_dev); + } else if (len < 12) { + return true; /* ignore bogus packet; freed by caller */ + } else if (IGMP_V1_SEEN(in_dev)) { + /* This is a v3 query with v1 queriers present */ + max_delay = IGMP_QUERY_RESPONSE_INTERVAL; + group = 0; + } else if (IGMP_V2_SEEN(in_dev)) { + /* this is a v3 query with v2 queriers present; + * Interpretation of the max_delay code is problematic here. + * A real v2 host would use ih_code directly, while v3 has a + * different encoding. We use the v3 encoding as more likely + * to be intended in a v3 query. + */ + max_delay = IGMPV3_MRC(ih3->code)*(HZ/IGMP_TIMER_SCALE); + if (!max_delay) + max_delay = 1; /* can't mod w/ 0 */ + } else { /* v3 */ + if (!pskb_may_pull(skb, sizeof(struct igmpv3_query))) + return true; + + ih3 = igmpv3_query_hdr(skb); + if (ih3->nsrcs) { + if (!pskb_may_pull(skb, sizeof(struct igmpv3_query) + + ntohs(ih3->nsrcs)*sizeof(__be32))) + return true; + ih3 = igmpv3_query_hdr(skb); + } + + max_delay = IGMPV3_MRC(ih3->code)*(HZ/IGMP_TIMER_SCALE); + if (!max_delay) + max_delay = 1; /* can't mod w/ 0 */ + in_dev->mr_maxdelay = max_delay; + + /* RFC3376, 4.1.6. QRV and 4.1.7. QQIC, when the most recently + * received value was zero, use the default or statically + * configured value. + */ + in_dev->mr_qrv = ih3->qrv ?: READ_ONCE(net->ipv4.sysctl_igmp_qrv); + in_dev->mr_qi = IGMPV3_QQIC(ih3->qqic)*HZ ?: IGMP_QUERY_INTERVAL; + + /* RFC3376, 8.3. Query Response Interval: + * The number of seconds represented by the [Query Response + * Interval] must be less than the [Query Interval]. + */ + if (in_dev->mr_qri >= in_dev->mr_qi) + in_dev->mr_qri = (in_dev->mr_qi/HZ - 1)*HZ; + + if (!group) { /* general query */ + if (ih3->nsrcs) + return true; /* no sources allowed */ + igmp_gq_start_timer(in_dev); + return false; + } + /* mark sources to include, if group & source-specific */ + mark = ih3->nsrcs != 0; + } + + /* + * - Start the timers in all of our membership records + * that the query applies to for the interface on + * which the query arrived excl. those that belong + * to a "local" group (224.0.0.X) + * - For timers already running check if they need to + * be reset. + * - Use the igmp->igmp_code field as the maximum + * delay possible + */ + rcu_read_lock(); + for_each_pmc_rcu(in_dev, im) { + int changed; + + if (group && group != im->multiaddr) + continue; + if (im->multiaddr == IGMP_ALL_HOSTS) + continue; + if (ipv4_is_local_multicast(im->multiaddr) && + !READ_ONCE(net->ipv4.sysctl_igmp_llm_reports)) + continue; + spin_lock_bh(&im->lock); + if (im->tm_running) + im->gsquery = im->gsquery && mark; + else + im->gsquery = mark; + changed = !im->gsquery || + igmp_marksources(im, ntohs(ih3->nsrcs), ih3->srcs); + spin_unlock_bh(&im->lock); + if (changed) + igmp_mod_timer(im, max_delay); + } + rcu_read_unlock(); + return false; +} + +/* called in rcu_read_lock() section */ +int igmp_rcv(struct sk_buff *skb) +{ + /* This basically follows the spec line by line -- see RFC1112 */ + struct igmphdr *ih; + struct net_device *dev = skb->dev; + struct in_device *in_dev; + int len = skb->len; + bool dropped = true; + + if (netif_is_l3_master(dev)) { + dev = dev_get_by_index_rcu(dev_net(dev), IPCB(skb)->iif); + if (!dev) + goto drop; + } + + in_dev = __in_dev_get_rcu(dev); + if (!in_dev) + goto drop; + + if (!pskb_may_pull(skb, sizeof(struct igmphdr))) + goto drop; + + if (skb_checksum_simple_validate(skb)) + goto drop; + + ih = igmp_hdr(skb); + switch (ih->type) { + case IGMP_HOST_MEMBERSHIP_QUERY: + dropped = igmp_heard_query(in_dev, skb, len); + break; + case IGMP_HOST_MEMBERSHIP_REPORT: + case IGMPV2_HOST_MEMBERSHIP_REPORT: + /* Is it our report looped back? */ + if (rt_is_output_route(skb_rtable(skb))) + break; + /* don't rely on MC router hearing unicast reports */ + if (skb->pkt_type == PACKET_MULTICAST || + skb->pkt_type == PACKET_BROADCAST) + dropped = igmp_heard_report(in_dev, ih->group); + break; + case IGMP_PIM: +#ifdef CONFIG_IP_PIMSM_V1 + return pim_rcv_v1(skb); +#endif + case IGMPV3_HOST_MEMBERSHIP_REPORT: + case IGMP_DVMRP: + case IGMP_TRACE: + case IGMP_HOST_LEAVE_MESSAGE: + case IGMP_MTRACE: + case IGMP_MTRACE_RESP: + break; + default: + break; + } + +drop: + if (dropped) + kfree_skb(skb); + else + consume_skb(skb); + return 0; +} + +#endif + + +/* + * Add a filter to a device + */ + +static void ip_mc_filter_add(struct in_device *in_dev, __be32 addr) +{ + char buf[MAX_ADDR_LEN]; + struct net_device *dev = in_dev->dev; + + /* Checking for IFF_MULTICAST here is WRONG-WRONG-WRONG. + We will get multicast token leakage, when IFF_MULTICAST + is changed. This check should be done in ndo_set_rx_mode + routine. Something sort of: + if (dev->mc_list && dev->flags&IFF_MULTICAST) { do it; } + --ANK + */ + if (arp_mc_map(addr, buf, dev, 0) == 0) + dev_mc_add(dev, buf); +} + +/* + * Remove a filter from a device + */ + +static void ip_mc_filter_del(struct in_device *in_dev, __be32 addr) +{ + char buf[MAX_ADDR_LEN]; + struct net_device *dev = in_dev->dev; + + if (arp_mc_map(addr, buf, dev, 0) == 0) + dev_mc_del(dev, buf); +} + +#ifdef CONFIG_IP_MULTICAST +/* + * deleted ip_mc_list manipulation + */ +static void igmpv3_add_delrec(struct in_device *in_dev, struct ip_mc_list *im, + gfp_t gfp) +{ + struct ip_mc_list *pmc; + struct net *net = dev_net(in_dev->dev); + + /* this is an "ip_mc_list" for convenience; only the fields below + * are actually used. In particular, the refcnt and users are not + * used for management of the delete list. Using the same structure + * for deleted items allows change reports to use common code with + * non-deleted or query-response MCA's. + */ + pmc = kzalloc(sizeof(*pmc), gfp); + if (!pmc) + return; + spin_lock_init(&pmc->lock); + spin_lock_bh(&im->lock); + pmc->interface = im->interface; + in_dev_hold(in_dev); + pmc->multiaddr = im->multiaddr; + pmc->crcount = in_dev->mr_qrv ?: READ_ONCE(net->ipv4.sysctl_igmp_qrv); + pmc->sfmode = im->sfmode; + if (pmc->sfmode == MCAST_INCLUDE) { + struct ip_sf_list *psf; + + pmc->tomb = im->tomb; + pmc->sources = im->sources; + im->tomb = im->sources = NULL; + for (psf = pmc->sources; psf; psf = psf->sf_next) + psf->sf_crcount = pmc->crcount; + } + spin_unlock_bh(&im->lock); + + spin_lock_bh(&in_dev->mc_tomb_lock); + pmc->next = in_dev->mc_tomb; + in_dev->mc_tomb = pmc; + spin_unlock_bh(&in_dev->mc_tomb_lock); +} + +/* + * restore ip_mc_list deleted records + */ +static void igmpv3_del_delrec(struct in_device *in_dev, struct ip_mc_list *im) +{ + struct ip_mc_list *pmc, *pmc_prev; + struct ip_sf_list *psf; + struct net *net = dev_net(in_dev->dev); + __be32 multiaddr = im->multiaddr; + + spin_lock_bh(&in_dev->mc_tomb_lock); + pmc_prev = NULL; + for (pmc = in_dev->mc_tomb; pmc; pmc = pmc->next) { + if (pmc->multiaddr == multiaddr) + break; + pmc_prev = pmc; + } + if (pmc) { + if (pmc_prev) + pmc_prev->next = pmc->next; + else + in_dev->mc_tomb = pmc->next; + } + spin_unlock_bh(&in_dev->mc_tomb_lock); + + spin_lock_bh(&im->lock); + if (pmc) { + im->interface = pmc->interface; + if (im->sfmode == MCAST_INCLUDE) { + swap(im->tomb, pmc->tomb); + swap(im->sources, pmc->sources); + for (psf = im->sources; psf; psf = psf->sf_next) + psf->sf_crcount = in_dev->mr_qrv ?: + READ_ONCE(net->ipv4.sysctl_igmp_qrv); + } else { + im->crcount = in_dev->mr_qrv ?: + READ_ONCE(net->ipv4.sysctl_igmp_qrv); + } + in_dev_put(pmc->interface); + kfree_pmc(pmc); + } + spin_unlock_bh(&im->lock); +} + +/* + * flush ip_mc_list deleted records + */ +static void igmpv3_clear_delrec(struct in_device *in_dev) +{ + struct ip_mc_list *pmc, *nextpmc; + + spin_lock_bh(&in_dev->mc_tomb_lock); + pmc = in_dev->mc_tomb; + in_dev->mc_tomb = NULL; + spin_unlock_bh(&in_dev->mc_tomb_lock); + + for (; pmc; pmc = nextpmc) { + nextpmc = pmc->next; + ip_mc_clear_src(pmc); + in_dev_put(pmc->interface); + kfree_pmc(pmc); + } + /* clear dead sources, too */ + rcu_read_lock(); + for_each_pmc_rcu(in_dev, pmc) { + struct ip_sf_list *psf; + + spin_lock_bh(&pmc->lock); + psf = pmc->tomb; + pmc->tomb = NULL; + spin_unlock_bh(&pmc->lock); + ip_sf_list_clear_all(psf); + } + rcu_read_unlock(); +} +#endif + +static void __igmp_group_dropped(struct ip_mc_list *im, gfp_t gfp) +{ + struct in_device *in_dev = im->interface; +#ifdef CONFIG_IP_MULTICAST + struct net *net = dev_net(in_dev->dev); + int reporter; +#endif + + if (im->loaded) { + im->loaded = 0; + ip_mc_filter_del(in_dev, im->multiaddr); + } + +#ifdef CONFIG_IP_MULTICAST + if (im->multiaddr == IGMP_ALL_HOSTS) + return; + if (ipv4_is_local_multicast(im->multiaddr) && + !READ_ONCE(net->ipv4.sysctl_igmp_llm_reports)) + return; + + reporter = im->reporter; + igmp_stop_timer(im); + + if (!in_dev->dead) { + if (IGMP_V1_SEEN(in_dev)) + return; + if (IGMP_V2_SEEN(in_dev)) { + if (reporter) + igmp_send_report(in_dev, im, IGMP_HOST_LEAVE_MESSAGE); + return; + } + /* IGMPv3 */ + igmpv3_add_delrec(in_dev, im, gfp); + + igmp_ifc_event(in_dev); + } +#endif +} + +static void igmp_group_dropped(struct ip_mc_list *im) +{ + __igmp_group_dropped(im, GFP_KERNEL); +} + +static void igmp_group_added(struct ip_mc_list *im) +{ + struct in_device *in_dev = im->interface; +#ifdef CONFIG_IP_MULTICAST + struct net *net = dev_net(in_dev->dev); +#endif + + if (im->loaded == 0) { + im->loaded = 1; + ip_mc_filter_add(in_dev, im->multiaddr); + } + +#ifdef CONFIG_IP_MULTICAST + if (im->multiaddr == IGMP_ALL_HOSTS) + return; + if (ipv4_is_local_multicast(im->multiaddr) && + !READ_ONCE(net->ipv4.sysctl_igmp_llm_reports)) + return; + + if (in_dev->dead) + return; + + im->unsolicit_count = READ_ONCE(net->ipv4.sysctl_igmp_qrv); + if (IGMP_V1_SEEN(in_dev) || IGMP_V2_SEEN(in_dev)) { + spin_lock_bh(&im->lock); + igmp_start_timer(im, IGMP_INITIAL_REPORT_DELAY); + spin_unlock_bh(&im->lock); + return; + } + /* else, v3 */ + + /* Based on RFC3376 5.1, for newly added INCLUDE SSM, we should + * not send filter-mode change record as the mode should be from + * IN() to IN(A). + */ + if (im->sfmode == MCAST_EXCLUDE) + im->crcount = in_dev->mr_qrv ?: READ_ONCE(net->ipv4.sysctl_igmp_qrv); + + igmp_ifc_event(in_dev); +#endif +} + + +/* + * Multicast list managers + */ + +static u32 ip_mc_hash(const struct ip_mc_list *im) +{ + return hash_32((__force u32)im->multiaddr, MC_HASH_SZ_LOG); +} + +static void ip_mc_hash_add(struct in_device *in_dev, + struct ip_mc_list *im) +{ + struct ip_mc_list __rcu **mc_hash; + u32 hash; + + mc_hash = rtnl_dereference(in_dev->mc_hash); + if (mc_hash) { + hash = ip_mc_hash(im); + im->next_hash = mc_hash[hash]; + rcu_assign_pointer(mc_hash[hash], im); + return; + } + + /* do not use a hash table for small number of items */ + if (in_dev->mc_count < 4) + return; + + mc_hash = kzalloc(sizeof(struct ip_mc_list *) << MC_HASH_SZ_LOG, + GFP_KERNEL); + if (!mc_hash) + return; + + for_each_pmc_rtnl(in_dev, im) { + hash = ip_mc_hash(im); + im->next_hash = mc_hash[hash]; + RCU_INIT_POINTER(mc_hash[hash], im); + } + + rcu_assign_pointer(in_dev->mc_hash, mc_hash); +} + +static void ip_mc_hash_remove(struct in_device *in_dev, + struct ip_mc_list *im) +{ + struct ip_mc_list __rcu **mc_hash = rtnl_dereference(in_dev->mc_hash); + struct ip_mc_list *aux; + + if (!mc_hash) + return; + mc_hash += ip_mc_hash(im); + while ((aux = rtnl_dereference(*mc_hash)) != im) + mc_hash = &aux->next_hash; + *mc_hash = im->next_hash; +} + + +/* + * A socket has joined a multicast group on device dev. + */ +static void ____ip_mc_inc_group(struct in_device *in_dev, __be32 addr, + unsigned int mode, gfp_t gfp) +{ + struct ip_mc_list *im; + + ASSERT_RTNL(); + + for_each_pmc_rtnl(in_dev, im) { + if (im->multiaddr == addr) { + im->users++; + ip_mc_add_src(in_dev, &addr, mode, 0, NULL, 0); + goto out; + } + } + + im = kzalloc(sizeof(*im), gfp); + if (!im) + goto out; + + im->users = 1; + im->interface = in_dev; + in_dev_hold(in_dev); + im->multiaddr = addr; + /* initial mode is (EX, empty) */ + im->sfmode = mode; + im->sfcount[mode] = 1; + refcount_set(&im->refcnt, 1); + spin_lock_init(&im->lock); +#ifdef CONFIG_IP_MULTICAST + timer_setup(&im->timer, igmp_timer_expire, 0); +#endif + + im->next_rcu = in_dev->mc_list; + in_dev->mc_count++; + rcu_assign_pointer(in_dev->mc_list, im); + + ip_mc_hash_add(in_dev, im); + +#ifdef CONFIG_IP_MULTICAST + igmpv3_del_delrec(in_dev, im); +#endif + igmp_group_added(im); + if (!in_dev->dead) + ip_rt_multicast_event(in_dev); +out: + return; +} + +void __ip_mc_inc_group(struct in_device *in_dev, __be32 addr, gfp_t gfp) +{ + ____ip_mc_inc_group(in_dev, addr, MCAST_EXCLUDE, gfp); +} +EXPORT_SYMBOL(__ip_mc_inc_group); + +void ip_mc_inc_group(struct in_device *in_dev, __be32 addr) +{ + __ip_mc_inc_group(in_dev, addr, GFP_KERNEL); +} +EXPORT_SYMBOL(ip_mc_inc_group); + +static int ip_mc_check_iphdr(struct sk_buff *skb) +{ + const struct iphdr *iph; + unsigned int len; + unsigned int offset = skb_network_offset(skb) + sizeof(*iph); + + if (!pskb_may_pull(skb, offset)) + return -EINVAL; + + iph = ip_hdr(skb); + + if (iph->version != 4 || ip_hdrlen(skb) < sizeof(*iph)) + return -EINVAL; + + offset += ip_hdrlen(skb) - sizeof(*iph); + + if (!pskb_may_pull(skb, offset)) + return -EINVAL; + + iph = ip_hdr(skb); + + if (unlikely(ip_fast_csum((u8 *)iph, iph->ihl))) + return -EINVAL; + + len = skb_network_offset(skb) + ntohs(iph->tot_len); + if (skb->len < len || len < offset) + return -EINVAL; + + skb_set_transport_header(skb, offset); + + return 0; +} + +static int ip_mc_check_igmp_reportv3(struct sk_buff *skb) +{ + unsigned int len = skb_transport_offset(skb); + + len += sizeof(struct igmpv3_report); + + return ip_mc_may_pull(skb, len) ? 0 : -EINVAL; +} + +static int ip_mc_check_igmp_query(struct sk_buff *skb) +{ + unsigned int transport_len = ip_transport_len(skb); + unsigned int len; + + /* IGMPv{1,2}? */ + if (transport_len != sizeof(struct igmphdr)) { + /* or IGMPv3? */ + if (transport_len < sizeof(struct igmpv3_query)) + return -EINVAL; + + len = skb_transport_offset(skb) + sizeof(struct igmpv3_query); + if (!ip_mc_may_pull(skb, len)) + return -EINVAL; + } + + /* RFC2236+RFC3376 (IGMPv2+IGMPv3) require the multicast link layer + * all-systems destination addresses (224.0.0.1) for general queries + */ + if (!igmp_hdr(skb)->group && + ip_hdr(skb)->daddr != htonl(INADDR_ALLHOSTS_GROUP)) + return -EINVAL; + + return 0; +} + +static int ip_mc_check_igmp_msg(struct sk_buff *skb) +{ + switch (igmp_hdr(skb)->type) { + case IGMP_HOST_LEAVE_MESSAGE: + case IGMP_HOST_MEMBERSHIP_REPORT: + case IGMPV2_HOST_MEMBERSHIP_REPORT: + return 0; + case IGMPV3_HOST_MEMBERSHIP_REPORT: + return ip_mc_check_igmp_reportv3(skb); + case IGMP_HOST_MEMBERSHIP_QUERY: + return ip_mc_check_igmp_query(skb); + default: + return -ENOMSG; + } +} + +static __sum16 ip_mc_validate_checksum(struct sk_buff *skb) +{ + return skb_checksum_simple_validate(skb); +} + +static int ip_mc_check_igmp_csum(struct sk_buff *skb) +{ + unsigned int len = skb_transport_offset(skb) + sizeof(struct igmphdr); + unsigned int transport_len = ip_transport_len(skb); + struct sk_buff *skb_chk; + + if (!ip_mc_may_pull(skb, len)) + return -EINVAL; + + skb_chk = skb_checksum_trimmed(skb, transport_len, + ip_mc_validate_checksum); + if (!skb_chk) + return -EINVAL; + + if (skb_chk != skb) + kfree_skb(skb_chk); + + return 0; +} + +/** + * ip_mc_check_igmp - checks whether this is a sane IGMP packet + * @skb: the skb to validate + * + * Checks whether an IPv4 packet is a valid IGMP packet. If so sets + * skb transport header accordingly and returns zero. + * + * -EINVAL: A broken packet was detected, i.e. it violates some internet + * standard + * -ENOMSG: IP header validation succeeded but it is not an IGMP packet. + * -ENOMEM: A memory allocation failure happened. + * + * Caller needs to set the skb network header and free any returned skb if it + * differs from the provided skb. + */ +int ip_mc_check_igmp(struct sk_buff *skb) +{ + int ret = ip_mc_check_iphdr(skb); + + if (ret < 0) + return ret; + + if (ip_hdr(skb)->protocol != IPPROTO_IGMP) + return -ENOMSG; + + ret = ip_mc_check_igmp_csum(skb); + if (ret < 0) + return ret; + + return ip_mc_check_igmp_msg(skb); +} +EXPORT_SYMBOL(ip_mc_check_igmp); + +/* + * Resend IGMP JOIN report; used by netdev notifier. + */ +static void ip_mc_rejoin_groups(struct in_device *in_dev) +{ +#ifdef CONFIG_IP_MULTICAST + struct ip_mc_list *im; + int type; + struct net *net = dev_net(in_dev->dev); + + ASSERT_RTNL(); + + for_each_pmc_rtnl(in_dev, im) { + if (im->multiaddr == IGMP_ALL_HOSTS) + continue; + if (ipv4_is_local_multicast(im->multiaddr) && + !READ_ONCE(net->ipv4.sysctl_igmp_llm_reports)) + continue; + + /* a failover is happening and switches + * must be notified immediately + */ + if (IGMP_V1_SEEN(in_dev)) + type = IGMP_HOST_MEMBERSHIP_REPORT; + else if (IGMP_V2_SEEN(in_dev)) + type = IGMPV2_HOST_MEMBERSHIP_REPORT; + else + type = IGMPV3_HOST_MEMBERSHIP_REPORT; + igmp_send_report(in_dev, im, type); + } +#endif +} + +/* + * A socket has left a multicast group on device dev + */ + +void __ip_mc_dec_group(struct in_device *in_dev, __be32 addr, gfp_t gfp) +{ + struct ip_mc_list *i; + struct ip_mc_list __rcu **ip; + + ASSERT_RTNL(); + + for (ip = &in_dev->mc_list; + (i = rtnl_dereference(*ip)) != NULL; + ip = &i->next_rcu) { + if (i->multiaddr == addr) { + if (--i->users == 0) { + ip_mc_hash_remove(in_dev, i); + *ip = i->next_rcu; + in_dev->mc_count--; + __igmp_group_dropped(i, gfp); + ip_mc_clear_src(i); + + if (!in_dev->dead) + ip_rt_multicast_event(in_dev); + + ip_ma_put(i); + return; + } + break; + } + } +} +EXPORT_SYMBOL(__ip_mc_dec_group); + +/* Device changing type */ + +void ip_mc_unmap(struct in_device *in_dev) +{ + struct ip_mc_list *pmc; + + ASSERT_RTNL(); + + for_each_pmc_rtnl(in_dev, pmc) + igmp_group_dropped(pmc); +} + +void ip_mc_remap(struct in_device *in_dev) +{ + struct ip_mc_list *pmc; + + ASSERT_RTNL(); + + for_each_pmc_rtnl(in_dev, pmc) { +#ifdef CONFIG_IP_MULTICAST + igmpv3_del_delrec(in_dev, pmc); +#endif + igmp_group_added(pmc); + } +} + +/* Device going down */ + +void ip_mc_down(struct in_device *in_dev) +{ + struct ip_mc_list *pmc; + + ASSERT_RTNL(); + + for_each_pmc_rtnl(in_dev, pmc) + igmp_group_dropped(pmc); + +#ifdef CONFIG_IP_MULTICAST + WRITE_ONCE(in_dev->mr_ifc_count, 0); + if (del_timer(&in_dev->mr_ifc_timer)) + __in_dev_put(in_dev); + in_dev->mr_gq_running = 0; + if (del_timer(&in_dev->mr_gq_timer)) + __in_dev_put(in_dev); +#endif + + ip_mc_dec_group(in_dev, IGMP_ALL_HOSTS); +} + +#ifdef CONFIG_IP_MULTICAST +static void ip_mc_reset(struct in_device *in_dev) +{ + struct net *net = dev_net(in_dev->dev); + + in_dev->mr_qi = IGMP_QUERY_INTERVAL; + in_dev->mr_qri = IGMP_QUERY_RESPONSE_INTERVAL; + in_dev->mr_qrv = READ_ONCE(net->ipv4.sysctl_igmp_qrv); +} +#else +static void ip_mc_reset(struct in_device *in_dev) +{ +} +#endif + +void ip_mc_init_dev(struct in_device *in_dev) +{ + ASSERT_RTNL(); + +#ifdef CONFIG_IP_MULTICAST + timer_setup(&in_dev->mr_gq_timer, igmp_gq_timer_expire, 0); + timer_setup(&in_dev->mr_ifc_timer, igmp_ifc_timer_expire, 0); +#endif + ip_mc_reset(in_dev); + + spin_lock_init(&in_dev->mc_tomb_lock); +} + +/* Device going up */ + +void ip_mc_up(struct in_device *in_dev) +{ + struct ip_mc_list *pmc; + + ASSERT_RTNL(); + + ip_mc_reset(in_dev); + ip_mc_inc_group(in_dev, IGMP_ALL_HOSTS); + + for_each_pmc_rtnl(in_dev, pmc) { +#ifdef CONFIG_IP_MULTICAST + igmpv3_del_delrec(in_dev, pmc); +#endif + igmp_group_added(pmc); + } +} + +/* + * Device is about to be destroyed: clean up. + */ + +void ip_mc_destroy_dev(struct in_device *in_dev) +{ + struct ip_mc_list *i; + + ASSERT_RTNL(); + + /* Deactivate timers */ + ip_mc_down(in_dev); +#ifdef CONFIG_IP_MULTICAST + igmpv3_clear_delrec(in_dev); +#endif + + while ((i = rtnl_dereference(in_dev->mc_list)) != NULL) { + in_dev->mc_list = i->next_rcu; + in_dev->mc_count--; + ip_mc_clear_src(i); + ip_ma_put(i); + } +} + +/* RTNL is locked */ +static struct in_device *ip_mc_find_dev(struct net *net, struct ip_mreqn *imr) +{ + struct net_device *dev = NULL; + struct in_device *idev = NULL; + + if (imr->imr_ifindex) { + idev = inetdev_by_index(net, imr->imr_ifindex); + return idev; + } + if (imr->imr_address.s_addr) { + dev = __ip_dev_find(net, imr->imr_address.s_addr, false); + if (!dev) + return NULL; + } + + if (!dev) { + struct rtable *rt = ip_route_output(net, + imr->imr_multiaddr.s_addr, + 0, 0, 0); + if (!IS_ERR(rt)) { + dev = rt->dst.dev; + ip_rt_put(rt); + } + } + if (dev) { + imr->imr_ifindex = dev->ifindex; + idev = __in_dev_get_rtnl(dev); + } + return idev; +} + +/* + * Join a socket to a group + */ + +static int ip_mc_del1_src(struct ip_mc_list *pmc, int sfmode, + __be32 *psfsrc) +{ + struct ip_sf_list *psf, *psf_prev; + int rv = 0; + + psf_prev = NULL; + for (psf = pmc->sources; psf; psf = psf->sf_next) { + if (psf->sf_inaddr == *psfsrc) + break; + psf_prev = psf; + } + if (!psf || psf->sf_count[sfmode] == 0) { + /* source filter not found, or count wrong => bug */ + return -ESRCH; + } + psf->sf_count[sfmode]--; + if (psf->sf_count[sfmode] == 0) { + ip_rt_multicast_event(pmc->interface); + } + if (!psf->sf_count[MCAST_INCLUDE] && !psf->sf_count[MCAST_EXCLUDE]) { +#ifdef CONFIG_IP_MULTICAST + struct in_device *in_dev = pmc->interface; + struct net *net = dev_net(in_dev->dev); +#endif + + /* no more filters for this source */ + if (psf_prev) + psf_prev->sf_next = psf->sf_next; + else + pmc->sources = psf->sf_next; +#ifdef CONFIG_IP_MULTICAST + if (psf->sf_oldin && + !IGMP_V1_SEEN(in_dev) && !IGMP_V2_SEEN(in_dev)) { + psf->sf_crcount = in_dev->mr_qrv ?: READ_ONCE(net->ipv4.sysctl_igmp_qrv); + psf->sf_next = pmc->tomb; + pmc->tomb = psf; + rv = 1; + } else +#endif + kfree(psf); + } + return rv; +} + +#ifndef CONFIG_IP_MULTICAST +#define igmp_ifc_event(x) do { } while (0) +#endif + +static int ip_mc_del_src(struct in_device *in_dev, __be32 *pmca, int sfmode, + int sfcount, __be32 *psfsrc, int delta) +{ + struct ip_mc_list *pmc; + int changerec = 0; + int i, err; + + if (!in_dev) + return -ENODEV; + rcu_read_lock(); + for_each_pmc_rcu(in_dev, pmc) { + if (*pmca == pmc->multiaddr) + break; + } + if (!pmc) { + /* MCA not found?? bug */ + rcu_read_unlock(); + return -ESRCH; + } + spin_lock_bh(&pmc->lock); + rcu_read_unlock(); +#ifdef CONFIG_IP_MULTICAST + sf_markstate(pmc); +#endif + if (!delta) { + err = -EINVAL; + if (!pmc->sfcount[sfmode]) + goto out_unlock; + pmc->sfcount[sfmode]--; + } + err = 0; + for (i = 0; i < sfcount; i++) { + int rv = ip_mc_del1_src(pmc, sfmode, &psfsrc[i]); + + changerec |= rv > 0; + if (!err && rv < 0) + err = rv; + } + if (pmc->sfmode == MCAST_EXCLUDE && + pmc->sfcount[MCAST_EXCLUDE] == 0 && + pmc->sfcount[MCAST_INCLUDE]) { +#ifdef CONFIG_IP_MULTICAST + struct ip_sf_list *psf; + struct net *net = dev_net(in_dev->dev); +#endif + + /* filter mode change */ + pmc->sfmode = MCAST_INCLUDE; +#ifdef CONFIG_IP_MULTICAST + pmc->crcount = in_dev->mr_qrv ?: READ_ONCE(net->ipv4.sysctl_igmp_qrv); + WRITE_ONCE(in_dev->mr_ifc_count, pmc->crcount); + for (psf = pmc->sources; psf; psf = psf->sf_next) + psf->sf_crcount = 0; + igmp_ifc_event(pmc->interface); + } else if (sf_setstate(pmc) || changerec) { + igmp_ifc_event(pmc->interface); +#endif + } +out_unlock: + spin_unlock_bh(&pmc->lock); + return err; +} + +/* + * Add multicast single-source filter to the interface list + */ +static int ip_mc_add1_src(struct ip_mc_list *pmc, int sfmode, + __be32 *psfsrc) +{ + struct ip_sf_list *psf, *psf_prev; + + psf_prev = NULL; + for (psf = pmc->sources; psf; psf = psf->sf_next) { + if (psf->sf_inaddr == *psfsrc) + break; + psf_prev = psf; + } + if (!psf) { + psf = kzalloc(sizeof(*psf), GFP_ATOMIC); + if (!psf) + return -ENOBUFS; + psf->sf_inaddr = *psfsrc; + if (psf_prev) { + psf_prev->sf_next = psf; + } else + pmc->sources = psf; + } + psf->sf_count[sfmode]++; + if (psf->sf_count[sfmode] == 1) { + ip_rt_multicast_event(pmc->interface); + } + return 0; +} + +#ifdef CONFIG_IP_MULTICAST +static void sf_markstate(struct ip_mc_list *pmc) +{ + struct ip_sf_list *psf; + int mca_xcount = pmc->sfcount[MCAST_EXCLUDE]; + + for (psf = pmc->sources; psf; psf = psf->sf_next) + if (pmc->sfcount[MCAST_EXCLUDE]) { + psf->sf_oldin = mca_xcount == + psf->sf_count[MCAST_EXCLUDE] && + !psf->sf_count[MCAST_INCLUDE]; + } else + psf->sf_oldin = psf->sf_count[MCAST_INCLUDE] != 0; +} + +static int sf_setstate(struct ip_mc_list *pmc) +{ + struct ip_sf_list *psf, *dpsf; + int mca_xcount = pmc->sfcount[MCAST_EXCLUDE]; + int qrv = pmc->interface->mr_qrv; + int new_in, rv; + + rv = 0; + for (psf = pmc->sources; psf; psf = psf->sf_next) { + if (pmc->sfcount[MCAST_EXCLUDE]) { + new_in = mca_xcount == psf->sf_count[MCAST_EXCLUDE] && + !psf->sf_count[MCAST_INCLUDE]; + } else + new_in = psf->sf_count[MCAST_INCLUDE] != 0; + if (new_in) { + if (!psf->sf_oldin) { + struct ip_sf_list *prev = NULL; + + for (dpsf = pmc->tomb; dpsf; dpsf = dpsf->sf_next) { + if (dpsf->sf_inaddr == psf->sf_inaddr) + break; + prev = dpsf; + } + if (dpsf) { + if (prev) + prev->sf_next = dpsf->sf_next; + else + pmc->tomb = dpsf->sf_next; + kfree(dpsf); + } + psf->sf_crcount = qrv; + rv++; + } + } else if (psf->sf_oldin) { + + psf->sf_crcount = 0; + /* + * add or update "delete" records if an active filter + * is now inactive + */ + for (dpsf = pmc->tomb; dpsf; dpsf = dpsf->sf_next) + if (dpsf->sf_inaddr == psf->sf_inaddr) + break; + if (!dpsf) { + dpsf = kmalloc(sizeof(*dpsf), GFP_ATOMIC); + if (!dpsf) + continue; + *dpsf = *psf; + /* pmc->lock held by callers */ + dpsf->sf_next = pmc->tomb; + pmc->tomb = dpsf; + } + dpsf->sf_crcount = qrv; + rv++; + } + } + return rv; +} +#endif + +/* + * Add multicast source filter list to the interface list + */ +static int ip_mc_add_src(struct in_device *in_dev, __be32 *pmca, int sfmode, + int sfcount, __be32 *psfsrc, int delta) +{ + struct ip_mc_list *pmc; + int isexclude; + int i, err; + + if (!in_dev) + return -ENODEV; + rcu_read_lock(); + for_each_pmc_rcu(in_dev, pmc) { + if (*pmca == pmc->multiaddr) + break; + } + if (!pmc) { + /* MCA not found?? bug */ + rcu_read_unlock(); + return -ESRCH; + } + spin_lock_bh(&pmc->lock); + rcu_read_unlock(); + +#ifdef CONFIG_IP_MULTICAST + sf_markstate(pmc); +#endif + isexclude = pmc->sfmode == MCAST_EXCLUDE; + if (!delta) + pmc->sfcount[sfmode]++; + err = 0; + for (i = 0; i < sfcount; i++) { + err = ip_mc_add1_src(pmc, sfmode, &psfsrc[i]); + if (err) + break; + } + if (err) { + int j; + + if (!delta) + pmc->sfcount[sfmode]--; + for (j = 0; j < i; j++) + (void) ip_mc_del1_src(pmc, sfmode, &psfsrc[j]); + } else if (isexclude != (pmc->sfcount[MCAST_EXCLUDE] != 0)) { +#ifdef CONFIG_IP_MULTICAST + struct ip_sf_list *psf; + struct net *net = dev_net(pmc->interface->dev); + in_dev = pmc->interface; +#endif + + /* filter mode change */ + if (pmc->sfcount[MCAST_EXCLUDE]) + pmc->sfmode = MCAST_EXCLUDE; + else if (pmc->sfcount[MCAST_INCLUDE]) + pmc->sfmode = MCAST_INCLUDE; +#ifdef CONFIG_IP_MULTICAST + /* else no filters; keep old mode for reports */ + + pmc->crcount = in_dev->mr_qrv ?: READ_ONCE(net->ipv4.sysctl_igmp_qrv); + WRITE_ONCE(in_dev->mr_ifc_count, pmc->crcount); + for (psf = pmc->sources; psf; psf = psf->sf_next) + psf->sf_crcount = 0; + igmp_ifc_event(in_dev); + } else if (sf_setstate(pmc)) { + igmp_ifc_event(in_dev); +#endif + } + spin_unlock_bh(&pmc->lock); + return err; +} + +static void ip_mc_clear_src(struct ip_mc_list *pmc) +{ + struct ip_sf_list *tomb, *sources; + + spin_lock_bh(&pmc->lock); + tomb = pmc->tomb; + pmc->tomb = NULL; + sources = pmc->sources; + pmc->sources = NULL; + pmc->sfmode = MCAST_EXCLUDE; + pmc->sfcount[MCAST_INCLUDE] = 0; + pmc->sfcount[MCAST_EXCLUDE] = 1; + spin_unlock_bh(&pmc->lock); + + ip_sf_list_clear_all(tomb); + ip_sf_list_clear_all(sources); +} + +/* Join a multicast group + */ +static int __ip_mc_join_group(struct sock *sk, struct ip_mreqn *imr, + unsigned int mode) +{ + __be32 addr = imr->imr_multiaddr.s_addr; + struct ip_mc_socklist *iml, *i; + struct in_device *in_dev; + struct inet_sock *inet = inet_sk(sk); + struct net *net = sock_net(sk); + int ifindex; + int count = 0; + int err; + + ASSERT_RTNL(); + + if (!ipv4_is_multicast(addr)) + return -EINVAL; + + in_dev = ip_mc_find_dev(net, imr); + + if (!in_dev) { + err = -ENODEV; + goto done; + } + + err = -EADDRINUSE; + ifindex = imr->imr_ifindex; + for_each_pmc_rtnl(inet, i) { + if (i->multi.imr_multiaddr.s_addr == addr && + i->multi.imr_ifindex == ifindex) + goto done; + count++; + } + err = -ENOBUFS; + if (count >= READ_ONCE(net->ipv4.sysctl_igmp_max_memberships)) + goto done; + iml = sock_kmalloc(sk, sizeof(*iml), GFP_KERNEL); + if (!iml) + goto done; + + memcpy(&iml->multi, imr, sizeof(*imr)); + iml->next_rcu = inet->mc_list; + iml->sflist = NULL; + iml->sfmode = mode; + rcu_assign_pointer(inet->mc_list, iml); + ____ip_mc_inc_group(in_dev, addr, mode, GFP_KERNEL); + err = 0; +done: + return err; +} + +/* Join ASM (Any-Source Multicast) group + */ +int ip_mc_join_group(struct sock *sk, struct ip_mreqn *imr) +{ + return __ip_mc_join_group(sk, imr, MCAST_EXCLUDE); +} +EXPORT_SYMBOL(ip_mc_join_group); + +/* Join SSM (Source-Specific Multicast) group + */ +int ip_mc_join_group_ssm(struct sock *sk, struct ip_mreqn *imr, + unsigned int mode) +{ + return __ip_mc_join_group(sk, imr, mode); +} + +static int ip_mc_leave_src(struct sock *sk, struct ip_mc_socklist *iml, + struct in_device *in_dev) +{ + struct ip_sf_socklist *psf = rtnl_dereference(iml->sflist); + int err; + + if (!psf) { + /* any-source empty exclude case */ + return ip_mc_del_src(in_dev, &iml->multi.imr_multiaddr.s_addr, + iml->sfmode, 0, NULL, 0); + } + err = ip_mc_del_src(in_dev, &iml->multi.imr_multiaddr.s_addr, + iml->sfmode, psf->sl_count, psf->sl_addr, 0); + RCU_INIT_POINTER(iml->sflist, NULL); + /* decrease mem now to avoid the memleak warning */ + atomic_sub(struct_size(psf, sl_addr, psf->sl_max), &sk->sk_omem_alloc); + kfree_rcu(psf, rcu); + return err; +} + +int ip_mc_leave_group(struct sock *sk, struct ip_mreqn *imr) +{ + struct inet_sock *inet = inet_sk(sk); + struct ip_mc_socklist *iml; + struct ip_mc_socklist __rcu **imlp; + struct in_device *in_dev; + struct net *net = sock_net(sk); + __be32 group = imr->imr_multiaddr.s_addr; + u32 ifindex; + int ret = -EADDRNOTAVAIL; + + ASSERT_RTNL(); + + in_dev = ip_mc_find_dev(net, imr); + if (!imr->imr_ifindex && !imr->imr_address.s_addr && !in_dev) { + ret = -ENODEV; + goto out; + } + ifindex = imr->imr_ifindex; + for (imlp = &inet->mc_list; + (iml = rtnl_dereference(*imlp)) != NULL; + imlp = &iml->next_rcu) { + if (iml->multi.imr_multiaddr.s_addr != group) + continue; + if (ifindex) { + if (iml->multi.imr_ifindex != ifindex) + continue; + } else if (imr->imr_address.s_addr && imr->imr_address.s_addr != + iml->multi.imr_address.s_addr) + continue; + + (void) ip_mc_leave_src(sk, iml, in_dev); + + *imlp = iml->next_rcu; + + if (in_dev) + ip_mc_dec_group(in_dev, group); + + /* decrease mem now to avoid the memleak warning */ + atomic_sub(sizeof(*iml), &sk->sk_omem_alloc); + kfree_rcu(iml, rcu); + return 0; + } +out: + return ret; +} +EXPORT_SYMBOL(ip_mc_leave_group); + +int ip_mc_source(int add, int omode, struct sock *sk, struct + ip_mreq_source *mreqs, int ifindex) +{ + int err; + struct ip_mreqn imr; + __be32 addr = mreqs->imr_multiaddr; + struct ip_mc_socklist *pmc; + struct in_device *in_dev = NULL; + struct inet_sock *inet = inet_sk(sk); + struct ip_sf_socklist *psl; + struct net *net = sock_net(sk); + int leavegroup = 0; + int i, j, rv; + + if (!ipv4_is_multicast(addr)) + return -EINVAL; + + ASSERT_RTNL(); + + imr.imr_multiaddr.s_addr = mreqs->imr_multiaddr; + imr.imr_address.s_addr = mreqs->imr_interface; + imr.imr_ifindex = ifindex; + in_dev = ip_mc_find_dev(net, &imr); + + if (!in_dev) { + err = -ENODEV; + goto done; + } + err = -EADDRNOTAVAIL; + + for_each_pmc_rtnl(inet, pmc) { + if ((pmc->multi.imr_multiaddr.s_addr == + imr.imr_multiaddr.s_addr) && + (pmc->multi.imr_ifindex == imr.imr_ifindex)) + break; + } + if (!pmc) { /* must have a prior join */ + err = -EINVAL; + goto done; + } + /* if a source filter was set, must be the same mode as before */ + if (pmc->sflist) { + if (pmc->sfmode != omode) { + err = -EINVAL; + goto done; + } + } else if (pmc->sfmode != omode) { + /* allow mode switches for empty-set filters */ + ip_mc_add_src(in_dev, &mreqs->imr_multiaddr, omode, 0, NULL, 0); + ip_mc_del_src(in_dev, &mreqs->imr_multiaddr, pmc->sfmode, 0, + NULL, 0); + pmc->sfmode = omode; + } + + psl = rtnl_dereference(pmc->sflist); + if (!add) { + if (!psl) + goto done; /* err = -EADDRNOTAVAIL */ + rv = !0; + for (i = 0; i < psl->sl_count; i++) { + rv = memcmp(&psl->sl_addr[i], &mreqs->imr_sourceaddr, + sizeof(__be32)); + if (rv == 0) + break; + } + if (rv) /* source not found */ + goto done; /* err = -EADDRNOTAVAIL */ + + /* special case - (INCLUDE, empty) == LEAVE_GROUP */ + if (psl->sl_count == 1 && omode == MCAST_INCLUDE) { + leavegroup = 1; + goto done; + } + + /* update the interface filter */ + ip_mc_del_src(in_dev, &mreqs->imr_multiaddr, omode, 1, + &mreqs->imr_sourceaddr, 1); + + for (j = i+1; j < psl->sl_count; j++) + psl->sl_addr[j-1] = psl->sl_addr[j]; + psl->sl_count--; + err = 0; + goto done; + } + /* else, add a new source to the filter */ + + if (psl && psl->sl_count >= READ_ONCE(net->ipv4.sysctl_igmp_max_msf)) { + err = -ENOBUFS; + goto done; + } + if (!psl || psl->sl_count == psl->sl_max) { + struct ip_sf_socklist *newpsl; + int count = IP_SFBLOCK; + + if (psl) + count += psl->sl_max; + newpsl = sock_kmalloc(sk, struct_size(newpsl, sl_addr, count), + GFP_KERNEL); + if (!newpsl) { + err = -ENOBUFS; + goto done; + } + newpsl->sl_max = count; + newpsl->sl_count = count - IP_SFBLOCK; + if (psl) { + for (i = 0; i < psl->sl_count; i++) + newpsl->sl_addr[i] = psl->sl_addr[i]; + /* decrease mem now to avoid the memleak warning */ + atomic_sub(struct_size(psl, sl_addr, psl->sl_max), + &sk->sk_omem_alloc); + } + rcu_assign_pointer(pmc->sflist, newpsl); + if (psl) + kfree_rcu(psl, rcu); + psl = newpsl; + } + rv = 1; /* > 0 for insert logic below if sl_count is 0 */ + for (i = 0; i < psl->sl_count; i++) { + rv = memcmp(&psl->sl_addr[i], &mreqs->imr_sourceaddr, + sizeof(__be32)); + if (rv == 0) + break; + } + if (rv == 0) /* address already there is an error */ + goto done; + for (j = psl->sl_count-1; j >= i; j--) + psl->sl_addr[j+1] = psl->sl_addr[j]; + psl->sl_addr[i] = mreqs->imr_sourceaddr; + psl->sl_count++; + err = 0; + /* update the interface list */ + ip_mc_add_src(in_dev, &mreqs->imr_multiaddr, omode, 1, + &mreqs->imr_sourceaddr, 1); +done: + if (leavegroup) + err = ip_mc_leave_group(sk, &imr); + return err; +} + +int ip_mc_msfilter(struct sock *sk, struct ip_msfilter *msf, int ifindex) +{ + int err = 0; + struct ip_mreqn imr; + __be32 addr = msf->imsf_multiaddr; + struct ip_mc_socklist *pmc; + struct in_device *in_dev; + struct inet_sock *inet = inet_sk(sk); + struct ip_sf_socklist *newpsl, *psl; + struct net *net = sock_net(sk); + int leavegroup = 0; + + if (!ipv4_is_multicast(addr)) + return -EINVAL; + if (msf->imsf_fmode != MCAST_INCLUDE && + msf->imsf_fmode != MCAST_EXCLUDE) + return -EINVAL; + + ASSERT_RTNL(); + + imr.imr_multiaddr.s_addr = msf->imsf_multiaddr; + imr.imr_address.s_addr = msf->imsf_interface; + imr.imr_ifindex = ifindex; + in_dev = ip_mc_find_dev(net, &imr); + + if (!in_dev) { + err = -ENODEV; + goto done; + } + + /* special case - (INCLUDE, empty) == LEAVE_GROUP */ + if (msf->imsf_fmode == MCAST_INCLUDE && msf->imsf_numsrc == 0) { + leavegroup = 1; + goto done; + } + + for_each_pmc_rtnl(inet, pmc) { + if (pmc->multi.imr_multiaddr.s_addr == msf->imsf_multiaddr && + pmc->multi.imr_ifindex == imr.imr_ifindex) + break; + } + if (!pmc) { /* must have a prior join */ + err = -EINVAL; + goto done; + } + if (msf->imsf_numsrc) { + newpsl = sock_kmalloc(sk, struct_size(newpsl, sl_addr, + msf->imsf_numsrc), + GFP_KERNEL); + if (!newpsl) { + err = -ENOBUFS; + goto done; + } + newpsl->sl_max = newpsl->sl_count = msf->imsf_numsrc; + memcpy(newpsl->sl_addr, msf->imsf_slist_flex, + flex_array_size(msf, imsf_slist_flex, msf->imsf_numsrc)); + err = ip_mc_add_src(in_dev, &msf->imsf_multiaddr, + msf->imsf_fmode, newpsl->sl_count, newpsl->sl_addr, 0); + if (err) { + sock_kfree_s(sk, newpsl, + struct_size(newpsl, sl_addr, + newpsl->sl_max)); + goto done; + } + } else { + newpsl = NULL; + (void) ip_mc_add_src(in_dev, &msf->imsf_multiaddr, + msf->imsf_fmode, 0, NULL, 0); + } + psl = rtnl_dereference(pmc->sflist); + if (psl) { + (void) ip_mc_del_src(in_dev, &msf->imsf_multiaddr, pmc->sfmode, + psl->sl_count, psl->sl_addr, 0); + /* decrease mem now to avoid the memleak warning */ + atomic_sub(struct_size(psl, sl_addr, psl->sl_max), + &sk->sk_omem_alloc); + } else { + (void) ip_mc_del_src(in_dev, &msf->imsf_multiaddr, pmc->sfmode, + 0, NULL, 0); + } + rcu_assign_pointer(pmc->sflist, newpsl); + if (psl) + kfree_rcu(psl, rcu); + pmc->sfmode = msf->imsf_fmode; + err = 0; +done: + if (leavegroup) + err = ip_mc_leave_group(sk, &imr); + return err; +} +int ip_mc_msfget(struct sock *sk, struct ip_msfilter *msf, + sockptr_t optval, sockptr_t optlen) +{ + int err, len, count, copycount, msf_size; + struct ip_mreqn imr; + __be32 addr = msf->imsf_multiaddr; + struct ip_mc_socklist *pmc; + struct in_device *in_dev; + struct inet_sock *inet = inet_sk(sk); + struct ip_sf_socklist *psl; + struct net *net = sock_net(sk); + + ASSERT_RTNL(); + + if (!ipv4_is_multicast(addr)) + return -EINVAL; + + imr.imr_multiaddr.s_addr = msf->imsf_multiaddr; + imr.imr_address.s_addr = msf->imsf_interface; + imr.imr_ifindex = 0; + in_dev = ip_mc_find_dev(net, &imr); + + if (!in_dev) { + err = -ENODEV; + goto done; + } + err = -EADDRNOTAVAIL; + + for_each_pmc_rtnl(inet, pmc) { + if (pmc->multi.imr_multiaddr.s_addr == msf->imsf_multiaddr && + pmc->multi.imr_ifindex == imr.imr_ifindex) + break; + } + if (!pmc) /* must have a prior join */ + goto done; + msf->imsf_fmode = pmc->sfmode; + psl = rtnl_dereference(pmc->sflist); + if (!psl) { + count = 0; + } else { + count = psl->sl_count; + } + copycount = count < msf->imsf_numsrc ? count : msf->imsf_numsrc; + len = flex_array_size(psl, sl_addr, copycount); + msf->imsf_numsrc = count; + msf_size = IP_MSFILTER_SIZE(copycount); + if (copy_to_sockptr(optlen, &msf_size, sizeof(int)) || + copy_to_sockptr(optval, msf, IP_MSFILTER_SIZE(0))) { + return -EFAULT; + } + if (len && + copy_to_sockptr_offset(optval, + offsetof(struct ip_msfilter, imsf_slist_flex), + psl->sl_addr, len)) + return -EFAULT; + return 0; +done: + return err; +} + +int ip_mc_gsfget(struct sock *sk, struct group_filter *gsf, + sockptr_t optval, size_t ss_offset) +{ + int i, count, copycount; + struct sockaddr_in *psin; + __be32 addr; + struct ip_mc_socklist *pmc; + struct inet_sock *inet = inet_sk(sk); + struct ip_sf_socklist *psl; + + ASSERT_RTNL(); + + psin = (struct sockaddr_in *)&gsf->gf_group; + if (psin->sin_family != AF_INET) + return -EINVAL; + addr = psin->sin_addr.s_addr; + if (!ipv4_is_multicast(addr)) + return -EINVAL; + + for_each_pmc_rtnl(inet, pmc) { + if (pmc->multi.imr_multiaddr.s_addr == addr && + pmc->multi.imr_ifindex == gsf->gf_interface) + break; + } + if (!pmc) /* must have a prior join */ + return -EADDRNOTAVAIL; + gsf->gf_fmode = pmc->sfmode; + psl = rtnl_dereference(pmc->sflist); + count = psl ? psl->sl_count : 0; + copycount = count < gsf->gf_numsrc ? count : gsf->gf_numsrc; + gsf->gf_numsrc = count; + for (i = 0; i < copycount; i++) { + struct sockaddr_storage ss; + + psin = (struct sockaddr_in *)&ss; + memset(&ss, 0, sizeof(ss)); + psin->sin_family = AF_INET; + psin->sin_addr.s_addr = psl->sl_addr[i]; + if (copy_to_sockptr_offset(optval, ss_offset, + &ss, sizeof(ss))) + return -EFAULT; + ss_offset += sizeof(ss); + } + return 0; +} + +/* + * check if a multicast source filter allows delivery for a given + */ +int ip_mc_sf_allow(struct sock *sk, __be32 loc_addr, __be32 rmt_addr, + int dif, int sdif) +{ + struct inet_sock *inet = inet_sk(sk); + struct ip_mc_socklist *pmc; + struct ip_sf_socklist *psl; + int i; + int ret; + + ret = 1; + if (!ipv4_is_multicast(loc_addr)) + goto out; + + rcu_read_lock(); + for_each_pmc_rcu(inet, pmc) { + if (pmc->multi.imr_multiaddr.s_addr == loc_addr && + (pmc->multi.imr_ifindex == dif || + (sdif && pmc->multi.imr_ifindex == sdif))) + break; + } + ret = inet->mc_all; + if (!pmc) + goto unlock; + psl = rcu_dereference(pmc->sflist); + ret = (pmc->sfmode == MCAST_EXCLUDE); + if (!psl) + goto unlock; + + for (i = 0; i < psl->sl_count; i++) { + if (psl->sl_addr[i] == rmt_addr) + break; + } + ret = 0; + if (pmc->sfmode == MCAST_INCLUDE && i >= psl->sl_count) + goto unlock; + if (pmc->sfmode == MCAST_EXCLUDE && i < psl->sl_count) + goto unlock; + ret = 1; +unlock: + rcu_read_unlock(); +out: + return ret; +} + +/* + * A socket is closing. + */ + +void ip_mc_drop_socket(struct sock *sk) +{ + struct inet_sock *inet = inet_sk(sk); + struct ip_mc_socklist *iml; + struct net *net = sock_net(sk); + + if (!inet->mc_list) + return; + + rtnl_lock(); + while ((iml = rtnl_dereference(inet->mc_list)) != NULL) { + struct in_device *in_dev; + + inet->mc_list = iml->next_rcu; + in_dev = inetdev_by_index(net, iml->multi.imr_ifindex); + (void) ip_mc_leave_src(sk, iml, in_dev); + if (in_dev) + ip_mc_dec_group(in_dev, iml->multi.imr_multiaddr.s_addr); + /* decrease mem now to avoid the memleak warning */ + atomic_sub(sizeof(*iml), &sk->sk_omem_alloc); + kfree_rcu(iml, rcu); + } + rtnl_unlock(); +} + +/* called with rcu_read_lock() */ +int ip_check_mc_rcu(struct in_device *in_dev, __be32 mc_addr, __be32 src_addr, u8 proto) +{ + struct ip_mc_list *im; + struct ip_mc_list __rcu **mc_hash; + struct ip_sf_list *psf; + int rv = 0; + + mc_hash = rcu_dereference(in_dev->mc_hash); + if (mc_hash) { + u32 hash = hash_32((__force u32)mc_addr, MC_HASH_SZ_LOG); + + for (im = rcu_dereference(mc_hash[hash]); + im != NULL; + im = rcu_dereference(im->next_hash)) { + if (im->multiaddr == mc_addr) + break; + } + } else { + for_each_pmc_rcu(in_dev, im) { + if (im->multiaddr == mc_addr) + break; + } + } + if (im && proto == IPPROTO_IGMP) { + rv = 1; + } else if (im) { + if (src_addr) { + spin_lock_bh(&im->lock); + for (psf = im->sources; psf; psf = psf->sf_next) { + if (psf->sf_inaddr == src_addr) + break; + } + if (psf) + rv = psf->sf_count[MCAST_INCLUDE] || + psf->sf_count[MCAST_EXCLUDE] != + im->sfcount[MCAST_EXCLUDE]; + else + rv = im->sfcount[MCAST_EXCLUDE] != 0; + spin_unlock_bh(&im->lock); + } else + rv = 1; /* unspecified source; tentatively allow */ + } + return rv; +} + +#if defined(CONFIG_PROC_FS) +struct igmp_mc_iter_state { + struct seq_net_private p; + struct net_device *dev; + struct in_device *in_dev; +}; + +#define igmp_mc_seq_private(seq) ((struct igmp_mc_iter_state *)(seq)->private) + +static inline struct ip_mc_list *igmp_mc_get_first(struct seq_file *seq) +{ + struct net *net = seq_file_net(seq); + struct ip_mc_list *im = NULL; + struct igmp_mc_iter_state *state = igmp_mc_seq_private(seq); + + state->in_dev = NULL; + for_each_netdev_rcu(net, state->dev) { + struct in_device *in_dev; + + in_dev = __in_dev_get_rcu(state->dev); + if (!in_dev) + continue; + im = rcu_dereference(in_dev->mc_list); + if (im) { + state->in_dev = in_dev; + break; + } + } + return im; +} + +static struct ip_mc_list *igmp_mc_get_next(struct seq_file *seq, struct ip_mc_list *im) +{ + struct igmp_mc_iter_state *state = igmp_mc_seq_private(seq); + + im = rcu_dereference(im->next_rcu); + while (!im) { + state->dev = next_net_device_rcu(state->dev); + if (!state->dev) { + state->in_dev = NULL; + break; + } + state->in_dev = __in_dev_get_rcu(state->dev); + if (!state->in_dev) + continue; + im = rcu_dereference(state->in_dev->mc_list); + } + return im; +} + +static struct ip_mc_list *igmp_mc_get_idx(struct seq_file *seq, loff_t pos) +{ + struct ip_mc_list *im = igmp_mc_get_first(seq); + if (im) + while (pos && (im = igmp_mc_get_next(seq, im)) != NULL) + --pos; + return pos ? NULL : im; +} + +static void *igmp_mc_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(rcu) +{ + rcu_read_lock(); + return *pos ? igmp_mc_get_idx(seq, *pos - 1) : SEQ_START_TOKEN; +} + +static void *igmp_mc_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct ip_mc_list *im; + if (v == SEQ_START_TOKEN) + im = igmp_mc_get_first(seq); + else + im = igmp_mc_get_next(seq, v); + ++*pos; + return im; +} + +static void igmp_mc_seq_stop(struct seq_file *seq, void *v) + __releases(rcu) +{ + struct igmp_mc_iter_state *state = igmp_mc_seq_private(seq); + + state->in_dev = NULL; + state->dev = NULL; + rcu_read_unlock(); +} + +static int igmp_mc_seq_show(struct seq_file *seq, void *v) +{ + if (v == SEQ_START_TOKEN) + seq_puts(seq, + "Idx\tDevice : Count Querier\tGroup Users Timer\tReporter\n"); + else { + struct ip_mc_list *im = v; + struct igmp_mc_iter_state *state = igmp_mc_seq_private(seq); + char *querier; + long delta; + +#ifdef CONFIG_IP_MULTICAST + querier = IGMP_V1_SEEN(state->in_dev) ? "V1" : + IGMP_V2_SEEN(state->in_dev) ? "V2" : + "V3"; +#else + querier = "NONE"; +#endif + + if (rcu_access_pointer(state->in_dev->mc_list) == im) { + seq_printf(seq, "%d\t%-10s: %5d %7s\n", + state->dev->ifindex, state->dev->name, state->in_dev->mc_count, querier); + } + + delta = im->timer.expires - jiffies; + seq_printf(seq, + "\t\t\t\t%08X %5d %d:%08lX\t\t%d\n", + im->multiaddr, im->users, + im->tm_running, + im->tm_running ? jiffies_delta_to_clock_t(delta) : 0, + im->reporter); + } + return 0; +} + +static const struct seq_operations igmp_mc_seq_ops = { + .start = igmp_mc_seq_start, + .next = igmp_mc_seq_next, + .stop = igmp_mc_seq_stop, + .show = igmp_mc_seq_show, +}; + +struct igmp_mcf_iter_state { + struct seq_net_private p; + struct net_device *dev; + struct in_device *idev; + struct ip_mc_list *im; +}; + +#define igmp_mcf_seq_private(seq) ((struct igmp_mcf_iter_state *)(seq)->private) + +static inline struct ip_sf_list *igmp_mcf_get_first(struct seq_file *seq) +{ + struct net *net = seq_file_net(seq); + struct ip_sf_list *psf = NULL; + struct ip_mc_list *im = NULL; + struct igmp_mcf_iter_state *state = igmp_mcf_seq_private(seq); + + state->idev = NULL; + state->im = NULL; + for_each_netdev_rcu(net, state->dev) { + struct in_device *idev; + idev = __in_dev_get_rcu(state->dev); + if (unlikely(!idev)) + continue; + im = rcu_dereference(idev->mc_list); + if (likely(im)) { + spin_lock_bh(&im->lock); + psf = im->sources; + if (likely(psf)) { + state->im = im; + state->idev = idev; + break; + } + spin_unlock_bh(&im->lock); + } + } + return psf; +} + +static struct ip_sf_list *igmp_mcf_get_next(struct seq_file *seq, struct ip_sf_list *psf) +{ + struct igmp_mcf_iter_state *state = igmp_mcf_seq_private(seq); + + psf = psf->sf_next; + while (!psf) { + spin_unlock_bh(&state->im->lock); + state->im = state->im->next; + while (!state->im) { + state->dev = next_net_device_rcu(state->dev); + if (!state->dev) { + state->idev = NULL; + goto out; + } + state->idev = __in_dev_get_rcu(state->dev); + if (!state->idev) + continue; + state->im = rcu_dereference(state->idev->mc_list); + } + if (!state->im) + break; + spin_lock_bh(&state->im->lock); + psf = state->im->sources; + } +out: + return psf; +} + +static struct ip_sf_list *igmp_mcf_get_idx(struct seq_file *seq, loff_t pos) +{ + struct ip_sf_list *psf = igmp_mcf_get_first(seq); + if (psf) + while (pos && (psf = igmp_mcf_get_next(seq, psf)) != NULL) + --pos; + return pos ? NULL : psf; +} + +static void *igmp_mcf_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(rcu) +{ + rcu_read_lock(); + return *pos ? igmp_mcf_get_idx(seq, *pos - 1) : SEQ_START_TOKEN; +} + +static void *igmp_mcf_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct ip_sf_list *psf; + if (v == SEQ_START_TOKEN) + psf = igmp_mcf_get_first(seq); + else + psf = igmp_mcf_get_next(seq, v); + ++*pos; + return psf; +} + +static void igmp_mcf_seq_stop(struct seq_file *seq, void *v) + __releases(rcu) +{ + struct igmp_mcf_iter_state *state = igmp_mcf_seq_private(seq); + if (likely(state->im)) { + spin_unlock_bh(&state->im->lock); + state->im = NULL; + } + state->idev = NULL; + state->dev = NULL; + rcu_read_unlock(); +} + +static int igmp_mcf_seq_show(struct seq_file *seq, void *v) +{ + struct ip_sf_list *psf = v; + struct igmp_mcf_iter_state *state = igmp_mcf_seq_private(seq); + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, "Idx Device MCA SRC INC EXC\n"); + } else { + seq_printf(seq, + "%3d %6.6s 0x%08x " + "0x%08x %6lu %6lu\n", + state->dev->ifindex, state->dev->name, + ntohl(state->im->multiaddr), + ntohl(psf->sf_inaddr), + psf->sf_count[MCAST_INCLUDE], + psf->sf_count[MCAST_EXCLUDE]); + } + return 0; +} + +static const struct seq_operations igmp_mcf_seq_ops = { + .start = igmp_mcf_seq_start, + .next = igmp_mcf_seq_next, + .stop = igmp_mcf_seq_stop, + .show = igmp_mcf_seq_show, +}; + +static int __net_init igmp_net_init(struct net *net) +{ + struct proc_dir_entry *pde; + int err; + + pde = proc_create_net("igmp", 0444, net->proc_net, &igmp_mc_seq_ops, + sizeof(struct igmp_mc_iter_state)); + if (!pde) + goto out_igmp; + pde = proc_create_net("mcfilter", 0444, net->proc_net, + &igmp_mcf_seq_ops, sizeof(struct igmp_mcf_iter_state)); + if (!pde) + goto out_mcfilter; + err = inet_ctl_sock_create(&net->ipv4.mc_autojoin_sk, AF_INET, + SOCK_DGRAM, 0, net); + if (err < 0) { + pr_err("Failed to initialize the IGMP autojoin socket (err %d)\n", + err); + goto out_sock; + } + + return 0; + +out_sock: + remove_proc_entry("mcfilter", net->proc_net); +out_mcfilter: + remove_proc_entry("igmp", net->proc_net); +out_igmp: + return -ENOMEM; +} + +static void __net_exit igmp_net_exit(struct net *net) +{ + remove_proc_entry("mcfilter", net->proc_net); + remove_proc_entry("igmp", net->proc_net); + inet_ctl_sock_destroy(net->ipv4.mc_autojoin_sk); +} + +static struct pernet_operations igmp_net_ops = { + .init = igmp_net_init, + .exit = igmp_net_exit, +}; +#endif + +static int igmp_netdev_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct in_device *in_dev; + + switch (event) { + case NETDEV_RESEND_IGMP: + in_dev = __in_dev_get_rtnl(dev); + if (in_dev) + ip_mc_rejoin_groups(in_dev); + break; + default: + break; + } + return NOTIFY_DONE; +} + +static struct notifier_block igmp_notifier = { + .notifier_call = igmp_netdev_event, +}; + +int __init igmp_mc_init(void) +{ +#if defined(CONFIG_PROC_FS) + int err; + + err = register_pernet_subsys(&igmp_net_ops); + if (err) + return err; + err = register_netdevice_notifier(&igmp_notifier); + if (err) + goto reg_notif_fail; + return 0; + +reg_notif_fail: + unregister_pernet_subsys(&igmp_net_ops); + return err; +#else + return register_netdevice_notifier(&igmp_notifier); +#endif +} diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c new file mode 100644 index 000000000..79fa19a36 --- /dev/null +++ b/net/ipv4/inet_connection_sock.c @@ -0,0 +1,1501 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * Support for INET connection oriented protocols. + * + * Authors: See the TCP sources + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_IPV6) +/* match_sk*_wildcard == true: IPV6_ADDR_ANY equals to any IPv6 addresses + * if IPv6 only, and any IPv4 addresses + * if not IPv6 only + * match_sk*_wildcard == false: addresses must be exactly the same, i.e. + * IPV6_ADDR_ANY only equals to IPV6_ADDR_ANY, + * and 0.0.0.0 equals to 0.0.0.0 only + */ +static bool ipv6_rcv_saddr_equal(const struct in6_addr *sk1_rcv_saddr6, + const struct in6_addr *sk2_rcv_saddr6, + __be32 sk1_rcv_saddr, __be32 sk2_rcv_saddr, + bool sk1_ipv6only, bool sk2_ipv6only, + bool match_sk1_wildcard, + bool match_sk2_wildcard) +{ + int addr_type = ipv6_addr_type(sk1_rcv_saddr6); + int addr_type2 = sk2_rcv_saddr6 ? ipv6_addr_type(sk2_rcv_saddr6) : IPV6_ADDR_MAPPED; + + /* if both are mapped, treat as IPv4 */ + if (addr_type == IPV6_ADDR_MAPPED && addr_type2 == IPV6_ADDR_MAPPED) { + if (!sk2_ipv6only) { + if (sk1_rcv_saddr == sk2_rcv_saddr) + return true; + return (match_sk1_wildcard && !sk1_rcv_saddr) || + (match_sk2_wildcard && !sk2_rcv_saddr); + } + return false; + } + + if (addr_type == IPV6_ADDR_ANY && addr_type2 == IPV6_ADDR_ANY) + return true; + + if (addr_type2 == IPV6_ADDR_ANY && match_sk2_wildcard && + !(sk2_ipv6only && addr_type == IPV6_ADDR_MAPPED)) + return true; + + if (addr_type == IPV6_ADDR_ANY && match_sk1_wildcard && + !(sk1_ipv6only && addr_type2 == IPV6_ADDR_MAPPED)) + return true; + + if (sk2_rcv_saddr6 && + ipv6_addr_equal(sk1_rcv_saddr6, sk2_rcv_saddr6)) + return true; + + return false; +} +#endif + +/* match_sk*_wildcard == true: 0.0.0.0 equals to any IPv4 addresses + * match_sk*_wildcard == false: addresses must be exactly the same, i.e. + * 0.0.0.0 only equals to 0.0.0.0 + */ +static bool ipv4_rcv_saddr_equal(__be32 sk1_rcv_saddr, __be32 sk2_rcv_saddr, + bool sk2_ipv6only, bool match_sk1_wildcard, + bool match_sk2_wildcard) +{ + if (!sk2_ipv6only) { + if (sk1_rcv_saddr == sk2_rcv_saddr) + return true; + return (match_sk1_wildcard && !sk1_rcv_saddr) || + (match_sk2_wildcard && !sk2_rcv_saddr); + } + return false; +} + +bool inet_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2, + bool match_wildcard) +{ +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) + return ipv6_rcv_saddr_equal(&sk->sk_v6_rcv_saddr, + inet6_rcv_saddr(sk2), + sk->sk_rcv_saddr, + sk2->sk_rcv_saddr, + ipv6_only_sock(sk), + ipv6_only_sock(sk2), + match_wildcard, + match_wildcard); +#endif + return ipv4_rcv_saddr_equal(sk->sk_rcv_saddr, sk2->sk_rcv_saddr, + ipv6_only_sock(sk2), match_wildcard, + match_wildcard); +} +EXPORT_SYMBOL(inet_rcv_saddr_equal); + +bool inet_rcv_saddr_any(const struct sock *sk) +{ +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) + return ipv6_addr_any(&sk->sk_v6_rcv_saddr); +#endif + return !sk->sk_rcv_saddr; +} + +void inet_get_local_port_range(const struct net *net, int *low, int *high) +{ + unsigned int seq; + + do { + seq = read_seqbegin(&net->ipv4.ip_local_ports.lock); + + *low = net->ipv4.ip_local_ports.range[0]; + *high = net->ipv4.ip_local_ports.range[1]; + } while (read_seqretry(&net->ipv4.ip_local_ports.lock, seq)); +} +EXPORT_SYMBOL(inet_get_local_port_range); + +void inet_sk_get_local_port_range(const struct sock *sk, int *low, int *high) +{ + const struct inet_sock *inet = inet_sk(sk); + const struct net *net = sock_net(sk); + int lo, hi, sk_lo, sk_hi; + + inet_get_local_port_range(net, &lo, &hi); + + sk_lo = inet->local_port_range.lo; + sk_hi = inet->local_port_range.hi; + + if (unlikely(lo <= sk_lo && sk_lo <= hi)) + lo = sk_lo; + if (unlikely(lo <= sk_hi && sk_hi <= hi)) + hi = sk_hi; + + *low = lo; + *high = hi; +} +EXPORT_SYMBOL(inet_sk_get_local_port_range); + +static bool inet_use_bhash2_on_bind(const struct sock *sk) +{ +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) { + int addr_type = ipv6_addr_type(&sk->sk_v6_rcv_saddr); + + return addr_type != IPV6_ADDR_ANY && + addr_type != IPV6_ADDR_MAPPED; + } +#endif + return sk->sk_rcv_saddr != htonl(INADDR_ANY); +} + +static bool inet_bind_conflict(const struct sock *sk, struct sock *sk2, + kuid_t sk_uid, bool relax, + bool reuseport_cb_ok, bool reuseport_ok) +{ + int bound_dev_if2; + + if (sk == sk2) + return false; + + bound_dev_if2 = READ_ONCE(sk2->sk_bound_dev_if); + + if (!sk->sk_bound_dev_if || !bound_dev_if2 || + sk->sk_bound_dev_if == bound_dev_if2) { + if (sk->sk_reuse && sk2->sk_reuse && + sk2->sk_state != TCP_LISTEN) { + if (!relax || (!reuseport_ok && sk->sk_reuseport && + sk2->sk_reuseport && reuseport_cb_ok && + (sk2->sk_state == TCP_TIME_WAIT || + uid_eq(sk_uid, sock_i_uid(sk2))))) + return true; + } else if (!reuseport_ok || !sk->sk_reuseport || + !sk2->sk_reuseport || !reuseport_cb_ok || + (sk2->sk_state != TCP_TIME_WAIT && + !uid_eq(sk_uid, sock_i_uid(sk2)))) { + return true; + } + } + return false; +} + +static bool __inet_bhash2_conflict(const struct sock *sk, struct sock *sk2, + kuid_t sk_uid, bool relax, + bool reuseport_cb_ok, bool reuseport_ok) +{ + if (sk->sk_family == AF_INET && ipv6_only_sock(sk2)) + return false; + + return inet_bind_conflict(sk, sk2, sk_uid, relax, + reuseport_cb_ok, reuseport_ok); +} + +static bool inet_bhash2_conflict(const struct sock *sk, + const struct inet_bind2_bucket *tb2, + kuid_t sk_uid, + bool relax, bool reuseport_cb_ok, + bool reuseport_ok) +{ + struct inet_timewait_sock *tw2; + struct sock *sk2; + + sk_for_each_bound_bhash2(sk2, &tb2->owners) { + if (__inet_bhash2_conflict(sk, sk2, sk_uid, relax, + reuseport_cb_ok, reuseport_ok)) + return true; + } + + twsk_for_each_bound_bhash2(tw2, &tb2->deathrow) { + sk2 = (struct sock *)tw2; + + if (__inet_bhash2_conflict(sk, sk2, sk_uid, relax, + reuseport_cb_ok, reuseport_ok)) + return true; + } + + return false; +} + +/* This should be called only when the tb and tb2 hashbuckets' locks are held */ +static int inet_csk_bind_conflict(const struct sock *sk, + const struct inet_bind_bucket *tb, + const struct inet_bind2_bucket *tb2, /* may be null */ + bool relax, bool reuseport_ok) +{ + bool reuseport_cb_ok; + struct sock_reuseport *reuseport_cb; + kuid_t uid = sock_i_uid((struct sock *)sk); + + rcu_read_lock(); + reuseport_cb = rcu_dereference(sk->sk_reuseport_cb); + /* paired with WRITE_ONCE() in __reuseport_(add|detach)_closed_sock */ + reuseport_cb_ok = !reuseport_cb || READ_ONCE(reuseport_cb->num_closed_socks); + rcu_read_unlock(); + + /* + * Unlike other sk lookup places we do not check + * for sk_net here, since _all_ the socks listed + * in tb->owners and tb2->owners list belong + * to the same net - the one this bucket belongs to. + */ + + if (!inet_use_bhash2_on_bind(sk)) { + struct sock *sk2; + + sk_for_each_bound(sk2, &tb->owners) + if (inet_bind_conflict(sk, sk2, uid, relax, + reuseport_cb_ok, reuseport_ok) && + inet_rcv_saddr_equal(sk, sk2, true)) + return true; + + return false; + } + + /* Conflicts with an existing IPV6_ADDR_ANY (if ipv6) or INADDR_ANY (if + * ipv4) should have been checked already. We need to do these two + * checks separately because their spinlocks have to be acquired/released + * independently of each other, to prevent possible deadlocks + */ + return tb2 && inet_bhash2_conflict(sk, tb2, uid, relax, reuseport_cb_ok, + reuseport_ok); +} + +/* Determine if there is a bind conflict with an existing IPV6_ADDR_ANY (if ipv6) or + * INADDR_ANY (if ipv4) socket. + * + * Caller must hold bhash hashbucket lock with local bh disabled, to protect + * against concurrent binds on the port for addr any + */ +static bool inet_bhash2_addr_any_conflict(const struct sock *sk, int port, int l3mdev, + bool relax, bool reuseport_ok) +{ + kuid_t uid = sock_i_uid((struct sock *)sk); + const struct net *net = sock_net(sk); + struct sock_reuseport *reuseport_cb; + struct inet_bind_hashbucket *head2; + struct inet_bind2_bucket *tb2; + bool reuseport_cb_ok; + + rcu_read_lock(); + reuseport_cb = rcu_dereference(sk->sk_reuseport_cb); + /* paired with WRITE_ONCE() in __reuseport_(add|detach)_closed_sock */ + reuseport_cb_ok = !reuseport_cb || READ_ONCE(reuseport_cb->num_closed_socks); + rcu_read_unlock(); + + head2 = inet_bhash2_addr_any_hashbucket(sk, net, port); + + spin_lock(&head2->lock); + + inet_bind_bucket_for_each(tb2, &head2->chain) + if (inet_bind2_bucket_match_addr_any(tb2, net, port, l3mdev, sk)) + break; + + if (tb2 && inet_bhash2_conflict(sk, tb2, uid, relax, reuseport_cb_ok, + reuseport_ok)) { + spin_unlock(&head2->lock); + return true; + } + + spin_unlock(&head2->lock); + return false; +} + +/* + * Find an open port number for the socket. Returns with the + * inet_bind_hashbucket locks held if successful. + */ +static struct inet_bind_hashbucket * +inet_csk_find_open_port(const struct sock *sk, struct inet_bind_bucket **tb_ret, + struct inet_bind2_bucket **tb2_ret, + struct inet_bind_hashbucket **head2_ret, int *port_ret) +{ + struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); + int i, low, high, attempt_half, port, l3mdev; + struct inet_bind_hashbucket *head, *head2; + struct net *net = sock_net(sk); + struct inet_bind2_bucket *tb2; + struct inet_bind_bucket *tb; + u32 remaining, offset; + bool relax = false; + + l3mdev = inet_sk_bound_l3mdev(sk); +ports_exhausted: + attempt_half = (sk->sk_reuse == SK_CAN_REUSE) ? 1 : 0; +other_half_scan: + inet_sk_get_local_port_range(sk, &low, &high); + high++; /* [32768, 60999] -> [32768, 61000[ */ + if (high - low < 4) + attempt_half = 0; + if (attempt_half) { + int half = low + (((high - low) >> 2) << 1); + + if (attempt_half == 1) + high = half; + else + low = half; + } + remaining = high - low; + if (likely(remaining > 1)) + remaining &= ~1U; + + offset = prandom_u32_max(remaining); + /* __inet_hash_connect() favors ports having @low parity + * We do the opposite to not pollute connect() users. + */ + offset |= 1U; + +other_parity_scan: + port = low + offset; + for (i = 0; i < remaining; i += 2, port += 2) { + if (unlikely(port >= high)) + port -= remaining; + if (inet_is_local_reserved_port(net, port)) + continue; + head = &hinfo->bhash[inet_bhashfn(net, port, + hinfo->bhash_size)]; + spin_lock_bh(&head->lock); + if (inet_use_bhash2_on_bind(sk)) { + if (inet_bhash2_addr_any_conflict(sk, port, l3mdev, relax, false)) + goto next_port; + } + + head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); + spin_lock(&head2->lock); + tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); + inet_bind_bucket_for_each(tb, &head->chain) + if (inet_bind_bucket_match(tb, net, port, l3mdev)) { + if (!inet_csk_bind_conflict(sk, tb, tb2, + relax, false)) + goto success; + spin_unlock(&head2->lock); + goto next_port; + } + tb = NULL; + goto success; +next_port: + spin_unlock_bh(&head->lock); + cond_resched(); + } + + offset--; + if (!(offset & 1)) + goto other_parity_scan; + + if (attempt_half == 1) { + /* OK we now try the upper half of the range */ + attempt_half = 2; + goto other_half_scan; + } + + if (READ_ONCE(net->ipv4.sysctl_ip_autobind_reuse) && !relax) { + /* We still have a chance to connect to different destinations */ + relax = true; + goto ports_exhausted; + } + return NULL; +success: + *port_ret = port; + *tb_ret = tb; + *tb2_ret = tb2; + *head2_ret = head2; + return head; +} + +static inline int sk_reuseport_match(struct inet_bind_bucket *tb, + struct sock *sk) +{ + kuid_t uid = sock_i_uid(sk); + + if (tb->fastreuseport <= 0) + return 0; + if (!sk->sk_reuseport) + return 0; + if (rcu_access_pointer(sk->sk_reuseport_cb)) + return 0; + if (!uid_eq(tb->fastuid, uid)) + return 0; + /* We only need to check the rcv_saddr if this tb was once marked + * without fastreuseport and then was reset, as we can only know that + * the fast_*rcv_saddr doesn't have any conflicts with the socks on the + * owners list. + */ + if (tb->fastreuseport == FASTREUSEPORT_ANY) + return 1; +#if IS_ENABLED(CONFIG_IPV6) + if (tb->fast_sk_family == AF_INET6) + return ipv6_rcv_saddr_equal(&tb->fast_v6_rcv_saddr, + inet6_rcv_saddr(sk), + tb->fast_rcv_saddr, + sk->sk_rcv_saddr, + tb->fast_ipv6_only, + ipv6_only_sock(sk), true, false); +#endif + return ipv4_rcv_saddr_equal(tb->fast_rcv_saddr, sk->sk_rcv_saddr, + ipv6_only_sock(sk), true, false); +} + +void inet_csk_update_fastreuse(struct inet_bind_bucket *tb, + struct sock *sk) +{ + kuid_t uid = sock_i_uid(sk); + bool reuse = sk->sk_reuse && sk->sk_state != TCP_LISTEN; + + if (hlist_empty(&tb->owners)) { + tb->fastreuse = reuse; + if (sk->sk_reuseport) { + tb->fastreuseport = FASTREUSEPORT_ANY; + tb->fastuid = uid; + tb->fast_rcv_saddr = sk->sk_rcv_saddr; + tb->fast_ipv6_only = ipv6_only_sock(sk); + tb->fast_sk_family = sk->sk_family; +#if IS_ENABLED(CONFIG_IPV6) + tb->fast_v6_rcv_saddr = sk->sk_v6_rcv_saddr; +#endif + } else { + tb->fastreuseport = 0; + } + } else { + if (!reuse) + tb->fastreuse = 0; + if (sk->sk_reuseport) { + /* We didn't match or we don't have fastreuseport set on + * the tb, but we have sk_reuseport set on this socket + * and we know that there are no bind conflicts with + * this socket in this tb, so reset our tb's reuseport + * settings so that any subsequent sockets that match + * our current socket will be put on the fast path. + * + * If we reset we need to set FASTREUSEPORT_STRICT so we + * do extra checking for all subsequent sk_reuseport + * socks. + */ + if (!sk_reuseport_match(tb, sk)) { + tb->fastreuseport = FASTREUSEPORT_STRICT; + tb->fastuid = uid; + tb->fast_rcv_saddr = sk->sk_rcv_saddr; + tb->fast_ipv6_only = ipv6_only_sock(sk); + tb->fast_sk_family = sk->sk_family; +#if IS_ENABLED(CONFIG_IPV6) + tb->fast_v6_rcv_saddr = sk->sk_v6_rcv_saddr; +#endif + } + } else { + tb->fastreuseport = 0; + } + } +} + +/* Obtain a reference to a local port for the given sock, + * if snum is zero it means select any available local port. + * We try to allocate an odd port (and leave even ports for connect()) + */ +int inet_csk_get_port(struct sock *sk, unsigned short snum) +{ + struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); + bool reuse = sk->sk_reuse && sk->sk_state != TCP_LISTEN; + bool found_port = false, check_bind_conflict = true; + bool bhash_created = false, bhash2_created = false; + int ret = -EADDRINUSE, port = snum, l3mdev; + struct inet_bind_hashbucket *head, *head2; + struct inet_bind2_bucket *tb2 = NULL; + struct inet_bind_bucket *tb = NULL; + bool head2_lock_acquired = false; + struct net *net = sock_net(sk); + + l3mdev = inet_sk_bound_l3mdev(sk); + + if (!port) { + head = inet_csk_find_open_port(sk, &tb, &tb2, &head2, &port); + if (!head) + return ret; + + head2_lock_acquired = true; + + if (tb && tb2) + goto success; + found_port = true; + } else { + head = &hinfo->bhash[inet_bhashfn(net, port, + hinfo->bhash_size)]; + spin_lock_bh(&head->lock); + inet_bind_bucket_for_each(tb, &head->chain) + if (inet_bind_bucket_match(tb, net, port, l3mdev)) + break; + } + + if (!tb) { + tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep, net, + head, port, l3mdev); + if (!tb) + goto fail_unlock; + bhash_created = true; + } + + if (!found_port) { + if (!hlist_empty(&tb->owners)) { + if (sk->sk_reuse == SK_FORCE_REUSE || + (tb->fastreuse > 0 && reuse) || + sk_reuseport_match(tb, sk)) + check_bind_conflict = false; + } + + if (check_bind_conflict && inet_use_bhash2_on_bind(sk)) { + if (inet_bhash2_addr_any_conflict(sk, port, l3mdev, true, true)) + goto fail_unlock; + } + + head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); + spin_lock(&head2->lock); + head2_lock_acquired = true; + tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); + } + + if (!tb2) { + tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep, + net, head2, port, l3mdev, sk); + if (!tb2) + goto fail_unlock; + bhash2_created = true; + } + + if (!found_port && check_bind_conflict) { + if (inet_csk_bind_conflict(sk, tb, tb2, true, true)) + goto fail_unlock; + } + +success: + inet_csk_update_fastreuse(tb, sk); + + if (!inet_csk(sk)->icsk_bind_hash) + inet_bind_hash(sk, tb, tb2, port); + WARN_ON(inet_csk(sk)->icsk_bind_hash != tb); + WARN_ON(inet_csk(sk)->icsk_bind2_hash != tb2); + ret = 0; + +fail_unlock: + if (ret) { + if (bhash_created) + inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb); + if (bhash2_created) + inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep, + tb2); + } + if (head2_lock_acquired) + spin_unlock(&head2->lock); + spin_unlock_bh(&head->lock); + return ret; +} +EXPORT_SYMBOL_GPL(inet_csk_get_port); + +/* + * Wait for an incoming connection, avoid race conditions. This must be called + * with the socket locked. + */ +static int inet_csk_wait_for_connect(struct sock *sk, long timeo) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + DEFINE_WAIT(wait); + int err; + + /* + * True wake-one mechanism for incoming connections: only + * one process gets woken up, not the 'whole herd'. + * Since we do not 'race & poll' for established sockets + * anymore, the common case will execute the loop only once. + * + * Subtle issue: "add_wait_queue_exclusive()" will be added + * after any current non-exclusive waiters, and we know that + * it will always _stay_ after any new non-exclusive waiters + * because all non-exclusive waiters are added at the + * beginning of the wait-queue. As such, it's ok to "drop" + * our exclusiveness temporarily when we get woken up without + * having to remove and re-insert us on the wait queue. + */ + for (;;) { + prepare_to_wait_exclusive(sk_sleep(sk), &wait, + TASK_INTERRUPTIBLE); + release_sock(sk); + if (reqsk_queue_empty(&icsk->icsk_accept_queue)) + timeo = schedule_timeout(timeo); + sched_annotate_sleep(); + lock_sock(sk); + err = 0; + if (!reqsk_queue_empty(&icsk->icsk_accept_queue)) + break; + err = -EINVAL; + if (sk->sk_state != TCP_LISTEN) + break; + err = sock_intr_errno(timeo); + if (signal_pending(current)) + break; + err = -EAGAIN; + if (!timeo) + break; + } + finish_wait(sk_sleep(sk), &wait); + return err; +} + +/* + * This will accept the next outstanding connection. + */ +struct sock *inet_csk_accept(struct sock *sk, int flags, int *err, bool kern) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct request_sock_queue *queue = &icsk->icsk_accept_queue; + struct request_sock *req; + struct sock *newsk; + int error; + + lock_sock(sk); + + /* We need to make sure that this socket is listening, + * and that it has something pending. + */ + error = -EINVAL; + if (sk->sk_state != TCP_LISTEN) + goto out_err; + + /* Find already established connection */ + if (reqsk_queue_empty(queue)) { + long timeo = sock_rcvtimeo(sk, flags & O_NONBLOCK); + + /* If this is a non blocking socket don't sleep */ + error = -EAGAIN; + if (!timeo) + goto out_err; + + error = inet_csk_wait_for_connect(sk, timeo); + if (error) + goto out_err; + } + req = reqsk_queue_remove(queue, sk); + newsk = req->sk; + + if (sk->sk_protocol == IPPROTO_TCP && + tcp_rsk(req)->tfo_listener) { + spin_lock_bh(&queue->fastopenq.lock); + if (tcp_rsk(req)->tfo_listener) { + /* We are still waiting for the final ACK from 3WHS + * so can't free req now. Instead, we set req->sk to + * NULL to signify that the child socket is taken + * so reqsk_fastopen_remove() will free the req + * when 3WHS finishes (or is aborted). + */ + req->sk = NULL; + req = NULL; + } + spin_unlock_bh(&queue->fastopenq.lock); + } + +out: + release_sock(sk); + if (newsk && mem_cgroup_sockets_enabled) { + int amt; + + /* atomically get the memory usage, set and charge the + * newsk->sk_memcg. + */ + lock_sock(newsk); + + /* The socket has not been accepted yet, no need to look at + * newsk->sk_wmem_queued. + */ + amt = sk_mem_pages(newsk->sk_forward_alloc + + atomic_read(&newsk->sk_rmem_alloc)); + mem_cgroup_sk_alloc(newsk); + if (newsk->sk_memcg && amt) + mem_cgroup_charge_skmem(newsk->sk_memcg, amt, + GFP_KERNEL | __GFP_NOFAIL); + + release_sock(newsk); + } + if (req) + reqsk_put(req); + + if (newsk) + inet_init_csk_locks(newsk); + + return newsk; +out_err: + newsk = NULL; + req = NULL; + *err = error; + goto out; +} +EXPORT_SYMBOL(inet_csk_accept); + +/* + * Using different timers for retransmit, delayed acks and probes + * We may wish use just one timer maintaining a list of expire jiffies + * to optimize. + */ +void inet_csk_init_xmit_timers(struct sock *sk, + void (*retransmit_handler)(struct timer_list *t), + void (*delack_handler)(struct timer_list *t), + void (*keepalive_handler)(struct timer_list *t)) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + timer_setup(&icsk->icsk_retransmit_timer, retransmit_handler, 0); + timer_setup(&icsk->icsk_delack_timer, delack_handler, 0); + timer_setup(&sk->sk_timer, keepalive_handler, 0); + icsk->icsk_pending = icsk->icsk_ack.pending = 0; +} +EXPORT_SYMBOL(inet_csk_init_xmit_timers); + +void inet_csk_clear_xmit_timers(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + icsk->icsk_pending = icsk->icsk_ack.pending = 0; + + sk_stop_timer(sk, &icsk->icsk_retransmit_timer); + sk_stop_timer(sk, &icsk->icsk_delack_timer); + sk_stop_timer(sk, &sk->sk_timer); +} +EXPORT_SYMBOL(inet_csk_clear_xmit_timers); + +void inet_csk_delete_keepalive_timer(struct sock *sk) +{ + sk_stop_timer(sk, &sk->sk_timer); +} +EXPORT_SYMBOL(inet_csk_delete_keepalive_timer); + +void inet_csk_reset_keepalive_timer(struct sock *sk, unsigned long len) +{ + sk_reset_timer(sk, &sk->sk_timer, jiffies + len); +} +EXPORT_SYMBOL(inet_csk_reset_keepalive_timer); + +struct dst_entry *inet_csk_route_req(const struct sock *sk, + struct flowi4 *fl4, + const struct request_sock *req) +{ + const struct inet_request_sock *ireq = inet_rsk(req); + struct net *net = read_pnet(&ireq->ireq_net); + struct ip_options_rcu *opt; + struct rtable *rt; + + rcu_read_lock(); + opt = rcu_dereference(ireq->ireq_opt); + + flowi4_init_output(fl4, ireq->ir_iif, ireq->ir_mark, + RT_CONN_FLAGS(sk), RT_SCOPE_UNIVERSE, + sk->sk_protocol, inet_sk_flowi_flags(sk), + (opt && opt->opt.srr) ? opt->opt.faddr : ireq->ir_rmt_addr, + ireq->ir_loc_addr, ireq->ir_rmt_port, + htons(ireq->ir_num), sk->sk_uid); + security_req_classify_flow(req, flowi4_to_flowi_common(fl4)); + rt = ip_route_output_flow(net, fl4, sk); + if (IS_ERR(rt)) + goto no_route; + if (opt && opt->opt.is_strictroute && rt->rt_uses_gateway) + goto route_err; + rcu_read_unlock(); + return &rt->dst; + +route_err: + ip_rt_put(rt); +no_route: + rcu_read_unlock(); + __IP_INC_STATS(net, IPSTATS_MIB_OUTNOROUTES); + return NULL; +} +EXPORT_SYMBOL_GPL(inet_csk_route_req); + +struct dst_entry *inet_csk_route_child_sock(const struct sock *sk, + struct sock *newsk, + const struct request_sock *req) +{ + const struct inet_request_sock *ireq = inet_rsk(req); + struct net *net = read_pnet(&ireq->ireq_net); + struct inet_sock *newinet = inet_sk(newsk); + struct ip_options_rcu *opt; + struct flowi4 *fl4; + struct rtable *rt; + + opt = rcu_dereference(ireq->ireq_opt); + fl4 = &newinet->cork.fl.u.ip4; + + flowi4_init_output(fl4, ireq->ir_iif, ireq->ir_mark, + RT_CONN_FLAGS(sk), RT_SCOPE_UNIVERSE, + sk->sk_protocol, inet_sk_flowi_flags(sk), + (opt && opt->opt.srr) ? opt->opt.faddr : ireq->ir_rmt_addr, + ireq->ir_loc_addr, ireq->ir_rmt_port, + htons(ireq->ir_num), sk->sk_uid); + security_req_classify_flow(req, flowi4_to_flowi_common(fl4)); + rt = ip_route_output_flow(net, fl4, sk); + if (IS_ERR(rt)) + goto no_route; + if (opt && opt->opt.is_strictroute && rt->rt_uses_gateway) + goto route_err; + return &rt->dst; + +route_err: + ip_rt_put(rt); +no_route: + __IP_INC_STATS(net, IPSTATS_MIB_OUTNOROUTES); + return NULL; +} +EXPORT_SYMBOL_GPL(inet_csk_route_child_sock); + +/* Decide when to expire the request and when to resend SYN-ACK */ +static void syn_ack_recalc(struct request_sock *req, + const int max_syn_ack_retries, + const u8 rskq_defer_accept, + int *expire, int *resend) +{ + if (!rskq_defer_accept) { + *expire = req->num_timeout >= max_syn_ack_retries; + *resend = 1; + return; + } + *expire = req->num_timeout >= max_syn_ack_retries && + (!inet_rsk(req)->acked || req->num_timeout >= rskq_defer_accept); + /* Do not resend while waiting for data after ACK, + * start to resend on end of deferring period to give + * last chance for data or ACK to create established socket. + */ + *resend = !inet_rsk(req)->acked || + req->num_timeout >= rskq_defer_accept - 1; +} + +int inet_rtx_syn_ack(const struct sock *parent, struct request_sock *req) +{ + int err = req->rsk_ops->rtx_syn_ack(parent, req); + + if (!err) + req->num_retrans++; + return err; +} +EXPORT_SYMBOL(inet_rtx_syn_ack); + +static struct request_sock *inet_reqsk_clone(struct request_sock *req, + struct sock *sk) +{ + struct sock *req_sk, *nreq_sk; + struct request_sock *nreq; + + nreq = kmem_cache_alloc(req->rsk_ops->slab, GFP_ATOMIC | __GFP_NOWARN); + if (!nreq) { + __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMIGRATEREQFAILURE); + + /* paired with refcount_inc_not_zero() in reuseport_migrate_sock() */ + sock_put(sk); + return NULL; + } + + req_sk = req_to_sk(req); + nreq_sk = req_to_sk(nreq); + + memcpy(nreq_sk, req_sk, + offsetof(struct sock, sk_dontcopy_begin)); + memcpy(&nreq_sk->sk_dontcopy_end, &req_sk->sk_dontcopy_end, + req->rsk_ops->obj_size - offsetof(struct sock, sk_dontcopy_end)); + + sk_node_init(&nreq_sk->sk_node); + nreq_sk->sk_tx_queue_mapping = req_sk->sk_tx_queue_mapping; +#ifdef CONFIG_SOCK_RX_QUEUE_MAPPING + nreq_sk->sk_rx_queue_mapping = req_sk->sk_rx_queue_mapping; +#endif + nreq_sk->sk_incoming_cpu = req_sk->sk_incoming_cpu; + + nreq->rsk_listener = sk; + + /* We need not acquire fastopenq->lock + * because the child socket is locked in inet_csk_listen_stop(). + */ + if (sk->sk_protocol == IPPROTO_TCP && tcp_rsk(nreq)->tfo_listener) + rcu_assign_pointer(tcp_sk(nreq->sk)->fastopen_rsk, nreq); + + return nreq; +} + +static void reqsk_queue_migrated(struct request_sock_queue *queue, + const struct request_sock *req) +{ + if (req->num_timeout == 0) + atomic_inc(&queue->young); + atomic_inc(&queue->qlen); +} + +static void reqsk_migrate_reset(struct request_sock *req) +{ + req->saved_syn = NULL; +#if IS_ENABLED(CONFIG_IPV6) + inet_rsk(req)->ipv6_opt = NULL; + inet_rsk(req)->pktopts = NULL; +#else + inet_rsk(req)->ireq_opt = NULL; +#endif +} + +/* return true if req was found in the ehash table */ +static bool reqsk_queue_unlink(struct request_sock *req) +{ + struct sock *sk = req_to_sk(req); + bool found = false; + + if (sk_hashed(sk)) { + struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); + spinlock_t *lock = inet_ehash_lockp(hashinfo, req->rsk_hash); + + spin_lock(lock); + found = __sk_nulls_del_node_init_rcu(sk); + spin_unlock(lock); + } + if (timer_pending(&req->rsk_timer) && del_timer_sync(&req->rsk_timer)) + reqsk_put(req); + return found; +} + +bool inet_csk_reqsk_queue_drop(struct sock *sk, struct request_sock *req) +{ + bool unlinked = reqsk_queue_unlink(req); + + if (unlinked) { + reqsk_queue_removed(&inet_csk(sk)->icsk_accept_queue, req); + reqsk_put(req); + } + return unlinked; +} +EXPORT_SYMBOL(inet_csk_reqsk_queue_drop); + +void inet_csk_reqsk_queue_drop_and_put(struct sock *sk, struct request_sock *req) +{ + inet_csk_reqsk_queue_drop(sk, req); + reqsk_put(req); +} +EXPORT_SYMBOL(inet_csk_reqsk_queue_drop_and_put); + +static void reqsk_timer_handler(struct timer_list *t) +{ + struct request_sock *req = from_timer(req, t, rsk_timer); + struct request_sock *nreq = NULL, *oreq = req; + struct sock *sk_listener = req->rsk_listener; + struct inet_connection_sock *icsk; + struct request_sock_queue *queue; + struct net *net; + int max_syn_ack_retries, qlen, expire = 0, resend = 0; + + if (inet_sk_state_load(sk_listener) != TCP_LISTEN) { + struct sock *nsk; + + nsk = reuseport_migrate_sock(sk_listener, req_to_sk(req), NULL); + if (!nsk) + goto drop; + + nreq = inet_reqsk_clone(req, nsk); + if (!nreq) + goto drop; + + /* The new timer for the cloned req can decrease the 2 + * by calling inet_csk_reqsk_queue_drop_and_put(), so + * hold another count to prevent use-after-free and + * call reqsk_put() just before return. + */ + refcount_set(&nreq->rsk_refcnt, 2 + 1); + timer_setup(&nreq->rsk_timer, reqsk_timer_handler, TIMER_PINNED); + reqsk_queue_migrated(&inet_csk(nsk)->icsk_accept_queue, req); + + req = nreq; + sk_listener = nsk; + } + + icsk = inet_csk(sk_listener); + net = sock_net(sk_listener); + max_syn_ack_retries = READ_ONCE(icsk->icsk_syn_retries) ? : + READ_ONCE(net->ipv4.sysctl_tcp_synack_retries); + /* Normally all the openreqs are young and become mature + * (i.e. converted to established socket) for first timeout. + * If synack was not acknowledged for 1 second, it means + * one of the following things: synack was lost, ack was lost, + * rtt is high or nobody planned to ack (i.e. synflood). + * When server is a bit loaded, queue is populated with old + * open requests, reducing effective size of queue. + * When server is well loaded, queue size reduces to zero + * after several minutes of work. It is not synflood, + * it is normal operation. The solution is pruning + * too old entries overriding normal timeout, when + * situation becomes dangerous. + * + * Essentially, we reserve half of room for young + * embrions; and abort old ones without pity, if old + * ones are about to clog our table. + */ + queue = &icsk->icsk_accept_queue; + qlen = reqsk_queue_len(queue); + if ((qlen << 1) > max(8U, READ_ONCE(sk_listener->sk_max_ack_backlog))) { + int young = reqsk_queue_len_young(queue) << 1; + + while (max_syn_ack_retries > 2) { + if (qlen < young) + break; + max_syn_ack_retries--; + young <<= 1; + } + } + syn_ack_recalc(req, max_syn_ack_retries, READ_ONCE(queue->rskq_defer_accept), + &expire, &resend); + req->rsk_ops->syn_ack_timeout(req); + if (!expire && + (!resend || + !inet_rtx_syn_ack(sk_listener, req) || + inet_rsk(req)->acked)) { + if (req->num_timeout++ == 0) + atomic_dec(&queue->young); + mod_timer(&req->rsk_timer, jiffies + reqsk_timeout(req, TCP_RTO_MAX)); + + if (!nreq) + return; + + if (!inet_ehash_insert(req_to_sk(nreq), req_to_sk(oreq), NULL)) { + /* delete timer */ + inet_csk_reqsk_queue_drop(sk_listener, nreq); + goto no_ownership; + } + + __NET_INC_STATS(net, LINUX_MIB_TCPMIGRATEREQSUCCESS); + reqsk_migrate_reset(oreq); + reqsk_queue_removed(&inet_csk(oreq->rsk_listener)->icsk_accept_queue, oreq); + reqsk_put(oreq); + + reqsk_put(nreq); + return; + } + + /* Even if we can clone the req, we may need not retransmit any more + * SYN+ACKs (nreq->num_timeout > max_syn_ack_retries, etc), or another + * CPU may win the "own_req" race so that inet_ehash_insert() fails. + */ + if (nreq) { + __NET_INC_STATS(net, LINUX_MIB_TCPMIGRATEREQFAILURE); +no_ownership: + reqsk_migrate_reset(nreq); + reqsk_queue_removed(queue, nreq); + __reqsk_free(nreq); + } + +drop: + inet_csk_reqsk_queue_drop_and_put(oreq->rsk_listener, oreq); +} + +static void reqsk_queue_hash_req(struct request_sock *req, + unsigned long timeout) +{ + timer_setup(&req->rsk_timer, reqsk_timer_handler, TIMER_PINNED); + mod_timer(&req->rsk_timer, jiffies + timeout); + + inet_ehash_insert(req_to_sk(req), NULL, NULL); + /* before letting lookups find us, make sure all req fields + * are committed to memory and refcnt initialized. + */ + smp_wmb(); + refcount_set(&req->rsk_refcnt, 2 + 1); +} + +void inet_csk_reqsk_queue_hash_add(struct sock *sk, struct request_sock *req, + unsigned long timeout) +{ + reqsk_queue_hash_req(req, timeout); + inet_csk_reqsk_queue_added(sk); +} +EXPORT_SYMBOL_GPL(inet_csk_reqsk_queue_hash_add); + +static void inet_clone_ulp(const struct request_sock *req, struct sock *newsk, + const gfp_t priority) +{ + struct inet_connection_sock *icsk = inet_csk(newsk); + + if (!icsk->icsk_ulp_ops) + return; + + if (icsk->icsk_ulp_ops->clone) + icsk->icsk_ulp_ops->clone(req, newsk, priority); +} + +/** + * inet_csk_clone_lock - clone an inet socket, and lock its clone + * @sk: the socket to clone + * @req: request_sock + * @priority: for allocation (%GFP_KERNEL, %GFP_ATOMIC, etc) + * + * Caller must unlock socket even in error path (bh_unlock_sock(newsk)) + */ +struct sock *inet_csk_clone_lock(const struct sock *sk, + const struct request_sock *req, + const gfp_t priority) +{ + struct sock *newsk = sk_clone_lock(sk, priority); + + if (newsk) { + struct inet_connection_sock *newicsk = inet_csk(newsk); + + inet_sk_set_state(newsk, TCP_SYN_RECV); + newicsk->icsk_bind_hash = NULL; + newicsk->icsk_bind2_hash = NULL; + + inet_sk(newsk)->inet_dport = inet_rsk(req)->ir_rmt_port; + inet_sk(newsk)->inet_num = inet_rsk(req)->ir_num; + inet_sk(newsk)->inet_sport = htons(inet_rsk(req)->ir_num); + + /* listeners have SOCK_RCU_FREE, not the children */ + sock_reset_flag(newsk, SOCK_RCU_FREE); + + inet_sk(newsk)->mc_list = NULL; + + newsk->sk_mark = inet_rsk(req)->ir_mark; + atomic64_set(&newsk->sk_cookie, + atomic64_read(&inet_rsk(req)->ir_cookie)); + + newicsk->icsk_retransmits = 0; + newicsk->icsk_backoff = 0; + newicsk->icsk_probes_out = 0; + newicsk->icsk_probes_tstamp = 0; + + /* Deinitialize accept_queue to trap illegal accesses. */ + memset(&newicsk->icsk_accept_queue, 0, sizeof(newicsk->icsk_accept_queue)); + + inet_clone_ulp(req, newsk, priority); + + security_inet_csk_clone(newsk, req); + } + return newsk; +} +EXPORT_SYMBOL_GPL(inet_csk_clone_lock); + +/* + * At this point, there should be no process reference to this + * socket, and thus no user references at all. Therefore we + * can assume the socket waitqueue is inactive and nobody will + * try to jump onto it. + */ +void inet_csk_destroy_sock(struct sock *sk) +{ + WARN_ON(sk->sk_state != TCP_CLOSE); + WARN_ON(!sock_flag(sk, SOCK_DEAD)); + + /* It cannot be in hash table! */ + WARN_ON(!sk_unhashed(sk)); + + /* If it has not 0 inet_sk(sk)->inet_num, it must be bound */ + WARN_ON(inet_sk(sk)->inet_num && !inet_csk(sk)->icsk_bind_hash); + + sk->sk_prot->destroy(sk); + + sk_stream_kill_queues(sk); + + xfrm_sk_free_policy(sk); + + sk_refcnt_debug_release(sk); + + this_cpu_dec(*sk->sk_prot->orphan_count); + + sock_put(sk); +} +EXPORT_SYMBOL(inet_csk_destroy_sock); + +/* This function allows to force a closure of a socket after the call to + * tcp/dccp_create_openreq_child(). + */ +void inet_csk_prepare_forced_close(struct sock *sk) + __releases(&sk->sk_lock.slock) +{ + /* sk_clone_lock locked the socket and set refcnt to 2 */ + bh_unlock_sock(sk); + sock_put(sk); + inet_csk_prepare_for_destroy_sock(sk); + inet_sk(sk)->inet_num = 0; +} +EXPORT_SYMBOL(inet_csk_prepare_forced_close); + +static int inet_ulp_can_listen(const struct sock *sk) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + + if (icsk->icsk_ulp_ops && !icsk->icsk_ulp_ops->clone) + return -EINVAL; + + return 0; +} + +int inet_csk_listen_start(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct inet_sock *inet = inet_sk(sk); + int err; + + err = inet_ulp_can_listen(sk); + if (unlikely(err)) + return err; + + reqsk_queue_alloc(&icsk->icsk_accept_queue); + + sk->sk_ack_backlog = 0; + inet_csk_delack_init(sk); + + /* There is race window here: we announce ourselves listening, + * but this transition is still not validated by get_port(). + * It is OK, because this socket enters to hash table only + * after validation is complete. + */ + inet_sk_state_store(sk, TCP_LISTEN); + err = sk->sk_prot->get_port(sk, inet->inet_num); + if (!err) { + inet->inet_sport = htons(inet->inet_num); + + sk_dst_reset(sk); + err = sk->sk_prot->hash(sk); + + if (likely(!err)) + return 0; + } + + inet_sk_set_state(sk, TCP_CLOSE); + return err; +} +EXPORT_SYMBOL_GPL(inet_csk_listen_start); + +static void inet_child_forget(struct sock *sk, struct request_sock *req, + struct sock *child) +{ + sk->sk_prot->disconnect(child, O_NONBLOCK); + + sock_orphan(child); + + this_cpu_inc(*sk->sk_prot->orphan_count); + + if (sk->sk_protocol == IPPROTO_TCP && tcp_rsk(req)->tfo_listener) { + BUG_ON(rcu_access_pointer(tcp_sk(child)->fastopen_rsk) != req); + BUG_ON(sk != req->rsk_listener); + + /* Paranoid, to prevent race condition if + * an inbound pkt destined for child is + * blocked by sock lock in tcp_v4_rcv(). + * Also to satisfy an assertion in + * tcp_v4_destroy_sock(). + */ + RCU_INIT_POINTER(tcp_sk(child)->fastopen_rsk, NULL); + } + inet_csk_destroy_sock(child); +} + +struct sock *inet_csk_reqsk_queue_add(struct sock *sk, + struct request_sock *req, + struct sock *child) +{ + struct request_sock_queue *queue = &inet_csk(sk)->icsk_accept_queue; + + spin_lock(&queue->rskq_lock); + if (unlikely(sk->sk_state != TCP_LISTEN)) { + inet_child_forget(sk, req, child); + child = NULL; + } else { + req->sk = child; + req->dl_next = NULL; + if (queue->rskq_accept_head == NULL) + WRITE_ONCE(queue->rskq_accept_head, req); + else + queue->rskq_accept_tail->dl_next = req; + queue->rskq_accept_tail = req; + sk_acceptq_added(sk); + } + spin_unlock(&queue->rskq_lock); + return child; +} +EXPORT_SYMBOL(inet_csk_reqsk_queue_add); + +struct sock *inet_csk_complete_hashdance(struct sock *sk, struct sock *child, + struct request_sock *req, bool own_req) +{ + if (own_req) { + inet_csk_reqsk_queue_drop(req->rsk_listener, req); + reqsk_queue_removed(&inet_csk(req->rsk_listener)->icsk_accept_queue, req); + + if (sk != req->rsk_listener) { + /* another listening sk has been selected, + * migrate the req to it. + */ + struct request_sock *nreq; + + /* hold a refcnt for the nreq->rsk_listener + * which is assigned in inet_reqsk_clone() + */ + sock_hold(sk); + nreq = inet_reqsk_clone(req, sk); + if (!nreq) { + inet_child_forget(sk, req, child); + goto child_put; + } + + refcount_set(&nreq->rsk_refcnt, 1); + if (inet_csk_reqsk_queue_add(sk, nreq, child)) { + __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMIGRATEREQSUCCESS); + reqsk_migrate_reset(req); + reqsk_put(req); + return child; + } + + __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMIGRATEREQFAILURE); + reqsk_migrate_reset(nreq); + __reqsk_free(nreq); + } else if (inet_csk_reqsk_queue_add(sk, req, child)) { + return child; + } + } + /* Too bad, another child took ownership of the request, undo. */ +child_put: + bh_unlock_sock(child); + sock_put(child); + return NULL; +} +EXPORT_SYMBOL(inet_csk_complete_hashdance); + +/* + * This routine closes sockets which have been at least partially + * opened, but not yet accepted. + */ +void inet_csk_listen_stop(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct request_sock_queue *queue = &icsk->icsk_accept_queue; + struct request_sock *next, *req; + + /* Following specs, it would be better either to send FIN + * (and enter FIN-WAIT-1, it is normal close) + * or to send active reset (abort). + * Certainly, it is pretty dangerous while synflood, but it is + * bad justification for our negligence 8) + * To be honest, we are not able to make either + * of the variants now. --ANK + */ + while ((req = reqsk_queue_remove(queue, sk)) != NULL) { + struct sock *child = req->sk, *nsk; + struct request_sock *nreq; + + local_bh_disable(); + bh_lock_sock(child); + WARN_ON(sock_owned_by_user(child)); + sock_hold(child); + + nsk = reuseport_migrate_sock(sk, child, NULL); + if (nsk) { + nreq = inet_reqsk_clone(req, nsk); + if (nreq) { + refcount_set(&nreq->rsk_refcnt, 1); + + if (inet_csk_reqsk_queue_add(nsk, nreq, child)) { + __NET_INC_STATS(sock_net(nsk), + LINUX_MIB_TCPMIGRATEREQSUCCESS); + reqsk_migrate_reset(req); + } else { + __NET_INC_STATS(sock_net(nsk), + LINUX_MIB_TCPMIGRATEREQFAILURE); + reqsk_migrate_reset(nreq); + __reqsk_free(nreq); + } + + /* inet_csk_reqsk_queue_add() has already + * called inet_child_forget() on failure case. + */ + goto skip_child_forget; + } + } + + inet_child_forget(sk, req, child); +skip_child_forget: + reqsk_put(req); + bh_unlock_sock(child); + local_bh_enable(); + sock_put(child); + + cond_resched(); + } + if (queue->fastopenq.rskq_rst_head) { + /* Free all the reqs queued in rskq_rst_head. */ + spin_lock_bh(&queue->fastopenq.lock); + req = queue->fastopenq.rskq_rst_head; + queue->fastopenq.rskq_rst_head = NULL; + spin_unlock_bh(&queue->fastopenq.lock); + while (req != NULL) { + next = req->dl_next; + reqsk_put(req); + req = next; + } + } + WARN_ON_ONCE(sk->sk_ack_backlog); +} +EXPORT_SYMBOL_GPL(inet_csk_listen_stop); + +void inet_csk_addr2sockaddr(struct sock *sk, struct sockaddr *uaddr) +{ + struct sockaddr_in *sin = (struct sockaddr_in *)uaddr; + const struct inet_sock *inet = inet_sk(sk); + + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = inet->inet_daddr; + sin->sin_port = inet->inet_dport; +} +EXPORT_SYMBOL_GPL(inet_csk_addr2sockaddr); + +static struct dst_entry *inet_csk_rebuild_route(struct sock *sk, struct flowi *fl) +{ + const struct inet_sock *inet = inet_sk(sk); + const struct ip_options_rcu *inet_opt; + __be32 daddr = inet->inet_daddr; + struct flowi4 *fl4; + struct rtable *rt; + + rcu_read_lock(); + inet_opt = rcu_dereference(inet->inet_opt); + if (inet_opt && inet_opt->opt.srr) + daddr = inet_opt->opt.faddr; + fl4 = &fl->u.ip4; + rt = ip_route_output_ports(sock_net(sk), fl4, sk, daddr, + inet->inet_saddr, inet->inet_dport, + inet->inet_sport, sk->sk_protocol, + RT_CONN_FLAGS(sk), sk->sk_bound_dev_if); + if (IS_ERR(rt)) + rt = NULL; + if (rt) + sk_setup_caps(sk, &rt->dst); + rcu_read_unlock(); + + return &rt->dst; +} + +struct dst_entry *inet_csk_update_pmtu(struct sock *sk, u32 mtu) +{ + struct dst_entry *dst = __sk_dst_check(sk, 0); + struct inet_sock *inet = inet_sk(sk); + + if (!dst) { + dst = inet_csk_rebuild_route(sk, &inet->cork.fl); + if (!dst) + goto out; + } + dst->ops->update_pmtu(dst, sk, NULL, mtu, true); + + dst = __sk_dst_check(sk, 0); + if (!dst) + dst = inet_csk_rebuild_route(sk, &inet->cork.fl); +out: + return dst; +} +EXPORT_SYMBOL_GPL(inet_csk_update_pmtu); diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c new file mode 100644 index 000000000..f7426926a --- /dev/null +++ b/net/ipv4/inet_diag.c @@ -0,0 +1,1485 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * inet_diag.c Module for monitoring INET transport protocols sockets. + * + * Authors: Alexey Kuznetsov, + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +static const struct inet_diag_handler **inet_diag_table; + +struct inet_diag_entry { + const __be32 *saddr; + const __be32 *daddr; + u16 sport; + u16 dport; + u16 family; + u16 userlocks; + u32 ifindex; + u32 mark; +#ifdef CONFIG_SOCK_CGROUP_DATA + u64 cgroup_id; +#endif +}; + +static DEFINE_MUTEX(inet_diag_table_mutex); + +static const struct inet_diag_handler *inet_diag_lock_handler(int proto) +{ + if (proto < 0 || proto >= IPPROTO_MAX) { + mutex_lock(&inet_diag_table_mutex); + return ERR_PTR(-ENOENT); + } + + if (!inet_diag_table[proto]) + sock_load_diag_module(AF_INET, proto); + + mutex_lock(&inet_diag_table_mutex); + if (!inet_diag_table[proto]) + return ERR_PTR(-ENOENT); + + return inet_diag_table[proto]; +} + +static void inet_diag_unlock_handler(const struct inet_diag_handler *handler) +{ + mutex_unlock(&inet_diag_table_mutex); +} + +void inet_diag_msg_common_fill(struct inet_diag_msg *r, struct sock *sk) +{ + r->idiag_family = sk->sk_family; + + r->id.idiag_sport = htons(sk->sk_num); + r->id.idiag_dport = sk->sk_dport; + r->id.idiag_if = sk->sk_bound_dev_if; + sock_diag_save_cookie(sk, r->id.idiag_cookie); + +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) { + *(struct in6_addr *)r->id.idiag_src = sk->sk_v6_rcv_saddr; + *(struct in6_addr *)r->id.idiag_dst = sk->sk_v6_daddr; + } else +#endif + { + memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src)); + memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst)); + + r->id.idiag_src[0] = sk->sk_rcv_saddr; + r->id.idiag_dst[0] = sk->sk_daddr; + } +} +EXPORT_SYMBOL_GPL(inet_diag_msg_common_fill); + +static size_t inet_sk_attr_size(struct sock *sk, + const struct inet_diag_req_v2 *req, + bool net_admin) +{ + const struct inet_diag_handler *handler; + size_t aux = 0; + + handler = inet_diag_table[req->sdiag_protocol]; + if (handler && handler->idiag_get_aux_size) + aux = handler->idiag_get_aux_size(sk, net_admin); + + return nla_total_size(sizeof(struct tcp_info)) + + nla_total_size(sizeof(struct inet_diag_msg)) + + inet_diag_msg_attrs_size() + + nla_total_size(sizeof(struct inet_diag_meminfo)) + + nla_total_size(SK_MEMINFO_VARS * sizeof(u32)) + + nla_total_size(TCP_CA_NAME_MAX) + + nla_total_size(sizeof(struct tcpvegas_info)) + + aux + + 64; +} + +int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb, + struct inet_diag_msg *r, int ext, + struct user_namespace *user_ns, + bool net_admin) +{ + const struct inet_sock *inet = inet_sk(sk); + struct inet_diag_sockopt inet_sockopt; + + if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown)) + goto errout; + + /* IPv6 dual-stack sockets use inet->tos for IPv4 connections, + * hence this needs to be included regardless of socket family. + */ + if (ext & (1 << (INET_DIAG_TOS - 1))) + if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0) + goto errout; + +#if IS_ENABLED(CONFIG_IPV6) + if (r->idiag_family == AF_INET6) { + if (ext & (1 << (INET_DIAG_TCLASS - 1))) + if (nla_put_u8(skb, INET_DIAG_TCLASS, + inet6_sk(sk)->tclass) < 0) + goto errout; + + if (((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) && + nla_put_u8(skb, INET_DIAG_SKV6ONLY, ipv6_only_sock(sk))) + goto errout; + } +#endif + + if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, READ_ONCE(sk->sk_mark))) + goto errout; + + if (ext & (1 << (INET_DIAG_CLASS_ID - 1)) || + ext & (1 << (INET_DIAG_TCLASS - 1))) { + u32 classid = 0; + +#ifdef CONFIG_SOCK_CGROUP_DATA + classid = sock_cgroup_classid(&sk->sk_cgrp_data); +#endif + /* Fallback to socket priority if class id isn't set. + * Classful qdiscs use it as direct reference to class. + * For cgroup2 classid is always zero. + */ + if (!classid) + classid = sk->sk_priority; + + if (nla_put_u32(skb, INET_DIAG_CLASS_ID, classid)) + goto errout; + } + +#ifdef CONFIG_SOCK_CGROUP_DATA + if (nla_put_u64_64bit(skb, INET_DIAG_CGROUP_ID, + cgroup_id(sock_cgroup_ptr(&sk->sk_cgrp_data)), + INET_DIAG_PAD)) + goto errout; +#endif + + r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk)); + r->idiag_inode = sock_i_ino(sk); + + memset(&inet_sockopt, 0, sizeof(inet_sockopt)); + inet_sockopt.recverr = inet->recverr; + inet_sockopt.is_icsk = inet->is_icsk; + inet_sockopt.freebind = inet->freebind; + inet_sockopt.hdrincl = inet->hdrincl; + inet_sockopt.mc_loop = inet->mc_loop; + inet_sockopt.transparent = inet->transparent; + inet_sockopt.mc_all = inet->mc_all; + inet_sockopt.nodefrag = inet->nodefrag; + inet_sockopt.bind_address_no_port = inet->bind_address_no_port; + inet_sockopt.recverr_rfc4884 = inet->recverr_rfc4884; + inet_sockopt.defer_connect = inet->defer_connect; + if (nla_put(skb, INET_DIAG_SOCKOPT, sizeof(inet_sockopt), + &inet_sockopt)) + goto errout; + + return 0; +errout: + return 1; +} +EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill); + +static int inet_diag_parse_attrs(const struct nlmsghdr *nlh, int hdrlen, + struct nlattr **req_nlas) +{ + struct nlattr *nla; + int remaining; + + nlmsg_for_each_attr(nla, nlh, hdrlen, remaining) { + int type = nla_type(nla); + + if (type == INET_DIAG_REQ_PROTOCOL && nla_len(nla) != sizeof(u32)) + return -EINVAL; + + if (type < __INET_DIAG_REQ_MAX) + req_nlas[type] = nla; + } + return 0; +} + +static int inet_diag_get_protocol(const struct inet_diag_req_v2 *req, + const struct inet_diag_dump_data *data) +{ + if (data->req_nlas[INET_DIAG_REQ_PROTOCOL]) + return nla_get_u32(data->req_nlas[INET_DIAG_REQ_PROTOCOL]); + return req->sdiag_protocol; +} + +#define MAX_DUMP_ALLOC_SIZE (KMALLOC_MAX_SIZE - SKB_DATA_ALIGN(sizeof(struct skb_shared_info))) + +int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, + struct sk_buff *skb, struct netlink_callback *cb, + const struct inet_diag_req_v2 *req, + u16 nlmsg_flags, bool net_admin) +{ + const struct tcp_congestion_ops *ca_ops; + const struct inet_diag_handler *handler; + struct inet_diag_dump_data *cb_data; + int ext = req->idiag_ext; + struct inet_diag_msg *r; + struct nlmsghdr *nlh; + struct nlattr *attr; + void *info = NULL; + + cb_data = cb->data; + handler = inet_diag_table[inet_diag_get_protocol(req, cb_data)]; + BUG_ON(!handler); + + nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags); + if (!nlh) + return -EMSGSIZE; + + r = nlmsg_data(nlh); + BUG_ON(!sk_fullsock(sk)); + + inet_diag_msg_common_fill(r, sk); + r->idiag_state = sk->sk_state; + r->idiag_timer = 0; + r->idiag_retrans = 0; + r->idiag_expires = 0; + + if (inet_diag_msg_attrs_fill(sk, skb, r, ext, + sk_user_ns(NETLINK_CB(cb->skb).sk), + net_admin)) + goto errout; + + if (ext & (1 << (INET_DIAG_MEMINFO - 1))) { + struct inet_diag_meminfo minfo = { + .idiag_rmem = sk_rmem_alloc_get(sk), + .idiag_wmem = READ_ONCE(sk->sk_wmem_queued), + .idiag_fmem = sk_forward_alloc_get(sk), + .idiag_tmem = sk_wmem_alloc_get(sk), + }; + + if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0) + goto errout; + } + + if (ext & (1 << (INET_DIAG_SKMEMINFO - 1))) + if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO)) + goto errout; + + /* + * RAW sockets might have user-defined protocols assigned, + * so report the one supplied on socket creation. + */ + if (sk->sk_type == SOCK_RAW) { + if (nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol)) + goto errout; + } + + if (!icsk) { + handler->idiag_get_info(sk, r, NULL); + goto out; + } + + if (icsk->icsk_pending == ICSK_TIME_RETRANS || + icsk->icsk_pending == ICSK_TIME_REO_TIMEOUT || + icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) { + r->idiag_timer = 1; + r->idiag_retrans = icsk->icsk_retransmits; + r->idiag_expires = + jiffies_delta_to_msecs(icsk->icsk_timeout - jiffies); + } else if (icsk->icsk_pending == ICSK_TIME_PROBE0) { + r->idiag_timer = 4; + r->idiag_retrans = icsk->icsk_probes_out; + r->idiag_expires = + jiffies_delta_to_msecs(icsk->icsk_timeout - jiffies); + } else if (timer_pending(&sk->sk_timer)) { + r->idiag_timer = 2; + r->idiag_retrans = icsk->icsk_probes_out; + r->idiag_expires = + jiffies_delta_to_msecs(sk->sk_timer.expires - jiffies); + } + + if ((ext & (1 << (INET_DIAG_INFO - 1))) && handler->idiag_info_size) { + attr = nla_reserve_64bit(skb, INET_DIAG_INFO, + handler->idiag_info_size, + INET_DIAG_PAD); + if (!attr) + goto errout; + + info = nla_data(attr); + } + + if (ext & (1 << (INET_DIAG_CONG - 1))) { + int err = 0; + + rcu_read_lock(); + ca_ops = READ_ONCE(icsk->icsk_ca_ops); + if (ca_ops) + err = nla_put_string(skb, INET_DIAG_CONG, ca_ops->name); + rcu_read_unlock(); + if (err < 0) + goto errout; + } + + handler->idiag_get_info(sk, r, info); + + if (ext & (1 << (INET_DIAG_INFO - 1)) && handler->idiag_get_aux) + if (handler->idiag_get_aux(sk, net_admin, skb) < 0) + goto errout; + + if (sk->sk_state < TCP_TIME_WAIT) { + union tcp_cc_info info; + size_t sz = 0; + int attr; + + rcu_read_lock(); + ca_ops = READ_ONCE(icsk->icsk_ca_ops); + if (ca_ops && ca_ops->get_info) + sz = ca_ops->get_info(sk, ext, &attr, &info); + rcu_read_unlock(); + if (sz && nla_put(skb, attr, sz, &info) < 0) + goto errout; + } + + /* Keep it at the end for potential retry with a larger skb, + * or else do best-effort fitting, which is only done for the + * first_nlmsg. + */ + if (cb_data->bpf_stg_diag) { + bool first_nlmsg = ((unsigned char *)nlh == skb->data); + unsigned int prev_min_dump_alloc; + unsigned int total_nla_size = 0; + unsigned int msg_len; + int err; + + msg_len = skb_tail_pointer(skb) - (unsigned char *)nlh; + err = bpf_sk_storage_diag_put(cb_data->bpf_stg_diag, sk, skb, + INET_DIAG_SK_BPF_STORAGES, + &total_nla_size); + + if (!err) + goto out; + + total_nla_size += msg_len; + prev_min_dump_alloc = cb->min_dump_alloc; + if (total_nla_size > prev_min_dump_alloc) + cb->min_dump_alloc = min_t(u32, total_nla_size, + MAX_DUMP_ALLOC_SIZE); + + if (!first_nlmsg) + goto errout; + + if (cb->min_dump_alloc > prev_min_dump_alloc) + /* Retry with pskb_expand_head() with + * __GFP_DIRECT_RECLAIM + */ + goto errout; + + WARN_ON_ONCE(total_nla_size <= prev_min_dump_alloc); + + /* Send what we have for this sk + * and move on to the next sk in the following + * dump() + */ + } + +out: + nlmsg_end(skb, nlh); + return 0; + +errout: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} +EXPORT_SYMBOL_GPL(inet_sk_diag_fill); + +static int inet_twsk_diag_fill(struct sock *sk, + struct sk_buff *skb, + struct netlink_callback *cb, + u16 nlmsg_flags, bool net_admin) +{ + struct inet_timewait_sock *tw = inet_twsk(sk); + struct inet_diag_msg *r; + struct nlmsghdr *nlh; + long tmo; + + nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, cb->nlh->nlmsg_type, + sizeof(*r), nlmsg_flags); + if (!nlh) + return -EMSGSIZE; + + r = nlmsg_data(nlh); + BUG_ON(tw->tw_state != TCP_TIME_WAIT); + + inet_diag_msg_common_fill(r, sk); + r->idiag_retrans = 0; + + r->idiag_state = tw->tw_substate; + r->idiag_timer = 3; + tmo = tw->tw_timer.expires - jiffies; + r->idiag_expires = jiffies_delta_to_msecs(tmo); + r->idiag_rqueue = 0; + r->idiag_wqueue = 0; + r->idiag_uid = 0; + r->idiag_inode = 0; + + if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, + tw->tw_mark)) { + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; + } + + nlmsg_end(skb, nlh); + return 0; +} + +static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb, + struct netlink_callback *cb, + u16 nlmsg_flags, bool net_admin) +{ + struct request_sock *reqsk = inet_reqsk(sk); + struct inet_diag_msg *r; + struct nlmsghdr *nlh; + long tmo; + + nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags); + if (!nlh) + return -EMSGSIZE; + + r = nlmsg_data(nlh); + inet_diag_msg_common_fill(r, sk); + r->idiag_state = TCP_SYN_RECV; + r->idiag_timer = 1; + r->idiag_retrans = reqsk->num_retrans; + + BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) != + offsetof(struct sock, sk_cookie)); + + tmo = inet_reqsk(sk)->rsk_timer.expires - jiffies; + r->idiag_expires = jiffies_delta_to_msecs(tmo); + r->idiag_rqueue = 0; + r->idiag_wqueue = 0; + r->idiag_uid = 0; + r->idiag_inode = 0; + + if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, + inet_rsk(reqsk)->ir_mark)) { + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; + } + + nlmsg_end(skb, nlh); + return 0; +} + +static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, + struct netlink_callback *cb, + const struct inet_diag_req_v2 *r, + u16 nlmsg_flags, bool net_admin) +{ + if (sk->sk_state == TCP_TIME_WAIT) + return inet_twsk_diag_fill(sk, skb, cb, nlmsg_flags, net_admin); + + if (sk->sk_state == TCP_NEW_SYN_RECV) + return inet_req_diag_fill(sk, skb, cb, nlmsg_flags, net_admin); + + return inet_sk_diag_fill(sk, inet_csk(sk), skb, cb, r, nlmsg_flags, + net_admin); +} + +struct sock *inet_diag_find_one_icsk(struct net *net, + struct inet_hashinfo *hashinfo, + const struct inet_diag_req_v2 *req) +{ + struct sock *sk; + + rcu_read_lock(); + if (req->sdiag_family == AF_INET) + sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[0], + req->id.idiag_dport, req->id.idiag_src[0], + req->id.idiag_sport, req->id.idiag_if); +#if IS_ENABLED(CONFIG_IPV6) + else if (req->sdiag_family == AF_INET6) { + if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) && + ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src)) + sk = inet_lookup(net, hashinfo, NULL, 0, req->id.idiag_dst[3], + req->id.idiag_dport, req->id.idiag_src[3], + req->id.idiag_sport, req->id.idiag_if); + else + sk = inet6_lookup(net, hashinfo, NULL, 0, + (struct in6_addr *)req->id.idiag_dst, + req->id.idiag_dport, + (struct in6_addr *)req->id.idiag_src, + req->id.idiag_sport, + req->id.idiag_if); + } +#endif + else { + rcu_read_unlock(); + return ERR_PTR(-EINVAL); + } + rcu_read_unlock(); + if (!sk) + return ERR_PTR(-ENOENT); + + if (sock_diag_check_cookie(sk, req->id.idiag_cookie)) { + sock_gen_put(sk); + return ERR_PTR(-ENOENT); + } + + return sk; +} +EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk); + +int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, + struct netlink_callback *cb, + const struct inet_diag_req_v2 *req) +{ + struct sk_buff *in_skb = cb->skb; + bool net_admin = netlink_net_capable(in_skb, CAP_NET_ADMIN); + struct net *net = sock_net(in_skb->sk); + struct sk_buff *rep; + struct sock *sk; + int err; + + sk = inet_diag_find_one_icsk(net, hashinfo, req); + if (IS_ERR(sk)) + return PTR_ERR(sk); + + rep = nlmsg_new(inet_sk_attr_size(sk, req, net_admin), GFP_KERNEL); + if (!rep) { + err = -ENOMEM; + goto out; + } + + err = sk_diag_fill(sk, rep, cb, req, 0, net_admin); + if (err < 0) { + WARN_ON(err == -EMSGSIZE); + nlmsg_free(rep); + goto out; + } + err = nlmsg_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid); + +out: + if (sk) + sock_gen_put(sk); + + return err; +} +EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk); + +static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb, + const struct nlmsghdr *nlh, + int hdrlen, + const struct inet_diag_req_v2 *req) +{ + const struct inet_diag_handler *handler; + struct inet_diag_dump_data dump_data; + int err, protocol; + + memset(&dump_data, 0, sizeof(dump_data)); + err = inet_diag_parse_attrs(nlh, hdrlen, dump_data.req_nlas); + if (err) + return err; + + protocol = inet_diag_get_protocol(req, &dump_data); + + handler = inet_diag_lock_handler(protocol); + if (IS_ERR(handler)) { + err = PTR_ERR(handler); + } else if (cmd == SOCK_DIAG_BY_FAMILY) { + struct netlink_callback cb = { + .nlh = nlh, + .skb = in_skb, + .data = &dump_data, + }; + err = handler->dump_one(&cb, req); + } else if (cmd == SOCK_DESTROY && handler->destroy) { + err = handler->destroy(in_skb, req); + } else { + err = -EOPNOTSUPP; + } + inet_diag_unlock_handler(handler); + + return err; +} + +static int bitstring_match(const __be32 *a1, const __be32 *a2, int bits) +{ + int words = bits >> 5; + + bits &= 0x1f; + + if (words) { + if (memcmp(a1, a2, words << 2)) + return 0; + } + if (bits) { + __be32 w1, w2; + __be32 mask; + + w1 = a1[words]; + w2 = a2[words]; + + mask = htonl((0xffffffff) << (32 - bits)); + + if ((w1 ^ w2) & mask) + return 0; + } + + return 1; +} + +static int inet_diag_bc_run(const struct nlattr *_bc, + const struct inet_diag_entry *entry) +{ + const void *bc = nla_data(_bc); + int len = nla_len(_bc); + + while (len > 0) { + int yes = 1; + const struct inet_diag_bc_op *op = bc; + + switch (op->code) { + case INET_DIAG_BC_NOP: + break; + case INET_DIAG_BC_JMP: + yes = 0; + break; + case INET_DIAG_BC_S_EQ: + yes = entry->sport == op[1].no; + break; + case INET_DIAG_BC_S_GE: + yes = entry->sport >= op[1].no; + break; + case INET_DIAG_BC_S_LE: + yes = entry->sport <= op[1].no; + break; + case INET_DIAG_BC_D_EQ: + yes = entry->dport == op[1].no; + break; + case INET_DIAG_BC_D_GE: + yes = entry->dport >= op[1].no; + break; + case INET_DIAG_BC_D_LE: + yes = entry->dport <= op[1].no; + break; + case INET_DIAG_BC_AUTO: + yes = !(entry->userlocks & SOCK_BINDPORT_LOCK); + break; + case INET_DIAG_BC_S_COND: + case INET_DIAG_BC_D_COND: { + const struct inet_diag_hostcond *cond; + const __be32 *addr; + + cond = (const struct inet_diag_hostcond *)(op + 1); + if (cond->port != -1 && + cond->port != (op->code == INET_DIAG_BC_S_COND ? + entry->sport : entry->dport)) { + yes = 0; + break; + } + + if (op->code == INET_DIAG_BC_S_COND) + addr = entry->saddr; + else + addr = entry->daddr; + + if (cond->family != AF_UNSPEC && + cond->family != entry->family) { + if (entry->family == AF_INET6 && + cond->family == AF_INET) { + if (addr[0] == 0 && addr[1] == 0 && + addr[2] == htonl(0xffff) && + bitstring_match(addr + 3, + cond->addr, + cond->prefix_len)) + break; + } + yes = 0; + break; + } + + if (cond->prefix_len == 0) + break; + if (bitstring_match(addr, cond->addr, + cond->prefix_len)) + break; + yes = 0; + break; + } + case INET_DIAG_BC_DEV_COND: { + u32 ifindex; + + ifindex = *((const u32 *)(op + 1)); + if (ifindex != entry->ifindex) + yes = 0; + break; + } + case INET_DIAG_BC_MARK_COND: { + struct inet_diag_markcond *cond; + + cond = (struct inet_diag_markcond *)(op + 1); + if ((entry->mark & cond->mask) != cond->mark) + yes = 0; + break; + } +#ifdef CONFIG_SOCK_CGROUP_DATA + case INET_DIAG_BC_CGROUP_COND: { + u64 cgroup_id; + + cgroup_id = get_unaligned((const u64 *)(op + 1)); + if (cgroup_id != entry->cgroup_id) + yes = 0; + break; + } +#endif + } + + if (yes) { + len -= op->yes; + bc += op->yes; + } else { + len -= op->no; + bc += op->no; + } + } + return len == 0; +} + +/* This helper is available for all sockets (ESTABLISH, TIMEWAIT, SYN_RECV) + */ +static void entry_fill_addrs(struct inet_diag_entry *entry, + const struct sock *sk) +{ +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) { + entry->saddr = sk->sk_v6_rcv_saddr.s6_addr32; + entry->daddr = sk->sk_v6_daddr.s6_addr32; + } else +#endif + { + entry->saddr = &sk->sk_rcv_saddr; + entry->daddr = &sk->sk_daddr; + } +} + +int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk) +{ + struct inet_sock *inet = inet_sk(sk); + struct inet_diag_entry entry; + + if (!bc) + return 1; + + entry.family = sk->sk_family; + entry_fill_addrs(&entry, sk); + entry.sport = inet->inet_num; + entry.dport = ntohs(inet->inet_dport); + entry.ifindex = sk->sk_bound_dev_if; + entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0; + if (sk_fullsock(sk)) + entry.mark = READ_ONCE(sk->sk_mark); + else if (sk->sk_state == TCP_NEW_SYN_RECV) + entry.mark = inet_rsk(inet_reqsk(sk))->ir_mark; + else if (sk->sk_state == TCP_TIME_WAIT) + entry.mark = inet_twsk(sk)->tw_mark; + else + entry.mark = 0; +#ifdef CONFIG_SOCK_CGROUP_DATA + entry.cgroup_id = sk_fullsock(sk) ? + cgroup_id(sock_cgroup_ptr(&sk->sk_cgrp_data)) : 0; +#endif + + return inet_diag_bc_run(bc, &entry); +} +EXPORT_SYMBOL_GPL(inet_diag_bc_sk); + +static int valid_cc(const void *bc, int len, int cc) +{ + while (len >= 0) { + const struct inet_diag_bc_op *op = bc; + + if (cc > len) + return 0; + if (cc == len) + return 1; + if (op->yes < 4 || op->yes & 3) + return 0; + len -= op->yes; + bc += op->yes; + } + return 0; +} + +/* data is u32 ifindex */ +static bool valid_devcond(const struct inet_diag_bc_op *op, int len, + int *min_len) +{ + /* Check ifindex space. */ + *min_len += sizeof(u32); + if (len < *min_len) + return false; + + return true; +} +/* Validate an inet_diag_hostcond. */ +static bool valid_hostcond(const struct inet_diag_bc_op *op, int len, + int *min_len) +{ + struct inet_diag_hostcond *cond; + int addr_len; + + /* Check hostcond space. */ + *min_len += sizeof(struct inet_diag_hostcond); + if (len < *min_len) + return false; + cond = (struct inet_diag_hostcond *)(op + 1); + + /* Check address family and address length. */ + switch (cond->family) { + case AF_UNSPEC: + addr_len = 0; + break; + case AF_INET: + addr_len = sizeof(struct in_addr); + break; + case AF_INET6: + addr_len = sizeof(struct in6_addr); + break; + default: + return false; + } + *min_len += addr_len; + if (len < *min_len) + return false; + + /* Check prefix length (in bits) vs address length (in bytes). */ + if (cond->prefix_len > 8 * addr_len) + return false; + + return true; +} + +/* Validate a port comparison operator. */ +static bool valid_port_comparison(const struct inet_diag_bc_op *op, + int len, int *min_len) +{ + /* Port comparisons put the port in a follow-on inet_diag_bc_op. */ + *min_len += sizeof(struct inet_diag_bc_op); + if (len < *min_len) + return false; + return true; +} + +static bool valid_markcond(const struct inet_diag_bc_op *op, int len, + int *min_len) +{ + *min_len += sizeof(struct inet_diag_markcond); + return len >= *min_len; +} + +#ifdef CONFIG_SOCK_CGROUP_DATA +static bool valid_cgroupcond(const struct inet_diag_bc_op *op, int len, + int *min_len) +{ + *min_len += sizeof(u64); + return len >= *min_len; +} +#endif + +static int inet_diag_bc_audit(const struct nlattr *attr, + const struct sk_buff *skb) +{ + bool net_admin = netlink_net_capable(skb, CAP_NET_ADMIN); + const void *bytecode, *bc; + int bytecode_len, len; + + if (!attr || nla_len(attr) < sizeof(struct inet_diag_bc_op)) + return -EINVAL; + + bytecode = bc = nla_data(attr); + len = bytecode_len = nla_len(attr); + + while (len > 0) { + int min_len = sizeof(struct inet_diag_bc_op); + const struct inet_diag_bc_op *op = bc; + + switch (op->code) { + case INET_DIAG_BC_S_COND: + case INET_DIAG_BC_D_COND: + if (!valid_hostcond(bc, len, &min_len)) + return -EINVAL; + break; + case INET_DIAG_BC_DEV_COND: + if (!valid_devcond(bc, len, &min_len)) + return -EINVAL; + break; + case INET_DIAG_BC_S_EQ: + case INET_DIAG_BC_S_GE: + case INET_DIAG_BC_S_LE: + case INET_DIAG_BC_D_EQ: + case INET_DIAG_BC_D_GE: + case INET_DIAG_BC_D_LE: + if (!valid_port_comparison(bc, len, &min_len)) + return -EINVAL; + break; + case INET_DIAG_BC_MARK_COND: + if (!net_admin) + return -EPERM; + if (!valid_markcond(bc, len, &min_len)) + return -EINVAL; + break; +#ifdef CONFIG_SOCK_CGROUP_DATA + case INET_DIAG_BC_CGROUP_COND: + if (!valid_cgroupcond(bc, len, &min_len)) + return -EINVAL; + break; +#endif + case INET_DIAG_BC_AUTO: + case INET_DIAG_BC_JMP: + case INET_DIAG_BC_NOP: + break; + default: + return -EINVAL; + } + + if (op->code != INET_DIAG_BC_NOP) { + if (op->no < min_len || op->no > len + 4 || op->no & 3) + return -EINVAL; + if (op->no < len && + !valid_cc(bytecode, bytecode_len, len - op->no)) + return -EINVAL; + } + + if (op->yes < min_len || op->yes > len + 4 || op->yes & 3) + return -EINVAL; + bc += op->yes; + len -= op->yes; + } + return len == 0 ? 0 : -EINVAL; +} + +static void twsk_build_assert(void) +{ + BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) != + offsetof(struct sock, sk_family)); + + BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_num) != + offsetof(struct inet_sock, inet_num)); + + BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_dport) != + offsetof(struct inet_sock, inet_dport)); + + BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_rcv_saddr) != + offsetof(struct inet_sock, inet_rcv_saddr)); + + BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_daddr) != + offsetof(struct inet_sock, inet_daddr)); + +#if IS_ENABLED(CONFIG_IPV6) + BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_rcv_saddr) != + offsetof(struct sock, sk_v6_rcv_saddr)); + + BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_v6_daddr) != + offsetof(struct sock, sk_v6_daddr)); +#endif +} + +void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb, + struct netlink_callback *cb, + const struct inet_diag_req_v2 *r) +{ + bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN); + struct inet_diag_dump_data *cb_data = cb->data; + struct net *net = sock_net(skb->sk); + u32 idiag_states = r->idiag_states; + int i, num, s_i, s_num; + struct nlattr *bc; + struct sock *sk; + + bc = cb_data->inet_diag_nla_bc; + if (idiag_states & TCPF_SYN_RECV) + idiag_states |= TCPF_NEW_SYN_RECV; + s_i = cb->args[1]; + s_num = num = cb->args[2]; + + if (cb->args[0] == 0) { + if (!(idiag_states & TCPF_LISTEN) || r->id.idiag_dport) + goto skip_listen_ht; + + for (i = s_i; i <= hashinfo->lhash2_mask; i++) { + struct inet_listen_hashbucket *ilb; + struct hlist_nulls_node *node; + + num = 0; + ilb = &hashinfo->lhash2[i]; + + spin_lock(&ilb->lock); + sk_nulls_for_each(sk, node, &ilb->nulls_head) { + struct inet_sock *inet = inet_sk(sk); + + if (!net_eq(sock_net(sk), net)) + continue; + + if (num < s_num) { + num++; + continue; + } + + if (r->sdiag_family != AF_UNSPEC && + sk->sk_family != r->sdiag_family) + goto next_listen; + + if (r->id.idiag_sport != inet->inet_sport && + r->id.idiag_sport) + goto next_listen; + + if (!inet_diag_bc_sk(bc, sk)) + goto next_listen; + + if (inet_sk_diag_fill(sk, inet_csk(sk), skb, + cb, r, NLM_F_MULTI, + net_admin) < 0) { + spin_unlock(&ilb->lock); + goto done; + } + +next_listen: + ++num; + } + spin_unlock(&ilb->lock); + + s_num = 0; + } +skip_listen_ht: + cb->args[0] = 1; + s_i = num = s_num = 0; + } + + if (!(idiag_states & ~TCPF_LISTEN)) + goto out; + +#define SKARR_SZ 16 + for (i = s_i; i <= hashinfo->ehash_mask; i++) { + struct inet_ehash_bucket *head = &hashinfo->ehash[i]; + spinlock_t *lock = inet_ehash_lockp(hashinfo, i); + struct hlist_nulls_node *node; + struct sock *sk_arr[SKARR_SZ]; + int num_arr[SKARR_SZ]; + int idx, accum, res; + + if (hlist_nulls_empty(&head->chain)) + continue; + + if (i > s_i) + s_num = 0; + +next_chunk: + num = 0; + accum = 0; + spin_lock_bh(lock); + sk_nulls_for_each(sk, node, &head->chain) { + int state; + + if (!net_eq(sock_net(sk), net)) + continue; + if (num < s_num) + goto next_normal; + state = (sk->sk_state == TCP_TIME_WAIT) ? + inet_twsk(sk)->tw_substate : sk->sk_state; + if (!(idiag_states & (1 << state))) + goto next_normal; + if (r->sdiag_family != AF_UNSPEC && + sk->sk_family != r->sdiag_family) + goto next_normal; + if (r->id.idiag_sport != htons(sk->sk_num) && + r->id.idiag_sport) + goto next_normal; + if (r->id.idiag_dport != sk->sk_dport && + r->id.idiag_dport) + goto next_normal; + twsk_build_assert(); + + if (!inet_diag_bc_sk(bc, sk)) + goto next_normal; + + if (!refcount_inc_not_zero(&sk->sk_refcnt)) + goto next_normal; + + num_arr[accum] = num; + sk_arr[accum] = sk; + if (++accum == SKARR_SZ) + break; +next_normal: + ++num; + } + spin_unlock_bh(lock); + res = 0; + for (idx = 0; idx < accum; idx++) { + if (res >= 0) { + res = sk_diag_fill(sk_arr[idx], skb, cb, r, + NLM_F_MULTI, net_admin); + if (res < 0) + num = num_arr[idx]; + } + sock_gen_put(sk_arr[idx]); + } + if (res < 0) + break; + cond_resched(); + if (accum == SKARR_SZ) { + s_num = num + 1; + goto next_chunk; + } + } + +done: + cb->args[1] = i; + cb->args[2] = num; +out: + ; +} +EXPORT_SYMBOL_GPL(inet_diag_dump_icsk); + +static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, + const struct inet_diag_req_v2 *r) +{ + struct inet_diag_dump_data *cb_data = cb->data; + const struct inet_diag_handler *handler; + u32 prev_min_dump_alloc; + int protocol, err = 0; + + protocol = inet_diag_get_protocol(r, cb_data); + +again: + prev_min_dump_alloc = cb->min_dump_alloc; + handler = inet_diag_lock_handler(protocol); + if (!IS_ERR(handler)) + handler->dump(skb, cb, r); + else + err = PTR_ERR(handler); + inet_diag_unlock_handler(handler); + + /* The skb is not large enough to fit one sk info and + * inet_sk_diag_fill() has requested for a larger skb. + */ + if (!skb->len && cb->min_dump_alloc > prev_min_dump_alloc) { + err = pskb_expand_head(skb, 0, cb->min_dump_alloc, GFP_KERNEL); + if (!err) + goto again; + } + + return err ? : skb->len; +} + +static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh)); +} + +static int __inet_diag_dump_start(struct netlink_callback *cb, int hdrlen) +{ + const struct nlmsghdr *nlh = cb->nlh; + struct inet_diag_dump_data *cb_data; + struct sk_buff *skb = cb->skb; + struct nlattr *nla; + int err; + + cb_data = kzalloc(sizeof(*cb_data), GFP_KERNEL); + if (!cb_data) + return -ENOMEM; + + err = inet_diag_parse_attrs(nlh, hdrlen, cb_data->req_nlas); + if (err) { + kfree(cb_data); + return err; + } + nla = cb_data->inet_diag_nla_bc; + if (nla) { + err = inet_diag_bc_audit(nla, skb); + if (err) { + kfree(cb_data); + return err; + } + } + + nla = cb_data->inet_diag_nla_bpf_stgs; + if (nla) { + struct bpf_sk_storage_diag *bpf_stg_diag; + + bpf_stg_diag = bpf_sk_storage_diag_alloc(nla); + if (IS_ERR(bpf_stg_diag)) { + kfree(cb_data); + return PTR_ERR(bpf_stg_diag); + } + cb_data->bpf_stg_diag = bpf_stg_diag; + } + + cb->data = cb_data; + return 0; +} + +static int inet_diag_dump_start(struct netlink_callback *cb) +{ + return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req_v2)); +} + +static int inet_diag_dump_start_compat(struct netlink_callback *cb) +{ + return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req)); +} + +static int inet_diag_dump_done(struct netlink_callback *cb) +{ + struct inet_diag_dump_data *cb_data = cb->data; + + bpf_sk_storage_diag_free(cb_data->bpf_stg_diag); + kfree(cb->data); + + return 0; +} + +static int inet_diag_type2proto(int type) +{ + switch (type) { + case TCPDIAG_GETSOCK: + return IPPROTO_TCP; + case DCCPDIAG_GETSOCK: + return IPPROTO_DCCP; + default: + return 0; + } +} + +static int inet_diag_dump_compat(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct inet_diag_req *rc = nlmsg_data(cb->nlh); + struct inet_diag_req_v2 req; + + req.sdiag_family = AF_UNSPEC; /* compatibility */ + req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type); + req.idiag_ext = rc->idiag_ext; + req.idiag_states = rc->idiag_states; + req.id = rc->id; + + return __inet_diag_dump(skb, cb, &req); +} + +static int inet_diag_get_exact_compat(struct sk_buff *in_skb, + const struct nlmsghdr *nlh) +{ + struct inet_diag_req *rc = nlmsg_data(nlh); + struct inet_diag_req_v2 req; + + req.sdiag_family = rc->idiag_family; + req.sdiag_protocol = inet_diag_type2proto(nlh->nlmsg_type); + req.idiag_ext = rc->idiag_ext; + req.idiag_states = rc->idiag_states; + req.id = rc->id; + + return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh, + sizeof(struct inet_diag_req), &req); +} + +static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh) +{ + int hdrlen = sizeof(struct inet_diag_req); + struct net *net = sock_net(skb->sk); + + if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX || + nlmsg_len(nlh) < hdrlen) + return -EINVAL; + + if (nlh->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .start = inet_diag_dump_start_compat, + .done = inet_diag_dump_done, + .dump = inet_diag_dump_compat, + }; + return netlink_dump_start(net->diag_nlsk, skb, nlh, &c); + } + + return inet_diag_get_exact_compat(skb, nlh); +} + +static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h) +{ + int hdrlen = sizeof(struct inet_diag_req_v2); + struct net *net = sock_net(skb->sk); + + if (nlmsg_len(h) < hdrlen) + return -EINVAL; + + if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY && + h->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .start = inet_diag_dump_start, + .done = inet_diag_dump_done, + .dump = inet_diag_dump, + }; + return netlink_dump_start(net->diag_nlsk, skb, h, &c); + } + + return inet_diag_cmd_exact(h->nlmsg_type, skb, h, hdrlen, + nlmsg_data(h)); +} + +static +int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk) +{ + const struct inet_diag_handler *handler; + struct nlmsghdr *nlh; + struct nlattr *attr; + struct inet_diag_msg *r; + void *info = NULL; + int err = 0; + + nlh = nlmsg_put(skb, 0, 0, SOCK_DIAG_BY_FAMILY, sizeof(*r), 0); + if (!nlh) + return -ENOMEM; + + r = nlmsg_data(nlh); + memset(r, 0, sizeof(*r)); + inet_diag_msg_common_fill(r, sk); + if (sk->sk_type == SOCK_DGRAM || sk->sk_type == SOCK_STREAM) + r->id.idiag_sport = inet_sk(sk)->inet_sport; + r->idiag_state = sk->sk_state; + + if ((err = nla_put_u8(skb, INET_DIAG_PROTOCOL, sk->sk_protocol))) { + nlmsg_cancel(skb, nlh); + return err; + } + + handler = inet_diag_lock_handler(sk->sk_protocol); + if (IS_ERR(handler)) { + inet_diag_unlock_handler(handler); + nlmsg_cancel(skb, nlh); + return PTR_ERR(handler); + } + + attr = handler->idiag_info_size + ? nla_reserve_64bit(skb, INET_DIAG_INFO, + handler->idiag_info_size, + INET_DIAG_PAD) + : NULL; + if (attr) + info = nla_data(attr); + + handler->idiag_get_info(sk, r, info); + inet_diag_unlock_handler(handler); + + nlmsg_end(skb, nlh); + return 0; +} + +static const struct sock_diag_handler inet_diag_handler = { + .family = AF_INET, + .dump = inet_diag_handler_cmd, + .get_info = inet_diag_handler_get_info, + .destroy = inet_diag_handler_cmd, +}; + +static const struct sock_diag_handler inet6_diag_handler = { + .family = AF_INET6, + .dump = inet_diag_handler_cmd, + .get_info = inet_diag_handler_get_info, + .destroy = inet_diag_handler_cmd, +}; + +int inet_diag_register(const struct inet_diag_handler *h) +{ + const __u16 type = h->idiag_type; + int err = -EINVAL; + + if (type >= IPPROTO_MAX) + goto out; + + mutex_lock(&inet_diag_table_mutex); + err = -EEXIST; + if (!inet_diag_table[type]) { + inet_diag_table[type] = h; + err = 0; + } + mutex_unlock(&inet_diag_table_mutex); +out: + return err; +} +EXPORT_SYMBOL_GPL(inet_diag_register); + +void inet_diag_unregister(const struct inet_diag_handler *h) +{ + const __u16 type = h->idiag_type; + + if (type >= IPPROTO_MAX) + return; + + mutex_lock(&inet_diag_table_mutex); + inet_diag_table[type] = NULL; + mutex_unlock(&inet_diag_table_mutex); +} +EXPORT_SYMBOL_GPL(inet_diag_unregister); + +static int __init inet_diag_init(void) +{ + const int inet_diag_table_size = (IPPROTO_MAX * + sizeof(struct inet_diag_handler *)); + int err = -ENOMEM; + + inet_diag_table = kzalloc(inet_diag_table_size, GFP_KERNEL); + if (!inet_diag_table) + goto out; + + err = sock_diag_register(&inet_diag_handler); + if (err) + goto out_free_nl; + + err = sock_diag_register(&inet6_diag_handler); + if (err) + goto out_free_inet; + + sock_diag_register_inet_compat(inet_diag_rcv_msg_compat); +out: + return err; + +out_free_inet: + sock_diag_unregister(&inet_diag_handler); +out_free_nl: + kfree(inet_diag_table); + goto out; +} + +static void __exit inet_diag_exit(void) +{ + sock_diag_unregister(&inet6_diag_handler); + sock_diag_unregister(&inet_diag_handler); + sock_diag_unregister_inet_compat(inet_diag_rcv_msg_compat); + kfree(inet_diag_table); +} + +module_init(inet_diag_init); +module_exit(inet_diag_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2 /* AF_INET */); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10 /* AF_INET6 */); diff --git a/net/ipv4/inet_fragment.c b/net/ipv4/inet_fragment.c new file mode 100644 index 000000000..c9f9ac501 --- /dev/null +++ b/net/ipv4/inet_fragment.c @@ -0,0 +1,602 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * inet fragments management + * + * Authors: Pavel Emelyanov + * Started as consolidation of ipv4/ip_fragment.c, + * ipv6/reassembly. and ipv6 nf conntrack reassembly + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +/* Use skb->cb to track consecutive/adjacent fragments coming at + * the end of the queue. Nodes in the rb-tree queue will + * contain "runs" of one or more adjacent fragments. + * + * Invariants: + * - next_frag is NULL at the tail of a "run"; + * - the head of a "run" has the sum of all fragment lengths in frag_run_len. + */ +struct ipfrag_skb_cb { + union { + struct inet_skb_parm h4; + struct inet6_skb_parm h6; + }; + struct sk_buff *next_frag; + int frag_run_len; +}; + +#define FRAG_CB(skb) ((struct ipfrag_skb_cb *)((skb)->cb)) + +static void fragcb_clear(struct sk_buff *skb) +{ + RB_CLEAR_NODE(&skb->rbnode); + FRAG_CB(skb)->next_frag = NULL; + FRAG_CB(skb)->frag_run_len = skb->len; +} + +/* Append skb to the last "run". */ +static void fragrun_append_to_last(struct inet_frag_queue *q, + struct sk_buff *skb) +{ + fragcb_clear(skb); + + FRAG_CB(q->last_run_head)->frag_run_len += skb->len; + FRAG_CB(q->fragments_tail)->next_frag = skb; + q->fragments_tail = skb; +} + +/* Create a new "run" with the skb. */ +static void fragrun_create(struct inet_frag_queue *q, struct sk_buff *skb) +{ + BUILD_BUG_ON(sizeof(struct ipfrag_skb_cb) > sizeof(skb->cb)); + fragcb_clear(skb); + + if (q->last_run_head) + rb_link_node(&skb->rbnode, &q->last_run_head->rbnode, + &q->last_run_head->rbnode.rb_right); + else + rb_link_node(&skb->rbnode, NULL, &q->rb_fragments.rb_node); + rb_insert_color(&skb->rbnode, &q->rb_fragments); + + q->fragments_tail = skb; + q->last_run_head = skb; +} + +/* Given the OR values of all fragments, apply RFC 3168 5.3 requirements + * Value : 0xff if frame should be dropped. + * 0 or INET_ECN_CE value, to be ORed in to final iph->tos field + */ +const u8 ip_frag_ecn_table[16] = { + /* at least one fragment had CE, and others ECT_0 or ECT_1 */ + [IPFRAG_ECN_CE | IPFRAG_ECN_ECT_0] = INET_ECN_CE, + [IPFRAG_ECN_CE | IPFRAG_ECN_ECT_1] = INET_ECN_CE, + [IPFRAG_ECN_CE | IPFRAG_ECN_ECT_0 | IPFRAG_ECN_ECT_1] = INET_ECN_CE, + + /* invalid combinations : drop frame */ + [IPFRAG_ECN_NOT_ECT | IPFRAG_ECN_CE] = 0xff, + [IPFRAG_ECN_NOT_ECT | IPFRAG_ECN_ECT_0] = 0xff, + [IPFRAG_ECN_NOT_ECT | IPFRAG_ECN_ECT_1] = 0xff, + [IPFRAG_ECN_NOT_ECT | IPFRAG_ECN_ECT_0 | IPFRAG_ECN_ECT_1] = 0xff, + [IPFRAG_ECN_NOT_ECT | IPFRAG_ECN_CE | IPFRAG_ECN_ECT_0] = 0xff, + [IPFRAG_ECN_NOT_ECT | IPFRAG_ECN_CE | IPFRAG_ECN_ECT_1] = 0xff, + [IPFRAG_ECN_NOT_ECT | IPFRAG_ECN_CE | IPFRAG_ECN_ECT_0 | IPFRAG_ECN_ECT_1] = 0xff, +}; +EXPORT_SYMBOL(ip_frag_ecn_table); + +int inet_frags_init(struct inet_frags *f) +{ + f->frags_cachep = kmem_cache_create(f->frags_cache_name, f->qsize, 0, 0, + NULL); + if (!f->frags_cachep) + return -ENOMEM; + + refcount_set(&f->refcnt, 1); + init_completion(&f->completion); + return 0; +} +EXPORT_SYMBOL(inet_frags_init); + +void inet_frags_fini(struct inet_frags *f) +{ + if (refcount_dec_and_test(&f->refcnt)) + complete(&f->completion); + + wait_for_completion(&f->completion); + + kmem_cache_destroy(f->frags_cachep); + f->frags_cachep = NULL; +} +EXPORT_SYMBOL(inet_frags_fini); + +/* called from rhashtable_free_and_destroy() at netns_frags dismantle */ +static void inet_frags_free_cb(void *ptr, void *arg) +{ + struct inet_frag_queue *fq = ptr; + int count; + + count = del_timer_sync(&fq->timer) ? 1 : 0; + + spin_lock_bh(&fq->lock); + if (!(fq->flags & INET_FRAG_COMPLETE)) { + fq->flags |= INET_FRAG_COMPLETE; + count++; + } else if (fq->flags & INET_FRAG_HASH_DEAD) { + count++; + } + spin_unlock_bh(&fq->lock); + + if (refcount_sub_and_test(count, &fq->refcnt)) + inet_frag_destroy(fq); +} + +static LLIST_HEAD(fqdir_free_list); + +static void fqdir_free_fn(struct work_struct *work) +{ + struct llist_node *kill_list; + struct fqdir *fqdir, *tmp; + struct inet_frags *f; + + /* Atomically snapshot the list of fqdirs to free */ + kill_list = llist_del_all(&fqdir_free_list); + + /* We need to make sure all ongoing call_rcu(..., inet_frag_destroy_rcu) + * have completed, since they need to dereference fqdir. + * Would it not be nice to have kfree_rcu_barrier() ? :) + */ + rcu_barrier(); + + llist_for_each_entry_safe(fqdir, tmp, kill_list, free_list) { + f = fqdir->f; + if (refcount_dec_and_test(&f->refcnt)) + complete(&f->completion); + + kfree(fqdir); + } +} + +static DECLARE_WORK(fqdir_free_work, fqdir_free_fn); + +static void fqdir_work_fn(struct work_struct *work) +{ + struct fqdir *fqdir = container_of(work, struct fqdir, destroy_work); + + rhashtable_free_and_destroy(&fqdir->rhashtable, inet_frags_free_cb, NULL); + + if (llist_add(&fqdir->free_list, &fqdir_free_list)) + queue_work(system_wq, &fqdir_free_work); +} + +int fqdir_init(struct fqdir **fqdirp, struct inet_frags *f, struct net *net) +{ + struct fqdir *fqdir = kzalloc(sizeof(*fqdir), GFP_KERNEL); + int res; + + if (!fqdir) + return -ENOMEM; + fqdir->f = f; + fqdir->net = net; + res = rhashtable_init(&fqdir->rhashtable, &fqdir->f->rhash_params); + if (res < 0) { + kfree(fqdir); + return res; + } + refcount_inc(&f->refcnt); + *fqdirp = fqdir; + return 0; +} +EXPORT_SYMBOL(fqdir_init); + +static struct workqueue_struct *inet_frag_wq; + +static int __init inet_frag_wq_init(void) +{ + inet_frag_wq = create_workqueue("inet_frag_wq"); + if (!inet_frag_wq) + panic("Could not create inet frag workq"); + return 0; +} + +pure_initcall(inet_frag_wq_init); + +void fqdir_exit(struct fqdir *fqdir) +{ + INIT_WORK(&fqdir->destroy_work, fqdir_work_fn); + queue_work(inet_frag_wq, &fqdir->destroy_work); +} +EXPORT_SYMBOL(fqdir_exit); + +void inet_frag_kill(struct inet_frag_queue *fq) +{ + if (del_timer(&fq->timer)) + refcount_dec(&fq->refcnt); + + if (!(fq->flags & INET_FRAG_COMPLETE)) { + struct fqdir *fqdir = fq->fqdir; + + fq->flags |= INET_FRAG_COMPLETE; + rcu_read_lock(); + /* The RCU read lock provides a memory barrier + * guaranteeing that if fqdir->dead is false then + * the hash table destruction will not start until + * after we unlock. Paired with fqdir_pre_exit(). + */ + if (!READ_ONCE(fqdir->dead)) { + rhashtable_remove_fast(&fqdir->rhashtable, &fq->node, + fqdir->f->rhash_params); + refcount_dec(&fq->refcnt); + } else { + fq->flags |= INET_FRAG_HASH_DEAD; + } + rcu_read_unlock(); + } +} +EXPORT_SYMBOL(inet_frag_kill); + +static void inet_frag_destroy_rcu(struct rcu_head *head) +{ + struct inet_frag_queue *q = container_of(head, struct inet_frag_queue, + rcu); + struct inet_frags *f = q->fqdir->f; + + if (f->destructor) + f->destructor(q); + kmem_cache_free(f->frags_cachep, q); +} + +unsigned int inet_frag_rbtree_purge(struct rb_root *root) +{ + struct rb_node *p = rb_first(root); + unsigned int sum = 0; + + while (p) { + struct sk_buff *skb = rb_entry(p, struct sk_buff, rbnode); + + p = rb_next(p); + rb_erase(&skb->rbnode, root); + while (skb) { + struct sk_buff *next = FRAG_CB(skb)->next_frag; + + sum += skb->truesize; + kfree_skb(skb); + skb = next; + } + } + return sum; +} +EXPORT_SYMBOL(inet_frag_rbtree_purge); + +void inet_frag_destroy(struct inet_frag_queue *q) +{ + struct fqdir *fqdir; + unsigned int sum, sum_truesize = 0; + struct inet_frags *f; + + WARN_ON(!(q->flags & INET_FRAG_COMPLETE)); + WARN_ON(del_timer(&q->timer) != 0); + + /* Release all fragment data. */ + fqdir = q->fqdir; + f = fqdir->f; + sum_truesize = inet_frag_rbtree_purge(&q->rb_fragments); + sum = sum_truesize + f->qsize; + + call_rcu(&q->rcu, inet_frag_destroy_rcu); + + sub_frag_mem_limit(fqdir, sum); +} +EXPORT_SYMBOL(inet_frag_destroy); + +static struct inet_frag_queue *inet_frag_alloc(struct fqdir *fqdir, + struct inet_frags *f, + void *arg) +{ + struct inet_frag_queue *q; + + q = kmem_cache_zalloc(f->frags_cachep, GFP_ATOMIC); + if (!q) + return NULL; + + q->fqdir = fqdir; + f->constructor(q, arg); + add_frag_mem_limit(fqdir, f->qsize); + + timer_setup(&q->timer, f->frag_expire, 0); + spin_lock_init(&q->lock); + refcount_set(&q->refcnt, 3); + + return q; +} + +static struct inet_frag_queue *inet_frag_create(struct fqdir *fqdir, + void *arg, + struct inet_frag_queue **prev) +{ + struct inet_frags *f = fqdir->f; + struct inet_frag_queue *q; + + q = inet_frag_alloc(fqdir, f, arg); + if (!q) { + *prev = ERR_PTR(-ENOMEM); + return NULL; + } + mod_timer(&q->timer, jiffies + fqdir->timeout); + + *prev = rhashtable_lookup_get_insert_key(&fqdir->rhashtable, &q->key, + &q->node, f->rhash_params); + if (*prev) { + q->flags |= INET_FRAG_COMPLETE; + inet_frag_kill(q); + inet_frag_destroy(q); + return NULL; + } + return q; +} + +/* TODO : call from rcu_read_lock() and no longer use refcount_inc_not_zero() */ +struct inet_frag_queue *inet_frag_find(struct fqdir *fqdir, void *key) +{ + /* This pairs with WRITE_ONCE() in fqdir_pre_exit(). */ + long high_thresh = READ_ONCE(fqdir->high_thresh); + struct inet_frag_queue *fq = NULL, *prev; + + if (!high_thresh || frag_mem_limit(fqdir) > high_thresh) + return NULL; + + rcu_read_lock(); + + prev = rhashtable_lookup(&fqdir->rhashtable, key, fqdir->f->rhash_params); + if (!prev) + fq = inet_frag_create(fqdir, key, &prev); + if (!IS_ERR_OR_NULL(prev)) { + fq = prev; + if (!refcount_inc_not_zero(&fq->refcnt)) + fq = NULL; + } + rcu_read_unlock(); + return fq; +} +EXPORT_SYMBOL(inet_frag_find); + +int inet_frag_queue_insert(struct inet_frag_queue *q, struct sk_buff *skb, + int offset, int end) +{ + struct sk_buff *last = q->fragments_tail; + + /* RFC5722, Section 4, amended by Errata ID : 3089 + * When reassembling an IPv6 datagram, if + * one or more its constituent fragments is determined to be an + * overlapping fragment, the entire datagram (and any constituent + * fragments) MUST be silently discarded. + * + * Duplicates, however, should be ignored (i.e. skb dropped, but the + * queue/fragments kept for later reassembly). + */ + if (!last) + fragrun_create(q, skb); /* First fragment. */ + else if (last->ip_defrag_offset + last->len < end) { + /* This is the common case: skb goes to the end. */ + /* Detect and discard overlaps. */ + if (offset < last->ip_defrag_offset + last->len) + return IPFRAG_OVERLAP; + if (offset == last->ip_defrag_offset + last->len) + fragrun_append_to_last(q, skb); + else + fragrun_create(q, skb); + } else { + /* Binary search. Note that skb can become the first fragment, + * but not the last (covered above). + */ + struct rb_node **rbn, *parent; + + rbn = &q->rb_fragments.rb_node; + do { + struct sk_buff *curr; + int curr_run_end; + + parent = *rbn; + curr = rb_to_skb(parent); + curr_run_end = curr->ip_defrag_offset + + FRAG_CB(curr)->frag_run_len; + if (end <= curr->ip_defrag_offset) + rbn = &parent->rb_left; + else if (offset >= curr_run_end) + rbn = &parent->rb_right; + else if (offset >= curr->ip_defrag_offset && + end <= curr_run_end) + return IPFRAG_DUP; + else + return IPFRAG_OVERLAP; + } while (*rbn); + /* Here we have parent properly set, and rbn pointing to + * one of its NULL left/right children. Insert skb. + */ + fragcb_clear(skb); + rb_link_node(&skb->rbnode, parent, rbn); + rb_insert_color(&skb->rbnode, &q->rb_fragments); + } + + skb->ip_defrag_offset = offset; + + return IPFRAG_OK; +} +EXPORT_SYMBOL(inet_frag_queue_insert); + +void *inet_frag_reasm_prepare(struct inet_frag_queue *q, struct sk_buff *skb, + struct sk_buff *parent) +{ + struct sk_buff *fp, *head = skb_rb_first(&q->rb_fragments); + struct sk_buff **nextp; + int delta; + + if (head != skb) { + fp = skb_clone(skb, GFP_ATOMIC); + if (!fp) + return NULL; + FRAG_CB(fp)->next_frag = FRAG_CB(skb)->next_frag; + if (RB_EMPTY_NODE(&skb->rbnode)) + FRAG_CB(parent)->next_frag = fp; + else + rb_replace_node(&skb->rbnode, &fp->rbnode, + &q->rb_fragments); + if (q->fragments_tail == skb) + q->fragments_tail = fp; + skb_morph(skb, head); + FRAG_CB(skb)->next_frag = FRAG_CB(head)->next_frag; + rb_replace_node(&head->rbnode, &skb->rbnode, + &q->rb_fragments); + consume_skb(head); + head = skb; + } + WARN_ON(head->ip_defrag_offset != 0); + + delta = -head->truesize; + + /* Head of list must not be cloned. */ + if (skb_unclone(head, GFP_ATOMIC)) + return NULL; + + delta += head->truesize; + if (delta) + add_frag_mem_limit(q->fqdir, delta); + + /* If the first fragment is fragmented itself, we split + * it to two chunks: the first with data and paged part + * and the second, holding only fragments. + */ + if (skb_has_frag_list(head)) { + struct sk_buff *clone; + int i, plen = 0; + + clone = alloc_skb(0, GFP_ATOMIC); + if (!clone) + return NULL; + skb_shinfo(clone)->frag_list = skb_shinfo(head)->frag_list; + skb_frag_list_init(head); + for (i = 0; i < skb_shinfo(head)->nr_frags; i++) + plen += skb_frag_size(&skb_shinfo(head)->frags[i]); + clone->data_len = head->data_len - plen; + clone->len = clone->data_len; + head->truesize += clone->truesize; + clone->csum = 0; + clone->ip_summed = head->ip_summed; + add_frag_mem_limit(q->fqdir, clone->truesize); + skb_shinfo(head)->frag_list = clone; + nextp = &clone->next; + } else { + nextp = &skb_shinfo(head)->frag_list; + } + + return nextp; +} +EXPORT_SYMBOL(inet_frag_reasm_prepare); + +void inet_frag_reasm_finish(struct inet_frag_queue *q, struct sk_buff *head, + void *reasm_data, bool try_coalesce) +{ + struct sk_buff **nextp = reasm_data; + struct rb_node *rbn; + struct sk_buff *fp; + int sum_truesize; + + skb_push(head, head->data - skb_network_header(head)); + + /* Traverse the tree in order, to build frag_list. */ + fp = FRAG_CB(head)->next_frag; + rbn = rb_next(&head->rbnode); + rb_erase(&head->rbnode, &q->rb_fragments); + + sum_truesize = head->truesize; + while (rbn || fp) { + /* fp points to the next sk_buff in the current run; + * rbn points to the next run. + */ + /* Go through the current run. */ + while (fp) { + struct sk_buff *next_frag = FRAG_CB(fp)->next_frag; + bool stolen; + int delta; + + sum_truesize += fp->truesize; + if (head->ip_summed != fp->ip_summed) + head->ip_summed = CHECKSUM_NONE; + else if (head->ip_summed == CHECKSUM_COMPLETE) + head->csum = csum_add(head->csum, fp->csum); + + if (try_coalesce && skb_try_coalesce(head, fp, &stolen, + &delta)) { + kfree_skb_partial(fp, stolen); + } else { + fp->prev = NULL; + memset(&fp->rbnode, 0, sizeof(fp->rbnode)); + fp->sk = NULL; + + head->data_len += fp->len; + head->len += fp->len; + head->truesize += fp->truesize; + + *nextp = fp; + nextp = &fp->next; + } + + fp = next_frag; + } + /* Move to the next run. */ + if (rbn) { + struct rb_node *rbnext = rb_next(rbn); + + fp = rb_to_skb(rbn); + rb_erase(rbn, &q->rb_fragments); + rbn = rbnext; + } + } + sub_frag_mem_limit(q->fqdir, sum_truesize); + + *nextp = NULL; + skb_mark_not_on_list(head); + head->prev = NULL; + head->tstamp = q->stamp; + head->mono_delivery_time = q->mono_delivery_time; +} +EXPORT_SYMBOL(inet_frag_reasm_finish); + +struct sk_buff *inet_frag_pull_head(struct inet_frag_queue *q) +{ + struct sk_buff *head, *skb; + + head = skb_rb_first(&q->rb_fragments); + if (!head) + return NULL; + skb = FRAG_CB(head)->next_frag; + if (skb) + rb_replace_node(&head->rbnode, &skb->rbnode, + &q->rb_fragments); + else + rb_erase(&head->rbnode, &q->rb_fragments); + memset(&head->rbnode, 0, sizeof(head->rbnode)); + barrier(); + + if (head == q->fragments_tail) + q->fragments_tail = NULL; + + sub_frag_mem_limit(q->fqdir, head->truesize); + + return head; +} +EXPORT_SYMBOL(inet_frag_pull_head); diff --git a/net/ipv4/inet_hashtables.c b/net/ipv4/inet_hashtables.c new file mode 100644 index 000000000..f2ed2aed0 --- /dev/null +++ b/net/ipv4/inet_hashtables.c @@ -0,0 +1,1257 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * Generic INET transport hashtables + * + * Authors: Lotsa people, from code originally in tcp + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#if IS_ENABLED(CONFIG_IPV6) +#include +#endif +#include +#include +#include +#include + +static u32 inet_ehashfn(const struct net *net, const __be32 laddr, + const __u16 lport, const __be32 faddr, + const __be16 fport) +{ + static u32 inet_ehash_secret __read_mostly; + + net_get_random_once(&inet_ehash_secret, sizeof(inet_ehash_secret)); + + return __inet_ehashfn(laddr, lport, faddr, fport, + inet_ehash_secret + net_hash_mix(net)); +} + +/* This function handles inet_sock, but also timewait and request sockets + * for IPv4/IPv6. + */ +static u32 sk_ehashfn(const struct sock *sk) +{ +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6 && + !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) + return inet6_ehashfn(sock_net(sk), + &sk->sk_v6_rcv_saddr, sk->sk_num, + &sk->sk_v6_daddr, sk->sk_dport); +#endif + return inet_ehashfn(sock_net(sk), + sk->sk_rcv_saddr, sk->sk_num, + sk->sk_daddr, sk->sk_dport); +} + +/* + * Allocate and initialize a new local port bind bucket. + * The bindhash mutex for snum's hash chain must be held here. + */ +struct inet_bind_bucket *inet_bind_bucket_create(struct kmem_cache *cachep, + struct net *net, + struct inet_bind_hashbucket *head, + const unsigned short snum, + int l3mdev) +{ + struct inet_bind_bucket *tb = kmem_cache_alloc(cachep, GFP_ATOMIC); + + if (tb) { + write_pnet(&tb->ib_net, net); + tb->l3mdev = l3mdev; + tb->port = snum; + tb->fastreuse = 0; + tb->fastreuseport = 0; + INIT_HLIST_HEAD(&tb->owners); + hlist_add_head(&tb->node, &head->chain); + } + return tb; +} + +/* + * Caller must hold hashbucket lock for this tb with local BH disabled + */ +void inet_bind_bucket_destroy(struct kmem_cache *cachep, struct inet_bind_bucket *tb) +{ + if (hlist_empty(&tb->owners)) { + __hlist_del(&tb->node); + kmem_cache_free(cachep, tb); + } +} + +bool inet_bind_bucket_match(const struct inet_bind_bucket *tb, const struct net *net, + unsigned short port, int l3mdev) +{ + return net_eq(ib_net(tb), net) && tb->port == port && + tb->l3mdev == l3mdev; +} + +static void inet_bind2_bucket_init(struct inet_bind2_bucket *tb, + struct net *net, + struct inet_bind_hashbucket *head, + unsigned short port, int l3mdev, + const struct sock *sk) +{ + write_pnet(&tb->ib_net, net); + tb->l3mdev = l3mdev; + tb->port = port; +#if IS_ENABLED(CONFIG_IPV6) + tb->family = sk->sk_family; + if (sk->sk_family == AF_INET6) + tb->v6_rcv_saddr = sk->sk_v6_rcv_saddr; + else +#endif + tb->rcv_saddr = sk->sk_rcv_saddr; + INIT_HLIST_HEAD(&tb->owners); + INIT_HLIST_HEAD(&tb->deathrow); + hlist_add_head(&tb->node, &head->chain); +} + +struct inet_bind2_bucket *inet_bind2_bucket_create(struct kmem_cache *cachep, + struct net *net, + struct inet_bind_hashbucket *head, + unsigned short port, + int l3mdev, + const struct sock *sk) +{ + struct inet_bind2_bucket *tb = kmem_cache_alloc(cachep, GFP_ATOMIC); + + if (tb) + inet_bind2_bucket_init(tb, net, head, port, l3mdev, sk); + + return tb; +} + +/* Caller must hold hashbucket lock for this tb with local BH disabled */ +void inet_bind2_bucket_destroy(struct kmem_cache *cachep, struct inet_bind2_bucket *tb) +{ + if (hlist_empty(&tb->owners) && hlist_empty(&tb->deathrow)) { + __hlist_del(&tb->node); + kmem_cache_free(cachep, tb); + } +} + +static bool inet_bind2_bucket_addr_match(const struct inet_bind2_bucket *tb2, + const struct sock *sk) +{ +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family != tb2->family) { + if (sk->sk_family == AF_INET) + return ipv6_addr_v4mapped(&tb2->v6_rcv_saddr) && + tb2->v6_rcv_saddr.s6_addr32[3] == sk->sk_rcv_saddr; + + return ipv6_addr_v4mapped(&sk->sk_v6_rcv_saddr) && + sk->sk_v6_rcv_saddr.s6_addr32[3] == tb2->rcv_saddr; + } + + if (sk->sk_family == AF_INET6) + return ipv6_addr_equal(&tb2->v6_rcv_saddr, + &sk->sk_v6_rcv_saddr); +#endif + return tb2->rcv_saddr == sk->sk_rcv_saddr; +} + +void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb, + struct inet_bind2_bucket *tb2, unsigned short port) +{ + inet_sk(sk)->inet_num = port; + sk_add_bind_node(sk, &tb->owners); + inet_csk(sk)->icsk_bind_hash = tb; + sk_add_bind2_node(sk, &tb2->owners); + inet_csk(sk)->icsk_bind2_hash = tb2; +} + +/* + * Get rid of any references to a local port held by the given sock. + */ +static void __inet_put_port(struct sock *sk) +{ + struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); + struct inet_bind_hashbucket *head, *head2; + struct net *net = sock_net(sk); + struct inet_bind_bucket *tb; + int bhash; + + bhash = inet_bhashfn(net, inet_sk(sk)->inet_num, hashinfo->bhash_size); + head = &hashinfo->bhash[bhash]; + head2 = inet_bhashfn_portaddr(hashinfo, sk, net, inet_sk(sk)->inet_num); + + spin_lock(&head->lock); + tb = inet_csk(sk)->icsk_bind_hash; + __sk_del_bind_node(sk); + inet_csk(sk)->icsk_bind_hash = NULL; + inet_sk(sk)->inet_num = 0; + inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb); + + spin_lock(&head2->lock); + if (inet_csk(sk)->icsk_bind2_hash) { + struct inet_bind2_bucket *tb2 = inet_csk(sk)->icsk_bind2_hash; + + __sk_del_bind2_node(sk); + inet_csk(sk)->icsk_bind2_hash = NULL; + inet_bind2_bucket_destroy(hashinfo->bind2_bucket_cachep, tb2); + } + spin_unlock(&head2->lock); + + spin_unlock(&head->lock); +} + +void inet_put_port(struct sock *sk) +{ + local_bh_disable(); + __inet_put_port(sk); + local_bh_enable(); +} +EXPORT_SYMBOL(inet_put_port); + +int __inet_inherit_port(const struct sock *sk, struct sock *child) +{ + struct inet_hashinfo *table = tcp_or_dccp_get_hashinfo(sk); + unsigned short port = inet_sk(child)->inet_num; + struct inet_bind_hashbucket *head, *head2; + bool created_inet_bind_bucket = false; + struct net *net = sock_net(sk); + bool update_fastreuse = false; + struct inet_bind2_bucket *tb2; + struct inet_bind_bucket *tb; + int bhash, l3mdev; + + bhash = inet_bhashfn(net, port, table->bhash_size); + head = &table->bhash[bhash]; + head2 = inet_bhashfn_portaddr(table, child, net, port); + + spin_lock(&head->lock); + spin_lock(&head2->lock); + tb = inet_csk(sk)->icsk_bind_hash; + tb2 = inet_csk(sk)->icsk_bind2_hash; + if (unlikely(!tb || !tb2)) { + spin_unlock(&head2->lock); + spin_unlock(&head->lock); + return -ENOENT; + } + if (tb->port != port) { + l3mdev = inet_sk_bound_l3mdev(sk); + + /* NOTE: using tproxy and redirecting skbs to a proxy + * on a different listener port breaks the assumption + * that the listener socket's icsk_bind_hash is the same + * as that of the child socket. We have to look up or + * create a new bind bucket for the child here. */ + inet_bind_bucket_for_each(tb, &head->chain) { + if (inet_bind_bucket_match(tb, net, port, l3mdev)) + break; + } + if (!tb) { + tb = inet_bind_bucket_create(table->bind_bucket_cachep, + net, head, port, l3mdev); + if (!tb) { + spin_unlock(&head2->lock); + spin_unlock(&head->lock); + return -ENOMEM; + } + created_inet_bind_bucket = true; + } + update_fastreuse = true; + + goto bhash2_find; + } else if (!inet_bind2_bucket_addr_match(tb2, child)) { + l3mdev = inet_sk_bound_l3mdev(sk); + +bhash2_find: + tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, child); + if (!tb2) { + tb2 = inet_bind2_bucket_create(table->bind2_bucket_cachep, + net, head2, port, + l3mdev, child); + if (!tb2) + goto error; + } + } + if (update_fastreuse) + inet_csk_update_fastreuse(tb, child); + inet_bind_hash(child, tb, tb2, port); + spin_unlock(&head2->lock); + spin_unlock(&head->lock); + + return 0; + +error: + if (created_inet_bind_bucket) + inet_bind_bucket_destroy(table->bind_bucket_cachep, tb); + spin_unlock(&head2->lock); + spin_unlock(&head->lock); + return -ENOMEM; +} +EXPORT_SYMBOL_GPL(__inet_inherit_port); + +static struct inet_listen_hashbucket * +inet_lhash2_bucket_sk(struct inet_hashinfo *h, struct sock *sk) +{ + u32 hash; + +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) + hash = ipv6_portaddr_hash(sock_net(sk), + &sk->sk_v6_rcv_saddr, + inet_sk(sk)->inet_num); + else +#endif + hash = ipv4_portaddr_hash(sock_net(sk), + inet_sk(sk)->inet_rcv_saddr, + inet_sk(sk)->inet_num); + return inet_lhash2_bucket(h, hash); +} + +static inline int compute_score(struct sock *sk, struct net *net, + const unsigned short hnum, const __be32 daddr, + const int dif, const int sdif) +{ + int score = -1; + + if (net_eq(sock_net(sk), net) && sk->sk_num == hnum && + !ipv6_only_sock(sk)) { + if (sk->sk_rcv_saddr != daddr) + return -1; + + if (!inet_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif)) + return -1; + score = sk->sk_bound_dev_if ? 2 : 1; + + if (sk->sk_family == PF_INET) + score++; + if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id()) + score++; + } + return score; +} + +static inline struct sock *lookup_reuseport(struct net *net, struct sock *sk, + struct sk_buff *skb, int doff, + __be32 saddr, __be16 sport, + __be32 daddr, unsigned short hnum) +{ + struct sock *reuse_sk = NULL; + u32 phash; + + if (sk->sk_reuseport) { + phash = inet_ehashfn(net, daddr, hnum, saddr, sport); + reuse_sk = reuseport_select_sock(sk, phash, skb, doff); + } + return reuse_sk; +} + +/* + * Here are some nice properties to exploit here. The BSD API + * does not allow a listening sock to specify the remote port nor the + * remote address for the connection. So always assume those are both + * wildcarded during the search since they can never be otherwise. + */ + +/* called with rcu_read_lock() : No refcount taken on the socket */ +static struct sock *inet_lhash2_lookup(struct net *net, + struct inet_listen_hashbucket *ilb2, + struct sk_buff *skb, int doff, + const __be32 saddr, __be16 sport, + const __be32 daddr, const unsigned short hnum, + const int dif, const int sdif) +{ + struct sock *sk, *result = NULL; + struct hlist_nulls_node *node; + int score, hiscore = 0; + + sk_nulls_for_each_rcu(sk, node, &ilb2->nulls_head) { + score = compute_score(sk, net, hnum, daddr, dif, sdif); + if (score > hiscore) { + result = lookup_reuseport(net, sk, skb, doff, + saddr, sport, daddr, hnum); + if (result) + return result; + + result = sk; + hiscore = score; + } + } + + return result; +} + +static inline struct sock *inet_lookup_run_bpf(struct net *net, + struct inet_hashinfo *hashinfo, + struct sk_buff *skb, int doff, + __be32 saddr, __be16 sport, + __be32 daddr, u16 hnum, const int dif) +{ + struct sock *sk, *reuse_sk; + bool no_reuseport; + + if (hashinfo != net->ipv4.tcp_death_row.hashinfo) + return NULL; /* only TCP is supported */ + + no_reuseport = bpf_sk_lookup_run_v4(net, IPPROTO_TCP, saddr, sport, + daddr, hnum, dif, &sk); + if (no_reuseport || IS_ERR_OR_NULL(sk)) + return sk; + + reuse_sk = lookup_reuseport(net, sk, skb, doff, saddr, sport, daddr, hnum); + if (reuse_sk) + sk = reuse_sk; + return sk; +} + +struct sock *__inet_lookup_listener(struct net *net, + struct inet_hashinfo *hashinfo, + struct sk_buff *skb, int doff, + const __be32 saddr, __be16 sport, + const __be32 daddr, const unsigned short hnum, + const int dif, const int sdif) +{ + struct inet_listen_hashbucket *ilb2; + struct sock *result = NULL; + unsigned int hash2; + + /* Lookup redirect from BPF */ + if (static_branch_unlikely(&bpf_sk_lookup_enabled)) { + result = inet_lookup_run_bpf(net, hashinfo, skb, doff, + saddr, sport, daddr, hnum, dif); + if (result) + goto done; + } + + hash2 = ipv4_portaddr_hash(net, daddr, hnum); + ilb2 = inet_lhash2_bucket(hashinfo, hash2); + + result = inet_lhash2_lookup(net, ilb2, skb, doff, + saddr, sport, daddr, hnum, + dif, sdif); + if (result) + goto done; + + /* Lookup lhash2 with INADDR_ANY */ + hash2 = ipv4_portaddr_hash(net, htonl(INADDR_ANY), hnum); + ilb2 = inet_lhash2_bucket(hashinfo, hash2); + + result = inet_lhash2_lookup(net, ilb2, skb, doff, + saddr, sport, htonl(INADDR_ANY), hnum, + dif, sdif); +done: + if (IS_ERR(result)) + return NULL; + return result; +} +EXPORT_SYMBOL_GPL(__inet_lookup_listener); + +/* All sockets share common refcount, but have different destructors */ +void sock_gen_put(struct sock *sk) +{ + if (!refcount_dec_and_test(&sk->sk_refcnt)) + return; + + if (sk->sk_state == TCP_TIME_WAIT) + inet_twsk_free(inet_twsk(sk)); + else if (sk->sk_state == TCP_NEW_SYN_RECV) + reqsk_free(inet_reqsk(sk)); + else + sk_free(sk); +} +EXPORT_SYMBOL_GPL(sock_gen_put); + +void sock_edemux(struct sk_buff *skb) +{ + sock_gen_put(skb->sk); +} +EXPORT_SYMBOL(sock_edemux); + +struct sock *__inet_lookup_established(struct net *net, + struct inet_hashinfo *hashinfo, + const __be32 saddr, const __be16 sport, + const __be32 daddr, const u16 hnum, + const int dif, const int sdif) +{ + INET_ADDR_COOKIE(acookie, saddr, daddr); + const __portpair ports = INET_COMBINED_PORTS(sport, hnum); + struct sock *sk; + const struct hlist_nulls_node *node; + /* Optimize here for direct hit, only listening connections can + * have wildcards anyways. + */ + unsigned int hash = inet_ehashfn(net, daddr, hnum, saddr, sport); + unsigned int slot = hash & hashinfo->ehash_mask; + struct inet_ehash_bucket *head = &hashinfo->ehash[slot]; + +begin: + sk_nulls_for_each_rcu(sk, node, &head->chain) { + if (sk->sk_hash != hash) + continue; + if (likely(inet_match(net, sk, acookie, ports, dif, sdif))) { + if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt))) + goto out; + if (unlikely(!inet_match(net, sk, acookie, + ports, dif, sdif))) { + sock_gen_put(sk); + goto begin; + } + goto found; + } + } + /* + * if the nulls value we got at the end of this lookup is + * not the expected one, we must restart lookup. + * We probably met an item that was moved to another chain. + */ + if (get_nulls_value(node) != slot) + goto begin; +out: + sk = NULL; +found: + return sk; +} +EXPORT_SYMBOL_GPL(__inet_lookup_established); + +/* called with local bh disabled */ +static int __inet_check_established(struct inet_timewait_death_row *death_row, + struct sock *sk, __u16 lport, + struct inet_timewait_sock **twp) +{ + struct inet_hashinfo *hinfo = death_row->hashinfo; + struct inet_sock *inet = inet_sk(sk); + __be32 daddr = inet->inet_rcv_saddr; + __be32 saddr = inet->inet_daddr; + int dif = sk->sk_bound_dev_if; + struct net *net = sock_net(sk); + int sdif = l3mdev_master_ifindex_by_index(net, dif); + INET_ADDR_COOKIE(acookie, saddr, daddr); + const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport); + unsigned int hash = inet_ehashfn(net, daddr, lport, + saddr, inet->inet_dport); + struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash); + spinlock_t *lock = inet_ehash_lockp(hinfo, hash); + struct sock *sk2; + const struct hlist_nulls_node *node; + struct inet_timewait_sock *tw = NULL; + + spin_lock(lock); + + sk_nulls_for_each(sk2, node, &head->chain) { + if (sk2->sk_hash != hash) + continue; + + if (likely(inet_match(net, sk2, acookie, ports, dif, sdif))) { + if (sk2->sk_state == TCP_TIME_WAIT) { + tw = inet_twsk(sk2); + if (twsk_unique(sk, sk2, twp)) + break; + } + goto not_unique; + } + } + + /* Must record num and sport now. Otherwise we will see + * in hash table socket with a funny identity. + */ + inet->inet_num = lport; + inet->inet_sport = htons(lport); + sk->sk_hash = hash; + WARN_ON(!sk_unhashed(sk)); + __sk_nulls_add_node_rcu(sk, &head->chain); + if (tw) { + sk_nulls_del_node_init_rcu((struct sock *)tw); + __NET_INC_STATS(net, LINUX_MIB_TIMEWAITRECYCLED); + } + spin_unlock(lock); + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); + + if (twp) { + *twp = tw; + } else if (tw) { + /* Silly. Should hash-dance instead... */ + inet_twsk_deschedule_put(tw); + } + return 0; + +not_unique: + spin_unlock(lock); + return -EADDRNOTAVAIL; +} + +static u64 inet_sk_port_offset(const struct sock *sk) +{ + const struct inet_sock *inet = inet_sk(sk); + + return secure_ipv4_port_ephemeral(inet->inet_rcv_saddr, + inet->inet_daddr, + inet->inet_dport); +} + +/* Searches for an exsiting socket in the ehash bucket list. + * Returns true if found, false otherwise. + */ +static bool inet_ehash_lookup_by_sk(struct sock *sk, + struct hlist_nulls_head *list) +{ + const __portpair ports = INET_COMBINED_PORTS(sk->sk_dport, sk->sk_num); + const int sdif = sk->sk_bound_dev_if; + const int dif = sk->sk_bound_dev_if; + const struct hlist_nulls_node *node; + struct net *net = sock_net(sk); + struct sock *esk; + + INET_ADDR_COOKIE(acookie, sk->sk_daddr, sk->sk_rcv_saddr); + + sk_nulls_for_each_rcu(esk, node, list) { + if (esk->sk_hash != sk->sk_hash) + continue; + if (sk->sk_family == AF_INET) { + if (unlikely(inet_match(net, esk, acookie, + ports, dif, sdif))) { + return true; + } + } +#if IS_ENABLED(CONFIG_IPV6) + else if (sk->sk_family == AF_INET6) { + if (unlikely(inet6_match(net, esk, + &sk->sk_v6_daddr, + &sk->sk_v6_rcv_saddr, + ports, dif, sdif))) { + return true; + } + } +#endif + } + return false; +} + +/* Insert a socket into ehash, and eventually remove another one + * (The another one can be a SYN_RECV or TIMEWAIT) + * If an existing socket already exists, socket sk is not inserted, + * and sets found_dup_sk parameter to true. + */ +bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk) +{ + struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); + struct inet_ehash_bucket *head; + struct hlist_nulls_head *list; + spinlock_t *lock; + bool ret = true; + + WARN_ON_ONCE(!sk_unhashed(sk)); + + sk->sk_hash = sk_ehashfn(sk); + head = inet_ehash_bucket(hashinfo, sk->sk_hash); + list = &head->chain; + lock = inet_ehash_lockp(hashinfo, sk->sk_hash); + + spin_lock(lock); + if (osk) { + WARN_ON_ONCE(sk->sk_hash != osk->sk_hash); + ret = sk_nulls_del_node_init_rcu(osk); + } else if (found_dup_sk) { + *found_dup_sk = inet_ehash_lookup_by_sk(sk, list); + if (*found_dup_sk) + ret = false; + } + + if (ret) + __sk_nulls_add_node_rcu(sk, list); + + spin_unlock(lock); + + return ret; +} + +bool inet_ehash_nolisten(struct sock *sk, struct sock *osk, bool *found_dup_sk) +{ + bool ok = inet_ehash_insert(sk, osk, found_dup_sk); + + if (ok) { + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); + } else { + this_cpu_inc(*sk->sk_prot->orphan_count); + inet_sk_set_state(sk, TCP_CLOSE); + sock_set_flag(sk, SOCK_DEAD); + inet_csk_destroy_sock(sk); + } + return ok; +} +EXPORT_SYMBOL_GPL(inet_ehash_nolisten); + +static int inet_reuseport_add_sock(struct sock *sk, + struct inet_listen_hashbucket *ilb) +{ + struct inet_bind_bucket *tb = inet_csk(sk)->icsk_bind_hash; + const struct hlist_nulls_node *node; + struct sock *sk2; + kuid_t uid = sock_i_uid(sk); + + sk_nulls_for_each_rcu(sk2, node, &ilb->nulls_head) { + if (sk2 != sk && + sk2->sk_family == sk->sk_family && + ipv6_only_sock(sk2) == ipv6_only_sock(sk) && + sk2->sk_bound_dev_if == sk->sk_bound_dev_if && + inet_csk(sk2)->icsk_bind_hash == tb && + sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) && + inet_rcv_saddr_equal(sk, sk2, false)) + return reuseport_add_sock(sk, sk2, + inet_rcv_saddr_any(sk)); + } + + return reuseport_alloc(sk, inet_rcv_saddr_any(sk)); +} + +int __inet_hash(struct sock *sk, struct sock *osk) +{ + struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); + struct inet_listen_hashbucket *ilb2; + int err = 0; + + if (sk->sk_state != TCP_LISTEN) { + local_bh_disable(); + inet_ehash_nolisten(sk, osk, NULL); + local_bh_enable(); + return 0; + } + WARN_ON(!sk_unhashed(sk)); + ilb2 = inet_lhash2_bucket_sk(hashinfo, sk); + + spin_lock(&ilb2->lock); + if (sk->sk_reuseport) { + err = inet_reuseport_add_sock(sk, ilb2); + if (err) + goto unlock; + } + sock_set_flag(sk, SOCK_RCU_FREE); + if (IS_ENABLED(CONFIG_IPV6) && sk->sk_reuseport && + sk->sk_family == AF_INET6) + __sk_nulls_add_node_tail_rcu(sk, &ilb2->nulls_head); + else + __sk_nulls_add_node_rcu(sk, &ilb2->nulls_head); + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); +unlock: + spin_unlock(&ilb2->lock); + + return err; +} +EXPORT_SYMBOL(__inet_hash); + +int inet_hash(struct sock *sk) +{ + int err = 0; + + if (sk->sk_state != TCP_CLOSE) + err = __inet_hash(sk, NULL); + + return err; +} +EXPORT_SYMBOL_GPL(inet_hash); + +void inet_unhash(struct sock *sk) +{ + struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); + + if (sk_unhashed(sk)) + return; + + if (sk->sk_state == TCP_LISTEN) { + struct inet_listen_hashbucket *ilb2; + + ilb2 = inet_lhash2_bucket_sk(hashinfo, sk); + /* Don't disable bottom halves while acquiring the lock to + * avoid circular locking dependency on PREEMPT_RT. + */ + spin_lock(&ilb2->lock); + if (sk_unhashed(sk)) { + spin_unlock(&ilb2->lock); + return; + } + + if (rcu_access_pointer(sk->sk_reuseport_cb)) + reuseport_stop_listen_sock(sk); + + __sk_nulls_del_node_init_rcu(sk); + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); + spin_unlock(&ilb2->lock); + } else { + spinlock_t *lock = inet_ehash_lockp(hashinfo, sk->sk_hash); + + spin_lock_bh(lock); + if (sk_unhashed(sk)) { + spin_unlock_bh(lock); + return; + } + __sk_nulls_del_node_init_rcu(sk); + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); + spin_unlock_bh(lock); + } +} +EXPORT_SYMBOL_GPL(inet_unhash); + +static bool inet_bind2_bucket_match(const struct inet_bind2_bucket *tb, + const struct net *net, unsigned short port, + int l3mdev, const struct sock *sk) +{ + if (!net_eq(ib2_net(tb), net) || tb->port != port || + tb->l3mdev != l3mdev) + return false; + + return inet_bind2_bucket_addr_match(tb, sk); +} + +bool inet_bind2_bucket_match_addr_any(const struct inet_bind2_bucket *tb, const struct net *net, + unsigned short port, int l3mdev, const struct sock *sk) +{ + if (!net_eq(ib2_net(tb), net) || tb->port != port || + tb->l3mdev != l3mdev) + return false; + +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family != tb->family) { + if (sk->sk_family == AF_INET) + return ipv6_addr_any(&tb->v6_rcv_saddr) || + ipv6_addr_v4mapped_any(&tb->v6_rcv_saddr); + + return false; + } + + if (sk->sk_family == AF_INET6) + return ipv6_addr_any(&tb->v6_rcv_saddr); +#endif + return tb->rcv_saddr == 0; +} + +/* The socket's bhash2 hashbucket spinlock must be held when this is called */ +struct inet_bind2_bucket * +inet_bind2_bucket_find(const struct inet_bind_hashbucket *head, const struct net *net, + unsigned short port, int l3mdev, const struct sock *sk) +{ + struct inet_bind2_bucket *bhash2 = NULL; + + inet_bind_bucket_for_each(bhash2, &head->chain) + if (inet_bind2_bucket_match(bhash2, net, port, l3mdev, sk)) + break; + + return bhash2; +} + +struct inet_bind_hashbucket * +inet_bhash2_addr_any_hashbucket(const struct sock *sk, const struct net *net, int port) +{ + struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); + u32 hash; + +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) + hash = ipv6_portaddr_hash(net, &in6addr_any, port); + else +#endif + hash = ipv4_portaddr_hash(net, 0, port); + + return &hinfo->bhash2[hash & (hinfo->bhash_size - 1)]; +} + +static void inet_update_saddr(struct sock *sk, void *saddr, int family) +{ + if (family == AF_INET) { + inet_sk(sk)->inet_saddr = *(__be32 *)saddr; + sk_rcv_saddr_set(sk, inet_sk(sk)->inet_saddr); + } +#if IS_ENABLED(CONFIG_IPV6) + else { + sk->sk_v6_rcv_saddr = *(struct in6_addr *)saddr; + } +#endif +} + +static int __inet_bhash2_update_saddr(struct sock *sk, void *saddr, int family, bool reset) +{ + struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); + struct inet_bind_hashbucket *head, *head2; + struct inet_bind2_bucket *tb2, *new_tb2; + int l3mdev = inet_sk_bound_l3mdev(sk); + int port = inet_sk(sk)->inet_num; + struct net *net = sock_net(sk); + int bhash; + + if (!inet_csk(sk)->icsk_bind2_hash) { + /* Not bind()ed before. */ + if (reset) + inet_reset_saddr(sk); + else + inet_update_saddr(sk, saddr, family); + + return 0; + } + + /* Allocate a bind2 bucket ahead of time to avoid permanently putting + * the bhash2 table in an inconsistent state if a new tb2 bucket + * allocation fails. + */ + new_tb2 = kmem_cache_alloc(hinfo->bind2_bucket_cachep, GFP_ATOMIC); + if (!new_tb2) { + if (reset) { + /* The (INADDR_ANY, port) bucket might have already + * been freed, then we cannot fixup icsk_bind2_hash, + * so we give up and unlink sk from bhash/bhash2 not + * to leave inconsistency in bhash2. + */ + inet_put_port(sk); + inet_reset_saddr(sk); + } + + return -ENOMEM; + } + + bhash = inet_bhashfn(net, port, hinfo->bhash_size); + head = &hinfo->bhash[bhash]; + head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); + + /* If we change saddr locklessly, another thread + * iterating over bhash might see corrupted address. + */ + spin_lock_bh(&head->lock); + + spin_lock(&head2->lock); + __sk_del_bind2_node(sk); + inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep, inet_csk(sk)->icsk_bind2_hash); + spin_unlock(&head2->lock); + + if (reset) + inet_reset_saddr(sk); + else + inet_update_saddr(sk, saddr, family); + + head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); + + spin_lock(&head2->lock); + tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); + if (!tb2) { + tb2 = new_tb2; + inet_bind2_bucket_init(tb2, net, head2, port, l3mdev, sk); + } + sk_add_bind2_node(sk, &tb2->owners); + inet_csk(sk)->icsk_bind2_hash = tb2; + spin_unlock(&head2->lock); + + spin_unlock_bh(&head->lock); + + if (tb2 != new_tb2) + kmem_cache_free(hinfo->bind2_bucket_cachep, new_tb2); + + return 0; +} + +int inet_bhash2_update_saddr(struct sock *sk, void *saddr, int family) +{ + return __inet_bhash2_update_saddr(sk, saddr, family, false); +} +EXPORT_SYMBOL_GPL(inet_bhash2_update_saddr); + +void inet_bhash2_reset_saddr(struct sock *sk) +{ + if (!(sk->sk_userlocks & SOCK_BINDADDR_LOCK)) + __inet_bhash2_update_saddr(sk, NULL, 0, true); +} +EXPORT_SYMBOL_GPL(inet_bhash2_reset_saddr); + +/* RFC 6056 3.3.4. Algorithm 4: Double-Hash Port Selection Algorithm + * Note that we use 32bit integers (vs RFC 'short integers') + * because 2^16 is not a multiple of num_ephemeral and this + * property might be used by clever attacker. + * + * RFC claims using TABLE_LENGTH=10 buckets gives an improvement, though + * attacks were since demonstrated, thus we use 65536 by default instead + * to really give more isolation and privacy, at the expense of 256kB + * of kernel memory. + */ +#define INET_TABLE_PERTURB_SIZE (1 << CONFIG_INET_TABLE_PERTURB_ORDER) +static u32 *table_perturb; + +int __inet_hash_connect(struct inet_timewait_death_row *death_row, + struct sock *sk, u64 port_offset, + int (*check_established)(struct inet_timewait_death_row *, + struct sock *, __u16, struct inet_timewait_sock **)) +{ + struct inet_hashinfo *hinfo = death_row->hashinfo; + struct inet_bind_hashbucket *head, *head2; + struct inet_timewait_sock *tw = NULL; + int port = inet_sk(sk)->inet_num; + struct net *net = sock_net(sk); + struct inet_bind2_bucket *tb2; + struct inet_bind_bucket *tb; + bool tb_created = false; + u32 remaining, offset; + int ret, i, low, high; + int l3mdev; + u32 index; + + if (port) { + local_bh_disable(); + ret = check_established(death_row, sk, port, NULL); + local_bh_enable(); + return ret; + } + + l3mdev = inet_sk_bound_l3mdev(sk); + + inet_sk_get_local_port_range(sk, &low, &high); + high++; /* [32768, 60999] -> [32768, 61000[ */ + remaining = high - low; + if (likely(remaining > 1)) + remaining &= ~1U; + + get_random_sleepable_once(table_perturb, + INET_TABLE_PERTURB_SIZE * sizeof(*table_perturb)); + index = port_offset & (INET_TABLE_PERTURB_SIZE - 1); + + offset = READ_ONCE(table_perturb[index]) + (port_offset >> 32); + offset %= remaining; + + /* In first pass we try ports of @low parity. + * inet_csk_get_port() does the opposite choice. + */ + offset &= ~1U; +other_parity_scan: + port = low + offset; + for (i = 0; i < remaining; i += 2, port += 2) { + if (unlikely(port >= high)) + port -= remaining; + if (inet_is_local_reserved_port(net, port)) + continue; + head = &hinfo->bhash[inet_bhashfn(net, port, + hinfo->bhash_size)]; + spin_lock_bh(&head->lock); + + /* Does not bother with rcv_saddr checks, because + * the established check is already unique enough. + */ + inet_bind_bucket_for_each(tb, &head->chain) { + if (inet_bind_bucket_match(tb, net, port, l3mdev)) { + if (tb->fastreuse >= 0 || + tb->fastreuseport >= 0) + goto next_port; + WARN_ON(hlist_empty(&tb->owners)); + if (!check_established(death_row, sk, + port, &tw)) + goto ok; + goto next_port; + } + } + + tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep, + net, head, port, l3mdev); + if (!tb) { + spin_unlock_bh(&head->lock); + return -ENOMEM; + } + tb_created = true; + tb->fastreuse = -1; + tb->fastreuseport = -1; + goto ok; +next_port: + spin_unlock_bh(&head->lock); + cond_resched(); + } + + offset++; + if ((offset & 1) && remaining > 1) + goto other_parity_scan; + + return -EADDRNOTAVAIL; + +ok: + /* Find the corresponding tb2 bucket since we need to + * add the socket to the bhash2 table as well + */ + head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); + spin_lock(&head2->lock); + + tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); + if (!tb2) { + tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep, net, + head2, port, l3mdev, sk); + if (!tb2) + goto error; + } + + /* Here we want to add a little bit of randomness to the next source + * port that will be chosen. We use a max() with a random here so that + * on low contention the randomness is maximal and on high contention + * it may be inexistent. + */ + i = max_t(int, i, prandom_u32_max(8) * 2); + WRITE_ONCE(table_perturb[index], READ_ONCE(table_perturb[index]) + i + 2); + + /* Head lock still held and bh's disabled */ + inet_bind_hash(sk, tb, tb2, port); + + if (sk_unhashed(sk)) { + inet_sk(sk)->inet_sport = htons(port); + inet_ehash_nolisten(sk, (struct sock *)tw, NULL); + } + if (tw) + inet_twsk_bind_unhash(tw, hinfo); + + spin_unlock(&head2->lock); + spin_unlock(&head->lock); + + if (tw) + inet_twsk_deschedule_put(tw); + local_bh_enable(); + return 0; + +error: + spin_unlock(&head2->lock); + if (tb_created) + inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb); + spin_unlock_bh(&head->lock); + return -ENOMEM; +} + +/* + * Bind a port for a connect operation and hash it. + */ +int inet_hash_connect(struct inet_timewait_death_row *death_row, + struct sock *sk) +{ + u64 port_offset = 0; + + if (!inet_sk(sk)->inet_num) + port_offset = inet_sk_port_offset(sk); + return __inet_hash_connect(death_row, sk, port_offset, + __inet_check_established); +} +EXPORT_SYMBOL_GPL(inet_hash_connect); + +static void init_hashinfo_lhash2(struct inet_hashinfo *h) +{ + int i; + + for (i = 0; i <= h->lhash2_mask; i++) { + spin_lock_init(&h->lhash2[i].lock); + INIT_HLIST_NULLS_HEAD(&h->lhash2[i].nulls_head, + i + LISTENING_NULLS_BASE); + } +} + +void __init inet_hashinfo2_init(struct inet_hashinfo *h, const char *name, + unsigned long numentries, int scale, + unsigned long low_limit, + unsigned long high_limit) +{ + h->lhash2 = alloc_large_system_hash(name, + sizeof(*h->lhash2), + numentries, + scale, + 0, + NULL, + &h->lhash2_mask, + low_limit, + high_limit); + init_hashinfo_lhash2(h); + + /* this one is used for source ports of outgoing connections */ + table_perturb = alloc_large_system_hash("Table-perturb", + sizeof(*table_perturb), + INET_TABLE_PERTURB_SIZE, + 0, 0, NULL, NULL, + INET_TABLE_PERTURB_SIZE, + INET_TABLE_PERTURB_SIZE); +} + +int inet_hashinfo2_init_mod(struct inet_hashinfo *h) +{ + h->lhash2 = kmalloc_array(INET_LHTABLE_SIZE, sizeof(*h->lhash2), GFP_KERNEL); + if (!h->lhash2) + return -ENOMEM; + + h->lhash2_mask = INET_LHTABLE_SIZE - 1; + /* INET_LHTABLE_SIZE must be a power of 2 */ + BUG_ON(INET_LHTABLE_SIZE & h->lhash2_mask); + + init_hashinfo_lhash2(h); + return 0; +} +EXPORT_SYMBOL_GPL(inet_hashinfo2_init_mod); + +int inet_ehash_locks_alloc(struct inet_hashinfo *hashinfo) +{ + unsigned int locksz = sizeof(spinlock_t); + unsigned int i, nblocks = 1; + + if (locksz != 0) { + /* allocate 2 cache lines or at least one spinlock per cpu */ + nblocks = max(2U * L1_CACHE_BYTES / locksz, 1U); + nblocks = roundup_pow_of_two(nblocks * num_possible_cpus()); + + /* no more locks than number of hash buckets */ + nblocks = min(nblocks, hashinfo->ehash_mask + 1); + + hashinfo->ehash_locks = kvmalloc_array(nblocks, locksz, GFP_KERNEL); + if (!hashinfo->ehash_locks) + return -ENOMEM; + + for (i = 0; i < nblocks; i++) + spin_lock_init(&hashinfo->ehash_locks[i]); + } + hashinfo->ehash_locks_mask = nblocks - 1; + return 0; +} +EXPORT_SYMBOL_GPL(inet_ehash_locks_alloc); + +struct inet_hashinfo *inet_pernet_hashinfo_alloc(struct inet_hashinfo *hashinfo, + unsigned int ehash_entries) +{ + struct inet_hashinfo *new_hashinfo; + int i; + + new_hashinfo = kmemdup(hashinfo, sizeof(*hashinfo), GFP_KERNEL); + if (!new_hashinfo) + goto err; + + new_hashinfo->ehash = vmalloc_huge(ehash_entries * sizeof(struct inet_ehash_bucket), + GFP_KERNEL_ACCOUNT); + if (!new_hashinfo->ehash) + goto free_hashinfo; + + new_hashinfo->ehash_mask = ehash_entries - 1; + + if (inet_ehash_locks_alloc(new_hashinfo)) + goto free_ehash; + + for (i = 0; i < ehash_entries; i++) + INIT_HLIST_NULLS_HEAD(&new_hashinfo->ehash[i].chain, i); + + new_hashinfo->pernet = true; + + return new_hashinfo; + +free_ehash: + vfree(new_hashinfo->ehash); +free_hashinfo: + kfree(new_hashinfo); +err: + return NULL; +} +EXPORT_SYMBOL_GPL(inet_pernet_hashinfo_alloc); + +void inet_pernet_hashinfo_free(struct inet_hashinfo *hashinfo) +{ + if (!hashinfo->pernet) + return; + + inet_ehash_locks_free(hashinfo); + vfree(hashinfo->ehash); + kfree(hashinfo); +} +EXPORT_SYMBOL_GPL(inet_pernet_hashinfo_free); diff --git a/net/ipv4/inet_timewait_sock.c b/net/ipv4/inet_timewait_sock.c new file mode 100644 index 000000000..1d77d992e --- /dev/null +++ b/net/ipv4/inet_timewait_sock.c @@ -0,0 +1,342 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * Generic TIME_WAIT sockets functions + * + * From code orinally in TCP + */ + +#include +#include +#include +#include +#include +#include + + +/** + * inet_twsk_bind_unhash - unhash a timewait socket from bind hash + * @tw: timewait socket + * @hashinfo: hashinfo pointer + * + * unhash a timewait socket from bind hash, if hashed. + * bind hash lock must be held by caller. + * Returns 1 if caller should call inet_twsk_put() after lock release. + */ +void inet_twsk_bind_unhash(struct inet_timewait_sock *tw, + struct inet_hashinfo *hashinfo) +{ + struct inet_bind2_bucket *tb2 = tw->tw_tb2; + struct inet_bind_bucket *tb = tw->tw_tb; + + if (!tb) + return; + + __hlist_del(&tw->tw_bind_node); + tw->tw_tb = NULL; + inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb); + + __hlist_del(&tw->tw_bind2_node); + tw->tw_tb2 = NULL; + inet_bind2_bucket_destroy(hashinfo->bind2_bucket_cachep, tb2); + + __sock_put((struct sock *)tw); +} + +/* Must be called with locally disabled BHs. */ +static void inet_twsk_kill(struct inet_timewait_sock *tw) +{ + struct inet_hashinfo *hashinfo = tw->tw_dr->hashinfo; + spinlock_t *lock = inet_ehash_lockp(hashinfo, tw->tw_hash); + struct inet_bind_hashbucket *bhead, *bhead2; + + spin_lock(lock); + sk_nulls_del_node_init_rcu((struct sock *)tw); + spin_unlock(lock); + + /* Disassociate with bind bucket. */ + bhead = &hashinfo->bhash[inet_bhashfn(twsk_net(tw), tw->tw_num, + hashinfo->bhash_size)]; + bhead2 = inet_bhashfn_portaddr(hashinfo, (struct sock *)tw, + twsk_net(tw), tw->tw_num); + + spin_lock(&bhead->lock); + spin_lock(&bhead2->lock); + inet_twsk_bind_unhash(tw, hashinfo); + spin_unlock(&bhead2->lock); + spin_unlock(&bhead->lock); + + refcount_dec(&tw->tw_dr->tw_refcount); + inet_twsk_put(tw); +} + +void inet_twsk_free(struct inet_timewait_sock *tw) +{ + struct module *owner = tw->tw_prot->owner; + twsk_destructor((struct sock *)tw); +#ifdef SOCK_REFCNT_DEBUG + pr_debug("%s timewait_sock %p released\n", tw->tw_prot->name, tw); +#endif + kmem_cache_free(tw->tw_prot->twsk_prot->twsk_slab, tw); + module_put(owner); +} + +void inet_twsk_put(struct inet_timewait_sock *tw) +{ + if (refcount_dec_and_test(&tw->tw_refcnt)) + inet_twsk_free(tw); +} +EXPORT_SYMBOL_GPL(inet_twsk_put); + +static void inet_twsk_add_node_rcu(struct inet_timewait_sock *tw, + struct hlist_nulls_head *list) +{ + hlist_nulls_add_head_rcu(&tw->tw_node, list); +} + +static void inet_twsk_add_bind_node(struct inet_timewait_sock *tw, + struct hlist_head *list) +{ + hlist_add_head(&tw->tw_bind_node, list); +} + +static void inet_twsk_add_bind2_node(struct inet_timewait_sock *tw, + struct hlist_head *list) +{ + hlist_add_head(&tw->tw_bind2_node, list); +} + +/* + * Enter the time wait state. This is called with locally disabled BH. + * Essentially we whip up a timewait bucket, copy the relevant info into it + * from the SK, and mess with hash chains and list linkage. + */ +void inet_twsk_hashdance(struct inet_timewait_sock *tw, struct sock *sk, + struct inet_hashinfo *hashinfo) +{ + const struct inet_sock *inet = inet_sk(sk); + const struct inet_connection_sock *icsk = inet_csk(sk); + struct inet_ehash_bucket *ehead = inet_ehash_bucket(hashinfo, sk->sk_hash); + spinlock_t *lock = inet_ehash_lockp(hashinfo, sk->sk_hash); + struct inet_bind_hashbucket *bhead, *bhead2; + + /* Step 1: Put TW into bind hash. Original socket stays there too. + Note, that any socket with inet->num != 0 MUST be bound in + binding cache, even if it is closed. + */ + bhead = &hashinfo->bhash[inet_bhashfn(twsk_net(tw), inet->inet_num, + hashinfo->bhash_size)]; + bhead2 = inet_bhashfn_portaddr(hashinfo, sk, twsk_net(tw), inet->inet_num); + + spin_lock(&bhead->lock); + spin_lock(&bhead2->lock); + + tw->tw_tb = icsk->icsk_bind_hash; + WARN_ON(!icsk->icsk_bind_hash); + inet_twsk_add_bind_node(tw, &tw->tw_tb->owners); + + tw->tw_tb2 = icsk->icsk_bind2_hash; + WARN_ON(!icsk->icsk_bind2_hash); + inet_twsk_add_bind2_node(tw, &tw->tw_tb2->deathrow); + + spin_unlock(&bhead2->lock); + spin_unlock(&bhead->lock); + + spin_lock(lock); + + inet_twsk_add_node_rcu(tw, &ehead->chain); + + /* Step 3: Remove SK from hash chain */ + if (__sk_nulls_del_node_init_rcu(sk)) + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); + + spin_unlock(lock); + + /* tw_refcnt is set to 3 because we have : + * - one reference for bhash chain. + * - one reference for ehash chain. + * - one reference for timer. + * We can use atomic_set() because prior spin_lock()/spin_unlock() + * committed into memory all tw fields. + * Also note that after this point, we lost our implicit reference + * so we are not allowed to use tw anymore. + */ + refcount_set(&tw->tw_refcnt, 3); +} +EXPORT_SYMBOL_GPL(inet_twsk_hashdance); + +static void tw_timer_handler(struct timer_list *t) +{ + struct inet_timewait_sock *tw = from_timer(tw, t, tw_timer); + + inet_twsk_kill(tw); +} + +struct inet_timewait_sock *inet_twsk_alloc(const struct sock *sk, + struct inet_timewait_death_row *dr, + const int state) +{ + struct inet_timewait_sock *tw; + + if (refcount_read(&dr->tw_refcount) - 1 >= + READ_ONCE(dr->sysctl_max_tw_buckets)) + return NULL; + + tw = kmem_cache_alloc(sk->sk_prot_creator->twsk_prot->twsk_slab, + GFP_ATOMIC); + if (tw) { + const struct inet_sock *inet = inet_sk(sk); + + tw->tw_dr = dr; + /* Give us an identity. */ + tw->tw_daddr = inet->inet_daddr; + tw->tw_rcv_saddr = inet->inet_rcv_saddr; + tw->tw_bound_dev_if = sk->sk_bound_dev_if; + tw->tw_tos = inet->tos; + tw->tw_num = inet->inet_num; + tw->tw_state = TCP_TIME_WAIT; + tw->tw_substate = state; + tw->tw_sport = inet->inet_sport; + tw->tw_dport = inet->inet_dport; + tw->tw_family = sk->sk_family; + tw->tw_reuse = sk->sk_reuse; + tw->tw_reuseport = sk->sk_reuseport; + tw->tw_hash = sk->sk_hash; + tw->tw_ipv6only = 0; + tw->tw_transparent = inet->transparent; + tw->tw_prot = sk->sk_prot_creator; + atomic64_set(&tw->tw_cookie, atomic64_read(&sk->sk_cookie)); + twsk_net_set(tw, sock_net(sk)); + timer_setup(&tw->tw_timer, tw_timer_handler, TIMER_PINNED); + /* + * Because we use RCU lookups, we should not set tw_refcnt + * to a non null value before everything is setup for this + * timewait socket. + */ + refcount_set(&tw->tw_refcnt, 0); + + __module_get(tw->tw_prot->owner); + } + + return tw; +} +EXPORT_SYMBOL_GPL(inet_twsk_alloc); + +/* These are always called from BH context. See callers in + * tcp_input.c to verify this. + */ + +/* This is for handling early-kills of TIME_WAIT sockets. + * Warning : consume reference. + * Caller should not access tw anymore. + */ +void inet_twsk_deschedule_put(struct inet_timewait_sock *tw) +{ + if (del_timer_sync(&tw->tw_timer)) + inet_twsk_kill(tw); + inet_twsk_put(tw); +} +EXPORT_SYMBOL(inet_twsk_deschedule_put); + +void __inet_twsk_schedule(struct inet_timewait_sock *tw, int timeo, bool rearm) +{ + /* timeout := RTO * 3.5 + * + * 3.5 = 1+2+0.5 to wait for two retransmits. + * + * RATIONALE: if FIN arrived and we entered TIME-WAIT state, + * our ACK acking that FIN can be lost. If N subsequent retransmitted + * FINs (or previous seqments) are lost (probability of such event + * is p^(N+1), where p is probability to lose single packet and + * time to detect the loss is about RTO*(2^N - 1) with exponential + * backoff). Normal timewait length is calculated so, that we + * waited at least for one retransmitted FIN (maximal RTO is 120sec). + * [ BTW Linux. following BSD, violates this requirement waiting + * only for 60sec, we should wait at least for 240 secs. + * Well, 240 consumes too much of resources 8) + * ] + * This interval is not reduced to catch old duplicate and + * responces to our wandering segments living for two MSLs. + * However, if we use PAWS to detect + * old duplicates, we can reduce the interval to bounds required + * by RTO, rather than MSL. So, if peer understands PAWS, we + * kill tw bucket after 3.5*RTO (it is important that this number + * is greater than TS tick!) and detect old duplicates with help + * of PAWS. + */ + + if (!rearm) { + bool kill = timeo <= 4*HZ; + + __NET_INC_STATS(twsk_net(tw), kill ? LINUX_MIB_TIMEWAITKILLED : + LINUX_MIB_TIMEWAITED); + BUG_ON(mod_timer(&tw->tw_timer, jiffies + timeo)); + refcount_inc(&tw->tw_dr->tw_refcount); + } else { + mod_timer_pending(&tw->tw_timer, jiffies + timeo); + } +} +EXPORT_SYMBOL_GPL(__inet_twsk_schedule); + +void inet_twsk_purge(struct inet_hashinfo *hashinfo, int family) +{ + struct inet_timewait_sock *tw; + struct sock *sk; + struct hlist_nulls_node *node; + unsigned int slot; + + for (slot = 0; slot <= hashinfo->ehash_mask; slot++) { + struct inet_ehash_bucket *head = &hashinfo->ehash[slot]; +restart_rcu: + cond_resched(); + rcu_read_lock(); +restart: + sk_nulls_for_each_rcu(sk, node, &head->chain) { + if (sk->sk_state != TCP_TIME_WAIT) { + /* A kernel listener socket might not hold refcnt for net, + * so reqsk_timer_handler() could be fired after net is + * freed. Userspace listener and reqsk never exist here. + */ + if (unlikely(sk->sk_state == TCP_NEW_SYN_RECV && + hashinfo->pernet)) { + struct request_sock *req = inet_reqsk(sk); + + inet_csk_reqsk_queue_drop_and_put(req->rsk_listener, req); + } + + continue; + } + + tw = inet_twsk(sk); + if ((tw->tw_family != family) || + refcount_read(&twsk_net(tw)->ns.count)) + continue; + + if (unlikely(!refcount_inc_not_zero(&tw->tw_refcnt))) + continue; + + if (unlikely((tw->tw_family != family) || + refcount_read(&twsk_net(tw)->ns.count))) { + inet_twsk_put(tw); + goto restart; + } + + rcu_read_unlock(); + local_bh_disable(); + inet_twsk_deschedule_put(tw); + local_bh_enable(); + goto restart_rcu; + } + /* If the nulls value we got at the end of this lookup is + * not the expected one, we must restart lookup. + * We probably met an item that was moved to another chain. + */ + if (get_nulls_value(node) != slot) + goto restart; + rcu_read_unlock(); + } +} +EXPORT_SYMBOL_GPL(inet_twsk_purge); diff --git a/net/ipv4/inetpeer.c b/net/ipv4/inetpeer.c new file mode 100644 index 000000000..e9fed83e9 --- /dev/null +++ b/net/ipv4/inetpeer.c @@ -0,0 +1,308 @@ +/* + * INETPEER - A storage for permanent information about peers + * + * This source is covered by the GNU GPL, the same as all kernel sources. + * + * Authors: Andrey V. Savochkin + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * Theory of operations. + * We keep one entry for each peer IP address. The nodes contains long-living + * information about the peer which doesn't depend on routes. + * + * Nodes are removed only when reference counter goes to 0. + * When it's happened the node may be removed when a sufficient amount of + * time has been passed since its last use. The less-recently-used entry can + * also be removed if the pool is overloaded i.e. if the total amount of + * entries is greater-or-equal than the threshold. + * + * Node pool is organised as an RB tree. + * Such an implementation has been chosen not just for fun. It's a way to + * prevent easy and efficient DoS attacks by creating hash collisions. A huge + * amount of long living nodes in a single hash slot would significantly delay + * lookups performed with disabled BHs. + * + * Serialisation issues. + * 1. Nodes may appear in the tree only with the pool lock held. + * 2. Nodes may disappear from the tree only with the pool lock held + * AND reference count being 0. + * 3. Global variable peer_total is modified under the pool lock. + * 4. struct inet_peer fields modification: + * rb_node: pool lock + * refcnt: atomically against modifications on other CPU; + * usually under some other lock to prevent node disappearing + * daddr: unchangeable + */ + +static struct kmem_cache *peer_cachep __ro_after_init; + +void inet_peer_base_init(struct inet_peer_base *bp) +{ + bp->rb_root = RB_ROOT; + seqlock_init(&bp->lock); + bp->total = 0; +} +EXPORT_SYMBOL_GPL(inet_peer_base_init); + +#define PEER_MAX_GC 32 + +/* Exported for sysctl_net_ipv4. */ +int inet_peer_threshold __read_mostly; /* start to throw entries more + * aggressively at this stage */ +int inet_peer_minttl __read_mostly = 120 * HZ; /* TTL under high load: 120 sec */ +int inet_peer_maxttl __read_mostly = 10 * 60 * HZ; /* usual time to live: 10 min */ + +/* Called from ip_output.c:ip_init */ +void __init inet_initpeers(void) +{ + u64 nr_entries; + + /* 1% of physical memory */ + nr_entries = div64_ul((u64)totalram_pages() << PAGE_SHIFT, + 100 * L1_CACHE_ALIGN(sizeof(struct inet_peer))); + + inet_peer_threshold = clamp_val(nr_entries, 4096, 65536 + 128); + + peer_cachep = kmem_cache_create("inet_peer_cache", + sizeof(struct inet_peer), + 0, SLAB_HWCACHE_ALIGN | SLAB_PANIC, + NULL); +} + +/* Called with rcu_read_lock() or base->lock held */ +static struct inet_peer *lookup(const struct inetpeer_addr *daddr, + struct inet_peer_base *base, + unsigned int seq, + struct inet_peer *gc_stack[], + unsigned int *gc_cnt, + struct rb_node **parent_p, + struct rb_node ***pp_p) +{ + struct rb_node **pp, *parent, *next; + struct inet_peer *p; + + pp = &base->rb_root.rb_node; + parent = NULL; + while (1) { + int cmp; + + next = rcu_dereference_raw(*pp); + if (!next) + break; + parent = next; + p = rb_entry(parent, struct inet_peer, rb_node); + cmp = inetpeer_addr_cmp(daddr, &p->daddr); + if (cmp == 0) { + if (!refcount_inc_not_zero(&p->refcnt)) + break; + return p; + } + if (gc_stack) { + if (*gc_cnt < PEER_MAX_GC) + gc_stack[(*gc_cnt)++] = p; + } else if (unlikely(read_seqretry(&base->lock, seq))) { + break; + } + if (cmp == -1) + pp = &next->rb_left; + else + pp = &next->rb_right; + } + *parent_p = parent; + *pp_p = pp; + return NULL; +} + +static void inetpeer_free_rcu(struct rcu_head *head) +{ + kmem_cache_free(peer_cachep, container_of(head, struct inet_peer, rcu)); +} + +/* perform garbage collect on all items stacked during a lookup */ +static void inet_peer_gc(struct inet_peer_base *base, + struct inet_peer *gc_stack[], + unsigned int gc_cnt) +{ + int peer_threshold, peer_maxttl, peer_minttl; + struct inet_peer *p; + __u32 delta, ttl; + int i; + + peer_threshold = READ_ONCE(inet_peer_threshold); + peer_maxttl = READ_ONCE(inet_peer_maxttl); + peer_minttl = READ_ONCE(inet_peer_minttl); + + if (base->total >= peer_threshold) + ttl = 0; /* be aggressive */ + else + ttl = peer_maxttl - (peer_maxttl - peer_minttl) / HZ * + base->total / peer_threshold * HZ; + for (i = 0; i < gc_cnt; i++) { + p = gc_stack[i]; + + /* The READ_ONCE() pairs with the WRITE_ONCE() + * in inet_putpeer() + */ + delta = (__u32)jiffies - READ_ONCE(p->dtime); + + if (delta < ttl || !refcount_dec_if_one(&p->refcnt)) + gc_stack[i] = NULL; + } + for (i = 0; i < gc_cnt; i++) { + p = gc_stack[i]; + if (p) { + rb_erase(&p->rb_node, &base->rb_root); + base->total--; + call_rcu(&p->rcu, inetpeer_free_rcu); + } + } +} + +struct inet_peer *inet_getpeer(struct inet_peer_base *base, + const struct inetpeer_addr *daddr, + int create) +{ + struct inet_peer *p, *gc_stack[PEER_MAX_GC]; + struct rb_node **pp, *parent; + unsigned int gc_cnt, seq; + int invalidated; + + /* Attempt a lockless lookup first. + * Because of a concurrent writer, we might not find an existing entry. + */ + rcu_read_lock(); + seq = read_seqbegin(&base->lock); + p = lookup(daddr, base, seq, NULL, &gc_cnt, &parent, &pp); + invalidated = read_seqretry(&base->lock, seq); + rcu_read_unlock(); + + if (p) + return p; + + /* If no writer did a change during our lookup, we can return early. */ + if (!create && !invalidated) + return NULL; + + /* retry an exact lookup, taking the lock before. + * At least, nodes should be hot in our cache. + */ + parent = NULL; + write_seqlock_bh(&base->lock); + + gc_cnt = 0; + p = lookup(daddr, base, seq, gc_stack, &gc_cnt, &parent, &pp); + if (!p && create) { + p = kmem_cache_alloc(peer_cachep, GFP_ATOMIC); + if (p) { + p->daddr = *daddr; + p->dtime = (__u32)jiffies; + refcount_set(&p->refcnt, 2); + atomic_set(&p->rid, 0); + p->metrics[RTAX_LOCK-1] = INETPEER_METRICS_NEW; + p->rate_tokens = 0; + p->n_redirects = 0; + /* 60*HZ is arbitrary, but chosen enough high so that the first + * calculation of tokens is at its maximum. + */ + p->rate_last = jiffies - 60*HZ; + + rb_link_node(&p->rb_node, parent, pp); + rb_insert_color(&p->rb_node, &base->rb_root); + base->total++; + } + } + if (gc_cnt) + inet_peer_gc(base, gc_stack, gc_cnt); + write_sequnlock_bh(&base->lock); + + return p; +} +EXPORT_SYMBOL_GPL(inet_getpeer); + +void inet_putpeer(struct inet_peer *p) +{ + /* The WRITE_ONCE() pairs with itself (we run lockless) + * and the READ_ONCE() in inet_peer_gc() + */ + WRITE_ONCE(p->dtime, (__u32)jiffies); + + if (refcount_dec_and_test(&p->refcnt)) + call_rcu(&p->rcu, inetpeer_free_rcu); +} +EXPORT_SYMBOL_GPL(inet_putpeer); + +/* + * Check transmit rate limitation for given message. + * The rate information is held in the inet_peer entries now. + * This function is generic and could be used for other purposes + * too. It uses a Token bucket filter as suggested by Alexey Kuznetsov. + * + * Note that the same inet_peer fields are modified by functions in + * route.c too, but these work for packet destinations while xrlim_allow + * works for icmp destinations. This means the rate limiting information + * for one "ip object" is shared - and these ICMPs are twice limited: + * by source and by destination. + * + * RFC 1812: 4.3.2.8 SHOULD be able to limit error message rate + * SHOULD allow setting of rate limits + * + * Shared between ICMPv4 and ICMPv6. + */ +#define XRLIM_BURST_FACTOR 6 +bool inet_peer_xrlim_allow(struct inet_peer *peer, int timeout) +{ + unsigned long now, token; + bool rc = false; + + if (!peer) + return true; + + token = peer->rate_tokens; + now = jiffies; + token += now - peer->rate_last; + peer->rate_last = now; + if (token > XRLIM_BURST_FACTOR * timeout) + token = XRLIM_BURST_FACTOR * timeout; + if (token >= timeout) { + token -= timeout; + rc = true; + } + peer->rate_tokens = token; + return rc; +} +EXPORT_SYMBOL(inet_peer_xrlim_allow); + +void inetpeer_invalidate_tree(struct inet_peer_base *base) +{ + struct rb_node *p = rb_first(&base->rb_root); + + while (p) { + struct inet_peer *peer = rb_entry(p, struct inet_peer, rb_node); + + p = rb_next(p); + rb_erase(&peer->rb_node, &base->rb_root); + inet_putpeer(peer); + cond_resched(); + } + + base->total = 0; +} +EXPORT_SYMBOL(inetpeer_invalidate_tree); diff --git a/net/ipv4/ip_forward.c b/net/ipv4/ip_forward.c new file mode 100644 index 000000000..e18931a6d --- /dev/null +++ b/net/ipv4/ip_forward.c @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * The IP forwarding functionality. + * + * Authors: see ip.c + * + * Fixes: + * Many : Split from ip.c , see ip_input.c for + * history. + * Dave Gregorich : NULL ip_rt_put fix for multicast + * routing. + * Jos Vos : Add call_out_firewall before sending, + * use output device for accounting. + * Jos Vos : Call forward firewall after routing + * (always use output device). + * Mike McLagan : Routing by source + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static bool ip_exceeds_mtu(const struct sk_buff *skb, unsigned int mtu) +{ + if (skb->len <= mtu) + return false; + + if (unlikely((ip_hdr(skb)->frag_off & htons(IP_DF)) == 0)) + return false; + + /* original fragment exceeds mtu and DF is set */ + if (unlikely(IPCB(skb)->frag_max_size > mtu)) + return true; + + if (skb->ignore_df) + return false; + + if (skb_is_gso(skb) && skb_gso_validate_network_len(skb, mtu)) + return false; + + return true; +} + + +static int ip_forward_finish(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct ip_options *opt = &(IPCB(skb)->opt); + + __IP_INC_STATS(net, IPSTATS_MIB_OUTFORWDATAGRAMS); + __IP_ADD_STATS(net, IPSTATS_MIB_OUTOCTETS, skb->len); + +#ifdef CONFIG_NET_SWITCHDEV + if (skb->offload_l3_fwd_mark) { + consume_skb(skb); + return 0; + } +#endif + + if (unlikely(opt->optlen)) + ip_forward_options(skb); + + skb_clear_tstamp(skb); + return dst_output(net, sk, skb); +} + +int ip_forward(struct sk_buff *skb) +{ + u32 mtu; + struct iphdr *iph; /* Our header */ + struct rtable *rt; /* Route we use */ + struct ip_options *opt = &(IPCB(skb)->opt); + struct net *net; + SKB_DR(reason); + + /* that should never happen */ + if (skb->pkt_type != PACKET_HOST) + goto drop; + + if (unlikely(skb->sk)) + goto drop; + + if (skb_warn_if_lro(skb)) + goto drop; + + if (!xfrm4_policy_check(NULL, XFRM_POLICY_FWD, skb)) { + SKB_DR_SET(reason, XFRM_POLICY); + goto drop; + } + + if (IPCB(skb)->opt.router_alert && ip_call_ra_chain(skb)) + return NET_RX_SUCCESS; + + skb_forward_csum(skb); + net = dev_net(skb->dev); + + /* + * According to the RFC, we must first decrease the TTL field. If + * that reaches zero, we must reply an ICMP control message telling + * that the packet's lifetime expired. + */ + if (ip_hdr(skb)->ttl <= 1) + goto too_many_hops; + + if (!xfrm4_route_forward(skb)) { + SKB_DR_SET(reason, XFRM_POLICY); + goto drop; + } + + rt = skb_rtable(skb); + + if (opt->is_strictroute && rt->rt_uses_gateway) + goto sr_failed; + + IPCB(skb)->flags |= IPSKB_FORWARDED; + mtu = ip_dst_mtu_maybe_forward(&rt->dst, true); + if (ip_exceeds_mtu(skb, mtu)) { + IP_INC_STATS(net, IPSTATS_MIB_FRAGFAILS); + icmp_send(skb, ICMP_DEST_UNREACH, ICMP_FRAG_NEEDED, + htonl(mtu)); + SKB_DR_SET(reason, PKT_TOO_BIG); + goto drop; + } + + /* We are about to mangle packet. Copy it! */ + if (skb_cow(skb, LL_RESERVED_SPACE(rt->dst.dev)+rt->dst.header_len)) + goto drop; + iph = ip_hdr(skb); + + /* Decrease ttl after skb cow done */ + ip_decrease_ttl(iph); + + /* + * We now generate an ICMP HOST REDIRECT giving the route + * we calculated. + */ + if (IPCB(skb)->flags & IPSKB_DOREDIRECT && !opt->srr && + !skb_sec_path(skb)) + ip_rt_send_redirect(skb); + + if (READ_ONCE(net->ipv4.sysctl_ip_fwd_update_priority)) + skb->priority = rt_tos2priority(iph->tos); + + return NF_HOOK(NFPROTO_IPV4, NF_INET_FORWARD, + net, NULL, skb, skb->dev, rt->dst.dev, + ip_forward_finish); + +sr_failed: + /* + * Strict routing permits no gatewaying + */ + icmp_send(skb, ICMP_DEST_UNREACH, ICMP_SR_FAILED, 0); + goto drop; + +too_many_hops: + /* Tell the sender its packet died... */ + __IP_INC_STATS(net, IPSTATS_MIB_INHDRERRORS); + icmp_send(skb, ICMP_TIME_EXCEEDED, ICMP_EXC_TTL, 0); + SKB_DR_SET(reason, IP_INHDR); +drop: + kfree_skb_reason(skb, reason); + return NET_RX_DROP; +} diff --git a/net/ipv4/ip_fragment.c b/net/ipv4/ip_fragment.c new file mode 100644 index 000000000..fb1535698 --- /dev/null +++ b/net/ipv4/ip_fragment.c @@ -0,0 +1,753 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * The IP fragmentation functionality. + * + * Authors: Fred N. van Kempen + * Alan Cox + * + * Fixes: + * Alan Cox : Split from ip.c , see ip_input.c for history. + * David S. Miller : Begin massive cleanup... + * Andi Kleen : Add sysctls. + * xxxx : Overlapfrag bug. + * Ultima : ip_expire() kernel panic. + * Bill Hawes : Frag accounting and evictor fixes. + * John McDonald : 0 length frag bug. + * Alexey Kuznetsov: SMP races, threading, cleanup. + * Patrick McHardy : LRU queue of frag heads for evictor. + */ + +#define pr_fmt(fmt) "IPv4: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* NOTE. Logic of IP defragmentation is parallel to corresponding IPv6 + * code now. If you change something here, _PLEASE_ update ipv6/reassembly.c + * as well. Or notify me, at least. --ANK + */ +static const char ip_frag_cache_name[] = "ip4-frags"; + +/* Describe an entry in the "incomplete datagrams" queue. */ +struct ipq { + struct inet_frag_queue q; + + u8 ecn; /* RFC3168 support */ + u16 max_df_size; /* largest frag with DF set seen */ + int iif; + unsigned int rid; + struct inet_peer *peer; +}; + +static u8 ip4_frag_ecn(u8 tos) +{ + return 1 << (tos & INET_ECN_MASK); +} + +static struct inet_frags ip4_frags; + +static int ip_frag_reasm(struct ipq *qp, struct sk_buff *skb, + struct sk_buff *prev_tail, struct net_device *dev); + + +static void ip4_frag_init(struct inet_frag_queue *q, const void *a) +{ + struct ipq *qp = container_of(q, struct ipq, q); + struct net *net = q->fqdir->net; + + const struct frag_v4_compare_key *key = a; + + q->key.v4 = *key; + qp->ecn = 0; + qp->peer = q->fqdir->max_dist ? + inet_getpeer_v4(net->ipv4.peers, key->saddr, key->vif, 1) : + NULL; +} + +static void ip4_frag_free(struct inet_frag_queue *q) +{ + struct ipq *qp; + + qp = container_of(q, struct ipq, q); + if (qp->peer) + inet_putpeer(qp->peer); +} + + +/* Destruction primitives. */ + +static void ipq_put(struct ipq *ipq) +{ + inet_frag_put(&ipq->q); +} + +/* Kill ipq entry. It is not destroyed immediately, + * because caller (and someone more) holds reference count. + */ +static void ipq_kill(struct ipq *ipq) +{ + inet_frag_kill(&ipq->q); +} + +static bool frag_expire_skip_icmp(u32 user) +{ + return user == IP_DEFRAG_AF_PACKET || + ip_defrag_user_in_between(user, IP_DEFRAG_CONNTRACK_IN, + __IP_DEFRAG_CONNTRACK_IN_END) || + ip_defrag_user_in_between(user, IP_DEFRAG_CONNTRACK_BRIDGE_IN, + __IP_DEFRAG_CONNTRACK_BRIDGE_IN); +} + +/* + * Oops, a fragment queue timed out. Kill it and send an ICMP reply. + */ +static void ip_expire(struct timer_list *t) +{ + struct inet_frag_queue *frag = from_timer(frag, t, timer); + const struct iphdr *iph; + struct sk_buff *head = NULL; + struct net *net; + struct ipq *qp; + int err; + + qp = container_of(frag, struct ipq, q); + net = qp->q.fqdir->net; + + rcu_read_lock(); + + /* Paired with WRITE_ONCE() in fqdir_pre_exit(). */ + if (READ_ONCE(qp->q.fqdir->dead)) + goto out_rcu_unlock; + + spin_lock(&qp->q.lock); + + if (qp->q.flags & INET_FRAG_COMPLETE) + goto out; + + ipq_kill(qp); + __IP_INC_STATS(net, IPSTATS_MIB_REASMFAILS); + __IP_INC_STATS(net, IPSTATS_MIB_REASMTIMEOUT); + + if (!(qp->q.flags & INET_FRAG_FIRST_IN)) + goto out; + + /* sk_buff::dev and sk_buff::rbnode are unionized. So we + * pull the head out of the tree in order to be able to + * deal with head->dev. + */ + head = inet_frag_pull_head(&qp->q); + if (!head) + goto out; + head->dev = dev_get_by_index_rcu(net, qp->iif); + if (!head->dev) + goto out; + + + /* skb has no dst, perform route lookup again */ + iph = ip_hdr(head); + err = ip_route_input_noref(head, iph->daddr, iph->saddr, + iph->tos, head->dev); + if (err) + goto out; + + /* Only an end host needs to send an ICMP + * "Fragment Reassembly Timeout" message, per RFC792. + */ + if (frag_expire_skip_icmp(qp->q.key.v4.user) && + (skb_rtable(head)->rt_type != RTN_LOCAL)) + goto out; + + spin_unlock(&qp->q.lock); + icmp_send(head, ICMP_TIME_EXCEEDED, ICMP_EXC_FRAGTIME, 0); + goto out_rcu_unlock; + +out: + spin_unlock(&qp->q.lock); +out_rcu_unlock: + rcu_read_unlock(); + kfree_skb(head); + ipq_put(qp); +} + +/* Find the correct entry in the "incomplete datagrams" queue for + * this IP datagram, and create new one, if nothing is found. + */ +static struct ipq *ip_find(struct net *net, struct iphdr *iph, + u32 user, int vif) +{ + struct frag_v4_compare_key key = { + .saddr = iph->saddr, + .daddr = iph->daddr, + .user = user, + .vif = vif, + .id = iph->id, + .protocol = iph->protocol, + }; + struct inet_frag_queue *q; + + q = inet_frag_find(net->ipv4.fqdir, &key); + if (!q) + return NULL; + + return container_of(q, struct ipq, q); +} + +/* Is the fragment too far ahead to be part of ipq? */ +static int ip_frag_too_far(struct ipq *qp) +{ + struct inet_peer *peer = qp->peer; + unsigned int max = qp->q.fqdir->max_dist; + unsigned int start, end; + + int rc; + + if (!peer || !max) + return 0; + + start = qp->rid; + end = atomic_inc_return(&peer->rid); + qp->rid = end; + + rc = qp->q.fragments_tail && (end - start) > max; + + if (rc) + __IP_INC_STATS(qp->q.fqdir->net, IPSTATS_MIB_REASMFAILS); + + return rc; +} + +static int ip_frag_reinit(struct ipq *qp) +{ + unsigned int sum_truesize = 0; + + if (!mod_timer(&qp->q.timer, jiffies + qp->q.fqdir->timeout)) { + refcount_inc(&qp->q.refcnt); + return -ETIMEDOUT; + } + + sum_truesize = inet_frag_rbtree_purge(&qp->q.rb_fragments); + sub_frag_mem_limit(qp->q.fqdir, sum_truesize); + + qp->q.flags = 0; + qp->q.len = 0; + qp->q.meat = 0; + qp->q.rb_fragments = RB_ROOT; + qp->q.fragments_tail = NULL; + qp->q.last_run_head = NULL; + qp->iif = 0; + qp->ecn = 0; + + return 0; +} + +/* Add new segment to existing queue. */ +static int ip_frag_queue(struct ipq *qp, struct sk_buff *skb) +{ + struct net *net = qp->q.fqdir->net; + int ihl, end, flags, offset; + struct sk_buff *prev_tail; + struct net_device *dev; + unsigned int fragsize; + int err = -ENOENT; + u8 ecn; + + if (qp->q.flags & INET_FRAG_COMPLETE) + goto err; + + if (!(IPCB(skb)->flags & IPSKB_FRAG_COMPLETE) && + unlikely(ip_frag_too_far(qp)) && + unlikely(err = ip_frag_reinit(qp))) { + ipq_kill(qp); + goto err; + } + + ecn = ip4_frag_ecn(ip_hdr(skb)->tos); + offset = ntohs(ip_hdr(skb)->frag_off); + flags = offset & ~IP_OFFSET; + offset &= IP_OFFSET; + offset <<= 3; /* offset is in 8-byte chunks */ + ihl = ip_hdrlen(skb); + + /* Determine the position of this fragment. */ + end = offset + skb->len - skb_network_offset(skb) - ihl; + err = -EINVAL; + + /* Is this the final fragment? */ + if ((flags & IP_MF) == 0) { + /* If we already have some bits beyond end + * or have different end, the segment is corrupted. + */ + if (end < qp->q.len || + ((qp->q.flags & INET_FRAG_LAST_IN) && end != qp->q.len)) + goto discard_qp; + qp->q.flags |= INET_FRAG_LAST_IN; + qp->q.len = end; + } else { + if (end&7) { + end &= ~7; + if (skb->ip_summed != CHECKSUM_UNNECESSARY) + skb->ip_summed = CHECKSUM_NONE; + } + if (end > qp->q.len) { + /* Some bits beyond end -> corruption. */ + if (qp->q.flags & INET_FRAG_LAST_IN) + goto discard_qp; + qp->q.len = end; + } + } + if (end == offset) + goto discard_qp; + + err = -ENOMEM; + if (!pskb_pull(skb, skb_network_offset(skb) + ihl)) + goto discard_qp; + + err = pskb_trim_rcsum(skb, end - offset); + if (err) + goto discard_qp; + + /* Note : skb->rbnode and skb->dev share the same location. */ + dev = skb->dev; + /* Makes sure compiler wont do silly aliasing games */ + barrier(); + + prev_tail = qp->q.fragments_tail; + err = inet_frag_queue_insert(&qp->q, skb, offset, end); + if (err) + goto insert_error; + + if (dev) + qp->iif = dev->ifindex; + + qp->q.stamp = skb->tstamp; + qp->q.mono_delivery_time = skb->mono_delivery_time; + qp->q.meat += skb->len; + qp->ecn |= ecn; + add_frag_mem_limit(qp->q.fqdir, skb->truesize); + if (offset == 0) + qp->q.flags |= INET_FRAG_FIRST_IN; + + fragsize = skb->len + ihl; + + if (fragsize > qp->q.max_size) + qp->q.max_size = fragsize; + + if (ip_hdr(skb)->frag_off & htons(IP_DF) && + fragsize > qp->max_df_size) + qp->max_df_size = fragsize; + + if (qp->q.flags == (INET_FRAG_FIRST_IN | INET_FRAG_LAST_IN) && + qp->q.meat == qp->q.len) { + unsigned long orefdst = skb->_skb_refdst; + + skb->_skb_refdst = 0UL; + err = ip_frag_reasm(qp, skb, prev_tail, dev); + skb->_skb_refdst = orefdst; + if (err) + inet_frag_kill(&qp->q); + return err; + } + + skb_dst_drop(skb); + return -EINPROGRESS; + +insert_error: + if (err == IPFRAG_DUP) { + kfree_skb(skb); + return -EINVAL; + } + err = -EINVAL; + __IP_INC_STATS(net, IPSTATS_MIB_REASM_OVERLAPS); +discard_qp: + inet_frag_kill(&qp->q); + __IP_INC_STATS(net, IPSTATS_MIB_REASMFAILS); +err: + kfree_skb(skb); + return err; +} + +static bool ip_frag_coalesce_ok(const struct ipq *qp) +{ + return qp->q.key.v4.user == IP_DEFRAG_LOCAL_DELIVER; +} + +/* Build a new IP datagram from all its fragments. */ +static int ip_frag_reasm(struct ipq *qp, struct sk_buff *skb, + struct sk_buff *prev_tail, struct net_device *dev) +{ + struct net *net = qp->q.fqdir->net; + struct iphdr *iph; + void *reasm_data; + int len, err; + u8 ecn; + + ipq_kill(qp); + + ecn = ip_frag_ecn_table[qp->ecn]; + if (unlikely(ecn == 0xff)) { + err = -EINVAL; + goto out_fail; + } + + /* Make the one we just received the head. */ + reasm_data = inet_frag_reasm_prepare(&qp->q, skb, prev_tail); + if (!reasm_data) + goto out_nomem; + + len = ip_hdrlen(skb) + qp->q.len; + err = -E2BIG; + if (len > 65535) + goto out_oversize; + + inet_frag_reasm_finish(&qp->q, skb, reasm_data, + ip_frag_coalesce_ok(qp)); + + skb->dev = dev; + IPCB(skb)->frag_max_size = max(qp->max_df_size, qp->q.max_size); + + iph = ip_hdr(skb); + iph->tot_len = htons(len); + iph->tos |= ecn; + + /* When we set IP_DF on a refragmented skb we must also force a + * call to ip_fragment to avoid forwarding a DF-skb of size s while + * original sender only sent fragments of size f (where f < s). + * + * We only set DF/IPSKB_FRAG_PMTU if such DF fragment was the largest + * frag seen to avoid sending tiny DF-fragments in case skb was built + * from one very small df-fragment and one large non-df frag. + */ + if (qp->max_df_size == qp->q.max_size) { + IPCB(skb)->flags |= IPSKB_FRAG_PMTU; + iph->frag_off = htons(IP_DF); + } else { + iph->frag_off = 0; + } + + ip_send_check(iph); + + __IP_INC_STATS(net, IPSTATS_MIB_REASMOKS); + qp->q.rb_fragments = RB_ROOT; + qp->q.fragments_tail = NULL; + qp->q.last_run_head = NULL; + return 0; + +out_nomem: + net_dbg_ratelimited("queue_glue: no memory for gluing queue %p\n", qp); + err = -ENOMEM; + goto out_fail; +out_oversize: + net_info_ratelimited("Oversized IP packet from %pI4\n", &qp->q.key.v4.saddr); +out_fail: + __IP_INC_STATS(net, IPSTATS_MIB_REASMFAILS); + return err; +} + +/* Process an incoming IP datagram fragment. */ +int ip_defrag(struct net *net, struct sk_buff *skb, u32 user) +{ + struct net_device *dev = skb->dev ? : skb_dst(skb)->dev; + int vif = l3mdev_master_ifindex_rcu(dev); + struct ipq *qp; + + __IP_INC_STATS(net, IPSTATS_MIB_REASMREQDS); + skb_orphan(skb); + + /* Lookup (or create) queue header */ + qp = ip_find(net, ip_hdr(skb), user, vif); + if (qp) { + int ret; + + spin_lock(&qp->q.lock); + + ret = ip_frag_queue(qp, skb); + + spin_unlock(&qp->q.lock); + ipq_put(qp); + return ret; + } + + __IP_INC_STATS(net, IPSTATS_MIB_REASMFAILS); + kfree_skb(skb); + return -ENOMEM; +} +EXPORT_SYMBOL(ip_defrag); + +struct sk_buff *ip_check_defrag(struct net *net, struct sk_buff *skb, u32 user) +{ + struct iphdr iph; + int netoff; + u32 len; + + if (skb->protocol != htons(ETH_P_IP)) + return skb; + + netoff = skb_network_offset(skb); + + if (skb_copy_bits(skb, netoff, &iph, sizeof(iph)) < 0) + return skb; + + if (iph.ihl < 5 || iph.version != 4) + return skb; + + len = ntohs(iph.tot_len); + if (skb->len < netoff + len || len < (iph.ihl * 4)) + return skb; + + if (ip_is_fragment(&iph)) { + skb = skb_share_check(skb, GFP_ATOMIC); + if (skb) { + if (!pskb_may_pull(skb, netoff + iph.ihl * 4)) { + kfree_skb(skb); + return NULL; + } + if (pskb_trim_rcsum(skb, netoff + len)) { + kfree_skb(skb); + return NULL; + } + memset(IPCB(skb), 0, sizeof(struct inet_skb_parm)); + if (ip_defrag(net, skb, user)) + return NULL; + skb_clear_hash(skb); + } + } + return skb; +} +EXPORT_SYMBOL(ip_check_defrag); + +#ifdef CONFIG_SYSCTL +static int dist_min; + +static struct ctl_table ip4_frags_ns_ctl_table[] = { + { + .procname = "ipfrag_high_thresh", + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + }, + { + .procname = "ipfrag_low_thresh", + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + }, + { + .procname = "ipfrag_time", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "ipfrag_max_dist", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &dist_min, + }, + { } +}; + +/* secret interval has been deprecated */ +static int ip4_frags_secret_interval_unused; +static struct ctl_table ip4_frags_ctl_table[] = { + { + .procname = "ipfrag_secret_interval", + .data = &ip4_frags_secret_interval_unused, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { } +}; + +static int __net_init ip4_frags_ns_ctl_register(struct net *net) +{ + struct ctl_table *table; + struct ctl_table_header *hdr; + + table = ip4_frags_ns_ctl_table; + if (!net_eq(net, &init_net)) { + table = kmemdup(table, sizeof(ip4_frags_ns_ctl_table), GFP_KERNEL); + if (!table) + goto err_alloc; + + } + table[0].data = &net->ipv4.fqdir->high_thresh; + table[0].extra1 = &net->ipv4.fqdir->low_thresh; + table[1].data = &net->ipv4.fqdir->low_thresh; + table[1].extra2 = &net->ipv4.fqdir->high_thresh; + table[2].data = &net->ipv4.fqdir->timeout; + table[3].data = &net->ipv4.fqdir->max_dist; + + hdr = register_net_sysctl(net, "net/ipv4", table); + if (!hdr) + goto err_reg; + + net->ipv4.frags_hdr = hdr; + return 0; + +err_reg: + if (!net_eq(net, &init_net)) + kfree(table); +err_alloc: + return -ENOMEM; +} + +static void __net_exit ip4_frags_ns_ctl_unregister(struct net *net) +{ + struct ctl_table *table; + + table = net->ipv4.frags_hdr->ctl_table_arg; + unregister_net_sysctl_table(net->ipv4.frags_hdr); + kfree(table); +} + +static void __init ip4_frags_ctl_register(void) +{ + register_net_sysctl(&init_net, "net/ipv4", ip4_frags_ctl_table); +} +#else +static int ip4_frags_ns_ctl_register(struct net *net) +{ + return 0; +} + +static void ip4_frags_ns_ctl_unregister(struct net *net) +{ +} + +static void __init ip4_frags_ctl_register(void) +{ +} +#endif + +static int __net_init ipv4_frags_init_net(struct net *net) +{ + int res; + + res = fqdir_init(&net->ipv4.fqdir, &ip4_frags, net); + if (res < 0) + return res; + /* Fragment cache limits. + * + * The fragment memory accounting code, (tries to) account for + * the real memory usage, by measuring both the size of frag + * queue struct (inet_frag_queue (ipv4:ipq/ipv6:frag_queue)) + * and the SKB's truesize. + * + * A 64K fragment consumes 129736 bytes (44*2944)+200 + * (1500 truesize == 2944, sizeof(struct ipq) == 200) + * + * We will commit 4MB at one time. Should we cross that limit + * we will prune down to 3MB, making room for approx 8 big 64K + * fragments 8x128k. + */ + net->ipv4.fqdir->high_thresh = 4 * 1024 * 1024; + net->ipv4.fqdir->low_thresh = 3 * 1024 * 1024; + /* + * Important NOTE! Fragment queue must be destroyed before MSL expires. + * RFC791 is wrong proposing to prolongate timer each fragment arrival + * by TTL. + */ + net->ipv4.fqdir->timeout = IP_FRAG_TIME; + + net->ipv4.fqdir->max_dist = 64; + + res = ip4_frags_ns_ctl_register(net); + if (res < 0) + fqdir_exit(net->ipv4.fqdir); + return res; +} + +static void __net_exit ipv4_frags_pre_exit_net(struct net *net) +{ + fqdir_pre_exit(net->ipv4.fqdir); +} + +static void __net_exit ipv4_frags_exit_net(struct net *net) +{ + ip4_frags_ns_ctl_unregister(net); + fqdir_exit(net->ipv4.fqdir); +} + +static struct pernet_operations ip4_frags_ops = { + .init = ipv4_frags_init_net, + .pre_exit = ipv4_frags_pre_exit_net, + .exit = ipv4_frags_exit_net, +}; + + +static u32 ip4_key_hashfn(const void *data, u32 len, u32 seed) +{ + return jhash2(data, + sizeof(struct frag_v4_compare_key) / sizeof(u32), seed); +} + +static u32 ip4_obj_hashfn(const void *data, u32 len, u32 seed) +{ + const struct inet_frag_queue *fq = data; + + return jhash2((const u32 *)&fq->key.v4, + sizeof(struct frag_v4_compare_key) / sizeof(u32), seed); +} + +static int ip4_obj_cmpfn(struct rhashtable_compare_arg *arg, const void *ptr) +{ + const struct frag_v4_compare_key *key = arg->key; + const struct inet_frag_queue *fq = ptr; + + return !!memcmp(&fq->key, key, sizeof(*key)); +} + +static const struct rhashtable_params ip4_rhash_params = { + .head_offset = offsetof(struct inet_frag_queue, node), + .key_offset = offsetof(struct inet_frag_queue, key), + .key_len = sizeof(struct frag_v4_compare_key), + .hashfn = ip4_key_hashfn, + .obj_hashfn = ip4_obj_hashfn, + .obj_cmpfn = ip4_obj_cmpfn, + .automatic_shrinking = true, +}; + +void __init ipfrag_init(void) +{ + ip4_frags.constructor = ip4_frag_init; + ip4_frags.destructor = ip4_frag_free; + ip4_frags.qsize = sizeof(struct ipq); + ip4_frags.frag_expire = ip_expire; + ip4_frags.frags_cache_name = ip_frag_cache_name; + ip4_frags.rhash_params = ip4_rhash_params; + if (inet_frags_init(&ip4_frags)) + panic("IP: failed to allocate ip4_frags cache\n"); + ip4_frags_ctl_register(); + register_pernet_subsys(&ip4_frags_ops); +} diff --git a/net/ipv4/ip_gre.c b/net/ipv4/ip_gre.c new file mode 100644 index 000000000..d67d026d7 --- /dev/null +++ b/net/ipv4/ip_gre.c @@ -0,0 +1,1800 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Linux NET3: GRE over IP protocol decoder. + * + * Authors: Alexey Kuznetsov (kuznet@ms2.inr.ac.ru) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + Problems & solutions + -------------------- + + 1. The most important issue is detecting local dead loops. + They would cause complete host lockup in transmit, which + would be "resolved" by stack overflow or, if queueing is enabled, + with infinite looping in net_bh. + + We cannot track such dead loops during route installation, + it is infeasible task. The most general solutions would be + to keep skb->encapsulation counter (sort of local ttl), + and silently drop packet when it expires. It is a good + solution, but it supposes maintaining new variable in ALL + skb, even if no tunneling is used. + + Current solution: xmit_recursion breaks dead loops. This is a percpu + counter, since when we enter the first ndo_xmit(), cpu migration is + forbidden. We force an exit if this counter reaches RECURSION_LIMIT + + 2. Networking dead loops would not kill routers, but would really + kill network. IP hop limit plays role of "t->recursion" in this case, + if we copy it from packet being encapsulated to upper header. + It is very good solution, but it introduces two problems: + + - Routing protocols, using packets with ttl=1 (OSPF, RIP2), + do not work over tunnels. + - traceroute does not work. I planned to relay ICMP from tunnel, + so that this problem would be solved and traceroute output + would even more informative. This idea appeared to be wrong: + only Linux complies to rfc1812 now (yes, guys, Linux is the only + true router now :-)), all routers (at least, in neighbourhood of mine) + return only 8 bytes of payload. It is the end. + + Hence, if we want that OSPF worked or traceroute said something reasonable, + we should search for another solution. + + One of them is to parse packet trying to detect inner encapsulation + made by our node. It is difficult or even impossible, especially, + taking into account fragmentation. TO be short, ttl is not solution at all. + + Current solution: The solution was UNEXPECTEDLY SIMPLE. + We force DF flag on tunnels with preconfigured hop limit, + that is ALL. :-) Well, it does not remove the problem completely, + but exponential growth of network traffic is changed to linear + (branches, that exceed pmtu are pruned) and tunnel mtu + rapidly degrades to value <68, where looping stops. + Yes, it is not good if there exists a router in the loop, + which does not force DF, even when encapsulating packets have DF set. + But it is not our problem! Nobody could accuse us, we made + all that we could make. Even if it is your gated who injected + fatal route to network, even if it were you who configured + fatal static route: you are innocent. :-) + + Alexey Kuznetsov. + */ + +static bool log_ecn_error = true; +module_param(log_ecn_error, bool, 0644); +MODULE_PARM_DESC(log_ecn_error, "Log packets received with corrupted ECN"); + +static struct rtnl_link_ops ipgre_link_ops __read_mostly; +static const struct header_ops ipgre_header_ops; + +static int ipgre_tunnel_init(struct net_device *dev); +static void erspan_build_header(struct sk_buff *skb, + u32 id, u32 index, + bool truncate, bool is_ipv4); + +static unsigned int ipgre_net_id __read_mostly; +static unsigned int gre_tap_net_id __read_mostly; +static unsigned int erspan_net_id __read_mostly; + +static int ipgre_err(struct sk_buff *skb, u32 info, + const struct tnl_ptk_info *tpi) +{ + + /* All the routers (except for Linux) return only + 8 bytes of packet payload. It means, that precise relaying of + ICMP in the real Internet is absolutely infeasible. + + Moreover, Cisco "wise men" put GRE key to the third word + in GRE header. It makes impossible maintaining even soft + state for keyed GRE tunnels with enabled checksum. Tell + them "thank you". + + Well, I wonder, rfc1812 was written by Cisco employee, + what the hell these idiots break standards established + by themselves??? + */ + struct net *net = dev_net(skb->dev); + struct ip_tunnel_net *itn; + const struct iphdr *iph; + const int type = icmp_hdr(skb)->type; + const int code = icmp_hdr(skb)->code; + unsigned int data_len = 0; + struct ip_tunnel *t; + + if (tpi->proto == htons(ETH_P_TEB)) + itn = net_generic(net, gre_tap_net_id); + else if (tpi->proto == htons(ETH_P_ERSPAN) || + tpi->proto == htons(ETH_P_ERSPAN2)) + itn = net_generic(net, erspan_net_id); + else + itn = net_generic(net, ipgre_net_id); + + iph = (const struct iphdr *)(icmp_hdr(skb) + 1); + t = ip_tunnel_lookup(itn, skb->dev->ifindex, tpi->flags, + iph->daddr, iph->saddr, tpi->key); + + if (!t) + return -ENOENT; + + switch (type) { + default: + case ICMP_PARAMETERPROB: + return 0; + + case ICMP_DEST_UNREACH: + switch (code) { + case ICMP_SR_FAILED: + case ICMP_PORT_UNREACH: + /* Impossible event. */ + return 0; + default: + /* All others are translated to HOST_UNREACH. + rfc2003 contains "deep thoughts" about NET_UNREACH, + I believe they are just ether pollution. --ANK + */ + break; + } + break; + + case ICMP_TIME_EXCEEDED: + if (code != ICMP_EXC_TTL) + return 0; + data_len = icmp_hdr(skb)->un.reserved[1] * 4; /* RFC 4884 4.1 */ + break; + + case ICMP_REDIRECT: + break; + } + +#if IS_ENABLED(CONFIG_IPV6) + if (tpi->proto == htons(ETH_P_IPV6) && + !ip6_err_gen_icmpv6_unreach(skb, iph->ihl * 4 + tpi->hdr_len, + type, data_len)) + return 0; +#endif + + if (t->parms.iph.daddr == 0 || + ipv4_is_multicast(t->parms.iph.daddr)) + return 0; + + if (t->parms.iph.ttl == 0 && type == ICMP_TIME_EXCEEDED) + return 0; + + if (time_before(jiffies, t->err_time + IPTUNNEL_ERR_TIMEO)) + t->err_count++; + else + t->err_count = 1; + t->err_time = jiffies; + + return 0; +} + +static void gre_err(struct sk_buff *skb, u32 info) +{ + /* All the routers (except for Linux) return only + * 8 bytes of packet payload. It means, that precise relaying of + * ICMP in the real Internet is absolutely infeasible. + * + * Moreover, Cisco "wise men" put GRE key to the third word + * in GRE header. It makes impossible maintaining even soft + * state for keyed + * GRE tunnels with enabled checksum. Tell them "thank you". + * + * Well, I wonder, rfc1812 was written by Cisco employee, + * what the hell these idiots break standards established + * by themselves??? + */ + + const struct iphdr *iph = (struct iphdr *)skb->data; + const int type = icmp_hdr(skb)->type; + const int code = icmp_hdr(skb)->code; + struct tnl_ptk_info tpi; + + if (gre_parse_header(skb, &tpi, NULL, htons(ETH_P_IP), + iph->ihl * 4) < 0) + return; + + if (type == ICMP_DEST_UNREACH && code == ICMP_FRAG_NEEDED) { + ipv4_update_pmtu(skb, dev_net(skb->dev), info, + skb->dev->ifindex, IPPROTO_GRE); + return; + } + if (type == ICMP_REDIRECT) { + ipv4_redirect(skb, dev_net(skb->dev), skb->dev->ifindex, + IPPROTO_GRE); + return; + } + + ipgre_err(skb, info, &tpi); +} + +static bool is_erspan_type1(int gre_hdr_len) +{ + /* Both ERSPAN type I (version 0) and type II (version 1) use + * protocol 0x88BE, but the type I has only 4-byte GRE header, + * while type II has 8-byte. + */ + return gre_hdr_len == 4; +} + +static int erspan_rcv(struct sk_buff *skb, struct tnl_ptk_info *tpi, + int gre_hdr_len) +{ + struct net *net = dev_net(skb->dev); + struct metadata_dst *tun_dst = NULL; + struct erspan_base_hdr *ershdr; + struct ip_tunnel_net *itn; + struct ip_tunnel *tunnel; + const struct iphdr *iph; + struct erspan_md2 *md2; + int ver; + int len; + + itn = net_generic(net, erspan_net_id); + iph = ip_hdr(skb); + if (is_erspan_type1(gre_hdr_len)) { + ver = 0; + tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, + tpi->flags | TUNNEL_NO_KEY, + iph->saddr, iph->daddr, 0); + } else { + ershdr = (struct erspan_base_hdr *)(skb->data + gre_hdr_len); + ver = ershdr->ver; + tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, + tpi->flags | TUNNEL_KEY, + iph->saddr, iph->daddr, tpi->key); + } + + if (tunnel) { + if (is_erspan_type1(gre_hdr_len)) + len = gre_hdr_len; + else + len = gre_hdr_len + erspan_hdr_len(ver); + + if (unlikely(!pskb_may_pull(skb, len))) + return PACKET_REJECT; + + if (__iptunnel_pull_header(skb, + len, + htons(ETH_P_TEB), + false, false) < 0) + goto drop; + + if (tunnel->collect_md) { + struct erspan_metadata *pkt_md, *md; + struct ip_tunnel_info *info; + unsigned char *gh; + __be64 tun_id; + __be16 flags; + + tpi->flags |= TUNNEL_KEY; + flags = tpi->flags; + tun_id = key32_to_tunnel_id(tpi->key); + + tun_dst = ip_tun_rx_dst(skb, flags, + tun_id, sizeof(*md)); + if (!tun_dst) + return PACKET_REJECT; + + /* skb can be uncloned in __iptunnel_pull_header, so + * old pkt_md is no longer valid and we need to reset + * it + */ + gh = skb_network_header(skb) + + skb_network_header_len(skb); + pkt_md = (struct erspan_metadata *)(gh + gre_hdr_len + + sizeof(*ershdr)); + md = ip_tunnel_info_opts(&tun_dst->u.tun_info); + md->version = ver; + md2 = &md->u.md2; + memcpy(md2, pkt_md, ver == 1 ? ERSPAN_V1_MDSIZE : + ERSPAN_V2_MDSIZE); + + info = &tun_dst->u.tun_info; + info->key.tun_flags |= TUNNEL_ERSPAN_OPT; + info->options_len = sizeof(*md); + } + + skb_reset_mac_header(skb); + ip_tunnel_rcv(tunnel, skb, tpi, tun_dst, log_ecn_error); + return PACKET_RCVD; + } + return PACKET_REJECT; + +drop: + kfree_skb(skb); + return PACKET_RCVD; +} + +static int __ipgre_rcv(struct sk_buff *skb, const struct tnl_ptk_info *tpi, + struct ip_tunnel_net *itn, int hdr_len, bool raw_proto) +{ + struct metadata_dst *tun_dst = NULL; + const struct iphdr *iph; + struct ip_tunnel *tunnel; + + iph = ip_hdr(skb); + tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, tpi->flags, + iph->saddr, iph->daddr, tpi->key); + + if (tunnel) { + const struct iphdr *tnl_params; + + if (__iptunnel_pull_header(skb, hdr_len, tpi->proto, + raw_proto, false) < 0) + goto drop; + + /* Special case for ipgre_header_parse(), which expects the + * mac_header to point to the outer IP header. + */ + if (tunnel->dev->header_ops == &ipgre_header_ops) + skb_pop_mac_header(skb); + else + skb_reset_mac_header(skb); + + tnl_params = &tunnel->parms.iph; + if (tunnel->collect_md || tnl_params->daddr == 0) { + __be16 flags; + __be64 tun_id; + + flags = tpi->flags & (TUNNEL_CSUM | TUNNEL_KEY); + tun_id = key32_to_tunnel_id(tpi->key); + tun_dst = ip_tun_rx_dst(skb, flags, tun_id, 0); + if (!tun_dst) + return PACKET_REJECT; + } + + ip_tunnel_rcv(tunnel, skb, tpi, tun_dst, log_ecn_error); + return PACKET_RCVD; + } + return PACKET_NEXT; + +drop: + kfree_skb(skb); + return PACKET_RCVD; +} + +static int ipgre_rcv(struct sk_buff *skb, const struct tnl_ptk_info *tpi, + int hdr_len) +{ + struct net *net = dev_net(skb->dev); + struct ip_tunnel_net *itn; + int res; + + if (tpi->proto == htons(ETH_P_TEB)) + itn = net_generic(net, gre_tap_net_id); + else + itn = net_generic(net, ipgre_net_id); + + res = __ipgre_rcv(skb, tpi, itn, hdr_len, false); + if (res == PACKET_NEXT && tpi->proto == htons(ETH_P_TEB)) { + /* ipgre tunnels in collect metadata mode should receive + * also ETH_P_TEB traffic. + */ + itn = net_generic(net, ipgre_net_id); + res = __ipgre_rcv(skb, tpi, itn, hdr_len, true); + } + return res; +} + +static int gre_rcv(struct sk_buff *skb) +{ + struct tnl_ptk_info tpi; + bool csum_err = false; + int hdr_len; + +#ifdef CONFIG_NET_IPGRE_BROADCAST + if (ipv4_is_multicast(ip_hdr(skb)->daddr)) { + /* Looped back packet, drop it! */ + if (rt_is_output_route(skb_rtable(skb))) + goto drop; + } +#endif + + hdr_len = gre_parse_header(skb, &tpi, &csum_err, htons(ETH_P_IP), 0); + if (hdr_len < 0) + goto drop; + + if (unlikely(tpi.proto == htons(ETH_P_ERSPAN) || + tpi.proto == htons(ETH_P_ERSPAN2))) { + if (erspan_rcv(skb, &tpi, hdr_len) == PACKET_RCVD) + return 0; + goto out; + } + + if (ipgre_rcv(skb, &tpi, hdr_len) == PACKET_RCVD) + return 0; + +out: + icmp_send(skb, ICMP_DEST_UNREACH, ICMP_PORT_UNREACH, 0); +drop: + kfree_skb(skb); + return 0; +} + +static void __gre_xmit(struct sk_buff *skb, struct net_device *dev, + const struct iphdr *tnl_params, + __be16 proto) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + __be16 flags = tunnel->parms.o_flags; + + /* Push GRE header. */ + gre_build_header(skb, tunnel->tun_hlen, + flags, proto, tunnel->parms.o_key, + (flags & TUNNEL_SEQ) ? htonl(atomic_fetch_inc(&tunnel->o_seqno)) : 0); + + ip_tunnel_xmit(skb, dev, tnl_params, tnl_params->protocol); +} + +static int gre_handle_offloads(struct sk_buff *skb, bool csum) +{ + return iptunnel_handle_offloads(skb, csum ? SKB_GSO_GRE_CSUM : SKB_GSO_GRE); +} + +static void gre_fb_xmit(struct sk_buff *skb, struct net_device *dev, + __be16 proto) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + struct ip_tunnel_info *tun_info; + const struct ip_tunnel_key *key; + int tunnel_hlen; + __be16 flags; + + tun_info = skb_tunnel_info(skb); + if (unlikely(!tun_info || !(tun_info->mode & IP_TUNNEL_INFO_TX) || + ip_tunnel_info_af(tun_info) != AF_INET)) + goto err_free_skb; + + key = &tun_info->key; + tunnel_hlen = gre_calc_hlen(key->tun_flags); + + if (skb_cow_head(skb, dev->needed_headroom)) + goto err_free_skb; + + /* Push Tunnel header. */ + if (gre_handle_offloads(skb, !!(tun_info->key.tun_flags & TUNNEL_CSUM))) + goto err_free_skb; + + flags = tun_info->key.tun_flags & + (TUNNEL_CSUM | TUNNEL_KEY | TUNNEL_SEQ); + gre_build_header(skb, tunnel_hlen, flags, proto, + tunnel_id_to_key32(tun_info->key.tun_id), + (flags & TUNNEL_SEQ) ? htonl(atomic_fetch_inc(&tunnel->o_seqno)) : 0); + + ip_md_tunnel_xmit(skb, dev, IPPROTO_GRE, tunnel_hlen); + + return; + +err_free_skb: + kfree_skb(skb); + dev->stats.tx_dropped++; +} + +static void erspan_fb_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + struct ip_tunnel_info *tun_info; + const struct ip_tunnel_key *key; + struct erspan_metadata *md; + bool truncate = false; + __be16 proto; + int tunnel_hlen; + int version; + int nhoff; + + tun_info = skb_tunnel_info(skb); + if (unlikely(!tun_info || !(tun_info->mode & IP_TUNNEL_INFO_TX) || + ip_tunnel_info_af(tun_info) != AF_INET)) + goto err_free_skb; + + key = &tun_info->key; + if (!(tun_info->key.tun_flags & TUNNEL_ERSPAN_OPT)) + goto err_free_skb; + if (tun_info->options_len < sizeof(*md)) + goto err_free_skb; + md = ip_tunnel_info_opts(tun_info); + + /* ERSPAN has fixed 8 byte GRE header */ + version = md->version; + tunnel_hlen = 8 + erspan_hdr_len(version); + + if (skb_cow_head(skb, dev->needed_headroom)) + goto err_free_skb; + + if (gre_handle_offloads(skb, false)) + goto err_free_skb; + + if (skb->len > dev->mtu + dev->hard_header_len) { + pskb_trim(skb, dev->mtu + dev->hard_header_len); + truncate = true; + } + + nhoff = skb_network_offset(skb); + if (skb->protocol == htons(ETH_P_IP) && + (ntohs(ip_hdr(skb)->tot_len) > skb->len - nhoff)) + truncate = true; + + if (skb->protocol == htons(ETH_P_IPV6)) { + int thoff; + + if (skb_transport_header_was_set(skb)) + thoff = skb_transport_offset(skb); + else + thoff = nhoff + sizeof(struct ipv6hdr); + if (ntohs(ipv6_hdr(skb)->payload_len) > skb->len - thoff) + truncate = true; + } + + if (version == 1) { + erspan_build_header(skb, ntohl(tunnel_id_to_key32(key->tun_id)), + ntohl(md->u.index), truncate, true); + proto = htons(ETH_P_ERSPAN); + } else if (version == 2) { + erspan_build_header_v2(skb, + ntohl(tunnel_id_to_key32(key->tun_id)), + md->u.md2.dir, + get_hwid(&md->u.md2), + truncate, true); + proto = htons(ETH_P_ERSPAN2); + } else { + goto err_free_skb; + } + + gre_build_header(skb, 8, TUNNEL_SEQ, + proto, 0, htonl(atomic_fetch_inc(&tunnel->o_seqno))); + + ip_md_tunnel_xmit(skb, dev, IPPROTO_GRE, tunnel_hlen); + + return; + +err_free_skb: + kfree_skb(skb); + dev->stats.tx_dropped++; +} + +static int gre_fill_metadata_dst(struct net_device *dev, struct sk_buff *skb) +{ + struct ip_tunnel_info *info = skb_tunnel_info(skb); + const struct ip_tunnel_key *key; + struct rtable *rt; + struct flowi4 fl4; + + if (ip_tunnel_info_af(info) != AF_INET) + return -EINVAL; + + key = &info->key; + ip_tunnel_init_flow(&fl4, IPPROTO_GRE, key->u.ipv4.dst, key->u.ipv4.src, + tunnel_id_to_key32(key->tun_id), + key->tos & ~INET_ECN_MASK, dev_net(dev), 0, + skb->mark, skb_get_hash(skb), key->flow_flags); + rt = ip_route_output_key(dev_net(dev), &fl4); + if (IS_ERR(rt)) + return PTR_ERR(rt); + + ip_rt_put(rt); + info->key.u.ipv4.src = fl4.saddr; + return 0; +} + +static netdev_tx_t ipgre_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + const struct iphdr *tnl_params; + + if (!pskb_inet_may_pull(skb)) + goto free_skb; + + if (tunnel->collect_md) { + gre_fb_xmit(skb, dev, skb->protocol); + return NETDEV_TX_OK; + } + + if (dev->header_ops) { + int pull_len = tunnel->hlen + sizeof(struct iphdr); + + if (skb_cow_head(skb, 0)) + goto free_skb; + + tnl_params = (const struct iphdr *)skb->data; + + if (!pskb_network_may_pull(skb, pull_len)) + goto free_skb; + + /* ip_tunnel_xmit() needs skb->data pointing to gre header. */ + skb_pull(skb, pull_len); + skb_reset_mac_header(skb); + + if (skb->ip_summed == CHECKSUM_PARTIAL && + skb_checksum_start(skb) < skb->data) + goto free_skb; + } else { + if (skb_cow_head(skb, dev->needed_headroom)) + goto free_skb; + + tnl_params = &tunnel->parms.iph; + } + + if (gre_handle_offloads(skb, !!(tunnel->parms.o_flags & TUNNEL_CSUM))) + goto free_skb; + + __gre_xmit(skb, dev, tnl_params, skb->protocol); + return NETDEV_TX_OK; + +free_skb: + kfree_skb(skb); + dev->stats.tx_dropped++; + return NETDEV_TX_OK; +} + +static netdev_tx_t erspan_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + bool truncate = false; + __be16 proto; + + if (!pskb_inet_may_pull(skb)) + goto free_skb; + + if (tunnel->collect_md) { + erspan_fb_xmit(skb, dev); + return NETDEV_TX_OK; + } + + if (gre_handle_offloads(skb, false)) + goto free_skb; + + if (skb_cow_head(skb, dev->needed_headroom)) + goto free_skb; + + if (skb->len > dev->mtu + dev->hard_header_len) { + pskb_trim(skb, dev->mtu + dev->hard_header_len); + truncate = true; + } + + /* Push ERSPAN header */ + if (tunnel->erspan_ver == 0) { + proto = htons(ETH_P_ERSPAN); + tunnel->parms.o_flags &= ~TUNNEL_SEQ; + } else if (tunnel->erspan_ver == 1) { + erspan_build_header(skb, ntohl(tunnel->parms.o_key), + tunnel->index, + truncate, true); + proto = htons(ETH_P_ERSPAN); + } else if (tunnel->erspan_ver == 2) { + erspan_build_header_v2(skb, ntohl(tunnel->parms.o_key), + tunnel->dir, tunnel->hwid, + truncate, true); + proto = htons(ETH_P_ERSPAN2); + } else { + goto free_skb; + } + + tunnel->parms.o_flags &= ~TUNNEL_KEY; + __gre_xmit(skb, dev, &tunnel->parms.iph, proto); + return NETDEV_TX_OK; + +free_skb: + kfree_skb(skb); + dev->stats.tx_dropped++; + return NETDEV_TX_OK; +} + +static netdev_tx_t gre_tap_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + + if (!pskb_inet_may_pull(skb)) + goto free_skb; + + if (tunnel->collect_md) { + gre_fb_xmit(skb, dev, htons(ETH_P_TEB)); + return NETDEV_TX_OK; + } + + if (gre_handle_offloads(skb, !!(tunnel->parms.o_flags & TUNNEL_CSUM))) + goto free_skb; + + if (skb_cow_head(skb, dev->needed_headroom)) + goto free_skb; + + __gre_xmit(skb, dev, &tunnel->parms.iph, htons(ETH_P_TEB)); + return NETDEV_TX_OK; + +free_skb: + kfree_skb(skb); + dev->stats.tx_dropped++; + return NETDEV_TX_OK; +} + +static void ipgre_link_update(struct net_device *dev, bool set_mtu) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + __be16 flags; + int len; + + len = tunnel->tun_hlen; + tunnel->tun_hlen = gre_calc_hlen(tunnel->parms.o_flags); + len = tunnel->tun_hlen - len; + tunnel->hlen = tunnel->hlen + len; + + if (dev->header_ops) + dev->hard_header_len += len; + else + dev->needed_headroom += len; + + if (set_mtu) + dev->mtu = max_t(int, dev->mtu - len, 68); + + flags = tunnel->parms.o_flags; + + if (flags & TUNNEL_SEQ || + (flags & TUNNEL_CSUM && tunnel->encap.type != TUNNEL_ENCAP_NONE)) { + dev->features &= ~NETIF_F_GSO_SOFTWARE; + dev->hw_features &= ~NETIF_F_GSO_SOFTWARE; + } else { + dev->features |= NETIF_F_GSO_SOFTWARE; + dev->hw_features |= NETIF_F_GSO_SOFTWARE; + } +} + +static int ipgre_tunnel_ctl(struct net_device *dev, struct ip_tunnel_parm *p, + int cmd) +{ + int err; + + if (cmd == SIOCADDTUNNEL || cmd == SIOCCHGTUNNEL) { + if (p->iph.version != 4 || p->iph.protocol != IPPROTO_GRE || + p->iph.ihl != 5 || (p->iph.frag_off & htons(~IP_DF)) || + ((p->i_flags | p->o_flags) & (GRE_VERSION | GRE_ROUTING))) + return -EINVAL; + } + + p->i_flags = gre_flags_to_tnl_flags(p->i_flags); + p->o_flags = gre_flags_to_tnl_flags(p->o_flags); + + err = ip_tunnel_ctl(dev, p, cmd); + if (err) + return err; + + if (cmd == SIOCCHGTUNNEL) { + struct ip_tunnel *t = netdev_priv(dev); + + t->parms.i_flags = p->i_flags; + t->parms.o_flags = p->o_flags; + + if (strcmp(dev->rtnl_link_ops->kind, "erspan")) + ipgre_link_update(dev, true); + } + + p->i_flags = gre_tnl_flags_to_gre_flags(p->i_flags); + p->o_flags = gre_tnl_flags_to_gre_flags(p->o_flags); + return 0; +} + +/* Nice toy. Unfortunately, useless in real life :-) + It allows to construct virtual multiprotocol broadcast "LAN" + over the Internet, provided multicast routing is tuned. + + + I have no idea was this bicycle invented before me, + so that I had to set ARPHRD_IPGRE to a random value. + I have an impression, that Cisco could make something similar, + but this feature is apparently missing in IOS<=11.2(8). + + I set up 10.66.66/24 and fec0:6666:6666::0/96 as virtual networks + with broadcast 224.66.66.66. If you have access to mbone, play with me :-) + + ping -t 255 224.66.66.66 + + If nobody answers, mbone does not work. + + ip tunnel add Universe mode gre remote 224.66.66.66 local ttl 255 + ip addr add 10.66.66./24 dev Universe + ifconfig Universe up + ifconfig Universe add fe80::/10 + ifconfig Universe add fec0:6666:6666::/96 + ftp 10.66.66.66 + ... + ftp fec0:6666:6666::193.233.7.65 + ... + */ +static int ipgre_header(struct sk_buff *skb, struct net_device *dev, + unsigned short type, + const void *daddr, const void *saddr, unsigned int len) +{ + struct ip_tunnel *t = netdev_priv(dev); + struct iphdr *iph; + struct gre_base_hdr *greh; + + iph = skb_push(skb, t->hlen + sizeof(*iph)); + greh = (struct gre_base_hdr *)(iph+1); + greh->flags = gre_tnl_flags_to_gre_flags(t->parms.o_flags); + greh->protocol = htons(type); + + memcpy(iph, &t->parms.iph, sizeof(struct iphdr)); + + /* Set the source hardware address. */ + if (saddr) + memcpy(&iph->saddr, saddr, 4); + if (daddr) + memcpy(&iph->daddr, daddr, 4); + if (iph->daddr) + return t->hlen + sizeof(*iph); + + return -(t->hlen + sizeof(*iph)); +} + +static int ipgre_header_parse(const struct sk_buff *skb, unsigned char *haddr) +{ + const struct iphdr *iph = (const struct iphdr *) skb_mac_header(skb); + memcpy(haddr, &iph->saddr, 4); + return 4; +} + +static const struct header_ops ipgre_header_ops = { + .create = ipgre_header, + .parse = ipgre_header_parse, +}; + +#ifdef CONFIG_NET_IPGRE_BROADCAST +static int ipgre_open(struct net_device *dev) +{ + struct ip_tunnel *t = netdev_priv(dev); + + if (ipv4_is_multicast(t->parms.iph.daddr)) { + struct flowi4 fl4; + struct rtable *rt; + + rt = ip_route_output_gre(t->net, &fl4, + t->parms.iph.daddr, + t->parms.iph.saddr, + t->parms.o_key, + RT_TOS(t->parms.iph.tos), + t->parms.link); + if (IS_ERR(rt)) + return -EADDRNOTAVAIL; + dev = rt->dst.dev; + ip_rt_put(rt); + if (!__in_dev_get_rtnl(dev)) + return -EADDRNOTAVAIL; + t->mlink = dev->ifindex; + ip_mc_inc_group(__in_dev_get_rtnl(dev), t->parms.iph.daddr); + } + return 0; +} + +static int ipgre_close(struct net_device *dev) +{ + struct ip_tunnel *t = netdev_priv(dev); + + if (ipv4_is_multicast(t->parms.iph.daddr) && t->mlink) { + struct in_device *in_dev; + in_dev = inetdev_by_index(t->net, t->mlink); + if (in_dev) + ip_mc_dec_group(in_dev, t->parms.iph.daddr); + } + return 0; +} +#endif + +static const struct net_device_ops ipgre_netdev_ops = { + .ndo_init = ipgre_tunnel_init, + .ndo_uninit = ip_tunnel_uninit, +#ifdef CONFIG_NET_IPGRE_BROADCAST + .ndo_open = ipgre_open, + .ndo_stop = ipgre_close, +#endif + .ndo_start_xmit = ipgre_xmit, + .ndo_siocdevprivate = ip_tunnel_siocdevprivate, + .ndo_change_mtu = ip_tunnel_change_mtu, + .ndo_get_stats64 = dev_get_tstats64, + .ndo_get_iflink = ip_tunnel_get_iflink, + .ndo_tunnel_ctl = ipgre_tunnel_ctl, +}; + +#define GRE_FEATURES (NETIF_F_SG | \ + NETIF_F_FRAGLIST | \ + NETIF_F_HIGHDMA | \ + NETIF_F_HW_CSUM) + +static void ipgre_tunnel_setup(struct net_device *dev) +{ + dev->netdev_ops = &ipgre_netdev_ops; + dev->type = ARPHRD_IPGRE; + ip_tunnel_setup(dev, ipgre_net_id); +} + +static void __gre_tunnel_init(struct net_device *dev) +{ + struct ip_tunnel *tunnel; + __be16 flags; + + tunnel = netdev_priv(dev); + tunnel->tun_hlen = gre_calc_hlen(tunnel->parms.o_flags); + tunnel->parms.iph.protocol = IPPROTO_GRE; + + tunnel->hlen = tunnel->tun_hlen + tunnel->encap_hlen; + dev->needed_headroom = tunnel->hlen + sizeof(tunnel->parms.iph); + + dev->features |= GRE_FEATURES | NETIF_F_LLTX; + dev->hw_features |= GRE_FEATURES; + + flags = tunnel->parms.o_flags; + + /* TCP offload with GRE SEQ is not supported, nor can we support 2 + * levels of outer headers requiring an update. + */ + if (flags & TUNNEL_SEQ) + return; + if (flags & TUNNEL_CSUM && tunnel->encap.type != TUNNEL_ENCAP_NONE) + return; + + dev->features |= NETIF_F_GSO_SOFTWARE; + dev->hw_features |= NETIF_F_GSO_SOFTWARE; +} + +static int ipgre_tunnel_init(struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + struct iphdr *iph = &tunnel->parms.iph; + + __gre_tunnel_init(dev); + + __dev_addr_set(dev, &iph->saddr, 4); + memcpy(dev->broadcast, &iph->daddr, 4); + + dev->flags = IFF_NOARP; + netif_keep_dst(dev); + dev->addr_len = 4; + + if (iph->daddr && !tunnel->collect_md) { +#ifdef CONFIG_NET_IPGRE_BROADCAST + if (ipv4_is_multicast(iph->daddr)) { + if (!iph->saddr) + return -EINVAL; + dev->flags = IFF_BROADCAST; + dev->header_ops = &ipgre_header_ops; + dev->hard_header_len = tunnel->hlen + sizeof(*iph); + dev->needed_headroom = 0; + } +#endif + } else if (!tunnel->collect_md) { + dev->header_ops = &ipgre_header_ops; + dev->hard_header_len = tunnel->hlen + sizeof(*iph); + dev->needed_headroom = 0; + } + + return ip_tunnel_init(dev); +} + +static const struct gre_protocol ipgre_protocol = { + .handler = gre_rcv, + .err_handler = gre_err, +}; + +static int __net_init ipgre_init_net(struct net *net) +{ + return ip_tunnel_init_net(net, ipgre_net_id, &ipgre_link_ops, NULL); +} + +static void __net_exit ipgre_exit_batch_net(struct list_head *list_net) +{ + ip_tunnel_delete_nets(list_net, ipgre_net_id, &ipgre_link_ops); +} + +static struct pernet_operations ipgre_net_ops = { + .init = ipgre_init_net, + .exit_batch = ipgre_exit_batch_net, + .id = &ipgre_net_id, + .size = sizeof(struct ip_tunnel_net), +}; + +static int ipgre_tunnel_validate(struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + __be16 flags; + + if (!data) + return 0; + + flags = 0; + if (data[IFLA_GRE_IFLAGS]) + flags |= nla_get_be16(data[IFLA_GRE_IFLAGS]); + if (data[IFLA_GRE_OFLAGS]) + flags |= nla_get_be16(data[IFLA_GRE_OFLAGS]); + if (flags & (GRE_VERSION|GRE_ROUTING)) + return -EINVAL; + + if (data[IFLA_GRE_COLLECT_METADATA] && + data[IFLA_GRE_ENCAP_TYPE] && + nla_get_u16(data[IFLA_GRE_ENCAP_TYPE]) != TUNNEL_ENCAP_NONE) + return -EINVAL; + + return 0; +} + +static int ipgre_tap_validate(struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + __be32 daddr; + + if (tb[IFLA_ADDRESS]) { + if (nla_len(tb[IFLA_ADDRESS]) != ETH_ALEN) + return -EINVAL; + if (!is_valid_ether_addr(nla_data(tb[IFLA_ADDRESS]))) + return -EADDRNOTAVAIL; + } + + if (!data) + goto out; + + if (data[IFLA_GRE_REMOTE]) { + memcpy(&daddr, nla_data(data[IFLA_GRE_REMOTE]), 4); + if (!daddr) + return -EINVAL; + } + +out: + return ipgre_tunnel_validate(tb, data, extack); +} + +static int erspan_validate(struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + __be16 flags = 0; + int ret; + + if (!data) + return 0; + + ret = ipgre_tap_validate(tb, data, extack); + if (ret) + return ret; + + if (data[IFLA_GRE_ERSPAN_VER] && + nla_get_u8(data[IFLA_GRE_ERSPAN_VER]) == 0) + return 0; + + /* ERSPAN type II/III should only have GRE sequence and key flag */ + if (data[IFLA_GRE_OFLAGS]) + flags |= nla_get_be16(data[IFLA_GRE_OFLAGS]); + if (data[IFLA_GRE_IFLAGS]) + flags |= nla_get_be16(data[IFLA_GRE_IFLAGS]); + if (!data[IFLA_GRE_COLLECT_METADATA] && + flags != (GRE_SEQ | GRE_KEY)) + return -EINVAL; + + /* ERSPAN Session ID only has 10-bit. Since we reuse + * 32-bit key field as ID, check it's range. + */ + if (data[IFLA_GRE_IKEY] && + (ntohl(nla_get_be32(data[IFLA_GRE_IKEY])) & ~ID_MASK)) + return -EINVAL; + + if (data[IFLA_GRE_OKEY] && + (ntohl(nla_get_be32(data[IFLA_GRE_OKEY])) & ~ID_MASK)) + return -EINVAL; + + return 0; +} + +static int ipgre_netlink_parms(struct net_device *dev, + struct nlattr *data[], + struct nlattr *tb[], + struct ip_tunnel_parm *parms, + __u32 *fwmark) +{ + struct ip_tunnel *t = netdev_priv(dev); + + memset(parms, 0, sizeof(*parms)); + + parms->iph.protocol = IPPROTO_GRE; + + if (!data) + return 0; + + if (data[IFLA_GRE_LINK]) + parms->link = nla_get_u32(data[IFLA_GRE_LINK]); + + if (data[IFLA_GRE_IFLAGS]) + parms->i_flags = gre_flags_to_tnl_flags(nla_get_be16(data[IFLA_GRE_IFLAGS])); + + if (data[IFLA_GRE_OFLAGS]) + parms->o_flags = gre_flags_to_tnl_flags(nla_get_be16(data[IFLA_GRE_OFLAGS])); + + if (data[IFLA_GRE_IKEY]) + parms->i_key = nla_get_be32(data[IFLA_GRE_IKEY]); + + if (data[IFLA_GRE_OKEY]) + parms->o_key = nla_get_be32(data[IFLA_GRE_OKEY]); + + if (data[IFLA_GRE_LOCAL]) + parms->iph.saddr = nla_get_in_addr(data[IFLA_GRE_LOCAL]); + + if (data[IFLA_GRE_REMOTE]) + parms->iph.daddr = nla_get_in_addr(data[IFLA_GRE_REMOTE]); + + if (data[IFLA_GRE_TTL]) + parms->iph.ttl = nla_get_u8(data[IFLA_GRE_TTL]); + + if (data[IFLA_GRE_TOS]) + parms->iph.tos = nla_get_u8(data[IFLA_GRE_TOS]); + + if (!data[IFLA_GRE_PMTUDISC] || nla_get_u8(data[IFLA_GRE_PMTUDISC])) { + if (t->ignore_df) + return -EINVAL; + parms->iph.frag_off = htons(IP_DF); + } + + if (data[IFLA_GRE_COLLECT_METADATA]) { + t->collect_md = true; + if (dev->type == ARPHRD_IPGRE) + dev->type = ARPHRD_NONE; + } + + if (data[IFLA_GRE_IGNORE_DF]) { + if (nla_get_u8(data[IFLA_GRE_IGNORE_DF]) + && (parms->iph.frag_off & htons(IP_DF))) + return -EINVAL; + t->ignore_df = !!nla_get_u8(data[IFLA_GRE_IGNORE_DF]); + } + + if (data[IFLA_GRE_FWMARK]) + *fwmark = nla_get_u32(data[IFLA_GRE_FWMARK]); + + return 0; +} + +static int erspan_netlink_parms(struct net_device *dev, + struct nlattr *data[], + struct nlattr *tb[], + struct ip_tunnel_parm *parms, + __u32 *fwmark) +{ + struct ip_tunnel *t = netdev_priv(dev); + int err; + + err = ipgre_netlink_parms(dev, data, tb, parms, fwmark); + if (err) + return err; + if (!data) + return 0; + + if (data[IFLA_GRE_ERSPAN_VER]) { + t->erspan_ver = nla_get_u8(data[IFLA_GRE_ERSPAN_VER]); + + if (t->erspan_ver > 2) + return -EINVAL; + } + + if (t->erspan_ver == 1) { + if (data[IFLA_GRE_ERSPAN_INDEX]) { + t->index = nla_get_u32(data[IFLA_GRE_ERSPAN_INDEX]); + if (t->index & ~INDEX_MASK) + return -EINVAL; + } + } else if (t->erspan_ver == 2) { + if (data[IFLA_GRE_ERSPAN_DIR]) { + t->dir = nla_get_u8(data[IFLA_GRE_ERSPAN_DIR]); + if (t->dir & ~(DIR_MASK >> DIR_OFFSET)) + return -EINVAL; + } + if (data[IFLA_GRE_ERSPAN_HWID]) { + t->hwid = nla_get_u16(data[IFLA_GRE_ERSPAN_HWID]); + if (t->hwid & ~(HWID_MASK >> HWID_OFFSET)) + return -EINVAL; + } + } + + return 0; +} + +/* This function returns true when ENCAP attributes are present in the nl msg */ +static bool ipgre_netlink_encap_parms(struct nlattr *data[], + struct ip_tunnel_encap *ipencap) +{ + bool ret = false; + + memset(ipencap, 0, sizeof(*ipencap)); + + if (!data) + return ret; + + if (data[IFLA_GRE_ENCAP_TYPE]) { + ret = true; + ipencap->type = nla_get_u16(data[IFLA_GRE_ENCAP_TYPE]); + } + + if (data[IFLA_GRE_ENCAP_FLAGS]) { + ret = true; + ipencap->flags = nla_get_u16(data[IFLA_GRE_ENCAP_FLAGS]); + } + + if (data[IFLA_GRE_ENCAP_SPORT]) { + ret = true; + ipencap->sport = nla_get_be16(data[IFLA_GRE_ENCAP_SPORT]); + } + + if (data[IFLA_GRE_ENCAP_DPORT]) { + ret = true; + ipencap->dport = nla_get_be16(data[IFLA_GRE_ENCAP_DPORT]); + } + + return ret; +} + +static int gre_tap_init(struct net_device *dev) +{ + __gre_tunnel_init(dev); + dev->priv_flags |= IFF_LIVE_ADDR_CHANGE; + netif_keep_dst(dev); + + return ip_tunnel_init(dev); +} + +static const struct net_device_ops gre_tap_netdev_ops = { + .ndo_init = gre_tap_init, + .ndo_uninit = ip_tunnel_uninit, + .ndo_start_xmit = gre_tap_xmit, + .ndo_set_mac_address = eth_mac_addr, + .ndo_validate_addr = eth_validate_addr, + .ndo_change_mtu = ip_tunnel_change_mtu, + .ndo_get_stats64 = dev_get_tstats64, + .ndo_get_iflink = ip_tunnel_get_iflink, + .ndo_fill_metadata_dst = gre_fill_metadata_dst, +}; + +static int erspan_tunnel_init(struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + + if (tunnel->erspan_ver == 0) + tunnel->tun_hlen = 4; /* 4-byte GRE hdr. */ + else + tunnel->tun_hlen = 8; /* 8-byte GRE hdr. */ + + tunnel->parms.iph.protocol = IPPROTO_GRE; + tunnel->hlen = tunnel->tun_hlen + tunnel->encap_hlen + + erspan_hdr_len(tunnel->erspan_ver); + + dev->features |= GRE_FEATURES; + dev->hw_features |= GRE_FEATURES; + dev->priv_flags |= IFF_LIVE_ADDR_CHANGE; + netif_keep_dst(dev); + + return ip_tunnel_init(dev); +} + +static const struct net_device_ops erspan_netdev_ops = { + .ndo_init = erspan_tunnel_init, + .ndo_uninit = ip_tunnel_uninit, + .ndo_start_xmit = erspan_xmit, + .ndo_set_mac_address = eth_mac_addr, + .ndo_validate_addr = eth_validate_addr, + .ndo_change_mtu = ip_tunnel_change_mtu, + .ndo_get_stats64 = dev_get_tstats64, + .ndo_get_iflink = ip_tunnel_get_iflink, + .ndo_fill_metadata_dst = gre_fill_metadata_dst, +}; + +static void ipgre_tap_setup(struct net_device *dev) +{ + ether_setup(dev); + dev->max_mtu = 0; + dev->netdev_ops = &gre_tap_netdev_ops; + dev->priv_flags &= ~IFF_TX_SKB_SHARING; + dev->priv_flags |= IFF_LIVE_ADDR_CHANGE; + ip_tunnel_setup(dev, gre_tap_net_id); +} + +static int +ipgre_newlink_encap_setup(struct net_device *dev, struct nlattr *data[]) +{ + struct ip_tunnel_encap ipencap; + + if (ipgre_netlink_encap_parms(data, &ipencap)) { + struct ip_tunnel *t = netdev_priv(dev); + int err = ip_tunnel_encap_setup(t, &ipencap); + + if (err < 0) + return err; + } + + return 0; +} + +static int ipgre_newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ip_tunnel_parm p; + __u32 fwmark = 0; + int err; + + err = ipgre_newlink_encap_setup(dev, data); + if (err) + return err; + + err = ipgre_netlink_parms(dev, data, tb, &p, &fwmark); + if (err < 0) + return err; + return ip_tunnel_newlink(dev, tb, &p, fwmark); +} + +static int erspan_newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ip_tunnel_parm p; + __u32 fwmark = 0; + int err; + + err = ipgre_newlink_encap_setup(dev, data); + if (err) + return err; + + err = erspan_netlink_parms(dev, data, tb, &p, &fwmark); + if (err) + return err; + return ip_tunnel_newlink(dev, tb, &p, fwmark); +} + +static int ipgre_changelink(struct net_device *dev, struct nlattr *tb[], + struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ip_tunnel *t = netdev_priv(dev); + __u32 fwmark = t->fwmark; + struct ip_tunnel_parm p; + int err; + + err = ipgre_newlink_encap_setup(dev, data); + if (err) + return err; + + err = ipgre_netlink_parms(dev, data, tb, &p, &fwmark); + if (err < 0) + return err; + + err = ip_tunnel_changelink(dev, tb, &p, fwmark); + if (err < 0) + return err; + + t->parms.i_flags = p.i_flags; + t->parms.o_flags = p.o_flags; + + ipgre_link_update(dev, !tb[IFLA_MTU]); + + return 0; +} + +static int erspan_changelink(struct net_device *dev, struct nlattr *tb[], + struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ip_tunnel *t = netdev_priv(dev); + __u32 fwmark = t->fwmark; + struct ip_tunnel_parm p; + int err; + + err = ipgre_newlink_encap_setup(dev, data); + if (err) + return err; + + err = erspan_netlink_parms(dev, data, tb, &p, &fwmark); + if (err < 0) + return err; + + err = ip_tunnel_changelink(dev, tb, &p, fwmark); + if (err < 0) + return err; + + t->parms.i_flags = p.i_flags; + t->parms.o_flags = p.o_flags; + + return 0; +} + +static size_t ipgre_get_size(const struct net_device *dev) +{ + return + /* IFLA_GRE_LINK */ + nla_total_size(4) + + /* IFLA_GRE_IFLAGS */ + nla_total_size(2) + + /* IFLA_GRE_OFLAGS */ + nla_total_size(2) + + /* IFLA_GRE_IKEY */ + nla_total_size(4) + + /* IFLA_GRE_OKEY */ + nla_total_size(4) + + /* IFLA_GRE_LOCAL */ + nla_total_size(4) + + /* IFLA_GRE_REMOTE */ + nla_total_size(4) + + /* IFLA_GRE_TTL */ + nla_total_size(1) + + /* IFLA_GRE_TOS */ + nla_total_size(1) + + /* IFLA_GRE_PMTUDISC */ + nla_total_size(1) + + /* IFLA_GRE_ENCAP_TYPE */ + nla_total_size(2) + + /* IFLA_GRE_ENCAP_FLAGS */ + nla_total_size(2) + + /* IFLA_GRE_ENCAP_SPORT */ + nla_total_size(2) + + /* IFLA_GRE_ENCAP_DPORT */ + nla_total_size(2) + + /* IFLA_GRE_COLLECT_METADATA */ + nla_total_size(0) + + /* IFLA_GRE_IGNORE_DF */ + nla_total_size(1) + + /* IFLA_GRE_FWMARK */ + nla_total_size(4) + + /* IFLA_GRE_ERSPAN_INDEX */ + nla_total_size(4) + + /* IFLA_GRE_ERSPAN_VER */ + nla_total_size(1) + + /* IFLA_GRE_ERSPAN_DIR */ + nla_total_size(1) + + /* IFLA_GRE_ERSPAN_HWID */ + nla_total_size(2) + + 0; +} + +static int ipgre_fill_info(struct sk_buff *skb, const struct net_device *dev) +{ + struct ip_tunnel *t = netdev_priv(dev); + struct ip_tunnel_parm *p = &t->parms; + __be16 o_flags = p->o_flags; + + if (nla_put_u32(skb, IFLA_GRE_LINK, p->link) || + nla_put_be16(skb, IFLA_GRE_IFLAGS, + gre_tnl_flags_to_gre_flags(p->i_flags)) || + nla_put_be16(skb, IFLA_GRE_OFLAGS, + gre_tnl_flags_to_gre_flags(o_flags)) || + nla_put_be32(skb, IFLA_GRE_IKEY, p->i_key) || + nla_put_be32(skb, IFLA_GRE_OKEY, p->o_key) || + nla_put_in_addr(skb, IFLA_GRE_LOCAL, p->iph.saddr) || + nla_put_in_addr(skb, IFLA_GRE_REMOTE, p->iph.daddr) || + nla_put_u8(skb, IFLA_GRE_TTL, p->iph.ttl) || + nla_put_u8(skb, IFLA_GRE_TOS, p->iph.tos) || + nla_put_u8(skb, IFLA_GRE_PMTUDISC, + !!(p->iph.frag_off & htons(IP_DF))) || + nla_put_u32(skb, IFLA_GRE_FWMARK, t->fwmark)) + goto nla_put_failure; + + if (nla_put_u16(skb, IFLA_GRE_ENCAP_TYPE, + t->encap.type) || + nla_put_be16(skb, IFLA_GRE_ENCAP_SPORT, + t->encap.sport) || + nla_put_be16(skb, IFLA_GRE_ENCAP_DPORT, + t->encap.dport) || + nla_put_u16(skb, IFLA_GRE_ENCAP_FLAGS, + t->encap.flags)) + goto nla_put_failure; + + if (nla_put_u8(skb, IFLA_GRE_IGNORE_DF, t->ignore_df)) + goto nla_put_failure; + + if (t->collect_md) { + if (nla_put_flag(skb, IFLA_GRE_COLLECT_METADATA)) + goto nla_put_failure; + } + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static int erspan_fill_info(struct sk_buff *skb, const struct net_device *dev) +{ + struct ip_tunnel *t = netdev_priv(dev); + + if (t->erspan_ver <= 2) { + if (t->erspan_ver != 0 && !t->collect_md) + t->parms.o_flags |= TUNNEL_KEY; + + if (nla_put_u8(skb, IFLA_GRE_ERSPAN_VER, t->erspan_ver)) + goto nla_put_failure; + + if (t->erspan_ver == 1) { + if (nla_put_u32(skb, IFLA_GRE_ERSPAN_INDEX, t->index)) + goto nla_put_failure; + } else if (t->erspan_ver == 2) { + if (nla_put_u8(skb, IFLA_GRE_ERSPAN_DIR, t->dir)) + goto nla_put_failure; + if (nla_put_u16(skb, IFLA_GRE_ERSPAN_HWID, t->hwid)) + goto nla_put_failure; + } + } + + return ipgre_fill_info(skb, dev); + +nla_put_failure: + return -EMSGSIZE; +} + +static void erspan_setup(struct net_device *dev) +{ + struct ip_tunnel *t = netdev_priv(dev); + + ether_setup(dev); + dev->max_mtu = 0; + dev->netdev_ops = &erspan_netdev_ops; + dev->priv_flags &= ~IFF_TX_SKB_SHARING; + dev->priv_flags |= IFF_LIVE_ADDR_CHANGE; + ip_tunnel_setup(dev, erspan_net_id); + t->erspan_ver = 1; +} + +static const struct nla_policy ipgre_policy[IFLA_GRE_MAX + 1] = { + [IFLA_GRE_LINK] = { .type = NLA_U32 }, + [IFLA_GRE_IFLAGS] = { .type = NLA_U16 }, + [IFLA_GRE_OFLAGS] = { .type = NLA_U16 }, + [IFLA_GRE_IKEY] = { .type = NLA_U32 }, + [IFLA_GRE_OKEY] = { .type = NLA_U32 }, + [IFLA_GRE_LOCAL] = { .len = sizeof_field(struct iphdr, saddr) }, + [IFLA_GRE_REMOTE] = { .len = sizeof_field(struct iphdr, daddr) }, + [IFLA_GRE_TTL] = { .type = NLA_U8 }, + [IFLA_GRE_TOS] = { .type = NLA_U8 }, + [IFLA_GRE_PMTUDISC] = { .type = NLA_U8 }, + [IFLA_GRE_ENCAP_TYPE] = { .type = NLA_U16 }, + [IFLA_GRE_ENCAP_FLAGS] = { .type = NLA_U16 }, + [IFLA_GRE_ENCAP_SPORT] = { .type = NLA_U16 }, + [IFLA_GRE_ENCAP_DPORT] = { .type = NLA_U16 }, + [IFLA_GRE_COLLECT_METADATA] = { .type = NLA_FLAG }, + [IFLA_GRE_IGNORE_DF] = { .type = NLA_U8 }, + [IFLA_GRE_FWMARK] = { .type = NLA_U32 }, + [IFLA_GRE_ERSPAN_INDEX] = { .type = NLA_U32 }, + [IFLA_GRE_ERSPAN_VER] = { .type = NLA_U8 }, + [IFLA_GRE_ERSPAN_DIR] = { .type = NLA_U8 }, + [IFLA_GRE_ERSPAN_HWID] = { .type = NLA_U16 }, +}; + +static struct rtnl_link_ops ipgre_link_ops __read_mostly = { + .kind = "gre", + .maxtype = IFLA_GRE_MAX, + .policy = ipgre_policy, + .priv_size = sizeof(struct ip_tunnel), + .setup = ipgre_tunnel_setup, + .validate = ipgre_tunnel_validate, + .newlink = ipgre_newlink, + .changelink = ipgre_changelink, + .dellink = ip_tunnel_dellink, + .get_size = ipgre_get_size, + .fill_info = ipgre_fill_info, + .get_link_net = ip_tunnel_get_link_net, +}; + +static struct rtnl_link_ops ipgre_tap_ops __read_mostly = { + .kind = "gretap", + .maxtype = IFLA_GRE_MAX, + .policy = ipgre_policy, + .priv_size = sizeof(struct ip_tunnel), + .setup = ipgre_tap_setup, + .validate = ipgre_tap_validate, + .newlink = ipgre_newlink, + .changelink = ipgre_changelink, + .dellink = ip_tunnel_dellink, + .get_size = ipgre_get_size, + .fill_info = ipgre_fill_info, + .get_link_net = ip_tunnel_get_link_net, +}; + +static struct rtnl_link_ops erspan_link_ops __read_mostly = { + .kind = "erspan", + .maxtype = IFLA_GRE_MAX, + .policy = ipgre_policy, + .priv_size = sizeof(struct ip_tunnel), + .setup = erspan_setup, + .validate = erspan_validate, + .newlink = erspan_newlink, + .changelink = erspan_changelink, + .dellink = ip_tunnel_dellink, + .get_size = ipgre_get_size, + .fill_info = erspan_fill_info, + .get_link_net = ip_tunnel_get_link_net, +}; + +struct net_device *gretap_fb_dev_create(struct net *net, const char *name, + u8 name_assign_type) +{ + struct nlattr *tb[IFLA_MAX + 1]; + struct net_device *dev; + LIST_HEAD(list_kill); + struct ip_tunnel *t; + int err; + + memset(&tb, 0, sizeof(tb)); + + dev = rtnl_create_link(net, name, name_assign_type, + &ipgre_tap_ops, tb, NULL); + if (IS_ERR(dev)) + return dev; + + /* Configure flow based GRE device. */ + t = netdev_priv(dev); + t->collect_md = true; + + err = ipgre_newlink(net, dev, tb, NULL, NULL); + if (err < 0) { + free_netdev(dev); + return ERR_PTR(err); + } + + /* openvswitch users expect packet sizes to be unrestricted, + * so set the largest MTU we can. + */ + err = __ip_tunnel_change_mtu(dev, IP_MAX_MTU, false); + if (err) + goto out; + + err = rtnl_configure_link(dev, NULL); + if (err < 0) + goto out; + + return dev; +out: + ip_tunnel_dellink(dev, &list_kill); + unregister_netdevice_many(&list_kill); + return ERR_PTR(err); +} +EXPORT_SYMBOL_GPL(gretap_fb_dev_create); + +static int __net_init ipgre_tap_init_net(struct net *net) +{ + return ip_tunnel_init_net(net, gre_tap_net_id, &ipgre_tap_ops, "gretap0"); +} + +static void __net_exit ipgre_tap_exit_batch_net(struct list_head *list_net) +{ + ip_tunnel_delete_nets(list_net, gre_tap_net_id, &ipgre_tap_ops); +} + +static struct pernet_operations ipgre_tap_net_ops = { + .init = ipgre_tap_init_net, + .exit_batch = ipgre_tap_exit_batch_net, + .id = &gre_tap_net_id, + .size = sizeof(struct ip_tunnel_net), +}; + +static int __net_init erspan_init_net(struct net *net) +{ + return ip_tunnel_init_net(net, erspan_net_id, + &erspan_link_ops, "erspan0"); +} + +static void __net_exit erspan_exit_batch_net(struct list_head *net_list) +{ + ip_tunnel_delete_nets(net_list, erspan_net_id, &erspan_link_ops); +} + +static struct pernet_operations erspan_net_ops = { + .init = erspan_init_net, + .exit_batch = erspan_exit_batch_net, + .id = &erspan_net_id, + .size = sizeof(struct ip_tunnel_net), +}; + +static int __init ipgre_init(void) +{ + int err; + + pr_info("GRE over IPv4 tunneling driver\n"); + + err = register_pernet_device(&ipgre_net_ops); + if (err < 0) + return err; + + err = register_pernet_device(&ipgre_tap_net_ops); + if (err < 0) + goto pnet_tap_failed; + + err = register_pernet_device(&erspan_net_ops); + if (err < 0) + goto pnet_erspan_failed; + + err = gre_add_protocol(&ipgre_protocol, GREPROTO_CISCO); + if (err < 0) { + pr_info("%s: can't add protocol\n", __func__); + goto add_proto_failed; + } + + err = rtnl_link_register(&ipgre_link_ops); + if (err < 0) + goto rtnl_link_failed; + + err = rtnl_link_register(&ipgre_tap_ops); + if (err < 0) + goto tap_ops_failed; + + err = rtnl_link_register(&erspan_link_ops); + if (err < 0) + goto erspan_link_failed; + + return 0; + +erspan_link_failed: + rtnl_link_unregister(&ipgre_tap_ops); +tap_ops_failed: + rtnl_link_unregister(&ipgre_link_ops); +rtnl_link_failed: + gre_del_protocol(&ipgre_protocol, GREPROTO_CISCO); +add_proto_failed: + unregister_pernet_device(&erspan_net_ops); +pnet_erspan_failed: + unregister_pernet_device(&ipgre_tap_net_ops); +pnet_tap_failed: + unregister_pernet_device(&ipgre_net_ops); + return err; +} + +static void __exit ipgre_fini(void) +{ + rtnl_link_unregister(&ipgre_tap_ops); + rtnl_link_unregister(&ipgre_link_ops); + rtnl_link_unregister(&erspan_link_ops); + gre_del_protocol(&ipgre_protocol, GREPROTO_CISCO); + unregister_pernet_device(&ipgre_tap_net_ops); + unregister_pernet_device(&ipgre_net_ops); + unregister_pernet_device(&erspan_net_ops); +} + +module_init(ipgre_init); +module_exit(ipgre_fini); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_RTNL_LINK("gre"); +MODULE_ALIAS_RTNL_LINK("gretap"); +MODULE_ALIAS_RTNL_LINK("erspan"); +MODULE_ALIAS_NETDEV("gre0"); +MODULE_ALIAS_NETDEV("gretap0"); +MODULE_ALIAS_NETDEV("erspan0"); diff --git a/net/ipv4/ip_input.c b/net/ipv4/ip_input.c new file mode 100644 index 000000000..e7196ecff --- /dev/null +++ b/net/ipv4/ip_input.c @@ -0,0 +1,675 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * The Internet Protocol (IP) module. + * + * Authors: Ross Biro + * Fred N. van Kempen, + * Donald Becker, + * Alan Cox, + * Richard Underwood + * Stefan Becker, + * Jorge Cwik, + * Arnt Gulbrandsen, + * + * Fixes: + * Alan Cox : Commented a couple of minor bits of surplus code + * Alan Cox : Undefining IP_FORWARD doesn't include the code + * (just stops a compiler warning). + * Alan Cox : Frames with >=MAX_ROUTE record routes, strict routes or loose routes + * are junked rather than corrupting things. + * Alan Cox : Frames to bad broadcast subnets are dumped + * We used to process them non broadcast and + * boy could that cause havoc. + * Alan Cox : ip_forward sets the free flag on the + * new frame it queues. Still crap because + * it copies the frame but at least it + * doesn't eat memory too. + * Alan Cox : Generic queue code and memory fixes. + * Fred Van Kempen : IP fragment support (borrowed from NET2E) + * Gerhard Koerting: Forward fragmented frames correctly. + * Gerhard Koerting: Fixes to my fix of the above 8-). + * Gerhard Koerting: IP interface addressing fix. + * Linus Torvalds : More robustness checks + * Alan Cox : Even more checks: Still not as robust as it ought to be + * Alan Cox : Save IP header pointer for later + * Alan Cox : ip option setting + * Alan Cox : Use ip_tos/ip_ttl settings + * Alan Cox : Fragmentation bogosity removed + * (Thanks to Mark.Bush@prg.ox.ac.uk) + * Dmitry Gorodchanin : Send of a raw packet crash fix. + * Alan Cox : Silly ip bug when an overlength + * fragment turns up. Now frees the + * queue. + * Linus Torvalds/ : Memory leakage on fragmentation + * Alan Cox : handling. + * Gerhard Koerting: Forwarding uses IP priority hints + * Teemu Rantanen : Fragment problems. + * Alan Cox : General cleanup, comments and reformat + * Alan Cox : SNMP statistics + * Alan Cox : BSD address rule semantics. Also see + * UDP as there is a nasty checksum issue + * if you do things the wrong way. + * Alan Cox : Always defrag, moved IP_FORWARD to the config.in file + * Alan Cox : IP options adjust sk->priority. + * Pedro Roque : Fix mtu/length error in ip_forward. + * Alan Cox : Avoid ip_chk_addr when possible. + * Richard Underwood : IP multicasting. + * Alan Cox : Cleaned up multicast handlers. + * Alan Cox : RAW sockets demultiplex in the BSD style. + * Gunther Mayer : Fix the SNMP reporting typo + * Alan Cox : Always in group 224.0.0.1 + * Pauline Middelink : Fast ip_checksum update when forwarding + * Masquerading support. + * Alan Cox : Multicast loopback error for 224.0.0.1 + * Alan Cox : IP_MULTICAST_LOOP option. + * Alan Cox : Use notifiers. + * Bjorn Ekwall : Removed ip_csum (from slhc.c too) + * Bjorn Ekwall : Moved ip_fast_csum to ip.h (inline!) + * Stefan Becker : Send out ICMP HOST REDIRECT + * Arnt Gulbrandsen : ip_build_xmit + * Alan Cox : Per socket routing cache + * Alan Cox : Fixed routing cache, added header cache. + * Alan Cox : Loopback didn't work right in original ip_build_xmit - fixed it. + * Alan Cox : Only send ICMP_REDIRECT if src/dest are the same net. + * Alan Cox : Incoming IP option handling. + * Alan Cox : Set saddr on raw output frames as per BSD. + * Alan Cox : Stopped broadcast source route explosions. + * Alan Cox : Can disable source routing + * Takeshi Sone : Masquerading didn't work. + * Dave Bonn,Alan Cox : Faster IP forwarding whenever possible. + * Alan Cox : Memory leaks, tramples, misc debugging. + * Alan Cox : Fixed multicast (by popular demand 8)) + * Alan Cox : Fixed forwarding (by even more popular demand 8)) + * Alan Cox : Fixed SNMP statistics [I think] + * Gerhard Koerting : IP fragmentation forwarding fix + * Alan Cox : Device lock against page fault. + * Alan Cox : IP_HDRINCL facility. + * Werner Almesberger : Zero fragment bug + * Alan Cox : RAW IP frame length bug + * Alan Cox : Outgoing firewall on build_xmit + * A.N.Kuznetsov : IP_OPTIONS support throughout the kernel + * Alan Cox : Multicast routing hooks + * Jos Vos : Do accounting *before* call_in_firewall + * Willy Konynenberg : Transparent proxying support + * + * To Fix: + * IP fragmentation wants rewriting cleanly. The RFC815 algorithm is much more efficient + * and could be made very efficient with the addition of some virtual memory hacks to permit + * the allocation of a buffer that can then be 'grown' by twiddling page tables. + * Output fragmentation wants updating along with the buffer management to use a single + * interleaved copy algorithm so that fragmenting has a one copy overhead. Actual packet + * output should probably do its own fragmentation at the UDP/RAW layer. TCP shouldn't cause + * fragmentation anyway. + */ + +#define pr_fmt(fmt) "IPv4: " fmt + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * Process Router Attention IP option (RFC 2113) + */ +bool ip_call_ra_chain(struct sk_buff *skb) +{ + struct ip_ra_chain *ra; + u8 protocol = ip_hdr(skb)->protocol; + struct sock *last = NULL; + struct net_device *dev = skb->dev; + struct net *net = dev_net(dev); + + for (ra = rcu_dereference(net->ipv4.ra_chain); ra; ra = rcu_dereference(ra->next)) { + struct sock *sk = ra->sk; + + /* If socket is bound to an interface, only report + * the packet if it came from that interface. + */ + if (sk && inet_sk(sk)->inet_num == protocol && + (!sk->sk_bound_dev_if || + sk->sk_bound_dev_if == dev->ifindex)) { + if (ip_is_fragment(ip_hdr(skb))) { + if (ip_defrag(net, skb, IP_DEFRAG_CALL_RA_CHAIN)) + return true; + } + if (last) { + struct sk_buff *skb2 = skb_clone(skb, GFP_ATOMIC); + if (skb2) + raw_rcv(last, skb2); + } + last = sk; + } + } + + if (last) { + raw_rcv(last, skb); + return true; + } + return false; +} + +INDIRECT_CALLABLE_DECLARE(int udp_rcv(struct sk_buff *)); +INDIRECT_CALLABLE_DECLARE(int tcp_v4_rcv(struct sk_buff *)); +void ip_protocol_deliver_rcu(struct net *net, struct sk_buff *skb, int protocol) +{ + const struct net_protocol *ipprot; + int raw, ret; + +resubmit: + raw = raw_local_deliver(skb, protocol); + + ipprot = rcu_dereference(inet_protos[protocol]); + if (ipprot) { + if (!ipprot->no_policy) { + if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) { + kfree_skb_reason(skb, + SKB_DROP_REASON_XFRM_POLICY); + return; + } + nf_reset_ct(skb); + } + ret = INDIRECT_CALL_2(ipprot->handler, tcp_v4_rcv, udp_rcv, + skb); + if (ret < 0) { + protocol = -ret; + goto resubmit; + } + __IP_INC_STATS(net, IPSTATS_MIB_INDELIVERS); + } else { + if (!raw) { + if (xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) { + __IP_INC_STATS(net, IPSTATS_MIB_INUNKNOWNPROTOS); + icmp_send(skb, ICMP_DEST_UNREACH, + ICMP_PROT_UNREACH, 0); + } + kfree_skb_reason(skb, SKB_DROP_REASON_IP_NOPROTO); + } else { + __IP_INC_STATS(net, IPSTATS_MIB_INDELIVERS); + consume_skb(skb); + } + } +} + +static int ip_local_deliver_finish(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + skb_clear_delivery_time(skb); + __skb_pull(skb, skb_network_header_len(skb)); + + rcu_read_lock(); + ip_protocol_deliver_rcu(net, skb, ip_hdr(skb)->protocol); + rcu_read_unlock(); + + return 0; +} + +/* + * Deliver IP Packets to the higher protocol layers. + */ +int ip_local_deliver(struct sk_buff *skb) +{ + /* + * Reassemble IP fragments. + */ + struct net *net = dev_net(skb->dev); + + if (ip_is_fragment(ip_hdr(skb))) { + if (ip_defrag(net, skb, IP_DEFRAG_LOCAL_DELIVER)) + return 0; + } + + return NF_HOOK(NFPROTO_IPV4, NF_INET_LOCAL_IN, + net, NULL, skb, skb->dev, NULL, + ip_local_deliver_finish); +} +EXPORT_SYMBOL(ip_local_deliver); + +static inline bool ip_rcv_options(struct sk_buff *skb, struct net_device *dev) +{ + struct ip_options *opt; + const struct iphdr *iph; + + /* It looks as overkill, because not all + IP options require packet mangling. + But it is the easiest for now, especially taking + into account that combination of IP options + and running sniffer is extremely rare condition. + --ANK (980813) + */ + if (skb_cow(skb, skb_headroom(skb))) { + __IP_INC_STATS(dev_net(dev), IPSTATS_MIB_INDISCARDS); + goto drop; + } + + iph = ip_hdr(skb); + opt = &(IPCB(skb)->opt); + opt->optlen = iph->ihl*4 - sizeof(struct iphdr); + + if (ip_options_compile(dev_net(dev), opt, skb)) { + __IP_INC_STATS(dev_net(dev), IPSTATS_MIB_INHDRERRORS); + goto drop; + } + + if (unlikely(opt->srr)) { + struct in_device *in_dev = __in_dev_get_rcu(dev); + + if (in_dev) { + if (!IN_DEV_SOURCE_ROUTE(in_dev)) { + if (IN_DEV_LOG_MARTIANS(in_dev)) + net_info_ratelimited("source route option %pI4 -> %pI4\n", + &iph->saddr, + &iph->daddr); + goto drop; + } + } + + if (ip_options_rcv_srr(skb, dev)) + goto drop; + } + + return false; +drop: + return true; +} + +static bool ip_can_use_hint(const struct sk_buff *skb, const struct iphdr *iph, + const struct sk_buff *hint) +{ + return hint && !skb_dst(skb) && ip_hdr(hint)->daddr == iph->daddr && + ip_hdr(hint)->tos == iph->tos; +} + +int tcp_v4_early_demux(struct sk_buff *skb); +int udp_v4_early_demux(struct sk_buff *skb); +static int ip_rcv_finish_core(struct net *net, struct sock *sk, + struct sk_buff *skb, struct net_device *dev, + const struct sk_buff *hint) +{ + const struct iphdr *iph = ip_hdr(skb); + int err, drop_reason; + struct rtable *rt; + + drop_reason = SKB_DROP_REASON_NOT_SPECIFIED; + + if (ip_can_use_hint(skb, iph, hint)) { + err = ip_route_use_hint(skb, iph->daddr, iph->saddr, iph->tos, + dev, hint); + if (unlikely(err)) + goto drop_error; + } + + if (READ_ONCE(net->ipv4.sysctl_ip_early_demux) && + !skb_dst(skb) && + !skb->sk && + !ip_is_fragment(iph)) { + switch (iph->protocol) { + case IPPROTO_TCP: + if (READ_ONCE(net->ipv4.sysctl_tcp_early_demux)) { + tcp_v4_early_demux(skb); + + /* must reload iph, skb->head might have changed */ + iph = ip_hdr(skb); + } + break; + case IPPROTO_UDP: + if (READ_ONCE(net->ipv4.sysctl_udp_early_demux)) { + err = udp_v4_early_demux(skb); + if (unlikely(err)) + goto drop_error; + + /* must reload iph, skb->head might have changed */ + iph = ip_hdr(skb); + } + break; + } + } + + /* + * Initialise the virtual path cache for the packet. It describes + * how the packet travels inside Linux networking. + */ + if (!skb_valid_dst(skb)) { + err = ip_route_input_noref(skb, iph->daddr, iph->saddr, + iph->tos, dev); + if (unlikely(err)) + goto drop_error; + } else { + struct in_device *in_dev = __in_dev_get_rcu(dev); + + if (in_dev && IN_DEV_ORCONF(in_dev, NOPOLICY)) + IPCB(skb)->flags |= IPSKB_NOPOLICY; + } + +#ifdef CONFIG_IP_ROUTE_CLASSID + if (unlikely(skb_dst(skb)->tclassid)) { + struct ip_rt_acct *st = this_cpu_ptr(ip_rt_acct); + u32 idx = skb_dst(skb)->tclassid; + st[idx&0xFF].o_packets++; + st[idx&0xFF].o_bytes += skb->len; + st[(idx>>16)&0xFF].i_packets++; + st[(idx>>16)&0xFF].i_bytes += skb->len; + } +#endif + + if (iph->ihl > 5 && ip_rcv_options(skb, dev)) + goto drop; + + rt = skb_rtable(skb); + if (rt->rt_type == RTN_MULTICAST) { + __IP_UPD_PO_STATS(net, IPSTATS_MIB_INMCAST, skb->len); + } else if (rt->rt_type == RTN_BROADCAST) { + __IP_UPD_PO_STATS(net, IPSTATS_MIB_INBCAST, skb->len); + } else if (skb->pkt_type == PACKET_BROADCAST || + skb->pkt_type == PACKET_MULTICAST) { + struct in_device *in_dev = __in_dev_get_rcu(dev); + + /* RFC 1122 3.3.6: + * + * When a host sends a datagram to a link-layer broadcast + * address, the IP destination address MUST be a legal IP + * broadcast or IP multicast address. + * + * A host SHOULD silently discard a datagram that is received + * via a link-layer broadcast (see Section 2.4) but does not + * specify an IP multicast or broadcast destination address. + * + * This doesn't explicitly say L2 *broadcast*, but broadcast is + * in a way a form of multicast and the most common use case for + * this is 802.11 protecting against cross-station spoofing (the + * so-called "hole-196" attack) so do it for both. + */ + if (in_dev && + IN_DEV_ORCONF(in_dev, DROP_UNICAST_IN_L2_MULTICAST)) { + drop_reason = SKB_DROP_REASON_UNICAST_IN_L2_MULTICAST; + goto drop; + } + } + + return NET_RX_SUCCESS; + +drop: + kfree_skb_reason(skb, drop_reason); + return NET_RX_DROP; + +drop_error: + if (err == -EXDEV) { + drop_reason = SKB_DROP_REASON_IP_RPFILTER; + __NET_INC_STATS(net, LINUX_MIB_IPRPFILTER); + } + goto drop; +} + +static int ip_rcv_finish(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct net_device *dev = skb->dev; + int ret; + + /* if ingress device is enslaved to an L3 master device pass the + * skb to its handler for processing + */ + skb = l3mdev_ip_rcv(skb); + if (!skb) + return NET_RX_SUCCESS; + + ret = ip_rcv_finish_core(net, sk, skb, dev, NULL); + if (ret != NET_RX_DROP) + ret = dst_input(skb); + return ret; +} + +/* + * Main IP Receive routine. + */ +static struct sk_buff *ip_rcv_core(struct sk_buff *skb, struct net *net) +{ + const struct iphdr *iph; + int drop_reason; + u32 len; + + /* When the interface is in promisc. mode, drop all the crap + * that it receives, do not try to analyse it. + */ + if (skb->pkt_type == PACKET_OTHERHOST) { + dev_core_stats_rx_otherhost_dropped_inc(skb->dev); + drop_reason = SKB_DROP_REASON_OTHERHOST; + goto drop; + } + + __IP_UPD_PO_STATS(net, IPSTATS_MIB_IN, skb->len); + + skb = skb_share_check(skb, GFP_ATOMIC); + if (!skb) { + __IP_INC_STATS(net, IPSTATS_MIB_INDISCARDS); + goto out; + } + + drop_reason = SKB_DROP_REASON_NOT_SPECIFIED; + if (!pskb_may_pull(skb, sizeof(struct iphdr))) + goto inhdr_error; + + iph = ip_hdr(skb); + + /* + * RFC1122: 3.2.1.2 MUST silently discard any IP frame that fails the checksum. + * + * Is the datagram acceptable? + * + * 1. Length at least the size of an ip header + * 2. Version of 4 + * 3. Checksums correctly. [Speed optimisation for later, skip loopback checksums] + * 4. Doesn't have a bogus length + */ + + if (iph->ihl < 5 || iph->version != 4) + goto inhdr_error; + + BUILD_BUG_ON(IPSTATS_MIB_ECT1PKTS != IPSTATS_MIB_NOECTPKTS + INET_ECN_ECT_1); + BUILD_BUG_ON(IPSTATS_MIB_ECT0PKTS != IPSTATS_MIB_NOECTPKTS + INET_ECN_ECT_0); + BUILD_BUG_ON(IPSTATS_MIB_CEPKTS != IPSTATS_MIB_NOECTPKTS + INET_ECN_CE); + __IP_ADD_STATS(net, + IPSTATS_MIB_NOECTPKTS + (iph->tos & INET_ECN_MASK), + max_t(unsigned short, 1, skb_shinfo(skb)->gso_segs)); + + if (!pskb_may_pull(skb, iph->ihl*4)) + goto inhdr_error; + + iph = ip_hdr(skb); + + if (unlikely(ip_fast_csum((u8 *)iph, iph->ihl))) + goto csum_error; + + len = ntohs(iph->tot_len); + if (skb->len < len) { + drop_reason = SKB_DROP_REASON_PKT_TOO_SMALL; + __IP_INC_STATS(net, IPSTATS_MIB_INTRUNCATEDPKTS); + goto drop; + } else if (len < (iph->ihl*4)) + goto inhdr_error; + + /* Our transport medium may have padded the buffer out. Now we know it + * is IP we can trim to the true length of the frame. + * Note this now means skb->len holds ntohs(iph->tot_len). + */ + if (pskb_trim_rcsum(skb, len)) { + __IP_INC_STATS(net, IPSTATS_MIB_INDISCARDS); + goto drop; + } + + iph = ip_hdr(skb); + skb->transport_header = skb->network_header + iph->ihl*4; + + /* Remove any debris in the socket control block */ + memset(IPCB(skb), 0, sizeof(struct inet_skb_parm)); + IPCB(skb)->iif = skb->skb_iif; + + /* Must drop socket now because of tproxy. */ + if (!skb_sk_is_prefetched(skb)) + skb_orphan(skb); + + return skb; + +csum_error: + drop_reason = SKB_DROP_REASON_IP_CSUM; + __IP_INC_STATS(net, IPSTATS_MIB_CSUMERRORS); +inhdr_error: + if (drop_reason == SKB_DROP_REASON_NOT_SPECIFIED) + drop_reason = SKB_DROP_REASON_IP_INHDR; + __IP_INC_STATS(net, IPSTATS_MIB_INHDRERRORS); +drop: + kfree_skb_reason(skb, drop_reason); +out: + return NULL; +} + +/* + * IP receive entry point + */ +int ip_rcv(struct sk_buff *skb, struct net_device *dev, struct packet_type *pt, + struct net_device *orig_dev) +{ + struct net *net = dev_net(dev); + + skb = ip_rcv_core(skb, net); + if (skb == NULL) + return NET_RX_DROP; + + return NF_HOOK(NFPROTO_IPV4, NF_INET_PRE_ROUTING, + net, NULL, skb, dev, NULL, + ip_rcv_finish); +} + +static void ip_sublist_rcv_finish(struct list_head *head) +{ + struct sk_buff *skb, *next; + + list_for_each_entry_safe(skb, next, head, list) { + skb_list_del_init(skb); + dst_input(skb); + } +} + +static struct sk_buff *ip_extract_route_hint(const struct net *net, + struct sk_buff *skb, int rt_type) +{ + if (fib4_has_custom_rules(net) || rt_type == RTN_BROADCAST || + IPCB(skb)->flags & IPSKB_MULTIPATH) + return NULL; + + return skb; +} + +static void ip_list_rcv_finish(struct net *net, struct sock *sk, + struct list_head *head) +{ + struct sk_buff *skb, *next, *hint = NULL; + struct dst_entry *curr_dst = NULL; + struct list_head sublist; + + INIT_LIST_HEAD(&sublist); + list_for_each_entry_safe(skb, next, head, list) { + struct net_device *dev = skb->dev; + struct dst_entry *dst; + + skb_list_del_init(skb); + /* if ingress device is enslaved to an L3 master device pass the + * skb to its handler for processing + */ + skb = l3mdev_ip_rcv(skb); + if (!skb) + continue; + if (ip_rcv_finish_core(net, sk, skb, dev, hint) == NET_RX_DROP) + continue; + + dst = skb_dst(skb); + if (curr_dst != dst) { + hint = ip_extract_route_hint(net, skb, + ((struct rtable *)dst)->rt_type); + + /* dispatch old sublist */ + if (!list_empty(&sublist)) + ip_sublist_rcv_finish(&sublist); + /* start new sublist */ + INIT_LIST_HEAD(&sublist); + curr_dst = dst; + } + list_add_tail(&skb->list, &sublist); + } + /* dispatch final sublist */ + ip_sublist_rcv_finish(&sublist); +} + +static void ip_sublist_rcv(struct list_head *head, struct net_device *dev, + struct net *net) +{ + NF_HOOK_LIST(NFPROTO_IPV4, NF_INET_PRE_ROUTING, net, NULL, + head, dev, NULL, ip_rcv_finish); + ip_list_rcv_finish(net, NULL, head); +} + +/* Receive a list of IP packets */ +void ip_list_rcv(struct list_head *head, struct packet_type *pt, + struct net_device *orig_dev) +{ + struct net_device *curr_dev = NULL; + struct net *curr_net = NULL; + struct sk_buff *skb, *next; + struct list_head sublist; + + INIT_LIST_HEAD(&sublist); + list_for_each_entry_safe(skb, next, head, list) { + struct net_device *dev = skb->dev; + struct net *net = dev_net(dev); + + skb_list_del_init(skb); + skb = ip_rcv_core(skb, net); + if (skb == NULL) + continue; + + if (curr_dev != dev || curr_net != net) { + /* dispatch old sublist */ + if (!list_empty(&sublist)) + ip_sublist_rcv(&sublist, curr_dev, curr_net); + /* start new sublist */ + INIT_LIST_HEAD(&sublist); + curr_dev = dev; + curr_net = net; + } + list_add_tail(&skb->list, &sublist); + } + /* dispatch final sublist */ + if (!list_empty(&sublist)) + ip_sublist_rcv(&sublist, curr_dev, curr_net); +} diff --git a/net/ipv4/ip_options.c b/net/ipv4/ip_options.c new file mode 100644 index 000000000..a9e22a098 --- /dev/null +++ b/net/ipv4/ip_options.c @@ -0,0 +1,641 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * The options processing module for ip.c + * + * Authors: A.N.Kuznetsov + * + */ + +#define pr_fmt(fmt) "IPv4: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * Write options to IP header, record destination address to + * source route option, address of outgoing interface + * (we should already know it, so that this function is allowed be + * called only after routing decision) and timestamp, + * if we originate this datagram. + * + * daddr is real destination address, next hop is recorded in IP header. + * saddr is address of outgoing interface. + */ + +void ip_options_build(struct sk_buff *skb, struct ip_options *opt, + __be32 daddr, struct rtable *rt) +{ + unsigned char *iph = skb_network_header(skb); + + memcpy(&(IPCB(skb)->opt), opt, sizeof(struct ip_options)); + memcpy(iph + sizeof(struct iphdr), opt->__data, opt->optlen); + opt = &(IPCB(skb)->opt); + + if (opt->srr) + memcpy(iph + opt->srr + iph[opt->srr + 1] - 4, &daddr, 4); + + if (opt->rr_needaddr) + ip_rt_get_source(iph + opt->rr + iph[opt->rr + 2] - 5, skb, rt); + if (opt->ts_needaddr) + ip_rt_get_source(iph + opt->ts + iph[opt->ts + 2] - 9, skb, rt); + if (opt->ts_needtime) { + __be32 midtime; + + midtime = inet_current_timestamp(); + memcpy(iph + opt->ts + iph[opt->ts + 2] - 5, &midtime, 4); + } +} + +/* + * Provided (sopt, skb) points to received options, + * build in dopt compiled option set appropriate for answering. + * i.e. invert SRR option, copy anothers, + * and grab room in RR/TS options. + * + * NOTE: dopt cannot point to skb. + */ + +int __ip_options_echo(struct net *net, struct ip_options *dopt, + struct sk_buff *skb, const struct ip_options *sopt) +{ + unsigned char *sptr, *dptr; + int soffset, doffset; + int optlen; + + memset(dopt, 0, sizeof(struct ip_options)); + + if (sopt->optlen == 0) + return 0; + + sptr = skb_network_header(skb); + dptr = dopt->__data; + + if (sopt->rr) { + optlen = sptr[sopt->rr+1]; + soffset = sptr[sopt->rr+2]; + dopt->rr = dopt->optlen + sizeof(struct iphdr); + memcpy(dptr, sptr+sopt->rr, optlen); + if (sopt->rr_needaddr && soffset <= optlen) { + if (soffset + 3 > optlen) + return -EINVAL; + dptr[2] = soffset + 4; + dopt->rr_needaddr = 1; + } + dptr += optlen; + dopt->optlen += optlen; + } + if (sopt->ts) { + optlen = sptr[sopt->ts+1]; + soffset = sptr[sopt->ts+2]; + dopt->ts = dopt->optlen + sizeof(struct iphdr); + memcpy(dptr, sptr+sopt->ts, optlen); + if (soffset <= optlen) { + if (sopt->ts_needaddr) { + if (soffset + 3 > optlen) + return -EINVAL; + dopt->ts_needaddr = 1; + soffset += 4; + } + if (sopt->ts_needtime) { + if (soffset + 3 > optlen) + return -EINVAL; + if ((dptr[3]&0xF) != IPOPT_TS_PRESPEC) { + dopt->ts_needtime = 1; + soffset += 4; + } else { + dopt->ts_needtime = 0; + + if (soffset + 7 <= optlen) { + __be32 addr; + + memcpy(&addr, dptr+soffset-1, 4); + if (inet_addr_type(net, addr) != RTN_UNICAST) { + dopt->ts_needtime = 1; + soffset += 8; + } + } + } + } + dptr[2] = soffset; + } + dptr += optlen; + dopt->optlen += optlen; + } + if (sopt->srr) { + unsigned char *start = sptr+sopt->srr; + __be32 faddr; + + optlen = start[1]; + soffset = start[2]; + doffset = 0; + if (soffset > optlen) + soffset = optlen + 1; + soffset -= 4; + if (soffset > 3) { + memcpy(&faddr, &start[soffset-1], 4); + for (soffset -= 4, doffset = 4; soffset > 3; soffset -= 4, doffset += 4) + memcpy(&dptr[doffset-1], &start[soffset-1], 4); + /* + * RFC1812 requires to fix illegal source routes. + */ + if (memcmp(&ip_hdr(skb)->saddr, + &start[soffset + 3], 4) == 0) + doffset -= 4; + } + if (doffset > 3) { + dopt->faddr = faddr; + dptr[0] = start[0]; + dptr[1] = doffset+3; + dptr[2] = 4; + dptr += doffset+3; + dopt->srr = dopt->optlen + sizeof(struct iphdr); + dopt->optlen += doffset+3; + dopt->is_strictroute = sopt->is_strictroute; + } + } + if (sopt->cipso) { + optlen = sptr[sopt->cipso+1]; + dopt->cipso = dopt->optlen+sizeof(struct iphdr); + memcpy(dptr, sptr+sopt->cipso, optlen); + dptr += optlen; + dopt->optlen += optlen; + } + while (dopt->optlen & 3) { + *dptr++ = IPOPT_END; + dopt->optlen++; + } + return 0; +} + +/* + * Options "fragmenting", just fill options not + * allowed in fragments with NOOPs. + * Simple and stupid 8), but the most efficient way. + */ + +void ip_options_fragment(struct sk_buff *skb) +{ + unsigned char *optptr = skb_network_header(skb) + sizeof(struct iphdr); + struct ip_options *opt = &(IPCB(skb)->opt); + int l = opt->optlen; + int optlen; + + while (l > 0) { + switch (*optptr) { + case IPOPT_END: + return; + case IPOPT_NOOP: + l--; + optptr++; + continue; + } + optlen = optptr[1]; + if (optlen < 2 || optlen > l) + return; + if (!IPOPT_COPIED(*optptr)) + memset(optptr, IPOPT_NOOP, optlen); + l -= optlen; + optptr += optlen; + } + opt->ts = 0; + opt->rr = 0; + opt->rr_needaddr = 0; + opt->ts_needaddr = 0; + opt->ts_needtime = 0; +} + +/* helper used by ip_options_compile() to call fib_compute_spec_dst() + * at most one time. + */ +static void spec_dst_fill(__be32 *spec_dst, struct sk_buff *skb) +{ + if (*spec_dst == htonl(INADDR_ANY)) + *spec_dst = fib_compute_spec_dst(skb); +} + +/* + * Verify options and fill pointers in struct options. + * Caller should clear *opt, and set opt->data. + * If opt == NULL, then skb->data should point to IP header. + */ + +int __ip_options_compile(struct net *net, + struct ip_options *opt, struct sk_buff *skb, + __be32 *info) +{ + __be32 spec_dst = htonl(INADDR_ANY); + unsigned char *pp_ptr = NULL; + struct rtable *rt = NULL; + unsigned char *optptr; + unsigned char *iph; + int optlen, l; + + if (skb) { + rt = skb_rtable(skb); + optptr = (unsigned char *)&(ip_hdr(skb)[1]); + } else + optptr = opt->__data; + iph = optptr - sizeof(struct iphdr); + + for (l = opt->optlen; l > 0; ) { + switch (*optptr) { + case IPOPT_END: + for (optptr++, l--; l > 0; optptr++, l--) { + if (*optptr != IPOPT_END) { + *optptr = IPOPT_END; + opt->is_changed = 1; + } + } + goto eol; + case IPOPT_NOOP: + l--; + optptr++; + continue; + } + if (unlikely(l < 2)) { + pp_ptr = optptr; + goto error; + } + optlen = optptr[1]; + if (optlen < 2 || optlen > l) { + pp_ptr = optptr; + goto error; + } + switch (*optptr) { + case IPOPT_SSRR: + case IPOPT_LSRR: + if (optlen < 3) { + pp_ptr = optptr + 1; + goto error; + } + if (optptr[2] < 4) { + pp_ptr = optptr + 2; + goto error; + } + /* NB: cf RFC-1812 5.2.4.1 */ + if (opt->srr) { + pp_ptr = optptr; + goto error; + } + if (!skb) { + if (optptr[2] != 4 || optlen < 7 || ((optlen-3) & 3)) { + pp_ptr = optptr + 1; + goto error; + } + memcpy(&opt->faddr, &optptr[3], 4); + if (optlen > 7) + memmove(&optptr[3], &optptr[7], optlen-7); + } + opt->is_strictroute = (optptr[0] == IPOPT_SSRR); + opt->srr = optptr - iph; + break; + case IPOPT_RR: + if (opt->rr) { + pp_ptr = optptr; + goto error; + } + if (optlen < 3) { + pp_ptr = optptr + 1; + goto error; + } + if (optptr[2] < 4) { + pp_ptr = optptr + 2; + goto error; + } + if (optptr[2] <= optlen) { + if (optptr[2]+3 > optlen) { + pp_ptr = optptr + 2; + goto error; + } + if (rt) { + spec_dst_fill(&spec_dst, skb); + memcpy(&optptr[optptr[2]-1], &spec_dst, 4); + opt->is_changed = 1; + } + optptr[2] += 4; + opt->rr_needaddr = 1; + } + opt->rr = optptr - iph; + break; + case IPOPT_TIMESTAMP: + if (opt->ts) { + pp_ptr = optptr; + goto error; + } + if (optlen < 4) { + pp_ptr = optptr + 1; + goto error; + } + if (optptr[2] < 5) { + pp_ptr = optptr + 2; + goto error; + } + if (optptr[2] <= optlen) { + unsigned char *timeptr = NULL; + if (optptr[2]+3 > optlen) { + pp_ptr = optptr + 2; + goto error; + } + switch (optptr[3]&0xF) { + case IPOPT_TS_TSONLY: + if (skb) + timeptr = &optptr[optptr[2]-1]; + opt->ts_needtime = 1; + optptr[2] += 4; + break; + case IPOPT_TS_TSANDADDR: + if (optptr[2]+7 > optlen) { + pp_ptr = optptr + 2; + goto error; + } + if (rt) { + spec_dst_fill(&spec_dst, skb); + memcpy(&optptr[optptr[2]-1], &spec_dst, 4); + timeptr = &optptr[optptr[2]+3]; + } + opt->ts_needaddr = 1; + opt->ts_needtime = 1; + optptr[2] += 8; + break; + case IPOPT_TS_PRESPEC: + if (optptr[2]+7 > optlen) { + pp_ptr = optptr + 2; + goto error; + } + { + __be32 addr; + memcpy(&addr, &optptr[optptr[2]-1], 4); + if (inet_addr_type(net, addr) == RTN_UNICAST) + break; + if (skb) + timeptr = &optptr[optptr[2]+3]; + } + opt->ts_needtime = 1; + optptr[2] += 8; + break; + default: + if (!skb && !ns_capable(net->user_ns, CAP_NET_RAW)) { + pp_ptr = optptr + 3; + goto error; + } + break; + } + if (timeptr) { + __be32 midtime; + + midtime = inet_current_timestamp(); + memcpy(timeptr, &midtime, 4); + opt->is_changed = 1; + } + } else if ((optptr[3]&0xF) != IPOPT_TS_PRESPEC) { + unsigned int overflow = optptr[3]>>4; + if (overflow == 15) { + pp_ptr = optptr + 3; + goto error; + } + if (skb) { + optptr[3] = (optptr[3]&0xF)|((overflow+1)<<4); + opt->is_changed = 1; + } + } + opt->ts = optptr - iph; + break; + case IPOPT_RA: + if (optlen < 4) { + pp_ptr = optptr + 1; + goto error; + } + if (optptr[2] == 0 && optptr[3] == 0) + opt->router_alert = optptr - iph; + break; + case IPOPT_CIPSO: + if ((!skb && !ns_capable(net->user_ns, CAP_NET_RAW)) || opt->cipso) { + pp_ptr = optptr; + goto error; + } + opt->cipso = optptr - iph; + if (cipso_v4_validate(skb, &optptr)) { + pp_ptr = optptr; + goto error; + } + break; + case IPOPT_SEC: + case IPOPT_SID: + default: + if (!skb && !ns_capable(net->user_ns, CAP_NET_RAW)) { + pp_ptr = optptr; + goto error; + } + break; + } + l -= optlen; + optptr += optlen; + } + +eol: + if (!pp_ptr) + return 0; + +error: + if (info) + *info = htonl((pp_ptr-iph)<<24); + return -EINVAL; +} +EXPORT_SYMBOL(__ip_options_compile); + +int ip_options_compile(struct net *net, + struct ip_options *opt, struct sk_buff *skb) +{ + int ret; + __be32 info; + + ret = __ip_options_compile(net, opt, skb, &info); + if (ret != 0 && skb) + icmp_send(skb, ICMP_PARAMETERPROB, 0, info); + return ret; +} +EXPORT_SYMBOL(ip_options_compile); + +/* + * Undo all the changes done by ip_options_compile(). + */ + +void ip_options_undo(struct ip_options *opt) +{ + if (opt->srr) { + unsigned char *optptr = opt->__data + opt->srr - sizeof(struct iphdr); + + memmove(optptr + 7, optptr + 3, optptr[1] - 7); + memcpy(optptr + 3, &opt->faddr, 4); + } + if (opt->rr_needaddr) { + unsigned char *optptr = opt->__data + opt->rr - sizeof(struct iphdr); + + optptr[2] -= 4; + memset(&optptr[optptr[2] - 1], 0, 4); + } + if (opt->ts) { + unsigned char *optptr = opt->__data + opt->ts - sizeof(struct iphdr); + + if (opt->ts_needtime) { + optptr[2] -= 4; + memset(&optptr[optptr[2] - 1], 0, 4); + if ((optptr[3] & 0xF) == IPOPT_TS_PRESPEC) + optptr[2] -= 4; + } + if (opt->ts_needaddr) { + optptr[2] -= 4; + memset(&optptr[optptr[2] - 1], 0, 4); + } + } +} + +int ip_options_get(struct net *net, struct ip_options_rcu **optp, + sockptr_t data, int optlen) +{ + struct ip_options_rcu *opt; + + opt = kzalloc(sizeof(struct ip_options_rcu) + ((optlen + 3) & ~3), + GFP_KERNEL); + if (!opt) + return -ENOMEM; + if (optlen && copy_from_sockptr(opt->opt.__data, data, optlen)) { + kfree(opt); + return -EFAULT; + } + + while (optlen & 3) + opt->opt.__data[optlen++] = IPOPT_END; + opt->opt.optlen = optlen; + if (optlen && ip_options_compile(net, &opt->opt, NULL)) { + kfree(opt); + return -EINVAL; + } + kfree(*optp); + *optp = opt; + return 0; +} + +void ip_forward_options(struct sk_buff *skb) +{ + struct ip_options *opt = &(IPCB(skb)->opt); + unsigned char *optptr; + struct rtable *rt = skb_rtable(skb); + unsigned char *raw = skb_network_header(skb); + + if (opt->rr_needaddr) { + optptr = (unsigned char *)raw + opt->rr; + ip_rt_get_source(&optptr[optptr[2]-5], skb, rt); + opt->is_changed = 1; + } + if (opt->srr_is_hit) { + int srrptr, srrspace; + + optptr = raw + opt->srr; + + for ( srrptr = optptr[2], srrspace = optptr[1]; + srrptr <= srrspace; + srrptr += 4 + ) { + if (srrptr + 3 > srrspace) + break; + if (memcmp(&opt->nexthop, &optptr[srrptr-1], 4) == 0) + break; + } + if (srrptr + 3 <= srrspace) { + opt->is_changed = 1; + ip_hdr(skb)->daddr = opt->nexthop; + ip_rt_get_source(&optptr[srrptr-1], skb, rt); + optptr[2] = srrptr+4; + } else { + net_crit_ratelimited("%s(): Argh! Destination lost!\n", + __func__); + } + if (opt->ts_needaddr) { + optptr = raw + opt->ts; + ip_rt_get_source(&optptr[optptr[2]-9], skb, rt); + opt->is_changed = 1; + } + } + if (opt->is_changed) { + opt->is_changed = 0; + ip_send_check(ip_hdr(skb)); + } +} + +int ip_options_rcv_srr(struct sk_buff *skb, struct net_device *dev) +{ + struct ip_options *opt = &(IPCB(skb)->opt); + int srrspace, srrptr; + __be32 nexthop; + struct iphdr *iph = ip_hdr(skb); + unsigned char *optptr = skb_network_header(skb) + opt->srr; + struct rtable *rt = skb_rtable(skb); + struct rtable *rt2; + unsigned long orefdst; + int err; + + if (!rt) + return 0; + + if (skb->pkt_type != PACKET_HOST) + return -EINVAL; + if (rt->rt_type == RTN_UNICAST) { + if (!opt->is_strictroute) + return 0; + icmp_send(skb, ICMP_PARAMETERPROB, 0, htonl(16<<24)); + return -EINVAL; + } + if (rt->rt_type != RTN_LOCAL) + return -EINVAL; + + for (srrptr = optptr[2], srrspace = optptr[1]; srrptr <= srrspace; srrptr += 4) { + if (srrptr + 3 > srrspace) { + icmp_send(skb, ICMP_PARAMETERPROB, 0, htonl((opt->srr+2)<<24)); + return -EINVAL; + } + memcpy(&nexthop, &optptr[srrptr-1], 4); + + orefdst = skb->_skb_refdst; + skb_dst_set(skb, NULL); + err = ip_route_input(skb, nexthop, iph->saddr, iph->tos, dev); + rt2 = skb_rtable(skb); + if (err || (rt2->rt_type != RTN_UNICAST && rt2->rt_type != RTN_LOCAL)) { + skb_dst_drop(skb); + skb->_skb_refdst = orefdst; + return -EINVAL; + } + refdst_drop(orefdst); + if (rt2->rt_type != RTN_LOCAL) + break; + /* Superfast 8) loopback forward */ + iph->daddr = nexthop; + opt->is_changed = 1; + } + if (srrptr <= srrspace) { + opt->srr_is_hit = 1; + opt->nexthop = nexthop; + opt->is_changed = 1; + } + return 0; +} +EXPORT_SYMBOL(ip_options_rcv_srr); diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c new file mode 100644 index 000000000..e19ef88ae --- /dev/null +++ b/net/ipv4/ip_output.c @@ -0,0 +1,1773 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * The Internet Protocol (IP) output module. + * + * Authors: Ross Biro + * Fred N. van Kempen, + * Donald Becker, + * Alan Cox, + * Richard Underwood + * Stefan Becker, + * Jorge Cwik, + * Arnt Gulbrandsen, + * Hirokazu Takahashi, + * + * See ip_input.c for original log + * + * Fixes: + * Alan Cox : Missing nonblock feature in ip_build_xmit. + * Mike Kilburn : htons() missing in ip_build_xmit. + * Bradford Johnson: Fix faulty handling of some frames when + * no route is found. + * Alexander Demenshin: Missing sk/skb free in ip_queue_xmit + * (in case if packet not accepted by + * output firewall rules) + * Mike McLagan : Routing by source + * Alexey Kuznetsov: use new route cache + * Andi Kleen: Fix broken PMTU recovery and remove + * some redundant tests. + * Vitaly E. Lavrov : Transparent proxy revived after year coma. + * Andi Kleen : Replace ip_reply with ip_send_reply. + * Andi Kleen : Split fast and slow ip_build_xmit path + * for decreased register pressure on x86 + * and more readability. + * Marc Boucher : When call_out_firewall returns FW_QUEUE, + * silently drop skb instead of failing with -EPERM. + * Detlev Wengorz : Copy protocol for fragments. + * Hirokazu Takahashi: HW checksumming for outgoing UDP + * datagrams. + * Hirokazu Takahashi: sendfile() on UDP works now. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int +ip_fragment(struct net *net, struct sock *sk, struct sk_buff *skb, + unsigned int mtu, + int (*output)(struct net *, struct sock *, struct sk_buff *)); + +/* Generate a checksum for an outgoing IP datagram. */ +void ip_send_check(struct iphdr *iph) +{ + iph->check = 0; + iph->check = ip_fast_csum((unsigned char *)iph, iph->ihl); +} +EXPORT_SYMBOL(ip_send_check); + +int __ip_local_out(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct iphdr *iph = ip_hdr(skb); + + iph->tot_len = htons(skb->len); + ip_send_check(iph); + + /* if egress device is enslaved to an L3 master device pass the + * skb to its handler for processing + */ + skb = l3mdev_ip_out(sk, skb); + if (unlikely(!skb)) + return 0; + + skb->protocol = htons(ETH_P_IP); + + return nf_hook(NFPROTO_IPV4, NF_INET_LOCAL_OUT, + net, sk, skb, NULL, skb_dst(skb)->dev, + dst_output); +} + +int ip_local_out(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + int err; + + err = __ip_local_out(net, sk, skb); + if (likely(err == 1)) + err = dst_output(net, sk, skb); + + return err; +} +EXPORT_SYMBOL_GPL(ip_local_out); + +static inline int ip_select_ttl(struct inet_sock *inet, struct dst_entry *dst) +{ + int ttl = inet->uc_ttl; + + if (ttl < 0) + ttl = ip4_dst_hoplimit(dst); + return ttl; +} + +/* + * Add an ip header to a skbuff and send it out. + * + */ +int ip_build_and_send_pkt(struct sk_buff *skb, const struct sock *sk, + __be32 saddr, __be32 daddr, struct ip_options_rcu *opt, + u8 tos) +{ + struct inet_sock *inet = inet_sk(sk); + struct rtable *rt = skb_rtable(skb); + struct net *net = sock_net(sk); + struct iphdr *iph; + + /* Build the IP header. */ + skb_push(skb, sizeof(struct iphdr) + (opt ? opt->opt.optlen : 0)); + skb_reset_network_header(skb); + iph = ip_hdr(skb); + iph->version = 4; + iph->ihl = 5; + iph->tos = tos; + iph->ttl = ip_select_ttl(inet, &rt->dst); + iph->daddr = (opt && opt->opt.srr ? opt->opt.faddr : daddr); + iph->saddr = saddr; + iph->protocol = sk->sk_protocol; + /* Do not bother generating IPID for small packets (eg SYNACK) */ + if (skb->len <= IPV4_MIN_MTU || ip_dont_fragment(sk, &rt->dst)) { + iph->frag_off = htons(IP_DF); + iph->id = 0; + } else { + iph->frag_off = 0; + /* TCP packets here are SYNACK with fat IPv4/TCP options. + * Avoid using the hashed IP ident generator. + */ + if (sk->sk_protocol == IPPROTO_TCP) + iph->id = (__force __be16)get_random_u16(); + else + __ip_select_ident(net, iph, 1); + } + + if (opt && opt->opt.optlen) { + iph->ihl += opt->opt.optlen>>2; + ip_options_build(skb, &opt->opt, daddr, rt); + } + + skb->priority = READ_ONCE(sk->sk_priority); + if (!skb->mark) + skb->mark = READ_ONCE(sk->sk_mark); + + /* Send it out. */ + return ip_local_out(net, skb->sk, skb); +} +EXPORT_SYMBOL_GPL(ip_build_and_send_pkt); + +static int ip_finish_output2(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + struct rtable *rt = (struct rtable *)dst; + struct net_device *dev = dst->dev; + unsigned int hh_len = LL_RESERVED_SPACE(dev); + struct neighbour *neigh; + bool is_v6gw = false; + + if (rt->rt_type == RTN_MULTICAST) { + IP_UPD_PO_STATS(net, IPSTATS_MIB_OUTMCAST, skb->len); + } else if (rt->rt_type == RTN_BROADCAST) + IP_UPD_PO_STATS(net, IPSTATS_MIB_OUTBCAST, skb->len); + + if (unlikely(skb_headroom(skb) < hh_len && dev->header_ops)) { + skb = skb_expand_head(skb, hh_len); + if (!skb) + return -ENOMEM; + } + + if (lwtunnel_xmit_redirect(dst->lwtstate)) { + int res = lwtunnel_xmit(skb); + + if (res != LWTUNNEL_XMIT_CONTINUE) + return res; + } + + rcu_read_lock(); + neigh = ip_neigh_for_gw(rt, skb, &is_v6gw); + if (!IS_ERR(neigh)) { + int res; + + sock_confirm_neigh(skb, neigh); + /* if crossing protocols, can not use the cached header */ + res = neigh_output(neigh, skb, is_v6gw); + rcu_read_unlock(); + return res; + } + rcu_read_unlock(); + + net_dbg_ratelimited("%s: No header cache and no neighbour!\n", + __func__); + kfree_skb_reason(skb, SKB_DROP_REASON_NEIGH_CREATEFAIL); + return PTR_ERR(neigh); +} + +static int ip_finish_output_gso(struct net *net, struct sock *sk, + struct sk_buff *skb, unsigned int mtu) +{ + struct sk_buff *segs, *nskb; + netdev_features_t features; + int ret = 0; + + /* common case: seglen is <= mtu + */ + if (skb_gso_validate_network_len(skb, mtu)) + return ip_finish_output2(net, sk, skb); + + /* Slowpath - GSO segment length exceeds the egress MTU. + * + * This can happen in several cases: + * - Forwarding of a TCP GRO skb, when DF flag is not set. + * - Forwarding of an skb that arrived on a virtualization interface + * (virtio-net/vhost/tap) with TSO/GSO size set by other network + * stack. + * - Local GSO skb transmitted on an NETIF_F_TSO tunnel stacked over an + * interface with a smaller MTU. + * - Arriving GRO skb (or GSO skb in a virtualized environment) that is + * bridged to a NETIF_F_TSO tunnel stacked over an interface with an + * insufficient MTU. + */ + features = netif_skb_features(skb); + BUILD_BUG_ON(sizeof(*IPCB(skb)) > SKB_GSO_CB_OFFSET); + segs = skb_gso_segment(skb, features & ~NETIF_F_GSO_MASK); + if (IS_ERR_OR_NULL(segs)) { + kfree_skb(skb); + return -ENOMEM; + } + + consume_skb(skb); + + skb_list_walk_safe(segs, segs, nskb) { + int err; + + skb_mark_not_on_list(segs); + err = ip_fragment(net, sk, segs, mtu, ip_finish_output2); + + if (err && ret == 0) + ret = err; + } + + return ret; +} + +static int __ip_finish_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + unsigned int mtu; + +#if defined(CONFIG_NETFILTER) && defined(CONFIG_XFRM) + /* Policy lookup after SNAT yielded a new policy */ + if (skb_dst(skb)->xfrm) { + IPCB(skb)->flags |= IPSKB_REROUTED; + return dst_output(net, sk, skb); + } +#endif + mtu = ip_skb_dst_mtu(sk, skb); + if (skb_is_gso(skb)) + return ip_finish_output_gso(net, sk, skb, mtu); + + if (skb->len > mtu || IPCB(skb)->frag_max_size) + return ip_fragment(net, sk, skb, mtu, ip_finish_output2); + + return ip_finish_output2(net, sk, skb); +} + +static int ip_finish_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + int ret; + + ret = BPF_CGROUP_RUN_PROG_INET_EGRESS(sk, skb); + switch (ret) { + case NET_XMIT_SUCCESS: + return __ip_finish_output(net, sk, skb); + case NET_XMIT_CN: + return __ip_finish_output(net, sk, skb) ? : ret; + default: + kfree_skb_reason(skb, SKB_DROP_REASON_BPF_CGROUP_EGRESS); + return ret; + } +} + +static int ip_mc_finish_output(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + struct rtable *new_rt; + bool do_cn = false; + int ret, err; + + ret = BPF_CGROUP_RUN_PROG_INET_EGRESS(sk, skb); + switch (ret) { + case NET_XMIT_CN: + do_cn = true; + fallthrough; + case NET_XMIT_SUCCESS: + break; + default: + kfree_skb_reason(skb, SKB_DROP_REASON_BPF_CGROUP_EGRESS); + return ret; + } + + /* Reset rt_iif so that inet_iif() will return skb->skb_iif. Setting + * this to non-zero causes ipi_ifindex in in_pktinfo to be overwritten, + * see ipv4_pktinfo_prepare(). + */ + new_rt = rt_dst_clone(net->loopback_dev, skb_rtable(skb)); + if (new_rt) { + new_rt->rt_iif = 0; + skb_dst_drop(skb); + skb_dst_set(skb, &new_rt->dst); + } + + err = dev_loopback_xmit(net, sk, skb); + return (do_cn && err) ? ret : err; +} + +int ip_mc_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct rtable *rt = skb_rtable(skb); + struct net_device *dev = rt->dst.dev; + + /* + * If the indicated interface is up and running, send the packet. + */ + IP_UPD_PO_STATS(net, IPSTATS_MIB_OUT, skb->len); + + skb->dev = dev; + skb->protocol = htons(ETH_P_IP); + + /* + * Multicasts are looped back for other local users + */ + + if (rt->rt_flags&RTCF_MULTICAST) { + if (sk_mc_loop(sk) +#ifdef CONFIG_IP_MROUTE + /* Small optimization: do not loopback not local frames, + which returned after forwarding; they will be dropped + by ip_mr_input in any case. + Note, that local frames are looped back to be delivered + to local recipients. + + This check is duplicated in ip_mr_input at the moment. + */ + && + ((rt->rt_flags & RTCF_LOCAL) || + !(IPCB(skb)->flags & IPSKB_FORWARDED)) +#endif + ) { + struct sk_buff *newskb = skb_clone(skb, GFP_ATOMIC); + if (newskb) + NF_HOOK(NFPROTO_IPV4, NF_INET_POST_ROUTING, + net, sk, newskb, NULL, newskb->dev, + ip_mc_finish_output); + } + + /* Multicasts with ttl 0 must not go beyond the host */ + + if (ip_hdr(skb)->ttl == 0) { + kfree_skb(skb); + return 0; + } + } + + if (rt->rt_flags&RTCF_BROADCAST) { + struct sk_buff *newskb = skb_clone(skb, GFP_ATOMIC); + if (newskb) + NF_HOOK(NFPROTO_IPV4, NF_INET_POST_ROUTING, + net, sk, newskb, NULL, newskb->dev, + ip_mc_finish_output); + } + + return NF_HOOK_COND(NFPROTO_IPV4, NF_INET_POST_ROUTING, + net, sk, skb, NULL, skb->dev, + ip_finish_output, + !(IPCB(skb)->flags & IPSKB_REROUTED)); +} + +int ip_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct net_device *dev = skb_dst(skb)->dev, *indev = skb->dev; + + IP_UPD_PO_STATS(net, IPSTATS_MIB_OUT, skb->len); + + skb->dev = dev; + skb->protocol = htons(ETH_P_IP); + + return NF_HOOK_COND(NFPROTO_IPV4, NF_INET_POST_ROUTING, + net, sk, skb, indev, dev, + ip_finish_output, + !(IPCB(skb)->flags & IPSKB_REROUTED)); +} +EXPORT_SYMBOL(ip_output); + +/* + * copy saddr and daddr, possibly using 64bit load/stores + * Equivalent to : + * iph->saddr = fl4->saddr; + * iph->daddr = fl4->daddr; + */ +static void ip_copy_addrs(struct iphdr *iph, const struct flowi4 *fl4) +{ + BUILD_BUG_ON(offsetof(typeof(*fl4), daddr) != + offsetof(typeof(*fl4), saddr) + sizeof(fl4->saddr)); + + iph->saddr = fl4->saddr; + iph->daddr = fl4->daddr; +} + +/* Note: skb->sk can be different from sk, in case of tunnels */ +int __ip_queue_xmit(struct sock *sk, struct sk_buff *skb, struct flowi *fl, + __u8 tos) +{ + struct inet_sock *inet = inet_sk(sk); + struct net *net = sock_net(sk); + struct ip_options_rcu *inet_opt; + struct flowi4 *fl4; + struct rtable *rt; + struct iphdr *iph; + int res; + + /* Skip all of this if the packet is already routed, + * f.e. by something like SCTP. + */ + rcu_read_lock(); + inet_opt = rcu_dereference(inet->inet_opt); + fl4 = &fl->u.ip4; + rt = skb_rtable(skb); + if (rt) + goto packet_routed; + + /* Make sure we can route this packet. */ + rt = (struct rtable *)__sk_dst_check(sk, 0); + if (!rt) { + __be32 daddr; + + /* Use correct destination address if we have options. */ + daddr = inet->inet_daddr; + if (inet_opt && inet_opt->opt.srr) + daddr = inet_opt->opt.faddr; + + /* If this fails, retransmit mechanism of transport layer will + * keep trying until route appears or the connection times + * itself out. + */ + rt = ip_route_output_ports(net, fl4, sk, + daddr, inet->inet_saddr, + inet->inet_dport, + inet->inet_sport, + sk->sk_protocol, + RT_CONN_FLAGS_TOS(sk, tos), + sk->sk_bound_dev_if); + if (IS_ERR(rt)) + goto no_route; + sk_setup_caps(sk, &rt->dst); + } + skb_dst_set_noref(skb, &rt->dst); + +packet_routed: + if (inet_opt && inet_opt->opt.is_strictroute && rt->rt_uses_gateway) + goto no_route; + + /* OK, we know where to send it, allocate and build IP header. */ + skb_push(skb, sizeof(struct iphdr) + (inet_opt ? inet_opt->opt.optlen : 0)); + skb_reset_network_header(skb); + iph = ip_hdr(skb); + *((__be16 *)iph) = htons((4 << 12) | (5 << 8) | (tos & 0xff)); + if (ip_dont_fragment(sk, &rt->dst) && !skb->ignore_df) + iph->frag_off = htons(IP_DF); + else + iph->frag_off = 0; + iph->ttl = ip_select_ttl(inet, &rt->dst); + iph->protocol = sk->sk_protocol; + ip_copy_addrs(iph, fl4); + + /* Transport layer set skb->h.foo itself. */ + + if (inet_opt && inet_opt->opt.optlen) { + iph->ihl += inet_opt->opt.optlen >> 2; + ip_options_build(skb, &inet_opt->opt, inet->inet_daddr, rt); + } + + ip_select_ident_segs(net, skb, sk, + skb_shinfo(skb)->gso_segs ?: 1); + + /* TODO : should we use skb->sk here instead of sk ? */ + skb->priority = READ_ONCE(sk->sk_priority); + skb->mark = READ_ONCE(sk->sk_mark); + + res = ip_local_out(net, sk, skb); + rcu_read_unlock(); + return res; + +no_route: + rcu_read_unlock(); + IP_INC_STATS(net, IPSTATS_MIB_OUTNOROUTES); + kfree_skb_reason(skb, SKB_DROP_REASON_IP_OUTNOROUTES); + return -EHOSTUNREACH; +} +EXPORT_SYMBOL(__ip_queue_xmit); + +int ip_queue_xmit(struct sock *sk, struct sk_buff *skb, struct flowi *fl) +{ + return __ip_queue_xmit(sk, skb, fl, inet_sk(sk)->tos); +} +EXPORT_SYMBOL(ip_queue_xmit); + +static void ip_copy_metadata(struct sk_buff *to, struct sk_buff *from) +{ + to->pkt_type = from->pkt_type; + to->priority = from->priority; + to->protocol = from->protocol; + to->skb_iif = from->skb_iif; + skb_dst_drop(to); + skb_dst_copy(to, from); + to->dev = from->dev; + to->mark = from->mark; + + skb_copy_hash(to, from); + +#ifdef CONFIG_NET_SCHED + to->tc_index = from->tc_index; +#endif + nf_copy(to, from); + skb_ext_copy(to, from); +#if IS_ENABLED(CONFIG_IP_VS) + to->ipvs_property = from->ipvs_property; +#endif + skb_copy_secmark(to, from); +} + +static int ip_fragment(struct net *net, struct sock *sk, struct sk_buff *skb, + unsigned int mtu, + int (*output)(struct net *, struct sock *, struct sk_buff *)) +{ + struct iphdr *iph = ip_hdr(skb); + + if ((iph->frag_off & htons(IP_DF)) == 0) + return ip_do_fragment(net, sk, skb, output); + + if (unlikely(!skb->ignore_df || + (IPCB(skb)->frag_max_size && + IPCB(skb)->frag_max_size > mtu))) { + IP_INC_STATS(net, IPSTATS_MIB_FRAGFAILS); + icmp_send(skb, ICMP_DEST_UNREACH, ICMP_FRAG_NEEDED, + htonl(mtu)); + kfree_skb(skb); + return -EMSGSIZE; + } + + return ip_do_fragment(net, sk, skb, output); +} + +void ip_fraglist_init(struct sk_buff *skb, struct iphdr *iph, + unsigned int hlen, struct ip_fraglist_iter *iter) +{ + unsigned int first_len = skb_pagelen(skb); + + iter->frag = skb_shinfo(skb)->frag_list; + skb_frag_list_init(skb); + + iter->offset = 0; + iter->iph = iph; + iter->hlen = hlen; + + skb->data_len = first_len - skb_headlen(skb); + skb->len = first_len; + iph->tot_len = htons(first_len); + iph->frag_off = htons(IP_MF); + ip_send_check(iph); +} +EXPORT_SYMBOL(ip_fraglist_init); + +void ip_fraglist_prepare(struct sk_buff *skb, struct ip_fraglist_iter *iter) +{ + unsigned int hlen = iter->hlen; + struct iphdr *iph = iter->iph; + struct sk_buff *frag; + + frag = iter->frag; + frag->ip_summed = CHECKSUM_NONE; + skb_reset_transport_header(frag); + __skb_push(frag, hlen); + skb_reset_network_header(frag); + memcpy(skb_network_header(frag), iph, hlen); + iter->iph = ip_hdr(frag); + iph = iter->iph; + iph->tot_len = htons(frag->len); + ip_copy_metadata(frag, skb); + iter->offset += skb->len - hlen; + iph->frag_off = htons(iter->offset >> 3); + if (frag->next) + iph->frag_off |= htons(IP_MF); + /* Ready, complete checksum */ + ip_send_check(iph); +} +EXPORT_SYMBOL(ip_fraglist_prepare); + +void ip_frag_init(struct sk_buff *skb, unsigned int hlen, + unsigned int ll_rs, unsigned int mtu, bool DF, + struct ip_frag_state *state) +{ + struct iphdr *iph = ip_hdr(skb); + + state->DF = DF; + state->hlen = hlen; + state->ll_rs = ll_rs; + state->mtu = mtu; + + state->left = skb->len - hlen; /* Space per frame */ + state->ptr = hlen; /* Where to start from */ + + state->offset = (ntohs(iph->frag_off) & IP_OFFSET) << 3; + state->not_last_frag = iph->frag_off & htons(IP_MF); +} +EXPORT_SYMBOL(ip_frag_init); + +static void ip_frag_ipcb(struct sk_buff *from, struct sk_buff *to, + bool first_frag) +{ + /* Copy the flags to each fragment. */ + IPCB(to)->flags = IPCB(from)->flags; + + /* ANK: dirty, but effective trick. Upgrade options only if + * the segment to be fragmented was THE FIRST (otherwise, + * options are already fixed) and make it ONCE + * on the initial skb, so that all the following fragments + * will inherit fixed options. + */ + if (first_frag) + ip_options_fragment(from); +} + +struct sk_buff *ip_frag_next(struct sk_buff *skb, struct ip_frag_state *state) +{ + unsigned int len = state->left; + struct sk_buff *skb2; + struct iphdr *iph; + + /* IF: it doesn't fit, use 'mtu' - the data space left */ + if (len > state->mtu) + len = state->mtu; + /* IF: we are not sending up to and including the packet end + then align the next start on an eight byte boundary */ + if (len < state->left) { + len &= ~7; + } + + /* Allocate buffer */ + skb2 = alloc_skb(len + state->hlen + state->ll_rs, GFP_ATOMIC); + if (!skb2) + return ERR_PTR(-ENOMEM); + + /* + * Set up data on packet + */ + + ip_copy_metadata(skb2, skb); + skb_reserve(skb2, state->ll_rs); + skb_put(skb2, len + state->hlen); + skb_reset_network_header(skb2); + skb2->transport_header = skb2->network_header + state->hlen; + + /* + * Charge the memory for the fragment to any owner + * it might possess + */ + + if (skb->sk) + skb_set_owner_w(skb2, skb->sk); + + /* + * Copy the packet header into the new buffer. + */ + + skb_copy_from_linear_data(skb, skb_network_header(skb2), state->hlen); + + /* + * Copy a block of the IP datagram. + */ + if (skb_copy_bits(skb, state->ptr, skb_transport_header(skb2), len)) + BUG(); + state->left -= len; + + /* + * Fill in the new header fields. + */ + iph = ip_hdr(skb2); + iph->frag_off = htons((state->offset >> 3)); + if (state->DF) + iph->frag_off |= htons(IP_DF); + + /* + * Added AC : If we are fragmenting a fragment that's not the + * last fragment then keep MF on each bit + */ + if (state->left > 0 || state->not_last_frag) + iph->frag_off |= htons(IP_MF); + state->ptr += len; + state->offset += len; + + iph->tot_len = htons(len + state->hlen); + + ip_send_check(iph); + + return skb2; +} +EXPORT_SYMBOL(ip_frag_next); + +/* + * This IP datagram is too large to be sent in one piece. Break it up into + * smaller pieces (each of size equal to IP header plus + * a block of the data of the original IP data part) that will yet fit in a + * single device frame, and queue such a frame for sending. + */ + +int ip_do_fragment(struct net *net, struct sock *sk, struct sk_buff *skb, + int (*output)(struct net *, struct sock *, struct sk_buff *)) +{ + struct iphdr *iph; + struct sk_buff *skb2; + bool mono_delivery_time = skb->mono_delivery_time; + struct rtable *rt = skb_rtable(skb); + unsigned int mtu, hlen, ll_rs; + struct ip_fraglist_iter iter; + ktime_t tstamp = skb->tstamp; + struct ip_frag_state state; + int err = 0; + + /* for offloaded checksums cleanup checksum before fragmentation */ + if (skb->ip_summed == CHECKSUM_PARTIAL && + (err = skb_checksum_help(skb))) + goto fail; + + /* + * Point into the IP datagram header. + */ + + iph = ip_hdr(skb); + + mtu = ip_skb_dst_mtu(sk, skb); + if (IPCB(skb)->frag_max_size && IPCB(skb)->frag_max_size < mtu) + mtu = IPCB(skb)->frag_max_size; + + /* + * Setup starting values. + */ + + hlen = iph->ihl * 4; + mtu = mtu - hlen; /* Size of data space */ + IPCB(skb)->flags |= IPSKB_FRAG_COMPLETE; + ll_rs = LL_RESERVED_SPACE(rt->dst.dev); + + /* When frag_list is given, use it. First, check its validity: + * some transformers could create wrong frag_list or break existing + * one, it is not prohibited. In this case fall back to copying. + * + * LATER: this step can be merged to real generation of fragments, + * we can switch to copy when see the first bad fragment. + */ + if (skb_has_frag_list(skb)) { + struct sk_buff *frag, *frag2; + unsigned int first_len = skb_pagelen(skb); + + if (first_len - hlen > mtu || + ((first_len - hlen) & 7) || + ip_is_fragment(iph) || + skb_cloned(skb) || + skb_headroom(skb) < ll_rs) + goto slow_path; + + skb_walk_frags(skb, frag) { + /* Correct geometry. */ + if (frag->len > mtu || + ((frag->len & 7) && frag->next) || + skb_headroom(frag) < hlen + ll_rs) + goto slow_path_clean; + + /* Partially cloned skb? */ + if (skb_shared(frag)) + goto slow_path_clean; + + BUG_ON(frag->sk); + if (skb->sk) { + frag->sk = skb->sk; + frag->destructor = sock_wfree; + } + skb->truesize -= frag->truesize; + } + + /* Everything is OK. Generate! */ + ip_fraglist_init(skb, iph, hlen, &iter); + + for (;;) { + /* Prepare header of the next frame, + * before previous one went down. */ + if (iter.frag) { + bool first_frag = (iter.offset == 0); + + IPCB(iter.frag)->flags = IPCB(skb)->flags; + ip_fraglist_prepare(skb, &iter); + if (first_frag && IPCB(skb)->opt.optlen) { + /* ipcb->opt is not populated for frags + * coming from __ip_make_skb(), + * ip_options_fragment() needs optlen + */ + IPCB(iter.frag)->opt.optlen = + IPCB(skb)->opt.optlen; + ip_options_fragment(iter.frag); + ip_send_check(iter.iph); + } + } + + skb_set_delivery_time(skb, tstamp, mono_delivery_time); + err = output(net, sk, skb); + + if (!err) + IP_INC_STATS(net, IPSTATS_MIB_FRAGCREATES); + if (err || !iter.frag) + break; + + skb = ip_fraglist_next(&iter); + } + + if (err == 0) { + IP_INC_STATS(net, IPSTATS_MIB_FRAGOKS); + return 0; + } + + kfree_skb_list(iter.frag); + + IP_INC_STATS(net, IPSTATS_MIB_FRAGFAILS); + return err; + +slow_path_clean: + skb_walk_frags(skb, frag2) { + if (frag2 == frag) + break; + frag2->sk = NULL; + frag2->destructor = NULL; + skb->truesize += frag2->truesize; + } + } + +slow_path: + /* + * Fragment the datagram. + */ + + ip_frag_init(skb, hlen, ll_rs, mtu, IPCB(skb)->flags & IPSKB_FRAG_PMTU, + &state); + + /* + * Keep copying data until we run out. + */ + + while (state.left > 0) { + bool first_frag = (state.offset == 0); + + skb2 = ip_frag_next(skb, &state); + if (IS_ERR(skb2)) { + err = PTR_ERR(skb2); + goto fail; + } + ip_frag_ipcb(skb, skb2, first_frag); + + /* + * Put this fragment into the sending queue. + */ + skb_set_delivery_time(skb2, tstamp, mono_delivery_time); + err = output(net, sk, skb2); + if (err) + goto fail; + + IP_INC_STATS(net, IPSTATS_MIB_FRAGCREATES); + } + consume_skb(skb); + IP_INC_STATS(net, IPSTATS_MIB_FRAGOKS); + return err; + +fail: + kfree_skb(skb); + IP_INC_STATS(net, IPSTATS_MIB_FRAGFAILS); + return err; +} +EXPORT_SYMBOL(ip_do_fragment); + +int +ip_generic_getfrag(void *from, char *to, int offset, int len, int odd, struct sk_buff *skb) +{ + struct msghdr *msg = from; + + if (skb->ip_summed == CHECKSUM_PARTIAL) { + if (!copy_from_iter_full(to, len, &msg->msg_iter)) + return -EFAULT; + } else { + __wsum csum = 0; + if (!csum_and_copy_from_iter_full(to, len, &csum, &msg->msg_iter)) + return -EFAULT; + skb->csum = csum_block_add(skb->csum, csum, odd); + } + return 0; +} +EXPORT_SYMBOL(ip_generic_getfrag); + +static inline __wsum +csum_page(struct page *page, int offset, int copy) +{ + char *kaddr; + __wsum csum; + kaddr = kmap(page); + csum = csum_partial(kaddr + offset, copy, 0); + kunmap(page); + return csum; +} + +static int __ip_append_data(struct sock *sk, + struct flowi4 *fl4, + struct sk_buff_head *queue, + struct inet_cork *cork, + struct page_frag *pfrag, + int getfrag(void *from, char *to, int offset, + int len, int odd, struct sk_buff *skb), + void *from, int length, int transhdrlen, + unsigned int flags) +{ + struct inet_sock *inet = inet_sk(sk); + struct ubuf_info *uarg = NULL; + struct sk_buff *skb; + struct ip_options *opt = cork->opt; + int hh_len; + int exthdrlen; + int mtu; + int copy; + int err; + int offset = 0; + bool zc = false; + unsigned int maxfraglen, fragheaderlen, maxnonfragsize; + int csummode = CHECKSUM_NONE; + struct rtable *rt = (struct rtable *)cork->dst; + unsigned int wmem_alloc_delta = 0; + bool paged, extra_uref = false; + u32 tskey = 0; + + skb = skb_peek_tail(queue); + + exthdrlen = !skb ? rt->dst.header_len : 0; + mtu = cork->gso_size ? IP_MAX_MTU : cork->fragsize; + paged = !!cork->gso_size; + + if (cork->tx_flags & SKBTX_ANY_TSTAMP && + READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_OPT_ID) + tskey = atomic_inc_return(&sk->sk_tskey) - 1; + + hh_len = LL_RESERVED_SPACE(rt->dst.dev); + + fragheaderlen = sizeof(struct iphdr) + (opt ? opt->optlen : 0); + maxfraglen = ((mtu - fragheaderlen) & ~7) + fragheaderlen; + maxnonfragsize = ip_sk_ignore_df(sk) ? IP_MAX_MTU : mtu; + + if (cork->length + length > maxnonfragsize - fragheaderlen) { + ip_local_error(sk, EMSGSIZE, fl4->daddr, inet->inet_dport, + mtu - (opt ? opt->optlen : 0)); + return -EMSGSIZE; + } + + /* + * transhdrlen > 0 means that this is the first fragment and we wish + * it won't be fragmented in the future. + */ + if (transhdrlen && + length + fragheaderlen <= mtu && + rt->dst.dev->features & (NETIF_F_HW_CSUM | NETIF_F_IP_CSUM) && + (!(flags & MSG_MORE) || cork->gso_size) && + (!exthdrlen || (rt->dst.dev->features & NETIF_F_HW_ESP_TX_CSUM))) + csummode = CHECKSUM_PARTIAL; + + if ((flags & MSG_ZEROCOPY) && length) { + struct msghdr *msg = from; + + if (getfrag == ip_generic_getfrag && msg->msg_ubuf) { + if (skb_zcopy(skb) && msg->msg_ubuf != skb_zcopy(skb)) + return -EINVAL; + + /* Leave uarg NULL if can't zerocopy, callers should + * be able to handle it. + */ + if ((rt->dst.dev->features & NETIF_F_SG) && + csummode == CHECKSUM_PARTIAL) { + paged = true; + zc = true; + uarg = msg->msg_ubuf; + } + } else if (sock_flag(sk, SOCK_ZEROCOPY)) { + uarg = msg_zerocopy_realloc(sk, length, skb_zcopy(skb)); + if (!uarg) + return -ENOBUFS; + extra_uref = !skb_zcopy(skb); /* only ref on new uarg */ + if (rt->dst.dev->features & NETIF_F_SG && + csummode == CHECKSUM_PARTIAL) { + paged = true; + zc = true; + } else { + uarg_to_msgzc(uarg)->zerocopy = 0; + skb_zcopy_set(skb, uarg, &extra_uref); + } + } + } + + cork->length += length; + + /* So, what's going on in the loop below? + * + * We use calculated fragment length to generate chained skb, + * each of segments is IP fragment ready for sending to network after + * adding appropriate IP header. + */ + + if (!skb) + goto alloc_new_skb; + + while (length > 0) { + /* Check if the remaining data fits into current packet. */ + copy = mtu - skb->len; + if (copy < length) + copy = maxfraglen - skb->len; + if (copy <= 0) { + char *data; + unsigned int datalen; + unsigned int fraglen; + unsigned int fraggap; + unsigned int alloclen, alloc_extra; + unsigned int pagedlen; + struct sk_buff *skb_prev; +alloc_new_skb: + skb_prev = skb; + if (skb_prev) + fraggap = skb_prev->len - maxfraglen; + else + fraggap = 0; + + /* + * If remaining data exceeds the mtu, + * we know we need more fragment(s). + */ + datalen = length + fraggap; + if (datalen > mtu - fragheaderlen) + datalen = maxfraglen - fragheaderlen; + fraglen = datalen + fragheaderlen; + pagedlen = 0; + + alloc_extra = hh_len + 15; + alloc_extra += exthdrlen; + + /* The last fragment gets additional space at tail. + * Note, with MSG_MORE we overallocate on fragments, + * because we have no idea what fragment will be + * the last. + */ + if (datalen == length + fraggap) + alloc_extra += rt->dst.trailer_len; + + if ((flags & MSG_MORE) && + !(rt->dst.dev->features&NETIF_F_SG)) + alloclen = mtu; + else if (!paged && + (fraglen + alloc_extra < SKB_MAX_ALLOC || + !(rt->dst.dev->features & NETIF_F_SG))) + alloclen = fraglen; + else { + alloclen = fragheaderlen + transhdrlen; + pagedlen = datalen - transhdrlen; + } + + alloclen += alloc_extra; + + if (transhdrlen) { + skb = sock_alloc_send_skb(sk, alloclen, + (flags & MSG_DONTWAIT), &err); + } else { + skb = NULL; + if (refcount_read(&sk->sk_wmem_alloc) + wmem_alloc_delta <= + 2 * sk->sk_sndbuf) + skb = alloc_skb(alloclen, + sk->sk_allocation); + if (unlikely(!skb)) + err = -ENOBUFS; + } + if (!skb) + goto error; + + /* + * Fill in the control structures + */ + skb->ip_summed = csummode; + skb->csum = 0; + skb_reserve(skb, hh_len); + + /* + * Find where to start putting bytes. + */ + data = skb_put(skb, fraglen + exthdrlen - pagedlen); + skb_set_network_header(skb, exthdrlen); + skb->transport_header = (skb->network_header + + fragheaderlen); + data += fragheaderlen + exthdrlen; + + if (fraggap) { + skb->csum = skb_copy_and_csum_bits( + skb_prev, maxfraglen, + data + transhdrlen, fraggap); + skb_prev->csum = csum_sub(skb_prev->csum, + skb->csum); + data += fraggap; + pskb_trim_unique(skb_prev, maxfraglen); + } + + copy = datalen - transhdrlen - fraggap - pagedlen; + if (copy > 0 && getfrag(from, data + transhdrlen, offset, copy, fraggap, skb) < 0) { + err = -EFAULT; + kfree_skb(skb); + goto error; + } + + offset += copy; + length -= copy + transhdrlen; + transhdrlen = 0; + exthdrlen = 0; + csummode = CHECKSUM_NONE; + + /* only the initial fragment is time stamped */ + skb_shinfo(skb)->tx_flags = cork->tx_flags; + cork->tx_flags = 0; + skb_shinfo(skb)->tskey = tskey; + tskey = 0; + skb_zcopy_set(skb, uarg, &extra_uref); + + if ((flags & MSG_CONFIRM) && !skb_prev) + skb_set_dst_pending_confirm(skb, 1); + + /* + * Put the packet on the pending queue. + */ + if (!skb->destructor) { + skb->destructor = sock_wfree; + skb->sk = sk; + wmem_alloc_delta += skb->truesize; + } + __skb_queue_tail(queue, skb); + continue; + } + + if (copy > length) + copy = length; + + if (!(rt->dst.dev->features&NETIF_F_SG) && + skb_tailroom(skb) >= copy) { + unsigned int off; + + off = skb->len; + if (getfrag(from, skb_put(skb, copy), + offset, copy, off, skb) < 0) { + __skb_trim(skb, off); + err = -EFAULT; + goto error; + } + } else if (!zc) { + int i = skb_shinfo(skb)->nr_frags; + + err = -ENOMEM; + if (!sk_page_frag_refill(sk, pfrag)) + goto error; + + skb_zcopy_downgrade_managed(skb); + if (!skb_can_coalesce(skb, i, pfrag->page, + pfrag->offset)) { + err = -EMSGSIZE; + if (i == MAX_SKB_FRAGS) + goto error; + + __skb_fill_page_desc(skb, i, pfrag->page, + pfrag->offset, 0); + skb_shinfo(skb)->nr_frags = ++i; + get_page(pfrag->page); + } + copy = min_t(int, copy, pfrag->size - pfrag->offset); + if (getfrag(from, + page_address(pfrag->page) + pfrag->offset, + offset, copy, skb->len, skb) < 0) + goto error_efault; + + pfrag->offset += copy; + skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], copy); + skb_len_add(skb, copy); + wmem_alloc_delta += copy; + } else { + err = skb_zerocopy_iter_dgram(skb, from, copy); + if (err < 0) + goto error; + } + offset += copy; + length -= copy; + } + + if (wmem_alloc_delta) + refcount_add(wmem_alloc_delta, &sk->sk_wmem_alloc); + return 0; + +error_efault: + err = -EFAULT; +error: + net_zcopy_put_abort(uarg, extra_uref); + cork->length -= length; + IP_INC_STATS(sock_net(sk), IPSTATS_MIB_OUTDISCARDS); + refcount_add(wmem_alloc_delta, &sk->sk_wmem_alloc); + return err; +} + +static int ip_setup_cork(struct sock *sk, struct inet_cork *cork, + struct ipcm_cookie *ipc, struct rtable **rtp) +{ + struct ip_options_rcu *opt; + struct rtable *rt; + + rt = *rtp; + if (unlikely(!rt)) + return -EFAULT; + + /* + * setup for corking. + */ + opt = ipc->opt; + if (opt) { + if (!cork->opt) { + cork->opt = kmalloc(sizeof(struct ip_options) + 40, + sk->sk_allocation); + if (unlikely(!cork->opt)) + return -ENOBUFS; + } + memcpy(cork->opt, &opt->opt, sizeof(struct ip_options) + opt->opt.optlen); + cork->flags |= IPCORK_OPT; + cork->addr = ipc->addr; + } + + cork->fragsize = ip_sk_use_pmtu(sk) ? + dst_mtu(&rt->dst) : READ_ONCE(rt->dst.dev->mtu); + + if (!inetdev_valid_mtu(cork->fragsize)) + return -ENETUNREACH; + + cork->gso_size = ipc->gso_size; + + cork->dst = &rt->dst; + /* We stole this route, caller should not release it. */ + *rtp = NULL; + + cork->length = 0; + cork->ttl = ipc->ttl; + cork->tos = ipc->tos; + cork->mark = ipc->sockc.mark; + cork->priority = ipc->priority; + cork->transmit_time = ipc->sockc.transmit_time; + cork->tx_flags = 0; + sock_tx_timestamp(sk, ipc->sockc.tsflags, &cork->tx_flags); + + return 0; +} + +/* + * ip_append_data() and ip_append_page() can make one large IP datagram + * from many pieces of data. Each pieces will be holded on the socket + * until ip_push_pending_frames() is called. Each piece can be a page + * or non-page data. + * + * Not only UDP, other transport protocols - e.g. raw sockets - can use + * this interface potentially. + * + * LATER: length must be adjusted by pad at tail, when it is required. + */ +int ip_append_data(struct sock *sk, struct flowi4 *fl4, + int getfrag(void *from, char *to, int offset, int len, + int odd, struct sk_buff *skb), + void *from, int length, int transhdrlen, + struct ipcm_cookie *ipc, struct rtable **rtp, + unsigned int flags) +{ + struct inet_sock *inet = inet_sk(sk); + int err; + + if (flags&MSG_PROBE) + return 0; + + if (skb_queue_empty(&sk->sk_write_queue)) { + err = ip_setup_cork(sk, &inet->cork.base, ipc, rtp); + if (err) + return err; + } else { + transhdrlen = 0; + } + + return __ip_append_data(sk, fl4, &sk->sk_write_queue, &inet->cork.base, + sk_page_frag(sk), getfrag, + from, length, transhdrlen, flags); +} + +ssize_t ip_append_page(struct sock *sk, struct flowi4 *fl4, struct page *page, + int offset, size_t size, int flags) +{ + struct inet_sock *inet = inet_sk(sk); + struct sk_buff *skb; + struct rtable *rt; + struct ip_options *opt = NULL; + struct inet_cork *cork; + int hh_len; + int mtu; + int len; + int err; + unsigned int maxfraglen, fragheaderlen, fraggap, maxnonfragsize; + + if (inet->hdrincl) + return -EPERM; + + if (flags&MSG_PROBE) + return 0; + + if (skb_queue_empty(&sk->sk_write_queue)) + return -EINVAL; + + cork = &inet->cork.base; + rt = (struct rtable *)cork->dst; + if (cork->flags & IPCORK_OPT) + opt = cork->opt; + + if (!(rt->dst.dev->features & NETIF_F_SG)) + return -EOPNOTSUPP; + + hh_len = LL_RESERVED_SPACE(rt->dst.dev); + mtu = cork->gso_size ? IP_MAX_MTU : cork->fragsize; + + fragheaderlen = sizeof(struct iphdr) + (opt ? opt->optlen : 0); + maxfraglen = ((mtu - fragheaderlen) & ~7) + fragheaderlen; + maxnonfragsize = ip_sk_ignore_df(sk) ? 0xFFFF : mtu; + + if (cork->length + size > maxnonfragsize - fragheaderlen) { + ip_local_error(sk, EMSGSIZE, fl4->daddr, inet->inet_dport, + mtu - (opt ? opt->optlen : 0)); + return -EMSGSIZE; + } + + skb = skb_peek_tail(&sk->sk_write_queue); + if (!skb) + return -EINVAL; + + cork->length += size; + + while (size > 0) { + /* Check if the remaining data fits into current packet. */ + len = mtu - skb->len; + if (len < size) + len = maxfraglen - skb->len; + + if (len <= 0) { + struct sk_buff *skb_prev; + int alloclen; + + skb_prev = skb; + fraggap = skb_prev->len - maxfraglen; + + alloclen = fragheaderlen + hh_len + fraggap + 15; + skb = sock_wmalloc(sk, alloclen, 1, sk->sk_allocation); + if (unlikely(!skb)) { + err = -ENOBUFS; + goto error; + } + + /* + * Fill in the control structures + */ + skb->ip_summed = CHECKSUM_NONE; + skb->csum = 0; + skb_reserve(skb, hh_len); + + /* + * Find where to start putting bytes. + */ + skb_put(skb, fragheaderlen + fraggap); + skb_reset_network_header(skb); + skb->transport_header = (skb->network_header + + fragheaderlen); + if (fraggap) { + skb->csum = skb_copy_and_csum_bits(skb_prev, + maxfraglen, + skb_transport_header(skb), + fraggap); + skb_prev->csum = csum_sub(skb_prev->csum, + skb->csum); + pskb_trim_unique(skb_prev, maxfraglen); + } + + /* + * Put the packet on the pending queue. + */ + __skb_queue_tail(&sk->sk_write_queue, skb); + continue; + } + + if (len > size) + len = size; + + if (skb_append_pagefrags(skb, page, offset, len)) { + err = -EMSGSIZE; + goto error; + } + + if (skb->ip_summed == CHECKSUM_NONE) { + __wsum csum; + csum = csum_page(page, offset, len); + skb->csum = csum_block_add(skb->csum, csum, skb->len); + } + + skb_len_add(skb, len); + refcount_add(len, &sk->sk_wmem_alloc); + offset += len; + size -= len; + } + return 0; + +error: + cork->length -= size; + IP_INC_STATS(sock_net(sk), IPSTATS_MIB_OUTDISCARDS); + return err; +} + +static void ip_cork_release(struct inet_cork *cork) +{ + cork->flags &= ~IPCORK_OPT; + kfree(cork->opt); + cork->opt = NULL; + dst_release(cork->dst); + cork->dst = NULL; +} + +/* + * Combined all pending IP fragments on the socket as one IP datagram + * and push them out. + */ +struct sk_buff *__ip_make_skb(struct sock *sk, + struct flowi4 *fl4, + struct sk_buff_head *queue, + struct inet_cork *cork) +{ + struct sk_buff *skb, *tmp_skb; + struct sk_buff **tail_skb; + struct inet_sock *inet = inet_sk(sk); + struct net *net = sock_net(sk); + struct ip_options *opt = NULL; + struct rtable *rt = (struct rtable *)cork->dst; + struct iphdr *iph; + __be16 df = 0; + __u8 ttl; + + skb = __skb_dequeue(queue); + if (!skb) + goto out; + tail_skb = &(skb_shinfo(skb)->frag_list); + + /* move skb->data to ip header from ext header */ + if (skb->data < skb_network_header(skb)) + __skb_pull(skb, skb_network_offset(skb)); + while ((tmp_skb = __skb_dequeue(queue)) != NULL) { + __skb_pull(tmp_skb, skb_network_header_len(skb)); + *tail_skb = tmp_skb; + tail_skb = &(tmp_skb->next); + skb->len += tmp_skb->len; + skb->data_len += tmp_skb->len; + skb->truesize += tmp_skb->truesize; + tmp_skb->destructor = NULL; + tmp_skb->sk = NULL; + } + + /* Unless user demanded real pmtu discovery (IP_PMTUDISC_DO), we allow + * to fragment the frame generated here. No matter, what transforms + * how transforms change size of the packet, it will come out. + */ + skb->ignore_df = ip_sk_ignore_df(sk); + + /* DF bit is set when we want to see DF on outgoing frames. + * If ignore_df is set too, we still allow to fragment this frame + * locally. */ + if (inet->pmtudisc == IP_PMTUDISC_DO || + inet->pmtudisc == IP_PMTUDISC_PROBE || + (skb->len <= dst_mtu(&rt->dst) && + ip_dont_fragment(sk, &rt->dst))) + df = htons(IP_DF); + + if (cork->flags & IPCORK_OPT) + opt = cork->opt; + + if (cork->ttl != 0) + ttl = cork->ttl; + else if (rt->rt_type == RTN_MULTICAST) + ttl = inet->mc_ttl; + else + ttl = ip_select_ttl(inet, &rt->dst); + + iph = ip_hdr(skb); + iph->version = 4; + iph->ihl = 5; + iph->tos = (cork->tos != -1) ? cork->tos : inet->tos; + iph->frag_off = df; + iph->ttl = ttl; + iph->protocol = sk->sk_protocol; + ip_copy_addrs(iph, fl4); + ip_select_ident(net, skb, sk); + + if (opt) { + iph->ihl += opt->optlen >> 2; + ip_options_build(skb, opt, cork->addr, rt); + } + + skb->priority = (cork->tos != -1) ? cork->priority: sk->sk_priority; + skb->mark = cork->mark; + skb->tstamp = cork->transmit_time; + /* + * Steal rt from cork.dst to avoid a pair of atomic_inc/atomic_dec + * on dst refcount + */ + cork->dst = NULL; + skb_dst_set(skb, &rt->dst); + + if (iph->protocol == IPPROTO_ICMP) { + u8 icmp_type; + + /* For such sockets, transhdrlen is zero when do ip_append_data(), + * so icmphdr does not in skb linear region and can not get icmp_type + * by icmp_hdr(skb)->type. + */ + if (sk->sk_type == SOCK_RAW && !inet_sk(sk)->hdrincl) + icmp_type = fl4->fl4_icmp_type; + else + icmp_type = icmp_hdr(skb)->type; + icmp_out_count(net, icmp_type); + } + + ip_cork_release(cork); +out: + return skb; +} + +int ip_send_skb(struct net *net, struct sk_buff *skb) +{ + int err; + + err = ip_local_out(net, skb->sk, skb); + if (err) { + if (err > 0) + err = net_xmit_errno(err); + if (err) + IP_INC_STATS(net, IPSTATS_MIB_OUTDISCARDS); + } + + return err; +} + +int ip_push_pending_frames(struct sock *sk, struct flowi4 *fl4) +{ + struct sk_buff *skb; + + skb = ip_finish_skb(sk, fl4); + if (!skb) + return 0; + + /* Netfilter gets whole the not fragmented skb. */ + return ip_send_skb(sock_net(sk), skb); +} + +/* + * Throw away all pending data on the socket. + */ +static void __ip_flush_pending_frames(struct sock *sk, + struct sk_buff_head *queue, + struct inet_cork *cork) +{ + struct sk_buff *skb; + + while ((skb = __skb_dequeue_tail(queue)) != NULL) + kfree_skb(skb); + + ip_cork_release(cork); +} + +void ip_flush_pending_frames(struct sock *sk) +{ + __ip_flush_pending_frames(sk, &sk->sk_write_queue, &inet_sk(sk)->cork.base); +} + +struct sk_buff *ip_make_skb(struct sock *sk, + struct flowi4 *fl4, + int getfrag(void *from, char *to, int offset, + int len, int odd, struct sk_buff *skb), + void *from, int length, int transhdrlen, + struct ipcm_cookie *ipc, struct rtable **rtp, + struct inet_cork *cork, unsigned int flags) +{ + struct sk_buff_head queue; + int err; + + if (flags & MSG_PROBE) + return NULL; + + __skb_queue_head_init(&queue); + + cork->flags = 0; + cork->addr = 0; + cork->opt = NULL; + err = ip_setup_cork(sk, cork, ipc, rtp); + if (err) + return ERR_PTR(err); + + err = __ip_append_data(sk, fl4, &queue, cork, + ¤t->task_frag, getfrag, + from, length, transhdrlen, flags); + if (err) { + __ip_flush_pending_frames(sk, &queue, cork); + return ERR_PTR(err); + } + + return __ip_make_skb(sk, fl4, &queue, cork); +} + +/* + * Fetch data from kernel space and fill in checksum if needed. + */ +static int ip_reply_glue_bits(void *dptr, char *to, int offset, + int len, int odd, struct sk_buff *skb) +{ + __wsum csum; + + csum = csum_partial_copy_nocheck(dptr+offset, to, len); + skb->csum = csum_block_add(skb->csum, csum, odd); + return 0; +} + +/* + * Generic function to send a packet as reply to another packet. + * Used to send some TCP resets/acks so far. + */ +void ip_send_unicast_reply(struct sock *sk, struct sk_buff *skb, + const struct ip_options *sopt, + __be32 daddr, __be32 saddr, + const struct ip_reply_arg *arg, + unsigned int len, u64 transmit_time, u32 txhash) +{ + struct ip_options_data replyopts; + struct ipcm_cookie ipc; + struct flowi4 fl4; + struct rtable *rt = skb_rtable(skb); + struct net *net = sock_net(sk); + struct sk_buff *nskb; + int err; + int oif; + + if (__ip_options_echo(net, &replyopts.opt.opt, skb, sopt)) + return; + + ipcm_init(&ipc); + ipc.addr = daddr; + ipc.sockc.transmit_time = transmit_time; + + if (replyopts.opt.opt.optlen) { + ipc.opt = &replyopts.opt; + + if (replyopts.opt.opt.srr) + daddr = replyopts.opt.opt.faddr; + } + + oif = arg->bound_dev_if; + if (!oif && netif_index_is_l3_master(net, skb->skb_iif)) + oif = skb->skb_iif; + + flowi4_init_output(&fl4, oif, + IP4_REPLY_MARK(net, skb->mark) ?: sk->sk_mark, + RT_TOS(arg->tos), + RT_SCOPE_UNIVERSE, ip_hdr(skb)->protocol, + ip_reply_arg_flowi_flags(arg), + daddr, saddr, + tcp_hdr(skb)->source, tcp_hdr(skb)->dest, + arg->uid); + security_skb_classify_flow(skb, flowi4_to_flowi_common(&fl4)); + rt = ip_route_output_flow(net, &fl4, sk); + if (IS_ERR(rt)) + return; + + inet_sk(sk)->tos = arg->tos & ~INET_ECN_MASK; + + sk->sk_protocol = ip_hdr(skb)->protocol; + sk->sk_bound_dev_if = arg->bound_dev_if; + sk->sk_sndbuf = READ_ONCE(sysctl_wmem_default); + ipc.sockc.mark = fl4.flowi4_mark; + err = ip_append_data(sk, &fl4, ip_reply_glue_bits, arg->iov->iov_base, + len, 0, &ipc, &rt, MSG_DONTWAIT); + if (unlikely(err)) { + ip_flush_pending_frames(sk); + goto out; + } + + nskb = skb_peek(&sk->sk_write_queue); + if (nskb) { + if (arg->csumoffset >= 0) + *((__sum16 *)skb_transport_header(nskb) + + arg->csumoffset) = csum_fold(csum_add(nskb->csum, + arg->csum)); + nskb->ip_summed = CHECKSUM_NONE; + nskb->mono_delivery_time = !!transmit_time; + if (txhash) + skb_set_hash(nskb, txhash, PKT_HASH_TYPE_L4); + ip_push_pending_frames(sk, &fl4); + } +out: + ip_rt_put(rt); +} + +void __init ip_init(void) +{ + ip_rt_init(); + inet_initpeers(); + +#if defined(CONFIG_IP_MULTICAST) + igmp_mc_init(); +#endif +} diff --git a/net/ipv4/ip_sockglue.c b/net/ipv4/ip_sockglue.c new file mode 100644 index 000000000..c1fb7580e --- /dev/null +++ b/net/ipv4/ip_sockglue.c @@ -0,0 +1,1832 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * The IP to API glue. + * + * Authors: see ip.c + * + * Fixes: + * Many : Split from ip.c , see ip.c for history. + * Martin Mares : TOS setting fixed. + * Alan Cox : Fixed a couple of oopses in Martin's + * TOS tweaks. + * Mike McLagan : Routing by source + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_IPV6) +#include +#endif +#include + +#include +#include + +#include + +/* + * SOL_IP control messages. + */ + +static void ip_cmsg_recv_pktinfo(struct msghdr *msg, struct sk_buff *skb) +{ + struct in_pktinfo info = *PKTINFO_SKB_CB(skb); + + info.ipi_addr.s_addr = ip_hdr(skb)->daddr; + + put_cmsg(msg, SOL_IP, IP_PKTINFO, sizeof(info), &info); +} + +static void ip_cmsg_recv_ttl(struct msghdr *msg, struct sk_buff *skb) +{ + int ttl = ip_hdr(skb)->ttl; + put_cmsg(msg, SOL_IP, IP_TTL, sizeof(int), &ttl); +} + +static void ip_cmsg_recv_tos(struct msghdr *msg, struct sk_buff *skb) +{ + put_cmsg(msg, SOL_IP, IP_TOS, 1, &ip_hdr(skb)->tos); +} + +static void ip_cmsg_recv_opts(struct msghdr *msg, struct sk_buff *skb) +{ + if (IPCB(skb)->opt.optlen == 0) + return; + + put_cmsg(msg, SOL_IP, IP_RECVOPTS, IPCB(skb)->opt.optlen, + ip_hdr(skb) + 1); +} + + +static void ip_cmsg_recv_retopts(struct net *net, struct msghdr *msg, + struct sk_buff *skb) +{ + unsigned char optbuf[sizeof(struct ip_options) + 40]; + struct ip_options *opt = (struct ip_options *)optbuf; + + if (IPCB(skb)->opt.optlen == 0) + return; + + if (ip_options_echo(net, opt, skb)) { + msg->msg_flags |= MSG_CTRUNC; + return; + } + ip_options_undo(opt); + + put_cmsg(msg, SOL_IP, IP_RETOPTS, opt->optlen, opt->__data); +} + +static void ip_cmsg_recv_fragsize(struct msghdr *msg, struct sk_buff *skb) +{ + int val; + + if (IPCB(skb)->frag_max_size == 0) + return; + + val = IPCB(skb)->frag_max_size; + put_cmsg(msg, SOL_IP, IP_RECVFRAGSIZE, sizeof(val), &val); +} + +static void ip_cmsg_recv_checksum(struct msghdr *msg, struct sk_buff *skb, + int tlen, int offset) +{ + __wsum csum = skb->csum; + + if (skb->ip_summed != CHECKSUM_COMPLETE) + return; + + if (offset != 0) { + int tend_off = skb_transport_offset(skb) + tlen; + csum = csum_sub(csum, skb_checksum(skb, tend_off, offset, 0)); + } + + put_cmsg(msg, SOL_IP, IP_CHECKSUM, sizeof(__wsum), &csum); +} + +static void ip_cmsg_recv_security(struct msghdr *msg, struct sk_buff *skb) +{ + char *secdata; + u32 seclen, secid; + int err; + + err = security_socket_getpeersec_dgram(NULL, skb, &secid); + if (err) + return; + + err = security_secid_to_secctx(secid, &secdata, &seclen); + if (err) + return; + + put_cmsg(msg, SOL_IP, SCM_SECURITY, seclen, secdata); + security_release_secctx(secdata, seclen); +} + +static void ip_cmsg_recv_dstaddr(struct msghdr *msg, struct sk_buff *skb) +{ + __be16 _ports[2], *ports; + struct sockaddr_in sin; + + /* All current transport protocols have the port numbers in the + * first four bytes of the transport header and this function is + * written with this assumption in mind. + */ + ports = skb_header_pointer(skb, skb_transport_offset(skb), + sizeof(_ports), &_ports); + if (!ports) + return; + + sin.sin_family = AF_INET; + sin.sin_addr.s_addr = ip_hdr(skb)->daddr; + sin.sin_port = ports[1]; + memset(sin.sin_zero, 0, sizeof(sin.sin_zero)); + + put_cmsg(msg, SOL_IP, IP_ORIGDSTADDR, sizeof(sin), &sin); +} + +void ip_cmsg_recv_offset(struct msghdr *msg, struct sock *sk, + struct sk_buff *skb, int tlen, int offset) +{ + struct inet_sock *inet = inet_sk(sk); + unsigned int flags = inet->cmsg_flags; + + /* Ordered by supposed usage frequency */ + if (flags & IP_CMSG_PKTINFO) { + ip_cmsg_recv_pktinfo(msg, skb); + + flags &= ~IP_CMSG_PKTINFO; + if (!flags) + return; + } + + if (flags & IP_CMSG_TTL) { + ip_cmsg_recv_ttl(msg, skb); + + flags &= ~IP_CMSG_TTL; + if (!flags) + return; + } + + if (flags & IP_CMSG_TOS) { + ip_cmsg_recv_tos(msg, skb); + + flags &= ~IP_CMSG_TOS; + if (!flags) + return; + } + + if (flags & IP_CMSG_RECVOPTS) { + ip_cmsg_recv_opts(msg, skb); + + flags &= ~IP_CMSG_RECVOPTS; + if (!flags) + return; + } + + if (flags & IP_CMSG_RETOPTS) { + ip_cmsg_recv_retopts(sock_net(sk), msg, skb); + + flags &= ~IP_CMSG_RETOPTS; + if (!flags) + return; + } + + if (flags & IP_CMSG_PASSSEC) { + ip_cmsg_recv_security(msg, skb); + + flags &= ~IP_CMSG_PASSSEC; + if (!flags) + return; + } + + if (flags & IP_CMSG_ORIGDSTADDR) { + ip_cmsg_recv_dstaddr(msg, skb); + + flags &= ~IP_CMSG_ORIGDSTADDR; + if (!flags) + return; + } + + if (flags & IP_CMSG_CHECKSUM) + ip_cmsg_recv_checksum(msg, skb, tlen, offset); + + if (flags & IP_CMSG_RECVFRAGSIZE) + ip_cmsg_recv_fragsize(msg, skb); +} +EXPORT_SYMBOL(ip_cmsg_recv_offset); + +int ip_cmsg_send(struct sock *sk, struct msghdr *msg, struct ipcm_cookie *ipc, + bool allow_ipv6) +{ + int err, val; + struct cmsghdr *cmsg; + struct net *net = sock_net(sk); + + for_each_cmsghdr(cmsg, msg) { + if (!CMSG_OK(msg, cmsg)) + return -EINVAL; +#if IS_ENABLED(CONFIG_IPV6) + if (allow_ipv6 && + cmsg->cmsg_level == SOL_IPV6 && + cmsg->cmsg_type == IPV6_PKTINFO) { + struct in6_pktinfo *src_info; + + if (cmsg->cmsg_len < CMSG_LEN(sizeof(*src_info))) + return -EINVAL; + src_info = (struct in6_pktinfo *)CMSG_DATA(cmsg); + if (!ipv6_addr_v4mapped(&src_info->ipi6_addr)) + return -EINVAL; + if (src_info->ipi6_ifindex) + ipc->oif = src_info->ipi6_ifindex; + ipc->addr = src_info->ipi6_addr.s6_addr32[3]; + continue; + } +#endif + if (cmsg->cmsg_level == SOL_SOCKET) { + err = __sock_cmsg_send(sk, msg, cmsg, &ipc->sockc); + if (err) + return err; + continue; + } + + if (cmsg->cmsg_level != SOL_IP) + continue; + switch (cmsg->cmsg_type) { + case IP_RETOPTS: + err = cmsg->cmsg_len - sizeof(struct cmsghdr); + + /* Our caller is responsible for freeing ipc->opt */ + err = ip_options_get(net, &ipc->opt, + KERNEL_SOCKPTR(CMSG_DATA(cmsg)), + err < 40 ? err : 40); + if (err) + return err; + break; + case IP_PKTINFO: + { + struct in_pktinfo *info; + if (cmsg->cmsg_len != CMSG_LEN(sizeof(struct in_pktinfo))) + return -EINVAL; + info = (struct in_pktinfo *)CMSG_DATA(cmsg); + if (info->ipi_ifindex) + ipc->oif = info->ipi_ifindex; + ipc->addr = info->ipi_spec_dst.s_addr; + break; + } + case IP_TTL: + if (cmsg->cmsg_len != CMSG_LEN(sizeof(int))) + return -EINVAL; + val = *(int *)CMSG_DATA(cmsg); + if (val < 1 || val > 255) + return -EINVAL; + ipc->ttl = val; + break; + case IP_TOS: + if (cmsg->cmsg_len == CMSG_LEN(sizeof(int))) + val = *(int *)CMSG_DATA(cmsg); + else if (cmsg->cmsg_len == CMSG_LEN(sizeof(u8))) + val = *(u8 *)CMSG_DATA(cmsg); + else + return -EINVAL; + if (val < 0 || val > 255) + return -EINVAL; + ipc->tos = val; + ipc->priority = rt_tos2priority(ipc->tos); + break; + case IP_PROTOCOL: + if (cmsg->cmsg_len != CMSG_LEN(sizeof(int))) + return -EINVAL; + val = *(int *)CMSG_DATA(cmsg); + if (val < 1 || val > 255) + return -EINVAL; + ipc->protocol = val; + break; + default: + return -EINVAL; + } + } + return 0; +} + +static void ip_ra_destroy_rcu(struct rcu_head *head) +{ + struct ip_ra_chain *ra = container_of(head, struct ip_ra_chain, rcu); + + sock_put(ra->saved_sk); + kfree(ra); +} + +int ip_ra_control(struct sock *sk, unsigned char on, + void (*destructor)(struct sock *)) +{ + struct ip_ra_chain *ra, *new_ra; + struct ip_ra_chain __rcu **rap; + struct net *net = sock_net(sk); + + if (sk->sk_type != SOCK_RAW || inet_sk(sk)->inet_num == IPPROTO_RAW) + return -EINVAL; + + new_ra = on ? kmalloc(sizeof(*new_ra), GFP_KERNEL) : NULL; + if (on && !new_ra) + return -ENOMEM; + + mutex_lock(&net->ipv4.ra_mutex); + for (rap = &net->ipv4.ra_chain; + (ra = rcu_dereference_protected(*rap, + lockdep_is_held(&net->ipv4.ra_mutex))) != NULL; + rap = &ra->next) { + if (ra->sk == sk) { + if (on) { + mutex_unlock(&net->ipv4.ra_mutex); + kfree(new_ra); + return -EADDRINUSE; + } + /* dont let ip_call_ra_chain() use sk again */ + ra->sk = NULL; + RCU_INIT_POINTER(*rap, ra->next); + mutex_unlock(&net->ipv4.ra_mutex); + + if (ra->destructor) + ra->destructor(sk); + /* + * Delay sock_put(sk) and kfree(ra) after one rcu grace + * period. This guarantee ip_call_ra_chain() dont need + * to mess with socket refcounts. + */ + ra->saved_sk = sk; + call_rcu(&ra->rcu, ip_ra_destroy_rcu); + return 0; + } + } + if (!new_ra) { + mutex_unlock(&net->ipv4.ra_mutex); + return -ENOBUFS; + } + new_ra->sk = sk; + new_ra->destructor = destructor; + + RCU_INIT_POINTER(new_ra->next, ra); + rcu_assign_pointer(*rap, new_ra); + sock_hold(sk); + mutex_unlock(&net->ipv4.ra_mutex); + + return 0; +} + +static void ipv4_icmp_error_rfc4884(const struct sk_buff *skb, + struct sock_ee_data_rfc4884 *out) +{ + switch (icmp_hdr(skb)->type) { + case ICMP_DEST_UNREACH: + case ICMP_TIME_EXCEEDED: + case ICMP_PARAMETERPROB: + ip_icmp_error_rfc4884(skb, out, sizeof(struct icmphdr), + icmp_hdr(skb)->un.reserved[1] * 4); + } +} + +void ip_icmp_error(struct sock *sk, struct sk_buff *skb, int err, + __be16 port, u32 info, u8 *payload) +{ + struct sock_exterr_skb *serr; + + skb = skb_clone(skb, GFP_ATOMIC); + if (!skb) + return; + + serr = SKB_EXT_ERR(skb); + serr->ee.ee_errno = err; + serr->ee.ee_origin = SO_EE_ORIGIN_ICMP; + serr->ee.ee_type = icmp_hdr(skb)->type; + serr->ee.ee_code = icmp_hdr(skb)->code; + serr->ee.ee_pad = 0; + serr->ee.ee_info = info; + serr->ee.ee_data = 0; + serr->addr_offset = (u8 *)&(((struct iphdr *)(icmp_hdr(skb) + 1))->daddr) - + skb_network_header(skb); + serr->port = port; + + if (skb_pull(skb, payload - skb->data)) { + if (inet_sk(sk)->recverr_rfc4884) + ipv4_icmp_error_rfc4884(skb, &serr->ee.ee_rfc4884); + + skb_reset_transport_header(skb); + if (sock_queue_err_skb(sk, skb) == 0) + return; + } + kfree_skb(skb); +} + +void ip_local_error(struct sock *sk, int err, __be32 daddr, __be16 port, u32 info) +{ + struct inet_sock *inet = inet_sk(sk); + struct sock_exterr_skb *serr; + struct iphdr *iph; + struct sk_buff *skb; + + if (!inet->recverr) + return; + + skb = alloc_skb(sizeof(struct iphdr), GFP_ATOMIC); + if (!skb) + return; + + skb_put(skb, sizeof(struct iphdr)); + skb_reset_network_header(skb); + iph = ip_hdr(skb); + iph->daddr = daddr; + + serr = SKB_EXT_ERR(skb); + serr->ee.ee_errno = err; + serr->ee.ee_origin = SO_EE_ORIGIN_LOCAL; + serr->ee.ee_type = 0; + serr->ee.ee_code = 0; + serr->ee.ee_pad = 0; + serr->ee.ee_info = info; + serr->ee.ee_data = 0; + serr->addr_offset = (u8 *)&iph->daddr - skb_network_header(skb); + serr->port = port; + + __skb_pull(skb, skb_tail_pointer(skb) - skb->data); + skb_reset_transport_header(skb); + + if (sock_queue_err_skb(sk, skb)) + kfree_skb(skb); +} + +/* For some errors we have valid addr_offset even with zero payload and + * zero port. Also, addr_offset should be supported if port is set. + */ +static inline bool ipv4_datagram_support_addr(struct sock_exterr_skb *serr) +{ + return serr->ee.ee_origin == SO_EE_ORIGIN_ICMP || + serr->ee.ee_origin == SO_EE_ORIGIN_LOCAL || serr->port; +} + +/* IPv4 supports cmsg on all imcp errors and some timestamps + * + * Timestamp code paths do not initialize the fields expected by cmsg: + * the PKTINFO fields in skb->cb[]. Fill those in here. + */ +static bool ipv4_datagram_support_cmsg(const struct sock *sk, + struct sk_buff *skb, + int ee_origin) +{ + struct in_pktinfo *info; + + if (ee_origin == SO_EE_ORIGIN_ICMP) + return true; + + if (ee_origin == SO_EE_ORIGIN_LOCAL) + return false; + + /* Support IP_PKTINFO on tstamp packets if requested, to correlate + * timestamp with egress dev. Not possible for packets without iif + * or without payload (SOF_TIMESTAMPING_OPT_TSONLY). + */ + info = PKTINFO_SKB_CB(skb); + if (!(READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_OPT_CMSG) || + !info->ipi_ifindex) + return false; + + info->ipi_spec_dst.s_addr = ip_hdr(skb)->saddr; + return true; +} + +/* + * Handle MSG_ERRQUEUE + */ +int ip_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len) +{ + struct sock_exterr_skb *serr; + struct sk_buff *skb; + DECLARE_SOCKADDR(struct sockaddr_in *, sin, msg->msg_name); + struct { + struct sock_extended_err ee; + struct sockaddr_in offender; + } errhdr; + int err; + int copied; + + err = -EAGAIN; + skb = sock_dequeue_err_skb(sk); + if (!skb) + goto out; + + copied = skb->len; + if (copied > len) { + msg->msg_flags |= MSG_TRUNC; + copied = len; + } + err = skb_copy_datagram_msg(skb, 0, msg, copied); + if (unlikely(err)) { + kfree_skb(skb); + return err; + } + sock_recv_timestamp(msg, sk, skb); + + serr = SKB_EXT_ERR(skb); + + if (sin && ipv4_datagram_support_addr(serr)) { + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = *(__be32 *)(skb_network_header(skb) + + serr->addr_offset); + sin->sin_port = serr->port; + memset(&sin->sin_zero, 0, sizeof(sin->sin_zero)); + *addr_len = sizeof(*sin); + } + + memcpy(&errhdr.ee, &serr->ee, sizeof(struct sock_extended_err)); + sin = &errhdr.offender; + memset(sin, 0, sizeof(*sin)); + + if (ipv4_datagram_support_cmsg(sk, skb, serr->ee.ee_origin)) { + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = ip_hdr(skb)->saddr; + if (inet_sk(sk)->cmsg_flags) + ip_cmsg_recv(msg, skb); + } + + put_cmsg(msg, SOL_IP, IP_RECVERR, sizeof(errhdr), &errhdr); + + /* Now we could try to dump offended packet options */ + + msg->msg_flags |= MSG_ERRQUEUE; + err = copied; + + consume_skb(skb); +out: + return err; +} + +void __ip_sock_set_tos(struct sock *sk, int val) +{ + if (sk->sk_type == SOCK_STREAM) { + val &= ~INET_ECN_MASK; + val |= inet_sk(sk)->tos & INET_ECN_MASK; + } + if (inet_sk(sk)->tos != val) { + inet_sk(sk)->tos = val; + WRITE_ONCE(sk->sk_priority, rt_tos2priority(val)); + sk_dst_reset(sk); + } +} + +void ip_sock_set_tos(struct sock *sk, int val) +{ + lock_sock(sk); + __ip_sock_set_tos(sk, val); + release_sock(sk); +} +EXPORT_SYMBOL(ip_sock_set_tos); + +void ip_sock_set_freebind(struct sock *sk) +{ + lock_sock(sk); + inet_sk(sk)->freebind = true; + release_sock(sk); +} +EXPORT_SYMBOL(ip_sock_set_freebind); + +void ip_sock_set_recverr(struct sock *sk) +{ + lock_sock(sk); + inet_sk(sk)->recverr = true; + release_sock(sk); +} +EXPORT_SYMBOL(ip_sock_set_recverr); + +int ip_sock_set_mtu_discover(struct sock *sk, int val) +{ + if (val < IP_PMTUDISC_DONT || val > IP_PMTUDISC_OMIT) + return -EINVAL; + lock_sock(sk); + inet_sk(sk)->pmtudisc = val; + release_sock(sk); + return 0; +} +EXPORT_SYMBOL(ip_sock_set_mtu_discover); + +void ip_sock_set_pktinfo(struct sock *sk) +{ + lock_sock(sk); + inet_sk(sk)->cmsg_flags |= IP_CMSG_PKTINFO; + release_sock(sk); +} +EXPORT_SYMBOL(ip_sock_set_pktinfo); + +/* + * Socket option code for IP. This is the end of the line after any + * TCP,UDP etc options on an IP socket. + */ +static bool setsockopt_needs_rtnl(int optname) +{ + switch (optname) { + case IP_ADD_MEMBERSHIP: + case IP_ADD_SOURCE_MEMBERSHIP: + case IP_BLOCK_SOURCE: + case IP_DROP_MEMBERSHIP: + case IP_DROP_SOURCE_MEMBERSHIP: + case IP_MSFILTER: + case IP_UNBLOCK_SOURCE: + case MCAST_BLOCK_SOURCE: + case MCAST_MSFILTER: + case MCAST_JOIN_GROUP: + case MCAST_JOIN_SOURCE_GROUP: + case MCAST_LEAVE_GROUP: + case MCAST_LEAVE_SOURCE_GROUP: + case MCAST_UNBLOCK_SOURCE: + return true; + } + return false; +} + +static int set_mcast_msfilter(struct sock *sk, int ifindex, + int numsrc, int fmode, + struct sockaddr_storage *group, + struct sockaddr_storage *list) +{ + struct ip_msfilter *msf; + struct sockaddr_in *psin; + int err, i; + + msf = kmalloc(IP_MSFILTER_SIZE(numsrc), GFP_KERNEL); + if (!msf) + return -ENOBUFS; + + psin = (struct sockaddr_in *)group; + if (psin->sin_family != AF_INET) + goto Eaddrnotavail; + msf->imsf_multiaddr = psin->sin_addr.s_addr; + msf->imsf_interface = 0; + msf->imsf_fmode = fmode; + msf->imsf_numsrc = numsrc; + for (i = 0; i < numsrc; ++i) { + psin = (struct sockaddr_in *)&list[i]; + + if (psin->sin_family != AF_INET) + goto Eaddrnotavail; + msf->imsf_slist_flex[i] = psin->sin_addr.s_addr; + } + err = ip_mc_msfilter(sk, msf, ifindex); + kfree(msf); + return err; + +Eaddrnotavail: + kfree(msf); + return -EADDRNOTAVAIL; +} + +static int copy_group_source_from_sockptr(struct group_source_req *greqs, + sockptr_t optval, int optlen) +{ + if (in_compat_syscall()) { + struct compat_group_source_req gr32; + + if (optlen != sizeof(gr32)) + return -EINVAL; + if (copy_from_sockptr(&gr32, optval, sizeof(gr32))) + return -EFAULT; + greqs->gsr_interface = gr32.gsr_interface; + greqs->gsr_group = gr32.gsr_group; + greqs->gsr_source = gr32.gsr_source; + } else { + if (optlen != sizeof(*greqs)) + return -EINVAL; + if (copy_from_sockptr(greqs, optval, sizeof(*greqs))) + return -EFAULT; + } + + return 0; +} + +static int do_mcast_group_source(struct sock *sk, int optname, + sockptr_t optval, int optlen) +{ + struct group_source_req greqs; + struct ip_mreq_source mreqs; + struct sockaddr_in *psin; + int omode, add, err; + + err = copy_group_source_from_sockptr(&greqs, optval, optlen); + if (err) + return err; + + if (greqs.gsr_group.ss_family != AF_INET || + greqs.gsr_source.ss_family != AF_INET) + return -EADDRNOTAVAIL; + + psin = (struct sockaddr_in *)&greqs.gsr_group; + mreqs.imr_multiaddr = psin->sin_addr.s_addr; + psin = (struct sockaddr_in *)&greqs.gsr_source; + mreqs.imr_sourceaddr = psin->sin_addr.s_addr; + mreqs.imr_interface = 0; /* use index for mc_source */ + + if (optname == MCAST_BLOCK_SOURCE) { + omode = MCAST_EXCLUDE; + add = 1; + } else if (optname == MCAST_UNBLOCK_SOURCE) { + omode = MCAST_EXCLUDE; + add = 0; + } else if (optname == MCAST_JOIN_SOURCE_GROUP) { + struct ip_mreqn mreq; + + psin = (struct sockaddr_in *)&greqs.gsr_group; + mreq.imr_multiaddr = psin->sin_addr; + mreq.imr_address.s_addr = 0; + mreq.imr_ifindex = greqs.gsr_interface; + err = ip_mc_join_group_ssm(sk, &mreq, MCAST_INCLUDE); + if (err && err != -EADDRINUSE) + return err; + greqs.gsr_interface = mreq.imr_ifindex; + omode = MCAST_INCLUDE; + add = 1; + } else /* MCAST_LEAVE_SOURCE_GROUP */ { + omode = MCAST_INCLUDE; + add = 0; + } + return ip_mc_source(add, omode, sk, &mreqs, greqs.gsr_interface); +} + +static int ip_set_mcast_msfilter(struct sock *sk, sockptr_t optval, int optlen) +{ + struct group_filter *gsf = NULL; + int err; + + if (optlen < GROUP_FILTER_SIZE(0)) + return -EINVAL; + if (optlen > READ_ONCE(sysctl_optmem_max)) + return -ENOBUFS; + + gsf = memdup_sockptr(optval, optlen); + if (IS_ERR(gsf)) + return PTR_ERR(gsf); + + /* numsrc >= (4G-140)/128 overflow in 32 bits */ + err = -ENOBUFS; + if (gsf->gf_numsrc >= 0x1ffffff || + gsf->gf_numsrc > READ_ONCE(sock_net(sk)->ipv4.sysctl_igmp_max_msf)) + goto out_free_gsf; + + err = -EINVAL; + if (GROUP_FILTER_SIZE(gsf->gf_numsrc) > optlen) + goto out_free_gsf; + + err = set_mcast_msfilter(sk, gsf->gf_interface, gsf->gf_numsrc, + gsf->gf_fmode, &gsf->gf_group, + gsf->gf_slist_flex); +out_free_gsf: + kfree(gsf); + return err; +} + +static int compat_ip_set_mcast_msfilter(struct sock *sk, sockptr_t optval, + int optlen) +{ + const int size0 = offsetof(struct compat_group_filter, gf_slist_flex); + struct compat_group_filter *gf32; + unsigned int n; + void *p; + int err; + + if (optlen < size0) + return -EINVAL; + if (optlen > READ_ONCE(sysctl_optmem_max) - 4) + return -ENOBUFS; + + p = kmalloc(optlen + 4, GFP_KERNEL); + if (!p) + return -ENOMEM; + gf32 = p + 4; /* we want ->gf_group and ->gf_slist_flex aligned */ + + err = -EFAULT; + if (copy_from_sockptr(gf32, optval, optlen)) + goto out_free_gsf; + + /* numsrc >= (4G-140)/128 overflow in 32 bits */ + n = gf32->gf_numsrc; + err = -ENOBUFS; + if (n >= 0x1ffffff) + goto out_free_gsf; + + err = -EINVAL; + if (offsetof(struct compat_group_filter, gf_slist_flex[n]) > optlen) + goto out_free_gsf; + + /* numsrc >= (4G-140)/128 overflow in 32 bits */ + err = -ENOBUFS; + if (n > READ_ONCE(sock_net(sk)->ipv4.sysctl_igmp_max_msf)) + goto out_free_gsf; + err = set_mcast_msfilter(sk, gf32->gf_interface, n, gf32->gf_fmode, + &gf32->gf_group, gf32->gf_slist_flex); +out_free_gsf: + kfree(p); + return err; +} + +static int ip_mcast_join_leave(struct sock *sk, int optname, + sockptr_t optval, int optlen) +{ + struct ip_mreqn mreq = { }; + struct sockaddr_in *psin; + struct group_req greq; + + if (optlen < sizeof(struct group_req)) + return -EINVAL; + if (copy_from_sockptr(&greq, optval, sizeof(greq))) + return -EFAULT; + + psin = (struct sockaddr_in *)&greq.gr_group; + if (psin->sin_family != AF_INET) + return -EINVAL; + mreq.imr_multiaddr = psin->sin_addr; + mreq.imr_ifindex = greq.gr_interface; + if (optname == MCAST_JOIN_GROUP) + return ip_mc_join_group(sk, &mreq); + return ip_mc_leave_group(sk, &mreq); +} + +static int compat_ip_mcast_join_leave(struct sock *sk, int optname, + sockptr_t optval, int optlen) +{ + struct compat_group_req greq; + struct ip_mreqn mreq = { }; + struct sockaddr_in *psin; + + if (optlen < sizeof(struct compat_group_req)) + return -EINVAL; + if (copy_from_sockptr(&greq, optval, sizeof(greq))) + return -EFAULT; + + psin = (struct sockaddr_in *)&greq.gr_group; + if (psin->sin_family != AF_INET) + return -EINVAL; + mreq.imr_multiaddr = psin->sin_addr; + mreq.imr_ifindex = greq.gr_interface; + + if (optname == MCAST_JOIN_GROUP) + return ip_mc_join_group(sk, &mreq); + return ip_mc_leave_group(sk, &mreq); +} + +DEFINE_STATIC_KEY_FALSE(ip4_min_ttl); + +int do_ip_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct inet_sock *inet = inet_sk(sk); + struct net *net = sock_net(sk); + int val = 0, err; + bool needs_rtnl = setsockopt_needs_rtnl(optname); + + switch (optname) { + case IP_PKTINFO: + case IP_RECVTTL: + case IP_RECVOPTS: + case IP_RECVTOS: + case IP_RETOPTS: + case IP_TOS: + case IP_TTL: + case IP_HDRINCL: + case IP_MTU_DISCOVER: + case IP_RECVERR: + case IP_ROUTER_ALERT: + case IP_FREEBIND: + case IP_PASSSEC: + case IP_TRANSPARENT: + case IP_MINTTL: + case IP_NODEFRAG: + case IP_BIND_ADDRESS_NO_PORT: + case IP_UNICAST_IF: + case IP_MULTICAST_TTL: + case IP_MULTICAST_ALL: + case IP_MULTICAST_LOOP: + case IP_RECVORIGDSTADDR: + case IP_CHECKSUM: + case IP_RECVFRAGSIZE: + case IP_RECVERR_RFC4884: + case IP_LOCAL_PORT_RANGE: + if (optlen >= sizeof(int)) { + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + } else if (optlen >= sizeof(char)) { + unsigned char ucval; + + if (copy_from_sockptr(&ucval, optval, sizeof(ucval))) + return -EFAULT; + val = (int) ucval; + } + } + + /* If optlen==0, it is equivalent to val == 0 */ + + if (optname == IP_ROUTER_ALERT) + return ip_ra_control(sk, val ? 1 : 0, NULL); + if (ip_mroute_opt(optname)) + return ip_mroute_setsockopt(sk, optname, optval, optlen); + + err = 0; + if (needs_rtnl) + rtnl_lock(); + sockopt_lock_sock(sk); + + switch (optname) { + case IP_OPTIONS: + { + struct ip_options_rcu *old, *opt = NULL; + + if (optlen > 40) + goto e_inval; + err = ip_options_get(sock_net(sk), &opt, optval, optlen); + if (err) + break; + old = rcu_dereference_protected(inet->inet_opt, + lockdep_sock_is_held(sk)); + if (inet->is_icsk) { + struct inet_connection_sock *icsk = inet_csk(sk); +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == PF_INET || + (!((1 << sk->sk_state) & + (TCPF_LISTEN | TCPF_CLOSE)) && + inet->inet_daddr != LOOPBACK4_IPV6)) { +#endif + if (old) + icsk->icsk_ext_hdr_len -= old->opt.optlen; + if (opt) + icsk->icsk_ext_hdr_len += opt->opt.optlen; + icsk->icsk_sync_mss(sk, icsk->icsk_pmtu_cookie); +#if IS_ENABLED(CONFIG_IPV6) + } +#endif + } + rcu_assign_pointer(inet->inet_opt, opt); + if (old) + kfree_rcu(old, rcu); + break; + } + case IP_PKTINFO: + if (val) + inet->cmsg_flags |= IP_CMSG_PKTINFO; + else + inet->cmsg_flags &= ~IP_CMSG_PKTINFO; + break; + case IP_RECVTTL: + if (val) + inet->cmsg_flags |= IP_CMSG_TTL; + else + inet->cmsg_flags &= ~IP_CMSG_TTL; + break; + case IP_RECVTOS: + if (val) + inet->cmsg_flags |= IP_CMSG_TOS; + else + inet->cmsg_flags &= ~IP_CMSG_TOS; + break; + case IP_RECVOPTS: + if (val) + inet->cmsg_flags |= IP_CMSG_RECVOPTS; + else + inet->cmsg_flags &= ~IP_CMSG_RECVOPTS; + break; + case IP_RETOPTS: + if (val) + inet->cmsg_flags |= IP_CMSG_RETOPTS; + else + inet->cmsg_flags &= ~IP_CMSG_RETOPTS; + break; + case IP_PASSSEC: + if (val) + inet->cmsg_flags |= IP_CMSG_PASSSEC; + else + inet->cmsg_flags &= ~IP_CMSG_PASSSEC; + break; + case IP_RECVORIGDSTADDR: + if (val) + inet->cmsg_flags |= IP_CMSG_ORIGDSTADDR; + else + inet->cmsg_flags &= ~IP_CMSG_ORIGDSTADDR; + break; + case IP_CHECKSUM: + if (val) { + if (!(inet->cmsg_flags & IP_CMSG_CHECKSUM)) { + inet_inc_convert_csum(sk); + inet->cmsg_flags |= IP_CMSG_CHECKSUM; + } + } else { + if (inet->cmsg_flags & IP_CMSG_CHECKSUM) { + inet_dec_convert_csum(sk); + inet->cmsg_flags &= ~IP_CMSG_CHECKSUM; + } + } + break; + case IP_RECVFRAGSIZE: + if (sk->sk_type != SOCK_RAW && sk->sk_type != SOCK_DGRAM) + goto e_inval; + if (val) + inet->cmsg_flags |= IP_CMSG_RECVFRAGSIZE; + else + inet->cmsg_flags &= ~IP_CMSG_RECVFRAGSIZE; + break; + case IP_TOS: /* This sets both TOS and Precedence */ + __ip_sock_set_tos(sk, val); + break; + case IP_TTL: + if (optlen < 1) + goto e_inval; + if (val != -1 && (val < 1 || val > 255)) + goto e_inval; + inet->uc_ttl = val; + break; + case IP_HDRINCL: + if (sk->sk_type != SOCK_RAW) { + err = -ENOPROTOOPT; + break; + } + inet->hdrincl = val ? 1 : 0; + break; + case IP_NODEFRAG: + if (sk->sk_type != SOCK_RAW) { + err = -ENOPROTOOPT; + break; + } + inet->nodefrag = val ? 1 : 0; + break; + case IP_BIND_ADDRESS_NO_PORT: + inet->bind_address_no_port = val ? 1 : 0; + break; + case IP_MTU_DISCOVER: + if (val < IP_PMTUDISC_DONT || val > IP_PMTUDISC_OMIT) + goto e_inval; + inet->pmtudisc = val; + break; + case IP_RECVERR: + inet->recverr = !!val; + if (!val) + skb_queue_purge(&sk->sk_error_queue); + break; + case IP_RECVERR_RFC4884: + if (val < 0 || val > 1) + goto e_inval; + inet->recverr_rfc4884 = !!val; + break; + case IP_MULTICAST_TTL: + if (sk->sk_type == SOCK_STREAM) + goto e_inval; + if (optlen < 1) + goto e_inval; + if (val == -1) + val = 1; + if (val < 0 || val > 255) + goto e_inval; + inet->mc_ttl = val; + break; + case IP_MULTICAST_LOOP: + if (optlen < 1) + goto e_inval; + inet->mc_loop = !!val; + break; + case IP_UNICAST_IF: + { + struct net_device *dev = NULL; + int ifindex; + int midx; + + if (optlen != sizeof(int)) + goto e_inval; + + ifindex = (__force int)ntohl((__force __be32)val); + if (ifindex == 0) { + inet->uc_index = 0; + err = 0; + break; + } + + dev = dev_get_by_index(sock_net(sk), ifindex); + err = -EADDRNOTAVAIL; + if (!dev) + break; + + midx = l3mdev_master_ifindex(dev); + dev_put(dev); + + err = -EINVAL; + if (sk->sk_bound_dev_if && midx != sk->sk_bound_dev_if) + break; + + inet->uc_index = ifindex; + err = 0; + break; + } + case IP_MULTICAST_IF: + { + struct ip_mreqn mreq; + struct net_device *dev = NULL; + int midx; + + if (sk->sk_type == SOCK_STREAM) + goto e_inval; + /* + * Check the arguments are allowable + */ + + if (optlen < sizeof(struct in_addr)) + goto e_inval; + + err = -EFAULT; + if (optlen >= sizeof(struct ip_mreqn)) { + if (copy_from_sockptr(&mreq, optval, sizeof(mreq))) + break; + } else { + memset(&mreq, 0, sizeof(mreq)); + if (optlen >= sizeof(struct ip_mreq)) { + if (copy_from_sockptr(&mreq, optval, + sizeof(struct ip_mreq))) + break; + } else if (optlen >= sizeof(struct in_addr)) { + if (copy_from_sockptr(&mreq.imr_address, optval, + sizeof(struct in_addr))) + break; + } + } + + if (!mreq.imr_ifindex) { + if (mreq.imr_address.s_addr == htonl(INADDR_ANY)) { + inet->mc_index = 0; + inet->mc_addr = 0; + err = 0; + break; + } + dev = ip_dev_find(sock_net(sk), mreq.imr_address.s_addr); + if (dev) + mreq.imr_ifindex = dev->ifindex; + } else + dev = dev_get_by_index(sock_net(sk), mreq.imr_ifindex); + + + err = -EADDRNOTAVAIL; + if (!dev) + break; + + midx = l3mdev_master_ifindex(dev); + + dev_put(dev); + + err = -EINVAL; + if (sk->sk_bound_dev_if && + mreq.imr_ifindex != sk->sk_bound_dev_if && + midx != sk->sk_bound_dev_if) + break; + + inet->mc_index = mreq.imr_ifindex; + inet->mc_addr = mreq.imr_address.s_addr; + err = 0; + break; + } + + case IP_ADD_MEMBERSHIP: + case IP_DROP_MEMBERSHIP: + { + struct ip_mreqn mreq; + + err = -EPROTO; + if (inet_sk(sk)->is_icsk) + break; + + if (optlen < sizeof(struct ip_mreq)) + goto e_inval; + err = -EFAULT; + if (optlen >= sizeof(struct ip_mreqn)) { + if (copy_from_sockptr(&mreq, optval, sizeof(mreq))) + break; + } else { + memset(&mreq, 0, sizeof(mreq)); + if (copy_from_sockptr(&mreq, optval, + sizeof(struct ip_mreq))) + break; + } + + if (optname == IP_ADD_MEMBERSHIP) + err = ip_mc_join_group(sk, &mreq); + else + err = ip_mc_leave_group(sk, &mreq); + break; + } + case IP_MSFILTER: + { + struct ip_msfilter *msf; + + if (optlen < IP_MSFILTER_SIZE(0)) + goto e_inval; + if (optlen > READ_ONCE(sysctl_optmem_max)) { + err = -ENOBUFS; + break; + } + msf = memdup_sockptr(optval, optlen); + if (IS_ERR(msf)) { + err = PTR_ERR(msf); + break; + } + /* numsrc >= (1G-4) overflow in 32 bits */ + if (msf->imsf_numsrc >= 0x3ffffffcU || + msf->imsf_numsrc > READ_ONCE(net->ipv4.sysctl_igmp_max_msf)) { + kfree(msf); + err = -ENOBUFS; + break; + } + if (IP_MSFILTER_SIZE(msf->imsf_numsrc) > optlen) { + kfree(msf); + err = -EINVAL; + break; + } + err = ip_mc_msfilter(sk, msf, 0); + kfree(msf); + break; + } + case IP_BLOCK_SOURCE: + case IP_UNBLOCK_SOURCE: + case IP_ADD_SOURCE_MEMBERSHIP: + case IP_DROP_SOURCE_MEMBERSHIP: + { + struct ip_mreq_source mreqs; + int omode, add; + + if (optlen != sizeof(struct ip_mreq_source)) + goto e_inval; + if (copy_from_sockptr(&mreqs, optval, sizeof(mreqs))) { + err = -EFAULT; + break; + } + if (optname == IP_BLOCK_SOURCE) { + omode = MCAST_EXCLUDE; + add = 1; + } else if (optname == IP_UNBLOCK_SOURCE) { + omode = MCAST_EXCLUDE; + add = 0; + } else if (optname == IP_ADD_SOURCE_MEMBERSHIP) { + struct ip_mreqn mreq; + + mreq.imr_multiaddr.s_addr = mreqs.imr_multiaddr; + mreq.imr_address.s_addr = mreqs.imr_interface; + mreq.imr_ifindex = 0; + err = ip_mc_join_group_ssm(sk, &mreq, MCAST_INCLUDE); + if (err && err != -EADDRINUSE) + break; + omode = MCAST_INCLUDE; + add = 1; + } else /* IP_DROP_SOURCE_MEMBERSHIP */ { + omode = MCAST_INCLUDE; + add = 0; + } + err = ip_mc_source(add, omode, sk, &mreqs, 0); + break; + } + case MCAST_JOIN_GROUP: + case MCAST_LEAVE_GROUP: + if (in_compat_syscall()) + err = compat_ip_mcast_join_leave(sk, optname, optval, + optlen); + else + err = ip_mcast_join_leave(sk, optname, optval, optlen); + break; + case MCAST_JOIN_SOURCE_GROUP: + case MCAST_LEAVE_SOURCE_GROUP: + case MCAST_BLOCK_SOURCE: + case MCAST_UNBLOCK_SOURCE: + err = do_mcast_group_source(sk, optname, optval, optlen); + break; + case MCAST_MSFILTER: + if (in_compat_syscall()) + err = compat_ip_set_mcast_msfilter(sk, optval, optlen); + else + err = ip_set_mcast_msfilter(sk, optval, optlen); + break; + case IP_MULTICAST_ALL: + if (optlen < 1) + goto e_inval; + if (val != 0 && val != 1) + goto e_inval; + inet->mc_all = val; + break; + + case IP_FREEBIND: + if (optlen < 1) + goto e_inval; + inet->freebind = !!val; + break; + + case IP_IPSEC_POLICY: + case IP_XFRM_POLICY: + err = -EPERM; + if (!sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) + break; + err = xfrm_user_policy(sk, optname, optval, optlen); + break; + + case IP_TRANSPARENT: + if (!!val && !sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_RAW) && + !sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) { + err = -EPERM; + break; + } + if (optlen < 1) + goto e_inval; + inet->transparent = !!val; + break; + + case IP_MINTTL: + if (optlen < 1) + goto e_inval; + if (val < 0 || val > 255) + goto e_inval; + + if (val) + static_branch_enable(&ip4_min_ttl); + + /* tcp_v4_err() and tcp_v4_rcv() might read min_ttl + * while we are changint it. + */ + WRITE_ONCE(inet->min_ttl, val); + break; + + case IP_LOCAL_PORT_RANGE: + { + const __u16 lo = val; + const __u16 hi = val >> 16; + + if (optlen != sizeof(__u32)) + goto e_inval; + if (lo != 0 && hi != 0 && lo > hi) + goto e_inval; + + inet->local_port_range.lo = lo; + inet->local_port_range.hi = hi; + break; + } + default: + err = -ENOPROTOOPT; + break; + } + sockopt_release_sock(sk); + if (needs_rtnl) + rtnl_unlock(); + return err; + +e_inval: + sockopt_release_sock(sk); + if (needs_rtnl) + rtnl_unlock(); + return -EINVAL; +} + +/** + * ipv4_pktinfo_prepare - transfer some info from rtable to skb + * @sk: socket + * @skb: buffer + * + * To support IP_CMSG_PKTINFO option, we store rt_iif and specific + * destination in skb->cb[] before dst drop. + * This way, receiver doesn't make cache line misses to read rtable. + */ +void ipv4_pktinfo_prepare(const struct sock *sk, struct sk_buff *skb) +{ + struct in_pktinfo *pktinfo = PKTINFO_SKB_CB(skb); + bool prepare = (inet_sk(sk)->cmsg_flags & IP_CMSG_PKTINFO) || + ipv6_sk_rxinfo(sk); + + if (prepare && skb_rtable(skb)) { + /* skb->cb is overloaded: prior to this point it is IP{6}CB + * which has interface index (iif) as the first member of the + * underlying inet{6}_skb_parm struct. This code then overlays + * PKTINFO_SKB_CB and in_pktinfo also has iif as the first + * element so the iif is picked up from the prior IPCB. If iif + * is the loopback interface, then return the sending interface + * (e.g., process binds socket to eth0 for Tx which is + * redirected to loopback in the rtable/dst). + */ + struct rtable *rt = skb_rtable(skb); + bool l3slave = ipv4_l3mdev_skb(IPCB(skb)->flags); + + if (pktinfo->ipi_ifindex == LOOPBACK_IFINDEX) + pktinfo->ipi_ifindex = inet_iif(skb); + else if (l3slave && rt && rt->rt_iif) + pktinfo->ipi_ifindex = rt->rt_iif; + + pktinfo->ipi_spec_dst.s_addr = fib_compute_spec_dst(skb); + } else { + pktinfo->ipi_ifindex = 0; + pktinfo->ipi_spec_dst.s_addr = 0; + } + skb_dst_drop(skb); +} + +int ip_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, + unsigned int optlen) +{ + int err; + + if (level != SOL_IP) + return -ENOPROTOOPT; + + err = do_ip_setsockopt(sk, level, optname, optval, optlen); +#if IS_ENABLED(CONFIG_BPFILTER_UMH) + if (optname >= BPFILTER_IPT_SO_SET_REPLACE && + optname < BPFILTER_IPT_SET_MAX) + err = bpfilter_ip_set_sockopt(sk, optname, optval, optlen); +#endif +#ifdef CONFIG_NETFILTER + /* we need to exclude all possible ENOPROTOOPTs except default case */ + if (err == -ENOPROTOOPT && optname != IP_HDRINCL && + optname != IP_IPSEC_POLICY && + optname != IP_XFRM_POLICY && + !ip_mroute_opt(optname)) + err = nf_setsockopt(sk, PF_INET, optname, optval, optlen); +#endif + return err; +} +EXPORT_SYMBOL(ip_setsockopt); + +/* + * Get the options. Note for future reference. The GET of IP options gets + * the _received_ ones. The set sets the _sent_ ones. + */ + +static bool getsockopt_needs_rtnl(int optname) +{ + switch (optname) { + case IP_MSFILTER: + case MCAST_MSFILTER: + return true; + } + return false; +} + +static int ip_get_mcast_msfilter(struct sock *sk, sockptr_t optval, + sockptr_t optlen, int len) +{ + const int size0 = offsetof(struct group_filter, gf_slist_flex); + struct group_filter gsf; + int num, gsf_size; + int err; + + if (len < size0) + return -EINVAL; + if (copy_from_sockptr(&gsf, optval, size0)) + return -EFAULT; + + num = gsf.gf_numsrc; + err = ip_mc_gsfget(sk, &gsf, optval, + offsetof(struct group_filter, gf_slist_flex)); + if (err) + return err; + if (gsf.gf_numsrc < num) + num = gsf.gf_numsrc; + gsf_size = GROUP_FILTER_SIZE(num); + if (copy_to_sockptr(optlen, &gsf_size, sizeof(int)) || + copy_to_sockptr(optval, &gsf, size0)) + return -EFAULT; + return 0; +} + +static int compat_ip_get_mcast_msfilter(struct sock *sk, sockptr_t optval, + sockptr_t optlen, int len) +{ + const int size0 = offsetof(struct compat_group_filter, gf_slist_flex); + struct compat_group_filter gf32; + struct group_filter gf; + int num; + int err; + + if (len < size0) + return -EINVAL; + if (copy_from_sockptr(&gf32, optval, size0)) + return -EFAULT; + + gf.gf_interface = gf32.gf_interface; + gf.gf_fmode = gf32.gf_fmode; + num = gf.gf_numsrc = gf32.gf_numsrc; + gf.gf_group = gf32.gf_group; + + err = ip_mc_gsfget(sk, &gf, optval, + offsetof(struct compat_group_filter, gf_slist_flex)); + if (err) + return err; + if (gf.gf_numsrc < num) + num = gf.gf_numsrc; + len = GROUP_FILTER_SIZE(num) - (sizeof(gf) - sizeof(gf32)); + if (copy_to_sockptr(optlen, &len, sizeof(int)) || + copy_to_sockptr_offset(optval, offsetof(struct compat_group_filter, gf_fmode), + &gf.gf_fmode, sizeof(gf.gf_fmode)) || + copy_to_sockptr_offset(optval, offsetof(struct compat_group_filter, gf_numsrc), + &gf.gf_numsrc, sizeof(gf.gf_numsrc))) + return -EFAULT; + return 0; +} + +int do_ip_getsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, sockptr_t optlen) +{ + struct inet_sock *inet = inet_sk(sk); + bool needs_rtnl = getsockopt_needs_rtnl(optname); + int val, err = 0; + int len; + + if (level != SOL_IP) + return -EOPNOTSUPP; + + if (ip_mroute_opt(optname)) + return ip_mroute_getsockopt(sk, optname, optval, optlen); + + if (copy_from_sockptr(&len, optlen, sizeof(int))) + return -EFAULT; + if (len < 0) + return -EINVAL; + + if (needs_rtnl) + rtnl_lock(); + sockopt_lock_sock(sk); + + switch (optname) { + case IP_OPTIONS: + { + unsigned char optbuf[sizeof(struct ip_options)+40]; + struct ip_options *opt = (struct ip_options *)optbuf; + struct ip_options_rcu *inet_opt; + + inet_opt = rcu_dereference_protected(inet->inet_opt, + lockdep_sock_is_held(sk)); + opt->optlen = 0; + if (inet_opt) + memcpy(optbuf, &inet_opt->opt, + sizeof(struct ip_options) + + inet_opt->opt.optlen); + sockopt_release_sock(sk); + + if (opt->optlen == 0) { + len = 0; + return copy_to_sockptr(optlen, &len, sizeof(int)); + } + + ip_options_undo(opt); + + len = min_t(unsigned int, len, opt->optlen); + if (copy_to_sockptr(optlen, &len, sizeof(int))) + return -EFAULT; + if (copy_to_sockptr(optval, opt->__data, len)) + return -EFAULT; + return 0; + } + case IP_PKTINFO: + val = (inet->cmsg_flags & IP_CMSG_PKTINFO) != 0; + break; + case IP_RECVTTL: + val = (inet->cmsg_flags & IP_CMSG_TTL) != 0; + break; + case IP_RECVTOS: + val = (inet->cmsg_flags & IP_CMSG_TOS) != 0; + break; + case IP_RECVOPTS: + val = (inet->cmsg_flags & IP_CMSG_RECVOPTS) != 0; + break; + case IP_RETOPTS: + val = (inet->cmsg_flags & IP_CMSG_RETOPTS) != 0; + break; + case IP_PASSSEC: + val = (inet->cmsg_flags & IP_CMSG_PASSSEC) != 0; + break; + case IP_RECVORIGDSTADDR: + val = (inet->cmsg_flags & IP_CMSG_ORIGDSTADDR) != 0; + break; + case IP_CHECKSUM: + val = (inet->cmsg_flags & IP_CMSG_CHECKSUM) != 0; + break; + case IP_RECVFRAGSIZE: + val = (inet->cmsg_flags & IP_CMSG_RECVFRAGSIZE) != 0; + break; + case IP_TOS: + val = inet->tos; + break; + case IP_TTL: + { + struct net *net = sock_net(sk); + val = (inet->uc_ttl == -1 ? + READ_ONCE(net->ipv4.sysctl_ip_default_ttl) : + inet->uc_ttl); + break; + } + case IP_HDRINCL: + val = inet->hdrincl; + break; + case IP_NODEFRAG: + val = inet->nodefrag; + break; + case IP_BIND_ADDRESS_NO_PORT: + val = inet->bind_address_no_port; + break; + case IP_MTU_DISCOVER: + val = inet->pmtudisc; + break; + case IP_MTU: + { + struct dst_entry *dst; + val = 0; + dst = sk_dst_get(sk); + if (dst) { + val = dst_mtu(dst); + dst_release(dst); + } + if (!val) { + sockopt_release_sock(sk); + return -ENOTCONN; + } + break; + } + case IP_RECVERR: + val = inet->recverr; + break; + case IP_RECVERR_RFC4884: + val = inet->recverr_rfc4884; + break; + case IP_MULTICAST_TTL: + val = inet->mc_ttl; + break; + case IP_MULTICAST_LOOP: + val = inet->mc_loop; + break; + case IP_UNICAST_IF: + val = (__force int)htonl((__u32) inet->uc_index); + break; + case IP_MULTICAST_IF: + { + struct in_addr addr; + len = min_t(unsigned int, len, sizeof(struct in_addr)); + addr.s_addr = inet->mc_addr; + sockopt_release_sock(sk); + + if (copy_to_sockptr(optlen, &len, sizeof(int))) + return -EFAULT; + if (copy_to_sockptr(optval, &addr, len)) + return -EFAULT; + return 0; + } + case IP_MSFILTER: + { + struct ip_msfilter msf; + + if (len < IP_MSFILTER_SIZE(0)) { + err = -EINVAL; + goto out; + } + if (copy_from_sockptr(&msf, optval, IP_MSFILTER_SIZE(0))) { + err = -EFAULT; + goto out; + } + err = ip_mc_msfget(sk, &msf, optval, optlen); + goto out; + } + case MCAST_MSFILTER: + if (in_compat_syscall()) + err = compat_ip_get_mcast_msfilter(sk, optval, optlen, + len); + else + err = ip_get_mcast_msfilter(sk, optval, optlen, len); + goto out; + case IP_MULTICAST_ALL: + val = inet->mc_all; + break; + case IP_PKTOPTIONS: + { + struct msghdr msg; + + sockopt_release_sock(sk); + + if (sk->sk_type != SOCK_STREAM) + return -ENOPROTOOPT; + + if (optval.is_kernel) { + msg.msg_control_is_user = false; + msg.msg_control = optval.kernel; + } else { + msg.msg_control_is_user = true; + msg.msg_control_user = optval.user; + } + msg.msg_controllen = len; + msg.msg_flags = in_compat_syscall() ? MSG_CMSG_COMPAT : 0; + + if (inet->cmsg_flags & IP_CMSG_PKTINFO) { + struct in_pktinfo info; + + info.ipi_addr.s_addr = inet->inet_rcv_saddr; + info.ipi_spec_dst.s_addr = inet->inet_rcv_saddr; + info.ipi_ifindex = inet->mc_index; + put_cmsg(&msg, SOL_IP, IP_PKTINFO, sizeof(info), &info); + } + if (inet->cmsg_flags & IP_CMSG_TTL) { + int hlim = inet->mc_ttl; + put_cmsg(&msg, SOL_IP, IP_TTL, sizeof(hlim), &hlim); + } + if (inet->cmsg_flags & IP_CMSG_TOS) { + int tos = inet->rcv_tos; + put_cmsg(&msg, SOL_IP, IP_TOS, sizeof(tos), &tos); + } + len -= msg.msg_controllen; + return copy_to_sockptr(optlen, &len, sizeof(int)); + } + case IP_FREEBIND: + val = inet->freebind; + break; + case IP_TRANSPARENT: + val = inet->transparent; + break; + case IP_MINTTL: + val = inet->min_ttl; + break; + case IP_LOCAL_PORT_RANGE: + val = inet->local_port_range.hi << 16 | inet->local_port_range.lo; + break; + case IP_PROTOCOL: + val = inet_sk(sk)->inet_num; + break; + default: + sockopt_release_sock(sk); + return -ENOPROTOOPT; + } + sockopt_release_sock(sk); + + if (len < sizeof(int) && len > 0 && val >= 0 && val <= 255) { + unsigned char ucval = (unsigned char)val; + len = 1; + if (copy_to_sockptr(optlen, &len, sizeof(int))) + return -EFAULT; + if (copy_to_sockptr(optval, &ucval, 1)) + return -EFAULT; + } else { + len = min_t(unsigned int, sizeof(int), len); + if (copy_to_sockptr(optlen, &len, sizeof(int))) + return -EFAULT; + if (copy_to_sockptr(optval, &val, len)) + return -EFAULT; + } + return 0; + +out: + sockopt_release_sock(sk); + if (needs_rtnl) + rtnl_unlock(); + return err; +} + +int ip_getsockopt(struct sock *sk, int level, + int optname, char __user *optval, int __user *optlen) +{ + int err; + + err = do_ip_getsockopt(sk, level, optname, + USER_SOCKPTR(optval), USER_SOCKPTR(optlen)); + +#if IS_ENABLED(CONFIG_BPFILTER_UMH) + if (optname >= BPFILTER_IPT_SO_GET_INFO && + optname < BPFILTER_IPT_GET_MAX) + err = bpfilter_ip_get_sockopt(sk, optname, optval, optlen); +#endif +#ifdef CONFIG_NETFILTER + /* we need to exclude all possible ENOPROTOOPTs except default case */ + if (err == -ENOPROTOOPT && optname != IP_PKTOPTIONS && + !ip_mroute_opt(optname)) { + int len; + + if (get_user(len, optlen)) + return -EFAULT; + + err = nf_getsockopt(sk, PF_INET, optname, optval, &len); + if (err >= 0) + err = put_user(len, optlen); + return err; + } +#endif + return err; +} +EXPORT_SYMBOL(ip_getsockopt); diff --git a/net/ipv4/ip_tunnel.c b/net/ipv4/ip_tunnel.c new file mode 100644 index 000000000..24961b304 --- /dev/null +++ b/net/ipv4/ip_tunnel.c @@ -0,0 +1,1283 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2013 Nicira, Inc. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_IPV6) +#include +#include +#include +#endif + +static unsigned int ip_tunnel_hash(__be32 key, __be32 remote) +{ + return hash_32((__force u32)key ^ (__force u32)remote, + IP_TNL_HASH_BITS); +} + +static bool ip_tunnel_key_match(const struct ip_tunnel_parm *p, + __be16 flags, __be32 key) +{ + if (p->i_flags & TUNNEL_KEY) { + if (flags & TUNNEL_KEY) + return key == p->i_key; + else + /* key expected, none present */ + return false; + } else + return !(flags & TUNNEL_KEY); +} + +/* Fallback tunnel: no source, no destination, no key, no options + + Tunnel hash table: + We require exact key match i.e. if a key is present in packet + it will match only tunnel with the same key; if it is not present, + it will match only keyless tunnel. + + All keysless packets, if not matched configured keyless tunnels + will match fallback tunnel. + Given src, dst and key, find appropriate for input tunnel. +*/ +struct ip_tunnel *ip_tunnel_lookup(struct ip_tunnel_net *itn, + int link, __be16 flags, + __be32 remote, __be32 local, + __be32 key) +{ + struct ip_tunnel *t, *cand = NULL; + struct hlist_head *head; + struct net_device *ndev; + unsigned int hash; + + hash = ip_tunnel_hash(key, remote); + head = &itn->tunnels[hash]; + + hlist_for_each_entry_rcu(t, head, hash_node) { + if (local != t->parms.iph.saddr || + remote != t->parms.iph.daddr || + !(t->dev->flags & IFF_UP)) + continue; + + if (!ip_tunnel_key_match(&t->parms, flags, key)) + continue; + + if (t->parms.link == link) + return t; + else + cand = t; + } + + hlist_for_each_entry_rcu(t, head, hash_node) { + if (remote != t->parms.iph.daddr || + t->parms.iph.saddr != 0 || + !(t->dev->flags & IFF_UP)) + continue; + + if (!ip_tunnel_key_match(&t->parms, flags, key)) + continue; + + if (t->parms.link == link) + return t; + else if (!cand) + cand = t; + } + + hash = ip_tunnel_hash(key, 0); + head = &itn->tunnels[hash]; + + hlist_for_each_entry_rcu(t, head, hash_node) { + if ((local != t->parms.iph.saddr || t->parms.iph.daddr != 0) && + (local != t->parms.iph.daddr || !ipv4_is_multicast(local))) + continue; + + if (!(t->dev->flags & IFF_UP)) + continue; + + if (!ip_tunnel_key_match(&t->parms, flags, key)) + continue; + + if (t->parms.link == link) + return t; + else if (!cand) + cand = t; + } + + hlist_for_each_entry_rcu(t, head, hash_node) { + if ((!(flags & TUNNEL_NO_KEY) && t->parms.i_key != key) || + t->parms.iph.saddr != 0 || + t->parms.iph.daddr != 0 || + !(t->dev->flags & IFF_UP)) + continue; + + if (t->parms.link == link) + return t; + else if (!cand) + cand = t; + } + + if (cand) + return cand; + + t = rcu_dereference(itn->collect_md_tun); + if (t && t->dev->flags & IFF_UP) + return t; + + ndev = READ_ONCE(itn->fb_tunnel_dev); + if (ndev && ndev->flags & IFF_UP) + return netdev_priv(ndev); + + return NULL; +} +EXPORT_SYMBOL_GPL(ip_tunnel_lookup); + +static struct hlist_head *ip_bucket(struct ip_tunnel_net *itn, + struct ip_tunnel_parm *parms) +{ + unsigned int h; + __be32 remote; + __be32 i_key = parms->i_key; + + if (parms->iph.daddr && !ipv4_is_multicast(parms->iph.daddr)) + remote = parms->iph.daddr; + else + remote = 0; + + if (!(parms->i_flags & TUNNEL_KEY) && (parms->i_flags & VTI_ISVTI)) + i_key = 0; + + h = ip_tunnel_hash(i_key, remote); + return &itn->tunnels[h]; +} + +static void ip_tunnel_add(struct ip_tunnel_net *itn, struct ip_tunnel *t) +{ + struct hlist_head *head = ip_bucket(itn, &t->parms); + + if (t->collect_md) + rcu_assign_pointer(itn->collect_md_tun, t); + hlist_add_head_rcu(&t->hash_node, head); +} + +static void ip_tunnel_del(struct ip_tunnel_net *itn, struct ip_tunnel *t) +{ + if (t->collect_md) + rcu_assign_pointer(itn->collect_md_tun, NULL); + hlist_del_init_rcu(&t->hash_node); +} + +static struct ip_tunnel *ip_tunnel_find(struct ip_tunnel_net *itn, + struct ip_tunnel_parm *parms, + int type) +{ + __be32 remote = parms->iph.daddr; + __be32 local = parms->iph.saddr; + __be32 key = parms->i_key; + __be16 flags = parms->i_flags; + int link = parms->link; + struct ip_tunnel *t = NULL; + struct hlist_head *head = ip_bucket(itn, parms); + + hlist_for_each_entry_rcu(t, head, hash_node) { + if (local == t->parms.iph.saddr && + remote == t->parms.iph.daddr && + link == t->parms.link && + type == t->dev->type && + ip_tunnel_key_match(&t->parms, flags, key)) + break; + } + return t; +} + +static struct net_device *__ip_tunnel_create(struct net *net, + const struct rtnl_link_ops *ops, + struct ip_tunnel_parm *parms) +{ + int err; + struct ip_tunnel *tunnel; + struct net_device *dev; + char name[IFNAMSIZ]; + + err = -E2BIG; + if (parms->name[0]) { + if (!dev_valid_name(parms->name)) + goto failed; + strscpy(name, parms->name, IFNAMSIZ); + } else { + if (strlen(ops->kind) > (IFNAMSIZ - 3)) + goto failed; + strcpy(name, ops->kind); + strcat(name, "%d"); + } + + ASSERT_RTNL(); + dev = alloc_netdev(ops->priv_size, name, NET_NAME_UNKNOWN, ops->setup); + if (!dev) { + err = -ENOMEM; + goto failed; + } + dev_net_set(dev, net); + + dev->rtnl_link_ops = ops; + + tunnel = netdev_priv(dev); + tunnel->parms = *parms; + tunnel->net = net; + + err = register_netdevice(dev); + if (err) + goto failed_free; + + return dev; + +failed_free: + free_netdev(dev); +failed: + return ERR_PTR(err); +} + +static int ip_tunnel_bind_dev(struct net_device *dev) +{ + struct net_device *tdev = NULL; + struct ip_tunnel *tunnel = netdev_priv(dev); + const struct iphdr *iph; + int hlen = LL_MAX_HEADER; + int mtu = ETH_DATA_LEN; + int t_hlen = tunnel->hlen + sizeof(struct iphdr); + + iph = &tunnel->parms.iph; + + /* Guess output device to choose reasonable mtu and needed_headroom */ + if (iph->daddr) { + struct flowi4 fl4; + struct rtable *rt; + + ip_tunnel_init_flow(&fl4, iph->protocol, iph->daddr, + iph->saddr, tunnel->parms.o_key, + RT_TOS(iph->tos), dev_net(dev), + tunnel->parms.link, tunnel->fwmark, 0, 0); + rt = ip_route_output_key(tunnel->net, &fl4); + + if (!IS_ERR(rt)) { + tdev = rt->dst.dev; + ip_rt_put(rt); + } + if (dev->type != ARPHRD_ETHER) + dev->flags |= IFF_POINTOPOINT; + + dst_cache_reset(&tunnel->dst_cache); + } + + if (!tdev && tunnel->parms.link) + tdev = __dev_get_by_index(tunnel->net, tunnel->parms.link); + + if (tdev) { + hlen = tdev->hard_header_len + tdev->needed_headroom; + mtu = min(tdev->mtu, IP_MAX_MTU); + } + + dev->needed_headroom = t_hlen + hlen; + mtu -= t_hlen + (dev->type == ARPHRD_ETHER ? dev->hard_header_len : 0); + + if (mtu < IPV4_MIN_MTU) + mtu = IPV4_MIN_MTU; + + return mtu; +} + +static struct ip_tunnel *ip_tunnel_create(struct net *net, + struct ip_tunnel_net *itn, + struct ip_tunnel_parm *parms) +{ + struct ip_tunnel *nt; + struct net_device *dev; + int t_hlen; + int mtu; + int err; + + dev = __ip_tunnel_create(net, itn->rtnl_link_ops, parms); + if (IS_ERR(dev)) + return ERR_CAST(dev); + + mtu = ip_tunnel_bind_dev(dev); + err = dev_set_mtu(dev, mtu); + if (err) + goto err_dev_set_mtu; + + nt = netdev_priv(dev); + t_hlen = nt->hlen + sizeof(struct iphdr); + dev->min_mtu = ETH_MIN_MTU; + dev->max_mtu = IP_MAX_MTU - t_hlen; + if (dev->type == ARPHRD_ETHER) + dev->max_mtu -= dev->hard_header_len; + + ip_tunnel_add(itn, nt); + return nt; + +err_dev_set_mtu: + unregister_netdevice(dev); + return ERR_PTR(err); +} + +int ip_tunnel_rcv(struct ip_tunnel *tunnel, struct sk_buff *skb, + const struct tnl_ptk_info *tpi, struct metadata_dst *tun_dst, + bool log_ecn_error) +{ + const struct iphdr *iph = ip_hdr(skb); + int err; + +#ifdef CONFIG_NET_IPGRE_BROADCAST + if (ipv4_is_multicast(iph->daddr)) { + tunnel->dev->stats.multicast++; + skb->pkt_type = PACKET_BROADCAST; + } +#endif + + if ((!(tpi->flags&TUNNEL_CSUM) && (tunnel->parms.i_flags&TUNNEL_CSUM)) || + ((tpi->flags&TUNNEL_CSUM) && !(tunnel->parms.i_flags&TUNNEL_CSUM))) { + tunnel->dev->stats.rx_crc_errors++; + tunnel->dev->stats.rx_errors++; + goto drop; + } + + if (tunnel->parms.i_flags&TUNNEL_SEQ) { + if (!(tpi->flags&TUNNEL_SEQ) || + (tunnel->i_seqno && (s32)(ntohl(tpi->seq) - tunnel->i_seqno) < 0)) { + tunnel->dev->stats.rx_fifo_errors++; + tunnel->dev->stats.rx_errors++; + goto drop; + } + tunnel->i_seqno = ntohl(tpi->seq) + 1; + } + + skb_set_network_header(skb, (tunnel->dev->type == ARPHRD_ETHER) ? ETH_HLEN : 0); + + err = IP_ECN_decapsulate(iph, skb); + if (unlikely(err)) { + if (log_ecn_error) + net_info_ratelimited("non-ECT from %pI4 with TOS=%#x\n", + &iph->saddr, iph->tos); + if (err > 1) { + ++tunnel->dev->stats.rx_frame_errors; + ++tunnel->dev->stats.rx_errors; + goto drop; + } + } + + dev_sw_netstats_rx_add(tunnel->dev, skb->len); + skb_scrub_packet(skb, !net_eq(tunnel->net, dev_net(tunnel->dev))); + + if (tunnel->dev->type == ARPHRD_ETHER) { + skb->protocol = eth_type_trans(skb, tunnel->dev); + skb_postpull_rcsum(skb, eth_hdr(skb), ETH_HLEN); + } else { + skb->dev = tunnel->dev; + } + + if (tun_dst) + skb_dst_set(skb, (struct dst_entry *)tun_dst); + + gro_cells_receive(&tunnel->gro_cells, skb); + return 0; + +drop: + if (tun_dst) + dst_release((struct dst_entry *)tun_dst); + kfree_skb(skb); + return 0; +} +EXPORT_SYMBOL_GPL(ip_tunnel_rcv); + +int ip_tunnel_encap_add_ops(const struct ip_tunnel_encap_ops *ops, + unsigned int num) +{ + if (num >= MAX_IPTUN_ENCAP_OPS) + return -ERANGE; + + return !cmpxchg((const struct ip_tunnel_encap_ops **) + &iptun_encaps[num], + NULL, ops) ? 0 : -1; +} +EXPORT_SYMBOL(ip_tunnel_encap_add_ops); + +int ip_tunnel_encap_del_ops(const struct ip_tunnel_encap_ops *ops, + unsigned int num) +{ + int ret; + + if (num >= MAX_IPTUN_ENCAP_OPS) + return -ERANGE; + + ret = (cmpxchg((const struct ip_tunnel_encap_ops **) + &iptun_encaps[num], + ops, NULL) == ops) ? 0 : -1; + + synchronize_net(); + + return ret; +} +EXPORT_SYMBOL(ip_tunnel_encap_del_ops); + +int ip_tunnel_encap_setup(struct ip_tunnel *t, + struct ip_tunnel_encap *ipencap) +{ + int hlen; + + memset(&t->encap, 0, sizeof(t->encap)); + + hlen = ip_encap_hlen(ipencap); + if (hlen < 0) + return hlen; + + t->encap.type = ipencap->type; + t->encap.sport = ipencap->sport; + t->encap.dport = ipencap->dport; + t->encap.flags = ipencap->flags; + + t->encap_hlen = hlen; + t->hlen = t->encap_hlen + t->tun_hlen; + + return 0; +} +EXPORT_SYMBOL_GPL(ip_tunnel_encap_setup); + +static int tnl_update_pmtu(struct net_device *dev, struct sk_buff *skb, + struct rtable *rt, __be16 df, + const struct iphdr *inner_iph, + int tunnel_hlen, __be32 dst, bool md) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + int pkt_size; + int mtu; + + tunnel_hlen = md ? tunnel_hlen : tunnel->hlen; + pkt_size = skb->len - tunnel_hlen; + pkt_size -= dev->type == ARPHRD_ETHER ? dev->hard_header_len : 0; + + if (df) { + mtu = dst_mtu(&rt->dst) - (sizeof(struct iphdr) + tunnel_hlen); + mtu -= dev->type == ARPHRD_ETHER ? dev->hard_header_len : 0; + } else { + mtu = skb_valid_dst(skb) ? dst_mtu(skb_dst(skb)) : dev->mtu; + } + + if (skb_valid_dst(skb)) + skb_dst_update_pmtu_no_confirm(skb, mtu); + + if (skb->protocol == htons(ETH_P_IP)) { + if (!skb_is_gso(skb) && + (inner_iph->frag_off & htons(IP_DF)) && + mtu < pkt_size) { + icmp_ndo_send(skb, ICMP_DEST_UNREACH, ICMP_FRAG_NEEDED, htonl(mtu)); + return -E2BIG; + } + } +#if IS_ENABLED(CONFIG_IPV6) + else if (skb->protocol == htons(ETH_P_IPV6)) { + struct rt6_info *rt6; + __be32 daddr; + + rt6 = skb_valid_dst(skb) ? (struct rt6_info *)skb_dst(skb) : + NULL; + daddr = md ? dst : tunnel->parms.iph.daddr; + + if (rt6 && mtu < dst_mtu(skb_dst(skb)) && + mtu >= IPV6_MIN_MTU) { + if ((daddr && !ipv4_is_multicast(daddr)) || + rt6->rt6i_dst.plen == 128) { + rt6->rt6i_flags |= RTF_MODIFIED; + dst_metric_set(skb_dst(skb), RTAX_MTU, mtu); + } + } + + if (!skb_is_gso(skb) && mtu >= IPV6_MIN_MTU && + mtu < pkt_size) { + icmpv6_ndo_send(skb, ICMPV6_PKT_TOOBIG, 0, mtu); + return -E2BIG; + } + } +#endif + return 0; +} + +void ip_md_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, + u8 proto, int tunnel_hlen) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + u32 headroom = sizeof(struct iphdr); + struct ip_tunnel_info *tun_info; + const struct ip_tunnel_key *key; + const struct iphdr *inner_iph; + struct rtable *rt = NULL; + struct flowi4 fl4; + __be16 df = 0; + u8 tos, ttl; + bool use_cache; + + tun_info = skb_tunnel_info(skb); + if (unlikely(!tun_info || !(tun_info->mode & IP_TUNNEL_INFO_TX) || + ip_tunnel_info_af(tun_info) != AF_INET)) + goto tx_error; + key = &tun_info->key; + memset(&(IPCB(skb)->opt), 0, sizeof(IPCB(skb)->opt)); + inner_iph = (const struct iphdr *)skb_inner_network_header(skb); + tos = key->tos; + if (tos == 1) { + if (skb->protocol == htons(ETH_P_IP)) + tos = inner_iph->tos; + else if (skb->protocol == htons(ETH_P_IPV6)) + tos = ipv6_get_dsfield((const struct ipv6hdr *)inner_iph); + } + ip_tunnel_init_flow(&fl4, proto, key->u.ipv4.dst, key->u.ipv4.src, + tunnel_id_to_key32(key->tun_id), RT_TOS(tos), + dev_net(dev), 0, skb->mark, skb_get_hash(skb), + key->flow_flags); + if (tunnel->encap.type != TUNNEL_ENCAP_NONE) + goto tx_error; + + use_cache = ip_tunnel_dst_cache_usable(skb, tun_info); + if (use_cache) + rt = dst_cache_get_ip4(&tun_info->dst_cache, &fl4.saddr); + if (!rt) { + rt = ip_route_output_key(tunnel->net, &fl4); + if (IS_ERR(rt)) { + dev->stats.tx_carrier_errors++; + goto tx_error; + } + if (use_cache) + dst_cache_set_ip4(&tun_info->dst_cache, &rt->dst, + fl4.saddr); + } + if (rt->dst.dev == dev) { + ip_rt_put(rt); + dev->stats.collisions++; + goto tx_error; + } + + if (key->tun_flags & TUNNEL_DONT_FRAGMENT) + df = htons(IP_DF); + if (tnl_update_pmtu(dev, skb, rt, df, inner_iph, tunnel_hlen, + key->u.ipv4.dst, true)) { + ip_rt_put(rt); + goto tx_error; + } + + tos = ip_tunnel_ecn_encap(tos, inner_iph, skb); + ttl = key->ttl; + if (ttl == 0) { + if (skb->protocol == htons(ETH_P_IP)) + ttl = inner_iph->ttl; + else if (skb->protocol == htons(ETH_P_IPV6)) + ttl = ((const struct ipv6hdr *)inner_iph)->hop_limit; + else + ttl = ip4_dst_hoplimit(&rt->dst); + } + + headroom += LL_RESERVED_SPACE(rt->dst.dev) + rt->dst.header_len; + if (headroom > READ_ONCE(dev->needed_headroom)) + WRITE_ONCE(dev->needed_headroom, headroom); + + if (skb_cow_head(skb, READ_ONCE(dev->needed_headroom))) { + ip_rt_put(rt); + goto tx_dropped; + } + iptunnel_xmit(NULL, rt, skb, fl4.saddr, fl4.daddr, proto, tos, ttl, + df, !net_eq(tunnel->net, dev_net(dev))); + return; +tx_error: + dev->stats.tx_errors++; + goto kfree; +tx_dropped: + dev->stats.tx_dropped++; +kfree: + kfree_skb(skb); +} +EXPORT_SYMBOL_GPL(ip_md_tunnel_xmit); + +void ip_tunnel_xmit(struct sk_buff *skb, struct net_device *dev, + const struct iphdr *tnl_params, u8 protocol) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + struct ip_tunnel_info *tun_info = NULL; + const struct iphdr *inner_iph; + unsigned int max_headroom; /* The extra header space needed */ + struct rtable *rt = NULL; /* Route to the other host */ + __be16 payload_protocol; + bool use_cache = false; + struct flowi4 fl4; + bool md = false; + bool connected; + u8 tos, ttl; + __be32 dst; + __be16 df; + + inner_iph = (const struct iphdr *)skb_inner_network_header(skb); + connected = (tunnel->parms.iph.daddr != 0); + payload_protocol = skb_protocol(skb, true); + + memset(&(IPCB(skb)->opt), 0, sizeof(IPCB(skb)->opt)); + + dst = tnl_params->daddr; + if (dst == 0) { + /* NBMA tunnel */ + + if (!skb_dst(skb)) { + dev->stats.tx_fifo_errors++; + goto tx_error; + } + + tun_info = skb_tunnel_info(skb); + if (tun_info && (tun_info->mode & IP_TUNNEL_INFO_TX) && + ip_tunnel_info_af(tun_info) == AF_INET && + tun_info->key.u.ipv4.dst) { + dst = tun_info->key.u.ipv4.dst; + md = true; + connected = true; + } else if (payload_protocol == htons(ETH_P_IP)) { + rt = skb_rtable(skb); + dst = rt_nexthop(rt, inner_iph->daddr); + } +#if IS_ENABLED(CONFIG_IPV6) + else if (payload_protocol == htons(ETH_P_IPV6)) { + const struct in6_addr *addr6; + struct neighbour *neigh; + bool do_tx_error_icmp; + int addr_type; + + neigh = dst_neigh_lookup(skb_dst(skb), + &ipv6_hdr(skb)->daddr); + if (!neigh) + goto tx_error; + + addr6 = (const struct in6_addr *)&neigh->primary_key; + addr_type = ipv6_addr_type(addr6); + + if (addr_type == IPV6_ADDR_ANY) { + addr6 = &ipv6_hdr(skb)->daddr; + addr_type = ipv6_addr_type(addr6); + } + + if ((addr_type & IPV6_ADDR_COMPATv4) == 0) + do_tx_error_icmp = true; + else { + do_tx_error_icmp = false; + dst = addr6->s6_addr32[3]; + } + neigh_release(neigh); + if (do_tx_error_icmp) + goto tx_error_icmp; + } +#endif + else + goto tx_error; + + if (!md) + connected = false; + } + + tos = tnl_params->tos; + if (tos & 0x1) { + tos &= ~0x1; + if (payload_protocol == htons(ETH_P_IP)) { + tos = inner_iph->tos; + connected = false; + } else if (payload_protocol == htons(ETH_P_IPV6)) { + tos = ipv6_get_dsfield((const struct ipv6hdr *)inner_iph); + connected = false; + } + } + + ip_tunnel_init_flow(&fl4, protocol, dst, tnl_params->saddr, + tunnel->parms.o_key, RT_TOS(tos), + dev_net(dev), tunnel->parms.link, + tunnel->fwmark, skb_get_hash(skb), 0); + + if (ip_tunnel_encap(skb, tunnel, &protocol, &fl4) < 0) + goto tx_error; + + if (connected && md) { + use_cache = ip_tunnel_dst_cache_usable(skb, tun_info); + if (use_cache) + rt = dst_cache_get_ip4(&tun_info->dst_cache, + &fl4.saddr); + } else { + rt = connected ? dst_cache_get_ip4(&tunnel->dst_cache, + &fl4.saddr) : NULL; + } + + if (!rt) { + rt = ip_route_output_key(tunnel->net, &fl4); + + if (IS_ERR(rt)) { + dev->stats.tx_carrier_errors++; + goto tx_error; + } + if (use_cache) + dst_cache_set_ip4(&tun_info->dst_cache, &rt->dst, + fl4.saddr); + else if (!md && connected) + dst_cache_set_ip4(&tunnel->dst_cache, &rt->dst, + fl4.saddr); + } + + if (rt->dst.dev == dev) { + ip_rt_put(rt); + dev->stats.collisions++; + goto tx_error; + } + + df = tnl_params->frag_off; + if (payload_protocol == htons(ETH_P_IP) && !tunnel->ignore_df) + df |= (inner_iph->frag_off & htons(IP_DF)); + + if (tnl_update_pmtu(dev, skb, rt, df, inner_iph, 0, 0, false)) { + ip_rt_put(rt); + goto tx_error; + } + + if (tunnel->err_count > 0) { + if (time_before(jiffies, + tunnel->err_time + IPTUNNEL_ERR_TIMEO)) { + tunnel->err_count--; + + dst_link_failure(skb); + } else + tunnel->err_count = 0; + } + + tos = ip_tunnel_ecn_encap(tos, inner_iph, skb); + ttl = tnl_params->ttl; + if (ttl == 0) { + if (payload_protocol == htons(ETH_P_IP)) + ttl = inner_iph->ttl; +#if IS_ENABLED(CONFIG_IPV6) + else if (payload_protocol == htons(ETH_P_IPV6)) + ttl = ((const struct ipv6hdr *)inner_iph)->hop_limit; +#endif + else + ttl = ip4_dst_hoplimit(&rt->dst); + } + + max_headroom = LL_RESERVED_SPACE(rt->dst.dev) + sizeof(struct iphdr) + + rt->dst.header_len + ip_encap_hlen(&tunnel->encap); + if (max_headroom > READ_ONCE(dev->needed_headroom)) + WRITE_ONCE(dev->needed_headroom, max_headroom); + + if (skb_cow_head(skb, READ_ONCE(dev->needed_headroom))) { + ip_rt_put(rt); + dev->stats.tx_dropped++; + kfree_skb(skb); + return; + } + + iptunnel_xmit(NULL, rt, skb, fl4.saddr, fl4.daddr, protocol, tos, ttl, + df, !net_eq(tunnel->net, dev_net(dev))); + return; + +#if IS_ENABLED(CONFIG_IPV6) +tx_error_icmp: + dst_link_failure(skb); +#endif +tx_error: + dev->stats.tx_errors++; + kfree_skb(skb); +} +EXPORT_SYMBOL_GPL(ip_tunnel_xmit); + +static void ip_tunnel_update(struct ip_tunnel_net *itn, + struct ip_tunnel *t, + struct net_device *dev, + struct ip_tunnel_parm *p, + bool set_mtu, + __u32 fwmark) +{ + ip_tunnel_del(itn, t); + t->parms.iph.saddr = p->iph.saddr; + t->parms.iph.daddr = p->iph.daddr; + t->parms.i_key = p->i_key; + t->parms.o_key = p->o_key; + if (dev->type != ARPHRD_ETHER) { + __dev_addr_set(dev, &p->iph.saddr, 4); + memcpy(dev->broadcast, &p->iph.daddr, 4); + } + ip_tunnel_add(itn, t); + + t->parms.iph.ttl = p->iph.ttl; + t->parms.iph.tos = p->iph.tos; + t->parms.iph.frag_off = p->iph.frag_off; + + if (t->parms.link != p->link || t->fwmark != fwmark) { + int mtu; + + t->parms.link = p->link; + t->fwmark = fwmark; + mtu = ip_tunnel_bind_dev(dev); + if (set_mtu) + dev->mtu = mtu; + } + dst_cache_reset(&t->dst_cache); + netdev_state_change(dev); +} + +int ip_tunnel_ctl(struct net_device *dev, struct ip_tunnel_parm *p, int cmd) +{ + int err = 0; + struct ip_tunnel *t = netdev_priv(dev); + struct net *net = t->net; + struct ip_tunnel_net *itn = net_generic(net, t->ip_tnl_net_id); + + switch (cmd) { + case SIOCGETTUNNEL: + if (dev == itn->fb_tunnel_dev) { + t = ip_tunnel_find(itn, p, itn->fb_tunnel_dev->type); + if (!t) + t = netdev_priv(dev); + } + memcpy(p, &t->parms, sizeof(*p)); + break; + + case SIOCADDTUNNEL: + case SIOCCHGTUNNEL: + err = -EPERM; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + goto done; + if (p->iph.ttl) + p->iph.frag_off |= htons(IP_DF); + if (!(p->i_flags & VTI_ISVTI)) { + if (!(p->i_flags & TUNNEL_KEY)) + p->i_key = 0; + if (!(p->o_flags & TUNNEL_KEY)) + p->o_key = 0; + } + + t = ip_tunnel_find(itn, p, itn->type); + + if (cmd == SIOCADDTUNNEL) { + if (!t) { + t = ip_tunnel_create(net, itn, p); + err = PTR_ERR_OR_ZERO(t); + break; + } + + err = -EEXIST; + break; + } + if (dev != itn->fb_tunnel_dev && cmd == SIOCCHGTUNNEL) { + if (t) { + if (t->dev != dev) { + err = -EEXIST; + break; + } + } else { + unsigned int nflags = 0; + + if (ipv4_is_multicast(p->iph.daddr)) + nflags = IFF_BROADCAST; + else if (p->iph.daddr) + nflags = IFF_POINTOPOINT; + + if ((dev->flags^nflags)&(IFF_POINTOPOINT|IFF_BROADCAST)) { + err = -EINVAL; + break; + } + + t = netdev_priv(dev); + } + } + + if (t) { + err = 0; + ip_tunnel_update(itn, t, dev, p, true, 0); + } else { + err = -ENOENT; + } + break; + + case SIOCDELTUNNEL: + err = -EPERM; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + goto done; + + if (dev == itn->fb_tunnel_dev) { + err = -ENOENT; + t = ip_tunnel_find(itn, p, itn->fb_tunnel_dev->type); + if (!t) + goto done; + err = -EPERM; + if (t == netdev_priv(itn->fb_tunnel_dev)) + goto done; + dev = t->dev; + } + unregister_netdevice(dev); + err = 0; + break; + + default: + err = -EINVAL; + } + +done: + return err; +} +EXPORT_SYMBOL_GPL(ip_tunnel_ctl); + +int ip_tunnel_siocdevprivate(struct net_device *dev, struct ifreq *ifr, + void __user *data, int cmd) +{ + struct ip_tunnel_parm p; + int err; + + if (copy_from_user(&p, data, sizeof(p))) + return -EFAULT; + err = dev->netdev_ops->ndo_tunnel_ctl(dev, &p, cmd); + if (!err && copy_to_user(data, &p, sizeof(p))) + return -EFAULT; + return err; +} +EXPORT_SYMBOL_GPL(ip_tunnel_siocdevprivate); + +int __ip_tunnel_change_mtu(struct net_device *dev, int new_mtu, bool strict) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + int t_hlen = tunnel->hlen + sizeof(struct iphdr); + int max_mtu = IP_MAX_MTU - t_hlen; + + if (dev->type == ARPHRD_ETHER) + max_mtu -= dev->hard_header_len; + + if (new_mtu < ETH_MIN_MTU) + return -EINVAL; + + if (new_mtu > max_mtu) { + if (strict) + return -EINVAL; + + new_mtu = max_mtu; + } + + dev->mtu = new_mtu; + return 0; +} +EXPORT_SYMBOL_GPL(__ip_tunnel_change_mtu); + +int ip_tunnel_change_mtu(struct net_device *dev, int new_mtu) +{ + return __ip_tunnel_change_mtu(dev, new_mtu, true); +} +EXPORT_SYMBOL_GPL(ip_tunnel_change_mtu); + +static void ip_tunnel_dev_free(struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + + gro_cells_destroy(&tunnel->gro_cells); + dst_cache_destroy(&tunnel->dst_cache); + free_percpu(dev->tstats); +} + +void ip_tunnel_dellink(struct net_device *dev, struct list_head *head) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + struct ip_tunnel_net *itn; + + itn = net_generic(tunnel->net, tunnel->ip_tnl_net_id); + + if (itn->fb_tunnel_dev != dev) { + ip_tunnel_del(itn, netdev_priv(dev)); + unregister_netdevice_queue(dev, head); + } +} +EXPORT_SYMBOL_GPL(ip_tunnel_dellink); + +struct net *ip_tunnel_get_link_net(const struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + + return tunnel->net; +} +EXPORT_SYMBOL(ip_tunnel_get_link_net); + +int ip_tunnel_get_iflink(const struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + + return tunnel->parms.link; +} +EXPORT_SYMBOL(ip_tunnel_get_iflink); + +int ip_tunnel_init_net(struct net *net, unsigned int ip_tnl_net_id, + struct rtnl_link_ops *ops, char *devname) +{ + struct ip_tunnel_net *itn = net_generic(net, ip_tnl_net_id); + struct ip_tunnel_parm parms; + unsigned int i; + + itn->rtnl_link_ops = ops; + for (i = 0; i < IP_TNL_HASH_SIZE; i++) + INIT_HLIST_HEAD(&itn->tunnels[i]); + + if (!ops || !net_has_fallback_tunnels(net)) { + struct ip_tunnel_net *it_init_net; + + it_init_net = net_generic(&init_net, ip_tnl_net_id); + itn->type = it_init_net->type; + itn->fb_tunnel_dev = NULL; + return 0; + } + + memset(&parms, 0, sizeof(parms)); + if (devname) + strscpy(parms.name, devname, IFNAMSIZ); + + rtnl_lock(); + itn->fb_tunnel_dev = __ip_tunnel_create(net, ops, &parms); + /* FB netdevice is special: we have one, and only one per netns. + * Allowing to move it to another netns is clearly unsafe. + */ + if (!IS_ERR(itn->fb_tunnel_dev)) { + itn->fb_tunnel_dev->features |= NETIF_F_NETNS_LOCAL; + itn->fb_tunnel_dev->mtu = ip_tunnel_bind_dev(itn->fb_tunnel_dev); + ip_tunnel_add(itn, netdev_priv(itn->fb_tunnel_dev)); + itn->type = itn->fb_tunnel_dev->type; + } + rtnl_unlock(); + + return PTR_ERR_OR_ZERO(itn->fb_tunnel_dev); +} +EXPORT_SYMBOL_GPL(ip_tunnel_init_net); + +static void ip_tunnel_destroy(struct net *net, struct ip_tunnel_net *itn, + struct list_head *head, + struct rtnl_link_ops *ops) +{ + struct net_device *dev, *aux; + int h; + + for_each_netdev_safe(net, dev, aux) + if (dev->rtnl_link_ops == ops) + unregister_netdevice_queue(dev, head); + + for (h = 0; h < IP_TNL_HASH_SIZE; h++) { + struct ip_tunnel *t; + struct hlist_node *n; + struct hlist_head *thead = &itn->tunnels[h]; + + hlist_for_each_entry_safe(t, n, thead, hash_node) + /* If dev is in the same netns, it has already + * been added to the list by the previous loop. + */ + if (!net_eq(dev_net(t->dev), net)) + unregister_netdevice_queue(t->dev, head); + } +} + +void ip_tunnel_delete_nets(struct list_head *net_list, unsigned int id, + struct rtnl_link_ops *ops) +{ + struct ip_tunnel_net *itn; + struct net *net; + LIST_HEAD(list); + + rtnl_lock(); + list_for_each_entry(net, net_list, exit_list) { + itn = net_generic(net, id); + ip_tunnel_destroy(net, itn, &list, ops); + } + unregister_netdevice_many(&list); + rtnl_unlock(); +} +EXPORT_SYMBOL_GPL(ip_tunnel_delete_nets); + +int ip_tunnel_newlink(struct net_device *dev, struct nlattr *tb[], + struct ip_tunnel_parm *p, __u32 fwmark) +{ + struct ip_tunnel *nt; + struct net *net = dev_net(dev); + struct ip_tunnel_net *itn; + int mtu; + int err; + + nt = netdev_priv(dev); + itn = net_generic(net, nt->ip_tnl_net_id); + + if (nt->collect_md) { + if (rtnl_dereference(itn->collect_md_tun)) + return -EEXIST; + } else { + if (ip_tunnel_find(itn, p, dev->type)) + return -EEXIST; + } + + nt->net = net; + nt->parms = *p; + nt->fwmark = fwmark; + err = register_netdevice(dev); + if (err) + goto err_register_netdevice; + + if (dev->type == ARPHRD_ETHER && !tb[IFLA_ADDRESS]) + eth_hw_addr_random(dev); + + mtu = ip_tunnel_bind_dev(dev); + if (tb[IFLA_MTU]) { + unsigned int max = IP_MAX_MTU - (nt->hlen + sizeof(struct iphdr)); + + if (dev->type == ARPHRD_ETHER) + max -= dev->hard_header_len; + + mtu = clamp(dev->mtu, (unsigned int)ETH_MIN_MTU, max); + } + + err = dev_set_mtu(dev, mtu); + if (err) + goto err_dev_set_mtu; + + ip_tunnel_add(itn, nt); + return 0; + +err_dev_set_mtu: + unregister_netdevice(dev); +err_register_netdevice: + return err; +} +EXPORT_SYMBOL_GPL(ip_tunnel_newlink); + +int ip_tunnel_changelink(struct net_device *dev, struct nlattr *tb[], + struct ip_tunnel_parm *p, __u32 fwmark) +{ + struct ip_tunnel *t; + struct ip_tunnel *tunnel = netdev_priv(dev); + struct net *net = tunnel->net; + struct ip_tunnel_net *itn = net_generic(net, tunnel->ip_tnl_net_id); + + if (dev == itn->fb_tunnel_dev) + return -EINVAL; + + t = ip_tunnel_find(itn, p, dev->type); + + if (t) { + if (t->dev != dev) + return -EEXIST; + } else { + t = tunnel; + + if (dev->type != ARPHRD_ETHER) { + unsigned int nflags = 0; + + if (ipv4_is_multicast(p->iph.daddr)) + nflags = IFF_BROADCAST; + else if (p->iph.daddr) + nflags = IFF_POINTOPOINT; + + if ((dev->flags ^ nflags) & + (IFF_POINTOPOINT | IFF_BROADCAST)) + return -EINVAL; + } + } + + ip_tunnel_update(itn, t, dev, p, !tb[IFLA_MTU], fwmark); + return 0; +} +EXPORT_SYMBOL_GPL(ip_tunnel_changelink); + +int ip_tunnel_init(struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + struct iphdr *iph = &tunnel->parms.iph; + int err; + + dev->needs_free_netdev = true; + dev->priv_destructor = ip_tunnel_dev_free; + dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); + if (!dev->tstats) + return -ENOMEM; + + err = dst_cache_init(&tunnel->dst_cache, GFP_KERNEL); + if (err) { + free_percpu(dev->tstats); + return err; + } + + err = gro_cells_init(&tunnel->gro_cells, dev); + if (err) { + dst_cache_destroy(&tunnel->dst_cache); + free_percpu(dev->tstats); + return err; + } + + tunnel->dev = dev; + tunnel->net = dev_net(dev); + strcpy(tunnel->parms.name, dev->name); + iph->version = 4; + iph->ihl = 5; + + if (tunnel->collect_md) + netif_keep_dst(dev); + return 0; +} +EXPORT_SYMBOL_GPL(ip_tunnel_init); + +void ip_tunnel_uninit(struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + struct net *net = tunnel->net; + struct ip_tunnel_net *itn; + + itn = net_generic(net, tunnel->ip_tnl_net_id); + ip_tunnel_del(itn, netdev_priv(dev)); + if (itn->fb_tunnel_dev == dev) + WRITE_ONCE(itn->fb_tunnel_dev, NULL); + + dst_cache_reset(&tunnel->dst_cache); +} +EXPORT_SYMBOL_GPL(ip_tunnel_uninit); + +/* Do least required initialization, rest of init is done in tunnel_init call */ +void ip_tunnel_setup(struct net_device *dev, unsigned int net_id) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + tunnel->ip_tnl_net_id = net_id; +} +EXPORT_SYMBOL_GPL(ip_tunnel_setup); + +MODULE_LICENSE("GPL"); diff --git a/net/ipv4/ip_tunnel_core.c b/net/ipv4/ip_tunnel_core.c new file mode 100644 index 000000000..586b1b3e3 --- /dev/null +++ b/net/ipv4/ip_tunnel_core.c @@ -0,0 +1,1148 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2013 Nicira, Inc. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +const struct ip_tunnel_encap_ops __rcu * + iptun_encaps[MAX_IPTUN_ENCAP_OPS] __read_mostly; +EXPORT_SYMBOL(iptun_encaps); + +const struct ip6_tnl_encap_ops __rcu * + ip6tun_encaps[MAX_IPTUN_ENCAP_OPS] __read_mostly; +EXPORT_SYMBOL(ip6tun_encaps); + +void iptunnel_xmit(struct sock *sk, struct rtable *rt, struct sk_buff *skb, + __be32 src, __be32 dst, __u8 proto, + __u8 tos, __u8 ttl, __be16 df, bool xnet) +{ + int pkt_len = skb->len - skb_inner_network_offset(skb); + struct net *net = dev_net(rt->dst.dev); + struct net_device *dev = skb->dev; + struct iphdr *iph; + int err; + + skb_scrub_packet(skb, xnet); + + skb_clear_hash_if_not_l4(skb); + skb_dst_set(skb, &rt->dst); + memset(IPCB(skb), 0, sizeof(*IPCB(skb))); + + /* Push down and install the IP header. */ + skb_push(skb, sizeof(struct iphdr)); + skb_reset_network_header(skb); + + iph = ip_hdr(skb); + + iph->version = 4; + iph->ihl = sizeof(struct iphdr) >> 2; + iph->frag_off = ip_mtu_locked(&rt->dst) ? 0 : df; + iph->protocol = proto; + iph->tos = tos; + iph->daddr = dst; + iph->saddr = src; + iph->ttl = ttl; + __ip_select_ident(net, iph, skb_shinfo(skb)->gso_segs ?: 1); + + err = ip_local_out(net, sk, skb); + + if (dev) { + if (unlikely(net_xmit_eval(err))) + pkt_len = 0; + iptunnel_xmit_stats(dev, pkt_len); + } +} +EXPORT_SYMBOL_GPL(iptunnel_xmit); + +int __iptunnel_pull_header(struct sk_buff *skb, int hdr_len, + __be16 inner_proto, bool raw_proto, bool xnet) +{ + if (unlikely(!pskb_may_pull(skb, hdr_len))) + return -ENOMEM; + + skb_pull_rcsum(skb, hdr_len); + + if (!raw_proto && inner_proto == htons(ETH_P_TEB)) { + struct ethhdr *eh; + + if (unlikely(!pskb_may_pull(skb, ETH_HLEN))) + return -ENOMEM; + + eh = (struct ethhdr *)skb->data; + if (likely(eth_proto_is_802_3(eh->h_proto))) + skb->protocol = eh->h_proto; + else + skb->protocol = htons(ETH_P_802_2); + + } else { + skb->protocol = inner_proto; + } + + skb_clear_hash_if_not_l4(skb); + __vlan_hwaccel_clear_tag(skb); + skb_set_queue_mapping(skb, 0); + skb_scrub_packet(skb, xnet); + + return iptunnel_pull_offloads(skb); +} +EXPORT_SYMBOL_GPL(__iptunnel_pull_header); + +struct metadata_dst *iptunnel_metadata_reply(struct metadata_dst *md, + gfp_t flags) +{ + struct metadata_dst *res; + struct ip_tunnel_info *dst, *src; + + if (!md || md->type != METADATA_IP_TUNNEL || + md->u.tun_info.mode & IP_TUNNEL_INFO_TX) + return NULL; + + src = &md->u.tun_info; + res = metadata_dst_alloc(src->options_len, METADATA_IP_TUNNEL, flags); + if (!res) + return NULL; + + dst = &res->u.tun_info; + dst->key.tun_id = src->key.tun_id; + if (src->mode & IP_TUNNEL_INFO_IPV6) + memcpy(&dst->key.u.ipv6.dst, &src->key.u.ipv6.src, + sizeof(struct in6_addr)); + else + dst->key.u.ipv4.dst = src->key.u.ipv4.src; + dst->key.tun_flags = src->key.tun_flags; + dst->mode = src->mode | IP_TUNNEL_INFO_TX; + ip_tunnel_info_opts_set(dst, ip_tunnel_info_opts(src), + src->options_len, 0); + + return res; +} +EXPORT_SYMBOL_GPL(iptunnel_metadata_reply); + +int iptunnel_handle_offloads(struct sk_buff *skb, + int gso_type_mask) +{ + int err; + + if (likely(!skb->encapsulation)) { + skb_reset_inner_headers(skb); + skb->encapsulation = 1; + } + + if (skb_is_gso(skb)) { + err = skb_header_unclone(skb, GFP_ATOMIC); + if (unlikely(err)) + return err; + skb_shinfo(skb)->gso_type |= gso_type_mask; + return 0; + } + + if (skb->ip_summed != CHECKSUM_PARTIAL) { + skb->ip_summed = CHECKSUM_NONE; + /* We clear encapsulation here to prevent badly-written + * drivers potentially deciding to offload an inner checksum + * if we set CHECKSUM_PARTIAL on the outer header. + * This should go away when the drivers are all fixed. + */ + skb->encapsulation = 0; + } + + return 0; +} +EXPORT_SYMBOL_GPL(iptunnel_handle_offloads); + +/** + * iptunnel_pmtud_build_icmp() - Build ICMP error message for PMTUD + * @skb: Original packet with L2 header + * @mtu: MTU value for ICMP error + * + * Return: length on success, negative error code if message couldn't be built. + */ +static int iptunnel_pmtud_build_icmp(struct sk_buff *skb, int mtu) +{ + const struct iphdr *iph = ip_hdr(skb); + struct icmphdr *icmph; + struct iphdr *niph; + struct ethhdr eh; + int len, err; + + if (!pskb_may_pull(skb, ETH_HLEN + sizeof(struct iphdr))) + return -EINVAL; + + skb_copy_bits(skb, skb_mac_offset(skb), &eh, ETH_HLEN); + pskb_pull(skb, ETH_HLEN); + skb_reset_network_header(skb); + + err = pskb_trim(skb, 576 - sizeof(*niph) - sizeof(*icmph)); + if (err) + return err; + + len = skb->len + sizeof(*icmph); + err = skb_cow(skb, sizeof(*niph) + sizeof(*icmph) + ETH_HLEN); + if (err) + return err; + + icmph = skb_push(skb, sizeof(*icmph)); + *icmph = (struct icmphdr) { + .type = ICMP_DEST_UNREACH, + .code = ICMP_FRAG_NEEDED, + .checksum = 0, + .un.frag.__unused = 0, + .un.frag.mtu = htons(mtu), + }; + icmph->checksum = csum_fold(skb_checksum(skb, 0, len, 0)); + skb_reset_transport_header(skb); + + niph = skb_push(skb, sizeof(*niph)); + *niph = (struct iphdr) { + .ihl = sizeof(*niph) / 4u, + .version = 4, + .tos = 0, + .tot_len = htons(len + sizeof(*niph)), + .id = 0, + .frag_off = htons(IP_DF), + .ttl = iph->ttl, + .protocol = IPPROTO_ICMP, + .saddr = iph->daddr, + .daddr = iph->saddr, + }; + ip_send_check(niph); + skb_reset_network_header(skb); + + skb->ip_summed = CHECKSUM_NONE; + + eth_header(skb, skb->dev, ntohs(eh.h_proto), eh.h_source, eh.h_dest, 0); + skb_reset_mac_header(skb); + + return skb->len; +} + +/** + * iptunnel_pmtud_check_icmp() - Trigger ICMP reply if needed and allowed + * @skb: Buffer being sent by encapsulation, L2 headers expected + * @mtu: Network MTU for path + * + * Return: 0 for no ICMP reply, length if built, negative value on error. + */ +static int iptunnel_pmtud_check_icmp(struct sk_buff *skb, int mtu) +{ + const struct icmphdr *icmph = icmp_hdr(skb); + const struct iphdr *iph = ip_hdr(skb); + + if (mtu < 576 || iph->frag_off != htons(IP_DF)) + return 0; + + if (ipv4_is_lbcast(iph->daddr) || ipv4_is_multicast(iph->daddr) || + ipv4_is_zeronet(iph->saddr) || ipv4_is_loopback(iph->saddr) || + ipv4_is_lbcast(iph->saddr) || ipv4_is_multicast(iph->saddr)) + return 0; + + if (iph->protocol == IPPROTO_ICMP && icmp_is_err(icmph->type)) + return 0; + + return iptunnel_pmtud_build_icmp(skb, mtu); +} + +#if IS_ENABLED(CONFIG_IPV6) +/** + * iptunnel_pmtud_build_icmpv6() - Build ICMPv6 error message for PMTUD + * @skb: Original packet with L2 header + * @mtu: MTU value for ICMPv6 error + * + * Return: length on success, negative error code if message couldn't be built. + */ +static int iptunnel_pmtud_build_icmpv6(struct sk_buff *skb, int mtu) +{ + const struct ipv6hdr *ip6h = ipv6_hdr(skb); + struct icmp6hdr *icmp6h; + struct ipv6hdr *nip6h; + struct ethhdr eh; + int len, err; + __wsum csum; + + if (!pskb_may_pull(skb, ETH_HLEN + sizeof(struct ipv6hdr))) + return -EINVAL; + + skb_copy_bits(skb, skb_mac_offset(skb), &eh, ETH_HLEN); + pskb_pull(skb, ETH_HLEN); + skb_reset_network_header(skb); + + err = pskb_trim(skb, IPV6_MIN_MTU - sizeof(*nip6h) - sizeof(*icmp6h)); + if (err) + return err; + + len = skb->len + sizeof(*icmp6h); + err = skb_cow(skb, sizeof(*nip6h) + sizeof(*icmp6h) + ETH_HLEN); + if (err) + return err; + + icmp6h = skb_push(skb, sizeof(*icmp6h)); + *icmp6h = (struct icmp6hdr) { + .icmp6_type = ICMPV6_PKT_TOOBIG, + .icmp6_code = 0, + .icmp6_cksum = 0, + .icmp6_mtu = htonl(mtu), + }; + skb_reset_transport_header(skb); + + nip6h = skb_push(skb, sizeof(*nip6h)); + *nip6h = (struct ipv6hdr) { + .priority = 0, + .version = 6, + .flow_lbl = { 0 }, + .payload_len = htons(len), + .nexthdr = IPPROTO_ICMPV6, + .hop_limit = ip6h->hop_limit, + .saddr = ip6h->daddr, + .daddr = ip6h->saddr, + }; + skb_reset_network_header(skb); + + csum = csum_partial(icmp6h, len, 0); + icmp6h->icmp6_cksum = csum_ipv6_magic(&nip6h->saddr, &nip6h->daddr, len, + IPPROTO_ICMPV6, csum); + + skb->ip_summed = CHECKSUM_NONE; + + eth_header(skb, skb->dev, ntohs(eh.h_proto), eh.h_source, eh.h_dest, 0); + skb_reset_mac_header(skb); + + return skb->len; +} + +/** + * iptunnel_pmtud_check_icmpv6() - Trigger ICMPv6 reply if needed and allowed + * @skb: Buffer being sent by encapsulation, L2 headers expected + * @mtu: Network MTU for path + * + * Return: 0 for no ICMPv6 reply, length if built, negative value on error. + */ +static int iptunnel_pmtud_check_icmpv6(struct sk_buff *skb, int mtu) +{ + const struct ipv6hdr *ip6h = ipv6_hdr(skb); + int stype = ipv6_addr_type(&ip6h->saddr); + u8 proto = ip6h->nexthdr; + __be16 frag_off; + int offset; + + if (mtu < IPV6_MIN_MTU) + return 0; + + if (stype == IPV6_ADDR_ANY || stype == IPV6_ADDR_MULTICAST || + stype == IPV6_ADDR_LOOPBACK) + return 0; + + offset = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &proto, + &frag_off); + if (offset < 0 || (frag_off & htons(~0x7))) + return 0; + + if (proto == IPPROTO_ICMPV6) { + struct icmp6hdr *icmp6h; + + if (!pskb_may_pull(skb, skb_network_header(skb) + + offset + 1 - skb->data)) + return 0; + + icmp6h = (struct icmp6hdr *)(skb_network_header(skb) + offset); + if (icmpv6_is_err(icmp6h->icmp6_type) || + icmp6h->icmp6_type == NDISC_REDIRECT) + return 0; + } + + return iptunnel_pmtud_build_icmpv6(skb, mtu); +} +#endif /* IS_ENABLED(CONFIG_IPV6) */ + +/** + * skb_tunnel_check_pmtu() - Check, update PMTU and trigger ICMP reply as needed + * @skb: Buffer being sent by encapsulation, L2 headers expected + * @encap_dst: Destination for tunnel encapsulation (outer IP) + * @headroom: Encapsulation header size, bytes + * @reply: Build matching ICMP or ICMPv6 message as a result + * + * L2 tunnel implementations that can carry IP and can be directly bridged + * (currently UDP tunnels) can't always rely on IP forwarding paths to handle + * PMTU discovery. In the bridged case, ICMP or ICMPv6 messages need to be built + * based on payload and sent back by the encapsulation itself. + * + * For routable interfaces, we just need to update the PMTU for the destination. + * + * Return: 0 if ICMP error not needed, length if built, negative value on error + */ +int skb_tunnel_check_pmtu(struct sk_buff *skb, struct dst_entry *encap_dst, + int headroom, bool reply) +{ + u32 mtu = dst_mtu(encap_dst) - headroom; + + if ((skb_is_gso(skb) && skb_gso_validate_network_len(skb, mtu)) || + (!skb_is_gso(skb) && (skb->len - skb_network_offset(skb)) <= mtu)) + return 0; + + skb_dst_update_pmtu_no_confirm(skb, mtu); + + if (!reply || skb->pkt_type == PACKET_HOST) + return 0; + + if (skb->protocol == htons(ETH_P_IP)) + return iptunnel_pmtud_check_icmp(skb, mtu); + +#if IS_ENABLED(CONFIG_IPV6) + if (skb->protocol == htons(ETH_P_IPV6)) + return iptunnel_pmtud_check_icmpv6(skb, mtu); +#endif + return 0; +} +EXPORT_SYMBOL(skb_tunnel_check_pmtu); + +static const struct nla_policy ip_tun_policy[LWTUNNEL_IP_MAX + 1] = { + [LWTUNNEL_IP_UNSPEC] = { .strict_start_type = LWTUNNEL_IP_OPTS }, + [LWTUNNEL_IP_ID] = { .type = NLA_U64 }, + [LWTUNNEL_IP_DST] = { .type = NLA_U32 }, + [LWTUNNEL_IP_SRC] = { .type = NLA_U32 }, + [LWTUNNEL_IP_TTL] = { .type = NLA_U8 }, + [LWTUNNEL_IP_TOS] = { .type = NLA_U8 }, + [LWTUNNEL_IP_FLAGS] = { .type = NLA_U16 }, + [LWTUNNEL_IP_OPTS] = { .type = NLA_NESTED }, +}; + +static const struct nla_policy ip_opts_policy[LWTUNNEL_IP_OPTS_MAX + 1] = { + [LWTUNNEL_IP_OPTS_GENEVE] = { .type = NLA_NESTED }, + [LWTUNNEL_IP_OPTS_VXLAN] = { .type = NLA_NESTED }, + [LWTUNNEL_IP_OPTS_ERSPAN] = { .type = NLA_NESTED }, +}; + +static const struct nla_policy +geneve_opt_policy[LWTUNNEL_IP_OPT_GENEVE_MAX + 1] = { + [LWTUNNEL_IP_OPT_GENEVE_CLASS] = { .type = NLA_U16 }, + [LWTUNNEL_IP_OPT_GENEVE_TYPE] = { .type = NLA_U8 }, + [LWTUNNEL_IP_OPT_GENEVE_DATA] = { .type = NLA_BINARY, .len = 128 }, +}; + +static const struct nla_policy +vxlan_opt_policy[LWTUNNEL_IP_OPT_VXLAN_MAX + 1] = { + [LWTUNNEL_IP_OPT_VXLAN_GBP] = { .type = NLA_U32 }, +}; + +static const struct nla_policy +erspan_opt_policy[LWTUNNEL_IP_OPT_ERSPAN_MAX + 1] = { + [LWTUNNEL_IP_OPT_ERSPAN_VER] = { .type = NLA_U8 }, + [LWTUNNEL_IP_OPT_ERSPAN_INDEX] = { .type = NLA_U32 }, + [LWTUNNEL_IP_OPT_ERSPAN_DIR] = { .type = NLA_U8 }, + [LWTUNNEL_IP_OPT_ERSPAN_HWID] = { .type = NLA_U8 }, +}; + +static int ip_tun_parse_opts_geneve(struct nlattr *attr, + struct ip_tunnel_info *info, int opts_len, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[LWTUNNEL_IP_OPT_GENEVE_MAX + 1]; + int data_len, err; + + err = nla_parse_nested(tb, LWTUNNEL_IP_OPT_GENEVE_MAX, attr, + geneve_opt_policy, extack); + if (err) + return err; + + if (!tb[LWTUNNEL_IP_OPT_GENEVE_CLASS] || + !tb[LWTUNNEL_IP_OPT_GENEVE_TYPE] || + !tb[LWTUNNEL_IP_OPT_GENEVE_DATA]) + return -EINVAL; + + attr = tb[LWTUNNEL_IP_OPT_GENEVE_DATA]; + data_len = nla_len(attr); + if (data_len % 4) + return -EINVAL; + + if (info) { + struct geneve_opt *opt = ip_tunnel_info_opts(info) + opts_len; + + memcpy(opt->opt_data, nla_data(attr), data_len); + opt->length = data_len / 4; + attr = tb[LWTUNNEL_IP_OPT_GENEVE_CLASS]; + opt->opt_class = nla_get_be16(attr); + attr = tb[LWTUNNEL_IP_OPT_GENEVE_TYPE]; + opt->type = nla_get_u8(attr); + info->key.tun_flags |= TUNNEL_GENEVE_OPT; + } + + return sizeof(struct geneve_opt) + data_len; +} + +static int ip_tun_parse_opts_vxlan(struct nlattr *attr, + struct ip_tunnel_info *info, int opts_len, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[LWTUNNEL_IP_OPT_VXLAN_MAX + 1]; + int err; + + err = nla_parse_nested(tb, LWTUNNEL_IP_OPT_VXLAN_MAX, attr, + vxlan_opt_policy, extack); + if (err) + return err; + + if (!tb[LWTUNNEL_IP_OPT_VXLAN_GBP]) + return -EINVAL; + + if (info) { + struct vxlan_metadata *md = + ip_tunnel_info_opts(info) + opts_len; + + attr = tb[LWTUNNEL_IP_OPT_VXLAN_GBP]; + md->gbp = nla_get_u32(attr); + md->gbp &= VXLAN_GBP_MASK; + info->key.tun_flags |= TUNNEL_VXLAN_OPT; + } + + return sizeof(struct vxlan_metadata); +} + +static int ip_tun_parse_opts_erspan(struct nlattr *attr, + struct ip_tunnel_info *info, int opts_len, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[LWTUNNEL_IP_OPT_ERSPAN_MAX + 1]; + int err; + u8 ver; + + err = nla_parse_nested(tb, LWTUNNEL_IP_OPT_ERSPAN_MAX, attr, + erspan_opt_policy, extack); + if (err) + return err; + + if (!tb[LWTUNNEL_IP_OPT_ERSPAN_VER]) + return -EINVAL; + + ver = nla_get_u8(tb[LWTUNNEL_IP_OPT_ERSPAN_VER]); + if (ver == 1) { + if (!tb[LWTUNNEL_IP_OPT_ERSPAN_INDEX]) + return -EINVAL; + } else if (ver == 2) { + if (!tb[LWTUNNEL_IP_OPT_ERSPAN_DIR] || + !tb[LWTUNNEL_IP_OPT_ERSPAN_HWID]) + return -EINVAL; + } else { + return -EINVAL; + } + + if (info) { + struct erspan_metadata *md = + ip_tunnel_info_opts(info) + opts_len; + + md->version = ver; + if (ver == 1) { + attr = tb[LWTUNNEL_IP_OPT_ERSPAN_INDEX]; + md->u.index = nla_get_be32(attr); + } else { + attr = tb[LWTUNNEL_IP_OPT_ERSPAN_DIR]; + md->u.md2.dir = nla_get_u8(attr); + attr = tb[LWTUNNEL_IP_OPT_ERSPAN_HWID]; + set_hwid(&md->u.md2, nla_get_u8(attr)); + } + + info->key.tun_flags |= TUNNEL_ERSPAN_OPT; + } + + return sizeof(struct erspan_metadata); +} + +static int ip_tun_parse_opts(struct nlattr *attr, struct ip_tunnel_info *info, + struct netlink_ext_ack *extack) +{ + int err, rem, opt_len, opts_len = 0; + struct nlattr *nla; + __be16 type = 0; + + if (!attr) + return 0; + + err = nla_validate(nla_data(attr), nla_len(attr), LWTUNNEL_IP_OPTS_MAX, + ip_opts_policy, extack); + if (err) + return err; + + nla_for_each_attr(nla, nla_data(attr), nla_len(attr), rem) { + switch (nla_type(nla)) { + case LWTUNNEL_IP_OPTS_GENEVE: + if (type && type != TUNNEL_GENEVE_OPT) + return -EINVAL; + opt_len = ip_tun_parse_opts_geneve(nla, info, opts_len, + extack); + if (opt_len < 0) + return opt_len; + opts_len += opt_len; + if (opts_len > IP_TUNNEL_OPTS_MAX) + return -EINVAL; + type = TUNNEL_GENEVE_OPT; + break; + case LWTUNNEL_IP_OPTS_VXLAN: + if (type) + return -EINVAL; + opt_len = ip_tun_parse_opts_vxlan(nla, info, opts_len, + extack); + if (opt_len < 0) + return opt_len; + opts_len += opt_len; + type = TUNNEL_VXLAN_OPT; + break; + case LWTUNNEL_IP_OPTS_ERSPAN: + if (type) + return -EINVAL; + opt_len = ip_tun_parse_opts_erspan(nla, info, opts_len, + extack); + if (opt_len < 0) + return opt_len; + opts_len += opt_len; + type = TUNNEL_ERSPAN_OPT; + break; + default: + return -EINVAL; + } + } + + return opts_len; +} + +static int ip_tun_get_optlen(struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + return ip_tun_parse_opts(attr, NULL, extack); +} + +static int ip_tun_set_opts(struct nlattr *attr, struct ip_tunnel_info *info, + struct netlink_ext_ack *extack) +{ + return ip_tun_parse_opts(attr, info, extack); +} + +static int ip_tun_build_state(struct net *net, struct nlattr *attr, + unsigned int family, const void *cfg, + struct lwtunnel_state **ts, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[LWTUNNEL_IP_MAX + 1]; + struct lwtunnel_state *new_state; + struct ip_tunnel_info *tun_info; + int err, opt_len; + + err = nla_parse_nested_deprecated(tb, LWTUNNEL_IP_MAX, attr, + ip_tun_policy, extack); + if (err < 0) + return err; + + opt_len = ip_tun_get_optlen(tb[LWTUNNEL_IP_OPTS], extack); + if (opt_len < 0) + return opt_len; + + new_state = lwtunnel_state_alloc(sizeof(*tun_info) + opt_len); + if (!new_state) + return -ENOMEM; + + new_state->type = LWTUNNEL_ENCAP_IP; + + tun_info = lwt_tun_info(new_state); + + err = ip_tun_set_opts(tb[LWTUNNEL_IP_OPTS], tun_info, extack); + if (err < 0) { + lwtstate_free(new_state); + return err; + } + +#ifdef CONFIG_DST_CACHE + err = dst_cache_init(&tun_info->dst_cache, GFP_KERNEL); + if (err) { + lwtstate_free(new_state); + return err; + } +#endif + + if (tb[LWTUNNEL_IP_ID]) + tun_info->key.tun_id = nla_get_be64(tb[LWTUNNEL_IP_ID]); + + if (tb[LWTUNNEL_IP_DST]) + tun_info->key.u.ipv4.dst = nla_get_in_addr(tb[LWTUNNEL_IP_DST]); + + if (tb[LWTUNNEL_IP_SRC]) + tun_info->key.u.ipv4.src = nla_get_in_addr(tb[LWTUNNEL_IP_SRC]); + + if (tb[LWTUNNEL_IP_TTL]) + tun_info->key.ttl = nla_get_u8(tb[LWTUNNEL_IP_TTL]); + + if (tb[LWTUNNEL_IP_TOS]) + tun_info->key.tos = nla_get_u8(tb[LWTUNNEL_IP_TOS]); + + if (tb[LWTUNNEL_IP_FLAGS]) + tun_info->key.tun_flags |= + (nla_get_be16(tb[LWTUNNEL_IP_FLAGS]) & + ~TUNNEL_OPTIONS_PRESENT); + + tun_info->mode = IP_TUNNEL_INFO_TX; + tun_info->options_len = opt_len; + + *ts = new_state; + + return 0; +} + +static void ip_tun_destroy_state(struct lwtunnel_state *lwtstate) +{ +#ifdef CONFIG_DST_CACHE + struct ip_tunnel_info *tun_info = lwt_tun_info(lwtstate); + + dst_cache_destroy(&tun_info->dst_cache); +#endif +} + +static int ip_tun_fill_encap_opts_geneve(struct sk_buff *skb, + struct ip_tunnel_info *tun_info) +{ + struct geneve_opt *opt; + struct nlattr *nest; + int offset = 0; + + nest = nla_nest_start_noflag(skb, LWTUNNEL_IP_OPTS_GENEVE); + if (!nest) + return -ENOMEM; + + while (tun_info->options_len > offset) { + opt = ip_tunnel_info_opts(tun_info) + offset; + if (nla_put_be16(skb, LWTUNNEL_IP_OPT_GENEVE_CLASS, + opt->opt_class) || + nla_put_u8(skb, LWTUNNEL_IP_OPT_GENEVE_TYPE, opt->type) || + nla_put(skb, LWTUNNEL_IP_OPT_GENEVE_DATA, opt->length * 4, + opt->opt_data)) { + nla_nest_cancel(skb, nest); + return -ENOMEM; + } + offset += sizeof(*opt) + opt->length * 4; + } + + nla_nest_end(skb, nest); + return 0; +} + +static int ip_tun_fill_encap_opts_vxlan(struct sk_buff *skb, + struct ip_tunnel_info *tun_info) +{ + struct vxlan_metadata *md; + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, LWTUNNEL_IP_OPTS_VXLAN); + if (!nest) + return -ENOMEM; + + md = ip_tunnel_info_opts(tun_info); + if (nla_put_u32(skb, LWTUNNEL_IP_OPT_VXLAN_GBP, md->gbp)) { + nla_nest_cancel(skb, nest); + return -ENOMEM; + } + + nla_nest_end(skb, nest); + return 0; +} + +static int ip_tun_fill_encap_opts_erspan(struct sk_buff *skb, + struct ip_tunnel_info *tun_info) +{ + struct erspan_metadata *md; + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, LWTUNNEL_IP_OPTS_ERSPAN); + if (!nest) + return -ENOMEM; + + md = ip_tunnel_info_opts(tun_info); + if (nla_put_u8(skb, LWTUNNEL_IP_OPT_ERSPAN_VER, md->version)) + goto err; + + if (md->version == 1 && + nla_put_be32(skb, LWTUNNEL_IP_OPT_ERSPAN_INDEX, md->u.index)) + goto err; + + if (md->version == 2 && + (nla_put_u8(skb, LWTUNNEL_IP_OPT_ERSPAN_DIR, md->u.md2.dir) || + nla_put_u8(skb, LWTUNNEL_IP_OPT_ERSPAN_HWID, + get_hwid(&md->u.md2)))) + goto err; + + nla_nest_end(skb, nest); + return 0; +err: + nla_nest_cancel(skb, nest); + return -ENOMEM; +} + +static int ip_tun_fill_encap_opts(struct sk_buff *skb, int type, + struct ip_tunnel_info *tun_info) +{ + struct nlattr *nest; + int err = 0; + + if (!(tun_info->key.tun_flags & TUNNEL_OPTIONS_PRESENT)) + return 0; + + nest = nla_nest_start_noflag(skb, type); + if (!nest) + return -ENOMEM; + + if (tun_info->key.tun_flags & TUNNEL_GENEVE_OPT) + err = ip_tun_fill_encap_opts_geneve(skb, tun_info); + else if (tun_info->key.tun_flags & TUNNEL_VXLAN_OPT) + err = ip_tun_fill_encap_opts_vxlan(skb, tun_info); + else if (tun_info->key.tun_flags & TUNNEL_ERSPAN_OPT) + err = ip_tun_fill_encap_opts_erspan(skb, tun_info); + + if (err) { + nla_nest_cancel(skb, nest); + return err; + } + + nla_nest_end(skb, nest); + return 0; +} + +static int ip_tun_fill_encap_info(struct sk_buff *skb, + struct lwtunnel_state *lwtstate) +{ + struct ip_tunnel_info *tun_info = lwt_tun_info(lwtstate); + + if (nla_put_be64(skb, LWTUNNEL_IP_ID, tun_info->key.tun_id, + LWTUNNEL_IP_PAD) || + nla_put_in_addr(skb, LWTUNNEL_IP_DST, tun_info->key.u.ipv4.dst) || + nla_put_in_addr(skb, LWTUNNEL_IP_SRC, tun_info->key.u.ipv4.src) || + nla_put_u8(skb, LWTUNNEL_IP_TOS, tun_info->key.tos) || + nla_put_u8(skb, LWTUNNEL_IP_TTL, tun_info->key.ttl) || + nla_put_be16(skb, LWTUNNEL_IP_FLAGS, tun_info->key.tun_flags) || + ip_tun_fill_encap_opts(skb, LWTUNNEL_IP_OPTS, tun_info)) + return -ENOMEM; + + return 0; +} + +static int ip_tun_opts_nlsize(struct ip_tunnel_info *info) +{ + int opt_len; + + if (!(info->key.tun_flags & TUNNEL_OPTIONS_PRESENT)) + return 0; + + opt_len = nla_total_size(0); /* LWTUNNEL_IP_OPTS */ + if (info->key.tun_flags & TUNNEL_GENEVE_OPT) { + struct geneve_opt *opt; + int offset = 0; + + opt_len += nla_total_size(0); /* LWTUNNEL_IP_OPTS_GENEVE */ + while (info->options_len > offset) { + opt = ip_tunnel_info_opts(info) + offset; + opt_len += nla_total_size(2) /* OPT_GENEVE_CLASS */ + + nla_total_size(1) /* OPT_GENEVE_TYPE */ + + nla_total_size(opt->length * 4); + /* OPT_GENEVE_DATA */ + offset += sizeof(*opt) + opt->length * 4; + } + } else if (info->key.tun_flags & TUNNEL_VXLAN_OPT) { + opt_len += nla_total_size(0) /* LWTUNNEL_IP_OPTS_VXLAN */ + + nla_total_size(4); /* OPT_VXLAN_GBP */ + } else if (info->key.tun_flags & TUNNEL_ERSPAN_OPT) { + struct erspan_metadata *md = ip_tunnel_info_opts(info); + + opt_len += nla_total_size(0) /* LWTUNNEL_IP_OPTS_ERSPAN */ + + nla_total_size(1) /* OPT_ERSPAN_VER */ + + (md->version == 1 ? nla_total_size(4) + /* OPT_ERSPAN_INDEX (v1) */ + : nla_total_size(1) + + nla_total_size(1)); + /* OPT_ERSPAN_DIR + HWID (v2) */ + } + + return opt_len; +} + +static int ip_tun_encap_nlsize(struct lwtunnel_state *lwtstate) +{ + return nla_total_size_64bit(8) /* LWTUNNEL_IP_ID */ + + nla_total_size(4) /* LWTUNNEL_IP_DST */ + + nla_total_size(4) /* LWTUNNEL_IP_SRC */ + + nla_total_size(1) /* LWTUNNEL_IP_TOS */ + + nla_total_size(1) /* LWTUNNEL_IP_TTL */ + + nla_total_size(2) /* LWTUNNEL_IP_FLAGS */ + + ip_tun_opts_nlsize(lwt_tun_info(lwtstate)); + /* LWTUNNEL_IP_OPTS */ +} + +static int ip_tun_cmp_encap(struct lwtunnel_state *a, struct lwtunnel_state *b) +{ + struct ip_tunnel_info *info_a = lwt_tun_info(a); + struct ip_tunnel_info *info_b = lwt_tun_info(b); + + return memcmp(info_a, info_b, sizeof(info_a->key)) || + info_a->mode != info_b->mode || + info_a->options_len != info_b->options_len || + memcmp(ip_tunnel_info_opts(info_a), + ip_tunnel_info_opts(info_b), info_a->options_len); +} + +static const struct lwtunnel_encap_ops ip_tun_lwt_ops = { + .build_state = ip_tun_build_state, + .destroy_state = ip_tun_destroy_state, + .fill_encap = ip_tun_fill_encap_info, + .get_encap_size = ip_tun_encap_nlsize, + .cmp_encap = ip_tun_cmp_encap, + .owner = THIS_MODULE, +}; + +static const struct nla_policy ip6_tun_policy[LWTUNNEL_IP6_MAX + 1] = { + [LWTUNNEL_IP6_UNSPEC] = { .strict_start_type = LWTUNNEL_IP6_OPTS }, + [LWTUNNEL_IP6_ID] = { .type = NLA_U64 }, + [LWTUNNEL_IP6_DST] = { .len = sizeof(struct in6_addr) }, + [LWTUNNEL_IP6_SRC] = { .len = sizeof(struct in6_addr) }, + [LWTUNNEL_IP6_HOPLIMIT] = { .type = NLA_U8 }, + [LWTUNNEL_IP6_TC] = { .type = NLA_U8 }, + [LWTUNNEL_IP6_FLAGS] = { .type = NLA_U16 }, + [LWTUNNEL_IP6_OPTS] = { .type = NLA_NESTED }, +}; + +static int ip6_tun_build_state(struct net *net, struct nlattr *attr, + unsigned int family, const void *cfg, + struct lwtunnel_state **ts, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[LWTUNNEL_IP6_MAX + 1]; + struct lwtunnel_state *new_state; + struct ip_tunnel_info *tun_info; + int err, opt_len; + + err = nla_parse_nested_deprecated(tb, LWTUNNEL_IP6_MAX, attr, + ip6_tun_policy, extack); + if (err < 0) + return err; + + opt_len = ip_tun_get_optlen(tb[LWTUNNEL_IP6_OPTS], extack); + if (opt_len < 0) + return opt_len; + + new_state = lwtunnel_state_alloc(sizeof(*tun_info) + opt_len); + if (!new_state) + return -ENOMEM; + + new_state->type = LWTUNNEL_ENCAP_IP6; + + tun_info = lwt_tun_info(new_state); + + err = ip_tun_set_opts(tb[LWTUNNEL_IP6_OPTS], tun_info, extack); + if (err < 0) { + lwtstate_free(new_state); + return err; + } + + if (tb[LWTUNNEL_IP6_ID]) + tun_info->key.tun_id = nla_get_be64(tb[LWTUNNEL_IP6_ID]); + + if (tb[LWTUNNEL_IP6_DST]) + tun_info->key.u.ipv6.dst = nla_get_in6_addr(tb[LWTUNNEL_IP6_DST]); + + if (tb[LWTUNNEL_IP6_SRC]) + tun_info->key.u.ipv6.src = nla_get_in6_addr(tb[LWTUNNEL_IP6_SRC]); + + if (tb[LWTUNNEL_IP6_HOPLIMIT]) + tun_info->key.ttl = nla_get_u8(tb[LWTUNNEL_IP6_HOPLIMIT]); + + if (tb[LWTUNNEL_IP6_TC]) + tun_info->key.tos = nla_get_u8(tb[LWTUNNEL_IP6_TC]); + + if (tb[LWTUNNEL_IP6_FLAGS]) + tun_info->key.tun_flags |= + (nla_get_be16(tb[LWTUNNEL_IP6_FLAGS]) & + ~TUNNEL_OPTIONS_PRESENT); + + tun_info->mode = IP_TUNNEL_INFO_TX | IP_TUNNEL_INFO_IPV6; + tun_info->options_len = opt_len; + + *ts = new_state; + + return 0; +} + +static int ip6_tun_fill_encap_info(struct sk_buff *skb, + struct lwtunnel_state *lwtstate) +{ + struct ip_tunnel_info *tun_info = lwt_tun_info(lwtstate); + + if (nla_put_be64(skb, LWTUNNEL_IP6_ID, tun_info->key.tun_id, + LWTUNNEL_IP6_PAD) || + nla_put_in6_addr(skb, LWTUNNEL_IP6_DST, &tun_info->key.u.ipv6.dst) || + nla_put_in6_addr(skb, LWTUNNEL_IP6_SRC, &tun_info->key.u.ipv6.src) || + nla_put_u8(skb, LWTUNNEL_IP6_TC, tun_info->key.tos) || + nla_put_u8(skb, LWTUNNEL_IP6_HOPLIMIT, tun_info->key.ttl) || + nla_put_be16(skb, LWTUNNEL_IP6_FLAGS, tun_info->key.tun_flags) || + ip_tun_fill_encap_opts(skb, LWTUNNEL_IP6_OPTS, tun_info)) + return -ENOMEM; + + return 0; +} + +static int ip6_tun_encap_nlsize(struct lwtunnel_state *lwtstate) +{ + return nla_total_size_64bit(8) /* LWTUNNEL_IP6_ID */ + + nla_total_size(16) /* LWTUNNEL_IP6_DST */ + + nla_total_size(16) /* LWTUNNEL_IP6_SRC */ + + nla_total_size(1) /* LWTUNNEL_IP6_HOPLIMIT */ + + nla_total_size(1) /* LWTUNNEL_IP6_TC */ + + nla_total_size(2) /* LWTUNNEL_IP6_FLAGS */ + + ip_tun_opts_nlsize(lwt_tun_info(lwtstate)); + /* LWTUNNEL_IP6_OPTS */ +} + +static const struct lwtunnel_encap_ops ip6_tun_lwt_ops = { + .build_state = ip6_tun_build_state, + .fill_encap = ip6_tun_fill_encap_info, + .get_encap_size = ip6_tun_encap_nlsize, + .cmp_encap = ip_tun_cmp_encap, + .owner = THIS_MODULE, +}; + +void __init ip_tunnel_core_init(void) +{ + /* If you land here, make sure whether increasing ip_tunnel_info's + * options_len is a reasonable choice with its usage in front ends + * (f.e., it's part of flow keys, etc). + */ + BUILD_BUG_ON(IP_TUNNEL_OPTS_MAX != 255); + + lwtunnel_encap_add_ops(&ip_tun_lwt_ops, LWTUNNEL_ENCAP_IP); + lwtunnel_encap_add_ops(&ip6_tun_lwt_ops, LWTUNNEL_ENCAP_IP6); +} + +DEFINE_STATIC_KEY_FALSE(ip_tunnel_metadata_cnt); +EXPORT_SYMBOL(ip_tunnel_metadata_cnt); + +void ip_tunnel_need_metadata(void) +{ + static_branch_inc(&ip_tunnel_metadata_cnt); +} +EXPORT_SYMBOL_GPL(ip_tunnel_need_metadata); + +void ip_tunnel_unneed_metadata(void) +{ + static_branch_dec(&ip_tunnel_metadata_cnt); +} +EXPORT_SYMBOL_GPL(ip_tunnel_unneed_metadata); + +/* Returns either the correct skb->protocol value, or 0 if invalid. */ +__be16 ip_tunnel_parse_protocol(const struct sk_buff *skb) +{ + if (skb_network_header(skb) >= skb->head && + (skb_network_header(skb) + sizeof(struct iphdr)) <= skb_tail_pointer(skb) && + ip_hdr(skb)->version == 4) + return htons(ETH_P_IP); + if (skb_network_header(skb) >= skb->head && + (skb_network_header(skb) + sizeof(struct ipv6hdr)) <= skb_tail_pointer(skb) && + ipv6_hdr(skb)->version == 6) + return htons(ETH_P_IPV6); + return 0; +} +EXPORT_SYMBOL(ip_tunnel_parse_protocol); + +const struct header_ops ip_tunnel_header_ops = { .parse_protocol = ip_tunnel_parse_protocol }; +EXPORT_SYMBOL(ip_tunnel_header_ops); + +/* This function returns true when ENCAP attributes are present in the nl msg */ +bool ip_tunnel_netlink_encap_parms(struct nlattr *data[], + struct ip_tunnel_encap *encap) +{ + bool ret = false; + + memset(encap, 0, sizeof(*encap)); + + if (!data) + return ret; + + if (data[IFLA_IPTUN_ENCAP_TYPE]) { + ret = true; + encap->type = nla_get_u16(data[IFLA_IPTUN_ENCAP_TYPE]); + } + + if (data[IFLA_IPTUN_ENCAP_FLAGS]) { + ret = true; + encap->flags = nla_get_u16(data[IFLA_IPTUN_ENCAP_FLAGS]); + } + + if (data[IFLA_IPTUN_ENCAP_SPORT]) { + ret = true; + encap->sport = nla_get_be16(data[IFLA_IPTUN_ENCAP_SPORT]); + } + + if (data[IFLA_IPTUN_ENCAP_DPORT]) { + ret = true; + encap->dport = nla_get_be16(data[IFLA_IPTUN_ENCAP_DPORT]); + } + + return ret; +} +EXPORT_SYMBOL_GPL(ip_tunnel_netlink_encap_parms); + +void ip_tunnel_netlink_parms(struct nlattr *data[], + struct ip_tunnel_parm *parms) +{ + if (data[IFLA_IPTUN_LINK]) + parms->link = nla_get_u32(data[IFLA_IPTUN_LINK]); + + if (data[IFLA_IPTUN_LOCAL]) + parms->iph.saddr = nla_get_be32(data[IFLA_IPTUN_LOCAL]); + + if (data[IFLA_IPTUN_REMOTE]) + parms->iph.daddr = nla_get_be32(data[IFLA_IPTUN_REMOTE]); + + if (data[IFLA_IPTUN_TTL]) { + parms->iph.ttl = nla_get_u8(data[IFLA_IPTUN_TTL]); + if (parms->iph.ttl) + parms->iph.frag_off = htons(IP_DF); + } + + if (data[IFLA_IPTUN_TOS]) + parms->iph.tos = nla_get_u8(data[IFLA_IPTUN_TOS]); + + if (!data[IFLA_IPTUN_PMTUDISC] || nla_get_u8(data[IFLA_IPTUN_PMTUDISC])) + parms->iph.frag_off = htons(IP_DF); + + if (data[IFLA_IPTUN_FLAGS]) + parms->i_flags = nla_get_be16(data[IFLA_IPTUN_FLAGS]); + + if (data[IFLA_IPTUN_PROTO]) + parms->iph.protocol = nla_get_u8(data[IFLA_IPTUN_PROTO]); +} +EXPORT_SYMBOL_GPL(ip_tunnel_netlink_parms); diff --git a/net/ipv4/ip_vti.c b/net/ipv4/ip_vti.c new file mode 100644 index 000000000..615c1dcf3 --- /dev/null +++ b/net/ipv4/ip_vti.c @@ -0,0 +1,726 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Linux NET3: IP/IP protocol decoder modified to support + * virtual tunnel interface + * + * Authors: + * Saurabh Mohan (saurabh.mohan@vyatta.com) 05/07/2012 + */ + +/* + This version of net/ipv4/ip_vti.c is cloned of net/ipv4/ipip.c + + For comments look at net/ipv4/ip_gre.c --ANK + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +static struct rtnl_link_ops vti_link_ops __read_mostly; + +static unsigned int vti_net_id __read_mostly; +static int vti_tunnel_init(struct net_device *dev); + +static int vti_input(struct sk_buff *skb, int nexthdr, __be32 spi, + int encap_type, bool update_skb_dev) +{ + struct ip_tunnel *tunnel; + const struct iphdr *iph = ip_hdr(skb); + struct net *net = dev_net(skb->dev); + struct ip_tunnel_net *itn = net_generic(net, vti_net_id); + + tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, TUNNEL_NO_KEY, + iph->saddr, iph->daddr, 0); + if (tunnel) { + if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) + goto drop; + + XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip4 = tunnel; + + if (update_skb_dev) + skb->dev = tunnel->dev; + + return xfrm_input(skb, nexthdr, spi, encap_type); + } + + return -EINVAL; +drop: + kfree_skb(skb); + return 0; +} + +static int vti_input_proto(struct sk_buff *skb, int nexthdr, __be32 spi, + int encap_type) +{ + return vti_input(skb, nexthdr, spi, encap_type, false); +} + +static int vti_rcv(struct sk_buff *skb, __be32 spi, bool update_skb_dev) +{ + XFRM_SPI_SKB_CB(skb)->family = AF_INET; + XFRM_SPI_SKB_CB(skb)->daddroff = offsetof(struct iphdr, daddr); + + return vti_input(skb, ip_hdr(skb)->protocol, spi, 0, update_skb_dev); +} + +static int vti_rcv_proto(struct sk_buff *skb) +{ + return vti_rcv(skb, 0, false); +} + +static int vti_rcv_cb(struct sk_buff *skb, int err) +{ + unsigned short family; + struct net_device *dev; + struct xfrm_state *x; + const struct xfrm_mode *inner_mode; + struct ip_tunnel *tunnel = XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip4; + u32 orig_mark = skb->mark; + int ret; + + if (!tunnel) + return 1; + + dev = tunnel->dev; + + if (err) { + dev->stats.rx_errors++; + dev->stats.rx_dropped++; + + return 0; + } + + x = xfrm_input_state(skb); + + inner_mode = &x->inner_mode; + + if (x->sel.family == AF_UNSPEC) { + inner_mode = xfrm_ip2inner_mode(x, XFRM_MODE_SKB_CB(skb)->protocol); + if (inner_mode == NULL) { + XFRM_INC_STATS(dev_net(skb->dev), + LINUX_MIB_XFRMINSTATEMODEERROR); + return -EINVAL; + } + } + + family = inner_mode->family; + + skb->mark = be32_to_cpu(tunnel->parms.i_key); + ret = xfrm_policy_check(NULL, XFRM_POLICY_IN, skb, family); + skb->mark = orig_mark; + + if (!ret) + return -EPERM; + + skb_scrub_packet(skb, !net_eq(tunnel->net, dev_net(skb->dev))); + skb->dev = dev; + dev_sw_netstats_rx_add(dev, skb->len); + + return 0; +} + +static bool vti_state_check(const struct xfrm_state *x, __be32 dst, __be32 src) +{ + xfrm_address_t *daddr = (xfrm_address_t *)&dst; + xfrm_address_t *saddr = (xfrm_address_t *)&src; + + /* if there is no transform then this tunnel is not functional. + * Or if the xfrm is not mode tunnel. + */ + if (!x || x->props.mode != XFRM_MODE_TUNNEL || + x->props.family != AF_INET) + return false; + + if (!dst) + return xfrm_addr_equal(saddr, &x->props.saddr, AF_INET); + + if (!xfrm_state_addr_check(x, daddr, saddr, AF_INET)) + return false; + + return true; +} + +static netdev_tx_t vti_xmit(struct sk_buff *skb, struct net_device *dev, + struct flowi *fl) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + struct ip_tunnel_parm *parms = &tunnel->parms; + struct dst_entry *dst = skb_dst(skb); + struct net_device *tdev; /* Device to other host */ + int pkt_len = skb->len; + int err; + int mtu; + + if (!dst) { + switch (skb->protocol) { + case htons(ETH_P_IP): { + struct rtable *rt; + + fl->u.ip4.flowi4_oif = dev->ifindex; + fl->u.ip4.flowi4_flags |= FLOWI_FLAG_ANYSRC; + rt = __ip_route_output_key(dev_net(dev), &fl->u.ip4); + if (IS_ERR(rt)) { + dev->stats.tx_carrier_errors++; + goto tx_error_icmp; + } + dst = &rt->dst; + skb_dst_set(skb, dst); + break; + } +#if IS_ENABLED(CONFIG_IPV6) + case htons(ETH_P_IPV6): + fl->u.ip6.flowi6_oif = dev->ifindex; + fl->u.ip6.flowi6_flags |= FLOWI_FLAG_ANYSRC; + dst = ip6_route_output(dev_net(dev), NULL, &fl->u.ip6); + if (dst->error) { + dst_release(dst); + dst = NULL; + dev->stats.tx_carrier_errors++; + goto tx_error_icmp; + } + skb_dst_set(skb, dst); + break; +#endif + default: + dev->stats.tx_carrier_errors++; + goto tx_error_icmp; + } + } + + dst_hold(dst); + dst = xfrm_lookup_route(tunnel->net, dst, fl, NULL, 0); + if (IS_ERR(dst)) { + dev->stats.tx_carrier_errors++; + goto tx_error_icmp; + } + + if (dst->flags & DST_XFRM_QUEUE) + goto xmit; + + if (!vti_state_check(dst->xfrm, parms->iph.daddr, parms->iph.saddr)) { + dev->stats.tx_carrier_errors++; + dst_release(dst); + goto tx_error_icmp; + } + + tdev = dst->dev; + + if (tdev == dev) { + dst_release(dst); + dev->stats.collisions++; + goto tx_error; + } + + mtu = dst_mtu(dst); + if (skb->len > mtu) { + skb_dst_update_pmtu_no_confirm(skb, mtu); + if (skb->protocol == htons(ETH_P_IP)) { + if (!(ip_hdr(skb)->frag_off & htons(IP_DF))) + goto xmit; + icmp_ndo_send(skb, ICMP_DEST_UNREACH, ICMP_FRAG_NEEDED, + htonl(mtu)); + } else { + if (mtu < IPV6_MIN_MTU) + mtu = IPV6_MIN_MTU; + + icmpv6_ndo_send(skb, ICMPV6_PKT_TOOBIG, 0, mtu); + } + + dst_release(dst); + goto tx_error; + } + +xmit: + skb_scrub_packet(skb, !net_eq(tunnel->net, dev_net(dev))); + skb_dst_set(skb, dst); + skb->dev = skb_dst(skb)->dev; + + err = dst_output(tunnel->net, skb->sk, skb); + if (net_xmit_eval(err) == 0) + err = pkt_len; + iptunnel_xmit_stats(dev, err); + return NETDEV_TX_OK; + +tx_error_icmp: + dst_link_failure(skb); +tx_error: + dev->stats.tx_errors++; + kfree_skb(skb); + return NETDEV_TX_OK; +} + +/* This function assumes it is being called from dev_queue_xmit() + * and that skb is filled properly by that function. + */ +static netdev_tx_t vti_tunnel_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + struct flowi fl; + + if (!pskb_inet_may_pull(skb)) + goto tx_err; + + memset(&fl, 0, sizeof(fl)); + + switch (skb->protocol) { + case htons(ETH_P_IP): + memset(IPCB(skb), 0, sizeof(*IPCB(skb))); + xfrm_decode_session(skb, &fl, AF_INET); + break; + case htons(ETH_P_IPV6): + memset(IP6CB(skb), 0, sizeof(*IP6CB(skb))); + xfrm_decode_session(skb, &fl, AF_INET6); + break; + default: + goto tx_err; + } + + /* override mark with tunnel output key */ + fl.flowi_mark = be32_to_cpu(tunnel->parms.o_key); + + return vti_xmit(skb, dev, &fl); + +tx_err: + dev->stats.tx_errors++; + kfree_skb(skb); + return NETDEV_TX_OK; +} + +static int vti4_err(struct sk_buff *skb, u32 info) +{ + __be32 spi; + __u32 mark; + struct xfrm_state *x; + struct ip_tunnel *tunnel; + struct ip_esp_hdr *esph; + struct ip_auth_hdr *ah ; + struct ip_comp_hdr *ipch; + struct net *net = dev_net(skb->dev); + const struct iphdr *iph = (const struct iphdr *)skb->data; + int protocol = iph->protocol; + struct ip_tunnel_net *itn = net_generic(net, vti_net_id); + + tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, TUNNEL_NO_KEY, + iph->daddr, iph->saddr, 0); + if (!tunnel) + return -1; + + mark = be32_to_cpu(tunnel->parms.o_key); + + switch (protocol) { + case IPPROTO_ESP: + esph = (struct ip_esp_hdr *)(skb->data+(iph->ihl<<2)); + spi = esph->spi; + break; + case IPPROTO_AH: + ah = (struct ip_auth_hdr *)(skb->data+(iph->ihl<<2)); + spi = ah->spi; + break; + case IPPROTO_COMP: + ipch = (struct ip_comp_hdr *)(skb->data+(iph->ihl<<2)); + spi = htonl(ntohs(ipch->cpi)); + break; + default: + return 0; + } + + switch (icmp_hdr(skb)->type) { + case ICMP_DEST_UNREACH: + if (icmp_hdr(skb)->code != ICMP_FRAG_NEEDED) + return 0; + break; + case ICMP_REDIRECT: + break; + default: + return 0; + } + + x = xfrm_state_lookup(net, mark, (const xfrm_address_t *)&iph->daddr, + spi, protocol, AF_INET); + if (!x) + return 0; + + if (icmp_hdr(skb)->type == ICMP_DEST_UNREACH) + ipv4_update_pmtu(skb, net, info, 0, protocol); + else + ipv4_redirect(skb, net, 0, protocol); + xfrm_state_put(x); + + return 0; +} + +static int +vti_tunnel_ctl(struct net_device *dev, struct ip_tunnel_parm *p, int cmd) +{ + int err = 0; + + if (cmd == SIOCADDTUNNEL || cmd == SIOCCHGTUNNEL) { + if (p->iph.version != 4 || p->iph.protocol != IPPROTO_IPIP || + p->iph.ihl != 5) + return -EINVAL; + } + + if (!(p->i_flags & GRE_KEY)) + p->i_key = 0; + if (!(p->o_flags & GRE_KEY)) + p->o_key = 0; + + p->i_flags = VTI_ISVTI; + + err = ip_tunnel_ctl(dev, p, cmd); + if (err) + return err; + + if (cmd != SIOCDELTUNNEL) { + p->i_flags |= GRE_KEY; + p->o_flags |= GRE_KEY; + } + return 0; +} + +static const struct net_device_ops vti_netdev_ops = { + .ndo_init = vti_tunnel_init, + .ndo_uninit = ip_tunnel_uninit, + .ndo_start_xmit = vti_tunnel_xmit, + .ndo_siocdevprivate = ip_tunnel_siocdevprivate, + .ndo_change_mtu = ip_tunnel_change_mtu, + .ndo_get_stats64 = dev_get_tstats64, + .ndo_get_iflink = ip_tunnel_get_iflink, + .ndo_tunnel_ctl = vti_tunnel_ctl, +}; + +static void vti_tunnel_setup(struct net_device *dev) +{ + dev->netdev_ops = &vti_netdev_ops; + dev->header_ops = &ip_tunnel_header_ops; + dev->type = ARPHRD_TUNNEL; + ip_tunnel_setup(dev, vti_net_id); +} + +static int vti_tunnel_init(struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + struct iphdr *iph = &tunnel->parms.iph; + + __dev_addr_set(dev, &iph->saddr, 4); + memcpy(dev->broadcast, &iph->daddr, 4); + + dev->flags = IFF_NOARP; + dev->addr_len = 4; + dev->features |= NETIF_F_LLTX; + netif_keep_dst(dev); + + return ip_tunnel_init(dev); +} + +static void __net_init vti_fb_tunnel_init(struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + struct iphdr *iph = &tunnel->parms.iph; + + iph->version = 4; + iph->protocol = IPPROTO_IPIP; + iph->ihl = 5; +} + +static struct xfrm4_protocol vti_esp4_protocol __read_mostly = { + .handler = vti_rcv_proto, + .input_handler = vti_input_proto, + .cb_handler = vti_rcv_cb, + .err_handler = vti4_err, + .priority = 100, +}; + +static struct xfrm4_protocol vti_ah4_protocol __read_mostly = { + .handler = vti_rcv_proto, + .input_handler = vti_input_proto, + .cb_handler = vti_rcv_cb, + .err_handler = vti4_err, + .priority = 100, +}; + +static struct xfrm4_protocol vti_ipcomp4_protocol __read_mostly = { + .handler = vti_rcv_proto, + .input_handler = vti_input_proto, + .cb_handler = vti_rcv_cb, + .err_handler = vti4_err, + .priority = 100, +}; + +#if IS_ENABLED(CONFIG_INET_XFRM_TUNNEL) +static int vti_rcv_tunnel(struct sk_buff *skb) +{ + XFRM_SPI_SKB_CB(skb)->family = AF_INET; + XFRM_SPI_SKB_CB(skb)->daddroff = offsetof(struct iphdr, daddr); + + return vti_input(skb, IPPROTO_IPIP, ip_hdr(skb)->saddr, 0, false); +} + +static struct xfrm_tunnel vti_ipip_handler __read_mostly = { + .handler = vti_rcv_tunnel, + .cb_handler = vti_rcv_cb, + .err_handler = vti4_err, + .priority = 0, +}; + +#if IS_ENABLED(CONFIG_IPV6) +static struct xfrm_tunnel vti_ipip6_handler __read_mostly = { + .handler = vti_rcv_tunnel, + .cb_handler = vti_rcv_cb, + .err_handler = vti4_err, + .priority = 0, +}; +#endif +#endif + +static int __net_init vti_init_net(struct net *net) +{ + int err; + struct ip_tunnel_net *itn; + + err = ip_tunnel_init_net(net, vti_net_id, &vti_link_ops, "ip_vti0"); + if (err) + return err; + itn = net_generic(net, vti_net_id); + if (itn->fb_tunnel_dev) + vti_fb_tunnel_init(itn->fb_tunnel_dev); + return 0; +} + +static void __net_exit vti_exit_batch_net(struct list_head *list_net) +{ + ip_tunnel_delete_nets(list_net, vti_net_id, &vti_link_ops); +} + +static struct pernet_operations vti_net_ops = { + .init = vti_init_net, + .exit_batch = vti_exit_batch_net, + .id = &vti_net_id, + .size = sizeof(struct ip_tunnel_net), +}; + +static int vti_tunnel_validate(struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + return 0; +} + +static void vti_netlink_parms(struct nlattr *data[], + struct ip_tunnel_parm *parms, + __u32 *fwmark) +{ + memset(parms, 0, sizeof(*parms)); + + parms->iph.protocol = IPPROTO_IPIP; + + if (!data) + return; + + parms->i_flags = VTI_ISVTI; + + if (data[IFLA_VTI_LINK]) + parms->link = nla_get_u32(data[IFLA_VTI_LINK]); + + if (data[IFLA_VTI_IKEY]) + parms->i_key = nla_get_be32(data[IFLA_VTI_IKEY]); + + if (data[IFLA_VTI_OKEY]) + parms->o_key = nla_get_be32(data[IFLA_VTI_OKEY]); + + if (data[IFLA_VTI_LOCAL]) + parms->iph.saddr = nla_get_in_addr(data[IFLA_VTI_LOCAL]); + + if (data[IFLA_VTI_REMOTE]) + parms->iph.daddr = nla_get_in_addr(data[IFLA_VTI_REMOTE]); + + if (data[IFLA_VTI_FWMARK]) + *fwmark = nla_get_u32(data[IFLA_VTI_FWMARK]); +} + +static int vti_newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ip_tunnel_parm parms; + __u32 fwmark = 0; + + vti_netlink_parms(data, &parms, &fwmark); + return ip_tunnel_newlink(dev, tb, &parms, fwmark); +} + +static int vti_changelink(struct net_device *dev, struct nlattr *tb[], + struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ip_tunnel *t = netdev_priv(dev); + __u32 fwmark = t->fwmark; + struct ip_tunnel_parm p; + + vti_netlink_parms(data, &p, &fwmark); + return ip_tunnel_changelink(dev, tb, &p, fwmark); +} + +static size_t vti_get_size(const struct net_device *dev) +{ + return + /* IFLA_VTI_LINK */ + nla_total_size(4) + + /* IFLA_VTI_IKEY */ + nla_total_size(4) + + /* IFLA_VTI_OKEY */ + nla_total_size(4) + + /* IFLA_VTI_LOCAL */ + nla_total_size(4) + + /* IFLA_VTI_REMOTE */ + nla_total_size(4) + + /* IFLA_VTI_FWMARK */ + nla_total_size(4) + + 0; +} + +static int vti_fill_info(struct sk_buff *skb, const struct net_device *dev) +{ + struct ip_tunnel *t = netdev_priv(dev); + struct ip_tunnel_parm *p = &t->parms; + + if (nla_put_u32(skb, IFLA_VTI_LINK, p->link) || + nla_put_be32(skb, IFLA_VTI_IKEY, p->i_key) || + nla_put_be32(skb, IFLA_VTI_OKEY, p->o_key) || + nla_put_in_addr(skb, IFLA_VTI_LOCAL, p->iph.saddr) || + nla_put_in_addr(skb, IFLA_VTI_REMOTE, p->iph.daddr) || + nla_put_u32(skb, IFLA_VTI_FWMARK, t->fwmark)) + return -EMSGSIZE; + + return 0; +} + +static const struct nla_policy vti_policy[IFLA_VTI_MAX + 1] = { + [IFLA_VTI_LINK] = { .type = NLA_U32 }, + [IFLA_VTI_IKEY] = { .type = NLA_U32 }, + [IFLA_VTI_OKEY] = { .type = NLA_U32 }, + [IFLA_VTI_LOCAL] = { .len = sizeof_field(struct iphdr, saddr) }, + [IFLA_VTI_REMOTE] = { .len = sizeof_field(struct iphdr, daddr) }, + [IFLA_VTI_FWMARK] = { .type = NLA_U32 }, +}; + +static struct rtnl_link_ops vti_link_ops __read_mostly = { + .kind = "vti", + .maxtype = IFLA_VTI_MAX, + .policy = vti_policy, + .priv_size = sizeof(struct ip_tunnel), + .setup = vti_tunnel_setup, + .validate = vti_tunnel_validate, + .newlink = vti_newlink, + .changelink = vti_changelink, + .dellink = ip_tunnel_dellink, + .get_size = vti_get_size, + .fill_info = vti_fill_info, + .get_link_net = ip_tunnel_get_link_net, +}; + +static int __init vti_init(void) +{ + const char *msg; + int err; + + pr_info("IPv4 over IPsec tunneling driver\n"); + + msg = "tunnel device"; + err = register_pernet_device(&vti_net_ops); + if (err < 0) + goto pernet_dev_failed; + + msg = "tunnel protocols"; + err = xfrm4_protocol_register(&vti_esp4_protocol, IPPROTO_ESP); + if (err < 0) + goto xfrm_proto_esp_failed; + err = xfrm4_protocol_register(&vti_ah4_protocol, IPPROTO_AH); + if (err < 0) + goto xfrm_proto_ah_failed; + err = xfrm4_protocol_register(&vti_ipcomp4_protocol, IPPROTO_COMP); + if (err < 0) + goto xfrm_proto_comp_failed; + +#if IS_ENABLED(CONFIG_INET_XFRM_TUNNEL) + msg = "ipip tunnel"; + err = xfrm4_tunnel_register(&vti_ipip_handler, AF_INET); + if (err < 0) + goto xfrm_tunnel_ipip_failed; +#if IS_ENABLED(CONFIG_IPV6) + err = xfrm4_tunnel_register(&vti_ipip6_handler, AF_INET6); + if (err < 0) + goto xfrm_tunnel_ipip6_failed; +#endif +#endif + + msg = "netlink interface"; + err = rtnl_link_register(&vti_link_ops); + if (err < 0) + goto rtnl_link_failed; + + return err; + +rtnl_link_failed: +#if IS_ENABLED(CONFIG_INET_XFRM_TUNNEL) +#if IS_ENABLED(CONFIG_IPV6) + xfrm4_tunnel_deregister(&vti_ipip6_handler, AF_INET6); +xfrm_tunnel_ipip6_failed: +#endif + xfrm4_tunnel_deregister(&vti_ipip_handler, AF_INET); +xfrm_tunnel_ipip_failed: +#endif + xfrm4_protocol_deregister(&vti_ipcomp4_protocol, IPPROTO_COMP); +xfrm_proto_comp_failed: + xfrm4_protocol_deregister(&vti_ah4_protocol, IPPROTO_AH); +xfrm_proto_ah_failed: + xfrm4_protocol_deregister(&vti_esp4_protocol, IPPROTO_ESP); +xfrm_proto_esp_failed: + unregister_pernet_device(&vti_net_ops); +pernet_dev_failed: + pr_err("vti init: failed to register %s\n", msg); + return err; +} + +static void __exit vti_fini(void) +{ + rtnl_link_unregister(&vti_link_ops); +#if IS_ENABLED(CONFIG_INET_XFRM_TUNNEL) +#if IS_ENABLED(CONFIG_IPV6) + xfrm4_tunnel_deregister(&vti_ipip6_handler, AF_INET6); +#endif + xfrm4_tunnel_deregister(&vti_ipip_handler, AF_INET); +#endif + xfrm4_protocol_deregister(&vti_ipcomp4_protocol, IPPROTO_COMP); + xfrm4_protocol_deregister(&vti_ah4_protocol, IPPROTO_AH); + xfrm4_protocol_deregister(&vti_esp4_protocol, IPPROTO_ESP); + unregister_pernet_device(&vti_net_ops); +} + +module_init(vti_init); +module_exit(vti_fini); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_RTNL_LINK("vti"); +MODULE_ALIAS_NETDEV("ip_vti0"); diff --git a/net/ipv4/ipcomp.c b/net/ipv4/ipcomp.c new file mode 100644 index 000000000..5a4fb2539 --- /dev/null +++ b/net/ipv4/ipcomp.c @@ -0,0 +1,204 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IP Payload Compression Protocol (IPComp) - RFC3173. + * + * Copyright (c) 2003 James Morris + * + * Todo: + * - Tunable compression parameters. + * - Compression stats. + * - Adaptive compression. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int ipcomp4_err(struct sk_buff *skb, u32 info) +{ + struct net *net = dev_net(skb->dev); + __be32 spi; + const struct iphdr *iph = (const struct iphdr *)skb->data; + struct ip_comp_hdr *ipch = (struct ip_comp_hdr *)(skb->data+(iph->ihl<<2)); + struct xfrm_state *x; + + switch (icmp_hdr(skb)->type) { + case ICMP_DEST_UNREACH: + if (icmp_hdr(skb)->code != ICMP_FRAG_NEEDED) + return 0; + break; + case ICMP_REDIRECT: + break; + default: + return 0; + } + + spi = htonl(ntohs(ipch->cpi)); + x = xfrm_state_lookup(net, skb->mark, (const xfrm_address_t *)&iph->daddr, + spi, IPPROTO_COMP, AF_INET); + if (!x) + return 0; + + if (icmp_hdr(skb)->type == ICMP_DEST_UNREACH) + ipv4_update_pmtu(skb, net, info, 0, IPPROTO_COMP); + else + ipv4_redirect(skb, net, 0, IPPROTO_COMP); + xfrm_state_put(x); + + return 0; +} + +/* We always hold one tunnel user reference to indicate a tunnel */ +static struct xfrm_state *ipcomp_tunnel_create(struct xfrm_state *x) +{ + struct net *net = xs_net(x); + struct xfrm_state *t; + + t = xfrm_state_alloc(net); + if (!t) + goto out; + + t->id.proto = IPPROTO_IPIP; + t->id.spi = x->props.saddr.a4; + t->id.daddr.a4 = x->id.daddr.a4; + memcpy(&t->sel, &x->sel, sizeof(t->sel)); + t->props.family = AF_INET; + t->props.mode = x->props.mode; + t->props.saddr.a4 = x->props.saddr.a4; + t->props.flags = x->props.flags; + t->props.extra_flags = x->props.extra_flags; + memcpy(&t->mark, &x->mark, sizeof(t->mark)); + t->if_id = x->if_id; + + if (xfrm_init_state(t)) + goto error; + + atomic_set(&t->tunnel_users, 1); +out: + return t; + +error: + t->km.state = XFRM_STATE_DEAD; + xfrm_state_put(t); + t = NULL; + goto out; +} + +/* + * Must be protected by xfrm_cfg_mutex. State and tunnel user references are + * always incremented on success. + */ +static int ipcomp_tunnel_attach(struct xfrm_state *x) +{ + struct net *net = xs_net(x); + int err = 0; + struct xfrm_state *t; + u32 mark = x->mark.v & x->mark.m; + + t = xfrm_state_lookup(net, mark, (xfrm_address_t *)&x->id.daddr.a4, + x->props.saddr.a4, IPPROTO_IPIP, AF_INET); + if (!t) { + t = ipcomp_tunnel_create(x); + if (!t) { + err = -EINVAL; + goto out; + } + xfrm_state_insert(t); + xfrm_state_hold(t); + } + x->tunnel = t; + atomic_inc(&t->tunnel_users); +out: + return err; +} + +static int ipcomp4_init_state(struct xfrm_state *x, + struct netlink_ext_ack *extack) +{ + int err = -EINVAL; + + x->props.header_len = 0; + switch (x->props.mode) { + case XFRM_MODE_TRANSPORT: + break; + case XFRM_MODE_TUNNEL: + x->props.header_len += sizeof(struct iphdr); + break; + default: + NL_SET_ERR_MSG(extack, "Unsupported XFRM mode for IPcomp"); + goto out; + } + + err = ipcomp_init_state(x, extack); + if (err) + goto out; + + if (x->props.mode == XFRM_MODE_TUNNEL) { + err = ipcomp_tunnel_attach(x); + if (err) { + NL_SET_ERR_MSG(extack, "Kernel error: failed to initialize the associated state"); + goto out; + } + } + + err = 0; +out: + return err; +} + +static int ipcomp4_rcv_cb(struct sk_buff *skb, int err) +{ + return 0; +} + +static const struct xfrm_type ipcomp_type = { + .owner = THIS_MODULE, + .proto = IPPROTO_COMP, + .init_state = ipcomp4_init_state, + .destructor = ipcomp_destroy, + .input = ipcomp_input, + .output = ipcomp_output +}; + +static struct xfrm4_protocol ipcomp4_protocol = { + .handler = xfrm4_rcv, + .input_handler = xfrm_input, + .cb_handler = ipcomp4_rcv_cb, + .err_handler = ipcomp4_err, + .priority = 0, +}; + +static int __init ipcomp4_init(void) +{ + if (xfrm_register_type(&ipcomp_type, AF_INET) < 0) { + pr_info("%s: can't add xfrm type\n", __func__); + return -EAGAIN; + } + if (xfrm4_protocol_register(&ipcomp4_protocol, IPPROTO_COMP) < 0) { + pr_info("%s: can't add protocol\n", __func__); + xfrm_unregister_type(&ipcomp_type, AF_INET); + return -EAGAIN; + } + return 0; +} + +static void __exit ipcomp4_fini(void) +{ + if (xfrm4_protocol_deregister(&ipcomp4_protocol, IPPROTO_COMP) < 0) + pr_info("%s: can't remove protocol\n", __func__); + xfrm_unregister_type(&ipcomp_type, AF_INET); +} + +module_init(ipcomp4_init); +module_exit(ipcomp4_fini); + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("IP Payload Compression Protocol (IPComp/IPv4) - RFC3173"); +MODULE_AUTHOR("James Morris "); + +MODULE_ALIAS_XFRM_TYPE(AF_INET, XFRM_PROTO_COMP); diff --git a/net/ipv4/ipconfig.c b/net/ipv4/ipconfig.c new file mode 100644 index 000000000..e90bc0aa8 --- /dev/null +++ b/net/ipv4/ipconfig.c @@ -0,0 +1,1845 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Automatic Configuration of IP -- use DHCP, BOOTP, RARP, or + * user-supplied information to configure own IP address and routes. + * + * Copyright (C) 1996-1998 Martin Mares + * + * Derived from network configuration code in fs/nfs/nfsroot.c, + * originally Copyright (C) 1995, 1996 Gero Kuhlmann and me. + * + * BOOTP rewritten to construct and analyse packets itself instead + * of misusing the IP layer. num_bugs_causing_wrong_arp_replies--; + * -- MJ, December 1998 + * + * Fixed ip_auto_config_setup calling at startup in the new "Linker Magic" + * initialization scheme. + * - Arnaldo Carvalho de Melo , 08/11/1999 + * + * DHCP support added. To users this looks like a whole separate + * protocol, but we know it's just a bag on the side of BOOTP. + * -- Chip Salzenberg , May 2000 + * + * Ported DHCP support from 2.2.16 to 2.4.0-test4 + * -- Eric Biederman , 30 Aug 2000 + * + * Merged changes from 2.2.19 into 2.4.3 + * -- Eric Biederman , 22 April Aug 2001 + * + * Multiple Nameservers in /proc/net/pnp + * -- Josef Siemes , Aug 2002 + * + * NTP servers in /proc/net/ipconfig/ntp_servers + * -- Chris Novakovic , April 2018 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#if defined(CONFIG_IP_PNP_DHCP) +#define IPCONFIG_DHCP +#endif +#if defined(CONFIG_IP_PNP_BOOTP) || defined(CONFIG_IP_PNP_DHCP) +#define IPCONFIG_BOOTP +#endif +#if defined(CONFIG_IP_PNP_RARP) +#define IPCONFIG_RARP +#endif +#if defined(IPCONFIG_BOOTP) || defined(IPCONFIG_RARP) +#define IPCONFIG_DYNAMIC +#endif + +/* Define the friendly delay before and after opening net devices */ +#define CONF_POST_OPEN 10 /* After opening: 10 msecs */ + +/* Define the timeout for waiting for a DHCP/BOOTP/RARP reply */ +#define CONF_OPEN_RETRIES 2 /* (Re)open devices twice */ +#define CONF_SEND_RETRIES 6 /* Send six requests per open */ +#define CONF_BASE_TIMEOUT (HZ*2) /* Initial timeout: 2 seconds */ +#define CONF_TIMEOUT_RANDOM (HZ) /* Maximum amount of randomization */ +#define CONF_TIMEOUT_MULT *7/4 /* Rate of timeout growth */ +#define CONF_TIMEOUT_MAX (HZ*30) /* Maximum allowed timeout */ +#define CONF_NAMESERVERS_MAX 3 /* Maximum number of nameservers + - '3' from resolv.h */ +#define CONF_NTP_SERVERS_MAX 3 /* Maximum number of NTP servers */ + +#define NONE cpu_to_be32(INADDR_NONE) +#define ANY cpu_to_be32(INADDR_ANY) + +/* Wait for carrier timeout default in seconds */ +static unsigned int carrier_timeout = 120; + +/* + * Public IP configuration + */ + +/* This is used by platforms which might be able to set the ipconfig + * variables using firmware environment vars. If this is set, it will + * ignore such firmware variables. + */ +int ic_set_manually __initdata = 0; /* IPconfig parameters set manually */ + +static int ic_enable __initdata; /* IP config enabled? */ + +/* Protocol choice */ +int ic_proto_enabled __initdata = 0 +#ifdef IPCONFIG_BOOTP + | IC_BOOTP +#endif +#ifdef CONFIG_IP_PNP_DHCP + | IC_USE_DHCP +#endif +#ifdef IPCONFIG_RARP + | IC_RARP +#endif + ; + +static int ic_host_name_set __initdata; /* Host name set by us? */ + +__be32 ic_myaddr = NONE; /* My IP address */ +static __be32 ic_netmask = NONE; /* Netmask for local subnet */ +__be32 ic_gateway = NONE; /* Gateway IP address */ + +#ifdef IPCONFIG_DYNAMIC +static __be32 ic_addrservaddr = NONE; /* IP Address of the IP addresses'server */ +#endif + +__be32 ic_servaddr = NONE; /* Boot server IP address */ + +__be32 root_server_addr = NONE; /* Address of NFS server */ +u8 root_server_path[256] = { 0, }; /* Path to mount as root */ + +/* vendor class identifier */ +static char vendor_class_identifier[253] __initdata; + +#if defined(CONFIG_IP_PNP_DHCP) +static char dhcp_client_identifier[253] __initdata; +#endif + +/* Persistent data: */ + +#ifdef IPCONFIG_DYNAMIC +static int ic_proto_used; /* Protocol used, if any */ +#else +#define ic_proto_used 0 +#endif +static __be32 ic_nameservers[CONF_NAMESERVERS_MAX]; /* DNS Server IP addresses */ +static __be32 ic_ntp_servers[CONF_NTP_SERVERS_MAX]; /* NTP server IP addresses */ +static u8 ic_domain[64]; /* DNS (not NIS) domain name */ + +/* + * Private state. + */ + +/* Name of user-selected boot device */ +static char user_dev_name[IFNAMSIZ] __initdata = { 0, }; + +/* Protocols supported by available interfaces */ +static int ic_proto_have_if __initdata; + +/* MTU for boot device */ +static int ic_dev_mtu __initdata; + +#ifdef IPCONFIG_DYNAMIC +static DEFINE_SPINLOCK(ic_recv_lock); +static volatile int ic_got_reply __initdata; /* Proto(s) that replied */ +#endif +#ifdef IPCONFIG_DHCP +static int ic_dhcp_msgtype __initdata; /* DHCP msg type received */ +#endif + + +/* + * Network devices + */ + +struct ic_device { + struct ic_device *next; + struct net_device *dev; + unsigned short flags; + short able; + __be32 xid; +}; + +static struct ic_device *ic_first_dev __initdata; /* List of open device */ +static struct ic_device *ic_dev __initdata; /* Selected device */ + +static bool __init ic_is_init_dev(struct net_device *dev) +{ + if (dev->flags & IFF_LOOPBACK) + return false; + return user_dev_name[0] ? !strcmp(dev->name, user_dev_name) : + (!(dev->flags & IFF_LOOPBACK) && + (dev->flags & (IFF_POINTOPOINT|IFF_BROADCAST)) && + strncmp(dev->name, "dummy", 5)); +} + +static int __init ic_open_devs(void) +{ + struct ic_device *d, **last; + struct net_device *dev; + unsigned short oflags; + unsigned long start, next_msg; + + last = &ic_first_dev; + rtnl_lock(); + + /* bring loopback device up first */ + for_each_netdev(&init_net, dev) { + if (!(dev->flags & IFF_LOOPBACK)) + continue; + if (dev_change_flags(dev, dev->flags | IFF_UP, NULL) < 0) + pr_err("IP-Config: Failed to open %s\n", dev->name); + } + + for_each_netdev(&init_net, dev) { + if (ic_is_init_dev(dev)) { + int able = 0; + if (dev->mtu >= 364) + able |= IC_BOOTP; + else + pr_warn("DHCP/BOOTP: Ignoring device %s, MTU %d too small\n", + dev->name, dev->mtu); + if (!(dev->flags & IFF_NOARP)) + able |= IC_RARP; + able &= ic_proto_enabled; + if (ic_proto_enabled && !able) + continue; + oflags = dev->flags; + if (dev_change_flags(dev, oflags | IFF_UP, NULL) < 0) { + pr_err("IP-Config: Failed to open %s\n", + dev->name); + continue; + } + if (!(d = kmalloc(sizeof(struct ic_device), GFP_KERNEL))) { + rtnl_unlock(); + return -ENOMEM; + } + d->dev = dev; + *last = d; + last = &d->next; + d->flags = oflags; + d->able = able; + if (able & IC_BOOTP) + get_random_bytes(&d->xid, sizeof(__be32)); + else + d->xid = 0; + ic_proto_have_if |= able; + pr_debug("IP-Config: %s UP (able=%d, xid=%08x)\n", + dev->name, able, d->xid); + } + } + /* Devices with a complex topology like SFP ethernet interfaces needs + * the rtnl_lock at init. The carrier wait-loop must therefore run + * without holding it. + */ + rtnl_unlock(); + + /* no point in waiting if we could not bring up at least one device */ + if (!ic_first_dev) + goto have_carrier; + + /* wait for a carrier on at least one device */ + start = jiffies; + next_msg = start + msecs_to_jiffies(20000); + while (time_before(jiffies, start + + msecs_to_jiffies(carrier_timeout * 1000))) { + int wait, elapsed; + + rtnl_lock(); + for_each_netdev(&init_net, dev) + if (ic_is_init_dev(dev) && netif_carrier_ok(dev)) { + rtnl_unlock(); + goto have_carrier; + } + rtnl_unlock(); + + msleep(1); + + if (time_before(jiffies, next_msg)) + continue; + + elapsed = jiffies_to_msecs(jiffies - start); + wait = (carrier_timeout * 1000 - elapsed + 500) / 1000; + pr_info("Waiting up to %d more seconds for network.\n", wait); + next_msg = jiffies + msecs_to_jiffies(20000); + } +have_carrier: + + *last = NULL; + + if (!ic_first_dev) { + if (user_dev_name[0]) + pr_err("IP-Config: Device `%s' not found\n", + user_dev_name); + else + pr_err("IP-Config: No network devices available\n"); + return -ENODEV; + } + return 0; +} + +/* Close all network interfaces except the one we've autoconfigured, and its + * lowers, in case it's a stacked virtual interface. + */ +static void __init ic_close_devs(void) +{ + struct net_device *selected_dev = ic_dev ? ic_dev->dev : NULL; + struct ic_device *d, *next; + struct net_device *dev; + + rtnl_lock(); + next = ic_first_dev; + while ((d = next)) { + bool bring_down = (d != ic_dev); + struct net_device *lower; + struct list_head *iter; + + next = d->next; + dev = d->dev; + + if (selected_dev) { + netdev_for_each_lower_dev(selected_dev, lower, iter) { + if (dev == lower) { + bring_down = false; + break; + } + } + } + if (bring_down) { + pr_debug("IP-Config: Downing %s\n", dev->name); + dev_change_flags(dev, d->flags, NULL); + } + kfree(d); + } + rtnl_unlock(); +} + +/* + * Interface to various network functions. + */ + +static inline void +set_sockaddr(struct sockaddr_in *sin, __be32 addr, __be16 port) +{ + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = addr; + sin->sin_port = port; +} + +/* + * Set up interface addresses and routes. + */ + +static int __init ic_setup_if(void) +{ + struct ifreq ir; + struct sockaddr_in *sin = (void *) &ir.ifr_ifru.ifru_addr; + int err; + + memset(&ir, 0, sizeof(ir)); + strcpy(ir.ifr_ifrn.ifrn_name, ic_dev->dev->name); + set_sockaddr(sin, ic_myaddr, 0); + if ((err = devinet_ioctl(&init_net, SIOCSIFADDR, &ir)) < 0) { + pr_err("IP-Config: Unable to set interface address (%d)\n", + err); + return -1; + } + set_sockaddr(sin, ic_netmask, 0); + if ((err = devinet_ioctl(&init_net, SIOCSIFNETMASK, &ir)) < 0) { + pr_err("IP-Config: Unable to set interface netmask (%d)\n", + err); + return -1; + } + set_sockaddr(sin, ic_myaddr | ~ic_netmask, 0); + if ((err = devinet_ioctl(&init_net, SIOCSIFBRDADDR, &ir)) < 0) { + pr_err("IP-Config: Unable to set interface broadcast address (%d)\n", + err); + return -1; + } + /* Handle the case where we need non-standard MTU on the boot link (a network + * using jumbo frames, for instance). If we can't set the mtu, don't error + * out, we'll try to muddle along. + */ + if (ic_dev_mtu != 0) { + rtnl_lock(); + if ((err = dev_set_mtu(ic_dev->dev, ic_dev_mtu)) < 0) + pr_err("IP-Config: Unable to set interface mtu to %d (%d)\n", + ic_dev_mtu, err); + rtnl_unlock(); + } + return 0; +} + +static int __init ic_setup_routes(void) +{ + /* No need to setup device routes, only the default route... */ + + if (ic_gateway != NONE) { + struct rtentry rm; + int err; + + memset(&rm, 0, sizeof(rm)); + if ((ic_gateway ^ ic_myaddr) & ic_netmask) { + pr_err("IP-Config: Gateway not on directly connected network\n"); + return -1; + } + set_sockaddr((struct sockaddr_in *) &rm.rt_dst, 0, 0); + set_sockaddr((struct sockaddr_in *) &rm.rt_genmask, 0, 0); + set_sockaddr((struct sockaddr_in *) &rm.rt_gateway, ic_gateway, 0); + rm.rt_flags = RTF_UP | RTF_GATEWAY; + if ((err = ip_rt_ioctl(&init_net, SIOCADDRT, &rm)) < 0) { + pr_err("IP-Config: Cannot add default route (%d)\n", + err); + return -1; + } + } + + return 0; +} + +/* + * Fill in default values for all missing parameters. + */ + +static int __init ic_defaults(void) +{ + /* + * At this point we have no userspace running so need not + * claim locks on system_utsname + */ + + if (!ic_host_name_set) + sprintf(init_utsname()->nodename, "%pI4", &ic_myaddr); + + if (root_server_addr == NONE) + root_server_addr = ic_servaddr; + + if (ic_netmask == NONE) { + if (IN_CLASSA(ntohl(ic_myaddr))) + ic_netmask = htonl(IN_CLASSA_NET); + else if (IN_CLASSB(ntohl(ic_myaddr))) + ic_netmask = htonl(IN_CLASSB_NET); + else if (IN_CLASSC(ntohl(ic_myaddr))) + ic_netmask = htonl(IN_CLASSC_NET); + else if (IN_CLASSE(ntohl(ic_myaddr))) + ic_netmask = htonl(IN_CLASSE_NET); + else { + pr_err("IP-Config: Unable to guess netmask for address %pI4\n", + &ic_myaddr); + return -1; + } + pr_notice("IP-Config: Guessing netmask %pI4\n", + &ic_netmask); + } + + return 0; +} + +/* + * RARP support. + */ + +#ifdef IPCONFIG_RARP + +static int ic_rarp_recv(struct sk_buff *skb, struct net_device *dev, struct packet_type *pt, struct net_device *orig_dev); + +static struct packet_type rarp_packet_type __initdata = { + .type = cpu_to_be16(ETH_P_RARP), + .func = ic_rarp_recv, +}; + +static inline void __init ic_rarp_init(void) +{ + dev_add_pack(&rarp_packet_type); +} + +static inline void __init ic_rarp_cleanup(void) +{ + dev_remove_pack(&rarp_packet_type); +} + +/* + * Process received RARP packet. + */ +static int __init +ic_rarp_recv(struct sk_buff *skb, struct net_device *dev, struct packet_type *pt, struct net_device *orig_dev) +{ + struct arphdr *rarp; + unsigned char *rarp_ptr; + __be32 sip, tip; + unsigned char *tha; /* t for "target" */ + struct ic_device *d; + + if (!net_eq(dev_net(dev), &init_net)) + goto drop; + + skb = skb_share_check(skb, GFP_ATOMIC); + if (!skb) + return NET_RX_DROP; + + if (!pskb_may_pull(skb, sizeof(struct arphdr))) + goto drop; + + /* Basic sanity checks can be done without the lock. */ + rarp = (struct arphdr *)skb_transport_header(skb); + + /* If this test doesn't pass, it's not IP, or we should + * ignore it anyway. + */ + if (rarp->ar_hln != dev->addr_len || dev->type != ntohs(rarp->ar_hrd)) + goto drop; + + /* If it's not a RARP reply, delete it. */ + if (rarp->ar_op != htons(ARPOP_RREPLY)) + goto drop; + + /* If it's not Ethernet, delete it. */ + if (rarp->ar_pro != htons(ETH_P_IP)) + goto drop; + + if (!pskb_may_pull(skb, arp_hdr_len(dev))) + goto drop; + + /* OK, it is all there and looks valid, process... */ + rarp = (struct arphdr *)skb_transport_header(skb); + rarp_ptr = (unsigned char *) (rarp + 1); + + /* One reply at a time, please. */ + spin_lock(&ic_recv_lock); + + /* If we already have a reply, just drop the packet */ + if (ic_got_reply) + goto drop_unlock; + + /* Find the ic_device that the packet arrived on */ + d = ic_first_dev; + while (d && d->dev != dev) + d = d->next; + if (!d) + goto drop_unlock; /* should never happen */ + + /* Extract variable-width fields */ + rarp_ptr += dev->addr_len; + memcpy(&sip, rarp_ptr, 4); + rarp_ptr += 4; + tha = rarp_ptr; + rarp_ptr += dev->addr_len; + memcpy(&tip, rarp_ptr, 4); + + /* Discard packets which are not meant for us. */ + if (memcmp(tha, dev->dev_addr, dev->addr_len)) + goto drop_unlock; + + /* Discard packets which are not from specified server. */ + if (ic_servaddr != NONE && ic_servaddr != sip) + goto drop_unlock; + + /* We have a winner! */ + ic_dev = d; + if (ic_myaddr == NONE) + ic_myaddr = tip; + ic_servaddr = sip; + ic_addrservaddr = sip; + ic_got_reply = IC_RARP; + +drop_unlock: + /* Show's over. Nothing to see here. */ + spin_unlock(&ic_recv_lock); + +drop: + /* Throw the packet out. */ + kfree_skb(skb); + return 0; +} + + +/* + * Send RARP request packet over a single interface. + */ +static void __init ic_rarp_send_if(struct ic_device *d) +{ + struct net_device *dev = d->dev; + arp_send(ARPOP_RREQUEST, ETH_P_RARP, 0, dev, 0, NULL, + dev->dev_addr, dev->dev_addr); +} +#endif + +/* + * Predefine Nameservers + */ +static inline void __init ic_nameservers_predef(void) +{ + int i; + + for (i = 0; i < CONF_NAMESERVERS_MAX; i++) + ic_nameservers[i] = NONE; +} + +/* Predefine NTP servers */ +static inline void __init ic_ntp_servers_predef(void) +{ + int i; + + for (i = 0; i < CONF_NTP_SERVERS_MAX; i++) + ic_ntp_servers[i] = NONE; +} + +/* + * DHCP/BOOTP support. + */ + +#ifdef IPCONFIG_BOOTP + +struct bootp_pkt { /* BOOTP packet format */ + struct iphdr iph; /* IP header */ + struct udphdr udph; /* UDP header */ + u8 op; /* 1=request, 2=reply */ + u8 htype; /* HW address type */ + u8 hlen; /* HW address length */ + u8 hops; /* Used only by gateways */ + __be32 xid; /* Transaction ID */ + __be16 secs; /* Seconds since we started */ + __be16 flags; /* Just what it says */ + __be32 client_ip; /* Client's IP address if known */ + __be32 your_ip; /* Assigned IP address */ + __be32 server_ip; /* (Next, e.g. NFS) Server's IP address */ + __be32 relay_ip; /* IP address of BOOTP relay */ + u8 hw_addr[16]; /* Client's HW address */ + u8 serv_name[64]; /* Server host name */ + u8 boot_file[128]; /* Name of boot file */ + u8 exten[312]; /* DHCP options / BOOTP vendor extensions */ +}; + +/* packet ops */ +#define BOOTP_REQUEST 1 +#define BOOTP_REPLY 2 + +/* DHCP message types */ +#define DHCPDISCOVER 1 +#define DHCPOFFER 2 +#define DHCPREQUEST 3 +#define DHCPDECLINE 4 +#define DHCPACK 5 +#define DHCPNAK 6 +#define DHCPRELEASE 7 +#define DHCPINFORM 8 + +static int ic_bootp_recv(struct sk_buff *skb, struct net_device *dev, struct packet_type *pt, struct net_device *orig_dev); + +static struct packet_type bootp_packet_type __initdata = { + .type = cpu_to_be16(ETH_P_IP), + .func = ic_bootp_recv, +}; + +/* + * Initialize DHCP/BOOTP extension fields in the request. + */ + +static const u8 ic_bootp_cookie[4] = { 99, 130, 83, 99 }; + +#ifdef IPCONFIG_DHCP + +static void __init +ic_dhcp_init_options(u8 *options, struct ic_device *d) +{ + u8 mt = ((ic_servaddr == NONE) + ? DHCPDISCOVER : DHCPREQUEST); + u8 *e = options; + int len; + + pr_debug("DHCP: Sending message type %d (%s)\n", mt, d->dev->name); + + memcpy(e, ic_bootp_cookie, 4); /* RFC1048 Magic Cookie */ + e += 4; + + *e++ = 53; /* DHCP message type */ + *e++ = 1; + *e++ = mt; + + if (mt == DHCPREQUEST) { + *e++ = 54; /* Server ID (IP address) */ + *e++ = 4; + memcpy(e, &ic_servaddr, 4); + e += 4; + + *e++ = 50; /* Requested IP address */ + *e++ = 4; + memcpy(e, &ic_myaddr, 4); + e += 4; + } + + /* always? */ + { + static const u8 ic_req_params[] = { + 1, /* Subnet mask */ + 3, /* Default gateway */ + 6, /* DNS server */ + 12, /* Host name */ + 15, /* Domain name */ + 17, /* Boot path */ + 26, /* MTU */ + 40, /* NIS domain name */ + 42, /* NTP servers */ + }; + + *e++ = 55; /* Parameter request list */ + *e++ = sizeof(ic_req_params); + memcpy(e, ic_req_params, sizeof(ic_req_params)); + e += sizeof(ic_req_params); + + if (ic_host_name_set) { + *e++ = 12; /* host-name */ + len = strlen(utsname()->nodename); + *e++ = len; + memcpy(e, utsname()->nodename, len); + e += len; + } + if (*vendor_class_identifier) { + pr_info("DHCP: sending class identifier \"%s\"\n", + vendor_class_identifier); + *e++ = 60; /* Class-identifier */ + len = strlen(vendor_class_identifier); + *e++ = len; + memcpy(e, vendor_class_identifier, len); + e += len; + } + len = strlen(dhcp_client_identifier + 1); + /* the minimum length of identifier is 2, include 1 byte type, + * and can not be larger than the length of options + */ + if (len >= 1 && len < 312 - (e - options) - 1) { + *e++ = 61; + *e++ = len + 1; + memcpy(e, dhcp_client_identifier, len + 1); + e += len + 1; + } + } + + *e++ = 255; /* End of the list */ +} + +#endif /* IPCONFIG_DHCP */ + +static void __init ic_bootp_init_ext(u8 *e) +{ + memcpy(e, ic_bootp_cookie, 4); /* RFC1048 Magic Cookie */ + e += 4; + *e++ = 1; /* Subnet mask request */ + *e++ = 4; + e += 4; + *e++ = 3; /* Default gateway request */ + *e++ = 4; + e += 4; +#if CONF_NAMESERVERS_MAX > 0 + *e++ = 6; /* (DNS) name server request */ + *e++ = 4 * CONF_NAMESERVERS_MAX; + e += 4 * CONF_NAMESERVERS_MAX; +#endif + *e++ = 12; /* Host name request */ + *e++ = 32; + e += 32; + *e++ = 40; /* NIS Domain name request */ + *e++ = 32; + e += 32; + *e++ = 17; /* Boot path */ + *e++ = 40; + e += 40; + + *e++ = 57; /* set extension buffer size for reply */ + *e++ = 2; + *e++ = 1; /* 128+236+8+20+14, see dhcpd sources */ + *e++ = 150; + + *e++ = 255; /* End of the list */ +} + + +/* + * Initialize the DHCP/BOOTP mechanism. + */ +static inline void __init ic_bootp_init(void) +{ + /* Re-initialise all name servers and NTP servers to NONE, in case any + * were set via the "ip=" or "nfsaddrs=" kernel command line parameters: + * any IP addresses specified there will already have been decoded but + * are no longer needed + */ + ic_nameservers_predef(); + ic_ntp_servers_predef(); + + dev_add_pack(&bootp_packet_type); +} + + +/* + * DHCP/BOOTP cleanup. + */ +static inline void __init ic_bootp_cleanup(void) +{ + dev_remove_pack(&bootp_packet_type); +} + + +/* + * Send DHCP/BOOTP request to single interface. + */ +static void __init ic_bootp_send_if(struct ic_device *d, unsigned long jiffies_diff) +{ + struct net_device *dev = d->dev; + struct sk_buff *skb; + struct bootp_pkt *b; + struct iphdr *h; + int hlen = LL_RESERVED_SPACE(dev); + int tlen = dev->needed_tailroom; + + /* Allocate packet */ + skb = alloc_skb(sizeof(struct bootp_pkt) + hlen + tlen + 15, + GFP_KERNEL); + if (!skb) + return; + skb_reserve(skb, hlen); + b = skb_put_zero(skb, sizeof(struct bootp_pkt)); + + /* Construct IP header */ + skb_reset_network_header(skb); + h = ip_hdr(skb); + h->version = 4; + h->ihl = 5; + h->tot_len = htons(sizeof(struct bootp_pkt)); + h->frag_off = htons(IP_DF); + h->ttl = 64; + h->protocol = IPPROTO_UDP; + h->daddr = htonl(INADDR_BROADCAST); + h->check = ip_fast_csum((unsigned char *) h, h->ihl); + + /* Construct UDP header */ + b->udph.source = htons(68); + b->udph.dest = htons(67); + b->udph.len = htons(sizeof(struct bootp_pkt) - sizeof(struct iphdr)); + /* UDP checksum not calculated -- explicitly allowed in BOOTP RFC */ + + /* Construct DHCP/BOOTP header */ + b->op = BOOTP_REQUEST; + if (dev->type < 256) /* check for false types */ + b->htype = dev->type; + else if (dev->type == ARPHRD_FDDI) + b->htype = ARPHRD_ETHER; + else { + pr_warn("Unknown ARP type 0x%04x for device %s\n", dev->type, + dev->name); + b->htype = dev->type; /* can cause undefined behavior */ + } + + /* server_ip and your_ip address are both already zero per RFC2131 */ + b->hlen = dev->addr_len; + memcpy(b->hw_addr, dev->dev_addr, dev->addr_len); + b->secs = htons(jiffies_diff / HZ); + b->xid = d->xid; + + /* add DHCP options or BOOTP extensions */ +#ifdef IPCONFIG_DHCP + if (ic_proto_enabled & IC_USE_DHCP) + ic_dhcp_init_options(b->exten, d); + else +#endif + ic_bootp_init_ext(b->exten); + + /* Chain packet down the line... */ + skb->dev = dev; + skb->protocol = htons(ETH_P_IP); + if (dev_hard_header(skb, dev, ntohs(skb->protocol), + dev->broadcast, dev->dev_addr, skb->len) < 0) { + kfree_skb(skb); + printk("E"); + return; + } + + if (dev_queue_xmit(skb) < 0) + printk("E"); +} + + +/* + * Copy BOOTP-supplied string + */ +static int __init ic_bootp_string(char *dest, char *src, int len, int max) +{ + if (!len) + return 0; + if (len > max-1) + len = max-1; + memcpy(dest, src, len); + dest[len] = '\0'; + return 1; +} + + +/* + * Process BOOTP extensions. + */ +static void __init ic_do_bootp_ext(u8 *ext) +{ + u8 servers; + int i; + __be16 mtu; + + u8 *c; + + pr_debug("DHCP/BOOTP: Got extension %d:", *ext); + for (c=ext+2; c CONF_NAMESERVERS_MAX) + servers = CONF_NAMESERVERS_MAX; + for (i = 0; i < servers; i++) { + if (ic_nameservers[i] == NONE) + memcpy(&ic_nameservers[i], ext+1+4*i, 4); + } + break; + case 12: /* Host name */ + if (!ic_host_name_set) { + ic_bootp_string(utsname()->nodename, ext+1, *ext, + __NEW_UTS_LEN); + ic_host_name_set = 1; + } + break; + case 15: /* Domain name (DNS) */ + if (!ic_domain[0]) + ic_bootp_string(ic_domain, ext+1, *ext, sizeof(ic_domain)); + break; + case 17: /* Root path */ + if (!root_server_path[0]) + ic_bootp_string(root_server_path, ext+1, *ext, + sizeof(root_server_path)); + break; + case 26: /* Interface MTU */ + memcpy(&mtu, ext+1, sizeof(mtu)); + ic_dev_mtu = ntohs(mtu); + break; + case 40: /* NIS Domain name (_not_ DNS) */ + ic_bootp_string(utsname()->domainname, ext+1, *ext, + __NEW_UTS_LEN); + break; + case 42: /* NTP servers */ + servers = *ext / 4; + if (servers > CONF_NTP_SERVERS_MAX) + servers = CONF_NTP_SERVERS_MAX; + for (i = 0; i < servers; i++) { + if (ic_ntp_servers[i] == NONE) + memcpy(&ic_ntp_servers[i], ext+1+4*i, 4); + } + break; + } +} + + +/* + * Receive BOOTP reply. + */ +static int __init ic_bootp_recv(struct sk_buff *skb, struct net_device *dev, struct packet_type *pt, struct net_device *orig_dev) +{ + struct bootp_pkt *b; + struct iphdr *h; + struct ic_device *d; + int len, ext_len; + + if (!net_eq(dev_net(dev), &init_net)) + goto drop; + + /* Perform verifications before taking the lock. */ + if (skb->pkt_type == PACKET_OTHERHOST) + goto drop; + + skb = skb_share_check(skb, GFP_ATOMIC); + if (!skb) + return NET_RX_DROP; + + if (!pskb_may_pull(skb, + sizeof(struct iphdr) + + sizeof(struct udphdr))) + goto drop; + + b = (struct bootp_pkt *)skb_network_header(skb); + h = &b->iph; + + if (h->ihl != 5 || h->version != 4 || h->protocol != IPPROTO_UDP) + goto drop; + + /* Fragments are not supported */ + if (ip_is_fragment(h)) { + net_err_ratelimited("DHCP/BOOTP: Ignoring fragmented reply\n"); + goto drop; + } + + if (skb->len < ntohs(h->tot_len)) + goto drop; + + if (ip_fast_csum((char *) h, h->ihl)) + goto drop; + + if (b->udph.source != htons(67) || b->udph.dest != htons(68)) + goto drop; + + if (ntohs(h->tot_len) < ntohs(b->udph.len) + sizeof(struct iphdr)) + goto drop; + + len = ntohs(b->udph.len) - sizeof(struct udphdr); + ext_len = len - (sizeof(*b) - + sizeof(struct iphdr) - + sizeof(struct udphdr) - + sizeof(b->exten)); + if (ext_len < 0) + goto drop; + + /* Ok the front looks good, make sure we can get at the rest. */ + if (!pskb_may_pull(skb, skb->len)) + goto drop; + + b = (struct bootp_pkt *)skb_network_header(skb); + h = &b->iph; + + /* One reply at a time, please. */ + spin_lock(&ic_recv_lock); + + /* If we already have a reply, just drop the packet */ + if (ic_got_reply) + goto drop_unlock; + + /* Find the ic_device that the packet arrived on */ + d = ic_first_dev; + while (d && d->dev != dev) + d = d->next; + if (!d) + goto drop_unlock; /* should never happen */ + + /* Is it a reply to our BOOTP request? */ + if (b->op != BOOTP_REPLY || + b->xid != d->xid) { + net_err_ratelimited("DHCP/BOOTP: Reply not for us on %s, op[%x] xid[%x]\n", + d->dev->name, b->op, b->xid); + goto drop_unlock; + } + + /* Parse extensions */ + if (ext_len >= 4 && + !memcmp(b->exten, ic_bootp_cookie, 4)) { /* Check magic cookie */ + u8 *end = (u8 *) b + ntohs(b->iph.tot_len); + u8 *ext; + +#ifdef IPCONFIG_DHCP + if (ic_proto_enabled & IC_USE_DHCP) { + __be32 server_id = NONE; + int mt = 0; + + ext = &b->exten[4]; + while (ext < end && *ext != 0xff) { + u8 *opt = ext++; + if (*opt == 0) /* Padding */ + continue; + ext += *ext + 1; + if (ext >= end) + break; + switch (*opt) { + case 53: /* Message type */ + if (opt[1]) + mt = opt[2]; + break; + case 54: /* Server ID (IP address) */ + if (opt[1] >= 4) + memcpy(&server_id, opt + 2, 4); + break; + } + } + + pr_debug("DHCP: Got message type %d (%s)\n", mt, d->dev->name); + + switch (mt) { + case DHCPOFFER: + /* While in the process of accepting one offer, + * ignore all others. + */ + if (ic_myaddr != NONE) + goto drop_unlock; + + /* Let's accept that offer. */ + ic_myaddr = b->your_ip; + ic_servaddr = server_id; + pr_debug("DHCP: Offered address %pI4 by server %pI4\n", + &ic_myaddr, &b->iph.saddr); + /* The DHCP indicated server address takes + * precedence over the bootp header one if + * they are different. + */ + if ((server_id != NONE) && + (b->server_ip != server_id)) + b->server_ip = ic_servaddr; + break; + + case DHCPACK: + if (memcmp(dev->dev_addr, b->hw_addr, dev->addr_len) != 0) + goto drop_unlock; + + /* Yeah! */ + break; + + default: + /* Urque. Forget it*/ + ic_myaddr = NONE; + ic_servaddr = NONE; + goto drop_unlock; + } + + ic_dhcp_msgtype = mt; + + } +#endif /* IPCONFIG_DHCP */ + + ext = &b->exten[4]; + while (ext < end && *ext != 0xff) { + u8 *opt = ext++; + if (*opt == 0) /* Padding */ + continue; + ext += *ext + 1; + if (ext < end) + ic_do_bootp_ext(opt); + } + } + + /* We have a winner! */ + ic_dev = d; + ic_myaddr = b->your_ip; + ic_servaddr = b->server_ip; + ic_addrservaddr = b->iph.saddr; + if (ic_gateway == NONE && b->relay_ip) + ic_gateway = b->relay_ip; + if (ic_nameservers[0] == NONE) + ic_nameservers[0] = ic_servaddr; + ic_got_reply = IC_BOOTP; + +drop_unlock: + /* Show's over. Nothing to see here. */ + spin_unlock(&ic_recv_lock); + +drop: + /* Throw the packet out. */ + kfree_skb(skb); + + return 0; +} + + +#endif + + +/* + * Dynamic IP configuration -- DHCP, BOOTP, RARP. + */ + +#ifdef IPCONFIG_DYNAMIC + +static int __init ic_dynamic(void) +{ + int retries; + struct ic_device *d; + unsigned long start_jiffies, timeout, jiff; + int do_bootp = ic_proto_have_if & IC_BOOTP; + int do_rarp = ic_proto_have_if & IC_RARP; + + /* + * If none of DHCP/BOOTP/RARP was selected, return with an error. + * This routine gets only called when some pieces of information + * are missing, and without DHCP/BOOTP/RARP we are unable to get it. + */ + if (!ic_proto_enabled) { + pr_err("IP-Config: Incomplete network configuration information\n"); + return -1; + } + +#ifdef IPCONFIG_BOOTP + if ((ic_proto_enabled ^ ic_proto_have_if) & IC_BOOTP) + pr_err("DHCP/BOOTP: No suitable device found\n"); +#endif +#ifdef IPCONFIG_RARP + if ((ic_proto_enabled ^ ic_proto_have_if) & IC_RARP) + pr_err("RARP: No suitable device found\n"); +#endif + + if (!ic_proto_have_if) + /* Error message already printed */ + return -1; + + /* + * Setup protocols + */ +#ifdef IPCONFIG_BOOTP + if (do_bootp) + ic_bootp_init(); +#endif +#ifdef IPCONFIG_RARP + if (do_rarp) + ic_rarp_init(); +#endif + + /* + * Send requests and wait, until we get an answer. This loop + * seems to be a terrible waste of CPU time, but actually there is + * only one process running at all, so we don't need to use any + * scheduler functions. + * [Actually we could now, but the nothing else running note still + * applies.. - AC] + */ + pr_notice("Sending %s%s%s requests .", + do_bootp + ? ((ic_proto_enabled & IC_USE_DHCP) ? "DHCP" : "BOOTP") : "", + (do_bootp && do_rarp) ? " and " : "", + do_rarp ? "RARP" : ""); + + start_jiffies = jiffies; + d = ic_first_dev; + retries = CONF_SEND_RETRIES; + get_random_bytes(&timeout, sizeof(timeout)); + timeout = CONF_BASE_TIMEOUT + (timeout % (unsigned int) CONF_TIMEOUT_RANDOM); + for (;;) { +#ifdef IPCONFIG_BOOTP + if (do_bootp && (d->able & IC_BOOTP)) + ic_bootp_send_if(d, jiffies - start_jiffies); +#endif +#ifdef IPCONFIG_RARP + if (do_rarp && (d->able & IC_RARP)) + ic_rarp_send_if(d); +#endif + + if (!d->next) { + jiff = jiffies + timeout; + while (time_before(jiffies, jiff) && !ic_got_reply) + schedule_timeout_uninterruptible(1); + } +#ifdef IPCONFIG_DHCP + /* DHCP isn't done until we get a DHCPACK. */ + if ((ic_got_reply & IC_BOOTP) && + (ic_proto_enabled & IC_USE_DHCP) && + ic_dhcp_msgtype != DHCPACK) { + ic_got_reply = 0; + /* continue on device that got the reply */ + d = ic_dev; + pr_cont(","); + continue; + } +#endif /* IPCONFIG_DHCP */ + + if (ic_got_reply) { + pr_cont(" OK\n"); + break; + } + + if ((d = d->next)) + continue; + + if (! --retries) { + pr_cont(" timed out!\n"); + break; + } + + d = ic_first_dev; + + timeout = timeout CONF_TIMEOUT_MULT; + if (timeout > CONF_TIMEOUT_MAX) + timeout = CONF_TIMEOUT_MAX; + + pr_cont("."); + } + +#ifdef IPCONFIG_BOOTP + if (do_bootp) + ic_bootp_cleanup(); +#endif +#ifdef IPCONFIG_RARP + if (do_rarp) + ic_rarp_cleanup(); +#endif + + if (!ic_got_reply) { + ic_myaddr = NONE; + return -1; + } + + pr_info("IP-Config: Got %s answer from %pI4, my address is %pI4\n", + ((ic_got_reply & IC_RARP) ? "RARP" + : (ic_proto_enabled & IC_USE_DHCP) ? "DHCP" : "BOOTP"), + &ic_addrservaddr, &ic_myaddr); + + return 0; +} + +#endif /* IPCONFIG_DYNAMIC */ + +#ifdef CONFIG_PROC_FS +/* proc_dir_entry for /proc/net/ipconfig */ +static struct proc_dir_entry *ipconfig_dir; + +/* Name servers: */ +static int pnp_seq_show(struct seq_file *seq, void *v) +{ + int i; + + if (ic_proto_used & IC_PROTO) + seq_printf(seq, "#PROTO: %s\n", + (ic_proto_used & IC_RARP) ? "RARP" + : (ic_proto_used & IC_USE_DHCP) ? "DHCP" : "BOOTP"); + else + seq_puts(seq, "#MANUAL\n"); + + if (ic_domain[0]) + seq_printf(seq, + "domain %s\n", ic_domain); + for (i = 0; i < CONF_NAMESERVERS_MAX; i++) { + if (ic_nameservers[i] != NONE) + seq_printf(seq, "nameserver %pI4\n", + &ic_nameservers[i]); + } + if (ic_servaddr != NONE) + seq_printf(seq, "bootserver %pI4\n", + &ic_servaddr); + return 0; +} + +/* Create the /proc/net/ipconfig directory */ +static int __init ipconfig_proc_net_init(void) +{ + ipconfig_dir = proc_net_mkdir(&init_net, "ipconfig", init_net.proc_net); + if (!ipconfig_dir) + return -ENOMEM; + + return 0; +} + +/* Create a new file under /proc/net/ipconfig */ +static int ipconfig_proc_net_create(const char *name, + const struct proc_ops *proc_ops) +{ + char *pname; + struct proc_dir_entry *p; + + if (!ipconfig_dir) + return -ENOMEM; + + pname = kasprintf(GFP_KERNEL, "%s%s", "ipconfig/", name); + if (!pname) + return -ENOMEM; + + p = proc_create(pname, 0444, init_net.proc_net, proc_ops); + kfree(pname); + if (!p) + return -ENOMEM; + + return 0; +} + +/* Write NTP server IP addresses to /proc/net/ipconfig/ntp_servers */ +static int ntp_servers_show(struct seq_file *seq, void *v) +{ + int i; + + for (i = 0; i < CONF_NTP_SERVERS_MAX; i++) { + if (ic_ntp_servers[i] != NONE) + seq_printf(seq, "%pI4\n", &ic_ntp_servers[i]); + } + return 0; +} +DEFINE_PROC_SHOW_ATTRIBUTE(ntp_servers); +#endif /* CONFIG_PROC_FS */ + +/* + * Extract IP address from the parameter string if needed. Note that we + * need to have root_server_addr set _before_ IPConfig gets called as it + * can override it. + */ +__be32 __init root_nfs_parse_addr(char *name) +{ + __be32 addr; + int octets = 0; + char *cp, *cq; + + cp = cq = name; + while (octets < 4) { + while (*cp >= '0' && *cp <= '9') + cp++; + if (cp == cq || cp - cq > 3) + break; + if (*cp == '.' || octets == 3) + octets++; + if (octets < 4) + cp++; + cq = cp; + } + if (octets == 4 && (*cp == ':' || *cp == '\0')) { + if (*cp == ':') + *cp++ = '\0'; + addr = in_aton(name); + memmove(name, cp, strlen(cp) + 1); + } else + addr = NONE; + + return addr; +} + +#define DEVICE_WAIT_MAX 12 /* 12 seconds */ + +static int __init wait_for_devices(void) +{ + int i; + bool try_init_devs = true; + + for (i = 0; i < DEVICE_WAIT_MAX; i++) { + struct net_device *dev; + int found = 0; + + /* make sure deferred device probes are finished */ + wait_for_device_probe(); + + rtnl_lock(); + for_each_netdev(&init_net, dev) { + if (ic_is_init_dev(dev)) { + found = 1; + break; + } + } + rtnl_unlock(); + if (found) + return 0; + if (try_init_devs && + (ROOT_DEV == Root_NFS || ROOT_DEV == Root_CIFS)) { + try_init_devs = false; + wait_for_init_devices_probe(); + } + ssleep(1); + } + return -ENODEV; +} + +/* + * IP Autoconfig dispatcher. + */ + +static int __init ip_auto_config(void) +{ + __be32 addr; +#ifdef IPCONFIG_DYNAMIC + int retries = CONF_OPEN_RETRIES; +#endif + int err; + unsigned int i, count; + + /* Initialise all name servers and NTP servers to NONE (but only if the + * "ip=" or "nfsaddrs=" kernel command line parameters weren't decoded, + * otherwise we'll overwrite the IP addresses specified there) + */ + if (ic_set_manually == 0) { + ic_nameservers_predef(); + ic_ntp_servers_predef(); + } + +#ifdef CONFIG_PROC_FS + proc_create_single("pnp", 0444, init_net.proc_net, pnp_seq_show); + + if (ipconfig_proc_net_init() == 0) + ipconfig_proc_net_create("ntp_servers", &ntp_servers_proc_ops); +#endif /* CONFIG_PROC_FS */ + + if (!ic_enable) + return 0; + + pr_debug("IP-Config: Entered.\n"); +#ifdef IPCONFIG_DYNAMIC + try_try_again: +#endif + /* Wait for devices to appear */ + err = wait_for_devices(); + if (err) + return err; + + /* Setup all network devices */ + err = ic_open_devs(); + if (err) + return err; + + /* Give drivers a chance to settle */ + msleep(CONF_POST_OPEN); + + /* + * If the config information is insufficient (e.g., our IP address or + * IP address of the boot server is missing or we have multiple network + * interfaces and no default was set), use BOOTP or RARP to get the + * missing values. + */ + if (ic_myaddr == NONE || +#if defined(CONFIG_ROOT_NFS) || defined(CONFIG_CIFS_ROOT) + (root_server_addr == NONE && + ic_servaddr == NONE && + (ROOT_DEV == Root_NFS || ROOT_DEV == Root_CIFS)) || +#endif + ic_first_dev->next) { +#ifdef IPCONFIG_DYNAMIC + if (ic_dynamic() < 0) { + ic_close_devs(); + + /* + * I don't know why, but sometimes the + * eepro100 driver (at least) gets upset and + * doesn't work the first time it's opened. + * But then if you close it and reopen it, it + * works just fine. So we need to try that at + * least once before giving up. + * + * Also, if the root will be NFS-mounted, we + * have nowhere to go if DHCP fails. So we + * just have to keep trying forever. + * + * -- Chip + */ +#ifdef CONFIG_ROOT_NFS + if (ROOT_DEV == Root_NFS) { + pr_err("IP-Config: Retrying forever (NFS root)...\n"); + goto try_try_again; + } +#endif +#ifdef CONFIG_CIFS_ROOT + if (ROOT_DEV == Root_CIFS) { + pr_err("IP-Config: Retrying forever (CIFS root)...\n"); + goto try_try_again; + } +#endif + + if (--retries) { + pr_err("IP-Config: Reopening network devices...\n"); + goto try_try_again; + } + + /* Oh, well. At least we tried. */ + pr_err("IP-Config: Auto-configuration of network failed\n"); + return -1; + } +#else /* !DYNAMIC */ + pr_err("IP-Config: Incomplete network configuration information\n"); + ic_close_devs(); + return -1; +#endif /* IPCONFIG_DYNAMIC */ + } else { + /* Device selected manually or only one device -> use it */ + ic_dev = ic_first_dev; + } + + addr = root_nfs_parse_addr(root_server_path); + if (root_server_addr == NONE) + root_server_addr = addr; + + /* + * Use defaults wherever applicable. + */ + if (ic_defaults() < 0) + return -1; + + /* + * Record which protocol was actually used. + */ +#ifdef IPCONFIG_DYNAMIC + ic_proto_used = ic_got_reply | (ic_proto_enabled & IC_USE_DHCP); +#endif + +#ifndef IPCONFIG_SILENT + /* + * Clue in the operator. + */ + pr_info("IP-Config: Complete:\n"); + + pr_info(" device=%s, hwaddr=%*phC, ipaddr=%pI4, mask=%pI4, gw=%pI4\n", + ic_dev->dev->name, ic_dev->dev->addr_len, ic_dev->dev->dev_addr, + &ic_myaddr, &ic_netmask, &ic_gateway); + pr_info(" host=%s, domain=%s, nis-domain=%s\n", + utsname()->nodename, ic_domain, utsname()->domainname); + pr_info(" bootserver=%pI4, rootserver=%pI4, rootpath=%s", + &ic_servaddr, &root_server_addr, root_server_path); + if (ic_dev_mtu) + pr_cont(", mtu=%d", ic_dev_mtu); + /* Name servers (if any): */ + for (i = 0, count = 0; i < CONF_NAMESERVERS_MAX; i++) { + if (ic_nameservers[i] != NONE) { + if (i == 0) + pr_info(" nameserver%u=%pI4", + i, &ic_nameservers[i]); + else + pr_cont(", nameserver%u=%pI4", + i, &ic_nameservers[i]); + + count++; + } + if ((i + 1 == CONF_NAMESERVERS_MAX) && count > 0) + pr_cont("\n"); + } + /* NTP servers (if any): */ + for (i = 0, count = 0; i < CONF_NTP_SERVERS_MAX; i++) { + if (ic_ntp_servers[i] != NONE) { + if (i == 0) + pr_info(" ntpserver%u=%pI4", + i, &ic_ntp_servers[i]); + else + pr_cont(", ntpserver%u=%pI4", + i, &ic_ntp_servers[i]); + + count++; + } + if ((i + 1 == CONF_NTP_SERVERS_MAX) && count > 0) + pr_cont("\n"); + } +#endif /* !SILENT */ + + /* + * Close all network devices except the device we've + * autoconfigured and set up routes. + */ + if (ic_setup_if() < 0 || ic_setup_routes() < 0) + err = -1; + else + err = 0; + + ic_close_devs(); + + return err; +} + +late_initcall(ip_auto_config); + + +/* + * Decode any IP configuration options in the "ip=" or "nfsaddrs=" kernel + * command line parameter. See Documentation/admin-guide/nfs/nfsroot.rst. + */ +static int __init ic_proto_name(char *name) +{ + if (!strcmp(name, "on") || !strcmp(name, "any")) { + return 1; + } + if (!strcmp(name, "off") || !strcmp(name, "none")) { + return 0; + } +#ifdef CONFIG_IP_PNP_DHCP + else if (!strncmp(name, "dhcp", 4)) { + char *client_id; + + ic_proto_enabled &= ~IC_RARP; + client_id = strstr(name, "dhcp,"); + if (client_id) { + char *v; + + client_id = client_id + 5; + v = strchr(client_id, ','); + if (!v) + return 1; + *v = 0; + if (kstrtou8(client_id, 0, dhcp_client_identifier)) + pr_debug("DHCP: Invalid client identifier type\n"); + strncpy(dhcp_client_identifier + 1, v + 1, 251); + *v = ','; + } + return 1; + } +#endif +#ifdef CONFIG_IP_PNP_BOOTP + else if (!strcmp(name, "bootp")) { + ic_proto_enabled &= ~(IC_RARP | IC_USE_DHCP); + return 1; + } +#endif +#ifdef CONFIG_IP_PNP_RARP + else if (!strcmp(name, "rarp")) { + ic_proto_enabled &= ~(IC_BOOTP | IC_USE_DHCP); + return 1; + } +#endif +#ifdef IPCONFIG_DYNAMIC + else if (!strcmp(name, "both")) { + ic_proto_enabled &= ~IC_USE_DHCP; /* backward compat :-( */ + return 1; + } +#endif + return 0; +} + +static int __init ip_auto_config_setup(char *addrs) +{ + char *cp, *ip, *dp; + int num = 0; + + ic_set_manually = 1; + ic_enable = 1; + + /* + * If any dhcp, bootp etc options are set, leave autoconfig on + * and skip the below static IP processing. + */ + if (ic_proto_name(addrs)) + return 1; + + /* If no static IP is given, turn off autoconfig and bail. */ + if (*addrs == 0 || + strcmp(addrs, "off") == 0 || + strcmp(addrs, "none") == 0) { + ic_enable = 0; + return 1; + } + + /* Initialise all name servers and NTP servers to NONE */ + ic_nameservers_predef(); + ic_ntp_servers_predef(); + + /* Parse string for static IP assignment. */ + ip = addrs; + while (ip && *ip) { + if ((cp = strchr(ip, ':'))) + *cp++ = '\0'; + if (strlen(ip) > 0) { + pr_debug("IP-Config: Parameter #%d: `%s'\n", num, ip); + switch (num) { + case 0: + if ((ic_myaddr = in_aton(ip)) == ANY) + ic_myaddr = NONE; + break; + case 1: + if ((ic_servaddr = in_aton(ip)) == ANY) + ic_servaddr = NONE; + break; + case 2: + if ((ic_gateway = in_aton(ip)) == ANY) + ic_gateway = NONE; + break; + case 3: + if ((ic_netmask = in_aton(ip)) == ANY) + ic_netmask = NONE; + break; + case 4: + if ((dp = strchr(ip, '.'))) { + *dp++ = '\0'; + strscpy(utsname()->domainname, dp, + sizeof(utsname()->domainname)); + } + strscpy(utsname()->nodename, ip, + sizeof(utsname()->nodename)); + ic_host_name_set = 1; + break; + case 5: + strscpy(user_dev_name, ip, sizeof(user_dev_name)); + break; + case 6: + if (ic_proto_name(ip) == 0 && + ic_myaddr == NONE) { + ic_enable = 0; + } + break; + case 7: + if (CONF_NAMESERVERS_MAX >= 1) { + ic_nameservers[0] = in_aton(ip); + if (ic_nameservers[0] == ANY) + ic_nameservers[0] = NONE; + } + break; + case 8: + if (CONF_NAMESERVERS_MAX >= 2) { + ic_nameservers[1] = in_aton(ip); + if (ic_nameservers[1] == ANY) + ic_nameservers[1] = NONE; + } + break; + case 9: + if (CONF_NTP_SERVERS_MAX >= 1) { + ic_ntp_servers[0] = in_aton(ip); + if (ic_ntp_servers[0] == ANY) + ic_ntp_servers[0] = NONE; + } + break; + } + } + ip = cp; + num++; + } + + return 1; +} +__setup("ip=", ip_auto_config_setup); + +static int __init nfsaddrs_config_setup(char *addrs) +{ + return ip_auto_config_setup(addrs); +} +__setup("nfsaddrs=", nfsaddrs_config_setup); + +static int __init vendor_class_identifier_setup(char *addrs) +{ + if (strscpy(vendor_class_identifier, addrs, + sizeof(vendor_class_identifier)) + >= sizeof(vendor_class_identifier)) + pr_warn("DHCP: vendorclass too long, truncated to \"%s\"\n", + vendor_class_identifier); + return 1; +} +__setup("dhcpclass=", vendor_class_identifier_setup); + +static int __init set_carrier_timeout(char *str) +{ + ssize_t ret; + + if (!str) + return 0; + + ret = kstrtouint(str, 0, &carrier_timeout); + if (ret) + return 0; + + return 1; +} +__setup("carrier_timeout=", set_carrier_timeout); diff --git a/net/ipv4/ipip.c b/net/ipv4/ipip.c new file mode 100644 index 000000000..180f9daf5 --- /dev/null +++ b/net/ipv4/ipip.c @@ -0,0 +1,662 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Linux NET3: IP/IP protocol decoder. + * + * Authors: + * Sam Lantinga (slouken@cs.ucdavis.edu) 02/01/95 + * + * Fixes: + * Alan Cox : Merged and made usable non modular (its so tiny its silly as + * a module taking up 2 pages). + * Alan Cox : Fixed bug with 1.3.18 and IPIP not working (now needs to set skb->h.iph) + * to keep ip_forward happy. + * Alan Cox : More fixes for 1.3.21, and firewall fix. Maybe this will work soon 8). + * Kai Schulte : Fixed #defines for IP_FIREWALL->FIREWALL + * David Woodhouse : Perform some basic ICMP handling. + * IPIP Routing without decapsulation. + * Carlos Picoto : GRE over IP support + * Alexey Kuznetsov: Reworked. Really, now it is truncated version of ipv4/ip_gre.c. + * I do not want to merge them together. + */ + +/* tunnel.c: an IP tunnel driver + + The purpose of this driver is to provide an IP tunnel through + which you can tunnel network traffic transparently across subnets. + + This was written by looking at Nick Holloway's dummy driver + Thanks for the great code! + + -Sam Lantinga (slouken@cs.ucdavis.edu) 02/01/95 + + Minor tweaks: + Cleaned up the code a little and added some pre-1.3.0 tweaks. + dev->hard_header/hard_header_len changed to use no headers. + Comments/bracketing tweaked. + Made the tunnels use dev->name not tunnel: when error reporting. + Added tx_dropped stat + + -Alan Cox (alan@lxorguk.ukuu.org.uk) 21 March 95 + + Reworked: + Changed to tunnel to destination gateway in addition to the + tunnel's pointopoint address + Almost completely rewritten + Note: There is currently no firewall or ICMP handling done. + + -Sam Lantinga (slouken@cs.ucdavis.edu) 02/13/96 + +*/ + +/* Things I wish I had known when writing the tunnel driver: + + When the tunnel_xmit() function is called, the skb contains the + packet to be sent (plus a great deal of extra info), and dev + contains the tunnel device that _we_ are. + + When we are passed a packet, we are expected to fill in the + source address with our source IP address. + + What is the proper way to allocate, copy and free a buffer? + After you allocate it, it is a "0 length" chunk of memory + starting at zero. If you want to add headers to the buffer + later, you'll have to call "skb_reserve(skb, amount)" with + the amount of memory you want reserved. Then, you call + "skb_put(skb, amount)" with the amount of space you want in + the buffer. skb_put() returns a pointer to the top (#0) of + that buffer. skb->len is set to the amount of space you have + "allocated" with skb_put(). You can then write up to skb->len + bytes to that buffer. If you need more, you can call skb_put() + again with the additional amount of space you need. You can + find out how much more space you can allocate by calling + "skb_tailroom(skb)". + Now, to add header space, call "skb_push(skb, header_len)". + This creates space at the beginning of the buffer and returns + a pointer to this new space. If later you need to strip a + header from a buffer, call "skb_pull(skb, header_len)". + skb_headroom() will return how much space is left at the top + of the buffer (before the main data). Remember, this headroom + space must be reserved before the skb_put() function is called. + */ + +/* + This version of net/ipv4/ipip.c is cloned of net/ipv4/ip_gre.c + + For comments look at net/ipv4/ip_gre.c --ANK + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static bool log_ecn_error = true; +module_param(log_ecn_error, bool, 0644); +MODULE_PARM_DESC(log_ecn_error, "Log packets received with corrupted ECN"); + +static unsigned int ipip_net_id __read_mostly; + +static int ipip_tunnel_init(struct net_device *dev); +static struct rtnl_link_ops ipip_link_ops __read_mostly; + +static int ipip_err(struct sk_buff *skb, u32 info) +{ + /* All the routers (except for Linux) return only + * 8 bytes of packet payload. It means, that precise relaying of + * ICMP in the real Internet is absolutely infeasible. + */ + struct net *net = dev_net(skb->dev); + struct ip_tunnel_net *itn = net_generic(net, ipip_net_id); + const struct iphdr *iph = (const struct iphdr *)skb->data; + const int type = icmp_hdr(skb)->type; + const int code = icmp_hdr(skb)->code; + struct ip_tunnel *t; + int err = 0; + + t = ip_tunnel_lookup(itn, skb->dev->ifindex, TUNNEL_NO_KEY, + iph->daddr, iph->saddr, 0); + if (!t) { + err = -ENOENT; + goto out; + } + + switch (type) { + case ICMP_DEST_UNREACH: + switch (code) { + case ICMP_SR_FAILED: + /* Impossible event. */ + goto out; + default: + /* All others are translated to HOST_UNREACH. + * rfc2003 contains "deep thoughts" about NET_UNREACH, + * I believe they are just ether pollution. --ANK + */ + break; + } + break; + + case ICMP_TIME_EXCEEDED: + if (code != ICMP_EXC_TTL) + goto out; + break; + + case ICMP_REDIRECT: + break; + + default: + goto out; + } + + if (type == ICMP_DEST_UNREACH && code == ICMP_FRAG_NEEDED) { + ipv4_update_pmtu(skb, net, info, t->parms.link, iph->protocol); + goto out; + } + + if (type == ICMP_REDIRECT) { + ipv4_redirect(skb, net, t->parms.link, iph->protocol); + goto out; + } + + if (t->parms.iph.daddr == 0) { + err = -ENOENT; + goto out; + } + + if (t->parms.iph.ttl == 0 && type == ICMP_TIME_EXCEEDED) + goto out; + + if (time_before(jiffies, t->err_time + IPTUNNEL_ERR_TIMEO)) + t->err_count++; + else + t->err_count = 1; + t->err_time = jiffies; + +out: + return err; +} + +static const struct tnl_ptk_info ipip_tpi = { + /* no tunnel info required for ipip. */ + .proto = htons(ETH_P_IP), +}; + +#if IS_ENABLED(CONFIG_MPLS) +static const struct tnl_ptk_info mplsip_tpi = { + /* no tunnel info required for mplsip. */ + .proto = htons(ETH_P_MPLS_UC), +}; +#endif + +static int ipip_tunnel_rcv(struct sk_buff *skb, u8 ipproto) +{ + struct net *net = dev_net(skb->dev); + struct ip_tunnel_net *itn = net_generic(net, ipip_net_id); + struct metadata_dst *tun_dst = NULL; + struct ip_tunnel *tunnel; + const struct iphdr *iph; + + iph = ip_hdr(skb); + tunnel = ip_tunnel_lookup(itn, skb->dev->ifindex, TUNNEL_NO_KEY, + iph->saddr, iph->daddr, 0); + if (tunnel) { + const struct tnl_ptk_info *tpi; + + if (tunnel->parms.iph.protocol != ipproto && + tunnel->parms.iph.protocol != 0) + goto drop; + + if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) + goto drop; +#if IS_ENABLED(CONFIG_MPLS) + if (ipproto == IPPROTO_MPLS) + tpi = &mplsip_tpi; + else +#endif + tpi = &ipip_tpi; + if (iptunnel_pull_header(skb, 0, tpi->proto, false)) + goto drop; + if (tunnel->collect_md) { + tun_dst = ip_tun_rx_dst(skb, 0, 0, 0); + if (!tun_dst) + return 0; + } + skb_reset_mac_header(skb); + + return ip_tunnel_rcv(tunnel, skb, tpi, tun_dst, log_ecn_error); + } + + return -1; + +drop: + kfree_skb(skb); + return 0; +} + +static int ipip_rcv(struct sk_buff *skb) +{ + return ipip_tunnel_rcv(skb, IPPROTO_IPIP); +} + +#if IS_ENABLED(CONFIG_MPLS) +static int mplsip_rcv(struct sk_buff *skb) +{ + return ipip_tunnel_rcv(skb, IPPROTO_MPLS); +} +#endif + +/* + * This function assumes it is being called from dev_queue_xmit() + * and that skb is filled properly by that function. + */ +static netdev_tx_t ipip_tunnel_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + const struct iphdr *tiph = &tunnel->parms.iph; + u8 ipproto; + + if (!pskb_inet_may_pull(skb)) + goto tx_error; + + switch (skb->protocol) { + case htons(ETH_P_IP): + ipproto = IPPROTO_IPIP; + break; +#if IS_ENABLED(CONFIG_MPLS) + case htons(ETH_P_MPLS_UC): + ipproto = IPPROTO_MPLS; + break; +#endif + default: + goto tx_error; + } + + if (tiph->protocol != ipproto && tiph->protocol != 0) + goto tx_error; + + if (iptunnel_handle_offloads(skb, SKB_GSO_IPXIP4)) + goto tx_error; + + skb_set_inner_ipproto(skb, ipproto); + + if (tunnel->collect_md) + ip_md_tunnel_xmit(skb, dev, ipproto, 0); + else + ip_tunnel_xmit(skb, dev, tiph, ipproto); + return NETDEV_TX_OK; + +tx_error: + kfree_skb(skb); + + dev->stats.tx_errors++; + return NETDEV_TX_OK; +} + +static bool ipip_tunnel_ioctl_verify_protocol(u8 ipproto) +{ + switch (ipproto) { + case 0: + case IPPROTO_IPIP: +#if IS_ENABLED(CONFIG_MPLS) + case IPPROTO_MPLS: +#endif + return true; + } + + return false; +} + +static int +ipip_tunnel_ctl(struct net_device *dev, struct ip_tunnel_parm *p, int cmd) +{ + if (cmd == SIOCADDTUNNEL || cmd == SIOCCHGTUNNEL) { + if (p->iph.version != 4 || + !ipip_tunnel_ioctl_verify_protocol(p->iph.protocol) || + p->iph.ihl != 5 || (p->iph.frag_off & htons(~IP_DF))) + return -EINVAL; + } + + p->i_key = p->o_key = 0; + p->i_flags = p->o_flags = 0; + return ip_tunnel_ctl(dev, p, cmd); +} + +static const struct net_device_ops ipip_netdev_ops = { + .ndo_init = ipip_tunnel_init, + .ndo_uninit = ip_tunnel_uninit, + .ndo_start_xmit = ipip_tunnel_xmit, + .ndo_siocdevprivate = ip_tunnel_siocdevprivate, + .ndo_change_mtu = ip_tunnel_change_mtu, + .ndo_get_stats64 = dev_get_tstats64, + .ndo_get_iflink = ip_tunnel_get_iflink, + .ndo_tunnel_ctl = ipip_tunnel_ctl, +}; + +#define IPIP_FEATURES (NETIF_F_SG | \ + NETIF_F_FRAGLIST | \ + NETIF_F_HIGHDMA | \ + NETIF_F_GSO_SOFTWARE | \ + NETIF_F_HW_CSUM) + +static void ipip_tunnel_setup(struct net_device *dev) +{ + dev->netdev_ops = &ipip_netdev_ops; + dev->header_ops = &ip_tunnel_header_ops; + + dev->type = ARPHRD_TUNNEL; + dev->flags = IFF_NOARP; + dev->addr_len = 4; + dev->features |= NETIF_F_LLTX; + netif_keep_dst(dev); + + dev->features |= IPIP_FEATURES; + dev->hw_features |= IPIP_FEATURES; + ip_tunnel_setup(dev, ipip_net_id); +} + +static int ipip_tunnel_init(struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + + __dev_addr_set(dev, &tunnel->parms.iph.saddr, 4); + memcpy(dev->broadcast, &tunnel->parms.iph.daddr, 4); + + tunnel->tun_hlen = 0; + tunnel->hlen = tunnel->tun_hlen + tunnel->encap_hlen; + return ip_tunnel_init(dev); +} + +static int ipip_tunnel_validate(struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + u8 proto; + + if (!data || !data[IFLA_IPTUN_PROTO]) + return 0; + + proto = nla_get_u8(data[IFLA_IPTUN_PROTO]); + if (proto != IPPROTO_IPIP && proto != IPPROTO_MPLS && proto != 0) + return -EINVAL; + + return 0; +} + +static void ipip_netlink_parms(struct nlattr *data[], + struct ip_tunnel_parm *parms, bool *collect_md, + __u32 *fwmark) +{ + memset(parms, 0, sizeof(*parms)); + + parms->iph.version = 4; + parms->iph.protocol = IPPROTO_IPIP; + parms->iph.ihl = 5; + *collect_md = false; + + if (!data) + return; + + ip_tunnel_netlink_parms(data, parms); + + if (data[IFLA_IPTUN_COLLECT_METADATA]) + *collect_md = true; + + if (data[IFLA_IPTUN_FWMARK]) + *fwmark = nla_get_u32(data[IFLA_IPTUN_FWMARK]); +} + +static int ipip_newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ip_tunnel *t = netdev_priv(dev); + struct ip_tunnel_parm p; + struct ip_tunnel_encap ipencap; + __u32 fwmark = 0; + + if (ip_tunnel_netlink_encap_parms(data, &ipencap)) { + int err = ip_tunnel_encap_setup(t, &ipencap); + + if (err < 0) + return err; + } + + ipip_netlink_parms(data, &p, &t->collect_md, &fwmark); + return ip_tunnel_newlink(dev, tb, &p, fwmark); +} + +static int ipip_changelink(struct net_device *dev, struct nlattr *tb[], + struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ip_tunnel *t = netdev_priv(dev); + struct ip_tunnel_parm p; + struct ip_tunnel_encap ipencap; + bool collect_md; + __u32 fwmark = t->fwmark; + + if (ip_tunnel_netlink_encap_parms(data, &ipencap)) { + int err = ip_tunnel_encap_setup(t, &ipencap); + + if (err < 0) + return err; + } + + ipip_netlink_parms(data, &p, &collect_md, &fwmark); + if (collect_md) + return -EINVAL; + + if (((dev->flags & IFF_POINTOPOINT) && !p.iph.daddr) || + (!(dev->flags & IFF_POINTOPOINT) && p.iph.daddr)) + return -EINVAL; + + return ip_tunnel_changelink(dev, tb, &p, fwmark); +} + +static size_t ipip_get_size(const struct net_device *dev) +{ + return + /* IFLA_IPTUN_LINK */ + nla_total_size(4) + + /* IFLA_IPTUN_LOCAL */ + nla_total_size(4) + + /* IFLA_IPTUN_REMOTE */ + nla_total_size(4) + + /* IFLA_IPTUN_TTL */ + nla_total_size(1) + + /* IFLA_IPTUN_TOS */ + nla_total_size(1) + + /* IFLA_IPTUN_PROTO */ + nla_total_size(1) + + /* IFLA_IPTUN_PMTUDISC */ + nla_total_size(1) + + /* IFLA_IPTUN_ENCAP_TYPE */ + nla_total_size(2) + + /* IFLA_IPTUN_ENCAP_FLAGS */ + nla_total_size(2) + + /* IFLA_IPTUN_ENCAP_SPORT */ + nla_total_size(2) + + /* IFLA_IPTUN_ENCAP_DPORT */ + nla_total_size(2) + + /* IFLA_IPTUN_COLLECT_METADATA */ + nla_total_size(0) + + /* IFLA_IPTUN_FWMARK */ + nla_total_size(4) + + 0; +} + +static int ipip_fill_info(struct sk_buff *skb, const struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + struct ip_tunnel_parm *parm = &tunnel->parms; + + if (nla_put_u32(skb, IFLA_IPTUN_LINK, parm->link) || + nla_put_in_addr(skb, IFLA_IPTUN_LOCAL, parm->iph.saddr) || + nla_put_in_addr(skb, IFLA_IPTUN_REMOTE, parm->iph.daddr) || + nla_put_u8(skb, IFLA_IPTUN_TTL, parm->iph.ttl) || + nla_put_u8(skb, IFLA_IPTUN_TOS, parm->iph.tos) || + nla_put_u8(skb, IFLA_IPTUN_PROTO, parm->iph.protocol) || + nla_put_u8(skb, IFLA_IPTUN_PMTUDISC, + !!(parm->iph.frag_off & htons(IP_DF))) || + nla_put_u32(skb, IFLA_IPTUN_FWMARK, tunnel->fwmark)) + goto nla_put_failure; + + if (nla_put_u16(skb, IFLA_IPTUN_ENCAP_TYPE, + tunnel->encap.type) || + nla_put_be16(skb, IFLA_IPTUN_ENCAP_SPORT, + tunnel->encap.sport) || + nla_put_be16(skb, IFLA_IPTUN_ENCAP_DPORT, + tunnel->encap.dport) || + nla_put_u16(skb, IFLA_IPTUN_ENCAP_FLAGS, + tunnel->encap.flags)) + goto nla_put_failure; + + if (tunnel->collect_md) + if (nla_put_flag(skb, IFLA_IPTUN_COLLECT_METADATA)) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static const struct nla_policy ipip_policy[IFLA_IPTUN_MAX + 1] = { + [IFLA_IPTUN_LINK] = { .type = NLA_U32 }, + [IFLA_IPTUN_LOCAL] = { .type = NLA_U32 }, + [IFLA_IPTUN_REMOTE] = { .type = NLA_U32 }, + [IFLA_IPTUN_TTL] = { .type = NLA_U8 }, + [IFLA_IPTUN_TOS] = { .type = NLA_U8 }, + [IFLA_IPTUN_PROTO] = { .type = NLA_U8 }, + [IFLA_IPTUN_PMTUDISC] = { .type = NLA_U8 }, + [IFLA_IPTUN_ENCAP_TYPE] = { .type = NLA_U16 }, + [IFLA_IPTUN_ENCAP_FLAGS] = { .type = NLA_U16 }, + [IFLA_IPTUN_ENCAP_SPORT] = { .type = NLA_U16 }, + [IFLA_IPTUN_ENCAP_DPORT] = { .type = NLA_U16 }, + [IFLA_IPTUN_COLLECT_METADATA] = { .type = NLA_FLAG }, + [IFLA_IPTUN_FWMARK] = { .type = NLA_U32 }, +}; + +static struct rtnl_link_ops ipip_link_ops __read_mostly = { + .kind = "ipip", + .maxtype = IFLA_IPTUN_MAX, + .policy = ipip_policy, + .priv_size = sizeof(struct ip_tunnel), + .setup = ipip_tunnel_setup, + .validate = ipip_tunnel_validate, + .newlink = ipip_newlink, + .changelink = ipip_changelink, + .dellink = ip_tunnel_dellink, + .get_size = ipip_get_size, + .fill_info = ipip_fill_info, + .get_link_net = ip_tunnel_get_link_net, +}; + +static struct xfrm_tunnel ipip_handler __read_mostly = { + .handler = ipip_rcv, + .err_handler = ipip_err, + .priority = 1, +}; + +#if IS_ENABLED(CONFIG_MPLS) +static struct xfrm_tunnel mplsip_handler __read_mostly = { + .handler = mplsip_rcv, + .err_handler = ipip_err, + .priority = 1, +}; +#endif + +static int __net_init ipip_init_net(struct net *net) +{ + return ip_tunnel_init_net(net, ipip_net_id, &ipip_link_ops, "tunl0"); +} + +static void __net_exit ipip_exit_batch_net(struct list_head *list_net) +{ + ip_tunnel_delete_nets(list_net, ipip_net_id, &ipip_link_ops); +} + +static struct pernet_operations ipip_net_ops = { + .init = ipip_init_net, + .exit_batch = ipip_exit_batch_net, + .id = &ipip_net_id, + .size = sizeof(struct ip_tunnel_net), +}; + +static int __init ipip_init(void) +{ + int err; + + pr_info("ipip: IPv4 and MPLS over IPv4 tunneling driver\n"); + + err = register_pernet_device(&ipip_net_ops); + if (err < 0) + return err; + err = xfrm4_tunnel_register(&ipip_handler, AF_INET); + if (err < 0) { + pr_info("%s: can't register tunnel\n", __func__); + goto xfrm_tunnel_ipip_failed; + } +#if IS_ENABLED(CONFIG_MPLS) + err = xfrm4_tunnel_register(&mplsip_handler, AF_MPLS); + if (err < 0) { + pr_info("%s: can't register tunnel\n", __func__); + goto xfrm_tunnel_mplsip_failed; + } +#endif + err = rtnl_link_register(&ipip_link_ops); + if (err < 0) + goto rtnl_link_failed; + +out: + return err; + +rtnl_link_failed: +#if IS_ENABLED(CONFIG_MPLS) + xfrm4_tunnel_deregister(&mplsip_handler, AF_MPLS); +xfrm_tunnel_mplsip_failed: + +#endif + xfrm4_tunnel_deregister(&ipip_handler, AF_INET); +xfrm_tunnel_ipip_failed: + unregister_pernet_device(&ipip_net_ops); + goto out; +} + +static void __exit ipip_fini(void) +{ + rtnl_link_unregister(&ipip_link_ops); + if (xfrm4_tunnel_deregister(&ipip_handler, AF_INET)) + pr_info("%s: can't deregister tunnel\n", __func__); +#if IS_ENABLED(CONFIG_MPLS) + if (xfrm4_tunnel_deregister(&mplsip_handler, AF_MPLS)) + pr_info("%s: can't deregister tunnel\n", __func__); +#endif + unregister_pernet_device(&ipip_net_ops); +} + +module_init(ipip_init); +module_exit(ipip_fini); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_RTNL_LINK("ipip"); +MODULE_ALIAS_NETDEV("tunl0"); diff --git a/net/ipv4/ipmr.c b/net/ipv4/ipmr.c new file mode 100644 index 000000000..b80719747 --- /dev/null +++ b/net/ipv4/ipmr.c @@ -0,0 +1,3167 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IP multicast routing support for mrouted 3.6/3.8 + * + * (c) 1995 Alan Cox, + * Linux Consultancy and Custom Driver Development + * + * Fixes: + * Michael Chastain : Incorrect size of copying. + * Alan Cox : Added the cache manager code + * Alan Cox : Fixed the clone/copy bug and device race. + * Mike McLagan : Routing by source + * Malcolm Beattie : Buffer handling fixes. + * Alexey Kuznetsov : Double buffer free and other fixes. + * SVR Anand : Fixed several multicast bugs and problems. + * Alexey Kuznetsov : Status, optimisations and more. + * Brad Parker : Better behaviour on mrouted upcall + * overflow. + * Carlos Picoto : PIMv1 Support + * Pavlin Ivanov Radoslavov: PIMv2 Registers must checksum only PIM header + * Relax this requirement to work with older peers. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +struct ipmr_rule { + struct fib_rule common; +}; + +struct ipmr_result { + struct mr_table *mrt; +}; + +/* Big lock, protecting vif table, mrt cache and mroute socket state. + * Note that the changes are semaphored via rtnl_lock. + */ + +static DEFINE_SPINLOCK(mrt_lock); + +static struct net_device *vif_dev_read(const struct vif_device *vif) +{ + return rcu_dereference(vif->dev); +} + +/* Multicast router control variables */ + +/* Special spinlock for queue of unresolved entries */ +static DEFINE_SPINLOCK(mfc_unres_lock); + +/* We return to original Alan's scheme. Hash table of resolved + * entries is changed only in process context and protected + * with weak lock mrt_lock. Queue of unresolved entries is protected + * with strong spinlock mfc_unres_lock. + * + * In this case data path is free of exclusive locks at all. + */ + +static struct kmem_cache *mrt_cachep __ro_after_init; + +static struct mr_table *ipmr_new_table(struct net *net, u32 id); +static void ipmr_free_table(struct mr_table *mrt); + +static void ip_mr_forward(struct net *net, struct mr_table *mrt, + struct net_device *dev, struct sk_buff *skb, + struct mfc_cache *cache, int local); +static int ipmr_cache_report(const struct mr_table *mrt, + struct sk_buff *pkt, vifi_t vifi, int assert); +static void mroute_netlink_event(struct mr_table *mrt, struct mfc_cache *mfc, + int cmd); +static void igmpmsg_netlink_event(const struct mr_table *mrt, struct sk_buff *pkt); +static void mroute_clean_tables(struct mr_table *mrt, int flags); +static void ipmr_expire_process(struct timer_list *t); + +#ifdef CONFIG_IP_MROUTE_MULTIPLE_TABLES +#define ipmr_for_each_table(mrt, net) \ + list_for_each_entry_rcu(mrt, &net->ipv4.mr_tables, list, \ + lockdep_rtnl_is_held() || \ + list_empty(&net->ipv4.mr_tables)) + +static struct mr_table *ipmr_mr_table_iter(struct net *net, + struct mr_table *mrt) +{ + struct mr_table *ret; + + if (!mrt) + ret = list_entry_rcu(net->ipv4.mr_tables.next, + struct mr_table, list); + else + ret = list_entry_rcu(mrt->list.next, + struct mr_table, list); + + if (&ret->list == &net->ipv4.mr_tables) + return NULL; + return ret; +} + +static struct mr_table *ipmr_get_table(struct net *net, u32 id) +{ + struct mr_table *mrt; + + ipmr_for_each_table(mrt, net) { + if (mrt->id == id) + return mrt; + } + return NULL; +} + +static int ipmr_fib_lookup(struct net *net, struct flowi4 *flp4, + struct mr_table **mrt) +{ + int err; + struct ipmr_result res; + struct fib_lookup_arg arg = { + .result = &res, + .flags = FIB_LOOKUP_NOREF, + }; + + /* update flow if oif or iif point to device enslaved to l3mdev */ + l3mdev_update_flow(net, flowi4_to_flowi(flp4)); + + err = fib_rules_lookup(net->ipv4.mr_rules_ops, + flowi4_to_flowi(flp4), 0, &arg); + if (err < 0) + return err; + *mrt = res.mrt; + return 0; +} + +static int ipmr_rule_action(struct fib_rule *rule, struct flowi *flp, + int flags, struct fib_lookup_arg *arg) +{ + struct ipmr_result *res = arg->result; + struct mr_table *mrt; + + switch (rule->action) { + case FR_ACT_TO_TBL: + break; + case FR_ACT_UNREACHABLE: + return -ENETUNREACH; + case FR_ACT_PROHIBIT: + return -EACCES; + case FR_ACT_BLACKHOLE: + default: + return -EINVAL; + } + + arg->table = fib_rule_get_table(rule, arg); + + mrt = ipmr_get_table(rule->fr_net, arg->table); + if (!mrt) + return -EAGAIN; + res->mrt = mrt; + return 0; +} + +static int ipmr_rule_match(struct fib_rule *rule, struct flowi *fl, int flags) +{ + return 1; +} + +static int ipmr_rule_configure(struct fib_rule *rule, struct sk_buff *skb, + struct fib_rule_hdr *frh, struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + return 0; +} + +static int ipmr_rule_compare(struct fib_rule *rule, struct fib_rule_hdr *frh, + struct nlattr **tb) +{ + return 1; +} + +static int ipmr_rule_fill(struct fib_rule *rule, struct sk_buff *skb, + struct fib_rule_hdr *frh) +{ + frh->dst_len = 0; + frh->src_len = 0; + frh->tos = 0; + return 0; +} + +static const struct fib_rules_ops __net_initconst ipmr_rules_ops_template = { + .family = RTNL_FAMILY_IPMR, + .rule_size = sizeof(struct ipmr_rule), + .addr_size = sizeof(u32), + .action = ipmr_rule_action, + .match = ipmr_rule_match, + .configure = ipmr_rule_configure, + .compare = ipmr_rule_compare, + .fill = ipmr_rule_fill, + .nlgroup = RTNLGRP_IPV4_RULE, + .owner = THIS_MODULE, +}; + +static int __net_init ipmr_rules_init(struct net *net) +{ + struct fib_rules_ops *ops; + struct mr_table *mrt; + int err; + + ops = fib_rules_register(&ipmr_rules_ops_template, net); + if (IS_ERR(ops)) + return PTR_ERR(ops); + + INIT_LIST_HEAD(&net->ipv4.mr_tables); + + mrt = ipmr_new_table(net, RT_TABLE_DEFAULT); + if (IS_ERR(mrt)) { + err = PTR_ERR(mrt); + goto err1; + } + + err = fib_default_rule_add(ops, 0x7fff, RT_TABLE_DEFAULT, 0); + if (err < 0) + goto err2; + + net->ipv4.mr_rules_ops = ops; + return 0; + +err2: + rtnl_lock(); + ipmr_free_table(mrt); + rtnl_unlock(); +err1: + fib_rules_unregister(ops); + return err; +} + +static void __net_exit ipmr_rules_exit(struct net *net) +{ + struct mr_table *mrt, *next; + + ASSERT_RTNL(); + list_for_each_entry_safe(mrt, next, &net->ipv4.mr_tables, list) { + list_del(&mrt->list); + ipmr_free_table(mrt); + } + fib_rules_unregister(net->ipv4.mr_rules_ops); +} + +static int ipmr_rules_dump(struct net *net, struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + return fib_rules_dump(net, nb, RTNL_FAMILY_IPMR, extack); +} + +static unsigned int ipmr_rules_seq_read(struct net *net) +{ + return fib_rules_seq_read(net, RTNL_FAMILY_IPMR); +} + +bool ipmr_rule_default(const struct fib_rule *rule) +{ + return fib_rule_matchall(rule) && rule->table == RT_TABLE_DEFAULT; +} +EXPORT_SYMBOL(ipmr_rule_default); +#else +#define ipmr_for_each_table(mrt, net) \ + for (mrt = net->ipv4.mrt; mrt; mrt = NULL) + +static struct mr_table *ipmr_mr_table_iter(struct net *net, + struct mr_table *mrt) +{ + if (!mrt) + return net->ipv4.mrt; + return NULL; +} + +static struct mr_table *ipmr_get_table(struct net *net, u32 id) +{ + return net->ipv4.mrt; +} + +static int ipmr_fib_lookup(struct net *net, struct flowi4 *flp4, + struct mr_table **mrt) +{ + *mrt = net->ipv4.mrt; + return 0; +} + +static int __net_init ipmr_rules_init(struct net *net) +{ + struct mr_table *mrt; + + mrt = ipmr_new_table(net, RT_TABLE_DEFAULT); + if (IS_ERR(mrt)) + return PTR_ERR(mrt); + net->ipv4.mrt = mrt; + return 0; +} + +static void __net_exit ipmr_rules_exit(struct net *net) +{ + ASSERT_RTNL(); + ipmr_free_table(net->ipv4.mrt); + net->ipv4.mrt = NULL; +} + +static int ipmr_rules_dump(struct net *net, struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + return 0; +} + +static unsigned int ipmr_rules_seq_read(struct net *net) +{ + return 0; +} + +bool ipmr_rule_default(const struct fib_rule *rule) +{ + return true; +} +EXPORT_SYMBOL(ipmr_rule_default); +#endif + +static inline int ipmr_hash_cmp(struct rhashtable_compare_arg *arg, + const void *ptr) +{ + const struct mfc_cache_cmp_arg *cmparg = arg->key; + const struct mfc_cache *c = ptr; + + return cmparg->mfc_mcastgrp != c->mfc_mcastgrp || + cmparg->mfc_origin != c->mfc_origin; +} + +static const struct rhashtable_params ipmr_rht_params = { + .head_offset = offsetof(struct mr_mfc, mnode), + .key_offset = offsetof(struct mfc_cache, cmparg), + .key_len = sizeof(struct mfc_cache_cmp_arg), + .nelem_hint = 3, + .obj_cmpfn = ipmr_hash_cmp, + .automatic_shrinking = true, +}; + +static void ipmr_new_table_set(struct mr_table *mrt, + struct net *net) +{ +#ifdef CONFIG_IP_MROUTE_MULTIPLE_TABLES + list_add_tail_rcu(&mrt->list, &net->ipv4.mr_tables); +#endif +} + +static struct mfc_cache_cmp_arg ipmr_mr_table_ops_cmparg_any = { + .mfc_mcastgrp = htonl(INADDR_ANY), + .mfc_origin = htonl(INADDR_ANY), +}; + +static struct mr_table_ops ipmr_mr_table_ops = { + .rht_params = &ipmr_rht_params, + .cmparg_any = &ipmr_mr_table_ops_cmparg_any, +}; + +static struct mr_table *ipmr_new_table(struct net *net, u32 id) +{ + struct mr_table *mrt; + + /* "pimreg%u" should not exceed 16 bytes (IFNAMSIZ) */ + if (id != RT_TABLE_DEFAULT && id >= 1000000000) + return ERR_PTR(-EINVAL); + + mrt = ipmr_get_table(net, id); + if (mrt) + return mrt; + + return mr_table_alloc(net, id, &ipmr_mr_table_ops, + ipmr_expire_process, ipmr_new_table_set); +} + +static void ipmr_free_table(struct mr_table *mrt) +{ + del_timer_sync(&mrt->ipmr_expire_timer); + mroute_clean_tables(mrt, MRT_FLUSH_VIFS | MRT_FLUSH_VIFS_STATIC | + MRT_FLUSH_MFC | MRT_FLUSH_MFC_STATIC); + rhltable_destroy(&mrt->mfc_hash); + kfree(mrt); +} + +/* Service routines creating virtual interfaces: DVMRP tunnels and PIMREG */ + +/* Initialize ipmr pimreg/tunnel in_device */ +static bool ipmr_init_vif_indev(const struct net_device *dev) +{ + struct in_device *in_dev; + + ASSERT_RTNL(); + + in_dev = __in_dev_get_rtnl(dev); + if (!in_dev) + return false; + ipv4_devconf_setall(in_dev); + neigh_parms_data_state_setall(in_dev->arp_parms); + IPV4_DEVCONF(in_dev->cnf, RP_FILTER) = 0; + + return true; +} + +static struct net_device *ipmr_new_tunnel(struct net *net, struct vifctl *v) +{ + struct net_device *tunnel_dev, *new_dev; + struct ip_tunnel_parm p = { }; + int err; + + tunnel_dev = __dev_get_by_name(net, "tunl0"); + if (!tunnel_dev) + goto out; + + p.iph.daddr = v->vifc_rmt_addr.s_addr; + p.iph.saddr = v->vifc_lcl_addr.s_addr; + p.iph.version = 4; + p.iph.ihl = 5; + p.iph.protocol = IPPROTO_IPIP; + sprintf(p.name, "dvmrp%d", v->vifc_vifi); + + if (!tunnel_dev->netdev_ops->ndo_tunnel_ctl) + goto out; + err = tunnel_dev->netdev_ops->ndo_tunnel_ctl(tunnel_dev, &p, + SIOCADDTUNNEL); + if (err) + goto out; + + new_dev = __dev_get_by_name(net, p.name); + if (!new_dev) + goto out; + + new_dev->flags |= IFF_MULTICAST; + if (!ipmr_init_vif_indev(new_dev)) + goto out_unregister; + if (dev_open(new_dev, NULL)) + goto out_unregister; + dev_hold(new_dev); + err = dev_set_allmulti(new_dev, 1); + if (err) { + dev_close(new_dev); + tunnel_dev->netdev_ops->ndo_tunnel_ctl(tunnel_dev, &p, + SIOCDELTUNNEL); + dev_put(new_dev); + new_dev = ERR_PTR(err); + } + return new_dev; + +out_unregister: + unregister_netdevice(new_dev); +out: + return ERR_PTR(-ENOBUFS); +} + +#if defined(CONFIG_IP_PIMSM_V1) || defined(CONFIG_IP_PIMSM_V2) +static netdev_tx_t reg_vif_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct net *net = dev_net(dev); + struct mr_table *mrt; + struct flowi4 fl4 = { + .flowi4_oif = dev->ifindex, + .flowi4_iif = skb->skb_iif ? : LOOPBACK_IFINDEX, + .flowi4_mark = skb->mark, + }; + int err; + + err = ipmr_fib_lookup(net, &fl4, &mrt); + if (err < 0) { + kfree_skb(skb); + return err; + } + + dev->stats.tx_bytes += skb->len; + dev->stats.tx_packets++; + rcu_read_lock(); + + /* Pairs with WRITE_ONCE() in vif_add() and vif_delete() */ + ipmr_cache_report(mrt, skb, READ_ONCE(mrt->mroute_reg_vif_num), + IGMPMSG_WHOLEPKT); + + rcu_read_unlock(); + kfree_skb(skb); + return NETDEV_TX_OK; +} + +static int reg_vif_get_iflink(const struct net_device *dev) +{ + return 0; +} + +static const struct net_device_ops reg_vif_netdev_ops = { + .ndo_start_xmit = reg_vif_xmit, + .ndo_get_iflink = reg_vif_get_iflink, +}; + +static void reg_vif_setup(struct net_device *dev) +{ + dev->type = ARPHRD_PIMREG; + dev->mtu = ETH_DATA_LEN - sizeof(struct iphdr) - 8; + dev->flags = IFF_NOARP; + dev->netdev_ops = ®_vif_netdev_ops; + dev->needs_free_netdev = true; + dev->features |= NETIF_F_NETNS_LOCAL; +} + +static struct net_device *ipmr_reg_vif(struct net *net, struct mr_table *mrt) +{ + struct net_device *dev; + char name[IFNAMSIZ]; + + if (mrt->id == RT_TABLE_DEFAULT) + sprintf(name, "pimreg"); + else + sprintf(name, "pimreg%u", mrt->id); + + dev = alloc_netdev(0, name, NET_NAME_UNKNOWN, reg_vif_setup); + + if (!dev) + return NULL; + + dev_net_set(dev, net); + + if (register_netdevice(dev)) { + free_netdev(dev); + return NULL; + } + + if (!ipmr_init_vif_indev(dev)) + goto failure; + if (dev_open(dev, NULL)) + goto failure; + + dev_hold(dev); + + return dev; + +failure: + unregister_netdevice(dev); + return NULL; +} + +/* called with rcu_read_lock() */ +static int __pim_rcv(struct mr_table *mrt, struct sk_buff *skb, + unsigned int pimlen) +{ + struct net_device *reg_dev = NULL; + struct iphdr *encap; + int vif_num; + + encap = (struct iphdr *)(skb_transport_header(skb) + pimlen); + /* Check that: + * a. packet is really sent to a multicast group + * b. packet is not a NULL-REGISTER + * c. packet is not truncated + */ + if (!ipv4_is_multicast(encap->daddr) || + encap->tot_len == 0 || + ntohs(encap->tot_len) + pimlen > skb->len) + return 1; + + /* Pairs with WRITE_ONCE() in vif_add()/vid_delete() */ + vif_num = READ_ONCE(mrt->mroute_reg_vif_num); + if (vif_num >= 0) + reg_dev = vif_dev_read(&mrt->vif_table[vif_num]); + if (!reg_dev) + return 1; + + skb->mac_header = skb->network_header; + skb_pull(skb, (u8 *)encap - skb->data); + skb_reset_network_header(skb); + skb->protocol = htons(ETH_P_IP); + skb->ip_summed = CHECKSUM_NONE; + + skb_tunnel_rx(skb, reg_dev, dev_net(reg_dev)); + + netif_rx(skb); + + return NET_RX_SUCCESS; +} +#else +static struct net_device *ipmr_reg_vif(struct net *net, struct mr_table *mrt) +{ + return NULL; +} +#endif + +static int call_ipmr_vif_entry_notifiers(struct net *net, + enum fib_event_type event_type, + struct vif_device *vif, + struct net_device *vif_dev, + vifi_t vif_index, u32 tb_id) +{ + return mr_call_vif_notifiers(net, RTNL_FAMILY_IPMR, event_type, + vif, vif_dev, vif_index, tb_id, + &net->ipv4.ipmr_seq); +} + +static int call_ipmr_mfc_entry_notifiers(struct net *net, + enum fib_event_type event_type, + struct mfc_cache *mfc, u32 tb_id) +{ + return mr_call_mfc_notifiers(net, RTNL_FAMILY_IPMR, event_type, + &mfc->_c, tb_id, &net->ipv4.ipmr_seq); +} + +/** + * vif_delete - Delete a VIF entry + * @mrt: Table to delete from + * @vifi: VIF identifier to delete + * @notify: Set to 1, if the caller is a notifier_call + * @head: if unregistering the VIF, place it on this queue + */ +static int vif_delete(struct mr_table *mrt, int vifi, int notify, + struct list_head *head) +{ + struct net *net = read_pnet(&mrt->net); + struct vif_device *v; + struct net_device *dev; + struct in_device *in_dev; + + if (vifi < 0 || vifi >= mrt->maxvif) + return -EADDRNOTAVAIL; + + v = &mrt->vif_table[vifi]; + + dev = rtnl_dereference(v->dev); + if (!dev) + return -EADDRNOTAVAIL; + + spin_lock(&mrt_lock); + call_ipmr_vif_entry_notifiers(net, FIB_EVENT_VIF_DEL, v, dev, + vifi, mrt->id); + RCU_INIT_POINTER(v->dev, NULL); + + if (vifi == mrt->mroute_reg_vif_num) { + /* Pairs with READ_ONCE() in ipmr_cache_report() and reg_vif_xmit() */ + WRITE_ONCE(mrt->mroute_reg_vif_num, -1); + } + if (vifi + 1 == mrt->maxvif) { + int tmp; + + for (tmp = vifi - 1; tmp >= 0; tmp--) { + if (VIF_EXISTS(mrt, tmp)) + break; + } + WRITE_ONCE(mrt->maxvif, tmp + 1); + } + + spin_unlock(&mrt_lock); + + dev_set_allmulti(dev, -1); + + in_dev = __in_dev_get_rtnl(dev); + if (in_dev) { + IPV4_DEVCONF(in_dev->cnf, MC_FORWARDING)--; + inet_netconf_notify_devconf(dev_net(dev), RTM_NEWNETCONF, + NETCONFA_MC_FORWARDING, + dev->ifindex, &in_dev->cnf); + ip_rt_multicast_event(in_dev); + } + + if (v->flags & (VIFF_TUNNEL | VIFF_REGISTER) && !notify) + unregister_netdevice_queue(dev, head); + + netdev_put(dev, &v->dev_tracker); + return 0; +} + +static void ipmr_cache_free_rcu(struct rcu_head *head) +{ + struct mr_mfc *c = container_of(head, struct mr_mfc, rcu); + + kmem_cache_free(mrt_cachep, (struct mfc_cache *)c); +} + +static void ipmr_cache_free(struct mfc_cache *c) +{ + call_rcu(&c->_c.rcu, ipmr_cache_free_rcu); +} + +/* Destroy an unresolved cache entry, killing queued skbs + * and reporting error to netlink readers. + */ +static void ipmr_destroy_unres(struct mr_table *mrt, struct mfc_cache *c) +{ + struct net *net = read_pnet(&mrt->net); + struct sk_buff *skb; + struct nlmsgerr *e; + + atomic_dec(&mrt->cache_resolve_queue_len); + + while ((skb = skb_dequeue(&c->_c.mfc_un.unres.unresolved))) { + if (ip_hdr(skb)->version == 0) { + struct nlmsghdr *nlh = skb_pull(skb, + sizeof(struct iphdr)); + nlh->nlmsg_type = NLMSG_ERROR; + nlh->nlmsg_len = nlmsg_msg_size(sizeof(struct nlmsgerr)); + skb_trim(skb, nlh->nlmsg_len); + e = nlmsg_data(nlh); + e->error = -ETIMEDOUT; + memset(&e->msg, 0, sizeof(e->msg)); + + rtnl_unicast(skb, net, NETLINK_CB(skb).portid); + } else { + kfree_skb(skb); + } + } + + ipmr_cache_free(c); +} + +/* Timer process for the unresolved queue. */ +static void ipmr_expire_process(struct timer_list *t) +{ + struct mr_table *mrt = from_timer(mrt, t, ipmr_expire_timer); + struct mr_mfc *c, *next; + unsigned long expires; + unsigned long now; + + if (!spin_trylock(&mfc_unres_lock)) { + mod_timer(&mrt->ipmr_expire_timer, jiffies+HZ/10); + return; + } + + if (list_empty(&mrt->mfc_unres_queue)) + goto out; + + now = jiffies; + expires = 10*HZ; + + list_for_each_entry_safe(c, next, &mrt->mfc_unres_queue, list) { + if (time_after(c->mfc_un.unres.expires, now)) { + unsigned long interval = c->mfc_un.unres.expires - now; + if (interval < expires) + expires = interval; + continue; + } + + list_del(&c->list); + mroute_netlink_event(mrt, (struct mfc_cache *)c, RTM_DELROUTE); + ipmr_destroy_unres(mrt, (struct mfc_cache *)c); + } + + if (!list_empty(&mrt->mfc_unres_queue)) + mod_timer(&mrt->ipmr_expire_timer, jiffies + expires); + +out: + spin_unlock(&mfc_unres_lock); +} + +/* Fill oifs list. It is called under locked mrt_lock. */ +static void ipmr_update_thresholds(struct mr_table *mrt, struct mr_mfc *cache, + unsigned char *ttls) +{ + int vifi; + + cache->mfc_un.res.minvif = MAXVIFS; + cache->mfc_un.res.maxvif = 0; + memset(cache->mfc_un.res.ttls, 255, MAXVIFS); + + for (vifi = 0; vifi < mrt->maxvif; vifi++) { + if (VIF_EXISTS(mrt, vifi) && + ttls[vifi] && ttls[vifi] < 255) { + cache->mfc_un.res.ttls[vifi] = ttls[vifi]; + if (cache->mfc_un.res.minvif > vifi) + cache->mfc_un.res.minvif = vifi; + if (cache->mfc_un.res.maxvif <= vifi) + cache->mfc_un.res.maxvif = vifi + 1; + } + } + cache->mfc_un.res.lastuse = jiffies; +} + +static int vif_add(struct net *net, struct mr_table *mrt, + struct vifctl *vifc, int mrtsock) +{ + struct netdev_phys_item_id ppid = { }; + int vifi = vifc->vifc_vifi; + struct vif_device *v = &mrt->vif_table[vifi]; + struct net_device *dev; + struct in_device *in_dev; + int err; + + /* Is vif busy ? */ + if (VIF_EXISTS(mrt, vifi)) + return -EADDRINUSE; + + switch (vifc->vifc_flags) { + case VIFF_REGISTER: + if (!ipmr_pimsm_enabled()) + return -EINVAL; + /* Special Purpose VIF in PIM + * All the packets will be sent to the daemon + */ + if (mrt->mroute_reg_vif_num >= 0) + return -EADDRINUSE; + dev = ipmr_reg_vif(net, mrt); + if (!dev) + return -ENOBUFS; + err = dev_set_allmulti(dev, 1); + if (err) { + unregister_netdevice(dev); + dev_put(dev); + return err; + } + break; + case VIFF_TUNNEL: + dev = ipmr_new_tunnel(net, vifc); + if (IS_ERR(dev)) + return PTR_ERR(dev); + break; + case VIFF_USE_IFINDEX: + case 0: + if (vifc->vifc_flags == VIFF_USE_IFINDEX) { + dev = dev_get_by_index(net, vifc->vifc_lcl_ifindex); + if (dev && !__in_dev_get_rtnl(dev)) { + dev_put(dev); + return -EADDRNOTAVAIL; + } + } else { + dev = ip_dev_find(net, vifc->vifc_lcl_addr.s_addr); + } + if (!dev) + return -EADDRNOTAVAIL; + err = dev_set_allmulti(dev, 1); + if (err) { + dev_put(dev); + return err; + } + break; + default: + return -EINVAL; + } + + in_dev = __in_dev_get_rtnl(dev); + if (!in_dev) { + dev_put(dev); + return -EADDRNOTAVAIL; + } + IPV4_DEVCONF(in_dev->cnf, MC_FORWARDING)++; + inet_netconf_notify_devconf(net, RTM_NEWNETCONF, NETCONFA_MC_FORWARDING, + dev->ifindex, &in_dev->cnf); + ip_rt_multicast_event(in_dev); + + /* Fill in the VIF structures */ + vif_device_init(v, dev, vifc->vifc_rate_limit, + vifc->vifc_threshold, + vifc->vifc_flags | (!mrtsock ? VIFF_STATIC : 0), + (VIFF_TUNNEL | VIFF_REGISTER)); + + err = dev_get_port_parent_id(dev, &ppid, true); + if (err == 0) { + memcpy(v->dev_parent_id.id, ppid.id, ppid.id_len); + v->dev_parent_id.id_len = ppid.id_len; + } else { + v->dev_parent_id.id_len = 0; + } + + v->local = vifc->vifc_lcl_addr.s_addr; + v->remote = vifc->vifc_rmt_addr.s_addr; + + /* And finish update writing critical data */ + spin_lock(&mrt_lock); + rcu_assign_pointer(v->dev, dev); + netdev_tracker_alloc(dev, &v->dev_tracker, GFP_ATOMIC); + if (v->flags & VIFF_REGISTER) { + /* Pairs with READ_ONCE() in ipmr_cache_report() and reg_vif_xmit() */ + WRITE_ONCE(mrt->mroute_reg_vif_num, vifi); + } + if (vifi+1 > mrt->maxvif) + WRITE_ONCE(mrt->maxvif, vifi + 1); + spin_unlock(&mrt_lock); + call_ipmr_vif_entry_notifiers(net, FIB_EVENT_VIF_ADD, v, dev, + vifi, mrt->id); + return 0; +} + +/* called with rcu_read_lock() */ +static struct mfc_cache *ipmr_cache_find(struct mr_table *mrt, + __be32 origin, + __be32 mcastgrp) +{ + struct mfc_cache_cmp_arg arg = { + .mfc_mcastgrp = mcastgrp, + .mfc_origin = origin + }; + + return mr_mfc_find(mrt, &arg); +} + +/* Look for a (*,G) entry */ +static struct mfc_cache *ipmr_cache_find_any(struct mr_table *mrt, + __be32 mcastgrp, int vifi) +{ + struct mfc_cache_cmp_arg arg = { + .mfc_mcastgrp = mcastgrp, + .mfc_origin = htonl(INADDR_ANY) + }; + + if (mcastgrp == htonl(INADDR_ANY)) + return mr_mfc_find_any_parent(mrt, vifi); + return mr_mfc_find_any(mrt, vifi, &arg); +} + +/* Look for a (S,G,iif) entry if parent != -1 */ +static struct mfc_cache *ipmr_cache_find_parent(struct mr_table *mrt, + __be32 origin, __be32 mcastgrp, + int parent) +{ + struct mfc_cache_cmp_arg arg = { + .mfc_mcastgrp = mcastgrp, + .mfc_origin = origin, + }; + + return mr_mfc_find_parent(mrt, &arg, parent); +} + +/* Allocate a multicast cache entry */ +static struct mfc_cache *ipmr_cache_alloc(void) +{ + struct mfc_cache *c = kmem_cache_zalloc(mrt_cachep, GFP_KERNEL); + + if (c) { + c->_c.mfc_un.res.last_assert = jiffies - MFC_ASSERT_THRESH - 1; + c->_c.mfc_un.res.minvif = MAXVIFS; + c->_c.free = ipmr_cache_free_rcu; + refcount_set(&c->_c.mfc_un.res.refcount, 1); + } + return c; +} + +static struct mfc_cache *ipmr_cache_alloc_unres(void) +{ + struct mfc_cache *c = kmem_cache_zalloc(mrt_cachep, GFP_ATOMIC); + + if (c) { + skb_queue_head_init(&c->_c.mfc_un.unres.unresolved); + c->_c.mfc_un.unres.expires = jiffies + 10 * HZ; + } + return c; +} + +/* A cache entry has gone into a resolved state from queued */ +static void ipmr_cache_resolve(struct net *net, struct mr_table *mrt, + struct mfc_cache *uc, struct mfc_cache *c) +{ + struct sk_buff *skb; + struct nlmsgerr *e; + + /* Play the pending entries through our router */ + while ((skb = __skb_dequeue(&uc->_c.mfc_un.unres.unresolved))) { + if (ip_hdr(skb)->version == 0) { + struct nlmsghdr *nlh = skb_pull(skb, + sizeof(struct iphdr)); + + if (mr_fill_mroute(mrt, skb, &c->_c, + nlmsg_data(nlh)) > 0) { + nlh->nlmsg_len = skb_tail_pointer(skb) - + (u8 *)nlh; + } else { + nlh->nlmsg_type = NLMSG_ERROR; + nlh->nlmsg_len = nlmsg_msg_size(sizeof(struct nlmsgerr)); + skb_trim(skb, nlh->nlmsg_len); + e = nlmsg_data(nlh); + e->error = -EMSGSIZE; + memset(&e->msg, 0, sizeof(e->msg)); + } + + rtnl_unicast(skb, net, NETLINK_CB(skb).portid); + } else { + rcu_read_lock(); + ip_mr_forward(net, mrt, skb->dev, skb, c, 0); + rcu_read_unlock(); + } + } +} + +/* Bounce a cache query up to mrouted and netlink. + * + * Called under rcu_read_lock(). + */ +static int ipmr_cache_report(const struct mr_table *mrt, + struct sk_buff *pkt, vifi_t vifi, int assert) +{ + const int ihl = ip_hdrlen(pkt); + struct sock *mroute_sk; + struct igmphdr *igmp; + struct igmpmsg *msg; + struct sk_buff *skb; + int ret; + + mroute_sk = rcu_dereference(mrt->mroute_sk); + if (!mroute_sk) + return -EINVAL; + + if (assert == IGMPMSG_WHOLEPKT || assert == IGMPMSG_WRVIFWHOLE) + skb = skb_realloc_headroom(pkt, sizeof(struct iphdr)); + else + skb = alloc_skb(128, GFP_ATOMIC); + + if (!skb) + return -ENOBUFS; + + if (assert == IGMPMSG_WHOLEPKT || assert == IGMPMSG_WRVIFWHOLE) { + /* Ugly, but we have no choice with this interface. + * Duplicate old header, fix ihl, length etc. + * And all this only to mangle msg->im_msgtype and + * to set msg->im_mbz to "mbz" :-) + */ + skb_push(skb, sizeof(struct iphdr)); + skb_reset_network_header(skb); + skb_reset_transport_header(skb); + msg = (struct igmpmsg *)skb_network_header(skb); + memcpy(msg, skb_network_header(pkt), sizeof(struct iphdr)); + msg->im_msgtype = assert; + msg->im_mbz = 0; + if (assert == IGMPMSG_WRVIFWHOLE) { + msg->im_vif = vifi; + msg->im_vif_hi = vifi >> 8; + } else { + /* Pairs with WRITE_ONCE() in vif_add() and vif_delete() */ + int vif_num = READ_ONCE(mrt->mroute_reg_vif_num); + + msg->im_vif = vif_num; + msg->im_vif_hi = vif_num >> 8; + } + ip_hdr(skb)->ihl = sizeof(struct iphdr) >> 2; + ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(pkt)->tot_len) + + sizeof(struct iphdr)); + } else { + /* Copy the IP header */ + skb_set_network_header(skb, skb->len); + skb_put(skb, ihl); + skb_copy_to_linear_data(skb, pkt->data, ihl); + /* Flag to the kernel this is a route add */ + ip_hdr(skb)->protocol = 0; + msg = (struct igmpmsg *)skb_network_header(skb); + msg->im_vif = vifi; + msg->im_vif_hi = vifi >> 8; + ipv4_pktinfo_prepare(mroute_sk, pkt); + memcpy(skb->cb, pkt->cb, sizeof(skb->cb)); + /* Add our header */ + igmp = skb_put(skb, sizeof(struct igmphdr)); + igmp->type = assert; + msg->im_msgtype = assert; + igmp->code = 0; + ip_hdr(skb)->tot_len = htons(skb->len); /* Fix the length */ + skb->transport_header = skb->network_header; + } + + igmpmsg_netlink_event(mrt, skb); + + /* Deliver to mrouted */ + ret = sock_queue_rcv_skb(mroute_sk, skb); + + if (ret < 0) { + net_warn_ratelimited("mroute: pending queue full, dropping entries\n"); + kfree_skb(skb); + } + + return ret; +} + +/* Queue a packet for resolution. It gets locked cache entry! */ +/* Called under rcu_read_lock() */ +static int ipmr_cache_unresolved(struct mr_table *mrt, vifi_t vifi, + struct sk_buff *skb, struct net_device *dev) +{ + const struct iphdr *iph = ip_hdr(skb); + struct mfc_cache *c; + bool found = false; + int err; + + spin_lock_bh(&mfc_unres_lock); + list_for_each_entry(c, &mrt->mfc_unres_queue, _c.list) { + if (c->mfc_mcastgrp == iph->daddr && + c->mfc_origin == iph->saddr) { + found = true; + break; + } + } + + if (!found) { + /* Create a new entry if allowable */ + c = ipmr_cache_alloc_unres(); + if (!c) { + spin_unlock_bh(&mfc_unres_lock); + + kfree_skb(skb); + return -ENOBUFS; + } + + /* Fill in the new cache entry */ + c->_c.mfc_parent = -1; + c->mfc_origin = iph->saddr; + c->mfc_mcastgrp = iph->daddr; + + /* Reflect first query at mrouted. */ + err = ipmr_cache_report(mrt, skb, vifi, IGMPMSG_NOCACHE); + + if (err < 0) { + /* If the report failed throw the cache entry + out - Brad Parker + */ + spin_unlock_bh(&mfc_unres_lock); + + ipmr_cache_free(c); + kfree_skb(skb); + return err; + } + + atomic_inc(&mrt->cache_resolve_queue_len); + list_add(&c->_c.list, &mrt->mfc_unres_queue); + mroute_netlink_event(mrt, c, RTM_NEWROUTE); + + if (atomic_read(&mrt->cache_resolve_queue_len) == 1) + mod_timer(&mrt->ipmr_expire_timer, + c->_c.mfc_un.unres.expires); + } + + /* See if we can append the packet */ + if (c->_c.mfc_un.unres.unresolved.qlen > 3) { + kfree_skb(skb); + err = -ENOBUFS; + } else { + if (dev) { + skb->dev = dev; + skb->skb_iif = dev->ifindex; + } + skb_queue_tail(&c->_c.mfc_un.unres.unresolved, skb); + err = 0; + } + + spin_unlock_bh(&mfc_unres_lock); + return err; +} + +/* MFC cache manipulation by user space mroute daemon */ + +static int ipmr_mfc_delete(struct mr_table *mrt, struct mfcctl *mfc, int parent) +{ + struct net *net = read_pnet(&mrt->net); + struct mfc_cache *c; + + /* The entries are added/deleted only under RTNL */ + rcu_read_lock(); + c = ipmr_cache_find_parent(mrt, mfc->mfcc_origin.s_addr, + mfc->mfcc_mcastgrp.s_addr, parent); + rcu_read_unlock(); + if (!c) + return -ENOENT; + rhltable_remove(&mrt->mfc_hash, &c->_c.mnode, ipmr_rht_params); + list_del_rcu(&c->_c.list); + call_ipmr_mfc_entry_notifiers(net, FIB_EVENT_ENTRY_DEL, c, mrt->id); + mroute_netlink_event(mrt, c, RTM_DELROUTE); + mr_cache_put(&c->_c); + + return 0; +} + +static int ipmr_mfc_add(struct net *net, struct mr_table *mrt, + struct mfcctl *mfc, int mrtsock, int parent) +{ + struct mfc_cache *uc, *c; + struct mr_mfc *_uc; + bool found; + int ret; + + if (mfc->mfcc_parent >= MAXVIFS) + return -ENFILE; + + /* The entries are added/deleted only under RTNL */ + rcu_read_lock(); + c = ipmr_cache_find_parent(mrt, mfc->mfcc_origin.s_addr, + mfc->mfcc_mcastgrp.s_addr, parent); + rcu_read_unlock(); + if (c) { + spin_lock(&mrt_lock); + c->_c.mfc_parent = mfc->mfcc_parent; + ipmr_update_thresholds(mrt, &c->_c, mfc->mfcc_ttls); + if (!mrtsock) + c->_c.mfc_flags |= MFC_STATIC; + spin_unlock(&mrt_lock); + call_ipmr_mfc_entry_notifiers(net, FIB_EVENT_ENTRY_REPLACE, c, + mrt->id); + mroute_netlink_event(mrt, c, RTM_NEWROUTE); + return 0; + } + + if (mfc->mfcc_mcastgrp.s_addr != htonl(INADDR_ANY) && + !ipv4_is_multicast(mfc->mfcc_mcastgrp.s_addr)) + return -EINVAL; + + c = ipmr_cache_alloc(); + if (!c) + return -ENOMEM; + + c->mfc_origin = mfc->mfcc_origin.s_addr; + c->mfc_mcastgrp = mfc->mfcc_mcastgrp.s_addr; + c->_c.mfc_parent = mfc->mfcc_parent; + ipmr_update_thresholds(mrt, &c->_c, mfc->mfcc_ttls); + if (!mrtsock) + c->_c.mfc_flags |= MFC_STATIC; + + ret = rhltable_insert_key(&mrt->mfc_hash, &c->cmparg, &c->_c.mnode, + ipmr_rht_params); + if (ret) { + pr_err("ipmr: rhtable insert error %d\n", ret); + ipmr_cache_free(c); + return ret; + } + list_add_tail_rcu(&c->_c.list, &mrt->mfc_cache_list); + /* Check to see if we resolved a queued list. If so we + * need to send on the frames and tidy up. + */ + found = false; + spin_lock_bh(&mfc_unres_lock); + list_for_each_entry(_uc, &mrt->mfc_unres_queue, list) { + uc = (struct mfc_cache *)_uc; + if (uc->mfc_origin == c->mfc_origin && + uc->mfc_mcastgrp == c->mfc_mcastgrp) { + list_del(&_uc->list); + atomic_dec(&mrt->cache_resolve_queue_len); + found = true; + break; + } + } + if (list_empty(&mrt->mfc_unres_queue)) + del_timer(&mrt->ipmr_expire_timer); + spin_unlock_bh(&mfc_unres_lock); + + if (found) { + ipmr_cache_resolve(net, mrt, uc, c); + ipmr_cache_free(uc); + } + call_ipmr_mfc_entry_notifiers(net, FIB_EVENT_ENTRY_ADD, c, mrt->id); + mroute_netlink_event(mrt, c, RTM_NEWROUTE); + return 0; +} + +/* Close the multicast socket, and clear the vif tables etc */ +static void mroute_clean_tables(struct mr_table *mrt, int flags) +{ + struct net *net = read_pnet(&mrt->net); + struct mr_mfc *c, *tmp; + struct mfc_cache *cache; + LIST_HEAD(list); + int i; + + /* Shut down all active vif entries */ + if (flags & (MRT_FLUSH_VIFS | MRT_FLUSH_VIFS_STATIC)) { + for (i = 0; i < mrt->maxvif; i++) { + if (((mrt->vif_table[i].flags & VIFF_STATIC) && + !(flags & MRT_FLUSH_VIFS_STATIC)) || + (!(mrt->vif_table[i].flags & VIFF_STATIC) && !(flags & MRT_FLUSH_VIFS))) + continue; + vif_delete(mrt, i, 0, &list); + } + unregister_netdevice_many(&list); + } + + /* Wipe the cache */ + if (flags & (MRT_FLUSH_MFC | MRT_FLUSH_MFC_STATIC)) { + list_for_each_entry_safe(c, tmp, &mrt->mfc_cache_list, list) { + if (((c->mfc_flags & MFC_STATIC) && !(flags & MRT_FLUSH_MFC_STATIC)) || + (!(c->mfc_flags & MFC_STATIC) && !(flags & MRT_FLUSH_MFC))) + continue; + rhltable_remove(&mrt->mfc_hash, &c->mnode, ipmr_rht_params); + list_del_rcu(&c->list); + cache = (struct mfc_cache *)c; + call_ipmr_mfc_entry_notifiers(net, FIB_EVENT_ENTRY_DEL, cache, + mrt->id); + mroute_netlink_event(mrt, cache, RTM_DELROUTE); + mr_cache_put(c); + } + } + + if (flags & MRT_FLUSH_MFC) { + if (atomic_read(&mrt->cache_resolve_queue_len) != 0) { + spin_lock_bh(&mfc_unres_lock); + list_for_each_entry_safe(c, tmp, &mrt->mfc_unres_queue, list) { + list_del(&c->list); + cache = (struct mfc_cache *)c; + mroute_netlink_event(mrt, cache, RTM_DELROUTE); + ipmr_destroy_unres(mrt, cache); + } + spin_unlock_bh(&mfc_unres_lock); + } + } +} + +/* called from ip_ra_control(), before an RCU grace period, + * we don't need to call synchronize_rcu() here + */ +static void mrtsock_destruct(struct sock *sk) +{ + struct net *net = sock_net(sk); + struct mr_table *mrt; + + rtnl_lock(); + ipmr_for_each_table(mrt, net) { + if (sk == rtnl_dereference(mrt->mroute_sk)) { + IPV4_DEVCONF_ALL(net, MC_FORWARDING)--; + inet_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_MC_FORWARDING, + NETCONFA_IFINDEX_ALL, + net->ipv4.devconf_all); + RCU_INIT_POINTER(mrt->mroute_sk, NULL); + mroute_clean_tables(mrt, MRT_FLUSH_VIFS | MRT_FLUSH_MFC); + } + } + rtnl_unlock(); +} + +/* Socket options and virtual interface manipulation. The whole + * virtual interface system is a complete heap, but unfortunately + * that's how BSD mrouted happens to think. Maybe one day with a proper + * MOSPF/PIM router set up we can clean this up. + */ + +int ip_mroute_setsockopt(struct sock *sk, int optname, sockptr_t optval, + unsigned int optlen) +{ + struct net *net = sock_net(sk); + int val, ret = 0, parent = 0; + struct mr_table *mrt; + struct vifctl vif; + struct mfcctl mfc; + bool do_wrvifwhole; + u32 uval; + + /* There's one exception to the lock - MRT_DONE which needs to unlock */ + rtnl_lock(); + if (sk->sk_type != SOCK_RAW || + inet_sk(sk)->inet_num != IPPROTO_IGMP) { + ret = -EOPNOTSUPP; + goto out_unlock; + } + + mrt = ipmr_get_table(net, raw_sk(sk)->ipmr_table ? : RT_TABLE_DEFAULT); + if (!mrt) { + ret = -ENOENT; + goto out_unlock; + } + if (optname != MRT_INIT) { + if (sk != rcu_access_pointer(mrt->mroute_sk) && + !ns_capable(net->user_ns, CAP_NET_ADMIN)) { + ret = -EACCES; + goto out_unlock; + } + } + + switch (optname) { + case MRT_INIT: + if (optlen != sizeof(int)) { + ret = -EINVAL; + break; + } + if (rtnl_dereference(mrt->mroute_sk)) { + ret = -EADDRINUSE; + break; + } + + ret = ip_ra_control(sk, 1, mrtsock_destruct); + if (ret == 0) { + rcu_assign_pointer(mrt->mroute_sk, sk); + IPV4_DEVCONF_ALL(net, MC_FORWARDING)++; + inet_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_MC_FORWARDING, + NETCONFA_IFINDEX_ALL, + net->ipv4.devconf_all); + } + break; + case MRT_DONE: + if (sk != rcu_access_pointer(mrt->mroute_sk)) { + ret = -EACCES; + } else { + /* We need to unlock here because mrtsock_destruct takes + * care of rtnl itself and we can't change that due to + * the IP_ROUTER_ALERT setsockopt which runs without it. + */ + rtnl_unlock(); + ret = ip_ra_control(sk, 0, NULL); + goto out; + } + break; + case MRT_ADD_VIF: + case MRT_DEL_VIF: + if (optlen != sizeof(vif)) { + ret = -EINVAL; + break; + } + if (copy_from_sockptr(&vif, optval, sizeof(vif))) { + ret = -EFAULT; + break; + } + if (vif.vifc_vifi >= MAXVIFS) { + ret = -ENFILE; + break; + } + if (optname == MRT_ADD_VIF) { + ret = vif_add(net, mrt, &vif, + sk == rtnl_dereference(mrt->mroute_sk)); + } else { + ret = vif_delete(mrt, vif.vifc_vifi, 0, NULL); + } + break; + /* Manipulate the forwarding caches. These live + * in a sort of kernel/user symbiosis. + */ + case MRT_ADD_MFC: + case MRT_DEL_MFC: + parent = -1; + fallthrough; + case MRT_ADD_MFC_PROXY: + case MRT_DEL_MFC_PROXY: + if (optlen != sizeof(mfc)) { + ret = -EINVAL; + break; + } + if (copy_from_sockptr(&mfc, optval, sizeof(mfc))) { + ret = -EFAULT; + break; + } + if (parent == 0) + parent = mfc.mfcc_parent; + if (optname == MRT_DEL_MFC || optname == MRT_DEL_MFC_PROXY) + ret = ipmr_mfc_delete(mrt, &mfc, parent); + else + ret = ipmr_mfc_add(net, mrt, &mfc, + sk == rtnl_dereference(mrt->mroute_sk), + parent); + break; + case MRT_FLUSH: + if (optlen != sizeof(val)) { + ret = -EINVAL; + break; + } + if (copy_from_sockptr(&val, optval, sizeof(val))) { + ret = -EFAULT; + break; + } + mroute_clean_tables(mrt, val); + break; + /* Control PIM assert. */ + case MRT_ASSERT: + if (optlen != sizeof(val)) { + ret = -EINVAL; + break; + } + if (copy_from_sockptr(&val, optval, sizeof(val))) { + ret = -EFAULT; + break; + } + mrt->mroute_do_assert = val; + break; + case MRT_PIM: + if (!ipmr_pimsm_enabled()) { + ret = -ENOPROTOOPT; + break; + } + if (optlen != sizeof(val)) { + ret = -EINVAL; + break; + } + if (copy_from_sockptr(&val, optval, sizeof(val))) { + ret = -EFAULT; + break; + } + + do_wrvifwhole = (val == IGMPMSG_WRVIFWHOLE); + val = !!val; + if (val != mrt->mroute_do_pim) { + mrt->mroute_do_pim = val; + mrt->mroute_do_assert = val; + mrt->mroute_do_wrvifwhole = do_wrvifwhole; + } + break; + case MRT_TABLE: + if (!IS_BUILTIN(CONFIG_IP_MROUTE_MULTIPLE_TABLES)) { + ret = -ENOPROTOOPT; + break; + } + if (optlen != sizeof(uval)) { + ret = -EINVAL; + break; + } + if (copy_from_sockptr(&uval, optval, sizeof(uval))) { + ret = -EFAULT; + break; + } + + if (sk == rtnl_dereference(mrt->mroute_sk)) { + ret = -EBUSY; + } else { + mrt = ipmr_new_table(net, uval); + if (IS_ERR(mrt)) + ret = PTR_ERR(mrt); + else + raw_sk(sk)->ipmr_table = uval; + } + break; + /* Spurious command, or MRT_VERSION which you cannot set. */ + default: + ret = -ENOPROTOOPT; + } +out_unlock: + rtnl_unlock(); +out: + return ret; +} + +/* Getsock opt support for the multicast routing system. */ +int ip_mroute_getsockopt(struct sock *sk, int optname, sockptr_t optval, + sockptr_t optlen) +{ + int olr; + int val; + struct net *net = sock_net(sk); + struct mr_table *mrt; + + if (sk->sk_type != SOCK_RAW || + inet_sk(sk)->inet_num != IPPROTO_IGMP) + return -EOPNOTSUPP; + + mrt = ipmr_get_table(net, raw_sk(sk)->ipmr_table ? : RT_TABLE_DEFAULT); + if (!mrt) + return -ENOENT; + + switch (optname) { + case MRT_VERSION: + val = 0x0305; + break; + case MRT_PIM: + if (!ipmr_pimsm_enabled()) + return -ENOPROTOOPT; + val = mrt->mroute_do_pim; + break; + case MRT_ASSERT: + val = mrt->mroute_do_assert; + break; + default: + return -ENOPROTOOPT; + } + + if (copy_from_sockptr(&olr, optlen, sizeof(int))) + return -EFAULT; + olr = min_t(unsigned int, olr, sizeof(int)); + if (olr < 0) + return -EINVAL; + if (copy_to_sockptr(optlen, &olr, sizeof(int))) + return -EFAULT; + if (copy_to_sockptr(optval, &val, olr)) + return -EFAULT; + return 0; +} + +/* The IP multicast ioctl support routines. */ +int ipmr_ioctl(struct sock *sk, int cmd, void __user *arg) +{ + struct sioc_sg_req sr; + struct sioc_vif_req vr; + struct vif_device *vif; + struct mfc_cache *c; + struct net *net = sock_net(sk); + struct mr_table *mrt; + + mrt = ipmr_get_table(net, raw_sk(sk)->ipmr_table ? : RT_TABLE_DEFAULT); + if (!mrt) + return -ENOENT; + + switch (cmd) { + case SIOCGETVIFCNT: + if (copy_from_user(&vr, arg, sizeof(vr))) + return -EFAULT; + if (vr.vifi >= mrt->maxvif) + return -EINVAL; + vr.vifi = array_index_nospec(vr.vifi, mrt->maxvif); + rcu_read_lock(); + vif = &mrt->vif_table[vr.vifi]; + if (VIF_EXISTS(mrt, vr.vifi)) { + vr.icount = READ_ONCE(vif->pkt_in); + vr.ocount = READ_ONCE(vif->pkt_out); + vr.ibytes = READ_ONCE(vif->bytes_in); + vr.obytes = READ_ONCE(vif->bytes_out); + rcu_read_unlock(); + + if (copy_to_user(arg, &vr, sizeof(vr))) + return -EFAULT; + return 0; + } + rcu_read_unlock(); + return -EADDRNOTAVAIL; + case SIOCGETSGCNT: + if (copy_from_user(&sr, arg, sizeof(sr))) + return -EFAULT; + + rcu_read_lock(); + c = ipmr_cache_find(mrt, sr.src.s_addr, sr.grp.s_addr); + if (c) { + sr.pktcnt = c->_c.mfc_un.res.pkt; + sr.bytecnt = c->_c.mfc_un.res.bytes; + sr.wrong_if = c->_c.mfc_un.res.wrong_if; + rcu_read_unlock(); + + if (copy_to_user(arg, &sr, sizeof(sr))) + return -EFAULT; + return 0; + } + rcu_read_unlock(); + return -EADDRNOTAVAIL; + default: + return -ENOIOCTLCMD; + } +} + +#ifdef CONFIG_COMPAT +struct compat_sioc_sg_req { + struct in_addr src; + struct in_addr grp; + compat_ulong_t pktcnt; + compat_ulong_t bytecnt; + compat_ulong_t wrong_if; +}; + +struct compat_sioc_vif_req { + vifi_t vifi; /* Which iface */ + compat_ulong_t icount; + compat_ulong_t ocount; + compat_ulong_t ibytes; + compat_ulong_t obytes; +}; + +int ipmr_compat_ioctl(struct sock *sk, unsigned int cmd, void __user *arg) +{ + struct compat_sioc_sg_req sr; + struct compat_sioc_vif_req vr; + struct vif_device *vif; + struct mfc_cache *c; + struct net *net = sock_net(sk); + struct mr_table *mrt; + + mrt = ipmr_get_table(net, raw_sk(sk)->ipmr_table ? : RT_TABLE_DEFAULT); + if (!mrt) + return -ENOENT; + + switch (cmd) { + case SIOCGETVIFCNT: + if (copy_from_user(&vr, arg, sizeof(vr))) + return -EFAULT; + if (vr.vifi >= mrt->maxvif) + return -EINVAL; + vr.vifi = array_index_nospec(vr.vifi, mrt->maxvif); + rcu_read_lock(); + vif = &mrt->vif_table[vr.vifi]; + if (VIF_EXISTS(mrt, vr.vifi)) { + vr.icount = READ_ONCE(vif->pkt_in); + vr.ocount = READ_ONCE(vif->pkt_out); + vr.ibytes = READ_ONCE(vif->bytes_in); + vr.obytes = READ_ONCE(vif->bytes_out); + rcu_read_unlock(); + + if (copy_to_user(arg, &vr, sizeof(vr))) + return -EFAULT; + return 0; + } + rcu_read_unlock(); + return -EADDRNOTAVAIL; + case SIOCGETSGCNT: + if (copy_from_user(&sr, arg, sizeof(sr))) + return -EFAULT; + + rcu_read_lock(); + c = ipmr_cache_find(mrt, sr.src.s_addr, sr.grp.s_addr); + if (c) { + sr.pktcnt = c->_c.mfc_un.res.pkt; + sr.bytecnt = c->_c.mfc_un.res.bytes; + sr.wrong_if = c->_c.mfc_un.res.wrong_if; + rcu_read_unlock(); + + if (copy_to_user(arg, &sr, sizeof(sr))) + return -EFAULT; + return 0; + } + rcu_read_unlock(); + return -EADDRNOTAVAIL; + default: + return -ENOIOCTLCMD; + } +} +#endif + +static int ipmr_device_event(struct notifier_block *this, unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct net *net = dev_net(dev); + struct mr_table *mrt; + struct vif_device *v; + int ct; + + if (event != NETDEV_UNREGISTER) + return NOTIFY_DONE; + + ipmr_for_each_table(mrt, net) { + v = &mrt->vif_table[0]; + for (ct = 0; ct < mrt->maxvif; ct++, v++) { + if (rcu_access_pointer(v->dev) == dev) + vif_delete(mrt, ct, 1, NULL); + } + } + return NOTIFY_DONE; +} + +static struct notifier_block ip_mr_notifier = { + .notifier_call = ipmr_device_event, +}; + +/* Encapsulate a packet by attaching a valid IPIP header to it. + * This avoids tunnel drivers and other mess and gives us the speed so + * important for multicast video. + */ +static void ip_encap(struct net *net, struct sk_buff *skb, + __be32 saddr, __be32 daddr) +{ + struct iphdr *iph; + const struct iphdr *old_iph = ip_hdr(skb); + + skb_push(skb, sizeof(struct iphdr)); + skb->transport_header = skb->network_header; + skb_reset_network_header(skb); + iph = ip_hdr(skb); + + iph->version = 4; + iph->tos = old_iph->tos; + iph->ttl = old_iph->ttl; + iph->frag_off = 0; + iph->daddr = daddr; + iph->saddr = saddr; + iph->protocol = IPPROTO_IPIP; + iph->ihl = 5; + iph->tot_len = htons(skb->len); + ip_select_ident(net, skb, NULL); + ip_send_check(iph); + + memset(&(IPCB(skb)->opt), 0, sizeof(IPCB(skb)->opt)); + nf_reset_ct(skb); +} + +static inline int ipmr_forward_finish(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + struct ip_options *opt = &(IPCB(skb)->opt); + + IP_INC_STATS(net, IPSTATS_MIB_OUTFORWDATAGRAMS); + IP_ADD_STATS(net, IPSTATS_MIB_OUTOCTETS, skb->len); + + if (unlikely(opt->optlen)) + ip_forward_options(skb); + + return dst_output(net, sk, skb); +} + +#ifdef CONFIG_NET_SWITCHDEV +static bool ipmr_forward_offloaded(struct sk_buff *skb, struct mr_table *mrt, + int in_vifi, int out_vifi) +{ + struct vif_device *out_vif = &mrt->vif_table[out_vifi]; + struct vif_device *in_vif = &mrt->vif_table[in_vifi]; + + if (!skb->offload_l3_fwd_mark) + return false; + if (!out_vif->dev_parent_id.id_len || !in_vif->dev_parent_id.id_len) + return false; + return netdev_phys_item_id_same(&out_vif->dev_parent_id, + &in_vif->dev_parent_id); +} +#else +static bool ipmr_forward_offloaded(struct sk_buff *skb, struct mr_table *mrt, + int in_vifi, int out_vifi) +{ + return false; +} +#endif + +/* Processing handlers for ipmr_forward, under rcu_read_lock() */ + +static void ipmr_queue_xmit(struct net *net, struct mr_table *mrt, + int in_vifi, struct sk_buff *skb, int vifi) +{ + const struct iphdr *iph = ip_hdr(skb); + struct vif_device *vif = &mrt->vif_table[vifi]; + struct net_device *vif_dev; + struct net_device *dev; + struct rtable *rt; + struct flowi4 fl4; + int encap = 0; + + vif_dev = vif_dev_read(vif); + if (!vif_dev) + goto out_free; + + if (vif->flags & VIFF_REGISTER) { + WRITE_ONCE(vif->pkt_out, vif->pkt_out + 1); + WRITE_ONCE(vif->bytes_out, vif->bytes_out + skb->len); + vif_dev->stats.tx_bytes += skb->len; + vif_dev->stats.tx_packets++; + ipmr_cache_report(mrt, skb, vifi, IGMPMSG_WHOLEPKT); + goto out_free; + } + + if (ipmr_forward_offloaded(skb, mrt, in_vifi, vifi)) + goto out_free; + + if (vif->flags & VIFF_TUNNEL) { + rt = ip_route_output_ports(net, &fl4, NULL, + vif->remote, vif->local, + 0, 0, + IPPROTO_IPIP, + RT_TOS(iph->tos), vif->link); + if (IS_ERR(rt)) + goto out_free; + encap = sizeof(struct iphdr); + } else { + rt = ip_route_output_ports(net, &fl4, NULL, iph->daddr, 0, + 0, 0, + IPPROTO_IPIP, + RT_TOS(iph->tos), vif->link); + if (IS_ERR(rt)) + goto out_free; + } + + dev = rt->dst.dev; + + if (skb->len+encap > dst_mtu(&rt->dst) && (ntohs(iph->frag_off) & IP_DF)) { + /* Do not fragment multicasts. Alas, IPv4 does not + * allow to send ICMP, so that packets will disappear + * to blackhole. + */ + IP_INC_STATS(net, IPSTATS_MIB_FRAGFAILS); + ip_rt_put(rt); + goto out_free; + } + + encap += LL_RESERVED_SPACE(dev) + rt->dst.header_len; + + if (skb_cow(skb, encap)) { + ip_rt_put(rt); + goto out_free; + } + + WRITE_ONCE(vif->pkt_out, vif->pkt_out + 1); + WRITE_ONCE(vif->bytes_out, vif->bytes_out + skb->len); + + skb_dst_drop(skb); + skb_dst_set(skb, &rt->dst); + ip_decrease_ttl(ip_hdr(skb)); + + /* FIXME: forward and output firewalls used to be called here. + * What do we do with netfilter? -- RR + */ + if (vif->flags & VIFF_TUNNEL) { + ip_encap(net, skb, vif->local, vif->remote); + /* FIXME: extra output firewall step used to be here. --RR */ + vif_dev->stats.tx_packets++; + vif_dev->stats.tx_bytes += skb->len; + } + + IPCB(skb)->flags |= IPSKB_FORWARDED; + + /* RFC1584 teaches, that DVMRP/PIM router must deliver packets locally + * not only before forwarding, but after forwarding on all output + * interfaces. It is clear, if mrouter runs a multicasting + * program, it should receive packets not depending to what interface + * program is joined. + * If we will not make it, the program will have to join on all + * interfaces. On the other hand, multihoming host (or router, but + * not mrouter) cannot join to more than one interface - it will + * result in receiving multiple packets. + */ + NF_HOOK(NFPROTO_IPV4, NF_INET_FORWARD, + net, NULL, skb, skb->dev, dev, + ipmr_forward_finish); + return; + +out_free: + kfree_skb(skb); +} + +/* Called with mrt_lock or rcu_read_lock() */ +static int ipmr_find_vif(const struct mr_table *mrt, struct net_device *dev) +{ + int ct; + /* Pairs with WRITE_ONCE() in vif_delete()/vif_add() */ + for (ct = READ_ONCE(mrt->maxvif) - 1; ct >= 0; ct--) { + if (rcu_access_pointer(mrt->vif_table[ct].dev) == dev) + break; + } + return ct; +} + +/* "local" means that we should preserve one skb (for local delivery) */ +/* Called uner rcu_read_lock() */ +static void ip_mr_forward(struct net *net, struct mr_table *mrt, + struct net_device *dev, struct sk_buff *skb, + struct mfc_cache *c, int local) +{ + int true_vifi = ipmr_find_vif(mrt, dev); + int psend = -1; + int vif, ct; + + vif = c->_c.mfc_parent; + c->_c.mfc_un.res.pkt++; + c->_c.mfc_un.res.bytes += skb->len; + c->_c.mfc_un.res.lastuse = jiffies; + + if (c->mfc_origin == htonl(INADDR_ANY) && true_vifi >= 0) { + struct mfc_cache *cache_proxy; + + /* For an (*,G) entry, we only check that the incoming + * interface is part of the static tree. + */ + cache_proxy = mr_mfc_find_any_parent(mrt, vif); + if (cache_proxy && + cache_proxy->_c.mfc_un.res.ttls[true_vifi] < 255) + goto forward; + } + + /* Wrong interface: drop packet and (maybe) send PIM assert. */ + if (rcu_access_pointer(mrt->vif_table[vif].dev) != dev) { + if (rt_is_output_route(skb_rtable(skb))) { + /* It is our own packet, looped back. + * Very complicated situation... + * + * The best workaround until routing daemons will be + * fixed is not to redistribute packet, if it was + * send through wrong interface. It means, that + * multicast applications WILL NOT work for + * (S,G), which have default multicast route pointing + * to wrong oif. In any case, it is not a good + * idea to use multicasting applications on router. + */ + goto dont_forward; + } + + c->_c.mfc_un.res.wrong_if++; + + if (true_vifi >= 0 && mrt->mroute_do_assert && + /* pimsm uses asserts, when switching from RPT to SPT, + * so that we cannot check that packet arrived on an oif. + * It is bad, but otherwise we would need to move pretty + * large chunk of pimd to kernel. Ough... --ANK + */ + (mrt->mroute_do_pim || + c->_c.mfc_un.res.ttls[true_vifi] < 255) && + time_after(jiffies, + c->_c.mfc_un.res.last_assert + + MFC_ASSERT_THRESH)) { + c->_c.mfc_un.res.last_assert = jiffies; + ipmr_cache_report(mrt, skb, true_vifi, IGMPMSG_WRONGVIF); + if (mrt->mroute_do_wrvifwhole) + ipmr_cache_report(mrt, skb, true_vifi, + IGMPMSG_WRVIFWHOLE); + } + goto dont_forward; + } + +forward: + WRITE_ONCE(mrt->vif_table[vif].pkt_in, + mrt->vif_table[vif].pkt_in + 1); + WRITE_ONCE(mrt->vif_table[vif].bytes_in, + mrt->vif_table[vif].bytes_in + skb->len); + + /* Forward the frame */ + if (c->mfc_origin == htonl(INADDR_ANY) && + c->mfc_mcastgrp == htonl(INADDR_ANY)) { + if (true_vifi >= 0 && + true_vifi != c->_c.mfc_parent && + ip_hdr(skb)->ttl > + c->_c.mfc_un.res.ttls[c->_c.mfc_parent]) { + /* It's an (*,*) entry and the packet is not coming from + * the upstream: forward the packet to the upstream + * only. + */ + psend = c->_c.mfc_parent; + goto last_forward; + } + goto dont_forward; + } + for (ct = c->_c.mfc_un.res.maxvif - 1; + ct >= c->_c.mfc_un.res.minvif; ct--) { + /* For (*,G) entry, don't forward to the incoming interface */ + if ((c->mfc_origin != htonl(INADDR_ANY) || + ct != true_vifi) && + ip_hdr(skb)->ttl > c->_c.mfc_un.res.ttls[ct]) { + if (psend != -1) { + struct sk_buff *skb2 = skb_clone(skb, GFP_ATOMIC); + + if (skb2) + ipmr_queue_xmit(net, mrt, true_vifi, + skb2, psend); + } + psend = ct; + } + } +last_forward: + if (psend != -1) { + if (local) { + struct sk_buff *skb2 = skb_clone(skb, GFP_ATOMIC); + + if (skb2) + ipmr_queue_xmit(net, mrt, true_vifi, skb2, + psend); + } else { + ipmr_queue_xmit(net, mrt, true_vifi, skb, psend); + return; + } + } + +dont_forward: + if (!local) + kfree_skb(skb); +} + +static struct mr_table *ipmr_rt_fib_lookup(struct net *net, struct sk_buff *skb) +{ + struct rtable *rt = skb_rtable(skb); + struct iphdr *iph = ip_hdr(skb); + struct flowi4 fl4 = { + .daddr = iph->daddr, + .saddr = iph->saddr, + .flowi4_tos = RT_TOS(iph->tos), + .flowi4_oif = (rt_is_output_route(rt) ? + skb->dev->ifindex : 0), + .flowi4_iif = (rt_is_output_route(rt) ? + LOOPBACK_IFINDEX : + skb->dev->ifindex), + .flowi4_mark = skb->mark, + }; + struct mr_table *mrt; + int err; + + err = ipmr_fib_lookup(net, &fl4, &mrt); + if (err) + return ERR_PTR(err); + return mrt; +} + +/* Multicast packets for forwarding arrive here + * Called with rcu_read_lock(); + */ +int ip_mr_input(struct sk_buff *skb) +{ + struct mfc_cache *cache; + struct net *net = dev_net(skb->dev); + int local = skb_rtable(skb)->rt_flags & RTCF_LOCAL; + struct mr_table *mrt; + struct net_device *dev; + + /* skb->dev passed in is the loX master dev for vrfs. + * As there are no vifs associated with loopback devices, + * get the proper interface that does have a vif associated with it. + */ + dev = skb->dev; + if (netif_is_l3_master(skb->dev)) { + dev = dev_get_by_index_rcu(net, IPCB(skb)->iif); + if (!dev) { + kfree_skb(skb); + return -ENODEV; + } + } + + /* Packet is looped back after forward, it should not be + * forwarded second time, but still can be delivered locally. + */ + if (IPCB(skb)->flags & IPSKB_FORWARDED) + goto dont_forward; + + mrt = ipmr_rt_fib_lookup(net, skb); + if (IS_ERR(mrt)) { + kfree_skb(skb); + return PTR_ERR(mrt); + } + if (!local) { + if (IPCB(skb)->opt.router_alert) { + if (ip_call_ra_chain(skb)) + return 0; + } else if (ip_hdr(skb)->protocol == IPPROTO_IGMP) { + /* IGMPv1 (and broken IGMPv2 implementations sort of + * Cisco IOS <= 11.2(8)) do not put router alert + * option to IGMP packets destined to routable + * groups. It is very bad, because it means + * that we can forward NO IGMP messages. + */ + struct sock *mroute_sk; + + mroute_sk = rcu_dereference(mrt->mroute_sk); + if (mroute_sk) { + nf_reset_ct(skb); + raw_rcv(mroute_sk, skb); + return 0; + } + } + } + + /* already under rcu_read_lock() */ + cache = ipmr_cache_find(mrt, ip_hdr(skb)->saddr, ip_hdr(skb)->daddr); + if (!cache) { + int vif = ipmr_find_vif(mrt, dev); + + if (vif >= 0) + cache = ipmr_cache_find_any(mrt, ip_hdr(skb)->daddr, + vif); + } + + /* No usable cache entry */ + if (!cache) { + int vif; + + if (local) { + struct sk_buff *skb2 = skb_clone(skb, GFP_ATOMIC); + ip_local_deliver(skb); + if (!skb2) + return -ENOBUFS; + skb = skb2; + } + + vif = ipmr_find_vif(mrt, dev); + if (vif >= 0) + return ipmr_cache_unresolved(mrt, vif, skb, dev); + kfree_skb(skb); + return -ENODEV; + } + + ip_mr_forward(net, mrt, dev, skb, cache, local); + + if (local) + return ip_local_deliver(skb); + + return 0; + +dont_forward: + if (local) + return ip_local_deliver(skb); + kfree_skb(skb); + return 0; +} + +#ifdef CONFIG_IP_PIMSM_V1 +/* Handle IGMP messages of PIMv1 */ +int pim_rcv_v1(struct sk_buff *skb) +{ + struct igmphdr *pim; + struct net *net = dev_net(skb->dev); + struct mr_table *mrt; + + if (!pskb_may_pull(skb, sizeof(*pim) + sizeof(struct iphdr))) + goto drop; + + pim = igmp_hdr(skb); + + mrt = ipmr_rt_fib_lookup(net, skb); + if (IS_ERR(mrt)) + goto drop; + if (!mrt->mroute_do_pim || + pim->group != PIM_V1_VERSION || pim->code != PIM_V1_REGISTER) + goto drop; + + if (__pim_rcv(mrt, skb, sizeof(*pim))) { +drop: + kfree_skb(skb); + } + return 0; +} +#endif + +#ifdef CONFIG_IP_PIMSM_V2 +static int pim_rcv(struct sk_buff *skb) +{ + struct pimreghdr *pim; + struct net *net = dev_net(skb->dev); + struct mr_table *mrt; + + if (!pskb_may_pull(skb, sizeof(*pim) + sizeof(struct iphdr))) + goto drop; + + pim = (struct pimreghdr *)skb_transport_header(skb); + if (pim->type != ((PIM_VERSION << 4) | (PIM_TYPE_REGISTER)) || + (pim->flags & PIM_NULL_REGISTER) || + (ip_compute_csum((void *)pim, sizeof(*pim)) != 0 && + csum_fold(skb_checksum(skb, 0, skb->len, 0)))) + goto drop; + + mrt = ipmr_rt_fib_lookup(net, skb); + if (IS_ERR(mrt)) + goto drop; + if (__pim_rcv(mrt, skb, sizeof(*pim))) { +drop: + kfree_skb(skb); + } + return 0; +} +#endif + +int ipmr_get_route(struct net *net, struct sk_buff *skb, + __be32 saddr, __be32 daddr, + struct rtmsg *rtm, u32 portid) +{ + struct mfc_cache *cache; + struct mr_table *mrt; + int err; + + mrt = ipmr_get_table(net, RT_TABLE_DEFAULT); + if (!mrt) + return -ENOENT; + + rcu_read_lock(); + cache = ipmr_cache_find(mrt, saddr, daddr); + if (!cache && skb->dev) { + int vif = ipmr_find_vif(mrt, skb->dev); + + if (vif >= 0) + cache = ipmr_cache_find_any(mrt, daddr, vif); + } + if (!cache) { + struct sk_buff *skb2; + struct iphdr *iph; + struct net_device *dev; + int vif = -1; + + dev = skb->dev; + if (dev) + vif = ipmr_find_vif(mrt, dev); + if (vif < 0) { + rcu_read_unlock(); + return -ENODEV; + } + + skb2 = skb_realloc_headroom(skb, sizeof(struct iphdr)); + if (!skb2) { + rcu_read_unlock(); + return -ENOMEM; + } + + NETLINK_CB(skb2).portid = portid; + skb_push(skb2, sizeof(struct iphdr)); + skb_reset_network_header(skb2); + iph = ip_hdr(skb2); + iph->ihl = sizeof(struct iphdr) >> 2; + iph->saddr = saddr; + iph->daddr = daddr; + iph->version = 0; + err = ipmr_cache_unresolved(mrt, vif, skb2, dev); + rcu_read_unlock(); + return err; + } + + err = mr_fill_mroute(mrt, skb, &cache->_c, rtm); + rcu_read_unlock(); + return err; +} + +static int ipmr_fill_mroute(struct mr_table *mrt, struct sk_buff *skb, + u32 portid, u32 seq, struct mfc_cache *c, int cmd, + int flags) +{ + struct nlmsghdr *nlh; + struct rtmsg *rtm; + int err; + + nlh = nlmsg_put(skb, portid, seq, cmd, sizeof(*rtm), flags); + if (!nlh) + return -EMSGSIZE; + + rtm = nlmsg_data(nlh); + rtm->rtm_family = RTNL_FAMILY_IPMR; + rtm->rtm_dst_len = 32; + rtm->rtm_src_len = 32; + rtm->rtm_tos = 0; + rtm->rtm_table = mrt->id; + if (nla_put_u32(skb, RTA_TABLE, mrt->id)) + goto nla_put_failure; + rtm->rtm_type = RTN_MULTICAST; + rtm->rtm_scope = RT_SCOPE_UNIVERSE; + if (c->_c.mfc_flags & MFC_STATIC) + rtm->rtm_protocol = RTPROT_STATIC; + else + rtm->rtm_protocol = RTPROT_MROUTED; + rtm->rtm_flags = 0; + + if (nla_put_in_addr(skb, RTA_SRC, c->mfc_origin) || + nla_put_in_addr(skb, RTA_DST, c->mfc_mcastgrp)) + goto nla_put_failure; + err = mr_fill_mroute(mrt, skb, &c->_c, rtm); + /* do not break the dump if cache is unresolved */ + if (err < 0 && err != -ENOENT) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int _ipmr_fill_mroute(struct mr_table *mrt, struct sk_buff *skb, + u32 portid, u32 seq, struct mr_mfc *c, int cmd, + int flags) +{ + return ipmr_fill_mroute(mrt, skb, portid, seq, (struct mfc_cache *)c, + cmd, flags); +} + +static size_t mroute_msgsize(bool unresolved, int maxvif) +{ + size_t len = + NLMSG_ALIGN(sizeof(struct rtmsg)) + + nla_total_size(4) /* RTA_TABLE */ + + nla_total_size(4) /* RTA_SRC */ + + nla_total_size(4) /* RTA_DST */ + ; + + if (!unresolved) + len = len + + nla_total_size(4) /* RTA_IIF */ + + nla_total_size(0) /* RTA_MULTIPATH */ + + maxvif * NLA_ALIGN(sizeof(struct rtnexthop)) + /* RTA_MFC_STATS */ + + nla_total_size_64bit(sizeof(struct rta_mfc_stats)) + ; + + return len; +} + +static void mroute_netlink_event(struct mr_table *mrt, struct mfc_cache *mfc, + int cmd) +{ + struct net *net = read_pnet(&mrt->net); + struct sk_buff *skb; + int err = -ENOBUFS; + + skb = nlmsg_new(mroute_msgsize(mfc->_c.mfc_parent >= MAXVIFS, + mrt->maxvif), + GFP_ATOMIC); + if (!skb) + goto errout; + + err = ipmr_fill_mroute(mrt, skb, 0, 0, mfc, cmd, 0); + if (err < 0) + goto errout; + + rtnl_notify(skb, net, 0, RTNLGRP_IPV4_MROUTE, NULL, GFP_ATOMIC); + return; + +errout: + kfree_skb(skb); + if (err < 0) + rtnl_set_sk_err(net, RTNLGRP_IPV4_MROUTE, err); +} + +static size_t igmpmsg_netlink_msgsize(size_t payloadlen) +{ + size_t len = + NLMSG_ALIGN(sizeof(struct rtgenmsg)) + + nla_total_size(1) /* IPMRA_CREPORT_MSGTYPE */ + + nla_total_size(4) /* IPMRA_CREPORT_VIF_ID */ + + nla_total_size(4) /* IPMRA_CREPORT_SRC_ADDR */ + + nla_total_size(4) /* IPMRA_CREPORT_DST_ADDR */ + + nla_total_size(4) /* IPMRA_CREPORT_TABLE */ + /* IPMRA_CREPORT_PKT */ + + nla_total_size(payloadlen) + ; + + return len; +} + +static void igmpmsg_netlink_event(const struct mr_table *mrt, struct sk_buff *pkt) +{ + struct net *net = read_pnet(&mrt->net); + struct nlmsghdr *nlh; + struct rtgenmsg *rtgenm; + struct igmpmsg *msg; + struct sk_buff *skb; + struct nlattr *nla; + int payloadlen; + + payloadlen = pkt->len - sizeof(struct igmpmsg); + msg = (struct igmpmsg *)skb_network_header(pkt); + + skb = nlmsg_new(igmpmsg_netlink_msgsize(payloadlen), GFP_ATOMIC); + if (!skb) + goto errout; + + nlh = nlmsg_put(skb, 0, 0, RTM_NEWCACHEREPORT, + sizeof(struct rtgenmsg), 0); + if (!nlh) + goto errout; + rtgenm = nlmsg_data(nlh); + rtgenm->rtgen_family = RTNL_FAMILY_IPMR; + if (nla_put_u8(skb, IPMRA_CREPORT_MSGTYPE, msg->im_msgtype) || + nla_put_u32(skb, IPMRA_CREPORT_VIF_ID, msg->im_vif | (msg->im_vif_hi << 8)) || + nla_put_in_addr(skb, IPMRA_CREPORT_SRC_ADDR, + msg->im_src.s_addr) || + nla_put_in_addr(skb, IPMRA_CREPORT_DST_ADDR, + msg->im_dst.s_addr) || + nla_put_u32(skb, IPMRA_CREPORT_TABLE, mrt->id)) + goto nla_put_failure; + + nla = nla_reserve(skb, IPMRA_CREPORT_PKT, payloadlen); + if (!nla || skb_copy_bits(pkt, sizeof(struct igmpmsg), + nla_data(nla), payloadlen)) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + + rtnl_notify(skb, net, 0, RTNLGRP_IPV4_MROUTE_R, NULL, GFP_ATOMIC); + return; + +nla_put_failure: + nlmsg_cancel(skb, nlh); +errout: + kfree_skb(skb); + rtnl_set_sk_err(net, RTNLGRP_IPV4_MROUTE_R, -ENOBUFS); +} + +static int ipmr_rtm_valid_getroute_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct rtmsg *rtm; + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid header for multicast route get request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse_deprecated(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_ipv4_policy, extack); + + rtm = nlmsg_data(nlh); + if ((rtm->rtm_src_len && rtm->rtm_src_len != 32) || + (rtm->rtm_dst_len && rtm->rtm_dst_len != 32) || + rtm->rtm_tos || rtm->rtm_table || rtm->rtm_protocol || + rtm->rtm_scope || rtm->rtm_type || rtm->rtm_flags) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid values in header for multicast route get request"); + return -EINVAL; + } + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_ipv4_policy, extack); + if (err) + return err; + + if ((tb[RTA_SRC] && !rtm->rtm_src_len) || + (tb[RTA_DST] && !rtm->rtm_dst_len)) { + NL_SET_ERR_MSG(extack, "ipv4: rtm_src_len and rtm_dst_len must be 32 for IPv4"); + return -EINVAL; + } + + for (i = 0; i <= RTA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case RTA_SRC: + case RTA_DST: + case RTA_TABLE: + break; + default: + NL_SET_ERR_MSG(extack, "ipv4: Unsupported attribute in multicast route get request"); + return -EINVAL; + } + } + + return 0; +} + +static int ipmr_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(in_skb->sk); + struct nlattr *tb[RTA_MAX + 1]; + struct sk_buff *skb = NULL; + struct mfc_cache *cache; + struct mr_table *mrt; + __be32 src, grp; + u32 tableid; + int err; + + err = ipmr_rtm_valid_getroute_req(in_skb, nlh, tb, extack); + if (err < 0) + goto errout; + + src = tb[RTA_SRC] ? nla_get_in_addr(tb[RTA_SRC]) : 0; + grp = tb[RTA_DST] ? nla_get_in_addr(tb[RTA_DST]) : 0; + tableid = tb[RTA_TABLE] ? nla_get_u32(tb[RTA_TABLE]) : 0; + + mrt = ipmr_get_table(net, tableid ? tableid : RT_TABLE_DEFAULT); + if (!mrt) { + err = -ENOENT; + goto errout_free; + } + + /* entries are added/deleted only under RTNL */ + rcu_read_lock(); + cache = ipmr_cache_find(mrt, src, grp); + rcu_read_unlock(); + if (!cache) { + err = -ENOENT; + goto errout_free; + } + + skb = nlmsg_new(mroute_msgsize(false, mrt->maxvif), GFP_KERNEL); + if (!skb) { + err = -ENOBUFS; + goto errout_free; + } + + err = ipmr_fill_mroute(mrt, skb, NETLINK_CB(in_skb).portid, + nlh->nlmsg_seq, cache, + RTM_NEWROUTE, 0); + if (err < 0) + goto errout_free; + + err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid); + +errout: + return err; + +errout_free: + kfree_skb(skb); + goto errout; +} + +static int ipmr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct fib_dump_filter filter = {}; + int err; + + if (cb->strict_check) { + err = ip_valid_fib_dump_req(sock_net(skb->sk), cb->nlh, + &filter, cb); + if (err < 0) + return err; + } + + if (filter.table_id) { + struct mr_table *mrt; + + mrt = ipmr_get_table(sock_net(skb->sk), filter.table_id); + if (!mrt) { + if (rtnl_msg_family(cb->nlh) != RTNL_FAMILY_IPMR) + return skb->len; + + NL_SET_ERR_MSG(cb->extack, "ipv4: MR table does not exist"); + return -ENOENT; + } + err = mr_table_dump(mrt, skb, cb, _ipmr_fill_mroute, + &mfc_unres_lock, &filter); + return skb->len ? : err; + } + + return mr_rtm_dumproute(skb, cb, ipmr_mr_table_iter, + _ipmr_fill_mroute, &mfc_unres_lock, &filter); +} + +static const struct nla_policy rtm_ipmr_policy[RTA_MAX + 1] = { + [RTA_SRC] = { .type = NLA_U32 }, + [RTA_DST] = { .type = NLA_U32 }, + [RTA_IIF] = { .type = NLA_U32 }, + [RTA_TABLE] = { .type = NLA_U32 }, + [RTA_MULTIPATH] = { .len = sizeof(struct rtnexthop) }, +}; + +static bool ipmr_rtm_validate_proto(unsigned char rtm_protocol) +{ + switch (rtm_protocol) { + case RTPROT_STATIC: + case RTPROT_MROUTED: + return true; + } + return false; +} + +static int ipmr_nla_get_ttls(const struct nlattr *nla, struct mfcctl *mfcc) +{ + struct rtnexthop *rtnh = nla_data(nla); + int remaining = nla_len(nla), vifi = 0; + + while (rtnh_ok(rtnh, remaining)) { + mfcc->mfcc_ttls[vifi] = rtnh->rtnh_hops; + if (++vifi == MAXVIFS) + break; + rtnh = rtnh_next(rtnh, &remaining); + } + + return remaining > 0 ? -EINVAL : vifi; +} + +/* returns < 0 on error, 0 for ADD_MFC and 1 for ADD_MFC_PROXY */ +static int rtm_to_ipmr_mfcc(struct net *net, struct nlmsghdr *nlh, + struct mfcctl *mfcc, int *mrtsock, + struct mr_table **mrtret, + struct netlink_ext_ack *extack) +{ + struct net_device *dev = NULL; + u32 tblid = RT_TABLE_DEFAULT; + struct mr_table *mrt; + struct nlattr *attr; + struct rtmsg *rtm; + int ret, rem; + + ret = nlmsg_validate_deprecated(nlh, sizeof(*rtm), RTA_MAX, + rtm_ipmr_policy, extack); + if (ret < 0) + goto out; + rtm = nlmsg_data(nlh); + + ret = -EINVAL; + if (rtm->rtm_family != RTNL_FAMILY_IPMR || rtm->rtm_dst_len != 32 || + rtm->rtm_type != RTN_MULTICAST || + rtm->rtm_scope != RT_SCOPE_UNIVERSE || + !ipmr_rtm_validate_proto(rtm->rtm_protocol)) + goto out; + + memset(mfcc, 0, sizeof(*mfcc)); + mfcc->mfcc_parent = -1; + ret = 0; + nlmsg_for_each_attr(attr, nlh, sizeof(struct rtmsg), rem) { + switch (nla_type(attr)) { + case RTA_SRC: + mfcc->mfcc_origin.s_addr = nla_get_be32(attr); + break; + case RTA_DST: + mfcc->mfcc_mcastgrp.s_addr = nla_get_be32(attr); + break; + case RTA_IIF: + dev = __dev_get_by_index(net, nla_get_u32(attr)); + if (!dev) { + ret = -ENODEV; + goto out; + } + break; + case RTA_MULTIPATH: + if (ipmr_nla_get_ttls(attr, mfcc) < 0) { + ret = -EINVAL; + goto out; + } + break; + case RTA_PREFSRC: + ret = 1; + break; + case RTA_TABLE: + tblid = nla_get_u32(attr); + break; + } + } + mrt = ipmr_get_table(net, tblid); + if (!mrt) { + ret = -ENOENT; + goto out; + } + *mrtret = mrt; + *mrtsock = rtm->rtm_protocol == RTPROT_MROUTED ? 1 : 0; + if (dev) + mfcc->mfcc_parent = ipmr_find_vif(mrt, dev); + +out: + return ret; +} + +/* takes care of both newroute and delroute */ +static int ipmr_rtm_route(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + int ret, mrtsock, parent; + struct mr_table *tbl; + struct mfcctl mfcc; + + mrtsock = 0; + tbl = NULL; + ret = rtm_to_ipmr_mfcc(net, nlh, &mfcc, &mrtsock, &tbl, extack); + if (ret < 0) + return ret; + + parent = ret ? mfcc.mfcc_parent : -1; + if (nlh->nlmsg_type == RTM_NEWROUTE) + return ipmr_mfc_add(net, tbl, &mfcc, mrtsock, parent); + else + return ipmr_mfc_delete(tbl, &mfcc, parent); +} + +static bool ipmr_fill_table(struct mr_table *mrt, struct sk_buff *skb) +{ + u32 queue_len = atomic_read(&mrt->cache_resolve_queue_len); + + if (nla_put_u32(skb, IPMRA_TABLE_ID, mrt->id) || + nla_put_u32(skb, IPMRA_TABLE_CACHE_RES_QUEUE_LEN, queue_len) || + nla_put_s32(skb, IPMRA_TABLE_MROUTE_REG_VIF_NUM, + mrt->mroute_reg_vif_num) || + nla_put_u8(skb, IPMRA_TABLE_MROUTE_DO_ASSERT, + mrt->mroute_do_assert) || + nla_put_u8(skb, IPMRA_TABLE_MROUTE_DO_PIM, mrt->mroute_do_pim) || + nla_put_u8(skb, IPMRA_TABLE_MROUTE_DO_WRVIFWHOLE, + mrt->mroute_do_wrvifwhole)) + return false; + + return true; +} + +static bool ipmr_fill_vif(struct mr_table *mrt, u32 vifid, struct sk_buff *skb) +{ + struct net_device *vif_dev; + struct nlattr *vif_nest; + struct vif_device *vif; + + vif = &mrt->vif_table[vifid]; + vif_dev = rtnl_dereference(vif->dev); + /* if the VIF doesn't exist just continue */ + if (!vif_dev) + return true; + + vif_nest = nla_nest_start_noflag(skb, IPMRA_VIF); + if (!vif_nest) + return false; + + if (nla_put_u32(skb, IPMRA_VIFA_IFINDEX, vif_dev->ifindex) || + nla_put_u32(skb, IPMRA_VIFA_VIF_ID, vifid) || + nla_put_u16(skb, IPMRA_VIFA_FLAGS, vif->flags) || + nla_put_u64_64bit(skb, IPMRA_VIFA_BYTES_IN, vif->bytes_in, + IPMRA_VIFA_PAD) || + nla_put_u64_64bit(skb, IPMRA_VIFA_BYTES_OUT, vif->bytes_out, + IPMRA_VIFA_PAD) || + nla_put_u64_64bit(skb, IPMRA_VIFA_PACKETS_IN, vif->pkt_in, + IPMRA_VIFA_PAD) || + nla_put_u64_64bit(skb, IPMRA_VIFA_PACKETS_OUT, vif->pkt_out, + IPMRA_VIFA_PAD) || + nla_put_be32(skb, IPMRA_VIFA_LOCAL_ADDR, vif->local) || + nla_put_be32(skb, IPMRA_VIFA_REMOTE_ADDR, vif->remote)) { + nla_nest_cancel(skb, vif_nest); + return false; + } + nla_nest_end(skb, vif_nest); + + return true; +} + +static int ipmr_valid_dumplink(const struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct ifinfomsg *ifm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid header for ipmr link dump"); + return -EINVAL; + } + + if (nlmsg_attrlen(nlh, sizeof(*ifm))) { + NL_SET_ERR_MSG(extack, "Invalid data after header in ipmr link dump"); + return -EINVAL; + } + + ifm = nlmsg_data(nlh); + if (ifm->__ifi_pad || ifm->ifi_type || ifm->ifi_flags || + ifm->ifi_change || ifm->ifi_index) { + NL_SET_ERR_MSG(extack, "Invalid values in header for ipmr link dump request"); + return -EINVAL; + } + + return 0; +} + +static int ipmr_rtm_dumplink(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + struct nlmsghdr *nlh = NULL; + unsigned int t = 0, s_t; + unsigned int e = 0, s_e; + struct mr_table *mrt; + + if (cb->strict_check) { + int err = ipmr_valid_dumplink(cb->nlh, cb->extack); + + if (err < 0) + return err; + } + + s_t = cb->args[0]; + s_e = cb->args[1]; + + ipmr_for_each_table(mrt, net) { + struct nlattr *vifs, *af; + struct ifinfomsg *hdr; + u32 i; + + if (t < s_t) + goto skip_table; + nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, RTM_NEWLINK, + sizeof(*hdr), NLM_F_MULTI); + if (!nlh) + break; + + hdr = nlmsg_data(nlh); + memset(hdr, 0, sizeof(*hdr)); + hdr->ifi_family = RTNL_FAMILY_IPMR; + + af = nla_nest_start_noflag(skb, IFLA_AF_SPEC); + if (!af) { + nlmsg_cancel(skb, nlh); + goto out; + } + + if (!ipmr_fill_table(mrt, skb)) { + nlmsg_cancel(skb, nlh); + goto out; + } + + vifs = nla_nest_start_noflag(skb, IPMRA_TABLE_VIFS); + if (!vifs) { + nla_nest_end(skb, af); + nlmsg_end(skb, nlh); + goto out; + } + for (i = 0; i < mrt->maxvif; i++) { + if (e < s_e) + goto skip_entry; + if (!ipmr_fill_vif(mrt, i, skb)) { + nla_nest_end(skb, vifs); + nla_nest_end(skb, af); + nlmsg_end(skb, nlh); + goto out; + } +skip_entry: + e++; + } + s_e = 0; + e = 0; + nla_nest_end(skb, vifs); + nla_nest_end(skb, af); + nlmsg_end(skb, nlh); +skip_table: + t++; + } + +out: + cb->args[1] = e; + cb->args[0] = t; + + return skb->len; +} + +#ifdef CONFIG_PROC_FS +/* The /proc interfaces to multicast routing : + * /proc/net/ip_mr_cache & /proc/net/ip_mr_vif + */ + +static void *ipmr_vif_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(RCU) +{ + struct mr_vif_iter *iter = seq->private; + struct net *net = seq_file_net(seq); + struct mr_table *mrt; + + mrt = ipmr_get_table(net, RT_TABLE_DEFAULT); + if (!mrt) + return ERR_PTR(-ENOENT); + + iter->mrt = mrt; + + rcu_read_lock(); + return mr_vif_seq_start(seq, pos); +} + +static void ipmr_vif_seq_stop(struct seq_file *seq, void *v) + __releases(RCU) +{ + rcu_read_unlock(); +} + +static int ipmr_vif_seq_show(struct seq_file *seq, void *v) +{ + struct mr_vif_iter *iter = seq->private; + struct mr_table *mrt = iter->mrt; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, + "Interface BytesIn PktsIn BytesOut PktsOut Flags Local Remote\n"); + } else { + const struct vif_device *vif = v; + const struct net_device *vif_dev; + const char *name; + + vif_dev = vif_dev_read(vif); + name = vif_dev ? vif_dev->name : "none"; + seq_printf(seq, + "%2td %-10s %8ld %7ld %8ld %7ld %05X %08X %08X\n", + vif - mrt->vif_table, + name, vif->bytes_in, vif->pkt_in, + vif->bytes_out, vif->pkt_out, + vif->flags, vif->local, vif->remote); + } + return 0; +} + +static const struct seq_operations ipmr_vif_seq_ops = { + .start = ipmr_vif_seq_start, + .next = mr_vif_seq_next, + .stop = ipmr_vif_seq_stop, + .show = ipmr_vif_seq_show, +}; + +static void *ipmr_mfc_seq_start(struct seq_file *seq, loff_t *pos) +{ + struct net *net = seq_file_net(seq); + struct mr_table *mrt; + + mrt = ipmr_get_table(net, RT_TABLE_DEFAULT); + if (!mrt) + return ERR_PTR(-ENOENT); + + return mr_mfc_seq_start(seq, pos, mrt, &mfc_unres_lock); +} + +static int ipmr_mfc_seq_show(struct seq_file *seq, void *v) +{ + int n; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, + "Group Origin Iif Pkts Bytes Wrong Oifs\n"); + } else { + const struct mfc_cache *mfc = v; + const struct mr_mfc_iter *it = seq->private; + const struct mr_table *mrt = it->mrt; + + seq_printf(seq, "%08X %08X %-3hd", + (__force u32) mfc->mfc_mcastgrp, + (__force u32) mfc->mfc_origin, + mfc->_c.mfc_parent); + + if (it->cache != &mrt->mfc_unres_queue) { + seq_printf(seq, " %8lu %8lu %8lu", + mfc->_c.mfc_un.res.pkt, + mfc->_c.mfc_un.res.bytes, + mfc->_c.mfc_un.res.wrong_if); + for (n = mfc->_c.mfc_un.res.minvif; + n < mfc->_c.mfc_un.res.maxvif; n++) { + if (VIF_EXISTS(mrt, n) && + mfc->_c.mfc_un.res.ttls[n] < 255) + seq_printf(seq, + " %2d:%-3d", + n, mfc->_c.mfc_un.res.ttls[n]); + } + } else { + /* unresolved mfc_caches don't contain + * pkt, bytes and wrong_if values + */ + seq_printf(seq, " %8lu %8lu %8lu", 0ul, 0ul, 0ul); + } + seq_putc(seq, '\n'); + } + return 0; +} + +static const struct seq_operations ipmr_mfc_seq_ops = { + .start = ipmr_mfc_seq_start, + .next = mr_mfc_seq_next, + .stop = mr_mfc_seq_stop, + .show = ipmr_mfc_seq_show, +}; +#endif + +#ifdef CONFIG_IP_PIMSM_V2 +static const struct net_protocol pim_protocol = { + .handler = pim_rcv, +}; +#endif + +static unsigned int ipmr_seq_read(struct net *net) +{ + ASSERT_RTNL(); + + return net->ipv4.ipmr_seq + ipmr_rules_seq_read(net); +} + +static int ipmr_dump(struct net *net, struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + return mr_dump(net, nb, RTNL_FAMILY_IPMR, ipmr_rules_dump, + ipmr_mr_table_iter, extack); +} + +static const struct fib_notifier_ops ipmr_notifier_ops_template = { + .family = RTNL_FAMILY_IPMR, + .fib_seq_read = ipmr_seq_read, + .fib_dump = ipmr_dump, + .owner = THIS_MODULE, +}; + +static int __net_init ipmr_notifier_init(struct net *net) +{ + struct fib_notifier_ops *ops; + + net->ipv4.ipmr_seq = 0; + + ops = fib_notifier_ops_register(&ipmr_notifier_ops_template, net); + if (IS_ERR(ops)) + return PTR_ERR(ops); + net->ipv4.ipmr_notifier_ops = ops; + + return 0; +} + +static void __net_exit ipmr_notifier_exit(struct net *net) +{ + fib_notifier_ops_unregister(net->ipv4.ipmr_notifier_ops); + net->ipv4.ipmr_notifier_ops = NULL; +} + +/* Setup for IP multicast routing */ +static int __net_init ipmr_net_init(struct net *net) +{ + int err; + + err = ipmr_notifier_init(net); + if (err) + goto ipmr_notifier_fail; + + err = ipmr_rules_init(net); + if (err < 0) + goto ipmr_rules_fail; + +#ifdef CONFIG_PROC_FS + err = -ENOMEM; + if (!proc_create_net("ip_mr_vif", 0, net->proc_net, &ipmr_vif_seq_ops, + sizeof(struct mr_vif_iter))) + goto proc_vif_fail; + if (!proc_create_net("ip_mr_cache", 0, net->proc_net, &ipmr_mfc_seq_ops, + sizeof(struct mr_mfc_iter))) + goto proc_cache_fail; +#endif + return 0; + +#ifdef CONFIG_PROC_FS +proc_cache_fail: + remove_proc_entry("ip_mr_vif", net->proc_net); +proc_vif_fail: + rtnl_lock(); + ipmr_rules_exit(net); + rtnl_unlock(); +#endif +ipmr_rules_fail: + ipmr_notifier_exit(net); +ipmr_notifier_fail: + return err; +} + +static void __net_exit ipmr_net_exit(struct net *net) +{ +#ifdef CONFIG_PROC_FS + remove_proc_entry("ip_mr_cache", net->proc_net); + remove_proc_entry("ip_mr_vif", net->proc_net); +#endif + ipmr_notifier_exit(net); +} + +static void __net_exit ipmr_net_exit_batch(struct list_head *net_list) +{ + struct net *net; + + rtnl_lock(); + list_for_each_entry(net, net_list, exit_list) + ipmr_rules_exit(net); + rtnl_unlock(); +} + +static struct pernet_operations ipmr_net_ops = { + .init = ipmr_net_init, + .exit = ipmr_net_exit, + .exit_batch = ipmr_net_exit_batch, +}; + +int __init ip_mr_init(void) +{ + int err; + + mrt_cachep = kmem_cache_create("ip_mrt_cache", + sizeof(struct mfc_cache), + 0, SLAB_HWCACHE_ALIGN | SLAB_PANIC, + NULL); + + err = register_pernet_subsys(&ipmr_net_ops); + if (err) + goto reg_pernet_fail; + + err = register_netdevice_notifier(&ip_mr_notifier); + if (err) + goto reg_notif_fail; +#ifdef CONFIG_IP_PIMSM_V2 + if (inet_add_protocol(&pim_protocol, IPPROTO_PIM) < 0) { + pr_err("%s: can't add PIM protocol\n", __func__); + err = -EAGAIN; + goto add_proto_fail; + } +#endif + rtnl_register(RTNL_FAMILY_IPMR, RTM_GETROUTE, + ipmr_rtm_getroute, ipmr_rtm_dumproute, 0); + rtnl_register(RTNL_FAMILY_IPMR, RTM_NEWROUTE, + ipmr_rtm_route, NULL, 0); + rtnl_register(RTNL_FAMILY_IPMR, RTM_DELROUTE, + ipmr_rtm_route, NULL, 0); + + rtnl_register(RTNL_FAMILY_IPMR, RTM_GETLINK, + NULL, ipmr_rtm_dumplink, 0); + return 0; + +#ifdef CONFIG_IP_PIMSM_V2 +add_proto_fail: + unregister_netdevice_notifier(&ip_mr_notifier); +#endif +reg_notif_fail: + unregister_pernet_subsys(&ipmr_net_ops); +reg_pernet_fail: + kmem_cache_destroy(mrt_cachep); + return err; +} diff --git a/net/ipv4/ipmr_base.c b/net/ipv4/ipmr_base.c new file mode 100644 index 000000000..271dc03fc --- /dev/null +++ b/net/ipv4/ipmr_base.c @@ -0,0 +1,448 @@ +/* Linux multicast routing support + * Common logic shared by IPv4 [ipmr] and IPv6 [ip6mr] implementation + */ + +#include +#include + +/* Sets everything common except 'dev', since that is done under locking */ +void vif_device_init(struct vif_device *v, + struct net_device *dev, + unsigned long rate_limit, + unsigned char threshold, + unsigned short flags, + unsigned short get_iflink_mask) +{ + RCU_INIT_POINTER(v->dev, NULL); + v->bytes_in = 0; + v->bytes_out = 0; + v->pkt_in = 0; + v->pkt_out = 0; + v->rate_limit = rate_limit; + v->flags = flags; + v->threshold = threshold; + if (v->flags & get_iflink_mask) + v->link = dev_get_iflink(dev); + else + v->link = dev->ifindex; +} +EXPORT_SYMBOL(vif_device_init); + +struct mr_table * +mr_table_alloc(struct net *net, u32 id, + struct mr_table_ops *ops, + void (*expire_func)(struct timer_list *t), + void (*table_set)(struct mr_table *mrt, + struct net *net)) +{ + struct mr_table *mrt; + int err; + + mrt = kzalloc(sizeof(*mrt), GFP_KERNEL); + if (!mrt) + return ERR_PTR(-ENOMEM); + mrt->id = id; + write_pnet(&mrt->net, net); + + mrt->ops = *ops; + err = rhltable_init(&mrt->mfc_hash, mrt->ops.rht_params); + if (err) { + kfree(mrt); + return ERR_PTR(err); + } + INIT_LIST_HEAD(&mrt->mfc_cache_list); + INIT_LIST_HEAD(&mrt->mfc_unres_queue); + + timer_setup(&mrt->ipmr_expire_timer, expire_func, 0); + + mrt->mroute_reg_vif_num = -1; + table_set(mrt, net); + return mrt; +} +EXPORT_SYMBOL(mr_table_alloc); + +void *mr_mfc_find_parent(struct mr_table *mrt, void *hasharg, int parent) +{ + struct rhlist_head *tmp, *list; + struct mr_mfc *c; + + list = rhltable_lookup(&mrt->mfc_hash, hasharg, *mrt->ops.rht_params); + rhl_for_each_entry_rcu(c, tmp, list, mnode) + if (parent == -1 || parent == c->mfc_parent) + return c; + + return NULL; +} +EXPORT_SYMBOL(mr_mfc_find_parent); + +void *mr_mfc_find_any_parent(struct mr_table *mrt, int vifi) +{ + struct rhlist_head *tmp, *list; + struct mr_mfc *c; + + list = rhltable_lookup(&mrt->mfc_hash, mrt->ops.cmparg_any, + *mrt->ops.rht_params); + rhl_for_each_entry_rcu(c, tmp, list, mnode) + if (c->mfc_un.res.ttls[vifi] < 255) + return c; + + return NULL; +} +EXPORT_SYMBOL(mr_mfc_find_any_parent); + +void *mr_mfc_find_any(struct mr_table *mrt, int vifi, void *hasharg) +{ + struct rhlist_head *tmp, *list; + struct mr_mfc *c, *proxy; + + list = rhltable_lookup(&mrt->mfc_hash, hasharg, *mrt->ops.rht_params); + rhl_for_each_entry_rcu(c, tmp, list, mnode) { + if (c->mfc_un.res.ttls[vifi] < 255) + return c; + + /* It's ok if the vifi is part of the static tree */ + proxy = mr_mfc_find_any_parent(mrt, c->mfc_parent); + if (proxy && proxy->mfc_un.res.ttls[vifi] < 255) + return c; + } + + return mr_mfc_find_any_parent(mrt, vifi); +} +EXPORT_SYMBOL(mr_mfc_find_any); + +#ifdef CONFIG_PROC_FS +void *mr_vif_seq_idx(struct net *net, struct mr_vif_iter *iter, loff_t pos) +{ + struct mr_table *mrt = iter->mrt; + + for (iter->ct = 0; iter->ct < mrt->maxvif; ++iter->ct) { + if (!VIF_EXISTS(mrt, iter->ct)) + continue; + if (pos-- == 0) + return &mrt->vif_table[iter->ct]; + } + return NULL; +} +EXPORT_SYMBOL(mr_vif_seq_idx); + +void *mr_vif_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct mr_vif_iter *iter = seq->private; + struct net *net = seq_file_net(seq); + struct mr_table *mrt = iter->mrt; + + ++*pos; + if (v == SEQ_START_TOKEN) + return mr_vif_seq_idx(net, iter, 0); + + while (++iter->ct < mrt->maxvif) { + if (!VIF_EXISTS(mrt, iter->ct)) + continue; + return &mrt->vif_table[iter->ct]; + } + return NULL; +} +EXPORT_SYMBOL(mr_vif_seq_next); + +void *mr_mfc_seq_idx(struct net *net, + struct mr_mfc_iter *it, loff_t pos) +{ + struct mr_table *mrt = it->mrt; + struct mr_mfc *mfc; + + rcu_read_lock(); + it->cache = &mrt->mfc_cache_list; + list_for_each_entry_rcu(mfc, &mrt->mfc_cache_list, list) + if (pos-- == 0) + return mfc; + rcu_read_unlock(); + + spin_lock_bh(it->lock); + it->cache = &mrt->mfc_unres_queue; + list_for_each_entry(mfc, it->cache, list) + if (pos-- == 0) + return mfc; + spin_unlock_bh(it->lock); + + it->cache = NULL; + return NULL; +} +EXPORT_SYMBOL(mr_mfc_seq_idx); + +void *mr_mfc_seq_next(struct seq_file *seq, void *v, + loff_t *pos) +{ + struct mr_mfc_iter *it = seq->private; + struct net *net = seq_file_net(seq); + struct mr_table *mrt = it->mrt; + struct mr_mfc *c = v; + + ++*pos; + + if (v == SEQ_START_TOKEN) + return mr_mfc_seq_idx(net, seq->private, 0); + + if (c->list.next != it->cache) + return list_entry(c->list.next, struct mr_mfc, list); + + if (it->cache == &mrt->mfc_unres_queue) + goto end_of_list; + + /* exhausted cache_array, show unresolved */ + rcu_read_unlock(); + it->cache = &mrt->mfc_unres_queue; + + spin_lock_bh(it->lock); + if (!list_empty(it->cache)) + return list_first_entry(it->cache, struct mr_mfc, list); + +end_of_list: + spin_unlock_bh(it->lock); + it->cache = NULL; + + return NULL; +} +EXPORT_SYMBOL(mr_mfc_seq_next); +#endif + +int mr_fill_mroute(struct mr_table *mrt, struct sk_buff *skb, + struct mr_mfc *c, struct rtmsg *rtm) +{ + struct net_device *vif_dev; + struct rta_mfc_stats mfcs; + struct nlattr *mp_attr; + struct rtnexthop *nhp; + unsigned long lastuse; + int ct; + + /* If cache is unresolved, don't try to parse IIF and OIF */ + if (c->mfc_parent >= MAXVIFS) { + rtm->rtm_flags |= RTNH_F_UNRESOLVED; + return -ENOENT; + } + + rcu_read_lock(); + vif_dev = rcu_dereference(mrt->vif_table[c->mfc_parent].dev); + if (vif_dev && nla_put_u32(skb, RTA_IIF, vif_dev->ifindex) < 0) { + rcu_read_unlock(); + return -EMSGSIZE; + } + rcu_read_unlock(); + + if (c->mfc_flags & MFC_OFFLOAD) + rtm->rtm_flags |= RTNH_F_OFFLOAD; + + mp_attr = nla_nest_start_noflag(skb, RTA_MULTIPATH); + if (!mp_attr) + return -EMSGSIZE; + + rcu_read_lock(); + for (ct = c->mfc_un.res.minvif; ct < c->mfc_un.res.maxvif; ct++) { + struct vif_device *vif = &mrt->vif_table[ct]; + + vif_dev = rcu_dereference(vif->dev); + if (vif_dev && c->mfc_un.res.ttls[ct] < 255) { + + nhp = nla_reserve_nohdr(skb, sizeof(*nhp)); + if (!nhp) { + rcu_read_unlock(); + nla_nest_cancel(skb, mp_attr); + return -EMSGSIZE; + } + + nhp->rtnh_flags = 0; + nhp->rtnh_hops = c->mfc_un.res.ttls[ct]; + nhp->rtnh_ifindex = vif_dev->ifindex; + nhp->rtnh_len = sizeof(*nhp); + } + } + rcu_read_unlock(); + + nla_nest_end(skb, mp_attr); + + lastuse = READ_ONCE(c->mfc_un.res.lastuse); + lastuse = time_after_eq(jiffies, lastuse) ? jiffies - lastuse : 0; + + mfcs.mfcs_packets = c->mfc_un.res.pkt; + mfcs.mfcs_bytes = c->mfc_un.res.bytes; + mfcs.mfcs_wrong_if = c->mfc_un.res.wrong_if; + if (nla_put_64bit(skb, RTA_MFC_STATS, sizeof(mfcs), &mfcs, RTA_PAD) || + nla_put_u64_64bit(skb, RTA_EXPIRES, jiffies_to_clock_t(lastuse), + RTA_PAD)) + return -EMSGSIZE; + + rtm->rtm_type = RTN_MULTICAST; + return 1; +} +EXPORT_SYMBOL(mr_fill_mroute); + +static bool mr_mfc_uses_dev(const struct mr_table *mrt, + const struct mr_mfc *c, + const struct net_device *dev) +{ + int ct; + + for (ct = c->mfc_un.res.minvif; ct < c->mfc_un.res.maxvif; ct++) { + const struct net_device *vif_dev; + const struct vif_device *vif; + + vif = &mrt->vif_table[ct]; + vif_dev = rcu_access_pointer(vif->dev); + if (vif_dev && c->mfc_un.res.ttls[ct] < 255 && + vif_dev == dev) + return true; + } + return false; +} + +int mr_table_dump(struct mr_table *mrt, struct sk_buff *skb, + struct netlink_callback *cb, + int (*fill)(struct mr_table *mrt, struct sk_buff *skb, + u32 portid, u32 seq, struct mr_mfc *c, + int cmd, int flags), + spinlock_t *lock, struct fib_dump_filter *filter) +{ + unsigned int e = 0, s_e = cb->args[1]; + unsigned int flags = NLM_F_MULTI; + struct mr_mfc *mfc; + int err; + + if (filter->filter_set) + flags |= NLM_F_DUMP_FILTERED; + + list_for_each_entry_rcu(mfc, &mrt->mfc_cache_list, list) { + if (e < s_e) + goto next_entry; + if (filter->dev && + !mr_mfc_uses_dev(mrt, mfc, filter->dev)) + goto next_entry; + + err = fill(mrt, skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, mfc, RTM_NEWROUTE, flags); + if (err < 0) + goto out; +next_entry: + e++; + } + + spin_lock_bh(lock); + list_for_each_entry(mfc, &mrt->mfc_unres_queue, list) { + if (e < s_e) + goto next_entry2; + if (filter->dev && + !mr_mfc_uses_dev(mrt, mfc, filter->dev)) + goto next_entry2; + + err = fill(mrt, skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, mfc, RTM_NEWROUTE, flags); + if (err < 0) { + spin_unlock_bh(lock); + goto out; + } +next_entry2: + e++; + } + spin_unlock_bh(lock); + err = 0; +out: + cb->args[1] = e; + return err; +} +EXPORT_SYMBOL(mr_table_dump); + +int mr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb, + struct mr_table *(*iter)(struct net *net, + struct mr_table *mrt), + int (*fill)(struct mr_table *mrt, + struct sk_buff *skb, + u32 portid, u32 seq, struct mr_mfc *c, + int cmd, int flags), + spinlock_t *lock, struct fib_dump_filter *filter) +{ + unsigned int t = 0, s_t = cb->args[0]; + struct net *net = sock_net(skb->sk); + struct mr_table *mrt; + int err; + + /* multicast does not track protocol or have route type other + * than RTN_MULTICAST + */ + if (filter->filter_set) { + if (filter->protocol || filter->flags || + (filter->rt_type && filter->rt_type != RTN_MULTICAST)) + return skb->len; + } + + rcu_read_lock(); + for (mrt = iter(net, NULL); mrt; mrt = iter(net, mrt)) { + if (t < s_t) + goto next_table; + + err = mr_table_dump(mrt, skb, cb, fill, lock, filter); + if (err < 0) + break; + cb->args[1] = 0; +next_table: + t++; + } + rcu_read_unlock(); + + cb->args[0] = t; + + return skb->len; +} +EXPORT_SYMBOL(mr_rtm_dumproute); + +int mr_dump(struct net *net, struct notifier_block *nb, unsigned short family, + int (*rules_dump)(struct net *net, + struct notifier_block *nb, + struct netlink_ext_ack *extack), + struct mr_table *(*mr_iter)(struct net *net, + struct mr_table *mrt), + struct netlink_ext_ack *extack) +{ + struct mr_table *mrt; + int err; + + err = rules_dump(net, nb, extack); + if (err) + return err; + + for (mrt = mr_iter(net, NULL); mrt; mrt = mr_iter(net, mrt)) { + struct vif_device *v = &mrt->vif_table[0]; + struct net_device *vif_dev; + struct mr_mfc *mfc; + int vifi; + + /* Notifiy on table VIF entries */ + rcu_read_lock(); + for (vifi = 0; vifi < mrt->maxvif; vifi++, v++) { + vif_dev = rcu_dereference(v->dev); + if (!vif_dev) + continue; + + err = mr_call_vif_notifier(nb, family, + FIB_EVENT_VIF_ADD, v, + vif_dev, vifi, + mrt->id, extack); + if (err) + break; + } + rcu_read_unlock(); + + if (err) + return err; + + /* Notify on table MFC entries */ + list_for_each_entry_rcu(mfc, &mrt->mfc_cache_list, list) { + err = mr_call_mfc_notifier(nb, family, + FIB_EVENT_ENTRY_ADD, + mfc, mrt->id, extack); + if (err) + return err; + } + } + + return 0; +} +EXPORT_SYMBOL(mr_dump); diff --git a/net/ipv4/metrics.c b/net/ipv4/metrics.c new file mode 100644 index 000000000..6a1427916 --- /dev/null +++ b/net/ipv4/metrics.c @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include +#include + +static int ip_metrics_convert(struct net *net, struct nlattr *fc_mx, + int fc_mx_len, u32 *metrics, + struct netlink_ext_ack *extack) +{ + bool ecn_ca = false; + struct nlattr *nla; + int remaining; + + if (!fc_mx) + return 0; + + nla_for_each_attr(nla, fc_mx, fc_mx_len, remaining) { + int type = nla_type(nla); + u32 val; + + if (!type) + continue; + if (type > RTAX_MAX) { + NL_SET_ERR_MSG(extack, "Invalid metric type"); + return -EINVAL; + } + + type = array_index_nospec(type, RTAX_MAX + 1); + if (type == RTAX_CC_ALGO) { + char tmp[TCP_CA_NAME_MAX]; + + nla_strscpy(tmp, nla, sizeof(tmp)); + val = tcp_ca_get_key_by_name(net, tmp, &ecn_ca); + if (val == TCP_CA_UNSPEC) { + NL_SET_ERR_MSG(extack, "Unknown tcp congestion algorithm"); + return -EINVAL; + } + } else { + if (nla_len(nla) != sizeof(u32)) { + NL_SET_ERR_MSG_ATTR(extack, nla, + "Invalid attribute in metrics"); + return -EINVAL; + } + val = nla_get_u32(nla); + } + if (type == RTAX_ADVMSS && val > 65535 - 40) + val = 65535 - 40; + if (type == RTAX_MTU && val > 65535 - 15) + val = 65535 - 15; + if (type == RTAX_HOPLIMIT && val > 255) + val = 255; + if (type == RTAX_FEATURES && (val & ~RTAX_FEATURE_MASK)) { + NL_SET_ERR_MSG(extack, "Unknown flag set in feature mask in metrics attribute"); + return -EINVAL; + } + metrics[type - 1] = val; + } + + if (ecn_ca) + metrics[RTAX_FEATURES - 1] |= DST_FEATURE_ECN_CA; + + return 0; +} + +struct dst_metrics *ip_fib_metrics_init(struct net *net, struct nlattr *fc_mx, + int fc_mx_len, + struct netlink_ext_ack *extack) +{ + struct dst_metrics *fib_metrics; + int err; + + if (!fc_mx) + return (struct dst_metrics *)&dst_default_metrics; + + fib_metrics = kzalloc(sizeof(*fib_metrics), GFP_KERNEL); + if (unlikely(!fib_metrics)) + return ERR_PTR(-ENOMEM); + + err = ip_metrics_convert(net, fc_mx, fc_mx_len, fib_metrics->metrics, + extack); + if (!err) { + refcount_set(&fib_metrics->refcnt, 1); + } else { + kfree(fib_metrics); + fib_metrics = ERR_PTR(err); + } + + return fib_metrics; +} +EXPORT_SYMBOL_GPL(ip_fib_metrics_init); diff --git a/net/ipv4/netfilter.c b/net/ipv4/netfilter.c new file mode 100644 index 000000000..bd1351654 --- /dev/null +++ b/net/ipv4/netfilter.c @@ -0,0 +1,95 @@ +/* + * IPv4 specific functions of netfilter core + * + * Rusty Russell (C) 2000 -- This code is GPL. + * Patrick McHardy (C) 2006-2012 + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* route_me_harder function, used by iptable_nat, iptable_mangle + ip_queue */ +int ip_route_me_harder(struct net *net, struct sock *sk, struct sk_buff *skb, unsigned int addr_type) +{ + const struct iphdr *iph = ip_hdr(skb); + struct rtable *rt; + struct flowi4 fl4 = {}; + __be32 saddr = iph->saddr; + __u8 flags; + struct net_device *dev = skb_dst(skb)->dev; + struct flow_keys flkeys; + unsigned int hh_len; + + sk = sk_to_full_sk(sk); + flags = sk ? inet_sk_flowi_flags(sk) : 0; + + if (addr_type == RTN_UNSPEC) + addr_type = inet_addr_type_dev_table(net, dev, saddr); + if (addr_type == RTN_LOCAL || addr_type == RTN_UNICAST) + flags |= FLOWI_FLAG_ANYSRC; + else + saddr = 0; + + /* some non-standard hacks like ipt_REJECT.c:send_reset() can cause + * packets with foreign saddr to appear on the NF_INET_LOCAL_OUT hook. + */ + fl4.daddr = iph->daddr; + fl4.saddr = saddr; + fl4.flowi4_tos = RT_TOS(iph->tos); + fl4.flowi4_oif = sk ? sk->sk_bound_dev_if : 0; + fl4.flowi4_l3mdev = l3mdev_master_ifindex(dev); + fl4.flowi4_mark = skb->mark; + fl4.flowi4_flags = flags; + fib4_rules_early_flow_dissect(net, skb, &fl4, &flkeys); + rt = ip_route_output_key(net, &fl4); + if (IS_ERR(rt)) + return PTR_ERR(rt); + + /* Drop old route. */ + skb_dst_drop(skb); + skb_dst_set(skb, &rt->dst); + + if (skb_dst(skb)->error) + return skb_dst(skb)->error; + +#ifdef CONFIG_XFRM + if (!(IPCB(skb)->flags & IPSKB_XFRM_TRANSFORMED) && + xfrm_decode_session(skb, flowi4_to_flowi(&fl4), AF_INET) == 0) { + struct dst_entry *dst = skb_dst(skb); + skb_dst_set(skb, NULL); + dst = xfrm_lookup(net, dst, flowi4_to_flowi(&fl4), sk, 0); + if (IS_ERR(dst)) + return PTR_ERR(dst); + skb_dst_set(skb, dst); + } +#endif + + /* Change in oif may mean change in hh_len. */ + hh_len = skb_dst(skb)->dev->hard_header_len; + if (skb_headroom(skb) < hh_len && + pskb_expand_head(skb, HH_DATA_ALIGN(hh_len - skb_headroom(skb)), + 0, GFP_ATOMIC)) + return -ENOMEM; + + return 0; +} +EXPORT_SYMBOL(ip_route_me_harder); + +int nf_ip_route(struct net *net, struct dst_entry **dst, struct flowi *fl, + bool strict __always_unused) +{ + struct rtable *rt = ip_route_output_key(net, &fl->u.ip4); + if (IS_ERR(rt)) + return PTR_ERR(rt); + *dst = &rt->dst; + return 0; +} +EXPORT_SYMBOL_GPL(nf_ip_route); diff --git a/net/ipv4/netfilter/Kconfig b/net/ipv4/netfilter/Kconfig new file mode 100644 index 000000000..aab384126 --- /dev/null +++ b/net/ipv4/netfilter/Kconfig @@ -0,0 +1,358 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# IP netfilter configuration +# + +menu "IP: Netfilter Configuration" + depends on INET && NETFILTER + +config NF_DEFRAG_IPV4 + tristate + default n + +config NF_SOCKET_IPV4 + tristate "IPv4 socket lookup support" + help + This option enables the IPv4 socket lookup infrastructure. This is + is required by the {ip,nf}tables socket match. + +config NF_TPROXY_IPV4 + tristate "IPv4 tproxy support" + +if NF_TABLES + +config NF_TABLES_IPV4 + bool "IPv4 nf_tables support" + help + This option enables the IPv4 support for nf_tables. + +if NF_TABLES_IPV4 + +config NFT_REJECT_IPV4 + select NF_REJECT_IPV4 + default NFT_REJECT + tristate + +config NFT_DUP_IPV4 + tristate "IPv4 nf_tables packet duplication support" + depends on !NF_CONNTRACK || NF_CONNTRACK + select NF_DUP_IPV4 + help + This module enables IPv4 packet duplication support for nf_tables. + +config NFT_FIB_IPV4 + select NFT_FIB + tristate "nf_tables fib / ip route lookup support" + help + This module enables IPv4 FIB lookups, e.g. for reverse path filtering. + It also allows query of the FIB for the route type, e.g. local, unicast, + multicast or blackhole. + +endif # NF_TABLES_IPV4 + +config NF_TABLES_ARP + bool "ARP nf_tables support" + select NETFILTER_FAMILY_ARP + help + This option enables the ARP support for nf_tables. + +endif # NF_TABLES + +config NF_DUP_IPV4 + tristate "Netfilter IPv4 packet duplication to alternate destination" + depends on !NF_CONNTRACK || NF_CONNTRACK + help + This option enables the nf_dup_ipv4 core, which duplicates an IPv4 + packet to be rerouted to another destination. + +config NF_LOG_ARP + tristate "ARP packet logging" + default m if NETFILTER_ADVANCED=n + select NF_LOG_SYSLOG + help + This is a backwards-compat option for the user's convenience + (e.g. when running oldconfig). It selects CONFIG_NF_LOG_SYSLOG. + +config NF_LOG_IPV4 + tristate "IPv4 packet logging" + default m if NETFILTER_ADVANCED=n + select NF_LOG_SYSLOG + help + This is a backwards-compat option for the user's convenience + (e.g. when running oldconfig). It selects CONFIG_NF_LOG_SYSLOG. + +config NF_REJECT_IPV4 + tristate "IPv4 packet rejection" + default m if NETFILTER_ADVANCED=n + +if NF_NAT +config NF_NAT_SNMP_BASIC + tristate "Basic SNMP-ALG support" + depends on NF_CONNTRACK_SNMP + depends on NETFILTER_ADVANCED + default NF_NAT && NF_CONNTRACK_SNMP + select ASN1 + help + + This module implements an Application Layer Gateway (ALG) for + SNMP payloads. In conjunction with NAT, it allows a network + management system to access multiple private networks with + conflicting addresses. It works by modifying IP addresses + inside SNMP payloads to match IP-layer NAT mapping. + + This is the "basic" form of SNMP-ALG, as described in RFC 2962 + + To compile it as a module, choose M here. If unsure, say N. + +config NF_NAT_PPTP + tristate + depends on NF_CONNTRACK + default NF_CONNTRACK_PPTP + +config NF_NAT_H323 + tristate + depends on NF_CONNTRACK + default NF_CONNTRACK_H323 + +endif # NF_NAT + +config IP_NF_IPTABLES + tristate "IP tables support (required for filtering/masq/NAT)" + default m if NETFILTER_ADVANCED=n + select NETFILTER_XTABLES + help + iptables is a general, extensible packet identification framework. + The packet filtering and full NAT (masquerading, port forwarding, + etc) subsystems now use this: say `Y' or `M' here if you want to use + either of those. + + To compile it as a module, choose M here. If unsure, say N. + +if IP_NF_IPTABLES + +# The matches. +config IP_NF_MATCH_AH + tristate '"ah" match support' + depends on NETFILTER_ADVANCED + help + This match extension allows you to match a range of SPIs + inside AH header of IPSec packets. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_NF_MATCH_ECN + tristate '"ecn" match support' + depends on NETFILTER_ADVANCED + select NETFILTER_XT_MATCH_ECN + help + This is a backwards-compat option for the user's convenience + (e.g. when running oldconfig). It selects + CONFIG_NETFILTER_XT_MATCH_ECN. + +config IP_NF_MATCH_RPFILTER + tristate '"rpfilter" reverse path filter match support' + depends on NETFILTER_ADVANCED + depends on IP_NF_MANGLE || IP_NF_RAW + help + This option allows you to match packets whose replies would + go out via the interface the packet came in. + + To compile it as a module, choose M here. If unsure, say N. + The module will be called ipt_rpfilter. + +config IP_NF_MATCH_TTL + tristate '"ttl" match support' + depends on NETFILTER_ADVANCED + select NETFILTER_XT_MATCH_HL + help + This is a backwards-compat option for the user's convenience + (e.g. when running oldconfig). It selects + CONFIG_NETFILTER_XT_MATCH_HL. + +# `filter', generic and specific targets +config IP_NF_FILTER + tristate "Packet filtering" + default m if NETFILTER_ADVANCED=n + help + Packet filtering defines a table `filter', which has a series of + rules for simple packet filtering at local input, forwarding and + local output. See the man page for iptables(8). + + To compile it as a module, choose M here. If unsure, say N. + +config IP_NF_TARGET_REJECT + tristate "REJECT target support" + depends on IP_NF_FILTER + select NF_REJECT_IPV4 + default m if NETFILTER_ADVANCED=n + help + The REJECT target allows a filtering rule to specify that an ICMP + error should be issued in response to an incoming packet, rather + than silently being dropped. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_NF_TARGET_SYNPROXY + tristate "SYNPROXY target support" + depends on NF_CONNTRACK && NETFILTER_ADVANCED + select NETFILTER_SYNPROXY + select SYN_COOKIES + help + The SYNPROXY target allows you to intercept TCP connections and + establish them using syncookies before they are passed on to the + server. This allows to avoid conntrack and server resource usage + during SYN-flood attacks. + + To compile it as a module, choose M here. If unsure, say N. + +# NAT + specific targets: nf_conntrack +config IP_NF_NAT + tristate "iptables NAT support" + depends on NF_CONNTRACK + default m if NETFILTER_ADVANCED=n + select NF_NAT + select NETFILTER_XT_NAT + help + This enables the `nat' table in iptables. This allows masquerading, + port forwarding and other forms of full Network Address Port + Translation. + + To compile it as a module, choose M here. If unsure, say N. + +if IP_NF_NAT + +config IP_NF_TARGET_MASQUERADE + tristate "MASQUERADE target support" + select NETFILTER_XT_TARGET_MASQUERADE + help + This is a backwards-compat option for the user's convenience + (e.g. when running oldconfig). It selects NETFILTER_XT_TARGET_MASQUERADE. + +config IP_NF_TARGET_NETMAP + tristate "NETMAP target support" + depends on NETFILTER_ADVANCED + select NETFILTER_XT_TARGET_NETMAP + help + This is a backwards-compat option for the user's convenience + (e.g. when running oldconfig). It selects + CONFIG_NETFILTER_XT_TARGET_NETMAP. + +config IP_NF_TARGET_REDIRECT + tristate "REDIRECT target support" + depends on NETFILTER_ADVANCED + select NETFILTER_XT_TARGET_REDIRECT + help + This is a backwards-compat option for the user's convenience + (e.g. when running oldconfig). It selects + CONFIG_NETFILTER_XT_TARGET_REDIRECT. + +endif # IP_NF_NAT + +# mangle + specific targets +config IP_NF_MANGLE + tristate "Packet mangling" + default m if NETFILTER_ADVANCED=n + help + This option adds a `mangle' table to iptables: see the man page for + iptables(8). This table is used for various packet alterations + which can effect how the packet is routed. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_NF_TARGET_CLUSTERIP + tristate "CLUSTERIP target support" + depends on IP_NF_MANGLE + depends on NF_CONNTRACK + depends on NETFILTER_ADVANCED + select NF_CONNTRACK_MARK + select NETFILTER_FAMILY_ARP + help + The CLUSTERIP target allows you to build load-balancing clusters of + network servers without having a dedicated load-balancing + router/server/switch. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_NF_TARGET_ECN + tristate "ECN target support" + depends on IP_NF_MANGLE + depends on NETFILTER_ADVANCED + help + This option adds a `ECN' target, which can be used in the iptables mangle + table. + + You can use this target to remove the ECN bits from the IPv4 header of + an IP packet. This is particularly useful, if you need to work around + existing ECN blackholes on the internet, but don't want to disable + ECN support in general. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_NF_TARGET_TTL + tristate '"TTL" target support' + depends on NETFILTER_ADVANCED && IP_NF_MANGLE + select NETFILTER_XT_TARGET_HL + help + This is a backwards-compatible option for the user's convenience + (e.g. when running oldconfig). It selects + CONFIG_NETFILTER_XT_TARGET_HL. + +# raw + specific targets +config IP_NF_RAW + tristate 'raw table support (required for NOTRACK/TRACE)' + help + This option adds a `raw' table to iptables. This table is the very + first in the netfilter framework and hooks in at the PREROUTING + and OUTPUT chains. + + If you want to compile it as a module, say M here and read + . If unsure, say `N'. + +# security table for MAC policy +config IP_NF_SECURITY + tristate "Security table" + depends on SECURITY + depends on NETFILTER_ADVANCED + help + This option adds a `security' table to iptables, for use + with Mandatory Access Control (MAC) policy. + + If unsure, say N. + +endif # IP_NF_IPTABLES + +# ARP tables +config IP_NF_ARPTABLES + tristate "ARP tables support" + select NETFILTER_XTABLES + select NETFILTER_FAMILY_ARP + depends on NETFILTER_ADVANCED + help + arptables is a general, extensible packet identification framework. + The ARP packet filtering and mangling (manipulation)subsystems + use this: say Y or M here if you want to use either of those. + + To compile it as a module, choose M here. If unsure, say N. + +if IP_NF_ARPTABLES + +config IP_NF_ARPFILTER + tristate "ARP packet filtering" + help + ARP packet filtering defines a table `filter', which has a series of + rules for simple ARP packet filtering at local input and + local output. On a bridge, you can also specify filtering rules + for forwarded ARP packets. See the man page for arptables(8). + + To compile it as a module, choose M here. If unsure, say N. + +config IP_NF_ARP_MANGLE + tristate "ARP payload mangling" + help + Allows altering the ARP packet payload: source and destination + hardware and network addresses. + +endif # IP_NF_ARPTABLES + +endmenu + diff --git a/net/ipv4/netfilter/Makefile b/net/ipv4/netfilter/Makefile new file mode 100644 index 000000000..93bad1184 --- /dev/null +++ b/net/ipv4/netfilter/Makefile @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the netfilter modules on top of IPv4. +# + +# defrag +obj-$(CONFIG_NF_DEFRAG_IPV4) += nf_defrag_ipv4.o + +obj-$(CONFIG_NF_SOCKET_IPV4) += nf_socket_ipv4.o +obj-$(CONFIG_NF_TPROXY_IPV4) += nf_tproxy_ipv4.o + +# reject +obj-$(CONFIG_NF_REJECT_IPV4) += nf_reject_ipv4.o + +# NAT helpers (nf_conntrack) +obj-$(CONFIG_NF_NAT_H323) += nf_nat_h323.o +obj-$(CONFIG_NF_NAT_PPTP) += nf_nat_pptp.o + +nf_nat_snmp_basic-y := nf_nat_snmp_basic.asn1.o nf_nat_snmp_basic_main.o +$(obj)/nf_nat_snmp_basic_main.o: $(obj)/nf_nat_snmp_basic.asn1.h +obj-$(CONFIG_NF_NAT_SNMP_BASIC) += nf_nat_snmp_basic.o + +obj-$(CONFIG_NFT_REJECT_IPV4) += nft_reject_ipv4.o +obj-$(CONFIG_NFT_FIB_IPV4) += nft_fib_ipv4.o +obj-$(CONFIG_NFT_DUP_IPV4) += nft_dup_ipv4.o + +# generic IP tables +obj-$(CONFIG_IP_NF_IPTABLES) += ip_tables.o + +# the three instances of ip_tables +obj-$(CONFIG_IP_NF_FILTER) += iptable_filter.o +obj-$(CONFIG_IP_NF_MANGLE) += iptable_mangle.o +obj-$(CONFIG_IP_NF_NAT) += iptable_nat.o +obj-$(CONFIG_IP_NF_RAW) += iptable_raw.o +obj-$(CONFIG_IP_NF_SECURITY) += iptable_security.o + +# matches +obj-$(CONFIG_IP_NF_MATCH_AH) += ipt_ah.o +obj-$(CONFIG_IP_NF_MATCH_RPFILTER) += ipt_rpfilter.o + +# targets +obj-$(CONFIG_IP_NF_TARGET_CLUSTERIP) += ipt_CLUSTERIP.o +obj-$(CONFIG_IP_NF_TARGET_ECN) += ipt_ECN.o +obj-$(CONFIG_IP_NF_TARGET_REJECT) += ipt_REJECT.o +obj-$(CONFIG_IP_NF_TARGET_SYNPROXY) += ipt_SYNPROXY.o + +# generic ARP tables +obj-$(CONFIG_IP_NF_ARPTABLES) += arp_tables.o +obj-$(CONFIG_IP_NF_ARP_MANGLE) += arpt_mangle.o + +# just filtering instance of ARP tables for now +obj-$(CONFIG_IP_NF_ARPFILTER) += arptable_filter.o + +obj-$(CONFIG_NF_DUP_IPV4) += nf_dup_ipv4.o diff --git a/net/ipv4/netfilter/arp_tables.c b/net/ipv4/netfilter/arp_tables.c new file mode 100644 index 000000000..2407066b0 --- /dev/null +++ b/net/ipv4/netfilter/arp_tables.c @@ -0,0 +1,1667 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Packet matching code for ARP packets. + * + * Based heavily, if not almost entirely, upon ip_tables.c framework. + * + * Some ARP specific bits are: + * + * Copyright (C) 2002 David S. Miller (davem@redhat.com) + * Copyright (C) 2006-2009 Patrick McHardy + * + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "../../netfilter/xt_repldata.h" + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("David S. Miller "); +MODULE_DESCRIPTION("arptables core"); + +void *arpt_alloc_initial_table(const struct xt_table *info) +{ + return xt_alloc_initial_table(arpt, ARPT); +} +EXPORT_SYMBOL_GPL(arpt_alloc_initial_table); + +static inline int arp_devaddr_compare(const struct arpt_devaddr_info *ap, + const char *hdr_addr, int len) +{ + int i, ret; + + if (len > ARPT_DEV_ADDR_LEN_MAX) + len = ARPT_DEV_ADDR_LEN_MAX; + + ret = 0; + for (i = 0; i < len; i++) + ret |= (hdr_addr[i] ^ ap->addr[i]) & ap->mask[i]; + + return ret != 0; +} + +/* + * Unfortunately, _b and _mask are not aligned to an int (or long int) + * Some arches dont care, unrolling the loop is a win on them. + * For other arches, we only have a 16bit alignement. + */ +static unsigned long ifname_compare(const char *_a, const char *_b, const char *_mask) +{ +#ifdef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS + unsigned long ret = ifname_compare_aligned(_a, _b, _mask); +#else + unsigned long ret = 0; + const u16 *a = (const u16 *)_a; + const u16 *b = (const u16 *)_b; + const u16 *mask = (const u16 *)_mask; + int i; + + for (i = 0; i < IFNAMSIZ/sizeof(u16); i++) + ret |= (a[i] ^ b[i]) & mask[i]; +#endif + return ret; +} + +/* Returns whether packet matches rule or not. */ +static inline int arp_packet_match(const struct arphdr *arphdr, + struct net_device *dev, + const char *indev, + const char *outdev, + const struct arpt_arp *arpinfo) +{ + const char *arpptr = (char *)(arphdr + 1); + const char *src_devaddr, *tgt_devaddr; + __be32 src_ipaddr, tgt_ipaddr; + long ret; + + if (NF_INVF(arpinfo, ARPT_INV_ARPOP, + (arphdr->ar_op & arpinfo->arpop_mask) != arpinfo->arpop)) + return 0; + + if (NF_INVF(arpinfo, ARPT_INV_ARPHRD, + (arphdr->ar_hrd & arpinfo->arhrd_mask) != arpinfo->arhrd)) + return 0; + + if (NF_INVF(arpinfo, ARPT_INV_ARPPRO, + (arphdr->ar_pro & arpinfo->arpro_mask) != arpinfo->arpro)) + return 0; + + if (NF_INVF(arpinfo, ARPT_INV_ARPHLN, + (arphdr->ar_hln & arpinfo->arhln_mask) != arpinfo->arhln)) + return 0; + + src_devaddr = arpptr; + arpptr += dev->addr_len; + memcpy(&src_ipaddr, arpptr, sizeof(u32)); + arpptr += sizeof(u32); + tgt_devaddr = arpptr; + arpptr += dev->addr_len; + memcpy(&tgt_ipaddr, arpptr, sizeof(u32)); + + if (NF_INVF(arpinfo, ARPT_INV_SRCDEVADDR, + arp_devaddr_compare(&arpinfo->src_devaddr, src_devaddr, + dev->addr_len)) || + NF_INVF(arpinfo, ARPT_INV_TGTDEVADDR, + arp_devaddr_compare(&arpinfo->tgt_devaddr, tgt_devaddr, + dev->addr_len))) + return 0; + + if (NF_INVF(arpinfo, ARPT_INV_SRCIP, + (src_ipaddr & arpinfo->smsk.s_addr) != arpinfo->src.s_addr) || + NF_INVF(arpinfo, ARPT_INV_TGTIP, + (tgt_ipaddr & arpinfo->tmsk.s_addr) != arpinfo->tgt.s_addr)) + return 0; + + /* Look for ifname matches. */ + ret = ifname_compare(indev, arpinfo->iniface, arpinfo->iniface_mask); + + if (NF_INVF(arpinfo, ARPT_INV_VIA_IN, ret != 0)) + return 0; + + ret = ifname_compare(outdev, arpinfo->outiface, arpinfo->outiface_mask); + + if (NF_INVF(arpinfo, ARPT_INV_VIA_OUT, ret != 0)) + return 0; + + return 1; +} + +static inline int arp_checkentry(const struct arpt_arp *arp) +{ + if (arp->flags & ~ARPT_F_MASK) + return 0; + if (arp->invflags & ~ARPT_INV_MASK) + return 0; + + return 1; +} + +static unsigned int +arpt_error(struct sk_buff *skb, const struct xt_action_param *par) +{ + net_err_ratelimited("arp_tables: error: '%s'\n", + (const char *)par->targinfo); + + return NF_DROP; +} + +static inline const struct xt_entry_target * +arpt_get_target_c(const struct arpt_entry *e) +{ + return arpt_get_target((struct arpt_entry *)e); +} + +static inline struct arpt_entry * +get_entry(const void *base, unsigned int offset) +{ + return (struct arpt_entry *)(base + offset); +} + +static inline +struct arpt_entry *arpt_next_entry(const struct arpt_entry *entry) +{ + return (void *)entry + entry->next_offset; +} + +unsigned int arpt_do_table(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + const struct xt_table *table = priv; + unsigned int hook = state->hook; + static const char nulldevname[IFNAMSIZ] __attribute__((aligned(sizeof(long)))); + unsigned int verdict = NF_DROP; + const struct arphdr *arp; + struct arpt_entry *e, **jumpstack; + const char *indev, *outdev; + const void *table_base; + unsigned int cpu, stackidx = 0; + const struct xt_table_info *private; + struct xt_action_param acpar; + unsigned int addend; + + if (!pskb_may_pull(skb, arp_hdr_len(skb->dev))) + return NF_DROP; + + indev = state->in ? state->in->name : nulldevname; + outdev = state->out ? state->out->name : nulldevname; + + local_bh_disable(); + addend = xt_write_recseq_begin(); + private = READ_ONCE(table->private); /* Address dependency. */ + cpu = smp_processor_id(); + table_base = private->entries; + jumpstack = (struct arpt_entry **)private->jumpstack[cpu]; + + /* No TEE support for arptables, so no need to switch to alternate + * stack. All targets that reenter must return absolute verdicts. + */ + e = get_entry(table_base, private->hook_entry[hook]); + + acpar.state = state; + acpar.hotdrop = false; + + arp = arp_hdr(skb); + do { + const struct xt_entry_target *t; + struct xt_counters *counter; + + if (!arp_packet_match(arp, skb->dev, indev, outdev, &e->arp)) { + e = arpt_next_entry(e); + continue; + } + + counter = xt_get_this_cpu_counter(&e->counters); + ADD_COUNTER(*counter, arp_hdr_len(skb->dev), 1); + + t = arpt_get_target_c(e); + + /* Standard target? */ + if (!t->u.kernel.target->target) { + int v; + + v = ((struct xt_standard_target *)t)->verdict; + if (v < 0) { + /* Pop from stack? */ + if (v != XT_RETURN) { + verdict = (unsigned int)(-v) - 1; + break; + } + if (stackidx == 0) { + e = get_entry(table_base, + private->underflow[hook]); + } else { + e = jumpstack[--stackidx]; + e = arpt_next_entry(e); + } + continue; + } + if (table_base + v + != arpt_next_entry(e)) { + if (unlikely(stackidx >= private->stacksize)) { + verdict = NF_DROP; + break; + } + jumpstack[stackidx++] = e; + } + + e = get_entry(table_base, v); + continue; + } + + acpar.target = t->u.kernel.target; + acpar.targinfo = t->data; + verdict = t->u.kernel.target->target(skb, &acpar); + + if (verdict == XT_CONTINUE) { + /* Target might have changed stuff. */ + arp = arp_hdr(skb); + e = arpt_next_entry(e); + } else { + /* Verdict */ + break; + } + } while (!acpar.hotdrop); + xt_write_recseq_end(addend); + local_bh_enable(); + + if (acpar.hotdrop) + return NF_DROP; + else + return verdict; +} + +/* All zeroes == unconditional rule. */ +static inline bool unconditional(const struct arpt_entry *e) +{ + static const struct arpt_arp uncond; + + return e->target_offset == sizeof(struct arpt_entry) && + memcmp(&e->arp, &uncond, sizeof(uncond)) == 0; +} + +/* Figures out from what hook each rule can be called: returns 0 if + * there are loops. Puts hook bitmask in comefrom. + */ +static int mark_source_chains(const struct xt_table_info *newinfo, + unsigned int valid_hooks, void *entry0, + unsigned int *offsets) +{ + unsigned int hook; + + /* No recursion; use packet counter to save back ptrs (reset + * to 0 as we leave), and comefrom to save source hook bitmask. + */ + for (hook = 0; hook < NF_ARP_NUMHOOKS; hook++) { + unsigned int pos = newinfo->hook_entry[hook]; + struct arpt_entry *e = entry0 + pos; + + if (!(valid_hooks & (1 << hook))) + continue; + + /* Set initial back pointer. */ + e->counters.pcnt = pos; + + for (;;) { + const struct xt_standard_target *t + = (void *)arpt_get_target_c(e); + int visited = e->comefrom & (1 << hook); + + if (e->comefrom & (1 << NF_ARP_NUMHOOKS)) + return 0; + + e->comefrom + |= ((1 << hook) | (1 << NF_ARP_NUMHOOKS)); + + /* Unconditional return/END. */ + if ((unconditional(e) && + (strcmp(t->target.u.user.name, + XT_STANDARD_TARGET) == 0) && + t->verdict < 0) || visited) { + unsigned int oldpos, size; + + /* Return: backtrack through the last + * big jump. + */ + do { + e->comefrom ^= (1<counters.pcnt; + e->counters.pcnt = 0; + + /* We're at the start. */ + if (pos == oldpos) + goto next; + + e = entry0 + pos; + } while (oldpos == pos + e->next_offset); + + /* Move along one */ + size = e->next_offset; + e = entry0 + pos + size; + if (pos + size >= newinfo->size) + return 0; + e->counters.pcnt = pos; + pos += size; + } else { + int newpos = t->verdict; + + if (strcmp(t->target.u.user.name, + XT_STANDARD_TARGET) == 0 && + newpos >= 0) { + /* This a jump; chase it. */ + if (!xt_find_jump_offset(offsets, newpos, + newinfo->number)) + return 0; + } else { + /* ... this is a fallthru */ + newpos = pos + e->next_offset; + if (newpos >= newinfo->size) + return 0; + } + e = entry0 + newpos; + e->counters.pcnt = pos; + pos = newpos; + } + } +next: ; + } + return 1; +} + +static int check_target(struct arpt_entry *e, struct net *net, const char *name) +{ + struct xt_entry_target *t = arpt_get_target(e); + struct xt_tgchk_param par = { + .net = net, + .table = name, + .entryinfo = e, + .target = t->u.kernel.target, + .targinfo = t->data, + .hook_mask = e->comefrom, + .family = NFPROTO_ARP, + }; + + return xt_check_target(&par, t->u.target_size - sizeof(*t), 0, false); +} + +static int +find_check_entry(struct arpt_entry *e, struct net *net, const char *name, + unsigned int size, + struct xt_percpu_counter_alloc_state *alloc_state) +{ + struct xt_entry_target *t; + struct xt_target *target; + int ret; + + if (!xt_percpu_counter_alloc(alloc_state, &e->counters)) + return -ENOMEM; + + t = arpt_get_target(e); + target = xt_request_find_target(NFPROTO_ARP, t->u.user.name, + t->u.user.revision); + if (IS_ERR(target)) { + ret = PTR_ERR(target); + goto out; + } + t->u.kernel.target = target; + + ret = check_target(e, net, name); + if (ret) + goto err; + return 0; +err: + module_put(t->u.kernel.target->me); +out: + xt_percpu_counter_free(&e->counters); + + return ret; +} + +static bool check_underflow(const struct arpt_entry *e) +{ + const struct xt_entry_target *t; + unsigned int verdict; + + if (!unconditional(e)) + return false; + t = arpt_get_target_c(e); + if (strcmp(t->u.user.name, XT_STANDARD_TARGET) != 0) + return false; + verdict = ((struct xt_standard_target *)t)->verdict; + verdict = -verdict - 1; + return verdict == NF_DROP || verdict == NF_ACCEPT; +} + +static inline int check_entry_size_and_hooks(struct arpt_entry *e, + struct xt_table_info *newinfo, + const unsigned char *base, + const unsigned char *limit, + const unsigned int *hook_entries, + const unsigned int *underflows, + unsigned int valid_hooks) +{ + unsigned int h; + int err; + + if ((unsigned long)e % __alignof__(struct arpt_entry) != 0 || + (unsigned char *)e + sizeof(struct arpt_entry) >= limit || + (unsigned char *)e + e->next_offset > limit) + return -EINVAL; + + if (e->next_offset + < sizeof(struct arpt_entry) + sizeof(struct xt_entry_target)) + return -EINVAL; + + if (!arp_checkentry(&e->arp)) + return -EINVAL; + + err = xt_check_entry_offsets(e, e->elems, e->target_offset, + e->next_offset); + if (err) + return err; + + /* Check hooks & underflows */ + for (h = 0; h < NF_ARP_NUMHOOKS; h++) { + if (!(valid_hooks & (1 << h))) + continue; + if ((unsigned char *)e - base == hook_entries[h]) + newinfo->hook_entry[h] = hook_entries[h]; + if ((unsigned char *)e - base == underflows[h]) { + if (!check_underflow(e)) + return -EINVAL; + + newinfo->underflow[h] = underflows[h]; + } + } + + /* Clear counters and comefrom */ + e->counters = ((struct xt_counters) { 0, 0 }); + e->comefrom = 0; + return 0; +} + +static void cleanup_entry(struct arpt_entry *e, struct net *net) +{ + struct xt_tgdtor_param par; + struct xt_entry_target *t; + + t = arpt_get_target(e); + par.net = net; + par.target = t->u.kernel.target; + par.targinfo = t->data; + par.family = NFPROTO_ARP; + if (par.target->destroy != NULL) + par.target->destroy(&par); + module_put(par.target->me); + xt_percpu_counter_free(&e->counters); +} + +/* Checks and translates the user-supplied table segment (held in + * newinfo). + */ +static int translate_table(struct net *net, + struct xt_table_info *newinfo, + void *entry0, + const struct arpt_replace *repl) +{ + struct xt_percpu_counter_alloc_state alloc_state = { 0 }; + struct arpt_entry *iter; + unsigned int *offsets; + unsigned int i; + int ret = 0; + + newinfo->size = repl->size; + newinfo->number = repl->num_entries; + + /* Init all hooks to impossible value. */ + for (i = 0; i < NF_ARP_NUMHOOKS; i++) { + newinfo->hook_entry[i] = 0xFFFFFFFF; + newinfo->underflow[i] = 0xFFFFFFFF; + } + + offsets = xt_alloc_entry_offsets(newinfo->number); + if (!offsets) + return -ENOMEM; + i = 0; + + /* Walk through entries, checking offsets. */ + xt_entry_foreach(iter, entry0, newinfo->size) { + ret = check_entry_size_and_hooks(iter, newinfo, entry0, + entry0 + repl->size, + repl->hook_entry, + repl->underflow, + repl->valid_hooks); + if (ret != 0) + goto out_free; + if (i < repl->num_entries) + offsets[i] = (void *)iter - entry0; + ++i; + if (strcmp(arpt_get_target(iter)->u.user.name, + XT_ERROR_TARGET) == 0) + ++newinfo->stacksize; + } + + ret = -EINVAL; + if (i != repl->num_entries) + goto out_free; + + ret = xt_check_table_hooks(newinfo, repl->valid_hooks); + if (ret) + goto out_free; + + if (!mark_source_chains(newinfo, repl->valid_hooks, entry0, offsets)) { + ret = -ELOOP; + goto out_free; + } + kvfree(offsets); + + /* Finally, each sanity check must pass */ + i = 0; + xt_entry_foreach(iter, entry0, newinfo->size) { + ret = find_check_entry(iter, net, repl->name, repl->size, + &alloc_state); + if (ret != 0) + break; + ++i; + } + + if (ret != 0) { + xt_entry_foreach(iter, entry0, newinfo->size) { + if (i-- == 0) + break; + cleanup_entry(iter, net); + } + return ret; + } + + return ret; + out_free: + kvfree(offsets); + return ret; +} + +static void get_counters(const struct xt_table_info *t, + struct xt_counters counters[]) +{ + struct arpt_entry *iter; + unsigned int cpu; + unsigned int i; + + for_each_possible_cpu(cpu) { + seqcount_t *s = &per_cpu(xt_recseq, cpu); + + i = 0; + xt_entry_foreach(iter, t->entries, t->size) { + struct xt_counters *tmp; + u64 bcnt, pcnt; + unsigned int start; + + tmp = xt_get_per_cpu_counter(&iter->counters, cpu); + do { + start = read_seqcount_begin(s); + bcnt = tmp->bcnt; + pcnt = tmp->pcnt; + } while (read_seqcount_retry(s, start)); + + ADD_COUNTER(counters[i], bcnt, pcnt); + ++i; + cond_resched(); + } + } +} + +static void get_old_counters(const struct xt_table_info *t, + struct xt_counters counters[]) +{ + struct arpt_entry *iter; + unsigned int cpu, i; + + for_each_possible_cpu(cpu) { + i = 0; + xt_entry_foreach(iter, t->entries, t->size) { + struct xt_counters *tmp; + + tmp = xt_get_per_cpu_counter(&iter->counters, cpu); + ADD_COUNTER(counters[i], tmp->bcnt, tmp->pcnt); + ++i; + } + cond_resched(); + } +} + +static struct xt_counters *alloc_counters(const struct xt_table *table) +{ + unsigned int countersize; + struct xt_counters *counters; + const struct xt_table_info *private = table->private; + + /* We need atomic snapshot of counters: rest doesn't change + * (other than comefrom, which userspace doesn't care + * about). + */ + countersize = sizeof(struct xt_counters) * private->number; + counters = vzalloc(countersize); + + if (counters == NULL) + return ERR_PTR(-ENOMEM); + + get_counters(private, counters); + + return counters; +} + +static int copy_entries_to_user(unsigned int total_size, + const struct xt_table *table, + void __user *userptr) +{ + unsigned int off, num; + const struct arpt_entry *e; + struct xt_counters *counters; + struct xt_table_info *private = table->private; + int ret = 0; + void *loc_cpu_entry; + + counters = alloc_counters(table); + if (IS_ERR(counters)) + return PTR_ERR(counters); + + loc_cpu_entry = private->entries; + + /* FIXME: use iterator macros --RR */ + /* ... then go back and fix counters and names */ + for (off = 0, num = 0; off < total_size; off += e->next_offset, num++){ + const struct xt_entry_target *t; + + e = loc_cpu_entry + off; + if (copy_to_user(userptr + off, e, sizeof(*e))) { + ret = -EFAULT; + goto free_counters; + } + if (copy_to_user(userptr + off + + offsetof(struct arpt_entry, counters), + &counters[num], + sizeof(counters[num])) != 0) { + ret = -EFAULT; + goto free_counters; + } + + t = arpt_get_target_c(e); + if (xt_target_to_user(t, userptr + off + e->target_offset)) { + ret = -EFAULT; + goto free_counters; + } + } + + free_counters: + vfree(counters); + return ret; +} + +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT +static void compat_standard_from_user(void *dst, const void *src) +{ + int v = *(compat_int_t *)src; + + if (v > 0) + v += xt_compat_calc_jump(NFPROTO_ARP, v); + memcpy(dst, &v, sizeof(v)); +} + +static int compat_standard_to_user(void __user *dst, const void *src) +{ + compat_int_t cv = *(int *)src; + + if (cv > 0) + cv -= xt_compat_calc_jump(NFPROTO_ARP, cv); + return copy_to_user(dst, &cv, sizeof(cv)) ? -EFAULT : 0; +} + +static int compat_calc_entry(const struct arpt_entry *e, + const struct xt_table_info *info, + const void *base, struct xt_table_info *newinfo) +{ + const struct xt_entry_target *t; + unsigned int entry_offset; + int off, i, ret; + + off = sizeof(struct arpt_entry) - sizeof(struct compat_arpt_entry); + entry_offset = (void *)e - base; + + t = arpt_get_target_c(e); + off += xt_compat_target_offset(t->u.kernel.target); + newinfo->size -= off; + ret = xt_compat_add_offset(NFPROTO_ARP, entry_offset, off); + if (ret) + return ret; + + for (i = 0; i < NF_ARP_NUMHOOKS; i++) { + if (info->hook_entry[i] && + (e < (struct arpt_entry *)(base + info->hook_entry[i]))) + newinfo->hook_entry[i] -= off; + if (info->underflow[i] && + (e < (struct arpt_entry *)(base + info->underflow[i]))) + newinfo->underflow[i] -= off; + } + return 0; +} + +static int compat_table_info(const struct xt_table_info *info, + struct xt_table_info *newinfo) +{ + struct arpt_entry *iter; + const void *loc_cpu_entry; + int ret; + + if (!newinfo || !info) + return -EINVAL; + + /* we dont care about newinfo->entries */ + memcpy(newinfo, info, offsetof(struct xt_table_info, entries)); + newinfo->initial_entries = 0; + loc_cpu_entry = info->entries; + ret = xt_compat_init_offsets(NFPROTO_ARP, info->number); + if (ret) + return ret; + xt_entry_foreach(iter, loc_cpu_entry, info->size) { + ret = compat_calc_entry(iter, info, loc_cpu_entry, newinfo); + if (ret != 0) + return ret; + } + return 0; +} +#endif + +static int get_info(struct net *net, void __user *user, const int *len) +{ + char name[XT_TABLE_MAXNAMELEN]; + struct xt_table *t; + int ret; + + if (*len != sizeof(struct arpt_getinfo)) + return -EINVAL; + + if (copy_from_user(name, user, sizeof(name)) != 0) + return -EFAULT; + + name[XT_TABLE_MAXNAMELEN-1] = '\0'; +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + if (in_compat_syscall()) + xt_compat_lock(NFPROTO_ARP); +#endif + t = xt_request_find_table_lock(net, NFPROTO_ARP, name); + if (!IS_ERR(t)) { + struct arpt_getinfo info; + const struct xt_table_info *private = t->private; +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + struct xt_table_info tmp; + + if (in_compat_syscall()) { + ret = compat_table_info(private, &tmp); + xt_compat_flush_offsets(NFPROTO_ARP); + private = &tmp; + } +#endif + memset(&info, 0, sizeof(info)); + info.valid_hooks = t->valid_hooks; + memcpy(info.hook_entry, private->hook_entry, + sizeof(info.hook_entry)); + memcpy(info.underflow, private->underflow, + sizeof(info.underflow)); + info.num_entries = private->number; + info.size = private->size; + strcpy(info.name, name); + + if (copy_to_user(user, &info, *len) != 0) + ret = -EFAULT; + else + ret = 0; + xt_table_unlock(t); + module_put(t->me); + } else + ret = PTR_ERR(t); +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + if (in_compat_syscall()) + xt_compat_unlock(NFPROTO_ARP); +#endif + return ret; +} + +static int get_entries(struct net *net, struct arpt_get_entries __user *uptr, + const int *len) +{ + int ret; + struct arpt_get_entries get; + struct xt_table *t; + + if (*len < sizeof(get)) + return -EINVAL; + if (copy_from_user(&get, uptr, sizeof(get)) != 0) + return -EFAULT; + if (*len != sizeof(struct arpt_get_entries) + get.size) + return -EINVAL; + + get.name[sizeof(get.name) - 1] = '\0'; + + t = xt_find_table_lock(net, NFPROTO_ARP, get.name); + if (!IS_ERR(t)) { + const struct xt_table_info *private = t->private; + + if (get.size == private->size) + ret = copy_entries_to_user(private->size, + t, uptr->entrytable); + else + ret = -EAGAIN; + + module_put(t->me); + xt_table_unlock(t); + } else + ret = PTR_ERR(t); + + return ret; +} + +static int __do_replace(struct net *net, const char *name, + unsigned int valid_hooks, + struct xt_table_info *newinfo, + unsigned int num_counters, + void __user *counters_ptr) +{ + int ret; + struct xt_table *t; + struct xt_table_info *oldinfo; + struct xt_counters *counters; + void *loc_cpu_old_entry; + struct arpt_entry *iter; + + ret = 0; + counters = xt_counters_alloc(num_counters); + if (!counters) { + ret = -ENOMEM; + goto out; + } + + t = xt_request_find_table_lock(net, NFPROTO_ARP, name); + if (IS_ERR(t)) { + ret = PTR_ERR(t); + goto free_newinfo_counters_untrans; + } + + /* You lied! */ + if (valid_hooks != t->valid_hooks) { + ret = -EINVAL; + goto put_module; + } + + oldinfo = xt_replace_table(t, num_counters, newinfo, &ret); + if (!oldinfo) + goto put_module; + + /* Update module usage count based on number of rules */ + if ((oldinfo->number > oldinfo->initial_entries) || + (newinfo->number <= oldinfo->initial_entries)) + module_put(t->me); + if ((oldinfo->number > oldinfo->initial_entries) && + (newinfo->number <= oldinfo->initial_entries)) + module_put(t->me); + + xt_table_unlock(t); + + get_old_counters(oldinfo, counters); + + /* Decrease module usage counts and free resource */ + loc_cpu_old_entry = oldinfo->entries; + xt_entry_foreach(iter, loc_cpu_old_entry, oldinfo->size) + cleanup_entry(iter, net); + + xt_free_table_info(oldinfo); + if (copy_to_user(counters_ptr, counters, + sizeof(struct xt_counters) * num_counters) != 0) { + /* Silent error, can't fail, new table is already in place */ + net_warn_ratelimited("arptables: counters copy to user failed while replacing table\n"); + } + vfree(counters); + return ret; + + put_module: + module_put(t->me); + xt_table_unlock(t); + free_newinfo_counters_untrans: + vfree(counters); + out: + return ret; +} + +static int do_replace(struct net *net, sockptr_t arg, unsigned int len) +{ + int ret; + struct arpt_replace tmp; + struct xt_table_info *newinfo; + void *loc_cpu_entry; + struct arpt_entry *iter; + + if (copy_from_sockptr(&tmp, arg, sizeof(tmp)) != 0) + return -EFAULT; + + /* overflow check */ + if (tmp.num_counters >= INT_MAX / sizeof(struct xt_counters)) + return -ENOMEM; + if (tmp.num_counters == 0) + return -EINVAL; + + tmp.name[sizeof(tmp.name)-1] = 0; + + newinfo = xt_alloc_table_info(tmp.size); + if (!newinfo) + return -ENOMEM; + + loc_cpu_entry = newinfo->entries; + if (copy_from_sockptr_offset(loc_cpu_entry, arg, sizeof(tmp), + tmp.size) != 0) { + ret = -EFAULT; + goto free_newinfo; + } + + ret = translate_table(net, newinfo, loc_cpu_entry, &tmp); + if (ret != 0) + goto free_newinfo; + + ret = __do_replace(net, tmp.name, tmp.valid_hooks, newinfo, + tmp.num_counters, tmp.counters); + if (ret) + goto free_newinfo_untrans; + return 0; + + free_newinfo_untrans: + xt_entry_foreach(iter, loc_cpu_entry, newinfo->size) + cleanup_entry(iter, net); + free_newinfo: + xt_free_table_info(newinfo); + return ret; +} + +static int do_add_counters(struct net *net, sockptr_t arg, unsigned int len) +{ + unsigned int i; + struct xt_counters_info tmp; + struct xt_counters *paddc; + struct xt_table *t; + const struct xt_table_info *private; + int ret = 0; + struct arpt_entry *iter; + unsigned int addend; + + paddc = xt_copy_counters(arg, len, &tmp); + if (IS_ERR(paddc)) + return PTR_ERR(paddc); + + t = xt_find_table_lock(net, NFPROTO_ARP, tmp.name); + if (IS_ERR(t)) { + ret = PTR_ERR(t); + goto free; + } + + local_bh_disable(); + private = t->private; + if (private->number != tmp.num_counters) { + ret = -EINVAL; + goto unlock_up_free; + } + + i = 0; + + addend = xt_write_recseq_begin(); + xt_entry_foreach(iter, private->entries, private->size) { + struct xt_counters *tmp; + + tmp = xt_get_this_cpu_counter(&iter->counters); + ADD_COUNTER(*tmp, paddc[i].bcnt, paddc[i].pcnt); + ++i; + } + xt_write_recseq_end(addend); + unlock_up_free: + local_bh_enable(); + xt_table_unlock(t); + module_put(t->me); + free: + vfree(paddc); + + return ret; +} + +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT +struct compat_arpt_replace { + char name[XT_TABLE_MAXNAMELEN]; + u32 valid_hooks; + u32 num_entries; + u32 size; + u32 hook_entry[NF_ARP_NUMHOOKS]; + u32 underflow[NF_ARP_NUMHOOKS]; + u32 num_counters; + compat_uptr_t counters; + struct compat_arpt_entry entries[]; +}; + +static inline void compat_release_entry(struct compat_arpt_entry *e) +{ + struct xt_entry_target *t; + + t = compat_arpt_get_target(e); + module_put(t->u.kernel.target->me); +} + +static int +check_compat_entry_size_and_hooks(struct compat_arpt_entry *e, + struct xt_table_info *newinfo, + unsigned int *size, + const unsigned char *base, + const unsigned char *limit) +{ + struct xt_entry_target *t; + struct xt_target *target; + unsigned int entry_offset; + int ret, off; + + if ((unsigned long)e % __alignof__(struct compat_arpt_entry) != 0 || + (unsigned char *)e + sizeof(struct compat_arpt_entry) >= limit || + (unsigned char *)e + e->next_offset > limit) + return -EINVAL; + + if (e->next_offset < sizeof(struct compat_arpt_entry) + + sizeof(struct compat_xt_entry_target)) + return -EINVAL; + + if (!arp_checkentry(&e->arp)) + return -EINVAL; + + ret = xt_compat_check_entry_offsets(e, e->elems, e->target_offset, + e->next_offset); + if (ret) + return ret; + + off = sizeof(struct arpt_entry) - sizeof(struct compat_arpt_entry); + entry_offset = (void *)e - (void *)base; + + t = compat_arpt_get_target(e); + target = xt_request_find_target(NFPROTO_ARP, t->u.user.name, + t->u.user.revision); + if (IS_ERR(target)) { + ret = PTR_ERR(target); + goto out; + } + t->u.kernel.target = target; + + off += xt_compat_target_offset(target); + *size += off; + ret = xt_compat_add_offset(NFPROTO_ARP, entry_offset, off); + if (ret) + goto release_target; + + return 0; + +release_target: + module_put(t->u.kernel.target->me); +out: + return ret; +} + +static void +compat_copy_entry_from_user(struct compat_arpt_entry *e, void **dstptr, + unsigned int *size, + struct xt_table_info *newinfo, unsigned char *base) +{ + struct xt_entry_target *t; + struct arpt_entry *de; + unsigned int origsize; + int h; + + origsize = *size; + de = *dstptr; + memcpy(de, e, sizeof(struct arpt_entry)); + memcpy(&de->counters, &e->counters, sizeof(e->counters)); + + *dstptr += sizeof(struct arpt_entry); + *size += sizeof(struct arpt_entry) - sizeof(struct compat_arpt_entry); + + de->target_offset = e->target_offset - (origsize - *size); + t = compat_arpt_get_target(e); + xt_compat_target_from_user(t, dstptr, size); + + de->next_offset = e->next_offset - (origsize - *size); + for (h = 0; h < NF_ARP_NUMHOOKS; h++) { + if ((unsigned char *)de - base < newinfo->hook_entry[h]) + newinfo->hook_entry[h] -= origsize - *size; + if ((unsigned char *)de - base < newinfo->underflow[h]) + newinfo->underflow[h] -= origsize - *size; + } +} + +static int translate_compat_table(struct net *net, + struct xt_table_info **pinfo, + void **pentry0, + const struct compat_arpt_replace *compatr) +{ + unsigned int i, j; + struct xt_table_info *newinfo, *info; + void *pos, *entry0, *entry1; + struct compat_arpt_entry *iter0; + struct arpt_replace repl; + unsigned int size; + int ret; + + info = *pinfo; + entry0 = *pentry0; + size = compatr->size; + info->number = compatr->num_entries; + + j = 0; + xt_compat_lock(NFPROTO_ARP); + ret = xt_compat_init_offsets(NFPROTO_ARP, compatr->num_entries); + if (ret) + goto out_unlock; + /* Walk through entries, checking offsets. */ + xt_entry_foreach(iter0, entry0, compatr->size) { + ret = check_compat_entry_size_and_hooks(iter0, info, &size, + entry0, + entry0 + compatr->size); + if (ret != 0) + goto out_unlock; + ++j; + } + + ret = -EINVAL; + if (j != compatr->num_entries) + goto out_unlock; + + ret = -ENOMEM; + newinfo = xt_alloc_table_info(size); + if (!newinfo) + goto out_unlock; + + memset(newinfo->entries, 0, size); + + newinfo->number = compatr->num_entries; + for (i = 0; i < NF_ARP_NUMHOOKS; i++) { + newinfo->hook_entry[i] = compatr->hook_entry[i]; + newinfo->underflow[i] = compatr->underflow[i]; + } + entry1 = newinfo->entries; + pos = entry1; + size = compatr->size; + xt_entry_foreach(iter0, entry0, compatr->size) + compat_copy_entry_from_user(iter0, &pos, &size, + newinfo, entry1); + + /* all module references in entry0 are now gone */ + + xt_compat_flush_offsets(NFPROTO_ARP); + xt_compat_unlock(NFPROTO_ARP); + + memcpy(&repl, compatr, sizeof(*compatr)); + + for (i = 0; i < NF_ARP_NUMHOOKS; i++) { + repl.hook_entry[i] = newinfo->hook_entry[i]; + repl.underflow[i] = newinfo->underflow[i]; + } + + repl.num_counters = 0; + repl.counters = NULL; + repl.size = newinfo->size; + ret = translate_table(net, newinfo, entry1, &repl); + if (ret) + goto free_newinfo; + + *pinfo = newinfo; + *pentry0 = entry1; + xt_free_table_info(info); + return 0; + +free_newinfo: + xt_free_table_info(newinfo); + return ret; +out_unlock: + xt_compat_flush_offsets(NFPROTO_ARP); + xt_compat_unlock(NFPROTO_ARP); + xt_entry_foreach(iter0, entry0, compatr->size) { + if (j-- == 0) + break; + compat_release_entry(iter0); + } + return ret; +} + +static int compat_do_replace(struct net *net, sockptr_t arg, unsigned int len) +{ + int ret; + struct compat_arpt_replace tmp; + struct xt_table_info *newinfo; + void *loc_cpu_entry; + struct arpt_entry *iter; + + if (copy_from_sockptr(&tmp, arg, sizeof(tmp)) != 0) + return -EFAULT; + + /* overflow check */ + if (tmp.num_counters >= INT_MAX / sizeof(struct xt_counters)) + return -ENOMEM; + if (tmp.num_counters == 0) + return -EINVAL; + + tmp.name[sizeof(tmp.name)-1] = 0; + + newinfo = xt_alloc_table_info(tmp.size); + if (!newinfo) + return -ENOMEM; + + loc_cpu_entry = newinfo->entries; + if (copy_from_sockptr_offset(loc_cpu_entry, arg, sizeof(tmp), + tmp.size) != 0) { + ret = -EFAULT; + goto free_newinfo; + } + + ret = translate_compat_table(net, &newinfo, &loc_cpu_entry, &tmp); + if (ret != 0) + goto free_newinfo; + + ret = __do_replace(net, tmp.name, tmp.valid_hooks, newinfo, + tmp.num_counters, compat_ptr(tmp.counters)); + if (ret) + goto free_newinfo_untrans; + return 0; + + free_newinfo_untrans: + xt_entry_foreach(iter, loc_cpu_entry, newinfo->size) + cleanup_entry(iter, net); + free_newinfo: + xt_free_table_info(newinfo); + return ret; +} + +static int compat_copy_entry_to_user(struct arpt_entry *e, void __user **dstptr, + compat_uint_t *size, + struct xt_counters *counters, + unsigned int i) +{ + struct xt_entry_target *t; + struct compat_arpt_entry __user *ce; + u_int16_t target_offset, next_offset; + compat_uint_t origsize; + int ret; + + origsize = *size; + ce = *dstptr; + if (copy_to_user(ce, e, sizeof(struct arpt_entry)) != 0 || + copy_to_user(&ce->counters, &counters[i], + sizeof(counters[i])) != 0) + return -EFAULT; + + *dstptr += sizeof(struct compat_arpt_entry); + *size -= sizeof(struct arpt_entry) - sizeof(struct compat_arpt_entry); + + target_offset = e->target_offset - (origsize - *size); + + t = arpt_get_target(e); + ret = xt_compat_target_to_user(t, dstptr, size); + if (ret) + return ret; + next_offset = e->next_offset - (origsize - *size); + if (put_user(target_offset, &ce->target_offset) != 0 || + put_user(next_offset, &ce->next_offset) != 0) + return -EFAULT; + return 0; +} + +static int compat_copy_entries_to_user(unsigned int total_size, + struct xt_table *table, + void __user *userptr) +{ + struct xt_counters *counters; + const struct xt_table_info *private = table->private; + void __user *pos; + unsigned int size; + int ret = 0; + unsigned int i = 0; + struct arpt_entry *iter; + + counters = alloc_counters(table); + if (IS_ERR(counters)) + return PTR_ERR(counters); + + pos = userptr; + size = total_size; + xt_entry_foreach(iter, private->entries, total_size) { + ret = compat_copy_entry_to_user(iter, &pos, + &size, counters, i++); + if (ret != 0) + break; + } + vfree(counters); + return ret; +} + +struct compat_arpt_get_entries { + char name[XT_TABLE_MAXNAMELEN]; + compat_uint_t size; + struct compat_arpt_entry entrytable[]; +}; + +static int compat_get_entries(struct net *net, + struct compat_arpt_get_entries __user *uptr, + int *len) +{ + int ret; + struct compat_arpt_get_entries get; + struct xt_table *t; + + if (*len < sizeof(get)) + return -EINVAL; + if (copy_from_user(&get, uptr, sizeof(get)) != 0) + return -EFAULT; + if (*len != sizeof(struct compat_arpt_get_entries) + get.size) + return -EINVAL; + + get.name[sizeof(get.name) - 1] = '\0'; + + xt_compat_lock(NFPROTO_ARP); + t = xt_find_table_lock(net, NFPROTO_ARP, get.name); + if (!IS_ERR(t)) { + const struct xt_table_info *private = t->private; + struct xt_table_info info; + + ret = compat_table_info(private, &info); + if (!ret && get.size == info.size) { + ret = compat_copy_entries_to_user(private->size, + t, uptr->entrytable); + } else if (!ret) + ret = -EAGAIN; + + xt_compat_flush_offsets(NFPROTO_ARP); + module_put(t->me); + xt_table_unlock(t); + } else + ret = PTR_ERR(t); + + xt_compat_unlock(NFPROTO_ARP); + return ret; +} +#endif + +static int do_arpt_set_ctl(struct sock *sk, int cmd, sockptr_t arg, + unsigned int len) +{ + int ret; + + if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + switch (cmd) { + case ARPT_SO_SET_REPLACE: +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + if (in_compat_syscall()) + ret = compat_do_replace(sock_net(sk), arg, len); + else +#endif + ret = do_replace(sock_net(sk), arg, len); + break; + + case ARPT_SO_SET_ADD_COUNTERS: + ret = do_add_counters(sock_net(sk), arg, len); + break; + + default: + ret = -EINVAL; + } + + return ret; +} + +static int do_arpt_get_ctl(struct sock *sk, int cmd, void __user *user, int *len) +{ + int ret; + + if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + switch (cmd) { + case ARPT_SO_GET_INFO: + ret = get_info(sock_net(sk), user, len); + break; + + case ARPT_SO_GET_ENTRIES: +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + if (in_compat_syscall()) + ret = compat_get_entries(sock_net(sk), user, len); + else +#endif + ret = get_entries(sock_net(sk), user, len); + break; + + case ARPT_SO_GET_REVISION_TARGET: { + struct xt_get_revision rev; + + if (*len != sizeof(rev)) { + ret = -EINVAL; + break; + } + if (copy_from_user(&rev, user, sizeof(rev)) != 0) { + ret = -EFAULT; + break; + } + rev.name[sizeof(rev.name)-1] = 0; + + try_then_request_module(xt_find_revision(NFPROTO_ARP, rev.name, + rev.revision, 1, &ret), + "arpt_%s", rev.name); + break; + } + + default: + ret = -EINVAL; + } + + return ret; +} + +static void __arpt_unregister_table(struct net *net, struct xt_table *table) +{ + struct xt_table_info *private; + void *loc_cpu_entry; + struct module *table_owner = table->me; + struct arpt_entry *iter; + + private = xt_unregister_table(table); + + /* Decrease module usage counts and free resources */ + loc_cpu_entry = private->entries; + xt_entry_foreach(iter, loc_cpu_entry, private->size) + cleanup_entry(iter, net); + if (private->number > private->initial_entries) + module_put(table_owner); + xt_free_table_info(private); +} + +int arpt_register_table(struct net *net, + const struct xt_table *table, + const struct arpt_replace *repl, + const struct nf_hook_ops *template_ops) +{ + struct nf_hook_ops *ops; + unsigned int num_ops; + int ret, i; + struct xt_table_info *newinfo; + struct xt_table_info bootstrap = {0}; + void *loc_cpu_entry; + struct xt_table *new_table; + + newinfo = xt_alloc_table_info(repl->size); + if (!newinfo) + return -ENOMEM; + + loc_cpu_entry = newinfo->entries; + memcpy(loc_cpu_entry, repl->entries, repl->size); + + ret = translate_table(net, newinfo, loc_cpu_entry, repl); + if (ret != 0) { + xt_free_table_info(newinfo); + return ret; + } + + new_table = xt_register_table(net, table, &bootstrap, newinfo); + if (IS_ERR(new_table)) { + struct arpt_entry *iter; + + xt_entry_foreach(iter, loc_cpu_entry, newinfo->size) + cleanup_entry(iter, net); + xt_free_table_info(newinfo); + return PTR_ERR(new_table); + } + + num_ops = hweight32(table->valid_hooks); + if (num_ops == 0) { + ret = -EINVAL; + goto out_free; + } + + ops = kmemdup(template_ops, sizeof(*ops) * num_ops, GFP_KERNEL); + if (!ops) { + ret = -ENOMEM; + goto out_free; + } + + for (i = 0; i < num_ops; i++) + ops[i].priv = new_table; + + new_table->ops = ops; + + ret = nf_register_net_hooks(net, ops, num_ops); + if (ret != 0) + goto out_free; + + return ret; + +out_free: + __arpt_unregister_table(net, new_table); + return ret; +} + +void arpt_unregister_table_pre_exit(struct net *net, const char *name) +{ + struct xt_table *table = xt_find_table(net, NFPROTO_ARP, name); + + if (table) + nf_unregister_net_hooks(net, table->ops, hweight32(table->valid_hooks)); +} +EXPORT_SYMBOL(arpt_unregister_table_pre_exit); + +void arpt_unregister_table(struct net *net, const char *name) +{ + struct xt_table *table = xt_find_table(net, NFPROTO_ARP, name); + + if (table) + __arpt_unregister_table(net, table); +} + +/* The built-in targets: standard (NULL) and error. */ +static struct xt_target arpt_builtin_tg[] __read_mostly = { + { + .name = XT_STANDARD_TARGET, + .targetsize = sizeof(int), + .family = NFPROTO_ARP, +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + .compatsize = sizeof(compat_int_t), + .compat_from_user = compat_standard_from_user, + .compat_to_user = compat_standard_to_user, +#endif + }, + { + .name = XT_ERROR_TARGET, + .target = arpt_error, + .targetsize = XT_FUNCTION_MAXNAMELEN, + .family = NFPROTO_ARP, + }, +}; + +static struct nf_sockopt_ops arpt_sockopts = { + .pf = PF_INET, + .set_optmin = ARPT_BASE_CTL, + .set_optmax = ARPT_SO_SET_MAX+1, + .set = do_arpt_set_ctl, + .get_optmin = ARPT_BASE_CTL, + .get_optmax = ARPT_SO_GET_MAX+1, + .get = do_arpt_get_ctl, + .owner = THIS_MODULE, +}; + +static int __net_init arp_tables_net_init(struct net *net) +{ + return xt_proto_init(net, NFPROTO_ARP); +} + +static void __net_exit arp_tables_net_exit(struct net *net) +{ + xt_proto_fini(net, NFPROTO_ARP); +} + +static struct pernet_operations arp_tables_net_ops = { + .init = arp_tables_net_init, + .exit = arp_tables_net_exit, +}; + +static int __init arp_tables_init(void) +{ + int ret; + + ret = register_pernet_subsys(&arp_tables_net_ops); + if (ret < 0) + goto err1; + + /* No one else will be downing sem now, so we won't sleep */ + ret = xt_register_targets(arpt_builtin_tg, ARRAY_SIZE(arpt_builtin_tg)); + if (ret < 0) + goto err2; + + /* Register setsockopt */ + ret = nf_register_sockopt(&arpt_sockopts); + if (ret < 0) + goto err4; + + return 0; + +err4: + xt_unregister_targets(arpt_builtin_tg, ARRAY_SIZE(arpt_builtin_tg)); +err2: + unregister_pernet_subsys(&arp_tables_net_ops); +err1: + return ret; +} + +static void __exit arp_tables_fini(void) +{ + nf_unregister_sockopt(&arpt_sockopts); + xt_unregister_targets(arpt_builtin_tg, ARRAY_SIZE(arpt_builtin_tg)); + unregister_pernet_subsys(&arp_tables_net_ops); +} + +EXPORT_SYMBOL(arpt_register_table); +EXPORT_SYMBOL(arpt_unregister_table); +EXPORT_SYMBOL(arpt_do_table); + +module_init(arp_tables_init); +module_exit(arp_tables_fini); diff --git a/net/ipv4/netfilter/arpt_mangle.c b/net/ipv4/netfilter/arpt_mangle.c new file mode 100644 index 000000000..a4e07e5e9 --- /dev/null +++ b/net/ipv4/netfilter/arpt_mangle.c @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* module that allows mangling of the arp payload */ +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Bart De Schuymer "); +MODULE_DESCRIPTION("arptables arp payload mangle target"); + +static unsigned int +target(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct arpt_mangle *mangle = par->targinfo; + const struct arphdr *arp; + unsigned char *arpptr; + int pln, hln; + + if (skb_ensure_writable(skb, skb->len)) + return NF_DROP; + + arp = arp_hdr(skb); + arpptr = skb_network_header(skb) + sizeof(*arp); + pln = arp->ar_pln; + hln = arp->ar_hln; + /* We assume that pln and hln were checked in the match */ + if (mangle->flags & ARPT_MANGLE_SDEV) { + if (ARPT_DEV_ADDR_LEN_MAX < hln || + (arpptr + hln > skb_tail_pointer(skb))) + return NF_DROP; + memcpy(arpptr, mangle->src_devaddr, hln); + } + arpptr += hln; + if (mangle->flags & ARPT_MANGLE_SIP) { + if (ARPT_MANGLE_ADDR_LEN_MAX < pln || + (arpptr + pln > skb_tail_pointer(skb))) + return NF_DROP; + memcpy(arpptr, &mangle->u_s.src_ip, pln); + } + arpptr += pln; + if (mangle->flags & ARPT_MANGLE_TDEV) { + if (ARPT_DEV_ADDR_LEN_MAX < hln || + (arpptr + hln > skb_tail_pointer(skb))) + return NF_DROP; + memcpy(arpptr, mangle->tgt_devaddr, hln); + } + arpptr += hln; + if (mangle->flags & ARPT_MANGLE_TIP) { + if (ARPT_MANGLE_ADDR_LEN_MAX < pln || + (arpptr + pln > skb_tail_pointer(skb))) + return NF_DROP; + memcpy(arpptr, &mangle->u_t.tgt_ip, pln); + } + return mangle->target; +} + +static int checkentry(const struct xt_tgchk_param *par) +{ + const struct arpt_mangle *mangle = par->targinfo; + + if (mangle->flags & ~ARPT_MANGLE_MASK || + !(mangle->flags & ARPT_MANGLE_MASK)) + return -EINVAL; + + if (mangle->target != NF_DROP && mangle->target != NF_ACCEPT && + mangle->target != XT_CONTINUE) + return -EINVAL; + return 0; +} + +static struct xt_target arpt_mangle_reg __read_mostly = { + .name = "mangle", + .family = NFPROTO_ARP, + .target = target, + .targetsize = sizeof(struct arpt_mangle), + .checkentry = checkentry, + .me = THIS_MODULE, +}; + +static int __init arpt_mangle_init(void) +{ + return xt_register_target(&arpt_mangle_reg); +} + +static void __exit arpt_mangle_fini(void) +{ + xt_unregister_target(&arpt_mangle_reg); +} + +module_init(arpt_mangle_init); +module_exit(arpt_mangle_fini); diff --git a/net/ipv4/netfilter/arptable_filter.c b/net/ipv4/netfilter/arptable_filter.c new file mode 100644 index 000000000..78cd5ee24 --- /dev/null +++ b/net/ipv4/netfilter/arptable_filter.c @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Filtering ARP tables module. + * + * Copyright (C) 2002 David S. Miller (davem@redhat.com) + * + */ + +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("David S. Miller "); +MODULE_DESCRIPTION("arptables filter table"); + +#define FILTER_VALID_HOOKS ((1 << NF_ARP_IN) | (1 << NF_ARP_OUT) | \ + (1 << NF_ARP_FORWARD)) + +static const struct xt_table packet_filter = { + .name = "filter", + .valid_hooks = FILTER_VALID_HOOKS, + .me = THIS_MODULE, + .af = NFPROTO_ARP, + .priority = NF_IP_PRI_FILTER, +}; + +static struct nf_hook_ops *arpfilter_ops __read_mostly; + +static int arptable_filter_table_init(struct net *net) +{ + struct arpt_replace *repl; + int err; + + repl = arpt_alloc_initial_table(&packet_filter); + if (repl == NULL) + return -ENOMEM; + err = arpt_register_table(net, &packet_filter, repl, arpfilter_ops); + kfree(repl); + return err; +} + +static void __net_exit arptable_filter_net_pre_exit(struct net *net) +{ + arpt_unregister_table_pre_exit(net, "filter"); +} + +static void __net_exit arptable_filter_net_exit(struct net *net) +{ + arpt_unregister_table(net, "filter"); +} + +static struct pernet_operations arptable_filter_net_ops = { + .exit = arptable_filter_net_exit, + .pre_exit = arptable_filter_net_pre_exit, +}; + +static int __init arptable_filter_init(void) +{ + int ret = xt_register_template(&packet_filter, + arptable_filter_table_init); + + if (ret < 0) + return ret; + + arpfilter_ops = xt_hook_ops_alloc(&packet_filter, arpt_do_table); + if (IS_ERR(arpfilter_ops)) { + xt_unregister_template(&packet_filter); + return PTR_ERR(arpfilter_ops); + } + + ret = register_pernet_subsys(&arptable_filter_net_ops); + if (ret < 0) { + xt_unregister_template(&packet_filter); + kfree(arpfilter_ops); + return ret; + } + + return ret; +} + +static void __exit arptable_filter_fini(void) +{ + unregister_pernet_subsys(&arptable_filter_net_ops); + xt_unregister_template(&packet_filter); + kfree(arpfilter_ops); +} + +module_init(arptable_filter_init); +module_exit(arptable_filter_fini); diff --git a/net/ipv4/netfilter/ip_tables.c b/net/ipv4/netfilter/ip_tables.c new file mode 100644 index 000000000..da5998011 --- /dev/null +++ b/net/ipv4/netfilter/ip_tables.c @@ -0,0 +1,1952 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Packet matching code. + * + * Copyright (C) 1999 Paul `Rusty' Russell & Michael J. Neuling + * Copyright (C) 2000-2005 Netfilter Core Team + * Copyright (C) 2006-2010 Patrick McHardy + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include "../../netfilter/xt_repldata.h" + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Netfilter Core Team "); +MODULE_DESCRIPTION("IPv4 packet filter"); +MODULE_ALIAS("ipt_icmp"); + +void *ipt_alloc_initial_table(const struct xt_table *info) +{ + return xt_alloc_initial_table(ipt, IPT); +} +EXPORT_SYMBOL_GPL(ipt_alloc_initial_table); + +/* Returns whether matches rule or not. */ +/* Performance critical - called for every packet */ +static inline bool +ip_packet_match(const struct iphdr *ip, + const char *indev, + const char *outdev, + const struct ipt_ip *ipinfo, + int isfrag) +{ + unsigned long ret; + + if (NF_INVF(ipinfo, IPT_INV_SRCIP, + (ip->saddr & ipinfo->smsk.s_addr) != ipinfo->src.s_addr) || + NF_INVF(ipinfo, IPT_INV_DSTIP, + (ip->daddr & ipinfo->dmsk.s_addr) != ipinfo->dst.s_addr)) + return false; + + ret = ifname_compare_aligned(indev, ipinfo->iniface, ipinfo->iniface_mask); + + if (NF_INVF(ipinfo, IPT_INV_VIA_IN, ret != 0)) + return false; + + ret = ifname_compare_aligned(outdev, ipinfo->outiface, ipinfo->outiface_mask); + + if (NF_INVF(ipinfo, IPT_INV_VIA_OUT, ret != 0)) + return false; + + /* Check specific protocol */ + if (ipinfo->proto && + NF_INVF(ipinfo, IPT_INV_PROTO, ip->protocol != ipinfo->proto)) + return false; + + /* If we have a fragment rule but the packet is not a fragment + * then we return zero */ + if (NF_INVF(ipinfo, IPT_INV_FRAG, + (ipinfo->flags & IPT_F_FRAG) && !isfrag)) + return false; + + return true; +} + +static bool +ip_checkentry(const struct ipt_ip *ip) +{ + if (ip->flags & ~IPT_F_MASK) + return false; + if (ip->invflags & ~IPT_INV_MASK) + return false; + return true; +} + +static unsigned int +ipt_error(struct sk_buff *skb, const struct xt_action_param *par) +{ + net_info_ratelimited("error: `%s'\n", (const char *)par->targinfo); + + return NF_DROP; +} + +/* Performance critical */ +static inline struct ipt_entry * +get_entry(const void *base, unsigned int offset) +{ + return (struct ipt_entry *)(base + offset); +} + +/* All zeroes == unconditional rule. */ +/* Mildly perf critical (only if packet tracing is on) */ +static inline bool unconditional(const struct ipt_entry *e) +{ + static const struct ipt_ip uncond; + + return e->target_offset == sizeof(struct ipt_entry) && + memcmp(&e->ip, &uncond, sizeof(uncond)) == 0; +} + +/* for const-correctness */ +static inline const struct xt_entry_target * +ipt_get_target_c(const struct ipt_entry *e) +{ + return ipt_get_target((struct ipt_entry *)e); +} + +#if IS_ENABLED(CONFIG_NETFILTER_XT_TARGET_TRACE) +static const char *const hooknames[] = { + [NF_INET_PRE_ROUTING] = "PREROUTING", + [NF_INET_LOCAL_IN] = "INPUT", + [NF_INET_FORWARD] = "FORWARD", + [NF_INET_LOCAL_OUT] = "OUTPUT", + [NF_INET_POST_ROUTING] = "POSTROUTING", +}; + +enum nf_ip_trace_comments { + NF_IP_TRACE_COMMENT_RULE, + NF_IP_TRACE_COMMENT_RETURN, + NF_IP_TRACE_COMMENT_POLICY, +}; + +static const char *const comments[] = { + [NF_IP_TRACE_COMMENT_RULE] = "rule", + [NF_IP_TRACE_COMMENT_RETURN] = "return", + [NF_IP_TRACE_COMMENT_POLICY] = "policy", +}; + +static const struct nf_loginfo trace_loginfo = { + .type = NF_LOG_TYPE_LOG, + .u = { + .log = { + .level = 4, + .logflags = NF_LOG_DEFAULT_MASK, + }, + }, +}; + +/* Mildly perf critical (only if packet tracing is on) */ +static inline int +get_chainname_rulenum(const struct ipt_entry *s, const struct ipt_entry *e, + const char *hookname, const char **chainname, + const char **comment, unsigned int *rulenum) +{ + const struct xt_standard_target *t = (void *)ipt_get_target_c(s); + + if (strcmp(t->target.u.kernel.target->name, XT_ERROR_TARGET) == 0) { + /* Head of user chain: ERROR target with chainname */ + *chainname = t->target.data; + (*rulenum) = 0; + } else if (s == e) { + (*rulenum)++; + + if (unconditional(s) && + strcmp(t->target.u.kernel.target->name, + XT_STANDARD_TARGET) == 0 && + t->verdict < 0) { + /* Tail of chains: STANDARD target (return/policy) */ + *comment = *chainname == hookname + ? comments[NF_IP_TRACE_COMMENT_POLICY] + : comments[NF_IP_TRACE_COMMENT_RETURN]; + } + return 1; + } else + (*rulenum)++; + + return 0; +} + +static void trace_packet(struct net *net, + const struct sk_buff *skb, + unsigned int hook, + const struct net_device *in, + const struct net_device *out, + const char *tablename, + const struct xt_table_info *private, + const struct ipt_entry *e) +{ + const struct ipt_entry *root; + const char *hookname, *chainname, *comment; + const struct ipt_entry *iter; + unsigned int rulenum = 0; + + root = get_entry(private->entries, private->hook_entry[hook]); + + hookname = chainname = hooknames[hook]; + comment = comments[NF_IP_TRACE_COMMENT_RULE]; + + xt_entry_foreach(iter, root, private->size - private->hook_entry[hook]) + if (get_chainname_rulenum(iter, e, hookname, + &chainname, &comment, &rulenum) != 0) + break; + + nf_log_trace(net, AF_INET, hook, skb, in, out, &trace_loginfo, + "TRACE: %s:%s:%s:%u ", + tablename, chainname, comment, rulenum); +} +#endif + +static inline +struct ipt_entry *ipt_next_entry(const struct ipt_entry *entry) +{ + return (void *)entry + entry->next_offset; +} + +/* Returns one of the generic firewall policies, like NF_ACCEPT. */ +unsigned int +ipt_do_table(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + const struct xt_table *table = priv; + unsigned int hook = state->hook; + static const char nulldevname[IFNAMSIZ] __attribute__((aligned(sizeof(long)))); + const struct iphdr *ip; + /* Initializing verdict to NF_DROP keeps gcc happy. */ + unsigned int verdict = NF_DROP; + const char *indev, *outdev; + const void *table_base; + struct ipt_entry *e, **jumpstack; + unsigned int stackidx, cpu; + const struct xt_table_info *private; + struct xt_action_param acpar; + unsigned int addend; + + /* Initialization */ + stackidx = 0; + ip = ip_hdr(skb); + indev = state->in ? state->in->name : nulldevname; + outdev = state->out ? state->out->name : nulldevname; + /* We handle fragments by dealing with the first fragment as + * if it was a normal packet. All other fragments are treated + * normally, except that they will NEVER match rules that ask + * things we don't know, ie. tcp syn flag or ports). If the + * rule is also a fragment-specific rule, non-fragments won't + * match it. */ + acpar.fragoff = ntohs(ip->frag_off) & IP_OFFSET; + acpar.thoff = ip_hdrlen(skb); + acpar.hotdrop = false; + acpar.state = state; + + WARN_ON(!(table->valid_hooks & (1 << hook))); + local_bh_disable(); + addend = xt_write_recseq_begin(); + private = READ_ONCE(table->private); /* Address dependency. */ + cpu = smp_processor_id(); + table_base = private->entries; + jumpstack = (struct ipt_entry **)private->jumpstack[cpu]; + + /* Switch to alternate jumpstack if we're being invoked via TEE. + * TEE issues XT_CONTINUE verdict on original skb so we must not + * clobber the jumpstack. + * + * For recursion via REJECT or SYNPROXY the stack will be clobbered + * but it is no problem since absolute verdict is issued by these. + */ + if (static_key_false(&xt_tee_enabled)) + jumpstack += private->stacksize * __this_cpu_read(nf_skb_duplicated); + + e = get_entry(table_base, private->hook_entry[hook]); + + do { + const struct xt_entry_target *t; + const struct xt_entry_match *ematch; + struct xt_counters *counter; + + WARN_ON(!e); + if (!ip_packet_match(ip, indev, outdev, + &e->ip, acpar.fragoff)) { + no_match: + e = ipt_next_entry(e); + continue; + } + + xt_ematch_foreach(ematch, e) { + acpar.match = ematch->u.kernel.match; + acpar.matchinfo = ematch->data; + if (!acpar.match->match(skb, &acpar)) + goto no_match; + } + + counter = xt_get_this_cpu_counter(&e->counters); + ADD_COUNTER(*counter, skb->len, 1); + + t = ipt_get_target_c(e); + WARN_ON(!t->u.kernel.target); + +#if IS_ENABLED(CONFIG_NETFILTER_XT_TARGET_TRACE) + /* The packet is traced: log it */ + if (unlikely(skb->nf_trace)) + trace_packet(state->net, skb, hook, state->in, + state->out, table->name, private, e); +#endif + /* Standard target? */ + if (!t->u.kernel.target->target) { + int v; + + v = ((struct xt_standard_target *)t)->verdict; + if (v < 0) { + /* Pop from stack? */ + if (v != XT_RETURN) { + verdict = (unsigned int)(-v) - 1; + break; + } + if (stackidx == 0) { + e = get_entry(table_base, + private->underflow[hook]); + } else { + e = jumpstack[--stackidx]; + e = ipt_next_entry(e); + } + continue; + } + if (table_base + v != ipt_next_entry(e) && + !(e->ip.flags & IPT_F_GOTO)) { + if (unlikely(stackidx >= private->stacksize)) { + verdict = NF_DROP; + break; + } + jumpstack[stackidx++] = e; + } + + e = get_entry(table_base, v); + continue; + } + + acpar.target = t->u.kernel.target; + acpar.targinfo = t->data; + + verdict = t->u.kernel.target->target(skb, &acpar); + if (verdict == XT_CONTINUE) { + /* Target might have changed stuff. */ + ip = ip_hdr(skb); + e = ipt_next_entry(e); + } else { + /* Verdict */ + break; + } + } while (!acpar.hotdrop); + + xt_write_recseq_end(addend); + local_bh_enable(); + + if (acpar.hotdrop) + return NF_DROP; + else return verdict; +} + +/* Figures out from what hook each rule can be called: returns 0 if + there are loops. Puts hook bitmask in comefrom. */ +static int +mark_source_chains(const struct xt_table_info *newinfo, + unsigned int valid_hooks, void *entry0, + unsigned int *offsets) +{ + unsigned int hook; + + /* No recursion; use packet counter to save back ptrs (reset + to 0 as we leave), and comefrom to save source hook bitmask */ + for (hook = 0; hook < NF_INET_NUMHOOKS; hook++) { + unsigned int pos = newinfo->hook_entry[hook]; + struct ipt_entry *e = entry0 + pos; + + if (!(valid_hooks & (1 << hook))) + continue; + + /* Set initial back pointer. */ + e->counters.pcnt = pos; + + for (;;) { + const struct xt_standard_target *t + = (void *)ipt_get_target_c(e); + int visited = e->comefrom & (1 << hook); + + if (e->comefrom & (1 << NF_INET_NUMHOOKS)) + return 0; + + e->comefrom |= ((1 << hook) | (1 << NF_INET_NUMHOOKS)); + + /* Unconditional return/END. */ + if ((unconditional(e) && + (strcmp(t->target.u.user.name, + XT_STANDARD_TARGET) == 0) && + t->verdict < 0) || visited) { + unsigned int oldpos, size; + + /* Return: backtrack through the last + big jump. */ + do { + e->comefrom ^= (1<counters.pcnt; + e->counters.pcnt = 0; + + /* We're at the start. */ + if (pos == oldpos) + goto next; + + e = entry0 + pos; + } while (oldpos == pos + e->next_offset); + + /* Move along one */ + size = e->next_offset; + e = entry0 + pos + size; + if (pos + size >= newinfo->size) + return 0; + e->counters.pcnt = pos; + pos += size; + } else { + int newpos = t->verdict; + + if (strcmp(t->target.u.user.name, + XT_STANDARD_TARGET) == 0 && + newpos >= 0) { + /* This a jump; chase it. */ + if (!xt_find_jump_offset(offsets, newpos, + newinfo->number)) + return 0; + } else { + /* ... this is a fallthru */ + newpos = pos + e->next_offset; + if (newpos >= newinfo->size) + return 0; + } + e = entry0 + newpos; + e->counters.pcnt = pos; + pos = newpos; + } + } +next: ; + } + return 1; +} + +static void cleanup_match(struct xt_entry_match *m, struct net *net) +{ + struct xt_mtdtor_param par; + + par.net = net; + par.match = m->u.kernel.match; + par.matchinfo = m->data; + par.family = NFPROTO_IPV4; + if (par.match->destroy != NULL) + par.match->destroy(&par); + module_put(par.match->me); +} + +static int +check_match(struct xt_entry_match *m, struct xt_mtchk_param *par) +{ + const struct ipt_ip *ip = par->entryinfo; + + par->match = m->u.kernel.match; + par->matchinfo = m->data; + + return xt_check_match(par, m->u.match_size - sizeof(*m), + ip->proto, ip->invflags & IPT_INV_PROTO); +} + +static int +find_check_match(struct xt_entry_match *m, struct xt_mtchk_param *par) +{ + struct xt_match *match; + int ret; + + match = xt_request_find_match(NFPROTO_IPV4, m->u.user.name, + m->u.user.revision); + if (IS_ERR(match)) + return PTR_ERR(match); + m->u.kernel.match = match; + + ret = check_match(m, par); + if (ret) + goto err; + + return 0; +err: + module_put(m->u.kernel.match->me); + return ret; +} + +static int check_target(struct ipt_entry *e, struct net *net, const char *name) +{ + struct xt_entry_target *t = ipt_get_target(e); + struct xt_tgchk_param par = { + .net = net, + .table = name, + .entryinfo = e, + .target = t->u.kernel.target, + .targinfo = t->data, + .hook_mask = e->comefrom, + .family = NFPROTO_IPV4, + }; + + return xt_check_target(&par, t->u.target_size - sizeof(*t), + e->ip.proto, e->ip.invflags & IPT_INV_PROTO); +} + +static int +find_check_entry(struct ipt_entry *e, struct net *net, const char *name, + unsigned int size, + struct xt_percpu_counter_alloc_state *alloc_state) +{ + struct xt_entry_target *t; + struct xt_target *target; + int ret; + unsigned int j; + struct xt_mtchk_param mtpar; + struct xt_entry_match *ematch; + + if (!xt_percpu_counter_alloc(alloc_state, &e->counters)) + return -ENOMEM; + + j = 0; + memset(&mtpar, 0, sizeof(mtpar)); + mtpar.net = net; + mtpar.table = name; + mtpar.entryinfo = &e->ip; + mtpar.hook_mask = e->comefrom; + mtpar.family = NFPROTO_IPV4; + xt_ematch_foreach(ematch, e) { + ret = find_check_match(ematch, &mtpar); + if (ret != 0) + goto cleanup_matches; + ++j; + } + + t = ipt_get_target(e); + target = xt_request_find_target(NFPROTO_IPV4, t->u.user.name, + t->u.user.revision); + if (IS_ERR(target)) { + ret = PTR_ERR(target); + goto cleanup_matches; + } + t->u.kernel.target = target; + + ret = check_target(e, net, name); + if (ret) + goto err; + + return 0; + err: + module_put(t->u.kernel.target->me); + cleanup_matches: + xt_ematch_foreach(ematch, e) { + if (j-- == 0) + break; + cleanup_match(ematch, net); + } + + xt_percpu_counter_free(&e->counters); + + return ret; +} + +static bool check_underflow(const struct ipt_entry *e) +{ + const struct xt_entry_target *t; + unsigned int verdict; + + if (!unconditional(e)) + return false; + t = ipt_get_target_c(e); + if (strcmp(t->u.user.name, XT_STANDARD_TARGET) != 0) + return false; + verdict = ((struct xt_standard_target *)t)->verdict; + verdict = -verdict - 1; + return verdict == NF_DROP || verdict == NF_ACCEPT; +} + +static int +check_entry_size_and_hooks(struct ipt_entry *e, + struct xt_table_info *newinfo, + const unsigned char *base, + const unsigned char *limit, + const unsigned int *hook_entries, + const unsigned int *underflows, + unsigned int valid_hooks) +{ + unsigned int h; + int err; + + if ((unsigned long)e % __alignof__(struct ipt_entry) != 0 || + (unsigned char *)e + sizeof(struct ipt_entry) >= limit || + (unsigned char *)e + e->next_offset > limit) + return -EINVAL; + + if (e->next_offset + < sizeof(struct ipt_entry) + sizeof(struct xt_entry_target)) + return -EINVAL; + + if (!ip_checkentry(&e->ip)) + return -EINVAL; + + err = xt_check_entry_offsets(e, e->elems, e->target_offset, + e->next_offset); + if (err) + return err; + + /* Check hooks & underflows */ + for (h = 0; h < NF_INET_NUMHOOKS; h++) { + if (!(valid_hooks & (1 << h))) + continue; + if ((unsigned char *)e - base == hook_entries[h]) + newinfo->hook_entry[h] = hook_entries[h]; + if ((unsigned char *)e - base == underflows[h]) { + if (!check_underflow(e)) + return -EINVAL; + + newinfo->underflow[h] = underflows[h]; + } + } + + /* Clear counters and comefrom */ + e->counters = ((struct xt_counters) { 0, 0 }); + e->comefrom = 0; + return 0; +} + +static void +cleanup_entry(struct ipt_entry *e, struct net *net) +{ + struct xt_tgdtor_param par; + struct xt_entry_target *t; + struct xt_entry_match *ematch; + + /* Cleanup all matches */ + xt_ematch_foreach(ematch, e) + cleanup_match(ematch, net); + t = ipt_get_target(e); + + par.net = net; + par.target = t->u.kernel.target; + par.targinfo = t->data; + par.family = NFPROTO_IPV4; + if (par.target->destroy != NULL) + par.target->destroy(&par); + module_put(par.target->me); + xt_percpu_counter_free(&e->counters); +} + +/* Checks and translates the user-supplied table segment (held in + newinfo) */ +static int +translate_table(struct net *net, struct xt_table_info *newinfo, void *entry0, + const struct ipt_replace *repl) +{ + struct xt_percpu_counter_alloc_state alloc_state = { 0 }; + struct ipt_entry *iter; + unsigned int *offsets; + unsigned int i; + int ret = 0; + + newinfo->size = repl->size; + newinfo->number = repl->num_entries; + + /* Init all hooks to impossible value. */ + for (i = 0; i < NF_INET_NUMHOOKS; i++) { + newinfo->hook_entry[i] = 0xFFFFFFFF; + newinfo->underflow[i] = 0xFFFFFFFF; + } + + offsets = xt_alloc_entry_offsets(newinfo->number); + if (!offsets) + return -ENOMEM; + i = 0; + /* Walk through entries, checking offsets. */ + xt_entry_foreach(iter, entry0, newinfo->size) { + ret = check_entry_size_and_hooks(iter, newinfo, entry0, + entry0 + repl->size, + repl->hook_entry, + repl->underflow, + repl->valid_hooks); + if (ret != 0) + goto out_free; + if (i < repl->num_entries) + offsets[i] = (void *)iter - entry0; + ++i; + if (strcmp(ipt_get_target(iter)->u.user.name, + XT_ERROR_TARGET) == 0) + ++newinfo->stacksize; + } + + ret = -EINVAL; + if (i != repl->num_entries) + goto out_free; + + ret = xt_check_table_hooks(newinfo, repl->valid_hooks); + if (ret) + goto out_free; + + if (!mark_source_chains(newinfo, repl->valid_hooks, entry0, offsets)) { + ret = -ELOOP; + goto out_free; + } + kvfree(offsets); + + /* Finally, each sanity check must pass */ + i = 0; + xt_entry_foreach(iter, entry0, newinfo->size) { + ret = find_check_entry(iter, net, repl->name, repl->size, + &alloc_state); + if (ret != 0) + break; + ++i; + } + + if (ret != 0) { + xt_entry_foreach(iter, entry0, newinfo->size) { + if (i-- == 0) + break; + cleanup_entry(iter, net); + } + return ret; + } + + return ret; + out_free: + kvfree(offsets); + return ret; +} + +static void +get_counters(const struct xt_table_info *t, + struct xt_counters counters[]) +{ + struct ipt_entry *iter; + unsigned int cpu; + unsigned int i; + + for_each_possible_cpu(cpu) { + seqcount_t *s = &per_cpu(xt_recseq, cpu); + + i = 0; + xt_entry_foreach(iter, t->entries, t->size) { + struct xt_counters *tmp; + u64 bcnt, pcnt; + unsigned int start; + + tmp = xt_get_per_cpu_counter(&iter->counters, cpu); + do { + start = read_seqcount_begin(s); + bcnt = tmp->bcnt; + pcnt = tmp->pcnt; + } while (read_seqcount_retry(s, start)); + + ADD_COUNTER(counters[i], bcnt, pcnt); + ++i; /* macro does multi eval of i */ + cond_resched(); + } + } +} + +static void get_old_counters(const struct xt_table_info *t, + struct xt_counters counters[]) +{ + struct ipt_entry *iter; + unsigned int cpu, i; + + for_each_possible_cpu(cpu) { + i = 0; + xt_entry_foreach(iter, t->entries, t->size) { + const struct xt_counters *tmp; + + tmp = xt_get_per_cpu_counter(&iter->counters, cpu); + ADD_COUNTER(counters[i], tmp->bcnt, tmp->pcnt); + ++i; /* macro does multi eval of i */ + } + + cond_resched(); + } +} + +static struct xt_counters *alloc_counters(const struct xt_table *table) +{ + unsigned int countersize; + struct xt_counters *counters; + const struct xt_table_info *private = table->private; + + /* We need atomic snapshot of counters: rest doesn't change + (other than comefrom, which userspace doesn't care + about). */ + countersize = sizeof(struct xt_counters) * private->number; + counters = vzalloc(countersize); + + if (counters == NULL) + return ERR_PTR(-ENOMEM); + + get_counters(private, counters); + + return counters; +} + +static int +copy_entries_to_user(unsigned int total_size, + const struct xt_table *table, + void __user *userptr) +{ + unsigned int off, num; + const struct ipt_entry *e; + struct xt_counters *counters; + const struct xt_table_info *private = table->private; + int ret = 0; + const void *loc_cpu_entry; + + counters = alloc_counters(table); + if (IS_ERR(counters)) + return PTR_ERR(counters); + + loc_cpu_entry = private->entries; + + /* FIXME: use iterator macros --RR */ + /* ... then go back and fix counters and names */ + for (off = 0, num = 0; off < total_size; off += e->next_offset, num++){ + unsigned int i; + const struct xt_entry_match *m; + const struct xt_entry_target *t; + + e = loc_cpu_entry + off; + if (copy_to_user(userptr + off, e, sizeof(*e))) { + ret = -EFAULT; + goto free_counters; + } + if (copy_to_user(userptr + off + + offsetof(struct ipt_entry, counters), + &counters[num], + sizeof(counters[num])) != 0) { + ret = -EFAULT; + goto free_counters; + } + + for (i = sizeof(struct ipt_entry); + i < e->target_offset; + i += m->u.match_size) { + m = (void *)e + i; + + if (xt_match_to_user(m, userptr + off + i)) { + ret = -EFAULT; + goto free_counters; + } + } + + t = ipt_get_target_c(e); + if (xt_target_to_user(t, userptr + off + e->target_offset)) { + ret = -EFAULT; + goto free_counters; + } + } + + free_counters: + vfree(counters); + return ret; +} + +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT +static void compat_standard_from_user(void *dst, const void *src) +{ + int v = *(compat_int_t *)src; + + if (v > 0) + v += xt_compat_calc_jump(AF_INET, v); + memcpy(dst, &v, sizeof(v)); +} + +static int compat_standard_to_user(void __user *dst, const void *src) +{ + compat_int_t cv = *(int *)src; + + if (cv > 0) + cv -= xt_compat_calc_jump(AF_INET, cv); + return copy_to_user(dst, &cv, sizeof(cv)) ? -EFAULT : 0; +} + +static int compat_calc_entry(const struct ipt_entry *e, + const struct xt_table_info *info, + const void *base, struct xt_table_info *newinfo) +{ + const struct xt_entry_match *ematch; + const struct xt_entry_target *t; + unsigned int entry_offset; + int off, i, ret; + + off = sizeof(struct ipt_entry) - sizeof(struct compat_ipt_entry); + entry_offset = (void *)e - base; + xt_ematch_foreach(ematch, e) + off += xt_compat_match_offset(ematch->u.kernel.match); + t = ipt_get_target_c(e); + off += xt_compat_target_offset(t->u.kernel.target); + newinfo->size -= off; + ret = xt_compat_add_offset(AF_INET, entry_offset, off); + if (ret) + return ret; + + for (i = 0; i < NF_INET_NUMHOOKS; i++) { + if (info->hook_entry[i] && + (e < (struct ipt_entry *)(base + info->hook_entry[i]))) + newinfo->hook_entry[i] -= off; + if (info->underflow[i] && + (e < (struct ipt_entry *)(base + info->underflow[i]))) + newinfo->underflow[i] -= off; + } + return 0; +} + +static int compat_table_info(const struct xt_table_info *info, + struct xt_table_info *newinfo) +{ + struct ipt_entry *iter; + const void *loc_cpu_entry; + int ret; + + if (!newinfo || !info) + return -EINVAL; + + /* we dont care about newinfo->entries */ + memcpy(newinfo, info, offsetof(struct xt_table_info, entries)); + newinfo->initial_entries = 0; + loc_cpu_entry = info->entries; + ret = xt_compat_init_offsets(AF_INET, info->number); + if (ret) + return ret; + xt_entry_foreach(iter, loc_cpu_entry, info->size) { + ret = compat_calc_entry(iter, info, loc_cpu_entry, newinfo); + if (ret != 0) + return ret; + } + return 0; +} +#endif + +static int get_info(struct net *net, void __user *user, const int *len) +{ + char name[XT_TABLE_MAXNAMELEN]; + struct xt_table *t; + int ret; + + if (*len != sizeof(struct ipt_getinfo)) + return -EINVAL; + + if (copy_from_user(name, user, sizeof(name)) != 0) + return -EFAULT; + + name[XT_TABLE_MAXNAMELEN-1] = '\0'; +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + if (in_compat_syscall()) + xt_compat_lock(AF_INET); +#endif + t = xt_request_find_table_lock(net, AF_INET, name); + if (!IS_ERR(t)) { + struct ipt_getinfo info; + const struct xt_table_info *private = t->private; +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + struct xt_table_info tmp; + + if (in_compat_syscall()) { + ret = compat_table_info(private, &tmp); + xt_compat_flush_offsets(AF_INET); + private = &tmp; + } +#endif + memset(&info, 0, sizeof(info)); + info.valid_hooks = t->valid_hooks; + memcpy(info.hook_entry, private->hook_entry, + sizeof(info.hook_entry)); + memcpy(info.underflow, private->underflow, + sizeof(info.underflow)); + info.num_entries = private->number; + info.size = private->size; + strcpy(info.name, name); + + if (copy_to_user(user, &info, *len) != 0) + ret = -EFAULT; + else + ret = 0; + + xt_table_unlock(t); + module_put(t->me); + } else + ret = PTR_ERR(t); +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + if (in_compat_syscall()) + xt_compat_unlock(AF_INET); +#endif + return ret; +} + +static int +get_entries(struct net *net, struct ipt_get_entries __user *uptr, + const int *len) +{ + int ret; + struct ipt_get_entries get; + struct xt_table *t; + + if (*len < sizeof(get)) + return -EINVAL; + if (copy_from_user(&get, uptr, sizeof(get)) != 0) + return -EFAULT; + if (*len != sizeof(struct ipt_get_entries) + get.size) + return -EINVAL; + get.name[sizeof(get.name) - 1] = '\0'; + + t = xt_find_table_lock(net, AF_INET, get.name); + if (!IS_ERR(t)) { + const struct xt_table_info *private = t->private; + if (get.size == private->size) + ret = copy_entries_to_user(private->size, + t, uptr->entrytable); + else + ret = -EAGAIN; + + module_put(t->me); + xt_table_unlock(t); + } else + ret = PTR_ERR(t); + + return ret; +} + +static int +__do_replace(struct net *net, const char *name, unsigned int valid_hooks, + struct xt_table_info *newinfo, unsigned int num_counters, + void __user *counters_ptr) +{ + int ret; + struct xt_table *t; + struct xt_table_info *oldinfo; + struct xt_counters *counters; + struct ipt_entry *iter; + + counters = xt_counters_alloc(num_counters); + if (!counters) { + ret = -ENOMEM; + goto out; + } + + t = xt_request_find_table_lock(net, AF_INET, name); + if (IS_ERR(t)) { + ret = PTR_ERR(t); + goto free_newinfo_counters_untrans; + } + + /* You lied! */ + if (valid_hooks != t->valid_hooks) { + ret = -EINVAL; + goto put_module; + } + + oldinfo = xt_replace_table(t, num_counters, newinfo, &ret); + if (!oldinfo) + goto put_module; + + /* Update module usage count based on number of rules */ + if ((oldinfo->number > oldinfo->initial_entries) || + (newinfo->number <= oldinfo->initial_entries)) + module_put(t->me); + if ((oldinfo->number > oldinfo->initial_entries) && + (newinfo->number <= oldinfo->initial_entries)) + module_put(t->me); + + xt_table_unlock(t); + + get_old_counters(oldinfo, counters); + + /* Decrease module usage counts and free resource */ + xt_entry_foreach(iter, oldinfo->entries, oldinfo->size) + cleanup_entry(iter, net); + + xt_free_table_info(oldinfo); + if (copy_to_user(counters_ptr, counters, + sizeof(struct xt_counters) * num_counters) != 0) { + /* Silent error, can't fail, new table is already in place */ + net_warn_ratelimited("iptables: counters copy to user failed while replacing table\n"); + } + vfree(counters); + return 0; + + put_module: + module_put(t->me); + xt_table_unlock(t); + free_newinfo_counters_untrans: + vfree(counters); + out: + return ret; +} + +static int +do_replace(struct net *net, sockptr_t arg, unsigned int len) +{ + int ret; + struct ipt_replace tmp; + struct xt_table_info *newinfo; + void *loc_cpu_entry; + struct ipt_entry *iter; + + if (copy_from_sockptr(&tmp, arg, sizeof(tmp)) != 0) + return -EFAULT; + + /* overflow check */ + if (tmp.num_counters >= INT_MAX / sizeof(struct xt_counters)) + return -ENOMEM; + if (tmp.num_counters == 0) + return -EINVAL; + + tmp.name[sizeof(tmp.name)-1] = 0; + + newinfo = xt_alloc_table_info(tmp.size); + if (!newinfo) + return -ENOMEM; + + loc_cpu_entry = newinfo->entries; + if (copy_from_sockptr_offset(loc_cpu_entry, arg, sizeof(tmp), + tmp.size) != 0) { + ret = -EFAULT; + goto free_newinfo; + } + + ret = translate_table(net, newinfo, loc_cpu_entry, &tmp); + if (ret != 0) + goto free_newinfo; + + ret = __do_replace(net, tmp.name, tmp.valid_hooks, newinfo, + tmp.num_counters, tmp.counters); + if (ret) + goto free_newinfo_untrans; + return 0; + + free_newinfo_untrans: + xt_entry_foreach(iter, loc_cpu_entry, newinfo->size) + cleanup_entry(iter, net); + free_newinfo: + xt_free_table_info(newinfo); + return ret; +} + +static int +do_add_counters(struct net *net, sockptr_t arg, unsigned int len) +{ + unsigned int i; + struct xt_counters_info tmp; + struct xt_counters *paddc; + struct xt_table *t; + const struct xt_table_info *private; + int ret = 0; + struct ipt_entry *iter; + unsigned int addend; + + paddc = xt_copy_counters(arg, len, &tmp); + if (IS_ERR(paddc)) + return PTR_ERR(paddc); + + t = xt_find_table_lock(net, AF_INET, tmp.name); + if (IS_ERR(t)) { + ret = PTR_ERR(t); + goto free; + } + + local_bh_disable(); + private = t->private; + if (private->number != tmp.num_counters) { + ret = -EINVAL; + goto unlock_up_free; + } + + i = 0; + addend = xt_write_recseq_begin(); + xt_entry_foreach(iter, private->entries, private->size) { + struct xt_counters *tmp; + + tmp = xt_get_this_cpu_counter(&iter->counters); + ADD_COUNTER(*tmp, paddc[i].bcnt, paddc[i].pcnt); + ++i; + } + xt_write_recseq_end(addend); + unlock_up_free: + local_bh_enable(); + xt_table_unlock(t); + module_put(t->me); + free: + vfree(paddc); + + return ret; +} + +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT +struct compat_ipt_replace { + char name[XT_TABLE_MAXNAMELEN]; + u32 valid_hooks; + u32 num_entries; + u32 size; + u32 hook_entry[NF_INET_NUMHOOKS]; + u32 underflow[NF_INET_NUMHOOKS]; + u32 num_counters; + compat_uptr_t counters; /* struct xt_counters * */ + struct compat_ipt_entry entries[]; +}; + +static int +compat_copy_entry_to_user(struct ipt_entry *e, void __user **dstptr, + unsigned int *size, struct xt_counters *counters, + unsigned int i) +{ + struct xt_entry_target *t; + struct compat_ipt_entry __user *ce; + u_int16_t target_offset, next_offset; + compat_uint_t origsize; + const struct xt_entry_match *ematch; + int ret = 0; + + origsize = *size; + ce = *dstptr; + if (copy_to_user(ce, e, sizeof(struct ipt_entry)) != 0 || + copy_to_user(&ce->counters, &counters[i], + sizeof(counters[i])) != 0) + return -EFAULT; + + *dstptr += sizeof(struct compat_ipt_entry); + *size -= sizeof(struct ipt_entry) - sizeof(struct compat_ipt_entry); + + xt_ematch_foreach(ematch, e) { + ret = xt_compat_match_to_user(ematch, dstptr, size); + if (ret != 0) + return ret; + } + target_offset = e->target_offset - (origsize - *size); + t = ipt_get_target(e); + ret = xt_compat_target_to_user(t, dstptr, size); + if (ret) + return ret; + next_offset = e->next_offset - (origsize - *size); + if (put_user(target_offset, &ce->target_offset) != 0 || + put_user(next_offset, &ce->next_offset) != 0) + return -EFAULT; + return 0; +} + +static int +compat_find_calc_match(struct xt_entry_match *m, + const struct ipt_ip *ip, + int *size) +{ + struct xt_match *match; + + match = xt_request_find_match(NFPROTO_IPV4, m->u.user.name, + m->u.user.revision); + if (IS_ERR(match)) + return PTR_ERR(match); + + m->u.kernel.match = match; + *size += xt_compat_match_offset(match); + return 0; +} + +static void compat_release_entry(struct compat_ipt_entry *e) +{ + struct xt_entry_target *t; + struct xt_entry_match *ematch; + + /* Cleanup all matches */ + xt_ematch_foreach(ematch, e) + module_put(ematch->u.kernel.match->me); + t = compat_ipt_get_target(e); + module_put(t->u.kernel.target->me); +} + +static int +check_compat_entry_size_and_hooks(struct compat_ipt_entry *e, + struct xt_table_info *newinfo, + unsigned int *size, + const unsigned char *base, + const unsigned char *limit) +{ + struct xt_entry_match *ematch; + struct xt_entry_target *t; + struct xt_target *target; + unsigned int entry_offset; + unsigned int j; + int ret, off; + + if ((unsigned long)e % __alignof__(struct compat_ipt_entry) != 0 || + (unsigned char *)e + sizeof(struct compat_ipt_entry) >= limit || + (unsigned char *)e + e->next_offset > limit) + return -EINVAL; + + if (e->next_offset < sizeof(struct compat_ipt_entry) + + sizeof(struct compat_xt_entry_target)) + return -EINVAL; + + if (!ip_checkentry(&e->ip)) + return -EINVAL; + + ret = xt_compat_check_entry_offsets(e, e->elems, + e->target_offset, e->next_offset); + if (ret) + return ret; + + off = sizeof(struct ipt_entry) - sizeof(struct compat_ipt_entry); + entry_offset = (void *)e - (void *)base; + j = 0; + xt_ematch_foreach(ematch, e) { + ret = compat_find_calc_match(ematch, &e->ip, &off); + if (ret != 0) + goto release_matches; + ++j; + } + + t = compat_ipt_get_target(e); + target = xt_request_find_target(NFPROTO_IPV4, t->u.user.name, + t->u.user.revision); + if (IS_ERR(target)) { + ret = PTR_ERR(target); + goto release_matches; + } + t->u.kernel.target = target; + + off += xt_compat_target_offset(target); + *size += off; + ret = xt_compat_add_offset(AF_INET, entry_offset, off); + if (ret) + goto out; + + return 0; + +out: + module_put(t->u.kernel.target->me); +release_matches: + xt_ematch_foreach(ematch, e) { + if (j-- == 0) + break; + module_put(ematch->u.kernel.match->me); + } + return ret; +} + +static void +compat_copy_entry_from_user(struct compat_ipt_entry *e, void **dstptr, + unsigned int *size, + struct xt_table_info *newinfo, unsigned char *base) +{ + struct xt_entry_target *t; + struct ipt_entry *de; + unsigned int origsize; + int h; + struct xt_entry_match *ematch; + + origsize = *size; + de = *dstptr; + memcpy(de, e, sizeof(struct ipt_entry)); + memcpy(&de->counters, &e->counters, sizeof(e->counters)); + + *dstptr += sizeof(struct ipt_entry); + *size += sizeof(struct ipt_entry) - sizeof(struct compat_ipt_entry); + + xt_ematch_foreach(ematch, e) + xt_compat_match_from_user(ematch, dstptr, size); + + de->target_offset = e->target_offset - (origsize - *size); + t = compat_ipt_get_target(e); + xt_compat_target_from_user(t, dstptr, size); + + de->next_offset = e->next_offset - (origsize - *size); + + for (h = 0; h < NF_INET_NUMHOOKS; h++) { + if ((unsigned char *)de - base < newinfo->hook_entry[h]) + newinfo->hook_entry[h] -= origsize - *size; + if ((unsigned char *)de - base < newinfo->underflow[h]) + newinfo->underflow[h] -= origsize - *size; + } +} + +static int +translate_compat_table(struct net *net, + struct xt_table_info **pinfo, + void **pentry0, + const struct compat_ipt_replace *compatr) +{ + unsigned int i, j; + struct xt_table_info *newinfo, *info; + void *pos, *entry0, *entry1; + struct compat_ipt_entry *iter0; + struct ipt_replace repl; + unsigned int size; + int ret; + + info = *pinfo; + entry0 = *pentry0; + size = compatr->size; + info->number = compatr->num_entries; + + j = 0; + xt_compat_lock(AF_INET); + ret = xt_compat_init_offsets(AF_INET, compatr->num_entries); + if (ret) + goto out_unlock; + /* Walk through entries, checking offsets. */ + xt_entry_foreach(iter0, entry0, compatr->size) { + ret = check_compat_entry_size_and_hooks(iter0, info, &size, + entry0, + entry0 + compatr->size); + if (ret != 0) + goto out_unlock; + ++j; + } + + ret = -EINVAL; + if (j != compatr->num_entries) + goto out_unlock; + + ret = -ENOMEM; + newinfo = xt_alloc_table_info(size); + if (!newinfo) + goto out_unlock; + + memset(newinfo->entries, 0, size); + + newinfo->number = compatr->num_entries; + for (i = 0; i < NF_INET_NUMHOOKS; i++) { + newinfo->hook_entry[i] = compatr->hook_entry[i]; + newinfo->underflow[i] = compatr->underflow[i]; + } + entry1 = newinfo->entries; + pos = entry1; + size = compatr->size; + xt_entry_foreach(iter0, entry0, compatr->size) + compat_copy_entry_from_user(iter0, &pos, &size, + newinfo, entry1); + + /* all module references in entry0 are now gone. + * entry1/newinfo contains a 64bit ruleset that looks exactly as + * generated by 64bit userspace. + * + * Call standard translate_table() to validate all hook_entrys, + * underflows, check for loops, etc. + */ + xt_compat_flush_offsets(AF_INET); + xt_compat_unlock(AF_INET); + + memcpy(&repl, compatr, sizeof(*compatr)); + + for (i = 0; i < NF_INET_NUMHOOKS; i++) { + repl.hook_entry[i] = newinfo->hook_entry[i]; + repl.underflow[i] = newinfo->underflow[i]; + } + + repl.num_counters = 0; + repl.counters = NULL; + repl.size = newinfo->size; + ret = translate_table(net, newinfo, entry1, &repl); + if (ret) + goto free_newinfo; + + *pinfo = newinfo; + *pentry0 = entry1; + xt_free_table_info(info); + return 0; + +free_newinfo: + xt_free_table_info(newinfo); + return ret; +out_unlock: + xt_compat_flush_offsets(AF_INET); + xt_compat_unlock(AF_INET); + xt_entry_foreach(iter0, entry0, compatr->size) { + if (j-- == 0) + break; + compat_release_entry(iter0); + } + return ret; +} + +static int +compat_do_replace(struct net *net, sockptr_t arg, unsigned int len) +{ + int ret; + struct compat_ipt_replace tmp; + struct xt_table_info *newinfo; + void *loc_cpu_entry; + struct ipt_entry *iter; + + if (copy_from_sockptr(&tmp, arg, sizeof(tmp)) != 0) + return -EFAULT; + + /* overflow check */ + if (tmp.num_counters >= INT_MAX / sizeof(struct xt_counters)) + return -ENOMEM; + if (tmp.num_counters == 0) + return -EINVAL; + + tmp.name[sizeof(tmp.name)-1] = 0; + + newinfo = xt_alloc_table_info(tmp.size); + if (!newinfo) + return -ENOMEM; + + loc_cpu_entry = newinfo->entries; + if (copy_from_sockptr_offset(loc_cpu_entry, arg, sizeof(tmp), + tmp.size) != 0) { + ret = -EFAULT; + goto free_newinfo; + } + + ret = translate_compat_table(net, &newinfo, &loc_cpu_entry, &tmp); + if (ret != 0) + goto free_newinfo; + + ret = __do_replace(net, tmp.name, tmp.valid_hooks, newinfo, + tmp.num_counters, compat_ptr(tmp.counters)); + if (ret) + goto free_newinfo_untrans; + return 0; + + free_newinfo_untrans: + xt_entry_foreach(iter, loc_cpu_entry, newinfo->size) + cleanup_entry(iter, net); + free_newinfo: + xt_free_table_info(newinfo); + return ret; +} + +struct compat_ipt_get_entries { + char name[XT_TABLE_MAXNAMELEN]; + compat_uint_t size; + struct compat_ipt_entry entrytable[]; +}; + +static int +compat_copy_entries_to_user(unsigned int total_size, struct xt_table *table, + void __user *userptr) +{ + struct xt_counters *counters; + const struct xt_table_info *private = table->private; + void __user *pos; + unsigned int size; + int ret = 0; + unsigned int i = 0; + struct ipt_entry *iter; + + counters = alloc_counters(table); + if (IS_ERR(counters)) + return PTR_ERR(counters); + + pos = userptr; + size = total_size; + xt_entry_foreach(iter, private->entries, total_size) { + ret = compat_copy_entry_to_user(iter, &pos, + &size, counters, i++); + if (ret != 0) + break; + } + + vfree(counters); + return ret; +} + +static int +compat_get_entries(struct net *net, struct compat_ipt_get_entries __user *uptr, + int *len) +{ + int ret; + struct compat_ipt_get_entries get; + struct xt_table *t; + + if (*len < sizeof(get)) + return -EINVAL; + + if (copy_from_user(&get, uptr, sizeof(get)) != 0) + return -EFAULT; + + if (*len != sizeof(struct compat_ipt_get_entries) + get.size) + return -EINVAL; + + get.name[sizeof(get.name) - 1] = '\0'; + + xt_compat_lock(AF_INET); + t = xt_find_table_lock(net, AF_INET, get.name); + if (!IS_ERR(t)) { + const struct xt_table_info *private = t->private; + struct xt_table_info info; + ret = compat_table_info(private, &info); + if (!ret && get.size == info.size) + ret = compat_copy_entries_to_user(private->size, + t, uptr->entrytable); + else if (!ret) + ret = -EAGAIN; + + xt_compat_flush_offsets(AF_INET); + module_put(t->me); + xt_table_unlock(t); + } else + ret = PTR_ERR(t); + + xt_compat_unlock(AF_INET); + return ret; +} +#endif + +static int +do_ipt_set_ctl(struct sock *sk, int cmd, sockptr_t arg, unsigned int len) +{ + int ret; + + if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + switch (cmd) { + case IPT_SO_SET_REPLACE: +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + if (in_compat_syscall()) + ret = compat_do_replace(sock_net(sk), arg, len); + else +#endif + ret = do_replace(sock_net(sk), arg, len); + break; + + case IPT_SO_SET_ADD_COUNTERS: + ret = do_add_counters(sock_net(sk), arg, len); + break; + + default: + ret = -EINVAL; + } + + return ret; +} + +static int +do_ipt_get_ctl(struct sock *sk, int cmd, void __user *user, int *len) +{ + int ret; + + if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + switch (cmd) { + case IPT_SO_GET_INFO: + ret = get_info(sock_net(sk), user, len); + break; + + case IPT_SO_GET_ENTRIES: +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + if (in_compat_syscall()) + ret = compat_get_entries(sock_net(sk), user, len); + else +#endif + ret = get_entries(sock_net(sk), user, len); + break; + + case IPT_SO_GET_REVISION_MATCH: + case IPT_SO_GET_REVISION_TARGET: { + struct xt_get_revision rev; + int target; + + if (*len != sizeof(rev)) { + ret = -EINVAL; + break; + } + if (copy_from_user(&rev, user, sizeof(rev)) != 0) { + ret = -EFAULT; + break; + } + rev.name[sizeof(rev.name)-1] = 0; + + if (cmd == IPT_SO_GET_REVISION_TARGET) + target = 1; + else + target = 0; + + try_then_request_module(xt_find_revision(AF_INET, rev.name, + rev.revision, + target, &ret), + "ipt_%s", rev.name); + break; + } + + default: + ret = -EINVAL; + } + + return ret; +} + +static void __ipt_unregister_table(struct net *net, struct xt_table *table) +{ + struct xt_table_info *private; + void *loc_cpu_entry; + struct module *table_owner = table->me; + struct ipt_entry *iter; + + private = xt_unregister_table(table); + + /* Decrease module usage counts and free resources */ + loc_cpu_entry = private->entries; + xt_entry_foreach(iter, loc_cpu_entry, private->size) + cleanup_entry(iter, net); + if (private->number > private->initial_entries) + module_put(table_owner); + xt_free_table_info(private); +} + +int ipt_register_table(struct net *net, const struct xt_table *table, + const struct ipt_replace *repl, + const struct nf_hook_ops *template_ops) +{ + struct nf_hook_ops *ops; + unsigned int num_ops; + int ret, i; + struct xt_table_info *newinfo; + struct xt_table_info bootstrap = {0}; + void *loc_cpu_entry; + struct xt_table *new_table; + + newinfo = xt_alloc_table_info(repl->size); + if (!newinfo) + return -ENOMEM; + + loc_cpu_entry = newinfo->entries; + memcpy(loc_cpu_entry, repl->entries, repl->size); + + ret = translate_table(net, newinfo, loc_cpu_entry, repl); + if (ret != 0) { + xt_free_table_info(newinfo); + return ret; + } + + new_table = xt_register_table(net, table, &bootstrap, newinfo); + if (IS_ERR(new_table)) { + struct ipt_entry *iter; + + xt_entry_foreach(iter, loc_cpu_entry, newinfo->size) + cleanup_entry(iter, net); + xt_free_table_info(newinfo); + return PTR_ERR(new_table); + } + + /* No template? No need to do anything. This is used by 'nat' table, it registers + * with the nat core instead of the netfilter core. + */ + if (!template_ops) + return 0; + + num_ops = hweight32(table->valid_hooks); + if (num_ops == 0) { + ret = -EINVAL; + goto out_free; + } + + ops = kmemdup(template_ops, sizeof(*ops) * num_ops, GFP_KERNEL); + if (!ops) { + ret = -ENOMEM; + goto out_free; + } + + for (i = 0; i < num_ops; i++) + ops[i].priv = new_table; + + new_table->ops = ops; + + ret = nf_register_net_hooks(net, ops, num_ops); + if (ret != 0) + goto out_free; + + return ret; + +out_free: + __ipt_unregister_table(net, new_table); + return ret; +} + +void ipt_unregister_table_pre_exit(struct net *net, const char *name) +{ + struct xt_table *table = xt_find_table(net, NFPROTO_IPV4, name); + + if (table) + nf_unregister_net_hooks(net, table->ops, hweight32(table->valid_hooks)); +} + +void ipt_unregister_table_exit(struct net *net, const char *name) +{ + struct xt_table *table = xt_find_table(net, NFPROTO_IPV4, name); + + if (table) + __ipt_unregister_table(net, table); +} + +/* Returns 1 if the type and code is matched by the range, 0 otherwise */ +static inline bool +icmp_type_code_match(u_int8_t test_type, u_int8_t min_code, u_int8_t max_code, + u_int8_t type, u_int8_t code, + bool invert) +{ + return ((test_type == 0xFF) || + (type == test_type && code >= min_code && code <= max_code)) + ^ invert; +} + +static bool +icmp_match(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct icmphdr *ic; + struct icmphdr _icmph; + const struct ipt_icmp *icmpinfo = par->matchinfo; + + /* Must not be a fragment. */ + if (par->fragoff != 0) + return false; + + ic = skb_header_pointer(skb, par->thoff, sizeof(_icmph), &_icmph); + if (ic == NULL) { + /* We've been asked to examine this packet, and we + * can't. Hence, no choice but to drop. + */ + par->hotdrop = true; + return false; + } + + return icmp_type_code_match(icmpinfo->type, + icmpinfo->code[0], + icmpinfo->code[1], + ic->type, ic->code, + !!(icmpinfo->invflags&IPT_ICMP_INV)); +} + +static int icmp_checkentry(const struct xt_mtchk_param *par) +{ + const struct ipt_icmp *icmpinfo = par->matchinfo; + + /* Must specify no unknown invflags */ + return (icmpinfo->invflags & ~IPT_ICMP_INV) ? -EINVAL : 0; +} + +static struct xt_target ipt_builtin_tg[] __read_mostly = { + { + .name = XT_STANDARD_TARGET, + .targetsize = sizeof(int), + .family = NFPROTO_IPV4, +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + .compatsize = sizeof(compat_int_t), + .compat_from_user = compat_standard_from_user, + .compat_to_user = compat_standard_to_user, +#endif + }, + { + .name = XT_ERROR_TARGET, + .target = ipt_error, + .targetsize = XT_FUNCTION_MAXNAMELEN, + .family = NFPROTO_IPV4, + }, +}; + +static struct nf_sockopt_ops ipt_sockopts = { + .pf = PF_INET, + .set_optmin = IPT_BASE_CTL, + .set_optmax = IPT_SO_SET_MAX+1, + .set = do_ipt_set_ctl, + .get_optmin = IPT_BASE_CTL, + .get_optmax = IPT_SO_GET_MAX+1, + .get = do_ipt_get_ctl, + .owner = THIS_MODULE, +}; + +static struct xt_match ipt_builtin_mt[] __read_mostly = { + { + .name = "icmp", + .match = icmp_match, + .matchsize = sizeof(struct ipt_icmp), + .checkentry = icmp_checkentry, + .proto = IPPROTO_ICMP, + .family = NFPROTO_IPV4, + .me = THIS_MODULE, + }, +}; + +static int __net_init ip_tables_net_init(struct net *net) +{ + return xt_proto_init(net, NFPROTO_IPV4); +} + +static void __net_exit ip_tables_net_exit(struct net *net) +{ + xt_proto_fini(net, NFPROTO_IPV4); +} + +static struct pernet_operations ip_tables_net_ops = { + .init = ip_tables_net_init, + .exit = ip_tables_net_exit, +}; + +static int __init ip_tables_init(void) +{ + int ret; + + ret = register_pernet_subsys(&ip_tables_net_ops); + if (ret < 0) + goto err1; + + /* No one else will be downing sem now, so we won't sleep */ + ret = xt_register_targets(ipt_builtin_tg, ARRAY_SIZE(ipt_builtin_tg)); + if (ret < 0) + goto err2; + ret = xt_register_matches(ipt_builtin_mt, ARRAY_SIZE(ipt_builtin_mt)); + if (ret < 0) + goto err4; + + /* Register setsockopt */ + ret = nf_register_sockopt(&ipt_sockopts); + if (ret < 0) + goto err5; + + return 0; + +err5: + xt_unregister_matches(ipt_builtin_mt, ARRAY_SIZE(ipt_builtin_mt)); +err4: + xt_unregister_targets(ipt_builtin_tg, ARRAY_SIZE(ipt_builtin_tg)); +err2: + unregister_pernet_subsys(&ip_tables_net_ops); +err1: + return ret; +} + +static void __exit ip_tables_fini(void) +{ + nf_unregister_sockopt(&ipt_sockopts); + + xt_unregister_matches(ipt_builtin_mt, ARRAY_SIZE(ipt_builtin_mt)); + xt_unregister_targets(ipt_builtin_tg, ARRAY_SIZE(ipt_builtin_tg)); + unregister_pernet_subsys(&ip_tables_net_ops); +} + +EXPORT_SYMBOL(ipt_register_table); +EXPORT_SYMBOL(ipt_unregister_table_pre_exit); +EXPORT_SYMBOL(ipt_unregister_table_exit); +EXPORT_SYMBOL(ipt_do_table); +module_init(ip_tables_init); +module_exit(ip_tables_fini); diff --git a/net/ipv4/netfilter/ipt_CLUSTERIP.c b/net/ipv4/netfilter/ipt_CLUSTERIP.c new file mode 100644 index 000000000..b3cc416ed --- /dev/null +++ b/net/ipv4/netfilter/ipt_CLUSTERIP.c @@ -0,0 +1,929 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Cluster IP hashmark target + * (C) 2003-2004 by Harald Welte + * based on ideas of Fabio Olive Leite + * + * Development of this code funded by SuSE Linux AG, https://www.suse.com/ + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define CLUSTERIP_VERSION "0.8" + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Harald Welte "); +MODULE_DESCRIPTION("Xtables: CLUSTERIP target"); + +struct clusterip_config { + struct list_head list; /* list of all configs */ + refcount_t refcount; /* reference count */ + refcount_t entries; /* number of entries/rules + * referencing us */ + + __be32 clusterip; /* the IP address */ + u_int8_t clustermac[ETH_ALEN]; /* the MAC address */ + int ifindex; /* device ifindex */ + u_int16_t num_total_nodes; /* total number of nodes */ + unsigned long local_nodes; /* node number array */ + +#ifdef CONFIG_PROC_FS + struct proc_dir_entry *pde; /* proc dir entry */ +#endif + enum clusterip_hashmode hash_mode; /* which hashing mode */ + u_int32_t hash_initval; /* hash initialization */ + struct rcu_head rcu; /* for call_rcu */ + struct net *net; /* netns for pernet list */ + char ifname[IFNAMSIZ]; /* device ifname */ +}; + +#ifdef CONFIG_PROC_FS +static const struct proc_ops clusterip_proc_ops; +#endif + +struct clusterip_net { + struct list_head configs; + /* lock protects the configs list */ + spinlock_t lock; + + bool clusterip_deprecated_warning; +#ifdef CONFIG_PROC_FS + struct proc_dir_entry *procdir; + /* mutex protects the config->pde*/ + struct mutex mutex; +#endif + unsigned int hook_users; +}; + +static unsigned int clusterip_arp_mangle(void *priv, struct sk_buff *skb, const struct nf_hook_state *state); + +static const struct nf_hook_ops cip_arp_ops = { + .hook = clusterip_arp_mangle, + .pf = NFPROTO_ARP, + .hooknum = NF_ARP_OUT, + .priority = -1 +}; + +static unsigned int clusterip_net_id __read_mostly; +static inline struct clusterip_net *clusterip_pernet(struct net *net) +{ + return net_generic(net, clusterip_net_id); +} + +static inline void +clusterip_config_get(struct clusterip_config *c) +{ + refcount_inc(&c->refcount); +} + +static void clusterip_config_rcu_free(struct rcu_head *head) +{ + struct clusterip_config *config; + struct net_device *dev; + + config = container_of(head, struct clusterip_config, rcu); + dev = dev_get_by_name(config->net, config->ifname); + if (dev) { + dev_mc_del(dev, config->clustermac); + dev_put(dev); + } + kfree(config); +} + +static inline void +clusterip_config_put(struct clusterip_config *c) +{ + if (refcount_dec_and_test(&c->refcount)) + call_rcu(&c->rcu, clusterip_config_rcu_free); +} + +/* decrease the count of entries using/referencing this config. If last + * entry(rule) is removed, remove the config from lists, but don't free it + * yet, since proc-files could still be holding references */ +static inline void +clusterip_config_entry_put(struct clusterip_config *c) +{ + struct clusterip_net *cn = clusterip_pernet(c->net); + + local_bh_disable(); + if (refcount_dec_and_lock(&c->entries, &cn->lock)) { + list_del_rcu(&c->list); + spin_unlock(&cn->lock); + local_bh_enable(); + /* In case anyone still accesses the file, the open/close + * functions are also incrementing the refcount on their own, + * so it's safe to remove the entry even if it's in use. */ +#ifdef CONFIG_PROC_FS + mutex_lock(&cn->mutex); + if (cn->procdir) + proc_remove(c->pde); + mutex_unlock(&cn->mutex); +#endif + return; + } + local_bh_enable(); +} + +static struct clusterip_config * +__clusterip_config_find(struct net *net, __be32 clusterip) +{ + struct clusterip_config *c; + struct clusterip_net *cn = clusterip_pernet(net); + + list_for_each_entry_rcu(c, &cn->configs, list) { + if (c->clusterip == clusterip) + return c; + } + + return NULL; +} + +static inline struct clusterip_config * +clusterip_config_find_get(struct net *net, __be32 clusterip, int entry) +{ + struct clusterip_config *c; + + rcu_read_lock_bh(); + c = __clusterip_config_find(net, clusterip); + if (c) { +#ifdef CONFIG_PROC_FS + if (!c->pde) + c = NULL; + else +#endif + if (unlikely(!refcount_inc_not_zero(&c->refcount))) + c = NULL; + else if (entry) { + if (unlikely(!refcount_inc_not_zero(&c->entries))) { + clusterip_config_put(c); + c = NULL; + } + } + } + rcu_read_unlock_bh(); + + return c; +} + +static void +clusterip_config_init_nodelist(struct clusterip_config *c, + const struct ipt_clusterip_tgt_info *i) +{ + int n; + + for (n = 0; n < i->num_local_nodes; n++) + set_bit(i->local_nodes[n] - 1, &c->local_nodes); +} + +static int +clusterip_netdev_event(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct net *net = dev_net(dev); + struct clusterip_net *cn = clusterip_pernet(net); + struct clusterip_config *c; + + spin_lock_bh(&cn->lock); + list_for_each_entry_rcu(c, &cn->configs, list) { + switch (event) { + case NETDEV_REGISTER: + if (!strcmp(dev->name, c->ifname)) { + c->ifindex = dev->ifindex; + dev_mc_add(dev, c->clustermac); + } + break; + case NETDEV_UNREGISTER: + if (dev->ifindex == c->ifindex) { + dev_mc_del(dev, c->clustermac); + c->ifindex = -1; + } + break; + case NETDEV_CHANGENAME: + if (!strcmp(dev->name, c->ifname)) { + c->ifindex = dev->ifindex; + dev_mc_add(dev, c->clustermac); + } else if (dev->ifindex == c->ifindex) { + dev_mc_del(dev, c->clustermac); + c->ifindex = -1; + } + break; + } + } + spin_unlock_bh(&cn->lock); + + return NOTIFY_DONE; +} + +static struct clusterip_config * +clusterip_config_init(struct net *net, const struct ipt_clusterip_tgt_info *i, + __be32 ip, const char *iniface) +{ + struct clusterip_net *cn = clusterip_pernet(net); + struct clusterip_config *c; + struct net_device *dev; + int err; + + if (iniface[0] == '\0') { + pr_info("Please specify an interface name\n"); + return ERR_PTR(-EINVAL); + } + + c = kzalloc(sizeof(*c), GFP_ATOMIC); + if (!c) + return ERR_PTR(-ENOMEM); + + dev = dev_get_by_name(net, iniface); + if (!dev) { + pr_info("no such interface %s\n", iniface); + kfree(c); + return ERR_PTR(-ENOENT); + } + c->ifindex = dev->ifindex; + strcpy(c->ifname, dev->name); + memcpy(&c->clustermac, &i->clustermac, ETH_ALEN); + dev_mc_add(dev, c->clustermac); + dev_put(dev); + + c->clusterip = ip; + c->num_total_nodes = i->num_total_nodes; + clusterip_config_init_nodelist(c, i); + c->hash_mode = i->hash_mode; + c->hash_initval = i->hash_initval; + c->net = net; + refcount_set(&c->refcount, 1); + + spin_lock_bh(&cn->lock); + if (__clusterip_config_find(net, ip)) { + err = -EBUSY; + goto out_config_put; + } + + list_add_rcu(&c->list, &cn->configs); + spin_unlock_bh(&cn->lock); + +#ifdef CONFIG_PROC_FS + { + char buffer[16]; + + /* create proc dir entry */ + sprintf(buffer, "%pI4", &ip); + mutex_lock(&cn->mutex); + c->pde = proc_create_data(buffer, 0600, + cn->procdir, + &clusterip_proc_ops, c); + mutex_unlock(&cn->mutex); + if (!c->pde) { + err = -ENOMEM; + goto err; + } + } +#endif + + refcount_set(&c->entries, 1); + return c; + +#ifdef CONFIG_PROC_FS +err: +#endif + spin_lock_bh(&cn->lock); + list_del_rcu(&c->list); +out_config_put: + spin_unlock_bh(&cn->lock); + clusterip_config_put(c); + return ERR_PTR(err); +} + +#ifdef CONFIG_PROC_FS +static int +clusterip_add_node(struct clusterip_config *c, u_int16_t nodenum) +{ + + if (nodenum == 0 || + nodenum > c->num_total_nodes) + return 1; + + /* check if we already have this number in our bitfield */ + if (test_and_set_bit(nodenum - 1, &c->local_nodes)) + return 1; + + return 0; +} + +static bool +clusterip_del_node(struct clusterip_config *c, u_int16_t nodenum) +{ + if (nodenum == 0 || + nodenum > c->num_total_nodes) + return true; + + if (test_and_clear_bit(nodenum - 1, &c->local_nodes)) + return false; + + return true; +} +#endif + +static inline u_int32_t +clusterip_hashfn(const struct sk_buff *skb, + const struct clusterip_config *config) +{ + const struct iphdr *iph = ip_hdr(skb); + unsigned long hashval; + u_int16_t sport = 0, dport = 0; + int poff; + + poff = proto_ports_offset(iph->protocol); + if (poff >= 0) { + const u_int16_t *ports; + u16 _ports[2]; + + ports = skb_header_pointer(skb, iph->ihl * 4 + poff, 4, _ports); + if (ports) { + sport = ports[0]; + dport = ports[1]; + } + } else { + net_info_ratelimited("unknown protocol %u\n", iph->protocol); + } + + switch (config->hash_mode) { + case CLUSTERIP_HASHMODE_SIP: + hashval = jhash_1word(ntohl(iph->saddr), + config->hash_initval); + break; + case CLUSTERIP_HASHMODE_SIP_SPT: + hashval = jhash_2words(ntohl(iph->saddr), sport, + config->hash_initval); + break; + case CLUSTERIP_HASHMODE_SIP_SPT_DPT: + hashval = jhash_3words(ntohl(iph->saddr), sport, dport, + config->hash_initval); + break; + default: + /* to make gcc happy */ + hashval = 0; + /* This cannot happen, unless the check function wasn't called + * at rule load time */ + pr_info("unknown mode %u\n", config->hash_mode); + BUG(); + break; + } + + /* node numbers are 1..n, not 0..n */ + return reciprocal_scale(hashval, config->num_total_nodes) + 1; +} + +static inline int +clusterip_responsible(const struct clusterip_config *config, u_int32_t hash) +{ + return test_bit(hash - 1, &config->local_nodes); +} + +/*********************************************************************** + * IPTABLES TARGET + ***********************************************************************/ + +static unsigned int +clusterip_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct ipt_clusterip_tgt_info *cipinfo = par->targinfo; + struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + u_int32_t hash; + + /* don't need to clusterip_config_get() here, since refcount + * is only decremented by destroy() - and ip_tables guarantees + * that the ->target() function isn't called after ->destroy() */ + + ct = nf_ct_get(skb, &ctinfo); + if (ct == NULL) + return NF_DROP; + + /* special case: ICMP error handling. conntrack distinguishes between + * error messages (RELATED) and information requests (see below) */ + if (ip_hdr(skb)->protocol == IPPROTO_ICMP && + (ctinfo == IP_CT_RELATED || + ctinfo == IP_CT_RELATED_REPLY)) + return XT_CONTINUE; + + /* nf_conntrack_proto_icmp guarantees us that we only have ICMP_ECHO, + * TIMESTAMP, INFO_REQUEST or ICMP_ADDRESS type icmp packets from here + * on, which all have an ID field [relevant for hashing]. */ + + hash = clusterip_hashfn(skb, cipinfo->config); + + switch (ctinfo) { + case IP_CT_NEW: + WRITE_ONCE(ct->mark, hash); + break; + case IP_CT_RELATED: + case IP_CT_RELATED_REPLY: + /* FIXME: we don't handle expectations at the moment. + * They can arrive on a different node than + * the master connection (e.g. FTP passive mode) */ + case IP_CT_ESTABLISHED: + case IP_CT_ESTABLISHED_REPLY: + break; + default: /* Prevent gcc warnings */ + break; + } + +#ifdef DEBUG + nf_ct_dump_tuple_ip(&ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple); +#endif + pr_debug("hash=%u ct_hash=%u ", hash, READ_ONCE(ct->mark)); + if (!clusterip_responsible(cipinfo->config, hash)) { + pr_debug("not responsible\n"); + return NF_DROP; + } + pr_debug("responsible\n"); + + /* despite being received via linklayer multicast, this is + * actually a unicast IP packet. TCP doesn't like PACKET_MULTICAST */ + skb->pkt_type = PACKET_HOST; + + return XT_CONTINUE; +} + +static int clusterip_tg_check(const struct xt_tgchk_param *par) +{ + struct ipt_clusterip_tgt_info *cipinfo = par->targinfo; + struct clusterip_net *cn = clusterip_pernet(par->net); + const struct ipt_entry *e = par->entryinfo; + struct clusterip_config *config; + int ret, i; + + if (par->nft_compat) { + pr_err("cannot use CLUSTERIP target from nftables compat\n"); + return -EOPNOTSUPP; + } + + if (cn->hook_users == UINT_MAX) + return -EOVERFLOW; + + if (cipinfo->hash_mode != CLUSTERIP_HASHMODE_SIP && + cipinfo->hash_mode != CLUSTERIP_HASHMODE_SIP_SPT && + cipinfo->hash_mode != CLUSTERIP_HASHMODE_SIP_SPT_DPT) { + pr_info("unknown mode %u\n", cipinfo->hash_mode); + return -EINVAL; + + } + if (e->ip.dmsk.s_addr != htonl(0xffffffff) || + e->ip.dst.s_addr == 0) { + pr_info("Please specify destination IP\n"); + return -EINVAL; + } + if (cipinfo->num_local_nodes > ARRAY_SIZE(cipinfo->local_nodes)) { + pr_info("bad num_local_nodes %u\n", cipinfo->num_local_nodes); + return -EINVAL; + } + for (i = 0; i < cipinfo->num_local_nodes; i++) { + if (cipinfo->local_nodes[i] - 1 >= + sizeof(config->local_nodes) * 8) { + pr_info("bad local_nodes[%d] %u\n", + i, cipinfo->local_nodes[i]); + return -EINVAL; + } + } + + config = clusterip_config_find_get(par->net, e->ip.dst.s_addr, 1); + if (!config) { + if (!(cipinfo->flags & CLUSTERIP_FLAG_NEW)) { + pr_info("no config found for %pI4, need 'new'\n", + &e->ip.dst.s_addr); + return -EINVAL; + } else { + config = clusterip_config_init(par->net, cipinfo, + e->ip.dst.s_addr, + e->ip.iniface); + if (IS_ERR(config)) + return PTR_ERR(config); + } + } else if (memcmp(&config->clustermac, &cipinfo->clustermac, ETH_ALEN)) { + clusterip_config_entry_put(config); + clusterip_config_put(config); + return -EINVAL; + } + + ret = nf_ct_netns_get(par->net, par->family); + if (ret < 0) { + pr_info("cannot load conntrack support for proto=%u\n", + par->family); + clusterip_config_entry_put(config); + clusterip_config_put(config); + return ret; + } + + if (cn->hook_users == 0) { + ret = nf_register_net_hook(par->net, &cip_arp_ops); + + if (ret < 0) { + clusterip_config_entry_put(config); + clusterip_config_put(config); + nf_ct_netns_put(par->net, par->family); + return ret; + } + } + + cn->hook_users++; + + if (!cn->clusterip_deprecated_warning) { + pr_info("ipt_CLUSTERIP is deprecated and it will removed soon, " + "use xt_cluster instead\n"); + cn->clusterip_deprecated_warning = true; + } + + cipinfo->config = config; + return ret; +} + +/* drop reference count of cluster config when rule is deleted */ +static void clusterip_tg_destroy(const struct xt_tgdtor_param *par) +{ + const struct ipt_clusterip_tgt_info *cipinfo = par->targinfo; + struct clusterip_net *cn = clusterip_pernet(par->net); + + /* if no more entries are referencing the config, remove it + * from the list and destroy the proc entry */ + clusterip_config_entry_put(cipinfo->config); + + clusterip_config_put(cipinfo->config); + + nf_ct_netns_put(par->net, par->family); + cn->hook_users--; + + if (cn->hook_users == 0) + nf_unregister_net_hook(par->net, &cip_arp_ops); +} + +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT +struct compat_ipt_clusterip_tgt_info +{ + u_int32_t flags; + u_int8_t clustermac[6]; + u_int16_t num_total_nodes; + u_int16_t num_local_nodes; + u_int16_t local_nodes[CLUSTERIP_MAX_NODES]; + u_int32_t hash_mode; + u_int32_t hash_initval; + compat_uptr_t config; +}; +#endif /* CONFIG_NETFILTER_XTABLES_COMPAT */ + +static struct xt_target clusterip_tg_reg __read_mostly = { + .name = "CLUSTERIP", + .family = NFPROTO_IPV4, + .target = clusterip_tg, + .checkentry = clusterip_tg_check, + .destroy = clusterip_tg_destroy, + .targetsize = sizeof(struct ipt_clusterip_tgt_info), + .usersize = offsetof(struct ipt_clusterip_tgt_info, config), +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + .compatsize = sizeof(struct compat_ipt_clusterip_tgt_info), +#endif /* CONFIG_NETFILTER_XTABLES_COMPAT */ + .me = THIS_MODULE +}; + + +/*********************************************************************** + * ARP MANGLING CODE + ***********************************************************************/ + +/* hardcoded for 48bit ethernet and 32bit ipv4 addresses */ +struct arp_payload { + u_int8_t src_hw[ETH_ALEN]; + __be32 src_ip; + u_int8_t dst_hw[ETH_ALEN]; + __be32 dst_ip; +} __packed; + +#ifdef DEBUG +static void arp_print(struct arp_payload *payload) +{ +#define HBUFFERLEN 30 + char hbuffer[HBUFFERLEN]; + int j, k; + + for (k = 0, j = 0; k < HBUFFERLEN - 3 && j < ETH_ALEN; j++) { + hbuffer[k++] = hex_asc_hi(payload->src_hw[j]); + hbuffer[k++] = hex_asc_lo(payload->src_hw[j]); + hbuffer[k++] = ':'; + } + hbuffer[--k] = '\0'; + + pr_debug("src %pI4@%s, dst %pI4\n", + &payload->src_ip, hbuffer, &payload->dst_ip); +} +#endif + +static unsigned int +clusterip_arp_mangle(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct arphdr *arp = arp_hdr(skb); + struct arp_payload *payload; + struct clusterip_config *c; + struct net *net = state->net; + + /* we don't care about non-ethernet and non-ipv4 ARP */ + if (arp->ar_hrd != htons(ARPHRD_ETHER) || + arp->ar_pro != htons(ETH_P_IP) || + arp->ar_pln != 4 || arp->ar_hln != ETH_ALEN) + return NF_ACCEPT; + + /* we only want to mangle arp requests and replies */ + if (arp->ar_op != htons(ARPOP_REPLY) && + arp->ar_op != htons(ARPOP_REQUEST)) + return NF_ACCEPT; + + payload = (void *)(arp+1); + + /* if there is no clusterip configuration for the arp reply's + * source ip, we don't want to mangle it */ + c = clusterip_config_find_get(net, payload->src_ip, 0); + if (!c) + return NF_ACCEPT; + + /* normally the linux kernel always replies to arp queries of + * addresses on different interfacs. However, in the CLUSTERIP case + * this wouldn't work, since we didn't subscribe the mcast group on + * other interfaces */ + if (c->ifindex != state->out->ifindex) { + pr_debug("not mangling arp reply on different interface: cip'%d'-skb'%d'\n", + c->ifindex, state->out->ifindex); + clusterip_config_put(c); + return NF_ACCEPT; + } + + /* mangle reply hardware address */ + memcpy(payload->src_hw, c->clustermac, arp->ar_hln); + +#ifdef DEBUG + pr_debug("mangled arp reply: "); + arp_print(payload); +#endif + + clusterip_config_put(c); + + return NF_ACCEPT; +} + +/*********************************************************************** + * PROC DIR HANDLING + ***********************************************************************/ + +#ifdef CONFIG_PROC_FS + +struct clusterip_seq_position { + unsigned int pos; /* position */ + unsigned int weight; /* number of bits set == size */ + unsigned int bit; /* current bit */ + unsigned long val; /* current value */ +}; + +static void *clusterip_seq_start(struct seq_file *s, loff_t *pos) +{ + struct clusterip_config *c = s->private; + unsigned int weight; + u_int32_t local_nodes; + struct clusterip_seq_position *idx; + + /* FIXME: possible race */ + local_nodes = c->local_nodes; + weight = hweight32(local_nodes); + if (*pos >= weight) + return NULL; + + idx = kmalloc(sizeof(struct clusterip_seq_position), GFP_KERNEL); + if (!idx) + return ERR_PTR(-ENOMEM); + + idx->pos = *pos; + idx->weight = weight; + idx->bit = ffs(local_nodes); + idx->val = local_nodes; + clear_bit(idx->bit - 1, &idx->val); + + return idx; +} + +static void *clusterip_seq_next(struct seq_file *s, void *v, loff_t *pos) +{ + struct clusterip_seq_position *idx = v; + + *pos = ++idx->pos; + if (*pos >= idx->weight) { + kfree(v); + return NULL; + } + idx->bit = ffs(idx->val); + clear_bit(idx->bit - 1, &idx->val); + return idx; +} + +static void clusterip_seq_stop(struct seq_file *s, void *v) +{ + if (!IS_ERR(v)) + kfree(v); +} + +static int clusterip_seq_show(struct seq_file *s, void *v) +{ + struct clusterip_seq_position *idx = v; + + if (idx->pos != 0) + seq_putc(s, ','); + + seq_printf(s, "%u", idx->bit); + + if (idx->pos == idx->weight - 1) + seq_putc(s, '\n'); + + return 0; +} + +static const struct seq_operations clusterip_seq_ops = { + .start = clusterip_seq_start, + .next = clusterip_seq_next, + .stop = clusterip_seq_stop, + .show = clusterip_seq_show, +}; + +static int clusterip_proc_open(struct inode *inode, struct file *file) +{ + int ret = seq_open(file, &clusterip_seq_ops); + + if (!ret) { + struct seq_file *sf = file->private_data; + struct clusterip_config *c = pde_data(inode); + + sf->private = c; + + clusterip_config_get(c); + } + + return ret; +} + +static int clusterip_proc_release(struct inode *inode, struct file *file) +{ + struct clusterip_config *c = pde_data(inode); + int ret; + + ret = seq_release(inode, file); + + if (!ret) + clusterip_config_put(c); + + return ret; +} + +static ssize_t clusterip_proc_write(struct file *file, const char __user *input, + size_t size, loff_t *ofs) +{ + struct clusterip_config *c = pde_data(file_inode(file)); +#define PROC_WRITELEN 10 + char buffer[PROC_WRITELEN+1]; + unsigned long nodenum; + int rc; + + if (size > PROC_WRITELEN) + return -EIO; + if (copy_from_user(buffer, input, size)) + return -EFAULT; + buffer[size] = 0; + + if (*buffer == '+') { + rc = kstrtoul(buffer+1, 10, &nodenum); + if (rc) + return rc; + if (clusterip_add_node(c, nodenum)) + return -ENOMEM; + } else if (*buffer == '-') { + rc = kstrtoul(buffer+1, 10, &nodenum); + if (rc) + return rc; + if (clusterip_del_node(c, nodenum)) + return -ENOENT; + } else + return -EIO; + + return size; +} + +static const struct proc_ops clusterip_proc_ops = { + .proc_open = clusterip_proc_open, + .proc_read = seq_read, + .proc_write = clusterip_proc_write, + .proc_lseek = seq_lseek, + .proc_release = clusterip_proc_release, +}; + +#endif /* CONFIG_PROC_FS */ + +static int clusterip_net_init(struct net *net) +{ + struct clusterip_net *cn = clusterip_pernet(net); + + INIT_LIST_HEAD(&cn->configs); + + spin_lock_init(&cn->lock); + +#ifdef CONFIG_PROC_FS + cn->procdir = proc_mkdir("ipt_CLUSTERIP", net->proc_net); + if (!cn->procdir) { + pr_err("Unable to proc dir entry\n"); + return -ENOMEM; + } + mutex_init(&cn->mutex); +#endif /* CONFIG_PROC_FS */ + + return 0; +} + +static void clusterip_net_exit(struct net *net) +{ +#ifdef CONFIG_PROC_FS + struct clusterip_net *cn = clusterip_pernet(net); + + mutex_lock(&cn->mutex); + proc_remove(cn->procdir); + cn->procdir = NULL; + mutex_unlock(&cn->mutex); +#endif +} + +static struct pernet_operations clusterip_net_ops = { + .init = clusterip_net_init, + .exit = clusterip_net_exit, + .id = &clusterip_net_id, + .size = sizeof(struct clusterip_net), +}; + +static struct notifier_block cip_netdev_notifier = { + .notifier_call = clusterip_netdev_event +}; + +static int __init clusterip_tg_init(void) +{ + int ret; + + ret = register_pernet_subsys(&clusterip_net_ops); + if (ret < 0) + return ret; + + ret = xt_register_target(&clusterip_tg_reg); + if (ret < 0) + goto cleanup_subsys; + + ret = register_netdevice_notifier(&cip_netdev_notifier); + if (ret < 0) + goto unregister_target; + + pr_info("ClusterIP Version %s loaded successfully\n", + CLUSTERIP_VERSION); + + return 0; + +unregister_target: + xt_unregister_target(&clusterip_tg_reg); +cleanup_subsys: + unregister_pernet_subsys(&clusterip_net_ops); + return ret; +} + +static void __exit clusterip_tg_exit(void) +{ + pr_info("ClusterIP Version %s unloading\n", CLUSTERIP_VERSION); + + unregister_netdevice_notifier(&cip_netdev_notifier); + xt_unregister_target(&clusterip_tg_reg); + unregister_pernet_subsys(&clusterip_net_ops); + + /* Wait for completion of call_rcu()'s (clusterip_config_rcu_free) */ + rcu_barrier(); +} + +module_init(clusterip_tg_init); +module_exit(clusterip_tg_exit); diff --git a/net/ipv4/netfilter/ipt_ECN.c b/net/ipv4/netfilter/ipt_ECN.c new file mode 100644 index 000000000..5930d3b02 --- /dev/null +++ b/net/ipv4/netfilter/ipt_ECN.c @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* iptables module for the IPv4 and TCP ECN bits, Version 1.5 + * + * (C) 2002 by Harald Welte +*/ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Harald Welte "); +MODULE_DESCRIPTION("Xtables: Explicit Congestion Notification (ECN) flag modification"); + +/* set ECT codepoint from IP header. + * return false if there was an error. */ +static inline bool +set_ect_ip(struct sk_buff *skb, const struct ipt_ECN_info *einfo) +{ + struct iphdr *iph = ip_hdr(skb); + + if ((iph->tos & IPT_ECN_IP_MASK) != (einfo->ip_ect & IPT_ECN_IP_MASK)) { + __u8 oldtos; + if (skb_ensure_writable(skb, sizeof(struct iphdr))) + return false; + iph = ip_hdr(skb); + oldtos = iph->tos; + iph->tos &= ~IPT_ECN_IP_MASK; + iph->tos |= (einfo->ip_ect & IPT_ECN_IP_MASK); + csum_replace2(&iph->check, htons(oldtos), htons(iph->tos)); + } + return true; +} + +/* Return false if there was an error. */ +static inline bool +set_ect_tcp(struct sk_buff *skb, const struct ipt_ECN_info *einfo) +{ + struct tcphdr _tcph, *tcph; + __be16 oldval; + + /* Not enough header? */ + tcph = skb_header_pointer(skb, ip_hdrlen(skb), sizeof(_tcph), &_tcph); + if (!tcph) + return false; + + if ((!(einfo->operation & IPT_ECN_OP_SET_ECE) || + tcph->ece == einfo->proto.tcp.ece) && + (!(einfo->operation & IPT_ECN_OP_SET_CWR) || + tcph->cwr == einfo->proto.tcp.cwr)) + return true; + + if (skb_ensure_writable(skb, ip_hdrlen(skb) + sizeof(*tcph))) + return false; + tcph = (void *)ip_hdr(skb) + ip_hdrlen(skb); + + oldval = ((__be16 *)tcph)[6]; + if (einfo->operation & IPT_ECN_OP_SET_ECE) + tcph->ece = einfo->proto.tcp.ece; + if (einfo->operation & IPT_ECN_OP_SET_CWR) + tcph->cwr = einfo->proto.tcp.cwr; + + inet_proto_csum_replace2(&tcph->check, skb, + oldval, ((__be16 *)tcph)[6], false); + return true; +} + +static unsigned int +ecn_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct ipt_ECN_info *einfo = par->targinfo; + + if (einfo->operation & IPT_ECN_OP_SET_IP) + if (!set_ect_ip(skb, einfo)) + return NF_DROP; + + if (einfo->operation & (IPT_ECN_OP_SET_ECE | IPT_ECN_OP_SET_CWR) && + ip_hdr(skb)->protocol == IPPROTO_TCP) + if (!set_ect_tcp(skb, einfo)) + return NF_DROP; + + return XT_CONTINUE; +} + +static int ecn_tg_check(const struct xt_tgchk_param *par) +{ + const struct ipt_ECN_info *einfo = par->targinfo; + const struct ipt_entry *e = par->entryinfo; + + if (einfo->operation & IPT_ECN_OP_MASK) + return -EINVAL; + + if (einfo->ip_ect & ~IPT_ECN_IP_MASK) + return -EINVAL; + + if ((einfo->operation & (IPT_ECN_OP_SET_ECE|IPT_ECN_OP_SET_CWR)) && + (e->ip.proto != IPPROTO_TCP || (e->ip.invflags & XT_INV_PROTO))) { + pr_info_ratelimited("cannot use operation on non-tcp rule\n"); + return -EINVAL; + } + return 0; +} + +static struct xt_target ecn_tg_reg __read_mostly = { + .name = "ECN", + .family = NFPROTO_IPV4, + .target = ecn_tg, + .targetsize = sizeof(struct ipt_ECN_info), + .table = "mangle", + .checkentry = ecn_tg_check, + .me = THIS_MODULE, +}; + +static int __init ecn_tg_init(void) +{ + return xt_register_target(&ecn_tg_reg); +} + +static void __exit ecn_tg_exit(void) +{ + xt_unregister_target(&ecn_tg_reg); +} + +module_init(ecn_tg_init); +module_exit(ecn_tg_exit); diff --git a/net/ipv4/netfilter/ipt_REJECT.c b/net/ipv4/netfilter/ipt_REJECT.c new file mode 100644 index 000000000..4b8840734 --- /dev/null +++ b/net/ipv4/netfilter/ipt_REJECT.c @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * This is a module which is used for rejecting packets. + */ + +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2004 Netfilter Core Team + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) +#include +#endif + +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Netfilter Core Team "); +MODULE_DESCRIPTION("Xtables: packet \"rejection\" target for IPv4"); + +static unsigned int +reject_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct ipt_reject_info *reject = par->targinfo; + int hook = xt_hooknum(par); + + switch (reject->with) { + case IPT_ICMP_NET_UNREACHABLE: + nf_send_unreach(skb, ICMP_NET_UNREACH, hook); + break; + case IPT_ICMP_HOST_UNREACHABLE: + nf_send_unreach(skb, ICMP_HOST_UNREACH, hook); + break; + case IPT_ICMP_PROT_UNREACHABLE: + nf_send_unreach(skb, ICMP_PROT_UNREACH, hook); + break; + case IPT_ICMP_PORT_UNREACHABLE: + nf_send_unreach(skb, ICMP_PORT_UNREACH, hook); + break; + case IPT_ICMP_NET_PROHIBITED: + nf_send_unreach(skb, ICMP_NET_ANO, hook); + break; + case IPT_ICMP_HOST_PROHIBITED: + nf_send_unreach(skb, ICMP_HOST_ANO, hook); + break; + case IPT_ICMP_ADMIN_PROHIBITED: + nf_send_unreach(skb, ICMP_PKT_FILTERED, hook); + break; + case IPT_TCP_RESET: + nf_send_reset(xt_net(par), par->state->sk, skb, hook); + break; + case IPT_ICMP_ECHOREPLY: + /* Doesn't happen. */ + break; + } + + return NF_DROP; +} + +static int reject_tg_check(const struct xt_tgchk_param *par) +{ + const struct ipt_reject_info *rejinfo = par->targinfo; + const struct ipt_entry *e = par->entryinfo; + + if (rejinfo->with == IPT_ICMP_ECHOREPLY) { + pr_info_ratelimited("ECHOREPLY no longer supported.\n"); + return -EINVAL; + } else if (rejinfo->with == IPT_TCP_RESET) { + /* Must specify that it's a TCP packet */ + if (e->ip.proto != IPPROTO_TCP || + (e->ip.invflags & XT_INV_PROTO)) { + pr_info_ratelimited("TCP_RESET invalid for non-tcp\n"); + return -EINVAL; + } + } + return 0; +} + +static struct xt_target reject_tg_reg __read_mostly = { + .name = "REJECT", + .family = NFPROTO_IPV4, + .target = reject_tg, + .targetsize = sizeof(struct ipt_reject_info), + .table = "filter", + .hooks = (1 << NF_INET_LOCAL_IN) | (1 << NF_INET_FORWARD) | + (1 << NF_INET_LOCAL_OUT), + .checkentry = reject_tg_check, + .me = THIS_MODULE, +}; + +static int __init reject_tg_init(void) +{ + return xt_register_target(&reject_tg_reg); +} + +static void __exit reject_tg_exit(void) +{ + xt_unregister_target(&reject_tg_reg); +} + +module_init(reject_tg_init); +module_exit(reject_tg_exit); diff --git a/net/ipv4/netfilter/ipt_SYNPROXY.c b/net/ipv4/netfilter/ipt_SYNPROXY.c new file mode 100644 index 000000000..f2984c7ee --- /dev/null +++ b/net/ipv4/netfilter/ipt_SYNPROXY.c @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2013 Patrick McHardy + */ + +#include +#include +#include + +#include + +static unsigned int +synproxy_tg4(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_synproxy_info *info = par->targinfo; + struct net *net = xt_net(par); + struct synproxy_net *snet = synproxy_pernet(net); + struct synproxy_options opts = {}; + struct tcphdr *th, _th; + + if (nf_ip_checksum(skb, xt_hooknum(par), par->thoff, IPPROTO_TCP)) + return NF_DROP; + + th = skb_header_pointer(skb, par->thoff, sizeof(_th), &_th); + if (th == NULL) + return NF_DROP; + + if (!synproxy_parse_options(skb, par->thoff, th, &opts)) + return NF_DROP; + + if (th->syn && !(th->ack || th->fin || th->rst)) { + /* Initial SYN from client */ + this_cpu_inc(snet->stats->syn_received); + + if (th->ece && th->cwr) + opts.options |= XT_SYNPROXY_OPT_ECN; + + opts.options &= info->options; + opts.mss_encode = opts.mss_option; + opts.mss_option = info->mss; + if (opts.options & XT_SYNPROXY_OPT_TIMESTAMP) + synproxy_init_timestamp_cookie(info, &opts); + else + opts.options &= ~(XT_SYNPROXY_OPT_WSCALE | + XT_SYNPROXY_OPT_SACK_PERM | + XT_SYNPROXY_OPT_ECN); + + synproxy_send_client_synack(net, skb, th, &opts); + consume_skb(skb); + return NF_STOLEN; + } else if (th->ack && !(th->fin || th->rst || th->syn)) { + /* ACK from client */ + if (synproxy_recv_client_ack(net, skb, th, &opts, ntohl(th->seq))) { + consume_skb(skb); + return NF_STOLEN; + } else { + return NF_DROP; + } + } + + return XT_CONTINUE; +} + +static int synproxy_tg4_check(const struct xt_tgchk_param *par) +{ + struct synproxy_net *snet = synproxy_pernet(par->net); + const struct ipt_entry *e = par->entryinfo; + int err; + + if (e->ip.proto != IPPROTO_TCP || + e->ip.invflags & XT_INV_PROTO) + return -EINVAL; + + err = nf_ct_netns_get(par->net, par->family); + if (err) + return err; + + err = nf_synproxy_ipv4_init(snet, par->net); + if (err) { + nf_ct_netns_put(par->net, par->family); + return err; + } + + return err; +} + +static void synproxy_tg4_destroy(const struct xt_tgdtor_param *par) +{ + struct synproxy_net *snet = synproxy_pernet(par->net); + + nf_synproxy_ipv4_fini(snet, par->net); + nf_ct_netns_put(par->net, par->family); +} + +static struct xt_target synproxy_tg4_reg __read_mostly = { + .name = "SYNPROXY", + .family = NFPROTO_IPV4, + .hooks = (1 << NF_INET_LOCAL_IN) | (1 << NF_INET_FORWARD), + .target = synproxy_tg4, + .targetsize = sizeof(struct xt_synproxy_info), + .checkentry = synproxy_tg4_check, + .destroy = synproxy_tg4_destroy, + .me = THIS_MODULE, +}; + +static int __init synproxy_tg4_init(void) +{ + return xt_register_target(&synproxy_tg4_reg); +} + +static void __exit synproxy_tg4_exit(void) +{ + xt_unregister_target(&synproxy_tg4_reg); +} + +module_init(synproxy_tg4_init); +module_exit(synproxy_tg4_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_DESCRIPTION("Intercept TCP connections and establish them using syncookies"); diff --git a/net/ipv4/netfilter/ipt_ah.c b/net/ipv4/netfilter/ipt_ah.c new file mode 100644 index 000000000..161ba412c --- /dev/null +++ b/net/ipv4/netfilter/ipt_ah.c @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Kernel module to match AH parameters. */ +/* (C) 1999-2000 Yon Uriarte + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include + +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Yon Uriarte "); +MODULE_DESCRIPTION("Xtables: IPv4 IPsec-AH SPI match"); + +/* Returns 1 if the spi is matched by the range, 0 otherwise */ +static inline bool +spi_match(u_int32_t min, u_int32_t max, u_int32_t spi, bool invert) +{ + bool r; + pr_debug("spi_match:%c 0x%x <= 0x%x <= 0x%x\n", + invert ? '!' : ' ', min, spi, max); + r = (spi >= min && spi <= max) ^ invert; + pr_debug(" result %s\n", r ? "PASS" : "FAILED"); + return r; +} + +static bool ah_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + struct ip_auth_hdr _ahdr; + const struct ip_auth_hdr *ah; + const struct ipt_ah *ahinfo = par->matchinfo; + + /* Must not be a fragment. */ + if (par->fragoff != 0) + return false; + + ah = skb_header_pointer(skb, par->thoff, sizeof(_ahdr), &_ahdr); + if (ah == NULL) { + /* We've been asked to examine this packet, and we + * can't. Hence, no choice but to drop. + */ + pr_debug("Dropping evil AH tinygram.\n"); + par->hotdrop = true; + return false; + } + + return spi_match(ahinfo->spis[0], ahinfo->spis[1], + ntohl(ah->spi), + !!(ahinfo->invflags & IPT_AH_INV_SPI)); +} + +static int ah_mt_check(const struct xt_mtchk_param *par) +{ + const struct ipt_ah *ahinfo = par->matchinfo; + + /* Must specify no unknown invflags */ + if (ahinfo->invflags & ~IPT_AH_INV_MASK) { + pr_debug("unknown flags %X\n", ahinfo->invflags); + return -EINVAL; + } + return 0; +} + +static struct xt_match ah_mt_reg __read_mostly = { + .name = "ah", + .family = NFPROTO_IPV4, + .match = ah_mt, + .matchsize = sizeof(struct ipt_ah), + .proto = IPPROTO_AH, + .checkentry = ah_mt_check, + .me = THIS_MODULE, +}; + +static int __init ah_mt_init(void) +{ + return xt_register_match(&ah_mt_reg); +} + +static void __exit ah_mt_exit(void) +{ + xt_unregister_match(&ah_mt_reg); +} + +module_init(ah_mt_init); +module_exit(ah_mt_exit); diff --git a/net/ipv4/netfilter/ipt_rpfilter.c b/net/ipv4/netfilter/ipt_rpfilter.c new file mode 100644 index 000000000..ded5bef02 --- /dev/null +++ b/net/ipv4/netfilter/ipt_rpfilter.c @@ -0,0 +1,126 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2011 Florian Westphal + * + * based on fib_frontend.c; Author: Alexey Kuznetsov, + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Florian Westphal "); +MODULE_DESCRIPTION("iptables: ipv4 reverse path filter match"); + +/* don't try to find route from mcast/bcast/zeronet */ +static __be32 rpfilter_get_saddr(__be32 addr) +{ + if (ipv4_is_multicast(addr) || ipv4_is_lbcast(addr) || + ipv4_is_zeronet(addr)) + return 0; + return addr; +} + +static bool rpfilter_lookup_reverse(struct net *net, struct flowi4 *fl4, + const struct net_device *dev, u8 flags) +{ + struct fib_result res; + + if (fib_lookup(net, fl4, &res, FIB_LOOKUP_IGNORE_LINKSTATE)) + return false; + + if (res.type != RTN_UNICAST) { + if (res.type != RTN_LOCAL || !(flags & XT_RPFILTER_ACCEPT_LOCAL)) + return false; + } + return fib_info_nh_uses_dev(res.fi, dev) || flags & XT_RPFILTER_LOOSE; +} + +static bool +rpfilter_is_loopback(const struct sk_buff *skb, const struct net_device *in) +{ + return skb->pkt_type == PACKET_LOOPBACK || in->flags & IFF_LOOPBACK; +} + +static bool rpfilter_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_rpfilter_info *info; + const struct iphdr *iph; + struct flowi4 flow; + bool invert; + + info = par->matchinfo; + invert = info->flags & XT_RPFILTER_INVERT; + + if (rpfilter_is_loopback(skb, xt_in(par))) + return true ^ invert; + + iph = ip_hdr(skb); + if (ipv4_is_zeronet(iph->saddr)) { + if (ipv4_is_lbcast(iph->daddr) || + ipv4_is_local_multicast(iph->daddr)) + return true ^ invert; + } + + memset(&flow, 0, sizeof(flow)); + flow.flowi4_iif = LOOPBACK_IFINDEX; + flow.daddr = iph->saddr; + flow.saddr = rpfilter_get_saddr(iph->daddr); + flow.flowi4_mark = info->flags & XT_RPFILTER_VALID_MARK ? skb->mark : 0; + flow.flowi4_tos = iph->tos & IPTOS_RT_MASK; + flow.flowi4_scope = RT_SCOPE_UNIVERSE; + flow.flowi4_l3mdev = l3mdev_master_ifindex_rcu(xt_in(par)); + flow.flowi4_uid = sock_net_uid(xt_net(par), NULL); + + return rpfilter_lookup_reverse(xt_net(par), &flow, xt_in(par), info->flags) ^ invert; +} + +static int rpfilter_check(const struct xt_mtchk_param *par) +{ + const struct xt_rpfilter_info *info = par->matchinfo; + unsigned int options = ~XT_RPFILTER_OPTION_MASK; + if (info->flags & options) { + pr_info_ratelimited("unknown options\n"); + return -EINVAL; + } + + if (strcmp(par->table, "mangle") != 0 && + strcmp(par->table, "raw") != 0) { + pr_info_ratelimited("only valid in \'raw\' or \'mangle\' table, not \'%s\'\n", + par->table); + return -EINVAL; + } + + return 0; +} + +static struct xt_match rpfilter_mt_reg __read_mostly = { + .name = "rpfilter", + .family = NFPROTO_IPV4, + .checkentry = rpfilter_check, + .match = rpfilter_mt, + .matchsize = sizeof(struct xt_rpfilter_info), + .hooks = (1 << NF_INET_PRE_ROUTING), + .me = THIS_MODULE +}; + +static int __init rpfilter_mt_init(void) +{ + return xt_register_match(&rpfilter_mt_reg); +} + +static void __exit rpfilter_mt_exit(void) +{ + xt_unregister_match(&rpfilter_mt_reg); +} + +module_init(rpfilter_mt_init); +module_exit(rpfilter_mt_exit); diff --git a/net/ipv4/netfilter/iptable_filter.c b/net/ipv4/netfilter/iptable_filter.c new file mode 100644 index 000000000..b9062f455 --- /dev/null +++ b/net/ipv4/netfilter/iptable_filter.c @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * This is the 1999 rewrite of IP Firewalling, aiming for kernel 2.3.x. + * + * Copyright (C) 1999 Paul `Rusty' Russell & Michael J. Neuling + * Copyright (C) 2000-2004 Netfilter Core Team + */ + +#include +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Netfilter Core Team "); +MODULE_DESCRIPTION("iptables filter table"); + +#define FILTER_VALID_HOOKS ((1 << NF_INET_LOCAL_IN) | \ + (1 << NF_INET_FORWARD) | \ + (1 << NF_INET_LOCAL_OUT)) + +static const struct xt_table packet_filter = { + .name = "filter", + .valid_hooks = FILTER_VALID_HOOKS, + .me = THIS_MODULE, + .af = NFPROTO_IPV4, + .priority = NF_IP_PRI_FILTER, +}; + +static struct nf_hook_ops *filter_ops __read_mostly; + +/* Default to forward because I got too much mail already. */ +static bool forward __read_mostly = true; +module_param(forward, bool, 0000); + +static int iptable_filter_table_init(struct net *net) +{ + struct ipt_replace *repl; + int err; + + repl = ipt_alloc_initial_table(&packet_filter); + if (repl == NULL) + return -ENOMEM; + /* Entry 1 is the FORWARD hook */ + ((struct ipt_standard *)repl->entries)[1].target.verdict = + forward ? -NF_ACCEPT - 1 : -NF_DROP - 1; + + err = ipt_register_table(net, &packet_filter, repl, filter_ops); + kfree(repl); + return err; +} + +static int __net_init iptable_filter_net_init(struct net *net) +{ + if (!forward) + return iptable_filter_table_init(net); + + return 0; +} + +static void __net_exit iptable_filter_net_pre_exit(struct net *net) +{ + ipt_unregister_table_pre_exit(net, "filter"); +} + +static void __net_exit iptable_filter_net_exit(struct net *net) +{ + ipt_unregister_table_exit(net, "filter"); +} + +static struct pernet_operations iptable_filter_net_ops = { + .init = iptable_filter_net_init, + .pre_exit = iptable_filter_net_pre_exit, + .exit = iptable_filter_net_exit, +}; + +static int __init iptable_filter_init(void) +{ + int ret = xt_register_template(&packet_filter, + iptable_filter_table_init); + + if (ret < 0) + return ret; + + filter_ops = xt_hook_ops_alloc(&packet_filter, ipt_do_table); + if (IS_ERR(filter_ops)) { + xt_unregister_template(&packet_filter); + return PTR_ERR(filter_ops); + } + + ret = register_pernet_subsys(&iptable_filter_net_ops); + if (ret < 0) { + xt_unregister_template(&packet_filter); + kfree(filter_ops); + return ret; + } + + return 0; +} + +static void __exit iptable_filter_fini(void) +{ + unregister_pernet_subsys(&iptable_filter_net_ops); + xt_unregister_template(&packet_filter); + kfree(filter_ops); +} + +module_init(iptable_filter_init); +module_exit(iptable_filter_fini); diff --git a/net/ipv4/netfilter/iptable_mangle.c b/net/ipv4/netfilter/iptable_mangle.c new file mode 100644 index 000000000..3abb430af --- /dev/null +++ b/net/ipv4/netfilter/iptable_mangle.c @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * This is the 1999 rewrite of IP Firewalling, aiming for kernel 2.3.x. + * + * Copyright (C) 1999 Paul `Rusty' Russell & Michael J. Neuling + * Copyright (C) 2000-2004 Netfilter Core Team + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Netfilter Core Team "); +MODULE_DESCRIPTION("iptables mangle table"); + +#define MANGLE_VALID_HOOKS ((1 << NF_INET_PRE_ROUTING) | \ + (1 << NF_INET_LOCAL_IN) | \ + (1 << NF_INET_FORWARD) | \ + (1 << NF_INET_LOCAL_OUT) | \ + (1 << NF_INET_POST_ROUTING)) + +static const struct xt_table packet_mangler = { + .name = "mangle", + .valid_hooks = MANGLE_VALID_HOOKS, + .me = THIS_MODULE, + .af = NFPROTO_IPV4, + .priority = NF_IP_PRI_MANGLE, +}; + +static unsigned int +ipt_mangle_out(void *priv, struct sk_buff *skb, const struct nf_hook_state *state) +{ + unsigned int ret; + const struct iphdr *iph; + u_int8_t tos; + __be32 saddr, daddr; + u_int32_t mark; + int err; + + /* Save things which could affect route */ + mark = skb->mark; + iph = ip_hdr(skb); + saddr = iph->saddr; + daddr = iph->daddr; + tos = iph->tos; + + ret = ipt_do_table(priv, skb, state); + /* Reroute for ANY change. */ + if (ret != NF_DROP && ret != NF_STOLEN) { + iph = ip_hdr(skb); + + if (iph->saddr != saddr || + iph->daddr != daddr || + skb->mark != mark || + iph->tos != tos) { + err = ip_route_me_harder(state->net, state->sk, skb, RTN_UNSPEC); + if (err < 0) + ret = NF_DROP_ERR(err); + } + } + + return ret; +} + +/* The work comes in here from netfilter.c. */ +static unsigned int +iptable_mangle_hook(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + if (state->hook == NF_INET_LOCAL_OUT) + return ipt_mangle_out(priv, skb, state); + return ipt_do_table(priv, skb, state); +} + +static struct nf_hook_ops *mangle_ops __read_mostly; +static int iptable_mangle_table_init(struct net *net) +{ + struct ipt_replace *repl; + int ret; + + repl = ipt_alloc_initial_table(&packet_mangler); + if (repl == NULL) + return -ENOMEM; + ret = ipt_register_table(net, &packet_mangler, repl, mangle_ops); + kfree(repl); + return ret; +} + +static void __net_exit iptable_mangle_net_pre_exit(struct net *net) +{ + ipt_unregister_table_pre_exit(net, "mangle"); +} + +static void __net_exit iptable_mangle_net_exit(struct net *net) +{ + ipt_unregister_table_exit(net, "mangle"); +} + +static struct pernet_operations iptable_mangle_net_ops = { + .pre_exit = iptable_mangle_net_pre_exit, + .exit = iptable_mangle_net_exit, +}; + +static int __init iptable_mangle_init(void) +{ + int ret = xt_register_template(&packet_mangler, + iptable_mangle_table_init); + if (ret < 0) + return ret; + + mangle_ops = xt_hook_ops_alloc(&packet_mangler, iptable_mangle_hook); + if (IS_ERR(mangle_ops)) { + xt_unregister_template(&packet_mangler); + ret = PTR_ERR(mangle_ops); + return ret; + } + + ret = register_pernet_subsys(&iptable_mangle_net_ops); + if (ret < 0) { + xt_unregister_template(&packet_mangler); + kfree(mangle_ops); + return ret; + } + + return ret; +} + +static void __exit iptable_mangle_fini(void) +{ + unregister_pernet_subsys(&iptable_mangle_net_ops); + xt_unregister_template(&packet_mangler); + kfree(mangle_ops); +} + +module_init(iptable_mangle_init); +module_exit(iptable_mangle_fini); diff --git a/net/ipv4/netfilter/iptable_nat.c b/net/ipv4/netfilter/iptable_nat.c new file mode 100644 index 000000000..56f6ecc43 --- /dev/null +++ b/net/ipv4/netfilter/iptable_nat.c @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2006 Netfilter Core Team + * (C) 2011 Patrick McHardy + */ + +#include +#include +#include +#include +#include +#include + +#include + +struct iptable_nat_pernet { + struct nf_hook_ops *nf_nat_ops; +}; + +static unsigned int iptable_nat_net_id __read_mostly; + +static const struct xt_table nf_nat_ipv4_table = { + .name = "nat", + .valid_hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_LOCAL_IN), + .me = THIS_MODULE, + .af = NFPROTO_IPV4, +}; + +static const struct nf_hook_ops nf_nat_ipv4_ops[] = { + { + .hook = ipt_do_table, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_PRE_ROUTING, + .priority = NF_IP_PRI_NAT_DST, + }, + { + .hook = ipt_do_table, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_POST_ROUTING, + .priority = NF_IP_PRI_NAT_SRC, + }, + { + .hook = ipt_do_table, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_LOCAL_OUT, + .priority = NF_IP_PRI_NAT_DST, + }, + { + .hook = ipt_do_table, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_LOCAL_IN, + .priority = NF_IP_PRI_NAT_SRC, + }, +}; + +static int ipt_nat_register_lookups(struct net *net) +{ + struct iptable_nat_pernet *xt_nat_net; + struct nf_hook_ops *ops; + struct xt_table *table; + int i, ret; + + xt_nat_net = net_generic(net, iptable_nat_net_id); + table = xt_find_table(net, NFPROTO_IPV4, "nat"); + if (WARN_ON_ONCE(!table)) + return -ENOENT; + + ops = kmemdup(nf_nat_ipv4_ops, sizeof(nf_nat_ipv4_ops), GFP_KERNEL); + if (!ops) + return -ENOMEM; + + for (i = 0; i < ARRAY_SIZE(nf_nat_ipv4_ops); i++) { + ops[i].priv = table; + ret = nf_nat_ipv4_register_fn(net, &ops[i]); + if (ret) { + while (i) + nf_nat_ipv4_unregister_fn(net, &ops[--i]); + + kfree(ops); + return ret; + } + } + + xt_nat_net->nf_nat_ops = ops; + return 0; +} + +static void ipt_nat_unregister_lookups(struct net *net) +{ + struct iptable_nat_pernet *xt_nat_net = net_generic(net, iptable_nat_net_id); + struct nf_hook_ops *ops = xt_nat_net->nf_nat_ops; + int i; + + if (!ops) + return; + + for (i = 0; i < ARRAY_SIZE(nf_nat_ipv4_ops); i++) + nf_nat_ipv4_unregister_fn(net, &ops[i]); + + kfree(ops); +} + +static int iptable_nat_table_init(struct net *net) +{ + struct ipt_replace *repl; + int ret; + + repl = ipt_alloc_initial_table(&nf_nat_ipv4_table); + if (repl == NULL) + return -ENOMEM; + + ret = ipt_register_table(net, &nf_nat_ipv4_table, repl, NULL); + if (ret < 0) { + kfree(repl); + return ret; + } + + ret = ipt_nat_register_lookups(net); + if (ret < 0) + ipt_unregister_table_exit(net, "nat"); + + kfree(repl); + return ret; +} + +static void __net_exit iptable_nat_net_pre_exit(struct net *net) +{ + ipt_nat_unregister_lookups(net); +} + +static void __net_exit iptable_nat_net_exit(struct net *net) +{ + ipt_unregister_table_exit(net, "nat"); +} + +static struct pernet_operations iptable_nat_net_ops = { + .pre_exit = iptable_nat_net_pre_exit, + .exit = iptable_nat_net_exit, + .id = &iptable_nat_net_id, + .size = sizeof(struct iptable_nat_pernet), +}; + +static int __init iptable_nat_init(void) +{ + int ret = xt_register_template(&nf_nat_ipv4_table, + iptable_nat_table_init); + + if (ret < 0) + return ret; + + ret = register_pernet_subsys(&iptable_nat_net_ops); + if (ret < 0) { + xt_unregister_template(&nf_nat_ipv4_table); + return ret; + } + + return ret; +} + +static void __exit iptable_nat_exit(void) +{ + unregister_pernet_subsys(&iptable_nat_net_ops); + xt_unregister_template(&nf_nat_ipv4_table); +} + +module_init(iptable_nat_init); +module_exit(iptable_nat_exit); + +MODULE_LICENSE("GPL"); diff --git a/net/ipv4/netfilter/iptable_raw.c b/net/ipv4/netfilter/iptable_raw.c new file mode 100644 index 000000000..ca5e5b215 --- /dev/null +++ b/net/ipv4/netfilter/iptable_raw.c @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * 'raw' table, which is the very first hooked in at PRE_ROUTING and LOCAL_OUT . + * + * Copyright (C) 2003 Jozsef Kadlecsik + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include + +#define RAW_VALID_HOOKS ((1 << NF_INET_PRE_ROUTING) | (1 << NF_INET_LOCAL_OUT)) + +static bool raw_before_defrag __read_mostly; +MODULE_PARM_DESC(raw_before_defrag, "Enable raw table before defrag"); +module_param(raw_before_defrag, bool, 0000); + +static const struct xt_table packet_raw = { + .name = "raw", + .valid_hooks = RAW_VALID_HOOKS, + .me = THIS_MODULE, + .af = NFPROTO_IPV4, + .priority = NF_IP_PRI_RAW, +}; + +static const struct xt_table packet_raw_before_defrag = { + .name = "raw", + .valid_hooks = RAW_VALID_HOOKS, + .me = THIS_MODULE, + .af = NFPROTO_IPV4, + .priority = NF_IP_PRI_RAW_BEFORE_DEFRAG, +}; + +static struct nf_hook_ops *rawtable_ops __read_mostly; + +static int iptable_raw_table_init(struct net *net) +{ + struct ipt_replace *repl; + const struct xt_table *table = &packet_raw; + int ret; + + if (raw_before_defrag) + table = &packet_raw_before_defrag; + + repl = ipt_alloc_initial_table(table); + if (repl == NULL) + return -ENOMEM; + ret = ipt_register_table(net, table, repl, rawtable_ops); + kfree(repl); + return ret; +} + +static void __net_exit iptable_raw_net_pre_exit(struct net *net) +{ + ipt_unregister_table_pre_exit(net, "raw"); +} + +static void __net_exit iptable_raw_net_exit(struct net *net) +{ + ipt_unregister_table_exit(net, "raw"); +} + +static struct pernet_operations iptable_raw_net_ops = { + .pre_exit = iptable_raw_net_pre_exit, + .exit = iptable_raw_net_exit, +}; + +static int __init iptable_raw_init(void) +{ + int ret; + const struct xt_table *table = &packet_raw; + + if (raw_before_defrag) { + table = &packet_raw_before_defrag; + + pr_info("Enabling raw table before defrag\n"); + } + + ret = xt_register_template(table, + iptable_raw_table_init); + if (ret < 0) + return ret; + + rawtable_ops = xt_hook_ops_alloc(table, ipt_do_table); + if (IS_ERR(rawtable_ops)) { + xt_unregister_template(table); + return PTR_ERR(rawtable_ops); + } + + ret = register_pernet_subsys(&iptable_raw_net_ops); + if (ret < 0) { + xt_unregister_template(table); + kfree(rawtable_ops); + return ret; + } + + return ret; +} + +static void __exit iptable_raw_fini(void) +{ + unregister_pernet_subsys(&iptable_raw_net_ops); + kfree(rawtable_ops); + xt_unregister_template(&packet_raw); +} + +module_init(iptable_raw_init); +module_exit(iptable_raw_fini); +MODULE_LICENSE("GPL"); diff --git a/net/ipv4/netfilter/iptable_security.c b/net/ipv4/netfilter/iptable_security.c new file mode 100644 index 000000000..d885443cb --- /dev/null +++ b/net/ipv4/netfilter/iptable_security.c @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * "security" table + * + * This is for use by Mandatory Access Control (MAC) security models, + * which need to be able to manage security policy in separate context + * to DAC. + * + * Based on iptable_mangle.c + * + * Copyright (C) 1999 Paul `Rusty' Russell & Michael J. Neuling + * Copyright (C) 2000-2004 Netfilter Core Team netfilter.org> + * Copyright (C) 2008 Red Hat, Inc., James Morris redhat.com> + */ +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("James Morris redhat.com>"); +MODULE_DESCRIPTION("iptables security table, for MAC rules"); + +#define SECURITY_VALID_HOOKS (1 << NF_INET_LOCAL_IN) | \ + (1 << NF_INET_FORWARD) | \ + (1 << NF_INET_LOCAL_OUT) + +static const struct xt_table security_table = { + .name = "security", + .valid_hooks = SECURITY_VALID_HOOKS, + .me = THIS_MODULE, + .af = NFPROTO_IPV4, + .priority = NF_IP_PRI_SECURITY, +}; + +static struct nf_hook_ops *sectbl_ops __read_mostly; + +static int iptable_security_table_init(struct net *net) +{ + struct ipt_replace *repl; + int ret; + + repl = ipt_alloc_initial_table(&security_table); + if (repl == NULL) + return -ENOMEM; + ret = ipt_register_table(net, &security_table, repl, sectbl_ops); + kfree(repl); + return ret; +} + +static void __net_exit iptable_security_net_pre_exit(struct net *net) +{ + ipt_unregister_table_pre_exit(net, "security"); +} + +static void __net_exit iptable_security_net_exit(struct net *net) +{ + ipt_unregister_table_exit(net, "security"); +} + +static struct pernet_operations iptable_security_net_ops = { + .pre_exit = iptable_security_net_pre_exit, + .exit = iptable_security_net_exit, +}; + +static int __init iptable_security_init(void) +{ + int ret = xt_register_template(&security_table, + iptable_security_table_init); + + if (ret < 0) + return ret; + + sectbl_ops = xt_hook_ops_alloc(&security_table, ipt_do_table); + if (IS_ERR(sectbl_ops)) { + xt_unregister_template(&security_table); + return PTR_ERR(sectbl_ops); + } + + ret = register_pernet_subsys(&iptable_security_net_ops); + if (ret < 0) { + xt_unregister_template(&security_table); + kfree(sectbl_ops); + return ret; + } + + return ret; +} + +static void __exit iptable_security_fini(void) +{ + unregister_pernet_subsys(&iptable_security_net_ops); + kfree(sectbl_ops); + xt_unregister_template(&security_table); +} + +module_init(iptable_security_init); +module_exit(iptable_security_fini); diff --git a/net/ipv4/netfilter/nf_defrag_ipv4.c b/net/ipv4/netfilter/nf_defrag_ipv4.c new file mode 100644 index 000000000..e61ea428e --- /dev/null +++ b/net/ipv4/netfilter/nf_defrag_ipv4.c @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2004 Netfilter Core Team + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#if IS_ENABLED(CONFIG_NF_CONNTRACK) +#include +#endif +#include + +static DEFINE_MUTEX(defrag4_mutex); + +static int nf_ct_ipv4_gather_frags(struct net *net, struct sk_buff *skb, + u_int32_t user) +{ + int err; + + local_bh_disable(); + err = ip_defrag(net, skb, user); + local_bh_enable(); + + if (!err) + skb->ignore_df = 1; + + return err; +} + +static enum ip_defrag_users nf_ct_defrag_user(unsigned int hooknum, + struct sk_buff *skb) +{ + u16 zone_id = NF_CT_DEFAULT_ZONE_ID; +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + if (skb_nfct(skb)) { + enum ip_conntrack_info ctinfo; + const struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + + zone_id = nf_ct_zone_id(nf_ct_zone(ct), CTINFO2DIR(ctinfo)); + } +#endif + if (nf_bridge_in_prerouting(skb)) + return IP_DEFRAG_CONNTRACK_BRIDGE_IN + zone_id; + + if (hooknum == NF_INET_PRE_ROUTING) + return IP_DEFRAG_CONNTRACK_IN + zone_id; + else + return IP_DEFRAG_CONNTRACK_OUT + zone_id; +} + +static unsigned int ipv4_conntrack_defrag(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct sock *sk = skb->sk; + + if (sk && sk_fullsock(sk) && (sk->sk_family == PF_INET) && + inet_sk(sk)->nodefrag) + return NF_ACCEPT; + +#if IS_ENABLED(CONFIG_NF_CONNTRACK) +#if !IS_ENABLED(CONFIG_NF_NAT) + /* Previously seen (loopback)? Ignore. Do this before + fragment check. */ + if (skb_nfct(skb) && !nf_ct_is_template((struct nf_conn *)skb_nfct(skb))) + return NF_ACCEPT; +#endif + if (skb->_nfct == IP_CT_UNTRACKED) + return NF_ACCEPT; +#endif + /* Gather fragments. */ + if (ip_is_fragment(ip_hdr(skb))) { + enum ip_defrag_users user = + nf_ct_defrag_user(state->hook, skb); + + if (nf_ct_ipv4_gather_frags(state->net, skb, user)) + return NF_STOLEN; + } + return NF_ACCEPT; +} + +static const struct nf_hook_ops ipv4_defrag_ops[] = { + { + .hook = ipv4_conntrack_defrag, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_PRE_ROUTING, + .priority = NF_IP_PRI_CONNTRACK_DEFRAG, + }, + { + .hook = ipv4_conntrack_defrag, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_LOCAL_OUT, + .priority = NF_IP_PRI_CONNTRACK_DEFRAG, + }, +}; + +static void __net_exit defrag4_net_exit(struct net *net) +{ + if (net->nf.defrag_ipv4_users) { + nf_unregister_net_hooks(net, ipv4_defrag_ops, + ARRAY_SIZE(ipv4_defrag_ops)); + net->nf.defrag_ipv4_users = 0; + } +} + +static struct pernet_operations defrag4_net_ops = { + .exit = defrag4_net_exit, +}; + +static int __init nf_defrag_init(void) +{ + return register_pernet_subsys(&defrag4_net_ops); +} + +static void __exit nf_defrag_fini(void) +{ + unregister_pernet_subsys(&defrag4_net_ops); +} + +int nf_defrag_ipv4_enable(struct net *net) +{ + int err = 0; + + mutex_lock(&defrag4_mutex); + if (net->nf.defrag_ipv4_users == UINT_MAX) { + err = -EOVERFLOW; + goto out_unlock; + } + + if (net->nf.defrag_ipv4_users) { + net->nf.defrag_ipv4_users++; + goto out_unlock; + } + + err = nf_register_net_hooks(net, ipv4_defrag_ops, + ARRAY_SIZE(ipv4_defrag_ops)); + if (err == 0) + net->nf.defrag_ipv4_users = 1; + + out_unlock: + mutex_unlock(&defrag4_mutex); + return err; +} +EXPORT_SYMBOL_GPL(nf_defrag_ipv4_enable); + +void nf_defrag_ipv4_disable(struct net *net) +{ + mutex_lock(&defrag4_mutex); + if (net->nf.defrag_ipv4_users) { + net->nf.defrag_ipv4_users--; + if (net->nf.defrag_ipv4_users == 0) + nf_unregister_net_hooks(net, ipv4_defrag_ops, + ARRAY_SIZE(ipv4_defrag_ops)); + } + + mutex_unlock(&defrag4_mutex); +} +EXPORT_SYMBOL_GPL(nf_defrag_ipv4_disable); + +module_init(nf_defrag_init); +module_exit(nf_defrag_fini); + +MODULE_LICENSE("GPL"); diff --git a/net/ipv4/netfilter/nf_dup_ipv4.c b/net/ipv4/netfilter/nf_dup_ipv4.c new file mode 100644 index 000000000..6cc5743c5 --- /dev/null +++ b/net/ipv4/netfilter/nf_dup_ipv4.c @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * (C) 2007 by Sebastian Claßen + * (C) 2007-2010 by Jan Engelhardt + * + * Extracted from xt_TEE.c + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_NF_CONNTRACK) +#include +#endif + +static bool nf_dup_ipv4_route(struct net *net, struct sk_buff *skb, + const struct in_addr *gw, int oif) +{ + const struct iphdr *iph = ip_hdr(skb); + struct rtable *rt; + struct flowi4 fl4; + + memset(&fl4, 0, sizeof(fl4)); + if (oif != -1) + fl4.flowi4_oif = oif; + + fl4.daddr = gw->s_addr; + fl4.flowi4_tos = RT_TOS(iph->tos); + fl4.flowi4_scope = RT_SCOPE_UNIVERSE; + fl4.flowi4_flags = FLOWI_FLAG_KNOWN_NH; + rt = ip_route_output_key(net, &fl4); + if (IS_ERR(rt)) + return false; + + skb_dst_drop(skb); + skb_dst_set(skb, &rt->dst); + skb->dev = rt->dst.dev; + skb->protocol = htons(ETH_P_IP); + + return true; +} + +void nf_dup_ipv4(struct net *net, struct sk_buff *skb, unsigned int hooknum, + const struct in_addr *gw, int oif) +{ + struct iphdr *iph; + + if (this_cpu_read(nf_skb_duplicated)) + return; + /* + * Copy the skb, and route the copy. Will later return %XT_CONTINUE for + * the original skb, which should continue on its way as if nothing has + * happened. The copy should be independently delivered to the gateway. + */ + skb = pskb_copy(skb, GFP_ATOMIC); + if (skb == NULL) + return; + +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + /* Avoid counting cloned packets towards the original connection. */ + nf_reset_ct(skb); + nf_ct_set(skb, NULL, IP_CT_UNTRACKED); +#endif + /* + * If we are in PREROUTING/INPUT, decrease the TTL to mitigate potential + * loops between two hosts. + * + * Set %IP_DF so that the original source is notified of a potentially + * decreased MTU on the clone route. IPv6 does this too. + * + * IP header checksum will be recalculated at ip_local_out. + */ + iph = ip_hdr(skb); + iph->frag_off |= htons(IP_DF); + if (hooknum == NF_INET_PRE_ROUTING || + hooknum == NF_INET_LOCAL_IN) + --iph->ttl; + + if (nf_dup_ipv4_route(net, skb, gw, oif)) { + __this_cpu_write(nf_skb_duplicated, true); + ip_local_out(net, skb->sk, skb); + __this_cpu_write(nf_skb_duplicated, false); + } else { + kfree_skb(skb); + } +} +EXPORT_SYMBOL_GPL(nf_dup_ipv4); + +MODULE_AUTHOR("Sebastian Claßen "); +MODULE_AUTHOR("Jan Engelhardt "); +MODULE_DESCRIPTION("nf_dup_ipv4: Duplicate IPv4 packet"); +MODULE_LICENSE("GPL"); diff --git a/net/ipv4/netfilter/nf_nat_h323.c b/net/ipv4/netfilter/nf_nat_h323.c new file mode 100644 index 000000000..faee20af4 --- /dev/null +++ b/net/ipv4/netfilter/nf_nat_h323.c @@ -0,0 +1,567 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * H.323 extension for NAT alteration. + * + * Copyright (c) 2006 Jing Min Zhao + * Copyright (c) 2006-2012 Patrick McHardy + * + * Based on the 'brute force' H.323 NAT module by + * Jozsef Kadlecsik + */ + +#include +#include +#include + +#include +#include +#include +#include +#include + +/****************************************************************************/ +static int set_addr(struct sk_buff *skb, unsigned int protoff, + unsigned char **data, int dataoff, + unsigned int addroff, __be32 ip, __be16 port) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + struct { + __be32 ip; + __be16 port; + } __attribute__ ((__packed__)) buf; + const struct tcphdr *th; + struct tcphdr _tcph; + + buf.ip = ip; + buf.port = port; + addroff += dataoff; + + if (ip_hdr(skb)->protocol == IPPROTO_TCP) { + if (!nf_nat_mangle_tcp_packet(skb, ct, ctinfo, + protoff, addroff, sizeof(buf), + (char *) &buf, sizeof(buf))) { + net_notice_ratelimited("nf_nat_h323: nf_nat_mangle_tcp_packet error\n"); + return -1; + } + + /* Relocate data pointer */ + th = skb_header_pointer(skb, ip_hdrlen(skb), + sizeof(_tcph), &_tcph); + if (th == NULL) + return -1; + *data = skb->data + ip_hdrlen(skb) + th->doff * 4 + dataoff; + } else { + if (!nf_nat_mangle_udp_packet(skb, ct, ctinfo, + protoff, addroff, sizeof(buf), + (char *) &buf, sizeof(buf))) { + net_notice_ratelimited("nf_nat_h323: nf_nat_mangle_udp_packet error\n"); + return -1; + } + /* nf_nat_mangle_udp_packet uses skb_ensure_writable() to copy + * or pull everything in a linear buffer, so we can safely + * use the skb pointers now */ + *data = skb->data + ip_hdrlen(skb) + sizeof(struct udphdr); + } + + return 0; +} + +/****************************************************************************/ +static int set_h225_addr(struct sk_buff *skb, unsigned int protoff, + unsigned char **data, int dataoff, + TransportAddress *taddr, + union nf_inet_addr *addr, __be16 port) +{ + return set_addr(skb, protoff, data, dataoff, taddr->ipAddress.ip, + addr->ip, port); +} + +/****************************************************************************/ +static int set_h245_addr(struct sk_buff *skb, unsigned protoff, + unsigned char **data, int dataoff, + H245_TransportAddress *taddr, + union nf_inet_addr *addr, __be16 port) +{ + return set_addr(skb, protoff, data, dataoff, + taddr->unicastAddress.iPAddress.network, + addr->ip, port); +} + +/****************************************************************************/ +static int set_sig_addr(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, unsigned char **data, + TransportAddress *taddr, int count) +{ + const struct nf_ct_h323_master *info = nfct_help_data(ct); + int dir = CTINFO2DIR(ctinfo); + int i; + __be16 port; + union nf_inet_addr addr; + + for (i = 0; i < count; i++) { + if (get_h225_addr(ct, *data, &taddr[i], &addr, &port)) { + if (addr.ip == ct->tuplehash[dir].tuple.src.u3.ip && + port == info->sig_port[dir]) { + /* GW->GK */ + + /* Fix for Gnomemeeting */ + if (i > 0 && + get_h225_addr(ct, *data, &taddr[0], + &addr, &port) && + (ntohl(addr.ip) & 0xff000000) == 0x7f000000) + i = 0; + + pr_debug("nf_nat_ras: set signal address %pI4:%hu->%pI4:%hu\n", + &addr.ip, port, + &ct->tuplehash[!dir].tuple.dst.u3.ip, + info->sig_port[!dir]); + return set_h225_addr(skb, protoff, data, 0, + &taddr[i], + &ct->tuplehash[!dir]. + tuple.dst.u3, + info->sig_port[!dir]); + } else if (addr.ip == ct->tuplehash[dir].tuple.dst.u3.ip && + port == info->sig_port[dir]) { + /* GK->GW */ + pr_debug("nf_nat_ras: set signal address %pI4:%hu->%pI4:%hu\n", + &addr.ip, port, + &ct->tuplehash[!dir].tuple.src.u3.ip, + info->sig_port[!dir]); + return set_h225_addr(skb, protoff, data, 0, + &taddr[i], + &ct->tuplehash[!dir]. + tuple.src.u3, + info->sig_port[!dir]); + } + } + } + + return 0; +} + +/****************************************************************************/ +static int set_ras_addr(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, unsigned char **data, + TransportAddress *taddr, int count) +{ + int dir = CTINFO2DIR(ctinfo); + int i; + __be16 port; + union nf_inet_addr addr; + + for (i = 0; i < count; i++) { + if (get_h225_addr(ct, *data, &taddr[i], &addr, &port) && + addr.ip == ct->tuplehash[dir].tuple.src.u3.ip && + port == ct->tuplehash[dir].tuple.src.u.udp.port) { + pr_debug("nf_nat_ras: set rasAddress %pI4:%hu->%pI4:%hu\n", + &addr.ip, ntohs(port), + &ct->tuplehash[!dir].tuple.dst.u3.ip, + ntohs(ct->tuplehash[!dir].tuple.dst.u.udp.port)); + return set_h225_addr(skb, protoff, data, 0, &taddr[i], + &ct->tuplehash[!dir].tuple.dst.u3, + ct->tuplehash[!dir].tuple. + dst.u.udp.port); + } + } + + return 0; +} + +/****************************************************************************/ +static int nat_rtp_rtcp(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, unsigned char **data, int dataoff, + H245_TransportAddress *taddr, + __be16 port, __be16 rtp_port, + struct nf_conntrack_expect *rtp_exp, + struct nf_conntrack_expect *rtcp_exp) +{ + struct nf_ct_h323_master *info = nfct_help_data(ct); + int dir = CTINFO2DIR(ctinfo); + int i; + u_int16_t nated_port; + + /* Set expectations for NAT */ + rtp_exp->saved_proto.udp.port = rtp_exp->tuple.dst.u.udp.port; + rtp_exp->expectfn = nf_nat_follow_master; + rtp_exp->dir = !dir; + rtcp_exp->saved_proto.udp.port = rtcp_exp->tuple.dst.u.udp.port; + rtcp_exp->expectfn = nf_nat_follow_master; + rtcp_exp->dir = !dir; + + /* Lookup existing expects */ + for (i = 0; i < H323_RTP_CHANNEL_MAX; i++) { + if (info->rtp_port[i][dir] == rtp_port) { + /* Expected */ + + /* Use allocated ports first. This will refresh + * the expects */ + rtp_exp->tuple.dst.u.udp.port = info->rtp_port[i][dir]; + rtcp_exp->tuple.dst.u.udp.port = + htons(ntohs(info->rtp_port[i][dir]) + 1); + break; + } else if (info->rtp_port[i][dir] == 0) { + /* Not expected */ + break; + } + } + + /* Run out of expectations */ + if (i >= H323_RTP_CHANNEL_MAX) { + net_notice_ratelimited("nf_nat_h323: out of expectations\n"); + return 0; + } + + /* Try to get a pair of ports. */ + for (nated_port = ntohs(rtp_exp->tuple.dst.u.udp.port); + nated_port != 0; nated_port += 2) { + int ret; + + rtp_exp->tuple.dst.u.udp.port = htons(nated_port); + ret = nf_ct_expect_related(rtp_exp, 0); + if (ret == 0) { + rtcp_exp->tuple.dst.u.udp.port = + htons(nated_port + 1); + ret = nf_ct_expect_related(rtcp_exp, 0); + if (ret == 0) + break; + else if (ret == -EBUSY) { + nf_ct_unexpect_related(rtp_exp); + continue; + } else if (ret < 0) { + nf_ct_unexpect_related(rtp_exp); + nated_port = 0; + break; + } + } else if (ret != -EBUSY) { + nated_port = 0; + break; + } + } + + if (nated_port == 0) { /* No port available */ + net_notice_ratelimited("nf_nat_h323: out of RTP ports\n"); + return 0; + } + + /* Modify signal */ + if (set_h245_addr(skb, protoff, data, dataoff, taddr, + &ct->tuplehash[!dir].tuple.dst.u3, + htons((port & htons(1)) ? nated_port + 1 : + nated_port))) { + nf_ct_unexpect_related(rtp_exp); + nf_ct_unexpect_related(rtcp_exp); + return -1; + } + + /* Save ports */ + info->rtp_port[i][dir] = rtp_port; + info->rtp_port[i][!dir] = htons(nated_port); + + /* Success */ + pr_debug("nf_nat_h323: expect RTP %pI4:%hu->%pI4:%hu\n", + &rtp_exp->tuple.src.u3.ip, + ntohs(rtp_exp->tuple.src.u.udp.port), + &rtp_exp->tuple.dst.u3.ip, + ntohs(rtp_exp->tuple.dst.u.udp.port)); + pr_debug("nf_nat_h323: expect RTCP %pI4:%hu->%pI4:%hu\n", + &rtcp_exp->tuple.src.u3.ip, + ntohs(rtcp_exp->tuple.src.u.udp.port), + &rtcp_exp->tuple.dst.u3.ip, + ntohs(rtcp_exp->tuple.dst.u.udp.port)); + + return 0; +} + +/****************************************************************************/ +static int nat_t120(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, unsigned char **data, int dataoff, + H245_TransportAddress *taddr, __be16 port, + struct nf_conntrack_expect *exp) +{ + int dir = CTINFO2DIR(ctinfo); + u_int16_t nated_port = ntohs(port); + + /* Set expectations for NAT */ + exp->saved_proto.tcp.port = exp->tuple.dst.u.tcp.port; + exp->expectfn = nf_nat_follow_master; + exp->dir = !dir; + + nated_port = nf_nat_exp_find_port(exp, nated_port); + if (nated_port == 0) { /* No port available */ + net_notice_ratelimited("nf_nat_h323: out of TCP ports\n"); + return 0; + } + + /* Modify signal */ + if (set_h245_addr(skb, protoff, data, dataoff, taddr, + &ct->tuplehash[!dir].tuple.dst.u3, + htons(nated_port)) < 0) { + nf_ct_unexpect_related(exp); + return -1; + } + + pr_debug("nf_nat_h323: expect T.120 %pI4:%hu->%pI4:%hu\n", + &exp->tuple.src.u3.ip, + ntohs(exp->tuple.src.u.tcp.port), + &exp->tuple.dst.u3.ip, + ntohs(exp->tuple.dst.u.tcp.port)); + + return 0; +} + +/****************************************************************************/ +static int nat_h245(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, unsigned char **data, int dataoff, + TransportAddress *taddr, __be16 port, + struct nf_conntrack_expect *exp) +{ + struct nf_ct_h323_master *info = nfct_help_data(ct); + int dir = CTINFO2DIR(ctinfo); + u_int16_t nated_port = ntohs(port); + + /* Set expectations for NAT */ + exp->saved_proto.tcp.port = exp->tuple.dst.u.tcp.port; + exp->expectfn = nf_nat_follow_master; + exp->dir = !dir; + + /* Check existing expects */ + if (info->sig_port[dir] == port) + nated_port = ntohs(info->sig_port[!dir]); + + nated_port = nf_nat_exp_find_port(exp, nated_port); + if (nated_port == 0) { /* No port available */ + net_notice_ratelimited("nf_nat_q931: out of TCP ports\n"); + return 0; + } + + /* Modify signal */ + if (set_h225_addr(skb, protoff, data, dataoff, taddr, + &ct->tuplehash[!dir].tuple.dst.u3, + htons(nated_port))) { + nf_ct_unexpect_related(exp); + return -1; + } + + /* Save ports */ + info->sig_port[dir] = port; + info->sig_port[!dir] = htons(nated_port); + + pr_debug("nf_nat_q931: expect H.245 %pI4:%hu->%pI4:%hu\n", + &exp->tuple.src.u3.ip, + ntohs(exp->tuple.src.u.tcp.port), + &exp->tuple.dst.u3.ip, + ntohs(exp->tuple.dst.u.tcp.port)); + + return 0; +} + +/**************************************************************************** + * This conntrack expect function replaces nf_conntrack_q931_expect() + * which was set by nf_conntrack_h323.c. + ****************************************************************************/ +static void ip_nat_q931_expect(struct nf_conn *new, + struct nf_conntrack_expect *this) +{ + struct nf_nat_range2 range; + + if (this->tuple.src.u3.ip != 0) { /* Only accept calls from GK */ + nf_nat_follow_master(new, this); + return; + } + + /* This must be a fresh one. */ + BUG_ON(new->status & IPS_NAT_DONE_MASK); + + /* Change src to where master sends to */ + range.flags = NF_NAT_RANGE_MAP_IPS; + range.min_addr = range.max_addr = + new->tuplehash[!this->dir].tuple.src.u3; + nf_nat_setup_info(new, &range, NF_NAT_MANIP_SRC); + + /* For DST manip, map port here to where it's expected. */ + range.flags = (NF_NAT_RANGE_MAP_IPS | NF_NAT_RANGE_PROTO_SPECIFIED); + range.min_proto = range.max_proto = this->saved_proto; + range.min_addr = range.max_addr = + new->master->tuplehash[!this->dir].tuple.src.u3; + nf_nat_setup_info(new, &range, NF_NAT_MANIP_DST); +} + +/****************************************************************************/ +static int nat_q931(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, unsigned char **data, + TransportAddress *taddr, int idx, + __be16 port, struct nf_conntrack_expect *exp) +{ + struct nf_ct_h323_master *info = nfct_help_data(ct); + int dir = CTINFO2DIR(ctinfo); + u_int16_t nated_port = ntohs(port); + union nf_inet_addr addr; + + /* Set expectations for NAT */ + exp->saved_proto.tcp.port = exp->tuple.dst.u.tcp.port; + exp->expectfn = ip_nat_q931_expect; + exp->dir = !dir; + + /* Check existing expects */ + if (info->sig_port[dir] == port) + nated_port = ntohs(info->sig_port[!dir]); + + nated_port = nf_nat_exp_find_port(exp, nated_port); + if (nated_port == 0) { /* No port available */ + net_notice_ratelimited("nf_nat_ras: out of TCP ports\n"); + return 0; + } + + /* Modify signal */ + if (set_h225_addr(skb, protoff, data, 0, &taddr[idx], + &ct->tuplehash[!dir].tuple.dst.u3, + htons(nated_port))) { + nf_ct_unexpect_related(exp); + return -1; + } + + /* Save ports */ + info->sig_port[dir] = port; + info->sig_port[!dir] = htons(nated_port); + + /* Fix for Gnomemeeting */ + if (idx > 0 && + get_h225_addr(ct, *data, &taddr[0], &addr, &port) && + (ntohl(addr.ip) & 0xff000000) == 0x7f000000) { + if (set_h225_addr(skb, protoff, data, 0, &taddr[0], + &ct->tuplehash[!dir].tuple.dst.u3, + info->sig_port[!dir])) { + nf_ct_unexpect_related(exp); + return -1; + } + } + + /* Success */ + pr_debug("nf_nat_ras: expect Q.931 %pI4:%hu->%pI4:%hu\n", + &exp->tuple.src.u3.ip, + ntohs(exp->tuple.src.u.tcp.port), + &exp->tuple.dst.u3.ip, + ntohs(exp->tuple.dst.u.tcp.port)); + + return 0; +} + +/****************************************************************************/ +static void ip_nat_callforwarding_expect(struct nf_conn *new, + struct nf_conntrack_expect *this) +{ + struct nf_nat_range2 range; + + /* This must be a fresh one. */ + BUG_ON(new->status & IPS_NAT_DONE_MASK); + + /* Change src to where master sends to */ + range.flags = NF_NAT_RANGE_MAP_IPS; + range.min_addr = range.max_addr = + new->tuplehash[!this->dir].tuple.src.u3; + nf_nat_setup_info(new, &range, NF_NAT_MANIP_SRC); + + /* For DST manip, map port here to where it's expected. */ + range.flags = (NF_NAT_RANGE_MAP_IPS | NF_NAT_RANGE_PROTO_SPECIFIED); + range.min_proto = range.max_proto = this->saved_proto; + range.min_addr = range.max_addr = this->saved_addr; + nf_nat_setup_info(new, &range, NF_NAT_MANIP_DST); +} + +/****************************************************************************/ +static int nat_callforwarding(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, int dataoff, + TransportAddress *taddr, __be16 port, + struct nf_conntrack_expect *exp) +{ + int dir = CTINFO2DIR(ctinfo); + u_int16_t nated_port; + + /* Set expectations for NAT */ + exp->saved_addr = exp->tuple.dst.u3; + exp->tuple.dst.u3.ip = ct->tuplehash[!dir].tuple.dst.u3.ip; + exp->saved_proto.tcp.port = exp->tuple.dst.u.tcp.port; + exp->expectfn = ip_nat_callforwarding_expect; + exp->dir = !dir; + + nated_port = nf_nat_exp_find_port(exp, ntohs(port)); + if (nated_port == 0) { /* No port available */ + net_notice_ratelimited("nf_nat_q931: out of TCP ports\n"); + return 0; + } + + /* Modify signal */ + if (set_h225_addr(skb, protoff, data, dataoff, taddr, + &ct->tuplehash[!dir].tuple.dst.u3, + htons(nated_port))) { + nf_ct_unexpect_related(exp); + return -1; + } + + /* Success */ + pr_debug("nf_nat_q931: expect Call Forwarding %pI4:%hu->%pI4:%hu\n", + &exp->tuple.src.u3.ip, + ntohs(exp->tuple.src.u.tcp.port), + &exp->tuple.dst.u3.ip, + ntohs(exp->tuple.dst.u.tcp.port)); + + return 0; +} + +static struct nf_ct_helper_expectfn q931_nat = { + .name = "Q.931", + .expectfn = ip_nat_q931_expect, +}; + +static struct nf_ct_helper_expectfn callforwarding_nat = { + .name = "callforwarding", + .expectfn = ip_nat_callforwarding_expect, +}; + +static const struct nfct_h323_nat_hooks nathooks = { + .set_h245_addr = set_h245_addr, + .set_h225_addr = set_h225_addr, + .set_sig_addr = set_sig_addr, + .set_ras_addr = set_ras_addr, + .nat_rtp_rtcp = nat_rtp_rtcp, + .nat_t120 = nat_t120, + .nat_h245 = nat_h245, + .nat_callforwarding = nat_callforwarding, + .nat_q931 = nat_q931, +}; + +/****************************************************************************/ +static int __init nf_nat_h323_init(void) +{ + RCU_INIT_POINTER(nfct_h323_nat_hook, &nathooks); + nf_ct_helper_expectfn_register(&q931_nat); + nf_ct_helper_expectfn_register(&callforwarding_nat); + return 0; +} + +/****************************************************************************/ +static void __exit nf_nat_h323_fini(void) +{ + RCU_INIT_POINTER(nfct_h323_nat_hook, NULL); + nf_ct_helper_expectfn_unregister(&q931_nat); + nf_ct_helper_expectfn_unregister(&callforwarding_nat); + synchronize_rcu(); +} + +/****************************************************************************/ +module_init(nf_nat_h323_init); +module_exit(nf_nat_h323_fini); + +MODULE_AUTHOR("Jing Min Zhao "); +MODULE_DESCRIPTION("H.323 NAT helper"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NF_NAT_HELPER("h323"); diff --git a/net/ipv4/netfilter/nf_nat_pptp.c b/net/ipv4/netfilter/nf_nat_pptp.c new file mode 100644 index 000000000..fab357cc8 --- /dev/null +++ b/net/ipv4/netfilter/nf_nat_pptp.c @@ -0,0 +1,320 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * nf_nat_pptp.c + * + * NAT support for PPTP (Point to Point Tunneling Protocol). + * PPTP is a protocol for creating virtual private networks. + * It is a specification defined by Microsoft and some vendors + * working with Microsoft. PPTP is built on top of a modified + * version of the Internet Generic Routing Encapsulation Protocol. + * GRE is defined in RFC 1701 and RFC 1702. Documentation of + * PPTP can be found in RFC 2637 + * + * (C) 2000-2005 by Harald Welte + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + * + * (C) 2006-2012 Patrick McHardy + * + * TODO: - NAT to a unique tuple, not to TCP source port + * (needs netfilter tuple reservation) + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define NF_NAT_PPTP_VERSION "3.0" + +#define REQ_CID(req, off) (*(__be16 *)((char *)(req) + (off))) + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Harald Welte "); +MODULE_DESCRIPTION("Netfilter NAT helper module for PPTP"); +MODULE_ALIAS_NF_NAT_HELPER("pptp"); + +static void pptp_nat_expected(struct nf_conn *ct, + struct nf_conntrack_expect *exp) +{ + struct net *net = nf_ct_net(ct); + const struct nf_conn *master = ct->master; + struct nf_conntrack_expect *other_exp; + struct nf_conntrack_tuple t = {}; + const struct nf_ct_pptp_master *ct_pptp_info; + const struct nf_nat_pptp *nat_pptp_info; + struct nf_nat_range2 range; + struct nf_conn_nat *nat; + + nat = nf_ct_nat_ext_add(ct); + if (WARN_ON_ONCE(!nat)) + return; + + nat_pptp_info = &nat->help.nat_pptp_info; + ct_pptp_info = nfct_help_data(master); + + /* And here goes the grand finale of corrosion... */ + if (exp->dir == IP_CT_DIR_ORIGINAL) { + pr_debug("we are PNS->PAC\n"); + /* therefore, build tuple for PAC->PNS */ + t.src.l3num = AF_INET; + t.src.u3.ip = master->tuplehash[!exp->dir].tuple.src.u3.ip; + t.src.u.gre.key = ct_pptp_info->pac_call_id; + t.dst.u3.ip = master->tuplehash[!exp->dir].tuple.dst.u3.ip; + t.dst.u.gre.key = ct_pptp_info->pns_call_id; + t.dst.protonum = IPPROTO_GRE; + } else { + pr_debug("we are PAC->PNS\n"); + /* build tuple for PNS->PAC */ + t.src.l3num = AF_INET; + t.src.u3.ip = master->tuplehash[!exp->dir].tuple.src.u3.ip; + t.src.u.gre.key = nat_pptp_info->pns_call_id; + t.dst.u3.ip = master->tuplehash[!exp->dir].tuple.dst.u3.ip; + t.dst.u.gre.key = nat_pptp_info->pac_call_id; + t.dst.protonum = IPPROTO_GRE; + } + + pr_debug("trying to unexpect other dir: "); + nf_ct_dump_tuple_ip(&t); + other_exp = nf_ct_expect_find_get(net, nf_ct_zone(ct), &t); + if (other_exp) { + nf_ct_unexpect_related(other_exp); + nf_ct_expect_put(other_exp); + pr_debug("success\n"); + } else { + pr_debug("not found!\n"); + } + + /* This must be a fresh one. */ + BUG_ON(ct->status & IPS_NAT_DONE_MASK); + + /* Change src to where master sends to */ + range.flags = NF_NAT_RANGE_MAP_IPS; + range.min_addr = range.max_addr + = ct->master->tuplehash[!exp->dir].tuple.dst.u3; + if (exp->dir == IP_CT_DIR_ORIGINAL) { + range.flags |= NF_NAT_RANGE_PROTO_SPECIFIED; + range.min_proto = range.max_proto = exp->saved_proto; + } + nf_nat_setup_info(ct, &range, NF_NAT_MANIP_SRC); + + /* For DST manip, map port here to where it's expected. */ + range.flags = NF_NAT_RANGE_MAP_IPS; + range.min_addr = range.max_addr + = ct->master->tuplehash[!exp->dir].tuple.src.u3; + if (exp->dir == IP_CT_DIR_REPLY) { + range.flags |= NF_NAT_RANGE_PROTO_SPECIFIED; + range.min_proto = range.max_proto = exp->saved_proto; + } + nf_nat_setup_info(ct, &range, NF_NAT_MANIP_DST); +} + +/* outbound packets == from PNS to PAC */ +static int +pptp_outbound_pkt(struct sk_buff *skb, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + struct PptpControlHeader *ctlh, + union pptp_ctrl_union *pptpReq) + +{ + struct nf_ct_pptp_master *ct_pptp_info; + struct nf_conn_nat *nat = nfct_nat(ct); + struct nf_nat_pptp *nat_pptp_info; + u_int16_t msg; + __be16 new_callid; + unsigned int cid_off; + + if (WARN_ON_ONCE(!nat)) + return NF_DROP; + + nat_pptp_info = &nat->help.nat_pptp_info; + ct_pptp_info = nfct_help_data(ct); + + new_callid = ct_pptp_info->pns_call_id; + + switch (msg = ntohs(ctlh->messageType)) { + case PPTP_OUT_CALL_REQUEST: + cid_off = offsetof(union pptp_ctrl_union, ocreq.callID); + /* FIXME: ideally we would want to reserve a call ID + * here. current netfilter NAT core is not able to do + * this :( For now we use TCP source port. This breaks + * multiple calls within one control session */ + + /* save original call ID in nat_info */ + nat_pptp_info->pns_call_id = ct_pptp_info->pns_call_id; + + /* don't use tcph->source since we are at a DSTmanip + * hook (e.g. PREROUTING) and pkt is not mangled yet */ + new_callid = ct->tuplehash[IP_CT_DIR_REPLY].tuple.dst.u.tcp.port; + + /* save new call ID in ct info */ + ct_pptp_info->pns_call_id = new_callid; + break; + case PPTP_IN_CALL_REPLY: + cid_off = offsetof(union pptp_ctrl_union, icack.callID); + break; + case PPTP_CALL_CLEAR_REQUEST: + cid_off = offsetof(union pptp_ctrl_union, clrreq.callID); + break; + default: + pr_debug("unknown outbound packet 0x%04x:%s\n", msg, + pptp_msg_name(msg)); + fallthrough; + case PPTP_SET_LINK_INFO: + /* only need to NAT in case PAC is behind NAT box */ + case PPTP_START_SESSION_REQUEST: + case PPTP_START_SESSION_REPLY: + case PPTP_STOP_SESSION_REQUEST: + case PPTP_STOP_SESSION_REPLY: + case PPTP_ECHO_REQUEST: + case PPTP_ECHO_REPLY: + /* no need to alter packet */ + return NF_ACCEPT; + } + + /* only OUT_CALL_REQUEST, IN_CALL_REPLY, CALL_CLEAR_REQUEST pass + * down to here */ + pr_debug("altering call id from 0x%04x to 0x%04x\n", + ntohs(REQ_CID(pptpReq, cid_off)), ntohs(new_callid)); + + /* mangle packet */ + if (!nf_nat_mangle_tcp_packet(skb, ct, ctinfo, protoff, + cid_off + sizeof(struct pptp_pkt_hdr) + + sizeof(struct PptpControlHeader), + sizeof(new_callid), (char *)&new_callid, + sizeof(new_callid))) + return NF_DROP; + return NF_ACCEPT; +} + +static void +pptp_exp_gre(struct nf_conntrack_expect *expect_orig, + struct nf_conntrack_expect *expect_reply) +{ + const struct nf_conn *ct = expect_orig->master; + struct nf_conn_nat *nat = nfct_nat(ct); + struct nf_ct_pptp_master *ct_pptp_info; + struct nf_nat_pptp *nat_pptp_info; + + if (WARN_ON_ONCE(!nat)) + return; + + nat_pptp_info = &nat->help.nat_pptp_info; + ct_pptp_info = nfct_help_data(ct); + + /* save original PAC call ID in nat_info */ + nat_pptp_info->pac_call_id = ct_pptp_info->pac_call_id; + + /* alter expectation for PNS->PAC direction */ + expect_orig->saved_proto.gre.key = ct_pptp_info->pns_call_id; + expect_orig->tuple.src.u.gre.key = nat_pptp_info->pns_call_id; + expect_orig->tuple.dst.u.gre.key = ct_pptp_info->pac_call_id; + expect_orig->dir = IP_CT_DIR_ORIGINAL; + + /* alter expectation for PAC->PNS direction */ + expect_reply->saved_proto.gre.key = nat_pptp_info->pns_call_id; + expect_reply->tuple.src.u.gre.key = nat_pptp_info->pac_call_id; + expect_reply->tuple.dst.u.gre.key = ct_pptp_info->pns_call_id; + expect_reply->dir = IP_CT_DIR_REPLY; +} + +/* inbound packets == from PAC to PNS */ +static int +pptp_inbound_pkt(struct sk_buff *skb, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + struct PptpControlHeader *ctlh, + union pptp_ctrl_union *pptpReq) +{ + const struct nf_nat_pptp *nat_pptp_info; + struct nf_conn_nat *nat = nfct_nat(ct); + u_int16_t msg; + __be16 new_pcid; + unsigned int pcid_off; + + if (WARN_ON_ONCE(!nat)) + return NF_DROP; + + nat_pptp_info = &nat->help.nat_pptp_info; + new_pcid = nat_pptp_info->pns_call_id; + + switch (msg = ntohs(ctlh->messageType)) { + case PPTP_OUT_CALL_REPLY: + pcid_off = offsetof(union pptp_ctrl_union, ocack.peersCallID); + break; + case PPTP_IN_CALL_CONNECT: + pcid_off = offsetof(union pptp_ctrl_union, iccon.peersCallID); + break; + case PPTP_IN_CALL_REQUEST: + /* only need to nat in case PAC is behind NAT box */ + return NF_ACCEPT; + case PPTP_WAN_ERROR_NOTIFY: + pcid_off = offsetof(union pptp_ctrl_union, wanerr.peersCallID); + break; + case PPTP_CALL_DISCONNECT_NOTIFY: + pcid_off = offsetof(union pptp_ctrl_union, disc.callID); + break; + case PPTP_SET_LINK_INFO: + pcid_off = offsetof(union pptp_ctrl_union, setlink.peersCallID); + break; + default: + pr_debug("unknown inbound packet %s\n", pptp_msg_name(msg)); + fallthrough; + case PPTP_START_SESSION_REQUEST: + case PPTP_START_SESSION_REPLY: + case PPTP_STOP_SESSION_REQUEST: + case PPTP_STOP_SESSION_REPLY: + case PPTP_ECHO_REQUEST: + case PPTP_ECHO_REPLY: + /* no need to alter packet */ + return NF_ACCEPT; + } + + /* only OUT_CALL_REPLY, IN_CALL_CONNECT, IN_CALL_REQUEST, + * WAN_ERROR_NOTIFY, CALL_DISCONNECT_NOTIFY pass down here */ + + /* mangle packet */ + pr_debug("altering peer call id from 0x%04x to 0x%04x\n", + ntohs(REQ_CID(pptpReq, pcid_off)), ntohs(new_pcid)); + + if (!nf_nat_mangle_tcp_packet(skb, ct, ctinfo, protoff, + pcid_off + sizeof(struct pptp_pkt_hdr) + + sizeof(struct PptpControlHeader), + sizeof(new_pcid), (char *)&new_pcid, + sizeof(new_pcid))) + return NF_DROP; + return NF_ACCEPT; +} + +static const struct nf_nat_pptp_hook pptp_hooks = { + .outbound = pptp_outbound_pkt, + .inbound = pptp_inbound_pkt, + .exp_gre = pptp_exp_gre, + .expectfn = pptp_nat_expected, +}; + +static int __init nf_nat_helper_pptp_init(void) +{ + WARN_ON(nf_nat_pptp_hook != NULL); + RCU_INIT_POINTER(nf_nat_pptp_hook, &pptp_hooks); + + return 0; +} + +static void __exit nf_nat_helper_pptp_fini(void) +{ + RCU_INIT_POINTER(nf_nat_pptp_hook, NULL); + synchronize_rcu(); +} + +module_init(nf_nat_helper_pptp_init); +module_exit(nf_nat_helper_pptp_fini); diff --git a/net/ipv4/netfilter/nf_nat_snmp_basic.asn1 b/net/ipv4/netfilter/nf_nat_snmp_basic.asn1 new file mode 100644 index 000000000..24b73268f --- /dev/null +++ b/net/ipv4/netfilter/nf_nat_snmp_basic.asn1 @@ -0,0 +1,177 @@ +Message ::= + SEQUENCE { + version + INTEGER ({snmp_version}), + + community + OCTET STRING, + + pdu + PDUs + } + + +ObjectName ::= + OBJECT IDENTIFIER + +ObjectSyntax ::= + CHOICE { + simple + SimpleSyntax, + + application-wide + ApplicationSyntax + } + +SimpleSyntax ::= + CHOICE { + integer-value + INTEGER, + + string-value + OCTET STRING, + + objectID-value + OBJECT IDENTIFIER + } + +ApplicationSyntax ::= + CHOICE { + ipAddress-value + IpAddress, + + counter-value + Counter32, + + timeticks-value + TimeTicks, + + arbitrary-value + Opaque, + + big-counter-value + Counter64, + + unsigned-integer-value + Unsigned32 + } + +IpAddress ::= + [APPLICATION 0] + IMPLICIT OCTET STRING OPTIONAL ({snmp_helper}) + +Counter32 ::= + [APPLICATION 1] + IMPLICIT INTEGER OPTIONAL + +Unsigned32 ::= + [APPLICATION 2] + IMPLICIT INTEGER OPTIONAL + +Gauge32 ::= Unsigned32 OPTIONAL + +TimeTicks ::= + [APPLICATION 3] + IMPLICIT INTEGER OPTIONAL + +Opaque ::= + [APPLICATION 4] + IMPLICIT OCTET STRING OPTIONAL + +Counter64 ::= + [APPLICATION 6] + IMPLICIT INTEGER OPTIONAL + +PDUs ::= + CHOICE { + get-request + GetRequest-PDU, + + get-next-request + GetNextRequest-PDU, + + get-bulk-request + GetBulkRequest-PDU, + + response + Response-PDU, + + set-request + SetRequest-PDU, + + inform-request + InformRequest-PDU, + + snmpV2-trap + SNMPv2-Trap-PDU, + + report + Report-PDU + } + +GetRequest-PDU ::= + [0] IMPLICIT PDU OPTIONAL + +GetNextRequest-PDU ::= + [1] IMPLICIT PDU OPTIONAL + +Response-PDU ::= + [2] IMPLICIT PDU OPTIONAL + +SetRequest-PDU ::= + [3] IMPLICIT PDU OPTIONAL + +-- [4] is obsolete + +GetBulkRequest-PDU ::= + [5] IMPLICIT PDU OPTIONAL + +InformRequest-PDU ::= + [6] IMPLICIT PDU OPTIONAL + +SNMPv2-Trap-PDU ::= + [7] IMPLICIT PDU OPTIONAL + +Report-PDU ::= + [8] IMPLICIT PDU OPTIONAL + +PDU ::= + SEQUENCE { + request-id + INTEGER, + + error-status + INTEGER, + + error-index + INTEGER, + + variable-bindings + VarBindList + } + + +VarBind ::= + SEQUENCE { + name + ObjectName, + + CHOICE { + value + ObjectSyntax, + + unSpecified + NULL, + + noSuchObject + [0] IMPLICIT NULL, + + noSuchInstance + [1] IMPLICIT NULL, + + endOfMibView + [2] IMPLICIT NULL + } +} + +VarBindList ::= SEQUENCE OF VarBind diff --git a/net/ipv4/netfilter/nf_nat_snmp_basic_main.c b/net/ipv4/netfilter/nf_nat_snmp_basic_main.c new file mode 100644 index 000000000..717b72650 --- /dev/null +++ b/net/ipv4/netfilter/nf_nat_snmp_basic_main.c @@ -0,0 +1,231 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * nf_nat_snmp_basic.c + * + * Basic SNMP Application Layer Gateway + * + * This IP NAT module is intended for use with SNMP network + * discovery and monitoring applications where target networks use + * conflicting private address realms. + * + * Static NAT is used to remap the networks from the view of the network + * management system at the IP layer, and this module remaps some application + * layer addresses to match. + * + * The simplest form of ALG is performed, where only tagged IP addresses + * are modified. The module does not need to be MIB aware and only scans + * messages at the ASN.1/BER level. + * + * Currently, only SNMPv1 and SNMPv2 are supported. + * + * More information on ALG and associated issues can be found in + * RFC 2962 + * + * The ASB.1/BER parsing code is derived from the gxsnmp package by Gregory + * McLean & Jochen Friedrich, stripped down for use in the kernel. + * + * Copyright (c) 2000 RP Internet (www.rpi.net.au). + * + * Author: James Morris + * + * Copyright (c) 2006-2010 Patrick McHardy + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include "nf_nat_snmp_basic.asn1.h" + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("James Morris "); +MODULE_DESCRIPTION("Basic SNMP Application Layer Gateway"); +MODULE_ALIAS("ip_nat_snmp_basic"); +MODULE_ALIAS_NFCT_HELPER("snmp_trap"); + +#define SNMP_PORT 161 +#define SNMP_TRAP_PORT 162 + +static DEFINE_SPINLOCK(snmp_lock); + +struct snmp_ctx { + unsigned char *begin; + __sum16 *check; + __be32 from; + __be32 to; +}; + +static void fast_csum(struct snmp_ctx *ctx, unsigned char offset) +{ + unsigned char s[12] = {0,}; + int size; + + if (offset & 1) { + memcpy(&s[1], &ctx->from, 4); + memcpy(&s[7], &ctx->to, 4); + s[0] = ~0; + s[1] = ~s[1]; + s[2] = ~s[2]; + s[3] = ~s[3]; + s[4] = ~s[4]; + s[5] = ~0; + size = 12; + } else { + memcpy(&s[0], &ctx->from, 4); + memcpy(&s[4], &ctx->to, 4); + s[0] = ~s[0]; + s[1] = ~s[1]; + s[2] = ~s[2]; + s[3] = ~s[3]; + size = 8; + } + *ctx->check = csum_fold(csum_partial(s, size, + ~csum_unfold(*ctx->check))); +} + +int snmp_version(void *context, size_t hdrlen, unsigned char tag, + const void *data, size_t datalen) +{ + if (datalen != 1) + return -EINVAL; + if (*(unsigned char *)data > 1) + return -ENOTSUPP; + return 1; +} + +int snmp_helper(void *context, size_t hdrlen, unsigned char tag, + const void *data, size_t datalen) +{ + struct snmp_ctx *ctx = (struct snmp_ctx *)context; + __be32 *pdata; + + if (datalen != 4) + return -EINVAL; + pdata = (__be32 *)data; + if (*pdata == ctx->from) { + pr_debug("%s: %pI4 to %pI4\n", __func__, + (void *)&ctx->from, (void *)&ctx->to); + + if (*ctx->check) + fast_csum(ctx, (unsigned char *)data - ctx->begin); + *pdata = ctx->to; + } + + return 1; +} + +static int snmp_translate(struct nf_conn *ct, int dir, struct sk_buff *skb) +{ + struct iphdr *iph = ip_hdr(skb); + struct udphdr *udph = (struct udphdr *)((__be32 *)iph + iph->ihl); + u16 datalen = ntohs(udph->len) - sizeof(struct udphdr); + char *data = (unsigned char *)udph + sizeof(struct udphdr); + struct snmp_ctx ctx; + int ret; + + if (dir == IP_CT_DIR_ORIGINAL) { + ctx.from = ct->tuplehash[dir].tuple.src.u3.ip; + ctx.to = ct->tuplehash[!dir].tuple.dst.u3.ip; + } else { + ctx.from = ct->tuplehash[!dir].tuple.src.u3.ip; + ctx.to = ct->tuplehash[dir].tuple.dst.u3.ip; + } + + if (ctx.from == ctx.to) + return NF_ACCEPT; + + ctx.begin = (unsigned char *)udph + sizeof(struct udphdr); + ctx.check = &udph->check; + ret = asn1_ber_decoder(&nf_nat_snmp_basic_decoder, &ctx, data, datalen); + if (ret < 0) { + nf_ct_helper_log(skb, ct, "parser failed\n"); + return NF_DROP; + } + + return NF_ACCEPT; +} + +/* We don't actually set up expectations, just adjust internal IP + * addresses if this is being NATted + */ +static int help(struct sk_buff *skb, unsigned int protoff, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo) +{ + int dir = CTINFO2DIR(ctinfo); + unsigned int ret; + const struct iphdr *iph = ip_hdr(skb); + const struct udphdr *udph = (struct udphdr *)((__be32 *)iph + iph->ihl); + + /* SNMP replies and originating SNMP traps get mangled */ + if (udph->source == htons(SNMP_PORT) && dir != IP_CT_DIR_REPLY) + return NF_ACCEPT; + if (udph->dest == htons(SNMP_TRAP_PORT) && dir != IP_CT_DIR_ORIGINAL) + return NF_ACCEPT; + + /* No NAT? */ + if (!(ct->status & IPS_NAT_MASK)) + return NF_ACCEPT; + + /* Make sure the packet length is ok. So far, we were only guaranteed + * to have a valid length IP header plus 8 bytes, which means we have + * enough room for a UDP header. Just verify the UDP length field so we + * can mess around with the payload. + */ + if (ntohs(udph->len) != skb->len - (iph->ihl << 2)) { + nf_ct_helper_log(skb, ct, "dropping malformed packet\n"); + return NF_DROP; + } + + if (skb_ensure_writable(skb, skb->len)) { + nf_ct_helper_log(skb, ct, "cannot mangle packet"); + return NF_DROP; + } + + spin_lock_bh(&snmp_lock); + ret = snmp_translate(ct, dir, skb); + spin_unlock_bh(&snmp_lock); + return ret; +} + +static const struct nf_conntrack_expect_policy snmp_exp_policy = { + .max_expected = 0, + .timeout = 180, +}; + +static struct nf_conntrack_helper snmp_trap_helper __read_mostly = { + .me = THIS_MODULE, + .help = help, + .expect_policy = &snmp_exp_policy, + .name = "snmp_trap", + .tuple.src.l3num = AF_INET, + .tuple.src.u.udp.port = cpu_to_be16(SNMP_TRAP_PORT), + .tuple.dst.protonum = IPPROTO_UDP, +}; + +static int __init nf_nat_snmp_basic_init(void) +{ + BUG_ON(nf_nat_snmp_hook != NULL); + RCU_INIT_POINTER(nf_nat_snmp_hook, help); + + return nf_conntrack_helper_register(&snmp_trap_helper); +} + +static void __exit nf_nat_snmp_basic_fini(void) +{ + RCU_INIT_POINTER(nf_nat_snmp_hook, NULL); + synchronize_rcu(); + nf_conntrack_helper_unregister(&snmp_trap_helper); +} + +module_init(nf_nat_snmp_basic_init); +module_exit(nf_nat_snmp_basic_fini); diff --git a/net/ipv4/netfilter/nf_reject_ipv4.c b/net/ipv4/netfilter/nf_reject_ipv4.c new file mode 100644 index 000000000..407376299 --- /dev/null +++ b/net/ipv4/netfilter/nf_reject_ipv4.c @@ -0,0 +1,340 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2004 Netfilter Core Team + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +static int nf_reject_iphdr_validate(struct sk_buff *skb) +{ + struct iphdr *iph; + u32 len; + + if (!pskb_may_pull(skb, sizeof(struct iphdr))) + return 0; + + iph = ip_hdr(skb); + if (iph->ihl < 5 || iph->version != 4) + return 0; + + len = ntohs(iph->tot_len); + if (skb->len < len) + return 0; + else if (len < (iph->ihl*4)) + return 0; + + if (!pskb_may_pull(skb, iph->ihl*4)) + return 0; + + return 1; +} + +struct sk_buff *nf_reject_skb_v4_tcp_reset(struct net *net, + struct sk_buff *oldskb, + const struct net_device *dev, + int hook) +{ + const struct tcphdr *oth; + struct sk_buff *nskb; + struct iphdr *niph; + struct tcphdr _oth; + + if (!nf_reject_iphdr_validate(oldskb)) + return NULL; + + oth = nf_reject_ip_tcphdr_get(oldskb, &_oth, hook); + if (!oth) + return NULL; + + nskb = alloc_skb(sizeof(struct iphdr) + sizeof(struct tcphdr) + + LL_MAX_HEADER, GFP_ATOMIC); + if (!nskb) + return NULL; + + nskb->dev = (struct net_device *)dev; + + skb_reserve(nskb, LL_MAX_HEADER); + niph = nf_reject_iphdr_put(nskb, oldskb, IPPROTO_TCP, + READ_ONCE(net->ipv4.sysctl_ip_default_ttl)); + nf_reject_ip_tcphdr_put(nskb, oldskb, oth); + niph->tot_len = htons(nskb->len); + ip_send_check(niph); + + return nskb; +} +EXPORT_SYMBOL_GPL(nf_reject_skb_v4_tcp_reset); + +struct sk_buff *nf_reject_skb_v4_unreach(struct net *net, + struct sk_buff *oldskb, + const struct net_device *dev, + int hook, u8 code) +{ + struct sk_buff *nskb; + struct iphdr *niph; + struct icmphdr *icmph; + unsigned int len; + int dataoff; + __wsum csum; + u8 proto; + + if (!nf_reject_iphdr_validate(oldskb)) + return NULL; + + /* IP header checks: fragment. */ + if (ip_hdr(oldskb)->frag_off & htons(IP_OFFSET)) + return NULL; + + /* RFC says return as much as we can without exceeding 576 bytes. */ + len = min_t(unsigned int, 536, oldskb->len); + + if (!pskb_may_pull(oldskb, len)) + return NULL; + + if (pskb_trim_rcsum(oldskb, ntohs(ip_hdr(oldskb)->tot_len))) + return NULL; + + dataoff = ip_hdrlen(oldskb); + proto = ip_hdr(oldskb)->protocol; + + if (!skb_csum_unnecessary(oldskb) && + nf_reject_verify_csum(oldskb, dataoff, proto) && + nf_ip_checksum(oldskb, hook, ip_hdrlen(oldskb), proto)) + return NULL; + + nskb = alloc_skb(sizeof(struct iphdr) + sizeof(struct icmphdr) + + LL_MAX_HEADER + len, GFP_ATOMIC); + if (!nskb) + return NULL; + + nskb->dev = (struct net_device *)dev; + + skb_reserve(nskb, LL_MAX_HEADER); + niph = nf_reject_iphdr_put(nskb, oldskb, IPPROTO_ICMP, + READ_ONCE(net->ipv4.sysctl_ip_default_ttl)); + + skb_reset_transport_header(nskb); + icmph = skb_put_zero(nskb, sizeof(struct icmphdr)); + icmph->type = ICMP_DEST_UNREACH; + icmph->code = code; + + skb_put_data(nskb, skb_network_header(oldskb), len); + + csum = csum_partial((void *)icmph, len + sizeof(struct icmphdr), 0); + icmph->checksum = csum_fold(csum); + + niph->tot_len = htons(nskb->len); + ip_send_check(niph); + + return nskb; +} +EXPORT_SYMBOL_GPL(nf_reject_skb_v4_unreach); + +const struct tcphdr *nf_reject_ip_tcphdr_get(struct sk_buff *oldskb, + struct tcphdr *_oth, int hook) +{ + const struct tcphdr *oth; + + /* IP header checks: fragment. */ + if (ip_hdr(oldskb)->frag_off & htons(IP_OFFSET)) + return NULL; + + if (ip_hdr(oldskb)->protocol != IPPROTO_TCP) + return NULL; + + oth = skb_header_pointer(oldskb, ip_hdrlen(oldskb), + sizeof(struct tcphdr), _oth); + if (oth == NULL) + return NULL; + + /* No RST for RST. */ + if (oth->rst) + return NULL; + + /* Check checksum */ + if (nf_ip_checksum(oldskb, hook, ip_hdrlen(oldskb), IPPROTO_TCP)) + return NULL; + + return oth; +} +EXPORT_SYMBOL_GPL(nf_reject_ip_tcphdr_get); + +struct iphdr *nf_reject_iphdr_put(struct sk_buff *nskb, + const struct sk_buff *oldskb, + __u8 protocol, int ttl) +{ + struct iphdr *niph, *oiph = ip_hdr(oldskb); + + skb_reset_network_header(nskb); + niph = skb_put(nskb, sizeof(struct iphdr)); + niph->version = 4; + niph->ihl = sizeof(struct iphdr) / 4; + niph->tos = 0; + niph->id = 0; + niph->frag_off = htons(IP_DF); + niph->protocol = protocol; + niph->check = 0; + niph->saddr = oiph->daddr; + niph->daddr = oiph->saddr; + niph->ttl = ttl; + + nskb->protocol = htons(ETH_P_IP); + + return niph; +} +EXPORT_SYMBOL_GPL(nf_reject_iphdr_put); + +void nf_reject_ip_tcphdr_put(struct sk_buff *nskb, const struct sk_buff *oldskb, + const struct tcphdr *oth) +{ + struct iphdr *niph = ip_hdr(nskb); + struct tcphdr *tcph; + + skb_reset_transport_header(nskb); + tcph = skb_put_zero(nskb, sizeof(struct tcphdr)); + tcph->source = oth->dest; + tcph->dest = oth->source; + tcph->doff = sizeof(struct tcphdr) / 4; + + if (oth->ack) { + tcph->seq = oth->ack_seq; + } else { + tcph->ack_seq = htonl(ntohl(oth->seq) + oth->syn + oth->fin + + oldskb->len - ip_hdrlen(oldskb) - + (oth->doff << 2)); + tcph->ack = 1; + } + + tcph->rst = 1; + tcph->check = ~tcp_v4_check(sizeof(struct tcphdr), niph->saddr, + niph->daddr, 0); + nskb->ip_summed = CHECKSUM_PARTIAL; + nskb->csum_start = (unsigned char *)tcph - nskb->head; + nskb->csum_offset = offsetof(struct tcphdr, check); +} +EXPORT_SYMBOL_GPL(nf_reject_ip_tcphdr_put); + +static int nf_reject_fill_skb_dst(struct sk_buff *skb_in) +{ + struct dst_entry *dst = NULL; + struct flowi fl; + + memset(&fl, 0, sizeof(struct flowi)); + fl.u.ip4.daddr = ip_hdr(skb_in)->saddr; + nf_ip_route(dev_net(skb_in->dev), &dst, &fl, false); + if (!dst) + return -1; + + skb_dst_set(skb_in, dst); + return 0; +} + +/* Send RST reply */ +void nf_send_reset(struct net *net, struct sock *sk, struct sk_buff *oldskb, + int hook) +{ + struct sk_buff *nskb; + struct iphdr *niph; + const struct tcphdr *oth; + struct tcphdr _oth; + + oth = nf_reject_ip_tcphdr_get(oldskb, &_oth, hook); + if (!oth) + return; + + if ((hook == NF_INET_PRE_ROUTING || hook == NF_INET_INGRESS) && + nf_reject_fill_skb_dst(oldskb) < 0) + return; + + if (skb_rtable(oldskb)->rt_flags & (RTCF_BROADCAST | RTCF_MULTICAST)) + return; + + nskb = alloc_skb(sizeof(struct iphdr) + sizeof(struct tcphdr) + + LL_MAX_HEADER, GFP_ATOMIC); + if (!nskb) + return; + + /* ip_route_me_harder expects skb->dst to be set */ + skb_dst_set_noref(nskb, skb_dst(oldskb)); + + nskb->mark = IP4_REPLY_MARK(net, oldskb->mark); + + skb_reserve(nskb, LL_MAX_HEADER); + niph = nf_reject_iphdr_put(nskb, oldskb, IPPROTO_TCP, + ip4_dst_hoplimit(skb_dst(nskb))); + nf_reject_ip_tcphdr_put(nskb, oldskb, oth); + if (ip_route_me_harder(net, sk, nskb, RTN_UNSPEC)) + goto free_nskb; + + niph = ip_hdr(nskb); + + /* "Never happens" */ + if (nskb->len > dst_mtu(skb_dst(nskb))) + goto free_nskb; + + nf_ct_attach(nskb, oldskb); + +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + /* If we use ip_local_out for bridged traffic, the MAC source on + * the RST will be ours, instead of the destination's. This confuses + * some routers/firewalls, and they drop the packet. So we need to + * build the eth header using the original destination's MAC as the + * source, and send the RST packet directly. + */ + if (nf_bridge_info_exists(oldskb)) { + struct ethhdr *oeth = eth_hdr(oldskb); + struct net_device *br_indev; + + br_indev = nf_bridge_get_physindev(oldskb, net); + if (!br_indev) + goto free_nskb; + + nskb->dev = br_indev; + niph->tot_len = htons(nskb->len); + ip_send_check(niph); + if (dev_hard_header(nskb, nskb->dev, ntohs(nskb->protocol), + oeth->h_source, oeth->h_dest, nskb->len) < 0) + goto free_nskb; + dev_queue_xmit(nskb); + } else +#endif + ip_local_out(net, nskb->sk, nskb); + + return; + + free_nskb: + kfree_skb(nskb); +} +EXPORT_SYMBOL_GPL(nf_send_reset); + +void nf_send_unreach(struct sk_buff *skb_in, int code, int hook) +{ + struct iphdr *iph = ip_hdr(skb_in); + int dataoff = ip_hdrlen(skb_in); + u8 proto = iph->protocol; + + if (iph->frag_off & htons(IP_OFFSET)) + return; + + if ((hook == NF_INET_PRE_ROUTING || hook == NF_INET_INGRESS) && + nf_reject_fill_skb_dst(skb_in) < 0) + return; + + if (skb_csum_unnecessary(skb_in) || + !nf_reject_verify_csum(skb_in, dataoff, proto)) { + icmp_send(skb_in, ICMP_DEST_UNREACH, code, 0); + return; + } + + if (nf_ip_checksum(skb_in, hook, dataoff, proto) == 0) + icmp_send(skb_in, ICMP_DEST_UNREACH, code, 0); +} +EXPORT_SYMBOL_GPL(nf_send_unreach); + +MODULE_LICENSE("GPL"); diff --git a/net/ipv4/netfilter/nf_socket_ipv4.c b/net/ipv4/netfilter/nf_socket_ipv4.c new file mode 100644 index 000000000..a1350fc25 --- /dev/null +++ b/net/ipv4/netfilter/nf_socket_ipv4.c @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2007-2008 BalaBit IT Ltd. + * Author: Krisztian Kovacs + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_NF_CONNTRACK) +#include +#endif + +static int +extract_icmp4_fields(const struct sk_buff *skb, u8 *protocol, + __be32 *raddr, __be32 *laddr, + __be16 *rport, __be16 *lport) +{ + unsigned int outside_hdrlen = ip_hdrlen(skb); + struct iphdr *inside_iph, _inside_iph; + struct icmphdr *icmph, _icmph; + __be16 *ports, _ports[2]; + + icmph = skb_header_pointer(skb, outside_hdrlen, + sizeof(_icmph), &_icmph); + if (icmph == NULL) + return 1; + + if (!icmp_is_err(icmph->type)) + return 1; + + inside_iph = skb_header_pointer(skb, outside_hdrlen + + sizeof(struct icmphdr), + sizeof(_inside_iph), &_inside_iph); + if (inside_iph == NULL) + return 1; + + if (inside_iph->protocol != IPPROTO_TCP && + inside_iph->protocol != IPPROTO_UDP) + return 1; + + ports = skb_header_pointer(skb, outside_hdrlen + + sizeof(struct icmphdr) + + (inside_iph->ihl << 2), + sizeof(_ports), &_ports); + if (ports == NULL) + return 1; + + /* the inside IP packet is the one quoted from our side, thus + * its saddr is the local address */ + *protocol = inside_iph->protocol; + *laddr = inside_iph->saddr; + *lport = ports[0]; + *raddr = inside_iph->daddr; + *rport = ports[1]; + + return 0; +} + +static struct sock * +nf_socket_get_sock_v4(struct net *net, struct sk_buff *skb, const int doff, + const u8 protocol, + const __be32 saddr, const __be32 daddr, + const __be16 sport, const __be16 dport, + const struct net_device *in) +{ + switch (protocol) { + case IPPROTO_TCP: + return inet_lookup(net, net->ipv4.tcp_death_row.hashinfo, + skb, doff, saddr, sport, daddr, dport, + in->ifindex); + case IPPROTO_UDP: + return udp4_lib_lookup(net, saddr, sport, daddr, dport, + in->ifindex); + } + return NULL; +} + +struct sock *nf_sk_lookup_slow_v4(struct net *net, const struct sk_buff *skb, + const struct net_device *indev) +{ + __be32 daddr, saddr; + __be16 dport, sport; + const struct iphdr *iph = ip_hdr(skb); + struct sk_buff *data_skb = NULL; + u8 protocol; +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + enum ip_conntrack_info ctinfo; + struct nf_conn const *ct; +#endif + int doff = 0; + + if (iph->protocol == IPPROTO_UDP || iph->protocol == IPPROTO_TCP) { + struct tcphdr _hdr; + struct udphdr *hp; + + hp = skb_header_pointer(skb, ip_hdrlen(skb), + iph->protocol == IPPROTO_UDP ? + sizeof(*hp) : sizeof(_hdr), &_hdr); + if (hp == NULL) + return NULL; + + protocol = iph->protocol; + saddr = iph->saddr; + sport = hp->source; + daddr = iph->daddr; + dport = hp->dest; + data_skb = (struct sk_buff *)skb; + doff = iph->protocol == IPPROTO_TCP ? + ip_hdrlen(skb) + __tcp_hdrlen((struct tcphdr *)hp) : + ip_hdrlen(skb) + sizeof(*hp); + + } else if (iph->protocol == IPPROTO_ICMP) { + if (extract_icmp4_fields(skb, &protocol, &saddr, &daddr, + &sport, &dport)) + return NULL; + } else { + return NULL; + } + +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + /* Do the lookup with the original socket address in + * case this is a reply packet of an established + * SNAT-ted connection. + */ + ct = nf_ct_get(skb, &ctinfo); + if (ct && + ((iph->protocol != IPPROTO_ICMP && + ctinfo == IP_CT_ESTABLISHED_REPLY) || + (iph->protocol == IPPROTO_ICMP && + ctinfo == IP_CT_RELATED_REPLY)) && + (ct->status & IPS_SRC_NAT_DONE)) { + + daddr = ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.u3.ip; + dport = (iph->protocol == IPPROTO_TCP) ? + ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.u.tcp.port : + ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.u.udp.port; + } +#endif + + return nf_socket_get_sock_v4(net, data_skb, doff, protocol, saddr, + daddr, sport, dport, indev); +} +EXPORT_SYMBOL_GPL(nf_sk_lookup_slow_v4); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Krisztian Kovacs, Balazs Scheidler"); +MODULE_DESCRIPTION("Netfilter IPv4 socket lookup infrastructure"); diff --git a/net/ipv4/netfilter/nf_tproxy_ipv4.c b/net/ipv4/netfilter/nf_tproxy_ipv4.c new file mode 100644 index 000000000..69e331799 --- /dev/null +++ b/net/ipv4/netfilter/nf_tproxy_ipv4.c @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2007-2008 BalaBit IT Ltd. + * Author: Krisztian Kovacs + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct sock * +nf_tproxy_handle_time_wait4(struct net *net, struct sk_buff *skb, + __be32 laddr, __be16 lport, struct sock *sk) +{ + const struct iphdr *iph = ip_hdr(skb); + struct tcphdr _hdr, *hp; + + hp = skb_header_pointer(skb, ip_hdrlen(skb), sizeof(_hdr), &_hdr); + if (hp == NULL) { + inet_twsk_put(inet_twsk(sk)); + return NULL; + } + + if (hp->syn && !hp->rst && !hp->ack && !hp->fin) { + /* SYN to a TIME_WAIT socket, we'd rather redirect it + * to a listener socket if there's one */ + struct sock *sk2; + + sk2 = nf_tproxy_get_sock_v4(net, skb, iph->protocol, + iph->saddr, laddr ? laddr : iph->daddr, + hp->source, lport ? lport : hp->dest, + skb->dev, NF_TPROXY_LOOKUP_LISTENER); + if (sk2) { + nf_tproxy_twsk_deschedule_put(inet_twsk(sk)); + sk = sk2; + } + } + + return sk; +} +EXPORT_SYMBOL_GPL(nf_tproxy_handle_time_wait4); + +__be32 nf_tproxy_laddr4(struct sk_buff *skb, __be32 user_laddr, __be32 daddr) +{ + const struct in_ifaddr *ifa; + struct in_device *indev; + __be32 laddr; + + if (user_laddr) + return user_laddr; + + laddr = 0; + indev = __in_dev_get_rcu(skb->dev); + + in_dev_for_each_ifa_rcu(ifa, indev) { + if (ifa->ifa_flags & IFA_F_SECONDARY) + continue; + + laddr = ifa->ifa_local; + break; + } + + return laddr ? laddr : daddr; +} +EXPORT_SYMBOL_GPL(nf_tproxy_laddr4); + +struct sock * +nf_tproxy_get_sock_v4(struct net *net, struct sk_buff *skb, + const u8 protocol, + const __be32 saddr, const __be32 daddr, + const __be16 sport, const __be16 dport, + const struct net_device *in, + const enum nf_tproxy_lookup_t lookup_type) +{ + struct inet_hashinfo *hinfo = net->ipv4.tcp_death_row.hashinfo; + struct sock *sk; + + switch (protocol) { + case IPPROTO_TCP: { + struct tcphdr _hdr, *hp; + + hp = skb_header_pointer(skb, ip_hdrlen(skb), + sizeof(struct tcphdr), &_hdr); + if (hp == NULL) + return NULL; + + switch (lookup_type) { + case NF_TPROXY_LOOKUP_LISTENER: + sk = inet_lookup_listener(net, hinfo, skb, + ip_hdrlen(skb) + __tcp_hdrlen(hp), + saddr, sport, daddr, dport, + in->ifindex, 0); + + if (sk && !refcount_inc_not_zero(&sk->sk_refcnt)) + sk = NULL; + /* NOTE: we return listeners even if bound to + * 0.0.0.0, those are filtered out in + * xt_socket, since xt_TPROXY needs 0 bound + * listeners too + */ + break; + case NF_TPROXY_LOOKUP_ESTABLISHED: + sk = inet_lookup_established(net, hinfo, saddr, sport, + daddr, dport, in->ifindex); + break; + default: + BUG(); + } + break; + } + case IPPROTO_UDP: + sk = udp4_lib_lookup(net, saddr, sport, daddr, dport, + in->ifindex); + if (sk) { + int connected = (sk->sk_state == TCP_ESTABLISHED); + int wildcard = (inet_sk(sk)->inet_rcv_saddr == 0); + + /* NOTE: we return listeners even if bound to + * 0.0.0.0, those are filtered out in + * xt_socket, since xt_TPROXY needs 0 bound + * listeners too + */ + if ((lookup_type == NF_TPROXY_LOOKUP_ESTABLISHED && + (!connected || wildcard)) || + (lookup_type == NF_TPROXY_LOOKUP_LISTENER && connected)) { + sock_put(sk); + sk = NULL; + } + } + break; + default: + WARN_ON(1); + sk = NULL; + } + + pr_debug("tproxy socket lookup: proto %u %08x:%u -> %08x:%u, lookup type: %d, sock %p\n", + protocol, ntohl(saddr), ntohs(sport), ntohl(daddr), ntohs(dport), lookup_type, sk); + + return sk; +} +EXPORT_SYMBOL_GPL(nf_tproxy_get_sock_v4); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Balazs Scheidler, Krisztian Kovacs"); +MODULE_DESCRIPTION("Netfilter IPv4 transparent proxy support"); diff --git a/net/ipv4/netfilter/nft_dup_ipv4.c b/net/ipv4/netfilter/nft_dup_ipv4.c new file mode 100644 index 000000000..0bcd6aee6 --- /dev/null +++ b/net/ipv4/netfilter/nft_dup_ipv4.c @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2015 Pablo Neira Ayuso + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_dup_ipv4 { + u8 sreg_addr; + u8 sreg_dev; +}; + +static void nft_dup_ipv4_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_dup_ipv4 *priv = nft_expr_priv(expr); + struct in_addr gw = { + .s_addr = (__force __be32)regs->data[priv->sreg_addr], + }; + int oif = priv->sreg_dev ? regs->data[priv->sreg_dev] : -1; + + nf_dup_ipv4(nft_net(pkt), pkt->skb, nft_hook(pkt), &gw, oif); +} + +static int nft_dup_ipv4_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_dup_ipv4 *priv = nft_expr_priv(expr); + int err; + + if (tb[NFTA_DUP_SREG_ADDR] == NULL) + return -EINVAL; + + err = nft_parse_register_load(tb[NFTA_DUP_SREG_ADDR], &priv->sreg_addr, + sizeof(struct in_addr)); + if (err < 0) + return err; + + if (tb[NFTA_DUP_SREG_DEV]) + err = nft_parse_register_load(tb[NFTA_DUP_SREG_DEV], + &priv->sreg_dev, sizeof(int)); + + return err; +} + +static int nft_dup_ipv4_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + struct nft_dup_ipv4 *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_DUP_SREG_ADDR, priv->sreg_addr)) + goto nla_put_failure; + if (priv->sreg_dev && + nft_dump_register(skb, NFTA_DUP_SREG_DEV, priv->sreg_dev)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static struct nft_expr_type nft_dup_ipv4_type; +static const struct nft_expr_ops nft_dup_ipv4_ops = { + .type = &nft_dup_ipv4_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_dup_ipv4)), + .eval = nft_dup_ipv4_eval, + .init = nft_dup_ipv4_init, + .dump = nft_dup_ipv4_dump, + .reduce = NFT_REDUCE_READONLY, +}; + +static const struct nla_policy nft_dup_ipv4_policy[NFTA_DUP_MAX + 1] = { + [NFTA_DUP_SREG_ADDR] = { .type = NLA_U32 }, + [NFTA_DUP_SREG_DEV] = { .type = NLA_U32 }, +}; + +static struct nft_expr_type nft_dup_ipv4_type __read_mostly = { + .family = NFPROTO_IPV4, + .name = "dup", + .ops = &nft_dup_ipv4_ops, + .policy = nft_dup_ipv4_policy, + .maxattr = NFTA_DUP_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_dup_ipv4_module_init(void) +{ + return nft_register_expr(&nft_dup_ipv4_type); +} + +static void __exit nft_dup_ipv4_module_exit(void) +{ + nft_unregister_expr(&nft_dup_ipv4_type); +} + +module_init(nft_dup_ipv4_module_init); +module_exit(nft_dup_ipv4_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_ALIAS_NFT_AF_EXPR(AF_INET, "dup"); +MODULE_DESCRIPTION("IPv4 nftables packet duplication support"); diff --git a/net/ipv4/netfilter/nft_fib_ipv4.c b/net/ipv4/netfilter/nft_fib_ipv4.c new file mode 100644 index 000000000..fc65d69f2 --- /dev/null +++ b/net/ipv4/netfilter/nft_fib_ipv4.c @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +/* don't try to find route from mcast/bcast/zeronet */ +static __be32 get_saddr(__be32 addr) +{ + if (ipv4_is_multicast(addr) || ipv4_is_lbcast(addr) || + ipv4_is_zeronet(addr)) + return 0; + return addr; +} + +#define DSCP_BITS 0xfc + +void nft_fib4_eval_type(const struct nft_expr *expr, struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_fib *priv = nft_expr_priv(expr); + int noff = skb_network_offset(pkt->skb); + u32 *dst = ®s->data[priv->dreg]; + const struct net_device *dev = NULL; + struct iphdr *iph, _iph; + __be32 addr; + + if (priv->flags & NFTA_FIB_F_IIF) + dev = nft_in(pkt); + else if (priv->flags & NFTA_FIB_F_OIF) + dev = nft_out(pkt); + + iph = skb_header_pointer(pkt->skb, noff, sizeof(_iph), &_iph); + if (!iph) { + regs->verdict.code = NFT_BREAK; + return; + } + + if (priv->flags & NFTA_FIB_F_DADDR) + addr = iph->daddr; + else + addr = iph->saddr; + + *dst = inet_dev_addr_type(nft_net(pkt), dev, addr); +} +EXPORT_SYMBOL_GPL(nft_fib4_eval_type); + +void nft_fib4_eval(const struct nft_expr *expr, struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_fib *priv = nft_expr_priv(expr); + int noff = skb_network_offset(pkt->skb); + u32 *dest = ®s->data[priv->dreg]; + struct iphdr *iph, _iph; + struct fib_result res; + struct flowi4 fl4 = { + .flowi4_scope = RT_SCOPE_UNIVERSE, + .flowi4_iif = LOOPBACK_IFINDEX, + .flowi4_uid = sock_net_uid(nft_net(pkt), NULL), + }; + const struct net_device *oif; + const struct net_device *found; + + /* + * Do not set flowi4_oif, it restricts results (for example, asking + * for oif 3 will get RTN_UNICAST result even if the daddr exits + * on another interface. + * + * Search results for the desired outinterface instead. + */ + if (priv->flags & NFTA_FIB_F_OIF) + oif = nft_out(pkt); + else if (priv->flags & NFTA_FIB_F_IIF) + oif = nft_in(pkt); + else + oif = NULL; + + if (priv->flags & NFTA_FIB_F_IIF) + fl4.flowi4_l3mdev = l3mdev_master_ifindex_rcu(oif); + + if (nft_hook(pkt) == NF_INET_PRE_ROUTING && + nft_fib_is_loopback(pkt->skb, nft_in(pkt))) { + nft_fib_store_result(dest, priv, nft_in(pkt)); + return; + } + + iph = skb_header_pointer(pkt->skb, noff, sizeof(_iph), &_iph); + if (!iph) { + regs->verdict.code = NFT_BREAK; + return; + } + + if (ipv4_is_zeronet(iph->saddr)) { + if (ipv4_is_lbcast(iph->daddr) || + ipv4_is_local_multicast(iph->daddr)) { + nft_fib_store_result(dest, priv, pkt->skb->dev); + return; + } + } + + if (priv->flags & NFTA_FIB_F_MARK) + fl4.flowi4_mark = pkt->skb->mark; + + fl4.flowi4_tos = iph->tos & DSCP_BITS; + + if (priv->flags & NFTA_FIB_F_DADDR) { + fl4.daddr = iph->daddr; + fl4.saddr = get_saddr(iph->saddr); + } else { + if (nft_hook(pkt) == NF_INET_FORWARD && + priv->flags & NFTA_FIB_F_IIF) + fl4.flowi4_iif = nft_out(pkt)->ifindex; + + fl4.daddr = iph->saddr; + fl4.saddr = get_saddr(iph->daddr); + } + + *dest = 0; + + if (fib_lookup(nft_net(pkt), &fl4, &res, FIB_LOOKUP_IGNORE_LINKSTATE)) + return; + + switch (res.type) { + case RTN_UNICAST: + break; + case RTN_LOCAL: /* Should not see RTN_LOCAL here */ + return; + default: + break; + } + + if (!oif) { + found = FIB_RES_DEV(res); + } else { + if (!fib_info_nh_uses_dev(res.fi, oif)) + return; + + found = oif; + } + + nft_fib_store_result(dest, priv, found); +} +EXPORT_SYMBOL_GPL(nft_fib4_eval); + +static struct nft_expr_type nft_fib4_type; + +static const struct nft_expr_ops nft_fib4_type_ops = { + .type = &nft_fib4_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_fib)), + .eval = nft_fib4_eval_type, + .init = nft_fib_init, + .dump = nft_fib_dump, + .validate = nft_fib_validate, + .reduce = nft_fib_reduce, +}; + +static const struct nft_expr_ops nft_fib4_ops = { + .type = &nft_fib4_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_fib)), + .eval = nft_fib4_eval, + .init = nft_fib_init, + .dump = nft_fib_dump, + .validate = nft_fib_validate, + .reduce = nft_fib_reduce, +}; + +static const struct nft_expr_ops * +nft_fib4_select_ops(const struct nft_ctx *ctx, + const struct nlattr * const tb[]) +{ + enum nft_fib_result result; + + if (!tb[NFTA_FIB_RESULT]) + return ERR_PTR(-EINVAL); + + result = ntohl(nla_get_be32(tb[NFTA_FIB_RESULT])); + + switch (result) { + case NFT_FIB_RESULT_OIF: + return &nft_fib4_ops; + case NFT_FIB_RESULT_OIFNAME: + return &nft_fib4_ops; + case NFT_FIB_RESULT_ADDRTYPE: + return &nft_fib4_type_ops; + default: + return ERR_PTR(-EOPNOTSUPP); + } +} + +static struct nft_expr_type nft_fib4_type __read_mostly = { + .name = "fib", + .select_ops = nft_fib4_select_ops, + .policy = nft_fib_policy, + .maxattr = NFTA_FIB_MAX, + .family = NFPROTO_IPV4, + .owner = THIS_MODULE, +}; + +static int __init nft_fib4_module_init(void) +{ + return nft_register_expr(&nft_fib4_type); +} + +static void __exit nft_fib4_module_exit(void) +{ + nft_unregister_expr(&nft_fib4_type); +} + +module_init(nft_fib4_module_init); +module_exit(nft_fib4_module_exit); +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Florian Westphal "); +MODULE_ALIAS_NFT_AF_EXPR(2, "fib"); +MODULE_DESCRIPTION("nftables fib / ip route lookup support"); diff --git a/net/ipv4/netfilter/nft_reject_ipv4.c b/net/ipv4/netfilter/nft_reject_ipv4.c new file mode 100644 index 000000000..6cb213bb7 --- /dev/null +++ b/net/ipv4/netfilter/nft_reject_ipv4.c @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008-2009 Patrick McHardy + * Copyright (c) 2013 Eric Leblond + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void nft_reject_ipv4_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_reject *priv = nft_expr_priv(expr); + + switch (priv->type) { + case NFT_REJECT_ICMP_UNREACH: + nf_send_unreach(pkt->skb, priv->icmp_code, nft_hook(pkt)); + break; + case NFT_REJECT_TCP_RST: + nf_send_reset(nft_net(pkt), nft_sk(pkt), pkt->skb, + nft_hook(pkt)); + break; + default: + break; + } + + regs->verdict.code = NF_DROP; +} + +static struct nft_expr_type nft_reject_ipv4_type; +static const struct nft_expr_ops nft_reject_ipv4_ops = { + .type = &nft_reject_ipv4_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_reject)), + .eval = nft_reject_ipv4_eval, + .init = nft_reject_init, + .dump = nft_reject_dump, + .validate = nft_reject_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_reject_ipv4_type __read_mostly = { + .family = NFPROTO_IPV4, + .name = "reject", + .ops = &nft_reject_ipv4_ops, + .policy = nft_reject_policy, + .maxattr = NFTA_REJECT_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_reject_ipv4_module_init(void) +{ + return nft_register_expr(&nft_reject_ipv4_type); +} + +static void __exit nft_reject_ipv4_module_exit(void) +{ + nft_unregister_expr(&nft_reject_ipv4_type); +} + +module_init(nft_reject_ipv4_module_init); +module_exit(nft_reject_ipv4_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_ALIAS_NFT_AF_EXPR(AF_INET, "reject"); +MODULE_DESCRIPTION("IPv4 packet rejection for nftables"); diff --git a/net/ipv4/netlink.c b/net/ipv4/netlink.c new file mode 100644 index 000000000..b920e1bdc --- /dev/null +++ b/net/ipv4/netlink.c @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include +#include + +int rtm_getroute_parse_ip_proto(struct nlattr *attr, u8 *ip_proto, u8 family, + struct netlink_ext_ack *extack) +{ + *ip_proto = nla_get_u8(attr); + + switch (*ip_proto) { + case IPPROTO_TCP: + case IPPROTO_UDP: + return 0; + case IPPROTO_ICMP: + if (family != AF_INET) + break; + return 0; +#if IS_ENABLED(CONFIG_IPV6) + case IPPROTO_ICMPV6: + if (family != AF_INET6) + break; + return 0; +#endif + } + NL_SET_ERR_MSG(extack, "Unsupported ip proto"); + return -EOPNOTSUPP; +} +EXPORT_SYMBOL_GPL(rtm_getroute_parse_ip_proto); diff --git a/net/ipv4/nexthop.c b/net/ipv4/nexthop.c new file mode 100644 index 000000000..be5498f5d --- /dev/null +++ b/net/ipv4/nexthop.c @@ -0,0 +1,3775 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Generic nexthop implementation + * + * Copyright (c) 2017-19 Cumulus Networks + * Copyright (c) 2017-19 David Ahern + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define NH_RES_DEFAULT_IDLE_TIMER (120 * HZ) +#define NH_RES_DEFAULT_UNBALANCED_TIMER 0 /* No forced rebalancing. */ + +static void remove_nexthop(struct net *net, struct nexthop *nh, + struct nl_info *nlinfo); + +#define NH_DEV_HASHBITS 8 +#define NH_DEV_HASHSIZE (1U << NH_DEV_HASHBITS) + +static const struct nla_policy rtm_nh_policy_new[] = { + [NHA_ID] = { .type = NLA_U32 }, + [NHA_GROUP] = { .type = NLA_BINARY }, + [NHA_GROUP_TYPE] = { .type = NLA_U16 }, + [NHA_BLACKHOLE] = { .type = NLA_FLAG }, + [NHA_OIF] = { .type = NLA_U32 }, + [NHA_GATEWAY] = { .type = NLA_BINARY }, + [NHA_ENCAP_TYPE] = { .type = NLA_U16 }, + [NHA_ENCAP] = { .type = NLA_NESTED }, + [NHA_FDB] = { .type = NLA_FLAG }, + [NHA_RES_GROUP] = { .type = NLA_NESTED }, +}; + +static const struct nla_policy rtm_nh_policy_get[] = { + [NHA_ID] = { .type = NLA_U32 }, +}; + +static const struct nla_policy rtm_nh_policy_dump[] = { + [NHA_OIF] = { .type = NLA_U32 }, + [NHA_GROUPS] = { .type = NLA_FLAG }, + [NHA_MASTER] = { .type = NLA_U32 }, + [NHA_FDB] = { .type = NLA_FLAG }, +}; + +static const struct nla_policy rtm_nh_res_policy_new[] = { + [NHA_RES_GROUP_BUCKETS] = { .type = NLA_U16 }, + [NHA_RES_GROUP_IDLE_TIMER] = { .type = NLA_U32 }, + [NHA_RES_GROUP_UNBALANCED_TIMER] = { .type = NLA_U32 }, +}; + +static const struct nla_policy rtm_nh_policy_dump_bucket[] = { + [NHA_ID] = { .type = NLA_U32 }, + [NHA_OIF] = { .type = NLA_U32 }, + [NHA_MASTER] = { .type = NLA_U32 }, + [NHA_RES_BUCKET] = { .type = NLA_NESTED }, +}; + +static const struct nla_policy rtm_nh_res_bucket_policy_dump[] = { + [NHA_RES_BUCKET_NH_ID] = { .type = NLA_U32 }, +}; + +static const struct nla_policy rtm_nh_policy_get_bucket[] = { + [NHA_ID] = { .type = NLA_U32 }, + [NHA_RES_BUCKET] = { .type = NLA_NESTED }, +}; + +static const struct nla_policy rtm_nh_res_bucket_policy_get[] = { + [NHA_RES_BUCKET_INDEX] = { .type = NLA_U16 }, +}; + +static bool nexthop_notifiers_is_empty(struct net *net) +{ + return !net->nexthop.notifier_chain.head; +} + +static void +__nh_notifier_single_info_init(struct nh_notifier_single_info *nh_info, + const struct nh_info *nhi) +{ + nh_info->dev = nhi->fib_nhc.nhc_dev; + nh_info->gw_family = nhi->fib_nhc.nhc_gw_family; + if (nh_info->gw_family == AF_INET) + nh_info->ipv4 = nhi->fib_nhc.nhc_gw.ipv4; + else if (nh_info->gw_family == AF_INET6) + nh_info->ipv6 = nhi->fib_nhc.nhc_gw.ipv6; + + nh_info->is_reject = nhi->reject_nh; + nh_info->is_fdb = nhi->fdb_nh; + nh_info->has_encap = !!nhi->fib_nhc.nhc_lwtstate; +} + +static int nh_notifier_single_info_init(struct nh_notifier_info *info, + const struct nexthop *nh) +{ + struct nh_info *nhi = rtnl_dereference(nh->nh_info); + + info->type = NH_NOTIFIER_INFO_TYPE_SINGLE; + info->nh = kzalloc(sizeof(*info->nh), GFP_KERNEL); + if (!info->nh) + return -ENOMEM; + + __nh_notifier_single_info_init(info->nh, nhi); + + return 0; +} + +static void nh_notifier_single_info_fini(struct nh_notifier_info *info) +{ + kfree(info->nh); +} + +static int nh_notifier_mpath_info_init(struct nh_notifier_info *info, + struct nh_group *nhg) +{ + u16 num_nh = nhg->num_nh; + int i; + + info->type = NH_NOTIFIER_INFO_TYPE_GRP; + info->nh_grp = kzalloc(struct_size(info->nh_grp, nh_entries, num_nh), + GFP_KERNEL); + if (!info->nh_grp) + return -ENOMEM; + + info->nh_grp->num_nh = num_nh; + info->nh_grp->is_fdb = nhg->fdb_nh; + + for (i = 0; i < num_nh; i++) { + struct nh_grp_entry *nhge = &nhg->nh_entries[i]; + struct nh_info *nhi; + + nhi = rtnl_dereference(nhge->nh->nh_info); + info->nh_grp->nh_entries[i].id = nhge->nh->id; + info->nh_grp->nh_entries[i].weight = nhge->weight; + __nh_notifier_single_info_init(&info->nh_grp->nh_entries[i].nh, + nhi); + } + + return 0; +} + +static int nh_notifier_res_table_info_init(struct nh_notifier_info *info, + struct nh_group *nhg) +{ + struct nh_res_table *res_table = rtnl_dereference(nhg->res_table); + u16 num_nh_buckets = res_table->num_nh_buckets; + unsigned long size; + u16 i; + + info->type = NH_NOTIFIER_INFO_TYPE_RES_TABLE; + size = struct_size(info->nh_res_table, nhs, num_nh_buckets); + info->nh_res_table = __vmalloc(size, GFP_KERNEL | __GFP_ZERO | + __GFP_NOWARN); + if (!info->nh_res_table) + return -ENOMEM; + + info->nh_res_table->num_nh_buckets = num_nh_buckets; + + for (i = 0; i < num_nh_buckets; i++) { + struct nh_res_bucket *bucket = &res_table->nh_buckets[i]; + struct nh_grp_entry *nhge; + struct nh_info *nhi; + + nhge = rtnl_dereference(bucket->nh_entry); + nhi = rtnl_dereference(nhge->nh->nh_info); + __nh_notifier_single_info_init(&info->nh_res_table->nhs[i], + nhi); + } + + return 0; +} + +static int nh_notifier_grp_info_init(struct nh_notifier_info *info, + const struct nexthop *nh) +{ + struct nh_group *nhg = rtnl_dereference(nh->nh_grp); + + if (nhg->hash_threshold) + return nh_notifier_mpath_info_init(info, nhg); + else if (nhg->resilient) + return nh_notifier_res_table_info_init(info, nhg); + return -EINVAL; +} + +static void nh_notifier_grp_info_fini(struct nh_notifier_info *info, + const struct nexthop *nh) +{ + struct nh_group *nhg = rtnl_dereference(nh->nh_grp); + + if (nhg->hash_threshold) + kfree(info->nh_grp); + else if (nhg->resilient) + vfree(info->nh_res_table); +} + +static int nh_notifier_info_init(struct nh_notifier_info *info, + const struct nexthop *nh) +{ + info->id = nh->id; + + if (nh->is_group) + return nh_notifier_grp_info_init(info, nh); + else + return nh_notifier_single_info_init(info, nh); +} + +static void nh_notifier_info_fini(struct nh_notifier_info *info, + const struct nexthop *nh) +{ + if (nh->is_group) + nh_notifier_grp_info_fini(info, nh); + else + nh_notifier_single_info_fini(info); +} + +static int call_nexthop_notifiers(struct net *net, + enum nexthop_event_type event_type, + struct nexthop *nh, + struct netlink_ext_ack *extack) +{ + struct nh_notifier_info info = { + .net = net, + .extack = extack, + }; + int err; + + ASSERT_RTNL(); + + if (nexthop_notifiers_is_empty(net)) + return 0; + + err = nh_notifier_info_init(&info, nh); + if (err) { + NL_SET_ERR_MSG(extack, "Failed to initialize nexthop notifier info"); + return err; + } + + err = blocking_notifier_call_chain(&net->nexthop.notifier_chain, + event_type, &info); + nh_notifier_info_fini(&info, nh); + + return notifier_to_errno(err); +} + +static int +nh_notifier_res_bucket_idle_timer_get(const struct nh_notifier_info *info, + bool force, unsigned int *p_idle_timer_ms) +{ + struct nh_res_table *res_table; + struct nh_group *nhg; + struct nexthop *nh; + int err = 0; + + /* When 'force' is false, nexthop bucket replacement is performed + * because the bucket was deemed to be idle. In this case, capable + * listeners can choose to perform an atomic replacement: The bucket is + * only replaced if it is inactive. However, if the idle timer interval + * is smaller than the interval in which a listener is querying + * buckets' activity from the device, then atomic replacement should + * not be tried. Pass the idle timer value to listeners, so that they + * could determine which type of replacement to perform. + */ + if (force) { + *p_idle_timer_ms = 0; + return 0; + } + + rcu_read_lock(); + + nh = nexthop_find_by_id(info->net, info->id); + if (!nh) { + err = -EINVAL; + goto out; + } + + nhg = rcu_dereference(nh->nh_grp); + res_table = rcu_dereference(nhg->res_table); + *p_idle_timer_ms = jiffies_to_msecs(res_table->idle_timer); + +out: + rcu_read_unlock(); + + return err; +} + +static int nh_notifier_res_bucket_info_init(struct nh_notifier_info *info, + u16 bucket_index, bool force, + struct nh_info *oldi, + struct nh_info *newi) +{ + unsigned int idle_timer_ms; + int err; + + err = nh_notifier_res_bucket_idle_timer_get(info, force, + &idle_timer_ms); + if (err) + return err; + + info->type = NH_NOTIFIER_INFO_TYPE_RES_BUCKET; + info->nh_res_bucket = kzalloc(sizeof(*info->nh_res_bucket), + GFP_KERNEL); + if (!info->nh_res_bucket) + return -ENOMEM; + + info->nh_res_bucket->bucket_index = bucket_index; + info->nh_res_bucket->idle_timer_ms = idle_timer_ms; + info->nh_res_bucket->force = force; + __nh_notifier_single_info_init(&info->nh_res_bucket->old_nh, oldi); + __nh_notifier_single_info_init(&info->nh_res_bucket->new_nh, newi); + return 0; +} + +static void nh_notifier_res_bucket_info_fini(struct nh_notifier_info *info) +{ + kfree(info->nh_res_bucket); +} + +static int __call_nexthop_res_bucket_notifiers(struct net *net, u32 nhg_id, + u16 bucket_index, bool force, + struct nh_info *oldi, + struct nh_info *newi, + struct netlink_ext_ack *extack) +{ + struct nh_notifier_info info = { + .net = net, + .extack = extack, + .id = nhg_id, + }; + int err; + + if (nexthop_notifiers_is_empty(net)) + return 0; + + err = nh_notifier_res_bucket_info_init(&info, bucket_index, force, + oldi, newi); + if (err) + return err; + + err = blocking_notifier_call_chain(&net->nexthop.notifier_chain, + NEXTHOP_EVENT_BUCKET_REPLACE, &info); + nh_notifier_res_bucket_info_fini(&info); + + return notifier_to_errno(err); +} + +/* There are three users of RES_TABLE, and NHs etc. referenced from there: + * + * 1) a collection of callbacks for NH maintenance. This operates under + * RTNL, + * 2) the delayed work that gradually balances the resilient table, + * 3) and nexthop_select_path(), operating under RCU. + * + * Both the delayed work and the RTNL block are writers, and need to + * maintain mutual exclusion. Since there are only two and well-known + * writers for each table, the RTNL code can make sure it has exclusive + * access thus: + * + * - Have the DW operate without locking; + * - synchronously cancel the DW; + * - do the writing; + * - if the write was not actually a delete, call upkeep, which schedules + * DW again if necessary. + * + * The functions that are always called from the RTNL context use + * rtnl_dereference(). The functions that can also be called from the DW do + * a raw dereference and rely on the above mutual exclusion scheme. + */ +#define nh_res_dereference(p) (rcu_dereference_raw(p)) + +static int call_nexthop_res_bucket_notifiers(struct net *net, u32 nhg_id, + u16 bucket_index, bool force, + struct nexthop *old_nh, + struct nexthop *new_nh, + struct netlink_ext_ack *extack) +{ + struct nh_info *oldi = nh_res_dereference(old_nh->nh_info); + struct nh_info *newi = nh_res_dereference(new_nh->nh_info); + + return __call_nexthop_res_bucket_notifiers(net, nhg_id, bucket_index, + force, oldi, newi, extack); +} + +static int call_nexthop_res_table_notifiers(struct net *net, struct nexthop *nh, + struct netlink_ext_ack *extack) +{ + struct nh_notifier_info info = { + .net = net, + .extack = extack, + }; + struct nh_group *nhg; + int err; + + ASSERT_RTNL(); + + if (nexthop_notifiers_is_empty(net)) + return 0; + + /* At this point, the nexthop buckets are still not populated. Only + * emit a notification with the logical nexthops, so that a listener + * could potentially veto it in case of unsupported configuration. + */ + nhg = rtnl_dereference(nh->nh_grp); + err = nh_notifier_mpath_info_init(&info, nhg); + if (err) { + NL_SET_ERR_MSG(extack, "Failed to initialize nexthop notifier info"); + return err; + } + + err = blocking_notifier_call_chain(&net->nexthop.notifier_chain, + NEXTHOP_EVENT_RES_TABLE_PRE_REPLACE, + &info); + kfree(info.nh_grp); + + return notifier_to_errno(err); +} + +static int call_nexthop_notifier(struct notifier_block *nb, struct net *net, + enum nexthop_event_type event_type, + struct nexthop *nh, + struct netlink_ext_ack *extack) +{ + struct nh_notifier_info info = { + .net = net, + .extack = extack, + }; + int err; + + err = nh_notifier_info_init(&info, nh); + if (err) + return err; + + err = nb->notifier_call(nb, event_type, &info); + nh_notifier_info_fini(&info, nh); + + return notifier_to_errno(err); +} + +static unsigned int nh_dev_hashfn(unsigned int val) +{ + unsigned int mask = NH_DEV_HASHSIZE - 1; + + return (val ^ + (val >> NH_DEV_HASHBITS) ^ + (val >> (NH_DEV_HASHBITS * 2))) & mask; +} + +static void nexthop_devhash_add(struct net *net, struct nh_info *nhi) +{ + struct net_device *dev = nhi->fib_nhc.nhc_dev; + struct hlist_head *head; + unsigned int hash; + + WARN_ON(!dev); + + hash = nh_dev_hashfn(dev->ifindex); + head = &net->nexthop.devhash[hash]; + hlist_add_head(&nhi->dev_hash, head); +} + +static void nexthop_free_group(struct nexthop *nh) +{ + struct nh_group *nhg; + int i; + + nhg = rcu_dereference_raw(nh->nh_grp); + for (i = 0; i < nhg->num_nh; ++i) { + struct nh_grp_entry *nhge = &nhg->nh_entries[i]; + + WARN_ON(!list_empty(&nhge->nh_list)); + nexthop_put(nhge->nh); + } + + WARN_ON(nhg->spare == nhg); + + if (nhg->resilient) + vfree(rcu_dereference_raw(nhg->res_table)); + + kfree(nhg->spare); + kfree(nhg); +} + +static void nexthop_free_single(struct nexthop *nh) +{ + struct nh_info *nhi; + + nhi = rcu_dereference_raw(nh->nh_info); + switch (nhi->family) { + case AF_INET: + fib_nh_release(nh->net, &nhi->fib_nh); + break; + case AF_INET6: + ipv6_stub->fib6_nh_release(&nhi->fib6_nh); + break; + } + kfree(nhi); +} + +void nexthop_free_rcu(struct rcu_head *head) +{ + struct nexthop *nh = container_of(head, struct nexthop, rcu); + + if (nh->is_group) + nexthop_free_group(nh); + else + nexthop_free_single(nh); + + kfree(nh); +} +EXPORT_SYMBOL_GPL(nexthop_free_rcu); + +static struct nexthop *nexthop_alloc(void) +{ + struct nexthop *nh; + + nh = kzalloc(sizeof(struct nexthop), GFP_KERNEL); + if (nh) { + INIT_LIST_HEAD(&nh->fi_list); + INIT_LIST_HEAD(&nh->f6i_list); + INIT_LIST_HEAD(&nh->grp_list); + INIT_LIST_HEAD(&nh->fdb_list); + } + return nh; +} + +static struct nh_group *nexthop_grp_alloc(u16 num_nh) +{ + struct nh_group *nhg; + + nhg = kzalloc(struct_size(nhg, nh_entries, num_nh), GFP_KERNEL); + if (nhg) + nhg->num_nh = num_nh; + + return nhg; +} + +static void nh_res_table_upkeep_dw(struct work_struct *work); + +static struct nh_res_table * +nexthop_res_table_alloc(struct net *net, u32 nhg_id, struct nh_config *cfg) +{ + const u16 num_nh_buckets = cfg->nh_grp_res_num_buckets; + struct nh_res_table *res_table; + unsigned long size; + + size = struct_size(res_table, nh_buckets, num_nh_buckets); + res_table = __vmalloc(size, GFP_KERNEL | __GFP_ZERO | __GFP_NOWARN); + if (!res_table) + return NULL; + + res_table->net = net; + res_table->nhg_id = nhg_id; + INIT_DELAYED_WORK(&res_table->upkeep_dw, &nh_res_table_upkeep_dw); + INIT_LIST_HEAD(&res_table->uw_nh_entries); + res_table->idle_timer = cfg->nh_grp_res_idle_timer; + res_table->unbalanced_timer = cfg->nh_grp_res_unbalanced_timer; + res_table->num_nh_buckets = num_nh_buckets; + return res_table; +} + +static void nh_base_seq_inc(struct net *net) +{ + while (++net->nexthop.seq == 0) + ; +} + +/* no reference taken; rcu lock or rtnl must be held */ +struct nexthop *nexthop_find_by_id(struct net *net, u32 id) +{ + struct rb_node **pp, *parent = NULL, *next; + + pp = &net->nexthop.rb_root.rb_node; + while (1) { + struct nexthop *nh; + + next = rcu_dereference_raw(*pp); + if (!next) + break; + parent = next; + + nh = rb_entry(parent, struct nexthop, rb_node); + if (id < nh->id) + pp = &next->rb_left; + else if (id > nh->id) + pp = &next->rb_right; + else + return nh; + } + return NULL; +} +EXPORT_SYMBOL_GPL(nexthop_find_by_id); + +/* used for auto id allocation; called with rtnl held */ +static u32 nh_find_unused_id(struct net *net) +{ + u32 id_start = net->nexthop.last_id_allocated; + + while (1) { + net->nexthop.last_id_allocated++; + if (net->nexthop.last_id_allocated == id_start) + break; + + if (!nexthop_find_by_id(net, net->nexthop.last_id_allocated)) + return net->nexthop.last_id_allocated; + } + return 0; +} + +static void nh_res_time_set_deadline(unsigned long next_time, + unsigned long *deadline) +{ + if (time_before(next_time, *deadline)) + *deadline = next_time; +} + +static clock_t nh_res_table_unbalanced_time(struct nh_res_table *res_table) +{ + if (list_empty(&res_table->uw_nh_entries)) + return 0; + return jiffies_delta_to_clock_t(jiffies - res_table->unbalanced_since); +} + +static int nla_put_nh_group_res(struct sk_buff *skb, struct nh_group *nhg) +{ + struct nh_res_table *res_table = rtnl_dereference(nhg->res_table); + struct nlattr *nest; + + nest = nla_nest_start(skb, NHA_RES_GROUP); + if (!nest) + return -EMSGSIZE; + + if (nla_put_u16(skb, NHA_RES_GROUP_BUCKETS, + res_table->num_nh_buckets) || + nla_put_u32(skb, NHA_RES_GROUP_IDLE_TIMER, + jiffies_to_clock_t(res_table->idle_timer)) || + nla_put_u32(skb, NHA_RES_GROUP_UNBALANCED_TIMER, + jiffies_to_clock_t(res_table->unbalanced_timer)) || + nla_put_u64_64bit(skb, NHA_RES_GROUP_UNBALANCED_TIME, + nh_res_table_unbalanced_time(res_table), + NHA_RES_GROUP_PAD)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int nla_put_nh_group(struct sk_buff *skb, struct nh_group *nhg) +{ + struct nexthop_grp *p; + size_t len = nhg->num_nh * sizeof(*p); + struct nlattr *nla; + u16 group_type = 0; + int i; + + if (nhg->hash_threshold) + group_type = NEXTHOP_GRP_TYPE_MPATH; + else if (nhg->resilient) + group_type = NEXTHOP_GRP_TYPE_RES; + + if (nla_put_u16(skb, NHA_GROUP_TYPE, group_type)) + goto nla_put_failure; + + nla = nla_reserve(skb, NHA_GROUP, len); + if (!nla) + goto nla_put_failure; + + p = nla_data(nla); + for (i = 0; i < nhg->num_nh; ++i) { + p->id = nhg->nh_entries[i].nh->id; + p->weight = nhg->nh_entries[i].weight - 1; + p += 1; + } + + if (nhg->resilient && nla_put_nh_group_res(skb, nhg)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static int nh_fill_node(struct sk_buff *skb, struct nexthop *nh, + int event, u32 portid, u32 seq, unsigned int nlflags) +{ + struct fib6_nh *fib6_nh; + struct fib_nh *fib_nh; + struct nlmsghdr *nlh; + struct nh_info *nhi; + struct nhmsg *nhm; + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(*nhm), nlflags); + if (!nlh) + return -EMSGSIZE; + + nhm = nlmsg_data(nlh); + nhm->nh_family = AF_UNSPEC; + nhm->nh_flags = nh->nh_flags; + nhm->nh_protocol = nh->protocol; + nhm->nh_scope = 0; + nhm->resvd = 0; + + if (nla_put_u32(skb, NHA_ID, nh->id)) + goto nla_put_failure; + + if (nh->is_group) { + struct nh_group *nhg = rtnl_dereference(nh->nh_grp); + + if (nhg->fdb_nh && nla_put_flag(skb, NHA_FDB)) + goto nla_put_failure; + if (nla_put_nh_group(skb, nhg)) + goto nla_put_failure; + goto out; + } + + nhi = rtnl_dereference(nh->nh_info); + nhm->nh_family = nhi->family; + if (nhi->reject_nh) { + if (nla_put_flag(skb, NHA_BLACKHOLE)) + goto nla_put_failure; + goto out; + } else if (nhi->fdb_nh) { + if (nla_put_flag(skb, NHA_FDB)) + goto nla_put_failure; + } else { + const struct net_device *dev; + + dev = nhi->fib_nhc.nhc_dev; + if (dev && nla_put_u32(skb, NHA_OIF, dev->ifindex)) + goto nla_put_failure; + } + + nhm->nh_scope = nhi->fib_nhc.nhc_scope; + switch (nhi->family) { + case AF_INET: + fib_nh = &nhi->fib_nh; + if (fib_nh->fib_nh_gw_family && + nla_put_be32(skb, NHA_GATEWAY, fib_nh->fib_nh_gw4)) + goto nla_put_failure; + break; + + case AF_INET6: + fib6_nh = &nhi->fib6_nh; + if (fib6_nh->fib_nh_gw_family && + nla_put_in6_addr(skb, NHA_GATEWAY, &fib6_nh->fib_nh_gw6)) + goto nla_put_failure; + break; + } + + if (nhi->fib_nhc.nhc_lwtstate && + lwtunnel_fill_encap(skb, nhi->fib_nhc.nhc_lwtstate, + NHA_ENCAP, NHA_ENCAP_TYPE) < 0) + goto nla_put_failure; + +out: + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static size_t nh_nlmsg_size_grp_res(struct nh_group *nhg) +{ + return nla_total_size(0) + /* NHA_RES_GROUP */ + nla_total_size(2) + /* NHA_RES_GROUP_BUCKETS */ + nla_total_size(4) + /* NHA_RES_GROUP_IDLE_TIMER */ + nla_total_size(4) + /* NHA_RES_GROUP_UNBALANCED_TIMER */ + nla_total_size_64bit(8);/* NHA_RES_GROUP_UNBALANCED_TIME */ +} + +static size_t nh_nlmsg_size_grp(struct nexthop *nh) +{ + struct nh_group *nhg = rtnl_dereference(nh->nh_grp); + size_t sz = sizeof(struct nexthop_grp) * nhg->num_nh; + size_t tot = nla_total_size(sz) + + nla_total_size(2); /* NHA_GROUP_TYPE */ + + if (nhg->resilient) + tot += nh_nlmsg_size_grp_res(nhg); + + return tot; +} + +static size_t nh_nlmsg_size_single(struct nexthop *nh) +{ + struct nh_info *nhi = rtnl_dereference(nh->nh_info); + size_t sz; + + /* covers NHA_BLACKHOLE since NHA_OIF and BLACKHOLE + * are mutually exclusive + */ + sz = nla_total_size(4); /* NHA_OIF */ + + switch (nhi->family) { + case AF_INET: + if (nhi->fib_nh.fib_nh_gw_family) + sz += nla_total_size(4); /* NHA_GATEWAY */ + break; + + case AF_INET6: + /* NHA_GATEWAY */ + if (nhi->fib6_nh.fib_nh_gw_family) + sz += nla_total_size(sizeof(const struct in6_addr)); + break; + } + + if (nhi->fib_nhc.nhc_lwtstate) { + sz += lwtunnel_get_encap_size(nhi->fib_nhc.nhc_lwtstate); + sz += nla_total_size(2); /* NHA_ENCAP_TYPE */ + } + + return sz; +} + +static size_t nh_nlmsg_size(struct nexthop *nh) +{ + size_t sz = NLMSG_ALIGN(sizeof(struct nhmsg)); + + sz += nla_total_size(4); /* NHA_ID */ + + if (nh->is_group) + sz += nh_nlmsg_size_grp(nh); + else + sz += nh_nlmsg_size_single(nh); + + return sz; +} + +static void nexthop_notify(int event, struct nexthop *nh, struct nl_info *info) +{ + unsigned int nlflags = info->nlh ? info->nlh->nlmsg_flags : 0; + u32 seq = info->nlh ? info->nlh->nlmsg_seq : 0; + struct sk_buff *skb; + int err = -ENOBUFS; + + skb = nlmsg_new(nh_nlmsg_size(nh), gfp_any()); + if (!skb) + goto errout; + + err = nh_fill_node(skb, nh, event, info->portid, seq, nlflags); + if (err < 0) { + /* -EMSGSIZE implies BUG in nh_nlmsg_size() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + + rtnl_notify(skb, info->nl_net, info->portid, RTNLGRP_NEXTHOP, + info->nlh, gfp_any()); + return; +errout: + if (err < 0) + rtnl_set_sk_err(info->nl_net, RTNLGRP_NEXTHOP, err); +} + +static unsigned long nh_res_bucket_used_time(const struct nh_res_bucket *bucket) +{ + return (unsigned long)atomic_long_read(&bucket->used_time); +} + +static unsigned long +nh_res_bucket_idle_point(const struct nh_res_table *res_table, + const struct nh_res_bucket *bucket, + unsigned long now) +{ + unsigned long time = nh_res_bucket_used_time(bucket); + + /* Bucket was not used since it was migrated. The idle time is now. */ + if (time == bucket->migrated_time) + return now; + + return time + res_table->idle_timer; +} + +static unsigned long +nh_res_table_unb_point(const struct nh_res_table *res_table) +{ + return res_table->unbalanced_since + res_table->unbalanced_timer; +} + +static void nh_res_bucket_set_idle(const struct nh_res_table *res_table, + struct nh_res_bucket *bucket) +{ + unsigned long now = jiffies; + + atomic_long_set(&bucket->used_time, (long)now); + bucket->migrated_time = now; +} + +static void nh_res_bucket_set_busy(struct nh_res_bucket *bucket) +{ + atomic_long_set(&bucket->used_time, (long)jiffies); +} + +static clock_t nh_res_bucket_idle_time(const struct nh_res_bucket *bucket) +{ + unsigned long used_time = nh_res_bucket_used_time(bucket); + + return jiffies_delta_to_clock_t(jiffies - used_time); +} + +static int nh_fill_res_bucket(struct sk_buff *skb, struct nexthop *nh, + struct nh_res_bucket *bucket, u16 bucket_index, + int event, u32 portid, u32 seq, + unsigned int nlflags, + struct netlink_ext_ack *extack) +{ + struct nh_grp_entry *nhge = nh_res_dereference(bucket->nh_entry); + struct nlmsghdr *nlh; + struct nlattr *nest; + struct nhmsg *nhm; + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(*nhm), nlflags); + if (!nlh) + return -EMSGSIZE; + + nhm = nlmsg_data(nlh); + nhm->nh_family = AF_UNSPEC; + nhm->nh_flags = bucket->nh_flags; + nhm->nh_protocol = nh->protocol; + nhm->nh_scope = 0; + nhm->resvd = 0; + + if (nla_put_u32(skb, NHA_ID, nh->id)) + goto nla_put_failure; + + nest = nla_nest_start(skb, NHA_RES_BUCKET); + if (!nest) + goto nla_put_failure; + + if (nla_put_u16(skb, NHA_RES_BUCKET_INDEX, bucket_index) || + nla_put_u32(skb, NHA_RES_BUCKET_NH_ID, nhge->nh->id) || + nla_put_u64_64bit(skb, NHA_RES_BUCKET_IDLE_TIME, + nh_res_bucket_idle_time(bucket), + NHA_RES_BUCKET_PAD)) + goto nla_put_failure_nest; + + nla_nest_end(skb, nest); + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure_nest: + nla_nest_cancel(skb, nest); +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static void nexthop_bucket_notify(struct nh_res_table *res_table, + u16 bucket_index) +{ + struct nh_res_bucket *bucket = &res_table->nh_buckets[bucket_index]; + struct nh_grp_entry *nhge = nh_res_dereference(bucket->nh_entry); + struct nexthop *nh = nhge->nh_parent; + struct sk_buff *skb; + int err = -ENOBUFS; + + skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) + goto errout; + + err = nh_fill_res_bucket(skb, nh, bucket, bucket_index, + RTM_NEWNEXTHOPBUCKET, 0, 0, NLM_F_REPLACE, + NULL); + if (err < 0) { + kfree_skb(skb); + goto errout; + } + + rtnl_notify(skb, nh->net, 0, RTNLGRP_NEXTHOP, NULL, GFP_KERNEL); + return; +errout: + if (err < 0) + rtnl_set_sk_err(nh->net, RTNLGRP_NEXTHOP, err); +} + +static bool valid_group_nh(struct nexthop *nh, unsigned int npaths, + bool *is_fdb, struct netlink_ext_ack *extack) +{ + if (nh->is_group) { + struct nh_group *nhg = rtnl_dereference(nh->nh_grp); + + /* Nesting groups within groups is not supported. */ + if (nhg->hash_threshold) { + NL_SET_ERR_MSG(extack, + "Hash-threshold group can not be a nexthop within a group"); + return false; + } + if (nhg->resilient) { + NL_SET_ERR_MSG(extack, + "Resilient group can not be a nexthop within a group"); + return false; + } + *is_fdb = nhg->fdb_nh; + } else { + struct nh_info *nhi = rtnl_dereference(nh->nh_info); + + if (nhi->reject_nh && npaths > 1) { + NL_SET_ERR_MSG(extack, + "Blackhole nexthop can not be used in a group with more than 1 path"); + return false; + } + *is_fdb = nhi->fdb_nh; + } + + return true; +} + +static int nh_check_attr_fdb_group(struct nexthop *nh, u8 *nh_family, + struct netlink_ext_ack *extack) +{ + struct nh_info *nhi; + + nhi = rtnl_dereference(nh->nh_info); + + if (!nhi->fdb_nh) { + NL_SET_ERR_MSG(extack, "FDB nexthop group can only have fdb nexthops"); + return -EINVAL; + } + + if (*nh_family == AF_UNSPEC) { + *nh_family = nhi->family; + } else if (*nh_family != nhi->family) { + NL_SET_ERR_MSG(extack, "FDB nexthop group cannot have mixed family nexthops"); + return -EINVAL; + } + + return 0; +} + +static int nh_check_attr_group(struct net *net, + struct nlattr *tb[], size_t tb_size, + u16 nh_grp_type, struct netlink_ext_ack *extack) +{ + unsigned int len = nla_len(tb[NHA_GROUP]); + u8 nh_family = AF_UNSPEC; + struct nexthop_grp *nhg; + unsigned int i, j; + u8 nhg_fdb = 0; + + if (!len || len & (sizeof(struct nexthop_grp) - 1)) { + NL_SET_ERR_MSG(extack, + "Invalid length for nexthop group attribute"); + return -EINVAL; + } + + /* convert len to number of nexthop ids */ + len /= sizeof(*nhg); + + nhg = nla_data(tb[NHA_GROUP]); + for (i = 0; i < len; ++i) { + if (nhg[i].resvd1 || nhg[i].resvd2) { + NL_SET_ERR_MSG(extack, "Reserved fields in nexthop_grp must be 0"); + return -EINVAL; + } + if (nhg[i].weight > 254) { + NL_SET_ERR_MSG(extack, "Invalid value for weight"); + return -EINVAL; + } + for (j = i + 1; j < len; ++j) { + if (nhg[i].id == nhg[j].id) { + NL_SET_ERR_MSG(extack, "Nexthop id can not be used twice in a group"); + return -EINVAL; + } + } + } + + if (tb[NHA_FDB]) + nhg_fdb = 1; + nhg = nla_data(tb[NHA_GROUP]); + for (i = 0; i < len; ++i) { + struct nexthop *nh; + bool is_fdb_nh; + + nh = nexthop_find_by_id(net, nhg[i].id); + if (!nh) { + NL_SET_ERR_MSG(extack, "Invalid nexthop id"); + return -EINVAL; + } + if (!valid_group_nh(nh, len, &is_fdb_nh, extack)) + return -EINVAL; + + if (nhg_fdb && nh_check_attr_fdb_group(nh, &nh_family, extack)) + return -EINVAL; + + if (!nhg_fdb && is_fdb_nh) { + NL_SET_ERR_MSG(extack, "Non FDB nexthop group cannot have fdb nexthops"); + return -EINVAL; + } + } + for (i = NHA_GROUP_TYPE + 1; i < tb_size; ++i) { + if (!tb[i]) + continue; + switch (i) { + case NHA_FDB: + continue; + case NHA_RES_GROUP: + if (nh_grp_type == NEXTHOP_GRP_TYPE_RES) + continue; + break; + } + NL_SET_ERR_MSG(extack, + "No other attributes can be set in nexthop groups"); + return -EINVAL; + } + + return 0; +} + +static bool ipv6_good_nh(const struct fib6_nh *nh) +{ + int state = NUD_REACHABLE; + struct neighbour *n; + + rcu_read_lock(); + + n = __ipv6_neigh_lookup_noref_stub(nh->fib_nh_dev, &nh->fib_nh_gw6); + if (n) + state = READ_ONCE(n->nud_state); + + rcu_read_unlock(); + + return !!(state & NUD_VALID); +} + +static bool ipv4_good_nh(const struct fib_nh *nh) +{ + int state = NUD_REACHABLE; + struct neighbour *n; + + rcu_read_lock(); + + n = __ipv4_neigh_lookup_noref(nh->fib_nh_dev, + (__force u32)nh->fib_nh_gw4); + if (n) + state = READ_ONCE(n->nud_state); + + rcu_read_unlock(); + + return !!(state & NUD_VALID); +} + +static struct nexthop *nexthop_select_path_hthr(struct nh_group *nhg, int hash) +{ + struct nexthop *rc = NULL; + int i; + + for (i = 0; i < nhg->num_nh; ++i) { + struct nh_grp_entry *nhge = &nhg->nh_entries[i]; + struct nh_info *nhi; + + if (hash > atomic_read(&nhge->hthr.upper_bound)) + continue; + + nhi = rcu_dereference(nhge->nh->nh_info); + if (nhi->fdb_nh) + return nhge->nh; + + /* nexthops always check if it is good and does + * not rely on a sysctl for this behavior + */ + switch (nhi->family) { + case AF_INET: + if (ipv4_good_nh(&nhi->fib_nh)) + return nhge->nh; + break; + case AF_INET6: + if (ipv6_good_nh(&nhi->fib6_nh)) + return nhge->nh; + break; + } + + if (!rc) + rc = nhge->nh; + } + + return rc; +} + +static struct nexthop *nexthop_select_path_res(struct nh_group *nhg, int hash) +{ + struct nh_res_table *res_table = rcu_dereference(nhg->res_table); + u16 bucket_index = hash % res_table->num_nh_buckets; + struct nh_res_bucket *bucket; + struct nh_grp_entry *nhge; + + /* nexthop_select_path() is expected to return a non-NULL value, so + * skip protocol validation and just hand out whatever there is. + */ + bucket = &res_table->nh_buckets[bucket_index]; + nh_res_bucket_set_busy(bucket); + nhge = rcu_dereference(bucket->nh_entry); + return nhge->nh; +} + +struct nexthop *nexthop_select_path(struct nexthop *nh, int hash) +{ + struct nh_group *nhg; + + if (!nh->is_group) + return nh; + + nhg = rcu_dereference(nh->nh_grp); + if (nhg->hash_threshold) + return nexthop_select_path_hthr(nhg, hash); + else if (nhg->resilient) + return nexthop_select_path_res(nhg, hash); + + /* Unreachable. */ + return NULL; +} +EXPORT_SYMBOL_GPL(nexthop_select_path); + +int nexthop_for_each_fib6_nh(struct nexthop *nh, + int (*cb)(struct fib6_nh *nh, void *arg), + void *arg) +{ + struct nh_info *nhi; + int err; + + if (nh->is_group) { + struct nh_group *nhg; + int i; + + nhg = rcu_dereference_rtnl(nh->nh_grp); + for (i = 0; i < nhg->num_nh; i++) { + struct nh_grp_entry *nhge = &nhg->nh_entries[i]; + + nhi = rcu_dereference_rtnl(nhge->nh->nh_info); + err = cb(&nhi->fib6_nh, arg); + if (err) + return err; + } + } else { + nhi = rcu_dereference_rtnl(nh->nh_info); + err = cb(&nhi->fib6_nh, arg); + if (err) + return err; + } + + return 0; +} +EXPORT_SYMBOL_GPL(nexthop_for_each_fib6_nh); + +static int check_src_addr(const struct in6_addr *saddr, + struct netlink_ext_ack *extack) +{ + if (!ipv6_addr_any(saddr)) { + NL_SET_ERR_MSG(extack, "IPv6 routes using source address can not use nexthop objects"); + return -EINVAL; + } + return 0; +} + +int fib6_check_nexthop(struct nexthop *nh, struct fib6_config *cfg, + struct netlink_ext_ack *extack) +{ + struct nh_info *nhi; + bool is_fdb_nh; + + /* fib6_src is unique to a fib6_info and limits the ability to cache + * routes in fib6_nh within a nexthop that is potentially shared + * across multiple fib entries. If the config wants to use source + * routing it can not use nexthop objects. mlxsw also does not allow + * fib6_src on routes. + */ + if (cfg && check_src_addr(&cfg->fc_src, extack) < 0) + return -EINVAL; + + if (nh->is_group) { + struct nh_group *nhg; + + nhg = rtnl_dereference(nh->nh_grp); + if (nhg->has_v4) + goto no_v4_nh; + is_fdb_nh = nhg->fdb_nh; + } else { + nhi = rtnl_dereference(nh->nh_info); + if (nhi->family == AF_INET) + goto no_v4_nh; + is_fdb_nh = nhi->fdb_nh; + } + + if (is_fdb_nh) { + NL_SET_ERR_MSG(extack, "Route cannot point to a fdb nexthop"); + return -EINVAL; + } + + return 0; +no_v4_nh: + NL_SET_ERR_MSG(extack, "IPv6 routes can not use an IPv4 nexthop"); + return -EINVAL; +} +EXPORT_SYMBOL_GPL(fib6_check_nexthop); + +/* if existing nexthop has ipv6 routes linked to it, need + * to verify this new spec works with ipv6 + */ +static int fib6_check_nh_list(struct nexthop *old, struct nexthop *new, + struct netlink_ext_ack *extack) +{ + struct fib6_info *f6i; + + if (list_empty(&old->f6i_list)) + return 0; + + list_for_each_entry(f6i, &old->f6i_list, nh_list) { + if (check_src_addr(&f6i->fib6_src.addr, extack) < 0) + return -EINVAL; + } + + return fib6_check_nexthop(new, NULL, extack); +} + +static int nexthop_check_scope(struct nh_info *nhi, u8 scope, + struct netlink_ext_ack *extack) +{ + if (scope == RT_SCOPE_HOST && nhi->fib_nhc.nhc_gw_family) { + NL_SET_ERR_MSG(extack, + "Route with host scope can not have a gateway"); + return -EINVAL; + } + + if (nhi->fib_nhc.nhc_flags & RTNH_F_ONLINK && scope >= RT_SCOPE_LINK) { + NL_SET_ERR_MSG(extack, "Scope mismatch with nexthop"); + return -EINVAL; + } + + return 0; +} + +/* Invoked by fib add code to verify nexthop by id is ok with + * config for prefix; parts of fib_check_nh not done when nexthop + * object is used. + */ +int fib_check_nexthop(struct nexthop *nh, u8 scope, + struct netlink_ext_ack *extack) +{ + struct nh_info *nhi; + int err = 0; + + if (nh->is_group) { + struct nh_group *nhg; + + nhg = rtnl_dereference(nh->nh_grp); + if (nhg->fdb_nh) { + NL_SET_ERR_MSG(extack, "Route cannot point to a fdb nexthop"); + err = -EINVAL; + goto out; + } + + if (scope == RT_SCOPE_HOST) { + NL_SET_ERR_MSG(extack, "Route with host scope can not have multiple nexthops"); + err = -EINVAL; + goto out; + } + + /* all nexthops in a group have the same scope */ + nhi = rtnl_dereference(nhg->nh_entries[0].nh->nh_info); + err = nexthop_check_scope(nhi, scope, extack); + } else { + nhi = rtnl_dereference(nh->nh_info); + if (nhi->fdb_nh) { + NL_SET_ERR_MSG(extack, "Route cannot point to a fdb nexthop"); + err = -EINVAL; + goto out; + } + err = nexthop_check_scope(nhi, scope, extack); + } + +out: + return err; +} + +static int fib_check_nh_list(struct nexthop *old, struct nexthop *new, + struct netlink_ext_ack *extack) +{ + struct fib_info *fi; + + list_for_each_entry(fi, &old->fi_list, nh_list) { + int err; + + err = fib_check_nexthop(new, fi->fib_scope, extack); + if (err) + return err; + } + return 0; +} + +static bool nh_res_nhge_is_balanced(const struct nh_grp_entry *nhge) +{ + return nhge->res.count_buckets == nhge->res.wants_buckets; +} + +static bool nh_res_nhge_is_ow(const struct nh_grp_entry *nhge) +{ + return nhge->res.count_buckets > nhge->res.wants_buckets; +} + +static bool nh_res_nhge_is_uw(const struct nh_grp_entry *nhge) +{ + return nhge->res.count_buckets < nhge->res.wants_buckets; +} + +static bool nh_res_table_is_balanced(const struct nh_res_table *res_table) +{ + return list_empty(&res_table->uw_nh_entries); +} + +static void nh_res_bucket_unset_nh(struct nh_res_bucket *bucket) +{ + struct nh_grp_entry *nhge; + + if (bucket->occupied) { + nhge = nh_res_dereference(bucket->nh_entry); + nhge->res.count_buckets--; + bucket->occupied = false; + } +} + +static void nh_res_bucket_set_nh(struct nh_res_bucket *bucket, + struct nh_grp_entry *nhge) +{ + nh_res_bucket_unset_nh(bucket); + + bucket->occupied = true; + rcu_assign_pointer(bucket->nh_entry, nhge); + nhge->res.count_buckets++; +} + +static bool nh_res_bucket_should_migrate(struct nh_res_table *res_table, + struct nh_res_bucket *bucket, + unsigned long *deadline, bool *force) +{ + unsigned long now = jiffies; + struct nh_grp_entry *nhge; + unsigned long idle_point; + + if (!bucket->occupied) { + /* The bucket is not occupied, its NHGE pointer is either + * NULL or obsolete. We _have to_ migrate: set force. + */ + *force = true; + return true; + } + + nhge = nh_res_dereference(bucket->nh_entry); + + /* If the bucket is populated by an underweight or balanced + * nexthop, do not migrate. + */ + if (!nh_res_nhge_is_ow(nhge)) + return false; + + /* At this point we know that the bucket is populated with an + * overweight nexthop. It needs to be migrated to a new nexthop if + * the idle timer of unbalanced timer expired. + */ + + idle_point = nh_res_bucket_idle_point(res_table, bucket, now); + if (time_after_eq(now, idle_point)) { + /* The bucket is idle. We _can_ migrate: unset force. */ + *force = false; + return true; + } + + /* Unbalanced timer of 0 means "never force". */ + if (res_table->unbalanced_timer) { + unsigned long unb_point; + + unb_point = nh_res_table_unb_point(res_table); + if (time_after(now, unb_point)) { + /* The bucket is not idle, but the unbalanced timer + * expired. We _can_ migrate, but set force anyway, + * so that drivers know to ignore activity reports + * from the HW. + */ + *force = true; + return true; + } + + nh_res_time_set_deadline(unb_point, deadline); + } + + nh_res_time_set_deadline(idle_point, deadline); + return false; +} + +static bool nh_res_bucket_migrate(struct nh_res_table *res_table, + u16 bucket_index, bool notify, + bool notify_nl, bool force) +{ + struct nh_res_bucket *bucket = &res_table->nh_buckets[bucket_index]; + struct nh_grp_entry *new_nhge; + struct netlink_ext_ack extack; + int err; + + new_nhge = list_first_entry_or_null(&res_table->uw_nh_entries, + struct nh_grp_entry, + res.uw_nh_entry); + if (WARN_ON_ONCE(!new_nhge)) + /* If this function is called, "bucket" is either not + * occupied, or it belongs to a next hop that is + * overweight. In either case, there ought to be a + * corresponding underweight next hop. + */ + return false; + + if (notify) { + struct nh_grp_entry *old_nhge; + + old_nhge = nh_res_dereference(bucket->nh_entry); + err = call_nexthop_res_bucket_notifiers(res_table->net, + res_table->nhg_id, + bucket_index, force, + old_nhge->nh, + new_nhge->nh, &extack); + if (err) { + pr_err_ratelimited("%s\n", extack._msg); + if (!force) + return false; + /* It is not possible to veto a forced replacement, so + * just clear the hardware flags from the nexthop + * bucket to indicate to user space that this bucket is + * not correctly populated in hardware. + */ + bucket->nh_flags &= ~(RTNH_F_OFFLOAD | RTNH_F_TRAP); + } + } + + nh_res_bucket_set_nh(bucket, new_nhge); + nh_res_bucket_set_idle(res_table, bucket); + + if (notify_nl) + nexthop_bucket_notify(res_table, bucket_index); + + if (nh_res_nhge_is_balanced(new_nhge)) + list_del(&new_nhge->res.uw_nh_entry); + return true; +} + +#define NH_RES_UPKEEP_DW_MINIMUM_INTERVAL (HZ / 2) + +static void nh_res_table_upkeep(struct nh_res_table *res_table, + bool notify, bool notify_nl) +{ + unsigned long now = jiffies; + unsigned long deadline; + u16 i; + + /* Deadline is the next time that upkeep should be run. It is the + * earliest time at which one of the buckets might be migrated. + * Start at the most pessimistic estimate: either unbalanced_timer + * from now, or if there is none, idle_timer from now. For each + * encountered time point, call nh_res_time_set_deadline() to + * refine the estimate. + */ + if (res_table->unbalanced_timer) + deadline = now + res_table->unbalanced_timer; + else + deadline = now + res_table->idle_timer; + + for (i = 0; i < res_table->num_nh_buckets; i++) { + struct nh_res_bucket *bucket = &res_table->nh_buckets[i]; + bool force; + + if (nh_res_bucket_should_migrate(res_table, bucket, + &deadline, &force)) { + if (!nh_res_bucket_migrate(res_table, i, notify, + notify_nl, force)) { + unsigned long idle_point; + + /* A driver can override the migration + * decision if the HW reports that the + * bucket is actually not idle. Therefore + * remark the bucket as busy again and + * update the deadline. + */ + nh_res_bucket_set_busy(bucket); + idle_point = nh_res_bucket_idle_point(res_table, + bucket, + now); + nh_res_time_set_deadline(idle_point, &deadline); + } + } + } + + /* If the group is still unbalanced, schedule the next upkeep to + * either the deadline computed above, or the minimum deadline, + * whichever comes later. + */ + if (!nh_res_table_is_balanced(res_table)) { + unsigned long now = jiffies; + unsigned long min_deadline; + + min_deadline = now + NH_RES_UPKEEP_DW_MINIMUM_INTERVAL; + if (time_before(deadline, min_deadline)) + deadline = min_deadline; + + queue_delayed_work(system_power_efficient_wq, + &res_table->upkeep_dw, deadline - now); + } +} + +static void nh_res_table_upkeep_dw(struct work_struct *work) +{ + struct delayed_work *dw = to_delayed_work(work); + struct nh_res_table *res_table; + + res_table = container_of(dw, struct nh_res_table, upkeep_dw); + nh_res_table_upkeep(res_table, true, true); +} + +static void nh_res_table_cancel_upkeep(struct nh_res_table *res_table) +{ + cancel_delayed_work_sync(&res_table->upkeep_dw); +} + +static void nh_res_group_rebalance(struct nh_group *nhg, + struct nh_res_table *res_table) +{ + int prev_upper_bound = 0; + int total = 0; + int w = 0; + int i; + + INIT_LIST_HEAD(&res_table->uw_nh_entries); + + for (i = 0; i < nhg->num_nh; ++i) + total += nhg->nh_entries[i].weight; + + for (i = 0; i < nhg->num_nh; ++i) { + struct nh_grp_entry *nhge = &nhg->nh_entries[i]; + int upper_bound; + + w += nhge->weight; + upper_bound = DIV_ROUND_CLOSEST(res_table->num_nh_buckets * w, + total); + nhge->res.wants_buckets = upper_bound - prev_upper_bound; + prev_upper_bound = upper_bound; + + if (nh_res_nhge_is_uw(nhge)) { + if (list_empty(&res_table->uw_nh_entries)) + res_table->unbalanced_since = jiffies; + list_add(&nhge->res.uw_nh_entry, + &res_table->uw_nh_entries); + } + } +} + +/* Migrate buckets in res_table so that they reference NHGE's from NHG with + * the right NH ID. Set those buckets that do not have a corresponding NHGE + * entry in NHG as not occupied. + */ +static void nh_res_table_migrate_buckets(struct nh_res_table *res_table, + struct nh_group *nhg) +{ + u16 i; + + for (i = 0; i < res_table->num_nh_buckets; i++) { + struct nh_res_bucket *bucket = &res_table->nh_buckets[i]; + u32 id = rtnl_dereference(bucket->nh_entry)->nh->id; + bool found = false; + int j; + + for (j = 0; j < nhg->num_nh; j++) { + struct nh_grp_entry *nhge = &nhg->nh_entries[j]; + + if (nhge->nh->id == id) { + nh_res_bucket_set_nh(bucket, nhge); + found = true; + break; + } + } + + if (!found) + nh_res_bucket_unset_nh(bucket); + } +} + +static void replace_nexthop_grp_res(struct nh_group *oldg, + struct nh_group *newg) +{ + /* For NH group replacement, the new NHG might only have a stub + * hash table with 0 buckets, because the number of buckets was not + * specified. For NH removal, oldg and newg both reference the same + * res_table. So in any case, in the following, we want to work + * with oldg->res_table. + */ + struct nh_res_table *old_res_table = rtnl_dereference(oldg->res_table); + unsigned long prev_unbalanced_since = old_res_table->unbalanced_since; + bool prev_has_uw = !list_empty(&old_res_table->uw_nh_entries); + + nh_res_table_cancel_upkeep(old_res_table); + nh_res_table_migrate_buckets(old_res_table, newg); + nh_res_group_rebalance(newg, old_res_table); + if (prev_has_uw && !list_empty(&old_res_table->uw_nh_entries)) + old_res_table->unbalanced_since = prev_unbalanced_since; + nh_res_table_upkeep(old_res_table, true, false); +} + +static void nh_hthr_group_rebalance(struct nh_group *nhg) +{ + int total = 0; + int w = 0; + int i; + + for (i = 0; i < nhg->num_nh; ++i) + total += nhg->nh_entries[i].weight; + + for (i = 0; i < nhg->num_nh; ++i) { + struct nh_grp_entry *nhge = &nhg->nh_entries[i]; + int upper_bound; + + w += nhge->weight; + upper_bound = DIV_ROUND_CLOSEST_ULL((u64)w << 31, total) - 1; + atomic_set(&nhge->hthr.upper_bound, upper_bound); + } +} + +static void remove_nh_grp_entry(struct net *net, struct nh_grp_entry *nhge, + struct nl_info *nlinfo) +{ + struct nh_grp_entry *nhges, *new_nhges; + struct nexthop *nhp = nhge->nh_parent; + struct netlink_ext_ack extack; + struct nexthop *nh = nhge->nh; + struct nh_group *nhg, *newg; + int i, j, err; + + WARN_ON(!nh); + + nhg = rtnl_dereference(nhp->nh_grp); + newg = nhg->spare; + + /* last entry, keep it visible and remove the parent */ + if (nhg->num_nh == 1) { + remove_nexthop(net, nhp, nlinfo); + return; + } + + newg->has_v4 = false; + newg->is_multipath = nhg->is_multipath; + newg->hash_threshold = nhg->hash_threshold; + newg->resilient = nhg->resilient; + newg->fdb_nh = nhg->fdb_nh; + newg->num_nh = nhg->num_nh; + + /* copy old entries to new except the one getting removed */ + nhges = nhg->nh_entries; + new_nhges = newg->nh_entries; + for (i = 0, j = 0; i < nhg->num_nh; ++i) { + struct nh_info *nhi; + + /* current nexthop getting removed */ + if (nhg->nh_entries[i].nh == nh) { + newg->num_nh--; + continue; + } + + nhi = rtnl_dereference(nhges[i].nh->nh_info); + if (nhi->family == AF_INET) + newg->has_v4 = true; + + list_del(&nhges[i].nh_list); + new_nhges[j].nh_parent = nhges[i].nh_parent; + new_nhges[j].nh = nhges[i].nh; + new_nhges[j].weight = nhges[i].weight; + list_add(&new_nhges[j].nh_list, &new_nhges[j].nh->grp_list); + j++; + } + + if (newg->hash_threshold) + nh_hthr_group_rebalance(newg); + else if (newg->resilient) + replace_nexthop_grp_res(nhg, newg); + + rcu_assign_pointer(nhp->nh_grp, newg); + + list_del(&nhge->nh_list); + nexthop_put(nhge->nh); + + /* Removal of a NH from a resilient group is notified through + * bucket notifications. + */ + if (newg->hash_threshold) { + err = call_nexthop_notifiers(net, NEXTHOP_EVENT_REPLACE, nhp, + &extack); + if (err) + pr_err("%s\n", extack._msg); + } + + if (nlinfo) + nexthop_notify(RTM_NEWNEXTHOP, nhp, nlinfo); +} + +static void remove_nexthop_from_groups(struct net *net, struct nexthop *nh, + struct nl_info *nlinfo) +{ + struct nh_grp_entry *nhge, *tmp; + + list_for_each_entry_safe(nhge, tmp, &nh->grp_list, nh_list) + remove_nh_grp_entry(net, nhge, nlinfo); + + /* make sure all see the newly published array before releasing rtnl */ + synchronize_net(); +} + +static void remove_nexthop_group(struct nexthop *nh, struct nl_info *nlinfo) +{ + struct nh_group *nhg = rcu_dereference_rtnl(nh->nh_grp); + struct nh_res_table *res_table; + int i, num_nh = nhg->num_nh; + + for (i = 0; i < num_nh; ++i) { + struct nh_grp_entry *nhge = &nhg->nh_entries[i]; + + if (WARN_ON(!nhge->nh)) + continue; + + list_del_init(&nhge->nh_list); + } + + if (nhg->resilient) { + res_table = rtnl_dereference(nhg->res_table); + nh_res_table_cancel_upkeep(res_table); + } +} + +/* not called for nexthop replace */ +static void __remove_nexthop_fib(struct net *net, struct nexthop *nh) +{ + struct fib6_info *f6i, *tmp; + bool do_flush = false; + struct fib_info *fi; + + list_for_each_entry(fi, &nh->fi_list, nh_list) { + fi->fib_flags |= RTNH_F_DEAD; + do_flush = true; + } + if (do_flush) + fib_flush(net); + + /* ip6_del_rt removes the entry from this list hence the _safe */ + list_for_each_entry_safe(f6i, tmp, &nh->f6i_list, nh_list) { + /* __ip6_del_rt does a release, so do a hold here */ + fib6_info_hold(f6i); + ipv6_stub->ip6_del_rt(net, f6i, + !READ_ONCE(net->ipv4.sysctl_nexthop_compat_mode)); + } +} + +static void __remove_nexthop(struct net *net, struct nexthop *nh, + struct nl_info *nlinfo) +{ + __remove_nexthop_fib(net, nh); + + if (nh->is_group) { + remove_nexthop_group(nh, nlinfo); + } else { + struct nh_info *nhi; + + nhi = rtnl_dereference(nh->nh_info); + if (nhi->fib_nhc.nhc_dev) + hlist_del(&nhi->dev_hash); + + remove_nexthop_from_groups(net, nh, nlinfo); + } +} + +static void remove_nexthop(struct net *net, struct nexthop *nh, + struct nl_info *nlinfo) +{ + call_nexthop_notifiers(net, NEXTHOP_EVENT_DEL, nh, NULL); + + /* remove from the tree */ + rb_erase(&nh->rb_node, &net->nexthop.rb_root); + + if (nlinfo) + nexthop_notify(RTM_DELNEXTHOP, nh, nlinfo); + + __remove_nexthop(net, nh, nlinfo); + nh_base_seq_inc(net); + + nexthop_put(nh); +} + +/* if any FIB entries reference this nexthop, any dst entries + * need to be regenerated + */ +static void nh_rt_cache_flush(struct net *net, struct nexthop *nh, + struct nexthop *replaced_nh) +{ + struct fib6_info *f6i; + struct nh_group *nhg; + int i; + + if (!list_empty(&nh->fi_list)) + rt_cache_flush(net); + + list_for_each_entry(f6i, &nh->f6i_list, nh_list) + ipv6_stub->fib6_update_sernum(net, f6i); + + /* if an IPv6 group was replaced, we have to release all old + * dsts to make sure all refcounts are released + */ + if (!replaced_nh->is_group) + return; + + nhg = rtnl_dereference(replaced_nh->nh_grp); + for (i = 0; i < nhg->num_nh; i++) { + struct nh_grp_entry *nhge = &nhg->nh_entries[i]; + struct nh_info *nhi = rtnl_dereference(nhge->nh->nh_info); + + if (nhi->family == AF_INET6) + ipv6_stub->fib6_nh_release_dsts(&nhi->fib6_nh); + } +} + +static int replace_nexthop_grp(struct net *net, struct nexthop *old, + struct nexthop *new, const struct nh_config *cfg, + struct netlink_ext_ack *extack) +{ + struct nh_res_table *tmp_table = NULL; + struct nh_res_table *new_res_table; + struct nh_res_table *old_res_table; + struct nh_group *oldg, *newg; + int i, err; + + if (!new->is_group) { + NL_SET_ERR_MSG(extack, "Can not replace a nexthop group with a nexthop."); + return -EINVAL; + } + + oldg = rtnl_dereference(old->nh_grp); + newg = rtnl_dereference(new->nh_grp); + + if (newg->hash_threshold != oldg->hash_threshold) { + NL_SET_ERR_MSG(extack, "Can not replace a nexthop group with one of a different type."); + return -EINVAL; + } + + if (newg->hash_threshold) { + err = call_nexthop_notifiers(net, NEXTHOP_EVENT_REPLACE, new, + extack); + if (err) + return err; + } else if (newg->resilient) { + new_res_table = rtnl_dereference(newg->res_table); + old_res_table = rtnl_dereference(oldg->res_table); + + /* Accept if num_nh_buckets was not given, but if it was + * given, demand that the value be correct. + */ + if (cfg->nh_grp_res_has_num_buckets && + cfg->nh_grp_res_num_buckets != + old_res_table->num_nh_buckets) { + NL_SET_ERR_MSG(extack, "Can not change number of buckets of a resilient nexthop group."); + return -EINVAL; + } + + /* Emit a pre-replace notification so that listeners could veto + * a potentially unsupported configuration. Otherwise, + * individual bucket replacement notifications would need to be + * vetoed, which is something that should only happen if the + * bucket is currently active. + */ + err = call_nexthop_res_table_notifiers(net, new, extack); + if (err) + return err; + + if (cfg->nh_grp_res_has_idle_timer) + old_res_table->idle_timer = cfg->nh_grp_res_idle_timer; + if (cfg->nh_grp_res_has_unbalanced_timer) + old_res_table->unbalanced_timer = + cfg->nh_grp_res_unbalanced_timer; + + replace_nexthop_grp_res(oldg, newg); + + tmp_table = new_res_table; + rcu_assign_pointer(newg->res_table, old_res_table); + rcu_assign_pointer(newg->spare->res_table, old_res_table); + } + + /* update parents - used by nexthop code for cleanup */ + for (i = 0; i < newg->num_nh; i++) + newg->nh_entries[i].nh_parent = old; + + rcu_assign_pointer(old->nh_grp, newg); + + /* Make sure concurrent readers are not using 'oldg' anymore. */ + synchronize_net(); + + if (newg->resilient) { + rcu_assign_pointer(oldg->res_table, tmp_table); + rcu_assign_pointer(oldg->spare->res_table, tmp_table); + } + + for (i = 0; i < oldg->num_nh; i++) + oldg->nh_entries[i].nh_parent = new; + + rcu_assign_pointer(new->nh_grp, oldg); + + return 0; +} + +static void nh_group_v4_update(struct nh_group *nhg) +{ + struct nh_grp_entry *nhges; + bool has_v4 = false; + int i; + + nhges = nhg->nh_entries; + for (i = 0; i < nhg->num_nh; i++) { + struct nh_info *nhi; + + nhi = rtnl_dereference(nhges[i].nh->nh_info); + if (nhi->family == AF_INET) + has_v4 = true; + } + nhg->has_v4 = has_v4; +} + +static int replace_nexthop_single_notify_res(struct net *net, + struct nh_res_table *res_table, + struct nexthop *old, + struct nh_info *oldi, + struct nh_info *newi, + struct netlink_ext_ack *extack) +{ + u32 nhg_id = res_table->nhg_id; + int err; + u16 i; + + for (i = 0; i < res_table->num_nh_buckets; i++) { + struct nh_res_bucket *bucket = &res_table->nh_buckets[i]; + struct nh_grp_entry *nhge; + + nhge = rtnl_dereference(bucket->nh_entry); + if (nhge->nh == old) { + err = __call_nexthop_res_bucket_notifiers(net, nhg_id, + i, true, + oldi, newi, + extack); + if (err) + goto err_notify; + } + } + + return 0; + +err_notify: + while (i-- > 0) { + struct nh_res_bucket *bucket = &res_table->nh_buckets[i]; + struct nh_grp_entry *nhge; + + nhge = rtnl_dereference(bucket->nh_entry); + if (nhge->nh == old) + __call_nexthop_res_bucket_notifiers(net, nhg_id, i, + true, newi, oldi, + extack); + } + return err; +} + +static int replace_nexthop_single_notify(struct net *net, + struct nexthop *group_nh, + struct nexthop *old, + struct nh_info *oldi, + struct nh_info *newi, + struct netlink_ext_ack *extack) +{ + struct nh_group *nhg = rtnl_dereference(group_nh->nh_grp); + struct nh_res_table *res_table; + + if (nhg->hash_threshold) { + return call_nexthop_notifiers(net, NEXTHOP_EVENT_REPLACE, + group_nh, extack); + } else if (nhg->resilient) { + res_table = rtnl_dereference(nhg->res_table); + return replace_nexthop_single_notify_res(net, res_table, + old, oldi, newi, + extack); + } + + return -EINVAL; +} + +static int replace_nexthop_single(struct net *net, struct nexthop *old, + struct nexthop *new, + struct netlink_ext_ack *extack) +{ + u8 old_protocol, old_nh_flags; + struct nh_info *oldi, *newi; + struct nh_grp_entry *nhge; + int err; + + if (new->is_group) { + NL_SET_ERR_MSG(extack, "Can not replace a nexthop with a nexthop group."); + return -EINVAL; + } + + err = call_nexthop_notifiers(net, NEXTHOP_EVENT_REPLACE, new, extack); + if (err) + return err; + + /* Hardware flags were set on 'old' as 'new' is not in the red-black + * tree. Therefore, inherit the flags from 'old' to 'new'. + */ + new->nh_flags |= old->nh_flags & (RTNH_F_OFFLOAD | RTNH_F_TRAP); + + oldi = rtnl_dereference(old->nh_info); + newi = rtnl_dereference(new->nh_info); + + newi->nh_parent = old; + oldi->nh_parent = new; + + old_protocol = old->protocol; + old_nh_flags = old->nh_flags; + + old->protocol = new->protocol; + old->nh_flags = new->nh_flags; + + rcu_assign_pointer(old->nh_info, newi); + rcu_assign_pointer(new->nh_info, oldi); + + /* Send a replace notification for all the groups using the nexthop. */ + list_for_each_entry(nhge, &old->grp_list, nh_list) { + struct nexthop *nhp = nhge->nh_parent; + + err = replace_nexthop_single_notify(net, nhp, old, oldi, newi, + extack); + if (err) + goto err_notify; + } + + /* When replacing an IPv4 nexthop with an IPv6 nexthop, potentially + * update IPv4 indication in all the groups using the nexthop. + */ + if (oldi->family == AF_INET && newi->family == AF_INET6) { + list_for_each_entry(nhge, &old->grp_list, nh_list) { + struct nexthop *nhp = nhge->nh_parent; + struct nh_group *nhg; + + nhg = rtnl_dereference(nhp->nh_grp); + nh_group_v4_update(nhg); + } + } + + return 0; + +err_notify: + rcu_assign_pointer(new->nh_info, newi); + rcu_assign_pointer(old->nh_info, oldi); + old->nh_flags = old_nh_flags; + old->protocol = old_protocol; + oldi->nh_parent = old; + newi->nh_parent = new; + list_for_each_entry_continue_reverse(nhge, &old->grp_list, nh_list) { + struct nexthop *nhp = nhge->nh_parent; + + replace_nexthop_single_notify(net, nhp, old, newi, oldi, NULL); + } + call_nexthop_notifiers(net, NEXTHOP_EVENT_REPLACE, old, extack); + return err; +} + +static void __nexthop_replace_notify(struct net *net, struct nexthop *nh, + struct nl_info *info) +{ + struct fib6_info *f6i; + + if (!list_empty(&nh->fi_list)) { + struct fib_info *fi; + + /* expectation is a few fib_info per nexthop and then + * a lot of routes per fib_info. So mark the fib_info + * and then walk the fib tables once + */ + list_for_each_entry(fi, &nh->fi_list, nh_list) + fi->nh_updated = true; + + fib_info_notify_update(net, info); + + list_for_each_entry(fi, &nh->fi_list, nh_list) + fi->nh_updated = false; + } + + list_for_each_entry(f6i, &nh->f6i_list, nh_list) + ipv6_stub->fib6_rt_update(net, f6i, info); +} + +/* send RTM_NEWROUTE with REPLACE flag set for all FIB entries + * linked to this nexthop and for all groups that the nexthop + * is a member of + */ +static void nexthop_replace_notify(struct net *net, struct nexthop *nh, + struct nl_info *info) +{ + struct nh_grp_entry *nhge; + + __nexthop_replace_notify(net, nh, info); + + list_for_each_entry(nhge, &nh->grp_list, nh_list) + __nexthop_replace_notify(net, nhge->nh_parent, info); +} + +static int replace_nexthop(struct net *net, struct nexthop *old, + struct nexthop *new, const struct nh_config *cfg, + struct netlink_ext_ack *extack) +{ + bool new_is_reject = false; + struct nh_grp_entry *nhge; + int err; + + /* check that existing FIB entries are ok with the + * new nexthop definition + */ + err = fib_check_nh_list(old, new, extack); + if (err) + return err; + + err = fib6_check_nh_list(old, new, extack); + if (err) + return err; + + if (!new->is_group) { + struct nh_info *nhi = rtnl_dereference(new->nh_info); + + new_is_reject = nhi->reject_nh; + } + + list_for_each_entry(nhge, &old->grp_list, nh_list) { + /* if new nexthop is a blackhole, any groups using this + * nexthop cannot have more than 1 path + */ + if (new_is_reject && + nexthop_num_path(nhge->nh_parent) > 1) { + NL_SET_ERR_MSG(extack, "Blackhole nexthop can not be a member of a group with more than one path"); + return -EINVAL; + } + + err = fib_check_nh_list(nhge->nh_parent, new, extack); + if (err) + return err; + + err = fib6_check_nh_list(nhge->nh_parent, new, extack); + if (err) + return err; + } + + if (old->is_group) + err = replace_nexthop_grp(net, old, new, cfg, extack); + else + err = replace_nexthop_single(net, old, new, extack); + + if (!err) { + nh_rt_cache_flush(net, old, new); + + __remove_nexthop(net, new, NULL); + nexthop_put(new); + } + + return err; +} + +/* called with rtnl_lock held */ +static int insert_nexthop(struct net *net, struct nexthop *new_nh, + struct nh_config *cfg, struct netlink_ext_ack *extack) +{ + struct rb_node **pp, *parent = NULL, *next; + struct rb_root *root = &net->nexthop.rb_root; + bool replace = !!(cfg->nlflags & NLM_F_REPLACE); + bool create = !!(cfg->nlflags & NLM_F_CREATE); + u32 new_id = new_nh->id; + int replace_notify = 0; + int rc = -EEXIST; + + pp = &root->rb_node; + while (1) { + struct nexthop *nh; + + next = *pp; + if (!next) + break; + + parent = next; + + nh = rb_entry(parent, struct nexthop, rb_node); + if (new_id < nh->id) { + pp = &next->rb_left; + } else if (new_id > nh->id) { + pp = &next->rb_right; + } else if (replace) { + rc = replace_nexthop(net, nh, new_nh, cfg, extack); + if (!rc) { + new_nh = nh; /* send notification with old nh */ + replace_notify = 1; + } + goto out; + } else { + /* id already exists and not a replace */ + goto out; + } + } + + if (replace && !create) { + NL_SET_ERR_MSG(extack, "Replace specified without create and no entry exists"); + rc = -ENOENT; + goto out; + } + + if (new_nh->is_group) { + struct nh_group *nhg = rtnl_dereference(new_nh->nh_grp); + struct nh_res_table *res_table; + + if (nhg->resilient) { + res_table = rtnl_dereference(nhg->res_table); + + /* Not passing the number of buckets is OK when + * replacing, but not when creating a new group. + */ + if (!cfg->nh_grp_res_has_num_buckets) { + NL_SET_ERR_MSG(extack, "Number of buckets not specified for nexthop group insertion"); + rc = -EINVAL; + goto out; + } + + nh_res_group_rebalance(nhg, res_table); + + /* Do not send bucket notifications, we do full + * notification below. + */ + nh_res_table_upkeep(res_table, false, false); + } + } + + rb_link_node_rcu(&new_nh->rb_node, parent, pp); + rb_insert_color(&new_nh->rb_node, root); + + /* The initial insertion is a full notification for hash-threshold as + * well as resilient groups. + */ + rc = call_nexthop_notifiers(net, NEXTHOP_EVENT_REPLACE, new_nh, extack); + if (rc) + rb_erase(&new_nh->rb_node, &net->nexthop.rb_root); + +out: + if (!rc) { + nh_base_seq_inc(net); + nexthop_notify(RTM_NEWNEXTHOP, new_nh, &cfg->nlinfo); + if (replace_notify && + READ_ONCE(net->ipv4.sysctl_nexthop_compat_mode)) + nexthop_replace_notify(net, new_nh, &cfg->nlinfo); + } + + return rc; +} + +/* rtnl */ +/* remove all nexthops tied to a device being deleted */ +static void nexthop_flush_dev(struct net_device *dev, unsigned long event) +{ + unsigned int hash = nh_dev_hashfn(dev->ifindex); + struct net *net = dev_net(dev); + struct hlist_head *head = &net->nexthop.devhash[hash]; + struct hlist_node *n; + struct nh_info *nhi; + + hlist_for_each_entry_safe(nhi, n, head, dev_hash) { + if (nhi->fib_nhc.nhc_dev != dev) + continue; + + if (nhi->reject_nh && + (event == NETDEV_DOWN || event == NETDEV_CHANGE)) + continue; + + remove_nexthop(net, nhi->nh_parent, NULL); + } +} + +/* rtnl; called when net namespace is deleted */ +static void flush_all_nexthops(struct net *net) +{ + struct rb_root *root = &net->nexthop.rb_root; + struct rb_node *node; + struct nexthop *nh; + + while ((node = rb_first(root))) { + nh = rb_entry(node, struct nexthop, rb_node); + remove_nexthop(net, nh, NULL); + cond_resched(); + } +} + +static struct nexthop *nexthop_create_group(struct net *net, + struct nh_config *cfg) +{ + struct nlattr *grps_attr = cfg->nh_grp; + struct nexthop_grp *entry = nla_data(grps_attr); + u16 num_nh = nla_len(grps_attr) / sizeof(*entry); + struct nh_group *nhg; + struct nexthop *nh; + int err; + int i; + + if (WARN_ON(!num_nh)) + return ERR_PTR(-EINVAL); + + nh = nexthop_alloc(); + if (!nh) + return ERR_PTR(-ENOMEM); + + nh->is_group = 1; + + nhg = nexthop_grp_alloc(num_nh); + if (!nhg) { + kfree(nh); + return ERR_PTR(-ENOMEM); + } + + /* spare group used for removals */ + nhg->spare = nexthop_grp_alloc(num_nh); + if (!nhg->spare) { + kfree(nhg); + kfree(nh); + return ERR_PTR(-ENOMEM); + } + nhg->spare->spare = nhg; + + for (i = 0; i < nhg->num_nh; ++i) { + struct nexthop *nhe; + struct nh_info *nhi; + + nhe = nexthop_find_by_id(net, entry[i].id); + if (!nexthop_get(nhe)) { + err = -ENOENT; + goto out_no_nh; + } + + nhi = rtnl_dereference(nhe->nh_info); + if (nhi->family == AF_INET) + nhg->has_v4 = true; + + nhg->nh_entries[i].nh = nhe; + nhg->nh_entries[i].weight = entry[i].weight + 1; + list_add(&nhg->nh_entries[i].nh_list, &nhe->grp_list); + nhg->nh_entries[i].nh_parent = nh; + } + + if (cfg->nh_grp_type == NEXTHOP_GRP_TYPE_MPATH) { + nhg->hash_threshold = 1; + nhg->is_multipath = true; + } else if (cfg->nh_grp_type == NEXTHOP_GRP_TYPE_RES) { + struct nh_res_table *res_table; + + res_table = nexthop_res_table_alloc(net, cfg->nh_id, cfg); + if (!res_table) { + err = -ENOMEM; + goto out_no_nh; + } + + rcu_assign_pointer(nhg->spare->res_table, res_table); + rcu_assign_pointer(nhg->res_table, res_table); + nhg->resilient = true; + nhg->is_multipath = true; + } + + WARN_ON_ONCE(nhg->hash_threshold + nhg->resilient != 1); + + if (nhg->hash_threshold) + nh_hthr_group_rebalance(nhg); + + if (cfg->nh_fdb) + nhg->fdb_nh = 1; + + rcu_assign_pointer(nh->nh_grp, nhg); + + return nh; + +out_no_nh: + for (i--; i >= 0; --i) { + list_del(&nhg->nh_entries[i].nh_list); + nexthop_put(nhg->nh_entries[i].nh); + } + + kfree(nhg->spare); + kfree(nhg); + kfree(nh); + + return ERR_PTR(err); +} + +static int nh_create_ipv4(struct net *net, struct nexthop *nh, + struct nh_info *nhi, struct nh_config *cfg, + struct netlink_ext_ack *extack) +{ + struct fib_nh *fib_nh = &nhi->fib_nh; + struct fib_config fib_cfg = { + .fc_oif = cfg->nh_ifindex, + .fc_gw4 = cfg->gw.ipv4, + .fc_gw_family = cfg->gw.ipv4 ? AF_INET : 0, + .fc_flags = cfg->nh_flags, + .fc_nlinfo = cfg->nlinfo, + .fc_encap = cfg->nh_encap, + .fc_encap_type = cfg->nh_encap_type, + }; + u32 tb_id = (cfg->dev ? l3mdev_fib_table(cfg->dev) : RT_TABLE_MAIN); + int err; + + err = fib_nh_init(net, fib_nh, &fib_cfg, 1, extack); + if (err) { + fib_nh_release(net, fib_nh); + goto out; + } + + if (nhi->fdb_nh) + goto out; + + /* sets nh_dev if successful */ + err = fib_check_nh(net, fib_nh, tb_id, 0, extack); + if (!err) { + nh->nh_flags = fib_nh->fib_nh_flags; + fib_info_update_nhc_saddr(net, &fib_nh->nh_common, + !fib_nh->fib_nh_scope ? 0 : fib_nh->fib_nh_scope - 1); + } else { + fib_nh_release(net, fib_nh); + } +out: + return err; +} + +static int nh_create_ipv6(struct net *net, struct nexthop *nh, + struct nh_info *nhi, struct nh_config *cfg, + struct netlink_ext_ack *extack) +{ + struct fib6_nh *fib6_nh = &nhi->fib6_nh; + struct fib6_config fib6_cfg = { + .fc_table = l3mdev_fib_table(cfg->dev), + .fc_ifindex = cfg->nh_ifindex, + .fc_gateway = cfg->gw.ipv6, + .fc_flags = cfg->nh_flags, + .fc_nlinfo = cfg->nlinfo, + .fc_encap = cfg->nh_encap, + .fc_encap_type = cfg->nh_encap_type, + .fc_is_fdb = cfg->nh_fdb, + }; + int err; + + if (!ipv6_addr_any(&cfg->gw.ipv6)) + fib6_cfg.fc_flags |= RTF_GATEWAY; + + /* sets nh_dev if successful */ + err = ipv6_stub->fib6_nh_init(net, fib6_nh, &fib6_cfg, GFP_KERNEL, + extack); + if (err) { + /* IPv6 is not enabled, don't call fib6_nh_release */ + if (err == -EAFNOSUPPORT) + goto out; + ipv6_stub->fib6_nh_release(fib6_nh); + } else { + nh->nh_flags = fib6_nh->fib_nh_flags; + } +out: + return err; +} + +static struct nexthop *nexthop_create(struct net *net, struct nh_config *cfg, + struct netlink_ext_ack *extack) +{ + struct nh_info *nhi; + struct nexthop *nh; + int err = 0; + + nh = nexthop_alloc(); + if (!nh) + return ERR_PTR(-ENOMEM); + + nhi = kzalloc(sizeof(*nhi), GFP_KERNEL); + if (!nhi) { + kfree(nh); + return ERR_PTR(-ENOMEM); + } + + nh->nh_flags = cfg->nh_flags; + nh->net = net; + + nhi->nh_parent = nh; + nhi->family = cfg->nh_family; + nhi->fib_nhc.nhc_scope = RT_SCOPE_LINK; + + if (cfg->nh_fdb) + nhi->fdb_nh = 1; + + if (cfg->nh_blackhole) { + nhi->reject_nh = 1; + cfg->nh_ifindex = net->loopback_dev->ifindex; + } + + switch (cfg->nh_family) { + case AF_INET: + err = nh_create_ipv4(net, nh, nhi, cfg, extack); + break; + case AF_INET6: + err = nh_create_ipv6(net, nh, nhi, cfg, extack); + break; + } + + if (err) { + kfree(nhi); + kfree(nh); + return ERR_PTR(err); + } + + /* add the entry to the device based hash */ + if (!nhi->fdb_nh) + nexthop_devhash_add(net, nhi); + + rcu_assign_pointer(nh->nh_info, nhi); + + return nh; +} + +/* called with rtnl lock held */ +static struct nexthop *nexthop_add(struct net *net, struct nh_config *cfg, + struct netlink_ext_ack *extack) +{ + struct nexthop *nh; + int err; + + if (cfg->nlflags & NLM_F_REPLACE && !cfg->nh_id) { + NL_SET_ERR_MSG(extack, "Replace requires nexthop id"); + return ERR_PTR(-EINVAL); + } + + if (!cfg->nh_id) { + cfg->nh_id = nh_find_unused_id(net); + if (!cfg->nh_id) { + NL_SET_ERR_MSG(extack, "No unused id"); + return ERR_PTR(-EINVAL); + } + } + + if (cfg->nh_grp) + nh = nexthop_create_group(net, cfg); + else + nh = nexthop_create(net, cfg, extack); + + if (IS_ERR(nh)) + return nh; + + refcount_set(&nh->refcnt, 1); + nh->id = cfg->nh_id; + nh->protocol = cfg->nh_protocol; + nh->net = net; + + err = insert_nexthop(net, nh, cfg, extack); + if (err) { + __remove_nexthop(net, nh, NULL); + nexthop_put(nh); + nh = ERR_PTR(err); + } + + return nh; +} + +static int rtm_nh_get_timer(struct nlattr *attr, unsigned long fallback, + unsigned long *timer_p, bool *has_p, + struct netlink_ext_ack *extack) +{ + unsigned long timer; + u32 value; + + if (!attr) { + *timer_p = fallback; + *has_p = false; + return 0; + } + + value = nla_get_u32(attr); + timer = clock_t_to_jiffies(value); + if (timer == ~0UL) { + NL_SET_ERR_MSG(extack, "Timer value too large"); + return -EINVAL; + } + + *timer_p = timer; + *has_p = true; + return 0; +} + +static int rtm_to_nh_config_grp_res(struct nlattr *res, struct nh_config *cfg, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[ARRAY_SIZE(rtm_nh_res_policy_new)] = {}; + int err; + + if (res) { + err = nla_parse_nested(tb, + ARRAY_SIZE(rtm_nh_res_policy_new) - 1, + res, rtm_nh_res_policy_new, extack); + if (err < 0) + return err; + } + + if (tb[NHA_RES_GROUP_BUCKETS]) { + cfg->nh_grp_res_num_buckets = + nla_get_u16(tb[NHA_RES_GROUP_BUCKETS]); + cfg->nh_grp_res_has_num_buckets = true; + if (!cfg->nh_grp_res_num_buckets) { + NL_SET_ERR_MSG(extack, "Number of buckets needs to be non-0"); + return -EINVAL; + } + } + + err = rtm_nh_get_timer(tb[NHA_RES_GROUP_IDLE_TIMER], + NH_RES_DEFAULT_IDLE_TIMER, + &cfg->nh_grp_res_idle_timer, + &cfg->nh_grp_res_has_idle_timer, + extack); + if (err) + return err; + + return rtm_nh_get_timer(tb[NHA_RES_GROUP_UNBALANCED_TIMER], + NH_RES_DEFAULT_UNBALANCED_TIMER, + &cfg->nh_grp_res_unbalanced_timer, + &cfg->nh_grp_res_has_unbalanced_timer, + extack); +} + +static int rtm_to_nh_config(struct net *net, struct sk_buff *skb, + struct nlmsghdr *nlh, struct nh_config *cfg, + struct netlink_ext_ack *extack) +{ + struct nhmsg *nhm = nlmsg_data(nlh); + struct nlattr *tb[ARRAY_SIZE(rtm_nh_policy_new)]; + int err; + + err = nlmsg_parse(nlh, sizeof(*nhm), tb, + ARRAY_SIZE(rtm_nh_policy_new) - 1, + rtm_nh_policy_new, extack); + if (err < 0) + return err; + + err = -EINVAL; + if (nhm->resvd || nhm->nh_scope) { + NL_SET_ERR_MSG(extack, "Invalid values in ancillary header"); + goto out; + } + if (nhm->nh_flags & ~NEXTHOP_VALID_USER_FLAGS) { + NL_SET_ERR_MSG(extack, "Invalid nexthop flags in ancillary header"); + goto out; + } + + switch (nhm->nh_family) { + case AF_INET: + case AF_INET6: + break; + case AF_UNSPEC: + if (tb[NHA_GROUP]) + break; + fallthrough; + default: + NL_SET_ERR_MSG(extack, "Invalid address family"); + goto out; + } + + memset(cfg, 0, sizeof(*cfg)); + cfg->nlflags = nlh->nlmsg_flags; + cfg->nlinfo.portid = NETLINK_CB(skb).portid; + cfg->nlinfo.nlh = nlh; + cfg->nlinfo.nl_net = net; + + cfg->nh_family = nhm->nh_family; + cfg->nh_protocol = nhm->nh_protocol; + cfg->nh_flags = nhm->nh_flags; + + if (tb[NHA_ID]) + cfg->nh_id = nla_get_u32(tb[NHA_ID]); + + if (tb[NHA_FDB]) { + if (tb[NHA_OIF] || tb[NHA_BLACKHOLE] || + tb[NHA_ENCAP] || tb[NHA_ENCAP_TYPE]) { + NL_SET_ERR_MSG(extack, "Fdb attribute can not be used with encap, oif or blackhole"); + goto out; + } + if (nhm->nh_flags) { + NL_SET_ERR_MSG(extack, "Unsupported nexthop flags in ancillary header"); + goto out; + } + cfg->nh_fdb = nla_get_flag(tb[NHA_FDB]); + } + + if (tb[NHA_GROUP]) { + if (nhm->nh_family != AF_UNSPEC) { + NL_SET_ERR_MSG(extack, "Invalid family for group"); + goto out; + } + cfg->nh_grp = tb[NHA_GROUP]; + + cfg->nh_grp_type = NEXTHOP_GRP_TYPE_MPATH; + if (tb[NHA_GROUP_TYPE]) + cfg->nh_grp_type = nla_get_u16(tb[NHA_GROUP_TYPE]); + + if (cfg->nh_grp_type > NEXTHOP_GRP_TYPE_MAX) { + NL_SET_ERR_MSG(extack, "Invalid group type"); + goto out; + } + err = nh_check_attr_group(net, tb, ARRAY_SIZE(tb), + cfg->nh_grp_type, extack); + if (err) + goto out; + + if (cfg->nh_grp_type == NEXTHOP_GRP_TYPE_RES) + err = rtm_to_nh_config_grp_res(tb[NHA_RES_GROUP], + cfg, extack); + + /* no other attributes should be set */ + goto out; + } + + if (tb[NHA_BLACKHOLE]) { + if (tb[NHA_GATEWAY] || tb[NHA_OIF] || + tb[NHA_ENCAP] || tb[NHA_ENCAP_TYPE] || tb[NHA_FDB]) { + NL_SET_ERR_MSG(extack, "Blackhole attribute can not be used with gateway, oif, encap or fdb"); + goto out; + } + + cfg->nh_blackhole = 1; + err = 0; + goto out; + } + + if (!cfg->nh_fdb && !tb[NHA_OIF]) { + NL_SET_ERR_MSG(extack, "Device attribute required for non-blackhole and non-fdb nexthops"); + goto out; + } + + if (!cfg->nh_fdb && tb[NHA_OIF]) { + cfg->nh_ifindex = nla_get_u32(tb[NHA_OIF]); + if (cfg->nh_ifindex) + cfg->dev = __dev_get_by_index(net, cfg->nh_ifindex); + + if (!cfg->dev) { + NL_SET_ERR_MSG(extack, "Invalid device index"); + goto out; + } else if (!(cfg->dev->flags & IFF_UP)) { + NL_SET_ERR_MSG(extack, "Nexthop device is not up"); + err = -ENETDOWN; + goto out; + } else if (!netif_carrier_ok(cfg->dev)) { + NL_SET_ERR_MSG(extack, "Carrier for nexthop device is down"); + err = -ENETDOWN; + goto out; + } + } + + err = -EINVAL; + if (tb[NHA_GATEWAY]) { + struct nlattr *gwa = tb[NHA_GATEWAY]; + + switch (cfg->nh_family) { + case AF_INET: + if (nla_len(gwa) != sizeof(u32)) { + NL_SET_ERR_MSG(extack, "Invalid gateway"); + goto out; + } + cfg->gw.ipv4 = nla_get_be32(gwa); + break; + case AF_INET6: + if (nla_len(gwa) != sizeof(struct in6_addr)) { + NL_SET_ERR_MSG(extack, "Invalid gateway"); + goto out; + } + cfg->gw.ipv6 = nla_get_in6_addr(gwa); + break; + default: + NL_SET_ERR_MSG(extack, + "Unknown address family for gateway"); + goto out; + } + } else { + /* device only nexthop (no gateway) */ + if (cfg->nh_flags & RTNH_F_ONLINK) { + NL_SET_ERR_MSG(extack, + "ONLINK flag can not be set for nexthop without a gateway"); + goto out; + } + } + + if (tb[NHA_ENCAP]) { + cfg->nh_encap = tb[NHA_ENCAP]; + + if (!tb[NHA_ENCAP_TYPE]) { + NL_SET_ERR_MSG(extack, "LWT encapsulation type is missing"); + goto out; + } + + cfg->nh_encap_type = nla_get_u16(tb[NHA_ENCAP_TYPE]); + err = lwtunnel_valid_encap_type(cfg->nh_encap_type, extack); + if (err < 0) + goto out; + + } else if (tb[NHA_ENCAP_TYPE]) { + NL_SET_ERR_MSG(extack, "LWT encapsulation attribute is missing"); + goto out; + } + + + err = 0; +out: + return err; +} + +/* rtnl */ +static int rtm_new_nexthop(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nh_config cfg; + struct nexthop *nh; + int err; + + err = rtm_to_nh_config(net, skb, nlh, &cfg, extack); + if (!err) { + nh = nexthop_add(net, &cfg, extack); + if (IS_ERR(nh)) + err = PTR_ERR(nh); + } + + return err; +} + +static int __nh_valid_get_del_req(const struct nlmsghdr *nlh, + struct nlattr **tb, u32 *id, + struct netlink_ext_ack *extack) +{ + struct nhmsg *nhm = nlmsg_data(nlh); + + if (nhm->nh_protocol || nhm->resvd || nhm->nh_scope || nhm->nh_flags) { + NL_SET_ERR_MSG(extack, "Invalid values in header"); + return -EINVAL; + } + + if (!tb[NHA_ID]) { + NL_SET_ERR_MSG(extack, "Nexthop id is missing"); + return -EINVAL; + } + + *id = nla_get_u32(tb[NHA_ID]); + if (!(*id)) { + NL_SET_ERR_MSG(extack, "Invalid nexthop id"); + return -EINVAL; + } + + return 0; +} + +static int nh_valid_get_del_req(const struct nlmsghdr *nlh, u32 *id, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[ARRAY_SIZE(rtm_nh_policy_get)]; + int err; + + err = nlmsg_parse(nlh, sizeof(struct nhmsg), tb, + ARRAY_SIZE(rtm_nh_policy_get) - 1, + rtm_nh_policy_get, extack); + if (err < 0) + return err; + + return __nh_valid_get_del_req(nlh, tb, id, extack); +} + +/* rtnl */ +static int rtm_del_nexthop(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nl_info nlinfo = { + .nlh = nlh, + .nl_net = net, + .portid = NETLINK_CB(skb).portid, + }; + struct nexthop *nh; + int err; + u32 id; + + err = nh_valid_get_del_req(nlh, &id, extack); + if (err) + return err; + + nh = nexthop_find_by_id(net, id); + if (!nh) + return -ENOENT; + + remove_nexthop(net, nh, &nlinfo); + + return 0; +} + +/* rtnl */ +static int rtm_get_nexthop(struct sk_buff *in_skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(in_skb->sk); + struct sk_buff *skb = NULL; + struct nexthop *nh; + int err; + u32 id; + + err = nh_valid_get_del_req(nlh, &id, extack); + if (err) + return err; + + err = -ENOBUFS; + skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) + goto out; + + err = -ENOENT; + nh = nexthop_find_by_id(net, id); + if (!nh) + goto errout_free; + + err = nh_fill_node(skb, nh, RTM_NEWNEXTHOP, NETLINK_CB(in_skb).portid, + nlh->nlmsg_seq, 0); + if (err < 0) { + WARN_ON(err == -EMSGSIZE); + goto errout_free; + } + + err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid); +out: + return err; +errout_free: + kfree_skb(skb); + goto out; +} + +struct nh_dump_filter { + u32 nh_id; + int dev_idx; + int master_idx; + bool group_filter; + bool fdb_filter; + u32 res_bucket_nh_id; +}; + +static bool nh_dump_filtered(struct nexthop *nh, + struct nh_dump_filter *filter, u8 family) +{ + const struct net_device *dev; + const struct nh_info *nhi; + + if (filter->group_filter && !nh->is_group) + return true; + + if (!filter->dev_idx && !filter->master_idx && !family) + return false; + + if (nh->is_group) + return true; + + nhi = rtnl_dereference(nh->nh_info); + if (family && nhi->family != family) + return true; + + dev = nhi->fib_nhc.nhc_dev; + if (filter->dev_idx && (!dev || dev->ifindex != filter->dev_idx)) + return true; + + if (filter->master_idx) { + struct net_device *master; + + if (!dev) + return true; + + master = netdev_master_upper_dev_get((struct net_device *)dev); + if (!master || master->ifindex != filter->master_idx) + return true; + } + + return false; +} + +static int __nh_valid_dump_req(const struct nlmsghdr *nlh, struct nlattr **tb, + struct nh_dump_filter *filter, + struct netlink_ext_ack *extack) +{ + struct nhmsg *nhm; + u32 idx; + + if (tb[NHA_OIF]) { + idx = nla_get_u32(tb[NHA_OIF]); + if (idx > INT_MAX) { + NL_SET_ERR_MSG(extack, "Invalid device index"); + return -EINVAL; + } + filter->dev_idx = idx; + } + if (tb[NHA_MASTER]) { + idx = nla_get_u32(tb[NHA_MASTER]); + if (idx > INT_MAX) { + NL_SET_ERR_MSG(extack, "Invalid master device index"); + return -EINVAL; + } + filter->master_idx = idx; + } + filter->group_filter = nla_get_flag(tb[NHA_GROUPS]); + filter->fdb_filter = nla_get_flag(tb[NHA_FDB]); + + nhm = nlmsg_data(nlh); + if (nhm->nh_protocol || nhm->resvd || nhm->nh_scope || nhm->nh_flags) { + NL_SET_ERR_MSG(extack, "Invalid values in header for nexthop dump request"); + return -EINVAL; + } + + return 0; +} + +static int nh_valid_dump_req(const struct nlmsghdr *nlh, + struct nh_dump_filter *filter, + struct netlink_callback *cb) +{ + struct nlattr *tb[ARRAY_SIZE(rtm_nh_policy_dump)]; + int err; + + err = nlmsg_parse(nlh, sizeof(struct nhmsg), tb, + ARRAY_SIZE(rtm_nh_policy_dump) - 1, + rtm_nh_policy_dump, cb->extack); + if (err < 0) + return err; + + return __nh_valid_dump_req(nlh, tb, filter, cb->extack); +} + +struct rtm_dump_nh_ctx { + u32 idx; +}; + +static struct rtm_dump_nh_ctx * +rtm_dump_nh_ctx(struct netlink_callback *cb) +{ + struct rtm_dump_nh_ctx *ctx = (void *)cb->ctx; + + BUILD_BUG_ON(sizeof(*ctx) > sizeof(cb->ctx)); + return ctx; +} + +static int rtm_dump_walk_nexthops(struct sk_buff *skb, + struct netlink_callback *cb, + struct rb_root *root, + struct rtm_dump_nh_ctx *ctx, + int (*nh_cb)(struct sk_buff *skb, + struct netlink_callback *cb, + struct nexthop *nh, void *data), + void *data) +{ + struct rb_node *node; + int s_idx; + int err; + + s_idx = ctx->idx; + for (node = rb_first(root); node; node = rb_next(node)) { + struct nexthop *nh; + + nh = rb_entry(node, struct nexthop, rb_node); + if (nh->id < s_idx) + continue; + + ctx->idx = nh->id; + err = nh_cb(skb, cb, nh, data); + if (err) + return err; + } + + ctx->idx++; + return 0; +} + +static int rtm_dump_nexthop_cb(struct sk_buff *skb, struct netlink_callback *cb, + struct nexthop *nh, void *data) +{ + struct nhmsg *nhm = nlmsg_data(cb->nlh); + struct nh_dump_filter *filter = data; + + if (nh_dump_filtered(nh, filter, nhm->nh_family)) + return 0; + + return nh_fill_node(skb, nh, RTM_NEWNEXTHOP, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI); +} + +/* rtnl */ +static int rtm_dump_nexthop(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct rtm_dump_nh_ctx *ctx = rtm_dump_nh_ctx(cb); + struct net *net = sock_net(skb->sk); + struct rb_root *root = &net->nexthop.rb_root; + struct nh_dump_filter filter = {}; + int err; + + err = nh_valid_dump_req(cb->nlh, &filter, cb); + if (err < 0) + return err; + + err = rtm_dump_walk_nexthops(skb, cb, root, ctx, + &rtm_dump_nexthop_cb, &filter); + if (err < 0) { + if (likely(skb->len)) + err = skb->len; + } + + cb->seq = net->nexthop.seq; + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); + return err; +} + +static struct nexthop * +nexthop_find_group_resilient(struct net *net, u32 id, + struct netlink_ext_ack *extack) +{ + struct nh_group *nhg; + struct nexthop *nh; + + nh = nexthop_find_by_id(net, id); + if (!nh) + return ERR_PTR(-ENOENT); + + if (!nh->is_group) { + NL_SET_ERR_MSG(extack, "Not a nexthop group"); + return ERR_PTR(-EINVAL); + } + + nhg = rtnl_dereference(nh->nh_grp); + if (!nhg->resilient) { + NL_SET_ERR_MSG(extack, "Nexthop group not of type resilient"); + return ERR_PTR(-EINVAL); + } + + return nh; +} + +static int nh_valid_dump_nhid(struct nlattr *attr, u32 *nh_id_p, + struct netlink_ext_ack *extack) +{ + u32 idx; + + if (attr) { + idx = nla_get_u32(attr); + if (!idx) { + NL_SET_ERR_MSG(extack, "Invalid nexthop id"); + return -EINVAL; + } + *nh_id_p = idx; + } else { + *nh_id_p = 0; + } + + return 0; +} + +static int nh_valid_dump_bucket_req(const struct nlmsghdr *nlh, + struct nh_dump_filter *filter, + struct netlink_callback *cb) +{ + struct nlattr *res_tb[ARRAY_SIZE(rtm_nh_res_bucket_policy_dump)]; + struct nlattr *tb[ARRAY_SIZE(rtm_nh_policy_dump_bucket)]; + int err; + + err = nlmsg_parse(nlh, sizeof(struct nhmsg), tb, + ARRAY_SIZE(rtm_nh_policy_dump_bucket) - 1, + rtm_nh_policy_dump_bucket, NULL); + if (err < 0) + return err; + + err = nh_valid_dump_nhid(tb[NHA_ID], &filter->nh_id, cb->extack); + if (err) + return err; + + if (tb[NHA_RES_BUCKET]) { + size_t max = ARRAY_SIZE(rtm_nh_res_bucket_policy_dump) - 1; + + err = nla_parse_nested(res_tb, max, + tb[NHA_RES_BUCKET], + rtm_nh_res_bucket_policy_dump, + cb->extack); + if (err < 0) + return err; + + err = nh_valid_dump_nhid(res_tb[NHA_RES_BUCKET_NH_ID], + &filter->res_bucket_nh_id, + cb->extack); + if (err) + return err; + } + + return __nh_valid_dump_req(nlh, tb, filter, cb->extack); +} + +struct rtm_dump_res_bucket_ctx { + struct rtm_dump_nh_ctx nh; + u16 bucket_index; + u32 done_nh_idx; /* 1 + the index of the last fully processed NH. */ +}; + +static struct rtm_dump_res_bucket_ctx * +rtm_dump_res_bucket_ctx(struct netlink_callback *cb) +{ + struct rtm_dump_res_bucket_ctx *ctx = (void *)cb->ctx; + + BUILD_BUG_ON(sizeof(*ctx) > sizeof(cb->ctx)); + return ctx; +} + +struct rtm_dump_nexthop_bucket_data { + struct rtm_dump_res_bucket_ctx *ctx; + struct nh_dump_filter filter; +}; + +static int rtm_dump_nexthop_bucket_nh(struct sk_buff *skb, + struct netlink_callback *cb, + struct nexthop *nh, + struct rtm_dump_nexthop_bucket_data *dd) +{ + u32 portid = NETLINK_CB(cb->skb).portid; + struct nhmsg *nhm = nlmsg_data(cb->nlh); + struct nh_res_table *res_table; + struct nh_group *nhg; + u16 bucket_index; + int err; + + if (dd->ctx->nh.idx < dd->ctx->done_nh_idx) + return 0; + + nhg = rtnl_dereference(nh->nh_grp); + res_table = rtnl_dereference(nhg->res_table); + for (bucket_index = dd->ctx->bucket_index; + bucket_index < res_table->num_nh_buckets; + bucket_index++) { + struct nh_res_bucket *bucket; + struct nh_grp_entry *nhge; + + bucket = &res_table->nh_buckets[bucket_index]; + nhge = rtnl_dereference(bucket->nh_entry); + if (nh_dump_filtered(nhge->nh, &dd->filter, nhm->nh_family)) + continue; + + if (dd->filter.res_bucket_nh_id && + dd->filter.res_bucket_nh_id != nhge->nh->id) + continue; + + dd->ctx->bucket_index = bucket_index; + err = nh_fill_res_bucket(skb, nh, bucket, bucket_index, + RTM_NEWNEXTHOPBUCKET, portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + cb->extack); + if (err) + return err; + } + + dd->ctx->done_nh_idx = dd->ctx->nh.idx + 1; + dd->ctx->bucket_index = 0; + + return 0; +} + +static int rtm_dump_nexthop_bucket_cb(struct sk_buff *skb, + struct netlink_callback *cb, + struct nexthop *nh, void *data) +{ + struct rtm_dump_nexthop_bucket_data *dd = data; + struct nh_group *nhg; + + if (!nh->is_group) + return 0; + + nhg = rtnl_dereference(nh->nh_grp); + if (!nhg->resilient) + return 0; + + return rtm_dump_nexthop_bucket_nh(skb, cb, nh, dd); +} + +/* rtnl */ +static int rtm_dump_nexthop_bucket(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct rtm_dump_res_bucket_ctx *ctx = rtm_dump_res_bucket_ctx(cb); + struct rtm_dump_nexthop_bucket_data dd = { .ctx = ctx }; + struct net *net = sock_net(skb->sk); + struct nexthop *nh; + int err; + + err = nh_valid_dump_bucket_req(cb->nlh, &dd.filter, cb); + if (err) + return err; + + if (dd.filter.nh_id) { + nh = nexthop_find_group_resilient(net, dd.filter.nh_id, + cb->extack); + if (IS_ERR(nh)) + return PTR_ERR(nh); + err = rtm_dump_nexthop_bucket_nh(skb, cb, nh, &dd); + } else { + struct rb_root *root = &net->nexthop.rb_root; + + err = rtm_dump_walk_nexthops(skb, cb, root, &ctx->nh, + &rtm_dump_nexthop_bucket_cb, &dd); + } + + if (err < 0) { + if (likely(skb->len)) + err = skb->len; + } + + cb->seq = net->nexthop.seq; + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); + return err; +} + +static int nh_valid_get_bucket_req_res_bucket(struct nlattr *res, + u16 *bucket_index, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[ARRAY_SIZE(rtm_nh_res_bucket_policy_get)]; + int err; + + err = nla_parse_nested(tb, ARRAY_SIZE(rtm_nh_res_bucket_policy_get) - 1, + res, rtm_nh_res_bucket_policy_get, extack); + if (err < 0) + return err; + + if (!tb[NHA_RES_BUCKET_INDEX]) { + NL_SET_ERR_MSG(extack, "Bucket index is missing"); + return -EINVAL; + } + + *bucket_index = nla_get_u16(tb[NHA_RES_BUCKET_INDEX]); + return 0; +} + +static int nh_valid_get_bucket_req(const struct nlmsghdr *nlh, + u32 *id, u16 *bucket_index, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[ARRAY_SIZE(rtm_nh_policy_get_bucket)]; + int err; + + err = nlmsg_parse(nlh, sizeof(struct nhmsg), tb, + ARRAY_SIZE(rtm_nh_policy_get_bucket) - 1, + rtm_nh_policy_get_bucket, extack); + if (err < 0) + return err; + + err = __nh_valid_get_del_req(nlh, tb, id, extack); + if (err) + return err; + + if (!tb[NHA_RES_BUCKET]) { + NL_SET_ERR_MSG(extack, "Bucket information is missing"); + return -EINVAL; + } + + err = nh_valid_get_bucket_req_res_bucket(tb[NHA_RES_BUCKET], + bucket_index, extack); + if (err) + return err; + + return 0; +} + +/* rtnl */ +static int rtm_get_nexthop_bucket(struct sk_buff *in_skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(in_skb->sk); + struct nh_res_table *res_table; + struct sk_buff *skb = NULL; + struct nh_group *nhg; + struct nexthop *nh; + u16 bucket_index; + int err; + u32 id; + + err = nh_valid_get_bucket_req(nlh, &id, &bucket_index, extack); + if (err) + return err; + + nh = nexthop_find_group_resilient(net, id, extack); + if (IS_ERR(nh)) + return PTR_ERR(nh); + + nhg = rtnl_dereference(nh->nh_grp); + res_table = rtnl_dereference(nhg->res_table); + if (bucket_index >= res_table->num_nh_buckets) { + NL_SET_ERR_MSG(extack, "Bucket index out of bounds"); + return -ENOENT; + } + + skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) + return -ENOBUFS; + + err = nh_fill_res_bucket(skb, nh, &res_table->nh_buckets[bucket_index], + bucket_index, RTM_NEWNEXTHOPBUCKET, + NETLINK_CB(in_skb).portid, nlh->nlmsg_seq, + 0, extack); + if (err < 0) { + WARN_ON(err == -EMSGSIZE); + goto errout_free; + } + + return rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid); + +errout_free: + kfree_skb(skb); + return err; +} + +static void nexthop_sync_mtu(struct net_device *dev, u32 orig_mtu) +{ + unsigned int hash = nh_dev_hashfn(dev->ifindex); + struct net *net = dev_net(dev); + struct hlist_head *head = &net->nexthop.devhash[hash]; + struct hlist_node *n; + struct nh_info *nhi; + + hlist_for_each_entry_safe(nhi, n, head, dev_hash) { + if (nhi->fib_nhc.nhc_dev == dev) { + if (nhi->family == AF_INET) + fib_nhc_update_mtu(&nhi->fib_nhc, dev->mtu, + orig_mtu); + } + } +} + +/* rtnl */ +static int nh_netdev_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct netdev_notifier_info_ext *info_ext; + + switch (event) { + case NETDEV_DOWN: + case NETDEV_UNREGISTER: + nexthop_flush_dev(dev, event); + break; + case NETDEV_CHANGE: + if (!(dev_get_flags(dev) & (IFF_RUNNING | IFF_LOWER_UP))) + nexthop_flush_dev(dev, event); + break; + case NETDEV_CHANGEMTU: + info_ext = ptr; + nexthop_sync_mtu(dev, info_ext->ext.mtu); + rt_cache_flush(dev_net(dev)); + break; + } + return NOTIFY_DONE; +} + +static struct notifier_block nh_netdev_notifier = { + .notifier_call = nh_netdev_event, +}; + +static int nexthops_dump(struct net *net, struct notifier_block *nb, + enum nexthop_event_type event_type, + struct netlink_ext_ack *extack) +{ + struct rb_root *root = &net->nexthop.rb_root; + struct rb_node *node; + int err = 0; + + for (node = rb_first(root); node; node = rb_next(node)) { + struct nexthop *nh; + + nh = rb_entry(node, struct nexthop, rb_node); + err = call_nexthop_notifier(nb, net, event_type, nh, extack); + if (err) + break; + } + + return err; +} + +int register_nexthop_notifier(struct net *net, struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + int err; + + rtnl_lock(); + err = nexthops_dump(net, nb, NEXTHOP_EVENT_REPLACE, extack); + if (err) + goto unlock; + err = blocking_notifier_chain_register(&net->nexthop.notifier_chain, + nb); +unlock: + rtnl_unlock(); + return err; +} +EXPORT_SYMBOL(register_nexthop_notifier); + +int unregister_nexthop_notifier(struct net *net, struct notifier_block *nb) +{ + int err; + + rtnl_lock(); + err = blocking_notifier_chain_unregister(&net->nexthop.notifier_chain, + nb); + if (err) + goto unlock; + nexthops_dump(net, nb, NEXTHOP_EVENT_DEL, NULL); +unlock: + rtnl_unlock(); + return err; +} +EXPORT_SYMBOL(unregister_nexthop_notifier); + +void nexthop_set_hw_flags(struct net *net, u32 id, bool offload, bool trap) +{ + struct nexthop *nexthop; + + rcu_read_lock(); + + nexthop = nexthop_find_by_id(net, id); + if (!nexthop) + goto out; + + nexthop->nh_flags &= ~(RTNH_F_OFFLOAD | RTNH_F_TRAP); + if (offload) + nexthop->nh_flags |= RTNH_F_OFFLOAD; + if (trap) + nexthop->nh_flags |= RTNH_F_TRAP; + +out: + rcu_read_unlock(); +} +EXPORT_SYMBOL(nexthop_set_hw_flags); + +void nexthop_bucket_set_hw_flags(struct net *net, u32 id, u16 bucket_index, + bool offload, bool trap) +{ + struct nh_res_table *res_table; + struct nh_res_bucket *bucket; + struct nexthop *nexthop; + struct nh_group *nhg; + + rcu_read_lock(); + + nexthop = nexthop_find_by_id(net, id); + if (!nexthop || !nexthop->is_group) + goto out; + + nhg = rcu_dereference(nexthop->nh_grp); + if (!nhg->resilient) + goto out; + + if (bucket_index >= nhg->res_table->num_nh_buckets) + goto out; + + res_table = rcu_dereference(nhg->res_table); + bucket = &res_table->nh_buckets[bucket_index]; + bucket->nh_flags &= ~(RTNH_F_OFFLOAD | RTNH_F_TRAP); + if (offload) + bucket->nh_flags |= RTNH_F_OFFLOAD; + if (trap) + bucket->nh_flags |= RTNH_F_TRAP; + +out: + rcu_read_unlock(); +} +EXPORT_SYMBOL(nexthop_bucket_set_hw_flags); + +void nexthop_res_grp_activity_update(struct net *net, u32 id, u16 num_buckets, + unsigned long *activity) +{ + struct nh_res_table *res_table; + struct nexthop *nexthop; + struct nh_group *nhg; + u16 i; + + rcu_read_lock(); + + nexthop = nexthop_find_by_id(net, id); + if (!nexthop || !nexthop->is_group) + goto out; + + nhg = rcu_dereference(nexthop->nh_grp); + if (!nhg->resilient) + goto out; + + /* Instead of silently ignoring some buckets, demand that the sizes + * be the same. + */ + res_table = rcu_dereference(nhg->res_table); + if (num_buckets != res_table->num_nh_buckets) + goto out; + + for (i = 0; i < num_buckets; i++) { + if (test_bit(i, activity)) + nh_res_bucket_set_busy(&res_table->nh_buckets[i]); + } + +out: + rcu_read_unlock(); +} +EXPORT_SYMBOL(nexthop_res_grp_activity_update); + +static void __net_exit nexthop_net_exit_batch(struct list_head *net_list) +{ + struct net *net; + + rtnl_lock(); + list_for_each_entry(net, net_list, exit_list) { + flush_all_nexthops(net); + kfree(net->nexthop.devhash); + } + rtnl_unlock(); +} + +static int __net_init nexthop_net_init(struct net *net) +{ + size_t sz = sizeof(struct hlist_head) * NH_DEV_HASHSIZE; + + net->nexthop.rb_root = RB_ROOT; + net->nexthop.devhash = kzalloc(sz, GFP_KERNEL); + if (!net->nexthop.devhash) + return -ENOMEM; + BLOCKING_INIT_NOTIFIER_HEAD(&net->nexthop.notifier_chain); + + return 0; +} + +static struct pernet_operations nexthop_net_ops = { + .init = nexthop_net_init, + .exit_batch = nexthop_net_exit_batch, +}; + +static int __init nexthop_init(void) +{ + register_pernet_subsys(&nexthop_net_ops); + + register_netdevice_notifier(&nh_netdev_notifier); + + rtnl_register(PF_UNSPEC, RTM_NEWNEXTHOP, rtm_new_nexthop, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_DELNEXTHOP, rtm_del_nexthop, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_GETNEXTHOP, rtm_get_nexthop, + rtm_dump_nexthop, 0); + + rtnl_register(PF_INET, RTM_NEWNEXTHOP, rtm_new_nexthop, NULL, 0); + rtnl_register(PF_INET, RTM_GETNEXTHOP, NULL, rtm_dump_nexthop, 0); + + rtnl_register(PF_INET6, RTM_NEWNEXTHOP, rtm_new_nexthop, NULL, 0); + rtnl_register(PF_INET6, RTM_GETNEXTHOP, NULL, rtm_dump_nexthop, 0); + + rtnl_register(PF_UNSPEC, RTM_GETNEXTHOPBUCKET, rtm_get_nexthop_bucket, + rtm_dump_nexthop_bucket, 0); + + return 0; +} +subsys_initcall(nexthop_init); diff --git a/net/ipv4/ping.c b/net/ipv4/ping.c new file mode 100644 index 000000000..5178a3f3c --- /dev/null +++ b/net/ipv4/ping.c @@ -0,0 +1,1211 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * "Ping" sockets + * + * Based on ipv4/udp.c code. + * + * Authors: Vasiliy Kulikov / Openwall (for Linux 2.6), + * Pavel Kankovsky (for Linux 2.4.32) + * + * Pavel gave all rights to bugs to Vasiliy, + * none of the bugs are Pavel's now. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_IPV6) +#include +#include +#include +#include +#include +#endif + +#define ping_portaddr_for_each_entry(__sk, node, list) \ + hlist_nulls_for_each_entry(__sk, node, list, sk_nulls_node) +#define ping_portaddr_for_each_entry_rcu(__sk, node, list) \ + hlist_nulls_for_each_entry_rcu(__sk, node, list, sk_nulls_node) + +struct ping_table { + struct hlist_nulls_head hash[PING_HTABLE_SIZE]; + spinlock_t lock; +}; + +static struct ping_table ping_table; +struct pingv6_ops pingv6_ops; +EXPORT_SYMBOL_GPL(pingv6_ops); + +static u16 ping_port_rover; + +static inline u32 ping_hashfn(const struct net *net, u32 num, u32 mask) +{ + u32 res = (num + net_hash_mix(net)) & mask; + + pr_debug("hash(%u) = %u\n", num, res); + return res; +} +EXPORT_SYMBOL_GPL(ping_hash); + +static inline struct hlist_nulls_head *ping_hashslot(struct ping_table *table, + struct net *net, unsigned int num) +{ + return &table->hash[ping_hashfn(net, num, PING_HTABLE_MASK)]; +} + +int ping_get_port(struct sock *sk, unsigned short ident) +{ + struct hlist_nulls_node *node; + struct hlist_nulls_head *hlist; + struct inet_sock *isk, *isk2; + struct sock *sk2 = NULL; + + isk = inet_sk(sk); + spin_lock(&ping_table.lock); + if (ident == 0) { + u32 i; + u16 result = ping_port_rover + 1; + + for (i = 0; i < (1L << 16); i++, result++) { + if (!result) + result++; /* avoid zero */ + hlist = ping_hashslot(&ping_table, sock_net(sk), + result); + ping_portaddr_for_each_entry(sk2, node, hlist) { + isk2 = inet_sk(sk2); + + if (isk2->inet_num == result) + goto next_port; + } + + /* found */ + ping_port_rover = ident = result; + break; +next_port: + ; + } + if (i >= (1L << 16)) + goto fail; + } else { + hlist = ping_hashslot(&ping_table, sock_net(sk), ident); + ping_portaddr_for_each_entry(sk2, node, hlist) { + isk2 = inet_sk(sk2); + + /* BUG? Why is this reuse and not reuseaddr? ping.c + * doesn't turn off SO_REUSEADDR, and it doesn't expect + * that other ping processes can steal its packets. + */ + if ((isk2->inet_num == ident) && + (sk2 != sk) && + (!sk2->sk_reuse || !sk->sk_reuse)) + goto fail; + } + } + + pr_debug("found port/ident = %d\n", ident); + isk->inet_num = ident; + if (sk_unhashed(sk)) { + pr_debug("was not hashed\n"); + sock_hold(sk); + sock_set_flag(sk, SOCK_RCU_FREE); + hlist_nulls_add_head_rcu(&sk->sk_nulls_node, hlist); + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); + } + spin_unlock(&ping_table.lock); + return 0; + +fail: + spin_unlock(&ping_table.lock); + return -EADDRINUSE; +} +EXPORT_SYMBOL_GPL(ping_get_port); + +int ping_hash(struct sock *sk) +{ + pr_debug("ping_hash(sk->port=%u)\n", inet_sk(sk)->inet_num); + BUG(); /* "Please do not press this button again." */ + + return 0; +} + +void ping_unhash(struct sock *sk) +{ + struct inet_sock *isk = inet_sk(sk); + + pr_debug("ping_unhash(isk=%p,isk->num=%u)\n", isk, isk->inet_num); + spin_lock(&ping_table.lock); + if (sk_hashed(sk)) { + hlist_nulls_del_init_rcu(&sk->sk_nulls_node); + sock_put(sk); + isk->inet_num = 0; + isk->inet_sport = 0; + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); + } + spin_unlock(&ping_table.lock); +} +EXPORT_SYMBOL_GPL(ping_unhash); + +/* Called under rcu_read_lock() */ +static struct sock *ping_lookup(struct net *net, struct sk_buff *skb, u16 ident) +{ + struct hlist_nulls_head *hslot = ping_hashslot(&ping_table, net, ident); + struct sock *sk = NULL; + struct inet_sock *isk; + struct hlist_nulls_node *hnode; + int dif, sdif; + + if (skb->protocol == htons(ETH_P_IP)) { + dif = inet_iif(skb); + sdif = inet_sdif(skb); + pr_debug("try to find: num = %d, daddr = %pI4, dif = %d\n", + (int)ident, &ip_hdr(skb)->daddr, dif); +#if IS_ENABLED(CONFIG_IPV6) + } else if (skb->protocol == htons(ETH_P_IPV6)) { + dif = inet6_iif(skb); + sdif = inet6_sdif(skb); + pr_debug("try to find: num = %d, daddr = %pI6c, dif = %d\n", + (int)ident, &ipv6_hdr(skb)->daddr, dif); +#endif + } else { + return NULL; + } + + ping_portaddr_for_each_entry_rcu(sk, hnode, hslot) { + isk = inet_sk(sk); + + pr_debug("iterate\n"); + if (isk->inet_num != ident) + continue; + + if (skb->protocol == htons(ETH_P_IP) && + sk->sk_family == AF_INET) { + pr_debug("found: %p: num=%d, daddr=%pI4, dif=%d\n", sk, + (int) isk->inet_num, &isk->inet_rcv_saddr, + sk->sk_bound_dev_if); + + if (isk->inet_rcv_saddr && + isk->inet_rcv_saddr != ip_hdr(skb)->daddr) + continue; +#if IS_ENABLED(CONFIG_IPV6) + } else if (skb->protocol == htons(ETH_P_IPV6) && + sk->sk_family == AF_INET6) { + + pr_debug("found: %p: num=%d, daddr=%pI6c, dif=%d\n", sk, + (int) isk->inet_num, + &sk->sk_v6_rcv_saddr, + sk->sk_bound_dev_if); + + if (!ipv6_addr_any(&sk->sk_v6_rcv_saddr) && + !ipv6_addr_equal(&sk->sk_v6_rcv_saddr, + &ipv6_hdr(skb)->daddr)) + continue; +#endif + } else { + continue; + } + + if (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif && + sk->sk_bound_dev_if != sdif) + continue; + + goto exit; + } + + sk = NULL; +exit: + + return sk; +} + +static void inet_get_ping_group_range_net(struct net *net, kgid_t *low, + kgid_t *high) +{ + kgid_t *data = net->ipv4.ping_group_range.range; + unsigned int seq; + + do { + seq = read_seqbegin(&net->ipv4.ping_group_range.lock); + + *low = data[0]; + *high = data[1]; + } while (read_seqretry(&net->ipv4.ping_group_range.lock, seq)); +} + + +int ping_init_sock(struct sock *sk) +{ + struct net *net = sock_net(sk); + kgid_t group = current_egid(); + struct group_info *group_info; + int i; + kgid_t low, high; + int ret = 0; + + if (sk->sk_family == AF_INET6) + sk->sk_ipv6only = 1; + + inet_get_ping_group_range_net(net, &low, &high); + if (gid_lte(low, group) && gid_lte(group, high)) + return 0; + + group_info = get_current_groups(); + for (i = 0; i < group_info->ngroups; i++) { + kgid_t gid = group_info->gid[i]; + + if (gid_lte(low, gid) && gid_lte(gid, high)) + goto out_release_group; + } + + ret = -EACCES; + +out_release_group: + put_group_info(group_info); + return ret; +} +EXPORT_SYMBOL_GPL(ping_init_sock); + +void ping_close(struct sock *sk, long timeout) +{ + pr_debug("ping_close(sk=%p,sk->num=%u)\n", + inet_sk(sk), inet_sk(sk)->inet_num); + pr_debug("isk->refcnt = %d\n", refcount_read(&sk->sk_refcnt)); + + sk_common_release(sk); +} +EXPORT_SYMBOL_GPL(ping_close); + +static int ping_pre_connect(struct sock *sk, struct sockaddr *uaddr, + int addr_len) +{ + /* This check is replicated from __ip4_datagram_connect() and + * intended to prevent BPF program called below from accessing bytes + * that are out of the bound specified by user in addr_len. + */ + if (addr_len < sizeof(struct sockaddr_in)) + return -EINVAL; + + return BPF_CGROUP_RUN_PROG_INET4_CONNECT_LOCK(sk, uaddr); +} + +/* Checks the bind address and possibly modifies sk->sk_bound_dev_if. */ +static int ping_check_bind_addr(struct sock *sk, struct inet_sock *isk, + struct sockaddr *uaddr, int addr_len) +{ + struct net *net = sock_net(sk); + if (sk->sk_family == AF_INET) { + struct sockaddr_in *addr = (struct sockaddr_in *) uaddr; + u32 tb_id = RT_TABLE_LOCAL; + int chk_addr_ret; + + if (addr_len < sizeof(*addr)) + return -EINVAL; + + if (addr->sin_family != AF_INET && + !(addr->sin_family == AF_UNSPEC && + addr->sin_addr.s_addr == htonl(INADDR_ANY))) + return -EAFNOSUPPORT; + + pr_debug("ping_check_bind_addr(sk=%p,addr=%pI4,port=%d)\n", + sk, &addr->sin_addr.s_addr, ntohs(addr->sin_port)); + + if (addr->sin_addr.s_addr == htonl(INADDR_ANY)) + return 0; + + tb_id = l3mdev_fib_table_by_index(net, sk->sk_bound_dev_if) ? : tb_id; + chk_addr_ret = inet_addr_type_table(net, addr->sin_addr.s_addr, tb_id); + + if (chk_addr_ret == RTN_MULTICAST || + chk_addr_ret == RTN_BROADCAST || + (chk_addr_ret != RTN_LOCAL && + !inet_can_nonlocal_bind(net, isk))) + return -EADDRNOTAVAIL; + +#if IS_ENABLED(CONFIG_IPV6) + } else if (sk->sk_family == AF_INET6) { + struct sockaddr_in6 *addr = (struct sockaddr_in6 *) uaddr; + int addr_type, scoped, has_addr; + struct net_device *dev = NULL; + + if (addr_len < sizeof(*addr)) + return -EINVAL; + + if (addr->sin6_family != AF_INET6) + return -EAFNOSUPPORT; + + pr_debug("ping_check_bind_addr(sk=%p,addr=%pI6c,port=%d)\n", + sk, addr->sin6_addr.s6_addr, ntohs(addr->sin6_port)); + + addr_type = ipv6_addr_type(&addr->sin6_addr); + scoped = __ipv6_addr_needs_scope_id(addr_type); + if ((addr_type != IPV6_ADDR_ANY && + !(addr_type & IPV6_ADDR_UNICAST)) || + (scoped && !addr->sin6_scope_id)) + return -EINVAL; + + rcu_read_lock(); + if (addr->sin6_scope_id) { + dev = dev_get_by_index_rcu(net, addr->sin6_scope_id); + if (!dev) { + rcu_read_unlock(); + return -ENODEV; + } + } + + if (!dev && sk->sk_bound_dev_if) { + dev = dev_get_by_index_rcu(net, sk->sk_bound_dev_if); + if (!dev) { + rcu_read_unlock(); + return -ENODEV; + } + } + has_addr = pingv6_ops.ipv6_chk_addr(net, &addr->sin6_addr, dev, + scoped); + rcu_read_unlock(); + + if (!(ipv6_can_nonlocal_bind(net, isk) || has_addr || + addr_type == IPV6_ADDR_ANY)) + return -EADDRNOTAVAIL; + + if (scoped) + sk->sk_bound_dev_if = addr->sin6_scope_id; +#endif + } else { + return -EAFNOSUPPORT; + } + return 0; +} + +static void ping_set_saddr(struct sock *sk, struct sockaddr *saddr) +{ + if (saddr->sa_family == AF_INET) { + struct inet_sock *isk = inet_sk(sk); + struct sockaddr_in *addr = (struct sockaddr_in *) saddr; + isk->inet_rcv_saddr = isk->inet_saddr = addr->sin_addr.s_addr; +#if IS_ENABLED(CONFIG_IPV6) + } else if (saddr->sa_family == AF_INET6) { + struct sockaddr_in6 *addr = (struct sockaddr_in6 *) saddr; + struct ipv6_pinfo *np = inet6_sk(sk); + sk->sk_v6_rcv_saddr = np->saddr = addr->sin6_addr; +#endif + } +} + +/* + * We need our own bind because there are no privileged id's == local ports. + * Moreover, we don't allow binding to multi- and broadcast addresses. + */ + +int ping_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len) +{ + struct inet_sock *isk = inet_sk(sk); + unsigned short snum; + int err; + int dif = sk->sk_bound_dev_if; + + err = ping_check_bind_addr(sk, isk, uaddr, addr_len); + if (err) + return err; + + lock_sock(sk); + + err = -EINVAL; + if (isk->inet_num != 0) + goto out; + + err = -EADDRINUSE; + snum = ntohs(((struct sockaddr_in *)uaddr)->sin_port); + if (ping_get_port(sk, snum) != 0) { + /* Restore possibly modified sk->sk_bound_dev_if by ping_check_bind_addr(). */ + sk->sk_bound_dev_if = dif; + goto out; + } + ping_set_saddr(sk, uaddr); + + pr_debug("after bind(): num = %hu, dif = %d\n", + isk->inet_num, + sk->sk_bound_dev_if); + + err = 0; + if (sk->sk_family == AF_INET && isk->inet_rcv_saddr) + sk->sk_userlocks |= SOCK_BINDADDR_LOCK; +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6 && !ipv6_addr_any(&sk->sk_v6_rcv_saddr)) + sk->sk_userlocks |= SOCK_BINDADDR_LOCK; +#endif + + if (snum) + sk->sk_userlocks |= SOCK_BINDPORT_LOCK; + isk->inet_sport = htons(isk->inet_num); + isk->inet_daddr = 0; + isk->inet_dport = 0; + +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) + memset(&sk->sk_v6_daddr, 0, sizeof(sk->sk_v6_daddr)); +#endif + + sk_dst_reset(sk); +out: + release_sock(sk); + pr_debug("ping_v4_bind -> %d\n", err); + return err; +} +EXPORT_SYMBOL_GPL(ping_bind); + +/* + * Is this a supported type of ICMP message? + */ + +static inline int ping_supported(int family, int type, int code) +{ + return (family == AF_INET && type == ICMP_ECHO && code == 0) || + (family == AF_INET && type == ICMP_EXT_ECHO && code == 0) || + (family == AF_INET6 && type == ICMPV6_ECHO_REQUEST && code == 0) || + (family == AF_INET6 && type == ICMPV6_EXT_ECHO_REQUEST && code == 0); +} + +/* + * This routine is called by the ICMP module when it gets some + * sort of error condition. + */ + +void ping_err(struct sk_buff *skb, int offset, u32 info) +{ + int family; + struct icmphdr *icmph; + struct inet_sock *inet_sock; + int type; + int code; + struct net *net = dev_net(skb->dev); + struct sock *sk; + int harderr; + int err; + + if (skb->protocol == htons(ETH_P_IP)) { + family = AF_INET; + type = icmp_hdr(skb)->type; + code = icmp_hdr(skb)->code; + icmph = (struct icmphdr *)(skb->data + offset); + } else if (skb->protocol == htons(ETH_P_IPV6)) { + family = AF_INET6; + type = icmp6_hdr(skb)->icmp6_type; + code = icmp6_hdr(skb)->icmp6_code; + icmph = (struct icmphdr *) (skb->data + offset); + } else { + BUG(); + } + + /* We assume the packet has already been checked by icmp_unreach */ + + if (!ping_supported(family, icmph->type, icmph->code)) + return; + + pr_debug("ping_err(proto=0x%x,type=%d,code=%d,id=%04x,seq=%04x)\n", + skb->protocol, type, code, ntohs(icmph->un.echo.id), + ntohs(icmph->un.echo.sequence)); + + sk = ping_lookup(net, skb, ntohs(icmph->un.echo.id)); + if (!sk) { + pr_debug("no socket, dropping\n"); + return; /* No socket for error */ + } + pr_debug("err on socket %p\n", sk); + + err = 0; + harderr = 0; + inet_sock = inet_sk(sk); + + if (skb->protocol == htons(ETH_P_IP)) { + switch (type) { + default: + case ICMP_TIME_EXCEEDED: + err = EHOSTUNREACH; + break; + case ICMP_SOURCE_QUENCH: + /* This is not a real error but ping wants to see it. + * Report it with some fake errno. + */ + err = EREMOTEIO; + break; + case ICMP_PARAMETERPROB: + err = EPROTO; + harderr = 1; + break; + case ICMP_DEST_UNREACH: + if (code == ICMP_FRAG_NEEDED) { /* Path MTU discovery */ + ipv4_sk_update_pmtu(skb, sk, info); + if (inet_sock->pmtudisc != IP_PMTUDISC_DONT) { + err = EMSGSIZE; + harderr = 1; + break; + } + goto out; + } + err = EHOSTUNREACH; + if (code <= NR_ICMP_UNREACH) { + harderr = icmp_err_convert[code].fatal; + err = icmp_err_convert[code].errno; + } + break; + case ICMP_REDIRECT: + /* See ICMP_SOURCE_QUENCH */ + ipv4_sk_redirect(skb, sk); + err = EREMOTEIO; + break; + } +#if IS_ENABLED(CONFIG_IPV6) + } else if (skb->protocol == htons(ETH_P_IPV6)) { + harderr = pingv6_ops.icmpv6_err_convert(type, code, &err); +#endif + } + + /* + * RFC1122: OK. Passes ICMP errors back to application, as per + * 4.1.3.3. + */ + if ((family == AF_INET && !inet_sock->recverr) || + (family == AF_INET6 && !inet6_sk(sk)->recverr)) { + if (!harderr || sk->sk_state != TCP_ESTABLISHED) + goto out; + } else { + if (family == AF_INET) { + ip_icmp_error(sk, skb, err, 0 /* no remote port */, + info, (u8 *)icmph); +#if IS_ENABLED(CONFIG_IPV6) + } else if (family == AF_INET6) { + pingv6_ops.ipv6_icmp_error(sk, skb, err, 0, + info, (u8 *)icmph); +#endif + } + } + sk->sk_err = err; + sk_error_report(sk); +out: + return; +} +EXPORT_SYMBOL_GPL(ping_err); + +/* + * Copy and checksum an ICMP Echo packet from user space into a buffer + * starting from the payload. + */ + +int ping_getfrag(void *from, char *to, + int offset, int fraglen, int odd, struct sk_buff *skb) +{ + struct pingfakehdr *pfh = from; + + if (!csum_and_copy_from_iter_full(to, fraglen, &pfh->wcheck, + &pfh->msg->msg_iter)) + return -EFAULT; + +#if IS_ENABLED(CONFIG_IPV6) + /* For IPv6, checksum each skb as we go along, as expected by + * icmpv6_push_pending_frames. For IPv4, accumulate the checksum in + * wcheck, it will be finalized in ping_v4_push_pending_frames. + */ + if (pfh->family == AF_INET6) { + skb->csum = csum_block_add(skb->csum, pfh->wcheck, odd); + skb->ip_summed = CHECKSUM_NONE; + pfh->wcheck = 0; + } +#endif + + return 0; +} +EXPORT_SYMBOL_GPL(ping_getfrag); + +static int ping_v4_push_pending_frames(struct sock *sk, struct pingfakehdr *pfh, + struct flowi4 *fl4) +{ + struct sk_buff *skb = skb_peek(&sk->sk_write_queue); + + if (!skb) + return 0; + pfh->wcheck = csum_partial((char *)&pfh->icmph, + sizeof(struct icmphdr), pfh->wcheck); + pfh->icmph.checksum = csum_fold(pfh->wcheck); + memcpy(icmp_hdr(skb), &pfh->icmph, sizeof(struct icmphdr)); + skb->ip_summed = CHECKSUM_NONE; + return ip_push_pending_frames(sk, fl4); +} + +int ping_common_sendmsg(int family, struct msghdr *msg, size_t len, + void *user_icmph, size_t icmph_len) +{ + u8 type, code; + + if (len > 0xFFFF) + return -EMSGSIZE; + + /* Must have at least a full ICMP header. */ + if (len < icmph_len) + return -EINVAL; + + /* + * Check the flags. + */ + + /* Mirror BSD error message compatibility */ + if (msg->msg_flags & MSG_OOB) + return -EOPNOTSUPP; + + /* + * Fetch the ICMP header provided by the userland. + * iovec is modified! The ICMP header is consumed. + */ + if (memcpy_from_msg(user_icmph, msg, icmph_len)) + return -EFAULT; + + if (family == AF_INET) { + type = ((struct icmphdr *) user_icmph)->type; + code = ((struct icmphdr *) user_icmph)->code; +#if IS_ENABLED(CONFIG_IPV6) + } else if (family == AF_INET6) { + type = ((struct icmp6hdr *) user_icmph)->icmp6_type; + code = ((struct icmp6hdr *) user_icmph)->icmp6_code; +#endif + } else { + BUG(); + } + + if (!ping_supported(family, type, code)) + return -EINVAL; + + return 0; +} +EXPORT_SYMBOL_GPL(ping_common_sendmsg); + +static int ping_v4_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) +{ + struct net *net = sock_net(sk); + struct flowi4 fl4; + struct inet_sock *inet = inet_sk(sk); + struct ipcm_cookie ipc; + struct icmphdr user_icmph; + struct pingfakehdr pfh; + struct rtable *rt = NULL; + struct ip_options_data opt_copy; + int free = 0; + __be32 saddr, daddr, faddr; + u8 tos; + int err; + + pr_debug("ping_v4_sendmsg(sk=%p,sk->num=%u)\n", inet, inet->inet_num); + + err = ping_common_sendmsg(AF_INET, msg, len, &user_icmph, + sizeof(user_icmph)); + if (err) + return err; + + /* + * Get and verify the address. + */ + + if (msg->msg_name) { + DECLARE_SOCKADDR(struct sockaddr_in *, usin, msg->msg_name); + if (msg->msg_namelen < sizeof(*usin)) + return -EINVAL; + if (usin->sin_family != AF_INET) + return -EAFNOSUPPORT; + daddr = usin->sin_addr.s_addr; + /* no remote port */ + } else { + if (sk->sk_state != TCP_ESTABLISHED) + return -EDESTADDRREQ; + daddr = inet->inet_daddr; + /* no remote port */ + } + + ipcm_init_sk(&ipc, inet); + + if (msg->msg_controllen) { + err = ip_cmsg_send(sk, msg, &ipc, false); + if (unlikely(err)) { + kfree(ipc.opt); + return err; + } + if (ipc.opt) + free = 1; + } + if (!ipc.opt) { + struct ip_options_rcu *inet_opt; + + rcu_read_lock(); + inet_opt = rcu_dereference(inet->inet_opt); + if (inet_opt) { + memcpy(&opt_copy, inet_opt, + sizeof(*inet_opt) + inet_opt->opt.optlen); + ipc.opt = &opt_copy.opt; + } + rcu_read_unlock(); + } + + saddr = ipc.addr; + ipc.addr = faddr = daddr; + + if (ipc.opt && ipc.opt->opt.srr) { + if (!daddr) { + err = -EINVAL; + goto out_free; + } + faddr = ipc.opt->opt.faddr; + } + tos = get_rttos(&ipc, inet); + if (sock_flag(sk, SOCK_LOCALROUTE) || + (msg->msg_flags & MSG_DONTROUTE) || + (ipc.opt && ipc.opt->opt.is_strictroute)) { + tos |= RTO_ONLINK; + } + + if (ipv4_is_multicast(daddr)) { + if (!ipc.oif || netif_index_is_l3_master(sock_net(sk), ipc.oif)) + ipc.oif = inet->mc_index; + if (!saddr) + saddr = inet->mc_addr; + } else if (!ipc.oif) + ipc.oif = inet->uc_index; + + flowi4_init_output(&fl4, ipc.oif, ipc.sockc.mark, tos, + RT_SCOPE_UNIVERSE, sk->sk_protocol, + inet_sk_flowi_flags(sk), faddr, saddr, 0, 0, + sk->sk_uid); + + fl4.fl4_icmp_type = user_icmph.type; + fl4.fl4_icmp_code = user_icmph.code; + + security_sk_classify_flow(sk, flowi4_to_flowi_common(&fl4)); + rt = ip_route_output_flow(net, &fl4, sk); + if (IS_ERR(rt)) { + err = PTR_ERR(rt); + rt = NULL; + if (err == -ENETUNREACH) + IP_INC_STATS(net, IPSTATS_MIB_OUTNOROUTES); + goto out; + } + + err = -EACCES; + if ((rt->rt_flags & RTCF_BROADCAST) && + !sock_flag(sk, SOCK_BROADCAST)) + goto out; + + if (msg->msg_flags & MSG_CONFIRM) + goto do_confirm; +back_from_confirm: + + if (!ipc.addr) + ipc.addr = fl4.daddr; + + lock_sock(sk); + + pfh.icmph.type = user_icmph.type; /* already checked */ + pfh.icmph.code = user_icmph.code; /* ditto */ + pfh.icmph.checksum = 0; + pfh.icmph.un.echo.id = inet->inet_sport; + pfh.icmph.un.echo.sequence = user_icmph.un.echo.sequence; + pfh.msg = msg; + pfh.wcheck = 0; + pfh.family = AF_INET; + + err = ip_append_data(sk, &fl4, ping_getfrag, &pfh, len, + sizeof(struct icmphdr), &ipc, &rt, + msg->msg_flags); + if (err) + ip_flush_pending_frames(sk); + else + err = ping_v4_push_pending_frames(sk, &pfh, &fl4); + release_sock(sk); + +out: + ip_rt_put(rt); +out_free: + if (free) + kfree(ipc.opt); + if (!err) { + icmp_out_count(sock_net(sk), user_icmph.type); + return len; + } + return err; + +do_confirm: + if (msg->msg_flags & MSG_PROBE) + dst_confirm_neigh(&rt->dst, &fl4.daddr); + if (!(msg->msg_flags & MSG_PROBE) || len) + goto back_from_confirm; + err = 0; + goto out; +} + +int ping_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, + int *addr_len) +{ + struct inet_sock *isk = inet_sk(sk); + int family = sk->sk_family; + struct sk_buff *skb; + int copied, err; + + pr_debug("ping_recvmsg(sk=%p,sk->num=%u)\n", isk, isk->inet_num); + + err = -EOPNOTSUPP; + if (flags & MSG_OOB) + goto out; + + if (flags & MSG_ERRQUEUE) + return inet_recv_error(sk, msg, len, addr_len); + + skb = skb_recv_datagram(sk, flags, &err); + if (!skb) + goto out; + + copied = skb->len; + if (copied > len) { + msg->msg_flags |= MSG_TRUNC; + copied = len; + } + + /* Don't bother checking the checksum */ + err = skb_copy_datagram_msg(skb, 0, msg, copied); + if (err) + goto done; + + sock_recv_timestamp(msg, sk, skb); + + /* Copy the address and add cmsg data. */ + if (family == AF_INET) { + DECLARE_SOCKADDR(struct sockaddr_in *, sin, msg->msg_name); + + if (sin) { + sin->sin_family = AF_INET; + sin->sin_port = 0 /* skb->h.uh->source */; + sin->sin_addr.s_addr = ip_hdr(skb)->saddr; + memset(sin->sin_zero, 0, sizeof(sin->sin_zero)); + *addr_len = sizeof(*sin); + } + + if (isk->cmsg_flags) + ip_cmsg_recv(msg, skb); + +#if IS_ENABLED(CONFIG_IPV6) + } else if (family == AF_INET6) { + struct ipv6_pinfo *np = inet6_sk(sk); + struct ipv6hdr *ip6 = ipv6_hdr(skb); + DECLARE_SOCKADDR(struct sockaddr_in6 *, sin6, msg->msg_name); + + if (sin6) { + sin6->sin6_family = AF_INET6; + sin6->sin6_port = 0; + sin6->sin6_addr = ip6->saddr; + sin6->sin6_flowinfo = 0; + if (np->sndflow) + sin6->sin6_flowinfo = ip6_flowinfo(ip6); + sin6->sin6_scope_id = + ipv6_iface_scope_id(&sin6->sin6_addr, + inet6_iif(skb)); + *addr_len = sizeof(*sin6); + } + + if (inet6_sk(sk)->rxopt.all) + pingv6_ops.ip6_datagram_recv_common_ctl(sk, msg, skb); + if (skb->protocol == htons(ETH_P_IPV6) && + inet6_sk(sk)->rxopt.all) + pingv6_ops.ip6_datagram_recv_specific_ctl(sk, msg, skb); + else if (skb->protocol == htons(ETH_P_IP) && isk->cmsg_flags) + ip_cmsg_recv(msg, skb); +#endif + } else { + BUG(); + } + + err = copied; + +done: + skb_free_datagram(sk, skb); +out: + pr_debug("ping_recvmsg -> %d\n", err); + return err; +} +EXPORT_SYMBOL_GPL(ping_recvmsg); + +static enum skb_drop_reason __ping_queue_rcv_skb(struct sock *sk, + struct sk_buff *skb) +{ + enum skb_drop_reason reason; + + pr_debug("ping_queue_rcv_skb(sk=%p,sk->num=%d,skb=%p)\n", + inet_sk(sk), inet_sk(sk)->inet_num, skb); + if (sock_queue_rcv_skb_reason(sk, skb, &reason) < 0) { + kfree_skb_reason(skb, reason); + pr_debug("ping_queue_rcv_skb -> failed\n"); + return reason; + } + return SKB_NOT_DROPPED_YET; +} + +int ping_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) +{ + return __ping_queue_rcv_skb(sk, skb) ? -1 : 0; +} +EXPORT_SYMBOL_GPL(ping_queue_rcv_skb); + + +/* + * All we need to do is get the socket. + */ + +enum skb_drop_reason ping_rcv(struct sk_buff *skb) +{ + enum skb_drop_reason reason = SKB_DROP_REASON_NO_SOCKET; + struct sock *sk; + struct net *net = dev_net(skb->dev); + struct icmphdr *icmph = icmp_hdr(skb); + + /* We assume the packet has already been checked by icmp_rcv */ + + pr_debug("ping_rcv(skb=%p,id=%04x,seq=%04x)\n", + skb, ntohs(icmph->un.echo.id), ntohs(icmph->un.echo.sequence)); + + /* Push ICMP header back */ + skb_push(skb, skb->data - (u8 *)icmph); + + sk = ping_lookup(net, skb, ntohs(icmph->un.echo.id)); + if (sk) { + struct sk_buff *skb2 = skb_clone(skb, GFP_ATOMIC); + + pr_debug("rcv on socket %p\n", sk); + if (skb2) + reason = __ping_queue_rcv_skb(sk, skb2); + else + reason = SKB_DROP_REASON_NOMEM; + } + + if (reason) + pr_debug("no socket, dropping\n"); + + return reason; +} +EXPORT_SYMBOL_GPL(ping_rcv); + +struct proto ping_prot = { + .name = "PING", + .owner = THIS_MODULE, + .init = ping_init_sock, + .close = ping_close, + .pre_connect = ping_pre_connect, + .connect = ip4_datagram_connect, + .disconnect = __udp_disconnect, + .setsockopt = ip_setsockopt, + .getsockopt = ip_getsockopt, + .sendmsg = ping_v4_sendmsg, + .recvmsg = ping_recvmsg, + .bind = ping_bind, + .backlog_rcv = ping_queue_rcv_skb, + .release_cb = ip4_datagram_release_cb, + .hash = ping_hash, + .unhash = ping_unhash, + .get_port = ping_get_port, + .put_port = ping_unhash, + .obj_size = sizeof(struct inet_sock), +}; +EXPORT_SYMBOL(ping_prot); + +#ifdef CONFIG_PROC_FS + +static struct sock *ping_get_first(struct seq_file *seq, int start) +{ + struct sock *sk; + struct ping_iter_state *state = seq->private; + struct net *net = seq_file_net(seq); + + for (state->bucket = start; state->bucket < PING_HTABLE_SIZE; + ++state->bucket) { + struct hlist_nulls_node *node; + struct hlist_nulls_head *hslot; + + hslot = &ping_table.hash[state->bucket]; + + if (hlist_nulls_empty(hslot)) + continue; + + sk_nulls_for_each(sk, node, hslot) { + if (net_eq(sock_net(sk), net) && + sk->sk_family == state->family) + goto found; + } + } + sk = NULL; +found: + return sk; +} + +static struct sock *ping_get_next(struct seq_file *seq, struct sock *sk) +{ + struct ping_iter_state *state = seq->private; + struct net *net = seq_file_net(seq); + + do { + sk = sk_nulls_next(sk); + } while (sk && (!net_eq(sock_net(sk), net))); + + if (!sk) + return ping_get_first(seq, state->bucket + 1); + return sk; +} + +static struct sock *ping_get_idx(struct seq_file *seq, loff_t pos) +{ + struct sock *sk = ping_get_first(seq, 0); + + if (sk) + while (pos && (sk = ping_get_next(seq, sk)) != NULL) + --pos; + return pos ? NULL : sk; +} + +void *ping_seq_start(struct seq_file *seq, loff_t *pos, sa_family_t family) + __acquires(ping_table.lock) +{ + struct ping_iter_state *state = seq->private; + state->bucket = 0; + state->family = family; + + spin_lock(&ping_table.lock); + + return *pos ? ping_get_idx(seq, *pos-1) : SEQ_START_TOKEN; +} +EXPORT_SYMBOL_GPL(ping_seq_start); + +static void *ping_v4_seq_start(struct seq_file *seq, loff_t *pos) +{ + return ping_seq_start(seq, pos, AF_INET); +} + +void *ping_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct sock *sk; + + if (v == SEQ_START_TOKEN) + sk = ping_get_idx(seq, 0); + else + sk = ping_get_next(seq, v); + + ++*pos; + return sk; +} +EXPORT_SYMBOL_GPL(ping_seq_next); + +void ping_seq_stop(struct seq_file *seq, void *v) + __releases(ping_table.lock) +{ + spin_unlock(&ping_table.lock); +} +EXPORT_SYMBOL_GPL(ping_seq_stop); + +static void ping_v4_format_sock(struct sock *sp, struct seq_file *f, + int bucket) +{ + struct inet_sock *inet = inet_sk(sp); + __be32 dest = inet->inet_daddr; + __be32 src = inet->inet_rcv_saddr; + __u16 destp = ntohs(inet->inet_dport); + __u16 srcp = ntohs(inet->inet_sport); + + seq_printf(f, "%5d: %08X:%04X %08X:%04X" + " %02X %08X:%08X %02X:%08lX %08X %5u %8d %lu %d %pK %u", + bucket, src, srcp, dest, destp, sp->sk_state, + sk_wmem_alloc_get(sp), + sk_rmem_alloc_get(sp), + 0, 0L, 0, + from_kuid_munged(seq_user_ns(f), sock_i_uid(sp)), + 0, sock_i_ino(sp), + refcount_read(&sp->sk_refcnt), sp, + atomic_read(&sp->sk_drops)); +} + +static int ping_v4_seq_show(struct seq_file *seq, void *v) +{ + seq_setwidth(seq, 127); + if (v == SEQ_START_TOKEN) + seq_puts(seq, " sl local_address rem_address st tx_queue " + "rx_queue tr tm->when retrnsmt uid timeout " + "inode ref pointer drops"); + else { + struct ping_iter_state *state = seq->private; + + ping_v4_format_sock(v, seq, state->bucket); + } + seq_pad(seq, '\n'); + return 0; +} + +static const struct seq_operations ping_v4_seq_ops = { + .start = ping_v4_seq_start, + .show = ping_v4_seq_show, + .next = ping_seq_next, + .stop = ping_seq_stop, +}; + +static int __net_init ping_v4_proc_init_net(struct net *net) +{ + if (!proc_create_net("icmp", 0444, net->proc_net, &ping_v4_seq_ops, + sizeof(struct ping_iter_state))) + return -ENOMEM; + return 0; +} + +static void __net_exit ping_v4_proc_exit_net(struct net *net) +{ + remove_proc_entry("icmp", net->proc_net); +} + +static struct pernet_operations ping_v4_net_ops = { + .init = ping_v4_proc_init_net, + .exit = ping_v4_proc_exit_net, +}; + +int __init ping_proc_init(void) +{ + return register_pernet_subsys(&ping_v4_net_ops); +} + +void ping_proc_exit(void) +{ + unregister_pernet_subsys(&ping_v4_net_ops); +} + +#endif + +void __init ping_init(void) +{ + int i; + + for (i = 0; i < PING_HTABLE_SIZE; i++) + INIT_HLIST_NULLS_HEAD(&ping_table.hash[i], i); + spin_lock_init(&ping_table.lock); +} diff --git a/net/ipv4/proc.c b/net/ipv4/proc.c new file mode 100644 index 000000000..5386f460b --- /dev/null +++ b/net/ipv4/proc.c @@ -0,0 +1,557 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * This file implements the various access functions for the + * PROC file system. It is mainly used for debugging and + * statistics. + * + * Authors: Fred N. van Kempen, + * Gerald J. Heim, + * Fred Baumgarten, + * Erik Schoenfelder, + * + * Fixes: + * Alan Cox : UDP sockets show the rxqueue/txqueue + * using hint flag for the netinfo. + * Pauline Middelink : identd support + * Alan Cox : Make /proc safer. + * Erik Schoenfelder : /proc/net/snmp + * Alan Cox : Handle dead sockets properly. + * Gerhard Koerting : Show both timers + * Alan Cox : Allow inode to be NULL (kernel socket) + * Andi Kleen : Add support for open_requests and + * split functions for more readibility. + * Andi Kleen : Add support for /proc/net/netstat + * Arnaldo C. Melo : Convert to seq_file + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define TCPUDP_MIB_MAX max_t(u32, UDP_MIB_MAX, TCP_MIB_MAX) + +/* + * Report socket allocation statistics [mea@utu.fi] + */ +static int sockstat_seq_show(struct seq_file *seq, void *v) +{ + struct net *net = seq->private; + int orphans, sockets; + + orphans = tcp_orphan_count_sum(); + sockets = proto_sockets_allocated_sum_positive(&tcp_prot); + + socket_seq_show(seq); + seq_printf(seq, "TCP: inuse %d orphan %d tw %d alloc %d mem %ld\n", + sock_prot_inuse_get(net, &tcp_prot), orphans, + refcount_read(&net->ipv4.tcp_death_row.tw_refcount) - 1, + sockets, proto_memory_allocated(&tcp_prot)); + seq_printf(seq, "UDP: inuse %d mem %ld\n", + sock_prot_inuse_get(net, &udp_prot), + proto_memory_allocated(&udp_prot)); + seq_printf(seq, "UDPLITE: inuse %d\n", + sock_prot_inuse_get(net, &udplite_prot)); + seq_printf(seq, "RAW: inuse %d\n", + sock_prot_inuse_get(net, &raw_prot)); + seq_printf(seq, "FRAG: inuse %u memory %lu\n", + atomic_read(&net->ipv4.fqdir->rhashtable.nelems), + frag_mem_limit(net->ipv4.fqdir)); + return 0; +} + +/* snmp items */ +static const struct snmp_mib snmp4_ipstats_list[] = { + SNMP_MIB_ITEM("InReceives", IPSTATS_MIB_INPKTS), + SNMP_MIB_ITEM("InHdrErrors", IPSTATS_MIB_INHDRERRORS), + SNMP_MIB_ITEM("InAddrErrors", IPSTATS_MIB_INADDRERRORS), + SNMP_MIB_ITEM("ForwDatagrams", IPSTATS_MIB_OUTFORWDATAGRAMS), + SNMP_MIB_ITEM("InUnknownProtos", IPSTATS_MIB_INUNKNOWNPROTOS), + SNMP_MIB_ITEM("InDiscards", IPSTATS_MIB_INDISCARDS), + SNMP_MIB_ITEM("InDelivers", IPSTATS_MIB_INDELIVERS), + SNMP_MIB_ITEM("OutRequests", IPSTATS_MIB_OUTPKTS), + SNMP_MIB_ITEM("OutDiscards", IPSTATS_MIB_OUTDISCARDS), + SNMP_MIB_ITEM("OutNoRoutes", IPSTATS_MIB_OUTNOROUTES), + SNMP_MIB_ITEM("ReasmTimeout", IPSTATS_MIB_REASMTIMEOUT), + SNMP_MIB_ITEM("ReasmReqds", IPSTATS_MIB_REASMREQDS), + SNMP_MIB_ITEM("ReasmOKs", IPSTATS_MIB_REASMOKS), + SNMP_MIB_ITEM("ReasmFails", IPSTATS_MIB_REASMFAILS), + SNMP_MIB_ITEM("FragOKs", IPSTATS_MIB_FRAGOKS), + SNMP_MIB_ITEM("FragFails", IPSTATS_MIB_FRAGFAILS), + SNMP_MIB_ITEM("FragCreates", IPSTATS_MIB_FRAGCREATES), + SNMP_MIB_SENTINEL +}; + +/* Following items are displayed in /proc/net/netstat */ +static const struct snmp_mib snmp4_ipextstats_list[] = { + SNMP_MIB_ITEM("InNoRoutes", IPSTATS_MIB_INNOROUTES), + SNMP_MIB_ITEM("InTruncatedPkts", IPSTATS_MIB_INTRUNCATEDPKTS), + SNMP_MIB_ITEM("InMcastPkts", IPSTATS_MIB_INMCASTPKTS), + SNMP_MIB_ITEM("OutMcastPkts", IPSTATS_MIB_OUTMCASTPKTS), + SNMP_MIB_ITEM("InBcastPkts", IPSTATS_MIB_INBCASTPKTS), + SNMP_MIB_ITEM("OutBcastPkts", IPSTATS_MIB_OUTBCASTPKTS), + SNMP_MIB_ITEM("InOctets", IPSTATS_MIB_INOCTETS), + SNMP_MIB_ITEM("OutOctets", IPSTATS_MIB_OUTOCTETS), + SNMP_MIB_ITEM("InMcastOctets", IPSTATS_MIB_INMCASTOCTETS), + SNMP_MIB_ITEM("OutMcastOctets", IPSTATS_MIB_OUTMCASTOCTETS), + SNMP_MIB_ITEM("InBcastOctets", IPSTATS_MIB_INBCASTOCTETS), + SNMP_MIB_ITEM("OutBcastOctets", IPSTATS_MIB_OUTBCASTOCTETS), + /* Non RFC4293 fields */ + SNMP_MIB_ITEM("InCsumErrors", IPSTATS_MIB_CSUMERRORS), + SNMP_MIB_ITEM("InNoECTPkts", IPSTATS_MIB_NOECTPKTS), + SNMP_MIB_ITEM("InECT1Pkts", IPSTATS_MIB_ECT1PKTS), + SNMP_MIB_ITEM("InECT0Pkts", IPSTATS_MIB_ECT0PKTS), + SNMP_MIB_ITEM("InCEPkts", IPSTATS_MIB_CEPKTS), + SNMP_MIB_ITEM("ReasmOverlaps", IPSTATS_MIB_REASM_OVERLAPS), + SNMP_MIB_SENTINEL +}; + +static const struct { + const char *name; + int index; +} icmpmibmap[] = { + { "DestUnreachs", ICMP_DEST_UNREACH }, + { "TimeExcds", ICMP_TIME_EXCEEDED }, + { "ParmProbs", ICMP_PARAMETERPROB }, + { "SrcQuenchs", ICMP_SOURCE_QUENCH }, + { "Redirects", ICMP_REDIRECT }, + { "Echos", ICMP_ECHO }, + { "EchoReps", ICMP_ECHOREPLY }, + { "Timestamps", ICMP_TIMESTAMP }, + { "TimestampReps", ICMP_TIMESTAMPREPLY }, + { "AddrMasks", ICMP_ADDRESS }, + { "AddrMaskReps", ICMP_ADDRESSREPLY }, + { NULL, 0 } +}; + + +static const struct snmp_mib snmp4_tcp_list[] = { + SNMP_MIB_ITEM("RtoAlgorithm", TCP_MIB_RTOALGORITHM), + SNMP_MIB_ITEM("RtoMin", TCP_MIB_RTOMIN), + SNMP_MIB_ITEM("RtoMax", TCP_MIB_RTOMAX), + SNMP_MIB_ITEM("MaxConn", TCP_MIB_MAXCONN), + SNMP_MIB_ITEM("ActiveOpens", TCP_MIB_ACTIVEOPENS), + SNMP_MIB_ITEM("PassiveOpens", TCP_MIB_PASSIVEOPENS), + SNMP_MIB_ITEM("AttemptFails", TCP_MIB_ATTEMPTFAILS), + SNMP_MIB_ITEM("EstabResets", TCP_MIB_ESTABRESETS), + SNMP_MIB_ITEM("CurrEstab", TCP_MIB_CURRESTAB), + SNMP_MIB_ITEM("InSegs", TCP_MIB_INSEGS), + SNMP_MIB_ITEM("OutSegs", TCP_MIB_OUTSEGS), + SNMP_MIB_ITEM("RetransSegs", TCP_MIB_RETRANSSEGS), + SNMP_MIB_ITEM("InErrs", TCP_MIB_INERRS), + SNMP_MIB_ITEM("OutRsts", TCP_MIB_OUTRSTS), + SNMP_MIB_ITEM("InCsumErrors", TCP_MIB_CSUMERRORS), + SNMP_MIB_SENTINEL +}; + +static const struct snmp_mib snmp4_udp_list[] = { + SNMP_MIB_ITEM("InDatagrams", UDP_MIB_INDATAGRAMS), + SNMP_MIB_ITEM("NoPorts", UDP_MIB_NOPORTS), + SNMP_MIB_ITEM("InErrors", UDP_MIB_INERRORS), + SNMP_MIB_ITEM("OutDatagrams", UDP_MIB_OUTDATAGRAMS), + SNMP_MIB_ITEM("RcvbufErrors", UDP_MIB_RCVBUFERRORS), + SNMP_MIB_ITEM("SndbufErrors", UDP_MIB_SNDBUFERRORS), + SNMP_MIB_ITEM("InCsumErrors", UDP_MIB_CSUMERRORS), + SNMP_MIB_ITEM("IgnoredMulti", UDP_MIB_IGNOREDMULTI), + SNMP_MIB_ITEM("MemErrors", UDP_MIB_MEMERRORS), + SNMP_MIB_SENTINEL +}; + +static const struct snmp_mib snmp4_net_list[] = { + SNMP_MIB_ITEM("SyncookiesSent", LINUX_MIB_SYNCOOKIESSENT), + SNMP_MIB_ITEM("SyncookiesRecv", LINUX_MIB_SYNCOOKIESRECV), + SNMP_MIB_ITEM("SyncookiesFailed", LINUX_MIB_SYNCOOKIESFAILED), + SNMP_MIB_ITEM("EmbryonicRsts", LINUX_MIB_EMBRYONICRSTS), + SNMP_MIB_ITEM("PruneCalled", LINUX_MIB_PRUNECALLED), + SNMP_MIB_ITEM("RcvPruned", LINUX_MIB_RCVPRUNED), + SNMP_MIB_ITEM("OfoPruned", LINUX_MIB_OFOPRUNED), + SNMP_MIB_ITEM("OutOfWindowIcmps", LINUX_MIB_OUTOFWINDOWICMPS), + SNMP_MIB_ITEM("LockDroppedIcmps", LINUX_MIB_LOCKDROPPEDICMPS), + SNMP_MIB_ITEM("ArpFilter", LINUX_MIB_ARPFILTER), + SNMP_MIB_ITEM("TW", LINUX_MIB_TIMEWAITED), + SNMP_MIB_ITEM("TWRecycled", LINUX_MIB_TIMEWAITRECYCLED), + SNMP_MIB_ITEM("TWKilled", LINUX_MIB_TIMEWAITKILLED), + SNMP_MIB_ITEM("PAWSActive", LINUX_MIB_PAWSACTIVEREJECTED), + SNMP_MIB_ITEM("PAWSEstab", LINUX_MIB_PAWSESTABREJECTED), + SNMP_MIB_ITEM("DelayedACKs", LINUX_MIB_DELAYEDACKS), + SNMP_MIB_ITEM("DelayedACKLocked", LINUX_MIB_DELAYEDACKLOCKED), + SNMP_MIB_ITEM("DelayedACKLost", LINUX_MIB_DELAYEDACKLOST), + SNMP_MIB_ITEM("ListenOverflows", LINUX_MIB_LISTENOVERFLOWS), + SNMP_MIB_ITEM("ListenDrops", LINUX_MIB_LISTENDROPS), + SNMP_MIB_ITEM("TCPHPHits", LINUX_MIB_TCPHPHITS), + SNMP_MIB_ITEM("TCPPureAcks", LINUX_MIB_TCPPUREACKS), + SNMP_MIB_ITEM("TCPHPAcks", LINUX_MIB_TCPHPACKS), + SNMP_MIB_ITEM("TCPRenoRecovery", LINUX_MIB_TCPRENORECOVERY), + SNMP_MIB_ITEM("TCPSackRecovery", LINUX_MIB_TCPSACKRECOVERY), + SNMP_MIB_ITEM("TCPSACKReneging", LINUX_MIB_TCPSACKRENEGING), + SNMP_MIB_ITEM("TCPSACKReorder", LINUX_MIB_TCPSACKREORDER), + SNMP_MIB_ITEM("TCPRenoReorder", LINUX_MIB_TCPRENOREORDER), + SNMP_MIB_ITEM("TCPTSReorder", LINUX_MIB_TCPTSREORDER), + SNMP_MIB_ITEM("TCPFullUndo", LINUX_MIB_TCPFULLUNDO), + SNMP_MIB_ITEM("TCPPartialUndo", LINUX_MIB_TCPPARTIALUNDO), + SNMP_MIB_ITEM("TCPDSACKUndo", LINUX_MIB_TCPDSACKUNDO), + SNMP_MIB_ITEM("TCPLossUndo", LINUX_MIB_TCPLOSSUNDO), + SNMP_MIB_ITEM("TCPLostRetransmit", LINUX_MIB_TCPLOSTRETRANSMIT), + SNMP_MIB_ITEM("TCPRenoFailures", LINUX_MIB_TCPRENOFAILURES), + SNMP_MIB_ITEM("TCPSackFailures", LINUX_MIB_TCPSACKFAILURES), + SNMP_MIB_ITEM("TCPLossFailures", LINUX_MIB_TCPLOSSFAILURES), + SNMP_MIB_ITEM("TCPFastRetrans", LINUX_MIB_TCPFASTRETRANS), + SNMP_MIB_ITEM("TCPSlowStartRetrans", LINUX_MIB_TCPSLOWSTARTRETRANS), + SNMP_MIB_ITEM("TCPTimeouts", LINUX_MIB_TCPTIMEOUTS), + SNMP_MIB_ITEM("TCPLossProbes", LINUX_MIB_TCPLOSSPROBES), + SNMP_MIB_ITEM("TCPLossProbeRecovery", LINUX_MIB_TCPLOSSPROBERECOVERY), + SNMP_MIB_ITEM("TCPRenoRecoveryFail", LINUX_MIB_TCPRENORECOVERYFAIL), + SNMP_MIB_ITEM("TCPSackRecoveryFail", LINUX_MIB_TCPSACKRECOVERYFAIL), + SNMP_MIB_ITEM("TCPRcvCollapsed", LINUX_MIB_TCPRCVCOLLAPSED), + SNMP_MIB_ITEM("TCPBacklogCoalesce", LINUX_MIB_TCPBACKLOGCOALESCE), + SNMP_MIB_ITEM("TCPDSACKOldSent", LINUX_MIB_TCPDSACKOLDSENT), + SNMP_MIB_ITEM("TCPDSACKOfoSent", LINUX_MIB_TCPDSACKOFOSENT), + SNMP_MIB_ITEM("TCPDSACKRecv", LINUX_MIB_TCPDSACKRECV), + SNMP_MIB_ITEM("TCPDSACKOfoRecv", LINUX_MIB_TCPDSACKOFORECV), + SNMP_MIB_ITEM("TCPAbortOnData", LINUX_MIB_TCPABORTONDATA), + SNMP_MIB_ITEM("TCPAbortOnClose", LINUX_MIB_TCPABORTONCLOSE), + SNMP_MIB_ITEM("TCPAbortOnMemory", LINUX_MIB_TCPABORTONMEMORY), + SNMP_MIB_ITEM("TCPAbortOnTimeout", LINUX_MIB_TCPABORTONTIMEOUT), + SNMP_MIB_ITEM("TCPAbortOnLinger", LINUX_MIB_TCPABORTONLINGER), + SNMP_MIB_ITEM("TCPAbortFailed", LINUX_MIB_TCPABORTFAILED), + SNMP_MIB_ITEM("TCPMemoryPressures", LINUX_MIB_TCPMEMORYPRESSURES), + SNMP_MIB_ITEM("TCPMemoryPressuresChrono", LINUX_MIB_TCPMEMORYPRESSURESCHRONO), + SNMP_MIB_ITEM("TCPSACKDiscard", LINUX_MIB_TCPSACKDISCARD), + SNMP_MIB_ITEM("TCPDSACKIgnoredOld", LINUX_MIB_TCPDSACKIGNOREDOLD), + SNMP_MIB_ITEM("TCPDSACKIgnoredNoUndo", LINUX_MIB_TCPDSACKIGNOREDNOUNDO), + SNMP_MIB_ITEM("TCPSpuriousRTOs", LINUX_MIB_TCPSPURIOUSRTOS), + SNMP_MIB_ITEM("TCPMD5NotFound", LINUX_MIB_TCPMD5NOTFOUND), + SNMP_MIB_ITEM("TCPMD5Unexpected", LINUX_MIB_TCPMD5UNEXPECTED), + SNMP_MIB_ITEM("TCPMD5Failure", LINUX_MIB_TCPMD5FAILURE), + SNMP_MIB_ITEM("TCPSackShifted", LINUX_MIB_SACKSHIFTED), + SNMP_MIB_ITEM("TCPSackMerged", LINUX_MIB_SACKMERGED), + SNMP_MIB_ITEM("TCPSackShiftFallback", LINUX_MIB_SACKSHIFTFALLBACK), + SNMP_MIB_ITEM("TCPBacklogDrop", LINUX_MIB_TCPBACKLOGDROP), + SNMP_MIB_ITEM("PFMemallocDrop", LINUX_MIB_PFMEMALLOCDROP), + SNMP_MIB_ITEM("TCPMinTTLDrop", LINUX_MIB_TCPMINTTLDROP), + SNMP_MIB_ITEM("TCPDeferAcceptDrop", LINUX_MIB_TCPDEFERACCEPTDROP), + SNMP_MIB_ITEM("IPReversePathFilter", LINUX_MIB_IPRPFILTER), + SNMP_MIB_ITEM("TCPTimeWaitOverflow", LINUX_MIB_TCPTIMEWAITOVERFLOW), + SNMP_MIB_ITEM("TCPReqQFullDoCookies", LINUX_MIB_TCPREQQFULLDOCOOKIES), + SNMP_MIB_ITEM("TCPReqQFullDrop", LINUX_MIB_TCPREQQFULLDROP), + SNMP_MIB_ITEM("TCPRetransFail", LINUX_MIB_TCPRETRANSFAIL), + SNMP_MIB_ITEM("TCPRcvCoalesce", LINUX_MIB_TCPRCVCOALESCE), + SNMP_MIB_ITEM("TCPOFOQueue", LINUX_MIB_TCPOFOQUEUE), + SNMP_MIB_ITEM("TCPOFODrop", LINUX_MIB_TCPOFODROP), + SNMP_MIB_ITEM("TCPOFOMerge", LINUX_MIB_TCPOFOMERGE), + SNMP_MIB_ITEM("TCPChallengeACK", LINUX_MIB_TCPCHALLENGEACK), + SNMP_MIB_ITEM("TCPSYNChallenge", LINUX_MIB_TCPSYNCHALLENGE), + SNMP_MIB_ITEM("TCPFastOpenActive", LINUX_MIB_TCPFASTOPENACTIVE), + SNMP_MIB_ITEM("TCPFastOpenActiveFail", LINUX_MIB_TCPFASTOPENACTIVEFAIL), + SNMP_MIB_ITEM("TCPFastOpenPassive", LINUX_MIB_TCPFASTOPENPASSIVE), + SNMP_MIB_ITEM("TCPFastOpenPassiveFail", LINUX_MIB_TCPFASTOPENPASSIVEFAIL), + SNMP_MIB_ITEM("TCPFastOpenListenOverflow", LINUX_MIB_TCPFASTOPENLISTENOVERFLOW), + SNMP_MIB_ITEM("TCPFastOpenCookieReqd", LINUX_MIB_TCPFASTOPENCOOKIEREQD), + SNMP_MIB_ITEM("TCPFastOpenBlackhole", LINUX_MIB_TCPFASTOPENBLACKHOLE), + SNMP_MIB_ITEM("TCPSpuriousRtxHostQueues", LINUX_MIB_TCPSPURIOUS_RTX_HOSTQUEUES), + SNMP_MIB_ITEM("BusyPollRxPackets", LINUX_MIB_BUSYPOLLRXPACKETS), + SNMP_MIB_ITEM("TCPAutoCorking", LINUX_MIB_TCPAUTOCORKING), + SNMP_MIB_ITEM("TCPFromZeroWindowAdv", LINUX_MIB_TCPFROMZEROWINDOWADV), + SNMP_MIB_ITEM("TCPToZeroWindowAdv", LINUX_MIB_TCPTOZEROWINDOWADV), + SNMP_MIB_ITEM("TCPWantZeroWindowAdv", LINUX_MIB_TCPWANTZEROWINDOWADV), + SNMP_MIB_ITEM("TCPSynRetrans", LINUX_MIB_TCPSYNRETRANS), + SNMP_MIB_ITEM("TCPOrigDataSent", LINUX_MIB_TCPORIGDATASENT), + SNMP_MIB_ITEM("TCPHystartTrainDetect", LINUX_MIB_TCPHYSTARTTRAINDETECT), + SNMP_MIB_ITEM("TCPHystartTrainCwnd", LINUX_MIB_TCPHYSTARTTRAINCWND), + SNMP_MIB_ITEM("TCPHystartDelayDetect", LINUX_MIB_TCPHYSTARTDELAYDETECT), + SNMP_MIB_ITEM("TCPHystartDelayCwnd", LINUX_MIB_TCPHYSTARTDELAYCWND), + SNMP_MIB_ITEM("TCPACKSkippedSynRecv", LINUX_MIB_TCPACKSKIPPEDSYNRECV), + SNMP_MIB_ITEM("TCPACKSkippedPAWS", LINUX_MIB_TCPACKSKIPPEDPAWS), + SNMP_MIB_ITEM("TCPACKSkippedSeq", LINUX_MIB_TCPACKSKIPPEDSEQ), + SNMP_MIB_ITEM("TCPACKSkippedFinWait2", LINUX_MIB_TCPACKSKIPPEDFINWAIT2), + SNMP_MIB_ITEM("TCPACKSkippedTimeWait", LINUX_MIB_TCPACKSKIPPEDTIMEWAIT), + SNMP_MIB_ITEM("TCPACKSkippedChallenge", LINUX_MIB_TCPACKSKIPPEDCHALLENGE), + SNMP_MIB_ITEM("TCPWinProbe", LINUX_MIB_TCPWINPROBE), + SNMP_MIB_ITEM("TCPKeepAlive", LINUX_MIB_TCPKEEPALIVE), + SNMP_MIB_ITEM("TCPMTUPFail", LINUX_MIB_TCPMTUPFAIL), + SNMP_MIB_ITEM("TCPMTUPSuccess", LINUX_MIB_TCPMTUPSUCCESS), + SNMP_MIB_ITEM("TCPDelivered", LINUX_MIB_TCPDELIVERED), + SNMP_MIB_ITEM("TCPDeliveredCE", LINUX_MIB_TCPDELIVEREDCE), + SNMP_MIB_ITEM("TCPAckCompressed", LINUX_MIB_TCPACKCOMPRESSED), + SNMP_MIB_ITEM("TCPZeroWindowDrop", LINUX_MIB_TCPZEROWINDOWDROP), + SNMP_MIB_ITEM("TCPRcvQDrop", LINUX_MIB_TCPRCVQDROP), + SNMP_MIB_ITEM("TCPWqueueTooBig", LINUX_MIB_TCPWQUEUETOOBIG), + SNMP_MIB_ITEM("TCPFastOpenPassiveAltKey", LINUX_MIB_TCPFASTOPENPASSIVEALTKEY), + SNMP_MIB_ITEM("TcpTimeoutRehash", LINUX_MIB_TCPTIMEOUTREHASH), + SNMP_MIB_ITEM("TcpDuplicateDataRehash", LINUX_MIB_TCPDUPLICATEDATAREHASH), + SNMP_MIB_ITEM("TCPDSACKRecvSegs", LINUX_MIB_TCPDSACKRECVSEGS), + SNMP_MIB_ITEM("TCPDSACKIgnoredDubious", LINUX_MIB_TCPDSACKIGNOREDDUBIOUS), + SNMP_MIB_ITEM("TCPMigrateReqSuccess", LINUX_MIB_TCPMIGRATEREQSUCCESS), + SNMP_MIB_ITEM("TCPMigrateReqFailure", LINUX_MIB_TCPMIGRATEREQFAILURE), + SNMP_MIB_SENTINEL +}; + +static void icmpmsg_put_line(struct seq_file *seq, unsigned long *vals, + unsigned short *type, int count) +{ + int j; + + if (count) { + seq_puts(seq, "\nIcmpMsg:"); + for (j = 0; j < count; ++j) + seq_printf(seq, " %sType%u", + type[j] & 0x100 ? "Out" : "In", + type[j] & 0xff); + seq_puts(seq, "\nIcmpMsg:"); + for (j = 0; j < count; ++j) + seq_printf(seq, " %lu", vals[j]); + } +} + +static void icmpmsg_put(struct seq_file *seq) +{ +#define PERLINE 16 + + int i, count; + unsigned short type[PERLINE]; + unsigned long vals[PERLINE], val; + struct net *net = seq->private; + + count = 0; + for (i = 0; i < ICMPMSG_MIB_MAX; i++) { + val = atomic_long_read(&net->mib.icmpmsg_statistics->mibs[i]); + if (val) { + type[count] = i; + vals[count++] = val; + } + if (count == PERLINE) { + icmpmsg_put_line(seq, vals, type, count); + count = 0; + } + } + icmpmsg_put_line(seq, vals, type, count); + +#undef PERLINE +} + +static void icmp_put(struct seq_file *seq) +{ + int i; + struct net *net = seq->private; + atomic_long_t *ptr = net->mib.icmpmsg_statistics->mibs; + + seq_puts(seq, "\nIcmp: InMsgs InErrors InCsumErrors"); + for (i = 0; icmpmibmap[i].name; i++) + seq_printf(seq, " In%s", icmpmibmap[i].name); + seq_puts(seq, " OutMsgs OutErrors"); + for (i = 0; icmpmibmap[i].name; i++) + seq_printf(seq, " Out%s", icmpmibmap[i].name); + seq_printf(seq, "\nIcmp: %lu %lu %lu", + snmp_fold_field(net->mib.icmp_statistics, ICMP_MIB_INMSGS), + snmp_fold_field(net->mib.icmp_statistics, ICMP_MIB_INERRORS), + snmp_fold_field(net->mib.icmp_statistics, ICMP_MIB_CSUMERRORS)); + for (i = 0; icmpmibmap[i].name; i++) + seq_printf(seq, " %lu", + atomic_long_read(ptr + icmpmibmap[i].index)); + seq_printf(seq, " %lu %lu", + snmp_fold_field(net->mib.icmp_statistics, ICMP_MIB_OUTMSGS), + snmp_fold_field(net->mib.icmp_statistics, ICMP_MIB_OUTERRORS)); + for (i = 0; icmpmibmap[i].name; i++) + seq_printf(seq, " %lu", + atomic_long_read(ptr + (icmpmibmap[i].index | 0x100))); +} + +/* + * Called from the PROCfs module. This outputs /proc/net/snmp. + */ +static int snmp_seq_show_ipstats(struct seq_file *seq, void *v) +{ + struct net *net = seq->private; + u64 buff64[IPSTATS_MIB_MAX]; + int i; + + memset(buff64, 0, IPSTATS_MIB_MAX * sizeof(u64)); + + seq_puts(seq, "Ip: Forwarding DefaultTTL"); + for (i = 0; snmp4_ipstats_list[i].name; i++) + seq_printf(seq, " %s", snmp4_ipstats_list[i].name); + + seq_printf(seq, "\nIp: %d %d", + IPV4_DEVCONF_ALL(net, FORWARDING) ? 1 : 2, + READ_ONCE(net->ipv4.sysctl_ip_default_ttl)); + + BUILD_BUG_ON(offsetof(struct ipstats_mib, mibs) != 0); + snmp_get_cpu_field64_batch(buff64, snmp4_ipstats_list, + net->mib.ip_statistics, + offsetof(struct ipstats_mib, syncp)); + for (i = 0; snmp4_ipstats_list[i].name; i++) + seq_printf(seq, " %llu", buff64[i]); + + return 0; +} + +static int snmp_seq_show_tcp_udp(struct seq_file *seq, void *v) +{ + unsigned long buff[TCPUDP_MIB_MAX]; + struct net *net = seq->private; + int i; + + memset(buff, 0, TCPUDP_MIB_MAX * sizeof(unsigned long)); + + seq_puts(seq, "\nTcp:"); + for (i = 0; snmp4_tcp_list[i].name; i++) + seq_printf(seq, " %s", snmp4_tcp_list[i].name); + + seq_puts(seq, "\nTcp:"); + snmp_get_cpu_field_batch(buff, snmp4_tcp_list, + net->mib.tcp_statistics); + for (i = 0; snmp4_tcp_list[i].name; i++) { + /* MaxConn field is signed, RFC 2012 */ + if (snmp4_tcp_list[i].entry == TCP_MIB_MAXCONN) + seq_printf(seq, " %ld", buff[i]); + else + seq_printf(seq, " %lu", buff[i]); + } + + memset(buff, 0, TCPUDP_MIB_MAX * sizeof(unsigned long)); + + snmp_get_cpu_field_batch(buff, snmp4_udp_list, + net->mib.udp_statistics); + seq_puts(seq, "\nUdp:"); + for (i = 0; snmp4_udp_list[i].name; i++) + seq_printf(seq, " %s", snmp4_udp_list[i].name); + seq_puts(seq, "\nUdp:"); + for (i = 0; snmp4_udp_list[i].name; i++) + seq_printf(seq, " %lu", buff[i]); + + memset(buff, 0, TCPUDP_MIB_MAX * sizeof(unsigned long)); + + /* the UDP and UDP-Lite MIBs are the same */ + seq_puts(seq, "\nUdpLite:"); + snmp_get_cpu_field_batch(buff, snmp4_udp_list, + net->mib.udplite_statistics); + for (i = 0; snmp4_udp_list[i].name; i++) + seq_printf(seq, " %s", snmp4_udp_list[i].name); + seq_puts(seq, "\nUdpLite:"); + for (i = 0; snmp4_udp_list[i].name; i++) + seq_printf(seq, " %lu", buff[i]); + + seq_putc(seq, '\n'); + return 0; +} + +static int snmp_seq_show(struct seq_file *seq, void *v) +{ + snmp_seq_show_ipstats(seq, v); + + icmp_put(seq); /* RFC 2011 compatibility */ + icmpmsg_put(seq); + + snmp_seq_show_tcp_udp(seq, v); + + return 0; +} + +/* + * Output /proc/net/netstat + */ +static int netstat_seq_show(struct seq_file *seq, void *v) +{ + const int ip_cnt = ARRAY_SIZE(snmp4_ipextstats_list) - 1; + const int tcp_cnt = ARRAY_SIZE(snmp4_net_list) - 1; + struct net *net = seq->private; + unsigned long *buff; + int i; + + seq_puts(seq, "TcpExt:"); + for (i = 0; i < tcp_cnt; i++) + seq_printf(seq, " %s", snmp4_net_list[i].name); + + seq_puts(seq, "\nTcpExt:"); + buff = kzalloc(max(tcp_cnt * sizeof(long), ip_cnt * sizeof(u64)), + GFP_KERNEL); + if (buff) { + snmp_get_cpu_field_batch(buff, snmp4_net_list, + net->mib.net_statistics); + for (i = 0; i < tcp_cnt; i++) + seq_printf(seq, " %lu", buff[i]); + } else { + for (i = 0; i < tcp_cnt; i++) + seq_printf(seq, " %lu", + snmp_fold_field(net->mib.net_statistics, + snmp4_net_list[i].entry)); + } + seq_puts(seq, "\nIpExt:"); + for (i = 0; i < ip_cnt; i++) + seq_printf(seq, " %s", snmp4_ipextstats_list[i].name); + + seq_puts(seq, "\nIpExt:"); + if (buff) { + u64 *buff64 = (u64 *)buff; + + memset(buff64, 0, ip_cnt * sizeof(u64)); + snmp_get_cpu_field64_batch(buff64, snmp4_ipextstats_list, + net->mib.ip_statistics, + offsetof(struct ipstats_mib, syncp)); + for (i = 0; i < ip_cnt; i++) + seq_printf(seq, " %llu", buff64[i]); + } else { + for (i = 0; i < ip_cnt; i++) + seq_printf(seq, " %llu", + snmp_fold_field64(net->mib.ip_statistics, + snmp4_ipextstats_list[i].entry, + offsetof(struct ipstats_mib, syncp))); + } + kfree(buff); + seq_putc(seq, '\n'); + mptcp_seq_show(seq); + return 0; +} + +static __net_init int ip_proc_init_net(struct net *net) +{ + if (!proc_create_net_single("sockstat", 0444, net->proc_net, + sockstat_seq_show, NULL)) + goto out_sockstat; + if (!proc_create_net_single("netstat", 0444, net->proc_net, + netstat_seq_show, NULL)) + goto out_netstat; + if (!proc_create_net_single("snmp", 0444, net->proc_net, snmp_seq_show, + NULL)) + goto out_snmp; + + return 0; + +out_snmp: + remove_proc_entry("netstat", net->proc_net); +out_netstat: + remove_proc_entry("sockstat", net->proc_net); +out_sockstat: + return -ENOMEM; +} + +static __net_exit void ip_proc_exit_net(struct net *net) +{ + remove_proc_entry("snmp", net->proc_net); + remove_proc_entry("netstat", net->proc_net); + remove_proc_entry("sockstat", net->proc_net); +} + +static __net_initdata struct pernet_operations ip_proc_ops = { + .init = ip_proc_init_net, + .exit = ip_proc_exit_net, +}; + +int __init ip_misc_proc_init(void) +{ + return register_pernet_subsys(&ip_proc_ops); +} diff --git a/net/ipv4/protocol.c b/net/ipv4/protocol.c new file mode 100644 index 000000000..691397994 --- /dev/null +++ b/net/ipv4/protocol.c @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * INET protocol dispatch tables. + * + * Authors: Ross Biro + * Fred N. van Kempen, + * + * Fixes: + * Alan Cox : Ahah! udp icmp errors don't work because + * udp_err is never called! + * Alan Cox : Added new fields for init and ready for + * proper fragmentation (_NO_ 4K limits!) + * Richard Colella : Hang on hash collision + * Vince Laviano : Modified inet_del_protocol() to correctly + * maintain copy bit. + */ +#include +#include +#include +#include +#include + +struct net_protocol __rcu *inet_protos[MAX_INET_PROTOS] __read_mostly; +EXPORT_SYMBOL(inet_protos); +const struct net_offload __rcu *inet_offloads[MAX_INET_PROTOS] __read_mostly; +EXPORT_SYMBOL(inet_offloads); + +int inet_add_protocol(const struct net_protocol *prot, unsigned char protocol) +{ + return !cmpxchg((const struct net_protocol **)&inet_protos[protocol], + NULL, prot) ? 0 : -1; +} +EXPORT_SYMBOL(inet_add_protocol); + +int inet_add_offload(const struct net_offload *prot, unsigned char protocol) +{ + return !cmpxchg((const struct net_offload **)&inet_offloads[protocol], + NULL, prot) ? 0 : -1; +} +EXPORT_SYMBOL(inet_add_offload); + +int inet_del_protocol(const struct net_protocol *prot, unsigned char protocol) +{ + int ret; + + ret = (cmpxchg((const struct net_protocol **)&inet_protos[protocol], + prot, NULL) == prot) ? 0 : -1; + + synchronize_net(); + + return ret; +} +EXPORT_SYMBOL(inet_del_protocol); + +int inet_del_offload(const struct net_offload *prot, unsigned char protocol) +{ + int ret; + + ret = (cmpxchg((const struct net_offload **)&inet_offloads[protocol], + prot, NULL) == prot) ? 0 : -1; + + synchronize_net(); + + return ret; +} +EXPORT_SYMBOL(inet_del_offload); diff --git a/net/ipv4/raw.c b/net/ipv4/raw.c new file mode 100644 index 000000000..639aa5abd --- /dev/null +++ b/net/ipv4/raw.c @@ -0,0 +1,1112 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * RAW - implementation of IP "raw" sockets. + * + * Authors: Ross Biro + * Fred N. van Kempen, + * + * Fixes: + * Alan Cox : verify_area() fixed up + * Alan Cox : ICMP error handling + * Alan Cox : EMSGSIZE if you send too big a packet + * Alan Cox : Now uses generic datagrams and shared + * skbuff library. No more peek crashes, + * no more backlogs + * Alan Cox : Checks sk->broadcast. + * Alan Cox : Uses skb_free_datagram/skb_copy_datagram + * Alan Cox : Raw passes ip options too + * Alan Cox : Setsocketopt added + * Alan Cox : Fixed error return for broadcasts + * Alan Cox : Removed wake_up calls + * Alan Cox : Use ttl/tos + * Alan Cox : Cleaned up old debugging + * Alan Cox : Use new kernel side addresses + * Arnt Gulbrandsen : Fixed MSG_DONTROUTE in raw sockets. + * Alan Cox : BSD style RAW socket demultiplexing. + * Alan Cox : Beginnings of mrouted support. + * Alan Cox : Added IP_HDRINCL option. + * Alan Cox : Skip broadcast check if BSDism set. + * David S. Miller : New socket lookup architecture. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct raw_frag_vec { + struct msghdr *msg; + union { + struct icmphdr icmph; + char c[1]; + } hdr; + int hlen; +}; + +struct raw_hashinfo raw_v4_hashinfo; +EXPORT_SYMBOL_GPL(raw_v4_hashinfo); + +int raw_hash_sk(struct sock *sk) +{ + struct raw_hashinfo *h = sk->sk_prot->h.raw_hash; + struct hlist_head *hlist; + + hlist = &h->ht[raw_hashfunc(sock_net(sk), inet_sk(sk)->inet_num)]; + + spin_lock(&h->lock); + sk_add_node_rcu(sk, hlist); + sock_set_flag(sk, SOCK_RCU_FREE); + spin_unlock(&h->lock); + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); + + return 0; +} +EXPORT_SYMBOL_GPL(raw_hash_sk); + +void raw_unhash_sk(struct sock *sk) +{ + struct raw_hashinfo *h = sk->sk_prot->h.raw_hash; + + spin_lock(&h->lock); + if (sk_del_node_init_rcu(sk)) + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); + spin_unlock(&h->lock); +} +EXPORT_SYMBOL_GPL(raw_unhash_sk); + +bool raw_v4_match(struct net *net, struct sock *sk, unsigned short num, + __be32 raddr, __be32 laddr, int dif, int sdif) +{ + struct inet_sock *inet = inet_sk(sk); + + if (net_eq(sock_net(sk), net) && inet->inet_num == num && + !(inet->inet_daddr && inet->inet_daddr != raddr) && + !(inet->inet_rcv_saddr && inet->inet_rcv_saddr != laddr) && + raw_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif)) + return true; + return false; +} +EXPORT_SYMBOL_GPL(raw_v4_match); + +/* + * 0 - deliver + * 1 - block + */ +static int icmp_filter(const struct sock *sk, const struct sk_buff *skb) +{ + struct icmphdr _hdr; + const struct icmphdr *hdr; + + hdr = skb_header_pointer(skb, skb_transport_offset(skb), + sizeof(_hdr), &_hdr); + if (!hdr) + return 1; + + if (hdr->type < 32) { + __u32 data = raw_sk(sk)->filter.data; + + return ((1U << hdr->type) & data) != 0; + } + + /* Do not block unknown ICMP types */ + return 0; +} + +/* IP input processing comes here for RAW socket delivery. + * Caller owns SKB, so we must make clones. + * + * RFC 1122: SHOULD pass TOS value up to the transport layer. + * -> It does. And not only TOS, but all IP header. + */ +static int raw_v4_input(struct net *net, struct sk_buff *skb, + const struct iphdr *iph, int hash) +{ + int sdif = inet_sdif(skb); + struct hlist_head *hlist; + int dif = inet_iif(skb); + int delivered = 0; + struct sock *sk; + + hlist = &raw_v4_hashinfo.ht[hash]; + rcu_read_lock(); + sk_for_each_rcu(sk, hlist) { + if (!raw_v4_match(net, sk, iph->protocol, + iph->saddr, iph->daddr, dif, sdif)) + continue; + delivered = 1; + if ((iph->protocol != IPPROTO_ICMP || !icmp_filter(sk, skb)) && + ip_mc_sf_allow(sk, iph->daddr, iph->saddr, + skb->dev->ifindex, sdif)) { + struct sk_buff *clone = skb_clone(skb, GFP_ATOMIC); + + /* Not releasing hash table! */ + if (clone) + raw_rcv(sk, clone); + } + } + rcu_read_unlock(); + return delivered; +} + +int raw_local_deliver(struct sk_buff *skb, int protocol) +{ + struct net *net = dev_net(skb->dev); + + return raw_v4_input(net, skb, ip_hdr(skb), + raw_hashfunc(net, protocol)); +} + +static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info) +{ + struct inet_sock *inet = inet_sk(sk); + const int type = icmp_hdr(skb)->type; + const int code = icmp_hdr(skb)->code; + int err = 0; + int harderr = 0; + + if (type == ICMP_DEST_UNREACH && code == ICMP_FRAG_NEEDED) + ipv4_sk_update_pmtu(skb, sk, info); + else if (type == ICMP_REDIRECT) { + ipv4_sk_redirect(skb, sk); + return; + } + + /* Report error on raw socket, if: + 1. User requested ip_recverr. + 2. Socket is connected (otherwise the error indication + is useless without ip_recverr and error is hard. + */ + if (!inet->recverr && sk->sk_state != TCP_ESTABLISHED) + return; + + switch (type) { + default: + case ICMP_TIME_EXCEEDED: + err = EHOSTUNREACH; + break; + case ICMP_SOURCE_QUENCH: + return; + case ICMP_PARAMETERPROB: + err = EPROTO; + harderr = 1; + break; + case ICMP_DEST_UNREACH: + err = EHOSTUNREACH; + if (code > NR_ICMP_UNREACH) + break; + if (code == ICMP_FRAG_NEEDED) { + harderr = inet->pmtudisc != IP_PMTUDISC_DONT; + err = EMSGSIZE; + } else { + err = icmp_err_convert[code].errno; + harderr = icmp_err_convert[code].fatal; + } + } + + if (inet->recverr) { + const struct iphdr *iph = (const struct iphdr *)skb->data; + u8 *payload = skb->data + (iph->ihl << 2); + + if (inet->hdrincl) + payload = skb->data; + ip_icmp_error(sk, skb, err, 0, info, payload); + } + + if (inet->recverr || harderr) { + sk->sk_err = err; + sk_error_report(sk); + } +} + +void raw_icmp_error(struct sk_buff *skb, int protocol, u32 info) +{ + struct net *net = dev_net(skb->dev); + int dif = skb->dev->ifindex; + int sdif = inet_sdif(skb); + struct hlist_head *hlist; + const struct iphdr *iph; + struct sock *sk; + int hash; + + hash = raw_hashfunc(net, protocol); + hlist = &raw_v4_hashinfo.ht[hash]; + + rcu_read_lock(); + sk_for_each_rcu(sk, hlist) { + iph = (const struct iphdr *)skb->data; + if (!raw_v4_match(net, sk, iph->protocol, + iph->daddr, iph->saddr, dif, sdif)) + continue; + raw_err(sk, skb, info); + } + rcu_read_unlock(); +} + +static int raw_rcv_skb(struct sock *sk, struct sk_buff *skb) +{ + /* Charge it to the socket. */ + + ipv4_pktinfo_prepare(sk, skb); + if (sock_queue_rcv_skb(sk, skb) < 0) { + kfree_skb(skb); + return NET_RX_DROP; + } + + return NET_RX_SUCCESS; +} + +int raw_rcv(struct sock *sk, struct sk_buff *skb) +{ + if (!xfrm4_policy_check(sk, XFRM_POLICY_IN, skb)) { + atomic_inc(&sk->sk_drops); + kfree_skb(skb); + return NET_RX_DROP; + } + nf_reset_ct(skb); + + skb_push(skb, skb->data - skb_network_header(skb)); + + raw_rcv_skb(sk, skb); + return 0; +} + +static int raw_send_hdrinc(struct sock *sk, struct flowi4 *fl4, + struct msghdr *msg, size_t length, + struct rtable **rtp, unsigned int flags, + const struct sockcm_cookie *sockc) +{ + struct inet_sock *inet = inet_sk(sk); + struct net *net = sock_net(sk); + struct iphdr *iph; + struct sk_buff *skb; + unsigned int iphlen; + int err; + struct rtable *rt = *rtp; + int hlen, tlen; + + if (length > rt->dst.dev->mtu) { + ip_local_error(sk, EMSGSIZE, fl4->daddr, inet->inet_dport, + rt->dst.dev->mtu); + return -EMSGSIZE; + } + if (length < sizeof(struct iphdr)) + return -EINVAL; + + if (flags&MSG_PROBE) + goto out; + + hlen = LL_RESERVED_SPACE(rt->dst.dev); + tlen = rt->dst.dev->needed_tailroom; + skb = sock_alloc_send_skb(sk, + length + hlen + tlen + 15, + flags & MSG_DONTWAIT, &err); + if (!skb) + goto error; + skb_reserve(skb, hlen); + + skb->priority = READ_ONCE(sk->sk_priority); + skb->mark = sockc->mark; + skb->tstamp = sockc->transmit_time; + skb_dst_set(skb, &rt->dst); + *rtp = NULL; + + skb_reset_network_header(skb); + iph = ip_hdr(skb); + skb_put(skb, length); + + skb->ip_summed = CHECKSUM_NONE; + + skb_setup_tx_timestamp(skb, sockc->tsflags); + + if (flags & MSG_CONFIRM) + skb_set_dst_pending_confirm(skb, 1); + + skb->transport_header = skb->network_header; + err = -EFAULT; + if (memcpy_from_msg(iph, msg, length)) + goto error_free; + + iphlen = iph->ihl * 4; + + /* + * We don't want to modify the ip header, but we do need to + * be sure that it won't cause problems later along the network + * stack. Specifically we want to make sure that iph->ihl is a + * sane value. If ihl points beyond the length of the buffer passed + * in, reject the frame as invalid + */ + err = -EINVAL; + if (iphlen > length) + goto error_free; + + if (iphlen >= sizeof(*iph)) { + if (!iph->saddr) + iph->saddr = fl4->saddr; + iph->check = 0; + iph->tot_len = htons(length); + if (!iph->id) + ip_select_ident(net, skb, NULL); + + iph->check = ip_fast_csum((unsigned char *)iph, iph->ihl); + skb->transport_header += iphlen; + if (iph->protocol == IPPROTO_ICMP && + length >= iphlen + sizeof(struct icmphdr)) + icmp_out_count(net, ((struct icmphdr *) + skb_transport_header(skb))->type); + } + + err = NF_HOOK(NFPROTO_IPV4, NF_INET_LOCAL_OUT, + net, sk, skb, NULL, rt->dst.dev, + dst_output); + if (err > 0) + err = net_xmit_errno(err); + if (err) + goto error; +out: + return 0; + +error_free: + kfree_skb(skb); +error: + IP_INC_STATS(net, IPSTATS_MIB_OUTDISCARDS); + if (err == -ENOBUFS && !inet->recverr) + err = 0; + return err; +} + +static int raw_probe_proto_opt(struct raw_frag_vec *rfv, struct flowi4 *fl4) +{ + int err; + + if (fl4->flowi4_proto != IPPROTO_ICMP) + return 0; + + /* We only need the first two bytes. */ + rfv->hlen = 2; + + err = memcpy_from_msg(rfv->hdr.c, rfv->msg, rfv->hlen); + if (err) + return err; + + fl4->fl4_icmp_type = rfv->hdr.icmph.type; + fl4->fl4_icmp_code = rfv->hdr.icmph.code; + + return 0; +} + +static int raw_getfrag(void *from, char *to, int offset, int len, int odd, + struct sk_buff *skb) +{ + struct raw_frag_vec *rfv = from; + + if (offset < rfv->hlen) { + int copy = min(rfv->hlen - offset, len); + + if (skb->ip_summed == CHECKSUM_PARTIAL) + memcpy(to, rfv->hdr.c + offset, copy); + else + skb->csum = csum_block_add( + skb->csum, + csum_partial_copy_nocheck(rfv->hdr.c + offset, + to, copy), + odd); + + odd = 0; + offset += copy; + to += copy; + len -= copy; + + if (!len) + return 0; + } + + offset -= rfv->hlen; + + return ip_generic_getfrag(rfv->msg, to, offset, len, odd, skb); +} + +static int raw_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) +{ + struct inet_sock *inet = inet_sk(sk); + struct net *net = sock_net(sk); + struct ipcm_cookie ipc; + struct rtable *rt = NULL; + struct flowi4 fl4; + int free = 0; + __be32 daddr; + __be32 saddr; + u8 tos; + int err; + struct ip_options_data opt_copy; + struct raw_frag_vec rfv; + int hdrincl; + + err = -EMSGSIZE; + if (len > 0xFFFF) + goto out; + + /* hdrincl should be READ_ONCE(inet->hdrincl) + * but READ_ONCE() doesn't work with bit fields. + * Doing this indirectly yields the same result. + */ + hdrincl = inet->hdrincl; + hdrincl = READ_ONCE(hdrincl); + /* + * Check the flags. + */ + + err = -EOPNOTSUPP; + if (msg->msg_flags & MSG_OOB) /* Mirror BSD error message */ + goto out; /* compatibility */ + + /* + * Get and verify the address. + */ + + if (msg->msg_namelen) { + DECLARE_SOCKADDR(struct sockaddr_in *, usin, msg->msg_name); + err = -EINVAL; + if (msg->msg_namelen < sizeof(*usin)) + goto out; + if (usin->sin_family != AF_INET) { + pr_info_once("%s: %s forgot to set AF_INET. Fix it!\n", + __func__, current->comm); + err = -EAFNOSUPPORT; + if (usin->sin_family) + goto out; + } + daddr = usin->sin_addr.s_addr; + /* ANK: I did not forget to get protocol from port field. + * I just do not know, who uses this weirdness. + * IP_HDRINCL is much more convenient. + */ + } else { + err = -EDESTADDRREQ; + if (sk->sk_state != TCP_ESTABLISHED) + goto out; + daddr = inet->inet_daddr; + } + + ipcm_init_sk(&ipc, inet); + /* Keep backward compat */ + if (hdrincl) + ipc.protocol = IPPROTO_RAW; + + if (msg->msg_controllen) { + err = ip_cmsg_send(sk, msg, &ipc, false); + if (unlikely(err)) { + kfree(ipc.opt); + goto out; + } + if (ipc.opt) + free = 1; + } + + saddr = ipc.addr; + ipc.addr = daddr; + + if (!ipc.opt) { + struct ip_options_rcu *inet_opt; + + rcu_read_lock(); + inet_opt = rcu_dereference(inet->inet_opt); + if (inet_opt) { + memcpy(&opt_copy, inet_opt, + sizeof(*inet_opt) + inet_opt->opt.optlen); + ipc.opt = &opt_copy.opt; + } + rcu_read_unlock(); + } + + if (ipc.opt) { + err = -EINVAL; + /* Linux does not mangle headers on raw sockets, + * so that IP options + IP_HDRINCL is non-sense. + */ + if (hdrincl) + goto done; + if (ipc.opt->opt.srr) { + if (!daddr) + goto done; + daddr = ipc.opt->opt.faddr; + } + } + tos = get_rtconn_flags(&ipc, sk); + if (msg->msg_flags & MSG_DONTROUTE) + tos |= RTO_ONLINK; + + if (ipv4_is_multicast(daddr)) { + if (!ipc.oif || netif_index_is_l3_master(sock_net(sk), ipc.oif)) + ipc.oif = inet->mc_index; + if (!saddr) + saddr = inet->mc_addr; + } else if (!ipc.oif) { + ipc.oif = inet->uc_index; + } else if (ipv4_is_lbcast(daddr) && inet->uc_index) { + /* oif is set, packet is to local broadcast + * and uc_index is set. oif is most likely set + * by sk_bound_dev_if. If uc_index != oif check if the + * oif is an L3 master and uc_index is an L3 slave. + * If so, we want to allow the send using the uc_index. + */ + if (ipc.oif != inet->uc_index && + ipc.oif == l3mdev_master_ifindex_by_index(sock_net(sk), + inet->uc_index)) { + ipc.oif = inet->uc_index; + } + } + + flowi4_init_output(&fl4, ipc.oif, ipc.sockc.mark, tos, + RT_SCOPE_UNIVERSE, + hdrincl ? ipc.protocol : sk->sk_protocol, + inet_sk_flowi_flags(sk) | + (hdrincl ? FLOWI_FLAG_KNOWN_NH : 0), + daddr, saddr, 0, 0, sk->sk_uid); + + if (!hdrincl) { + rfv.msg = msg; + rfv.hlen = 0; + + err = raw_probe_proto_opt(&rfv, &fl4); + if (err) + goto done; + } + + security_sk_classify_flow(sk, flowi4_to_flowi_common(&fl4)); + rt = ip_route_output_flow(net, &fl4, sk); + if (IS_ERR(rt)) { + err = PTR_ERR(rt); + rt = NULL; + goto done; + } + + err = -EACCES; + if (rt->rt_flags & RTCF_BROADCAST && !sock_flag(sk, SOCK_BROADCAST)) + goto done; + + if (msg->msg_flags & MSG_CONFIRM) + goto do_confirm; +back_from_confirm: + + if (hdrincl) + err = raw_send_hdrinc(sk, &fl4, msg, len, + &rt, msg->msg_flags, &ipc.sockc); + + else { + if (!ipc.addr) + ipc.addr = fl4.daddr; + lock_sock(sk); + err = ip_append_data(sk, &fl4, raw_getfrag, + &rfv, len, 0, + &ipc, &rt, msg->msg_flags); + if (err) + ip_flush_pending_frames(sk); + else if (!(msg->msg_flags & MSG_MORE)) { + err = ip_push_pending_frames(sk, &fl4); + if (err == -ENOBUFS && !inet->recverr) + err = 0; + } + release_sock(sk); + } +done: + if (free) + kfree(ipc.opt); + ip_rt_put(rt); + +out: + if (err < 0) + return err; + return len; + +do_confirm: + if (msg->msg_flags & MSG_PROBE) + dst_confirm_neigh(&rt->dst, &fl4.daddr); + if (!(msg->msg_flags & MSG_PROBE) || len) + goto back_from_confirm; + err = 0; + goto done; +} + +static void raw_close(struct sock *sk, long timeout) +{ + /* + * Raw sockets may have direct kernel references. Kill them. + */ + ip_ra_control(sk, 0, NULL); + + sk_common_release(sk); +} + +static void raw_destroy(struct sock *sk) +{ + lock_sock(sk); + ip_flush_pending_frames(sk); + release_sock(sk); +} + +/* This gets rid of all the nasties in af_inet. -DaveM */ +static int raw_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len) +{ + struct inet_sock *inet = inet_sk(sk); + struct sockaddr_in *addr = (struct sockaddr_in *) uaddr; + struct net *net = sock_net(sk); + u32 tb_id = RT_TABLE_LOCAL; + int ret = -EINVAL; + int chk_addr_ret; + + lock_sock(sk); + if (sk->sk_state != TCP_CLOSE || addr_len < sizeof(struct sockaddr_in)) + goto out; + + if (sk->sk_bound_dev_if) + tb_id = l3mdev_fib_table_by_index(net, + sk->sk_bound_dev_if) ? : tb_id; + + chk_addr_ret = inet_addr_type_table(net, addr->sin_addr.s_addr, tb_id); + + ret = -EADDRNOTAVAIL; + if (!inet_addr_valid_or_nonlocal(net, inet, addr->sin_addr.s_addr, + chk_addr_ret)) + goto out; + + inet->inet_rcv_saddr = inet->inet_saddr = addr->sin_addr.s_addr; + if (chk_addr_ret == RTN_MULTICAST || chk_addr_ret == RTN_BROADCAST) + inet->inet_saddr = 0; /* Use device */ + sk_dst_reset(sk); + ret = 0; +out: + release_sock(sk); + return ret; +} + +/* + * This should be easy, if there is something there + * we return it, otherwise we block. + */ + +static int raw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int flags, int *addr_len) +{ + struct inet_sock *inet = inet_sk(sk); + size_t copied = 0; + int err = -EOPNOTSUPP; + DECLARE_SOCKADDR(struct sockaddr_in *, sin, msg->msg_name); + struct sk_buff *skb; + + if (flags & MSG_OOB) + goto out; + + if (flags & MSG_ERRQUEUE) { + err = ip_recv_error(sk, msg, len, addr_len); + goto out; + } + + skb = skb_recv_datagram(sk, flags, &err); + if (!skb) + goto out; + + copied = skb->len; + if (len < copied) { + msg->msg_flags |= MSG_TRUNC; + copied = len; + } + + err = skb_copy_datagram_msg(skb, 0, msg, copied); + if (err) + goto done; + + sock_recv_cmsgs(msg, sk, skb); + + /* Copy the address. */ + if (sin) { + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = ip_hdr(skb)->saddr; + sin->sin_port = 0; + memset(&sin->sin_zero, 0, sizeof(sin->sin_zero)); + *addr_len = sizeof(*sin); + } + if (inet->cmsg_flags) + ip_cmsg_recv(msg, skb); + if (flags & MSG_TRUNC) + copied = skb->len; +done: + skb_free_datagram(sk, skb); +out: + if (err) + return err; + return copied; +} + +static int raw_sk_init(struct sock *sk) +{ + struct raw_sock *rp = raw_sk(sk); + + if (inet_sk(sk)->inet_num == IPPROTO_ICMP) + memset(&rp->filter, 0, sizeof(rp->filter)); + return 0; +} + +static int raw_seticmpfilter(struct sock *sk, sockptr_t optval, int optlen) +{ + if (optlen > sizeof(struct icmp_filter)) + optlen = sizeof(struct icmp_filter); + if (copy_from_sockptr(&raw_sk(sk)->filter, optval, optlen)) + return -EFAULT; + return 0; +} + +static int raw_geticmpfilter(struct sock *sk, char __user *optval, int __user *optlen) +{ + int len, ret = -EFAULT; + + if (get_user(len, optlen)) + goto out; + ret = -EINVAL; + if (len < 0) + goto out; + if (len > sizeof(struct icmp_filter)) + len = sizeof(struct icmp_filter); + ret = -EFAULT; + if (put_user(len, optlen) || + copy_to_user(optval, &raw_sk(sk)->filter, len)) + goto out; + ret = 0; +out: return ret; +} + +static int do_raw_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + if (optname == ICMP_FILTER) { + if (inet_sk(sk)->inet_num != IPPROTO_ICMP) + return -EOPNOTSUPP; + else + return raw_seticmpfilter(sk, optval, optlen); + } + return -ENOPROTOOPT; +} + +static int raw_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + if (level != SOL_RAW) + return ip_setsockopt(sk, level, optname, optval, optlen); + return do_raw_setsockopt(sk, level, optname, optval, optlen); +} + +static int do_raw_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen) +{ + if (optname == ICMP_FILTER) { + if (inet_sk(sk)->inet_num != IPPROTO_ICMP) + return -EOPNOTSUPP; + else + return raw_geticmpfilter(sk, optval, optlen); + } + return -ENOPROTOOPT; +} + +static int raw_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen) +{ + if (level != SOL_RAW) + return ip_getsockopt(sk, level, optname, optval, optlen); + return do_raw_getsockopt(sk, level, optname, optval, optlen); +} + +static int raw_ioctl(struct sock *sk, int cmd, unsigned long arg) +{ + switch (cmd) { + case SIOCOUTQ: { + int amount = sk_wmem_alloc_get(sk); + + return put_user(amount, (int __user *)arg); + } + case SIOCINQ: { + struct sk_buff *skb; + int amount = 0; + + spin_lock_bh(&sk->sk_receive_queue.lock); + skb = skb_peek(&sk->sk_receive_queue); + if (skb) + amount = skb->len; + spin_unlock_bh(&sk->sk_receive_queue.lock); + return put_user(amount, (int __user *)arg); + } + + default: +#ifdef CONFIG_IP_MROUTE + return ipmr_ioctl(sk, cmd, (void __user *)arg); +#else + return -ENOIOCTLCMD; +#endif + } +} + +#ifdef CONFIG_COMPAT +static int compat_raw_ioctl(struct sock *sk, unsigned int cmd, unsigned long arg) +{ + switch (cmd) { + case SIOCOUTQ: + case SIOCINQ: + return -ENOIOCTLCMD; + default: +#ifdef CONFIG_IP_MROUTE + return ipmr_compat_ioctl(sk, cmd, compat_ptr(arg)); +#else + return -ENOIOCTLCMD; +#endif + } +} +#endif + +int raw_abort(struct sock *sk, int err) +{ + lock_sock(sk); + + sk->sk_err = err; + sk_error_report(sk); + __udp_disconnect(sk, 0); + + release_sock(sk); + + return 0; +} +EXPORT_SYMBOL_GPL(raw_abort); + +struct proto raw_prot = { + .name = "RAW", + .owner = THIS_MODULE, + .close = raw_close, + .destroy = raw_destroy, + .connect = ip4_datagram_connect, + .disconnect = __udp_disconnect, + .ioctl = raw_ioctl, + .init = raw_sk_init, + .setsockopt = raw_setsockopt, + .getsockopt = raw_getsockopt, + .sendmsg = raw_sendmsg, + .recvmsg = raw_recvmsg, + .bind = raw_bind, + .backlog_rcv = raw_rcv_skb, + .release_cb = ip4_datagram_release_cb, + .hash = raw_hash_sk, + .unhash = raw_unhash_sk, + .obj_size = sizeof(struct raw_sock), + .useroffset = offsetof(struct raw_sock, filter), + .usersize = sizeof_field(struct raw_sock, filter), + .h.raw_hash = &raw_v4_hashinfo, +#ifdef CONFIG_COMPAT + .compat_ioctl = compat_raw_ioctl, +#endif + .diag_destroy = raw_abort, +}; + +#ifdef CONFIG_PROC_FS +static struct sock *raw_get_first(struct seq_file *seq, int bucket) +{ + struct raw_hashinfo *h = pde_data(file_inode(seq->file)); + struct raw_iter_state *state = raw_seq_private(seq); + struct hlist_head *hlist; + struct sock *sk; + + for (state->bucket = bucket; state->bucket < RAW_HTABLE_SIZE; + ++state->bucket) { + hlist = &h->ht[state->bucket]; + sk_for_each(sk, hlist) { + if (sock_net(sk) == seq_file_net(seq)) + return sk; + } + } + return NULL; +} + +static struct sock *raw_get_next(struct seq_file *seq, struct sock *sk) +{ + struct raw_iter_state *state = raw_seq_private(seq); + + do { + sk = sk_next(sk); + } while (sk && sock_net(sk) != seq_file_net(seq)); + + if (!sk) + return raw_get_first(seq, state->bucket + 1); + return sk; +} + +static struct sock *raw_get_idx(struct seq_file *seq, loff_t pos) +{ + struct sock *sk = raw_get_first(seq, 0); + + if (sk) + while (pos && (sk = raw_get_next(seq, sk)) != NULL) + --pos; + return pos ? NULL : sk; +} + +void *raw_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(&h->lock) +{ + struct raw_hashinfo *h = pde_data(file_inode(seq->file)); + + spin_lock(&h->lock); + + return *pos ? raw_get_idx(seq, *pos - 1) : SEQ_START_TOKEN; +} +EXPORT_SYMBOL_GPL(raw_seq_start); + +void *raw_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct sock *sk; + + if (v == SEQ_START_TOKEN) + sk = raw_get_first(seq, 0); + else + sk = raw_get_next(seq, v); + ++*pos; + return sk; +} +EXPORT_SYMBOL_GPL(raw_seq_next); + +void raw_seq_stop(struct seq_file *seq, void *v) + __releases(&h->lock) +{ + struct raw_hashinfo *h = pde_data(file_inode(seq->file)); + + spin_unlock(&h->lock); +} +EXPORT_SYMBOL_GPL(raw_seq_stop); + +static void raw_sock_seq_show(struct seq_file *seq, struct sock *sp, int i) +{ + struct inet_sock *inet = inet_sk(sp); + __be32 dest = inet->inet_daddr, + src = inet->inet_rcv_saddr; + __u16 destp = 0, + srcp = inet->inet_num; + + seq_printf(seq, "%4d: %08X:%04X %08X:%04X" + " %02X %08X:%08X %02X:%08lX %08X %5u %8d %lu %d %pK %u\n", + i, src, srcp, dest, destp, sp->sk_state, + sk_wmem_alloc_get(sp), + sk_rmem_alloc_get(sp), + 0, 0L, 0, + from_kuid_munged(seq_user_ns(seq), sock_i_uid(sp)), + 0, sock_i_ino(sp), + refcount_read(&sp->sk_refcnt), sp, atomic_read(&sp->sk_drops)); +} + +static int raw_seq_show(struct seq_file *seq, void *v) +{ + if (v == SEQ_START_TOKEN) + seq_printf(seq, " sl local_address rem_address st tx_queue " + "rx_queue tr tm->when retrnsmt uid timeout " + "inode ref pointer drops\n"); + else + raw_sock_seq_show(seq, v, raw_seq_private(seq)->bucket); + return 0; +} + +static const struct seq_operations raw_seq_ops = { + .start = raw_seq_start, + .next = raw_seq_next, + .stop = raw_seq_stop, + .show = raw_seq_show, +}; + +static __net_init int raw_init_net(struct net *net) +{ + if (!proc_create_net_data("raw", 0444, net->proc_net, &raw_seq_ops, + sizeof(struct raw_iter_state), &raw_v4_hashinfo)) + return -ENOMEM; + + return 0; +} + +static __net_exit void raw_exit_net(struct net *net) +{ + remove_proc_entry("raw", net->proc_net); +} + +static __net_initdata struct pernet_operations raw_net_ops = { + .init = raw_init_net, + .exit = raw_exit_net, +}; + +int __init raw_proc_init(void) +{ + + return register_pernet_subsys(&raw_net_ops); +} + +void __init raw_proc_exit(void) +{ + unregister_pernet_subsys(&raw_net_ops); +} +#endif /* CONFIG_PROC_FS */ + +static void raw_sysctl_init_net(struct net *net) +{ +#ifdef CONFIG_NET_L3_MASTER_DEV + net->ipv4.sysctl_raw_l3mdev_accept = 1; +#endif +} + +static int __net_init raw_sysctl_init(struct net *net) +{ + raw_sysctl_init_net(net); + return 0; +} + +static struct pernet_operations __net_initdata raw_sysctl_ops = { + .init = raw_sysctl_init, +}; + +void __init raw_init(void) +{ + raw_sysctl_init_net(&init_net); + if (register_pernet_subsys(&raw_sysctl_ops)) + panic("RAW: failed to init sysctl parameters.\n"); +} diff --git a/net/ipv4/raw_diag.c b/net/ipv4/raw_diag.c new file mode 100644 index 000000000..da3591a66 --- /dev/null +++ b/net/ipv4/raw_diag.c @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include + +#include +#include + +#include +#include +#include + +#ifdef pr_fmt +# undef pr_fmt +#endif + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +static struct raw_hashinfo * +raw_get_hashinfo(const struct inet_diag_req_v2 *r) +{ + if (r->sdiag_family == AF_INET) { + return &raw_v4_hashinfo; +#if IS_ENABLED(CONFIG_IPV6) + } else if (r->sdiag_family == AF_INET6) { + return &raw_v6_hashinfo; +#endif + } else { + return ERR_PTR(-EINVAL); + } +} + +/* + * Due to requirement of not breaking user API we can't simply + * rename @pad field in inet_diag_req_v2 structure, instead + * use helper to figure it out. + */ + +static bool raw_lookup(struct net *net, struct sock *sk, + const struct inet_diag_req_v2 *req) +{ + struct inet_diag_req_raw *r = (void *)req; + + if (r->sdiag_family == AF_INET) + return raw_v4_match(net, sk, r->sdiag_raw_protocol, + r->id.idiag_dst[0], + r->id.idiag_src[0], + r->id.idiag_if, 0); +#if IS_ENABLED(CONFIG_IPV6) + else + return raw_v6_match(net, sk, r->sdiag_raw_protocol, + (const struct in6_addr *)r->id.idiag_src, + (const struct in6_addr *)r->id.idiag_dst, + r->id.idiag_if, 0); +#endif + return false; +} + +static struct sock *raw_sock_get(struct net *net, const struct inet_diag_req_v2 *r) +{ + struct raw_hashinfo *hashinfo = raw_get_hashinfo(r); + struct hlist_head *hlist; + struct sock *sk; + int slot; + + if (IS_ERR(hashinfo)) + return ERR_CAST(hashinfo); + + rcu_read_lock(); + for (slot = 0; slot < RAW_HTABLE_SIZE; slot++) { + hlist = &hashinfo->ht[slot]; + sk_for_each_rcu(sk, hlist) { + if (raw_lookup(net, sk, r)) { + /* + * Grab it and keep until we fill + * diag message to be reported, so + * caller should call sock_put then. + */ + if (refcount_inc_not_zero(&sk->sk_refcnt)) + goto out_unlock; + } + } + } + sk = ERR_PTR(-ENOENT); +out_unlock: + rcu_read_unlock(); + + return sk; +} + +static int raw_diag_dump_one(struct netlink_callback *cb, + const struct inet_diag_req_v2 *r) +{ + struct sk_buff *in_skb = cb->skb; + struct sk_buff *rep; + struct sock *sk; + struct net *net; + int err; + + net = sock_net(in_skb->sk); + sk = raw_sock_get(net, r); + if (IS_ERR(sk)) + return PTR_ERR(sk); + + rep = nlmsg_new(nla_total_size(sizeof(struct inet_diag_msg)) + + inet_diag_msg_attrs_size() + + nla_total_size(sizeof(struct inet_diag_meminfo)) + 64, + GFP_KERNEL); + if (!rep) { + sock_put(sk); + return -ENOMEM; + } + + err = inet_sk_diag_fill(sk, NULL, rep, cb, r, 0, + netlink_net_capable(in_skb, CAP_NET_ADMIN)); + sock_put(sk); + + if (err < 0) { + kfree_skb(rep); + return err; + } + + err = nlmsg_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid); + + return err; +} + +static int sk_diag_dump(struct sock *sk, struct sk_buff *skb, + struct netlink_callback *cb, + const struct inet_diag_req_v2 *r, + struct nlattr *bc, bool net_admin) +{ + if (!inet_diag_bc_sk(bc, sk)) + return 0; + + return inet_sk_diag_fill(sk, NULL, skb, cb, r, NLM_F_MULTI, net_admin); +} + +static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, + const struct inet_diag_req_v2 *r) +{ + bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN); + struct raw_hashinfo *hashinfo = raw_get_hashinfo(r); + struct net *net = sock_net(skb->sk); + struct inet_diag_dump_data *cb_data; + int num, s_num, slot, s_slot; + struct hlist_head *hlist; + struct sock *sk = NULL; + struct nlattr *bc; + + if (IS_ERR(hashinfo)) + return; + + cb_data = cb->data; + bc = cb_data->inet_diag_nla_bc; + s_slot = cb->args[0]; + num = s_num = cb->args[1]; + + rcu_read_lock(); + for (slot = s_slot; slot < RAW_HTABLE_SIZE; s_num = 0, slot++) { + num = 0; + + hlist = &hashinfo->ht[slot]; + sk_for_each_rcu(sk, hlist) { + struct inet_sock *inet = inet_sk(sk); + + if (!net_eq(sock_net(sk), net)) + continue; + if (num < s_num) + goto next; + if (sk->sk_family != r->sdiag_family) + goto next; + if (r->id.idiag_sport != inet->inet_sport && + r->id.idiag_sport) + goto next; + if (r->id.idiag_dport != inet->inet_dport && + r->id.idiag_dport) + goto next; + if (sk_diag_dump(sk, skb, cb, r, bc, net_admin) < 0) + goto out_unlock; +next: + num++; + } + } + +out_unlock: + rcu_read_unlock(); + + cb->args[0] = slot; + cb->args[1] = num; +} + +static void raw_diag_get_info(struct sock *sk, struct inet_diag_msg *r, + void *info) +{ + r->idiag_rqueue = sk_rmem_alloc_get(sk); + r->idiag_wqueue = sk_wmem_alloc_get(sk); +} + +#ifdef CONFIG_INET_DIAG_DESTROY +static int raw_diag_destroy(struct sk_buff *in_skb, + const struct inet_diag_req_v2 *r) +{ + struct net *net = sock_net(in_skb->sk); + struct sock *sk; + int err; + + sk = raw_sock_get(net, r); + if (IS_ERR(sk)) + return PTR_ERR(sk); + err = sock_diag_destroy(sk, ECONNABORTED); + sock_put(sk); + return err; +} +#endif + +static const struct inet_diag_handler raw_diag_handler = { + .dump = raw_diag_dump, + .dump_one = raw_diag_dump_one, + .idiag_get_info = raw_diag_get_info, + .idiag_type = IPPROTO_RAW, + .idiag_info_size = 0, +#ifdef CONFIG_INET_DIAG_DESTROY + .destroy = raw_diag_destroy, +#endif +}; + +static void __always_unused __check_inet_diag_req_raw(void) +{ + /* + * Make sure the two structures are identical, + * except the @pad field. + */ +#define __offset_mismatch(m1, m2) \ + (offsetof(struct inet_diag_req_v2, m1) != \ + offsetof(struct inet_diag_req_raw, m2)) + + BUILD_BUG_ON(sizeof(struct inet_diag_req_v2) != + sizeof(struct inet_diag_req_raw)); + BUILD_BUG_ON(__offset_mismatch(sdiag_family, sdiag_family)); + BUILD_BUG_ON(__offset_mismatch(sdiag_protocol, sdiag_protocol)); + BUILD_BUG_ON(__offset_mismatch(idiag_ext, idiag_ext)); + BUILD_BUG_ON(__offset_mismatch(pad, sdiag_raw_protocol)); + BUILD_BUG_ON(__offset_mismatch(idiag_states, idiag_states)); + BUILD_BUG_ON(__offset_mismatch(id, id)); +#undef __offset_mismatch +} + +static int __init raw_diag_init(void) +{ + return inet_diag_register(&raw_diag_handler); +} + +static void __exit raw_diag_exit(void) +{ + inet_diag_unregister(&raw_diag_handler); +} + +module_init(raw_diag_init); +module_exit(raw_diag_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2-255 /* AF_INET - IPPROTO_RAW */); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 10-255 /* AF_INET6 - IPPROTO_RAW */); diff --git a/net/ipv4/route.c b/net/ipv4/route.c new file mode 100644 index 000000000..474f391fa --- /dev/null +++ b/net/ipv4/route.c @@ -0,0 +1,3789 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * ROUTE - implementation of the IP router. + * + * Authors: Ross Biro + * Fred N. van Kempen, + * Alan Cox, + * Linus Torvalds, + * Alexey Kuznetsov, + * + * Fixes: + * Alan Cox : Verify area fixes. + * Alan Cox : cli() protects routing changes + * Rui Oliveira : ICMP routing table updates + * (rco@di.uminho.pt) Routing table insertion and update + * Linus Torvalds : Rewrote bits to be sensible + * Alan Cox : Added BSD route gw semantics + * Alan Cox : Super /proc >4K + * Alan Cox : MTU in route table + * Alan Cox : MSS actually. Also added the window + * clamper. + * Sam Lantinga : Fixed route matching in rt_del() + * Alan Cox : Routing cache support. + * Alan Cox : Removed compatibility cruft. + * Alan Cox : RTF_REJECT support. + * Alan Cox : TCP irtt support. + * Jonathan Naylor : Added Metric support. + * Miquel van Smoorenburg : BSD API fixes. + * Miquel van Smoorenburg : Metrics. + * Alan Cox : Use __u32 properly + * Alan Cox : Aligned routing errors more closely with BSD + * our system is still very different. + * Alan Cox : Faster /proc handling + * Alexey Kuznetsov : Massive rework to support tree based routing, + * routing caches and better behaviour. + * + * Olaf Erb : irtt wasn't being copied right. + * Bjorn Ekwall : Kerneld route support. + * Alan Cox : Multicast fixed (I hope) + * Pavel Krauz : Limited broadcast fixed + * Mike McLagan : Routing by source + * Alexey Kuznetsov : End of old history. Split to fib.c and + * route.c and rewritten from scratch. + * Andi Kleen : Load-limit warning messages. + * Vitaly E. Lavrov : Transparent proxy revived after year coma. + * Vitaly E. Lavrov : Race condition in ip_route_input_slow. + * Tobias Ringstrom : Uninitialized res.type in ip_route_output_slow. + * Vladimir V. Ivanov : IP rule info (flowid) is really useful. + * Marc Boucher : routing by fwmark + * Robert Olsson : Added rt_cache statistics + * Arnaldo C. Melo : Convert proc stuff to seq_file + * Eric Dumazet : hashed spinlocks and rt_check_expire() fixes. + * Ilia Sotnikov : Ignore TOS on PMTUD and Redirect + * Ilia Sotnikov : Removed TOS from hash calculations + */ + +#define pr_fmt(fmt) "IPv4: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_SYSCTL +#include +#endif +#include +#include + +#include "fib_lookup.h" + +#define RT_FL_TOS(oldflp4) \ + ((oldflp4)->flowi4_tos & (IPTOS_RT_MASK | RTO_ONLINK)) + +#define RT_GC_TIMEOUT (300*HZ) + +#define DEFAULT_MIN_PMTU (512 + 20 + 20) +#define DEFAULT_MTU_EXPIRES (10 * 60 * HZ) +#define DEFAULT_MIN_ADVMSS 256 +static int ip_rt_max_size; +static int ip_rt_redirect_number __read_mostly = 9; +static int ip_rt_redirect_load __read_mostly = HZ / 50; +static int ip_rt_redirect_silence __read_mostly = ((HZ / 50) << (9 + 1)); +static int ip_rt_error_cost __read_mostly = HZ; +static int ip_rt_error_burst __read_mostly = 5 * HZ; + +static int ip_rt_gc_timeout __read_mostly = RT_GC_TIMEOUT; + +/* + * Interface to generic destination cache. + */ + +INDIRECT_CALLABLE_SCOPE +struct dst_entry *ipv4_dst_check(struct dst_entry *dst, u32 cookie); +static unsigned int ipv4_default_advmss(const struct dst_entry *dst); +INDIRECT_CALLABLE_SCOPE +unsigned int ipv4_mtu(const struct dst_entry *dst); +static struct dst_entry *ipv4_negative_advice(struct dst_entry *dst); +static void ipv4_link_failure(struct sk_buff *skb); +static void ip_rt_update_pmtu(struct dst_entry *dst, struct sock *sk, + struct sk_buff *skb, u32 mtu, + bool confirm_neigh); +static void ip_do_redirect(struct dst_entry *dst, struct sock *sk, + struct sk_buff *skb); +static void ipv4_dst_destroy(struct dst_entry *dst); + +static u32 *ipv4_cow_metrics(struct dst_entry *dst, unsigned long old) +{ + WARN_ON(1); + return NULL; +} + +static struct neighbour *ipv4_neigh_lookup(const struct dst_entry *dst, + struct sk_buff *skb, + const void *daddr); +static void ipv4_confirm_neigh(const struct dst_entry *dst, const void *daddr); + +static struct dst_ops ipv4_dst_ops = { + .family = AF_INET, + .check = ipv4_dst_check, + .default_advmss = ipv4_default_advmss, + .mtu = ipv4_mtu, + .cow_metrics = ipv4_cow_metrics, + .destroy = ipv4_dst_destroy, + .negative_advice = ipv4_negative_advice, + .link_failure = ipv4_link_failure, + .update_pmtu = ip_rt_update_pmtu, + .redirect = ip_do_redirect, + .local_out = __ip_local_out, + .neigh_lookup = ipv4_neigh_lookup, + .confirm_neigh = ipv4_confirm_neigh, +}; + +#define ECN_OR_COST(class) TC_PRIO_##class + +const __u8 ip_tos2prio[16] = { + TC_PRIO_BESTEFFORT, + ECN_OR_COST(BESTEFFORT), + TC_PRIO_BESTEFFORT, + ECN_OR_COST(BESTEFFORT), + TC_PRIO_BULK, + ECN_OR_COST(BULK), + TC_PRIO_BULK, + ECN_OR_COST(BULK), + TC_PRIO_INTERACTIVE, + ECN_OR_COST(INTERACTIVE), + TC_PRIO_INTERACTIVE, + ECN_OR_COST(INTERACTIVE), + TC_PRIO_INTERACTIVE_BULK, + ECN_OR_COST(INTERACTIVE_BULK), + TC_PRIO_INTERACTIVE_BULK, + ECN_OR_COST(INTERACTIVE_BULK) +}; +EXPORT_SYMBOL(ip_tos2prio); + +static DEFINE_PER_CPU(struct rt_cache_stat, rt_cache_stat); +#define RT_CACHE_STAT_INC(field) raw_cpu_inc(rt_cache_stat.field) + +#ifdef CONFIG_PROC_FS +static void *rt_cache_seq_start(struct seq_file *seq, loff_t *pos) +{ + if (*pos) + return NULL; + return SEQ_START_TOKEN; +} + +static void *rt_cache_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + ++*pos; + return NULL; +} + +static void rt_cache_seq_stop(struct seq_file *seq, void *v) +{ +} + +static int rt_cache_seq_show(struct seq_file *seq, void *v) +{ + if (v == SEQ_START_TOKEN) + seq_printf(seq, "%-127s\n", + "Iface\tDestination\tGateway \tFlags\t\tRefCnt\tUse\t" + "Metric\tSource\t\tMTU\tWindow\tIRTT\tTOS\tHHRef\t" + "HHUptod\tSpecDst"); + return 0; +} + +static const struct seq_operations rt_cache_seq_ops = { + .start = rt_cache_seq_start, + .next = rt_cache_seq_next, + .stop = rt_cache_seq_stop, + .show = rt_cache_seq_show, +}; + +static void *rt_cpu_seq_start(struct seq_file *seq, loff_t *pos) +{ + int cpu; + + if (*pos == 0) + return SEQ_START_TOKEN; + + for (cpu = *pos-1; cpu < nr_cpu_ids; ++cpu) { + if (!cpu_possible(cpu)) + continue; + *pos = cpu+1; + return &per_cpu(rt_cache_stat, cpu); + } + return NULL; +} + +static void *rt_cpu_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + int cpu; + + for (cpu = *pos; cpu < nr_cpu_ids; ++cpu) { + if (!cpu_possible(cpu)) + continue; + *pos = cpu+1; + return &per_cpu(rt_cache_stat, cpu); + } + (*pos)++; + return NULL; + +} + +static void rt_cpu_seq_stop(struct seq_file *seq, void *v) +{ + +} + +static int rt_cpu_seq_show(struct seq_file *seq, void *v) +{ + struct rt_cache_stat *st = v; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, "entries in_hit in_slow_tot in_slow_mc in_no_route in_brd in_martian_dst in_martian_src out_hit out_slow_tot out_slow_mc gc_total gc_ignored gc_goal_miss gc_dst_overflow in_hlist_search out_hlist_search\n"); + return 0; + } + + seq_printf(seq, "%08x %08x %08x %08x %08x %08x %08x " + "%08x %08x %08x %08x %08x %08x " + "%08x %08x %08x %08x\n", + dst_entries_get_slow(&ipv4_dst_ops), + 0, /* st->in_hit */ + st->in_slow_tot, + st->in_slow_mc, + st->in_no_route, + st->in_brd, + st->in_martian_dst, + st->in_martian_src, + + 0, /* st->out_hit */ + st->out_slow_tot, + st->out_slow_mc, + + 0, /* st->gc_total */ + 0, /* st->gc_ignored */ + 0, /* st->gc_goal_miss */ + 0, /* st->gc_dst_overflow */ + 0, /* st->in_hlist_search */ + 0 /* st->out_hlist_search */ + ); + return 0; +} + +static const struct seq_operations rt_cpu_seq_ops = { + .start = rt_cpu_seq_start, + .next = rt_cpu_seq_next, + .stop = rt_cpu_seq_stop, + .show = rt_cpu_seq_show, +}; + +#ifdef CONFIG_IP_ROUTE_CLASSID +static int rt_acct_proc_show(struct seq_file *m, void *v) +{ + struct ip_rt_acct *dst, *src; + unsigned int i, j; + + dst = kcalloc(256, sizeof(struct ip_rt_acct), GFP_KERNEL); + if (!dst) + return -ENOMEM; + + for_each_possible_cpu(i) { + src = (struct ip_rt_acct *)per_cpu_ptr(ip_rt_acct, i); + for (j = 0; j < 256; j++) { + dst[j].o_bytes += src[j].o_bytes; + dst[j].o_packets += src[j].o_packets; + dst[j].i_bytes += src[j].i_bytes; + dst[j].i_packets += src[j].i_packets; + } + } + + seq_write(m, dst, 256 * sizeof(struct ip_rt_acct)); + kfree(dst); + return 0; +} +#endif + +static int __net_init ip_rt_do_proc_init(struct net *net) +{ + struct proc_dir_entry *pde; + + pde = proc_create_seq("rt_cache", 0444, net->proc_net, + &rt_cache_seq_ops); + if (!pde) + goto err1; + + pde = proc_create_seq("rt_cache", 0444, net->proc_net_stat, + &rt_cpu_seq_ops); + if (!pde) + goto err2; + +#ifdef CONFIG_IP_ROUTE_CLASSID + pde = proc_create_single("rt_acct", 0, net->proc_net, + rt_acct_proc_show); + if (!pde) + goto err3; +#endif + return 0; + +#ifdef CONFIG_IP_ROUTE_CLASSID +err3: + remove_proc_entry("rt_cache", net->proc_net_stat); +#endif +err2: + remove_proc_entry("rt_cache", net->proc_net); +err1: + return -ENOMEM; +} + +static void __net_exit ip_rt_do_proc_exit(struct net *net) +{ + remove_proc_entry("rt_cache", net->proc_net_stat); + remove_proc_entry("rt_cache", net->proc_net); +#ifdef CONFIG_IP_ROUTE_CLASSID + remove_proc_entry("rt_acct", net->proc_net); +#endif +} + +static struct pernet_operations ip_rt_proc_ops __net_initdata = { + .init = ip_rt_do_proc_init, + .exit = ip_rt_do_proc_exit, +}; + +static int __init ip_rt_proc_init(void) +{ + return register_pernet_subsys(&ip_rt_proc_ops); +} + +#else +static inline int ip_rt_proc_init(void) +{ + return 0; +} +#endif /* CONFIG_PROC_FS */ + +static inline bool rt_is_expired(const struct rtable *rth) +{ + return rth->rt_genid != rt_genid_ipv4(dev_net(rth->dst.dev)); +} + +void rt_cache_flush(struct net *net) +{ + rt_genid_bump_ipv4(net); +} + +static struct neighbour *ipv4_neigh_lookup(const struct dst_entry *dst, + struct sk_buff *skb, + const void *daddr) +{ + const struct rtable *rt = container_of(dst, struct rtable, dst); + struct net_device *dev = dst->dev; + struct neighbour *n; + + rcu_read_lock(); + + if (likely(rt->rt_gw_family == AF_INET)) { + n = ip_neigh_gw4(dev, rt->rt_gw4); + } else if (rt->rt_gw_family == AF_INET6) { + n = ip_neigh_gw6(dev, &rt->rt_gw6); + } else { + __be32 pkey; + + pkey = skb ? ip_hdr(skb)->daddr : *((__be32 *) daddr); + n = ip_neigh_gw4(dev, pkey); + } + + if (!IS_ERR(n) && !refcount_inc_not_zero(&n->refcnt)) + n = NULL; + + rcu_read_unlock(); + + return n; +} + +static void ipv4_confirm_neigh(const struct dst_entry *dst, const void *daddr) +{ + const struct rtable *rt = container_of(dst, struct rtable, dst); + struct net_device *dev = dst->dev; + const __be32 *pkey = daddr; + + if (rt->rt_gw_family == AF_INET) { + pkey = (const __be32 *)&rt->rt_gw4; + } else if (rt->rt_gw_family == AF_INET6) { + return __ipv6_confirm_neigh_stub(dev, &rt->rt_gw6); + } else if (!daddr || + (rt->rt_flags & + (RTCF_MULTICAST | RTCF_BROADCAST | RTCF_LOCAL))) { + return; + } + __ipv4_confirm_neigh(dev, *(__force u32 *)pkey); +} + +/* Hash tables of size 2048..262144 depending on RAM size. + * Each bucket uses 8 bytes. + */ +static u32 ip_idents_mask __read_mostly; +static atomic_t *ip_idents __read_mostly; +static u32 *ip_tstamps __read_mostly; + +/* In order to protect privacy, we add a perturbation to identifiers + * if one generator is seldom used. This makes hard for an attacker + * to infer how many packets were sent between two points in time. + */ +static u32 ip_idents_reserve(u32 hash, int segs) +{ + u32 bucket, old, now = (u32)jiffies; + atomic_t *p_id; + u32 *p_tstamp; + u32 delta = 0; + + bucket = hash & ip_idents_mask; + p_tstamp = ip_tstamps + bucket; + p_id = ip_idents + bucket; + old = READ_ONCE(*p_tstamp); + + if (old != now && cmpxchg(p_tstamp, old, now) == old) + delta = prandom_u32_max(now - old); + + /* If UBSAN reports an error there, please make sure your compiler + * supports -fno-strict-overflow before reporting it that was a bug + * in UBSAN, and it has been fixed in GCC-8. + */ + return atomic_add_return(segs + delta, p_id) - segs; +} + +void __ip_select_ident(struct net *net, struct iphdr *iph, int segs) +{ + u32 hash, id; + + /* Note the following code is not safe, but this is okay. */ + if (unlikely(siphash_key_is_zero(&net->ipv4.ip_id_key))) + get_random_bytes(&net->ipv4.ip_id_key, + sizeof(net->ipv4.ip_id_key)); + + hash = siphash_3u32((__force u32)iph->daddr, + (__force u32)iph->saddr, + iph->protocol, + &net->ipv4.ip_id_key); + id = ip_idents_reserve(hash, segs); + iph->id = htons(id); +} +EXPORT_SYMBOL(__ip_select_ident); + +static void ip_rt_fix_tos(struct flowi4 *fl4) +{ + __u8 tos = RT_FL_TOS(fl4); + + fl4->flowi4_tos = tos & IPTOS_RT_MASK; + if (tos & RTO_ONLINK) + fl4->flowi4_scope = RT_SCOPE_LINK; +} + +static void __build_flow_key(const struct net *net, struct flowi4 *fl4, + const struct sock *sk, const struct iphdr *iph, + int oif, __u8 tos, u8 prot, u32 mark, + int flow_flags) +{ + __u8 scope = RT_SCOPE_UNIVERSE; + + if (sk) { + const struct inet_sock *inet = inet_sk(sk); + + oif = sk->sk_bound_dev_if; + mark = READ_ONCE(sk->sk_mark); + tos = ip_sock_rt_tos(sk); + scope = ip_sock_rt_scope(sk); + prot = inet->hdrincl ? IPPROTO_RAW : sk->sk_protocol; + } + + flowi4_init_output(fl4, oif, mark, tos & IPTOS_RT_MASK, scope, + prot, flow_flags, iph->daddr, iph->saddr, 0, 0, + sock_net_uid(net, sk)); +} + +static void build_skb_flow_key(struct flowi4 *fl4, const struct sk_buff *skb, + const struct sock *sk) +{ + const struct net *net = dev_net(skb->dev); + const struct iphdr *iph = ip_hdr(skb); + int oif = skb->dev->ifindex; + u8 prot = iph->protocol; + u32 mark = skb->mark; + __u8 tos = iph->tos; + + __build_flow_key(net, fl4, sk, iph, oif, tos, prot, mark, 0); +} + +static void build_sk_flow_key(struct flowi4 *fl4, const struct sock *sk) +{ + const struct inet_sock *inet = inet_sk(sk); + const struct ip_options_rcu *inet_opt; + __be32 daddr = inet->inet_daddr; + + rcu_read_lock(); + inet_opt = rcu_dereference(inet->inet_opt); + if (inet_opt && inet_opt->opt.srr) + daddr = inet_opt->opt.faddr; + flowi4_init_output(fl4, sk->sk_bound_dev_if, READ_ONCE(sk->sk_mark), + ip_sock_rt_tos(sk) & IPTOS_RT_MASK, + ip_sock_rt_scope(sk), + inet->hdrincl ? IPPROTO_RAW : sk->sk_protocol, + inet_sk_flowi_flags(sk), + daddr, inet->inet_saddr, 0, 0, sk->sk_uid); + rcu_read_unlock(); +} + +static void ip_rt_build_flow_key(struct flowi4 *fl4, const struct sock *sk, + const struct sk_buff *skb) +{ + if (skb) + build_skb_flow_key(fl4, skb, sk); + else + build_sk_flow_key(fl4, sk); +} + +static DEFINE_SPINLOCK(fnhe_lock); + +static void fnhe_flush_routes(struct fib_nh_exception *fnhe) +{ + struct rtable *rt; + + rt = rcu_dereference(fnhe->fnhe_rth_input); + if (rt) { + RCU_INIT_POINTER(fnhe->fnhe_rth_input, NULL); + dst_dev_put(&rt->dst); + dst_release(&rt->dst); + } + rt = rcu_dereference(fnhe->fnhe_rth_output); + if (rt) { + RCU_INIT_POINTER(fnhe->fnhe_rth_output, NULL); + dst_dev_put(&rt->dst); + dst_release(&rt->dst); + } +} + +static void fnhe_remove_oldest(struct fnhe_hash_bucket *hash) +{ + struct fib_nh_exception __rcu **fnhe_p, **oldest_p; + struct fib_nh_exception *fnhe, *oldest = NULL; + + for (fnhe_p = &hash->chain; ; fnhe_p = &fnhe->fnhe_next) { + fnhe = rcu_dereference_protected(*fnhe_p, + lockdep_is_held(&fnhe_lock)); + if (!fnhe) + break; + if (!oldest || + time_before(fnhe->fnhe_stamp, oldest->fnhe_stamp)) { + oldest = fnhe; + oldest_p = fnhe_p; + } + } + fnhe_flush_routes(oldest); + *oldest_p = oldest->fnhe_next; + kfree_rcu(oldest, rcu); +} + +static u32 fnhe_hashfun(__be32 daddr) +{ + static siphash_aligned_key_t fnhe_hash_key; + u64 hval; + + net_get_random_once(&fnhe_hash_key, sizeof(fnhe_hash_key)); + hval = siphash_1u32((__force u32)daddr, &fnhe_hash_key); + return hash_64(hval, FNHE_HASH_SHIFT); +} + +static void fill_route_from_fnhe(struct rtable *rt, struct fib_nh_exception *fnhe) +{ + rt->rt_pmtu = fnhe->fnhe_pmtu; + rt->rt_mtu_locked = fnhe->fnhe_mtu_locked; + rt->dst.expires = fnhe->fnhe_expires; + + if (fnhe->fnhe_gw) { + rt->rt_flags |= RTCF_REDIRECTED; + rt->rt_uses_gateway = 1; + rt->rt_gw_family = AF_INET; + rt->rt_gw4 = fnhe->fnhe_gw; + } +} + +static void update_or_create_fnhe(struct fib_nh_common *nhc, __be32 daddr, + __be32 gw, u32 pmtu, bool lock, + unsigned long expires) +{ + struct fnhe_hash_bucket *hash; + struct fib_nh_exception *fnhe; + struct rtable *rt; + u32 genid, hval; + unsigned int i; + int depth; + + genid = fnhe_genid(dev_net(nhc->nhc_dev)); + hval = fnhe_hashfun(daddr); + + spin_lock_bh(&fnhe_lock); + + hash = rcu_dereference(nhc->nhc_exceptions); + if (!hash) { + hash = kcalloc(FNHE_HASH_SIZE, sizeof(*hash), GFP_ATOMIC); + if (!hash) + goto out_unlock; + rcu_assign_pointer(nhc->nhc_exceptions, hash); + } + + hash += hval; + + depth = 0; + for (fnhe = rcu_dereference(hash->chain); fnhe; + fnhe = rcu_dereference(fnhe->fnhe_next)) { + if (fnhe->fnhe_daddr == daddr) + break; + depth++; + } + + if (fnhe) { + if (fnhe->fnhe_genid != genid) + fnhe->fnhe_genid = genid; + if (gw) + fnhe->fnhe_gw = gw; + if (pmtu) { + fnhe->fnhe_pmtu = pmtu; + fnhe->fnhe_mtu_locked = lock; + } + fnhe->fnhe_expires = max(1UL, expires); + /* Update all cached dsts too */ + rt = rcu_dereference(fnhe->fnhe_rth_input); + if (rt) + fill_route_from_fnhe(rt, fnhe); + rt = rcu_dereference(fnhe->fnhe_rth_output); + if (rt) + fill_route_from_fnhe(rt, fnhe); + } else { + /* Randomize max depth to avoid some side channels attacks. */ + int max_depth = FNHE_RECLAIM_DEPTH + + prandom_u32_max(FNHE_RECLAIM_DEPTH); + + while (depth > max_depth) { + fnhe_remove_oldest(hash); + depth--; + } + + fnhe = kzalloc(sizeof(*fnhe), GFP_ATOMIC); + if (!fnhe) + goto out_unlock; + + fnhe->fnhe_next = hash->chain; + + fnhe->fnhe_genid = genid; + fnhe->fnhe_daddr = daddr; + fnhe->fnhe_gw = gw; + fnhe->fnhe_pmtu = pmtu; + fnhe->fnhe_mtu_locked = lock; + fnhe->fnhe_expires = max(1UL, expires); + + rcu_assign_pointer(hash->chain, fnhe); + + /* Exception created; mark the cached routes for the nexthop + * stale, so anyone caching it rechecks if this exception + * applies to them. + */ + rt = rcu_dereference(nhc->nhc_rth_input); + if (rt) + rt->dst.obsolete = DST_OBSOLETE_KILL; + + for_each_possible_cpu(i) { + struct rtable __rcu **prt; + + prt = per_cpu_ptr(nhc->nhc_pcpu_rth_output, i); + rt = rcu_dereference(*prt); + if (rt) + rt->dst.obsolete = DST_OBSOLETE_KILL; + } + } + + fnhe->fnhe_stamp = jiffies; + +out_unlock: + spin_unlock_bh(&fnhe_lock); +} + +static void __ip_do_redirect(struct rtable *rt, struct sk_buff *skb, struct flowi4 *fl4, + bool kill_route) +{ + __be32 new_gw = icmp_hdr(skb)->un.gateway; + __be32 old_gw = ip_hdr(skb)->saddr; + struct net_device *dev = skb->dev; + struct in_device *in_dev; + struct fib_result res; + struct neighbour *n; + struct net *net; + + switch (icmp_hdr(skb)->code & 7) { + case ICMP_REDIR_NET: + case ICMP_REDIR_NETTOS: + case ICMP_REDIR_HOST: + case ICMP_REDIR_HOSTTOS: + break; + + default: + return; + } + + if (rt->rt_gw_family != AF_INET || rt->rt_gw4 != old_gw) + return; + + in_dev = __in_dev_get_rcu(dev); + if (!in_dev) + return; + + net = dev_net(dev); + if (new_gw == old_gw || !IN_DEV_RX_REDIRECTS(in_dev) || + ipv4_is_multicast(new_gw) || ipv4_is_lbcast(new_gw) || + ipv4_is_zeronet(new_gw)) + goto reject_redirect; + + if (!IN_DEV_SHARED_MEDIA(in_dev)) { + if (!inet_addr_onlink(in_dev, new_gw, old_gw)) + goto reject_redirect; + if (IN_DEV_SEC_REDIRECTS(in_dev) && ip_fib_check_default(new_gw, dev)) + goto reject_redirect; + } else { + if (inet_addr_type(net, new_gw) != RTN_UNICAST) + goto reject_redirect; + } + + n = __ipv4_neigh_lookup(rt->dst.dev, (__force u32)new_gw); + if (!n) + n = neigh_create(&arp_tbl, &new_gw, rt->dst.dev); + if (!IS_ERR(n)) { + if (!(READ_ONCE(n->nud_state) & NUD_VALID)) { + neigh_event_send(n, NULL); + } else { + if (fib_lookup(net, fl4, &res, 0) == 0) { + struct fib_nh_common *nhc; + + fib_select_path(net, &res, fl4, skb); + nhc = FIB_RES_NHC(res); + update_or_create_fnhe(nhc, fl4->daddr, new_gw, + 0, false, + jiffies + ip_rt_gc_timeout); + } + if (kill_route) + rt->dst.obsolete = DST_OBSOLETE_KILL; + call_netevent_notifiers(NETEVENT_NEIGH_UPDATE, n); + } + neigh_release(n); + } + return; + +reject_redirect: +#ifdef CONFIG_IP_ROUTE_VERBOSE + if (IN_DEV_LOG_MARTIANS(in_dev)) { + const struct iphdr *iph = (const struct iphdr *) skb->data; + __be32 daddr = iph->daddr; + __be32 saddr = iph->saddr; + + net_info_ratelimited("Redirect from %pI4 on %s about %pI4 ignored\n" + " Advised path = %pI4 -> %pI4\n", + &old_gw, dev->name, &new_gw, + &saddr, &daddr); + } +#endif + ; +} + +static void ip_do_redirect(struct dst_entry *dst, struct sock *sk, struct sk_buff *skb) +{ + struct rtable *rt; + struct flowi4 fl4; + const struct iphdr *iph = (const struct iphdr *) skb->data; + struct net *net = dev_net(skb->dev); + int oif = skb->dev->ifindex; + u8 prot = iph->protocol; + u32 mark = skb->mark; + __u8 tos = iph->tos; + + rt = (struct rtable *) dst; + + __build_flow_key(net, &fl4, sk, iph, oif, tos, prot, mark, 0); + __ip_do_redirect(rt, skb, &fl4, true); +} + +static struct dst_entry *ipv4_negative_advice(struct dst_entry *dst) +{ + struct rtable *rt = (struct rtable *)dst; + struct dst_entry *ret = dst; + + if (rt) { + if (dst->obsolete > 0) { + ip_rt_put(rt); + ret = NULL; + } else if ((rt->rt_flags & RTCF_REDIRECTED) || + rt->dst.expires) { + ip_rt_put(rt); + ret = NULL; + } + } + return ret; +} + +/* + * Algorithm: + * 1. The first ip_rt_redirect_number redirects are sent + * with exponential backoff, then we stop sending them at all, + * assuming that the host ignores our redirects. + * 2. If we did not see packets requiring redirects + * during ip_rt_redirect_silence, we assume that the host + * forgot redirected route and start to send redirects again. + * + * This algorithm is much cheaper and more intelligent than dumb load limiting + * in icmp.c. + * + * NOTE. Do not forget to inhibit load limiting for redirects (redundant) + * and "frag. need" (breaks PMTU discovery) in icmp.c. + */ + +void ip_rt_send_redirect(struct sk_buff *skb) +{ + struct rtable *rt = skb_rtable(skb); + struct in_device *in_dev; + struct inet_peer *peer; + struct net *net; + int log_martians; + int vif; + + rcu_read_lock(); + in_dev = __in_dev_get_rcu(rt->dst.dev); + if (!in_dev || !IN_DEV_TX_REDIRECTS(in_dev)) { + rcu_read_unlock(); + return; + } + log_martians = IN_DEV_LOG_MARTIANS(in_dev); + vif = l3mdev_master_ifindex_rcu(rt->dst.dev); + rcu_read_unlock(); + + net = dev_net(rt->dst.dev); + peer = inet_getpeer_v4(net->ipv4.peers, ip_hdr(skb)->saddr, vif, 1); + if (!peer) { + icmp_send(skb, ICMP_REDIRECT, ICMP_REDIR_HOST, + rt_nexthop(rt, ip_hdr(skb)->daddr)); + return; + } + + /* No redirected packets during ip_rt_redirect_silence; + * reset the algorithm. + */ + if (time_after(jiffies, peer->rate_last + ip_rt_redirect_silence)) { + peer->rate_tokens = 0; + peer->n_redirects = 0; + } + + /* Too many ignored redirects; do not send anything + * set dst.rate_last to the last seen redirected packet. + */ + if (peer->n_redirects >= ip_rt_redirect_number) { + peer->rate_last = jiffies; + goto out_put_peer; + } + + /* Check for load limit; set rate_last to the latest sent + * redirect. + */ + if (peer->n_redirects == 0 || + time_after(jiffies, + (peer->rate_last + + (ip_rt_redirect_load << peer->n_redirects)))) { + __be32 gw = rt_nexthop(rt, ip_hdr(skb)->daddr); + + icmp_send(skb, ICMP_REDIRECT, ICMP_REDIR_HOST, gw); + peer->rate_last = jiffies; + ++peer->n_redirects; +#ifdef CONFIG_IP_ROUTE_VERBOSE + if (log_martians && + peer->n_redirects == ip_rt_redirect_number) + net_warn_ratelimited("host %pI4/if%d ignores redirects for %pI4 to %pI4\n", + &ip_hdr(skb)->saddr, inet_iif(skb), + &ip_hdr(skb)->daddr, &gw); +#endif + } +out_put_peer: + inet_putpeer(peer); +} + +static int ip_error(struct sk_buff *skb) +{ + struct rtable *rt = skb_rtable(skb); + struct net_device *dev = skb->dev; + struct in_device *in_dev; + struct inet_peer *peer; + unsigned long now; + struct net *net; + SKB_DR(reason); + bool send; + int code; + + if (netif_is_l3_master(skb->dev)) { + dev = __dev_get_by_index(dev_net(skb->dev), IPCB(skb)->iif); + if (!dev) + goto out; + } + + in_dev = __in_dev_get_rcu(dev); + + /* IP on this device is disabled. */ + if (!in_dev) + goto out; + + net = dev_net(rt->dst.dev); + if (!IN_DEV_FORWARD(in_dev)) { + switch (rt->dst.error) { + case EHOSTUNREACH: + SKB_DR_SET(reason, IP_INADDRERRORS); + __IP_INC_STATS(net, IPSTATS_MIB_INADDRERRORS); + break; + + case ENETUNREACH: + SKB_DR_SET(reason, IP_INNOROUTES); + __IP_INC_STATS(net, IPSTATS_MIB_INNOROUTES); + break; + } + goto out; + } + + switch (rt->dst.error) { + case EINVAL: + default: + goto out; + case EHOSTUNREACH: + code = ICMP_HOST_UNREACH; + break; + case ENETUNREACH: + code = ICMP_NET_UNREACH; + SKB_DR_SET(reason, IP_INNOROUTES); + __IP_INC_STATS(net, IPSTATS_MIB_INNOROUTES); + break; + case EACCES: + code = ICMP_PKT_FILTERED; + break; + } + + peer = inet_getpeer_v4(net->ipv4.peers, ip_hdr(skb)->saddr, + l3mdev_master_ifindex(skb->dev), 1); + + send = true; + if (peer) { + now = jiffies; + peer->rate_tokens += now - peer->rate_last; + if (peer->rate_tokens > ip_rt_error_burst) + peer->rate_tokens = ip_rt_error_burst; + peer->rate_last = now; + if (peer->rate_tokens >= ip_rt_error_cost) + peer->rate_tokens -= ip_rt_error_cost; + else + send = false; + inet_putpeer(peer); + } + if (send) + icmp_send(skb, ICMP_DEST_UNREACH, code, 0); + +out: kfree_skb_reason(skb, reason); + return 0; +} + +static void __ip_rt_update_pmtu(struct rtable *rt, struct flowi4 *fl4, u32 mtu) +{ + struct dst_entry *dst = &rt->dst; + struct net *net = dev_net(dst->dev); + struct fib_result res; + bool lock = false; + u32 old_mtu; + + if (ip_mtu_locked(dst)) + return; + + old_mtu = ipv4_mtu(dst); + if (old_mtu < mtu) + return; + + if (mtu < net->ipv4.ip_rt_min_pmtu) { + lock = true; + mtu = min(old_mtu, net->ipv4.ip_rt_min_pmtu); + } + + if (rt->rt_pmtu == mtu && !lock && + time_before(jiffies, dst->expires - net->ipv4.ip_rt_mtu_expires / 2)) + return; + + rcu_read_lock(); + if (fib_lookup(net, fl4, &res, 0) == 0) { + struct fib_nh_common *nhc; + + fib_select_path(net, &res, fl4, NULL); + nhc = FIB_RES_NHC(res); + update_or_create_fnhe(nhc, fl4->daddr, 0, mtu, lock, + jiffies + net->ipv4.ip_rt_mtu_expires); + } + rcu_read_unlock(); +} + +static void ip_rt_update_pmtu(struct dst_entry *dst, struct sock *sk, + struct sk_buff *skb, u32 mtu, + bool confirm_neigh) +{ + struct rtable *rt = (struct rtable *) dst; + struct flowi4 fl4; + + ip_rt_build_flow_key(&fl4, sk, skb); + + /* Don't make lookup fail for bridged encapsulations */ + if (skb && netif_is_any_bridge_port(skb->dev)) + fl4.flowi4_oif = 0; + + __ip_rt_update_pmtu(rt, &fl4, mtu); +} + +void ipv4_update_pmtu(struct sk_buff *skb, struct net *net, u32 mtu, + int oif, u8 protocol) +{ + const struct iphdr *iph = (const struct iphdr *)skb->data; + struct flowi4 fl4; + struct rtable *rt; + u32 mark = IP4_REPLY_MARK(net, skb->mark); + + __build_flow_key(net, &fl4, NULL, iph, oif, iph->tos, protocol, mark, + 0); + rt = __ip_route_output_key(net, &fl4); + if (!IS_ERR(rt)) { + __ip_rt_update_pmtu(rt, &fl4, mtu); + ip_rt_put(rt); + } +} +EXPORT_SYMBOL_GPL(ipv4_update_pmtu); + +static void __ipv4_sk_update_pmtu(struct sk_buff *skb, struct sock *sk, u32 mtu) +{ + const struct iphdr *iph = (const struct iphdr *)skb->data; + struct flowi4 fl4; + struct rtable *rt; + + __build_flow_key(sock_net(sk), &fl4, sk, iph, 0, 0, 0, 0, 0); + + if (!fl4.flowi4_mark) + fl4.flowi4_mark = IP4_REPLY_MARK(sock_net(sk), skb->mark); + + rt = __ip_route_output_key(sock_net(sk), &fl4); + if (!IS_ERR(rt)) { + __ip_rt_update_pmtu(rt, &fl4, mtu); + ip_rt_put(rt); + } +} + +void ipv4_sk_update_pmtu(struct sk_buff *skb, struct sock *sk, u32 mtu) +{ + const struct iphdr *iph = (const struct iphdr *)skb->data; + struct flowi4 fl4; + struct rtable *rt; + struct dst_entry *odst = NULL; + bool new = false; + struct net *net = sock_net(sk); + + bh_lock_sock(sk); + + if (!ip_sk_accept_pmtu(sk)) + goto out; + + odst = sk_dst_get(sk); + + if (sock_owned_by_user(sk) || !odst) { + __ipv4_sk_update_pmtu(skb, sk, mtu); + goto out; + } + + __build_flow_key(net, &fl4, sk, iph, 0, 0, 0, 0, 0); + + rt = (struct rtable *)odst; + if (odst->obsolete && !odst->ops->check(odst, 0)) { + rt = ip_route_output_flow(sock_net(sk), &fl4, sk); + if (IS_ERR(rt)) + goto out; + + new = true; + } + + __ip_rt_update_pmtu((struct rtable *)xfrm_dst_path(&rt->dst), &fl4, mtu); + + if (!dst_check(&rt->dst, 0)) { + if (new) + dst_release(&rt->dst); + + rt = ip_route_output_flow(sock_net(sk), &fl4, sk); + if (IS_ERR(rt)) + goto out; + + new = true; + } + + if (new) + sk_dst_set(sk, &rt->dst); + +out: + bh_unlock_sock(sk); + dst_release(odst); +} +EXPORT_SYMBOL_GPL(ipv4_sk_update_pmtu); + +void ipv4_redirect(struct sk_buff *skb, struct net *net, + int oif, u8 protocol) +{ + const struct iphdr *iph = (const struct iphdr *)skb->data; + struct flowi4 fl4; + struct rtable *rt; + + __build_flow_key(net, &fl4, NULL, iph, oif, iph->tos, protocol, 0, 0); + rt = __ip_route_output_key(net, &fl4); + if (!IS_ERR(rt)) { + __ip_do_redirect(rt, skb, &fl4, false); + ip_rt_put(rt); + } +} +EXPORT_SYMBOL_GPL(ipv4_redirect); + +void ipv4_sk_redirect(struct sk_buff *skb, struct sock *sk) +{ + const struct iphdr *iph = (const struct iphdr *)skb->data; + struct flowi4 fl4; + struct rtable *rt; + struct net *net = sock_net(sk); + + __build_flow_key(net, &fl4, sk, iph, 0, 0, 0, 0, 0); + rt = __ip_route_output_key(net, &fl4); + if (!IS_ERR(rt)) { + __ip_do_redirect(rt, skb, &fl4, false); + ip_rt_put(rt); + } +} +EXPORT_SYMBOL_GPL(ipv4_sk_redirect); + +INDIRECT_CALLABLE_SCOPE struct dst_entry *ipv4_dst_check(struct dst_entry *dst, + u32 cookie) +{ + struct rtable *rt = (struct rtable *) dst; + + /* All IPV4 dsts are created with ->obsolete set to the value + * DST_OBSOLETE_FORCE_CHK which forces validation calls down + * into this function always. + * + * When a PMTU/redirect information update invalidates a route, + * this is indicated by setting obsolete to DST_OBSOLETE_KILL or + * DST_OBSOLETE_DEAD. + */ + if (dst->obsolete != DST_OBSOLETE_FORCE_CHK || rt_is_expired(rt)) + return NULL; + return dst; +} +EXPORT_INDIRECT_CALLABLE(ipv4_dst_check); + +static void ipv4_send_dest_unreach(struct sk_buff *skb) +{ + struct net_device *dev; + struct ip_options opt; + int res; + + /* Recompile ip options since IPCB may not be valid anymore. + * Also check we have a reasonable ipv4 header. + */ + if (!pskb_network_may_pull(skb, sizeof(struct iphdr)) || + ip_hdr(skb)->version != 4 || ip_hdr(skb)->ihl < 5) + return; + + memset(&opt, 0, sizeof(opt)); + if (ip_hdr(skb)->ihl > 5) { + if (!pskb_network_may_pull(skb, ip_hdr(skb)->ihl * 4)) + return; + opt.optlen = ip_hdr(skb)->ihl * 4 - sizeof(struct iphdr); + + rcu_read_lock(); + dev = skb->dev ? skb->dev : skb_rtable(skb)->dst.dev; + res = __ip_options_compile(dev_net(dev), &opt, skb, NULL); + rcu_read_unlock(); + + if (res) + return; + } + __icmp_send(skb, ICMP_DEST_UNREACH, ICMP_HOST_UNREACH, 0, &opt); +} + +static void ipv4_link_failure(struct sk_buff *skb) +{ + struct rtable *rt; + + ipv4_send_dest_unreach(skb); + + rt = skb_rtable(skb); + if (rt) + dst_set_expires(&rt->dst, 0); +} + +static int ip_rt_bug(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + pr_debug("%s: %pI4 -> %pI4, %s\n", + __func__, &ip_hdr(skb)->saddr, &ip_hdr(skb)->daddr, + skb->dev ? skb->dev->name : "?"); + kfree_skb(skb); + WARN_ON(1); + return 0; +} + +/* + * We do not cache source address of outgoing interface, + * because it is used only by IP RR, TS and SRR options, + * so that it out of fast path. + * + * BTW remember: "addr" is allowed to be not aligned + * in IP options! + */ + +void ip_rt_get_source(u8 *addr, struct sk_buff *skb, struct rtable *rt) +{ + __be32 src; + + if (rt_is_output_route(rt)) + src = ip_hdr(skb)->saddr; + else { + struct fib_result res; + struct iphdr *iph = ip_hdr(skb); + struct flowi4 fl4 = { + .daddr = iph->daddr, + .saddr = iph->saddr, + .flowi4_tos = RT_TOS(iph->tos), + .flowi4_oif = rt->dst.dev->ifindex, + .flowi4_iif = skb->dev->ifindex, + .flowi4_mark = skb->mark, + }; + + rcu_read_lock(); + if (fib_lookup(dev_net(rt->dst.dev), &fl4, &res, 0) == 0) + src = fib_result_prefsrc(dev_net(rt->dst.dev), &res); + else + src = inet_select_addr(rt->dst.dev, + rt_nexthop(rt, iph->daddr), + RT_SCOPE_UNIVERSE); + rcu_read_unlock(); + } + memcpy(addr, &src, 4); +} + +#ifdef CONFIG_IP_ROUTE_CLASSID +static void set_class_tag(struct rtable *rt, u32 tag) +{ + if (!(rt->dst.tclassid & 0xFFFF)) + rt->dst.tclassid |= tag & 0xFFFF; + if (!(rt->dst.tclassid & 0xFFFF0000)) + rt->dst.tclassid |= tag & 0xFFFF0000; +} +#endif + +static unsigned int ipv4_default_advmss(const struct dst_entry *dst) +{ + struct net *net = dev_net(dst->dev); + unsigned int header_size = sizeof(struct tcphdr) + sizeof(struct iphdr); + unsigned int advmss = max_t(unsigned int, ipv4_mtu(dst) - header_size, + net->ipv4.ip_rt_min_advmss); + + return min(advmss, IPV4_MAX_PMTU - header_size); +} + +INDIRECT_CALLABLE_SCOPE unsigned int ipv4_mtu(const struct dst_entry *dst) +{ + return ip_dst_mtu_maybe_forward(dst, false); +} +EXPORT_INDIRECT_CALLABLE(ipv4_mtu); + +static void ip_del_fnhe(struct fib_nh_common *nhc, __be32 daddr) +{ + struct fnhe_hash_bucket *hash; + struct fib_nh_exception *fnhe, __rcu **fnhe_p; + u32 hval = fnhe_hashfun(daddr); + + spin_lock_bh(&fnhe_lock); + + hash = rcu_dereference_protected(nhc->nhc_exceptions, + lockdep_is_held(&fnhe_lock)); + hash += hval; + + fnhe_p = &hash->chain; + fnhe = rcu_dereference_protected(*fnhe_p, lockdep_is_held(&fnhe_lock)); + while (fnhe) { + if (fnhe->fnhe_daddr == daddr) { + rcu_assign_pointer(*fnhe_p, rcu_dereference_protected( + fnhe->fnhe_next, lockdep_is_held(&fnhe_lock))); + /* set fnhe_daddr to 0 to ensure it won't bind with + * new dsts in rt_bind_exception(). + */ + fnhe->fnhe_daddr = 0; + fnhe_flush_routes(fnhe); + kfree_rcu(fnhe, rcu); + break; + } + fnhe_p = &fnhe->fnhe_next; + fnhe = rcu_dereference_protected(fnhe->fnhe_next, + lockdep_is_held(&fnhe_lock)); + } + + spin_unlock_bh(&fnhe_lock); +} + +static struct fib_nh_exception *find_exception(struct fib_nh_common *nhc, + __be32 daddr) +{ + struct fnhe_hash_bucket *hash = rcu_dereference(nhc->nhc_exceptions); + struct fib_nh_exception *fnhe; + u32 hval; + + if (!hash) + return NULL; + + hval = fnhe_hashfun(daddr); + + for (fnhe = rcu_dereference(hash[hval].chain); fnhe; + fnhe = rcu_dereference(fnhe->fnhe_next)) { + if (fnhe->fnhe_daddr == daddr) { + if (fnhe->fnhe_expires && + time_after(jiffies, fnhe->fnhe_expires)) { + ip_del_fnhe(nhc, daddr); + break; + } + return fnhe; + } + } + return NULL; +} + +/* MTU selection: + * 1. mtu on route is locked - use it + * 2. mtu from nexthop exception + * 3. mtu from egress device + */ + +u32 ip_mtu_from_fib_result(struct fib_result *res, __be32 daddr) +{ + struct fib_nh_common *nhc = res->nhc; + struct net_device *dev = nhc->nhc_dev; + struct fib_info *fi = res->fi; + u32 mtu = 0; + + if (READ_ONCE(dev_net(dev)->ipv4.sysctl_ip_fwd_use_pmtu) || + fi->fib_metrics->metrics[RTAX_LOCK - 1] & (1 << RTAX_MTU)) + mtu = fi->fib_mtu; + + if (likely(!mtu)) { + struct fib_nh_exception *fnhe; + + fnhe = find_exception(nhc, daddr); + if (fnhe && !time_after_eq(jiffies, fnhe->fnhe_expires)) + mtu = fnhe->fnhe_pmtu; + } + + if (likely(!mtu)) + mtu = min(READ_ONCE(dev->mtu), IP_MAX_MTU); + + return mtu - lwtunnel_headroom(nhc->nhc_lwtstate, mtu); +} + +static bool rt_bind_exception(struct rtable *rt, struct fib_nh_exception *fnhe, + __be32 daddr, const bool do_cache) +{ + bool ret = false; + + spin_lock_bh(&fnhe_lock); + + if (daddr == fnhe->fnhe_daddr) { + struct rtable __rcu **porig; + struct rtable *orig; + int genid = fnhe_genid(dev_net(rt->dst.dev)); + + if (rt_is_input_route(rt)) + porig = &fnhe->fnhe_rth_input; + else + porig = &fnhe->fnhe_rth_output; + orig = rcu_dereference(*porig); + + if (fnhe->fnhe_genid != genid) { + fnhe->fnhe_genid = genid; + fnhe->fnhe_gw = 0; + fnhe->fnhe_pmtu = 0; + fnhe->fnhe_expires = 0; + fnhe->fnhe_mtu_locked = false; + fnhe_flush_routes(fnhe); + orig = NULL; + } + fill_route_from_fnhe(rt, fnhe); + if (!rt->rt_gw4) { + rt->rt_gw4 = daddr; + rt->rt_gw_family = AF_INET; + } + + if (do_cache) { + dst_hold(&rt->dst); + rcu_assign_pointer(*porig, rt); + if (orig) { + dst_dev_put(&orig->dst); + dst_release(&orig->dst); + } + ret = true; + } + + fnhe->fnhe_stamp = jiffies; + } + spin_unlock_bh(&fnhe_lock); + + return ret; +} + +static bool rt_cache_route(struct fib_nh_common *nhc, struct rtable *rt) +{ + struct rtable *orig, *prev, **p; + bool ret = true; + + if (rt_is_input_route(rt)) { + p = (struct rtable **)&nhc->nhc_rth_input; + } else { + p = (struct rtable **)raw_cpu_ptr(nhc->nhc_pcpu_rth_output); + } + orig = *p; + + /* hold dst before doing cmpxchg() to avoid race condition + * on this dst + */ + dst_hold(&rt->dst); + prev = cmpxchg(p, orig, rt); + if (prev == orig) { + if (orig) { + rt_add_uncached_list(orig); + dst_release(&orig->dst); + } + } else { + dst_release(&rt->dst); + ret = false; + } + + return ret; +} + +struct uncached_list { + spinlock_t lock; + struct list_head head; + struct list_head quarantine; +}; + +static DEFINE_PER_CPU_ALIGNED(struct uncached_list, rt_uncached_list); + +void rt_add_uncached_list(struct rtable *rt) +{ + struct uncached_list *ul = raw_cpu_ptr(&rt_uncached_list); + + rt->rt_uncached_list = ul; + + spin_lock_bh(&ul->lock); + list_add_tail(&rt->rt_uncached, &ul->head); + spin_unlock_bh(&ul->lock); +} + +void rt_del_uncached_list(struct rtable *rt) +{ + if (!list_empty(&rt->rt_uncached)) { + struct uncached_list *ul = rt->rt_uncached_list; + + spin_lock_bh(&ul->lock); + list_del_init(&rt->rt_uncached); + spin_unlock_bh(&ul->lock); + } +} + +static void ipv4_dst_destroy(struct dst_entry *dst) +{ + struct rtable *rt = (struct rtable *)dst; + + ip_dst_metrics_put(dst); + rt_del_uncached_list(rt); +} + +void rt_flush_dev(struct net_device *dev) +{ + struct rtable *rt, *safe; + int cpu; + + for_each_possible_cpu(cpu) { + struct uncached_list *ul = &per_cpu(rt_uncached_list, cpu); + + if (list_empty(&ul->head)) + continue; + + spin_lock_bh(&ul->lock); + list_for_each_entry_safe(rt, safe, &ul->head, rt_uncached) { + if (rt->dst.dev != dev) + continue; + rt->dst.dev = blackhole_netdev; + netdev_ref_replace(dev, blackhole_netdev, + &rt->dst.dev_tracker, GFP_ATOMIC); + list_move(&rt->rt_uncached, &ul->quarantine); + } + spin_unlock_bh(&ul->lock); + } +} + +static bool rt_cache_valid(const struct rtable *rt) +{ + return rt && + rt->dst.obsolete == DST_OBSOLETE_FORCE_CHK && + !rt_is_expired(rt); +} + +static void rt_set_nexthop(struct rtable *rt, __be32 daddr, + const struct fib_result *res, + struct fib_nh_exception *fnhe, + struct fib_info *fi, u16 type, u32 itag, + const bool do_cache) +{ + bool cached = false; + + if (fi) { + struct fib_nh_common *nhc = FIB_RES_NHC(*res); + + if (nhc->nhc_gw_family && nhc->nhc_scope == RT_SCOPE_LINK) { + rt->rt_uses_gateway = 1; + rt->rt_gw_family = nhc->nhc_gw_family; + /* only INET and INET6 are supported */ + if (likely(nhc->nhc_gw_family == AF_INET)) + rt->rt_gw4 = nhc->nhc_gw.ipv4; + else + rt->rt_gw6 = nhc->nhc_gw.ipv6; + } + + ip_dst_init_metrics(&rt->dst, fi->fib_metrics); + +#ifdef CONFIG_IP_ROUTE_CLASSID + if (nhc->nhc_family == AF_INET) { + struct fib_nh *nh; + + nh = container_of(nhc, struct fib_nh, nh_common); + rt->dst.tclassid = nh->nh_tclassid; + } +#endif + rt->dst.lwtstate = lwtstate_get(nhc->nhc_lwtstate); + if (unlikely(fnhe)) + cached = rt_bind_exception(rt, fnhe, daddr, do_cache); + else if (do_cache) + cached = rt_cache_route(nhc, rt); + if (unlikely(!cached)) { + /* Routes we intend to cache in nexthop exception or + * FIB nexthop have the DST_NOCACHE bit clear. + * However, if we are unsuccessful at storing this + * route into the cache we really need to set it. + */ + if (!rt->rt_gw4) { + rt->rt_gw_family = AF_INET; + rt->rt_gw4 = daddr; + } + rt_add_uncached_list(rt); + } + } else + rt_add_uncached_list(rt); + +#ifdef CONFIG_IP_ROUTE_CLASSID +#ifdef CONFIG_IP_MULTIPLE_TABLES + set_class_tag(rt, res->tclassid); +#endif + set_class_tag(rt, itag); +#endif +} + +struct rtable *rt_dst_alloc(struct net_device *dev, + unsigned int flags, u16 type, + bool noxfrm) +{ + struct rtable *rt; + + rt = dst_alloc(&ipv4_dst_ops, dev, 1, DST_OBSOLETE_FORCE_CHK, + (noxfrm ? DST_NOXFRM : 0)); + + if (rt) { + rt->rt_genid = rt_genid_ipv4(dev_net(dev)); + rt->rt_flags = flags; + rt->rt_type = type; + rt->rt_is_input = 0; + rt->rt_iif = 0; + rt->rt_pmtu = 0; + rt->rt_mtu_locked = 0; + rt->rt_uses_gateway = 0; + rt->rt_gw_family = 0; + rt->rt_gw4 = 0; + INIT_LIST_HEAD(&rt->rt_uncached); + + rt->dst.output = ip_output; + if (flags & RTCF_LOCAL) + rt->dst.input = ip_local_deliver; + } + + return rt; +} +EXPORT_SYMBOL(rt_dst_alloc); + +struct rtable *rt_dst_clone(struct net_device *dev, struct rtable *rt) +{ + struct rtable *new_rt; + + new_rt = dst_alloc(&ipv4_dst_ops, dev, 1, DST_OBSOLETE_FORCE_CHK, + rt->dst.flags); + + if (new_rt) { + new_rt->rt_genid = rt_genid_ipv4(dev_net(dev)); + new_rt->rt_flags = rt->rt_flags; + new_rt->rt_type = rt->rt_type; + new_rt->rt_is_input = rt->rt_is_input; + new_rt->rt_iif = rt->rt_iif; + new_rt->rt_pmtu = rt->rt_pmtu; + new_rt->rt_mtu_locked = rt->rt_mtu_locked; + new_rt->rt_gw_family = rt->rt_gw_family; + if (rt->rt_gw_family == AF_INET) + new_rt->rt_gw4 = rt->rt_gw4; + else if (rt->rt_gw_family == AF_INET6) + new_rt->rt_gw6 = rt->rt_gw6; + INIT_LIST_HEAD(&new_rt->rt_uncached); + + new_rt->dst.input = rt->dst.input; + new_rt->dst.output = rt->dst.output; + new_rt->dst.error = rt->dst.error; + new_rt->dst.lastuse = jiffies; + new_rt->dst.lwtstate = lwtstate_get(rt->dst.lwtstate); + } + return new_rt; +} +EXPORT_SYMBOL(rt_dst_clone); + +/* called in rcu_read_lock() section */ +int ip_mc_validate_source(struct sk_buff *skb, __be32 daddr, __be32 saddr, + u8 tos, struct net_device *dev, + struct in_device *in_dev, u32 *itag) +{ + int err; + + /* Primary sanity checks. */ + if (!in_dev) + return -EINVAL; + + if (ipv4_is_multicast(saddr) || ipv4_is_lbcast(saddr) || + skb->protocol != htons(ETH_P_IP)) + return -EINVAL; + + if (ipv4_is_loopback(saddr) && !IN_DEV_ROUTE_LOCALNET(in_dev)) + return -EINVAL; + + if (ipv4_is_zeronet(saddr)) { + if (!ipv4_is_local_multicast(daddr) && + ip_hdr(skb)->protocol != IPPROTO_IGMP) + return -EINVAL; + } else { + err = fib_validate_source(skb, saddr, 0, tos, 0, dev, + in_dev, itag); + if (err < 0) + return err; + } + return 0; +} + +/* called in rcu_read_lock() section */ +static int ip_route_input_mc(struct sk_buff *skb, __be32 daddr, __be32 saddr, + u8 tos, struct net_device *dev, int our) +{ + struct in_device *in_dev = __in_dev_get_rcu(dev); + unsigned int flags = RTCF_MULTICAST; + struct rtable *rth; + u32 itag = 0; + int err; + + err = ip_mc_validate_source(skb, daddr, saddr, tos, dev, in_dev, &itag); + if (err) + return err; + + if (our) + flags |= RTCF_LOCAL; + + if (IN_DEV_ORCONF(in_dev, NOPOLICY)) + IPCB(skb)->flags |= IPSKB_NOPOLICY; + + rth = rt_dst_alloc(dev_net(dev)->loopback_dev, flags, RTN_MULTICAST, + false); + if (!rth) + return -ENOBUFS; + +#ifdef CONFIG_IP_ROUTE_CLASSID + rth->dst.tclassid = itag; +#endif + rth->dst.output = ip_rt_bug; + rth->rt_is_input= 1; + +#ifdef CONFIG_IP_MROUTE + if (!ipv4_is_local_multicast(daddr) && IN_DEV_MFORWARD(in_dev)) + rth->dst.input = ip_mr_input; +#endif + RT_CACHE_STAT_INC(in_slow_mc); + + skb_dst_drop(skb); + skb_dst_set(skb, &rth->dst); + return 0; +} + + +static void ip_handle_martian_source(struct net_device *dev, + struct in_device *in_dev, + struct sk_buff *skb, + __be32 daddr, + __be32 saddr) +{ + RT_CACHE_STAT_INC(in_martian_src); +#ifdef CONFIG_IP_ROUTE_VERBOSE + if (IN_DEV_LOG_MARTIANS(in_dev) && net_ratelimit()) { + /* + * RFC1812 recommendation, if source is martian, + * the only hint is MAC header. + */ + pr_warn("martian source %pI4 from %pI4, on dev %s\n", + &daddr, &saddr, dev->name); + if (dev->hard_header_len && skb_mac_header_was_set(skb)) { + print_hex_dump(KERN_WARNING, "ll header: ", + DUMP_PREFIX_OFFSET, 16, 1, + skb_mac_header(skb), + dev->hard_header_len, false); + } + } +#endif +} + +/* called in rcu_read_lock() section */ +static int __mkroute_input(struct sk_buff *skb, + const struct fib_result *res, + struct in_device *in_dev, + __be32 daddr, __be32 saddr, u32 tos) +{ + struct fib_nh_common *nhc = FIB_RES_NHC(*res); + struct net_device *dev = nhc->nhc_dev; + struct fib_nh_exception *fnhe; + struct rtable *rth; + int err; + struct in_device *out_dev; + bool do_cache; + u32 itag = 0; + + /* get a working reference to the output device */ + out_dev = __in_dev_get_rcu(dev); + if (!out_dev) { + net_crit_ratelimited("Bug in ip_route_input_slow(). Please report.\n"); + return -EINVAL; + } + + err = fib_validate_source(skb, saddr, daddr, tos, FIB_RES_OIF(*res), + in_dev->dev, in_dev, &itag); + if (err < 0) { + ip_handle_martian_source(in_dev->dev, in_dev, skb, daddr, + saddr); + + goto cleanup; + } + + do_cache = res->fi && !itag; + if (out_dev == in_dev && err && IN_DEV_TX_REDIRECTS(out_dev) && + skb->protocol == htons(ETH_P_IP)) { + __be32 gw; + + gw = nhc->nhc_gw_family == AF_INET ? nhc->nhc_gw.ipv4 : 0; + if (IN_DEV_SHARED_MEDIA(out_dev) || + inet_addr_onlink(out_dev, saddr, gw)) + IPCB(skb)->flags |= IPSKB_DOREDIRECT; + } + + if (skb->protocol != htons(ETH_P_IP)) { + /* Not IP (i.e. ARP). Do not create route, if it is + * invalid for proxy arp. DNAT routes are always valid. + * + * Proxy arp feature have been extended to allow, ARP + * replies back to the same interface, to support + * Private VLAN switch technologies. See arp.c. + */ + if (out_dev == in_dev && + IN_DEV_PROXY_ARP_PVLAN(in_dev) == 0) { + err = -EINVAL; + goto cleanup; + } + } + + if (IN_DEV_ORCONF(in_dev, NOPOLICY)) + IPCB(skb)->flags |= IPSKB_NOPOLICY; + + fnhe = find_exception(nhc, daddr); + if (do_cache) { + if (fnhe) + rth = rcu_dereference(fnhe->fnhe_rth_input); + else + rth = rcu_dereference(nhc->nhc_rth_input); + if (rt_cache_valid(rth)) { + skb_dst_set_noref(skb, &rth->dst); + goto out; + } + } + + rth = rt_dst_alloc(out_dev->dev, 0, res->type, + IN_DEV_ORCONF(out_dev, NOXFRM)); + if (!rth) { + err = -ENOBUFS; + goto cleanup; + } + + rth->rt_is_input = 1; + RT_CACHE_STAT_INC(in_slow_tot); + + rth->dst.input = ip_forward; + + rt_set_nexthop(rth, daddr, res, fnhe, res->fi, res->type, itag, + do_cache); + lwtunnel_set_redirect(&rth->dst); + skb_dst_set(skb, &rth->dst); +out: + err = 0; + cleanup: + return err; +} + +#ifdef CONFIG_IP_ROUTE_MULTIPATH +/* To make ICMP packets follow the right flow, the multipath hash is + * calculated from the inner IP addresses. + */ +static void ip_multipath_l3_keys(const struct sk_buff *skb, + struct flow_keys *hash_keys) +{ + const struct iphdr *outer_iph = ip_hdr(skb); + const struct iphdr *key_iph = outer_iph; + const struct iphdr *inner_iph; + const struct icmphdr *icmph; + struct iphdr _inner_iph; + struct icmphdr _icmph; + + if (likely(outer_iph->protocol != IPPROTO_ICMP)) + goto out; + + if (unlikely((outer_iph->frag_off & htons(IP_OFFSET)) != 0)) + goto out; + + icmph = skb_header_pointer(skb, outer_iph->ihl * 4, sizeof(_icmph), + &_icmph); + if (!icmph) + goto out; + + if (!icmp_is_err(icmph->type)) + goto out; + + inner_iph = skb_header_pointer(skb, + outer_iph->ihl * 4 + sizeof(_icmph), + sizeof(_inner_iph), &_inner_iph); + if (!inner_iph) + goto out; + + key_iph = inner_iph; +out: + hash_keys->addrs.v4addrs.src = key_iph->saddr; + hash_keys->addrs.v4addrs.dst = key_iph->daddr; +} + +static u32 fib_multipath_custom_hash_outer(const struct net *net, + const struct sk_buff *skb, + bool *p_has_inner) +{ + u32 hash_fields = READ_ONCE(net->ipv4.sysctl_fib_multipath_hash_fields); + struct flow_keys keys, hash_keys; + + if (!(hash_fields & FIB_MULTIPATH_HASH_FIELD_OUTER_MASK)) + return 0; + + memset(&hash_keys, 0, sizeof(hash_keys)); + skb_flow_dissect_flow_keys(skb, &keys, FLOW_DISSECTOR_F_STOP_AT_ENCAP); + + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_SRC_IP) + hash_keys.addrs.v4addrs.src = keys.addrs.v4addrs.src; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_DST_IP) + hash_keys.addrs.v4addrs.dst = keys.addrs.v4addrs.dst; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_IP_PROTO) + hash_keys.basic.ip_proto = keys.basic.ip_proto; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_SRC_PORT) + hash_keys.ports.src = keys.ports.src; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_DST_PORT) + hash_keys.ports.dst = keys.ports.dst; + + *p_has_inner = !!(keys.control.flags & FLOW_DIS_ENCAPSULATION); + return flow_hash_from_keys(&hash_keys); +} + +static u32 fib_multipath_custom_hash_inner(const struct net *net, + const struct sk_buff *skb, + bool has_inner) +{ + u32 hash_fields = READ_ONCE(net->ipv4.sysctl_fib_multipath_hash_fields); + struct flow_keys keys, hash_keys; + + /* We assume the packet carries an encapsulation, but if none was + * encountered during dissection of the outer flow, then there is no + * point in calling the flow dissector again. + */ + if (!has_inner) + return 0; + + if (!(hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_MASK)) + return 0; + + memset(&hash_keys, 0, sizeof(hash_keys)); + skb_flow_dissect_flow_keys(skb, &keys, 0); + + if (!(keys.control.flags & FLOW_DIS_ENCAPSULATION)) + return 0; + + if (keys.control.addr_type == FLOW_DISSECTOR_KEY_IPV4_ADDRS) { + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_SRC_IP) + hash_keys.addrs.v4addrs.src = keys.addrs.v4addrs.src; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_DST_IP) + hash_keys.addrs.v4addrs.dst = keys.addrs.v4addrs.dst; + } else if (keys.control.addr_type == FLOW_DISSECTOR_KEY_IPV6_ADDRS) { + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_SRC_IP) + hash_keys.addrs.v6addrs.src = keys.addrs.v6addrs.src; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_DST_IP) + hash_keys.addrs.v6addrs.dst = keys.addrs.v6addrs.dst; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_FLOWLABEL) + hash_keys.tags.flow_label = keys.tags.flow_label; + } + + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_IP_PROTO) + hash_keys.basic.ip_proto = keys.basic.ip_proto; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_SRC_PORT) + hash_keys.ports.src = keys.ports.src; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_DST_PORT) + hash_keys.ports.dst = keys.ports.dst; + + return flow_hash_from_keys(&hash_keys); +} + +static u32 fib_multipath_custom_hash_skb(const struct net *net, + const struct sk_buff *skb) +{ + u32 mhash, mhash_inner; + bool has_inner = true; + + mhash = fib_multipath_custom_hash_outer(net, skb, &has_inner); + mhash_inner = fib_multipath_custom_hash_inner(net, skb, has_inner); + + return jhash_2words(mhash, mhash_inner, 0); +} + +static u32 fib_multipath_custom_hash_fl4(const struct net *net, + const struct flowi4 *fl4) +{ + u32 hash_fields = READ_ONCE(net->ipv4.sysctl_fib_multipath_hash_fields); + struct flow_keys hash_keys; + + if (!(hash_fields & FIB_MULTIPATH_HASH_FIELD_OUTER_MASK)) + return 0; + + memset(&hash_keys, 0, sizeof(hash_keys)); + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_SRC_IP) + hash_keys.addrs.v4addrs.src = fl4->saddr; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_DST_IP) + hash_keys.addrs.v4addrs.dst = fl4->daddr; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_IP_PROTO) + hash_keys.basic.ip_proto = fl4->flowi4_proto; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_SRC_PORT) + hash_keys.ports.src = fl4->fl4_sport; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_DST_PORT) + hash_keys.ports.dst = fl4->fl4_dport; + + return flow_hash_from_keys(&hash_keys); +} + +/* if skb is set it will be used and fl4 can be NULL */ +int fib_multipath_hash(const struct net *net, const struct flowi4 *fl4, + const struct sk_buff *skb, struct flow_keys *flkeys) +{ + u32 multipath_hash = fl4 ? fl4->flowi4_multipath_hash : 0; + struct flow_keys hash_keys; + u32 mhash = 0; + + switch (READ_ONCE(net->ipv4.sysctl_fib_multipath_hash_policy)) { + case 0: + memset(&hash_keys, 0, sizeof(hash_keys)); + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; + if (skb) { + ip_multipath_l3_keys(skb, &hash_keys); + } else { + hash_keys.addrs.v4addrs.src = fl4->saddr; + hash_keys.addrs.v4addrs.dst = fl4->daddr; + } + mhash = flow_hash_from_keys(&hash_keys); + break; + case 1: + /* skb is currently provided only when forwarding */ + if (skb) { + unsigned int flag = FLOW_DISSECTOR_F_STOP_AT_ENCAP; + struct flow_keys keys; + + /* short-circuit if we already have L4 hash present */ + if (skb->l4_hash) + return skb_get_hash_raw(skb) >> 1; + + memset(&hash_keys, 0, sizeof(hash_keys)); + + if (!flkeys) { + skb_flow_dissect_flow_keys(skb, &keys, flag); + flkeys = &keys; + } + + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; + hash_keys.addrs.v4addrs.src = flkeys->addrs.v4addrs.src; + hash_keys.addrs.v4addrs.dst = flkeys->addrs.v4addrs.dst; + hash_keys.ports.src = flkeys->ports.src; + hash_keys.ports.dst = flkeys->ports.dst; + hash_keys.basic.ip_proto = flkeys->basic.ip_proto; + } else { + memset(&hash_keys, 0, sizeof(hash_keys)); + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; + hash_keys.addrs.v4addrs.src = fl4->saddr; + hash_keys.addrs.v4addrs.dst = fl4->daddr; + hash_keys.ports.src = fl4->fl4_sport; + hash_keys.ports.dst = fl4->fl4_dport; + hash_keys.basic.ip_proto = fl4->flowi4_proto; + } + mhash = flow_hash_from_keys(&hash_keys); + break; + case 2: + memset(&hash_keys, 0, sizeof(hash_keys)); + /* skb is currently provided only when forwarding */ + if (skb) { + struct flow_keys keys; + + skb_flow_dissect_flow_keys(skb, &keys, 0); + /* Inner can be v4 or v6 */ + if (keys.control.addr_type == FLOW_DISSECTOR_KEY_IPV4_ADDRS) { + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; + hash_keys.addrs.v4addrs.src = keys.addrs.v4addrs.src; + hash_keys.addrs.v4addrs.dst = keys.addrs.v4addrs.dst; + } else if (keys.control.addr_type == FLOW_DISSECTOR_KEY_IPV6_ADDRS) { + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + hash_keys.addrs.v6addrs.src = keys.addrs.v6addrs.src; + hash_keys.addrs.v6addrs.dst = keys.addrs.v6addrs.dst; + hash_keys.tags.flow_label = keys.tags.flow_label; + hash_keys.basic.ip_proto = keys.basic.ip_proto; + } else { + /* Same as case 0 */ + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; + ip_multipath_l3_keys(skb, &hash_keys); + } + } else { + /* Same as case 0 */ + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; + hash_keys.addrs.v4addrs.src = fl4->saddr; + hash_keys.addrs.v4addrs.dst = fl4->daddr; + } + mhash = flow_hash_from_keys(&hash_keys); + break; + case 3: + if (skb) + mhash = fib_multipath_custom_hash_skb(net, skb); + else + mhash = fib_multipath_custom_hash_fl4(net, fl4); + break; + } + + if (multipath_hash) + mhash = jhash_2words(mhash, multipath_hash, 0); + + return mhash >> 1; +} +#endif /* CONFIG_IP_ROUTE_MULTIPATH */ + +static int ip_mkroute_input(struct sk_buff *skb, + struct fib_result *res, + struct in_device *in_dev, + __be32 daddr, __be32 saddr, u32 tos, + struct flow_keys *hkeys) +{ +#ifdef CONFIG_IP_ROUTE_MULTIPATH + if (res->fi && fib_info_num_path(res->fi) > 1) { + int h = fib_multipath_hash(res->fi->fib_net, NULL, skb, hkeys); + + fib_select_multipath(res, h); + IPCB(skb)->flags |= IPSKB_MULTIPATH; + } +#endif + + /* create a routing cache entry */ + return __mkroute_input(skb, res, in_dev, daddr, saddr, tos); +} + +/* Implements all the saddr-related checks as ip_route_input_slow(), + * assuming daddr is valid and the destination is not a local broadcast one. + * Uses the provided hint instead of performing a route lookup. + */ +int ip_route_use_hint(struct sk_buff *skb, __be32 daddr, __be32 saddr, + u8 tos, struct net_device *dev, + const struct sk_buff *hint) +{ + struct in_device *in_dev = __in_dev_get_rcu(dev); + struct rtable *rt = skb_rtable(hint); + struct net *net = dev_net(dev); + int err = -EINVAL; + u32 tag = 0; + + if (ipv4_is_multicast(saddr) || ipv4_is_lbcast(saddr)) + goto martian_source; + + if (ipv4_is_zeronet(saddr)) + goto martian_source; + + if (ipv4_is_loopback(saddr) && !IN_DEV_NET_ROUTE_LOCALNET(in_dev, net)) + goto martian_source; + + if (rt->rt_type != RTN_LOCAL) + goto skip_validate_source; + + tos &= IPTOS_RT_MASK; + err = fib_validate_source(skb, saddr, daddr, tos, 0, dev, in_dev, &tag); + if (err < 0) + goto martian_source; + +skip_validate_source: + skb_dst_copy(skb, hint); + return 0; + +martian_source: + ip_handle_martian_source(dev, in_dev, skb, daddr, saddr); + return err; +} + +/* get device for dst_alloc with local routes */ +static struct net_device *ip_rt_get_dev(struct net *net, + const struct fib_result *res) +{ + struct fib_nh_common *nhc = res->fi ? res->nhc : NULL; + struct net_device *dev = NULL; + + if (nhc) + dev = l3mdev_master_dev_rcu(nhc->nhc_dev); + + return dev ? : net->loopback_dev; +} + +/* + * NOTE. We drop all the packets that has local source + * addresses, because every properly looped back packet + * must have correct destination already attached by output routine. + * Changes in the enforced policies must be applied also to + * ip_route_use_hint(). + * + * Such approach solves two big problems: + * 1. Not simplex devices are handled properly. + * 2. IP spoofing attempts are filtered with 100% of guarantee. + * called with rcu_read_lock() + */ + +static int ip_route_input_slow(struct sk_buff *skb, __be32 daddr, __be32 saddr, + u8 tos, struct net_device *dev, + struct fib_result *res) +{ + struct in_device *in_dev = __in_dev_get_rcu(dev); + struct flow_keys *flkeys = NULL, _flkeys; + struct net *net = dev_net(dev); + struct ip_tunnel_info *tun_info; + int err = -EINVAL; + unsigned int flags = 0; + u32 itag = 0; + struct rtable *rth; + struct flowi4 fl4; + bool do_cache = true; + + /* IP on this device is disabled. */ + + if (!in_dev) + goto out; + + /* Check for the most weird martians, which can be not detected + * by fib_lookup. + */ + + tun_info = skb_tunnel_info(skb); + if (tun_info && !(tun_info->mode & IP_TUNNEL_INFO_TX)) + fl4.flowi4_tun_key.tun_id = tun_info->key.tun_id; + else + fl4.flowi4_tun_key.tun_id = 0; + skb_dst_drop(skb); + + if (ipv4_is_multicast(saddr) || ipv4_is_lbcast(saddr)) + goto martian_source; + + res->fi = NULL; + res->table = NULL; + if (ipv4_is_lbcast(daddr) || (saddr == 0 && daddr == 0)) + goto brd_input; + + /* Accept zero addresses only to limited broadcast; + * I even do not know to fix it or not. Waiting for complains :-) + */ + if (ipv4_is_zeronet(saddr)) + goto martian_source; + + if (ipv4_is_zeronet(daddr)) + goto martian_destination; + + /* Following code try to avoid calling IN_DEV_NET_ROUTE_LOCALNET(), + * and call it once if daddr or/and saddr are loopback addresses + */ + if (ipv4_is_loopback(daddr)) { + if (!IN_DEV_NET_ROUTE_LOCALNET(in_dev, net)) + goto martian_destination; + } else if (ipv4_is_loopback(saddr)) { + if (!IN_DEV_NET_ROUTE_LOCALNET(in_dev, net)) + goto martian_source; + } + + /* + * Now we are ready to route packet. + */ + fl4.flowi4_l3mdev = 0; + fl4.flowi4_oif = 0; + fl4.flowi4_iif = dev->ifindex; + fl4.flowi4_mark = skb->mark; + fl4.flowi4_tos = tos; + fl4.flowi4_scope = RT_SCOPE_UNIVERSE; + fl4.flowi4_flags = 0; + fl4.daddr = daddr; + fl4.saddr = saddr; + fl4.flowi4_uid = sock_net_uid(net, NULL); + fl4.flowi4_multipath_hash = 0; + + if (fib4_rules_early_flow_dissect(net, skb, &fl4, &_flkeys)) { + flkeys = &_flkeys; + } else { + fl4.flowi4_proto = 0; + fl4.fl4_sport = 0; + fl4.fl4_dport = 0; + } + + err = fib_lookup(net, &fl4, res, 0); + if (err != 0) { + if (!IN_DEV_FORWARD(in_dev)) + err = -EHOSTUNREACH; + goto no_route; + } + + if (res->type == RTN_BROADCAST) { + if (IN_DEV_BFORWARD(in_dev)) + goto make_route; + /* not do cache if bc_forwarding is enabled */ + if (IPV4_DEVCONF_ALL(net, BC_FORWARDING)) + do_cache = false; + goto brd_input; + } + + if (res->type == RTN_LOCAL) { + err = fib_validate_source(skb, saddr, daddr, tos, + 0, dev, in_dev, &itag); + if (err < 0) + goto martian_source; + goto local_input; + } + + if (!IN_DEV_FORWARD(in_dev)) { + err = -EHOSTUNREACH; + goto no_route; + } + if (res->type != RTN_UNICAST) + goto martian_destination; + +make_route: + err = ip_mkroute_input(skb, res, in_dev, daddr, saddr, tos, flkeys); +out: return err; + +brd_input: + if (skb->protocol != htons(ETH_P_IP)) + goto e_inval; + + if (!ipv4_is_zeronet(saddr)) { + err = fib_validate_source(skb, saddr, 0, tos, 0, dev, + in_dev, &itag); + if (err < 0) + goto martian_source; + } + flags |= RTCF_BROADCAST; + res->type = RTN_BROADCAST; + RT_CACHE_STAT_INC(in_brd); + +local_input: + if (IN_DEV_ORCONF(in_dev, NOPOLICY)) + IPCB(skb)->flags |= IPSKB_NOPOLICY; + + do_cache &= res->fi && !itag; + if (do_cache) { + struct fib_nh_common *nhc = FIB_RES_NHC(*res); + + rth = rcu_dereference(nhc->nhc_rth_input); + if (rt_cache_valid(rth)) { + skb_dst_set_noref(skb, &rth->dst); + err = 0; + goto out; + } + } + + rth = rt_dst_alloc(ip_rt_get_dev(net, res), + flags | RTCF_LOCAL, res->type, false); + if (!rth) + goto e_nobufs; + + rth->dst.output= ip_rt_bug; +#ifdef CONFIG_IP_ROUTE_CLASSID + rth->dst.tclassid = itag; +#endif + rth->rt_is_input = 1; + + RT_CACHE_STAT_INC(in_slow_tot); + if (res->type == RTN_UNREACHABLE) { + rth->dst.input= ip_error; + rth->dst.error= -err; + rth->rt_flags &= ~RTCF_LOCAL; + } + + if (do_cache) { + struct fib_nh_common *nhc = FIB_RES_NHC(*res); + + rth->dst.lwtstate = lwtstate_get(nhc->nhc_lwtstate); + if (lwtunnel_input_redirect(rth->dst.lwtstate)) { + WARN_ON(rth->dst.input == lwtunnel_input); + rth->dst.lwtstate->orig_input = rth->dst.input; + rth->dst.input = lwtunnel_input; + } + + if (unlikely(!rt_cache_route(nhc, rth))) + rt_add_uncached_list(rth); + } + skb_dst_set(skb, &rth->dst); + err = 0; + goto out; + +no_route: + RT_CACHE_STAT_INC(in_no_route); + res->type = RTN_UNREACHABLE; + res->fi = NULL; + res->table = NULL; + goto local_input; + + /* + * Do not cache martian addresses: they should be logged (RFC1812) + */ +martian_destination: + RT_CACHE_STAT_INC(in_martian_dst); +#ifdef CONFIG_IP_ROUTE_VERBOSE + if (IN_DEV_LOG_MARTIANS(in_dev)) + net_warn_ratelimited("martian destination %pI4 from %pI4, dev %s\n", + &daddr, &saddr, dev->name); +#endif + +e_inval: + err = -EINVAL; + goto out; + +e_nobufs: + err = -ENOBUFS; + goto out; + +martian_source: + ip_handle_martian_source(dev, in_dev, skb, daddr, saddr); + goto out; +} + +/* called with rcu_read_lock held */ +static int ip_route_input_rcu(struct sk_buff *skb, __be32 daddr, __be32 saddr, + u8 tos, struct net_device *dev, struct fib_result *res) +{ + /* Multicast recognition logic is moved from route cache to here. + * The problem was that too many Ethernet cards have broken/missing + * hardware multicast filters :-( As result the host on multicasting + * network acquires a lot of useless route cache entries, sort of + * SDR messages from all the world. Now we try to get rid of them. + * Really, provided software IP multicast filter is organized + * reasonably (at least, hashed), it does not result in a slowdown + * comparing with route cache reject entries. + * Note, that multicast routers are not affected, because + * route cache entry is created eventually. + */ + if (ipv4_is_multicast(daddr)) { + struct in_device *in_dev = __in_dev_get_rcu(dev); + int our = 0; + int err = -EINVAL; + + if (!in_dev) + return err; + our = ip_check_mc_rcu(in_dev, daddr, saddr, + ip_hdr(skb)->protocol); + + /* check l3 master if no match yet */ + if (!our && netif_is_l3_slave(dev)) { + struct in_device *l3_in_dev; + + l3_in_dev = __in_dev_get_rcu(skb->dev); + if (l3_in_dev) + our = ip_check_mc_rcu(l3_in_dev, daddr, saddr, + ip_hdr(skb)->protocol); + } + + if (our +#ifdef CONFIG_IP_MROUTE + || + (!ipv4_is_local_multicast(daddr) && + IN_DEV_MFORWARD(in_dev)) +#endif + ) { + err = ip_route_input_mc(skb, daddr, saddr, + tos, dev, our); + } + return err; + } + + return ip_route_input_slow(skb, daddr, saddr, tos, dev, res); +} + +int ip_route_input_noref(struct sk_buff *skb, __be32 daddr, __be32 saddr, + u8 tos, struct net_device *dev) +{ + struct fib_result res; + int err; + + tos &= IPTOS_RT_MASK; + rcu_read_lock(); + err = ip_route_input_rcu(skb, daddr, saddr, tos, dev, &res); + rcu_read_unlock(); + + return err; +} +EXPORT_SYMBOL(ip_route_input_noref); + +/* called with rcu_read_lock() */ +static struct rtable *__mkroute_output(const struct fib_result *res, + const struct flowi4 *fl4, int orig_oif, + struct net_device *dev_out, + unsigned int flags) +{ + struct fib_info *fi = res->fi; + struct fib_nh_exception *fnhe; + struct in_device *in_dev; + u16 type = res->type; + struct rtable *rth; + bool do_cache; + + in_dev = __in_dev_get_rcu(dev_out); + if (!in_dev) + return ERR_PTR(-EINVAL); + + if (likely(!IN_DEV_ROUTE_LOCALNET(in_dev))) + if (ipv4_is_loopback(fl4->saddr) && + !(dev_out->flags & IFF_LOOPBACK) && + !netif_is_l3_master(dev_out)) + return ERR_PTR(-EINVAL); + + if (ipv4_is_lbcast(fl4->daddr)) + type = RTN_BROADCAST; + else if (ipv4_is_multicast(fl4->daddr)) + type = RTN_MULTICAST; + else if (ipv4_is_zeronet(fl4->daddr)) + return ERR_PTR(-EINVAL); + + if (dev_out->flags & IFF_LOOPBACK) + flags |= RTCF_LOCAL; + + do_cache = true; + if (type == RTN_BROADCAST) { + flags |= RTCF_BROADCAST | RTCF_LOCAL; + fi = NULL; + } else if (type == RTN_MULTICAST) { + flags |= RTCF_MULTICAST | RTCF_LOCAL; + if (!ip_check_mc_rcu(in_dev, fl4->daddr, fl4->saddr, + fl4->flowi4_proto)) + flags &= ~RTCF_LOCAL; + else + do_cache = false; + /* If multicast route do not exist use + * default one, but do not gateway in this case. + * Yes, it is hack. + */ + if (fi && res->prefixlen < 4) + fi = NULL; + } else if ((type == RTN_LOCAL) && (orig_oif != 0) && + (orig_oif != dev_out->ifindex)) { + /* For local routes that require a particular output interface + * we do not want to cache the result. Caching the result + * causes incorrect behaviour when there are multiple source + * addresses on the interface, the end result being that if the + * intended recipient is waiting on that interface for the + * packet he won't receive it because it will be delivered on + * the loopback interface and the IP_PKTINFO ipi_ifindex will + * be set to the loopback interface as well. + */ + do_cache = false; + } + + fnhe = NULL; + do_cache &= fi != NULL; + if (fi) { + struct fib_nh_common *nhc = FIB_RES_NHC(*res); + struct rtable __rcu **prth; + + fnhe = find_exception(nhc, fl4->daddr); + if (!do_cache) + goto add; + if (fnhe) { + prth = &fnhe->fnhe_rth_output; + } else { + if (unlikely(fl4->flowi4_flags & + FLOWI_FLAG_KNOWN_NH && + !(nhc->nhc_gw_family && + nhc->nhc_scope == RT_SCOPE_LINK))) { + do_cache = false; + goto add; + } + prth = raw_cpu_ptr(nhc->nhc_pcpu_rth_output); + } + rth = rcu_dereference(*prth); + if (rt_cache_valid(rth) && dst_hold_safe(&rth->dst)) + return rth; + } + +add: + rth = rt_dst_alloc(dev_out, flags, type, + IN_DEV_ORCONF(in_dev, NOXFRM)); + if (!rth) + return ERR_PTR(-ENOBUFS); + + rth->rt_iif = orig_oif; + + RT_CACHE_STAT_INC(out_slow_tot); + + if (flags & (RTCF_BROADCAST | RTCF_MULTICAST)) { + if (flags & RTCF_LOCAL && + !(dev_out->flags & IFF_LOOPBACK)) { + rth->dst.output = ip_mc_output; + RT_CACHE_STAT_INC(out_slow_mc); + } +#ifdef CONFIG_IP_MROUTE + if (type == RTN_MULTICAST) { + if (IN_DEV_MFORWARD(in_dev) && + !ipv4_is_local_multicast(fl4->daddr)) { + rth->dst.input = ip_mr_input; + rth->dst.output = ip_mc_output; + } + } +#endif + } + + rt_set_nexthop(rth, fl4->daddr, res, fnhe, fi, type, 0, do_cache); + lwtunnel_set_redirect(&rth->dst); + + return rth; +} + +/* + * Major route resolver routine. + */ + +struct rtable *ip_route_output_key_hash(struct net *net, struct flowi4 *fl4, + const struct sk_buff *skb) +{ + struct fib_result res = { + .type = RTN_UNSPEC, + .fi = NULL, + .table = NULL, + .tclassid = 0, + }; + struct rtable *rth; + + fl4->flowi4_iif = LOOPBACK_IFINDEX; + ip_rt_fix_tos(fl4); + + rcu_read_lock(); + rth = ip_route_output_key_hash_rcu(net, fl4, &res, skb); + rcu_read_unlock(); + + return rth; +} +EXPORT_SYMBOL_GPL(ip_route_output_key_hash); + +struct rtable *ip_route_output_key_hash_rcu(struct net *net, struct flowi4 *fl4, + struct fib_result *res, + const struct sk_buff *skb) +{ + struct net_device *dev_out = NULL; + int orig_oif = fl4->flowi4_oif; + unsigned int flags = 0; + struct rtable *rth; + int err; + + if (fl4->saddr) { + if (ipv4_is_multicast(fl4->saddr) || + ipv4_is_lbcast(fl4->saddr) || + ipv4_is_zeronet(fl4->saddr)) { + rth = ERR_PTR(-EINVAL); + goto out; + } + + rth = ERR_PTR(-ENETUNREACH); + + /* I removed check for oif == dev_out->oif here. + * It was wrong for two reasons: + * 1. ip_dev_find(net, saddr) can return wrong iface, if saddr + * is assigned to multiple interfaces. + * 2. Moreover, we are allowed to send packets with saddr + * of another iface. --ANK + */ + + if (fl4->flowi4_oif == 0 && + (ipv4_is_multicast(fl4->daddr) || + ipv4_is_lbcast(fl4->daddr))) { + /* It is equivalent to inet_addr_type(saddr) == RTN_LOCAL */ + dev_out = __ip_dev_find(net, fl4->saddr, false); + if (!dev_out) + goto out; + + /* Special hack: user can direct multicasts + * and limited broadcast via necessary interface + * without fiddling with IP_MULTICAST_IF or IP_PKTINFO. + * This hack is not just for fun, it allows + * vic,vat and friends to work. + * They bind socket to loopback, set ttl to zero + * and expect that it will work. + * From the viewpoint of routing cache they are broken, + * because we are not allowed to build multicast path + * with loopback source addr (look, routing cache + * cannot know, that ttl is zero, so that packet + * will not leave this host and route is valid). + * Luckily, this hack is good workaround. + */ + + fl4->flowi4_oif = dev_out->ifindex; + goto make_route; + } + + if (!(fl4->flowi4_flags & FLOWI_FLAG_ANYSRC)) { + /* It is equivalent to inet_addr_type(saddr) == RTN_LOCAL */ + if (!__ip_dev_find(net, fl4->saddr, false)) + goto out; + } + } + + + if (fl4->flowi4_oif) { + dev_out = dev_get_by_index_rcu(net, fl4->flowi4_oif); + rth = ERR_PTR(-ENODEV); + if (!dev_out) + goto out; + + /* RACE: Check return value of inet_select_addr instead. */ + if (!(dev_out->flags & IFF_UP) || !__in_dev_get_rcu(dev_out)) { + rth = ERR_PTR(-ENETUNREACH); + goto out; + } + if (ipv4_is_local_multicast(fl4->daddr) || + ipv4_is_lbcast(fl4->daddr) || + fl4->flowi4_proto == IPPROTO_IGMP) { + if (!fl4->saddr) + fl4->saddr = inet_select_addr(dev_out, 0, + RT_SCOPE_LINK); + goto make_route; + } + if (!fl4->saddr) { + if (ipv4_is_multicast(fl4->daddr)) + fl4->saddr = inet_select_addr(dev_out, 0, + fl4->flowi4_scope); + else if (!fl4->daddr) + fl4->saddr = inet_select_addr(dev_out, 0, + RT_SCOPE_HOST); + } + } + + if (!fl4->daddr) { + fl4->daddr = fl4->saddr; + if (!fl4->daddr) + fl4->daddr = fl4->saddr = htonl(INADDR_LOOPBACK); + dev_out = net->loopback_dev; + fl4->flowi4_oif = LOOPBACK_IFINDEX; + res->type = RTN_LOCAL; + flags |= RTCF_LOCAL; + goto make_route; + } + + err = fib_lookup(net, fl4, res, 0); + if (err) { + res->fi = NULL; + res->table = NULL; + if (fl4->flowi4_oif && + (ipv4_is_multicast(fl4->daddr) || !fl4->flowi4_l3mdev)) { + /* Apparently, routing tables are wrong. Assume, + * that the destination is on link. + * + * WHY? DW. + * Because we are allowed to send to iface + * even if it has NO routes and NO assigned + * addresses. When oif is specified, routing + * tables are looked up with only one purpose: + * to catch if destination is gatewayed, rather than + * direct. Moreover, if MSG_DONTROUTE is set, + * we send packet, ignoring both routing tables + * and ifaddr state. --ANK + * + * + * We could make it even if oif is unknown, + * likely IPv6, but we do not. + */ + + if (fl4->saddr == 0) + fl4->saddr = inet_select_addr(dev_out, 0, + RT_SCOPE_LINK); + res->type = RTN_UNICAST; + goto make_route; + } + rth = ERR_PTR(err); + goto out; + } + + if (res->type == RTN_LOCAL) { + if (!fl4->saddr) { + if (res->fi->fib_prefsrc) + fl4->saddr = res->fi->fib_prefsrc; + else + fl4->saddr = fl4->daddr; + } + + /* L3 master device is the loopback for that domain */ + dev_out = l3mdev_master_dev_rcu(FIB_RES_DEV(*res)) ? : + net->loopback_dev; + + /* make sure orig_oif points to fib result device even + * though packet rx/tx happens over loopback or l3mdev + */ + orig_oif = FIB_RES_OIF(*res); + + fl4->flowi4_oif = dev_out->ifindex; + flags |= RTCF_LOCAL; + goto make_route; + } + + fib_select_path(net, res, fl4, skb); + + dev_out = FIB_RES_DEV(*res); + +make_route: + rth = __mkroute_output(res, fl4, orig_oif, dev_out, flags); + +out: + return rth; +} + +static struct dst_ops ipv4_dst_blackhole_ops = { + .family = AF_INET, + .default_advmss = ipv4_default_advmss, + .neigh_lookup = ipv4_neigh_lookup, + .check = dst_blackhole_check, + .cow_metrics = dst_blackhole_cow_metrics, + .update_pmtu = dst_blackhole_update_pmtu, + .redirect = dst_blackhole_redirect, + .mtu = dst_blackhole_mtu, +}; + +struct dst_entry *ipv4_blackhole_route(struct net *net, struct dst_entry *dst_orig) +{ + struct rtable *ort = (struct rtable *) dst_orig; + struct rtable *rt; + + rt = dst_alloc(&ipv4_dst_blackhole_ops, NULL, 1, DST_OBSOLETE_DEAD, 0); + if (rt) { + struct dst_entry *new = &rt->dst; + + new->__use = 1; + new->input = dst_discard; + new->output = dst_discard_out; + + new->dev = net->loopback_dev; + netdev_hold(new->dev, &new->dev_tracker, GFP_ATOMIC); + + rt->rt_is_input = ort->rt_is_input; + rt->rt_iif = ort->rt_iif; + rt->rt_pmtu = ort->rt_pmtu; + rt->rt_mtu_locked = ort->rt_mtu_locked; + + rt->rt_genid = rt_genid_ipv4(net); + rt->rt_flags = ort->rt_flags; + rt->rt_type = ort->rt_type; + rt->rt_uses_gateway = ort->rt_uses_gateway; + rt->rt_gw_family = ort->rt_gw_family; + if (rt->rt_gw_family == AF_INET) + rt->rt_gw4 = ort->rt_gw4; + else if (rt->rt_gw_family == AF_INET6) + rt->rt_gw6 = ort->rt_gw6; + + INIT_LIST_HEAD(&rt->rt_uncached); + } + + dst_release(dst_orig); + + return rt ? &rt->dst : ERR_PTR(-ENOMEM); +} + +struct rtable *ip_route_output_flow(struct net *net, struct flowi4 *flp4, + const struct sock *sk) +{ + struct rtable *rt = __ip_route_output_key(net, flp4); + + if (IS_ERR(rt)) + return rt; + + if (flp4->flowi4_proto) { + flp4->flowi4_oif = rt->dst.dev->ifindex; + rt = (struct rtable *)xfrm_lookup_route(net, &rt->dst, + flowi4_to_flowi(flp4), + sk, 0); + } + + return rt; +} +EXPORT_SYMBOL_GPL(ip_route_output_flow); + +struct rtable *ip_route_output_tunnel(struct sk_buff *skb, + struct net_device *dev, + struct net *net, __be32 *saddr, + const struct ip_tunnel_info *info, + u8 protocol, bool use_cache) +{ +#ifdef CONFIG_DST_CACHE + struct dst_cache *dst_cache; +#endif + struct rtable *rt = NULL; + struct flowi4 fl4; + __u8 tos; + +#ifdef CONFIG_DST_CACHE + dst_cache = (struct dst_cache *)&info->dst_cache; + if (use_cache) { + rt = dst_cache_get_ip4(dst_cache, saddr); + if (rt) + return rt; + } +#endif + memset(&fl4, 0, sizeof(fl4)); + fl4.flowi4_mark = skb->mark; + fl4.flowi4_proto = protocol; + fl4.daddr = info->key.u.ipv4.dst; + fl4.saddr = info->key.u.ipv4.src; + tos = info->key.tos; + fl4.flowi4_tos = RT_TOS(tos); + + rt = ip_route_output_key(net, &fl4); + if (IS_ERR(rt)) { + netdev_dbg(dev, "no route to %pI4\n", &fl4.daddr); + return ERR_PTR(-ENETUNREACH); + } + if (rt->dst.dev == dev) { /* is this necessary? */ + netdev_dbg(dev, "circular route to %pI4\n", &fl4.daddr); + ip_rt_put(rt); + return ERR_PTR(-ELOOP); + } +#ifdef CONFIG_DST_CACHE + if (use_cache) + dst_cache_set_ip4(dst_cache, &rt->dst, fl4.saddr); +#endif + *saddr = fl4.saddr; + return rt; +} +EXPORT_SYMBOL_GPL(ip_route_output_tunnel); + +/* called with rcu_read_lock held */ +static int rt_fill_info(struct net *net, __be32 dst, __be32 src, + struct rtable *rt, u32 table_id, struct flowi4 *fl4, + struct sk_buff *skb, u32 portid, u32 seq, + unsigned int flags) +{ + struct rtmsg *r; + struct nlmsghdr *nlh; + unsigned long expires = 0; + u32 error; + u32 metrics[RTAX_MAX]; + + nlh = nlmsg_put(skb, portid, seq, RTM_NEWROUTE, sizeof(*r), flags); + if (!nlh) + return -EMSGSIZE; + + r = nlmsg_data(nlh); + r->rtm_family = AF_INET; + r->rtm_dst_len = 32; + r->rtm_src_len = 0; + r->rtm_tos = fl4 ? fl4->flowi4_tos : 0; + r->rtm_table = table_id < 256 ? table_id : RT_TABLE_COMPAT; + if (nla_put_u32(skb, RTA_TABLE, table_id)) + goto nla_put_failure; + r->rtm_type = rt->rt_type; + r->rtm_scope = RT_SCOPE_UNIVERSE; + r->rtm_protocol = RTPROT_UNSPEC; + r->rtm_flags = (rt->rt_flags & ~0xFFFF) | RTM_F_CLONED; + if (rt->rt_flags & RTCF_NOTIFY) + r->rtm_flags |= RTM_F_NOTIFY; + if (IPCB(skb)->flags & IPSKB_DOREDIRECT) + r->rtm_flags |= RTCF_DOREDIRECT; + + if (nla_put_in_addr(skb, RTA_DST, dst)) + goto nla_put_failure; + if (src) { + r->rtm_src_len = 32; + if (nla_put_in_addr(skb, RTA_SRC, src)) + goto nla_put_failure; + } + if (rt->dst.dev && + nla_put_u32(skb, RTA_OIF, rt->dst.dev->ifindex)) + goto nla_put_failure; + if (rt->dst.lwtstate && + lwtunnel_fill_encap(skb, rt->dst.lwtstate, RTA_ENCAP, RTA_ENCAP_TYPE) < 0) + goto nla_put_failure; +#ifdef CONFIG_IP_ROUTE_CLASSID + if (rt->dst.tclassid && + nla_put_u32(skb, RTA_FLOW, rt->dst.tclassid)) + goto nla_put_failure; +#endif + if (fl4 && !rt_is_input_route(rt) && + fl4->saddr != src) { + if (nla_put_in_addr(skb, RTA_PREFSRC, fl4->saddr)) + goto nla_put_failure; + } + if (rt->rt_uses_gateway) { + if (rt->rt_gw_family == AF_INET && + nla_put_in_addr(skb, RTA_GATEWAY, rt->rt_gw4)) { + goto nla_put_failure; + } else if (rt->rt_gw_family == AF_INET6) { + int alen = sizeof(struct in6_addr); + struct nlattr *nla; + struct rtvia *via; + + nla = nla_reserve(skb, RTA_VIA, alen + 2); + if (!nla) + goto nla_put_failure; + + via = nla_data(nla); + via->rtvia_family = AF_INET6; + memcpy(via->rtvia_addr, &rt->rt_gw6, alen); + } + } + + expires = rt->dst.expires; + if (expires) { + unsigned long now = jiffies; + + if (time_before(now, expires)) + expires -= now; + else + expires = 0; + } + + memcpy(metrics, dst_metrics_ptr(&rt->dst), sizeof(metrics)); + if (rt->rt_pmtu && expires) + metrics[RTAX_MTU - 1] = rt->rt_pmtu; + if (rt->rt_mtu_locked && expires) + metrics[RTAX_LOCK - 1] |= BIT(RTAX_MTU); + if (rtnetlink_put_metrics(skb, metrics) < 0) + goto nla_put_failure; + + if (fl4) { + if (fl4->flowi4_mark && + nla_put_u32(skb, RTA_MARK, fl4->flowi4_mark)) + goto nla_put_failure; + + if (!uid_eq(fl4->flowi4_uid, INVALID_UID) && + nla_put_u32(skb, RTA_UID, + from_kuid_munged(current_user_ns(), + fl4->flowi4_uid))) + goto nla_put_failure; + + if (rt_is_input_route(rt)) { +#ifdef CONFIG_IP_MROUTE + if (ipv4_is_multicast(dst) && + !ipv4_is_local_multicast(dst) && + IPV4_DEVCONF_ALL(net, MC_FORWARDING)) { + int err = ipmr_get_route(net, skb, + fl4->saddr, fl4->daddr, + r, portid); + + if (err <= 0) { + if (err == 0) + return 0; + goto nla_put_failure; + } + } else +#endif + if (nla_put_u32(skb, RTA_IIF, fl4->flowi4_iif)) + goto nla_put_failure; + } + } + + error = rt->dst.error; + + if (rtnl_put_cacheinfo(skb, &rt->dst, 0, expires, error) < 0) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int fnhe_dump_bucket(struct net *net, struct sk_buff *skb, + struct netlink_callback *cb, u32 table_id, + struct fnhe_hash_bucket *bucket, int genid, + int *fa_index, int fa_start, unsigned int flags) +{ + int i; + + for (i = 0; i < FNHE_HASH_SIZE; i++) { + struct fib_nh_exception *fnhe; + + for (fnhe = rcu_dereference(bucket[i].chain); fnhe; + fnhe = rcu_dereference(fnhe->fnhe_next)) { + struct rtable *rt; + int err; + + if (*fa_index < fa_start) + goto next; + + if (fnhe->fnhe_genid != genid) + goto next; + + if (fnhe->fnhe_expires && + time_after(jiffies, fnhe->fnhe_expires)) + goto next; + + rt = rcu_dereference(fnhe->fnhe_rth_input); + if (!rt) + rt = rcu_dereference(fnhe->fnhe_rth_output); + if (!rt) + goto next; + + err = rt_fill_info(net, fnhe->fnhe_daddr, 0, rt, + table_id, NULL, skb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, flags); + if (err) + return err; +next: + (*fa_index)++; + } + } + + return 0; +} + +int fib_dump_info_fnhe(struct sk_buff *skb, struct netlink_callback *cb, + u32 table_id, struct fib_info *fi, + int *fa_index, int fa_start, unsigned int flags) +{ + struct net *net = sock_net(cb->skb->sk); + int nhsel, genid = fnhe_genid(net); + + for (nhsel = 0; nhsel < fib_info_num_path(fi); nhsel++) { + struct fib_nh_common *nhc = fib_info_nhc(fi, nhsel); + struct fnhe_hash_bucket *bucket; + int err; + + if (nhc->nhc_flags & RTNH_F_DEAD) + continue; + + rcu_read_lock(); + bucket = rcu_dereference(nhc->nhc_exceptions); + err = 0; + if (bucket) + err = fnhe_dump_bucket(net, skb, cb, table_id, bucket, + genid, fa_index, fa_start, + flags); + rcu_read_unlock(); + if (err) + return err; + } + + return 0; +} + +static struct sk_buff *inet_rtm_getroute_build_skb(__be32 src, __be32 dst, + u8 ip_proto, __be16 sport, + __be16 dport) +{ + struct sk_buff *skb; + struct iphdr *iph; + + skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) + return NULL; + + /* Reserve room for dummy headers, this skb can pass + * through good chunk of routing engine. + */ + skb_reset_mac_header(skb); + skb_reset_network_header(skb); + skb->protocol = htons(ETH_P_IP); + iph = skb_put(skb, sizeof(struct iphdr)); + iph->protocol = ip_proto; + iph->saddr = src; + iph->daddr = dst; + iph->version = 0x4; + iph->frag_off = 0; + iph->ihl = 0x5; + skb_set_transport_header(skb, skb->len); + + switch (iph->protocol) { + case IPPROTO_UDP: { + struct udphdr *udph; + + udph = skb_put_zero(skb, sizeof(struct udphdr)); + udph->source = sport; + udph->dest = dport; + udph->len = htons(sizeof(struct udphdr)); + udph->check = 0; + break; + } + case IPPROTO_TCP: { + struct tcphdr *tcph; + + tcph = skb_put_zero(skb, sizeof(struct tcphdr)); + tcph->source = sport; + tcph->dest = dport; + tcph->doff = sizeof(struct tcphdr) / 4; + tcph->rst = 1; + tcph->check = ~tcp_v4_check(sizeof(struct tcphdr), + src, dst, 0); + break; + } + case IPPROTO_ICMP: { + struct icmphdr *icmph; + + icmph = skb_put_zero(skb, sizeof(struct icmphdr)); + icmph->type = ICMP_ECHO; + icmph->code = 0; + } + } + + return skb; +} + +static int inet_rtm_valid_getroute_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct rtmsg *rtm; + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) { + NL_SET_ERR_MSG(extack, + "ipv4: Invalid header for route get request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse_deprecated(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_ipv4_policy, extack); + + rtm = nlmsg_data(nlh); + if ((rtm->rtm_src_len && rtm->rtm_src_len != 32) || + (rtm->rtm_dst_len && rtm->rtm_dst_len != 32) || + rtm->rtm_table || rtm->rtm_protocol || + rtm->rtm_scope || rtm->rtm_type) { + NL_SET_ERR_MSG(extack, "ipv4: Invalid values in header for route get request"); + return -EINVAL; + } + + if (rtm->rtm_flags & ~(RTM_F_NOTIFY | + RTM_F_LOOKUP_TABLE | + RTM_F_FIB_MATCH)) { + NL_SET_ERR_MSG(extack, "ipv4: Unsupported rtm_flags for route get request"); + return -EINVAL; + } + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_ipv4_policy, extack); + if (err) + return err; + + if ((tb[RTA_SRC] && !rtm->rtm_src_len) || + (tb[RTA_DST] && !rtm->rtm_dst_len)) { + NL_SET_ERR_MSG(extack, "ipv4: rtm_src_len and rtm_dst_len must be 32 for IPv4"); + return -EINVAL; + } + + for (i = 0; i <= RTA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case RTA_IIF: + case RTA_OIF: + case RTA_SRC: + case RTA_DST: + case RTA_IP_PROTO: + case RTA_SPORT: + case RTA_DPORT: + case RTA_MARK: + case RTA_UID: + break; + default: + NL_SET_ERR_MSG(extack, "ipv4: Unsupported attribute in route get request"); + return -EINVAL; + } + } + + return 0; +} + +static int inet_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(in_skb->sk); + struct nlattr *tb[RTA_MAX+1]; + u32 table_id = RT_TABLE_MAIN; + __be16 sport = 0, dport = 0; + struct fib_result res = {}; + u8 ip_proto = IPPROTO_UDP; + struct rtable *rt = NULL; + struct sk_buff *skb; + struct rtmsg *rtm; + struct flowi4 fl4 = {}; + __be32 dst = 0; + __be32 src = 0; + kuid_t uid; + u32 iif; + int err; + int mark; + + err = inet_rtm_valid_getroute_req(in_skb, nlh, tb, extack); + if (err < 0) + return err; + + rtm = nlmsg_data(nlh); + src = tb[RTA_SRC] ? nla_get_in_addr(tb[RTA_SRC]) : 0; + dst = tb[RTA_DST] ? nla_get_in_addr(tb[RTA_DST]) : 0; + iif = tb[RTA_IIF] ? nla_get_u32(tb[RTA_IIF]) : 0; + mark = tb[RTA_MARK] ? nla_get_u32(tb[RTA_MARK]) : 0; + if (tb[RTA_UID]) + uid = make_kuid(current_user_ns(), nla_get_u32(tb[RTA_UID])); + else + uid = (iif ? INVALID_UID : current_uid()); + + if (tb[RTA_IP_PROTO]) { + err = rtm_getroute_parse_ip_proto(tb[RTA_IP_PROTO], + &ip_proto, AF_INET, extack); + if (err) + return err; + } + + if (tb[RTA_SPORT]) + sport = nla_get_be16(tb[RTA_SPORT]); + + if (tb[RTA_DPORT]) + dport = nla_get_be16(tb[RTA_DPORT]); + + skb = inet_rtm_getroute_build_skb(src, dst, ip_proto, sport, dport); + if (!skb) + return -ENOBUFS; + + fl4.daddr = dst; + fl4.saddr = src; + fl4.flowi4_tos = rtm->rtm_tos & IPTOS_RT_MASK; + fl4.flowi4_oif = tb[RTA_OIF] ? nla_get_u32(tb[RTA_OIF]) : 0; + fl4.flowi4_mark = mark; + fl4.flowi4_uid = uid; + if (sport) + fl4.fl4_sport = sport; + if (dport) + fl4.fl4_dport = dport; + fl4.flowi4_proto = ip_proto; + + rcu_read_lock(); + + if (iif) { + struct net_device *dev; + + dev = dev_get_by_index_rcu(net, iif); + if (!dev) { + err = -ENODEV; + goto errout_rcu; + } + + fl4.flowi4_iif = iif; /* for rt_fill_info */ + skb->dev = dev; + skb->mark = mark; + err = ip_route_input_rcu(skb, dst, src, + rtm->rtm_tos & IPTOS_RT_MASK, dev, + &res); + + rt = skb_rtable(skb); + if (err == 0 && rt->dst.error) + err = -rt->dst.error; + } else { + fl4.flowi4_iif = LOOPBACK_IFINDEX; + skb->dev = net->loopback_dev; + rt = ip_route_output_key_hash_rcu(net, &fl4, &res, skb); + err = 0; + if (IS_ERR(rt)) + err = PTR_ERR(rt); + else + skb_dst_set(skb, &rt->dst); + } + + if (err) + goto errout_rcu; + + if (rtm->rtm_flags & RTM_F_NOTIFY) + rt->rt_flags |= RTCF_NOTIFY; + + if (rtm->rtm_flags & RTM_F_LOOKUP_TABLE) + table_id = res.table ? res.table->tb_id : 0; + + /* reset skb for netlink reply msg */ + skb_trim(skb, 0); + skb_reset_network_header(skb); + skb_reset_transport_header(skb); + skb_reset_mac_header(skb); + + if (rtm->rtm_flags & RTM_F_FIB_MATCH) { + struct fib_rt_info fri; + + if (!res.fi) { + err = fib_props[res.type].error; + if (!err) + err = -EHOSTUNREACH; + goto errout_rcu; + } + fri.fi = res.fi; + fri.tb_id = table_id; + fri.dst = res.prefix; + fri.dst_len = res.prefixlen; + fri.dscp = inet_dsfield_to_dscp(fl4.flowi4_tos); + fri.type = rt->rt_type; + fri.offload = 0; + fri.trap = 0; + fri.offload_failed = 0; + if (res.fa_head) { + struct fib_alias *fa; + + hlist_for_each_entry_rcu(fa, res.fa_head, fa_list) { + u8 slen = 32 - fri.dst_len; + + if (fa->fa_slen == slen && + fa->tb_id == fri.tb_id && + fa->fa_dscp == fri.dscp && + fa->fa_info == res.fi && + fa->fa_type == fri.type) { + fri.offload = READ_ONCE(fa->offload); + fri.trap = READ_ONCE(fa->trap); + fri.offload_failed = + READ_ONCE(fa->offload_failed); + break; + } + } + } + err = fib_dump_info(skb, NETLINK_CB(in_skb).portid, + nlh->nlmsg_seq, RTM_NEWROUTE, &fri, 0); + } else { + err = rt_fill_info(net, dst, src, rt, table_id, &fl4, skb, + NETLINK_CB(in_skb).portid, + nlh->nlmsg_seq, 0); + } + if (err < 0) + goto errout_rcu; + + rcu_read_unlock(); + + err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid); + +errout_free: + return err; +errout_rcu: + rcu_read_unlock(); + kfree_skb(skb); + goto errout_free; +} + +void ip_rt_multicast_event(struct in_device *in_dev) +{ + rt_cache_flush(dev_net(in_dev->dev)); +} + +#ifdef CONFIG_SYSCTL +static int ip_rt_gc_interval __read_mostly = 60 * HZ; +static int ip_rt_gc_min_interval __read_mostly = HZ / 2; +static int ip_rt_gc_elasticity __read_mostly = 8; +static int ip_min_valid_pmtu __read_mostly = IPV4_MIN_MTU; + +static int ipv4_sysctl_rtcache_flush(struct ctl_table *__ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = (struct net *)__ctl->extra1; + + if (write) { + rt_cache_flush(net); + fnhe_genid_bump(net); + return 0; + } + + return -EINVAL; +} + +static struct ctl_table ipv4_route_table[] = { + { + .procname = "gc_thresh", + .data = &ipv4_dst_ops.gc_thresh, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "max_size", + .data = &ip_rt_max_size, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + /* Deprecated. Use gc_min_interval_ms */ + + .procname = "gc_min_interval", + .data = &ip_rt_gc_min_interval, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "gc_min_interval_ms", + .data = &ip_rt_gc_min_interval, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_ms_jiffies, + }, + { + .procname = "gc_timeout", + .data = &ip_rt_gc_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "gc_interval", + .data = &ip_rt_gc_interval, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "redirect_load", + .data = &ip_rt_redirect_load, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "redirect_number", + .data = &ip_rt_redirect_number, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "redirect_silence", + .data = &ip_rt_redirect_silence, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "error_cost", + .data = &ip_rt_error_cost, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "error_burst", + .data = &ip_rt_error_burst, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "gc_elasticity", + .data = &ip_rt_gc_elasticity, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { } +}; + +static const char ipv4_route_flush_procname[] = "flush"; + +static struct ctl_table ipv4_route_netns_table[] = { + { + .procname = ipv4_route_flush_procname, + .maxlen = sizeof(int), + .mode = 0200, + .proc_handler = ipv4_sysctl_rtcache_flush, + }, + { + .procname = "min_pmtu", + .data = &init_net.ipv4.ip_rt_min_pmtu, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &ip_min_valid_pmtu, + }, + { + .procname = "mtu_expires", + .data = &init_net.ipv4.ip_rt_mtu_expires, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "min_adv_mss", + .data = &init_net.ipv4.ip_rt_min_advmss, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { }, +}; + +static __net_init int sysctl_route_net_init(struct net *net) +{ + struct ctl_table *tbl; + + tbl = ipv4_route_netns_table; + if (!net_eq(net, &init_net)) { + int i; + + tbl = kmemdup(tbl, sizeof(ipv4_route_netns_table), GFP_KERNEL); + if (!tbl) + goto err_dup; + + /* Don't export non-whitelisted sysctls to unprivileged users */ + if (net->user_ns != &init_user_ns) { + if (tbl[0].procname != ipv4_route_flush_procname) + tbl[0].procname = NULL; + } + + /* Update the variables to point into the current struct net + * except for the first element flush + */ + for (i = 1; i < ARRAY_SIZE(ipv4_route_netns_table) - 1; i++) + tbl[i].data += (void *)net - (void *)&init_net; + } + tbl[0].extra1 = net; + + net->ipv4.route_hdr = register_net_sysctl(net, "net/ipv4/route", tbl); + if (!net->ipv4.route_hdr) + goto err_reg; + return 0; + +err_reg: + if (tbl != ipv4_route_netns_table) + kfree(tbl); +err_dup: + return -ENOMEM; +} + +static __net_exit void sysctl_route_net_exit(struct net *net) +{ + struct ctl_table *tbl; + + tbl = net->ipv4.route_hdr->ctl_table_arg; + unregister_net_sysctl_table(net->ipv4.route_hdr); + BUG_ON(tbl == ipv4_route_netns_table); + kfree(tbl); +} + +static __net_initdata struct pernet_operations sysctl_route_ops = { + .init = sysctl_route_net_init, + .exit = sysctl_route_net_exit, +}; +#endif + +static __net_init int netns_ip_rt_init(struct net *net) +{ + /* Set default value for namespaceified sysctls */ + net->ipv4.ip_rt_min_pmtu = DEFAULT_MIN_PMTU; + net->ipv4.ip_rt_mtu_expires = DEFAULT_MTU_EXPIRES; + net->ipv4.ip_rt_min_advmss = DEFAULT_MIN_ADVMSS; + return 0; +} + +static struct pernet_operations __net_initdata ip_rt_ops = { + .init = netns_ip_rt_init, +}; + +static __net_init int rt_genid_init(struct net *net) +{ + atomic_set(&net->ipv4.rt_genid, 0); + atomic_set(&net->fnhe_genid, 0); + atomic_set(&net->ipv4.dev_addr_genid, get_random_u32()); + return 0; +} + +static __net_initdata struct pernet_operations rt_genid_ops = { + .init = rt_genid_init, +}; + +static int __net_init ipv4_inetpeer_init(struct net *net) +{ + struct inet_peer_base *bp = kmalloc(sizeof(*bp), GFP_KERNEL); + + if (!bp) + return -ENOMEM; + inet_peer_base_init(bp); + net->ipv4.peers = bp; + return 0; +} + +static void __net_exit ipv4_inetpeer_exit(struct net *net) +{ + struct inet_peer_base *bp = net->ipv4.peers; + + net->ipv4.peers = NULL; + inetpeer_invalidate_tree(bp); + kfree(bp); +} + +static __net_initdata struct pernet_operations ipv4_inetpeer_ops = { + .init = ipv4_inetpeer_init, + .exit = ipv4_inetpeer_exit, +}; + +#ifdef CONFIG_IP_ROUTE_CLASSID +struct ip_rt_acct __percpu *ip_rt_acct __read_mostly; +#endif /* CONFIG_IP_ROUTE_CLASSID */ + +int __init ip_rt_init(void) +{ + void *idents_hash; + int cpu; + + /* For modern hosts, this will use 2 MB of memory */ + idents_hash = alloc_large_system_hash("IP idents", + sizeof(*ip_idents) + sizeof(*ip_tstamps), + 0, + 16, /* one bucket per 64 KB */ + HASH_ZERO, + NULL, + &ip_idents_mask, + 2048, + 256*1024); + + ip_idents = idents_hash; + + get_random_bytes(ip_idents, (ip_idents_mask + 1) * sizeof(*ip_idents)); + + ip_tstamps = idents_hash + (ip_idents_mask + 1) * sizeof(*ip_idents); + + for_each_possible_cpu(cpu) { + struct uncached_list *ul = &per_cpu(rt_uncached_list, cpu); + + INIT_LIST_HEAD(&ul->head); + INIT_LIST_HEAD(&ul->quarantine); + spin_lock_init(&ul->lock); + } +#ifdef CONFIG_IP_ROUTE_CLASSID + ip_rt_acct = __alloc_percpu(256 * sizeof(struct ip_rt_acct), __alignof__(struct ip_rt_acct)); + if (!ip_rt_acct) + panic("IP: failed to allocate ip_rt_acct\n"); +#endif + + ipv4_dst_ops.kmem_cachep = + kmem_cache_create("ip_dst_cache", sizeof(struct rtable), 0, + SLAB_HWCACHE_ALIGN|SLAB_PANIC, NULL); + + ipv4_dst_blackhole_ops.kmem_cachep = ipv4_dst_ops.kmem_cachep; + + if (dst_entries_init(&ipv4_dst_ops) < 0) + panic("IP: failed to allocate ipv4_dst_ops counter\n"); + + if (dst_entries_init(&ipv4_dst_blackhole_ops) < 0) + panic("IP: failed to allocate ipv4_dst_blackhole_ops counter\n"); + + ipv4_dst_ops.gc_thresh = ~0; + ip_rt_max_size = INT_MAX; + + devinet_init(); + ip_fib_init(); + + if (ip_rt_proc_init()) + pr_err("Unable to create route proc files\n"); +#ifdef CONFIG_XFRM + xfrm_init(); + xfrm4_init(); +#endif + rtnl_register(PF_INET, RTM_GETROUTE, inet_rtm_getroute, NULL, + RTNL_FLAG_DOIT_UNLOCKED); + +#ifdef CONFIG_SYSCTL + register_pernet_subsys(&sysctl_route_ops); +#endif + register_pernet_subsys(&ip_rt_ops); + register_pernet_subsys(&rt_genid_ops); + register_pernet_subsys(&ipv4_inetpeer_ops); + return 0; +} + +#ifdef CONFIG_SYSCTL +/* + * We really need to sanitize the damn ipv4 init order, then all + * this nonsense will go away. + */ +void __init ip_static_sysctl_init(void) +{ + register_net_sysctl(&init_net, "net/ipv4/route", ipv4_route_table); +} +#endif diff --git a/net/ipv4/syncookies.c b/net/ipv4/syncookies.c new file mode 100644 index 000000000..f9514cf87 --- /dev/null +++ b/net/ipv4/syncookies.c @@ -0,0 +1,449 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Syncookies implementation for the Linux kernel + * + * Copyright (C) 1997 Andi Kleen + * Based on ideas by D.J.Bernstein and Eric Schenk. + */ + +#include +#include +#include +#include +#include +#include +#include + +static siphash_aligned_key_t syncookie_secret[2]; + +#define COOKIEBITS 24 /* Upper bits store count */ +#define COOKIEMASK (((__u32)1 << COOKIEBITS) - 1) + +/* TCP Timestamp: 6 lowest bits of timestamp sent in the cookie SYN-ACK + * stores TCP options: + * + * MSB LSB + * | 31 ... 6 | 5 | 4 | 3 2 1 0 | + * | Timestamp | ECN | SACK | WScale | + * + * When we receive a valid cookie-ACK, we look at the echoed tsval (if + * any) to figure out which TCP options we should use for the rebuilt + * connection. + * + * A WScale setting of '0xf' (which is an invalid scaling value) + * means that original syn did not include the TCP window scaling option. + */ +#define TS_OPT_WSCALE_MASK 0xf +#define TS_OPT_SACK BIT(4) +#define TS_OPT_ECN BIT(5) +/* There is no TS_OPT_TIMESTAMP: + * if ACK contains timestamp option, we already know it was + * requested/supported by the syn/synack exchange. + */ +#define TSBITS 6 + +static u32 cookie_hash(__be32 saddr, __be32 daddr, __be16 sport, __be16 dport, + u32 count, int c) +{ + net_get_random_once(syncookie_secret, sizeof(syncookie_secret)); + return siphash_4u32((__force u32)saddr, (__force u32)daddr, + (__force u32)sport << 16 | (__force u32)dport, + count, &syncookie_secret[c]); +} + + +/* + * when syncookies are in effect and tcp timestamps are enabled we encode + * tcp options in the lower bits of the timestamp value that will be + * sent in the syn-ack. + * Since subsequent timestamps use the normal tcp_time_stamp value, we + * must make sure that the resulting initial timestamp is <= tcp_time_stamp. + */ +u64 cookie_init_timestamp(struct request_sock *req, u64 now) +{ + const struct inet_request_sock *ireq = inet_rsk(req); + u64 ts, ts_now = tcp_ns_to_ts(now); + u32 options = 0; + + options = ireq->wscale_ok ? ireq->snd_wscale : TS_OPT_WSCALE_MASK; + if (ireq->sack_ok) + options |= TS_OPT_SACK; + if (ireq->ecn_ok) + options |= TS_OPT_ECN; + + ts = (ts_now >> TSBITS) << TSBITS; + ts |= options; + if (ts > ts_now) + ts -= (1UL << TSBITS); + + return ts * (NSEC_PER_SEC / TCP_TS_HZ); +} + + +static __u32 secure_tcp_syn_cookie(__be32 saddr, __be32 daddr, __be16 sport, + __be16 dport, __u32 sseq, __u32 data) +{ + /* + * Compute the secure sequence number. + * The output should be: + * HASH(sec1,saddr,sport,daddr,dport,sec1) + sseq + (count * 2^24) + * + (HASH(sec2,saddr,sport,daddr,dport,count,sec2) % 2^24). + * Where sseq is their sequence number and count increases every + * minute by 1. + * As an extra hack, we add a small "data" value that encodes the + * MSS into the second hash value. + */ + u32 count = tcp_cookie_time(); + return (cookie_hash(saddr, daddr, sport, dport, 0, 0) + + sseq + (count << COOKIEBITS) + + ((cookie_hash(saddr, daddr, sport, dport, count, 1) + data) + & COOKIEMASK)); +} + +/* + * This retrieves the small "data" value from the syncookie. + * If the syncookie is bad, the data returned will be out of + * range. This must be checked by the caller. + * + * The count value used to generate the cookie must be less than + * MAX_SYNCOOKIE_AGE minutes in the past. + * The return value (__u32)-1 if this test fails. + */ +static __u32 check_tcp_syn_cookie(__u32 cookie, __be32 saddr, __be32 daddr, + __be16 sport, __be16 dport, __u32 sseq) +{ + u32 diff, count = tcp_cookie_time(); + + /* Strip away the layers from the cookie */ + cookie -= cookie_hash(saddr, daddr, sport, dport, 0, 0) + sseq; + + /* Cookie is now reduced to (count * 2^24) ^ (hash % 2^24) */ + diff = (count - (cookie >> COOKIEBITS)) & ((__u32) -1 >> COOKIEBITS); + if (diff >= MAX_SYNCOOKIE_AGE) + return (__u32)-1; + + return (cookie - + cookie_hash(saddr, daddr, sport, dport, count - diff, 1)) + & COOKIEMASK; /* Leaving the data behind */ +} + +/* + * MSS Values are chosen based on the 2011 paper + * 'An Analysis of TCP Maximum Segement Sizes' by S. Alcock and R. Nelson. + * Values .. + * .. lower than 536 are rare (< 0.2%) + * .. between 537 and 1299 account for less than < 1.5% of observed values + * .. in the 1300-1349 range account for about 15 to 20% of observed mss values + * .. exceeding 1460 are very rare (< 0.04%) + * + * 1460 is the single most frequently announced mss value (30 to 46% depending + * on monitor location). Table must be sorted. + */ +static __u16 const msstab[] = { + 536, + 1300, + 1440, /* 1440, 1452: PPPoE */ + 1460, +}; + +/* + * Generate a syncookie. mssp points to the mss, which is returned + * rounded down to the value encoded in the cookie. + */ +u32 __cookie_v4_init_sequence(const struct iphdr *iph, const struct tcphdr *th, + u16 *mssp) +{ + int mssind; + const __u16 mss = *mssp; + + for (mssind = ARRAY_SIZE(msstab) - 1; mssind ; mssind--) + if (mss >= msstab[mssind]) + break; + *mssp = msstab[mssind]; + + return secure_tcp_syn_cookie(iph->saddr, iph->daddr, + th->source, th->dest, ntohl(th->seq), + mssind); +} +EXPORT_SYMBOL_GPL(__cookie_v4_init_sequence); + +__u32 cookie_v4_init_sequence(const struct sk_buff *skb, __u16 *mssp) +{ + const struct iphdr *iph = ip_hdr(skb); + const struct tcphdr *th = tcp_hdr(skb); + + return __cookie_v4_init_sequence(iph, th, mssp); +} + +/* + * Check if a ack sequence number is a valid syncookie. + * Return the decoded mss if it is, or 0 if not. + */ +int __cookie_v4_check(const struct iphdr *iph, const struct tcphdr *th, + u32 cookie) +{ + __u32 seq = ntohl(th->seq) - 1; + __u32 mssind = check_tcp_syn_cookie(cookie, iph->saddr, iph->daddr, + th->source, th->dest, seq); + + return mssind < ARRAY_SIZE(msstab) ? msstab[mssind] : 0; +} +EXPORT_SYMBOL_GPL(__cookie_v4_check); + +struct sock *tcp_get_cookie_sock(struct sock *sk, struct sk_buff *skb, + struct request_sock *req, + struct dst_entry *dst, u32 tsoff) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct sock *child; + bool own_req; + + child = icsk->icsk_af_ops->syn_recv_sock(sk, skb, req, dst, + NULL, &own_req); + if (child) { + refcount_set(&req->rsk_refcnt, 1); + tcp_sk(child)->tsoffset = tsoff; + sock_rps_save_rxhash(child, skb); + + if (rsk_drop_req(req)) { + reqsk_put(req); + return child; + } + + if (inet_csk_reqsk_queue_add(sk, req, child)) + return child; + + bh_unlock_sock(child); + sock_put(child); + } + __reqsk_free(req); + + return NULL; +} +EXPORT_SYMBOL(tcp_get_cookie_sock); + +/* + * when syncookies are in effect and tcp timestamps are enabled we stored + * additional tcp options in the timestamp. + * This extracts these options from the timestamp echo. + * + * return false if we decode a tcp option that is disabled + * on the host. + */ +bool cookie_timestamp_decode(const struct net *net, + struct tcp_options_received *tcp_opt) +{ + /* echoed timestamp, lowest bits contain options */ + u32 options = tcp_opt->rcv_tsecr; + + if (!tcp_opt->saw_tstamp) { + tcp_clear_options(tcp_opt); + return true; + } + + if (!READ_ONCE(net->ipv4.sysctl_tcp_timestamps)) + return false; + + tcp_opt->sack_ok = (options & TS_OPT_SACK) ? TCP_SACK_SEEN : 0; + + if (tcp_opt->sack_ok && !READ_ONCE(net->ipv4.sysctl_tcp_sack)) + return false; + + if ((options & TS_OPT_WSCALE_MASK) == TS_OPT_WSCALE_MASK) + return true; /* no window scaling */ + + tcp_opt->wscale_ok = 1; + tcp_opt->snd_wscale = options & TS_OPT_WSCALE_MASK; + + return READ_ONCE(net->ipv4.sysctl_tcp_window_scaling) != 0; +} +EXPORT_SYMBOL(cookie_timestamp_decode); + +bool cookie_ecn_ok(const struct tcp_options_received *tcp_opt, + const struct net *net, const struct dst_entry *dst) +{ + bool ecn_ok = tcp_opt->rcv_tsecr & TS_OPT_ECN; + + if (!ecn_ok) + return false; + + if (READ_ONCE(net->ipv4.sysctl_tcp_ecn)) + return true; + + return dst_feature(dst, RTAX_FEATURE_ECN); +} +EXPORT_SYMBOL(cookie_ecn_ok); + +struct request_sock *cookie_tcp_reqsk_alloc(const struct request_sock_ops *ops, + const struct tcp_request_sock_ops *af_ops, + struct sock *sk, + struct sk_buff *skb) +{ + struct tcp_request_sock *treq; + struct request_sock *req; + + if (sk_is_mptcp(sk)) + req = mptcp_subflow_reqsk_alloc(ops, sk, false); + else + req = inet_reqsk_alloc(ops, sk, false); + + if (!req) + return NULL; + + treq = tcp_rsk(req); + + /* treq->af_specific might be used to perform TCP_MD5 lookup */ + treq->af_specific = af_ops; + + treq->syn_tos = TCP_SKB_CB(skb)->ip_dsfield; +#if IS_ENABLED(CONFIG_MPTCP) + treq->is_mptcp = sk_is_mptcp(sk); + if (treq->is_mptcp) { + int err = mptcp_subflow_init_cookie_req(req, sk, skb); + + if (err) { + reqsk_free(req); + return NULL; + } + } +#endif + + return req; +} +EXPORT_SYMBOL_GPL(cookie_tcp_reqsk_alloc); + +/* On input, sk is a listener. + * Output is listener if incoming packet would not create a child + * NULL if memory could not be allocated. + */ +struct sock *cookie_v4_check(struct sock *sk, struct sk_buff *skb) +{ + struct ip_options *opt = &TCP_SKB_CB(skb)->header.h4.opt; + struct tcp_options_received tcp_opt; + struct inet_request_sock *ireq; + struct tcp_request_sock *treq; + struct tcp_sock *tp = tcp_sk(sk); + const struct tcphdr *th = tcp_hdr(skb); + __u32 cookie = ntohl(th->ack_seq) - 1; + struct sock *ret = sk; + struct request_sock *req; + int full_space, mss; + struct rtable *rt; + __u8 rcv_wscale; + struct flowi4 fl4; + u32 tsoff = 0; + + if (!READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_syncookies) || + !th->ack || th->rst) + goto out; + + if (tcp_synq_no_recent_overflow(sk)) + goto out; + + mss = __cookie_v4_check(ip_hdr(skb), th, cookie); + if (mss == 0) { + __NET_INC_STATS(sock_net(sk), LINUX_MIB_SYNCOOKIESFAILED); + goto out; + } + + __NET_INC_STATS(sock_net(sk), LINUX_MIB_SYNCOOKIESRECV); + + /* check for timestamp cookie support */ + memset(&tcp_opt, 0, sizeof(tcp_opt)); + tcp_parse_options(sock_net(sk), skb, &tcp_opt, 0, NULL); + + if (tcp_opt.saw_tstamp && tcp_opt.rcv_tsecr) { + tsoff = secure_tcp_ts_off(sock_net(sk), + ip_hdr(skb)->daddr, + ip_hdr(skb)->saddr); + tcp_opt.rcv_tsecr -= tsoff; + } + + if (!cookie_timestamp_decode(sock_net(sk), &tcp_opt)) + goto out; + + ret = NULL; + req = cookie_tcp_reqsk_alloc(&tcp_request_sock_ops, + &tcp_request_sock_ipv4_ops, sk, skb); + if (!req) + goto out; + + ireq = inet_rsk(req); + treq = tcp_rsk(req); + treq->rcv_isn = ntohl(th->seq) - 1; + treq->snt_isn = cookie; + treq->ts_off = 0; + treq->txhash = net_tx_rndhash(); + req->mss = mss; + ireq->ir_num = ntohs(th->dest); + ireq->ir_rmt_port = th->source; + sk_rcv_saddr_set(req_to_sk(req), ip_hdr(skb)->daddr); + sk_daddr_set(req_to_sk(req), ip_hdr(skb)->saddr); + ireq->ir_mark = inet_request_mark(sk, skb); + ireq->snd_wscale = tcp_opt.snd_wscale; + ireq->sack_ok = tcp_opt.sack_ok; + ireq->wscale_ok = tcp_opt.wscale_ok; + ireq->tstamp_ok = tcp_opt.saw_tstamp; + req->ts_recent = tcp_opt.saw_tstamp ? tcp_opt.rcv_tsval : 0; + treq->snt_synack = 0; + treq->tfo_listener = false; + + if (IS_ENABLED(CONFIG_SMC)) + ireq->smc_ok = 0; + + ireq->ir_iif = inet_request_bound_dev_if(sk, skb); + + /* We throwed the options of the initial SYN away, so we hope + * the ACK carries the same options again (see RFC1122 4.2.3.8) + */ + RCU_INIT_POINTER(ireq->ireq_opt, tcp_v4_save_options(sock_net(sk), skb)); + + if (security_inet_conn_request(sk, skb, req)) { + reqsk_free(req); + goto out; + } + + req->num_retrans = 0; + + /* + * We need to lookup the route here to get at the correct + * window size. We should better make sure that the window size + * hasn't changed since we received the original syn, but I see + * no easy way to do this. + */ + flowi4_init_output(&fl4, ireq->ir_iif, ireq->ir_mark, + RT_CONN_FLAGS(sk), RT_SCOPE_UNIVERSE, IPPROTO_TCP, + inet_sk_flowi_flags(sk), + opt->srr ? opt->faddr : ireq->ir_rmt_addr, + ireq->ir_loc_addr, th->source, th->dest, sk->sk_uid); + security_req_classify_flow(req, flowi4_to_flowi_common(&fl4)); + rt = ip_route_output_key(sock_net(sk), &fl4); + if (IS_ERR(rt)) { + reqsk_free(req); + goto out; + } + + /* Try to redo what tcp_v4_send_synack did. */ + req->rsk_window_clamp = tp->window_clamp ? :dst_metric(&rt->dst, RTAX_WINDOW); + /* limit the window selection if the user enforce a smaller rx buffer */ + full_space = tcp_full_space(sk); + if (sk->sk_userlocks & SOCK_RCVBUF_LOCK && + (req->rsk_window_clamp > full_space || req->rsk_window_clamp == 0)) + req->rsk_window_clamp = full_space; + + tcp_select_initial_window(sk, full_space, req->mss, + &req->rsk_rcv_wnd, &req->rsk_window_clamp, + ireq->wscale_ok, &rcv_wscale, + dst_metric(&rt->dst, RTAX_INITRWND)); + + ireq->rcv_wscale = rcv_wscale; + ireq->ecn_ok = cookie_ecn_ok(&tcp_opt, sock_net(sk), &rt->dst); + + ret = tcp_get_cookie_sock(sk, skb, req, &rt->dst, tsoff); + /* ip_queue_xmit() depends on our flow being setup + * Normal sockets get it right from inet_csk_route_child_sock() + */ + if (ret) + inet_sk(ret)->cork.fl.u.ip4 = fl4; +out: return ret; +} diff --git a/net/ipv4/sysctl_net_ipv4.c b/net/ipv4/sysctl_net_ipv4.c new file mode 100644 index 000000000..73e582158 --- /dev/null +++ b/net/ipv4/sysctl_net_ipv4.c @@ -0,0 +1,1479 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * sysctl_net_ipv4.c: sysctl interface to net IPV4 subsystem. + * + * Begun April 1, 1996, Mike Shaver. + * Added /proc/sys/net/ipv4 directory entry (empty =) ). [MS] + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int tcp_retr1_max = 255; +static int ip_local_port_range_min[] = { 1, 1 }; +static int ip_local_port_range_max[] = { 65535, 65535 }; +static int tcp_adv_win_scale_min = -31; +static int tcp_adv_win_scale_max = 31; +static int tcp_app_win_max = 31; +static int tcp_min_snd_mss_min = TCP_MIN_SND_MSS; +static int tcp_min_snd_mss_max = 65535; +static int ip_privileged_port_min; +static int ip_privileged_port_max = 65535; +static int ip_ttl_min = 1; +static int ip_ttl_max = 255; +static int tcp_syn_retries_min = 1; +static int tcp_syn_retries_max = MAX_TCP_SYNCNT; +static unsigned long ip_ping_group_range_min[] = { 0, 0 }; +static unsigned long ip_ping_group_range_max[] = { GID_T_MAX, GID_T_MAX }; +static u32 u32_max_div_HZ = UINT_MAX / HZ; +static int one_day_secs = 24 * 3600; +static u32 fib_multipath_hash_fields_all_mask __maybe_unused = + FIB_MULTIPATH_HASH_FIELD_ALL_MASK; +static unsigned int tcp_child_ehash_entries_max = 16 * 1024 * 1024; + +/* obsolete */ +static int sysctl_tcp_low_latency __read_mostly; + +/* Update system visible IP port range */ +static void set_local_port_range(struct net *net, int range[2]) +{ + bool same_parity = !((range[0] ^ range[1]) & 1); + + write_seqlock_bh(&net->ipv4.ip_local_ports.lock); + if (same_parity && !net->ipv4.ip_local_ports.warned) { + net->ipv4.ip_local_ports.warned = true; + pr_err_ratelimited("ip_local_port_range: prefer different parity for start/end values.\n"); + } + net->ipv4.ip_local_ports.range[0] = range[0]; + net->ipv4.ip_local_ports.range[1] = range[1]; + write_sequnlock_bh(&net->ipv4.ip_local_ports.lock); +} + +/* Validate changes from /proc interface. */ +static int ipv4_local_port_range(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = + container_of(table->data, struct net, ipv4.ip_local_ports.range); + int ret; + int range[2]; + struct ctl_table tmp = { + .data = &range, + .maxlen = sizeof(range), + .mode = table->mode, + .extra1 = &ip_local_port_range_min, + .extra2 = &ip_local_port_range_max, + }; + + inet_get_local_port_range(net, &range[0], &range[1]); + + ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos); + + if (write && ret == 0) { + /* Ensure that the upper limit is not smaller than the lower, + * and that the lower does not encroach upon the privileged + * port limit. + */ + if ((range[1] < range[0]) || + (range[0] < READ_ONCE(net->ipv4.sysctl_ip_prot_sock))) + ret = -EINVAL; + else + set_local_port_range(net, range); + } + + return ret; +} + +/* Validate changes from /proc interface. */ +static int ipv4_privileged_ports(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = container_of(table->data, struct net, + ipv4.sysctl_ip_prot_sock); + int ret; + int pports; + int range[2]; + struct ctl_table tmp = { + .data = &pports, + .maxlen = sizeof(pports), + .mode = table->mode, + .extra1 = &ip_privileged_port_min, + .extra2 = &ip_privileged_port_max, + }; + + pports = READ_ONCE(net->ipv4.sysctl_ip_prot_sock); + + ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos); + + if (write && ret == 0) { + inet_get_local_port_range(net, &range[0], &range[1]); + /* Ensure that the local port range doesn't overlap with the + * privileged port range. + */ + if (range[0] < pports) + ret = -EINVAL; + else + WRITE_ONCE(net->ipv4.sysctl_ip_prot_sock, pports); + } + + return ret; +} + +static void inet_get_ping_group_range_table(struct ctl_table *table, kgid_t *low, kgid_t *high) +{ + kgid_t *data = table->data; + struct net *net = + container_of(table->data, struct net, ipv4.ping_group_range.range); + unsigned int seq; + do { + seq = read_seqbegin(&net->ipv4.ping_group_range.lock); + + *low = data[0]; + *high = data[1]; + } while (read_seqretry(&net->ipv4.ping_group_range.lock, seq)); +} + +/* Update system visible IP port range */ +static void set_ping_group_range(struct ctl_table *table, kgid_t low, kgid_t high) +{ + kgid_t *data = table->data; + struct net *net = + container_of(table->data, struct net, ipv4.ping_group_range.range); + write_seqlock(&net->ipv4.ping_group_range.lock); + data[0] = low; + data[1] = high; + write_sequnlock(&net->ipv4.ping_group_range.lock); +} + +/* Validate changes from /proc interface. */ +static int ipv4_ping_group_range(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct user_namespace *user_ns = current_user_ns(); + int ret; + unsigned long urange[2]; + kgid_t low, high; + struct ctl_table tmp = { + .data = &urange, + .maxlen = sizeof(urange), + .mode = table->mode, + .extra1 = &ip_ping_group_range_min, + .extra2 = &ip_ping_group_range_max, + }; + + inet_get_ping_group_range_table(table, &low, &high); + urange[0] = from_kgid_munged(user_ns, low); + urange[1] = from_kgid_munged(user_ns, high); + ret = proc_doulongvec_minmax(&tmp, write, buffer, lenp, ppos); + + if (write && ret == 0) { + low = make_kgid(user_ns, urange[0]); + high = make_kgid(user_ns, urange[1]); + if (!gid_valid(low) || !gid_valid(high)) + return -EINVAL; + if (urange[1] < urange[0] || gid_lt(high, low)) { + low = make_kgid(&init_user_ns, 1); + high = make_kgid(&init_user_ns, 0); + } + set_ping_group_range(table, low, high); + } + + return ret; +} + +static int ipv4_fwd_update_priority(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net; + int ret; + + net = container_of(table->data, struct net, + ipv4.sysctl_ip_fwd_update_priority); + ret = proc_dou8vec_minmax(table, write, buffer, lenp, ppos); + if (write && ret == 0) + call_netevent_notifiers(NETEVENT_IPV4_FWD_UPDATE_PRIORITY_UPDATE, + net); + + return ret; +} + +static int proc_tcp_congestion_control(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = container_of(ctl->data, struct net, + ipv4.tcp_congestion_control); + char val[TCP_CA_NAME_MAX]; + struct ctl_table tbl = { + .data = val, + .maxlen = TCP_CA_NAME_MAX, + }; + int ret; + + tcp_get_default_congestion_control(net, val); + + ret = proc_dostring(&tbl, write, buffer, lenp, ppos); + if (write && ret == 0) + ret = tcp_set_default_congestion_control(net, val); + return ret; +} + +static int proc_tcp_available_congestion_control(struct ctl_table *ctl, + int write, void *buffer, + size_t *lenp, loff_t *ppos) +{ + struct ctl_table tbl = { .maxlen = TCP_CA_BUF_MAX, }; + int ret; + + tbl.data = kmalloc(tbl.maxlen, GFP_USER); + if (!tbl.data) + return -ENOMEM; + tcp_get_available_congestion_control(tbl.data, TCP_CA_BUF_MAX); + ret = proc_dostring(&tbl, write, buffer, lenp, ppos); + kfree(tbl.data); + return ret; +} + +static int proc_allowed_congestion_control(struct ctl_table *ctl, + int write, void *buffer, + size_t *lenp, loff_t *ppos) +{ + struct ctl_table tbl = { .maxlen = TCP_CA_BUF_MAX }; + int ret; + + tbl.data = kmalloc(tbl.maxlen, GFP_USER); + if (!tbl.data) + return -ENOMEM; + + tcp_get_allowed_congestion_control(tbl.data, tbl.maxlen); + ret = proc_dostring(&tbl, write, buffer, lenp, ppos); + if (write && ret == 0) + ret = tcp_set_allowed_congestion_control(tbl.data); + kfree(tbl.data); + return ret; +} + +static int sscanf_key(char *buf, __le32 *key) +{ + u32 user_key[4]; + int i, ret = 0; + + if (sscanf(buf, "%x-%x-%x-%x", user_key, user_key + 1, + user_key + 2, user_key + 3) != 4) { + ret = -EINVAL; + } else { + for (i = 0; i < ARRAY_SIZE(user_key); i++) + key[i] = cpu_to_le32(user_key[i]); + } + pr_debug("proc TFO key set 0x%x-%x-%x-%x <- 0x%s: %u\n", + user_key[0], user_key[1], user_key[2], user_key[3], buf, ret); + + return ret; +} + +static int proc_tcp_fastopen_key(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = container_of(table->data, struct net, + ipv4.sysctl_tcp_fastopen); + /* maxlen to print the list of keys in hex (*2), with dashes + * separating doublewords and a comma in between keys. + */ + struct ctl_table tbl = { .maxlen = ((TCP_FASTOPEN_KEY_LENGTH * + 2 * TCP_FASTOPEN_KEY_MAX) + + (TCP_FASTOPEN_KEY_MAX * 5)) }; + u32 user_key[TCP_FASTOPEN_KEY_BUF_LENGTH / sizeof(u32)]; + __le32 key[TCP_FASTOPEN_KEY_BUF_LENGTH / sizeof(__le32)]; + char *backup_data; + int ret, i = 0, off = 0, n_keys; + + tbl.data = kmalloc(tbl.maxlen, GFP_KERNEL); + if (!tbl.data) + return -ENOMEM; + + n_keys = tcp_fastopen_get_cipher(net, NULL, (u64 *)key); + if (!n_keys) { + memset(&key[0], 0, TCP_FASTOPEN_KEY_LENGTH); + n_keys = 1; + } + + for (i = 0; i < n_keys * 4; i++) + user_key[i] = le32_to_cpu(key[i]); + + for (i = 0; i < n_keys; i++) { + off += snprintf(tbl.data + off, tbl.maxlen - off, + "%08x-%08x-%08x-%08x", + user_key[i * 4], + user_key[i * 4 + 1], + user_key[i * 4 + 2], + user_key[i * 4 + 3]); + + if (WARN_ON_ONCE(off >= tbl.maxlen - 1)) + break; + + if (i + 1 < n_keys) + off += snprintf(tbl.data + off, tbl.maxlen - off, ","); + } + + ret = proc_dostring(&tbl, write, buffer, lenp, ppos); + + if (write && ret == 0) { + backup_data = strchr(tbl.data, ','); + if (backup_data) { + *backup_data = '\0'; + backup_data++; + } + if (sscanf_key(tbl.data, key)) { + ret = -EINVAL; + goto bad_key; + } + if (backup_data) { + if (sscanf_key(backup_data, key + 4)) { + ret = -EINVAL; + goto bad_key; + } + } + tcp_fastopen_reset_cipher(net, NULL, key, + backup_data ? key + 4 : NULL); + } + +bad_key: + kfree(tbl.data); + return ret; +} + +static int proc_tfo_blackhole_detect_timeout(struct ctl_table *table, + int write, void *buffer, + size_t *lenp, loff_t *ppos) +{ + struct net *net = container_of(table->data, struct net, + ipv4.sysctl_tcp_fastopen_blackhole_timeout); + int ret; + + ret = proc_dointvec_minmax(table, write, buffer, lenp, ppos); + if (write && ret == 0) + atomic_set(&net->ipv4.tfo_active_disable_times, 0); + + return ret; +} + +static int proc_tcp_available_ulp(struct ctl_table *ctl, + int write, void *buffer, size_t *lenp, + loff_t *ppos) +{ + struct ctl_table tbl = { .maxlen = TCP_ULP_BUF_MAX, }; + int ret; + + tbl.data = kmalloc(tbl.maxlen, GFP_USER); + if (!tbl.data) + return -ENOMEM; + tcp_get_available_ulp(tbl.data, TCP_ULP_BUF_MAX); + ret = proc_dostring(&tbl, write, buffer, lenp, ppos); + kfree(tbl.data); + + return ret; +} + +static int proc_tcp_ehash_entries(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = container_of(table->data, struct net, + ipv4.sysctl_tcp_child_ehash_entries); + struct inet_hashinfo *hinfo = net->ipv4.tcp_death_row.hashinfo; + int tcp_ehash_entries; + struct ctl_table tbl; + + tcp_ehash_entries = hinfo->ehash_mask + 1; + + /* A negative number indicates that the child netns + * shares the global ehash. + */ + if (!net_eq(net, &init_net) && !hinfo->pernet) + tcp_ehash_entries *= -1; + + tbl.data = &tcp_ehash_entries; + tbl.maxlen = sizeof(int); + + return proc_dointvec(&tbl, write, buffer, lenp, ppos); +} + +#ifdef CONFIG_IP_ROUTE_MULTIPATH +static int proc_fib_multipath_hash_policy(struct ctl_table *table, int write, + void *buffer, size_t *lenp, + loff_t *ppos) +{ + struct net *net = container_of(table->data, struct net, + ipv4.sysctl_fib_multipath_hash_policy); + int ret; + + ret = proc_dou8vec_minmax(table, write, buffer, lenp, ppos); + if (write && ret == 0) + call_netevent_notifiers(NETEVENT_IPV4_MPATH_HASH_UPDATE, net); + + return ret; +} + +static int proc_fib_multipath_hash_fields(struct ctl_table *table, int write, + void *buffer, size_t *lenp, + loff_t *ppos) +{ + struct net *net; + int ret; + + net = container_of(table->data, struct net, + ipv4.sysctl_fib_multipath_hash_fields); + ret = proc_douintvec_minmax(table, write, buffer, lenp, ppos); + if (write && ret == 0) + call_netevent_notifiers(NETEVENT_IPV4_MPATH_HASH_UPDATE, net); + + return ret; +} +#endif + +static struct ctl_table ipv4_table[] = { + { + .procname = "tcp_max_orphans", + .data = &sysctl_tcp_max_orphans, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "inet_peer_threshold", + .data = &inet_peer_threshold, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "inet_peer_minttl", + .data = &inet_peer_minttl, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "inet_peer_maxttl", + .data = &inet_peer_maxttl, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "tcp_mem", + .maxlen = sizeof(sysctl_tcp_mem), + .data = &sysctl_tcp_mem, + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + }, + { + .procname = "tcp_low_latency", + .data = &sysctl_tcp_low_latency, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, +#ifdef CONFIG_NETLABEL + { + .procname = "cipso_cache_enable", + .data = &cipso_v4_cache_enabled, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "cipso_cache_bucket_size", + .data = &cipso_v4_cache_bucketsize, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "cipso_rbm_optfmt", + .data = &cipso_v4_rbm_optfmt, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "cipso_rbm_strictvalid", + .data = &cipso_v4_rbm_strictvalid, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, +#endif /* CONFIG_NETLABEL */ + { + .procname = "tcp_available_ulp", + .maxlen = TCP_ULP_BUF_MAX, + .mode = 0444, + .proc_handler = proc_tcp_available_ulp, + }, + { + .procname = "icmp_msgs_per_sec", + .data = &sysctl_icmp_msgs_per_sec, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + }, + { + .procname = "icmp_msgs_burst", + .data = &sysctl_icmp_msgs_burst, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + }, + { + .procname = "udp_mem", + .data = &sysctl_udp_mem, + .maxlen = sizeof(sysctl_udp_mem), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + }, + { + .procname = "fib_sync_mem", + .data = &sysctl_fib_sync_mem, + .maxlen = sizeof(sysctl_fib_sync_mem), + .mode = 0644, + .proc_handler = proc_douintvec_minmax, + .extra1 = &sysctl_fib_sync_mem_min, + .extra2 = &sysctl_fib_sync_mem_max, + }, + { } +}; + +static struct ctl_table ipv4_net_table[] = { + { + .procname = "tcp_max_tw_buckets", + .data = &init_net.ipv4.tcp_death_row.sysctl_max_tw_buckets, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "icmp_echo_ignore_all", + .data = &init_net.ipv4.sysctl_icmp_echo_ignore_all, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE + }, + { + .procname = "icmp_echo_enable_probe", + .data = &init_net.ipv4.sysctl_icmp_echo_enable_probe, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE + }, + { + .procname = "icmp_echo_ignore_broadcasts", + .data = &init_net.ipv4.sysctl_icmp_echo_ignore_broadcasts, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE + }, + { + .procname = "icmp_ignore_bogus_error_responses", + .data = &init_net.ipv4.sysctl_icmp_ignore_bogus_error_responses, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE + }, + { + .procname = "icmp_errors_use_inbound_ifaddr", + .data = &init_net.ipv4.sysctl_icmp_errors_use_inbound_ifaddr, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE + }, + { + .procname = "icmp_ratelimit", + .data = &init_net.ipv4.sysctl_icmp_ratelimit, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_ms_jiffies, + }, + { + .procname = "icmp_ratemask", + .data = &init_net.ipv4.sysctl_icmp_ratemask, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "ping_group_range", + .data = &init_net.ipv4.ping_group_range.range, + .maxlen = sizeof(gid_t)*2, + .mode = 0644, + .proc_handler = ipv4_ping_group_range, + }, +#ifdef CONFIG_NET_L3_MASTER_DEV + { + .procname = "raw_l3mdev_accept", + .data = &init_net.ipv4.sysctl_raw_l3mdev_accept, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, +#endif + { + .procname = "tcp_ecn", + .data = &init_net.ipv4.sysctl_tcp_ecn, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_TWO, + }, + { + .procname = "tcp_ecn_fallback", + .data = &init_net.ipv4.sysctl_tcp_ecn_fallback, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + { + .procname = "ip_dynaddr", + .data = &init_net.ipv4.sysctl_ip_dynaddr, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "ip_early_demux", + .data = &init_net.ipv4.sysctl_ip_early_demux, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "udp_early_demux", + .data = &init_net.ipv4.sysctl_udp_early_demux, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_early_demux", + .data = &init_net.ipv4.sysctl_tcp_early_demux, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "nexthop_compat_mode", + .data = &init_net.ipv4.sysctl_nexthop_compat_mode, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + { + .procname = "ip_default_ttl", + .data = &init_net.ipv4.sysctl_ip_default_ttl, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = &ip_ttl_min, + .extra2 = &ip_ttl_max, + }, + { + .procname = "ip_local_port_range", + .maxlen = sizeof(init_net.ipv4.ip_local_ports.range), + .data = &init_net.ipv4.ip_local_ports.range, + .mode = 0644, + .proc_handler = ipv4_local_port_range, + }, + { + .procname = "ip_local_reserved_ports", + .data = &init_net.ipv4.sysctl_local_reserved_ports, + .maxlen = 65536, + .mode = 0644, + .proc_handler = proc_do_large_bitmap, + }, + { + .procname = "ip_no_pmtu_disc", + .data = &init_net.ipv4.sysctl_ip_no_pmtu_disc, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "ip_forward_use_pmtu", + .data = &init_net.ipv4.sysctl_ip_fwd_use_pmtu, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "ip_forward_update_priority", + .data = &init_net.ipv4.sysctl_ip_fwd_update_priority, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = ipv4_fwd_update_priority, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + { + .procname = "ip_nonlocal_bind", + .data = &init_net.ipv4.sysctl_ip_nonlocal_bind, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "ip_autobind_reuse", + .data = &init_net.ipv4.sysctl_ip_autobind_reuse, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + { + .procname = "fwmark_reflect", + .data = &init_net.ipv4.sysctl_fwmark_reflect, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_fwmark_accept", + .data = &init_net.ipv4.sysctl_tcp_fwmark_accept, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, +#ifdef CONFIG_NET_L3_MASTER_DEV + { + .procname = "tcp_l3mdev_accept", + .data = &init_net.ipv4.sysctl_tcp_l3mdev_accept, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, +#endif + { + .procname = "tcp_mtu_probing", + .data = &init_net.ipv4.sysctl_tcp_mtu_probing, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_base_mss", + .data = &init_net.ipv4.sysctl_tcp_base_mss, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "tcp_min_snd_mss", + .data = &init_net.ipv4.sysctl_tcp_min_snd_mss, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &tcp_min_snd_mss_min, + .extra2 = &tcp_min_snd_mss_max, + }, + { + .procname = "tcp_mtu_probe_floor", + .data = &init_net.ipv4.sysctl_tcp_mtu_probe_floor, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &tcp_min_snd_mss_min, + .extra2 = &tcp_min_snd_mss_max, + }, + { + .procname = "tcp_probe_threshold", + .data = &init_net.ipv4.sysctl_tcp_probe_threshold, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "tcp_probe_interval", + .data = &init_net.ipv4.sysctl_tcp_probe_interval, + .maxlen = sizeof(u32), + .mode = 0644, + .proc_handler = proc_douintvec_minmax, + .extra2 = &u32_max_div_HZ, + }, + { + .procname = "igmp_link_local_mcast_reports", + .data = &init_net.ipv4.sysctl_igmp_llm_reports, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "igmp_max_memberships", + .data = &init_net.ipv4.sysctl_igmp_max_memberships, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "igmp_max_msf", + .data = &init_net.ipv4.sysctl_igmp_max_msf, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, +#ifdef CONFIG_IP_MULTICAST + { + .procname = "igmp_qrv", + .data = &init_net.ipv4.sysctl_igmp_qrv, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE + }, +#endif + { + .procname = "tcp_congestion_control", + .data = &init_net.ipv4.tcp_congestion_control, + .mode = 0644, + .maxlen = TCP_CA_NAME_MAX, + .proc_handler = proc_tcp_congestion_control, + }, + { + .procname = "tcp_available_congestion_control", + .maxlen = TCP_CA_BUF_MAX, + .mode = 0444, + .proc_handler = proc_tcp_available_congestion_control, + }, + { + .procname = "tcp_allowed_congestion_control", + .maxlen = TCP_CA_BUF_MAX, + .mode = 0644, + .proc_handler = proc_allowed_congestion_control, + }, + { + .procname = "tcp_keepalive_time", + .data = &init_net.ipv4.sysctl_tcp_keepalive_time, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "tcp_keepalive_probes", + .data = &init_net.ipv4.sysctl_tcp_keepalive_probes, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_keepalive_intvl", + .data = &init_net.ipv4.sysctl_tcp_keepalive_intvl, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "tcp_syn_retries", + .data = &init_net.ipv4.sysctl_tcp_syn_retries, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = &tcp_syn_retries_min, + .extra2 = &tcp_syn_retries_max + }, + { + .procname = "tcp_synack_retries", + .data = &init_net.ipv4.sysctl_tcp_synack_retries, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, +#ifdef CONFIG_SYN_COOKIES + { + .procname = "tcp_syncookies", + .data = &init_net.ipv4.sysctl_tcp_syncookies, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, +#endif + { + .procname = "tcp_migrate_req", + .data = &init_net.ipv4.sysctl_tcp_migrate_req, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE + }, + { + .procname = "tcp_reordering", + .data = &init_net.ipv4.sysctl_tcp_reordering, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "tcp_retries1", + .data = &init_net.ipv4.sysctl_tcp_retries1, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra2 = &tcp_retr1_max + }, + { + .procname = "tcp_retries2", + .data = &init_net.ipv4.sysctl_tcp_retries2, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_orphan_retries", + .data = &init_net.ipv4.sysctl_tcp_orphan_retries, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_fin_timeout", + .data = &init_net.ipv4.sysctl_tcp_fin_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "tcp_notsent_lowat", + .data = &init_net.ipv4.sysctl_tcp_notsent_lowat, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_douintvec, + }, + { + .procname = "tcp_tw_reuse", + .data = &init_net.ipv4.sysctl_tcp_tw_reuse, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_TWO, + }, + { + .procname = "tcp_max_syn_backlog", + .data = &init_net.ipv4.sysctl_max_syn_backlog, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "tcp_fastopen", + .data = &init_net.ipv4.sysctl_tcp_fastopen, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "tcp_fastopen_key", + .mode = 0600, + .data = &init_net.ipv4.sysctl_tcp_fastopen, + /* maxlen to print the list of keys in hex (*2), with dashes + * separating doublewords and a comma in between keys. + */ + .maxlen = ((TCP_FASTOPEN_KEY_LENGTH * + 2 * TCP_FASTOPEN_KEY_MAX) + + (TCP_FASTOPEN_KEY_MAX * 5)), + .proc_handler = proc_tcp_fastopen_key, + }, + { + .procname = "tcp_fastopen_blackhole_timeout_sec", + .data = &init_net.ipv4.sysctl_tcp_fastopen_blackhole_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_tfo_blackhole_detect_timeout, + .extra1 = SYSCTL_ZERO, + }, +#ifdef CONFIG_IP_ROUTE_MULTIPATH + { + .procname = "fib_multipath_use_neigh", + .data = &init_net.ipv4.sysctl_fib_multipath_use_neigh, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + { + .procname = "fib_multipath_hash_policy", + .data = &init_net.ipv4.sysctl_fib_multipath_hash_policy, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_fib_multipath_hash_policy, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_THREE, + }, + { + .procname = "fib_multipath_hash_fields", + .data = &init_net.ipv4.sysctl_fib_multipath_hash_fields, + .maxlen = sizeof(u32), + .mode = 0644, + .proc_handler = proc_fib_multipath_hash_fields, + .extra1 = SYSCTL_ONE, + .extra2 = &fib_multipath_hash_fields_all_mask, + }, +#endif + { + .procname = "ip_unprivileged_port_start", + .maxlen = sizeof(int), + .data = &init_net.ipv4.sysctl_ip_prot_sock, + .mode = 0644, + .proc_handler = ipv4_privileged_ports, + }, +#ifdef CONFIG_NET_L3_MASTER_DEV + { + .procname = "udp_l3mdev_accept", + .data = &init_net.ipv4.sysctl_udp_l3mdev_accept, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, +#endif + { + .procname = "tcp_sack", + .data = &init_net.ipv4.sysctl_tcp_sack, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_window_scaling", + .data = &init_net.ipv4.sysctl_tcp_window_scaling, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_timestamps", + .data = &init_net.ipv4.sysctl_tcp_timestamps, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_early_retrans", + .data = &init_net.ipv4.sysctl_tcp_early_retrans, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_FOUR, + }, + { + .procname = "tcp_recovery", + .data = &init_net.ipv4.sysctl_tcp_recovery, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_thin_linear_timeouts", + .data = &init_net.ipv4.sysctl_tcp_thin_linear_timeouts, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_slow_start_after_idle", + .data = &init_net.ipv4.sysctl_tcp_slow_start_after_idle, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_retrans_collapse", + .data = &init_net.ipv4.sysctl_tcp_retrans_collapse, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_stdurg", + .data = &init_net.ipv4.sysctl_tcp_stdurg, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_rfc1337", + .data = &init_net.ipv4.sysctl_tcp_rfc1337, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_abort_on_overflow", + .data = &init_net.ipv4.sysctl_tcp_abort_on_overflow, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_fack", + .data = &init_net.ipv4.sysctl_tcp_fack, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_max_reordering", + .data = &init_net.ipv4.sysctl_tcp_max_reordering, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "tcp_dsack", + .data = &init_net.ipv4.sysctl_tcp_dsack, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_app_win", + .data = &init_net.ipv4.sysctl_tcp_app_win, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &tcp_app_win_max, + }, + { + .procname = "tcp_adv_win_scale", + .data = &init_net.ipv4.sysctl_tcp_adv_win_scale, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &tcp_adv_win_scale_min, + .extra2 = &tcp_adv_win_scale_max, + }, + { + .procname = "tcp_frto", + .data = &init_net.ipv4.sysctl_tcp_frto, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_no_metrics_save", + .data = &init_net.ipv4.sysctl_tcp_nometrics_save, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_no_ssthresh_metrics_save", + .data = &init_net.ipv4.sysctl_tcp_no_ssthresh_metrics_save, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + { + .procname = "tcp_moderate_rcvbuf", + .data = &init_net.ipv4.sysctl_tcp_moderate_rcvbuf, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_tso_win_divisor", + .data = &init_net.ipv4.sysctl_tcp_tso_win_divisor, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_workaround_signed_windows", + .data = &init_net.ipv4.sysctl_tcp_workaround_signed_windows, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_limit_output_bytes", + .data = &init_net.ipv4.sysctl_tcp_limit_output_bytes, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "tcp_challenge_ack_limit", + .data = &init_net.ipv4.sysctl_tcp_challenge_ack_limit, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "tcp_min_tso_segs", + .data = &init_net.ipv4.sysctl_tcp_min_tso_segs, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ONE, + }, + { + .procname = "tcp_tso_rtt_log", + .data = &init_net.ipv4.sysctl_tcp_tso_rtt_log, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "tcp_min_rtt_wlen", + .data = &init_net.ipv4.sysctl_tcp_min_rtt_wlen, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &one_day_secs + }, + { + .procname = "tcp_autocorking", + .data = &init_net.ipv4.sysctl_tcp_autocorking, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + { + .procname = "tcp_invalid_ratelimit", + .data = &init_net.ipv4.sysctl_tcp_invalid_ratelimit, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_ms_jiffies, + }, + { + .procname = "tcp_pacing_ss_ratio", + .data = &init_net.ipv4.sysctl_tcp_pacing_ss_ratio, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE_THOUSAND, + }, + { + .procname = "tcp_pacing_ca_ratio", + .data = &init_net.ipv4.sysctl_tcp_pacing_ca_ratio, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE_THOUSAND, + }, + { + .procname = "tcp_wmem", + .data = &init_net.ipv4.sysctl_tcp_wmem, + .maxlen = sizeof(init_net.ipv4.sysctl_tcp_wmem), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + }, + { + .procname = "tcp_rmem", + .data = &init_net.ipv4.sysctl_tcp_rmem, + .maxlen = sizeof(init_net.ipv4.sysctl_tcp_rmem), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + }, + { + .procname = "tcp_comp_sack_delay_ns", + .data = &init_net.ipv4.sysctl_tcp_comp_sack_delay_ns, + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + }, + { + .procname = "tcp_comp_sack_slack_ns", + .data = &init_net.ipv4.sysctl_tcp_comp_sack_slack_ns, + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + }, + { + .procname = "tcp_comp_sack_nr", + .data = &init_net.ipv4.sysctl_tcp_comp_sack_nr, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + }, + { + .procname = "tcp_reflect_tos", + .data = &init_net.ipv4.sysctl_tcp_reflect_tos, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + { + .procname = "tcp_ehash_entries", + .data = &init_net.ipv4.sysctl_tcp_child_ehash_entries, + .mode = 0444, + .proc_handler = proc_tcp_ehash_entries, + }, + { + .procname = "tcp_child_ehash_entries", + .data = &init_net.ipv4.sysctl_tcp_child_ehash_entries, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_douintvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &tcp_child_ehash_entries_max, + }, + { + .procname = "udp_rmem_min", + .data = &init_net.ipv4.sysctl_udp_rmem_min, + .maxlen = sizeof(init_net.ipv4.sysctl_udp_rmem_min), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE + }, + { + .procname = "udp_wmem_min", + .data = &init_net.ipv4.sysctl_udp_wmem_min, + .maxlen = sizeof(init_net.ipv4.sysctl_udp_wmem_min), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE + }, + { + .procname = "fib_notify_on_flag_change", + .data = &init_net.ipv4.sysctl_fib_notify_on_flag_change, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_TWO, + }, + { + .procname = "tcp_shrink_window", + .data = &init_net.ipv4.sysctl_tcp_shrink_window, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + { } +}; + +static __net_init int ipv4_sysctl_init_net(struct net *net) +{ + struct ctl_table *table; + + table = ipv4_net_table; + if (!net_eq(net, &init_net)) { + int i; + + table = kmemdup(table, sizeof(ipv4_net_table), GFP_KERNEL); + if (!table) + goto err_alloc; + + for (i = 0; i < ARRAY_SIZE(ipv4_net_table) - 1; i++) { + if (table[i].data) { + /* Update the variables to point into + * the current struct net + */ + table[i].data += (void *)net - (void *)&init_net; + } else { + /* Entries without data pointer are global; + * Make them read-only in non-init_net ns + */ + table[i].mode &= ~0222; + } + } + } + + net->ipv4.ipv4_hdr = register_net_sysctl(net, "net/ipv4", table); + if (!net->ipv4.ipv4_hdr) + goto err_reg; + + net->ipv4.sysctl_local_reserved_ports = kzalloc(65536 / 8, GFP_KERNEL); + if (!net->ipv4.sysctl_local_reserved_ports) + goto err_ports; + + return 0; + +err_ports: + unregister_net_sysctl_table(net->ipv4.ipv4_hdr); +err_reg: + if (!net_eq(net, &init_net)) + kfree(table); +err_alloc: + return -ENOMEM; +} + +static __net_exit void ipv4_sysctl_exit_net(struct net *net) +{ + struct ctl_table *table; + + kfree(net->ipv4.sysctl_local_reserved_ports); + table = net->ipv4.ipv4_hdr->ctl_table_arg; + unregister_net_sysctl_table(net->ipv4.ipv4_hdr); + kfree(table); +} + +static __net_initdata struct pernet_operations ipv4_sysctl_ops = { + .init = ipv4_sysctl_init_net, + .exit = ipv4_sysctl_exit_net, +}; + +static __init int sysctl_ipv4_init(void) +{ + struct ctl_table_header *hdr; + + hdr = register_net_sysctl(&init_net, "net/ipv4", ipv4_table); + if (!hdr) + return -ENOMEM; + + if (register_pernet_subsys(&ipv4_sysctl_ops)) { + unregister_net_sysctl_table(hdr); + return -ENOMEM; + } + + return 0; +} + +__initcall(sysctl_ipv4_init); diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c new file mode 100644 index 000000000..90e24c3f6 --- /dev/null +++ b/net/ipv4/tcp.c @@ -0,0 +1,4864 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * Implementation of the Transmission Control Protocol(TCP). + * + * Authors: Ross Biro + * Fred N. van Kempen, + * Mark Evans, + * Corey Minyard + * Florian La Roche, + * Charles Hedrick, + * Linus Torvalds, + * Alan Cox, + * Matthew Dillon, + * Arnt Gulbrandsen, + * Jorge Cwik, + * + * Fixes: + * Alan Cox : Numerous verify_area() calls + * Alan Cox : Set the ACK bit on a reset + * Alan Cox : Stopped it crashing if it closed while + * sk->inuse=1 and was trying to connect + * (tcp_err()). + * Alan Cox : All icmp error handling was broken + * pointers passed where wrong and the + * socket was looked up backwards. Nobody + * tested any icmp error code obviously. + * Alan Cox : tcp_err() now handled properly. It + * wakes people on errors. poll + * behaves and the icmp error race + * has gone by moving it into sock.c + * Alan Cox : tcp_send_reset() fixed to work for + * everything not just packets for + * unknown sockets. + * Alan Cox : tcp option processing. + * Alan Cox : Reset tweaked (still not 100%) [Had + * syn rule wrong] + * Herp Rosmanith : More reset fixes + * Alan Cox : No longer acks invalid rst frames. + * Acking any kind of RST is right out. + * Alan Cox : Sets an ignore me flag on an rst + * receive otherwise odd bits of prattle + * escape still + * Alan Cox : Fixed another acking RST frame bug. + * Should stop LAN workplace lockups. + * Alan Cox : Some tidyups using the new skb list + * facilities + * Alan Cox : sk->keepopen now seems to work + * Alan Cox : Pulls options out correctly on accepts + * Alan Cox : Fixed assorted sk->rqueue->next errors + * Alan Cox : PSH doesn't end a TCP read. Switched a + * bit to skb ops. + * Alan Cox : Tidied tcp_data to avoid a potential + * nasty. + * Alan Cox : Added some better commenting, as the + * tcp is hard to follow + * Alan Cox : Removed incorrect check for 20 * psh + * Michael O'Reilly : ack < copied bug fix. + * Johannes Stille : Misc tcp fixes (not all in yet). + * Alan Cox : FIN with no memory -> CRASH + * Alan Cox : Added socket option proto entries. + * Also added awareness of them to accept. + * Alan Cox : Added TCP options (SOL_TCP) + * Alan Cox : Switched wakeup calls to callbacks, + * so the kernel can layer network + * sockets. + * Alan Cox : Use ip_tos/ip_ttl settings. + * Alan Cox : Handle FIN (more) properly (we hope). + * Alan Cox : RST frames sent on unsynchronised + * state ack error. + * Alan Cox : Put in missing check for SYN bit. + * Alan Cox : Added tcp_select_window() aka NET2E + * window non shrink trick. + * Alan Cox : Added a couple of small NET2E timer + * fixes + * Charles Hedrick : TCP fixes + * Toomas Tamm : TCP window fixes + * Alan Cox : Small URG fix to rlogin ^C ack fight + * Charles Hedrick : Rewrote most of it to actually work + * Linus : Rewrote tcp_read() and URG handling + * completely + * Gerhard Koerting: Fixed some missing timer handling + * Matthew Dillon : Reworked TCP machine states as per RFC + * Gerhard Koerting: PC/TCP workarounds + * Adam Caldwell : Assorted timer/timing errors + * Matthew Dillon : Fixed another RST bug + * Alan Cox : Move to kernel side addressing changes. + * Alan Cox : Beginning work on TCP fastpathing + * (not yet usable) + * Arnt Gulbrandsen: Turbocharged tcp_check() routine. + * Alan Cox : TCP fast path debugging + * Alan Cox : Window clamping + * Michael Riepe : Bug in tcp_check() + * Matt Dillon : More TCP improvements and RST bug fixes + * Matt Dillon : Yet more small nasties remove from the + * TCP code (Be very nice to this man if + * tcp finally works 100%) 8) + * Alan Cox : BSD accept semantics. + * Alan Cox : Reset on closedown bug. + * Peter De Schrijver : ENOTCONN check missing in tcp_sendto(). + * Michael Pall : Handle poll() after URG properly in + * all cases. + * Michael Pall : Undo the last fix in tcp_read_urg() + * (multi URG PUSH broke rlogin). + * Michael Pall : Fix the multi URG PUSH problem in + * tcp_readable(), poll() after URG + * works now. + * Michael Pall : recv(...,MSG_OOB) never blocks in the + * BSD api. + * Alan Cox : Changed the semantics of sk->socket to + * fix a race and a signal problem with + * accept() and async I/O. + * Alan Cox : Relaxed the rules on tcp_sendto(). + * Yury Shevchuk : Really fixed accept() blocking problem. + * Craig I. Hagan : Allow for BSD compatible TIME_WAIT for + * clients/servers which listen in on + * fixed ports. + * Alan Cox : Cleaned the above up and shrank it to + * a sensible code size. + * Alan Cox : Self connect lockup fix. + * Alan Cox : No connect to multicast. + * Ross Biro : Close unaccepted children on master + * socket close. + * Alan Cox : Reset tracing code. + * Alan Cox : Spurious resets on shutdown. + * Alan Cox : Giant 15 minute/60 second timer error + * Alan Cox : Small whoops in polling before an + * accept. + * Alan Cox : Kept the state trace facility since + * it's handy for debugging. + * Alan Cox : More reset handler fixes. + * Alan Cox : Started rewriting the code based on + * the RFC's for other useful protocol + * references see: Comer, KA9Q NOS, and + * for a reference on the difference + * between specifications and how BSD + * works see the 4.4lite source. + * A.N.Kuznetsov : Don't time wait on completion of tidy + * close. + * Linus Torvalds : Fin/Shutdown & copied_seq changes. + * Linus Torvalds : Fixed BSD port reuse to work first syn + * Alan Cox : Reimplemented timers as per the RFC + * and using multiple timers for sanity. + * Alan Cox : Small bug fixes, and a lot of new + * comments. + * Alan Cox : Fixed dual reader crash by locking + * the buffers (much like datagram.c) + * Alan Cox : Fixed stuck sockets in probe. A probe + * now gets fed up of retrying without + * (even a no space) answer. + * Alan Cox : Extracted closing code better + * Alan Cox : Fixed the closing state machine to + * resemble the RFC. + * Alan Cox : More 'per spec' fixes. + * Jorge Cwik : Even faster checksumming. + * Alan Cox : tcp_data() doesn't ack illegal PSH + * only frames. At least one pc tcp stack + * generates them. + * Alan Cox : Cache last socket. + * Alan Cox : Per route irtt. + * Matt Day : poll()->select() match BSD precisely on error + * Alan Cox : New buffers + * Marc Tamsky : Various sk->prot->retransmits and + * sk->retransmits misupdating fixed. + * Fixed tcp_write_timeout: stuck close, + * and TCP syn retries gets used now. + * Mark Yarvis : In tcp_read_wakeup(), don't send an + * ack if state is TCP_CLOSED. + * Alan Cox : Look up device on a retransmit - routes may + * change. Doesn't yet cope with MSS shrink right + * but it's a start! + * Marc Tamsky : Closing in closing fixes. + * Mike Shaver : RFC1122 verifications. + * Alan Cox : rcv_saddr errors. + * Alan Cox : Block double connect(). + * Alan Cox : Small hooks for enSKIP. + * Alexey Kuznetsov: Path MTU discovery. + * Alan Cox : Support soft errors. + * Alan Cox : Fix MTU discovery pathological case + * when the remote claims no mtu! + * Marc Tamsky : TCP_CLOSE fix. + * Colin (G3TNE) : Send a reset on syn ack replies in + * window but wrong (fixes NT lpd problems) + * Pedro Roque : Better TCP window handling, delayed ack. + * Joerg Reuter : No modification of locked buffers in + * tcp_do_retransmit() + * Eric Schenk : Changed receiver side silly window + * avoidance algorithm to BSD style + * algorithm. This doubles throughput + * against machines running Solaris, + * and seems to result in general + * improvement. + * Stefan Magdalinski : adjusted tcp_readable() to fix FIONREAD + * Willy Konynenberg : Transparent proxying support. + * Mike McLagan : Routing by source + * Keith Owens : Do proper merging with partial SKB's in + * tcp_do_sendmsg to avoid burstiness. + * Eric Schenk : Fix fast close down bug with + * shutdown() followed by close(). + * Andi Kleen : Make poll agree with SIGIO + * Salvatore Sanfilippo : Support SO_LINGER with linger == 1 and + * lingertime == 0 (RFC 793 ABORT Call) + * Hirokazu Takahashi : Use copy_from_user() instead of + * csum_and_copy_from_user() if possible. + * + * Description of States: + * + * TCP_SYN_SENT sent a connection request, waiting for ack + * + * TCP_SYN_RECV received a connection request, sent ack, + * waiting for final ack in three-way handshake. + * + * TCP_ESTABLISHED connection established + * + * TCP_FIN_WAIT1 our side has shutdown, waiting to complete + * transmission of remaining buffered data + * + * TCP_FIN_WAIT2 all buffered data sent, waiting for remote + * to shutdown + * + * TCP_CLOSING both sides have shutdown but we still have + * data we have to finish sending + * + * TCP_TIME_WAIT timeout to catch resent junk before entering + * closed, can only be entered from FIN_WAIT2 + * or CLOSING. Required because the other end + * may not have gotten our last ACK causing it + * to retransmit the data packet (which we ignore) + * + * TCP_CLOSE_WAIT remote side has shutdown and is waiting for + * us to finish writing our data and to shutdown + * (we have to close() to move on to LAST_ACK) + * + * TCP_LAST_ACK out side has shutdown after remote has + * shutdown. There may still be data in our + * buffer that we have to finish sending + * + * TCP_CLOSE socket is finished + */ + +#define pr_fmt(fmt) "TCP: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +/* Track pending CMSGs. */ +enum { + TCP_CMSG_INQ = 1, + TCP_CMSG_TS = 2 +}; + +DEFINE_PER_CPU(unsigned int, tcp_orphan_count); +EXPORT_PER_CPU_SYMBOL_GPL(tcp_orphan_count); + +long sysctl_tcp_mem[3] __read_mostly; +EXPORT_SYMBOL(sysctl_tcp_mem); + +atomic_long_t tcp_memory_allocated ____cacheline_aligned_in_smp; /* Current allocated memory. */ +EXPORT_SYMBOL(tcp_memory_allocated); +DEFINE_PER_CPU(int, tcp_memory_per_cpu_fw_alloc); +EXPORT_PER_CPU_SYMBOL_GPL(tcp_memory_per_cpu_fw_alloc); + +#if IS_ENABLED(CONFIG_SMC) +DEFINE_STATIC_KEY_FALSE(tcp_have_smc); +EXPORT_SYMBOL(tcp_have_smc); +#endif + +/* + * Current number of TCP sockets. + */ +struct percpu_counter tcp_sockets_allocated ____cacheline_aligned_in_smp; +EXPORT_SYMBOL(tcp_sockets_allocated); + +/* + * TCP splice context + */ +struct tcp_splice_state { + struct pipe_inode_info *pipe; + size_t len; + unsigned int flags; +}; + +/* + * Pressure flag: try to collapse. + * Technical note: it is used by multiple contexts non atomically. + * All the __sk_mem_schedule() is of this nature: accounting + * is strict, actions are advisory and have some latency. + */ +unsigned long tcp_memory_pressure __read_mostly; +EXPORT_SYMBOL_GPL(tcp_memory_pressure); + +void tcp_enter_memory_pressure(struct sock *sk) +{ + unsigned long val; + + if (READ_ONCE(tcp_memory_pressure)) + return; + val = jiffies; + + if (!val) + val--; + if (!cmpxchg(&tcp_memory_pressure, 0, val)) + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMEMORYPRESSURES); +} +EXPORT_SYMBOL_GPL(tcp_enter_memory_pressure); + +void tcp_leave_memory_pressure(struct sock *sk) +{ + unsigned long val; + + if (!READ_ONCE(tcp_memory_pressure)) + return; + val = xchg(&tcp_memory_pressure, 0); + if (val) + NET_ADD_STATS(sock_net(sk), LINUX_MIB_TCPMEMORYPRESSURESCHRONO, + jiffies_to_msecs(jiffies - val)); +} +EXPORT_SYMBOL_GPL(tcp_leave_memory_pressure); + +/* Convert seconds to retransmits based on initial and max timeout */ +static u8 secs_to_retrans(int seconds, int timeout, int rto_max) +{ + u8 res = 0; + + if (seconds > 0) { + int period = timeout; + + res = 1; + while (seconds > period && res < 255) { + res++; + timeout <<= 1; + if (timeout > rto_max) + timeout = rto_max; + period += timeout; + } + } + return res; +} + +/* Convert retransmits to seconds based on initial and max timeout */ +static int retrans_to_secs(u8 retrans, int timeout, int rto_max) +{ + int period = 0; + + if (retrans > 0) { + period = timeout; + while (--retrans) { + timeout <<= 1; + if (timeout > rto_max) + timeout = rto_max; + period += timeout; + } + } + return period; +} + +static u64 tcp_compute_delivery_rate(const struct tcp_sock *tp) +{ + u32 rate = READ_ONCE(tp->rate_delivered); + u32 intv = READ_ONCE(tp->rate_interval_us); + u64 rate64 = 0; + + if (rate && intv) { + rate64 = (u64)rate * tp->mss_cache * USEC_PER_SEC; + do_div(rate64, intv); + } + return rate64; +} + +/* Address-family independent initialization for a tcp_sock. + * + * NOTE: A lot of things set to zero explicitly by call to + * sk_alloc() so need not be done here. + */ +void tcp_init_sock(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + + tp->out_of_order_queue = RB_ROOT; + sk->tcp_rtx_queue = RB_ROOT; + tcp_init_xmit_timers(sk); + INIT_LIST_HEAD(&tp->tsq_node); + INIT_LIST_HEAD(&tp->tsorted_sent_queue); + + icsk->icsk_rto = TCP_TIMEOUT_INIT; + icsk->icsk_rto_min = TCP_RTO_MIN; + icsk->icsk_delack_max = TCP_DELACK_MAX; + tp->mdev_us = jiffies_to_usecs(TCP_TIMEOUT_INIT); + minmax_reset(&tp->rtt_min, tcp_jiffies32, ~0U); + + /* So many TCP implementations out there (incorrectly) count the + * initial SYN frame in their delayed-ACK and congestion control + * algorithms that we must have the following bandaid to talk + * efficiently to them. -DaveM + */ + tcp_snd_cwnd_set(tp, TCP_INIT_CWND); + + /* There's a bubble in the pipe until at least the first ACK. */ + tp->app_limited = ~0U; + tp->rate_app_limited = 1; + + /* See draft-stevens-tcpca-spec-01 for discussion of the + * initialization of these values. + */ + tp->snd_ssthresh = TCP_INFINITE_SSTHRESH; + tp->snd_cwnd_clamp = ~0; + tp->mss_cache = TCP_MSS_DEFAULT; + + tp->reordering = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_reordering); + tcp_assign_congestion_control(sk); + + tp->tsoffset = 0; + tp->rack.reo_wnd_steps = 1; + + sk->sk_write_space = sk_stream_write_space; + sock_set_flag(sk, SOCK_USE_WRITE_QUEUE); + + icsk->icsk_sync_mss = tcp_sync_mss; + + WRITE_ONCE(sk->sk_sndbuf, READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_wmem[1])); + WRITE_ONCE(sk->sk_rcvbuf, READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_rmem[1])); + + set_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags); + sk_sockets_allocated_inc(sk); +} +EXPORT_SYMBOL(tcp_init_sock); + +static void tcp_tx_timestamp(struct sock *sk, u16 tsflags) +{ + struct sk_buff *skb = tcp_write_queue_tail(sk); + + if (tsflags && skb) { + struct skb_shared_info *shinfo = skb_shinfo(skb); + struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); + + sock_tx_timestamp(sk, tsflags, &shinfo->tx_flags); + if (tsflags & SOF_TIMESTAMPING_TX_ACK) + tcb->txstamp_ack = 1; + if (tsflags & SOF_TIMESTAMPING_TX_RECORD_MASK) + shinfo->tskey = TCP_SKB_CB(skb)->seq + skb->len - 1; + } +} + +static bool tcp_stream_is_readable(struct sock *sk, int target) +{ + if (tcp_epollin_ready(sk, target)) + return true; + return sk_is_readable(sk); +} + +/* + * Wait for a TCP event. + * + * Note that we don't need to lock the socket, as the upper poll layers + * take care of normal races (between the test and the event) and we don't + * go look at any of the socket buffers directly. + */ +__poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait) +{ + __poll_t mask; + struct sock *sk = sock->sk; + const struct tcp_sock *tp = tcp_sk(sk); + u8 shutdown; + int state; + + sock_poll_wait(file, sock, wait); + + state = inet_sk_state_load(sk); + if (state == TCP_LISTEN) + return inet_csk_listen_poll(sk); + + /* Socket is not locked. We are protected from async events + * by poll logic and correct handling of state changes + * made by other threads is impossible in any case. + */ + + mask = 0; + + /* + * EPOLLHUP is certainly not done right. But poll() doesn't + * have a notion of HUP in just one direction, and for a + * socket the read side is more interesting. + * + * Some poll() documentation says that EPOLLHUP is incompatible + * with the EPOLLOUT/POLLWR flags, so somebody should check this + * all. But careful, it tends to be safer to return too many + * bits than too few, and you can easily break real applications + * if you don't tell them that something has hung up! + * + * Check-me. + * + * Check number 1. EPOLLHUP is _UNMASKABLE_ event (see UNIX98 and + * our fs/select.c). It means that after we received EOF, + * poll always returns immediately, making impossible poll() on write() + * in state CLOSE_WAIT. One solution is evident --- to set EPOLLHUP + * if and only if shutdown has been made in both directions. + * Actually, it is interesting to look how Solaris and DUX + * solve this dilemma. I would prefer, if EPOLLHUP were maskable, + * then we could set it on SND_SHUTDOWN. BTW examples given + * in Stevens' books assume exactly this behaviour, it explains + * why EPOLLHUP is incompatible with EPOLLOUT. --ANK + * + * NOTE. Check for TCP_CLOSE is added. The goal is to prevent + * blocking on fresh not-connected or disconnected socket. --ANK + */ + shutdown = READ_ONCE(sk->sk_shutdown); + if (shutdown == SHUTDOWN_MASK || state == TCP_CLOSE) + mask |= EPOLLHUP; + if (shutdown & RCV_SHUTDOWN) + mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP; + + /* Connected or passive Fast Open socket? */ + if (state != TCP_SYN_SENT && + (state != TCP_SYN_RECV || rcu_access_pointer(tp->fastopen_rsk))) { + int target = sock_rcvlowat(sk, 0, INT_MAX); + u16 urg_data = READ_ONCE(tp->urg_data); + + if (unlikely(urg_data) && + READ_ONCE(tp->urg_seq) == READ_ONCE(tp->copied_seq) && + !sock_flag(sk, SOCK_URGINLINE)) + target++; + + if (tcp_stream_is_readable(sk, target)) + mask |= EPOLLIN | EPOLLRDNORM; + + if (!(shutdown & SEND_SHUTDOWN)) { + if (__sk_stream_is_writeable(sk, 1)) { + mask |= EPOLLOUT | EPOLLWRNORM; + } else { /* send SIGIO later */ + sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk); + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + + /* Race breaker. If space is freed after + * wspace test but before the flags are set, + * IO signal will be lost. Memory barrier + * pairs with the input side. + */ + smp_mb__after_atomic(); + if (__sk_stream_is_writeable(sk, 1)) + mask |= EPOLLOUT | EPOLLWRNORM; + } + } else + mask |= EPOLLOUT | EPOLLWRNORM; + + if (urg_data & TCP_URG_VALID) + mask |= EPOLLPRI; + } else if (state == TCP_SYN_SENT && inet_sk(sk)->defer_connect) { + /* Active TCP fastopen socket with defer_connect + * Return EPOLLOUT so application can call write() + * in order for kernel to generate SYN+data + */ + mask |= EPOLLOUT | EPOLLWRNORM; + } + /* This barrier is coupled with smp_wmb() in tcp_reset() */ + smp_rmb(); + if (sk->sk_err || !skb_queue_empty_lockless(&sk->sk_error_queue)) + mask |= EPOLLERR; + + return mask; +} +EXPORT_SYMBOL(tcp_poll); + +int tcp_ioctl(struct sock *sk, int cmd, unsigned long arg) +{ + struct tcp_sock *tp = tcp_sk(sk); + int answ; + bool slow; + + switch (cmd) { + case SIOCINQ: + if (sk->sk_state == TCP_LISTEN) + return -EINVAL; + + slow = lock_sock_fast(sk); + answ = tcp_inq(sk); + unlock_sock_fast(sk, slow); + break; + case SIOCATMARK: + answ = READ_ONCE(tp->urg_data) && + READ_ONCE(tp->urg_seq) == READ_ONCE(tp->copied_seq); + break; + case SIOCOUTQ: + if (sk->sk_state == TCP_LISTEN) + return -EINVAL; + + if ((1 << sk->sk_state) & (TCPF_SYN_SENT | TCPF_SYN_RECV)) + answ = 0; + else + answ = READ_ONCE(tp->write_seq) - tp->snd_una; + break; + case SIOCOUTQNSD: + if (sk->sk_state == TCP_LISTEN) + return -EINVAL; + + if ((1 << sk->sk_state) & (TCPF_SYN_SENT | TCPF_SYN_RECV)) + answ = 0; + else + answ = READ_ONCE(tp->write_seq) - + READ_ONCE(tp->snd_nxt); + break; + default: + return -ENOIOCTLCMD; + } + + return put_user(answ, (int __user *)arg); +} +EXPORT_SYMBOL(tcp_ioctl); + +void tcp_mark_push(struct tcp_sock *tp, struct sk_buff *skb) +{ + TCP_SKB_CB(skb)->tcp_flags |= TCPHDR_PSH; + tp->pushed_seq = tp->write_seq; +} + +static inline bool forced_push(const struct tcp_sock *tp) +{ + return after(tp->write_seq, tp->pushed_seq + (tp->max_window >> 1)); +} + +void tcp_skb_entail(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); + + tcb->seq = tcb->end_seq = tp->write_seq; + tcb->tcp_flags = TCPHDR_ACK; + __skb_header_release(skb); + tcp_add_write_queue_tail(sk, skb); + sk_wmem_queued_add(sk, skb->truesize); + sk_mem_charge(sk, skb->truesize); + if (tp->nonagle & TCP_NAGLE_PUSH) + tp->nonagle &= ~TCP_NAGLE_PUSH; + + tcp_slow_start_after_idle_check(sk); +} + +static inline void tcp_mark_urg(struct tcp_sock *tp, int flags) +{ + if (flags & MSG_OOB) + tp->snd_up = tp->write_seq; +} + +/* If a not yet filled skb is pushed, do not send it if + * we have data packets in Qdisc or NIC queues : + * Because TX completion will happen shortly, it gives a chance + * to coalesce future sendmsg() payload into this skb, without + * need for a timer, and with no latency trade off. + * As packets containing data payload have a bigger truesize + * than pure acks (dataless) packets, the last checks prevent + * autocorking if we only have an ACK in Qdisc/NIC queues, + * or if TX completion was delayed after we processed ACK packet. + */ +static bool tcp_should_autocork(struct sock *sk, struct sk_buff *skb, + int size_goal) +{ + return skb->len < size_goal && + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_autocorking) && + !tcp_rtx_queue_empty(sk) && + refcount_read(&sk->sk_wmem_alloc) > skb->truesize && + tcp_skb_can_collapse_to(skb); +} + +void tcp_push(struct sock *sk, int flags, int mss_now, + int nonagle, int size_goal) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *skb; + + skb = tcp_write_queue_tail(sk); + if (!skb) + return; + if (!(flags & MSG_MORE) || forced_push(tp)) + tcp_mark_push(tp, skb); + + tcp_mark_urg(tp, flags); + + if (tcp_should_autocork(sk, skb, size_goal)) { + + /* avoid atomic op if TSQ_THROTTLED bit is already set */ + if (!test_bit(TSQ_THROTTLED, &sk->sk_tsq_flags)) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPAUTOCORKING); + set_bit(TSQ_THROTTLED, &sk->sk_tsq_flags); + smp_mb__after_atomic(); + } + /* It is possible TX completion already happened + * before we set TSQ_THROTTLED. + */ + if (refcount_read(&sk->sk_wmem_alloc) > skb->truesize) + return; + } + + if (flags & MSG_MORE) + nonagle = TCP_NAGLE_CORK; + + __tcp_push_pending_frames(sk, mss_now, nonagle); +} + +static int tcp_splice_data_recv(read_descriptor_t *rd_desc, struct sk_buff *skb, + unsigned int offset, size_t len) +{ + struct tcp_splice_state *tss = rd_desc->arg.data; + int ret; + + ret = skb_splice_bits(skb, skb->sk, offset, tss->pipe, + min(rd_desc->count, len), tss->flags); + if (ret > 0) + rd_desc->count -= ret; + return ret; +} + +static int __tcp_splice_read(struct sock *sk, struct tcp_splice_state *tss) +{ + /* Store TCP splice context information in read_descriptor_t. */ + read_descriptor_t rd_desc = { + .arg.data = tss, + .count = tss->len, + }; + + return tcp_read_sock(sk, &rd_desc, tcp_splice_data_recv); +} + +/** + * tcp_splice_read - splice data from TCP socket to a pipe + * @sock: socket to splice from + * @ppos: position (not valid) + * @pipe: pipe to splice to + * @len: number of bytes to splice + * @flags: splice modifier flags + * + * Description: + * Will read pages from given socket and fill them into a pipe. + * + **/ +ssize_t tcp_splice_read(struct socket *sock, loff_t *ppos, + struct pipe_inode_info *pipe, size_t len, + unsigned int flags) +{ + struct sock *sk = sock->sk; + struct tcp_splice_state tss = { + .pipe = pipe, + .len = len, + .flags = flags, + }; + long timeo; + ssize_t spliced; + int ret; + + sock_rps_record_flow(sk); + /* + * We can't seek on a socket input + */ + if (unlikely(*ppos)) + return -ESPIPE; + + ret = spliced = 0; + + lock_sock(sk); + + timeo = sock_rcvtimeo(sk, sock->file->f_flags & O_NONBLOCK); + while (tss.len) { + ret = __tcp_splice_read(sk, &tss); + if (ret < 0) + break; + else if (!ret) { + if (spliced) + break; + if (sock_flag(sk, SOCK_DONE)) + break; + if (sk->sk_err) { + ret = sock_error(sk); + break; + } + if (sk->sk_shutdown & RCV_SHUTDOWN) + break; + if (sk->sk_state == TCP_CLOSE) { + /* + * This occurs when user tries to read + * from never connected socket. + */ + ret = -ENOTCONN; + break; + } + if (!timeo) { + ret = -EAGAIN; + break; + } + /* if __tcp_splice_read() got nothing while we have + * an skb in receive queue, we do not want to loop. + * This might happen with URG data. + */ + if (!skb_queue_empty(&sk->sk_receive_queue)) + break; + ret = sk_wait_data(sk, &timeo, NULL); + if (ret < 0) + break; + if (signal_pending(current)) { + ret = sock_intr_errno(timeo); + break; + } + continue; + } + tss.len -= ret; + spliced += ret; + + if (!timeo) + break; + release_sock(sk); + lock_sock(sk); + + if (sk->sk_err || sk->sk_state == TCP_CLOSE || + (sk->sk_shutdown & RCV_SHUTDOWN) || + signal_pending(current)) + break; + } + + release_sock(sk); + + if (spliced) + return spliced; + + return ret; +} +EXPORT_SYMBOL(tcp_splice_read); + +struct sk_buff *tcp_stream_alloc_skb(struct sock *sk, int size, gfp_t gfp, + bool force_schedule) +{ + struct sk_buff *skb; + + skb = alloc_skb_fclone(size + MAX_TCP_HEADER, gfp); + if (likely(skb)) { + bool mem_scheduled; + + skb->truesize = SKB_TRUESIZE(skb_end_offset(skb)); + if (force_schedule) { + mem_scheduled = true; + sk_forced_mem_schedule(sk, skb->truesize); + } else { + mem_scheduled = sk_wmem_schedule(sk, skb->truesize); + } + if (likely(mem_scheduled)) { + skb_reserve(skb, MAX_TCP_HEADER); + skb->ip_summed = CHECKSUM_PARTIAL; + INIT_LIST_HEAD(&skb->tcp_tsorted_anchor); + return skb; + } + __kfree_skb(skb); + } else { + sk->sk_prot->enter_memory_pressure(sk); + sk_stream_moderate_sndbuf(sk); + } + return NULL; +} + +static unsigned int tcp_xmit_size_goal(struct sock *sk, u32 mss_now, + int large_allowed) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 new_size_goal, size_goal; + + if (!large_allowed) + return mss_now; + + /* Note : tcp_tso_autosize() will eventually split this later */ + new_size_goal = tcp_bound_to_half_wnd(tp, sk->sk_gso_max_size); + + /* We try hard to avoid divides here */ + size_goal = tp->gso_segs * mss_now; + if (unlikely(new_size_goal < size_goal || + new_size_goal >= size_goal + mss_now)) { + tp->gso_segs = min_t(u16, new_size_goal / mss_now, + sk->sk_gso_max_segs); + size_goal = tp->gso_segs * mss_now; + } + + return max(size_goal, mss_now); +} + +int tcp_send_mss(struct sock *sk, int *size_goal, int flags) +{ + int mss_now; + + mss_now = tcp_current_mss(sk); + *size_goal = tcp_xmit_size_goal(sk, mss_now, !(flags & MSG_OOB)); + + return mss_now; +} + +/* In some cases, both sendpage() and sendmsg() could have added + * an skb to the write queue, but failed adding payload on it. + * We need to remove it to consume less memory, but more + * importantly be able to generate EPOLLOUT for Edge Trigger epoll() + * users. + */ +void tcp_remove_empty_skb(struct sock *sk) +{ + struct sk_buff *skb = tcp_write_queue_tail(sk); + + if (skb && TCP_SKB_CB(skb)->seq == TCP_SKB_CB(skb)->end_seq) { + tcp_unlink_write_queue(skb, sk); + if (tcp_write_queue_empty(sk)) + tcp_chrono_stop(sk, TCP_CHRONO_BUSY); + tcp_wmem_free_skb(sk, skb); + } +} + +/* skb changing from pure zc to mixed, must charge zc */ +static int tcp_downgrade_zcopy_pure(struct sock *sk, struct sk_buff *skb) +{ + if (unlikely(skb_zcopy_pure(skb))) { + u32 extra = skb->truesize - + SKB_TRUESIZE(skb_end_offset(skb)); + + if (!sk_wmem_schedule(sk, extra)) + return -ENOMEM; + + sk_mem_charge(sk, extra); + skb_shinfo(skb)->flags &= ~SKBFL_PURE_ZEROCOPY; + } + return 0; +} + + +static int tcp_wmem_schedule(struct sock *sk, int copy) +{ + int left; + + if (likely(sk_wmem_schedule(sk, copy))) + return copy; + + /* We could be in trouble if we have nothing queued. + * Use whatever is left in sk->sk_forward_alloc and tcp_wmem[0] + * to guarantee some progress. + */ + left = sock_net(sk)->ipv4.sysctl_tcp_wmem[0] - sk->sk_wmem_queued; + if (left > 0) + sk_forced_mem_schedule(sk, min(left, copy)); + return min(copy, sk->sk_forward_alloc); +} + +static struct sk_buff *tcp_build_frag(struct sock *sk, int size_goal, int flags, + struct page *page, int offset, size_t *size) +{ + struct sk_buff *skb = tcp_write_queue_tail(sk); + struct tcp_sock *tp = tcp_sk(sk); + bool can_coalesce; + int copy, i; + + if (!skb || (copy = size_goal - skb->len) <= 0 || + !tcp_skb_can_collapse_to(skb)) { +new_segment: + if (!sk_stream_memory_free(sk)) + return NULL; + + skb = tcp_stream_alloc_skb(sk, 0, sk->sk_allocation, + tcp_rtx_and_write_queues_empty(sk)); + if (!skb) + return NULL; + +#ifdef CONFIG_TLS_DEVICE + skb->decrypted = !!(flags & MSG_SENDPAGE_DECRYPTED); +#endif + tcp_skb_entail(sk, skb); + copy = size_goal; + } + + if (copy > *size) + copy = *size; + + i = skb_shinfo(skb)->nr_frags; + can_coalesce = skb_can_coalesce(skb, i, page, offset); + if (!can_coalesce && i >= READ_ONCE(sysctl_max_skb_frags)) { + tcp_mark_push(tp, skb); + goto new_segment; + } + if (tcp_downgrade_zcopy_pure(sk, skb)) + return NULL; + + copy = tcp_wmem_schedule(sk, copy); + if (!copy) + return NULL; + + if (can_coalesce) { + skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], copy); + } else { + get_page(page); + skb_fill_page_desc_noacc(skb, i, page, offset, copy); + } + + if (!(flags & MSG_NO_SHARED_FRAGS)) + skb_shinfo(skb)->flags |= SKBFL_SHARED_FRAG; + + skb->len += copy; + skb->data_len += copy; + skb->truesize += copy; + sk_wmem_queued_add(sk, copy); + sk_mem_charge(sk, copy); + WRITE_ONCE(tp->write_seq, tp->write_seq + copy); + TCP_SKB_CB(skb)->end_seq += copy; + tcp_skb_pcount_set(skb, 0); + + *size = copy; + return skb; +} + +ssize_t do_tcp_sendpages(struct sock *sk, struct page *page, int offset, + size_t size, int flags) +{ + struct tcp_sock *tp = tcp_sk(sk); + int mss_now, size_goal; + int err; + ssize_t copied; + long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); + + if (IS_ENABLED(CONFIG_DEBUG_VM) && + WARN_ONCE(!sendpage_ok(page), + "page must not be a Slab one and have page_count > 0")) + return -EINVAL; + + /* Wait for a connection to finish. One exception is TCP Fast Open + * (passive side) where data is allowed to be sent before a connection + * is fully established. + */ + if (((1 << sk->sk_state) & ~(TCPF_ESTABLISHED | TCPF_CLOSE_WAIT)) && + !tcp_passive_fastopen(sk)) { + err = sk_stream_wait_connect(sk, &timeo); + if (err != 0) + goto out_err; + } + + sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); + + mss_now = tcp_send_mss(sk, &size_goal, flags); + copied = 0; + + err = -EPIPE; + if (sk->sk_err || (sk->sk_shutdown & SEND_SHUTDOWN)) + goto out_err; + + while (size > 0) { + struct sk_buff *skb; + size_t copy = size; + + skb = tcp_build_frag(sk, size_goal, flags, page, offset, ©); + if (!skb) + goto wait_for_space; + + if (!copied) + TCP_SKB_CB(skb)->tcp_flags &= ~TCPHDR_PSH; + + copied += copy; + offset += copy; + size -= copy; + if (!size) + goto out; + + if (skb->len < size_goal || (flags & MSG_OOB)) + continue; + + if (forced_push(tp)) { + tcp_mark_push(tp, skb); + __tcp_push_pending_frames(sk, mss_now, TCP_NAGLE_PUSH); + } else if (skb == tcp_send_head(sk)) + tcp_push_one(sk, mss_now); + continue; + +wait_for_space: + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + tcp_push(sk, flags & ~MSG_MORE, mss_now, + TCP_NAGLE_PUSH, size_goal); + + err = sk_stream_wait_memory(sk, &timeo); + if (err != 0) + goto do_error; + + mss_now = tcp_send_mss(sk, &size_goal, flags); + } + +out: + if (copied) { + tcp_tx_timestamp(sk, sk->sk_tsflags); + if (!(flags & MSG_SENDPAGE_NOTLAST)) + tcp_push(sk, flags, mss_now, tp->nonagle, size_goal); + } + return copied; + +do_error: + tcp_remove_empty_skb(sk); + if (copied) + goto out; +out_err: + /* make sure we wake any epoll edge trigger waiter */ + if (unlikely(tcp_rtx_and_write_queues_empty(sk) && err == -EAGAIN)) { + sk->sk_write_space(sk); + tcp_chrono_stop(sk, TCP_CHRONO_SNDBUF_LIMITED); + } + return sk_stream_error(sk, flags, err); +} +EXPORT_SYMBOL_GPL(do_tcp_sendpages); + +int tcp_sendpage_locked(struct sock *sk, struct page *page, int offset, + size_t size, int flags) +{ + if (!(sk->sk_route_caps & NETIF_F_SG)) + return sock_no_sendpage_locked(sk, page, offset, size, flags); + + tcp_rate_check_app_limited(sk); /* is sending application-limited? */ + + return do_tcp_sendpages(sk, page, offset, size, flags); +} +EXPORT_SYMBOL_GPL(tcp_sendpage_locked); + +int tcp_sendpage(struct sock *sk, struct page *page, int offset, + size_t size, int flags) +{ + int ret; + + lock_sock(sk); + ret = tcp_sendpage_locked(sk, page, offset, size, flags); + release_sock(sk); + + return ret; +} +EXPORT_SYMBOL(tcp_sendpage); + +void tcp_free_fastopen_req(struct tcp_sock *tp) +{ + if (tp->fastopen_req) { + kfree(tp->fastopen_req); + tp->fastopen_req = NULL; + } +} + +int tcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, int *copied, + size_t size, struct ubuf_info *uarg) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct inet_sock *inet = inet_sk(sk); + struct sockaddr *uaddr = msg->msg_name; + int err, flags; + + if (!(READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_fastopen) & + TFO_CLIENT_ENABLE) || + (uaddr && msg->msg_namelen >= sizeof(uaddr->sa_family) && + uaddr->sa_family == AF_UNSPEC)) + return -EOPNOTSUPP; + if (tp->fastopen_req) + return -EALREADY; /* Another Fast Open is in progress */ + + tp->fastopen_req = kzalloc(sizeof(struct tcp_fastopen_request), + sk->sk_allocation); + if (unlikely(!tp->fastopen_req)) + return -ENOBUFS; + tp->fastopen_req->data = msg; + tp->fastopen_req->size = size; + tp->fastopen_req->uarg = uarg; + + if (inet->defer_connect) { + err = tcp_connect(sk); + /* Same failure procedure as in tcp_v4/6_connect */ + if (err) { + tcp_set_state(sk, TCP_CLOSE); + inet->inet_dport = 0; + sk->sk_route_caps = 0; + } + } + flags = (msg->msg_flags & MSG_DONTWAIT) ? O_NONBLOCK : 0; + err = __inet_stream_connect(sk->sk_socket, uaddr, + msg->msg_namelen, flags, 1); + /* fastopen_req could already be freed in __inet_stream_connect + * if the connection times out or gets rst + */ + if (tp->fastopen_req) { + *copied = tp->fastopen_req->copied; + tcp_free_fastopen_req(tp); + inet->defer_connect = 0; + } + return err; +} + +int tcp_sendmsg_locked(struct sock *sk, struct msghdr *msg, size_t size) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct ubuf_info *uarg = NULL; + struct sk_buff *skb; + struct sockcm_cookie sockc; + int flags, err, copied = 0; + int mss_now = 0, size_goal, copied_syn = 0; + int process_backlog = 0; + bool zc = false; + long timeo; + + flags = msg->msg_flags; + + if ((flags & MSG_ZEROCOPY) && size) { + skb = tcp_write_queue_tail(sk); + + if (msg->msg_ubuf) { + uarg = msg->msg_ubuf; + net_zcopy_get(uarg); + zc = sk->sk_route_caps & NETIF_F_SG; + } else if (sock_flag(sk, SOCK_ZEROCOPY)) { + uarg = msg_zerocopy_realloc(sk, size, skb_zcopy(skb)); + if (!uarg) { + err = -ENOBUFS; + goto out_err; + } + zc = sk->sk_route_caps & NETIF_F_SG; + if (!zc) + uarg_to_msgzc(uarg)->zerocopy = 0; + } + } + + if (unlikely(flags & MSG_FASTOPEN || inet_sk(sk)->defer_connect) && + !tp->repair) { + err = tcp_sendmsg_fastopen(sk, msg, &copied_syn, size, uarg); + if (err == -EINPROGRESS && copied_syn > 0) + goto out; + else if (err) + goto out_err; + } + + timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); + + tcp_rate_check_app_limited(sk); /* is sending application-limited? */ + + /* Wait for a connection to finish. One exception is TCP Fast Open + * (passive side) where data is allowed to be sent before a connection + * is fully established. + */ + if (((1 << sk->sk_state) & ~(TCPF_ESTABLISHED | TCPF_CLOSE_WAIT)) && + !tcp_passive_fastopen(sk)) { + err = sk_stream_wait_connect(sk, &timeo); + if (err != 0) + goto do_error; + } + + if (unlikely(tp->repair)) { + if (tp->repair_queue == TCP_RECV_QUEUE) { + copied = tcp_send_rcvq(sk, msg, size); + goto out_nopush; + } + + err = -EINVAL; + if (tp->repair_queue == TCP_NO_QUEUE) + goto out_err; + + /* 'common' sending to sendq */ + } + + sockcm_init(&sockc, sk); + if (msg->msg_controllen) { + err = sock_cmsg_send(sk, msg, &sockc); + if (unlikely(err)) { + err = -EINVAL; + goto out_err; + } + } + + /* This should be in poll */ + sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); + + /* Ok commence sending. */ + copied = 0; + +restart: + mss_now = tcp_send_mss(sk, &size_goal, flags); + + err = -EPIPE; + if (sk->sk_err || (sk->sk_shutdown & SEND_SHUTDOWN)) + goto do_error; + + while (msg_data_left(msg)) { + int copy = 0; + + skb = tcp_write_queue_tail(sk); + if (skb) + copy = size_goal - skb->len; + + if (copy <= 0 || !tcp_skb_can_collapse_to(skb)) { + bool first_skb; + +new_segment: + if (!sk_stream_memory_free(sk)) + goto wait_for_space; + + if (unlikely(process_backlog >= 16)) { + process_backlog = 0; + if (sk_flush_backlog(sk)) + goto restart; + } + first_skb = tcp_rtx_and_write_queues_empty(sk); + skb = tcp_stream_alloc_skb(sk, 0, sk->sk_allocation, + first_skb); + if (!skb) + goto wait_for_space; + + process_backlog++; + + tcp_skb_entail(sk, skb); + copy = size_goal; + + /* All packets are restored as if they have + * already been sent. skb_mstamp_ns isn't set to + * avoid wrong rtt estimation. + */ + if (tp->repair) + TCP_SKB_CB(skb)->sacked |= TCPCB_REPAIRED; + } + + /* Try to append data to the end of skb. */ + if (copy > msg_data_left(msg)) + copy = msg_data_left(msg); + + if (!zc) { + bool merge = true; + int i = skb_shinfo(skb)->nr_frags; + struct page_frag *pfrag = sk_page_frag(sk); + + if (!sk_page_frag_refill(sk, pfrag)) + goto wait_for_space; + + if (!skb_can_coalesce(skb, i, pfrag->page, + pfrag->offset)) { + if (i >= READ_ONCE(sysctl_max_skb_frags)) { + tcp_mark_push(tp, skb); + goto new_segment; + } + merge = false; + } + + copy = min_t(int, copy, pfrag->size - pfrag->offset); + + if (unlikely(skb_zcopy_pure(skb) || skb_zcopy_managed(skb))) { + if (tcp_downgrade_zcopy_pure(sk, skb)) + goto wait_for_space; + skb_zcopy_downgrade_managed(skb); + } + + copy = tcp_wmem_schedule(sk, copy); + if (!copy) + goto wait_for_space; + + err = skb_copy_to_page_nocache(sk, &msg->msg_iter, skb, + pfrag->page, + pfrag->offset, + copy); + if (err) + goto do_error; + + /* Update the skb. */ + if (merge) { + skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], copy); + } else { + skb_fill_page_desc(skb, i, pfrag->page, + pfrag->offset, copy); + page_ref_inc(pfrag->page); + } + pfrag->offset += copy; + } else { + /* First append to a fragless skb builds initial + * pure zerocopy skb + */ + if (!skb->len) + skb_shinfo(skb)->flags |= SKBFL_PURE_ZEROCOPY; + + if (!skb_zcopy_pure(skb)) { + copy = tcp_wmem_schedule(sk, copy); + if (!copy) + goto wait_for_space; + } + + err = skb_zerocopy_iter_stream(sk, skb, msg, copy, uarg); + if (err == -EMSGSIZE || err == -EEXIST) { + tcp_mark_push(tp, skb); + goto new_segment; + } + if (err < 0) + goto do_error; + copy = err; + } + + if (!copied) + TCP_SKB_CB(skb)->tcp_flags &= ~TCPHDR_PSH; + + WRITE_ONCE(tp->write_seq, tp->write_seq + copy); + TCP_SKB_CB(skb)->end_seq += copy; + tcp_skb_pcount_set(skb, 0); + + copied += copy; + if (!msg_data_left(msg)) { + if (unlikely(flags & MSG_EOR)) + TCP_SKB_CB(skb)->eor = 1; + goto out; + } + + if (skb->len < size_goal || (flags & MSG_OOB) || unlikely(tp->repair)) + continue; + + if (forced_push(tp)) { + tcp_mark_push(tp, skb); + __tcp_push_pending_frames(sk, mss_now, TCP_NAGLE_PUSH); + } else if (skb == tcp_send_head(sk)) + tcp_push_one(sk, mss_now); + continue; + +wait_for_space: + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + if (copied) + tcp_push(sk, flags & ~MSG_MORE, mss_now, + TCP_NAGLE_PUSH, size_goal); + + err = sk_stream_wait_memory(sk, &timeo); + if (err != 0) + goto do_error; + + mss_now = tcp_send_mss(sk, &size_goal, flags); + } + +out: + if (copied) { + tcp_tx_timestamp(sk, sockc.tsflags); + tcp_push(sk, flags, mss_now, tp->nonagle, size_goal); + } +out_nopush: + net_zcopy_put(uarg); + return copied + copied_syn; + +do_error: + tcp_remove_empty_skb(sk); + + if (copied + copied_syn) + goto out; +out_err: + net_zcopy_put_abort(uarg, true); + err = sk_stream_error(sk, flags, err); + /* make sure we wake any epoll edge trigger waiter */ + if (unlikely(tcp_rtx_and_write_queues_empty(sk) && err == -EAGAIN)) { + sk->sk_write_space(sk); + tcp_chrono_stop(sk, TCP_CHRONO_SNDBUF_LIMITED); + } + return err; +} +EXPORT_SYMBOL_GPL(tcp_sendmsg_locked); + +int tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +{ + int ret; + + lock_sock(sk); + ret = tcp_sendmsg_locked(sk, msg, size); + release_sock(sk); + + return ret; +} +EXPORT_SYMBOL(tcp_sendmsg); + +void tcp_splice_eof(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct tcp_sock *tp = tcp_sk(sk); + int mss_now, size_goal; + + if (!tcp_write_queue_tail(sk)) + return; + + lock_sock(sk); + mss_now = tcp_send_mss(sk, &size_goal, 0); + tcp_push(sk, 0, mss_now, tp->nonagle, size_goal); + release_sock(sk); +} +EXPORT_SYMBOL_GPL(tcp_splice_eof); + +/* + * Handle reading urgent data. BSD has very simple semantics for + * this, no blocking and very strange errors 8) + */ + +static int tcp_recv_urg(struct sock *sk, struct msghdr *msg, int len, int flags) +{ + struct tcp_sock *tp = tcp_sk(sk); + + /* No URG data to read. */ + if (sock_flag(sk, SOCK_URGINLINE) || !tp->urg_data || + tp->urg_data == TCP_URG_READ) + return -EINVAL; /* Yes this is right ! */ + + if (sk->sk_state == TCP_CLOSE && !sock_flag(sk, SOCK_DONE)) + return -ENOTCONN; + + if (tp->urg_data & TCP_URG_VALID) { + int err = 0; + char c = tp->urg_data; + + if (!(flags & MSG_PEEK)) + WRITE_ONCE(tp->urg_data, TCP_URG_READ); + + /* Read urgent data. */ + msg->msg_flags |= MSG_OOB; + + if (len > 0) { + if (!(flags & MSG_TRUNC)) + err = memcpy_to_msg(msg, &c, 1); + len = 1; + } else + msg->msg_flags |= MSG_TRUNC; + + return err ? -EFAULT : len; + } + + if (sk->sk_state == TCP_CLOSE || (sk->sk_shutdown & RCV_SHUTDOWN)) + return 0; + + /* Fixed the recv(..., MSG_OOB) behaviour. BSD docs and + * the available implementations agree in this case: + * this call should never block, independent of the + * blocking state of the socket. + * Mike + */ + return -EAGAIN; +} + +static int tcp_peek_sndq(struct sock *sk, struct msghdr *msg, int len) +{ + struct sk_buff *skb; + int copied = 0, err = 0; + + /* XXX -- need to support SO_PEEK_OFF */ + + skb_rbtree_walk(skb, &sk->tcp_rtx_queue) { + err = skb_copy_datagram_msg(skb, 0, msg, skb->len); + if (err) + return err; + copied += skb->len; + } + + skb_queue_walk(&sk->sk_write_queue, skb) { + err = skb_copy_datagram_msg(skb, 0, msg, skb->len); + if (err) + break; + + copied += skb->len; + } + + return err ?: copied; +} + +/* Clean up the receive buffer for full frames taken by the user, + * then send an ACK if necessary. COPIED is the number of bytes + * tcp_recvmsg has given to the user so far, it speeds up the + * calculation of whether or not we must ACK for the sake of + * a window update. + */ +void __tcp_cleanup_rbuf(struct sock *sk, int copied) +{ + struct tcp_sock *tp = tcp_sk(sk); + bool time_to_ack = false; + + if (inet_csk_ack_scheduled(sk)) { + const struct inet_connection_sock *icsk = inet_csk(sk); + + if (/* Once-per-two-segments ACK was not sent by tcp_input.c */ + tp->rcv_nxt - tp->rcv_wup > icsk->icsk_ack.rcv_mss || + /* + * If this read emptied read buffer, we send ACK, if + * connection is not bidirectional, user drained + * receive buffer and there was a small segment + * in queue. + */ + (copied > 0 && + ((icsk->icsk_ack.pending & ICSK_ACK_PUSHED2) || + ((icsk->icsk_ack.pending & ICSK_ACK_PUSHED) && + !inet_csk_in_pingpong_mode(sk))) && + !atomic_read(&sk->sk_rmem_alloc))) + time_to_ack = true; + } + + /* We send an ACK if we can now advertise a non-zero window + * which has been raised "significantly". + * + * Even if window raised up to infinity, do not send window open ACK + * in states, where we will not receive more. It is useless. + */ + if (copied > 0 && !time_to_ack && !(sk->sk_shutdown & RCV_SHUTDOWN)) { + __u32 rcv_window_now = tcp_receive_window(tp); + + /* Optimize, __tcp_select_window() is not cheap. */ + if (2*rcv_window_now <= tp->window_clamp) { + __u32 new_window = __tcp_select_window(sk); + + /* Send ACK now, if this read freed lots of space + * in our buffer. Certainly, new_window is new window. + * We can advertise it now, if it is not less than current one. + * "Lots" means "at least twice" here. + */ + if (new_window && new_window >= 2 * rcv_window_now) + time_to_ack = true; + } + } + if (time_to_ack) + tcp_send_ack(sk); +} + +void tcp_cleanup_rbuf(struct sock *sk, int copied) +{ + struct sk_buff *skb = skb_peek(&sk->sk_receive_queue); + struct tcp_sock *tp = tcp_sk(sk); + + WARN(skb && !before(tp->copied_seq, TCP_SKB_CB(skb)->end_seq), + "cleanup rbuf bug: copied %X seq %X rcvnxt %X\n", + tp->copied_seq, TCP_SKB_CB(skb)->end_seq, tp->rcv_nxt); + __tcp_cleanup_rbuf(sk, copied); +} + +static void tcp_eat_recv_skb(struct sock *sk, struct sk_buff *skb) +{ + __skb_unlink(skb, &sk->sk_receive_queue); + if (likely(skb->destructor == sock_rfree)) { + sock_rfree(skb); + skb->destructor = NULL; + skb->sk = NULL; + return skb_attempt_defer_free(skb); + } + __kfree_skb(skb); +} + +struct sk_buff *tcp_recv_skb(struct sock *sk, u32 seq, u32 *off) +{ + struct sk_buff *skb; + u32 offset; + + while ((skb = skb_peek(&sk->sk_receive_queue)) != NULL) { + offset = seq - TCP_SKB_CB(skb)->seq; + if (unlikely(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_SYN)) { + pr_err_once("%s: found a SYN, please report !\n", __func__); + offset--; + } + if (offset < skb->len || (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN)) { + *off = offset; + return skb; + } + /* This looks weird, but this can happen if TCP collapsing + * splitted a fat GRO packet, while we released socket lock + * in skb_splice_bits() + */ + tcp_eat_recv_skb(sk, skb); + } + return NULL; +} +EXPORT_SYMBOL(tcp_recv_skb); + +/* + * This routine provides an alternative to tcp_recvmsg() for routines + * that would like to handle copying from skbuffs directly in 'sendfile' + * fashion. + * Note: + * - It is assumed that the socket was locked by the caller. + * - The routine does not block. + * - At present, there is no support for reading OOB data + * or for 'peeking' the socket using this routine + * (although both would be easy to implement). + */ +int tcp_read_sock(struct sock *sk, read_descriptor_t *desc, + sk_read_actor_t recv_actor) +{ + struct sk_buff *skb; + struct tcp_sock *tp = tcp_sk(sk); + u32 seq = tp->copied_seq; + u32 offset; + int copied = 0; + + if (sk->sk_state == TCP_LISTEN) + return -ENOTCONN; + while ((skb = tcp_recv_skb(sk, seq, &offset)) != NULL) { + if (offset < skb->len) { + int used; + size_t len; + + len = skb->len - offset; + /* Stop reading if we hit a patch of urgent data */ + if (unlikely(tp->urg_data)) { + u32 urg_offset = tp->urg_seq - seq; + if (urg_offset < len) + len = urg_offset; + if (!len) + break; + } + used = recv_actor(desc, skb, offset, len); + if (used <= 0) { + if (!copied) + copied = used; + break; + } + if (WARN_ON_ONCE(used > len)) + used = len; + seq += used; + copied += used; + offset += used; + + /* If recv_actor drops the lock (e.g. TCP splice + * receive) the skb pointer might be invalid when + * getting here: tcp_collapse might have deleted it + * while aggregating skbs from the socket queue. + */ + skb = tcp_recv_skb(sk, seq - 1, &offset); + if (!skb) + break; + /* TCP coalescing might have appended data to the skb. + * Try to splice more frags + */ + if (offset + 1 != skb->len) + continue; + } + if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) { + tcp_eat_recv_skb(sk, skb); + ++seq; + break; + } + tcp_eat_recv_skb(sk, skb); + if (!desc->count) + break; + WRITE_ONCE(tp->copied_seq, seq); + } + WRITE_ONCE(tp->copied_seq, seq); + + tcp_rcv_space_adjust(sk); + + /* Clean up data we have read: This will do ACK frames. */ + if (copied > 0) { + tcp_recv_skb(sk, seq, &offset); + tcp_cleanup_rbuf(sk, copied); + } + return copied; +} +EXPORT_SYMBOL(tcp_read_sock); + +int tcp_read_skb(struct sock *sk, skb_read_actor_t recv_actor) +{ + struct sk_buff *skb; + int copied = 0; + + if (sk->sk_state == TCP_LISTEN) + return -ENOTCONN; + + while ((skb = skb_peek(&sk->sk_receive_queue)) != NULL) { + u8 tcp_flags; + int used; + + __skb_unlink(skb, &sk->sk_receive_queue); + WARN_ON_ONCE(!skb_set_owner_sk_safe(skb, sk)); + tcp_flags = TCP_SKB_CB(skb)->tcp_flags; + used = recv_actor(sk, skb); + if (used < 0) { + if (!copied) + copied = used; + break; + } + copied += used; + + if (tcp_flags & TCPHDR_FIN) + break; + } + return copied; +} +EXPORT_SYMBOL(tcp_read_skb); + +void tcp_read_done(struct sock *sk, size_t len) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 seq = tp->copied_seq; + struct sk_buff *skb; + size_t left; + u32 offset; + + if (sk->sk_state == TCP_LISTEN) + return; + + left = len; + while (left && (skb = tcp_recv_skb(sk, seq, &offset)) != NULL) { + int used; + + used = min_t(size_t, skb->len - offset, left); + seq += used; + left -= used; + + if (skb->len > offset + used) + break; + + if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) { + tcp_eat_recv_skb(sk, skb); + ++seq; + break; + } + tcp_eat_recv_skb(sk, skb); + } + WRITE_ONCE(tp->copied_seq, seq); + + tcp_rcv_space_adjust(sk); + + /* Clean up data we have read: This will do ACK frames. */ + if (left != len) + tcp_cleanup_rbuf(sk, len - left); +} +EXPORT_SYMBOL(tcp_read_done); + +int tcp_peek_len(struct socket *sock) +{ + return tcp_inq(sock->sk); +} +EXPORT_SYMBOL(tcp_peek_len); + +/* Make sure sk_rcvbuf is big enough to satisfy SO_RCVLOWAT hint */ +int tcp_set_rcvlowat(struct sock *sk, int val) +{ + int cap; + + if (sk->sk_userlocks & SOCK_RCVBUF_LOCK) + cap = sk->sk_rcvbuf >> 1; + else + cap = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_rmem[2]) >> 1; + val = min(val, cap); + WRITE_ONCE(sk->sk_rcvlowat, val ? : 1); + + /* Check if we need to signal EPOLLIN right now */ + tcp_data_ready(sk); + + if (sk->sk_userlocks & SOCK_RCVBUF_LOCK) + return 0; + + val <<= 1; + if (val > sk->sk_rcvbuf) { + WRITE_ONCE(sk->sk_rcvbuf, val); + tcp_sk(sk)->window_clamp = tcp_win_from_space(sk, val); + } + return 0; +} +EXPORT_SYMBOL(tcp_set_rcvlowat); + +void tcp_update_recv_tstamps(struct sk_buff *skb, + struct scm_timestamping_internal *tss) +{ + if (skb->tstamp) + tss->ts[0] = ktime_to_timespec64(skb->tstamp); + else + tss->ts[0] = (struct timespec64) {0}; + + if (skb_hwtstamps(skb)->hwtstamp) + tss->ts[2] = ktime_to_timespec64(skb_hwtstamps(skb)->hwtstamp); + else + tss->ts[2] = (struct timespec64) {0}; +} + +#ifdef CONFIG_MMU +static const struct vm_operations_struct tcp_vm_ops = { +}; + +int tcp_mmap(struct file *file, struct socket *sock, + struct vm_area_struct *vma) +{ + if (vma->vm_flags & (VM_WRITE | VM_EXEC)) + return -EPERM; + vma->vm_flags &= ~(VM_MAYWRITE | VM_MAYEXEC); + + /* Instruct vm_insert_page() to not mmap_read_lock(mm) */ + vma->vm_flags |= VM_MIXEDMAP; + + vma->vm_ops = &tcp_vm_ops; + return 0; +} +EXPORT_SYMBOL(tcp_mmap); + +static skb_frag_t *skb_advance_to_frag(struct sk_buff *skb, u32 offset_skb, + u32 *offset_frag) +{ + skb_frag_t *frag; + + if (unlikely(offset_skb >= skb->len)) + return NULL; + + offset_skb -= skb_headlen(skb); + if ((int)offset_skb < 0 || skb_has_frag_list(skb)) + return NULL; + + frag = skb_shinfo(skb)->frags; + while (offset_skb) { + if (skb_frag_size(frag) > offset_skb) { + *offset_frag = offset_skb; + return frag; + } + offset_skb -= skb_frag_size(frag); + ++frag; + } + *offset_frag = 0; + return frag; +} + +static bool can_map_frag(const skb_frag_t *frag) +{ + return skb_frag_size(frag) == PAGE_SIZE && !skb_frag_off(frag); +} + +static int find_next_mappable_frag(const skb_frag_t *frag, + int remaining_in_skb) +{ + int offset = 0; + + if (likely(can_map_frag(frag))) + return 0; + + while (offset < remaining_in_skb && !can_map_frag(frag)) { + offset += skb_frag_size(frag); + ++frag; + } + return offset; +} + +static void tcp_zerocopy_set_hint_for_skb(struct sock *sk, + struct tcp_zerocopy_receive *zc, + struct sk_buff *skb, u32 offset) +{ + u32 frag_offset, partial_frag_remainder = 0; + int mappable_offset; + skb_frag_t *frag; + + /* worst case: skip to next skb. try to improve on this case below */ + zc->recv_skip_hint = skb->len - offset; + + /* Find the frag containing this offset (and how far into that frag) */ + frag = skb_advance_to_frag(skb, offset, &frag_offset); + if (!frag) + return; + + if (frag_offset) { + struct skb_shared_info *info = skb_shinfo(skb); + + /* We read part of the last frag, must recvmsg() rest of skb. */ + if (frag == &info->frags[info->nr_frags - 1]) + return; + + /* Else, we must at least read the remainder in this frag. */ + partial_frag_remainder = skb_frag_size(frag) - frag_offset; + zc->recv_skip_hint -= partial_frag_remainder; + ++frag; + } + + /* partial_frag_remainder: If part way through a frag, must read rest. + * mappable_offset: Bytes till next mappable frag, *not* counting bytes + * in partial_frag_remainder. + */ + mappable_offset = find_next_mappable_frag(frag, zc->recv_skip_hint); + zc->recv_skip_hint = mappable_offset + partial_frag_remainder; +} + +static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len, + int flags, struct scm_timestamping_internal *tss, + int *cmsg_flags); +static int receive_fallback_to_copy(struct sock *sk, + struct tcp_zerocopy_receive *zc, int inq, + struct scm_timestamping_internal *tss) +{ + unsigned long copy_address = (unsigned long)zc->copybuf_address; + struct msghdr msg = {}; + struct iovec iov; + int err; + + zc->length = 0; + zc->recv_skip_hint = 0; + + if (copy_address != zc->copybuf_address) + return -EINVAL; + + err = import_single_range(ITER_DEST, (void __user *)copy_address, + inq, &iov, &msg.msg_iter); + if (err) + return err; + + err = tcp_recvmsg_locked(sk, &msg, inq, MSG_DONTWAIT, + tss, &zc->msg_flags); + if (err < 0) + return err; + + zc->copybuf_len = err; + if (likely(zc->copybuf_len)) { + struct sk_buff *skb; + u32 offset; + + skb = tcp_recv_skb(sk, tcp_sk(sk)->copied_seq, &offset); + if (skb) + tcp_zerocopy_set_hint_for_skb(sk, zc, skb, offset); + } + return 0; +} + +static int tcp_copy_straggler_data(struct tcp_zerocopy_receive *zc, + struct sk_buff *skb, u32 copylen, + u32 *offset, u32 *seq) +{ + unsigned long copy_address = (unsigned long)zc->copybuf_address; + struct msghdr msg = {}; + struct iovec iov; + int err; + + if (copy_address != zc->copybuf_address) + return -EINVAL; + + err = import_single_range(ITER_DEST, (void __user *)copy_address, + copylen, &iov, &msg.msg_iter); + if (err) + return err; + err = skb_copy_datagram_msg(skb, *offset, &msg, copylen); + if (err) + return err; + zc->recv_skip_hint -= copylen; + *offset += copylen; + *seq += copylen; + return (__s32)copylen; +} + +static int tcp_zc_handle_leftover(struct tcp_zerocopy_receive *zc, + struct sock *sk, + struct sk_buff *skb, + u32 *seq, + s32 copybuf_len, + struct scm_timestamping_internal *tss) +{ + u32 offset, copylen = min_t(u32, copybuf_len, zc->recv_skip_hint); + + if (!copylen) + return 0; + /* skb is null if inq < PAGE_SIZE. */ + if (skb) { + offset = *seq - TCP_SKB_CB(skb)->seq; + } else { + skb = tcp_recv_skb(sk, *seq, &offset); + if (TCP_SKB_CB(skb)->has_rxtstamp) { + tcp_update_recv_tstamps(skb, tss); + zc->msg_flags |= TCP_CMSG_TS; + } + } + + zc->copybuf_len = tcp_copy_straggler_data(zc, skb, copylen, &offset, + seq); + return zc->copybuf_len < 0 ? 0 : copylen; +} + +static int tcp_zerocopy_vm_insert_batch_error(struct vm_area_struct *vma, + struct page **pending_pages, + unsigned long pages_remaining, + unsigned long *address, + u32 *length, + u32 *seq, + struct tcp_zerocopy_receive *zc, + u32 total_bytes_to_map, + int err) +{ + /* At least one page did not map. Try zapping if we skipped earlier. */ + if (err == -EBUSY && + zc->flags & TCP_RECEIVE_ZEROCOPY_FLAG_TLB_CLEAN_HINT) { + u32 maybe_zap_len; + + maybe_zap_len = total_bytes_to_map - /* All bytes to map */ + *length + /* Mapped or pending */ + (pages_remaining * PAGE_SIZE); /* Failed map. */ + zap_page_range(vma, *address, maybe_zap_len); + err = 0; + } + + if (!err) { + unsigned long leftover_pages = pages_remaining; + int bytes_mapped; + + /* We called zap_page_range, try to reinsert. */ + err = vm_insert_pages(vma, *address, + pending_pages, + &pages_remaining); + bytes_mapped = PAGE_SIZE * (leftover_pages - pages_remaining); + *seq += bytes_mapped; + *address += bytes_mapped; + } + if (err) { + /* Either we were unable to zap, OR we zapped, retried an + * insert, and still had an issue. Either ways, pages_remaining + * is the number of pages we were unable to map, and we unroll + * some state we speculatively touched before. + */ + const int bytes_not_mapped = PAGE_SIZE * pages_remaining; + + *length -= bytes_not_mapped; + zc->recv_skip_hint += bytes_not_mapped; + } + return err; +} + +static int tcp_zerocopy_vm_insert_batch(struct vm_area_struct *vma, + struct page **pages, + unsigned int pages_to_map, + unsigned long *address, + u32 *length, + u32 *seq, + struct tcp_zerocopy_receive *zc, + u32 total_bytes_to_map) +{ + unsigned long pages_remaining = pages_to_map; + unsigned int pages_mapped; + unsigned int bytes_mapped; + int err; + + err = vm_insert_pages(vma, *address, pages, &pages_remaining); + pages_mapped = pages_to_map - (unsigned int)pages_remaining; + bytes_mapped = PAGE_SIZE * pages_mapped; + /* Even if vm_insert_pages fails, it may have partially succeeded in + * mapping (some but not all of the pages). + */ + *seq += bytes_mapped; + *address += bytes_mapped; + + if (likely(!err)) + return 0; + + /* Error: maybe zap and retry + rollback state for failed inserts. */ + return tcp_zerocopy_vm_insert_batch_error(vma, pages + pages_mapped, + pages_remaining, address, length, seq, zc, total_bytes_to_map, + err); +} + +#define TCP_VALID_ZC_MSG_FLAGS (TCP_CMSG_TS) +static void tcp_zc_finalize_rx_tstamp(struct sock *sk, + struct tcp_zerocopy_receive *zc, + struct scm_timestamping_internal *tss) +{ + unsigned long msg_control_addr; + struct msghdr cmsg_dummy; + + msg_control_addr = (unsigned long)zc->msg_control; + cmsg_dummy.msg_control = (void *)msg_control_addr; + cmsg_dummy.msg_controllen = + (__kernel_size_t)zc->msg_controllen; + cmsg_dummy.msg_flags = in_compat_syscall() + ? MSG_CMSG_COMPAT : 0; + cmsg_dummy.msg_control_is_user = true; + zc->msg_flags = 0; + if (zc->msg_control == msg_control_addr && + zc->msg_controllen == cmsg_dummy.msg_controllen) { + tcp_recv_timestamp(&cmsg_dummy, sk, tss); + zc->msg_control = (__u64) + ((uintptr_t)cmsg_dummy.msg_control); + zc->msg_controllen = + (__u64)cmsg_dummy.msg_controllen; + zc->msg_flags = (__u32)cmsg_dummy.msg_flags; + } +} + +#define TCP_ZEROCOPY_PAGE_BATCH_SIZE 32 +static int tcp_zerocopy_receive(struct sock *sk, + struct tcp_zerocopy_receive *zc, + struct scm_timestamping_internal *tss) +{ + u32 length = 0, offset, vma_len, avail_len, copylen = 0; + unsigned long address = (unsigned long)zc->address; + struct page *pages[TCP_ZEROCOPY_PAGE_BATCH_SIZE]; + s32 copybuf_len = zc->copybuf_len; + struct tcp_sock *tp = tcp_sk(sk); + const skb_frag_t *frags = NULL; + unsigned int pages_to_map = 0; + struct vm_area_struct *vma; + struct sk_buff *skb = NULL; + u32 seq = tp->copied_seq; + u32 total_bytes_to_map; + int inq = tcp_inq(sk); + int ret; + + zc->copybuf_len = 0; + zc->msg_flags = 0; + + if (address & (PAGE_SIZE - 1) || address != zc->address) + return -EINVAL; + + if (sk->sk_state == TCP_LISTEN) + return -ENOTCONN; + + sock_rps_record_flow(sk); + + if (inq && inq <= copybuf_len) + return receive_fallback_to_copy(sk, zc, inq, tss); + + if (inq < PAGE_SIZE) { + zc->length = 0; + zc->recv_skip_hint = inq; + if (!inq && sock_flag(sk, SOCK_DONE)) + return -EIO; + return 0; + } + + mmap_read_lock(current->mm); + + vma = vma_lookup(current->mm, address); + if (!vma || vma->vm_ops != &tcp_vm_ops) { + mmap_read_unlock(current->mm); + return -EINVAL; + } + vma_len = min_t(unsigned long, zc->length, vma->vm_end - address); + avail_len = min_t(u32, vma_len, inq); + total_bytes_to_map = avail_len & ~(PAGE_SIZE - 1); + if (total_bytes_to_map) { + if (!(zc->flags & TCP_RECEIVE_ZEROCOPY_FLAG_TLB_CLEAN_HINT)) + zap_page_range(vma, address, total_bytes_to_map); + zc->length = total_bytes_to_map; + zc->recv_skip_hint = 0; + } else { + zc->length = avail_len; + zc->recv_skip_hint = avail_len; + } + ret = 0; + while (length + PAGE_SIZE <= zc->length) { + int mappable_offset; + struct page *page; + + if (zc->recv_skip_hint < PAGE_SIZE) { + u32 offset_frag; + + if (skb) { + if (zc->recv_skip_hint > 0) + break; + skb = skb->next; + offset = seq - TCP_SKB_CB(skb)->seq; + } else { + skb = tcp_recv_skb(sk, seq, &offset); + } + + if (TCP_SKB_CB(skb)->has_rxtstamp) { + tcp_update_recv_tstamps(skb, tss); + zc->msg_flags |= TCP_CMSG_TS; + } + zc->recv_skip_hint = skb->len - offset; + frags = skb_advance_to_frag(skb, offset, &offset_frag); + if (!frags || offset_frag) + break; + } + + mappable_offset = find_next_mappable_frag(frags, + zc->recv_skip_hint); + if (mappable_offset) { + zc->recv_skip_hint = mappable_offset; + break; + } + page = skb_frag_page(frags); + prefetchw(page); + pages[pages_to_map++] = page; + length += PAGE_SIZE; + zc->recv_skip_hint -= PAGE_SIZE; + frags++; + if (pages_to_map == TCP_ZEROCOPY_PAGE_BATCH_SIZE || + zc->recv_skip_hint < PAGE_SIZE) { + /* Either full batch, or we're about to go to next skb + * (and we cannot unroll failed ops across skbs). + */ + ret = tcp_zerocopy_vm_insert_batch(vma, pages, + pages_to_map, + &address, &length, + &seq, zc, + total_bytes_to_map); + if (ret) + goto out; + pages_to_map = 0; + } + } + if (pages_to_map) { + ret = tcp_zerocopy_vm_insert_batch(vma, pages, pages_to_map, + &address, &length, &seq, + zc, total_bytes_to_map); + } +out: + mmap_read_unlock(current->mm); + /* Try to copy straggler data. */ + if (!ret) + copylen = tcp_zc_handle_leftover(zc, sk, skb, &seq, copybuf_len, tss); + + if (length + copylen) { + WRITE_ONCE(tp->copied_seq, seq); + tcp_rcv_space_adjust(sk); + + /* Clean up data we have read: This will do ACK frames. */ + tcp_recv_skb(sk, seq, &offset); + tcp_cleanup_rbuf(sk, length + copylen); + ret = 0; + if (length == zc->length) + zc->recv_skip_hint = 0; + } else { + if (!zc->recv_skip_hint && sock_flag(sk, SOCK_DONE)) + ret = -EIO; + } + zc->length = length; + return ret; +} +#endif + +/* Similar to __sock_recv_timestamp, but does not require an skb */ +void tcp_recv_timestamp(struct msghdr *msg, const struct sock *sk, + struct scm_timestamping_internal *tss) +{ + int new_tstamp = sock_flag(sk, SOCK_TSTAMP_NEW); + bool has_timestamping = false; + + if (tss->ts[0].tv_sec || tss->ts[0].tv_nsec) { + if (sock_flag(sk, SOCK_RCVTSTAMP)) { + if (sock_flag(sk, SOCK_RCVTSTAMPNS)) { + if (new_tstamp) { + struct __kernel_timespec kts = { + .tv_sec = tss->ts[0].tv_sec, + .tv_nsec = tss->ts[0].tv_nsec, + }; + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMPNS_NEW, + sizeof(kts), &kts); + } else { + struct __kernel_old_timespec ts_old = { + .tv_sec = tss->ts[0].tv_sec, + .tv_nsec = tss->ts[0].tv_nsec, + }; + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMPNS_OLD, + sizeof(ts_old), &ts_old); + } + } else { + if (new_tstamp) { + struct __kernel_sock_timeval stv = { + .tv_sec = tss->ts[0].tv_sec, + .tv_usec = tss->ts[0].tv_nsec / 1000, + }; + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMP_NEW, + sizeof(stv), &stv); + } else { + struct __kernel_old_timeval tv = { + .tv_sec = tss->ts[0].tv_sec, + .tv_usec = tss->ts[0].tv_nsec / 1000, + }; + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMP_OLD, + sizeof(tv), &tv); + } + } + } + + if (READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_SOFTWARE) + has_timestamping = true; + else + tss->ts[0] = (struct timespec64) {0}; + } + + if (tss->ts[2].tv_sec || tss->ts[2].tv_nsec) { + if (READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_RAW_HARDWARE) + has_timestamping = true; + else + tss->ts[2] = (struct timespec64) {0}; + } + + if (has_timestamping) { + tss->ts[1] = (struct timespec64) {0}; + if (sock_flag(sk, SOCK_TSTAMP_NEW)) + put_cmsg_scm_timestamping64(msg, tss); + else + put_cmsg_scm_timestamping(msg, tss); + } +} + +static int tcp_inq_hint(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + u32 copied_seq = READ_ONCE(tp->copied_seq); + u32 rcv_nxt = READ_ONCE(tp->rcv_nxt); + int inq; + + inq = rcv_nxt - copied_seq; + if (unlikely(inq < 0 || copied_seq != READ_ONCE(tp->copied_seq))) { + lock_sock(sk); + inq = tp->rcv_nxt - tp->copied_seq; + release_sock(sk); + } + /* After receiving a FIN, tell the user-space to continue reading + * by returning a non-zero inq. + */ + if (inq == 0 && sock_flag(sk, SOCK_DONE)) + inq = 1; + return inq; +} + +/* + * This routine copies from a sock struct into the user buffer. + * + * Technical note: in 2.3 we work on _locked_ socket, so that + * tricks with *seq access order and skb->users are not required. + * Probably, code can be easily improved even more. + */ + +static int tcp_recvmsg_locked(struct sock *sk, struct msghdr *msg, size_t len, + int flags, struct scm_timestamping_internal *tss, + int *cmsg_flags) +{ + struct tcp_sock *tp = tcp_sk(sk); + int copied = 0; + u32 peek_seq; + u32 *seq; + unsigned long used; + int err; + int target; /* Read at least this many bytes */ + long timeo; + struct sk_buff *skb, *last; + u32 urg_hole = 0; + + err = -ENOTCONN; + if (sk->sk_state == TCP_LISTEN) + goto out; + + if (tp->recvmsg_inq) { + *cmsg_flags = TCP_CMSG_INQ; + msg->msg_get_inq = 1; + } + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + + /* Urgent data needs to be handled specially. */ + if (flags & MSG_OOB) + goto recv_urg; + + if (unlikely(tp->repair)) { + err = -EPERM; + if (!(flags & MSG_PEEK)) + goto out; + + if (tp->repair_queue == TCP_SEND_QUEUE) + goto recv_sndq; + + err = -EINVAL; + if (tp->repair_queue == TCP_NO_QUEUE) + goto out; + + /* 'common' recv queue MSG_PEEK-ing */ + } + + seq = &tp->copied_seq; + if (flags & MSG_PEEK) { + peek_seq = tp->copied_seq; + seq = &peek_seq; + } + + target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); + + do { + u32 offset; + + /* Are we at urgent data? Stop if we have read anything or have SIGURG pending. */ + if (unlikely(tp->urg_data) && tp->urg_seq == *seq) { + if (copied) + break; + if (signal_pending(current)) { + copied = timeo ? sock_intr_errno(timeo) : -EAGAIN; + break; + } + } + + /* Next get a buffer. */ + + last = skb_peek_tail(&sk->sk_receive_queue); + skb_queue_walk(&sk->sk_receive_queue, skb) { + last = skb; + /* Now that we have two receive queues this + * shouldn't happen. + */ + if (WARN(before(*seq, TCP_SKB_CB(skb)->seq), + "TCP recvmsg seq # bug: copied %X, seq %X, rcvnxt %X, fl %X\n", + *seq, TCP_SKB_CB(skb)->seq, tp->rcv_nxt, + flags)) + break; + + offset = *seq - TCP_SKB_CB(skb)->seq; + if (unlikely(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_SYN)) { + pr_err_once("%s: found a SYN, please report !\n", __func__); + offset--; + } + if (offset < skb->len) + goto found_ok_skb; + if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) + goto found_fin_ok; + WARN(!(flags & MSG_PEEK), + "TCP recvmsg seq # bug 2: copied %X, seq %X, rcvnxt %X, fl %X\n", + *seq, TCP_SKB_CB(skb)->seq, tp->rcv_nxt, flags); + } + + /* Well, if we have backlog, try to process it now yet. */ + + if (copied >= target && !READ_ONCE(sk->sk_backlog.tail)) + break; + + if (copied) { + if (!timeo || + sk->sk_err || + sk->sk_state == TCP_CLOSE || + (sk->sk_shutdown & RCV_SHUTDOWN) || + signal_pending(current)) + break; + } else { + if (sock_flag(sk, SOCK_DONE)) + break; + + if (sk->sk_err) { + copied = sock_error(sk); + break; + } + + if (sk->sk_shutdown & RCV_SHUTDOWN) + break; + + if (sk->sk_state == TCP_CLOSE) { + /* This occurs when user tries to read + * from never connected socket. + */ + copied = -ENOTCONN; + break; + } + + if (!timeo) { + copied = -EAGAIN; + break; + } + + if (signal_pending(current)) { + copied = sock_intr_errno(timeo); + break; + } + } + + if (copied >= target) { + /* Do not sleep, just process backlog. */ + __sk_flush_backlog(sk); + } else { + tcp_cleanup_rbuf(sk, copied); + err = sk_wait_data(sk, &timeo, last); + if (err < 0) { + err = copied ? : err; + goto out; + } + } + + if ((flags & MSG_PEEK) && + (peek_seq - copied - urg_hole != tp->copied_seq)) { + net_dbg_ratelimited("TCP(%s:%d): Application bug, race in MSG_PEEK\n", + current->comm, + task_pid_nr(current)); + peek_seq = tp->copied_seq; + } + continue; + +found_ok_skb: + /* Ok so how much can we use? */ + used = skb->len - offset; + if (len < used) + used = len; + + /* Do we have urgent data here? */ + if (unlikely(tp->urg_data)) { + u32 urg_offset = tp->urg_seq - *seq; + if (urg_offset < used) { + if (!urg_offset) { + if (!sock_flag(sk, SOCK_URGINLINE)) { + WRITE_ONCE(*seq, *seq + 1); + urg_hole++; + offset++; + used--; + if (!used) + goto skip_copy; + } + } else + used = urg_offset; + } + } + + if (!(flags & MSG_TRUNC)) { + err = skb_copy_datagram_msg(skb, offset, msg, used); + if (err) { + /* Exception. Bailout! */ + if (!copied) + copied = -EFAULT; + break; + } + } + + WRITE_ONCE(*seq, *seq + used); + copied += used; + len -= used; + + tcp_rcv_space_adjust(sk); + +skip_copy: + if (unlikely(tp->urg_data) && after(tp->copied_seq, tp->urg_seq)) { + WRITE_ONCE(tp->urg_data, 0); + tcp_fast_path_check(sk); + } + + if (TCP_SKB_CB(skb)->has_rxtstamp) { + tcp_update_recv_tstamps(skb, tss); + *cmsg_flags |= TCP_CMSG_TS; + } + + if (used + offset < skb->len) + continue; + + if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) + goto found_fin_ok; + if (!(flags & MSG_PEEK)) + tcp_eat_recv_skb(sk, skb); + continue; + +found_fin_ok: + /* Process the FIN. */ + WRITE_ONCE(*seq, *seq + 1); + if (!(flags & MSG_PEEK)) + tcp_eat_recv_skb(sk, skb); + break; + } while (len > 0); + + /* According to UNIX98, msg_name/msg_namelen are ignored + * on connected socket. I was just happy when found this 8) --ANK + */ + + /* Clean up data we have read: This will do ACK frames. */ + tcp_cleanup_rbuf(sk, copied); + return copied; + +out: + return err; + +recv_urg: + err = tcp_recv_urg(sk, msg, len, flags); + goto out; + +recv_sndq: + err = tcp_peek_sndq(sk, msg, len); + goto out; +} + +int tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, + int *addr_len) +{ + int cmsg_flags = 0, ret; + struct scm_timestamping_internal tss; + + if (unlikely(flags & MSG_ERRQUEUE)) + return inet_recv_error(sk, msg, len, addr_len); + + if (sk_can_busy_loop(sk) && + skb_queue_empty_lockless(&sk->sk_receive_queue) && + sk->sk_state == TCP_ESTABLISHED) + sk_busy_loop(sk, flags & MSG_DONTWAIT); + + lock_sock(sk); + ret = tcp_recvmsg_locked(sk, msg, len, flags, &tss, &cmsg_flags); + release_sock(sk); + + if ((cmsg_flags || msg->msg_get_inq) && ret >= 0) { + if (cmsg_flags & TCP_CMSG_TS) + tcp_recv_timestamp(msg, sk, &tss); + if (msg->msg_get_inq) { + msg->msg_inq = tcp_inq_hint(sk); + if (cmsg_flags & TCP_CMSG_INQ) + put_cmsg(msg, SOL_TCP, TCP_CM_INQ, + sizeof(msg->msg_inq), &msg->msg_inq); + } + } + return ret; +} +EXPORT_SYMBOL(tcp_recvmsg); + +void tcp_set_state(struct sock *sk, int state) +{ + int oldstate = sk->sk_state; + + /* We defined a new enum for TCP states that are exported in BPF + * so as not force the internal TCP states to be frozen. The + * following checks will detect if an internal state value ever + * differs from the BPF value. If this ever happens, then we will + * need to remap the internal value to the BPF value before calling + * tcp_call_bpf_2arg. + */ + BUILD_BUG_ON((int)BPF_TCP_ESTABLISHED != (int)TCP_ESTABLISHED); + BUILD_BUG_ON((int)BPF_TCP_SYN_SENT != (int)TCP_SYN_SENT); + BUILD_BUG_ON((int)BPF_TCP_SYN_RECV != (int)TCP_SYN_RECV); + BUILD_BUG_ON((int)BPF_TCP_FIN_WAIT1 != (int)TCP_FIN_WAIT1); + BUILD_BUG_ON((int)BPF_TCP_FIN_WAIT2 != (int)TCP_FIN_WAIT2); + BUILD_BUG_ON((int)BPF_TCP_TIME_WAIT != (int)TCP_TIME_WAIT); + BUILD_BUG_ON((int)BPF_TCP_CLOSE != (int)TCP_CLOSE); + BUILD_BUG_ON((int)BPF_TCP_CLOSE_WAIT != (int)TCP_CLOSE_WAIT); + BUILD_BUG_ON((int)BPF_TCP_LAST_ACK != (int)TCP_LAST_ACK); + BUILD_BUG_ON((int)BPF_TCP_LISTEN != (int)TCP_LISTEN); + BUILD_BUG_ON((int)BPF_TCP_CLOSING != (int)TCP_CLOSING); + BUILD_BUG_ON((int)BPF_TCP_NEW_SYN_RECV != (int)TCP_NEW_SYN_RECV); + BUILD_BUG_ON((int)BPF_TCP_MAX_STATES != (int)TCP_MAX_STATES); + + /* bpf uapi header bpf.h defines an anonymous enum with values + * BPF_TCP_* used by bpf programs. Currently gcc built vmlinux + * is able to emit this enum in DWARF due to the above BUILD_BUG_ON. + * But clang built vmlinux does not have this enum in DWARF + * since clang removes the above code before generating IR/debuginfo. + * Let us explicitly emit the type debuginfo to ensure the + * above-mentioned anonymous enum in the vmlinux DWARF and hence BTF + * regardless of which compiler is used. + */ + BTF_TYPE_EMIT_ENUM(BPF_TCP_ESTABLISHED); + + if (BPF_SOCK_OPS_TEST_FLAG(tcp_sk(sk), BPF_SOCK_OPS_STATE_CB_FLAG)) + tcp_call_bpf_2arg(sk, BPF_SOCK_OPS_STATE_CB, oldstate, state); + + switch (state) { + case TCP_ESTABLISHED: + if (oldstate != TCP_ESTABLISHED) + TCP_INC_STATS(sock_net(sk), TCP_MIB_CURRESTAB); + break; + + case TCP_CLOSE: + if (oldstate == TCP_CLOSE_WAIT || oldstate == TCP_ESTABLISHED) + TCP_INC_STATS(sock_net(sk), TCP_MIB_ESTABRESETS); + + sk->sk_prot->unhash(sk); + if (inet_csk(sk)->icsk_bind_hash && + !(sk->sk_userlocks & SOCK_BINDPORT_LOCK)) + inet_put_port(sk); + fallthrough; + default: + if (oldstate == TCP_ESTABLISHED) + TCP_DEC_STATS(sock_net(sk), TCP_MIB_CURRESTAB); + } + + /* Change state AFTER socket is unhashed to avoid closed + * socket sitting in hash tables. + */ + inet_sk_state_store(sk, state); +} +EXPORT_SYMBOL_GPL(tcp_set_state); + +/* + * State processing on a close. This implements the state shift for + * sending our FIN frame. Note that we only send a FIN for some + * states. A shutdown() may have already sent the FIN, or we may be + * closed. + */ + +static const unsigned char new_state[16] = { + /* current state: new state: action: */ + [0 /* (Invalid) */] = TCP_CLOSE, + [TCP_ESTABLISHED] = TCP_FIN_WAIT1 | TCP_ACTION_FIN, + [TCP_SYN_SENT] = TCP_CLOSE, + [TCP_SYN_RECV] = TCP_FIN_WAIT1 | TCP_ACTION_FIN, + [TCP_FIN_WAIT1] = TCP_FIN_WAIT1, + [TCP_FIN_WAIT2] = TCP_FIN_WAIT2, + [TCP_TIME_WAIT] = TCP_CLOSE, + [TCP_CLOSE] = TCP_CLOSE, + [TCP_CLOSE_WAIT] = TCP_LAST_ACK | TCP_ACTION_FIN, + [TCP_LAST_ACK] = TCP_LAST_ACK, + [TCP_LISTEN] = TCP_CLOSE, + [TCP_CLOSING] = TCP_CLOSING, + [TCP_NEW_SYN_RECV] = TCP_CLOSE, /* should not happen ! */ +}; + +static int tcp_close_state(struct sock *sk) +{ + int next = (int)new_state[sk->sk_state]; + int ns = next & TCP_STATE_MASK; + + tcp_set_state(sk, ns); + + return next & TCP_ACTION_FIN; +} + +/* + * Shutdown the sending side of a connection. Much like close except + * that we don't receive shut down or sock_set_flag(sk, SOCK_DEAD). + */ + +void tcp_shutdown(struct sock *sk, int how) +{ + /* We need to grab some memory, and put together a FIN, + * and then put it into the queue to be sent. + * Tim MacKenzie(tym@dibbler.cs.monash.edu.au) 4 Dec '92. + */ + if (!(how & SEND_SHUTDOWN)) + return; + + /* If we've already sent a FIN, or it's a closed state, skip this. */ + if ((1 << sk->sk_state) & + (TCPF_ESTABLISHED | TCPF_SYN_SENT | + TCPF_SYN_RECV | TCPF_CLOSE_WAIT)) { + /* Clear out any half completed packets. FIN if needed. */ + if (tcp_close_state(sk)) + tcp_send_fin(sk); + } +} +EXPORT_SYMBOL(tcp_shutdown); + +int tcp_orphan_count_sum(void) +{ + int i, total = 0; + + for_each_possible_cpu(i) + total += per_cpu(tcp_orphan_count, i); + + return max(total, 0); +} + +static int tcp_orphan_cache; +static struct timer_list tcp_orphan_timer; +#define TCP_ORPHAN_TIMER_PERIOD msecs_to_jiffies(100) + +static void tcp_orphan_update(struct timer_list *unused) +{ + WRITE_ONCE(tcp_orphan_cache, tcp_orphan_count_sum()); + mod_timer(&tcp_orphan_timer, jiffies + TCP_ORPHAN_TIMER_PERIOD); +} + +static bool tcp_too_many_orphans(int shift) +{ + return READ_ONCE(tcp_orphan_cache) << shift > + READ_ONCE(sysctl_tcp_max_orphans); +} + +bool tcp_check_oom(struct sock *sk, int shift) +{ + bool too_many_orphans, out_of_socket_memory; + + too_many_orphans = tcp_too_many_orphans(shift); + out_of_socket_memory = tcp_out_of_memory(sk); + + if (too_many_orphans) + net_info_ratelimited("too many orphaned sockets\n"); + if (out_of_socket_memory) + net_info_ratelimited("out of memory -- consider tuning tcp_mem\n"); + return too_many_orphans || out_of_socket_memory; +} + +void __tcp_close(struct sock *sk, long timeout) +{ + struct sk_buff *skb; + int data_was_unread = 0; + int state; + + WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK); + + if (sk->sk_state == TCP_LISTEN) { + tcp_set_state(sk, TCP_CLOSE); + + /* Special case. */ + inet_csk_listen_stop(sk); + + goto adjudge_to_death; + } + + /* We need to flush the recv. buffs. We do this only on the + * descriptor close, not protocol-sourced closes, because the + * reader process may not have drained the data yet! + */ + while ((skb = __skb_dequeue(&sk->sk_receive_queue)) != NULL) { + u32 len = TCP_SKB_CB(skb)->end_seq - TCP_SKB_CB(skb)->seq; + + if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) + len--; + data_was_unread += len; + __kfree_skb(skb); + } + + /* If socket has been already reset (e.g. in tcp_reset()) - kill it. */ + if (sk->sk_state == TCP_CLOSE) + goto adjudge_to_death; + + /* As outlined in RFC 2525, section 2.17, we send a RST here because + * data was lost. To witness the awful effects of the old behavior of + * always doing a FIN, run an older 2.1.x kernel or 2.0.x, start a bulk + * GET in an FTP client, suspend the process, wait for the client to + * advertise a zero window, then kill -9 the FTP client, wheee... + * Note: timeout is always zero in such a case. + */ + if (unlikely(tcp_sk(sk)->repair)) { + sk->sk_prot->disconnect(sk, 0); + } else if (data_was_unread) { + /* Unread data was tossed, zap the connection. */ + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTONCLOSE); + tcp_set_state(sk, TCP_CLOSE); + tcp_send_active_reset(sk, sk->sk_allocation); + } else if (sock_flag(sk, SOCK_LINGER) && !sk->sk_lingertime) { + /* Check zero linger _after_ checking for unread data. */ + sk->sk_prot->disconnect(sk, 0); + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTONDATA); + } else if (tcp_close_state(sk)) { + /* We FIN if the application ate all the data before + * zapping the connection. + */ + + /* RED-PEN. Formally speaking, we have broken TCP state + * machine. State transitions: + * + * TCP_ESTABLISHED -> TCP_FIN_WAIT1 + * TCP_SYN_RECV -> TCP_FIN_WAIT1 (forget it, it's impossible) + * TCP_CLOSE_WAIT -> TCP_LAST_ACK + * + * are legal only when FIN has been sent (i.e. in window), + * rather than queued out of window. Purists blame. + * + * F.e. "RFC state" is ESTABLISHED, + * if Linux state is FIN-WAIT-1, but FIN is still not sent. + * + * The visible declinations are that sometimes + * we enter time-wait state, when it is not required really + * (harmless), do not send active resets, when they are + * required by specs (TCP_ESTABLISHED, TCP_CLOSE_WAIT, when + * they look as CLOSING or LAST_ACK for Linux) + * Probably, I missed some more holelets. + * --ANK + * XXX (TFO) - To start off we don't support SYN+ACK+FIN + * in a single packet! (May consider it later but will + * probably need API support or TCP_CORK SYN-ACK until + * data is written and socket is closed.) + */ + tcp_send_fin(sk); + } + + sk_stream_wait_close(sk, timeout); + +adjudge_to_death: + state = sk->sk_state; + sock_hold(sk); + sock_orphan(sk); + + local_bh_disable(); + bh_lock_sock(sk); + /* remove backlog if any, without releasing ownership. */ + __release_sock(sk); + + this_cpu_inc(tcp_orphan_count); + + /* Have we already been destroyed by a softirq or backlog? */ + if (state != TCP_CLOSE && sk->sk_state == TCP_CLOSE) + goto out; + + /* This is a (useful) BSD violating of the RFC. There is a + * problem with TCP as specified in that the other end could + * keep a socket open forever with no application left this end. + * We use a 1 minute timeout (about the same as BSD) then kill + * our end. If they send after that then tough - BUT: long enough + * that we won't make the old 4*rto = almost no time - whoops + * reset mistake. + * + * Nope, it was not mistake. It is really desired behaviour + * f.e. on http servers, when such sockets are useless, but + * consume significant resources. Let's do it with special + * linger2 option. --ANK + */ + + if (sk->sk_state == TCP_FIN_WAIT2) { + struct tcp_sock *tp = tcp_sk(sk); + if (tp->linger2 < 0) { + tcp_set_state(sk, TCP_CLOSE); + tcp_send_active_reset(sk, GFP_ATOMIC); + __NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPABORTONLINGER); + } else { + const int tmo = tcp_fin_time(sk); + + if (tmo > TCP_TIMEWAIT_LEN) { + inet_csk_reset_keepalive_timer(sk, + tmo - TCP_TIMEWAIT_LEN); + } else { + tcp_time_wait(sk, TCP_FIN_WAIT2, tmo); + goto out; + } + } + } + if (sk->sk_state != TCP_CLOSE) { + if (tcp_check_oom(sk, 0)) { + tcp_set_state(sk, TCP_CLOSE); + tcp_send_active_reset(sk, GFP_ATOMIC); + __NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPABORTONMEMORY); + } else if (!check_net(sock_net(sk))) { + /* Not possible to send reset; just close */ + tcp_set_state(sk, TCP_CLOSE); + } + } + + if (sk->sk_state == TCP_CLOSE) { + struct request_sock *req; + + req = rcu_dereference_protected(tcp_sk(sk)->fastopen_rsk, + lockdep_sock_is_held(sk)); + /* We could get here with a non-NULL req if the socket is + * aborted (e.g., closed with unread data) before 3WHS + * finishes. + */ + if (req) + reqsk_fastopen_remove(sk, req, false); + inet_csk_destroy_sock(sk); + } + /* Otherwise, socket is reprieved until protocol close. */ + +out: + bh_unlock_sock(sk); + local_bh_enable(); +} + +void tcp_close(struct sock *sk, long timeout) +{ + lock_sock(sk); + __tcp_close(sk, timeout); + release_sock(sk); + sock_put(sk); +} +EXPORT_SYMBOL(tcp_close); + +/* These states need RST on ABORT according to RFC793 */ + +static inline bool tcp_need_reset(int state) +{ + return (1 << state) & + (TCPF_ESTABLISHED | TCPF_CLOSE_WAIT | TCPF_FIN_WAIT1 | + TCPF_FIN_WAIT2 | TCPF_SYN_RECV); +} + +static void tcp_rtx_queue_purge(struct sock *sk) +{ + struct rb_node *p = rb_first(&sk->tcp_rtx_queue); + + tcp_sk(sk)->highest_sack = NULL; + while (p) { + struct sk_buff *skb = rb_to_skb(p); + + p = rb_next(p); + /* Since we are deleting whole queue, no need to + * list_del(&skb->tcp_tsorted_anchor) + */ + tcp_rtx_queue_unlink(skb, sk); + tcp_wmem_free_skb(sk, skb); + } +} + +void tcp_write_queue_purge(struct sock *sk) +{ + struct sk_buff *skb; + + tcp_chrono_stop(sk, TCP_CHRONO_BUSY); + while ((skb = __skb_dequeue(&sk->sk_write_queue)) != NULL) { + tcp_skb_tsorted_anchor_cleanup(skb); + tcp_wmem_free_skb(sk, skb); + } + tcp_rtx_queue_purge(sk); + INIT_LIST_HEAD(&tcp_sk(sk)->tsorted_sent_queue); + tcp_clear_all_retrans_hints(tcp_sk(sk)); + tcp_sk(sk)->packets_out = 0; + inet_csk(sk)->icsk_backoff = 0; +} + +int tcp_disconnect(struct sock *sk, int flags) +{ + struct inet_sock *inet = inet_sk(sk); + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + int old_state = sk->sk_state; + u32 seq; + + if (old_state != TCP_CLOSE) + tcp_set_state(sk, TCP_CLOSE); + + /* ABORT function of RFC793 */ + if (old_state == TCP_LISTEN) { + inet_csk_listen_stop(sk); + } else if (unlikely(tp->repair)) { + sk->sk_err = ECONNABORTED; + } else if (tcp_need_reset(old_state) || + (tp->snd_nxt != tp->write_seq && + (1 << old_state) & (TCPF_CLOSING | TCPF_LAST_ACK))) { + /* The last check adjusts for discrepancy of Linux wrt. RFC + * states + */ + tcp_send_active_reset(sk, gfp_any()); + sk->sk_err = ECONNRESET; + } else if (old_state == TCP_SYN_SENT) + sk->sk_err = ECONNRESET; + + tcp_clear_xmit_timers(sk); + __skb_queue_purge(&sk->sk_receive_queue); + WRITE_ONCE(tp->copied_seq, tp->rcv_nxt); + WRITE_ONCE(tp->urg_data, 0); + tcp_write_queue_purge(sk); + tcp_fastopen_active_disable_ofo_check(sk); + skb_rbtree_purge(&tp->out_of_order_queue); + + inet->inet_dport = 0; + + inet_bhash2_reset_saddr(sk); + + WRITE_ONCE(sk->sk_shutdown, 0); + sock_reset_flag(sk, SOCK_DONE); + tp->srtt_us = 0; + tp->mdev_us = jiffies_to_usecs(TCP_TIMEOUT_INIT); + tp->rcv_rtt_last_tsecr = 0; + + seq = tp->write_seq + tp->max_window + 2; + if (!seq) + seq = 1; + WRITE_ONCE(tp->write_seq, seq); + + icsk->icsk_backoff = 0; + icsk->icsk_probes_out = 0; + icsk->icsk_probes_tstamp = 0; + icsk->icsk_rto = TCP_TIMEOUT_INIT; + icsk->icsk_rto_min = TCP_RTO_MIN; + icsk->icsk_delack_max = TCP_DELACK_MAX; + tp->snd_ssthresh = TCP_INFINITE_SSTHRESH; + tcp_snd_cwnd_set(tp, TCP_INIT_CWND); + tp->snd_cwnd_cnt = 0; + tp->is_cwnd_limited = 0; + tp->max_packets_out = 0; + tp->window_clamp = 0; + tp->delivered = 0; + tp->delivered_ce = 0; + if (icsk->icsk_ca_ops->release) + icsk->icsk_ca_ops->release(sk); + memset(icsk->icsk_ca_priv, 0, sizeof(icsk->icsk_ca_priv)); + icsk->icsk_ca_initialized = 0; + tcp_set_ca_state(sk, TCP_CA_Open); + tp->is_sack_reneg = 0; + tcp_clear_retrans(tp); + tp->total_retrans = 0; + inet_csk_delack_init(sk); + /* Initialize rcv_mss to TCP_MIN_MSS to avoid division by 0 + * issue in __tcp_select_window() + */ + icsk->icsk_ack.rcv_mss = TCP_MIN_MSS; + memset(&tp->rx_opt, 0, sizeof(tp->rx_opt)); + __sk_dst_reset(sk); + dst_release(xchg((__force struct dst_entry **)&sk->sk_rx_dst, NULL)); + tcp_saved_syn_free(tp); + tp->compressed_ack = 0; + tp->segs_in = 0; + tp->segs_out = 0; + tp->bytes_sent = 0; + tp->bytes_acked = 0; + tp->bytes_received = 0; + tp->bytes_retrans = 0; + tp->data_segs_in = 0; + tp->data_segs_out = 0; + tp->duplicate_sack[0].start_seq = 0; + tp->duplicate_sack[0].end_seq = 0; + tp->dsack_dups = 0; + tp->reord_seen = 0; + tp->retrans_out = 0; + tp->sacked_out = 0; + tp->tlp_high_seq = 0; + tp->last_oow_ack_time = 0; + /* There's a bubble in the pipe until at least the first ACK. */ + tp->app_limited = ~0U; + tp->rate_app_limited = 1; + tp->rack.mstamp = 0; + tp->rack.advanced = 0; + tp->rack.reo_wnd_steps = 1; + tp->rack.last_delivered = 0; + tp->rack.reo_wnd_persist = 0; + tp->rack.dsack_seen = 0; + tp->syn_data_acked = 0; + tp->rx_opt.saw_tstamp = 0; + tp->rx_opt.dsack = 0; + tp->rx_opt.num_sacks = 0; + tp->rcv_ooopack = 0; + + + /* Clean up fastopen related fields */ + tcp_free_fastopen_req(tp); + inet->defer_connect = 0; + tp->fastopen_client_fail = 0; + + WARN_ON(inet->inet_num && !icsk->icsk_bind_hash); + + if (sk->sk_frag.page) { + put_page(sk->sk_frag.page); + sk->sk_frag.page = NULL; + sk->sk_frag.offset = 0; + } + sk_error_report(sk); + return 0; +} +EXPORT_SYMBOL(tcp_disconnect); + +static inline bool tcp_can_repair_sock(const struct sock *sk) +{ + return sockopt_ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN) && + (sk->sk_state != TCP_LISTEN); +} + +static int tcp_repair_set_window(struct tcp_sock *tp, sockptr_t optbuf, int len) +{ + struct tcp_repair_window opt; + + if (!tp->repair) + return -EPERM; + + if (len != sizeof(opt)) + return -EINVAL; + + if (copy_from_sockptr(&opt, optbuf, sizeof(opt))) + return -EFAULT; + + if (opt.max_window < opt.snd_wnd) + return -EINVAL; + + if (after(opt.snd_wl1, tp->rcv_nxt + opt.rcv_wnd)) + return -EINVAL; + + if (after(opt.rcv_wup, tp->rcv_nxt)) + return -EINVAL; + + tp->snd_wl1 = opt.snd_wl1; + tp->snd_wnd = opt.snd_wnd; + tp->max_window = opt.max_window; + + tp->rcv_wnd = opt.rcv_wnd; + tp->rcv_wup = opt.rcv_wup; + + return 0; +} + +static int tcp_repair_options_est(struct sock *sk, sockptr_t optbuf, + unsigned int len) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct tcp_repair_opt opt; + size_t offset = 0; + + while (len >= sizeof(opt)) { + if (copy_from_sockptr_offset(&opt, optbuf, offset, sizeof(opt))) + return -EFAULT; + + offset += sizeof(opt); + len -= sizeof(opt); + + switch (opt.opt_code) { + case TCPOPT_MSS: + tp->rx_opt.mss_clamp = opt.opt_val; + tcp_mtup_init(sk); + break; + case TCPOPT_WINDOW: + { + u16 snd_wscale = opt.opt_val & 0xFFFF; + u16 rcv_wscale = opt.opt_val >> 16; + + if (snd_wscale > TCP_MAX_WSCALE || rcv_wscale > TCP_MAX_WSCALE) + return -EFBIG; + + tp->rx_opt.snd_wscale = snd_wscale; + tp->rx_opt.rcv_wscale = rcv_wscale; + tp->rx_opt.wscale_ok = 1; + } + break; + case TCPOPT_SACK_PERM: + if (opt.opt_val != 0) + return -EINVAL; + + tp->rx_opt.sack_ok |= TCP_SACK_SEEN; + break; + case TCPOPT_TIMESTAMP: + if (opt.opt_val != 0) + return -EINVAL; + + tp->rx_opt.tstamp_ok = 1; + break; + } + } + + return 0; +} + +DEFINE_STATIC_KEY_FALSE(tcp_tx_delay_enabled); +EXPORT_SYMBOL(tcp_tx_delay_enabled); + +static void tcp_enable_tx_delay(void) +{ + if (!static_branch_unlikely(&tcp_tx_delay_enabled)) { + static int __tcp_tx_delay_enabled = 0; + + if (cmpxchg(&__tcp_tx_delay_enabled, 0, 1) == 0) { + static_branch_enable(&tcp_tx_delay_enabled); + pr_info("TCP_TX_DELAY enabled\n"); + } + } +} + +/* When set indicates to always queue non-full frames. Later the user clears + * this option and we transmit any pending partial frames in the queue. This is + * meant to be used alongside sendfile() to get properly filled frames when the + * user (for example) must write out headers with a write() call first and then + * use sendfile to send out the data parts. + * + * TCP_CORK can be set together with TCP_NODELAY and it is stronger than + * TCP_NODELAY. + */ +void __tcp_sock_set_cork(struct sock *sk, bool on) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (on) { + tp->nonagle |= TCP_NAGLE_CORK; + } else { + tp->nonagle &= ~TCP_NAGLE_CORK; + if (tp->nonagle & TCP_NAGLE_OFF) + tp->nonagle |= TCP_NAGLE_PUSH; + tcp_push_pending_frames(sk); + } +} + +void tcp_sock_set_cork(struct sock *sk, bool on) +{ + lock_sock(sk); + __tcp_sock_set_cork(sk, on); + release_sock(sk); +} +EXPORT_SYMBOL(tcp_sock_set_cork); + +/* TCP_NODELAY is weaker than TCP_CORK, so that this option on corked socket is + * remembered, but it is not activated until cork is cleared. + * + * However, when TCP_NODELAY is set we make an explicit push, which overrides + * even TCP_CORK for currently queued segments. + */ +void __tcp_sock_set_nodelay(struct sock *sk, bool on) +{ + if (on) { + tcp_sk(sk)->nonagle |= TCP_NAGLE_OFF|TCP_NAGLE_PUSH; + tcp_push_pending_frames(sk); + } else { + tcp_sk(sk)->nonagle &= ~TCP_NAGLE_OFF; + } +} + +void tcp_sock_set_nodelay(struct sock *sk) +{ + lock_sock(sk); + __tcp_sock_set_nodelay(sk, true); + release_sock(sk); +} +EXPORT_SYMBOL(tcp_sock_set_nodelay); + +static void __tcp_sock_set_quickack(struct sock *sk, int val) +{ + if (!val) { + inet_csk_enter_pingpong_mode(sk); + return; + } + + inet_csk_exit_pingpong_mode(sk); + if ((1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_CLOSE_WAIT) && + inet_csk_ack_scheduled(sk)) { + inet_csk(sk)->icsk_ack.pending |= ICSK_ACK_PUSHED; + tcp_cleanup_rbuf(sk, 1); + if (!(val & 1)) + inet_csk_enter_pingpong_mode(sk); + } +} + +void tcp_sock_set_quickack(struct sock *sk, int val) +{ + lock_sock(sk); + __tcp_sock_set_quickack(sk, val); + release_sock(sk); +} +EXPORT_SYMBOL(tcp_sock_set_quickack); + +int tcp_sock_set_syncnt(struct sock *sk, int val) +{ + if (val < 1 || val > MAX_TCP_SYNCNT) + return -EINVAL; + + lock_sock(sk); + WRITE_ONCE(inet_csk(sk)->icsk_syn_retries, val); + release_sock(sk); + return 0; +} +EXPORT_SYMBOL(tcp_sock_set_syncnt); + +void tcp_sock_set_user_timeout(struct sock *sk, u32 val) +{ + lock_sock(sk); + WRITE_ONCE(inet_csk(sk)->icsk_user_timeout, val); + release_sock(sk); +} +EXPORT_SYMBOL(tcp_sock_set_user_timeout); + +int tcp_sock_set_keepidle_locked(struct sock *sk, int val) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (val < 1 || val > MAX_TCP_KEEPIDLE) + return -EINVAL; + + /* Paired with WRITE_ONCE() in keepalive_time_when() */ + WRITE_ONCE(tp->keepalive_time, val * HZ); + if (sock_flag(sk, SOCK_KEEPOPEN) && + !((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN))) { + u32 elapsed = keepalive_time_elapsed(tp); + + if (tp->keepalive_time > elapsed) + elapsed = tp->keepalive_time - elapsed; + else + elapsed = 0; + inet_csk_reset_keepalive_timer(sk, elapsed); + } + + return 0; +} + +int tcp_sock_set_keepidle(struct sock *sk, int val) +{ + int err; + + lock_sock(sk); + err = tcp_sock_set_keepidle_locked(sk, val); + release_sock(sk); + return err; +} +EXPORT_SYMBOL(tcp_sock_set_keepidle); + +int tcp_sock_set_keepintvl(struct sock *sk, int val) +{ + if (val < 1 || val > MAX_TCP_KEEPINTVL) + return -EINVAL; + + lock_sock(sk); + WRITE_ONCE(tcp_sk(sk)->keepalive_intvl, val * HZ); + release_sock(sk); + return 0; +} +EXPORT_SYMBOL(tcp_sock_set_keepintvl); + +int tcp_sock_set_keepcnt(struct sock *sk, int val) +{ + if (val < 1 || val > MAX_TCP_KEEPCNT) + return -EINVAL; + + lock_sock(sk); + /* Paired with READ_ONCE() in keepalive_probes() */ + WRITE_ONCE(tcp_sk(sk)->keepalive_probes, val); + release_sock(sk); + return 0; +} +EXPORT_SYMBOL(tcp_sock_set_keepcnt); + +int tcp_set_window_clamp(struct sock *sk, int val) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (!val) { + if (sk->sk_state != TCP_CLOSE) + return -EINVAL; + tp->window_clamp = 0; + } else { + u32 new_rcv_ssthresh, old_window_clamp = tp->window_clamp; + u32 new_window_clamp = val < SOCK_MIN_RCVBUF / 2 ? + SOCK_MIN_RCVBUF / 2 : val; + + if (new_window_clamp == old_window_clamp) + return 0; + + tp->window_clamp = new_window_clamp; + if (new_window_clamp < old_window_clamp) { + /* need to apply the reserved mem provisioning only + * when shrinking the window clamp + */ + __tcp_adjust_rcv_ssthresh(sk, tp->window_clamp); + + } else { + new_rcv_ssthresh = min(tp->rcv_wnd, tp->window_clamp); + tp->rcv_ssthresh = max(new_rcv_ssthresh, + tp->rcv_ssthresh); + } + } + return 0; +} + +/* + * Socket option code for TCP. + */ +int do_tcp_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct inet_connection_sock *icsk = inet_csk(sk); + struct net *net = sock_net(sk); + int val; + int err = 0; + + /* These are data/string values, all the others are ints */ + switch (optname) { + case TCP_CONGESTION: { + char name[TCP_CA_NAME_MAX]; + + if (optlen < 1) + return -EINVAL; + + val = strncpy_from_sockptr(name, optval, + min_t(long, TCP_CA_NAME_MAX-1, optlen)); + if (val < 0) + return -EFAULT; + name[val] = 0; + + sockopt_lock_sock(sk); + err = tcp_set_congestion_control(sk, name, !has_current_bpf_ctx(), + sockopt_ns_capable(sock_net(sk)->user_ns, + CAP_NET_ADMIN)); + sockopt_release_sock(sk); + return err; + } + case TCP_ULP: { + char name[TCP_ULP_NAME_MAX]; + + if (optlen < 1) + return -EINVAL; + + val = strncpy_from_sockptr(name, optval, + min_t(long, TCP_ULP_NAME_MAX - 1, + optlen)); + if (val < 0) + return -EFAULT; + name[val] = 0; + + sockopt_lock_sock(sk); + err = tcp_set_ulp(sk, name); + sockopt_release_sock(sk); + return err; + } + case TCP_FASTOPEN_KEY: { + __u8 key[TCP_FASTOPEN_KEY_BUF_LENGTH]; + __u8 *backup_key = NULL; + + /* Allow a backup key as well to facilitate key rotation + * First key is the active one. + */ + if (optlen != TCP_FASTOPEN_KEY_LENGTH && + optlen != TCP_FASTOPEN_KEY_BUF_LENGTH) + return -EINVAL; + + if (copy_from_sockptr(key, optval, optlen)) + return -EFAULT; + + if (optlen == TCP_FASTOPEN_KEY_BUF_LENGTH) + backup_key = key + TCP_FASTOPEN_KEY_LENGTH; + + return tcp_fastopen_reset_cipher(net, sk, key, backup_key); + } + default: + /* fallthru */ + break; + } + + if (optlen < sizeof(int)) + return -EINVAL; + + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + + sockopt_lock_sock(sk); + + switch (optname) { + case TCP_MAXSEG: + /* Values greater than interface MTU won't take effect. However + * at the point when this call is done we typically don't yet + * know which interface is going to be used + */ + if (val && (val < TCP_MIN_MSS || val > MAX_TCP_WINDOW)) { + err = -EINVAL; + break; + } + tp->rx_opt.user_mss = val; + break; + + case TCP_NODELAY: + __tcp_sock_set_nodelay(sk, val); + break; + + case TCP_THIN_LINEAR_TIMEOUTS: + if (val < 0 || val > 1) + err = -EINVAL; + else + tp->thin_lto = val; + break; + + case TCP_THIN_DUPACK: + if (val < 0 || val > 1) + err = -EINVAL; + break; + + case TCP_REPAIR: + if (!tcp_can_repair_sock(sk)) + err = -EPERM; + else if (val == TCP_REPAIR_ON) { + tp->repair = 1; + sk->sk_reuse = SK_FORCE_REUSE; + tp->repair_queue = TCP_NO_QUEUE; + } else if (val == TCP_REPAIR_OFF) { + tp->repair = 0; + sk->sk_reuse = SK_NO_REUSE; + tcp_send_window_probe(sk); + } else if (val == TCP_REPAIR_OFF_NO_WP) { + tp->repair = 0; + sk->sk_reuse = SK_NO_REUSE; + } else + err = -EINVAL; + + break; + + case TCP_REPAIR_QUEUE: + if (!tp->repair) + err = -EPERM; + else if ((unsigned int)val < TCP_QUEUES_NR) + tp->repair_queue = val; + else + err = -EINVAL; + break; + + case TCP_QUEUE_SEQ: + if (sk->sk_state != TCP_CLOSE) { + err = -EPERM; + } else if (tp->repair_queue == TCP_SEND_QUEUE) { + if (!tcp_rtx_queue_empty(sk)) + err = -EPERM; + else + WRITE_ONCE(tp->write_seq, val); + } else if (tp->repair_queue == TCP_RECV_QUEUE) { + if (tp->rcv_nxt != tp->copied_seq) { + err = -EPERM; + } else { + WRITE_ONCE(tp->rcv_nxt, val); + WRITE_ONCE(tp->copied_seq, val); + } + } else { + err = -EINVAL; + } + break; + + case TCP_REPAIR_OPTIONS: + if (!tp->repair) + err = -EINVAL; + else if (sk->sk_state == TCP_ESTABLISHED && !tp->bytes_sent) + err = tcp_repair_options_est(sk, optval, optlen); + else + err = -EPERM; + break; + + case TCP_CORK: + __tcp_sock_set_cork(sk, val); + break; + + case TCP_KEEPIDLE: + err = tcp_sock_set_keepidle_locked(sk, val); + break; + case TCP_KEEPINTVL: + if (val < 1 || val > MAX_TCP_KEEPINTVL) + err = -EINVAL; + else + WRITE_ONCE(tp->keepalive_intvl, val * HZ); + break; + case TCP_KEEPCNT: + if (val < 1 || val > MAX_TCP_KEEPCNT) + err = -EINVAL; + else + WRITE_ONCE(tp->keepalive_probes, val); + break; + case TCP_SYNCNT: + if (val < 1 || val > MAX_TCP_SYNCNT) + err = -EINVAL; + else + WRITE_ONCE(icsk->icsk_syn_retries, val); + break; + + case TCP_SAVE_SYN: + /* 0: disable, 1: enable, 2: start from ether_header */ + if (val < 0 || val > 2) + err = -EINVAL; + else + tp->save_syn = val; + break; + + case TCP_LINGER2: + if (val < 0) + WRITE_ONCE(tp->linger2, -1); + else if (val > TCP_FIN_TIMEOUT_MAX / HZ) + WRITE_ONCE(tp->linger2, TCP_FIN_TIMEOUT_MAX); + else + WRITE_ONCE(tp->linger2, val * HZ); + break; + + case TCP_DEFER_ACCEPT: + /* Translate value in seconds to number of retransmits */ + WRITE_ONCE(icsk->icsk_accept_queue.rskq_defer_accept, + secs_to_retrans(val, TCP_TIMEOUT_INIT / HZ, + TCP_RTO_MAX / HZ)); + break; + + case TCP_WINDOW_CLAMP: + err = tcp_set_window_clamp(sk, val); + break; + + case TCP_QUICKACK: + __tcp_sock_set_quickack(sk, val); + break; + +#ifdef CONFIG_TCP_MD5SIG + case TCP_MD5SIG: + case TCP_MD5SIG_EXT: + err = tp->af_specific->md5_parse(sk, optname, optval, optlen); + break; +#endif + case TCP_USER_TIMEOUT: + /* Cap the max time in ms TCP will retry or probe the window + * before giving up and aborting (ETIMEDOUT) a connection. + */ + if (val < 0) + err = -EINVAL; + else + WRITE_ONCE(icsk->icsk_user_timeout, val); + break; + + case TCP_FASTOPEN: + if (val >= 0 && ((1 << sk->sk_state) & (TCPF_CLOSE | + TCPF_LISTEN))) { + tcp_fastopen_init_key_once(net); + + fastopen_queue_tune(sk, val); + } else { + err = -EINVAL; + } + break; + case TCP_FASTOPEN_CONNECT: + if (val > 1 || val < 0) { + err = -EINVAL; + } else if (READ_ONCE(net->ipv4.sysctl_tcp_fastopen) & + TFO_CLIENT_ENABLE) { + if (sk->sk_state == TCP_CLOSE) + tp->fastopen_connect = val; + else + err = -EINVAL; + } else { + err = -EOPNOTSUPP; + } + break; + case TCP_FASTOPEN_NO_COOKIE: + if (val > 1 || val < 0) + err = -EINVAL; + else if (!((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN))) + err = -EINVAL; + else + tp->fastopen_no_cookie = val; + break; + case TCP_TIMESTAMP: + if (!tp->repair) + err = -EPERM; + else + WRITE_ONCE(tp->tsoffset, val - tcp_time_stamp_raw()); + break; + case TCP_REPAIR_WINDOW: + err = tcp_repair_set_window(tp, optval, optlen); + break; + case TCP_NOTSENT_LOWAT: + WRITE_ONCE(tp->notsent_lowat, val); + sk->sk_write_space(sk); + break; + case TCP_INQ: + if (val > 1 || val < 0) + err = -EINVAL; + else + tp->recvmsg_inq = val; + break; + case TCP_TX_DELAY: + if (val) + tcp_enable_tx_delay(); + WRITE_ONCE(tp->tcp_tx_delay, val); + break; + default: + err = -ENOPROTOOPT; + break; + } + + sockopt_release_sock(sk); + return err; +} + +int tcp_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, + unsigned int optlen) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + + if (level != SOL_TCP) + /* Paired with WRITE_ONCE() in do_ipv6_setsockopt() and tcp_v6_connect() */ + return READ_ONCE(icsk->icsk_af_ops)->setsockopt(sk, level, optname, + optval, optlen); + return do_tcp_setsockopt(sk, level, optname, optval, optlen); +} +EXPORT_SYMBOL(tcp_setsockopt); + +static void tcp_get_info_chrono_stats(const struct tcp_sock *tp, + struct tcp_info *info) +{ + u64 stats[__TCP_CHRONO_MAX], total = 0; + enum tcp_chrono i; + + for (i = TCP_CHRONO_BUSY; i < __TCP_CHRONO_MAX; ++i) { + stats[i] = tp->chrono_stat[i - 1]; + if (i == tp->chrono_type) + stats[i] += tcp_jiffies32 - tp->chrono_start; + stats[i] *= USEC_PER_SEC / HZ; + total += stats[i]; + } + + info->tcpi_busy_time = total; + info->tcpi_rwnd_limited = stats[TCP_CHRONO_RWND_LIMITED]; + info->tcpi_sndbuf_limited = stats[TCP_CHRONO_SNDBUF_LIMITED]; +} + +/* Return information about state of tcp endpoint in API format. */ +void tcp_get_info(struct sock *sk, struct tcp_info *info) +{ + const struct tcp_sock *tp = tcp_sk(sk); /* iff sk_type == SOCK_STREAM */ + const struct inet_connection_sock *icsk = inet_csk(sk); + unsigned long rate; + u32 now; + u64 rate64; + bool slow; + + memset(info, 0, sizeof(*info)); + if (sk->sk_type != SOCK_STREAM) + return; + + info->tcpi_state = inet_sk_state_load(sk); + + /* Report meaningful fields for all TCP states, including listeners */ + rate = READ_ONCE(sk->sk_pacing_rate); + rate64 = (rate != ~0UL) ? rate : ~0ULL; + info->tcpi_pacing_rate = rate64; + + rate = READ_ONCE(sk->sk_max_pacing_rate); + rate64 = (rate != ~0UL) ? rate : ~0ULL; + info->tcpi_max_pacing_rate = rate64; + + info->tcpi_reordering = tp->reordering; + info->tcpi_snd_cwnd = tcp_snd_cwnd(tp); + + if (info->tcpi_state == TCP_LISTEN) { + /* listeners aliased fields : + * tcpi_unacked -> Number of children ready for accept() + * tcpi_sacked -> max backlog + */ + info->tcpi_unacked = READ_ONCE(sk->sk_ack_backlog); + info->tcpi_sacked = READ_ONCE(sk->sk_max_ack_backlog); + return; + } + + slow = lock_sock_fast(sk); + + info->tcpi_ca_state = icsk->icsk_ca_state; + info->tcpi_retransmits = icsk->icsk_retransmits; + info->tcpi_probes = icsk->icsk_probes_out; + info->tcpi_backoff = icsk->icsk_backoff; + + if (tp->rx_opt.tstamp_ok) + info->tcpi_options |= TCPI_OPT_TIMESTAMPS; + if (tcp_is_sack(tp)) + info->tcpi_options |= TCPI_OPT_SACK; + if (tp->rx_opt.wscale_ok) { + info->tcpi_options |= TCPI_OPT_WSCALE; + info->tcpi_snd_wscale = tp->rx_opt.snd_wscale; + info->tcpi_rcv_wscale = tp->rx_opt.rcv_wscale; + } + + if (tp->ecn_flags & TCP_ECN_OK) + info->tcpi_options |= TCPI_OPT_ECN; + if (tp->ecn_flags & TCP_ECN_SEEN) + info->tcpi_options |= TCPI_OPT_ECN_SEEN; + if (tp->syn_data_acked) + info->tcpi_options |= TCPI_OPT_SYN_DATA; + + info->tcpi_rto = jiffies_to_usecs(icsk->icsk_rto); + info->tcpi_ato = jiffies_to_usecs(icsk->icsk_ack.ato); + info->tcpi_snd_mss = tp->mss_cache; + info->tcpi_rcv_mss = icsk->icsk_ack.rcv_mss; + + info->tcpi_unacked = tp->packets_out; + info->tcpi_sacked = tp->sacked_out; + + info->tcpi_lost = tp->lost_out; + info->tcpi_retrans = tp->retrans_out; + + now = tcp_jiffies32; + info->tcpi_last_data_sent = jiffies_to_msecs(now - tp->lsndtime); + info->tcpi_last_data_recv = jiffies_to_msecs(now - icsk->icsk_ack.lrcvtime); + info->tcpi_last_ack_recv = jiffies_to_msecs(now - tp->rcv_tstamp); + + info->tcpi_pmtu = icsk->icsk_pmtu_cookie; + info->tcpi_rcv_ssthresh = tp->rcv_ssthresh; + info->tcpi_rtt = tp->srtt_us >> 3; + info->tcpi_rttvar = tp->mdev_us >> 2; + info->tcpi_snd_ssthresh = tp->snd_ssthresh; + info->tcpi_advmss = tp->advmss; + + info->tcpi_rcv_rtt = tp->rcv_rtt_est.rtt_us >> 3; + info->tcpi_rcv_space = tp->rcvq_space.space; + + info->tcpi_total_retrans = tp->total_retrans; + + info->tcpi_bytes_acked = tp->bytes_acked; + info->tcpi_bytes_received = tp->bytes_received; + info->tcpi_notsent_bytes = max_t(int, 0, tp->write_seq - tp->snd_nxt); + tcp_get_info_chrono_stats(tp, info); + + info->tcpi_segs_out = tp->segs_out; + + /* segs_in and data_segs_in can be updated from tcp_segs_in() from BH */ + info->tcpi_segs_in = READ_ONCE(tp->segs_in); + info->tcpi_data_segs_in = READ_ONCE(tp->data_segs_in); + + info->tcpi_min_rtt = tcp_min_rtt(tp); + info->tcpi_data_segs_out = tp->data_segs_out; + + info->tcpi_delivery_rate_app_limited = tp->rate_app_limited ? 1 : 0; + rate64 = tcp_compute_delivery_rate(tp); + if (rate64) + info->tcpi_delivery_rate = rate64; + info->tcpi_delivered = tp->delivered; + info->tcpi_delivered_ce = tp->delivered_ce; + info->tcpi_bytes_sent = tp->bytes_sent; + info->tcpi_bytes_retrans = tp->bytes_retrans; + info->tcpi_dsack_dups = tp->dsack_dups; + info->tcpi_reord_seen = tp->reord_seen; + info->tcpi_rcv_ooopack = tp->rcv_ooopack; + info->tcpi_snd_wnd = tp->snd_wnd; + info->tcpi_fastopen_client_fail = tp->fastopen_client_fail; + unlock_sock_fast(sk, slow); +} +EXPORT_SYMBOL_GPL(tcp_get_info); + +static size_t tcp_opt_stats_get_size(void) +{ + return + nla_total_size_64bit(sizeof(u64)) + /* TCP_NLA_BUSY */ + nla_total_size_64bit(sizeof(u64)) + /* TCP_NLA_RWND_LIMITED */ + nla_total_size_64bit(sizeof(u64)) + /* TCP_NLA_SNDBUF_LIMITED */ + nla_total_size_64bit(sizeof(u64)) + /* TCP_NLA_DATA_SEGS_OUT */ + nla_total_size_64bit(sizeof(u64)) + /* TCP_NLA_TOTAL_RETRANS */ + nla_total_size_64bit(sizeof(u64)) + /* TCP_NLA_PACING_RATE */ + nla_total_size_64bit(sizeof(u64)) + /* TCP_NLA_DELIVERY_RATE */ + nla_total_size(sizeof(u32)) + /* TCP_NLA_SND_CWND */ + nla_total_size(sizeof(u32)) + /* TCP_NLA_REORDERING */ + nla_total_size(sizeof(u32)) + /* TCP_NLA_MIN_RTT */ + nla_total_size(sizeof(u8)) + /* TCP_NLA_RECUR_RETRANS */ + nla_total_size(sizeof(u8)) + /* TCP_NLA_DELIVERY_RATE_APP_LMT */ + nla_total_size(sizeof(u32)) + /* TCP_NLA_SNDQ_SIZE */ + nla_total_size(sizeof(u8)) + /* TCP_NLA_CA_STATE */ + nla_total_size(sizeof(u32)) + /* TCP_NLA_SND_SSTHRESH */ + nla_total_size(sizeof(u32)) + /* TCP_NLA_DELIVERED */ + nla_total_size(sizeof(u32)) + /* TCP_NLA_DELIVERED_CE */ + nla_total_size_64bit(sizeof(u64)) + /* TCP_NLA_BYTES_SENT */ + nla_total_size_64bit(sizeof(u64)) + /* TCP_NLA_BYTES_RETRANS */ + nla_total_size(sizeof(u32)) + /* TCP_NLA_DSACK_DUPS */ + nla_total_size(sizeof(u32)) + /* TCP_NLA_REORD_SEEN */ + nla_total_size(sizeof(u32)) + /* TCP_NLA_SRTT */ + nla_total_size(sizeof(u16)) + /* TCP_NLA_TIMEOUT_REHASH */ + nla_total_size(sizeof(u32)) + /* TCP_NLA_BYTES_NOTSENT */ + nla_total_size_64bit(sizeof(u64)) + /* TCP_NLA_EDT */ + nla_total_size(sizeof(u8)) + /* TCP_NLA_TTL */ + 0; +} + +/* Returns TTL or hop limit of an incoming packet from skb. */ +static u8 tcp_skb_ttl_or_hop_limit(const struct sk_buff *skb) +{ + if (skb->protocol == htons(ETH_P_IP)) + return ip_hdr(skb)->ttl; + else if (skb->protocol == htons(ETH_P_IPV6)) + return ipv6_hdr(skb)->hop_limit; + else + return 0; +} + +struct sk_buff *tcp_get_timestamping_opt_stats(const struct sock *sk, + const struct sk_buff *orig_skb, + const struct sk_buff *ack_skb) +{ + const struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *stats; + struct tcp_info info; + unsigned long rate; + u64 rate64; + + stats = alloc_skb(tcp_opt_stats_get_size(), GFP_ATOMIC); + if (!stats) + return NULL; + + tcp_get_info_chrono_stats(tp, &info); + nla_put_u64_64bit(stats, TCP_NLA_BUSY, + info.tcpi_busy_time, TCP_NLA_PAD); + nla_put_u64_64bit(stats, TCP_NLA_RWND_LIMITED, + info.tcpi_rwnd_limited, TCP_NLA_PAD); + nla_put_u64_64bit(stats, TCP_NLA_SNDBUF_LIMITED, + info.tcpi_sndbuf_limited, TCP_NLA_PAD); + nla_put_u64_64bit(stats, TCP_NLA_DATA_SEGS_OUT, + tp->data_segs_out, TCP_NLA_PAD); + nla_put_u64_64bit(stats, TCP_NLA_TOTAL_RETRANS, + tp->total_retrans, TCP_NLA_PAD); + + rate = READ_ONCE(sk->sk_pacing_rate); + rate64 = (rate != ~0UL) ? rate : ~0ULL; + nla_put_u64_64bit(stats, TCP_NLA_PACING_RATE, rate64, TCP_NLA_PAD); + + rate64 = tcp_compute_delivery_rate(tp); + nla_put_u64_64bit(stats, TCP_NLA_DELIVERY_RATE, rate64, TCP_NLA_PAD); + + nla_put_u32(stats, TCP_NLA_SND_CWND, tcp_snd_cwnd(tp)); + nla_put_u32(stats, TCP_NLA_REORDERING, tp->reordering); + nla_put_u32(stats, TCP_NLA_MIN_RTT, tcp_min_rtt(tp)); + + nla_put_u8(stats, TCP_NLA_RECUR_RETRANS, inet_csk(sk)->icsk_retransmits); + nla_put_u8(stats, TCP_NLA_DELIVERY_RATE_APP_LMT, !!tp->rate_app_limited); + nla_put_u32(stats, TCP_NLA_SND_SSTHRESH, tp->snd_ssthresh); + nla_put_u32(stats, TCP_NLA_DELIVERED, tp->delivered); + nla_put_u32(stats, TCP_NLA_DELIVERED_CE, tp->delivered_ce); + + nla_put_u32(stats, TCP_NLA_SNDQ_SIZE, tp->write_seq - tp->snd_una); + nla_put_u8(stats, TCP_NLA_CA_STATE, inet_csk(sk)->icsk_ca_state); + + nla_put_u64_64bit(stats, TCP_NLA_BYTES_SENT, tp->bytes_sent, + TCP_NLA_PAD); + nla_put_u64_64bit(stats, TCP_NLA_BYTES_RETRANS, tp->bytes_retrans, + TCP_NLA_PAD); + nla_put_u32(stats, TCP_NLA_DSACK_DUPS, tp->dsack_dups); + nla_put_u32(stats, TCP_NLA_REORD_SEEN, tp->reord_seen); + nla_put_u32(stats, TCP_NLA_SRTT, tp->srtt_us >> 3); + nla_put_u16(stats, TCP_NLA_TIMEOUT_REHASH, tp->timeout_rehash); + nla_put_u32(stats, TCP_NLA_BYTES_NOTSENT, + max_t(int, 0, tp->write_seq - tp->snd_nxt)); + nla_put_u64_64bit(stats, TCP_NLA_EDT, orig_skb->skb_mstamp_ns, + TCP_NLA_PAD); + if (ack_skb) + nla_put_u8(stats, TCP_NLA_TTL, + tcp_skb_ttl_or_hop_limit(ack_skb)); + + return stats; +} + +int do_tcp_getsockopt(struct sock *sk, int level, + int optname, sockptr_t optval, sockptr_t optlen) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); + int val, len; + + if (copy_from_sockptr(&len, optlen, sizeof(int))) + return -EFAULT; + + len = min_t(unsigned int, len, sizeof(int)); + + if (len < 0) + return -EINVAL; + + switch (optname) { + case TCP_MAXSEG: + val = tp->mss_cache; + if (tp->rx_opt.user_mss && + ((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN))) + val = tp->rx_opt.user_mss; + if (tp->repair) + val = tp->rx_opt.mss_clamp; + break; + case TCP_NODELAY: + val = !!(tp->nonagle&TCP_NAGLE_OFF); + break; + case TCP_CORK: + val = !!(tp->nonagle&TCP_NAGLE_CORK); + break; + case TCP_KEEPIDLE: + val = keepalive_time_when(tp) / HZ; + break; + case TCP_KEEPINTVL: + val = keepalive_intvl_when(tp) / HZ; + break; + case TCP_KEEPCNT: + val = keepalive_probes(tp); + break; + case TCP_SYNCNT: + val = READ_ONCE(icsk->icsk_syn_retries) ? : + READ_ONCE(net->ipv4.sysctl_tcp_syn_retries); + break; + case TCP_LINGER2: + val = READ_ONCE(tp->linger2); + if (val >= 0) + val = (val ? : READ_ONCE(net->ipv4.sysctl_tcp_fin_timeout)) / HZ; + break; + case TCP_DEFER_ACCEPT: + val = READ_ONCE(icsk->icsk_accept_queue.rskq_defer_accept); + val = retrans_to_secs(val, TCP_TIMEOUT_INIT / HZ, + TCP_RTO_MAX / HZ); + break; + case TCP_WINDOW_CLAMP: + val = tp->window_clamp; + break; + case TCP_INFO: { + struct tcp_info info; + + if (copy_from_sockptr(&len, optlen, sizeof(int))) + return -EFAULT; + + tcp_get_info(sk, &info); + + len = min_t(unsigned int, len, sizeof(info)); + if (copy_to_sockptr(optlen, &len, sizeof(int))) + return -EFAULT; + if (copy_to_sockptr(optval, &info, len)) + return -EFAULT; + return 0; + } + case TCP_CC_INFO: { + const struct tcp_congestion_ops *ca_ops; + union tcp_cc_info info; + size_t sz = 0; + int attr; + + if (copy_from_sockptr(&len, optlen, sizeof(int))) + return -EFAULT; + + ca_ops = icsk->icsk_ca_ops; + if (ca_ops && ca_ops->get_info) + sz = ca_ops->get_info(sk, ~0U, &attr, &info); + + len = min_t(unsigned int, len, sz); + if (copy_to_sockptr(optlen, &len, sizeof(int))) + return -EFAULT; + if (copy_to_sockptr(optval, &info, len)) + return -EFAULT; + return 0; + } + case TCP_QUICKACK: + val = !inet_csk_in_pingpong_mode(sk); + break; + + case TCP_CONGESTION: + if (copy_from_sockptr(&len, optlen, sizeof(int))) + return -EFAULT; + len = min_t(unsigned int, len, TCP_CA_NAME_MAX); + if (copy_to_sockptr(optlen, &len, sizeof(int))) + return -EFAULT; + if (copy_to_sockptr(optval, icsk->icsk_ca_ops->name, len)) + return -EFAULT; + return 0; + + case TCP_ULP: + if (copy_from_sockptr(&len, optlen, sizeof(int))) + return -EFAULT; + len = min_t(unsigned int, len, TCP_ULP_NAME_MAX); + if (!icsk->icsk_ulp_ops) { + len = 0; + if (copy_to_sockptr(optlen, &len, sizeof(int))) + return -EFAULT; + return 0; + } + if (copy_to_sockptr(optlen, &len, sizeof(int))) + return -EFAULT; + if (copy_to_sockptr(optval, icsk->icsk_ulp_ops->name, len)) + return -EFAULT; + return 0; + + case TCP_FASTOPEN_KEY: { + u64 key[TCP_FASTOPEN_KEY_BUF_LENGTH / sizeof(u64)]; + unsigned int key_len; + + if (copy_from_sockptr(&len, optlen, sizeof(int))) + return -EFAULT; + + key_len = tcp_fastopen_get_cipher(net, icsk, key) * + TCP_FASTOPEN_KEY_LENGTH; + len = min_t(unsigned int, len, key_len); + if (copy_to_sockptr(optlen, &len, sizeof(int))) + return -EFAULT; + if (copy_to_sockptr(optval, key, len)) + return -EFAULT; + return 0; + } + case TCP_THIN_LINEAR_TIMEOUTS: + val = tp->thin_lto; + break; + + case TCP_THIN_DUPACK: + val = 0; + break; + + case TCP_REPAIR: + val = tp->repair; + break; + + case TCP_REPAIR_QUEUE: + if (tp->repair) + val = tp->repair_queue; + else + return -EINVAL; + break; + + case TCP_REPAIR_WINDOW: { + struct tcp_repair_window opt; + + if (copy_from_sockptr(&len, optlen, sizeof(int))) + return -EFAULT; + + if (len != sizeof(opt)) + return -EINVAL; + + if (!tp->repair) + return -EPERM; + + opt.snd_wl1 = tp->snd_wl1; + opt.snd_wnd = tp->snd_wnd; + opt.max_window = tp->max_window; + opt.rcv_wnd = tp->rcv_wnd; + opt.rcv_wup = tp->rcv_wup; + + if (copy_to_sockptr(optval, &opt, len)) + return -EFAULT; + return 0; + } + case TCP_QUEUE_SEQ: + if (tp->repair_queue == TCP_SEND_QUEUE) + val = tp->write_seq; + else if (tp->repair_queue == TCP_RECV_QUEUE) + val = tp->rcv_nxt; + else + return -EINVAL; + break; + + case TCP_USER_TIMEOUT: + val = READ_ONCE(icsk->icsk_user_timeout); + break; + + case TCP_FASTOPEN: + val = READ_ONCE(icsk->icsk_accept_queue.fastopenq.max_qlen); + break; + + case TCP_FASTOPEN_CONNECT: + val = tp->fastopen_connect; + break; + + case TCP_FASTOPEN_NO_COOKIE: + val = tp->fastopen_no_cookie; + break; + + case TCP_TX_DELAY: + val = READ_ONCE(tp->tcp_tx_delay); + break; + + case TCP_TIMESTAMP: + val = tcp_time_stamp_raw() + READ_ONCE(tp->tsoffset); + break; + case TCP_NOTSENT_LOWAT: + val = READ_ONCE(tp->notsent_lowat); + break; + case TCP_INQ: + val = tp->recvmsg_inq; + break; + case TCP_SAVE_SYN: + val = tp->save_syn; + break; + case TCP_SAVED_SYN: { + if (copy_from_sockptr(&len, optlen, sizeof(int))) + return -EFAULT; + + sockopt_lock_sock(sk); + if (tp->saved_syn) { + if (len < tcp_saved_syn_len(tp->saved_syn)) { + len = tcp_saved_syn_len(tp->saved_syn); + if (copy_to_sockptr(optlen, &len, sizeof(int))) { + sockopt_release_sock(sk); + return -EFAULT; + } + sockopt_release_sock(sk); + return -EINVAL; + } + len = tcp_saved_syn_len(tp->saved_syn); + if (copy_to_sockptr(optlen, &len, sizeof(int))) { + sockopt_release_sock(sk); + return -EFAULT; + } + if (copy_to_sockptr(optval, tp->saved_syn->data, len)) { + sockopt_release_sock(sk); + return -EFAULT; + } + tcp_saved_syn_free(tp); + sockopt_release_sock(sk); + } else { + sockopt_release_sock(sk); + len = 0; + if (copy_to_sockptr(optlen, &len, sizeof(int))) + return -EFAULT; + } + return 0; + } +#ifdef CONFIG_MMU + case TCP_ZEROCOPY_RECEIVE: { + struct scm_timestamping_internal tss; + struct tcp_zerocopy_receive zc = {}; + int err; + + if (copy_from_sockptr(&len, optlen, sizeof(int))) + return -EFAULT; + if (len < 0 || + len < offsetofend(struct tcp_zerocopy_receive, length)) + return -EINVAL; + if (unlikely(len > sizeof(zc))) { + err = check_zeroed_sockptr(optval, sizeof(zc), + len - sizeof(zc)); + if (err < 1) + return err == 0 ? -EINVAL : err; + len = sizeof(zc); + if (copy_to_sockptr(optlen, &len, sizeof(int))) + return -EFAULT; + } + if (copy_from_sockptr(&zc, optval, len)) + return -EFAULT; + if (zc.reserved) + return -EINVAL; + if (zc.msg_flags & ~(TCP_VALID_ZC_MSG_FLAGS)) + return -EINVAL; + sockopt_lock_sock(sk); + err = tcp_zerocopy_receive(sk, &zc, &tss); + err = BPF_CGROUP_RUN_PROG_GETSOCKOPT_KERN(sk, level, optname, + &zc, &len, err); + sockopt_release_sock(sk); + if (len >= offsetofend(struct tcp_zerocopy_receive, msg_flags)) + goto zerocopy_rcv_cmsg; + switch (len) { + case offsetofend(struct tcp_zerocopy_receive, msg_flags): + goto zerocopy_rcv_cmsg; + case offsetofend(struct tcp_zerocopy_receive, msg_controllen): + case offsetofend(struct tcp_zerocopy_receive, msg_control): + case offsetofend(struct tcp_zerocopy_receive, flags): + case offsetofend(struct tcp_zerocopy_receive, copybuf_len): + case offsetofend(struct tcp_zerocopy_receive, copybuf_address): + case offsetofend(struct tcp_zerocopy_receive, err): + goto zerocopy_rcv_sk_err; + case offsetofend(struct tcp_zerocopy_receive, inq): + goto zerocopy_rcv_inq; + case offsetofend(struct tcp_zerocopy_receive, length): + default: + goto zerocopy_rcv_out; + } +zerocopy_rcv_cmsg: + if (zc.msg_flags & TCP_CMSG_TS) + tcp_zc_finalize_rx_tstamp(sk, &zc, &tss); + else + zc.msg_flags = 0; +zerocopy_rcv_sk_err: + if (!err) + zc.err = sock_error(sk); +zerocopy_rcv_inq: + zc.inq = tcp_inq_hint(sk); +zerocopy_rcv_out: + if (!err && copy_to_sockptr(optval, &zc, len)) + err = -EFAULT; + return err; + } +#endif + default: + return -ENOPROTOOPT; + } + + if (copy_to_sockptr(optlen, &len, sizeof(int))) + return -EFAULT; + if (copy_to_sockptr(optval, &val, len)) + return -EFAULT; + return 0; +} + +bool tcp_bpf_bypass_getsockopt(int level, int optname) +{ + /* TCP do_tcp_getsockopt has optimized getsockopt implementation + * to avoid extra socket lock for TCP_ZEROCOPY_RECEIVE. + */ + if (level == SOL_TCP && optname == TCP_ZEROCOPY_RECEIVE) + return true; + + return false; +} +EXPORT_SYMBOL(tcp_bpf_bypass_getsockopt); + +int tcp_getsockopt(struct sock *sk, int level, int optname, char __user *optval, + int __user *optlen) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + if (level != SOL_TCP) + /* Paired with WRITE_ONCE() in do_ipv6_setsockopt() and tcp_v6_connect() */ + return READ_ONCE(icsk->icsk_af_ops)->getsockopt(sk, level, optname, + optval, optlen); + return do_tcp_getsockopt(sk, level, optname, USER_SOCKPTR(optval), + USER_SOCKPTR(optlen)); +} +EXPORT_SYMBOL(tcp_getsockopt); + +#ifdef CONFIG_TCP_MD5SIG +static DEFINE_PER_CPU(struct tcp_md5sig_pool, tcp_md5sig_pool); +static DEFINE_MUTEX(tcp_md5sig_mutex); +static bool tcp_md5sig_pool_populated = false; + +static void __tcp_alloc_md5sig_pool(void) +{ + struct crypto_ahash *hash; + int cpu; + + hash = crypto_alloc_ahash("md5", 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(hash)) + return; + + for_each_possible_cpu(cpu) { + void *scratch = per_cpu(tcp_md5sig_pool, cpu).scratch; + struct ahash_request *req; + + if (!scratch) { + scratch = kmalloc_node(sizeof(union tcp_md5sum_block) + + sizeof(struct tcphdr), + GFP_KERNEL, + cpu_to_node(cpu)); + if (!scratch) + return; + per_cpu(tcp_md5sig_pool, cpu).scratch = scratch; + } + if (per_cpu(tcp_md5sig_pool, cpu).md5_req) + continue; + + req = ahash_request_alloc(hash, GFP_KERNEL); + if (!req) + return; + + ahash_request_set_callback(req, 0, NULL, NULL); + + per_cpu(tcp_md5sig_pool, cpu).md5_req = req; + } + /* before setting tcp_md5sig_pool_populated, we must commit all writes + * to memory. See smp_rmb() in tcp_get_md5sig_pool() + */ + smp_wmb(); + /* Paired with READ_ONCE() from tcp_alloc_md5sig_pool() + * and tcp_get_md5sig_pool(). + */ + WRITE_ONCE(tcp_md5sig_pool_populated, true); +} + +bool tcp_alloc_md5sig_pool(void) +{ + /* Paired with WRITE_ONCE() from __tcp_alloc_md5sig_pool() */ + if (unlikely(!READ_ONCE(tcp_md5sig_pool_populated))) { + mutex_lock(&tcp_md5sig_mutex); + + if (!tcp_md5sig_pool_populated) { + __tcp_alloc_md5sig_pool(); + if (tcp_md5sig_pool_populated) + static_branch_inc(&tcp_md5_needed); + } + + mutex_unlock(&tcp_md5sig_mutex); + } + /* Paired with WRITE_ONCE() from __tcp_alloc_md5sig_pool() */ + return READ_ONCE(tcp_md5sig_pool_populated); +} +EXPORT_SYMBOL(tcp_alloc_md5sig_pool); + + +/** + * tcp_get_md5sig_pool - get md5sig_pool for this user + * + * We use percpu structure, so if we succeed, we exit with preemption + * and BH disabled, to make sure another thread or softirq handling + * wont try to get same context. + */ +struct tcp_md5sig_pool *tcp_get_md5sig_pool(void) +{ + local_bh_disable(); + + /* Paired with WRITE_ONCE() from __tcp_alloc_md5sig_pool() */ + if (READ_ONCE(tcp_md5sig_pool_populated)) { + /* coupled with smp_wmb() in __tcp_alloc_md5sig_pool() */ + smp_rmb(); + return this_cpu_ptr(&tcp_md5sig_pool); + } + local_bh_enable(); + return NULL; +} +EXPORT_SYMBOL(tcp_get_md5sig_pool); + +int tcp_md5_hash_skb_data(struct tcp_md5sig_pool *hp, + const struct sk_buff *skb, unsigned int header_len) +{ + struct scatterlist sg; + const struct tcphdr *tp = tcp_hdr(skb); + struct ahash_request *req = hp->md5_req; + unsigned int i; + const unsigned int head_data_len = skb_headlen(skb) > header_len ? + skb_headlen(skb) - header_len : 0; + const struct skb_shared_info *shi = skb_shinfo(skb); + struct sk_buff *frag_iter; + + sg_init_table(&sg, 1); + + sg_set_buf(&sg, ((u8 *) tp) + header_len, head_data_len); + ahash_request_set_crypt(req, &sg, NULL, head_data_len); + if (crypto_ahash_update(req)) + return 1; + + for (i = 0; i < shi->nr_frags; ++i) { + const skb_frag_t *f = &shi->frags[i]; + unsigned int offset = skb_frag_off(f); + struct page *page = skb_frag_page(f) + (offset >> PAGE_SHIFT); + + sg_set_page(&sg, page, skb_frag_size(f), + offset_in_page(offset)); + ahash_request_set_crypt(req, &sg, NULL, skb_frag_size(f)); + if (crypto_ahash_update(req)) + return 1; + } + + skb_walk_frags(skb, frag_iter) + if (tcp_md5_hash_skb_data(hp, frag_iter, 0)) + return 1; + + return 0; +} +EXPORT_SYMBOL(tcp_md5_hash_skb_data); + +int tcp_md5_hash_key(struct tcp_md5sig_pool *hp, const struct tcp_md5sig_key *key) +{ + u8 keylen = READ_ONCE(key->keylen); /* paired with WRITE_ONCE() in tcp_md5_do_add */ + struct scatterlist sg; + + sg_init_one(&sg, key->key, keylen); + ahash_request_set_crypt(hp->md5_req, &sg, NULL, keylen); + + /* We use data_race() because tcp_md5_do_add() might change key->key under us */ + return data_race(crypto_ahash_update(hp->md5_req)); +} +EXPORT_SYMBOL(tcp_md5_hash_key); + +/* Called with rcu_read_lock() */ +enum skb_drop_reason +tcp_inbound_md5_hash(const struct sock *sk, const struct sk_buff *skb, + const void *saddr, const void *daddr, + int family, int dif, int sdif) +{ + /* + * This gets called for each TCP segment that arrives + * so we want to be efficient. + * We have 3 drop cases: + * o No MD5 hash and one expected. + * o MD5 hash and we're not expecting one. + * o MD5 hash and its wrong. + */ + const __u8 *hash_location = NULL; + struct tcp_md5sig_key *hash_expected; + const struct tcphdr *th = tcp_hdr(skb); + struct tcp_sock *tp = tcp_sk(sk); + int genhash, l3index; + u8 newhash[16]; + + /* sdif set, means packet ingressed via a device + * in an L3 domain and dif is set to the l3mdev + */ + l3index = sdif ? dif : 0; + + hash_expected = tcp_md5_do_lookup(sk, l3index, saddr, family); + hash_location = tcp_parse_md5sig_option(th); + + /* We've parsed the options - do we have a hash? */ + if (!hash_expected && !hash_location) + return SKB_NOT_DROPPED_YET; + + if (hash_expected && !hash_location) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5NOTFOUND); + return SKB_DROP_REASON_TCP_MD5NOTFOUND; + } + + if (!hash_expected && hash_location) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5UNEXPECTED); + return SKB_DROP_REASON_TCP_MD5UNEXPECTED; + } + + /* Check the signature. + * To support dual stack listeners, we need to handle + * IPv4-mapped case. + */ + if (family == AF_INET) + genhash = tcp_v4_md5_hash_skb(newhash, + hash_expected, + NULL, skb); + else + genhash = tp->af_specific->calc_md5_hash(newhash, + hash_expected, + NULL, skb); + + if (genhash || memcmp(hash_location, newhash, 16) != 0) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMD5FAILURE); + if (family == AF_INET) { + net_info_ratelimited("MD5 Hash failed for (%pI4, %d)->(%pI4, %d)%s L3 index %d\n", + saddr, ntohs(th->source), + daddr, ntohs(th->dest), + genhash ? " tcp_v4_calc_md5_hash failed" + : "", l3index); + } else { + net_info_ratelimited("MD5 Hash %s for [%pI6c]:%u->[%pI6c]:%u L3 index %d\n", + genhash ? "failed" : "mismatch", + saddr, ntohs(th->source), + daddr, ntohs(th->dest), l3index); + } + return SKB_DROP_REASON_TCP_MD5FAILURE; + } + return SKB_NOT_DROPPED_YET; +} +EXPORT_SYMBOL(tcp_inbound_md5_hash); + +#endif + +void tcp_done(struct sock *sk) +{ + struct request_sock *req; + + /* We might be called with a new socket, after + * inet_csk_prepare_forced_close() has been called + * so we can not use lockdep_sock_is_held(sk) + */ + req = rcu_dereference_protected(tcp_sk(sk)->fastopen_rsk, 1); + + if (sk->sk_state == TCP_SYN_SENT || sk->sk_state == TCP_SYN_RECV) + TCP_INC_STATS(sock_net(sk), TCP_MIB_ATTEMPTFAILS); + + tcp_set_state(sk, TCP_CLOSE); + tcp_clear_xmit_timers(sk); + if (req) + reqsk_fastopen_remove(sk, req, false); + + WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK); + + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_state_change(sk); + else + inet_csk_destroy_sock(sk); +} +EXPORT_SYMBOL_GPL(tcp_done); + +int tcp_abort(struct sock *sk, int err) +{ + int state = inet_sk_state_load(sk); + + if (state == TCP_NEW_SYN_RECV) { + struct request_sock *req = inet_reqsk(sk); + + local_bh_disable(); + inet_csk_reqsk_queue_drop(req->rsk_listener, req); + local_bh_enable(); + return 0; + } + if (state == TCP_TIME_WAIT) { + struct inet_timewait_sock *tw = inet_twsk(sk); + + refcount_inc(&tw->tw_refcnt); + local_bh_disable(); + inet_twsk_deschedule_put(tw); + local_bh_enable(); + return 0; + } + + /* Don't race with userspace socket closes such as tcp_close. */ + lock_sock(sk); + + if (sk->sk_state == TCP_LISTEN) { + tcp_set_state(sk, TCP_CLOSE); + inet_csk_listen_stop(sk); + } + + /* Don't race with BH socket closes such as inet_csk_listen_stop. */ + local_bh_disable(); + bh_lock_sock(sk); + + if (!sock_flag(sk, SOCK_DEAD)) { + sk->sk_err = err; + /* This barrier is coupled with smp_rmb() in tcp_poll() */ + smp_wmb(); + sk_error_report(sk); + if (tcp_need_reset(sk->sk_state)) + tcp_send_active_reset(sk, GFP_ATOMIC); + tcp_done(sk); + } + + bh_unlock_sock(sk); + local_bh_enable(); + tcp_write_queue_purge(sk); + release_sock(sk); + return 0; +} +EXPORT_SYMBOL_GPL(tcp_abort); + +extern struct tcp_congestion_ops tcp_reno; + +static __initdata unsigned long thash_entries; +static int __init set_thash_entries(char *str) +{ + ssize_t ret; + + if (!str) + return 0; + + ret = kstrtoul(str, 0, &thash_entries); + if (ret) + return 0; + + return 1; +} +__setup("thash_entries=", set_thash_entries); + +static void __init tcp_init_mem(void) +{ + unsigned long limit = nr_free_buffer_pages() / 16; + + limit = max(limit, 128UL); + sysctl_tcp_mem[0] = limit / 4 * 3; /* 4.68 % */ + sysctl_tcp_mem[1] = limit; /* 6.25 % */ + sysctl_tcp_mem[2] = sysctl_tcp_mem[0] * 2; /* 9.37 % */ +} + +void __init tcp_init(void) +{ + int max_rshare, max_wshare, cnt; + unsigned long limit; + unsigned int i; + + BUILD_BUG_ON(TCP_MIN_SND_MSS <= MAX_TCP_OPTION_SPACE); + BUILD_BUG_ON(sizeof(struct tcp_skb_cb) > + sizeof_field(struct sk_buff, cb)); + + percpu_counter_init(&tcp_sockets_allocated, 0, GFP_KERNEL); + + timer_setup(&tcp_orphan_timer, tcp_orphan_update, TIMER_DEFERRABLE); + mod_timer(&tcp_orphan_timer, jiffies + TCP_ORPHAN_TIMER_PERIOD); + + inet_hashinfo2_init(&tcp_hashinfo, "tcp_listen_portaddr_hash", + thash_entries, 21, /* one slot per 2 MB*/ + 0, 64 * 1024); + tcp_hashinfo.bind_bucket_cachep = + kmem_cache_create("tcp_bind_bucket", + sizeof(struct inet_bind_bucket), 0, + SLAB_HWCACHE_ALIGN | SLAB_PANIC | + SLAB_ACCOUNT, + NULL); + tcp_hashinfo.bind2_bucket_cachep = + kmem_cache_create("tcp_bind2_bucket", + sizeof(struct inet_bind2_bucket), 0, + SLAB_HWCACHE_ALIGN | SLAB_PANIC | + SLAB_ACCOUNT, + NULL); + + /* Size and allocate the main established and bind bucket + * hash tables. + * + * The methodology is similar to that of the buffer cache. + */ + tcp_hashinfo.ehash = + alloc_large_system_hash("TCP established", + sizeof(struct inet_ehash_bucket), + thash_entries, + 17, /* one slot per 128 KB of memory */ + 0, + NULL, + &tcp_hashinfo.ehash_mask, + 0, + thash_entries ? 0 : 512 * 1024); + for (i = 0; i <= tcp_hashinfo.ehash_mask; i++) + INIT_HLIST_NULLS_HEAD(&tcp_hashinfo.ehash[i].chain, i); + + if (inet_ehash_locks_alloc(&tcp_hashinfo)) + panic("TCP: failed to alloc ehash_locks"); + tcp_hashinfo.bhash = + alloc_large_system_hash("TCP bind", + 2 * sizeof(struct inet_bind_hashbucket), + tcp_hashinfo.ehash_mask + 1, + 17, /* one slot per 128 KB of memory */ + 0, + &tcp_hashinfo.bhash_size, + NULL, + 0, + 64 * 1024); + tcp_hashinfo.bhash_size = 1U << tcp_hashinfo.bhash_size; + tcp_hashinfo.bhash2 = tcp_hashinfo.bhash + tcp_hashinfo.bhash_size; + for (i = 0; i < tcp_hashinfo.bhash_size; i++) { + spin_lock_init(&tcp_hashinfo.bhash[i].lock); + INIT_HLIST_HEAD(&tcp_hashinfo.bhash[i].chain); + spin_lock_init(&tcp_hashinfo.bhash2[i].lock); + INIT_HLIST_HEAD(&tcp_hashinfo.bhash2[i].chain); + } + + tcp_hashinfo.pernet = false; + + cnt = tcp_hashinfo.ehash_mask + 1; + sysctl_tcp_max_orphans = cnt / 2; + + tcp_init_mem(); + /* Set per-socket limits to no more than 1/128 the pressure threshold */ + limit = nr_free_buffer_pages() << (PAGE_SHIFT - 7); + max_wshare = min(4UL*1024*1024, limit); + max_rshare = min(6UL*1024*1024, limit); + + init_net.ipv4.sysctl_tcp_wmem[0] = PAGE_SIZE; + init_net.ipv4.sysctl_tcp_wmem[1] = 16*1024; + init_net.ipv4.sysctl_tcp_wmem[2] = max(64*1024, max_wshare); + + init_net.ipv4.sysctl_tcp_rmem[0] = PAGE_SIZE; + init_net.ipv4.sysctl_tcp_rmem[1] = 131072; + init_net.ipv4.sysctl_tcp_rmem[2] = max(131072, max_rshare); + + pr_info("Hash tables configured (established %u bind %u)\n", + tcp_hashinfo.ehash_mask + 1, tcp_hashinfo.bhash_size); + + tcp_v4_init(); + tcp_metrics_init(); + BUG_ON(tcp_register_congestion_control(&tcp_reno) != 0); + tcp_tasklet_init(); + mptcp_init(); +} diff --git a/net/ipv4/tcp_bbr.c b/net/ipv4/tcp_bbr.c new file mode 100644 index 000000000..54eec33c6 --- /dev/null +++ b/net/ipv4/tcp_bbr.c @@ -0,0 +1,1202 @@ +/* Bottleneck Bandwidth and RTT (BBR) congestion control + * + * BBR congestion control computes the sending rate based on the delivery + * rate (throughput) estimated from ACKs. In a nutshell: + * + * On each ACK, update our model of the network path: + * bottleneck_bandwidth = windowed_max(delivered / elapsed, 10 round trips) + * min_rtt = windowed_min(rtt, 10 seconds) + * pacing_rate = pacing_gain * bottleneck_bandwidth + * cwnd = max(cwnd_gain * bottleneck_bandwidth * min_rtt, 4) + * + * The core algorithm does not react directly to packet losses or delays, + * although BBR may adjust the size of next send per ACK when loss is + * observed, or adjust the sending rate if it estimates there is a + * traffic policer, in order to keep the drop rate reasonable. + * + * Here is a state transition diagram for BBR: + * + * | + * V + * +---> STARTUP ----+ + * | | | + * | V | + * | DRAIN ----+ + * | | | + * | V | + * +---> PROBE_BW ----+ + * | ^ | | + * | | | | + * | +----+ | + * | | + * +---- PROBE_RTT <--+ + * + * A BBR flow starts in STARTUP, and ramps up its sending rate quickly. + * When it estimates the pipe is full, it enters DRAIN to drain the queue. + * In steady state a BBR flow only uses PROBE_BW and PROBE_RTT. + * A long-lived BBR flow spends the vast majority of its time remaining + * (repeatedly) in PROBE_BW, fully probing and utilizing the pipe's bandwidth + * in a fair manner, with a small, bounded queue. *If* a flow has been + * continuously sending for the entire min_rtt window, and hasn't seen an RTT + * sample that matches or decreases its min_rtt estimate for 10 seconds, then + * it briefly enters PROBE_RTT to cut inflight to a minimum value to re-probe + * the path's two-way propagation delay (min_rtt). When exiting PROBE_RTT, if + * we estimated that we reached the full bw of the pipe then we enter PROBE_BW; + * otherwise we enter STARTUP to try to fill the pipe. + * + * BBR is described in detail in: + * "BBR: Congestion-Based Congestion Control", + * Neal Cardwell, Yuchung Cheng, C. Stephen Gunn, Soheil Hassas Yeganeh, + * Van Jacobson. ACM Queue, Vol. 14 No. 5, September-October 2016. + * + * There is a public e-mail list for discussing BBR development and testing: + * https://groups.google.com/forum/#!forum/bbr-dev + * + * NOTE: BBR might be used with the fq qdisc ("man tc-fq") with pacing enabled, + * otherwise TCP stack falls back to an internal pacing using one high + * resolution timer per TCP socket and may use more resources. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +/* Scale factor for rate in pkt/uSec unit to avoid truncation in bandwidth + * estimation. The rate unit ~= (1500 bytes / 1 usec / 2^24) ~= 715 bps. + * This handles bandwidths from 0.06pps (715bps) to 256Mpps (3Tbps) in a u32. + * Since the minimum window is >=4 packets, the lower bound isn't + * an issue. The upper bound isn't an issue with existing technologies. + */ +#define BW_SCALE 24 +#define BW_UNIT (1 << BW_SCALE) + +#define BBR_SCALE 8 /* scaling factor for fractions in BBR (e.g. gains) */ +#define BBR_UNIT (1 << BBR_SCALE) + +/* BBR has the following modes for deciding how fast to send: */ +enum bbr_mode { + BBR_STARTUP, /* ramp up sending rate rapidly to fill pipe */ + BBR_DRAIN, /* drain any queue created during startup */ + BBR_PROBE_BW, /* discover, share bw: pace around estimated bw */ + BBR_PROBE_RTT, /* cut inflight to min to probe min_rtt */ +}; + +/* BBR congestion control block */ +struct bbr { + u32 min_rtt_us; /* min RTT in min_rtt_win_sec window */ + u32 min_rtt_stamp; /* timestamp of min_rtt_us */ + u32 probe_rtt_done_stamp; /* end time for BBR_PROBE_RTT mode */ + struct minmax bw; /* Max recent delivery rate in pkts/uS << 24 */ + u32 rtt_cnt; /* count of packet-timed rounds elapsed */ + u32 next_rtt_delivered; /* scb->tx.delivered at end of round */ + u64 cycle_mstamp; /* time of this cycle phase start */ + u32 mode:3, /* current bbr_mode in state machine */ + prev_ca_state:3, /* CA state on previous ACK */ + packet_conservation:1, /* use packet conservation? */ + round_start:1, /* start of packet-timed tx->ack round? */ + idle_restart:1, /* restarting after idle? */ + probe_rtt_round_done:1, /* a BBR_PROBE_RTT round at 4 pkts? */ + unused:13, + lt_is_sampling:1, /* taking long-term ("LT") samples now? */ + lt_rtt_cnt:7, /* round trips in long-term interval */ + lt_use_bw:1; /* use lt_bw as our bw estimate? */ + u32 lt_bw; /* LT est delivery rate in pkts/uS << 24 */ + u32 lt_last_delivered; /* LT intvl start: tp->delivered */ + u32 lt_last_stamp; /* LT intvl start: tp->delivered_mstamp */ + u32 lt_last_lost; /* LT intvl start: tp->lost */ + u32 pacing_gain:10, /* current gain for setting pacing rate */ + cwnd_gain:10, /* current gain for setting cwnd */ + full_bw_reached:1, /* reached full bw in Startup? */ + full_bw_cnt:2, /* number of rounds without large bw gains */ + cycle_idx:3, /* current index in pacing_gain cycle array */ + has_seen_rtt:1, /* have we seen an RTT sample yet? */ + unused_b:5; + u32 prior_cwnd; /* prior cwnd upon entering loss recovery */ + u32 full_bw; /* recent bw, to estimate if pipe is full */ + + /* For tracking ACK aggregation: */ + u64 ack_epoch_mstamp; /* start of ACK sampling epoch */ + u16 extra_acked[2]; /* max excess data ACKed in epoch */ + u32 ack_epoch_acked:20, /* packets (S)ACKed in sampling epoch */ + extra_acked_win_rtts:5, /* age of extra_acked, in round trips */ + extra_acked_win_idx:1, /* current index in extra_acked array */ + unused_c:6; +}; + +#define CYCLE_LEN 8 /* number of phases in a pacing gain cycle */ + +/* Window length of bw filter (in rounds): */ +static const int bbr_bw_rtts = CYCLE_LEN + 2; +/* Window length of min_rtt filter (in sec): */ +static const u32 bbr_min_rtt_win_sec = 10; +/* Minimum time (in ms) spent at bbr_cwnd_min_target in BBR_PROBE_RTT mode: */ +static const u32 bbr_probe_rtt_mode_ms = 200; +/* Skip TSO below the following bandwidth (bits/sec): */ +static const int bbr_min_tso_rate = 1200000; + +/* Pace at ~1% below estimated bw, on average, to reduce queue at bottleneck. + * In order to help drive the network toward lower queues and low latency while + * maintaining high utilization, the average pacing rate aims to be slightly + * lower than the estimated bandwidth. This is an important aspect of the + * design. + */ +static const int bbr_pacing_margin_percent = 1; + +/* We use a high_gain value of 2/ln(2) because it's the smallest pacing gain + * that will allow a smoothly increasing pacing rate that will double each RTT + * and send the same number of packets per RTT that an un-paced, slow-starting + * Reno or CUBIC flow would: + */ +static const int bbr_high_gain = BBR_UNIT * 2885 / 1000 + 1; +/* The pacing gain of 1/high_gain in BBR_DRAIN is calculated to typically drain + * the queue created in BBR_STARTUP in a single round: + */ +static const int bbr_drain_gain = BBR_UNIT * 1000 / 2885; +/* The gain for deriving steady-state cwnd tolerates delayed/stretched ACKs: */ +static const int bbr_cwnd_gain = BBR_UNIT * 2; +/* The pacing_gain values for the PROBE_BW gain cycle, to discover/share bw: */ +static const int bbr_pacing_gain[] = { + BBR_UNIT * 5 / 4, /* probe for more available bw */ + BBR_UNIT * 3 / 4, /* drain queue and/or yield bw to other flows */ + BBR_UNIT, BBR_UNIT, BBR_UNIT, /* cruise at 1.0*bw to utilize pipe, */ + BBR_UNIT, BBR_UNIT, BBR_UNIT /* without creating excess queue... */ +}; +/* Randomize the starting gain cycling phase over N phases: */ +static const u32 bbr_cycle_rand = 7; + +/* Try to keep at least this many packets in flight, if things go smoothly. For + * smooth functioning, a sliding window protocol ACKing every other packet + * needs at least 4 packets in flight: + */ +static const u32 bbr_cwnd_min_target = 4; + +/* To estimate if BBR_STARTUP mode (i.e. high_gain) has filled pipe... */ +/* If bw has increased significantly (1.25x), there may be more bw available: */ +static const u32 bbr_full_bw_thresh = BBR_UNIT * 5 / 4; +/* But after 3 rounds w/o significant bw growth, estimate pipe is full: */ +static const u32 bbr_full_bw_cnt = 3; + +/* "long-term" ("LT") bandwidth estimator parameters... */ +/* The minimum number of rounds in an LT bw sampling interval: */ +static const u32 bbr_lt_intvl_min_rtts = 4; +/* If lost/delivered ratio > 20%, interval is "lossy" and we may be policed: */ +static const u32 bbr_lt_loss_thresh = 50; +/* If 2 intervals have a bw ratio <= 1/8, their bw is "consistent": */ +static const u32 bbr_lt_bw_ratio = BBR_UNIT / 8; +/* If 2 intervals have a bw diff <= 4 Kbit/sec their bw is "consistent": */ +static const u32 bbr_lt_bw_diff = 4000 / 8; +/* If we estimate we're policed, use lt_bw for this many round trips: */ +static const u32 bbr_lt_bw_max_rtts = 48; + +/* Gain factor for adding extra_acked to target cwnd: */ +static const int bbr_extra_acked_gain = BBR_UNIT; +/* Window length of extra_acked window. */ +static const u32 bbr_extra_acked_win_rtts = 5; +/* Max allowed val for ack_epoch_acked, after which sampling epoch is reset */ +static const u32 bbr_ack_epoch_acked_reset_thresh = 1U << 20; +/* Time period for clamping cwnd increment due to ack aggregation */ +static const u32 bbr_extra_acked_max_us = 100 * 1000; + +static void bbr_check_probe_rtt_done(struct sock *sk); + +/* Do we estimate that STARTUP filled the pipe? */ +static bool bbr_full_bw_reached(const struct sock *sk) +{ + const struct bbr *bbr = inet_csk_ca(sk); + + return bbr->full_bw_reached; +} + +/* Return the windowed max recent bandwidth sample, in pkts/uS << BW_SCALE. */ +static u32 bbr_max_bw(const struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + return minmax_get(&bbr->bw); +} + +/* Return the estimated bandwidth of the path, in pkts/uS << BW_SCALE. */ +static u32 bbr_bw(const struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + return bbr->lt_use_bw ? bbr->lt_bw : bbr_max_bw(sk); +} + +/* Return maximum extra acked in past k-2k round trips, + * where k = bbr_extra_acked_win_rtts. + */ +static u16 bbr_extra_acked(const struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + return max(bbr->extra_acked[0], bbr->extra_acked[1]); +} + +/* Return rate in bytes per second, optionally with a gain. + * The order here is chosen carefully to avoid overflow of u64. This should + * work for input rates of up to 2.9Tbit/sec and gain of 2.89x. + */ +static u64 bbr_rate_bytes_per_sec(struct sock *sk, u64 rate, int gain) +{ + unsigned int mss = tcp_sk(sk)->mss_cache; + + rate *= mss; + rate *= gain; + rate >>= BBR_SCALE; + rate *= USEC_PER_SEC / 100 * (100 - bbr_pacing_margin_percent); + return rate >> BW_SCALE; +} + +/* Convert a BBR bw and gain factor to a pacing rate in bytes per second. */ +static unsigned long bbr_bw_to_pacing_rate(struct sock *sk, u32 bw, int gain) +{ + u64 rate = bw; + + rate = bbr_rate_bytes_per_sec(sk, rate, gain); + rate = min_t(u64, rate, sk->sk_max_pacing_rate); + return rate; +} + +/* Initialize pacing rate to: high_gain * init_cwnd / RTT. */ +static void bbr_init_pacing_rate_from_rtt(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u64 bw; + u32 rtt_us; + + if (tp->srtt_us) { /* any RTT sample yet? */ + rtt_us = max(tp->srtt_us >> 3, 1U); + bbr->has_seen_rtt = 1; + } else { /* no RTT sample yet */ + rtt_us = USEC_PER_MSEC; /* use nominal default RTT */ + } + bw = (u64)tcp_snd_cwnd(tp) * BW_UNIT; + do_div(bw, rtt_us); + sk->sk_pacing_rate = bbr_bw_to_pacing_rate(sk, bw, bbr_high_gain); +} + +/* Pace using current bw estimate and a gain factor. */ +static void bbr_set_pacing_rate(struct sock *sk, u32 bw, int gain) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + unsigned long rate = bbr_bw_to_pacing_rate(sk, bw, gain); + + if (unlikely(!bbr->has_seen_rtt && tp->srtt_us)) + bbr_init_pacing_rate_from_rtt(sk); + if (bbr_full_bw_reached(sk) || rate > sk->sk_pacing_rate) + sk->sk_pacing_rate = rate; +} + +/* override sysctl_tcp_min_tso_segs */ +static u32 bbr_min_tso_segs(struct sock *sk) +{ + return sk->sk_pacing_rate < (bbr_min_tso_rate >> 3) ? 1 : 2; +} + +static u32 bbr_tso_segs_goal(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 segs, bytes; + + /* Sort of tcp_tso_autosize() but ignoring + * driver provided sk_gso_max_size. + */ + bytes = min_t(unsigned long, + sk->sk_pacing_rate >> READ_ONCE(sk->sk_pacing_shift), + GSO_LEGACY_MAX_SIZE - 1 - MAX_TCP_HEADER); + segs = max_t(u32, bytes / tp->mss_cache, bbr_min_tso_segs(sk)); + + return min(segs, 0x7FU); +} + +/* Save "last known good" cwnd so we can restore it after losses or PROBE_RTT */ +static void bbr_save_cwnd(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr->prev_ca_state < TCP_CA_Recovery && bbr->mode != BBR_PROBE_RTT) + bbr->prior_cwnd = tcp_snd_cwnd(tp); /* this cwnd is good enough */ + else /* loss recovery or BBR_PROBE_RTT have temporarily cut cwnd */ + bbr->prior_cwnd = max(bbr->prior_cwnd, tcp_snd_cwnd(tp)); +} + +static void bbr_cwnd_event(struct sock *sk, enum tcp_ca_event event) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + if (event == CA_EVENT_TX_START && tp->app_limited) { + bbr->idle_restart = 1; + bbr->ack_epoch_mstamp = tp->tcp_mstamp; + bbr->ack_epoch_acked = 0; + /* Avoid pointless buffer overflows: pace at est. bw if we don't + * need more speed (we're restarting from idle and app-limited). + */ + if (bbr->mode == BBR_PROBE_BW) + bbr_set_pacing_rate(sk, bbr_bw(sk), BBR_UNIT); + else if (bbr->mode == BBR_PROBE_RTT) + bbr_check_probe_rtt_done(sk); + } +} + +/* Calculate bdp based on min RTT and the estimated bottleneck bandwidth: + * + * bdp = ceil(bw * min_rtt * gain) + * + * The key factor, gain, controls the amount of queue. While a small gain + * builds a smaller queue, it becomes more vulnerable to noise in RTT + * measurements (e.g., delayed ACKs or other ACK compression effects). This + * noise may cause BBR to under-estimate the rate. + */ +static u32 bbr_bdp(struct sock *sk, u32 bw, int gain) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 bdp; + u64 w; + + /* If we've never had a valid RTT sample, cap cwnd at the initial + * default. This should only happen when the connection is not using TCP + * timestamps and has retransmitted all of the SYN/SYNACK/data packets + * ACKed so far. In this case, an RTO can cut cwnd to 1, in which + * case we need to slow-start up toward something safe: TCP_INIT_CWND. + */ + if (unlikely(bbr->min_rtt_us == ~0U)) /* no valid RTT samples yet? */ + return TCP_INIT_CWND; /* be safe: cap at default initial cwnd*/ + + w = (u64)bw * bbr->min_rtt_us; + + /* Apply a gain to the given value, remove the BW_SCALE shift, and + * round the value up to avoid a negative feedback loop. + */ + bdp = (((w * gain) >> BBR_SCALE) + BW_UNIT - 1) / BW_UNIT; + + return bdp; +} + +/* To achieve full performance in high-speed paths, we budget enough cwnd to + * fit full-sized skbs in-flight on both end hosts to fully utilize the path: + * - one skb in sending host Qdisc, + * - one skb in sending host TSO/GSO engine + * - one skb being received by receiver host LRO/GRO/delayed-ACK engine + * Don't worry, at low rates (bbr_min_tso_rate) this won't bloat cwnd because + * in such cases tso_segs_goal is 1. The minimum cwnd is 4 packets, + * which allows 2 outstanding 2-packet sequences, to try to keep pipe + * full even with ACK-every-other-packet delayed ACKs. + */ +static u32 bbr_quantization_budget(struct sock *sk, u32 cwnd) +{ + struct bbr *bbr = inet_csk_ca(sk); + + /* Allow enough full-sized skbs in flight to utilize end systems. */ + cwnd += 3 * bbr_tso_segs_goal(sk); + + /* Reduce delayed ACKs by rounding up cwnd to the next even number. */ + cwnd = (cwnd + 1) & ~1U; + + /* Ensure gain cycling gets inflight above BDP even for small BDPs. */ + if (bbr->mode == BBR_PROBE_BW && bbr->cycle_idx == 0) + cwnd += 2; + + return cwnd; +} + +/* Find inflight based on min RTT and the estimated bottleneck bandwidth. */ +static u32 bbr_inflight(struct sock *sk, u32 bw, int gain) +{ + u32 inflight; + + inflight = bbr_bdp(sk, bw, gain); + inflight = bbr_quantization_budget(sk, inflight); + + return inflight; +} + +/* With pacing at lower layers, there's often less data "in the network" than + * "in flight". With TSQ and departure time pacing at lower layers (e.g. fq), + * we often have several skbs queued in the pacing layer with a pre-scheduled + * earliest departure time (EDT). BBR adapts its pacing rate based on the + * inflight level that it estimates has already been "baked in" by previous + * departure time decisions. We calculate a rough estimate of the number of our + * packets that might be in the network at the earliest departure time for the + * next skb scheduled: + * in_network_at_edt = inflight_at_edt - (EDT - now) * bw + * If we're increasing inflight, then we want to know if the transmit of the + * EDT skb will push inflight above the target, so inflight_at_edt includes + * bbr_tso_segs_goal() from the skb departing at EDT. If decreasing inflight, + * then estimate if inflight will sink too low just before the EDT transmit. + */ +static u32 bbr_packets_in_net_at_edt(struct sock *sk, u32 inflight_now) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u64 now_ns, edt_ns, interval_us; + u32 interval_delivered, inflight_at_edt; + + now_ns = tp->tcp_clock_cache; + edt_ns = max(tp->tcp_wstamp_ns, now_ns); + interval_us = div_u64(edt_ns - now_ns, NSEC_PER_USEC); + interval_delivered = (u64)bbr_bw(sk) * interval_us >> BW_SCALE; + inflight_at_edt = inflight_now; + if (bbr->pacing_gain > BBR_UNIT) /* increasing inflight */ + inflight_at_edt += bbr_tso_segs_goal(sk); /* include EDT skb */ + if (interval_delivered >= inflight_at_edt) + return 0; + return inflight_at_edt - interval_delivered; +} + +/* Find the cwnd increment based on estimate of ack aggregation */ +static u32 bbr_ack_aggregation_cwnd(struct sock *sk) +{ + u32 max_aggr_cwnd, aggr_cwnd = 0; + + if (bbr_extra_acked_gain && bbr_full_bw_reached(sk)) { + max_aggr_cwnd = ((u64)bbr_bw(sk) * bbr_extra_acked_max_us) + / BW_UNIT; + aggr_cwnd = (bbr_extra_acked_gain * bbr_extra_acked(sk)) + >> BBR_SCALE; + aggr_cwnd = min(aggr_cwnd, max_aggr_cwnd); + } + + return aggr_cwnd; +} + +/* An optimization in BBR to reduce losses: On the first round of recovery, we + * follow the packet conservation principle: send P packets per P packets acked. + * After that, we slow-start and send at most 2*P packets per P packets acked. + * After recovery finishes, or upon undo, we restore the cwnd we had when + * recovery started (capped by the target cwnd based on estimated BDP). + * + * TODO(ycheng/ncardwell): implement a rate-based approach. + */ +static bool bbr_set_cwnd_to_recover_or_restore( + struct sock *sk, const struct rate_sample *rs, u32 acked, u32 *new_cwnd) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u8 prev_state = bbr->prev_ca_state, state = inet_csk(sk)->icsk_ca_state; + u32 cwnd = tcp_snd_cwnd(tp); + + /* An ACK for P pkts should release at most 2*P packets. We do this + * in two steps. First, here we deduct the number of lost packets. + * Then, in bbr_set_cwnd() we slow start up toward the target cwnd. + */ + if (rs->losses > 0) + cwnd = max_t(s32, cwnd - rs->losses, 1); + + if (state == TCP_CA_Recovery && prev_state != TCP_CA_Recovery) { + /* Starting 1st round of Recovery, so do packet conservation. */ + bbr->packet_conservation = 1; + bbr->next_rtt_delivered = tp->delivered; /* start round now */ + /* Cut unused cwnd from app behavior, TSQ, or TSO deferral: */ + cwnd = tcp_packets_in_flight(tp) + acked; + } else if (prev_state >= TCP_CA_Recovery && state < TCP_CA_Recovery) { + /* Exiting loss recovery; restore cwnd saved before recovery. */ + cwnd = max(cwnd, bbr->prior_cwnd); + bbr->packet_conservation = 0; + } + bbr->prev_ca_state = state; + + if (bbr->packet_conservation) { + *new_cwnd = max(cwnd, tcp_packets_in_flight(tp) + acked); + return true; /* yes, using packet conservation */ + } + *new_cwnd = cwnd; + return false; +} + +/* Slow-start up toward target cwnd (if bw estimate is growing, or packet loss + * has drawn us down below target), or snap down to target if we're above it. + */ +static void bbr_set_cwnd(struct sock *sk, const struct rate_sample *rs, + u32 acked, u32 bw, int gain) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u32 cwnd = tcp_snd_cwnd(tp), target_cwnd = 0; + + if (!acked) + goto done; /* no packet fully ACKed; just apply caps */ + + if (bbr_set_cwnd_to_recover_or_restore(sk, rs, acked, &cwnd)) + goto done; + + target_cwnd = bbr_bdp(sk, bw, gain); + + /* Increment the cwnd to account for excess ACKed data that seems + * due to aggregation (of data and/or ACKs) visible in the ACK stream. + */ + target_cwnd += bbr_ack_aggregation_cwnd(sk); + target_cwnd = bbr_quantization_budget(sk, target_cwnd); + + /* If we're below target cwnd, slow start cwnd toward target cwnd. */ + if (bbr_full_bw_reached(sk)) /* only cut cwnd if we filled the pipe */ + cwnd = min(cwnd + acked, target_cwnd); + else if (cwnd < target_cwnd || tp->delivered < TCP_INIT_CWND) + cwnd = cwnd + acked; + cwnd = max(cwnd, bbr_cwnd_min_target); + +done: + tcp_snd_cwnd_set(tp, min(cwnd, tp->snd_cwnd_clamp)); /* apply global cap */ + if (bbr->mode == BBR_PROBE_RTT) /* drain queue, refresh min_rtt */ + tcp_snd_cwnd_set(tp, min(tcp_snd_cwnd(tp), bbr_cwnd_min_target)); +} + +/* End cycle phase if it's time and/or we hit the phase's in-flight target. */ +static bool bbr_is_next_cycle_phase(struct sock *sk, + const struct rate_sample *rs) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + bool is_full_length = + tcp_stamp_us_delta(tp->delivered_mstamp, bbr->cycle_mstamp) > + bbr->min_rtt_us; + u32 inflight, bw; + + /* The pacing_gain of 1.0 paces at the estimated bw to try to fully + * use the pipe without increasing the queue. + */ + if (bbr->pacing_gain == BBR_UNIT) + return is_full_length; /* just use wall clock time */ + + inflight = bbr_packets_in_net_at_edt(sk, rs->prior_in_flight); + bw = bbr_max_bw(sk); + + /* A pacing_gain > 1.0 probes for bw by trying to raise inflight to at + * least pacing_gain*BDP; this may take more than min_rtt if min_rtt is + * small (e.g. on a LAN). We do not persist if packets are lost, since + * a path with small buffers may not hold that much. + */ + if (bbr->pacing_gain > BBR_UNIT) + return is_full_length && + (rs->losses || /* perhaps pacing_gain*BDP won't fit */ + inflight >= bbr_inflight(sk, bw, bbr->pacing_gain)); + + /* A pacing_gain < 1.0 tries to drain extra queue we added if bw + * probing didn't find more bw. If inflight falls to match BDP then we + * estimate queue is drained; persisting would underutilize the pipe. + */ + return is_full_length || + inflight <= bbr_inflight(sk, bw, BBR_UNIT); +} + +static void bbr_advance_cycle_phase(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr->cycle_idx = (bbr->cycle_idx + 1) & (CYCLE_LEN - 1); + bbr->cycle_mstamp = tp->delivered_mstamp; +} + +/* Gain cycling: cycle pacing gain to converge to fair share of available bw. */ +static void bbr_update_cycle_phase(struct sock *sk, + const struct rate_sample *rs) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr->mode == BBR_PROBE_BW && bbr_is_next_cycle_phase(sk, rs)) + bbr_advance_cycle_phase(sk); +} + +static void bbr_reset_startup_mode(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr->mode = BBR_STARTUP; +} + +static void bbr_reset_probe_bw_mode(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr->mode = BBR_PROBE_BW; + bbr->cycle_idx = CYCLE_LEN - 1 - prandom_u32_max(bbr_cycle_rand); + bbr_advance_cycle_phase(sk); /* flip to next phase of gain cycle */ +} + +static void bbr_reset_mode(struct sock *sk) +{ + if (!bbr_full_bw_reached(sk)) + bbr_reset_startup_mode(sk); + else + bbr_reset_probe_bw_mode(sk); +} + +/* Start a new long-term sampling interval. */ +static void bbr_reset_lt_bw_sampling_interval(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr->lt_last_stamp = div_u64(tp->delivered_mstamp, USEC_PER_MSEC); + bbr->lt_last_delivered = tp->delivered; + bbr->lt_last_lost = tp->lost; + bbr->lt_rtt_cnt = 0; +} + +/* Completely reset long-term bandwidth sampling. */ +static void bbr_reset_lt_bw_sampling(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr->lt_bw = 0; + bbr->lt_use_bw = 0; + bbr->lt_is_sampling = false; + bbr_reset_lt_bw_sampling_interval(sk); +} + +/* Long-term bw sampling interval is done. Estimate whether we're policed. */ +static void bbr_lt_bw_interval_done(struct sock *sk, u32 bw) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 diff; + + if (bbr->lt_bw) { /* do we have bw from a previous interval? */ + /* Is new bw close to the lt_bw from the previous interval? */ + diff = abs(bw - bbr->lt_bw); + if ((diff * BBR_UNIT <= bbr_lt_bw_ratio * bbr->lt_bw) || + (bbr_rate_bytes_per_sec(sk, diff, BBR_UNIT) <= + bbr_lt_bw_diff)) { + /* All criteria are met; estimate we're policed. */ + bbr->lt_bw = (bw + bbr->lt_bw) >> 1; /* avg 2 intvls */ + bbr->lt_use_bw = 1; + bbr->pacing_gain = BBR_UNIT; /* try to avoid drops */ + bbr->lt_rtt_cnt = 0; + return; + } + } + bbr->lt_bw = bw; + bbr_reset_lt_bw_sampling_interval(sk); +} + +/* Token-bucket traffic policers are common (see "An Internet-Wide Analysis of + * Traffic Policing", SIGCOMM 2016). BBR detects token-bucket policers and + * explicitly models their policed rate, to reduce unnecessary losses. We + * estimate that we're policed if we see 2 consecutive sampling intervals with + * consistent throughput and high packet loss. If we think we're being policed, + * set lt_bw to the "long-term" average delivery rate from those 2 intervals. + */ +static void bbr_lt_bw_sampling(struct sock *sk, const struct rate_sample *rs) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u32 lost, delivered; + u64 bw; + u32 t; + + if (bbr->lt_use_bw) { /* already using long-term rate, lt_bw? */ + if (bbr->mode == BBR_PROBE_BW && bbr->round_start && + ++bbr->lt_rtt_cnt >= bbr_lt_bw_max_rtts) { + bbr_reset_lt_bw_sampling(sk); /* stop using lt_bw */ + bbr_reset_probe_bw_mode(sk); /* restart gain cycling */ + } + return; + } + + /* Wait for the first loss before sampling, to let the policer exhaust + * its tokens and estimate the steady-state rate allowed by the policer. + * Starting samples earlier includes bursts that over-estimate the bw. + */ + if (!bbr->lt_is_sampling) { + if (!rs->losses) + return; + bbr_reset_lt_bw_sampling_interval(sk); + bbr->lt_is_sampling = true; + } + + /* To avoid underestimates, reset sampling if we run out of data. */ + if (rs->is_app_limited) { + bbr_reset_lt_bw_sampling(sk); + return; + } + + if (bbr->round_start) + bbr->lt_rtt_cnt++; /* count round trips in this interval */ + if (bbr->lt_rtt_cnt < bbr_lt_intvl_min_rtts) + return; /* sampling interval needs to be longer */ + if (bbr->lt_rtt_cnt > 4 * bbr_lt_intvl_min_rtts) { + bbr_reset_lt_bw_sampling(sk); /* interval is too long */ + return; + } + + /* End sampling interval when a packet is lost, so we estimate the + * policer tokens were exhausted. Stopping the sampling before the + * tokens are exhausted under-estimates the policed rate. + */ + if (!rs->losses) + return; + + /* Calculate packets lost and delivered in sampling interval. */ + lost = tp->lost - bbr->lt_last_lost; + delivered = tp->delivered - bbr->lt_last_delivered; + /* Is loss rate (lost/delivered) >= lt_loss_thresh? If not, wait. */ + if (!delivered || (lost << BBR_SCALE) < bbr_lt_loss_thresh * delivered) + return; + + /* Find average delivery rate in this sampling interval. */ + t = div_u64(tp->delivered_mstamp, USEC_PER_MSEC) - bbr->lt_last_stamp; + if ((s32)t < 1) + return; /* interval is less than one ms, so wait */ + /* Check if can multiply without overflow */ + if (t >= ~0U / USEC_PER_MSEC) { + bbr_reset_lt_bw_sampling(sk); /* interval too long; reset */ + return; + } + t *= USEC_PER_MSEC; + bw = (u64)delivered * BW_UNIT; + do_div(bw, t); + bbr_lt_bw_interval_done(sk, bw); +} + +/* Estimate the bandwidth based on how fast packets are delivered */ +static void bbr_update_bw(struct sock *sk, const struct rate_sample *rs) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u64 bw; + + bbr->round_start = 0; + if (rs->delivered < 0 || rs->interval_us <= 0) + return; /* Not a valid observation */ + + /* See if we've reached the next RTT */ + if (!before(rs->prior_delivered, bbr->next_rtt_delivered)) { + bbr->next_rtt_delivered = tp->delivered; + bbr->rtt_cnt++; + bbr->round_start = 1; + bbr->packet_conservation = 0; + } + + bbr_lt_bw_sampling(sk, rs); + + /* Divide delivered by the interval to find a (lower bound) bottleneck + * bandwidth sample. Delivered is in packets and interval_us in uS and + * ratio will be <<1 for most connections. So delivered is first scaled. + */ + bw = div64_long((u64)rs->delivered * BW_UNIT, rs->interval_us); + + /* If this sample is application-limited, it is likely to have a very + * low delivered count that represents application behavior rather than + * the available network rate. Such a sample could drag down estimated + * bw, causing needless slow-down. Thus, to continue to send at the + * last measured network rate, we filter out app-limited samples unless + * they describe the path bw at least as well as our bw model. + * + * So the goal during app-limited phase is to proceed with the best + * network rate no matter how long. We automatically leave this + * phase when app writes faster than the network can deliver :) + */ + if (!rs->is_app_limited || bw >= bbr_max_bw(sk)) { + /* Incorporate new sample into our max bw filter. */ + minmax_running_max(&bbr->bw, bbr_bw_rtts, bbr->rtt_cnt, bw); + } +} + +/* Estimates the windowed max degree of ack aggregation. + * This is used to provision extra in-flight data to keep sending during + * inter-ACK silences. + * + * Degree of ack aggregation is estimated as extra data acked beyond expected. + * + * max_extra_acked = "maximum recent excess data ACKed beyond max_bw * interval" + * cwnd += max_extra_acked + * + * Max extra_acked is clamped by cwnd and bw * bbr_extra_acked_max_us (100 ms). + * Max filter is an approximate sliding window of 5-10 (packet timed) round + * trips. + */ +static void bbr_update_ack_aggregation(struct sock *sk, + const struct rate_sample *rs) +{ + u32 epoch_us, expected_acked, extra_acked; + struct bbr *bbr = inet_csk_ca(sk); + struct tcp_sock *tp = tcp_sk(sk); + + if (!bbr_extra_acked_gain || rs->acked_sacked <= 0 || + rs->delivered < 0 || rs->interval_us <= 0) + return; + + if (bbr->round_start) { + bbr->extra_acked_win_rtts = min(0x1F, + bbr->extra_acked_win_rtts + 1); + if (bbr->extra_acked_win_rtts >= bbr_extra_acked_win_rtts) { + bbr->extra_acked_win_rtts = 0; + bbr->extra_acked_win_idx = bbr->extra_acked_win_idx ? + 0 : 1; + bbr->extra_acked[bbr->extra_acked_win_idx] = 0; + } + } + + /* Compute how many packets we expected to be delivered over epoch. */ + epoch_us = tcp_stamp_us_delta(tp->delivered_mstamp, + bbr->ack_epoch_mstamp); + expected_acked = ((u64)bbr_bw(sk) * epoch_us) / BW_UNIT; + + /* Reset the aggregation epoch if ACK rate is below expected rate or + * significantly large no. of ack received since epoch (potentially + * quite old epoch). + */ + if (bbr->ack_epoch_acked <= expected_acked || + (bbr->ack_epoch_acked + rs->acked_sacked >= + bbr_ack_epoch_acked_reset_thresh)) { + bbr->ack_epoch_acked = 0; + bbr->ack_epoch_mstamp = tp->delivered_mstamp; + expected_acked = 0; + } + + /* Compute excess data delivered, beyond what was expected. */ + bbr->ack_epoch_acked = min_t(u32, 0xFFFFF, + bbr->ack_epoch_acked + rs->acked_sacked); + extra_acked = bbr->ack_epoch_acked - expected_acked; + extra_acked = min(extra_acked, tcp_snd_cwnd(tp)); + if (extra_acked > bbr->extra_acked[bbr->extra_acked_win_idx]) + bbr->extra_acked[bbr->extra_acked_win_idx] = extra_acked; +} + +/* Estimate when the pipe is full, using the change in delivery rate: BBR + * estimates that STARTUP filled the pipe if the estimated bw hasn't changed by + * at least bbr_full_bw_thresh (25%) after bbr_full_bw_cnt (3) non-app-limited + * rounds. Why 3 rounds: 1: rwin autotuning grows the rwin, 2: we fill the + * higher rwin, 3: we get higher delivery rate samples. Or transient + * cross-traffic or radio noise can go away. CUBIC Hystart shares a similar + * design goal, but uses delay and inter-ACK spacing instead of bandwidth. + */ +static void bbr_check_full_bw_reached(struct sock *sk, + const struct rate_sample *rs) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 bw_thresh; + + if (bbr_full_bw_reached(sk) || !bbr->round_start || rs->is_app_limited) + return; + + bw_thresh = (u64)bbr->full_bw * bbr_full_bw_thresh >> BBR_SCALE; + if (bbr_max_bw(sk) >= bw_thresh) { + bbr->full_bw = bbr_max_bw(sk); + bbr->full_bw_cnt = 0; + return; + } + ++bbr->full_bw_cnt; + bbr->full_bw_reached = bbr->full_bw_cnt >= bbr_full_bw_cnt; +} + +/* If pipe is probably full, drain the queue and then enter steady-state. */ +static void bbr_check_drain(struct sock *sk, const struct rate_sample *rs) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (bbr->mode == BBR_STARTUP && bbr_full_bw_reached(sk)) { + bbr->mode = BBR_DRAIN; /* drain queue we created */ + tcp_sk(sk)->snd_ssthresh = + bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT); + } /* fall through to check if in-flight is already small: */ + if (bbr->mode == BBR_DRAIN && + bbr_packets_in_net_at_edt(sk, tcp_packets_in_flight(tcp_sk(sk))) <= + bbr_inflight(sk, bbr_max_bw(sk), BBR_UNIT)) + bbr_reset_probe_bw_mode(sk); /* we estimate queue is drained */ +} + +static void bbr_check_probe_rtt_done(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + if (!(bbr->probe_rtt_done_stamp && + after(tcp_jiffies32, bbr->probe_rtt_done_stamp))) + return; + + bbr->min_rtt_stamp = tcp_jiffies32; /* wait a while until PROBE_RTT */ + tcp_snd_cwnd_set(tp, max(tcp_snd_cwnd(tp), bbr->prior_cwnd)); + bbr_reset_mode(sk); +} + +/* The goal of PROBE_RTT mode is to have BBR flows cooperatively and + * periodically drain the bottleneck queue, to converge to measure the true + * min_rtt (unloaded propagation delay). This allows the flows to keep queues + * small (reducing queuing delay and packet loss) and achieve fairness among + * BBR flows. + * + * The min_rtt filter window is 10 seconds. When the min_rtt estimate expires, + * we enter PROBE_RTT mode and cap the cwnd at bbr_cwnd_min_target=4 packets. + * After at least bbr_probe_rtt_mode_ms=200ms and at least one packet-timed + * round trip elapsed with that flight size <= 4, we leave PROBE_RTT mode and + * re-enter the previous mode. BBR uses 200ms to approximately bound the + * performance penalty of PROBE_RTT's cwnd capping to roughly 2% (200ms/10s). + * + * Note that flows need only pay 2% if they are busy sending over the last 10 + * seconds. Interactive applications (e.g., Web, RPCs, video chunks) often have + * natural silences or low-rate periods within 10 seconds where the rate is low + * enough for long enough to drain its queue in the bottleneck. We pick up + * these min RTT measurements opportunistically with our min_rtt filter. :-) + */ +static void bbr_update_min_rtt(struct sock *sk, const struct rate_sample *rs) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + bool filter_expired; + + /* Track min RTT seen in the min_rtt_win_sec filter window: */ + filter_expired = after(tcp_jiffies32, + bbr->min_rtt_stamp + bbr_min_rtt_win_sec * HZ); + if (rs->rtt_us >= 0 && + (rs->rtt_us < bbr->min_rtt_us || + (filter_expired && !rs->is_ack_delayed))) { + bbr->min_rtt_us = rs->rtt_us; + bbr->min_rtt_stamp = tcp_jiffies32; + } + + if (bbr_probe_rtt_mode_ms > 0 && filter_expired && + !bbr->idle_restart && bbr->mode != BBR_PROBE_RTT) { + bbr->mode = BBR_PROBE_RTT; /* dip, drain queue */ + bbr_save_cwnd(sk); /* note cwnd so we can restore it */ + bbr->probe_rtt_done_stamp = 0; + } + + if (bbr->mode == BBR_PROBE_RTT) { + /* Ignore low rate samples during this mode. */ + tp->app_limited = + (tp->delivered + tcp_packets_in_flight(tp)) ? : 1; + /* Maintain min packets in flight for max(200 ms, 1 round). */ + if (!bbr->probe_rtt_done_stamp && + tcp_packets_in_flight(tp) <= bbr_cwnd_min_target) { + bbr->probe_rtt_done_stamp = tcp_jiffies32 + + msecs_to_jiffies(bbr_probe_rtt_mode_ms); + bbr->probe_rtt_round_done = 0; + bbr->next_rtt_delivered = tp->delivered; + } else if (bbr->probe_rtt_done_stamp) { + if (bbr->round_start) + bbr->probe_rtt_round_done = 1; + if (bbr->probe_rtt_round_done) + bbr_check_probe_rtt_done(sk); + } + } + /* Restart after idle ends only once we process a new S/ACK for data */ + if (rs->delivered > 0) + bbr->idle_restart = 0; +} + +static void bbr_update_gains(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + switch (bbr->mode) { + case BBR_STARTUP: + bbr->pacing_gain = bbr_high_gain; + bbr->cwnd_gain = bbr_high_gain; + break; + case BBR_DRAIN: + bbr->pacing_gain = bbr_drain_gain; /* slow, to drain */ + bbr->cwnd_gain = bbr_high_gain; /* keep cwnd */ + break; + case BBR_PROBE_BW: + bbr->pacing_gain = (bbr->lt_use_bw ? + BBR_UNIT : + bbr_pacing_gain[bbr->cycle_idx]); + bbr->cwnd_gain = bbr_cwnd_gain; + break; + case BBR_PROBE_RTT: + bbr->pacing_gain = BBR_UNIT; + bbr->cwnd_gain = BBR_UNIT; + break; + default: + WARN_ONCE(1, "BBR bad mode: %u\n", bbr->mode); + break; + } +} + +static void bbr_update_model(struct sock *sk, const struct rate_sample *rs) +{ + bbr_update_bw(sk, rs); + bbr_update_ack_aggregation(sk, rs); + bbr_update_cycle_phase(sk, rs); + bbr_check_full_bw_reached(sk, rs); + bbr_check_drain(sk, rs); + bbr_update_min_rtt(sk, rs); + bbr_update_gains(sk); +} + +static void bbr_main(struct sock *sk, const struct rate_sample *rs) +{ + struct bbr *bbr = inet_csk_ca(sk); + u32 bw; + + bbr_update_model(sk, rs); + + bw = bbr_bw(sk); + bbr_set_pacing_rate(sk, bw, bbr->pacing_gain); + bbr_set_cwnd(sk, rs, rs->acked_sacked, bw, bbr->cwnd_gain); +} + +static void bbr_init(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + + bbr->prior_cwnd = 0; + tp->snd_ssthresh = TCP_INFINITE_SSTHRESH; + bbr->rtt_cnt = 0; + bbr->next_rtt_delivered = tp->delivered; + bbr->prev_ca_state = TCP_CA_Open; + bbr->packet_conservation = 0; + + bbr->probe_rtt_done_stamp = 0; + bbr->probe_rtt_round_done = 0; + bbr->min_rtt_us = tcp_min_rtt(tp); + bbr->min_rtt_stamp = tcp_jiffies32; + + minmax_reset(&bbr->bw, bbr->rtt_cnt, 0); /* init max bw to 0 */ + + bbr->has_seen_rtt = 0; + bbr_init_pacing_rate_from_rtt(sk); + + bbr->round_start = 0; + bbr->idle_restart = 0; + bbr->full_bw_reached = 0; + bbr->full_bw = 0; + bbr->full_bw_cnt = 0; + bbr->cycle_mstamp = 0; + bbr->cycle_idx = 0; + bbr_reset_lt_bw_sampling(sk); + bbr_reset_startup_mode(sk); + + bbr->ack_epoch_mstamp = tp->tcp_mstamp; + bbr->ack_epoch_acked = 0; + bbr->extra_acked_win_rtts = 0; + bbr->extra_acked_win_idx = 0; + bbr->extra_acked[0] = 0; + bbr->extra_acked[1] = 0; + + cmpxchg(&sk->sk_pacing_status, SK_PACING_NONE, SK_PACING_NEEDED); +} + +static u32 bbr_sndbuf_expand(struct sock *sk) +{ + /* Provision 3 * cwnd since BBR may slow-start even during recovery. */ + return 3; +} + +/* In theory BBR does not need to undo the cwnd since it does not + * always reduce cwnd on losses (see bbr_main()). Keep it for now. + */ +static u32 bbr_undo_cwnd(struct sock *sk) +{ + struct bbr *bbr = inet_csk_ca(sk); + + bbr->full_bw = 0; /* spurious slow-down; reset full pipe detection */ + bbr->full_bw_cnt = 0; + bbr_reset_lt_bw_sampling(sk); + return tcp_snd_cwnd(tcp_sk(sk)); +} + +/* Entering loss recovery, so save cwnd for when we exit or undo recovery. */ +static u32 bbr_ssthresh(struct sock *sk) +{ + bbr_save_cwnd(sk); + return tcp_sk(sk)->snd_ssthresh; +} + +static size_t bbr_get_info(struct sock *sk, u32 ext, int *attr, + union tcp_cc_info *info) +{ + if (ext & (1 << (INET_DIAG_BBRINFO - 1)) || + ext & (1 << (INET_DIAG_VEGASINFO - 1))) { + struct tcp_sock *tp = tcp_sk(sk); + struct bbr *bbr = inet_csk_ca(sk); + u64 bw = bbr_bw(sk); + + bw = bw * tp->mss_cache * USEC_PER_SEC >> BW_SCALE; + memset(&info->bbr, 0, sizeof(info->bbr)); + info->bbr.bbr_bw_lo = (u32)bw; + info->bbr.bbr_bw_hi = (u32)(bw >> 32); + info->bbr.bbr_min_rtt = bbr->min_rtt_us; + info->bbr.bbr_pacing_gain = bbr->pacing_gain; + info->bbr.bbr_cwnd_gain = bbr->cwnd_gain; + *attr = INET_DIAG_BBRINFO; + return sizeof(info->bbr); + } + return 0; +} + +static void bbr_set_state(struct sock *sk, u8 new_state) +{ + struct bbr *bbr = inet_csk_ca(sk); + + if (new_state == TCP_CA_Loss) { + struct rate_sample rs = { .losses = 1 }; + + bbr->prev_ca_state = TCP_CA_Loss; + bbr->full_bw = 0; + bbr->round_start = 1; /* treat RTO like end of a round */ + bbr_lt_bw_sampling(sk, &rs); + } +} + +static struct tcp_congestion_ops tcp_bbr_cong_ops __read_mostly = { + .flags = TCP_CONG_NON_RESTRICTED, + .name = "bbr", + .owner = THIS_MODULE, + .init = bbr_init, + .cong_control = bbr_main, + .sndbuf_expand = bbr_sndbuf_expand, + .undo_cwnd = bbr_undo_cwnd, + .cwnd_event = bbr_cwnd_event, + .ssthresh = bbr_ssthresh, + .min_tso_segs = bbr_min_tso_segs, + .get_info = bbr_get_info, + .set_state = bbr_set_state, +}; + +BTF_SET8_START(tcp_bbr_check_kfunc_ids) +#ifdef CONFIG_X86 +#ifdef CONFIG_DYNAMIC_FTRACE +BTF_ID_FLAGS(func, bbr_init) +BTF_ID_FLAGS(func, bbr_main) +BTF_ID_FLAGS(func, bbr_sndbuf_expand) +BTF_ID_FLAGS(func, bbr_undo_cwnd) +BTF_ID_FLAGS(func, bbr_cwnd_event) +BTF_ID_FLAGS(func, bbr_ssthresh) +BTF_ID_FLAGS(func, bbr_min_tso_segs) +BTF_ID_FLAGS(func, bbr_set_state) +#endif +#endif +BTF_SET8_END(tcp_bbr_check_kfunc_ids) + +static const struct btf_kfunc_id_set tcp_bbr_kfunc_set = { + .owner = THIS_MODULE, + .set = &tcp_bbr_check_kfunc_ids, +}; + +static int __init bbr_register(void) +{ + int ret; + + BUILD_BUG_ON(sizeof(struct bbr) > ICSK_CA_PRIV_SIZE); + + ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_STRUCT_OPS, &tcp_bbr_kfunc_set); + if (ret < 0) + return ret; + return tcp_register_congestion_control(&tcp_bbr_cong_ops); +} + +static void __exit bbr_unregister(void) +{ + tcp_unregister_congestion_control(&tcp_bbr_cong_ops); +} + +module_init(bbr_register); +module_exit(bbr_unregister); + +MODULE_AUTHOR("Van Jacobson "); +MODULE_AUTHOR("Neal Cardwell "); +MODULE_AUTHOR("Yuchung Cheng "); +MODULE_AUTHOR("Soheil Hassas Yeganeh "); +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_DESCRIPTION("TCP BBR (Bottleneck Bandwidth and RTT)"); diff --git a/net/ipv4/tcp_bic.c b/net/ipv4/tcp_bic.c new file mode 100644 index 000000000..58358bf92 --- /dev/null +++ b/net/ipv4/tcp_bic.c @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Binary Increase Congestion control for TCP + * Home page: + * http://netsrv.csc.ncsu.edu/twiki/bin/view/Main/BIC + * This is from the implementation of BICTCP in + * Lison-Xu, Kahaled Harfoush, and Injong Rhee. + * "Binary Increase Congestion Control for Fast, Long Distance + * Networks" in InfoComm 2004 + * Available from: + * http://netsrv.csc.ncsu.edu/export/bitcp.pdf + * + * Unless BIC is enabled and congestion window is large + * this behaves the same as the original Reno. + */ + +#include +#include +#include + +#define BICTCP_BETA_SCALE 1024 /* Scale factor beta calculation + * max_cwnd = snd_cwnd * beta + */ +#define BICTCP_B 4 /* + * In binary search, + * go to point (max+min)/N + */ + +static int fast_convergence = 1; +static int max_increment = 16; +static int low_window = 14; +static int beta = 819; /* = 819/1024 (BICTCP_BETA_SCALE) */ +static int initial_ssthresh; +static int smooth_part = 20; + +module_param(fast_convergence, int, 0644); +MODULE_PARM_DESC(fast_convergence, "turn on/off fast convergence"); +module_param(max_increment, int, 0644); +MODULE_PARM_DESC(max_increment, "Limit on increment allowed during binary search"); +module_param(low_window, int, 0644); +MODULE_PARM_DESC(low_window, "lower bound on congestion window (for TCP friendliness)"); +module_param(beta, int, 0644); +MODULE_PARM_DESC(beta, "beta for multiplicative increase"); +module_param(initial_ssthresh, int, 0644); +MODULE_PARM_DESC(initial_ssthresh, "initial value of slow start threshold"); +module_param(smooth_part, int, 0644); +MODULE_PARM_DESC(smooth_part, "log(B/(B*Smin))/log(B/(B-1))+B, # of RTT from Wmax-B to Wmax"); + +/* BIC TCP Parameters */ +struct bictcp { + u32 cnt; /* increase cwnd by 1 after ACKs */ + u32 last_max_cwnd; /* last maximum snd_cwnd */ + u32 last_cwnd; /* the last snd_cwnd */ + u32 last_time; /* time when updated last_cwnd */ + u32 epoch_start; /* beginning of an epoch */ +#define ACK_RATIO_SHIFT 4 + u32 delayed_ack; /* estimate the ratio of Packets/ACKs << 4 */ +}; + +static inline void bictcp_reset(struct bictcp *ca) +{ + ca->cnt = 0; + ca->last_max_cwnd = 0; + ca->last_cwnd = 0; + ca->last_time = 0; + ca->epoch_start = 0; + ca->delayed_ack = 2 << ACK_RATIO_SHIFT; +} + +static void bictcp_init(struct sock *sk) +{ + struct bictcp *ca = inet_csk_ca(sk); + + bictcp_reset(ca); + + if (initial_ssthresh) + tcp_sk(sk)->snd_ssthresh = initial_ssthresh; +} + +/* + * Compute congestion window to use. + */ +static inline void bictcp_update(struct bictcp *ca, u32 cwnd) +{ + if (ca->last_cwnd == cwnd && + (s32)(tcp_jiffies32 - ca->last_time) <= HZ / 32) + return; + + ca->last_cwnd = cwnd; + ca->last_time = tcp_jiffies32; + + if (ca->epoch_start == 0) /* record the beginning of an epoch */ + ca->epoch_start = tcp_jiffies32; + + /* start off normal */ + if (cwnd <= low_window) { + ca->cnt = cwnd; + return; + } + + /* binary increase */ + if (cwnd < ca->last_max_cwnd) { + __u32 dist = (ca->last_max_cwnd - cwnd) + / BICTCP_B; + + if (dist > max_increment) + /* linear increase */ + ca->cnt = cwnd / max_increment; + else if (dist <= 1U) + /* binary search increase */ + ca->cnt = (cwnd * smooth_part) / BICTCP_B; + else + /* binary search increase */ + ca->cnt = cwnd / dist; + } else { + /* slow start AMD linear increase */ + if (cwnd < ca->last_max_cwnd + BICTCP_B) + /* slow start */ + ca->cnt = (cwnd * smooth_part) / BICTCP_B; + else if (cwnd < ca->last_max_cwnd + max_increment*(BICTCP_B-1)) + /* slow start */ + ca->cnt = (cwnd * (BICTCP_B-1)) + / (cwnd - ca->last_max_cwnd); + else + /* linear increase */ + ca->cnt = cwnd / max_increment; + } + + /* if in slow start or link utilization is very low */ + if (ca->last_max_cwnd == 0) { + if (ca->cnt > 20) /* increase cwnd 5% per RTT */ + ca->cnt = 20; + } + + ca->cnt = (ca->cnt << ACK_RATIO_SHIFT) / ca->delayed_ack; + if (ca->cnt == 0) /* cannot be zero */ + ca->cnt = 1; +} + +static void bictcp_cong_avoid(struct sock *sk, u32 ack, u32 acked) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bictcp *ca = inet_csk_ca(sk); + + if (!tcp_is_cwnd_limited(sk)) + return; + + if (tcp_in_slow_start(tp)) { + acked = tcp_slow_start(tp, acked); + if (!acked) + return; + } + bictcp_update(ca, tcp_snd_cwnd(tp)); + tcp_cong_avoid_ai(tp, ca->cnt, acked); +} + +/* + * behave like Reno until low_window is reached, + * then increase congestion window slowly + */ +static u32 bictcp_recalc_ssthresh(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + struct bictcp *ca = inet_csk_ca(sk); + + ca->epoch_start = 0; /* end of epoch */ + + /* Wmax and fast convergence */ + if (tcp_snd_cwnd(tp) < ca->last_max_cwnd && fast_convergence) + ca->last_max_cwnd = (tcp_snd_cwnd(tp) * (BICTCP_BETA_SCALE + beta)) + / (2 * BICTCP_BETA_SCALE); + else + ca->last_max_cwnd = tcp_snd_cwnd(tp); + + if (tcp_snd_cwnd(tp) <= low_window) + return max(tcp_snd_cwnd(tp) >> 1U, 2U); + else + return max((tcp_snd_cwnd(tp) * beta) / BICTCP_BETA_SCALE, 2U); +} + +static void bictcp_state(struct sock *sk, u8 new_state) +{ + if (new_state == TCP_CA_Loss) + bictcp_reset(inet_csk_ca(sk)); +} + +/* Track delayed acknowledgment ratio using sliding window + * ratio = (15*ratio + sample) / 16 + */ +static void bictcp_acked(struct sock *sk, const struct ack_sample *sample) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + + if (icsk->icsk_ca_state == TCP_CA_Open) { + struct bictcp *ca = inet_csk_ca(sk); + + ca->delayed_ack += sample->pkts_acked - + (ca->delayed_ack >> ACK_RATIO_SHIFT); + } +} + +static struct tcp_congestion_ops bictcp __read_mostly = { + .init = bictcp_init, + .ssthresh = bictcp_recalc_ssthresh, + .cong_avoid = bictcp_cong_avoid, + .set_state = bictcp_state, + .undo_cwnd = tcp_reno_undo_cwnd, + .pkts_acked = bictcp_acked, + .owner = THIS_MODULE, + .name = "bic", +}; + +static int __init bictcp_register(void) +{ + BUILD_BUG_ON(sizeof(struct bictcp) > ICSK_CA_PRIV_SIZE); + return tcp_register_congestion_control(&bictcp); +} + +static void __exit bictcp_unregister(void) +{ + tcp_unregister_congestion_control(&bictcp); +} + +module_init(bictcp_register); +module_exit(bictcp_unregister); + +MODULE_AUTHOR("Stephen Hemminger"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("BIC TCP"); diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c new file mode 100644 index 000000000..f8037d142 --- /dev/null +++ b/net/ipv4/tcp_bpf.c @@ -0,0 +1,745 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ + +#include +#include +#include +#include +#include +#include + +#include +#include + +void tcp_eat_skb(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tcp; + int copied; + + if (!skb || !skb->len || !sk_is_tcp(sk)) + return; + + if (skb_bpf_strparser(skb)) + return; + + tcp = tcp_sk(sk); + copied = tcp->copied_seq + skb->len; + WRITE_ONCE(tcp->copied_seq, copied); + tcp_rcv_space_adjust(sk); + __tcp_cleanup_rbuf(sk, skb->len); +} + +static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, + struct sk_msg *msg, u32 apply_bytes, int flags) +{ + bool apply = apply_bytes; + struct scatterlist *sge; + u32 size, copied = 0; + struct sk_msg *tmp; + int i, ret = 0; + + tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL); + if (unlikely(!tmp)) + return -ENOMEM; + + lock_sock(sk); + tmp->sg.start = msg->sg.start; + i = msg->sg.start; + do { + sge = sk_msg_elem(msg, i); + size = (apply && apply_bytes < sge->length) ? + apply_bytes : sge->length; + if (!sk_wmem_schedule(sk, size)) { + if (!copied) + ret = -ENOMEM; + break; + } + + sk_mem_charge(sk, size); + sk_msg_xfer(tmp, msg, i, size); + copied += size; + if (sge->length) + get_page(sk_msg_page(tmp, i)); + sk_msg_iter_var_next(i); + tmp->sg.end = i; + if (apply) { + apply_bytes -= size; + if (!apply_bytes) { + if (sge->length) + sk_msg_iter_var_prev(i); + break; + } + } + } while (i != msg->sg.end); + + if (!ret) { + msg->sg.start = i; + sk_psock_queue_msg(psock, tmp); + sk_psock_data_ready(sk, psock); + } else { + sk_msg_free(sk, tmp); + kfree(tmp); + } + + release_sock(sk); + return ret; +} + +static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, + int flags, bool uncharge) +{ + bool apply = apply_bytes; + struct scatterlist *sge; + struct page *page; + int size, ret = 0; + u32 off; + + while (1) { + bool has_tx_ulp; + + sge = sk_msg_elem(msg, msg->sg.start); + size = (apply && apply_bytes < sge->length) ? + apply_bytes : sge->length; + off = sge->offset; + page = sg_page(sge); + + tcp_rate_check_app_limited(sk); +retry: + has_tx_ulp = tls_sw_has_ctx_tx(sk); + if (has_tx_ulp) { + flags |= MSG_SENDPAGE_NOPOLICY; + ret = kernel_sendpage_locked(sk, + page, off, size, flags); + } else { + ret = do_tcp_sendpages(sk, page, off, size, flags); + } + + if (ret <= 0) + return ret; + if (apply) + apply_bytes -= ret; + msg->sg.size -= ret; + sge->offset += ret; + sge->length -= ret; + if (uncharge) + sk_mem_uncharge(sk, ret); + if (ret != size) { + size -= ret; + off += ret; + goto retry; + } + if (!sge->length) { + put_page(page); + sk_msg_iter_next(msg, start); + sg_init_table(sge, 1); + if (msg->sg.start == msg->sg.end) + break; + } + if (apply && !apply_bytes) + break; + } + + return 0; +} + +static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg, + u32 apply_bytes, int flags, bool uncharge) +{ + int ret; + + lock_sock(sk); + ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge); + release_sock(sk); + return ret; +} + +int tcp_bpf_sendmsg_redir(struct sock *sk, bool ingress, + struct sk_msg *msg, u32 bytes, int flags) +{ + struct sk_psock *psock = sk_psock_get(sk); + int ret; + + if (unlikely(!psock)) + return -EPIPE; + + ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) : + tcp_bpf_push_locked(sk, msg, bytes, flags, false); + sk_psock_put(sk, psock); + return ret; +} +EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); + +#ifdef CONFIG_BPF_SYSCALL +static int tcp_msg_wait_data(struct sock *sk, struct sk_psock *psock, + long timeo) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + int ret = 0; + + if (sk->sk_shutdown & RCV_SHUTDOWN) + return 1; + + if (!timeo) + return ret; + + add_wait_queue(sk_sleep(sk), &wait); + sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); + ret = sk_wait_event(sk, &timeo, + !list_empty(&psock->ingress_msg) || + !skb_queue_empty_lockless(&sk->sk_receive_queue), &wait); + sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); + remove_wait_queue(sk_sleep(sk), &wait); + return ret; +} + +static bool is_next_msg_fin(struct sk_psock *psock) +{ + struct scatterlist *sge; + struct sk_msg *msg_rx; + int i; + + msg_rx = sk_psock_peek_msg(psock); + i = msg_rx->sg.start; + sge = sk_msg_elem(msg_rx, i); + if (!sge->length) { + struct sk_buff *skb = msg_rx->skb; + + if (skb && TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) + return true; + } + return false; +} + +static int tcp_bpf_recvmsg_parser(struct sock *sk, + struct msghdr *msg, + size_t len, + int flags, + int *addr_len) +{ + struct tcp_sock *tcp = tcp_sk(sk); + int peek = flags & MSG_PEEK; + u32 seq = tcp->copied_seq; + struct sk_psock *psock; + int copied = 0; + + if (unlikely(flags & MSG_ERRQUEUE)) + return inet_recv_error(sk, msg, len, addr_len); + + if (!len) + return 0; + + psock = sk_psock_get(sk); + if (unlikely(!psock)) + return tcp_recvmsg(sk, msg, len, flags, addr_len); + + lock_sock(sk); + + /* We may have received data on the sk_receive_queue pre-accept and + * then we can not use read_skb in this context because we haven't + * assigned a sk_socket yet so have no link to the ops. The work-around + * is to check the sk_receive_queue and in these cases read skbs off + * queue again. The read_skb hook is not running at this point because + * of lock_sock so we avoid having multiple runners in read_skb. + */ + if (unlikely(!skb_queue_empty(&sk->sk_receive_queue))) { + tcp_data_ready(sk); + /* This handles the ENOMEM errors if we both receive data + * pre accept and are already under memory pressure. At least + * let user know to retry. + */ + if (unlikely(!skb_queue_empty(&sk->sk_receive_queue))) { + copied = -EAGAIN; + goto out; + } + } + +msg_bytes_ready: + copied = sk_msg_recvmsg(sk, psock, msg, len, flags); + /* The typical case for EFAULT is the socket was gracefully + * shutdown with a FIN pkt. So check here the other case is + * some error on copy_page_to_iter which would be unexpected. + * On fin return correct return code to zero. + */ + if (copied == -EFAULT) { + bool is_fin = is_next_msg_fin(psock); + + if (is_fin) { + copied = 0; + seq++; + goto out; + } + } + seq += copied; + if (!copied) { + long timeo; + int data; + + if (sock_flag(sk, SOCK_DONE)) + goto out; + + if (sk->sk_err) { + copied = sock_error(sk); + goto out; + } + + if (sk->sk_shutdown & RCV_SHUTDOWN) + goto out; + + if (sk->sk_state == TCP_CLOSE) { + copied = -ENOTCONN; + goto out; + } + + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + if (!timeo) { + copied = -EAGAIN; + goto out; + } + + if (signal_pending(current)) { + copied = sock_intr_errno(timeo); + goto out; + } + + data = tcp_msg_wait_data(sk, psock, timeo); + if (data < 0) { + copied = data; + goto unlock; + } + if (data && !sk_psock_queue_empty(psock)) + goto msg_bytes_ready; + copied = -EAGAIN; + } +out: + if (!peek) + WRITE_ONCE(tcp->copied_seq, seq); + tcp_rcv_space_adjust(sk); + if (copied > 0) + __tcp_cleanup_rbuf(sk, copied); + +unlock: + release_sock(sk); + sk_psock_put(sk, psock); + return copied; +} + +static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int flags, int *addr_len) +{ + struct sk_psock *psock; + int copied, ret; + + if (unlikely(flags & MSG_ERRQUEUE)) + return inet_recv_error(sk, msg, len, addr_len); + + if (!len) + return 0; + + psock = sk_psock_get(sk); + if (unlikely(!psock)) + return tcp_recvmsg(sk, msg, len, flags, addr_len); + if (!skb_queue_empty(&sk->sk_receive_queue) && + sk_psock_queue_empty(psock)) { + sk_psock_put(sk, psock); + return tcp_recvmsg(sk, msg, len, flags, addr_len); + } + lock_sock(sk); +msg_bytes_ready: + copied = sk_msg_recvmsg(sk, psock, msg, len, flags); + if (!copied) { + long timeo; + int data; + + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + data = tcp_msg_wait_data(sk, psock, timeo); + if (data < 0) { + ret = data; + goto unlock; + } + if (data) { + if (!sk_psock_queue_empty(psock)) + goto msg_bytes_ready; + release_sock(sk); + sk_psock_put(sk, psock); + return tcp_recvmsg(sk, msg, len, flags, addr_len); + } + copied = -EAGAIN; + } + ret = copied; + +unlock: + release_sock(sk); + sk_psock_put(sk, psock); + return ret; +} + +static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, + struct sk_msg *msg, int *copied, int flags) +{ + bool cork = false, enospc = sk_msg_full(msg), redir_ingress; + struct sock *sk_redir; + u32 tosend, origsize, sent, delta = 0; + u32 eval; + int ret; + +more_data: + if (psock->eval == __SK_NONE) { + /* Track delta in msg size to add/subtract it on SK_DROP from + * returned to user copied size. This ensures user doesn't + * get a positive return code with msg_cut_data and SK_DROP + * verdict. + */ + delta = msg->sg.size; + psock->eval = sk_psock_msg_verdict(sk, psock, msg); + delta -= msg->sg.size; + } + + if (msg->cork_bytes && + msg->cork_bytes > msg->sg.size && !enospc) { + psock->cork_bytes = msg->cork_bytes - msg->sg.size; + if (!psock->cork) { + psock->cork = kzalloc(sizeof(*psock->cork), + GFP_ATOMIC | __GFP_NOWARN); + if (!psock->cork) + return -ENOMEM; + } + memcpy(psock->cork, msg, sizeof(*msg)); + return 0; + } + + tosend = msg->sg.size; + if (psock->apply_bytes && psock->apply_bytes < tosend) + tosend = psock->apply_bytes; + eval = __SK_NONE; + + switch (psock->eval) { + case __SK_PASS: + ret = tcp_bpf_push(sk, msg, tosend, flags, true); + if (unlikely(ret)) { + *copied -= sk_msg_free(sk, msg); + break; + } + sk_msg_apply_bytes(psock, tosend); + break; + case __SK_REDIRECT: + redir_ingress = psock->redir_ingress; + sk_redir = psock->sk_redir; + sk_msg_apply_bytes(psock, tosend); + if (!psock->apply_bytes) { + /* Clean up before releasing the sock lock. */ + eval = psock->eval; + psock->eval = __SK_NONE; + psock->sk_redir = NULL; + } + if (psock->cork) { + cork = true; + psock->cork = NULL; + } + sk_msg_return(sk, msg, tosend); + release_sock(sk); + + origsize = msg->sg.size; + ret = tcp_bpf_sendmsg_redir(sk_redir, redir_ingress, + msg, tosend, flags); + sent = origsize - msg->sg.size; + + if (eval == __SK_REDIRECT) + sock_put(sk_redir); + + lock_sock(sk); + if (unlikely(ret < 0)) { + int free = sk_msg_free_nocharge(sk, msg); + + if (!cork) + *copied -= free; + } + if (cork) { + sk_msg_free(sk, msg); + kfree(msg); + msg = NULL; + ret = 0; + } + break; + case __SK_DROP: + default: + sk_msg_free_partial(sk, msg, tosend); + sk_msg_apply_bytes(psock, tosend); + *copied -= (tosend + delta); + return -EACCES; + } + + if (likely(!ret)) { + if (!psock->apply_bytes) { + psock->eval = __SK_NONE; + if (psock->sk_redir) { + sock_put(psock->sk_redir); + psock->sk_redir = NULL; + } + } + if (msg && + msg->sg.data[msg->sg.start].page_link && + msg->sg.data[msg->sg.start].length) { + if (eval == __SK_REDIRECT) + sk_mem_charge(sk, tosend - sent); + goto more_data; + } + } + return ret; +} + +static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +{ + struct sk_msg tmp, *msg_tx = NULL; + int copied = 0, err = 0; + struct sk_psock *psock; + long timeo; + int flags; + + /* Don't let internal do_tcp_sendpages() flags through */ + flags = (msg->msg_flags & ~MSG_SENDPAGE_DECRYPTED); + flags |= MSG_NO_SHARED_FRAGS; + + psock = sk_psock_get(sk); + if (unlikely(!psock)) + return tcp_sendmsg(sk, msg, size); + + lock_sock(sk); + timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + while (msg_data_left(msg)) { + bool enospc = false; + u32 copy, osize; + + if (sk->sk_err) { + err = -sk->sk_err; + goto out_err; + } + + copy = msg_data_left(msg); + if (!sk_stream_memory_free(sk)) + goto wait_for_sndbuf; + if (psock->cork) { + msg_tx = psock->cork; + } else { + msg_tx = &tmp; + sk_msg_init(msg_tx); + } + + osize = msg_tx->sg.size; + err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1); + if (err) { + if (err != -ENOSPC) + goto wait_for_memory; + enospc = true; + copy = msg_tx->sg.size - osize; + } + + err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx, + copy); + if (err < 0) { + sk_msg_trim(sk, msg_tx, osize); + goto out_err; + } + + copied += copy; + if (psock->cork_bytes) { + if (size > psock->cork_bytes) + psock->cork_bytes = 0; + else + psock->cork_bytes -= size; + if (psock->cork_bytes && !enospc) + goto out_err; + /* All cork bytes are accounted, rerun the prog. */ + psock->eval = __SK_NONE; + psock->cork_bytes = 0; + } + + err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags); + if (unlikely(err < 0)) + goto out_err; + continue; +wait_for_sndbuf: + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); +wait_for_memory: + err = sk_stream_wait_memory(sk, &timeo); + if (err) { + if (msg_tx && msg_tx != psock->cork) + sk_msg_free(sk, msg_tx); + goto out_err; + } + } +out_err: + if (err < 0) + err = sk_stream_error(sk, msg->msg_flags, err); + release_sock(sk); + sk_psock_put(sk, psock); + return copied ? copied : err; +} + +static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset, + size_t size, int flags) +{ + struct sk_msg tmp, *msg = NULL; + int err = 0, copied = 0; + struct sk_psock *psock; + bool enospc = false; + + psock = sk_psock_get(sk); + if (unlikely(!psock)) + return tcp_sendpage(sk, page, offset, size, flags); + + lock_sock(sk); + if (psock->cork) { + msg = psock->cork; + } else { + msg = &tmp; + sk_msg_init(msg); + } + + /* Catch case where ring is full and sendpage is stalled. */ + if (unlikely(sk_msg_full(msg))) + goto out_err; + + sk_msg_page_add(msg, page, size, offset); + sk_mem_charge(sk, size); + copied = size; + if (sk_msg_full(msg)) + enospc = true; + if (psock->cork_bytes) { + if (size > psock->cork_bytes) + psock->cork_bytes = 0; + else + psock->cork_bytes -= size; + if (psock->cork_bytes && !enospc) + goto out_err; + /* All cork bytes are accounted, rerun the prog. */ + psock->eval = __SK_NONE; + psock->cork_bytes = 0; + } + + err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags); +out_err: + release_sock(sk); + sk_psock_put(sk, psock); + return copied ? copied : err; +} + +enum { + TCP_BPF_IPV4, + TCP_BPF_IPV6, + TCP_BPF_NUM_PROTS, +}; + +enum { + TCP_BPF_BASE, + TCP_BPF_TX, + TCP_BPF_RX, + TCP_BPF_TXRX, + TCP_BPF_NUM_CFGS, +}; + +static struct proto *tcpv6_prot_saved __read_mostly; +static DEFINE_SPINLOCK(tcpv6_prot_lock); +static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS]; + +static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], + struct proto *base) +{ + prot[TCP_BPF_BASE] = *base; + prot[TCP_BPF_BASE].destroy = sock_map_destroy; + prot[TCP_BPF_BASE].close = sock_map_close; + prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; + prot[TCP_BPF_BASE].sock_is_readable = sk_msg_is_readable; + + prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; + prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; + prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage; + + prot[TCP_BPF_RX] = prot[TCP_BPF_BASE]; + prot[TCP_BPF_RX].recvmsg = tcp_bpf_recvmsg_parser; + + prot[TCP_BPF_TXRX] = prot[TCP_BPF_TX]; + prot[TCP_BPF_TXRX].recvmsg = tcp_bpf_recvmsg_parser; +} + +static void tcp_bpf_check_v6_needs_rebuild(struct proto *ops) +{ + if (unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) { + spin_lock_bh(&tcpv6_prot_lock); + if (likely(ops != tcpv6_prot_saved)) { + tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops); + smp_store_release(&tcpv6_prot_saved, ops); + } + spin_unlock_bh(&tcpv6_prot_lock); + } +} + +static int __init tcp_bpf_v4_build_proto(void) +{ + tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot); + return 0; +} +late_initcall(tcp_bpf_v4_build_proto); + +static int tcp_bpf_assert_proto_ops(struct proto *ops) +{ + /* In order to avoid retpoline, we make assumptions when we call + * into ops if e.g. a psock is not present. Make sure they are + * indeed valid assumptions. + */ + return ops->recvmsg == tcp_recvmsg && + ops->sendmsg == tcp_sendmsg && + ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; +} + +int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) +{ + int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; + int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; + + if (psock->progs.stream_verdict || psock->progs.skb_verdict) { + config = (config == TCP_BPF_TX) ? TCP_BPF_TXRX : TCP_BPF_RX; + } + + if (restore) { + if (inet_csk_has_ulp(sk)) { + /* TLS does not have an unhash proto in SW cases, + * but we need to ensure we stop using the sock_map + * unhash routine because the associated psock is being + * removed. So use the original unhash handler. + */ + WRITE_ONCE(sk->sk_prot->unhash, psock->saved_unhash); + tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space); + } else { + sk->sk_write_space = psock->saved_write_space; + /* Pairs with lockless read in sk_clone_lock() */ + sock_replace_proto(sk, psock->sk_proto); + } + return 0; + } + + if (sk->sk_family == AF_INET6) { + if (tcp_bpf_assert_proto_ops(psock->sk_proto)) + return -EINVAL; + + tcp_bpf_check_v6_needs_rebuild(psock->sk_proto); + } + + /* Pairs with lockless read in sk_clone_lock() */ + sock_replace_proto(sk, &tcp_bpf_prots[family][config]); + return 0; +} +EXPORT_SYMBOL_GPL(tcp_bpf_update_proto); + +/* If a child got cloned from a listening socket that had tcp_bpf + * protocol callbacks installed, we need to restore the callbacks to + * the default ones because the child does not inherit the psock state + * that tcp_bpf callbacks expect. + */ +void tcp_bpf_clone(const struct sock *sk, struct sock *newsk) +{ + struct proto *prot = newsk->sk_prot; + + if (is_insidevar(prot, tcp_bpf_prots)) + newsk->sk_prot = sk->sk_prot_creator; +} +#endif /* CONFIG_BPF_SYSCALL */ diff --git a/net/ipv4/tcp_cdg.c b/net/ipv4/tcp_cdg.c new file mode 100644 index 000000000..ba4d98e51 --- /dev/null +++ b/net/ipv4/tcp_cdg.c @@ -0,0 +1,428 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * CAIA Delay-Gradient (CDG) congestion control + * + * This implementation is based on the paper: + * D.A. Hayes and G. Armitage. "Revisiting TCP congestion control using + * delay gradients." In IFIP Networking, pages 328-341. Springer, 2011. + * + * Scavenger traffic (Less-than-Best-Effort) should disable coexistence + * heuristics using parameters use_shadow=0 and use_ineff=0. + * + * Parameters window, backoff_beta, and backoff_factor are crucial for + * throughput and delay. Future work is needed to determine better defaults, + * and to provide guidelines for use in different environments/contexts. + * + * Except for window, knobs are configured via /sys/module/tcp_cdg/parameters/. + * Parameter window is only configurable when loading tcp_cdg as a module. + * + * Notable differences from paper/FreeBSD: + * o Using Hybrid Slow start and Proportional Rate Reduction. + * o Add toggle for shadow window mechanism. Suggested by David Hayes. + * o Add toggle for non-congestion loss tolerance. + * o Scaling parameter G is changed to a backoff factor; + * conversion is given by: backoff_factor = 1000/(G * window). + * o Limit shadow window to 2 * cwnd, or to cwnd when application limited. + * o More accurate e^-x. + */ +#include +#include +#include +#include + +#include + +#define HYSTART_ACK_TRAIN 1 +#define HYSTART_DELAY 2 + +static int window __read_mostly = 8; +static unsigned int backoff_beta __read_mostly = 0.7071 * 1024; /* sqrt 0.5 */ +static unsigned int backoff_factor __read_mostly = 42; +static unsigned int hystart_detect __read_mostly = 3; +static unsigned int use_ineff __read_mostly = 5; +static bool use_shadow __read_mostly = true; +static bool use_tolerance __read_mostly; + +module_param(window, int, 0444); +MODULE_PARM_DESC(window, "gradient window size (power of two <= 256)"); +module_param(backoff_beta, uint, 0644); +MODULE_PARM_DESC(backoff_beta, "backoff beta (0-1024)"); +module_param(backoff_factor, uint, 0644); +MODULE_PARM_DESC(backoff_factor, "backoff probability scale factor"); +module_param(hystart_detect, uint, 0644); +MODULE_PARM_DESC(hystart_detect, "use Hybrid Slow start " + "(0: disabled, 1: ACK train, 2: delay threshold, 3: both)"); +module_param(use_ineff, uint, 0644); +MODULE_PARM_DESC(use_ineff, "use ineffectual backoff detection (threshold)"); +module_param(use_shadow, bool, 0644); +MODULE_PARM_DESC(use_shadow, "use shadow window heuristic"); +module_param(use_tolerance, bool, 0644); +MODULE_PARM_DESC(use_tolerance, "use loss tolerance heuristic"); + +struct cdg_minmax { + union { + struct { + s32 min; + s32 max; + }; + u64 v64; + }; +}; + +enum cdg_state { + CDG_UNKNOWN = 0, + CDG_NONFULL = 1, + CDG_FULL = 2, + CDG_BACKOFF = 3, +}; + +struct cdg { + struct cdg_minmax rtt; + struct cdg_minmax rtt_prev; + struct cdg_minmax *gradients; + struct cdg_minmax gsum; + bool gfilled; + u8 tail; + u8 state; + u8 delack; + u32 rtt_seq; + u32 shadow_wnd; + u16 backoff_cnt; + u16 sample_cnt; + s32 delay_min; + u32 last_ack; + u32 round_start; +}; + +/** + * nexp_u32 - negative base-e exponential + * @ux: x in units of micro + * + * Returns exp(ux * -1e-6) * U32_MAX. + */ +static u32 __pure nexp_u32(u32 ux) +{ + static const u16 v[] = { + /* exp(-x)*65536-1 for x = 0, 0.000256, 0.000512, ... */ + 65535, + 65518, 65501, 65468, 65401, 65267, 65001, 64470, 63422, + 61378, 57484, 50423, 38795, 22965, 8047, 987, 14, + }; + u32 msb = ux >> 8; + u32 res; + int i; + + /* Cut off when ux >= 2^24 (actual result is <= 222/U32_MAX). */ + if (msb > U16_MAX) + return 0; + + /* Scale first eight bits linearly: */ + res = U32_MAX - (ux & 0xff) * (U32_MAX / 1000000); + + /* Obtain e^(x + y + ...) by computing e^x * e^y * ...: */ + for (i = 1; msb; i++, msb >>= 1) { + u32 y = v[i & -(msb & 1)] + U32_C(1); + + res = ((u64)res * y) >> 16; + } + + return res; +} + +/* Based on the HyStart algorithm (by Ha et al.) that is implemented in + * tcp_cubic. Differences/experimental changes: + * o Using Hayes' delayed ACK filter. + * o Using a usec clock for the ACK train. + * o Reset ACK train when application limited. + * o Invoked at any cwnd (i.e. also when cwnd < 16). + * o Invoked only when cwnd < ssthresh (i.e. not when cwnd == ssthresh). + */ +static void tcp_cdg_hystart_update(struct sock *sk) +{ + struct cdg *ca = inet_csk_ca(sk); + struct tcp_sock *tp = tcp_sk(sk); + + ca->delay_min = min_not_zero(ca->delay_min, ca->rtt.min); + if (ca->delay_min == 0) + return; + + if (hystart_detect & HYSTART_ACK_TRAIN) { + u32 now_us = tp->tcp_mstamp; + + if (ca->last_ack == 0 || !tcp_is_cwnd_limited(sk)) { + ca->last_ack = now_us; + ca->round_start = now_us; + } else if (before(now_us, ca->last_ack + 3000)) { + u32 base_owd = max(ca->delay_min / 2U, 125U); + + ca->last_ack = now_us; + if (after(now_us, ca->round_start + base_owd)) { + NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPHYSTARTTRAINDETECT); + NET_ADD_STATS(sock_net(sk), + LINUX_MIB_TCPHYSTARTTRAINCWND, + tcp_snd_cwnd(tp)); + tp->snd_ssthresh = tcp_snd_cwnd(tp); + return; + } + } + } + + if (hystart_detect & HYSTART_DELAY) { + if (ca->sample_cnt < 8) { + ca->sample_cnt++; + } else { + s32 thresh = max(ca->delay_min + ca->delay_min / 8U, + 125U); + + if (ca->rtt.min > thresh) { + NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPHYSTARTDELAYDETECT); + NET_ADD_STATS(sock_net(sk), + LINUX_MIB_TCPHYSTARTDELAYCWND, + tcp_snd_cwnd(tp)); + tp->snd_ssthresh = tcp_snd_cwnd(tp); + } + } + } +} + +static s32 tcp_cdg_grad(struct cdg *ca) +{ + s32 gmin = ca->rtt.min - ca->rtt_prev.min; + s32 gmax = ca->rtt.max - ca->rtt_prev.max; + s32 grad; + + if (ca->gradients) { + ca->gsum.min += gmin - ca->gradients[ca->tail].min; + ca->gsum.max += gmax - ca->gradients[ca->tail].max; + ca->gradients[ca->tail].min = gmin; + ca->gradients[ca->tail].max = gmax; + ca->tail = (ca->tail + 1) & (window - 1); + gmin = ca->gsum.min; + gmax = ca->gsum.max; + } + + /* We keep sums to ignore gradients during cwnd reductions; + * the paper's smoothed gradients otherwise simplify to: + * (rtt_latest - rtt_oldest) / window. + * + * We also drop division by window here. + */ + grad = gmin > 0 ? gmin : gmax; + + /* Extrapolate missing values in gradient window: */ + if (!ca->gfilled) { + if (!ca->gradients && window > 1) + grad *= window; /* Memory allocation failed. */ + else if (ca->tail == 0) + ca->gfilled = true; + else + grad = (grad * window) / (int)ca->tail; + } + + /* Backoff was effectual: */ + if (gmin <= -32 || gmax <= -32) + ca->backoff_cnt = 0; + + if (use_tolerance) { + /* Reduce small variations to zero: */ + gmin = DIV_ROUND_CLOSEST(gmin, 64); + gmax = DIV_ROUND_CLOSEST(gmax, 64); + + if (gmin > 0 && gmax <= 0) + ca->state = CDG_FULL; + else if ((gmin > 0 && gmax > 0) || gmax < 0) + ca->state = CDG_NONFULL; + } + return grad; +} + +static bool tcp_cdg_backoff(struct sock *sk, u32 grad) +{ + struct cdg *ca = inet_csk_ca(sk); + struct tcp_sock *tp = tcp_sk(sk); + + if (get_random_u32() <= nexp_u32(grad * backoff_factor)) + return false; + + if (use_ineff) { + ca->backoff_cnt++; + if (ca->backoff_cnt > use_ineff) + return false; + } + + ca->shadow_wnd = max(ca->shadow_wnd, tcp_snd_cwnd(tp)); + ca->state = CDG_BACKOFF; + tcp_enter_cwr(sk); + return true; +} + +/* Not called in CWR or Recovery state. */ +static void tcp_cdg_cong_avoid(struct sock *sk, u32 ack, u32 acked) +{ + struct cdg *ca = inet_csk_ca(sk); + struct tcp_sock *tp = tcp_sk(sk); + u32 prior_snd_cwnd; + u32 incr; + + if (tcp_in_slow_start(tp) && hystart_detect) + tcp_cdg_hystart_update(sk); + + if (after(ack, ca->rtt_seq) && ca->rtt.v64) { + s32 grad = 0; + + if (ca->rtt_prev.v64) + grad = tcp_cdg_grad(ca); + ca->rtt_seq = tp->snd_nxt; + ca->rtt_prev = ca->rtt; + ca->rtt.v64 = 0; + ca->last_ack = 0; + ca->sample_cnt = 0; + + if (grad > 0 && tcp_cdg_backoff(sk, grad)) + return; + } + + if (!tcp_is_cwnd_limited(sk)) { + ca->shadow_wnd = min(ca->shadow_wnd, tcp_snd_cwnd(tp)); + return; + } + + prior_snd_cwnd = tcp_snd_cwnd(tp); + tcp_reno_cong_avoid(sk, ack, acked); + + incr = tcp_snd_cwnd(tp) - prior_snd_cwnd; + ca->shadow_wnd = max(ca->shadow_wnd, ca->shadow_wnd + incr); +} + +static void tcp_cdg_acked(struct sock *sk, const struct ack_sample *sample) +{ + struct cdg *ca = inet_csk_ca(sk); + struct tcp_sock *tp = tcp_sk(sk); + + if (sample->rtt_us <= 0) + return; + + /* A heuristic for filtering delayed ACKs, adapted from: + * D.A. Hayes. "Timing enhancements to the FreeBSD kernel to support + * delay and rate based TCP mechanisms." TR 100219A. CAIA, 2010. + */ + if (tp->sacked_out == 0) { + if (sample->pkts_acked == 1 && ca->delack) { + /* A delayed ACK is only used for the minimum if it is + * provenly lower than an existing non-zero minimum. + */ + ca->rtt.min = min(ca->rtt.min, sample->rtt_us); + ca->delack--; + return; + } else if (sample->pkts_acked > 1 && ca->delack < 5) { + ca->delack++; + } + } + + ca->rtt.min = min_not_zero(ca->rtt.min, sample->rtt_us); + ca->rtt.max = max(ca->rtt.max, sample->rtt_us); +} + +static u32 tcp_cdg_ssthresh(struct sock *sk) +{ + struct cdg *ca = inet_csk_ca(sk); + struct tcp_sock *tp = tcp_sk(sk); + + if (ca->state == CDG_BACKOFF) + return max(2U, (tcp_snd_cwnd(tp) * min(1024U, backoff_beta)) >> 10); + + if (ca->state == CDG_NONFULL && use_tolerance) + return tcp_snd_cwnd(tp); + + ca->shadow_wnd = min(ca->shadow_wnd >> 1, tcp_snd_cwnd(tp)); + if (use_shadow) + return max3(2U, ca->shadow_wnd, tcp_snd_cwnd(tp) >> 1); + return max(2U, tcp_snd_cwnd(tp) >> 1); +} + +static void tcp_cdg_cwnd_event(struct sock *sk, const enum tcp_ca_event ev) +{ + struct cdg *ca = inet_csk_ca(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct cdg_minmax *gradients; + + switch (ev) { + case CA_EVENT_CWND_RESTART: + gradients = ca->gradients; + if (gradients) + memset(gradients, 0, window * sizeof(gradients[0])); + memset(ca, 0, sizeof(*ca)); + + ca->gradients = gradients; + ca->rtt_seq = tp->snd_nxt; + ca->shadow_wnd = tcp_snd_cwnd(tp); + break; + case CA_EVENT_COMPLETE_CWR: + ca->state = CDG_UNKNOWN; + ca->rtt_seq = tp->snd_nxt; + ca->rtt_prev = ca->rtt; + ca->rtt.v64 = 0; + break; + default: + break; + } +} + +static void tcp_cdg_init(struct sock *sk) +{ + struct cdg *ca = inet_csk_ca(sk); + struct tcp_sock *tp = tcp_sk(sk); + + ca->gradients = NULL; + /* We silently fall back to window = 1 if allocation fails. */ + if (window > 1) + ca->gradients = kcalloc(window, sizeof(ca->gradients[0]), + GFP_NOWAIT | __GFP_NOWARN); + ca->rtt_seq = tp->snd_nxt; + ca->shadow_wnd = tcp_snd_cwnd(tp); +} + +static void tcp_cdg_release(struct sock *sk) +{ + struct cdg *ca = inet_csk_ca(sk); + + kfree(ca->gradients); + ca->gradients = NULL; +} + +static struct tcp_congestion_ops tcp_cdg __read_mostly = { + .cong_avoid = tcp_cdg_cong_avoid, + .cwnd_event = tcp_cdg_cwnd_event, + .pkts_acked = tcp_cdg_acked, + .undo_cwnd = tcp_reno_undo_cwnd, + .ssthresh = tcp_cdg_ssthresh, + .release = tcp_cdg_release, + .init = tcp_cdg_init, + .owner = THIS_MODULE, + .name = "cdg", +}; + +static int __init tcp_cdg_register(void) +{ + if (backoff_beta > 1024 || window < 1 || window > 256) + return -ERANGE; + if (!is_power_of_2(window)) + return -EINVAL; + + BUILD_BUG_ON(sizeof(struct cdg) > ICSK_CA_PRIV_SIZE); + tcp_register_congestion_control(&tcp_cdg); + return 0; +} + +static void __exit tcp_cdg_unregister(void) +{ + tcp_unregister_congestion_control(&tcp_cdg); +} + +module_init(tcp_cdg_register); +module_exit(tcp_cdg_unregister); +MODULE_AUTHOR("Kenneth Klette Jonassen"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("TCP CDG"); diff --git a/net/ipv4/tcp_cong.c b/net/ipv4/tcp_cong.c new file mode 100644 index 000000000..d3cae4074 --- /dev/null +++ b/net/ipv4/tcp_cong.c @@ -0,0 +1,488 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Pluggable TCP congestion control support and newReno + * congestion control. + * Based on ideas from I/O scheduler support and Web100. + * + * Copyright (C) 2005 Stephen Hemminger + */ + +#define pr_fmt(fmt) "TCP: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include + +static DEFINE_SPINLOCK(tcp_cong_list_lock); +static LIST_HEAD(tcp_cong_list); + +/* Simple linear search, don't expect many entries! */ +struct tcp_congestion_ops *tcp_ca_find(const char *name) +{ + struct tcp_congestion_ops *e; + + list_for_each_entry_rcu(e, &tcp_cong_list, list) { + if (strcmp(e->name, name) == 0) + return e; + } + + return NULL; +} + +void tcp_set_ca_state(struct sock *sk, const u8 ca_state) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + trace_tcp_cong_state_set(sk, ca_state); + + if (icsk->icsk_ca_ops->set_state) + icsk->icsk_ca_ops->set_state(sk, ca_state); + icsk->icsk_ca_state = ca_state; +} + +/* Must be called with rcu lock held */ +static struct tcp_congestion_ops *tcp_ca_find_autoload(struct net *net, + const char *name) +{ + struct tcp_congestion_ops *ca = tcp_ca_find(name); + +#ifdef CONFIG_MODULES + if (!ca && capable(CAP_NET_ADMIN)) { + rcu_read_unlock(); + request_module("tcp_%s", name); + rcu_read_lock(); + ca = tcp_ca_find(name); + } +#endif + return ca; +} + +/* Simple linear search, not much in here. */ +struct tcp_congestion_ops *tcp_ca_find_key(u32 key) +{ + struct tcp_congestion_ops *e; + + list_for_each_entry_rcu(e, &tcp_cong_list, list) { + if (e->key == key) + return e; + } + + return NULL; +} + +/* + * Attach new congestion control algorithm to the list + * of available options. + */ +int tcp_register_congestion_control(struct tcp_congestion_ops *ca) +{ + int ret = 0; + + /* all algorithms must implement these */ + if (!ca->ssthresh || !ca->undo_cwnd || + !(ca->cong_avoid || ca->cong_control)) { + pr_err("%s does not implement required ops\n", ca->name); + return -EINVAL; + } + + ca->key = jhash(ca->name, sizeof(ca->name), strlen(ca->name)); + + spin_lock(&tcp_cong_list_lock); + if (ca->key == TCP_CA_UNSPEC || tcp_ca_find_key(ca->key)) { + pr_notice("%s already registered or non-unique key\n", + ca->name); + ret = -EEXIST; + } else { + list_add_tail_rcu(&ca->list, &tcp_cong_list); + pr_debug("%s registered\n", ca->name); + } + spin_unlock(&tcp_cong_list_lock); + + return ret; +} +EXPORT_SYMBOL_GPL(tcp_register_congestion_control); + +/* + * Remove congestion control algorithm, called from + * the module's remove function. Module ref counts are used + * to ensure that this can't be done till all sockets using + * that method are closed. + */ +void tcp_unregister_congestion_control(struct tcp_congestion_ops *ca) +{ + spin_lock(&tcp_cong_list_lock); + list_del_rcu(&ca->list); + spin_unlock(&tcp_cong_list_lock); + + /* Wait for outstanding readers to complete before the + * module gets removed entirely. + * + * A try_module_get() should fail by now as our module is + * in "going" state since no refs are held anymore and + * module_exit() handler being called. + */ + synchronize_rcu(); +} +EXPORT_SYMBOL_GPL(tcp_unregister_congestion_control); + +u32 tcp_ca_get_key_by_name(struct net *net, const char *name, bool *ecn_ca) +{ + const struct tcp_congestion_ops *ca; + u32 key = TCP_CA_UNSPEC; + + might_sleep(); + + rcu_read_lock(); + ca = tcp_ca_find_autoload(net, name); + if (ca) { + key = ca->key; + *ecn_ca = ca->flags & TCP_CONG_NEEDS_ECN; + } + rcu_read_unlock(); + + return key; +} + +char *tcp_ca_get_name_by_key(u32 key, char *buffer) +{ + const struct tcp_congestion_ops *ca; + char *ret = NULL; + + rcu_read_lock(); + ca = tcp_ca_find_key(key); + if (ca) + ret = strncpy(buffer, ca->name, + TCP_CA_NAME_MAX); + rcu_read_unlock(); + + return ret; +} + +/* Assign choice of congestion control. */ +void tcp_assign_congestion_control(struct sock *sk) +{ + struct net *net = sock_net(sk); + struct inet_connection_sock *icsk = inet_csk(sk); + const struct tcp_congestion_ops *ca; + + rcu_read_lock(); + ca = rcu_dereference(net->ipv4.tcp_congestion_control); + if (unlikely(!bpf_try_module_get(ca, ca->owner))) + ca = &tcp_reno; + icsk->icsk_ca_ops = ca; + rcu_read_unlock(); + + memset(icsk->icsk_ca_priv, 0, sizeof(icsk->icsk_ca_priv)); + if (ca->flags & TCP_CONG_NEEDS_ECN) + INET_ECN_xmit(sk); + else + INET_ECN_dontxmit(sk); +} + +void tcp_init_congestion_control(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + tcp_sk(sk)->prior_ssthresh = 0; + if (icsk->icsk_ca_ops->init) + icsk->icsk_ca_ops->init(sk); + if (tcp_ca_needs_ecn(sk)) + INET_ECN_xmit(sk); + else + INET_ECN_dontxmit(sk); + icsk->icsk_ca_initialized = 1; +} + +static void tcp_reinit_congestion_control(struct sock *sk, + const struct tcp_congestion_ops *ca) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + tcp_cleanup_congestion_control(sk); + icsk->icsk_ca_ops = ca; + icsk->icsk_ca_setsockopt = 1; + memset(icsk->icsk_ca_priv, 0, sizeof(icsk->icsk_ca_priv)); + + if (ca->flags & TCP_CONG_NEEDS_ECN) + INET_ECN_xmit(sk); + else + INET_ECN_dontxmit(sk); + + if (!((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN))) + tcp_init_congestion_control(sk); +} + +/* Manage refcounts on socket close. */ +void tcp_cleanup_congestion_control(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + if (icsk->icsk_ca_ops->release) + icsk->icsk_ca_ops->release(sk); + bpf_module_put(icsk->icsk_ca_ops, icsk->icsk_ca_ops->owner); +} + +/* Used by sysctl to change default congestion control */ +int tcp_set_default_congestion_control(struct net *net, const char *name) +{ + struct tcp_congestion_ops *ca; + const struct tcp_congestion_ops *prev; + int ret; + + rcu_read_lock(); + ca = tcp_ca_find_autoload(net, name); + if (!ca) { + ret = -ENOENT; + } else if (!bpf_try_module_get(ca, ca->owner)) { + ret = -EBUSY; + } else if (!net_eq(net, &init_net) && + !(ca->flags & TCP_CONG_NON_RESTRICTED)) { + /* Only init netns can set default to a restricted algorithm */ + ret = -EPERM; + } else { + prev = xchg(&net->ipv4.tcp_congestion_control, ca); + if (prev) + bpf_module_put(prev, prev->owner); + + ca->flags |= TCP_CONG_NON_RESTRICTED; + ret = 0; + } + rcu_read_unlock(); + + return ret; +} + +/* Set default value from kernel configuration at bootup */ +static int __init tcp_congestion_default(void) +{ + return tcp_set_default_congestion_control(&init_net, + CONFIG_DEFAULT_TCP_CONG); +} +late_initcall(tcp_congestion_default); + +/* Build string with list of available congestion control values */ +void tcp_get_available_congestion_control(char *buf, size_t maxlen) +{ + struct tcp_congestion_ops *ca; + size_t offs = 0; + + rcu_read_lock(); + list_for_each_entry_rcu(ca, &tcp_cong_list, list) { + offs += snprintf(buf + offs, maxlen - offs, + "%s%s", + offs == 0 ? "" : " ", ca->name); + + if (WARN_ON_ONCE(offs >= maxlen)) + break; + } + rcu_read_unlock(); +} + +/* Get current default congestion control */ +void tcp_get_default_congestion_control(struct net *net, char *name) +{ + const struct tcp_congestion_ops *ca; + + rcu_read_lock(); + ca = rcu_dereference(net->ipv4.tcp_congestion_control); + strncpy(name, ca->name, TCP_CA_NAME_MAX); + rcu_read_unlock(); +} + +/* Built list of non-restricted congestion control values */ +void tcp_get_allowed_congestion_control(char *buf, size_t maxlen) +{ + struct tcp_congestion_ops *ca; + size_t offs = 0; + + *buf = '\0'; + rcu_read_lock(); + list_for_each_entry_rcu(ca, &tcp_cong_list, list) { + if (!(ca->flags & TCP_CONG_NON_RESTRICTED)) + continue; + offs += snprintf(buf + offs, maxlen - offs, + "%s%s", + offs == 0 ? "" : " ", ca->name); + + if (WARN_ON_ONCE(offs >= maxlen)) + break; + } + rcu_read_unlock(); +} + +/* Change list of non-restricted congestion control */ +int tcp_set_allowed_congestion_control(char *val) +{ + struct tcp_congestion_ops *ca; + char *saved_clone, *clone, *name; + int ret = 0; + + saved_clone = clone = kstrdup(val, GFP_USER); + if (!clone) + return -ENOMEM; + + spin_lock(&tcp_cong_list_lock); + /* pass 1 check for bad entries */ + while ((name = strsep(&clone, " ")) && *name) { + ca = tcp_ca_find(name); + if (!ca) { + ret = -ENOENT; + goto out; + } + } + + /* pass 2 clear old values */ + list_for_each_entry_rcu(ca, &tcp_cong_list, list) + ca->flags &= ~TCP_CONG_NON_RESTRICTED; + + /* pass 3 mark as allowed */ + while ((name = strsep(&val, " ")) && *name) { + ca = tcp_ca_find(name); + WARN_ON(!ca); + if (ca) + ca->flags |= TCP_CONG_NON_RESTRICTED; + } +out: + spin_unlock(&tcp_cong_list_lock); + kfree(saved_clone); + + return ret; +} + +/* Change congestion control for socket. If load is false, then it is the + * responsibility of the caller to call tcp_init_congestion_control or + * tcp_reinit_congestion_control (if the current congestion control was + * already initialized. + */ +int tcp_set_congestion_control(struct sock *sk, const char *name, bool load, + bool cap_net_admin) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + const struct tcp_congestion_ops *ca; + int err = 0; + + if (icsk->icsk_ca_dst_locked) + return -EPERM; + + rcu_read_lock(); + if (!load) + ca = tcp_ca_find(name); + else + ca = tcp_ca_find_autoload(sock_net(sk), name); + + /* No change asking for existing value */ + if (ca == icsk->icsk_ca_ops) { + icsk->icsk_ca_setsockopt = 1; + goto out; + } + + if (!ca) + err = -ENOENT; + else if (!((ca->flags & TCP_CONG_NON_RESTRICTED) || cap_net_admin)) + err = -EPERM; + else if (!bpf_try_module_get(ca, ca->owner)) + err = -EBUSY; + else + tcp_reinit_congestion_control(sk, ca); + out: + rcu_read_unlock(); + return err; +} + +/* Slow start is used when congestion window is no greater than the slow start + * threshold. We base on RFC2581 and also handle stretch ACKs properly. + * We do not implement RFC3465 Appropriate Byte Counting (ABC) per se but + * something better;) a packet is only considered (s)acked in its entirety to + * defend the ACK attacks described in the RFC. Slow start processes a stretch + * ACK of degree N as if N acks of degree 1 are received back to back except + * ABC caps N to 2. Slow start exits when cwnd grows over ssthresh and + * returns the leftover acks to adjust cwnd in congestion avoidance mode. + */ +u32 tcp_slow_start(struct tcp_sock *tp, u32 acked) +{ + u32 cwnd = min(tcp_snd_cwnd(tp) + acked, tp->snd_ssthresh); + + acked -= cwnd - tcp_snd_cwnd(tp); + tcp_snd_cwnd_set(tp, min(cwnd, tp->snd_cwnd_clamp)); + + return acked; +} +EXPORT_SYMBOL_GPL(tcp_slow_start); + +/* In theory this is tp->snd_cwnd += 1 / tp->snd_cwnd (or alternative w), + * for every packet that was ACKed. + */ +void tcp_cong_avoid_ai(struct tcp_sock *tp, u32 w, u32 acked) +{ + /* If credits accumulated at a higher w, apply them gently now. */ + if (tp->snd_cwnd_cnt >= w) { + tp->snd_cwnd_cnt = 0; + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + 1); + } + + tp->snd_cwnd_cnt += acked; + if (tp->snd_cwnd_cnt >= w) { + u32 delta = tp->snd_cwnd_cnt / w; + + tp->snd_cwnd_cnt -= delta * w; + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + delta); + } + tcp_snd_cwnd_set(tp, min(tcp_snd_cwnd(tp), tp->snd_cwnd_clamp)); +} +EXPORT_SYMBOL_GPL(tcp_cong_avoid_ai); + +/* + * TCP Reno congestion control + * This is special case used for fallback as well. + */ +/* This is Jacobson's slow start and congestion avoidance. + * SIGCOMM '88, p. 328. + */ +void tcp_reno_cong_avoid(struct sock *sk, u32 ack, u32 acked) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (!tcp_is_cwnd_limited(sk)) + return; + + /* In "safe" area, increase. */ + if (tcp_in_slow_start(tp)) { + acked = tcp_slow_start(tp, acked); + if (!acked) + return; + } + /* In dangerous area, increase slowly. */ + tcp_cong_avoid_ai(tp, tcp_snd_cwnd(tp), acked); +} +EXPORT_SYMBOL_GPL(tcp_reno_cong_avoid); + +/* Slow start threshold is half the congestion window (min 2) */ +u32 tcp_reno_ssthresh(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + + return max(tcp_snd_cwnd(tp) >> 1U, 2U); +} +EXPORT_SYMBOL_GPL(tcp_reno_ssthresh); + +u32 tcp_reno_undo_cwnd(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + + return max(tcp_snd_cwnd(tp), tp->prior_cwnd); +} +EXPORT_SYMBOL_GPL(tcp_reno_undo_cwnd); + +struct tcp_congestion_ops tcp_reno = { + .flags = TCP_CONG_NON_RESTRICTED, + .name = "reno", + .owner = THIS_MODULE, + .ssthresh = tcp_reno_ssthresh, + .cong_avoid = tcp_reno_cong_avoid, + .undo_cwnd = tcp_reno_undo_cwnd, +}; diff --git a/net/ipv4/tcp_cubic.c b/net/ipv4/tcp_cubic.c new file mode 100644 index 000000000..768c10c1f --- /dev/null +++ b/net/ipv4/tcp_cubic.c @@ -0,0 +1,557 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * TCP CUBIC: Binary Increase Congestion control for TCP v2.3 + * Home page: + * http://netsrv.csc.ncsu.edu/twiki/bin/view/Main/BIC + * This is from the implementation of CUBIC TCP in + * Sangtae Ha, Injong Rhee and Lisong Xu, + * "CUBIC: A New TCP-Friendly High-Speed TCP Variant" + * in ACM SIGOPS Operating System Review, July 2008. + * Available from: + * http://netsrv.csc.ncsu.edu/export/cubic_a_new_tcp_2008.pdf + * + * CUBIC integrates a new slow start algorithm, called HyStart. + * The details of HyStart are presented in + * Sangtae Ha and Injong Rhee, + * "Taming the Elephants: New TCP Slow Start", NCSU TechReport 2008. + * Available from: + * http://netsrv.csc.ncsu.edu/export/hystart_techreport_2008.pdf + * + * All testing results are available from: + * http://netsrv.csc.ncsu.edu/wiki/index.php/TCP_Testing + * + * Unless CUBIC is enabled and congestion window is large + * this behaves the same as the original Reno. + */ + +#include +#include +#include +#include +#include +#include + +#define BICTCP_BETA_SCALE 1024 /* Scale factor beta calculation + * max_cwnd = snd_cwnd * beta + */ +#define BICTCP_HZ 10 /* BIC HZ 2^10 = 1024 */ + +/* Two methods of hybrid slow start */ +#define HYSTART_ACK_TRAIN 0x1 +#define HYSTART_DELAY 0x2 + +/* Number of delay samples for detecting the increase of delay */ +#define HYSTART_MIN_SAMPLES 8 +#define HYSTART_DELAY_MIN (4000U) /* 4 ms */ +#define HYSTART_DELAY_MAX (16000U) /* 16 ms */ +#define HYSTART_DELAY_THRESH(x) clamp(x, HYSTART_DELAY_MIN, HYSTART_DELAY_MAX) + +static int fast_convergence __read_mostly = 1; +static int beta __read_mostly = 717; /* = 717/1024 (BICTCP_BETA_SCALE) */ +static int initial_ssthresh __read_mostly; +static int bic_scale __read_mostly = 41; +static int tcp_friendliness __read_mostly = 1; + +static int hystart __read_mostly = 1; +static int hystart_detect __read_mostly = HYSTART_ACK_TRAIN | HYSTART_DELAY; +static int hystart_low_window __read_mostly = 16; +static int hystart_ack_delta_us __read_mostly = 2000; + +static u32 cube_rtt_scale __read_mostly; +static u32 beta_scale __read_mostly; +static u64 cube_factor __read_mostly; + +/* Note parameters that are used for precomputing scale factors are read-only */ +module_param(fast_convergence, int, 0644); +MODULE_PARM_DESC(fast_convergence, "turn on/off fast convergence"); +module_param(beta, int, 0644); +MODULE_PARM_DESC(beta, "beta for multiplicative increase"); +module_param(initial_ssthresh, int, 0644); +MODULE_PARM_DESC(initial_ssthresh, "initial value of slow start threshold"); +module_param(bic_scale, int, 0444); +MODULE_PARM_DESC(bic_scale, "scale (scaled by 1024) value for bic function (bic_scale/1024)"); +module_param(tcp_friendliness, int, 0644); +MODULE_PARM_DESC(tcp_friendliness, "turn on/off tcp friendliness"); +module_param(hystart, int, 0644); +MODULE_PARM_DESC(hystart, "turn on/off hybrid slow start algorithm"); +module_param(hystart_detect, int, 0644); +MODULE_PARM_DESC(hystart_detect, "hybrid slow start detection mechanisms" + " 1: packet-train 2: delay 3: both packet-train and delay"); +module_param(hystart_low_window, int, 0644); +MODULE_PARM_DESC(hystart_low_window, "lower bound cwnd for hybrid slow start"); +module_param(hystart_ack_delta_us, int, 0644); +MODULE_PARM_DESC(hystart_ack_delta_us, "spacing between ack's indicating train (usecs)"); + +/* BIC TCP Parameters */ +struct bictcp { + u32 cnt; /* increase cwnd by 1 after ACKs */ + u32 last_max_cwnd; /* last maximum snd_cwnd */ + u32 last_cwnd; /* the last snd_cwnd */ + u32 last_time; /* time when updated last_cwnd */ + u32 bic_origin_point;/* origin point of bic function */ + u32 bic_K; /* time to origin point + from the beginning of the current epoch */ + u32 delay_min; /* min delay (usec) */ + u32 epoch_start; /* beginning of an epoch */ + u32 ack_cnt; /* number of acks */ + u32 tcp_cwnd; /* estimated tcp cwnd */ + u16 unused; + u8 sample_cnt; /* number of samples to decide curr_rtt */ + u8 found; /* the exit point is found? */ + u32 round_start; /* beginning of each round */ + u32 end_seq; /* end_seq of the round */ + u32 last_ack; /* last time when the ACK spacing is close */ + u32 curr_rtt; /* the minimum rtt of current round */ +}; + +static inline void bictcp_reset(struct bictcp *ca) +{ + memset(ca, 0, offsetof(struct bictcp, unused)); + ca->found = 0; +} + +static inline u32 bictcp_clock_us(const struct sock *sk) +{ + return tcp_sk(sk)->tcp_mstamp; +} + +static inline void bictcp_hystart_reset(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bictcp *ca = inet_csk_ca(sk); + + ca->round_start = ca->last_ack = bictcp_clock_us(sk); + ca->end_seq = tp->snd_nxt; + ca->curr_rtt = ~0U; + ca->sample_cnt = 0; +} + +static void cubictcp_init(struct sock *sk) +{ + struct bictcp *ca = inet_csk_ca(sk); + + bictcp_reset(ca); + + if (hystart) + bictcp_hystart_reset(sk); + + if (!hystart && initial_ssthresh) + tcp_sk(sk)->snd_ssthresh = initial_ssthresh; +} + +static void cubictcp_cwnd_event(struct sock *sk, enum tcp_ca_event event) +{ + if (event == CA_EVENT_TX_START) { + struct bictcp *ca = inet_csk_ca(sk); + u32 now = tcp_jiffies32; + s32 delta; + + delta = now - tcp_sk(sk)->lsndtime; + + /* We were application limited (idle) for a while. + * Shift epoch_start to keep cwnd growth to cubic curve. + */ + if (ca->epoch_start && delta > 0) { + ca->epoch_start += delta; + if (after(ca->epoch_start, now)) + ca->epoch_start = now; + } + return; + } +} + +/* calculate the cubic root of x using a table lookup followed by one + * Newton-Raphson iteration. + * Avg err ~= 0.195% + */ +static u32 cubic_root(u64 a) +{ + u32 x, b, shift; + /* + * cbrt(x) MSB values for x MSB values in [0..63]. + * Precomputed then refined by hand - Willy Tarreau + * + * For x in [0..63], + * v = cbrt(x << 18) - 1 + * cbrt(x) = (v[x] + 10) >> 6 + */ + static const u8 v[] = { + /* 0x00 */ 0, 54, 54, 54, 118, 118, 118, 118, + /* 0x08 */ 123, 129, 134, 138, 143, 147, 151, 156, + /* 0x10 */ 157, 161, 164, 168, 170, 173, 176, 179, + /* 0x18 */ 181, 185, 187, 190, 192, 194, 197, 199, + /* 0x20 */ 200, 202, 204, 206, 209, 211, 213, 215, + /* 0x28 */ 217, 219, 221, 222, 224, 225, 227, 229, + /* 0x30 */ 231, 232, 234, 236, 237, 239, 240, 242, + /* 0x38 */ 244, 245, 246, 248, 250, 251, 252, 254, + }; + + b = fls64(a); + if (b < 7) { + /* a in [0..63] */ + return ((u32)v[(u32)a] + 35) >> 6; + } + + b = ((b * 84) >> 8) - 1; + shift = (a >> (b * 3)); + + x = ((u32)(((u32)v[shift] + 10) << b)) >> 6; + + /* + * Newton-Raphson iteration + * 2 + * x = ( 2 * x + a / x ) / 3 + * k+1 k k + */ + x = (2 * x + (u32)div64_u64(a, (u64)x * (u64)(x - 1))); + x = ((x * 341) >> 10); + return x; +} + +/* + * Compute congestion window to use. + */ +static inline void bictcp_update(struct bictcp *ca, u32 cwnd, u32 acked) +{ + u32 delta, bic_target, max_cnt; + u64 offs, t; + + ca->ack_cnt += acked; /* count the number of ACKed packets */ + + if (ca->last_cwnd == cwnd && + (s32)(tcp_jiffies32 - ca->last_time) <= HZ / 32) + return; + + /* The CUBIC function can update ca->cnt at most once per jiffy. + * On all cwnd reduction events, ca->epoch_start is set to 0, + * which will force a recalculation of ca->cnt. + */ + if (ca->epoch_start && tcp_jiffies32 == ca->last_time) + goto tcp_friendliness; + + ca->last_cwnd = cwnd; + ca->last_time = tcp_jiffies32; + + if (ca->epoch_start == 0) { + ca->epoch_start = tcp_jiffies32; /* record beginning */ + ca->ack_cnt = acked; /* start counting */ + ca->tcp_cwnd = cwnd; /* syn with cubic */ + + if (ca->last_max_cwnd <= cwnd) { + ca->bic_K = 0; + ca->bic_origin_point = cwnd; + } else { + /* Compute new K based on + * (wmax-cwnd) * (srtt>>3 / HZ) / c * 2^(3*bictcp_HZ) + */ + ca->bic_K = cubic_root(cube_factor + * (ca->last_max_cwnd - cwnd)); + ca->bic_origin_point = ca->last_max_cwnd; + } + } + + /* cubic function - calc*/ + /* calculate c * time^3 / rtt, + * while considering overflow in calculation of time^3 + * (so time^3 is done by using 64 bit) + * and without the support of division of 64bit numbers + * (so all divisions are done by using 32 bit) + * also NOTE the unit of those veriables + * time = (t - K) / 2^bictcp_HZ + * c = bic_scale >> 10 + * rtt = (srtt >> 3) / HZ + * !!! The following code does not have overflow problems, + * if the cwnd < 1 million packets !!! + */ + + t = (s32)(tcp_jiffies32 - ca->epoch_start); + t += usecs_to_jiffies(ca->delay_min); + /* change the unit from HZ to bictcp_HZ */ + t <<= BICTCP_HZ; + do_div(t, HZ); + + if (t < ca->bic_K) /* t - K */ + offs = ca->bic_K - t; + else + offs = t - ca->bic_K; + + /* c/rtt * (t-K)^3 */ + delta = (cube_rtt_scale * offs * offs * offs) >> (10+3*BICTCP_HZ); + if (t < ca->bic_K) /* below origin*/ + bic_target = ca->bic_origin_point - delta; + else /* above origin*/ + bic_target = ca->bic_origin_point + delta; + + /* cubic function - calc bictcp_cnt*/ + if (bic_target > cwnd) { + ca->cnt = cwnd / (bic_target - cwnd); + } else { + ca->cnt = 100 * cwnd; /* very small increment*/ + } + + /* + * The initial growth of cubic function may be too conservative + * when the available bandwidth is still unknown. + */ + if (ca->last_max_cwnd == 0 && ca->cnt > 20) + ca->cnt = 20; /* increase cwnd 5% per RTT */ + +tcp_friendliness: + /* TCP Friendly */ + if (tcp_friendliness) { + u32 scale = beta_scale; + + delta = (cwnd * scale) >> 3; + while (ca->ack_cnt > delta) { /* update tcp cwnd */ + ca->ack_cnt -= delta; + ca->tcp_cwnd++; + } + + if (ca->tcp_cwnd > cwnd) { /* if bic is slower than tcp */ + delta = ca->tcp_cwnd - cwnd; + max_cnt = cwnd / delta; + if (ca->cnt > max_cnt) + ca->cnt = max_cnt; + } + } + + /* The maximum rate of cwnd increase CUBIC allows is 1 packet per + * 2 packets ACKed, meaning cwnd grows at 1.5x per RTT. + */ + ca->cnt = max(ca->cnt, 2U); +} + +static void cubictcp_cong_avoid(struct sock *sk, u32 ack, u32 acked) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bictcp *ca = inet_csk_ca(sk); + + if (!tcp_is_cwnd_limited(sk)) + return; + + if (tcp_in_slow_start(tp)) { + acked = tcp_slow_start(tp, acked); + if (!acked) + return; + } + bictcp_update(ca, tcp_snd_cwnd(tp), acked); + tcp_cong_avoid_ai(tp, ca->cnt, acked); +} + +static u32 cubictcp_recalc_ssthresh(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + struct bictcp *ca = inet_csk_ca(sk); + + ca->epoch_start = 0; /* end of epoch */ + + /* Wmax and fast convergence */ + if (tcp_snd_cwnd(tp) < ca->last_max_cwnd && fast_convergence) + ca->last_max_cwnd = (tcp_snd_cwnd(tp) * (BICTCP_BETA_SCALE + beta)) + / (2 * BICTCP_BETA_SCALE); + else + ca->last_max_cwnd = tcp_snd_cwnd(tp); + + return max((tcp_snd_cwnd(tp) * beta) / BICTCP_BETA_SCALE, 2U); +} + +static void cubictcp_state(struct sock *sk, u8 new_state) +{ + if (new_state == TCP_CA_Loss) { + bictcp_reset(inet_csk_ca(sk)); + bictcp_hystart_reset(sk); + } +} + +/* Account for TSO/GRO delays. + * Otherwise short RTT flows could get too small ssthresh, since during + * slow start we begin with small TSO packets and ca->delay_min would + * not account for long aggregation delay when TSO packets get bigger. + * Ideally even with a very small RTT we would like to have at least one + * TSO packet being sent and received by GRO, and another one in qdisc layer. + * We apply another 100% factor because @rate is doubled at this point. + * We cap the cushion to 1ms. + */ +static u32 hystart_ack_delay(const struct sock *sk) +{ + unsigned long rate; + + rate = READ_ONCE(sk->sk_pacing_rate); + if (!rate) + return 0; + return min_t(u64, USEC_PER_MSEC, + div64_ul((u64)sk->sk_gso_max_size * 4 * USEC_PER_SEC, rate)); +} + +static void hystart_update(struct sock *sk, u32 delay) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct bictcp *ca = inet_csk_ca(sk); + u32 threshold; + + if (after(tp->snd_una, ca->end_seq)) + bictcp_hystart_reset(sk); + + if (hystart_detect & HYSTART_ACK_TRAIN) { + u32 now = bictcp_clock_us(sk); + + /* first detection parameter - ack-train detection */ + if ((s32)(now - ca->last_ack) <= hystart_ack_delta_us) { + ca->last_ack = now; + + threshold = ca->delay_min + hystart_ack_delay(sk); + + /* Hystart ack train triggers if we get ack past + * ca->delay_min/2. + * Pacing might have delayed packets up to RTT/2 + * during slow start. + */ + if (sk->sk_pacing_status == SK_PACING_NONE) + threshold >>= 1; + + if ((s32)(now - ca->round_start) > threshold) { + ca->found = 1; + pr_debug("hystart_ack_train (%u > %u) delay_min %u (+ ack_delay %u) cwnd %u\n", + now - ca->round_start, threshold, + ca->delay_min, hystart_ack_delay(sk), tcp_snd_cwnd(tp)); + NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPHYSTARTTRAINDETECT); + NET_ADD_STATS(sock_net(sk), + LINUX_MIB_TCPHYSTARTTRAINCWND, + tcp_snd_cwnd(tp)); + tp->snd_ssthresh = tcp_snd_cwnd(tp); + } + } + } + + if (hystart_detect & HYSTART_DELAY) { + /* obtain the minimum delay of more than sampling packets */ + if (ca->curr_rtt > delay) + ca->curr_rtt = delay; + if (ca->sample_cnt < HYSTART_MIN_SAMPLES) { + ca->sample_cnt++; + } else { + if (ca->curr_rtt > ca->delay_min + + HYSTART_DELAY_THRESH(ca->delay_min >> 3)) { + ca->found = 1; + NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPHYSTARTDELAYDETECT); + NET_ADD_STATS(sock_net(sk), + LINUX_MIB_TCPHYSTARTDELAYCWND, + tcp_snd_cwnd(tp)); + tp->snd_ssthresh = tcp_snd_cwnd(tp); + } + } + } +} + +static void cubictcp_acked(struct sock *sk, const struct ack_sample *sample) +{ + const struct tcp_sock *tp = tcp_sk(sk); + struct bictcp *ca = inet_csk_ca(sk); + u32 delay; + + /* Some calls are for duplicates without timetamps */ + if (sample->rtt_us < 0) + return; + + /* Discard delay samples right after fast recovery */ + if (ca->epoch_start && (s32)(tcp_jiffies32 - ca->epoch_start) < HZ) + return; + + delay = sample->rtt_us; + if (delay == 0) + delay = 1; + + /* first time call or link delay decreases */ + if (ca->delay_min == 0 || ca->delay_min > delay) + ca->delay_min = delay; + + /* hystart triggers when cwnd is larger than some threshold */ + if (!ca->found && tcp_in_slow_start(tp) && hystart && + tcp_snd_cwnd(tp) >= hystart_low_window) + hystart_update(sk, delay); +} + +static struct tcp_congestion_ops cubictcp __read_mostly = { + .init = cubictcp_init, + .ssthresh = cubictcp_recalc_ssthresh, + .cong_avoid = cubictcp_cong_avoid, + .set_state = cubictcp_state, + .undo_cwnd = tcp_reno_undo_cwnd, + .cwnd_event = cubictcp_cwnd_event, + .pkts_acked = cubictcp_acked, + .owner = THIS_MODULE, + .name = "cubic", +}; + +BTF_SET8_START(tcp_cubic_check_kfunc_ids) +#ifdef CONFIG_X86 +#ifdef CONFIG_DYNAMIC_FTRACE +BTF_ID_FLAGS(func, cubictcp_init) +BTF_ID_FLAGS(func, cubictcp_recalc_ssthresh) +BTF_ID_FLAGS(func, cubictcp_cong_avoid) +BTF_ID_FLAGS(func, cubictcp_state) +BTF_ID_FLAGS(func, cubictcp_cwnd_event) +BTF_ID_FLAGS(func, cubictcp_acked) +#endif +#endif +BTF_SET8_END(tcp_cubic_check_kfunc_ids) + +static const struct btf_kfunc_id_set tcp_cubic_kfunc_set = { + .owner = THIS_MODULE, + .set = &tcp_cubic_check_kfunc_ids, +}; + +static int __init cubictcp_register(void) +{ + int ret; + + BUILD_BUG_ON(sizeof(struct bictcp) > ICSK_CA_PRIV_SIZE); + + /* Precompute a bunch of the scaling factors that are used per-packet + * based on SRTT of 100ms + */ + + beta_scale = 8*(BICTCP_BETA_SCALE+beta) / 3 + / (BICTCP_BETA_SCALE - beta); + + cube_rtt_scale = (bic_scale * 10); /* 1024*c/rtt */ + + /* calculate the "K" for (wmax-cwnd) = c/rtt * K^3 + * so K = cubic_root( (wmax-cwnd)*rtt/c ) + * the unit of K is bictcp_HZ=2^10, not HZ + * + * c = bic_scale >> 10 + * rtt = 100ms + * + * the following code has been designed and tested for + * cwnd < 1 million packets + * RTT < 100 seconds + * HZ < 1,000,00 (corresponding to 10 nano-second) + */ + + /* 1/c * 2^2*bictcp_HZ * srtt */ + cube_factor = 1ull << (10+3*BICTCP_HZ); /* 2^40 */ + + /* divide by bic_scale and by constant Srtt (100ms) */ + do_div(cube_factor, bic_scale * 10); + + ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_STRUCT_OPS, &tcp_cubic_kfunc_set); + if (ret < 0) + return ret; + return tcp_register_congestion_control(&cubictcp); +} + +static void __exit cubictcp_unregister(void) +{ + tcp_unregister_congestion_control(&cubictcp); +} + +module_init(cubictcp_register); +module_exit(cubictcp_unregister); + +MODULE_AUTHOR("Sangtae Ha, Stephen Hemminger"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("CUBIC TCP"); +MODULE_VERSION("2.3"); diff --git a/net/ipv4/tcp_dctcp.c b/net/ipv4/tcp_dctcp.c new file mode 100644 index 000000000..2a6c0dd66 --- /dev/null +++ b/net/ipv4/tcp_dctcp.c @@ -0,0 +1,285 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* DataCenter TCP (DCTCP) congestion control. + * + * http://simula.stanford.edu/~alizade/Site/DCTCP.html + * + * This is an implementation of DCTCP over Reno, an enhancement to the + * TCP congestion control algorithm designed for data centers. DCTCP + * leverages Explicit Congestion Notification (ECN) in the network to + * provide multi-bit feedback to the end hosts. DCTCP's goal is to meet + * the following three data center transport requirements: + * + * - High burst tolerance (incast due to partition/aggregate) + * - Low latency (short flows, queries) + * - High throughput (continuous data updates, large file transfers) + * with commodity shallow buffered switches + * + * The algorithm is described in detail in the following two papers: + * + * 1) Mohammad Alizadeh, Albert Greenberg, David A. Maltz, Jitendra Padhye, + * Parveen Patel, Balaji Prabhakar, Sudipta Sengupta, and Murari Sridharan: + * "Data Center TCP (DCTCP)", Data Center Networks session + * Proc. ACM SIGCOMM, New Delhi, 2010. + * http://simula.stanford.edu/~alizade/Site/DCTCP_files/dctcp-final.pdf + * + * 2) Mohammad Alizadeh, Adel Javanmard, and Balaji Prabhakar: + * "Analysis of DCTCP: Stability, Convergence, and Fairness" + * Proc. ACM SIGMETRICS, San Jose, 2011. + * http://simula.stanford.edu/~alizade/Site/DCTCP_files/dctcp_analysis-full.pdf + * + * Initial prototype from Abdul Kabbani, Masato Yasuda and Mohammad Alizadeh. + * + * Authors: + * + * Daniel Borkmann + * Florian Westphal + * Glenn Judd + */ + +#include +#include +#include +#include +#include +#include +#include "tcp_dctcp.h" + +#define DCTCP_MAX_ALPHA 1024U + +struct dctcp { + u32 old_delivered; + u32 old_delivered_ce; + u32 prior_rcv_nxt; + u32 dctcp_alpha; + u32 next_seq; + u32 ce_state; + u32 loss_cwnd; +}; + +static unsigned int dctcp_shift_g __read_mostly = 4; /* g = 1/2^4 */ +module_param(dctcp_shift_g, uint, 0644); +MODULE_PARM_DESC(dctcp_shift_g, "parameter g for updating dctcp_alpha"); + +static unsigned int dctcp_alpha_on_init __read_mostly = DCTCP_MAX_ALPHA; +module_param(dctcp_alpha_on_init, uint, 0644); +MODULE_PARM_DESC(dctcp_alpha_on_init, "parameter for initial alpha value"); + +static struct tcp_congestion_ops dctcp_reno; + +static void dctcp_reset(const struct tcp_sock *tp, struct dctcp *ca) +{ + ca->next_seq = tp->snd_nxt; + + ca->old_delivered = tp->delivered; + ca->old_delivered_ce = tp->delivered_ce; +} + +static void dctcp_init(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + + if ((tp->ecn_flags & TCP_ECN_OK) || + (sk->sk_state == TCP_LISTEN || + sk->sk_state == TCP_CLOSE)) { + struct dctcp *ca = inet_csk_ca(sk); + + ca->prior_rcv_nxt = tp->rcv_nxt; + + ca->dctcp_alpha = min(dctcp_alpha_on_init, DCTCP_MAX_ALPHA); + + ca->loss_cwnd = 0; + ca->ce_state = 0; + + dctcp_reset(tp, ca); + return; + } + + /* No ECN support? Fall back to Reno. Also need to clear + * ECT from sk since it is set during 3WHS for DCTCP. + */ + inet_csk(sk)->icsk_ca_ops = &dctcp_reno; + INET_ECN_dontxmit(sk); +} + +static u32 dctcp_ssthresh(struct sock *sk) +{ + struct dctcp *ca = inet_csk_ca(sk); + struct tcp_sock *tp = tcp_sk(sk); + + ca->loss_cwnd = tcp_snd_cwnd(tp); + return max(tcp_snd_cwnd(tp) - ((tcp_snd_cwnd(tp) * ca->dctcp_alpha) >> 11U), 2U); +} + +static void dctcp_update_alpha(struct sock *sk, u32 flags) +{ + const struct tcp_sock *tp = tcp_sk(sk); + struct dctcp *ca = inet_csk_ca(sk); + + /* Expired RTT */ + if (!before(tp->snd_una, ca->next_seq)) { + u32 delivered_ce = tp->delivered_ce - ca->old_delivered_ce; + u32 alpha = ca->dctcp_alpha; + + /* alpha = (1 - g) * alpha + g * F */ + + alpha -= min_not_zero(alpha, alpha >> dctcp_shift_g); + if (delivered_ce) { + u32 delivered = tp->delivered - ca->old_delivered; + + /* If dctcp_shift_g == 1, a 32bit value would overflow + * after 8 M packets. + */ + delivered_ce <<= (10 - dctcp_shift_g); + delivered_ce /= max(1U, delivered); + + alpha = min(alpha + delivered_ce, DCTCP_MAX_ALPHA); + } + /* dctcp_alpha can be read from dctcp_get_info() without + * synchro, so we ask compiler to not use dctcp_alpha + * as a temporary variable in prior operations. + */ + WRITE_ONCE(ca->dctcp_alpha, alpha); + dctcp_reset(tp, ca); + } +} + +static void dctcp_react_to_loss(struct sock *sk) +{ + struct dctcp *ca = inet_csk_ca(sk); + struct tcp_sock *tp = tcp_sk(sk); + + ca->loss_cwnd = tcp_snd_cwnd(tp); + tp->snd_ssthresh = max(tcp_snd_cwnd(tp) >> 1U, 2U); +} + +static void dctcp_state(struct sock *sk, u8 new_state) +{ + if (new_state == TCP_CA_Recovery && + new_state != inet_csk(sk)->icsk_ca_state) + dctcp_react_to_loss(sk); + /* We handle RTO in dctcp_cwnd_event to ensure that we perform only + * one loss-adjustment per RTT. + */ +} + +static void dctcp_cwnd_event(struct sock *sk, enum tcp_ca_event ev) +{ + struct dctcp *ca = inet_csk_ca(sk); + + switch (ev) { + case CA_EVENT_ECN_IS_CE: + case CA_EVENT_ECN_NO_CE: + dctcp_ece_ack_update(sk, ev, &ca->prior_rcv_nxt, &ca->ce_state); + break; + case CA_EVENT_LOSS: + dctcp_react_to_loss(sk); + break; + default: + /* Don't care for the rest. */ + break; + } +} + +static size_t dctcp_get_info(struct sock *sk, u32 ext, int *attr, + union tcp_cc_info *info) +{ + const struct dctcp *ca = inet_csk_ca(sk); + const struct tcp_sock *tp = tcp_sk(sk); + + /* Fill it also in case of VEGASINFO due to req struct limits. + * We can still correctly retrieve it later. + */ + if (ext & (1 << (INET_DIAG_DCTCPINFO - 1)) || + ext & (1 << (INET_DIAG_VEGASINFO - 1))) { + memset(&info->dctcp, 0, sizeof(info->dctcp)); + if (inet_csk(sk)->icsk_ca_ops != &dctcp_reno) { + info->dctcp.dctcp_enabled = 1; + info->dctcp.dctcp_ce_state = (u16) ca->ce_state; + info->dctcp.dctcp_alpha = ca->dctcp_alpha; + info->dctcp.dctcp_ab_ecn = tp->mss_cache * + (tp->delivered_ce - ca->old_delivered_ce); + info->dctcp.dctcp_ab_tot = tp->mss_cache * + (tp->delivered - ca->old_delivered); + } + + *attr = INET_DIAG_DCTCPINFO; + return sizeof(info->dctcp); + } + return 0; +} + +static u32 dctcp_cwnd_undo(struct sock *sk) +{ + const struct dctcp *ca = inet_csk_ca(sk); + struct tcp_sock *tp = tcp_sk(sk); + + return max(tcp_snd_cwnd(tp), ca->loss_cwnd); +} + +static struct tcp_congestion_ops dctcp __read_mostly = { + .init = dctcp_init, + .in_ack_event = dctcp_update_alpha, + .cwnd_event = dctcp_cwnd_event, + .ssthresh = dctcp_ssthresh, + .cong_avoid = tcp_reno_cong_avoid, + .undo_cwnd = dctcp_cwnd_undo, + .set_state = dctcp_state, + .get_info = dctcp_get_info, + .flags = TCP_CONG_NEEDS_ECN, + .owner = THIS_MODULE, + .name = "dctcp", +}; + +static struct tcp_congestion_ops dctcp_reno __read_mostly = { + .ssthresh = tcp_reno_ssthresh, + .cong_avoid = tcp_reno_cong_avoid, + .undo_cwnd = tcp_reno_undo_cwnd, + .get_info = dctcp_get_info, + .owner = THIS_MODULE, + .name = "dctcp-reno", +}; + +BTF_SET8_START(tcp_dctcp_check_kfunc_ids) +#ifdef CONFIG_X86 +#ifdef CONFIG_DYNAMIC_FTRACE +BTF_ID_FLAGS(func, dctcp_init) +BTF_ID_FLAGS(func, dctcp_update_alpha) +BTF_ID_FLAGS(func, dctcp_cwnd_event) +BTF_ID_FLAGS(func, dctcp_ssthresh) +BTF_ID_FLAGS(func, dctcp_cwnd_undo) +BTF_ID_FLAGS(func, dctcp_state) +#endif +#endif +BTF_SET8_END(tcp_dctcp_check_kfunc_ids) + +static const struct btf_kfunc_id_set tcp_dctcp_kfunc_set = { + .owner = THIS_MODULE, + .set = &tcp_dctcp_check_kfunc_ids, +}; + +static int __init dctcp_register(void) +{ + int ret; + + BUILD_BUG_ON(sizeof(struct dctcp) > ICSK_CA_PRIV_SIZE); + + ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_STRUCT_OPS, &tcp_dctcp_kfunc_set); + if (ret < 0) + return ret; + return tcp_register_congestion_control(&dctcp); +} + +static void __exit dctcp_unregister(void) +{ + tcp_unregister_congestion_control(&dctcp); +} + +module_init(dctcp_register); +module_exit(dctcp_unregister); + +MODULE_AUTHOR("Daniel Borkmann "); +MODULE_AUTHOR("Florian Westphal "); +MODULE_AUTHOR("Glenn Judd "); + +MODULE_LICENSE("GPL v2"); +MODULE_DESCRIPTION("DataCenter TCP (DCTCP)"); diff --git a/net/ipv4/tcp_dctcp.h b/net/ipv4/tcp_dctcp.h new file mode 100644 index 000000000..d69a77cbd --- /dev/null +++ b/net/ipv4/tcp_dctcp.h @@ -0,0 +1,40 @@ +#ifndef _TCP_DCTCP_H +#define _TCP_DCTCP_H + +static inline void dctcp_ece_ack_cwr(struct sock *sk, u32 ce_state) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (ce_state == 1) + tp->ecn_flags |= TCP_ECN_DEMAND_CWR; + else + tp->ecn_flags &= ~TCP_ECN_DEMAND_CWR; +} + +/* Minimal DCTP CE state machine: + * + * S: 0 <- last pkt was non-CE + * 1 <- last pkt was CE + */ +static inline void dctcp_ece_ack_update(struct sock *sk, enum tcp_ca_event evt, + u32 *prior_rcv_nxt, u32 *ce_state) +{ + u32 new_ce_state = (evt == CA_EVENT_ECN_IS_CE) ? 1 : 0; + + if (*ce_state != new_ce_state) { + /* CE state has changed, force an immediate ACK to + * reflect the new CE state. If an ACK was delayed, + * send that first to reflect the prior CE state. + */ + if (inet_csk(sk)->icsk_ack.pending & ICSK_ACK_TIMER) { + dctcp_ece_ack_cwr(sk, *ce_state); + __tcp_send_ack(sk, *prior_rcv_nxt); + } + inet_csk(sk)->icsk_ack.pending |= ICSK_ACK_NOW; + } + *prior_rcv_nxt = tcp_sk(sk)->rcv_nxt; + *ce_state = new_ce_state; + dctcp_ece_ack_cwr(sk, new_ce_state); +} + +#endif diff --git a/net/ipv4/tcp_diag.c b/net/ipv4/tcp_diag.c new file mode 100644 index 000000000..01b50fa79 --- /dev/null +++ b/net/ipv4/tcp_diag.c @@ -0,0 +1,250 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * tcp_diag.c Module for monitoring TCP transport protocols sockets. + * + * Authors: Alexey Kuznetsov, + */ + +#include +#include +#include +#include + +#include + +#include +#include + +static void tcp_diag_get_info(struct sock *sk, struct inet_diag_msg *r, + void *_info) +{ + struct tcp_info *info = _info; + + if (inet_sk_state_load(sk) == TCP_LISTEN) { + r->idiag_rqueue = READ_ONCE(sk->sk_ack_backlog); + r->idiag_wqueue = READ_ONCE(sk->sk_max_ack_backlog); + } else if (sk->sk_type == SOCK_STREAM) { + const struct tcp_sock *tp = tcp_sk(sk); + + r->idiag_rqueue = max_t(int, READ_ONCE(tp->rcv_nxt) - + READ_ONCE(tp->copied_seq), 0); + r->idiag_wqueue = READ_ONCE(tp->write_seq) - tp->snd_una; + } + if (info) + tcp_get_info(sk, info); +} + +#ifdef CONFIG_TCP_MD5SIG +static void tcp_diag_md5sig_fill(struct tcp_diag_md5sig *info, + const struct tcp_md5sig_key *key) +{ + info->tcpm_family = key->family; + info->tcpm_prefixlen = key->prefixlen; + info->tcpm_keylen = key->keylen; + memcpy(info->tcpm_key, key->key, key->keylen); + + if (key->family == AF_INET) + info->tcpm_addr[0] = key->addr.a4.s_addr; + #if IS_ENABLED(CONFIG_IPV6) + else if (key->family == AF_INET6) + memcpy(&info->tcpm_addr, &key->addr.a6, + sizeof(info->tcpm_addr)); + #endif +} + +static int tcp_diag_put_md5sig(struct sk_buff *skb, + const struct tcp_md5sig_info *md5sig) +{ + const struct tcp_md5sig_key *key; + struct tcp_diag_md5sig *info; + struct nlattr *attr; + int md5sig_count = 0; + + hlist_for_each_entry_rcu(key, &md5sig->head, node) + md5sig_count++; + if (md5sig_count == 0) + return 0; + + attr = nla_reserve(skb, INET_DIAG_MD5SIG, + md5sig_count * sizeof(struct tcp_diag_md5sig)); + if (!attr) + return -EMSGSIZE; + + info = nla_data(attr); + memset(info, 0, md5sig_count * sizeof(struct tcp_diag_md5sig)); + hlist_for_each_entry_rcu(key, &md5sig->head, node) { + tcp_diag_md5sig_fill(info++, key); + if (--md5sig_count == 0) + break; + } + + return 0; +} +#endif + +static int tcp_diag_put_ulp(struct sk_buff *skb, struct sock *sk, + const struct tcp_ulp_ops *ulp_ops) +{ + struct nlattr *nest; + int err; + + nest = nla_nest_start_noflag(skb, INET_DIAG_ULP_INFO); + if (!nest) + return -EMSGSIZE; + + err = nla_put_string(skb, INET_ULP_INFO_NAME, ulp_ops->name); + if (err) + goto nla_failure; + + if (ulp_ops->get_info) + err = ulp_ops->get_info(sk, skb); + if (err) + goto nla_failure; + + nla_nest_end(skb, nest); + return 0; + +nla_failure: + nla_nest_cancel(skb, nest); + return err; +} + +static int tcp_diag_get_aux(struct sock *sk, bool net_admin, + struct sk_buff *skb) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + int err = 0; + +#ifdef CONFIG_TCP_MD5SIG + if (net_admin) { + struct tcp_md5sig_info *md5sig; + + rcu_read_lock(); + md5sig = rcu_dereference(tcp_sk(sk)->md5sig_info); + if (md5sig) + err = tcp_diag_put_md5sig(skb, md5sig); + rcu_read_unlock(); + if (err < 0) + return err; + } +#endif + + if (net_admin) { + const struct tcp_ulp_ops *ulp_ops; + + ulp_ops = icsk->icsk_ulp_ops; + if (ulp_ops) + err = tcp_diag_put_ulp(skb, sk, ulp_ops); + if (err) + return err; + } + return 0; +} + +static size_t tcp_diag_get_aux_size(struct sock *sk, bool net_admin) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + size_t size = 0; + +#ifdef CONFIG_TCP_MD5SIG + if (net_admin && sk_fullsock(sk)) { + const struct tcp_md5sig_info *md5sig; + const struct tcp_md5sig_key *key; + size_t md5sig_count = 0; + + rcu_read_lock(); + md5sig = rcu_dereference(tcp_sk(sk)->md5sig_info); + if (md5sig) { + hlist_for_each_entry_rcu(key, &md5sig->head, node) + md5sig_count++; + } + rcu_read_unlock(); + size += nla_total_size(md5sig_count * + sizeof(struct tcp_diag_md5sig)); + } +#endif + + if (net_admin && sk_fullsock(sk)) { + const struct tcp_ulp_ops *ulp_ops; + + ulp_ops = icsk->icsk_ulp_ops; + if (ulp_ops) { + size += nla_total_size(0) + + nla_total_size(TCP_ULP_NAME_MAX); + if (ulp_ops->get_info_size) + size += ulp_ops->get_info_size(sk); + } + } + return size; +} + +static void tcp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, + const struct inet_diag_req_v2 *r) +{ + struct inet_hashinfo *hinfo; + + hinfo = sock_net(cb->skb->sk)->ipv4.tcp_death_row.hashinfo; + + inet_diag_dump_icsk(hinfo, skb, cb, r); +} + +static int tcp_diag_dump_one(struct netlink_callback *cb, + const struct inet_diag_req_v2 *req) +{ + struct inet_hashinfo *hinfo; + + hinfo = sock_net(cb->skb->sk)->ipv4.tcp_death_row.hashinfo; + + return inet_diag_dump_one_icsk(hinfo, cb, req); +} + +#ifdef CONFIG_INET_DIAG_DESTROY +static int tcp_diag_destroy(struct sk_buff *in_skb, + const struct inet_diag_req_v2 *req) +{ + struct net *net = sock_net(in_skb->sk); + struct inet_hashinfo *hinfo; + struct sock *sk; + int err; + + hinfo = net->ipv4.tcp_death_row.hashinfo; + sk = inet_diag_find_one_icsk(net, hinfo, req); + + if (IS_ERR(sk)) + return PTR_ERR(sk); + + err = sock_diag_destroy(sk, ECONNABORTED); + + sock_gen_put(sk); + + return err; +} +#endif + +static const struct inet_diag_handler tcp_diag_handler = { + .dump = tcp_diag_dump, + .dump_one = tcp_diag_dump_one, + .idiag_get_info = tcp_diag_get_info, + .idiag_get_aux = tcp_diag_get_aux, + .idiag_get_aux_size = tcp_diag_get_aux_size, + .idiag_type = IPPROTO_TCP, + .idiag_info_size = sizeof(struct tcp_info), +#ifdef CONFIG_INET_DIAG_DESTROY + .destroy = tcp_diag_destroy, +#endif +}; + +static int __init tcp_diag_init(void) +{ + return inet_diag_register(&tcp_diag_handler); +} + +static void __exit tcp_diag_exit(void) +{ + inet_diag_unregister(&tcp_diag_handler); +} + +module_init(tcp_diag_init); +module_exit(tcp_diag_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2-6 /* AF_INET - IPPROTO_TCP */); diff --git a/net/ipv4/tcp_fastopen.c b/net/ipv4/tcp_fastopen.c new file mode 100644 index 000000000..85e4953f1 --- /dev/null +++ b/net/ipv4/tcp_fastopen.c @@ -0,0 +1,595 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include + +void tcp_fastopen_init_key_once(struct net *net) +{ + u8 key[TCP_FASTOPEN_KEY_LENGTH]; + struct tcp_fastopen_context *ctxt; + + rcu_read_lock(); + ctxt = rcu_dereference(net->ipv4.tcp_fastopen_ctx); + if (ctxt) { + rcu_read_unlock(); + return; + } + rcu_read_unlock(); + + /* tcp_fastopen_reset_cipher publishes the new context + * atomically, so we allow this race happening here. + * + * All call sites of tcp_fastopen_cookie_gen also check + * for a valid cookie, so this is an acceptable risk. + */ + get_random_bytes(key, sizeof(key)); + tcp_fastopen_reset_cipher(net, NULL, key, NULL); +} + +static void tcp_fastopen_ctx_free(struct rcu_head *head) +{ + struct tcp_fastopen_context *ctx = + container_of(head, struct tcp_fastopen_context, rcu); + + kfree_sensitive(ctx); +} + +void tcp_fastopen_destroy_cipher(struct sock *sk) +{ + struct tcp_fastopen_context *ctx; + + ctx = rcu_dereference_protected( + inet_csk(sk)->icsk_accept_queue.fastopenq.ctx, 1); + if (ctx) + call_rcu(&ctx->rcu, tcp_fastopen_ctx_free); +} + +void tcp_fastopen_ctx_destroy(struct net *net) +{ + struct tcp_fastopen_context *ctxt; + + ctxt = xchg((__force struct tcp_fastopen_context **)&net->ipv4.tcp_fastopen_ctx, NULL); + + if (ctxt) + call_rcu(&ctxt->rcu, tcp_fastopen_ctx_free); +} + +int tcp_fastopen_reset_cipher(struct net *net, struct sock *sk, + void *primary_key, void *backup_key) +{ + struct tcp_fastopen_context *ctx, *octx; + struct fastopen_queue *q; + int err = 0; + + ctx = kmalloc(sizeof(*ctx), GFP_KERNEL); + if (!ctx) { + err = -ENOMEM; + goto out; + } + + ctx->key[0].key[0] = get_unaligned_le64(primary_key); + ctx->key[0].key[1] = get_unaligned_le64(primary_key + 8); + if (backup_key) { + ctx->key[1].key[0] = get_unaligned_le64(backup_key); + ctx->key[1].key[1] = get_unaligned_le64(backup_key + 8); + ctx->num = 2; + } else { + ctx->num = 1; + } + + if (sk) { + q = &inet_csk(sk)->icsk_accept_queue.fastopenq; + octx = xchg((__force struct tcp_fastopen_context **)&q->ctx, ctx); + } else { + octx = xchg((__force struct tcp_fastopen_context **)&net->ipv4.tcp_fastopen_ctx, ctx); + } + + if (octx) + call_rcu(&octx->rcu, tcp_fastopen_ctx_free); +out: + return err; +} + +int tcp_fastopen_get_cipher(struct net *net, struct inet_connection_sock *icsk, + u64 *key) +{ + struct tcp_fastopen_context *ctx; + int n_keys = 0, i; + + rcu_read_lock(); + if (icsk) + ctx = rcu_dereference(icsk->icsk_accept_queue.fastopenq.ctx); + else + ctx = rcu_dereference(net->ipv4.tcp_fastopen_ctx); + if (ctx) { + n_keys = tcp_fastopen_context_len(ctx); + for (i = 0; i < n_keys; i++) { + put_unaligned_le64(ctx->key[i].key[0], key + (i * 2)); + put_unaligned_le64(ctx->key[i].key[1], key + (i * 2) + 1); + } + } + rcu_read_unlock(); + + return n_keys; +} + +static bool __tcp_fastopen_cookie_gen_cipher(struct request_sock *req, + struct sk_buff *syn, + const siphash_key_t *key, + struct tcp_fastopen_cookie *foc) +{ + BUILD_BUG_ON(TCP_FASTOPEN_COOKIE_SIZE != sizeof(u64)); + + if (req->rsk_ops->family == AF_INET) { + const struct iphdr *iph = ip_hdr(syn); + + foc->val[0] = cpu_to_le64(siphash(&iph->saddr, + sizeof(iph->saddr) + + sizeof(iph->daddr), + key)); + foc->len = TCP_FASTOPEN_COOKIE_SIZE; + return true; + } +#if IS_ENABLED(CONFIG_IPV6) + if (req->rsk_ops->family == AF_INET6) { + const struct ipv6hdr *ip6h = ipv6_hdr(syn); + + foc->val[0] = cpu_to_le64(siphash(&ip6h->saddr, + sizeof(ip6h->saddr) + + sizeof(ip6h->daddr), + key)); + foc->len = TCP_FASTOPEN_COOKIE_SIZE; + return true; + } +#endif + return false; +} + +/* Generate the fastopen cookie by applying SipHash to both the source and + * destination addresses. + */ +static void tcp_fastopen_cookie_gen(struct sock *sk, + struct request_sock *req, + struct sk_buff *syn, + struct tcp_fastopen_cookie *foc) +{ + struct tcp_fastopen_context *ctx; + + rcu_read_lock(); + ctx = tcp_fastopen_get_ctx(sk); + if (ctx) + __tcp_fastopen_cookie_gen_cipher(req, syn, &ctx->key[0], foc); + rcu_read_unlock(); +} + +/* If an incoming SYN or SYNACK frame contains a payload and/or FIN, + * queue this additional data / FIN. + */ +void tcp_fastopen_add_skb(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (TCP_SKB_CB(skb)->end_seq == tp->rcv_nxt) + return; + + skb = skb_clone(skb, GFP_ATOMIC); + if (!skb) + return; + + skb_dst_drop(skb); + /* segs_in has been initialized to 1 in tcp_create_openreq_child(). + * Hence, reset segs_in to 0 before calling tcp_segs_in() + * to avoid double counting. Also, tcp_segs_in() expects + * skb->len to include the tcp_hdrlen. Hence, it should + * be called before __skb_pull(). + */ + tp->segs_in = 0; + tcp_segs_in(tp, skb); + __skb_pull(skb, tcp_hdrlen(skb)); + sk_forced_mem_schedule(sk, skb->truesize); + skb_set_owner_r(skb, sk); + + TCP_SKB_CB(skb)->seq++; + TCP_SKB_CB(skb)->tcp_flags &= ~TCPHDR_SYN; + + tp->rcv_nxt = TCP_SKB_CB(skb)->end_seq; + __skb_queue_tail(&sk->sk_receive_queue, skb); + tp->syn_data_acked = 1; + + /* u64_stats_update_begin(&tp->syncp) not needed here, + * as we certainly are not changing upper 32bit value (0) + */ + tp->bytes_received = skb->len; + + if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) + tcp_fin(sk); +} + +/* returns 0 - no key match, 1 for primary, 2 for backup */ +static int tcp_fastopen_cookie_gen_check(struct sock *sk, + struct request_sock *req, + struct sk_buff *syn, + struct tcp_fastopen_cookie *orig, + struct tcp_fastopen_cookie *valid_foc) +{ + struct tcp_fastopen_cookie search_foc = { .len = -1 }; + struct tcp_fastopen_cookie *foc = valid_foc; + struct tcp_fastopen_context *ctx; + int i, ret = 0; + + rcu_read_lock(); + ctx = tcp_fastopen_get_ctx(sk); + if (!ctx) + goto out; + for (i = 0; i < tcp_fastopen_context_len(ctx); i++) { + __tcp_fastopen_cookie_gen_cipher(req, syn, &ctx->key[i], foc); + if (tcp_fastopen_cookie_match(foc, orig)) { + ret = i + 1; + goto out; + } + foc = &search_foc; + } +out: + rcu_read_unlock(); + return ret; +} + +static struct sock *tcp_fastopen_create_child(struct sock *sk, + struct sk_buff *skb, + struct request_sock *req) +{ + struct tcp_sock *tp; + struct request_sock_queue *queue = &inet_csk(sk)->icsk_accept_queue; + struct sock *child; + bool own_req; + + child = inet_csk(sk)->icsk_af_ops->syn_recv_sock(sk, skb, req, NULL, + NULL, &own_req); + if (!child) + return NULL; + + spin_lock(&queue->fastopenq.lock); + queue->fastopenq.qlen++; + spin_unlock(&queue->fastopenq.lock); + + /* Initialize the child socket. Have to fix some values to take + * into account the child is a Fast Open socket and is created + * only out of the bits carried in the SYN packet. + */ + tp = tcp_sk(child); + + rcu_assign_pointer(tp->fastopen_rsk, req); + tcp_rsk(req)->tfo_listener = true; + + /* RFC1323: The window in SYN & SYN/ACK segments is never + * scaled. So correct it appropriately. + */ + tp->snd_wnd = ntohs(tcp_hdr(skb)->window); + tp->max_window = tp->snd_wnd; + + /* Activate the retrans timer so that SYNACK can be retransmitted. + * The request socket is not added to the ehash + * because it's been added to the accept queue directly. + */ + req->timeout = tcp_timeout_init(child); + inet_csk_reset_xmit_timer(child, ICSK_TIME_RETRANS, + req->timeout, TCP_RTO_MAX); + + refcount_set(&req->rsk_refcnt, 2); + + /* Now finish processing the fastopen child socket. */ + tcp_init_transfer(child, BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB, skb); + + tp->rcv_nxt = TCP_SKB_CB(skb)->seq + 1; + + tcp_fastopen_add_skb(child, skb); + + tcp_rsk(req)->rcv_nxt = tp->rcv_nxt; + tp->rcv_wup = tp->rcv_nxt; + /* tcp_conn_request() is sending the SYNACK, + * and queues the child into listener accept queue. + */ + return child; +} + +static bool tcp_fastopen_queue_check(struct sock *sk) +{ + struct fastopen_queue *fastopenq; + int max_qlen; + + /* Make sure the listener has enabled fastopen, and we don't + * exceed the max # of pending TFO requests allowed before trying + * to validating the cookie in order to avoid burning CPU cycles + * unnecessarily. + * + * XXX (TFO) - The implication of checking the max_qlen before + * processing a cookie request is that clients can't differentiate + * between qlen overflow causing Fast Open to be disabled + * temporarily vs a server not supporting Fast Open at all. + */ + fastopenq = &inet_csk(sk)->icsk_accept_queue.fastopenq; + max_qlen = READ_ONCE(fastopenq->max_qlen); + if (max_qlen == 0) + return false; + + if (fastopenq->qlen >= max_qlen) { + struct request_sock *req1; + spin_lock(&fastopenq->lock); + req1 = fastopenq->rskq_rst_head; + if (!req1 || time_after(req1->rsk_timer.expires, jiffies)) { + __NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPFASTOPENLISTENOVERFLOW); + spin_unlock(&fastopenq->lock); + return false; + } + fastopenq->rskq_rst_head = req1->dl_next; + fastopenq->qlen--; + spin_unlock(&fastopenq->lock); + reqsk_put(req1); + } + return true; +} + +static bool tcp_fastopen_no_cookie(const struct sock *sk, + const struct dst_entry *dst, + int flag) +{ + return (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_fastopen) & flag) || + tcp_sk(sk)->fastopen_no_cookie || + (dst && dst_metric(dst, RTAX_FASTOPEN_NO_COOKIE)); +} + +/* Returns true if we should perform Fast Open on the SYN. The cookie (foc) + * may be updated and return the client in the SYN-ACK later. E.g., Fast Open + * cookie request (foc->len == 0). + */ +struct sock *tcp_try_fastopen(struct sock *sk, struct sk_buff *skb, + struct request_sock *req, + struct tcp_fastopen_cookie *foc, + const struct dst_entry *dst) +{ + bool syn_data = TCP_SKB_CB(skb)->end_seq != TCP_SKB_CB(skb)->seq + 1; + int tcp_fastopen = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_fastopen); + struct tcp_fastopen_cookie valid_foc = { .len = -1 }; + struct sock *child; + int ret = 0; + + if (foc->len == 0) /* Client requests a cookie */ + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPFASTOPENCOOKIEREQD); + + if (!((tcp_fastopen & TFO_SERVER_ENABLE) && + (syn_data || foc->len >= 0) && + tcp_fastopen_queue_check(sk))) { + foc->len = -1; + return NULL; + } + + if (tcp_fastopen_no_cookie(sk, dst, TFO_SERVER_COOKIE_NOT_REQD)) + goto fastopen; + + if (foc->len == 0) { + /* Client requests a cookie. */ + tcp_fastopen_cookie_gen(sk, req, skb, &valid_foc); + } else if (foc->len > 0) { + ret = tcp_fastopen_cookie_gen_check(sk, req, skb, foc, + &valid_foc); + if (!ret) { + NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPFASTOPENPASSIVEFAIL); + } else { + /* Cookie is valid. Create a (full) child socket to + * accept the data in SYN before returning a SYN-ACK to + * ack the data. If we fail to create the socket, fall + * back and ack the ISN only but includes the same + * cookie. + * + * Note: Data-less SYN with valid cookie is allowed to + * send data in SYN_RECV state. + */ +fastopen: + child = tcp_fastopen_create_child(sk, skb, req); + if (child) { + if (ret == 2) { + valid_foc.exp = foc->exp; + *foc = valid_foc; + NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPFASTOPENPASSIVEALTKEY); + } else { + foc->len = -1; + } + NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPFASTOPENPASSIVE); + return child; + } + NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPFASTOPENPASSIVEFAIL); + } + } + valid_foc.exp = foc->exp; + *foc = valid_foc; + return NULL; +} + +bool tcp_fastopen_cookie_check(struct sock *sk, u16 *mss, + struct tcp_fastopen_cookie *cookie) +{ + const struct dst_entry *dst; + + tcp_fastopen_cache_get(sk, mss, cookie); + + /* Firewall blackhole issue check */ + if (tcp_fastopen_active_should_disable(sk)) { + cookie->len = -1; + return false; + } + + dst = __sk_dst_get(sk); + + if (tcp_fastopen_no_cookie(sk, dst, TFO_CLIENT_NO_COOKIE)) { + cookie->len = -1; + return true; + } + if (cookie->len > 0) + return true; + tcp_sk(sk)->fastopen_client_fail = TFO_COOKIE_UNAVAILABLE; + return false; +} + +/* This function checks if we want to defer sending SYN until the first + * write(). We defer under the following conditions: + * 1. fastopen_connect sockopt is set + * 2. we have a valid cookie + * Return value: return true if we want to defer until application writes data + * return false if we want to send out SYN immediately + */ +bool tcp_fastopen_defer_connect(struct sock *sk, int *err) +{ + struct tcp_fastopen_cookie cookie = { .len = 0 }; + struct tcp_sock *tp = tcp_sk(sk); + u16 mss; + + if (tp->fastopen_connect && !tp->fastopen_req) { + if (tcp_fastopen_cookie_check(sk, &mss, &cookie)) { + inet_sk(sk)->defer_connect = 1; + return true; + } + + /* Alloc fastopen_req in order for FO option to be included + * in SYN + */ + tp->fastopen_req = kzalloc(sizeof(*tp->fastopen_req), + sk->sk_allocation); + if (tp->fastopen_req) + tp->fastopen_req->cookie = cookie; + else + *err = -ENOBUFS; + } + return false; +} +EXPORT_SYMBOL(tcp_fastopen_defer_connect); + +/* + * The following code block is to deal with middle box issues with TFO: + * Middlebox firewall issues can potentially cause server's data being + * blackholed after a successful 3WHS using TFO. + * The proposed solution is to disable active TFO globally under the + * following circumstances: + * 1. client side TFO socket receives out of order FIN + * 2. client side TFO socket receives out of order RST + * 3. client side TFO socket has timed out three times consecutively during + * or after handshake + * We disable active side TFO globally for 1hr at first. Then if it + * happens again, we disable it for 2h, then 4h, 8h, ... + * And we reset the timeout back to 1hr when we see a successful active + * TFO connection with data exchanges. + */ + +/* Disable active TFO and record current jiffies and + * tfo_active_disable_times + */ +void tcp_fastopen_active_disable(struct sock *sk) +{ + struct net *net = sock_net(sk); + + if (!READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_fastopen_blackhole_timeout)) + return; + + /* Paired with READ_ONCE() in tcp_fastopen_active_should_disable() */ + WRITE_ONCE(net->ipv4.tfo_active_disable_stamp, jiffies); + + /* Paired with smp_rmb() in tcp_fastopen_active_should_disable(). + * We want net->ipv4.tfo_active_disable_stamp to be updated first. + */ + smp_mb__before_atomic(); + atomic_inc(&net->ipv4.tfo_active_disable_times); + + NET_INC_STATS(net, LINUX_MIB_TCPFASTOPENBLACKHOLE); +} + +/* Calculate timeout for tfo active disable + * Return true if we are still in the active TFO disable period + * Return false if timeout already expired and we should use active TFO + */ +bool tcp_fastopen_active_should_disable(struct sock *sk) +{ + unsigned int tfo_bh_timeout = + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_fastopen_blackhole_timeout); + unsigned long timeout; + int tfo_da_times; + int multiplier; + + if (!tfo_bh_timeout) + return false; + + tfo_da_times = atomic_read(&sock_net(sk)->ipv4.tfo_active_disable_times); + if (!tfo_da_times) + return false; + + /* Paired with smp_mb__before_atomic() in tcp_fastopen_active_disable() */ + smp_rmb(); + + /* Limit timeout to max: 2^6 * initial timeout */ + multiplier = 1 << min(tfo_da_times - 1, 6); + + /* Paired with the WRITE_ONCE() in tcp_fastopen_active_disable(). */ + timeout = READ_ONCE(sock_net(sk)->ipv4.tfo_active_disable_stamp) + + multiplier * tfo_bh_timeout * HZ; + if (time_before(jiffies, timeout)) + return true; + + /* Mark check bit so we can check for successful active TFO + * condition and reset tfo_active_disable_times + */ + tcp_sk(sk)->syn_fastopen_ch = 1; + return false; +} + +/* Disable active TFO if FIN is the only packet in the ofo queue + * and no data is received. + * Also check if we can reset tfo_active_disable_times if data is + * received successfully on a marked active TFO sockets opened on + * a non-loopback interface + */ +void tcp_fastopen_active_disable_ofo_check(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct dst_entry *dst; + struct sk_buff *skb; + + if (!tp->syn_fastopen) + return; + + if (!tp->data_segs_in) { + skb = skb_rb_first(&tp->out_of_order_queue); + if (skb && !skb_rb_next(skb)) { + if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) { + tcp_fastopen_active_disable(sk); + return; + } + } + } else if (tp->syn_fastopen_ch && + atomic_read(&sock_net(sk)->ipv4.tfo_active_disable_times)) { + dst = sk_dst_get(sk); + if (!(dst && dst->dev && (dst->dev->flags & IFF_LOOPBACK))) + atomic_set(&sock_net(sk)->ipv4.tfo_active_disable_times, 0); + dst_release(dst); + } +} + +void tcp_fastopen_active_detect_blackhole(struct sock *sk, bool expired) +{ + u32 timeouts = inet_csk(sk)->icsk_retransmits; + struct tcp_sock *tp = tcp_sk(sk); + + /* Broken middle-boxes may black-hole Fast Open connection during or + * even after the handshake. Be extremely conservative and pause + * Fast Open globally after hitting the third consecutive timeout or + * exceeding the configured timeout limit. + */ + if ((tp->syn_fastopen || tp->syn_data || tp->syn_data_acked) && + (timeouts == 2 || (timeouts < 2 && expired))) { + tcp_fastopen_active_disable(sk); + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPFASTOPENACTIVEFAIL); + } +} diff --git a/net/ipv4/tcp_highspeed.c b/net/ipv4/tcp_highspeed.c new file mode 100644 index 000000000..c6de5ce79 --- /dev/null +++ b/net/ipv4/tcp_highspeed.c @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Sally Floyd's High Speed TCP (RFC 3649) congestion control + * + * See https://www.icir.org/floyd/hstcp.html + * + * John Heffner + */ + +#include +#include + +/* From AIMD tables from RFC 3649 appendix B, + * with fixed-point MD scaled <<8. + */ +static const struct hstcp_aimd_val { + unsigned int cwnd; + unsigned int md; +} hstcp_aimd_vals[] = { + { 38, 128, /* 0.50 */ }, + { 118, 112, /* 0.44 */ }, + { 221, 104, /* 0.41 */ }, + { 347, 98, /* 0.38 */ }, + { 495, 93, /* 0.37 */ }, + { 663, 89, /* 0.35 */ }, + { 851, 86, /* 0.34 */ }, + { 1058, 83, /* 0.33 */ }, + { 1284, 81, /* 0.32 */ }, + { 1529, 78, /* 0.31 */ }, + { 1793, 76, /* 0.30 */ }, + { 2076, 74, /* 0.29 */ }, + { 2378, 72, /* 0.28 */ }, + { 2699, 71, /* 0.28 */ }, + { 3039, 69, /* 0.27 */ }, + { 3399, 68, /* 0.27 */ }, + { 3778, 66, /* 0.26 */ }, + { 4177, 65, /* 0.26 */ }, + { 4596, 64, /* 0.25 */ }, + { 5036, 62, /* 0.25 */ }, + { 5497, 61, /* 0.24 */ }, + { 5979, 60, /* 0.24 */ }, + { 6483, 59, /* 0.23 */ }, + { 7009, 58, /* 0.23 */ }, + { 7558, 57, /* 0.22 */ }, + { 8130, 56, /* 0.22 */ }, + { 8726, 55, /* 0.22 */ }, + { 9346, 54, /* 0.21 */ }, + { 9991, 53, /* 0.21 */ }, + { 10661, 52, /* 0.21 */ }, + { 11358, 52, /* 0.20 */ }, + { 12082, 51, /* 0.20 */ }, + { 12834, 50, /* 0.20 */ }, + { 13614, 49, /* 0.19 */ }, + { 14424, 48, /* 0.19 */ }, + { 15265, 48, /* 0.19 */ }, + { 16137, 47, /* 0.19 */ }, + { 17042, 46, /* 0.18 */ }, + { 17981, 45, /* 0.18 */ }, + { 18955, 45, /* 0.18 */ }, + { 19965, 44, /* 0.17 */ }, + { 21013, 43, /* 0.17 */ }, + { 22101, 43, /* 0.17 */ }, + { 23230, 42, /* 0.17 */ }, + { 24402, 41, /* 0.16 */ }, + { 25618, 41, /* 0.16 */ }, + { 26881, 40, /* 0.16 */ }, + { 28193, 39, /* 0.16 */ }, + { 29557, 39, /* 0.15 */ }, + { 30975, 38, /* 0.15 */ }, + { 32450, 38, /* 0.15 */ }, + { 33986, 37, /* 0.15 */ }, + { 35586, 36, /* 0.14 */ }, + { 37253, 36, /* 0.14 */ }, + { 38992, 35, /* 0.14 */ }, + { 40808, 35, /* 0.14 */ }, + { 42707, 34, /* 0.13 */ }, + { 44694, 33, /* 0.13 */ }, + { 46776, 33, /* 0.13 */ }, + { 48961, 32, /* 0.13 */ }, + { 51258, 32, /* 0.13 */ }, + { 53677, 31, /* 0.12 */ }, + { 56230, 30, /* 0.12 */ }, + { 58932, 30, /* 0.12 */ }, + { 61799, 29, /* 0.12 */ }, + { 64851, 28, /* 0.11 */ }, + { 68113, 28, /* 0.11 */ }, + { 71617, 27, /* 0.11 */ }, + { 75401, 26, /* 0.10 */ }, + { 79517, 26, /* 0.10 */ }, + { 84035, 25, /* 0.10 */ }, + { 89053, 24, /* 0.10 */ }, +}; + +#define HSTCP_AIMD_MAX ARRAY_SIZE(hstcp_aimd_vals) + +struct hstcp { + u32 ai; +}; + +static void hstcp_init(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct hstcp *ca = inet_csk_ca(sk); + + ca->ai = 0; + + /* Ensure the MD arithmetic works. This is somewhat pedantic, + * since I don't think we will see a cwnd this large. :) */ + tp->snd_cwnd_clamp = min_t(u32, tp->snd_cwnd_clamp, 0xffffffff/128); +} + +static void hstcp_cong_avoid(struct sock *sk, u32 ack, u32 acked) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct hstcp *ca = inet_csk_ca(sk); + + if (!tcp_is_cwnd_limited(sk)) + return; + + if (tcp_in_slow_start(tp)) + tcp_slow_start(tp, acked); + else { + /* Update AIMD parameters. + * + * We want to guarantee that: + * hstcp_aimd_vals[ca->ai-1].cwnd < + * snd_cwnd <= + * hstcp_aimd_vals[ca->ai].cwnd + */ + if (tcp_snd_cwnd(tp) > hstcp_aimd_vals[ca->ai].cwnd) { + while (tcp_snd_cwnd(tp) > hstcp_aimd_vals[ca->ai].cwnd && + ca->ai < HSTCP_AIMD_MAX - 1) + ca->ai++; + } else if (ca->ai && tcp_snd_cwnd(tp) <= hstcp_aimd_vals[ca->ai-1].cwnd) { + while (ca->ai && tcp_snd_cwnd(tp) <= hstcp_aimd_vals[ca->ai-1].cwnd) + ca->ai--; + } + + /* Do additive increase */ + if (tcp_snd_cwnd(tp) < tp->snd_cwnd_clamp) { + /* cwnd = cwnd + a(w) / cwnd */ + tp->snd_cwnd_cnt += ca->ai + 1; + if (tp->snd_cwnd_cnt >= tcp_snd_cwnd(tp)) { + tp->snd_cwnd_cnt -= tcp_snd_cwnd(tp); + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + 1); + } + } + } +} + +static u32 hstcp_ssthresh(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + struct hstcp *ca = inet_csk_ca(sk); + + /* Do multiplicative decrease */ + return max(tcp_snd_cwnd(tp) - ((tcp_snd_cwnd(tp) * hstcp_aimd_vals[ca->ai].md) >> 8), 2U); +} + +static struct tcp_congestion_ops tcp_highspeed __read_mostly = { + .init = hstcp_init, + .ssthresh = hstcp_ssthresh, + .undo_cwnd = tcp_reno_undo_cwnd, + .cong_avoid = hstcp_cong_avoid, + + .owner = THIS_MODULE, + .name = "highspeed" +}; + +static int __init hstcp_register(void) +{ + BUILD_BUG_ON(sizeof(struct hstcp) > ICSK_CA_PRIV_SIZE); + return tcp_register_congestion_control(&tcp_highspeed); +} + +static void __exit hstcp_unregister(void) +{ + tcp_unregister_congestion_control(&tcp_highspeed); +} + +module_init(hstcp_register); +module_exit(hstcp_unregister); + +MODULE_AUTHOR("John Heffner"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("High Speed TCP"); diff --git a/net/ipv4/tcp_htcp.c b/net/ipv4/tcp_htcp.c new file mode 100644 index 000000000..52b1f2665 --- /dev/null +++ b/net/ipv4/tcp_htcp.c @@ -0,0 +1,317 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * H-TCP congestion control. The algorithm is detailed in: + * R.N.Shorten, D.J.Leith: + * "H-TCP: TCP for high-speed and long-distance networks" + * Proc. PFLDnet, Argonne, 2004. + * https://www.hamilton.ie/net/htcp3.pdf + */ + +#include +#include +#include + +#define ALPHA_BASE (1<<7) /* 1.0 with shift << 7 */ +#define BETA_MIN (1<<6) /* 0.5 with shift << 7 */ +#define BETA_MAX 102 /* 0.8 with shift << 7 */ + +static int use_rtt_scaling __read_mostly = 1; +module_param(use_rtt_scaling, int, 0644); +MODULE_PARM_DESC(use_rtt_scaling, "turn on/off RTT scaling"); + +static int use_bandwidth_switch __read_mostly = 1; +module_param(use_bandwidth_switch, int, 0644); +MODULE_PARM_DESC(use_bandwidth_switch, "turn on/off bandwidth switcher"); + +struct htcp { + u32 alpha; /* Fixed point arith, << 7 */ + u8 beta; /* Fixed point arith, << 7 */ + u8 modeswitch; /* Delay modeswitch + until we had at least one congestion event */ + u16 pkts_acked; + u32 packetcount; + u32 minRTT; + u32 maxRTT; + u32 last_cong; /* Time since last congestion event end */ + u32 undo_last_cong; + + u32 undo_maxRTT; + u32 undo_old_maxB; + + /* Bandwidth estimation */ + u32 minB; + u32 maxB; + u32 old_maxB; + u32 Bi; + u32 lasttime; +}; + +static inline u32 htcp_cong_time(const struct htcp *ca) +{ + return jiffies - ca->last_cong; +} + +static inline u32 htcp_ccount(const struct htcp *ca) +{ + return htcp_cong_time(ca) / ca->minRTT; +} + +static inline void htcp_reset(struct htcp *ca) +{ + ca->undo_last_cong = ca->last_cong; + ca->undo_maxRTT = ca->maxRTT; + ca->undo_old_maxB = ca->old_maxB; + + ca->last_cong = jiffies; +} + +static u32 htcp_cwnd_undo(struct sock *sk) +{ + struct htcp *ca = inet_csk_ca(sk); + + if (ca->undo_last_cong) { + ca->last_cong = ca->undo_last_cong; + ca->maxRTT = ca->undo_maxRTT; + ca->old_maxB = ca->undo_old_maxB; + ca->undo_last_cong = 0; + } + + return tcp_reno_undo_cwnd(sk); +} + +static inline void measure_rtt(struct sock *sk, u32 srtt) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + struct htcp *ca = inet_csk_ca(sk); + + /* keep track of minimum RTT seen so far, minRTT is zero at first */ + if (ca->minRTT > srtt || !ca->minRTT) + ca->minRTT = srtt; + + /* max RTT */ + if (icsk->icsk_ca_state == TCP_CA_Open) { + if (ca->maxRTT < ca->minRTT) + ca->maxRTT = ca->minRTT; + if (ca->maxRTT < srtt && + srtt <= ca->maxRTT + msecs_to_jiffies(20)) + ca->maxRTT = srtt; + } +} + +static void measure_achieved_throughput(struct sock *sk, + const struct ack_sample *sample) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + const struct tcp_sock *tp = tcp_sk(sk); + struct htcp *ca = inet_csk_ca(sk); + u32 now = tcp_jiffies32; + + if (icsk->icsk_ca_state == TCP_CA_Open) + ca->pkts_acked = sample->pkts_acked; + + if (sample->rtt_us > 0) + measure_rtt(sk, usecs_to_jiffies(sample->rtt_us)); + + if (!use_bandwidth_switch) + return; + + /* achieved throughput calculations */ + if (!((1 << icsk->icsk_ca_state) & (TCPF_CA_Open | TCPF_CA_Disorder))) { + ca->packetcount = 0; + ca->lasttime = now; + return; + } + + ca->packetcount += sample->pkts_acked; + + if (ca->packetcount >= tcp_snd_cwnd(tp) - (ca->alpha >> 7 ? : 1) && + now - ca->lasttime >= ca->minRTT && + ca->minRTT > 0) { + __u32 cur_Bi = ca->packetcount * HZ / (now - ca->lasttime); + + if (htcp_ccount(ca) <= 3) { + /* just after backoff */ + ca->minB = ca->maxB = ca->Bi = cur_Bi; + } else { + ca->Bi = (3 * ca->Bi + cur_Bi) / 4; + if (ca->Bi > ca->maxB) + ca->maxB = ca->Bi; + if (ca->minB > ca->maxB) + ca->minB = ca->maxB; + } + ca->packetcount = 0; + ca->lasttime = now; + } +} + +static inline void htcp_beta_update(struct htcp *ca, u32 minRTT, u32 maxRTT) +{ + if (use_bandwidth_switch) { + u32 maxB = ca->maxB; + u32 old_maxB = ca->old_maxB; + + ca->old_maxB = ca->maxB; + if (!between(5 * maxB, 4 * old_maxB, 6 * old_maxB)) { + ca->beta = BETA_MIN; + ca->modeswitch = 0; + return; + } + } + + if (ca->modeswitch && minRTT > msecs_to_jiffies(10) && maxRTT) { + ca->beta = (minRTT << 7) / maxRTT; + if (ca->beta < BETA_MIN) + ca->beta = BETA_MIN; + else if (ca->beta > BETA_MAX) + ca->beta = BETA_MAX; + } else { + ca->beta = BETA_MIN; + ca->modeswitch = 1; + } +} + +static inline void htcp_alpha_update(struct htcp *ca) +{ + u32 minRTT = ca->minRTT; + u32 factor = 1; + u32 diff = htcp_cong_time(ca); + + if (diff > HZ) { + diff -= HZ; + factor = 1 + (10 * diff + ((diff / 2) * (diff / 2) / HZ)) / HZ; + } + + if (use_rtt_scaling && minRTT) { + u32 scale = (HZ << 3) / (10 * minRTT); + + /* clamping ratio to interval [0.5,10]<<3 */ + scale = min(max(scale, 1U << 2), 10U << 3); + factor = (factor << 3) / scale; + if (!factor) + factor = 1; + } + + ca->alpha = 2 * factor * ((1 << 7) - ca->beta); + if (!ca->alpha) + ca->alpha = ALPHA_BASE; +} + +/* + * After we have the rtt data to calculate beta, we'd still prefer to wait one + * rtt before we adjust our beta to ensure we are working from a consistent + * data. + * + * This function should be called when we hit a congestion event since only at + * that point do we really have a real sense of maxRTT (the queues en route + * were getting just too full now). + */ +static void htcp_param_update(struct sock *sk) +{ + struct htcp *ca = inet_csk_ca(sk); + u32 minRTT = ca->minRTT; + u32 maxRTT = ca->maxRTT; + + htcp_beta_update(ca, minRTT, maxRTT); + htcp_alpha_update(ca); + + /* add slowly fading memory for maxRTT to accommodate routing changes */ + if (minRTT > 0 && maxRTT > minRTT) + ca->maxRTT = minRTT + ((maxRTT - minRTT) * 95) / 100; +} + +static u32 htcp_recalc_ssthresh(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + const struct htcp *ca = inet_csk_ca(sk); + + htcp_param_update(sk); + return max((tcp_snd_cwnd(tp) * ca->beta) >> 7, 2U); +} + +static void htcp_cong_avoid(struct sock *sk, u32 ack, u32 acked) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct htcp *ca = inet_csk_ca(sk); + + if (!tcp_is_cwnd_limited(sk)) + return; + + if (tcp_in_slow_start(tp)) + tcp_slow_start(tp, acked); + else { + /* In dangerous area, increase slowly. + * In theory this is tp->snd_cwnd += alpha / tp->snd_cwnd + */ + if ((tp->snd_cwnd_cnt * ca->alpha)>>7 >= tcp_snd_cwnd(tp)) { + if (tcp_snd_cwnd(tp) < tp->snd_cwnd_clamp) + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + 1); + tp->snd_cwnd_cnt = 0; + htcp_alpha_update(ca); + } else + tp->snd_cwnd_cnt += ca->pkts_acked; + + ca->pkts_acked = 1; + } +} + +static void htcp_init(struct sock *sk) +{ + struct htcp *ca = inet_csk_ca(sk); + + memset(ca, 0, sizeof(struct htcp)); + ca->alpha = ALPHA_BASE; + ca->beta = BETA_MIN; + ca->pkts_acked = 1; + ca->last_cong = jiffies; +} + +static void htcp_state(struct sock *sk, u8 new_state) +{ + switch (new_state) { + case TCP_CA_Open: + { + struct htcp *ca = inet_csk_ca(sk); + + if (ca->undo_last_cong) { + ca->last_cong = jiffies; + ca->undo_last_cong = 0; + } + } + break; + case TCP_CA_CWR: + case TCP_CA_Recovery: + case TCP_CA_Loss: + htcp_reset(inet_csk_ca(sk)); + break; + } +} + +static struct tcp_congestion_ops htcp __read_mostly = { + .init = htcp_init, + .ssthresh = htcp_recalc_ssthresh, + .cong_avoid = htcp_cong_avoid, + .set_state = htcp_state, + .undo_cwnd = htcp_cwnd_undo, + .pkts_acked = measure_achieved_throughput, + .owner = THIS_MODULE, + .name = "htcp", +}; + +static int __init htcp_register(void) +{ + BUILD_BUG_ON(sizeof(struct htcp) > ICSK_CA_PRIV_SIZE); + BUILD_BUG_ON(BETA_MIN >= BETA_MAX); + return tcp_register_congestion_control(&htcp); +} + +static void __exit htcp_unregister(void) +{ + tcp_unregister_congestion_control(&htcp); +} + +module_init(htcp_register); +module_exit(htcp_unregister); + +MODULE_AUTHOR("Baruch Even"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("H-TCP"); diff --git a/net/ipv4/tcp_hybla.c b/net/ipv4/tcp_hybla.c new file mode 100644 index 000000000..abd7d9180 --- /dev/null +++ b/net/ipv4/tcp_hybla.c @@ -0,0 +1,194 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * TCP HYBLA + * + * TCP-HYBLA Congestion control algorithm, based on: + * C.Caini, R.Firrincieli, "TCP-Hybla: A TCP Enhancement + * for Heterogeneous Networks", + * International Journal on satellite Communications, + * September 2004 + * Daniele Lacamera + * root at danielinux.net + */ + +#include +#include + +/* Tcp Hybla structure. */ +struct hybla { + bool hybla_en; + u32 snd_cwnd_cents; /* Keeps increment values when it is <1, <<7 */ + u32 rho; /* Rho parameter, integer part */ + u32 rho2; /* Rho * Rho, integer part */ + u32 rho_3ls; /* Rho parameter, <<3 */ + u32 rho2_7ls; /* Rho^2, <<7 */ + u32 minrtt_us; /* Minimum smoothed round trip time value seen */ +}; + +/* Hybla reference round trip time (default= 1/40 sec = 25 ms), in ms */ +static int rtt0 = 25; +module_param(rtt0, int, 0644); +MODULE_PARM_DESC(rtt0, "reference rout trip time (ms)"); + +/* This is called to refresh values for hybla parameters */ +static inline void hybla_recalc_param (struct sock *sk) +{ + struct hybla *ca = inet_csk_ca(sk); + + ca->rho_3ls = max_t(u32, + tcp_sk(sk)->srtt_us / (rtt0 * USEC_PER_MSEC), + 8U); + ca->rho = ca->rho_3ls >> 3; + ca->rho2_7ls = (ca->rho_3ls * ca->rho_3ls) << 1; + ca->rho2 = ca->rho2_7ls >> 7; +} + +static void hybla_init(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct hybla *ca = inet_csk_ca(sk); + + ca->rho = 0; + ca->rho2 = 0; + ca->rho_3ls = 0; + ca->rho2_7ls = 0; + ca->snd_cwnd_cents = 0; + ca->hybla_en = true; + tcp_snd_cwnd_set(tp, 2); + tp->snd_cwnd_clamp = 65535; + + /* 1st Rho measurement based on initial srtt */ + hybla_recalc_param(sk); + + /* set minimum rtt as this is the 1st ever seen */ + ca->minrtt_us = tp->srtt_us; + tcp_snd_cwnd_set(tp, ca->rho); +} + +static void hybla_state(struct sock *sk, u8 ca_state) +{ + struct hybla *ca = inet_csk_ca(sk); + + ca->hybla_en = (ca_state == TCP_CA_Open); +} + +static inline u32 hybla_fraction(u32 odds) +{ + static const u32 fractions[] = { + 128, 139, 152, 165, 181, 197, 215, 234, + }; + + return (odds < ARRAY_SIZE(fractions)) ? fractions[odds] : 128; +} + +/* TCP Hybla main routine. + * This is the algorithm behavior: + * o Recalc Hybla parameters if min_rtt has changed + * o Give cwnd a new value based on the model proposed + * o remember increments <1 + */ +static void hybla_cong_avoid(struct sock *sk, u32 ack, u32 acked) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct hybla *ca = inet_csk_ca(sk); + u32 increment, odd, rho_fractions; + int is_slowstart = 0; + + /* Recalculate rho only if this srtt is the lowest */ + if (tp->srtt_us < ca->minrtt_us) { + hybla_recalc_param(sk); + ca->minrtt_us = tp->srtt_us; + } + + if (!tcp_is_cwnd_limited(sk)) + return; + + if (!ca->hybla_en) { + tcp_reno_cong_avoid(sk, ack, acked); + return; + } + + if (ca->rho == 0) + hybla_recalc_param(sk); + + rho_fractions = ca->rho_3ls - (ca->rho << 3); + + if (tcp_in_slow_start(tp)) { + /* + * slow start + * INC = 2^RHO - 1 + * This is done by splitting the rho parameter + * into 2 parts: an integer part and a fraction part. + * Inrement<<7 is estimated by doing: + * [2^(int+fract)]<<7 + * that is equal to: + * (2^int) * [(2^fract) <<7] + * 2^int is straightly computed as 1<rho, 16U)) * + hybla_fraction(rho_fractions)) - 128; + } else { + /* + * congestion avoidance + * INC = RHO^2 / W + * as long as increment is estimated as (rho<<7)/window + * it already is <<7 and we can easily count its fractions. + */ + increment = ca->rho2_7ls / tcp_snd_cwnd(tp); + if (increment < 128) + tp->snd_cwnd_cnt++; + } + + odd = increment % 128; + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + (increment >> 7)); + ca->snd_cwnd_cents += odd; + + /* check when fractions goes >=128 and increase cwnd by 1. */ + while (ca->snd_cwnd_cents >= 128) { + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + 1); + ca->snd_cwnd_cents -= 128; + tp->snd_cwnd_cnt = 0; + } + /* check when cwnd has not been incremented for a while */ + if (increment == 0 && odd == 0 && tp->snd_cwnd_cnt >= tcp_snd_cwnd(tp)) { + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + 1); + tp->snd_cwnd_cnt = 0; + } + /* clamp down slowstart cwnd to ssthresh value. */ + if (is_slowstart) + tcp_snd_cwnd_set(tp, min(tcp_snd_cwnd(tp), tp->snd_ssthresh)); + + tcp_snd_cwnd_set(tp, min(tcp_snd_cwnd(tp), tp->snd_cwnd_clamp)); +} + +static struct tcp_congestion_ops tcp_hybla __read_mostly = { + .init = hybla_init, + .ssthresh = tcp_reno_ssthresh, + .undo_cwnd = tcp_reno_undo_cwnd, + .cong_avoid = hybla_cong_avoid, + .set_state = hybla_state, + + .owner = THIS_MODULE, + .name = "hybla" +}; + +static int __init hybla_register(void) +{ + BUILD_BUG_ON(sizeof(struct hybla) > ICSK_CA_PRIV_SIZE); + return tcp_register_congestion_control(&tcp_hybla); +} + +static void __exit hybla_unregister(void) +{ + tcp_unregister_congestion_control(&tcp_hybla); +} + +module_init(hybla_register); +module_exit(hybla_unregister); + +MODULE_AUTHOR("Daniele Lacamera"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("TCP Hybla"); diff --git a/net/ipv4/tcp_illinois.c b/net/ipv4/tcp_illinois.c new file mode 100644 index 000000000..c0c81a2c7 --- /dev/null +++ b/net/ipv4/tcp_illinois.c @@ -0,0 +1,360 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * TCP Illinois congestion control. + * Home page: + * http://www.ews.uiuc.edu/~shaoliu/tcpillinois/index.html + * + * The algorithm is described in: + * "TCP-Illinois: A Loss and Delay-Based Congestion Control Algorithm + * for High-Speed Networks" + * http://tamerbasar.csl.illinois.edu/LiuBasarSrikantPerfEvalArtJun2008.pdf + * + * Implemented from description in paper and ns-2 simulation. + * Copyright (C) 2007 Stephen Hemminger + */ + +#include +#include +#include +#include +#include + +#define ALPHA_SHIFT 7 +#define ALPHA_SCALE (1u<end_seq = tp->snd_nxt; + ca->cnt_rtt = 0; + ca->sum_rtt = 0; + + /* TODO: age max_rtt? */ +} + +static void tcp_illinois_init(struct sock *sk) +{ + struct illinois *ca = inet_csk_ca(sk); + + ca->alpha = ALPHA_MAX; + ca->beta = BETA_BASE; + ca->base_rtt = 0x7fffffff; + ca->max_rtt = 0; + + ca->acked = 0; + ca->rtt_low = 0; + ca->rtt_above = 0; + + rtt_reset(sk); +} + +/* Measure RTT for each ack. */ +static void tcp_illinois_acked(struct sock *sk, const struct ack_sample *sample) +{ + struct illinois *ca = inet_csk_ca(sk); + s32 rtt_us = sample->rtt_us; + + ca->acked = sample->pkts_acked; + + /* dup ack, no rtt sample */ + if (rtt_us < 0) + return; + + /* ignore bogus values, this prevents wraparound in alpha math */ + if (rtt_us > RTT_MAX) + rtt_us = RTT_MAX; + + /* keep track of minimum RTT seen so far */ + if (ca->base_rtt > rtt_us) + ca->base_rtt = rtt_us; + + /* and max */ + if (ca->max_rtt < rtt_us) + ca->max_rtt = rtt_us; + + ++ca->cnt_rtt; + ca->sum_rtt += rtt_us; +} + +/* Maximum queuing delay */ +static inline u32 max_delay(const struct illinois *ca) +{ + return ca->max_rtt - ca->base_rtt; +} + +/* Average queuing delay */ +static inline u32 avg_delay(const struct illinois *ca) +{ + u64 t = ca->sum_rtt; + + do_div(t, ca->cnt_rtt); + return t - ca->base_rtt; +} + +/* + * Compute value of alpha used for additive increase. + * If small window then use 1.0, equivalent to Reno. + * + * For larger windows, adjust based on average delay. + * A. If average delay is at minimum (we are uncongested), + * then use large alpha (10.0) to increase faster. + * B. If average delay is at maximum (getting congested) + * then use small alpha (0.3) + * + * The result is a convex window growth curve. + */ +static u32 alpha(struct illinois *ca, u32 da, u32 dm) +{ + u32 d1 = dm / 100; /* Low threshold */ + + if (da <= d1) { + /* If never got out of low delay zone, then use max */ + if (!ca->rtt_above) + return ALPHA_MAX; + + /* Wait for 5 good RTT's before allowing alpha to go alpha max. + * This prevents one good RTT from causing sudden window increase. + */ + if (++ca->rtt_low < theta) + return ca->alpha; + + ca->rtt_low = 0; + ca->rtt_above = 0; + return ALPHA_MAX; + } + + ca->rtt_above = 1; + + /* + * Based on: + * + * (dm - d1) amin amax + * k1 = ------------------- + * amax - amin + * + * (dm - d1) amin + * k2 = ---------------- - d1 + * amax - amin + * + * k1 + * alpha = ---------- + * k2 + da + */ + + dm -= d1; + da -= d1; + return (dm * ALPHA_MAX) / + (dm + (da * (ALPHA_MAX - ALPHA_MIN)) / ALPHA_MIN); +} + +/* + * Beta used for multiplicative decrease. + * For small window sizes returns same value as Reno (0.5) + * + * If delay is small (10% of max) then beta = 1/8 + * If delay is up to 80% of max then beta = 1/2 + * In between is a linear function + */ +static u32 beta(u32 da, u32 dm) +{ + u32 d2, d3; + + d2 = dm / 10; + if (da <= d2) + return BETA_MIN; + + d3 = (8 * dm) / 10; + if (da >= d3 || d3 <= d2) + return BETA_MAX; + + /* + * Based on: + * + * bmin d3 - bmax d2 + * k3 = ------------------- + * d3 - d2 + * + * bmax - bmin + * k4 = ------------- + * d3 - d2 + * + * b = k3 + k4 da + */ + return (BETA_MIN * d3 - BETA_MAX * d2 + (BETA_MAX - BETA_MIN) * da) + / (d3 - d2); +} + +/* Update alpha and beta values once per RTT */ +static void update_params(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct illinois *ca = inet_csk_ca(sk); + + if (tcp_snd_cwnd(tp) < win_thresh) { + ca->alpha = ALPHA_BASE; + ca->beta = BETA_BASE; + } else if (ca->cnt_rtt > 0) { + u32 dm = max_delay(ca); + u32 da = avg_delay(ca); + + ca->alpha = alpha(ca, da, dm); + ca->beta = beta(da, dm); + } + + rtt_reset(sk); +} + +/* + * In case of loss, reset to default values + */ +static void tcp_illinois_state(struct sock *sk, u8 new_state) +{ + struct illinois *ca = inet_csk_ca(sk); + + if (new_state == TCP_CA_Loss) { + ca->alpha = ALPHA_BASE; + ca->beta = BETA_BASE; + ca->rtt_low = 0; + ca->rtt_above = 0; + rtt_reset(sk); + } +} + +/* + * Increase window in response to successful acknowledgment. + */ +static void tcp_illinois_cong_avoid(struct sock *sk, u32 ack, u32 acked) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct illinois *ca = inet_csk_ca(sk); + + if (after(ack, ca->end_seq)) + update_params(sk); + + /* RFC2861 only increase cwnd if fully utilized */ + if (!tcp_is_cwnd_limited(sk)) + return; + + /* In slow start */ + if (tcp_in_slow_start(tp)) + tcp_slow_start(tp, acked); + + else { + u32 delta; + + /* snd_cwnd_cnt is # of packets since last cwnd increment */ + tp->snd_cwnd_cnt += ca->acked; + ca->acked = 1; + + /* This is close approximation of: + * tp->snd_cwnd += alpha/tp->snd_cwnd + */ + delta = (tp->snd_cwnd_cnt * ca->alpha) >> ALPHA_SHIFT; + if (delta >= tcp_snd_cwnd(tp)) { + tcp_snd_cwnd_set(tp, min(tcp_snd_cwnd(tp) + delta / tcp_snd_cwnd(tp), + (u32)tp->snd_cwnd_clamp)); + tp->snd_cwnd_cnt = 0; + } + } +} + +static u32 tcp_illinois_ssthresh(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct illinois *ca = inet_csk_ca(sk); + u32 decr; + + /* Multiplicative decrease */ + decr = (tcp_snd_cwnd(tp) * ca->beta) >> BETA_SHIFT; + return max(tcp_snd_cwnd(tp) - decr, 2U); +} + +/* Extract info for Tcp socket info provided via netlink. */ +static size_t tcp_illinois_info(struct sock *sk, u32 ext, int *attr, + union tcp_cc_info *info) +{ + const struct illinois *ca = inet_csk_ca(sk); + + if (ext & (1 << (INET_DIAG_VEGASINFO - 1))) { + info->vegas.tcpv_enabled = 1; + info->vegas.tcpv_rttcnt = ca->cnt_rtt; + info->vegas.tcpv_minrtt = ca->base_rtt; + info->vegas.tcpv_rtt = 0; + + if (info->vegas.tcpv_rttcnt > 0) { + u64 t = ca->sum_rtt; + + do_div(t, info->vegas.tcpv_rttcnt); + info->vegas.tcpv_rtt = t; + } + *attr = INET_DIAG_VEGASINFO; + return sizeof(struct tcpvegas_info); + } + return 0; +} + +static struct tcp_congestion_ops tcp_illinois __read_mostly = { + .init = tcp_illinois_init, + .ssthresh = tcp_illinois_ssthresh, + .undo_cwnd = tcp_reno_undo_cwnd, + .cong_avoid = tcp_illinois_cong_avoid, + .set_state = tcp_illinois_state, + .get_info = tcp_illinois_info, + .pkts_acked = tcp_illinois_acked, + + .owner = THIS_MODULE, + .name = "illinois", +}; + +static int __init tcp_illinois_register(void) +{ + BUILD_BUG_ON(sizeof(struct illinois) > ICSK_CA_PRIV_SIZE); + return tcp_register_congestion_control(&tcp_illinois); +} + +static void __exit tcp_illinois_unregister(void) +{ + tcp_unregister_congestion_control(&tcp_illinois); +} + +module_init(tcp_illinois_register); +module_exit(tcp_illinois_unregister); + +MODULE_AUTHOR("Stephen Hemminger, Shao Liu"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("TCP Illinois"); +MODULE_VERSION("1.0"); diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c new file mode 100644 index 000000000..34460c9b3 --- /dev/null +++ b/net/ipv4/tcp_input.c @@ -0,0 +1,7075 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * Implementation of the Transmission Control Protocol(TCP). + * + * Authors: Ross Biro + * Fred N. van Kempen, + * Mark Evans, + * Corey Minyard + * Florian La Roche, + * Charles Hedrick, + * Linus Torvalds, + * Alan Cox, + * Matthew Dillon, + * Arnt Gulbrandsen, + * Jorge Cwik, + */ + +/* + * Changes: + * Pedro Roque : Fast Retransmit/Recovery. + * Two receive queues. + * Retransmit queue handled by TCP. + * Better retransmit timer handling. + * New congestion avoidance. + * Header prediction. + * Variable renaming. + * + * Eric : Fast Retransmit. + * Randy Scott : MSS option defines. + * Eric Schenk : Fixes to slow start algorithm. + * Eric Schenk : Yet another double ACK bug. + * Eric Schenk : Delayed ACK bug fixes. + * Eric Schenk : Floyd style fast retrans war avoidance. + * David S. Miller : Don't allow zero congestion window. + * Eric Schenk : Fix retransmitter so that it sends + * next packet on ack of previous packet. + * Andi Kleen : Moved open_request checking here + * and process RSTs for open_requests. + * Andi Kleen : Better prune_queue, and other fixes. + * Andrey Savochkin: Fix RTT measurements in the presence of + * timestamps. + * Andrey Savochkin: Check sequence numbers correctly when + * removing SACKs due to in sequence incoming + * data segments. + * Andi Kleen: Make sure we never ack data there is not + * enough room for. Also make this condition + * a fatal error if it might still happen. + * Andi Kleen: Add tcp_measure_rcv_mss to make + * connections with MSS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +int sysctl_tcp_max_orphans __read_mostly = NR_FILE; + +#define FLAG_DATA 0x01 /* Incoming frame contained data. */ +#define FLAG_WIN_UPDATE 0x02 /* Incoming ACK was a window update. */ +#define FLAG_DATA_ACKED 0x04 /* This ACK acknowledged new data. */ +#define FLAG_RETRANS_DATA_ACKED 0x08 /* "" "" some of which was retransmitted. */ +#define FLAG_SYN_ACKED 0x10 /* This ACK acknowledged SYN. */ +#define FLAG_DATA_SACKED 0x20 /* New SACK. */ +#define FLAG_ECE 0x40 /* ECE in this ACK */ +#define FLAG_LOST_RETRANS 0x80 /* This ACK marks some retransmission lost */ +#define FLAG_SLOWPATH 0x100 /* Do not skip RFC checks for window update.*/ +#define FLAG_ORIG_SACK_ACKED 0x200 /* Never retransmitted data are (s)acked */ +#define FLAG_SND_UNA_ADVANCED 0x400 /* Snd_una was changed (!= FLAG_DATA_ACKED) */ +#define FLAG_DSACKING_ACK 0x800 /* SACK blocks contained D-SACK info */ +#define FLAG_SET_XMIT_TIMER 0x1000 /* Set TLP or RTO timer */ +#define FLAG_SACK_RENEGING 0x2000 /* snd_una advanced to a sacked seq */ +#define FLAG_UPDATE_TS_RECENT 0x4000 /* tcp_replace_ts_recent() */ +#define FLAG_NO_CHALLENGE_ACK 0x8000 /* do not call tcp_send_challenge_ack() */ +#define FLAG_ACK_MAYBE_DELAYED 0x10000 /* Likely a delayed ACK */ +#define FLAG_DSACK_TLP 0x20000 /* DSACK for tail loss probe */ + +#define FLAG_ACKED (FLAG_DATA_ACKED|FLAG_SYN_ACKED) +#define FLAG_NOT_DUP (FLAG_DATA|FLAG_WIN_UPDATE|FLAG_ACKED) +#define FLAG_CA_ALERT (FLAG_DATA_SACKED|FLAG_ECE|FLAG_DSACKING_ACK) +#define FLAG_FORWARD_PROGRESS (FLAG_ACKED|FLAG_DATA_SACKED) + +#define TCP_REMNANT (TCP_FLAG_FIN|TCP_FLAG_URG|TCP_FLAG_SYN|TCP_FLAG_PSH) +#define TCP_HP_BITS (~(TCP_RESERVED_BITS|TCP_FLAG_PSH)) + +#define REXMIT_NONE 0 /* no loss recovery to do */ +#define REXMIT_LOST 1 /* retransmit packets marked lost */ +#define REXMIT_NEW 2 /* FRTO-style transmit of unsent/new packets */ + +#if IS_ENABLED(CONFIG_TLS_DEVICE) +static DEFINE_STATIC_KEY_DEFERRED_FALSE(clean_acked_data_enabled, HZ); + +void clean_acked_data_enable(struct inet_connection_sock *icsk, + void (*cad)(struct sock *sk, u32 ack_seq)) +{ + icsk->icsk_clean_acked = cad; + static_branch_deferred_inc(&clean_acked_data_enabled); +} +EXPORT_SYMBOL_GPL(clean_acked_data_enable); + +void clean_acked_data_disable(struct inet_connection_sock *icsk) +{ + static_branch_slow_dec_deferred(&clean_acked_data_enabled); + icsk->icsk_clean_acked = NULL; +} +EXPORT_SYMBOL_GPL(clean_acked_data_disable); + +void clean_acked_data_flush(void) +{ + static_key_deferred_flush(&clean_acked_data_enabled); +} +EXPORT_SYMBOL_GPL(clean_acked_data_flush); +#endif + +#ifdef CONFIG_CGROUP_BPF +static void bpf_skops_parse_hdr(struct sock *sk, struct sk_buff *skb) +{ + bool unknown_opt = tcp_sk(sk)->rx_opt.saw_unknown && + BPF_SOCK_OPS_TEST_FLAG(tcp_sk(sk), + BPF_SOCK_OPS_PARSE_UNKNOWN_HDR_OPT_CB_FLAG); + bool parse_all_opt = BPF_SOCK_OPS_TEST_FLAG(tcp_sk(sk), + BPF_SOCK_OPS_PARSE_ALL_HDR_OPT_CB_FLAG); + struct bpf_sock_ops_kern sock_ops; + + if (likely(!unknown_opt && !parse_all_opt)) + return; + + /* The skb will be handled in the + * bpf_skops_established() or + * bpf_skops_write_hdr_opt(). + */ + switch (sk->sk_state) { + case TCP_SYN_RECV: + case TCP_SYN_SENT: + case TCP_LISTEN: + return; + } + + sock_owned_by_me(sk); + + memset(&sock_ops, 0, offsetof(struct bpf_sock_ops_kern, temp)); + sock_ops.op = BPF_SOCK_OPS_PARSE_HDR_OPT_CB; + sock_ops.is_fullsock = 1; + sock_ops.sk = sk; + bpf_skops_init_skb(&sock_ops, skb, tcp_hdrlen(skb)); + + BPF_CGROUP_RUN_PROG_SOCK_OPS(&sock_ops); +} + +static void bpf_skops_established(struct sock *sk, int bpf_op, + struct sk_buff *skb) +{ + struct bpf_sock_ops_kern sock_ops; + + sock_owned_by_me(sk); + + memset(&sock_ops, 0, offsetof(struct bpf_sock_ops_kern, temp)); + sock_ops.op = bpf_op; + sock_ops.is_fullsock = 1; + sock_ops.sk = sk; + /* sk with TCP_REPAIR_ON does not have skb in tcp_finish_connect */ + if (skb) + bpf_skops_init_skb(&sock_ops, skb, tcp_hdrlen(skb)); + + BPF_CGROUP_RUN_PROG_SOCK_OPS(&sock_ops); +} +#else +static void bpf_skops_parse_hdr(struct sock *sk, struct sk_buff *skb) +{ +} + +static void bpf_skops_established(struct sock *sk, int bpf_op, + struct sk_buff *skb) +{ +} +#endif + +static void tcp_gro_dev_warn(struct sock *sk, const struct sk_buff *skb, + unsigned int len) +{ + static bool __once __read_mostly; + + if (!__once) { + struct net_device *dev; + + __once = true; + + rcu_read_lock(); + dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif); + if (!dev || len >= dev->mtu) + pr_warn("%s: Driver has suspect GRO implementation, TCP performance may be compromised.\n", + dev ? dev->name : "Unknown driver"); + rcu_read_unlock(); + } +} + +/* Adapt the MSS value used to make delayed ack decision to the + * real world. + */ +static void tcp_measure_rcv_mss(struct sock *sk, const struct sk_buff *skb) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + const unsigned int lss = icsk->icsk_ack.last_seg_size; + unsigned int len; + + icsk->icsk_ack.last_seg_size = 0; + + /* skb->len may jitter because of SACKs, even if peer + * sends good full-sized frames. + */ + len = skb_shinfo(skb)->gso_size ? : skb->len; + if (len >= icsk->icsk_ack.rcv_mss) { + icsk->icsk_ack.rcv_mss = min_t(unsigned int, len, + tcp_sk(sk)->advmss); + /* Account for possibly-removed options */ + if (unlikely(len > icsk->icsk_ack.rcv_mss + + MAX_TCP_OPTION_SPACE)) + tcp_gro_dev_warn(sk, skb, len); + /* If the skb has a len of exactly 1*MSS and has the PSH bit + * set then it is likely the end of an application write. So + * more data may not be arriving soon, and yet the data sender + * may be waiting for an ACK if cwnd-bound or using TX zero + * copy. So we set ICSK_ACK_PUSHED here so that + * tcp_cleanup_rbuf() will send an ACK immediately if the app + * reads all of the data and is not ping-pong. If len > MSS + * then this logic does not matter (and does not hurt) because + * tcp_cleanup_rbuf() will always ACK immediately if the app + * reads data and there is more than an MSS of unACKed data. + */ + if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_PSH) + icsk->icsk_ack.pending |= ICSK_ACK_PUSHED; + } else { + /* Otherwise, we make more careful check taking into account, + * that SACKs block is variable. + * + * "len" is invariant segment length, including TCP header. + */ + len += skb->data - skb_transport_header(skb); + if (len >= TCP_MSS_DEFAULT + sizeof(struct tcphdr) || + /* If PSH is not set, packet should be + * full sized, provided peer TCP is not badly broken. + * This observation (if it is correct 8)) allows + * to handle super-low mtu links fairly. + */ + (len >= TCP_MIN_MSS + sizeof(struct tcphdr) && + !(tcp_flag_word(tcp_hdr(skb)) & TCP_REMNANT))) { + /* Subtract also invariant (if peer is RFC compliant), + * tcp header plus fixed timestamp option length. + * Resulting "len" is MSS free of SACK jitter. + */ + len -= tcp_sk(sk)->tcp_header_len; + icsk->icsk_ack.last_seg_size = len; + if (len == lss) { + icsk->icsk_ack.rcv_mss = len; + return; + } + } + if (icsk->icsk_ack.pending & ICSK_ACK_PUSHED) + icsk->icsk_ack.pending |= ICSK_ACK_PUSHED2; + icsk->icsk_ack.pending |= ICSK_ACK_PUSHED; + } +} + +static void tcp_incr_quickack(struct sock *sk, unsigned int max_quickacks) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + unsigned int quickacks = tcp_sk(sk)->rcv_wnd / (2 * icsk->icsk_ack.rcv_mss); + + if (quickacks == 0) + quickacks = 2; + quickacks = min(quickacks, max_quickacks); + if (quickacks > icsk->icsk_ack.quick) + icsk->icsk_ack.quick = quickacks; +} + +static void tcp_enter_quickack_mode(struct sock *sk, unsigned int max_quickacks) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + tcp_incr_quickack(sk, max_quickacks); + inet_csk_exit_pingpong_mode(sk); + icsk->icsk_ack.ato = TCP_ATO_MIN; +} + +/* Send ACKs quickly, if "quick" count is not exhausted + * and the session is not interactive. + */ + +static bool tcp_in_quickack_mode(struct sock *sk) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + const struct dst_entry *dst = __sk_dst_get(sk); + + return (dst && dst_metric(dst, RTAX_QUICKACK)) || + (icsk->icsk_ack.quick && !inet_csk_in_pingpong_mode(sk)); +} + +static void tcp_ecn_queue_cwr(struct tcp_sock *tp) +{ + if (tp->ecn_flags & TCP_ECN_OK) + tp->ecn_flags |= TCP_ECN_QUEUE_CWR; +} + +static void tcp_ecn_accept_cwr(struct sock *sk, const struct sk_buff *skb) +{ + if (tcp_hdr(skb)->cwr) { + tcp_sk(sk)->ecn_flags &= ~TCP_ECN_DEMAND_CWR; + + /* If the sender is telling us it has entered CWR, then its + * cwnd may be very low (even just 1 packet), so we should ACK + * immediately. + */ + if (TCP_SKB_CB(skb)->seq != TCP_SKB_CB(skb)->end_seq) + inet_csk(sk)->icsk_ack.pending |= ICSK_ACK_NOW; + } +} + +static void tcp_ecn_withdraw_cwr(struct tcp_sock *tp) +{ + tp->ecn_flags &= ~TCP_ECN_QUEUE_CWR; +} + +static void __tcp_ecn_check_ce(struct sock *sk, const struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + + switch (TCP_SKB_CB(skb)->ip_dsfield & INET_ECN_MASK) { + case INET_ECN_NOT_ECT: + /* Funny extension: if ECT is not set on a segment, + * and we already seen ECT on a previous segment, + * it is probably a retransmit. + */ + if (tp->ecn_flags & TCP_ECN_SEEN) + tcp_enter_quickack_mode(sk, 2); + break; + case INET_ECN_CE: + if (tcp_ca_needs_ecn(sk)) + tcp_ca_event(sk, CA_EVENT_ECN_IS_CE); + + if (!(tp->ecn_flags & TCP_ECN_DEMAND_CWR)) { + /* Better not delay acks, sender can have a very low cwnd */ + tcp_enter_quickack_mode(sk, 2); + tp->ecn_flags |= TCP_ECN_DEMAND_CWR; + } + tp->ecn_flags |= TCP_ECN_SEEN; + break; + default: + if (tcp_ca_needs_ecn(sk)) + tcp_ca_event(sk, CA_EVENT_ECN_NO_CE); + tp->ecn_flags |= TCP_ECN_SEEN; + break; + } +} + +static void tcp_ecn_check_ce(struct sock *sk, const struct sk_buff *skb) +{ + if (tcp_sk(sk)->ecn_flags & TCP_ECN_OK) + __tcp_ecn_check_ce(sk, skb); +} + +static void tcp_ecn_rcv_synack(struct tcp_sock *tp, const struct tcphdr *th) +{ + if ((tp->ecn_flags & TCP_ECN_OK) && (!th->ece || th->cwr)) + tp->ecn_flags &= ~TCP_ECN_OK; +} + +static void tcp_ecn_rcv_syn(struct tcp_sock *tp, const struct tcphdr *th) +{ + if ((tp->ecn_flags & TCP_ECN_OK) && (!th->ece || !th->cwr)) + tp->ecn_flags &= ~TCP_ECN_OK; +} + +static bool tcp_ecn_rcv_ecn_echo(const struct tcp_sock *tp, const struct tcphdr *th) +{ + if (th->ece && !th->syn && (tp->ecn_flags & TCP_ECN_OK)) + return true; + return false; +} + +/* Buffer size and advertised window tuning. + * + * 1. Tuning sk->sk_sndbuf, when connection enters established state. + */ + +static void tcp_sndbuf_expand(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + const struct tcp_congestion_ops *ca_ops = inet_csk(sk)->icsk_ca_ops; + int sndmem, per_mss; + u32 nr_segs; + + /* Worst case is non GSO/TSO : each frame consumes one skb + * and skb->head is kmalloced using power of two area of memory + */ + per_mss = max_t(u32, tp->rx_opt.mss_clamp, tp->mss_cache) + + MAX_TCP_HEADER + + SKB_DATA_ALIGN(sizeof(struct skb_shared_info)); + + per_mss = roundup_pow_of_two(per_mss) + + SKB_DATA_ALIGN(sizeof(struct sk_buff)); + + nr_segs = max_t(u32, TCP_INIT_CWND, tcp_snd_cwnd(tp)); + nr_segs = max_t(u32, nr_segs, tp->reordering + 1); + + /* Fast Recovery (RFC 5681 3.2) : + * Cubic needs 1.7 factor, rounded to 2 to include + * extra cushion (application might react slowly to EPOLLOUT) + */ + sndmem = ca_ops->sndbuf_expand ? ca_ops->sndbuf_expand(sk) : 2; + sndmem *= nr_segs * per_mss; + + if (sk->sk_sndbuf < sndmem) + WRITE_ONCE(sk->sk_sndbuf, + min(sndmem, READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_wmem[2]))); +} + +/* 2. Tuning advertised window (window_clamp, rcv_ssthresh) + * + * All tcp_full_space() is split to two parts: "network" buffer, allocated + * forward and advertised in receiver window (tp->rcv_wnd) and + * "application buffer", required to isolate scheduling/application + * latencies from network. + * window_clamp is maximal advertised window. It can be less than + * tcp_full_space(), in this case tcp_full_space() - window_clamp + * is reserved for "application" buffer. The less window_clamp is + * the smoother our behaviour from viewpoint of network, but the lower + * throughput and the higher sensitivity of the connection to losses. 8) + * + * rcv_ssthresh is more strict window_clamp used at "slow start" + * phase to predict further behaviour of this connection. + * It is used for two goals: + * - to enforce header prediction at sender, even when application + * requires some significant "application buffer". It is check #1. + * - to prevent pruning of receive queue because of misprediction + * of receiver window. Check #2. + * + * The scheme does not work when sender sends good segments opening + * window and then starts to feed us spaghetti. But it should work + * in common situations. Otherwise, we have to rely on queue collapsing. + */ + +/* Slow part of check#2. */ +static int __tcp_grow_window(const struct sock *sk, const struct sk_buff *skb, + unsigned int skbtruesize) +{ + struct tcp_sock *tp = tcp_sk(sk); + /* Optimize this! */ + int truesize = tcp_win_from_space(sk, skbtruesize) >> 1; + int window = tcp_win_from_space(sk, READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_rmem[2])) >> 1; + + while (tp->rcv_ssthresh <= window) { + if (truesize <= skb->len) + return 2 * inet_csk(sk)->icsk_ack.rcv_mss; + + truesize >>= 1; + window >>= 1; + } + return 0; +} + +/* Even if skb appears to have a bad len/truesize ratio, TCP coalescing + * can play nice with us, as sk_buff and skb->head might be either + * freed or shared with up to MAX_SKB_FRAGS segments. + * Only give a boost to drivers using page frag(s) to hold the frame(s), + * and if no payload was pulled in skb->head before reaching us. + */ +static u32 truesize_adjust(bool adjust, const struct sk_buff *skb) +{ + u32 truesize = skb->truesize; + + if (adjust && !skb_headlen(skb)) { + truesize -= SKB_TRUESIZE(skb_end_offset(skb)); + /* paranoid check, some drivers might be buggy */ + if (unlikely((int)truesize < (int)skb->len)) + truesize = skb->truesize; + } + return truesize; +} + +static void tcp_grow_window(struct sock *sk, const struct sk_buff *skb, + bool adjust) +{ + struct tcp_sock *tp = tcp_sk(sk); + int room; + + room = min_t(int, tp->window_clamp, tcp_space(sk)) - tp->rcv_ssthresh; + + if (room <= 0) + return; + + /* Check #1 */ + if (!tcp_under_memory_pressure(sk)) { + unsigned int truesize = truesize_adjust(adjust, skb); + int incr; + + /* Check #2. Increase window, if skb with such overhead + * will fit to rcvbuf in future. + */ + if (tcp_win_from_space(sk, truesize) <= skb->len) + incr = 2 * tp->advmss; + else + incr = __tcp_grow_window(sk, skb, truesize); + + if (incr) { + incr = max_t(int, incr, 2 * skb->len); + tp->rcv_ssthresh += min(room, incr); + inet_csk(sk)->icsk_ack.quick |= 1; + } + } else { + /* Under pressure: + * Adjust rcv_ssthresh according to reserved mem + */ + tcp_adjust_rcv_ssthresh(sk); + } +} + +/* 3. Try to fixup all. It is made immediately after connection enters + * established state. + */ +static void tcp_init_buffer_space(struct sock *sk) +{ + int tcp_app_win = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_app_win); + struct tcp_sock *tp = tcp_sk(sk); + int maxwin; + + if (!(sk->sk_userlocks & SOCK_SNDBUF_LOCK)) + tcp_sndbuf_expand(sk); + + tcp_mstamp_refresh(tp); + tp->rcvq_space.time = tp->tcp_mstamp; + tp->rcvq_space.seq = tp->copied_seq; + + maxwin = tcp_full_space(sk); + + if (tp->window_clamp >= maxwin) { + tp->window_clamp = maxwin; + + if (tcp_app_win && maxwin > 4 * tp->advmss) + tp->window_clamp = max(maxwin - + (maxwin >> tcp_app_win), + 4 * tp->advmss); + } + + /* Force reservation of one segment. */ + if (tcp_app_win && + tp->window_clamp > 2 * tp->advmss && + tp->window_clamp + tp->advmss > maxwin) + tp->window_clamp = max(2 * tp->advmss, maxwin - tp->advmss); + + tp->rcv_ssthresh = min(tp->rcv_ssthresh, tp->window_clamp); + tp->snd_cwnd_stamp = tcp_jiffies32; + tp->rcvq_space.space = min3(tp->rcv_ssthresh, tp->rcv_wnd, + (u32)TCP_INIT_CWND * tp->advmss); +} + +/* 4. Recalculate window clamp after socket hit its memory bounds. */ +static void tcp_clamp_window(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct inet_connection_sock *icsk = inet_csk(sk); + struct net *net = sock_net(sk); + int rmem2; + + icsk->icsk_ack.quick = 0; + rmem2 = READ_ONCE(net->ipv4.sysctl_tcp_rmem[2]); + + if (sk->sk_rcvbuf < rmem2 && + !(sk->sk_userlocks & SOCK_RCVBUF_LOCK) && + !tcp_under_memory_pressure(sk) && + sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0)) { + WRITE_ONCE(sk->sk_rcvbuf, + min(atomic_read(&sk->sk_rmem_alloc), rmem2)); + } + if (atomic_read(&sk->sk_rmem_alloc) > sk->sk_rcvbuf) + tp->rcv_ssthresh = min(tp->window_clamp, 2U * tp->advmss); +} + +/* Initialize RCV_MSS value. + * RCV_MSS is an our guess about MSS used by the peer. + * We haven't any direct information about the MSS. + * It's better to underestimate the RCV_MSS rather than overestimate. + * Overestimations make us ACKing less frequently than needed. + * Underestimations are more easy to detect and fix by tcp_measure_rcv_mss(). + */ +void tcp_initialize_rcv_mss(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + unsigned int hint = min_t(unsigned int, tp->advmss, tp->mss_cache); + + hint = min(hint, tp->rcv_wnd / 2); + hint = min(hint, TCP_MSS_DEFAULT); + hint = max(hint, TCP_MIN_MSS); + + inet_csk(sk)->icsk_ack.rcv_mss = hint; +} +EXPORT_SYMBOL(tcp_initialize_rcv_mss); + +/* Receiver "autotuning" code. + * + * The algorithm for RTT estimation w/o timestamps is based on + * Dynamic Right-Sizing (DRS) by Wu Feng and Mike Fisk of LANL. + * + * + * More detail on this code can be found at + * , + * though this reference is out of date. A new paper + * is pending. + */ +static void tcp_rcv_rtt_update(struct tcp_sock *tp, u32 sample, int win_dep) +{ + u32 new_sample = tp->rcv_rtt_est.rtt_us; + long m = sample; + + if (new_sample != 0) { + /* If we sample in larger samples in the non-timestamp + * case, we could grossly overestimate the RTT especially + * with chatty applications or bulk transfer apps which + * are stalled on filesystem I/O. + * + * Also, since we are only going for a minimum in the + * non-timestamp case, we do not smooth things out + * else with timestamps disabled convergence takes too + * long. + */ + if (!win_dep) { + m -= (new_sample >> 3); + new_sample += m; + } else { + m <<= 3; + if (m < new_sample) + new_sample = m; + } + } else { + /* No previous measure. */ + new_sample = m << 3; + } + + tp->rcv_rtt_est.rtt_us = new_sample; +} + +static inline void tcp_rcv_rtt_measure(struct tcp_sock *tp) +{ + u32 delta_us; + + if (tp->rcv_rtt_est.time == 0) + goto new_measure; + if (before(tp->rcv_nxt, tp->rcv_rtt_est.seq)) + return; + delta_us = tcp_stamp_us_delta(tp->tcp_mstamp, tp->rcv_rtt_est.time); + if (!delta_us) + delta_us = 1; + tcp_rcv_rtt_update(tp, delta_us, 1); + +new_measure: + tp->rcv_rtt_est.seq = tp->rcv_nxt + tp->rcv_wnd; + tp->rcv_rtt_est.time = tp->tcp_mstamp; +} + +static inline void tcp_rcv_rtt_measure_ts(struct sock *sk, + const struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (tp->rx_opt.rcv_tsecr == tp->rcv_rtt_last_tsecr) + return; + tp->rcv_rtt_last_tsecr = tp->rx_opt.rcv_tsecr; + + if (TCP_SKB_CB(skb)->end_seq - + TCP_SKB_CB(skb)->seq >= inet_csk(sk)->icsk_ack.rcv_mss) { + u32 delta = tcp_time_stamp(tp) - tp->rx_opt.rcv_tsecr; + u32 delta_us; + + if (likely(delta < INT_MAX / (USEC_PER_SEC / TCP_TS_HZ))) { + if (!delta) + delta = 1; + delta_us = delta * (USEC_PER_SEC / TCP_TS_HZ); + tcp_rcv_rtt_update(tp, delta_us, 0); + } + } +} + +/* + * This function should be called every time data is copied to user space. + * It calculates the appropriate TCP receive buffer space. + */ +void tcp_rcv_space_adjust(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 copied; + int time; + + trace_tcp_rcv_space_adjust(sk); + + tcp_mstamp_refresh(tp); + time = tcp_stamp_us_delta(tp->tcp_mstamp, tp->rcvq_space.time); + if (time < (tp->rcv_rtt_est.rtt_us >> 3) || tp->rcv_rtt_est.rtt_us == 0) + return; + + /* Number of bytes copied to user in last RTT */ + copied = tp->copied_seq - tp->rcvq_space.seq; + if (copied <= tp->rcvq_space.space) + goto new_measure; + + /* A bit of theory : + * copied = bytes received in previous RTT, our base window + * To cope with packet losses, we need a 2x factor + * To cope with slow start, and sender growing its cwin by 100 % + * every RTT, we need a 4x factor, because the ACK we are sending + * now is for the next RTT, not the current one : + * + */ + + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_moderate_rcvbuf) && + !(sk->sk_userlocks & SOCK_RCVBUF_LOCK)) { + int rcvmem, rcvbuf; + u64 rcvwin, grow; + + /* minimal window to cope with packet losses, assuming + * steady state. Add some cushion because of small variations. + */ + rcvwin = ((u64)copied << 1) + 16 * tp->advmss; + + /* Accommodate for sender rate increase (eg. slow start) */ + grow = rcvwin * (copied - tp->rcvq_space.space); + do_div(grow, tp->rcvq_space.space); + rcvwin += (grow << 1); + + rcvmem = SKB_TRUESIZE(tp->advmss + MAX_TCP_HEADER); + while (tcp_win_from_space(sk, rcvmem) < tp->advmss) + rcvmem += 128; + + do_div(rcvwin, tp->advmss); + rcvbuf = min_t(u64, rcvwin * rcvmem, + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_rmem[2])); + if (rcvbuf > sk->sk_rcvbuf) { + WRITE_ONCE(sk->sk_rcvbuf, rcvbuf); + + /* Make the window clamp follow along. */ + tp->window_clamp = tcp_win_from_space(sk, rcvbuf); + } + } + tp->rcvq_space.space = copied; + +new_measure: + tp->rcvq_space.seq = tp->copied_seq; + tp->rcvq_space.time = tp->tcp_mstamp; +} + +/* There is something which you must keep in mind when you analyze the + * behavior of the tp->ato delayed ack timeout interval. When a + * connection starts up, we want to ack as quickly as possible. The + * problem is that "good" TCP's do slow start at the beginning of data + * transmission. The means that until we send the first few ACK's the + * sender will sit on his end and only queue most of his data, because + * he can only send snd_cwnd unacked packets at any given time. For + * each ACK we send, he increments snd_cwnd and transmits more of his + * queue. -DaveM + */ +static void tcp_event_data_recv(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct inet_connection_sock *icsk = inet_csk(sk); + u32 now; + + inet_csk_schedule_ack(sk); + + tcp_measure_rcv_mss(sk, skb); + + tcp_rcv_rtt_measure(tp); + + now = tcp_jiffies32; + + if (!icsk->icsk_ack.ato) { + /* The _first_ data packet received, initialize + * delayed ACK engine. + */ + tcp_incr_quickack(sk, TCP_MAX_QUICKACKS); + icsk->icsk_ack.ato = TCP_ATO_MIN; + } else { + int m = now - icsk->icsk_ack.lrcvtime; + + if (m <= TCP_ATO_MIN / 2) { + /* The fastest case is the first. */ + icsk->icsk_ack.ato = (icsk->icsk_ack.ato >> 1) + TCP_ATO_MIN / 2; + } else if (m < icsk->icsk_ack.ato) { + icsk->icsk_ack.ato = (icsk->icsk_ack.ato >> 1) + m; + if (icsk->icsk_ack.ato > icsk->icsk_rto) + icsk->icsk_ack.ato = icsk->icsk_rto; + } else if (m > icsk->icsk_rto) { + /* Too long gap. Apparently sender failed to + * restart window, so that we send ACKs quickly. + */ + tcp_incr_quickack(sk, TCP_MAX_QUICKACKS); + } + } + icsk->icsk_ack.lrcvtime = now; + + tcp_ecn_check_ce(sk, skb); + + if (skb->len >= 128) + tcp_grow_window(sk, skb, true); +} + +/* Called to compute a smoothed rtt estimate. The data fed to this + * routine either comes from timestamps, or from segments that were + * known _not_ to have been retransmitted [see Karn/Partridge + * Proceedings SIGCOMM 87]. The algorithm is from the SIGCOMM 88 + * piece by Van Jacobson. + * NOTE: the next three routines used to be one big routine. + * To save cycles in the RFC 1323 implementation it was better to break + * it up into three procedures. -- erics + */ +static void tcp_rtt_estimator(struct sock *sk, long mrtt_us) +{ + struct tcp_sock *tp = tcp_sk(sk); + long m = mrtt_us; /* RTT */ + u32 srtt = tp->srtt_us; + + /* The following amusing code comes from Jacobson's + * article in SIGCOMM '88. Note that rtt and mdev + * are scaled versions of rtt and mean deviation. + * This is designed to be as fast as possible + * m stands for "measurement". + * + * On a 1990 paper the rto value is changed to: + * RTO = rtt + 4 * mdev + * + * Funny. This algorithm seems to be very broken. + * These formulae increase RTO, when it should be decreased, increase + * too slowly, when it should be increased quickly, decrease too quickly + * etc. I guess in BSD RTO takes ONE value, so that it is absolutely + * does not matter how to _calculate_ it. Seems, it was trap + * that VJ failed to avoid. 8) + */ + if (srtt != 0) { + m -= (srtt >> 3); /* m is now error in rtt est */ + srtt += m; /* rtt = 7/8 rtt + 1/8 new */ + if (m < 0) { + m = -m; /* m is now abs(error) */ + m -= (tp->mdev_us >> 2); /* similar update on mdev */ + /* This is similar to one of Eifel findings. + * Eifel blocks mdev updates when rtt decreases. + * This solution is a bit different: we use finer gain + * for mdev in this case (alpha*beta). + * Like Eifel it also prevents growth of rto, + * but also it limits too fast rto decreases, + * happening in pure Eifel. + */ + if (m > 0) + m >>= 3; + } else { + m -= (tp->mdev_us >> 2); /* similar update on mdev */ + } + tp->mdev_us += m; /* mdev = 3/4 mdev + 1/4 new */ + if (tp->mdev_us > tp->mdev_max_us) { + tp->mdev_max_us = tp->mdev_us; + if (tp->mdev_max_us > tp->rttvar_us) + tp->rttvar_us = tp->mdev_max_us; + } + if (after(tp->snd_una, tp->rtt_seq)) { + if (tp->mdev_max_us < tp->rttvar_us) + tp->rttvar_us -= (tp->rttvar_us - tp->mdev_max_us) >> 2; + tp->rtt_seq = tp->snd_nxt; + tp->mdev_max_us = tcp_rto_min_us(sk); + + tcp_bpf_rtt(sk); + } + } else { + /* no previous measure. */ + srtt = m << 3; /* take the measured time to be rtt */ + tp->mdev_us = m << 1; /* make sure rto = 3*rtt */ + tp->rttvar_us = max(tp->mdev_us, tcp_rto_min_us(sk)); + tp->mdev_max_us = tp->rttvar_us; + tp->rtt_seq = tp->snd_nxt; + + tcp_bpf_rtt(sk); + } + tp->srtt_us = max(1U, srtt); +} + +static void tcp_update_pacing_rate(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + u64 rate; + + /* set sk_pacing_rate to 200 % of current rate (mss * cwnd / srtt) */ + rate = (u64)tp->mss_cache * ((USEC_PER_SEC / 100) << 3); + + /* current rate is (cwnd * mss) / srtt + * In Slow Start [1], set sk_pacing_rate to 200 % the current rate. + * In Congestion Avoidance phase, set it to 120 % the current rate. + * + * [1] : Normal Slow Start condition is (tp->snd_cwnd < tp->snd_ssthresh) + * If snd_cwnd >= (tp->snd_ssthresh / 2), we are approaching + * end of slow start and should slow down. + */ + if (tcp_snd_cwnd(tp) < tp->snd_ssthresh / 2) + rate *= READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_pacing_ss_ratio); + else + rate *= READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_pacing_ca_ratio); + + rate *= max(tcp_snd_cwnd(tp), tp->packets_out); + + if (likely(tp->srtt_us)) + do_div(rate, tp->srtt_us); + + /* WRITE_ONCE() is needed because sch_fq fetches sk_pacing_rate + * without any lock. We want to make sure compiler wont store + * intermediate values in this location. + */ + WRITE_ONCE(sk->sk_pacing_rate, min_t(u64, rate, + sk->sk_max_pacing_rate)); +} + +/* Calculate rto without backoff. This is the second half of Van Jacobson's + * routine referred to above. + */ +static void tcp_set_rto(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + /* Old crap is replaced with new one. 8) + * + * More seriously: + * 1. If rtt variance happened to be less 50msec, it is hallucination. + * It cannot be less due to utterly erratic ACK generation made + * at least by solaris and freebsd. "Erratic ACKs" has _nothing_ + * to do with delayed acks, because at cwnd>2 true delack timeout + * is invisible. Actually, Linux-2.4 also generates erratic + * ACKs in some circumstances. + */ + inet_csk(sk)->icsk_rto = __tcp_set_rto(tp); + + /* 2. Fixups made earlier cannot be right. + * If we do not estimate RTO correctly without them, + * all the algo is pure shit and should be replaced + * with correct one. It is exactly, which we pretend to do. + */ + + /* NOTE: clamping at TCP_RTO_MIN is not required, current algo + * guarantees that rto is higher. + */ + tcp_bound_rto(sk); +} + +__u32 tcp_init_cwnd(const struct tcp_sock *tp, const struct dst_entry *dst) +{ + __u32 cwnd = (dst ? dst_metric(dst, RTAX_INITCWND) : 0); + + if (!cwnd) + cwnd = TCP_INIT_CWND; + return min_t(__u32, cwnd, tp->snd_cwnd_clamp); +} + +struct tcp_sacktag_state { + /* Timestamps for earliest and latest never-retransmitted segment + * that was SACKed. RTO needs the earliest RTT to stay conservative, + * but congestion control should still get an accurate delay signal. + */ + u64 first_sackt; + u64 last_sackt; + u32 reord; + u32 sack_delivered; + int flag; + unsigned int mss_now; + struct rate_sample *rate; +}; + +/* Take a notice that peer is sending D-SACKs. Skip update of data delivery + * and spurious retransmission information if this DSACK is unlikely caused by + * sender's action: + * - DSACKed sequence range is larger than maximum receiver's window. + * - Total no. of DSACKed segments exceed the total no. of retransmitted segs. + */ +static u32 tcp_dsack_seen(struct tcp_sock *tp, u32 start_seq, + u32 end_seq, struct tcp_sacktag_state *state) +{ + u32 seq_len, dup_segs = 1; + + if (!before(start_seq, end_seq)) + return 0; + + seq_len = end_seq - start_seq; + /* Dubious DSACK: DSACKed range greater than maximum advertised rwnd */ + if (seq_len > tp->max_window) + return 0; + if (seq_len > tp->mss_cache) + dup_segs = DIV_ROUND_UP(seq_len, tp->mss_cache); + else if (tp->tlp_high_seq && tp->tlp_high_seq == end_seq) + state->flag |= FLAG_DSACK_TLP; + + tp->dsack_dups += dup_segs; + /* Skip the DSACK if dup segs weren't retransmitted by sender */ + if (tp->dsack_dups > tp->total_retrans) + return 0; + + tp->rx_opt.sack_ok |= TCP_DSACK_SEEN; + /* We increase the RACK ordering window in rounds where we receive + * DSACKs that may have been due to reordering causing RACK to trigger + * a spurious fast recovery. Thus RACK ignores DSACKs that happen + * without having seen reordering, or that match TLP probes (TLP + * is timer-driven, not triggered by RACK). + */ + if (tp->reord_seen && !(state->flag & FLAG_DSACK_TLP)) + tp->rack.dsack_seen = 1; + + state->flag |= FLAG_DSACKING_ACK; + /* A spurious retransmission is delivered */ + state->sack_delivered += dup_segs; + + return dup_segs; +} + +/* It's reordering when higher sequence was delivered (i.e. sacked) before + * some lower never-retransmitted sequence ("low_seq"). The maximum reordering + * distance is approximated in full-mss packet distance ("reordering"). + */ +static void tcp_check_sack_reordering(struct sock *sk, const u32 low_seq, + const int ts) +{ + struct tcp_sock *tp = tcp_sk(sk); + const u32 mss = tp->mss_cache; + u32 fack, metric; + + fack = tcp_highest_sack_seq(tp); + if (!before(low_seq, fack)) + return; + + metric = fack - low_seq; + if ((metric > tp->reordering * mss) && mss) { +#if FASTRETRANS_DEBUG > 1 + pr_debug("Disorder%d %d %u f%u s%u rr%d\n", + tp->rx_opt.sack_ok, inet_csk(sk)->icsk_ca_state, + tp->reordering, + 0, + tp->sacked_out, + tp->undo_marker ? tp->undo_retrans : 0); +#endif + tp->reordering = min_t(u32, (metric + mss - 1) / mss, + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_max_reordering)); + } + + /* This exciting event is worth to be remembered. 8) */ + tp->reord_seen++; + NET_INC_STATS(sock_net(sk), + ts ? LINUX_MIB_TCPTSREORDER : LINUX_MIB_TCPSACKREORDER); +} + + /* This must be called before lost_out or retrans_out are updated + * on a new loss, because we want to know if all skbs previously + * known to be lost have already been retransmitted, indicating + * that this newly lost skb is our next skb to retransmit. + */ +static void tcp_verify_retransmit_hint(struct tcp_sock *tp, struct sk_buff *skb) +{ + if ((!tp->retransmit_skb_hint && tp->retrans_out >= tp->lost_out) || + (tp->retransmit_skb_hint && + before(TCP_SKB_CB(skb)->seq, + TCP_SKB_CB(tp->retransmit_skb_hint)->seq))) + tp->retransmit_skb_hint = skb; +} + +/* Sum the number of packets on the wire we have marked as lost, and + * notify the congestion control module that the given skb was marked lost. + */ +static void tcp_notify_skb_loss_event(struct tcp_sock *tp, const struct sk_buff *skb) +{ + tp->lost += tcp_skb_pcount(skb); +} + +void tcp_mark_skb_lost(struct sock *sk, struct sk_buff *skb) +{ + __u8 sacked = TCP_SKB_CB(skb)->sacked; + struct tcp_sock *tp = tcp_sk(sk); + + if (sacked & TCPCB_SACKED_ACKED) + return; + + tcp_verify_retransmit_hint(tp, skb); + if (sacked & TCPCB_LOST) { + if (sacked & TCPCB_SACKED_RETRANS) { + /* Account for retransmits that are lost again */ + TCP_SKB_CB(skb)->sacked &= ~TCPCB_SACKED_RETRANS; + tp->retrans_out -= tcp_skb_pcount(skb); + NET_ADD_STATS(sock_net(sk), LINUX_MIB_TCPLOSTRETRANSMIT, + tcp_skb_pcount(skb)); + tcp_notify_skb_loss_event(tp, skb); + } + } else { + tp->lost_out += tcp_skb_pcount(skb); + TCP_SKB_CB(skb)->sacked |= TCPCB_LOST; + tcp_notify_skb_loss_event(tp, skb); + } +} + +/* Updates the delivered and delivered_ce counts */ +static void tcp_count_delivered(struct tcp_sock *tp, u32 delivered, + bool ece_ack) +{ + tp->delivered += delivered; + if (ece_ack) + tp->delivered_ce += delivered; +} + +/* This procedure tags the retransmission queue when SACKs arrive. + * + * We have three tag bits: SACKED(S), RETRANS(R) and LOST(L). + * Packets in queue with these bits set are counted in variables + * sacked_out, retrans_out and lost_out, correspondingly. + * + * Valid combinations are: + * Tag InFlight Description + * 0 1 - orig segment is in flight. + * S 0 - nothing flies, orig reached receiver. + * L 0 - nothing flies, orig lost by net. + * R 2 - both orig and retransmit are in flight. + * L|R 1 - orig is lost, retransmit is in flight. + * S|R 1 - orig reached receiver, retrans is still in flight. + * (L|S|R is logically valid, it could occur when L|R is sacked, + * but it is equivalent to plain S and code short-curcuits it to S. + * L|S is logically invalid, it would mean -1 packet in flight 8)) + * + * These 6 states form finite state machine, controlled by the following events: + * 1. New ACK (+SACK) arrives. (tcp_sacktag_write_queue()) + * 2. Retransmission. (tcp_retransmit_skb(), tcp_xmit_retransmit_queue()) + * 3. Loss detection event of two flavors: + * A. Scoreboard estimator decided the packet is lost. + * A'. Reno "three dupacks" marks head of queue lost. + * B. SACK arrives sacking SND.NXT at the moment, when the + * segment was retransmitted. + * 4. D-SACK added new rule: D-SACK changes any tag to S. + * + * It is pleasant to note, that state diagram turns out to be commutative, + * so that we are allowed not to be bothered by order of our actions, + * when multiple events arrive simultaneously. (see the function below). + * + * Reordering detection. + * -------------------- + * Reordering metric is maximal distance, which a packet can be displaced + * in packet stream. With SACKs we can estimate it: + * + * 1. SACK fills old hole and the corresponding segment was not + * ever retransmitted -> reordering. Alas, we cannot use it + * when segment was retransmitted. + * 2. The last flaw is solved with D-SACK. D-SACK arrives + * for retransmitted and already SACKed segment -> reordering.. + * Both of these heuristics are not used in Loss state, when we cannot + * account for retransmits accurately. + * + * SACK block validation. + * ---------------------- + * + * SACK block range validation checks that the received SACK block fits to + * the expected sequence limits, i.e., it is between SND.UNA and SND.NXT. + * Note that SND.UNA is not included to the range though being valid because + * it means that the receiver is rather inconsistent with itself reporting + * SACK reneging when it should advance SND.UNA. Such SACK block this is + * perfectly valid, however, in light of RFC2018 which explicitly states + * that "SACK block MUST reflect the newest segment. Even if the newest + * segment is going to be discarded ...", not that it looks very clever + * in case of head skb. Due to potentional receiver driven attacks, we + * choose to avoid immediate execution of a walk in write queue due to + * reneging and defer head skb's loss recovery to standard loss recovery + * procedure that will eventually trigger (nothing forbids us doing this). + * + * Implements also blockage to start_seq wrap-around. Problem lies in the + * fact that though start_seq (s) is before end_seq (i.e., not reversed), + * there's no guarantee that it will be before snd_nxt (n). The problem + * happens when start_seq resides between end_seq wrap (e_w) and snd_nxt + * wrap (s_w): + * + * <- outs wnd -> <- wrapzone -> + * u e n u_w e_w s n_w + * | | | | | | | + * |<------------+------+----- TCP seqno space --------------+---------->| + * ...-- <2^31 ->| |<--------... + * ...---- >2^31 ------>| |<--------... + * + * Current code wouldn't be vulnerable but it's better still to discard such + * crazy SACK blocks. Doing this check for start_seq alone closes somewhat + * similar case (end_seq after snd_nxt wrap) as earlier reversed check in + * snd_nxt wrap -> snd_una region will then become "well defined", i.e., + * equal to the ideal case (infinite seqno space without wrap caused issues). + * + * With D-SACK the lower bound is extended to cover sequence space below + * SND.UNA down to undo_marker, which is the last point of interest. Yet + * again, D-SACK block must not to go across snd_una (for the same reason as + * for the normal SACK blocks, explained above). But there all simplicity + * ends, TCP might receive valid D-SACKs below that. As long as they reside + * fully below undo_marker they do not affect behavior in anyway and can + * therefore be safely ignored. In rare cases (which are more or less + * theoretical ones), the D-SACK will nicely cross that boundary due to skb + * fragmentation and packet reordering past skb's retransmission. To consider + * them correctly, the acceptable range must be extended even more though + * the exact amount is rather hard to quantify. However, tp->max_window can + * be used as an exaggerated estimate. + */ +static bool tcp_is_sackblock_valid(struct tcp_sock *tp, bool is_dsack, + u32 start_seq, u32 end_seq) +{ + /* Too far in future, or reversed (interpretation is ambiguous) */ + if (after(end_seq, tp->snd_nxt) || !before(start_seq, end_seq)) + return false; + + /* Nasty start_seq wrap-around check (see comments above) */ + if (!before(start_seq, tp->snd_nxt)) + return false; + + /* In outstanding window? ...This is valid exit for D-SACKs too. + * start_seq == snd_una is non-sensical (see comments above) + */ + if (after(start_seq, tp->snd_una)) + return true; + + if (!is_dsack || !tp->undo_marker) + return false; + + /* ...Then it's D-SACK, and must reside below snd_una completely */ + if (after(end_seq, tp->snd_una)) + return false; + + if (!before(start_seq, tp->undo_marker)) + return true; + + /* Too old */ + if (!after(end_seq, tp->undo_marker)) + return false; + + /* Undo_marker boundary crossing (overestimates a lot). Known already: + * start_seq < undo_marker and end_seq >= undo_marker. + */ + return !before(start_seq, end_seq - tp->max_window); +} + +static bool tcp_check_dsack(struct sock *sk, const struct sk_buff *ack_skb, + struct tcp_sack_block_wire *sp, int num_sacks, + u32 prior_snd_una, struct tcp_sacktag_state *state) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 start_seq_0 = get_unaligned_be32(&sp[0].start_seq); + u32 end_seq_0 = get_unaligned_be32(&sp[0].end_seq); + u32 dup_segs; + + if (before(start_seq_0, TCP_SKB_CB(ack_skb)->ack_seq)) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPDSACKRECV); + } else if (num_sacks > 1) { + u32 end_seq_1 = get_unaligned_be32(&sp[1].end_seq); + u32 start_seq_1 = get_unaligned_be32(&sp[1].start_seq); + + if (after(end_seq_0, end_seq_1) || before(start_seq_0, start_seq_1)) + return false; + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPDSACKOFORECV); + } else { + return false; + } + + dup_segs = tcp_dsack_seen(tp, start_seq_0, end_seq_0, state); + if (!dup_segs) { /* Skip dubious DSACK */ + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPDSACKIGNOREDDUBIOUS); + return false; + } + + NET_ADD_STATS(sock_net(sk), LINUX_MIB_TCPDSACKRECVSEGS, dup_segs); + + /* D-SACK for already forgotten data... Do dumb counting. */ + if (tp->undo_marker && tp->undo_retrans > 0 && + !after(end_seq_0, prior_snd_una) && + after(end_seq_0, tp->undo_marker)) + tp->undo_retrans = max_t(int, 0, tp->undo_retrans - dup_segs); + + return true; +} + +/* Check if skb is fully within the SACK block. In presence of GSO skbs, + * the incoming SACK may not exactly match but we can find smaller MSS + * aligned portion of it that matches. Therefore we might need to fragment + * which may fail and creates some hassle (caller must handle error case + * returns). + * + * FIXME: this could be merged to shift decision code + */ +static int tcp_match_skb_to_sack(struct sock *sk, struct sk_buff *skb, + u32 start_seq, u32 end_seq) +{ + int err; + bool in_sack; + unsigned int pkt_len; + unsigned int mss; + + in_sack = !after(start_seq, TCP_SKB_CB(skb)->seq) && + !before(end_seq, TCP_SKB_CB(skb)->end_seq); + + if (tcp_skb_pcount(skb) > 1 && !in_sack && + after(TCP_SKB_CB(skb)->end_seq, start_seq)) { + mss = tcp_skb_mss(skb); + in_sack = !after(start_seq, TCP_SKB_CB(skb)->seq); + + if (!in_sack) { + pkt_len = start_seq - TCP_SKB_CB(skb)->seq; + if (pkt_len < mss) + pkt_len = mss; + } else { + pkt_len = end_seq - TCP_SKB_CB(skb)->seq; + if (pkt_len < mss) + return -EINVAL; + } + + /* Round if necessary so that SACKs cover only full MSSes + * and/or the remaining small portion (if present) + */ + if (pkt_len > mss) { + unsigned int new_len = (pkt_len / mss) * mss; + if (!in_sack && new_len < pkt_len) + new_len += mss; + pkt_len = new_len; + } + + if (pkt_len >= skb->len && !in_sack) + return 0; + + err = tcp_fragment(sk, TCP_FRAG_IN_RTX_QUEUE, skb, + pkt_len, mss, GFP_ATOMIC); + if (err < 0) + return err; + } + + return in_sack; +} + +/* Mark the given newly-SACKed range as such, adjusting counters and hints. */ +static u8 tcp_sacktag_one(struct sock *sk, + struct tcp_sacktag_state *state, u8 sacked, + u32 start_seq, u32 end_seq, + int dup_sack, int pcount, + u64 xmit_time) +{ + struct tcp_sock *tp = tcp_sk(sk); + + /* Account D-SACK for retransmitted packet. */ + if (dup_sack && (sacked & TCPCB_RETRANS)) { + if (tp->undo_marker && tp->undo_retrans > 0 && + after(end_seq, tp->undo_marker)) + tp->undo_retrans = max_t(int, 0, tp->undo_retrans - pcount); + if ((sacked & TCPCB_SACKED_ACKED) && + before(start_seq, state->reord)) + state->reord = start_seq; + } + + /* Nothing to do; acked frame is about to be dropped (was ACKed). */ + if (!after(end_seq, tp->snd_una)) + return sacked; + + if (!(sacked & TCPCB_SACKED_ACKED)) { + tcp_rack_advance(tp, sacked, end_seq, xmit_time); + + if (sacked & TCPCB_SACKED_RETRANS) { + /* If the segment is not tagged as lost, + * we do not clear RETRANS, believing + * that retransmission is still in flight. + */ + if (sacked & TCPCB_LOST) { + sacked &= ~(TCPCB_LOST|TCPCB_SACKED_RETRANS); + tp->lost_out -= pcount; + tp->retrans_out -= pcount; + } + } else { + if (!(sacked & TCPCB_RETRANS)) { + /* New sack for not retransmitted frame, + * which was in hole. It is reordering. + */ + if (before(start_seq, + tcp_highest_sack_seq(tp)) && + before(start_seq, state->reord)) + state->reord = start_seq; + + if (!after(end_seq, tp->high_seq)) + state->flag |= FLAG_ORIG_SACK_ACKED; + if (state->first_sackt == 0) + state->first_sackt = xmit_time; + state->last_sackt = xmit_time; + } + + if (sacked & TCPCB_LOST) { + sacked &= ~TCPCB_LOST; + tp->lost_out -= pcount; + } + } + + sacked |= TCPCB_SACKED_ACKED; + state->flag |= FLAG_DATA_SACKED; + tp->sacked_out += pcount; + /* Out-of-order packets delivered */ + state->sack_delivered += pcount; + + /* Lost marker hint past SACKed? Tweak RFC3517 cnt */ + if (tp->lost_skb_hint && + before(start_seq, TCP_SKB_CB(tp->lost_skb_hint)->seq)) + tp->lost_cnt_hint += pcount; + } + + /* D-SACK. We can detect redundant retransmission in S|R and plain R + * frames and clear it. undo_retrans is decreased above, L|R frames + * are accounted above as well. + */ + if (dup_sack && (sacked & TCPCB_SACKED_RETRANS)) { + sacked &= ~TCPCB_SACKED_RETRANS; + tp->retrans_out -= pcount; + } + + return sacked; +} + +/* Shift newly-SACKed bytes from this skb to the immediately previous + * already-SACKed sk_buff. Mark the newly-SACKed bytes as such. + */ +static bool tcp_shifted_skb(struct sock *sk, struct sk_buff *prev, + struct sk_buff *skb, + struct tcp_sacktag_state *state, + unsigned int pcount, int shifted, int mss, + bool dup_sack) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 start_seq = TCP_SKB_CB(skb)->seq; /* start of newly-SACKed */ + u32 end_seq = start_seq + shifted; /* end of newly-SACKed */ + + BUG_ON(!pcount); + + /* Adjust counters and hints for the newly sacked sequence + * range but discard the return value since prev is already + * marked. We must tag the range first because the seq + * advancement below implicitly advances + * tcp_highest_sack_seq() when skb is highest_sack. + */ + tcp_sacktag_one(sk, state, TCP_SKB_CB(skb)->sacked, + start_seq, end_seq, dup_sack, pcount, + tcp_skb_timestamp_us(skb)); + tcp_rate_skb_delivered(sk, skb, state->rate); + + if (skb == tp->lost_skb_hint) + tp->lost_cnt_hint += pcount; + + TCP_SKB_CB(prev)->end_seq += shifted; + TCP_SKB_CB(skb)->seq += shifted; + + tcp_skb_pcount_add(prev, pcount); + WARN_ON_ONCE(tcp_skb_pcount(skb) < pcount); + tcp_skb_pcount_add(skb, -pcount); + + /* When we're adding to gso_segs == 1, gso_size will be zero, + * in theory this shouldn't be necessary but as long as DSACK + * code can come after this skb later on it's better to keep + * setting gso_size to something. + */ + if (!TCP_SKB_CB(prev)->tcp_gso_size) + TCP_SKB_CB(prev)->tcp_gso_size = mss; + + /* CHECKME: To clear or not to clear? Mimics normal skb currently */ + if (tcp_skb_pcount(skb) <= 1) + TCP_SKB_CB(skb)->tcp_gso_size = 0; + + /* Difference in this won't matter, both ACKed by the same cumul. ACK */ + TCP_SKB_CB(prev)->sacked |= (TCP_SKB_CB(skb)->sacked & TCPCB_EVER_RETRANS); + + if (skb->len > 0) { + BUG_ON(!tcp_skb_pcount(skb)); + NET_INC_STATS(sock_net(sk), LINUX_MIB_SACKSHIFTED); + return false; + } + + /* Whole SKB was eaten :-) */ + + if (skb == tp->retransmit_skb_hint) + tp->retransmit_skb_hint = prev; + if (skb == tp->lost_skb_hint) { + tp->lost_skb_hint = prev; + tp->lost_cnt_hint -= tcp_skb_pcount(prev); + } + + TCP_SKB_CB(prev)->tcp_flags |= TCP_SKB_CB(skb)->tcp_flags; + TCP_SKB_CB(prev)->eor = TCP_SKB_CB(skb)->eor; + if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) + TCP_SKB_CB(prev)->end_seq++; + + if (skb == tcp_highest_sack(sk)) + tcp_advance_highest_sack(sk, skb); + + tcp_skb_collapse_tstamp(prev, skb); + if (unlikely(TCP_SKB_CB(prev)->tx.delivered_mstamp)) + TCP_SKB_CB(prev)->tx.delivered_mstamp = 0; + + tcp_rtx_queue_unlink_and_free(skb, sk); + + NET_INC_STATS(sock_net(sk), LINUX_MIB_SACKMERGED); + + return true; +} + +/* I wish gso_size would have a bit more sane initialization than + * something-or-zero which complicates things + */ +static int tcp_skb_seglen(const struct sk_buff *skb) +{ + return tcp_skb_pcount(skb) == 1 ? skb->len : tcp_skb_mss(skb); +} + +/* Shifting pages past head area doesn't work */ +static int skb_can_shift(const struct sk_buff *skb) +{ + return !skb_headlen(skb) && skb_is_nonlinear(skb); +} + +int tcp_skb_shift(struct sk_buff *to, struct sk_buff *from, + int pcount, int shiftlen) +{ + /* TCP min gso_size is 8 bytes (TCP_MIN_GSO_SIZE) + * Since TCP_SKB_CB(skb)->tcp_gso_segs is 16 bits, we need + * to make sure not storing more than 65535 * 8 bytes per skb, + * even if current MSS is bigger. + */ + if (unlikely(to->len + shiftlen >= 65535 * TCP_MIN_GSO_SIZE)) + return 0; + if (unlikely(tcp_skb_pcount(to) + pcount > 65535)) + return 0; + return skb_shift(to, from, shiftlen); +} + +/* Try collapsing SACK blocks spanning across multiple skbs to a single + * skb. + */ +static struct sk_buff *tcp_shift_skb_data(struct sock *sk, struct sk_buff *skb, + struct tcp_sacktag_state *state, + u32 start_seq, u32 end_seq, + bool dup_sack) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *prev; + int mss; + int pcount = 0; + int len; + int in_sack; + + /* Normally R but no L won't result in plain S */ + if (!dup_sack && + (TCP_SKB_CB(skb)->sacked & (TCPCB_LOST|TCPCB_SACKED_RETRANS)) == TCPCB_SACKED_RETRANS) + goto fallback; + if (!skb_can_shift(skb)) + goto fallback; + /* This frame is about to be dropped (was ACKed). */ + if (!after(TCP_SKB_CB(skb)->end_seq, tp->snd_una)) + goto fallback; + + /* Can only happen with delayed DSACK + discard craziness */ + prev = skb_rb_prev(skb); + if (!prev) + goto fallback; + + if ((TCP_SKB_CB(prev)->sacked & TCPCB_TAGBITS) != TCPCB_SACKED_ACKED) + goto fallback; + + if (!tcp_skb_can_collapse(prev, skb)) + goto fallback; + + in_sack = !after(start_seq, TCP_SKB_CB(skb)->seq) && + !before(end_seq, TCP_SKB_CB(skb)->end_seq); + + if (in_sack) { + len = skb->len; + pcount = tcp_skb_pcount(skb); + mss = tcp_skb_seglen(skb); + + /* TODO: Fix DSACKs to not fragment already SACKed and we can + * drop this restriction as unnecessary + */ + if (mss != tcp_skb_seglen(prev)) + goto fallback; + } else { + if (!after(TCP_SKB_CB(skb)->end_seq, start_seq)) + goto noop; + /* CHECKME: This is non-MSS split case only?, this will + * cause skipped skbs due to advancing loop btw, original + * has that feature too + */ + if (tcp_skb_pcount(skb) <= 1) + goto noop; + + in_sack = !after(start_seq, TCP_SKB_CB(skb)->seq); + if (!in_sack) { + /* TODO: head merge to next could be attempted here + * if (!after(TCP_SKB_CB(skb)->end_seq, end_seq)), + * though it might not be worth of the additional hassle + * + * ...we can probably just fallback to what was done + * previously. We could try merging non-SACKed ones + * as well but it probably isn't going to buy off + * because later SACKs might again split them, and + * it would make skb timestamp tracking considerably + * harder problem. + */ + goto fallback; + } + + len = end_seq - TCP_SKB_CB(skb)->seq; + BUG_ON(len < 0); + BUG_ON(len > skb->len); + + /* MSS boundaries should be honoured or else pcount will + * severely break even though it makes things bit trickier. + * Optimize common case to avoid most of the divides + */ + mss = tcp_skb_mss(skb); + + /* TODO: Fix DSACKs to not fragment already SACKed and we can + * drop this restriction as unnecessary + */ + if (mss != tcp_skb_seglen(prev)) + goto fallback; + + if (len == mss) { + pcount = 1; + } else if (len < mss) { + goto noop; + } else { + pcount = len / mss; + len = pcount * mss; + } + } + + /* tcp_sacktag_one() won't SACK-tag ranges below snd_una */ + if (!after(TCP_SKB_CB(skb)->seq + len, tp->snd_una)) + goto fallback; + + if (!tcp_skb_shift(prev, skb, pcount, len)) + goto fallback; + if (!tcp_shifted_skb(sk, prev, skb, state, pcount, len, mss, dup_sack)) + goto out; + + /* Hole filled allows collapsing with the next as well, this is very + * useful when hole on every nth skb pattern happens + */ + skb = skb_rb_next(prev); + if (!skb) + goto out; + + if (!skb_can_shift(skb) || + ((TCP_SKB_CB(skb)->sacked & TCPCB_TAGBITS) != TCPCB_SACKED_ACKED) || + (mss != tcp_skb_seglen(skb))) + goto out; + + if (!tcp_skb_can_collapse(prev, skb)) + goto out; + len = skb->len; + pcount = tcp_skb_pcount(skb); + if (tcp_skb_shift(prev, skb, pcount, len)) + tcp_shifted_skb(sk, prev, skb, state, pcount, + len, mss, 0); + +out: + return prev; + +noop: + return skb; + +fallback: + NET_INC_STATS(sock_net(sk), LINUX_MIB_SACKSHIFTFALLBACK); + return NULL; +} + +static struct sk_buff *tcp_sacktag_walk(struct sk_buff *skb, struct sock *sk, + struct tcp_sack_block *next_dup, + struct tcp_sacktag_state *state, + u32 start_seq, u32 end_seq, + bool dup_sack_in) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *tmp; + + skb_rbtree_walk_from(skb) { + int in_sack = 0; + bool dup_sack = dup_sack_in; + + /* queue is in-order => we can short-circuit the walk early */ + if (!before(TCP_SKB_CB(skb)->seq, end_seq)) + break; + + if (next_dup && + before(TCP_SKB_CB(skb)->seq, next_dup->end_seq)) { + in_sack = tcp_match_skb_to_sack(sk, skb, + next_dup->start_seq, + next_dup->end_seq); + if (in_sack > 0) + dup_sack = true; + } + + /* skb reference here is a bit tricky to get right, since + * shifting can eat and free both this skb and the next, + * so not even _safe variant of the loop is enough. + */ + if (in_sack <= 0) { + tmp = tcp_shift_skb_data(sk, skb, state, + start_seq, end_seq, dup_sack); + if (tmp) { + if (tmp != skb) { + skb = tmp; + continue; + } + + in_sack = 0; + } else { + in_sack = tcp_match_skb_to_sack(sk, skb, + start_seq, + end_seq); + } + } + + if (unlikely(in_sack < 0)) + break; + + if (in_sack) { + TCP_SKB_CB(skb)->sacked = + tcp_sacktag_one(sk, + state, + TCP_SKB_CB(skb)->sacked, + TCP_SKB_CB(skb)->seq, + TCP_SKB_CB(skb)->end_seq, + dup_sack, + tcp_skb_pcount(skb), + tcp_skb_timestamp_us(skb)); + tcp_rate_skb_delivered(sk, skb, state->rate); + if (TCP_SKB_CB(skb)->sacked & TCPCB_SACKED_ACKED) + list_del_init(&skb->tcp_tsorted_anchor); + + if (!before(TCP_SKB_CB(skb)->seq, + tcp_highest_sack_seq(tp))) + tcp_advance_highest_sack(sk, skb); + } + } + return skb; +} + +static struct sk_buff *tcp_sacktag_bsearch(struct sock *sk, u32 seq) +{ + struct rb_node *parent, **p = &sk->tcp_rtx_queue.rb_node; + struct sk_buff *skb; + + while (*p) { + parent = *p; + skb = rb_to_skb(parent); + if (before(seq, TCP_SKB_CB(skb)->seq)) { + p = &parent->rb_left; + continue; + } + if (!before(seq, TCP_SKB_CB(skb)->end_seq)) { + p = &parent->rb_right; + continue; + } + return skb; + } + return NULL; +} + +static struct sk_buff *tcp_sacktag_skip(struct sk_buff *skb, struct sock *sk, + u32 skip_to_seq) +{ + if (skb && after(TCP_SKB_CB(skb)->seq, skip_to_seq)) + return skb; + + return tcp_sacktag_bsearch(sk, skip_to_seq); +} + +static struct sk_buff *tcp_maybe_skipping_dsack(struct sk_buff *skb, + struct sock *sk, + struct tcp_sack_block *next_dup, + struct tcp_sacktag_state *state, + u32 skip_to_seq) +{ + if (!next_dup) + return skb; + + if (before(next_dup->start_seq, skip_to_seq)) { + skb = tcp_sacktag_skip(skb, sk, next_dup->start_seq); + skb = tcp_sacktag_walk(skb, sk, NULL, state, + next_dup->start_seq, next_dup->end_seq, + 1); + } + + return skb; +} + +static int tcp_sack_cache_ok(const struct tcp_sock *tp, const struct tcp_sack_block *cache) +{ + return cache < tp->recv_sack_cache + ARRAY_SIZE(tp->recv_sack_cache); +} + +static int +tcp_sacktag_write_queue(struct sock *sk, const struct sk_buff *ack_skb, + u32 prior_snd_una, struct tcp_sacktag_state *state) +{ + struct tcp_sock *tp = tcp_sk(sk); + const unsigned char *ptr = (skb_transport_header(ack_skb) + + TCP_SKB_CB(ack_skb)->sacked); + struct tcp_sack_block_wire *sp_wire = (struct tcp_sack_block_wire *)(ptr+2); + struct tcp_sack_block sp[TCP_NUM_SACKS]; + struct tcp_sack_block *cache; + struct sk_buff *skb; + int num_sacks = min(TCP_NUM_SACKS, (ptr[1] - TCPOLEN_SACK_BASE) >> 3); + int used_sacks; + bool found_dup_sack = false; + int i, j; + int first_sack_index; + + state->flag = 0; + state->reord = tp->snd_nxt; + + if (!tp->sacked_out) + tcp_highest_sack_reset(sk); + + found_dup_sack = tcp_check_dsack(sk, ack_skb, sp_wire, + num_sacks, prior_snd_una, state); + + /* Eliminate too old ACKs, but take into + * account more or less fresh ones, they can + * contain valid SACK info. + */ + if (before(TCP_SKB_CB(ack_skb)->ack_seq, prior_snd_una - tp->max_window)) + return 0; + + if (!tp->packets_out) + goto out; + + used_sacks = 0; + first_sack_index = 0; + for (i = 0; i < num_sacks; i++) { + bool dup_sack = !i && found_dup_sack; + + sp[used_sacks].start_seq = get_unaligned_be32(&sp_wire[i].start_seq); + sp[used_sacks].end_seq = get_unaligned_be32(&sp_wire[i].end_seq); + + if (!tcp_is_sackblock_valid(tp, dup_sack, + sp[used_sacks].start_seq, + sp[used_sacks].end_seq)) { + int mib_idx; + + if (dup_sack) { + if (!tp->undo_marker) + mib_idx = LINUX_MIB_TCPDSACKIGNOREDNOUNDO; + else + mib_idx = LINUX_MIB_TCPDSACKIGNOREDOLD; + } else { + /* Don't count olds caused by ACK reordering */ + if ((TCP_SKB_CB(ack_skb)->ack_seq != tp->snd_una) && + !after(sp[used_sacks].end_seq, tp->snd_una)) + continue; + mib_idx = LINUX_MIB_TCPSACKDISCARD; + } + + NET_INC_STATS(sock_net(sk), mib_idx); + if (i == 0) + first_sack_index = -1; + continue; + } + + /* Ignore very old stuff early */ + if (!after(sp[used_sacks].end_seq, prior_snd_una)) { + if (i == 0) + first_sack_index = -1; + continue; + } + + used_sacks++; + } + + /* order SACK blocks to allow in order walk of the retrans queue */ + for (i = used_sacks - 1; i > 0; i--) { + for (j = 0; j < i; j++) { + if (after(sp[j].start_seq, sp[j + 1].start_seq)) { + swap(sp[j], sp[j + 1]); + + /* Track where the first SACK block goes to */ + if (j == first_sack_index) + first_sack_index = j + 1; + } + } + } + + state->mss_now = tcp_current_mss(sk); + skb = NULL; + i = 0; + + if (!tp->sacked_out) { + /* It's already past, so skip checking against it */ + cache = tp->recv_sack_cache + ARRAY_SIZE(tp->recv_sack_cache); + } else { + cache = tp->recv_sack_cache; + /* Skip empty blocks in at head of the cache */ + while (tcp_sack_cache_ok(tp, cache) && !cache->start_seq && + !cache->end_seq) + cache++; + } + + while (i < used_sacks) { + u32 start_seq = sp[i].start_seq; + u32 end_seq = sp[i].end_seq; + bool dup_sack = (found_dup_sack && (i == first_sack_index)); + struct tcp_sack_block *next_dup = NULL; + + if (found_dup_sack && ((i + 1) == first_sack_index)) + next_dup = &sp[i + 1]; + + /* Skip too early cached blocks */ + while (tcp_sack_cache_ok(tp, cache) && + !before(start_seq, cache->end_seq)) + cache++; + + /* Can skip some work by looking recv_sack_cache? */ + if (tcp_sack_cache_ok(tp, cache) && !dup_sack && + after(end_seq, cache->start_seq)) { + + /* Head todo? */ + if (before(start_seq, cache->start_seq)) { + skb = tcp_sacktag_skip(skb, sk, start_seq); + skb = tcp_sacktag_walk(skb, sk, next_dup, + state, + start_seq, + cache->start_seq, + dup_sack); + } + + /* Rest of the block already fully processed? */ + if (!after(end_seq, cache->end_seq)) + goto advance_sp; + + skb = tcp_maybe_skipping_dsack(skb, sk, next_dup, + state, + cache->end_seq); + + /* ...tail remains todo... */ + if (tcp_highest_sack_seq(tp) == cache->end_seq) { + /* ...but better entrypoint exists! */ + skb = tcp_highest_sack(sk); + if (!skb) + break; + cache++; + goto walk; + } + + skb = tcp_sacktag_skip(skb, sk, cache->end_seq); + /* Check overlap against next cached too (past this one already) */ + cache++; + continue; + } + + if (!before(start_seq, tcp_highest_sack_seq(tp))) { + skb = tcp_highest_sack(sk); + if (!skb) + break; + } + skb = tcp_sacktag_skip(skb, sk, start_seq); + +walk: + skb = tcp_sacktag_walk(skb, sk, next_dup, state, + start_seq, end_seq, dup_sack); + +advance_sp: + i++; + } + + /* Clear the head of the cache sack blocks so we can skip it next time */ + for (i = 0; i < ARRAY_SIZE(tp->recv_sack_cache) - used_sacks; i++) { + tp->recv_sack_cache[i].start_seq = 0; + tp->recv_sack_cache[i].end_seq = 0; + } + for (j = 0; j < used_sacks; j++) + tp->recv_sack_cache[i++] = sp[j]; + + if (inet_csk(sk)->icsk_ca_state != TCP_CA_Loss || tp->undo_marker) + tcp_check_sack_reordering(sk, state->reord, 0); + + tcp_verify_left_out(tp); +out: + +#if FASTRETRANS_DEBUG > 0 + WARN_ON((int)tp->sacked_out < 0); + WARN_ON((int)tp->lost_out < 0); + WARN_ON((int)tp->retrans_out < 0); + WARN_ON((int)tcp_packets_in_flight(tp) < 0); +#endif + return state->flag; +} + +/* Limits sacked_out so that sum with lost_out isn't ever larger than + * packets_out. Returns false if sacked_out adjustement wasn't necessary. + */ +static bool tcp_limit_reno_sacked(struct tcp_sock *tp) +{ + u32 holes; + + holes = max(tp->lost_out, 1U); + holes = min(holes, tp->packets_out); + + if ((tp->sacked_out + holes) > tp->packets_out) { + tp->sacked_out = tp->packets_out - holes; + return true; + } + return false; +} + +/* If we receive more dupacks than we expected counting segments + * in assumption of absent reordering, interpret this as reordering. + * The only another reason could be bug in receiver TCP. + */ +static void tcp_check_reno_reordering(struct sock *sk, const int addend) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (!tcp_limit_reno_sacked(tp)) + return; + + tp->reordering = min_t(u32, tp->packets_out + addend, + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_max_reordering)); + tp->reord_seen++; + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPRENOREORDER); +} + +/* Emulate SACKs for SACKless connection: account for a new dupack. */ + +static void tcp_add_reno_sack(struct sock *sk, int num_dupack, bool ece_ack) +{ + if (num_dupack) { + struct tcp_sock *tp = tcp_sk(sk); + u32 prior_sacked = tp->sacked_out; + s32 delivered; + + tp->sacked_out += num_dupack; + tcp_check_reno_reordering(sk, 0); + delivered = tp->sacked_out - prior_sacked; + if (delivered > 0) + tcp_count_delivered(tp, delivered, ece_ack); + tcp_verify_left_out(tp); + } +} + +/* Account for ACK, ACKing some data in Reno Recovery phase. */ + +static void tcp_remove_reno_sacks(struct sock *sk, int acked, bool ece_ack) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (acked > 0) { + /* One ACK acked hole. The rest eat duplicate ACKs. */ + tcp_count_delivered(tp, max_t(int, acked - tp->sacked_out, 1), + ece_ack); + if (acked - 1 >= tp->sacked_out) + tp->sacked_out = 0; + else + tp->sacked_out -= acked - 1; + } + tcp_check_reno_reordering(sk, acked); + tcp_verify_left_out(tp); +} + +static inline void tcp_reset_reno_sack(struct tcp_sock *tp) +{ + tp->sacked_out = 0; +} + +void tcp_clear_retrans(struct tcp_sock *tp) +{ + tp->retrans_out = 0; + tp->lost_out = 0; + tp->undo_marker = 0; + tp->undo_retrans = -1; + tp->sacked_out = 0; +} + +static inline void tcp_init_undo(struct tcp_sock *tp) +{ + tp->undo_marker = tp->snd_una; + /* Retransmission still in flight may cause DSACKs later. */ + tp->undo_retrans = tp->retrans_out ? : -1; +} + +static bool tcp_is_rack(const struct sock *sk) +{ + return READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_recovery) & + TCP_RACK_LOSS_DETECTION; +} + +/* If we detect SACK reneging, forget all SACK information + * and reset tags completely, otherwise preserve SACKs. If receiver + * dropped its ofo queue, we will know this due to reneging detection. + */ +static void tcp_timeout_mark_lost(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *skb, *head; + bool is_reneg; /* is receiver reneging on SACKs? */ + + head = tcp_rtx_queue_head(sk); + is_reneg = head && (TCP_SKB_CB(head)->sacked & TCPCB_SACKED_ACKED); + if (is_reneg) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPSACKRENEGING); + tp->sacked_out = 0; + /* Mark SACK reneging until we recover from this loss event. */ + tp->is_sack_reneg = 1; + } else if (tcp_is_reno(tp)) { + tcp_reset_reno_sack(tp); + } + + skb = head; + skb_rbtree_walk_from(skb) { + if (is_reneg) + TCP_SKB_CB(skb)->sacked &= ~TCPCB_SACKED_ACKED; + else if (tcp_is_rack(sk) && skb != head && + tcp_rack_skb_timeout(tp, skb, 0) > 0) + continue; /* Don't mark recently sent ones lost yet */ + tcp_mark_skb_lost(sk, skb); + } + tcp_verify_left_out(tp); + tcp_clear_all_retrans_hints(tp); +} + +/* Enter Loss state. */ +void tcp_enter_loss(struct sock *sk) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); + bool new_recovery = icsk->icsk_ca_state < TCP_CA_Recovery; + u8 reordering; + + tcp_timeout_mark_lost(sk); + + /* Reduce ssthresh if it has not yet been made inside this window. */ + if (icsk->icsk_ca_state <= TCP_CA_Disorder || + !after(tp->high_seq, tp->snd_una) || + (icsk->icsk_ca_state == TCP_CA_Loss && !icsk->icsk_retransmits)) { + tp->prior_ssthresh = tcp_current_ssthresh(sk); + tp->prior_cwnd = tcp_snd_cwnd(tp); + tp->snd_ssthresh = icsk->icsk_ca_ops->ssthresh(sk); + tcp_ca_event(sk, CA_EVENT_LOSS); + tcp_init_undo(tp); + } + tcp_snd_cwnd_set(tp, tcp_packets_in_flight(tp) + 1); + tp->snd_cwnd_cnt = 0; + tp->snd_cwnd_stamp = tcp_jiffies32; + + /* Timeout in disordered state after receiving substantial DUPACKs + * suggests that the degree of reordering is over-estimated. + */ + reordering = READ_ONCE(net->ipv4.sysctl_tcp_reordering); + if (icsk->icsk_ca_state <= TCP_CA_Disorder && + tp->sacked_out >= reordering) + tp->reordering = min_t(unsigned int, tp->reordering, + reordering); + + tcp_set_ca_state(sk, TCP_CA_Loss); + tp->high_seq = tp->snd_nxt; + tcp_ecn_queue_cwr(tp); + + /* F-RTO RFC5682 sec 3.1 step 1: retransmit SND.UNA if no previous + * loss recovery is underway except recurring timeout(s) on + * the same SND.UNA (sec 3.2). Disable F-RTO on path MTU probing + */ + tp->frto = READ_ONCE(net->ipv4.sysctl_tcp_frto) && + (new_recovery || icsk->icsk_retransmits) && + !inet_csk(sk)->icsk_mtup.probe_size; +} + +/* If ACK arrived pointing to a remembered SACK, it means that our + * remembered SACKs do not reflect real state of receiver i.e. + * receiver _host_ is heavily congested (or buggy). + * + * To avoid big spurious retransmission bursts due to transient SACK + * scoreboard oddities that look like reneging, we give the receiver a + * little time (max(RTT/2, 10ms)) to send us some more ACKs that will + * restore sanity to the SACK scoreboard. If the apparent reneging + * persists until this RTO then we'll clear the SACK scoreboard. + */ +static bool tcp_check_sack_reneging(struct sock *sk, int *ack_flag) +{ + if (*ack_flag & FLAG_SACK_RENEGING && + *ack_flag & FLAG_SND_UNA_ADVANCED) { + struct tcp_sock *tp = tcp_sk(sk); + unsigned long delay = max(usecs_to_jiffies(tp->srtt_us >> 4), + msecs_to_jiffies(10)); + + inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, + delay, TCP_RTO_MAX); + *ack_flag &= ~FLAG_SET_XMIT_TIMER; + return true; + } + return false; +} + +/* Heurestics to calculate number of duplicate ACKs. There's no dupACKs + * counter when SACK is enabled (without SACK, sacked_out is used for + * that purpose). + * + * With reordering, holes may still be in flight, so RFC3517 recovery + * uses pure sacked_out (total number of SACKed segments) even though + * it violates the RFC that uses duplicate ACKs, often these are equal + * but when e.g. out-of-window ACKs or packet duplication occurs, + * they differ. Since neither occurs due to loss, TCP should really + * ignore them. + */ +static inline int tcp_dupack_heuristics(const struct tcp_sock *tp) +{ + return tp->sacked_out + 1; +} + +/* Linux NewReno/SACK/ECN state machine. + * -------------------------------------- + * + * "Open" Normal state, no dubious events, fast path. + * "Disorder" In all the respects it is "Open", + * but requires a bit more attention. It is entered when + * we see some SACKs or dupacks. It is split of "Open" + * mainly to move some processing from fast path to slow one. + * "CWR" CWND was reduced due to some Congestion Notification event. + * It can be ECN, ICMP source quench, local device congestion. + * "Recovery" CWND was reduced, we are fast-retransmitting. + * "Loss" CWND was reduced due to RTO timeout or SACK reneging. + * + * tcp_fastretrans_alert() is entered: + * - each incoming ACK, if state is not "Open" + * - when arrived ACK is unusual, namely: + * * SACK + * * Duplicate ACK. + * * ECN ECE. + * + * Counting packets in flight is pretty simple. + * + * in_flight = packets_out - left_out + retrans_out + * + * packets_out is SND.NXT-SND.UNA counted in packets. + * + * retrans_out is number of retransmitted segments. + * + * left_out is number of segments left network, but not ACKed yet. + * + * left_out = sacked_out + lost_out + * + * sacked_out: Packets, which arrived to receiver out of order + * and hence not ACKed. With SACKs this number is simply + * amount of SACKed data. Even without SACKs + * it is easy to give pretty reliable estimate of this number, + * counting duplicate ACKs. + * + * lost_out: Packets lost by network. TCP has no explicit + * "loss notification" feedback from network (for now). + * It means that this number can be only _guessed_. + * Actually, it is the heuristics to predict lossage that + * distinguishes different algorithms. + * + * F.e. after RTO, when all the queue is considered as lost, + * lost_out = packets_out and in_flight = retrans_out. + * + * Essentially, we have now a few algorithms detecting + * lost packets. + * + * If the receiver supports SACK: + * + * RFC6675/3517: It is the conventional algorithm. A packet is + * considered lost if the number of higher sequence packets + * SACKed is greater than or equal the DUPACK thoreshold + * (reordering). This is implemented in tcp_mark_head_lost and + * tcp_update_scoreboard. + * + * RACK (draft-ietf-tcpm-rack-01): it is a newer algorithm + * (2017-) that checks timing instead of counting DUPACKs. + * Essentially a packet is considered lost if it's not S/ACKed + * after RTT + reordering_window, where both metrics are + * dynamically measured and adjusted. This is implemented in + * tcp_rack_mark_lost. + * + * If the receiver does not support SACK: + * + * NewReno (RFC6582): in Recovery we assume that one segment + * is lost (classic Reno). While we are in Recovery and + * a partial ACK arrives, we assume that one more packet + * is lost (NewReno). This heuristics are the same in NewReno + * and SACK. + * + * Really tricky (and requiring careful tuning) part of algorithm + * is hidden in functions tcp_time_to_recover() and tcp_xmit_retransmit_queue(). + * The first determines the moment _when_ we should reduce CWND and, + * hence, slow down forward transmission. In fact, it determines the moment + * when we decide that hole is caused by loss, rather than by a reorder. + * + * tcp_xmit_retransmit_queue() decides, _what_ we should retransmit to fill + * holes, caused by lost packets. + * + * And the most logically complicated part of algorithm is undo + * heuristics. We detect false retransmits due to both too early + * fast retransmit (reordering) and underestimated RTO, analyzing + * timestamps and D-SACKs. When we detect that some segments were + * retransmitted by mistake and CWND reduction was wrong, we undo + * window reduction and abort recovery phase. This logic is hidden + * inside several functions named tcp_try_undo_. + */ + +/* This function decides, when we should leave Disordered state + * and enter Recovery phase, reducing congestion window. + * + * Main question: may we further continue forward transmission + * with the same cwnd? + */ +static bool tcp_time_to_recover(struct sock *sk, int flag) +{ + struct tcp_sock *tp = tcp_sk(sk); + + /* Trick#1: The loss is proven. */ + if (tp->lost_out) + return true; + + /* Not-A-Trick#2 : Classic rule... */ + if (!tcp_is_rack(sk) && tcp_dupack_heuristics(tp) > tp->reordering) + return true; + + return false; +} + +/* Detect loss in event "A" above by marking head of queue up as lost. + * For RFC3517 SACK, a segment is considered lost if it + * has at least tp->reordering SACKed seqments above it; "packets" refers to + * the maximum SACKed segments to pass before reaching this limit. + */ +static void tcp_mark_head_lost(struct sock *sk, int packets, int mark_head) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *skb; + int cnt; + /* Use SACK to deduce losses of new sequences sent during recovery */ + const u32 loss_high = tp->snd_nxt; + + WARN_ON(packets > tp->packets_out); + skb = tp->lost_skb_hint; + if (skb) { + /* Head already handled? */ + if (mark_head && after(TCP_SKB_CB(skb)->seq, tp->snd_una)) + return; + cnt = tp->lost_cnt_hint; + } else { + skb = tcp_rtx_queue_head(sk); + cnt = 0; + } + + skb_rbtree_walk_from(skb) { + /* TODO: do this better */ + /* this is not the most efficient way to do this... */ + tp->lost_skb_hint = skb; + tp->lost_cnt_hint = cnt; + + if (after(TCP_SKB_CB(skb)->end_seq, loss_high)) + break; + + if (TCP_SKB_CB(skb)->sacked & TCPCB_SACKED_ACKED) + cnt += tcp_skb_pcount(skb); + + if (cnt > packets) + break; + + if (!(TCP_SKB_CB(skb)->sacked & TCPCB_LOST)) + tcp_mark_skb_lost(sk, skb); + + if (mark_head) + break; + } + tcp_verify_left_out(tp); +} + +/* Account newly detected lost packet(s) */ + +static void tcp_update_scoreboard(struct sock *sk, int fast_rexmit) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (tcp_is_sack(tp)) { + int sacked_upto = tp->sacked_out - tp->reordering; + if (sacked_upto >= 0) + tcp_mark_head_lost(sk, sacked_upto, 0); + else if (fast_rexmit) + tcp_mark_head_lost(sk, 1, 1); + } +} + +static bool tcp_tsopt_ecr_before(const struct tcp_sock *tp, u32 when) +{ + return tp->rx_opt.saw_tstamp && tp->rx_opt.rcv_tsecr && + before(tp->rx_opt.rcv_tsecr, when); +} + +/* skb is spurious retransmitted if the returned timestamp echo + * reply is prior to the skb transmission time + */ +static bool tcp_skb_spurious_retrans(const struct tcp_sock *tp, + const struct sk_buff *skb) +{ + return (TCP_SKB_CB(skb)->sacked & TCPCB_RETRANS) && + tcp_tsopt_ecr_before(tp, tcp_skb_timestamp(skb)); +} + +/* Nothing was retransmitted or returned timestamp is less + * than timestamp of the first retransmission. + */ +static inline bool tcp_packet_delayed(const struct tcp_sock *tp) +{ + return tp->retrans_stamp && + tcp_tsopt_ecr_before(tp, tp->retrans_stamp); +} + +/* Undo procedures. */ + +/* We can clear retrans_stamp when there are no retransmissions in the + * window. It would seem that it is trivially available for us in + * tp->retrans_out, however, that kind of assumptions doesn't consider + * what will happen if errors occur when sending retransmission for the + * second time. ...It could the that such segment has only + * TCPCB_EVER_RETRANS set at the present time. It seems that checking + * the head skb is enough except for some reneging corner cases that + * are not worth the effort. + * + * Main reason for all this complexity is the fact that connection dying + * time now depends on the validity of the retrans_stamp, in particular, + * that successive retransmissions of a segment must not advance + * retrans_stamp under any conditions. + */ +static bool tcp_any_retrans_done(const struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *skb; + + if (tp->retrans_out) + return true; + + skb = tcp_rtx_queue_head(sk); + if (unlikely(skb && TCP_SKB_CB(skb)->sacked & TCPCB_EVER_RETRANS)) + return true; + + return false; +} + +static void DBGUNDO(struct sock *sk, const char *msg) +{ +#if FASTRETRANS_DEBUG > 1 + struct tcp_sock *tp = tcp_sk(sk); + struct inet_sock *inet = inet_sk(sk); + + if (sk->sk_family == AF_INET) { + pr_debug("Undo %s %pI4/%u c%u l%u ss%u/%u p%u\n", + msg, + &inet->inet_daddr, ntohs(inet->inet_dport), + tcp_snd_cwnd(tp), tcp_left_out(tp), + tp->snd_ssthresh, tp->prior_ssthresh, + tp->packets_out); + } +#if IS_ENABLED(CONFIG_IPV6) + else if (sk->sk_family == AF_INET6) { + pr_debug("Undo %s %pI6/%u c%u l%u ss%u/%u p%u\n", + msg, + &sk->sk_v6_daddr, ntohs(inet->inet_dport), + tcp_snd_cwnd(tp), tcp_left_out(tp), + tp->snd_ssthresh, tp->prior_ssthresh, + tp->packets_out); + } +#endif +#endif +} + +static void tcp_undo_cwnd_reduction(struct sock *sk, bool unmark_loss) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (unmark_loss) { + struct sk_buff *skb; + + skb_rbtree_walk(skb, &sk->tcp_rtx_queue) { + TCP_SKB_CB(skb)->sacked &= ~TCPCB_LOST; + } + tp->lost_out = 0; + tcp_clear_all_retrans_hints(tp); + } + + if (tp->prior_ssthresh) { + const struct inet_connection_sock *icsk = inet_csk(sk); + + tcp_snd_cwnd_set(tp, icsk->icsk_ca_ops->undo_cwnd(sk)); + + if (tp->prior_ssthresh > tp->snd_ssthresh) { + tp->snd_ssthresh = tp->prior_ssthresh; + tcp_ecn_withdraw_cwr(tp); + } + } + tp->snd_cwnd_stamp = tcp_jiffies32; + tp->undo_marker = 0; + tp->rack.advanced = 1; /* Force RACK to re-exam losses */ +} + +static inline bool tcp_may_undo(const struct tcp_sock *tp) +{ + return tp->undo_marker && (!tp->undo_retrans || tcp_packet_delayed(tp)); +} + +static bool tcp_is_non_sack_preventing_reopen(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (tp->snd_una == tp->high_seq && tcp_is_reno(tp)) { + /* Hold old state until something *above* high_seq + * is ACKed. For Reno it is MUST to prevent false + * fast retransmits (RFC2582). SACK TCP is safe. */ + if (!tcp_any_retrans_done(sk)) + tp->retrans_stamp = 0; + return true; + } + return false; +} + +/* People celebrate: "We love our President!" */ +static bool tcp_try_undo_recovery(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (tcp_may_undo(tp)) { + int mib_idx; + + /* Happy end! We did not retransmit anything + * or our original transmission succeeded. + */ + DBGUNDO(sk, inet_csk(sk)->icsk_ca_state == TCP_CA_Loss ? "loss" : "retrans"); + tcp_undo_cwnd_reduction(sk, false); + if (inet_csk(sk)->icsk_ca_state == TCP_CA_Loss) + mib_idx = LINUX_MIB_TCPLOSSUNDO; + else + mib_idx = LINUX_MIB_TCPFULLUNDO; + + NET_INC_STATS(sock_net(sk), mib_idx); + } else if (tp->rack.reo_wnd_persist) { + tp->rack.reo_wnd_persist--; + } + if (tcp_is_non_sack_preventing_reopen(sk)) + return true; + tcp_set_ca_state(sk, TCP_CA_Open); + tp->is_sack_reneg = 0; + return false; +} + +/* Try to undo cwnd reduction, because D-SACKs acked all retransmitted data */ +static bool tcp_try_undo_dsack(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (tp->undo_marker && !tp->undo_retrans) { + tp->rack.reo_wnd_persist = min(TCP_RACK_RECOVERY_THRESH, + tp->rack.reo_wnd_persist + 1); + DBGUNDO(sk, "D-SACK"); + tcp_undo_cwnd_reduction(sk, false); + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPDSACKUNDO); + return true; + } + return false; +} + +/* Undo during loss recovery after partial ACK or using F-RTO. */ +static bool tcp_try_undo_loss(struct sock *sk, bool frto_undo) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (frto_undo || tcp_may_undo(tp)) { + tcp_undo_cwnd_reduction(sk, true); + + DBGUNDO(sk, "partial loss"); + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPLOSSUNDO); + if (frto_undo) + NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPSPURIOUSRTOS); + inet_csk(sk)->icsk_retransmits = 0; + if (tcp_is_non_sack_preventing_reopen(sk)) + return true; + if (frto_undo || tcp_is_sack(tp)) { + tcp_set_ca_state(sk, TCP_CA_Open); + tp->is_sack_reneg = 0; + } + return true; + } + return false; +} + +/* The cwnd reduction in CWR and Recovery uses the PRR algorithm in RFC 6937. + * It computes the number of packets to send (sndcnt) based on packets newly + * delivered: + * 1) If the packets in flight is larger than ssthresh, PRR spreads the + * cwnd reductions across a full RTT. + * 2) Otherwise PRR uses packet conservation to send as much as delivered. + * But when SND_UNA is acked without further losses, + * slow starts cwnd up to ssthresh to speed up the recovery. + */ +static void tcp_init_cwnd_reduction(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + tp->high_seq = tp->snd_nxt; + tp->tlp_high_seq = 0; + tp->snd_cwnd_cnt = 0; + tp->prior_cwnd = tcp_snd_cwnd(tp); + tp->prr_delivered = 0; + tp->prr_out = 0; + tp->snd_ssthresh = inet_csk(sk)->icsk_ca_ops->ssthresh(sk); + tcp_ecn_queue_cwr(tp); +} + +void tcp_cwnd_reduction(struct sock *sk, int newly_acked_sacked, int newly_lost, int flag) +{ + struct tcp_sock *tp = tcp_sk(sk); + int sndcnt = 0; + int delta = tp->snd_ssthresh - tcp_packets_in_flight(tp); + + if (newly_acked_sacked <= 0 || WARN_ON_ONCE(!tp->prior_cwnd)) + return; + + tp->prr_delivered += newly_acked_sacked; + if (delta < 0) { + u64 dividend = (u64)tp->snd_ssthresh * tp->prr_delivered + + tp->prior_cwnd - 1; + sndcnt = div_u64(dividend, tp->prior_cwnd) - tp->prr_out; + } else { + sndcnt = max_t(int, tp->prr_delivered - tp->prr_out, + newly_acked_sacked); + if (flag & FLAG_SND_UNA_ADVANCED && !newly_lost) + sndcnt++; + sndcnt = min(delta, sndcnt); + } + /* Force a fast retransmit upon entering fast recovery */ + sndcnt = max(sndcnt, (tp->prr_out ? 0 : 1)); + tcp_snd_cwnd_set(tp, tcp_packets_in_flight(tp) + sndcnt); +} + +static inline void tcp_end_cwnd_reduction(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (inet_csk(sk)->icsk_ca_ops->cong_control) + return; + + /* Reset cwnd to ssthresh in CWR or Recovery (unless it's undone) */ + if (tp->snd_ssthresh < TCP_INFINITE_SSTHRESH && + (inet_csk(sk)->icsk_ca_state == TCP_CA_CWR || tp->undo_marker)) { + tcp_snd_cwnd_set(tp, tp->snd_ssthresh); + tp->snd_cwnd_stamp = tcp_jiffies32; + } + tcp_ca_event(sk, CA_EVENT_COMPLETE_CWR); +} + +/* Enter CWR state. Disable cwnd undo since congestion is proven with ECN */ +void tcp_enter_cwr(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + tp->prior_ssthresh = 0; + if (inet_csk(sk)->icsk_ca_state < TCP_CA_CWR) { + tp->undo_marker = 0; + tcp_init_cwnd_reduction(sk); + tcp_set_ca_state(sk, TCP_CA_CWR); + } +} +EXPORT_SYMBOL(tcp_enter_cwr); + +static void tcp_try_keep_open(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + int state = TCP_CA_Open; + + if (tcp_left_out(tp) || tcp_any_retrans_done(sk)) + state = TCP_CA_Disorder; + + if (inet_csk(sk)->icsk_ca_state != state) { + tcp_set_ca_state(sk, state); + tp->high_seq = tp->snd_nxt; + } +} + +static void tcp_try_to_open(struct sock *sk, int flag) +{ + struct tcp_sock *tp = tcp_sk(sk); + + tcp_verify_left_out(tp); + + if (!tcp_any_retrans_done(sk)) + tp->retrans_stamp = 0; + + if (flag & FLAG_ECE) + tcp_enter_cwr(sk); + + if (inet_csk(sk)->icsk_ca_state != TCP_CA_CWR) { + tcp_try_keep_open(sk); + } +} + +static void tcp_mtup_probe_failed(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + icsk->icsk_mtup.search_high = icsk->icsk_mtup.probe_size - 1; + icsk->icsk_mtup.probe_size = 0; + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMTUPFAIL); +} + +static void tcp_mtup_probe_success(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct inet_connection_sock *icsk = inet_csk(sk); + u64 val; + + tp->prior_ssthresh = tcp_current_ssthresh(sk); + + val = (u64)tcp_snd_cwnd(tp) * tcp_mss_to_mtu(sk, tp->mss_cache); + do_div(val, icsk->icsk_mtup.probe_size); + DEBUG_NET_WARN_ON_ONCE((u32)val != val); + tcp_snd_cwnd_set(tp, max_t(u32, 1U, val)); + + tp->snd_cwnd_cnt = 0; + tp->snd_cwnd_stamp = tcp_jiffies32; + tp->snd_ssthresh = tcp_current_ssthresh(sk); + + icsk->icsk_mtup.search_low = icsk->icsk_mtup.probe_size; + icsk->icsk_mtup.probe_size = 0; + tcp_sync_mss(sk, icsk->icsk_pmtu_cookie); + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMTUPSUCCESS); +} + +/* Do a simple retransmit without using the backoff mechanisms in + * tcp_timer. This is used for path mtu discovery. + * The socket is already locked here. + */ +void tcp_simple_retransmit(struct sock *sk) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *skb; + int mss; + + /* A fastopen SYN request is stored as two separate packets within + * the retransmit queue, this is done by tcp_send_syn_data(). + * As a result simply checking the MSS of the frames in the queue + * will not work for the SYN packet. + * + * Us being here is an indication of a path MTU issue so we can + * assume that the fastopen SYN was lost and just mark all the + * frames in the retransmit queue as lost. We will use an MSS of + * -1 to mark all frames as lost, otherwise compute the current MSS. + */ + if (tp->syn_data && sk->sk_state == TCP_SYN_SENT) + mss = -1; + else + mss = tcp_current_mss(sk); + + skb_rbtree_walk(skb, &sk->tcp_rtx_queue) { + if (tcp_skb_seglen(skb) > mss) + tcp_mark_skb_lost(sk, skb); + } + + tcp_clear_retrans_hints_partial(tp); + + if (!tp->lost_out) + return; + + if (tcp_is_reno(tp)) + tcp_limit_reno_sacked(tp); + + tcp_verify_left_out(tp); + + /* Don't muck with the congestion window here. + * Reason is that we do not increase amount of _data_ + * in network, but units changed and effective + * cwnd/ssthresh really reduced now. + */ + if (icsk->icsk_ca_state != TCP_CA_Loss) { + tp->high_seq = tp->snd_nxt; + tp->snd_ssthresh = tcp_current_ssthresh(sk); + tp->prior_ssthresh = 0; + tp->undo_marker = 0; + tcp_set_ca_state(sk, TCP_CA_Loss); + } + tcp_xmit_retransmit_queue(sk); +} +EXPORT_SYMBOL(tcp_simple_retransmit); + +void tcp_enter_recovery(struct sock *sk, bool ece_ack) +{ + struct tcp_sock *tp = tcp_sk(sk); + int mib_idx; + + if (tcp_is_reno(tp)) + mib_idx = LINUX_MIB_TCPRENORECOVERY; + else + mib_idx = LINUX_MIB_TCPSACKRECOVERY; + + NET_INC_STATS(sock_net(sk), mib_idx); + + tp->prior_ssthresh = 0; + tcp_init_undo(tp); + + if (!tcp_in_cwnd_reduction(sk)) { + if (!ece_ack) + tp->prior_ssthresh = tcp_current_ssthresh(sk); + tcp_init_cwnd_reduction(sk); + } + tcp_set_ca_state(sk, TCP_CA_Recovery); +} + +/* Process an ACK in CA_Loss state. Move to CA_Open if lost data are + * recovered or spurious. Otherwise retransmits more on partial ACKs. + */ +static void tcp_process_loss(struct sock *sk, int flag, int num_dupack, + int *rexmit) +{ + struct tcp_sock *tp = tcp_sk(sk); + bool recovered = !before(tp->snd_una, tp->high_seq); + + if ((flag & FLAG_SND_UNA_ADVANCED || rcu_access_pointer(tp->fastopen_rsk)) && + tcp_try_undo_loss(sk, false)) + return; + + if (tp->frto) { /* F-RTO RFC5682 sec 3.1 (sack enhanced version). */ + /* Step 3.b. A timeout is spurious if not all data are + * lost, i.e., never-retransmitted data are (s)acked. + */ + if ((flag & FLAG_ORIG_SACK_ACKED) && + tcp_try_undo_loss(sk, true)) + return; + + if (after(tp->snd_nxt, tp->high_seq)) { + if (flag & FLAG_DATA_SACKED || num_dupack) + tp->frto = 0; /* Step 3.a. loss was real */ + } else if (flag & FLAG_SND_UNA_ADVANCED && !recovered) { + tp->high_seq = tp->snd_nxt; + /* Step 2.b. Try send new data (but deferred until cwnd + * is updated in tcp_ack()). Otherwise fall back to + * the conventional recovery. + */ + if (!tcp_write_queue_empty(sk) && + after(tcp_wnd_end(tp), tp->snd_nxt)) { + *rexmit = REXMIT_NEW; + return; + } + tp->frto = 0; + } + } + + if (recovered) { + /* F-RTO RFC5682 sec 3.1 step 2.a and 1st part of step 3.a */ + tcp_try_undo_recovery(sk); + return; + } + if (tcp_is_reno(tp)) { + /* A Reno DUPACK means new data in F-RTO step 2.b above are + * delivered. Lower inflight to clock out (re)tranmissions. + */ + if (after(tp->snd_nxt, tp->high_seq) && num_dupack) + tcp_add_reno_sack(sk, num_dupack, flag & FLAG_ECE); + else if (flag & FLAG_SND_UNA_ADVANCED) + tcp_reset_reno_sack(tp); + } + *rexmit = REXMIT_LOST; +} + +static bool tcp_force_fast_retransmit(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + return after(tcp_highest_sack_seq(tp), + tp->snd_una + tp->reordering * tp->mss_cache); +} + +/* Undo during fast recovery after partial ACK. */ +static bool tcp_try_undo_partial(struct sock *sk, u32 prior_snd_una, + bool *do_lost) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (tp->undo_marker && tcp_packet_delayed(tp)) { + /* Plain luck! Hole if filled with delayed + * packet, rather than with a retransmit. Check reordering. + */ + tcp_check_sack_reordering(sk, prior_snd_una, 1); + + /* We are getting evidence that the reordering degree is higher + * than we realized. If there are no retransmits out then we + * can undo. Otherwise we clock out new packets but do not + * mark more packets lost or retransmit more. + */ + if (tp->retrans_out) + return true; + + if (!tcp_any_retrans_done(sk)) + tp->retrans_stamp = 0; + + DBGUNDO(sk, "partial recovery"); + tcp_undo_cwnd_reduction(sk, true); + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPPARTIALUNDO); + tcp_try_keep_open(sk); + } else { + /* Partial ACK arrived. Force fast retransmit. */ + *do_lost = tcp_force_fast_retransmit(sk); + } + return false; +} + +static void tcp_identify_packet_loss(struct sock *sk, int *ack_flag) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (tcp_rtx_queue_empty(sk)) + return; + + if (unlikely(tcp_is_reno(tp))) { + tcp_newreno_mark_lost(sk, *ack_flag & FLAG_SND_UNA_ADVANCED); + } else if (tcp_is_rack(sk)) { + u32 prior_retrans = tp->retrans_out; + + if (tcp_rack_mark_lost(sk)) + *ack_flag &= ~FLAG_SET_XMIT_TIMER; + if (prior_retrans > tp->retrans_out) + *ack_flag |= FLAG_LOST_RETRANS; + } +} + +/* Process an event, which can update packets-in-flight not trivially. + * Main goal of this function is to calculate new estimate for left_out, + * taking into account both packets sitting in receiver's buffer and + * packets lost by network. + * + * Besides that it updates the congestion state when packet loss or ECN + * is detected. But it does not reduce the cwnd, it is done by the + * congestion control later. + * + * It does _not_ decide what to send, it is made in function + * tcp_xmit_retransmit_queue(). + */ +static void tcp_fastretrans_alert(struct sock *sk, const u32 prior_snd_una, + int num_dupack, int *ack_flag, int *rexmit) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + int fast_rexmit = 0, flag = *ack_flag; + bool ece_ack = flag & FLAG_ECE; + bool do_lost = num_dupack || ((flag & FLAG_DATA_SACKED) && + tcp_force_fast_retransmit(sk)); + + if (!tp->packets_out && tp->sacked_out) + tp->sacked_out = 0; + + /* Now state machine starts. + * A. ECE, hence prohibit cwnd undoing, the reduction is required. */ + if (ece_ack) + tp->prior_ssthresh = 0; + + /* B. In all the states check for reneging SACKs. */ + if (tcp_check_sack_reneging(sk, ack_flag)) + return; + + /* C. Check consistency of the current state. */ + tcp_verify_left_out(tp); + + /* D. Check state exit conditions. State can be terminated + * when high_seq is ACKed. */ + if (icsk->icsk_ca_state == TCP_CA_Open) { + WARN_ON(tp->retrans_out != 0 && !tp->syn_data); + tp->retrans_stamp = 0; + } else if (!before(tp->snd_una, tp->high_seq)) { + switch (icsk->icsk_ca_state) { + case TCP_CA_CWR: + /* CWR is to be held something *above* high_seq + * is ACKed for CWR bit to reach receiver. */ + if (tp->snd_una != tp->high_seq) { + tcp_end_cwnd_reduction(sk); + tcp_set_ca_state(sk, TCP_CA_Open); + } + break; + + case TCP_CA_Recovery: + if (tcp_is_reno(tp)) + tcp_reset_reno_sack(tp); + if (tcp_try_undo_recovery(sk)) + return; + tcp_end_cwnd_reduction(sk); + break; + } + } + + /* E. Process state. */ + switch (icsk->icsk_ca_state) { + case TCP_CA_Recovery: + if (!(flag & FLAG_SND_UNA_ADVANCED)) { + if (tcp_is_reno(tp)) + tcp_add_reno_sack(sk, num_dupack, ece_ack); + } else if (tcp_try_undo_partial(sk, prior_snd_una, &do_lost)) + return; + + if (tcp_try_undo_dsack(sk)) + tcp_try_keep_open(sk); + + tcp_identify_packet_loss(sk, ack_flag); + if (icsk->icsk_ca_state != TCP_CA_Recovery) { + if (!tcp_time_to_recover(sk, flag)) + return; + /* Undo reverts the recovery state. If loss is evident, + * starts a new recovery (e.g. reordering then loss); + */ + tcp_enter_recovery(sk, ece_ack); + } + break; + case TCP_CA_Loss: + tcp_process_loss(sk, flag, num_dupack, rexmit); + tcp_identify_packet_loss(sk, ack_flag); + if (!(icsk->icsk_ca_state == TCP_CA_Open || + (*ack_flag & FLAG_LOST_RETRANS))) + return; + /* Change state if cwnd is undone or retransmits are lost */ + fallthrough; + default: + if (tcp_is_reno(tp)) { + if (flag & FLAG_SND_UNA_ADVANCED) + tcp_reset_reno_sack(tp); + tcp_add_reno_sack(sk, num_dupack, ece_ack); + } + + if (icsk->icsk_ca_state <= TCP_CA_Disorder) + tcp_try_undo_dsack(sk); + + tcp_identify_packet_loss(sk, ack_flag); + if (!tcp_time_to_recover(sk, flag)) { + tcp_try_to_open(sk, flag); + return; + } + + /* MTU probe failure: don't reduce cwnd */ + if (icsk->icsk_ca_state < TCP_CA_CWR && + icsk->icsk_mtup.probe_size && + tp->snd_una == tp->mtu_probe.probe_seq_start) { + tcp_mtup_probe_failed(sk); + /* Restores the reduction we did in tcp_mtup_probe() */ + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + 1); + tcp_simple_retransmit(sk); + return; + } + + /* Otherwise enter Recovery state */ + tcp_enter_recovery(sk, ece_ack); + fast_rexmit = 1; + } + + if (!tcp_is_rack(sk) && do_lost) + tcp_update_scoreboard(sk, fast_rexmit); + *rexmit = REXMIT_LOST; +} + +static void tcp_update_rtt_min(struct sock *sk, u32 rtt_us, const int flag) +{ + u32 wlen = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_min_rtt_wlen) * HZ; + struct tcp_sock *tp = tcp_sk(sk); + + if ((flag & FLAG_ACK_MAYBE_DELAYED) && rtt_us > tcp_min_rtt(tp)) { + /* If the remote keeps returning delayed ACKs, eventually + * the min filter would pick it up and overestimate the + * prop. delay when it expires. Skip suspected delayed ACKs. + */ + return; + } + minmax_running_min(&tp->rtt_min, wlen, tcp_jiffies32, + rtt_us ? : jiffies_to_usecs(1)); +} + +static bool tcp_ack_update_rtt(struct sock *sk, const int flag, + long seq_rtt_us, long sack_rtt_us, + long ca_rtt_us, struct rate_sample *rs) +{ + const struct tcp_sock *tp = tcp_sk(sk); + + /* Prefer RTT measured from ACK's timing to TS-ECR. This is because + * broken middle-boxes or peers may corrupt TS-ECR fields. But + * Karn's algorithm forbids taking RTT if some retransmitted data + * is acked (RFC6298). + */ + if (seq_rtt_us < 0) + seq_rtt_us = sack_rtt_us; + + /* RTTM Rule: A TSecr value received in a segment is used to + * update the averaged RTT measurement only if the segment + * acknowledges some new data, i.e., only if it advances the + * left edge of the send window. + * See draft-ietf-tcplw-high-performance-00, section 3.3. + */ + if (seq_rtt_us < 0 && tp->rx_opt.saw_tstamp && tp->rx_opt.rcv_tsecr && + flag & FLAG_ACKED) { + u32 delta = tcp_time_stamp(tp) - tp->rx_opt.rcv_tsecr; + + if (likely(delta < INT_MAX / (USEC_PER_SEC / TCP_TS_HZ))) { + if (!delta) + delta = 1; + seq_rtt_us = delta * (USEC_PER_SEC / TCP_TS_HZ); + ca_rtt_us = seq_rtt_us; + } + } + rs->rtt_us = ca_rtt_us; /* RTT of last (S)ACKed packet (or -1) */ + if (seq_rtt_us < 0) + return false; + + /* ca_rtt_us >= 0 is counting on the invariant that ca_rtt_us is + * always taken together with ACK, SACK, or TS-opts. Any negative + * values will be skipped with the seq_rtt_us < 0 check above. + */ + tcp_update_rtt_min(sk, ca_rtt_us, flag); + tcp_rtt_estimator(sk, seq_rtt_us); + tcp_set_rto(sk); + + /* RFC6298: only reset backoff on valid RTT measurement. */ + inet_csk(sk)->icsk_backoff = 0; + return true; +} + +/* Compute time elapsed between (last) SYNACK and the ACK completing 3WHS. */ +void tcp_synack_rtt_meas(struct sock *sk, struct request_sock *req) +{ + struct rate_sample rs; + long rtt_us = -1L; + + if (req && !req->num_retrans && tcp_rsk(req)->snt_synack) + rtt_us = tcp_stamp_us_delta(tcp_clock_us(), tcp_rsk(req)->snt_synack); + + tcp_ack_update_rtt(sk, FLAG_SYN_ACKED, rtt_us, -1L, rtt_us, &rs); +} + + +static void tcp_cong_avoid(struct sock *sk, u32 ack, u32 acked) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + + icsk->icsk_ca_ops->cong_avoid(sk, ack, acked); + tcp_sk(sk)->snd_cwnd_stamp = tcp_jiffies32; +} + +/* Restart timer after forward progress on connection. + * RFC2988 recommends to restart timer to now+rto. + */ +void tcp_rearm_rto(struct sock *sk) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + + /* If the retrans timer is currently being used by Fast Open + * for SYN-ACK retrans purpose, stay put. + */ + if (rcu_access_pointer(tp->fastopen_rsk)) + return; + + if (!tp->packets_out) { + inet_csk_clear_xmit_timer(sk, ICSK_TIME_RETRANS); + } else { + u32 rto = inet_csk(sk)->icsk_rto; + /* Offset the time elapsed after installing regular RTO */ + if (icsk->icsk_pending == ICSK_TIME_REO_TIMEOUT || + icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) { + s64 delta_us = tcp_rto_delta_us(sk); + /* delta_us may not be positive if the socket is locked + * when the retrans timer fires and is rescheduled. + */ + rto = usecs_to_jiffies(max_t(int, delta_us, 1)); + } + tcp_reset_xmit_timer(sk, ICSK_TIME_RETRANS, rto, + TCP_RTO_MAX); + } +} + +/* Try to schedule a loss probe; if that doesn't work, then schedule an RTO. */ +static void tcp_set_xmit_timer(struct sock *sk) +{ + if (!tcp_schedule_loss_probe(sk, true)) + tcp_rearm_rto(sk); +} + +/* If we get here, the whole TSO packet has not been acked. */ +static u32 tcp_tso_acked(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 packets_acked; + + BUG_ON(!after(TCP_SKB_CB(skb)->end_seq, tp->snd_una)); + + packets_acked = tcp_skb_pcount(skb); + if (tcp_trim_head(sk, skb, tp->snd_una - TCP_SKB_CB(skb)->seq)) + return 0; + packets_acked -= tcp_skb_pcount(skb); + + if (packets_acked) { + BUG_ON(tcp_skb_pcount(skb) == 0); + BUG_ON(!before(TCP_SKB_CB(skb)->seq, TCP_SKB_CB(skb)->end_seq)); + } + + return packets_acked; +} + +static void tcp_ack_tstamp(struct sock *sk, struct sk_buff *skb, + const struct sk_buff *ack_skb, u32 prior_snd_una) +{ + const struct skb_shared_info *shinfo; + + /* Avoid cache line misses to get skb_shinfo() and shinfo->tx_flags */ + if (likely(!TCP_SKB_CB(skb)->txstamp_ack)) + return; + + shinfo = skb_shinfo(skb); + if (!before(shinfo->tskey, prior_snd_una) && + before(shinfo->tskey, tcp_sk(sk)->snd_una)) { + tcp_skb_tsorted_save(skb) { + __skb_tstamp_tx(skb, ack_skb, NULL, sk, SCM_TSTAMP_ACK); + } tcp_skb_tsorted_restore(skb); + } +} + +/* Remove acknowledged frames from the retransmission queue. If our packet + * is before the ack sequence we can discard it as it's confirmed to have + * arrived at the other end. + */ +static int tcp_clean_rtx_queue(struct sock *sk, const struct sk_buff *ack_skb, + u32 prior_fack, u32 prior_snd_una, + struct tcp_sacktag_state *sack, bool ece_ack) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + u64 first_ackt, last_ackt; + struct tcp_sock *tp = tcp_sk(sk); + u32 prior_sacked = tp->sacked_out; + u32 reord = tp->snd_nxt; /* lowest acked un-retx un-sacked seq */ + struct sk_buff *skb, *next; + bool fully_acked = true; + long sack_rtt_us = -1L; + long seq_rtt_us = -1L; + long ca_rtt_us = -1L; + u32 pkts_acked = 0; + bool rtt_update; + int flag = 0; + + first_ackt = 0; + + for (skb = skb_rb_first(&sk->tcp_rtx_queue); skb; skb = next) { + struct tcp_skb_cb *scb = TCP_SKB_CB(skb); + const u32 start_seq = scb->seq; + u8 sacked = scb->sacked; + u32 acked_pcount; + + /* Determine how many packets and what bytes were acked, tso and else */ + if (after(scb->end_seq, tp->snd_una)) { + if (tcp_skb_pcount(skb) == 1 || + !after(tp->snd_una, scb->seq)) + break; + + acked_pcount = tcp_tso_acked(sk, skb); + if (!acked_pcount) + break; + fully_acked = false; + } else { + acked_pcount = tcp_skb_pcount(skb); + } + + if (unlikely(sacked & TCPCB_RETRANS)) { + if (sacked & TCPCB_SACKED_RETRANS) + tp->retrans_out -= acked_pcount; + flag |= FLAG_RETRANS_DATA_ACKED; + } else if (!(sacked & TCPCB_SACKED_ACKED)) { + last_ackt = tcp_skb_timestamp_us(skb); + WARN_ON_ONCE(last_ackt == 0); + if (!first_ackt) + first_ackt = last_ackt; + + if (before(start_seq, reord)) + reord = start_seq; + if (!after(scb->end_seq, tp->high_seq)) + flag |= FLAG_ORIG_SACK_ACKED; + } + + if (sacked & TCPCB_SACKED_ACKED) { + tp->sacked_out -= acked_pcount; + } else if (tcp_is_sack(tp)) { + tcp_count_delivered(tp, acked_pcount, ece_ack); + if (!tcp_skb_spurious_retrans(tp, skb)) + tcp_rack_advance(tp, sacked, scb->end_seq, + tcp_skb_timestamp_us(skb)); + } + if (sacked & TCPCB_LOST) + tp->lost_out -= acked_pcount; + + tp->packets_out -= acked_pcount; + pkts_acked += acked_pcount; + tcp_rate_skb_delivered(sk, skb, sack->rate); + + /* Initial outgoing SYN's get put onto the write_queue + * just like anything else we transmit. It is not + * true data, and if we misinform our callers that + * this ACK acks real data, we will erroneously exit + * connection startup slow start one packet too + * quickly. This is severely frowned upon behavior. + */ + if (likely(!(scb->tcp_flags & TCPHDR_SYN))) { + flag |= FLAG_DATA_ACKED; + } else { + flag |= FLAG_SYN_ACKED; + tp->retrans_stamp = 0; + } + + if (!fully_acked) + break; + + tcp_ack_tstamp(sk, skb, ack_skb, prior_snd_una); + + next = skb_rb_next(skb); + if (unlikely(skb == tp->retransmit_skb_hint)) + tp->retransmit_skb_hint = NULL; + if (unlikely(skb == tp->lost_skb_hint)) + tp->lost_skb_hint = NULL; + tcp_highest_sack_replace(sk, skb, next); + tcp_rtx_queue_unlink_and_free(skb, sk); + } + + if (!skb) + tcp_chrono_stop(sk, TCP_CHRONO_BUSY); + + if (likely(between(tp->snd_up, prior_snd_una, tp->snd_una))) + tp->snd_up = tp->snd_una; + + if (skb) { + tcp_ack_tstamp(sk, skb, ack_skb, prior_snd_una); + if (TCP_SKB_CB(skb)->sacked & TCPCB_SACKED_ACKED) + flag |= FLAG_SACK_RENEGING; + } + + if (likely(first_ackt) && !(flag & FLAG_RETRANS_DATA_ACKED)) { + seq_rtt_us = tcp_stamp_us_delta(tp->tcp_mstamp, first_ackt); + ca_rtt_us = tcp_stamp_us_delta(tp->tcp_mstamp, last_ackt); + + if (pkts_acked == 1 && fully_acked && !prior_sacked && + (tp->snd_una - prior_snd_una) < tp->mss_cache && + sack->rate->prior_delivered + 1 == tp->delivered && + !(flag & (FLAG_CA_ALERT | FLAG_SYN_ACKED))) { + /* Conservatively mark a delayed ACK. It's typically + * from a lone runt packet over the round trip to + * a receiver w/o out-of-order or CE events. + */ + flag |= FLAG_ACK_MAYBE_DELAYED; + } + } + if (sack->first_sackt) { + sack_rtt_us = tcp_stamp_us_delta(tp->tcp_mstamp, sack->first_sackt); + ca_rtt_us = tcp_stamp_us_delta(tp->tcp_mstamp, sack->last_sackt); + } + rtt_update = tcp_ack_update_rtt(sk, flag, seq_rtt_us, sack_rtt_us, + ca_rtt_us, sack->rate); + + if (flag & FLAG_ACKED) { + flag |= FLAG_SET_XMIT_TIMER; /* set TLP or RTO timer */ + if (unlikely(icsk->icsk_mtup.probe_size && + !after(tp->mtu_probe.probe_seq_end, tp->snd_una))) { + tcp_mtup_probe_success(sk); + } + + if (tcp_is_reno(tp)) { + tcp_remove_reno_sacks(sk, pkts_acked, ece_ack); + + /* If any of the cumulatively ACKed segments was + * retransmitted, non-SACK case cannot confirm that + * progress was due to original transmission due to + * lack of TCPCB_SACKED_ACKED bits even if some of + * the packets may have been never retransmitted. + */ + if (flag & FLAG_RETRANS_DATA_ACKED) + flag &= ~FLAG_ORIG_SACK_ACKED; + } else { + int delta; + + /* Non-retransmitted hole got filled? That's reordering */ + if (before(reord, prior_fack)) + tcp_check_sack_reordering(sk, reord, 0); + + delta = prior_sacked - tp->sacked_out; + tp->lost_cnt_hint -= min(tp->lost_cnt_hint, delta); + } + } else if (skb && rtt_update && sack_rtt_us >= 0 && + sack_rtt_us > tcp_stamp_us_delta(tp->tcp_mstamp, + tcp_skb_timestamp_us(skb))) { + /* Do not re-arm RTO if the sack RTT is measured from data sent + * after when the head was last (re)transmitted. Otherwise the + * timeout may continue to extend in loss recovery. + */ + flag |= FLAG_SET_XMIT_TIMER; /* set TLP or RTO timer */ + } + + if (icsk->icsk_ca_ops->pkts_acked) { + struct ack_sample sample = { .pkts_acked = pkts_acked, + .rtt_us = sack->rate->rtt_us }; + + sample.in_flight = tp->mss_cache * + (tp->delivered - sack->rate->prior_delivered); + icsk->icsk_ca_ops->pkts_acked(sk, &sample); + } + +#if FASTRETRANS_DEBUG > 0 + WARN_ON((int)tp->sacked_out < 0); + WARN_ON((int)tp->lost_out < 0); + WARN_ON((int)tp->retrans_out < 0); + if (!tp->packets_out && tcp_is_sack(tp)) { + icsk = inet_csk(sk); + if (tp->lost_out) { + pr_debug("Leak l=%u %d\n", + tp->lost_out, icsk->icsk_ca_state); + tp->lost_out = 0; + } + if (tp->sacked_out) { + pr_debug("Leak s=%u %d\n", + tp->sacked_out, icsk->icsk_ca_state); + tp->sacked_out = 0; + } + if (tp->retrans_out) { + pr_debug("Leak r=%u %d\n", + tp->retrans_out, icsk->icsk_ca_state); + tp->retrans_out = 0; + } + } +#endif + return flag; +} + +static void tcp_ack_probe(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct sk_buff *head = tcp_send_head(sk); + const struct tcp_sock *tp = tcp_sk(sk); + + /* Was it a usable window open? */ + if (!head) + return; + if (!after(TCP_SKB_CB(head)->end_seq, tcp_wnd_end(tp))) { + icsk->icsk_backoff = 0; + icsk->icsk_probes_tstamp = 0; + inet_csk_clear_xmit_timer(sk, ICSK_TIME_PROBE0); + /* Socket must be waked up by subsequent tcp_data_snd_check(). + * This function is not for random using! + */ + } else { + unsigned long when = tcp_probe0_when(sk, TCP_RTO_MAX); + + when = tcp_clamp_probe0_to_user_timeout(sk, when); + tcp_reset_xmit_timer(sk, ICSK_TIME_PROBE0, when, TCP_RTO_MAX); + } +} + +static inline bool tcp_ack_is_dubious(const struct sock *sk, const int flag) +{ + return !(flag & FLAG_NOT_DUP) || (flag & FLAG_CA_ALERT) || + inet_csk(sk)->icsk_ca_state != TCP_CA_Open; +} + +/* Decide wheather to run the increase function of congestion control. */ +static inline bool tcp_may_raise_cwnd(const struct sock *sk, const int flag) +{ + /* If reordering is high then always grow cwnd whenever data is + * delivered regardless of its ordering. Otherwise stay conservative + * and only grow cwnd on in-order delivery (RFC5681). A stretched ACK w/ + * new SACK or ECE mark may first advance cwnd here and later reduce + * cwnd in tcp_fastretrans_alert() based on more states. + */ + if (tcp_sk(sk)->reordering > + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_reordering)) + return flag & FLAG_FORWARD_PROGRESS; + + return flag & FLAG_DATA_ACKED; +} + +/* The "ultimate" congestion control function that aims to replace the rigid + * cwnd increase and decrease control (tcp_cong_avoid,tcp_*cwnd_reduction). + * It's called toward the end of processing an ACK with precise rate + * information. All transmission or retransmission are delayed afterwards. + */ +static void tcp_cong_control(struct sock *sk, u32 ack, u32 acked_sacked, + int flag, const struct rate_sample *rs) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + + if (icsk->icsk_ca_ops->cong_control) { + icsk->icsk_ca_ops->cong_control(sk, rs); + return; + } + + if (tcp_in_cwnd_reduction(sk)) { + /* Reduce cwnd if state mandates */ + tcp_cwnd_reduction(sk, acked_sacked, rs->losses, flag); + } else if (tcp_may_raise_cwnd(sk, flag)) { + /* Advance cwnd if state allows */ + tcp_cong_avoid(sk, ack, acked_sacked); + } + tcp_update_pacing_rate(sk); +} + +/* Check that window update is acceptable. + * The function assumes that snd_una<=ack<=snd_next. + */ +static inline bool tcp_may_update_window(const struct tcp_sock *tp, + const u32 ack, const u32 ack_seq, + const u32 nwin) +{ + return after(ack, tp->snd_una) || + after(ack_seq, tp->snd_wl1) || + (ack_seq == tp->snd_wl1 && nwin > tp->snd_wnd); +} + +/* If we update tp->snd_una, also update tp->bytes_acked */ +static void tcp_snd_una_update(struct tcp_sock *tp, u32 ack) +{ + u32 delta = ack - tp->snd_una; + + sock_owned_by_me((struct sock *)tp); + tp->bytes_acked += delta; + tp->snd_una = ack; +} + +/* If we update tp->rcv_nxt, also update tp->bytes_received */ +static void tcp_rcv_nxt_update(struct tcp_sock *tp, u32 seq) +{ + u32 delta = seq - tp->rcv_nxt; + + sock_owned_by_me((struct sock *)tp); + tp->bytes_received += delta; + WRITE_ONCE(tp->rcv_nxt, seq); +} + +/* Update our send window. + * + * Window update algorithm, described in RFC793/RFC1122 (used in linux-2.2 + * and in FreeBSD. NetBSD's one is even worse.) is wrong. + */ +static int tcp_ack_update_window(struct sock *sk, const struct sk_buff *skb, u32 ack, + u32 ack_seq) +{ + struct tcp_sock *tp = tcp_sk(sk); + int flag = 0; + u32 nwin = ntohs(tcp_hdr(skb)->window); + + if (likely(!tcp_hdr(skb)->syn)) + nwin <<= tp->rx_opt.snd_wscale; + + if (tcp_may_update_window(tp, ack, ack_seq, nwin)) { + flag |= FLAG_WIN_UPDATE; + tcp_update_wl(tp, ack_seq); + + if (tp->snd_wnd != nwin) { + tp->snd_wnd = nwin; + + /* Note, it is the only place, where + * fast path is recovered for sending TCP. + */ + tp->pred_flags = 0; + tcp_fast_path_check(sk); + + if (!tcp_write_queue_empty(sk)) + tcp_slow_start_after_idle_check(sk); + + if (nwin > tp->max_window) { + tp->max_window = nwin; + tcp_sync_mss(sk, inet_csk(sk)->icsk_pmtu_cookie); + } + } + } + + tcp_snd_una_update(tp, ack); + + return flag; +} + +static bool __tcp_oow_rate_limited(struct net *net, int mib_idx, + u32 *last_oow_ack_time) +{ + /* Paired with the WRITE_ONCE() in this function. */ + u32 val = READ_ONCE(*last_oow_ack_time); + + if (val) { + s32 elapsed = (s32)(tcp_jiffies32 - val); + + if (0 <= elapsed && + elapsed < READ_ONCE(net->ipv4.sysctl_tcp_invalid_ratelimit)) { + NET_INC_STATS(net, mib_idx); + return true; /* rate-limited: don't send yet! */ + } + } + + /* Paired with the prior READ_ONCE() and with itself, + * as we might be lockless. + */ + WRITE_ONCE(*last_oow_ack_time, tcp_jiffies32); + + return false; /* not rate-limited: go ahead, send dupack now! */ +} + +/* Return true if we're currently rate-limiting out-of-window ACKs and + * thus shouldn't send a dupack right now. We rate-limit dupacks in + * response to out-of-window SYNs or ACKs to mitigate ACK loops or DoS + * attacks that send repeated SYNs or ACKs for the same connection. To + * do this, we do not send a duplicate SYNACK or ACK if the remote + * endpoint is sending out-of-window SYNs or pure ACKs at a high rate. + */ +bool tcp_oow_rate_limited(struct net *net, const struct sk_buff *skb, + int mib_idx, u32 *last_oow_ack_time) +{ + /* Data packets without SYNs are not likely part of an ACK loop. */ + if ((TCP_SKB_CB(skb)->seq != TCP_SKB_CB(skb)->end_seq) && + !tcp_hdr(skb)->syn) + return false; + + return __tcp_oow_rate_limited(net, mib_idx, last_oow_ack_time); +} + +/* RFC 5961 7 [ACK Throttling] */ +static void tcp_send_challenge_ack(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); + u32 count, now, ack_limit; + + /* First check our per-socket dupack rate limit. */ + if (__tcp_oow_rate_limited(net, + LINUX_MIB_TCPACKSKIPPEDCHALLENGE, + &tp->last_oow_ack_time)) + return; + + ack_limit = READ_ONCE(net->ipv4.sysctl_tcp_challenge_ack_limit); + if (ack_limit == INT_MAX) + goto send_ack; + + /* Then check host-wide RFC 5961 rate limit. */ + now = jiffies / HZ; + if (now != READ_ONCE(net->ipv4.tcp_challenge_timestamp)) { + u32 half = (ack_limit + 1) >> 1; + + WRITE_ONCE(net->ipv4.tcp_challenge_timestamp, now); + WRITE_ONCE(net->ipv4.tcp_challenge_count, half + prandom_u32_max(ack_limit)); + } + count = READ_ONCE(net->ipv4.tcp_challenge_count); + if (count > 0) { + WRITE_ONCE(net->ipv4.tcp_challenge_count, count - 1); +send_ack: + NET_INC_STATS(net, LINUX_MIB_TCPCHALLENGEACK); + tcp_send_ack(sk); + } +} + +static void tcp_store_ts_recent(struct tcp_sock *tp) +{ + tp->rx_opt.ts_recent = tp->rx_opt.rcv_tsval; + tp->rx_opt.ts_recent_stamp = ktime_get_seconds(); +} + +static void tcp_replace_ts_recent(struct tcp_sock *tp, u32 seq) +{ + if (tp->rx_opt.saw_tstamp && !after(seq, tp->rcv_wup)) { + /* PAWS bug workaround wrt. ACK frames, the PAWS discard + * extra check below makes sure this can only happen + * for pure ACK frames. -DaveM + * + * Not only, also it occurs for expired timestamps. + */ + + if (tcp_paws_check(&tp->rx_opt, 0)) + tcp_store_ts_recent(tp); + } +} + +/* This routine deals with acks during a TLP episode and ends an episode by + * resetting tlp_high_seq. Ref: TLP algorithm in draft-ietf-tcpm-rack + */ +static void tcp_process_tlp_ack(struct sock *sk, u32 ack, int flag) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (before(ack, tp->tlp_high_seq)) + return; + + if (!tp->tlp_retrans) { + /* TLP of new data has been acknowledged */ + tp->tlp_high_seq = 0; + } else if (flag & FLAG_DSACK_TLP) { + /* This DSACK means original and TLP probe arrived; no loss */ + tp->tlp_high_seq = 0; + } else if (after(ack, tp->tlp_high_seq)) { + /* ACK advances: there was a loss, so reduce cwnd. Reset + * tlp_high_seq in tcp_init_cwnd_reduction() + */ + tcp_init_cwnd_reduction(sk); + tcp_set_ca_state(sk, TCP_CA_CWR); + tcp_end_cwnd_reduction(sk); + tcp_try_keep_open(sk); + NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPLOSSPROBERECOVERY); + } else if (!(flag & (FLAG_SND_UNA_ADVANCED | + FLAG_NOT_DUP | FLAG_DATA_SACKED))) { + /* Pure dupack: original and TLP probe arrived; no loss */ + tp->tlp_high_seq = 0; + } +} + +static inline void tcp_in_ack_event(struct sock *sk, u32 flags) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + + if (icsk->icsk_ca_ops->in_ack_event) + icsk->icsk_ca_ops->in_ack_event(sk, flags); +} + +/* Congestion control has updated the cwnd already. So if we're in + * loss recovery then now we do any new sends (for FRTO) or + * retransmits (for CA_Loss or CA_recovery) that make sense. + */ +static void tcp_xmit_recovery(struct sock *sk, int rexmit) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (rexmit == REXMIT_NONE || sk->sk_state == TCP_SYN_SENT) + return; + + if (unlikely(rexmit == REXMIT_NEW)) { + __tcp_push_pending_frames(sk, tcp_current_mss(sk), + TCP_NAGLE_OFF); + if (after(tp->snd_nxt, tp->high_seq)) + return; + tp->frto = 0; + } + tcp_xmit_retransmit_queue(sk); +} + +/* Returns the number of packets newly acked or sacked by the current ACK */ +static u32 tcp_newly_delivered(struct sock *sk, u32 prior_delivered, int flag) +{ + const struct net *net = sock_net(sk); + struct tcp_sock *tp = tcp_sk(sk); + u32 delivered; + + delivered = tp->delivered - prior_delivered; + NET_ADD_STATS(net, LINUX_MIB_TCPDELIVERED, delivered); + if (flag & FLAG_ECE) + NET_ADD_STATS(net, LINUX_MIB_TCPDELIVEREDCE, delivered); + + return delivered; +} + +/* This routine deals with incoming acks, but not outgoing ones. */ +static int tcp_ack(struct sock *sk, const struct sk_buff *skb, int flag) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct tcp_sacktag_state sack_state; + struct rate_sample rs = { .prior_delivered = 0 }; + u32 prior_snd_una = tp->snd_una; + bool is_sack_reneg = tp->is_sack_reneg; + u32 ack_seq = TCP_SKB_CB(skb)->seq; + u32 ack = TCP_SKB_CB(skb)->ack_seq; + int num_dupack = 0; + int prior_packets = tp->packets_out; + u32 delivered = tp->delivered; + u32 lost = tp->lost; + int rexmit = REXMIT_NONE; /* Flag to (re)transmit to recover losses */ + u32 prior_fack; + + sack_state.first_sackt = 0; + sack_state.rate = &rs; + sack_state.sack_delivered = 0; + + /* We very likely will need to access rtx queue. */ + prefetch(sk->tcp_rtx_queue.rb_node); + + /* If the ack is older than previous acks + * then we can probably ignore it. + */ + if (before(ack, prior_snd_una)) { + u32 max_window; + + /* do not accept ACK for bytes we never sent. */ + max_window = min_t(u64, tp->max_window, tp->bytes_acked); + /* RFC 5961 5.2 [Blind Data Injection Attack].[Mitigation] */ + if (before(ack, prior_snd_una - max_window)) { + if (!(flag & FLAG_NO_CHALLENGE_ACK)) + tcp_send_challenge_ack(sk); + return -SKB_DROP_REASON_TCP_TOO_OLD_ACK; + } + goto old_ack; + } + + /* If the ack includes data we haven't sent yet, discard + * this segment (RFC793 Section 3.9). + */ + if (after(ack, tp->snd_nxt)) + return -SKB_DROP_REASON_TCP_ACK_UNSENT_DATA; + + if (after(ack, prior_snd_una)) { + flag |= FLAG_SND_UNA_ADVANCED; + icsk->icsk_retransmits = 0; + +#if IS_ENABLED(CONFIG_TLS_DEVICE) + if (static_branch_unlikely(&clean_acked_data_enabled.key)) + if (icsk->icsk_clean_acked) + icsk->icsk_clean_acked(sk, ack); +#endif + } + + prior_fack = tcp_is_sack(tp) ? tcp_highest_sack_seq(tp) : tp->snd_una; + rs.prior_in_flight = tcp_packets_in_flight(tp); + + /* ts_recent update must be made after we are sure that the packet + * is in window. + */ + if (flag & FLAG_UPDATE_TS_RECENT) + tcp_replace_ts_recent(tp, TCP_SKB_CB(skb)->seq); + + if ((flag & (FLAG_SLOWPATH | FLAG_SND_UNA_ADVANCED)) == + FLAG_SND_UNA_ADVANCED) { + /* Window is constant, pure forward advance. + * No more checks are required. + * Note, we use the fact that SND.UNA>=SND.WL2. + */ + tcp_update_wl(tp, ack_seq); + tcp_snd_una_update(tp, ack); + flag |= FLAG_WIN_UPDATE; + + tcp_in_ack_event(sk, CA_ACK_WIN_UPDATE); + + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPHPACKS); + } else { + u32 ack_ev_flags = CA_ACK_SLOWPATH; + + if (ack_seq != TCP_SKB_CB(skb)->end_seq) + flag |= FLAG_DATA; + else + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPPUREACKS); + + flag |= tcp_ack_update_window(sk, skb, ack, ack_seq); + + if (TCP_SKB_CB(skb)->sacked) + flag |= tcp_sacktag_write_queue(sk, skb, prior_snd_una, + &sack_state); + + if (tcp_ecn_rcv_ecn_echo(tp, tcp_hdr(skb))) { + flag |= FLAG_ECE; + ack_ev_flags |= CA_ACK_ECE; + } + + if (sack_state.sack_delivered) + tcp_count_delivered(tp, sack_state.sack_delivered, + flag & FLAG_ECE); + + if (flag & FLAG_WIN_UPDATE) + ack_ev_flags |= CA_ACK_WIN_UPDATE; + + tcp_in_ack_event(sk, ack_ev_flags); + } + + /* This is a deviation from RFC3168 since it states that: + * "When the TCP data sender is ready to set the CWR bit after reducing + * the congestion window, it SHOULD set the CWR bit only on the first + * new data packet that it transmits." + * We accept CWR on pure ACKs to be more robust + * with widely-deployed TCP implementations that do this. + */ + tcp_ecn_accept_cwr(sk, skb); + + /* We passed data and got it acked, remove any soft error + * log. Something worked... + */ + sk->sk_err_soft = 0; + icsk->icsk_probes_out = 0; + tp->rcv_tstamp = tcp_jiffies32; + if (!prior_packets) + goto no_queue; + + /* See if we can take anything off of the retransmit queue. */ + flag |= tcp_clean_rtx_queue(sk, skb, prior_fack, prior_snd_una, + &sack_state, flag & FLAG_ECE); + + tcp_rack_update_reo_wnd(sk, &rs); + + if (tp->tlp_high_seq) + tcp_process_tlp_ack(sk, ack, flag); + + if (tcp_ack_is_dubious(sk, flag)) { + if (!(flag & (FLAG_SND_UNA_ADVANCED | + FLAG_NOT_DUP | FLAG_DSACKING_ACK))) { + num_dupack = 1; + /* Consider if pure acks were aggregated in tcp_add_backlog() */ + if (!(flag & FLAG_DATA)) + num_dupack = max_t(u16, 1, skb_shinfo(skb)->gso_segs); + } + tcp_fastretrans_alert(sk, prior_snd_una, num_dupack, &flag, + &rexmit); + } + + /* If needed, reset TLP/RTO timer when RACK doesn't set. */ + if (flag & FLAG_SET_XMIT_TIMER) + tcp_set_xmit_timer(sk); + + if ((flag & FLAG_FORWARD_PROGRESS) || !(flag & FLAG_NOT_DUP)) + sk_dst_confirm(sk); + + delivered = tcp_newly_delivered(sk, delivered, flag); + lost = tp->lost - lost; /* freshly marked lost */ + rs.is_ack_delayed = !!(flag & FLAG_ACK_MAYBE_DELAYED); + tcp_rate_gen(sk, delivered, lost, is_sack_reneg, sack_state.rate); + tcp_cong_control(sk, ack, delivered, flag, sack_state.rate); + tcp_xmit_recovery(sk, rexmit); + return 1; + +no_queue: + /* If data was DSACKed, see if we can undo a cwnd reduction. */ + if (flag & FLAG_DSACKING_ACK) { + tcp_fastretrans_alert(sk, prior_snd_una, num_dupack, &flag, + &rexmit); + tcp_newly_delivered(sk, delivered, flag); + } + /* If this ack opens up a zero window, clear backoff. It was + * being used to time the probes, and is probably far higher than + * it needs to be for normal retransmission. + */ + tcp_ack_probe(sk); + + if (tp->tlp_high_seq) + tcp_process_tlp_ack(sk, ack, flag); + return 1; + +old_ack: + /* If data was SACKed, tag it and see if we should send more data. + * If data was DSACKed, see if we can undo a cwnd reduction. + */ + if (TCP_SKB_CB(skb)->sacked) { + flag |= tcp_sacktag_write_queue(sk, skb, prior_snd_una, + &sack_state); + tcp_fastretrans_alert(sk, prior_snd_una, num_dupack, &flag, + &rexmit); + tcp_newly_delivered(sk, delivered, flag); + tcp_xmit_recovery(sk, rexmit); + } + + return 0; +} + +static void tcp_parse_fastopen_option(int len, const unsigned char *cookie, + bool syn, struct tcp_fastopen_cookie *foc, + bool exp_opt) +{ + /* Valid only in SYN or SYN-ACK with an even length. */ + if (!foc || !syn || len < 0 || (len & 1)) + return; + + if (len >= TCP_FASTOPEN_COOKIE_MIN && + len <= TCP_FASTOPEN_COOKIE_MAX) + memcpy(foc->val, cookie, len); + else if (len != 0) + len = -1; + foc->len = len; + foc->exp = exp_opt; +} + +static bool smc_parse_options(const struct tcphdr *th, + struct tcp_options_received *opt_rx, + const unsigned char *ptr, + int opsize) +{ +#if IS_ENABLED(CONFIG_SMC) + if (static_branch_unlikely(&tcp_have_smc)) { + if (th->syn && !(opsize & 1) && + opsize >= TCPOLEN_EXP_SMC_BASE && + get_unaligned_be32(ptr) == TCPOPT_SMC_MAGIC) { + opt_rx->smc_ok = 1; + return true; + } + } +#endif + return false; +} + +/* Try to parse the MSS option from the TCP header. Return 0 on failure, clamped + * value on success. + */ +u16 tcp_parse_mss_option(const struct tcphdr *th, u16 user_mss) +{ + const unsigned char *ptr = (const unsigned char *)(th + 1); + int length = (th->doff * 4) - sizeof(struct tcphdr); + u16 mss = 0; + + while (length > 0) { + int opcode = *ptr++; + int opsize; + + switch (opcode) { + case TCPOPT_EOL: + return mss; + case TCPOPT_NOP: /* Ref: RFC 793 section 3.1 */ + length--; + continue; + default: + if (length < 2) + return mss; + opsize = *ptr++; + if (opsize < 2) /* "silly options" */ + return mss; + if (opsize > length) + return mss; /* fail on partial options */ + if (opcode == TCPOPT_MSS && opsize == TCPOLEN_MSS) { + u16 in_mss = get_unaligned_be16(ptr); + + if (in_mss) { + if (user_mss && user_mss < in_mss) + in_mss = user_mss; + mss = in_mss; + } + } + ptr += opsize - 2; + length -= opsize; + } + } + return mss; +} +EXPORT_SYMBOL_GPL(tcp_parse_mss_option); + +/* Look for tcp options. Normally only called on SYN and SYNACK packets. + * But, this can also be called on packets in the established flow when + * the fast version below fails. + */ +void tcp_parse_options(const struct net *net, + const struct sk_buff *skb, + struct tcp_options_received *opt_rx, int estab, + struct tcp_fastopen_cookie *foc) +{ + const unsigned char *ptr; + const struct tcphdr *th = tcp_hdr(skb); + int length = (th->doff * 4) - sizeof(struct tcphdr); + + ptr = (const unsigned char *)(th + 1); + opt_rx->saw_tstamp = 0; + opt_rx->saw_unknown = 0; + + while (length > 0) { + int opcode = *ptr++; + int opsize; + + switch (opcode) { + case TCPOPT_EOL: + return; + case TCPOPT_NOP: /* Ref: RFC 793 section 3.1 */ + length--; + continue; + default: + if (length < 2) + return; + opsize = *ptr++; + if (opsize < 2) /* "silly options" */ + return; + if (opsize > length) + return; /* don't parse partial options */ + switch (opcode) { + case TCPOPT_MSS: + if (opsize == TCPOLEN_MSS && th->syn && !estab) { + u16 in_mss = get_unaligned_be16(ptr); + if (in_mss) { + if (opt_rx->user_mss && + opt_rx->user_mss < in_mss) + in_mss = opt_rx->user_mss; + opt_rx->mss_clamp = in_mss; + } + } + break; + case TCPOPT_WINDOW: + if (opsize == TCPOLEN_WINDOW && th->syn && + !estab && READ_ONCE(net->ipv4.sysctl_tcp_window_scaling)) { + __u8 snd_wscale = *(__u8 *)ptr; + opt_rx->wscale_ok = 1; + if (snd_wscale > TCP_MAX_WSCALE) { + net_info_ratelimited("%s: Illegal window scaling value %d > %u received\n", + __func__, + snd_wscale, + TCP_MAX_WSCALE); + snd_wscale = TCP_MAX_WSCALE; + } + opt_rx->snd_wscale = snd_wscale; + } + break; + case TCPOPT_TIMESTAMP: + if ((opsize == TCPOLEN_TIMESTAMP) && + ((estab && opt_rx->tstamp_ok) || + (!estab && READ_ONCE(net->ipv4.sysctl_tcp_timestamps)))) { + opt_rx->saw_tstamp = 1; + opt_rx->rcv_tsval = get_unaligned_be32(ptr); + opt_rx->rcv_tsecr = get_unaligned_be32(ptr + 4); + } + break; + case TCPOPT_SACK_PERM: + if (opsize == TCPOLEN_SACK_PERM && th->syn && + !estab && READ_ONCE(net->ipv4.sysctl_tcp_sack)) { + opt_rx->sack_ok = TCP_SACK_SEEN; + tcp_sack_reset(opt_rx); + } + break; + + case TCPOPT_SACK: + if ((opsize >= (TCPOLEN_SACK_BASE + TCPOLEN_SACK_PERBLOCK)) && + !((opsize - TCPOLEN_SACK_BASE) % TCPOLEN_SACK_PERBLOCK) && + opt_rx->sack_ok) { + TCP_SKB_CB(skb)->sacked = (ptr - 2) - (unsigned char *)th; + } + break; +#ifdef CONFIG_TCP_MD5SIG + case TCPOPT_MD5SIG: + /* + * The MD5 Hash has already been + * checked (see tcp_v{4,6}_do_rcv()). + */ + break; +#endif + case TCPOPT_FASTOPEN: + tcp_parse_fastopen_option( + opsize - TCPOLEN_FASTOPEN_BASE, + ptr, th->syn, foc, false); + break; + + case TCPOPT_EXP: + /* Fast Open option shares code 254 using a + * 16 bits magic number. + */ + if (opsize >= TCPOLEN_EXP_FASTOPEN_BASE && + get_unaligned_be16(ptr) == + TCPOPT_FASTOPEN_MAGIC) { + tcp_parse_fastopen_option(opsize - + TCPOLEN_EXP_FASTOPEN_BASE, + ptr + 2, th->syn, foc, true); + break; + } + + if (smc_parse_options(th, opt_rx, ptr, opsize)) + break; + + opt_rx->saw_unknown = 1; + break; + + default: + opt_rx->saw_unknown = 1; + } + ptr += opsize-2; + length -= opsize; + } + } +} +EXPORT_SYMBOL(tcp_parse_options); + +static bool tcp_parse_aligned_timestamp(struct tcp_sock *tp, const struct tcphdr *th) +{ + const __be32 *ptr = (const __be32 *)(th + 1); + + if (*ptr == htonl((TCPOPT_NOP << 24) | (TCPOPT_NOP << 16) + | (TCPOPT_TIMESTAMP << 8) | TCPOLEN_TIMESTAMP)) { + tp->rx_opt.saw_tstamp = 1; + ++ptr; + tp->rx_opt.rcv_tsval = ntohl(*ptr); + ++ptr; + if (*ptr) + tp->rx_opt.rcv_tsecr = ntohl(*ptr) - tp->tsoffset; + else + tp->rx_opt.rcv_tsecr = 0; + return true; + } + return false; +} + +/* Fast parse options. This hopes to only see timestamps. + * If it is wrong it falls back on tcp_parse_options(). + */ +static bool tcp_fast_parse_options(const struct net *net, + const struct sk_buff *skb, + const struct tcphdr *th, struct tcp_sock *tp) +{ + /* In the spirit of fast parsing, compare doff directly to constant + * values. Because equality is used, short doff can be ignored here. + */ + if (th->doff == (sizeof(*th) / 4)) { + tp->rx_opt.saw_tstamp = 0; + return false; + } else if (tp->rx_opt.tstamp_ok && + th->doff == ((sizeof(*th) + TCPOLEN_TSTAMP_ALIGNED) / 4)) { + if (tcp_parse_aligned_timestamp(tp, th)) + return true; + } + + tcp_parse_options(net, skb, &tp->rx_opt, 1, NULL); + if (tp->rx_opt.saw_tstamp && tp->rx_opt.rcv_tsecr) + tp->rx_opt.rcv_tsecr -= tp->tsoffset; + + return true; +} + +#ifdef CONFIG_TCP_MD5SIG +/* + * Parse MD5 Signature option + */ +const u8 *tcp_parse_md5sig_option(const struct tcphdr *th) +{ + int length = (th->doff << 2) - sizeof(*th); + const u8 *ptr = (const u8 *)(th + 1); + + /* If not enough data remaining, we can short cut */ + while (length >= TCPOLEN_MD5SIG) { + int opcode = *ptr++; + int opsize; + + switch (opcode) { + case TCPOPT_EOL: + return NULL; + case TCPOPT_NOP: + length--; + continue; + default: + opsize = *ptr++; + if (opsize < 2 || opsize > length) + return NULL; + if (opcode == TCPOPT_MD5SIG) + return opsize == TCPOLEN_MD5SIG ? ptr : NULL; + } + ptr += opsize - 2; + length -= opsize; + } + return NULL; +} +EXPORT_SYMBOL(tcp_parse_md5sig_option); +#endif + +/* Sorry, PAWS as specified is broken wrt. pure-ACKs -DaveM + * + * It is not fatal. If this ACK does _not_ change critical state (seqs, window) + * it can pass through stack. So, the following predicate verifies that + * this segment is not used for anything but congestion avoidance or + * fast retransmit. Moreover, we even are able to eliminate most of such + * second order effects, if we apply some small "replay" window (~RTO) + * to timestamp space. + * + * All these measures still do not guarantee that we reject wrapped ACKs + * on networks with high bandwidth, when sequence space is recycled fastly, + * but it guarantees that such events will be very rare and do not affect + * connection seriously. This doesn't look nice, but alas, PAWS is really + * buggy extension. + * + * [ Later note. Even worse! It is buggy for segments _with_ data. RFC + * states that events when retransmit arrives after original data are rare. + * It is a blatant lie. VJ forgot about fast retransmit! 8)8) It is + * the biggest problem on large power networks even with minor reordering. + * OK, let's give it small replay window. If peer clock is even 1hz, it is safe + * up to bandwidth of 18Gigabit/sec. 8) ] + */ + +static int tcp_disordered_ack(const struct sock *sk, const struct sk_buff *skb) +{ + const struct tcp_sock *tp = tcp_sk(sk); + const struct tcphdr *th = tcp_hdr(skb); + u32 seq = TCP_SKB_CB(skb)->seq; + u32 ack = TCP_SKB_CB(skb)->ack_seq; + + return (/* 1. Pure ACK with correct sequence number. */ + (th->ack && seq == TCP_SKB_CB(skb)->end_seq && seq == tp->rcv_nxt) && + + /* 2. ... and duplicate ACK. */ + ack == tp->snd_una && + + /* 3. ... and does not update window. */ + !tcp_may_update_window(tp, ack, seq, ntohs(th->window) << tp->rx_opt.snd_wscale) && + + /* 4. ... and sits in replay window. */ + (s32)(tp->rx_opt.ts_recent - tp->rx_opt.rcv_tsval) <= (inet_csk(sk)->icsk_rto * 1024) / HZ); +} + +static inline bool tcp_paws_discard(const struct sock *sk, + const struct sk_buff *skb) +{ + const struct tcp_sock *tp = tcp_sk(sk); + + return !tcp_paws_check(&tp->rx_opt, TCP_PAWS_WINDOW) && + !tcp_disordered_ack(sk, skb); +} + +/* Check segment sequence number for validity. + * + * Segment controls are considered valid, if the segment + * fits to the window after truncation to the window. Acceptability + * of data (and SYN, FIN, of course) is checked separately. + * See tcp_data_queue(), for example. + * + * Also, controls (RST is main one) are accepted using RCV.WUP instead + * of RCV.NXT. Peer still did not advance his SND.UNA when we + * delayed ACK, so that hisSND.UNA<=ourRCV.WUP. + * (borrowed from freebsd) + */ + +static inline bool tcp_sequence(const struct tcp_sock *tp, u32 seq, u32 end_seq) +{ + return !before(end_seq, tp->rcv_wup) && + !after(seq, tp->rcv_nxt + tcp_receive_window(tp)); +} + +/* When we get a reset we do this. */ +void tcp_reset(struct sock *sk, struct sk_buff *skb) +{ + trace_tcp_receive_reset(sk); + + /* mptcp can't tell us to ignore reset pkts, + * so just ignore the return value of mptcp_incoming_options(). + */ + if (sk_is_mptcp(sk)) + mptcp_incoming_options(sk, skb); + + /* We want the right error as BSD sees it (and indeed as we do). */ + switch (sk->sk_state) { + case TCP_SYN_SENT: + sk->sk_err = ECONNREFUSED; + break; + case TCP_CLOSE_WAIT: + sk->sk_err = EPIPE; + break; + case TCP_CLOSE: + return; + default: + sk->sk_err = ECONNRESET; + } + /* This barrier is coupled with smp_rmb() in tcp_poll() */ + smp_wmb(); + + tcp_write_queue_purge(sk); + tcp_done(sk); + + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); +} + +/* + * Process the FIN bit. This now behaves as it is supposed to work + * and the FIN takes effect when it is validly part of sequence + * space. Not before when we get holes. + * + * If we are ESTABLISHED, a received fin moves us to CLOSE-WAIT + * (and thence onto LAST-ACK and finally, CLOSE, we never enter + * TIME-WAIT) + * + * If we are in FINWAIT-1, a received FIN indicates simultaneous + * close and we go into CLOSING (and later onto TIME-WAIT) + * + * If we are in FINWAIT-2, a received FIN moves us to TIME-WAIT. + */ +void tcp_fin(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + inet_csk_schedule_ack(sk); + + WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | RCV_SHUTDOWN); + sock_set_flag(sk, SOCK_DONE); + + switch (sk->sk_state) { + case TCP_SYN_RECV: + case TCP_ESTABLISHED: + /* Move to CLOSE_WAIT */ + tcp_set_state(sk, TCP_CLOSE_WAIT); + inet_csk_enter_pingpong_mode(sk); + break; + + case TCP_CLOSE_WAIT: + case TCP_CLOSING: + /* Received a retransmission of the FIN, do + * nothing. + */ + break; + case TCP_LAST_ACK: + /* RFC793: Remain in the LAST-ACK state. */ + break; + + case TCP_FIN_WAIT1: + /* This case occurs when a simultaneous close + * happens, we must ack the received FIN and + * enter the CLOSING state. + */ + tcp_send_ack(sk); + tcp_set_state(sk, TCP_CLOSING); + break; + case TCP_FIN_WAIT2: + /* Received a FIN -- send ACK and enter TIME_WAIT. */ + tcp_send_ack(sk); + tcp_time_wait(sk, TCP_TIME_WAIT, 0); + break; + default: + /* Only TCP_LISTEN and TCP_CLOSE are left, in these + * cases we should never reach this piece of code. + */ + pr_err("%s: Impossible, sk->sk_state=%d\n", + __func__, sk->sk_state); + break; + } + + /* It _is_ possible, that we have something out-of-order _after_ FIN. + * Probably, we should reset in this case. For now drop them. + */ + skb_rbtree_purge(&tp->out_of_order_queue); + if (tcp_is_sack(tp)) + tcp_sack_reset(&tp->rx_opt); + + if (!sock_flag(sk, SOCK_DEAD)) { + sk->sk_state_change(sk); + + /* Do not send POLL_HUP for half duplex close. */ + if (sk->sk_shutdown == SHUTDOWN_MASK || + sk->sk_state == TCP_CLOSE) + sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_HUP); + else + sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN); + } +} + +static inline bool tcp_sack_extend(struct tcp_sack_block *sp, u32 seq, + u32 end_seq) +{ + if (!after(seq, sp->end_seq) && !after(sp->start_seq, end_seq)) { + if (before(seq, sp->start_seq)) + sp->start_seq = seq; + if (after(end_seq, sp->end_seq)) + sp->end_seq = end_seq; + return true; + } + return false; +} + +static void tcp_dsack_set(struct sock *sk, u32 seq, u32 end_seq) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (tcp_is_sack(tp) && READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_dsack)) { + int mib_idx; + + if (before(seq, tp->rcv_nxt)) + mib_idx = LINUX_MIB_TCPDSACKOLDSENT; + else + mib_idx = LINUX_MIB_TCPDSACKOFOSENT; + + NET_INC_STATS(sock_net(sk), mib_idx); + + tp->rx_opt.dsack = 1; + tp->duplicate_sack[0].start_seq = seq; + tp->duplicate_sack[0].end_seq = end_seq; + } +} + +static void tcp_dsack_extend(struct sock *sk, u32 seq, u32 end_seq) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (!tp->rx_opt.dsack) + tcp_dsack_set(sk, seq, end_seq); + else + tcp_sack_extend(tp->duplicate_sack, seq, end_seq); +} + +static void tcp_rcv_spurious_retrans(struct sock *sk, const struct sk_buff *skb) +{ + /* When the ACK path fails or drops most ACKs, the sender would + * timeout and spuriously retransmit the same segment repeatedly. + * The receiver remembers and reflects via DSACKs. Leverage the + * DSACK state and change the txhash to re-route speculatively. + */ + if (TCP_SKB_CB(skb)->seq == tcp_sk(sk)->duplicate_sack[0].start_seq && + sk_rethink_txhash(sk)) + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPDUPLICATEDATAREHASH); +} + +static void tcp_send_dupack(struct sock *sk, const struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (TCP_SKB_CB(skb)->end_seq != TCP_SKB_CB(skb)->seq && + before(TCP_SKB_CB(skb)->seq, tp->rcv_nxt)) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_DELAYEDACKLOST); + tcp_enter_quickack_mode(sk, TCP_MAX_QUICKACKS); + + if (tcp_is_sack(tp) && READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_dsack)) { + u32 end_seq = TCP_SKB_CB(skb)->end_seq; + + tcp_rcv_spurious_retrans(sk, skb); + if (after(TCP_SKB_CB(skb)->end_seq, tp->rcv_nxt)) + end_seq = tp->rcv_nxt; + tcp_dsack_set(sk, TCP_SKB_CB(skb)->seq, end_seq); + } + } + + tcp_send_ack(sk); +} + +/* These routines update the SACK block as out-of-order packets arrive or + * in-order packets close up the sequence space. + */ +static void tcp_sack_maybe_coalesce(struct tcp_sock *tp) +{ + int this_sack; + struct tcp_sack_block *sp = &tp->selective_acks[0]; + struct tcp_sack_block *swalk = sp + 1; + + /* See if the recent change to the first SACK eats into + * or hits the sequence space of other SACK blocks, if so coalesce. + */ + for (this_sack = 1; this_sack < tp->rx_opt.num_sacks;) { + if (tcp_sack_extend(sp, swalk->start_seq, swalk->end_seq)) { + int i; + + /* Zap SWALK, by moving every further SACK up by one slot. + * Decrease num_sacks. + */ + tp->rx_opt.num_sacks--; + for (i = this_sack; i < tp->rx_opt.num_sacks; i++) + sp[i] = sp[i + 1]; + continue; + } + this_sack++; + swalk++; + } +} + +void tcp_sack_compress_send_ack(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (!tp->compressed_ack) + return; + + if (hrtimer_try_to_cancel(&tp->compressed_ack_timer) == 1) + __sock_put(sk); + + /* Since we have to send one ack finally, + * substract one from tp->compressed_ack to keep + * LINUX_MIB_TCPACKCOMPRESSED accurate. + */ + NET_ADD_STATS(sock_net(sk), LINUX_MIB_TCPACKCOMPRESSED, + tp->compressed_ack - 1); + + tp->compressed_ack = 0; + tcp_send_ack(sk); +} + +/* Reasonable amount of sack blocks included in TCP SACK option + * The max is 4, but this becomes 3 if TCP timestamps are there. + * Given that SACK packets might be lost, be conservative and use 2. + */ +#define TCP_SACK_BLOCKS_EXPECTED 2 + +static void tcp_sack_new_ofo_skb(struct sock *sk, u32 seq, u32 end_seq) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct tcp_sack_block *sp = &tp->selective_acks[0]; + int cur_sacks = tp->rx_opt.num_sacks; + int this_sack; + + if (!cur_sacks) + goto new_sack; + + for (this_sack = 0; this_sack < cur_sacks; this_sack++, sp++) { + if (tcp_sack_extend(sp, seq, end_seq)) { + if (this_sack >= TCP_SACK_BLOCKS_EXPECTED) + tcp_sack_compress_send_ack(sk); + /* Rotate this_sack to the first one. */ + for (; this_sack > 0; this_sack--, sp--) + swap(*sp, *(sp - 1)); + if (cur_sacks > 1) + tcp_sack_maybe_coalesce(tp); + return; + } + } + + if (this_sack >= TCP_SACK_BLOCKS_EXPECTED) + tcp_sack_compress_send_ack(sk); + + /* Could not find an adjacent existing SACK, build a new one, + * put it at the front, and shift everyone else down. We + * always know there is at least one SACK present already here. + * + * If the sack array is full, forget about the last one. + */ + if (this_sack >= TCP_NUM_SACKS) { + this_sack--; + tp->rx_opt.num_sacks--; + sp--; + } + for (; this_sack > 0; this_sack--, sp--) + *sp = *(sp - 1); + +new_sack: + /* Build the new head SACK, and we're done. */ + sp->start_seq = seq; + sp->end_seq = end_seq; + tp->rx_opt.num_sacks++; +} + +/* RCV.NXT advances, some SACKs should be eaten. */ + +static void tcp_sack_remove(struct tcp_sock *tp) +{ + struct tcp_sack_block *sp = &tp->selective_acks[0]; + int num_sacks = tp->rx_opt.num_sacks; + int this_sack; + + /* Empty ofo queue, hence, all the SACKs are eaten. Clear. */ + if (RB_EMPTY_ROOT(&tp->out_of_order_queue)) { + tp->rx_opt.num_sacks = 0; + return; + } + + for (this_sack = 0; this_sack < num_sacks;) { + /* Check if the start of the sack is covered by RCV.NXT. */ + if (!before(tp->rcv_nxt, sp->start_seq)) { + int i; + + /* RCV.NXT must cover all the block! */ + WARN_ON(before(tp->rcv_nxt, sp->end_seq)); + + /* Zap this SACK, by moving forward any other SACKS. */ + for (i = this_sack+1; i < num_sacks; i++) + tp->selective_acks[i-1] = tp->selective_acks[i]; + num_sacks--; + continue; + } + this_sack++; + sp++; + } + tp->rx_opt.num_sacks = num_sacks; +} + +/** + * tcp_try_coalesce - try to merge skb to prior one + * @sk: socket + * @to: prior buffer + * @from: buffer to add in queue + * @fragstolen: pointer to boolean + * + * Before queueing skb @from after @to, try to merge them + * to reduce overall memory use and queue lengths, if cost is small. + * Packets in ofo or receive queues can stay a long time. + * Better try to coalesce them right now to avoid future collapses. + * Returns true if caller should free @from instead of queueing it + */ +static bool tcp_try_coalesce(struct sock *sk, + struct sk_buff *to, + struct sk_buff *from, + bool *fragstolen) +{ + int delta; + + *fragstolen = false; + + /* Its possible this segment overlaps with prior segment in queue */ + if (TCP_SKB_CB(from)->seq != TCP_SKB_CB(to)->end_seq) + return false; + + if (!mptcp_skb_can_collapse(to, from)) + return false; + +#ifdef CONFIG_TLS_DEVICE + if (from->decrypted != to->decrypted) + return false; +#endif + + if (!skb_try_coalesce(to, from, fragstolen, &delta)) + return false; + + atomic_add(delta, &sk->sk_rmem_alloc); + sk_mem_charge(sk, delta); + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPRCVCOALESCE); + TCP_SKB_CB(to)->end_seq = TCP_SKB_CB(from)->end_seq; + TCP_SKB_CB(to)->ack_seq = TCP_SKB_CB(from)->ack_seq; + TCP_SKB_CB(to)->tcp_flags |= TCP_SKB_CB(from)->tcp_flags; + + if (TCP_SKB_CB(from)->has_rxtstamp) { + TCP_SKB_CB(to)->has_rxtstamp = true; + to->tstamp = from->tstamp; + skb_hwtstamps(to)->hwtstamp = skb_hwtstamps(from)->hwtstamp; + } + + return true; +} + +static bool tcp_ooo_try_coalesce(struct sock *sk, + struct sk_buff *to, + struct sk_buff *from, + bool *fragstolen) +{ + bool res = tcp_try_coalesce(sk, to, from, fragstolen); + + /* In case tcp_drop_reason() is called later, update to->gso_segs */ + if (res) { + u32 gso_segs = max_t(u16, 1, skb_shinfo(to)->gso_segs) + + max_t(u16, 1, skb_shinfo(from)->gso_segs); + + skb_shinfo(to)->gso_segs = min_t(u32, gso_segs, 0xFFFF); + } + return res; +} + +static void tcp_drop_reason(struct sock *sk, struct sk_buff *skb, + enum skb_drop_reason reason) +{ + sk_drops_add(sk, skb); + kfree_skb_reason(skb, reason); +} + +/* This one checks to see if we can put data from the + * out_of_order queue into the receive_queue. + */ +static void tcp_ofo_queue(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + __u32 dsack_high = tp->rcv_nxt; + bool fin, fragstolen, eaten; + struct sk_buff *skb, *tail; + struct rb_node *p; + + p = rb_first(&tp->out_of_order_queue); + while (p) { + skb = rb_to_skb(p); + if (after(TCP_SKB_CB(skb)->seq, tp->rcv_nxt)) + break; + + if (before(TCP_SKB_CB(skb)->seq, dsack_high)) { + __u32 dsack = dsack_high; + if (before(TCP_SKB_CB(skb)->end_seq, dsack_high)) + dsack_high = TCP_SKB_CB(skb)->end_seq; + tcp_dsack_extend(sk, TCP_SKB_CB(skb)->seq, dsack); + } + p = rb_next(p); + rb_erase(&skb->rbnode, &tp->out_of_order_queue); + + if (unlikely(!after(TCP_SKB_CB(skb)->end_seq, tp->rcv_nxt))) { + tcp_drop_reason(sk, skb, SKB_DROP_REASON_TCP_OFO_DROP); + continue; + } + + tail = skb_peek_tail(&sk->sk_receive_queue); + eaten = tail && tcp_try_coalesce(sk, tail, skb, &fragstolen); + tcp_rcv_nxt_update(tp, TCP_SKB_CB(skb)->end_seq); + fin = TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN; + if (!eaten) + __skb_queue_tail(&sk->sk_receive_queue, skb); + else + kfree_skb_partial(skb, fragstolen); + + if (unlikely(fin)) { + tcp_fin(sk); + /* tcp_fin() purges tp->out_of_order_queue, + * so we must end this loop right now. + */ + break; + } + } +} + +static bool tcp_prune_ofo_queue(struct sock *sk); +static int tcp_prune_queue(struct sock *sk); + +static int tcp_try_rmem_schedule(struct sock *sk, struct sk_buff *skb, + unsigned int size) +{ + if (atomic_read(&sk->sk_rmem_alloc) > sk->sk_rcvbuf || + !sk_rmem_schedule(sk, skb, size)) { + + if (tcp_prune_queue(sk) < 0) + return -1; + + while (!sk_rmem_schedule(sk, skb, size)) { + if (!tcp_prune_ofo_queue(sk)) + return -1; + } + } + return 0; +} + +static void tcp_data_queue_ofo(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct rb_node **p, *parent; + struct sk_buff *skb1; + u32 seq, end_seq; + bool fragstolen; + + tcp_ecn_check_ce(sk, skb); + + if (unlikely(tcp_try_rmem_schedule(sk, skb, skb->truesize))) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPOFODROP); + sk->sk_data_ready(sk); + tcp_drop_reason(sk, skb, SKB_DROP_REASON_PROTO_MEM); + return; + } + + /* Disable header prediction. */ + tp->pred_flags = 0; + inet_csk_schedule_ack(sk); + + tp->rcv_ooopack += max_t(u16, 1, skb_shinfo(skb)->gso_segs); + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPOFOQUEUE); + seq = TCP_SKB_CB(skb)->seq; + end_seq = TCP_SKB_CB(skb)->end_seq; + + p = &tp->out_of_order_queue.rb_node; + if (RB_EMPTY_ROOT(&tp->out_of_order_queue)) { + /* Initial out of order segment, build 1 SACK. */ + if (tcp_is_sack(tp)) { + tp->rx_opt.num_sacks = 1; + tp->selective_acks[0].start_seq = seq; + tp->selective_acks[0].end_seq = end_seq; + } + rb_link_node(&skb->rbnode, NULL, p); + rb_insert_color(&skb->rbnode, &tp->out_of_order_queue); + tp->ooo_last_skb = skb; + goto end; + } + + /* In the typical case, we are adding an skb to the end of the list. + * Use of ooo_last_skb avoids the O(Log(N)) rbtree lookup. + */ + if (tcp_ooo_try_coalesce(sk, tp->ooo_last_skb, + skb, &fragstolen)) { +coalesce_done: + /* For non sack flows, do not grow window to force DUPACK + * and trigger fast retransmit. + */ + if (tcp_is_sack(tp)) + tcp_grow_window(sk, skb, true); + kfree_skb_partial(skb, fragstolen); + skb = NULL; + goto add_sack; + } + /* Can avoid an rbtree lookup if we are adding skb after ooo_last_skb */ + if (!before(seq, TCP_SKB_CB(tp->ooo_last_skb)->end_seq)) { + parent = &tp->ooo_last_skb->rbnode; + p = &parent->rb_right; + goto insert; + } + + /* Find place to insert this segment. Handle overlaps on the way. */ + parent = NULL; + while (*p) { + parent = *p; + skb1 = rb_to_skb(parent); + if (before(seq, TCP_SKB_CB(skb1)->seq)) { + p = &parent->rb_left; + continue; + } + if (before(seq, TCP_SKB_CB(skb1)->end_seq)) { + if (!after(end_seq, TCP_SKB_CB(skb1)->end_seq)) { + /* All the bits are present. Drop. */ + NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPOFOMERGE); + tcp_drop_reason(sk, skb, + SKB_DROP_REASON_TCP_OFOMERGE); + skb = NULL; + tcp_dsack_set(sk, seq, end_seq); + goto add_sack; + } + if (after(seq, TCP_SKB_CB(skb1)->seq)) { + /* Partial overlap. */ + tcp_dsack_set(sk, seq, TCP_SKB_CB(skb1)->end_seq); + } else { + /* skb's seq == skb1's seq and skb covers skb1. + * Replace skb1 with skb. + */ + rb_replace_node(&skb1->rbnode, &skb->rbnode, + &tp->out_of_order_queue); + tcp_dsack_extend(sk, + TCP_SKB_CB(skb1)->seq, + TCP_SKB_CB(skb1)->end_seq); + NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPOFOMERGE); + tcp_drop_reason(sk, skb1, + SKB_DROP_REASON_TCP_OFOMERGE); + goto merge_right; + } + } else if (tcp_ooo_try_coalesce(sk, skb1, + skb, &fragstolen)) { + goto coalesce_done; + } + p = &parent->rb_right; + } +insert: + /* Insert segment into RB tree. */ + rb_link_node(&skb->rbnode, parent, p); + rb_insert_color(&skb->rbnode, &tp->out_of_order_queue); + +merge_right: + /* Remove other segments covered by skb. */ + while ((skb1 = skb_rb_next(skb)) != NULL) { + if (!after(end_seq, TCP_SKB_CB(skb1)->seq)) + break; + if (before(end_seq, TCP_SKB_CB(skb1)->end_seq)) { + tcp_dsack_extend(sk, TCP_SKB_CB(skb1)->seq, + end_seq); + break; + } + rb_erase(&skb1->rbnode, &tp->out_of_order_queue); + tcp_dsack_extend(sk, TCP_SKB_CB(skb1)->seq, + TCP_SKB_CB(skb1)->end_seq); + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPOFOMERGE); + tcp_drop_reason(sk, skb1, SKB_DROP_REASON_TCP_OFOMERGE); + } + /* If there is no skb after us, we are the last_skb ! */ + if (!skb1) + tp->ooo_last_skb = skb; + +add_sack: + if (tcp_is_sack(tp)) + tcp_sack_new_ofo_skb(sk, seq, end_seq); +end: + if (skb) { + /* For non sack flows, do not grow window to force DUPACK + * and trigger fast retransmit. + */ + if (tcp_is_sack(tp)) + tcp_grow_window(sk, skb, false); + skb_condense(skb); + skb_set_owner_r(skb, sk); + } +} + +static int __must_check tcp_queue_rcv(struct sock *sk, struct sk_buff *skb, + bool *fragstolen) +{ + int eaten; + struct sk_buff *tail = skb_peek_tail(&sk->sk_receive_queue); + + eaten = (tail && + tcp_try_coalesce(sk, tail, + skb, fragstolen)) ? 1 : 0; + tcp_rcv_nxt_update(tcp_sk(sk), TCP_SKB_CB(skb)->end_seq); + if (!eaten) { + __skb_queue_tail(&sk->sk_receive_queue, skb); + skb_set_owner_r(skb, sk); + } + return eaten; +} + +int tcp_send_rcvq(struct sock *sk, struct msghdr *msg, size_t size) +{ + struct sk_buff *skb; + int err = -ENOMEM; + int data_len = 0; + bool fragstolen; + + if (size == 0) + return 0; + + if (size > PAGE_SIZE) { + int npages = min_t(size_t, size >> PAGE_SHIFT, MAX_SKB_FRAGS); + + data_len = npages << PAGE_SHIFT; + size = data_len + (size & ~PAGE_MASK); + } + skb = alloc_skb_with_frags(size - data_len, data_len, + PAGE_ALLOC_COSTLY_ORDER, + &err, sk->sk_allocation); + if (!skb) + goto err; + + skb_put(skb, size - data_len); + skb->data_len = data_len; + skb->len = size; + + if (tcp_try_rmem_schedule(sk, skb, skb->truesize)) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPRCVQDROP); + goto err_free; + } + + err = skb_copy_datagram_from_iter(skb, 0, &msg->msg_iter, size); + if (err) + goto err_free; + + TCP_SKB_CB(skb)->seq = tcp_sk(sk)->rcv_nxt; + TCP_SKB_CB(skb)->end_seq = TCP_SKB_CB(skb)->seq + size; + TCP_SKB_CB(skb)->ack_seq = tcp_sk(sk)->snd_una - 1; + + if (tcp_queue_rcv(sk, skb, &fragstolen)) { + WARN_ON_ONCE(fragstolen); /* should not happen */ + __kfree_skb(skb); + } + return size; + +err_free: + kfree_skb(skb); +err: + return err; + +} + +void tcp_data_ready(struct sock *sk) +{ + if (tcp_epollin_ready(sk, sk->sk_rcvlowat) || sock_flag(sk, SOCK_DONE)) + sk->sk_data_ready(sk); +} + +static void tcp_data_queue(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + enum skb_drop_reason reason; + bool fragstolen; + int eaten; + + /* If a subflow has been reset, the packet should not continue + * to be processed, drop the packet. + */ + if (sk_is_mptcp(sk) && !mptcp_incoming_options(sk, skb)) { + __kfree_skb(skb); + return; + } + + if (TCP_SKB_CB(skb)->seq == TCP_SKB_CB(skb)->end_seq) { + __kfree_skb(skb); + return; + } + skb_dst_drop(skb); + __skb_pull(skb, tcp_hdr(skb)->doff * 4); + + reason = SKB_DROP_REASON_NOT_SPECIFIED; + tp->rx_opt.dsack = 0; + + /* Queue data for delivery to the user. + * Packets in sequence go to the receive queue. + * Out of sequence packets to the out_of_order_queue. + */ + if (TCP_SKB_CB(skb)->seq == tp->rcv_nxt) { + if (tcp_receive_window(tp) == 0) { + reason = SKB_DROP_REASON_TCP_ZEROWINDOW; + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPZEROWINDOWDROP); + goto out_of_window; + } + + /* Ok. In sequence. In window. */ +queue_and_out: + if (skb_queue_len(&sk->sk_receive_queue) == 0) + sk_forced_mem_schedule(sk, skb->truesize); + else if (tcp_try_rmem_schedule(sk, skb, skb->truesize)) { + reason = SKB_DROP_REASON_PROTO_MEM; + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPRCVQDROP); + sk->sk_data_ready(sk); + goto drop; + } + + eaten = tcp_queue_rcv(sk, skb, &fragstolen); + if (skb->len) + tcp_event_data_recv(sk, skb); + if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) + tcp_fin(sk); + + if (!RB_EMPTY_ROOT(&tp->out_of_order_queue)) { + tcp_ofo_queue(sk); + + /* RFC5681. 4.2. SHOULD send immediate ACK, when + * gap in queue is filled. + */ + if (RB_EMPTY_ROOT(&tp->out_of_order_queue)) + inet_csk(sk)->icsk_ack.pending |= ICSK_ACK_NOW; + } + + if (tp->rx_opt.num_sacks) + tcp_sack_remove(tp); + + tcp_fast_path_check(sk); + + if (eaten > 0) + kfree_skb_partial(skb, fragstolen); + if (!sock_flag(sk, SOCK_DEAD)) + tcp_data_ready(sk); + return; + } + + if (!after(TCP_SKB_CB(skb)->end_seq, tp->rcv_nxt)) { + tcp_rcv_spurious_retrans(sk, skb); + /* A retransmit, 2nd most common case. Force an immediate ack. */ + reason = SKB_DROP_REASON_TCP_OLD_DATA; + NET_INC_STATS(sock_net(sk), LINUX_MIB_DELAYEDACKLOST); + tcp_dsack_set(sk, TCP_SKB_CB(skb)->seq, TCP_SKB_CB(skb)->end_seq); + +out_of_window: + tcp_enter_quickack_mode(sk, TCP_MAX_QUICKACKS); + inet_csk_schedule_ack(sk); +drop: + tcp_drop_reason(sk, skb, reason); + return; + } + + /* Out of window. F.e. zero window probe. */ + if (!before(TCP_SKB_CB(skb)->seq, + tp->rcv_nxt + tcp_receive_window(tp))) { + reason = SKB_DROP_REASON_TCP_OVERWINDOW; + goto out_of_window; + } + + if (before(TCP_SKB_CB(skb)->seq, tp->rcv_nxt)) { + /* Partial packet, seq < rcv_next < end_seq */ + tcp_dsack_set(sk, TCP_SKB_CB(skb)->seq, tp->rcv_nxt); + + /* If window is closed, drop tail of packet. But after + * remembering D-SACK for its head made in previous line. + */ + if (!tcp_receive_window(tp)) { + reason = SKB_DROP_REASON_TCP_ZEROWINDOW; + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPZEROWINDOWDROP); + goto out_of_window; + } + goto queue_and_out; + } + + tcp_data_queue_ofo(sk, skb); +} + +static struct sk_buff *tcp_skb_next(struct sk_buff *skb, struct sk_buff_head *list) +{ + if (list) + return !skb_queue_is_last(list, skb) ? skb->next : NULL; + + return skb_rb_next(skb); +} + +static struct sk_buff *tcp_collapse_one(struct sock *sk, struct sk_buff *skb, + struct sk_buff_head *list, + struct rb_root *root) +{ + struct sk_buff *next = tcp_skb_next(skb, list); + + if (list) + __skb_unlink(skb, list); + else + rb_erase(&skb->rbnode, root); + + __kfree_skb(skb); + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPRCVCOLLAPSED); + + return next; +} + +/* Insert skb into rb tree, ordered by TCP_SKB_CB(skb)->seq */ +void tcp_rbtree_insert(struct rb_root *root, struct sk_buff *skb) +{ + struct rb_node **p = &root->rb_node; + struct rb_node *parent = NULL; + struct sk_buff *skb1; + + while (*p) { + parent = *p; + skb1 = rb_to_skb(parent); + if (before(TCP_SKB_CB(skb)->seq, TCP_SKB_CB(skb1)->seq)) + p = &parent->rb_left; + else + p = &parent->rb_right; + } + rb_link_node(&skb->rbnode, parent, p); + rb_insert_color(&skb->rbnode, root); +} + +/* Collapse contiguous sequence of skbs head..tail with + * sequence numbers start..end. + * + * If tail is NULL, this means until the end of the queue. + * + * Segments with FIN/SYN are not collapsed (only because this + * simplifies code) + */ +static void +tcp_collapse(struct sock *sk, struct sk_buff_head *list, struct rb_root *root, + struct sk_buff *head, struct sk_buff *tail, u32 start, u32 end) +{ + struct sk_buff *skb = head, *n; + struct sk_buff_head tmp; + bool end_of_skbs; + + /* First, check that queue is collapsible and find + * the point where collapsing can be useful. + */ +restart: + for (end_of_skbs = true; skb != NULL && skb != tail; skb = n) { + n = tcp_skb_next(skb, list); + + /* No new bits? It is possible on ofo queue. */ + if (!before(start, TCP_SKB_CB(skb)->end_seq)) { + skb = tcp_collapse_one(sk, skb, list, root); + if (!skb) + break; + goto restart; + } + + /* The first skb to collapse is: + * - not SYN/FIN and + * - bloated or contains data before "start" or + * overlaps to the next one and mptcp allow collapsing. + */ + if (!(TCP_SKB_CB(skb)->tcp_flags & (TCPHDR_SYN | TCPHDR_FIN)) && + (tcp_win_from_space(sk, skb->truesize) > skb->len || + before(TCP_SKB_CB(skb)->seq, start))) { + end_of_skbs = false; + break; + } + + if (n && n != tail && mptcp_skb_can_collapse(skb, n) && + TCP_SKB_CB(skb)->end_seq != TCP_SKB_CB(n)->seq) { + end_of_skbs = false; + break; + } + + /* Decided to skip this, advance start seq. */ + start = TCP_SKB_CB(skb)->end_seq; + } + if (end_of_skbs || + (TCP_SKB_CB(skb)->tcp_flags & (TCPHDR_SYN | TCPHDR_FIN))) + return; + + __skb_queue_head_init(&tmp); + + while (before(start, end)) { + int copy = min_t(int, SKB_MAX_ORDER(0, 0), end - start); + struct sk_buff *nskb; + + nskb = alloc_skb(copy, GFP_ATOMIC); + if (!nskb) + break; + + memcpy(nskb->cb, skb->cb, sizeof(skb->cb)); +#ifdef CONFIG_TLS_DEVICE + nskb->decrypted = skb->decrypted; +#endif + TCP_SKB_CB(nskb)->seq = TCP_SKB_CB(nskb)->end_seq = start; + if (list) + __skb_queue_before(list, skb, nskb); + else + __skb_queue_tail(&tmp, nskb); /* defer rbtree insertion */ + skb_set_owner_r(nskb, sk); + mptcp_skb_ext_move(nskb, skb); + + /* Copy data, releasing collapsed skbs. */ + while (copy > 0) { + int offset = start - TCP_SKB_CB(skb)->seq; + int size = TCP_SKB_CB(skb)->end_seq - start; + + BUG_ON(offset < 0); + if (size > 0) { + size = min(copy, size); + if (skb_copy_bits(skb, offset, skb_put(nskb, size), size)) + BUG(); + TCP_SKB_CB(nskb)->end_seq += size; + copy -= size; + start += size; + } + if (!before(start, TCP_SKB_CB(skb)->end_seq)) { + skb = tcp_collapse_one(sk, skb, list, root); + if (!skb || + skb == tail || + !mptcp_skb_can_collapse(nskb, skb) || + (TCP_SKB_CB(skb)->tcp_flags & (TCPHDR_SYN | TCPHDR_FIN))) + goto end; +#ifdef CONFIG_TLS_DEVICE + if (skb->decrypted != nskb->decrypted) + goto end; +#endif + } + } + } +end: + skb_queue_walk_safe(&tmp, skb, n) + tcp_rbtree_insert(root, skb); +} + +/* Collapse ofo queue. Algorithm: select contiguous sequence of skbs + * and tcp_collapse() them until all the queue is collapsed. + */ +static void tcp_collapse_ofo_queue(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 range_truesize, sum_tiny = 0; + struct sk_buff *skb, *head; + u32 start, end; + + skb = skb_rb_first(&tp->out_of_order_queue); +new_range: + if (!skb) { + tp->ooo_last_skb = skb_rb_last(&tp->out_of_order_queue); + return; + } + start = TCP_SKB_CB(skb)->seq; + end = TCP_SKB_CB(skb)->end_seq; + range_truesize = skb->truesize; + + for (head = skb;;) { + skb = skb_rb_next(skb); + + /* Range is terminated when we see a gap or when + * we are at the queue end. + */ + if (!skb || + after(TCP_SKB_CB(skb)->seq, end) || + before(TCP_SKB_CB(skb)->end_seq, start)) { + /* Do not attempt collapsing tiny skbs */ + if (range_truesize != head->truesize || + end - start >= SKB_WITH_OVERHEAD(PAGE_SIZE)) { + tcp_collapse(sk, NULL, &tp->out_of_order_queue, + head, skb, start, end); + } else { + sum_tiny += range_truesize; + if (sum_tiny > sk->sk_rcvbuf >> 3) + return; + } + goto new_range; + } + + range_truesize += skb->truesize; + if (unlikely(before(TCP_SKB_CB(skb)->seq, start))) + start = TCP_SKB_CB(skb)->seq; + if (after(TCP_SKB_CB(skb)->end_seq, end)) + end = TCP_SKB_CB(skb)->end_seq; + } +} + +/* + * Clean the out-of-order queue to make room. + * We drop high sequences packets to : + * 1) Let a chance for holes to be filled. + * 2) not add too big latencies if thousands of packets sit there. + * (But if application shrinks SO_RCVBUF, we could still end up + * freeing whole queue here) + * 3) Drop at least 12.5 % of sk_rcvbuf to avoid malicious attacks. + * + * Return true if queue has shrunk. + */ +static bool tcp_prune_ofo_queue(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct rb_node *node, *prev; + int goal; + + if (RB_EMPTY_ROOT(&tp->out_of_order_queue)) + return false; + + NET_INC_STATS(sock_net(sk), LINUX_MIB_OFOPRUNED); + goal = sk->sk_rcvbuf >> 3; + node = &tp->ooo_last_skb->rbnode; + do { + prev = rb_prev(node); + rb_erase(node, &tp->out_of_order_queue); + goal -= rb_to_skb(node)->truesize; + tcp_drop_reason(sk, rb_to_skb(node), + SKB_DROP_REASON_TCP_OFO_QUEUE_PRUNE); + if (!prev || goal <= 0) { + if (atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf && + !tcp_under_memory_pressure(sk)) + break; + goal = sk->sk_rcvbuf >> 3; + } + node = prev; + } while (node); + tp->ooo_last_skb = rb_to_skb(prev); + + /* Reset SACK state. A conforming SACK implementation will + * do the same at a timeout based retransmit. When a connection + * is in a sad state like this, we care only about integrity + * of the connection not performance. + */ + if (tp->rx_opt.sack_ok) + tcp_sack_reset(&tp->rx_opt); + return true; +} + +/* Reduce allocated memory if we can, trying to get + * the socket within its memory limits again. + * + * Return less than zero if we should start dropping frames + * until the socket owning process reads some of the data + * to stabilize the situation. + */ +static int tcp_prune_queue(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + NET_INC_STATS(sock_net(sk), LINUX_MIB_PRUNECALLED); + + if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf) + tcp_clamp_window(sk); + else if (tcp_under_memory_pressure(sk)) + tcp_adjust_rcv_ssthresh(sk); + + if (atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) + return 0; + + tcp_collapse_ofo_queue(sk); + if (!skb_queue_empty(&sk->sk_receive_queue)) + tcp_collapse(sk, &sk->sk_receive_queue, NULL, + skb_peek(&sk->sk_receive_queue), + NULL, + tp->copied_seq, tp->rcv_nxt); + + if (atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) + return 0; + + /* Collapsing did not help, destructive actions follow. + * This must not ever occur. */ + + tcp_prune_ofo_queue(sk); + + if (atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) + return 0; + + /* If we are really being abused, tell the caller to silently + * drop receive data on the floor. It will get retransmitted + * and hopefully then we'll have sufficient space. + */ + NET_INC_STATS(sock_net(sk), LINUX_MIB_RCVPRUNED); + + /* Massive buffer overcommit. */ + tp->pred_flags = 0; + return -1; +} + +static bool tcp_should_expand_sndbuf(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + + /* If the user specified a specific send buffer setting, do + * not modify it. + */ + if (sk->sk_userlocks & SOCK_SNDBUF_LOCK) + return false; + + /* If we are under global TCP memory pressure, do not expand. */ + if (tcp_under_memory_pressure(sk)) { + int unused_mem = sk_unused_reserved_mem(sk); + + /* Adjust sndbuf according to reserved mem. But make sure + * it never goes below SOCK_MIN_SNDBUF. + * See sk_stream_moderate_sndbuf() for more details. + */ + if (unused_mem > SOCK_MIN_SNDBUF) + WRITE_ONCE(sk->sk_sndbuf, unused_mem); + + return false; + } + + /* If we are under soft global TCP memory pressure, do not expand. */ + if (sk_memory_allocated(sk) >= sk_prot_mem_limits(sk, 0)) + return false; + + /* If we filled the congestion window, do not expand. */ + if (tcp_packets_in_flight(tp) >= tcp_snd_cwnd(tp)) + return false; + + return true; +} + +static void tcp_new_space(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (tcp_should_expand_sndbuf(sk)) { + tcp_sndbuf_expand(sk); + tp->snd_cwnd_stamp = tcp_jiffies32; + } + + INDIRECT_CALL_1(sk->sk_write_space, sk_stream_write_space, sk); +} + +/* Caller made space either from: + * 1) Freeing skbs in rtx queues (after tp->snd_una has advanced) + * 2) Sent skbs from output queue (and thus advancing tp->snd_nxt) + * + * We might be able to generate EPOLLOUT to the application if: + * 1) Space consumed in output/rtx queues is below sk->sk_sndbuf/2 + * 2) notsent amount (tp->write_seq - tp->snd_nxt) became + * small enough that tcp_stream_memory_free() decides it + * is time to generate EPOLLOUT. + */ +void tcp_check_space(struct sock *sk) +{ + /* pairs with tcp_poll() */ + smp_mb(); + if (sk->sk_socket && + test_bit(SOCK_NOSPACE, &sk->sk_socket->flags)) { + tcp_new_space(sk); + if (!test_bit(SOCK_NOSPACE, &sk->sk_socket->flags)) + tcp_chrono_stop(sk, TCP_CHRONO_SNDBUF_LIMITED); + } +} + +static inline void tcp_data_snd_check(struct sock *sk) +{ + tcp_push_pending_frames(sk); + tcp_check_space(sk); +} + +/* + * Check if sending an ack is needed. + */ +static void __tcp_ack_snd_check(struct sock *sk, int ofo_possible) +{ + struct tcp_sock *tp = tcp_sk(sk); + unsigned long rtt, delay; + + /* More than one full frame received... */ + if (((tp->rcv_nxt - tp->rcv_wup) > inet_csk(sk)->icsk_ack.rcv_mss && + /* ... and right edge of window advances far enough. + * (tcp_recvmsg() will send ACK otherwise). + * If application uses SO_RCVLOWAT, we want send ack now if + * we have not received enough bytes to satisfy the condition. + */ + (tp->rcv_nxt - tp->copied_seq < sk->sk_rcvlowat || + __tcp_select_window(sk) >= tp->rcv_wnd)) || + /* We ACK each frame or... */ + tcp_in_quickack_mode(sk) || + /* Protocol state mandates a one-time immediate ACK */ + inet_csk(sk)->icsk_ack.pending & ICSK_ACK_NOW) { +send_now: + tcp_send_ack(sk); + return; + } + + if (!ofo_possible || RB_EMPTY_ROOT(&tp->out_of_order_queue)) { + tcp_send_delayed_ack(sk); + return; + } + + if (!tcp_is_sack(tp) || + tp->compressed_ack >= READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_comp_sack_nr)) + goto send_now; + + if (tp->compressed_ack_rcv_nxt != tp->rcv_nxt) { + tp->compressed_ack_rcv_nxt = tp->rcv_nxt; + tp->dup_ack_counter = 0; + } + if (tp->dup_ack_counter < TCP_FASTRETRANS_THRESH) { + tp->dup_ack_counter++; + goto send_now; + } + tp->compressed_ack++; + if (hrtimer_is_queued(&tp->compressed_ack_timer)) + return; + + /* compress ack timer : 5 % of rtt, but no more than tcp_comp_sack_delay_ns */ + + rtt = tp->rcv_rtt_est.rtt_us; + if (tp->srtt_us && tp->srtt_us < rtt) + rtt = tp->srtt_us; + + delay = min_t(unsigned long, + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_comp_sack_delay_ns), + rtt * (NSEC_PER_USEC >> 3)/20); + sock_hold(sk); + hrtimer_start_range_ns(&tp->compressed_ack_timer, ns_to_ktime(delay), + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_comp_sack_slack_ns), + HRTIMER_MODE_REL_PINNED_SOFT); +} + +static inline void tcp_ack_snd_check(struct sock *sk) +{ + if (!inet_csk_ack_scheduled(sk)) { + /* We sent a data segment already. */ + return; + } + __tcp_ack_snd_check(sk, 1); +} + +/* + * This routine is only called when we have urgent data + * signaled. Its the 'slow' part of tcp_urg. It could be + * moved inline now as tcp_urg is only called from one + * place. We handle URGent data wrong. We have to - as + * BSD still doesn't use the correction from RFC961. + * For 1003.1g we should support a new option TCP_STDURG to permit + * either form (or just set the sysctl tcp_stdurg). + */ + +static void tcp_check_urg(struct sock *sk, const struct tcphdr *th) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 ptr = ntohs(th->urg_ptr); + + if (ptr && !READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_stdurg)) + ptr--; + ptr += ntohl(th->seq); + + /* Ignore urgent data that we've already seen and read. */ + if (after(tp->copied_seq, ptr)) + return; + + /* Do not replay urg ptr. + * + * NOTE: interesting situation not covered by specs. + * Misbehaving sender may send urg ptr, pointing to segment, + * which we already have in ofo queue. We are not able to fetch + * such data and will stay in TCP_URG_NOTYET until will be eaten + * by recvmsg(). Seems, we are not obliged to handle such wicked + * situations. But it is worth to think about possibility of some + * DoSes using some hypothetical application level deadlock. + */ + if (before(ptr, tp->rcv_nxt)) + return; + + /* Do we already have a newer (or duplicate) urgent pointer? */ + if (tp->urg_data && !after(ptr, tp->urg_seq)) + return; + + /* Tell the world about our new urgent pointer. */ + sk_send_sigurg(sk); + + /* We may be adding urgent data when the last byte read was + * urgent. To do this requires some care. We cannot just ignore + * tp->copied_seq since we would read the last urgent byte again + * as data, nor can we alter copied_seq until this data arrives + * or we break the semantics of SIOCATMARK (and thus sockatmark()) + * + * NOTE. Double Dutch. Rendering to plain English: author of comment + * above did something sort of send("A", MSG_OOB); send("B", MSG_OOB); + * and expect that both A and B disappear from stream. This is _wrong_. + * Though this happens in BSD with high probability, this is occasional. + * Any application relying on this is buggy. Note also, that fix "works" + * only in this artificial test. Insert some normal data between A and B and we will + * decline of BSD again. Verdict: it is better to remove to trap + * buggy users. + */ + if (tp->urg_seq == tp->copied_seq && tp->urg_data && + !sock_flag(sk, SOCK_URGINLINE) && tp->copied_seq != tp->rcv_nxt) { + struct sk_buff *skb = skb_peek(&sk->sk_receive_queue); + tp->copied_seq++; + if (skb && !before(tp->copied_seq, TCP_SKB_CB(skb)->end_seq)) { + __skb_unlink(skb, &sk->sk_receive_queue); + __kfree_skb(skb); + } + } + + WRITE_ONCE(tp->urg_data, TCP_URG_NOTYET); + WRITE_ONCE(tp->urg_seq, ptr); + + /* Disable header prediction. */ + tp->pred_flags = 0; +} + +/* This is the 'fast' part of urgent handling. */ +static void tcp_urg(struct sock *sk, struct sk_buff *skb, const struct tcphdr *th) +{ + struct tcp_sock *tp = tcp_sk(sk); + + /* Check if we get a new urgent pointer - normally not. */ + if (unlikely(th->urg)) + tcp_check_urg(sk, th); + + /* Do we wait for any urgent data? - normally not... */ + if (unlikely(tp->urg_data == TCP_URG_NOTYET)) { + u32 ptr = tp->urg_seq - ntohl(th->seq) + (th->doff * 4) - + th->syn; + + /* Is the urgent pointer pointing into this packet? */ + if (ptr < skb->len) { + u8 tmp; + if (skb_copy_bits(skb, ptr, &tmp, 1)) + BUG(); + WRITE_ONCE(tp->urg_data, TCP_URG_VALID | tmp); + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_data_ready(sk); + } + } +} + +/* Accept RST for rcv_nxt - 1 after a FIN. + * When tcp connections are abruptly terminated from Mac OSX (via ^C), a + * FIN is sent followed by a RST packet. The RST is sent with the same + * sequence number as the FIN, and thus according to RFC 5961 a challenge + * ACK should be sent. However, Mac OSX rate limits replies to challenge + * ACKs on the closed socket. In addition middleboxes can drop either the + * challenge ACK or a subsequent RST. + */ +static bool tcp_reset_check(const struct sock *sk, const struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + + return unlikely(TCP_SKB_CB(skb)->seq == (tp->rcv_nxt - 1) && + (1 << sk->sk_state) & (TCPF_CLOSE_WAIT | TCPF_LAST_ACK | + TCPF_CLOSING)); +} + +/* Does PAWS and seqno based validation of an incoming segment, flags will + * play significant role here. + */ +static bool tcp_validate_incoming(struct sock *sk, struct sk_buff *skb, + const struct tcphdr *th, int syn_inerr) +{ + struct tcp_sock *tp = tcp_sk(sk); + SKB_DR(reason); + + /* RFC1323: H1. Apply PAWS check first. */ + if (tcp_fast_parse_options(sock_net(sk), skb, th, tp) && + tp->rx_opt.saw_tstamp && + tcp_paws_discard(sk, skb)) { + if (!th->rst) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_PAWSESTABREJECTED); + if (!tcp_oow_rate_limited(sock_net(sk), skb, + LINUX_MIB_TCPACKSKIPPEDPAWS, + &tp->last_oow_ack_time)) + tcp_send_dupack(sk, skb); + SKB_DR_SET(reason, TCP_RFC7323_PAWS); + goto discard; + } + /* Reset is accepted even if it did not pass PAWS. */ + } + + /* Step 1: check sequence number */ + if (!tcp_sequence(tp, TCP_SKB_CB(skb)->seq, TCP_SKB_CB(skb)->end_seq)) { + /* RFC793, page 37: "In all states except SYN-SENT, all reset + * (RST) segments are validated by checking their SEQ-fields." + * And page 69: "If an incoming segment is not acceptable, + * an acknowledgment should be sent in reply (unless the RST + * bit is set, if so drop the segment and return)". + */ + if (!th->rst) { + if (th->syn) + goto syn_challenge; + if (!tcp_oow_rate_limited(sock_net(sk), skb, + LINUX_MIB_TCPACKSKIPPEDSEQ, + &tp->last_oow_ack_time)) + tcp_send_dupack(sk, skb); + } else if (tcp_reset_check(sk, skb)) { + goto reset; + } + SKB_DR_SET(reason, TCP_INVALID_SEQUENCE); + goto discard; + } + + /* Step 2: check RST bit */ + if (th->rst) { + /* RFC 5961 3.2 (extend to match against (RCV.NXT - 1) after a + * FIN and SACK too if available): + * If seq num matches RCV.NXT or (RCV.NXT - 1) after a FIN, or + * the right-most SACK block, + * then + * RESET the connection + * else + * Send a challenge ACK + */ + if (TCP_SKB_CB(skb)->seq == tp->rcv_nxt || + tcp_reset_check(sk, skb)) + goto reset; + + if (tcp_is_sack(tp) && tp->rx_opt.num_sacks > 0) { + struct tcp_sack_block *sp = &tp->selective_acks[0]; + int max_sack = sp[0].end_seq; + int this_sack; + + for (this_sack = 1; this_sack < tp->rx_opt.num_sacks; + ++this_sack) { + max_sack = after(sp[this_sack].end_seq, + max_sack) ? + sp[this_sack].end_seq : max_sack; + } + + if (TCP_SKB_CB(skb)->seq == max_sack) + goto reset; + } + + /* Disable TFO if RST is out-of-order + * and no data has been received + * for current active TFO socket + */ + if (tp->syn_fastopen && !tp->data_segs_in && + sk->sk_state == TCP_ESTABLISHED) + tcp_fastopen_active_disable(sk); + tcp_send_challenge_ack(sk); + SKB_DR_SET(reason, TCP_RESET); + goto discard; + } + + /* step 3: check security and precedence [ignored] */ + + /* step 4: Check for a SYN + * RFC 5961 4.2 : Send a challenge ack + */ + if (th->syn) { +syn_challenge: + if (syn_inerr) + TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS); + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPSYNCHALLENGE); + tcp_send_challenge_ack(sk); + SKB_DR_SET(reason, TCP_INVALID_SYN); + goto discard; + } + + bpf_skops_parse_hdr(sk, skb); + + return true; + +discard: + tcp_drop_reason(sk, skb, reason); + return false; + +reset: + tcp_reset(sk, skb); + __kfree_skb(skb); + return false; +} + +/* + * TCP receive function for the ESTABLISHED state. + * + * It is split into a fast path and a slow path. The fast path is + * disabled when: + * - A zero window was announced from us - zero window probing + * is only handled properly in the slow path. + * - Out of order segments arrived. + * - Urgent data is expected. + * - There is no buffer space left + * - Unexpected TCP flags/window values/header lengths are received + * (detected by checking the TCP header against pred_flags) + * - Data is sent in both directions. Fast path only supports pure senders + * or pure receivers (this means either the sequence number or the ack + * value must stay constant) + * - Unexpected TCP option. + * + * When these conditions are not satisfied it drops into a standard + * receive procedure patterned after RFC793 to handle all cases. + * The first three cases are guaranteed by proper pred_flags setting, + * the rest is checked inline. Fast processing is turned on in + * tcp_data_queue when everything is OK. + */ +void tcp_rcv_established(struct sock *sk, struct sk_buff *skb) +{ + enum skb_drop_reason reason = SKB_DROP_REASON_NOT_SPECIFIED; + const struct tcphdr *th = (const struct tcphdr *)skb->data; + struct tcp_sock *tp = tcp_sk(sk); + unsigned int len = skb->len; + + /* TCP congestion window tracking */ + trace_tcp_probe(sk, skb); + + tcp_mstamp_refresh(tp); + if (unlikely(!rcu_access_pointer(sk->sk_rx_dst))) + inet_csk(sk)->icsk_af_ops->sk_rx_dst_set(sk, skb); + /* + * Header prediction. + * The code loosely follows the one in the famous + * "30 instruction TCP receive" Van Jacobson mail. + * + * Van's trick is to deposit buffers into socket queue + * on a device interrupt, to call tcp_recv function + * on the receive process context and checksum and copy + * the buffer to user space. smart... + * + * Our current scheme is not silly either but we take the + * extra cost of the net_bh soft interrupt processing... + * We do checksum and copy also but from device to kernel. + */ + + tp->rx_opt.saw_tstamp = 0; + + /* pred_flags is 0xS?10 << 16 + snd_wnd + * if header_prediction is to be made + * 'S' will always be tp->tcp_header_len >> 2 + * '?' will be 0 for the fast path, otherwise pred_flags is 0 to + * turn it off (when there are holes in the receive + * space for instance) + * PSH flag is ignored. + */ + + if ((tcp_flag_word(th) & TCP_HP_BITS) == tp->pred_flags && + TCP_SKB_CB(skb)->seq == tp->rcv_nxt && + !after(TCP_SKB_CB(skb)->ack_seq, tp->snd_nxt)) { + int tcp_header_len = tp->tcp_header_len; + + /* Timestamp header prediction: tcp_header_len + * is automatically equal to th->doff*4 due to pred_flags + * match. + */ + + /* Check timestamp */ + if (tcp_header_len == sizeof(struct tcphdr) + TCPOLEN_TSTAMP_ALIGNED) { + /* No? Slow path! */ + if (!tcp_parse_aligned_timestamp(tp, th)) + goto slow_path; + + /* If PAWS failed, check it more carefully in slow path */ + if ((s32)(tp->rx_opt.rcv_tsval - tp->rx_opt.ts_recent) < 0) + goto slow_path; + + /* DO NOT update ts_recent here, if checksum fails + * and timestamp was corrupted part, it will result + * in a hung connection since we will drop all + * future packets due to the PAWS test. + */ + } + + if (len <= tcp_header_len) { + /* Bulk data transfer: sender */ + if (len == tcp_header_len) { + /* Predicted packet is in window by definition. + * seq == rcv_nxt and rcv_wup <= rcv_nxt. + * Hence, check seq<=rcv_wup reduces to: + */ + if (tcp_header_len == + (sizeof(struct tcphdr) + TCPOLEN_TSTAMP_ALIGNED) && + tp->rcv_nxt == tp->rcv_wup) + tcp_store_ts_recent(tp); + + /* We know that such packets are checksummed + * on entry. + */ + tcp_ack(sk, skb, 0); + __kfree_skb(skb); + tcp_data_snd_check(sk); + /* When receiving pure ack in fast path, update + * last ts ecr directly instead of calling + * tcp_rcv_rtt_measure_ts() + */ + tp->rcv_rtt_last_tsecr = tp->rx_opt.rcv_tsecr; + return; + } else { /* Header too small */ + reason = SKB_DROP_REASON_PKT_TOO_SMALL; + TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS); + goto discard; + } + } else { + int eaten = 0; + bool fragstolen = false; + + if (tcp_checksum_complete(skb)) + goto csum_error; + + if ((int)skb->truesize > sk->sk_forward_alloc) + goto step5; + + /* Predicted packet is in window by definition. + * seq == rcv_nxt and rcv_wup <= rcv_nxt. + * Hence, check seq<=rcv_wup reduces to: + */ + if (tcp_header_len == + (sizeof(struct tcphdr) + TCPOLEN_TSTAMP_ALIGNED) && + tp->rcv_nxt == tp->rcv_wup) + tcp_store_ts_recent(tp); + + tcp_rcv_rtt_measure_ts(sk, skb); + + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPHPHITS); + + /* Bulk data transfer: receiver */ + skb_dst_drop(skb); + __skb_pull(skb, tcp_header_len); + eaten = tcp_queue_rcv(sk, skb, &fragstolen); + + tcp_event_data_recv(sk, skb); + + if (TCP_SKB_CB(skb)->ack_seq != tp->snd_una) { + /* Well, only one small jumplet in fast path... */ + tcp_ack(sk, skb, FLAG_DATA); + tcp_data_snd_check(sk); + if (!inet_csk_ack_scheduled(sk)) + goto no_ack; + } else { + tcp_update_wl(tp, TCP_SKB_CB(skb)->seq); + } + + __tcp_ack_snd_check(sk, 0); +no_ack: + if (eaten) + kfree_skb_partial(skb, fragstolen); + tcp_data_ready(sk); + return; + } + } + +slow_path: + if (len < (th->doff << 2) || tcp_checksum_complete(skb)) + goto csum_error; + + if (!th->ack && !th->rst && !th->syn) { + reason = SKB_DROP_REASON_TCP_FLAGS; + goto discard; + } + + /* + * Standard slow path. + */ + + if (!tcp_validate_incoming(sk, skb, th, 1)) + return; + +step5: + reason = tcp_ack(sk, skb, FLAG_SLOWPATH | FLAG_UPDATE_TS_RECENT); + if ((int)reason < 0) { + reason = -reason; + goto discard; + } + tcp_rcv_rtt_measure_ts(sk, skb); + + /* Process urgent data. */ + tcp_urg(sk, skb, th); + + /* step 7: process the segment text */ + tcp_data_queue(sk, skb); + + tcp_data_snd_check(sk); + tcp_ack_snd_check(sk); + return; + +csum_error: + reason = SKB_DROP_REASON_TCP_CSUM; + trace_tcp_bad_csum(skb); + TCP_INC_STATS(sock_net(sk), TCP_MIB_CSUMERRORS); + TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS); + +discard: + tcp_drop_reason(sk, skb, reason); +} +EXPORT_SYMBOL(tcp_rcv_established); + +void tcp_init_transfer(struct sock *sk, int bpf_op, struct sk_buff *skb) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + + tcp_mtup_init(sk); + icsk->icsk_af_ops->rebuild_header(sk); + tcp_init_metrics(sk); + + /* Initialize the congestion window to start the transfer. + * Cut cwnd down to 1 per RFC5681 if SYN or SYN-ACK has been + * retransmitted. In light of RFC6298 more aggressive 1sec + * initRTO, we only reset cwnd when more than 1 SYN/SYN-ACK + * retransmission has occurred. + */ + if (tp->total_retrans > 1 && tp->undo_marker) + tcp_snd_cwnd_set(tp, 1); + else + tcp_snd_cwnd_set(tp, tcp_init_cwnd(tp, __sk_dst_get(sk))); + tp->snd_cwnd_stamp = tcp_jiffies32; + + bpf_skops_established(sk, bpf_op, skb); + /* Initialize congestion control unless BPF initialized it already: */ + if (!icsk->icsk_ca_initialized) + tcp_init_congestion_control(sk); + tcp_init_buffer_space(sk); +} + +void tcp_finish_connect(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct inet_connection_sock *icsk = inet_csk(sk); + + tcp_set_state(sk, TCP_ESTABLISHED); + icsk->icsk_ack.lrcvtime = tcp_jiffies32; + + if (skb) { + icsk->icsk_af_ops->sk_rx_dst_set(sk, skb); + security_inet_conn_established(sk, skb); + sk_mark_napi_id(sk, skb); + } + + tcp_init_transfer(sk, BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB, skb); + + /* Prevent spurious tcp_cwnd_restart() on first data + * packet. + */ + tp->lsndtime = tcp_jiffies32; + + if (sock_flag(sk, SOCK_KEEPOPEN)) + inet_csk_reset_keepalive_timer(sk, keepalive_time_when(tp)); + + if (!tp->rx_opt.snd_wscale) + __tcp_fast_path_on(tp, tp->snd_wnd); + else + tp->pred_flags = 0; +} + +static bool tcp_rcv_fastopen_synack(struct sock *sk, struct sk_buff *synack, + struct tcp_fastopen_cookie *cookie) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *data = tp->syn_data ? tcp_rtx_queue_head(sk) : NULL; + u16 mss = tp->rx_opt.mss_clamp, try_exp = 0; + bool syn_drop = false; + + if (mss == tp->rx_opt.user_mss) { + struct tcp_options_received opt; + + /* Get original SYNACK MSS value if user MSS sets mss_clamp */ + tcp_clear_options(&opt); + opt.user_mss = opt.mss_clamp = 0; + tcp_parse_options(sock_net(sk), synack, &opt, 0, NULL); + mss = opt.mss_clamp; + } + + if (!tp->syn_fastopen) { + /* Ignore an unsolicited cookie */ + cookie->len = -1; + } else if (tp->total_retrans) { + /* SYN timed out and the SYN-ACK neither has a cookie nor + * acknowledges data. Presumably the remote received only + * the retransmitted (regular) SYNs: either the original + * SYN-data or the corresponding SYN-ACK was dropped. + */ + syn_drop = (cookie->len < 0 && data); + } else if (cookie->len < 0 && !tp->syn_data) { + /* We requested a cookie but didn't get it. If we did not use + * the (old) exp opt format then try so next time (try_exp=1). + * Otherwise we go back to use the RFC7413 opt (try_exp=2). + */ + try_exp = tp->syn_fastopen_exp ? 2 : 1; + } + + tcp_fastopen_cache_set(sk, mss, cookie, syn_drop, try_exp); + + if (data) { /* Retransmit unacked data in SYN */ + if (tp->total_retrans) + tp->fastopen_client_fail = TFO_SYN_RETRANSMITTED; + else + tp->fastopen_client_fail = TFO_DATA_NOT_ACKED; + skb_rbtree_walk_from(data) + tcp_mark_skb_lost(sk, data); + tcp_xmit_retransmit_queue(sk); + NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPFASTOPENACTIVEFAIL); + return true; + } + tp->syn_data_acked = tp->syn_data; + if (tp->syn_data_acked) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPFASTOPENACTIVE); + /* SYN-data is counted as two separate packets in tcp_ack() */ + if (tp->delivered > 1) + --tp->delivered; + } + + tcp_fastopen_add_skb(sk, synack); + + return false; +} + +static void smc_check_reset_syn(struct tcp_sock *tp) +{ +#if IS_ENABLED(CONFIG_SMC) + if (static_branch_unlikely(&tcp_have_smc)) { + if (tp->syn_smc && !tp->rx_opt.smc_ok) + tp->syn_smc = 0; + } +#endif +} + +static void tcp_try_undo_spurious_syn(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 syn_stamp; + + /* undo_marker is set when SYN or SYNACK times out. The timeout is + * spurious if the ACK's timestamp option echo value matches the + * original SYN timestamp. + */ + syn_stamp = tp->retrans_stamp; + if (tp->undo_marker && syn_stamp && tp->rx_opt.saw_tstamp && + syn_stamp == tp->rx_opt.rcv_tsecr) + tp->undo_marker = 0; +} + +static int tcp_rcv_synsent_state_process(struct sock *sk, struct sk_buff *skb, + const struct tcphdr *th) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct tcp_fastopen_cookie foc = { .len = -1 }; + int saved_clamp = tp->rx_opt.mss_clamp; + bool fastopen_fail; + SKB_DR(reason); + + tcp_parse_options(sock_net(sk), skb, &tp->rx_opt, 0, &foc); + if (tp->rx_opt.saw_tstamp && tp->rx_opt.rcv_tsecr) + tp->rx_opt.rcv_tsecr -= tp->tsoffset; + + if (th->ack) { + /* rfc793: + * "If the state is SYN-SENT then + * first check the ACK bit + * If the ACK bit is set + * If SEG.ACK =< ISS, or SEG.ACK > SND.NXT, send + * a reset (unless the RST bit is set, if so drop + * the segment and return)" + */ + if (!after(TCP_SKB_CB(skb)->ack_seq, tp->snd_una) || + after(TCP_SKB_CB(skb)->ack_seq, tp->snd_nxt)) { + /* Previous FIN/ACK or RST/ACK might be ignored. */ + if (icsk->icsk_retransmits == 0) + inet_csk_reset_xmit_timer(sk, + ICSK_TIME_RETRANS, + TCP_TIMEOUT_MIN, TCP_RTO_MAX); + goto reset_and_undo; + } + + if (tp->rx_opt.saw_tstamp && tp->rx_opt.rcv_tsecr && + !between(tp->rx_opt.rcv_tsecr, tp->retrans_stamp, + tcp_time_stamp(tp))) { + NET_INC_STATS(sock_net(sk), + LINUX_MIB_PAWSACTIVEREJECTED); + goto reset_and_undo; + } + + /* Now ACK is acceptable. + * + * "If the RST bit is set + * If the ACK was acceptable then signal the user "error: + * connection reset", drop the segment, enter CLOSED state, + * delete TCB, and return." + */ + + if (th->rst) { + tcp_reset(sk, skb); +consume: + __kfree_skb(skb); + return 0; + } + + /* rfc793: + * "fifth, if neither of the SYN or RST bits is set then + * drop the segment and return." + * + * See note below! + * --ANK(990513) + */ + if (!th->syn) { + SKB_DR_SET(reason, TCP_FLAGS); + goto discard_and_undo; + } + /* rfc793: + * "If the SYN bit is on ... + * are acceptable then ... + * (our SYN has been ACKed), change the connection + * state to ESTABLISHED..." + */ + + tcp_ecn_rcv_synack(tp, th); + + tcp_init_wl(tp, TCP_SKB_CB(skb)->seq); + tcp_try_undo_spurious_syn(sk); + tcp_ack(sk, skb, FLAG_SLOWPATH); + + /* Ok.. it's good. Set up sequence numbers and + * move to established. + */ + WRITE_ONCE(tp->rcv_nxt, TCP_SKB_CB(skb)->seq + 1); + tp->rcv_wup = TCP_SKB_CB(skb)->seq + 1; + + /* RFC1323: The window in SYN & SYN/ACK segments is + * never scaled. + */ + tp->snd_wnd = ntohs(th->window); + + if (!tp->rx_opt.wscale_ok) { + tp->rx_opt.snd_wscale = tp->rx_opt.rcv_wscale = 0; + tp->window_clamp = min(tp->window_clamp, 65535U); + } + + if (tp->rx_opt.saw_tstamp) { + tp->rx_opt.tstamp_ok = 1; + tp->tcp_header_len = + sizeof(struct tcphdr) + TCPOLEN_TSTAMP_ALIGNED; + tp->advmss -= TCPOLEN_TSTAMP_ALIGNED; + tcp_store_ts_recent(tp); + } else { + tp->tcp_header_len = sizeof(struct tcphdr); + } + + tcp_sync_mss(sk, icsk->icsk_pmtu_cookie); + tcp_initialize_rcv_mss(sk); + + /* Remember, tcp_poll() does not lock socket! + * Change state from SYN-SENT only after copied_seq + * is initialized. */ + WRITE_ONCE(tp->copied_seq, tp->rcv_nxt); + + smc_check_reset_syn(tp); + + smp_mb(); + + tcp_finish_connect(sk, skb); + + fastopen_fail = (tp->syn_fastopen || tp->syn_data) && + tcp_rcv_fastopen_synack(sk, skb, &foc); + + if (!sock_flag(sk, SOCK_DEAD)) { + sk->sk_state_change(sk); + sk_wake_async(sk, SOCK_WAKE_IO, POLL_OUT); + } + if (fastopen_fail) + return -1; + if (sk->sk_write_pending || + icsk->icsk_accept_queue.rskq_defer_accept || + inet_csk_in_pingpong_mode(sk)) { + /* Save one ACK. Data will be ready after + * several ticks, if write_pending is set. + * + * It may be deleted, but with this feature tcpdumps + * look so _wonderfully_ clever, that I was not able + * to stand against the temptation 8) --ANK + */ + inet_csk_schedule_ack(sk); + tcp_enter_quickack_mode(sk, TCP_MAX_QUICKACKS); + inet_csk_reset_xmit_timer(sk, ICSK_TIME_DACK, + TCP_DELACK_MAX, TCP_RTO_MAX); + goto consume; + } + tcp_send_ack(sk); + return -1; + } + + /* No ACK in the segment */ + + if (th->rst) { + /* rfc793: + * "If the RST bit is set + * + * Otherwise (no ACK) drop the segment and return." + */ + SKB_DR_SET(reason, TCP_RESET); + goto discard_and_undo; + } + + /* PAWS check. */ + if (tp->rx_opt.ts_recent_stamp && tp->rx_opt.saw_tstamp && + tcp_paws_reject(&tp->rx_opt, 0)) { + SKB_DR_SET(reason, TCP_RFC7323_PAWS); + goto discard_and_undo; + } + if (th->syn) { + /* We see SYN without ACK. It is attempt of + * simultaneous connect with crossed SYNs. + * Particularly, it can be connect to self. + */ + tcp_set_state(sk, TCP_SYN_RECV); + + if (tp->rx_opt.saw_tstamp) { + tp->rx_opt.tstamp_ok = 1; + tcp_store_ts_recent(tp); + tp->tcp_header_len = + sizeof(struct tcphdr) + TCPOLEN_TSTAMP_ALIGNED; + } else { + tp->tcp_header_len = sizeof(struct tcphdr); + } + + WRITE_ONCE(tp->rcv_nxt, TCP_SKB_CB(skb)->seq + 1); + WRITE_ONCE(tp->copied_seq, tp->rcv_nxt); + tp->rcv_wup = TCP_SKB_CB(skb)->seq + 1; + + /* RFC1323: The window in SYN & SYN/ACK segments is + * never scaled. + */ + tp->snd_wnd = ntohs(th->window); + tp->snd_wl1 = TCP_SKB_CB(skb)->seq; + tp->max_window = tp->snd_wnd; + + tcp_ecn_rcv_syn(tp, th); + + tcp_mtup_init(sk); + tcp_sync_mss(sk, icsk->icsk_pmtu_cookie); + tcp_initialize_rcv_mss(sk); + + tcp_send_synack(sk); +#if 0 + /* Note, we could accept data and URG from this segment. + * There are no obstacles to make this (except that we must + * either change tcp_recvmsg() to prevent it from returning data + * before 3WHS completes per RFC793, or employ TCP Fast Open). + * + * However, if we ignore data in ACKless segments sometimes, + * we have no reasons to accept it sometimes. + * Also, seems the code doing it in step6 of tcp_rcv_state_process + * is not flawless. So, discard packet for sanity. + * Uncomment this return to process the data. + */ + return -1; +#else + goto consume; +#endif + } + /* "fifth, if neither of the SYN or RST bits is set then + * drop the segment and return." + */ + +discard_and_undo: + tcp_clear_options(&tp->rx_opt); + tp->rx_opt.mss_clamp = saved_clamp; + tcp_drop_reason(sk, skb, reason); + return 0; + +reset_and_undo: + tcp_clear_options(&tp->rx_opt); + tp->rx_opt.mss_clamp = saved_clamp; + return 1; +} + +static void tcp_rcv_synrecv_state_fastopen(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct request_sock *req; + + /* If we are still handling the SYNACK RTO, see if timestamp ECR allows + * undo. If peer SACKs triggered fast recovery, we can't undo here. + */ + if (inet_csk(sk)->icsk_ca_state == TCP_CA_Loss && !tp->packets_out) + tcp_try_undo_recovery(sk); + + /* Reset rtx states to prevent spurious retransmits_timed_out() */ + tp->retrans_stamp = 0; + inet_csk(sk)->icsk_retransmits = 0; + + /* Once we leave TCP_SYN_RECV or TCP_FIN_WAIT_1, + * we no longer need req so release it. + */ + req = rcu_dereference_protected(tp->fastopen_rsk, + lockdep_sock_is_held(sk)); + reqsk_fastopen_remove(sk, req, false); + + /* Re-arm the timer because data may have been sent out. + * This is similar to the regular data transmission case + * when new data has just been ack'ed. + * + * (TFO) - we could try to be more aggressive and + * retransmitting any data sooner based on when they + * are sent out. + */ + tcp_rearm_rto(sk); +} + +/* + * This function implements the receiving procedure of RFC 793 for + * all states except ESTABLISHED and TIME_WAIT. + * It's called from both tcp_v4_rcv and tcp_v6_rcv and should be + * address independent. + */ + +int tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct inet_connection_sock *icsk = inet_csk(sk); + const struct tcphdr *th = tcp_hdr(skb); + struct request_sock *req; + int queued = 0; + bool acceptable; + SKB_DR(reason); + + switch (sk->sk_state) { + case TCP_CLOSE: + SKB_DR_SET(reason, TCP_CLOSE); + goto discard; + + case TCP_LISTEN: + if (th->ack) + return 1; + + if (th->rst) { + SKB_DR_SET(reason, TCP_RESET); + goto discard; + } + if (th->syn) { + if (th->fin) { + SKB_DR_SET(reason, TCP_FLAGS); + goto discard; + } + /* It is possible that we process SYN packets from backlog, + * so we need to make sure to disable BH and RCU right there. + */ + rcu_read_lock(); + local_bh_disable(); + acceptable = icsk->icsk_af_ops->conn_request(sk, skb) >= 0; + local_bh_enable(); + rcu_read_unlock(); + + if (!acceptable) + return 1; + consume_skb(skb); + return 0; + } + SKB_DR_SET(reason, TCP_FLAGS); + goto discard; + + case TCP_SYN_SENT: + tp->rx_opt.saw_tstamp = 0; + tcp_mstamp_refresh(tp); + queued = tcp_rcv_synsent_state_process(sk, skb, th); + if (queued >= 0) + return queued; + + /* Do step6 onward by hand. */ + tcp_urg(sk, skb, th); + __kfree_skb(skb); + tcp_data_snd_check(sk); + return 0; + } + + tcp_mstamp_refresh(tp); + tp->rx_opt.saw_tstamp = 0; + req = rcu_dereference_protected(tp->fastopen_rsk, + lockdep_sock_is_held(sk)); + if (req) { + bool req_stolen; + + WARN_ON_ONCE(sk->sk_state != TCP_SYN_RECV && + sk->sk_state != TCP_FIN_WAIT1); + + if (!tcp_check_req(sk, skb, req, true, &req_stolen)) { + SKB_DR_SET(reason, TCP_FASTOPEN); + goto discard; + } + } + + if (!th->ack && !th->rst && !th->syn) { + SKB_DR_SET(reason, TCP_FLAGS); + goto discard; + } + if (!tcp_validate_incoming(sk, skb, th, 0)) + return 0; + + /* step 5: check the ACK field */ + acceptable = tcp_ack(sk, skb, FLAG_SLOWPATH | + FLAG_UPDATE_TS_RECENT | + FLAG_NO_CHALLENGE_ACK) > 0; + + if (!acceptable) { + if (sk->sk_state == TCP_SYN_RECV) + return 1; /* send one RST */ + tcp_send_challenge_ack(sk); + SKB_DR_SET(reason, TCP_OLD_ACK); + goto discard; + } + switch (sk->sk_state) { + case TCP_SYN_RECV: + tp->delivered++; /* SYN-ACK delivery isn't tracked in tcp_ack */ + if (!tp->srtt_us) + tcp_synack_rtt_meas(sk, req); + + if (req) { + tcp_rcv_synrecv_state_fastopen(sk); + } else { + tcp_try_undo_spurious_syn(sk); + tp->retrans_stamp = 0; + tcp_init_transfer(sk, BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB, + skb); + WRITE_ONCE(tp->copied_seq, tp->rcv_nxt); + } + smp_mb(); + tcp_set_state(sk, TCP_ESTABLISHED); + sk->sk_state_change(sk); + + /* Note, that this wakeup is only for marginal crossed SYN case. + * Passively open sockets are not waked up, because + * sk->sk_sleep == NULL and sk->sk_socket == NULL. + */ + if (sk->sk_socket) + sk_wake_async(sk, SOCK_WAKE_IO, POLL_OUT); + + tp->snd_una = TCP_SKB_CB(skb)->ack_seq; + tp->snd_wnd = ntohs(th->window) << tp->rx_opt.snd_wscale; + tcp_init_wl(tp, TCP_SKB_CB(skb)->seq); + + if (tp->rx_opt.tstamp_ok) + tp->advmss -= TCPOLEN_TSTAMP_ALIGNED; + + if (!inet_csk(sk)->icsk_ca_ops->cong_control) + tcp_update_pacing_rate(sk); + + /* Prevent spurious tcp_cwnd_restart() on first data packet */ + tp->lsndtime = tcp_jiffies32; + + tcp_initialize_rcv_mss(sk); + tcp_fast_path_on(tp); + break; + + case TCP_FIN_WAIT1: { + int tmo; + + if (req) + tcp_rcv_synrecv_state_fastopen(sk); + + if (tp->snd_una != tp->write_seq) + break; + + tcp_set_state(sk, TCP_FIN_WAIT2); + WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | SEND_SHUTDOWN); + + sk_dst_confirm(sk); + + if (!sock_flag(sk, SOCK_DEAD)) { + /* Wake up lingering close() */ + sk->sk_state_change(sk); + break; + } + + if (tp->linger2 < 0) { + tcp_done(sk); + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTONDATA); + return 1; + } + if (TCP_SKB_CB(skb)->end_seq != TCP_SKB_CB(skb)->seq && + after(TCP_SKB_CB(skb)->end_seq - th->fin, tp->rcv_nxt)) { + /* Receive out of order FIN after close() */ + if (tp->syn_fastopen && th->fin) + tcp_fastopen_active_disable(sk); + tcp_done(sk); + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTONDATA); + return 1; + } + + tmo = tcp_fin_time(sk); + if (tmo > TCP_TIMEWAIT_LEN) { + inet_csk_reset_keepalive_timer(sk, tmo - TCP_TIMEWAIT_LEN); + } else if (th->fin || sock_owned_by_user(sk)) { + /* Bad case. We could lose such FIN otherwise. + * It is not a big problem, but it looks confusing + * and not so rare event. We still can lose it now, + * if it spins in bh_lock_sock(), but it is really + * marginal case. + */ + inet_csk_reset_keepalive_timer(sk, tmo); + } else { + tcp_time_wait(sk, TCP_FIN_WAIT2, tmo); + goto consume; + } + break; + } + + case TCP_CLOSING: + if (tp->snd_una == tp->write_seq) { + tcp_time_wait(sk, TCP_TIME_WAIT, 0); + goto consume; + } + break; + + case TCP_LAST_ACK: + if (tp->snd_una == tp->write_seq) { + tcp_update_metrics(sk); + tcp_done(sk); + goto consume; + } + break; + } + + /* step 6: check the URG bit */ + tcp_urg(sk, skb, th); + + /* step 7: process the segment text */ + switch (sk->sk_state) { + case TCP_CLOSE_WAIT: + case TCP_CLOSING: + case TCP_LAST_ACK: + if (!before(TCP_SKB_CB(skb)->seq, tp->rcv_nxt)) { + /* If a subflow has been reset, the packet should not + * continue to be processed, drop the packet. + */ + if (sk_is_mptcp(sk) && !mptcp_incoming_options(sk, skb)) + goto discard; + break; + } + fallthrough; + case TCP_FIN_WAIT1: + case TCP_FIN_WAIT2: + /* RFC 793 says to queue data in these states, + * RFC 1122 says we MUST send a reset. + * BSD 4.4 also does reset. + */ + if (sk->sk_shutdown & RCV_SHUTDOWN) { + if (TCP_SKB_CB(skb)->end_seq != TCP_SKB_CB(skb)->seq && + after(TCP_SKB_CB(skb)->end_seq - th->fin, tp->rcv_nxt)) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTONDATA); + tcp_reset(sk, skb); + return 1; + } + } + fallthrough; + case TCP_ESTABLISHED: + tcp_data_queue(sk, skb); + queued = 1; + break; + } + + /* tcp_data could move socket to TIME-WAIT */ + if (sk->sk_state != TCP_CLOSE) { + tcp_data_snd_check(sk); + tcp_ack_snd_check(sk); + } + + if (!queued) { +discard: + tcp_drop_reason(sk, skb, reason); + } + return 0; + +consume: + __kfree_skb(skb); + return 0; +} +EXPORT_SYMBOL(tcp_rcv_state_process); + +static inline void pr_drop_req(struct request_sock *req, __u16 port, int family) +{ + struct inet_request_sock *ireq = inet_rsk(req); + + if (family == AF_INET) + net_dbg_ratelimited("drop open request from %pI4/%u\n", + &ireq->ir_rmt_addr, port); +#if IS_ENABLED(CONFIG_IPV6) + else if (family == AF_INET6) + net_dbg_ratelimited("drop open request from %pI6/%u\n", + &ireq->ir_v6_rmt_addr, port); +#endif +} + +/* RFC3168 : 6.1.1 SYN packets must not have ECT/ECN bits set + * + * If we receive a SYN packet with these bits set, it means a + * network is playing bad games with TOS bits. In order to + * avoid possible false congestion notifications, we disable + * TCP ECN negotiation. + * + * Exception: tcp_ca wants ECN. This is required for DCTCP + * congestion control: Linux DCTCP asserts ECT on all packets, + * including SYN, which is most optimal solution; however, + * others, such as FreeBSD do not. + * + * Exception: At least one of the reserved bits of the TCP header (th->res1) is + * set, indicating the use of a future TCP extension (such as AccECN). See + * RFC8311 §4.3 which updates RFC3168 to allow the development of such + * extensions. + */ +static void tcp_ecn_create_request(struct request_sock *req, + const struct sk_buff *skb, + const struct sock *listen_sk, + const struct dst_entry *dst) +{ + const struct tcphdr *th = tcp_hdr(skb); + const struct net *net = sock_net(listen_sk); + bool th_ecn = th->ece && th->cwr; + bool ect, ecn_ok; + u32 ecn_ok_dst; + + if (!th_ecn) + return; + + ect = !INET_ECN_is_not_ect(TCP_SKB_CB(skb)->ip_dsfield); + ecn_ok_dst = dst_feature(dst, DST_FEATURE_ECN_MASK); + ecn_ok = READ_ONCE(net->ipv4.sysctl_tcp_ecn) || ecn_ok_dst; + + if (((!ect || th->res1) && ecn_ok) || tcp_ca_needs_ecn(listen_sk) || + (ecn_ok_dst & DST_FEATURE_ECN_CA) || + tcp_bpf_ca_needs_ecn((struct sock *)req)) + inet_rsk(req)->ecn_ok = 1; +} + +static void tcp_openreq_init(struct request_sock *req, + const struct tcp_options_received *rx_opt, + struct sk_buff *skb, const struct sock *sk) +{ + struct inet_request_sock *ireq = inet_rsk(req); + + req->rsk_rcv_wnd = 0; /* So that tcp_send_synack() knows! */ + tcp_rsk(req)->rcv_isn = TCP_SKB_CB(skb)->seq; + tcp_rsk(req)->rcv_nxt = TCP_SKB_CB(skb)->seq + 1; + tcp_rsk(req)->snt_synack = 0; + tcp_rsk(req)->last_oow_ack_time = 0; + req->mss = rx_opt->mss_clamp; + req->ts_recent = rx_opt->saw_tstamp ? rx_opt->rcv_tsval : 0; + ireq->tstamp_ok = rx_opt->tstamp_ok; + ireq->sack_ok = rx_opt->sack_ok; + ireq->snd_wscale = rx_opt->snd_wscale; + ireq->wscale_ok = rx_opt->wscale_ok; + ireq->acked = 0; + ireq->ecn_ok = 0; + ireq->ir_rmt_port = tcp_hdr(skb)->source; + ireq->ir_num = ntohs(tcp_hdr(skb)->dest); + ireq->ir_mark = inet_request_mark(sk, skb); +#if IS_ENABLED(CONFIG_SMC) + ireq->smc_ok = rx_opt->smc_ok && !(tcp_sk(sk)->smc_hs_congested && + tcp_sk(sk)->smc_hs_congested(sk)); +#endif +} + +struct request_sock *inet_reqsk_alloc(const struct request_sock_ops *ops, + struct sock *sk_listener, + bool attach_listener) +{ + struct request_sock *req = reqsk_alloc(ops, sk_listener, + attach_listener); + + if (req) { + struct inet_request_sock *ireq = inet_rsk(req); + + ireq->ireq_opt = NULL; +#if IS_ENABLED(CONFIG_IPV6) + ireq->pktopts = NULL; +#endif + atomic64_set(&ireq->ir_cookie, 0); + ireq->ireq_state = TCP_NEW_SYN_RECV; + write_pnet(&ireq->ireq_net, sock_net(sk_listener)); + ireq->ireq_family = sk_listener->sk_family; + req->timeout = TCP_TIMEOUT_INIT; + } + + return req; +} +EXPORT_SYMBOL(inet_reqsk_alloc); + +/* + * Return true if a syncookie should be sent + */ +static bool tcp_syn_flood_action(const struct sock *sk, const char *proto) +{ + struct request_sock_queue *queue = &inet_csk(sk)->icsk_accept_queue; + const char *msg = "Dropping request"; + struct net *net = sock_net(sk); + bool want_cookie = false; + u8 syncookies; + + syncookies = READ_ONCE(net->ipv4.sysctl_tcp_syncookies); + +#ifdef CONFIG_SYN_COOKIES + if (syncookies) { + msg = "Sending cookies"; + want_cookie = true; + __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPREQQFULLDOCOOKIES); + } else +#endif + __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPREQQFULLDROP); + + if (!queue->synflood_warned && syncookies != 2 && + xchg(&queue->synflood_warned, 1) == 0) + net_info_ratelimited("%s: Possible SYN flooding on port %d. %s. Check SNMP counters.\n", + proto, sk->sk_num, msg); + + return want_cookie; +} + +static void tcp_reqsk_record_syn(const struct sock *sk, + struct request_sock *req, + const struct sk_buff *skb) +{ + if (tcp_sk(sk)->save_syn) { + u32 len = skb_network_header_len(skb) + tcp_hdrlen(skb); + struct saved_syn *saved_syn; + u32 mac_hdrlen; + void *base; + + if (tcp_sk(sk)->save_syn == 2) { /* Save full header. */ + base = skb_mac_header(skb); + mac_hdrlen = skb_mac_header_len(skb); + len += mac_hdrlen; + } else { + base = skb_network_header(skb); + mac_hdrlen = 0; + } + + saved_syn = kmalloc(struct_size(saved_syn, data, len), + GFP_ATOMIC); + if (saved_syn) { + saved_syn->mac_hdrlen = mac_hdrlen; + saved_syn->network_hdrlen = skb_network_header_len(skb); + saved_syn->tcp_hdrlen = tcp_hdrlen(skb); + memcpy(saved_syn->data, base, len); + req->saved_syn = saved_syn; + } + } +} + +/* If a SYN cookie is required and supported, returns a clamped MSS value to be + * used for SYN cookie generation. + */ +u16 tcp_get_syncookie_mss(struct request_sock_ops *rsk_ops, + const struct tcp_request_sock_ops *af_ops, + struct sock *sk, struct tcphdr *th) +{ + struct tcp_sock *tp = tcp_sk(sk); + u16 mss; + + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_syncookies) != 2 && + !inet_csk_reqsk_queue_is_full(sk)) + return 0; + + if (!tcp_syn_flood_action(sk, rsk_ops->slab_name)) + return 0; + + if (sk_acceptq_is_full(sk)) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_LISTENOVERFLOWS); + return 0; + } + + mss = tcp_parse_mss_option(th, tp->rx_opt.user_mss); + if (!mss) + mss = af_ops->mss_clamp; + + return mss; +} +EXPORT_SYMBOL_GPL(tcp_get_syncookie_mss); + +int tcp_conn_request(struct request_sock_ops *rsk_ops, + const struct tcp_request_sock_ops *af_ops, + struct sock *sk, struct sk_buff *skb) +{ + struct tcp_fastopen_cookie foc = { .len = -1 }; + __u32 isn = TCP_SKB_CB(skb)->tcp_tw_isn; + struct tcp_options_received tmp_opt; + struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); + struct sock *fastopen_sk = NULL; + struct request_sock *req; + bool want_cookie = false; + struct dst_entry *dst; + struct flowi fl; + u8 syncookies; + + syncookies = READ_ONCE(net->ipv4.sysctl_tcp_syncookies); + + /* TW buckets are converted to open requests without + * limitations, they conserve resources and peer is + * evidently real one. + */ + if ((syncookies == 2 || inet_csk_reqsk_queue_is_full(sk)) && !isn) { + want_cookie = tcp_syn_flood_action(sk, rsk_ops->slab_name); + if (!want_cookie) + goto drop; + } + + if (sk_acceptq_is_full(sk)) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_LISTENOVERFLOWS); + goto drop; + } + + req = inet_reqsk_alloc(rsk_ops, sk, !want_cookie); + if (!req) + goto drop; + + req->syncookie = want_cookie; + tcp_rsk(req)->af_specific = af_ops; + tcp_rsk(req)->ts_off = 0; +#if IS_ENABLED(CONFIG_MPTCP) + tcp_rsk(req)->is_mptcp = 0; +#endif + + tcp_clear_options(&tmp_opt); + tmp_opt.mss_clamp = af_ops->mss_clamp; + tmp_opt.user_mss = tp->rx_opt.user_mss; + tcp_parse_options(sock_net(sk), skb, &tmp_opt, 0, + want_cookie ? NULL : &foc); + + if (want_cookie && !tmp_opt.saw_tstamp) + tcp_clear_options(&tmp_opt); + + if (IS_ENABLED(CONFIG_SMC) && want_cookie) + tmp_opt.smc_ok = 0; + + tmp_opt.tstamp_ok = tmp_opt.saw_tstamp; + tcp_openreq_init(req, &tmp_opt, skb, sk); + inet_rsk(req)->no_srccheck = inet_sk(sk)->transparent; + + /* Note: tcp_v6_init_req() might override ir_iif for link locals */ + inet_rsk(req)->ir_iif = inet_request_bound_dev_if(sk, skb); + + dst = af_ops->route_req(sk, skb, &fl, req); + if (!dst) + goto drop_and_free; + + if (tmp_opt.tstamp_ok) + tcp_rsk(req)->ts_off = af_ops->init_ts_off(net, skb); + + if (!want_cookie && !isn) { + int max_syn_backlog = READ_ONCE(net->ipv4.sysctl_max_syn_backlog); + + /* Kill the following clause, if you dislike this way. */ + if (!syncookies && + (max_syn_backlog - inet_csk_reqsk_queue_len(sk) < + (max_syn_backlog >> 2)) && + !tcp_peer_is_proven(req, dst)) { + /* Without syncookies last quarter of + * backlog is filled with destinations, + * proven to be alive. + * It means that we continue to communicate + * to destinations, already remembered + * to the moment of synflood. + */ + pr_drop_req(req, ntohs(tcp_hdr(skb)->source), + rsk_ops->family); + goto drop_and_release; + } + + isn = af_ops->init_seq(skb); + } + + tcp_ecn_create_request(req, skb, sk, dst); + + if (want_cookie) { + isn = cookie_init_sequence(af_ops, sk, skb, &req->mss); + if (!tmp_opt.tstamp_ok) + inet_rsk(req)->ecn_ok = 0; + } + + tcp_rsk(req)->snt_isn = isn; + tcp_rsk(req)->txhash = net_tx_rndhash(); + tcp_rsk(req)->syn_tos = TCP_SKB_CB(skb)->ip_dsfield; + tcp_openreq_init_rwin(req, sk, dst); + sk_rx_queue_set(req_to_sk(req), skb); + if (!want_cookie) { + tcp_reqsk_record_syn(sk, req, skb); + fastopen_sk = tcp_try_fastopen(sk, skb, req, &foc, dst); + } + if (fastopen_sk) { + af_ops->send_synack(fastopen_sk, dst, &fl, req, + &foc, TCP_SYNACK_FASTOPEN, skb); + /* Add the child socket directly into the accept queue */ + if (!inet_csk_reqsk_queue_add(sk, req, fastopen_sk)) { + reqsk_fastopen_remove(fastopen_sk, req, false); + bh_unlock_sock(fastopen_sk); + sock_put(fastopen_sk); + goto drop_and_free; + } + sk->sk_data_ready(sk); + bh_unlock_sock(fastopen_sk); + sock_put(fastopen_sk); + } else { + tcp_rsk(req)->tfo_listener = false; + if (!want_cookie) { + req->timeout = tcp_timeout_init((struct sock *)req); + inet_csk_reqsk_queue_hash_add(sk, req, req->timeout); + } + af_ops->send_synack(sk, dst, &fl, req, &foc, + !want_cookie ? TCP_SYNACK_NORMAL : + TCP_SYNACK_COOKIE, + skb); + if (want_cookie) { + reqsk_free(req); + return 0; + } + } + reqsk_put(req); + return 0; + +drop_and_release: + dst_release(dst); +drop_and_free: + __reqsk_free(req); +drop: + tcp_listendrop(sk); + return 0; +} +EXPORT_SYMBOL(tcp_conn_request); diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c new file mode 100644 index 000000000..be2c807ee --- /dev/null +++ b/net/ipv4/tcp_ipv4.c @@ -0,0 +1,3349 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * Implementation of the Transmission Control Protocol(TCP). + * + * IPv4 specific functions + * + * code split from: + * linux/ipv4/tcp.c + * linux/ipv4/tcp_input.c + * linux/ipv4/tcp_output.c + * + * See tcp.c for author information + */ + +/* + * Changes: + * David S. Miller : New socket lookup architecture. + * This code is dedicated to John Dyson. + * David S. Miller : Change semantics of established hash, + * half is devoted to TIME_WAIT sockets + * and the rest go in the other half. + * Andi Kleen : Add support for syncookies and fixed + * some bugs: ip options weren't passed to + * the TCP layer, missed a check for an + * ACK bit. + * Andi Kleen : Implemented fast path mtu discovery. + * Fixed many serious bugs in the + * request_sock handling and moved + * most of it into the af independent code. + * Added tail drop and some other bugfixes. + * Added new listen semantics. + * Mike McLagan : Routing by source + * Juan Jose Ciarlante: ip_dynaddr bits + * Andi Kleen: various fixes. + * Vitaly E. Lavrov : Transparent proxy revived after year + * coma. + * Andi Kleen : Fix new listen. + * Andi Kleen : Fix accept error reporting. + * YOSHIFUJI Hideaki @USAGI and: Support IPV6_V6ONLY socket option, which + * Alexey Kuznetsov allow both IPv4 and IPv6 sockets to bind + * a single port at the same time. + */ + +#define pr_fmt(fmt) "TCP: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#ifdef CONFIG_TCP_MD5SIG +static int tcp_v4_md5_hash_hdr(char *md5_hash, const struct tcp_md5sig_key *key, + __be32 daddr, __be32 saddr, const struct tcphdr *th); +#endif + +struct inet_hashinfo tcp_hashinfo; +EXPORT_SYMBOL(tcp_hashinfo); + +static DEFINE_PER_CPU(struct sock *, ipv4_tcp_sk); + +static u32 tcp_v4_init_seq(const struct sk_buff *skb) +{ + return secure_tcp_seq(ip_hdr(skb)->daddr, + ip_hdr(skb)->saddr, + tcp_hdr(skb)->dest, + tcp_hdr(skb)->source); +} + +static u32 tcp_v4_init_ts_off(const struct net *net, const struct sk_buff *skb) +{ + return secure_tcp_ts_off(net, ip_hdr(skb)->daddr, ip_hdr(skb)->saddr); +} + +int tcp_twsk_unique(struct sock *sk, struct sock *sktw, void *twp) +{ + int reuse = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_tw_reuse); + const struct inet_timewait_sock *tw = inet_twsk(sktw); + const struct tcp_timewait_sock *tcptw = tcp_twsk(sktw); + struct tcp_sock *tp = tcp_sk(sk); + + if (reuse == 2) { + /* Still does not detect *everything* that goes through + * lo, since we require a loopback src or dst address + * or direct binding to 'lo' interface. + */ + bool loopback = false; + if (tw->tw_bound_dev_if == LOOPBACK_IFINDEX) + loopback = true; +#if IS_ENABLED(CONFIG_IPV6) + if (tw->tw_family == AF_INET6) { + if (ipv6_addr_loopback(&tw->tw_v6_daddr) || + ipv6_addr_v4mapped_loopback(&tw->tw_v6_daddr) || + ipv6_addr_loopback(&tw->tw_v6_rcv_saddr) || + ipv6_addr_v4mapped_loopback(&tw->tw_v6_rcv_saddr)) + loopback = true; + } else +#endif + { + if (ipv4_is_loopback(tw->tw_daddr) || + ipv4_is_loopback(tw->tw_rcv_saddr)) + loopback = true; + } + if (!loopback) + reuse = 0; + } + + /* With PAWS, it is safe from the viewpoint + of data integrity. Even without PAWS it is safe provided sequence + spaces do not overlap i.e. at data rates <= 80Mbit/sec. + + Actually, the idea is close to VJ's one, only timestamp cache is + held not per host, but per port pair and TW bucket is used as state + holder. + + If TW bucket has been already destroyed we fall back to VJ's scheme + and use initial timestamp retrieved from peer table. + */ + if (tcptw->tw_ts_recent_stamp && + (!twp || (reuse && time_after32(ktime_get_seconds(), + tcptw->tw_ts_recent_stamp)))) { + /* In case of repair and re-using TIME-WAIT sockets we still + * want to be sure that it is safe as above but honor the + * sequence numbers and time stamps set as part of the repair + * process. + * + * Without this check re-using a TIME-WAIT socket with TCP + * repair would accumulate a -1 on the repair assigned + * sequence number. The first time it is reused the sequence + * is -1, the second time -2, etc. This fixes that issue + * without appearing to create any others. + */ + if (likely(!tp->repair)) { + u32 seq = tcptw->tw_snd_nxt + 65535 + 2; + + if (!seq) + seq = 1; + WRITE_ONCE(tp->write_seq, seq); + tp->rx_opt.ts_recent = tcptw->tw_ts_recent; + tp->rx_opt.ts_recent_stamp = tcptw->tw_ts_recent_stamp; + } + sock_hold(sktw); + return 1; + } + + return 0; +} +EXPORT_SYMBOL_GPL(tcp_twsk_unique); + +static int tcp_v4_pre_connect(struct sock *sk, struct sockaddr *uaddr, + int addr_len) +{ + /* This check is replicated from tcp_v4_connect() and intended to + * prevent BPF program called below from accessing bytes that are out + * of the bound specified by user in addr_len. + */ + if (addr_len < sizeof(struct sockaddr_in)) + return -EINVAL; + + sock_owned_by_me(sk); + + return BPF_CGROUP_RUN_PROG_INET4_CONNECT(sk, uaddr); +} + +/* This will initiate an outgoing connection. */ +int tcp_v4_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) +{ + struct sockaddr_in *usin = (struct sockaddr_in *)uaddr; + struct inet_timewait_death_row *tcp_death_row; + struct inet_sock *inet = inet_sk(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct ip_options_rcu *inet_opt; + struct net *net = sock_net(sk); + __be16 orig_sport, orig_dport; + __be32 daddr, nexthop; + struct flowi4 *fl4; + struct rtable *rt; + int err; + + if (addr_len < sizeof(struct sockaddr_in)) + return -EINVAL; + + if (usin->sin_family != AF_INET) + return -EAFNOSUPPORT; + + nexthop = daddr = usin->sin_addr.s_addr; + inet_opt = rcu_dereference_protected(inet->inet_opt, + lockdep_sock_is_held(sk)); + if (inet_opt && inet_opt->opt.srr) { + if (!daddr) + return -EINVAL; + nexthop = inet_opt->opt.faddr; + } + + orig_sport = inet->inet_sport; + orig_dport = usin->sin_port; + fl4 = &inet->cork.fl.u.ip4; + rt = ip_route_connect(fl4, nexthop, inet->inet_saddr, + sk->sk_bound_dev_if, IPPROTO_TCP, orig_sport, + orig_dport, sk); + if (IS_ERR(rt)) { + err = PTR_ERR(rt); + if (err == -ENETUNREACH) + IP_INC_STATS(net, IPSTATS_MIB_OUTNOROUTES); + return err; + } + + if (rt->rt_flags & (RTCF_MULTICAST | RTCF_BROADCAST)) { + ip_rt_put(rt); + return -ENETUNREACH; + } + + if (!inet_opt || !inet_opt->opt.srr) + daddr = fl4->daddr; + + tcp_death_row = &sock_net(sk)->ipv4.tcp_death_row; + + if (!inet->inet_saddr) { + err = inet_bhash2_update_saddr(sk, &fl4->saddr, AF_INET); + if (err) { + ip_rt_put(rt); + return err; + } + } else { + sk_rcv_saddr_set(sk, inet->inet_saddr); + } + + if (tp->rx_opt.ts_recent_stamp && inet->inet_daddr != daddr) { + /* Reset inherited state */ + tp->rx_opt.ts_recent = 0; + tp->rx_opt.ts_recent_stamp = 0; + if (likely(!tp->repair)) + WRITE_ONCE(tp->write_seq, 0); + } + + inet->inet_dport = usin->sin_port; + sk_daddr_set(sk, daddr); + + inet_csk(sk)->icsk_ext_hdr_len = 0; + if (inet_opt) + inet_csk(sk)->icsk_ext_hdr_len = inet_opt->opt.optlen; + + tp->rx_opt.mss_clamp = TCP_MSS_DEFAULT; + + /* Socket identity is still unknown (sport may be zero). + * However we set state to SYN-SENT and not releasing socket + * lock select source port, enter ourselves into the hash tables and + * complete initialization after this. + */ + tcp_set_state(sk, TCP_SYN_SENT); + err = inet_hash_connect(tcp_death_row, sk); + if (err) + goto failure; + + sk_set_txhash(sk); + + rt = ip_route_newports(fl4, rt, orig_sport, orig_dport, + inet->inet_sport, inet->inet_dport, sk); + if (IS_ERR(rt)) { + err = PTR_ERR(rt); + rt = NULL; + goto failure; + } + /* OK, now commit destination to socket. */ + sk->sk_gso_type = SKB_GSO_TCPV4; + sk_setup_caps(sk, &rt->dst); + rt = NULL; + + if (likely(!tp->repair)) { + if (!tp->write_seq) + WRITE_ONCE(tp->write_seq, + secure_tcp_seq(inet->inet_saddr, + inet->inet_daddr, + inet->inet_sport, + usin->sin_port)); + WRITE_ONCE(tp->tsoffset, + secure_tcp_ts_off(net, inet->inet_saddr, + inet->inet_daddr)); + } + + atomic_set(&inet->inet_id, get_random_u16()); + + if (tcp_fastopen_defer_connect(sk, &err)) + return err; + if (err) + goto failure; + + err = tcp_connect(sk); + + if (err) + goto failure; + + return 0; + +failure: + /* + * This unhashes the socket and releases the local port, + * if necessary. + */ + tcp_set_state(sk, TCP_CLOSE); + inet_bhash2_reset_saddr(sk); + ip_rt_put(rt); + sk->sk_route_caps = 0; + inet->inet_dport = 0; + return err; +} +EXPORT_SYMBOL(tcp_v4_connect); + +/* + * This routine reacts to ICMP_FRAG_NEEDED mtu indications as defined in RFC1191. + * It can be called through tcp_release_cb() if socket was owned by user + * at the time tcp_v4_err() was called to handle ICMP message. + */ +void tcp_v4_mtu_reduced(struct sock *sk) +{ + struct inet_sock *inet = inet_sk(sk); + struct dst_entry *dst; + u32 mtu; + + if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) + return; + mtu = READ_ONCE(tcp_sk(sk)->mtu_info); + dst = inet_csk_update_pmtu(sk, mtu); + if (!dst) + return; + + /* Something is about to be wrong... Remember soft error + * for the case, if this connection will not able to recover. + */ + if (mtu < dst_mtu(dst) && ip_dont_fragment(sk, dst)) + sk->sk_err_soft = EMSGSIZE; + + mtu = dst_mtu(dst); + + if (inet->pmtudisc != IP_PMTUDISC_DONT && + ip_sk_accept_pmtu(sk) && + inet_csk(sk)->icsk_pmtu_cookie > mtu) { + tcp_sync_mss(sk, mtu); + + /* Resend the TCP packet because it's + * clear that the old packet has been + * dropped. This is the new "fast" path mtu + * discovery. + */ + tcp_simple_retransmit(sk); + } /* else let the usual retransmit timer handle it */ +} +EXPORT_SYMBOL(tcp_v4_mtu_reduced); + +static void do_redirect(struct sk_buff *skb, struct sock *sk) +{ + struct dst_entry *dst = __sk_dst_check(sk, 0); + + if (dst) + dst->ops->redirect(dst, sk, skb); +} + + +/* handle ICMP messages on TCP_NEW_SYN_RECV request sockets */ +void tcp_req_err(struct sock *sk, u32 seq, bool abort) +{ + struct request_sock *req = inet_reqsk(sk); + struct net *net = sock_net(sk); + + /* ICMPs are not backlogged, hence we cannot get + * an established socket here. + */ + if (seq != tcp_rsk(req)->snt_isn) { + __NET_INC_STATS(net, LINUX_MIB_OUTOFWINDOWICMPS); + } else if (abort) { + /* + * Still in SYN_RECV, just remove it silently. + * There is no good way to pass the error to the newly + * created socket, and POSIX does not want network + * errors returned from accept(). + */ + inet_csk_reqsk_queue_drop(req->rsk_listener, req); + tcp_listendrop(req->rsk_listener); + } + reqsk_put(req); +} +EXPORT_SYMBOL(tcp_req_err); + +/* TCP-LD (RFC 6069) logic */ +void tcp_ld_RTO_revert(struct sock *sk, u32 seq) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *skb; + s32 remaining; + u32 delta_us; + + if (sock_owned_by_user(sk)) + return; + + if (seq != tp->snd_una || !icsk->icsk_retransmits || + !icsk->icsk_backoff) + return; + + skb = tcp_rtx_queue_head(sk); + if (WARN_ON_ONCE(!skb)) + return; + + icsk->icsk_backoff--; + icsk->icsk_rto = tp->srtt_us ? __tcp_set_rto(tp) : TCP_TIMEOUT_INIT; + icsk->icsk_rto = inet_csk_rto_backoff(icsk, TCP_RTO_MAX); + + tcp_mstamp_refresh(tp); + delta_us = (u32)(tp->tcp_mstamp - tcp_skb_timestamp_us(skb)); + remaining = icsk->icsk_rto - usecs_to_jiffies(delta_us); + + if (remaining > 0) { + inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, + remaining, TCP_RTO_MAX); + } else { + /* RTO revert clocked out retransmission. + * Will retransmit now. + */ + tcp_retransmit_timer(sk); + } +} +EXPORT_SYMBOL(tcp_ld_RTO_revert); + +/* + * This routine is called by the ICMP module when it gets some + * sort of error condition. If err < 0 then the socket should + * be closed and the error returned to the user. If err > 0 + * it's just the icmp type << 8 | icmp code. After adjustment + * header points to the first 8 bytes of the tcp header. We need + * to find the appropriate port. + * + * The locking strategy used here is very "optimistic". When + * someone else accesses the socket the ICMP is just dropped + * and for some paths there is no check at all. + * A more general error queue to queue errors for later handling + * is probably better. + * + */ + +int tcp_v4_err(struct sk_buff *skb, u32 info) +{ + const struct iphdr *iph = (const struct iphdr *)skb->data; + struct tcphdr *th = (struct tcphdr *)(skb->data + (iph->ihl << 2)); + struct tcp_sock *tp; + struct inet_sock *inet; + const int type = icmp_hdr(skb)->type; + const int code = icmp_hdr(skb)->code; + struct sock *sk; + struct request_sock *fastopen; + u32 seq, snd_una; + int err; + struct net *net = dev_net(skb->dev); + + sk = __inet_lookup_established(net, net->ipv4.tcp_death_row.hashinfo, + iph->daddr, th->dest, iph->saddr, + ntohs(th->source), inet_iif(skb), 0); + if (!sk) { + __ICMP_INC_STATS(net, ICMP_MIB_INERRORS); + return -ENOENT; + } + if (sk->sk_state == TCP_TIME_WAIT) { + inet_twsk_put(inet_twsk(sk)); + return 0; + } + seq = ntohl(th->seq); + if (sk->sk_state == TCP_NEW_SYN_RECV) { + tcp_req_err(sk, seq, type == ICMP_PARAMETERPROB || + type == ICMP_TIME_EXCEEDED || + (type == ICMP_DEST_UNREACH && + (code == ICMP_NET_UNREACH || + code == ICMP_HOST_UNREACH))); + return 0; + } + + bh_lock_sock(sk); + /* If too many ICMPs get dropped on busy + * servers this needs to be solved differently. + * We do take care of PMTU discovery (RFC1191) special case : + * we can receive locally generated ICMP messages while socket is held. + */ + if (sock_owned_by_user(sk)) { + if (!(type == ICMP_DEST_UNREACH && code == ICMP_FRAG_NEEDED)) + __NET_INC_STATS(net, LINUX_MIB_LOCKDROPPEDICMPS); + } + if (sk->sk_state == TCP_CLOSE) + goto out; + + if (static_branch_unlikely(&ip4_min_ttl)) { + /* min_ttl can be changed concurrently from do_ip_setsockopt() */ + if (unlikely(iph->ttl < READ_ONCE(inet_sk(sk)->min_ttl))) { + __NET_INC_STATS(net, LINUX_MIB_TCPMINTTLDROP); + goto out; + } + } + + tp = tcp_sk(sk); + /* XXX (TFO) - tp->snd_una should be ISN (tcp_create_openreq_child() */ + fastopen = rcu_dereference(tp->fastopen_rsk); + snd_una = fastopen ? tcp_rsk(fastopen)->snt_isn : tp->snd_una; + if (sk->sk_state != TCP_LISTEN && + !between(seq, snd_una, tp->snd_nxt)) { + __NET_INC_STATS(net, LINUX_MIB_OUTOFWINDOWICMPS); + goto out; + } + + switch (type) { + case ICMP_REDIRECT: + if (!sock_owned_by_user(sk)) + do_redirect(skb, sk); + goto out; + case ICMP_SOURCE_QUENCH: + /* Just silently ignore these. */ + goto out; + case ICMP_PARAMETERPROB: + err = EPROTO; + break; + case ICMP_DEST_UNREACH: + if (code > NR_ICMP_UNREACH) + goto out; + + if (code == ICMP_FRAG_NEEDED) { /* PMTU discovery (RFC1191) */ + /* We are not interested in TCP_LISTEN and open_requests + * (SYN-ACKs send out by Linux are always <576bytes so + * they should go through unfragmented). + */ + if (sk->sk_state == TCP_LISTEN) + goto out; + + WRITE_ONCE(tp->mtu_info, info); + if (!sock_owned_by_user(sk)) { + tcp_v4_mtu_reduced(sk); + } else { + if (!test_and_set_bit(TCP_MTU_REDUCED_DEFERRED, &sk->sk_tsq_flags)) + sock_hold(sk); + } + goto out; + } + + err = icmp_err_convert[code].errno; + /* check if this ICMP message allows revert of backoff. + * (see RFC 6069) + */ + if (!fastopen && + (code == ICMP_NET_UNREACH || code == ICMP_HOST_UNREACH)) + tcp_ld_RTO_revert(sk, seq); + break; + case ICMP_TIME_EXCEEDED: + err = EHOSTUNREACH; + break; + default: + goto out; + } + + switch (sk->sk_state) { + case TCP_SYN_SENT: + case TCP_SYN_RECV: + /* Only in fast or simultaneous open. If a fast open socket is + * already accepted it is treated as a connected one below. + */ + if (fastopen && !fastopen->sk) + break; + + ip_icmp_error(sk, skb, err, th->dest, info, (u8 *)th); + + if (!sock_owned_by_user(sk)) { + sk->sk_err = err; + + sk_error_report(sk); + + tcp_done(sk); + } else { + sk->sk_err_soft = err; + } + goto out; + } + + /* If we've already connected we will keep trying + * until we time out, or the user gives up. + * + * rfc1122 4.2.3.9 allows to consider as hard errors + * only PROTO_UNREACH and PORT_UNREACH (well, FRAG_FAILED too, + * but it is obsoleted by pmtu discovery). + * + * Note, that in modern internet, where routing is unreliable + * and in each dark corner broken firewalls sit, sending random + * errors ordered by their masters even this two messages finally lose + * their original sense (even Linux sends invalid PORT_UNREACHs) + * + * Now we are in compliance with RFCs. + * --ANK (980905) + */ + + inet = inet_sk(sk); + if (!sock_owned_by_user(sk) && inet->recverr) { + sk->sk_err = err; + sk_error_report(sk); + } else { /* Only an error on timeout */ + sk->sk_err_soft = err; + } + +out: + bh_unlock_sock(sk); + sock_put(sk); + return 0; +} + +void __tcp_v4_send_check(struct sk_buff *skb, __be32 saddr, __be32 daddr) +{ + struct tcphdr *th = tcp_hdr(skb); + + th->check = ~tcp_v4_check(skb->len, saddr, daddr, 0); + skb->csum_start = skb_transport_header(skb) - skb->head; + skb->csum_offset = offsetof(struct tcphdr, check); +} + +/* This routine computes an IPv4 TCP checksum. */ +void tcp_v4_send_check(struct sock *sk, struct sk_buff *skb) +{ + const struct inet_sock *inet = inet_sk(sk); + + __tcp_v4_send_check(skb, inet->inet_saddr, inet->inet_daddr); +} +EXPORT_SYMBOL(tcp_v4_send_check); + +/* + * This routine will send an RST to the other tcp. + * + * Someone asks: why I NEVER use socket parameters (TOS, TTL etc.) + * for reset. + * Answer: if a packet caused RST, it is not for a socket + * existing in our system, if it is matched to a socket, + * it is just duplicate segment or bug in other side's TCP. + * So that we build reply only basing on parameters + * arrived with segment. + * Exception: precedence violation. We do not implement it in any case. + */ + +#ifdef CONFIG_TCP_MD5SIG +#define OPTION_BYTES TCPOLEN_MD5SIG_ALIGNED +#else +#define OPTION_BYTES sizeof(__be32) +#endif + +static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb) +{ + const struct tcphdr *th = tcp_hdr(skb); + struct { + struct tcphdr th; + __be32 opt[OPTION_BYTES / sizeof(__be32)]; + } rep; + struct ip_reply_arg arg; +#ifdef CONFIG_TCP_MD5SIG + struct tcp_md5sig_key *key = NULL; + const __u8 *hash_location = NULL; + unsigned char newhash[16]; + int genhash; + struct sock *sk1 = NULL; +#endif + u64 transmit_time = 0; + struct sock *ctl_sk; + struct net *net; + u32 txhash = 0; + + /* Never send a reset in response to a reset. */ + if (th->rst) + return; + + /* If sk not NULL, it means we did a successful lookup and incoming + * route had to be correct. prequeue might have dropped our dst. + */ + if (!sk && skb_rtable(skb)->rt_type != RTN_LOCAL) + return; + + /* Swap the send and the receive. */ + memset(&rep, 0, sizeof(rep)); + rep.th.dest = th->source; + rep.th.source = th->dest; + rep.th.doff = sizeof(struct tcphdr) / 4; + rep.th.rst = 1; + + if (th->ack) { + rep.th.seq = th->ack_seq; + } else { + rep.th.ack = 1; + rep.th.ack_seq = htonl(ntohl(th->seq) + th->syn + th->fin + + skb->len - (th->doff << 2)); + } + + memset(&arg, 0, sizeof(arg)); + arg.iov[0].iov_base = (unsigned char *)&rep; + arg.iov[0].iov_len = sizeof(rep.th); + + net = sk ? sock_net(sk) : dev_net(skb_dst(skb)->dev); +#ifdef CONFIG_TCP_MD5SIG + rcu_read_lock(); + hash_location = tcp_parse_md5sig_option(th); + if (sk && sk_fullsock(sk)) { + const union tcp_md5_addr *addr; + int l3index; + + /* sdif set, means packet ingressed via a device + * in an L3 domain and inet_iif is set to it. + */ + l3index = tcp_v4_sdif(skb) ? inet_iif(skb) : 0; + addr = (union tcp_md5_addr *)&ip_hdr(skb)->saddr; + key = tcp_md5_do_lookup(sk, l3index, addr, AF_INET); + } else if (hash_location) { + const union tcp_md5_addr *addr; + int sdif = tcp_v4_sdif(skb); + int dif = inet_iif(skb); + int l3index; + + /* + * active side is lost. Try to find listening socket through + * source port, and then find md5 key through listening socket. + * we are not loose security here: + * Incoming packet is checked with md5 hash with finding key, + * no RST generated if md5 hash doesn't match. + */ + sk1 = __inet_lookup_listener(net, net->ipv4.tcp_death_row.hashinfo, + NULL, 0, ip_hdr(skb)->saddr, + th->source, ip_hdr(skb)->daddr, + ntohs(th->source), dif, sdif); + /* don't send rst if it can't find key */ + if (!sk1) + goto out; + + /* sdif set, means packet ingressed via a device + * in an L3 domain and dif is set to it. + */ + l3index = sdif ? dif : 0; + addr = (union tcp_md5_addr *)&ip_hdr(skb)->saddr; + key = tcp_md5_do_lookup(sk1, l3index, addr, AF_INET); + if (!key) + goto out; + + + genhash = tcp_v4_md5_hash_skb(newhash, key, NULL, skb); + if (genhash || memcmp(hash_location, newhash, 16) != 0) + goto out; + + } + + if (key) { + rep.opt[0] = htonl((TCPOPT_NOP << 24) | + (TCPOPT_NOP << 16) | + (TCPOPT_MD5SIG << 8) | + TCPOLEN_MD5SIG); + /* Update length and the length the header thinks exists */ + arg.iov[0].iov_len += TCPOLEN_MD5SIG_ALIGNED; + rep.th.doff = arg.iov[0].iov_len / 4; + + tcp_v4_md5_hash_hdr((__u8 *) &rep.opt[1], + key, ip_hdr(skb)->saddr, + ip_hdr(skb)->daddr, &rep.th); + } +#endif + /* Can't co-exist with TCPMD5, hence check rep.opt[0] */ + if (rep.opt[0] == 0) { + __be32 mrst = mptcp_reset_option(skb); + + if (mrst) { + rep.opt[0] = mrst; + arg.iov[0].iov_len += sizeof(mrst); + rep.th.doff = arg.iov[0].iov_len / 4; + } + } + + arg.csum = csum_tcpudp_nofold(ip_hdr(skb)->daddr, + ip_hdr(skb)->saddr, /* XXX */ + arg.iov[0].iov_len, IPPROTO_TCP, 0); + arg.csumoffset = offsetof(struct tcphdr, check) / 2; + arg.flags = (sk && inet_sk_transparent(sk)) ? IP_REPLY_ARG_NOSRCCHECK : 0; + + /* When socket is gone, all binding information is lost. + * routing might fail in this case. No choice here, if we choose to force + * input interface, we will misroute in case of asymmetric route. + */ + if (sk) { + arg.bound_dev_if = sk->sk_bound_dev_if; + if (sk_fullsock(sk)) + trace_tcp_send_reset(sk, skb); + } + + BUILD_BUG_ON(offsetof(struct sock, sk_bound_dev_if) != + offsetof(struct inet_timewait_sock, tw_bound_dev_if)); + + arg.tos = ip_hdr(skb)->tos; + arg.uid = sock_net_uid(net, sk && sk_fullsock(sk) ? sk : NULL); + local_bh_disable(); + ctl_sk = this_cpu_read(ipv4_tcp_sk); + sock_net_set(ctl_sk, net); + if (sk) { + ctl_sk->sk_mark = (sk->sk_state == TCP_TIME_WAIT) ? + inet_twsk(sk)->tw_mark : sk->sk_mark; + ctl_sk->sk_priority = (sk->sk_state == TCP_TIME_WAIT) ? + inet_twsk(sk)->tw_priority : sk->sk_priority; + transmit_time = tcp_transmit_time(sk); + xfrm_sk_clone_policy(ctl_sk, sk); + txhash = (sk->sk_state == TCP_TIME_WAIT) ? + inet_twsk(sk)->tw_txhash : sk->sk_txhash; + } else { + ctl_sk->sk_mark = 0; + ctl_sk->sk_priority = 0; + } + ip_send_unicast_reply(ctl_sk, + skb, &TCP_SKB_CB(skb)->header.h4.opt, + ip_hdr(skb)->saddr, ip_hdr(skb)->daddr, + &arg, arg.iov[0].iov_len, + transmit_time, txhash); + + xfrm_sk_free_policy(ctl_sk); + sock_net_set(ctl_sk, &init_net); + __TCP_INC_STATS(net, TCP_MIB_OUTSEGS); + __TCP_INC_STATS(net, TCP_MIB_OUTRSTS); + local_bh_enable(); + +#ifdef CONFIG_TCP_MD5SIG +out: + rcu_read_unlock(); +#endif +} + +/* The code following below sending ACKs in SYN-RECV and TIME-WAIT states + outside socket context is ugly, certainly. What can I do? + */ + +static void tcp_v4_send_ack(const struct sock *sk, + struct sk_buff *skb, u32 seq, u32 ack, + u32 win, u32 tsval, u32 tsecr, int oif, + struct tcp_md5sig_key *key, + int reply_flags, u8 tos, u32 txhash) +{ + const struct tcphdr *th = tcp_hdr(skb); + struct { + struct tcphdr th; + __be32 opt[(TCPOLEN_TSTAMP_ALIGNED >> 2) +#ifdef CONFIG_TCP_MD5SIG + + (TCPOLEN_MD5SIG_ALIGNED >> 2) +#endif + ]; + } rep; + struct net *net = sock_net(sk); + struct ip_reply_arg arg; + struct sock *ctl_sk; + u64 transmit_time; + + memset(&rep.th, 0, sizeof(struct tcphdr)); + memset(&arg, 0, sizeof(arg)); + + arg.iov[0].iov_base = (unsigned char *)&rep; + arg.iov[0].iov_len = sizeof(rep.th); + if (tsecr) { + rep.opt[0] = htonl((TCPOPT_NOP << 24) | (TCPOPT_NOP << 16) | + (TCPOPT_TIMESTAMP << 8) | + TCPOLEN_TIMESTAMP); + rep.opt[1] = htonl(tsval); + rep.opt[2] = htonl(tsecr); + arg.iov[0].iov_len += TCPOLEN_TSTAMP_ALIGNED; + } + + /* Swap the send and the receive. */ + rep.th.dest = th->source; + rep.th.source = th->dest; + rep.th.doff = arg.iov[0].iov_len / 4; + rep.th.seq = htonl(seq); + rep.th.ack_seq = htonl(ack); + rep.th.ack = 1; + rep.th.window = htons(win); + +#ifdef CONFIG_TCP_MD5SIG + if (key) { + int offset = (tsecr) ? 3 : 0; + + rep.opt[offset++] = htonl((TCPOPT_NOP << 24) | + (TCPOPT_NOP << 16) | + (TCPOPT_MD5SIG << 8) | + TCPOLEN_MD5SIG); + arg.iov[0].iov_len += TCPOLEN_MD5SIG_ALIGNED; + rep.th.doff = arg.iov[0].iov_len/4; + + tcp_v4_md5_hash_hdr((__u8 *) &rep.opt[offset], + key, ip_hdr(skb)->saddr, + ip_hdr(skb)->daddr, &rep.th); + } +#endif + arg.flags = reply_flags; + arg.csum = csum_tcpudp_nofold(ip_hdr(skb)->daddr, + ip_hdr(skb)->saddr, /* XXX */ + arg.iov[0].iov_len, IPPROTO_TCP, 0); + arg.csumoffset = offsetof(struct tcphdr, check) / 2; + if (oif) + arg.bound_dev_if = oif; + arg.tos = tos; + arg.uid = sock_net_uid(net, sk_fullsock(sk) ? sk : NULL); + local_bh_disable(); + ctl_sk = this_cpu_read(ipv4_tcp_sk); + sock_net_set(ctl_sk, net); + ctl_sk->sk_mark = (sk->sk_state == TCP_TIME_WAIT) ? + inet_twsk(sk)->tw_mark : READ_ONCE(sk->sk_mark); + ctl_sk->sk_priority = (sk->sk_state == TCP_TIME_WAIT) ? + inet_twsk(sk)->tw_priority : READ_ONCE(sk->sk_priority); + transmit_time = tcp_transmit_time(sk); + ip_send_unicast_reply(ctl_sk, + skb, &TCP_SKB_CB(skb)->header.h4.opt, + ip_hdr(skb)->saddr, ip_hdr(skb)->daddr, + &arg, arg.iov[0].iov_len, + transmit_time, txhash); + + sock_net_set(ctl_sk, &init_net); + __TCP_INC_STATS(net, TCP_MIB_OUTSEGS); + local_bh_enable(); +} + +static void tcp_v4_timewait_ack(struct sock *sk, struct sk_buff *skb) +{ + struct inet_timewait_sock *tw = inet_twsk(sk); + struct tcp_timewait_sock *tcptw = tcp_twsk(sk); + + tcp_v4_send_ack(sk, skb, + tcptw->tw_snd_nxt, tcptw->tw_rcv_nxt, + tcptw->tw_rcv_wnd >> tw->tw_rcv_wscale, + tcp_time_stamp_raw() + tcptw->tw_ts_offset, + tcptw->tw_ts_recent, + tw->tw_bound_dev_if, + tcp_twsk_md5_key(tcptw), + tw->tw_transparent ? IP_REPLY_ARG_NOSRCCHECK : 0, + tw->tw_tos, + tw->tw_txhash + ); + + inet_twsk_put(tw); +} + +static void tcp_v4_reqsk_send_ack(const struct sock *sk, struct sk_buff *skb, + struct request_sock *req) +{ + const union tcp_md5_addr *addr; + int l3index; + + /* sk->sk_state == TCP_LISTEN -> for regular TCP_SYN_RECV + * sk->sk_state == TCP_SYN_RECV -> for Fast Open. + */ + u32 seq = (sk->sk_state == TCP_LISTEN) ? tcp_rsk(req)->snt_isn + 1 : + tcp_sk(sk)->snd_nxt; + + /* RFC 7323 2.3 + * The window field (SEG.WND) of every outgoing segment, with the + * exception of segments, MUST be right-shifted by + * Rcv.Wind.Shift bits: + */ + addr = (union tcp_md5_addr *)&ip_hdr(skb)->saddr; + l3index = tcp_v4_sdif(skb) ? inet_iif(skb) : 0; + tcp_v4_send_ack(sk, skb, seq, + tcp_rsk(req)->rcv_nxt, + req->rsk_rcv_wnd >> inet_rsk(req)->rcv_wscale, + tcp_time_stamp_raw() + tcp_rsk(req)->ts_off, + READ_ONCE(req->ts_recent), + 0, + tcp_md5_do_lookup(sk, l3index, addr, AF_INET), + inet_rsk(req)->no_srccheck ? IP_REPLY_ARG_NOSRCCHECK : 0, + ip_hdr(skb)->tos, + READ_ONCE(tcp_rsk(req)->txhash)); +} + +/* + * Send a SYN-ACK after having received a SYN. + * This still operates on a request_sock only, not on a big + * socket. + */ +static int tcp_v4_send_synack(const struct sock *sk, struct dst_entry *dst, + struct flowi *fl, + struct request_sock *req, + struct tcp_fastopen_cookie *foc, + enum tcp_synack_type synack_type, + struct sk_buff *syn_skb) +{ + const struct inet_request_sock *ireq = inet_rsk(req); + struct flowi4 fl4; + int err = -1; + struct sk_buff *skb; + u8 tos; + + /* First, grab a route. */ + if (!dst && (dst = inet_csk_route_req(sk, &fl4, req)) == NULL) + return -1; + + skb = tcp_make_synack(sk, dst, req, foc, synack_type, syn_skb); + + if (skb) { + __tcp_v4_send_check(skb, ireq->ir_loc_addr, ireq->ir_rmt_addr); + + tos = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_reflect_tos) ? + (tcp_rsk(req)->syn_tos & ~INET_ECN_MASK) | + (inet_sk(sk)->tos & INET_ECN_MASK) : + inet_sk(sk)->tos; + + if (!INET_ECN_is_capable(tos) && + tcp_bpf_ca_needs_ecn((struct sock *)req)) + tos |= INET_ECN_ECT_0; + + rcu_read_lock(); + err = ip_build_and_send_pkt(skb, sk, ireq->ir_loc_addr, + ireq->ir_rmt_addr, + rcu_dereference(ireq->ireq_opt), + tos); + rcu_read_unlock(); + err = net_xmit_eval(err); + } + + return err; +} + +/* + * IPv4 request_sock destructor. + */ +static void tcp_v4_reqsk_destructor(struct request_sock *req) +{ + kfree(rcu_dereference_protected(inet_rsk(req)->ireq_opt, 1)); +} + +#ifdef CONFIG_TCP_MD5SIG +/* + * RFC2385 MD5 checksumming requires a mapping of + * IP address->MD5 Key. + * We need to maintain these in the sk structure. + */ + +DEFINE_STATIC_KEY_FALSE(tcp_md5_needed); +EXPORT_SYMBOL(tcp_md5_needed); + +static bool better_md5_match(struct tcp_md5sig_key *old, struct tcp_md5sig_key *new) +{ + if (!old) + return true; + + /* l3index always overrides non-l3index */ + if (old->l3index && new->l3index == 0) + return false; + if (old->l3index == 0 && new->l3index) + return true; + + return old->prefixlen < new->prefixlen; +} + +/* Find the Key structure for an address. */ +struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index, + const union tcp_md5_addr *addr, + int family) +{ + const struct tcp_sock *tp = tcp_sk(sk); + struct tcp_md5sig_key *key; + const struct tcp_md5sig_info *md5sig; + __be32 mask; + struct tcp_md5sig_key *best_match = NULL; + bool match; + + /* caller either holds rcu_read_lock() or socket lock */ + md5sig = rcu_dereference_check(tp->md5sig_info, + lockdep_sock_is_held(sk)); + if (!md5sig) + return NULL; + + hlist_for_each_entry_rcu(key, &md5sig->head, node, + lockdep_sock_is_held(sk)) { + if (key->family != family) + continue; + if (key->flags & TCP_MD5SIG_FLAG_IFINDEX && key->l3index != l3index) + continue; + if (family == AF_INET) { + mask = inet_make_mask(key->prefixlen); + match = (key->addr.a4.s_addr & mask) == + (addr->a4.s_addr & mask); +#if IS_ENABLED(CONFIG_IPV6) + } else if (family == AF_INET6) { + match = ipv6_prefix_equal(&key->addr.a6, &addr->a6, + key->prefixlen); +#endif + } else { + match = false; + } + + if (match && better_md5_match(best_match, key)) + best_match = key; + } + return best_match; +} +EXPORT_SYMBOL(__tcp_md5_do_lookup); + +static struct tcp_md5sig_key *tcp_md5_do_lookup_exact(const struct sock *sk, + const union tcp_md5_addr *addr, + int family, u8 prefixlen, + int l3index, u8 flags) +{ + const struct tcp_sock *tp = tcp_sk(sk); + struct tcp_md5sig_key *key; + unsigned int size = sizeof(struct in_addr); + const struct tcp_md5sig_info *md5sig; + + /* caller either holds rcu_read_lock() or socket lock */ + md5sig = rcu_dereference_check(tp->md5sig_info, + lockdep_sock_is_held(sk)); + if (!md5sig) + return NULL; +#if IS_ENABLED(CONFIG_IPV6) + if (family == AF_INET6) + size = sizeof(struct in6_addr); +#endif + hlist_for_each_entry_rcu(key, &md5sig->head, node, + lockdep_sock_is_held(sk)) { + if (key->family != family) + continue; + if ((key->flags & TCP_MD5SIG_FLAG_IFINDEX) != (flags & TCP_MD5SIG_FLAG_IFINDEX)) + continue; + if (key->l3index != l3index) + continue; + if (!memcmp(&key->addr, addr, size) && + key->prefixlen == prefixlen) + return key; + } + return NULL; +} + +struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk, + const struct sock *addr_sk) +{ + const union tcp_md5_addr *addr; + int l3index; + + l3index = l3mdev_master_ifindex_by_index(sock_net(sk), + addr_sk->sk_bound_dev_if); + addr = (const union tcp_md5_addr *)&addr_sk->sk_daddr; + return tcp_md5_do_lookup(sk, l3index, addr, AF_INET); +} +EXPORT_SYMBOL(tcp_v4_md5_lookup); + +/* This can be called on a newly created socket, from other files */ +int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, + int family, u8 prefixlen, int l3index, u8 flags, + const u8 *newkey, u8 newkeylen, gfp_t gfp) +{ + /* Add Key to the list */ + struct tcp_md5sig_key *key; + struct tcp_sock *tp = tcp_sk(sk); + struct tcp_md5sig_info *md5sig; + + key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen, l3index, flags); + if (key) { + /* Pre-existing entry - just update that one. + * Note that the key might be used concurrently. + * data_race() is telling kcsan that we do not care of + * key mismatches, since changing MD5 key on live flows + * can lead to packet drops. + */ + data_race(memcpy(key->key, newkey, newkeylen)); + + /* Pairs with READ_ONCE() in tcp_md5_hash_key(). + * Also note that a reader could catch new key->keylen value + * but old key->key[], this is the reason we use __GFP_ZERO + * at sock_kmalloc() time below these lines. + */ + WRITE_ONCE(key->keylen, newkeylen); + + return 0; + } + + md5sig = rcu_dereference_protected(tp->md5sig_info, + lockdep_sock_is_held(sk)); + if (!md5sig) { + md5sig = kmalloc(sizeof(*md5sig), gfp); + if (!md5sig) + return -ENOMEM; + + sk_gso_disable(sk); + INIT_HLIST_HEAD(&md5sig->head); + rcu_assign_pointer(tp->md5sig_info, md5sig); + } + + key = sock_kmalloc(sk, sizeof(*key), gfp | __GFP_ZERO); + if (!key) + return -ENOMEM; + if (!tcp_alloc_md5sig_pool()) { + sock_kfree_s(sk, key, sizeof(*key)); + return -ENOMEM; + } + + memcpy(key->key, newkey, newkeylen); + key->keylen = newkeylen; + key->family = family; + key->prefixlen = prefixlen; + key->l3index = l3index; + key->flags = flags; + memcpy(&key->addr, addr, + (IS_ENABLED(CONFIG_IPV6) && family == AF_INET6) ? sizeof(struct in6_addr) : + sizeof(struct in_addr)); + hlist_add_head_rcu(&key->node, &md5sig->head); + return 0; +} +EXPORT_SYMBOL(tcp_md5_do_add); + +int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr, int family, + u8 prefixlen, int l3index, u8 flags) +{ + struct tcp_md5sig_key *key; + + key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen, l3index, flags); + if (!key) + return -ENOENT; + hlist_del_rcu(&key->node); + atomic_sub(sizeof(*key), &sk->sk_omem_alloc); + kfree_rcu(key, rcu); + return 0; +} +EXPORT_SYMBOL(tcp_md5_do_del); + +static void tcp_clear_md5_list(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct tcp_md5sig_key *key; + struct hlist_node *n; + struct tcp_md5sig_info *md5sig; + + md5sig = rcu_dereference_protected(tp->md5sig_info, 1); + + hlist_for_each_entry_safe(key, n, &md5sig->head, node) { + hlist_del_rcu(&key->node); + atomic_sub(sizeof(*key), &sk->sk_omem_alloc); + kfree_rcu(key, rcu); + } +} + +static int tcp_v4_parse_md5_keys(struct sock *sk, int optname, + sockptr_t optval, int optlen) +{ + struct tcp_md5sig cmd; + struct sockaddr_in *sin = (struct sockaddr_in *)&cmd.tcpm_addr; + const union tcp_md5_addr *addr; + u8 prefixlen = 32; + int l3index = 0; + u8 flags; + + if (optlen < sizeof(cmd)) + return -EINVAL; + + if (copy_from_sockptr(&cmd, optval, sizeof(cmd))) + return -EFAULT; + + if (sin->sin_family != AF_INET) + return -EINVAL; + + flags = cmd.tcpm_flags & TCP_MD5SIG_FLAG_IFINDEX; + + if (optname == TCP_MD5SIG_EXT && + cmd.tcpm_flags & TCP_MD5SIG_FLAG_PREFIX) { + prefixlen = cmd.tcpm_prefixlen; + if (prefixlen > 32) + return -EINVAL; + } + + if (optname == TCP_MD5SIG_EXT && cmd.tcpm_ifindex && + cmd.tcpm_flags & TCP_MD5SIG_FLAG_IFINDEX) { + struct net_device *dev; + + rcu_read_lock(); + dev = dev_get_by_index_rcu(sock_net(sk), cmd.tcpm_ifindex); + if (dev && netif_is_l3_master(dev)) + l3index = dev->ifindex; + + rcu_read_unlock(); + + /* ok to reference set/not set outside of rcu; + * right now device MUST be an L3 master + */ + if (!dev || !l3index) + return -EINVAL; + } + + addr = (union tcp_md5_addr *)&sin->sin_addr.s_addr; + + if (!cmd.tcpm_keylen) + return tcp_md5_do_del(sk, addr, AF_INET, prefixlen, l3index, flags); + + if (cmd.tcpm_keylen > TCP_MD5SIG_MAXKEYLEN) + return -EINVAL; + + return tcp_md5_do_add(sk, addr, AF_INET, prefixlen, l3index, flags, + cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL); +} + +static int tcp_v4_md5_hash_headers(struct tcp_md5sig_pool *hp, + __be32 daddr, __be32 saddr, + const struct tcphdr *th, int nbytes) +{ + struct tcp4_pseudohdr *bp; + struct scatterlist sg; + struct tcphdr *_th; + + bp = hp->scratch; + bp->saddr = saddr; + bp->daddr = daddr; + bp->pad = 0; + bp->protocol = IPPROTO_TCP; + bp->len = cpu_to_be16(nbytes); + + _th = (struct tcphdr *)(bp + 1); + memcpy(_th, th, sizeof(*th)); + _th->check = 0; + + sg_init_one(&sg, bp, sizeof(*bp) + sizeof(*th)); + ahash_request_set_crypt(hp->md5_req, &sg, NULL, + sizeof(*bp) + sizeof(*th)); + return crypto_ahash_update(hp->md5_req); +} + +static int tcp_v4_md5_hash_hdr(char *md5_hash, const struct tcp_md5sig_key *key, + __be32 daddr, __be32 saddr, const struct tcphdr *th) +{ + struct tcp_md5sig_pool *hp; + struct ahash_request *req; + + hp = tcp_get_md5sig_pool(); + if (!hp) + goto clear_hash_noput; + req = hp->md5_req; + + if (crypto_ahash_init(req)) + goto clear_hash; + if (tcp_v4_md5_hash_headers(hp, daddr, saddr, th, th->doff << 2)) + goto clear_hash; + if (tcp_md5_hash_key(hp, key)) + goto clear_hash; + ahash_request_set_crypt(req, NULL, md5_hash, 0); + if (crypto_ahash_final(req)) + goto clear_hash; + + tcp_put_md5sig_pool(); + return 0; + +clear_hash: + tcp_put_md5sig_pool(); +clear_hash_noput: + memset(md5_hash, 0, 16); + return 1; +} + +int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key, + const struct sock *sk, + const struct sk_buff *skb) +{ + struct tcp_md5sig_pool *hp; + struct ahash_request *req; + const struct tcphdr *th = tcp_hdr(skb); + __be32 saddr, daddr; + + if (sk) { /* valid for establish/request sockets */ + saddr = sk->sk_rcv_saddr; + daddr = sk->sk_daddr; + } else { + const struct iphdr *iph = ip_hdr(skb); + saddr = iph->saddr; + daddr = iph->daddr; + } + + hp = tcp_get_md5sig_pool(); + if (!hp) + goto clear_hash_noput; + req = hp->md5_req; + + if (crypto_ahash_init(req)) + goto clear_hash; + + if (tcp_v4_md5_hash_headers(hp, daddr, saddr, th, skb->len)) + goto clear_hash; + if (tcp_md5_hash_skb_data(hp, skb, th->doff << 2)) + goto clear_hash; + if (tcp_md5_hash_key(hp, key)) + goto clear_hash; + ahash_request_set_crypt(req, NULL, md5_hash, 0); + if (crypto_ahash_final(req)) + goto clear_hash; + + tcp_put_md5sig_pool(); + return 0; + +clear_hash: + tcp_put_md5sig_pool(); +clear_hash_noput: + memset(md5_hash, 0, 16); + return 1; +} +EXPORT_SYMBOL(tcp_v4_md5_hash_skb); + +#endif + +static void tcp_v4_init_req(struct request_sock *req, + const struct sock *sk_listener, + struct sk_buff *skb) +{ + struct inet_request_sock *ireq = inet_rsk(req); + struct net *net = sock_net(sk_listener); + + sk_rcv_saddr_set(req_to_sk(req), ip_hdr(skb)->daddr); + sk_daddr_set(req_to_sk(req), ip_hdr(skb)->saddr); + RCU_INIT_POINTER(ireq->ireq_opt, tcp_v4_save_options(net, skb)); +} + +static struct dst_entry *tcp_v4_route_req(const struct sock *sk, + struct sk_buff *skb, + struct flowi *fl, + struct request_sock *req) +{ + tcp_v4_init_req(req, sk, skb); + + if (security_inet_conn_request(sk, skb, req)) + return NULL; + + return inet_csk_route_req(sk, &fl->u.ip4, req); +} + +struct request_sock_ops tcp_request_sock_ops __read_mostly = { + .family = PF_INET, + .obj_size = sizeof(struct tcp_request_sock), + .rtx_syn_ack = tcp_rtx_synack, + .send_ack = tcp_v4_reqsk_send_ack, + .destructor = tcp_v4_reqsk_destructor, + .send_reset = tcp_v4_send_reset, + .syn_ack_timeout = tcp_syn_ack_timeout, +}; + +const struct tcp_request_sock_ops tcp_request_sock_ipv4_ops = { + .mss_clamp = TCP_MSS_DEFAULT, +#ifdef CONFIG_TCP_MD5SIG + .req_md5_lookup = tcp_v4_md5_lookup, + .calc_md5_hash = tcp_v4_md5_hash_skb, +#endif +#ifdef CONFIG_SYN_COOKIES + .cookie_init_seq = cookie_v4_init_sequence, +#endif + .route_req = tcp_v4_route_req, + .init_seq = tcp_v4_init_seq, + .init_ts_off = tcp_v4_init_ts_off, + .send_synack = tcp_v4_send_synack, +}; + +int tcp_v4_conn_request(struct sock *sk, struct sk_buff *skb) +{ + /* Never answer to SYNs send to broadcast or multicast */ + if (skb_rtable(skb)->rt_flags & (RTCF_BROADCAST | RTCF_MULTICAST)) + goto drop; + + return tcp_conn_request(&tcp_request_sock_ops, + &tcp_request_sock_ipv4_ops, sk, skb); + +drop: + tcp_listendrop(sk); + return 0; +} +EXPORT_SYMBOL(tcp_v4_conn_request); + + +/* + * The three way handshake has completed - we got a valid synack - + * now create the new socket. + */ +struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb, + struct request_sock *req, + struct dst_entry *dst, + struct request_sock *req_unhash, + bool *own_req) +{ + struct inet_request_sock *ireq; + bool found_dup_sk = false; + struct inet_sock *newinet; + struct tcp_sock *newtp; + struct sock *newsk; +#ifdef CONFIG_TCP_MD5SIG + const union tcp_md5_addr *addr; + struct tcp_md5sig_key *key; + int l3index; +#endif + struct ip_options_rcu *inet_opt; + + if (sk_acceptq_is_full(sk)) + goto exit_overflow; + + newsk = tcp_create_openreq_child(sk, req, skb); + if (!newsk) + goto exit_nonewsk; + + newsk->sk_gso_type = SKB_GSO_TCPV4; + inet_sk_rx_dst_set(newsk, skb); + + newtp = tcp_sk(newsk); + newinet = inet_sk(newsk); + ireq = inet_rsk(req); + sk_daddr_set(newsk, ireq->ir_rmt_addr); + sk_rcv_saddr_set(newsk, ireq->ir_loc_addr); + newsk->sk_bound_dev_if = ireq->ir_iif; + newinet->inet_saddr = ireq->ir_loc_addr; + inet_opt = rcu_dereference(ireq->ireq_opt); + RCU_INIT_POINTER(newinet->inet_opt, inet_opt); + newinet->mc_index = inet_iif(skb); + newinet->mc_ttl = ip_hdr(skb)->ttl; + newinet->rcv_tos = ip_hdr(skb)->tos; + inet_csk(newsk)->icsk_ext_hdr_len = 0; + if (inet_opt) + inet_csk(newsk)->icsk_ext_hdr_len = inet_opt->opt.optlen; + atomic_set(&newinet->inet_id, get_random_u16()); + + /* Set ToS of the new socket based upon the value of incoming SYN. + * ECT bits are set later in tcp_init_transfer(). + */ + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_reflect_tos)) + newinet->tos = tcp_rsk(req)->syn_tos & ~INET_ECN_MASK; + + if (!dst) { + dst = inet_csk_route_child_sock(sk, newsk, req); + if (!dst) + goto put_and_exit; + } else { + /* syncookie case : see end of cookie_v4_check() */ + } + sk_setup_caps(newsk, dst); + + tcp_ca_openreq_child(newsk, dst); + + tcp_sync_mss(newsk, dst_mtu(dst)); + newtp->advmss = tcp_mss_clamp(tcp_sk(sk), dst_metric_advmss(dst)); + + tcp_initialize_rcv_mss(newsk); + +#ifdef CONFIG_TCP_MD5SIG + l3index = l3mdev_master_ifindex_by_index(sock_net(sk), ireq->ir_iif); + /* Copy over the MD5 key from the original socket */ + addr = (union tcp_md5_addr *)&newinet->inet_daddr; + key = tcp_md5_do_lookup(sk, l3index, addr, AF_INET); + if (key) { + /* + * We're using one, so create a matching key + * on the newsk structure. If we fail to get + * memory, then we end up not copying the key + * across. Shucks. + */ + tcp_md5_do_add(newsk, addr, AF_INET, 32, l3index, key->flags, + key->key, key->keylen, GFP_ATOMIC); + sk_gso_disable(newsk); + } +#endif + + if (__inet_inherit_port(sk, newsk) < 0) + goto put_and_exit; + *own_req = inet_ehash_nolisten(newsk, req_to_sk(req_unhash), + &found_dup_sk); + if (likely(*own_req)) { + tcp_move_syn(newtp, req); + ireq->ireq_opt = NULL; + } else { + newinet->inet_opt = NULL; + + if (!req_unhash && found_dup_sk) { + /* This code path should only be executed in the + * syncookie case only + */ + bh_unlock_sock(newsk); + sock_put(newsk); + newsk = NULL; + } + } + return newsk; + +exit_overflow: + NET_INC_STATS(sock_net(sk), LINUX_MIB_LISTENOVERFLOWS); +exit_nonewsk: + dst_release(dst); +exit: + tcp_listendrop(sk); + return NULL; +put_and_exit: + newinet->inet_opt = NULL; + inet_csk_prepare_forced_close(newsk); + tcp_done(newsk); + goto exit; +} +EXPORT_SYMBOL(tcp_v4_syn_recv_sock); + +static struct sock *tcp_v4_cookie_check(struct sock *sk, struct sk_buff *skb) +{ +#ifdef CONFIG_SYN_COOKIES + const struct tcphdr *th = tcp_hdr(skb); + + if (!th->syn) + sk = cookie_v4_check(sk, skb); +#endif + return sk; +} + +u16 tcp_v4_get_syncookie(struct sock *sk, struct iphdr *iph, + struct tcphdr *th, u32 *cookie) +{ + u16 mss = 0; +#ifdef CONFIG_SYN_COOKIES + mss = tcp_get_syncookie_mss(&tcp_request_sock_ops, + &tcp_request_sock_ipv4_ops, sk, th); + if (mss) { + *cookie = __cookie_v4_init_sequence(iph, th, &mss); + tcp_synq_overflow(sk); + } +#endif + return mss; +} + +INDIRECT_CALLABLE_DECLARE(struct dst_entry *ipv4_dst_check(struct dst_entry *, + u32)); +/* The socket must have it's spinlock held when we get + * here, unless it is a TCP_LISTEN socket. + * + * We have a potential double-lock case here, so even when + * doing backlog processing we use the BH locking scheme. + * This is because we cannot sleep with the original spinlock + * held. + */ +int tcp_v4_do_rcv(struct sock *sk, struct sk_buff *skb) +{ + enum skb_drop_reason reason; + struct sock *rsk; + + if (sk->sk_state == TCP_ESTABLISHED) { /* Fast path */ + struct dst_entry *dst; + + dst = rcu_dereference_protected(sk->sk_rx_dst, + lockdep_sock_is_held(sk)); + + sock_rps_save_rxhash(sk, skb); + sk_mark_napi_id(sk, skb); + if (dst) { + if (sk->sk_rx_dst_ifindex != skb->skb_iif || + !INDIRECT_CALL_1(dst->ops->check, ipv4_dst_check, + dst, 0)) { + RCU_INIT_POINTER(sk->sk_rx_dst, NULL); + dst_release(dst); + } + } + tcp_rcv_established(sk, skb); + return 0; + } + + reason = SKB_DROP_REASON_NOT_SPECIFIED; + if (tcp_checksum_complete(skb)) + goto csum_err; + + if (sk->sk_state == TCP_LISTEN) { + struct sock *nsk = tcp_v4_cookie_check(sk, skb); + + if (!nsk) + goto discard; + if (nsk != sk) { + if (tcp_child_process(sk, nsk, skb)) { + rsk = nsk; + goto reset; + } + return 0; + } + } else + sock_rps_save_rxhash(sk, skb); + + if (tcp_rcv_state_process(sk, skb)) { + rsk = sk; + goto reset; + } + return 0; + +reset: + tcp_v4_send_reset(rsk, skb); +discard: + kfree_skb_reason(skb, reason); + /* Be careful here. If this function gets more complicated and + * gcc suffers from register pressure on the x86, sk (in %ebx) + * might be destroyed here. This current version compiles correctly, + * but you have been warned. + */ + return 0; + +csum_err: + reason = SKB_DROP_REASON_TCP_CSUM; + trace_tcp_bad_csum(skb); + TCP_INC_STATS(sock_net(sk), TCP_MIB_CSUMERRORS); + TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS); + goto discard; +} +EXPORT_SYMBOL(tcp_v4_do_rcv); + +int tcp_v4_early_demux(struct sk_buff *skb) +{ + struct net *net = dev_net(skb->dev); + const struct iphdr *iph; + const struct tcphdr *th; + struct sock *sk; + + if (skb->pkt_type != PACKET_HOST) + return 0; + + if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct tcphdr))) + return 0; + + iph = ip_hdr(skb); + th = tcp_hdr(skb); + + if (th->doff < sizeof(struct tcphdr) / 4) + return 0; + + sk = __inet_lookup_established(net, net->ipv4.tcp_death_row.hashinfo, + iph->saddr, th->source, + iph->daddr, ntohs(th->dest), + skb->skb_iif, inet_sdif(skb)); + if (sk) { + skb->sk = sk; + skb->destructor = sock_edemux; + if (sk_fullsock(sk)) { + struct dst_entry *dst = rcu_dereference(sk->sk_rx_dst); + + if (dst) + dst = dst_check(dst, 0); + if (dst && + sk->sk_rx_dst_ifindex == skb->skb_iif) + skb_dst_set_noref(skb, dst); + } + } + return 0; +} + +bool tcp_add_backlog(struct sock *sk, struct sk_buff *skb, + enum skb_drop_reason *reason) +{ + u32 limit, tail_gso_size, tail_gso_segs; + struct skb_shared_info *shinfo; + const struct tcphdr *th; + struct tcphdr *thtail; + struct sk_buff *tail; + unsigned int hdrlen; + bool fragstolen; + u32 gso_segs; + u32 gso_size; + int delta; + + /* In case all data was pulled from skb frags (in __pskb_pull_tail()), + * we can fix skb->truesize to its real value to avoid future drops. + * This is valid because skb is not yet charged to the socket. + * It has been noticed pure SACK packets were sometimes dropped + * (if cooked by drivers without copybreak feature). + */ + skb_condense(skb); + + skb_dst_drop(skb); + + if (unlikely(tcp_checksum_complete(skb))) { + bh_unlock_sock(sk); + trace_tcp_bad_csum(skb); + *reason = SKB_DROP_REASON_TCP_CSUM; + __TCP_INC_STATS(sock_net(sk), TCP_MIB_CSUMERRORS); + __TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS); + return true; + } + + /* Attempt coalescing to last skb in backlog, even if we are + * above the limits. + * This is okay because skb capacity is limited to MAX_SKB_FRAGS. + */ + th = (const struct tcphdr *)skb->data; + hdrlen = th->doff * 4; + + tail = sk->sk_backlog.tail; + if (!tail) + goto no_coalesce; + thtail = (struct tcphdr *)tail->data; + + if (TCP_SKB_CB(tail)->end_seq != TCP_SKB_CB(skb)->seq || + TCP_SKB_CB(tail)->ip_dsfield != TCP_SKB_CB(skb)->ip_dsfield || + ((TCP_SKB_CB(tail)->tcp_flags | + TCP_SKB_CB(skb)->tcp_flags) & (TCPHDR_SYN | TCPHDR_RST | TCPHDR_URG)) || + !((TCP_SKB_CB(tail)->tcp_flags & + TCP_SKB_CB(skb)->tcp_flags) & TCPHDR_ACK) || + ((TCP_SKB_CB(tail)->tcp_flags ^ + TCP_SKB_CB(skb)->tcp_flags) & (TCPHDR_ECE | TCPHDR_CWR)) || +#ifdef CONFIG_TLS_DEVICE + tail->decrypted != skb->decrypted || +#endif + !mptcp_skb_can_collapse(tail, skb) || + thtail->doff != th->doff || + memcmp(thtail + 1, th + 1, hdrlen - sizeof(*th))) + goto no_coalesce; + + __skb_pull(skb, hdrlen); + + shinfo = skb_shinfo(skb); + gso_size = shinfo->gso_size ?: skb->len; + gso_segs = shinfo->gso_segs ?: 1; + + shinfo = skb_shinfo(tail); + tail_gso_size = shinfo->gso_size ?: (tail->len - hdrlen); + tail_gso_segs = shinfo->gso_segs ?: 1; + + if (skb_try_coalesce(tail, skb, &fragstolen, &delta)) { + TCP_SKB_CB(tail)->end_seq = TCP_SKB_CB(skb)->end_seq; + + if (likely(!before(TCP_SKB_CB(skb)->ack_seq, TCP_SKB_CB(tail)->ack_seq))) { + TCP_SKB_CB(tail)->ack_seq = TCP_SKB_CB(skb)->ack_seq; + thtail->window = th->window; + } + + /* We have to update both TCP_SKB_CB(tail)->tcp_flags and + * thtail->fin, so that the fast path in tcp_rcv_established() + * is not entered if we append a packet with a FIN. + * SYN, RST, URG are not present. + * ACK is set on both packets. + * PSH : we do not really care in TCP stack, + * at least for 'GRO' packets. + */ + thtail->fin |= th->fin; + TCP_SKB_CB(tail)->tcp_flags |= TCP_SKB_CB(skb)->tcp_flags; + + if (TCP_SKB_CB(skb)->has_rxtstamp) { + TCP_SKB_CB(tail)->has_rxtstamp = true; + tail->tstamp = skb->tstamp; + skb_hwtstamps(tail)->hwtstamp = skb_hwtstamps(skb)->hwtstamp; + } + + /* Not as strict as GRO. We only need to carry mss max value */ + shinfo->gso_size = max(gso_size, tail_gso_size); + shinfo->gso_segs = min_t(u32, gso_segs + tail_gso_segs, 0xFFFF); + + sk->sk_backlog.len += delta; + __NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPBACKLOGCOALESCE); + kfree_skb_partial(skb, fragstolen); + return false; + } + __skb_push(skb, hdrlen); + +no_coalesce: + limit = (u32)READ_ONCE(sk->sk_rcvbuf) + (u32)(READ_ONCE(sk->sk_sndbuf) >> 1); + + /* Only socket owner can try to collapse/prune rx queues + * to reduce memory overhead, so add a little headroom here. + * Few sockets backlog are possibly concurrently non empty. + */ + limit += 64 * 1024; + + if (unlikely(sk_add_backlog(sk, skb, limit))) { + bh_unlock_sock(sk); + *reason = SKB_DROP_REASON_SOCKET_BACKLOG; + __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPBACKLOGDROP); + return true; + } + return false; +} +EXPORT_SYMBOL(tcp_add_backlog); + +int tcp_filter(struct sock *sk, struct sk_buff *skb) +{ + struct tcphdr *th = (struct tcphdr *)skb->data; + + return sk_filter_trim_cap(sk, skb, th->doff * 4); +} +EXPORT_SYMBOL(tcp_filter); + +static void tcp_v4_restore_cb(struct sk_buff *skb) +{ + memmove(IPCB(skb), &TCP_SKB_CB(skb)->header.h4, + sizeof(struct inet_skb_parm)); +} + +static void tcp_v4_fill_cb(struct sk_buff *skb, const struct iphdr *iph, + const struct tcphdr *th) +{ + /* This is tricky : We move IPCB at its correct location into TCP_SKB_CB() + * barrier() makes sure compiler wont play fool^Waliasing games. + */ + memmove(&TCP_SKB_CB(skb)->header.h4, IPCB(skb), + sizeof(struct inet_skb_parm)); + barrier(); + + TCP_SKB_CB(skb)->seq = ntohl(th->seq); + TCP_SKB_CB(skb)->end_seq = (TCP_SKB_CB(skb)->seq + th->syn + th->fin + + skb->len - th->doff * 4); + TCP_SKB_CB(skb)->ack_seq = ntohl(th->ack_seq); + TCP_SKB_CB(skb)->tcp_flags = tcp_flag_byte(th); + TCP_SKB_CB(skb)->tcp_tw_isn = 0; + TCP_SKB_CB(skb)->ip_dsfield = ipv4_get_dsfield(iph); + TCP_SKB_CB(skb)->sacked = 0; + TCP_SKB_CB(skb)->has_rxtstamp = + skb->tstamp || skb_hwtstamps(skb)->hwtstamp; +} + +/* + * From tcp_input.c + */ + +int tcp_v4_rcv(struct sk_buff *skb) +{ + struct net *net = dev_net(skb->dev); + enum skb_drop_reason drop_reason; + int sdif = inet_sdif(skb); + int dif = inet_iif(skb); + const struct iphdr *iph; + const struct tcphdr *th; + bool refcounted; + struct sock *sk; + int ret; + + drop_reason = SKB_DROP_REASON_NOT_SPECIFIED; + if (skb->pkt_type != PACKET_HOST) + goto discard_it; + + /* Count it even if it's bad */ + __TCP_INC_STATS(net, TCP_MIB_INSEGS); + + if (!pskb_may_pull(skb, sizeof(struct tcphdr))) + goto discard_it; + + th = (const struct tcphdr *)skb->data; + + if (unlikely(th->doff < sizeof(struct tcphdr) / 4)) { + drop_reason = SKB_DROP_REASON_PKT_TOO_SMALL; + goto bad_packet; + } + if (!pskb_may_pull(skb, th->doff * 4)) + goto discard_it; + + /* An explanation is required here, I think. + * Packet length and doff are validated by header prediction, + * provided case of th->doff==0 is eliminated. + * So, we defer the checks. */ + + if (skb_checksum_init(skb, IPPROTO_TCP, inet_compute_pseudo)) + goto csum_error; + + th = (const struct tcphdr *)skb->data; + iph = ip_hdr(skb); +lookup: + sk = __inet_lookup_skb(net->ipv4.tcp_death_row.hashinfo, + skb, __tcp_hdrlen(th), th->source, + th->dest, sdif, &refcounted); + if (!sk) + goto no_tcp_socket; + +process: + if (sk->sk_state == TCP_TIME_WAIT) + goto do_time_wait; + + if (sk->sk_state == TCP_NEW_SYN_RECV) { + struct request_sock *req = inet_reqsk(sk); + bool req_stolen = false; + struct sock *nsk; + + sk = req->rsk_listener; + if (!xfrm4_policy_check(sk, XFRM_POLICY_IN, skb)) + drop_reason = SKB_DROP_REASON_XFRM_POLICY; + else + drop_reason = tcp_inbound_md5_hash(sk, skb, + &iph->saddr, &iph->daddr, + AF_INET, dif, sdif); + if (unlikely(drop_reason)) { + sk_drops_add(sk, skb); + reqsk_put(req); + goto discard_it; + } + if (tcp_checksum_complete(skb)) { + reqsk_put(req); + goto csum_error; + } + if (unlikely(sk->sk_state != TCP_LISTEN)) { + nsk = reuseport_migrate_sock(sk, req_to_sk(req), skb); + if (!nsk) { + inet_csk_reqsk_queue_drop_and_put(sk, req); + goto lookup; + } + sk = nsk; + /* reuseport_migrate_sock() has already held one sk_refcnt + * before returning. + */ + } else { + /* We own a reference on the listener, increase it again + * as we might lose it too soon. + */ + sock_hold(sk); + } + refcounted = true; + nsk = NULL; + if (!tcp_filter(sk, skb)) { + th = (const struct tcphdr *)skb->data; + iph = ip_hdr(skb); + tcp_v4_fill_cb(skb, iph, th); + nsk = tcp_check_req(sk, skb, req, false, &req_stolen); + } else { + drop_reason = SKB_DROP_REASON_SOCKET_FILTER; + } + if (!nsk) { + reqsk_put(req); + if (req_stolen) { + /* Another cpu got exclusive access to req + * and created a full blown socket. + * Try to feed this packet to this socket + * instead of discarding it. + */ + tcp_v4_restore_cb(skb); + sock_put(sk); + goto lookup; + } + goto discard_and_relse; + } + nf_reset_ct(skb); + if (nsk == sk) { + reqsk_put(req); + tcp_v4_restore_cb(skb); + } else if (tcp_child_process(sk, nsk, skb)) { + tcp_v4_send_reset(nsk, skb); + goto discard_and_relse; + } else { + sock_put(sk); + return 0; + } + } + + if (static_branch_unlikely(&ip4_min_ttl)) { + /* min_ttl can be changed concurrently from do_ip_setsockopt() */ + if (unlikely(iph->ttl < READ_ONCE(inet_sk(sk)->min_ttl))) { + __NET_INC_STATS(net, LINUX_MIB_TCPMINTTLDROP); + goto discard_and_relse; + } + } + + if (!xfrm4_policy_check(sk, XFRM_POLICY_IN, skb)) { + drop_reason = SKB_DROP_REASON_XFRM_POLICY; + goto discard_and_relse; + } + + drop_reason = tcp_inbound_md5_hash(sk, skb, &iph->saddr, + &iph->daddr, AF_INET, dif, sdif); + if (drop_reason) + goto discard_and_relse; + + nf_reset_ct(skb); + + if (tcp_filter(sk, skb)) { + drop_reason = SKB_DROP_REASON_SOCKET_FILTER; + goto discard_and_relse; + } + th = (const struct tcphdr *)skb->data; + iph = ip_hdr(skb); + tcp_v4_fill_cb(skb, iph, th); + + skb->dev = NULL; + + if (sk->sk_state == TCP_LISTEN) { + ret = tcp_v4_do_rcv(sk, skb); + goto put_and_return; + } + + sk_incoming_cpu_update(sk); + + bh_lock_sock_nested(sk); + tcp_segs_in(tcp_sk(sk), skb); + ret = 0; + if (!sock_owned_by_user(sk)) { + ret = tcp_v4_do_rcv(sk, skb); + } else { + if (tcp_add_backlog(sk, skb, &drop_reason)) + goto discard_and_relse; + } + bh_unlock_sock(sk); + +put_and_return: + if (refcounted) + sock_put(sk); + + return ret; + +no_tcp_socket: + drop_reason = SKB_DROP_REASON_NO_SOCKET; + if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) + goto discard_it; + + tcp_v4_fill_cb(skb, iph, th); + + if (tcp_checksum_complete(skb)) { +csum_error: + drop_reason = SKB_DROP_REASON_TCP_CSUM; + trace_tcp_bad_csum(skb); + __TCP_INC_STATS(net, TCP_MIB_CSUMERRORS); +bad_packet: + __TCP_INC_STATS(net, TCP_MIB_INERRS); + } else { + tcp_v4_send_reset(NULL, skb); + } + +discard_it: + SKB_DR_OR(drop_reason, NOT_SPECIFIED); + /* Discard frame. */ + kfree_skb_reason(skb, drop_reason); + return 0; + +discard_and_relse: + sk_drops_add(sk, skb); + if (refcounted) + sock_put(sk); + goto discard_it; + +do_time_wait: + if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) { + drop_reason = SKB_DROP_REASON_XFRM_POLICY; + inet_twsk_put(inet_twsk(sk)); + goto discard_it; + } + + tcp_v4_fill_cb(skb, iph, th); + + if (tcp_checksum_complete(skb)) { + inet_twsk_put(inet_twsk(sk)); + goto csum_error; + } + switch (tcp_timewait_state_process(inet_twsk(sk), skb, th)) { + case TCP_TW_SYN: { + struct sock *sk2 = inet_lookup_listener(net, + net->ipv4.tcp_death_row.hashinfo, + skb, __tcp_hdrlen(th), + iph->saddr, th->source, + iph->daddr, th->dest, + inet_iif(skb), + sdif); + if (sk2) { + inet_twsk_deschedule_put(inet_twsk(sk)); + sk = sk2; + tcp_v4_restore_cb(skb); + refcounted = false; + goto process; + } + } + /* to ACK */ + fallthrough; + case TCP_TW_ACK: + tcp_v4_timewait_ack(sk, skb); + break; + case TCP_TW_RST: + tcp_v4_send_reset(sk, skb); + inet_twsk_deschedule_put(inet_twsk(sk)); + goto discard_it; + case TCP_TW_SUCCESS:; + } + goto discard_it; +} + +static struct timewait_sock_ops tcp_timewait_sock_ops = { + .twsk_obj_size = sizeof(struct tcp_timewait_sock), + .twsk_unique = tcp_twsk_unique, + .twsk_destructor= tcp_twsk_destructor, +}; + +void inet_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + + if (dst && dst_hold_safe(dst)) { + rcu_assign_pointer(sk->sk_rx_dst, dst); + sk->sk_rx_dst_ifindex = skb->skb_iif; + } +} +EXPORT_SYMBOL(inet_sk_rx_dst_set); + +const struct inet_connection_sock_af_ops ipv4_specific = { + .queue_xmit = ip_queue_xmit, + .send_check = tcp_v4_send_check, + .rebuild_header = inet_sk_rebuild_header, + .sk_rx_dst_set = inet_sk_rx_dst_set, + .conn_request = tcp_v4_conn_request, + .syn_recv_sock = tcp_v4_syn_recv_sock, + .net_header_len = sizeof(struct iphdr), + .setsockopt = ip_setsockopt, + .getsockopt = ip_getsockopt, + .addr2sockaddr = inet_csk_addr2sockaddr, + .sockaddr_len = sizeof(struct sockaddr_in), + .mtu_reduced = tcp_v4_mtu_reduced, +}; +EXPORT_SYMBOL(ipv4_specific); + +#ifdef CONFIG_TCP_MD5SIG +static const struct tcp_sock_af_ops tcp_sock_ipv4_specific = { + .md5_lookup = tcp_v4_md5_lookup, + .calc_md5_hash = tcp_v4_md5_hash_skb, + .md5_parse = tcp_v4_parse_md5_keys, +}; +#endif + +/* NOTE: A lot of things set to zero explicitly by call to + * sk_alloc() so need not be done here. + */ +static int tcp_v4_init_sock(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + tcp_init_sock(sk); + + icsk->icsk_af_ops = &ipv4_specific; + +#ifdef CONFIG_TCP_MD5SIG + tcp_sk(sk)->af_specific = &tcp_sock_ipv4_specific; +#endif + + return 0; +} + +void tcp_v4_destroy_sock(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + trace_tcp_destroy_sock(sk); + + tcp_clear_xmit_timers(sk); + + tcp_cleanup_congestion_control(sk); + + tcp_cleanup_ulp(sk); + + /* Cleanup up the write buffer. */ + tcp_write_queue_purge(sk); + + /* Check if we want to disable active TFO */ + tcp_fastopen_active_disable_ofo_check(sk); + + /* Cleans up our, hopefully empty, out_of_order_queue. */ + skb_rbtree_purge(&tp->out_of_order_queue); + +#ifdef CONFIG_TCP_MD5SIG + /* Clean up the MD5 key list, if any */ + if (tp->md5sig_info) { + tcp_clear_md5_list(sk); + kfree_rcu(rcu_dereference_protected(tp->md5sig_info, 1), rcu); + tp->md5sig_info = NULL; + } +#endif + + /* Clean up a referenced TCP bind bucket. */ + if (inet_csk(sk)->icsk_bind_hash) + inet_put_port(sk); + + BUG_ON(rcu_access_pointer(tp->fastopen_rsk)); + + /* If socket is aborted during connect operation */ + tcp_free_fastopen_req(tp); + tcp_fastopen_destroy_cipher(sk); + tcp_saved_syn_free(tp); + + sk_sockets_allocated_dec(sk); +} +EXPORT_SYMBOL(tcp_v4_destroy_sock); + +#ifdef CONFIG_PROC_FS +/* Proc filesystem TCP sock list dumping. */ + +static unsigned short seq_file_family(const struct seq_file *seq); + +static bool seq_sk_match(struct seq_file *seq, const struct sock *sk) +{ + unsigned short family = seq_file_family(seq); + + /* AF_UNSPEC is used as a match all */ + return ((family == AF_UNSPEC || family == sk->sk_family) && + net_eq(sock_net(sk), seq_file_net(seq))); +} + +/* Find a non empty bucket (starting from st->bucket) + * and return the first sk from it. + */ +static void *listening_get_first(struct seq_file *seq) +{ + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; + struct tcp_iter_state *st = seq->private; + + st->offset = 0; + for (; st->bucket <= hinfo->lhash2_mask; st->bucket++) { + struct inet_listen_hashbucket *ilb2; + struct hlist_nulls_node *node; + struct sock *sk; + + ilb2 = &hinfo->lhash2[st->bucket]; + if (hlist_nulls_empty(&ilb2->nulls_head)) + continue; + + spin_lock(&ilb2->lock); + sk_nulls_for_each(sk, node, &ilb2->nulls_head) { + if (seq_sk_match(seq, sk)) + return sk; + } + spin_unlock(&ilb2->lock); + } + + return NULL; +} + +/* Find the next sk of "cur" within the same bucket (i.e. st->bucket). + * If "cur" is the last one in the st->bucket, + * call listening_get_first() to return the first sk of the next + * non empty bucket. + */ +static void *listening_get_next(struct seq_file *seq, void *cur) +{ + struct tcp_iter_state *st = seq->private; + struct inet_listen_hashbucket *ilb2; + struct hlist_nulls_node *node; + struct inet_hashinfo *hinfo; + struct sock *sk = cur; + + ++st->num; + ++st->offset; + + sk = sk_nulls_next(sk); + sk_nulls_for_each_from(sk, node) { + if (seq_sk_match(seq, sk)) + return sk; + } + + hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; + ilb2 = &hinfo->lhash2[st->bucket]; + spin_unlock(&ilb2->lock); + ++st->bucket; + return listening_get_first(seq); +} + +static void *listening_get_idx(struct seq_file *seq, loff_t *pos) +{ + struct tcp_iter_state *st = seq->private; + void *rc; + + st->bucket = 0; + st->offset = 0; + rc = listening_get_first(seq); + + while (rc && *pos) { + rc = listening_get_next(seq, rc); + --*pos; + } + return rc; +} + +static inline bool empty_bucket(struct inet_hashinfo *hinfo, + const struct tcp_iter_state *st) +{ + return hlist_nulls_empty(&hinfo->ehash[st->bucket].chain); +} + +/* + * Get first established socket starting from bucket given in st->bucket. + * If st->bucket is zero, the very first socket in the hash is returned. + */ +static void *established_get_first(struct seq_file *seq) +{ + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; + struct tcp_iter_state *st = seq->private; + + st->offset = 0; + for (; st->bucket <= hinfo->ehash_mask; ++st->bucket) { + struct sock *sk; + struct hlist_nulls_node *node; + spinlock_t *lock = inet_ehash_lockp(hinfo, st->bucket); + + /* Lockless fast path for the common case of empty buckets */ + if (empty_bucket(hinfo, st)) + continue; + + spin_lock_bh(lock); + sk_nulls_for_each(sk, node, &hinfo->ehash[st->bucket].chain) { + if (seq_sk_match(seq, sk)) + return sk; + } + spin_unlock_bh(lock); + } + + return NULL; +} + +static void *established_get_next(struct seq_file *seq, void *cur) +{ + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; + struct tcp_iter_state *st = seq->private; + struct hlist_nulls_node *node; + struct sock *sk = cur; + + ++st->num; + ++st->offset; + + sk = sk_nulls_next(sk); + + sk_nulls_for_each_from(sk, node) { + if (seq_sk_match(seq, sk)) + return sk; + } + + spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket)); + ++st->bucket; + return established_get_first(seq); +} + +static void *established_get_idx(struct seq_file *seq, loff_t pos) +{ + struct tcp_iter_state *st = seq->private; + void *rc; + + st->bucket = 0; + rc = established_get_first(seq); + + while (rc && pos) { + rc = established_get_next(seq, rc); + --pos; + } + return rc; +} + +static void *tcp_get_idx(struct seq_file *seq, loff_t pos) +{ + void *rc; + struct tcp_iter_state *st = seq->private; + + st->state = TCP_SEQ_STATE_LISTENING; + rc = listening_get_idx(seq, &pos); + + if (!rc) { + st->state = TCP_SEQ_STATE_ESTABLISHED; + rc = established_get_idx(seq, pos); + } + + return rc; +} + +static void *tcp_seek_last_pos(struct seq_file *seq) +{ + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; + struct tcp_iter_state *st = seq->private; + int bucket = st->bucket; + int offset = st->offset; + int orig_num = st->num; + void *rc = NULL; + + switch (st->state) { + case TCP_SEQ_STATE_LISTENING: + if (st->bucket > hinfo->lhash2_mask) + break; + st->state = TCP_SEQ_STATE_LISTENING; + rc = listening_get_first(seq); + while (offset-- && rc && bucket == st->bucket) + rc = listening_get_next(seq, rc); + if (rc) + break; + st->bucket = 0; + st->state = TCP_SEQ_STATE_ESTABLISHED; + fallthrough; + case TCP_SEQ_STATE_ESTABLISHED: + if (st->bucket > hinfo->ehash_mask) + break; + rc = established_get_first(seq); + while (offset-- && rc && bucket == st->bucket) + rc = established_get_next(seq, rc); + } + + st->num = orig_num; + + return rc; +} + +void *tcp_seq_start(struct seq_file *seq, loff_t *pos) +{ + struct tcp_iter_state *st = seq->private; + void *rc; + + if (*pos && *pos == st->last_pos) { + rc = tcp_seek_last_pos(seq); + if (rc) + goto out; + } + + st->state = TCP_SEQ_STATE_LISTENING; + st->num = 0; + st->bucket = 0; + st->offset = 0; + rc = *pos ? tcp_get_idx(seq, *pos - 1) : SEQ_START_TOKEN; + +out: + st->last_pos = *pos; + return rc; +} +EXPORT_SYMBOL(tcp_seq_start); + +void *tcp_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct tcp_iter_state *st = seq->private; + void *rc = NULL; + + if (v == SEQ_START_TOKEN) { + rc = tcp_get_idx(seq, 0); + goto out; + } + + switch (st->state) { + case TCP_SEQ_STATE_LISTENING: + rc = listening_get_next(seq, v); + if (!rc) { + st->state = TCP_SEQ_STATE_ESTABLISHED; + st->bucket = 0; + st->offset = 0; + rc = established_get_first(seq); + } + break; + case TCP_SEQ_STATE_ESTABLISHED: + rc = established_get_next(seq, v); + break; + } +out: + ++*pos; + st->last_pos = *pos; + return rc; +} +EXPORT_SYMBOL(tcp_seq_next); + +void tcp_seq_stop(struct seq_file *seq, void *v) +{ + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; + struct tcp_iter_state *st = seq->private; + + switch (st->state) { + case TCP_SEQ_STATE_LISTENING: + if (v != SEQ_START_TOKEN) + spin_unlock(&hinfo->lhash2[st->bucket].lock); + break; + case TCP_SEQ_STATE_ESTABLISHED: + if (v) + spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket)); + break; + } +} +EXPORT_SYMBOL(tcp_seq_stop); + +static void get_openreq4(const struct request_sock *req, + struct seq_file *f, int i) +{ + const struct inet_request_sock *ireq = inet_rsk(req); + long delta = req->rsk_timer.expires - jiffies; + + seq_printf(f, "%4d: %08X:%04X %08X:%04X" + " %02X %08X:%08X %02X:%08lX %08X %5u %8d %u %d %pK", + i, + ireq->ir_loc_addr, + ireq->ir_num, + ireq->ir_rmt_addr, + ntohs(ireq->ir_rmt_port), + TCP_SYN_RECV, + 0, 0, /* could print option size, but that is af dependent. */ + 1, /* timers active (only the expire timer) */ + jiffies_delta_to_clock_t(delta), + req->num_timeout, + from_kuid_munged(seq_user_ns(f), + sock_i_uid(req->rsk_listener)), + 0, /* non standard timer */ + 0, /* open_requests have no inode */ + 0, + req); +} + +static void get_tcp4_sock(struct sock *sk, struct seq_file *f, int i) +{ + int timer_active; + unsigned long timer_expires; + const struct tcp_sock *tp = tcp_sk(sk); + const struct inet_connection_sock *icsk = inet_csk(sk); + const struct inet_sock *inet = inet_sk(sk); + const struct fastopen_queue *fastopenq = &icsk->icsk_accept_queue.fastopenq; + __be32 dest = inet->inet_daddr; + __be32 src = inet->inet_rcv_saddr; + __u16 destp = ntohs(inet->inet_dport); + __u16 srcp = ntohs(inet->inet_sport); + int rx_queue; + int state; + + if (icsk->icsk_pending == ICSK_TIME_RETRANS || + icsk->icsk_pending == ICSK_TIME_REO_TIMEOUT || + icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) { + timer_active = 1; + timer_expires = icsk->icsk_timeout; + } else if (icsk->icsk_pending == ICSK_TIME_PROBE0) { + timer_active = 4; + timer_expires = icsk->icsk_timeout; + } else if (timer_pending(&sk->sk_timer)) { + timer_active = 2; + timer_expires = sk->sk_timer.expires; + } else { + timer_active = 0; + timer_expires = jiffies; + } + + state = inet_sk_state_load(sk); + if (state == TCP_LISTEN) + rx_queue = READ_ONCE(sk->sk_ack_backlog); + else + /* Because we don't lock the socket, + * we might find a transient negative value. + */ + rx_queue = max_t(int, READ_ONCE(tp->rcv_nxt) - + READ_ONCE(tp->copied_seq), 0); + + seq_printf(f, "%4d: %08X:%04X %08X:%04X %02X %08X:%08X %02X:%08lX " + "%08X %5u %8d %lu %d %pK %lu %lu %u %u %d", + i, src, srcp, dest, destp, state, + READ_ONCE(tp->write_seq) - tp->snd_una, + rx_queue, + timer_active, + jiffies_delta_to_clock_t(timer_expires - jiffies), + icsk->icsk_retransmits, + from_kuid_munged(seq_user_ns(f), sock_i_uid(sk)), + icsk->icsk_probes_out, + sock_i_ino(sk), + refcount_read(&sk->sk_refcnt), sk, + jiffies_to_clock_t(icsk->icsk_rto), + jiffies_to_clock_t(icsk->icsk_ack.ato), + (icsk->icsk_ack.quick << 1) | inet_csk_in_pingpong_mode(sk), + tcp_snd_cwnd(tp), + state == TCP_LISTEN ? + fastopenq->max_qlen : + (tcp_in_initial_slowstart(tp) ? -1 : tp->snd_ssthresh)); +} + +static void get_timewait4_sock(const struct inet_timewait_sock *tw, + struct seq_file *f, int i) +{ + long delta = tw->tw_timer.expires - jiffies; + __be32 dest, src; + __u16 destp, srcp; + + dest = tw->tw_daddr; + src = tw->tw_rcv_saddr; + destp = ntohs(tw->tw_dport); + srcp = ntohs(tw->tw_sport); + + seq_printf(f, "%4d: %08X:%04X %08X:%04X" + " %02X %08X:%08X %02X:%08lX %08X %5d %8d %d %d %pK", + i, src, srcp, dest, destp, tw->tw_substate, 0, 0, + 3, jiffies_delta_to_clock_t(delta), 0, 0, 0, 0, + refcount_read(&tw->tw_refcnt), tw); +} + +#define TMPSZ 150 + +static int tcp4_seq_show(struct seq_file *seq, void *v) +{ + struct tcp_iter_state *st; + struct sock *sk = v; + + seq_setwidth(seq, TMPSZ - 1); + if (v == SEQ_START_TOKEN) { + seq_puts(seq, " sl local_address rem_address st tx_queue " + "rx_queue tr tm->when retrnsmt uid timeout " + "inode"); + goto out; + } + st = seq->private; + + if (sk->sk_state == TCP_TIME_WAIT) + get_timewait4_sock(v, seq, st->num); + else if (sk->sk_state == TCP_NEW_SYN_RECV) + get_openreq4(v, seq, st->num); + else + get_tcp4_sock(v, seq, st->num); +out: + seq_pad(seq, '\n'); + return 0; +} + +#ifdef CONFIG_BPF_SYSCALL +struct bpf_tcp_iter_state { + struct tcp_iter_state state; + unsigned int cur_sk; + unsigned int end_sk; + unsigned int max_sk; + struct sock **batch; + bool st_bucket_done; +}; + +struct bpf_iter__tcp { + __bpf_md_ptr(struct bpf_iter_meta *, meta); + __bpf_md_ptr(struct sock_common *, sk_common); + uid_t uid __aligned(8); +}; + +static int tcp_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta, + struct sock_common *sk_common, uid_t uid) +{ + struct bpf_iter__tcp ctx; + + meta->seq_num--; /* skip SEQ_START_TOKEN */ + ctx.meta = meta; + ctx.sk_common = sk_common; + ctx.uid = uid; + return bpf_iter_run_prog(prog, &ctx); +} + +static void bpf_iter_tcp_put_batch(struct bpf_tcp_iter_state *iter) +{ + while (iter->cur_sk < iter->end_sk) + sock_gen_put(iter->batch[iter->cur_sk++]); +} + +static int bpf_iter_tcp_realloc_batch(struct bpf_tcp_iter_state *iter, + unsigned int new_batch_sz) +{ + struct sock **new_batch; + + new_batch = kvmalloc(sizeof(*new_batch) * new_batch_sz, + GFP_USER | __GFP_NOWARN); + if (!new_batch) + return -ENOMEM; + + bpf_iter_tcp_put_batch(iter); + kvfree(iter->batch); + iter->batch = new_batch; + iter->max_sk = new_batch_sz; + + return 0; +} + +static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq, + struct sock *start_sk) +{ + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + struct hlist_nulls_node *node; + unsigned int expected = 1; + struct sock *sk; + + sock_hold(start_sk); + iter->batch[iter->end_sk++] = start_sk; + + sk = sk_nulls_next(start_sk); + sk_nulls_for_each_from(sk, node) { + if (seq_sk_match(seq, sk)) { + if (iter->end_sk < iter->max_sk) { + sock_hold(sk); + iter->batch[iter->end_sk++] = sk; + } + expected++; + } + } + spin_unlock(&hinfo->lhash2[st->bucket].lock); + + return expected; +} + +static unsigned int bpf_iter_tcp_established_batch(struct seq_file *seq, + struct sock *start_sk) +{ + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + struct hlist_nulls_node *node; + unsigned int expected = 1; + struct sock *sk; + + sock_hold(start_sk); + iter->batch[iter->end_sk++] = start_sk; + + sk = sk_nulls_next(start_sk); + sk_nulls_for_each_from(sk, node) { + if (seq_sk_match(seq, sk)) { + if (iter->end_sk < iter->max_sk) { + sock_hold(sk); + iter->batch[iter->end_sk++] = sk; + } + expected++; + } + } + spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket)); + + return expected; +} + +static struct sock *bpf_iter_tcp_batch(struct seq_file *seq) +{ + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + unsigned int expected; + bool resized = false; + struct sock *sk; + + /* The st->bucket is done. Directly advance to the next + * bucket instead of having the tcp_seek_last_pos() to skip + * one by one in the current bucket and eventually find out + * it has to advance to the next bucket. + */ + if (iter->st_bucket_done) { + st->offset = 0; + st->bucket++; + if (st->state == TCP_SEQ_STATE_LISTENING && + st->bucket > hinfo->lhash2_mask) { + st->state = TCP_SEQ_STATE_ESTABLISHED; + st->bucket = 0; + } + } + +again: + /* Get a new batch */ + iter->cur_sk = 0; + iter->end_sk = 0; + iter->st_bucket_done = false; + + sk = tcp_seek_last_pos(seq); + if (!sk) + return NULL; /* Done */ + + if (st->state == TCP_SEQ_STATE_LISTENING) + expected = bpf_iter_tcp_listening_batch(seq, sk); + else + expected = bpf_iter_tcp_established_batch(seq, sk); + + if (iter->end_sk == expected) { + iter->st_bucket_done = true; + return sk; + } + + if (!resized && !bpf_iter_tcp_realloc_batch(iter, expected * 3 / 2)) { + resized = true; + goto again; + } + + return sk; +} + +static void *bpf_iter_tcp_seq_start(struct seq_file *seq, loff_t *pos) +{ + /* bpf iter does not support lseek, so it always + * continue from where it was stop()-ped. + */ + if (*pos) + return bpf_iter_tcp_batch(seq); + + return SEQ_START_TOKEN; +} + +static void *bpf_iter_tcp_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct bpf_tcp_iter_state *iter = seq->private; + struct tcp_iter_state *st = &iter->state; + struct sock *sk; + + /* Whenever seq_next() is called, the iter->cur_sk is + * done with seq_show(), so advance to the next sk in + * the batch. + */ + if (iter->cur_sk < iter->end_sk) { + /* Keeping st->num consistent in tcp_iter_state. + * bpf_iter_tcp does not use st->num. + * meta.seq_num is used instead. + */ + st->num++; + /* Move st->offset to the next sk in the bucket such that + * the future start() will resume at st->offset in + * st->bucket. See tcp_seek_last_pos(). + */ + st->offset++; + sock_gen_put(iter->batch[iter->cur_sk++]); + } + + if (iter->cur_sk < iter->end_sk) + sk = iter->batch[iter->cur_sk]; + else + sk = bpf_iter_tcp_batch(seq); + + ++*pos; + /* Keeping st->last_pos consistent in tcp_iter_state. + * bpf iter does not do lseek, so st->last_pos always equals to *pos. + */ + st->last_pos = *pos; + return sk; +} + +static int bpf_iter_tcp_seq_show(struct seq_file *seq, void *v) +{ + struct bpf_iter_meta meta; + struct bpf_prog *prog; + struct sock *sk = v; + uid_t uid; + int ret; + + if (v == SEQ_START_TOKEN) + return 0; + + if (sk_fullsock(sk)) + lock_sock(sk); + + if (unlikely(sk_unhashed(sk))) { + ret = SEQ_SKIP; + goto unlock; + } + + if (sk->sk_state == TCP_TIME_WAIT) { + uid = 0; + } else if (sk->sk_state == TCP_NEW_SYN_RECV) { + const struct request_sock *req = v; + + uid = from_kuid_munged(seq_user_ns(seq), + sock_i_uid(req->rsk_listener)); + } else { + uid = from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk)); + } + + meta.seq = seq; + prog = bpf_iter_get_info(&meta, false); + ret = tcp_prog_seq_show(prog, &meta, v, uid); + +unlock: + if (sk_fullsock(sk)) + release_sock(sk); + return ret; + +} + +static void bpf_iter_tcp_seq_stop(struct seq_file *seq, void *v) +{ + struct bpf_tcp_iter_state *iter = seq->private; + struct bpf_iter_meta meta; + struct bpf_prog *prog; + + if (!v) { + meta.seq = seq; + prog = bpf_iter_get_info(&meta, true); + if (prog) + (void)tcp_prog_seq_show(prog, &meta, v, 0); + } + + if (iter->cur_sk < iter->end_sk) { + bpf_iter_tcp_put_batch(iter); + iter->st_bucket_done = false; + } +} + +static const struct seq_operations bpf_iter_tcp_seq_ops = { + .show = bpf_iter_tcp_seq_show, + .start = bpf_iter_tcp_seq_start, + .next = bpf_iter_tcp_seq_next, + .stop = bpf_iter_tcp_seq_stop, +}; +#endif +static unsigned short seq_file_family(const struct seq_file *seq) +{ + const struct tcp_seq_afinfo *afinfo; + +#ifdef CONFIG_BPF_SYSCALL + /* Iterated from bpf_iter. Let the bpf prog to filter instead. */ + if (seq->op == &bpf_iter_tcp_seq_ops) + return AF_UNSPEC; +#endif + + /* Iterated from proc fs */ + afinfo = pde_data(file_inode(seq->file)); + return afinfo->family; +} + +static const struct seq_operations tcp4_seq_ops = { + .show = tcp4_seq_show, + .start = tcp_seq_start, + .next = tcp_seq_next, + .stop = tcp_seq_stop, +}; + +static struct tcp_seq_afinfo tcp4_seq_afinfo = { + .family = AF_INET, +}; + +static int __net_init tcp4_proc_init_net(struct net *net) +{ + if (!proc_create_net_data("tcp", 0444, net->proc_net, &tcp4_seq_ops, + sizeof(struct tcp_iter_state), &tcp4_seq_afinfo)) + return -ENOMEM; + return 0; +} + +static void __net_exit tcp4_proc_exit_net(struct net *net) +{ + remove_proc_entry("tcp", net->proc_net); +} + +static struct pernet_operations tcp4_net_ops = { + .init = tcp4_proc_init_net, + .exit = tcp4_proc_exit_net, +}; + +int __init tcp4_proc_init(void) +{ + return register_pernet_subsys(&tcp4_net_ops); +} + +void tcp4_proc_exit(void) +{ + unregister_pernet_subsys(&tcp4_net_ops); +} +#endif /* CONFIG_PROC_FS */ + +/* @wake is one when sk_stream_write_space() calls us. + * This sends EPOLLOUT only if notsent_bytes is half the limit. + * This mimics the strategy used in sock_def_write_space(). + */ +bool tcp_stream_memory_free(const struct sock *sk, int wake) +{ + const struct tcp_sock *tp = tcp_sk(sk); + u32 notsent_bytes = READ_ONCE(tp->write_seq) - + READ_ONCE(tp->snd_nxt); + + return (notsent_bytes << wake) < tcp_notsent_lowat(tp); +} +EXPORT_SYMBOL(tcp_stream_memory_free); + +struct proto tcp_prot = { + .name = "TCP", + .owner = THIS_MODULE, + .close = tcp_close, + .pre_connect = tcp_v4_pre_connect, + .connect = tcp_v4_connect, + .disconnect = tcp_disconnect, + .accept = inet_csk_accept, + .ioctl = tcp_ioctl, + .init = tcp_v4_init_sock, + .destroy = tcp_v4_destroy_sock, + .shutdown = tcp_shutdown, + .setsockopt = tcp_setsockopt, + .getsockopt = tcp_getsockopt, + .bpf_bypass_getsockopt = tcp_bpf_bypass_getsockopt, + .keepalive = tcp_set_keepalive, + .recvmsg = tcp_recvmsg, + .sendmsg = tcp_sendmsg, + .splice_eof = tcp_splice_eof, + .sendpage = tcp_sendpage, + .backlog_rcv = tcp_v4_do_rcv, + .release_cb = tcp_release_cb, + .hash = inet_hash, + .unhash = inet_unhash, + .get_port = inet_csk_get_port, + .put_port = inet_put_port, +#ifdef CONFIG_BPF_SYSCALL + .psock_update_sk_prot = tcp_bpf_update_proto, +#endif + .enter_memory_pressure = tcp_enter_memory_pressure, + .leave_memory_pressure = tcp_leave_memory_pressure, + .stream_memory_free = tcp_stream_memory_free, + .sockets_allocated = &tcp_sockets_allocated, + .orphan_count = &tcp_orphan_count, + + .memory_allocated = &tcp_memory_allocated, + .per_cpu_fw_alloc = &tcp_memory_per_cpu_fw_alloc, + + .memory_pressure = &tcp_memory_pressure, + .sysctl_mem = sysctl_tcp_mem, + .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_tcp_wmem), + .sysctl_rmem_offset = offsetof(struct net, ipv4.sysctl_tcp_rmem), + .max_header = MAX_TCP_HEADER, + .obj_size = sizeof(struct tcp_sock), + .slab_flags = SLAB_TYPESAFE_BY_RCU, + .twsk_prot = &tcp_timewait_sock_ops, + .rsk_prot = &tcp_request_sock_ops, + .h.hashinfo = NULL, + .no_autobind = true, + .diag_destroy = tcp_abort, +}; +EXPORT_SYMBOL(tcp_prot); + +static void __net_exit tcp_sk_exit(struct net *net) +{ + if (net->ipv4.tcp_congestion_control) + bpf_module_put(net->ipv4.tcp_congestion_control, + net->ipv4.tcp_congestion_control->owner); +} + +static void __net_init tcp_set_hashinfo(struct net *net) +{ + struct inet_hashinfo *hinfo; + unsigned int ehash_entries; + struct net *old_net; + + if (net_eq(net, &init_net)) + goto fallback; + + old_net = current->nsproxy->net_ns; + ehash_entries = READ_ONCE(old_net->ipv4.sysctl_tcp_child_ehash_entries); + if (!ehash_entries) + goto fallback; + + ehash_entries = roundup_pow_of_two(ehash_entries); + hinfo = inet_pernet_hashinfo_alloc(&tcp_hashinfo, ehash_entries); + if (!hinfo) { + pr_warn("Failed to allocate TCP ehash (entries: %u) " + "for a netns, fallback to the global one\n", + ehash_entries); +fallback: + hinfo = &tcp_hashinfo; + ehash_entries = tcp_hashinfo.ehash_mask + 1; + } + + net->ipv4.tcp_death_row.hashinfo = hinfo; + net->ipv4.tcp_death_row.sysctl_max_tw_buckets = ehash_entries / 2; + net->ipv4.sysctl_max_syn_backlog = max(128U, ehash_entries / 128); +} + +static int __net_init tcp_sk_init(struct net *net) +{ + net->ipv4.sysctl_tcp_ecn = 2; + net->ipv4.sysctl_tcp_ecn_fallback = 1; + + net->ipv4.sysctl_tcp_base_mss = TCP_BASE_MSS; + net->ipv4.sysctl_tcp_min_snd_mss = TCP_MIN_SND_MSS; + net->ipv4.sysctl_tcp_probe_threshold = TCP_PROBE_THRESHOLD; + net->ipv4.sysctl_tcp_probe_interval = TCP_PROBE_INTERVAL; + net->ipv4.sysctl_tcp_mtu_probe_floor = TCP_MIN_SND_MSS; + + net->ipv4.sysctl_tcp_keepalive_time = TCP_KEEPALIVE_TIME; + net->ipv4.sysctl_tcp_keepalive_probes = TCP_KEEPALIVE_PROBES; + net->ipv4.sysctl_tcp_keepalive_intvl = TCP_KEEPALIVE_INTVL; + + net->ipv4.sysctl_tcp_syn_retries = TCP_SYN_RETRIES; + net->ipv4.sysctl_tcp_synack_retries = TCP_SYNACK_RETRIES; + net->ipv4.sysctl_tcp_syncookies = 1; + net->ipv4.sysctl_tcp_reordering = TCP_FASTRETRANS_THRESH; + net->ipv4.sysctl_tcp_retries1 = TCP_RETR1; + net->ipv4.sysctl_tcp_retries2 = TCP_RETR2; + net->ipv4.sysctl_tcp_orphan_retries = 0; + net->ipv4.sysctl_tcp_fin_timeout = TCP_FIN_TIMEOUT; + net->ipv4.sysctl_tcp_notsent_lowat = UINT_MAX; + net->ipv4.sysctl_tcp_tw_reuse = 2; + net->ipv4.sysctl_tcp_no_ssthresh_metrics_save = 1; + + refcount_set(&net->ipv4.tcp_death_row.tw_refcount, 1); + tcp_set_hashinfo(net); + + net->ipv4.sysctl_tcp_sack = 1; + net->ipv4.sysctl_tcp_window_scaling = 1; + net->ipv4.sysctl_tcp_timestamps = 1; + net->ipv4.sysctl_tcp_early_retrans = 3; + net->ipv4.sysctl_tcp_recovery = TCP_RACK_LOSS_DETECTION; + net->ipv4.sysctl_tcp_slow_start_after_idle = 1; /* By default, RFC2861 behavior. */ + net->ipv4.sysctl_tcp_retrans_collapse = 1; + net->ipv4.sysctl_tcp_max_reordering = 300; + net->ipv4.sysctl_tcp_dsack = 1; + net->ipv4.sysctl_tcp_app_win = 31; + net->ipv4.sysctl_tcp_adv_win_scale = 1; + net->ipv4.sysctl_tcp_frto = 2; + net->ipv4.sysctl_tcp_moderate_rcvbuf = 1; + /* This limits the percentage of the congestion window which we + * will allow a single TSO frame to consume. Building TSO frames + * which are too large can cause TCP streams to be bursty. + */ + net->ipv4.sysctl_tcp_tso_win_divisor = 3; + /* Default TSQ limit of 16 TSO segments */ + net->ipv4.sysctl_tcp_limit_output_bytes = 16 * 65536; + + /* rfc5961 challenge ack rate limiting, per net-ns, disabled by default. */ + net->ipv4.sysctl_tcp_challenge_ack_limit = INT_MAX; + + net->ipv4.sysctl_tcp_min_tso_segs = 2; + net->ipv4.sysctl_tcp_tso_rtt_log = 9; /* 2^9 = 512 usec */ + net->ipv4.sysctl_tcp_min_rtt_wlen = 300; + net->ipv4.sysctl_tcp_autocorking = 1; + net->ipv4.sysctl_tcp_invalid_ratelimit = HZ/2; + net->ipv4.sysctl_tcp_pacing_ss_ratio = 200; + net->ipv4.sysctl_tcp_pacing_ca_ratio = 120; + if (net != &init_net) { + memcpy(net->ipv4.sysctl_tcp_rmem, + init_net.ipv4.sysctl_tcp_rmem, + sizeof(init_net.ipv4.sysctl_tcp_rmem)); + memcpy(net->ipv4.sysctl_tcp_wmem, + init_net.ipv4.sysctl_tcp_wmem, + sizeof(init_net.ipv4.sysctl_tcp_wmem)); + } + net->ipv4.sysctl_tcp_comp_sack_delay_ns = NSEC_PER_MSEC; + net->ipv4.sysctl_tcp_comp_sack_slack_ns = 100 * NSEC_PER_USEC; + net->ipv4.sysctl_tcp_comp_sack_nr = 44; + net->ipv4.sysctl_tcp_fastopen = TFO_CLIENT_ENABLE; + net->ipv4.sysctl_tcp_fastopen_blackhole_timeout = 0; + atomic_set(&net->ipv4.tfo_active_disable_times, 0); + + /* Reno is always built in */ + if (!net_eq(net, &init_net) && + bpf_try_module_get(init_net.ipv4.tcp_congestion_control, + init_net.ipv4.tcp_congestion_control->owner)) + net->ipv4.tcp_congestion_control = init_net.ipv4.tcp_congestion_control; + else + net->ipv4.tcp_congestion_control = &tcp_reno; + + net->ipv4.sysctl_tcp_shrink_window = 0; + + return 0; +} + +static void __net_exit tcp_sk_exit_batch(struct list_head *net_exit_list) +{ + struct net *net; + + tcp_twsk_purge(net_exit_list, AF_INET); + + list_for_each_entry(net, net_exit_list, exit_list) { + inet_pernet_hashinfo_free(net->ipv4.tcp_death_row.hashinfo); + WARN_ON_ONCE(!refcount_dec_and_test(&net->ipv4.tcp_death_row.tw_refcount)); + tcp_fastopen_ctx_destroy(net); + } +} + +static struct pernet_operations __net_initdata tcp_sk_ops = { + .init = tcp_sk_init, + .exit = tcp_sk_exit, + .exit_batch = tcp_sk_exit_batch, +}; + +#if defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS) +DEFINE_BPF_ITER_FUNC(tcp, struct bpf_iter_meta *meta, + struct sock_common *sk_common, uid_t uid) + +#define INIT_BATCH_SZ 16 + +static int bpf_iter_init_tcp(void *priv_data, struct bpf_iter_aux_info *aux) +{ + struct bpf_tcp_iter_state *iter = priv_data; + int err; + + err = bpf_iter_init_seq_net(priv_data, aux); + if (err) + return err; + + err = bpf_iter_tcp_realloc_batch(iter, INIT_BATCH_SZ); + if (err) { + bpf_iter_fini_seq_net(priv_data); + return err; + } + + return 0; +} + +static void bpf_iter_fini_tcp(void *priv_data) +{ + struct bpf_tcp_iter_state *iter = priv_data; + + bpf_iter_fini_seq_net(priv_data); + kvfree(iter->batch); +} + +static const struct bpf_iter_seq_info tcp_seq_info = { + .seq_ops = &bpf_iter_tcp_seq_ops, + .init_seq_private = bpf_iter_init_tcp, + .fini_seq_private = bpf_iter_fini_tcp, + .seq_priv_size = sizeof(struct bpf_tcp_iter_state), +}; + +static const struct bpf_func_proto * +bpf_iter_tcp_get_func_proto(enum bpf_func_id func_id, + const struct bpf_prog *prog) +{ + switch (func_id) { + case BPF_FUNC_setsockopt: + return &bpf_sk_setsockopt_proto; + case BPF_FUNC_getsockopt: + return &bpf_sk_getsockopt_proto; + default: + return NULL; + } +} + +static struct bpf_iter_reg tcp_reg_info = { + .target = "tcp", + .ctx_arg_info_size = 1, + .ctx_arg_info = { + { offsetof(struct bpf_iter__tcp, sk_common), + PTR_TO_BTF_ID_OR_NULL }, + }, + .get_func_proto = bpf_iter_tcp_get_func_proto, + .seq_info = &tcp_seq_info, +}; + +static void __init bpf_iter_register(void) +{ + tcp_reg_info.ctx_arg_info[0].btf_id = btf_sock_ids[BTF_SOCK_TYPE_SOCK_COMMON]; + if (bpf_iter_reg_target(&tcp_reg_info)) + pr_warn("Warning: could not register bpf iterator tcp\n"); +} + +#endif + +void __init tcp_v4_init(void) +{ + int cpu, res; + + for_each_possible_cpu(cpu) { + struct sock *sk; + + res = inet_ctl_sock_create(&sk, PF_INET, SOCK_RAW, + IPPROTO_TCP, &init_net); + if (res) + panic("Failed to create the TCP control socket.\n"); + sock_set_flag(sk, SOCK_USE_WRITE_QUEUE); + + /* Please enforce IP_DF and IPID==0 for RST and + * ACK sent in SYN-RECV and TIME-WAIT state. + */ + inet_sk(sk)->pmtudisc = IP_PMTUDISC_DO; + + per_cpu(ipv4_tcp_sk, cpu) = sk; + } + if (register_pernet_subsys(&tcp_sk_ops)) + panic("Failed to create the TCP control socket.\n"); + +#if defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS) + bpf_iter_register(); +#endif +} diff --git a/net/ipv4/tcp_lp.c b/net/ipv4/tcp_lp.c new file mode 100644 index 000000000..ae3678097 --- /dev/null +++ b/net/ipv4/tcp_lp.c @@ -0,0 +1,354 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * TCP Low Priority (TCP-LP) + * + * TCP Low Priority is a distributed algorithm whose goal is to utilize only + * the excess network bandwidth as compared to the ``fair share`` of + * bandwidth as targeted by TCP. + * + * As of 2.6.13, Linux supports pluggable congestion control algorithms. + * Due to the limitation of the API, we take the following changes from + * the original TCP-LP implementation: + * o We use newReno in most core CA handling. Only add some checking + * within cong_avoid. + * o Error correcting in remote HZ, therefore remote HZ will be keeped + * on checking and updating. + * o Handling calculation of One-Way-Delay (OWD) within rtt_sample, since + * OWD have a similar meaning as RTT. Also correct the buggy formular. + * o Handle reaction for Early Congestion Indication (ECI) within + * pkts_acked, as mentioned within pseudo code. + * o OWD is handled in relative format, where local time stamp will in + * tcp_time_stamp format. + * + * Original Author: + * Aleksandar Kuzmanovic + * Available from: + * http://www.ece.rice.edu/~akuzma/Doc/akuzma/TCP-LP.pdf + * Original implementation for 2.4.19: + * http://www-ece.rice.edu/networks/TCP-LP/ + * + * 2.6.x module Authors: + * Wong Hoi Sing, Edison + * Hung Hing Lun, Mike + * SourceForge project page: + * http://tcp-lp-mod.sourceforge.net/ + */ + +#include +#include + +/* resolution of owd */ +#define LP_RESOL TCP_TS_HZ + +/** + * enum tcp_lp_state + * @LP_VALID_RHZ: is remote HZ valid? + * @LP_VALID_OWD: is OWD valid? + * @LP_WITHIN_THR: are we within threshold? + * @LP_WITHIN_INF: are we within inference? + * + * TCP-LP's state flags. + * We create this set of state flag mainly for debugging. + */ +enum tcp_lp_state { + LP_VALID_RHZ = (1 << 0), + LP_VALID_OWD = (1 << 1), + LP_WITHIN_THR = (1 << 3), + LP_WITHIN_INF = (1 << 4), +}; + +/** + * struct lp + * @flag: TCP-LP state flag + * @sowd: smoothed OWD << 3 + * @owd_min: min OWD + * @owd_max: max OWD + * @owd_max_rsv: reserved max owd + * @remote_hz: estimated remote HZ + * @remote_ref_time: remote reference time + * @local_ref_time: local reference time + * @last_drop: time for last active drop + * @inference: current inference + * + * TCP-LP's private struct. + * We get the idea from original TCP-LP implementation where only left those we + * found are really useful. + */ +struct lp { + u32 flag; + u32 sowd; + u32 owd_min; + u32 owd_max; + u32 owd_max_rsv; + u32 remote_hz; + u32 remote_ref_time; + u32 local_ref_time; + u32 last_drop; + u32 inference; +}; + +/** + * tcp_lp_init + * @sk: socket to initialize congestion control algorithm for + * + * Init all required variables. + * Clone the handling from Vegas module implementation. + */ +static void tcp_lp_init(struct sock *sk) +{ + struct lp *lp = inet_csk_ca(sk); + + lp->flag = 0; + lp->sowd = 0; + lp->owd_min = 0xffffffff; + lp->owd_max = 0; + lp->owd_max_rsv = 0; + lp->remote_hz = 0; + lp->remote_ref_time = 0; + lp->local_ref_time = 0; + lp->last_drop = 0; + lp->inference = 0; +} + +/** + * tcp_lp_cong_avoid + * @sk: socket to avoid congesting + * + * Implementation of cong_avoid. + * Will only call newReno CA when away from inference. + * From TCP-LP's paper, this will be handled in additive increasement. + */ +static void tcp_lp_cong_avoid(struct sock *sk, u32 ack, u32 acked) +{ + struct lp *lp = inet_csk_ca(sk); + + if (!(lp->flag & LP_WITHIN_INF)) + tcp_reno_cong_avoid(sk, ack, acked); +} + +/** + * tcp_lp_remote_hz_estimator + * @sk: socket which needs an estimate for the remote HZs + * + * Estimate remote HZ. + * We keep on updating the estimated value, where original TCP-LP + * implementation only guest it for once and use forever. + */ +static u32 tcp_lp_remote_hz_estimator(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct lp *lp = inet_csk_ca(sk); + s64 rhz = lp->remote_hz << 6; /* remote HZ << 6 */ + s64 m = 0; + + /* not yet record reference time + * go away!! record it before come back!! */ + if (lp->remote_ref_time == 0 || lp->local_ref_time == 0) + goto out; + + /* we can't calc remote HZ with no different!! */ + if (tp->rx_opt.rcv_tsval == lp->remote_ref_time || + tp->rx_opt.rcv_tsecr == lp->local_ref_time) + goto out; + + m = TCP_TS_HZ * + (tp->rx_opt.rcv_tsval - lp->remote_ref_time) / + (tp->rx_opt.rcv_tsecr - lp->local_ref_time); + if (m < 0) + m = -m; + + if (rhz > 0) { + m -= rhz >> 6; /* m is now error in remote HZ est */ + rhz += m; /* 63/64 old + 1/64 new */ + } else + rhz = m << 6; + + out: + /* record time for successful remote HZ calc */ + if ((rhz >> 6) > 0) + lp->flag |= LP_VALID_RHZ; + else + lp->flag &= ~LP_VALID_RHZ; + + /* record reference time stamp */ + lp->remote_ref_time = tp->rx_opt.rcv_tsval; + lp->local_ref_time = tp->rx_opt.rcv_tsecr; + + return rhz >> 6; +} + +/** + * tcp_lp_owd_calculator + * @sk: socket to calculate one way delay for + * + * Calculate one way delay (in relative format). + * Original implement OWD as minus of remote time difference to local time + * difference directly. As this time difference just simply equal to RTT, when + * the network status is stable, remote RTT will equal to local RTT, and result + * OWD into zero. + * It seems to be a bug and so we fixed it. + */ +static u32 tcp_lp_owd_calculator(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct lp *lp = inet_csk_ca(sk); + s64 owd = 0; + + lp->remote_hz = tcp_lp_remote_hz_estimator(sk); + + if (lp->flag & LP_VALID_RHZ) { + owd = + tp->rx_opt.rcv_tsval * (LP_RESOL / lp->remote_hz) - + tp->rx_opt.rcv_tsecr * (LP_RESOL / TCP_TS_HZ); + if (owd < 0) + owd = -owd; + } + + if (owd > 0) + lp->flag |= LP_VALID_OWD; + else + lp->flag &= ~LP_VALID_OWD; + + return owd; +} + +/** + * tcp_lp_rtt_sample + * @sk: socket to add a rtt sample to + * @rtt: round trip time, which is ignored! + * + * Implementation or rtt_sample. + * Will take the following action, + * 1. calc OWD, + * 2. record the min/max OWD, + * 3. calc smoothed OWD (SOWD). + * Most ideas come from the original TCP-LP implementation. + */ +static void tcp_lp_rtt_sample(struct sock *sk, u32 rtt) +{ + struct lp *lp = inet_csk_ca(sk); + s64 mowd = tcp_lp_owd_calculator(sk); + + /* sorry that we don't have valid data */ + if (!(lp->flag & LP_VALID_RHZ) || !(lp->flag & LP_VALID_OWD)) + return; + + /* record the next min owd */ + if (mowd < lp->owd_min) + lp->owd_min = mowd; + + /* always forget the max of the max + * we just set owd_max as one below it */ + if (mowd > lp->owd_max) { + if (mowd > lp->owd_max_rsv) { + if (lp->owd_max_rsv == 0) + lp->owd_max = mowd; + else + lp->owd_max = lp->owd_max_rsv; + lp->owd_max_rsv = mowd; + } else + lp->owd_max = mowd; + } + + /* calc for smoothed owd */ + if (lp->sowd != 0) { + mowd -= lp->sowd >> 3; /* m is now error in owd est */ + lp->sowd += mowd; /* owd = 7/8 owd + 1/8 new */ + } else + lp->sowd = mowd << 3; /* take the measured time be owd */ +} + +/** + * tcp_lp_pkts_acked + * @sk: socket requiring congestion avoidance calculations + * + * Implementation of pkts_acked. + * Deal with active drop under Early Congestion Indication. + * Only drop to half and 1 will be handle, because we hope to use back + * newReno in increase case. + * We work it out by following the idea from TCP-LP's paper directly + */ +static void tcp_lp_pkts_acked(struct sock *sk, const struct ack_sample *sample) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct lp *lp = inet_csk_ca(sk); + u32 now = tcp_time_stamp(tp); + u32 delta; + + if (sample->rtt_us > 0) + tcp_lp_rtt_sample(sk, sample->rtt_us); + + /* calc inference */ + delta = now - tp->rx_opt.rcv_tsecr; + if ((s32)delta > 0) + lp->inference = 3 * delta; + + /* test if within inference */ + if (lp->last_drop && (now - lp->last_drop < lp->inference)) + lp->flag |= LP_WITHIN_INF; + else + lp->flag &= ~LP_WITHIN_INF; + + /* test if within threshold */ + if (lp->sowd >> 3 < + lp->owd_min + 15 * (lp->owd_max - lp->owd_min) / 100) + lp->flag |= LP_WITHIN_THR; + else + lp->flag &= ~LP_WITHIN_THR; + + pr_debug("TCP-LP: %05o|%5u|%5u|%15u|%15u|%15u\n", lp->flag, + tcp_snd_cwnd(tp), lp->remote_hz, lp->owd_min, lp->owd_max, + lp->sowd >> 3); + + if (lp->flag & LP_WITHIN_THR) + return; + + /* FIXME: try to reset owd_min and owd_max here + * so decrease the chance the min/max is no longer suitable + * and will usually within threshold when within inference */ + lp->owd_min = lp->sowd >> 3; + lp->owd_max = lp->sowd >> 2; + lp->owd_max_rsv = lp->sowd >> 2; + + /* happened within inference + * drop snd_cwnd into 1 */ + if (lp->flag & LP_WITHIN_INF) + tcp_snd_cwnd_set(tp, 1U); + + /* happened after inference + * cut snd_cwnd into half */ + else + tcp_snd_cwnd_set(tp, max(tcp_snd_cwnd(tp) >> 1U, 1U)); + + /* record this drop time */ + lp->last_drop = now; +} + +static struct tcp_congestion_ops tcp_lp __read_mostly = { + .init = tcp_lp_init, + .ssthresh = tcp_reno_ssthresh, + .undo_cwnd = tcp_reno_undo_cwnd, + .cong_avoid = tcp_lp_cong_avoid, + .pkts_acked = tcp_lp_pkts_acked, + + .owner = THIS_MODULE, + .name = "lp" +}; + +static int __init tcp_lp_register(void) +{ + BUILD_BUG_ON(sizeof(struct lp) > ICSK_CA_PRIV_SIZE); + return tcp_register_congestion_control(&tcp_lp); +} + +static void __exit tcp_lp_unregister(void) +{ + tcp_unregister_congestion_control(&tcp_lp); +} + +module_init(tcp_lp_register); +module_exit(tcp_lp_unregister); + +MODULE_AUTHOR("Wong Hoi Sing Edison, Hung Hing Lun Mike"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("TCP Low Priority"); diff --git a/net/ipv4/tcp_metrics.c b/net/ipv4/tcp_metrics.c new file mode 100644 index 000000000..a7364ff8b --- /dev/null +++ b/net/ipv4/tcp_metrics.c @@ -0,0 +1,1055 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static struct tcp_metrics_block *__tcp_get_metrics(const struct inetpeer_addr *saddr, + const struct inetpeer_addr *daddr, + struct net *net, unsigned int hash); + +struct tcp_fastopen_metrics { + u16 mss; + u16 syn_loss:10, /* Recurring Fast Open SYN losses */ + try_exp:2; /* Request w/ exp. option (once) */ + unsigned long last_syn_loss; /* Last Fast Open SYN loss */ + struct tcp_fastopen_cookie cookie; +}; + +/* TCP_METRIC_MAX includes 2 extra fields for userspace compatibility + * Kernel only stores RTT and RTTVAR in usec resolution + */ +#define TCP_METRIC_MAX_KERNEL (TCP_METRIC_MAX - 2) + +struct tcp_metrics_block { + struct tcp_metrics_block __rcu *tcpm_next; + struct net *tcpm_net; + struct inetpeer_addr tcpm_saddr; + struct inetpeer_addr tcpm_daddr; + unsigned long tcpm_stamp; + u32 tcpm_lock; + u32 tcpm_vals[TCP_METRIC_MAX_KERNEL + 1]; + struct tcp_fastopen_metrics tcpm_fastopen; + + struct rcu_head rcu_head; +}; + +static inline struct net *tm_net(const struct tcp_metrics_block *tm) +{ + /* Paired with the WRITE_ONCE() in tcpm_new() */ + return READ_ONCE(tm->tcpm_net); +} + +static bool tcp_metric_locked(struct tcp_metrics_block *tm, + enum tcp_metric_index idx) +{ + /* Paired with WRITE_ONCE() in tcpm_suck_dst() */ + return READ_ONCE(tm->tcpm_lock) & (1 << idx); +} + +static u32 tcp_metric_get(const struct tcp_metrics_block *tm, + enum tcp_metric_index idx) +{ + /* Paired with WRITE_ONCE() in tcp_metric_set() */ + return READ_ONCE(tm->tcpm_vals[idx]); +} + +static void tcp_metric_set(struct tcp_metrics_block *tm, + enum tcp_metric_index idx, + u32 val) +{ + /* Paired with READ_ONCE() in tcp_metric_get() */ + WRITE_ONCE(tm->tcpm_vals[idx], val); +} + +static bool addr_same(const struct inetpeer_addr *a, + const struct inetpeer_addr *b) +{ + return (a->family == b->family) && !inetpeer_addr_cmp(a, b); +} + +struct tcpm_hash_bucket { + struct tcp_metrics_block __rcu *chain; +}; + +static struct tcpm_hash_bucket *tcp_metrics_hash __read_mostly; +static unsigned int tcp_metrics_hash_log __read_mostly; + +static DEFINE_SPINLOCK(tcp_metrics_lock); +static DEFINE_SEQLOCK(fastopen_seqlock); + +static void tcpm_suck_dst(struct tcp_metrics_block *tm, + const struct dst_entry *dst, + bool fastopen_clear) +{ + u32 msval; + u32 val; + + WRITE_ONCE(tm->tcpm_stamp, jiffies); + + val = 0; + if (dst_metric_locked(dst, RTAX_RTT)) + val |= 1 << TCP_METRIC_RTT; + if (dst_metric_locked(dst, RTAX_RTTVAR)) + val |= 1 << TCP_METRIC_RTTVAR; + if (dst_metric_locked(dst, RTAX_SSTHRESH)) + val |= 1 << TCP_METRIC_SSTHRESH; + if (dst_metric_locked(dst, RTAX_CWND)) + val |= 1 << TCP_METRIC_CWND; + if (dst_metric_locked(dst, RTAX_REORDERING)) + val |= 1 << TCP_METRIC_REORDERING; + /* Paired with READ_ONCE() in tcp_metric_locked() */ + WRITE_ONCE(tm->tcpm_lock, val); + + msval = dst_metric_raw(dst, RTAX_RTT); + tcp_metric_set(tm, TCP_METRIC_RTT, msval * USEC_PER_MSEC); + + msval = dst_metric_raw(dst, RTAX_RTTVAR); + tcp_metric_set(tm, TCP_METRIC_RTTVAR, msval * USEC_PER_MSEC); + tcp_metric_set(tm, TCP_METRIC_SSTHRESH, + dst_metric_raw(dst, RTAX_SSTHRESH)); + tcp_metric_set(tm, TCP_METRIC_CWND, + dst_metric_raw(dst, RTAX_CWND)); + tcp_metric_set(tm, TCP_METRIC_REORDERING, + dst_metric_raw(dst, RTAX_REORDERING)); + if (fastopen_clear) { + write_seqlock(&fastopen_seqlock); + tm->tcpm_fastopen.mss = 0; + tm->tcpm_fastopen.syn_loss = 0; + tm->tcpm_fastopen.try_exp = 0; + tm->tcpm_fastopen.cookie.exp = false; + tm->tcpm_fastopen.cookie.len = 0; + write_sequnlock(&fastopen_seqlock); + } +} + +#define TCP_METRICS_TIMEOUT (60 * 60 * HZ) + +static void tcpm_check_stamp(struct tcp_metrics_block *tm, + const struct dst_entry *dst) +{ + unsigned long limit; + + if (!tm) + return; + limit = READ_ONCE(tm->tcpm_stamp) + TCP_METRICS_TIMEOUT; + if (unlikely(time_after(jiffies, limit))) + tcpm_suck_dst(tm, dst, false); +} + +#define TCP_METRICS_RECLAIM_DEPTH 5 +#define TCP_METRICS_RECLAIM_PTR (struct tcp_metrics_block *) 0x1UL + +#define deref_locked(p) \ + rcu_dereference_protected(p, lockdep_is_held(&tcp_metrics_lock)) + +static struct tcp_metrics_block *tcpm_new(struct dst_entry *dst, + struct inetpeer_addr *saddr, + struct inetpeer_addr *daddr, + unsigned int hash) +{ + struct tcp_metrics_block *tm; + struct net *net; + bool reclaim = false; + + spin_lock_bh(&tcp_metrics_lock); + net = dev_net(dst->dev); + + /* While waiting for the spin-lock the cache might have been populated + * with this entry and so we have to check again. + */ + tm = __tcp_get_metrics(saddr, daddr, net, hash); + if (tm == TCP_METRICS_RECLAIM_PTR) { + reclaim = true; + tm = NULL; + } + if (tm) { + tcpm_check_stamp(tm, dst); + goto out_unlock; + } + + if (unlikely(reclaim)) { + struct tcp_metrics_block *oldest; + + oldest = deref_locked(tcp_metrics_hash[hash].chain); + for (tm = deref_locked(oldest->tcpm_next); tm; + tm = deref_locked(tm->tcpm_next)) { + if (time_before(READ_ONCE(tm->tcpm_stamp), + READ_ONCE(oldest->tcpm_stamp))) + oldest = tm; + } + tm = oldest; + } else { + tm = kzalloc(sizeof(*tm), GFP_ATOMIC); + if (!tm) + goto out_unlock; + } + /* Paired with the READ_ONCE() in tm_net() */ + WRITE_ONCE(tm->tcpm_net, net); + + tm->tcpm_saddr = *saddr; + tm->tcpm_daddr = *daddr; + + tcpm_suck_dst(tm, dst, reclaim); + + if (likely(!reclaim)) { + tm->tcpm_next = tcp_metrics_hash[hash].chain; + rcu_assign_pointer(tcp_metrics_hash[hash].chain, tm); + } + +out_unlock: + spin_unlock_bh(&tcp_metrics_lock); + return tm; +} + +static struct tcp_metrics_block *tcp_get_encode(struct tcp_metrics_block *tm, int depth) +{ + if (tm) + return tm; + if (depth > TCP_METRICS_RECLAIM_DEPTH) + return TCP_METRICS_RECLAIM_PTR; + return NULL; +} + +static struct tcp_metrics_block *__tcp_get_metrics(const struct inetpeer_addr *saddr, + const struct inetpeer_addr *daddr, + struct net *net, unsigned int hash) +{ + struct tcp_metrics_block *tm; + int depth = 0; + + for (tm = rcu_dereference(tcp_metrics_hash[hash].chain); tm; + tm = rcu_dereference(tm->tcpm_next)) { + if (addr_same(&tm->tcpm_saddr, saddr) && + addr_same(&tm->tcpm_daddr, daddr) && + net_eq(tm_net(tm), net)) + break; + depth++; + } + return tcp_get_encode(tm, depth); +} + +static struct tcp_metrics_block *__tcp_get_metrics_req(struct request_sock *req, + struct dst_entry *dst) +{ + struct tcp_metrics_block *tm; + struct inetpeer_addr saddr, daddr; + unsigned int hash; + struct net *net; + + saddr.family = req->rsk_ops->family; + daddr.family = req->rsk_ops->family; + switch (daddr.family) { + case AF_INET: + inetpeer_set_addr_v4(&saddr, inet_rsk(req)->ir_loc_addr); + inetpeer_set_addr_v4(&daddr, inet_rsk(req)->ir_rmt_addr); + hash = ipv4_addr_hash(inet_rsk(req)->ir_rmt_addr); + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + inetpeer_set_addr_v6(&saddr, &inet_rsk(req)->ir_v6_loc_addr); + inetpeer_set_addr_v6(&daddr, &inet_rsk(req)->ir_v6_rmt_addr); + hash = ipv6_addr_hash(&inet_rsk(req)->ir_v6_rmt_addr); + break; +#endif + default: + return NULL; + } + + net = dev_net(dst->dev); + hash ^= net_hash_mix(net); + hash = hash_32(hash, tcp_metrics_hash_log); + + for (tm = rcu_dereference(tcp_metrics_hash[hash].chain); tm; + tm = rcu_dereference(tm->tcpm_next)) { + if (addr_same(&tm->tcpm_saddr, &saddr) && + addr_same(&tm->tcpm_daddr, &daddr) && + net_eq(tm_net(tm), net)) + break; + } + tcpm_check_stamp(tm, dst); + return tm; +} + +static struct tcp_metrics_block *tcp_get_metrics(struct sock *sk, + struct dst_entry *dst, + bool create) +{ + struct tcp_metrics_block *tm; + struct inetpeer_addr saddr, daddr; + unsigned int hash; + struct net *net; + + if (sk->sk_family == AF_INET) { + inetpeer_set_addr_v4(&saddr, inet_sk(sk)->inet_saddr); + inetpeer_set_addr_v4(&daddr, inet_sk(sk)->inet_daddr); + hash = ipv4_addr_hash(inet_sk(sk)->inet_daddr); + } +#if IS_ENABLED(CONFIG_IPV6) + else if (sk->sk_family == AF_INET6) { + if (ipv6_addr_v4mapped(&sk->sk_v6_daddr)) { + inetpeer_set_addr_v4(&saddr, inet_sk(sk)->inet_saddr); + inetpeer_set_addr_v4(&daddr, inet_sk(sk)->inet_daddr); + hash = ipv4_addr_hash(inet_sk(sk)->inet_daddr); + } else { + inetpeer_set_addr_v6(&saddr, &sk->sk_v6_rcv_saddr); + inetpeer_set_addr_v6(&daddr, &sk->sk_v6_daddr); + hash = ipv6_addr_hash(&sk->sk_v6_daddr); + } + } +#endif + else + return NULL; + + net = dev_net(dst->dev); + hash ^= net_hash_mix(net); + hash = hash_32(hash, tcp_metrics_hash_log); + + tm = __tcp_get_metrics(&saddr, &daddr, net, hash); + if (tm == TCP_METRICS_RECLAIM_PTR) + tm = NULL; + if (!tm && create) + tm = tcpm_new(dst, &saddr, &daddr, hash); + else + tcpm_check_stamp(tm, dst); + + return tm; +} + +/* Save metrics learned by this TCP session. This function is called + * only, when TCP finishes successfully i.e. when it enters TIME-WAIT + * or goes from LAST-ACK to CLOSE. + */ +void tcp_update_metrics(struct sock *sk) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + struct dst_entry *dst = __sk_dst_get(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); + struct tcp_metrics_block *tm; + unsigned long rtt; + u32 val; + int m; + + sk_dst_confirm(sk); + if (READ_ONCE(net->ipv4.sysctl_tcp_nometrics_save) || !dst) + return; + + rcu_read_lock(); + if (icsk->icsk_backoff || !tp->srtt_us) { + /* This session failed to estimate rtt. Why? + * Probably, no packets returned in time. Reset our + * results. + */ + tm = tcp_get_metrics(sk, dst, false); + if (tm && !tcp_metric_locked(tm, TCP_METRIC_RTT)) + tcp_metric_set(tm, TCP_METRIC_RTT, 0); + goto out_unlock; + } else + tm = tcp_get_metrics(sk, dst, true); + + if (!tm) + goto out_unlock; + + rtt = tcp_metric_get(tm, TCP_METRIC_RTT); + m = rtt - tp->srtt_us; + + /* If newly calculated rtt larger than stored one, store new + * one. Otherwise, use EWMA. Remember, rtt overestimation is + * always better than underestimation. + */ + if (!tcp_metric_locked(tm, TCP_METRIC_RTT)) { + if (m <= 0) + rtt = tp->srtt_us; + else + rtt -= (m >> 3); + tcp_metric_set(tm, TCP_METRIC_RTT, rtt); + } + + if (!tcp_metric_locked(tm, TCP_METRIC_RTTVAR)) { + unsigned long var; + + if (m < 0) + m = -m; + + /* Scale deviation to rttvar fixed point */ + m >>= 1; + if (m < tp->mdev_us) + m = tp->mdev_us; + + var = tcp_metric_get(tm, TCP_METRIC_RTTVAR); + if (m >= var) + var = m; + else + var -= (var - m) >> 2; + + tcp_metric_set(tm, TCP_METRIC_RTTVAR, var); + } + + if (tcp_in_initial_slowstart(tp)) { + /* Slow start still did not finish. */ + if (!READ_ONCE(net->ipv4.sysctl_tcp_no_ssthresh_metrics_save) && + !tcp_metric_locked(tm, TCP_METRIC_SSTHRESH)) { + val = tcp_metric_get(tm, TCP_METRIC_SSTHRESH); + if (val && (tcp_snd_cwnd(tp) >> 1) > val) + tcp_metric_set(tm, TCP_METRIC_SSTHRESH, + tcp_snd_cwnd(tp) >> 1); + } + if (!tcp_metric_locked(tm, TCP_METRIC_CWND)) { + val = tcp_metric_get(tm, TCP_METRIC_CWND); + if (tcp_snd_cwnd(tp) > val) + tcp_metric_set(tm, TCP_METRIC_CWND, + tcp_snd_cwnd(tp)); + } + } else if (!tcp_in_slow_start(tp) && + icsk->icsk_ca_state == TCP_CA_Open) { + /* Cong. avoidance phase, cwnd is reliable. */ + if (!READ_ONCE(net->ipv4.sysctl_tcp_no_ssthresh_metrics_save) && + !tcp_metric_locked(tm, TCP_METRIC_SSTHRESH)) + tcp_metric_set(tm, TCP_METRIC_SSTHRESH, + max(tcp_snd_cwnd(tp) >> 1, tp->snd_ssthresh)); + if (!tcp_metric_locked(tm, TCP_METRIC_CWND)) { + val = tcp_metric_get(tm, TCP_METRIC_CWND); + tcp_metric_set(tm, TCP_METRIC_CWND, (val + tcp_snd_cwnd(tp)) >> 1); + } + } else { + /* Else slow start did not finish, cwnd is non-sense, + * ssthresh may be also invalid. + */ + if (!tcp_metric_locked(tm, TCP_METRIC_CWND)) { + val = tcp_metric_get(tm, TCP_METRIC_CWND); + tcp_metric_set(tm, TCP_METRIC_CWND, + (val + tp->snd_ssthresh) >> 1); + } + if (!READ_ONCE(net->ipv4.sysctl_tcp_no_ssthresh_metrics_save) && + !tcp_metric_locked(tm, TCP_METRIC_SSTHRESH)) { + val = tcp_metric_get(tm, TCP_METRIC_SSTHRESH); + if (val && tp->snd_ssthresh > val) + tcp_metric_set(tm, TCP_METRIC_SSTHRESH, + tp->snd_ssthresh); + } + if (!tcp_metric_locked(tm, TCP_METRIC_REORDERING)) { + val = tcp_metric_get(tm, TCP_METRIC_REORDERING); + if (val < tp->reordering && + tp->reordering != + READ_ONCE(net->ipv4.sysctl_tcp_reordering)) + tcp_metric_set(tm, TCP_METRIC_REORDERING, + tp->reordering); + } + } + WRITE_ONCE(tm->tcpm_stamp, jiffies); +out_unlock: + rcu_read_unlock(); +} + +/* Initialize metrics on socket. */ + +void tcp_init_metrics(struct sock *sk) +{ + struct dst_entry *dst = __sk_dst_get(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); + struct tcp_metrics_block *tm; + u32 val, crtt = 0; /* cached RTT scaled by 8 */ + + sk_dst_confirm(sk); + /* ssthresh may have been reduced unnecessarily during. + * 3WHS. Restore it back to its initial default. + */ + tp->snd_ssthresh = TCP_INFINITE_SSTHRESH; + if (!dst) + goto reset; + + rcu_read_lock(); + tm = tcp_get_metrics(sk, dst, false); + if (!tm) { + rcu_read_unlock(); + goto reset; + } + + if (tcp_metric_locked(tm, TCP_METRIC_CWND)) + tp->snd_cwnd_clamp = tcp_metric_get(tm, TCP_METRIC_CWND); + + val = READ_ONCE(net->ipv4.sysctl_tcp_no_ssthresh_metrics_save) ? + 0 : tcp_metric_get(tm, TCP_METRIC_SSTHRESH); + if (val) { + tp->snd_ssthresh = val; + if (tp->snd_ssthresh > tp->snd_cwnd_clamp) + tp->snd_ssthresh = tp->snd_cwnd_clamp; + } + val = tcp_metric_get(tm, TCP_METRIC_REORDERING); + if (val && tp->reordering != val) + tp->reordering = val; + + crtt = tcp_metric_get(tm, TCP_METRIC_RTT); + rcu_read_unlock(); +reset: + /* The initial RTT measurement from the SYN/SYN-ACK is not ideal + * to seed the RTO for later data packets because SYN packets are + * small. Use the per-dst cached values to seed the RTO but keep + * the RTT estimator variables intact (e.g., srtt, mdev, rttvar). + * Later the RTO will be updated immediately upon obtaining the first + * data RTT sample (tcp_rtt_estimator()). Hence the cached RTT only + * influences the first RTO but not later RTT estimation. + * + * But if RTT is not available from the SYN (due to retransmits or + * syn cookies) or the cache, force a conservative 3secs timeout. + * + * A bit of theory. RTT is time passed after "normal" sized packet + * is sent until it is ACKed. In normal circumstances sending small + * packets force peer to delay ACKs and calculation is correct too. + * The algorithm is adaptive and, provided we follow specs, it + * NEVER underestimate RTT. BUT! If peer tries to make some clever + * tricks sort of "quick acks" for time long enough to decrease RTT + * to low value, and then abruptly stops to do it and starts to delay + * ACKs, wait for troubles. + */ + if (crtt > tp->srtt_us) { + /* Set RTO like tcp_rtt_estimator(), but from cached RTT. */ + crtt /= 8 * USEC_PER_SEC / HZ; + inet_csk(sk)->icsk_rto = crtt + max(2 * crtt, tcp_rto_min(sk)); + } else if (tp->srtt_us == 0) { + /* RFC6298: 5.7 We've failed to get a valid RTT sample from + * 3WHS. This is most likely due to retransmission, + * including spurious one. Reset the RTO back to 3secs + * from the more aggressive 1sec to avoid more spurious + * retransmission. + */ + tp->rttvar_us = jiffies_to_usecs(TCP_TIMEOUT_FALLBACK); + tp->mdev_us = tp->mdev_max_us = tp->rttvar_us; + + inet_csk(sk)->icsk_rto = TCP_TIMEOUT_FALLBACK; + } +} + +bool tcp_peer_is_proven(struct request_sock *req, struct dst_entry *dst) +{ + struct tcp_metrics_block *tm; + bool ret; + + if (!dst) + return false; + + rcu_read_lock(); + tm = __tcp_get_metrics_req(req, dst); + if (tm && tcp_metric_get(tm, TCP_METRIC_RTT)) + ret = true; + else + ret = false; + rcu_read_unlock(); + + return ret; +} + +void tcp_fastopen_cache_get(struct sock *sk, u16 *mss, + struct tcp_fastopen_cookie *cookie) +{ + struct tcp_metrics_block *tm; + + rcu_read_lock(); + tm = tcp_get_metrics(sk, __sk_dst_get(sk), false); + if (tm) { + struct tcp_fastopen_metrics *tfom = &tm->tcpm_fastopen; + unsigned int seq; + + do { + seq = read_seqbegin(&fastopen_seqlock); + if (tfom->mss) + *mss = tfom->mss; + *cookie = tfom->cookie; + if (cookie->len <= 0 && tfom->try_exp == 1) + cookie->exp = true; + } while (read_seqretry(&fastopen_seqlock, seq)); + } + rcu_read_unlock(); +} + +void tcp_fastopen_cache_set(struct sock *sk, u16 mss, + struct tcp_fastopen_cookie *cookie, bool syn_lost, + u16 try_exp) +{ + struct dst_entry *dst = __sk_dst_get(sk); + struct tcp_metrics_block *tm; + + if (!dst) + return; + rcu_read_lock(); + tm = tcp_get_metrics(sk, dst, true); + if (tm) { + struct tcp_fastopen_metrics *tfom = &tm->tcpm_fastopen; + + write_seqlock_bh(&fastopen_seqlock); + if (mss) + tfom->mss = mss; + if (cookie && cookie->len > 0) + tfom->cookie = *cookie; + else if (try_exp > tfom->try_exp && + tfom->cookie.len <= 0 && !tfom->cookie.exp) + tfom->try_exp = try_exp; + if (syn_lost) { + ++tfom->syn_loss; + tfom->last_syn_loss = jiffies; + } else + tfom->syn_loss = 0; + write_sequnlock_bh(&fastopen_seqlock); + } + rcu_read_unlock(); +} + +static struct genl_family tcp_metrics_nl_family; + +static const struct nla_policy tcp_metrics_nl_policy[TCP_METRICS_ATTR_MAX + 1] = { + [TCP_METRICS_ATTR_ADDR_IPV4] = { .type = NLA_U32, }, + [TCP_METRICS_ATTR_ADDR_IPV6] = { .type = NLA_BINARY, + .len = sizeof(struct in6_addr), }, + /* Following attributes are not received for GET/DEL, + * we keep them for reference + */ +#if 0 + [TCP_METRICS_ATTR_AGE] = { .type = NLA_MSECS, }, + [TCP_METRICS_ATTR_TW_TSVAL] = { .type = NLA_U32, }, + [TCP_METRICS_ATTR_TW_TS_STAMP] = { .type = NLA_S32, }, + [TCP_METRICS_ATTR_VALS] = { .type = NLA_NESTED, }, + [TCP_METRICS_ATTR_FOPEN_MSS] = { .type = NLA_U16, }, + [TCP_METRICS_ATTR_FOPEN_SYN_DROPS] = { .type = NLA_U16, }, + [TCP_METRICS_ATTR_FOPEN_SYN_DROP_TS] = { .type = NLA_MSECS, }, + [TCP_METRICS_ATTR_FOPEN_COOKIE] = { .type = NLA_BINARY, + .len = TCP_FASTOPEN_COOKIE_MAX, }, +#endif +}; + +/* Add attributes, caller cancels its header on failure */ +static int tcp_metrics_fill_info(struct sk_buff *msg, + struct tcp_metrics_block *tm) +{ + struct nlattr *nest; + int i; + + switch (tm->tcpm_daddr.family) { + case AF_INET: + if (nla_put_in_addr(msg, TCP_METRICS_ATTR_ADDR_IPV4, + inetpeer_get_addr_v4(&tm->tcpm_daddr)) < 0) + goto nla_put_failure; + if (nla_put_in_addr(msg, TCP_METRICS_ATTR_SADDR_IPV4, + inetpeer_get_addr_v4(&tm->tcpm_saddr)) < 0) + goto nla_put_failure; + break; + case AF_INET6: + if (nla_put_in6_addr(msg, TCP_METRICS_ATTR_ADDR_IPV6, + inetpeer_get_addr_v6(&tm->tcpm_daddr)) < 0) + goto nla_put_failure; + if (nla_put_in6_addr(msg, TCP_METRICS_ATTR_SADDR_IPV6, + inetpeer_get_addr_v6(&tm->tcpm_saddr)) < 0) + goto nla_put_failure; + break; + default: + return -EAFNOSUPPORT; + } + + if (nla_put_msecs(msg, TCP_METRICS_ATTR_AGE, + jiffies - READ_ONCE(tm->tcpm_stamp), + TCP_METRICS_ATTR_PAD) < 0) + goto nla_put_failure; + + { + int n = 0; + + nest = nla_nest_start_noflag(msg, TCP_METRICS_ATTR_VALS); + if (!nest) + goto nla_put_failure; + for (i = 0; i < TCP_METRIC_MAX_KERNEL + 1; i++) { + u32 val = tcp_metric_get(tm, i); + + if (!val) + continue; + if (i == TCP_METRIC_RTT) { + if (nla_put_u32(msg, TCP_METRIC_RTT_US + 1, + val) < 0) + goto nla_put_failure; + n++; + val = max(val / 1000, 1U); + } + if (i == TCP_METRIC_RTTVAR) { + if (nla_put_u32(msg, TCP_METRIC_RTTVAR_US + 1, + val) < 0) + goto nla_put_failure; + n++; + val = max(val / 1000, 1U); + } + if (nla_put_u32(msg, i + 1, val) < 0) + goto nla_put_failure; + n++; + } + if (n) + nla_nest_end(msg, nest); + else + nla_nest_cancel(msg, nest); + } + + { + struct tcp_fastopen_metrics tfom_copy[1], *tfom; + unsigned int seq; + + do { + seq = read_seqbegin(&fastopen_seqlock); + tfom_copy[0] = tm->tcpm_fastopen; + } while (read_seqretry(&fastopen_seqlock, seq)); + + tfom = tfom_copy; + if (tfom->mss && + nla_put_u16(msg, TCP_METRICS_ATTR_FOPEN_MSS, + tfom->mss) < 0) + goto nla_put_failure; + if (tfom->syn_loss && + (nla_put_u16(msg, TCP_METRICS_ATTR_FOPEN_SYN_DROPS, + tfom->syn_loss) < 0 || + nla_put_msecs(msg, TCP_METRICS_ATTR_FOPEN_SYN_DROP_TS, + jiffies - tfom->last_syn_loss, + TCP_METRICS_ATTR_PAD) < 0)) + goto nla_put_failure; + if (tfom->cookie.len > 0 && + nla_put(msg, TCP_METRICS_ATTR_FOPEN_COOKIE, + tfom->cookie.len, tfom->cookie.val) < 0) + goto nla_put_failure; + } + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static int tcp_metrics_dump_info(struct sk_buff *skb, + struct netlink_callback *cb, + struct tcp_metrics_block *tm) +{ + void *hdr; + + hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &tcp_metrics_nl_family, NLM_F_MULTI, + TCP_METRICS_CMD_GET); + if (!hdr) + return -EMSGSIZE; + + if (tcp_metrics_fill_info(skb, tm) < 0) + goto nla_put_failure; + + genlmsg_end(skb, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; +} + +static int tcp_metrics_nl_dump(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + unsigned int max_rows = 1U << tcp_metrics_hash_log; + unsigned int row, s_row = cb->args[0]; + int s_col = cb->args[1], col = s_col; + + for (row = s_row; row < max_rows; row++, s_col = 0) { + struct tcp_metrics_block *tm; + struct tcpm_hash_bucket *hb = tcp_metrics_hash + row; + + rcu_read_lock(); + for (col = 0, tm = rcu_dereference(hb->chain); tm; + tm = rcu_dereference(tm->tcpm_next), col++) { + if (!net_eq(tm_net(tm), net)) + continue; + if (col < s_col) + continue; + if (tcp_metrics_dump_info(skb, cb, tm) < 0) { + rcu_read_unlock(); + goto done; + } + } + rcu_read_unlock(); + } + +done: + cb->args[0] = row; + cb->args[1] = col; + return skb->len; +} + +static int __parse_nl_addr(struct genl_info *info, struct inetpeer_addr *addr, + unsigned int *hash, int optional, int v4, int v6) +{ + struct nlattr *a; + + a = info->attrs[v4]; + if (a) { + inetpeer_set_addr_v4(addr, nla_get_in_addr(a)); + if (hash) + *hash = ipv4_addr_hash(inetpeer_get_addr_v4(addr)); + return 0; + } + a = info->attrs[v6]; + if (a) { + struct in6_addr in6; + + if (nla_len(a) != sizeof(struct in6_addr)) + return -EINVAL; + in6 = nla_get_in6_addr(a); + inetpeer_set_addr_v6(addr, &in6); + if (hash) + *hash = ipv6_addr_hash(inetpeer_get_addr_v6(addr)); + return 0; + } + return optional ? 1 : -EAFNOSUPPORT; +} + +static int parse_nl_addr(struct genl_info *info, struct inetpeer_addr *addr, + unsigned int *hash, int optional) +{ + return __parse_nl_addr(info, addr, hash, optional, + TCP_METRICS_ATTR_ADDR_IPV4, + TCP_METRICS_ATTR_ADDR_IPV6); +} + +static int parse_nl_saddr(struct genl_info *info, struct inetpeer_addr *addr) +{ + return __parse_nl_addr(info, addr, NULL, 0, + TCP_METRICS_ATTR_SADDR_IPV4, + TCP_METRICS_ATTR_SADDR_IPV6); +} + +static int tcp_metrics_nl_cmd_get(struct sk_buff *skb, struct genl_info *info) +{ + struct tcp_metrics_block *tm; + struct inetpeer_addr saddr, daddr; + unsigned int hash; + struct sk_buff *msg; + struct net *net = genl_info_net(info); + void *reply; + int ret; + bool src = true; + + ret = parse_nl_addr(info, &daddr, &hash, 0); + if (ret < 0) + return ret; + + ret = parse_nl_saddr(info, &saddr); + if (ret < 0) + src = false; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + reply = genlmsg_put_reply(msg, info, &tcp_metrics_nl_family, 0, + info->genlhdr->cmd); + if (!reply) + goto nla_put_failure; + + hash ^= net_hash_mix(net); + hash = hash_32(hash, tcp_metrics_hash_log); + ret = -ESRCH; + rcu_read_lock(); + for (tm = rcu_dereference(tcp_metrics_hash[hash].chain); tm; + tm = rcu_dereference(tm->tcpm_next)) { + if (addr_same(&tm->tcpm_daddr, &daddr) && + (!src || addr_same(&tm->tcpm_saddr, &saddr)) && + net_eq(tm_net(tm), net)) { + ret = tcp_metrics_fill_info(msg, tm); + break; + } + } + rcu_read_unlock(); + if (ret < 0) + goto out_free; + + genlmsg_end(msg, reply); + return genlmsg_reply(msg, info); + +nla_put_failure: + ret = -EMSGSIZE; + +out_free: + nlmsg_free(msg); + return ret; +} + +static void tcp_metrics_flush_all(struct net *net) +{ + unsigned int max_rows = 1U << tcp_metrics_hash_log; + struct tcpm_hash_bucket *hb = tcp_metrics_hash; + struct tcp_metrics_block *tm; + unsigned int row; + + for (row = 0; row < max_rows; row++, hb++) { + struct tcp_metrics_block __rcu **pp; + bool match; + + spin_lock_bh(&tcp_metrics_lock); + pp = &hb->chain; + for (tm = deref_locked(*pp); tm; tm = deref_locked(*pp)) { + match = net ? net_eq(tm_net(tm), net) : + !refcount_read(&tm_net(tm)->ns.count); + if (match) { + rcu_assign_pointer(*pp, tm->tcpm_next); + kfree_rcu(tm, rcu_head); + } else { + pp = &tm->tcpm_next; + } + } + spin_unlock_bh(&tcp_metrics_lock); + } +} + +static int tcp_metrics_nl_cmd_del(struct sk_buff *skb, struct genl_info *info) +{ + struct tcpm_hash_bucket *hb; + struct tcp_metrics_block *tm; + struct tcp_metrics_block __rcu **pp; + struct inetpeer_addr saddr, daddr; + unsigned int hash; + struct net *net = genl_info_net(info); + int ret; + bool src = true, found = false; + + ret = parse_nl_addr(info, &daddr, &hash, 1); + if (ret < 0) + return ret; + if (ret > 0) { + tcp_metrics_flush_all(net); + return 0; + } + ret = parse_nl_saddr(info, &saddr); + if (ret < 0) + src = false; + + hash ^= net_hash_mix(net); + hash = hash_32(hash, tcp_metrics_hash_log); + hb = tcp_metrics_hash + hash; + pp = &hb->chain; + spin_lock_bh(&tcp_metrics_lock); + for (tm = deref_locked(*pp); tm; tm = deref_locked(*pp)) { + if (addr_same(&tm->tcpm_daddr, &daddr) && + (!src || addr_same(&tm->tcpm_saddr, &saddr)) && + net_eq(tm_net(tm), net)) { + rcu_assign_pointer(*pp, tm->tcpm_next); + kfree_rcu(tm, rcu_head); + found = true; + } else { + pp = &tm->tcpm_next; + } + } + spin_unlock_bh(&tcp_metrics_lock); + if (!found) + return -ESRCH; + return 0; +} + +static const struct genl_small_ops tcp_metrics_nl_ops[] = { + { + .cmd = TCP_METRICS_CMD_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tcp_metrics_nl_cmd_get, + .dumpit = tcp_metrics_nl_dump, + }, + { + .cmd = TCP_METRICS_CMD_DEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tcp_metrics_nl_cmd_del, + .flags = GENL_ADMIN_PERM, + }, +}; + +static struct genl_family tcp_metrics_nl_family __ro_after_init = { + .hdrsize = 0, + .name = TCP_METRICS_GENL_NAME, + .version = TCP_METRICS_GENL_VERSION, + .maxattr = TCP_METRICS_ATTR_MAX, + .policy = tcp_metrics_nl_policy, + .netnsok = true, + .module = THIS_MODULE, + .small_ops = tcp_metrics_nl_ops, + .n_small_ops = ARRAY_SIZE(tcp_metrics_nl_ops), + .resv_start_op = TCP_METRICS_CMD_DEL + 1, +}; + +static unsigned int tcpmhash_entries; +static int __init set_tcpmhash_entries(char *str) +{ + ssize_t ret; + + if (!str) + return 0; + + ret = kstrtouint(str, 0, &tcpmhash_entries); + if (ret) + return 0; + + return 1; +} +__setup("tcpmhash_entries=", set_tcpmhash_entries); + +static int __net_init tcp_net_metrics_init(struct net *net) +{ + size_t size; + unsigned int slots; + + if (!net_eq(net, &init_net)) + return 0; + + slots = tcpmhash_entries; + if (!slots) { + if (totalram_pages() >= 128 * 1024) + slots = 16 * 1024; + else + slots = 8 * 1024; + } + + tcp_metrics_hash_log = order_base_2(slots); + size = sizeof(struct tcpm_hash_bucket) << tcp_metrics_hash_log; + + tcp_metrics_hash = kvzalloc(size, GFP_KERNEL); + if (!tcp_metrics_hash) + return -ENOMEM; + + return 0; +} + +static void __net_exit tcp_net_metrics_exit_batch(struct list_head *net_exit_list) +{ + tcp_metrics_flush_all(NULL); +} + +static __net_initdata struct pernet_operations tcp_net_metrics_ops = { + .init = tcp_net_metrics_init, + .exit_batch = tcp_net_metrics_exit_batch, +}; + +void __init tcp_metrics_init(void) +{ + int ret; + + ret = register_pernet_subsys(&tcp_net_metrics_ops); + if (ret < 0) + panic("Could not allocate the tcp_metrics hash table\n"); + + ret = genl_register_family(&tcp_metrics_nl_family); + if (ret < 0) + panic("Could not register tcp_metrics generic netlink\n"); +} diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c new file mode 100644 index 000000000..42844d20d --- /dev/null +++ b/net/ipv4/tcp_minisocks.c @@ -0,0 +1,880 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * Implementation of the Transmission Control Protocol(TCP). + * + * Authors: Ross Biro + * Fred N. van Kempen, + * Mark Evans, + * Corey Minyard + * Florian La Roche, + * Charles Hedrick, + * Linus Torvalds, + * Alan Cox, + * Matthew Dillon, + * Arnt Gulbrandsen, + * Jorge Cwik, + */ + +#include +#include +#include + +static bool tcp_in_window(u32 seq, u32 end_seq, u32 s_win, u32 e_win) +{ + if (seq == s_win) + return true; + if (after(end_seq, s_win) && before(seq, e_win)) + return true; + return seq == e_win && seq == end_seq; +} + +static enum tcp_tw_status +tcp_timewait_check_oow_rate_limit(struct inet_timewait_sock *tw, + const struct sk_buff *skb, int mib_idx) +{ + struct tcp_timewait_sock *tcptw = tcp_twsk((struct sock *)tw); + + if (!tcp_oow_rate_limited(twsk_net(tw), skb, mib_idx, + &tcptw->tw_last_oow_ack_time)) { + /* Send ACK. Note, we do not put the bucket, + * it will be released by caller. + */ + return TCP_TW_ACK; + } + + /* We are rate-limiting, so just release the tw sock and drop skb. */ + inet_twsk_put(tw); + return TCP_TW_SUCCESS; +} + +/* + * * Main purpose of TIME-WAIT state is to close connection gracefully, + * when one of ends sits in LAST-ACK or CLOSING retransmitting FIN + * (and, probably, tail of data) and one or more our ACKs are lost. + * * What is TIME-WAIT timeout? It is associated with maximal packet + * lifetime in the internet, which results in wrong conclusion, that + * it is set to catch "old duplicate segments" wandering out of their path. + * It is not quite correct. This timeout is calculated so that it exceeds + * maximal retransmission timeout enough to allow to lose one (or more) + * segments sent by peer and our ACKs. This time may be calculated from RTO. + * * When TIME-WAIT socket receives RST, it means that another end + * finally closed and we are allowed to kill TIME-WAIT too. + * * Second purpose of TIME-WAIT is catching old duplicate segments. + * Well, certainly it is pure paranoia, but if we load TIME-WAIT + * with this semantics, we MUST NOT kill TIME-WAIT state with RSTs. + * * If we invented some more clever way to catch duplicates + * (f.e. based on PAWS), we could truncate TIME-WAIT to several RTOs. + * + * The algorithm below is based on FORMAL INTERPRETATION of RFCs. + * When you compare it to RFCs, please, read section SEGMENT ARRIVES + * from the very beginning. + * + * NOTE. With recycling (and later with fin-wait-2) TW bucket + * is _not_ stateless. It means, that strictly speaking we must + * spinlock it. I do not want! Well, probability of misbehaviour + * is ridiculously low and, seems, we could use some mb() tricks + * to avoid misread sequence numbers, states etc. --ANK + * + * We don't need to initialize tmp_out.sack_ok as we don't use the results + */ +enum tcp_tw_status +tcp_timewait_state_process(struct inet_timewait_sock *tw, struct sk_buff *skb, + const struct tcphdr *th) +{ + struct tcp_options_received tmp_opt; + struct tcp_timewait_sock *tcptw = tcp_twsk((struct sock *)tw); + bool paws_reject = false; + + tmp_opt.saw_tstamp = 0; + if (th->doff > (sizeof(*th) >> 2) && tcptw->tw_ts_recent_stamp) { + tcp_parse_options(twsk_net(tw), skb, &tmp_opt, 0, NULL); + + if (tmp_opt.saw_tstamp) { + if (tmp_opt.rcv_tsecr) + tmp_opt.rcv_tsecr -= tcptw->tw_ts_offset; + tmp_opt.ts_recent = tcptw->tw_ts_recent; + tmp_opt.ts_recent_stamp = tcptw->tw_ts_recent_stamp; + paws_reject = tcp_paws_reject(&tmp_opt, th->rst); + } + } + + if (tw->tw_substate == TCP_FIN_WAIT2) { + /* Just repeat all the checks of tcp_rcv_state_process() */ + + /* Out of window, send ACK */ + if (paws_reject || + !tcp_in_window(TCP_SKB_CB(skb)->seq, TCP_SKB_CB(skb)->end_seq, + tcptw->tw_rcv_nxt, + tcptw->tw_rcv_nxt + tcptw->tw_rcv_wnd)) + return tcp_timewait_check_oow_rate_limit( + tw, skb, LINUX_MIB_TCPACKSKIPPEDFINWAIT2); + + if (th->rst) + goto kill; + + if (th->syn && !before(TCP_SKB_CB(skb)->seq, tcptw->tw_rcv_nxt)) + return TCP_TW_RST; + + /* Dup ACK? */ + if (!th->ack || + !after(TCP_SKB_CB(skb)->end_seq, tcptw->tw_rcv_nxt) || + TCP_SKB_CB(skb)->end_seq == TCP_SKB_CB(skb)->seq) { + inet_twsk_put(tw); + return TCP_TW_SUCCESS; + } + + /* New data or FIN. If new data arrive after half-duplex close, + * reset. + */ + if (!th->fin || + TCP_SKB_CB(skb)->end_seq != tcptw->tw_rcv_nxt + 1) + return TCP_TW_RST; + + /* FIN arrived, enter true time-wait state. */ + tw->tw_substate = TCP_TIME_WAIT; + tcptw->tw_rcv_nxt = TCP_SKB_CB(skb)->end_seq; + if (tmp_opt.saw_tstamp) { + tcptw->tw_ts_recent_stamp = ktime_get_seconds(); + tcptw->tw_ts_recent = tmp_opt.rcv_tsval; + } + + inet_twsk_reschedule(tw, TCP_TIMEWAIT_LEN); + return TCP_TW_ACK; + } + + /* + * Now real TIME-WAIT state. + * + * RFC 1122: + * "When a connection is [...] on TIME-WAIT state [...] + * [a TCP] MAY accept a new SYN from the remote TCP to + * reopen the connection directly, if it: + * + * (1) assigns its initial sequence number for the new + * connection to be larger than the largest sequence + * number it used on the previous connection incarnation, + * and + * + * (2) returns to TIME-WAIT state if the SYN turns out + * to be an old duplicate". + */ + + if (!paws_reject && + (TCP_SKB_CB(skb)->seq == tcptw->tw_rcv_nxt && + (TCP_SKB_CB(skb)->seq == TCP_SKB_CB(skb)->end_seq || th->rst))) { + /* In window segment, it may be only reset or bare ack. */ + + if (th->rst) { + /* This is TIME_WAIT assassination, in two flavors. + * Oh well... nobody has a sufficient solution to this + * protocol bug yet. + */ + if (!READ_ONCE(twsk_net(tw)->ipv4.sysctl_tcp_rfc1337)) { +kill: + inet_twsk_deschedule_put(tw); + return TCP_TW_SUCCESS; + } + } else { + inet_twsk_reschedule(tw, TCP_TIMEWAIT_LEN); + } + + if (tmp_opt.saw_tstamp) { + tcptw->tw_ts_recent = tmp_opt.rcv_tsval; + tcptw->tw_ts_recent_stamp = ktime_get_seconds(); + } + + inet_twsk_put(tw); + return TCP_TW_SUCCESS; + } + + /* Out of window segment. + + All the segments are ACKed immediately. + + The only exception is new SYN. We accept it, if it is + not old duplicate and we are not in danger to be killed + by delayed old duplicates. RFC check is that it has + newer sequence number works at rates <40Mbit/sec. + However, if paws works, it is reliable AND even more, + we even may relax silly seq space cutoff. + + RED-PEN: we violate main RFC requirement, if this SYN will appear + old duplicate (i.e. we receive RST in reply to SYN-ACK), + we must return socket to time-wait state. It is not good, + but not fatal yet. + */ + + if (th->syn && !th->rst && !th->ack && !paws_reject && + (after(TCP_SKB_CB(skb)->seq, tcptw->tw_rcv_nxt) || + (tmp_opt.saw_tstamp && + (s32)(tcptw->tw_ts_recent - tmp_opt.rcv_tsval) < 0))) { + u32 isn = tcptw->tw_snd_nxt + 65535 + 2; + if (isn == 0) + isn++; + TCP_SKB_CB(skb)->tcp_tw_isn = isn; + return TCP_TW_SYN; + } + + if (paws_reject) + __NET_INC_STATS(twsk_net(tw), LINUX_MIB_PAWSESTABREJECTED); + + if (!th->rst) { + /* In this case we must reset the TIMEWAIT timer. + * + * If it is ACKless SYN it may be both old duplicate + * and new good SYN with random sequence number ack) + inet_twsk_reschedule(tw, TCP_TIMEWAIT_LEN); + + return tcp_timewait_check_oow_rate_limit( + tw, skb, LINUX_MIB_TCPACKSKIPPEDTIMEWAIT); + } + inet_twsk_put(tw); + return TCP_TW_SUCCESS; +} +EXPORT_SYMBOL(tcp_timewait_state_process); + +/* + * Move a socket to time-wait or dead fin-wait-2 state. + */ +void tcp_time_wait(struct sock *sk, int state, int timeo) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + const struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); + struct inet_timewait_sock *tw; + + tw = inet_twsk_alloc(sk, &net->ipv4.tcp_death_row, state); + + if (tw) { + struct tcp_timewait_sock *tcptw = tcp_twsk((struct sock *)tw); + const int rto = (icsk->icsk_rto << 2) - (icsk->icsk_rto >> 1); + struct inet_sock *inet = inet_sk(sk); + + tw->tw_transparent = inet->transparent; + tw->tw_mark = sk->sk_mark; + tw->tw_priority = sk->sk_priority; + tw->tw_rcv_wscale = tp->rx_opt.rcv_wscale; + tcptw->tw_rcv_nxt = tp->rcv_nxt; + tcptw->tw_snd_nxt = tp->snd_nxt; + tcptw->tw_rcv_wnd = tcp_receive_window(tp); + tcptw->tw_ts_recent = tp->rx_opt.ts_recent; + tcptw->tw_ts_recent_stamp = tp->rx_opt.ts_recent_stamp; + tcptw->tw_ts_offset = tp->tsoffset; + tcptw->tw_last_oow_ack_time = 0; + tcptw->tw_tx_delay = tp->tcp_tx_delay; +#if IS_ENABLED(CONFIG_IPV6) + if (tw->tw_family == PF_INET6) { + struct ipv6_pinfo *np = inet6_sk(sk); + + tw->tw_v6_daddr = sk->sk_v6_daddr; + tw->tw_v6_rcv_saddr = sk->sk_v6_rcv_saddr; + tw->tw_tclass = np->tclass; + tw->tw_flowlabel = be32_to_cpu(np->flow_label & IPV6_FLOWLABEL_MASK); + tw->tw_txhash = sk->sk_txhash; + tw->tw_ipv6only = sk->sk_ipv6only; + } +#endif + +#ifdef CONFIG_TCP_MD5SIG + /* + * The timewait bucket does not have the key DB from the + * sock structure. We just make a quick copy of the + * md5 key being used (if indeed we are using one) + * so the timewait ack generating code has the key. + */ + do { + tcptw->tw_md5_key = NULL; + if (static_branch_unlikely(&tcp_md5_needed)) { + struct tcp_md5sig_key *key; + + key = tp->af_specific->md5_lookup(sk, sk); + if (key) { + tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC); + BUG_ON(tcptw->tw_md5_key && !tcp_alloc_md5sig_pool()); + } + } + } while (0); +#endif + + /* Get the TIME_WAIT timeout firing. */ + if (timeo < rto) + timeo = rto; + + if (state == TCP_TIME_WAIT) + timeo = TCP_TIMEWAIT_LEN; + + /* tw_timer is pinned, so we need to make sure BH are disabled + * in following section, otherwise timer handler could run before + * we complete the initialization. + */ + local_bh_disable(); + inet_twsk_schedule(tw, timeo); + /* Linkage updates. + * Note that access to tw after this point is illegal. + */ + inet_twsk_hashdance(tw, sk, net->ipv4.tcp_death_row.hashinfo); + local_bh_enable(); + } else { + /* Sorry, if we're out of memory, just CLOSE this + * socket up. We've got bigger problems than + * non-graceful socket closings. + */ + NET_INC_STATS(net, LINUX_MIB_TCPTIMEWAITOVERFLOW); + } + + tcp_update_metrics(sk); + tcp_done(sk); +} +EXPORT_SYMBOL(tcp_time_wait); + +void tcp_twsk_destructor(struct sock *sk) +{ +#ifdef CONFIG_TCP_MD5SIG + if (static_branch_unlikely(&tcp_md5_needed)) { + struct tcp_timewait_sock *twsk = tcp_twsk(sk); + + if (twsk->tw_md5_key) + kfree_rcu(twsk->tw_md5_key, rcu); + } +#endif +} +EXPORT_SYMBOL_GPL(tcp_twsk_destructor); + +void tcp_twsk_purge(struct list_head *net_exit_list, int family) +{ + bool purged_once = false; + struct net *net; + + list_for_each_entry(net, net_exit_list, exit_list) { + if (net->ipv4.tcp_death_row.hashinfo->pernet) { + /* Even if tw_refcount == 1, we must clean up kernel reqsk */ + inet_twsk_purge(net->ipv4.tcp_death_row.hashinfo, family); + } else if (!purged_once) { + /* The last refcount is decremented in tcp_sk_exit_batch() */ + if (refcount_read(&net->ipv4.tcp_death_row.tw_refcount) == 1) + continue; + + inet_twsk_purge(&tcp_hashinfo, family); + purged_once = true; + } + } +} +EXPORT_SYMBOL_GPL(tcp_twsk_purge); + +/* Warning : This function is called without sk_listener being locked. + * Be sure to read socket fields once, as their value could change under us. + */ +void tcp_openreq_init_rwin(struct request_sock *req, + const struct sock *sk_listener, + const struct dst_entry *dst) +{ + struct inet_request_sock *ireq = inet_rsk(req); + const struct tcp_sock *tp = tcp_sk(sk_listener); + int full_space = tcp_full_space(sk_listener); + u32 window_clamp; + __u8 rcv_wscale; + u32 rcv_wnd; + int mss; + + mss = tcp_mss_clamp(tp, dst_metric_advmss(dst)); + window_clamp = READ_ONCE(tp->window_clamp); + /* Set this up on the first call only */ + req->rsk_window_clamp = window_clamp ? : dst_metric(dst, RTAX_WINDOW); + + /* limit the window selection if the user enforce a smaller rx buffer */ + if (sk_listener->sk_userlocks & SOCK_RCVBUF_LOCK && + (req->rsk_window_clamp > full_space || req->rsk_window_clamp == 0)) + req->rsk_window_clamp = full_space; + + rcv_wnd = tcp_rwnd_init_bpf((struct sock *)req); + if (rcv_wnd == 0) + rcv_wnd = dst_metric(dst, RTAX_INITRWND); + else if (full_space < rcv_wnd * mss) + full_space = rcv_wnd * mss; + + /* tcp_full_space because it is guaranteed to be the first packet */ + tcp_select_initial_window(sk_listener, full_space, + mss - (ireq->tstamp_ok ? TCPOLEN_TSTAMP_ALIGNED : 0), + &req->rsk_rcv_wnd, + &req->rsk_window_clamp, + ireq->wscale_ok, + &rcv_wscale, + rcv_wnd); + ireq->rcv_wscale = rcv_wscale; +} +EXPORT_SYMBOL(tcp_openreq_init_rwin); + +static void tcp_ecn_openreq_child(struct tcp_sock *tp, + const struct request_sock *req) +{ + tp->ecn_flags = inet_rsk(req)->ecn_ok ? TCP_ECN_OK : 0; +} + +void tcp_ca_openreq_child(struct sock *sk, const struct dst_entry *dst) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + u32 ca_key = dst_metric(dst, RTAX_CC_ALGO); + bool ca_got_dst = false; + + if (ca_key != TCP_CA_UNSPEC) { + const struct tcp_congestion_ops *ca; + + rcu_read_lock(); + ca = tcp_ca_find_key(ca_key); + if (likely(ca && bpf_try_module_get(ca, ca->owner))) { + icsk->icsk_ca_dst_locked = tcp_ca_dst_locked(dst); + icsk->icsk_ca_ops = ca; + ca_got_dst = true; + } + rcu_read_unlock(); + } + + /* If no valid choice made yet, assign current system default ca. */ + if (!ca_got_dst && + (!icsk->icsk_ca_setsockopt || + !bpf_try_module_get(icsk->icsk_ca_ops, icsk->icsk_ca_ops->owner))) + tcp_assign_congestion_control(sk); + + tcp_set_ca_state(sk, TCP_CA_Open); +} +EXPORT_SYMBOL_GPL(tcp_ca_openreq_child); + +static void smc_check_reset_syn_req(struct tcp_sock *oldtp, + struct request_sock *req, + struct tcp_sock *newtp) +{ +#if IS_ENABLED(CONFIG_SMC) + struct inet_request_sock *ireq; + + if (static_branch_unlikely(&tcp_have_smc)) { + ireq = inet_rsk(req); + if (oldtp->syn_smc && !ireq->smc_ok) + newtp->syn_smc = 0; + } +#endif +} + +/* This is not only more efficient than what we used to do, it eliminates + * a lot of code duplication between IPv4/IPv6 SYN recv processing. -DaveM + * + * Actually, we could lots of memory writes here. tp of listening + * socket contains all necessary default parameters. + */ +struct sock *tcp_create_openreq_child(const struct sock *sk, + struct request_sock *req, + struct sk_buff *skb) +{ + struct sock *newsk = inet_csk_clone_lock(sk, req, GFP_ATOMIC); + const struct inet_request_sock *ireq = inet_rsk(req); + struct tcp_request_sock *treq = tcp_rsk(req); + struct inet_connection_sock *newicsk; + struct tcp_sock *oldtp, *newtp; + u32 seq; + + if (!newsk) + return NULL; + + newicsk = inet_csk(newsk); + newtp = tcp_sk(newsk); + oldtp = tcp_sk(sk); + + smc_check_reset_syn_req(oldtp, req, newtp); + + /* Now setup tcp_sock */ + newtp->pred_flags = 0; + + seq = treq->rcv_isn + 1; + newtp->rcv_wup = seq; + WRITE_ONCE(newtp->copied_seq, seq); + WRITE_ONCE(newtp->rcv_nxt, seq); + newtp->segs_in = 1; + + seq = treq->snt_isn + 1; + newtp->snd_sml = newtp->snd_una = seq; + WRITE_ONCE(newtp->snd_nxt, seq); + newtp->snd_up = seq; + + INIT_LIST_HEAD(&newtp->tsq_node); + INIT_LIST_HEAD(&newtp->tsorted_sent_queue); + + tcp_init_wl(newtp, treq->rcv_isn); + + minmax_reset(&newtp->rtt_min, tcp_jiffies32, ~0U); + newicsk->icsk_ack.lrcvtime = tcp_jiffies32; + + newtp->lsndtime = tcp_jiffies32; + newsk->sk_txhash = READ_ONCE(treq->txhash); + newtp->total_retrans = req->num_retrans; + + tcp_init_xmit_timers(newsk); + WRITE_ONCE(newtp->write_seq, newtp->pushed_seq = treq->snt_isn + 1); + + if (sock_flag(newsk, SOCK_KEEPOPEN)) + inet_csk_reset_keepalive_timer(newsk, + keepalive_time_when(newtp)); + + newtp->rx_opt.tstamp_ok = ireq->tstamp_ok; + newtp->rx_opt.sack_ok = ireq->sack_ok; + newtp->window_clamp = req->rsk_window_clamp; + newtp->rcv_ssthresh = req->rsk_rcv_wnd; + newtp->rcv_wnd = req->rsk_rcv_wnd; + newtp->rx_opt.wscale_ok = ireq->wscale_ok; + if (newtp->rx_opt.wscale_ok) { + newtp->rx_opt.snd_wscale = ireq->snd_wscale; + newtp->rx_opt.rcv_wscale = ireq->rcv_wscale; + } else { + newtp->rx_opt.snd_wscale = newtp->rx_opt.rcv_wscale = 0; + newtp->window_clamp = min(newtp->window_clamp, 65535U); + } + newtp->snd_wnd = ntohs(tcp_hdr(skb)->window) << newtp->rx_opt.snd_wscale; + newtp->max_window = newtp->snd_wnd; + + if (newtp->rx_opt.tstamp_ok) { + newtp->rx_opt.ts_recent = READ_ONCE(req->ts_recent); + newtp->rx_opt.ts_recent_stamp = ktime_get_seconds(); + newtp->tcp_header_len = sizeof(struct tcphdr) + TCPOLEN_TSTAMP_ALIGNED; + } else { + newtp->rx_opt.ts_recent_stamp = 0; + newtp->tcp_header_len = sizeof(struct tcphdr); + } + if (req->num_timeout) { + newtp->undo_marker = treq->snt_isn; + newtp->retrans_stamp = div_u64(treq->snt_synack, + USEC_PER_SEC / TCP_TS_HZ); + } + newtp->tsoffset = treq->ts_off; +#ifdef CONFIG_TCP_MD5SIG + newtp->md5sig_info = NULL; /*XXX*/ + if (treq->af_specific->req_md5_lookup(sk, req_to_sk(req))) + newtp->tcp_header_len += TCPOLEN_MD5SIG_ALIGNED; +#endif + if (skb->len >= TCP_MSS_DEFAULT + newtp->tcp_header_len) + newicsk->icsk_ack.last_seg_size = skb->len - newtp->tcp_header_len; + newtp->rx_opt.mss_clamp = req->mss; + tcp_ecn_openreq_child(newtp, req); + newtp->fastopen_req = NULL; + RCU_INIT_POINTER(newtp->fastopen_rsk, NULL); + + newtp->bpf_chg_cc_inprogress = 0; + tcp_bpf_clone(sk, newsk); + + __TCP_INC_STATS(sock_net(sk), TCP_MIB_PASSIVEOPENS); + + return newsk; +} +EXPORT_SYMBOL(tcp_create_openreq_child); + +/* + * Process an incoming packet for SYN_RECV sockets represented as a + * request_sock. Normally sk is the listener socket but for TFO it + * points to the child socket. + * + * XXX (TFO) - The current impl contains a special check for ack + * validation and inside tcp_v4_reqsk_send_ack(). Can we do better? + * + * We don't need to initialize tmp_opt.sack_ok as we don't use the results + * + * Note: If @fastopen is true, this can be called from process context. + * Otherwise, this is from BH context. + */ + +struct sock *tcp_check_req(struct sock *sk, struct sk_buff *skb, + struct request_sock *req, + bool fastopen, bool *req_stolen) +{ + struct tcp_options_received tmp_opt; + struct sock *child; + const struct tcphdr *th = tcp_hdr(skb); + __be32 flg = tcp_flag_word(th) & (TCP_FLAG_RST|TCP_FLAG_SYN|TCP_FLAG_ACK); + bool paws_reject = false; + bool own_req; + + tmp_opt.saw_tstamp = 0; + if (th->doff > (sizeof(struct tcphdr)>>2)) { + tcp_parse_options(sock_net(sk), skb, &tmp_opt, 0, NULL); + + if (tmp_opt.saw_tstamp) { + tmp_opt.ts_recent = READ_ONCE(req->ts_recent); + if (tmp_opt.rcv_tsecr) + tmp_opt.rcv_tsecr -= tcp_rsk(req)->ts_off; + /* We do not store true stamp, but it is not required, + * it can be estimated (approximately) + * from another data. + */ + tmp_opt.ts_recent_stamp = ktime_get_seconds() - reqsk_timeout(req, TCP_RTO_MAX) / HZ; + paws_reject = tcp_paws_reject(&tmp_opt, th->rst); + } + } + + /* Check for pure retransmitted SYN. */ + if (TCP_SKB_CB(skb)->seq == tcp_rsk(req)->rcv_isn && + flg == TCP_FLAG_SYN && + !paws_reject) { + /* + * RFC793 draws (Incorrectly! It was fixed in RFC1122) + * this case on figure 6 and figure 8, but formal + * protocol description says NOTHING. + * To be more exact, it says that we should send ACK, + * because this segment (at least, if it has no data) + * is out of window. + * + * CONCLUSION: RFC793 (even with RFC1122) DOES NOT + * describe SYN-RECV state. All the description + * is wrong, we cannot believe to it and should + * rely only on common sense and implementation + * experience. + * + * Enforce "SYN-ACK" according to figure 8, figure 6 + * of RFC793, fixed by RFC1122. + * + * Note that even if there is new data in the SYN packet + * they will be thrown away too. + * + * Reset timer after retransmitting SYNACK, similar to + * the idea of fast retransmit in recovery. + */ + if (!tcp_oow_rate_limited(sock_net(sk), skb, + LINUX_MIB_TCPACKSKIPPEDSYNRECV, + &tcp_rsk(req)->last_oow_ack_time) && + + !inet_rtx_syn_ack(sk, req)) { + unsigned long expires = jiffies; + + expires += reqsk_timeout(req, TCP_RTO_MAX); + if (!fastopen) + mod_timer_pending(&req->rsk_timer, expires); + else + req->rsk_timer.expires = expires; + } + return NULL; + } + + /* Further reproduces section "SEGMENT ARRIVES" + for state SYN-RECEIVED of RFC793. + It is broken, however, it does not work only + when SYNs are crossed. + + You would think that SYN crossing is impossible here, since + we should have a SYN_SENT socket (from connect()) on our end, + but this is not true if the crossed SYNs were sent to both + ends by a malicious third party. We must defend against this, + and to do that we first verify the ACK (as per RFC793, page + 36) and reset if it is invalid. Is this a true full defense? + To convince ourselves, let us consider a way in which the ACK + test can still pass in this 'malicious crossed SYNs' case. + Malicious sender sends identical SYNs (and thus identical sequence + numbers) to both A and B: + + A: gets SYN, seq=7 + B: gets SYN, seq=7 + + By our good fortune, both A and B select the same initial + send sequence number of seven :-) + + A: sends SYN|ACK, seq=7, ack_seq=8 + B: sends SYN|ACK, seq=7, ack_seq=8 + + So we are now A eating this SYN|ACK, ACK test passes. So + does sequence test, SYN is truncated, and thus we consider + it a bare ACK. + + If icsk->icsk_accept_queue.rskq_defer_accept, we silently drop this + bare ACK. Otherwise, we create an established connection. Both + ends (listening sockets) accept the new incoming connection and try + to talk to each other. 8-) + + Note: This case is both harmless, and rare. Possibility is about the + same as us discovering intelligent life on another plant tomorrow. + + But generally, we should (RFC lies!) to accept ACK + from SYNACK both here and in tcp_rcv_state_process(). + tcp_rcv_state_process() does not, hence, we do not too. + + Note that the case is absolutely generic: + we cannot optimize anything here without + violating protocol. All the checks must be made + before attempt to create socket. + */ + + /* RFC793 page 36: "If the connection is in any non-synchronized state ... + * and the incoming segment acknowledges something not yet + * sent (the segment carries an unacceptable ACK) ... + * a reset is sent." + * + * Invalid ACK: reset will be sent by listening socket. + * Note that the ACK validity check for a Fast Open socket is done + * elsewhere and is checked directly against the child socket rather + * than req because user data may have been sent out. + */ + if ((flg & TCP_FLAG_ACK) && !fastopen && + (TCP_SKB_CB(skb)->ack_seq != + tcp_rsk(req)->snt_isn + 1)) + return sk; + + /* Also, it would be not so bad idea to check rcv_tsecr, which + * is essentially ACK extension and too early or too late values + * should cause reset in unsynchronized states. + */ + + /* RFC793: "first check sequence number". */ + + if (paws_reject || !tcp_in_window(TCP_SKB_CB(skb)->seq, TCP_SKB_CB(skb)->end_seq, + tcp_rsk(req)->rcv_nxt, tcp_rsk(req)->rcv_nxt + req->rsk_rcv_wnd)) { + /* Out of window: send ACK and drop. */ + if (!(flg & TCP_FLAG_RST) && + !tcp_oow_rate_limited(sock_net(sk), skb, + LINUX_MIB_TCPACKSKIPPEDSYNRECV, + &tcp_rsk(req)->last_oow_ack_time)) + req->rsk_ops->send_ack(sk, skb, req); + if (paws_reject) + NET_INC_STATS(sock_net(sk), LINUX_MIB_PAWSESTABREJECTED); + return NULL; + } + + /* In sequence, PAWS is OK. */ + + /* TODO: We probably should defer ts_recent change once + * we take ownership of @req. + */ + if (tmp_opt.saw_tstamp && !after(TCP_SKB_CB(skb)->seq, tcp_rsk(req)->rcv_nxt)) + WRITE_ONCE(req->ts_recent, tmp_opt.rcv_tsval); + + if (TCP_SKB_CB(skb)->seq == tcp_rsk(req)->rcv_isn) { + /* Truncate SYN, it is out of window starting + at tcp_rsk(req)->rcv_isn + 1. */ + flg &= ~TCP_FLAG_SYN; + } + + /* RFC793: "second check the RST bit" and + * "fourth, check the SYN bit" + */ + if (flg & (TCP_FLAG_RST|TCP_FLAG_SYN)) { + TCP_INC_STATS(sock_net(sk), TCP_MIB_ATTEMPTFAILS); + goto embryonic_reset; + } + + /* ACK sequence verified above, just make sure ACK is + * set. If ACK not set, just silently drop the packet. + * + * XXX (TFO) - if we ever allow "data after SYN", the + * following check needs to be removed. + */ + if (!(flg & TCP_FLAG_ACK)) + return NULL; + + /* For Fast Open no more processing is needed (sk is the + * child socket). + */ + if (fastopen) + return sk; + + /* While TCP_DEFER_ACCEPT is active, drop bare ACK. */ + if (req->num_timeout < inet_csk(sk)->icsk_accept_queue.rskq_defer_accept && + TCP_SKB_CB(skb)->end_seq == tcp_rsk(req)->rcv_isn + 1) { + inet_rsk(req)->acked = 1; + __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPDEFERACCEPTDROP); + return NULL; + } + + /* OK, ACK is valid, create big socket and + * feed this segment to it. It will repeat all + * the tests. THIS SEGMENT MUST MOVE SOCKET TO + * ESTABLISHED STATE. If it will be dropped after + * socket is created, wait for troubles. + */ + child = inet_csk(sk)->icsk_af_ops->syn_recv_sock(sk, skb, req, NULL, + req, &own_req); + if (!child) + goto listen_overflow; + + if (own_req && rsk_drop_req(req)) { + reqsk_queue_removed(&inet_csk(req->rsk_listener)->icsk_accept_queue, req); + inet_csk_reqsk_queue_drop_and_put(req->rsk_listener, req); + return child; + } + + sock_rps_save_rxhash(child, skb); + tcp_synack_rtt_meas(child, req); + *req_stolen = !own_req; + return inet_csk_complete_hashdance(sk, child, req, own_req); + +listen_overflow: + if (sk != req->rsk_listener) + __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPMIGRATEREQFAILURE); + + if (!READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_abort_on_overflow)) { + inet_rsk(req)->acked = 1; + return NULL; + } + +embryonic_reset: + if (!(flg & TCP_FLAG_RST)) { + /* Received a bad SYN pkt - for TFO We try not to reset + * the local connection unless it's really necessary to + * avoid becoming vulnerable to outside attack aiming at + * resetting legit local connections. + */ + req->rsk_ops->send_reset(sk, skb); + } else if (fastopen) { /* received a valid RST pkt */ + reqsk_fastopen_remove(sk, req, true); + tcp_reset(sk, skb); + } + if (!fastopen) { + bool unlinked = inet_csk_reqsk_queue_drop(sk, req); + + if (unlinked) + __NET_INC_STATS(sock_net(sk), LINUX_MIB_EMBRYONICRSTS); + *req_stolen = !unlinked; + } + return NULL; +} +EXPORT_SYMBOL(tcp_check_req); + +/* + * Queue segment on the new socket if the new socket is active, + * otherwise we just shortcircuit this and continue with + * the new socket. + * + * For the vast majority of cases child->sk_state will be TCP_SYN_RECV + * when entering. But other states are possible due to a race condition + * where after __inet_lookup_established() fails but before the listener + * locked is obtained, other packets cause the same connection to + * be created. + */ + +int tcp_child_process(struct sock *parent, struct sock *child, + struct sk_buff *skb) + __releases(&((child)->sk_lock.slock)) +{ + int ret = 0; + int state = child->sk_state; + + /* record sk_napi_id and sk_rx_queue_mapping of child. */ + sk_mark_napi_id_set(child, skb); + + tcp_segs_in(tcp_sk(child), skb); + if (!sock_owned_by_user(child)) { + ret = tcp_rcv_state_process(child, skb); + /* Wakeup parent, send SIGIO */ + if (state == TCP_SYN_RECV && child->sk_state != state) + parent->sk_data_ready(parent); + } else { + /* Alas, it is possible again, because we do lookup + * in main socket hash table and lock on listening + * socket does not protect us more. + */ + __sk_add_backlog(child, skb); + } + + bh_unlock_sock(child); + sock_put(child); + return ret; +} +EXPORT_SYMBOL(tcp_child_process); diff --git a/net/ipv4/tcp_nv.c b/net/ipv4/tcp_nv.c new file mode 100644 index 000000000..a60662f4b --- /dev/null +++ b/net/ipv4/tcp_nv.c @@ -0,0 +1,501 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * TCP NV: TCP with Congestion Avoidance + * + * TCP-NV is a successor of TCP-Vegas that has been developed to + * deal with the issues that occur in modern networks. + * Like TCP-Vegas, TCP-NV supports true congestion avoidance, + * the ability to detect congestion before packet losses occur. + * When congestion (queue buildup) starts to occur, TCP-NV + * predicts what the cwnd size should be for the current + * throughput and it reduces the cwnd proportionally to + * the difference between the current cwnd and the predicted cwnd. + * + * NV is only recommeneded for traffic within a data center, and when + * all the flows are NV (at least those within the data center). This + * is due to the inherent unfairness between flows using losses to + * detect congestion (congestion control) and those that use queue + * buildup to detect congestion (congestion avoidance). + * + * Note: High NIC coalescence values may lower the performance of NV + * due to the increased noise in RTT values. In particular, we have + * seen issues with rx-frames values greater than 8. + * + * TODO: + * 1) Add mechanism to deal with reverse congestion. + */ + +#include +#include +#include +#include + +/* TCP NV parameters + * + * nv_pad Max number of queued packets allowed in network + * nv_pad_buffer Do not grow cwnd if this closed to nv_pad + * nv_reset_period How often (in) seconds)to reset min_rtt + * nv_min_cwnd Don't decrease cwnd below this if there are no losses + * nv_cong_dec_mult Decrease cwnd by X% (30%) of congestion when detected + * nv_ssthresh_factor On congestion set ssthresh to this * / 8 + * nv_rtt_factor RTT averaging factor + * nv_loss_dec_factor Decrease cwnd to this (80%) when losses occur + * nv_dec_eval_min_calls Wait this many RTT measurements before dec cwnd + * nv_inc_eval_min_calls Wait this many RTT measurements before inc cwnd + * nv_ssthresh_eval_min_calls Wait this many RTT measurements before stopping + * slow-start due to congestion + * nv_stop_rtt_cnt Only grow cwnd for this many RTTs after non-congestion + * nv_rtt_min_cnt Wait these many RTTs before making congesion decision + * nv_cwnd_growth_rate_neg + * nv_cwnd_growth_rate_pos + * How quickly to double growth rate (not rate) of cwnd when not + * congested. One value (nv_cwnd_growth_rate_neg) for when + * rate < 1 pkt/RTT (after losses). The other (nv_cwnd_growth_rate_pos) + * otherwise. + */ + +static int nv_pad __read_mostly = 10; +static int nv_pad_buffer __read_mostly = 2; +static int nv_reset_period __read_mostly = 5; /* in seconds */ +static int nv_min_cwnd __read_mostly = 2; +static int nv_cong_dec_mult __read_mostly = 30 * 128 / 100; /* = 30% */ +static int nv_ssthresh_factor __read_mostly = 8; /* = 1 */ +static int nv_rtt_factor __read_mostly = 128; /* = 1/2*old + 1/2*new */ +static int nv_loss_dec_factor __read_mostly = 819; /* => 80% */ +static int nv_cwnd_growth_rate_neg __read_mostly = 8; +static int nv_cwnd_growth_rate_pos __read_mostly; /* 0 => fixed like Reno */ +static int nv_dec_eval_min_calls __read_mostly = 60; +static int nv_inc_eval_min_calls __read_mostly = 20; +static int nv_ssthresh_eval_min_calls __read_mostly = 30; +static int nv_stop_rtt_cnt __read_mostly = 10; +static int nv_rtt_min_cnt __read_mostly = 2; + +module_param(nv_pad, int, 0644); +MODULE_PARM_DESC(nv_pad, "max queued packets allowed in network"); +module_param(nv_reset_period, int, 0644); +MODULE_PARM_DESC(nv_reset_period, "nv_min_rtt reset period (secs)"); +module_param(nv_min_cwnd, int, 0644); +MODULE_PARM_DESC(nv_min_cwnd, "NV will not decrease cwnd below this value" + " without losses"); + +/* TCP NV Parameters */ +struct tcpnv { + unsigned long nv_min_rtt_reset_jiffies; /* when to switch to + * nv_min_rtt_new */ + s8 cwnd_growth_factor; /* Current cwnd growth factor, + * < 0 => less than 1 packet/RTT */ + u8 available8; + u16 available16; + u8 nv_allow_cwnd_growth:1, /* whether cwnd can grow */ + nv_reset:1, /* whether to reset values */ + nv_catchup:1; /* whether we are growing because + * of temporary cwnd decrease */ + u8 nv_eval_call_cnt; /* call count since last eval */ + u8 nv_min_cwnd; /* nv won't make a ca decision if cwnd is + * smaller than this. It may grow to handle + * TSO, LRO and interrupt coalescence because + * with these a small cwnd cannot saturate + * the link. Note that this is different from + * the file local nv_min_cwnd */ + u8 nv_rtt_cnt; /* RTTs without making ca decision */; + u32 nv_last_rtt; /* last rtt */ + u32 nv_min_rtt; /* active min rtt. Used to determine slope */ + u32 nv_min_rtt_new; /* min rtt for future use */ + u32 nv_base_rtt; /* If non-zero it represents the threshold for + * congestion */ + u32 nv_lower_bound_rtt; /* Used in conjunction with nv_base_rtt. It is + * set to 80% of nv_base_rtt. It helps reduce + * unfairness between flows */ + u32 nv_rtt_max_rate; /* max rate seen during current RTT */ + u32 nv_rtt_start_seq; /* current RTT ends when packet arrives + * acking beyond nv_rtt_start_seq */ + u32 nv_last_snd_una; /* Previous value of tp->snd_una. It is + * used to determine bytes acked since last + * call to bictcp_acked */ + u32 nv_no_cong_cnt; /* Consecutive no congestion decisions */ +}; + +#define NV_INIT_RTT U32_MAX +#define NV_MIN_CWND 4 +#define NV_MIN_CWND_GROW 2 +#define NV_TSO_CWND_BOUND 80 + +static inline void tcpnv_reset(struct tcpnv *ca, struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + ca->nv_reset = 0; + ca->nv_no_cong_cnt = 0; + ca->nv_rtt_cnt = 0; + ca->nv_last_rtt = 0; + ca->nv_rtt_max_rate = 0; + ca->nv_rtt_start_seq = tp->snd_una; + ca->nv_eval_call_cnt = 0; + ca->nv_last_snd_una = tp->snd_una; +} + +static void tcpnv_init(struct sock *sk) +{ + struct tcpnv *ca = inet_csk_ca(sk); + int base_rtt; + + tcpnv_reset(ca, sk); + + /* See if base_rtt is available from socket_ops bpf program. + * It is meant to be used in environments, such as communication + * within a datacenter, where we have reasonable estimates of + * RTTs + */ + base_rtt = tcp_call_bpf(sk, BPF_SOCK_OPS_BASE_RTT, 0, NULL); + if (base_rtt > 0) { + ca->nv_base_rtt = base_rtt; + ca->nv_lower_bound_rtt = (base_rtt * 205) >> 8; /* 80% */ + } else { + ca->nv_base_rtt = 0; + ca->nv_lower_bound_rtt = 0; + } + + ca->nv_allow_cwnd_growth = 1; + ca->nv_min_rtt_reset_jiffies = jiffies + 2 * HZ; + ca->nv_min_rtt = NV_INIT_RTT; + ca->nv_min_rtt_new = NV_INIT_RTT; + ca->nv_min_cwnd = NV_MIN_CWND; + ca->nv_catchup = 0; + ca->cwnd_growth_factor = 0; +} + +/* If provided, apply upper (base_rtt) and lower (lower_bound_rtt) + * bounds to RTT. + */ +inline u32 nv_get_bounded_rtt(struct tcpnv *ca, u32 val) +{ + if (ca->nv_lower_bound_rtt > 0 && val < ca->nv_lower_bound_rtt) + return ca->nv_lower_bound_rtt; + else if (ca->nv_base_rtt > 0 && val > ca->nv_base_rtt) + return ca->nv_base_rtt; + else + return val; +} + +static void tcpnv_cong_avoid(struct sock *sk, u32 ack, u32 acked) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct tcpnv *ca = inet_csk_ca(sk); + u32 cnt; + + if (!tcp_is_cwnd_limited(sk)) + return; + + /* Only grow cwnd if NV has not detected congestion */ + if (!ca->nv_allow_cwnd_growth) + return; + + if (tcp_in_slow_start(tp)) { + acked = tcp_slow_start(tp, acked); + if (!acked) + return; + } + + if (ca->cwnd_growth_factor < 0) { + cnt = tcp_snd_cwnd(tp) << -ca->cwnd_growth_factor; + tcp_cong_avoid_ai(tp, cnt, acked); + } else { + cnt = max(4U, tcp_snd_cwnd(tp) >> ca->cwnd_growth_factor); + tcp_cong_avoid_ai(tp, cnt, acked); + } +} + +static u32 tcpnv_recalc_ssthresh(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + + return max((tcp_snd_cwnd(tp) * nv_loss_dec_factor) >> 10, 2U); +} + +static void tcpnv_state(struct sock *sk, u8 new_state) +{ + struct tcpnv *ca = inet_csk_ca(sk); + + if (new_state == TCP_CA_Open && ca->nv_reset) { + tcpnv_reset(ca, sk); + } else if (new_state == TCP_CA_Loss || new_state == TCP_CA_CWR || + new_state == TCP_CA_Recovery) { + ca->nv_reset = 1; + ca->nv_allow_cwnd_growth = 0; + if (new_state == TCP_CA_Loss) { + /* Reset cwnd growth factor to Reno value */ + if (ca->cwnd_growth_factor > 0) + ca->cwnd_growth_factor = 0; + /* Decrease growth rate if allowed */ + if (nv_cwnd_growth_rate_neg > 0 && + ca->cwnd_growth_factor > -8) + ca->cwnd_growth_factor--; + } + } +} + +/* Do congestion avoidance calculations for TCP-NV + */ +static void tcpnv_acked(struct sock *sk, const struct ack_sample *sample) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct tcpnv *ca = inet_csk_ca(sk); + unsigned long now = jiffies; + u64 rate64; + u32 rate, max_win, cwnd_by_slope; + u32 avg_rtt; + u32 bytes_acked = 0; + + /* Some calls are for duplicates without timetamps */ + if (sample->rtt_us < 0) + return; + + /* If not in TCP_CA_Open or TCP_CA_Disorder states, skip. */ + if (icsk->icsk_ca_state != TCP_CA_Open && + icsk->icsk_ca_state != TCP_CA_Disorder) + return; + + /* Stop cwnd growth if we were in catch up mode */ + if (ca->nv_catchup && tcp_snd_cwnd(tp) >= nv_min_cwnd) { + ca->nv_catchup = 0; + ca->nv_allow_cwnd_growth = 0; + } + + bytes_acked = tp->snd_una - ca->nv_last_snd_una; + ca->nv_last_snd_una = tp->snd_una; + + if (sample->in_flight == 0) + return; + + /* Calculate moving average of RTT */ + if (nv_rtt_factor > 0) { + if (ca->nv_last_rtt > 0) { + avg_rtt = (((u64)sample->rtt_us) * nv_rtt_factor + + ((u64)ca->nv_last_rtt) + * (256 - nv_rtt_factor)) >> 8; + } else { + avg_rtt = sample->rtt_us; + ca->nv_min_rtt = avg_rtt << 1; + } + ca->nv_last_rtt = avg_rtt; + } else { + avg_rtt = sample->rtt_us; + } + + /* rate in 100's bits per second */ + rate64 = ((u64)sample->in_flight) * 80000; + do_div(rate64, avg_rtt ?: 1); + rate = (u32)rate64; + + /* Remember the maximum rate seen during this RTT + * Note: It may be more than one RTT. This function should be + * called at least nv_dec_eval_min_calls times. + */ + if (ca->nv_rtt_max_rate < rate) + ca->nv_rtt_max_rate = rate; + + /* We have valid information, increment counter */ + if (ca->nv_eval_call_cnt < 255) + ca->nv_eval_call_cnt++; + + /* Apply bounds to rtt. Only used to update min_rtt */ + avg_rtt = nv_get_bounded_rtt(ca, avg_rtt); + + /* update min rtt if necessary */ + if (avg_rtt < ca->nv_min_rtt) + ca->nv_min_rtt = avg_rtt; + + /* update future min_rtt if necessary */ + if (avg_rtt < ca->nv_min_rtt_new) + ca->nv_min_rtt_new = avg_rtt; + + /* nv_min_rtt is updated with the minimum (possibley averaged) rtt + * seen in the last sysctl_tcp_nv_reset_period seconds (i.e. a + * warm reset). This new nv_min_rtt will be continued to be updated + * and be used for another sysctl_tcp_nv_reset_period seconds, + * when it will be updated again. + * In practice we introduce some randomness, so the actual period used + * is chosen randomly from the range: + * [sysctl_tcp_nv_reset_period*3/4, sysctl_tcp_nv_reset_period*5/4) + */ + if (time_after_eq(now, ca->nv_min_rtt_reset_jiffies)) { + unsigned char rand; + + ca->nv_min_rtt = ca->nv_min_rtt_new; + ca->nv_min_rtt_new = NV_INIT_RTT; + get_random_bytes(&rand, 1); + ca->nv_min_rtt_reset_jiffies = + now + ((nv_reset_period * (384 + rand) * HZ) >> 9); + /* Every so often we decrease ca->nv_min_cwnd in case previous + * value is no longer accurate. + */ + ca->nv_min_cwnd = max(ca->nv_min_cwnd / 2, NV_MIN_CWND); + } + + /* Once per RTT check if we need to do congestion avoidance */ + if (before(ca->nv_rtt_start_seq, tp->snd_una)) { + ca->nv_rtt_start_seq = tp->snd_nxt; + if (ca->nv_rtt_cnt < 0xff) + /* Increase counter for RTTs without CA decision */ + ca->nv_rtt_cnt++; + + /* If this function is only called once within an RTT + * the cwnd is probably too small (in some cases due to + * tso, lro or interrupt coalescence), so we increase + * ca->nv_min_cwnd. + */ + if (ca->nv_eval_call_cnt == 1 && + bytes_acked >= (ca->nv_min_cwnd - 1) * tp->mss_cache && + ca->nv_min_cwnd < (NV_TSO_CWND_BOUND + 1)) { + ca->nv_min_cwnd = min(ca->nv_min_cwnd + + NV_MIN_CWND_GROW, + NV_TSO_CWND_BOUND + 1); + ca->nv_rtt_start_seq = tp->snd_nxt + + ca->nv_min_cwnd * tp->mss_cache; + ca->nv_eval_call_cnt = 0; + ca->nv_allow_cwnd_growth = 1; + return; + } + + /* Find the ideal cwnd for current rate from slope + * slope = 80000.0 * mss / nv_min_rtt + * cwnd_by_slope = nv_rtt_max_rate / slope + */ + cwnd_by_slope = (u32) + div64_u64(((u64)ca->nv_rtt_max_rate) * ca->nv_min_rtt, + 80000ULL * tp->mss_cache); + max_win = cwnd_by_slope + nv_pad; + + /* If cwnd > max_win, decrease cwnd + * if cwnd < max_win, grow cwnd + * else leave the same + */ + if (tcp_snd_cwnd(tp) > max_win) { + /* there is congestion, check that it is ok + * to make a CA decision + * 1. We should have at least nv_dec_eval_min_calls + * data points before making a CA decision + * 2. We only make a congesion decision after + * nv_rtt_min_cnt RTTs + */ + if (ca->nv_rtt_cnt < nv_rtt_min_cnt) { + return; + } else if (tp->snd_ssthresh == TCP_INFINITE_SSTHRESH) { + if (ca->nv_eval_call_cnt < + nv_ssthresh_eval_min_calls) + return; + /* otherwise we will decrease cwnd */ + } else if (ca->nv_eval_call_cnt < + nv_dec_eval_min_calls) { + if (ca->nv_allow_cwnd_growth && + ca->nv_rtt_cnt > nv_stop_rtt_cnt) + ca->nv_allow_cwnd_growth = 0; + return; + } + + /* We have enough data to determine we are congested */ + ca->nv_allow_cwnd_growth = 0; + tp->snd_ssthresh = + (nv_ssthresh_factor * max_win) >> 3; + if (tcp_snd_cwnd(tp) - max_win > 2) { + /* gap > 2, we do exponential cwnd decrease */ + int dec; + + dec = max(2U, ((tcp_snd_cwnd(tp) - max_win) * + nv_cong_dec_mult) >> 7); + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) - dec); + } else if (nv_cong_dec_mult > 0) { + tcp_snd_cwnd_set(tp, max_win); + } + if (ca->cwnd_growth_factor > 0) + ca->cwnd_growth_factor = 0; + ca->nv_no_cong_cnt = 0; + } else if (tcp_snd_cwnd(tp) <= max_win - nv_pad_buffer) { + /* There is no congestion, grow cwnd if allowed*/ + if (ca->nv_eval_call_cnt < nv_inc_eval_min_calls) + return; + + ca->nv_allow_cwnd_growth = 1; + ca->nv_no_cong_cnt++; + if (ca->cwnd_growth_factor < 0 && + nv_cwnd_growth_rate_neg > 0 && + ca->nv_no_cong_cnt > nv_cwnd_growth_rate_neg) { + ca->cwnd_growth_factor++; + ca->nv_no_cong_cnt = 0; + } else if (ca->cwnd_growth_factor >= 0 && + nv_cwnd_growth_rate_pos > 0 && + ca->nv_no_cong_cnt > + nv_cwnd_growth_rate_pos) { + ca->cwnd_growth_factor++; + ca->nv_no_cong_cnt = 0; + } + } else { + /* cwnd is in-between, so do nothing */ + return; + } + + /* update state */ + ca->nv_eval_call_cnt = 0; + ca->nv_rtt_cnt = 0; + ca->nv_rtt_max_rate = 0; + + /* Don't want to make cwnd < nv_min_cwnd + * (it wasn't before, if it is now is because nv + * decreased it). + */ + if (tcp_snd_cwnd(tp) < nv_min_cwnd) + tcp_snd_cwnd_set(tp, nv_min_cwnd); + } +} + +/* Extract info for Tcp socket info provided via netlink */ +static size_t tcpnv_get_info(struct sock *sk, u32 ext, int *attr, + union tcp_cc_info *info) +{ + const struct tcpnv *ca = inet_csk_ca(sk); + + if (ext & (1 << (INET_DIAG_VEGASINFO - 1))) { + info->vegas.tcpv_enabled = 1; + info->vegas.tcpv_rttcnt = ca->nv_rtt_cnt; + info->vegas.tcpv_rtt = ca->nv_last_rtt; + info->vegas.tcpv_minrtt = ca->nv_min_rtt; + + *attr = INET_DIAG_VEGASINFO; + return sizeof(struct tcpvegas_info); + } + return 0; +} + +static struct tcp_congestion_ops tcpnv __read_mostly = { + .init = tcpnv_init, + .ssthresh = tcpnv_recalc_ssthresh, + .cong_avoid = tcpnv_cong_avoid, + .set_state = tcpnv_state, + .undo_cwnd = tcp_reno_undo_cwnd, + .pkts_acked = tcpnv_acked, + .get_info = tcpnv_get_info, + + .owner = THIS_MODULE, + .name = "nv", +}; + +static int __init tcpnv_register(void) +{ + BUILD_BUG_ON(sizeof(struct tcpnv) > ICSK_CA_PRIV_SIZE); + + return tcp_register_congestion_control(&tcpnv); +} + +static void __exit tcpnv_unregister(void) +{ + tcp_unregister_congestion_control(&tcpnv); +} + +module_init(tcpnv_register); +module_exit(tcpnv_unregister); + +MODULE_AUTHOR("Lawrence Brakmo"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("TCP NV"); +MODULE_VERSION("1.0"); diff --git a/net/ipv4/tcp_offload.c b/net/ipv4/tcp_offload.c new file mode 100644 index 000000000..4851211aa --- /dev/null +++ b/net/ipv4/tcp_offload.c @@ -0,0 +1,358 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPV4 GSO/GRO offload support + * Linux INET implementation + * + * TCPv4 GSO/GRO support + */ + +#include +#include +#include +#include +#include + +static void tcp_gso_tstamp(struct sk_buff *skb, unsigned int ts_seq, + unsigned int seq, unsigned int mss) +{ + while (skb) { + if (before(ts_seq, seq + mss)) { + skb_shinfo(skb)->tx_flags |= SKBTX_SW_TSTAMP; + skb_shinfo(skb)->tskey = ts_seq; + return; + } + + skb = skb->next; + seq += mss; + } +} + +static struct sk_buff *tcp4_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + if (!(skb_shinfo(skb)->gso_type & SKB_GSO_TCPV4)) + return ERR_PTR(-EINVAL); + + if (!pskb_may_pull(skb, sizeof(struct tcphdr))) + return ERR_PTR(-EINVAL); + + if (unlikely(skb->ip_summed != CHECKSUM_PARTIAL)) { + const struct iphdr *iph = ip_hdr(skb); + struct tcphdr *th = tcp_hdr(skb); + + /* Set up checksum pseudo header, usually expect stack to + * have done this already. + */ + + th->check = 0; + skb->ip_summed = CHECKSUM_PARTIAL; + __tcp_v4_send_check(skb, iph->saddr, iph->daddr); + } + + return tcp_gso_segment(skb, features); +} + +struct sk_buff *tcp_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + struct sk_buff *segs = ERR_PTR(-EINVAL); + unsigned int sum_truesize = 0; + struct tcphdr *th; + unsigned int thlen; + unsigned int seq; + unsigned int oldlen; + unsigned int mss; + struct sk_buff *gso_skb = skb; + __sum16 newcheck; + bool ooo_okay, copy_destructor; + __wsum delta; + + th = tcp_hdr(skb); + thlen = th->doff * 4; + if (thlen < sizeof(*th)) + goto out; + + if (!pskb_may_pull(skb, thlen)) + goto out; + + oldlen = ~skb->len; + __skb_pull(skb, thlen); + + mss = skb_shinfo(skb)->gso_size; + if (unlikely(skb->len <= mss)) + goto out; + + if (skb_gso_ok(skb, features | NETIF_F_GSO_ROBUST)) { + /* Packet is from an untrusted source, reset gso_segs. */ + + skb_shinfo(skb)->gso_segs = DIV_ROUND_UP(skb->len, mss); + + segs = NULL; + goto out; + } + + copy_destructor = gso_skb->destructor == tcp_wfree; + ooo_okay = gso_skb->ooo_okay; + /* All segments but the first should have ooo_okay cleared */ + skb->ooo_okay = 0; + + segs = skb_segment(skb, features); + if (IS_ERR(segs)) + goto out; + + /* Only first segment might have ooo_okay set */ + segs->ooo_okay = ooo_okay; + + /* GSO partial and frag_list segmentation only requires splitting + * the frame into an MSS multiple and possibly a remainder, both + * cases return a GSO skb. So update the mss now. + */ + if (skb_is_gso(segs)) + mss *= skb_shinfo(segs)->gso_segs; + + delta = (__force __wsum)htonl(oldlen + thlen + mss); + + skb = segs; + th = tcp_hdr(skb); + seq = ntohl(th->seq); + + if (unlikely(skb_shinfo(gso_skb)->tx_flags & SKBTX_SW_TSTAMP)) + tcp_gso_tstamp(segs, skb_shinfo(gso_skb)->tskey, seq, mss); + + newcheck = ~csum_fold(csum_add(csum_unfold(th->check), delta)); + + while (skb->next) { + th->fin = th->psh = 0; + th->check = newcheck; + + if (skb->ip_summed == CHECKSUM_PARTIAL) + gso_reset_checksum(skb, ~th->check); + else + th->check = gso_make_checksum(skb, ~th->check); + + seq += mss; + if (copy_destructor) { + skb->destructor = gso_skb->destructor; + skb->sk = gso_skb->sk; + sum_truesize += skb->truesize; + } + skb = skb->next; + th = tcp_hdr(skb); + + th->seq = htonl(seq); + th->cwr = 0; + } + + /* Following permits TCP Small Queues to work well with GSO : + * The callback to TCP stack will be called at the time last frag + * is freed at TX completion, and not right now when gso_skb + * is freed by GSO engine + */ + if (copy_destructor) { + int delta; + + swap(gso_skb->sk, skb->sk); + swap(gso_skb->destructor, skb->destructor); + sum_truesize += skb->truesize; + delta = sum_truesize - gso_skb->truesize; + /* In some pathological cases, delta can be negative. + * We need to either use refcount_add() or refcount_sub_and_test() + */ + if (likely(delta >= 0)) + refcount_add(delta, &skb->sk->sk_wmem_alloc); + else + WARN_ON_ONCE(refcount_sub_and_test(-delta, &skb->sk->sk_wmem_alloc)); + } + + delta = (__force __wsum)htonl(oldlen + + (skb_tail_pointer(skb) - + skb_transport_header(skb)) + + skb->data_len); + th->check = ~csum_fold(csum_add(csum_unfold(th->check), delta)); + if (skb->ip_summed == CHECKSUM_PARTIAL) + gso_reset_checksum(skb, ~th->check); + else + th->check = gso_make_checksum(skb, ~th->check); +out: + return segs; +} + +struct sk_buff *tcp_gro_receive(struct list_head *head, struct sk_buff *skb) +{ + struct sk_buff *pp = NULL; + struct sk_buff *p; + struct tcphdr *th; + struct tcphdr *th2; + unsigned int len; + unsigned int thlen; + __be32 flags; + unsigned int mss = 1; + unsigned int hlen; + unsigned int off; + int flush = 1; + int i; + + off = skb_gro_offset(skb); + hlen = off + sizeof(*th); + th = skb_gro_header(skb, hlen, off); + if (unlikely(!th)) + goto out; + + thlen = th->doff * 4; + if (thlen < sizeof(*th)) + goto out; + + hlen = off + thlen; + if (skb_gro_header_hard(skb, hlen)) { + th = skb_gro_header_slow(skb, hlen, off); + if (unlikely(!th)) + goto out; + } + + skb_gro_pull(skb, thlen); + + len = skb_gro_len(skb); + flags = tcp_flag_word(th); + + list_for_each_entry(p, head, list) { + if (!NAPI_GRO_CB(p)->same_flow) + continue; + + th2 = tcp_hdr(p); + + if (*(u32 *)&th->source ^ *(u32 *)&th2->source) { + NAPI_GRO_CB(p)->same_flow = 0; + continue; + } + + goto found; + } + p = NULL; + goto out_check_final; + +found: + /* Include the IP ID check below from the inner most IP hdr */ + flush = NAPI_GRO_CB(p)->flush; + flush |= (__force int)(flags & TCP_FLAG_CWR); + flush |= (__force int)((flags ^ tcp_flag_word(th2)) & + ~(TCP_FLAG_CWR | TCP_FLAG_FIN | TCP_FLAG_PSH)); + flush |= (__force int)(th->ack_seq ^ th2->ack_seq); + for (i = sizeof(*th); i < thlen; i += 4) + flush |= *(u32 *)((u8 *)th + i) ^ + *(u32 *)((u8 *)th2 + i); + + /* When we receive our second frame we can made a decision on if we + * continue this flow as an atomic flow with a fixed ID or if we use + * an incrementing ID. + */ + if (NAPI_GRO_CB(p)->flush_id != 1 || + NAPI_GRO_CB(p)->count != 1 || + !NAPI_GRO_CB(p)->is_atomic) + flush |= NAPI_GRO_CB(p)->flush_id; + else + NAPI_GRO_CB(p)->is_atomic = false; + + mss = skb_shinfo(p)->gso_size; + + /* If skb is a GRO packet, make sure its gso_size matches prior packet mss. + * If it is a single frame, do not aggregate it if its length + * is bigger than our mss. + */ + if (unlikely(skb_is_gso(skb))) + flush |= (mss != skb_shinfo(skb)->gso_size); + else + flush |= (len - 1) >= mss; + + flush |= (ntohl(th2->seq) + skb_gro_len(p)) ^ ntohl(th->seq); +#ifdef CONFIG_TLS_DEVICE + flush |= p->decrypted ^ skb->decrypted; +#endif + + if (flush || skb_gro_receive(p, skb)) { + mss = 1; + goto out_check_final; + } + + tcp_flag_word(th2) |= flags & (TCP_FLAG_FIN | TCP_FLAG_PSH); + +out_check_final: + /* Force a flush if last segment is smaller than mss. */ + if (unlikely(skb_is_gso(skb))) + flush = len != NAPI_GRO_CB(skb)->count * skb_shinfo(skb)->gso_size; + else + flush = len < mss; + + flush |= (__force int)(flags & (TCP_FLAG_URG | TCP_FLAG_PSH | + TCP_FLAG_RST | TCP_FLAG_SYN | + TCP_FLAG_FIN)); + + if (p && (!NAPI_GRO_CB(skb)->same_flow || flush)) + pp = p; + +out: + NAPI_GRO_CB(skb)->flush |= (flush != 0); + + return pp; +} + +int tcp_gro_complete(struct sk_buff *skb) +{ + struct tcphdr *th = tcp_hdr(skb); + + skb->csum_start = (unsigned char *)th - skb->head; + skb->csum_offset = offsetof(struct tcphdr, check); + skb->ip_summed = CHECKSUM_PARTIAL; + + skb_shinfo(skb)->gso_segs = NAPI_GRO_CB(skb)->count; + + if (th->cwr) + skb_shinfo(skb)->gso_type |= SKB_GSO_TCP_ECN; + + if (skb->encapsulation) + skb->inner_transport_header = skb->transport_header; + + return 0; +} +EXPORT_SYMBOL(tcp_gro_complete); + +INDIRECT_CALLABLE_SCOPE +struct sk_buff *tcp4_gro_receive(struct list_head *head, struct sk_buff *skb) +{ + /* Don't bother verifying checksum if we're going to flush anyway. */ + if (!NAPI_GRO_CB(skb)->flush && + skb_gro_checksum_validate(skb, IPPROTO_TCP, + inet_gro_compute_pseudo)) { + NAPI_GRO_CB(skb)->flush = 1; + return NULL; + } + + return tcp_gro_receive(head, skb); +} + +INDIRECT_CALLABLE_SCOPE int tcp4_gro_complete(struct sk_buff *skb, int thoff) +{ + const struct iphdr *iph = ip_hdr(skb); + struct tcphdr *th = tcp_hdr(skb); + + th->check = ~tcp_v4_check(skb->len - thoff, iph->saddr, + iph->daddr, 0); + skb_shinfo(skb)->gso_type |= SKB_GSO_TCPV4; + + if (NAPI_GRO_CB(skb)->is_atomic) + skb_shinfo(skb)->gso_type |= SKB_GSO_TCP_FIXEDID; + + return tcp_gro_complete(skb); +} + +static const struct net_offload tcpv4_offload = { + .callbacks = { + .gso_segment = tcp4_gso_segment, + .gro_receive = tcp4_gro_receive, + .gro_complete = tcp4_gro_complete, + }, +}; + +int __init tcpv4_offload_init(void) +{ + return inet_add_offload(&tcpv4_offload, IPPROTO_TCP); +} diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c new file mode 100644 index 000000000..67087da45 --- /dev/null +++ b/net/ipv4/tcp_output.c @@ -0,0 +1,4199 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * Implementation of the Transmission Control Protocol(TCP). + * + * Authors: Ross Biro + * Fred N. van Kempen, + * Mark Evans, + * Corey Minyard + * Florian La Roche, + * Charles Hedrick, + * Linus Torvalds, + * Alan Cox, + * Matthew Dillon, + * Arnt Gulbrandsen, + * Jorge Cwik, + */ + +/* + * Changes: Pedro Roque : Retransmit queue handled by TCP. + * : Fragmentation on mtu decrease + * : Segment collapse on retransmit + * : AF independence + * + * Linus Torvalds : send_delayed_ack + * David S. Miller : Charge memory using the right skb + * during syn/ack processing. + * David S. Miller : Output engine completely rewritten. + * Andrea Arcangeli: SYNACK carry ts_recent in tsecr. + * Cacophonix Gaul : draft-minshall-nagle-01 + * J Hadi Salim : ECN support + * + */ + +#define pr_fmt(fmt) "TCP: " fmt + +#include +#include + +#include +#include +#include +#include + +#include + +/* Refresh clocks of a TCP socket, + * ensuring monotically increasing values. + */ +void tcp_mstamp_refresh(struct tcp_sock *tp) +{ + u64 val = tcp_clock_ns(); + + tp->tcp_clock_cache = val; + tp->tcp_mstamp = div_u64(val, NSEC_PER_USEC); +} + +static bool tcp_write_xmit(struct sock *sk, unsigned int mss_now, int nonagle, + int push_one, gfp_t gfp); + +/* Account for new data that has been sent to the network. */ +static void tcp_event_new_data_sent(struct sock *sk, struct sk_buff *skb) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + unsigned int prior_packets = tp->packets_out; + + WRITE_ONCE(tp->snd_nxt, TCP_SKB_CB(skb)->end_seq); + + __skb_unlink(skb, &sk->sk_write_queue); + tcp_rbtree_insert(&sk->tcp_rtx_queue, skb); + + if (tp->highest_sack == NULL) + tp->highest_sack = skb; + + tp->packets_out += tcp_skb_pcount(skb); + if (!prior_packets || icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) + tcp_rearm_rto(sk); + + NET_ADD_STATS(sock_net(sk), LINUX_MIB_TCPORIGDATASENT, + tcp_skb_pcount(skb)); + tcp_check_space(sk); +} + +/* SND.NXT, if window was not shrunk or the amount of shrunk was less than one + * window scaling factor due to loss of precision. + * If window has been shrunk, what should we make? It is not clear at all. + * Using SND.UNA we will fail to open window, SND.NXT is out of window. :-( + * Anything in between SND.UNA...SND.UNA+SND.WND also can be already + * invalid. OK, let's make this for now: + */ +static inline __u32 tcp_acceptable_seq(const struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + + if (!before(tcp_wnd_end(tp), tp->snd_nxt) || + (tp->rx_opt.wscale_ok && + ((tp->snd_nxt - tcp_wnd_end(tp)) < (1 << tp->rx_opt.rcv_wscale)))) + return tp->snd_nxt; + else + return tcp_wnd_end(tp); +} + +/* Calculate mss to advertise in SYN segment. + * RFC1122, RFC1063, draft-ietf-tcpimpl-pmtud-01 state that: + * + * 1. It is independent of path mtu. + * 2. Ideally, it is maximal possible segment size i.e. 65535-40. + * 3. For IPv4 it is reasonable to calculate it from maximal MTU of + * attached devices, because some buggy hosts are confused by + * large MSS. + * 4. We do not make 3, we advertise MSS, calculated from first + * hop device mtu, but allow to raise it to ip_rt_min_advmss. + * This may be overridden via information stored in routing table. + * 5. Value 65535 for MSS is valid in IPv6 and means "as large as possible, + * probably even Jumbo". + */ +static __u16 tcp_advertise_mss(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + const struct dst_entry *dst = __sk_dst_get(sk); + int mss = tp->advmss; + + if (dst) { + unsigned int metric = dst_metric_advmss(dst); + + if (metric < mss) { + mss = metric; + tp->advmss = mss; + } + } + + return (__u16)mss; +} + +/* RFC2861. Reset CWND after idle period longer RTO to "restart window". + * This is the first part of cwnd validation mechanism. + */ +void tcp_cwnd_restart(struct sock *sk, s32 delta) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 restart_cwnd = tcp_init_cwnd(tp, __sk_dst_get(sk)); + u32 cwnd = tcp_snd_cwnd(tp); + + tcp_ca_event(sk, CA_EVENT_CWND_RESTART); + + tp->snd_ssthresh = tcp_current_ssthresh(sk); + restart_cwnd = min(restart_cwnd, cwnd); + + while ((delta -= inet_csk(sk)->icsk_rto) > 0 && cwnd > restart_cwnd) + cwnd >>= 1; + tcp_snd_cwnd_set(tp, max(cwnd, restart_cwnd)); + tp->snd_cwnd_stamp = tcp_jiffies32; + tp->snd_cwnd_used = 0; +} + +/* Congestion state accounting after a packet has been sent. */ +static void tcp_event_data_sent(struct tcp_sock *tp, + struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + const u32 now = tcp_jiffies32; + + if (tcp_packets_in_flight(tp) == 0) + tcp_ca_event(sk, CA_EVENT_TX_START); + + tp->lsndtime = now; + + /* If it is a reply for ato after last received + * packet, enter pingpong mode. + */ + if ((u32)(now - icsk->icsk_ack.lrcvtime) < icsk->icsk_ack.ato) + inet_csk_enter_pingpong_mode(sk); +} + +/* Account for an ACK we sent. */ +static inline void tcp_event_ack_sent(struct sock *sk, u32 rcv_nxt) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (unlikely(tp->compressed_ack)) { + NET_ADD_STATS(sock_net(sk), LINUX_MIB_TCPACKCOMPRESSED, + tp->compressed_ack); + tp->compressed_ack = 0; + if (hrtimer_try_to_cancel(&tp->compressed_ack_timer) == 1) + __sock_put(sk); + } + + if (unlikely(rcv_nxt != tp->rcv_nxt)) + return; /* Special ACK sent by DCTCP to reflect ECN */ + tcp_dec_quickack_mode(sk); + inet_csk_clear_xmit_timer(sk, ICSK_TIME_DACK); +} + +/* Determine a window scaling and initial window to offer. + * Based on the assumption that the given amount of space + * will be offered. Store the results in the tp structure. + * NOTE: for smooth operation initial space offering should + * be a multiple of mss if possible. We assume here that mss >= 1. + * This MUST be enforced by all callers. + */ +void tcp_select_initial_window(const struct sock *sk, int __space, __u32 mss, + __u32 *rcv_wnd, __u32 *window_clamp, + int wscale_ok, __u8 *rcv_wscale, + __u32 init_rcv_wnd) +{ + unsigned int space = (__space < 0 ? 0 : __space); + + /* If no clamp set the clamp to the max possible scaled window */ + if (*window_clamp == 0) + (*window_clamp) = (U16_MAX << TCP_MAX_WSCALE); + space = min(*window_clamp, space); + + /* Quantize space offering to a multiple of mss if possible. */ + if (space > mss) + space = rounddown(space, mss); + + /* NOTE: offering an initial window larger than 32767 + * will break some buggy TCP stacks. If the admin tells us + * it is likely we could be speaking with such a buggy stack + * we will truncate our initial window offering to 32K-1 + * unless the remote has sent us a window scaling option, + * which we interpret as a sign the remote TCP is not + * misinterpreting the window field as a signed quantity. + */ + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_workaround_signed_windows)) + (*rcv_wnd) = min(space, MAX_TCP_WINDOW); + else + (*rcv_wnd) = min_t(u32, space, U16_MAX); + + if (init_rcv_wnd) + *rcv_wnd = min(*rcv_wnd, init_rcv_wnd * mss); + + *rcv_wscale = 0; + if (wscale_ok) { + /* Set window scaling on max possible window */ + space = max_t(u32, space, READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_rmem[2])); + space = max_t(u32, space, READ_ONCE(sysctl_rmem_max)); + space = min_t(u32, space, *window_clamp); + *rcv_wscale = clamp_t(int, ilog2(space) - 15, + 0, TCP_MAX_WSCALE); + } + /* Set the clamp no higher than max representable value */ + (*window_clamp) = min_t(__u32, U16_MAX << (*rcv_wscale), *window_clamp); +} +EXPORT_SYMBOL(tcp_select_initial_window); + +/* Chose a new window to advertise, update state in tcp_sock for the + * socket, and return result with RFC1323 scaling applied. The return + * value can be stuffed directly into th->window for an outgoing + * frame. + */ +static u16 tcp_select_window(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 old_win = tp->rcv_wnd; + u32 cur_win = tcp_receive_window(tp); + u32 new_win = __tcp_select_window(sk); + struct net *net = sock_net(sk); + + if (new_win < cur_win) { + /* Danger Will Robinson! + * Don't update rcv_wup/rcv_wnd here or else + * we will not be able to advertise a zero + * window in time. --DaveM + * + * Relax Will Robinson. + */ + if (!READ_ONCE(net->ipv4.sysctl_tcp_shrink_window) || !tp->rx_opt.rcv_wscale) { + /* Never shrink the offered window */ + if (new_win == 0) + NET_INC_STATS(net, LINUX_MIB_TCPWANTZEROWINDOWADV); + new_win = ALIGN(cur_win, 1 << tp->rx_opt.rcv_wscale); + } + } + + tp->rcv_wnd = new_win; + tp->rcv_wup = tp->rcv_nxt; + + /* Make sure we do not exceed the maximum possible + * scaled window. + */ + if (!tp->rx_opt.rcv_wscale && + READ_ONCE(net->ipv4.sysctl_tcp_workaround_signed_windows)) + new_win = min(new_win, MAX_TCP_WINDOW); + else + new_win = min(new_win, (65535U << tp->rx_opt.rcv_wscale)); + + /* RFC1323 scaling applied */ + new_win >>= tp->rx_opt.rcv_wscale; + + /* If we advertise zero window, disable fast path. */ + if (new_win == 0) { + tp->pred_flags = 0; + if (old_win) + NET_INC_STATS(net, LINUX_MIB_TCPTOZEROWINDOWADV); + } else if (old_win == 0) { + NET_INC_STATS(net, LINUX_MIB_TCPFROMZEROWINDOWADV); + } + + return new_win; +} + +/* Packet ECN state for a SYN-ACK */ +static void tcp_ecn_send_synack(struct sock *sk, struct sk_buff *skb) +{ + const struct tcp_sock *tp = tcp_sk(sk); + + TCP_SKB_CB(skb)->tcp_flags &= ~TCPHDR_CWR; + if (!(tp->ecn_flags & TCP_ECN_OK)) + TCP_SKB_CB(skb)->tcp_flags &= ~TCPHDR_ECE; + else if (tcp_ca_needs_ecn(sk) || + tcp_bpf_ca_needs_ecn(sk)) + INET_ECN_xmit(sk); +} + +/* Packet ECN state for a SYN. */ +static void tcp_ecn_send_syn(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + bool bpf_needs_ecn = tcp_bpf_ca_needs_ecn(sk); + bool use_ecn = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_ecn) == 1 || + tcp_ca_needs_ecn(sk) || bpf_needs_ecn; + + if (!use_ecn) { + const struct dst_entry *dst = __sk_dst_get(sk); + + if (dst && dst_feature(dst, RTAX_FEATURE_ECN)) + use_ecn = true; + } + + tp->ecn_flags = 0; + + if (use_ecn) { + TCP_SKB_CB(skb)->tcp_flags |= TCPHDR_ECE | TCPHDR_CWR; + tp->ecn_flags = TCP_ECN_OK; + if (tcp_ca_needs_ecn(sk) || bpf_needs_ecn) + INET_ECN_xmit(sk); + } +} + +static void tcp_ecn_clear_syn(struct sock *sk, struct sk_buff *skb) +{ + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_ecn_fallback)) + /* tp->ecn_flags are cleared at a later point in time when + * SYN ACK is ultimatively being received. + */ + TCP_SKB_CB(skb)->tcp_flags &= ~(TCPHDR_ECE | TCPHDR_CWR); +} + +static void +tcp_ecn_make_synack(const struct request_sock *req, struct tcphdr *th) +{ + if (inet_rsk(req)->ecn_ok) + th->ece = 1; +} + +/* Set up ECN state for a packet on a ESTABLISHED socket that is about to + * be sent. + */ +static void tcp_ecn_send(struct sock *sk, struct sk_buff *skb, + struct tcphdr *th, int tcp_header_len) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (tp->ecn_flags & TCP_ECN_OK) { + /* Not-retransmitted data segment: set ECT and inject CWR. */ + if (skb->len != tcp_header_len && + !before(TCP_SKB_CB(skb)->seq, tp->snd_nxt)) { + INET_ECN_xmit(sk); + if (tp->ecn_flags & TCP_ECN_QUEUE_CWR) { + tp->ecn_flags &= ~TCP_ECN_QUEUE_CWR; + th->cwr = 1; + skb_shinfo(skb)->gso_type |= SKB_GSO_TCP_ECN; + } + } else if (!tcp_ca_needs_ecn(sk)) { + /* ACK or retransmitted segment: clear ECT|CE */ + INET_ECN_dontxmit(sk); + } + if (tp->ecn_flags & TCP_ECN_DEMAND_CWR) + th->ece = 1; + } +} + +/* Constructs common control bits of non-data skb. If SYN/FIN is present, + * auto increment end seqno. + */ +static void tcp_init_nondata_skb(struct sk_buff *skb, u32 seq, u8 flags) +{ + skb->ip_summed = CHECKSUM_PARTIAL; + + TCP_SKB_CB(skb)->tcp_flags = flags; + + tcp_skb_pcount_set(skb, 1); + + TCP_SKB_CB(skb)->seq = seq; + if (flags & (TCPHDR_SYN | TCPHDR_FIN)) + seq++; + TCP_SKB_CB(skb)->end_seq = seq; +} + +static inline bool tcp_urg_mode(const struct tcp_sock *tp) +{ + return tp->snd_una != tp->snd_up; +} + +#define OPTION_SACK_ADVERTISE BIT(0) +#define OPTION_TS BIT(1) +#define OPTION_MD5 BIT(2) +#define OPTION_WSCALE BIT(3) +#define OPTION_FAST_OPEN_COOKIE BIT(8) +#define OPTION_SMC BIT(9) +#define OPTION_MPTCP BIT(10) + +static void smc_options_write(__be32 *ptr, u16 *options) +{ +#if IS_ENABLED(CONFIG_SMC) + if (static_branch_unlikely(&tcp_have_smc)) { + if (unlikely(OPTION_SMC & *options)) { + *ptr++ = htonl((TCPOPT_NOP << 24) | + (TCPOPT_NOP << 16) | + (TCPOPT_EXP << 8) | + (TCPOLEN_EXP_SMC_BASE)); + *ptr++ = htonl(TCPOPT_SMC_MAGIC); + } + } +#endif +} + +struct tcp_out_options { + u16 options; /* bit field of OPTION_* */ + u16 mss; /* 0 to disable */ + u8 ws; /* window scale, 0 to disable */ + u8 num_sack_blocks; /* number of SACK blocks to include */ + u8 hash_size; /* bytes in hash_location */ + u8 bpf_opt_len; /* length of BPF hdr option */ + __u8 *hash_location; /* temporary pointer, overloaded */ + __u32 tsval, tsecr; /* need to include OPTION_TS */ + struct tcp_fastopen_cookie *fastopen_cookie; /* Fast open cookie */ + struct mptcp_out_options mptcp; +}; + +static void mptcp_options_write(struct tcphdr *th, __be32 *ptr, + struct tcp_sock *tp, + struct tcp_out_options *opts) +{ +#if IS_ENABLED(CONFIG_MPTCP) + if (unlikely(OPTION_MPTCP & opts->options)) + mptcp_write_options(th, ptr, tp, &opts->mptcp); +#endif +} + +#ifdef CONFIG_CGROUP_BPF +static int bpf_skops_write_hdr_opt_arg0(struct sk_buff *skb, + enum tcp_synack_type synack_type) +{ + if (unlikely(!skb)) + return BPF_WRITE_HDR_TCP_CURRENT_MSS; + + if (unlikely(synack_type == TCP_SYNACK_COOKIE)) + return BPF_WRITE_HDR_TCP_SYNACK_COOKIE; + + return 0; +} + +/* req, syn_skb and synack_type are used when writing synack */ +static void bpf_skops_hdr_opt_len(struct sock *sk, struct sk_buff *skb, + struct request_sock *req, + struct sk_buff *syn_skb, + enum tcp_synack_type synack_type, + struct tcp_out_options *opts, + unsigned int *remaining) +{ + struct bpf_sock_ops_kern sock_ops; + int err; + + if (likely(!BPF_SOCK_OPS_TEST_FLAG(tcp_sk(sk), + BPF_SOCK_OPS_WRITE_HDR_OPT_CB_FLAG)) || + !*remaining) + return; + + /* *remaining has already been aligned to 4 bytes, so *remaining >= 4 */ + + /* init sock_ops */ + memset(&sock_ops, 0, offsetof(struct bpf_sock_ops_kern, temp)); + + sock_ops.op = BPF_SOCK_OPS_HDR_OPT_LEN_CB; + + if (req) { + /* The listen "sk" cannot be passed here because + * it is not locked. It would not make too much + * sense to do bpf_setsockopt(listen_sk) based + * on individual connection request also. + * + * Thus, "req" is passed here and the cgroup-bpf-progs + * of the listen "sk" will be run. + * + * "req" is also used here for fastopen even the "sk" here is + * a fullsock "child" sk. It is to keep the behavior + * consistent between fastopen and non-fastopen on + * the bpf programming side. + */ + sock_ops.sk = (struct sock *)req; + sock_ops.syn_skb = syn_skb; + } else { + sock_owned_by_me(sk); + + sock_ops.is_fullsock = 1; + sock_ops.sk = sk; + } + + sock_ops.args[0] = bpf_skops_write_hdr_opt_arg0(skb, synack_type); + sock_ops.remaining_opt_len = *remaining; + /* tcp_current_mss() does not pass a skb */ + if (skb) + bpf_skops_init_skb(&sock_ops, skb, 0); + + err = BPF_CGROUP_RUN_PROG_SOCK_OPS_SK(&sock_ops, sk); + + if (err || sock_ops.remaining_opt_len == *remaining) + return; + + opts->bpf_opt_len = *remaining - sock_ops.remaining_opt_len; + /* round up to 4 bytes */ + opts->bpf_opt_len = (opts->bpf_opt_len + 3) & ~3; + + *remaining -= opts->bpf_opt_len; +} + +static void bpf_skops_write_hdr_opt(struct sock *sk, struct sk_buff *skb, + struct request_sock *req, + struct sk_buff *syn_skb, + enum tcp_synack_type synack_type, + struct tcp_out_options *opts) +{ + u8 first_opt_off, nr_written, max_opt_len = opts->bpf_opt_len; + struct bpf_sock_ops_kern sock_ops; + int err; + + if (likely(!max_opt_len)) + return; + + memset(&sock_ops, 0, offsetof(struct bpf_sock_ops_kern, temp)); + + sock_ops.op = BPF_SOCK_OPS_WRITE_HDR_OPT_CB; + + if (req) { + sock_ops.sk = (struct sock *)req; + sock_ops.syn_skb = syn_skb; + } else { + sock_owned_by_me(sk); + + sock_ops.is_fullsock = 1; + sock_ops.sk = sk; + } + + sock_ops.args[0] = bpf_skops_write_hdr_opt_arg0(skb, synack_type); + sock_ops.remaining_opt_len = max_opt_len; + first_opt_off = tcp_hdrlen(skb) - max_opt_len; + bpf_skops_init_skb(&sock_ops, skb, first_opt_off); + + err = BPF_CGROUP_RUN_PROG_SOCK_OPS_SK(&sock_ops, sk); + + if (err) + nr_written = 0; + else + nr_written = max_opt_len - sock_ops.remaining_opt_len; + + if (nr_written < max_opt_len) + memset(skb->data + first_opt_off + nr_written, TCPOPT_NOP, + max_opt_len - nr_written); +} +#else +static void bpf_skops_hdr_opt_len(struct sock *sk, struct sk_buff *skb, + struct request_sock *req, + struct sk_buff *syn_skb, + enum tcp_synack_type synack_type, + struct tcp_out_options *opts, + unsigned int *remaining) +{ +} + +static void bpf_skops_write_hdr_opt(struct sock *sk, struct sk_buff *skb, + struct request_sock *req, + struct sk_buff *syn_skb, + enum tcp_synack_type synack_type, + struct tcp_out_options *opts) +{ +} +#endif + +/* Write previously computed TCP options to the packet. + * + * Beware: Something in the Internet is very sensitive to the ordering of + * TCP options, we learned this through the hard way, so be careful here. + * Luckily we can at least blame others for their non-compliance but from + * inter-operability perspective it seems that we're somewhat stuck with + * the ordering which we have been using if we want to keep working with + * those broken things (not that it currently hurts anybody as there isn't + * particular reason why the ordering would need to be changed). + * + * At least SACK_PERM as the first option is known to lead to a disaster + * (but it may well be that other scenarios fail similarly). + */ +static void tcp_options_write(struct tcphdr *th, struct tcp_sock *tp, + struct tcp_out_options *opts) +{ + __be32 *ptr = (__be32 *)(th + 1); + u16 options = opts->options; /* mungable copy */ + + if (unlikely(OPTION_MD5 & options)) { + *ptr++ = htonl((TCPOPT_NOP << 24) | (TCPOPT_NOP << 16) | + (TCPOPT_MD5SIG << 8) | TCPOLEN_MD5SIG); + /* overload cookie hash location */ + opts->hash_location = (__u8 *)ptr; + ptr += 4; + } + + if (unlikely(opts->mss)) { + *ptr++ = htonl((TCPOPT_MSS << 24) | + (TCPOLEN_MSS << 16) | + opts->mss); + } + + if (likely(OPTION_TS & options)) { + if (unlikely(OPTION_SACK_ADVERTISE & options)) { + *ptr++ = htonl((TCPOPT_SACK_PERM << 24) | + (TCPOLEN_SACK_PERM << 16) | + (TCPOPT_TIMESTAMP << 8) | + TCPOLEN_TIMESTAMP); + options &= ~OPTION_SACK_ADVERTISE; + } else { + *ptr++ = htonl((TCPOPT_NOP << 24) | + (TCPOPT_NOP << 16) | + (TCPOPT_TIMESTAMP << 8) | + TCPOLEN_TIMESTAMP); + } + *ptr++ = htonl(opts->tsval); + *ptr++ = htonl(opts->tsecr); + } + + if (unlikely(OPTION_SACK_ADVERTISE & options)) { + *ptr++ = htonl((TCPOPT_NOP << 24) | + (TCPOPT_NOP << 16) | + (TCPOPT_SACK_PERM << 8) | + TCPOLEN_SACK_PERM); + } + + if (unlikely(OPTION_WSCALE & options)) { + *ptr++ = htonl((TCPOPT_NOP << 24) | + (TCPOPT_WINDOW << 16) | + (TCPOLEN_WINDOW << 8) | + opts->ws); + } + + if (unlikely(opts->num_sack_blocks)) { + struct tcp_sack_block *sp = tp->rx_opt.dsack ? + tp->duplicate_sack : tp->selective_acks; + int this_sack; + + *ptr++ = htonl((TCPOPT_NOP << 24) | + (TCPOPT_NOP << 16) | + (TCPOPT_SACK << 8) | + (TCPOLEN_SACK_BASE + (opts->num_sack_blocks * + TCPOLEN_SACK_PERBLOCK))); + + for (this_sack = 0; this_sack < opts->num_sack_blocks; + ++this_sack) { + *ptr++ = htonl(sp[this_sack].start_seq); + *ptr++ = htonl(sp[this_sack].end_seq); + } + + tp->rx_opt.dsack = 0; + } + + if (unlikely(OPTION_FAST_OPEN_COOKIE & options)) { + struct tcp_fastopen_cookie *foc = opts->fastopen_cookie; + u8 *p = (u8 *)ptr; + u32 len; /* Fast Open option length */ + + if (foc->exp) { + len = TCPOLEN_EXP_FASTOPEN_BASE + foc->len; + *ptr = htonl((TCPOPT_EXP << 24) | (len << 16) | + TCPOPT_FASTOPEN_MAGIC); + p += TCPOLEN_EXP_FASTOPEN_BASE; + } else { + len = TCPOLEN_FASTOPEN_BASE + foc->len; + *p++ = TCPOPT_FASTOPEN; + *p++ = len; + } + + memcpy(p, foc->val, foc->len); + if ((len & 3) == 2) { + p[foc->len] = TCPOPT_NOP; + p[foc->len + 1] = TCPOPT_NOP; + } + ptr += (len + 3) >> 2; + } + + smc_options_write(ptr, &options); + + mptcp_options_write(th, ptr, tp, opts); +} + +static void smc_set_option(const struct tcp_sock *tp, + struct tcp_out_options *opts, + unsigned int *remaining) +{ +#if IS_ENABLED(CONFIG_SMC) + if (static_branch_unlikely(&tcp_have_smc)) { + if (tp->syn_smc) { + if (*remaining >= TCPOLEN_EXP_SMC_BASE_ALIGNED) { + opts->options |= OPTION_SMC; + *remaining -= TCPOLEN_EXP_SMC_BASE_ALIGNED; + } + } + } +#endif +} + +static void smc_set_option_cond(const struct tcp_sock *tp, + const struct inet_request_sock *ireq, + struct tcp_out_options *opts, + unsigned int *remaining) +{ +#if IS_ENABLED(CONFIG_SMC) + if (static_branch_unlikely(&tcp_have_smc)) { + if (tp->syn_smc && ireq->smc_ok) { + if (*remaining >= TCPOLEN_EXP_SMC_BASE_ALIGNED) { + opts->options |= OPTION_SMC; + *remaining -= TCPOLEN_EXP_SMC_BASE_ALIGNED; + } + } + } +#endif +} + +static void mptcp_set_option_cond(const struct request_sock *req, + struct tcp_out_options *opts, + unsigned int *remaining) +{ + if (rsk_is_mptcp(req)) { + unsigned int size; + + if (mptcp_synack_options(req, &size, &opts->mptcp)) { + if (*remaining >= size) { + opts->options |= OPTION_MPTCP; + *remaining -= size; + } + } + } +} + +/* Compute TCP options for SYN packets. This is not the final + * network wire format yet. + */ +static unsigned int tcp_syn_options(struct sock *sk, struct sk_buff *skb, + struct tcp_out_options *opts, + struct tcp_md5sig_key **md5) +{ + struct tcp_sock *tp = tcp_sk(sk); + unsigned int remaining = MAX_TCP_OPTION_SPACE; + struct tcp_fastopen_request *fastopen = tp->fastopen_req; + + *md5 = NULL; +#ifdef CONFIG_TCP_MD5SIG + if (static_branch_unlikely(&tcp_md5_needed) && + rcu_access_pointer(tp->md5sig_info)) { + *md5 = tp->af_specific->md5_lookup(sk, sk); + if (*md5) { + opts->options |= OPTION_MD5; + remaining -= TCPOLEN_MD5SIG_ALIGNED; + } + } +#endif + + /* We always get an MSS option. The option bytes which will be seen in + * normal data packets should timestamps be used, must be in the MSS + * advertised. But we subtract them from tp->mss_cache so that + * calculations in tcp_sendmsg are simpler etc. So account for this + * fact here if necessary. If we don't do this correctly, as a + * receiver we won't recognize data packets as being full sized when we + * should, and thus we won't abide by the delayed ACK rules correctly. + * SACKs don't matter, we never delay an ACK when we have any of those + * going out. */ + opts->mss = tcp_advertise_mss(sk); + remaining -= TCPOLEN_MSS_ALIGNED; + + if (likely(READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_timestamps) && !*md5)) { + opts->options |= OPTION_TS; + opts->tsval = tcp_skb_timestamp(skb) + tp->tsoffset; + opts->tsecr = tp->rx_opt.ts_recent; + remaining -= TCPOLEN_TSTAMP_ALIGNED; + } + if (likely(READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_window_scaling))) { + opts->ws = tp->rx_opt.rcv_wscale; + opts->options |= OPTION_WSCALE; + remaining -= TCPOLEN_WSCALE_ALIGNED; + } + if (likely(READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_sack))) { + opts->options |= OPTION_SACK_ADVERTISE; + if (unlikely(!(OPTION_TS & opts->options))) + remaining -= TCPOLEN_SACKPERM_ALIGNED; + } + + if (fastopen && fastopen->cookie.len >= 0) { + u32 need = fastopen->cookie.len; + + need += fastopen->cookie.exp ? TCPOLEN_EXP_FASTOPEN_BASE : + TCPOLEN_FASTOPEN_BASE; + need = (need + 3) & ~3U; /* Align to 32 bits */ + if (remaining >= need) { + opts->options |= OPTION_FAST_OPEN_COOKIE; + opts->fastopen_cookie = &fastopen->cookie; + remaining -= need; + tp->syn_fastopen = 1; + tp->syn_fastopen_exp = fastopen->cookie.exp ? 1 : 0; + } + } + + smc_set_option(tp, opts, &remaining); + + if (sk_is_mptcp(sk)) { + unsigned int size; + + if (mptcp_syn_options(sk, skb, &size, &opts->mptcp)) { + opts->options |= OPTION_MPTCP; + remaining -= size; + } + } + + bpf_skops_hdr_opt_len(sk, skb, NULL, NULL, 0, opts, &remaining); + + return MAX_TCP_OPTION_SPACE - remaining; +} + +/* Set up TCP options for SYN-ACKs. */ +static unsigned int tcp_synack_options(const struct sock *sk, + struct request_sock *req, + unsigned int mss, struct sk_buff *skb, + struct tcp_out_options *opts, + const struct tcp_md5sig_key *md5, + struct tcp_fastopen_cookie *foc, + enum tcp_synack_type synack_type, + struct sk_buff *syn_skb) +{ + struct inet_request_sock *ireq = inet_rsk(req); + unsigned int remaining = MAX_TCP_OPTION_SPACE; + +#ifdef CONFIG_TCP_MD5SIG + if (md5) { + opts->options |= OPTION_MD5; + remaining -= TCPOLEN_MD5SIG_ALIGNED; + + /* We can't fit any SACK blocks in a packet with MD5 + TS + * options. There was discussion about disabling SACK + * rather than TS in order to fit in better with old, + * buggy kernels, but that was deemed to be unnecessary. + */ + if (synack_type != TCP_SYNACK_COOKIE) + ireq->tstamp_ok &= !ireq->sack_ok; + } +#endif + + /* We always send an MSS option. */ + opts->mss = mss; + remaining -= TCPOLEN_MSS_ALIGNED; + + if (likely(ireq->wscale_ok)) { + opts->ws = ireq->rcv_wscale; + opts->options |= OPTION_WSCALE; + remaining -= TCPOLEN_WSCALE_ALIGNED; + } + if (likely(ireq->tstamp_ok)) { + opts->options |= OPTION_TS; + opts->tsval = tcp_skb_timestamp(skb) + tcp_rsk(req)->ts_off; + opts->tsecr = READ_ONCE(req->ts_recent); + remaining -= TCPOLEN_TSTAMP_ALIGNED; + } + if (likely(ireq->sack_ok)) { + opts->options |= OPTION_SACK_ADVERTISE; + if (unlikely(!ireq->tstamp_ok)) + remaining -= TCPOLEN_SACKPERM_ALIGNED; + } + if (foc != NULL && foc->len >= 0) { + u32 need = foc->len; + + need += foc->exp ? TCPOLEN_EXP_FASTOPEN_BASE : + TCPOLEN_FASTOPEN_BASE; + need = (need + 3) & ~3U; /* Align to 32 bits */ + if (remaining >= need) { + opts->options |= OPTION_FAST_OPEN_COOKIE; + opts->fastopen_cookie = foc; + remaining -= need; + } + } + + mptcp_set_option_cond(req, opts, &remaining); + + smc_set_option_cond(tcp_sk(sk), ireq, opts, &remaining); + + bpf_skops_hdr_opt_len((struct sock *)sk, skb, req, syn_skb, + synack_type, opts, &remaining); + + return MAX_TCP_OPTION_SPACE - remaining; +} + +/* Compute TCP options for ESTABLISHED sockets. This is not the + * final wire format yet. + */ +static unsigned int tcp_established_options(struct sock *sk, struct sk_buff *skb, + struct tcp_out_options *opts, + struct tcp_md5sig_key **md5) +{ + struct tcp_sock *tp = tcp_sk(sk); + unsigned int size = 0; + unsigned int eff_sacks; + + opts->options = 0; + + *md5 = NULL; +#ifdef CONFIG_TCP_MD5SIG + if (static_branch_unlikely(&tcp_md5_needed) && + rcu_access_pointer(tp->md5sig_info)) { + *md5 = tp->af_specific->md5_lookup(sk, sk); + if (*md5) { + opts->options |= OPTION_MD5; + size += TCPOLEN_MD5SIG_ALIGNED; + } + } +#endif + + if (likely(tp->rx_opt.tstamp_ok)) { + opts->options |= OPTION_TS; + opts->tsval = skb ? tcp_skb_timestamp(skb) + tp->tsoffset : 0; + opts->tsecr = tp->rx_opt.ts_recent; + size += TCPOLEN_TSTAMP_ALIGNED; + } + + /* MPTCP options have precedence over SACK for the limited TCP + * option space because a MPTCP connection would be forced to + * fall back to regular TCP if a required multipath option is + * missing. SACK still gets a chance to use whatever space is + * left. + */ + if (sk_is_mptcp(sk)) { + unsigned int remaining = MAX_TCP_OPTION_SPACE - size; + unsigned int opt_size = 0; + + if (mptcp_established_options(sk, skb, &opt_size, remaining, + &opts->mptcp)) { + opts->options |= OPTION_MPTCP; + size += opt_size; + } + } + + eff_sacks = tp->rx_opt.num_sacks + tp->rx_opt.dsack; + if (unlikely(eff_sacks)) { + const unsigned int remaining = MAX_TCP_OPTION_SPACE - size; + if (unlikely(remaining < TCPOLEN_SACK_BASE_ALIGNED + + TCPOLEN_SACK_PERBLOCK)) + return size; + + opts->num_sack_blocks = + min_t(unsigned int, eff_sacks, + (remaining - TCPOLEN_SACK_BASE_ALIGNED) / + TCPOLEN_SACK_PERBLOCK); + + size += TCPOLEN_SACK_BASE_ALIGNED + + opts->num_sack_blocks * TCPOLEN_SACK_PERBLOCK; + } + + if (unlikely(BPF_SOCK_OPS_TEST_FLAG(tp, + BPF_SOCK_OPS_WRITE_HDR_OPT_CB_FLAG))) { + unsigned int remaining = MAX_TCP_OPTION_SPACE - size; + + bpf_skops_hdr_opt_len(sk, skb, NULL, NULL, 0, opts, &remaining); + + size = MAX_TCP_OPTION_SPACE - remaining; + } + + return size; +} + + +/* TCP SMALL QUEUES (TSQ) + * + * TSQ goal is to keep small amount of skbs per tcp flow in tx queues (qdisc+dev) + * to reduce RTT and bufferbloat. + * We do this using a special skb destructor (tcp_wfree). + * + * Its important tcp_wfree() can be replaced by sock_wfree() in the event skb + * needs to be reallocated in a driver. + * The invariant being skb->truesize subtracted from sk->sk_wmem_alloc + * + * Since transmit from skb destructor is forbidden, we use a tasklet + * to process all sockets that eventually need to send more skbs. + * We use one tasklet per cpu, with its own queue of sockets. + */ +struct tsq_tasklet { + struct tasklet_struct tasklet; + struct list_head head; /* queue of tcp sockets */ +}; +static DEFINE_PER_CPU(struct tsq_tasklet, tsq_tasklet); + +static void tcp_tsq_write(struct sock *sk) +{ + if ((1 << sk->sk_state) & + (TCPF_ESTABLISHED | TCPF_FIN_WAIT1 | TCPF_CLOSING | + TCPF_CLOSE_WAIT | TCPF_LAST_ACK)) { + struct tcp_sock *tp = tcp_sk(sk); + + if (tp->lost_out > tp->retrans_out && + tcp_snd_cwnd(tp) > tcp_packets_in_flight(tp)) { + tcp_mstamp_refresh(tp); + tcp_xmit_retransmit_queue(sk); + } + + tcp_write_xmit(sk, tcp_current_mss(sk), tp->nonagle, + 0, GFP_ATOMIC); + } +} + +static void tcp_tsq_handler(struct sock *sk) +{ + bh_lock_sock(sk); + if (!sock_owned_by_user(sk)) + tcp_tsq_write(sk); + else if (!test_and_set_bit(TCP_TSQ_DEFERRED, &sk->sk_tsq_flags)) + sock_hold(sk); + bh_unlock_sock(sk); +} +/* + * One tasklet per cpu tries to send more skbs. + * We run in tasklet context but need to disable irqs when + * transferring tsq->head because tcp_wfree() might + * interrupt us (non NAPI drivers) + */ +static void tcp_tasklet_func(struct tasklet_struct *t) +{ + struct tsq_tasklet *tsq = from_tasklet(tsq, t, tasklet); + LIST_HEAD(list); + unsigned long flags; + struct list_head *q, *n; + struct tcp_sock *tp; + struct sock *sk; + + local_irq_save(flags); + list_splice_init(&tsq->head, &list); + local_irq_restore(flags); + + list_for_each_safe(q, n, &list) { + tp = list_entry(q, struct tcp_sock, tsq_node); + list_del(&tp->tsq_node); + + sk = (struct sock *)tp; + smp_mb__before_atomic(); + clear_bit(TSQ_QUEUED, &sk->sk_tsq_flags); + + tcp_tsq_handler(sk); + sk_free(sk); + } +} + +#define TCP_DEFERRED_ALL (TCPF_TSQ_DEFERRED | \ + TCPF_WRITE_TIMER_DEFERRED | \ + TCPF_DELACK_TIMER_DEFERRED | \ + TCPF_MTU_REDUCED_DEFERRED) +/** + * tcp_release_cb - tcp release_sock() callback + * @sk: socket + * + * called from release_sock() to perform protocol dependent + * actions before socket release. + */ +void tcp_release_cb(struct sock *sk) +{ + unsigned long flags, nflags; + + /* perform an atomic operation only if at least one flag is set */ + do { + flags = sk->sk_tsq_flags; + if (!(flags & TCP_DEFERRED_ALL)) + return; + nflags = flags & ~TCP_DEFERRED_ALL; + } while (cmpxchg(&sk->sk_tsq_flags, flags, nflags) != flags); + + if (flags & TCPF_TSQ_DEFERRED) { + tcp_tsq_write(sk); + __sock_put(sk); + } + /* Here begins the tricky part : + * We are called from release_sock() with : + * 1) BH disabled + * 2) sk_lock.slock spinlock held + * 3) socket owned by us (sk->sk_lock.owned == 1) + * + * But following code is meant to be called from BH handlers, + * so we should keep BH disabled, but early release socket ownership + */ + sock_release_ownership(sk); + + if (flags & TCPF_WRITE_TIMER_DEFERRED) { + tcp_write_timer_handler(sk); + __sock_put(sk); + } + if (flags & TCPF_DELACK_TIMER_DEFERRED) { + tcp_delack_timer_handler(sk); + __sock_put(sk); + } + if (flags & TCPF_MTU_REDUCED_DEFERRED) { + inet_csk(sk)->icsk_af_ops->mtu_reduced(sk); + __sock_put(sk); + } +} +EXPORT_SYMBOL(tcp_release_cb); + +void __init tcp_tasklet_init(void) +{ + int i; + + for_each_possible_cpu(i) { + struct tsq_tasklet *tsq = &per_cpu(tsq_tasklet, i); + + INIT_LIST_HEAD(&tsq->head); + tasklet_setup(&tsq->tasklet, tcp_tasklet_func); + } +} + +/* + * Write buffer destructor automatically called from kfree_skb. + * We can't xmit new skbs from this context, as we might already + * hold qdisc lock. + */ +void tcp_wfree(struct sk_buff *skb) +{ + struct sock *sk = skb->sk; + struct tcp_sock *tp = tcp_sk(sk); + unsigned long flags, nval, oval; + + /* Keep one reference on sk_wmem_alloc. + * Will be released by sk_free() from here or tcp_tasklet_func() + */ + WARN_ON(refcount_sub_and_test(skb->truesize - 1, &sk->sk_wmem_alloc)); + + /* If this softirq is serviced by ksoftirqd, we are likely under stress. + * Wait until our queues (qdisc + devices) are drained. + * This gives : + * - less callbacks to tcp_write_xmit(), reducing stress (batches) + * - chance for incoming ACK (processed by another cpu maybe) + * to migrate this flow (skb->ooo_okay will be eventually set) + */ + if (refcount_read(&sk->sk_wmem_alloc) >= SKB_TRUESIZE(1) && this_cpu_ksoftirqd() == current) + goto out; + + for (oval = READ_ONCE(sk->sk_tsq_flags);; oval = nval) { + struct tsq_tasklet *tsq; + bool empty; + + if (!(oval & TSQF_THROTTLED) || (oval & TSQF_QUEUED)) + goto out; + + nval = (oval & ~TSQF_THROTTLED) | TSQF_QUEUED; + nval = cmpxchg(&sk->sk_tsq_flags, oval, nval); + if (nval != oval) + continue; + + /* queue this socket to tasklet queue */ + local_irq_save(flags); + tsq = this_cpu_ptr(&tsq_tasklet); + empty = list_empty(&tsq->head); + list_add(&tp->tsq_node, &tsq->head); + if (empty) + tasklet_schedule(&tsq->tasklet); + local_irq_restore(flags); + return; + } +out: + sk_free(sk); +} + +/* Note: Called under soft irq. + * We can call TCP stack right away, unless socket is owned by user. + */ +enum hrtimer_restart tcp_pace_kick(struct hrtimer *timer) +{ + struct tcp_sock *tp = container_of(timer, struct tcp_sock, pacing_timer); + struct sock *sk = (struct sock *)tp; + + tcp_tsq_handler(sk); + sock_put(sk); + + return HRTIMER_NORESTART; +} + +static void tcp_update_skb_after_send(struct sock *sk, struct sk_buff *skb, + u64 prior_wstamp) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (sk->sk_pacing_status != SK_PACING_NONE) { + unsigned long rate = sk->sk_pacing_rate; + + /* Original sch_fq does not pace first 10 MSS + * Note that tp->data_segs_out overflows after 2^32 packets, + * this is a minor annoyance. + */ + if (rate != ~0UL && rate && tp->data_segs_out >= 10) { + u64 len_ns = div64_ul((u64)skb->len * NSEC_PER_SEC, rate); + u64 credit = tp->tcp_wstamp_ns - prior_wstamp; + + /* take into account OS jitter */ + len_ns -= min_t(u64, len_ns / 2, credit); + tp->tcp_wstamp_ns += len_ns; + } + } + list_move_tail(&skb->tcp_tsorted_anchor, &tp->tsorted_sent_queue); +} + +INDIRECT_CALLABLE_DECLARE(int ip_queue_xmit(struct sock *sk, struct sk_buff *skb, struct flowi *fl)); +INDIRECT_CALLABLE_DECLARE(int inet6_csk_xmit(struct sock *sk, struct sk_buff *skb, struct flowi *fl)); +INDIRECT_CALLABLE_DECLARE(void tcp_v4_send_check(struct sock *sk, struct sk_buff *skb)); + +/* This routine actually transmits TCP packets queued in by + * tcp_do_sendmsg(). This is used by both the initial + * transmission and possible later retransmissions. + * All SKB's seen here are completely headerless. It is our + * job to build the TCP header, and pass the packet down to + * IP so it can do the same plus pass the packet off to the + * device. + * + * We are working here with either a clone of the original + * SKB, or a fresh unique copy made by the retransmit engine. + */ +static int __tcp_transmit_skb(struct sock *sk, struct sk_buff *skb, + int clone_it, gfp_t gfp_mask, u32 rcv_nxt) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + struct inet_sock *inet; + struct tcp_sock *tp; + struct tcp_skb_cb *tcb; + struct tcp_out_options opts; + unsigned int tcp_options_size, tcp_header_size; + struct sk_buff *oskb = NULL; + struct tcp_md5sig_key *md5; + struct tcphdr *th; + u64 prior_wstamp; + int err; + + BUG_ON(!skb || !tcp_skb_pcount(skb)); + tp = tcp_sk(sk); + prior_wstamp = tp->tcp_wstamp_ns; + tp->tcp_wstamp_ns = max(tp->tcp_wstamp_ns, tp->tcp_clock_cache); + skb_set_delivery_time(skb, tp->tcp_wstamp_ns, true); + if (clone_it) { + oskb = skb; + + tcp_skb_tsorted_save(oskb) { + if (unlikely(skb_cloned(oskb))) + skb = pskb_copy(oskb, gfp_mask); + else + skb = skb_clone(oskb, gfp_mask); + } tcp_skb_tsorted_restore(oskb); + + if (unlikely(!skb)) + return -ENOBUFS; + /* retransmit skbs might have a non zero value in skb->dev + * because skb->dev is aliased with skb->rbnode.rb_left + */ + skb->dev = NULL; + } + + inet = inet_sk(sk); + tcb = TCP_SKB_CB(skb); + memset(&opts, 0, sizeof(opts)); + + if (unlikely(tcb->tcp_flags & TCPHDR_SYN)) { + tcp_options_size = tcp_syn_options(sk, skb, &opts, &md5); + } else { + tcp_options_size = tcp_established_options(sk, skb, &opts, + &md5); + /* Force a PSH flag on all (GSO) packets to expedite GRO flush + * at receiver : This slightly improve GRO performance. + * Note that we do not force the PSH flag for non GSO packets, + * because they might be sent under high congestion events, + * and in this case it is better to delay the delivery of 1-MSS + * packets and thus the corresponding ACK packet that would + * release the following packet. + */ + if (tcp_skb_pcount(skb) > 1) + tcb->tcp_flags |= TCPHDR_PSH; + } + tcp_header_size = tcp_options_size + sizeof(struct tcphdr); + + /* if no packet is in qdisc/device queue, then allow XPS to select + * another queue. We can be called from tcp_tsq_handler() + * which holds one reference to sk. + * + * TODO: Ideally, in-flight pure ACK packets should not matter here. + * One way to get this would be to set skb->truesize = 2 on them. + */ + skb->ooo_okay = sk_wmem_alloc_get(sk) < SKB_TRUESIZE(1); + + /* If we had to use memory reserve to allocate this skb, + * this might cause drops if packet is looped back : + * Other socket might not have SOCK_MEMALLOC. + * Packets not looped back do not care about pfmemalloc. + */ + skb->pfmemalloc = 0; + + skb_push(skb, tcp_header_size); + skb_reset_transport_header(skb); + + skb_orphan(skb); + skb->sk = sk; + skb->destructor = skb_is_tcp_pure_ack(skb) ? __sock_wfree : tcp_wfree; + refcount_add(skb->truesize, &sk->sk_wmem_alloc); + + skb_set_dst_pending_confirm(skb, READ_ONCE(sk->sk_dst_pending_confirm)); + + /* Build TCP header and checksum it. */ + th = (struct tcphdr *)skb->data; + th->source = inet->inet_sport; + th->dest = inet->inet_dport; + th->seq = htonl(tcb->seq); + th->ack_seq = htonl(rcv_nxt); + *(((__be16 *)th) + 6) = htons(((tcp_header_size >> 2) << 12) | + tcb->tcp_flags); + + th->check = 0; + th->urg_ptr = 0; + + /* The urg_mode check is necessary during a below snd_una win probe */ + if (unlikely(tcp_urg_mode(tp) && before(tcb->seq, tp->snd_up))) { + if (before(tp->snd_up, tcb->seq + 0x10000)) { + th->urg_ptr = htons(tp->snd_up - tcb->seq); + th->urg = 1; + } else if (after(tcb->seq + 0xFFFF, tp->snd_nxt)) { + th->urg_ptr = htons(0xFFFF); + th->urg = 1; + } + } + + skb_shinfo(skb)->gso_type = sk->sk_gso_type; + if (likely(!(tcb->tcp_flags & TCPHDR_SYN))) { + th->window = htons(tcp_select_window(sk)); + tcp_ecn_send(sk, skb, th, tcp_header_size); + } else { + /* RFC1323: The window in SYN & SYN/ACK segments + * is never scaled. + */ + th->window = htons(min(tp->rcv_wnd, 65535U)); + } + + tcp_options_write(th, tp, &opts); + +#ifdef CONFIG_TCP_MD5SIG + /* Calculate the MD5 hash, as we have all we need now */ + if (md5) { + sk_gso_disable(sk); + tp->af_specific->calc_md5_hash(opts.hash_location, + md5, sk, skb); + } +#endif + + /* BPF prog is the last one writing header option */ + bpf_skops_write_hdr_opt(sk, skb, NULL, NULL, 0, &opts); + + INDIRECT_CALL_INET(icsk->icsk_af_ops->send_check, + tcp_v6_send_check, tcp_v4_send_check, + sk, skb); + + if (likely(tcb->tcp_flags & TCPHDR_ACK)) + tcp_event_ack_sent(sk, rcv_nxt); + + if (skb->len != tcp_header_size) { + tcp_event_data_sent(tp, sk); + tp->data_segs_out += tcp_skb_pcount(skb); + tp->bytes_sent += skb->len - tcp_header_size; + } + + if (after(tcb->end_seq, tp->snd_nxt) || tcb->seq == tcb->end_seq) + TCP_ADD_STATS(sock_net(sk), TCP_MIB_OUTSEGS, + tcp_skb_pcount(skb)); + + tp->segs_out += tcp_skb_pcount(skb); + skb_set_hash_from_sk(skb, sk); + /* OK, its time to fill skb_shinfo(skb)->gso_{segs|size} */ + skb_shinfo(skb)->gso_segs = tcp_skb_pcount(skb); + skb_shinfo(skb)->gso_size = tcp_skb_mss(skb); + + /* Leave earliest departure time in skb->tstamp (skb->skb_mstamp_ns) */ + + /* Cleanup our debris for IP stacks */ + memset(skb->cb, 0, max(sizeof(struct inet_skb_parm), + sizeof(struct inet6_skb_parm))); + + tcp_add_tx_delay(skb, tp); + + err = INDIRECT_CALL_INET(icsk->icsk_af_ops->queue_xmit, + inet6_csk_xmit, ip_queue_xmit, + sk, skb, &inet->cork.fl); + + if (unlikely(err > 0)) { + tcp_enter_cwr(sk); + err = net_xmit_eval(err); + } + if (!err && oskb) { + tcp_update_skb_after_send(sk, oskb, prior_wstamp); + tcp_rate_skb_sent(sk, oskb); + } + return err; +} + +static int tcp_transmit_skb(struct sock *sk, struct sk_buff *skb, int clone_it, + gfp_t gfp_mask) +{ + return __tcp_transmit_skb(sk, skb, clone_it, gfp_mask, + tcp_sk(sk)->rcv_nxt); +} + +/* This routine just queues the buffer for sending. + * + * NOTE: probe0 timer is not checked, do not forget tcp_push_pending_frames, + * otherwise socket can stall. + */ +static void tcp_queue_skb(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + + /* Advance write_seq and place onto the write_queue. */ + WRITE_ONCE(tp->write_seq, TCP_SKB_CB(skb)->end_seq); + __skb_header_release(skb); + tcp_add_write_queue_tail(sk, skb); + sk_wmem_queued_add(sk, skb->truesize); + sk_mem_charge(sk, skb->truesize); +} + +/* Initialize TSO segments for a packet. */ +static void tcp_set_skb_tso_segs(struct sk_buff *skb, unsigned int mss_now) +{ + if (skb->len <= mss_now) { + /* Avoid the costly divide in the normal + * non-TSO case. + */ + tcp_skb_pcount_set(skb, 1); + TCP_SKB_CB(skb)->tcp_gso_size = 0; + } else { + tcp_skb_pcount_set(skb, DIV_ROUND_UP(skb->len, mss_now)); + TCP_SKB_CB(skb)->tcp_gso_size = mss_now; + } +} + +/* Pcount in the middle of the write queue got changed, we need to do various + * tweaks to fix counters + */ +static void tcp_adjust_pcount(struct sock *sk, const struct sk_buff *skb, int decr) +{ + struct tcp_sock *tp = tcp_sk(sk); + + tp->packets_out -= decr; + + if (TCP_SKB_CB(skb)->sacked & TCPCB_SACKED_ACKED) + tp->sacked_out -= decr; + if (TCP_SKB_CB(skb)->sacked & TCPCB_SACKED_RETRANS) + tp->retrans_out -= decr; + if (TCP_SKB_CB(skb)->sacked & TCPCB_LOST) + tp->lost_out -= decr; + + /* Reno case is special. Sigh... */ + if (tcp_is_reno(tp) && decr > 0) + tp->sacked_out -= min_t(u32, tp->sacked_out, decr); + + if (tp->lost_skb_hint && + before(TCP_SKB_CB(skb)->seq, TCP_SKB_CB(tp->lost_skb_hint)->seq) && + (TCP_SKB_CB(skb)->sacked & TCPCB_SACKED_ACKED)) + tp->lost_cnt_hint -= decr; + + tcp_verify_left_out(tp); +} + +static bool tcp_has_tx_tstamp(const struct sk_buff *skb) +{ + return TCP_SKB_CB(skb)->txstamp_ack || + (skb_shinfo(skb)->tx_flags & SKBTX_ANY_TSTAMP); +} + +static void tcp_fragment_tstamp(struct sk_buff *skb, struct sk_buff *skb2) +{ + struct skb_shared_info *shinfo = skb_shinfo(skb); + + if (unlikely(tcp_has_tx_tstamp(skb)) && + !before(shinfo->tskey, TCP_SKB_CB(skb2)->seq)) { + struct skb_shared_info *shinfo2 = skb_shinfo(skb2); + u8 tsflags = shinfo->tx_flags & SKBTX_ANY_TSTAMP; + + shinfo->tx_flags &= ~tsflags; + shinfo2->tx_flags |= tsflags; + swap(shinfo->tskey, shinfo2->tskey); + TCP_SKB_CB(skb2)->txstamp_ack = TCP_SKB_CB(skb)->txstamp_ack; + TCP_SKB_CB(skb)->txstamp_ack = 0; + } +} + +static void tcp_skb_fragment_eor(struct sk_buff *skb, struct sk_buff *skb2) +{ + TCP_SKB_CB(skb2)->eor = TCP_SKB_CB(skb)->eor; + TCP_SKB_CB(skb)->eor = 0; +} + +/* Insert buff after skb on the write or rtx queue of sk. */ +static void tcp_insert_write_queue_after(struct sk_buff *skb, + struct sk_buff *buff, + struct sock *sk, + enum tcp_queue tcp_queue) +{ + if (tcp_queue == TCP_FRAG_IN_WRITE_QUEUE) + __skb_queue_after(&sk->sk_write_queue, skb, buff); + else + tcp_rbtree_insert(&sk->tcp_rtx_queue, buff); +} + +/* Function to create two new TCP segments. Shrinks the given segment + * to the specified size and appends a new segment with the rest of the + * packet to the list. This won't be called frequently, I hope. + * Remember, these are still headerless SKBs at this point. + */ +int tcp_fragment(struct sock *sk, enum tcp_queue tcp_queue, + struct sk_buff *skb, u32 len, + unsigned int mss_now, gfp_t gfp) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *buff; + int nsize, old_factor; + long limit; + int nlen; + u8 flags; + + if (WARN_ON(len > skb->len)) + return -EINVAL; + + nsize = skb_headlen(skb) - len; + if (nsize < 0) + nsize = 0; + + /* tcp_sendmsg() can overshoot sk_wmem_queued by one full size skb. + * We need some allowance to not penalize applications setting small + * SO_SNDBUF values. + * Also allow first and last skb in retransmit queue to be split. + */ + limit = sk->sk_sndbuf + 2 * SKB_TRUESIZE(GSO_LEGACY_MAX_SIZE); + if (unlikely((sk->sk_wmem_queued >> 1) > limit && + tcp_queue != TCP_FRAG_IN_WRITE_QUEUE && + skb != tcp_rtx_queue_head(sk) && + skb != tcp_rtx_queue_tail(sk))) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPWQUEUETOOBIG); + return -ENOMEM; + } + + if (skb_unclone_keeptruesize(skb, gfp)) + return -ENOMEM; + + /* Get a new skb... force flag on. */ + buff = tcp_stream_alloc_skb(sk, nsize, gfp, true); + if (!buff) + return -ENOMEM; /* We'll just try again later. */ + skb_copy_decrypted(buff, skb); + mptcp_skb_ext_copy(buff, skb); + + sk_wmem_queued_add(sk, buff->truesize); + sk_mem_charge(sk, buff->truesize); + nlen = skb->len - len - nsize; + buff->truesize += nlen; + skb->truesize -= nlen; + + /* Correct the sequence numbers. */ + TCP_SKB_CB(buff)->seq = TCP_SKB_CB(skb)->seq + len; + TCP_SKB_CB(buff)->end_seq = TCP_SKB_CB(skb)->end_seq; + TCP_SKB_CB(skb)->end_seq = TCP_SKB_CB(buff)->seq; + + /* PSH and FIN should only be set in the second packet. */ + flags = TCP_SKB_CB(skb)->tcp_flags; + TCP_SKB_CB(skb)->tcp_flags = flags & ~(TCPHDR_FIN | TCPHDR_PSH); + TCP_SKB_CB(buff)->tcp_flags = flags; + TCP_SKB_CB(buff)->sacked = TCP_SKB_CB(skb)->sacked; + tcp_skb_fragment_eor(skb, buff); + + skb_split(skb, buff, len); + + skb_set_delivery_time(buff, skb->tstamp, true); + tcp_fragment_tstamp(skb, buff); + + old_factor = tcp_skb_pcount(skb); + + /* Fix up tso_factor for both original and new SKB. */ + tcp_set_skb_tso_segs(skb, mss_now); + tcp_set_skb_tso_segs(buff, mss_now); + + /* Update delivered info for the new segment */ + TCP_SKB_CB(buff)->tx = TCP_SKB_CB(skb)->tx; + + /* If this packet has been sent out already, we must + * adjust the various packet counters. + */ + if (!before(tp->snd_nxt, TCP_SKB_CB(buff)->end_seq)) { + int diff = old_factor - tcp_skb_pcount(skb) - + tcp_skb_pcount(buff); + + if (diff) + tcp_adjust_pcount(sk, skb, diff); + } + + /* Link BUFF into the send queue. */ + __skb_header_release(buff); + tcp_insert_write_queue_after(skb, buff, sk, tcp_queue); + if (tcp_queue == TCP_FRAG_IN_RTX_QUEUE) + list_add(&buff->tcp_tsorted_anchor, &skb->tcp_tsorted_anchor); + + return 0; +} + +/* This is similar to __pskb_pull_tail(). The difference is that pulled + * data is not copied, but immediately discarded. + */ +static int __pskb_trim_head(struct sk_buff *skb, int len) +{ + struct skb_shared_info *shinfo; + int i, k, eat; + + eat = min_t(int, len, skb_headlen(skb)); + if (eat) { + __skb_pull(skb, eat); + len -= eat; + if (!len) + return 0; + } + eat = len; + k = 0; + shinfo = skb_shinfo(skb); + for (i = 0; i < shinfo->nr_frags; i++) { + int size = skb_frag_size(&shinfo->frags[i]); + + if (size <= eat) { + skb_frag_unref(skb, i); + eat -= size; + } else { + shinfo->frags[k] = shinfo->frags[i]; + if (eat) { + skb_frag_off_add(&shinfo->frags[k], eat); + skb_frag_size_sub(&shinfo->frags[k], eat); + eat = 0; + } + k++; + } + } + shinfo->nr_frags = k; + + skb->data_len -= len; + skb->len = skb->data_len; + return len; +} + +/* Remove acked data from a packet in the transmit queue. */ +int tcp_trim_head(struct sock *sk, struct sk_buff *skb, u32 len) +{ + u32 delta_truesize; + + if (skb_unclone_keeptruesize(skb, GFP_ATOMIC)) + return -ENOMEM; + + delta_truesize = __pskb_trim_head(skb, len); + + TCP_SKB_CB(skb)->seq += len; + + if (delta_truesize) { + skb->truesize -= delta_truesize; + sk_wmem_queued_add(sk, -delta_truesize); + if (!skb_zcopy_pure(skb)) + sk_mem_uncharge(sk, delta_truesize); + } + + /* Any change of skb->len requires recalculation of tso factor. */ + if (tcp_skb_pcount(skb) > 1) + tcp_set_skb_tso_segs(skb, tcp_skb_mss(skb)); + + return 0; +} + +/* Calculate MSS not accounting any TCP options. */ +static inline int __tcp_mtu_to_mss(struct sock *sk, int pmtu) +{ + const struct tcp_sock *tp = tcp_sk(sk); + const struct inet_connection_sock *icsk = inet_csk(sk); + int mss_now; + + /* Calculate base mss without TCP options: + It is MMS_S - sizeof(tcphdr) of rfc1122 + */ + mss_now = pmtu - icsk->icsk_af_ops->net_header_len - sizeof(struct tcphdr); + + /* IPv6 adds a frag_hdr in case RTAX_FEATURE_ALLFRAG is set */ + if (icsk->icsk_af_ops->net_frag_header_len) { + const struct dst_entry *dst = __sk_dst_get(sk); + + if (dst && dst_allfrag(dst)) + mss_now -= icsk->icsk_af_ops->net_frag_header_len; + } + + /* Clamp it (mss_clamp does not include tcp options) */ + if (mss_now > tp->rx_opt.mss_clamp) + mss_now = tp->rx_opt.mss_clamp; + + /* Now subtract optional transport overhead */ + mss_now -= icsk->icsk_ext_hdr_len; + + /* Then reserve room for full set of TCP options and 8 bytes of data */ + mss_now = max(mss_now, + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_min_snd_mss)); + return mss_now; +} + +/* Calculate MSS. Not accounting for SACKs here. */ +int tcp_mtu_to_mss(struct sock *sk, int pmtu) +{ + /* Subtract TCP options size, not including SACKs */ + return __tcp_mtu_to_mss(sk, pmtu) - + (tcp_sk(sk)->tcp_header_len - sizeof(struct tcphdr)); +} +EXPORT_SYMBOL(tcp_mtu_to_mss); + +/* Inverse of above */ +int tcp_mss_to_mtu(struct sock *sk, int mss) +{ + const struct tcp_sock *tp = tcp_sk(sk); + const struct inet_connection_sock *icsk = inet_csk(sk); + int mtu; + + mtu = mss + + tp->tcp_header_len + + icsk->icsk_ext_hdr_len + + icsk->icsk_af_ops->net_header_len; + + /* IPv6 adds a frag_hdr in case RTAX_FEATURE_ALLFRAG is set */ + if (icsk->icsk_af_ops->net_frag_header_len) { + const struct dst_entry *dst = __sk_dst_get(sk); + + if (dst && dst_allfrag(dst)) + mtu += icsk->icsk_af_ops->net_frag_header_len; + } + return mtu; +} +EXPORT_SYMBOL(tcp_mss_to_mtu); + +/* MTU probing init per socket */ +void tcp_mtup_init(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct inet_connection_sock *icsk = inet_csk(sk); + struct net *net = sock_net(sk); + + icsk->icsk_mtup.enabled = READ_ONCE(net->ipv4.sysctl_tcp_mtu_probing) > 1; + icsk->icsk_mtup.search_high = tp->rx_opt.mss_clamp + sizeof(struct tcphdr) + + icsk->icsk_af_ops->net_header_len; + icsk->icsk_mtup.search_low = tcp_mss_to_mtu(sk, READ_ONCE(net->ipv4.sysctl_tcp_base_mss)); + icsk->icsk_mtup.probe_size = 0; + if (icsk->icsk_mtup.enabled) + icsk->icsk_mtup.probe_timestamp = tcp_jiffies32; +} +EXPORT_SYMBOL(tcp_mtup_init); + +/* This function synchronize snd mss to current pmtu/exthdr set. + + tp->rx_opt.user_mss is mss set by user by TCP_MAXSEG. It does NOT counts + for TCP options, but includes only bare TCP header. + + tp->rx_opt.mss_clamp is mss negotiated at connection setup. + It is minimum of user_mss and mss received with SYN. + It also does not include TCP options. + + inet_csk(sk)->icsk_pmtu_cookie is last pmtu, seen by this function. + + tp->mss_cache is current effective sending mss, including + all tcp options except for SACKs. It is evaluated, + taking into account current pmtu, but never exceeds + tp->rx_opt.mss_clamp. + + NOTE1. rfc1122 clearly states that advertised MSS + DOES NOT include either tcp or ip options. + + NOTE2. inet_csk(sk)->icsk_pmtu_cookie and tp->mss_cache + are READ ONLY outside this function. --ANK (980731) + */ +unsigned int tcp_sync_mss(struct sock *sk, u32 pmtu) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct inet_connection_sock *icsk = inet_csk(sk); + int mss_now; + + if (icsk->icsk_mtup.search_high > pmtu) + icsk->icsk_mtup.search_high = pmtu; + + mss_now = tcp_mtu_to_mss(sk, pmtu); + mss_now = tcp_bound_to_half_wnd(tp, mss_now); + + /* And store cached results */ + icsk->icsk_pmtu_cookie = pmtu; + if (icsk->icsk_mtup.enabled) + mss_now = min(mss_now, tcp_mtu_to_mss(sk, icsk->icsk_mtup.search_low)); + tp->mss_cache = mss_now; + + return mss_now; +} +EXPORT_SYMBOL(tcp_sync_mss); + +/* Compute the current effective MSS, taking SACKs and IP options, + * and even PMTU discovery events into account. + */ +unsigned int tcp_current_mss(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + const struct dst_entry *dst = __sk_dst_get(sk); + u32 mss_now; + unsigned int header_len; + struct tcp_out_options opts; + struct tcp_md5sig_key *md5; + + mss_now = tp->mss_cache; + + if (dst) { + u32 mtu = dst_mtu(dst); + if (mtu != inet_csk(sk)->icsk_pmtu_cookie) + mss_now = tcp_sync_mss(sk, mtu); + } + + header_len = tcp_established_options(sk, NULL, &opts, &md5) + + sizeof(struct tcphdr); + /* The mss_cache is sized based on tp->tcp_header_len, which assumes + * some common options. If this is an odd packet (because we have SACK + * blocks etc) then our calculated header_len will be different, and + * we have to adjust mss_now correspondingly */ + if (header_len != tp->tcp_header_len) { + int delta = (int) header_len - tp->tcp_header_len; + mss_now -= delta; + } + + return mss_now; +} + +/* RFC2861, slow part. Adjust cwnd, after it was not full during one rto. + * As additional protections, we do not touch cwnd in retransmission phases, + * and if application hit its sndbuf limit recently. + */ +static void tcp_cwnd_application_limited(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (inet_csk(sk)->icsk_ca_state == TCP_CA_Open && + sk->sk_socket && !test_bit(SOCK_NOSPACE, &sk->sk_socket->flags)) { + /* Limited by application or receiver window. */ + u32 init_win = tcp_init_cwnd(tp, __sk_dst_get(sk)); + u32 win_used = max(tp->snd_cwnd_used, init_win); + if (win_used < tcp_snd_cwnd(tp)) { + tp->snd_ssthresh = tcp_current_ssthresh(sk); + tcp_snd_cwnd_set(tp, (tcp_snd_cwnd(tp) + win_used) >> 1); + } + tp->snd_cwnd_used = 0; + } + tp->snd_cwnd_stamp = tcp_jiffies32; +} + +static void tcp_cwnd_validate(struct sock *sk, bool is_cwnd_limited) +{ + const struct tcp_congestion_ops *ca_ops = inet_csk(sk)->icsk_ca_ops; + struct tcp_sock *tp = tcp_sk(sk); + + /* Track the strongest available signal of the degree to which the cwnd + * is fully utilized. If cwnd-limited then remember that fact for the + * current window. If not cwnd-limited then track the maximum number of + * outstanding packets in the current window. (If cwnd-limited then we + * chose to not update tp->max_packets_out to avoid an extra else + * clause with no functional impact.) + */ + if (!before(tp->snd_una, tp->cwnd_usage_seq) || + is_cwnd_limited || + (!tp->is_cwnd_limited && + tp->packets_out > tp->max_packets_out)) { + tp->is_cwnd_limited = is_cwnd_limited; + tp->max_packets_out = tp->packets_out; + tp->cwnd_usage_seq = tp->snd_nxt; + } + + if (tcp_is_cwnd_limited(sk)) { + /* Network is feed fully. */ + tp->snd_cwnd_used = 0; + tp->snd_cwnd_stamp = tcp_jiffies32; + } else { + /* Network starves. */ + if (tp->packets_out > tp->snd_cwnd_used) + tp->snd_cwnd_used = tp->packets_out; + + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_slow_start_after_idle) && + (s32)(tcp_jiffies32 - tp->snd_cwnd_stamp) >= inet_csk(sk)->icsk_rto && + !ca_ops->cong_control) + tcp_cwnd_application_limited(sk); + + /* The following conditions together indicate the starvation + * is caused by insufficient sender buffer: + * 1) just sent some data (see tcp_write_xmit) + * 2) not cwnd limited (this else condition) + * 3) no more data to send (tcp_write_queue_empty()) + * 4) application is hitting buffer limit (SOCK_NOSPACE) + */ + if (tcp_write_queue_empty(sk) && sk->sk_socket && + test_bit(SOCK_NOSPACE, &sk->sk_socket->flags) && + (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_CLOSE_WAIT)) + tcp_chrono_start(sk, TCP_CHRONO_SNDBUF_LIMITED); + } +} + +/* Minshall's variant of the Nagle send check. */ +static bool tcp_minshall_check(const struct tcp_sock *tp) +{ + return after(tp->snd_sml, tp->snd_una) && + !after(tp->snd_sml, tp->snd_nxt); +} + +/* Update snd_sml if this skb is under mss + * Note that a TSO packet might end with a sub-mss segment + * The test is really : + * if ((skb->len % mss) != 0) + * tp->snd_sml = TCP_SKB_CB(skb)->end_seq; + * But we can avoid doing the divide again given we already have + * skb_pcount = skb->len / mss_now + */ +static void tcp_minshall_update(struct tcp_sock *tp, unsigned int mss_now, + const struct sk_buff *skb) +{ + if (skb->len < tcp_skb_pcount(skb) * mss_now) + tp->snd_sml = TCP_SKB_CB(skb)->end_seq; +} + +/* Return false, if packet can be sent now without violation Nagle's rules: + * 1. It is full sized. (provided by caller in %partial bool) + * 2. Or it contains FIN. (already checked by caller) + * 3. Or TCP_CORK is not set, and TCP_NODELAY is set. + * 4. Or TCP_CORK is not set, and all sent packets are ACKed. + * With Minshall's modification: all sent small packets are ACKed. + */ +static bool tcp_nagle_check(bool partial, const struct tcp_sock *tp, + int nonagle) +{ + return partial && + ((nonagle & TCP_NAGLE_CORK) || + (!nonagle && tp->packets_out && tcp_minshall_check(tp))); +} + +/* Return how many segs we'd like on a TSO packet, + * depending on current pacing rate, and how close the peer is. + * + * Rationale is: + * - For close peers, we rather send bigger packets to reduce + * cpu costs, because occasional losses will be repaired fast. + * - For long distance/rtt flows, we would like to get ACK clocking + * with 1 ACK per ms. + * + * Use min_rtt to help adapt TSO burst size, with smaller min_rtt resulting + * in bigger TSO bursts. We we cut the RTT-based allowance in half + * for every 2^9 usec (aka 512 us) of RTT, so that the RTT-based allowance + * is below 1500 bytes after 6 * ~500 usec = 3ms. + */ +static u32 tcp_tso_autosize(const struct sock *sk, unsigned int mss_now, + int min_tso_segs) +{ + unsigned long bytes; + u32 r; + + bytes = sk->sk_pacing_rate >> READ_ONCE(sk->sk_pacing_shift); + + r = tcp_min_rtt(tcp_sk(sk)) >> READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_tso_rtt_log); + if (r < BITS_PER_TYPE(sk->sk_gso_max_size)) + bytes += sk->sk_gso_max_size >> r; + + bytes = min_t(unsigned long, bytes, sk->sk_gso_max_size); + + return max_t(u32, bytes / mss_now, min_tso_segs); +} + +/* Return the number of segments we want in the skb we are transmitting. + * See if congestion control module wants to decide; otherwise, autosize. + */ +static u32 tcp_tso_segs(struct sock *sk, unsigned int mss_now) +{ + const struct tcp_congestion_ops *ca_ops = inet_csk(sk)->icsk_ca_ops; + u32 min_tso, tso_segs; + + min_tso = ca_ops->min_tso_segs ? + ca_ops->min_tso_segs(sk) : + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_min_tso_segs); + + tso_segs = tcp_tso_autosize(sk, mss_now, min_tso); + return min_t(u32, tso_segs, sk->sk_gso_max_segs); +} + +/* Returns the portion of skb which can be sent right away */ +static unsigned int tcp_mss_split_point(const struct sock *sk, + const struct sk_buff *skb, + unsigned int mss_now, + unsigned int max_segs, + int nonagle) +{ + const struct tcp_sock *tp = tcp_sk(sk); + u32 partial, needed, window, max_len; + + window = tcp_wnd_end(tp) - TCP_SKB_CB(skb)->seq; + max_len = mss_now * max_segs; + + if (likely(max_len <= window && skb != tcp_write_queue_tail(sk))) + return max_len; + + needed = min(skb->len, window); + + if (max_len <= needed) + return max_len; + + partial = needed % mss_now; + /* If last segment is not a full MSS, check if Nagle rules allow us + * to include this last segment in this skb. + * Otherwise, we'll split the skb at last MSS boundary + */ + if (tcp_nagle_check(partial != 0, tp, nonagle)) + return needed - partial; + + return needed; +} + +/* Can at least one segment of SKB be sent right now, according to the + * congestion window rules? If so, return how many segments are allowed. + */ +static inline unsigned int tcp_cwnd_test(const struct tcp_sock *tp, + const struct sk_buff *skb) +{ + u32 in_flight, cwnd, halfcwnd; + + /* Don't be strict about the congestion window for the final FIN. */ + if ((TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) && + tcp_skb_pcount(skb) == 1) + return 1; + + in_flight = tcp_packets_in_flight(tp); + cwnd = tcp_snd_cwnd(tp); + if (in_flight >= cwnd) + return 0; + + /* For better scheduling, ensure we have at least + * 2 GSO packets in flight. + */ + halfcwnd = max(cwnd >> 1, 1U); + return min(halfcwnd, cwnd - in_flight); +} + +/* Initialize TSO state of a skb. + * This must be invoked the first time we consider transmitting + * SKB onto the wire. + */ +static int tcp_init_tso_segs(struct sk_buff *skb, unsigned int mss_now) +{ + int tso_segs = tcp_skb_pcount(skb); + + if (!tso_segs || (tso_segs > 1 && tcp_skb_mss(skb) != mss_now)) { + tcp_set_skb_tso_segs(skb, mss_now); + tso_segs = tcp_skb_pcount(skb); + } + return tso_segs; +} + + +/* Return true if the Nagle test allows this packet to be + * sent now. + */ +static inline bool tcp_nagle_test(const struct tcp_sock *tp, const struct sk_buff *skb, + unsigned int cur_mss, int nonagle) +{ + /* Nagle rule does not apply to frames, which sit in the middle of the + * write_queue (they have no chances to get new data). + * + * This is implemented in the callers, where they modify the 'nonagle' + * argument based upon the location of SKB in the send queue. + */ + if (nonagle & TCP_NAGLE_PUSH) + return true; + + /* Don't use the nagle rule for urgent data (or for the final FIN). */ + if (tcp_urg_mode(tp) || (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN)) + return true; + + if (!tcp_nagle_check(skb->len < cur_mss, tp, nonagle)) + return true; + + return false; +} + +/* Does at least the first segment of SKB fit into the send window? */ +static bool tcp_snd_wnd_test(const struct tcp_sock *tp, + const struct sk_buff *skb, + unsigned int cur_mss) +{ + u32 end_seq = TCP_SKB_CB(skb)->end_seq; + + if (skb->len > cur_mss) + end_seq = TCP_SKB_CB(skb)->seq + cur_mss; + + return !after(end_seq, tcp_wnd_end(tp)); +} + +/* Trim TSO SKB to LEN bytes, put the remaining data into a new packet + * which is put after SKB on the list. It is very much like + * tcp_fragment() except that it may make several kinds of assumptions + * in order to speed up the splitting operation. In particular, we + * know that all the data is in scatter-gather pages, and that the + * packet has never been sent out before (and thus is not cloned). + */ +static int tso_fragment(struct sock *sk, struct sk_buff *skb, unsigned int len, + unsigned int mss_now, gfp_t gfp) +{ + int nlen = skb->len - len; + struct sk_buff *buff; + u8 flags; + + /* All of a TSO frame must be composed of paged data. */ + if (skb->len != skb->data_len) + return tcp_fragment(sk, TCP_FRAG_IN_WRITE_QUEUE, + skb, len, mss_now, gfp); + + buff = tcp_stream_alloc_skb(sk, 0, gfp, true); + if (unlikely(!buff)) + return -ENOMEM; + skb_copy_decrypted(buff, skb); + mptcp_skb_ext_copy(buff, skb); + + sk_wmem_queued_add(sk, buff->truesize); + sk_mem_charge(sk, buff->truesize); + buff->truesize += nlen; + skb->truesize -= nlen; + + /* Correct the sequence numbers. */ + TCP_SKB_CB(buff)->seq = TCP_SKB_CB(skb)->seq + len; + TCP_SKB_CB(buff)->end_seq = TCP_SKB_CB(skb)->end_seq; + TCP_SKB_CB(skb)->end_seq = TCP_SKB_CB(buff)->seq; + + /* PSH and FIN should only be set in the second packet. */ + flags = TCP_SKB_CB(skb)->tcp_flags; + TCP_SKB_CB(skb)->tcp_flags = flags & ~(TCPHDR_FIN | TCPHDR_PSH); + TCP_SKB_CB(buff)->tcp_flags = flags; + + tcp_skb_fragment_eor(skb, buff); + + skb_split(skb, buff, len); + tcp_fragment_tstamp(skb, buff); + + /* Fix up tso_factor for both original and new SKB. */ + tcp_set_skb_tso_segs(skb, mss_now); + tcp_set_skb_tso_segs(buff, mss_now); + + /* Link BUFF into the send queue. */ + __skb_header_release(buff); + tcp_insert_write_queue_after(skb, buff, sk, TCP_FRAG_IN_WRITE_QUEUE); + + return 0; +} + +/* Try to defer sending, if possible, in order to minimize the amount + * of TSO splitting we do. View it as a kind of TSO Nagle test. + * + * This algorithm is from John Heffner. + */ +static bool tcp_tso_should_defer(struct sock *sk, struct sk_buff *skb, + bool *is_cwnd_limited, + bool *is_rwnd_limited, + u32 max_segs) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + u32 send_win, cong_win, limit, in_flight; + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *head; + int win_divisor; + s64 delta; + + if (icsk->icsk_ca_state >= TCP_CA_Recovery) + goto send_now; + + /* Avoid bursty behavior by allowing defer + * only if the last write was recent (1 ms). + * Note that tp->tcp_wstamp_ns can be in the future if we have + * packets waiting in a qdisc or device for EDT delivery. + */ + delta = tp->tcp_clock_cache - tp->tcp_wstamp_ns - NSEC_PER_MSEC; + if (delta > 0) + goto send_now; + + in_flight = tcp_packets_in_flight(tp); + + BUG_ON(tcp_skb_pcount(skb) <= 1); + BUG_ON(tcp_snd_cwnd(tp) <= in_flight); + + send_win = tcp_wnd_end(tp) - TCP_SKB_CB(skb)->seq; + + /* From in_flight test above, we know that cwnd > in_flight. */ + cong_win = (tcp_snd_cwnd(tp) - in_flight) * tp->mss_cache; + + limit = min(send_win, cong_win); + + /* If a full-sized TSO skb can be sent, do it. */ + if (limit >= max_segs * tp->mss_cache) + goto send_now; + + /* Middle in queue won't get any more data, full sendable already? */ + if ((skb != tcp_write_queue_tail(sk)) && (limit >= skb->len)) + goto send_now; + + win_divisor = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_tso_win_divisor); + if (win_divisor) { + u32 chunk = min(tp->snd_wnd, tcp_snd_cwnd(tp) * tp->mss_cache); + + /* If at least some fraction of a window is available, + * just use it. + */ + chunk /= win_divisor; + if (limit >= chunk) + goto send_now; + } else { + /* Different approach, try not to defer past a single + * ACK. Receiver should ACK every other full sized + * frame, so if we have space for more than 3 frames + * then send now. + */ + if (limit > tcp_max_tso_deferred_mss(tp) * tp->mss_cache) + goto send_now; + } + + /* TODO : use tsorted_sent_queue ? */ + head = tcp_rtx_queue_head(sk); + if (!head) + goto send_now; + delta = tp->tcp_clock_cache - head->tstamp; + /* If next ACK is likely to come too late (half srtt), do not defer */ + if ((s64)(delta - (u64)NSEC_PER_USEC * (tp->srtt_us >> 4)) < 0) + goto send_now; + + /* Ok, it looks like it is advisable to defer. + * Three cases are tracked : + * 1) We are cwnd-limited + * 2) We are rwnd-limited + * 3) We are application limited. + */ + if (cong_win < send_win) { + if (cong_win <= skb->len) { + *is_cwnd_limited = true; + return true; + } + } else { + if (send_win <= skb->len) { + *is_rwnd_limited = true; + return true; + } + } + + /* If this packet won't get more data, do not wait. */ + if ((TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) || + TCP_SKB_CB(skb)->eor) + goto send_now; + + return true; + +send_now: + return false; +} + +static inline void tcp_mtu_check_reprobe(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); + u32 interval; + s32 delta; + + interval = READ_ONCE(net->ipv4.sysctl_tcp_probe_interval); + delta = tcp_jiffies32 - icsk->icsk_mtup.probe_timestamp; + if (unlikely(delta >= interval * HZ)) { + int mss = tcp_current_mss(sk); + + /* Update current search range */ + icsk->icsk_mtup.probe_size = 0; + icsk->icsk_mtup.search_high = tp->rx_opt.mss_clamp + + sizeof(struct tcphdr) + + icsk->icsk_af_ops->net_header_len; + icsk->icsk_mtup.search_low = tcp_mss_to_mtu(sk, mss); + + /* Update probe time stamp */ + icsk->icsk_mtup.probe_timestamp = tcp_jiffies32; + } +} + +static bool tcp_can_coalesce_send_queue_head(struct sock *sk, int len) +{ + struct sk_buff *skb, *next; + + skb = tcp_send_head(sk); + tcp_for_write_queue_from_safe(skb, next, sk) { + if (len <= skb->len) + break; + + if (unlikely(TCP_SKB_CB(skb)->eor) || + tcp_has_tx_tstamp(skb) || + !skb_pure_zcopy_same(skb, next)) + return false; + + len -= skb->len; + } + + return true; +} + +/* Create a new MTU probe if we are ready. + * MTU probe is regularly attempting to increase the path MTU by + * deliberately sending larger packets. This discovers routing + * changes resulting in larger path MTUs. + * + * Returns 0 if we should wait to probe (no cwnd available), + * 1 if a probe was sent, + * -1 otherwise + */ +static int tcp_mtu_probe(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *skb, *nskb, *next; + struct net *net = sock_net(sk); + int probe_size; + int size_needed; + int copy, len; + int mss_now; + int interval; + + /* Not currently probing/verifying, + * not in recovery, + * have enough cwnd, and + * not SACKing (the variable headers throw things off) + */ + if (likely(!icsk->icsk_mtup.enabled || + icsk->icsk_mtup.probe_size || + inet_csk(sk)->icsk_ca_state != TCP_CA_Open || + tcp_snd_cwnd(tp) < 11 || + tp->rx_opt.num_sacks || tp->rx_opt.dsack)) + return -1; + + /* Use binary search for probe_size between tcp_mss_base, + * and current mss_clamp. if (search_high - search_low) + * smaller than a threshold, backoff from probing. + */ + mss_now = tcp_current_mss(sk); + probe_size = tcp_mtu_to_mss(sk, (icsk->icsk_mtup.search_high + + icsk->icsk_mtup.search_low) >> 1); + size_needed = probe_size + (tp->reordering + 1) * tp->mss_cache; + interval = icsk->icsk_mtup.search_high - icsk->icsk_mtup.search_low; + /* When misfortune happens, we are reprobing actively, + * and then reprobe timer has expired. We stick with current + * probing process by not resetting search range to its orignal. + */ + if (probe_size > tcp_mtu_to_mss(sk, icsk->icsk_mtup.search_high) || + interval < READ_ONCE(net->ipv4.sysctl_tcp_probe_threshold)) { + /* Check whether enough time has elaplased for + * another round of probing. + */ + tcp_mtu_check_reprobe(sk); + return -1; + } + + /* Have enough data in the send queue to probe? */ + if (tp->write_seq - tp->snd_nxt < size_needed) + return -1; + + if (tp->snd_wnd < size_needed) + return -1; + if (after(tp->snd_nxt + size_needed, tcp_wnd_end(tp))) + return 0; + + /* Do we need to wait to drain cwnd? With none in flight, don't stall */ + if (tcp_packets_in_flight(tp) + 2 > tcp_snd_cwnd(tp)) { + if (!tcp_packets_in_flight(tp)) + return -1; + else + return 0; + } + + if (!tcp_can_coalesce_send_queue_head(sk, probe_size)) + return -1; + + /* We're allowed to probe. Build it now. */ + nskb = tcp_stream_alloc_skb(sk, probe_size, GFP_ATOMIC, false); + if (!nskb) + return -1; + sk_wmem_queued_add(sk, nskb->truesize); + sk_mem_charge(sk, nskb->truesize); + + skb = tcp_send_head(sk); + skb_copy_decrypted(nskb, skb); + mptcp_skb_ext_copy(nskb, skb); + + TCP_SKB_CB(nskb)->seq = TCP_SKB_CB(skb)->seq; + TCP_SKB_CB(nskb)->end_seq = TCP_SKB_CB(skb)->seq + probe_size; + TCP_SKB_CB(nskb)->tcp_flags = TCPHDR_ACK; + + tcp_insert_write_queue_before(nskb, skb, sk); + tcp_highest_sack_replace(sk, skb, nskb); + + len = 0; + tcp_for_write_queue_from_safe(skb, next, sk) { + copy = min_t(int, skb->len, probe_size - len); + skb_copy_bits(skb, 0, skb_put(nskb, copy), copy); + + if (skb->len <= copy) { + /* We've eaten all the data from this skb. + * Throw it away. */ + TCP_SKB_CB(nskb)->tcp_flags |= TCP_SKB_CB(skb)->tcp_flags; + /* If this is the last SKB we copy and eor is set + * we need to propagate it to the new skb. + */ + TCP_SKB_CB(nskb)->eor = TCP_SKB_CB(skb)->eor; + tcp_skb_collapse_tstamp(nskb, skb); + tcp_unlink_write_queue(skb, sk); + tcp_wmem_free_skb(sk, skb); + } else { + TCP_SKB_CB(nskb)->tcp_flags |= TCP_SKB_CB(skb)->tcp_flags & + ~(TCPHDR_FIN|TCPHDR_PSH); + if (!skb_shinfo(skb)->nr_frags) { + skb_pull(skb, copy); + } else { + __pskb_trim_head(skb, copy); + tcp_set_skb_tso_segs(skb, mss_now); + } + TCP_SKB_CB(skb)->seq += copy; + } + + len += copy; + + if (len >= probe_size) + break; + } + tcp_init_tso_segs(nskb, nskb->len); + + /* We're ready to send. If this fails, the probe will + * be resegmented into mss-sized pieces by tcp_write_xmit(). + */ + if (!tcp_transmit_skb(sk, nskb, 1, GFP_ATOMIC)) { + /* Decrement cwnd here because we are sending + * effectively two packets. */ + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) - 1); + tcp_event_new_data_sent(sk, nskb); + + icsk->icsk_mtup.probe_size = tcp_mss_to_mtu(sk, nskb->len); + tp->mtu_probe.probe_seq_start = TCP_SKB_CB(nskb)->seq; + tp->mtu_probe.probe_seq_end = TCP_SKB_CB(nskb)->end_seq; + + return 1; + } + + return -1; +} + +static bool tcp_pacing_check(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (!tcp_needs_internal_pacing(sk)) + return false; + + if (tp->tcp_wstamp_ns <= tp->tcp_clock_cache) + return false; + + if (!hrtimer_is_queued(&tp->pacing_timer)) { + hrtimer_start(&tp->pacing_timer, + ns_to_ktime(tp->tcp_wstamp_ns), + HRTIMER_MODE_ABS_PINNED_SOFT); + sock_hold(sk); + } + return true; +} + +static bool tcp_rtx_queue_empty_or_single_skb(const struct sock *sk) +{ + const struct rb_node *node = sk->tcp_rtx_queue.rb_node; + + /* No skb in the rtx queue. */ + if (!node) + return true; + + /* Only one skb in rtx queue. */ + return !node->rb_left && !node->rb_right; +} + +/* TCP Small Queues : + * Control number of packets in qdisc/devices to two packets / or ~1 ms. + * (These limits are doubled for retransmits) + * This allows for : + * - better RTT estimation and ACK scheduling + * - faster recovery + * - high rates + * Alas, some drivers / subsystems require a fair amount + * of queued bytes to ensure line rate. + * One example is wifi aggregation (802.11 AMPDU) + */ +static bool tcp_small_queue_check(struct sock *sk, const struct sk_buff *skb, + unsigned int factor) +{ + unsigned long limit; + + limit = max_t(unsigned long, + 2 * skb->truesize, + sk->sk_pacing_rate >> READ_ONCE(sk->sk_pacing_shift)); + if (sk->sk_pacing_status == SK_PACING_NONE) + limit = min_t(unsigned long, limit, + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_limit_output_bytes)); + limit <<= factor; + + if (static_branch_unlikely(&tcp_tx_delay_enabled) && + tcp_sk(sk)->tcp_tx_delay) { + u64 extra_bytes = (u64)sk->sk_pacing_rate * tcp_sk(sk)->tcp_tx_delay; + + /* TSQ is based on skb truesize sum (sk_wmem_alloc), so we + * approximate our needs assuming an ~100% skb->truesize overhead. + * USEC_PER_SEC is approximated by 2^20. + * do_div(extra_bytes, USEC_PER_SEC/2) is replaced by a right shift. + */ + extra_bytes >>= (20 - 1); + limit += extra_bytes; + } + if (refcount_read(&sk->sk_wmem_alloc) > limit) { + /* Always send skb if rtx queue is empty or has one skb. + * No need to wait for TX completion to call us back, + * after softirq/tasklet schedule. + * This helps when TX completions are delayed too much. + */ + if (tcp_rtx_queue_empty_or_single_skb(sk)) + return false; + + set_bit(TSQ_THROTTLED, &sk->sk_tsq_flags); + /* It is possible TX completion already happened + * before we set TSQ_THROTTLED, so we must + * test again the condition. + */ + smp_mb__after_atomic(); + if (refcount_read(&sk->sk_wmem_alloc) > limit) + return true; + } + return false; +} + +static void tcp_chrono_set(struct tcp_sock *tp, const enum tcp_chrono new) +{ + const u32 now = tcp_jiffies32; + enum tcp_chrono old = tp->chrono_type; + + if (old > TCP_CHRONO_UNSPEC) + tp->chrono_stat[old - 1] += now - tp->chrono_start; + tp->chrono_start = now; + tp->chrono_type = new; +} + +void tcp_chrono_start(struct sock *sk, const enum tcp_chrono type) +{ + struct tcp_sock *tp = tcp_sk(sk); + + /* If there are multiple conditions worthy of tracking in a + * chronograph then the highest priority enum takes precedence + * over the other conditions. So that if something "more interesting" + * starts happening, stop the previous chrono and start a new one. + */ + if (type > tp->chrono_type) + tcp_chrono_set(tp, type); +} + +void tcp_chrono_stop(struct sock *sk, const enum tcp_chrono type) +{ + struct tcp_sock *tp = tcp_sk(sk); + + + /* There are multiple conditions worthy of tracking in a + * chronograph, so that the highest priority enum takes + * precedence over the other conditions (see tcp_chrono_start). + * If a condition stops, we only stop chrono tracking if + * it's the "most interesting" or current chrono we are + * tracking and starts busy chrono if we have pending data. + */ + if (tcp_rtx_and_write_queues_empty(sk)) + tcp_chrono_set(tp, TCP_CHRONO_UNSPEC); + else if (type == tp->chrono_type) + tcp_chrono_set(tp, TCP_CHRONO_BUSY); +} + +/* This routine writes packets to the network. It advances the + * send_head. This happens as incoming acks open up the remote + * window for us. + * + * LARGESEND note: !tcp_urg_mode is overkill, only frames between + * snd_up-64k-mss .. snd_up cannot be large. However, taking into + * account rare use of URG, this is not a big flaw. + * + * Send at most one packet when push_one > 0. Temporarily ignore + * cwnd limit to force at most one packet out when push_one == 2. + + * Returns true, if no segments are in flight and we have queued segments, + * but cannot send anything now because of SWS or another problem. + */ +static bool tcp_write_xmit(struct sock *sk, unsigned int mss_now, int nonagle, + int push_one, gfp_t gfp) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *skb; + unsigned int tso_segs, sent_pkts; + int cwnd_quota; + int result; + bool is_cwnd_limited = false, is_rwnd_limited = false; + u32 max_segs; + + sent_pkts = 0; + + tcp_mstamp_refresh(tp); + if (!push_one) { + /* Do MTU probing. */ + result = tcp_mtu_probe(sk); + if (!result) { + return false; + } else if (result > 0) { + sent_pkts = 1; + } + } + + max_segs = tcp_tso_segs(sk, mss_now); + while ((skb = tcp_send_head(sk))) { + unsigned int limit; + + if (unlikely(tp->repair) && tp->repair_queue == TCP_SEND_QUEUE) { + /* "skb_mstamp_ns" is used as a start point for the retransmit timer */ + tp->tcp_wstamp_ns = tp->tcp_clock_cache; + skb_set_delivery_time(skb, tp->tcp_wstamp_ns, true); + list_move_tail(&skb->tcp_tsorted_anchor, &tp->tsorted_sent_queue); + tcp_init_tso_segs(skb, mss_now); + goto repair; /* Skip network transmission */ + } + + if (tcp_pacing_check(sk)) + break; + + tso_segs = tcp_init_tso_segs(skb, mss_now); + BUG_ON(!tso_segs); + + cwnd_quota = tcp_cwnd_test(tp, skb); + if (!cwnd_quota) { + if (push_one == 2) + /* Force out a loss probe pkt. */ + cwnd_quota = 1; + else + break; + } + + if (unlikely(!tcp_snd_wnd_test(tp, skb, mss_now))) { + is_rwnd_limited = true; + break; + } + + if (tso_segs == 1) { + if (unlikely(!tcp_nagle_test(tp, skb, mss_now, + (tcp_skb_is_last(sk, skb) ? + nonagle : TCP_NAGLE_PUSH)))) + break; + } else { + if (!push_one && + tcp_tso_should_defer(sk, skb, &is_cwnd_limited, + &is_rwnd_limited, max_segs)) + break; + } + + limit = mss_now; + if (tso_segs > 1 && !tcp_urg_mode(tp)) + limit = tcp_mss_split_point(sk, skb, mss_now, + min_t(unsigned int, + cwnd_quota, + max_segs), + nonagle); + + if (skb->len > limit && + unlikely(tso_fragment(sk, skb, limit, mss_now, gfp))) + break; + + if (tcp_small_queue_check(sk, skb, 0)) + break; + + /* Argh, we hit an empty skb(), presumably a thread + * is sleeping in sendmsg()/sk_stream_wait_memory(). + * We do not want to send a pure-ack packet and have + * a strange looking rtx queue with empty packet(s). + */ + if (TCP_SKB_CB(skb)->end_seq == TCP_SKB_CB(skb)->seq) + break; + + if (unlikely(tcp_transmit_skb(sk, skb, 1, gfp))) + break; + +repair: + /* Advance the send_head. This one is sent out. + * This call will increment packets_out. + */ + tcp_event_new_data_sent(sk, skb); + + tcp_minshall_update(tp, mss_now, skb); + sent_pkts += tcp_skb_pcount(skb); + + if (push_one) + break; + } + + if (is_rwnd_limited) + tcp_chrono_start(sk, TCP_CHRONO_RWND_LIMITED); + else + tcp_chrono_stop(sk, TCP_CHRONO_RWND_LIMITED); + + is_cwnd_limited |= (tcp_packets_in_flight(tp) >= tcp_snd_cwnd(tp)); + if (likely(sent_pkts || is_cwnd_limited)) + tcp_cwnd_validate(sk, is_cwnd_limited); + + if (likely(sent_pkts)) { + if (tcp_in_cwnd_reduction(sk)) + tp->prr_out += sent_pkts; + + /* Send one loss probe per tail loss episode. */ + if (push_one != 2) + tcp_schedule_loss_probe(sk, false); + return false; + } + return !tp->packets_out && !tcp_write_queue_empty(sk); +} + +bool tcp_schedule_loss_probe(struct sock *sk, bool advancing_rto) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + u32 timeout, timeout_us, rto_delta_us; + int early_retrans; + + /* Don't do any loss probe on a Fast Open connection before 3WHS + * finishes. + */ + if (rcu_access_pointer(tp->fastopen_rsk)) + return false; + + early_retrans = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_early_retrans); + /* Schedule a loss probe in 2*RTT for SACK capable connections + * not in loss recovery, that are either limited by cwnd or application. + */ + if ((early_retrans != 3 && early_retrans != 4) || + !tp->packets_out || !tcp_is_sack(tp) || + (icsk->icsk_ca_state != TCP_CA_Open && + icsk->icsk_ca_state != TCP_CA_CWR)) + return false; + + /* Probe timeout is 2*rtt. Add minimum RTO to account + * for delayed ack when there's one outstanding packet. If no RTT + * sample is available then probe after TCP_TIMEOUT_INIT. + */ + if (tp->srtt_us) { + timeout_us = tp->srtt_us >> 2; + if (tp->packets_out == 1) + timeout_us += tcp_rto_min_us(sk); + else + timeout_us += TCP_TIMEOUT_MIN_US; + timeout = usecs_to_jiffies(timeout_us); + } else { + timeout = TCP_TIMEOUT_INIT; + } + + /* If the RTO formula yields an earlier time, then use that time. */ + rto_delta_us = advancing_rto ? + jiffies_to_usecs(inet_csk(sk)->icsk_rto) : + tcp_rto_delta_us(sk); /* How far in future is RTO? */ + if (rto_delta_us > 0) + timeout = min_t(u32, timeout, usecs_to_jiffies(rto_delta_us)); + + tcp_reset_xmit_timer(sk, ICSK_TIME_LOSS_PROBE, timeout, TCP_RTO_MAX); + return true; +} + +/* Thanks to skb fast clones, we can detect if a prior transmit of + * a packet is still in a qdisc or driver queue. + * In this case, there is very little point doing a retransmit ! + */ +static bool skb_still_in_host_queue(struct sock *sk, + const struct sk_buff *skb) +{ + if (unlikely(skb_fclone_busy(sk, skb))) { + set_bit(TSQ_THROTTLED, &sk->sk_tsq_flags); + smp_mb__after_atomic(); + if (skb_fclone_busy(sk, skb)) { + NET_INC_STATS(sock_net(sk), + LINUX_MIB_TCPSPURIOUS_RTX_HOSTQUEUES); + return true; + } + } + return false; +} + +/* When probe timeout (PTO) fires, try send a new segment if possible, else + * retransmit the last segment. + */ +void tcp_send_loss_probe(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *skb; + int pcount; + int mss = tcp_current_mss(sk); + + /* At most one outstanding TLP */ + if (tp->tlp_high_seq) + goto rearm_timer; + + tp->tlp_retrans = 0; + skb = tcp_send_head(sk); + if (skb && tcp_snd_wnd_test(tp, skb, mss)) { + pcount = tp->packets_out; + tcp_write_xmit(sk, mss, TCP_NAGLE_OFF, 2, GFP_ATOMIC); + if (tp->packets_out > pcount) + goto probe_sent; + goto rearm_timer; + } + skb = skb_rb_last(&sk->tcp_rtx_queue); + if (unlikely(!skb)) { + WARN_ONCE(tp->packets_out, + "invalid inflight: %u state %u cwnd %u mss %d\n", + tp->packets_out, sk->sk_state, tcp_snd_cwnd(tp), mss); + inet_csk(sk)->icsk_pending = 0; + return; + } + + if (skb_still_in_host_queue(sk, skb)) + goto rearm_timer; + + pcount = tcp_skb_pcount(skb); + if (WARN_ON(!pcount)) + goto rearm_timer; + + if ((pcount > 1) && (skb->len > (pcount - 1) * mss)) { + if (unlikely(tcp_fragment(sk, TCP_FRAG_IN_RTX_QUEUE, skb, + (pcount - 1) * mss, mss, + GFP_ATOMIC))) + goto rearm_timer; + skb = skb_rb_next(skb); + } + + if (WARN_ON(!skb || !tcp_skb_pcount(skb))) + goto rearm_timer; + + if (__tcp_retransmit_skb(sk, skb, 1)) + goto rearm_timer; + + tp->tlp_retrans = 1; + +probe_sent: + /* Record snd_nxt for loss detection. */ + tp->tlp_high_seq = tp->snd_nxt; + + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPLOSSPROBES); + /* Reset s.t. tcp_rearm_rto will restart timer from now */ + inet_csk(sk)->icsk_pending = 0; +rearm_timer: + tcp_rearm_rto(sk); +} + +/* Push out any pending frames which were held back due to + * TCP_CORK or attempt at coalescing tiny packets. + * The socket must be locked by the caller. + */ +void __tcp_push_pending_frames(struct sock *sk, unsigned int cur_mss, + int nonagle) +{ + /* If we are closed, the bytes will have to remain here. + * In time closedown will finish, we empty the write queue and + * all will be happy. + */ + if (unlikely(sk->sk_state == TCP_CLOSE)) + return; + + if (tcp_write_xmit(sk, cur_mss, nonagle, 0, + sk_gfp_mask(sk, GFP_ATOMIC))) + tcp_check_probe_timer(sk); +} + +/* Send _single_ skb sitting at the send head. This function requires + * true push pending frames to setup probe timer etc. + */ +void tcp_push_one(struct sock *sk, unsigned int mss_now) +{ + struct sk_buff *skb = tcp_send_head(sk); + + BUG_ON(!skb || skb->len < mss_now); + + tcp_write_xmit(sk, mss_now, TCP_NAGLE_PUSH, 1, sk->sk_allocation); +} + +/* This function returns the amount that we can raise the + * usable window based on the following constraints + * + * 1. The window can never be shrunk once it is offered (RFC 793) + * 2. We limit memory per socket + * + * RFC 1122: + * "the suggested [SWS] avoidance algorithm for the receiver is to keep + * RECV.NEXT + RCV.WIN fixed until: + * RCV.BUFF - RCV.USER - RCV.WINDOW >= min(1/2 RCV.BUFF, MSS)" + * + * i.e. don't raise the right edge of the window until you can raise + * it at least MSS bytes. + * + * Unfortunately, the recommended algorithm breaks header prediction, + * since header prediction assumes th->window stays fixed. + * + * Strictly speaking, keeping th->window fixed violates the receiver + * side SWS prevention criteria. The problem is that under this rule + * a stream of single byte packets will cause the right side of the + * window to always advance by a single byte. + * + * Of course, if the sender implements sender side SWS prevention + * then this will not be a problem. + * + * BSD seems to make the following compromise: + * + * If the free space is less than the 1/4 of the maximum + * space available and the free space is less than 1/2 mss, + * then set the window to 0. + * [ Actually, bsd uses MSS and 1/4 of maximal _window_ ] + * Otherwise, just prevent the window from shrinking + * and from being larger than the largest representable value. + * + * This prevents incremental opening of the window in the regime + * where TCP is limited by the speed of the reader side taking + * data out of the TCP receive queue. It does nothing about + * those cases where the window is constrained on the sender side + * because the pipeline is full. + * + * BSD also seems to "accidentally" limit itself to windows that are a + * multiple of MSS, at least until the free space gets quite small. + * This would appear to be a side effect of the mbuf implementation. + * Combining these two algorithms results in the observed behavior + * of having a fixed window size at almost all times. + * + * Below we obtain similar behavior by forcing the offered window to + * a multiple of the mss when it is feasible to do so. + * + * Note, we don't "adjust" for TIMESTAMP or SACK option bytes. + * Regular options like TIMESTAMP are taken into account. + */ +u32 __tcp_select_window(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); + /* MSS for the peer's data. Previous versions used mss_clamp + * here. I don't know if the value based on our guesses + * of peer's MSS is better for the performance. It's more correct + * but may be worse for the performance because of rcv_mss + * fluctuations. --SAW 1998/11/1 + */ + int mss = icsk->icsk_ack.rcv_mss; + int free_space = tcp_space(sk); + int allowed_space = tcp_full_space(sk); + int full_space, window; + + if (sk_is_mptcp(sk)) + mptcp_space(sk, &free_space, &allowed_space); + + full_space = min_t(int, tp->window_clamp, allowed_space); + + if (unlikely(mss > full_space)) { + mss = full_space; + if (mss <= 0) + return 0; + } + + /* Only allow window shrink if the sysctl is enabled and we have + * a non-zero scaling factor in effect. + */ + if (READ_ONCE(net->ipv4.sysctl_tcp_shrink_window) && tp->rx_opt.rcv_wscale) + goto shrink_window_allowed; + + /* do not allow window to shrink */ + + if (free_space < (full_space >> 1)) { + icsk->icsk_ack.quick = 0; + + if (tcp_under_memory_pressure(sk)) + tcp_adjust_rcv_ssthresh(sk); + + /* free_space might become our new window, make sure we don't + * increase it due to wscale. + */ + free_space = round_down(free_space, 1 << tp->rx_opt.rcv_wscale); + + /* if free space is less than mss estimate, or is below 1/16th + * of the maximum allowed, try to move to zero-window, else + * tcp_clamp_window() will grow rcv buf up to tcp_rmem[2], and + * new incoming data is dropped due to memory limits. + * With large window, mss test triggers way too late in order + * to announce zero window in time before rmem limit kicks in. + */ + if (free_space < (allowed_space >> 4) || free_space < mss) + return 0; + } + + if (free_space > tp->rcv_ssthresh) + free_space = tp->rcv_ssthresh; + + /* Don't do rounding if we are using window scaling, since the + * scaled window will not line up with the MSS boundary anyway. + */ + if (tp->rx_opt.rcv_wscale) { + window = free_space; + + /* Advertise enough space so that it won't get scaled away. + * Import case: prevent zero window announcement if + * 1< mss. + */ + window = ALIGN(window, (1 << tp->rx_opt.rcv_wscale)); + } else { + window = tp->rcv_wnd; + /* Get the largest window that is a nice multiple of mss. + * Window clamp already applied above. + * If our current window offering is within 1 mss of the + * free space we just keep it. This prevents the divide + * and multiply from happening most of the time. + * We also don't do any window rounding when the free space + * is too small. + */ + if (window <= free_space - mss || window > free_space) + window = rounddown(free_space, mss); + else if (mss == full_space && + free_space > window + (full_space >> 1)) + window = free_space; + } + + return window; + +shrink_window_allowed: + /* new window should always be an exact multiple of scaling factor */ + free_space = round_down(free_space, 1 << tp->rx_opt.rcv_wscale); + + if (free_space < (full_space >> 1)) { + icsk->icsk_ack.quick = 0; + + if (tcp_under_memory_pressure(sk)) + tcp_adjust_rcv_ssthresh(sk); + + /* if free space is too low, return a zero window */ + if (free_space < (allowed_space >> 4) || free_space < mss || + free_space < (1 << tp->rx_opt.rcv_wscale)) + return 0; + } + + if (free_space > tp->rcv_ssthresh) { + free_space = tp->rcv_ssthresh; + /* new window should always be an exact multiple of scaling factor + * + * For this case, we ALIGN "up" (increase free_space) because + * we know free_space is not zero here, it has been reduced from + * the memory-based limit, and rcv_ssthresh is not a hard limit + * (unlike sk_rcvbuf). + */ + free_space = ALIGN(free_space, (1 << tp->rx_opt.rcv_wscale)); + } + + return free_space; +} + +void tcp_skb_collapse_tstamp(struct sk_buff *skb, + const struct sk_buff *next_skb) +{ + if (unlikely(tcp_has_tx_tstamp(next_skb))) { + const struct skb_shared_info *next_shinfo = + skb_shinfo(next_skb); + struct skb_shared_info *shinfo = skb_shinfo(skb); + + shinfo->tx_flags |= next_shinfo->tx_flags & SKBTX_ANY_TSTAMP; + shinfo->tskey = next_shinfo->tskey; + TCP_SKB_CB(skb)->txstamp_ack |= + TCP_SKB_CB(next_skb)->txstamp_ack; + } +} + +/* Collapses two adjacent SKB's during retransmission. */ +static bool tcp_collapse_retrans(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *next_skb = skb_rb_next(skb); + int next_skb_size; + + next_skb_size = next_skb->len; + + BUG_ON(tcp_skb_pcount(skb) != 1 || tcp_skb_pcount(next_skb) != 1); + + if (next_skb_size && !tcp_skb_shift(skb, next_skb, 1, next_skb_size)) + return false; + + tcp_highest_sack_replace(sk, next_skb, skb); + + /* Update sequence range on original skb. */ + TCP_SKB_CB(skb)->end_seq = TCP_SKB_CB(next_skb)->end_seq; + + /* Merge over control information. This moves PSH/FIN etc. over */ + TCP_SKB_CB(skb)->tcp_flags |= TCP_SKB_CB(next_skb)->tcp_flags; + + /* All done, get rid of second SKB and account for it so + * packet counting does not break. + */ + TCP_SKB_CB(skb)->sacked |= TCP_SKB_CB(next_skb)->sacked & TCPCB_EVER_RETRANS; + TCP_SKB_CB(skb)->eor = TCP_SKB_CB(next_skb)->eor; + + /* changed transmit queue under us so clear hints */ + tcp_clear_retrans_hints_partial(tp); + if (next_skb == tp->retransmit_skb_hint) + tp->retransmit_skb_hint = skb; + + tcp_adjust_pcount(sk, next_skb, tcp_skb_pcount(next_skb)); + + tcp_skb_collapse_tstamp(skb, next_skb); + + tcp_rtx_queue_unlink_and_free(next_skb, sk); + return true; +} + +/* Check if coalescing SKBs is legal. */ +static bool tcp_can_collapse(const struct sock *sk, const struct sk_buff *skb) +{ + if (tcp_skb_pcount(skb) > 1) + return false; + if (skb_cloned(skb)) + return false; + /* Some heuristics for collapsing over SACK'd could be invented */ + if (TCP_SKB_CB(skb)->sacked & TCPCB_SACKED_ACKED) + return false; + + return true; +} + +/* Collapse packets in the retransmit queue to make to create + * less packets on the wire. This is only done on retransmission. + */ +static void tcp_retrans_try_collapse(struct sock *sk, struct sk_buff *to, + int space) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *skb = to, *tmp; + bool first = true; + + if (!READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_retrans_collapse)) + return; + if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_SYN) + return; + + skb_rbtree_walk_from_safe(skb, tmp) { + if (!tcp_can_collapse(sk, skb)) + break; + + if (!tcp_skb_can_collapse(to, skb)) + break; + + space -= skb->len; + + if (first) { + first = false; + continue; + } + + if (space < 0) + break; + + if (after(TCP_SKB_CB(skb)->end_seq, tcp_wnd_end(tp))) + break; + + if (!tcp_collapse_retrans(sk, to)) + break; + } +} + +/* This retransmits one SKB. Policy decisions and retransmit queue + * state updates are done by the caller. Returns non-zero if an + * error occurred which prevented the send. + */ +int __tcp_retransmit_skb(struct sock *sk, struct sk_buff *skb, int segs) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + unsigned int cur_mss; + int diff, len, err; + int avail_wnd; + + /* Inconclusive MTU probe */ + if (icsk->icsk_mtup.probe_size) + icsk->icsk_mtup.probe_size = 0; + + if (skb_still_in_host_queue(sk, skb)) + return -EBUSY; + +start: + if (before(TCP_SKB_CB(skb)->seq, tp->snd_una)) { + if (unlikely(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_SYN)) { + TCP_SKB_CB(skb)->tcp_flags &= ~TCPHDR_SYN; + TCP_SKB_CB(skb)->seq++; + goto start; + } + if (unlikely(before(TCP_SKB_CB(skb)->end_seq, tp->snd_una))) { + WARN_ON_ONCE(1); + return -EINVAL; + } + if (tcp_trim_head(sk, skb, tp->snd_una - TCP_SKB_CB(skb)->seq)) + return -ENOMEM; + } + + if (inet_csk(sk)->icsk_af_ops->rebuild_header(sk)) + return -EHOSTUNREACH; /* Routing failure or similar. */ + + cur_mss = tcp_current_mss(sk); + avail_wnd = tcp_wnd_end(tp) - TCP_SKB_CB(skb)->seq; + + /* If receiver has shrunk his window, and skb is out of + * new window, do not retransmit it. The exception is the + * case, when window is shrunk to zero. In this case + * our retransmit of one segment serves as a zero window probe. + */ + if (avail_wnd <= 0) { + if (TCP_SKB_CB(skb)->seq != tp->snd_una) + return -EAGAIN; + avail_wnd = cur_mss; + } + + len = cur_mss * segs; + if (len > avail_wnd) { + len = rounddown(avail_wnd, cur_mss); + if (!len) + len = avail_wnd; + } + if (skb->len > len) { + if (tcp_fragment(sk, TCP_FRAG_IN_RTX_QUEUE, skb, len, + cur_mss, GFP_ATOMIC)) + return -ENOMEM; /* We'll try again later. */ + } else { + if (skb_unclone_keeptruesize(skb, GFP_ATOMIC)) + return -ENOMEM; + + diff = tcp_skb_pcount(skb); + tcp_set_skb_tso_segs(skb, cur_mss); + diff -= tcp_skb_pcount(skb); + if (diff) + tcp_adjust_pcount(sk, skb, diff); + avail_wnd = min_t(int, avail_wnd, cur_mss); + if (skb->len < avail_wnd) + tcp_retrans_try_collapse(sk, skb, avail_wnd); + } + + /* RFC3168, section 6.1.1.1. ECN fallback */ + if ((TCP_SKB_CB(skb)->tcp_flags & TCPHDR_SYN_ECN) == TCPHDR_SYN_ECN) + tcp_ecn_clear_syn(sk, skb); + + /* Update global and local TCP statistics. */ + segs = tcp_skb_pcount(skb); + TCP_ADD_STATS(sock_net(sk), TCP_MIB_RETRANSSEGS, segs); + if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_SYN) + __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPSYNRETRANS); + tp->total_retrans += segs; + tp->bytes_retrans += skb->len; + + /* make sure skb->data is aligned on arches that require it + * and check if ack-trimming & collapsing extended the headroom + * beyond what csum_start can cover. + */ + if (unlikely((NET_IP_ALIGN && ((unsigned long)skb->data & 3)) || + skb_headroom(skb) >= 0xFFFF)) { + struct sk_buff *nskb; + + tcp_skb_tsorted_save(skb) { + nskb = __pskb_copy(skb, MAX_TCP_HEADER, GFP_ATOMIC); + if (nskb) { + nskb->dev = NULL; + err = tcp_transmit_skb(sk, nskb, 0, GFP_ATOMIC); + } else { + err = -ENOBUFS; + } + } tcp_skb_tsorted_restore(skb); + + if (!err) { + tcp_update_skb_after_send(sk, skb, tp->tcp_wstamp_ns); + tcp_rate_skb_sent(sk, skb); + } + } else { + err = tcp_transmit_skb(sk, skb, 1, GFP_ATOMIC); + } + + /* To avoid taking spuriously low RTT samples based on a timestamp + * for a transmit that never happened, always mark EVER_RETRANS + */ + TCP_SKB_CB(skb)->sacked |= TCPCB_EVER_RETRANS; + + if (BPF_SOCK_OPS_TEST_FLAG(tp, BPF_SOCK_OPS_RETRANS_CB_FLAG)) + tcp_call_bpf_3arg(sk, BPF_SOCK_OPS_RETRANS_CB, + TCP_SKB_CB(skb)->seq, segs, err); + + if (likely(!err)) { + trace_tcp_retransmit_skb(sk, skb); + } else if (err != -EBUSY) { + NET_ADD_STATS(sock_net(sk), LINUX_MIB_TCPRETRANSFAIL, segs); + } + return err; +} + +int tcp_retransmit_skb(struct sock *sk, struct sk_buff *skb, int segs) +{ + struct tcp_sock *tp = tcp_sk(sk); + int err = __tcp_retransmit_skb(sk, skb, segs); + + if (err == 0) { +#if FASTRETRANS_DEBUG > 0 + if (TCP_SKB_CB(skb)->sacked & TCPCB_SACKED_RETRANS) { + net_dbg_ratelimited("retrans_out leaked\n"); + } +#endif + TCP_SKB_CB(skb)->sacked |= TCPCB_RETRANS; + tp->retrans_out += tcp_skb_pcount(skb); + } + + /* Save stamp of the first (attempted) retransmit. */ + if (!tp->retrans_stamp) + tp->retrans_stamp = tcp_skb_timestamp(skb); + + if (tp->undo_retrans < 0) + tp->undo_retrans = 0; + tp->undo_retrans += tcp_skb_pcount(skb); + return err; +} + +/* This gets called after a retransmit timeout, and the initially + * retransmitted data is acknowledged. It tries to continue + * resending the rest of the retransmit queue, until either + * we've sent it all or the congestion window limit is reached. + */ +void tcp_xmit_retransmit_queue(struct sock *sk) +{ + const struct inet_connection_sock *icsk = inet_csk(sk); + struct sk_buff *skb, *rtx_head, *hole = NULL; + struct tcp_sock *tp = tcp_sk(sk); + bool rearm_timer = false; + u32 max_segs; + int mib_idx; + + if (!tp->packets_out) + return; + + rtx_head = tcp_rtx_queue_head(sk); + skb = tp->retransmit_skb_hint ?: rtx_head; + max_segs = tcp_tso_segs(sk, tcp_current_mss(sk)); + skb_rbtree_walk_from(skb) { + __u8 sacked; + int segs; + + if (tcp_pacing_check(sk)) + break; + + /* we could do better than to assign each time */ + if (!hole) + tp->retransmit_skb_hint = skb; + + segs = tcp_snd_cwnd(tp) - tcp_packets_in_flight(tp); + if (segs <= 0) + break; + sacked = TCP_SKB_CB(skb)->sacked; + /* In case tcp_shift_skb_data() have aggregated large skbs, + * we need to make sure not sending too bigs TSO packets + */ + segs = min_t(int, segs, max_segs); + + if (tp->retrans_out >= tp->lost_out) { + break; + } else if (!(sacked & TCPCB_LOST)) { + if (!hole && !(sacked & (TCPCB_SACKED_RETRANS|TCPCB_SACKED_ACKED))) + hole = skb; + continue; + + } else { + if (icsk->icsk_ca_state != TCP_CA_Loss) + mib_idx = LINUX_MIB_TCPFASTRETRANS; + else + mib_idx = LINUX_MIB_TCPSLOWSTARTRETRANS; + } + + if (sacked & (TCPCB_SACKED_ACKED|TCPCB_SACKED_RETRANS)) + continue; + + if (tcp_small_queue_check(sk, skb, 1)) + break; + + if (tcp_retransmit_skb(sk, skb, segs)) + break; + + NET_ADD_STATS(sock_net(sk), mib_idx, tcp_skb_pcount(skb)); + + if (tcp_in_cwnd_reduction(sk)) + tp->prr_out += tcp_skb_pcount(skb); + + if (skb == rtx_head && + icsk->icsk_pending != ICSK_TIME_REO_TIMEOUT) + rearm_timer = true; + + } + if (rearm_timer) + tcp_reset_xmit_timer(sk, ICSK_TIME_RETRANS, + inet_csk(sk)->icsk_rto, + TCP_RTO_MAX); +} + +/* We allow to exceed memory limits for FIN packets to expedite + * connection tear down and (memory) recovery. + * Otherwise tcp_send_fin() could be tempted to either delay FIN + * or even be forced to close flow without any FIN. + * In general, we want to allow one skb per socket to avoid hangs + * with edge trigger epoll() + */ +void sk_forced_mem_schedule(struct sock *sk, int size) +{ + int delta, amt; + + delta = size - sk->sk_forward_alloc; + if (delta <= 0) + return; + amt = sk_mem_pages(delta); + sk_forward_alloc_add(sk, amt << PAGE_SHIFT); + sk_memory_allocated_add(sk, amt); + + if (mem_cgroup_sockets_enabled && sk->sk_memcg) + mem_cgroup_charge_skmem(sk->sk_memcg, amt, + gfp_memcg_charge() | __GFP_NOFAIL); +} + +/* Send a FIN. The caller locks the socket for us. + * We should try to send a FIN packet really hard, but eventually give up. + */ +void tcp_send_fin(struct sock *sk) +{ + struct sk_buff *skb, *tskb, *tail = tcp_write_queue_tail(sk); + struct tcp_sock *tp = tcp_sk(sk); + + /* Optimization, tack on the FIN if we have one skb in write queue and + * this skb was not yet sent, or we are under memory pressure. + * Note: in the latter case, FIN packet will be sent after a timeout, + * as TCP stack thinks it has already been transmitted. + */ + tskb = tail; + if (!tskb && tcp_under_memory_pressure(sk)) + tskb = skb_rb_last(&sk->tcp_rtx_queue); + + if (tskb) { + TCP_SKB_CB(tskb)->tcp_flags |= TCPHDR_FIN; + TCP_SKB_CB(tskb)->end_seq++; + tp->write_seq++; + if (!tail) { + /* This means tskb was already sent. + * Pretend we included the FIN on previous transmit. + * We need to set tp->snd_nxt to the value it would have + * if FIN had been sent. This is because retransmit path + * does not change tp->snd_nxt. + */ + WRITE_ONCE(tp->snd_nxt, tp->snd_nxt + 1); + return; + } + } else { + skb = alloc_skb_fclone(MAX_TCP_HEADER, sk->sk_allocation); + if (unlikely(!skb)) + return; + + INIT_LIST_HEAD(&skb->tcp_tsorted_anchor); + skb_reserve(skb, MAX_TCP_HEADER); + sk_forced_mem_schedule(sk, skb->truesize); + /* FIN eats a sequence byte, write_seq advanced by tcp_queue_skb(). */ + tcp_init_nondata_skb(skb, tp->write_seq, + TCPHDR_ACK | TCPHDR_FIN); + tcp_queue_skb(sk, skb); + } + __tcp_push_pending_frames(sk, tcp_current_mss(sk), TCP_NAGLE_OFF); +} + +/* We get here when a process closes a file descriptor (either due to + * an explicit close() or as a byproduct of exit()'ing) and there + * was unread data in the receive queue. This behavior is recommended + * by RFC 2525, section 2.17. -DaveM + */ +void tcp_send_active_reset(struct sock *sk, gfp_t priority) +{ + struct sk_buff *skb; + + TCP_INC_STATS(sock_net(sk), TCP_MIB_OUTRSTS); + + /* NOTE: No TCP options attached and we never retransmit this. */ + skb = alloc_skb(MAX_TCP_HEADER, priority); + if (!skb) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTFAILED); + return; + } + + /* Reserve space for headers and prepare control bits. */ + skb_reserve(skb, MAX_TCP_HEADER); + tcp_init_nondata_skb(skb, tcp_acceptable_seq(sk), + TCPHDR_ACK | TCPHDR_RST); + tcp_mstamp_refresh(tcp_sk(sk)); + /* Send it off. */ + if (tcp_transmit_skb(sk, skb, 0, priority)) + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTFAILED); + + /* skb of trace_tcp_send_reset() keeps the skb that caused RST, + * skb here is different to the troublesome skb, so use NULL + */ + trace_tcp_send_reset(sk, NULL); +} + +/* Send a crossed SYN-ACK during socket establishment. + * WARNING: This routine must only be called when we have already sent + * a SYN packet that crossed the incoming SYN that caused this routine + * to get called. If this assumption fails then the initial rcv_wnd + * and rcv_wscale values will not be correct. + */ +int tcp_send_synack(struct sock *sk) +{ + struct sk_buff *skb; + + skb = tcp_rtx_queue_head(sk); + if (!skb || !(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_SYN)) { + pr_err("%s: wrong queue state\n", __func__); + return -EFAULT; + } + if (!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_ACK)) { + if (skb_cloned(skb)) { + struct sk_buff *nskb; + + tcp_skb_tsorted_save(skb) { + nskb = skb_copy(skb, GFP_ATOMIC); + } tcp_skb_tsorted_restore(skb); + if (!nskb) + return -ENOMEM; + INIT_LIST_HEAD(&nskb->tcp_tsorted_anchor); + tcp_highest_sack_replace(sk, skb, nskb); + tcp_rtx_queue_unlink_and_free(skb, sk); + __skb_header_release(nskb); + tcp_rbtree_insert(&sk->tcp_rtx_queue, nskb); + sk_wmem_queued_add(sk, nskb->truesize); + sk_mem_charge(sk, nskb->truesize); + skb = nskb; + } + + TCP_SKB_CB(skb)->tcp_flags |= TCPHDR_ACK; + tcp_ecn_send_synack(sk, skb); + } + return tcp_transmit_skb(sk, skb, 1, GFP_ATOMIC); +} + +/** + * tcp_make_synack - Allocate one skb and build a SYNACK packet. + * @sk: listener socket + * @dst: dst entry attached to the SYNACK. It is consumed and caller + * should not use it again. + * @req: request_sock pointer + * @foc: cookie for tcp fast open + * @synack_type: Type of synack to prepare + * @syn_skb: SYN packet just received. It could be NULL for rtx case. + */ +struct sk_buff *tcp_make_synack(const struct sock *sk, struct dst_entry *dst, + struct request_sock *req, + struct tcp_fastopen_cookie *foc, + enum tcp_synack_type synack_type, + struct sk_buff *syn_skb) +{ + struct inet_request_sock *ireq = inet_rsk(req); + const struct tcp_sock *tp = tcp_sk(sk); + struct tcp_md5sig_key *md5 = NULL; + struct tcp_out_options opts; + struct sk_buff *skb; + int tcp_header_size; + struct tcphdr *th; + int mss; + u64 now; + + skb = alloc_skb(MAX_TCP_HEADER, GFP_ATOMIC); + if (unlikely(!skb)) { + dst_release(dst); + return NULL; + } + /* Reserve space for headers. */ + skb_reserve(skb, MAX_TCP_HEADER); + + switch (synack_type) { + case TCP_SYNACK_NORMAL: + skb_set_owner_w(skb, req_to_sk(req)); + break; + case TCP_SYNACK_COOKIE: + /* Under synflood, we do not attach skb to a socket, + * to avoid false sharing. + */ + break; + case TCP_SYNACK_FASTOPEN: + /* sk is a const pointer, because we want to express multiple + * cpu might call us concurrently. + * sk->sk_wmem_alloc in an atomic, we can promote to rw. + */ + skb_set_owner_w(skb, (struct sock *)sk); + break; + } + skb_dst_set(skb, dst); + + mss = tcp_mss_clamp(tp, dst_metric_advmss(dst)); + + memset(&opts, 0, sizeof(opts)); + now = tcp_clock_ns(); +#ifdef CONFIG_SYN_COOKIES + if (unlikely(synack_type == TCP_SYNACK_COOKIE && ireq->tstamp_ok)) + skb_set_delivery_time(skb, cookie_init_timestamp(req, now), + true); + else +#endif + { + skb_set_delivery_time(skb, now, true); + if (!tcp_rsk(req)->snt_synack) /* Timestamp first SYNACK */ + tcp_rsk(req)->snt_synack = tcp_skb_timestamp_us(skb); + } + +#ifdef CONFIG_TCP_MD5SIG + rcu_read_lock(); + md5 = tcp_rsk(req)->af_specific->req_md5_lookup(sk, req_to_sk(req)); +#endif + skb_set_hash(skb, READ_ONCE(tcp_rsk(req)->txhash), PKT_HASH_TYPE_L4); + /* bpf program will be interested in the tcp_flags */ + TCP_SKB_CB(skb)->tcp_flags = TCPHDR_SYN | TCPHDR_ACK; + tcp_header_size = tcp_synack_options(sk, req, mss, skb, &opts, md5, + foc, synack_type, + syn_skb) + sizeof(*th); + + skb_push(skb, tcp_header_size); + skb_reset_transport_header(skb); + + th = (struct tcphdr *)skb->data; + memset(th, 0, sizeof(struct tcphdr)); + th->syn = 1; + th->ack = 1; + tcp_ecn_make_synack(req, th); + th->source = htons(ireq->ir_num); + th->dest = ireq->ir_rmt_port; + skb->mark = ireq->ir_mark; + skb->ip_summed = CHECKSUM_PARTIAL; + th->seq = htonl(tcp_rsk(req)->snt_isn); + /* XXX data is queued and acked as is. No buffer/window check */ + th->ack_seq = htonl(tcp_rsk(req)->rcv_nxt); + + /* RFC1323: The window in SYN & SYN/ACK segments is never scaled. */ + th->window = htons(min(req->rsk_rcv_wnd, 65535U)); + tcp_options_write(th, NULL, &opts); + th->doff = (tcp_header_size >> 2); + TCP_INC_STATS(sock_net(sk), TCP_MIB_OUTSEGS); + +#ifdef CONFIG_TCP_MD5SIG + /* Okay, we have all we need - do the md5 hash if needed */ + if (md5) + tcp_rsk(req)->af_specific->calc_md5_hash(opts.hash_location, + md5, req_to_sk(req), skb); + rcu_read_unlock(); +#endif + + bpf_skops_write_hdr_opt((struct sock *)sk, skb, req, syn_skb, + synack_type, &opts); + + skb_set_delivery_time(skb, now, true); + tcp_add_tx_delay(skb, tp); + + return skb; +} +EXPORT_SYMBOL(tcp_make_synack); + +static void tcp_ca_dst_init(struct sock *sk, const struct dst_entry *dst) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + const struct tcp_congestion_ops *ca; + u32 ca_key = dst_metric(dst, RTAX_CC_ALGO); + + if (ca_key == TCP_CA_UNSPEC) + return; + + rcu_read_lock(); + ca = tcp_ca_find_key(ca_key); + if (likely(ca && bpf_try_module_get(ca, ca->owner))) { + bpf_module_put(icsk->icsk_ca_ops, icsk->icsk_ca_ops->owner); + icsk->icsk_ca_dst_locked = tcp_ca_dst_locked(dst); + icsk->icsk_ca_ops = ca; + } + rcu_read_unlock(); +} + +/* Do all connect socket setups that can be done AF independent. */ +static void tcp_connect_init(struct sock *sk) +{ + const struct dst_entry *dst = __sk_dst_get(sk); + struct tcp_sock *tp = tcp_sk(sk); + __u8 rcv_wscale; + u32 rcv_wnd; + + /* We'll fix this up when we get a response from the other end. + * See tcp_input.c:tcp_rcv_state_process case TCP_SYN_SENT. + */ + tp->tcp_header_len = sizeof(struct tcphdr); + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_timestamps)) + tp->tcp_header_len += TCPOLEN_TSTAMP_ALIGNED; + +#ifdef CONFIG_TCP_MD5SIG + if (tp->af_specific->md5_lookup(sk, sk)) + tp->tcp_header_len += TCPOLEN_MD5SIG_ALIGNED; +#endif + + /* If user gave his TCP_MAXSEG, record it to clamp */ + if (tp->rx_opt.user_mss) + tp->rx_opt.mss_clamp = tp->rx_opt.user_mss; + tp->max_window = 0; + tcp_mtup_init(sk); + tcp_sync_mss(sk, dst_mtu(dst)); + + tcp_ca_dst_init(sk, dst); + + if (!tp->window_clamp) + tp->window_clamp = dst_metric(dst, RTAX_WINDOW); + tp->advmss = tcp_mss_clamp(tp, dst_metric_advmss(dst)); + + tcp_initialize_rcv_mss(sk); + + /* limit the window selection if the user enforce a smaller rx buffer */ + if (sk->sk_userlocks & SOCK_RCVBUF_LOCK && + (tp->window_clamp > tcp_full_space(sk) || tp->window_clamp == 0)) + tp->window_clamp = tcp_full_space(sk); + + rcv_wnd = tcp_rwnd_init_bpf(sk); + if (rcv_wnd == 0) + rcv_wnd = dst_metric(dst, RTAX_INITRWND); + + tcp_select_initial_window(sk, tcp_full_space(sk), + tp->advmss - (tp->rx_opt.ts_recent_stamp ? tp->tcp_header_len - sizeof(struct tcphdr) : 0), + &tp->rcv_wnd, + &tp->window_clamp, + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_window_scaling), + &rcv_wscale, + rcv_wnd); + + tp->rx_opt.rcv_wscale = rcv_wscale; + tp->rcv_ssthresh = tp->rcv_wnd; + + sk->sk_err = 0; + sock_reset_flag(sk, SOCK_DONE); + tp->snd_wnd = 0; + tcp_init_wl(tp, 0); + tcp_write_queue_purge(sk); + tp->snd_una = tp->write_seq; + tp->snd_sml = tp->write_seq; + tp->snd_up = tp->write_seq; + WRITE_ONCE(tp->snd_nxt, tp->write_seq); + + if (likely(!tp->repair)) + tp->rcv_nxt = 0; + else + tp->rcv_tstamp = tcp_jiffies32; + tp->rcv_wup = tp->rcv_nxt; + WRITE_ONCE(tp->copied_seq, tp->rcv_nxt); + + inet_csk(sk)->icsk_rto = tcp_timeout_init(sk); + inet_csk(sk)->icsk_retransmits = 0; + tcp_clear_retrans(tp); +} + +static void tcp_connect_queue_skb(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); + + tcb->end_seq += skb->len; + __skb_header_release(skb); + sk_wmem_queued_add(sk, skb->truesize); + sk_mem_charge(sk, skb->truesize); + WRITE_ONCE(tp->write_seq, tcb->end_seq); + tp->packets_out += tcp_skb_pcount(skb); +} + +/* Build and send a SYN with data and (cached) Fast Open cookie. However, + * queue a data-only packet after the regular SYN, such that regular SYNs + * are retransmitted on timeouts. Also if the remote SYN-ACK acknowledges + * only the SYN sequence, the data are retransmitted in the first ACK. + * If cookie is not cached or other error occurs, falls back to send a + * regular SYN with Fast Open cookie request option. + */ +static int tcp_send_syn_data(struct sock *sk, struct sk_buff *syn) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct tcp_fastopen_request *fo = tp->fastopen_req; + int space, err = 0; + struct sk_buff *syn_data; + + tp->rx_opt.mss_clamp = tp->advmss; /* If MSS is not cached */ + if (!tcp_fastopen_cookie_check(sk, &tp->rx_opt.mss_clamp, &fo->cookie)) + goto fallback; + + /* MSS for SYN-data is based on cached MSS and bounded by PMTU and + * user-MSS. Reserve maximum option space for middleboxes that add + * private TCP options. The cost is reduced data space in SYN :( + */ + tp->rx_opt.mss_clamp = tcp_mss_clamp(tp, tp->rx_opt.mss_clamp); + /* Sync mss_cache after updating the mss_clamp */ + tcp_sync_mss(sk, icsk->icsk_pmtu_cookie); + + space = __tcp_mtu_to_mss(sk, icsk->icsk_pmtu_cookie) - + MAX_TCP_OPTION_SPACE; + + space = min_t(size_t, space, fo->size); + + /* limit to order-0 allocations */ + space = min_t(size_t, space, SKB_MAX_HEAD(MAX_TCP_HEADER)); + + syn_data = tcp_stream_alloc_skb(sk, space, sk->sk_allocation, false); + if (!syn_data) + goto fallback; + memcpy(syn_data->cb, syn->cb, sizeof(syn->cb)); + if (space) { + int copied = copy_from_iter(skb_put(syn_data, space), space, + &fo->data->msg_iter); + if (unlikely(!copied)) { + tcp_skb_tsorted_anchor_cleanup(syn_data); + kfree_skb(syn_data); + goto fallback; + } + if (copied != space) { + skb_trim(syn_data, copied); + space = copied; + } + skb_zcopy_set(syn_data, fo->uarg, NULL); + } + /* No more data pending in inet_wait_for_connect() */ + if (space == fo->size) + fo->data = NULL; + fo->copied = space; + + tcp_connect_queue_skb(sk, syn_data); + if (syn_data->len) + tcp_chrono_start(sk, TCP_CHRONO_BUSY); + + err = tcp_transmit_skb(sk, syn_data, 1, sk->sk_allocation); + + skb_set_delivery_time(syn, syn_data->skb_mstamp_ns, true); + + /* Now full SYN+DATA was cloned and sent (or not), + * remove the SYN from the original skb (syn_data) + * we keep in write queue in case of a retransmit, as we + * also have the SYN packet (with no data) in the same queue. + */ + TCP_SKB_CB(syn_data)->seq++; + TCP_SKB_CB(syn_data)->tcp_flags = TCPHDR_ACK | TCPHDR_PSH; + if (!err) { + tp->syn_data = (fo->copied > 0); + tcp_rbtree_insert(&sk->tcp_rtx_queue, syn_data); + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPORIGDATASENT); + goto done; + } + + /* data was not sent, put it in write_queue */ + __skb_queue_tail(&sk->sk_write_queue, syn_data); + tp->packets_out -= tcp_skb_pcount(syn_data); + +fallback: + /* Send a regular SYN with Fast Open cookie request option */ + if (fo->cookie.len > 0) + fo->cookie.len = 0; + err = tcp_transmit_skb(sk, syn, 1, sk->sk_allocation); + if (err) + tp->syn_fastopen = 0; +done: + fo->cookie.len = -1; /* Exclude Fast Open option for SYN retries */ + return err; +} + +/* Build a SYN and send it off. */ +int tcp_connect(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *buff; + int err; + + tcp_call_bpf(sk, BPF_SOCK_OPS_TCP_CONNECT_CB, 0, NULL); + + if (inet_csk(sk)->icsk_af_ops->rebuild_header(sk)) + return -EHOSTUNREACH; /* Routing failure or similar. */ + + tcp_connect_init(sk); + + if (unlikely(tp->repair)) { + tcp_finish_connect(sk, NULL); + return 0; + } + + buff = tcp_stream_alloc_skb(sk, 0, sk->sk_allocation, true); + if (unlikely(!buff)) + return -ENOBUFS; + + tcp_init_nondata_skb(buff, tp->write_seq++, TCPHDR_SYN); + tcp_mstamp_refresh(tp); + tp->retrans_stamp = tcp_time_stamp(tp); + tcp_connect_queue_skb(sk, buff); + tcp_ecn_send_syn(sk, buff); + tcp_rbtree_insert(&sk->tcp_rtx_queue, buff); + + /* Send off SYN; include data in Fast Open. */ + err = tp->fastopen_req ? tcp_send_syn_data(sk, buff) : + tcp_transmit_skb(sk, buff, 1, sk->sk_allocation); + if (err == -ECONNREFUSED) + return err; + + /* We change tp->snd_nxt after the tcp_transmit_skb() call + * in order to make this packet get counted in tcpOutSegs. + */ + WRITE_ONCE(tp->snd_nxt, tp->write_seq); + tp->pushed_seq = tp->write_seq; + buff = tcp_send_head(sk); + if (unlikely(buff)) { + WRITE_ONCE(tp->snd_nxt, TCP_SKB_CB(buff)->seq); + tp->pushed_seq = TCP_SKB_CB(buff)->seq; + } + TCP_INC_STATS(sock_net(sk), TCP_MIB_ACTIVEOPENS); + + /* Timer for repeating the SYN until an answer. */ + inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, + inet_csk(sk)->icsk_rto, TCP_RTO_MAX); + return 0; +} +EXPORT_SYMBOL(tcp_connect); + +/* Send out a delayed ack, the caller does the policy checking + * to see if we should even be here. See tcp_input.c:tcp_ack_snd_check() + * for details. + */ +void tcp_send_delayed_ack(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + int ato = icsk->icsk_ack.ato; + unsigned long timeout; + + if (ato > TCP_DELACK_MIN) { + const struct tcp_sock *tp = tcp_sk(sk); + int max_ato = HZ / 2; + + if (inet_csk_in_pingpong_mode(sk) || + (icsk->icsk_ack.pending & ICSK_ACK_PUSHED)) + max_ato = TCP_DELACK_MAX; + + /* Slow path, intersegment interval is "high". */ + + /* If some rtt estimate is known, use it to bound delayed ack. + * Do not use inet_csk(sk)->icsk_rto here, use results of rtt measurements + * directly. + */ + if (tp->srtt_us) { + int rtt = max_t(int, usecs_to_jiffies(tp->srtt_us >> 3), + TCP_DELACK_MIN); + + if (rtt < max_ato) + max_ato = rtt; + } + + ato = min(ato, max_ato); + } + + ato = min_t(u32, ato, inet_csk(sk)->icsk_delack_max); + + /* Stay within the limit we were given */ + timeout = jiffies + ato; + + /* Use new timeout only if there wasn't a older one earlier. */ + if (icsk->icsk_ack.pending & ICSK_ACK_TIMER) { + /* If delack timer is about to expire, send ACK now. */ + if (time_before_eq(icsk->icsk_ack.timeout, jiffies + (ato >> 2))) { + tcp_send_ack(sk); + return; + } + + if (!time_before(timeout, icsk->icsk_ack.timeout)) + timeout = icsk->icsk_ack.timeout; + } + icsk->icsk_ack.pending |= ICSK_ACK_SCHED | ICSK_ACK_TIMER; + icsk->icsk_ack.timeout = timeout; + sk_reset_timer(sk, &icsk->icsk_delack_timer, timeout); +} + +/* This routine sends an ack and also updates the window. */ +void __tcp_send_ack(struct sock *sk, u32 rcv_nxt) +{ + struct sk_buff *buff; + + /* If we have been reset, we may not send again. */ + if (sk->sk_state == TCP_CLOSE) + return; + + /* We are not putting this on the write queue, so + * tcp_transmit_skb() will set the ownership to this + * sock. + */ + buff = alloc_skb(MAX_TCP_HEADER, + sk_gfp_mask(sk, GFP_ATOMIC | __GFP_NOWARN)); + if (unlikely(!buff)) { + struct inet_connection_sock *icsk = inet_csk(sk); + unsigned long delay; + + delay = TCP_DELACK_MAX << icsk->icsk_ack.retry; + if (delay < TCP_RTO_MAX) + icsk->icsk_ack.retry++; + inet_csk_schedule_ack(sk); + icsk->icsk_ack.ato = TCP_ATO_MIN; + inet_csk_reset_xmit_timer(sk, ICSK_TIME_DACK, delay, TCP_RTO_MAX); + return; + } + + /* Reserve space for headers and prepare control bits. */ + skb_reserve(buff, MAX_TCP_HEADER); + tcp_init_nondata_skb(buff, tcp_acceptable_seq(sk), TCPHDR_ACK); + + /* We do not want pure acks influencing TCP Small Queues or fq/pacing + * too much. + * SKB_TRUESIZE(max(1 .. 66, MAX_TCP_HEADER)) is unfortunately ~784 + */ + skb_set_tcp_pure_ack(buff); + + /* Send it off, this clears delayed acks for us. */ + __tcp_transmit_skb(sk, buff, 0, (__force gfp_t)0, rcv_nxt); +} +EXPORT_SYMBOL_GPL(__tcp_send_ack); + +void tcp_send_ack(struct sock *sk) +{ + __tcp_send_ack(sk, tcp_sk(sk)->rcv_nxt); +} + +/* This routine sends a packet with an out of date sequence + * number. It assumes the other end will try to ack it. + * + * Question: what should we make while urgent mode? + * 4.4BSD forces sending single byte of data. We cannot send + * out of window data, because we have SND.NXT==SND.MAX... + * + * Current solution: to send TWO zero-length segments in urgent mode: + * one is with SEG.SEQ=SND.UNA to deliver urgent pointer, another is + * out-of-date with SND.UNA-1 to probe window. + */ +static int tcp_xmit_probe_skb(struct sock *sk, int urgent, int mib) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *skb; + + /* We don't queue it, tcp_transmit_skb() sets ownership. */ + skb = alloc_skb(MAX_TCP_HEADER, + sk_gfp_mask(sk, GFP_ATOMIC | __GFP_NOWARN)); + if (!skb) + return -1; + + /* Reserve space for headers and set control bits. */ + skb_reserve(skb, MAX_TCP_HEADER); + /* Use a previous sequence. This should cause the other + * end to send an ack. Don't queue or clone SKB, just + * send it. + */ + tcp_init_nondata_skb(skb, tp->snd_una - !urgent, TCPHDR_ACK); + NET_INC_STATS(sock_net(sk), mib); + return tcp_transmit_skb(sk, skb, 0, (__force gfp_t)0); +} + +/* Called from setsockopt( ... TCP_REPAIR ) */ +void tcp_send_window_probe(struct sock *sk) +{ + if (sk->sk_state == TCP_ESTABLISHED) { + tcp_sk(sk)->snd_wl1 = tcp_sk(sk)->rcv_nxt - 1; + tcp_mstamp_refresh(tcp_sk(sk)); + tcp_xmit_probe_skb(sk, 0, LINUX_MIB_TCPWINPROBE); + } +} + +/* Initiate keepalive or window probe from timer. */ +int tcp_write_wakeup(struct sock *sk, int mib) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *skb; + + if (sk->sk_state == TCP_CLOSE) + return -1; + + skb = tcp_send_head(sk); + if (skb && before(TCP_SKB_CB(skb)->seq, tcp_wnd_end(tp))) { + int err; + unsigned int mss = tcp_current_mss(sk); + unsigned int seg_size = tcp_wnd_end(tp) - TCP_SKB_CB(skb)->seq; + + if (before(tp->pushed_seq, TCP_SKB_CB(skb)->end_seq)) + tp->pushed_seq = TCP_SKB_CB(skb)->end_seq; + + /* We are probing the opening of a window + * but the window size is != 0 + * must have been a result SWS avoidance ( sender ) + */ + if (seg_size < TCP_SKB_CB(skb)->end_seq - TCP_SKB_CB(skb)->seq || + skb->len > mss) { + seg_size = min(seg_size, mss); + TCP_SKB_CB(skb)->tcp_flags |= TCPHDR_PSH; + if (tcp_fragment(sk, TCP_FRAG_IN_WRITE_QUEUE, + skb, seg_size, mss, GFP_ATOMIC)) + return -1; + } else if (!tcp_skb_pcount(skb)) + tcp_set_skb_tso_segs(skb, mss); + + TCP_SKB_CB(skb)->tcp_flags |= TCPHDR_PSH; + err = tcp_transmit_skb(sk, skb, 1, GFP_ATOMIC); + if (!err) + tcp_event_new_data_sent(sk, skb); + return err; + } else { + if (between(tp->snd_up, tp->snd_una + 1, tp->snd_una + 0xFFFF)) + tcp_xmit_probe_skb(sk, 1, mib); + return tcp_xmit_probe_skb(sk, 0, mib); + } +} + +/* A window probe timeout has occurred. If window is not closed send + * a partial packet else a zero probe. + */ +void tcp_send_probe0(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); + unsigned long timeout; + int err; + + err = tcp_write_wakeup(sk, LINUX_MIB_TCPWINPROBE); + + if (tp->packets_out || tcp_write_queue_empty(sk)) { + /* Cancel probe timer, if it is not required. */ + icsk->icsk_probes_out = 0; + icsk->icsk_backoff = 0; + icsk->icsk_probes_tstamp = 0; + return; + } + + icsk->icsk_probes_out++; + if (err <= 0) { + if (icsk->icsk_backoff < READ_ONCE(net->ipv4.sysctl_tcp_retries2)) + icsk->icsk_backoff++; + timeout = tcp_probe0_when(sk, TCP_RTO_MAX); + } else { + /* If packet was not sent due to local congestion, + * Let senders fight for local resources conservatively. + */ + timeout = TCP_RESOURCE_PROBE_INTERVAL; + } + + timeout = tcp_clamp_probe0_to_user_timeout(sk, timeout); + tcp_reset_xmit_timer(sk, ICSK_TIME_PROBE0, timeout, TCP_RTO_MAX); +} + +int tcp_rtx_synack(const struct sock *sk, struct request_sock *req) +{ + const struct tcp_request_sock_ops *af_ops = tcp_rsk(req)->af_specific; + struct flowi fl; + int res; + + /* Paired with WRITE_ONCE() in sock_setsockopt() */ + if (READ_ONCE(sk->sk_txrehash) == SOCK_TXREHASH_ENABLED) + WRITE_ONCE(tcp_rsk(req)->txhash, net_tx_rndhash()); + res = af_ops->send_synack(sk, NULL, &fl, req, NULL, TCP_SYNACK_NORMAL, + NULL); + if (!res) { + TCP_INC_STATS(sock_net(sk), TCP_MIB_RETRANSSEGS); + NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPSYNRETRANS); + if (unlikely(tcp_passive_fastopen(sk))) + tcp_sk(sk)->total_retrans++; + trace_tcp_retransmit_synack(sk, req); + } + return res; +} +EXPORT_SYMBOL(tcp_rtx_synack); diff --git a/net/ipv4/tcp_rate.c b/net/ipv4/tcp_rate.c new file mode 100644 index 000000000..a8f6d9d06 --- /dev/null +++ b/net/ipv4/tcp_rate.c @@ -0,0 +1,209 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include + +/* The bandwidth estimator estimates the rate at which the network + * can currently deliver outbound data packets for this flow. At a high + * level, it operates by taking a delivery rate sample for each ACK. + * + * A rate sample records the rate at which the network delivered packets + * for this flow, calculated over the time interval between the transmission + * of a data packet and the acknowledgment of that packet. + * + * Specifically, over the interval between each transmit and corresponding ACK, + * the estimator generates a delivery rate sample. Typically it uses the rate + * at which packets were acknowledged. However, the approach of using only the + * acknowledgment rate faces a challenge under the prevalent ACK decimation or + * compression: packets can temporarily appear to be delivered much quicker + * than the bottleneck rate. Since it is physically impossible to do that in a + * sustained fashion, when the estimator notices that the ACK rate is faster + * than the transmit rate, it uses the latter: + * + * send_rate = #pkts_delivered/(last_snd_time - first_snd_time) + * ack_rate = #pkts_delivered/(last_ack_time - first_ack_time) + * bw = min(send_rate, ack_rate) + * + * Notice the estimator essentially estimates the goodput, not always the + * network bottleneck link rate when the sending or receiving is limited by + * other factors like applications or receiver window limits. The estimator + * deliberately avoids using the inter-packet spacing approach because that + * approach requires a large number of samples and sophisticated filtering. + * + * TCP flows can often be application-limited in request/response workloads. + * The estimator marks a bandwidth sample as application-limited if there + * was some moment during the sampled window of packets when there was no data + * ready to send in the write queue. + */ + +/* Snapshot the current delivery information in the skb, to generate + * a rate sample later when the skb is (s)acked in tcp_rate_skb_delivered(). + */ +void tcp_rate_skb_sent(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_sock *tp = tcp_sk(sk); + + /* In general we need to start delivery rate samples from the + * time we received the most recent ACK, to ensure we include + * the full time the network needs to deliver all in-flight + * packets. If there are no packets in flight yet, then we + * know that any ACKs after now indicate that the network was + * able to deliver those packets completely in the sampling + * interval between now and the next ACK. + * + * Note that we use packets_out instead of tcp_packets_in_flight(tp) + * because the latter is a guess based on RTO and loss-marking + * heuristics. We don't want spurious RTOs or loss markings to cause + * a spuriously small time interval, causing a spuriously high + * bandwidth estimate. + */ + if (!tp->packets_out) { + u64 tstamp_us = tcp_skb_timestamp_us(skb); + + tp->first_tx_mstamp = tstamp_us; + tp->delivered_mstamp = tstamp_us; + } + + TCP_SKB_CB(skb)->tx.first_tx_mstamp = tp->first_tx_mstamp; + TCP_SKB_CB(skb)->tx.delivered_mstamp = tp->delivered_mstamp; + TCP_SKB_CB(skb)->tx.delivered = tp->delivered; + TCP_SKB_CB(skb)->tx.delivered_ce = tp->delivered_ce; + TCP_SKB_CB(skb)->tx.is_app_limited = tp->app_limited ? 1 : 0; +} + +/* When an skb is sacked or acked, we fill in the rate sample with the (prior) + * delivery information when the skb was last transmitted. + * + * If an ACK (s)acks multiple skbs (e.g., stretched-acks), this function is + * called multiple times. We favor the information from the most recently + * sent skb, i.e., the skb with the most recently sent time and the highest + * sequence. + */ +void tcp_rate_skb_delivered(struct sock *sk, struct sk_buff *skb, + struct rate_sample *rs) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct tcp_skb_cb *scb = TCP_SKB_CB(skb); + u64 tx_tstamp; + + if (!scb->tx.delivered_mstamp) + return; + + tx_tstamp = tcp_skb_timestamp_us(skb); + if (!rs->prior_delivered || + tcp_skb_sent_after(tx_tstamp, tp->first_tx_mstamp, + scb->end_seq, rs->last_end_seq)) { + rs->prior_delivered_ce = scb->tx.delivered_ce; + rs->prior_delivered = scb->tx.delivered; + rs->prior_mstamp = scb->tx.delivered_mstamp; + rs->is_app_limited = scb->tx.is_app_limited; + rs->is_retrans = scb->sacked & TCPCB_RETRANS; + rs->last_end_seq = scb->end_seq; + + /* Record send time of most recently ACKed packet: */ + tp->first_tx_mstamp = tx_tstamp; + /* Find the duration of the "send phase" of this window: */ + rs->interval_us = tcp_stamp_us_delta(tp->first_tx_mstamp, + scb->tx.first_tx_mstamp); + + } + /* Mark off the skb delivered once it's sacked to avoid being + * used again when it's cumulatively acked. For acked packets + * we don't need to reset since it'll be freed soon. + */ + if (scb->sacked & TCPCB_SACKED_ACKED) + scb->tx.delivered_mstamp = 0; +} + +/* Update the connection delivery information and generate a rate sample. */ +void tcp_rate_gen(struct sock *sk, u32 delivered, u32 lost, + bool is_sack_reneg, struct rate_sample *rs) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 snd_us, ack_us; + + /* Clear app limited if bubble is acked and gone. */ + if (tp->app_limited && after(tp->delivered, tp->app_limited)) + tp->app_limited = 0; + + /* TODO: there are multiple places throughout tcp_ack() to get + * current time. Refactor the code using a new "tcp_acktag_state" + * to carry current time, flags, stats like "tcp_sacktag_state". + */ + if (delivered) + tp->delivered_mstamp = tp->tcp_mstamp; + + rs->acked_sacked = delivered; /* freshly ACKed or SACKed */ + rs->losses = lost; /* freshly marked lost */ + /* Return an invalid sample if no timing information is available or + * in recovery from loss with SACK reneging. Rate samples taken during + * a SACK reneging event may overestimate bw by including packets that + * were SACKed before the reneg. + */ + if (!rs->prior_mstamp || is_sack_reneg) { + rs->delivered = -1; + rs->interval_us = -1; + return; + } + rs->delivered = tp->delivered - rs->prior_delivered; + + rs->delivered_ce = tp->delivered_ce - rs->prior_delivered_ce; + /* delivered_ce occupies less than 32 bits in the skb control block */ + rs->delivered_ce &= TCPCB_DELIVERED_CE_MASK; + + /* Model sending data and receiving ACKs as separate pipeline phases + * for a window. Usually the ACK phase is longer, but with ACK + * compression the send phase can be longer. To be safe we use the + * longer phase. + */ + snd_us = rs->interval_us; /* send phase */ + ack_us = tcp_stamp_us_delta(tp->tcp_mstamp, + rs->prior_mstamp); /* ack phase */ + rs->interval_us = max(snd_us, ack_us); + + /* Record both segment send and ack receive intervals */ + rs->snd_interval_us = snd_us; + rs->rcv_interval_us = ack_us; + + /* Normally we expect interval_us >= min-rtt. + * Note that rate may still be over-estimated when a spuriously + * retransmistted skb was first (s)acked because "interval_us" + * is under-estimated (up to an RTT). However continuously + * measuring the delivery rate during loss recovery is crucial + * for connections suffer heavy or prolonged losses. + */ + if (unlikely(rs->interval_us < tcp_min_rtt(tp))) { + if (!rs->is_retrans) + pr_debug("tcp rate: %ld %d %u %u %u\n", + rs->interval_us, rs->delivered, + inet_csk(sk)->icsk_ca_state, + tp->rx_opt.sack_ok, tcp_min_rtt(tp)); + rs->interval_us = -1; + return; + } + + /* Record the last non-app-limited or the highest app-limited bw */ + if (!rs->is_app_limited || + ((u64)rs->delivered * tp->rate_interval_us >= + (u64)tp->rate_delivered * rs->interval_us)) { + tp->rate_delivered = rs->delivered; + tp->rate_interval_us = rs->interval_us; + tp->rate_app_limited = rs->is_app_limited; + } +} + +/* If a gap is detected between sends, mark the socket application-limited. */ +void tcp_rate_check_app_limited(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (/* We have less than one packet to send. */ + tp->write_seq - tp->snd_nxt < tp->mss_cache && + /* Nothing in sending host's qdisc queues or NIC tx queue. */ + sk_wmem_alloc_get(sk) < SKB_TRUESIZE(1) && + /* We are not limited by CWND. */ + tcp_packets_in_flight(tp) < tcp_snd_cwnd(tp) && + /* All lost packets have been retransmitted. */ + tp->lost_out <= tp->retrans_out) + tp->app_limited = + (tp->delivered + tcp_packets_in_flight(tp)) ? : 1; +} +EXPORT_SYMBOL_GPL(tcp_rate_check_app_limited); diff --git a/net/ipv4/tcp_recovery.c b/net/ipv4/tcp_recovery.c new file mode 100644 index 000000000..c08579369 --- /dev/null +++ b/net/ipv4/tcp_recovery.c @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include + +static u32 tcp_rack_reo_wnd(const struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (!tp->reord_seen) { + /* If reordering has not been observed, be aggressive during + * the recovery or starting the recovery by DUPACK threshold. + */ + if (inet_csk(sk)->icsk_ca_state >= TCP_CA_Recovery) + return 0; + + if (tp->sacked_out >= tp->reordering && + !(READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_recovery) & + TCP_RACK_NO_DUPTHRESH)) + return 0; + } + + /* To be more reordering resilient, allow min_rtt/4 settling delay. + * Use min_rtt instead of the smoothed RTT because reordering is + * often a path property and less related to queuing or delayed ACKs. + * Upon receiving DSACKs, linearly increase the window up to the + * smoothed RTT. + */ + return min((tcp_min_rtt(tp) >> 2) * tp->rack.reo_wnd_steps, + tp->srtt_us >> 3); +} + +s32 tcp_rack_skb_timeout(struct tcp_sock *tp, struct sk_buff *skb, u32 reo_wnd) +{ + return tp->rack.rtt_us + reo_wnd - + tcp_stamp_us_delta(tp->tcp_mstamp, tcp_skb_timestamp_us(skb)); +} + +/* RACK loss detection (IETF draft draft-ietf-tcpm-rack-01): + * + * Marks a packet lost, if some packet sent later has been (s)acked. + * The underlying idea is similar to the traditional dupthresh and FACK + * but they look at different metrics: + * + * dupthresh: 3 OOO packets delivered (packet count) + * FACK: sequence delta to highest sacked sequence (sequence space) + * RACK: sent time delta to the latest delivered packet (time domain) + * + * The advantage of RACK is it applies to both original and retransmitted + * packet and therefore is robust against tail losses. Another advantage + * is being more resilient to reordering by simply allowing some + * "settling delay", instead of tweaking the dupthresh. + * + * When tcp_rack_detect_loss() detects some packets are lost and we + * are not already in the CA_Recovery state, either tcp_rack_reo_timeout() + * or tcp_time_to_recover()'s "Trick#1: the loss is proven" code path will + * make us enter the CA_Recovery state. + */ +static void tcp_rack_detect_loss(struct sock *sk, u32 *reo_timeout) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct sk_buff *skb, *n; + u32 reo_wnd; + + *reo_timeout = 0; + reo_wnd = tcp_rack_reo_wnd(sk); + list_for_each_entry_safe(skb, n, &tp->tsorted_sent_queue, + tcp_tsorted_anchor) { + struct tcp_skb_cb *scb = TCP_SKB_CB(skb); + s32 remaining; + + /* Skip ones marked lost but not yet retransmitted */ + if ((scb->sacked & TCPCB_LOST) && + !(scb->sacked & TCPCB_SACKED_RETRANS)) + continue; + + if (!tcp_skb_sent_after(tp->rack.mstamp, + tcp_skb_timestamp_us(skb), + tp->rack.end_seq, scb->end_seq)) + break; + + /* A packet is lost if it has not been s/acked beyond + * the recent RTT plus the reordering window. + */ + remaining = tcp_rack_skb_timeout(tp, skb, reo_wnd); + if (remaining <= 0) { + tcp_mark_skb_lost(sk, skb); + list_del_init(&skb->tcp_tsorted_anchor); + } else { + /* Record maximum wait time */ + *reo_timeout = max_t(u32, *reo_timeout, remaining); + } + } +} + +bool tcp_rack_mark_lost(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 timeout; + + if (!tp->rack.advanced) + return false; + + /* Reset the advanced flag to avoid unnecessary queue scanning */ + tp->rack.advanced = 0; + tcp_rack_detect_loss(sk, &timeout); + if (timeout) { + timeout = usecs_to_jiffies(timeout + TCP_TIMEOUT_MIN_US); + inet_csk_reset_xmit_timer(sk, ICSK_TIME_REO_TIMEOUT, + timeout, inet_csk(sk)->icsk_rto); + } + return !!timeout; +} + +/* Record the most recently (re)sent time among the (s)acked packets + * This is "Step 3: Advance RACK.xmit_time and update RACK.RTT" from + * draft-cheng-tcpm-rack-00.txt + */ +void tcp_rack_advance(struct tcp_sock *tp, u8 sacked, u32 end_seq, + u64 xmit_time) +{ + u32 rtt_us; + + rtt_us = tcp_stamp_us_delta(tp->tcp_mstamp, xmit_time); + if (rtt_us < tcp_min_rtt(tp) && (sacked & TCPCB_RETRANS)) { + /* If the sacked packet was retransmitted, it's ambiguous + * whether the retransmission or the original (or the prior + * retransmission) was sacked. + * + * If the original is lost, there is no ambiguity. Otherwise + * we assume the original can be delayed up to aRTT + min_rtt. + * the aRTT term is bounded by the fast recovery or timeout, + * so it's at least one RTT (i.e., retransmission is at least + * an RTT later). + */ + return; + } + tp->rack.advanced = 1; + tp->rack.rtt_us = rtt_us; + if (tcp_skb_sent_after(xmit_time, tp->rack.mstamp, + end_seq, tp->rack.end_seq)) { + tp->rack.mstamp = xmit_time; + tp->rack.end_seq = end_seq; + } +} + +/* We have waited long enough to accommodate reordering. Mark the expired + * packets lost and retransmit them. + */ +void tcp_rack_reo_timeout(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + u32 timeout, prior_inflight; + u32 lost = tp->lost; + + prior_inflight = tcp_packets_in_flight(tp); + tcp_rack_detect_loss(sk, &timeout); + if (prior_inflight != tcp_packets_in_flight(tp)) { + if (inet_csk(sk)->icsk_ca_state != TCP_CA_Recovery) { + tcp_enter_recovery(sk, false); + if (!inet_csk(sk)->icsk_ca_ops->cong_control) + tcp_cwnd_reduction(sk, 1, tp->lost - lost, 0); + } + tcp_xmit_retransmit_queue(sk); + } + if (inet_csk(sk)->icsk_pending != ICSK_TIME_RETRANS) + tcp_rearm_rto(sk); +} + +/* Updates the RACK's reo_wnd based on DSACK and no. of recoveries. + * + * If a DSACK is received that seems like it may have been due to reordering + * triggering fast recovery, increment reo_wnd by min_rtt/4 (upper bounded + * by srtt), since there is possibility that spurious retransmission was + * due to reordering delay longer than reo_wnd. + * + * Persist the current reo_wnd value for TCP_RACK_RECOVERY_THRESH (16) + * no. of successful recoveries (accounts for full DSACK-based loss + * recovery undo). After that, reset it to default (min_rtt/4). + * + * At max, reo_wnd is incremented only once per rtt. So that the new + * DSACK on which we are reacting, is due to the spurious retx (approx) + * after the reo_wnd has been updated last time. + * + * reo_wnd is tracked in terms of steps (of min_rtt/4), rather than + * absolute value to account for change in rtt. + */ +void tcp_rack_update_reo_wnd(struct sock *sk, struct rate_sample *rs) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if ((READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_recovery) & + TCP_RACK_STATIC_REO_WND) || + !rs->prior_delivered) + return; + + /* Disregard DSACK if a rtt has not passed since we adjusted reo_wnd */ + if (before(rs->prior_delivered, tp->rack.last_delivered)) + tp->rack.dsack_seen = 0; + + /* Adjust the reo_wnd if update is pending */ + if (tp->rack.dsack_seen) { + tp->rack.reo_wnd_steps = min_t(u32, 0xFF, + tp->rack.reo_wnd_steps + 1); + tp->rack.dsack_seen = 0; + tp->rack.last_delivered = tp->delivered; + tp->rack.reo_wnd_persist = TCP_RACK_RECOVERY_THRESH; + } else if (!tp->rack.reo_wnd_persist) { + tp->rack.reo_wnd_steps = 1; + } +} + +/* RFC6582 NewReno recovery for non-SACK connection. It simply retransmits + * the next unacked packet upon receiving + * a) three or more DUPACKs to start the fast recovery + * b) an ACK acknowledging new data during the fast recovery. + */ +void tcp_newreno_mark_lost(struct sock *sk, bool snd_una_advanced) +{ + const u8 state = inet_csk(sk)->icsk_ca_state; + struct tcp_sock *tp = tcp_sk(sk); + + if ((state < TCP_CA_Recovery && tp->sacked_out >= tp->reordering) || + (state == TCP_CA_Recovery && snd_una_advanced)) { + struct sk_buff *skb = tcp_rtx_queue_head(sk); + u32 mss; + + if (TCP_SKB_CB(skb)->sacked & TCPCB_LOST) + return; + + mss = tcp_skb_mss(skb); + if (tcp_skb_pcount(skb) > 1 && skb->len > mss) + tcp_fragment(sk, TCP_FRAG_IN_RTX_QUEUE, skb, + mss, mss, GFP_ATOMIC); + + tcp_mark_skb_lost(sk, skb); + } +} diff --git a/net/ipv4/tcp_scalable.c b/net/ipv4/tcp_scalable.c new file mode 100644 index 000000000..862b96248 --- /dev/null +++ b/net/ipv4/tcp_scalable.c @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Tom Kelly's Scalable TCP + * + * See http://www.deneholme.net/tom/scalable/ + * + * John Heffner + */ + +#include +#include + +/* These factors derived from the recommended values in the aer: + * .01 and 7/8. + */ +#define TCP_SCALABLE_AI_CNT 100U +#define TCP_SCALABLE_MD_SCALE 3 + +static void tcp_scalable_cong_avoid(struct sock *sk, u32 ack, u32 acked) +{ + struct tcp_sock *tp = tcp_sk(sk); + + if (!tcp_is_cwnd_limited(sk)) + return; + + if (tcp_in_slow_start(tp)) { + acked = tcp_slow_start(tp, acked); + if (!acked) + return; + } + tcp_cong_avoid_ai(tp, min(tcp_snd_cwnd(tp), TCP_SCALABLE_AI_CNT), + acked); +} + +static u32 tcp_scalable_ssthresh(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + + return max(tcp_snd_cwnd(tp) - (tcp_snd_cwnd(tp)>>TCP_SCALABLE_MD_SCALE), 2U); +} + +static struct tcp_congestion_ops tcp_scalable __read_mostly = { + .ssthresh = tcp_scalable_ssthresh, + .undo_cwnd = tcp_reno_undo_cwnd, + .cong_avoid = tcp_scalable_cong_avoid, + + .owner = THIS_MODULE, + .name = "scalable", +}; + +static int __init tcp_scalable_register(void) +{ + return tcp_register_congestion_control(&tcp_scalable); +} + +static void __exit tcp_scalable_unregister(void) +{ + tcp_unregister_congestion_control(&tcp_scalable); +} + +module_init(tcp_scalable_register); +module_exit(tcp_scalable_unregister); + +MODULE_AUTHOR("John Heffner"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Scalable TCP"); diff --git a/net/ipv4/tcp_timer.c b/net/ipv4/tcp_timer.c new file mode 100644 index 000000000..44b49f7d1 --- /dev/null +++ b/net/ipv4/tcp_timer.c @@ -0,0 +1,819 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * Implementation of the Transmission Control Protocol(TCP). + * + * Authors: Ross Biro + * Fred N. van Kempen, + * Mark Evans, + * Corey Minyard + * Florian La Roche, + * Charles Hedrick, + * Linus Torvalds, + * Alan Cox, + * Matthew Dillon, + * Arnt Gulbrandsen, + * Jorge Cwik, + */ + +#include +#include +#include + +static u32 tcp_clamp_rto_to_user_timeout(const struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + u32 elapsed, start_ts; + s32 remaining; + + start_ts = tcp_sk(sk)->retrans_stamp; + if (!icsk->icsk_user_timeout) + return icsk->icsk_rto; + elapsed = tcp_time_stamp(tcp_sk(sk)) - start_ts; + remaining = icsk->icsk_user_timeout - elapsed; + if (remaining <= 0) + return 1; /* user timeout has passed; fire ASAP */ + + return min_t(u32, icsk->icsk_rto, msecs_to_jiffies(remaining)); +} + +u32 tcp_clamp_probe0_to_user_timeout(const struct sock *sk, u32 when) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + u32 remaining; + s32 elapsed; + + if (!icsk->icsk_user_timeout || !icsk->icsk_probes_tstamp) + return when; + + elapsed = tcp_jiffies32 - icsk->icsk_probes_tstamp; + if (unlikely(elapsed < 0)) + elapsed = 0; + remaining = msecs_to_jiffies(icsk->icsk_user_timeout) - elapsed; + remaining = max_t(u32, remaining, TCP_TIMEOUT_MIN); + + return min_t(u32, remaining, when); +} + +/** + * tcp_write_err() - close socket and save error info + * @sk: The socket the error has appeared on. + * + * Returns: Nothing (void) + */ + +static void tcp_write_err(struct sock *sk) +{ + sk->sk_err = sk->sk_err_soft ? : ETIMEDOUT; + sk_error_report(sk); + + tcp_write_queue_purge(sk); + tcp_done(sk); + __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTONTIMEOUT); +} + +/** + * tcp_out_of_resources() - Close socket if out of resources + * @sk: pointer to current socket + * @do_reset: send a last packet with reset flag + * + * Do not allow orphaned sockets to eat all our resources. + * This is direct violation of TCP specs, but it is required + * to prevent DoS attacks. It is called when a retransmission timeout + * or zero probe timeout occurs on orphaned socket. + * + * Also close if our net namespace is exiting; in that case there is no + * hope of ever communicating again since all netns interfaces are already + * down (or about to be down), and we need to release our dst references, + * which have been moved to the netns loopback interface, so the namespace + * can finish exiting. This condition is only possible if we are a kernel + * socket, as those do not hold references to the namespace. + * + * Criteria is still not confirmed experimentally and may change. + * We kill the socket, if: + * 1. If number of orphaned sockets exceeds an administratively configured + * limit. + * 2. If we have strong memory pressure. + * 3. If our net namespace is exiting. + */ +static int tcp_out_of_resources(struct sock *sk, bool do_reset) +{ + struct tcp_sock *tp = tcp_sk(sk); + int shift = 0; + + /* If peer does not open window for long time, or did not transmit + * anything for long time, penalize it. */ + if ((s32)(tcp_jiffies32 - tp->lsndtime) > 2*TCP_RTO_MAX || !do_reset) + shift++; + + /* If some dubious ICMP arrived, penalize even more. */ + if (sk->sk_err_soft) + shift++; + + if (tcp_check_oom(sk, shift)) { + /* Catch exceptional cases, when connection requires reset. + * 1. Last segment was sent recently. */ + if ((s32)(tcp_jiffies32 - tp->lsndtime) <= TCP_TIMEWAIT_LEN || + /* 2. Window is closed. */ + (!tp->snd_wnd && !tp->packets_out)) + do_reset = true; + if (do_reset) + tcp_send_active_reset(sk, GFP_ATOMIC); + tcp_done(sk); + __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPABORTONMEMORY); + return 1; + } + + if (!check_net(sock_net(sk))) { + /* Not possible to send reset; just close */ + tcp_done(sk); + return 1; + } + + return 0; +} + +/** + * tcp_orphan_retries() - Returns maximal number of retries on an orphaned socket + * @sk: Pointer to the current socket. + * @alive: bool, socket alive state + */ +static int tcp_orphan_retries(struct sock *sk, bool alive) +{ + int retries = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_orphan_retries); /* May be zero. */ + + /* We know from an ICMP that something is wrong. */ + if (sk->sk_err_soft && !alive) + retries = 0; + + /* However, if socket sent something recently, select some safe + * number of retries. 8 corresponds to >100 seconds with minimal + * RTO of 200msec. */ + if (retries == 0 && alive) + retries = 8; + return retries; +} + +static void tcp_mtu_probing(struct inet_connection_sock *icsk, struct sock *sk) +{ + const struct net *net = sock_net(sk); + int mss; + + /* Black hole detection */ + if (!READ_ONCE(net->ipv4.sysctl_tcp_mtu_probing)) + return; + + if (!icsk->icsk_mtup.enabled) { + icsk->icsk_mtup.enabled = 1; + icsk->icsk_mtup.probe_timestamp = tcp_jiffies32; + } else { + mss = tcp_mtu_to_mss(sk, icsk->icsk_mtup.search_low) >> 1; + mss = min(READ_ONCE(net->ipv4.sysctl_tcp_base_mss), mss); + mss = max(mss, READ_ONCE(net->ipv4.sysctl_tcp_mtu_probe_floor)); + mss = max(mss, READ_ONCE(net->ipv4.sysctl_tcp_min_snd_mss)); + icsk->icsk_mtup.search_low = tcp_mss_to_mtu(sk, mss); + } + tcp_sync_mss(sk, icsk->icsk_pmtu_cookie); +} + +static unsigned int tcp_model_timeout(struct sock *sk, + unsigned int boundary, + unsigned int rto_base) +{ + unsigned int linear_backoff_thresh, timeout; + + linear_backoff_thresh = ilog2(TCP_RTO_MAX / rto_base); + if (boundary <= linear_backoff_thresh) + timeout = ((2 << boundary) - 1) * rto_base; + else + timeout = ((2 << linear_backoff_thresh) - 1) * rto_base + + (boundary - linear_backoff_thresh) * TCP_RTO_MAX; + return jiffies_to_msecs(timeout); +} +/** + * retransmits_timed_out() - returns true if this connection has timed out + * @sk: The current socket + * @boundary: max number of retransmissions + * @timeout: A custom timeout value. + * If set to 0 the default timeout is calculated and used. + * Using TCP_RTO_MIN and the number of unsuccessful retransmits. + * + * The default "timeout" value this function can calculate and use + * is equivalent to the timeout of a TCP Connection + * after "boundary" unsuccessful, exponentially backed-off + * retransmissions with an initial RTO of TCP_RTO_MIN. + */ +static bool retransmits_timed_out(struct sock *sk, + unsigned int boundary, + unsigned int timeout) +{ + unsigned int start_ts; + + if (!inet_csk(sk)->icsk_retransmits) + return false; + + start_ts = tcp_sk(sk)->retrans_stamp; + if (likely(timeout == 0)) { + unsigned int rto_base = TCP_RTO_MIN; + + if ((1 << sk->sk_state) & (TCPF_SYN_SENT | TCPF_SYN_RECV)) + rto_base = tcp_timeout_init(sk); + timeout = tcp_model_timeout(sk, boundary, rto_base); + } + + return (s32)(tcp_time_stamp(tcp_sk(sk)) - start_ts - timeout) >= 0; +} + +/* A write timeout has occurred. Process the after effects. */ +static int tcp_write_timeout(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); + bool expired = false, do_reset; + int retry_until; + + if ((1 << sk->sk_state) & (TCPF_SYN_SENT | TCPF_SYN_RECV)) { + if (icsk->icsk_retransmits) + __dst_negative_advice(sk); + retry_until = icsk->icsk_syn_retries ? : + READ_ONCE(net->ipv4.sysctl_tcp_syn_retries); + expired = icsk->icsk_retransmits >= retry_until; + } else { + if (retransmits_timed_out(sk, READ_ONCE(net->ipv4.sysctl_tcp_retries1), 0)) { + /* Black hole detection */ + tcp_mtu_probing(icsk, sk); + + __dst_negative_advice(sk); + } + + retry_until = READ_ONCE(net->ipv4.sysctl_tcp_retries2); + if (sock_flag(sk, SOCK_DEAD)) { + const bool alive = icsk->icsk_rto < TCP_RTO_MAX; + + retry_until = tcp_orphan_retries(sk, alive); + do_reset = alive || + !retransmits_timed_out(sk, retry_until, 0); + + if (tcp_out_of_resources(sk, do_reset)) + return 1; + } + } + if (!expired) + expired = retransmits_timed_out(sk, retry_until, + icsk->icsk_user_timeout); + tcp_fastopen_active_detect_blackhole(sk, expired); + + if (BPF_SOCK_OPS_TEST_FLAG(tp, BPF_SOCK_OPS_RTO_CB_FLAG)) + tcp_call_bpf_3arg(sk, BPF_SOCK_OPS_RTO_CB, + icsk->icsk_retransmits, + icsk->icsk_rto, (int)expired); + + if (expired) { + /* Has it gone just too far? */ + tcp_write_err(sk); + return 1; + } + + if (sk_rethink_txhash(sk)) { + tp->timeout_rehash++; + __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPTIMEOUTREHASH); + } + + return 0; +} + +/* Called with BH disabled */ +void tcp_delack_timer_handler(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + + if ((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN)) + return; + + /* Handling the sack compression case */ + if (tp->compressed_ack) { + tcp_mstamp_refresh(tp); + tcp_sack_compress_send_ack(sk); + return; + } + + if (!(icsk->icsk_ack.pending & ICSK_ACK_TIMER)) + return; + + if (time_after(icsk->icsk_ack.timeout, jiffies)) { + sk_reset_timer(sk, &icsk->icsk_delack_timer, icsk->icsk_ack.timeout); + return; + } + icsk->icsk_ack.pending &= ~ICSK_ACK_TIMER; + + if (inet_csk_ack_scheduled(sk)) { + if (!inet_csk_in_pingpong_mode(sk)) { + /* Delayed ACK missed: inflate ATO. */ + icsk->icsk_ack.ato = min(icsk->icsk_ack.ato << 1, icsk->icsk_rto); + } else { + /* Delayed ACK missed: leave pingpong mode and + * deflate ATO. + */ + inet_csk_exit_pingpong_mode(sk); + icsk->icsk_ack.ato = TCP_ATO_MIN; + } + tcp_mstamp_refresh(tp); + tcp_send_ack(sk); + __NET_INC_STATS(sock_net(sk), LINUX_MIB_DELAYEDACKS); + } +} + + +/** + * tcp_delack_timer() - The TCP delayed ACK timeout handler + * @t: Pointer to the timer. (gets casted to struct sock *) + * + * This function gets (indirectly) called when the kernel timer for a TCP packet + * of this socket expires. Calls tcp_delack_timer_handler() to do the actual work. + * + * Returns: Nothing (void) + */ +static void tcp_delack_timer(struct timer_list *t) +{ + struct inet_connection_sock *icsk = + from_timer(icsk, t, icsk_delack_timer); + struct sock *sk = &icsk->icsk_inet.sk; + + bh_lock_sock(sk); + if (!sock_owned_by_user(sk)) { + tcp_delack_timer_handler(sk); + } else { + __NET_INC_STATS(sock_net(sk), LINUX_MIB_DELAYEDACKLOCKED); + /* deleguate our work to tcp_release_cb() */ + if (!test_and_set_bit(TCP_DELACK_TIMER_DEFERRED, &sk->sk_tsq_flags)) + sock_hold(sk); + } + bh_unlock_sock(sk); + sock_put(sk); +} + +static void tcp_probe_timer(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct sk_buff *skb = tcp_send_head(sk); + struct tcp_sock *tp = tcp_sk(sk); + int max_probes; + + if (tp->packets_out || !skb) { + icsk->icsk_probes_out = 0; + icsk->icsk_probes_tstamp = 0; + return; + } + + /* RFC 1122 4.2.2.17 requires the sender to stay open indefinitely as + * long as the receiver continues to respond probes. We support this by + * default and reset icsk_probes_out with incoming ACKs. But if the + * socket is orphaned or the user specifies TCP_USER_TIMEOUT, we + * kill the socket when the retry count and the time exceeds the + * corresponding system limit. We also implement similar policy when + * we use RTO to probe window in tcp_retransmit_timer(). + */ + if (!icsk->icsk_probes_tstamp) + icsk->icsk_probes_tstamp = tcp_jiffies32; + else if (icsk->icsk_user_timeout && + (s32)(tcp_jiffies32 - icsk->icsk_probes_tstamp) >= + msecs_to_jiffies(icsk->icsk_user_timeout)) + goto abort; + + max_probes = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_retries2); + if (sock_flag(sk, SOCK_DEAD)) { + const bool alive = inet_csk_rto_backoff(icsk, TCP_RTO_MAX) < TCP_RTO_MAX; + + max_probes = tcp_orphan_retries(sk, alive); + if (!alive && icsk->icsk_backoff >= max_probes) + goto abort; + if (tcp_out_of_resources(sk, true)) + return; + } + + if (icsk->icsk_probes_out >= max_probes) { +abort: tcp_write_err(sk); + } else { + /* Only send another probe if we didn't close things up. */ + tcp_send_probe0(sk); + } +} + +/* + * Timer for Fast Open socket to retransmit SYNACK. Note that the + * sk here is the child socket, not the parent (listener) socket. + */ +static void tcp_fastopen_synack_timer(struct sock *sk, struct request_sock *req) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + int max_retries; + + req->rsk_ops->syn_ack_timeout(req); + + /* add one more retry for fastopen */ + max_retries = icsk->icsk_syn_retries ? : + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_synack_retries) + 1; + + if (req->num_timeout >= max_retries) { + tcp_write_err(sk); + return; + } + /* Lower cwnd after certain SYNACK timeout like tcp_init_transfer() */ + if (icsk->icsk_retransmits == 1) + tcp_enter_loss(sk); + /* XXX (TFO) - Unlike regular SYN-ACK retransmit, we ignore error + * returned from rtx_syn_ack() to make it more persistent like + * regular retransmit because if the child socket has been accepted + * it's not good to give up too easily. + */ + inet_rtx_syn_ack(sk, req); + req->num_timeout++; + icsk->icsk_retransmits++; + if (!tp->retrans_stamp) + tp->retrans_stamp = tcp_time_stamp(tp); + inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, + req->timeout << req->num_timeout, TCP_RTO_MAX); +} + +static bool tcp_rtx_probe0_timed_out(const struct sock *sk, + const struct sk_buff *skb) +{ + const struct tcp_sock *tp = tcp_sk(sk); + const int timeout = TCP_RTO_MAX * 2; + u32 rcv_delta, rtx_delta; + + rcv_delta = inet_csk(sk)->icsk_timeout - tp->rcv_tstamp; + if (rcv_delta <= timeout) + return false; + + rtx_delta = (u32)msecs_to_jiffies(tcp_time_stamp(tp) - + (tp->retrans_stamp ?: tcp_skb_timestamp(skb))); + + return rtx_delta > timeout; +} + +/** + * tcp_retransmit_timer() - The TCP retransmit timeout handler + * @sk: Pointer to the current socket. + * + * This function gets called when the kernel timer for a TCP packet + * of this socket expires. + * + * It handles retransmission, timer adjustment and other necessary measures. + * + * Returns: Nothing (void) + */ +void tcp_retransmit_timer(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); + struct inet_connection_sock *icsk = inet_csk(sk); + struct request_sock *req; + struct sk_buff *skb; + + req = rcu_dereference_protected(tp->fastopen_rsk, + lockdep_sock_is_held(sk)); + if (req) { + WARN_ON_ONCE(sk->sk_state != TCP_SYN_RECV && + sk->sk_state != TCP_FIN_WAIT1); + tcp_fastopen_synack_timer(sk, req); + /* Before we receive ACK to our SYN-ACK don't retransmit + * anything else (e.g., data or FIN segments). + */ + return; + } + + if (!tp->packets_out) + return; + + skb = tcp_rtx_queue_head(sk); + if (WARN_ON_ONCE(!skb)) + return; + + tp->tlp_high_seq = 0; + + if (!tp->snd_wnd && !sock_flag(sk, SOCK_DEAD) && + !((1 << sk->sk_state) & (TCPF_SYN_SENT | TCPF_SYN_RECV))) { + /* Receiver dastardly shrinks window. Our retransmits + * become zero probes, but we should not timeout this + * connection. If the socket is an orphan, time it out, + * we cannot allow such beasts to hang infinitely. + */ + struct inet_sock *inet = inet_sk(sk); + if (sk->sk_family == AF_INET) { + net_dbg_ratelimited("Peer %pI4:%u/%u unexpectedly shrunk window %u:%u (repaired)\n", + &inet->inet_daddr, + ntohs(inet->inet_dport), + inet->inet_num, + tp->snd_una, tp->snd_nxt); + } +#if IS_ENABLED(CONFIG_IPV6) + else if (sk->sk_family == AF_INET6) { + net_dbg_ratelimited("Peer %pI6:%u/%u unexpectedly shrunk window %u:%u (repaired)\n", + &sk->sk_v6_daddr, + ntohs(inet->inet_dport), + inet->inet_num, + tp->snd_una, tp->snd_nxt); + } +#endif + if (tcp_rtx_probe0_timed_out(sk, skb)) { + tcp_write_err(sk); + goto out; + } + tcp_enter_loss(sk); + tcp_retransmit_skb(sk, skb, 1); + __sk_dst_reset(sk); + goto out_reset_timer; + } + + __NET_INC_STATS(sock_net(sk), LINUX_MIB_TCPTIMEOUTS); + if (tcp_write_timeout(sk)) + goto out; + + if (icsk->icsk_retransmits == 0) { + int mib_idx = 0; + + if (icsk->icsk_ca_state == TCP_CA_Recovery) { + if (tcp_is_sack(tp)) + mib_idx = LINUX_MIB_TCPSACKRECOVERYFAIL; + else + mib_idx = LINUX_MIB_TCPRENORECOVERYFAIL; + } else if (icsk->icsk_ca_state == TCP_CA_Loss) { + mib_idx = LINUX_MIB_TCPLOSSFAILURES; + } else if ((icsk->icsk_ca_state == TCP_CA_Disorder) || + tp->sacked_out) { + if (tcp_is_sack(tp)) + mib_idx = LINUX_MIB_TCPSACKFAILURES; + else + mib_idx = LINUX_MIB_TCPRENOFAILURES; + } + if (mib_idx) + __NET_INC_STATS(sock_net(sk), mib_idx); + } + + tcp_enter_loss(sk); + + icsk->icsk_retransmits++; + if (tcp_retransmit_skb(sk, tcp_rtx_queue_head(sk), 1) > 0) { + /* Retransmission failed because of local congestion, + * Let senders fight for local resources conservatively. + */ + inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, + TCP_RESOURCE_PROBE_INTERVAL, + TCP_RTO_MAX); + goto out; + } + + /* Increase the timeout each time we retransmit. Note that + * we do not increase the rtt estimate. rto is initialized + * from rtt, but increases here. Jacobson (SIGCOMM 88) suggests + * that doubling rto each time is the least we can get away with. + * In KA9Q, Karn uses this for the first few times, and then + * goes to quadratic. netBSD doubles, but only goes up to *64, + * and clamps at 1 to 64 sec afterwards. Note that 120 sec is + * defined in the protocol as the maximum possible RTT. I guess + * we'll have to use something other than TCP to talk to the + * University of Mars. + * + * PAWS allows us longer timeouts and large windows, so once + * implemented ftp to mars will work nicely. We will have to fix + * the 120 second clamps though! + */ + icsk->icsk_backoff++; + +out_reset_timer: + /* If stream is thin, use linear timeouts. Since 'icsk_backoff' is + * used to reset timer, set to 0. Recalculate 'icsk_rto' as this + * might be increased if the stream oscillates between thin and thick, + * thus the old value might already be too high compared to the value + * set by 'tcp_set_rto' in tcp_input.c which resets the rto without + * backoff. Limit to TCP_THIN_LINEAR_RETRIES before initiating + * exponential backoff behaviour to avoid continue hammering + * linear-timeout retransmissions into a black hole + */ + if (sk->sk_state == TCP_ESTABLISHED && + (tp->thin_lto || READ_ONCE(net->ipv4.sysctl_tcp_thin_linear_timeouts)) && + tcp_stream_is_thin(tp) && + icsk->icsk_retransmits <= TCP_THIN_LINEAR_RETRIES) { + icsk->icsk_backoff = 0; + icsk->icsk_rto = clamp(__tcp_set_rto(tp), + tcp_rto_min(sk), + TCP_RTO_MAX); + } else { + /* Use normal (exponential) backoff */ + icsk->icsk_rto = min(icsk->icsk_rto << 1, TCP_RTO_MAX); + } + inet_csk_reset_xmit_timer(sk, ICSK_TIME_RETRANS, + tcp_clamp_rto_to_user_timeout(sk), TCP_RTO_MAX); + if (retransmits_timed_out(sk, READ_ONCE(net->ipv4.sysctl_tcp_retries1) + 1, 0)) + __sk_dst_reset(sk); + +out:; +} + +/* Called with bottom-half processing disabled. + Called by tcp_write_timer() */ +void tcp_write_timer_handler(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + int event; + + if (((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN)) || + !icsk->icsk_pending) + return; + + if (time_after(icsk->icsk_timeout, jiffies)) { + sk_reset_timer(sk, &icsk->icsk_retransmit_timer, icsk->icsk_timeout); + return; + } + + tcp_mstamp_refresh(tcp_sk(sk)); + event = icsk->icsk_pending; + + switch (event) { + case ICSK_TIME_REO_TIMEOUT: + tcp_rack_reo_timeout(sk); + break; + case ICSK_TIME_LOSS_PROBE: + tcp_send_loss_probe(sk); + break; + case ICSK_TIME_RETRANS: + icsk->icsk_pending = 0; + tcp_retransmit_timer(sk); + break; + case ICSK_TIME_PROBE0: + icsk->icsk_pending = 0; + tcp_probe_timer(sk); + break; + } +} + +static void tcp_write_timer(struct timer_list *t) +{ + struct inet_connection_sock *icsk = + from_timer(icsk, t, icsk_retransmit_timer); + struct sock *sk = &icsk->icsk_inet.sk; + + bh_lock_sock(sk); + if (!sock_owned_by_user(sk)) { + tcp_write_timer_handler(sk); + } else { + /* delegate our work to tcp_release_cb() */ + if (!test_and_set_bit(TCP_WRITE_TIMER_DEFERRED, &sk->sk_tsq_flags)) + sock_hold(sk); + } + bh_unlock_sock(sk); + sock_put(sk); +} + +void tcp_syn_ack_timeout(const struct request_sock *req) +{ + struct net *net = read_pnet(&inet_rsk(req)->ireq_net); + + __NET_INC_STATS(net, LINUX_MIB_TCPTIMEOUTS); +} +EXPORT_SYMBOL(tcp_syn_ack_timeout); + +void tcp_set_keepalive(struct sock *sk, int val) +{ + if ((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN)) + return; + + if (val && !sock_flag(sk, SOCK_KEEPOPEN)) + inet_csk_reset_keepalive_timer(sk, keepalive_time_when(tcp_sk(sk))); + else if (!val) + inet_csk_delete_keepalive_timer(sk); +} +EXPORT_SYMBOL_GPL(tcp_set_keepalive); + + +static void tcp_keepalive_timer (struct timer_list *t) +{ + struct sock *sk = from_timer(sk, t, sk_timer); + struct inet_connection_sock *icsk = inet_csk(sk); + struct tcp_sock *tp = tcp_sk(sk); + u32 elapsed; + + /* Only process if socket is not in use. */ + bh_lock_sock(sk); + if (sock_owned_by_user(sk)) { + /* Try again later. */ + inet_csk_reset_keepalive_timer (sk, HZ/20); + goto out; + } + + if (sk->sk_state == TCP_LISTEN) { + pr_err("Hmm... keepalive on a LISTEN ???\n"); + goto out; + } + + tcp_mstamp_refresh(tp); + if (sk->sk_state == TCP_FIN_WAIT2 && sock_flag(sk, SOCK_DEAD)) { + if (tp->linger2 >= 0) { + const int tmo = tcp_fin_time(sk) - TCP_TIMEWAIT_LEN; + + if (tmo > 0) { + tcp_time_wait(sk, TCP_FIN_WAIT2, tmo); + goto out; + } + } + tcp_send_active_reset(sk, GFP_ATOMIC); + goto death; + } + + if (!sock_flag(sk, SOCK_KEEPOPEN) || + ((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_SYN_SENT))) + goto out; + + elapsed = keepalive_time_when(tp); + + /* It is alive without keepalive 8) */ + if (tp->packets_out || !tcp_write_queue_empty(sk)) + goto resched; + + elapsed = keepalive_time_elapsed(tp); + + if (elapsed >= keepalive_time_when(tp)) { + /* If the TCP_USER_TIMEOUT option is enabled, use that + * to determine when to timeout instead. + */ + if ((icsk->icsk_user_timeout != 0 && + elapsed >= msecs_to_jiffies(icsk->icsk_user_timeout) && + icsk->icsk_probes_out > 0) || + (icsk->icsk_user_timeout == 0 && + icsk->icsk_probes_out >= keepalive_probes(tp))) { + tcp_send_active_reset(sk, GFP_ATOMIC); + tcp_write_err(sk); + goto out; + } + if (tcp_write_wakeup(sk, LINUX_MIB_TCPKEEPALIVE) <= 0) { + icsk->icsk_probes_out++; + elapsed = keepalive_intvl_when(tp); + } else { + /* If keepalive was lost due to local congestion, + * try harder. + */ + elapsed = TCP_RESOURCE_PROBE_INTERVAL; + } + } else { + /* It is tp->rcv_tstamp + keepalive_time_when(tp) */ + elapsed = keepalive_time_when(tp) - elapsed; + } + +resched: + inet_csk_reset_keepalive_timer (sk, elapsed); + goto out; + +death: + tcp_done(sk); + +out: + bh_unlock_sock(sk); + sock_put(sk); +} + +static enum hrtimer_restart tcp_compressed_ack_kick(struct hrtimer *timer) +{ + struct tcp_sock *tp = container_of(timer, struct tcp_sock, compressed_ack_timer); + struct sock *sk = (struct sock *)tp; + + bh_lock_sock(sk); + if (!sock_owned_by_user(sk)) { + if (tp->compressed_ack) { + /* Since we have to send one ack finally, + * subtract one from tp->compressed_ack to keep + * LINUX_MIB_TCPACKCOMPRESSED accurate. + */ + tp->compressed_ack--; + tcp_send_ack(sk); + } + } else { + if (!test_and_set_bit(TCP_DELACK_TIMER_DEFERRED, + &sk->sk_tsq_flags)) + sock_hold(sk); + } + bh_unlock_sock(sk); + + sock_put(sk); + + return HRTIMER_NORESTART; +} + +void tcp_init_xmit_timers(struct sock *sk) +{ + inet_csk_init_xmit_timers(sk, &tcp_write_timer, &tcp_delack_timer, + &tcp_keepalive_timer); + hrtimer_init(&tcp_sk(sk)->pacing_timer, CLOCK_MONOTONIC, + HRTIMER_MODE_ABS_PINNED_SOFT); + tcp_sk(sk)->pacing_timer.function = tcp_pace_kick; + + hrtimer_init(&tcp_sk(sk)->compressed_ack_timer, CLOCK_MONOTONIC, + HRTIMER_MODE_REL_PINNED_SOFT); + tcp_sk(sk)->compressed_ack_timer.function = tcp_compressed_ack_kick; +} diff --git a/net/ipv4/tcp_ulp.c b/net/ipv4/tcp_ulp.c new file mode 100644 index 000000000..2aa442128 --- /dev/null +++ b/net/ipv4/tcp_ulp.c @@ -0,0 +1,168 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Pluggable TCP upper layer protocol support. + * + * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. + * Copyright (c) 2016-2017, Dave Watson . All rights reserved. + * + */ + +#include +#include +#include +#include +#include +#include + +static DEFINE_SPINLOCK(tcp_ulp_list_lock); +static LIST_HEAD(tcp_ulp_list); + +/* Simple linear search, don't expect many entries! */ +static struct tcp_ulp_ops *tcp_ulp_find(const char *name) +{ + struct tcp_ulp_ops *e; + + list_for_each_entry_rcu(e, &tcp_ulp_list, list, + lockdep_is_held(&tcp_ulp_list_lock)) { + if (strcmp(e->name, name) == 0) + return e; + } + + return NULL; +} + +static const struct tcp_ulp_ops *__tcp_ulp_find_autoload(const char *name) +{ + const struct tcp_ulp_ops *ulp = NULL; + + rcu_read_lock(); + ulp = tcp_ulp_find(name); + +#ifdef CONFIG_MODULES + if (!ulp && capable(CAP_NET_ADMIN)) { + rcu_read_unlock(); + request_module("tcp-ulp-%s", name); + rcu_read_lock(); + ulp = tcp_ulp_find(name); + } +#endif + if (!ulp || !try_module_get(ulp->owner)) + ulp = NULL; + + rcu_read_unlock(); + return ulp; +} + +/* Attach new upper layer protocol to the list + * of available protocols. + */ +int tcp_register_ulp(struct tcp_ulp_ops *ulp) +{ + int ret = 0; + + spin_lock(&tcp_ulp_list_lock); + if (tcp_ulp_find(ulp->name)) + ret = -EEXIST; + else + list_add_tail_rcu(&ulp->list, &tcp_ulp_list); + spin_unlock(&tcp_ulp_list_lock); + + return ret; +} +EXPORT_SYMBOL_GPL(tcp_register_ulp); + +void tcp_unregister_ulp(struct tcp_ulp_ops *ulp) +{ + spin_lock(&tcp_ulp_list_lock); + list_del_rcu(&ulp->list); + spin_unlock(&tcp_ulp_list_lock); + + synchronize_rcu(); +} +EXPORT_SYMBOL_GPL(tcp_unregister_ulp); + +/* Build string with list of available upper layer protocl values */ +void tcp_get_available_ulp(char *buf, size_t maxlen) +{ + struct tcp_ulp_ops *ulp_ops; + size_t offs = 0; + + *buf = '\0'; + rcu_read_lock(); + list_for_each_entry_rcu(ulp_ops, &tcp_ulp_list, list) { + offs += snprintf(buf + offs, maxlen - offs, + "%s%s", + offs == 0 ? "" : " ", ulp_ops->name); + + if (WARN_ON_ONCE(offs >= maxlen)) + break; + } + rcu_read_unlock(); +} + +void tcp_update_ulp(struct sock *sk, struct proto *proto, + void (*write_space)(struct sock *sk)) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + if (icsk->icsk_ulp_ops->update) + icsk->icsk_ulp_ops->update(sk, proto, write_space); +} + +void tcp_cleanup_ulp(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + /* No sock_owned_by_me() check here as at the time the + * stack calls this function, the socket is dead and + * about to be destroyed. + */ + if (!icsk->icsk_ulp_ops) + return; + + if (icsk->icsk_ulp_ops->release) + icsk->icsk_ulp_ops->release(sk); + module_put(icsk->icsk_ulp_ops->owner); + + icsk->icsk_ulp_ops = NULL; +} + +static int __tcp_set_ulp(struct sock *sk, const struct tcp_ulp_ops *ulp_ops) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + int err; + + err = -EEXIST; + if (icsk->icsk_ulp_ops) + goto out_err; + + if (sk->sk_socket) + clear_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags); + + err = -ENOTCONN; + if (!ulp_ops->clone && sk->sk_state == TCP_LISTEN) + goto out_err; + + err = ulp_ops->init(sk); + if (err) + goto out_err; + + icsk->icsk_ulp_ops = ulp_ops; + return 0; +out_err: + module_put(ulp_ops->owner); + return err; +} + +int tcp_set_ulp(struct sock *sk, const char *name) +{ + const struct tcp_ulp_ops *ulp_ops; + + sock_owned_by_me(sk); + + ulp_ops = __tcp_ulp_find_autoload(name); + if (!ulp_ops) + return -ENOENT; + + return __tcp_set_ulp(sk, ulp_ops); +} diff --git a/net/ipv4/tcp_vegas.c b/net/ipv4/tcp_vegas.c new file mode 100644 index 000000000..786848ad3 --- /dev/null +++ b/net/ipv4/tcp_vegas.c @@ -0,0 +1,340 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * TCP Vegas congestion control + * + * This is based on the congestion detection/avoidance scheme described in + * Lawrence S. Brakmo and Larry L. Peterson. + * "TCP Vegas: End to end congestion avoidance on a global internet." + * IEEE Journal on Selected Areas in Communication, 13(8):1465--1480, + * October 1995. Available from: + * ftp://ftp.cs.arizona.edu/xkernel/Papers/jsac.ps + * + * See http://www.cs.arizona.edu/xkernel/ for their implementation. + * The main aspects that distinguish this implementation from the + * Arizona Vegas implementation are: + * o We do not change the loss detection or recovery mechanisms of + * Linux in any way. Linux already recovers from losses quite well, + * using fine-grained timers, NewReno, and FACK. + * o To avoid the performance penalty imposed by increasing cwnd + * only every-other RTT during slow start, we increase during + * every RTT during slow start, just like Reno. + * o Largely to allow continuous cwnd growth during slow start, + * we use the rate at which ACKs come back as the "actual" + * rate, rather than the rate at which data is sent. + * o To speed convergence to the right rate, we set the cwnd + * to achieve the right ("actual") rate when we exit slow start. + * o To filter out the noise caused by delayed ACKs, we use the + * minimum RTT sample observed during the last RTT to calculate + * the actual rate. + * o When the sender re-starts from idle, it waits until it has + * received ACKs for an entire flight of new data before making + * a cwnd adjustment decision. The original Vegas implementation + * assumed senders never went idle. + */ + +#include +#include +#include +#include + +#include + +#include "tcp_vegas.h" + +static int alpha = 2; +static int beta = 4; +static int gamma = 1; + +module_param(alpha, int, 0644); +MODULE_PARM_DESC(alpha, "lower bound of packets in network"); +module_param(beta, int, 0644); +MODULE_PARM_DESC(beta, "upper bound of packets in network"); +module_param(gamma, int, 0644); +MODULE_PARM_DESC(gamma, "limit on increase (scale by 2)"); + +/* There are several situations when we must "re-start" Vegas: + * + * o when a connection is established + * o after an RTO + * o after fast recovery + * o when we send a packet and there is no outstanding + * unacknowledged data (restarting an idle connection) + * + * In these circumstances we cannot do a Vegas calculation at the + * end of the first RTT, because any calculation we do is using + * stale info -- both the saved cwnd and congestion feedback are + * stale. + * + * Instead we must wait until the completion of an RTT during + * which we actually receive ACKs. + */ +static void vegas_enable(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + struct vegas *vegas = inet_csk_ca(sk); + + /* Begin taking Vegas samples next time we send something. */ + vegas->doing_vegas_now = 1; + + /* Set the beginning of the next send window. */ + vegas->beg_snd_nxt = tp->snd_nxt; + + vegas->cntRTT = 0; + vegas->minRTT = 0x7fffffff; +} + +/* Stop taking Vegas samples for now. */ +static inline void vegas_disable(struct sock *sk) +{ + struct vegas *vegas = inet_csk_ca(sk); + + vegas->doing_vegas_now = 0; +} + +void tcp_vegas_init(struct sock *sk) +{ + struct vegas *vegas = inet_csk_ca(sk); + + vegas->baseRTT = 0x7fffffff; + vegas_enable(sk); +} +EXPORT_SYMBOL_GPL(tcp_vegas_init); + +/* Do RTT sampling needed for Vegas. + * Basically we: + * o min-filter RTT samples from within an RTT to get the current + * propagation delay + queuing delay (we are min-filtering to try to + * avoid the effects of delayed ACKs) + * o min-filter RTT samples from a much longer window (forever for now) + * to find the propagation delay (baseRTT) + */ +void tcp_vegas_pkts_acked(struct sock *sk, const struct ack_sample *sample) +{ + struct vegas *vegas = inet_csk_ca(sk); + u32 vrtt; + + if (sample->rtt_us < 0) + return; + + /* Never allow zero rtt or baseRTT */ + vrtt = sample->rtt_us + 1; + + /* Filter to find propagation delay: */ + if (vrtt < vegas->baseRTT) + vegas->baseRTT = vrtt; + + /* Find the min RTT during the last RTT to find + * the current prop. delay + queuing delay: + */ + vegas->minRTT = min(vegas->minRTT, vrtt); + vegas->cntRTT++; +} +EXPORT_SYMBOL_GPL(tcp_vegas_pkts_acked); + +void tcp_vegas_state(struct sock *sk, u8 ca_state) +{ + if (ca_state == TCP_CA_Open) + vegas_enable(sk); + else + vegas_disable(sk); +} +EXPORT_SYMBOL_GPL(tcp_vegas_state); + +/* + * If the connection is idle and we are restarting, + * then we don't want to do any Vegas calculations + * until we get fresh RTT samples. So when we + * restart, we reset our Vegas state to a clean + * slate. After we get acks for this flight of + * packets, _then_ we can make Vegas calculations + * again. + */ +void tcp_vegas_cwnd_event(struct sock *sk, enum tcp_ca_event event) +{ + if (event == CA_EVENT_CWND_RESTART || + event == CA_EVENT_TX_START) + tcp_vegas_init(sk); +} +EXPORT_SYMBOL_GPL(tcp_vegas_cwnd_event); + +static inline u32 tcp_vegas_ssthresh(struct tcp_sock *tp) +{ + return min(tp->snd_ssthresh, tcp_snd_cwnd(tp)); +} + +static void tcp_vegas_cong_avoid(struct sock *sk, u32 ack, u32 acked) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct vegas *vegas = inet_csk_ca(sk); + + if (!vegas->doing_vegas_now) { + tcp_reno_cong_avoid(sk, ack, acked); + return; + } + + if (after(ack, vegas->beg_snd_nxt)) { + /* Do the Vegas once-per-RTT cwnd adjustment. */ + + /* Save the extent of the current window so we can use this + * at the end of the next RTT. + */ + vegas->beg_snd_nxt = tp->snd_nxt; + + /* We do the Vegas calculations only if we got enough RTT + * samples that we can be reasonably sure that we got + * at least one RTT sample that wasn't from a delayed ACK. + * If we only had 2 samples total, + * then that means we're getting only 1 ACK per RTT, which + * means they're almost certainly delayed ACKs. + * If we have 3 samples, we should be OK. + */ + + if (vegas->cntRTT <= 2) { + /* We don't have enough RTT samples to do the Vegas + * calculation, so we'll behave like Reno. + */ + tcp_reno_cong_avoid(sk, ack, acked); + } else { + u32 rtt, diff; + u64 target_cwnd; + + /* We have enough RTT samples, so, using the Vegas + * algorithm, we determine if we should increase or + * decrease cwnd, and by how much. + */ + + /* Pluck out the RTT we are using for the Vegas + * calculations. This is the min RTT seen during the + * last RTT. Taking the min filters out the effects + * of delayed ACKs, at the cost of noticing congestion + * a bit later. + */ + rtt = vegas->minRTT; + + /* Calculate the cwnd we should have, if we weren't + * going too fast. + * + * This is: + * (actual rate in segments) * baseRTT + */ + target_cwnd = (u64)tcp_snd_cwnd(tp) * vegas->baseRTT; + do_div(target_cwnd, rtt); + + /* Calculate the difference between the window we had, + * and the window we would like to have. This quantity + * is the "Diff" from the Arizona Vegas papers. + */ + diff = tcp_snd_cwnd(tp) * (rtt-vegas->baseRTT) / vegas->baseRTT; + + if (diff > gamma && tcp_in_slow_start(tp)) { + /* Going too fast. Time to slow down + * and switch to congestion avoidance. + */ + + /* Set cwnd to match the actual rate + * exactly: + * cwnd = (actual rate) * baseRTT + * Then we add 1 because the integer + * truncation robs us of full link + * utilization. + */ + tcp_snd_cwnd_set(tp, min(tcp_snd_cwnd(tp), + (u32)target_cwnd + 1)); + tp->snd_ssthresh = tcp_vegas_ssthresh(tp); + + } else if (tcp_in_slow_start(tp)) { + /* Slow start. */ + tcp_slow_start(tp, acked); + } else { + /* Congestion avoidance. */ + + /* Figure out where we would like cwnd + * to be. + */ + if (diff > beta) { + /* The old window was too fast, so + * we slow down. + */ + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) - 1); + tp->snd_ssthresh + = tcp_vegas_ssthresh(tp); + } else if (diff < alpha) { + /* We don't have enough extra packets + * in the network, so speed up. + */ + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + 1); + } else { + /* Sending just as fast as we + * should be. + */ + } + } + + if (tcp_snd_cwnd(tp) < 2) + tcp_snd_cwnd_set(tp, 2); + else if (tcp_snd_cwnd(tp) > tp->snd_cwnd_clamp) + tcp_snd_cwnd_set(tp, tp->snd_cwnd_clamp); + + tp->snd_ssthresh = tcp_current_ssthresh(sk); + } + + /* Wipe the slate clean for the next RTT. */ + vegas->cntRTT = 0; + vegas->minRTT = 0x7fffffff; + } + /* Use normal slow start */ + else if (tcp_in_slow_start(tp)) + tcp_slow_start(tp, acked); +} + +/* Extract info for Tcp socket info provided via netlink. */ +size_t tcp_vegas_get_info(struct sock *sk, u32 ext, int *attr, + union tcp_cc_info *info) +{ + const struct vegas *ca = inet_csk_ca(sk); + + if (ext & (1 << (INET_DIAG_VEGASINFO - 1))) { + info->vegas.tcpv_enabled = ca->doing_vegas_now; + info->vegas.tcpv_rttcnt = ca->cntRTT; + info->vegas.tcpv_rtt = ca->baseRTT; + info->vegas.tcpv_minrtt = ca->minRTT; + + *attr = INET_DIAG_VEGASINFO; + return sizeof(struct tcpvegas_info); + } + return 0; +} +EXPORT_SYMBOL_GPL(tcp_vegas_get_info); + +static struct tcp_congestion_ops tcp_vegas __read_mostly = { + .init = tcp_vegas_init, + .ssthresh = tcp_reno_ssthresh, + .undo_cwnd = tcp_reno_undo_cwnd, + .cong_avoid = tcp_vegas_cong_avoid, + .pkts_acked = tcp_vegas_pkts_acked, + .set_state = tcp_vegas_state, + .cwnd_event = tcp_vegas_cwnd_event, + .get_info = tcp_vegas_get_info, + + .owner = THIS_MODULE, + .name = "vegas", +}; + +static int __init tcp_vegas_register(void) +{ + BUILD_BUG_ON(sizeof(struct vegas) > ICSK_CA_PRIV_SIZE); + tcp_register_congestion_control(&tcp_vegas); + return 0; +} + +static void __exit tcp_vegas_unregister(void) +{ + tcp_unregister_congestion_control(&tcp_vegas); +} + +module_init(tcp_vegas_register); +module_exit(tcp_vegas_unregister); + +MODULE_AUTHOR("Stephen Hemminger"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("TCP Vegas"); diff --git a/net/ipv4/tcp_vegas.h b/net/ipv4/tcp_vegas.h new file mode 100644 index 000000000..4f24d0e37 --- /dev/null +++ b/net/ipv4/tcp_vegas.h @@ -0,0 +1,26 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * TCP Vegas congestion control interface + */ +#ifndef __TCP_VEGAS_H +#define __TCP_VEGAS_H 1 + +/* Vegas variables */ +struct vegas { + u32 beg_snd_nxt; /* right edge during last RTT */ + u32 beg_snd_una; /* left edge during last RTT */ + u32 beg_snd_cwnd; /* saves the size of the cwnd */ + u8 doing_vegas_now;/* if true, do vegas for this RTT */ + u16 cntRTT; /* # of RTTs measured within last RTT */ + u32 minRTT; /* min of RTTs measured within last RTT (in usec) */ + u32 baseRTT; /* the min of all Vegas RTT measurements seen (in usec) */ +}; + +void tcp_vegas_init(struct sock *sk); +void tcp_vegas_state(struct sock *sk, u8 ca_state); +void tcp_vegas_pkts_acked(struct sock *sk, const struct ack_sample *sample); +void tcp_vegas_cwnd_event(struct sock *sk, enum tcp_ca_event event); +size_t tcp_vegas_get_info(struct sock *sk, u32 ext, int *attr, + union tcp_cc_info *info); + +#endif /* __TCP_VEGAS_H */ diff --git a/net/ipv4/tcp_veno.c b/net/ipv4/tcp_veno.c new file mode 100644 index 000000000..366ff6f21 --- /dev/null +++ b/net/ipv4/tcp_veno.c @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * TCP Veno congestion control + * + * This is based on the congestion detection/avoidance scheme described in + * C. P. Fu, S. C. Liew. + * "TCP Veno: TCP Enhancement for Transmission over Wireless Access Networks." + * IEEE Journal on Selected Areas in Communication, + * Feb. 2003. + * See https://www.ie.cuhk.edu.hk/fileadmin/staff_upload/soung/Journal/J3.pdf + */ + +#include +#include +#include +#include + +#include + +/* Default values of the Veno variables, in fixed-point representation + * with V_PARAM_SHIFT bits to the right of the binary point. + */ +#define V_PARAM_SHIFT 1 +static const int beta = 3 << V_PARAM_SHIFT; + +/* Veno variables */ +struct veno { + u8 doing_veno_now; /* if true, do veno for this rtt */ + u16 cntrtt; /* # of rtts measured within last rtt */ + u32 minrtt; /* min of rtts measured within last rtt (in usec) */ + u32 basertt; /* the min of all Veno rtt measurements seen (in usec) */ + u32 inc; /* decide whether to increase cwnd */ + u32 diff; /* calculate the diff rate */ +}; + +/* There are several situations when we must "re-start" Veno: + * + * o when a connection is established + * o after an RTO + * o after fast recovery + * o when we send a packet and there is no outstanding + * unacknowledged data (restarting an idle connection) + * + */ +static inline void veno_enable(struct sock *sk) +{ + struct veno *veno = inet_csk_ca(sk); + + /* turn on Veno */ + veno->doing_veno_now = 1; + + veno->minrtt = 0x7fffffff; +} + +static inline void veno_disable(struct sock *sk) +{ + struct veno *veno = inet_csk_ca(sk); + + /* turn off Veno */ + veno->doing_veno_now = 0; +} + +static void tcp_veno_init(struct sock *sk) +{ + struct veno *veno = inet_csk_ca(sk); + + veno->basertt = 0x7fffffff; + veno->inc = 1; + veno_enable(sk); +} + +/* Do rtt sampling needed for Veno. */ +static void tcp_veno_pkts_acked(struct sock *sk, + const struct ack_sample *sample) +{ + struct veno *veno = inet_csk_ca(sk); + u32 vrtt; + + if (sample->rtt_us < 0) + return; + + /* Never allow zero rtt or baseRTT */ + vrtt = sample->rtt_us + 1; + + /* Filter to find propagation delay: */ + if (vrtt < veno->basertt) + veno->basertt = vrtt; + + /* Find the min rtt during the last rtt to find + * the current prop. delay + queuing delay: + */ + veno->minrtt = min(veno->minrtt, vrtt); + veno->cntrtt++; +} + +static void tcp_veno_state(struct sock *sk, u8 ca_state) +{ + if (ca_state == TCP_CA_Open) + veno_enable(sk); + else + veno_disable(sk); +} + +/* + * If the connection is idle and we are restarting, + * then we don't want to do any Veno calculations + * until we get fresh rtt samples. So when we + * restart, we reset our Veno state to a clean + * state. After we get acks for this flight of + * packets, _then_ we can make Veno calculations + * again. + */ +static void tcp_veno_cwnd_event(struct sock *sk, enum tcp_ca_event event) +{ + if (event == CA_EVENT_CWND_RESTART || event == CA_EVENT_TX_START) + tcp_veno_init(sk); +} + +static void tcp_veno_cong_avoid(struct sock *sk, u32 ack, u32 acked) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct veno *veno = inet_csk_ca(sk); + + if (!veno->doing_veno_now) { + tcp_reno_cong_avoid(sk, ack, acked); + return; + } + + /* limited by applications */ + if (!tcp_is_cwnd_limited(sk)) + return; + + /* We do the Veno calculations only if we got enough rtt samples */ + if (veno->cntrtt <= 2) { + /* We don't have enough rtt samples to do the Veno + * calculation, so we'll behave like Reno. + */ + tcp_reno_cong_avoid(sk, ack, acked); + } else { + u64 target_cwnd; + u32 rtt; + + /* We have enough rtt samples, so, using the Veno + * algorithm, we determine the state of the network. + */ + + rtt = veno->minrtt; + + target_cwnd = (u64)tcp_snd_cwnd(tp) * veno->basertt; + target_cwnd <<= V_PARAM_SHIFT; + do_div(target_cwnd, rtt); + + veno->diff = (tcp_snd_cwnd(tp) << V_PARAM_SHIFT) - target_cwnd; + + if (tcp_in_slow_start(tp)) { + /* Slow start. */ + acked = tcp_slow_start(tp, acked); + if (!acked) + goto done; + } + + /* Congestion avoidance. */ + if (veno->diff < beta) { + /* In the "non-congestive state", increase cwnd + * every rtt. + */ + tcp_cong_avoid_ai(tp, tcp_snd_cwnd(tp), acked); + } else { + /* In the "congestive state", increase cwnd + * every other rtt. + */ + if (tp->snd_cwnd_cnt >= tcp_snd_cwnd(tp)) { + if (veno->inc && + tcp_snd_cwnd(tp) < tp->snd_cwnd_clamp) { + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) + 1); + veno->inc = 0; + } else + veno->inc = 1; + tp->snd_cwnd_cnt = 0; + } else + tp->snd_cwnd_cnt += acked; + } +done: + if (tcp_snd_cwnd(tp) < 2) + tcp_snd_cwnd_set(tp, 2); + else if (tcp_snd_cwnd(tp) > tp->snd_cwnd_clamp) + tcp_snd_cwnd_set(tp, tp->snd_cwnd_clamp); + } + /* Wipe the slate clean for the next rtt. */ + /* veno->cntrtt = 0; */ + veno->minrtt = 0x7fffffff; +} + +/* Veno MD phase */ +static u32 tcp_veno_ssthresh(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + struct veno *veno = inet_csk_ca(sk); + + if (veno->diff < beta) + /* in "non-congestive state", cut cwnd by 1/5 */ + return max(tcp_snd_cwnd(tp) * 4 / 5, 2U); + else + /* in "congestive state", cut cwnd by 1/2 */ + return max(tcp_snd_cwnd(tp) >> 1U, 2U); +} + +static struct tcp_congestion_ops tcp_veno __read_mostly = { + .init = tcp_veno_init, + .ssthresh = tcp_veno_ssthresh, + .undo_cwnd = tcp_reno_undo_cwnd, + .cong_avoid = tcp_veno_cong_avoid, + .pkts_acked = tcp_veno_pkts_acked, + .set_state = tcp_veno_state, + .cwnd_event = tcp_veno_cwnd_event, + + .owner = THIS_MODULE, + .name = "veno", +}; + +static int __init tcp_veno_register(void) +{ + BUILD_BUG_ON(sizeof(struct veno) > ICSK_CA_PRIV_SIZE); + tcp_register_congestion_control(&tcp_veno); + return 0; +} + +static void __exit tcp_veno_unregister(void) +{ + tcp_unregister_congestion_control(&tcp_veno); +} + +module_init(tcp_veno_register); +module_exit(tcp_veno_unregister); + +MODULE_AUTHOR("Bin Zhou, Cheng Peng Fu"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("TCP Veno"); diff --git a/net/ipv4/tcp_westwood.c b/net/ipv4/tcp_westwood.c new file mode 100644 index 000000000..c6e97141e --- /dev/null +++ b/net/ipv4/tcp_westwood.c @@ -0,0 +1,309 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * TCP Westwood+: end-to-end bandwidth estimation for TCP + * + * Angelo Dell'Aera: author of the first version of TCP Westwood+ in Linux 2.4 + * + * Support at http://c3lab.poliba.it/index.php/Westwood + * Main references in literature: + * + * - Mascolo S, Casetti, M. Gerla et al. + * "TCP Westwood: bandwidth estimation for TCP" Proc. ACM Mobicom 2001 + * + * - A. Grieco, s. Mascolo + * "Performance evaluation of New Reno, Vegas, Westwood+ TCP" ACM Computer + * Comm. Review, 2004 + * + * - A. Dell'Aera, L. Grieco, S. Mascolo. + * "Linux 2.4 Implementation of Westwood+ TCP with Rate-Halving : + * A Performance Evaluation Over the Internet" (ICC 2004), Paris, June 2004 + * + * Westwood+ employs end-to-end bandwidth measurement to set cwnd and + * ssthresh after packet loss. The probing phase is as the original Reno. + */ + +#include +#include +#include +#include +#include + +/* TCP Westwood structure */ +struct westwood { + u32 bw_ns_est; /* first bandwidth estimation..not too smoothed 8) */ + u32 bw_est; /* bandwidth estimate */ + u32 rtt_win_sx; /* here starts a new evaluation... */ + u32 bk; + u32 snd_una; /* used for evaluating the number of acked bytes */ + u32 cumul_ack; + u32 accounted; + u32 rtt; + u32 rtt_min; /* minimum observed RTT */ + u8 first_ack; /* flag which infers that this is the first ack */ + u8 reset_rtt_min; /* Reset RTT min to next RTT sample*/ +}; + +/* TCP Westwood functions and constants */ +#define TCP_WESTWOOD_RTT_MIN (HZ/20) /* 50ms */ +#define TCP_WESTWOOD_INIT_RTT (20*HZ) /* maybe too conservative?! */ + +/* + * @tcp_westwood_create + * This function initializes fields used in TCP Westwood+, + * it is called after the initial SYN, so the sequence numbers + * are correct but new passive connections we have no + * information about RTTmin at this time so we simply set it to + * TCP_WESTWOOD_INIT_RTT. This value was chosen to be too conservative + * since in this way we're sure it will be updated in a consistent + * way as soon as possible. It will reasonably happen within the first + * RTT period of the connection lifetime. + */ +static void tcp_westwood_init(struct sock *sk) +{ + struct westwood *w = inet_csk_ca(sk); + + w->bk = 0; + w->bw_ns_est = 0; + w->bw_est = 0; + w->accounted = 0; + w->cumul_ack = 0; + w->reset_rtt_min = 1; + w->rtt_min = w->rtt = TCP_WESTWOOD_INIT_RTT; + w->rtt_win_sx = tcp_jiffies32; + w->snd_una = tcp_sk(sk)->snd_una; + w->first_ack = 1; +} + +/* + * @westwood_do_filter + * Low-pass filter. Implemented using constant coefficients. + */ +static inline u32 westwood_do_filter(u32 a, u32 b) +{ + return ((7 * a) + b) >> 3; +} + +static void westwood_filter(struct westwood *w, u32 delta) +{ + /* If the filter is empty fill it with the first sample of bandwidth */ + if (w->bw_ns_est == 0 && w->bw_est == 0) { + w->bw_ns_est = w->bk / delta; + w->bw_est = w->bw_ns_est; + } else { + w->bw_ns_est = westwood_do_filter(w->bw_ns_est, w->bk / delta); + w->bw_est = westwood_do_filter(w->bw_est, w->bw_ns_est); + } +} + +/* + * @westwood_pkts_acked + * Called after processing group of packets. + * but all westwood needs is the last sample of srtt. + */ +static void tcp_westwood_pkts_acked(struct sock *sk, + const struct ack_sample *sample) +{ + struct westwood *w = inet_csk_ca(sk); + + if (sample->rtt_us > 0) + w->rtt = usecs_to_jiffies(sample->rtt_us); +} + +/* + * @westwood_update_window + * It updates RTT evaluation window if it is the right moment to do + * it. If so it calls filter for evaluating bandwidth. + */ +static void westwood_update_window(struct sock *sk) +{ + struct westwood *w = inet_csk_ca(sk); + s32 delta = tcp_jiffies32 - w->rtt_win_sx; + + /* Initialize w->snd_una with the first acked sequence number in order + * to fix mismatch between tp->snd_una and w->snd_una for the first + * bandwidth sample + */ + if (w->first_ack) { + w->snd_una = tcp_sk(sk)->snd_una; + w->first_ack = 0; + } + + /* + * See if a RTT-window has passed. + * Be careful since if RTT is less than + * 50ms we don't filter but we continue 'building the sample'. + * This minimum limit was chosen since an estimation on small + * time intervals is better to avoid... + * Obviously on a LAN we reasonably will always have + * right_bound = left_bound + WESTWOOD_RTT_MIN + */ + if (w->rtt && delta > max_t(u32, w->rtt, TCP_WESTWOOD_RTT_MIN)) { + westwood_filter(w, delta); + + w->bk = 0; + w->rtt_win_sx = tcp_jiffies32; + } +} + +static inline void update_rtt_min(struct westwood *w) +{ + if (w->reset_rtt_min) { + w->rtt_min = w->rtt; + w->reset_rtt_min = 0; + } else + w->rtt_min = min(w->rtt, w->rtt_min); +} + +/* + * @westwood_fast_bw + * It is called when we are in fast path. In particular it is called when + * header prediction is successful. In such case in fact update is + * straight forward and doesn't need any particular care. + */ +static inline void westwood_fast_bw(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + struct westwood *w = inet_csk_ca(sk); + + westwood_update_window(sk); + + w->bk += tp->snd_una - w->snd_una; + w->snd_una = tp->snd_una; + update_rtt_min(w); +} + +/* + * @westwood_acked_count + * This function evaluates cumul_ack for evaluating bk in case of + * delayed or partial acks. + */ +static inline u32 westwood_acked_count(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + struct westwood *w = inet_csk_ca(sk); + + w->cumul_ack = tp->snd_una - w->snd_una; + + /* If cumul_ack is 0 this is a dupack since it's not moving + * tp->snd_una. + */ + if (!w->cumul_ack) { + w->accounted += tp->mss_cache; + w->cumul_ack = tp->mss_cache; + } + + if (w->cumul_ack > tp->mss_cache) { + /* Partial or delayed ack */ + if (w->accounted >= w->cumul_ack) { + w->accounted -= w->cumul_ack; + w->cumul_ack = tp->mss_cache; + } else { + w->cumul_ack -= w->accounted; + w->accounted = 0; + } + } + + w->snd_una = tp->snd_una; + + return w->cumul_ack; +} + +/* + * TCP Westwood + * Here limit is evaluated as Bw estimation*RTTmin (for obtaining it + * in packets we use mss_cache). Rttmin is guaranteed to be >= 2 + * so avoids ever returning 0. + */ +static u32 tcp_westwood_bw_rttmin(const struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + const struct westwood *w = inet_csk_ca(sk); + + return max_t(u32, (w->bw_est * w->rtt_min) / tp->mss_cache, 2); +} + +static void tcp_westwood_ack(struct sock *sk, u32 ack_flags) +{ + if (ack_flags & CA_ACK_SLOWPATH) { + struct westwood *w = inet_csk_ca(sk); + + westwood_update_window(sk); + w->bk += westwood_acked_count(sk); + + update_rtt_min(w); + return; + } + + westwood_fast_bw(sk); +} + +static void tcp_westwood_event(struct sock *sk, enum tcp_ca_event event) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct westwood *w = inet_csk_ca(sk); + + switch (event) { + case CA_EVENT_COMPLETE_CWR: + tp->snd_ssthresh = tcp_westwood_bw_rttmin(sk); + tcp_snd_cwnd_set(tp, tp->snd_ssthresh); + break; + case CA_EVENT_LOSS: + tp->snd_ssthresh = tcp_westwood_bw_rttmin(sk); + /* Update RTT_min when next ack arrives */ + w->reset_rtt_min = 1; + break; + default: + /* don't care */ + break; + } +} + +/* Extract info for Tcp socket info provided via netlink. */ +static size_t tcp_westwood_info(struct sock *sk, u32 ext, int *attr, + union tcp_cc_info *info) +{ + const struct westwood *ca = inet_csk_ca(sk); + + if (ext & (1 << (INET_DIAG_VEGASINFO - 1))) { + info->vegas.tcpv_enabled = 1; + info->vegas.tcpv_rttcnt = 0; + info->vegas.tcpv_rtt = jiffies_to_usecs(ca->rtt); + info->vegas.tcpv_minrtt = jiffies_to_usecs(ca->rtt_min); + + *attr = INET_DIAG_VEGASINFO; + return sizeof(struct tcpvegas_info); + } + return 0; +} + +static struct tcp_congestion_ops tcp_westwood __read_mostly = { + .init = tcp_westwood_init, + .ssthresh = tcp_reno_ssthresh, + .cong_avoid = tcp_reno_cong_avoid, + .undo_cwnd = tcp_reno_undo_cwnd, + .cwnd_event = tcp_westwood_event, + .in_ack_event = tcp_westwood_ack, + .get_info = tcp_westwood_info, + .pkts_acked = tcp_westwood_pkts_acked, + + .owner = THIS_MODULE, + .name = "westwood" +}; + +static int __init tcp_westwood_register(void) +{ + BUILD_BUG_ON(sizeof(struct westwood) > ICSK_CA_PRIV_SIZE); + return tcp_register_congestion_control(&tcp_westwood); +} + +static void __exit tcp_westwood_unregister(void) +{ + tcp_unregister_congestion_control(&tcp_westwood); +} + +module_init(tcp_westwood_register); +module_exit(tcp_westwood_unregister); + +MODULE_AUTHOR("Stephen Hemminger, Angelo Dell'Aera"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("TCP Westwood+"); diff --git a/net/ipv4/tcp_yeah.c b/net/ipv4/tcp_yeah.c new file mode 100644 index 000000000..18b07ff5d --- /dev/null +++ b/net/ipv4/tcp_yeah.c @@ -0,0 +1,239 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * + * YeAH TCP + * + * For further details look at: + * https://web.archive.org/web/20080316215752/http://wil.cs.caltech.edu/pfldnet2007/paper/YeAH_TCP.pdf + * + */ +#include +#include +#include +#include + +#include + +#include "tcp_vegas.h" + +#define TCP_YEAH_ALPHA 80 /* number of packets queued at the bottleneck */ +#define TCP_YEAH_GAMMA 1 /* fraction of queue to be removed per rtt */ +#define TCP_YEAH_DELTA 3 /* log minimum fraction of cwnd to be removed on loss */ +#define TCP_YEAH_EPSILON 1 /* log maximum fraction to be removed on early decongestion */ +#define TCP_YEAH_PHY 8 /* maximum delta from base */ +#define TCP_YEAH_RHO 16 /* minimum number of consecutive rtt to consider competition on loss */ +#define TCP_YEAH_ZETA 50 /* minimum number of state switches to reset reno_count */ + +#define TCP_SCALABLE_AI_CNT 100U + +/* YeAH variables */ +struct yeah { + struct vegas vegas; /* must be first */ + + /* YeAH */ + u32 lastQ; + u32 doing_reno_now; + + u32 reno_count; + u32 fast_count; +}; + +static void tcp_yeah_init(struct sock *sk) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct yeah *yeah = inet_csk_ca(sk); + + tcp_vegas_init(sk); + + yeah->doing_reno_now = 0; + yeah->lastQ = 0; + + yeah->reno_count = 2; + + /* Ensure the MD arithmetic works. This is somewhat pedantic, + * since I don't think we will see a cwnd this large. :) */ + tp->snd_cwnd_clamp = min_t(u32, tp->snd_cwnd_clamp, 0xffffffff/128); +} + +static void tcp_yeah_cong_avoid(struct sock *sk, u32 ack, u32 acked) +{ + struct tcp_sock *tp = tcp_sk(sk); + struct yeah *yeah = inet_csk_ca(sk); + + if (!tcp_is_cwnd_limited(sk)) + return; + + if (tcp_in_slow_start(tp)) { + acked = tcp_slow_start(tp, acked); + if (!acked) + goto do_vegas; + } + + if (!yeah->doing_reno_now) { + /* Scalable */ + tcp_cong_avoid_ai(tp, min(tcp_snd_cwnd(tp), TCP_SCALABLE_AI_CNT), + acked); + } else { + /* Reno */ + tcp_cong_avoid_ai(tp, tcp_snd_cwnd(tp), acked); + } + + /* The key players are v_vegas.beg_snd_una and v_beg_snd_nxt. + * + * These are so named because they represent the approximate values + * of snd_una and snd_nxt at the beginning of the current RTT. More + * precisely, they represent the amount of data sent during the RTT. + * At the end of the RTT, when we receive an ACK for v_beg_snd_nxt, + * we will calculate that (v_beg_snd_nxt - v_vegas.beg_snd_una) outstanding + * bytes of data have been ACKed during the course of the RTT, giving + * an "actual" rate of: + * + * (v_beg_snd_nxt - v_vegas.beg_snd_una) / (rtt duration) + * + * Unfortunately, v_vegas.beg_snd_una is not exactly equal to snd_una, + * because delayed ACKs can cover more than one segment, so they + * don't line up yeahly with the boundaries of RTTs. + * + * Another unfortunate fact of life is that delayed ACKs delay the + * advance of the left edge of our send window, so that the number + * of bytes we send in an RTT is often less than our cwnd will allow. + * So we keep track of our cwnd separately, in v_beg_snd_cwnd. + */ +do_vegas: + if (after(ack, yeah->vegas.beg_snd_nxt)) { + /* We do the Vegas calculations only if we got enough RTT + * samples that we can be reasonably sure that we got + * at least one RTT sample that wasn't from a delayed ACK. + * If we only had 2 samples total, + * then that means we're getting only 1 ACK per RTT, which + * means they're almost certainly delayed ACKs. + * If we have 3 samples, we should be OK. + */ + + if (yeah->vegas.cntRTT > 2) { + u32 rtt, queue; + u64 bw; + + /* We have enough RTT samples, so, using the Vegas + * algorithm, we determine if we should increase or + * decrease cwnd, and by how much. + */ + + /* Pluck out the RTT we are using for the Vegas + * calculations. This is the min RTT seen during the + * last RTT. Taking the min filters out the effects + * of delayed ACKs, at the cost of noticing congestion + * a bit later. + */ + rtt = yeah->vegas.minRTT; + + /* Compute excess number of packets above bandwidth + * Avoid doing full 64 bit divide. + */ + bw = tcp_snd_cwnd(tp); + bw *= rtt - yeah->vegas.baseRTT; + do_div(bw, rtt); + queue = bw; + + if (queue > TCP_YEAH_ALPHA || + rtt - yeah->vegas.baseRTT > (yeah->vegas.baseRTT / TCP_YEAH_PHY)) { + if (queue > TCP_YEAH_ALPHA && + tcp_snd_cwnd(tp) > yeah->reno_count) { + u32 reduction = min(queue / TCP_YEAH_GAMMA , + tcp_snd_cwnd(tp) >> TCP_YEAH_EPSILON); + + tcp_snd_cwnd_set(tp, tcp_snd_cwnd(tp) - reduction); + + tcp_snd_cwnd_set(tp, max(tcp_snd_cwnd(tp), + yeah->reno_count)); + + tp->snd_ssthresh = tcp_snd_cwnd(tp); + } + + if (yeah->reno_count <= 2) + yeah->reno_count = max(tcp_snd_cwnd(tp)>>1, 2U); + else + yeah->reno_count++; + + yeah->doing_reno_now = min(yeah->doing_reno_now + 1, + 0xffffffU); + } else { + yeah->fast_count++; + + if (yeah->fast_count > TCP_YEAH_ZETA) { + yeah->reno_count = 2; + yeah->fast_count = 0; + } + + yeah->doing_reno_now = 0; + } + + yeah->lastQ = queue; + } + + /* Save the extent of the current window so we can use this + * at the end of the next RTT. + */ + yeah->vegas.beg_snd_una = yeah->vegas.beg_snd_nxt; + yeah->vegas.beg_snd_nxt = tp->snd_nxt; + yeah->vegas.beg_snd_cwnd = tcp_snd_cwnd(tp); + + /* Wipe the slate clean for the next RTT. */ + yeah->vegas.cntRTT = 0; + yeah->vegas.minRTT = 0x7fffffff; + } +} + +static u32 tcp_yeah_ssthresh(struct sock *sk) +{ + const struct tcp_sock *tp = tcp_sk(sk); + struct yeah *yeah = inet_csk_ca(sk); + u32 reduction; + + if (yeah->doing_reno_now < TCP_YEAH_RHO) { + reduction = yeah->lastQ; + + reduction = min(reduction, max(tcp_snd_cwnd(tp)>>1, 2U)); + + reduction = max(reduction, tcp_snd_cwnd(tp) >> TCP_YEAH_DELTA); + } else + reduction = max(tcp_snd_cwnd(tp)>>1, 2U); + + yeah->fast_count = 0; + yeah->reno_count = max(yeah->reno_count>>1, 2U); + + return max_t(int, tcp_snd_cwnd(tp) - reduction, 2); +} + +static struct tcp_congestion_ops tcp_yeah __read_mostly = { + .init = tcp_yeah_init, + .ssthresh = tcp_yeah_ssthresh, + .undo_cwnd = tcp_reno_undo_cwnd, + .cong_avoid = tcp_yeah_cong_avoid, + .set_state = tcp_vegas_state, + .cwnd_event = tcp_vegas_cwnd_event, + .get_info = tcp_vegas_get_info, + .pkts_acked = tcp_vegas_pkts_acked, + + .owner = THIS_MODULE, + .name = "yeah", +}; + +static int __init tcp_yeah_register(void) +{ + BUILD_BUG_ON(sizeof(struct yeah) > ICSK_CA_PRIV_SIZE); + tcp_register_congestion_control(&tcp_yeah); + return 0; +} + +static void __exit tcp_yeah_unregister(void) +{ + tcp_unregister_congestion_control(&tcp_yeah); +} + +module_init(tcp_yeah_register); +module_exit(tcp_yeah_unregister); + +MODULE_AUTHOR("Angelo P. Castellani"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("YeAH TCP"); diff --git a/net/ipv4/tunnel4.c b/net/ipv4/tunnel4.c new file mode 100644 index 000000000..5048c47c7 --- /dev/null +++ b/net/ipv4/tunnel4.c @@ -0,0 +1,297 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* tunnel4.c: Generic IP tunnel transformer. + * + * Copyright (C) 2003 David S. Miller (davem@redhat.com) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static struct xfrm_tunnel __rcu *tunnel4_handlers __read_mostly; +static struct xfrm_tunnel __rcu *tunnel64_handlers __read_mostly; +static struct xfrm_tunnel __rcu *tunnelmpls4_handlers __read_mostly; +static DEFINE_MUTEX(tunnel4_mutex); + +static inline struct xfrm_tunnel __rcu **fam_handlers(unsigned short family) +{ + return (family == AF_INET) ? &tunnel4_handlers : + (family == AF_INET6) ? &tunnel64_handlers : + &tunnelmpls4_handlers; +} + +int xfrm4_tunnel_register(struct xfrm_tunnel *handler, unsigned short family) +{ + struct xfrm_tunnel __rcu **pprev; + struct xfrm_tunnel *t; + + int ret = -EEXIST; + int priority = handler->priority; + + mutex_lock(&tunnel4_mutex); + + for (pprev = fam_handlers(family); + (t = rcu_dereference_protected(*pprev, + lockdep_is_held(&tunnel4_mutex))) != NULL; + pprev = &t->next) { + if (t->priority > priority) + break; + if (t->priority == priority) + goto err; + } + + handler->next = *pprev; + rcu_assign_pointer(*pprev, handler); + + ret = 0; + +err: + mutex_unlock(&tunnel4_mutex); + + return ret; +} +EXPORT_SYMBOL(xfrm4_tunnel_register); + +int xfrm4_tunnel_deregister(struct xfrm_tunnel *handler, unsigned short family) +{ + struct xfrm_tunnel __rcu **pprev; + struct xfrm_tunnel *t; + int ret = -ENOENT; + + mutex_lock(&tunnel4_mutex); + + for (pprev = fam_handlers(family); + (t = rcu_dereference_protected(*pprev, + lockdep_is_held(&tunnel4_mutex))) != NULL; + pprev = &t->next) { + if (t == handler) { + *pprev = handler->next; + ret = 0; + break; + } + } + + mutex_unlock(&tunnel4_mutex); + + synchronize_net(); + + return ret; +} +EXPORT_SYMBOL(xfrm4_tunnel_deregister); + +#define for_each_tunnel_rcu(head, handler) \ + for (handler = rcu_dereference(head); \ + handler != NULL; \ + handler = rcu_dereference(handler->next)) \ + +static int tunnel4_rcv(struct sk_buff *skb) +{ + struct xfrm_tunnel *handler; + + if (!pskb_may_pull(skb, sizeof(struct iphdr))) + goto drop; + + for_each_tunnel_rcu(tunnel4_handlers, handler) + if (!handler->handler(skb)) + return 0; + + icmp_send(skb, ICMP_DEST_UNREACH, ICMP_PORT_UNREACH, 0); + +drop: + kfree_skb(skb); + return 0; +} + +#if IS_ENABLED(CONFIG_INET_XFRM_TUNNEL) +static int tunnel4_rcv_cb(struct sk_buff *skb, u8 proto, int err) +{ + struct xfrm_tunnel __rcu *head; + struct xfrm_tunnel *handler; + int ret; + + head = (proto == IPPROTO_IPIP) ? tunnel4_handlers : tunnel64_handlers; + + for_each_tunnel_rcu(head, handler) { + if (handler->cb_handler) { + ret = handler->cb_handler(skb, err); + if (ret <= 0) + return ret; + } + } + + return 0; +} + +static const struct xfrm_input_afinfo tunnel4_input_afinfo = { + .family = AF_INET, + .is_ipip = true, + .callback = tunnel4_rcv_cb, +}; +#endif + +#if IS_ENABLED(CONFIG_IPV6) +static int tunnel64_rcv(struct sk_buff *skb) +{ + struct xfrm_tunnel *handler; + + if (!pskb_may_pull(skb, sizeof(struct ipv6hdr))) + goto drop; + + for_each_tunnel_rcu(tunnel64_handlers, handler) + if (!handler->handler(skb)) + return 0; + + icmp_send(skb, ICMP_DEST_UNREACH, ICMP_PORT_UNREACH, 0); + +drop: + kfree_skb(skb); + return 0; +} +#endif + +#if IS_ENABLED(CONFIG_MPLS) +static int tunnelmpls4_rcv(struct sk_buff *skb) +{ + struct xfrm_tunnel *handler; + + if (!pskb_may_pull(skb, sizeof(struct mpls_label))) + goto drop; + + for_each_tunnel_rcu(tunnelmpls4_handlers, handler) + if (!handler->handler(skb)) + return 0; + + icmp_send(skb, ICMP_DEST_UNREACH, ICMP_PORT_UNREACH, 0); + +drop: + kfree_skb(skb); + return 0; +} +#endif + +static int tunnel4_err(struct sk_buff *skb, u32 info) +{ + struct xfrm_tunnel *handler; + + for_each_tunnel_rcu(tunnel4_handlers, handler) + if (!handler->err_handler(skb, info)) + return 0; + + return -ENOENT; +} + +#if IS_ENABLED(CONFIG_IPV6) +static int tunnel64_err(struct sk_buff *skb, u32 info) +{ + struct xfrm_tunnel *handler; + + for_each_tunnel_rcu(tunnel64_handlers, handler) + if (!handler->err_handler(skb, info)) + return 0; + + return -ENOENT; +} +#endif + +#if IS_ENABLED(CONFIG_MPLS) +static int tunnelmpls4_err(struct sk_buff *skb, u32 info) +{ + struct xfrm_tunnel *handler; + + for_each_tunnel_rcu(tunnelmpls4_handlers, handler) + if (!handler->err_handler(skb, info)) + return 0; + + return -ENOENT; +} +#endif + +static const struct net_protocol tunnel4_protocol = { + .handler = tunnel4_rcv, + .err_handler = tunnel4_err, + .no_policy = 1, +}; + +#if IS_ENABLED(CONFIG_IPV6) +static const struct net_protocol tunnel64_protocol = { + .handler = tunnel64_rcv, + .err_handler = tunnel64_err, + .no_policy = 1, +}; +#endif + +#if IS_ENABLED(CONFIG_MPLS) +static const struct net_protocol tunnelmpls4_protocol = { + .handler = tunnelmpls4_rcv, + .err_handler = tunnelmpls4_err, + .no_policy = 1, +}; +#endif + +static int __init tunnel4_init(void) +{ + if (inet_add_protocol(&tunnel4_protocol, IPPROTO_IPIP)) + goto err; +#if IS_ENABLED(CONFIG_IPV6) + if (inet_add_protocol(&tunnel64_protocol, IPPROTO_IPV6)) { + inet_del_protocol(&tunnel4_protocol, IPPROTO_IPIP); + goto err; + } +#endif +#if IS_ENABLED(CONFIG_MPLS) + if (inet_add_protocol(&tunnelmpls4_protocol, IPPROTO_MPLS)) { + inet_del_protocol(&tunnel4_protocol, IPPROTO_IPIP); +#if IS_ENABLED(CONFIG_IPV6) + inet_del_protocol(&tunnel64_protocol, IPPROTO_IPV6); +#endif + goto err; + } +#endif +#if IS_ENABLED(CONFIG_INET_XFRM_TUNNEL) + if (xfrm_input_register_afinfo(&tunnel4_input_afinfo)) { + inet_del_protocol(&tunnel4_protocol, IPPROTO_IPIP); +#if IS_ENABLED(CONFIG_IPV6) + inet_del_protocol(&tunnel64_protocol, IPPROTO_IPV6); +#endif +#if IS_ENABLED(CONFIG_MPLS) + inet_del_protocol(&tunnelmpls4_protocol, IPPROTO_MPLS); +#endif + goto err; + } +#endif + return 0; + +err: + pr_err("%s: can't add protocol\n", __func__); + return -EAGAIN; +} + +static void __exit tunnel4_fini(void) +{ +#if IS_ENABLED(CONFIG_INET_XFRM_TUNNEL) + if (xfrm_input_unregister_afinfo(&tunnel4_input_afinfo)) + pr_err("tunnel4 close: can't remove input afinfo\n"); +#endif +#if IS_ENABLED(CONFIG_MPLS) + if (inet_del_protocol(&tunnelmpls4_protocol, IPPROTO_MPLS)) + pr_err("tunnelmpls4 close: can't remove protocol\n"); +#endif +#if IS_ENABLED(CONFIG_IPV6) + if (inet_del_protocol(&tunnel64_protocol, IPPROTO_IPV6)) + pr_err("tunnel64 close: can't remove protocol\n"); +#endif + if (inet_del_protocol(&tunnel4_protocol, IPPROTO_IPIP)) + pr_err("tunnel4 close: can't remove protocol\n"); +} + +module_init(tunnel4_init); +module_exit(tunnel4_fini); +MODULE_LICENSE("GPL"); diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c new file mode 100644 index 000000000..11c0e1c66 --- /dev/null +++ b/net/ipv4/udp.c @@ -0,0 +1,3350 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * The User Datagram Protocol (UDP). + * + * Authors: Ross Biro + * Fred N. van Kempen, + * Arnt Gulbrandsen, + * Alan Cox, + * Hirokazu Takahashi, + * + * Fixes: + * Alan Cox : verify_area() calls + * Alan Cox : stopped close while in use off icmp + * messages. Not a fix but a botch that + * for udp at least is 'valid'. + * Alan Cox : Fixed icmp handling properly + * Alan Cox : Correct error for oversized datagrams + * Alan Cox : Tidied select() semantics. + * Alan Cox : udp_err() fixed properly, also now + * select and read wake correctly on errors + * Alan Cox : udp_send verify_area moved to avoid mem leak + * Alan Cox : UDP can count its memory + * Alan Cox : send to an unknown connection causes + * an ECONNREFUSED off the icmp, but + * does NOT close. + * Alan Cox : Switched to new sk_buff handlers. No more backlog! + * Alan Cox : Using generic datagram code. Even smaller and the PEEK + * bug no longer crashes it. + * Fred Van Kempen : Net2e support for sk->broadcast. + * Alan Cox : Uses skb_free_datagram + * Alan Cox : Added get/set sockopt support. + * Alan Cox : Broadcasting without option set returns EACCES. + * Alan Cox : No wakeup calls. Instead we now use the callbacks. + * Alan Cox : Use ip_tos and ip_ttl + * Alan Cox : SNMP Mibs + * Alan Cox : MSG_DONTROUTE, and 0.0.0.0 support. + * Matt Dillon : UDP length checks. + * Alan Cox : Smarter af_inet used properly. + * Alan Cox : Use new kernel side addressing. + * Alan Cox : Incorrect return on truncated datagram receive. + * Arnt Gulbrandsen : New udp_send and stuff + * Alan Cox : Cache last socket + * Alan Cox : Route cache + * Jon Peatfield : Minor efficiency fix to sendto(). + * Mike Shaver : RFC1122 checks. + * Alan Cox : Nonblocking error fix. + * Willy Konynenberg : Transparent proxying support. + * Mike McLagan : Routing by source + * David S. Miller : New socket lookup architecture. + * Last socket cache retained as it + * does have a high hit rate. + * Olaf Kirch : Don't linearise iovec on sendmsg. + * Andi Kleen : Some cleanups, cache destination entry + * for connect. + * Vitaly E. Lavrov : Transparent proxy revived after year coma. + * Melvin Smith : Check msg_name not msg_namelen in sendto(), + * return ENOTCONN for unconnected sockets (POSIX) + * Janos Farkas : don't deliver multi/broadcasts to a different + * bound-to-device socket + * Hirokazu Takahashi : HW checksumming for outgoing UDP + * datagrams. + * Hirokazu Takahashi : sendfile() on UDP works now. + * Arnaldo C. Melo : convert /proc/net/udp to seq_file + * YOSHIFUJI Hideaki @USAGI and: Support IPV6_V6ONLY socket option, which + * Alexey Kuznetsov: allow both IPv4 and IPv6 sockets to bind + * a single port at the same time. + * Derek Atkins : Add Encapulation Support + * James Chapman : Add L2TP encapsulation type. + */ + +#define pr_fmt(fmt) "UDP: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "udp_impl.h" +#include +#include +#include +#if IS_ENABLED(CONFIG_IPV6) +#include +#endif + +struct udp_table udp_table __read_mostly; +EXPORT_SYMBOL(udp_table); + +long sysctl_udp_mem[3] __read_mostly; +EXPORT_SYMBOL(sysctl_udp_mem); + +atomic_long_t udp_memory_allocated ____cacheline_aligned_in_smp; +EXPORT_SYMBOL(udp_memory_allocated); +DEFINE_PER_CPU(int, udp_memory_per_cpu_fw_alloc); +EXPORT_PER_CPU_SYMBOL_GPL(udp_memory_per_cpu_fw_alloc); + +#define MAX_UDP_PORTS 65536 +#define PORTS_PER_CHAIN (MAX_UDP_PORTS / UDP_HTABLE_SIZE_MIN) + +static int udp_lib_lport_inuse(struct net *net, __u16 num, + const struct udp_hslot *hslot, + unsigned long *bitmap, + struct sock *sk, unsigned int log) +{ + struct sock *sk2; + kuid_t uid = sock_i_uid(sk); + + sk_for_each(sk2, &hslot->head) { + if (net_eq(sock_net(sk2), net) && + sk2 != sk && + (bitmap || udp_sk(sk2)->udp_port_hash == num) && + (!sk2->sk_reuse || !sk->sk_reuse) && + (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if || + sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && + inet_rcv_saddr_equal(sk, sk2, true)) { + if (sk2->sk_reuseport && sk->sk_reuseport && + !rcu_access_pointer(sk->sk_reuseport_cb) && + uid_eq(uid, sock_i_uid(sk2))) { + if (!bitmap) + return 0; + } else { + if (!bitmap) + return 1; + __set_bit(udp_sk(sk2)->udp_port_hash >> log, + bitmap); + } + } + } + return 0; +} + +/* + * Note: we still hold spinlock of primary hash chain, so no other writer + * can insert/delete a socket with local_port == num + */ +static int udp_lib_lport_inuse2(struct net *net, __u16 num, + struct udp_hslot *hslot2, + struct sock *sk) +{ + struct sock *sk2; + kuid_t uid = sock_i_uid(sk); + int res = 0; + + spin_lock(&hslot2->lock); + udp_portaddr_for_each_entry(sk2, &hslot2->head) { + if (net_eq(sock_net(sk2), net) && + sk2 != sk && + (udp_sk(sk2)->udp_port_hash == num) && + (!sk2->sk_reuse || !sk->sk_reuse) && + (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if || + sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && + inet_rcv_saddr_equal(sk, sk2, true)) { + if (sk2->sk_reuseport && sk->sk_reuseport && + !rcu_access_pointer(sk->sk_reuseport_cb) && + uid_eq(uid, sock_i_uid(sk2))) { + res = 0; + } else { + res = 1; + } + break; + } + } + spin_unlock(&hslot2->lock); + return res; +} + +static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot) +{ + struct net *net = sock_net(sk); + kuid_t uid = sock_i_uid(sk); + struct sock *sk2; + + sk_for_each(sk2, &hslot->head) { + if (net_eq(sock_net(sk2), net) && + sk2 != sk && + sk2->sk_family == sk->sk_family && + ipv6_only_sock(sk2) == ipv6_only_sock(sk) && + (udp_sk(sk2)->udp_port_hash == udp_sk(sk)->udp_port_hash) && + (sk2->sk_bound_dev_if == sk->sk_bound_dev_if) && + sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) && + inet_rcv_saddr_equal(sk, sk2, false)) { + return reuseport_add_sock(sk, sk2, + inet_rcv_saddr_any(sk)); + } + } + + return reuseport_alloc(sk, inet_rcv_saddr_any(sk)); +} + +/** + * udp_lib_get_port - UDP/-Lite port lookup for IPv4 and IPv6 + * + * @sk: socket struct in question + * @snum: port number to look up + * @hash2_nulladdr: AF-dependent hash value in secondary hash chains, + * with NULL address + */ +int udp_lib_get_port(struct sock *sk, unsigned short snum, + unsigned int hash2_nulladdr) +{ + struct udp_table *udptable = sk->sk_prot->h.udp_table; + struct udp_hslot *hslot, *hslot2; + struct net *net = sock_net(sk); + int error = -EADDRINUSE; + + if (!snum) { + DECLARE_BITMAP(bitmap, PORTS_PER_CHAIN); + unsigned short first, last; + int low, high, remaining; + unsigned int rand; + + inet_sk_get_local_port_range(sk, &low, &high); + remaining = (high - low) + 1; + + rand = get_random_u32(); + first = reciprocal_scale(rand, remaining) + low; + /* + * force rand to be an odd multiple of UDP_HTABLE_SIZE + */ + rand = (rand | 1) * (udptable->mask + 1); + last = first + udptable->mask + 1; + do { + hslot = udp_hashslot(udptable, net, first); + bitmap_zero(bitmap, PORTS_PER_CHAIN); + spin_lock_bh(&hslot->lock); + udp_lib_lport_inuse(net, snum, hslot, bitmap, sk, + udptable->log); + + snum = first; + /* + * Iterate on all possible values of snum for this hash. + * Using steps of an odd multiple of UDP_HTABLE_SIZE + * give us randomization and full range coverage. + */ + do { + if (low <= snum && snum <= high && + !test_bit(snum >> udptable->log, bitmap) && + !inet_is_local_reserved_port(net, snum)) + goto found; + snum += rand; + } while (snum != first); + spin_unlock_bh(&hslot->lock); + cond_resched(); + } while (++first != last); + goto fail; + } else { + hslot = udp_hashslot(udptable, net, snum); + spin_lock_bh(&hslot->lock); + if (hslot->count > 10) { + int exist; + unsigned int slot2 = udp_sk(sk)->udp_portaddr_hash ^ snum; + + slot2 &= udptable->mask; + hash2_nulladdr &= udptable->mask; + + hslot2 = udp_hashslot2(udptable, slot2); + if (hslot->count < hslot2->count) + goto scan_primary_hash; + + exist = udp_lib_lport_inuse2(net, snum, hslot2, sk); + if (!exist && (hash2_nulladdr != slot2)) { + hslot2 = udp_hashslot2(udptable, hash2_nulladdr); + exist = udp_lib_lport_inuse2(net, snum, hslot2, + sk); + } + if (exist) + goto fail_unlock; + else + goto found; + } +scan_primary_hash: + if (udp_lib_lport_inuse(net, snum, hslot, NULL, sk, 0)) + goto fail_unlock; + } +found: + inet_sk(sk)->inet_num = snum; + udp_sk(sk)->udp_port_hash = snum; + udp_sk(sk)->udp_portaddr_hash ^= snum; + if (sk_unhashed(sk)) { + if (sk->sk_reuseport && + udp_reuseport_add_sock(sk, hslot)) { + inet_sk(sk)->inet_num = 0; + udp_sk(sk)->udp_port_hash = 0; + udp_sk(sk)->udp_portaddr_hash ^= snum; + goto fail_unlock; + } + + sk_add_node_rcu(sk, &hslot->head); + hslot->count++; + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); + + hslot2 = udp_hashslot2(udptable, udp_sk(sk)->udp_portaddr_hash); + spin_lock(&hslot2->lock); + if (IS_ENABLED(CONFIG_IPV6) && sk->sk_reuseport && + sk->sk_family == AF_INET6) + hlist_add_tail_rcu(&udp_sk(sk)->udp_portaddr_node, + &hslot2->head); + else + hlist_add_head_rcu(&udp_sk(sk)->udp_portaddr_node, + &hslot2->head); + hslot2->count++; + spin_unlock(&hslot2->lock); + } + sock_set_flag(sk, SOCK_RCU_FREE); + error = 0; +fail_unlock: + spin_unlock_bh(&hslot->lock); +fail: + return error; +} +EXPORT_SYMBOL(udp_lib_get_port); + +int udp_v4_get_port(struct sock *sk, unsigned short snum) +{ + unsigned int hash2_nulladdr = + ipv4_portaddr_hash(sock_net(sk), htonl(INADDR_ANY), snum); + unsigned int hash2_partial = + ipv4_portaddr_hash(sock_net(sk), inet_sk(sk)->inet_rcv_saddr, 0); + + /* precompute partial secondary hash */ + udp_sk(sk)->udp_portaddr_hash = hash2_partial; + return udp_lib_get_port(sk, snum, hash2_nulladdr); +} + +static int compute_score(struct sock *sk, struct net *net, + __be32 saddr, __be16 sport, + __be32 daddr, unsigned short hnum, + int dif, int sdif) +{ + int score; + struct inet_sock *inet; + bool dev_match; + + if (!net_eq(sock_net(sk), net) || + udp_sk(sk)->udp_port_hash != hnum || + ipv6_only_sock(sk)) + return -1; + + if (sk->sk_rcv_saddr != daddr) + return -1; + + score = (sk->sk_family == PF_INET) ? 2 : 1; + + inet = inet_sk(sk); + if (inet->inet_daddr) { + if (inet->inet_daddr != saddr) + return -1; + score += 4; + } + + if (inet->inet_dport) { + if (inet->inet_dport != sport) + return -1; + score += 4; + } + + dev_match = udp_sk_bound_dev_eq(net, sk->sk_bound_dev_if, + dif, sdif); + if (!dev_match) + return -1; + if (sk->sk_bound_dev_if) + score += 4; + + if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id()) + score++; + return score; +} + +static u32 udp_ehashfn(const struct net *net, const __be32 laddr, + const __u16 lport, const __be32 faddr, + const __be16 fport) +{ + static u32 udp_ehash_secret __read_mostly; + + net_get_random_once(&udp_ehash_secret, sizeof(udp_ehash_secret)); + + return __inet_ehashfn(laddr, lport, faddr, fport, + udp_ehash_secret + net_hash_mix(net)); +} + +static struct sock *lookup_reuseport(struct net *net, struct sock *sk, + struct sk_buff *skb, + __be32 saddr, __be16 sport, + __be32 daddr, unsigned short hnum) +{ + struct sock *reuse_sk = NULL; + u32 hash; + + if (sk->sk_reuseport && sk->sk_state != TCP_ESTABLISHED) { + hash = udp_ehashfn(net, daddr, hnum, saddr, sport); + reuse_sk = reuseport_select_sock(sk, hash, skb, + sizeof(struct udphdr)); + } + return reuse_sk; +} + +/* called with rcu_read_lock() */ +static struct sock *udp4_lib_lookup2(struct net *net, + __be32 saddr, __be16 sport, + __be32 daddr, unsigned int hnum, + int dif, int sdif, + struct udp_hslot *hslot2, + struct sk_buff *skb) +{ + struct sock *sk, *result; + int score, badness; + + result = NULL; + badness = 0; + udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) { + score = compute_score(sk, net, saddr, sport, + daddr, hnum, dif, sdif); + if (score > badness) { + badness = score; + result = lookup_reuseport(net, sk, skb, saddr, sport, daddr, hnum); + if (!result) { + result = sk; + continue; + } + + /* Fall back to scoring if group has connections */ + if (!reuseport_has_conns(sk)) + return result; + + /* Reuseport logic returned an error, keep original score. */ + if (IS_ERR(result)) + continue; + + badness = compute_score(result, net, saddr, sport, + daddr, hnum, dif, sdif); + + } + } + return result; +} + +static struct sock *udp4_lookup_run_bpf(struct net *net, + struct udp_table *udptable, + struct sk_buff *skb, + __be32 saddr, __be16 sport, + __be32 daddr, u16 hnum, const int dif) +{ + struct sock *sk, *reuse_sk; + bool no_reuseport; + + if (udptable != &udp_table) + return NULL; /* only UDP is supported */ + + no_reuseport = bpf_sk_lookup_run_v4(net, IPPROTO_UDP, saddr, sport, + daddr, hnum, dif, &sk); + if (no_reuseport || IS_ERR_OR_NULL(sk)) + return sk; + + reuse_sk = lookup_reuseport(net, sk, skb, saddr, sport, daddr, hnum); + if (reuse_sk) + sk = reuse_sk; + return sk; +} + +/* UDP is nearly always wildcards out the wazoo, it makes no sense to try + * harder than this. -DaveM + */ +struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr, + __be16 sport, __be32 daddr, __be16 dport, int dif, + int sdif, struct udp_table *udptable, struct sk_buff *skb) +{ + unsigned short hnum = ntohs(dport); + unsigned int hash2, slot2; + struct udp_hslot *hslot2; + struct sock *result, *sk; + + hash2 = ipv4_portaddr_hash(net, daddr, hnum); + slot2 = hash2 & udptable->mask; + hslot2 = &udptable->hash2[slot2]; + + /* Lookup connected or non-wildcard socket */ + result = udp4_lib_lookup2(net, saddr, sport, + daddr, hnum, dif, sdif, + hslot2, skb); + if (!IS_ERR_OR_NULL(result) && result->sk_state == TCP_ESTABLISHED) + goto done; + + /* Lookup redirect from BPF */ + if (static_branch_unlikely(&bpf_sk_lookup_enabled)) { + sk = udp4_lookup_run_bpf(net, udptable, skb, + saddr, sport, daddr, hnum, dif); + if (sk) { + result = sk; + goto done; + } + } + + /* Got non-wildcard socket or error on first lookup */ + if (result) + goto done; + + /* Lookup wildcard sockets */ + hash2 = ipv4_portaddr_hash(net, htonl(INADDR_ANY), hnum); + slot2 = hash2 & udptable->mask; + hslot2 = &udptable->hash2[slot2]; + + result = udp4_lib_lookup2(net, saddr, sport, + htonl(INADDR_ANY), hnum, dif, sdif, + hslot2, skb); +done: + if (IS_ERR(result)) + return NULL; + return result; +} +EXPORT_SYMBOL_GPL(__udp4_lib_lookup); + +static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb, + __be16 sport, __be16 dport, + struct udp_table *udptable) +{ + const struct iphdr *iph = ip_hdr(skb); + + return __udp4_lib_lookup(dev_net(skb->dev), iph->saddr, sport, + iph->daddr, dport, inet_iif(skb), + inet_sdif(skb), udptable, skb); +} + +struct sock *udp4_lib_lookup_skb(const struct sk_buff *skb, + __be16 sport, __be16 dport) +{ + const struct iphdr *iph = ip_hdr(skb); + + return __udp4_lib_lookup(dev_net(skb->dev), iph->saddr, sport, + iph->daddr, dport, inet_iif(skb), + inet_sdif(skb), &udp_table, NULL); +} + +/* Must be called under rcu_read_lock(). + * Does increment socket refcount. + */ +#if IS_ENABLED(CONFIG_NF_TPROXY_IPV4) || IS_ENABLED(CONFIG_NF_SOCKET_IPV4) +struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport, + __be32 daddr, __be16 dport, int dif) +{ + struct sock *sk; + + sk = __udp4_lib_lookup(net, saddr, sport, daddr, dport, + dif, 0, &udp_table, NULL); + if (sk && !refcount_inc_not_zero(&sk->sk_refcnt)) + sk = NULL; + return sk; +} +EXPORT_SYMBOL_GPL(udp4_lib_lookup); +#endif + +static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk, + __be16 loc_port, __be32 loc_addr, + __be16 rmt_port, __be32 rmt_addr, + int dif, int sdif, unsigned short hnum) +{ + struct inet_sock *inet = inet_sk(sk); + + if (!net_eq(sock_net(sk), net) || + udp_sk(sk)->udp_port_hash != hnum || + (inet->inet_daddr && inet->inet_daddr != rmt_addr) || + (inet->inet_dport != rmt_port && inet->inet_dport) || + (inet->inet_rcv_saddr && inet->inet_rcv_saddr != loc_addr) || + ipv6_only_sock(sk) || + !udp_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif)) + return false; + if (!ip_mc_sf_allow(sk, loc_addr, rmt_addr, dif, sdif)) + return false; + return true; +} + +DEFINE_STATIC_KEY_FALSE(udp_encap_needed_key); +void udp_encap_enable(void) +{ + static_branch_inc(&udp_encap_needed_key); +} +EXPORT_SYMBOL(udp_encap_enable); + +void udp_encap_disable(void) +{ + static_branch_dec(&udp_encap_needed_key); +} +EXPORT_SYMBOL(udp_encap_disable); + +/* Handler for tunnels with arbitrary destination ports: no socket lookup, go + * through error handlers in encapsulations looking for a match. + */ +static int __udp4_lib_err_encap_no_sk(struct sk_buff *skb, u32 info) +{ + int i; + + for (i = 0; i < MAX_IPTUN_ENCAP_OPS; i++) { + int (*handler)(struct sk_buff *skb, u32 info); + const struct ip_tunnel_encap_ops *encap; + + encap = rcu_dereference(iptun_encaps[i]); + if (!encap) + continue; + handler = encap->err_handler; + if (handler && !handler(skb, info)) + return 0; + } + + return -ENOENT; +} + +/* Try to match ICMP errors to UDP tunnels by looking up a socket without + * reversing source and destination port: this will match tunnels that force the + * same destination port on both endpoints (e.g. VXLAN, GENEVE). Note that + * lwtunnels might actually break this assumption by being configured with + * different destination ports on endpoints, in this case we won't be able to + * trace ICMP messages back to them. + * + * If this doesn't match any socket, probe tunnels with arbitrary destination + * ports (e.g. FoU, GUE): there, the receiving socket is useless, as the port + * we've sent packets to won't necessarily match the local destination port. + * + * Then ask the tunnel implementation to match the error against a valid + * association. + * + * Return an error if we can't find a match, the socket if we need further + * processing, zero otherwise. + */ +static struct sock *__udp4_lib_err_encap(struct net *net, + const struct iphdr *iph, + struct udphdr *uh, + struct udp_table *udptable, + struct sock *sk, + struct sk_buff *skb, u32 info) +{ + int (*lookup)(struct sock *sk, struct sk_buff *skb); + int network_offset, transport_offset; + struct udp_sock *up; + + network_offset = skb_network_offset(skb); + transport_offset = skb_transport_offset(skb); + + /* Network header needs to point to the outer IPv4 header inside ICMP */ + skb_reset_network_header(skb); + + /* Transport header needs to point to the UDP header */ + skb_set_transport_header(skb, iph->ihl << 2); + + if (sk) { + up = udp_sk(sk); + + lookup = READ_ONCE(up->encap_err_lookup); + if (lookup && lookup(sk, skb)) + sk = NULL; + + goto out; + } + + sk = __udp4_lib_lookup(net, iph->daddr, uh->source, + iph->saddr, uh->dest, skb->dev->ifindex, 0, + udptable, NULL); + if (sk) { + up = udp_sk(sk); + + lookup = READ_ONCE(up->encap_err_lookup); + if (!lookup || lookup(sk, skb)) + sk = NULL; + } + +out: + if (!sk) + sk = ERR_PTR(__udp4_lib_err_encap_no_sk(skb, info)); + + skb_set_transport_header(skb, transport_offset); + skb_set_network_header(skb, network_offset); + + return sk; +} + +/* + * This routine is called by the ICMP module when it gets some + * sort of error condition. If err < 0 then the socket should + * be closed and the error returned to the user. If err > 0 + * it's just the icmp type << 8 | icmp code. + * Header points to the ip header of the error packet. We move + * on past this. Then (as it used to claim before adjustment) + * header points to the first 8 bytes of the udp header. We need + * to find the appropriate port. + */ + +int __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable) +{ + struct inet_sock *inet; + const struct iphdr *iph = (const struct iphdr *)skb->data; + struct udphdr *uh = (struct udphdr *)(skb->data+(iph->ihl<<2)); + const int type = icmp_hdr(skb)->type; + const int code = icmp_hdr(skb)->code; + bool tunnel = false; + struct sock *sk; + int harderr; + int err; + struct net *net = dev_net(skb->dev); + + sk = __udp4_lib_lookup(net, iph->daddr, uh->dest, + iph->saddr, uh->source, skb->dev->ifindex, + inet_sdif(skb), udptable, NULL); + + if (!sk || READ_ONCE(udp_sk(sk)->encap_type)) { + /* No socket for error: try tunnels before discarding */ + if (static_branch_unlikely(&udp_encap_needed_key)) { + sk = __udp4_lib_err_encap(net, iph, uh, udptable, sk, skb, + info); + if (!sk) + return 0; + } else + sk = ERR_PTR(-ENOENT); + + if (IS_ERR(sk)) { + __ICMP_INC_STATS(net, ICMP_MIB_INERRORS); + return PTR_ERR(sk); + } + + tunnel = true; + } + + err = 0; + harderr = 0; + inet = inet_sk(sk); + + switch (type) { + default: + case ICMP_TIME_EXCEEDED: + err = EHOSTUNREACH; + break; + case ICMP_SOURCE_QUENCH: + goto out; + case ICMP_PARAMETERPROB: + err = EPROTO; + harderr = 1; + break; + case ICMP_DEST_UNREACH: + if (code == ICMP_FRAG_NEEDED) { /* Path MTU discovery */ + ipv4_sk_update_pmtu(skb, sk, info); + if (inet->pmtudisc != IP_PMTUDISC_DONT) { + err = EMSGSIZE; + harderr = 1; + break; + } + goto out; + } + err = EHOSTUNREACH; + if (code <= NR_ICMP_UNREACH) { + harderr = icmp_err_convert[code].fatal; + err = icmp_err_convert[code].errno; + } + break; + case ICMP_REDIRECT: + ipv4_sk_redirect(skb, sk); + goto out; + } + + /* + * RFC1122: OK. Passes ICMP errors back to application, as per + * 4.1.3.3. + */ + if (tunnel) { + /* ...not for tunnels though: we don't have a sending socket */ + if (udp_sk(sk)->encap_err_rcv) + udp_sk(sk)->encap_err_rcv(sk, skb, iph->ihl << 2); + goto out; + } + if (!inet->recverr) { + if (!harderr || sk->sk_state != TCP_ESTABLISHED) + goto out; + } else + ip_icmp_error(sk, skb, err, uh->dest, info, (u8 *)(uh+1)); + + sk->sk_err = err; + sk_error_report(sk); +out: + return 0; +} + +int udp_err(struct sk_buff *skb, u32 info) +{ + return __udp4_lib_err(skb, info, &udp_table); +} + +/* + * Throw away all pending data and cancel the corking. Socket is locked. + */ +void udp_flush_pending_frames(struct sock *sk) +{ + struct udp_sock *up = udp_sk(sk); + + if (up->pending) { + up->len = 0; + WRITE_ONCE(up->pending, 0); + ip_flush_pending_frames(sk); + } +} +EXPORT_SYMBOL(udp_flush_pending_frames); + +/** + * udp4_hwcsum - handle outgoing HW checksumming + * @skb: sk_buff containing the filled-in UDP header + * (checksum field must be zeroed out) + * @src: source IP address + * @dst: destination IP address + */ +void udp4_hwcsum(struct sk_buff *skb, __be32 src, __be32 dst) +{ + struct udphdr *uh = udp_hdr(skb); + int offset = skb_transport_offset(skb); + int len = skb->len - offset; + int hlen = len; + __wsum csum = 0; + + if (!skb_has_frag_list(skb)) { + /* + * Only one fragment on the socket. + */ + skb->csum_start = skb_transport_header(skb) - skb->head; + skb->csum_offset = offsetof(struct udphdr, check); + uh->check = ~csum_tcpudp_magic(src, dst, len, + IPPROTO_UDP, 0); + } else { + struct sk_buff *frags; + + /* + * HW-checksum won't work as there are two or more + * fragments on the socket so that all csums of sk_buffs + * should be together + */ + skb_walk_frags(skb, frags) { + csum = csum_add(csum, frags->csum); + hlen -= frags->len; + } + + csum = skb_checksum(skb, offset, hlen, csum); + skb->ip_summed = CHECKSUM_NONE; + + uh->check = csum_tcpudp_magic(src, dst, len, IPPROTO_UDP, csum); + if (uh->check == 0) + uh->check = CSUM_MANGLED_0; + } +} +EXPORT_SYMBOL_GPL(udp4_hwcsum); + +/* Function to set UDP checksum for an IPv4 UDP packet. This is intended + * for the simple case like when setting the checksum for a UDP tunnel. + */ +void udp_set_csum(bool nocheck, struct sk_buff *skb, + __be32 saddr, __be32 daddr, int len) +{ + struct udphdr *uh = udp_hdr(skb); + + if (nocheck) { + uh->check = 0; + } else if (skb_is_gso(skb)) { + uh->check = ~udp_v4_check(len, saddr, daddr, 0); + } else if (skb->ip_summed == CHECKSUM_PARTIAL) { + uh->check = 0; + uh->check = udp_v4_check(len, saddr, daddr, lco_csum(skb)); + if (uh->check == 0) + uh->check = CSUM_MANGLED_0; + } else { + skb->ip_summed = CHECKSUM_PARTIAL; + skb->csum_start = skb_transport_header(skb) - skb->head; + skb->csum_offset = offsetof(struct udphdr, check); + uh->check = ~udp_v4_check(len, saddr, daddr, 0); + } +} +EXPORT_SYMBOL(udp_set_csum); + +static int udp_send_skb(struct sk_buff *skb, struct flowi4 *fl4, + struct inet_cork *cork) +{ + struct sock *sk = skb->sk; + struct inet_sock *inet = inet_sk(sk); + struct udphdr *uh; + int err; + int is_udplite = IS_UDPLITE(sk); + int offset = skb_transport_offset(skb); + int len = skb->len - offset; + int datalen = len - sizeof(*uh); + __wsum csum = 0; + + /* + * Create a UDP header + */ + uh = udp_hdr(skb); + uh->source = inet->inet_sport; + uh->dest = fl4->fl4_dport; + uh->len = htons(len); + uh->check = 0; + + if (cork->gso_size) { + const int hlen = skb_network_header_len(skb) + + sizeof(struct udphdr); + + if (hlen + cork->gso_size > cork->fragsize) { + kfree_skb(skb); + return -EINVAL; + } + if (datalen > cork->gso_size * UDP_MAX_SEGMENTS) { + kfree_skb(skb); + return -EINVAL; + } + if (sk->sk_no_check_tx) { + kfree_skb(skb); + return -EINVAL; + } + if (skb->ip_summed != CHECKSUM_PARTIAL || is_udplite || + dst_xfrm(skb_dst(skb))) { + kfree_skb(skb); + return -EIO; + } + + if (datalen > cork->gso_size) { + skb_shinfo(skb)->gso_size = cork->gso_size; + skb_shinfo(skb)->gso_type = SKB_GSO_UDP_L4; + skb_shinfo(skb)->gso_segs = DIV_ROUND_UP(datalen, + cork->gso_size); + } + goto csum_partial; + } + + if (is_udplite) /* UDP-Lite */ + csum = udplite_csum(skb); + + else if (sk->sk_no_check_tx) { /* UDP csum off */ + + skb->ip_summed = CHECKSUM_NONE; + goto send; + + } else if (skb->ip_summed == CHECKSUM_PARTIAL) { /* UDP hardware csum */ +csum_partial: + + udp4_hwcsum(skb, fl4->saddr, fl4->daddr); + goto send; + + } else + csum = udp_csum(skb); + + /* add protocol-dependent pseudo-header */ + uh->check = csum_tcpudp_magic(fl4->saddr, fl4->daddr, len, + sk->sk_protocol, csum); + if (uh->check == 0) + uh->check = CSUM_MANGLED_0; + +send: + err = ip_send_skb(sock_net(sk), skb); + if (err) { + if (err == -ENOBUFS && !inet->recverr) { + UDP_INC_STATS(sock_net(sk), + UDP_MIB_SNDBUFERRORS, is_udplite); + err = 0; + } + } else + UDP_INC_STATS(sock_net(sk), + UDP_MIB_OUTDATAGRAMS, is_udplite); + return err; +} + +/* + * Push out all pending data as one UDP datagram. Socket is locked. + */ +int udp_push_pending_frames(struct sock *sk) +{ + struct udp_sock *up = udp_sk(sk); + struct inet_sock *inet = inet_sk(sk); + struct flowi4 *fl4 = &inet->cork.fl.u.ip4; + struct sk_buff *skb; + int err = 0; + + skb = ip_finish_skb(sk, fl4); + if (!skb) + goto out; + + err = udp_send_skb(skb, fl4, &inet->cork.base); + +out: + up->len = 0; + WRITE_ONCE(up->pending, 0); + return err; +} +EXPORT_SYMBOL(udp_push_pending_frames); + +static int __udp_cmsg_send(struct cmsghdr *cmsg, u16 *gso_size) +{ + switch (cmsg->cmsg_type) { + case UDP_SEGMENT: + if (cmsg->cmsg_len != CMSG_LEN(sizeof(__u16))) + return -EINVAL; + *gso_size = *(__u16 *)CMSG_DATA(cmsg); + return 0; + default: + return -EINVAL; + } +} + +int udp_cmsg_send(struct sock *sk, struct msghdr *msg, u16 *gso_size) +{ + struct cmsghdr *cmsg; + bool need_ip = false; + int err; + + for_each_cmsghdr(cmsg, msg) { + if (!CMSG_OK(msg, cmsg)) + return -EINVAL; + + if (cmsg->cmsg_level != SOL_UDP) { + need_ip = true; + continue; + } + + err = __udp_cmsg_send(cmsg, gso_size); + if (err) + return err; + } + + return need_ip; +} +EXPORT_SYMBOL_GPL(udp_cmsg_send); + +int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) +{ + struct inet_sock *inet = inet_sk(sk); + struct udp_sock *up = udp_sk(sk); + DECLARE_SOCKADDR(struct sockaddr_in *, usin, msg->msg_name); + struct flowi4 fl4_stack; + struct flowi4 *fl4; + int ulen = len; + struct ipcm_cookie ipc; + struct rtable *rt = NULL; + int free = 0; + int connected = 0; + __be32 daddr, faddr, saddr; + __be16 dport; + u8 tos; + int err, is_udplite = IS_UDPLITE(sk); + int corkreq = udp_test_bit(CORK, sk) || msg->msg_flags & MSG_MORE; + int (*getfrag)(void *, char *, int, int, int, struct sk_buff *); + struct sk_buff *skb; + struct ip_options_data opt_copy; + + if (len > 0xFFFF) + return -EMSGSIZE; + + /* + * Check the flags. + */ + + if (msg->msg_flags & MSG_OOB) /* Mirror BSD error message compatibility */ + return -EOPNOTSUPP; + + getfrag = is_udplite ? udplite_getfrag : ip_generic_getfrag; + + fl4 = &inet->cork.fl.u.ip4; + if (READ_ONCE(up->pending)) { + /* + * There are pending frames. + * The socket lock must be held while it's corked. + */ + lock_sock(sk); + if (likely(up->pending)) { + if (unlikely(up->pending != AF_INET)) { + release_sock(sk); + return -EINVAL; + } + goto do_append_data; + } + release_sock(sk); + } + ulen += sizeof(struct udphdr); + + /* + * Get and verify the address. + */ + if (usin) { + if (msg->msg_namelen < sizeof(*usin)) + return -EINVAL; + if (usin->sin_family != AF_INET) { + if (usin->sin_family != AF_UNSPEC) + return -EAFNOSUPPORT; + } + + daddr = usin->sin_addr.s_addr; + dport = usin->sin_port; + if (dport == 0) + return -EINVAL; + } else { + if (sk->sk_state != TCP_ESTABLISHED) + return -EDESTADDRREQ; + daddr = inet->inet_daddr; + dport = inet->inet_dport; + /* Open fast path for connected socket. + Route will not be used, if at least one option is set. + */ + connected = 1; + } + + ipcm_init_sk(&ipc, inet); + ipc.gso_size = READ_ONCE(up->gso_size); + + if (msg->msg_controllen) { + err = udp_cmsg_send(sk, msg, &ipc.gso_size); + if (err > 0) + err = ip_cmsg_send(sk, msg, &ipc, + sk->sk_family == AF_INET6); + if (unlikely(err < 0)) { + kfree(ipc.opt); + return err; + } + if (ipc.opt) + free = 1; + connected = 0; + } + if (!ipc.opt) { + struct ip_options_rcu *inet_opt; + + rcu_read_lock(); + inet_opt = rcu_dereference(inet->inet_opt); + if (inet_opt) { + memcpy(&opt_copy, inet_opt, + sizeof(*inet_opt) + inet_opt->opt.optlen); + ipc.opt = &opt_copy.opt; + } + rcu_read_unlock(); + } + + if (cgroup_bpf_enabled(CGROUP_UDP4_SENDMSG) && !connected) { + err = BPF_CGROUP_RUN_PROG_UDP4_SENDMSG_LOCK(sk, + (struct sockaddr *)usin, &ipc.addr); + if (err) + goto out_free; + if (usin) { + if (usin->sin_port == 0) { + /* BPF program set invalid port. Reject it. */ + err = -EINVAL; + goto out_free; + } + daddr = usin->sin_addr.s_addr; + dport = usin->sin_port; + } + } + + saddr = ipc.addr; + ipc.addr = faddr = daddr; + + if (ipc.opt && ipc.opt->opt.srr) { + if (!daddr) { + err = -EINVAL; + goto out_free; + } + faddr = ipc.opt->opt.faddr; + connected = 0; + } + tos = get_rttos(&ipc, inet); + if (sock_flag(sk, SOCK_LOCALROUTE) || + (msg->msg_flags & MSG_DONTROUTE) || + (ipc.opt && ipc.opt->opt.is_strictroute)) { + tos |= RTO_ONLINK; + connected = 0; + } + + if (ipv4_is_multicast(daddr)) { + if (!ipc.oif || netif_index_is_l3_master(sock_net(sk), ipc.oif)) + ipc.oif = inet->mc_index; + if (!saddr) + saddr = inet->mc_addr; + connected = 0; + } else if (!ipc.oif) { + ipc.oif = inet->uc_index; + } else if (ipv4_is_lbcast(daddr) && inet->uc_index) { + /* oif is set, packet is to local broadcast and + * uc_index is set. oif is most likely set + * by sk_bound_dev_if. If uc_index != oif check if the + * oif is an L3 master and uc_index is an L3 slave. + * If so, we want to allow the send using the uc_index. + */ + if (ipc.oif != inet->uc_index && + ipc.oif == l3mdev_master_ifindex_by_index(sock_net(sk), + inet->uc_index)) { + ipc.oif = inet->uc_index; + } + } + + if (connected) + rt = (struct rtable *)sk_dst_check(sk, 0); + + if (!rt) { + struct net *net = sock_net(sk); + __u8 flow_flags = inet_sk_flowi_flags(sk); + + fl4 = &fl4_stack; + + flowi4_init_output(fl4, ipc.oif, ipc.sockc.mark, tos, + RT_SCOPE_UNIVERSE, sk->sk_protocol, + flow_flags, + faddr, saddr, dport, inet->inet_sport, + sk->sk_uid); + + security_sk_classify_flow(sk, flowi4_to_flowi_common(fl4)); + rt = ip_route_output_flow(net, fl4, sk); + if (IS_ERR(rt)) { + err = PTR_ERR(rt); + rt = NULL; + if (err == -ENETUNREACH) + IP_INC_STATS(net, IPSTATS_MIB_OUTNOROUTES); + goto out; + } + + err = -EACCES; + if ((rt->rt_flags & RTCF_BROADCAST) && + !sock_flag(sk, SOCK_BROADCAST)) + goto out; + if (connected) + sk_dst_set(sk, dst_clone(&rt->dst)); + } + + if (msg->msg_flags&MSG_CONFIRM) + goto do_confirm; +back_from_confirm: + + saddr = fl4->saddr; + if (!ipc.addr) + daddr = ipc.addr = fl4->daddr; + + /* Lockless fast path for the non-corking case. */ + if (!corkreq) { + struct inet_cork cork; + + skb = ip_make_skb(sk, fl4, getfrag, msg, ulen, + sizeof(struct udphdr), &ipc, &rt, + &cork, msg->msg_flags); + err = PTR_ERR(skb); + if (!IS_ERR_OR_NULL(skb)) + err = udp_send_skb(skb, fl4, &cork); + goto out; + } + + lock_sock(sk); + if (unlikely(up->pending)) { + /* The socket is already corked while preparing it. */ + /* ... which is an evident application bug. --ANK */ + release_sock(sk); + + net_dbg_ratelimited("socket already corked\n"); + err = -EINVAL; + goto out; + } + /* + * Now cork the socket to pend data. + */ + fl4 = &inet->cork.fl.u.ip4; + fl4->daddr = daddr; + fl4->saddr = saddr; + fl4->fl4_dport = dport; + fl4->fl4_sport = inet->inet_sport; + WRITE_ONCE(up->pending, AF_INET); + +do_append_data: + up->len += ulen; + err = ip_append_data(sk, fl4, getfrag, msg, ulen, + sizeof(struct udphdr), &ipc, &rt, + corkreq ? msg->msg_flags|MSG_MORE : msg->msg_flags); + if (err) + udp_flush_pending_frames(sk); + else if (!corkreq) + err = udp_push_pending_frames(sk); + else if (unlikely(skb_queue_empty(&sk->sk_write_queue))) + WRITE_ONCE(up->pending, 0); + release_sock(sk); + +out: + ip_rt_put(rt); +out_free: + if (free) + kfree(ipc.opt); + if (!err) + return len; + /* + * ENOBUFS = no kernel mem, SOCK_NOSPACE = no sndbuf space. Reporting + * ENOBUFS might not be good (it's not tunable per se), but otherwise + * we don't have a good statistic (IpOutDiscards but it can be too many + * things). We could add another new stat but at least for now that + * seems like overkill. + */ + if (err == -ENOBUFS || test_bit(SOCK_NOSPACE, &sk->sk_socket->flags)) { + UDP_INC_STATS(sock_net(sk), + UDP_MIB_SNDBUFERRORS, is_udplite); + } + return err; + +do_confirm: + if (msg->msg_flags & MSG_PROBE) + dst_confirm_neigh(&rt->dst, &fl4->daddr); + if (!(msg->msg_flags&MSG_PROBE) || len) + goto back_from_confirm; + err = 0; + goto out; +} +EXPORT_SYMBOL(udp_sendmsg); + +void udp_splice_eof(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct udp_sock *up = udp_sk(sk); + + if (!READ_ONCE(up->pending) || udp_test_bit(CORK, sk)) + return; + + lock_sock(sk); + if (up->pending && !udp_test_bit(CORK, sk)) + udp_push_pending_frames(sk); + release_sock(sk); +} +EXPORT_SYMBOL_GPL(udp_splice_eof); + +int udp_sendpage(struct sock *sk, struct page *page, int offset, + size_t size, int flags) +{ + struct bio_vec bvec; + struct msghdr msg = { .msg_flags = flags | MSG_SPLICE_PAGES }; + + if (flags & MSG_SENDPAGE_NOTLAST) + msg.msg_flags |= MSG_MORE; + + bvec_set_page(&bvec, page, size, offset); + iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, size); + return udp_sendmsg(sk, &msg, size); +} + +#define UDP_SKB_IS_STATELESS 0x80000000 + +/* all head states (dst, sk, nf conntrack) except skb extensions are + * cleared by udp_rcv(). + * + * We need to preserve secpath, if present, to eventually process + * IP_CMSG_PASSSEC at recvmsg() time. + * + * Other extensions can be cleared. + */ +static bool udp_try_make_stateless(struct sk_buff *skb) +{ + if (!skb_has_extensions(skb)) + return true; + + if (!secpath_exists(skb)) { + skb_ext_reset(skb); + return true; + } + + return false; +} + +static void udp_set_dev_scratch(struct sk_buff *skb) +{ + struct udp_dev_scratch *scratch = udp_skb_scratch(skb); + + BUILD_BUG_ON(sizeof(struct udp_dev_scratch) > sizeof(long)); + scratch->_tsize_state = skb->truesize; +#if BITS_PER_LONG == 64 + scratch->len = skb->len; + scratch->csum_unnecessary = !!skb_csum_unnecessary(skb); + scratch->is_linear = !skb_is_nonlinear(skb); +#endif + if (udp_try_make_stateless(skb)) + scratch->_tsize_state |= UDP_SKB_IS_STATELESS; +} + +static void udp_skb_csum_unnecessary_set(struct sk_buff *skb) +{ + /* We come here after udp_lib_checksum_complete() returned 0. + * This means that __skb_checksum_complete() might have + * set skb->csum_valid to 1. + * On 64bit platforms, we can set csum_unnecessary + * to true, but only if the skb is not shared. + */ +#if BITS_PER_LONG == 64 + if (!skb_shared(skb)) + udp_skb_scratch(skb)->csum_unnecessary = true; +#endif +} + +static int udp_skb_truesize(struct sk_buff *skb) +{ + return udp_skb_scratch(skb)->_tsize_state & ~UDP_SKB_IS_STATELESS; +} + +static bool udp_skb_has_head_state(struct sk_buff *skb) +{ + return !(udp_skb_scratch(skb)->_tsize_state & UDP_SKB_IS_STATELESS); +} + +/* fully reclaim rmem/fwd memory allocated for skb */ +static void udp_rmem_release(struct sock *sk, int size, int partial, + bool rx_queue_lock_held) +{ + struct udp_sock *up = udp_sk(sk); + struct sk_buff_head *sk_queue; + int amt; + + if (likely(partial)) { + up->forward_deficit += size; + size = up->forward_deficit; + if (size < (sk->sk_rcvbuf >> 2) && + !skb_queue_empty(&up->reader_queue)) + return; + } else { + size += up->forward_deficit; + } + up->forward_deficit = 0; + + /* acquire the sk_receive_queue for fwd allocated memory scheduling, + * if the called don't held it already + */ + sk_queue = &sk->sk_receive_queue; + if (!rx_queue_lock_held) + spin_lock(&sk_queue->lock); + + + sk_forward_alloc_add(sk, size); + amt = (sk->sk_forward_alloc - partial) & ~(PAGE_SIZE - 1); + sk_forward_alloc_add(sk, -amt); + + if (amt) + __sk_mem_reduce_allocated(sk, amt >> PAGE_SHIFT); + + atomic_sub(size, &sk->sk_rmem_alloc); + + /* this can save us from acquiring the rx queue lock on next receive */ + skb_queue_splice_tail_init(sk_queue, &up->reader_queue); + + if (!rx_queue_lock_held) + spin_unlock(&sk_queue->lock); +} + +/* Note: called with reader_queue.lock held. + * Instead of using skb->truesize here, find a copy of it in skb->dev_scratch + * This avoids a cache line miss while receive_queue lock is held. + * Look at __udp_enqueue_schedule_skb() to find where this copy is done. + */ +void udp_skb_destructor(struct sock *sk, struct sk_buff *skb) +{ + prefetch(&skb->data); + udp_rmem_release(sk, udp_skb_truesize(skb), 1, false); +} +EXPORT_SYMBOL(udp_skb_destructor); + +/* as above, but the caller held the rx queue lock, too */ +static void udp_skb_dtor_locked(struct sock *sk, struct sk_buff *skb) +{ + prefetch(&skb->data); + udp_rmem_release(sk, udp_skb_truesize(skb), 1, true); +} + +/* Idea of busylocks is to let producers grab an extra spinlock + * to relieve pressure on the receive_queue spinlock shared by consumer. + * Under flood, this means that only one producer can be in line + * trying to acquire the receive_queue spinlock. + * These busylock can be allocated on a per cpu manner, instead of a + * per socket one (that would consume a cache line per socket) + */ +static int udp_busylocks_log __read_mostly; +static spinlock_t *udp_busylocks __read_mostly; + +static spinlock_t *busylock_acquire(void *ptr) +{ + spinlock_t *busy; + + busy = udp_busylocks + hash_ptr(ptr, udp_busylocks_log); + spin_lock(busy); + return busy; +} + +static void busylock_release(spinlock_t *busy) +{ + if (busy) + spin_unlock(busy); +} + +int __udp_enqueue_schedule_skb(struct sock *sk, struct sk_buff *skb) +{ + struct sk_buff_head *list = &sk->sk_receive_queue; + int rmem, delta, amt, err = -ENOMEM; + spinlock_t *busy = NULL; + int size; + + /* try to avoid the costly atomic add/sub pair when the receive + * queue is full; always allow at least a packet + */ + rmem = atomic_read(&sk->sk_rmem_alloc); + if (rmem > sk->sk_rcvbuf) + goto drop; + + /* Under mem pressure, it might be helpful to help udp_recvmsg() + * having linear skbs : + * - Reduce memory overhead and thus increase receive queue capacity + * - Less cache line misses at copyout() time + * - Less work at consume_skb() (less alien page frag freeing) + */ + if (rmem > (sk->sk_rcvbuf >> 1)) { + skb_condense(skb); + + busy = busylock_acquire(sk); + } + size = skb->truesize; + udp_set_dev_scratch(skb); + + /* we drop only if the receive buf is full and the receive + * queue contains some other skb + */ + rmem = atomic_add_return(size, &sk->sk_rmem_alloc); + if (rmem > (size + (unsigned int)sk->sk_rcvbuf)) + goto uncharge_drop; + + spin_lock(&list->lock); + if (size >= sk->sk_forward_alloc) { + amt = sk_mem_pages(size); + delta = amt << PAGE_SHIFT; + if (!__sk_mem_raise_allocated(sk, delta, amt, SK_MEM_RECV)) { + err = -ENOBUFS; + spin_unlock(&list->lock); + goto uncharge_drop; + } + + sk->sk_forward_alloc += delta; + } + + sk_forward_alloc_add(sk, -size); + + /* no need to setup a destructor, we will explicitly release the + * forward allocated memory on dequeue + */ + sock_skb_set_dropcount(sk, skb); + + __skb_queue_tail(list, skb); + spin_unlock(&list->lock); + + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_data_ready(sk); + + busylock_release(busy); + return 0; + +uncharge_drop: + atomic_sub(skb->truesize, &sk->sk_rmem_alloc); + +drop: + atomic_inc(&sk->sk_drops); + busylock_release(busy); + return err; +} +EXPORT_SYMBOL_GPL(__udp_enqueue_schedule_skb); + +void udp_destruct_common(struct sock *sk) +{ + /* reclaim completely the forward allocated memory */ + struct udp_sock *up = udp_sk(sk); + unsigned int total = 0; + struct sk_buff *skb; + + skb_queue_splice_tail_init(&sk->sk_receive_queue, &up->reader_queue); + while ((skb = __skb_dequeue(&up->reader_queue)) != NULL) { + total += skb->truesize; + kfree_skb(skb); + } + udp_rmem_release(sk, total, 0, true); +} +EXPORT_SYMBOL_GPL(udp_destruct_common); + +static void udp_destruct_sock(struct sock *sk) +{ + udp_destruct_common(sk); + inet_sock_destruct(sk); +} + +int udp_init_sock(struct sock *sk) +{ + skb_queue_head_init(&udp_sk(sk)->reader_queue); + sk->sk_destruct = udp_destruct_sock; + set_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags); + return 0; +} + +void skb_consume_udp(struct sock *sk, struct sk_buff *skb, int len) +{ + if (unlikely(READ_ONCE(sk->sk_peek_off) >= 0)) { + bool slow = lock_sock_fast(sk); + + sk_peek_offset_bwd(sk, len); + unlock_sock_fast(sk, slow); + } + + if (!skb_unref(skb)) + return; + + /* In the more common cases we cleared the head states previously, + * see __udp_queue_rcv_skb(). + */ + if (unlikely(udp_skb_has_head_state(skb))) + skb_release_head_state(skb); + __consume_stateless_skb(skb); +} +EXPORT_SYMBOL_GPL(skb_consume_udp); + +static struct sk_buff *__first_packet_length(struct sock *sk, + struct sk_buff_head *rcvq, + int *total) +{ + struct sk_buff *skb; + + while ((skb = skb_peek(rcvq)) != NULL) { + if (udp_lib_checksum_complete(skb)) { + __UDP_INC_STATS(sock_net(sk), UDP_MIB_CSUMERRORS, + IS_UDPLITE(sk)); + __UDP_INC_STATS(sock_net(sk), UDP_MIB_INERRORS, + IS_UDPLITE(sk)); + atomic_inc(&sk->sk_drops); + __skb_unlink(skb, rcvq); + *total += skb->truesize; + kfree_skb(skb); + } else { + udp_skb_csum_unnecessary_set(skb); + break; + } + } + return skb; +} + +/** + * first_packet_length - return length of first packet in receive queue + * @sk: socket + * + * Drops all bad checksum frames, until a valid one is found. + * Returns the length of found skb, or -1 if none is found. + */ +static int first_packet_length(struct sock *sk) +{ + struct sk_buff_head *rcvq = &udp_sk(sk)->reader_queue; + struct sk_buff_head *sk_queue = &sk->sk_receive_queue; + struct sk_buff *skb; + int total = 0; + int res; + + spin_lock_bh(&rcvq->lock); + skb = __first_packet_length(sk, rcvq, &total); + if (!skb && !skb_queue_empty_lockless(sk_queue)) { + spin_lock(&sk_queue->lock); + skb_queue_splice_tail_init(sk_queue, rcvq); + spin_unlock(&sk_queue->lock); + + skb = __first_packet_length(sk, rcvq, &total); + } + res = skb ? skb->len : -1; + if (total) + udp_rmem_release(sk, total, 1, false); + spin_unlock_bh(&rcvq->lock); + return res; +} + +/* + * IOCTL requests applicable to the UDP protocol + */ + +int udp_ioctl(struct sock *sk, int cmd, unsigned long arg) +{ + switch (cmd) { + case SIOCOUTQ: + { + int amount = sk_wmem_alloc_get(sk); + + return put_user(amount, (int __user *)arg); + } + + case SIOCINQ: + { + int amount = max_t(int, 0, first_packet_length(sk)); + + return put_user(amount, (int __user *)arg); + } + + default: + return -ENOIOCTLCMD; + } + + return 0; +} +EXPORT_SYMBOL(udp_ioctl); + +struct sk_buff *__skb_recv_udp(struct sock *sk, unsigned int flags, + int *off, int *err) +{ + struct sk_buff_head *sk_queue = &sk->sk_receive_queue; + struct sk_buff_head *queue; + struct sk_buff *last; + long timeo; + int error; + + queue = &udp_sk(sk)->reader_queue; + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + do { + struct sk_buff *skb; + + error = sock_error(sk); + if (error) + break; + + error = -EAGAIN; + do { + spin_lock_bh(&queue->lock); + skb = __skb_try_recv_from_queue(sk, queue, flags, off, + err, &last); + if (skb) { + if (!(flags & MSG_PEEK)) + udp_skb_destructor(sk, skb); + spin_unlock_bh(&queue->lock); + return skb; + } + + if (skb_queue_empty_lockless(sk_queue)) { + spin_unlock_bh(&queue->lock); + goto busy_check; + } + + /* refill the reader queue and walk it again + * keep both queues locked to avoid re-acquiring + * the sk_receive_queue lock if fwd memory scheduling + * is needed. + */ + spin_lock(&sk_queue->lock); + skb_queue_splice_tail_init(sk_queue, queue); + + skb = __skb_try_recv_from_queue(sk, queue, flags, off, + err, &last); + if (skb && !(flags & MSG_PEEK)) + udp_skb_dtor_locked(sk, skb); + spin_unlock(&sk_queue->lock); + spin_unlock_bh(&queue->lock); + if (skb) + return skb; + +busy_check: + if (!sk_can_busy_loop(sk)) + break; + + sk_busy_loop(sk, flags & MSG_DONTWAIT); + } while (!skb_queue_empty_lockless(sk_queue)); + + /* sk_queue is empty, reader_queue may contain peeked packets */ + } while (timeo && + !__skb_wait_for_more_packets(sk, &sk->sk_receive_queue, + &error, &timeo, + (struct sk_buff *)sk_queue)); + + *err = error; + return NULL; +} +EXPORT_SYMBOL(__skb_recv_udp); + +int udp_read_skb(struct sock *sk, skb_read_actor_t recv_actor) +{ + struct sk_buff *skb; + int err; + +try_again: + skb = skb_recv_udp(sk, MSG_DONTWAIT, &err); + if (!skb) + return err; + + if (udp_lib_checksum_complete(skb)) { + int is_udplite = IS_UDPLITE(sk); + struct net *net = sock_net(sk); + + __UDP_INC_STATS(net, UDP_MIB_CSUMERRORS, is_udplite); + __UDP_INC_STATS(net, UDP_MIB_INERRORS, is_udplite); + atomic_inc(&sk->sk_drops); + kfree_skb(skb); + goto try_again; + } + + WARN_ON_ONCE(!skb_set_owner_sk_safe(skb, sk)); + return recv_actor(sk, skb); +} +EXPORT_SYMBOL(udp_read_skb); + +/* + * This should be easy, if there is something there we + * return it, otherwise we block. + */ + +int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, + int *addr_len) +{ + struct inet_sock *inet = inet_sk(sk); + DECLARE_SOCKADDR(struct sockaddr_in *, sin, msg->msg_name); + struct sk_buff *skb; + unsigned int ulen, copied; + int off, err, peeking = flags & MSG_PEEK; + int is_udplite = IS_UDPLITE(sk); + bool checksum_valid = false; + + if (flags & MSG_ERRQUEUE) + return ip_recv_error(sk, msg, len, addr_len); + +try_again: + off = sk_peek_offset(sk, flags); + skb = __skb_recv_udp(sk, flags, &off, &err); + if (!skb) + return err; + + ulen = udp_skb_len(skb); + copied = len; + if (copied > ulen - off) + copied = ulen - off; + else if (copied < ulen) + msg->msg_flags |= MSG_TRUNC; + + /* + * If checksum is needed at all, try to do it while copying the + * data. If the data is truncated, or if we only want a partial + * coverage checksum (UDP-Lite), do it before the copy. + */ + + if (copied < ulen || peeking || + (is_udplite && UDP_SKB_CB(skb)->partial_cov)) { + checksum_valid = udp_skb_csum_unnecessary(skb) || + !__udp_lib_checksum_complete(skb); + if (!checksum_valid) + goto csum_copy_err; + } + + if (checksum_valid || udp_skb_csum_unnecessary(skb)) { + if (udp_skb_is_linear(skb)) + err = copy_linear_skb(skb, copied, off, &msg->msg_iter); + else + err = skb_copy_datagram_msg(skb, off, msg, copied); + } else { + err = skb_copy_and_csum_datagram_msg(skb, off, msg); + + if (err == -EINVAL) + goto csum_copy_err; + } + + if (unlikely(err)) { + if (!peeking) { + atomic_inc(&sk->sk_drops); + UDP_INC_STATS(sock_net(sk), + UDP_MIB_INERRORS, is_udplite); + } + kfree_skb(skb); + return err; + } + + if (!peeking) + UDP_INC_STATS(sock_net(sk), + UDP_MIB_INDATAGRAMS, is_udplite); + + sock_recv_cmsgs(msg, sk, skb); + + /* Copy the address. */ + if (sin) { + sin->sin_family = AF_INET; + sin->sin_port = udp_hdr(skb)->source; + sin->sin_addr.s_addr = ip_hdr(skb)->saddr; + memset(sin->sin_zero, 0, sizeof(sin->sin_zero)); + *addr_len = sizeof(*sin); + + BPF_CGROUP_RUN_PROG_UDP4_RECVMSG_LOCK(sk, + (struct sockaddr *)sin); + } + + if (udp_test_bit(GRO_ENABLED, sk)) + udp_cmsg_recv(msg, sk, skb); + + if (inet->cmsg_flags) + ip_cmsg_recv_offset(msg, sk, skb, sizeof(struct udphdr), off); + + err = copied; + if (flags & MSG_TRUNC) + err = ulen; + + skb_consume_udp(sk, skb, peeking ? -err : err); + return err; + +csum_copy_err: + if (!__sk_queue_drop_skb(sk, &udp_sk(sk)->reader_queue, skb, flags, + udp_skb_destructor)) { + UDP_INC_STATS(sock_net(sk), UDP_MIB_CSUMERRORS, is_udplite); + UDP_INC_STATS(sock_net(sk), UDP_MIB_INERRORS, is_udplite); + } + kfree_skb(skb); + + /* starting over for a new packet, but check if we need to yield */ + cond_resched(); + msg->msg_flags &= ~MSG_TRUNC; + goto try_again; +} + +int udp_pre_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) +{ + /* This check is replicated from __ip4_datagram_connect() and + * intended to prevent BPF program called below from accessing bytes + * that are out of the bound specified by user in addr_len. + */ + if (addr_len < sizeof(struct sockaddr_in)) + return -EINVAL; + + return BPF_CGROUP_RUN_PROG_INET4_CONNECT_LOCK(sk, uaddr); +} +EXPORT_SYMBOL(udp_pre_connect); + +int __udp_disconnect(struct sock *sk, int flags) +{ + struct inet_sock *inet = inet_sk(sk); + /* + * 1003.1g - break association. + */ + + sk->sk_state = TCP_CLOSE; + inet->inet_daddr = 0; + inet->inet_dport = 0; + sock_rps_reset_rxhash(sk); + sk->sk_bound_dev_if = 0; + if (!(sk->sk_userlocks & SOCK_BINDADDR_LOCK)) { + inet_reset_saddr(sk); + if (sk->sk_prot->rehash && + (sk->sk_userlocks & SOCK_BINDPORT_LOCK)) + sk->sk_prot->rehash(sk); + } + + if (!(sk->sk_userlocks & SOCK_BINDPORT_LOCK)) { + sk->sk_prot->unhash(sk); + inet->inet_sport = 0; + } + sk_dst_reset(sk); + return 0; +} +EXPORT_SYMBOL(__udp_disconnect); + +int udp_disconnect(struct sock *sk, int flags) +{ + lock_sock(sk); + __udp_disconnect(sk, flags); + release_sock(sk); + return 0; +} +EXPORT_SYMBOL(udp_disconnect); + +void udp_lib_unhash(struct sock *sk) +{ + if (sk_hashed(sk)) { + struct udp_table *udptable = sk->sk_prot->h.udp_table; + struct udp_hslot *hslot, *hslot2; + + hslot = udp_hashslot(udptable, sock_net(sk), + udp_sk(sk)->udp_port_hash); + hslot2 = udp_hashslot2(udptable, udp_sk(sk)->udp_portaddr_hash); + + spin_lock_bh(&hslot->lock); + if (rcu_access_pointer(sk->sk_reuseport_cb)) + reuseport_detach_sock(sk); + if (sk_del_node_init_rcu(sk)) { + hslot->count--; + inet_sk(sk)->inet_num = 0; + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); + + spin_lock(&hslot2->lock); + hlist_del_init_rcu(&udp_sk(sk)->udp_portaddr_node); + hslot2->count--; + spin_unlock(&hslot2->lock); + } + spin_unlock_bh(&hslot->lock); + } +} +EXPORT_SYMBOL(udp_lib_unhash); + +/* + * inet_rcv_saddr was changed, we must rehash secondary hash + */ +void udp_lib_rehash(struct sock *sk, u16 newhash) +{ + if (sk_hashed(sk)) { + struct udp_table *udptable = sk->sk_prot->h.udp_table; + struct udp_hslot *hslot, *hslot2, *nhslot2; + + hslot2 = udp_hashslot2(udptable, udp_sk(sk)->udp_portaddr_hash); + nhslot2 = udp_hashslot2(udptable, newhash); + udp_sk(sk)->udp_portaddr_hash = newhash; + + if (hslot2 != nhslot2 || + rcu_access_pointer(sk->sk_reuseport_cb)) { + hslot = udp_hashslot(udptable, sock_net(sk), + udp_sk(sk)->udp_port_hash); + /* we must lock primary chain too */ + spin_lock_bh(&hslot->lock); + if (rcu_access_pointer(sk->sk_reuseport_cb)) + reuseport_detach_sock(sk); + + if (hslot2 != nhslot2) { + spin_lock(&hslot2->lock); + hlist_del_init_rcu(&udp_sk(sk)->udp_portaddr_node); + hslot2->count--; + spin_unlock(&hslot2->lock); + + spin_lock(&nhslot2->lock); + hlist_add_head_rcu(&udp_sk(sk)->udp_portaddr_node, + &nhslot2->head); + nhslot2->count++; + spin_unlock(&nhslot2->lock); + } + + spin_unlock_bh(&hslot->lock); + } + } +} +EXPORT_SYMBOL(udp_lib_rehash); + +void udp_v4_rehash(struct sock *sk) +{ + u16 new_hash = ipv4_portaddr_hash(sock_net(sk), + inet_sk(sk)->inet_rcv_saddr, + inet_sk(sk)->inet_num); + udp_lib_rehash(sk, new_hash); +} + +static int __udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) +{ + int rc; + + if (inet_sk(sk)->inet_daddr) { + sock_rps_save_rxhash(sk, skb); + sk_mark_napi_id(sk, skb); + sk_incoming_cpu_update(sk); + } else { + sk_mark_napi_id_once(sk, skb); + } + + rc = __udp_enqueue_schedule_skb(sk, skb); + if (rc < 0) { + int is_udplite = IS_UDPLITE(sk); + int drop_reason; + + /* Note that an ENOMEM error is charged twice */ + if (rc == -ENOMEM) { + UDP_INC_STATS(sock_net(sk), UDP_MIB_RCVBUFERRORS, + is_udplite); + drop_reason = SKB_DROP_REASON_SOCKET_RCVBUFF; + } else { + UDP_INC_STATS(sock_net(sk), UDP_MIB_MEMERRORS, + is_udplite); + drop_reason = SKB_DROP_REASON_PROTO_MEM; + } + UDP_INC_STATS(sock_net(sk), UDP_MIB_INERRORS, is_udplite); + kfree_skb_reason(skb, drop_reason); + trace_udp_fail_queue_rcv_skb(rc, sk); + return -1; + } + + return 0; +} + +/* returns: + * -1: error + * 0: success + * >0: "udp encap" protocol resubmission + * + * Note that in the success and error cases, the skb is assumed to + * have either been requeued or freed. + */ +static int udp_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb) +{ + int drop_reason = SKB_DROP_REASON_NOT_SPECIFIED; + struct udp_sock *up = udp_sk(sk); + int is_udplite = IS_UDPLITE(sk); + + /* + * Charge it to the socket, dropping if the queue is full. + */ + if (!xfrm4_policy_check(sk, XFRM_POLICY_IN, skb)) { + drop_reason = SKB_DROP_REASON_XFRM_POLICY; + goto drop; + } + nf_reset_ct(skb); + + if (static_branch_unlikely(&udp_encap_needed_key) && + READ_ONCE(up->encap_type)) { + int (*encap_rcv)(struct sock *sk, struct sk_buff *skb); + + /* + * This is an encapsulation socket so pass the skb to + * the socket's udp_encap_rcv() hook. Otherwise, just + * fall through and pass this up the UDP socket. + * up->encap_rcv() returns the following value: + * =0 if skb was successfully passed to the encap + * handler or was discarded by it. + * >0 if skb should be passed on to UDP. + * <0 if skb should be resubmitted as proto -N + */ + + /* if we're overly short, let UDP handle it */ + encap_rcv = READ_ONCE(up->encap_rcv); + if (encap_rcv) { + int ret; + + /* Verify checksum before giving to encap */ + if (udp_lib_checksum_complete(skb)) + goto csum_error; + + ret = encap_rcv(sk, skb); + if (ret <= 0) { + __UDP_INC_STATS(sock_net(sk), + UDP_MIB_INDATAGRAMS, + is_udplite); + return -ret; + } + } + + /* FALLTHROUGH -- it's a UDP Packet */ + } + + /* + * UDP-Lite specific tests, ignored on UDP sockets + */ + if ((up->pcflag & UDPLITE_RECV_CC) && UDP_SKB_CB(skb)->partial_cov) { + + /* + * MIB statistics other than incrementing the error count are + * disabled for the following two types of errors: these depend + * on the application settings, not on the functioning of the + * protocol stack as such. + * + * RFC 3828 here recommends (sec 3.3): "There should also be a + * way ... to ... at least let the receiving application block + * delivery of packets with coverage values less than a value + * provided by the application." + */ + if (up->pcrlen == 0) { /* full coverage was set */ + net_dbg_ratelimited("UDPLite: partial coverage %d while full coverage %d requested\n", + UDP_SKB_CB(skb)->cscov, skb->len); + goto drop; + } + /* The next case involves violating the min. coverage requested + * by the receiver. This is subtle: if receiver wants x and x is + * greater than the buffersize/MTU then receiver will complain + * that it wants x while sender emits packets of smaller size y. + * Therefore the above ...()->partial_cov statement is essential. + */ + if (UDP_SKB_CB(skb)->cscov < up->pcrlen) { + net_dbg_ratelimited("UDPLite: coverage %d too small, need min %d\n", + UDP_SKB_CB(skb)->cscov, up->pcrlen); + goto drop; + } + } + + prefetch(&sk->sk_rmem_alloc); + if (rcu_access_pointer(sk->sk_filter) && + udp_lib_checksum_complete(skb)) + goto csum_error; + + if (sk_filter_trim_cap(sk, skb, sizeof(struct udphdr))) { + drop_reason = SKB_DROP_REASON_SOCKET_FILTER; + goto drop; + } + + udp_csum_pull_header(skb); + + ipv4_pktinfo_prepare(sk, skb); + return __udp_queue_rcv_skb(sk, skb); + +csum_error: + drop_reason = SKB_DROP_REASON_UDP_CSUM; + __UDP_INC_STATS(sock_net(sk), UDP_MIB_CSUMERRORS, is_udplite); +drop: + __UDP_INC_STATS(sock_net(sk), UDP_MIB_INERRORS, is_udplite); + atomic_inc(&sk->sk_drops); + kfree_skb_reason(skb, drop_reason); + return -1; +} + +static int udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) +{ + struct sk_buff *next, *segs; + int ret; + + if (likely(!udp_unexpected_gso(sk, skb))) + return udp_queue_rcv_one_skb(sk, skb); + + BUILD_BUG_ON(sizeof(struct udp_skb_cb) > SKB_GSO_CB_OFFSET); + __skb_push(skb, -skb_mac_offset(skb)); + segs = udp_rcv_segment(sk, skb, true); + skb_list_walk_safe(segs, skb, next) { + __skb_pull(skb, skb_transport_offset(skb)); + + udp_post_segment_fix_csum(skb); + ret = udp_queue_rcv_one_skb(sk, skb); + if (ret > 0) + ip_protocol_deliver_rcu(dev_net(skb->dev), skb, ret); + } + return 0; +} + +/* For TCP sockets, sk_rx_dst is protected by socket lock + * For UDP, we use xchg() to guard against concurrent changes. + */ +bool udp_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst) +{ + struct dst_entry *old; + + if (dst_hold_safe(dst)) { + old = xchg((__force struct dst_entry **)&sk->sk_rx_dst, dst); + dst_release(old); + return old != dst; + } + return false; +} +EXPORT_SYMBOL(udp_sk_rx_dst_set); + +/* + * Multicasts and broadcasts go to each listener. + * + * Note: called only from the BH handler context. + */ +static int __udp4_lib_mcast_deliver(struct net *net, struct sk_buff *skb, + struct udphdr *uh, + __be32 saddr, __be32 daddr, + struct udp_table *udptable, + int proto) +{ + struct sock *sk, *first = NULL; + unsigned short hnum = ntohs(uh->dest); + struct udp_hslot *hslot = udp_hashslot(udptable, net, hnum); + unsigned int hash2 = 0, hash2_any = 0, use_hash2 = (hslot->count > 10); + unsigned int offset = offsetof(typeof(*sk), sk_node); + int dif = skb->dev->ifindex; + int sdif = inet_sdif(skb); + struct hlist_node *node; + struct sk_buff *nskb; + + if (use_hash2) { + hash2_any = ipv4_portaddr_hash(net, htonl(INADDR_ANY), hnum) & + udptable->mask; + hash2 = ipv4_portaddr_hash(net, daddr, hnum) & udptable->mask; +start_lookup: + hslot = &udptable->hash2[hash2]; + offset = offsetof(typeof(*sk), __sk_common.skc_portaddr_node); + } + + sk_for_each_entry_offset_rcu(sk, node, &hslot->head, offset) { + if (!__udp_is_mcast_sock(net, sk, uh->dest, daddr, + uh->source, saddr, dif, sdif, hnum)) + continue; + + if (!first) { + first = sk; + continue; + } + nskb = skb_clone(skb, GFP_ATOMIC); + + if (unlikely(!nskb)) { + atomic_inc(&sk->sk_drops); + __UDP_INC_STATS(net, UDP_MIB_RCVBUFERRORS, + IS_UDPLITE(sk)); + __UDP_INC_STATS(net, UDP_MIB_INERRORS, + IS_UDPLITE(sk)); + continue; + } + if (udp_queue_rcv_skb(sk, nskb) > 0) + consume_skb(nskb); + } + + /* Also lookup *:port if we are using hash2 and haven't done so yet. */ + if (use_hash2 && hash2 != hash2_any) { + hash2 = hash2_any; + goto start_lookup; + } + + if (first) { + if (udp_queue_rcv_skb(first, skb) > 0) + consume_skb(skb); + } else { + kfree_skb(skb); + __UDP_INC_STATS(net, UDP_MIB_IGNOREDMULTI, + proto == IPPROTO_UDPLITE); + } + return 0; +} + +/* Initialize UDP checksum. If exited with zero value (success), + * CHECKSUM_UNNECESSARY means, that no more checks are required. + * Otherwise, csum completion requires checksumming packet body, + * including udp header and folding it to skb->csum. + */ +static inline int udp4_csum_init(struct sk_buff *skb, struct udphdr *uh, + int proto) +{ + int err; + + UDP_SKB_CB(skb)->partial_cov = 0; + UDP_SKB_CB(skb)->cscov = skb->len; + + if (proto == IPPROTO_UDPLITE) { + err = udplite_checksum_init(skb, uh); + if (err) + return err; + + if (UDP_SKB_CB(skb)->partial_cov) { + skb->csum = inet_compute_pseudo(skb, proto); + return 0; + } + } + + /* Note, we are only interested in != 0 or == 0, thus the + * force to int. + */ + err = (__force int)skb_checksum_init_zero_check(skb, proto, uh->check, + inet_compute_pseudo); + if (err) + return err; + + if (skb->ip_summed == CHECKSUM_COMPLETE && !skb->csum_valid) { + /* If SW calculated the value, we know it's bad */ + if (skb->csum_complete_sw) + return 1; + + /* HW says the value is bad. Let's validate that. + * skb->csum is no longer the full packet checksum, + * so don't treat it as such. + */ + skb_checksum_complete_unset(skb); + } + + return 0; +} + +/* wrapper for udp_queue_rcv_skb tacking care of csum conversion and + * return code conversion for ip layer consumption + */ +static int udp_unicast_rcv_skb(struct sock *sk, struct sk_buff *skb, + struct udphdr *uh) +{ + int ret; + + if (inet_get_convert_csum(sk) && uh->check && !IS_UDPLITE(sk)) + skb_checksum_try_convert(skb, IPPROTO_UDP, inet_compute_pseudo); + + ret = udp_queue_rcv_skb(sk, skb); + + /* a return value > 0 means to resubmit the input, but + * it wants the return to be -protocol, or 0 + */ + if (ret > 0) + return -ret; + return 0; +} + +/* + * All we need to do is get the socket, and then do a checksum. + */ + +int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, + int proto) +{ + struct sock *sk; + struct udphdr *uh; + unsigned short ulen; + struct rtable *rt = skb_rtable(skb); + __be32 saddr, daddr; + struct net *net = dev_net(skb->dev); + bool refcounted; + int drop_reason; + + drop_reason = SKB_DROP_REASON_NOT_SPECIFIED; + + /* + * Validate the packet. + */ + if (!pskb_may_pull(skb, sizeof(struct udphdr))) + goto drop; /* No space for header. */ + + uh = udp_hdr(skb); + ulen = ntohs(uh->len); + saddr = ip_hdr(skb)->saddr; + daddr = ip_hdr(skb)->daddr; + + if (ulen > skb->len) + goto short_packet; + + if (proto == IPPROTO_UDP) { + /* UDP validates ulen. */ + if (ulen < sizeof(*uh) || pskb_trim_rcsum(skb, ulen)) + goto short_packet; + uh = udp_hdr(skb); + } + + if (udp4_csum_init(skb, uh, proto)) + goto csum_error; + + sk = skb_steal_sock(skb, &refcounted); + if (sk) { + struct dst_entry *dst = skb_dst(skb); + int ret; + + if (unlikely(rcu_dereference(sk->sk_rx_dst) != dst)) + udp_sk_rx_dst_set(sk, dst); + + ret = udp_unicast_rcv_skb(sk, skb, uh); + if (refcounted) + sock_put(sk); + return ret; + } + + if (rt->rt_flags & (RTCF_BROADCAST|RTCF_MULTICAST)) + return __udp4_lib_mcast_deliver(net, skb, uh, + saddr, daddr, udptable, proto); + + sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable); + if (sk) + return udp_unicast_rcv_skb(sk, skb, uh); + + if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) + goto drop; + nf_reset_ct(skb); + + /* No socket. Drop packet silently, if checksum is wrong */ + if (udp_lib_checksum_complete(skb)) + goto csum_error; + + drop_reason = SKB_DROP_REASON_NO_SOCKET; + __UDP_INC_STATS(net, UDP_MIB_NOPORTS, proto == IPPROTO_UDPLITE); + icmp_send(skb, ICMP_DEST_UNREACH, ICMP_PORT_UNREACH, 0); + + /* + * Hmm. We got an UDP packet to a port to which we + * don't wanna listen. Ignore it. + */ + kfree_skb_reason(skb, drop_reason); + return 0; + +short_packet: + drop_reason = SKB_DROP_REASON_PKT_TOO_SMALL; + net_dbg_ratelimited("UDP%s: short packet: From %pI4:%u %d/%d to %pI4:%u\n", + proto == IPPROTO_UDPLITE ? "Lite" : "", + &saddr, ntohs(uh->source), + ulen, skb->len, + &daddr, ntohs(uh->dest)); + goto drop; + +csum_error: + /* + * RFC1122: OK. Discards the bad packet silently (as far as + * the network is concerned, anyway) as per 4.1.3.4 (MUST). + */ + drop_reason = SKB_DROP_REASON_UDP_CSUM; + net_dbg_ratelimited("UDP%s: bad checksum. From %pI4:%u to %pI4:%u ulen %d\n", + proto == IPPROTO_UDPLITE ? "Lite" : "", + &saddr, ntohs(uh->source), &daddr, ntohs(uh->dest), + ulen); + __UDP_INC_STATS(net, UDP_MIB_CSUMERRORS, proto == IPPROTO_UDPLITE); +drop: + __UDP_INC_STATS(net, UDP_MIB_INERRORS, proto == IPPROTO_UDPLITE); + kfree_skb_reason(skb, drop_reason); + return 0; +} + +/* We can only early demux multicast if there is a single matching socket. + * If more than one socket found returns NULL + */ +static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net, + __be16 loc_port, __be32 loc_addr, + __be16 rmt_port, __be32 rmt_addr, + int dif, int sdif) +{ + unsigned short hnum = ntohs(loc_port); + struct sock *sk, *result; + struct udp_hslot *hslot; + unsigned int slot; + + slot = udp_hashfn(net, hnum, udp_table.mask); + hslot = &udp_table.hash[slot]; + + /* Do not bother scanning a too big list */ + if (hslot->count > 10) + return NULL; + + result = NULL; + sk_for_each_rcu(sk, &hslot->head) { + if (__udp_is_mcast_sock(net, sk, loc_port, loc_addr, + rmt_port, rmt_addr, dif, sdif, hnum)) { + if (result) + return NULL; + result = sk; + } + } + + return result; +} + +/* For unicast we should only early demux connected sockets or we can + * break forwarding setups. The chains here can be long so only check + * if the first socket is an exact match and if not move on. + */ +static struct sock *__udp4_lib_demux_lookup(struct net *net, + __be16 loc_port, __be32 loc_addr, + __be16 rmt_port, __be32 rmt_addr, + int dif, int sdif) +{ + INET_ADDR_COOKIE(acookie, rmt_addr, loc_addr); + unsigned short hnum = ntohs(loc_port); + unsigned int hash2, slot2; + struct udp_hslot *hslot2; + __portpair ports; + struct sock *sk; + + hash2 = ipv4_portaddr_hash(net, loc_addr, hnum); + slot2 = hash2 & udp_table.mask; + hslot2 = &udp_table.hash2[slot2]; + ports = INET_COMBINED_PORTS(rmt_port, hnum); + + udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) { + if (inet_match(net, sk, acookie, ports, dif, sdif)) + return sk; + /* Only check first socket in chain */ + break; + } + return NULL; +} + +int udp_v4_early_demux(struct sk_buff *skb) +{ + struct net *net = dev_net(skb->dev); + struct in_device *in_dev = NULL; + const struct iphdr *iph; + const struct udphdr *uh; + struct sock *sk = NULL; + struct dst_entry *dst; + int dif = skb->dev->ifindex; + int sdif = inet_sdif(skb); + int ours; + + /* validate the packet */ + if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct udphdr))) + return 0; + + iph = ip_hdr(skb); + uh = udp_hdr(skb); + + if (skb->pkt_type == PACKET_MULTICAST) { + in_dev = __in_dev_get_rcu(skb->dev); + + if (!in_dev) + return 0; + + ours = ip_check_mc_rcu(in_dev, iph->daddr, iph->saddr, + iph->protocol); + if (!ours) + return 0; + + sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr, + uh->source, iph->saddr, + dif, sdif); + } else if (skb->pkt_type == PACKET_HOST) { + sk = __udp4_lib_demux_lookup(net, uh->dest, iph->daddr, + uh->source, iph->saddr, dif, sdif); + } + + if (!sk || !refcount_inc_not_zero(&sk->sk_refcnt)) + return 0; + + skb->sk = sk; + skb->destructor = sock_efree; + dst = rcu_dereference(sk->sk_rx_dst); + + if (dst) + dst = dst_check(dst, 0); + if (dst) { + u32 itag = 0; + + /* set noref for now. + * any place which wants to hold dst has to call + * dst_hold_safe() + */ + skb_dst_set_noref(skb, dst); + + /* for unconnected multicast sockets we need to validate + * the source on each packet + */ + if (!inet_sk(sk)->inet_daddr && in_dev) + return ip_mc_validate_source(skb, iph->daddr, + iph->saddr, + iph->tos & IPTOS_RT_MASK, + skb->dev, in_dev, &itag); + } + return 0; +} + +int udp_rcv(struct sk_buff *skb) +{ + return __udp4_lib_rcv(skb, &udp_table, IPPROTO_UDP); +} + +void udp_destroy_sock(struct sock *sk) +{ + struct udp_sock *up = udp_sk(sk); + bool slow = lock_sock_fast(sk); + + /* protects from races with udp_abort() */ + sock_set_flag(sk, SOCK_DEAD); + udp_flush_pending_frames(sk); + unlock_sock_fast(sk, slow); + if (static_branch_unlikely(&udp_encap_needed_key)) { + if (up->encap_type) { + void (*encap_destroy)(struct sock *sk); + encap_destroy = READ_ONCE(up->encap_destroy); + if (encap_destroy) + encap_destroy(sk); + } + if (udp_test_bit(ENCAP_ENABLED, sk)) + static_branch_dec(&udp_encap_needed_key); + } +} + +/* + * Socket option code for UDP + */ +int udp_lib_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen, + int (*push_pending_frames)(struct sock *)) +{ + struct udp_sock *up = udp_sk(sk); + int val, valbool; + int err = 0; + int is_udplite = IS_UDPLITE(sk); + + if (optlen < sizeof(int)) + return -EINVAL; + + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + + valbool = val ? 1 : 0; + + switch (optname) { + case UDP_CORK: + if (val != 0) { + udp_set_bit(CORK, sk); + } else { + udp_clear_bit(CORK, sk); + lock_sock(sk); + push_pending_frames(sk); + release_sock(sk); + } + break; + + case UDP_ENCAP: + switch (val) { + case 0: +#ifdef CONFIG_XFRM + case UDP_ENCAP_ESPINUDP: + case UDP_ENCAP_ESPINUDP_NON_IKE: +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) + WRITE_ONCE(up->encap_rcv, + ipv6_stub->xfrm6_udp_encap_rcv); + else +#endif + WRITE_ONCE(up->encap_rcv, + xfrm4_udp_encap_rcv); +#endif + fallthrough; + case UDP_ENCAP_L2TPINUDP: + WRITE_ONCE(up->encap_type, val); + udp_tunnel_encap_enable(sk); + break; + default: + err = -ENOPROTOOPT; + break; + } + break; + + case UDP_NO_CHECK6_TX: + udp_set_no_check6_tx(sk, valbool); + break; + + case UDP_NO_CHECK6_RX: + udp_set_no_check6_rx(sk, valbool); + break; + + case UDP_SEGMENT: + if (val < 0 || val > USHRT_MAX) + return -EINVAL; + WRITE_ONCE(up->gso_size, val); + break; + + case UDP_GRO: + + /* when enabling GRO, accept the related GSO packet type */ + if (valbool) + udp_tunnel_encap_enable(sk); + udp_assign_bit(GRO_ENABLED, sk, valbool); + udp_assign_bit(ACCEPT_L4, sk, valbool); + break; + + /* + * UDP-Lite's partial checksum coverage (RFC 3828). + */ + /* The sender sets actual checksum coverage length via this option. + * The case coverage > packet length is handled by send module. */ + case UDPLITE_SEND_CSCOV: + if (!is_udplite) /* Disable the option on UDP sockets */ + return -ENOPROTOOPT; + if (val != 0 && val < 8) /* Illegal coverage: use default (8) */ + val = 8; + else if (val > USHRT_MAX) + val = USHRT_MAX; + up->pcslen = val; + up->pcflag |= UDPLITE_SEND_CC; + break; + + /* The receiver specifies a minimum checksum coverage value. To make + * sense, this should be set to at least 8 (as done below). If zero is + * used, this again means full checksum coverage. */ + case UDPLITE_RECV_CSCOV: + if (!is_udplite) /* Disable the option on UDP sockets */ + return -ENOPROTOOPT; + if (val != 0 && val < 8) /* Avoid silly minimal values. */ + val = 8; + else if (val > USHRT_MAX) + val = USHRT_MAX; + up->pcrlen = val; + up->pcflag |= UDPLITE_RECV_CC; + break; + + default: + err = -ENOPROTOOPT; + break; + } + + return err; +} +EXPORT_SYMBOL(udp_lib_setsockopt); + +int udp_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, + unsigned int optlen) +{ + if (level == SOL_UDP || level == SOL_UDPLITE) + return udp_lib_setsockopt(sk, level, optname, + optval, optlen, + udp_push_pending_frames); + return ip_setsockopt(sk, level, optname, optval, optlen); +} + +int udp_lib_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct udp_sock *up = udp_sk(sk); + int val, len; + + if (get_user(len, optlen)) + return -EFAULT; + + len = min_t(unsigned int, len, sizeof(int)); + + if (len < 0) + return -EINVAL; + + switch (optname) { + case UDP_CORK: + val = udp_test_bit(CORK, sk); + break; + + case UDP_ENCAP: + val = READ_ONCE(up->encap_type); + break; + + case UDP_NO_CHECK6_TX: + val = udp_get_no_check6_tx(sk); + break; + + case UDP_NO_CHECK6_RX: + val = udp_get_no_check6_rx(sk); + break; + + case UDP_SEGMENT: + val = READ_ONCE(up->gso_size); + break; + + case UDP_GRO: + val = udp_test_bit(GRO_ENABLED, sk); + break; + + /* The following two cannot be changed on UDP sockets, the return is + * always 0 (which corresponds to the full checksum coverage of UDP). */ + case UDPLITE_SEND_CSCOV: + val = up->pcslen; + break; + + case UDPLITE_RECV_CSCOV: + val = up->pcrlen; + break; + + default: + return -ENOPROTOOPT; + } + + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + return 0; +} +EXPORT_SYMBOL(udp_lib_getsockopt); + +int udp_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen) +{ + if (level == SOL_UDP || level == SOL_UDPLITE) + return udp_lib_getsockopt(sk, level, optname, optval, optlen); + return ip_getsockopt(sk, level, optname, optval, optlen); +} + +/** + * udp_poll - wait for a UDP event. + * @file: - file struct + * @sock: - socket + * @wait: - poll table + * + * This is same as datagram poll, except for the special case of + * blocking sockets. If application is using a blocking fd + * and a packet with checksum error is in the queue; + * then it could get return from select indicating data available + * but then block when reading it. Add special case code + * to work around these arguably broken applications. + */ +__poll_t udp_poll(struct file *file, struct socket *sock, poll_table *wait) +{ + __poll_t mask = datagram_poll(file, sock, wait); + struct sock *sk = sock->sk; + + if (!skb_queue_empty_lockless(&udp_sk(sk)->reader_queue)) + mask |= EPOLLIN | EPOLLRDNORM; + + /* Check for false positives due to checksum errors */ + if ((mask & EPOLLRDNORM) && !(file->f_flags & O_NONBLOCK) && + !(sk->sk_shutdown & RCV_SHUTDOWN) && first_packet_length(sk) == -1) + mask &= ~(EPOLLIN | EPOLLRDNORM); + + /* psock ingress_msg queue should not contain any bad checksum frames */ + if (sk_is_readable(sk)) + mask |= EPOLLIN | EPOLLRDNORM; + return mask; + +} +EXPORT_SYMBOL(udp_poll); + +int udp_abort(struct sock *sk, int err) +{ + lock_sock(sk); + + /* udp{v6}_destroy_sock() sets it under the sk lock, avoid racing + * with close() + */ + if (sock_flag(sk, SOCK_DEAD)) + goto out; + + sk->sk_err = err; + sk_error_report(sk); + __udp_disconnect(sk, 0); + +out: + release_sock(sk); + + return 0; +} +EXPORT_SYMBOL_GPL(udp_abort); + +struct proto udp_prot = { + .name = "UDP", + .owner = THIS_MODULE, + .close = udp_lib_close, + .pre_connect = udp_pre_connect, + .connect = ip4_datagram_connect, + .disconnect = udp_disconnect, + .ioctl = udp_ioctl, + .init = udp_init_sock, + .destroy = udp_destroy_sock, + .setsockopt = udp_setsockopt, + .getsockopt = udp_getsockopt, + .sendmsg = udp_sendmsg, + .recvmsg = udp_recvmsg, + .splice_eof = udp_splice_eof, + .sendpage = udp_sendpage, + .release_cb = ip4_datagram_release_cb, + .hash = udp_lib_hash, + .unhash = udp_lib_unhash, + .rehash = udp_v4_rehash, + .get_port = udp_v4_get_port, + .put_port = udp_lib_unhash, +#ifdef CONFIG_BPF_SYSCALL + .psock_update_sk_prot = udp_bpf_update_proto, +#endif + .memory_allocated = &udp_memory_allocated, + .per_cpu_fw_alloc = &udp_memory_per_cpu_fw_alloc, + + .sysctl_mem = sysctl_udp_mem, + .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min), + .sysctl_rmem_offset = offsetof(struct net, ipv4.sysctl_udp_rmem_min), + .obj_size = sizeof(struct udp_sock), + .h.udp_table = &udp_table, + .diag_destroy = udp_abort, +}; +EXPORT_SYMBOL(udp_prot); + +/* ------------------------------------------------------------------------ */ +#ifdef CONFIG_PROC_FS + +static struct sock *udp_get_first(struct seq_file *seq, int start) +{ + struct udp_iter_state *state = seq->private; + struct net *net = seq_file_net(seq); + struct udp_seq_afinfo *afinfo; + struct sock *sk; + + if (state->bpf_seq_afinfo) + afinfo = state->bpf_seq_afinfo; + else + afinfo = pde_data(file_inode(seq->file)); + + for (state->bucket = start; state->bucket <= afinfo->udp_table->mask; + ++state->bucket) { + struct udp_hslot *hslot = &afinfo->udp_table->hash[state->bucket]; + + if (hlist_empty(&hslot->head)) + continue; + + spin_lock_bh(&hslot->lock); + sk_for_each(sk, &hslot->head) { + if (!net_eq(sock_net(sk), net)) + continue; + if (afinfo->family == AF_UNSPEC || + sk->sk_family == afinfo->family) + goto found; + } + spin_unlock_bh(&hslot->lock); + } + sk = NULL; +found: + return sk; +} + +static struct sock *udp_get_next(struct seq_file *seq, struct sock *sk) +{ + struct udp_iter_state *state = seq->private; + struct net *net = seq_file_net(seq); + struct udp_seq_afinfo *afinfo; + + if (state->bpf_seq_afinfo) + afinfo = state->bpf_seq_afinfo; + else + afinfo = pde_data(file_inode(seq->file)); + + do { + sk = sk_next(sk); + } while (sk && (!net_eq(sock_net(sk), net) || + (afinfo->family != AF_UNSPEC && + sk->sk_family != afinfo->family))); + + if (!sk) { + if (state->bucket <= afinfo->udp_table->mask) + spin_unlock_bh(&afinfo->udp_table->hash[state->bucket].lock); + return udp_get_first(seq, state->bucket + 1); + } + return sk; +} + +static struct sock *udp_get_idx(struct seq_file *seq, loff_t pos) +{ + struct sock *sk = udp_get_first(seq, 0); + + if (sk) + while (pos && (sk = udp_get_next(seq, sk)) != NULL) + --pos; + return pos ? NULL : sk; +} + +void *udp_seq_start(struct seq_file *seq, loff_t *pos) +{ + struct udp_iter_state *state = seq->private; + state->bucket = MAX_UDP_PORTS; + + return *pos ? udp_get_idx(seq, *pos-1) : SEQ_START_TOKEN; +} +EXPORT_SYMBOL(udp_seq_start); + +void *udp_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct sock *sk; + + if (v == SEQ_START_TOKEN) + sk = udp_get_idx(seq, 0); + else + sk = udp_get_next(seq, v); + + ++*pos; + return sk; +} +EXPORT_SYMBOL(udp_seq_next); + +void udp_seq_stop(struct seq_file *seq, void *v) +{ + struct udp_iter_state *state = seq->private; + struct udp_seq_afinfo *afinfo; + + if (state->bpf_seq_afinfo) + afinfo = state->bpf_seq_afinfo; + else + afinfo = pde_data(file_inode(seq->file)); + + if (state->bucket <= afinfo->udp_table->mask) + spin_unlock_bh(&afinfo->udp_table->hash[state->bucket].lock); +} +EXPORT_SYMBOL(udp_seq_stop); + +/* ------------------------------------------------------------------------ */ +static void udp4_format_sock(struct sock *sp, struct seq_file *f, + int bucket) +{ + struct inet_sock *inet = inet_sk(sp); + __be32 dest = inet->inet_daddr; + __be32 src = inet->inet_rcv_saddr; + __u16 destp = ntohs(inet->inet_dport); + __u16 srcp = ntohs(inet->inet_sport); + + seq_printf(f, "%5d: %08X:%04X %08X:%04X" + " %02X %08X:%08X %02X:%08lX %08X %5u %8d %lu %d %pK %u", + bucket, src, srcp, dest, destp, sp->sk_state, + sk_wmem_alloc_get(sp), + udp_rqueue_get(sp), + 0, 0L, 0, + from_kuid_munged(seq_user_ns(f), sock_i_uid(sp)), + 0, sock_i_ino(sp), + refcount_read(&sp->sk_refcnt), sp, + atomic_read(&sp->sk_drops)); +} + +int udp4_seq_show(struct seq_file *seq, void *v) +{ + seq_setwidth(seq, 127); + if (v == SEQ_START_TOKEN) + seq_puts(seq, " sl local_address rem_address st tx_queue " + "rx_queue tr tm->when retrnsmt uid timeout " + "inode ref pointer drops"); + else { + struct udp_iter_state *state = seq->private; + + udp4_format_sock(v, seq, state->bucket); + } + seq_pad(seq, '\n'); + return 0; +} + +#ifdef CONFIG_BPF_SYSCALL +struct bpf_iter__udp { + __bpf_md_ptr(struct bpf_iter_meta *, meta); + __bpf_md_ptr(struct udp_sock *, udp_sk); + uid_t uid __aligned(8); + int bucket __aligned(8); +}; + +static int udp_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta, + struct udp_sock *udp_sk, uid_t uid, int bucket) +{ + struct bpf_iter__udp ctx; + + meta->seq_num--; /* skip SEQ_START_TOKEN */ + ctx.meta = meta; + ctx.udp_sk = udp_sk; + ctx.uid = uid; + ctx.bucket = bucket; + return bpf_iter_run_prog(prog, &ctx); +} + +static int bpf_iter_udp_seq_show(struct seq_file *seq, void *v) +{ + struct udp_iter_state *state = seq->private; + struct bpf_iter_meta meta; + struct bpf_prog *prog; + struct sock *sk = v; + uid_t uid; + + if (v == SEQ_START_TOKEN) + return 0; + + uid = from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk)); + meta.seq = seq; + prog = bpf_iter_get_info(&meta, false); + return udp_prog_seq_show(prog, &meta, v, uid, state->bucket); +} + +static void bpf_iter_udp_seq_stop(struct seq_file *seq, void *v) +{ + struct bpf_iter_meta meta; + struct bpf_prog *prog; + + if (!v) { + meta.seq = seq; + prog = bpf_iter_get_info(&meta, true); + if (prog) + (void)udp_prog_seq_show(prog, &meta, v, 0, 0); + } + + udp_seq_stop(seq, v); +} + +static const struct seq_operations bpf_iter_udp_seq_ops = { + .start = udp_seq_start, + .next = udp_seq_next, + .stop = bpf_iter_udp_seq_stop, + .show = bpf_iter_udp_seq_show, +}; +#endif + +const struct seq_operations udp_seq_ops = { + .start = udp_seq_start, + .next = udp_seq_next, + .stop = udp_seq_stop, + .show = udp4_seq_show, +}; +EXPORT_SYMBOL(udp_seq_ops); + +static struct udp_seq_afinfo udp4_seq_afinfo = { + .family = AF_INET, + .udp_table = &udp_table, +}; + +static int __net_init udp4_proc_init_net(struct net *net) +{ + if (!proc_create_net_data("udp", 0444, net->proc_net, &udp_seq_ops, + sizeof(struct udp_iter_state), &udp4_seq_afinfo)) + return -ENOMEM; + return 0; +} + +static void __net_exit udp4_proc_exit_net(struct net *net) +{ + remove_proc_entry("udp", net->proc_net); +} + +static struct pernet_operations udp4_net_ops = { + .init = udp4_proc_init_net, + .exit = udp4_proc_exit_net, +}; + +int __init udp4_proc_init(void) +{ + return register_pernet_subsys(&udp4_net_ops); +} + +void udp4_proc_exit(void) +{ + unregister_pernet_subsys(&udp4_net_ops); +} +#endif /* CONFIG_PROC_FS */ + +static __initdata unsigned long uhash_entries; +static int __init set_uhash_entries(char *str) +{ + ssize_t ret; + + if (!str) + return 0; + + ret = kstrtoul(str, 0, &uhash_entries); + if (ret) + return 0; + + if (uhash_entries && uhash_entries < UDP_HTABLE_SIZE_MIN) + uhash_entries = UDP_HTABLE_SIZE_MIN; + return 1; +} +__setup("uhash_entries=", set_uhash_entries); + +void __init udp_table_init(struct udp_table *table, const char *name) +{ + unsigned int i; + + table->hash = alloc_large_system_hash(name, + 2 * sizeof(struct udp_hslot), + uhash_entries, + 21, /* one slot per 2 MB */ + 0, + &table->log, + &table->mask, + UDP_HTABLE_SIZE_MIN, + 64 * 1024); + + table->hash2 = table->hash + (table->mask + 1); + for (i = 0; i <= table->mask; i++) { + INIT_HLIST_HEAD(&table->hash[i].head); + table->hash[i].count = 0; + spin_lock_init(&table->hash[i].lock); + } + for (i = 0; i <= table->mask; i++) { + INIT_HLIST_HEAD(&table->hash2[i].head); + table->hash2[i].count = 0; + spin_lock_init(&table->hash2[i].lock); + } +} + +u32 udp_flow_hashrnd(void) +{ + static u32 hashrnd __read_mostly; + + net_get_random_once(&hashrnd, sizeof(hashrnd)); + + return hashrnd; +} +EXPORT_SYMBOL(udp_flow_hashrnd); + +static int __net_init udp_sysctl_init(struct net *net) +{ + net->ipv4.sysctl_udp_rmem_min = PAGE_SIZE; + net->ipv4.sysctl_udp_wmem_min = PAGE_SIZE; + +#ifdef CONFIG_NET_L3_MASTER_DEV + net->ipv4.sysctl_udp_l3mdev_accept = 0; +#endif + + return 0; +} + +static struct pernet_operations __net_initdata udp_sysctl_ops = { + .init = udp_sysctl_init, +}; + +#if defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS) +DEFINE_BPF_ITER_FUNC(udp, struct bpf_iter_meta *meta, + struct udp_sock *udp_sk, uid_t uid, int bucket) + +static int bpf_iter_init_udp(void *priv_data, struct bpf_iter_aux_info *aux) +{ + struct udp_iter_state *st = priv_data; + struct udp_seq_afinfo *afinfo; + int ret; + + afinfo = kmalloc(sizeof(*afinfo), GFP_USER | __GFP_NOWARN); + if (!afinfo) + return -ENOMEM; + + afinfo->family = AF_UNSPEC; + afinfo->udp_table = &udp_table; + st->bpf_seq_afinfo = afinfo; + ret = bpf_iter_init_seq_net(priv_data, aux); + if (ret) + kfree(afinfo); + return ret; +} + +static void bpf_iter_fini_udp(void *priv_data) +{ + struct udp_iter_state *st = priv_data; + + kfree(st->bpf_seq_afinfo); + bpf_iter_fini_seq_net(priv_data); +} + +static const struct bpf_iter_seq_info udp_seq_info = { + .seq_ops = &bpf_iter_udp_seq_ops, + .init_seq_private = bpf_iter_init_udp, + .fini_seq_private = bpf_iter_fini_udp, + .seq_priv_size = sizeof(struct udp_iter_state), +}; + +static struct bpf_iter_reg udp_reg_info = { + .target = "udp", + .ctx_arg_info_size = 1, + .ctx_arg_info = { + { offsetof(struct bpf_iter__udp, udp_sk), + PTR_TO_BTF_ID_OR_NULL }, + }, + .seq_info = &udp_seq_info, +}; + +static void __init bpf_iter_register(void) +{ + udp_reg_info.ctx_arg_info[0].btf_id = btf_sock_ids[BTF_SOCK_TYPE_UDP]; + if (bpf_iter_reg_target(&udp_reg_info)) + pr_warn("Warning: could not register bpf iterator udp\n"); +} +#endif + +void __init udp_init(void) +{ + unsigned long limit; + unsigned int i; + + udp_table_init(&udp_table, "UDP"); + limit = nr_free_buffer_pages() / 8; + limit = max(limit, 128UL); + sysctl_udp_mem[0] = limit / 4 * 3; + sysctl_udp_mem[1] = limit; + sysctl_udp_mem[2] = sysctl_udp_mem[0] * 2; + + /* 16 spinlocks per cpu */ + udp_busylocks_log = ilog2(nr_cpu_ids) + 4; + udp_busylocks = kmalloc(sizeof(spinlock_t) << udp_busylocks_log, + GFP_KERNEL); + if (!udp_busylocks) + panic("UDP: failed to alloc udp_busylocks\n"); + for (i = 0; i < (1U << udp_busylocks_log); i++) + spin_lock_init(udp_busylocks + i); + + if (register_pernet_subsys(&udp_sysctl_ops)) + panic("UDP: failed to init sysctl parameters.\n"); + +#if defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS) + bpf_iter_register(); +#endif +} diff --git a/net/ipv4/udp_bpf.c b/net/ipv4/udp_bpf.c new file mode 100644 index 000000000..0735d820e --- /dev/null +++ b/net/ipv4/udp_bpf.c @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */ + +#include +#include +#include +#include + +#include "udp_impl.h" + +static struct proto *udpv6_prot_saved __read_mostly; + +static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int flags, int *addr_len) +{ +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) + return udpv6_prot_saved->recvmsg(sk, msg, len, flags, addr_len); +#endif + return udp_prot.recvmsg(sk, msg, len, flags, addr_len); +} + +static bool udp_sk_has_data(struct sock *sk) +{ + return !skb_queue_empty(&udp_sk(sk)->reader_queue) || + !skb_queue_empty(&sk->sk_receive_queue); +} + +static bool psock_has_data(struct sk_psock *psock) +{ + return !skb_queue_empty(&psock->ingress_skb) || + !sk_psock_queue_empty(psock); +} + +#define udp_msg_has_data(__sk, __psock) \ + ({ udp_sk_has_data(__sk) || psock_has_data(__psock); }) + +static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock, + long timeo) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + int ret = 0; + + if (sk->sk_shutdown & RCV_SHUTDOWN) + return 1; + + if (!timeo) + return ret; + + add_wait_queue(sk_sleep(sk), &wait); + sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); + ret = udp_msg_has_data(sk, psock); + if (!ret) { + wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); + ret = udp_msg_has_data(sk, psock); + } + sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); + remove_wait_queue(sk_sleep(sk), &wait); + return ret; +} + +static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int flags, int *addr_len) +{ + struct sk_psock *psock; + int copied, ret; + + if (unlikely(flags & MSG_ERRQUEUE)) + return inet_recv_error(sk, msg, len, addr_len); + + if (!len) + return 0; + + psock = sk_psock_get(sk); + if (unlikely(!psock)) + return sk_udp_recvmsg(sk, msg, len, flags, addr_len); + + if (!psock_has_data(psock)) { + ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len); + goto out; + } + +msg_bytes_ready: + copied = sk_msg_recvmsg(sk, psock, msg, len, flags); + if (!copied) { + long timeo; + int data; + + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + data = udp_msg_wait_data(sk, psock, timeo); + if (data) { + if (psock_has_data(psock)) + goto msg_bytes_ready; + ret = sk_udp_recvmsg(sk, msg, len, flags, addr_len); + goto out; + } + copied = -EAGAIN; + } + ret = copied; +out: + sk_psock_put(sk, psock); + return ret; +} + +enum { + UDP_BPF_IPV4, + UDP_BPF_IPV6, + UDP_BPF_NUM_PROTS, +}; + +static DEFINE_SPINLOCK(udpv6_prot_lock); +static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS]; + +static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base) +{ + *prot = *base; + prot->close = sock_map_close; + prot->recvmsg = udp_bpf_recvmsg; + prot->sock_is_readable = sk_msg_is_readable; +} + +static void udp_bpf_check_v6_needs_rebuild(struct proto *ops) +{ + if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) { + spin_lock_bh(&udpv6_prot_lock); + if (likely(ops != udpv6_prot_saved)) { + udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops); + smp_store_release(&udpv6_prot_saved, ops); + } + spin_unlock_bh(&udpv6_prot_lock); + } +} + +static int __init udp_bpf_v4_build_proto(void) +{ + udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot); + return 0; +} +late_initcall(udp_bpf_v4_build_proto); + +int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) +{ + int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6; + + if (restore) { + sk->sk_write_space = psock->saved_write_space; + sock_replace_proto(sk, psock->sk_proto); + return 0; + } + + if (sk->sk_family == AF_INET6) + udp_bpf_check_v6_needs_rebuild(psock->sk_proto); + + sock_replace_proto(sk, &udp_bpf_prots[family]); + return 0; +} +EXPORT_SYMBOL_GPL(udp_bpf_update_proto); diff --git a/net/ipv4/udp_diag.c b/net/ipv4/udp_diag.c new file mode 100644 index 000000000..1ed8c4d78 --- /dev/null +++ b/net/ipv4/udp_diag.c @@ -0,0 +1,300 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * udp_diag.c Module for monitoring UDP transport protocols sockets. + * + * Authors: Pavel Emelyanov, + */ + + +#include +#include +#include +#include +#include +#include + +static int sk_diag_dump(struct sock *sk, struct sk_buff *skb, + struct netlink_callback *cb, + const struct inet_diag_req_v2 *req, + struct nlattr *bc, bool net_admin) +{ + if (!inet_diag_bc_sk(bc, sk)) + return 0; + + return inet_sk_diag_fill(sk, NULL, skb, cb, req, NLM_F_MULTI, + net_admin); +} + +static int udp_dump_one(struct udp_table *tbl, + struct netlink_callback *cb, + const struct inet_diag_req_v2 *req) +{ + struct sk_buff *in_skb = cb->skb; + int err; + struct sock *sk = NULL; + struct sk_buff *rep; + struct net *net = sock_net(in_skb->sk); + + rcu_read_lock(); + if (req->sdiag_family == AF_INET) + /* src and dst are swapped for historical reasons */ + sk = __udp4_lib_lookup(net, + req->id.idiag_src[0], req->id.idiag_sport, + req->id.idiag_dst[0], req->id.idiag_dport, + req->id.idiag_if, 0, tbl, NULL); +#if IS_ENABLED(CONFIG_IPV6) + else if (req->sdiag_family == AF_INET6) + sk = __udp6_lib_lookup(net, + (struct in6_addr *)req->id.idiag_src, + req->id.idiag_sport, + (struct in6_addr *)req->id.idiag_dst, + req->id.idiag_dport, + req->id.idiag_if, 0, tbl, NULL); +#endif + if (sk && !refcount_inc_not_zero(&sk->sk_refcnt)) + sk = NULL; + rcu_read_unlock(); + err = -ENOENT; + if (!sk) + goto out_nosk; + + err = sock_diag_check_cookie(sk, req->id.idiag_cookie); + if (err) + goto out; + + err = -ENOMEM; + rep = nlmsg_new(nla_total_size(sizeof(struct inet_diag_msg)) + + inet_diag_msg_attrs_size() + + nla_total_size(sizeof(struct inet_diag_meminfo)) + 64, + GFP_KERNEL); + if (!rep) + goto out; + + err = inet_sk_diag_fill(sk, NULL, rep, cb, req, 0, + netlink_net_capable(in_skb, CAP_NET_ADMIN)); + if (err < 0) { + WARN_ON(err == -EMSGSIZE); + kfree_skb(rep); + goto out; + } + err = nlmsg_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid); + +out: + if (sk) + sock_put(sk); +out_nosk: + return err; +} + +static void udp_dump(struct udp_table *table, struct sk_buff *skb, + struct netlink_callback *cb, + const struct inet_diag_req_v2 *r) +{ + bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN); + struct net *net = sock_net(skb->sk); + struct inet_diag_dump_data *cb_data; + int num, s_num, slot, s_slot; + struct nlattr *bc; + + cb_data = cb->data; + bc = cb_data->inet_diag_nla_bc; + s_slot = cb->args[0]; + num = s_num = cb->args[1]; + + for (slot = s_slot; slot <= table->mask; s_num = 0, slot++) { + struct udp_hslot *hslot = &table->hash[slot]; + struct sock *sk; + + num = 0; + + if (hlist_empty(&hslot->head)) + continue; + + spin_lock_bh(&hslot->lock); + sk_for_each(sk, &hslot->head) { + struct inet_sock *inet = inet_sk(sk); + + if (!net_eq(sock_net(sk), net)) + continue; + if (num < s_num) + goto next; + if (!(r->idiag_states & (1 << sk->sk_state))) + goto next; + if (r->sdiag_family != AF_UNSPEC && + sk->sk_family != r->sdiag_family) + goto next; + if (r->id.idiag_sport != inet->inet_sport && + r->id.idiag_sport) + goto next; + if (r->id.idiag_dport != inet->inet_dport && + r->id.idiag_dport) + goto next; + + if (sk_diag_dump(sk, skb, cb, r, bc, net_admin) < 0) { + spin_unlock_bh(&hslot->lock); + goto done; + } +next: + num++; + } + spin_unlock_bh(&hslot->lock); + } +done: + cb->args[0] = slot; + cb->args[1] = num; +} + +static void udp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, + const struct inet_diag_req_v2 *r) +{ + udp_dump(&udp_table, skb, cb, r); +} + +static int udp_diag_dump_one(struct netlink_callback *cb, + const struct inet_diag_req_v2 *req) +{ + return udp_dump_one(&udp_table, cb, req); +} + +static void udp_diag_get_info(struct sock *sk, struct inet_diag_msg *r, + void *info) +{ + r->idiag_rqueue = udp_rqueue_get(sk); + r->idiag_wqueue = sk_wmem_alloc_get(sk); +} + +#ifdef CONFIG_INET_DIAG_DESTROY +static int __udp_diag_destroy(struct sk_buff *in_skb, + const struct inet_diag_req_v2 *req, + struct udp_table *tbl) +{ + struct net *net = sock_net(in_skb->sk); + struct sock *sk; + int err; + + rcu_read_lock(); + + if (req->sdiag_family == AF_INET) + sk = __udp4_lib_lookup(net, + req->id.idiag_dst[0], req->id.idiag_dport, + req->id.idiag_src[0], req->id.idiag_sport, + req->id.idiag_if, 0, tbl, NULL); +#if IS_ENABLED(CONFIG_IPV6) + else if (req->sdiag_family == AF_INET6) { + if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) && + ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src)) + sk = __udp4_lib_lookup(net, + req->id.idiag_dst[3], req->id.idiag_dport, + req->id.idiag_src[3], req->id.idiag_sport, + req->id.idiag_if, 0, tbl, NULL); + + else + sk = __udp6_lib_lookup(net, + (struct in6_addr *)req->id.idiag_dst, + req->id.idiag_dport, + (struct in6_addr *)req->id.idiag_src, + req->id.idiag_sport, + req->id.idiag_if, 0, tbl, NULL); + } +#endif + else { + rcu_read_unlock(); + return -EINVAL; + } + + if (sk && !refcount_inc_not_zero(&sk->sk_refcnt)) + sk = NULL; + + rcu_read_unlock(); + + if (!sk) + return -ENOENT; + + if (sock_diag_check_cookie(sk, req->id.idiag_cookie)) { + sock_put(sk); + return -ENOENT; + } + + err = sock_diag_destroy(sk, ECONNABORTED); + + sock_put(sk); + + return err; +} + +static int udp_diag_destroy(struct sk_buff *in_skb, + const struct inet_diag_req_v2 *req) +{ + return __udp_diag_destroy(in_skb, req, &udp_table); +} + +static int udplite_diag_destroy(struct sk_buff *in_skb, + const struct inet_diag_req_v2 *req) +{ + return __udp_diag_destroy(in_skb, req, &udplite_table); +} + +#endif + +static const struct inet_diag_handler udp_diag_handler = { + .dump = udp_diag_dump, + .dump_one = udp_diag_dump_one, + .idiag_get_info = udp_diag_get_info, + .idiag_type = IPPROTO_UDP, + .idiag_info_size = 0, +#ifdef CONFIG_INET_DIAG_DESTROY + .destroy = udp_diag_destroy, +#endif +}; + +static void udplite_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, + const struct inet_diag_req_v2 *r) +{ + udp_dump(&udplite_table, skb, cb, r); +} + +static int udplite_diag_dump_one(struct netlink_callback *cb, + const struct inet_diag_req_v2 *req) +{ + return udp_dump_one(&udplite_table, cb, req); +} + +static const struct inet_diag_handler udplite_diag_handler = { + .dump = udplite_diag_dump, + .dump_one = udplite_diag_dump_one, + .idiag_get_info = udp_diag_get_info, + .idiag_type = IPPROTO_UDPLITE, + .idiag_info_size = 0, +#ifdef CONFIG_INET_DIAG_DESTROY + .destroy = udplite_diag_destroy, +#endif +}; + +static int __init udp_diag_init(void) +{ + int err; + + err = inet_diag_register(&udp_diag_handler); + if (err) + goto out; + err = inet_diag_register(&udplite_diag_handler); + if (err) + goto out_lite; +out: + return err; +out_lite: + inet_diag_unregister(&udp_diag_handler); + goto out; +} + +static void __exit udp_diag_exit(void) +{ + inet_diag_unregister(&udplite_diag_handler); + inet_diag_unregister(&udp_diag_handler); +} + +module_init(udp_diag_init); +module_exit(udp_diag_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2-17 /* AF_INET - IPPROTO_UDP */); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2-136 /* AF_INET - IPPROTO_UDPLITE */); diff --git a/net/ipv4/udp_impl.h b/net/ipv4/udp_impl.h new file mode 100644 index 000000000..4ba7a88a1 --- /dev/null +++ b/net/ipv4/udp_impl.h @@ -0,0 +1,29 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _UDP4_IMPL_H +#define _UDP4_IMPL_H +#include +#include +#include +#include + +int __udp4_lib_rcv(struct sk_buff *, struct udp_table *, int); +int __udp4_lib_err(struct sk_buff *, u32, struct udp_table *); + +int udp_v4_get_port(struct sock *sk, unsigned short snum); +void udp_v4_rehash(struct sock *sk); + +int udp_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, + unsigned int optlen); +int udp_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen); + +int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, + int *addr_len); +int udp_sendpage(struct sock *sk, struct page *page, int offset, size_t size, + int flags); +void udp_destroy_sock(struct sock *sk); + +#ifdef CONFIG_PROC_FS +int udp4_seq_show(struct seq_file *seq, void *v); +#endif +#endif /* _UDP4_IMPL_H */ diff --git a/net/ipv4/udp_offload.c b/net/ipv4/udp_offload.c new file mode 100644 index 000000000..8096576fd --- /dev/null +++ b/net/ipv4/udp_offload.c @@ -0,0 +1,739 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPV4 GSO/GRO offload support + * Linux INET implementation + * + * UDPv4 GSO support + */ + +#include +#include +#include +#include +#include + +static struct sk_buff *__skb_udp_tunnel_segment(struct sk_buff *skb, + netdev_features_t features, + struct sk_buff *(*gso_inner_segment)(struct sk_buff *skb, + netdev_features_t features), + __be16 new_protocol, bool is_ipv6) +{ + int tnl_hlen = skb_inner_mac_header(skb) - skb_transport_header(skb); + bool remcsum, need_csum, offload_csum, gso_partial; + struct sk_buff *segs = ERR_PTR(-EINVAL); + struct udphdr *uh = udp_hdr(skb); + u16 mac_offset = skb->mac_header; + __be16 protocol = skb->protocol; + u16 mac_len = skb->mac_len; + int udp_offset, outer_hlen; + __wsum partial; + bool need_ipsec; + + if (unlikely(!pskb_may_pull(skb, tnl_hlen))) + goto out; + + /* Adjust partial header checksum to negate old length. + * We cannot rely on the value contained in uh->len as it is + * possible that the actual value exceeds the boundaries of the + * 16 bit length field due to the header being added outside of an + * IP or IPv6 frame that was already limited to 64K - 1. + */ + if (skb_shinfo(skb)->gso_type & SKB_GSO_PARTIAL) + partial = (__force __wsum)uh->len; + else + partial = (__force __wsum)htonl(skb->len); + partial = csum_sub(csum_unfold(uh->check), partial); + + /* setup inner skb. */ + skb->encapsulation = 0; + SKB_GSO_CB(skb)->encap_level = 0; + __skb_pull(skb, tnl_hlen); + skb_reset_mac_header(skb); + skb_set_network_header(skb, skb_inner_network_offset(skb)); + skb_set_transport_header(skb, skb_inner_transport_offset(skb)); + skb->mac_len = skb_inner_network_offset(skb); + skb->protocol = new_protocol; + + need_csum = !!(skb_shinfo(skb)->gso_type & SKB_GSO_UDP_TUNNEL_CSUM); + skb->encap_hdr_csum = need_csum; + + remcsum = !!(skb_shinfo(skb)->gso_type & SKB_GSO_TUNNEL_REMCSUM); + skb->remcsum_offload = remcsum; + + need_ipsec = skb_dst(skb) && dst_xfrm(skb_dst(skb)); + /* Try to offload checksum if possible */ + offload_csum = !!(need_csum && + !need_ipsec && + (skb->dev->features & + (is_ipv6 ? (NETIF_F_HW_CSUM | NETIF_F_IPV6_CSUM) : + (NETIF_F_HW_CSUM | NETIF_F_IP_CSUM)))); + + features &= skb->dev->hw_enc_features; + if (need_csum) + features &= ~NETIF_F_SCTP_CRC; + + /* The only checksum offload we care about from here on out is the + * outer one so strip the existing checksum feature flags and + * instead set the flag based on our outer checksum offload value. + */ + if (remcsum) { + features &= ~NETIF_F_CSUM_MASK; + if (!need_csum || offload_csum) + features |= NETIF_F_HW_CSUM; + } + + /* segment inner packet. */ + segs = gso_inner_segment(skb, features); + if (IS_ERR_OR_NULL(segs)) { + skb_gso_error_unwind(skb, protocol, tnl_hlen, mac_offset, + mac_len); + goto out; + } + + gso_partial = !!(skb_shinfo(segs)->gso_type & SKB_GSO_PARTIAL); + + outer_hlen = skb_tnl_header_len(skb); + udp_offset = outer_hlen - tnl_hlen; + skb = segs; + do { + unsigned int len; + + if (remcsum) + skb->ip_summed = CHECKSUM_NONE; + + /* Set up inner headers if we are offloading inner checksum */ + if (skb->ip_summed == CHECKSUM_PARTIAL) { + skb_reset_inner_headers(skb); + skb->encapsulation = 1; + } + + skb->mac_len = mac_len; + skb->protocol = protocol; + + __skb_push(skb, outer_hlen); + skb_reset_mac_header(skb); + skb_set_network_header(skb, mac_len); + skb_set_transport_header(skb, udp_offset); + len = skb->len - udp_offset; + uh = udp_hdr(skb); + + /* If we are only performing partial GSO the inner header + * will be using a length value equal to only one MSS sized + * segment instead of the entire frame. + */ + if (gso_partial && skb_is_gso(skb)) { + uh->len = htons(skb_shinfo(skb)->gso_size + + SKB_GSO_CB(skb)->data_offset + + skb->head - (unsigned char *)uh); + } else { + uh->len = htons(len); + } + + if (!need_csum) + continue; + + uh->check = ~csum_fold(csum_add(partial, + (__force __wsum)htonl(len))); + + if (skb->encapsulation || !offload_csum) { + uh->check = gso_make_checksum(skb, ~uh->check); + if (uh->check == 0) + uh->check = CSUM_MANGLED_0; + } else { + skb->ip_summed = CHECKSUM_PARTIAL; + skb->csum_start = skb_transport_header(skb) - skb->head; + skb->csum_offset = offsetof(struct udphdr, check); + } + } while ((skb = skb->next)); +out: + return segs; +} + +struct sk_buff *skb_udp_tunnel_segment(struct sk_buff *skb, + netdev_features_t features, + bool is_ipv6) +{ + const struct net_offload __rcu **offloads; + __be16 protocol = skb->protocol; + const struct net_offload *ops; + struct sk_buff *segs = ERR_PTR(-EINVAL); + struct sk_buff *(*gso_inner_segment)(struct sk_buff *skb, + netdev_features_t features); + + rcu_read_lock(); + + switch (skb->inner_protocol_type) { + case ENCAP_TYPE_ETHER: + protocol = skb->inner_protocol; + gso_inner_segment = skb_mac_gso_segment; + break; + case ENCAP_TYPE_IPPROTO: + offloads = is_ipv6 ? inet6_offloads : inet_offloads; + ops = rcu_dereference(offloads[skb->inner_ipproto]); + if (!ops || !ops->callbacks.gso_segment) + goto out_unlock; + gso_inner_segment = ops->callbacks.gso_segment; + break; + default: + goto out_unlock; + } + + segs = __skb_udp_tunnel_segment(skb, features, gso_inner_segment, + protocol, is_ipv6); + +out_unlock: + rcu_read_unlock(); + + return segs; +} +EXPORT_SYMBOL(skb_udp_tunnel_segment); + +static void __udpv4_gso_segment_csum(struct sk_buff *seg, + __be32 *oldip, __be32 *newip, + __be16 *oldport, __be16 *newport) +{ + struct udphdr *uh; + struct iphdr *iph; + + if (*oldip == *newip && *oldport == *newport) + return; + + uh = udp_hdr(seg); + iph = ip_hdr(seg); + + if (uh->check) { + inet_proto_csum_replace4(&uh->check, seg, *oldip, *newip, + true); + inet_proto_csum_replace2(&uh->check, seg, *oldport, *newport, + false); + if (!uh->check) + uh->check = CSUM_MANGLED_0; + } + *oldport = *newport; + + csum_replace4(&iph->check, *oldip, *newip); + *oldip = *newip; +} + +static struct sk_buff *__udpv4_gso_segment_list_csum(struct sk_buff *segs) +{ + struct sk_buff *seg; + struct udphdr *uh, *uh2; + struct iphdr *iph, *iph2; + + seg = segs; + uh = udp_hdr(seg); + iph = ip_hdr(seg); + + if ((udp_hdr(seg)->dest == udp_hdr(seg->next)->dest) && + (udp_hdr(seg)->source == udp_hdr(seg->next)->source) && + (ip_hdr(seg)->daddr == ip_hdr(seg->next)->daddr) && + (ip_hdr(seg)->saddr == ip_hdr(seg->next)->saddr)) + return segs; + + while ((seg = seg->next)) { + uh2 = udp_hdr(seg); + iph2 = ip_hdr(seg); + + __udpv4_gso_segment_csum(seg, + &iph2->saddr, &iph->saddr, + &uh2->source, &uh->source); + __udpv4_gso_segment_csum(seg, + &iph2->daddr, &iph->daddr, + &uh2->dest, &uh->dest); + } + + return segs; +} + +static struct sk_buff *__udp_gso_segment_list(struct sk_buff *skb, + netdev_features_t features, + bool is_ipv6) +{ + unsigned int mss = skb_shinfo(skb)->gso_size; + + skb = skb_segment_list(skb, features, skb_mac_header_len(skb)); + if (IS_ERR(skb)) + return skb; + + udp_hdr(skb)->len = htons(sizeof(struct udphdr) + mss); + + return is_ipv6 ? skb : __udpv4_gso_segment_list_csum(skb); +} + +struct sk_buff *__udp_gso_segment(struct sk_buff *gso_skb, + netdev_features_t features, bool is_ipv6) +{ + struct sock *sk = gso_skb->sk; + unsigned int sum_truesize = 0; + struct sk_buff *segs, *seg; + struct udphdr *uh; + unsigned int mss; + bool copy_dtor; + __sum16 check; + __be16 newlen; + + if (skb_shinfo(gso_skb)->gso_type & SKB_GSO_FRAGLIST) + return __udp_gso_segment_list(gso_skb, features, is_ipv6); + + mss = skb_shinfo(gso_skb)->gso_size; + if (gso_skb->len <= sizeof(*uh) + mss) + return ERR_PTR(-EINVAL); + + skb_pull(gso_skb, sizeof(*uh)); + + /* clear destructor to avoid skb_segment assigning it to tail */ + copy_dtor = gso_skb->destructor == sock_wfree; + if (copy_dtor) + gso_skb->destructor = NULL; + + segs = skb_segment(gso_skb, features); + if (IS_ERR_OR_NULL(segs)) { + if (copy_dtor) + gso_skb->destructor = sock_wfree; + return segs; + } + + /* GSO partial and frag_list segmentation only requires splitting + * the frame into an MSS multiple and possibly a remainder, both + * cases return a GSO skb. So update the mss now. + */ + if (skb_is_gso(segs)) + mss *= skb_shinfo(segs)->gso_segs; + + seg = segs; + uh = udp_hdr(seg); + + /* preserve TX timestamp flags and TS key for first segment */ + skb_shinfo(seg)->tskey = skb_shinfo(gso_skb)->tskey; + skb_shinfo(seg)->tx_flags |= + (skb_shinfo(gso_skb)->tx_flags & SKBTX_ANY_TSTAMP); + + /* compute checksum adjustment based on old length versus new */ + newlen = htons(sizeof(*uh) + mss); + check = csum16_add(csum16_sub(uh->check, uh->len), newlen); + + for (;;) { + if (copy_dtor) { + seg->destructor = sock_wfree; + seg->sk = sk; + sum_truesize += seg->truesize; + } + + if (!seg->next) + break; + + uh->len = newlen; + uh->check = check; + + if (seg->ip_summed == CHECKSUM_PARTIAL) + gso_reset_checksum(seg, ~check); + else + uh->check = gso_make_checksum(seg, ~check) ? : + CSUM_MANGLED_0; + + seg = seg->next; + uh = udp_hdr(seg); + } + + /* last packet can be partial gso_size, account for that in checksum */ + newlen = htons(skb_tail_pointer(seg) - skb_transport_header(seg) + + seg->data_len); + check = csum16_add(csum16_sub(uh->check, uh->len), newlen); + + uh->len = newlen; + uh->check = check; + + if (seg->ip_summed == CHECKSUM_PARTIAL) + gso_reset_checksum(seg, ~check); + else + uh->check = gso_make_checksum(seg, ~check) ? : CSUM_MANGLED_0; + + /* update refcount for the packet */ + if (copy_dtor) { + int delta = sum_truesize - gso_skb->truesize; + + /* In some pathological cases, delta can be negative. + * We need to either use refcount_add() or refcount_sub_and_test() + */ + if (likely(delta >= 0)) + refcount_add(delta, &sk->sk_wmem_alloc); + else + WARN_ON_ONCE(refcount_sub_and_test(-delta, &sk->sk_wmem_alloc)); + } + return segs; +} +EXPORT_SYMBOL_GPL(__udp_gso_segment); + +static struct sk_buff *udp4_ufo_fragment(struct sk_buff *skb, + netdev_features_t features) +{ + struct sk_buff *segs = ERR_PTR(-EINVAL); + unsigned int mss; + __wsum csum; + struct udphdr *uh; + struct iphdr *iph; + + if (skb->encapsulation && + (skb_shinfo(skb)->gso_type & + (SKB_GSO_UDP_TUNNEL|SKB_GSO_UDP_TUNNEL_CSUM))) { + segs = skb_udp_tunnel_segment(skb, features, false); + goto out; + } + + if (!(skb_shinfo(skb)->gso_type & (SKB_GSO_UDP | SKB_GSO_UDP_L4))) + goto out; + + if (!pskb_may_pull(skb, sizeof(struct udphdr))) + goto out; + + if (skb_shinfo(skb)->gso_type & SKB_GSO_UDP_L4) + return __udp_gso_segment(skb, features, false); + + mss = skb_shinfo(skb)->gso_size; + if (unlikely(skb->len <= mss)) + goto out; + + /* Do software UFO. Complete and fill in the UDP checksum as + * HW cannot do checksum of UDP packets sent as multiple + * IP fragments. + */ + + uh = udp_hdr(skb); + iph = ip_hdr(skb); + + uh->check = 0; + csum = skb_checksum(skb, 0, skb->len, 0); + uh->check = udp_v4_check(skb->len, iph->saddr, iph->daddr, csum); + if (uh->check == 0) + uh->check = CSUM_MANGLED_0; + + skb->ip_summed = CHECKSUM_UNNECESSARY; + + /* If there is no outer header we can fake a checksum offload + * due to the fact that we have already done the checksum in + * software prior to segmenting the frame. + */ + if (!skb->encap_hdr_csum) + features |= NETIF_F_HW_CSUM; + + /* Fragment the skb. IP headers of the fragments are updated in + * inet_gso_segment() + */ + segs = skb_segment(skb, features); +out: + return segs; +} + +static int skb_gro_receive_list(struct sk_buff *p, struct sk_buff *skb) +{ + if (unlikely(p->len + skb->len >= 65536)) + return -E2BIG; + + if (NAPI_GRO_CB(p)->last == p) + skb_shinfo(p)->frag_list = skb; + else + NAPI_GRO_CB(p)->last->next = skb; + + skb_pull(skb, skb_gro_offset(skb)); + + NAPI_GRO_CB(p)->last = skb; + NAPI_GRO_CB(p)->count++; + p->data_len += skb->len; + + /* sk owenrship - if any - completely transferred to the aggregated packet */ + skb->destructor = NULL; + p->truesize += skb->truesize; + p->len += skb->len; + + NAPI_GRO_CB(skb)->same_flow = 1; + + return 0; +} + + +#define UDP_GRO_CNT_MAX 64 +static struct sk_buff *udp_gro_receive_segment(struct list_head *head, + struct sk_buff *skb) +{ + struct udphdr *uh = udp_gro_udphdr(skb); + struct sk_buff *pp = NULL; + struct udphdr *uh2; + struct sk_buff *p; + unsigned int ulen; + int ret = 0; + + /* requires non zero csum, for symmetry with GSO */ + if (!uh->check) { + NAPI_GRO_CB(skb)->flush = 1; + return NULL; + } + + /* Do not deal with padded or malicious packets, sorry ! */ + ulen = ntohs(uh->len); + if (ulen <= sizeof(*uh) || ulen != skb_gro_len(skb)) { + NAPI_GRO_CB(skb)->flush = 1; + return NULL; + } + /* pull encapsulating udp header */ + skb_gro_pull(skb, sizeof(struct udphdr)); + + list_for_each_entry(p, head, list) { + if (!NAPI_GRO_CB(p)->same_flow) + continue; + + uh2 = udp_hdr(p); + + /* Match ports only, as csum is always non zero */ + if ((*(u32 *)&uh->source != *(u32 *)&uh2->source)) { + NAPI_GRO_CB(p)->same_flow = 0; + continue; + } + + if (NAPI_GRO_CB(skb)->is_flist != NAPI_GRO_CB(p)->is_flist) { + NAPI_GRO_CB(skb)->flush = 1; + return p; + } + + /* Terminate the flow on len mismatch or if it grow "too much". + * Under small packet flood GRO count could elsewhere grow a lot + * leading to excessive truesize values. + * On len mismatch merge the first packet shorter than gso_size, + * otherwise complete the GRO packet. + */ + if (ulen > ntohs(uh2->len)) { + pp = p; + } else { + if (NAPI_GRO_CB(skb)->is_flist) { + if (!pskb_may_pull(skb, skb_gro_offset(skb))) { + NAPI_GRO_CB(skb)->flush = 1; + return NULL; + } + if ((skb->ip_summed != p->ip_summed) || + (skb->csum_level != p->csum_level)) { + NAPI_GRO_CB(skb)->flush = 1; + return NULL; + } + ret = skb_gro_receive_list(p, skb); + } else { + skb_gro_postpull_rcsum(skb, uh, + sizeof(struct udphdr)); + + ret = skb_gro_receive(p, skb); + } + } + + if (ret || ulen != ntohs(uh2->len) || + NAPI_GRO_CB(p)->count >= UDP_GRO_CNT_MAX) + pp = p; + + return pp; + } + + /* mismatch, but we never need to flush */ + return NULL; +} + +struct sk_buff *udp_gro_receive(struct list_head *head, struct sk_buff *skb, + struct udphdr *uh, struct sock *sk) +{ + struct sk_buff *pp = NULL; + struct sk_buff *p; + struct udphdr *uh2; + unsigned int off = skb_gro_offset(skb); + int flush = 1; + + /* we can do L4 aggregation only if the packet can't land in a tunnel + * otherwise we could corrupt the inner stream + */ + NAPI_GRO_CB(skb)->is_flist = 0; + if (!sk || !udp_sk(sk)->gro_receive) { + if (skb->dev->features & NETIF_F_GRO_FRAGLIST) + NAPI_GRO_CB(skb)->is_flist = sk ? !udp_test_bit(GRO_ENABLED, sk) : 1; + + if ((!sk && (skb->dev->features & NETIF_F_GRO_UDP_FWD)) || + (sk && udp_test_bit(GRO_ENABLED, sk)) || NAPI_GRO_CB(skb)->is_flist) + return call_gro_receive(udp_gro_receive_segment, head, skb); + + /* no GRO, be sure flush the current packet */ + goto out; + } + + if (NAPI_GRO_CB(skb)->encap_mark || + (uh->check && skb->ip_summed != CHECKSUM_PARTIAL && + NAPI_GRO_CB(skb)->csum_cnt == 0 && + !NAPI_GRO_CB(skb)->csum_valid)) + goto out; + + /* mark that this skb passed once through the tunnel gro layer */ + NAPI_GRO_CB(skb)->encap_mark = 1; + + flush = 0; + + list_for_each_entry(p, head, list) { + if (!NAPI_GRO_CB(p)->same_flow) + continue; + + uh2 = (struct udphdr *)(p->data + off); + + /* Match ports and either checksums are either both zero + * or nonzero. + */ + if ((*(u32 *)&uh->source != *(u32 *)&uh2->source) || + (!uh->check ^ !uh2->check)) { + NAPI_GRO_CB(p)->same_flow = 0; + continue; + } + } + + skb_gro_pull(skb, sizeof(struct udphdr)); /* pull encapsulating udp header */ + skb_gro_postpull_rcsum(skb, uh, sizeof(struct udphdr)); + pp = call_gro_receive_sk(udp_sk(sk)->gro_receive, sk, head, skb); + +out: + skb_gro_flush_final(skb, pp, flush); + return pp; +} +EXPORT_SYMBOL(udp_gro_receive); + +static struct sock *udp4_gro_lookup_skb(struct sk_buff *skb, __be16 sport, + __be16 dport) +{ + const struct iphdr *iph = skb_gro_network_header(skb); + + return __udp4_lib_lookup(dev_net(skb->dev), iph->saddr, sport, + iph->daddr, dport, inet_iif(skb), + inet_sdif(skb), &udp_table, NULL); +} + +INDIRECT_CALLABLE_SCOPE +struct sk_buff *udp4_gro_receive(struct list_head *head, struct sk_buff *skb) +{ + struct udphdr *uh = udp_gro_udphdr(skb); + struct sock *sk = NULL; + struct sk_buff *pp; + + if (unlikely(!uh)) + goto flush; + + /* Don't bother verifying checksum if we're going to flush anyway. */ + if (NAPI_GRO_CB(skb)->flush) + goto skip; + + if (skb_gro_checksum_validate_zero_check(skb, IPPROTO_UDP, uh->check, + inet_gro_compute_pseudo)) + goto flush; + else if (uh->check) + skb_gro_checksum_try_convert(skb, IPPROTO_UDP, + inet_gro_compute_pseudo); +skip: + NAPI_GRO_CB(skb)->is_ipv6 = 0; + + if (static_branch_unlikely(&udp_encap_needed_key)) + sk = udp4_gro_lookup_skb(skb, uh->source, uh->dest); + + pp = udp_gro_receive(head, skb, uh, sk); + return pp; + +flush: + NAPI_GRO_CB(skb)->flush = 1; + return NULL; +} + +static int udp_gro_complete_segment(struct sk_buff *skb) +{ + struct udphdr *uh = udp_hdr(skb); + + skb->csum_start = (unsigned char *)uh - skb->head; + skb->csum_offset = offsetof(struct udphdr, check); + skb->ip_summed = CHECKSUM_PARTIAL; + + skb_shinfo(skb)->gso_segs = NAPI_GRO_CB(skb)->count; + skb_shinfo(skb)->gso_type |= SKB_GSO_UDP_L4; + + if (skb->encapsulation) + skb->inner_transport_header = skb->transport_header; + + return 0; +} + +int udp_gro_complete(struct sk_buff *skb, int nhoff, + udp_lookup_t lookup) +{ + __be16 newlen = htons(skb->len - nhoff); + struct udphdr *uh = (struct udphdr *)(skb->data + nhoff); + struct sock *sk; + int err; + + uh->len = newlen; + + sk = INDIRECT_CALL_INET(lookup, udp6_lib_lookup_skb, + udp4_lib_lookup_skb, skb, uh->source, uh->dest); + if (sk && udp_sk(sk)->gro_complete) { + skb_shinfo(skb)->gso_type = uh->check ? SKB_GSO_UDP_TUNNEL_CSUM + : SKB_GSO_UDP_TUNNEL; + + /* clear the encap mark, so that inner frag_list gro_complete + * can take place + */ + NAPI_GRO_CB(skb)->encap_mark = 0; + + /* Set encapsulation before calling into inner gro_complete() + * functions to make them set up the inner offsets. + */ + skb->encapsulation = 1; + err = udp_sk(sk)->gro_complete(sk, skb, + nhoff + sizeof(struct udphdr)); + } else { + err = udp_gro_complete_segment(skb); + } + + if (skb->remcsum_offload) + skb_shinfo(skb)->gso_type |= SKB_GSO_TUNNEL_REMCSUM; + + return err; +} +EXPORT_SYMBOL(udp_gro_complete); + +INDIRECT_CALLABLE_SCOPE int udp4_gro_complete(struct sk_buff *skb, int nhoff) +{ + const struct iphdr *iph = ip_hdr(skb); + struct udphdr *uh = (struct udphdr *)(skb->data + nhoff); + + /* do fraglist only if there is no outer UDP encap (or we already processed it) */ + if (NAPI_GRO_CB(skb)->is_flist && !NAPI_GRO_CB(skb)->encap_mark) { + uh->len = htons(skb->len - nhoff); + + skb_shinfo(skb)->gso_type |= (SKB_GSO_FRAGLIST|SKB_GSO_UDP_L4); + skb_shinfo(skb)->gso_segs = NAPI_GRO_CB(skb)->count; + + if (skb->ip_summed == CHECKSUM_UNNECESSARY) { + if (skb->csum_level < SKB_MAX_CSUM_LEVEL) + skb->csum_level++; + } else { + skb->ip_summed = CHECKSUM_UNNECESSARY; + skb->csum_level = 0; + } + + return 0; + } + + if (uh->check) + uh->check = ~udp_v4_check(skb->len - nhoff, iph->saddr, + iph->daddr, 0); + + return udp_gro_complete(skb, nhoff, udp4_lib_lookup_skb); +} + +static const struct net_offload udpv4_offload = { + .callbacks = { + .gso_segment = udp4_ufo_fragment, + .gro_receive = udp4_gro_receive, + .gro_complete = udp4_gro_complete, + }, +}; + +int __init udpv4_offload_init(void) +{ + return inet_add_offload(&udpv4_offload, IPPROTO_UDP); +} diff --git a/net/ipv4/udp_tunnel_core.c b/net/ipv4/udp_tunnel_core.c new file mode 100644 index 000000000..732e21b75 --- /dev/null +++ b/net/ipv4/udp_tunnel_core.c @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include +#include + +int udp_sock_create4(struct net *net, struct udp_port_cfg *cfg, + struct socket **sockp) +{ + int err; + struct socket *sock = NULL; + struct sockaddr_in udp_addr; + + err = sock_create_kern(net, AF_INET, SOCK_DGRAM, 0, &sock); + if (err < 0) + goto error; + + if (cfg->bind_ifindex) { + err = sock_bindtoindex(sock->sk, cfg->bind_ifindex, true); + if (err < 0) + goto error; + } + + udp_addr.sin_family = AF_INET; + udp_addr.sin_addr = cfg->local_ip; + udp_addr.sin_port = cfg->local_udp_port; + err = kernel_bind(sock, (struct sockaddr *)&udp_addr, + sizeof(udp_addr)); + if (err < 0) + goto error; + + if (cfg->peer_udp_port) { + udp_addr.sin_family = AF_INET; + udp_addr.sin_addr = cfg->peer_ip; + udp_addr.sin_port = cfg->peer_udp_port; + err = kernel_connect(sock, (struct sockaddr *)&udp_addr, + sizeof(udp_addr), 0); + if (err < 0) + goto error; + } + + sock->sk->sk_no_check_tx = !cfg->use_udp_checksums; + + *sockp = sock; + return 0; + +error: + if (sock) { + kernel_sock_shutdown(sock, SHUT_RDWR); + sock_release(sock); + } + *sockp = NULL; + return err; +} +EXPORT_SYMBOL(udp_sock_create4); + +void setup_udp_tunnel_sock(struct net *net, struct socket *sock, + struct udp_tunnel_sock_cfg *cfg) +{ + struct sock *sk = sock->sk; + + /* Disable multicast loopback */ + inet_sk(sk)->mc_loop = 0; + + /* Enable CHECKSUM_UNNECESSARY to CHECKSUM_COMPLETE conversion */ + inet_inc_convert_csum(sk); + + rcu_assign_sk_user_data(sk, cfg->sk_user_data); + + udp_sk(sk)->encap_type = cfg->encap_type; + udp_sk(sk)->encap_rcv = cfg->encap_rcv; + udp_sk(sk)->encap_err_rcv = cfg->encap_err_rcv; + udp_sk(sk)->encap_err_lookup = cfg->encap_err_lookup; + udp_sk(sk)->encap_destroy = cfg->encap_destroy; + udp_sk(sk)->gro_receive = cfg->gro_receive; + udp_sk(sk)->gro_complete = cfg->gro_complete; + + udp_tunnel_encap_enable(sk); +} +EXPORT_SYMBOL_GPL(setup_udp_tunnel_sock); + +void udp_tunnel_push_rx_port(struct net_device *dev, struct socket *sock, + unsigned short type) +{ + struct sock *sk = sock->sk; + struct udp_tunnel_info ti; + + ti.type = type; + ti.sa_family = sk->sk_family; + ti.port = inet_sk(sk)->inet_sport; + + udp_tunnel_nic_add_port(dev, &ti); +} +EXPORT_SYMBOL_GPL(udp_tunnel_push_rx_port); + +void udp_tunnel_drop_rx_port(struct net_device *dev, struct socket *sock, + unsigned short type) +{ + struct sock *sk = sock->sk; + struct udp_tunnel_info ti; + + ti.type = type; + ti.sa_family = sk->sk_family; + ti.port = inet_sk(sk)->inet_sport; + + udp_tunnel_nic_del_port(dev, &ti); +} +EXPORT_SYMBOL_GPL(udp_tunnel_drop_rx_port); + +/* Notify netdevs that UDP port started listening */ +void udp_tunnel_notify_add_rx_port(struct socket *sock, unsigned short type) +{ + struct sock *sk = sock->sk; + struct net *net = sock_net(sk); + struct udp_tunnel_info ti; + struct net_device *dev; + + ti.type = type; + ti.sa_family = sk->sk_family; + ti.port = inet_sk(sk)->inet_sport; + + rcu_read_lock(); + for_each_netdev_rcu(net, dev) { + udp_tunnel_nic_add_port(dev, &ti); + } + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(udp_tunnel_notify_add_rx_port); + +/* Notify netdevs that UDP port is no more listening */ +void udp_tunnel_notify_del_rx_port(struct socket *sock, unsigned short type) +{ + struct sock *sk = sock->sk; + struct net *net = sock_net(sk); + struct udp_tunnel_info ti; + struct net_device *dev; + + ti.type = type; + ti.sa_family = sk->sk_family; + ti.port = inet_sk(sk)->inet_sport; + + rcu_read_lock(); + for_each_netdev_rcu(net, dev) { + udp_tunnel_nic_del_port(dev, &ti); + } + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(udp_tunnel_notify_del_rx_port); + +void udp_tunnel_xmit_skb(struct rtable *rt, struct sock *sk, struct sk_buff *skb, + __be32 src, __be32 dst, __u8 tos, __u8 ttl, + __be16 df, __be16 src_port, __be16 dst_port, + bool xnet, bool nocheck) +{ + struct udphdr *uh; + + __skb_push(skb, sizeof(*uh)); + skb_reset_transport_header(skb); + uh = udp_hdr(skb); + + uh->dest = dst_port; + uh->source = src_port; + uh->len = htons(skb->len); + + memset(&(IPCB(skb)->opt), 0, sizeof(IPCB(skb)->opt)); + + udp_set_csum(nocheck, skb, src, dst, skb->len); + + iptunnel_xmit(sk, rt, skb, src, dst, IPPROTO_UDP, tos, ttl, df, xnet); +} +EXPORT_SYMBOL_GPL(udp_tunnel_xmit_skb); + +void udp_tunnel_sock_release(struct socket *sock) +{ + rcu_assign_sk_user_data(sock->sk, NULL); + synchronize_rcu(); + kernel_sock_shutdown(sock, SHUT_RDWR); + sock_release(sock); +} +EXPORT_SYMBOL_GPL(udp_tunnel_sock_release); + +struct metadata_dst *udp_tun_rx_dst(struct sk_buff *skb, unsigned short family, + __be16 flags, __be64 tunnel_id, int md_size) +{ + struct metadata_dst *tun_dst; + struct ip_tunnel_info *info; + + if (family == AF_INET) + tun_dst = ip_tun_rx_dst(skb, flags, tunnel_id, md_size); + else + tun_dst = ipv6_tun_rx_dst(skb, flags, tunnel_id, md_size); + if (!tun_dst) + return NULL; + + info = &tun_dst->u.tun_info; + info->key.tp_src = udp_hdr(skb)->source; + info->key.tp_dst = udp_hdr(skb)->dest; + if (udp_hdr(skb)->check) + info->key.tun_flags |= TUNNEL_CSUM; + return tun_dst; +} +EXPORT_SYMBOL_GPL(udp_tun_rx_dst); + +MODULE_LICENSE("GPL"); diff --git a/net/ipv4/udp_tunnel_nic.c b/net/ipv4/udp_tunnel_nic.c new file mode 100644 index 000000000..bc3a043a5 --- /dev/null +++ b/net/ipv4/udp_tunnel_nic.c @@ -0,0 +1,973 @@ +// SPDX-License-Identifier: GPL-2.0-only +// Copyright (c) 2020 Facebook Inc. + +#include +#include +#include +#include +#include +#include +#include + +enum udp_tunnel_nic_table_entry_flags { + UDP_TUNNEL_NIC_ENTRY_ADD = BIT(0), + UDP_TUNNEL_NIC_ENTRY_DEL = BIT(1), + UDP_TUNNEL_NIC_ENTRY_OP_FAIL = BIT(2), + UDP_TUNNEL_NIC_ENTRY_FROZEN = BIT(3), +}; + +struct udp_tunnel_nic_table_entry { + __be16 port; + u8 type; + u8 flags; + u16 use_cnt; +#define UDP_TUNNEL_NIC_USE_CNT_MAX U16_MAX + u8 hw_priv; +}; + +/** + * struct udp_tunnel_nic - UDP tunnel port offload state + * @work: async work for talking to hardware from process context + * @dev: netdev pointer + * @need_sync: at least one port start changed + * @need_replay: space was freed, we need a replay of all ports + * @work_pending: @work is currently scheduled + * @n_tables: number of tables under @entries + * @missed: bitmap of tables which overflown + * @entries: table of tables of ports currently offloaded + */ +struct udp_tunnel_nic { + struct work_struct work; + + struct net_device *dev; + + u8 need_sync:1; + u8 need_replay:1; + u8 work_pending:1; + + unsigned int n_tables; + unsigned long missed; + struct udp_tunnel_nic_table_entry **entries; +}; + +/* We ensure all work structs are done using driver state, but not the code. + * We need a workqueue we can flush before module gets removed. + */ +static struct workqueue_struct *udp_tunnel_nic_workqueue; + +static const char *udp_tunnel_nic_tunnel_type_name(unsigned int type) +{ + switch (type) { + case UDP_TUNNEL_TYPE_VXLAN: + return "vxlan"; + case UDP_TUNNEL_TYPE_GENEVE: + return "geneve"; + case UDP_TUNNEL_TYPE_VXLAN_GPE: + return "vxlan-gpe"; + default: + return "unknown"; + } +} + +static bool +udp_tunnel_nic_entry_is_free(struct udp_tunnel_nic_table_entry *entry) +{ + return entry->use_cnt == 0 && !entry->flags; +} + +static bool +udp_tunnel_nic_entry_is_present(struct udp_tunnel_nic_table_entry *entry) +{ + return entry->use_cnt && !(entry->flags & ~UDP_TUNNEL_NIC_ENTRY_FROZEN); +} + +static bool +udp_tunnel_nic_entry_is_frozen(struct udp_tunnel_nic_table_entry *entry) +{ + return entry->flags & UDP_TUNNEL_NIC_ENTRY_FROZEN; +} + +static void +udp_tunnel_nic_entry_freeze_used(struct udp_tunnel_nic_table_entry *entry) +{ + if (!udp_tunnel_nic_entry_is_free(entry)) + entry->flags |= UDP_TUNNEL_NIC_ENTRY_FROZEN; +} + +static void +udp_tunnel_nic_entry_unfreeze(struct udp_tunnel_nic_table_entry *entry) +{ + entry->flags &= ~UDP_TUNNEL_NIC_ENTRY_FROZEN; +} + +static bool +udp_tunnel_nic_entry_is_queued(struct udp_tunnel_nic_table_entry *entry) +{ + return entry->flags & (UDP_TUNNEL_NIC_ENTRY_ADD | + UDP_TUNNEL_NIC_ENTRY_DEL); +} + +static void +udp_tunnel_nic_entry_queue(struct udp_tunnel_nic *utn, + struct udp_tunnel_nic_table_entry *entry, + unsigned int flag) +{ + entry->flags |= flag; + utn->need_sync = 1; +} + +static void +udp_tunnel_nic_ti_from_entry(struct udp_tunnel_nic_table_entry *entry, + struct udp_tunnel_info *ti) +{ + memset(ti, 0, sizeof(*ti)); + ti->port = entry->port; + ti->type = entry->type; + ti->hw_priv = entry->hw_priv; +} + +static bool +udp_tunnel_nic_is_empty(struct net_device *dev, struct udp_tunnel_nic *utn) +{ + const struct udp_tunnel_nic_info *info = dev->udp_tunnel_nic_info; + unsigned int i, j; + + for (i = 0; i < utn->n_tables; i++) + for (j = 0; j < info->tables[i].n_entries; j++) + if (!udp_tunnel_nic_entry_is_free(&utn->entries[i][j])) + return false; + return true; +} + +static bool +udp_tunnel_nic_should_replay(struct net_device *dev, struct udp_tunnel_nic *utn) +{ + const struct udp_tunnel_nic_table_info *table; + unsigned int i, j; + + if (!utn->missed) + return false; + + for (i = 0; i < utn->n_tables; i++) { + table = &dev->udp_tunnel_nic_info->tables[i]; + if (!test_bit(i, &utn->missed)) + continue; + + for (j = 0; j < table->n_entries; j++) + if (udp_tunnel_nic_entry_is_free(&utn->entries[i][j])) + return true; + } + + return false; +} + +static void +__udp_tunnel_nic_get_port(struct net_device *dev, unsigned int table, + unsigned int idx, struct udp_tunnel_info *ti) +{ + struct udp_tunnel_nic_table_entry *entry; + struct udp_tunnel_nic *utn; + + utn = dev->udp_tunnel_nic; + entry = &utn->entries[table][idx]; + + if (entry->use_cnt) + udp_tunnel_nic_ti_from_entry(entry, ti); +} + +static void +__udp_tunnel_nic_set_port_priv(struct net_device *dev, unsigned int table, + unsigned int idx, u8 priv) +{ + dev->udp_tunnel_nic->entries[table][idx].hw_priv = priv; +} + +static void +udp_tunnel_nic_entry_update_done(struct udp_tunnel_nic_table_entry *entry, + int err) +{ + bool dodgy = entry->flags & UDP_TUNNEL_NIC_ENTRY_OP_FAIL; + + WARN_ON_ONCE(entry->flags & UDP_TUNNEL_NIC_ENTRY_ADD && + entry->flags & UDP_TUNNEL_NIC_ENTRY_DEL); + + if (entry->flags & UDP_TUNNEL_NIC_ENTRY_ADD && + (!err || (err == -EEXIST && dodgy))) + entry->flags &= ~UDP_TUNNEL_NIC_ENTRY_ADD; + + if (entry->flags & UDP_TUNNEL_NIC_ENTRY_DEL && + (!err || (err == -ENOENT && dodgy))) + entry->flags &= ~UDP_TUNNEL_NIC_ENTRY_DEL; + + if (!err) + entry->flags &= ~UDP_TUNNEL_NIC_ENTRY_OP_FAIL; + else + entry->flags |= UDP_TUNNEL_NIC_ENTRY_OP_FAIL; +} + +static void +udp_tunnel_nic_device_sync_one(struct net_device *dev, + struct udp_tunnel_nic *utn, + unsigned int table, unsigned int idx) +{ + struct udp_tunnel_nic_table_entry *entry; + struct udp_tunnel_info ti; + int err; + + entry = &utn->entries[table][idx]; + if (!udp_tunnel_nic_entry_is_queued(entry)) + return; + + udp_tunnel_nic_ti_from_entry(entry, &ti); + if (entry->flags & UDP_TUNNEL_NIC_ENTRY_ADD) + err = dev->udp_tunnel_nic_info->set_port(dev, table, idx, &ti); + else + err = dev->udp_tunnel_nic_info->unset_port(dev, table, idx, + &ti); + udp_tunnel_nic_entry_update_done(entry, err); + + if (err) + netdev_warn(dev, + "UDP tunnel port sync failed port %d type %s: %d\n", + be16_to_cpu(entry->port), + udp_tunnel_nic_tunnel_type_name(entry->type), + err); +} + +static void +udp_tunnel_nic_device_sync_by_port(struct net_device *dev, + struct udp_tunnel_nic *utn) +{ + const struct udp_tunnel_nic_info *info = dev->udp_tunnel_nic_info; + unsigned int i, j; + + for (i = 0; i < utn->n_tables; i++) + for (j = 0; j < info->tables[i].n_entries; j++) + udp_tunnel_nic_device_sync_one(dev, utn, i, j); +} + +static void +udp_tunnel_nic_device_sync_by_table(struct net_device *dev, + struct udp_tunnel_nic *utn) +{ + const struct udp_tunnel_nic_info *info = dev->udp_tunnel_nic_info; + unsigned int i, j; + int err; + + for (i = 0; i < utn->n_tables; i++) { + /* Find something that needs sync in this table */ + for (j = 0; j < info->tables[i].n_entries; j++) + if (udp_tunnel_nic_entry_is_queued(&utn->entries[i][j])) + break; + if (j == info->tables[i].n_entries) + continue; + + err = info->sync_table(dev, i); + if (err) + netdev_warn(dev, "UDP tunnel port sync failed for table %d: %d\n", + i, err); + + for (j = 0; j < info->tables[i].n_entries; j++) { + struct udp_tunnel_nic_table_entry *entry; + + entry = &utn->entries[i][j]; + if (udp_tunnel_nic_entry_is_queued(entry)) + udp_tunnel_nic_entry_update_done(entry, err); + } + } +} + +static void +__udp_tunnel_nic_device_sync(struct net_device *dev, struct udp_tunnel_nic *utn) +{ + if (!utn->need_sync) + return; + + if (dev->udp_tunnel_nic_info->sync_table) + udp_tunnel_nic_device_sync_by_table(dev, utn); + else + udp_tunnel_nic_device_sync_by_port(dev, utn); + + utn->need_sync = 0; + /* Can't replay directly here, in case we come from the tunnel driver's + * notification - trying to replay may deadlock inside tunnel driver. + */ + utn->need_replay = udp_tunnel_nic_should_replay(dev, utn); +} + +static void +udp_tunnel_nic_device_sync(struct net_device *dev, struct udp_tunnel_nic *utn) +{ + const struct udp_tunnel_nic_info *info = dev->udp_tunnel_nic_info; + bool may_sleep; + + if (!utn->need_sync) + return; + + /* Drivers which sleep in the callback need to update from + * the workqueue, if we come from the tunnel driver's notification. + */ + may_sleep = info->flags & UDP_TUNNEL_NIC_INFO_MAY_SLEEP; + if (!may_sleep) + __udp_tunnel_nic_device_sync(dev, utn); + if (may_sleep || utn->need_replay) { + queue_work(udp_tunnel_nic_workqueue, &utn->work); + utn->work_pending = 1; + } +} + +static bool +udp_tunnel_nic_table_is_capable(const struct udp_tunnel_nic_table_info *table, + struct udp_tunnel_info *ti) +{ + return table->tunnel_types & ti->type; +} + +static bool +udp_tunnel_nic_is_capable(struct net_device *dev, struct udp_tunnel_nic *utn, + struct udp_tunnel_info *ti) +{ + const struct udp_tunnel_nic_info *info = dev->udp_tunnel_nic_info; + unsigned int i; + + /* Special case IPv4-only NICs */ + if (info->flags & UDP_TUNNEL_NIC_INFO_IPV4_ONLY && + ti->sa_family != AF_INET) + return false; + + for (i = 0; i < utn->n_tables; i++) + if (udp_tunnel_nic_table_is_capable(&info->tables[i], ti)) + return true; + return false; +} + +static int +udp_tunnel_nic_has_collision(struct net_device *dev, struct udp_tunnel_nic *utn, + struct udp_tunnel_info *ti) +{ + const struct udp_tunnel_nic_info *info = dev->udp_tunnel_nic_info; + struct udp_tunnel_nic_table_entry *entry; + unsigned int i, j; + + for (i = 0; i < utn->n_tables; i++) + for (j = 0; j < info->tables[i].n_entries; j++) { + entry = &utn->entries[i][j]; + + if (!udp_tunnel_nic_entry_is_free(entry) && + entry->port == ti->port && + entry->type != ti->type) { + __set_bit(i, &utn->missed); + return true; + } + } + return false; +} + +static void +udp_tunnel_nic_entry_adj(struct udp_tunnel_nic *utn, + unsigned int table, unsigned int idx, int use_cnt_adj) +{ + struct udp_tunnel_nic_table_entry *entry = &utn->entries[table][idx]; + bool dodgy = entry->flags & UDP_TUNNEL_NIC_ENTRY_OP_FAIL; + unsigned int from, to; + + WARN_ON(entry->use_cnt + (u32)use_cnt_adj > U16_MAX); + + /* If not going from used to unused or vice versa - all done. + * For dodgy entries make sure we try to sync again (queue the entry). + */ + entry->use_cnt += use_cnt_adj; + if (!dodgy && !entry->use_cnt == !(entry->use_cnt - use_cnt_adj)) + return; + + /* Cancel the op before it was sent to the device, if possible, + * otherwise we'd need to take special care to issue commands + * in the same order the ports arrived. + */ + if (use_cnt_adj < 0) { + from = UDP_TUNNEL_NIC_ENTRY_ADD; + to = UDP_TUNNEL_NIC_ENTRY_DEL; + } else { + from = UDP_TUNNEL_NIC_ENTRY_DEL; + to = UDP_TUNNEL_NIC_ENTRY_ADD; + } + + if (entry->flags & from) { + entry->flags &= ~from; + if (!dodgy) + return; + } + + udp_tunnel_nic_entry_queue(utn, entry, to); +} + +static bool +udp_tunnel_nic_entry_try_adj(struct udp_tunnel_nic *utn, + unsigned int table, unsigned int idx, + struct udp_tunnel_info *ti, int use_cnt_adj) +{ + struct udp_tunnel_nic_table_entry *entry = &utn->entries[table][idx]; + + if (udp_tunnel_nic_entry_is_free(entry) || + entry->port != ti->port || + entry->type != ti->type) + return false; + + if (udp_tunnel_nic_entry_is_frozen(entry)) + return true; + + udp_tunnel_nic_entry_adj(utn, table, idx, use_cnt_adj); + return true; +} + +/* Try to find existing matching entry and adjust its use count, instead of + * adding a new one. Returns true if entry was found. In case of delete the + * entry may have gotten removed in the process, in which case it will be + * queued for removal. + */ +static bool +udp_tunnel_nic_try_existing(struct net_device *dev, struct udp_tunnel_nic *utn, + struct udp_tunnel_info *ti, int use_cnt_adj) +{ + const struct udp_tunnel_nic_table_info *table; + unsigned int i, j; + + for (i = 0; i < utn->n_tables; i++) { + table = &dev->udp_tunnel_nic_info->tables[i]; + if (!udp_tunnel_nic_table_is_capable(table, ti)) + continue; + + for (j = 0; j < table->n_entries; j++) + if (udp_tunnel_nic_entry_try_adj(utn, i, j, ti, + use_cnt_adj)) + return true; + } + + return false; +} + +static bool +udp_tunnel_nic_add_existing(struct net_device *dev, struct udp_tunnel_nic *utn, + struct udp_tunnel_info *ti) +{ + return udp_tunnel_nic_try_existing(dev, utn, ti, +1); +} + +static bool +udp_tunnel_nic_del_existing(struct net_device *dev, struct udp_tunnel_nic *utn, + struct udp_tunnel_info *ti) +{ + return udp_tunnel_nic_try_existing(dev, utn, ti, -1); +} + +static bool +udp_tunnel_nic_add_new(struct net_device *dev, struct udp_tunnel_nic *utn, + struct udp_tunnel_info *ti) +{ + const struct udp_tunnel_nic_table_info *table; + unsigned int i, j; + + for (i = 0; i < utn->n_tables; i++) { + table = &dev->udp_tunnel_nic_info->tables[i]; + if (!udp_tunnel_nic_table_is_capable(table, ti)) + continue; + + for (j = 0; j < table->n_entries; j++) { + struct udp_tunnel_nic_table_entry *entry; + + entry = &utn->entries[i][j]; + if (!udp_tunnel_nic_entry_is_free(entry)) + continue; + + entry->port = ti->port; + entry->type = ti->type; + entry->use_cnt = 1; + udp_tunnel_nic_entry_queue(utn, entry, + UDP_TUNNEL_NIC_ENTRY_ADD); + return true; + } + + /* The different table may still fit this port in, but there + * are no devices currently which have multiple tables accepting + * the same tunnel type, and false positives are okay. + */ + __set_bit(i, &utn->missed); + } + + return false; +} + +static void +__udp_tunnel_nic_add_port(struct net_device *dev, struct udp_tunnel_info *ti) +{ + const struct udp_tunnel_nic_info *info = dev->udp_tunnel_nic_info; + struct udp_tunnel_nic *utn; + + utn = dev->udp_tunnel_nic; + if (!utn) + return; + if (!netif_running(dev) && info->flags & UDP_TUNNEL_NIC_INFO_OPEN_ONLY) + return; + if (info->flags & UDP_TUNNEL_NIC_INFO_STATIC_IANA_VXLAN && + ti->port == htons(IANA_VXLAN_UDP_PORT)) { + if (ti->type != UDP_TUNNEL_TYPE_VXLAN) + netdev_warn(dev, "device assumes port 4789 will be used by vxlan tunnels\n"); + return; + } + + if (!udp_tunnel_nic_is_capable(dev, utn, ti)) + return; + + /* It may happen that a tunnel of one type is removed and different + * tunnel type tries to reuse its port before the device was informed. + * Rely on utn->missed to re-add this port later. + */ + if (udp_tunnel_nic_has_collision(dev, utn, ti)) + return; + + if (!udp_tunnel_nic_add_existing(dev, utn, ti)) + udp_tunnel_nic_add_new(dev, utn, ti); + + udp_tunnel_nic_device_sync(dev, utn); +} + +static void +__udp_tunnel_nic_del_port(struct net_device *dev, struct udp_tunnel_info *ti) +{ + struct udp_tunnel_nic *utn; + + utn = dev->udp_tunnel_nic; + if (!utn) + return; + + if (!udp_tunnel_nic_is_capable(dev, utn, ti)) + return; + + udp_tunnel_nic_del_existing(dev, utn, ti); + + udp_tunnel_nic_device_sync(dev, utn); +} + +static void __udp_tunnel_nic_reset_ntf(struct net_device *dev) +{ + const struct udp_tunnel_nic_info *info = dev->udp_tunnel_nic_info; + struct udp_tunnel_nic *utn; + unsigned int i, j; + + ASSERT_RTNL(); + + utn = dev->udp_tunnel_nic; + if (!utn) + return; + + utn->need_sync = false; + for (i = 0; i < utn->n_tables; i++) + for (j = 0; j < info->tables[i].n_entries; j++) { + struct udp_tunnel_nic_table_entry *entry; + + entry = &utn->entries[i][j]; + + entry->flags &= ~(UDP_TUNNEL_NIC_ENTRY_DEL | + UDP_TUNNEL_NIC_ENTRY_OP_FAIL); + /* We don't release rtnl across ops */ + WARN_ON(entry->flags & UDP_TUNNEL_NIC_ENTRY_FROZEN); + if (!entry->use_cnt) + continue; + + udp_tunnel_nic_entry_queue(utn, entry, + UDP_TUNNEL_NIC_ENTRY_ADD); + } + + __udp_tunnel_nic_device_sync(dev, utn); +} + +static size_t +__udp_tunnel_nic_dump_size(struct net_device *dev, unsigned int table) +{ + const struct udp_tunnel_nic_info *info = dev->udp_tunnel_nic_info; + struct udp_tunnel_nic *utn; + unsigned int j; + size_t size; + + utn = dev->udp_tunnel_nic; + if (!utn) + return 0; + + size = 0; + for (j = 0; j < info->tables[table].n_entries; j++) { + if (!udp_tunnel_nic_entry_is_present(&utn->entries[table][j])) + continue; + + size += nla_total_size(0) + /* _TABLE_ENTRY */ + nla_total_size(sizeof(__be16)) + /* _ENTRY_PORT */ + nla_total_size(sizeof(u32)); /* _ENTRY_TYPE */ + } + + return size; +} + +static int +__udp_tunnel_nic_dump_write(struct net_device *dev, unsigned int table, + struct sk_buff *skb) +{ + const struct udp_tunnel_nic_info *info = dev->udp_tunnel_nic_info; + struct udp_tunnel_nic *utn; + struct nlattr *nest; + unsigned int j; + + utn = dev->udp_tunnel_nic; + if (!utn) + return 0; + + for (j = 0; j < info->tables[table].n_entries; j++) { + if (!udp_tunnel_nic_entry_is_present(&utn->entries[table][j])) + continue; + + nest = nla_nest_start(skb, ETHTOOL_A_TUNNEL_UDP_TABLE_ENTRY); + + if (nla_put_be16(skb, ETHTOOL_A_TUNNEL_UDP_ENTRY_PORT, + utn->entries[table][j].port) || + nla_put_u32(skb, ETHTOOL_A_TUNNEL_UDP_ENTRY_TYPE, + ilog2(utn->entries[table][j].type))) + goto err_cancel; + + nla_nest_end(skb, nest); + } + + return 0; + +err_cancel: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static const struct udp_tunnel_nic_ops __udp_tunnel_nic_ops = { + .get_port = __udp_tunnel_nic_get_port, + .set_port_priv = __udp_tunnel_nic_set_port_priv, + .add_port = __udp_tunnel_nic_add_port, + .del_port = __udp_tunnel_nic_del_port, + .reset_ntf = __udp_tunnel_nic_reset_ntf, + .dump_size = __udp_tunnel_nic_dump_size, + .dump_write = __udp_tunnel_nic_dump_write, +}; + +static void +udp_tunnel_nic_flush(struct net_device *dev, struct udp_tunnel_nic *utn) +{ + const struct udp_tunnel_nic_info *info = dev->udp_tunnel_nic_info; + unsigned int i, j; + + for (i = 0; i < utn->n_tables; i++) + for (j = 0; j < info->tables[i].n_entries; j++) { + int adj_cnt = -utn->entries[i][j].use_cnt; + + if (adj_cnt) + udp_tunnel_nic_entry_adj(utn, i, j, adj_cnt); + } + + __udp_tunnel_nic_device_sync(dev, utn); + + for (i = 0; i < utn->n_tables; i++) + memset(utn->entries[i], 0, array_size(info->tables[i].n_entries, + sizeof(**utn->entries))); + WARN_ON(utn->need_sync); + utn->need_replay = 0; +} + +static void +udp_tunnel_nic_replay(struct net_device *dev, struct udp_tunnel_nic *utn) +{ + const struct udp_tunnel_nic_info *info = dev->udp_tunnel_nic_info; + struct udp_tunnel_nic_shared_node *node; + unsigned int i, j; + + /* Freeze all the ports we are already tracking so that the replay + * does not double up the refcount. + */ + for (i = 0; i < utn->n_tables; i++) + for (j = 0; j < info->tables[i].n_entries; j++) + udp_tunnel_nic_entry_freeze_used(&utn->entries[i][j]); + utn->missed = 0; + utn->need_replay = 0; + + if (!info->shared) { + udp_tunnel_get_rx_info(dev); + } else { + list_for_each_entry(node, &info->shared->devices, list) + udp_tunnel_get_rx_info(node->dev); + } + + for (i = 0; i < utn->n_tables; i++) + for (j = 0; j < info->tables[i].n_entries; j++) + udp_tunnel_nic_entry_unfreeze(&utn->entries[i][j]); +} + +static void udp_tunnel_nic_device_sync_work(struct work_struct *work) +{ + struct udp_tunnel_nic *utn = + container_of(work, struct udp_tunnel_nic, work); + + rtnl_lock(); + utn->work_pending = 0; + __udp_tunnel_nic_device_sync(utn->dev, utn); + + if (utn->need_replay) + udp_tunnel_nic_replay(utn->dev, utn); + rtnl_unlock(); +} + +static struct udp_tunnel_nic * +udp_tunnel_nic_alloc(const struct udp_tunnel_nic_info *info, + unsigned int n_tables) +{ + struct udp_tunnel_nic *utn; + unsigned int i; + + utn = kzalloc(sizeof(*utn), GFP_KERNEL); + if (!utn) + return NULL; + utn->n_tables = n_tables; + INIT_WORK(&utn->work, udp_tunnel_nic_device_sync_work); + + utn->entries = kmalloc_array(n_tables, sizeof(void *), GFP_KERNEL); + if (!utn->entries) + goto err_free_utn; + + for (i = 0; i < n_tables; i++) { + utn->entries[i] = kcalloc(info->tables[i].n_entries, + sizeof(*utn->entries[i]), GFP_KERNEL); + if (!utn->entries[i]) + goto err_free_prev_entries; + } + + return utn; + +err_free_prev_entries: + while (i--) + kfree(utn->entries[i]); + kfree(utn->entries); +err_free_utn: + kfree(utn); + return NULL; +} + +static void udp_tunnel_nic_free(struct udp_tunnel_nic *utn) +{ + unsigned int i; + + for (i = 0; i < utn->n_tables; i++) + kfree(utn->entries[i]); + kfree(utn->entries); + kfree(utn); +} + +static int udp_tunnel_nic_register(struct net_device *dev) +{ + const struct udp_tunnel_nic_info *info = dev->udp_tunnel_nic_info; + struct udp_tunnel_nic_shared_node *node = NULL; + struct udp_tunnel_nic *utn; + unsigned int n_tables, i; + + BUILD_BUG_ON(sizeof(utn->missed) * BITS_PER_BYTE < + UDP_TUNNEL_NIC_MAX_TABLES); + /* Expect use count of at most 2 (IPv4, IPv6) per device */ + BUILD_BUG_ON(UDP_TUNNEL_NIC_USE_CNT_MAX < + UDP_TUNNEL_NIC_MAX_SHARING_DEVICES * 2); + + /* Check that the driver info is sane */ + if (WARN_ON(!info->set_port != !info->unset_port) || + WARN_ON(!info->set_port == !info->sync_table) || + WARN_ON(!info->tables[0].n_entries)) + return -EINVAL; + + if (WARN_ON(info->shared && + info->flags & UDP_TUNNEL_NIC_INFO_OPEN_ONLY)) + return -EINVAL; + + n_tables = 1; + for (i = 1; i < UDP_TUNNEL_NIC_MAX_TABLES; i++) { + if (!info->tables[i].n_entries) + continue; + + n_tables++; + if (WARN_ON(!info->tables[i - 1].n_entries)) + return -EINVAL; + } + + /* Create UDP tunnel state structures */ + if (info->shared) { + node = kzalloc(sizeof(*node), GFP_KERNEL); + if (!node) + return -ENOMEM; + + node->dev = dev; + } + + if (info->shared && info->shared->udp_tunnel_nic_info) { + utn = info->shared->udp_tunnel_nic_info; + } else { + utn = udp_tunnel_nic_alloc(info, n_tables); + if (!utn) { + kfree(node); + return -ENOMEM; + } + } + + if (info->shared) { + if (!info->shared->udp_tunnel_nic_info) { + INIT_LIST_HEAD(&info->shared->devices); + info->shared->udp_tunnel_nic_info = utn; + } + + list_add_tail(&node->list, &info->shared->devices); + } + + utn->dev = dev; + dev_hold(dev); + dev->udp_tunnel_nic = utn; + + if (!(info->flags & UDP_TUNNEL_NIC_INFO_OPEN_ONLY)) + udp_tunnel_get_rx_info(dev); + + return 0; +} + +static void +udp_tunnel_nic_unregister(struct net_device *dev, struct udp_tunnel_nic *utn) +{ + const struct udp_tunnel_nic_info *info = dev->udp_tunnel_nic_info; + + /* For a shared table remove this dev from the list of sharing devices + * and if there are other devices just detach. + */ + if (info->shared) { + struct udp_tunnel_nic_shared_node *node, *first; + + list_for_each_entry(node, &info->shared->devices, list) + if (node->dev == dev) + break; + if (list_entry_is_head(node, &info->shared->devices, list)) + return; + + list_del(&node->list); + kfree(node); + + first = list_first_entry_or_null(&info->shared->devices, + typeof(*first), list); + if (first) { + udp_tunnel_drop_rx_info(dev); + utn->dev = first->dev; + goto release_dev; + } + + info->shared->udp_tunnel_nic_info = NULL; + } + + /* Flush before we check work, so we don't waste time adding entries + * from the work which we will boot immediately. + */ + udp_tunnel_nic_flush(dev, utn); + + /* Wait for the work to be done using the state, netdev core will + * retry unregister until we give up our reference on this device. + */ + if (utn->work_pending) + return; + + udp_tunnel_nic_free(utn); +release_dev: + dev->udp_tunnel_nic = NULL; + dev_put(dev); +} + +static int +udp_tunnel_nic_netdevice_event(struct notifier_block *unused, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + const struct udp_tunnel_nic_info *info; + struct udp_tunnel_nic *utn; + + info = dev->udp_tunnel_nic_info; + if (!info) + return NOTIFY_DONE; + + if (event == NETDEV_REGISTER) { + int err; + + err = udp_tunnel_nic_register(dev); + if (err) + netdev_WARN(dev, "failed to register for UDP tunnel offloads: %d", err); + return notifier_from_errno(err); + } + /* All other events will need the udp_tunnel_nic state */ + utn = dev->udp_tunnel_nic; + if (!utn) + return NOTIFY_DONE; + + if (event == NETDEV_UNREGISTER) { + udp_tunnel_nic_unregister(dev, utn); + return NOTIFY_OK; + } + + /* All other events only matter if NIC has to be programmed open */ + if (!(info->flags & UDP_TUNNEL_NIC_INFO_OPEN_ONLY)) + return NOTIFY_DONE; + + if (event == NETDEV_UP) { + WARN_ON(!udp_tunnel_nic_is_empty(dev, utn)); + udp_tunnel_get_rx_info(dev); + return NOTIFY_OK; + } + if (event == NETDEV_GOING_DOWN) { + udp_tunnel_nic_flush(dev, utn); + return NOTIFY_OK; + } + + return NOTIFY_DONE; +} + +static struct notifier_block udp_tunnel_nic_notifier_block __read_mostly = { + .notifier_call = udp_tunnel_nic_netdevice_event, +}; + +static int __init udp_tunnel_nic_init_module(void) +{ + int err; + + udp_tunnel_nic_workqueue = alloc_ordered_workqueue("udp_tunnel_nic", 0); + if (!udp_tunnel_nic_workqueue) + return -ENOMEM; + + rtnl_lock(); + udp_tunnel_nic_ops = &__udp_tunnel_nic_ops; + rtnl_unlock(); + + err = register_netdevice_notifier(&udp_tunnel_nic_notifier_block); + if (err) + goto err_unset_ops; + + return 0; + +err_unset_ops: + rtnl_lock(); + udp_tunnel_nic_ops = NULL; + rtnl_unlock(); + destroy_workqueue(udp_tunnel_nic_workqueue); + return err; +} +late_initcall(udp_tunnel_nic_init_module); + +static void __exit udp_tunnel_nic_cleanup_module(void) +{ + unregister_netdevice_notifier(&udp_tunnel_nic_notifier_block); + + rtnl_lock(); + udp_tunnel_nic_ops = NULL; + rtnl_unlock(); + + destroy_workqueue(udp_tunnel_nic_workqueue); +} +module_exit(udp_tunnel_nic_cleanup_module); + +MODULE_LICENSE("GPL"); diff --git a/net/ipv4/udp_tunnel_stub.c b/net/ipv4/udp_tunnel_stub.c new file mode 100644 index 000000000..c4b2888f5 --- /dev/null +++ b/net/ipv4/udp_tunnel_stub.c @@ -0,0 +1,7 @@ +// SPDX-License-Identifier: GPL-2.0-only +// Copyright (c) 2020 Facebook Inc. + +#include + +const struct udp_tunnel_nic_ops *udp_tunnel_nic_ops; +EXPORT_SYMBOL_GPL(udp_tunnel_nic_ops); diff --git a/net/ipv4/udplite.c b/net/ipv4/udplite.c new file mode 100644 index 000000000..56d94d23b --- /dev/null +++ b/net/ipv4/udplite.c @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * UDPLITE An implementation of the UDP-Lite protocol (RFC 3828). + * + * Authors: Gerrit Renker + * + * Changes: + * Fixes: + */ + +#define pr_fmt(fmt) "UDPLite: " fmt + +#include +#include +#include "udp_impl.h" + +struct udp_table udplite_table __read_mostly; +EXPORT_SYMBOL(udplite_table); + +/* Designate sk as UDP-Lite socket */ +static int udplite_sk_init(struct sock *sk) +{ + udp_init_sock(sk); + udp_sk(sk)->pcflag = UDPLITE_BIT; + return 0; +} + +static int udplite_rcv(struct sk_buff *skb) +{ + return __udp4_lib_rcv(skb, &udplite_table, IPPROTO_UDPLITE); +} + +static int udplite_err(struct sk_buff *skb, u32 info) +{ + return __udp4_lib_err(skb, info, &udplite_table); +} + +static const struct net_protocol udplite_protocol = { + .handler = udplite_rcv, + .err_handler = udplite_err, + .no_policy = 1, +}; + +struct proto udplite_prot = { + .name = "UDP-Lite", + .owner = THIS_MODULE, + .close = udp_lib_close, + .connect = ip4_datagram_connect, + .disconnect = udp_disconnect, + .ioctl = udp_ioctl, + .init = udplite_sk_init, + .destroy = udp_destroy_sock, + .setsockopt = udp_setsockopt, + .getsockopt = udp_getsockopt, + .sendmsg = udp_sendmsg, + .recvmsg = udp_recvmsg, + .sendpage = udp_sendpage, + .hash = udp_lib_hash, + .unhash = udp_lib_unhash, + .rehash = udp_v4_rehash, + .get_port = udp_v4_get_port, + + .memory_allocated = &udp_memory_allocated, + .per_cpu_fw_alloc = &udp_memory_per_cpu_fw_alloc, + + .sysctl_mem = sysctl_udp_mem, + .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min), + .sysctl_rmem_offset = offsetof(struct net, ipv4.sysctl_udp_rmem_min), + .obj_size = sizeof(struct udp_sock), + .h.udp_table = &udplite_table, +}; +EXPORT_SYMBOL(udplite_prot); + +static struct inet_protosw udplite4_protosw = { + .type = SOCK_DGRAM, + .protocol = IPPROTO_UDPLITE, + .prot = &udplite_prot, + .ops = &inet_dgram_ops, + .flags = INET_PROTOSW_PERMANENT, +}; + +#ifdef CONFIG_PROC_FS +static struct udp_seq_afinfo udplite4_seq_afinfo = { + .family = AF_INET, + .udp_table = &udplite_table, +}; + +static int __net_init udplite4_proc_init_net(struct net *net) +{ + if (!proc_create_net_data("udplite", 0444, net->proc_net, &udp_seq_ops, + sizeof(struct udp_iter_state), &udplite4_seq_afinfo)) + return -ENOMEM; + return 0; +} + +static void __net_exit udplite4_proc_exit_net(struct net *net) +{ + remove_proc_entry("udplite", net->proc_net); +} + +static struct pernet_operations udplite4_net_ops = { + .init = udplite4_proc_init_net, + .exit = udplite4_proc_exit_net, +}; + +static __init int udplite4_proc_init(void) +{ + return register_pernet_subsys(&udplite4_net_ops); +} +#else +static inline int udplite4_proc_init(void) +{ + return 0; +} +#endif + +void __init udplite4_register(void) +{ + udp_table_init(&udplite_table, "UDP-Lite"); + if (proto_register(&udplite_prot, 1)) + goto out_register_err; + + if (inet_add_protocol(&udplite_protocol, IPPROTO_UDPLITE) < 0) + goto out_unregister_proto; + + inet_register_protosw(&udplite4_protosw); + + if (udplite4_proc_init()) + pr_err("%s: Cannot register /proc!\n", __func__); + return; + +out_unregister_proto: + proto_unregister(&udplite_prot); +out_register_err: + pr_crit("%s: Cannot add UDP-Lite protocol\n", __func__); +} diff --git a/net/ipv4/xfrm4_input.c b/net/ipv4/xfrm4_input.c new file mode 100644 index 000000000..183f6dc37 --- /dev/null +++ b/net/ipv4/xfrm4_input.c @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * xfrm4_input.c + * + * Changes: + * YOSHIFUJI Hideaki @USAGI + * Split up af-specific portion + * Derek Atkins + * Add Encapsulation support + * + */ + +#include +#include +#include +#include +#include +#include +#include + +static int xfrm4_rcv_encap_finish2(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + return dst_input(skb); +} + +static inline int xfrm4_rcv_encap_finish(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + if (!skb_dst(skb)) { + const struct iphdr *iph = ip_hdr(skb); + + if (ip_route_input_noref(skb, iph->daddr, iph->saddr, + iph->tos, skb->dev)) + goto drop; + } + + if (xfrm_trans_queue(skb, xfrm4_rcv_encap_finish2)) + goto drop; + + return 0; +drop: + kfree_skb(skb); + return NET_RX_DROP; +} + +int xfrm4_transport_finish(struct sk_buff *skb, int async) +{ + struct xfrm_offload *xo = xfrm_offload(skb); + struct iphdr *iph = ip_hdr(skb); + + iph->protocol = XFRM_MODE_SKB_CB(skb)->protocol; + +#ifndef CONFIG_NETFILTER + if (!async) + return -iph->protocol; +#endif + + __skb_push(skb, skb->data - skb_network_header(skb)); + iph->tot_len = htons(skb->len); + ip_send_check(iph); + + if (xo && (xo->flags & XFRM_GRO)) { + skb_mac_header_rebuild(skb); + skb_reset_transport_header(skb); + return 0; + } + + NF_HOOK(NFPROTO_IPV4, NF_INET_PRE_ROUTING, + dev_net(skb->dev), NULL, skb, skb->dev, NULL, + xfrm4_rcv_encap_finish); + return 0; +} + +/* If it's a keepalive packet, then just eat it. + * If it's an encapsulated packet, then pass it to the + * IPsec xfrm input. + * Returns 0 if skb passed to xfrm or was dropped. + * Returns >0 if skb should be passed to UDP. + * Returns <0 if skb should be resubmitted (-ret is protocol) + */ +int xfrm4_udp_encap_rcv(struct sock *sk, struct sk_buff *skb) +{ + struct udp_sock *up = udp_sk(sk); + struct udphdr *uh; + struct iphdr *iph; + int iphlen, len; + __u8 *udpdata; + __be32 *udpdata32; + u16 encap_type; + + encap_type = READ_ONCE(up->encap_type); + /* if this is not encapsulated socket, then just return now */ + if (!encap_type) + return 1; + + /* If this is a paged skb, make sure we pull up + * whatever data we need to look at. */ + len = skb->len - sizeof(struct udphdr); + if (!pskb_may_pull(skb, sizeof(struct udphdr) + min(len, 8))) + return 1; + + /* Now we can get the pointers */ + uh = udp_hdr(skb); + udpdata = (__u8 *)uh + sizeof(struct udphdr); + udpdata32 = (__be32 *)udpdata; + + switch (encap_type) { + default: + case UDP_ENCAP_ESPINUDP: + /* Check if this is a keepalive packet. If so, eat it. */ + if (len == 1 && udpdata[0] == 0xff) { + goto drop; + } else if (len > sizeof(struct ip_esp_hdr) && udpdata32[0] != 0) { + /* ESP Packet without Non-ESP header */ + len = sizeof(struct udphdr); + } else + /* Must be an IKE packet.. pass it through */ + return 1; + break; + case UDP_ENCAP_ESPINUDP_NON_IKE: + /* Check if this is a keepalive packet. If so, eat it. */ + if (len == 1 && udpdata[0] == 0xff) { + goto drop; + } else if (len > 2 * sizeof(u32) + sizeof(struct ip_esp_hdr) && + udpdata32[0] == 0 && udpdata32[1] == 0) { + + /* ESP Packet with Non-IKE marker */ + len = sizeof(struct udphdr) + 2 * sizeof(u32); + } else + /* Must be an IKE packet.. pass it through */ + return 1; + break; + } + + /* At this point we are sure that this is an ESPinUDP packet, + * so we need to remove 'len' bytes from the packet (the UDP + * header and optional ESP marker bytes) and then modify the + * protocol to ESP, and then call into the transform receiver. + */ + if (skb_unclone(skb, GFP_ATOMIC)) + goto drop; + + /* Now we can update and verify the packet length... */ + iph = ip_hdr(skb); + iphlen = iph->ihl << 2; + iph->tot_len = htons(ntohs(iph->tot_len) - len); + if (skb->len < iphlen + len) { + /* packet is too small!?! */ + goto drop; + } + + /* pull the data buffer up to the ESP header and set the + * transport header to point to ESP. Keep UDP on the stack + * for later. + */ + __skb_pull(skb, len); + skb_reset_transport_header(skb); + + /* process ESP */ + return xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, encap_type); + +drop: + kfree_skb(skb); + return 0; +} +EXPORT_SYMBOL(xfrm4_udp_encap_rcv); + +int xfrm4_rcv(struct sk_buff *skb) +{ + return xfrm4_rcv_spi(skb, ip_hdr(skb)->protocol, 0); +} +EXPORT_SYMBOL(xfrm4_rcv); diff --git a/net/ipv4/xfrm4_output.c b/net/ipv4/xfrm4_output.c new file mode 100644 index 000000000..3cff51ba7 --- /dev/null +++ b/net/ipv4/xfrm4_output.c @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * xfrm4_output.c - Common IPsec encapsulation code for IPv4. + * Copyright (c) 2004 Herbert Xu + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int __xfrm4_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ +#ifdef CONFIG_NETFILTER + struct xfrm_state *x = skb_dst(skb)->xfrm; + + if (!x) { + IPCB(skb)->flags |= IPSKB_REROUTED; + return dst_output(net, sk, skb); + } +#endif + + return xfrm_output(sk, skb); +} + +int xfrm4_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + return NF_HOOK_COND(NFPROTO_IPV4, NF_INET_POST_ROUTING, + net, sk, skb, skb->dev, skb_dst(skb)->dev, + __xfrm4_output, + !(IPCB(skb)->flags & IPSKB_REROUTED)); +} + +void xfrm4_local_error(struct sk_buff *skb, u32 mtu) +{ + struct iphdr *hdr; + + hdr = skb->encapsulation ? inner_ip_hdr(skb) : ip_hdr(skb); + ip_local_error(skb->sk, EMSGSIZE, hdr->daddr, + inet_sk(skb->sk)->inet_dport, mtu); +} diff --git a/net/ipv4/xfrm4_policy.c b/net/ipv4/xfrm4_policy.c new file mode 100644 index 000000000..3d0dfa6cf --- /dev/null +++ b/net/ipv4/xfrm4_policy.c @@ -0,0 +1,259 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * xfrm4_policy.c + * + * Changes: + * Kazunori MIYAZAWA @USAGI + * YOSHIFUJI Hideaki @USAGI + * Split up af-specific portion + * + */ + +#include +#include +#include +#include +#include +#include +#include + +static struct dst_entry *__xfrm4_dst_lookup(struct net *net, struct flowi4 *fl4, + int tos, int oif, + const xfrm_address_t *saddr, + const xfrm_address_t *daddr, + u32 mark) +{ + struct rtable *rt; + + memset(fl4, 0, sizeof(*fl4)); + fl4->daddr = daddr->a4; + fl4->flowi4_tos = tos; + fl4->flowi4_l3mdev = l3mdev_master_ifindex_by_index(net, oif); + fl4->flowi4_mark = mark; + if (saddr) + fl4->saddr = saddr->a4; + + rt = __ip_route_output_key(net, fl4); + if (!IS_ERR(rt)) + return &rt->dst; + + return ERR_CAST(rt); +} + +static struct dst_entry *xfrm4_dst_lookup(struct net *net, int tos, int oif, + const xfrm_address_t *saddr, + const xfrm_address_t *daddr, + u32 mark) +{ + struct flowi4 fl4; + + return __xfrm4_dst_lookup(net, &fl4, tos, oif, saddr, daddr, mark); +} + +static int xfrm4_get_saddr(struct net *net, int oif, + xfrm_address_t *saddr, xfrm_address_t *daddr, + u32 mark) +{ + struct dst_entry *dst; + struct flowi4 fl4; + + dst = __xfrm4_dst_lookup(net, &fl4, 0, oif, NULL, daddr, mark); + if (IS_ERR(dst)) + return -EHOSTUNREACH; + + saddr->a4 = fl4.saddr; + dst_release(dst); + return 0; +} + +static int xfrm4_fill_dst(struct xfrm_dst *xdst, struct net_device *dev, + const struct flowi *fl) +{ + struct rtable *rt = (struct rtable *)xdst->route; + const struct flowi4 *fl4 = &fl->u.ip4; + + xdst->u.rt.rt_iif = fl4->flowi4_iif; + + xdst->u.dst.dev = dev; + netdev_hold(dev, &xdst->u.dst.dev_tracker, GFP_ATOMIC); + + /* Sheit... I remember I did this right. Apparently, + * it was magically lost, so this code needs audit */ + xdst->u.rt.rt_is_input = rt->rt_is_input; + xdst->u.rt.rt_flags = rt->rt_flags & (RTCF_BROADCAST | RTCF_MULTICAST | + RTCF_LOCAL); + xdst->u.rt.rt_type = rt->rt_type; + xdst->u.rt.rt_uses_gateway = rt->rt_uses_gateway; + xdst->u.rt.rt_gw_family = rt->rt_gw_family; + if (rt->rt_gw_family == AF_INET) + xdst->u.rt.rt_gw4 = rt->rt_gw4; + else if (rt->rt_gw_family == AF_INET6) + xdst->u.rt.rt_gw6 = rt->rt_gw6; + xdst->u.rt.rt_pmtu = rt->rt_pmtu; + xdst->u.rt.rt_mtu_locked = rt->rt_mtu_locked; + INIT_LIST_HEAD(&xdst->u.rt.rt_uncached); + rt_add_uncached_list(&xdst->u.rt); + + return 0; +} + +static void xfrm4_update_pmtu(struct dst_entry *dst, struct sock *sk, + struct sk_buff *skb, u32 mtu, + bool confirm_neigh) +{ + struct xfrm_dst *xdst = (struct xfrm_dst *)dst; + struct dst_entry *path = xdst->route; + + path->ops->update_pmtu(path, sk, skb, mtu, confirm_neigh); +} + +static void xfrm4_redirect(struct dst_entry *dst, struct sock *sk, + struct sk_buff *skb) +{ + struct xfrm_dst *xdst = (struct xfrm_dst *)dst; + struct dst_entry *path = xdst->route; + + path->ops->redirect(path, sk, skb); +} + +static void xfrm4_dst_destroy(struct dst_entry *dst) +{ + struct xfrm_dst *xdst = (struct xfrm_dst *)dst; + + dst_destroy_metrics_generic(dst); + if (xdst->u.rt.rt_uncached_list) + rt_del_uncached_list(&xdst->u.rt); + xfrm_dst_destroy(xdst); +} + +static void xfrm4_dst_ifdown(struct dst_entry *dst, struct net_device *dev, + int unregister) +{ + if (!unregister) + return; + + xfrm_dst_ifdown(dst, dev); +} + +static struct dst_ops xfrm4_dst_ops_template = { + .family = AF_INET, + .update_pmtu = xfrm4_update_pmtu, + .redirect = xfrm4_redirect, + .cow_metrics = dst_cow_metrics_generic, + .destroy = xfrm4_dst_destroy, + .ifdown = xfrm4_dst_ifdown, + .local_out = __ip_local_out, + .gc_thresh = 32768, +}; + +static const struct xfrm_policy_afinfo xfrm4_policy_afinfo = { + .dst_ops = &xfrm4_dst_ops_template, + .dst_lookup = xfrm4_dst_lookup, + .get_saddr = xfrm4_get_saddr, + .fill_dst = xfrm4_fill_dst, + .blackhole_route = ipv4_blackhole_route, +}; + +#ifdef CONFIG_SYSCTL +static struct ctl_table xfrm4_policy_table[] = { + { + .procname = "xfrm4_gc_thresh", + .data = &init_net.xfrm.xfrm4_dst_ops.gc_thresh, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { } +}; + +static __net_init int xfrm4_net_sysctl_init(struct net *net) +{ + struct ctl_table *table; + struct ctl_table_header *hdr; + + table = xfrm4_policy_table; + if (!net_eq(net, &init_net)) { + table = kmemdup(table, sizeof(xfrm4_policy_table), GFP_KERNEL); + if (!table) + goto err_alloc; + + table[0].data = &net->xfrm.xfrm4_dst_ops.gc_thresh; + } + + hdr = register_net_sysctl(net, "net/ipv4", table); + if (!hdr) + goto err_reg; + + net->ipv4.xfrm4_hdr = hdr; + return 0; + +err_reg: + if (!net_eq(net, &init_net)) + kfree(table); +err_alloc: + return -ENOMEM; +} + +static __net_exit void xfrm4_net_sysctl_exit(struct net *net) +{ + struct ctl_table *table; + + if (!net->ipv4.xfrm4_hdr) + return; + + table = net->ipv4.xfrm4_hdr->ctl_table_arg; + unregister_net_sysctl_table(net->ipv4.xfrm4_hdr); + if (!net_eq(net, &init_net)) + kfree(table); +} +#else /* CONFIG_SYSCTL */ +static inline int xfrm4_net_sysctl_init(struct net *net) +{ + return 0; +} + +static inline void xfrm4_net_sysctl_exit(struct net *net) +{ +} +#endif + +static int __net_init xfrm4_net_init(struct net *net) +{ + int ret; + + memcpy(&net->xfrm.xfrm4_dst_ops, &xfrm4_dst_ops_template, + sizeof(xfrm4_dst_ops_template)); + ret = dst_entries_init(&net->xfrm.xfrm4_dst_ops); + if (ret) + return ret; + + ret = xfrm4_net_sysctl_init(net); + if (ret) + dst_entries_destroy(&net->xfrm.xfrm4_dst_ops); + + return ret; +} + +static void __net_exit xfrm4_net_exit(struct net *net) +{ + xfrm4_net_sysctl_exit(net); + dst_entries_destroy(&net->xfrm.xfrm4_dst_ops); +} + +static struct pernet_operations __net_initdata xfrm4_net_ops = { + .init = xfrm4_net_init, + .exit = xfrm4_net_exit, +}; + +static void __init xfrm4_policy_init(void) +{ + xfrm_policy_register_afinfo(&xfrm4_policy_afinfo, AF_INET); +} + +void __init xfrm4_init(void) +{ + xfrm4_state_init(); + xfrm4_policy_init(); + xfrm4_protocol_init(); + register_pernet_subsys(&xfrm4_net_ops); +} diff --git a/net/ipv4/xfrm4_protocol.c b/net/ipv4/xfrm4_protocol.c new file mode 100644 index 000000000..b146ce88c --- /dev/null +++ b/net/ipv4/xfrm4_protocol.c @@ -0,0 +1,306 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* xfrm4_protocol.c - Generic xfrm protocol multiplexer. + * + * Copyright (C) 2013 secunet Security Networks AG + * + * Author: + * Steffen Klassert + * + * Based on: + * net/ipv4/tunnel4.c + */ + +#include +#include +#include +#include +#include +#include +#include + +static struct xfrm4_protocol __rcu *esp4_handlers __read_mostly; +static struct xfrm4_protocol __rcu *ah4_handlers __read_mostly; +static struct xfrm4_protocol __rcu *ipcomp4_handlers __read_mostly; +static DEFINE_MUTEX(xfrm4_protocol_mutex); + +static inline struct xfrm4_protocol __rcu **proto_handlers(u8 protocol) +{ + switch (protocol) { + case IPPROTO_ESP: + return &esp4_handlers; + case IPPROTO_AH: + return &ah4_handlers; + case IPPROTO_COMP: + return &ipcomp4_handlers; + } + + return NULL; +} + +#define for_each_protocol_rcu(head, handler) \ + for (handler = rcu_dereference(head); \ + handler != NULL; \ + handler = rcu_dereference(handler->next)) \ + +static int xfrm4_rcv_cb(struct sk_buff *skb, u8 protocol, int err) +{ + int ret; + struct xfrm4_protocol *handler; + struct xfrm4_protocol __rcu **head = proto_handlers(protocol); + + if (!head) + return 0; + + for_each_protocol_rcu(*head, handler) + if ((ret = handler->cb_handler(skb, err)) <= 0) + return ret; + + return 0; +} + +int xfrm4_rcv_encap(struct sk_buff *skb, int nexthdr, __be32 spi, + int encap_type) +{ + int ret; + struct xfrm4_protocol *handler; + struct xfrm4_protocol __rcu **head = proto_handlers(nexthdr); + + XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip4 = NULL; + XFRM_SPI_SKB_CB(skb)->family = AF_INET; + XFRM_SPI_SKB_CB(skb)->daddroff = offsetof(struct iphdr, daddr); + + if (!head) + goto out; + + if (!skb_dst(skb)) { + const struct iphdr *iph = ip_hdr(skb); + + if (ip_route_input_noref(skb, iph->daddr, iph->saddr, + iph->tos, skb->dev)) + goto drop; + } + + for_each_protocol_rcu(*head, handler) + if ((ret = handler->input_handler(skb, nexthdr, spi, encap_type)) != -EINVAL) + return ret; + +out: + icmp_send(skb, ICMP_DEST_UNREACH, ICMP_PORT_UNREACH, 0); + +drop: + kfree_skb(skb); + return 0; +} +EXPORT_SYMBOL(xfrm4_rcv_encap); + +static int xfrm4_esp_rcv(struct sk_buff *skb) +{ + int ret; + struct xfrm4_protocol *handler; + + XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip4 = NULL; + + for_each_protocol_rcu(esp4_handlers, handler) + if ((ret = handler->handler(skb)) != -EINVAL) + return ret; + + icmp_send(skb, ICMP_DEST_UNREACH, ICMP_PORT_UNREACH, 0); + + kfree_skb(skb); + return 0; +} + +static int xfrm4_esp_err(struct sk_buff *skb, u32 info) +{ + struct xfrm4_protocol *handler; + + for_each_protocol_rcu(esp4_handlers, handler) + if (!handler->err_handler(skb, info)) + return 0; + + return -ENOENT; +} + +static int xfrm4_ah_rcv(struct sk_buff *skb) +{ + int ret; + struct xfrm4_protocol *handler; + + XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip4 = NULL; + + for_each_protocol_rcu(ah4_handlers, handler) + if ((ret = handler->handler(skb)) != -EINVAL) + return ret; + + icmp_send(skb, ICMP_DEST_UNREACH, ICMP_PORT_UNREACH, 0); + + kfree_skb(skb); + return 0; +} + +static int xfrm4_ah_err(struct sk_buff *skb, u32 info) +{ + struct xfrm4_protocol *handler; + + for_each_protocol_rcu(ah4_handlers, handler) + if (!handler->err_handler(skb, info)) + return 0; + + return -ENOENT; +} + +static int xfrm4_ipcomp_rcv(struct sk_buff *skb) +{ + int ret; + struct xfrm4_protocol *handler; + + XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip4 = NULL; + + for_each_protocol_rcu(ipcomp4_handlers, handler) + if ((ret = handler->handler(skb)) != -EINVAL) + return ret; + + icmp_send(skb, ICMP_DEST_UNREACH, ICMP_PORT_UNREACH, 0); + + kfree_skb(skb); + return 0; +} + +static int xfrm4_ipcomp_err(struct sk_buff *skb, u32 info) +{ + struct xfrm4_protocol *handler; + + for_each_protocol_rcu(ipcomp4_handlers, handler) + if (!handler->err_handler(skb, info)) + return 0; + + return -ENOENT; +} + +static const struct net_protocol esp4_protocol = { + .handler = xfrm4_esp_rcv, + .err_handler = xfrm4_esp_err, + .no_policy = 1, +}; + +static const struct net_protocol ah4_protocol = { + .handler = xfrm4_ah_rcv, + .err_handler = xfrm4_ah_err, + .no_policy = 1, +}; + +static const struct net_protocol ipcomp4_protocol = { + .handler = xfrm4_ipcomp_rcv, + .err_handler = xfrm4_ipcomp_err, + .no_policy = 1, +}; + +static const struct xfrm_input_afinfo xfrm4_input_afinfo = { + .family = AF_INET, + .callback = xfrm4_rcv_cb, +}; + +static inline const struct net_protocol *netproto(unsigned char protocol) +{ + switch (protocol) { + case IPPROTO_ESP: + return &esp4_protocol; + case IPPROTO_AH: + return &ah4_protocol; + case IPPROTO_COMP: + return &ipcomp4_protocol; + } + + return NULL; +} + +int xfrm4_protocol_register(struct xfrm4_protocol *handler, + unsigned char protocol) +{ + struct xfrm4_protocol __rcu **pprev; + struct xfrm4_protocol *t; + bool add_netproto = false; + int ret = -EEXIST; + int priority = handler->priority; + + if (!proto_handlers(protocol) || !netproto(protocol)) + return -EINVAL; + + mutex_lock(&xfrm4_protocol_mutex); + + if (!rcu_dereference_protected(*proto_handlers(protocol), + lockdep_is_held(&xfrm4_protocol_mutex))) + add_netproto = true; + + for (pprev = proto_handlers(protocol); + (t = rcu_dereference_protected(*pprev, + lockdep_is_held(&xfrm4_protocol_mutex))) != NULL; + pprev = &t->next) { + if (t->priority < priority) + break; + if (t->priority == priority) + goto err; + } + + handler->next = *pprev; + rcu_assign_pointer(*pprev, handler); + + ret = 0; + +err: + mutex_unlock(&xfrm4_protocol_mutex); + + if (add_netproto) { + if (inet_add_protocol(netproto(protocol), protocol)) { + pr_err("%s: can't add protocol\n", __func__); + ret = -EAGAIN; + } + } + + return ret; +} +EXPORT_SYMBOL(xfrm4_protocol_register); + +int xfrm4_protocol_deregister(struct xfrm4_protocol *handler, + unsigned char protocol) +{ + struct xfrm4_protocol __rcu **pprev; + struct xfrm4_protocol *t; + int ret = -ENOENT; + + if (!proto_handlers(protocol) || !netproto(protocol)) + return -EINVAL; + + mutex_lock(&xfrm4_protocol_mutex); + + for (pprev = proto_handlers(protocol); + (t = rcu_dereference_protected(*pprev, + lockdep_is_held(&xfrm4_protocol_mutex))) != NULL; + pprev = &t->next) { + if (t == handler) { + *pprev = handler->next; + ret = 0; + break; + } + } + + if (!rcu_dereference_protected(*proto_handlers(protocol), + lockdep_is_held(&xfrm4_protocol_mutex))) { + if (inet_del_protocol(netproto(protocol), protocol) < 0) { + pr_err("%s: can't remove protocol\n", __func__); + ret = -EAGAIN; + } + } + + mutex_unlock(&xfrm4_protocol_mutex); + + synchronize_net(); + + return ret; +} +EXPORT_SYMBOL(xfrm4_protocol_deregister); + +void __init xfrm4_protocol_init(void) +{ + xfrm_input_register_afinfo(&xfrm4_input_afinfo); +} diff --git a/net/ipv4/xfrm4_state.c b/net/ipv4/xfrm4_state.c new file mode 100644 index 000000000..87d4db591 --- /dev/null +++ b/net/ipv4/xfrm4_state.c @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * xfrm4_state.c + * + * Changes: + * YOSHIFUJI Hideaki @USAGI + * Split up af-specific portion + * + */ + +#include + +static struct xfrm_state_afinfo xfrm4_state_afinfo = { + .family = AF_INET, + .proto = IPPROTO_IPIP, + .output = xfrm4_output, + .transport_finish = xfrm4_transport_finish, + .local_error = xfrm4_local_error, +}; + +void __init xfrm4_state_init(void) +{ + xfrm_state_register_afinfo(&xfrm4_state_afinfo); +} diff --git a/net/ipv4/xfrm4_tunnel.c b/net/ipv4/xfrm4_tunnel.c new file mode 100644 index 000000000..8489fa106 --- /dev/null +++ b/net/ipv4/xfrm4_tunnel.c @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* xfrm4_tunnel.c: Generic IP tunnel transformer. + * + * Copyright (C) 2003 David S. Miller (davem@redhat.com) + */ + +#define pr_fmt(fmt) "IPsec: " fmt + +#include +#include +#include +#include + +static int ipip_output(struct xfrm_state *x, struct sk_buff *skb) +{ + skb_push(skb, -skb_network_offset(skb)); + return 0; +} + +static int ipip_xfrm_rcv(struct xfrm_state *x, struct sk_buff *skb) +{ + return ip_hdr(skb)->protocol; +} + +static int ipip_init_state(struct xfrm_state *x, struct netlink_ext_ack *extack) +{ + if (x->props.mode != XFRM_MODE_TUNNEL) { + NL_SET_ERR_MSG(extack, "IPv4 tunnel can only be used with tunnel mode"); + return -EINVAL; + } + + if (x->encap) { + NL_SET_ERR_MSG(extack, "IPv4 tunnel is not compatible with encapsulation"); + return -EINVAL; + } + + x->props.header_len = sizeof(struct iphdr); + + return 0; +} + +static void ipip_destroy(struct xfrm_state *x) +{ +} + +static const struct xfrm_type ipip_type = { + .owner = THIS_MODULE, + .proto = IPPROTO_IPIP, + .init_state = ipip_init_state, + .destructor = ipip_destroy, + .input = ipip_xfrm_rcv, + .output = ipip_output +}; + +static int xfrm_tunnel_rcv(struct sk_buff *skb) +{ + return xfrm4_rcv_spi(skb, IPPROTO_IPIP, ip_hdr(skb)->saddr); +} + +static int xfrm_tunnel_err(struct sk_buff *skb, u32 info) +{ + return -ENOENT; +} + +static struct xfrm_tunnel xfrm_tunnel_handler __read_mostly = { + .handler = xfrm_tunnel_rcv, + .err_handler = xfrm_tunnel_err, + .priority = 4, +}; + +#if IS_ENABLED(CONFIG_IPV6) +static struct xfrm_tunnel xfrm64_tunnel_handler __read_mostly = { + .handler = xfrm_tunnel_rcv, + .err_handler = xfrm_tunnel_err, + .priority = 3, +}; +#endif + +static int __init ipip_init(void) +{ + if (xfrm_register_type(&ipip_type, AF_INET) < 0) { + pr_info("%s: can't add xfrm type\n", __func__); + return -EAGAIN; + } + + if (xfrm4_tunnel_register(&xfrm_tunnel_handler, AF_INET)) { + pr_info("%s: can't add xfrm handler for AF_INET\n", __func__); + xfrm_unregister_type(&ipip_type, AF_INET); + return -EAGAIN; + } +#if IS_ENABLED(CONFIG_IPV6) + if (xfrm4_tunnel_register(&xfrm64_tunnel_handler, AF_INET6)) { + pr_info("%s: can't add xfrm handler for AF_INET6\n", __func__); + xfrm4_tunnel_deregister(&xfrm_tunnel_handler, AF_INET); + xfrm_unregister_type(&ipip_type, AF_INET); + return -EAGAIN; + } +#endif + return 0; +} + +static void __exit ipip_fini(void) +{ +#if IS_ENABLED(CONFIG_IPV6) + if (xfrm4_tunnel_deregister(&xfrm64_tunnel_handler, AF_INET6)) + pr_info("%s: can't remove xfrm handler for AF_INET6\n", + __func__); +#endif + if (xfrm4_tunnel_deregister(&xfrm_tunnel_handler, AF_INET)) + pr_info("%s: can't remove xfrm handler for AF_INET\n", + __func__); + xfrm_unregister_type(&ipip_type, AF_INET); +} + +module_init(ipip_init); +module_exit(ipip_fini); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_XFRM_TYPE(AF_INET, XFRM_PROTO_IPIP); diff --git a/net/ipv6/Kconfig b/net/ipv6/Kconfig new file mode 100644 index 000000000..658bfed1d --- /dev/null +++ b/net/ipv6/Kconfig @@ -0,0 +1,343 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# IPv6 configuration +# + +# IPv6 as module will cause a CRASH if you try to unload it +menuconfig IPV6 + tristate "The IPv6 protocol" + default y + select CRYPTO_LIB_SHA1 + help + Support for IP version 6 (IPv6). + + For general information about IPv6, see + . + For specific information about IPv6 under Linux, see + Documentation/networking/ipv6.rst and read the HOWTO at + + + To compile this protocol support as a module, choose M here: the + module will be called ipv6. + +if IPV6 + +config IPV6_ROUTER_PREF + bool "IPv6: Router Preference (RFC 4191) support" + help + Router Preference is an optional extension to the Router + Advertisement message which improves the ability of hosts + to pick an appropriate router, especially when the hosts + are placed in a multi-homed network. + + If unsure, say N. + +config IPV6_ROUTE_INFO + bool "IPv6: Route Information (RFC 4191) support" + depends on IPV6_ROUTER_PREF + help + Support of Route Information. + + If unsure, say N. + +config IPV6_OPTIMISTIC_DAD + bool "IPv6: Enable RFC 4429 Optimistic DAD" + help + Support for optimistic Duplicate Address Detection. It allows for + autoconfigured addresses to be used more quickly. + + If unsure, say N. + +config INET6_AH + tristate "IPv6: AH transformation" + select XFRM_AH + help + Support for IPsec AH (Authentication Header). + + AH can be used with various authentication algorithms. Besides + enabling AH support itself, this option enables the generic + implementations of the algorithms that RFC 8221 lists as MUST be + implemented. If you need any other algorithms, you'll need to enable + them in the crypto API. You should also enable accelerated + implementations of any needed algorithms when available. + + If unsure, say Y. + +config INET6_ESP + tristate "IPv6: ESP transformation" + select XFRM_ESP + help + Support for IPsec ESP (Encapsulating Security Payload). + + ESP can be used with various encryption and authentication algorithms. + Besides enabling ESP support itself, this option enables the generic + implementations of the algorithms that RFC 8221 lists as MUST be + implemented. If you need any other algorithms, you'll need to enable + them in the crypto API. You should also enable accelerated + implementations of any needed algorithms when available. + + If unsure, say Y. + +config INET6_ESP_OFFLOAD + tristate "IPv6: ESP transformation offload" + depends on INET6_ESP + select XFRM_OFFLOAD + default n + help + Support for ESP transformation offload. This makes sense + only if this system really does IPsec and want to do it + with high throughput. A typical desktop system does not + need it, even if it does IPsec. + + If unsure, say N. + +config INET6_ESPINTCP + bool "IPv6: ESP in TCP encapsulation (RFC 8229)" + depends on XFRM && INET6_ESP + select STREAM_PARSER + select NET_SOCK_MSG + select XFRM_ESPINTCP + help + Support for RFC 8229 encapsulation of ESP and IKE over + TCP/IPv6 sockets. + + If unsure, say N. + +config INET6_IPCOMP + tristate "IPv6: IPComp transformation" + select INET6_XFRM_TUNNEL + select XFRM_IPCOMP + help + Support for IP Payload Compression Protocol (IPComp) (RFC3173), + typically needed for IPsec. + + If unsure, say Y. + +config IPV6_MIP6 + tristate "IPv6: Mobility" + select XFRM + help + Support for IPv6 Mobility described in RFC 3775. + + If unsure, say N. + +config IPV6_ILA + tristate "IPv6: Identifier Locator Addressing (ILA)" + depends on NETFILTER + select DST_CACHE + select LWTUNNEL + help + Support for IPv6 Identifier Locator Addressing (ILA). + + ILA is a mechanism to do network virtualization without + encapsulation. The basic concept of ILA is that we split an + IPv6 address into a 64 bit locator and 64 bit identifier. The + identifier is the identity of an entity in communication + ("who") and the locator expresses the location of the + entity ("where"). + + ILA can be configured using the "encap ila" option with + "ip -6 route" command. ILA is described in + https://tools.ietf.org/html/draft-herbert-nvo3-ila-00. + + If unsure, say N. + +config INET6_XFRM_TUNNEL + tristate + select INET6_TUNNEL + default n + +config INET6_TUNNEL + tristate + default n + +config IPV6_VTI +tristate "Virtual (secure) IPv6: tunneling" + select IPV6_TUNNEL + select NET_IP_TUNNEL + select XFRM + help + Tunneling means encapsulating data of one protocol type within + another protocol and sending it over a channel that understands the + encapsulating protocol. This can be used with xfrm mode tunnel to give + the notion of a secure tunnel for IPSEC and then use routing protocol + on top. + +config IPV6_SIT + tristate "IPv6: IPv6-in-IPv4 tunnel (SIT driver)" + select INET_TUNNEL + select NET_IP_TUNNEL + select IPV6_NDISC_NODETYPE + default y + help + Tunneling means encapsulating data of one protocol type within + another protocol and sending it over a channel that understands the + encapsulating protocol. This driver implements encapsulation of IPv6 + into IPv4 packets. This is useful if you want to connect two IPv6 + networks over an IPv4-only path. + + Saying M here will produce a module called sit. If unsure, say Y. + +config IPV6_SIT_6RD + bool "IPv6: IPv6 Rapid Deployment (6RD)" + depends on IPV6_SIT + default n + help + IPv6 Rapid Deployment (6rd; draft-ietf-softwire-ipv6-6rd) builds upon + mechanisms of 6to4 (RFC3056) to enable a service provider to rapidly + deploy IPv6 unicast service to IPv4 sites to which it provides + customer premise equipment. Like 6to4, it utilizes stateless IPv6 in + IPv4 encapsulation in order to transit IPv4-only network + infrastructure. Unlike 6to4, a 6rd service provider uses an IPv6 + prefix of its own in place of the fixed 6to4 prefix. + + With this option enabled, the SIT driver offers 6rd functionality by + providing additional ioctl API to configure the IPv6 Prefix for in + stead of static 2002::/16 for 6to4. + + If unsure, say N. + +config IPV6_NDISC_NODETYPE + bool + +config IPV6_TUNNEL + tristate "IPv6: IP-in-IPv6 tunnel (RFC2473)" + select INET6_TUNNEL + select DST_CACHE + select GRO_CELLS + help + Support for IPv6-in-IPv6 and IPv4-in-IPv6 tunnels described in + RFC 2473. + + If unsure, say N. + +config IPV6_GRE + tristate "IPv6: GRE tunnel" + select IPV6_TUNNEL + select NET_IP_TUNNEL + depends on NET_IPGRE_DEMUX + help + Tunneling means encapsulating data of one protocol type within + another protocol and sending it over a channel that understands the + encapsulating protocol. This particular tunneling driver implements + GRE (Generic Routing Encapsulation) and at this time allows + encapsulating of IPv4 or IPv6 over existing IPv6 infrastructure. + This driver is useful if the other endpoint is a Cisco router: Cisco + likes GRE much better than the other Linux tunneling driver ("IP + tunneling" above). In addition, GRE allows multicast redistribution + through the tunnel. + + Saying M here will produce a module called ip6_gre. If unsure, say N. + +config IPV6_FOU + tristate + default NET_FOU && IPV6 + +config IPV6_FOU_TUNNEL + tristate + default NET_FOU_IP_TUNNELS && IPV6_FOU + select IPV6_TUNNEL + +config IPV6_MULTIPLE_TABLES + bool "IPv6: Multiple Routing Tables" + select FIB_RULES + help + Support multiple routing tables. + +config IPV6_SUBTREES + bool "IPv6: source address based routing" + depends on IPV6_MULTIPLE_TABLES + help + Enable routing by source address or prefix. + + The destination address is still the primary routing key, so mixing + normal and source prefix specific routes in the same routing table + may sometimes lead to unintended routing behavior. This can be + avoided by defining different routing tables for the normal and + source prefix specific routes. + + If unsure, say N. + +config IPV6_MROUTE + bool "IPv6: multicast routing" + depends on IPV6 + select IP_MROUTE_COMMON + help + Support for IPv6 multicast forwarding. + If unsure, say N. + +config IPV6_MROUTE_MULTIPLE_TABLES + bool "IPv6: multicast policy routing" + depends on IPV6_MROUTE + select FIB_RULES + help + Normally, a multicast router runs a userspace daemon and decides + what to do with a multicast packet based on the source and + destination addresses. If you say Y here, the multicast router + will also be able to take interfaces and packet marks into + account and run multiple instances of userspace daemons + simultaneously, each one handling a single table. + + If unsure, say N. + +config IPV6_PIMSM_V2 + bool "IPv6: PIM-SM version 2 support" + depends on IPV6_MROUTE + help + Support for IPv6 PIM multicast routing protocol PIM-SMv2. + If unsure, say N. + +config IPV6_SEG6_LWTUNNEL + bool "IPv6: Segment Routing Header encapsulation support" + depends on IPV6 + select LWTUNNEL + select DST_CACHE + select IPV6_MULTIPLE_TABLES + help + Support for encapsulation of packets within an outer IPv6 + header and a Segment Routing Header using the lightweight + tunnels mechanism. Also enable support for advanced local + processing of SRv6 packets based on their active segment. + + If unsure, say N. + +config IPV6_SEG6_HMAC + bool "IPv6: Segment Routing HMAC support" + depends on IPV6 + select CRYPTO + select CRYPTO_HMAC + select CRYPTO_SHA1 + select CRYPTO_SHA256 + help + Support for HMAC signature generation and verification + of SR-enabled packets. + + If unsure, say N. + +config IPV6_SEG6_BPF + def_bool y + depends on IPV6_SEG6_LWTUNNEL + depends on IPV6 = y + +config IPV6_RPL_LWTUNNEL + bool "IPv6: RPL Source Routing Header support" + depends on IPV6 + select LWTUNNEL + help + Support for RFC6554 RPL Source Routing Header using the lightweight + tunnels mechanism. + + If unsure, say N. + +config IPV6_IOAM6_LWTUNNEL + bool "IPv6: IOAM Pre-allocated Trace insertion support" + depends on IPV6 + select LWTUNNEL + select DST_CACHE + help + Support for the insertion of IOAM Pre-allocated Trace + Header using the lightweight tunnels mechanism. + + If unsure, say N. + +endif # IPV6 diff --git a/net/ipv6/Makefile b/net/ipv6/Makefile new file mode 100644 index 000000000..3036a45e8 --- /dev/null +++ b/net/ipv6/Makefile @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the Linux TCP/IP (INET6) layer. +# + +obj-$(CONFIG_IPV6) += ipv6.o + +ipv6-y := af_inet6.o anycast.o ip6_output.o ip6_input.o addrconf.o \ + addrlabel.o \ + route.o ip6_fib.o ipv6_sockglue.o ndisc.o udp.o udplite.o \ + raw.o icmp.o mcast.o reassembly.o tcp_ipv6.o ping.o \ + exthdrs.o datagram.o ip6_flowlabel.o inet6_connection_sock.o \ + udp_offload.o seg6.o fib6_notifier.o rpl.o ioam6.o + +ipv6-$(CONFIG_SYSCTL) += sysctl_net_ipv6.o +ipv6-$(CONFIG_IPV6_MROUTE) += ip6mr.o + +ipv6-$(CONFIG_XFRM) += xfrm6_policy.o xfrm6_state.o xfrm6_input.o \ + xfrm6_output.o xfrm6_protocol.o +ipv6-$(CONFIG_NETFILTER) += netfilter.o +ipv6-$(CONFIG_IPV6_MULTIPLE_TABLES) += fib6_rules.o +ipv6-$(CONFIG_PROC_FS) += proc.o +ipv6-$(CONFIG_SYN_COOKIES) += syncookies.o +ipv6-$(CONFIG_NETLABEL) += calipso.o +ipv6-$(CONFIG_IPV6_SEG6_LWTUNNEL) += seg6_iptunnel.o seg6_local.o +ipv6-$(CONFIG_IPV6_SEG6_HMAC) += seg6_hmac.o +ipv6-$(CONFIG_IPV6_RPL_LWTUNNEL) += rpl_iptunnel.o +ipv6-$(CONFIG_IPV6_IOAM6_LWTUNNEL) += ioam6_iptunnel.o + +obj-$(CONFIG_INET6_AH) += ah6.o +obj-$(CONFIG_INET6_ESP) += esp6.o +obj-$(CONFIG_INET6_ESP_OFFLOAD) += esp6_offload.o +obj-$(CONFIG_INET6_IPCOMP) += ipcomp6.o +obj-$(CONFIG_INET6_XFRM_TUNNEL) += xfrm6_tunnel.o +obj-$(CONFIG_INET6_TUNNEL) += tunnel6.o +obj-$(CONFIG_IPV6_MIP6) += mip6.o +obj-$(CONFIG_IPV6_ILA) += ila/ +obj-$(CONFIG_NETFILTER) += netfilter/ + +obj-$(CONFIG_IPV6_VTI) += ip6_vti.o +obj-$(CONFIG_IPV6_SIT) += sit.o +obj-$(CONFIG_IPV6_TUNNEL) += ip6_tunnel.o +obj-$(CONFIG_IPV6_GRE) += ip6_gre.o +obj-$(CONFIG_IPV6_FOU) += fou6.o + +obj-y += addrconf_core.o exthdrs_core.o ip6_checksum.o ip6_icmp.o +obj-$(CONFIG_INET) += output_core.o protocol.o \ + ip6_offload.o tcpv6_offload.o exthdrs_offload.o + +obj-$(subst m,y,$(CONFIG_IPV6)) += inet6_hashtables.o + +ifneq ($(CONFIG_IPV6),) +obj-$(CONFIG_NET_UDP_TUNNEL) += ip6_udp_tunnel.o +obj-y += mcast_snoop.o +endif diff --git a/net/ipv6/addrconf.c b/net/ipv6/addrconf.c new file mode 100644 index 000000000..b8dc20fe7 --- /dev/null +++ b/net/ipv6/addrconf.c @@ -0,0 +1,7402 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPv6 Address [auto]configuration + * Linux INET6 implementation + * + * Authors: + * Pedro Roque + * Alexey Kuznetsov + */ + +/* + * Changes: + * + * Janos Farkas : delete timer on ifdown + * + * Andi Kleen : kill double kfree on module + * unload. + * Maciej W. Rozycki : FDDI support + * sekiya@USAGI : Don't send too many RS + * packets. + * yoshfuji@USAGI : Fixed interval between DAD + * packets. + * YOSHIFUJI Hideaki @USAGI : improved accuracy of + * address validation timer. + * YOSHIFUJI Hideaki @USAGI : Privacy Extensions (RFC3041) + * support. + * Yuji SEKIYA @USAGI : Don't assign a same IPv6 + * address on a same interface. + * YOSHIFUJI Hideaki @USAGI : ARCnet support + * YOSHIFUJI Hideaki @USAGI : convert /proc/net/if_inet6 to + * seq_file. + * YOSHIFUJI Hideaki @USAGI : improved source address + * selection; consider scope, + * status etc. + */ + +#define pr_fmt(fmt) "IPv6: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_SYSCTL +#include +#endif +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#define INFINITY_LIFE_TIME 0xFFFFFFFF + +#define IPV6_MAX_STRLEN \ + sizeof("ffff:ffff:ffff:ffff:ffff:ffff:255.255.255.255") + +static inline u32 cstamp_delta(unsigned long cstamp) +{ + return (cstamp - INITIAL_JIFFIES) * 100UL / HZ; +} + +static inline s32 rfc3315_s14_backoff_init(s32 irt) +{ + /* multiply 'initial retransmission time' by 0.9 .. 1.1 */ + u64 tmp = (900000 + prandom_u32_max(200001)) * (u64)irt; + do_div(tmp, 1000000); + return (s32)tmp; +} + +static inline s32 rfc3315_s14_backoff_update(s32 rt, s32 mrt) +{ + /* multiply 'retransmission timeout' by 1.9 .. 2.1 */ + u64 tmp = (1900000 + prandom_u32_max(200001)) * (u64)rt; + do_div(tmp, 1000000); + if ((s32)tmp > mrt) { + /* multiply 'maximum retransmission time' by 0.9 .. 1.1 */ + tmp = (900000 + prandom_u32_max(200001)) * (u64)mrt; + do_div(tmp, 1000000); + } + return (s32)tmp; +} + +#ifdef CONFIG_SYSCTL +static int addrconf_sysctl_register(struct inet6_dev *idev); +static void addrconf_sysctl_unregister(struct inet6_dev *idev); +#else +static inline int addrconf_sysctl_register(struct inet6_dev *idev) +{ + return 0; +} + +static inline void addrconf_sysctl_unregister(struct inet6_dev *idev) +{ +} +#endif + +static void ipv6_gen_rnd_iid(struct in6_addr *addr); + +static int ipv6_generate_eui64(u8 *eui, struct net_device *dev); +static int ipv6_count_addresses(const struct inet6_dev *idev); +static int ipv6_generate_stable_address(struct in6_addr *addr, + u8 dad_count, + const struct inet6_dev *idev); + +#define IN6_ADDR_HSIZE_SHIFT 8 +#define IN6_ADDR_HSIZE (1 << IN6_ADDR_HSIZE_SHIFT) + +static void addrconf_verify(struct net *net); +static void addrconf_verify_rtnl(struct net *net); + +static struct workqueue_struct *addrconf_wq; + +static void addrconf_join_anycast(struct inet6_ifaddr *ifp); +static void addrconf_leave_anycast(struct inet6_ifaddr *ifp); + +static void addrconf_type_change(struct net_device *dev, + unsigned long event); +static int addrconf_ifdown(struct net_device *dev, bool unregister); + +static struct fib6_info *addrconf_get_prefix_route(const struct in6_addr *pfx, + int plen, + const struct net_device *dev, + u32 flags, u32 noflags, + bool no_gw); + +static void addrconf_dad_start(struct inet6_ifaddr *ifp); +static void addrconf_dad_work(struct work_struct *w); +static void addrconf_dad_completed(struct inet6_ifaddr *ifp, bool bump_id, + bool send_na); +static void addrconf_dad_run(struct inet6_dev *idev, bool restart); +static void addrconf_rs_timer(struct timer_list *t); +static void __ipv6_ifa_notify(int event, struct inet6_ifaddr *ifa); +static void ipv6_ifa_notify(int event, struct inet6_ifaddr *ifa); + +static void inet6_prefix_notify(int event, struct inet6_dev *idev, + struct prefix_info *pinfo); + +static struct ipv6_devconf ipv6_devconf __read_mostly = { + .forwarding = 0, + .hop_limit = IPV6_DEFAULT_HOPLIMIT, + .mtu6 = IPV6_MIN_MTU, + .accept_ra = 1, + .accept_redirects = 1, + .autoconf = 1, + .force_mld_version = 0, + .mldv1_unsolicited_report_interval = 10 * HZ, + .mldv2_unsolicited_report_interval = HZ, + .dad_transmits = 1, + .rtr_solicits = MAX_RTR_SOLICITATIONS, + .rtr_solicit_interval = RTR_SOLICITATION_INTERVAL, + .rtr_solicit_max_interval = RTR_SOLICITATION_MAX_INTERVAL, + .rtr_solicit_delay = MAX_RTR_SOLICITATION_DELAY, + .use_tempaddr = 0, + .temp_valid_lft = TEMP_VALID_LIFETIME, + .temp_prefered_lft = TEMP_PREFERRED_LIFETIME, + .regen_max_retry = REGEN_MAX_RETRY, + .max_desync_factor = MAX_DESYNC_FACTOR, + .max_addresses = IPV6_MAX_ADDRESSES, + .accept_ra_defrtr = 1, + .ra_defrtr_metric = IP6_RT_PRIO_USER, + .accept_ra_from_local = 0, + .accept_ra_min_hop_limit= 1, + .accept_ra_min_lft = 0, + .accept_ra_pinfo = 1, +#ifdef CONFIG_IPV6_ROUTER_PREF + .accept_ra_rtr_pref = 1, + .rtr_probe_interval = 60 * HZ, +#ifdef CONFIG_IPV6_ROUTE_INFO + .accept_ra_rt_info_min_plen = 0, + .accept_ra_rt_info_max_plen = 0, +#endif +#endif + .proxy_ndp = 0, + .accept_source_route = 0, /* we do not accept RH0 by default. */ + .disable_ipv6 = 0, + .accept_dad = 0, + .suppress_frag_ndisc = 1, + .accept_ra_mtu = 1, + .stable_secret = { + .initialized = false, + }, + .use_oif_addrs_only = 0, + .ignore_routes_with_linkdown = 0, + .keep_addr_on_down = 0, + .seg6_enabled = 0, +#ifdef CONFIG_IPV6_SEG6_HMAC + .seg6_require_hmac = 0, +#endif + .enhanced_dad = 1, + .addr_gen_mode = IN6_ADDR_GEN_MODE_EUI64, + .disable_policy = 0, + .rpl_seg_enabled = 0, + .ioam6_enabled = 0, + .ioam6_id = IOAM6_DEFAULT_IF_ID, + .ioam6_id_wide = IOAM6_DEFAULT_IF_ID_WIDE, + .ndisc_evict_nocarrier = 1, +}; + +static struct ipv6_devconf ipv6_devconf_dflt __read_mostly = { + .forwarding = 0, + .hop_limit = IPV6_DEFAULT_HOPLIMIT, + .mtu6 = IPV6_MIN_MTU, + .accept_ra = 1, + .accept_redirects = 1, + .autoconf = 1, + .force_mld_version = 0, + .mldv1_unsolicited_report_interval = 10 * HZ, + .mldv2_unsolicited_report_interval = HZ, + .dad_transmits = 1, + .rtr_solicits = MAX_RTR_SOLICITATIONS, + .rtr_solicit_interval = RTR_SOLICITATION_INTERVAL, + .rtr_solicit_max_interval = RTR_SOLICITATION_MAX_INTERVAL, + .rtr_solicit_delay = MAX_RTR_SOLICITATION_DELAY, + .use_tempaddr = 0, + .temp_valid_lft = TEMP_VALID_LIFETIME, + .temp_prefered_lft = TEMP_PREFERRED_LIFETIME, + .regen_max_retry = REGEN_MAX_RETRY, + .max_desync_factor = MAX_DESYNC_FACTOR, + .max_addresses = IPV6_MAX_ADDRESSES, + .accept_ra_defrtr = 1, + .ra_defrtr_metric = IP6_RT_PRIO_USER, + .accept_ra_from_local = 0, + .accept_ra_min_hop_limit= 1, + .accept_ra_min_lft = 0, + .accept_ra_pinfo = 1, +#ifdef CONFIG_IPV6_ROUTER_PREF + .accept_ra_rtr_pref = 1, + .rtr_probe_interval = 60 * HZ, +#ifdef CONFIG_IPV6_ROUTE_INFO + .accept_ra_rt_info_min_plen = 0, + .accept_ra_rt_info_max_plen = 0, +#endif +#endif + .proxy_ndp = 0, + .accept_source_route = 0, /* we do not accept RH0 by default. */ + .disable_ipv6 = 0, + .accept_dad = 1, + .suppress_frag_ndisc = 1, + .accept_ra_mtu = 1, + .stable_secret = { + .initialized = false, + }, + .use_oif_addrs_only = 0, + .ignore_routes_with_linkdown = 0, + .keep_addr_on_down = 0, + .seg6_enabled = 0, +#ifdef CONFIG_IPV6_SEG6_HMAC + .seg6_require_hmac = 0, +#endif + .enhanced_dad = 1, + .addr_gen_mode = IN6_ADDR_GEN_MODE_EUI64, + .disable_policy = 0, + .rpl_seg_enabled = 0, + .ioam6_enabled = 0, + .ioam6_id = IOAM6_DEFAULT_IF_ID, + .ioam6_id_wide = IOAM6_DEFAULT_IF_ID_WIDE, + .ndisc_evict_nocarrier = 1, +}; + +/* Check if link is ready: is it up and is a valid qdisc available */ +static inline bool addrconf_link_ready(const struct net_device *dev) +{ + return netif_oper_up(dev) && !qdisc_tx_is_noop(dev); +} + +static void addrconf_del_rs_timer(struct inet6_dev *idev) +{ + if (del_timer(&idev->rs_timer)) + __in6_dev_put(idev); +} + +static void addrconf_del_dad_work(struct inet6_ifaddr *ifp) +{ + if (cancel_delayed_work(&ifp->dad_work)) + __in6_ifa_put(ifp); +} + +static void addrconf_mod_rs_timer(struct inet6_dev *idev, + unsigned long when) +{ + if (!mod_timer(&idev->rs_timer, jiffies + when)) + in6_dev_hold(idev); +} + +static void addrconf_mod_dad_work(struct inet6_ifaddr *ifp, + unsigned long delay) +{ + in6_ifa_hold(ifp); + if (mod_delayed_work(addrconf_wq, &ifp->dad_work, delay)) + in6_ifa_put(ifp); +} + +static int snmp6_alloc_dev(struct inet6_dev *idev) +{ + int i; + + idev->stats.ipv6 = alloc_percpu_gfp(struct ipstats_mib, GFP_KERNEL_ACCOUNT); + if (!idev->stats.ipv6) + goto err_ip; + + for_each_possible_cpu(i) { + struct ipstats_mib *addrconf_stats; + addrconf_stats = per_cpu_ptr(idev->stats.ipv6, i); + u64_stats_init(&addrconf_stats->syncp); + } + + + idev->stats.icmpv6dev = kzalloc(sizeof(struct icmpv6_mib_device), + GFP_KERNEL); + if (!idev->stats.icmpv6dev) + goto err_icmp; + idev->stats.icmpv6msgdev = kzalloc(sizeof(struct icmpv6msg_mib_device), + GFP_KERNEL_ACCOUNT); + if (!idev->stats.icmpv6msgdev) + goto err_icmpmsg; + + return 0; + +err_icmpmsg: + kfree(idev->stats.icmpv6dev); +err_icmp: + free_percpu(idev->stats.ipv6); +err_ip: + return -ENOMEM; +} + +static struct inet6_dev *ipv6_add_dev(struct net_device *dev) +{ + struct inet6_dev *ndev; + int err = -ENOMEM; + + ASSERT_RTNL(); + + if (dev->mtu < IPV6_MIN_MTU && dev != blackhole_netdev) + return ERR_PTR(-EINVAL); + + ndev = kzalloc(sizeof(*ndev), GFP_KERNEL_ACCOUNT); + if (!ndev) + return ERR_PTR(err); + + rwlock_init(&ndev->lock); + ndev->dev = dev; + INIT_LIST_HEAD(&ndev->addr_list); + timer_setup(&ndev->rs_timer, addrconf_rs_timer, 0); + memcpy(&ndev->cnf, dev_net(dev)->ipv6.devconf_dflt, sizeof(ndev->cnf)); + + if (ndev->cnf.stable_secret.initialized) + ndev->cnf.addr_gen_mode = IN6_ADDR_GEN_MODE_STABLE_PRIVACY; + + ndev->cnf.mtu6 = dev->mtu; + ndev->ra_mtu = 0; + ndev->nd_parms = neigh_parms_alloc(dev, &nd_tbl); + if (!ndev->nd_parms) { + kfree(ndev); + return ERR_PTR(err); + } + if (ndev->cnf.forwarding) + dev_disable_lro(dev); + /* We refer to the device */ + netdev_hold(dev, &ndev->dev_tracker, GFP_KERNEL); + + if (snmp6_alloc_dev(ndev) < 0) { + netdev_dbg(dev, "%s: cannot allocate memory for statistics\n", + __func__); + neigh_parms_release(&nd_tbl, ndev->nd_parms); + netdev_put(dev, &ndev->dev_tracker); + kfree(ndev); + return ERR_PTR(err); + } + + if (dev != blackhole_netdev) { + if (snmp6_register_dev(ndev) < 0) { + netdev_dbg(dev, "%s: cannot create /proc/net/dev_snmp6/%s\n", + __func__, dev->name); + goto err_release; + } + } + /* One reference from device. */ + refcount_set(&ndev->refcnt, 1); + + if (dev->flags & (IFF_NOARP | IFF_LOOPBACK)) + ndev->cnf.accept_dad = -1; + +#if IS_ENABLED(CONFIG_IPV6_SIT) + if (dev->type == ARPHRD_SIT && (dev->priv_flags & IFF_ISATAP)) { + pr_info("%s: Disabled Multicast RS\n", dev->name); + ndev->cnf.rtr_solicits = 0; + } +#endif + + INIT_LIST_HEAD(&ndev->tempaddr_list); + ndev->desync_factor = U32_MAX; + if ((dev->flags&IFF_LOOPBACK) || + dev->type == ARPHRD_TUNNEL || + dev->type == ARPHRD_TUNNEL6 || + dev->type == ARPHRD_SIT || + dev->type == ARPHRD_NONE) { + ndev->cnf.use_tempaddr = -1; + } + + ndev->token = in6addr_any; + + if (netif_running(dev) && addrconf_link_ready(dev)) + ndev->if_flags |= IF_READY; + + ipv6_mc_init_dev(ndev); + ndev->tstamp = jiffies; + if (dev != blackhole_netdev) { + err = addrconf_sysctl_register(ndev); + if (err) { + ipv6_mc_destroy_dev(ndev); + snmp6_unregister_dev(ndev); + goto err_release; + } + } + /* protected by rtnl_lock */ + rcu_assign_pointer(dev->ip6_ptr, ndev); + + if (dev != blackhole_netdev) { + /* Join interface-local all-node multicast group */ + ipv6_dev_mc_inc(dev, &in6addr_interfacelocal_allnodes); + + /* Join all-node multicast group */ + ipv6_dev_mc_inc(dev, &in6addr_linklocal_allnodes); + + /* Join all-router multicast group if forwarding is set */ + if (ndev->cnf.forwarding && (dev->flags & IFF_MULTICAST)) + ipv6_dev_mc_inc(dev, &in6addr_linklocal_allrouters); + } + return ndev; + +err_release: + neigh_parms_release(&nd_tbl, ndev->nd_parms); + ndev->dead = 1; + in6_dev_finish_destroy(ndev); + return ERR_PTR(err); +} + +static struct inet6_dev *ipv6_find_idev(struct net_device *dev) +{ + struct inet6_dev *idev; + + ASSERT_RTNL(); + + idev = __in6_dev_get(dev); + if (!idev) { + idev = ipv6_add_dev(dev); + if (IS_ERR(idev)) + return idev; + } + + if (dev->flags&IFF_UP) + ipv6_mc_up(idev); + return idev; +} + +static int inet6_netconf_msgsize_devconf(int type) +{ + int size = NLMSG_ALIGN(sizeof(struct netconfmsg)) + + nla_total_size(4); /* NETCONFA_IFINDEX */ + bool all = false; + + if (type == NETCONFA_ALL) + all = true; + + if (all || type == NETCONFA_FORWARDING) + size += nla_total_size(4); +#ifdef CONFIG_IPV6_MROUTE + if (all || type == NETCONFA_MC_FORWARDING) + size += nla_total_size(4); +#endif + if (all || type == NETCONFA_PROXY_NEIGH) + size += nla_total_size(4); + + if (all || type == NETCONFA_IGNORE_ROUTES_WITH_LINKDOWN) + size += nla_total_size(4); + + return size; +} + +static int inet6_netconf_fill_devconf(struct sk_buff *skb, int ifindex, + struct ipv6_devconf *devconf, u32 portid, + u32 seq, int event, unsigned int flags, + int type) +{ + struct nlmsghdr *nlh; + struct netconfmsg *ncm; + bool all = false; + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(struct netconfmsg), + flags); + if (!nlh) + return -EMSGSIZE; + + if (type == NETCONFA_ALL) + all = true; + + ncm = nlmsg_data(nlh); + ncm->ncm_family = AF_INET6; + + if (nla_put_s32(skb, NETCONFA_IFINDEX, ifindex) < 0) + goto nla_put_failure; + + if (!devconf) + goto out; + + if ((all || type == NETCONFA_FORWARDING) && + nla_put_s32(skb, NETCONFA_FORWARDING, devconf->forwarding) < 0) + goto nla_put_failure; +#ifdef CONFIG_IPV6_MROUTE + if ((all || type == NETCONFA_MC_FORWARDING) && + nla_put_s32(skb, NETCONFA_MC_FORWARDING, + atomic_read(&devconf->mc_forwarding)) < 0) + goto nla_put_failure; +#endif + if ((all || type == NETCONFA_PROXY_NEIGH) && + nla_put_s32(skb, NETCONFA_PROXY_NEIGH, devconf->proxy_ndp) < 0) + goto nla_put_failure; + + if ((all || type == NETCONFA_IGNORE_ROUTES_WITH_LINKDOWN) && + nla_put_s32(skb, NETCONFA_IGNORE_ROUTES_WITH_LINKDOWN, + devconf->ignore_routes_with_linkdown) < 0) + goto nla_put_failure; + +out: + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +void inet6_netconf_notify_devconf(struct net *net, int event, int type, + int ifindex, struct ipv6_devconf *devconf) +{ + struct sk_buff *skb; + int err = -ENOBUFS; + + skb = nlmsg_new(inet6_netconf_msgsize_devconf(type), GFP_KERNEL); + if (!skb) + goto errout; + + err = inet6_netconf_fill_devconf(skb, ifindex, devconf, 0, 0, + event, 0, type); + if (err < 0) { + /* -EMSGSIZE implies BUG in inet6_netconf_msgsize_devconf() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + rtnl_notify(skb, net, 0, RTNLGRP_IPV6_NETCONF, NULL, GFP_KERNEL); + return; +errout: + rtnl_set_sk_err(net, RTNLGRP_IPV6_NETCONF, err); +} + +static const struct nla_policy devconf_ipv6_policy[NETCONFA_MAX+1] = { + [NETCONFA_IFINDEX] = { .len = sizeof(int) }, + [NETCONFA_FORWARDING] = { .len = sizeof(int) }, + [NETCONFA_PROXY_NEIGH] = { .len = sizeof(int) }, + [NETCONFA_IGNORE_ROUTES_WITH_LINKDOWN] = { .len = sizeof(int) }, +}; + +static int inet6_netconf_valid_get_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(struct netconfmsg))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for netconf get request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse_deprecated(nlh, sizeof(struct netconfmsg), + tb, NETCONFA_MAX, + devconf_ipv6_policy, extack); + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(struct netconfmsg), + tb, NETCONFA_MAX, + devconf_ipv6_policy, extack); + if (err) + return err; + + for (i = 0; i <= NETCONFA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case NETCONFA_IFINDEX: + break; + default: + NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in netconf get request"); + return -EINVAL; + } + } + + return 0; +} + +static int inet6_netconf_get_devconf(struct sk_buff *in_skb, + struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(in_skb->sk); + struct nlattr *tb[NETCONFA_MAX+1]; + struct inet6_dev *in6_dev = NULL; + struct net_device *dev = NULL; + struct sk_buff *skb; + struct ipv6_devconf *devconf; + int ifindex; + int err; + + err = inet6_netconf_valid_get_req(in_skb, nlh, tb, extack); + if (err < 0) + return err; + + if (!tb[NETCONFA_IFINDEX]) + return -EINVAL; + + err = -EINVAL; + ifindex = nla_get_s32(tb[NETCONFA_IFINDEX]); + switch (ifindex) { + case NETCONFA_IFINDEX_ALL: + devconf = net->ipv6.devconf_all; + break; + case NETCONFA_IFINDEX_DEFAULT: + devconf = net->ipv6.devconf_dflt; + break; + default: + dev = dev_get_by_index(net, ifindex); + if (!dev) + return -EINVAL; + in6_dev = in6_dev_get(dev); + if (!in6_dev) + goto errout; + devconf = &in6_dev->cnf; + break; + } + + err = -ENOBUFS; + skb = nlmsg_new(inet6_netconf_msgsize_devconf(NETCONFA_ALL), GFP_KERNEL); + if (!skb) + goto errout; + + err = inet6_netconf_fill_devconf(skb, ifindex, devconf, + NETLINK_CB(in_skb).portid, + nlh->nlmsg_seq, RTM_NEWNETCONF, 0, + NETCONFA_ALL); + if (err < 0) { + /* -EMSGSIZE implies BUG in inet6_netconf_msgsize_devconf() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid); +errout: + if (in6_dev) + in6_dev_put(in6_dev); + dev_put(dev); + return err; +} + +static int inet6_netconf_dump_devconf(struct sk_buff *skb, + struct netlink_callback *cb) +{ + const struct nlmsghdr *nlh = cb->nlh; + struct net *net = sock_net(skb->sk); + int h, s_h; + int idx, s_idx; + struct net_device *dev; + struct inet6_dev *idev; + struct hlist_head *head; + + if (cb->strict_check) { + struct netlink_ext_ack *extack = cb->extack; + struct netconfmsg *ncm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ncm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for netconf dump request"); + return -EINVAL; + } + + if (nlmsg_attrlen(nlh, sizeof(*ncm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid data after header in netconf dump request"); + return -EINVAL; + } + } + + s_h = cb->args[0]; + s_idx = idx = cb->args[1]; + + for (h = s_h; h < NETDEV_HASHENTRIES; h++, s_idx = 0) { + idx = 0; + head = &net->dev_index_head[h]; + rcu_read_lock(); + cb->seq = atomic_read(&net->ipv6.dev_addr_genid) ^ + net->dev_base_seq; + hlist_for_each_entry_rcu(dev, head, index_hlist) { + if (idx < s_idx) + goto cont; + idev = __in6_dev_get(dev); + if (!idev) + goto cont; + + if (inet6_netconf_fill_devconf(skb, dev->ifindex, + &idev->cnf, + NETLINK_CB(cb->skb).portid, + nlh->nlmsg_seq, + RTM_NEWNETCONF, + NLM_F_MULTI, + NETCONFA_ALL) < 0) { + rcu_read_unlock(); + goto done; + } + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); +cont: + idx++; + } + rcu_read_unlock(); + } + if (h == NETDEV_HASHENTRIES) { + if (inet6_netconf_fill_devconf(skb, NETCONFA_IFINDEX_ALL, + net->ipv6.devconf_all, + NETLINK_CB(cb->skb).portid, + nlh->nlmsg_seq, + RTM_NEWNETCONF, NLM_F_MULTI, + NETCONFA_ALL) < 0) + goto done; + else + h++; + } + if (h == NETDEV_HASHENTRIES + 1) { + if (inet6_netconf_fill_devconf(skb, NETCONFA_IFINDEX_DEFAULT, + net->ipv6.devconf_dflt, + NETLINK_CB(cb->skb).portid, + nlh->nlmsg_seq, + RTM_NEWNETCONF, NLM_F_MULTI, + NETCONFA_ALL) < 0) + goto done; + else + h++; + } +done: + cb->args[0] = h; + cb->args[1] = idx; + + return skb->len; +} + +#ifdef CONFIG_SYSCTL +static void dev_forward_change(struct inet6_dev *idev) +{ + struct net_device *dev; + struct inet6_ifaddr *ifa; + LIST_HEAD(tmp_addr_list); + + if (!idev) + return; + dev = idev->dev; + if (idev->cnf.forwarding) + dev_disable_lro(dev); + if (dev->flags & IFF_MULTICAST) { + if (idev->cnf.forwarding) { + ipv6_dev_mc_inc(dev, &in6addr_linklocal_allrouters); + ipv6_dev_mc_inc(dev, &in6addr_interfacelocal_allrouters); + ipv6_dev_mc_inc(dev, &in6addr_sitelocal_allrouters); + } else { + ipv6_dev_mc_dec(dev, &in6addr_linklocal_allrouters); + ipv6_dev_mc_dec(dev, &in6addr_interfacelocal_allrouters); + ipv6_dev_mc_dec(dev, &in6addr_sitelocal_allrouters); + } + } + + read_lock_bh(&idev->lock); + list_for_each_entry(ifa, &idev->addr_list, if_list) { + if (ifa->flags&IFA_F_TENTATIVE) + continue; + list_add_tail(&ifa->if_list_aux, &tmp_addr_list); + } + read_unlock_bh(&idev->lock); + + while (!list_empty(&tmp_addr_list)) { + ifa = list_first_entry(&tmp_addr_list, + struct inet6_ifaddr, if_list_aux); + list_del(&ifa->if_list_aux); + if (idev->cnf.forwarding) + addrconf_join_anycast(ifa); + else + addrconf_leave_anycast(ifa); + } + + inet6_netconf_notify_devconf(dev_net(dev), RTM_NEWNETCONF, + NETCONFA_FORWARDING, + dev->ifindex, &idev->cnf); +} + + +static void addrconf_forward_change(struct net *net, __s32 newf) +{ + struct net_device *dev; + struct inet6_dev *idev; + + for_each_netdev(net, dev) { + idev = __in6_dev_get(dev); + if (idev) { + int changed = (!idev->cnf.forwarding) ^ (!newf); + idev->cnf.forwarding = newf; + if (changed) + dev_forward_change(idev); + } + } +} + +static int addrconf_fixup_forwarding(struct ctl_table *table, int *p, int newf) +{ + struct net *net; + int old; + + if (!rtnl_trylock()) + return restart_syscall(); + + net = (struct net *)table->extra2; + old = *p; + *p = newf; + + if (p == &net->ipv6.devconf_dflt->forwarding) { + if ((!newf) ^ (!old)) + inet6_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_FORWARDING, + NETCONFA_IFINDEX_DEFAULT, + net->ipv6.devconf_dflt); + rtnl_unlock(); + return 0; + } + + if (p == &net->ipv6.devconf_all->forwarding) { + int old_dflt = net->ipv6.devconf_dflt->forwarding; + + net->ipv6.devconf_dflt->forwarding = newf; + if ((!newf) ^ (!old_dflt)) + inet6_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_FORWARDING, + NETCONFA_IFINDEX_DEFAULT, + net->ipv6.devconf_dflt); + + addrconf_forward_change(net, newf); + if ((!newf) ^ (!old)) + inet6_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_FORWARDING, + NETCONFA_IFINDEX_ALL, + net->ipv6.devconf_all); + } else if ((!newf) ^ (!old)) + dev_forward_change((struct inet6_dev *)table->extra1); + rtnl_unlock(); + + if (newf) + rt6_purge_dflt_routers(net); + return 1; +} + +static void addrconf_linkdown_change(struct net *net, __s32 newf) +{ + struct net_device *dev; + struct inet6_dev *idev; + + for_each_netdev(net, dev) { + idev = __in6_dev_get(dev); + if (idev) { + int changed = (!idev->cnf.ignore_routes_with_linkdown) ^ (!newf); + + idev->cnf.ignore_routes_with_linkdown = newf; + if (changed) + inet6_netconf_notify_devconf(dev_net(dev), + RTM_NEWNETCONF, + NETCONFA_IGNORE_ROUTES_WITH_LINKDOWN, + dev->ifindex, + &idev->cnf); + } + } +} + +static int addrconf_fixup_linkdown(struct ctl_table *table, int *p, int newf) +{ + struct net *net; + int old; + + if (!rtnl_trylock()) + return restart_syscall(); + + net = (struct net *)table->extra2; + old = *p; + *p = newf; + + if (p == &net->ipv6.devconf_dflt->ignore_routes_with_linkdown) { + if ((!newf) ^ (!old)) + inet6_netconf_notify_devconf(net, + RTM_NEWNETCONF, + NETCONFA_IGNORE_ROUTES_WITH_LINKDOWN, + NETCONFA_IFINDEX_DEFAULT, + net->ipv6.devconf_dflt); + rtnl_unlock(); + return 0; + } + + if (p == &net->ipv6.devconf_all->ignore_routes_with_linkdown) { + net->ipv6.devconf_dflt->ignore_routes_with_linkdown = newf; + addrconf_linkdown_change(net, newf); + if ((!newf) ^ (!old)) + inet6_netconf_notify_devconf(net, + RTM_NEWNETCONF, + NETCONFA_IGNORE_ROUTES_WITH_LINKDOWN, + NETCONFA_IFINDEX_ALL, + net->ipv6.devconf_all); + } + rtnl_unlock(); + + return 1; +} + +#endif + +/* Nobody refers to this ifaddr, destroy it */ +void inet6_ifa_finish_destroy(struct inet6_ifaddr *ifp) +{ + WARN_ON(!hlist_unhashed(&ifp->addr_lst)); + +#ifdef NET_REFCNT_DEBUG + pr_debug("%s\n", __func__); +#endif + + in6_dev_put(ifp->idev); + + if (cancel_delayed_work(&ifp->dad_work)) + pr_notice("delayed DAD work was pending while freeing ifa=%p\n", + ifp); + + if (ifp->state != INET6_IFADDR_STATE_DEAD) { + pr_warn("Freeing alive inet6 address %p\n", ifp); + return; + } + + kfree_rcu(ifp, rcu); +} + +static void +ipv6_link_dev_addr(struct inet6_dev *idev, struct inet6_ifaddr *ifp) +{ + struct list_head *p; + int ifp_scope = ipv6_addr_src_scope(&ifp->addr); + + /* + * Each device address list is sorted in order of scope - + * global before linklocal. + */ + list_for_each(p, &idev->addr_list) { + struct inet6_ifaddr *ifa + = list_entry(p, struct inet6_ifaddr, if_list); + if (ifp_scope >= ipv6_addr_src_scope(&ifa->addr)) + break; + } + + list_add_tail_rcu(&ifp->if_list, p); +} + +static u32 inet6_addr_hash(const struct net *net, const struct in6_addr *addr) +{ + u32 val = ipv6_addr_hash(addr) ^ net_hash_mix(net); + + return hash_32(val, IN6_ADDR_HSIZE_SHIFT); +} + +static bool ipv6_chk_same_addr(struct net *net, const struct in6_addr *addr, + struct net_device *dev, unsigned int hash) +{ + struct inet6_ifaddr *ifp; + + hlist_for_each_entry(ifp, &net->ipv6.inet6_addr_lst[hash], addr_lst) { + if (ipv6_addr_equal(&ifp->addr, addr)) { + if (!dev || ifp->idev->dev == dev) + return true; + } + } + return false; +} + +static int ipv6_add_addr_hash(struct net_device *dev, struct inet6_ifaddr *ifa) +{ + struct net *net = dev_net(dev); + unsigned int hash = inet6_addr_hash(net, &ifa->addr); + int err = 0; + + spin_lock_bh(&net->ipv6.addrconf_hash_lock); + + /* Ignore adding duplicate addresses on an interface */ + if (ipv6_chk_same_addr(net, &ifa->addr, dev, hash)) { + netdev_dbg(dev, "ipv6_add_addr: already assigned\n"); + err = -EEXIST; + } else { + hlist_add_head_rcu(&ifa->addr_lst, &net->ipv6.inet6_addr_lst[hash]); + } + + spin_unlock_bh(&net->ipv6.addrconf_hash_lock); + + return err; +} + +/* On success it returns ifp with increased reference count */ + +static struct inet6_ifaddr * +ipv6_add_addr(struct inet6_dev *idev, struct ifa6_config *cfg, + bool can_block, struct netlink_ext_ack *extack) +{ + gfp_t gfp_flags = can_block ? GFP_KERNEL : GFP_ATOMIC; + int addr_type = ipv6_addr_type(cfg->pfx); + struct net *net = dev_net(idev->dev); + struct inet6_ifaddr *ifa = NULL; + struct fib6_info *f6i = NULL; + int err = 0; + + if (addr_type == IPV6_ADDR_ANY || + (addr_type & IPV6_ADDR_MULTICAST && + !(cfg->ifa_flags & IFA_F_MCAUTOJOIN)) || + (!(idev->dev->flags & IFF_LOOPBACK) && + !netif_is_l3_master(idev->dev) && + addr_type & IPV6_ADDR_LOOPBACK)) + return ERR_PTR(-EADDRNOTAVAIL); + + if (idev->dead) { + err = -ENODEV; /*XXX*/ + goto out; + } + + if (idev->cnf.disable_ipv6) { + err = -EACCES; + goto out; + } + + /* validator notifier needs to be blocking; + * do not call in atomic context + */ + if (can_block) { + struct in6_validator_info i6vi = { + .i6vi_addr = *cfg->pfx, + .i6vi_dev = idev, + .extack = extack, + }; + + err = inet6addr_validator_notifier_call_chain(NETDEV_UP, &i6vi); + err = notifier_to_errno(err); + if (err < 0) + goto out; + } + + ifa = kzalloc(sizeof(*ifa), gfp_flags | __GFP_ACCOUNT); + if (!ifa) { + err = -ENOBUFS; + goto out; + } + + f6i = addrconf_f6i_alloc(net, idev, cfg->pfx, false, gfp_flags); + if (IS_ERR(f6i)) { + err = PTR_ERR(f6i); + f6i = NULL; + goto out; + } + + neigh_parms_data_state_setall(idev->nd_parms); + + ifa->addr = *cfg->pfx; + if (cfg->peer_pfx) + ifa->peer_addr = *cfg->peer_pfx; + + spin_lock_init(&ifa->lock); + INIT_DELAYED_WORK(&ifa->dad_work, addrconf_dad_work); + INIT_HLIST_NODE(&ifa->addr_lst); + ifa->scope = cfg->scope; + ifa->prefix_len = cfg->plen; + ifa->rt_priority = cfg->rt_priority; + ifa->flags = cfg->ifa_flags; + ifa->ifa_proto = cfg->ifa_proto; + /* No need to add the TENTATIVE flag for addresses with NODAD */ + if (!(cfg->ifa_flags & IFA_F_NODAD)) + ifa->flags |= IFA_F_TENTATIVE; + ifa->valid_lft = cfg->valid_lft; + ifa->prefered_lft = cfg->preferred_lft; + ifa->cstamp = ifa->tstamp = jiffies; + ifa->tokenized = false; + + ifa->rt = f6i; + + ifa->idev = idev; + in6_dev_hold(idev); + + /* For caller */ + refcount_set(&ifa->refcnt, 1); + + rcu_read_lock(); + + err = ipv6_add_addr_hash(idev->dev, ifa); + if (err < 0) { + rcu_read_unlock(); + goto out; + } + + write_lock_bh(&idev->lock); + + /* Add to inet6_dev unicast addr list. */ + ipv6_link_dev_addr(idev, ifa); + + if (ifa->flags&IFA_F_TEMPORARY) { + list_add(&ifa->tmp_list, &idev->tempaddr_list); + in6_ifa_hold(ifa); + } + + in6_ifa_hold(ifa); + write_unlock_bh(&idev->lock); + + rcu_read_unlock(); + + inet6addr_notifier_call_chain(NETDEV_UP, ifa); +out: + if (unlikely(err < 0)) { + fib6_info_release(f6i); + + if (ifa) { + if (ifa->idev) + in6_dev_put(ifa->idev); + kfree(ifa); + } + ifa = ERR_PTR(err); + } + + return ifa; +} + +enum cleanup_prefix_rt_t { + CLEANUP_PREFIX_RT_NOP, /* no cleanup action for prefix route */ + CLEANUP_PREFIX_RT_DEL, /* delete the prefix route */ + CLEANUP_PREFIX_RT_EXPIRE, /* update the lifetime of the prefix route */ +}; + +/* + * Check, whether the prefix for ifp would still need a prefix route + * after deleting ifp. The function returns one of the CLEANUP_PREFIX_RT_* + * constants. + * + * 1) we don't purge prefix if address was not permanent. + * prefix is managed by its own lifetime. + * 2) we also don't purge, if the address was IFA_F_NOPREFIXROUTE. + * 3) if there are no addresses, delete prefix. + * 4) if there are still other permanent address(es), + * corresponding prefix is still permanent. + * 5) if there are still other addresses with IFA_F_NOPREFIXROUTE, + * don't purge the prefix, assume user space is managing it. + * 6) otherwise, update prefix lifetime to the + * longest valid lifetime among the corresponding + * addresses on the device. + * Note: subsequent RA will update lifetime. + **/ +static enum cleanup_prefix_rt_t +check_cleanup_prefix_route(struct inet6_ifaddr *ifp, unsigned long *expires) +{ + struct inet6_ifaddr *ifa; + struct inet6_dev *idev = ifp->idev; + unsigned long lifetime; + enum cleanup_prefix_rt_t action = CLEANUP_PREFIX_RT_DEL; + + *expires = jiffies; + + list_for_each_entry(ifa, &idev->addr_list, if_list) { + if (ifa == ifp) + continue; + if (ifa->prefix_len != ifp->prefix_len || + !ipv6_prefix_equal(&ifa->addr, &ifp->addr, + ifp->prefix_len)) + continue; + if (ifa->flags & (IFA_F_PERMANENT | IFA_F_NOPREFIXROUTE)) + return CLEANUP_PREFIX_RT_NOP; + + action = CLEANUP_PREFIX_RT_EXPIRE; + + spin_lock(&ifa->lock); + + lifetime = addrconf_timeout_fixup(ifa->valid_lft, HZ); + /* + * Note: Because this address is + * not permanent, lifetime < + * LONG_MAX / HZ here. + */ + if (time_before(*expires, ifa->tstamp + lifetime * HZ)) + *expires = ifa->tstamp + lifetime * HZ; + spin_unlock(&ifa->lock); + } + + return action; +} + +static void +cleanup_prefix_route(struct inet6_ifaddr *ifp, unsigned long expires, + bool del_rt, bool del_peer) +{ + struct fib6_info *f6i; + + f6i = addrconf_get_prefix_route(del_peer ? &ifp->peer_addr : &ifp->addr, + ifp->prefix_len, + ifp->idev->dev, 0, RTF_DEFAULT, true); + if (f6i) { + if (del_rt) + ip6_del_rt(dev_net(ifp->idev->dev), f6i, false); + else { + if (!(f6i->fib6_flags & RTF_EXPIRES)) + fib6_set_expires(f6i, expires); + fib6_info_release(f6i); + } + } +} + + +/* This function wants to get referenced ifp and releases it before return */ + +static void ipv6_del_addr(struct inet6_ifaddr *ifp) +{ + enum cleanup_prefix_rt_t action = CLEANUP_PREFIX_RT_NOP; + struct net *net = dev_net(ifp->idev->dev); + unsigned long expires; + int state; + + ASSERT_RTNL(); + + spin_lock_bh(&ifp->lock); + state = ifp->state; + ifp->state = INET6_IFADDR_STATE_DEAD; + spin_unlock_bh(&ifp->lock); + + if (state == INET6_IFADDR_STATE_DEAD) + goto out; + + spin_lock_bh(&net->ipv6.addrconf_hash_lock); + hlist_del_init_rcu(&ifp->addr_lst); + spin_unlock_bh(&net->ipv6.addrconf_hash_lock); + + write_lock_bh(&ifp->idev->lock); + + if (ifp->flags&IFA_F_TEMPORARY) { + list_del(&ifp->tmp_list); + if (ifp->ifpub) { + in6_ifa_put(ifp->ifpub); + ifp->ifpub = NULL; + } + __in6_ifa_put(ifp); + } + + if (ifp->flags & IFA_F_PERMANENT && !(ifp->flags & IFA_F_NOPREFIXROUTE)) + action = check_cleanup_prefix_route(ifp, &expires); + + list_del_rcu(&ifp->if_list); + __in6_ifa_put(ifp); + + write_unlock_bh(&ifp->idev->lock); + + addrconf_del_dad_work(ifp); + + ipv6_ifa_notify(RTM_DELADDR, ifp); + + inet6addr_notifier_call_chain(NETDEV_DOWN, ifp); + + if (action != CLEANUP_PREFIX_RT_NOP) { + cleanup_prefix_route(ifp, expires, + action == CLEANUP_PREFIX_RT_DEL, false); + } + + /* clean up prefsrc entries */ + rt6_remove_prefsrc(ifp); +out: + in6_ifa_put(ifp); +} + +static int ipv6_create_tempaddr(struct inet6_ifaddr *ifp, bool block) +{ + struct inet6_dev *idev = ifp->idev; + unsigned long tmp_tstamp, age; + unsigned long regen_advance; + unsigned long now = jiffies; + s32 cnf_temp_preferred_lft; + struct inet6_ifaddr *ift; + struct ifa6_config cfg; + long max_desync_factor; + struct in6_addr addr; + int ret = 0; + + write_lock_bh(&idev->lock); + +retry: + in6_dev_hold(idev); + if (idev->cnf.use_tempaddr <= 0) { + write_unlock_bh(&idev->lock); + pr_info("%s: use_tempaddr is disabled\n", __func__); + in6_dev_put(idev); + ret = -1; + goto out; + } + spin_lock_bh(&ifp->lock); + if (ifp->regen_count++ >= idev->cnf.regen_max_retry) { + idev->cnf.use_tempaddr = -1; /*XXX*/ + spin_unlock_bh(&ifp->lock); + write_unlock_bh(&idev->lock); + pr_warn("%s: regeneration time exceeded - disabled temporary address support\n", + __func__); + in6_dev_put(idev); + ret = -1; + goto out; + } + in6_ifa_hold(ifp); + memcpy(addr.s6_addr, ifp->addr.s6_addr, 8); + ipv6_gen_rnd_iid(&addr); + + age = (now - ifp->tstamp) / HZ; + + regen_advance = idev->cnf.regen_max_retry * + idev->cnf.dad_transmits * + max(NEIGH_VAR(idev->nd_parms, RETRANS_TIME), HZ/100) / HZ; + + /* recalculate max_desync_factor each time and update + * idev->desync_factor if it's larger + */ + cnf_temp_preferred_lft = READ_ONCE(idev->cnf.temp_prefered_lft); + max_desync_factor = min_t(long, + idev->cnf.max_desync_factor, + cnf_temp_preferred_lft - regen_advance); + + if (unlikely(idev->desync_factor > max_desync_factor)) { + if (max_desync_factor > 0) { + get_random_bytes(&idev->desync_factor, + sizeof(idev->desync_factor)); + idev->desync_factor %= max_desync_factor; + } else { + idev->desync_factor = 0; + } + } + + memset(&cfg, 0, sizeof(cfg)); + cfg.valid_lft = min_t(__u32, ifp->valid_lft, + idev->cnf.temp_valid_lft + age); + cfg.preferred_lft = cnf_temp_preferred_lft + age - idev->desync_factor; + cfg.preferred_lft = min_t(__u32, ifp->prefered_lft, cfg.preferred_lft); + + cfg.plen = ifp->prefix_len; + tmp_tstamp = ifp->tstamp; + spin_unlock_bh(&ifp->lock); + + write_unlock_bh(&idev->lock); + + /* A temporary address is created only if this calculated Preferred + * Lifetime is greater than REGEN_ADVANCE time units. In particular, + * an implementation must not create a temporary address with a zero + * Preferred Lifetime. + * Use age calculation as in addrconf_verify to avoid unnecessary + * temporary addresses being generated. + */ + age = (now - tmp_tstamp + ADDRCONF_TIMER_FUZZ_MINUS) / HZ; + if (cfg.preferred_lft <= regen_advance + age) { + in6_ifa_put(ifp); + in6_dev_put(idev); + ret = -1; + goto out; + } + + cfg.ifa_flags = IFA_F_TEMPORARY; + /* set in addrconf_prefix_rcv() */ + if (ifp->flags & IFA_F_OPTIMISTIC) + cfg.ifa_flags |= IFA_F_OPTIMISTIC; + + cfg.pfx = &addr; + cfg.scope = ipv6_addr_scope(cfg.pfx); + + ift = ipv6_add_addr(idev, &cfg, block, NULL); + if (IS_ERR(ift)) { + in6_ifa_put(ifp); + in6_dev_put(idev); + pr_info("%s: retry temporary address regeneration\n", __func__); + write_lock_bh(&idev->lock); + goto retry; + } + + spin_lock_bh(&ift->lock); + ift->ifpub = ifp; + ift->cstamp = now; + ift->tstamp = tmp_tstamp; + spin_unlock_bh(&ift->lock); + + addrconf_dad_start(ift); + in6_ifa_put(ift); + in6_dev_put(idev); +out: + return ret; +} + +/* + * Choose an appropriate source address (RFC3484) + */ +enum { + IPV6_SADDR_RULE_INIT = 0, + IPV6_SADDR_RULE_LOCAL, + IPV6_SADDR_RULE_SCOPE, + IPV6_SADDR_RULE_PREFERRED, +#ifdef CONFIG_IPV6_MIP6 + IPV6_SADDR_RULE_HOA, +#endif + IPV6_SADDR_RULE_OIF, + IPV6_SADDR_RULE_LABEL, + IPV6_SADDR_RULE_PRIVACY, + IPV6_SADDR_RULE_ORCHID, + IPV6_SADDR_RULE_PREFIX, +#ifdef CONFIG_IPV6_OPTIMISTIC_DAD + IPV6_SADDR_RULE_NOT_OPTIMISTIC, +#endif + IPV6_SADDR_RULE_MAX +}; + +struct ipv6_saddr_score { + int rule; + int addr_type; + struct inet6_ifaddr *ifa; + DECLARE_BITMAP(scorebits, IPV6_SADDR_RULE_MAX); + int scopedist; + int matchlen; +}; + +struct ipv6_saddr_dst { + const struct in6_addr *addr; + int ifindex; + int scope; + int label; + unsigned int prefs; +}; + +static inline int ipv6_saddr_preferred(int type) +{ + if (type & (IPV6_ADDR_MAPPED|IPV6_ADDR_COMPATv4|IPV6_ADDR_LOOPBACK)) + return 1; + return 0; +} + +static bool ipv6_use_optimistic_addr(struct net *net, + struct inet6_dev *idev) +{ +#ifdef CONFIG_IPV6_OPTIMISTIC_DAD + if (!idev) + return false; + if (!net->ipv6.devconf_all->optimistic_dad && !idev->cnf.optimistic_dad) + return false; + if (!net->ipv6.devconf_all->use_optimistic && !idev->cnf.use_optimistic) + return false; + + return true; +#else + return false; +#endif +} + +static bool ipv6_allow_optimistic_dad(struct net *net, + struct inet6_dev *idev) +{ +#ifdef CONFIG_IPV6_OPTIMISTIC_DAD + if (!idev) + return false; + if (!net->ipv6.devconf_all->optimistic_dad && !idev->cnf.optimistic_dad) + return false; + + return true; +#else + return false; +#endif +} + +static int ipv6_get_saddr_eval(struct net *net, + struct ipv6_saddr_score *score, + struct ipv6_saddr_dst *dst, + int i) +{ + int ret; + + if (i <= score->rule) { + switch (i) { + case IPV6_SADDR_RULE_SCOPE: + ret = score->scopedist; + break; + case IPV6_SADDR_RULE_PREFIX: + ret = score->matchlen; + break; + default: + ret = !!test_bit(i, score->scorebits); + } + goto out; + } + + switch (i) { + case IPV6_SADDR_RULE_INIT: + /* Rule 0: remember if hiscore is not ready yet */ + ret = !!score->ifa; + break; + case IPV6_SADDR_RULE_LOCAL: + /* Rule 1: Prefer same address */ + ret = ipv6_addr_equal(&score->ifa->addr, dst->addr); + break; + case IPV6_SADDR_RULE_SCOPE: + /* Rule 2: Prefer appropriate scope + * + * ret + * ^ + * -1 | d 15 + * ---+--+-+---> scope + * | + * | d is scope of the destination. + * B-d | \ + * | \ <- smaller scope is better if + * B-15 | \ if scope is enough for destination. + * | ret = B - scope (-1 <= scope >= d <= 15). + * d-C-1 | / + * |/ <- greater is better + * -C / if scope is not enough for destination. + * /| ret = scope - C (-1 <= d < scope <= 15). + * + * d - C - 1 < B -15 (for all -1 <= d <= 15). + * C > d + 14 - B >= 15 + 14 - B = 29 - B. + * Assume B = 0 and we get C > 29. + */ + ret = __ipv6_addr_src_scope(score->addr_type); + if (ret >= dst->scope) + ret = -ret; + else + ret -= 128; /* 30 is enough */ + score->scopedist = ret; + break; + case IPV6_SADDR_RULE_PREFERRED: + { + /* Rule 3: Avoid deprecated and optimistic addresses */ + u8 avoid = IFA_F_DEPRECATED; + + if (!ipv6_use_optimistic_addr(net, score->ifa->idev)) + avoid |= IFA_F_OPTIMISTIC; + ret = ipv6_saddr_preferred(score->addr_type) || + !(score->ifa->flags & avoid); + break; + } +#ifdef CONFIG_IPV6_MIP6 + case IPV6_SADDR_RULE_HOA: + { + /* Rule 4: Prefer home address */ + int prefhome = !(dst->prefs & IPV6_PREFER_SRC_COA); + ret = !(score->ifa->flags & IFA_F_HOMEADDRESS) ^ prefhome; + break; + } +#endif + case IPV6_SADDR_RULE_OIF: + /* Rule 5: Prefer outgoing interface */ + ret = (!dst->ifindex || + dst->ifindex == score->ifa->idev->dev->ifindex); + break; + case IPV6_SADDR_RULE_LABEL: + /* Rule 6: Prefer matching label */ + ret = ipv6_addr_label(net, + &score->ifa->addr, score->addr_type, + score->ifa->idev->dev->ifindex) == dst->label; + break; + case IPV6_SADDR_RULE_PRIVACY: + { + /* Rule 7: Prefer public address + * Note: prefer temporary address if use_tempaddr >= 2 + */ + int preftmp = dst->prefs & (IPV6_PREFER_SRC_PUBLIC|IPV6_PREFER_SRC_TMP) ? + !!(dst->prefs & IPV6_PREFER_SRC_TMP) : + score->ifa->idev->cnf.use_tempaddr >= 2; + ret = (!(score->ifa->flags & IFA_F_TEMPORARY)) ^ preftmp; + break; + } + case IPV6_SADDR_RULE_ORCHID: + /* Rule 8-: Prefer ORCHID vs ORCHID or + * non-ORCHID vs non-ORCHID + */ + ret = !(ipv6_addr_orchid(&score->ifa->addr) ^ + ipv6_addr_orchid(dst->addr)); + break; + case IPV6_SADDR_RULE_PREFIX: + /* Rule 8: Use longest matching prefix */ + ret = ipv6_addr_diff(&score->ifa->addr, dst->addr); + if (ret > score->ifa->prefix_len) + ret = score->ifa->prefix_len; + score->matchlen = ret; + break; +#ifdef CONFIG_IPV6_OPTIMISTIC_DAD + case IPV6_SADDR_RULE_NOT_OPTIMISTIC: + /* Optimistic addresses still have lower precedence than other + * preferred addresses. + */ + ret = !(score->ifa->flags & IFA_F_OPTIMISTIC); + break; +#endif + default: + ret = 0; + } + + if (ret) + __set_bit(i, score->scorebits); + score->rule = i; +out: + return ret; +} + +static int __ipv6_dev_get_saddr(struct net *net, + struct ipv6_saddr_dst *dst, + struct inet6_dev *idev, + struct ipv6_saddr_score *scores, + int hiscore_idx) +{ + struct ipv6_saddr_score *score = &scores[1 - hiscore_idx], *hiscore = &scores[hiscore_idx]; + + list_for_each_entry_rcu(score->ifa, &idev->addr_list, if_list) { + int i; + + /* + * - Tentative Address (RFC2462 section 5.4) + * - A tentative address is not considered + * "assigned to an interface" in the traditional + * sense, unless it is also flagged as optimistic. + * - Candidate Source Address (section 4) + * - In any case, anycast addresses, multicast + * addresses, and the unspecified address MUST + * NOT be included in a candidate set. + */ + if ((score->ifa->flags & IFA_F_TENTATIVE) && + (!(score->ifa->flags & IFA_F_OPTIMISTIC))) + continue; + + score->addr_type = __ipv6_addr_type(&score->ifa->addr); + + if (unlikely(score->addr_type == IPV6_ADDR_ANY || + score->addr_type & IPV6_ADDR_MULTICAST)) { + net_dbg_ratelimited("ADDRCONF: unspecified / multicast address assigned as unicast address on %s", + idev->dev->name); + continue; + } + + score->rule = -1; + bitmap_zero(score->scorebits, IPV6_SADDR_RULE_MAX); + + for (i = 0; i < IPV6_SADDR_RULE_MAX; i++) { + int minihiscore, miniscore; + + minihiscore = ipv6_get_saddr_eval(net, hiscore, dst, i); + miniscore = ipv6_get_saddr_eval(net, score, dst, i); + + if (minihiscore > miniscore) { + if (i == IPV6_SADDR_RULE_SCOPE && + score->scopedist > 0) { + /* + * special case: + * each remaining entry + * has too small (not enough) + * scope, because ifa entries + * are sorted by their scope + * values. + */ + goto out; + } + break; + } else if (minihiscore < miniscore) { + swap(hiscore, score); + hiscore_idx = 1 - hiscore_idx; + + /* restore our iterator */ + score->ifa = hiscore->ifa; + + break; + } + } + } +out: + return hiscore_idx; +} + +static int ipv6_get_saddr_master(struct net *net, + const struct net_device *dst_dev, + const struct net_device *master, + struct ipv6_saddr_dst *dst, + struct ipv6_saddr_score *scores, + int hiscore_idx) +{ + struct inet6_dev *idev; + + idev = __in6_dev_get(dst_dev); + if (idev) + hiscore_idx = __ipv6_dev_get_saddr(net, dst, idev, + scores, hiscore_idx); + + idev = __in6_dev_get(master); + if (idev) + hiscore_idx = __ipv6_dev_get_saddr(net, dst, idev, + scores, hiscore_idx); + + return hiscore_idx; +} + +int ipv6_dev_get_saddr(struct net *net, const struct net_device *dst_dev, + const struct in6_addr *daddr, unsigned int prefs, + struct in6_addr *saddr) +{ + struct ipv6_saddr_score scores[2], *hiscore; + struct ipv6_saddr_dst dst; + struct inet6_dev *idev; + struct net_device *dev; + int dst_type; + bool use_oif_addr = false; + int hiscore_idx = 0; + int ret = 0; + + dst_type = __ipv6_addr_type(daddr); + dst.addr = daddr; + dst.ifindex = dst_dev ? dst_dev->ifindex : 0; + dst.scope = __ipv6_addr_src_scope(dst_type); + dst.label = ipv6_addr_label(net, daddr, dst_type, dst.ifindex); + dst.prefs = prefs; + + scores[hiscore_idx].rule = -1; + scores[hiscore_idx].ifa = NULL; + + rcu_read_lock(); + + /* Candidate Source Address (section 4) + * - multicast and link-local destination address, + * the set of candidate source address MUST only + * include addresses assigned to interfaces + * belonging to the same link as the outgoing + * interface. + * (- For site-local destination addresses, the + * set of candidate source addresses MUST only + * include addresses assigned to interfaces + * belonging to the same site as the outgoing + * interface.) + * - "It is RECOMMENDED that the candidate source addresses + * be the set of unicast addresses assigned to the + * interface that will be used to send to the destination + * (the 'outgoing' interface)." (RFC 6724) + */ + if (dst_dev) { + idev = __in6_dev_get(dst_dev); + if ((dst_type & IPV6_ADDR_MULTICAST) || + dst.scope <= IPV6_ADDR_SCOPE_LINKLOCAL || + (idev && idev->cnf.use_oif_addrs_only)) { + use_oif_addr = true; + } + } + + if (use_oif_addr) { + if (idev) + hiscore_idx = __ipv6_dev_get_saddr(net, &dst, idev, scores, hiscore_idx); + } else { + const struct net_device *master; + int master_idx = 0; + + /* if dst_dev exists and is enslaved to an L3 device, then + * prefer addresses from dst_dev and then the master over + * any other enslaved devices in the L3 domain. + */ + master = l3mdev_master_dev_rcu(dst_dev); + if (master) { + master_idx = master->ifindex; + + hiscore_idx = ipv6_get_saddr_master(net, dst_dev, + master, &dst, + scores, hiscore_idx); + + if (scores[hiscore_idx].ifa) + goto out; + } + + for_each_netdev_rcu(net, dev) { + /* only consider addresses on devices in the + * same L3 domain + */ + if (l3mdev_master_ifindex_rcu(dev) != master_idx) + continue; + idev = __in6_dev_get(dev); + if (!idev) + continue; + hiscore_idx = __ipv6_dev_get_saddr(net, &dst, idev, scores, hiscore_idx); + } + } + +out: + hiscore = &scores[hiscore_idx]; + if (!hiscore->ifa) + ret = -EADDRNOTAVAIL; + else + *saddr = hiscore->ifa->addr; + + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL(ipv6_dev_get_saddr); + +static int __ipv6_get_lladdr(struct inet6_dev *idev, struct in6_addr *addr, + u32 banned_flags) +{ + struct inet6_ifaddr *ifp; + int err = -EADDRNOTAVAIL; + + list_for_each_entry_reverse(ifp, &idev->addr_list, if_list) { + if (ifp->scope > IFA_LINK) + break; + if (ifp->scope == IFA_LINK && + !(ifp->flags & banned_flags)) { + *addr = ifp->addr; + err = 0; + break; + } + } + return err; +} + +int ipv6_get_lladdr(struct net_device *dev, struct in6_addr *addr, + u32 banned_flags) +{ + struct inet6_dev *idev; + int err = -EADDRNOTAVAIL; + + rcu_read_lock(); + idev = __in6_dev_get(dev); + if (idev) { + read_lock_bh(&idev->lock); + err = __ipv6_get_lladdr(idev, addr, banned_flags); + read_unlock_bh(&idev->lock); + } + rcu_read_unlock(); + return err; +} + +static int ipv6_count_addresses(const struct inet6_dev *idev) +{ + const struct inet6_ifaddr *ifp; + int cnt = 0; + + rcu_read_lock(); + list_for_each_entry_rcu(ifp, &idev->addr_list, if_list) + cnt++; + rcu_read_unlock(); + return cnt; +} + +int ipv6_chk_addr(struct net *net, const struct in6_addr *addr, + const struct net_device *dev, int strict) +{ + return ipv6_chk_addr_and_flags(net, addr, dev, !dev, + strict, IFA_F_TENTATIVE); +} +EXPORT_SYMBOL(ipv6_chk_addr); + +/* device argument is used to find the L3 domain of interest. If + * skip_dev_check is set, then the ifp device is not checked against + * the passed in dev argument. So the 2 cases for addresses checks are: + * 1. does the address exist in the L3 domain that dev is part of + * (skip_dev_check = true), or + * + * 2. does the address exist on the specific device + * (skip_dev_check = false) + */ +static struct net_device * +__ipv6_chk_addr_and_flags(struct net *net, const struct in6_addr *addr, + const struct net_device *dev, bool skip_dev_check, + int strict, u32 banned_flags) +{ + unsigned int hash = inet6_addr_hash(net, addr); + struct net_device *l3mdev, *ndev; + struct inet6_ifaddr *ifp; + u32 ifp_flags; + + rcu_read_lock(); + + l3mdev = l3mdev_master_dev_rcu(dev); + if (skip_dev_check) + dev = NULL; + + hlist_for_each_entry_rcu(ifp, &net->ipv6.inet6_addr_lst[hash], addr_lst) { + ndev = ifp->idev->dev; + + if (l3mdev_master_dev_rcu(ndev) != l3mdev) + continue; + + /* Decouple optimistic from tentative for evaluation here. + * Ban optimistic addresses explicitly, when required. + */ + ifp_flags = (ifp->flags&IFA_F_OPTIMISTIC) + ? (ifp->flags&~IFA_F_TENTATIVE) + : ifp->flags; + if (ipv6_addr_equal(&ifp->addr, addr) && + !(ifp_flags&banned_flags) && + (!dev || ndev == dev || + !(ifp->scope&(IFA_LINK|IFA_HOST) || strict))) { + rcu_read_unlock(); + return ndev; + } + } + + rcu_read_unlock(); + return NULL; +} + +int ipv6_chk_addr_and_flags(struct net *net, const struct in6_addr *addr, + const struct net_device *dev, bool skip_dev_check, + int strict, u32 banned_flags) +{ + return __ipv6_chk_addr_and_flags(net, addr, dev, skip_dev_check, + strict, banned_flags) ? 1 : 0; +} +EXPORT_SYMBOL(ipv6_chk_addr_and_flags); + + +/* Compares an address/prefix_len with addresses on device @dev. + * If one is found it returns true. + */ +bool ipv6_chk_custom_prefix(const struct in6_addr *addr, + const unsigned int prefix_len, struct net_device *dev) +{ + const struct inet6_ifaddr *ifa; + const struct inet6_dev *idev; + bool ret = false; + + rcu_read_lock(); + idev = __in6_dev_get(dev); + if (idev) { + list_for_each_entry_rcu(ifa, &idev->addr_list, if_list) { + ret = ipv6_prefix_equal(addr, &ifa->addr, prefix_len); + if (ret) + break; + } + } + rcu_read_unlock(); + + return ret; +} +EXPORT_SYMBOL(ipv6_chk_custom_prefix); + +int ipv6_chk_prefix(const struct in6_addr *addr, struct net_device *dev) +{ + const struct inet6_ifaddr *ifa; + const struct inet6_dev *idev; + int onlink; + + onlink = 0; + rcu_read_lock(); + idev = __in6_dev_get(dev); + if (idev) { + list_for_each_entry_rcu(ifa, &idev->addr_list, if_list) { + onlink = ipv6_prefix_equal(addr, &ifa->addr, + ifa->prefix_len); + if (onlink) + break; + } + } + rcu_read_unlock(); + return onlink; +} +EXPORT_SYMBOL(ipv6_chk_prefix); + +/** + * ipv6_dev_find - find the first device with a given source address. + * @net: the net namespace + * @addr: the source address + * @dev: used to find the L3 domain of interest + * + * The caller should be protected by RCU, or RTNL. + */ +struct net_device *ipv6_dev_find(struct net *net, const struct in6_addr *addr, + struct net_device *dev) +{ + return __ipv6_chk_addr_and_flags(net, addr, dev, !dev, 1, + IFA_F_TENTATIVE); +} +EXPORT_SYMBOL(ipv6_dev_find); + +struct inet6_ifaddr *ipv6_get_ifaddr(struct net *net, const struct in6_addr *addr, + struct net_device *dev, int strict) +{ + unsigned int hash = inet6_addr_hash(net, addr); + struct inet6_ifaddr *ifp, *result = NULL; + + rcu_read_lock(); + hlist_for_each_entry_rcu(ifp, &net->ipv6.inet6_addr_lst[hash], addr_lst) { + if (ipv6_addr_equal(&ifp->addr, addr)) { + if (!dev || ifp->idev->dev == dev || + !(ifp->scope&(IFA_LINK|IFA_HOST) || strict)) { + result = ifp; + in6_ifa_hold(ifp); + break; + } + } + } + rcu_read_unlock(); + + return result; +} + +/* Gets referenced address, destroys ifaddr */ + +static void addrconf_dad_stop(struct inet6_ifaddr *ifp, int dad_failed) +{ + if (dad_failed) + ifp->flags |= IFA_F_DADFAILED; + + if (ifp->flags&IFA_F_TEMPORARY) { + struct inet6_ifaddr *ifpub; + spin_lock_bh(&ifp->lock); + ifpub = ifp->ifpub; + if (ifpub) { + in6_ifa_hold(ifpub); + spin_unlock_bh(&ifp->lock); + ipv6_create_tempaddr(ifpub, true); + in6_ifa_put(ifpub); + } else { + spin_unlock_bh(&ifp->lock); + } + ipv6_del_addr(ifp); + } else if (ifp->flags&IFA_F_PERMANENT || !dad_failed) { + spin_lock_bh(&ifp->lock); + addrconf_del_dad_work(ifp); + ifp->flags |= IFA_F_TENTATIVE; + if (dad_failed) + ifp->flags &= ~IFA_F_OPTIMISTIC; + spin_unlock_bh(&ifp->lock); + if (dad_failed) + ipv6_ifa_notify(0, ifp); + in6_ifa_put(ifp); + } else { + ipv6_del_addr(ifp); + } +} + +static int addrconf_dad_end(struct inet6_ifaddr *ifp) +{ + int err = -ENOENT; + + spin_lock_bh(&ifp->lock); + if (ifp->state == INET6_IFADDR_STATE_DAD) { + ifp->state = INET6_IFADDR_STATE_POSTDAD; + err = 0; + } + spin_unlock_bh(&ifp->lock); + + return err; +} + +void addrconf_dad_failure(struct sk_buff *skb, struct inet6_ifaddr *ifp) +{ + struct inet6_dev *idev = ifp->idev; + struct net *net = dev_net(idev->dev); + + if (addrconf_dad_end(ifp)) { + in6_ifa_put(ifp); + return; + } + + net_info_ratelimited("%s: IPv6 duplicate address %pI6c used by %pM detected!\n", + ifp->idev->dev->name, &ifp->addr, eth_hdr(skb)->h_source); + + spin_lock_bh(&ifp->lock); + + if (ifp->flags & IFA_F_STABLE_PRIVACY) { + struct in6_addr new_addr; + struct inet6_ifaddr *ifp2; + int retries = ifp->stable_privacy_retry + 1; + struct ifa6_config cfg = { + .pfx = &new_addr, + .plen = ifp->prefix_len, + .ifa_flags = ifp->flags, + .valid_lft = ifp->valid_lft, + .preferred_lft = ifp->prefered_lft, + .scope = ifp->scope, + }; + + if (retries > net->ipv6.sysctl.idgen_retries) { + net_info_ratelimited("%s: privacy stable address generation failed because of DAD conflicts!\n", + ifp->idev->dev->name); + goto errdad; + } + + new_addr = ifp->addr; + if (ipv6_generate_stable_address(&new_addr, retries, + idev)) + goto errdad; + + spin_unlock_bh(&ifp->lock); + + if (idev->cnf.max_addresses && + ipv6_count_addresses(idev) >= + idev->cnf.max_addresses) + goto lock_errdad; + + net_info_ratelimited("%s: generating new stable privacy address because of DAD conflict\n", + ifp->idev->dev->name); + + ifp2 = ipv6_add_addr(idev, &cfg, false, NULL); + if (IS_ERR(ifp2)) + goto lock_errdad; + + spin_lock_bh(&ifp2->lock); + ifp2->stable_privacy_retry = retries; + ifp2->state = INET6_IFADDR_STATE_PREDAD; + spin_unlock_bh(&ifp2->lock); + + addrconf_mod_dad_work(ifp2, net->ipv6.sysctl.idgen_delay); + in6_ifa_put(ifp2); +lock_errdad: + spin_lock_bh(&ifp->lock); + } + +errdad: + /* transition from _POSTDAD to _ERRDAD */ + ifp->state = INET6_IFADDR_STATE_ERRDAD; + spin_unlock_bh(&ifp->lock); + + addrconf_mod_dad_work(ifp, 0); + in6_ifa_put(ifp); +} + +/* Join to solicited addr multicast group. + * caller must hold RTNL */ +void addrconf_join_solict(struct net_device *dev, const struct in6_addr *addr) +{ + struct in6_addr maddr; + + if (dev->flags&(IFF_LOOPBACK|IFF_NOARP)) + return; + + addrconf_addr_solict_mult(addr, &maddr); + ipv6_dev_mc_inc(dev, &maddr); +} + +/* caller must hold RTNL */ +void addrconf_leave_solict(struct inet6_dev *idev, const struct in6_addr *addr) +{ + struct in6_addr maddr; + + if (idev->dev->flags&(IFF_LOOPBACK|IFF_NOARP)) + return; + + addrconf_addr_solict_mult(addr, &maddr); + __ipv6_dev_mc_dec(idev, &maddr); +} + +/* caller must hold RTNL */ +static void addrconf_join_anycast(struct inet6_ifaddr *ifp) +{ + struct in6_addr addr; + + if (ifp->prefix_len >= 127) /* RFC 6164 */ + return; + ipv6_addr_prefix(&addr, &ifp->addr, ifp->prefix_len); + if (ipv6_addr_any(&addr)) + return; + __ipv6_dev_ac_inc(ifp->idev, &addr); +} + +/* caller must hold RTNL */ +static void addrconf_leave_anycast(struct inet6_ifaddr *ifp) +{ + struct in6_addr addr; + + if (ifp->prefix_len >= 127) /* RFC 6164 */ + return; + ipv6_addr_prefix(&addr, &ifp->addr, ifp->prefix_len); + if (ipv6_addr_any(&addr)) + return; + __ipv6_dev_ac_dec(ifp->idev, &addr); +} + +static int addrconf_ifid_6lowpan(u8 *eui, struct net_device *dev) +{ + switch (dev->addr_len) { + case ETH_ALEN: + memcpy(eui, dev->dev_addr, 3); + eui[3] = 0xFF; + eui[4] = 0xFE; + memcpy(eui + 5, dev->dev_addr + 3, 3); + break; + case EUI64_ADDR_LEN: + memcpy(eui, dev->dev_addr, EUI64_ADDR_LEN); + eui[0] ^= 2; + break; + default: + return -1; + } + + return 0; +} + +static int addrconf_ifid_ieee1394(u8 *eui, struct net_device *dev) +{ + const union fwnet_hwaddr *ha; + + if (dev->addr_len != FWNET_ALEN) + return -1; + + ha = (const union fwnet_hwaddr *)dev->dev_addr; + + memcpy(eui, &ha->uc.uniq_id, sizeof(ha->uc.uniq_id)); + eui[0] ^= 2; + return 0; +} + +static int addrconf_ifid_arcnet(u8 *eui, struct net_device *dev) +{ + /* XXX: inherit EUI-64 from other interface -- yoshfuji */ + if (dev->addr_len != ARCNET_ALEN) + return -1; + memset(eui, 0, 7); + eui[7] = *(u8 *)dev->dev_addr; + return 0; +} + +static int addrconf_ifid_infiniband(u8 *eui, struct net_device *dev) +{ + if (dev->addr_len != INFINIBAND_ALEN) + return -1; + memcpy(eui, dev->dev_addr + 12, 8); + eui[0] |= 2; + return 0; +} + +static int __ipv6_isatap_ifid(u8 *eui, __be32 addr) +{ + if (addr == 0) + return -1; + eui[0] = (ipv4_is_zeronet(addr) || ipv4_is_private_10(addr) || + ipv4_is_loopback(addr) || ipv4_is_linklocal_169(addr) || + ipv4_is_private_172(addr) || ipv4_is_test_192(addr) || + ipv4_is_anycast_6to4(addr) || ipv4_is_private_192(addr) || + ipv4_is_test_198(addr) || ipv4_is_multicast(addr) || + ipv4_is_lbcast(addr)) ? 0x00 : 0x02; + eui[1] = 0; + eui[2] = 0x5E; + eui[3] = 0xFE; + memcpy(eui + 4, &addr, 4); + return 0; +} + +static int addrconf_ifid_sit(u8 *eui, struct net_device *dev) +{ + if (dev->priv_flags & IFF_ISATAP) + return __ipv6_isatap_ifid(eui, *(__be32 *)dev->dev_addr); + return -1; +} + +static int addrconf_ifid_gre(u8 *eui, struct net_device *dev) +{ + return __ipv6_isatap_ifid(eui, *(__be32 *)dev->dev_addr); +} + +static int addrconf_ifid_ip6tnl(u8 *eui, struct net_device *dev) +{ + memcpy(eui, dev->perm_addr, 3); + memcpy(eui + 5, dev->perm_addr + 3, 3); + eui[3] = 0xFF; + eui[4] = 0xFE; + eui[0] ^= 2; + return 0; +} + +static int ipv6_generate_eui64(u8 *eui, struct net_device *dev) +{ + switch (dev->type) { + case ARPHRD_ETHER: + case ARPHRD_FDDI: + return addrconf_ifid_eui48(eui, dev); + case ARPHRD_ARCNET: + return addrconf_ifid_arcnet(eui, dev); + case ARPHRD_INFINIBAND: + return addrconf_ifid_infiniband(eui, dev); + case ARPHRD_SIT: + return addrconf_ifid_sit(eui, dev); + case ARPHRD_IPGRE: + case ARPHRD_TUNNEL: + return addrconf_ifid_gre(eui, dev); + case ARPHRD_6LOWPAN: + return addrconf_ifid_6lowpan(eui, dev); + case ARPHRD_IEEE1394: + return addrconf_ifid_ieee1394(eui, dev); + case ARPHRD_TUNNEL6: + case ARPHRD_IP6GRE: + case ARPHRD_RAWIP: + return addrconf_ifid_ip6tnl(eui, dev); + } + return -1; +} + +static int ipv6_inherit_eui64(u8 *eui, struct inet6_dev *idev) +{ + int err = -1; + struct inet6_ifaddr *ifp; + + read_lock_bh(&idev->lock); + list_for_each_entry_reverse(ifp, &idev->addr_list, if_list) { + if (ifp->scope > IFA_LINK) + break; + if (ifp->scope == IFA_LINK && !(ifp->flags&IFA_F_TENTATIVE)) { + memcpy(eui, ifp->addr.s6_addr+8, 8); + err = 0; + break; + } + } + read_unlock_bh(&idev->lock); + return err; +} + +/* Generation of a randomized Interface Identifier + * draft-ietf-6man-rfc4941bis, Section 3.3.1 + */ + +static void ipv6_gen_rnd_iid(struct in6_addr *addr) +{ +regen: + get_random_bytes(&addr->s6_addr[8], 8); + + /* , Section 3.3.1: + * check if generated address is not inappropriate: + * + * - Reserved IPv6 Interface Identifiers + * - XXX: already assigned to an address on the device + */ + + /* Subnet-router anycast: 0000:0000:0000:0000 */ + if (!(addr->s6_addr32[2] | addr->s6_addr32[3])) + goto regen; + + /* IANA Ethernet block: 0200:5EFF:FE00:0000-0200:5EFF:FE00:5212 + * Proxy Mobile IPv6: 0200:5EFF:FE00:5213 + * IANA Ethernet block: 0200:5EFF:FE00:5214-0200:5EFF:FEFF:FFFF + */ + if (ntohl(addr->s6_addr32[2]) == 0x02005eff && + (ntohl(addr->s6_addr32[3]) & 0Xff000000) == 0xfe000000) + goto regen; + + /* Reserved subnet anycast addresses */ + if (ntohl(addr->s6_addr32[2]) == 0xfdffffff && + ntohl(addr->s6_addr32[3]) >= 0Xffffff80) + goto regen; +} + +/* + * Add prefix route. + */ + +static void +addrconf_prefix_route(struct in6_addr *pfx, int plen, u32 metric, + struct net_device *dev, unsigned long expires, + u32 flags, gfp_t gfp_flags) +{ + struct fib6_config cfg = { + .fc_table = l3mdev_fib_table(dev) ? : RT6_TABLE_PREFIX, + .fc_metric = metric ? : IP6_RT_PRIO_ADDRCONF, + .fc_ifindex = dev->ifindex, + .fc_expires = expires, + .fc_dst_len = plen, + .fc_flags = RTF_UP | flags, + .fc_nlinfo.nl_net = dev_net(dev), + .fc_protocol = RTPROT_KERNEL, + .fc_type = RTN_UNICAST, + }; + + cfg.fc_dst = *pfx; + + /* Prevent useless cloning on PtP SIT. + This thing is done here expecting that the whole + class of non-broadcast devices need not cloning. + */ +#if IS_ENABLED(CONFIG_IPV6_SIT) + if (dev->type == ARPHRD_SIT && (dev->flags & IFF_POINTOPOINT)) + cfg.fc_flags |= RTF_NONEXTHOP; +#endif + + ip6_route_add(&cfg, gfp_flags, NULL); +} + + +static struct fib6_info *addrconf_get_prefix_route(const struct in6_addr *pfx, + int plen, + const struct net_device *dev, + u32 flags, u32 noflags, + bool no_gw) +{ + struct fib6_node *fn; + struct fib6_info *rt = NULL; + struct fib6_table *table; + u32 tb_id = l3mdev_fib_table(dev) ? : RT6_TABLE_PREFIX; + + table = fib6_get_table(dev_net(dev), tb_id); + if (!table) + return NULL; + + rcu_read_lock(); + fn = fib6_locate(&table->tb6_root, pfx, plen, NULL, 0, true); + if (!fn) + goto out; + + for_each_fib6_node_rt_rcu(fn) { + /* prefix routes only use builtin fib6_nh */ + if (rt->nh) + continue; + + if (rt->fib6_nh->fib_nh_dev->ifindex != dev->ifindex) + continue; + if (no_gw && rt->fib6_nh->fib_nh_gw_family) + continue; + if ((rt->fib6_flags & flags) != flags) + continue; + if ((rt->fib6_flags & noflags) != 0) + continue; + if (!fib6_info_hold_safe(rt)) + continue; + break; + } +out: + rcu_read_unlock(); + return rt; +} + + +/* Create "default" multicast route to the interface */ + +static void addrconf_add_mroute(struct net_device *dev) +{ + struct fib6_config cfg = { + .fc_table = l3mdev_fib_table(dev) ? : RT6_TABLE_LOCAL, + .fc_metric = IP6_RT_PRIO_ADDRCONF, + .fc_ifindex = dev->ifindex, + .fc_dst_len = 8, + .fc_flags = RTF_UP, + .fc_type = RTN_MULTICAST, + .fc_nlinfo.nl_net = dev_net(dev), + .fc_protocol = RTPROT_KERNEL, + }; + + ipv6_addr_set(&cfg.fc_dst, htonl(0xFF000000), 0, 0, 0); + + ip6_route_add(&cfg, GFP_KERNEL, NULL); +} + +static struct inet6_dev *addrconf_add_dev(struct net_device *dev) +{ + struct inet6_dev *idev; + + ASSERT_RTNL(); + + idev = ipv6_find_idev(dev); + if (IS_ERR(idev)) + return idev; + + if (idev->cnf.disable_ipv6) + return ERR_PTR(-EACCES); + + /* Add default multicast route */ + if (!(dev->flags & IFF_LOOPBACK) && !netif_is_l3_master(dev)) + addrconf_add_mroute(dev); + + return idev; +} + +static void manage_tempaddrs(struct inet6_dev *idev, + struct inet6_ifaddr *ifp, + __u32 valid_lft, __u32 prefered_lft, + bool create, unsigned long now) +{ + u32 flags; + struct inet6_ifaddr *ift; + + read_lock_bh(&idev->lock); + /* update all temporary addresses in the list */ + list_for_each_entry(ift, &idev->tempaddr_list, tmp_list) { + int age, max_valid, max_prefered; + + if (ifp != ift->ifpub) + continue; + + /* RFC 4941 section 3.3: + * If a received option will extend the lifetime of a public + * address, the lifetimes of temporary addresses should + * be extended, subject to the overall constraint that no + * temporary addresses should ever remain "valid" or "preferred" + * for a time longer than (TEMP_VALID_LIFETIME) or + * (TEMP_PREFERRED_LIFETIME - DESYNC_FACTOR), respectively. + */ + age = (now - ift->cstamp) / HZ; + max_valid = idev->cnf.temp_valid_lft - age; + if (max_valid < 0) + max_valid = 0; + + max_prefered = idev->cnf.temp_prefered_lft - + idev->desync_factor - age; + if (max_prefered < 0) + max_prefered = 0; + + if (valid_lft > max_valid) + valid_lft = max_valid; + + if (prefered_lft > max_prefered) + prefered_lft = max_prefered; + + spin_lock(&ift->lock); + flags = ift->flags; + ift->valid_lft = valid_lft; + ift->prefered_lft = prefered_lft; + ift->tstamp = now; + if (prefered_lft > 0) + ift->flags &= ~IFA_F_DEPRECATED; + + spin_unlock(&ift->lock); + if (!(flags&IFA_F_TENTATIVE)) + ipv6_ifa_notify(0, ift); + } + + /* Also create a temporary address if it's enabled but no temporary + * address currently exists. + * However, we get called with valid_lft == 0, prefered_lft == 0, create == false + * as part of cleanup (ie. deleting the mngtmpaddr). + * We don't want that to result in creating a new temporary ip address. + */ + if (list_empty(&idev->tempaddr_list) && (valid_lft || prefered_lft)) + create = true; + + if (create && idev->cnf.use_tempaddr > 0) { + /* When a new public address is created as described + * in [ADDRCONF], also create a new temporary address. + */ + read_unlock_bh(&idev->lock); + ipv6_create_tempaddr(ifp, false); + } else { + read_unlock_bh(&idev->lock); + } +} + +static bool is_addr_mode_generate_stable(struct inet6_dev *idev) +{ + return idev->cnf.addr_gen_mode == IN6_ADDR_GEN_MODE_STABLE_PRIVACY || + idev->cnf.addr_gen_mode == IN6_ADDR_GEN_MODE_RANDOM; +} + +int addrconf_prefix_rcv_add_addr(struct net *net, struct net_device *dev, + const struct prefix_info *pinfo, + struct inet6_dev *in6_dev, + const struct in6_addr *addr, int addr_type, + u32 addr_flags, bool sllao, bool tokenized, + __u32 valid_lft, u32 prefered_lft) +{ + struct inet6_ifaddr *ifp = ipv6_get_ifaddr(net, addr, dev, 1); + int create = 0, update_lft = 0; + + if (!ifp && valid_lft) { + int max_addresses = in6_dev->cnf.max_addresses; + struct ifa6_config cfg = { + .pfx = addr, + .plen = pinfo->prefix_len, + .ifa_flags = addr_flags, + .valid_lft = valid_lft, + .preferred_lft = prefered_lft, + .scope = addr_type & IPV6_ADDR_SCOPE_MASK, + .ifa_proto = IFAPROT_KERNEL_RA + }; + +#ifdef CONFIG_IPV6_OPTIMISTIC_DAD + if ((net->ipv6.devconf_all->optimistic_dad || + in6_dev->cnf.optimistic_dad) && + !net->ipv6.devconf_all->forwarding && sllao) + cfg.ifa_flags |= IFA_F_OPTIMISTIC; +#endif + + /* Do not allow to create too much of autoconfigured + * addresses; this would be too easy way to crash kernel. + */ + if (!max_addresses || + ipv6_count_addresses(in6_dev) < max_addresses) + ifp = ipv6_add_addr(in6_dev, &cfg, false, NULL); + + if (IS_ERR_OR_NULL(ifp)) + return -1; + + create = 1; + spin_lock_bh(&ifp->lock); + ifp->flags |= IFA_F_MANAGETEMPADDR; + ifp->cstamp = jiffies; + ifp->tokenized = tokenized; + spin_unlock_bh(&ifp->lock); + addrconf_dad_start(ifp); + } + + if (ifp) { + u32 flags; + unsigned long now; + u32 stored_lft; + + /* update lifetime (RFC2462 5.5.3 e) */ + spin_lock_bh(&ifp->lock); + now = jiffies; + if (ifp->valid_lft > (now - ifp->tstamp) / HZ) + stored_lft = ifp->valid_lft - (now - ifp->tstamp) / HZ; + else + stored_lft = 0; + if (!create && stored_lft) { + const u32 minimum_lft = min_t(u32, + stored_lft, MIN_VALID_LIFETIME); + valid_lft = max(valid_lft, minimum_lft); + + /* RFC4862 Section 5.5.3e: + * "Note that the preferred lifetime of the + * corresponding address is always reset to + * the Preferred Lifetime in the received + * Prefix Information option, regardless of + * whether the valid lifetime is also reset or + * ignored." + * + * So we should always update prefered_lft here. + */ + update_lft = 1; + } + + if (update_lft) { + ifp->valid_lft = valid_lft; + ifp->prefered_lft = prefered_lft; + ifp->tstamp = now; + flags = ifp->flags; + ifp->flags &= ~IFA_F_DEPRECATED; + spin_unlock_bh(&ifp->lock); + + if (!(flags&IFA_F_TENTATIVE)) + ipv6_ifa_notify(0, ifp); + } else + spin_unlock_bh(&ifp->lock); + + manage_tempaddrs(in6_dev, ifp, valid_lft, prefered_lft, + create, now); + + in6_ifa_put(ifp); + addrconf_verify(net); + } + + return 0; +} +EXPORT_SYMBOL_GPL(addrconf_prefix_rcv_add_addr); + +void addrconf_prefix_rcv(struct net_device *dev, u8 *opt, int len, bool sllao) +{ + struct prefix_info *pinfo; + __u32 valid_lft; + __u32 prefered_lft; + int addr_type, err; + u32 addr_flags = 0; + struct inet6_dev *in6_dev; + struct net *net = dev_net(dev); + + pinfo = (struct prefix_info *) opt; + + if (len < sizeof(struct prefix_info)) { + netdev_dbg(dev, "addrconf: prefix option too short\n"); + return; + } + + /* + * Validation checks ([ADDRCONF], page 19) + */ + + addr_type = ipv6_addr_type(&pinfo->prefix); + + if (addr_type & (IPV6_ADDR_MULTICAST|IPV6_ADDR_LINKLOCAL)) + return; + + valid_lft = ntohl(pinfo->valid); + prefered_lft = ntohl(pinfo->prefered); + + if (prefered_lft > valid_lft) { + net_warn_ratelimited("addrconf: prefix option has invalid lifetime\n"); + return; + } + + in6_dev = in6_dev_get(dev); + + if (!in6_dev) { + net_dbg_ratelimited("addrconf: device %s not configured\n", + dev->name); + return; + } + + if (valid_lft != 0 && valid_lft < in6_dev->cnf.accept_ra_min_lft) + goto put; + + /* + * Two things going on here: + * 1) Add routes for on-link prefixes + * 2) Configure prefixes with the auto flag set + */ + + if (pinfo->onlink) { + struct fib6_info *rt; + unsigned long rt_expires; + + /* Avoid arithmetic overflow. Really, we could + * save rt_expires in seconds, likely valid_lft, + * but it would require division in fib gc, that it + * not good. + */ + if (HZ > USER_HZ) + rt_expires = addrconf_timeout_fixup(valid_lft, HZ); + else + rt_expires = addrconf_timeout_fixup(valid_lft, USER_HZ); + + if (addrconf_finite_timeout(rt_expires)) + rt_expires *= HZ; + + rt = addrconf_get_prefix_route(&pinfo->prefix, + pinfo->prefix_len, + dev, + RTF_ADDRCONF | RTF_PREFIX_RT, + RTF_DEFAULT, true); + + if (rt) { + /* Autoconf prefix route */ + if (valid_lft == 0) { + ip6_del_rt(net, rt, false); + rt = NULL; + } else if (addrconf_finite_timeout(rt_expires)) { + /* not infinity */ + fib6_set_expires(rt, jiffies + rt_expires); + } else { + fib6_clean_expires(rt); + } + } else if (valid_lft) { + clock_t expires = 0; + int flags = RTF_ADDRCONF | RTF_PREFIX_RT; + if (addrconf_finite_timeout(rt_expires)) { + /* not infinity */ + flags |= RTF_EXPIRES; + expires = jiffies_to_clock_t(rt_expires); + } + addrconf_prefix_route(&pinfo->prefix, pinfo->prefix_len, + 0, dev, expires, flags, + GFP_ATOMIC); + } + fib6_info_release(rt); + } + + /* Try to figure out our local address for this prefix */ + + if (pinfo->autoconf && in6_dev->cnf.autoconf) { + struct in6_addr addr; + bool tokenized = false, dev_addr_generated = false; + + if (pinfo->prefix_len == 64) { + memcpy(&addr, &pinfo->prefix, 8); + + if (!ipv6_addr_any(&in6_dev->token)) { + read_lock_bh(&in6_dev->lock); + memcpy(addr.s6_addr + 8, + in6_dev->token.s6_addr + 8, 8); + read_unlock_bh(&in6_dev->lock); + tokenized = true; + } else if (is_addr_mode_generate_stable(in6_dev) && + !ipv6_generate_stable_address(&addr, 0, + in6_dev)) { + addr_flags |= IFA_F_STABLE_PRIVACY; + goto ok; + } else if (ipv6_generate_eui64(addr.s6_addr + 8, dev) && + ipv6_inherit_eui64(addr.s6_addr + 8, in6_dev)) { + goto put; + } else { + dev_addr_generated = true; + } + goto ok; + } + net_dbg_ratelimited("IPv6 addrconf: prefix with wrong length %d\n", + pinfo->prefix_len); + goto put; + +ok: + err = addrconf_prefix_rcv_add_addr(net, dev, pinfo, in6_dev, + &addr, addr_type, + addr_flags, sllao, + tokenized, valid_lft, + prefered_lft); + if (err) + goto put; + + /* Ignore error case here because previous prefix add addr was + * successful which will be notified. + */ + ndisc_ops_prefix_rcv_add_addr(net, dev, pinfo, in6_dev, &addr, + addr_type, addr_flags, sllao, + tokenized, valid_lft, + prefered_lft, + dev_addr_generated); + } + inet6_prefix_notify(RTM_NEWPREFIX, in6_dev, pinfo); +put: + in6_dev_put(in6_dev); +} + +static int addrconf_set_sit_dstaddr(struct net *net, struct net_device *dev, + struct in6_ifreq *ireq) +{ + struct ip_tunnel_parm p = { }; + int err; + + if (!(ipv6_addr_type(&ireq->ifr6_addr) & IPV6_ADDR_COMPATv4)) + return -EADDRNOTAVAIL; + + p.iph.daddr = ireq->ifr6_addr.s6_addr32[3]; + p.iph.version = 4; + p.iph.ihl = 5; + p.iph.protocol = IPPROTO_IPV6; + p.iph.ttl = 64; + + if (!dev->netdev_ops->ndo_tunnel_ctl) + return -EOPNOTSUPP; + err = dev->netdev_ops->ndo_tunnel_ctl(dev, &p, SIOCADDTUNNEL); + if (err) + return err; + + dev = __dev_get_by_name(net, p.name); + if (!dev) + return -ENOBUFS; + return dev_open(dev, NULL); +} + +/* + * Set destination address. + * Special case for SIT interfaces where we create a new "virtual" + * device. + */ +int addrconf_set_dstaddr(struct net *net, void __user *arg) +{ + struct net_device *dev; + struct in6_ifreq ireq; + int err = -ENODEV; + + if (!IS_ENABLED(CONFIG_IPV6_SIT)) + return -ENODEV; + if (copy_from_user(&ireq, arg, sizeof(struct in6_ifreq))) + return -EFAULT; + + rtnl_lock(); + dev = __dev_get_by_index(net, ireq.ifr6_ifindex); + if (dev && dev->type == ARPHRD_SIT) + err = addrconf_set_sit_dstaddr(net, dev, &ireq); + rtnl_unlock(); + return err; +} + +static int ipv6_mc_config(struct sock *sk, bool join, + const struct in6_addr *addr, int ifindex) +{ + int ret; + + ASSERT_RTNL(); + + lock_sock(sk); + if (join) + ret = ipv6_sock_mc_join(sk, ifindex, addr); + else + ret = ipv6_sock_mc_drop(sk, ifindex, addr); + release_sock(sk); + + return ret; +} + +/* + * Manual configuration of address on an interface + */ +static int inet6_addr_add(struct net *net, int ifindex, + struct ifa6_config *cfg, + struct netlink_ext_ack *extack) +{ + struct inet6_ifaddr *ifp; + struct inet6_dev *idev; + struct net_device *dev; + unsigned long timeout; + clock_t expires; + u32 flags; + + ASSERT_RTNL(); + + if (cfg->plen > 128) + return -EINVAL; + + /* check the lifetime */ + if (!cfg->valid_lft || cfg->preferred_lft > cfg->valid_lft) + return -EINVAL; + + if (cfg->ifa_flags & IFA_F_MANAGETEMPADDR && cfg->plen != 64) + return -EINVAL; + + dev = __dev_get_by_index(net, ifindex); + if (!dev) + return -ENODEV; + + idev = addrconf_add_dev(dev); + if (IS_ERR(idev)) + return PTR_ERR(idev); + + if (cfg->ifa_flags & IFA_F_MCAUTOJOIN) { + int ret = ipv6_mc_config(net->ipv6.mc_autojoin_sk, + true, cfg->pfx, ifindex); + + if (ret < 0) + return ret; + } + + cfg->scope = ipv6_addr_scope(cfg->pfx); + + timeout = addrconf_timeout_fixup(cfg->valid_lft, HZ); + if (addrconf_finite_timeout(timeout)) { + expires = jiffies_to_clock_t(timeout * HZ); + cfg->valid_lft = timeout; + flags = RTF_EXPIRES; + } else { + expires = 0; + flags = 0; + cfg->ifa_flags |= IFA_F_PERMANENT; + } + + timeout = addrconf_timeout_fixup(cfg->preferred_lft, HZ); + if (addrconf_finite_timeout(timeout)) { + if (timeout == 0) + cfg->ifa_flags |= IFA_F_DEPRECATED; + cfg->preferred_lft = timeout; + } + + ifp = ipv6_add_addr(idev, cfg, true, extack); + if (!IS_ERR(ifp)) { + if (!(cfg->ifa_flags & IFA_F_NOPREFIXROUTE)) { + addrconf_prefix_route(&ifp->addr, ifp->prefix_len, + ifp->rt_priority, dev, expires, + flags, GFP_KERNEL); + } + + /* Send a netlink notification if DAD is enabled and + * optimistic flag is not set + */ + if (!(ifp->flags & (IFA_F_OPTIMISTIC | IFA_F_NODAD))) + ipv6_ifa_notify(0, ifp); + /* + * Note that section 3.1 of RFC 4429 indicates + * that the Optimistic flag should not be set for + * manually configured addresses + */ + addrconf_dad_start(ifp); + if (cfg->ifa_flags & IFA_F_MANAGETEMPADDR) + manage_tempaddrs(idev, ifp, cfg->valid_lft, + cfg->preferred_lft, true, jiffies); + in6_ifa_put(ifp); + addrconf_verify_rtnl(net); + return 0; + } else if (cfg->ifa_flags & IFA_F_MCAUTOJOIN) { + ipv6_mc_config(net->ipv6.mc_autojoin_sk, false, + cfg->pfx, ifindex); + } + + return PTR_ERR(ifp); +} + +static int inet6_addr_del(struct net *net, int ifindex, u32 ifa_flags, + const struct in6_addr *pfx, unsigned int plen) +{ + struct inet6_ifaddr *ifp; + struct inet6_dev *idev; + struct net_device *dev; + + if (plen > 128) + return -EINVAL; + + dev = __dev_get_by_index(net, ifindex); + if (!dev) + return -ENODEV; + + idev = __in6_dev_get(dev); + if (!idev) + return -ENXIO; + + read_lock_bh(&idev->lock); + list_for_each_entry(ifp, &idev->addr_list, if_list) { + if (ifp->prefix_len == plen && + ipv6_addr_equal(pfx, &ifp->addr)) { + in6_ifa_hold(ifp); + read_unlock_bh(&idev->lock); + + if (!(ifp->flags & IFA_F_TEMPORARY) && + (ifa_flags & IFA_F_MANAGETEMPADDR)) + manage_tempaddrs(idev, ifp, 0, 0, false, + jiffies); + ipv6_del_addr(ifp); + addrconf_verify_rtnl(net); + if (ipv6_addr_is_multicast(pfx)) { + ipv6_mc_config(net->ipv6.mc_autojoin_sk, + false, pfx, dev->ifindex); + } + return 0; + } + } + read_unlock_bh(&idev->lock); + return -EADDRNOTAVAIL; +} + + +int addrconf_add_ifaddr(struct net *net, void __user *arg) +{ + struct ifa6_config cfg = { + .ifa_flags = IFA_F_PERMANENT, + .preferred_lft = INFINITY_LIFE_TIME, + .valid_lft = INFINITY_LIFE_TIME, + }; + struct in6_ifreq ireq; + int err; + + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + if (copy_from_user(&ireq, arg, sizeof(struct in6_ifreq))) + return -EFAULT; + + cfg.pfx = &ireq.ifr6_addr; + cfg.plen = ireq.ifr6_prefixlen; + + rtnl_lock(); + err = inet6_addr_add(net, ireq.ifr6_ifindex, &cfg, NULL); + rtnl_unlock(); + return err; +} + +int addrconf_del_ifaddr(struct net *net, void __user *arg) +{ + struct in6_ifreq ireq; + int err; + + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + if (copy_from_user(&ireq, arg, sizeof(struct in6_ifreq))) + return -EFAULT; + + rtnl_lock(); + err = inet6_addr_del(net, ireq.ifr6_ifindex, 0, &ireq.ifr6_addr, + ireq.ifr6_prefixlen); + rtnl_unlock(); + return err; +} + +static void add_addr(struct inet6_dev *idev, const struct in6_addr *addr, + int plen, int scope, u8 proto) +{ + struct inet6_ifaddr *ifp; + struct ifa6_config cfg = { + .pfx = addr, + .plen = plen, + .ifa_flags = IFA_F_PERMANENT, + .valid_lft = INFINITY_LIFE_TIME, + .preferred_lft = INFINITY_LIFE_TIME, + .scope = scope, + .ifa_proto = proto + }; + + ifp = ipv6_add_addr(idev, &cfg, true, NULL); + if (!IS_ERR(ifp)) { + spin_lock_bh(&ifp->lock); + ifp->flags &= ~IFA_F_TENTATIVE; + spin_unlock_bh(&ifp->lock); + rt_genid_bump_ipv6(dev_net(idev->dev)); + ipv6_ifa_notify(RTM_NEWADDR, ifp); + in6_ifa_put(ifp); + } +} + +#if IS_ENABLED(CONFIG_IPV6_SIT) || IS_ENABLED(CONFIG_NET_IPGRE) || IS_ENABLED(CONFIG_IPV6_GRE) +static void add_v4_addrs(struct inet6_dev *idev) +{ + struct in6_addr addr; + struct net_device *dev; + struct net *net = dev_net(idev->dev); + int scope, plen, offset = 0; + u32 pflags = 0; + + ASSERT_RTNL(); + + memset(&addr, 0, sizeof(struct in6_addr)); + /* in case of IP6GRE the dev_addr is an IPv6 and therefore we use only the last 4 bytes */ + if (idev->dev->addr_len == sizeof(struct in6_addr)) + offset = sizeof(struct in6_addr) - 4; + memcpy(&addr.s6_addr32[3], idev->dev->dev_addr + offset, 4); + + if (!(idev->dev->flags & IFF_POINTOPOINT) && idev->dev->type == ARPHRD_SIT) { + scope = IPV6_ADDR_COMPATv4; + plen = 96; + pflags |= RTF_NONEXTHOP; + } else { + if (idev->cnf.addr_gen_mode == IN6_ADDR_GEN_MODE_NONE) + return; + + addr.s6_addr32[0] = htonl(0xfe800000); + scope = IFA_LINK; + plen = 64; + } + + if (addr.s6_addr32[3]) { + add_addr(idev, &addr, plen, scope, IFAPROT_UNSPEC); + addrconf_prefix_route(&addr, plen, 0, idev->dev, 0, pflags, + GFP_KERNEL); + return; + } + + for_each_netdev(net, dev) { + struct in_device *in_dev = __in_dev_get_rtnl(dev); + if (in_dev && (dev->flags & IFF_UP)) { + struct in_ifaddr *ifa; + int flag = scope; + + in_dev_for_each_ifa_rtnl(ifa, in_dev) { + addr.s6_addr32[3] = ifa->ifa_local; + + if (ifa->ifa_scope == RT_SCOPE_LINK) + continue; + if (ifa->ifa_scope >= RT_SCOPE_HOST) { + if (idev->dev->flags&IFF_POINTOPOINT) + continue; + flag |= IFA_HOST; + } + + add_addr(idev, &addr, plen, flag, + IFAPROT_UNSPEC); + addrconf_prefix_route(&addr, plen, 0, idev->dev, + 0, pflags, GFP_KERNEL); + } + } + } +} +#endif + +static void init_loopback(struct net_device *dev) +{ + struct inet6_dev *idev; + + /* ::1 */ + + ASSERT_RTNL(); + + idev = ipv6_find_idev(dev); + if (IS_ERR(idev)) { + pr_debug("%s: add_dev failed\n", __func__); + return; + } + + add_addr(idev, &in6addr_loopback, 128, IFA_HOST, IFAPROT_KERNEL_LO); +} + +void addrconf_add_linklocal(struct inet6_dev *idev, + const struct in6_addr *addr, u32 flags) +{ + struct ifa6_config cfg = { + .pfx = addr, + .plen = 64, + .ifa_flags = flags | IFA_F_PERMANENT, + .valid_lft = INFINITY_LIFE_TIME, + .preferred_lft = INFINITY_LIFE_TIME, + .scope = IFA_LINK, + .ifa_proto = IFAPROT_KERNEL_LL + }; + struct inet6_ifaddr *ifp; + +#ifdef CONFIG_IPV6_OPTIMISTIC_DAD + if ((dev_net(idev->dev)->ipv6.devconf_all->optimistic_dad || + idev->cnf.optimistic_dad) && + !dev_net(idev->dev)->ipv6.devconf_all->forwarding) + cfg.ifa_flags |= IFA_F_OPTIMISTIC; +#endif + + ifp = ipv6_add_addr(idev, &cfg, true, NULL); + if (!IS_ERR(ifp)) { + addrconf_prefix_route(&ifp->addr, ifp->prefix_len, 0, idev->dev, + 0, 0, GFP_ATOMIC); + addrconf_dad_start(ifp); + in6_ifa_put(ifp); + } +} +EXPORT_SYMBOL_GPL(addrconf_add_linklocal); + +static bool ipv6_reserved_interfaceid(struct in6_addr address) +{ + if ((address.s6_addr32[2] | address.s6_addr32[3]) == 0) + return true; + + if (address.s6_addr32[2] == htonl(0x02005eff) && + ((address.s6_addr32[3] & htonl(0xfe000000)) == htonl(0xfe000000))) + return true; + + if (address.s6_addr32[2] == htonl(0xfdffffff) && + ((address.s6_addr32[3] & htonl(0xffffff80)) == htonl(0xffffff80))) + return true; + + return false; +} + +static int ipv6_generate_stable_address(struct in6_addr *address, + u8 dad_count, + const struct inet6_dev *idev) +{ + static DEFINE_SPINLOCK(lock); + static __u32 digest[SHA1_DIGEST_WORDS]; + static __u32 workspace[SHA1_WORKSPACE_WORDS]; + + static union { + char __data[SHA1_BLOCK_SIZE]; + struct { + struct in6_addr secret; + __be32 prefix[2]; + unsigned char hwaddr[MAX_ADDR_LEN]; + u8 dad_count; + } __packed; + } data; + + struct in6_addr secret; + struct in6_addr temp; + struct net *net = dev_net(idev->dev); + + BUILD_BUG_ON(sizeof(data.__data) != sizeof(data)); + + if (idev->cnf.stable_secret.initialized) + secret = idev->cnf.stable_secret.secret; + else if (net->ipv6.devconf_dflt->stable_secret.initialized) + secret = net->ipv6.devconf_dflt->stable_secret.secret; + else + return -1; + +retry: + spin_lock_bh(&lock); + + sha1_init(digest); + memset(&data, 0, sizeof(data)); + memset(workspace, 0, sizeof(workspace)); + memcpy(data.hwaddr, idev->dev->perm_addr, idev->dev->addr_len); + data.prefix[0] = address->s6_addr32[0]; + data.prefix[1] = address->s6_addr32[1]; + data.secret = secret; + data.dad_count = dad_count; + + sha1_transform(digest, data.__data, workspace); + + temp = *address; + temp.s6_addr32[2] = (__force __be32)digest[0]; + temp.s6_addr32[3] = (__force __be32)digest[1]; + + spin_unlock_bh(&lock); + + if (ipv6_reserved_interfaceid(temp)) { + dad_count++; + if (dad_count > dev_net(idev->dev)->ipv6.sysctl.idgen_retries) + return -1; + goto retry; + } + + *address = temp; + return 0; +} + +static void ipv6_gen_mode_random_init(struct inet6_dev *idev) +{ + struct ipv6_stable_secret *s = &idev->cnf.stable_secret; + + if (s->initialized) + return; + s = &idev->cnf.stable_secret; + get_random_bytes(&s->secret, sizeof(s->secret)); + s->initialized = true; +} + +static void addrconf_addr_gen(struct inet6_dev *idev, bool prefix_route) +{ + struct in6_addr addr; + + /* no link local addresses on L3 master devices */ + if (netif_is_l3_master(idev->dev)) + return; + + /* no link local addresses on devices flagged as slaves */ + if (idev->dev->flags & IFF_SLAVE) + return; + + ipv6_addr_set(&addr, htonl(0xFE800000), 0, 0, 0); + + switch (idev->cnf.addr_gen_mode) { + case IN6_ADDR_GEN_MODE_RANDOM: + ipv6_gen_mode_random_init(idev); + fallthrough; + case IN6_ADDR_GEN_MODE_STABLE_PRIVACY: + if (!ipv6_generate_stable_address(&addr, 0, idev)) + addrconf_add_linklocal(idev, &addr, + IFA_F_STABLE_PRIVACY); + else if (prefix_route) + addrconf_prefix_route(&addr, 64, 0, idev->dev, + 0, 0, GFP_KERNEL); + break; + case IN6_ADDR_GEN_MODE_EUI64: + /* addrconf_add_linklocal also adds a prefix_route and we + * only need to care about prefix routes if ipv6_generate_eui64 + * couldn't generate one. + */ + if (ipv6_generate_eui64(addr.s6_addr + 8, idev->dev) == 0) + addrconf_add_linklocal(idev, &addr, 0); + else if (prefix_route) + addrconf_prefix_route(&addr, 64, 0, idev->dev, + 0, 0, GFP_KERNEL); + break; + case IN6_ADDR_GEN_MODE_NONE: + default: + /* will not add any link local address */ + break; + } +} + +static void addrconf_dev_config(struct net_device *dev) +{ + struct inet6_dev *idev; + + ASSERT_RTNL(); + + if ((dev->type != ARPHRD_ETHER) && + (dev->type != ARPHRD_FDDI) && + (dev->type != ARPHRD_ARCNET) && + (dev->type != ARPHRD_INFINIBAND) && + (dev->type != ARPHRD_IEEE1394) && + (dev->type != ARPHRD_TUNNEL6) && + (dev->type != ARPHRD_6LOWPAN) && + (dev->type != ARPHRD_TUNNEL) && + (dev->type != ARPHRD_NONE) && + (dev->type != ARPHRD_RAWIP)) { + /* Alas, we support only Ethernet autoconfiguration. */ + idev = __in6_dev_get(dev); + if (!IS_ERR_OR_NULL(idev) && dev->flags & IFF_UP && + dev->flags & IFF_MULTICAST) + ipv6_mc_up(idev); + return; + } + + idev = addrconf_add_dev(dev); + if (IS_ERR(idev)) + return; + + /* this device type has no EUI support */ + if (dev->type == ARPHRD_NONE && + idev->cnf.addr_gen_mode == IN6_ADDR_GEN_MODE_EUI64) + idev->cnf.addr_gen_mode = IN6_ADDR_GEN_MODE_RANDOM; + + addrconf_addr_gen(idev, false); +} + +#if IS_ENABLED(CONFIG_IPV6_SIT) +static void addrconf_sit_config(struct net_device *dev) +{ + struct inet6_dev *idev; + + ASSERT_RTNL(); + + /* + * Configure the tunnel with one of our IPv4 + * addresses... we should configure all of + * our v4 addrs in the tunnel + */ + + idev = ipv6_find_idev(dev); + if (IS_ERR(idev)) { + pr_debug("%s: add_dev failed\n", __func__); + return; + } + + if (dev->priv_flags & IFF_ISATAP) { + addrconf_addr_gen(idev, false); + return; + } + + add_v4_addrs(idev); + + if (dev->flags&IFF_POINTOPOINT) + addrconf_add_mroute(dev); +} +#endif + +#if IS_ENABLED(CONFIG_NET_IPGRE) || IS_ENABLED(CONFIG_IPV6_GRE) +static void addrconf_gre_config(struct net_device *dev) +{ + struct inet6_dev *idev; + + ASSERT_RTNL(); + + idev = ipv6_find_idev(dev); + if (IS_ERR(idev)) { + pr_debug("%s: add_dev failed\n", __func__); + return; + } + + if (dev->type == ARPHRD_ETHER) { + addrconf_addr_gen(idev, true); + return; + } + + add_v4_addrs(idev); + + if (dev->flags & IFF_POINTOPOINT) + addrconf_add_mroute(dev); +} +#endif + +static void addrconf_init_auto_addrs(struct net_device *dev) +{ + switch (dev->type) { +#if IS_ENABLED(CONFIG_IPV6_SIT) + case ARPHRD_SIT: + addrconf_sit_config(dev); + break; +#endif +#if IS_ENABLED(CONFIG_NET_IPGRE) || IS_ENABLED(CONFIG_IPV6_GRE) + case ARPHRD_IP6GRE: + case ARPHRD_IPGRE: + addrconf_gre_config(dev); + break; +#endif + case ARPHRD_LOOPBACK: + init_loopback(dev); + break; + + default: + addrconf_dev_config(dev); + break; + } +} + +static int fixup_permanent_addr(struct net *net, + struct inet6_dev *idev, + struct inet6_ifaddr *ifp) +{ + /* !fib6_node means the host route was removed from the + * FIB, for example, if 'lo' device is taken down. In that + * case regenerate the host route. + */ + if (!ifp->rt || !ifp->rt->fib6_node) { + struct fib6_info *f6i, *prev; + + f6i = addrconf_f6i_alloc(net, idev, &ifp->addr, false, + GFP_ATOMIC); + if (IS_ERR(f6i)) + return PTR_ERR(f6i); + + /* ifp->rt can be accessed outside of rtnl */ + spin_lock(&ifp->lock); + prev = ifp->rt; + ifp->rt = f6i; + spin_unlock(&ifp->lock); + + fib6_info_release(prev); + } + + if (!(ifp->flags & IFA_F_NOPREFIXROUTE)) { + addrconf_prefix_route(&ifp->addr, ifp->prefix_len, + ifp->rt_priority, idev->dev, 0, 0, + GFP_ATOMIC); + } + + if (ifp->state == INET6_IFADDR_STATE_PREDAD) + addrconf_dad_start(ifp); + + return 0; +} + +static void addrconf_permanent_addr(struct net *net, struct net_device *dev) +{ + struct inet6_ifaddr *ifp, *tmp; + struct inet6_dev *idev; + + idev = __in6_dev_get(dev); + if (!idev) + return; + + write_lock_bh(&idev->lock); + + list_for_each_entry_safe(ifp, tmp, &idev->addr_list, if_list) { + if ((ifp->flags & IFA_F_PERMANENT) && + fixup_permanent_addr(net, idev, ifp) < 0) { + write_unlock_bh(&idev->lock); + in6_ifa_hold(ifp); + ipv6_del_addr(ifp); + write_lock_bh(&idev->lock); + + net_info_ratelimited("%s: Failed to add prefix route for address %pI6c; dropping\n", + idev->dev->name, &ifp->addr); + } + } + + write_unlock_bh(&idev->lock); +} + +static int addrconf_notify(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct netdev_notifier_change_info *change_info; + struct netdev_notifier_changeupper_info *info; + struct inet6_dev *idev = __in6_dev_get(dev); + struct net *net = dev_net(dev); + int run_pending = 0; + int err; + + switch (event) { + case NETDEV_REGISTER: + if (!idev && dev->mtu >= IPV6_MIN_MTU) { + idev = ipv6_add_dev(dev); + if (IS_ERR(idev)) + return notifier_from_errno(PTR_ERR(idev)); + } + break; + + case NETDEV_CHANGEMTU: + /* if MTU under IPV6_MIN_MTU stop IPv6 on this interface. */ + if (dev->mtu < IPV6_MIN_MTU) { + addrconf_ifdown(dev, dev != net->loopback_dev); + break; + } + + if (idev) { + rt6_mtu_change(dev, dev->mtu); + idev->cnf.mtu6 = dev->mtu; + break; + } + + /* allocate new idev */ + idev = ipv6_add_dev(dev); + if (IS_ERR(idev)) + break; + + /* device is still not ready */ + if (!(idev->if_flags & IF_READY)) + break; + + run_pending = 1; + fallthrough; + case NETDEV_UP: + case NETDEV_CHANGE: + if (idev && idev->cnf.disable_ipv6) + break; + + if (dev->flags & IFF_SLAVE) { + if (event == NETDEV_UP && !IS_ERR_OR_NULL(idev) && + dev->flags & IFF_UP && dev->flags & IFF_MULTICAST) + ipv6_mc_up(idev); + break; + } + + if (event == NETDEV_UP) { + /* restore routes for permanent addresses */ + addrconf_permanent_addr(net, dev); + + if (!addrconf_link_ready(dev)) { + /* device is not ready yet. */ + pr_debug("ADDRCONF(NETDEV_UP): %s: link is not ready\n", + dev->name); + break; + } + + if (!idev && dev->mtu >= IPV6_MIN_MTU) + idev = ipv6_add_dev(dev); + + if (!IS_ERR_OR_NULL(idev)) { + idev->if_flags |= IF_READY; + run_pending = 1; + } + } else if (event == NETDEV_CHANGE) { + if (!addrconf_link_ready(dev)) { + /* device is still not ready. */ + rt6_sync_down_dev(dev, event); + break; + } + + if (!IS_ERR_OR_NULL(idev)) { + if (idev->if_flags & IF_READY) { + /* device is already configured - + * but resend MLD reports, we might + * have roamed and need to update + * multicast snooping switches + */ + ipv6_mc_up(idev); + change_info = ptr; + if (change_info->flags_changed & IFF_NOARP) + addrconf_dad_run(idev, true); + rt6_sync_up(dev, RTNH_F_LINKDOWN); + break; + } + idev->if_flags |= IF_READY; + } + + pr_info("ADDRCONF(NETDEV_CHANGE): %s: link becomes ready\n", + dev->name); + + run_pending = 1; + } + + addrconf_init_auto_addrs(dev); + + if (!IS_ERR_OR_NULL(idev)) { + if (run_pending) + addrconf_dad_run(idev, false); + + /* Device has an address by now */ + rt6_sync_up(dev, RTNH_F_DEAD); + + /* + * If the MTU changed during the interface down, + * when the interface up, the changed MTU must be + * reflected in the idev as well as routers. + */ + if (idev->cnf.mtu6 != dev->mtu && + dev->mtu >= IPV6_MIN_MTU) { + rt6_mtu_change(dev, dev->mtu); + idev->cnf.mtu6 = dev->mtu; + } + idev->tstamp = jiffies; + inet6_ifinfo_notify(RTM_NEWLINK, idev); + + /* + * If the changed mtu during down is lower than + * IPV6_MIN_MTU stop IPv6 on this interface. + */ + if (dev->mtu < IPV6_MIN_MTU) + addrconf_ifdown(dev, dev != net->loopback_dev); + } + break; + + case NETDEV_DOWN: + case NETDEV_UNREGISTER: + /* + * Remove all addresses from this interface. + */ + addrconf_ifdown(dev, event != NETDEV_DOWN); + break; + + case NETDEV_CHANGENAME: + if (idev) { + snmp6_unregister_dev(idev); + addrconf_sysctl_unregister(idev); + err = addrconf_sysctl_register(idev); + if (err) + return notifier_from_errno(err); + err = snmp6_register_dev(idev); + if (err) { + addrconf_sysctl_unregister(idev); + return notifier_from_errno(err); + } + } + break; + + case NETDEV_PRE_TYPE_CHANGE: + case NETDEV_POST_TYPE_CHANGE: + if (idev) + addrconf_type_change(dev, event); + break; + + case NETDEV_CHANGEUPPER: + info = ptr; + + /* flush all routes if dev is linked to or unlinked from + * an L3 master device (e.g., VRF) + */ + if (info->upper_dev && netif_is_l3_master(info->upper_dev)) + addrconf_ifdown(dev, false); + } + + return NOTIFY_OK; +} + +/* + * addrconf module should be notified of a device going up + */ +static struct notifier_block ipv6_dev_notf = { + .notifier_call = addrconf_notify, + .priority = ADDRCONF_NOTIFY_PRIORITY, +}; + +static void addrconf_type_change(struct net_device *dev, unsigned long event) +{ + struct inet6_dev *idev; + ASSERT_RTNL(); + + idev = __in6_dev_get(dev); + + if (event == NETDEV_POST_TYPE_CHANGE) + ipv6_mc_remap(idev); + else if (event == NETDEV_PRE_TYPE_CHANGE) + ipv6_mc_unmap(idev); +} + +static bool addr_is_local(const struct in6_addr *addr) +{ + return ipv6_addr_type(addr) & + (IPV6_ADDR_LINKLOCAL | IPV6_ADDR_LOOPBACK); +} + +static int addrconf_ifdown(struct net_device *dev, bool unregister) +{ + unsigned long event = unregister ? NETDEV_UNREGISTER : NETDEV_DOWN; + struct net *net = dev_net(dev); + struct inet6_dev *idev; + struct inet6_ifaddr *ifa; + LIST_HEAD(tmp_addr_list); + bool keep_addr = false; + bool was_ready; + int state, i; + + ASSERT_RTNL(); + + rt6_disable_ip(dev, event); + + idev = __in6_dev_get(dev); + if (!idev) + return -ENODEV; + + /* + * Step 1: remove reference to ipv6 device from parent device. + * Do not dev_put! + */ + if (unregister) { + idev->dead = 1; + + /* protected by rtnl_lock */ + RCU_INIT_POINTER(dev->ip6_ptr, NULL); + + /* Step 1.5: remove snmp6 entry */ + snmp6_unregister_dev(idev); + + } + + /* combine the user config with event to determine if permanent + * addresses are to be removed from address hash table + */ + if (!unregister && !idev->cnf.disable_ipv6) { + /* aggregate the system setting and interface setting */ + int _keep_addr = net->ipv6.devconf_all->keep_addr_on_down; + + if (!_keep_addr) + _keep_addr = idev->cnf.keep_addr_on_down; + + keep_addr = (_keep_addr > 0); + } + + /* Step 2: clear hash table */ + for (i = 0; i < IN6_ADDR_HSIZE; i++) { + struct hlist_head *h = &net->ipv6.inet6_addr_lst[i]; + + spin_lock_bh(&net->ipv6.addrconf_hash_lock); +restart: + hlist_for_each_entry_rcu(ifa, h, addr_lst) { + if (ifa->idev == idev) { + addrconf_del_dad_work(ifa); + /* combined flag + permanent flag decide if + * address is retained on a down event + */ + if (!keep_addr || + !(ifa->flags & IFA_F_PERMANENT) || + addr_is_local(&ifa->addr)) { + hlist_del_init_rcu(&ifa->addr_lst); + goto restart; + } + } + } + spin_unlock_bh(&net->ipv6.addrconf_hash_lock); + } + + write_lock_bh(&idev->lock); + + addrconf_del_rs_timer(idev); + + /* Step 2: clear flags for stateless addrconf, repeated down + * detection + */ + was_ready = idev->if_flags & IF_READY; + if (!unregister) + idev->if_flags &= ~(IF_RS_SENT|IF_RA_RCVD|IF_READY); + + /* Step 3: clear tempaddr list */ + while (!list_empty(&idev->tempaddr_list)) { + ifa = list_first_entry(&idev->tempaddr_list, + struct inet6_ifaddr, tmp_list); + list_del(&ifa->tmp_list); + write_unlock_bh(&idev->lock); + spin_lock_bh(&ifa->lock); + + if (ifa->ifpub) { + in6_ifa_put(ifa->ifpub); + ifa->ifpub = NULL; + } + spin_unlock_bh(&ifa->lock); + in6_ifa_put(ifa); + write_lock_bh(&idev->lock); + } + + list_for_each_entry(ifa, &idev->addr_list, if_list) + list_add_tail(&ifa->if_list_aux, &tmp_addr_list); + write_unlock_bh(&idev->lock); + + while (!list_empty(&tmp_addr_list)) { + struct fib6_info *rt = NULL; + bool keep; + + ifa = list_first_entry(&tmp_addr_list, + struct inet6_ifaddr, if_list_aux); + list_del(&ifa->if_list_aux); + + addrconf_del_dad_work(ifa); + + keep = keep_addr && (ifa->flags & IFA_F_PERMANENT) && + !addr_is_local(&ifa->addr); + + spin_lock_bh(&ifa->lock); + + if (keep) { + /* set state to skip the notifier below */ + state = INET6_IFADDR_STATE_DEAD; + ifa->state = INET6_IFADDR_STATE_PREDAD; + if (!(ifa->flags & IFA_F_NODAD)) + ifa->flags |= IFA_F_TENTATIVE; + + rt = ifa->rt; + ifa->rt = NULL; + } else { + state = ifa->state; + ifa->state = INET6_IFADDR_STATE_DEAD; + } + + spin_unlock_bh(&ifa->lock); + + if (rt) + ip6_del_rt(net, rt, false); + + if (state != INET6_IFADDR_STATE_DEAD) { + __ipv6_ifa_notify(RTM_DELADDR, ifa); + inet6addr_notifier_call_chain(NETDEV_DOWN, ifa); + } else { + if (idev->cnf.forwarding) + addrconf_leave_anycast(ifa); + addrconf_leave_solict(ifa->idev, &ifa->addr); + } + + if (!keep) { + write_lock_bh(&idev->lock); + list_del_rcu(&ifa->if_list); + write_unlock_bh(&idev->lock); + in6_ifa_put(ifa); + } + } + + /* Step 5: Discard anycast and multicast list */ + if (unregister) { + ipv6_ac_destroy_dev(idev); + ipv6_mc_destroy_dev(idev); + } else if (was_ready) { + ipv6_mc_down(idev); + } + + idev->tstamp = jiffies; + idev->ra_mtu = 0; + + /* Last: Shot the device (if unregistered) */ + if (unregister) { + addrconf_sysctl_unregister(idev); + neigh_parms_release(&nd_tbl, idev->nd_parms); + neigh_ifdown(&nd_tbl, dev); + in6_dev_put(idev); + } + return 0; +} + +static void addrconf_rs_timer(struct timer_list *t) +{ + struct inet6_dev *idev = from_timer(idev, t, rs_timer); + struct net_device *dev = idev->dev; + struct in6_addr lladdr; + + write_lock(&idev->lock); + if (idev->dead || !(idev->if_flags & IF_READY)) + goto out; + + if (!ipv6_accept_ra(idev)) + goto out; + + /* Announcement received after solicitation was sent */ + if (idev->if_flags & IF_RA_RCVD) + goto out; + + if (idev->rs_probes++ < idev->cnf.rtr_solicits || idev->cnf.rtr_solicits < 0) { + write_unlock(&idev->lock); + if (!ipv6_get_lladdr(dev, &lladdr, IFA_F_TENTATIVE)) + ndisc_send_rs(dev, &lladdr, + &in6addr_linklocal_allrouters); + else + goto put; + + write_lock(&idev->lock); + idev->rs_interval = rfc3315_s14_backoff_update( + idev->rs_interval, idev->cnf.rtr_solicit_max_interval); + /* The wait after the last probe can be shorter */ + addrconf_mod_rs_timer(idev, (idev->rs_probes == + idev->cnf.rtr_solicits) ? + idev->cnf.rtr_solicit_delay : + idev->rs_interval); + } else { + /* + * Note: we do not support deprecated "all on-link" + * assumption any longer. + */ + pr_debug("%s: no IPv6 routers present\n", idev->dev->name); + } + +out: + write_unlock(&idev->lock); +put: + in6_dev_put(idev); +} + +/* + * Duplicate Address Detection + */ +static void addrconf_dad_kick(struct inet6_ifaddr *ifp) +{ + unsigned long rand_num; + struct inet6_dev *idev = ifp->idev; + u64 nonce; + + if (ifp->flags & IFA_F_OPTIMISTIC) + rand_num = 0; + else + rand_num = prandom_u32_max(idev->cnf.rtr_solicit_delay ?: 1); + + nonce = 0; + if (idev->cnf.enhanced_dad || + dev_net(idev->dev)->ipv6.devconf_all->enhanced_dad) { + do + get_random_bytes(&nonce, 6); + while (nonce == 0); + } + ifp->dad_nonce = nonce; + ifp->dad_probes = idev->cnf.dad_transmits; + addrconf_mod_dad_work(ifp, rand_num); +} + +static void addrconf_dad_begin(struct inet6_ifaddr *ifp) +{ + struct inet6_dev *idev = ifp->idev; + struct net_device *dev = idev->dev; + bool bump_id, notify = false; + struct net *net; + + addrconf_join_solict(dev, &ifp->addr); + + read_lock_bh(&idev->lock); + spin_lock(&ifp->lock); + if (ifp->state == INET6_IFADDR_STATE_DEAD) + goto out; + + net = dev_net(dev); + if (dev->flags&(IFF_NOARP|IFF_LOOPBACK) || + (net->ipv6.devconf_all->accept_dad < 1 && + idev->cnf.accept_dad < 1) || + !(ifp->flags&IFA_F_TENTATIVE) || + ifp->flags & IFA_F_NODAD) { + bool send_na = false; + + if (ifp->flags & IFA_F_TENTATIVE && + !(ifp->flags & IFA_F_OPTIMISTIC)) + send_na = true; + bump_id = ifp->flags & IFA_F_TENTATIVE; + ifp->flags &= ~(IFA_F_TENTATIVE|IFA_F_OPTIMISTIC|IFA_F_DADFAILED); + spin_unlock(&ifp->lock); + read_unlock_bh(&idev->lock); + + addrconf_dad_completed(ifp, bump_id, send_na); + return; + } + + if (!(idev->if_flags & IF_READY)) { + spin_unlock(&ifp->lock); + read_unlock_bh(&idev->lock); + /* + * If the device is not ready: + * - keep it tentative if it is a permanent address. + * - otherwise, kill it. + */ + in6_ifa_hold(ifp); + addrconf_dad_stop(ifp, 0); + return; + } + + /* + * Optimistic nodes can start receiving + * Frames right away + */ + if (ifp->flags & IFA_F_OPTIMISTIC) { + ip6_ins_rt(net, ifp->rt); + if (ipv6_use_optimistic_addr(net, idev)) { + /* Because optimistic nodes can use this address, + * notify listeners. If DAD fails, RTM_DELADDR is sent. + */ + notify = true; + } + } + + addrconf_dad_kick(ifp); +out: + spin_unlock(&ifp->lock); + read_unlock_bh(&idev->lock); + if (notify) + ipv6_ifa_notify(RTM_NEWADDR, ifp); +} + +static void addrconf_dad_start(struct inet6_ifaddr *ifp) +{ + bool begin_dad = false; + + spin_lock_bh(&ifp->lock); + if (ifp->state != INET6_IFADDR_STATE_DEAD) { + ifp->state = INET6_IFADDR_STATE_PREDAD; + begin_dad = true; + } + spin_unlock_bh(&ifp->lock); + + if (begin_dad) + addrconf_mod_dad_work(ifp, 0); +} + +static void addrconf_dad_work(struct work_struct *w) +{ + struct inet6_ifaddr *ifp = container_of(to_delayed_work(w), + struct inet6_ifaddr, + dad_work); + struct inet6_dev *idev = ifp->idev; + bool bump_id, disable_ipv6 = false; + struct in6_addr mcaddr; + + enum { + DAD_PROCESS, + DAD_BEGIN, + DAD_ABORT, + } action = DAD_PROCESS; + + rtnl_lock(); + + spin_lock_bh(&ifp->lock); + if (ifp->state == INET6_IFADDR_STATE_PREDAD) { + action = DAD_BEGIN; + ifp->state = INET6_IFADDR_STATE_DAD; + } else if (ifp->state == INET6_IFADDR_STATE_ERRDAD) { + action = DAD_ABORT; + ifp->state = INET6_IFADDR_STATE_POSTDAD; + + if ((dev_net(idev->dev)->ipv6.devconf_all->accept_dad > 1 || + idev->cnf.accept_dad > 1) && + !idev->cnf.disable_ipv6 && + !(ifp->flags & IFA_F_STABLE_PRIVACY)) { + struct in6_addr addr; + + addr.s6_addr32[0] = htonl(0xfe800000); + addr.s6_addr32[1] = 0; + + if (!ipv6_generate_eui64(addr.s6_addr + 8, idev->dev) && + ipv6_addr_equal(&ifp->addr, &addr)) { + /* DAD failed for link-local based on MAC */ + idev->cnf.disable_ipv6 = 1; + + pr_info("%s: IPv6 being disabled!\n", + ifp->idev->dev->name); + disable_ipv6 = true; + } + } + } + spin_unlock_bh(&ifp->lock); + + if (action == DAD_BEGIN) { + addrconf_dad_begin(ifp); + goto out; + } else if (action == DAD_ABORT) { + in6_ifa_hold(ifp); + addrconf_dad_stop(ifp, 1); + if (disable_ipv6) + addrconf_ifdown(idev->dev, false); + goto out; + } + + if (!ifp->dad_probes && addrconf_dad_end(ifp)) + goto out; + + write_lock_bh(&idev->lock); + if (idev->dead || !(idev->if_flags & IF_READY)) { + write_unlock_bh(&idev->lock); + goto out; + } + + spin_lock(&ifp->lock); + if (ifp->state == INET6_IFADDR_STATE_DEAD) { + spin_unlock(&ifp->lock); + write_unlock_bh(&idev->lock); + goto out; + } + + if (ifp->dad_probes == 0) { + bool send_na = false; + + /* + * DAD was successful + */ + + if (ifp->flags & IFA_F_TENTATIVE && + !(ifp->flags & IFA_F_OPTIMISTIC)) + send_na = true; + bump_id = ifp->flags & IFA_F_TENTATIVE; + ifp->flags &= ~(IFA_F_TENTATIVE|IFA_F_OPTIMISTIC|IFA_F_DADFAILED); + spin_unlock(&ifp->lock); + write_unlock_bh(&idev->lock); + + addrconf_dad_completed(ifp, bump_id, send_na); + + goto out; + } + + ifp->dad_probes--; + addrconf_mod_dad_work(ifp, + max(NEIGH_VAR(ifp->idev->nd_parms, RETRANS_TIME), + HZ/100)); + spin_unlock(&ifp->lock); + write_unlock_bh(&idev->lock); + + /* send a neighbour solicitation for our addr */ + addrconf_addr_solict_mult(&ifp->addr, &mcaddr); + ndisc_send_ns(ifp->idev->dev, &ifp->addr, &mcaddr, &in6addr_any, + ifp->dad_nonce); +out: + in6_ifa_put(ifp); + rtnl_unlock(); +} + +/* ifp->idev must be at least read locked */ +static bool ipv6_lonely_lladdr(struct inet6_ifaddr *ifp) +{ + struct inet6_ifaddr *ifpiter; + struct inet6_dev *idev = ifp->idev; + + list_for_each_entry_reverse(ifpiter, &idev->addr_list, if_list) { + if (ifpiter->scope > IFA_LINK) + break; + if (ifp != ifpiter && ifpiter->scope == IFA_LINK && + (ifpiter->flags & (IFA_F_PERMANENT|IFA_F_TENTATIVE| + IFA_F_OPTIMISTIC|IFA_F_DADFAILED)) == + IFA_F_PERMANENT) + return false; + } + return true; +} + +static void addrconf_dad_completed(struct inet6_ifaddr *ifp, bool bump_id, + bool send_na) +{ + struct net_device *dev = ifp->idev->dev; + struct in6_addr lladdr; + bool send_rs, send_mld; + + addrconf_del_dad_work(ifp); + + /* + * Configure the address for reception. Now it is valid. + */ + + ipv6_ifa_notify(RTM_NEWADDR, ifp); + + /* If added prefix is link local and we are prepared to process + router advertisements, start sending router solicitations. + */ + + read_lock_bh(&ifp->idev->lock); + send_mld = ifp->scope == IFA_LINK && ipv6_lonely_lladdr(ifp); + send_rs = send_mld && + ipv6_accept_ra(ifp->idev) && + ifp->idev->cnf.rtr_solicits != 0 && + (dev->flags & IFF_LOOPBACK) == 0 && + (dev->type != ARPHRD_TUNNEL); + read_unlock_bh(&ifp->idev->lock); + + /* While dad is in progress mld report's source address is in6_addrany. + * Resend with proper ll now. + */ + if (send_mld) + ipv6_mc_dad_complete(ifp->idev); + + /* send unsolicited NA if enabled */ + if (send_na && + (ifp->idev->cnf.ndisc_notify || + dev_net(dev)->ipv6.devconf_all->ndisc_notify)) { + ndisc_send_na(dev, &in6addr_linklocal_allnodes, &ifp->addr, + /*router=*/ !!ifp->idev->cnf.forwarding, + /*solicited=*/ false, /*override=*/ true, + /*inc_opt=*/ true); + } + + if (send_rs) { + /* + * If a host as already performed a random delay + * [...] as part of DAD [...] there is no need + * to delay again before sending the first RS + */ + if (ipv6_get_lladdr(dev, &lladdr, IFA_F_TENTATIVE)) + return; + ndisc_send_rs(dev, &lladdr, &in6addr_linklocal_allrouters); + + write_lock_bh(&ifp->idev->lock); + spin_lock(&ifp->lock); + ifp->idev->rs_interval = rfc3315_s14_backoff_init( + ifp->idev->cnf.rtr_solicit_interval); + ifp->idev->rs_probes = 1; + ifp->idev->if_flags |= IF_RS_SENT; + addrconf_mod_rs_timer(ifp->idev, ifp->idev->rs_interval); + spin_unlock(&ifp->lock); + write_unlock_bh(&ifp->idev->lock); + } + + if (bump_id) + rt_genid_bump_ipv6(dev_net(dev)); + + /* Make sure that a new temporary address will be created + * before this temporary address becomes deprecated. + */ + if (ifp->flags & IFA_F_TEMPORARY) + addrconf_verify_rtnl(dev_net(dev)); +} + +static void addrconf_dad_run(struct inet6_dev *idev, bool restart) +{ + struct inet6_ifaddr *ifp; + + read_lock_bh(&idev->lock); + list_for_each_entry(ifp, &idev->addr_list, if_list) { + spin_lock(&ifp->lock); + if ((ifp->flags & IFA_F_TENTATIVE && + ifp->state == INET6_IFADDR_STATE_DAD) || restart) { + if (restart) + ifp->state = INET6_IFADDR_STATE_PREDAD; + addrconf_dad_kick(ifp); + } + spin_unlock(&ifp->lock); + } + read_unlock_bh(&idev->lock); +} + +#ifdef CONFIG_PROC_FS +struct if6_iter_state { + struct seq_net_private p; + int bucket; + int offset; +}; + +static struct inet6_ifaddr *if6_get_first(struct seq_file *seq, loff_t pos) +{ + struct if6_iter_state *state = seq->private; + struct net *net = seq_file_net(seq); + struct inet6_ifaddr *ifa = NULL; + int p = 0; + + /* initial bucket if pos is 0 */ + if (pos == 0) { + state->bucket = 0; + state->offset = 0; + } + + for (; state->bucket < IN6_ADDR_HSIZE; ++state->bucket) { + hlist_for_each_entry_rcu(ifa, &net->ipv6.inet6_addr_lst[state->bucket], + addr_lst) { + /* sync with offset */ + if (p < state->offset) { + p++; + continue; + } + return ifa; + } + + /* prepare for next bucket */ + state->offset = 0; + p = 0; + } + return NULL; +} + +static struct inet6_ifaddr *if6_get_next(struct seq_file *seq, + struct inet6_ifaddr *ifa) +{ + struct if6_iter_state *state = seq->private; + struct net *net = seq_file_net(seq); + + hlist_for_each_entry_continue_rcu(ifa, addr_lst) { + state->offset++; + return ifa; + } + + state->offset = 0; + while (++state->bucket < IN6_ADDR_HSIZE) { + hlist_for_each_entry_rcu(ifa, + &net->ipv6.inet6_addr_lst[state->bucket], addr_lst) { + return ifa; + } + } + + return NULL; +} + +static void *if6_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(rcu) +{ + rcu_read_lock(); + return if6_get_first(seq, *pos); +} + +static void *if6_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct inet6_ifaddr *ifa; + + ifa = if6_get_next(seq, v); + ++*pos; + return ifa; +} + +static void if6_seq_stop(struct seq_file *seq, void *v) + __releases(rcu) +{ + rcu_read_unlock(); +} + +static int if6_seq_show(struct seq_file *seq, void *v) +{ + struct inet6_ifaddr *ifp = (struct inet6_ifaddr *)v; + seq_printf(seq, "%pi6 %02x %02x %02x %02x %8s\n", + &ifp->addr, + ifp->idev->dev->ifindex, + ifp->prefix_len, + ifp->scope, + (u8) ifp->flags, + ifp->idev->dev->name); + return 0; +} + +static const struct seq_operations if6_seq_ops = { + .start = if6_seq_start, + .next = if6_seq_next, + .show = if6_seq_show, + .stop = if6_seq_stop, +}; + +static int __net_init if6_proc_net_init(struct net *net) +{ + if (!proc_create_net("if_inet6", 0444, net->proc_net, &if6_seq_ops, + sizeof(struct if6_iter_state))) + return -ENOMEM; + return 0; +} + +static void __net_exit if6_proc_net_exit(struct net *net) +{ + remove_proc_entry("if_inet6", net->proc_net); +} + +static struct pernet_operations if6_proc_net_ops = { + .init = if6_proc_net_init, + .exit = if6_proc_net_exit, +}; + +int __init if6_proc_init(void) +{ + return register_pernet_subsys(&if6_proc_net_ops); +} + +void if6_proc_exit(void) +{ + unregister_pernet_subsys(&if6_proc_net_ops); +} +#endif /* CONFIG_PROC_FS */ + +#if IS_ENABLED(CONFIG_IPV6_MIP6) +/* Check if address is a home address configured on any interface. */ +int ipv6_chk_home_addr(struct net *net, const struct in6_addr *addr) +{ + unsigned int hash = inet6_addr_hash(net, addr); + struct inet6_ifaddr *ifp = NULL; + int ret = 0; + + rcu_read_lock(); + hlist_for_each_entry_rcu(ifp, &net->ipv6.inet6_addr_lst[hash], addr_lst) { + if (ipv6_addr_equal(&ifp->addr, addr) && + (ifp->flags & IFA_F_HOMEADDRESS)) { + ret = 1; + break; + } + } + rcu_read_unlock(); + return ret; +} +#endif + +/* RFC6554 has some algorithm to avoid loops in segment routing by + * checking if the segments contains any of a local interface address. + * + * Quote: + * + * To detect loops in the SRH, a router MUST determine if the SRH + * includes multiple addresses assigned to any interface on that router. + * If such addresses appear more than once and are separated by at least + * one address not assigned to that router. + */ +int ipv6_chk_rpl_srh_loop(struct net *net, const struct in6_addr *segs, + unsigned char nsegs) +{ + const struct in6_addr *addr; + int i, ret = 0, found = 0; + struct inet6_ifaddr *ifp; + bool separated = false; + unsigned int hash; + bool hash_found; + + rcu_read_lock(); + for (i = 0; i < nsegs; i++) { + addr = &segs[i]; + hash = inet6_addr_hash(net, addr); + + hash_found = false; + hlist_for_each_entry_rcu(ifp, &net->ipv6.inet6_addr_lst[hash], addr_lst) { + + if (ipv6_addr_equal(&ifp->addr, addr)) { + hash_found = true; + break; + } + } + + if (hash_found) { + if (found > 1 && separated) { + ret = 1; + break; + } + + separated = false; + found++; + } else { + separated = true; + } + } + rcu_read_unlock(); + + return ret; +} + +/* + * Periodic address status verification + */ + +static void addrconf_verify_rtnl(struct net *net) +{ + unsigned long now, next, next_sec, next_sched; + struct inet6_ifaddr *ifp; + int i; + + ASSERT_RTNL(); + + rcu_read_lock_bh(); + now = jiffies; + next = round_jiffies_up(now + ADDR_CHECK_FREQUENCY); + + cancel_delayed_work(&net->ipv6.addr_chk_work); + + for (i = 0; i < IN6_ADDR_HSIZE; i++) { +restart: + hlist_for_each_entry_rcu_bh(ifp, &net->ipv6.inet6_addr_lst[i], addr_lst) { + unsigned long age; + + /* When setting preferred_lft to a value not zero or + * infinity, while valid_lft is infinity + * IFA_F_PERMANENT has a non-infinity life time. + */ + if ((ifp->flags & IFA_F_PERMANENT) && + (ifp->prefered_lft == INFINITY_LIFE_TIME)) + continue; + + spin_lock(&ifp->lock); + /* We try to batch several events at once. */ + age = (now - ifp->tstamp + ADDRCONF_TIMER_FUZZ_MINUS) / HZ; + + if ((ifp->flags&IFA_F_TEMPORARY) && + !(ifp->flags&IFA_F_TENTATIVE) && + ifp->prefered_lft != INFINITY_LIFE_TIME && + !ifp->regen_count && ifp->ifpub) { + /* This is a non-regenerated temporary addr. */ + + unsigned long regen_advance = ifp->idev->cnf.regen_max_retry * + ifp->idev->cnf.dad_transmits * + max(NEIGH_VAR(ifp->idev->nd_parms, RETRANS_TIME), HZ/100) / HZ; + + if (age + regen_advance >= ifp->prefered_lft) { + struct inet6_ifaddr *ifpub = ifp->ifpub; + if (time_before(ifp->tstamp + ifp->prefered_lft * HZ, next)) + next = ifp->tstamp + ifp->prefered_lft * HZ; + + ifp->regen_count++; + in6_ifa_hold(ifp); + in6_ifa_hold(ifpub); + spin_unlock(&ifp->lock); + + spin_lock(&ifpub->lock); + ifpub->regen_count = 0; + spin_unlock(&ifpub->lock); + rcu_read_unlock_bh(); + ipv6_create_tempaddr(ifpub, true); + in6_ifa_put(ifpub); + in6_ifa_put(ifp); + rcu_read_lock_bh(); + goto restart; + } else if (time_before(ifp->tstamp + ifp->prefered_lft * HZ - regen_advance * HZ, next)) + next = ifp->tstamp + ifp->prefered_lft * HZ - regen_advance * HZ; + } + + if (ifp->valid_lft != INFINITY_LIFE_TIME && + age >= ifp->valid_lft) { + spin_unlock(&ifp->lock); + in6_ifa_hold(ifp); + rcu_read_unlock_bh(); + ipv6_del_addr(ifp); + rcu_read_lock_bh(); + goto restart; + } else if (ifp->prefered_lft == INFINITY_LIFE_TIME) { + spin_unlock(&ifp->lock); + continue; + } else if (age >= ifp->prefered_lft) { + /* jiffies - ifp->tstamp > age >= ifp->prefered_lft */ + int deprecate = 0; + + if (!(ifp->flags&IFA_F_DEPRECATED)) { + deprecate = 1; + ifp->flags |= IFA_F_DEPRECATED; + } + + if ((ifp->valid_lft != INFINITY_LIFE_TIME) && + (time_before(ifp->tstamp + ifp->valid_lft * HZ, next))) + next = ifp->tstamp + ifp->valid_lft * HZ; + + spin_unlock(&ifp->lock); + + if (deprecate) { + in6_ifa_hold(ifp); + + ipv6_ifa_notify(0, ifp); + in6_ifa_put(ifp); + goto restart; + } + } else { + /* ifp->prefered_lft <= ifp->valid_lft */ + if (time_before(ifp->tstamp + ifp->prefered_lft * HZ, next)) + next = ifp->tstamp + ifp->prefered_lft * HZ; + spin_unlock(&ifp->lock); + } + } + } + + next_sec = round_jiffies_up(next); + next_sched = next; + + /* If rounded timeout is accurate enough, accept it. */ + if (time_before(next_sec, next + ADDRCONF_TIMER_FUZZ)) + next_sched = next_sec; + + /* And minimum interval is ADDRCONF_TIMER_FUZZ_MAX. */ + if (time_before(next_sched, jiffies + ADDRCONF_TIMER_FUZZ_MAX)) + next_sched = jiffies + ADDRCONF_TIMER_FUZZ_MAX; + + pr_debug("now = %lu, schedule = %lu, rounded schedule = %lu => %lu\n", + now, next, next_sec, next_sched); + mod_delayed_work(addrconf_wq, &net->ipv6.addr_chk_work, next_sched - now); + rcu_read_unlock_bh(); +} + +static void addrconf_verify_work(struct work_struct *w) +{ + struct net *net = container_of(to_delayed_work(w), struct net, + ipv6.addr_chk_work); + + rtnl_lock(); + addrconf_verify_rtnl(net); + rtnl_unlock(); +} + +static void addrconf_verify(struct net *net) +{ + mod_delayed_work(addrconf_wq, &net->ipv6.addr_chk_work, 0); +} + +static struct in6_addr *extract_addr(struct nlattr *addr, struct nlattr *local, + struct in6_addr **peer_pfx) +{ + struct in6_addr *pfx = NULL; + + *peer_pfx = NULL; + + if (addr) + pfx = nla_data(addr); + + if (local) { + if (pfx && nla_memcmp(local, pfx, sizeof(*pfx))) + *peer_pfx = pfx; + pfx = nla_data(local); + } + + return pfx; +} + +static const struct nla_policy ifa_ipv6_policy[IFA_MAX+1] = { + [IFA_ADDRESS] = { .len = sizeof(struct in6_addr) }, + [IFA_LOCAL] = { .len = sizeof(struct in6_addr) }, + [IFA_CACHEINFO] = { .len = sizeof(struct ifa_cacheinfo) }, + [IFA_FLAGS] = { .len = sizeof(u32) }, + [IFA_RT_PRIORITY] = { .len = sizeof(u32) }, + [IFA_TARGET_NETNSID] = { .type = NLA_S32 }, + [IFA_PROTO] = { .type = NLA_U8 }, +}; + +static int +inet6_rtm_deladdr(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct ifaddrmsg *ifm; + struct nlattr *tb[IFA_MAX+1]; + struct in6_addr *pfx, *peer_pfx; + u32 ifa_flags; + int err; + + err = nlmsg_parse_deprecated(nlh, sizeof(*ifm), tb, IFA_MAX, + ifa_ipv6_policy, extack); + if (err < 0) + return err; + + ifm = nlmsg_data(nlh); + pfx = extract_addr(tb[IFA_ADDRESS], tb[IFA_LOCAL], &peer_pfx); + if (!pfx) + return -EINVAL; + + ifa_flags = tb[IFA_FLAGS] ? nla_get_u32(tb[IFA_FLAGS]) : ifm->ifa_flags; + + /* We ignore other flags so far. */ + ifa_flags &= IFA_F_MANAGETEMPADDR; + + return inet6_addr_del(net, ifm->ifa_index, ifa_flags, pfx, + ifm->ifa_prefixlen); +} + +static int modify_prefix_route(struct inet6_ifaddr *ifp, + unsigned long expires, u32 flags, + bool modify_peer) +{ + struct fib6_info *f6i; + u32 prio; + + f6i = addrconf_get_prefix_route(modify_peer ? &ifp->peer_addr : &ifp->addr, + ifp->prefix_len, + ifp->idev->dev, 0, RTF_DEFAULT, true); + if (!f6i) + return -ENOENT; + + prio = ifp->rt_priority ? : IP6_RT_PRIO_ADDRCONF; + if (f6i->fib6_metric != prio) { + /* delete old one */ + ip6_del_rt(dev_net(ifp->idev->dev), f6i, false); + + /* add new one */ + addrconf_prefix_route(modify_peer ? &ifp->peer_addr : &ifp->addr, + ifp->prefix_len, + ifp->rt_priority, ifp->idev->dev, + expires, flags, GFP_KERNEL); + } else { + if (!expires) + fib6_clean_expires(f6i); + else + fib6_set_expires(f6i, expires); + + fib6_info_release(f6i); + } + + return 0; +} + +static int inet6_addr_modify(struct net *net, struct inet6_ifaddr *ifp, + struct ifa6_config *cfg) +{ + u32 flags; + clock_t expires; + unsigned long timeout; + bool was_managetempaddr; + bool had_prefixroute; + bool new_peer = false; + + ASSERT_RTNL(); + + if (!cfg->valid_lft || cfg->preferred_lft > cfg->valid_lft) + return -EINVAL; + + if (cfg->ifa_flags & IFA_F_MANAGETEMPADDR && + (ifp->flags & IFA_F_TEMPORARY || ifp->prefix_len != 64)) + return -EINVAL; + + if (!(ifp->flags & IFA_F_TENTATIVE) || ifp->flags & IFA_F_DADFAILED) + cfg->ifa_flags &= ~IFA_F_OPTIMISTIC; + + timeout = addrconf_timeout_fixup(cfg->valid_lft, HZ); + if (addrconf_finite_timeout(timeout)) { + expires = jiffies_to_clock_t(timeout * HZ); + cfg->valid_lft = timeout; + flags = RTF_EXPIRES; + } else { + expires = 0; + flags = 0; + cfg->ifa_flags |= IFA_F_PERMANENT; + } + + timeout = addrconf_timeout_fixup(cfg->preferred_lft, HZ); + if (addrconf_finite_timeout(timeout)) { + if (timeout == 0) + cfg->ifa_flags |= IFA_F_DEPRECATED; + cfg->preferred_lft = timeout; + } + + if (cfg->peer_pfx && + memcmp(&ifp->peer_addr, cfg->peer_pfx, sizeof(struct in6_addr))) { + if (!ipv6_addr_any(&ifp->peer_addr)) + cleanup_prefix_route(ifp, expires, true, true); + new_peer = true; + } + + spin_lock_bh(&ifp->lock); + was_managetempaddr = ifp->flags & IFA_F_MANAGETEMPADDR; + had_prefixroute = ifp->flags & IFA_F_PERMANENT && + !(ifp->flags & IFA_F_NOPREFIXROUTE); + ifp->flags &= ~(IFA_F_DEPRECATED | IFA_F_PERMANENT | IFA_F_NODAD | + IFA_F_HOMEADDRESS | IFA_F_MANAGETEMPADDR | + IFA_F_NOPREFIXROUTE); + ifp->flags |= cfg->ifa_flags; + ifp->tstamp = jiffies; + ifp->valid_lft = cfg->valid_lft; + ifp->prefered_lft = cfg->preferred_lft; + ifp->ifa_proto = cfg->ifa_proto; + + if (cfg->rt_priority && cfg->rt_priority != ifp->rt_priority) + ifp->rt_priority = cfg->rt_priority; + + if (new_peer) + ifp->peer_addr = *cfg->peer_pfx; + + spin_unlock_bh(&ifp->lock); + if (!(ifp->flags&IFA_F_TENTATIVE)) + ipv6_ifa_notify(0, ifp); + + if (!(cfg->ifa_flags & IFA_F_NOPREFIXROUTE)) { + int rc = -ENOENT; + + if (had_prefixroute) + rc = modify_prefix_route(ifp, expires, flags, false); + + /* prefix route could have been deleted; if so restore it */ + if (rc == -ENOENT) { + addrconf_prefix_route(&ifp->addr, ifp->prefix_len, + ifp->rt_priority, ifp->idev->dev, + expires, flags, GFP_KERNEL); + } + + if (had_prefixroute && !ipv6_addr_any(&ifp->peer_addr)) + rc = modify_prefix_route(ifp, expires, flags, true); + + if (rc == -ENOENT && !ipv6_addr_any(&ifp->peer_addr)) { + addrconf_prefix_route(&ifp->peer_addr, ifp->prefix_len, + ifp->rt_priority, ifp->idev->dev, + expires, flags, GFP_KERNEL); + } + } else if (had_prefixroute) { + enum cleanup_prefix_rt_t action; + unsigned long rt_expires; + + write_lock_bh(&ifp->idev->lock); + action = check_cleanup_prefix_route(ifp, &rt_expires); + write_unlock_bh(&ifp->idev->lock); + + if (action != CLEANUP_PREFIX_RT_NOP) { + cleanup_prefix_route(ifp, rt_expires, + action == CLEANUP_PREFIX_RT_DEL, false); + } + } + + if (was_managetempaddr || ifp->flags & IFA_F_MANAGETEMPADDR) { + if (was_managetempaddr && + !(ifp->flags & IFA_F_MANAGETEMPADDR)) { + cfg->valid_lft = 0; + cfg->preferred_lft = 0; + } + manage_tempaddrs(ifp->idev, ifp, cfg->valid_lft, + cfg->preferred_lft, !was_managetempaddr, + jiffies); + } + + addrconf_verify_rtnl(net); + + return 0; +} + +static int +inet6_rtm_newaddr(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct ifaddrmsg *ifm; + struct nlattr *tb[IFA_MAX+1]; + struct in6_addr *peer_pfx; + struct inet6_ifaddr *ifa; + struct net_device *dev; + struct inet6_dev *idev; + struct ifa6_config cfg; + int err; + + err = nlmsg_parse_deprecated(nlh, sizeof(*ifm), tb, IFA_MAX, + ifa_ipv6_policy, extack); + if (err < 0) + return err; + + memset(&cfg, 0, sizeof(cfg)); + + ifm = nlmsg_data(nlh); + cfg.pfx = extract_addr(tb[IFA_ADDRESS], tb[IFA_LOCAL], &peer_pfx); + if (!cfg.pfx) + return -EINVAL; + + cfg.peer_pfx = peer_pfx; + cfg.plen = ifm->ifa_prefixlen; + if (tb[IFA_RT_PRIORITY]) + cfg.rt_priority = nla_get_u32(tb[IFA_RT_PRIORITY]); + + if (tb[IFA_PROTO]) + cfg.ifa_proto = nla_get_u8(tb[IFA_PROTO]); + + cfg.valid_lft = INFINITY_LIFE_TIME; + cfg.preferred_lft = INFINITY_LIFE_TIME; + + if (tb[IFA_CACHEINFO]) { + struct ifa_cacheinfo *ci; + + ci = nla_data(tb[IFA_CACHEINFO]); + cfg.valid_lft = ci->ifa_valid; + cfg.preferred_lft = ci->ifa_prefered; + } + + dev = __dev_get_by_index(net, ifm->ifa_index); + if (!dev) + return -ENODEV; + + if (tb[IFA_FLAGS]) + cfg.ifa_flags = nla_get_u32(tb[IFA_FLAGS]); + else + cfg.ifa_flags = ifm->ifa_flags; + + /* We ignore other flags so far. */ + cfg.ifa_flags &= IFA_F_NODAD | IFA_F_HOMEADDRESS | + IFA_F_MANAGETEMPADDR | IFA_F_NOPREFIXROUTE | + IFA_F_MCAUTOJOIN | IFA_F_OPTIMISTIC; + + idev = ipv6_find_idev(dev); + if (IS_ERR(idev)) + return PTR_ERR(idev); + + if (!ipv6_allow_optimistic_dad(net, idev)) + cfg.ifa_flags &= ~IFA_F_OPTIMISTIC; + + if (cfg.ifa_flags & IFA_F_NODAD && + cfg.ifa_flags & IFA_F_OPTIMISTIC) { + NL_SET_ERR_MSG(extack, "IFA_F_NODAD and IFA_F_OPTIMISTIC are mutually exclusive"); + return -EINVAL; + } + + ifa = ipv6_get_ifaddr(net, cfg.pfx, dev, 1); + if (!ifa) { + /* + * It would be best to check for !NLM_F_CREATE here but + * userspace already relies on not having to provide this. + */ + return inet6_addr_add(net, ifm->ifa_index, &cfg, extack); + } + + if (nlh->nlmsg_flags & NLM_F_EXCL || + !(nlh->nlmsg_flags & NLM_F_REPLACE)) + err = -EEXIST; + else + err = inet6_addr_modify(net, ifa, &cfg); + + in6_ifa_put(ifa); + + return err; +} + +static void put_ifaddrmsg(struct nlmsghdr *nlh, u8 prefixlen, u32 flags, + u8 scope, int ifindex) +{ + struct ifaddrmsg *ifm; + + ifm = nlmsg_data(nlh); + ifm->ifa_family = AF_INET6; + ifm->ifa_prefixlen = prefixlen; + ifm->ifa_flags = flags; + ifm->ifa_scope = scope; + ifm->ifa_index = ifindex; +} + +static int put_cacheinfo(struct sk_buff *skb, unsigned long cstamp, + unsigned long tstamp, u32 preferred, u32 valid) +{ + struct ifa_cacheinfo ci; + + ci.cstamp = cstamp_delta(cstamp); + ci.tstamp = cstamp_delta(tstamp); + ci.ifa_prefered = preferred; + ci.ifa_valid = valid; + + return nla_put(skb, IFA_CACHEINFO, sizeof(ci), &ci); +} + +static inline int rt_scope(int ifa_scope) +{ + if (ifa_scope & IFA_HOST) + return RT_SCOPE_HOST; + else if (ifa_scope & IFA_LINK) + return RT_SCOPE_LINK; + else if (ifa_scope & IFA_SITE) + return RT_SCOPE_SITE; + else + return RT_SCOPE_UNIVERSE; +} + +static inline int inet6_ifaddr_msgsize(void) +{ + return NLMSG_ALIGN(sizeof(struct ifaddrmsg)) + + nla_total_size(16) /* IFA_LOCAL */ + + nla_total_size(16) /* IFA_ADDRESS */ + + nla_total_size(sizeof(struct ifa_cacheinfo)) + + nla_total_size(4) /* IFA_FLAGS */ + + nla_total_size(1) /* IFA_PROTO */ + + nla_total_size(4) /* IFA_RT_PRIORITY */; +} + +enum addr_type_t { + UNICAST_ADDR, + MULTICAST_ADDR, + ANYCAST_ADDR, +}; + +struct inet6_fill_args { + u32 portid; + u32 seq; + int event; + unsigned int flags; + int netnsid; + int ifindex; + enum addr_type_t type; +}; + +static int inet6_fill_ifaddr(struct sk_buff *skb, struct inet6_ifaddr *ifa, + struct inet6_fill_args *args) +{ + struct nlmsghdr *nlh; + u32 preferred, valid; + + nlh = nlmsg_put(skb, args->portid, args->seq, args->event, + sizeof(struct ifaddrmsg), args->flags); + if (!nlh) + return -EMSGSIZE; + + put_ifaddrmsg(nlh, ifa->prefix_len, ifa->flags, rt_scope(ifa->scope), + ifa->idev->dev->ifindex); + + if (args->netnsid >= 0 && + nla_put_s32(skb, IFA_TARGET_NETNSID, args->netnsid)) + goto error; + + spin_lock_bh(&ifa->lock); + if (!((ifa->flags&IFA_F_PERMANENT) && + (ifa->prefered_lft == INFINITY_LIFE_TIME))) { + preferred = ifa->prefered_lft; + valid = ifa->valid_lft; + if (preferred != INFINITY_LIFE_TIME) { + long tval = (jiffies - ifa->tstamp)/HZ; + if (preferred > tval) + preferred -= tval; + else + preferred = 0; + if (valid != INFINITY_LIFE_TIME) { + if (valid > tval) + valid -= tval; + else + valid = 0; + } + } + } else { + preferred = INFINITY_LIFE_TIME; + valid = INFINITY_LIFE_TIME; + } + spin_unlock_bh(&ifa->lock); + + if (!ipv6_addr_any(&ifa->peer_addr)) { + if (nla_put_in6_addr(skb, IFA_LOCAL, &ifa->addr) < 0 || + nla_put_in6_addr(skb, IFA_ADDRESS, &ifa->peer_addr) < 0) + goto error; + } else + if (nla_put_in6_addr(skb, IFA_ADDRESS, &ifa->addr) < 0) + goto error; + + if (ifa->rt_priority && + nla_put_u32(skb, IFA_RT_PRIORITY, ifa->rt_priority)) + goto error; + + if (put_cacheinfo(skb, ifa->cstamp, ifa->tstamp, preferred, valid) < 0) + goto error; + + if (nla_put_u32(skb, IFA_FLAGS, ifa->flags) < 0) + goto error; + + if (ifa->ifa_proto && + nla_put_u8(skb, IFA_PROTO, ifa->ifa_proto)) + goto error; + + nlmsg_end(skb, nlh); + return 0; + +error: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int inet6_fill_ifmcaddr(struct sk_buff *skb, struct ifmcaddr6 *ifmca, + struct inet6_fill_args *args) +{ + struct nlmsghdr *nlh; + u8 scope = RT_SCOPE_UNIVERSE; + int ifindex = ifmca->idev->dev->ifindex; + + if (ipv6_addr_scope(&ifmca->mca_addr) & IFA_SITE) + scope = RT_SCOPE_SITE; + + nlh = nlmsg_put(skb, args->portid, args->seq, args->event, + sizeof(struct ifaddrmsg), args->flags); + if (!nlh) + return -EMSGSIZE; + + if (args->netnsid >= 0 && + nla_put_s32(skb, IFA_TARGET_NETNSID, args->netnsid)) { + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; + } + + put_ifaddrmsg(nlh, 128, IFA_F_PERMANENT, scope, ifindex); + if (nla_put_in6_addr(skb, IFA_MULTICAST, &ifmca->mca_addr) < 0 || + put_cacheinfo(skb, ifmca->mca_cstamp, ifmca->mca_tstamp, + INFINITY_LIFE_TIME, INFINITY_LIFE_TIME) < 0) { + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; + } + + nlmsg_end(skb, nlh); + return 0; +} + +static int inet6_fill_ifacaddr(struct sk_buff *skb, struct ifacaddr6 *ifaca, + struct inet6_fill_args *args) +{ + struct net_device *dev = fib6_info_nh_dev(ifaca->aca_rt); + int ifindex = dev ? dev->ifindex : 1; + struct nlmsghdr *nlh; + u8 scope = RT_SCOPE_UNIVERSE; + + if (ipv6_addr_scope(&ifaca->aca_addr) & IFA_SITE) + scope = RT_SCOPE_SITE; + + nlh = nlmsg_put(skb, args->portid, args->seq, args->event, + sizeof(struct ifaddrmsg), args->flags); + if (!nlh) + return -EMSGSIZE; + + if (args->netnsid >= 0 && + nla_put_s32(skb, IFA_TARGET_NETNSID, args->netnsid)) { + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; + } + + put_ifaddrmsg(nlh, 128, IFA_F_PERMANENT, scope, ifindex); + if (nla_put_in6_addr(skb, IFA_ANYCAST, &ifaca->aca_addr) < 0 || + put_cacheinfo(skb, ifaca->aca_cstamp, ifaca->aca_tstamp, + INFINITY_LIFE_TIME, INFINITY_LIFE_TIME) < 0) { + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; + } + + nlmsg_end(skb, nlh); + return 0; +} + +/* called with rcu_read_lock() */ +static int in6_dump_addrs(struct inet6_dev *idev, struct sk_buff *skb, + struct netlink_callback *cb, int s_ip_idx, + struct inet6_fill_args *fillargs) +{ + struct ifmcaddr6 *ifmca; + struct ifacaddr6 *ifaca; + int ip_idx = 0; + int err = 1; + + read_lock_bh(&idev->lock); + switch (fillargs->type) { + case UNICAST_ADDR: { + struct inet6_ifaddr *ifa; + fillargs->event = RTM_NEWADDR; + + /* unicast address incl. temp addr */ + list_for_each_entry(ifa, &idev->addr_list, if_list) { + if (ip_idx < s_ip_idx) + goto next; + err = inet6_fill_ifaddr(skb, ifa, fillargs); + if (err < 0) + break; + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); +next: + ip_idx++; + } + break; + } + case MULTICAST_ADDR: + read_unlock_bh(&idev->lock); + fillargs->event = RTM_GETMULTICAST; + + /* multicast address */ + for (ifmca = rtnl_dereference(idev->mc_list); + ifmca; + ifmca = rtnl_dereference(ifmca->next), ip_idx++) { + if (ip_idx < s_ip_idx) + continue; + err = inet6_fill_ifmcaddr(skb, ifmca, fillargs); + if (err < 0) + break; + } + read_lock_bh(&idev->lock); + break; + case ANYCAST_ADDR: + fillargs->event = RTM_GETANYCAST; + /* anycast address */ + for (ifaca = idev->ac_list; ifaca; + ifaca = ifaca->aca_next, ip_idx++) { + if (ip_idx < s_ip_idx) + continue; + err = inet6_fill_ifacaddr(skb, ifaca, fillargs); + if (err < 0) + break; + } + break; + default: + break; + } + read_unlock_bh(&idev->lock); + cb->args[2] = ip_idx; + return err; +} + +static int inet6_valid_dump_ifaddr_req(const struct nlmsghdr *nlh, + struct inet6_fill_args *fillargs, + struct net **tgt_net, struct sock *sk, + struct netlink_callback *cb) +{ + struct netlink_ext_ack *extack = cb->extack; + struct nlattr *tb[IFA_MAX+1]; + struct ifaddrmsg *ifm; + int err, i; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for address dump request"); + return -EINVAL; + } + + ifm = nlmsg_data(nlh); + if (ifm->ifa_prefixlen || ifm->ifa_flags || ifm->ifa_scope) { + NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for address dump request"); + return -EINVAL; + } + + fillargs->ifindex = ifm->ifa_index; + if (fillargs->ifindex) { + cb->answer_flags |= NLM_F_DUMP_FILTERED; + fillargs->flags |= NLM_F_DUMP_FILTERED; + } + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(*ifm), tb, IFA_MAX, + ifa_ipv6_policy, extack); + if (err < 0) + return err; + + for (i = 0; i <= IFA_MAX; ++i) { + if (!tb[i]) + continue; + + if (i == IFA_TARGET_NETNSID) { + struct net *net; + + fillargs->netnsid = nla_get_s32(tb[i]); + net = rtnl_get_net_ns_capable(sk, fillargs->netnsid); + if (IS_ERR(net)) { + fillargs->netnsid = -1; + NL_SET_ERR_MSG_MOD(extack, "Invalid target network namespace id"); + return PTR_ERR(net); + } + *tgt_net = net; + } else { + NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in dump request"); + return -EINVAL; + } + } + + return 0; +} + +static int inet6_dump_addr(struct sk_buff *skb, struct netlink_callback *cb, + enum addr_type_t type) +{ + const struct nlmsghdr *nlh = cb->nlh; + struct inet6_fill_args fillargs = { + .portid = NETLINK_CB(cb->skb).portid, + .seq = cb->nlh->nlmsg_seq, + .flags = NLM_F_MULTI, + .netnsid = -1, + .type = type, + }; + struct net *tgt_net = sock_net(skb->sk); + int idx, s_idx, s_ip_idx; + int h, s_h; + struct net_device *dev; + struct inet6_dev *idev; + struct hlist_head *head; + int err = 0; + + s_h = cb->args[0]; + s_idx = idx = cb->args[1]; + s_ip_idx = cb->args[2]; + + if (cb->strict_check) { + err = inet6_valid_dump_ifaddr_req(nlh, &fillargs, &tgt_net, + skb->sk, cb); + if (err < 0) + goto put_tgt_net; + + err = 0; + if (fillargs.ifindex) { + dev = __dev_get_by_index(tgt_net, fillargs.ifindex); + if (!dev) { + err = -ENODEV; + goto put_tgt_net; + } + idev = __in6_dev_get(dev); + if (idev) { + err = in6_dump_addrs(idev, skb, cb, s_ip_idx, + &fillargs); + if (err > 0) + err = 0; + } + goto put_tgt_net; + } + } + + rcu_read_lock(); + cb->seq = atomic_read(&tgt_net->ipv6.dev_addr_genid) ^ tgt_net->dev_base_seq; + for (h = s_h; h < NETDEV_HASHENTRIES; h++, s_idx = 0) { + idx = 0; + head = &tgt_net->dev_index_head[h]; + hlist_for_each_entry_rcu(dev, head, index_hlist) { + if (idx < s_idx) + goto cont; + if (h > s_h || idx > s_idx) + s_ip_idx = 0; + idev = __in6_dev_get(dev); + if (!idev) + goto cont; + + if (in6_dump_addrs(idev, skb, cb, s_ip_idx, + &fillargs) < 0) + goto done; +cont: + idx++; + } + } +done: + rcu_read_unlock(); + cb->args[0] = h; + cb->args[1] = idx; +put_tgt_net: + if (fillargs.netnsid >= 0) + put_net(tgt_net); + + return skb->len ? : err; +} + +static int inet6_dump_ifaddr(struct sk_buff *skb, struct netlink_callback *cb) +{ + enum addr_type_t type = UNICAST_ADDR; + + return inet6_dump_addr(skb, cb, type); +} + +static int inet6_dump_ifmcaddr(struct sk_buff *skb, struct netlink_callback *cb) +{ + enum addr_type_t type = MULTICAST_ADDR; + + return inet6_dump_addr(skb, cb, type); +} + + +static int inet6_dump_ifacaddr(struct sk_buff *skb, struct netlink_callback *cb) +{ + enum addr_type_t type = ANYCAST_ADDR; + + return inet6_dump_addr(skb, cb, type); +} + +static int inet6_rtm_valid_getaddr_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct ifaddrmsg *ifm; + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for get address request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse_deprecated(nlh, sizeof(*ifm), tb, IFA_MAX, + ifa_ipv6_policy, extack); + + ifm = nlmsg_data(nlh); + if (ifm->ifa_prefixlen || ifm->ifa_flags || ifm->ifa_scope) { + NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for get address request"); + return -EINVAL; + } + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(*ifm), tb, IFA_MAX, + ifa_ipv6_policy, extack); + if (err) + return err; + + for (i = 0; i <= IFA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case IFA_TARGET_NETNSID: + case IFA_ADDRESS: + case IFA_LOCAL: + break; + default: + NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in get address request"); + return -EINVAL; + } + } + + return 0; +} + +static int inet6_rtm_getaddr(struct sk_buff *in_skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *tgt_net = sock_net(in_skb->sk); + struct inet6_fill_args fillargs = { + .portid = NETLINK_CB(in_skb).portid, + .seq = nlh->nlmsg_seq, + .event = RTM_NEWADDR, + .flags = 0, + .netnsid = -1, + }; + struct ifaddrmsg *ifm; + struct nlattr *tb[IFA_MAX+1]; + struct in6_addr *addr = NULL, *peer; + struct net_device *dev = NULL; + struct inet6_ifaddr *ifa; + struct sk_buff *skb; + int err; + + err = inet6_rtm_valid_getaddr_req(in_skb, nlh, tb, extack); + if (err < 0) + return err; + + if (tb[IFA_TARGET_NETNSID]) { + fillargs.netnsid = nla_get_s32(tb[IFA_TARGET_NETNSID]); + + tgt_net = rtnl_get_net_ns_capable(NETLINK_CB(in_skb).sk, + fillargs.netnsid); + if (IS_ERR(tgt_net)) + return PTR_ERR(tgt_net); + } + + addr = extract_addr(tb[IFA_ADDRESS], tb[IFA_LOCAL], &peer); + if (!addr) + return -EINVAL; + + ifm = nlmsg_data(nlh); + if (ifm->ifa_index) + dev = dev_get_by_index(tgt_net, ifm->ifa_index); + + ifa = ipv6_get_ifaddr(tgt_net, addr, dev, 1); + if (!ifa) { + err = -EADDRNOTAVAIL; + goto errout; + } + + skb = nlmsg_new(inet6_ifaddr_msgsize(), GFP_KERNEL); + if (!skb) { + err = -ENOBUFS; + goto errout_ifa; + } + + err = inet6_fill_ifaddr(skb, ifa, &fillargs); + if (err < 0) { + /* -EMSGSIZE implies BUG in inet6_ifaddr_msgsize() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout_ifa; + } + err = rtnl_unicast(skb, tgt_net, NETLINK_CB(in_skb).portid); +errout_ifa: + in6_ifa_put(ifa); +errout: + dev_put(dev); + if (fillargs.netnsid >= 0) + put_net(tgt_net); + + return err; +} + +static void inet6_ifa_notify(int event, struct inet6_ifaddr *ifa) +{ + struct sk_buff *skb; + struct net *net = dev_net(ifa->idev->dev); + struct inet6_fill_args fillargs = { + .portid = 0, + .seq = 0, + .event = event, + .flags = 0, + .netnsid = -1, + }; + int err = -ENOBUFS; + + skb = nlmsg_new(inet6_ifaddr_msgsize(), GFP_ATOMIC); + if (!skb) + goto errout; + + err = inet6_fill_ifaddr(skb, ifa, &fillargs); + if (err < 0) { + /* -EMSGSIZE implies BUG in inet6_ifaddr_msgsize() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + rtnl_notify(skb, net, 0, RTNLGRP_IPV6_IFADDR, NULL, GFP_ATOMIC); + return; +errout: + if (err < 0) + rtnl_set_sk_err(net, RTNLGRP_IPV6_IFADDR, err); +} + +static inline void ipv6_store_devconf(struct ipv6_devconf *cnf, + __s32 *array, int bytes) +{ + BUG_ON(bytes < (DEVCONF_MAX * 4)); + + memset(array, 0, bytes); + array[DEVCONF_FORWARDING] = cnf->forwarding; + array[DEVCONF_HOPLIMIT] = cnf->hop_limit; + array[DEVCONF_MTU6] = cnf->mtu6; + array[DEVCONF_ACCEPT_RA] = cnf->accept_ra; + array[DEVCONF_ACCEPT_REDIRECTS] = cnf->accept_redirects; + array[DEVCONF_AUTOCONF] = cnf->autoconf; + array[DEVCONF_DAD_TRANSMITS] = cnf->dad_transmits; + array[DEVCONF_RTR_SOLICITS] = cnf->rtr_solicits; + array[DEVCONF_RTR_SOLICIT_INTERVAL] = + jiffies_to_msecs(cnf->rtr_solicit_interval); + array[DEVCONF_RTR_SOLICIT_MAX_INTERVAL] = + jiffies_to_msecs(cnf->rtr_solicit_max_interval); + array[DEVCONF_RTR_SOLICIT_DELAY] = + jiffies_to_msecs(cnf->rtr_solicit_delay); + array[DEVCONF_FORCE_MLD_VERSION] = cnf->force_mld_version; + array[DEVCONF_MLDV1_UNSOLICITED_REPORT_INTERVAL] = + jiffies_to_msecs(cnf->mldv1_unsolicited_report_interval); + array[DEVCONF_MLDV2_UNSOLICITED_REPORT_INTERVAL] = + jiffies_to_msecs(cnf->mldv2_unsolicited_report_interval); + array[DEVCONF_USE_TEMPADDR] = cnf->use_tempaddr; + array[DEVCONF_TEMP_VALID_LFT] = cnf->temp_valid_lft; + array[DEVCONF_TEMP_PREFERED_LFT] = cnf->temp_prefered_lft; + array[DEVCONF_REGEN_MAX_RETRY] = cnf->regen_max_retry; + array[DEVCONF_MAX_DESYNC_FACTOR] = cnf->max_desync_factor; + array[DEVCONF_MAX_ADDRESSES] = cnf->max_addresses; + array[DEVCONF_ACCEPT_RA_DEFRTR] = cnf->accept_ra_defrtr; + array[DEVCONF_RA_DEFRTR_METRIC] = cnf->ra_defrtr_metric; + array[DEVCONF_ACCEPT_RA_MIN_HOP_LIMIT] = cnf->accept_ra_min_hop_limit; + array[DEVCONF_ACCEPT_RA_PINFO] = cnf->accept_ra_pinfo; +#ifdef CONFIG_IPV6_ROUTER_PREF + array[DEVCONF_ACCEPT_RA_RTR_PREF] = cnf->accept_ra_rtr_pref; + array[DEVCONF_RTR_PROBE_INTERVAL] = + jiffies_to_msecs(cnf->rtr_probe_interval); +#ifdef CONFIG_IPV6_ROUTE_INFO + array[DEVCONF_ACCEPT_RA_RT_INFO_MIN_PLEN] = cnf->accept_ra_rt_info_min_plen; + array[DEVCONF_ACCEPT_RA_RT_INFO_MAX_PLEN] = cnf->accept_ra_rt_info_max_plen; +#endif +#endif + array[DEVCONF_PROXY_NDP] = cnf->proxy_ndp; + array[DEVCONF_ACCEPT_SOURCE_ROUTE] = cnf->accept_source_route; +#ifdef CONFIG_IPV6_OPTIMISTIC_DAD + array[DEVCONF_OPTIMISTIC_DAD] = cnf->optimistic_dad; + array[DEVCONF_USE_OPTIMISTIC] = cnf->use_optimistic; +#endif +#ifdef CONFIG_IPV6_MROUTE + array[DEVCONF_MC_FORWARDING] = atomic_read(&cnf->mc_forwarding); +#endif + array[DEVCONF_DISABLE_IPV6] = cnf->disable_ipv6; + array[DEVCONF_ACCEPT_DAD] = cnf->accept_dad; + array[DEVCONF_FORCE_TLLAO] = cnf->force_tllao; + array[DEVCONF_NDISC_NOTIFY] = cnf->ndisc_notify; + array[DEVCONF_SUPPRESS_FRAG_NDISC] = cnf->suppress_frag_ndisc; + array[DEVCONF_ACCEPT_RA_FROM_LOCAL] = cnf->accept_ra_from_local; + array[DEVCONF_ACCEPT_RA_MTU] = cnf->accept_ra_mtu; + array[DEVCONF_IGNORE_ROUTES_WITH_LINKDOWN] = cnf->ignore_routes_with_linkdown; + /* we omit DEVCONF_STABLE_SECRET for now */ + array[DEVCONF_USE_OIF_ADDRS_ONLY] = cnf->use_oif_addrs_only; + array[DEVCONF_DROP_UNICAST_IN_L2_MULTICAST] = cnf->drop_unicast_in_l2_multicast; + array[DEVCONF_DROP_UNSOLICITED_NA] = cnf->drop_unsolicited_na; + array[DEVCONF_KEEP_ADDR_ON_DOWN] = cnf->keep_addr_on_down; + array[DEVCONF_SEG6_ENABLED] = cnf->seg6_enabled; +#ifdef CONFIG_IPV6_SEG6_HMAC + array[DEVCONF_SEG6_REQUIRE_HMAC] = cnf->seg6_require_hmac; +#endif + array[DEVCONF_ENHANCED_DAD] = cnf->enhanced_dad; + array[DEVCONF_ADDR_GEN_MODE] = cnf->addr_gen_mode; + array[DEVCONF_DISABLE_POLICY] = cnf->disable_policy; + array[DEVCONF_NDISC_TCLASS] = cnf->ndisc_tclass; + array[DEVCONF_RPL_SEG_ENABLED] = cnf->rpl_seg_enabled; + array[DEVCONF_IOAM6_ENABLED] = cnf->ioam6_enabled; + array[DEVCONF_IOAM6_ID] = cnf->ioam6_id; + array[DEVCONF_IOAM6_ID_WIDE] = cnf->ioam6_id_wide; + array[DEVCONF_NDISC_EVICT_NOCARRIER] = cnf->ndisc_evict_nocarrier; + array[DEVCONF_ACCEPT_UNTRACKED_NA] = cnf->accept_untracked_na; + array[DEVCONF_ACCEPT_RA_MIN_LFT] = cnf->accept_ra_min_lft; +} + +static inline size_t inet6_ifla6_size(void) +{ + return nla_total_size(4) /* IFLA_INET6_FLAGS */ + + nla_total_size(sizeof(struct ifla_cacheinfo)) + + nla_total_size(DEVCONF_MAX * 4) /* IFLA_INET6_CONF */ + + nla_total_size(IPSTATS_MIB_MAX * 8) /* IFLA_INET6_STATS */ + + nla_total_size(ICMP6_MIB_MAX * 8) /* IFLA_INET6_ICMP6STATS */ + + nla_total_size(sizeof(struct in6_addr)) /* IFLA_INET6_TOKEN */ + + nla_total_size(1) /* IFLA_INET6_ADDR_GEN_MODE */ + + nla_total_size(4) /* IFLA_INET6_RA_MTU */ + + 0; +} + +static inline size_t inet6_if_nlmsg_size(void) +{ + return NLMSG_ALIGN(sizeof(struct ifinfomsg)) + + nla_total_size(IFNAMSIZ) /* IFLA_IFNAME */ + + nla_total_size(MAX_ADDR_LEN) /* IFLA_ADDRESS */ + + nla_total_size(4) /* IFLA_MTU */ + + nla_total_size(4) /* IFLA_LINK */ + + nla_total_size(1) /* IFLA_OPERSTATE */ + + nla_total_size(inet6_ifla6_size()); /* IFLA_PROTINFO */ +} + +static inline void __snmp6_fill_statsdev(u64 *stats, atomic_long_t *mib, + int bytes) +{ + int i; + int pad = bytes - sizeof(u64) * ICMP6_MIB_MAX; + BUG_ON(pad < 0); + + /* Use put_unaligned() because stats may not be aligned for u64. */ + put_unaligned(ICMP6_MIB_MAX, &stats[0]); + for (i = 1; i < ICMP6_MIB_MAX; i++) + put_unaligned(atomic_long_read(&mib[i]), &stats[i]); + + memset(&stats[ICMP6_MIB_MAX], 0, pad); +} + +static inline void __snmp6_fill_stats64(u64 *stats, void __percpu *mib, + int bytes, size_t syncpoff) +{ + int i, c; + u64 buff[IPSTATS_MIB_MAX]; + int pad = bytes - sizeof(u64) * IPSTATS_MIB_MAX; + + BUG_ON(pad < 0); + + memset(buff, 0, sizeof(buff)); + buff[0] = IPSTATS_MIB_MAX; + + for_each_possible_cpu(c) { + for (i = 1; i < IPSTATS_MIB_MAX; i++) + buff[i] += snmp_get_cpu_field64(mib, c, i, syncpoff); + } + + memcpy(stats, buff, IPSTATS_MIB_MAX * sizeof(u64)); + memset(&stats[IPSTATS_MIB_MAX], 0, pad); +} + +static void snmp6_fill_stats(u64 *stats, struct inet6_dev *idev, int attrtype, + int bytes) +{ + switch (attrtype) { + case IFLA_INET6_STATS: + __snmp6_fill_stats64(stats, idev->stats.ipv6, bytes, + offsetof(struct ipstats_mib, syncp)); + break; + case IFLA_INET6_ICMP6STATS: + __snmp6_fill_statsdev(stats, idev->stats.icmpv6dev->mibs, bytes); + break; + } +} + +static int inet6_fill_ifla6_attrs(struct sk_buff *skb, struct inet6_dev *idev, + u32 ext_filter_mask) +{ + struct nlattr *nla; + struct ifla_cacheinfo ci; + + if (nla_put_u32(skb, IFLA_INET6_FLAGS, idev->if_flags)) + goto nla_put_failure; + ci.max_reasm_len = IPV6_MAXPLEN; + ci.tstamp = cstamp_delta(idev->tstamp); + ci.reachable_time = jiffies_to_msecs(idev->nd_parms->reachable_time); + ci.retrans_time = jiffies_to_msecs(NEIGH_VAR(idev->nd_parms, RETRANS_TIME)); + if (nla_put(skb, IFLA_INET6_CACHEINFO, sizeof(ci), &ci)) + goto nla_put_failure; + nla = nla_reserve(skb, IFLA_INET6_CONF, DEVCONF_MAX * sizeof(s32)); + if (!nla) + goto nla_put_failure; + ipv6_store_devconf(&idev->cnf, nla_data(nla), nla_len(nla)); + + /* XXX - MC not implemented */ + + if (ext_filter_mask & RTEXT_FILTER_SKIP_STATS) + return 0; + + nla = nla_reserve(skb, IFLA_INET6_STATS, IPSTATS_MIB_MAX * sizeof(u64)); + if (!nla) + goto nla_put_failure; + snmp6_fill_stats(nla_data(nla), idev, IFLA_INET6_STATS, nla_len(nla)); + + nla = nla_reserve(skb, IFLA_INET6_ICMP6STATS, ICMP6_MIB_MAX * sizeof(u64)); + if (!nla) + goto nla_put_failure; + snmp6_fill_stats(nla_data(nla), idev, IFLA_INET6_ICMP6STATS, nla_len(nla)); + + nla = nla_reserve(skb, IFLA_INET6_TOKEN, sizeof(struct in6_addr)); + if (!nla) + goto nla_put_failure; + read_lock_bh(&idev->lock); + memcpy(nla_data(nla), idev->token.s6_addr, nla_len(nla)); + read_unlock_bh(&idev->lock); + + if (nla_put_u8(skb, IFLA_INET6_ADDR_GEN_MODE, idev->cnf.addr_gen_mode)) + goto nla_put_failure; + + if (idev->ra_mtu && + nla_put_u32(skb, IFLA_INET6_RA_MTU, idev->ra_mtu)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static size_t inet6_get_link_af_size(const struct net_device *dev, + u32 ext_filter_mask) +{ + if (!__in6_dev_get(dev)) + return 0; + + return inet6_ifla6_size(); +} + +static int inet6_fill_link_af(struct sk_buff *skb, const struct net_device *dev, + u32 ext_filter_mask) +{ + struct inet6_dev *idev = __in6_dev_get(dev); + + if (!idev) + return -ENODATA; + + if (inet6_fill_ifla6_attrs(skb, idev, ext_filter_mask) < 0) + return -EMSGSIZE; + + return 0; +} + +static int inet6_set_iftoken(struct inet6_dev *idev, struct in6_addr *token, + struct netlink_ext_ack *extack) +{ + struct inet6_ifaddr *ifp; + struct net_device *dev = idev->dev; + bool clear_token, update_rs = false; + struct in6_addr ll_addr; + + ASSERT_RTNL(); + + if (!token) + return -EINVAL; + + if (dev->flags & IFF_LOOPBACK) { + NL_SET_ERR_MSG_MOD(extack, "Device is loopback"); + return -EINVAL; + } + + if (dev->flags & IFF_NOARP) { + NL_SET_ERR_MSG_MOD(extack, + "Device does not do neighbour discovery"); + return -EINVAL; + } + + if (!ipv6_accept_ra(idev)) { + NL_SET_ERR_MSG_MOD(extack, + "Router advertisement is disabled on device"); + return -EINVAL; + } + + if (idev->cnf.rtr_solicits == 0) { + NL_SET_ERR_MSG(extack, + "Router solicitation is disabled on device"); + return -EINVAL; + } + + write_lock_bh(&idev->lock); + + BUILD_BUG_ON(sizeof(token->s6_addr) != 16); + memcpy(idev->token.s6_addr + 8, token->s6_addr + 8, 8); + + write_unlock_bh(&idev->lock); + + clear_token = ipv6_addr_any(token); + if (clear_token) + goto update_lft; + + if (!idev->dead && (idev->if_flags & IF_READY) && + !ipv6_get_lladdr(dev, &ll_addr, IFA_F_TENTATIVE | + IFA_F_OPTIMISTIC)) { + /* If we're not ready, then normal ifup will take care + * of this. Otherwise, we need to request our rs here. + */ + ndisc_send_rs(dev, &ll_addr, &in6addr_linklocal_allrouters); + update_rs = true; + } + +update_lft: + write_lock_bh(&idev->lock); + + if (update_rs) { + idev->if_flags |= IF_RS_SENT; + idev->rs_interval = rfc3315_s14_backoff_init( + idev->cnf.rtr_solicit_interval); + idev->rs_probes = 1; + addrconf_mod_rs_timer(idev, idev->rs_interval); + } + + /* Well, that's kinda nasty ... */ + list_for_each_entry(ifp, &idev->addr_list, if_list) { + spin_lock(&ifp->lock); + if (ifp->tokenized) { + ifp->valid_lft = 0; + ifp->prefered_lft = 0; + } + spin_unlock(&ifp->lock); + } + + write_unlock_bh(&idev->lock); + inet6_ifinfo_notify(RTM_NEWLINK, idev); + addrconf_verify_rtnl(dev_net(dev)); + return 0; +} + +static const struct nla_policy inet6_af_policy[IFLA_INET6_MAX + 1] = { + [IFLA_INET6_ADDR_GEN_MODE] = { .type = NLA_U8 }, + [IFLA_INET6_TOKEN] = { .len = sizeof(struct in6_addr) }, + [IFLA_INET6_RA_MTU] = { .type = NLA_REJECT, + .reject_message = + "IFLA_INET6_RA_MTU can not be set" }, +}; + +static int check_addr_gen_mode(int mode) +{ + if (mode != IN6_ADDR_GEN_MODE_EUI64 && + mode != IN6_ADDR_GEN_MODE_NONE && + mode != IN6_ADDR_GEN_MODE_STABLE_PRIVACY && + mode != IN6_ADDR_GEN_MODE_RANDOM) + return -EINVAL; + return 1; +} + +static int check_stable_privacy(struct inet6_dev *idev, struct net *net, + int mode) +{ + if (mode == IN6_ADDR_GEN_MODE_STABLE_PRIVACY && + !idev->cnf.stable_secret.initialized && + !net->ipv6.devconf_dflt->stable_secret.initialized) + return -EINVAL; + return 1; +} + +static int inet6_validate_link_af(const struct net_device *dev, + const struct nlattr *nla, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_INET6_MAX + 1]; + struct inet6_dev *idev = NULL; + int err; + + if (dev) { + idev = __in6_dev_get(dev); + if (!idev) + return -EAFNOSUPPORT; + } + + err = nla_parse_nested_deprecated(tb, IFLA_INET6_MAX, nla, + inet6_af_policy, extack); + if (err) + return err; + + if (!tb[IFLA_INET6_TOKEN] && !tb[IFLA_INET6_ADDR_GEN_MODE]) + return -EINVAL; + + if (tb[IFLA_INET6_ADDR_GEN_MODE]) { + u8 mode = nla_get_u8(tb[IFLA_INET6_ADDR_GEN_MODE]); + + if (check_addr_gen_mode(mode) < 0) + return -EINVAL; + if (dev && check_stable_privacy(idev, dev_net(dev), mode) < 0) + return -EINVAL; + } + + return 0; +} + +static int inet6_set_link_af(struct net_device *dev, const struct nlattr *nla, + struct netlink_ext_ack *extack) +{ + struct inet6_dev *idev = __in6_dev_get(dev); + struct nlattr *tb[IFLA_INET6_MAX + 1]; + int err; + + if (!idev) + return -EAFNOSUPPORT; + + if (nla_parse_nested_deprecated(tb, IFLA_INET6_MAX, nla, NULL, NULL) < 0) + return -EINVAL; + + if (tb[IFLA_INET6_TOKEN]) { + err = inet6_set_iftoken(idev, nla_data(tb[IFLA_INET6_TOKEN]), + extack); + if (err) + return err; + } + + if (tb[IFLA_INET6_ADDR_GEN_MODE]) { + u8 mode = nla_get_u8(tb[IFLA_INET6_ADDR_GEN_MODE]); + + idev->cnf.addr_gen_mode = mode; + } + + return 0; +} + +static int inet6_fill_ifinfo(struct sk_buff *skb, struct inet6_dev *idev, + u32 portid, u32 seq, int event, unsigned int flags) +{ + struct net_device *dev = idev->dev; + struct ifinfomsg *hdr; + struct nlmsghdr *nlh; + void *protoinfo; + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(*hdr), flags); + if (!nlh) + return -EMSGSIZE; + + hdr = nlmsg_data(nlh); + hdr->ifi_family = AF_INET6; + hdr->__ifi_pad = 0; + hdr->ifi_type = dev->type; + hdr->ifi_index = dev->ifindex; + hdr->ifi_flags = dev_get_flags(dev); + hdr->ifi_change = 0; + + if (nla_put_string(skb, IFLA_IFNAME, dev->name) || + (dev->addr_len && + nla_put(skb, IFLA_ADDRESS, dev->addr_len, dev->dev_addr)) || + nla_put_u32(skb, IFLA_MTU, dev->mtu) || + (dev->ifindex != dev_get_iflink(dev) && + nla_put_u32(skb, IFLA_LINK, dev_get_iflink(dev))) || + nla_put_u8(skb, IFLA_OPERSTATE, + netif_running(dev) ? dev->operstate : IF_OPER_DOWN)) + goto nla_put_failure; + protoinfo = nla_nest_start_noflag(skb, IFLA_PROTINFO); + if (!protoinfo) + goto nla_put_failure; + + if (inet6_fill_ifla6_attrs(skb, idev, 0) < 0) + goto nla_put_failure; + + nla_nest_end(skb, protoinfo); + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int inet6_valid_dump_ifinfo(const struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct ifinfomsg *ifm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for link dump request"); + return -EINVAL; + } + + if (nlmsg_attrlen(nlh, sizeof(*ifm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid data after header"); + return -EINVAL; + } + + ifm = nlmsg_data(nlh); + if (ifm->__ifi_pad || ifm->ifi_type || ifm->ifi_flags || + ifm->ifi_change || ifm->ifi_index) { + NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for dump request"); + return -EINVAL; + } + + return 0; +} + +static int inet6_dump_ifinfo(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + int h, s_h; + int idx = 0, s_idx; + struct net_device *dev; + struct inet6_dev *idev; + struct hlist_head *head; + + /* only requests using strict checking can pass data to + * influence the dump + */ + if (cb->strict_check) { + int err = inet6_valid_dump_ifinfo(cb->nlh, cb->extack); + + if (err < 0) + return err; + } + + s_h = cb->args[0]; + s_idx = cb->args[1]; + + rcu_read_lock(); + for (h = s_h; h < NETDEV_HASHENTRIES; h++, s_idx = 0) { + idx = 0; + head = &net->dev_index_head[h]; + hlist_for_each_entry_rcu(dev, head, index_hlist) { + if (idx < s_idx) + goto cont; + idev = __in6_dev_get(dev); + if (!idev) + goto cont; + if (inet6_fill_ifinfo(skb, idev, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + RTM_NEWLINK, NLM_F_MULTI) < 0) + goto out; +cont: + idx++; + } + } +out: + rcu_read_unlock(); + cb->args[1] = idx; + cb->args[0] = h; + + return skb->len; +} + +void inet6_ifinfo_notify(int event, struct inet6_dev *idev) +{ + struct sk_buff *skb; + struct net *net = dev_net(idev->dev); + int err = -ENOBUFS; + + skb = nlmsg_new(inet6_if_nlmsg_size(), GFP_ATOMIC); + if (!skb) + goto errout; + + err = inet6_fill_ifinfo(skb, idev, 0, 0, event, 0); + if (err < 0) { + /* -EMSGSIZE implies BUG in inet6_if_nlmsg_size() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + rtnl_notify(skb, net, 0, RTNLGRP_IPV6_IFINFO, NULL, GFP_ATOMIC); + return; +errout: + if (err < 0) + rtnl_set_sk_err(net, RTNLGRP_IPV6_IFINFO, err); +} + +static inline size_t inet6_prefix_nlmsg_size(void) +{ + return NLMSG_ALIGN(sizeof(struct prefixmsg)) + + nla_total_size(sizeof(struct in6_addr)) + + nla_total_size(sizeof(struct prefix_cacheinfo)); +} + +static int inet6_fill_prefix(struct sk_buff *skb, struct inet6_dev *idev, + struct prefix_info *pinfo, u32 portid, u32 seq, + int event, unsigned int flags) +{ + struct prefixmsg *pmsg; + struct nlmsghdr *nlh; + struct prefix_cacheinfo ci; + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(*pmsg), flags); + if (!nlh) + return -EMSGSIZE; + + pmsg = nlmsg_data(nlh); + pmsg->prefix_family = AF_INET6; + pmsg->prefix_pad1 = 0; + pmsg->prefix_pad2 = 0; + pmsg->prefix_ifindex = idev->dev->ifindex; + pmsg->prefix_len = pinfo->prefix_len; + pmsg->prefix_type = pinfo->type; + pmsg->prefix_pad3 = 0; + pmsg->prefix_flags = pinfo->flags; + + if (nla_put(skb, PREFIX_ADDRESS, sizeof(pinfo->prefix), &pinfo->prefix)) + goto nla_put_failure; + ci.preferred_time = ntohl(pinfo->prefered); + ci.valid_time = ntohl(pinfo->valid); + if (nla_put(skb, PREFIX_CACHEINFO, sizeof(ci), &ci)) + goto nla_put_failure; + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static void inet6_prefix_notify(int event, struct inet6_dev *idev, + struct prefix_info *pinfo) +{ + struct sk_buff *skb; + struct net *net = dev_net(idev->dev); + int err = -ENOBUFS; + + skb = nlmsg_new(inet6_prefix_nlmsg_size(), GFP_ATOMIC); + if (!skb) + goto errout; + + err = inet6_fill_prefix(skb, idev, pinfo, 0, 0, event, 0); + if (err < 0) { + /* -EMSGSIZE implies BUG in inet6_prefix_nlmsg_size() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + rtnl_notify(skb, net, 0, RTNLGRP_IPV6_PREFIX, NULL, GFP_ATOMIC); + return; +errout: + if (err < 0) + rtnl_set_sk_err(net, RTNLGRP_IPV6_PREFIX, err); +} + +static void __ipv6_ifa_notify(int event, struct inet6_ifaddr *ifp) +{ + struct net *net = dev_net(ifp->idev->dev); + + if (event) + ASSERT_RTNL(); + + inet6_ifa_notify(event ? : RTM_NEWADDR, ifp); + + switch (event) { + case RTM_NEWADDR: + /* + * If the address was optimistic we inserted the route at the + * start of our DAD process, so we don't need to do it again. + * If the device was taken down in the middle of the DAD + * cycle there is a race where we could get here without a + * host route, so nothing to insert. That will be fixed when + * the device is brought up. + */ + if (ifp->rt && !rcu_access_pointer(ifp->rt->fib6_node)) { + ip6_ins_rt(net, ifp->rt); + } else if (!ifp->rt && (ifp->idev->dev->flags & IFF_UP)) { + pr_warn("BUG: Address %pI6c on device %s is missing its host route.\n", + &ifp->addr, ifp->idev->dev->name); + } + + if (ifp->idev->cnf.forwarding) + addrconf_join_anycast(ifp); + if (!ipv6_addr_any(&ifp->peer_addr)) + addrconf_prefix_route(&ifp->peer_addr, 128, + ifp->rt_priority, ifp->idev->dev, + 0, 0, GFP_ATOMIC); + break; + case RTM_DELADDR: + if (ifp->idev->cnf.forwarding) + addrconf_leave_anycast(ifp); + addrconf_leave_solict(ifp->idev, &ifp->addr); + if (!ipv6_addr_any(&ifp->peer_addr)) { + struct fib6_info *rt; + + rt = addrconf_get_prefix_route(&ifp->peer_addr, 128, + ifp->idev->dev, 0, 0, + false); + if (rt) + ip6_del_rt(net, rt, false); + } + if (ifp->rt) { + ip6_del_rt(net, ifp->rt, false); + ifp->rt = NULL; + } + rt_genid_bump_ipv6(net); + break; + } + atomic_inc(&net->ipv6.dev_addr_genid); +} + +static void ipv6_ifa_notify(int event, struct inet6_ifaddr *ifp) +{ + if (likely(ifp->idev->dead == 0)) + __ipv6_ifa_notify(event, ifp); +} + +#ifdef CONFIG_SYSCTL + +static int addrconf_sysctl_forward(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + int *valp = ctl->data; + int val = *valp; + loff_t pos = *ppos; + struct ctl_table lctl; + int ret; + + /* + * ctl->data points to idev->cnf.forwarding, we should + * not modify it until we get the rtnl lock. + */ + lctl = *ctl; + lctl.data = &val; + + ret = proc_dointvec(&lctl, write, buffer, lenp, ppos); + + if (write) + ret = addrconf_fixup_forwarding(ctl, valp, val); + if (ret) + *ppos = pos; + return ret; +} + +static int addrconf_sysctl_mtu(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct inet6_dev *idev = ctl->extra1; + int min_mtu = IPV6_MIN_MTU; + struct ctl_table lctl; + + lctl = *ctl; + lctl.extra1 = &min_mtu; + lctl.extra2 = idev ? &idev->dev->mtu : NULL; + + return proc_dointvec_minmax(&lctl, write, buffer, lenp, ppos); +} + +static void dev_disable_change(struct inet6_dev *idev) +{ + struct netdev_notifier_info info; + + if (!idev || !idev->dev) + return; + + netdev_notifier_info_init(&info, idev->dev); + if (idev->cnf.disable_ipv6) + addrconf_notify(NULL, NETDEV_DOWN, &info); + else + addrconf_notify(NULL, NETDEV_UP, &info); +} + +static void addrconf_disable_change(struct net *net, __s32 newf) +{ + struct net_device *dev; + struct inet6_dev *idev; + + for_each_netdev(net, dev) { + idev = __in6_dev_get(dev); + if (idev) { + int changed = (!idev->cnf.disable_ipv6) ^ (!newf); + idev->cnf.disable_ipv6 = newf; + if (changed) + dev_disable_change(idev); + } + } +} + +static int addrconf_disable_ipv6(struct ctl_table *table, int *p, int newf) +{ + struct net *net; + int old; + + if (!rtnl_trylock()) + return restart_syscall(); + + net = (struct net *)table->extra2; + old = *p; + *p = newf; + + if (p == &net->ipv6.devconf_dflt->disable_ipv6) { + rtnl_unlock(); + return 0; + } + + if (p == &net->ipv6.devconf_all->disable_ipv6) { + net->ipv6.devconf_dflt->disable_ipv6 = newf; + addrconf_disable_change(net, newf); + } else if ((!newf) ^ (!old)) + dev_disable_change((struct inet6_dev *)table->extra1); + + rtnl_unlock(); + return 0; +} + +static int addrconf_sysctl_disable(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + int *valp = ctl->data; + int val = *valp; + loff_t pos = *ppos; + struct ctl_table lctl; + int ret; + + /* + * ctl->data points to idev->cnf.disable_ipv6, we should + * not modify it until we get the rtnl lock. + */ + lctl = *ctl; + lctl.data = &val; + + ret = proc_dointvec(&lctl, write, buffer, lenp, ppos); + + if (write) + ret = addrconf_disable_ipv6(ctl, valp, val); + if (ret) + *ppos = pos; + return ret; +} + +static int addrconf_sysctl_proxy_ndp(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + int *valp = ctl->data; + int ret; + int old, new; + + old = *valp; + ret = proc_dointvec(ctl, write, buffer, lenp, ppos); + new = *valp; + + if (write && old != new) { + struct net *net = ctl->extra2; + + if (!rtnl_trylock()) + return restart_syscall(); + + if (valp == &net->ipv6.devconf_dflt->proxy_ndp) + inet6_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_PROXY_NEIGH, + NETCONFA_IFINDEX_DEFAULT, + net->ipv6.devconf_dflt); + else if (valp == &net->ipv6.devconf_all->proxy_ndp) + inet6_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_PROXY_NEIGH, + NETCONFA_IFINDEX_ALL, + net->ipv6.devconf_all); + else { + struct inet6_dev *idev = ctl->extra1; + + inet6_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_PROXY_NEIGH, + idev->dev->ifindex, + &idev->cnf); + } + rtnl_unlock(); + } + + return ret; +} + +static int addrconf_sysctl_addr_gen_mode(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, + loff_t *ppos) +{ + int ret = 0; + u32 new_val; + struct inet6_dev *idev = (struct inet6_dev *)ctl->extra1; + struct net *net = (struct net *)ctl->extra2; + struct ctl_table tmp = { + .data = &new_val, + .maxlen = sizeof(new_val), + .mode = ctl->mode, + }; + + if (!rtnl_trylock()) + return restart_syscall(); + + new_val = *((u32 *)ctl->data); + + ret = proc_douintvec(&tmp, write, buffer, lenp, ppos); + if (ret != 0) + goto out; + + if (write) { + if (check_addr_gen_mode(new_val) < 0) { + ret = -EINVAL; + goto out; + } + + if (idev) { + if (check_stable_privacy(idev, net, new_val) < 0) { + ret = -EINVAL; + goto out; + } + + if (idev->cnf.addr_gen_mode != new_val) { + idev->cnf.addr_gen_mode = new_val; + addrconf_init_auto_addrs(idev->dev); + } + } else if (&net->ipv6.devconf_all->addr_gen_mode == ctl->data) { + struct net_device *dev; + + net->ipv6.devconf_dflt->addr_gen_mode = new_val; + for_each_netdev(net, dev) { + idev = __in6_dev_get(dev); + if (idev && + idev->cnf.addr_gen_mode != new_val) { + idev->cnf.addr_gen_mode = new_val; + addrconf_init_auto_addrs(idev->dev); + } + } + } + + *((u32 *)ctl->data) = new_val; + } + +out: + rtnl_unlock(); + + return ret; +} + +static int addrconf_sysctl_stable_secret(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, + loff_t *ppos) +{ + int err; + struct in6_addr addr; + char str[IPV6_MAX_STRLEN]; + struct ctl_table lctl = *ctl; + struct net *net = ctl->extra2; + struct ipv6_stable_secret *secret = ctl->data; + + if (&net->ipv6.devconf_all->stable_secret == ctl->data) + return -EIO; + + lctl.maxlen = IPV6_MAX_STRLEN; + lctl.data = str; + + if (!rtnl_trylock()) + return restart_syscall(); + + if (!write && !secret->initialized) { + err = -EIO; + goto out; + } + + err = snprintf(str, sizeof(str), "%pI6", &secret->secret); + if (err >= sizeof(str)) { + err = -EIO; + goto out; + } + + err = proc_dostring(&lctl, write, buffer, lenp, ppos); + if (err || !write) + goto out; + + if (in6_pton(str, -1, addr.in6_u.u6_addr8, -1, NULL) != 1) { + err = -EIO; + goto out; + } + + secret->initialized = true; + secret->secret = addr; + + if (&net->ipv6.devconf_dflt->stable_secret == ctl->data) { + struct net_device *dev; + + for_each_netdev(net, dev) { + struct inet6_dev *idev = __in6_dev_get(dev); + + if (idev) { + idev->cnf.addr_gen_mode = + IN6_ADDR_GEN_MODE_STABLE_PRIVACY; + } + } + } else { + struct inet6_dev *idev = ctl->extra1; + + idev->cnf.addr_gen_mode = IN6_ADDR_GEN_MODE_STABLE_PRIVACY; + } + +out: + rtnl_unlock(); + + return err; +} + +static +int addrconf_sysctl_ignore_routes_with_linkdown(struct ctl_table *ctl, + int write, void *buffer, + size_t *lenp, + loff_t *ppos) +{ + int *valp = ctl->data; + int val = *valp; + loff_t pos = *ppos; + struct ctl_table lctl; + int ret; + + /* ctl->data points to idev->cnf.ignore_routes_when_linkdown + * we should not modify it until we get the rtnl lock. + */ + lctl = *ctl; + lctl.data = &val; + + ret = proc_dointvec(&lctl, write, buffer, lenp, ppos); + + if (write) + ret = addrconf_fixup_linkdown(ctl, valp, val); + if (ret) + *ppos = pos; + return ret; +} + +static +void addrconf_set_nopolicy(struct rt6_info *rt, int action) +{ + if (rt) { + if (action) + rt->dst.flags |= DST_NOPOLICY; + else + rt->dst.flags &= ~DST_NOPOLICY; + } +} + +static +void addrconf_disable_policy_idev(struct inet6_dev *idev, int val) +{ + struct inet6_ifaddr *ifa; + + read_lock_bh(&idev->lock); + list_for_each_entry(ifa, &idev->addr_list, if_list) { + spin_lock(&ifa->lock); + if (ifa->rt) { + /* host routes only use builtin fib6_nh */ + struct fib6_nh *nh = ifa->rt->fib6_nh; + int cpu; + + rcu_read_lock(); + ifa->rt->dst_nopolicy = val ? true : false; + if (nh->rt6i_pcpu) { + for_each_possible_cpu(cpu) { + struct rt6_info **rtp; + + rtp = per_cpu_ptr(nh->rt6i_pcpu, cpu); + addrconf_set_nopolicy(*rtp, val); + } + } + rcu_read_unlock(); + } + spin_unlock(&ifa->lock); + } + read_unlock_bh(&idev->lock); +} + +static +int addrconf_disable_policy(struct ctl_table *ctl, int *valp, int val) +{ + struct inet6_dev *idev; + struct net *net; + + if (!rtnl_trylock()) + return restart_syscall(); + + *valp = val; + + net = (struct net *)ctl->extra2; + if (valp == &net->ipv6.devconf_dflt->disable_policy) { + rtnl_unlock(); + return 0; + } + + if (valp == &net->ipv6.devconf_all->disable_policy) { + struct net_device *dev; + + for_each_netdev(net, dev) { + idev = __in6_dev_get(dev); + if (idev) + addrconf_disable_policy_idev(idev, val); + } + } else { + idev = (struct inet6_dev *)ctl->extra1; + addrconf_disable_policy_idev(idev, val); + } + + rtnl_unlock(); + return 0; +} + +static int addrconf_sysctl_disable_policy(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + int *valp = ctl->data; + int val = *valp; + loff_t pos = *ppos; + struct ctl_table lctl; + int ret; + + lctl = *ctl; + lctl.data = &val; + ret = proc_dointvec(&lctl, write, buffer, lenp, ppos); + + if (write && (*valp != val)) + ret = addrconf_disable_policy(ctl, valp, val); + + if (ret) + *ppos = pos; + + return ret; +} + +static int minus_one = -1; +static const int two_five_five = 255; +static u32 ioam6_if_id_max = U16_MAX; + +static const struct ctl_table addrconf_sysctl[] = { + { + .procname = "forwarding", + .data = &ipv6_devconf.forwarding, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = addrconf_sysctl_forward, + }, + { + .procname = "hop_limit", + .data = &ipv6_devconf.hop_limit, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = (void *)SYSCTL_ONE, + .extra2 = (void *)&two_five_five, + }, + { + .procname = "mtu", + .data = &ipv6_devconf.mtu6, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = addrconf_sysctl_mtu, + }, + { + .procname = "accept_ra", + .data = &ipv6_devconf.accept_ra, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "accept_redirects", + .data = &ipv6_devconf.accept_redirects, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "autoconf", + .data = &ipv6_devconf.autoconf, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "dad_transmits", + .data = &ipv6_devconf.dad_transmits, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "router_solicitations", + .data = &ipv6_devconf.rtr_solicits, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &minus_one, + }, + { + .procname = "router_solicitation_interval", + .data = &ipv6_devconf.rtr_solicit_interval, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "router_solicitation_max_interval", + .data = &ipv6_devconf.rtr_solicit_max_interval, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "router_solicitation_delay", + .data = &ipv6_devconf.rtr_solicit_delay, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "force_mld_version", + .data = &ipv6_devconf.force_mld_version, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "mldv1_unsolicited_report_interval", + .data = + &ipv6_devconf.mldv1_unsolicited_report_interval, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_ms_jiffies, + }, + { + .procname = "mldv2_unsolicited_report_interval", + .data = + &ipv6_devconf.mldv2_unsolicited_report_interval, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_ms_jiffies, + }, + { + .procname = "use_tempaddr", + .data = &ipv6_devconf.use_tempaddr, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "temp_valid_lft", + .data = &ipv6_devconf.temp_valid_lft, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "temp_prefered_lft", + .data = &ipv6_devconf.temp_prefered_lft, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "regen_max_retry", + .data = &ipv6_devconf.regen_max_retry, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "max_desync_factor", + .data = &ipv6_devconf.max_desync_factor, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "max_addresses", + .data = &ipv6_devconf.max_addresses, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "accept_ra_defrtr", + .data = &ipv6_devconf.accept_ra_defrtr, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "ra_defrtr_metric", + .data = &ipv6_devconf.ra_defrtr_metric, + .maxlen = sizeof(u32), + .mode = 0644, + .proc_handler = proc_douintvec_minmax, + .extra1 = (void *)SYSCTL_ONE, + }, + { + .procname = "accept_ra_min_hop_limit", + .data = &ipv6_devconf.accept_ra_min_hop_limit, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "accept_ra_min_lft", + .data = &ipv6_devconf.accept_ra_min_lft, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "accept_ra_pinfo", + .data = &ipv6_devconf.accept_ra_pinfo, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, +#ifdef CONFIG_IPV6_ROUTER_PREF + { + .procname = "accept_ra_rtr_pref", + .data = &ipv6_devconf.accept_ra_rtr_pref, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "router_probe_interval", + .data = &ipv6_devconf.rtr_probe_interval, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, +#ifdef CONFIG_IPV6_ROUTE_INFO + { + .procname = "accept_ra_rt_info_min_plen", + .data = &ipv6_devconf.accept_ra_rt_info_min_plen, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "accept_ra_rt_info_max_plen", + .data = &ipv6_devconf.accept_ra_rt_info_max_plen, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, +#endif +#endif + { + .procname = "proxy_ndp", + .data = &ipv6_devconf.proxy_ndp, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = addrconf_sysctl_proxy_ndp, + }, + { + .procname = "accept_source_route", + .data = &ipv6_devconf.accept_source_route, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, +#ifdef CONFIG_IPV6_OPTIMISTIC_DAD + { + .procname = "optimistic_dad", + .data = &ipv6_devconf.optimistic_dad, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "use_optimistic", + .data = &ipv6_devconf.use_optimistic, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, +#endif +#ifdef CONFIG_IPV6_MROUTE + { + .procname = "mc_forwarding", + .data = &ipv6_devconf.mc_forwarding, + .maxlen = sizeof(int), + .mode = 0444, + .proc_handler = proc_dointvec, + }, +#endif + { + .procname = "disable_ipv6", + .data = &ipv6_devconf.disable_ipv6, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = addrconf_sysctl_disable, + }, + { + .procname = "accept_dad", + .data = &ipv6_devconf.accept_dad, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "force_tllao", + .data = &ipv6_devconf.force_tllao, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "ndisc_notify", + .data = &ipv6_devconf.ndisc_notify, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "suppress_frag_ndisc", + .data = &ipv6_devconf.suppress_frag_ndisc, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "accept_ra_from_local", + .data = &ipv6_devconf.accept_ra_from_local, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "accept_ra_mtu", + .data = &ipv6_devconf.accept_ra_mtu, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "stable_secret", + .data = &ipv6_devconf.stable_secret, + .maxlen = IPV6_MAX_STRLEN, + .mode = 0600, + .proc_handler = addrconf_sysctl_stable_secret, + }, + { + .procname = "use_oif_addrs_only", + .data = &ipv6_devconf.use_oif_addrs_only, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "ignore_routes_with_linkdown", + .data = &ipv6_devconf.ignore_routes_with_linkdown, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = addrconf_sysctl_ignore_routes_with_linkdown, + }, + { + .procname = "drop_unicast_in_l2_multicast", + .data = &ipv6_devconf.drop_unicast_in_l2_multicast, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "drop_unsolicited_na", + .data = &ipv6_devconf.drop_unsolicited_na, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "keep_addr_on_down", + .data = &ipv6_devconf.keep_addr_on_down, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + + }, + { + .procname = "seg6_enabled", + .data = &ipv6_devconf.seg6_enabled, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, +#ifdef CONFIG_IPV6_SEG6_HMAC + { + .procname = "seg6_require_hmac", + .data = &ipv6_devconf.seg6_require_hmac, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, +#endif + { + .procname = "enhanced_dad", + .data = &ipv6_devconf.enhanced_dad, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "addr_gen_mode", + .data = &ipv6_devconf.addr_gen_mode, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = addrconf_sysctl_addr_gen_mode, + }, + { + .procname = "disable_policy", + .data = &ipv6_devconf.disable_policy, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = addrconf_sysctl_disable_policy, + }, + { + .procname = "ndisc_tclass", + .data = &ipv6_devconf.ndisc_tclass, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = (void *)SYSCTL_ZERO, + .extra2 = (void *)&two_five_five, + }, + { + .procname = "rpl_seg_enabled", + .data = &ipv6_devconf.rpl_seg_enabled, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "ioam6_enabled", + .data = &ipv6_devconf.ioam6_enabled, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = (void *)SYSCTL_ZERO, + .extra2 = (void *)SYSCTL_ONE, + }, + { + .procname = "ioam6_id", + .data = &ipv6_devconf.ioam6_id, + .maxlen = sizeof(u32), + .mode = 0644, + .proc_handler = proc_douintvec_minmax, + .extra1 = (void *)SYSCTL_ZERO, + .extra2 = (void *)&ioam6_if_id_max, + }, + { + .procname = "ioam6_id_wide", + .data = &ipv6_devconf.ioam6_id_wide, + .maxlen = sizeof(u32), + .mode = 0644, + .proc_handler = proc_douintvec, + }, + { + .procname = "ndisc_evict_nocarrier", + .data = &ipv6_devconf.ndisc_evict_nocarrier, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = (void *)SYSCTL_ZERO, + .extra2 = (void *)SYSCTL_ONE, + }, + { + .procname = "accept_untracked_na", + .data = &ipv6_devconf.accept_untracked_na, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_TWO, + }, + { + /* sentinel */ + } +}; + +static int __addrconf_sysctl_register(struct net *net, char *dev_name, + struct inet6_dev *idev, struct ipv6_devconf *p) +{ + int i, ifindex; + struct ctl_table *table; + char path[sizeof("net/ipv6/conf/") + IFNAMSIZ]; + + table = kmemdup(addrconf_sysctl, sizeof(addrconf_sysctl), GFP_KERNEL_ACCOUNT); + if (!table) + goto out; + + for (i = 0; table[i].data; i++) { + table[i].data += (char *)p - (char *)&ipv6_devconf; + /* If one of these is already set, then it is not safe to + * overwrite either of them: this makes proc_dointvec_minmax + * usable. + */ + if (!table[i].extra1 && !table[i].extra2) { + table[i].extra1 = idev; /* embedded; no ref */ + table[i].extra2 = net; + } + } + + snprintf(path, sizeof(path), "net/ipv6/conf/%s", dev_name); + + p->sysctl_header = register_net_sysctl(net, path, table); + if (!p->sysctl_header) + goto free; + + if (!strcmp(dev_name, "all")) + ifindex = NETCONFA_IFINDEX_ALL; + else if (!strcmp(dev_name, "default")) + ifindex = NETCONFA_IFINDEX_DEFAULT; + else + ifindex = idev->dev->ifindex; + inet6_netconf_notify_devconf(net, RTM_NEWNETCONF, NETCONFA_ALL, + ifindex, p); + return 0; + +free: + kfree(table); +out: + return -ENOBUFS; +} + +static void __addrconf_sysctl_unregister(struct net *net, + struct ipv6_devconf *p, int ifindex) +{ + struct ctl_table *table; + + if (!p->sysctl_header) + return; + + table = p->sysctl_header->ctl_table_arg; + unregister_net_sysctl_table(p->sysctl_header); + p->sysctl_header = NULL; + kfree(table); + + inet6_netconf_notify_devconf(net, RTM_DELNETCONF, 0, ifindex, NULL); +} + +static int addrconf_sysctl_register(struct inet6_dev *idev) +{ + int err; + + if (!sysctl_dev_name_is_allowed(idev->dev->name)) + return -EINVAL; + + err = neigh_sysctl_register(idev->dev, idev->nd_parms, + &ndisc_ifinfo_sysctl_change); + if (err) + return err; + err = __addrconf_sysctl_register(dev_net(idev->dev), idev->dev->name, + idev, &idev->cnf); + if (err) + neigh_sysctl_unregister(idev->nd_parms); + + return err; +} + +static void addrconf_sysctl_unregister(struct inet6_dev *idev) +{ + __addrconf_sysctl_unregister(dev_net(idev->dev), &idev->cnf, + idev->dev->ifindex); + neigh_sysctl_unregister(idev->nd_parms); +} + + +#endif + +static int __net_init addrconf_init_net(struct net *net) +{ + int err = -ENOMEM; + struct ipv6_devconf *all, *dflt; + + spin_lock_init(&net->ipv6.addrconf_hash_lock); + INIT_DEFERRABLE_WORK(&net->ipv6.addr_chk_work, addrconf_verify_work); + net->ipv6.inet6_addr_lst = kcalloc(IN6_ADDR_HSIZE, + sizeof(struct hlist_head), + GFP_KERNEL); + if (!net->ipv6.inet6_addr_lst) + goto err_alloc_addr; + + all = kmemdup(&ipv6_devconf, sizeof(ipv6_devconf), GFP_KERNEL); + if (!all) + goto err_alloc_all; + + dflt = kmemdup(&ipv6_devconf_dflt, sizeof(ipv6_devconf_dflt), GFP_KERNEL); + if (!dflt) + goto err_alloc_dflt; + + if (!net_eq(net, &init_net)) { + switch (net_inherit_devconf()) { + case 1: /* copy from init_net */ + memcpy(all, init_net.ipv6.devconf_all, + sizeof(ipv6_devconf)); + memcpy(dflt, init_net.ipv6.devconf_dflt, + sizeof(ipv6_devconf_dflt)); + break; + case 3: /* copy from the current netns */ + memcpy(all, current->nsproxy->net_ns->ipv6.devconf_all, + sizeof(ipv6_devconf)); + memcpy(dflt, + current->nsproxy->net_ns->ipv6.devconf_dflt, + sizeof(ipv6_devconf_dflt)); + break; + case 0: + case 2: + /* use compiled values */ + break; + } + } + + /* these will be inherited by all namespaces */ + dflt->autoconf = ipv6_defaults.autoconf; + dflt->disable_ipv6 = ipv6_defaults.disable_ipv6; + + dflt->stable_secret.initialized = false; + all->stable_secret.initialized = false; + + net->ipv6.devconf_all = all; + net->ipv6.devconf_dflt = dflt; + +#ifdef CONFIG_SYSCTL + err = __addrconf_sysctl_register(net, "all", NULL, all); + if (err < 0) + goto err_reg_all; + + err = __addrconf_sysctl_register(net, "default", NULL, dflt); + if (err < 0) + goto err_reg_dflt; +#endif + return 0; + +#ifdef CONFIG_SYSCTL +err_reg_dflt: + __addrconf_sysctl_unregister(net, all, NETCONFA_IFINDEX_ALL); +err_reg_all: + kfree(dflt); + net->ipv6.devconf_dflt = NULL; +#endif +err_alloc_dflt: + kfree(all); + net->ipv6.devconf_all = NULL; +err_alloc_all: + kfree(net->ipv6.inet6_addr_lst); +err_alloc_addr: + return err; +} + +static void __net_exit addrconf_exit_net(struct net *net) +{ + int i; + +#ifdef CONFIG_SYSCTL + __addrconf_sysctl_unregister(net, net->ipv6.devconf_dflt, + NETCONFA_IFINDEX_DEFAULT); + __addrconf_sysctl_unregister(net, net->ipv6.devconf_all, + NETCONFA_IFINDEX_ALL); +#endif + kfree(net->ipv6.devconf_dflt); + net->ipv6.devconf_dflt = NULL; + kfree(net->ipv6.devconf_all); + net->ipv6.devconf_all = NULL; + + cancel_delayed_work_sync(&net->ipv6.addr_chk_work); + /* + * Check hash table, then free it. + */ + for (i = 0; i < IN6_ADDR_HSIZE; i++) + WARN_ON_ONCE(!hlist_empty(&net->ipv6.inet6_addr_lst[i])); + + kfree(net->ipv6.inet6_addr_lst); + net->ipv6.inet6_addr_lst = NULL; +} + +static struct pernet_operations addrconf_ops = { + .init = addrconf_init_net, + .exit = addrconf_exit_net, +}; + +static struct rtnl_af_ops inet6_ops __read_mostly = { + .family = AF_INET6, + .fill_link_af = inet6_fill_link_af, + .get_link_af_size = inet6_get_link_af_size, + .validate_link_af = inet6_validate_link_af, + .set_link_af = inet6_set_link_af, +}; + +/* + * Init / cleanup code + */ + +int __init addrconf_init(void) +{ + struct inet6_dev *idev; + int err; + + err = ipv6_addr_label_init(); + if (err < 0) { + pr_crit("%s: cannot initialize default policy table: %d\n", + __func__, err); + goto out; + } + + err = register_pernet_subsys(&addrconf_ops); + if (err < 0) + goto out_addrlabel; + + addrconf_wq = create_workqueue("ipv6_addrconf"); + if (!addrconf_wq) { + err = -ENOMEM; + goto out_nowq; + } + + rtnl_lock(); + idev = ipv6_add_dev(blackhole_netdev); + rtnl_unlock(); + if (IS_ERR(idev)) { + err = PTR_ERR(idev); + goto errlo; + } + + ip6_route_init_special_entries(); + + register_netdevice_notifier(&ipv6_dev_notf); + + addrconf_verify(&init_net); + + rtnl_af_register(&inet6_ops); + + err = rtnl_register_module(THIS_MODULE, PF_INET6, RTM_GETLINK, + NULL, inet6_dump_ifinfo, 0); + if (err < 0) + goto errout; + + err = rtnl_register_module(THIS_MODULE, PF_INET6, RTM_NEWADDR, + inet6_rtm_newaddr, NULL, 0); + if (err < 0) + goto errout; + err = rtnl_register_module(THIS_MODULE, PF_INET6, RTM_DELADDR, + inet6_rtm_deladdr, NULL, 0); + if (err < 0) + goto errout; + err = rtnl_register_module(THIS_MODULE, PF_INET6, RTM_GETADDR, + inet6_rtm_getaddr, inet6_dump_ifaddr, + RTNL_FLAG_DOIT_UNLOCKED); + if (err < 0) + goto errout; + err = rtnl_register_module(THIS_MODULE, PF_INET6, RTM_GETMULTICAST, + NULL, inet6_dump_ifmcaddr, 0); + if (err < 0) + goto errout; + err = rtnl_register_module(THIS_MODULE, PF_INET6, RTM_GETANYCAST, + NULL, inet6_dump_ifacaddr, 0); + if (err < 0) + goto errout; + err = rtnl_register_module(THIS_MODULE, PF_INET6, RTM_GETNETCONF, + inet6_netconf_get_devconf, + inet6_netconf_dump_devconf, + RTNL_FLAG_DOIT_UNLOCKED); + if (err < 0) + goto errout; + err = ipv6_addr_label_rtnl_register(); + if (err < 0) + goto errout; + + return 0; +errout: + rtnl_unregister_all(PF_INET6); + rtnl_af_unregister(&inet6_ops); + unregister_netdevice_notifier(&ipv6_dev_notf); +errlo: + destroy_workqueue(addrconf_wq); +out_nowq: + unregister_pernet_subsys(&addrconf_ops); +out_addrlabel: + ipv6_addr_label_cleanup(); +out: + return err; +} + +void addrconf_cleanup(void) +{ + struct net_device *dev; + + unregister_netdevice_notifier(&ipv6_dev_notf); + unregister_pernet_subsys(&addrconf_ops); + ipv6_addr_label_cleanup(); + + rtnl_af_unregister(&inet6_ops); + + rtnl_lock(); + + /* clean dev list */ + for_each_netdev(&init_net, dev) { + if (__in6_dev_get(dev) == NULL) + continue; + addrconf_ifdown(dev, true); + } + addrconf_ifdown(init_net.loopback_dev, true); + + rtnl_unlock(); + + destroy_workqueue(addrconf_wq); +} diff --git a/net/ipv6/addrconf_core.c b/net/ipv6/addrconf_core.c new file mode 100644 index 000000000..507a8353a --- /dev/null +++ b/net/ipv6/addrconf_core.c @@ -0,0 +1,273 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * IPv6 library code, needed by static components when full IPv6 support is + * not configured or static. + */ + +#include +#include +#include +#include +#include + +/* if ipv6 module registers this function is used by xfrm to force all + * sockets to relookup their nodes - this is fairly expensive, be + * careful + */ +void (*__fib6_flush_trees)(struct net *); +EXPORT_SYMBOL(__fib6_flush_trees); + +#define IPV6_ADDR_SCOPE_TYPE(scope) ((scope) << 16) + +static inline unsigned int ipv6_addr_scope2type(unsigned int scope) +{ + switch (scope) { + case IPV6_ADDR_SCOPE_NODELOCAL: + return (IPV6_ADDR_SCOPE_TYPE(IPV6_ADDR_SCOPE_NODELOCAL) | + IPV6_ADDR_LOOPBACK); + case IPV6_ADDR_SCOPE_LINKLOCAL: + return (IPV6_ADDR_SCOPE_TYPE(IPV6_ADDR_SCOPE_LINKLOCAL) | + IPV6_ADDR_LINKLOCAL); + case IPV6_ADDR_SCOPE_SITELOCAL: + return (IPV6_ADDR_SCOPE_TYPE(IPV6_ADDR_SCOPE_SITELOCAL) | + IPV6_ADDR_SITELOCAL); + } + return IPV6_ADDR_SCOPE_TYPE(scope); +} + +int __ipv6_addr_type(const struct in6_addr *addr) +{ + __be32 st; + + st = addr->s6_addr32[0]; + + /* Consider all addresses with the first three bits different of + 000 and 111 as unicasts. + */ + if ((st & htonl(0xE0000000)) != htonl(0x00000000) && + (st & htonl(0xE0000000)) != htonl(0xE0000000)) + return (IPV6_ADDR_UNICAST | + IPV6_ADDR_SCOPE_TYPE(IPV6_ADDR_SCOPE_GLOBAL)); + + if ((st & htonl(0xFF000000)) == htonl(0xFF000000)) { + /* multicast */ + /* addr-select 3.1 */ + return (IPV6_ADDR_MULTICAST | + ipv6_addr_scope2type(IPV6_ADDR_MC_SCOPE(addr))); + } + + if ((st & htonl(0xFFC00000)) == htonl(0xFE800000)) + return (IPV6_ADDR_LINKLOCAL | IPV6_ADDR_UNICAST | + IPV6_ADDR_SCOPE_TYPE(IPV6_ADDR_SCOPE_LINKLOCAL)); /* addr-select 3.1 */ + if ((st & htonl(0xFFC00000)) == htonl(0xFEC00000)) + return (IPV6_ADDR_SITELOCAL | IPV6_ADDR_UNICAST | + IPV6_ADDR_SCOPE_TYPE(IPV6_ADDR_SCOPE_SITELOCAL)); /* addr-select 3.1 */ + if ((st & htonl(0xFE000000)) == htonl(0xFC000000)) + return (IPV6_ADDR_UNICAST | + IPV6_ADDR_SCOPE_TYPE(IPV6_ADDR_SCOPE_GLOBAL)); /* RFC 4193 */ + + if ((addr->s6_addr32[0] | addr->s6_addr32[1]) == 0) { + if (addr->s6_addr32[2] == 0) { + if (addr->s6_addr32[3] == 0) + return IPV6_ADDR_ANY; + + if (addr->s6_addr32[3] == htonl(0x00000001)) + return (IPV6_ADDR_LOOPBACK | IPV6_ADDR_UNICAST | + IPV6_ADDR_SCOPE_TYPE(IPV6_ADDR_SCOPE_LINKLOCAL)); /* addr-select 3.4 */ + + return (IPV6_ADDR_COMPATv4 | IPV6_ADDR_UNICAST | + IPV6_ADDR_SCOPE_TYPE(IPV6_ADDR_SCOPE_GLOBAL)); /* addr-select 3.3 */ + } + + if (addr->s6_addr32[2] == htonl(0x0000ffff)) + return (IPV6_ADDR_MAPPED | + IPV6_ADDR_SCOPE_TYPE(IPV6_ADDR_SCOPE_GLOBAL)); /* addr-select 3.3 */ + } + + return (IPV6_ADDR_UNICAST | + IPV6_ADDR_SCOPE_TYPE(IPV6_ADDR_SCOPE_GLOBAL)); /* addr-select 3.4 */ +} +EXPORT_SYMBOL(__ipv6_addr_type); + +static ATOMIC_NOTIFIER_HEAD(inet6addr_chain); +static BLOCKING_NOTIFIER_HEAD(inet6addr_validator_chain); + +int register_inet6addr_notifier(struct notifier_block *nb) +{ + return atomic_notifier_chain_register(&inet6addr_chain, nb); +} +EXPORT_SYMBOL(register_inet6addr_notifier); + +int unregister_inet6addr_notifier(struct notifier_block *nb) +{ + return atomic_notifier_chain_unregister(&inet6addr_chain, nb); +} +EXPORT_SYMBOL(unregister_inet6addr_notifier); + +int inet6addr_notifier_call_chain(unsigned long val, void *v) +{ + return atomic_notifier_call_chain(&inet6addr_chain, val, v); +} +EXPORT_SYMBOL(inet6addr_notifier_call_chain); + +int register_inet6addr_validator_notifier(struct notifier_block *nb) +{ + return blocking_notifier_chain_register(&inet6addr_validator_chain, nb); +} +EXPORT_SYMBOL(register_inet6addr_validator_notifier); + +int unregister_inet6addr_validator_notifier(struct notifier_block *nb) +{ + return blocking_notifier_chain_unregister(&inet6addr_validator_chain, + nb); +} +EXPORT_SYMBOL(unregister_inet6addr_validator_notifier); + +int inet6addr_validator_notifier_call_chain(unsigned long val, void *v) +{ + return blocking_notifier_call_chain(&inet6addr_validator_chain, val, v); +} +EXPORT_SYMBOL(inet6addr_validator_notifier_call_chain); + +static struct dst_entry *eafnosupport_ipv6_dst_lookup_flow(struct net *net, + const struct sock *sk, + struct flowi6 *fl6, + const struct in6_addr *final_dst) +{ + return ERR_PTR(-EAFNOSUPPORT); +} + +static int eafnosupport_ipv6_route_input(struct sk_buff *skb) +{ + return -EAFNOSUPPORT; +} + +static struct fib6_table *eafnosupport_fib6_get_table(struct net *net, u32 id) +{ + return NULL; +} + +static int +eafnosupport_fib6_table_lookup(struct net *net, struct fib6_table *table, + int oif, struct flowi6 *fl6, + struct fib6_result *res, int flags) +{ + return -EAFNOSUPPORT; +} + +static int +eafnosupport_fib6_lookup(struct net *net, int oif, struct flowi6 *fl6, + struct fib6_result *res, int flags) +{ + return -EAFNOSUPPORT; +} + +static void +eafnosupport_fib6_select_path(const struct net *net, struct fib6_result *res, + struct flowi6 *fl6, int oif, bool have_oif_match, + const struct sk_buff *skb, int strict) +{ +} + +static u32 +eafnosupport_ip6_mtu_from_fib6(const struct fib6_result *res, + const struct in6_addr *daddr, + const struct in6_addr *saddr) +{ + return 0; +} + +static int eafnosupport_fib6_nh_init(struct net *net, struct fib6_nh *fib6_nh, + struct fib6_config *cfg, gfp_t gfp_flags, + struct netlink_ext_ack *extack) +{ + NL_SET_ERR_MSG(extack, "IPv6 support not enabled in kernel"); + return -EAFNOSUPPORT; +} + +static int eafnosupport_ip6_del_rt(struct net *net, struct fib6_info *rt, + bool skip_notify) +{ + return -EAFNOSUPPORT; +} + +static int eafnosupport_ipv6_fragment(struct net *net, struct sock *sk, struct sk_buff *skb, + int (*output)(struct net *, struct sock *, struct sk_buff *)) +{ + kfree_skb(skb); + return -EAFNOSUPPORT; +} + +static struct net_device *eafnosupport_ipv6_dev_find(struct net *net, const struct in6_addr *addr, + struct net_device *dev) +{ + return ERR_PTR(-EAFNOSUPPORT); +} + +const struct ipv6_stub *ipv6_stub __read_mostly = &(struct ipv6_stub) { + .ipv6_dst_lookup_flow = eafnosupport_ipv6_dst_lookup_flow, + .ipv6_route_input = eafnosupport_ipv6_route_input, + .fib6_get_table = eafnosupport_fib6_get_table, + .fib6_table_lookup = eafnosupport_fib6_table_lookup, + .fib6_lookup = eafnosupport_fib6_lookup, + .fib6_select_path = eafnosupport_fib6_select_path, + .ip6_mtu_from_fib6 = eafnosupport_ip6_mtu_from_fib6, + .fib6_nh_init = eafnosupport_fib6_nh_init, + .ip6_del_rt = eafnosupport_ip6_del_rt, + .ipv6_fragment = eafnosupport_ipv6_fragment, + .ipv6_dev_find = eafnosupport_ipv6_dev_find, +}; +EXPORT_SYMBOL_GPL(ipv6_stub); + +/* IPv6 Wildcard Address and Loopback Address defined by RFC2553 */ +const struct in6_addr in6addr_loopback = IN6ADDR_LOOPBACK_INIT; +EXPORT_SYMBOL(in6addr_loopback); +const struct in6_addr in6addr_any = IN6ADDR_ANY_INIT; +EXPORT_SYMBOL(in6addr_any); +const struct in6_addr in6addr_linklocal_allnodes = IN6ADDR_LINKLOCAL_ALLNODES_INIT; +EXPORT_SYMBOL(in6addr_linklocal_allnodes); +const struct in6_addr in6addr_linklocal_allrouters = IN6ADDR_LINKLOCAL_ALLROUTERS_INIT; +EXPORT_SYMBOL(in6addr_linklocal_allrouters); +const struct in6_addr in6addr_interfacelocal_allnodes = IN6ADDR_INTERFACELOCAL_ALLNODES_INIT; +EXPORT_SYMBOL(in6addr_interfacelocal_allnodes); +const struct in6_addr in6addr_interfacelocal_allrouters = IN6ADDR_INTERFACELOCAL_ALLROUTERS_INIT; +EXPORT_SYMBOL(in6addr_interfacelocal_allrouters); +const struct in6_addr in6addr_sitelocal_allrouters = IN6ADDR_SITELOCAL_ALLROUTERS_INIT; +EXPORT_SYMBOL(in6addr_sitelocal_allrouters); + +static void snmp6_free_dev(struct inet6_dev *idev) +{ + kfree(idev->stats.icmpv6msgdev); + kfree(idev->stats.icmpv6dev); + free_percpu(idev->stats.ipv6); +} + +static void in6_dev_finish_destroy_rcu(struct rcu_head *head) +{ + struct inet6_dev *idev = container_of(head, struct inet6_dev, rcu); + + snmp6_free_dev(idev); + kfree(idev); +} + +/* Nobody refers to this device, we may destroy it. */ + +void in6_dev_finish_destroy(struct inet6_dev *idev) +{ + struct net_device *dev = idev->dev; + + WARN_ON(!list_empty(&idev->addr_list)); + WARN_ON(rcu_access_pointer(idev->mc_list)); + WARN_ON(timer_pending(&idev->rs_timer)); + +#ifdef NET_REFCNT_DEBUG + pr_debug("%s: %s\n", __func__, dev ? dev->name : "NIL"); +#endif + netdev_put(dev, &idev->dev_tracker); + if (!idev->dead) { + pr_warn("Freeing alive inet6 device %p\n", idev); + return; + } + call_rcu(&idev->rcu, in6_dev_finish_destroy_rcu); +} +EXPORT_SYMBOL(in6_dev_finish_destroy); diff --git a/net/ipv6/addrlabel.c b/net/ipv6/addrlabel.c new file mode 100644 index 000000000..17ac45aa7 --- /dev/null +++ b/net/ipv6/addrlabel.c @@ -0,0 +1,652 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * IPv6 Address Label subsystem + * for the IPv6 "Default" Source Address Selection + * + * Copyright (C)2007 USAGI/WIDE Project + */ +/* + * Author: + * YOSHIFUJI Hideaki @ USAGI/WIDE Project + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if 0 +#define ADDRLABEL(x...) printk(x) +#else +#define ADDRLABEL(x...) do { ; } while (0) +#endif + +/* + * Policy Table + */ +struct ip6addrlbl_entry { + struct in6_addr prefix; + int prefixlen; + int ifindex; + int addrtype; + u32 label; + struct hlist_node list; + struct rcu_head rcu; +}; + +/* + * Default policy table (RFC6724 + extensions) + * + * prefix addr_type label + * ------------------------------------------------------------------------- + * ::1/128 LOOPBACK 0 + * ::/0 N/A 1 + * 2002::/16 N/A 2 + * ::/96 COMPATv4 3 + * ::ffff:0:0/96 V4MAPPED 4 + * fc00::/7 N/A 5 ULA (RFC 4193) + * 2001::/32 N/A 6 Teredo (RFC 4380) + * 2001:10::/28 N/A 7 ORCHID (RFC 4843) + * fec0::/10 N/A 11 Site-local + * (deprecated by RFC3879) + * 3ffe::/16 N/A 12 6bone + * + * Note: 0xffffffff is used if we do not have any policies. + * Note: Labels for ULA and 6to4 are different from labels listed in RFC6724. + */ + +#define IPV6_ADDR_LABEL_DEFAULT 0xffffffffUL + +static const __net_initconst struct ip6addrlbl_init_table +{ + const struct in6_addr *prefix; + int prefixlen; + u32 label; +} ip6addrlbl_init_table[] = { + { /* ::/0 */ + .prefix = &in6addr_any, + .label = 1, + }, { /* fc00::/7 */ + .prefix = &(struct in6_addr){ { { 0xfc } } } , + .prefixlen = 7, + .label = 5, + }, { /* fec0::/10 */ + .prefix = &(struct in6_addr){ { { 0xfe, 0xc0 } } }, + .prefixlen = 10, + .label = 11, + }, { /* 2002::/16 */ + .prefix = &(struct in6_addr){ { { 0x20, 0x02 } } }, + .prefixlen = 16, + .label = 2, + }, { /* 3ffe::/16 */ + .prefix = &(struct in6_addr){ { { 0x3f, 0xfe } } }, + .prefixlen = 16, + .label = 12, + }, { /* 2001::/32 */ + .prefix = &(struct in6_addr){ { { 0x20, 0x01 } } }, + .prefixlen = 32, + .label = 6, + }, { /* 2001:10::/28 */ + .prefix = &(struct in6_addr){ { { 0x20, 0x01, 0x00, 0x10 } } }, + .prefixlen = 28, + .label = 7, + }, { /* ::ffff:0:0 */ + .prefix = &(struct in6_addr){ { { [10] = 0xff, [11] = 0xff } } }, + .prefixlen = 96, + .label = 4, + }, { /* ::/96 */ + .prefix = &in6addr_any, + .prefixlen = 96, + .label = 3, + }, { /* ::1/128 */ + .prefix = &in6addr_loopback, + .prefixlen = 128, + .label = 0, + } +}; + +/* Find label */ +static bool __ip6addrlbl_match(const struct ip6addrlbl_entry *p, + const struct in6_addr *addr, + int addrtype, int ifindex) +{ + if (p->ifindex && p->ifindex != ifindex) + return false; + if (p->addrtype && p->addrtype != addrtype) + return false; + if (!ipv6_prefix_equal(addr, &p->prefix, p->prefixlen)) + return false; + return true; +} + +static struct ip6addrlbl_entry *__ipv6_addr_label(struct net *net, + const struct in6_addr *addr, + int type, int ifindex) +{ + struct ip6addrlbl_entry *p; + + hlist_for_each_entry_rcu(p, &net->ipv6.ip6addrlbl_table.head, list) { + if (__ip6addrlbl_match(p, addr, type, ifindex)) + return p; + } + return NULL; +} + +u32 ipv6_addr_label(struct net *net, + const struct in6_addr *addr, int type, int ifindex) +{ + u32 label; + struct ip6addrlbl_entry *p; + + type &= IPV6_ADDR_MAPPED | IPV6_ADDR_COMPATv4 | IPV6_ADDR_LOOPBACK; + + rcu_read_lock(); + p = __ipv6_addr_label(net, addr, type, ifindex); + label = p ? p->label : IPV6_ADDR_LABEL_DEFAULT; + rcu_read_unlock(); + + ADDRLABEL(KERN_DEBUG "%s(addr=%pI6, type=%d, ifindex=%d) => %08x\n", + __func__, addr, type, ifindex, label); + + return label; +} + +/* allocate one entry */ +static struct ip6addrlbl_entry *ip6addrlbl_alloc(const struct in6_addr *prefix, + int prefixlen, int ifindex, + u32 label) +{ + struct ip6addrlbl_entry *newp; + int addrtype; + + ADDRLABEL(KERN_DEBUG "%s(prefix=%pI6, prefixlen=%d, ifindex=%d, label=%u)\n", + __func__, prefix, prefixlen, ifindex, (unsigned int)label); + + addrtype = ipv6_addr_type(prefix) & (IPV6_ADDR_MAPPED | IPV6_ADDR_COMPATv4 | IPV6_ADDR_LOOPBACK); + + switch (addrtype) { + case IPV6_ADDR_MAPPED: + if (prefixlen > 96) + return ERR_PTR(-EINVAL); + if (prefixlen < 96) + addrtype = 0; + break; + case IPV6_ADDR_COMPATv4: + if (prefixlen != 96) + addrtype = 0; + break; + case IPV6_ADDR_LOOPBACK: + if (prefixlen != 128) + addrtype = 0; + break; + } + + newp = kmalloc(sizeof(*newp), GFP_KERNEL); + if (!newp) + return ERR_PTR(-ENOMEM); + + ipv6_addr_prefix(&newp->prefix, prefix, prefixlen); + newp->prefixlen = prefixlen; + newp->ifindex = ifindex; + newp->addrtype = addrtype; + newp->label = label; + INIT_HLIST_NODE(&newp->list); + return newp; +} + +/* add a label */ +static int __ip6addrlbl_add(struct net *net, struct ip6addrlbl_entry *newp, + int replace) +{ + struct ip6addrlbl_entry *last = NULL, *p = NULL; + struct hlist_node *n; + int ret = 0; + + ADDRLABEL(KERN_DEBUG "%s(newp=%p, replace=%d)\n", __func__, newp, + replace); + + hlist_for_each_entry_safe(p, n, &net->ipv6.ip6addrlbl_table.head, list) { + if (p->prefixlen == newp->prefixlen && + p->ifindex == newp->ifindex && + ipv6_addr_equal(&p->prefix, &newp->prefix)) { + if (!replace) { + ret = -EEXIST; + goto out; + } + hlist_replace_rcu(&p->list, &newp->list); + kfree_rcu(p, rcu); + goto out; + } else if ((p->prefixlen == newp->prefixlen && !p->ifindex) || + (p->prefixlen < newp->prefixlen)) { + hlist_add_before_rcu(&newp->list, &p->list); + goto out; + } + last = p; + } + if (last) + hlist_add_behind_rcu(&newp->list, &last->list); + else + hlist_add_head_rcu(&newp->list, &net->ipv6.ip6addrlbl_table.head); +out: + if (!ret) + net->ipv6.ip6addrlbl_table.seq++; + return ret; +} + +/* add a label */ +static int ip6addrlbl_add(struct net *net, + const struct in6_addr *prefix, int prefixlen, + int ifindex, u32 label, int replace) +{ + struct ip6addrlbl_entry *newp; + int ret = 0; + + ADDRLABEL(KERN_DEBUG "%s(prefix=%pI6, prefixlen=%d, ifindex=%d, label=%u, replace=%d)\n", + __func__, prefix, prefixlen, ifindex, (unsigned int)label, + replace); + + newp = ip6addrlbl_alloc(prefix, prefixlen, ifindex, label); + if (IS_ERR(newp)) + return PTR_ERR(newp); + spin_lock(&net->ipv6.ip6addrlbl_table.lock); + ret = __ip6addrlbl_add(net, newp, replace); + spin_unlock(&net->ipv6.ip6addrlbl_table.lock); + if (ret) + kfree(newp); + return ret; +} + +/* remove a label */ +static int __ip6addrlbl_del(struct net *net, + const struct in6_addr *prefix, int prefixlen, + int ifindex) +{ + struct ip6addrlbl_entry *p = NULL; + struct hlist_node *n; + int ret = -ESRCH; + + ADDRLABEL(KERN_DEBUG "%s(prefix=%pI6, prefixlen=%d, ifindex=%d)\n", + __func__, prefix, prefixlen, ifindex); + + hlist_for_each_entry_safe(p, n, &net->ipv6.ip6addrlbl_table.head, list) { + if (p->prefixlen == prefixlen && + p->ifindex == ifindex && + ipv6_addr_equal(&p->prefix, prefix)) { + hlist_del_rcu(&p->list); + kfree_rcu(p, rcu); + ret = 0; + break; + } + } + return ret; +} + +static int ip6addrlbl_del(struct net *net, + const struct in6_addr *prefix, int prefixlen, + int ifindex) +{ + struct in6_addr prefix_buf; + int ret; + + ADDRLABEL(KERN_DEBUG "%s(prefix=%pI6, prefixlen=%d, ifindex=%d)\n", + __func__, prefix, prefixlen, ifindex); + + ipv6_addr_prefix(&prefix_buf, prefix, prefixlen); + spin_lock(&net->ipv6.ip6addrlbl_table.lock); + ret = __ip6addrlbl_del(net, &prefix_buf, prefixlen, ifindex); + spin_unlock(&net->ipv6.ip6addrlbl_table.lock); + return ret; +} + +/* add default label */ +static int __net_init ip6addrlbl_net_init(struct net *net) +{ + struct ip6addrlbl_entry *p = NULL; + struct hlist_node *n; + int err; + int i; + + ADDRLABEL(KERN_DEBUG "%s\n", __func__); + + spin_lock_init(&net->ipv6.ip6addrlbl_table.lock); + INIT_HLIST_HEAD(&net->ipv6.ip6addrlbl_table.head); + + for (i = 0; i < ARRAY_SIZE(ip6addrlbl_init_table); i++) { + err = ip6addrlbl_add(net, + ip6addrlbl_init_table[i].prefix, + ip6addrlbl_init_table[i].prefixlen, + 0, + ip6addrlbl_init_table[i].label, 0); + if (err) + goto err_ip6addrlbl_add; + } + return 0; + +err_ip6addrlbl_add: + hlist_for_each_entry_safe(p, n, &net->ipv6.ip6addrlbl_table.head, list) { + hlist_del_rcu(&p->list); + kfree_rcu(p, rcu); + } + return err; +} + +static void __net_exit ip6addrlbl_net_exit(struct net *net) +{ + struct ip6addrlbl_entry *p = NULL; + struct hlist_node *n; + + /* Remove all labels belonging to the exiting net */ + spin_lock(&net->ipv6.ip6addrlbl_table.lock); + hlist_for_each_entry_safe(p, n, &net->ipv6.ip6addrlbl_table.head, list) { + hlist_del_rcu(&p->list); + kfree_rcu(p, rcu); + } + spin_unlock(&net->ipv6.ip6addrlbl_table.lock); +} + +static struct pernet_operations ipv6_addr_label_ops = { + .init = ip6addrlbl_net_init, + .exit = ip6addrlbl_net_exit, +}; + +int __init ipv6_addr_label_init(void) +{ + return register_pernet_subsys(&ipv6_addr_label_ops); +} + +void ipv6_addr_label_cleanup(void) +{ + unregister_pernet_subsys(&ipv6_addr_label_ops); +} + +static const struct nla_policy ifal_policy[IFAL_MAX+1] = { + [IFAL_ADDRESS] = { .len = sizeof(struct in6_addr), }, + [IFAL_LABEL] = { .len = sizeof(u32), }, +}; + +static bool addrlbl_ifindex_exists(struct net *net, int ifindex) +{ + + struct net_device *dev; + + rcu_read_lock(); + dev = dev_get_by_index_rcu(net, ifindex); + rcu_read_unlock(); + + return dev != NULL; +} + +static int ip6addrlbl_newdel(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct ifaddrlblmsg *ifal; + struct nlattr *tb[IFAL_MAX+1]; + struct in6_addr *pfx; + u32 label; + int err = 0; + + err = nlmsg_parse_deprecated(nlh, sizeof(*ifal), tb, IFAL_MAX, + ifal_policy, extack); + if (err < 0) + return err; + + ifal = nlmsg_data(nlh); + + if (ifal->ifal_family != AF_INET6 || + ifal->ifal_prefixlen > 128) + return -EINVAL; + + if (!tb[IFAL_ADDRESS]) + return -EINVAL; + pfx = nla_data(tb[IFAL_ADDRESS]); + + if (!tb[IFAL_LABEL]) + return -EINVAL; + label = nla_get_u32(tb[IFAL_LABEL]); + if (label == IPV6_ADDR_LABEL_DEFAULT) + return -EINVAL; + + switch (nlh->nlmsg_type) { + case RTM_NEWADDRLABEL: + if (ifal->ifal_index && + !addrlbl_ifindex_exists(net, ifal->ifal_index)) + return -EINVAL; + + err = ip6addrlbl_add(net, pfx, ifal->ifal_prefixlen, + ifal->ifal_index, label, + nlh->nlmsg_flags & NLM_F_REPLACE); + break; + case RTM_DELADDRLABEL: + err = ip6addrlbl_del(net, pfx, ifal->ifal_prefixlen, + ifal->ifal_index); + break; + default: + err = -EOPNOTSUPP; + } + return err; +} + +static void ip6addrlbl_putmsg(struct nlmsghdr *nlh, + int prefixlen, int ifindex, u32 lseq) +{ + struct ifaddrlblmsg *ifal = nlmsg_data(nlh); + ifal->ifal_family = AF_INET6; + ifal->__ifal_reserved = 0; + ifal->ifal_prefixlen = prefixlen; + ifal->ifal_flags = 0; + ifal->ifal_index = ifindex; + ifal->ifal_seq = lseq; +}; + +static int ip6addrlbl_fill(struct sk_buff *skb, + struct ip6addrlbl_entry *p, + u32 lseq, + u32 portid, u32 seq, int event, + unsigned int flags) +{ + struct nlmsghdr *nlh = nlmsg_put(skb, portid, seq, event, + sizeof(struct ifaddrlblmsg), flags); + if (!nlh) + return -EMSGSIZE; + + ip6addrlbl_putmsg(nlh, p->prefixlen, p->ifindex, lseq); + + if (nla_put_in6_addr(skb, IFAL_ADDRESS, &p->prefix) < 0 || + nla_put_u32(skb, IFAL_LABEL, p->label) < 0) { + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; + } + + nlmsg_end(skb, nlh); + return 0; +} + +static int ip6addrlbl_valid_dump_req(const struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct ifaddrlblmsg *ifal; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifal))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for address label dump request"); + return -EINVAL; + } + + ifal = nlmsg_data(nlh); + if (ifal->__ifal_reserved || ifal->ifal_prefixlen || + ifal->ifal_flags || ifal->ifal_index || ifal->ifal_seq) { + NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for address label dump request"); + return -EINVAL; + } + + if (nlmsg_attrlen(nlh, sizeof(*ifal))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid data after header for address label dump request"); + return -EINVAL; + } + + return 0; +} + +static int ip6addrlbl_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + const struct nlmsghdr *nlh = cb->nlh; + struct net *net = sock_net(skb->sk); + struct ip6addrlbl_entry *p; + int idx = 0, s_idx = cb->args[0]; + int err; + + if (cb->strict_check) { + err = ip6addrlbl_valid_dump_req(nlh, cb->extack); + if (err < 0) + return err; + } + + rcu_read_lock(); + hlist_for_each_entry_rcu(p, &net->ipv6.ip6addrlbl_table.head, list) { + if (idx >= s_idx) { + err = ip6addrlbl_fill(skb, p, + net->ipv6.ip6addrlbl_table.seq, + NETLINK_CB(cb->skb).portid, + nlh->nlmsg_seq, + RTM_NEWADDRLABEL, + NLM_F_MULTI); + if (err < 0) + break; + } + idx++; + } + rcu_read_unlock(); + cb->args[0] = idx; + return skb->len; +} + +static inline int ip6addrlbl_msgsize(void) +{ + return NLMSG_ALIGN(sizeof(struct ifaddrlblmsg)) + + nla_total_size(16) /* IFAL_ADDRESS */ + + nla_total_size(4); /* IFAL_LABEL */ +} + +static int ip6addrlbl_valid_get_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct ifaddrlblmsg *ifal; + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ifal))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for addrlabel get request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse_deprecated(nlh, sizeof(*ifal), tb, + IFAL_MAX, ifal_policy, extack); + + ifal = nlmsg_data(nlh); + if (ifal->__ifal_reserved || ifal->ifal_flags || ifal->ifal_seq) { + NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for addrlabel get request"); + return -EINVAL; + } + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(*ifal), tb, IFAL_MAX, + ifal_policy, extack); + if (err) + return err; + + for (i = 0; i <= IFAL_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case IFAL_ADDRESS: + break; + default: + NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in addrlabel get request"); + return -EINVAL; + } + } + + return 0; +} + +static int ip6addrlbl_get(struct sk_buff *in_skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(in_skb->sk); + struct ifaddrlblmsg *ifal; + struct nlattr *tb[IFAL_MAX+1]; + struct in6_addr *addr; + u32 lseq; + int err = 0; + struct ip6addrlbl_entry *p; + struct sk_buff *skb; + + err = ip6addrlbl_valid_get_req(in_skb, nlh, tb, extack); + if (err < 0) + return err; + + ifal = nlmsg_data(nlh); + + if (ifal->ifal_family != AF_INET6 || + ifal->ifal_prefixlen != 128) + return -EINVAL; + + if (ifal->ifal_index && + !addrlbl_ifindex_exists(net, ifal->ifal_index)) + return -EINVAL; + + if (!tb[IFAL_ADDRESS]) + return -EINVAL; + addr = nla_data(tb[IFAL_ADDRESS]); + + skb = nlmsg_new(ip6addrlbl_msgsize(), GFP_KERNEL); + if (!skb) + return -ENOBUFS; + + err = -ESRCH; + + rcu_read_lock(); + p = __ipv6_addr_label(net, addr, ipv6_addr_type(addr), ifal->ifal_index); + lseq = net->ipv6.ip6addrlbl_table.seq; + if (p) + err = ip6addrlbl_fill(skb, p, lseq, + NETLINK_CB(in_skb).portid, + nlh->nlmsg_seq, + RTM_NEWADDRLABEL, 0); + rcu_read_unlock(); + + if (err < 0) { + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + } else { + err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid); + } + return err; +} + +int __init ipv6_addr_label_rtnl_register(void) +{ + int ret; + + ret = rtnl_register_module(THIS_MODULE, PF_INET6, RTM_NEWADDRLABEL, + ip6addrlbl_newdel, + NULL, RTNL_FLAG_DOIT_UNLOCKED); + if (ret < 0) + return ret; + ret = rtnl_register_module(THIS_MODULE, PF_INET6, RTM_DELADDRLABEL, + ip6addrlbl_newdel, + NULL, RTNL_FLAG_DOIT_UNLOCKED); + if (ret < 0) + return ret; + ret = rtnl_register_module(THIS_MODULE, PF_INET6, RTM_GETADDRLABEL, + ip6addrlbl_get, + ip6addrlbl_dump, RTNL_FLAG_DOIT_UNLOCKED); + return ret; +} diff --git a/net/ipv6/af_inet6.c b/net/ipv6/af_inet6.c new file mode 100644 index 000000000..0b42eb8c5 --- /dev/null +++ b/net/ipv6/af_inet6.c @@ -0,0 +1,1334 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * PF_INET6 socket protocol family + * Linux INET6 implementation + * + * Authors: + * Pedro Roque + * + * Adapted from linux/net/ipv4/af_inet.c + * + * Fixes: + * piggy, Karl Knutson : Socket protocol table + * Hideaki YOSHIFUJI : sin6_scope_id support + * Arnaldo Melo : check proc_net_create return, cleanups + */ + +#define pr_fmt(fmt) "IPv6: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_IPV6_TUNNEL +#include +#endif +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "ip6_offload.h" + +MODULE_AUTHOR("Cast of dozens"); +MODULE_DESCRIPTION("IPv6 protocol stack for Linux"); +MODULE_LICENSE("GPL"); + +/* The inetsw6 table contains everything that inet6_create needs to + * build a new socket. + */ +static struct list_head inetsw6[SOCK_MAX]; +static DEFINE_SPINLOCK(inetsw6_lock); + +struct ipv6_params ipv6_defaults = { + .disable_ipv6 = 0, + .autoconf = 1, +}; + +static int disable_ipv6_mod; + +module_param_named(disable, disable_ipv6_mod, int, 0444); +MODULE_PARM_DESC(disable, "Disable IPv6 module such that it is non-functional"); + +module_param_named(disable_ipv6, ipv6_defaults.disable_ipv6, int, 0444); +MODULE_PARM_DESC(disable_ipv6, "Disable IPv6 on all interfaces"); + +module_param_named(autoconf, ipv6_defaults.autoconf, int, 0444); +MODULE_PARM_DESC(autoconf, "Enable IPv6 address autoconfiguration on all interfaces"); + +bool ipv6_mod_enabled(void) +{ + return disable_ipv6_mod == 0; +} +EXPORT_SYMBOL_GPL(ipv6_mod_enabled); + +static __inline__ struct ipv6_pinfo *inet6_sk_generic(struct sock *sk) +{ + const int offset = sk->sk_prot->obj_size - sizeof(struct ipv6_pinfo); + + return (struct ipv6_pinfo *)(((u8 *)sk) + offset); +} + +void inet6_sock_destruct(struct sock *sk) +{ + inet6_cleanup_sock(sk); + inet_sock_destruct(sk); +} +EXPORT_SYMBOL_GPL(inet6_sock_destruct); + +static int inet6_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct inet_sock *inet; + struct ipv6_pinfo *np; + struct sock *sk; + struct inet_protosw *answer; + struct proto *answer_prot; + unsigned char answer_flags; + int try_loading_module = 0; + int err; + + if (protocol < 0 || protocol >= IPPROTO_MAX) + return -EINVAL; + + /* Look for the requested type/protocol pair. */ +lookup_protocol: + err = -ESOCKTNOSUPPORT; + rcu_read_lock(); + list_for_each_entry_rcu(answer, &inetsw6[sock->type], list) { + + err = 0; + /* Check the non-wild match. */ + if (protocol == answer->protocol) { + if (protocol != IPPROTO_IP) + break; + } else { + /* Check for the two wild cases. */ + if (IPPROTO_IP == protocol) { + protocol = answer->protocol; + break; + } + if (IPPROTO_IP == answer->protocol) + break; + } + err = -EPROTONOSUPPORT; + } + + if (err) { + if (try_loading_module < 2) { + rcu_read_unlock(); + /* + * Be more specific, e.g. net-pf-10-proto-132-type-1 + * (net-pf-PF_INET6-proto-IPPROTO_SCTP-type-SOCK_STREAM) + */ + if (++try_loading_module == 1) + request_module("net-pf-%d-proto-%d-type-%d", + PF_INET6, protocol, sock->type); + /* + * Fall back to generic, e.g. net-pf-10-proto-132 + * (net-pf-PF_INET6-proto-IPPROTO_SCTP) + */ + else + request_module("net-pf-%d-proto-%d", + PF_INET6, protocol); + goto lookup_protocol; + } else + goto out_rcu_unlock; + } + + err = -EPERM; + if (sock->type == SOCK_RAW && !kern && + !ns_capable(net->user_ns, CAP_NET_RAW)) + goto out_rcu_unlock; + + sock->ops = answer->ops; + answer_prot = answer->prot; + answer_flags = answer->flags; + rcu_read_unlock(); + + WARN_ON(!answer_prot->slab); + + err = -ENOBUFS; + sk = sk_alloc(net, PF_INET6, GFP_KERNEL, answer_prot, kern); + if (!sk) + goto out; + + sock_init_data(sock, sk); + + err = 0; + if (INET_PROTOSW_REUSE & answer_flags) + sk->sk_reuse = SK_CAN_REUSE; + + if (INET_PROTOSW_ICSK & answer_flags) + inet_init_csk_locks(sk); + + inet = inet_sk(sk); + inet->is_icsk = (INET_PROTOSW_ICSK & answer_flags) != 0; + + if (SOCK_RAW == sock->type) { + inet->inet_num = protocol; + if (IPPROTO_RAW == protocol) + inet->hdrincl = 1; + } + + sk->sk_destruct = inet6_sock_destruct; + sk->sk_family = PF_INET6; + sk->sk_protocol = protocol; + + sk->sk_backlog_rcv = answer->prot->backlog_rcv; + + inet_sk(sk)->pinet6 = np = inet6_sk_generic(sk); + np->hop_limit = -1; + np->mcast_hops = IPV6_DEFAULT_MCASTHOPS; + np->mc_loop = 1; + np->mc_all = 1; + np->pmtudisc = IPV6_PMTUDISC_WANT; + np->repflow = net->ipv6.sysctl.flowlabel_reflect & FLOWLABEL_REFLECT_ESTABLISHED; + sk->sk_ipv6only = net->ipv6.sysctl.bindv6only; + sk->sk_txrehash = READ_ONCE(net->core.sysctl_txrehash); + + /* Init the ipv4 part of the socket since we can have sockets + * using v6 API for ipv4. + */ + inet->uc_ttl = -1; + + inet->mc_loop = 1; + inet->mc_ttl = 1; + inet->mc_index = 0; + RCU_INIT_POINTER(inet->mc_list, NULL); + inet->rcv_tos = 0; + + if (READ_ONCE(net->ipv4.sysctl_ip_no_pmtu_disc)) + inet->pmtudisc = IP_PMTUDISC_DONT; + else + inet->pmtudisc = IP_PMTUDISC_WANT; + /* + * Increment only the relevant sk_prot->socks debug field, this changes + * the previous behaviour of incrementing both the equivalent to + * answer->prot->socks (inet6_sock_nr) and inet_sock_nr. + * + * This allows better debug granularity as we'll know exactly how many + * UDPv6, TCPv6, etc socks were allocated, not the sum of all IPv6 + * transport protocol socks. -acme + */ + sk_refcnt_debug_inc(sk); + + if (inet->inet_num) { + /* It assumes that any protocol which allows + * the user to assign a number at socket + * creation time automatically shares. + */ + inet->inet_sport = htons(inet->inet_num); + err = sk->sk_prot->hash(sk); + if (err) { + sk_common_release(sk); + goto out; + } + } + if (sk->sk_prot->init) { + err = sk->sk_prot->init(sk); + if (err) { + sk_common_release(sk); + goto out; + } + } + + if (!kern) { + err = BPF_CGROUP_RUN_PROG_INET_SOCK(sk); + if (err) { + sk_common_release(sk); + goto out; + } + } +out: + return err; +out_rcu_unlock: + rcu_read_unlock(); + goto out; +} + +static int __inet6_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len, + u32 flags) +{ + struct sockaddr_in6 *addr = (struct sockaddr_in6 *)uaddr; + struct inet_sock *inet = inet_sk(sk); + struct ipv6_pinfo *np = inet6_sk(sk); + struct net *net = sock_net(sk); + __be32 v4addr = 0; + unsigned short snum; + bool saved_ipv6only; + int addr_type = 0; + int err = 0; + + if (addr->sin6_family != AF_INET6) + return -EAFNOSUPPORT; + + addr_type = ipv6_addr_type(&addr->sin6_addr); + if ((addr_type & IPV6_ADDR_MULTICAST) && sk->sk_type == SOCK_STREAM) + return -EINVAL; + + snum = ntohs(addr->sin6_port); + if (!(flags & BIND_NO_CAP_NET_BIND_SERVICE) && + snum && inet_port_requires_bind_service(net, snum) && + !ns_capable(net->user_ns, CAP_NET_BIND_SERVICE)) + return -EACCES; + + if (flags & BIND_WITH_LOCK) + lock_sock(sk); + + /* Check these errors (active socket, double bind). */ + if (sk->sk_state != TCP_CLOSE || inet->inet_num) { + err = -EINVAL; + goto out; + } + + /* Check if the address belongs to the host. */ + if (addr_type == IPV6_ADDR_MAPPED) { + struct net_device *dev = NULL; + int chk_addr_ret; + + /* Binding to v4-mapped address on a v6-only socket + * makes no sense + */ + if (ipv6_only_sock(sk)) { + err = -EINVAL; + goto out; + } + + rcu_read_lock(); + if (sk->sk_bound_dev_if) { + dev = dev_get_by_index_rcu(net, sk->sk_bound_dev_if); + if (!dev) { + err = -ENODEV; + goto out_unlock; + } + } + + /* Reproduce AF_INET checks to make the bindings consistent */ + v4addr = addr->sin6_addr.s6_addr32[3]; + chk_addr_ret = inet_addr_type_dev_table(net, dev, v4addr); + rcu_read_unlock(); + + if (!inet_addr_valid_or_nonlocal(net, inet, v4addr, + chk_addr_ret)) { + err = -EADDRNOTAVAIL; + goto out; + } + } else { + if (addr_type != IPV6_ADDR_ANY) { + struct net_device *dev = NULL; + + rcu_read_lock(); + if (__ipv6_addr_needs_scope_id(addr_type)) { + if (addr_len >= sizeof(struct sockaddr_in6) && + addr->sin6_scope_id) { + /* Override any existing binding, if another one + * is supplied by user. + */ + sk->sk_bound_dev_if = addr->sin6_scope_id; + } + + /* Binding to link-local address requires an interface */ + if (!sk->sk_bound_dev_if) { + err = -EINVAL; + goto out_unlock; + } + } + + if (sk->sk_bound_dev_if) { + dev = dev_get_by_index_rcu(net, sk->sk_bound_dev_if); + if (!dev) { + err = -ENODEV; + goto out_unlock; + } + } + + /* ipv4 addr of the socket is invalid. Only the + * unspecified and mapped address have a v4 equivalent. + */ + v4addr = LOOPBACK4_IPV6; + if (!(addr_type & IPV6_ADDR_MULTICAST)) { + if (!ipv6_can_nonlocal_bind(net, inet) && + !ipv6_chk_addr(net, &addr->sin6_addr, + dev, 0)) { + err = -EADDRNOTAVAIL; + goto out_unlock; + } + } + rcu_read_unlock(); + } + } + + inet->inet_rcv_saddr = v4addr; + inet->inet_saddr = v4addr; + + sk->sk_v6_rcv_saddr = addr->sin6_addr; + + if (!(addr_type & IPV6_ADDR_MULTICAST)) + np->saddr = addr->sin6_addr; + + saved_ipv6only = sk->sk_ipv6only; + if (addr_type != IPV6_ADDR_ANY && addr_type != IPV6_ADDR_MAPPED) + sk->sk_ipv6only = 1; + + /* Make sure we are allowed to bind here. */ + if (snum || !(inet->bind_address_no_port || + (flags & BIND_FORCE_ADDRESS_NO_PORT))) { + err = sk->sk_prot->get_port(sk, snum); + if (err) { + sk->sk_ipv6only = saved_ipv6only; + inet_reset_saddr(sk); + goto out; + } + if (!(flags & BIND_FROM_BPF)) { + err = BPF_CGROUP_RUN_PROG_INET6_POST_BIND(sk); + if (err) { + sk->sk_ipv6only = saved_ipv6only; + inet_reset_saddr(sk); + if (sk->sk_prot->put_port) + sk->sk_prot->put_port(sk); + goto out; + } + } + } + + if (addr_type != IPV6_ADDR_ANY) + sk->sk_userlocks |= SOCK_BINDADDR_LOCK; + if (snum) + sk->sk_userlocks |= SOCK_BINDPORT_LOCK; + inet->inet_sport = htons(inet->inet_num); + inet->inet_dport = 0; + inet->inet_daddr = 0; +out: + if (flags & BIND_WITH_LOCK) + release_sock(sk); + return err; +out_unlock: + rcu_read_unlock(); + goto out; +} + +/* bind for INET6 API */ +int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +{ + struct sock *sk = sock->sk; + u32 flags = BIND_WITH_LOCK; + const struct proto *prot; + int err = 0; + + /* IPV6_ADDRFORM can change sk->sk_prot under us. */ + prot = READ_ONCE(sk->sk_prot); + /* If the socket has its own bind function then use it. */ + if (prot->bind) + return prot->bind(sk, uaddr, addr_len); + + if (addr_len < SIN6_LEN_RFC2133) + return -EINVAL; + + /* BPF prog is run before any checks are done so that if the prog + * changes context in a wrong way it will be caught. + */ + err = BPF_CGROUP_RUN_PROG_INET_BIND_LOCK(sk, uaddr, + CGROUP_INET6_BIND, &flags); + if (err) + return err; + + return __inet6_bind(sk, uaddr, addr_len, flags); +} +EXPORT_SYMBOL(inet6_bind); + +int inet6_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + + if (!sk) + return -EINVAL; + + /* Free mc lists */ + ipv6_sock_mc_close(sk); + + /* Free ac lists */ + ipv6_sock_ac_close(sk); + + return inet_release(sock); +} +EXPORT_SYMBOL(inet6_release); + +void inet6_destroy_sock(struct sock *sk) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct sk_buff *skb; + struct ipv6_txoptions *opt; + + /* Release rx options */ + + skb = xchg(&np->pktoptions, NULL); + kfree_skb(skb); + + skb = xchg(&np->rxpmtu, NULL); + kfree_skb(skb); + + /* Free flowlabels */ + fl6_free_socklist(sk); + + /* Free tx options */ + + opt = xchg((__force struct ipv6_txoptions **)&np->opt, NULL); + if (opt) { + atomic_sub(opt->tot_len, &sk->sk_omem_alloc); + txopt_put(opt); + } +} +EXPORT_SYMBOL_GPL(inet6_destroy_sock); + +void inet6_cleanup_sock(struct sock *sk) +{ + inet6_destroy_sock(sk); +} +EXPORT_SYMBOL_GPL(inet6_cleanup_sock); + +/* + * This does both peername and sockname. + */ +int inet6_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + struct sockaddr_in6 *sin = (struct sockaddr_in6 *)uaddr; + struct sock *sk = sock->sk; + struct inet_sock *inet = inet_sk(sk); + struct ipv6_pinfo *np = inet6_sk(sk); + + sin->sin6_family = AF_INET6; + sin->sin6_flowinfo = 0; + sin->sin6_scope_id = 0; + lock_sock(sk); + if (peer) { + if (!inet->inet_dport || + (((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_SYN_SENT)) && + peer == 1)) { + release_sock(sk); + return -ENOTCONN; + } + sin->sin6_port = inet->inet_dport; + sin->sin6_addr = sk->sk_v6_daddr; + if (np->sndflow) + sin->sin6_flowinfo = np->flow_label; + BPF_CGROUP_RUN_SA_PROG(sk, (struct sockaddr *)sin, + CGROUP_INET6_GETPEERNAME); + } else { + if (ipv6_addr_any(&sk->sk_v6_rcv_saddr)) + sin->sin6_addr = np->saddr; + else + sin->sin6_addr = sk->sk_v6_rcv_saddr; + sin->sin6_port = inet->inet_sport; + BPF_CGROUP_RUN_SA_PROG(sk, (struct sockaddr *)sin, + CGROUP_INET6_GETSOCKNAME); + } + sin->sin6_scope_id = ipv6_iface_scope_id(&sin->sin6_addr, + sk->sk_bound_dev_if); + release_sock(sk); + return sizeof(*sin); +} +EXPORT_SYMBOL(inet6_getname); + +int inet6_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + void __user *argp = (void __user *)arg; + struct sock *sk = sock->sk; + struct net *net = sock_net(sk); + const struct proto *prot; + + switch (cmd) { + case SIOCADDRT: + case SIOCDELRT: { + struct in6_rtmsg rtmsg; + + if (copy_from_user(&rtmsg, argp, sizeof(rtmsg))) + return -EFAULT; + return ipv6_route_ioctl(net, cmd, &rtmsg); + } + case SIOCSIFADDR: + return addrconf_add_ifaddr(net, argp); + case SIOCDIFADDR: + return addrconf_del_ifaddr(net, argp); + case SIOCSIFDSTADDR: + return addrconf_set_dstaddr(net, argp); + default: + /* IPV6_ADDRFORM can change sk->sk_prot under us. */ + prot = READ_ONCE(sk->sk_prot); + if (!prot->ioctl) + return -ENOIOCTLCMD; + return prot->ioctl(sk, cmd, arg); + } + /*NOTREACHED*/ + return 0; +} +EXPORT_SYMBOL(inet6_ioctl); + +#ifdef CONFIG_COMPAT +struct compat_in6_rtmsg { + struct in6_addr rtmsg_dst; + struct in6_addr rtmsg_src; + struct in6_addr rtmsg_gateway; + u32 rtmsg_type; + u16 rtmsg_dst_len; + u16 rtmsg_src_len; + u32 rtmsg_metric; + u32 rtmsg_info; + u32 rtmsg_flags; + s32 rtmsg_ifindex; +}; + +static int inet6_compat_routing_ioctl(struct sock *sk, unsigned int cmd, + struct compat_in6_rtmsg __user *ur) +{ + struct in6_rtmsg rt; + + if (copy_from_user(&rt.rtmsg_dst, &ur->rtmsg_dst, + 3 * sizeof(struct in6_addr)) || + get_user(rt.rtmsg_type, &ur->rtmsg_type) || + get_user(rt.rtmsg_dst_len, &ur->rtmsg_dst_len) || + get_user(rt.rtmsg_src_len, &ur->rtmsg_src_len) || + get_user(rt.rtmsg_metric, &ur->rtmsg_metric) || + get_user(rt.rtmsg_info, &ur->rtmsg_info) || + get_user(rt.rtmsg_flags, &ur->rtmsg_flags) || + get_user(rt.rtmsg_ifindex, &ur->rtmsg_ifindex)) + return -EFAULT; + + + return ipv6_route_ioctl(sock_net(sk), cmd, &rt); +} + +int inet6_compat_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + void __user *argp = compat_ptr(arg); + struct sock *sk = sock->sk; + + switch (cmd) { + case SIOCADDRT: + case SIOCDELRT: + return inet6_compat_routing_ioctl(sk, cmd, argp); + default: + return -ENOIOCTLCMD; + } +} +EXPORT_SYMBOL_GPL(inet6_compat_ioctl); +#endif /* CONFIG_COMPAT */ + +INDIRECT_CALLABLE_DECLARE(int udpv6_sendmsg(struct sock *, struct msghdr *, + size_t)); +int inet6_sendmsg(struct socket *sock, struct msghdr *msg, size_t size) +{ + struct sock *sk = sock->sk; + const struct proto *prot; + + if (unlikely(inet_send_prepare(sk))) + return -EAGAIN; + + /* IPV6_ADDRFORM can change sk->sk_prot under us. */ + prot = READ_ONCE(sk->sk_prot); + return INDIRECT_CALL_2(prot->sendmsg, tcp_sendmsg, udpv6_sendmsg, + sk, msg, size); +} + +INDIRECT_CALLABLE_DECLARE(int udpv6_recvmsg(struct sock *, struct msghdr *, + size_t, int, int *)); +int inet6_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, + int flags) +{ + struct sock *sk = sock->sk; + const struct proto *prot; + int addr_len = 0; + int err; + + if (likely(!(flags & MSG_ERRQUEUE))) + sock_rps_record_flow(sk); + + /* IPV6_ADDRFORM can change sk->sk_prot under us. */ + prot = READ_ONCE(sk->sk_prot); + err = INDIRECT_CALL_2(prot->recvmsg, tcp_recvmsg, udpv6_recvmsg, + sk, msg, size, flags, &addr_len); + if (err >= 0) + msg->msg_namelen = addr_len; + return err; +} + +const struct proto_ops inet6_stream_ops = { + .family = PF_INET6, + .owner = THIS_MODULE, + .release = inet6_release, + .bind = inet6_bind, + .connect = inet_stream_connect, /* ok */ + .socketpair = sock_no_socketpair, /* a do nothing */ + .accept = inet_accept, /* ok */ + .getname = inet6_getname, + .poll = tcp_poll, /* ok */ + .ioctl = inet6_ioctl, /* must change */ + .gettstamp = sock_gettstamp, + .listen = inet_listen, /* ok */ + .shutdown = inet_shutdown, /* ok */ + .setsockopt = sock_common_setsockopt, /* ok */ + .getsockopt = sock_common_getsockopt, /* ok */ + .sendmsg = inet6_sendmsg, /* retpoline's sake */ + .recvmsg = inet6_recvmsg, /* retpoline's sake */ +#ifdef CONFIG_MMU + .mmap = tcp_mmap, +#endif + .splice_eof = inet_splice_eof, + .sendpage = inet_sendpage, + .sendmsg_locked = tcp_sendmsg_locked, + .sendpage_locked = tcp_sendpage_locked, + .splice_read = tcp_splice_read, + .read_sock = tcp_read_sock, + .read_skb = tcp_read_skb, + .peek_len = tcp_peek_len, +#ifdef CONFIG_COMPAT + .compat_ioctl = inet6_compat_ioctl, +#endif + .set_rcvlowat = tcp_set_rcvlowat, +}; + +const struct proto_ops inet6_dgram_ops = { + .family = PF_INET6, + .owner = THIS_MODULE, + .release = inet6_release, + .bind = inet6_bind, + .connect = inet_dgram_connect, /* ok */ + .socketpair = sock_no_socketpair, /* a do nothing */ + .accept = sock_no_accept, /* a do nothing */ + .getname = inet6_getname, + .poll = udp_poll, /* ok */ + .ioctl = inet6_ioctl, /* must change */ + .gettstamp = sock_gettstamp, + .listen = sock_no_listen, /* ok */ + .shutdown = inet_shutdown, /* ok */ + .setsockopt = sock_common_setsockopt, /* ok */ + .getsockopt = sock_common_getsockopt, /* ok */ + .sendmsg = inet6_sendmsg, /* retpoline's sake */ + .recvmsg = inet6_recvmsg, /* retpoline's sake */ + .read_skb = udp_read_skb, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, + .set_peek_off = sk_set_peek_off, +#ifdef CONFIG_COMPAT + .compat_ioctl = inet6_compat_ioctl, +#endif +}; + +static const struct net_proto_family inet6_family_ops = { + .family = PF_INET6, + .create = inet6_create, + .owner = THIS_MODULE, +}; + +int inet6_register_protosw(struct inet_protosw *p) +{ + struct list_head *lh; + struct inet_protosw *answer; + struct list_head *last_perm; + int protocol = p->protocol; + int ret; + + spin_lock_bh(&inetsw6_lock); + + ret = -EINVAL; + if (p->type >= SOCK_MAX) + goto out_illegal; + + /* If we are trying to override a permanent protocol, bail. */ + answer = NULL; + ret = -EPERM; + last_perm = &inetsw6[p->type]; + list_for_each(lh, &inetsw6[p->type]) { + answer = list_entry(lh, struct inet_protosw, list); + + /* Check only the non-wild match. */ + if (INET_PROTOSW_PERMANENT & answer->flags) { + if (protocol == answer->protocol) + break; + last_perm = lh; + } + + answer = NULL; + } + if (answer) + goto out_permanent; + + /* Add the new entry after the last permanent entry if any, so that + * the new entry does not override a permanent entry when matched with + * a wild-card protocol. But it is allowed to override any existing + * non-permanent entry. This means that when we remove this entry, the + * system automatically returns to the old behavior. + */ + list_add_rcu(&p->list, last_perm); + ret = 0; +out: + spin_unlock_bh(&inetsw6_lock); + return ret; + +out_permanent: + pr_err("Attempt to override permanent protocol %d\n", protocol); + goto out; + +out_illegal: + pr_err("Ignoring attempt to register invalid socket type %d\n", + p->type); + goto out; +} +EXPORT_SYMBOL(inet6_register_protosw); + +void +inet6_unregister_protosw(struct inet_protosw *p) +{ + if (INET_PROTOSW_PERMANENT & p->flags) { + pr_err("Attempt to unregister permanent protocol %d\n", + p->protocol); + } else { + spin_lock_bh(&inetsw6_lock); + list_del_rcu(&p->list); + spin_unlock_bh(&inetsw6_lock); + + synchronize_net(); + } +} +EXPORT_SYMBOL(inet6_unregister_protosw); + +int inet6_sk_rebuild_header(struct sock *sk) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct dst_entry *dst; + + dst = __sk_dst_check(sk, np->dst_cookie); + + if (!dst) { + struct inet_sock *inet = inet_sk(sk); + struct in6_addr *final_p, final; + struct flowi6 fl6; + + memset(&fl6, 0, sizeof(fl6)); + fl6.flowi6_proto = sk->sk_protocol; + fl6.daddr = sk->sk_v6_daddr; + fl6.saddr = np->saddr; + fl6.flowlabel = np->flow_label; + fl6.flowi6_oif = sk->sk_bound_dev_if; + fl6.flowi6_mark = sk->sk_mark; + fl6.fl6_dport = inet->inet_dport; + fl6.fl6_sport = inet->inet_sport; + fl6.flowi6_uid = sk->sk_uid; + security_sk_classify_flow(sk, flowi6_to_flowi_common(&fl6)); + + rcu_read_lock(); + final_p = fl6_update_dst(&fl6, rcu_dereference(np->opt), + &final); + rcu_read_unlock(); + + dst = ip6_dst_lookup_flow(sock_net(sk), sk, &fl6, final_p); + if (IS_ERR(dst)) { + sk->sk_route_caps = 0; + sk->sk_err_soft = -PTR_ERR(dst); + return PTR_ERR(dst); + } + + ip6_dst_store(sk, dst, NULL, NULL); + } + + return 0; +} +EXPORT_SYMBOL_GPL(inet6_sk_rebuild_header); + +bool ipv6_opt_accepted(const struct sock *sk, const struct sk_buff *skb, + const struct inet6_skb_parm *opt) +{ + const struct ipv6_pinfo *np = inet6_sk(sk); + + if (np->rxopt.all) { + if (((opt->flags & IP6SKB_HOPBYHOP) && + (np->rxopt.bits.hopopts || np->rxopt.bits.ohopopts)) || + (ip6_flowinfo((struct ipv6hdr *) skb_network_header(skb)) && + np->rxopt.bits.rxflow) || + (opt->srcrt && (np->rxopt.bits.srcrt || + np->rxopt.bits.osrcrt)) || + ((opt->dst1 || opt->dst0) && + (np->rxopt.bits.dstopts || np->rxopt.bits.odstopts))) + return true; + } + return false; +} +EXPORT_SYMBOL_GPL(ipv6_opt_accepted); + +static struct packet_type ipv6_packet_type __read_mostly = { + .type = cpu_to_be16(ETH_P_IPV6), + .func = ipv6_rcv, + .list_func = ipv6_list_rcv, +}; + +static int __init ipv6_packet_init(void) +{ + dev_add_pack(&ipv6_packet_type); + return 0; +} + +static void ipv6_packet_cleanup(void) +{ + dev_remove_pack(&ipv6_packet_type); +} + +static int __net_init ipv6_init_mibs(struct net *net) +{ + int i; + + net->mib.udp_stats_in6 = alloc_percpu(struct udp_mib); + if (!net->mib.udp_stats_in6) + return -ENOMEM; + net->mib.udplite_stats_in6 = alloc_percpu(struct udp_mib); + if (!net->mib.udplite_stats_in6) + goto err_udplite_mib; + net->mib.ipv6_statistics = alloc_percpu(struct ipstats_mib); + if (!net->mib.ipv6_statistics) + goto err_ip_mib; + + for_each_possible_cpu(i) { + struct ipstats_mib *af_inet6_stats; + af_inet6_stats = per_cpu_ptr(net->mib.ipv6_statistics, i); + u64_stats_init(&af_inet6_stats->syncp); + } + + + net->mib.icmpv6_statistics = alloc_percpu(struct icmpv6_mib); + if (!net->mib.icmpv6_statistics) + goto err_icmp_mib; + net->mib.icmpv6msg_statistics = kzalloc(sizeof(struct icmpv6msg_mib), + GFP_KERNEL); + if (!net->mib.icmpv6msg_statistics) + goto err_icmpmsg_mib; + return 0; + +err_icmpmsg_mib: + free_percpu(net->mib.icmpv6_statistics); +err_icmp_mib: + free_percpu(net->mib.ipv6_statistics); +err_ip_mib: + free_percpu(net->mib.udplite_stats_in6); +err_udplite_mib: + free_percpu(net->mib.udp_stats_in6); + return -ENOMEM; +} + +static void ipv6_cleanup_mibs(struct net *net) +{ + free_percpu(net->mib.udp_stats_in6); + free_percpu(net->mib.udplite_stats_in6); + free_percpu(net->mib.ipv6_statistics); + free_percpu(net->mib.icmpv6_statistics); + kfree(net->mib.icmpv6msg_statistics); +} + +static int __net_init inet6_net_init(struct net *net) +{ + int err = 0; + + net->ipv6.sysctl.bindv6only = 0; + net->ipv6.sysctl.icmpv6_time = 1*HZ; + net->ipv6.sysctl.icmpv6_echo_ignore_all = 0; + net->ipv6.sysctl.icmpv6_echo_ignore_multicast = 0; + net->ipv6.sysctl.icmpv6_echo_ignore_anycast = 0; + + /* By default, rate limit error messages. + * Except for pmtu discovery, it would break it. + * proc_do_large_bitmap needs pointer to the bitmap. + */ + bitmap_set(net->ipv6.sysctl.icmpv6_ratemask, 0, ICMPV6_ERRMSG_MAX + 1); + bitmap_clear(net->ipv6.sysctl.icmpv6_ratemask, ICMPV6_PKT_TOOBIG, 1); + net->ipv6.sysctl.icmpv6_ratemask_ptr = net->ipv6.sysctl.icmpv6_ratemask; + + net->ipv6.sysctl.flowlabel_consistency = 1; + net->ipv6.sysctl.auto_flowlabels = IP6_DEFAULT_AUTO_FLOW_LABELS; + net->ipv6.sysctl.idgen_retries = 3; + net->ipv6.sysctl.idgen_delay = 1 * HZ; + net->ipv6.sysctl.flowlabel_state_ranges = 0; + net->ipv6.sysctl.max_dst_opts_cnt = IP6_DEFAULT_MAX_DST_OPTS_CNT; + net->ipv6.sysctl.max_hbh_opts_cnt = IP6_DEFAULT_MAX_HBH_OPTS_CNT; + net->ipv6.sysctl.max_dst_opts_len = IP6_DEFAULT_MAX_DST_OPTS_LEN; + net->ipv6.sysctl.max_hbh_opts_len = IP6_DEFAULT_MAX_HBH_OPTS_LEN; + net->ipv6.sysctl.fib_notify_on_flag_change = 0; + atomic_set(&net->ipv6.fib6_sernum, 1); + + net->ipv6.sysctl.ioam6_id = IOAM6_DEFAULT_ID; + net->ipv6.sysctl.ioam6_id_wide = IOAM6_DEFAULT_ID_WIDE; + + err = ipv6_init_mibs(net); + if (err) + return err; +#ifdef CONFIG_PROC_FS + err = udp6_proc_init(net); + if (err) + goto out; + err = tcp6_proc_init(net); + if (err) + goto proc_tcp6_fail; + err = ac6_proc_init(net); + if (err) + goto proc_ac6_fail; +#endif + return err; + +#ifdef CONFIG_PROC_FS +proc_ac6_fail: + tcp6_proc_exit(net); +proc_tcp6_fail: + udp6_proc_exit(net); +out: + ipv6_cleanup_mibs(net); + return err; +#endif +} + +static void __net_exit inet6_net_exit(struct net *net) +{ +#ifdef CONFIG_PROC_FS + udp6_proc_exit(net); + tcp6_proc_exit(net); + ac6_proc_exit(net); +#endif + ipv6_cleanup_mibs(net); +} + +static struct pernet_operations inet6_net_ops = { + .init = inet6_net_init, + .exit = inet6_net_exit, +}; + +static int ipv6_route_input(struct sk_buff *skb) +{ + ip6_route_input(skb); + return skb_dst(skb)->error; +} + +static const struct ipv6_stub ipv6_stub_impl = { + .ipv6_sock_mc_join = ipv6_sock_mc_join, + .ipv6_sock_mc_drop = ipv6_sock_mc_drop, + .ipv6_dst_lookup_flow = ip6_dst_lookup_flow, + .ipv6_route_input = ipv6_route_input, + .fib6_get_table = fib6_get_table, + .fib6_table_lookup = fib6_table_lookup, + .fib6_lookup = fib6_lookup, + .fib6_select_path = fib6_select_path, + .ip6_mtu_from_fib6 = ip6_mtu_from_fib6, + .fib6_nh_init = fib6_nh_init, + .fib6_nh_release = fib6_nh_release, + .fib6_nh_release_dsts = fib6_nh_release_dsts, + .fib6_update_sernum = fib6_update_sernum_stub, + .fib6_rt_update = fib6_rt_update, + .ip6_del_rt = ip6_del_rt, + .udpv6_encap_enable = udpv6_encap_enable, + .ndisc_send_na = ndisc_send_na, +#if IS_ENABLED(CONFIG_XFRM) + .xfrm6_local_rxpmtu = xfrm6_local_rxpmtu, + .xfrm6_udp_encap_rcv = xfrm6_udp_encap_rcv, + .xfrm6_rcv_encap = xfrm6_rcv_encap, +#endif + .nd_tbl = &nd_tbl, + .ipv6_fragment = ip6_fragment, + .ipv6_dev_find = ipv6_dev_find, +}; + +static const struct ipv6_bpf_stub ipv6_bpf_stub_impl = { + .inet6_bind = __inet6_bind, + .udp6_lib_lookup = __udp6_lib_lookup, + .ipv6_setsockopt = do_ipv6_setsockopt, + .ipv6_getsockopt = do_ipv6_getsockopt, +}; + +static int __init inet6_init(void) +{ + struct list_head *r; + int err = 0; + + sock_skb_cb_check_size(sizeof(struct inet6_skb_parm)); + + /* Register the socket-side information for inet6_create. */ + for (r = &inetsw6[0]; r < &inetsw6[SOCK_MAX]; ++r) + INIT_LIST_HEAD(r); + + raw_hashinfo_init(&raw_v6_hashinfo); + + if (disable_ipv6_mod) { + pr_info("Loaded, but administratively disabled, reboot required to enable\n"); + goto out; + } + + err = proto_register(&tcpv6_prot, 1); + if (err) + goto out; + + err = proto_register(&udpv6_prot, 1); + if (err) + goto out_unregister_tcp_proto; + + err = proto_register(&udplitev6_prot, 1); + if (err) + goto out_unregister_udp_proto; + + err = proto_register(&rawv6_prot, 1); + if (err) + goto out_unregister_udplite_proto; + + err = proto_register(&pingv6_prot, 1); + if (err) + goto out_unregister_raw_proto; + + /* We MUST register RAW sockets before we create the ICMP6, + * IGMP6, or NDISC control sockets. + */ + err = rawv6_init(); + if (err) + goto out_unregister_ping_proto; + + /* Register the family here so that the init calls below will + * be able to create sockets. (?? is this dangerous ??) + */ + err = sock_register(&inet6_family_ops); + if (err) + goto out_sock_register_fail; + + /* + * ipngwg API draft makes clear that the correct semantics + * for TCP and UDP is to consider one TCP and UDP instance + * in a host available by both INET and INET6 APIs and + * able to communicate via both network protocols. + */ + + err = register_pernet_subsys(&inet6_net_ops); + if (err) + goto register_pernet_fail; + err = ip6_mr_init(); + if (err) + goto ipmr_fail; + err = icmpv6_init(); + if (err) + goto icmp_fail; + err = ndisc_init(); + if (err) + goto ndisc_fail; + err = igmp6_init(); + if (err) + goto igmp_fail; + + err = ipv6_netfilter_init(); + if (err) + goto netfilter_fail; + /* Create /proc/foo6 entries. */ +#ifdef CONFIG_PROC_FS + err = -ENOMEM; + if (raw6_proc_init()) + goto proc_raw6_fail; + if (udplite6_proc_init()) + goto proc_udplite6_fail; + if (ipv6_misc_proc_init()) + goto proc_misc6_fail; + if (if6_proc_init()) + goto proc_if6_fail; +#endif + err = ip6_route_init(); + if (err) + goto ip6_route_fail; + err = ndisc_late_init(); + if (err) + goto ndisc_late_fail; + err = ip6_flowlabel_init(); + if (err) + goto ip6_flowlabel_fail; + err = ipv6_anycast_init(); + if (err) + goto ipv6_anycast_fail; + err = addrconf_init(); + if (err) + goto addrconf_fail; + + /* Init v6 extension headers. */ + err = ipv6_exthdrs_init(); + if (err) + goto ipv6_exthdrs_fail; + + err = ipv6_frag_init(); + if (err) + goto ipv6_frag_fail; + + /* Init v6 transport protocols. */ + err = udpv6_init(); + if (err) + goto udpv6_fail; + + err = udplitev6_init(); + if (err) + goto udplitev6_fail; + + err = udpv6_offload_init(); + if (err) + goto udpv6_offload_fail; + + err = tcpv6_init(); + if (err) + goto tcpv6_fail; + + err = ipv6_packet_init(); + if (err) + goto ipv6_packet_fail; + + err = pingv6_init(); + if (err) + goto pingv6_fail; + + err = calipso_init(); + if (err) + goto calipso_fail; + + err = seg6_init(); + if (err) + goto seg6_fail; + + err = rpl_init(); + if (err) + goto rpl_fail; + + err = ioam6_init(); + if (err) + goto ioam6_fail; + + err = igmp6_late_init(); + if (err) + goto igmp6_late_err; + +#ifdef CONFIG_SYSCTL + err = ipv6_sysctl_register(); + if (err) + goto sysctl_fail; +#endif + + /* ensure that ipv6 stubs are visible only after ipv6 is ready */ + wmb(); + ipv6_stub = &ipv6_stub_impl; + ipv6_bpf_stub = &ipv6_bpf_stub_impl; +out: + return err; + +#ifdef CONFIG_SYSCTL +sysctl_fail: + igmp6_late_cleanup(); +#endif +igmp6_late_err: + ioam6_exit(); +ioam6_fail: + rpl_exit(); +rpl_fail: + seg6_exit(); +seg6_fail: + calipso_exit(); +calipso_fail: + pingv6_exit(); +pingv6_fail: + ipv6_packet_cleanup(); +ipv6_packet_fail: + tcpv6_exit(); +tcpv6_fail: + udpv6_offload_exit(); +udpv6_offload_fail: + udplitev6_exit(); +udplitev6_fail: + udpv6_exit(); +udpv6_fail: + ipv6_frag_exit(); +ipv6_frag_fail: + ipv6_exthdrs_exit(); +ipv6_exthdrs_fail: + addrconf_cleanup(); +addrconf_fail: + ipv6_anycast_cleanup(); +ipv6_anycast_fail: + ip6_flowlabel_cleanup(); +ip6_flowlabel_fail: + ndisc_late_cleanup(); +ndisc_late_fail: + ip6_route_cleanup(); +ip6_route_fail: +#ifdef CONFIG_PROC_FS + if6_proc_exit(); +proc_if6_fail: + ipv6_misc_proc_exit(); +proc_misc6_fail: + udplite6_proc_exit(); +proc_udplite6_fail: + raw6_proc_exit(); +proc_raw6_fail: +#endif + ipv6_netfilter_fini(); +netfilter_fail: + igmp6_cleanup(); +igmp_fail: + ndisc_cleanup(); +ndisc_fail: + icmpv6_cleanup(); +icmp_fail: + ip6_mr_cleanup(); +ipmr_fail: + unregister_pernet_subsys(&inet6_net_ops); +register_pernet_fail: + sock_unregister(PF_INET6); + rtnl_unregister_all(PF_INET6); +out_sock_register_fail: + rawv6_exit(); +out_unregister_ping_proto: + proto_unregister(&pingv6_prot); +out_unregister_raw_proto: + proto_unregister(&rawv6_prot); +out_unregister_udplite_proto: + proto_unregister(&udplitev6_prot); +out_unregister_udp_proto: + proto_unregister(&udpv6_prot); +out_unregister_tcp_proto: + proto_unregister(&tcpv6_prot); + goto out; +} +module_init(inet6_init); + +MODULE_ALIAS_NETPROTO(PF_INET6); diff --git a/net/ipv6/ah6.c b/net/ipv6/ah6.c new file mode 100644 index 000000000..5228d2716 --- /dev/null +++ b/net/ipv6/ah6.c @@ -0,0 +1,807 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (C)2002 USAGI/WIDE Project + * + * Authors + * + * Mitsuru KANDA @USAGI : IPv6 Support + * Kazunori MIYAZAWA @USAGI : + * Kunihiro Ishiguro + * + * This file is derived from net/ipv4/ah.c. + */ + +#define pr_fmt(fmt) "IPv6: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define IPV6HDR_BASELEN 8 + +struct tmp_ext { +#if IS_ENABLED(CONFIG_IPV6_MIP6) + struct in6_addr saddr; +#endif + struct in6_addr daddr; + char hdrs[]; +}; + +struct ah_skb_cb { + struct xfrm_skb_cb xfrm; + void *tmp; +}; + +#define AH_SKB_CB(__skb) ((struct ah_skb_cb *)&((__skb)->cb[0])) + +static void *ah_alloc_tmp(struct crypto_ahash *ahash, int nfrags, + unsigned int size) +{ + unsigned int len; + + len = size + crypto_ahash_digestsize(ahash) + + (crypto_ahash_alignmask(ahash) & + ~(crypto_tfm_ctx_alignment() - 1)); + + len = ALIGN(len, crypto_tfm_ctx_alignment()); + + len += sizeof(struct ahash_request) + crypto_ahash_reqsize(ahash); + len = ALIGN(len, __alignof__(struct scatterlist)); + + len += sizeof(struct scatterlist) * nfrags; + + return kmalloc(len, GFP_ATOMIC); +} + +static inline struct tmp_ext *ah_tmp_ext(void *base) +{ + return base + IPV6HDR_BASELEN; +} + +static inline u8 *ah_tmp_auth(u8 *tmp, unsigned int offset) +{ + return tmp + offset; +} + +static inline u8 *ah_tmp_icv(struct crypto_ahash *ahash, void *tmp, + unsigned int offset) +{ + return PTR_ALIGN((u8 *)tmp + offset, crypto_ahash_alignmask(ahash) + 1); +} + +static inline struct ahash_request *ah_tmp_req(struct crypto_ahash *ahash, + u8 *icv) +{ + struct ahash_request *req; + + req = (void *)PTR_ALIGN(icv + crypto_ahash_digestsize(ahash), + crypto_tfm_ctx_alignment()); + + ahash_request_set_tfm(req, ahash); + + return req; +} + +static inline struct scatterlist *ah_req_sg(struct crypto_ahash *ahash, + struct ahash_request *req) +{ + return (void *)ALIGN((unsigned long)(req + 1) + + crypto_ahash_reqsize(ahash), + __alignof__(struct scatterlist)); +} + +static bool zero_out_mutable_opts(struct ipv6_opt_hdr *opthdr) +{ + u8 *opt = (u8 *)opthdr; + int len = ipv6_optlen(opthdr); + int off = 0; + int optlen = 0; + + off += 2; + len -= 2; + + while (len > 0) { + + switch (opt[off]) { + + case IPV6_TLV_PAD1: + optlen = 1; + break; + default: + if (len < 2) + goto bad; + optlen = opt[off+1]+2; + if (len < optlen) + goto bad; + if (opt[off] & 0x20) + memset(&opt[off+2], 0, opt[off+1]); + break; + } + + off += optlen; + len -= optlen; + } + if (len == 0) + return true; + +bad: + return false; +} + +#if IS_ENABLED(CONFIG_IPV6_MIP6) +/** + * ipv6_rearrange_destopt - rearrange IPv6 destination options header + * @iph: IPv6 header + * @destopt: destionation options header + */ +static void ipv6_rearrange_destopt(struct ipv6hdr *iph, struct ipv6_opt_hdr *destopt) +{ + u8 *opt = (u8 *)destopt; + int len = ipv6_optlen(destopt); + int off = 0; + int optlen = 0; + + off += 2; + len -= 2; + + while (len > 0) { + + switch (opt[off]) { + + case IPV6_TLV_PAD1: + optlen = 1; + break; + default: + if (len < 2) + goto bad; + optlen = opt[off+1]+2; + if (len < optlen) + goto bad; + + /* Rearrange the source address in @iph and the + * addresses in home address option for final source. + * See 11.3.2 of RFC 3775 for details. + */ + if (opt[off] == IPV6_TLV_HAO) { + struct ipv6_destopt_hao *hao; + + hao = (struct ipv6_destopt_hao *)&opt[off]; + if (hao->length != sizeof(hao->addr)) { + net_warn_ratelimited("destopt hao: invalid header length: %u\n", + hao->length); + goto bad; + } + swap(hao->addr, iph->saddr); + } + break; + } + + off += optlen; + len -= optlen; + } + /* Note: ok if len == 0 */ +bad: + return; +} +#else +static void ipv6_rearrange_destopt(struct ipv6hdr *iph, struct ipv6_opt_hdr *destopt) {} +#endif + +/** + * ipv6_rearrange_rthdr - rearrange IPv6 routing header + * @iph: IPv6 header + * @rthdr: routing header + * + * Rearrange the destination address in @iph and the addresses in @rthdr + * so that they appear in the order they will at the final destination. + * See Appendix A2 of RFC 2402 for details. + */ +static void ipv6_rearrange_rthdr(struct ipv6hdr *iph, struct ipv6_rt_hdr *rthdr) +{ + int segments, segments_left; + struct in6_addr *addrs; + struct in6_addr final_addr; + + segments_left = rthdr->segments_left; + if (segments_left == 0) + return; + rthdr->segments_left = 0; + + /* The value of rthdr->hdrlen has been verified either by the system + * call if it is locally generated, or by ipv6_rthdr_rcv() for incoming + * packets. So we can assume that it is even and that segments is + * greater than or equal to segments_left. + * + * For the same reason we can assume that this option is of type 0. + */ + segments = rthdr->hdrlen >> 1; + + addrs = ((struct rt0_hdr *)rthdr)->addr; + final_addr = addrs[segments - 1]; + + addrs += segments - segments_left; + memmove(addrs + 1, addrs, (segments_left - 1) * sizeof(*addrs)); + + addrs[0] = iph->daddr; + iph->daddr = final_addr; +} + +static int ipv6_clear_mutable_options(struct ipv6hdr *iph, int len, int dir) +{ + union { + struct ipv6hdr *iph; + struct ipv6_opt_hdr *opth; + struct ipv6_rt_hdr *rth; + char *raw; + } exthdr = { .iph = iph }; + char *end = exthdr.raw + len; + int nexthdr = iph->nexthdr; + + exthdr.iph++; + + while (exthdr.raw < end) { + switch (nexthdr) { + case NEXTHDR_DEST: + if (dir == XFRM_POLICY_OUT) + ipv6_rearrange_destopt(iph, exthdr.opth); + fallthrough; + case NEXTHDR_HOP: + if (!zero_out_mutable_opts(exthdr.opth)) { + net_dbg_ratelimited("overrun %sopts\n", + nexthdr == NEXTHDR_HOP ? + "hop" : "dest"); + return -EINVAL; + } + break; + + case NEXTHDR_ROUTING: + ipv6_rearrange_rthdr(iph, exthdr.rth); + break; + + default: + return 0; + } + + nexthdr = exthdr.opth->nexthdr; + exthdr.raw += ipv6_optlen(exthdr.opth); + } + + return 0; +} + +static void ah6_output_done(struct crypto_async_request *base, int err) +{ + int extlen; + u8 *iph_base; + u8 *icv; + struct sk_buff *skb = base->data; + struct xfrm_state *x = skb_dst(skb)->xfrm; + struct ah_data *ahp = x->data; + struct ipv6hdr *top_iph = ipv6_hdr(skb); + struct ip_auth_hdr *ah = ip_auth_hdr(skb); + struct tmp_ext *iph_ext; + + extlen = skb_network_header_len(skb) - sizeof(struct ipv6hdr); + if (extlen) + extlen += sizeof(*iph_ext); + + iph_base = AH_SKB_CB(skb)->tmp; + iph_ext = ah_tmp_ext(iph_base); + icv = ah_tmp_icv(ahp->ahash, iph_ext, extlen); + + memcpy(ah->auth_data, icv, ahp->icv_trunc_len); + memcpy(top_iph, iph_base, IPV6HDR_BASELEN); + + if (extlen) { +#if IS_ENABLED(CONFIG_IPV6_MIP6) + memcpy(&top_iph->saddr, iph_ext, extlen); +#else + memcpy(&top_iph->daddr, iph_ext, extlen); +#endif + } + + kfree(AH_SKB_CB(skb)->tmp); + xfrm_output_resume(skb->sk, skb, err); +} + +static int ah6_output(struct xfrm_state *x, struct sk_buff *skb) +{ + int err; + int nfrags; + int extlen; + u8 *iph_base; + u8 *icv; + u8 nexthdr; + struct sk_buff *trailer; + struct crypto_ahash *ahash; + struct ahash_request *req; + struct scatterlist *sg; + struct ipv6hdr *top_iph; + struct ip_auth_hdr *ah; + struct ah_data *ahp; + struct tmp_ext *iph_ext; + int seqhi_len = 0; + __be32 *seqhi; + int sglists = 0; + struct scatterlist *seqhisg; + + ahp = x->data; + ahash = ahp->ahash; + + err = skb_cow_data(skb, 0, &trailer); + if (err < 0) + goto out; + nfrags = err; + + skb_push(skb, -skb_network_offset(skb)); + extlen = skb_network_header_len(skb) - sizeof(struct ipv6hdr); + if (extlen) + extlen += sizeof(*iph_ext); + + if (x->props.flags & XFRM_STATE_ESN) { + sglists = 1; + seqhi_len = sizeof(*seqhi); + } + err = -ENOMEM; + iph_base = ah_alloc_tmp(ahash, nfrags + sglists, IPV6HDR_BASELEN + + extlen + seqhi_len); + if (!iph_base) + goto out; + + iph_ext = ah_tmp_ext(iph_base); + seqhi = (__be32 *)((char *)iph_ext + extlen); + icv = ah_tmp_icv(ahash, seqhi, seqhi_len); + req = ah_tmp_req(ahash, icv); + sg = ah_req_sg(ahash, req); + seqhisg = sg + nfrags; + + ah = ip_auth_hdr(skb); + memset(ah->auth_data, 0, ahp->icv_trunc_len); + + top_iph = ipv6_hdr(skb); + top_iph->payload_len = htons(skb->len - sizeof(*top_iph)); + + nexthdr = *skb_mac_header(skb); + *skb_mac_header(skb) = IPPROTO_AH; + + /* When there are no extension headers, we only need to save the first + * 8 bytes of the base IP header. + */ + memcpy(iph_base, top_iph, IPV6HDR_BASELEN); + + if (extlen) { +#if IS_ENABLED(CONFIG_IPV6_MIP6) + memcpy(iph_ext, &top_iph->saddr, extlen); +#else + memcpy(iph_ext, &top_iph->daddr, extlen); +#endif + err = ipv6_clear_mutable_options(top_iph, + extlen - sizeof(*iph_ext) + + sizeof(*top_iph), + XFRM_POLICY_OUT); + if (err) + goto out_free; + } + + ah->nexthdr = nexthdr; + + top_iph->priority = 0; + top_iph->flow_lbl[0] = 0; + top_iph->flow_lbl[1] = 0; + top_iph->flow_lbl[2] = 0; + top_iph->hop_limit = 0; + + ah->hdrlen = (XFRM_ALIGN8(sizeof(*ah) + ahp->icv_trunc_len) >> 2) - 2; + + ah->reserved = 0; + ah->spi = x->id.spi; + ah->seq_no = htonl(XFRM_SKB_CB(skb)->seq.output.low); + + sg_init_table(sg, nfrags + sglists); + err = skb_to_sgvec_nomark(skb, sg, 0, skb->len); + if (unlikely(err < 0)) + goto out_free; + + if (x->props.flags & XFRM_STATE_ESN) { + /* Attach seqhi sg right after packet payload */ + *seqhi = htonl(XFRM_SKB_CB(skb)->seq.output.hi); + sg_set_buf(seqhisg, seqhi, seqhi_len); + } + ahash_request_set_crypt(req, sg, icv, skb->len + seqhi_len); + ahash_request_set_callback(req, 0, ah6_output_done, skb); + + AH_SKB_CB(skb)->tmp = iph_base; + + err = crypto_ahash_digest(req); + if (err) { + if (err == -EINPROGRESS) + goto out; + + if (err == -ENOSPC) + err = NET_XMIT_DROP; + goto out_free; + } + + memcpy(ah->auth_data, icv, ahp->icv_trunc_len); + memcpy(top_iph, iph_base, IPV6HDR_BASELEN); + + if (extlen) { +#if IS_ENABLED(CONFIG_IPV6_MIP6) + memcpy(&top_iph->saddr, iph_ext, extlen); +#else + memcpy(&top_iph->daddr, iph_ext, extlen); +#endif + } + +out_free: + kfree(iph_base); +out: + return err; +} + +static void ah6_input_done(struct crypto_async_request *base, int err) +{ + u8 *auth_data; + u8 *icv; + u8 *work_iph; + struct sk_buff *skb = base->data; + struct xfrm_state *x = xfrm_input_state(skb); + struct ah_data *ahp = x->data; + struct ip_auth_hdr *ah = ip_auth_hdr(skb); + int hdr_len = skb_network_header_len(skb); + int ah_hlen = ipv6_authlen(ah); + + if (err) + goto out; + + work_iph = AH_SKB_CB(skb)->tmp; + auth_data = ah_tmp_auth(work_iph, hdr_len); + icv = ah_tmp_icv(ahp->ahash, auth_data, ahp->icv_trunc_len); + + err = crypto_memneq(icv, auth_data, ahp->icv_trunc_len) ? -EBADMSG : 0; + if (err) + goto out; + + err = ah->nexthdr; + + skb->network_header += ah_hlen; + memcpy(skb_network_header(skb), work_iph, hdr_len); + __skb_pull(skb, ah_hlen + hdr_len); + if (x->props.mode == XFRM_MODE_TUNNEL) + skb_reset_transport_header(skb); + else + skb_set_transport_header(skb, -hdr_len); +out: + kfree(AH_SKB_CB(skb)->tmp); + xfrm_input_resume(skb, err); +} + + + +static int ah6_input(struct xfrm_state *x, struct sk_buff *skb) +{ + /* + * Before process AH + * [IPv6][Ext1][Ext2][AH][Dest][Payload] + * |<-------------->| hdr_len + * + * To erase AH: + * Keeping copy of cleared headers. After AH processing, + * Moving the pointer of skb->network_header by using skb_pull as long + * as AH header length. Then copy back the copy as long as hdr_len + * If destination header following AH exists, copy it into after [Ext2]. + * + * |<>|[IPv6][Ext1][Ext2][Dest][Payload] + * There is offset of AH before IPv6 header after the process. + */ + + u8 *auth_data; + u8 *icv; + u8 *work_iph; + struct sk_buff *trailer; + struct crypto_ahash *ahash; + struct ahash_request *req; + struct scatterlist *sg; + struct ip_auth_hdr *ah; + struct ipv6hdr *ip6h; + struct ah_data *ahp; + u16 hdr_len; + u16 ah_hlen; + int nexthdr; + int nfrags; + int err = -ENOMEM; + int seqhi_len = 0; + __be32 *seqhi; + int sglists = 0; + struct scatterlist *seqhisg; + + if (!pskb_may_pull(skb, sizeof(struct ip_auth_hdr))) + goto out; + + /* We are going to _remove_ AH header to keep sockets happy, + * so... Later this can change. */ + if (skb_unclone(skb, GFP_ATOMIC)) + goto out; + + skb->ip_summed = CHECKSUM_NONE; + + hdr_len = skb_network_header_len(skb); + ah = (struct ip_auth_hdr *)skb->data; + ahp = x->data; + ahash = ahp->ahash; + + nexthdr = ah->nexthdr; + ah_hlen = ipv6_authlen(ah); + + if (ah_hlen != XFRM_ALIGN8(sizeof(*ah) + ahp->icv_full_len) && + ah_hlen != XFRM_ALIGN8(sizeof(*ah) + ahp->icv_trunc_len)) + goto out; + + if (!pskb_may_pull(skb, ah_hlen)) + goto out; + + err = skb_cow_data(skb, 0, &trailer); + if (err < 0) + goto out; + nfrags = err; + + ah = (struct ip_auth_hdr *)skb->data; + ip6h = ipv6_hdr(skb); + + skb_push(skb, hdr_len); + + if (x->props.flags & XFRM_STATE_ESN) { + sglists = 1; + seqhi_len = sizeof(*seqhi); + } + + work_iph = ah_alloc_tmp(ahash, nfrags + sglists, hdr_len + + ahp->icv_trunc_len + seqhi_len); + if (!work_iph) { + err = -ENOMEM; + goto out; + } + + auth_data = ah_tmp_auth((u8 *)work_iph, hdr_len); + seqhi = (__be32 *)(auth_data + ahp->icv_trunc_len); + icv = ah_tmp_icv(ahash, seqhi, seqhi_len); + req = ah_tmp_req(ahash, icv); + sg = ah_req_sg(ahash, req); + seqhisg = sg + nfrags; + + memcpy(work_iph, ip6h, hdr_len); + memcpy(auth_data, ah->auth_data, ahp->icv_trunc_len); + memset(ah->auth_data, 0, ahp->icv_trunc_len); + + err = ipv6_clear_mutable_options(ip6h, hdr_len, XFRM_POLICY_IN); + if (err) + goto out_free; + + ip6h->priority = 0; + ip6h->flow_lbl[0] = 0; + ip6h->flow_lbl[1] = 0; + ip6h->flow_lbl[2] = 0; + ip6h->hop_limit = 0; + + sg_init_table(sg, nfrags + sglists); + err = skb_to_sgvec_nomark(skb, sg, 0, skb->len); + if (unlikely(err < 0)) + goto out_free; + + if (x->props.flags & XFRM_STATE_ESN) { + /* Attach seqhi sg right after packet payload */ + *seqhi = XFRM_SKB_CB(skb)->seq.input.hi; + sg_set_buf(seqhisg, seqhi, seqhi_len); + } + + ahash_request_set_crypt(req, sg, icv, skb->len + seqhi_len); + ahash_request_set_callback(req, 0, ah6_input_done, skb); + + AH_SKB_CB(skb)->tmp = work_iph; + + err = crypto_ahash_digest(req); + if (err) { + if (err == -EINPROGRESS) + goto out; + + goto out_free; + } + + err = crypto_memneq(icv, auth_data, ahp->icv_trunc_len) ? -EBADMSG : 0; + if (err) + goto out_free; + + skb->network_header += ah_hlen; + memcpy(skb_network_header(skb), work_iph, hdr_len); + __skb_pull(skb, ah_hlen + hdr_len); + + if (x->props.mode == XFRM_MODE_TUNNEL) + skb_reset_transport_header(skb); + else + skb_set_transport_header(skb, -hdr_len); + + err = nexthdr; + +out_free: + kfree(work_iph); +out: + return err; +} + +static int ah6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + struct net *net = dev_net(skb->dev); + struct ipv6hdr *iph = (struct ipv6hdr *)skb->data; + struct ip_auth_hdr *ah = (struct ip_auth_hdr *)(skb->data+offset); + struct xfrm_state *x; + + if (type != ICMPV6_PKT_TOOBIG && + type != NDISC_REDIRECT) + return 0; + + x = xfrm_state_lookup(net, skb->mark, (xfrm_address_t *)&iph->daddr, ah->spi, IPPROTO_AH, AF_INET6); + if (!x) + return 0; + + if (type == NDISC_REDIRECT) + ip6_redirect(skb, net, skb->dev->ifindex, 0, + sock_net_uid(net, NULL)); + else + ip6_update_pmtu(skb, net, info, 0, 0, sock_net_uid(net, NULL)); + xfrm_state_put(x); + + return 0; +} + +static int ah6_init_state(struct xfrm_state *x, struct netlink_ext_ack *extack) +{ + struct ah_data *ahp = NULL; + struct xfrm_algo_desc *aalg_desc; + struct crypto_ahash *ahash; + + if (!x->aalg) { + NL_SET_ERR_MSG(extack, "AH requires a state with an AUTH algorithm"); + goto error; + } + + if (x->encap) { + NL_SET_ERR_MSG(extack, "AH is not compatible with encapsulation"); + goto error; + } + + ahp = kzalloc(sizeof(*ahp), GFP_KERNEL); + if (!ahp) + return -ENOMEM; + + ahash = crypto_alloc_ahash(x->aalg->alg_name, 0, 0); + if (IS_ERR(ahash)) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); + goto error; + } + + ahp->ahash = ahash; + if (crypto_ahash_setkey(ahash, x->aalg->alg_key, + (x->aalg->alg_key_len + 7) / 8)) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); + goto error; + } + + /* + * Lookup the algorithm description maintained by xfrm_algo, + * verify crypto transform properties, and store information + * we need for AH processing. This lookup cannot fail here + * after a successful crypto_alloc_hash(). + */ + aalg_desc = xfrm_aalg_get_byname(x->aalg->alg_name, 0); + BUG_ON(!aalg_desc); + + if (aalg_desc->uinfo.auth.icv_fullbits/8 != + crypto_ahash_digestsize(ahash)) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); + goto error; + } + + ahp->icv_full_len = aalg_desc->uinfo.auth.icv_fullbits/8; + ahp->icv_trunc_len = x->aalg->alg_trunc_len/8; + + x->props.header_len = XFRM_ALIGN8(sizeof(struct ip_auth_hdr) + + ahp->icv_trunc_len); + switch (x->props.mode) { + case XFRM_MODE_BEET: + case XFRM_MODE_TRANSPORT: + break; + case XFRM_MODE_TUNNEL: + x->props.header_len += sizeof(struct ipv6hdr); + break; + default: + NL_SET_ERR_MSG(extack, "Invalid mode requested for AH, must be one of TRANSPORT, TUNNEL, BEET"); + goto error; + } + x->data = ahp; + + return 0; + +error: + if (ahp) { + crypto_free_ahash(ahp->ahash); + kfree(ahp); + } + return -EINVAL; +} + +static void ah6_destroy(struct xfrm_state *x) +{ + struct ah_data *ahp = x->data; + + if (!ahp) + return; + + crypto_free_ahash(ahp->ahash); + kfree(ahp); +} + +static int ah6_rcv_cb(struct sk_buff *skb, int err) +{ + return 0; +} + +static const struct xfrm_type ah6_type = { + .owner = THIS_MODULE, + .proto = IPPROTO_AH, + .flags = XFRM_TYPE_REPLAY_PROT, + .init_state = ah6_init_state, + .destructor = ah6_destroy, + .input = ah6_input, + .output = ah6_output, +}; + +static struct xfrm6_protocol ah6_protocol = { + .handler = xfrm6_rcv, + .input_handler = xfrm_input, + .cb_handler = ah6_rcv_cb, + .err_handler = ah6_err, + .priority = 0, +}; + +static int __init ah6_init(void) +{ + if (xfrm_register_type(&ah6_type, AF_INET6) < 0) { + pr_info("%s: can't add xfrm type\n", __func__); + return -EAGAIN; + } + + if (xfrm6_protocol_register(&ah6_protocol, IPPROTO_AH) < 0) { + pr_info("%s: can't add protocol\n", __func__); + xfrm_unregister_type(&ah6_type, AF_INET6); + return -EAGAIN; + } + + return 0; +} + +static void __exit ah6_fini(void) +{ + if (xfrm6_protocol_deregister(&ah6_protocol, IPPROTO_AH) < 0) + pr_info("%s: can't remove protocol\n", __func__); + + xfrm_unregister_type(&ah6_type, AF_INET6); +} + +module_init(ah6_init); +module_exit(ah6_fini); + +MODULE_LICENSE("GPL"); +MODULE_ALIAS_XFRM_TYPE(AF_INET6, XFRM_PROTO_AH); diff --git a/net/ipv6/anycast.c b/net/ipv6/anycast.c new file mode 100644 index 000000000..dacdea7fc --- /dev/null +++ b/net/ipv6/anycast.c @@ -0,0 +1,619 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Anycast support for IPv6 + * Linux INET6 implementation + * + * Authors: + * David L Stevens (dlstevens@us.ibm.com) + * + * based heavily on net/ipv6/mcast.c + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#define IN6_ADDR_HSIZE_SHIFT 8 +#define IN6_ADDR_HSIZE BIT(IN6_ADDR_HSIZE_SHIFT) +/* anycast address hash table + */ +static struct hlist_head inet6_acaddr_lst[IN6_ADDR_HSIZE]; +static DEFINE_SPINLOCK(acaddr_hash_lock); + +static int ipv6_dev_ac_dec(struct net_device *dev, const struct in6_addr *addr); + +static u32 inet6_acaddr_hash(struct net *net, const struct in6_addr *addr) +{ + u32 val = ipv6_addr_hash(addr) ^ net_hash_mix(net); + + return hash_32(val, IN6_ADDR_HSIZE_SHIFT); +} + +/* + * socket join an anycast group + */ + +int ipv6_sock_ac_join(struct sock *sk, int ifindex, const struct in6_addr *addr) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct net_device *dev = NULL; + struct inet6_dev *idev; + struct ipv6_ac_socklist *pac; + struct net *net = sock_net(sk); + int ishost = !net->ipv6.devconf_all->forwarding; + int err = 0; + + ASSERT_RTNL(); + + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + if (ipv6_addr_is_multicast(addr)) + return -EINVAL; + + if (ifindex) + dev = __dev_get_by_index(net, ifindex); + + if (ipv6_chk_addr_and_flags(net, addr, dev, true, 0, IFA_F_TENTATIVE)) + return -EINVAL; + + pac = sock_kmalloc(sk, sizeof(struct ipv6_ac_socklist), GFP_KERNEL); + if (!pac) + return -ENOMEM; + pac->acl_next = NULL; + pac->acl_addr = *addr; + + if (ifindex == 0) { + struct rt6_info *rt; + + rt = rt6_lookup(net, addr, NULL, 0, NULL, 0); + if (rt) { + dev = rt->dst.dev; + ip6_rt_put(rt); + } else if (ishost) { + err = -EADDRNOTAVAIL; + goto error; + } else { + /* router, no matching interface: just pick one */ + dev = __dev_get_by_flags(net, IFF_UP, + IFF_UP | IFF_LOOPBACK); + } + } + + if (!dev) { + err = -ENODEV; + goto error; + } + + idev = __in6_dev_get(dev); + if (!idev) { + if (ifindex) + err = -ENODEV; + else + err = -EADDRNOTAVAIL; + goto error; + } + /* reset ishost, now that we have a specific device */ + ishost = !idev->cnf.forwarding; + + pac->acl_ifindex = dev->ifindex; + + /* XXX + * For hosts, allow link-local or matching prefix anycasts. + * This obviates the need for propagating anycast routes while + * still allowing some non-router anycast participation. + */ + if (!ipv6_chk_prefix(addr, dev)) { + if (ishost) + err = -EADDRNOTAVAIL; + if (err) + goto error; + } + + err = __ipv6_dev_ac_inc(idev, addr); + if (!err) { + pac->acl_next = np->ipv6_ac_list; + np->ipv6_ac_list = pac; + pac = NULL; + } + +error: + if (pac) + sock_kfree_s(sk, pac, sizeof(*pac)); + return err; +} + +/* + * socket leave an anycast group + */ +int ipv6_sock_ac_drop(struct sock *sk, int ifindex, const struct in6_addr *addr) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct net_device *dev; + struct ipv6_ac_socklist *pac, *prev_pac; + struct net *net = sock_net(sk); + + ASSERT_RTNL(); + + prev_pac = NULL; + for (pac = np->ipv6_ac_list; pac; pac = pac->acl_next) { + if ((ifindex == 0 || pac->acl_ifindex == ifindex) && + ipv6_addr_equal(&pac->acl_addr, addr)) + break; + prev_pac = pac; + } + if (!pac) + return -ENOENT; + if (prev_pac) + prev_pac->acl_next = pac->acl_next; + else + np->ipv6_ac_list = pac->acl_next; + + dev = __dev_get_by_index(net, pac->acl_ifindex); + if (dev) + ipv6_dev_ac_dec(dev, &pac->acl_addr); + + sock_kfree_s(sk, pac, sizeof(*pac)); + return 0; +} + +void __ipv6_sock_ac_close(struct sock *sk) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct net_device *dev = NULL; + struct ipv6_ac_socklist *pac; + struct net *net = sock_net(sk); + int prev_index; + + ASSERT_RTNL(); + pac = np->ipv6_ac_list; + np->ipv6_ac_list = NULL; + + prev_index = 0; + while (pac) { + struct ipv6_ac_socklist *next = pac->acl_next; + + if (pac->acl_ifindex != prev_index) { + dev = __dev_get_by_index(net, pac->acl_ifindex); + prev_index = pac->acl_ifindex; + } + if (dev) + ipv6_dev_ac_dec(dev, &pac->acl_addr); + sock_kfree_s(sk, pac, sizeof(*pac)); + pac = next; + } +} + +void ipv6_sock_ac_close(struct sock *sk) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + + if (!np->ipv6_ac_list) + return; + rtnl_lock(); + __ipv6_sock_ac_close(sk); + rtnl_unlock(); +} + +static void ipv6_add_acaddr_hash(struct net *net, struct ifacaddr6 *aca) +{ + unsigned int hash = inet6_acaddr_hash(net, &aca->aca_addr); + + spin_lock(&acaddr_hash_lock); + hlist_add_head_rcu(&aca->aca_addr_lst, &inet6_acaddr_lst[hash]); + spin_unlock(&acaddr_hash_lock); +} + +static void ipv6_del_acaddr_hash(struct ifacaddr6 *aca) +{ + spin_lock(&acaddr_hash_lock); + hlist_del_init_rcu(&aca->aca_addr_lst); + spin_unlock(&acaddr_hash_lock); +} + +static void aca_get(struct ifacaddr6 *aca) +{ + refcount_inc(&aca->aca_refcnt); +} + +static void aca_free_rcu(struct rcu_head *h) +{ + struct ifacaddr6 *aca = container_of(h, struct ifacaddr6, rcu); + + fib6_info_release(aca->aca_rt); + kfree(aca); +} + +static void aca_put(struct ifacaddr6 *ac) +{ + if (refcount_dec_and_test(&ac->aca_refcnt)) { + call_rcu(&ac->rcu, aca_free_rcu); + } +} + +static struct ifacaddr6 *aca_alloc(struct fib6_info *f6i, + const struct in6_addr *addr) +{ + struct ifacaddr6 *aca; + + aca = kzalloc(sizeof(*aca), GFP_ATOMIC); + if (!aca) + return NULL; + + aca->aca_addr = *addr; + fib6_info_hold(f6i); + aca->aca_rt = f6i; + INIT_HLIST_NODE(&aca->aca_addr_lst); + aca->aca_users = 1; + /* aca_tstamp should be updated upon changes */ + aca->aca_cstamp = aca->aca_tstamp = jiffies; + refcount_set(&aca->aca_refcnt, 1); + + return aca; +} + +/* + * device anycast group inc (add if not found) + */ +int __ipv6_dev_ac_inc(struct inet6_dev *idev, const struct in6_addr *addr) +{ + struct ifacaddr6 *aca; + struct fib6_info *f6i; + struct net *net; + int err; + + ASSERT_RTNL(); + + write_lock_bh(&idev->lock); + if (idev->dead) { + err = -ENODEV; + goto out; + } + + for (aca = idev->ac_list; aca; aca = aca->aca_next) { + if (ipv6_addr_equal(&aca->aca_addr, addr)) { + aca->aca_users++; + err = 0; + goto out; + } + } + + net = dev_net(idev->dev); + f6i = addrconf_f6i_alloc(net, idev, addr, true, GFP_ATOMIC); + if (IS_ERR(f6i)) { + err = PTR_ERR(f6i); + goto out; + } + aca = aca_alloc(f6i, addr); + if (!aca) { + fib6_info_release(f6i); + err = -ENOMEM; + goto out; + } + + aca->aca_next = idev->ac_list; + idev->ac_list = aca; + + /* Hold this for addrconf_join_solict() below before we unlock, + * it is already exposed via idev->ac_list. + */ + aca_get(aca); + write_unlock_bh(&idev->lock); + + ipv6_add_acaddr_hash(net, aca); + + ip6_ins_rt(net, f6i); + + addrconf_join_solict(idev->dev, &aca->aca_addr); + + aca_put(aca); + return 0; +out: + write_unlock_bh(&idev->lock); + return err; +} + +/* + * device anycast group decrement + */ +int __ipv6_dev_ac_dec(struct inet6_dev *idev, const struct in6_addr *addr) +{ + struct ifacaddr6 *aca, *prev_aca; + + ASSERT_RTNL(); + + write_lock_bh(&idev->lock); + prev_aca = NULL; + for (aca = idev->ac_list; aca; aca = aca->aca_next) { + if (ipv6_addr_equal(&aca->aca_addr, addr)) + break; + prev_aca = aca; + } + if (!aca) { + write_unlock_bh(&idev->lock); + return -ENOENT; + } + if (--aca->aca_users > 0) { + write_unlock_bh(&idev->lock); + return 0; + } + if (prev_aca) + prev_aca->aca_next = aca->aca_next; + else + idev->ac_list = aca->aca_next; + write_unlock_bh(&idev->lock); + ipv6_del_acaddr_hash(aca); + addrconf_leave_solict(idev, &aca->aca_addr); + + ip6_del_rt(dev_net(idev->dev), aca->aca_rt, false); + + aca_put(aca); + return 0; +} + +/* called with rtnl_lock() */ +static int ipv6_dev_ac_dec(struct net_device *dev, const struct in6_addr *addr) +{ + struct inet6_dev *idev = __in6_dev_get(dev); + + if (!idev) + return -ENODEV; + return __ipv6_dev_ac_dec(idev, addr); +} + +void ipv6_ac_destroy_dev(struct inet6_dev *idev) +{ + struct ifacaddr6 *aca; + + write_lock_bh(&idev->lock); + while ((aca = idev->ac_list) != NULL) { + idev->ac_list = aca->aca_next; + write_unlock_bh(&idev->lock); + + ipv6_del_acaddr_hash(aca); + + addrconf_leave_solict(idev, &aca->aca_addr); + + ip6_del_rt(dev_net(idev->dev), aca->aca_rt, false); + + aca_put(aca); + + write_lock_bh(&idev->lock); + } + write_unlock_bh(&idev->lock); +} + +/* + * check if the interface has this anycast address + * called with rcu_read_lock() + */ +static bool ipv6_chk_acast_dev(struct net_device *dev, const struct in6_addr *addr) +{ + struct inet6_dev *idev; + struct ifacaddr6 *aca; + + idev = __in6_dev_get(dev); + if (idev) { + read_lock_bh(&idev->lock); + for (aca = idev->ac_list; aca; aca = aca->aca_next) + if (ipv6_addr_equal(&aca->aca_addr, addr)) + break; + read_unlock_bh(&idev->lock); + return aca != NULL; + } + return false; +} + +/* + * check if given interface (or any, if dev==0) has this anycast address + */ +bool ipv6_chk_acast_addr(struct net *net, struct net_device *dev, + const struct in6_addr *addr) +{ + struct net_device *nh_dev; + struct ifacaddr6 *aca; + bool found = false; + + rcu_read_lock(); + if (dev) + found = ipv6_chk_acast_dev(dev, addr); + else { + unsigned int hash = inet6_acaddr_hash(net, addr); + + hlist_for_each_entry_rcu(aca, &inet6_acaddr_lst[hash], + aca_addr_lst) { + nh_dev = fib6_info_nh_dev(aca->aca_rt); + if (!nh_dev || !net_eq(dev_net(nh_dev), net)) + continue; + if (ipv6_addr_equal(&aca->aca_addr, addr)) { + found = true; + break; + } + } + } + rcu_read_unlock(); + return found; +} + +/* check if this anycast address is link-local on given interface or + * is global + */ +bool ipv6_chk_acast_addr_src(struct net *net, struct net_device *dev, + const struct in6_addr *addr) +{ + return ipv6_chk_acast_addr(net, + (ipv6_addr_type(addr) & IPV6_ADDR_LINKLOCAL ? + dev : NULL), + addr); +} + +#ifdef CONFIG_PROC_FS +struct ac6_iter_state { + struct seq_net_private p; + struct net_device *dev; + struct inet6_dev *idev; +}; + +#define ac6_seq_private(seq) ((struct ac6_iter_state *)(seq)->private) + +static inline struct ifacaddr6 *ac6_get_first(struct seq_file *seq) +{ + struct ifacaddr6 *im = NULL; + struct ac6_iter_state *state = ac6_seq_private(seq); + struct net *net = seq_file_net(seq); + + state->idev = NULL; + for_each_netdev_rcu(net, state->dev) { + struct inet6_dev *idev; + idev = __in6_dev_get(state->dev); + if (!idev) + continue; + read_lock_bh(&idev->lock); + im = idev->ac_list; + if (im) { + state->idev = idev; + break; + } + read_unlock_bh(&idev->lock); + } + return im; +} + +static struct ifacaddr6 *ac6_get_next(struct seq_file *seq, struct ifacaddr6 *im) +{ + struct ac6_iter_state *state = ac6_seq_private(seq); + + im = im->aca_next; + while (!im) { + if (likely(state->idev != NULL)) + read_unlock_bh(&state->idev->lock); + + state->dev = next_net_device_rcu(state->dev); + if (!state->dev) { + state->idev = NULL; + break; + } + state->idev = __in6_dev_get(state->dev); + if (!state->idev) + continue; + read_lock_bh(&state->idev->lock); + im = state->idev->ac_list; + } + return im; +} + +static struct ifacaddr6 *ac6_get_idx(struct seq_file *seq, loff_t pos) +{ + struct ifacaddr6 *im = ac6_get_first(seq); + if (im) + while (pos && (im = ac6_get_next(seq, im)) != NULL) + --pos; + return pos ? NULL : im; +} + +static void *ac6_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(RCU) +{ + rcu_read_lock(); + return ac6_get_idx(seq, *pos); +} + +static void *ac6_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct ifacaddr6 *im = ac6_get_next(seq, v); + + ++*pos; + return im; +} + +static void ac6_seq_stop(struct seq_file *seq, void *v) + __releases(RCU) +{ + struct ac6_iter_state *state = ac6_seq_private(seq); + + if (likely(state->idev != NULL)) { + read_unlock_bh(&state->idev->lock); + state->idev = NULL; + } + rcu_read_unlock(); +} + +static int ac6_seq_show(struct seq_file *seq, void *v) +{ + struct ifacaddr6 *im = (struct ifacaddr6 *)v; + struct ac6_iter_state *state = ac6_seq_private(seq); + + seq_printf(seq, "%-4d %-15s %pi6 %5d\n", + state->dev->ifindex, state->dev->name, + &im->aca_addr, im->aca_users); + return 0; +} + +static const struct seq_operations ac6_seq_ops = { + .start = ac6_seq_start, + .next = ac6_seq_next, + .stop = ac6_seq_stop, + .show = ac6_seq_show, +}; + +int __net_init ac6_proc_init(struct net *net) +{ + if (!proc_create_net("anycast6", 0444, net->proc_net, &ac6_seq_ops, + sizeof(struct ac6_iter_state))) + return -ENOMEM; + + return 0; +} + +void ac6_proc_exit(struct net *net) +{ + remove_proc_entry("anycast6", net->proc_net); +} +#endif + +/* Init / cleanup code + */ +int __init ipv6_anycast_init(void) +{ + int i; + + for (i = 0; i < IN6_ADDR_HSIZE; i++) + INIT_HLIST_HEAD(&inet6_acaddr_lst[i]); + return 0; +} + +void ipv6_anycast_cleanup(void) +{ + int i; + + spin_lock(&acaddr_hash_lock); + for (i = 0; i < IN6_ADDR_HSIZE; i++) + WARN_ON(!hlist_empty(&inet6_acaddr_lst[i])); + spin_unlock(&acaddr_hash_lock); +} diff --git a/net/ipv6/calipso.c b/net/ipv6/calipso.c new file mode 100644 index 000000000..1578ed9e9 --- /dev/null +++ b/net/ipv6/calipso.c @@ -0,0 +1,1459 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * CALIPSO - Common Architecture Label IPv6 Security Option + * + * This is an implementation of the CALIPSO protocol as specified in + * RFC 5570. + * + * Authors: Paul Moore + * Huw Davies + */ + +/* (c) Copyright Hewlett-Packard Development Company, L.P., 2006, 2008 + * (c) Copyright Huw Davies , 2015 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* Maximium size of the calipso option including + * the two-byte TLV header. + */ +#define CALIPSO_OPT_LEN_MAX (2 + 252) + +/* Size of the minimum calipso option including + * the two-byte TLV header. + */ +#define CALIPSO_HDR_LEN (2 + 8) + +/* Maximium size of the calipso option including + * the two-byte TLV header and upto 3 bytes of + * leading pad and 7 bytes of trailing pad. + */ +#define CALIPSO_OPT_LEN_MAX_WITH_PAD (3 + CALIPSO_OPT_LEN_MAX + 7) + + /* Maximium size of u32 aligned buffer required to hold calipso + * option. Max of 3 initial pad bytes starting from buffer + 3. + * i.e. the worst case is when the previous tlv finishes on 4n + 3. + */ +#define CALIPSO_MAX_BUFFER (6 + CALIPSO_OPT_LEN_MAX) + +/* List of available DOI definitions */ +static DEFINE_SPINLOCK(calipso_doi_list_lock); +static LIST_HEAD(calipso_doi_list); + +/* Label mapping cache */ +int calipso_cache_enabled = 1; +int calipso_cache_bucketsize = 10; +#define CALIPSO_CACHE_BUCKETBITS 7 +#define CALIPSO_CACHE_BUCKETS BIT(CALIPSO_CACHE_BUCKETBITS) +#define CALIPSO_CACHE_REORDERLIMIT 10 +struct calipso_map_cache_bkt { + spinlock_t lock; + u32 size; + struct list_head list; +}; + +struct calipso_map_cache_entry { + u32 hash; + unsigned char *key; + size_t key_len; + + struct netlbl_lsm_cache *lsm_data; + + u32 activity; + struct list_head list; +}; + +static struct calipso_map_cache_bkt *calipso_cache; + +static void calipso_cache_invalidate(void); +static void calipso_doi_putdef(struct calipso_doi *doi_def); + +/* Label Mapping Cache Functions + */ + +/** + * calipso_cache_entry_free - Frees a cache entry + * @entry: the entry to free + * + * Description: + * This function frees the memory associated with a cache entry including the + * LSM cache data if there are no longer any users, i.e. reference count == 0. + * + */ +static void calipso_cache_entry_free(struct calipso_map_cache_entry *entry) +{ + if (entry->lsm_data) + netlbl_secattr_cache_free(entry->lsm_data); + kfree(entry->key); + kfree(entry); +} + +/** + * calipso_map_cache_hash - Hashing function for the CALIPSO cache + * @key: the hash key + * @key_len: the length of the key in bytes + * + * Description: + * The CALIPSO tag hashing function. Returns a 32-bit hash value. + * + */ +static u32 calipso_map_cache_hash(const unsigned char *key, u32 key_len) +{ + return jhash(key, key_len, 0); +} + +/** + * calipso_cache_init - Initialize the CALIPSO cache + * + * Description: + * Initializes the CALIPSO label mapping cache, this function should be called + * before any of the other functions defined in this file. Returns zero on + * success, negative values on error. + * + */ +static int __init calipso_cache_init(void) +{ + u32 iter; + + calipso_cache = kcalloc(CALIPSO_CACHE_BUCKETS, + sizeof(struct calipso_map_cache_bkt), + GFP_KERNEL); + if (!calipso_cache) + return -ENOMEM; + + for (iter = 0; iter < CALIPSO_CACHE_BUCKETS; iter++) { + spin_lock_init(&calipso_cache[iter].lock); + calipso_cache[iter].size = 0; + INIT_LIST_HEAD(&calipso_cache[iter].list); + } + + return 0; +} + +/** + * calipso_cache_invalidate - Invalidates the current CALIPSO cache + * + * Description: + * Invalidates and frees any entries in the CALIPSO cache. Returns zero on + * success and negative values on failure. + * + */ +static void calipso_cache_invalidate(void) +{ + struct calipso_map_cache_entry *entry, *tmp_entry; + u32 iter; + + for (iter = 0; iter < CALIPSO_CACHE_BUCKETS; iter++) { + spin_lock_bh(&calipso_cache[iter].lock); + list_for_each_entry_safe(entry, + tmp_entry, + &calipso_cache[iter].list, list) { + list_del(&entry->list); + calipso_cache_entry_free(entry); + } + calipso_cache[iter].size = 0; + spin_unlock_bh(&calipso_cache[iter].lock); + } +} + +/** + * calipso_cache_check - Check the CALIPSO cache for a label mapping + * @key: the buffer to check + * @key_len: buffer length in bytes + * @secattr: the security attribute struct to use + * + * Description: + * This function checks the cache to see if a label mapping already exists for + * the given key. If there is a match then the cache is adjusted and the + * @secattr struct is populated with the correct LSM security attributes. The + * cache is adjusted in the following manner if the entry is not already the + * first in the cache bucket: + * + * 1. The cache entry's activity counter is incremented + * 2. The previous (higher ranking) entry's activity counter is decremented + * 3. If the difference between the two activity counters is geater than + * CALIPSO_CACHE_REORDERLIMIT the two entries are swapped + * + * Returns zero on success, -ENOENT for a cache miss, and other negative values + * on error. + * + */ +static int calipso_cache_check(const unsigned char *key, + u32 key_len, + struct netlbl_lsm_secattr *secattr) +{ + u32 bkt; + struct calipso_map_cache_entry *entry; + struct calipso_map_cache_entry *prev_entry = NULL; + u32 hash; + + if (!calipso_cache_enabled) + return -ENOENT; + + hash = calipso_map_cache_hash(key, key_len); + bkt = hash & (CALIPSO_CACHE_BUCKETS - 1); + spin_lock_bh(&calipso_cache[bkt].lock); + list_for_each_entry(entry, &calipso_cache[bkt].list, list) { + if (entry->hash == hash && + entry->key_len == key_len && + memcmp(entry->key, key, key_len) == 0) { + entry->activity += 1; + refcount_inc(&entry->lsm_data->refcount); + secattr->cache = entry->lsm_data; + secattr->flags |= NETLBL_SECATTR_CACHE; + secattr->type = NETLBL_NLTYPE_CALIPSO; + if (!prev_entry) { + spin_unlock_bh(&calipso_cache[bkt].lock); + return 0; + } + + if (prev_entry->activity > 0) + prev_entry->activity -= 1; + if (entry->activity > prev_entry->activity && + entry->activity - prev_entry->activity > + CALIPSO_CACHE_REORDERLIMIT) { + __list_del(entry->list.prev, entry->list.next); + __list_add(&entry->list, + prev_entry->list.prev, + &prev_entry->list); + } + + spin_unlock_bh(&calipso_cache[bkt].lock); + return 0; + } + prev_entry = entry; + } + spin_unlock_bh(&calipso_cache[bkt].lock); + + return -ENOENT; +} + +/** + * calipso_cache_add - Add an entry to the CALIPSO cache + * @calipso_ptr: the CALIPSO option + * @secattr: the packet's security attributes + * + * Description: + * Add a new entry into the CALIPSO label mapping cache. Add the new entry to + * head of the cache bucket's list, if the cache bucket is out of room remove + * the last entry in the list first. It is important to note that there is + * currently no checking for duplicate keys. Returns zero on success, + * negative values on failure. The key stored starts at calipso_ptr + 2, + * i.e. the type and length bytes are not stored, this corresponds to + * calipso_ptr[1] bytes of data. + * + */ +static int calipso_cache_add(const unsigned char *calipso_ptr, + const struct netlbl_lsm_secattr *secattr) +{ + int ret_val = -EPERM; + u32 bkt; + struct calipso_map_cache_entry *entry = NULL; + struct calipso_map_cache_entry *old_entry = NULL; + u32 calipso_ptr_len; + + if (!calipso_cache_enabled || calipso_cache_bucketsize <= 0) + return 0; + + calipso_ptr_len = calipso_ptr[1]; + + entry = kzalloc(sizeof(*entry), GFP_ATOMIC); + if (!entry) + return -ENOMEM; + entry->key = kmemdup(calipso_ptr + 2, calipso_ptr_len, GFP_ATOMIC); + if (!entry->key) { + ret_val = -ENOMEM; + goto cache_add_failure; + } + entry->key_len = calipso_ptr_len; + entry->hash = calipso_map_cache_hash(calipso_ptr, calipso_ptr_len); + refcount_inc(&secattr->cache->refcount); + entry->lsm_data = secattr->cache; + + bkt = entry->hash & (CALIPSO_CACHE_BUCKETS - 1); + spin_lock_bh(&calipso_cache[bkt].lock); + if (calipso_cache[bkt].size < calipso_cache_bucketsize) { + list_add(&entry->list, &calipso_cache[bkt].list); + calipso_cache[bkt].size += 1; + } else { + old_entry = list_entry(calipso_cache[bkt].list.prev, + struct calipso_map_cache_entry, list); + list_del(&old_entry->list); + list_add(&entry->list, &calipso_cache[bkt].list); + calipso_cache_entry_free(old_entry); + } + spin_unlock_bh(&calipso_cache[bkt].lock); + + return 0; + +cache_add_failure: + if (entry) + calipso_cache_entry_free(entry); + return ret_val; +} + +/* DOI List Functions + */ + +/** + * calipso_doi_search - Searches for a DOI definition + * @doi: the DOI to search for + * + * Description: + * Search the DOI definition list for a DOI definition with a DOI value that + * matches @doi. The caller is responsible for calling rcu_read_[un]lock(). + * Returns a pointer to the DOI definition on success and NULL on failure. + */ +static struct calipso_doi *calipso_doi_search(u32 doi) +{ + struct calipso_doi *iter; + + list_for_each_entry_rcu(iter, &calipso_doi_list, list) + if (iter->doi == doi && refcount_read(&iter->refcount)) + return iter; + return NULL; +} + +/** + * calipso_doi_add - Add a new DOI to the CALIPSO protocol engine + * @doi_def: the DOI structure + * @audit_info: NetLabel audit information + * + * Description: + * The caller defines a new DOI for use by the CALIPSO engine and calls this + * function to add it to the list of acceptable domains. The caller must + * ensure that the mapping table specified in @doi_def->map meets all of the + * requirements of the mapping type (see calipso.h for details). Returns + * zero on success and non-zero on failure. + * + */ +static int calipso_doi_add(struct calipso_doi *doi_def, + struct netlbl_audit *audit_info) +{ + int ret_val = -EINVAL; + u32 doi; + u32 doi_type; + struct audit_buffer *audit_buf; + + doi = doi_def->doi; + doi_type = doi_def->type; + + if (doi_def->doi == CALIPSO_DOI_UNKNOWN) + goto doi_add_return; + + refcount_set(&doi_def->refcount, 1); + + spin_lock(&calipso_doi_list_lock); + if (calipso_doi_search(doi_def->doi)) { + spin_unlock(&calipso_doi_list_lock); + ret_val = -EEXIST; + goto doi_add_return; + } + list_add_tail_rcu(&doi_def->list, &calipso_doi_list); + spin_unlock(&calipso_doi_list_lock); + ret_val = 0; + +doi_add_return: + audit_buf = netlbl_audit_start(AUDIT_MAC_CALIPSO_ADD, audit_info); + if (audit_buf) { + const char *type_str; + + switch (doi_type) { + case CALIPSO_MAP_PASS: + type_str = "pass"; + break; + default: + type_str = "(unknown)"; + } + audit_log_format(audit_buf, + " calipso_doi=%u calipso_type=%s res=%u", + doi, type_str, ret_val == 0 ? 1 : 0); + audit_log_end(audit_buf); + } + + return ret_val; +} + +/** + * calipso_doi_free - Frees a DOI definition + * @doi_def: the DOI definition + * + * Description: + * This function frees all of the memory associated with a DOI definition. + * + */ +static void calipso_doi_free(struct calipso_doi *doi_def) +{ + kfree(doi_def); +} + +/** + * calipso_doi_free_rcu - Frees a DOI definition via the RCU pointer + * @entry: the entry's RCU field + * + * Description: + * This function is designed to be used as a callback to the call_rcu() + * function so that the memory allocated to the DOI definition can be released + * safely. + * + */ +static void calipso_doi_free_rcu(struct rcu_head *entry) +{ + struct calipso_doi *doi_def; + + doi_def = container_of(entry, struct calipso_doi, rcu); + calipso_doi_free(doi_def); +} + +/** + * calipso_doi_remove - Remove an existing DOI from the CALIPSO protocol engine + * @doi: the DOI value + * @audit_info: NetLabel audit information + * + * Description: + * Removes a DOI definition from the CALIPSO engine. The NetLabel routines will + * be called to release their own LSM domain mappings as well as our own + * domain list. Returns zero on success and negative values on failure. + * + */ +static int calipso_doi_remove(u32 doi, struct netlbl_audit *audit_info) +{ + int ret_val; + struct calipso_doi *doi_def; + struct audit_buffer *audit_buf; + + spin_lock(&calipso_doi_list_lock); + doi_def = calipso_doi_search(doi); + if (!doi_def) { + spin_unlock(&calipso_doi_list_lock); + ret_val = -ENOENT; + goto doi_remove_return; + } + list_del_rcu(&doi_def->list); + spin_unlock(&calipso_doi_list_lock); + + calipso_doi_putdef(doi_def); + ret_val = 0; + +doi_remove_return: + audit_buf = netlbl_audit_start(AUDIT_MAC_CALIPSO_DEL, audit_info); + if (audit_buf) { + audit_log_format(audit_buf, + " calipso_doi=%u res=%u", + doi, ret_val == 0 ? 1 : 0); + audit_log_end(audit_buf); + } + + return ret_val; +} + +/** + * calipso_doi_getdef - Returns a reference to a valid DOI definition + * @doi: the DOI value + * + * Description: + * Searches for a valid DOI definition and if one is found it is returned to + * the caller. Otherwise NULL is returned. The caller must ensure that + * calipso_doi_putdef() is called when the caller is done. + * + */ +static struct calipso_doi *calipso_doi_getdef(u32 doi) +{ + struct calipso_doi *doi_def; + + rcu_read_lock(); + doi_def = calipso_doi_search(doi); + if (!doi_def) + goto doi_getdef_return; + if (!refcount_inc_not_zero(&doi_def->refcount)) + doi_def = NULL; + +doi_getdef_return: + rcu_read_unlock(); + return doi_def; +} + +/** + * calipso_doi_putdef - Releases a reference for the given DOI definition + * @doi_def: the DOI definition + * + * Description: + * Releases a DOI definition reference obtained from calipso_doi_getdef(). + * + */ +static void calipso_doi_putdef(struct calipso_doi *doi_def) +{ + if (!doi_def) + return; + + if (!refcount_dec_and_test(&doi_def->refcount)) + return; + + calipso_cache_invalidate(); + call_rcu(&doi_def->rcu, calipso_doi_free_rcu); +} + +/** + * calipso_doi_walk - Iterate through the DOI definitions + * @skip_cnt: skip past this number of DOI definitions, updated + * @callback: callback for each DOI definition + * @cb_arg: argument for the callback function + * + * Description: + * Iterate over the DOI definition list, skipping the first @skip_cnt entries. + * For each entry call @callback, if @callback returns a negative value stop + * 'walking' through the list and return. Updates the value in @skip_cnt upon + * return. Returns zero on success, negative values on failure. + * + */ +static int calipso_doi_walk(u32 *skip_cnt, + int (*callback)(struct calipso_doi *doi_def, + void *arg), + void *cb_arg) +{ + int ret_val = -ENOENT; + u32 doi_cnt = 0; + struct calipso_doi *iter_doi; + + rcu_read_lock(); + list_for_each_entry_rcu(iter_doi, &calipso_doi_list, list) + if (refcount_read(&iter_doi->refcount) > 0) { + if (doi_cnt++ < *skip_cnt) + continue; + ret_val = callback(iter_doi, cb_arg); + if (ret_val < 0) { + doi_cnt--; + goto doi_walk_return; + } + } + +doi_walk_return: + rcu_read_unlock(); + *skip_cnt = doi_cnt; + return ret_val; +} + +/** + * calipso_validate - Validate a CALIPSO option + * @skb: the packet + * @option: the start of the option + * + * Description: + * This routine is called to validate a CALIPSO option. + * If the option is valid then %true is returned, otherwise + * %false is returned. + * + * The caller should have already checked that the length of the + * option (including the TLV header) is >= 10 and that the catmap + * length is consistent with the option length. + * + * We leave checks on the level and categories to the socket layer. + */ +bool calipso_validate(const struct sk_buff *skb, const unsigned char *option) +{ + struct calipso_doi *doi_def; + bool ret_val; + u16 crc, len = option[1] + 2; + static const u8 zero[2]; + + /* The original CRC runs over the option including the TLV header + * with the CRC-16 field (at offset 8) zeroed out. */ + crc = crc_ccitt(0xffff, option, 8); + crc = crc_ccitt(crc, zero, sizeof(zero)); + if (len > 10) + crc = crc_ccitt(crc, option + 10, len - 10); + crc = ~crc; + if (option[8] != (crc & 0xff) || option[9] != ((crc >> 8) & 0xff)) + return false; + + rcu_read_lock(); + doi_def = calipso_doi_search(get_unaligned_be32(option + 2)); + ret_val = !!doi_def; + rcu_read_unlock(); + + return ret_val; +} + +/** + * calipso_map_cat_hton - Perform a category mapping from host to network + * @doi_def: the DOI definition + * @secattr: the security attributes + * @net_cat: the zero'd out category bitmap in network/CALIPSO format + * @net_cat_len: the length of the CALIPSO bitmap in bytes + * + * Description: + * Perform a label mapping to translate a local MLS category bitmap to the + * correct CALIPSO bitmap using the given DOI definition. Returns the minimum + * size in bytes of the network bitmap on success, negative values otherwise. + * + */ +static int calipso_map_cat_hton(const struct calipso_doi *doi_def, + const struct netlbl_lsm_secattr *secattr, + unsigned char *net_cat, + u32 net_cat_len) +{ + int spot = -1; + u32 net_spot_max = 0; + u32 net_clen_bits = net_cat_len * 8; + + for (;;) { + spot = netlbl_catmap_walk(secattr->attr.mls.cat, + spot + 1); + if (spot < 0) + break; + if (spot >= net_clen_bits) + return -ENOSPC; + netlbl_bitmap_setbit(net_cat, spot, 1); + + if (spot > net_spot_max) + net_spot_max = spot; + } + + return (net_spot_max / 32 + 1) * 4; +} + +/** + * calipso_map_cat_ntoh - Perform a category mapping from network to host + * @doi_def: the DOI definition + * @net_cat: the category bitmap in network/CALIPSO format + * @net_cat_len: the length of the CALIPSO bitmap in bytes + * @secattr: the security attributes + * + * Description: + * Perform a label mapping to translate a CALIPSO bitmap to the correct local + * MLS category bitmap using the given DOI definition. Returns zero on + * success, negative values on failure. + * + */ +static int calipso_map_cat_ntoh(const struct calipso_doi *doi_def, + const unsigned char *net_cat, + u32 net_cat_len, + struct netlbl_lsm_secattr *secattr) +{ + int ret_val; + int spot = -1; + u32 net_clen_bits = net_cat_len * 8; + + for (;;) { + spot = netlbl_bitmap_walk(net_cat, + net_clen_bits, + spot + 1, + 1); + if (spot < 0) { + if (spot == -2) + return -EFAULT; + return 0; + } + + ret_val = netlbl_catmap_setbit(&secattr->attr.mls.cat, + spot, + GFP_ATOMIC); + if (ret_val != 0) + return ret_val; + } + + return -EINVAL; +} + +/** + * calipso_pad_write - Writes pad bytes in TLV format + * @buf: the buffer + * @offset: offset from start of buffer to write padding + * @count: number of pad bytes to write + * + * Description: + * Write @count bytes of TLV padding into @buffer starting at offset @offset. + * @count should be less than 8 - see RFC 4942. + * + */ +static int calipso_pad_write(unsigned char *buf, unsigned int offset, + unsigned int count) +{ + if (WARN_ON_ONCE(count >= 8)) + return -EINVAL; + + switch (count) { + case 0: + break; + case 1: + buf[offset] = IPV6_TLV_PAD1; + break; + default: + buf[offset] = IPV6_TLV_PADN; + buf[offset + 1] = count - 2; + if (count > 2) + memset(buf + offset + 2, 0, count - 2); + break; + } + return 0; +} + +/** + * calipso_genopt - Generate a CALIPSO option + * @buf: the option buffer + * @start: offset from which to write + * @buf_len: the size of opt_buf + * @doi_def: the CALIPSO DOI to use + * @secattr: the security attributes + * + * Description: + * Generate a CALIPSO option using the DOI definition and security attributes + * passed to the function. This also generates upto three bytes of leading + * padding that ensures that the option is 4n + 2 aligned. It returns the + * number of bytes written (including any initial padding). + */ +static int calipso_genopt(unsigned char *buf, u32 start, u32 buf_len, + const struct calipso_doi *doi_def, + const struct netlbl_lsm_secattr *secattr) +{ + int ret_val; + u32 len, pad; + u16 crc; + static const unsigned char padding[4] = {2, 1, 0, 3}; + unsigned char *calipso; + + /* CALIPSO has 4n + 2 alignment */ + pad = padding[start & 3]; + if (buf_len <= start + pad + CALIPSO_HDR_LEN) + return -ENOSPC; + + if ((secattr->flags & NETLBL_SECATTR_MLS_LVL) == 0) + return -EPERM; + + len = CALIPSO_HDR_LEN; + + if (secattr->flags & NETLBL_SECATTR_MLS_CAT) { + ret_val = calipso_map_cat_hton(doi_def, + secattr, + buf + start + pad + len, + buf_len - start - pad - len); + if (ret_val < 0) + return ret_val; + len += ret_val; + } + + calipso_pad_write(buf, start, pad); + calipso = buf + start + pad; + + calipso[0] = IPV6_TLV_CALIPSO; + calipso[1] = len - 2; + *(__be32 *)(calipso + 2) = htonl(doi_def->doi); + calipso[6] = (len - CALIPSO_HDR_LEN) / 4; + calipso[7] = secattr->attr.mls.lvl; + crc = ~crc_ccitt(0xffff, calipso, len); + calipso[8] = crc & 0xff; + calipso[9] = (crc >> 8) & 0xff; + return pad + len; +} + +/* Hop-by-hop hdr helper functions + */ + +/** + * calipso_opt_update - Replaces socket's hop options with a new set + * @sk: the socket + * @hop: new hop options + * + * Description: + * Replaces @sk's hop options with @hop. @hop may be NULL to leave + * the socket with no hop options. + * + */ +static int calipso_opt_update(struct sock *sk, struct ipv6_opt_hdr *hop) +{ + struct ipv6_txoptions *old = txopt_get(inet6_sk(sk)), *txopts; + + txopts = ipv6_renew_options(sk, old, IPV6_HOPOPTS, hop); + txopt_put(old); + if (IS_ERR(txopts)) + return PTR_ERR(txopts); + + txopts = ipv6_update_options(sk, txopts); + if (txopts) { + atomic_sub(txopts->tot_len, &sk->sk_omem_alloc); + txopt_put(txopts); + } + + return 0; +} + +/** + * calipso_tlv_len - Returns the length of the TLV + * @opt: the option header + * @offset: offset of the TLV within the header + * + * Description: + * Returns the length of the TLV option at offset @offset within + * the option header @opt. Checks that the entire TLV fits inside + * the option header, returns a negative value if this is not the case. + */ +static int calipso_tlv_len(struct ipv6_opt_hdr *opt, unsigned int offset) +{ + unsigned char *tlv = (unsigned char *)opt; + unsigned int opt_len = ipv6_optlen(opt), tlv_len; + + if (offset < sizeof(*opt) || offset >= opt_len) + return -EINVAL; + if (tlv[offset] == IPV6_TLV_PAD1) + return 1; + if (offset + 1 >= opt_len) + return -EINVAL; + tlv_len = tlv[offset + 1] + 2; + if (offset + tlv_len > opt_len) + return -EINVAL; + return tlv_len; +} + +/** + * calipso_opt_find - Finds the CALIPSO option in an IPv6 hop options header + * @hop: the hop options header + * @start: on return holds the offset of any leading padding + * @end: on return holds the offset of the first non-pad TLV after CALIPSO + * + * Description: + * Finds the space occupied by a CALIPSO option (including any leading and + * trailing padding). + * + * If a CALIPSO option exists set @start and @end to the + * offsets within @hop of the start of padding before the first + * CALIPSO option and the end of padding after the first CALIPSO + * option. In this case the function returns 0. + * + * In the absence of a CALIPSO option, @start and @end will be + * set to the start and end of any trailing padding in the header. + * This is useful when appending a new option, as the caller may want + * to overwrite some of this padding. In this case the function will + * return -ENOENT. + */ +static int calipso_opt_find(struct ipv6_opt_hdr *hop, unsigned int *start, + unsigned int *end) +{ + int ret_val = -ENOENT, tlv_len; + unsigned int opt_len, offset, offset_s = 0, offset_e = 0; + unsigned char *opt = (unsigned char *)hop; + + opt_len = ipv6_optlen(hop); + offset = sizeof(*hop); + + while (offset < opt_len) { + tlv_len = calipso_tlv_len(hop, offset); + if (tlv_len < 0) + return tlv_len; + + switch (opt[offset]) { + case IPV6_TLV_PAD1: + case IPV6_TLV_PADN: + if (offset_e) + offset_e = offset; + break; + case IPV6_TLV_CALIPSO: + ret_val = 0; + offset_e = offset; + break; + default: + if (offset_e == 0) + offset_s = offset; + else + goto out; + } + offset += tlv_len; + } + +out: + if (offset_s) + *start = offset_s + calipso_tlv_len(hop, offset_s); + else + *start = sizeof(*hop); + if (offset_e) + *end = offset_e + calipso_tlv_len(hop, offset_e); + else + *end = opt_len; + + return ret_val; +} + +/** + * calipso_opt_insert - Inserts a CALIPSO option into an IPv6 hop opt hdr + * @hop: the original hop options header + * @doi_def: the CALIPSO DOI to use + * @secattr: the specific security attributes of the socket + * + * Description: + * Creates a new hop options header based on @hop with a + * CALIPSO option added to it. If @hop already contains a CALIPSO + * option this is overwritten, otherwise the new option is appended + * after any existing options. If @hop is NULL then the new header + * will contain just the CALIPSO option and any needed padding. + * + */ +static struct ipv6_opt_hdr * +calipso_opt_insert(struct ipv6_opt_hdr *hop, + const struct calipso_doi *doi_def, + const struct netlbl_lsm_secattr *secattr) +{ + unsigned int start, end, buf_len, pad, hop_len; + struct ipv6_opt_hdr *new; + int ret_val; + + if (hop) { + hop_len = ipv6_optlen(hop); + ret_val = calipso_opt_find(hop, &start, &end); + if (ret_val && ret_val != -ENOENT) + return ERR_PTR(ret_val); + } else { + hop_len = 0; + start = sizeof(*hop); + end = 0; + } + + buf_len = hop_len + start - end + CALIPSO_OPT_LEN_MAX_WITH_PAD; + new = kzalloc(buf_len, GFP_ATOMIC); + if (!new) + return ERR_PTR(-ENOMEM); + + if (start > sizeof(*hop)) + memcpy(new, hop, start); + ret_val = calipso_genopt((unsigned char *)new, start, buf_len, doi_def, + secattr); + if (ret_val < 0) { + kfree(new); + return ERR_PTR(ret_val); + } + + buf_len = start + ret_val; + /* At this point buf_len aligns to 4n, so (buf_len & 4) pads to 8n */ + pad = ((buf_len & 4) + (end & 7)) & 7; + calipso_pad_write((unsigned char *)new, buf_len, pad); + buf_len += pad; + + if (end != hop_len) { + memcpy((char *)new + buf_len, (char *)hop + end, hop_len - end); + buf_len += hop_len - end; + } + new->nexthdr = 0; + new->hdrlen = buf_len / 8 - 1; + + return new; +} + +/** + * calipso_opt_del - Removes the CALIPSO option from an option header + * @hop: the original header + * @new: the new header + * + * Description: + * Creates a new header based on @hop without any CALIPSO option. If @hop + * doesn't contain a CALIPSO option it returns -ENOENT. If @hop contains + * no other non-padding options, it returns zero with @new set to NULL. + * Otherwise it returns zero, creates a new header without the CALIPSO + * option (and removing as much padding as possible) and returns with + * @new set to that header. + * + */ +static int calipso_opt_del(struct ipv6_opt_hdr *hop, + struct ipv6_opt_hdr **new) +{ + int ret_val; + unsigned int start, end, delta, pad, hop_len; + + ret_val = calipso_opt_find(hop, &start, &end); + if (ret_val) + return ret_val; + + hop_len = ipv6_optlen(hop); + if (start == sizeof(*hop) && end == hop_len) { + /* There's no other option in the header so return NULL */ + *new = NULL; + return 0; + } + + delta = (end - start) & ~7; + *new = kzalloc(hop_len - delta, GFP_ATOMIC); + if (!*new) + return -ENOMEM; + + memcpy(*new, hop, start); + (*new)->hdrlen -= delta / 8; + pad = (end - start) & 7; + calipso_pad_write((unsigned char *)*new, start, pad); + if (end != hop_len) + memcpy((char *)*new + start + pad, (char *)hop + end, + hop_len - end); + + return 0; +} + +/** + * calipso_opt_getattr - Get the security attributes from a memory block + * @calipso: the CALIPSO option + * @secattr: the security attributes + * + * Description: + * Inspect @calipso and return the security attributes in @secattr. + * Returns zero on success and negative values on failure. + * + */ +static int calipso_opt_getattr(const unsigned char *calipso, + struct netlbl_lsm_secattr *secattr) +{ + int ret_val = -ENOMSG; + u32 doi, len = calipso[1], cat_len = calipso[6] * 4; + struct calipso_doi *doi_def; + + if (cat_len + 8 > len) + return -EINVAL; + + if (calipso_cache_check(calipso + 2, calipso[1], secattr) == 0) + return 0; + + doi = get_unaligned_be32(calipso + 2); + rcu_read_lock(); + doi_def = calipso_doi_search(doi); + if (!doi_def) + goto getattr_return; + + secattr->attr.mls.lvl = calipso[7]; + secattr->flags |= NETLBL_SECATTR_MLS_LVL; + + if (cat_len) { + ret_val = calipso_map_cat_ntoh(doi_def, + calipso + 10, + cat_len, + secattr); + if (ret_val != 0) { + netlbl_catmap_free(secattr->attr.mls.cat); + goto getattr_return; + } + + if (secattr->attr.mls.cat) + secattr->flags |= NETLBL_SECATTR_MLS_CAT; + } + + secattr->type = NETLBL_NLTYPE_CALIPSO; + +getattr_return: + rcu_read_unlock(); + return ret_val; +} + +/* sock functions. + */ + +/** + * calipso_sock_getattr - Get the security attributes from a sock + * @sk: the sock + * @secattr: the security attributes + * + * Description: + * Query @sk to see if there is a CALIPSO option attached to the sock and if + * there is return the CALIPSO security attributes in @secattr. This function + * requires that @sk be locked, or privately held, but it does not do any + * locking itself. Returns zero on success and negative values on failure. + * + */ +static int calipso_sock_getattr(struct sock *sk, + struct netlbl_lsm_secattr *secattr) +{ + struct ipv6_opt_hdr *hop; + int opt_len, len, ret_val = -ENOMSG, offset; + unsigned char *opt; + struct ipv6_txoptions *txopts = txopt_get(inet6_sk(sk)); + + if (!txopts || !txopts->hopopt) + goto done; + + hop = txopts->hopopt; + opt = (unsigned char *)hop; + opt_len = ipv6_optlen(hop); + offset = sizeof(*hop); + while (offset < opt_len) { + len = calipso_tlv_len(hop, offset); + if (len < 0) { + ret_val = len; + goto done; + } + switch (opt[offset]) { + case IPV6_TLV_CALIPSO: + if (len < CALIPSO_HDR_LEN) + ret_val = -EINVAL; + else + ret_val = calipso_opt_getattr(&opt[offset], + secattr); + goto done; + default: + offset += len; + break; + } + } +done: + txopt_put(txopts); + return ret_val; +} + +/** + * calipso_sock_setattr - Add a CALIPSO option to a socket + * @sk: the socket + * @doi_def: the CALIPSO DOI to use + * @secattr: the specific security attributes of the socket + * + * Description: + * Set the CALIPSO option on the given socket using the DOI definition and + * security attributes passed to the function. This function requires + * exclusive access to @sk, which means it either needs to be in the + * process of being created or locked. Returns zero on success and negative + * values on failure. + * + */ +static int calipso_sock_setattr(struct sock *sk, + const struct calipso_doi *doi_def, + const struct netlbl_lsm_secattr *secattr) +{ + int ret_val; + struct ipv6_opt_hdr *old, *new; + struct ipv6_txoptions *txopts = txopt_get(inet6_sk(sk)); + + old = NULL; + if (txopts) + old = txopts->hopopt; + + new = calipso_opt_insert(old, doi_def, secattr); + txopt_put(txopts); + if (IS_ERR(new)) + return PTR_ERR(new); + + ret_val = calipso_opt_update(sk, new); + + kfree(new); + return ret_val; +} + +/** + * calipso_sock_delattr - Delete the CALIPSO option from a socket + * @sk: the socket + * + * Description: + * Removes the CALIPSO option from a socket, if present. + * + */ +static void calipso_sock_delattr(struct sock *sk) +{ + struct ipv6_opt_hdr *new_hop; + struct ipv6_txoptions *txopts = txopt_get(inet6_sk(sk)); + + if (!txopts || !txopts->hopopt) + goto done; + + if (calipso_opt_del(txopts->hopopt, &new_hop)) + goto done; + + calipso_opt_update(sk, new_hop); + kfree(new_hop); + +done: + txopt_put(txopts); +} + +/* request sock functions. + */ + +/** + * calipso_req_setattr - Add a CALIPSO option to a connection request socket + * @req: the connection request socket + * @doi_def: the CALIPSO DOI to use + * @secattr: the specific security attributes of the socket + * + * Description: + * Set the CALIPSO option on the given socket using the DOI definition and + * security attributes passed to the function. Returns zero on success and + * negative values on failure. + * + */ +static int calipso_req_setattr(struct request_sock *req, + const struct calipso_doi *doi_def, + const struct netlbl_lsm_secattr *secattr) +{ + struct ipv6_txoptions *txopts; + struct inet_request_sock *req_inet = inet_rsk(req); + struct ipv6_opt_hdr *old, *new; + struct sock *sk = sk_to_full_sk(req_to_sk(req)); + + if (req_inet->ipv6_opt && req_inet->ipv6_opt->hopopt) + old = req_inet->ipv6_opt->hopopt; + else + old = NULL; + + new = calipso_opt_insert(old, doi_def, secattr); + if (IS_ERR(new)) + return PTR_ERR(new); + + txopts = ipv6_renew_options(sk, req_inet->ipv6_opt, IPV6_HOPOPTS, new); + + kfree(new); + + if (IS_ERR(txopts)) + return PTR_ERR(txopts); + + txopts = xchg(&req_inet->ipv6_opt, txopts); + if (txopts) { + atomic_sub(txopts->tot_len, &sk->sk_omem_alloc); + txopt_put(txopts); + } + + return 0; +} + +/** + * calipso_req_delattr - Delete the CALIPSO option from a request socket + * @req: the request socket + * + * Description: + * Removes the CALIPSO option from a request socket, if present. + * + */ +static void calipso_req_delattr(struct request_sock *req) +{ + struct inet_request_sock *req_inet = inet_rsk(req); + struct ipv6_opt_hdr *new; + struct ipv6_txoptions *txopts; + struct sock *sk = sk_to_full_sk(req_to_sk(req)); + + if (!req_inet->ipv6_opt || !req_inet->ipv6_opt->hopopt) + return; + + if (calipso_opt_del(req_inet->ipv6_opt->hopopt, &new)) + return; /* Nothing to do */ + + txopts = ipv6_renew_options(sk, req_inet->ipv6_opt, IPV6_HOPOPTS, new); + + if (!IS_ERR(txopts)) { + txopts = xchg(&req_inet->ipv6_opt, txopts); + if (txopts) { + atomic_sub(txopts->tot_len, &sk->sk_omem_alloc); + txopt_put(txopts); + } + } + kfree(new); +} + +/* skbuff functions. + */ + +/** + * calipso_skbuff_optptr - Find the CALIPSO option in the packet + * @skb: the packet + * + * Description: + * Parse the packet's IP header looking for a CALIPSO option. Returns a pointer + * to the start of the CALIPSO option on success, NULL if one if not found. + * + */ +static unsigned char *calipso_skbuff_optptr(const struct sk_buff *skb) +{ + const struct ipv6hdr *ip6_hdr = ipv6_hdr(skb); + int offset; + + if (ip6_hdr->nexthdr != NEXTHDR_HOP) + return NULL; + + offset = ipv6_find_tlv(skb, sizeof(*ip6_hdr), IPV6_TLV_CALIPSO); + if (offset >= 0) + return (unsigned char *)ip6_hdr + offset; + + return NULL; +} + +/** + * calipso_skbuff_setattr - Set the CALIPSO option on a packet + * @skb: the packet + * @doi_def: the CALIPSO DOI to use + * @secattr: the security attributes + * + * Description: + * Set the CALIPSO option on the given packet based on the security attributes. + * Returns a pointer to the IP header on success and NULL on failure. + * + */ +static int calipso_skbuff_setattr(struct sk_buff *skb, + const struct calipso_doi *doi_def, + const struct netlbl_lsm_secattr *secattr) +{ + int ret_val; + struct ipv6hdr *ip6_hdr; + struct ipv6_opt_hdr *hop; + unsigned char buf[CALIPSO_MAX_BUFFER]; + int len_delta, new_end, pad, payload; + unsigned int start, end; + + ip6_hdr = ipv6_hdr(skb); + if (ip6_hdr->nexthdr == NEXTHDR_HOP) { + hop = (struct ipv6_opt_hdr *)(ip6_hdr + 1); + ret_val = calipso_opt_find(hop, &start, &end); + if (ret_val && ret_val != -ENOENT) + return ret_val; + } else { + start = 0; + end = 0; + } + + memset(buf, 0, sizeof(buf)); + ret_val = calipso_genopt(buf, start & 3, sizeof(buf), doi_def, secattr); + if (ret_val < 0) + return ret_val; + + new_end = start + ret_val; + /* At this point new_end aligns to 4n, so (new_end & 4) pads to 8n */ + pad = ((new_end & 4) + (end & 7)) & 7; + len_delta = new_end - (int)end + pad; + ret_val = skb_cow(skb, skb_headroom(skb) + len_delta); + if (ret_val < 0) + return ret_val; + + ip6_hdr = ipv6_hdr(skb); /* Reset as skb_cow() may have moved it */ + + if (len_delta) { + if (len_delta > 0) + skb_push(skb, len_delta); + else + skb_pull(skb, -len_delta); + memmove((char *)ip6_hdr - len_delta, ip6_hdr, + sizeof(*ip6_hdr) + start); + skb_reset_network_header(skb); + ip6_hdr = ipv6_hdr(skb); + payload = ntohs(ip6_hdr->payload_len); + ip6_hdr->payload_len = htons(payload + len_delta); + } + + hop = (struct ipv6_opt_hdr *)(ip6_hdr + 1); + if (start == 0) { + struct ipv6_opt_hdr *new_hop = (struct ipv6_opt_hdr *)buf; + + new_hop->nexthdr = ip6_hdr->nexthdr; + new_hop->hdrlen = len_delta / 8 - 1; + ip6_hdr->nexthdr = NEXTHDR_HOP; + } else { + hop->hdrlen += len_delta / 8; + } + memcpy((char *)hop + start, buf + (start & 3), new_end - start); + calipso_pad_write((unsigned char *)hop, new_end, pad); + + return 0; +} + +/** + * calipso_skbuff_delattr - Delete any CALIPSO options from a packet + * @skb: the packet + * + * Description: + * Removes any and all CALIPSO options from the given packet. Returns zero on + * success, negative values on failure. + * + */ +static int calipso_skbuff_delattr(struct sk_buff *skb) +{ + int ret_val; + struct ipv6hdr *ip6_hdr; + struct ipv6_opt_hdr *old_hop; + u32 old_hop_len, start = 0, end = 0, delta, size, pad; + + if (!calipso_skbuff_optptr(skb)) + return 0; + + /* since we are changing the packet we should make a copy */ + ret_val = skb_cow(skb, skb_headroom(skb)); + if (ret_val < 0) + return ret_val; + + ip6_hdr = ipv6_hdr(skb); + old_hop = (struct ipv6_opt_hdr *)(ip6_hdr + 1); + old_hop_len = ipv6_optlen(old_hop); + + ret_val = calipso_opt_find(old_hop, &start, &end); + if (ret_val) + return ret_val; + + if (start == sizeof(*old_hop) && end == old_hop_len) { + /* There's no other option in the header so we delete + * the whole thing. */ + delta = old_hop_len; + size = sizeof(*ip6_hdr); + ip6_hdr->nexthdr = old_hop->nexthdr; + } else { + delta = (end - start) & ~7; + if (delta) + old_hop->hdrlen -= delta / 8; + pad = (end - start) & 7; + size = sizeof(*ip6_hdr) + start + pad; + calipso_pad_write((unsigned char *)old_hop, start, pad); + } + + if (delta) { + skb_pull(skb, delta); + memmove((char *)ip6_hdr + delta, ip6_hdr, size); + skb_reset_network_header(skb); + } + + return 0; +} + +static const struct netlbl_calipso_ops ops = { + .doi_add = calipso_doi_add, + .doi_free = calipso_doi_free, + .doi_remove = calipso_doi_remove, + .doi_getdef = calipso_doi_getdef, + .doi_putdef = calipso_doi_putdef, + .doi_walk = calipso_doi_walk, + .sock_getattr = calipso_sock_getattr, + .sock_setattr = calipso_sock_setattr, + .sock_delattr = calipso_sock_delattr, + .req_setattr = calipso_req_setattr, + .req_delattr = calipso_req_delattr, + .opt_getattr = calipso_opt_getattr, + .skbuff_optptr = calipso_skbuff_optptr, + .skbuff_setattr = calipso_skbuff_setattr, + .skbuff_delattr = calipso_skbuff_delattr, + .cache_invalidate = calipso_cache_invalidate, + .cache_add = calipso_cache_add +}; + +/** + * calipso_init - Initialize the CALIPSO module + * + * Description: + * Initialize the CALIPSO module and prepare it for use. Returns zero on + * success and negative values on failure. + * + */ +int __init calipso_init(void) +{ + int ret_val; + + ret_val = calipso_cache_init(); + if (!ret_val) + netlbl_calipso_ops_register(&ops); + return ret_val; +} + +void calipso_exit(void) +{ + netlbl_calipso_ops_register(NULL); + calipso_cache_invalidate(); + kfree(calipso_cache); +} diff --git a/net/ipv6/datagram.c b/net/ipv6/datagram.c new file mode 100644 index 000000000..e70ace403 --- /dev/null +++ b/net/ipv6/datagram.c @@ -0,0 +1,1071 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * common UDP/RAW code + * Linux INET6 implementation + * + * Authors: + * Pedro Roque + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +static bool ipv6_mapped_addr_any(const struct in6_addr *a) +{ + return ipv6_addr_v4mapped(a) && (a->s6_addr32[3] == 0); +} + +static void ip6_datagram_flow_key_init(struct flowi6 *fl6, struct sock *sk) +{ + struct inet_sock *inet = inet_sk(sk); + struct ipv6_pinfo *np = inet6_sk(sk); + int oif = sk->sk_bound_dev_if; + + memset(fl6, 0, sizeof(*fl6)); + fl6->flowi6_proto = sk->sk_protocol; + fl6->daddr = sk->sk_v6_daddr; + fl6->saddr = np->saddr; + fl6->flowi6_mark = sk->sk_mark; + fl6->fl6_dport = inet->inet_dport; + fl6->fl6_sport = inet->inet_sport; + fl6->flowlabel = ip6_make_flowinfo(np->tclass, np->flow_label); + fl6->flowi6_uid = sk->sk_uid; + + if (!oif) + oif = np->sticky_pktinfo.ipi6_ifindex; + + if (!oif) { + if (ipv6_addr_is_multicast(&fl6->daddr)) + oif = np->mcast_oif; + else + oif = np->ucast_oif; + } + + fl6->flowi6_oif = oif; + security_sk_classify_flow(sk, flowi6_to_flowi_common(fl6)); +} + +int ip6_datagram_dst_update(struct sock *sk, bool fix_sk_saddr) +{ + struct ip6_flowlabel *flowlabel = NULL; + struct in6_addr *final_p, final; + struct ipv6_txoptions *opt; + struct dst_entry *dst; + struct inet_sock *inet = inet_sk(sk); + struct ipv6_pinfo *np = inet6_sk(sk); + struct flowi6 fl6; + int err = 0; + + if (np->sndflow && (np->flow_label & IPV6_FLOWLABEL_MASK)) { + flowlabel = fl6_sock_lookup(sk, np->flow_label); + if (IS_ERR(flowlabel)) + return -EINVAL; + } + ip6_datagram_flow_key_init(&fl6, sk); + + rcu_read_lock(); + opt = flowlabel ? flowlabel->opt : rcu_dereference(np->opt); + final_p = fl6_update_dst(&fl6, opt, &final); + rcu_read_unlock(); + + dst = ip6_dst_lookup_flow(sock_net(sk), sk, &fl6, final_p); + if (IS_ERR(dst)) { + err = PTR_ERR(dst); + goto out; + } + + if (fix_sk_saddr) { + if (ipv6_addr_any(&np->saddr)) + np->saddr = fl6.saddr; + + if (ipv6_addr_any(&sk->sk_v6_rcv_saddr)) { + sk->sk_v6_rcv_saddr = fl6.saddr; + inet->inet_rcv_saddr = LOOPBACK4_IPV6; + if (sk->sk_prot->rehash) + sk->sk_prot->rehash(sk); + } + } + + ip6_sk_dst_store_flow(sk, dst, &fl6); + +out: + fl6_sock_release(flowlabel); + return err; +} + +void ip6_datagram_release_cb(struct sock *sk) +{ + struct dst_entry *dst; + + if (ipv6_addr_v4mapped(&sk->sk_v6_daddr)) + return; + + rcu_read_lock(); + dst = __sk_dst_get(sk); + if (!dst || !dst->obsolete || + dst->ops->check(dst, inet6_sk(sk)->dst_cookie)) { + rcu_read_unlock(); + return; + } + rcu_read_unlock(); + + ip6_datagram_dst_update(sk, false); +} +EXPORT_SYMBOL_GPL(ip6_datagram_release_cb); + +int __ip6_datagram_connect(struct sock *sk, struct sockaddr *uaddr, + int addr_len) +{ + struct sockaddr_in6 *usin = (struct sockaddr_in6 *) uaddr; + struct inet_sock *inet = inet_sk(sk); + struct ipv6_pinfo *np = inet6_sk(sk); + struct in6_addr *daddr, old_daddr; + __be32 fl6_flowlabel = 0; + __be32 old_fl6_flowlabel; + __be16 old_dport; + int addr_type; + int err; + + if (usin->sin6_family == AF_INET) { + if (ipv6_only_sock(sk)) + return -EAFNOSUPPORT; + err = __ip4_datagram_connect(sk, uaddr, addr_len); + goto ipv4_connected; + } + + if (addr_len < SIN6_LEN_RFC2133) + return -EINVAL; + + if (usin->sin6_family != AF_INET6) + return -EAFNOSUPPORT; + + if (np->sndflow) + fl6_flowlabel = usin->sin6_flowinfo & IPV6_FLOWINFO_MASK; + + if (ipv6_addr_any(&usin->sin6_addr)) { + /* + * connect to self + */ + if (ipv6_addr_v4mapped(&sk->sk_v6_rcv_saddr)) + ipv6_addr_set_v4mapped(htonl(INADDR_LOOPBACK), + &usin->sin6_addr); + else + usin->sin6_addr = in6addr_loopback; + } + + addr_type = ipv6_addr_type(&usin->sin6_addr); + + daddr = &usin->sin6_addr; + + if (addr_type & IPV6_ADDR_MAPPED) { + struct sockaddr_in sin; + + if (ipv6_only_sock(sk)) { + err = -ENETUNREACH; + goto out; + } + sin.sin_family = AF_INET; + sin.sin_addr.s_addr = daddr->s6_addr32[3]; + sin.sin_port = usin->sin6_port; + + err = __ip4_datagram_connect(sk, + (struct sockaddr *) &sin, + sizeof(sin)); + +ipv4_connected: + if (err) + goto out; + + ipv6_addr_set_v4mapped(inet->inet_daddr, &sk->sk_v6_daddr); + + if (ipv6_addr_any(&np->saddr) || + ipv6_mapped_addr_any(&np->saddr)) + ipv6_addr_set_v4mapped(inet->inet_saddr, &np->saddr); + + if (ipv6_addr_any(&sk->sk_v6_rcv_saddr) || + ipv6_mapped_addr_any(&sk->sk_v6_rcv_saddr)) { + ipv6_addr_set_v4mapped(inet->inet_rcv_saddr, + &sk->sk_v6_rcv_saddr); + if (sk->sk_prot->rehash) + sk->sk_prot->rehash(sk); + } + + goto out; + } + + if (__ipv6_addr_needs_scope_id(addr_type)) { + if (addr_len >= sizeof(struct sockaddr_in6) && + usin->sin6_scope_id) { + if (!sk_dev_equal_l3scope(sk, usin->sin6_scope_id)) { + err = -EINVAL; + goto out; + } + WRITE_ONCE(sk->sk_bound_dev_if, usin->sin6_scope_id); + } + + if (!sk->sk_bound_dev_if && (addr_type & IPV6_ADDR_MULTICAST)) + WRITE_ONCE(sk->sk_bound_dev_if, np->mcast_oif); + + /* Connect to link-local address requires an interface */ + if (!sk->sk_bound_dev_if) { + err = -EINVAL; + goto out; + } + } + + /* save the current peer information before updating it */ + old_daddr = sk->sk_v6_daddr; + old_fl6_flowlabel = np->flow_label; + old_dport = inet->inet_dport; + + sk->sk_v6_daddr = *daddr; + np->flow_label = fl6_flowlabel; + inet->inet_dport = usin->sin6_port; + + /* + * Check for a route to destination an obtain the + * destination cache for it. + */ + + err = ip6_datagram_dst_update(sk, true); + if (err) { + /* Restore the socket peer info, to keep it consistent with + * the old socket state + */ + sk->sk_v6_daddr = old_daddr; + np->flow_label = old_fl6_flowlabel; + inet->inet_dport = old_dport; + goto out; + } + + reuseport_has_conns_set(sk); + sk->sk_state = TCP_ESTABLISHED; + sk_set_txhash(sk); +out: + return err; +} +EXPORT_SYMBOL_GPL(__ip6_datagram_connect); + +int ip6_datagram_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) +{ + int res; + + lock_sock(sk); + res = __ip6_datagram_connect(sk, uaddr, addr_len); + release_sock(sk); + return res; +} +EXPORT_SYMBOL_GPL(ip6_datagram_connect); + +int ip6_datagram_connect_v6_only(struct sock *sk, struct sockaddr *uaddr, + int addr_len) +{ + DECLARE_SOCKADDR(struct sockaddr_in6 *, sin6, uaddr); + if (sin6->sin6_family != AF_INET6) + return -EAFNOSUPPORT; + return ip6_datagram_connect(sk, uaddr, addr_len); +} +EXPORT_SYMBOL_GPL(ip6_datagram_connect_v6_only); + +static void ipv6_icmp_error_rfc4884(const struct sk_buff *skb, + struct sock_ee_data_rfc4884 *out) +{ + switch (icmp6_hdr(skb)->icmp6_type) { + case ICMPV6_TIME_EXCEED: + case ICMPV6_DEST_UNREACH: + ip_icmp_error_rfc4884(skb, out, sizeof(struct icmp6hdr), + icmp6_hdr(skb)->icmp6_datagram_len * 8); + } +} + +void ipv6_icmp_error(struct sock *sk, struct sk_buff *skb, int err, + __be16 port, u32 info, u8 *payload) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct icmp6hdr *icmph = icmp6_hdr(skb); + struct sock_exterr_skb *serr; + + if (!np->recverr) + return; + + skb = skb_clone(skb, GFP_ATOMIC); + if (!skb) + return; + + skb->protocol = htons(ETH_P_IPV6); + + serr = SKB_EXT_ERR(skb); + serr->ee.ee_errno = err; + serr->ee.ee_origin = SO_EE_ORIGIN_ICMP6; + serr->ee.ee_type = icmph->icmp6_type; + serr->ee.ee_code = icmph->icmp6_code; + serr->ee.ee_pad = 0; + serr->ee.ee_info = info; + serr->ee.ee_data = 0; + serr->addr_offset = (u8 *)&(((struct ipv6hdr *)(icmph + 1))->daddr) - + skb_network_header(skb); + serr->port = port; + + __skb_pull(skb, payload - skb->data); + + if (inet6_sk(sk)->recverr_rfc4884) + ipv6_icmp_error_rfc4884(skb, &serr->ee.ee_rfc4884); + + skb_reset_transport_header(skb); + + if (sock_queue_err_skb(sk, skb)) + kfree_skb(skb); +} + +void ipv6_local_error(struct sock *sk, int err, struct flowi6 *fl6, u32 info) +{ + const struct ipv6_pinfo *np = inet6_sk(sk); + struct sock_exterr_skb *serr; + struct ipv6hdr *iph; + struct sk_buff *skb; + + if (!np->recverr) + return; + + skb = alloc_skb(sizeof(struct ipv6hdr), GFP_ATOMIC); + if (!skb) + return; + + skb->protocol = htons(ETH_P_IPV6); + + skb_put(skb, sizeof(struct ipv6hdr)); + skb_reset_network_header(skb); + iph = ipv6_hdr(skb); + iph->daddr = fl6->daddr; + ip6_flow_hdr(iph, 0, 0); + + serr = SKB_EXT_ERR(skb); + serr->ee.ee_errno = err; + serr->ee.ee_origin = SO_EE_ORIGIN_LOCAL; + serr->ee.ee_type = 0; + serr->ee.ee_code = 0; + serr->ee.ee_pad = 0; + serr->ee.ee_info = info; + serr->ee.ee_data = 0; + serr->addr_offset = (u8 *)&iph->daddr - skb_network_header(skb); + serr->port = fl6->fl6_dport; + + __skb_pull(skb, skb_tail_pointer(skb) - skb->data); + skb_reset_transport_header(skb); + + if (sock_queue_err_skb(sk, skb)) + kfree_skb(skb); +} + +void ipv6_local_rxpmtu(struct sock *sk, struct flowi6 *fl6, u32 mtu) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct ipv6hdr *iph; + struct sk_buff *skb; + struct ip6_mtuinfo *mtu_info; + + if (!np->rxopt.bits.rxpmtu) + return; + + skb = alloc_skb(sizeof(struct ipv6hdr), GFP_ATOMIC); + if (!skb) + return; + + skb_put(skb, sizeof(struct ipv6hdr)); + skb_reset_network_header(skb); + iph = ipv6_hdr(skb); + iph->daddr = fl6->daddr; + + mtu_info = IP6CBMTU(skb); + + mtu_info->ip6m_mtu = mtu; + mtu_info->ip6m_addr.sin6_family = AF_INET6; + mtu_info->ip6m_addr.sin6_port = 0; + mtu_info->ip6m_addr.sin6_flowinfo = 0; + mtu_info->ip6m_addr.sin6_scope_id = fl6->flowi6_oif; + mtu_info->ip6m_addr.sin6_addr = ipv6_hdr(skb)->daddr; + + __skb_pull(skb, skb_tail_pointer(skb) - skb->data); + skb_reset_transport_header(skb); + + skb = xchg(&np->rxpmtu, skb); + kfree_skb(skb); +} + +/* For some errors we have valid addr_offset even with zero payload and + * zero port. Also, addr_offset should be supported if port is set. + */ +static inline bool ipv6_datagram_support_addr(struct sock_exterr_skb *serr) +{ + return serr->ee.ee_origin == SO_EE_ORIGIN_ICMP6 || + serr->ee.ee_origin == SO_EE_ORIGIN_ICMP || + serr->ee.ee_origin == SO_EE_ORIGIN_LOCAL || serr->port; +} + +/* IPv6 supports cmsg on all origins aside from SO_EE_ORIGIN_LOCAL. + * + * At one point, excluding local errors was a quick test to identify icmp/icmp6 + * errors. This is no longer true, but the test remained, so the v6 stack, + * unlike v4, also honors cmsg requests on all wifi and timestamp errors. + */ +static bool ip6_datagram_support_cmsg(struct sk_buff *skb, + struct sock_exterr_skb *serr) +{ + if (serr->ee.ee_origin == SO_EE_ORIGIN_ICMP || + serr->ee.ee_origin == SO_EE_ORIGIN_ICMP6) + return true; + + if (serr->ee.ee_origin == SO_EE_ORIGIN_LOCAL) + return false; + + if (!IP6CB(skb)->iif) + return false; + + return true; +} + +/* + * Handle MSG_ERRQUEUE + */ +int ipv6_recv_error(struct sock *sk, struct msghdr *msg, int len, int *addr_len) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct sock_exterr_skb *serr; + struct sk_buff *skb; + DECLARE_SOCKADDR(struct sockaddr_in6 *, sin, msg->msg_name); + struct { + struct sock_extended_err ee; + struct sockaddr_in6 offender; + } errhdr; + int err; + int copied; + + err = -EAGAIN; + skb = sock_dequeue_err_skb(sk); + if (!skb) + goto out; + + copied = skb->len; + if (copied > len) { + msg->msg_flags |= MSG_TRUNC; + copied = len; + } + err = skb_copy_datagram_msg(skb, 0, msg, copied); + if (unlikely(err)) { + kfree_skb(skb); + return err; + } + sock_recv_timestamp(msg, sk, skb); + + serr = SKB_EXT_ERR(skb); + + if (sin && ipv6_datagram_support_addr(serr)) { + const unsigned char *nh = skb_network_header(skb); + sin->sin6_family = AF_INET6; + sin->sin6_flowinfo = 0; + sin->sin6_port = serr->port; + if (skb->protocol == htons(ETH_P_IPV6)) { + const struct ipv6hdr *ip6h = container_of((struct in6_addr *)(nh + serr->addr_offset), + struct ipv6hdr, daddr); + sin->sin6_addr = ip6h->daddr; + if (np->sndflow) + sin->sin6_flowinfo = ip6_flowinfo(ip6h); + sin->sin6_scope_id = + ipv6_iface_scope_id(&sin->sin6_addr, + IP6CB(skb)->iif); + } else { + ipv6_addr_set_v4mapped(*(__be32 *)(nh + serr->addr_offset), + &sin->sin6_addr); + sin->sin6_scope_id = 0; + } + *addr_len = sizeof(*sin); + } + + memcpy(&errhdr.ee, &serr->ee, sizeof(struct sock_extended_err)); + sin = &errhdr.offender; + memset(sin, 0, sizeof(*sin)); + + if (ip6_datagram_support_cmsg(skb, serr)) { + sin->sin6_family = AF_INET6; + if (np->rxopt.all) + ip6_datagram_recv_common_ctl(sk, msg, skb); + if (skb->protocol == htons(ETH_P_IPV6)) { + sin->sin6_addr = ipv6_hdr(skb)->saddr; + if (np->rxopt.all) + ip6_datagram_recv_specific_ctl(sk, msg, skb); + sin->sin6_scope_id = + ipv6_iface_scope_id(&sin->sin6_addr, + IP6CB(skb)->iif); + } else { + ipv6_addr_set_v4mapped(ip_hdr(skb)->saddr, + &sin->sin6_addr); + if (inet_sk(sk)->cmsg_flags) + ip_cmsg_recv(msg, skb); + } + } + + put_cmsg(msg, SOL_IPV6, IPV6_RECVERR, sizeof(errhdr), &errhdr); + + /* Now we could try to dump offended packet options */ + + msg->msg_flags |= MSG_ERRQUEUE; + err = copied; + + consume_skb(skb); +out: + return err; +} +EXPORT_SYMBOL_GPL(ipv6_recv_error); + +/* + * Handle IPV6_RECVPATHMTU + */ +int ipv6_recv_rxpmtu(struct sock *sk, struct msghdr *msg, int len, + int *addr_len) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct sk_buff *skb; + struct ip6_mtuinfo mtu_info; + DECLARE_SOCKADDR(struct sockaddr_in6 *, sin, msg->msg_name); + int err; + int copied; + + err = -EAGAIN; + skb = xchg(&np->rxpmtu, NULL); + if (!skb) + goto out; + + copied = skb->len; + if (copied > len) { + msg->msg_flags |= MSG_TRUNC; + copied = len; + } + err = skb_copy_datagram_msg(skb, 0, msg, copied); + if (err) + goto out_free_skb; + + sock_recv_timestamp(msg, sk, skb); + + memcpy(&mtu_info, IP6CBMTU(skb), sizeof(mtu_info)); + + if (sin) { + sin->sin6_family = AF_INET6; + sin->sin6_flowinfo = 0; + sin->sin6_port = 0; + sin->sin6_scope_id = mtu_info.ip6m_addr.sin6_scope_id; + sin->sin6_addr = mtu_info.ip6m_addr.sin6_addr; + *addr_len = sizeof(*sin); + } + + put_cmsg(msg, SOL_IPV6, IPV6_PATHMTU, sizeof(mtu_info), &mtu_info); + + err = copied; + +out_free_skb: + kfree_skb(skb); +out: + return err; +} + + +void ip6_datagram_recv_common_ctl(struct sock *sk, struct msghdr *msg, + struct sk_buff *skb) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + bool is_ipv6 = skb->protocol == htons(ETH_P_IPV6); + + if (np->rxopt.bits.rxinfo) { + struct in6_pktinfo src_info; + + if (is_ipv6) { + src_info.ipi6_ifindex = IP6CB(skb)->iif; + src_info.ipi6_addr = ipv6_hdr(skb)->daddr; + } else { + src_info.ipi6_ifindex = + PKTINFO_SKB_CB(skb)->ipi_ifindex; + ipv6_addr_set_v4mapped(ip_hdr(skb)->daddr, + &src_info.ipi6_addr); + } + + if (src_info.ipi6_ifindex >= 0) + put_cmsg(msg, SOL_IPV6, IPV6_PKTINFO, + sizeof(src_info), &src_info); + } +} + +void ip6_datagram_recv_specific_ctl(struct sock *sk, struct msghdr *msg, + struct sk_buff *skb) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct inet6_skb_parm *opt = IP6CB(skb); + unsigned char *nh = skb_network_header(skb); + + if (np->rxopt.bits.rxhlim) { + int hlim = ipv6_hdr(skb)->hop_limit; + put_cmsg(msg, SOL_IPV6, IPV6_HOPLIMIT, sizeof(hlim), &hlim); + } + + if (np->rxopt.bits.rxtclass) { + int tclass = ipv6_get_dsfield(ipv6_hdr(skb)); + put_cmsg(msg, SOL_IPV6, IPV6_TCLASS, sizeof(tclass), &tclass); + } + + if (np->rxopt.bits.rxflow) { + __be32 flowinfo = ip6_flowinfo((struct ipv6hdr *)nh); + if (flowinfo) + put_cmsg(msg, SOL_IPV6, IPV6_FLOWINFO, sizeof(flowinfo), &flowinfo); + } + + /* HbH is allowed only once */ + if (np->rxopt.bits.hopopts && (opt->flags & IP6SKB_HOPBYHOP)) { + u8 *ptr = nh + sizeof(struct ipv6hdr); + put_cmsg(msg, SOL_IPV6, IPV6_HOPOPTS, (ptr[1]+1)<<3, ptr); + } + + if (opt->lastopt && + (np->rxopt.bits.dstopts || np->rxopt.bits.srcrt)) { + /* + * Silly enough, but we need to reparse in order to + * report extension headers (except for HbH) + * in order. + * + * Also note that IPV6_RECVRTHDRDSTOPTS is NOT + * (and WILL NOT be) defined because + * IPV6_RECVDSTOPTS is more generic. --yoshfuji + */ + unsigned int off = sizeof(struct ipv6hdr); + u8 nexthdr = ipv6_hdr(skb)->nexthdr; + + while (off <= opt->lastopt) { + unsigned int len; + u8 *ptr = nh + off; + + switch (nexthdr) { + case IPPROTO_DSTOPTS: + nexthdr = ptr[0]; + len = (ptr[1] + 1) << 3; + if (np->rxopt.bits.dstopts) + put_cmsg(msg, SOL_IPV6, IPV6_DSTOPTS, len, ptr); + break; + case IPPROTO_ROUTING: + nexthdr = ptr[0]; + len = (ptr[1] + 1) << 3; + if (np->rxopt.bits.srcrt) + put_cmsg(msg, SOL_IPV6, IPV6_RTHDR, len, ptr); + break; + case IPPROTO_AH: + nexthdr = ptr[0]; + len = (ptr[1] + 2) << 2; + break; + default: + nexthdr = ptr[0]; + len = (ptr[1] + 1) << 3; + break; + } + + off += len; + } + } + + /* socket options in old style */ + if (np->rxopt.bits.rxoinfo) { + struct in6_pktinfo src_info; + + src_info.ipi6_ifindex = opt->iif; + src_info.ipi6_addr = ipv6_hdr(skb)->daddr; + put_cmsg(msg, SOL_IPV6, IPV6_2292PKTINFO, sizeof(src_info), &src_info); + } + if (np->rxopt.bits.rxohlim) { + int hlim = ipv6_hdr(skb)->hop_limit; + put_cmsg(msg, SOL_IPV6, IPV6_2292HOPLIMIT, sizeof(hlim), &hlim); + } + if (np->rxopt.bits.ohopopts && (opt->flags & IP6SKB_HOPBYHOP)) { + u8 *ptr = nh + sizeof(struct ipv6hdr); + put_cmsg(msg, SOL_IPV6, IPV6_2292HOPOPTS, (ptr[1]+1)<<3, ptr); + } + if (np->rxopt.bits.odstopts && opt->dst0) { + u8 *ptr = nh + opt->dst0; + put_cmsg(msg, SOL_IPV6, IPV6_2292DSTOPTS, (ptr[1]+1)<<3, ptr); + } + if (np->rxopt.bits.osrcrt && opt->srcrt) { + struct ipv6_rt_hdr *rthdr = (struct ipv6_rt_hdr *)(nh + opt->srcrt); + put_cmsg(msg, SOL_IPV6, IPV6_2292RTHDR, (rthdr->hdrlen+1) << 3, rthdr); + } + if (np->rxopt.bits.odstopts && opt->dst1) { + u8 *ptr = nh + opt->dst1; + put_cmsg(msg, SOL_IPV6, IPV6_2292DSTOPTS, (ptr[1]+1)<<3, ptr); + } + if (np->rxopt.bits.rxorigdstaddr) { + struct sockaddr_in6 sin6; + __be16 _ports[2], *ports; + + ports = skb_header_pointer(skb, skb_transport_offset(skb), + sizeof(_ports), &_ports); + if (ports) { + /* All current transport protocols have the port numbers in the + * first four bytes of the transport header and this function is + * written with this assumption in mind. + */ + sin6.sin6_family = AF_INET6; + sin6.sin6_addr = ipv6_hdr(skb)->daddr; + sin6.sin6_port = ports[1]; + sin6.sin6_flowinfo = 0; + sin6.sin6_scope_id = + ipv6_iface_scope_id(&ipv6_hdr(skb)->daddr, + opt->iif); + + put_cmsg(msg, SOL_IPV6, IPV6_ORIGDSTADDR, sizeof(sin6), &sin6); + } + } + if (np->rxopt.bits.recvfragsize && opt->frag_max_size) { + int val = opt->frag_max_size; + + put_cmsg(msg, SOL_IPV6, IPV6_RECVFRAGSIZE, sizeof(val), &val); + } +} + +void ip6_datagram_recv_ctl(struct sock *sk, struct msghdr *msg, + struct sk_buff *skb) +{ + ip6_datagram_recv_common_ctl(sk, msg, skb); + ip6_datagram_recv_specific_ctl(sk, msg, skb); +} +EXPORT_SYMBOL_GPL(ip6_datagram_recv_ctl); + +int ip6_datagram_send_ctl(struct net *net, struct sock *sk, + struct msghdr *msg, struct flowi6 *fl6, + struct ipcm6_cookie *ipc6) +{ + struct in6_pktinfo *src_info; + struct cmsghdr *cmsg; + struct ipv6_rt_hdr *rthdr; + struct ipv6_opt_hdr *hdr; + struct ipv6_txoptions *opt = ipc6->opt; + int len; + int err = 0; + + for_each_cmsghdr(cmsg, msg) { + int addr_type; + + if (!CMSG_OK(msg, cmsg)) { + err = -EINVAL; + goto exit_f; + } + + if (cmsg->cmsg_level == SOL_SOCKET) { + err = __sock_cmsg_send(sk, msg, cmsg, &ipc6->sockc); + if (err) + return err; + continue; + } + + if (cmsg->cmsg_level != SOL_IPV6) + continue; + + switch (cmsg->cmsg_type) { + case IPV6_PKTINFO: + case IPV6_2292PKTINFO: + { + struct net_device *dev = NULL; + int src_idx; + + if (cmsg->cmsg_len < CMSG_LEN(sizeof(struct in6_pktinfo))) { + err = -EINVAL; + goto exit_f; + } + + src_info = (struct in6_pktinfo *)CMSG_DATA(cmsg); + src_idx = src_info->ipi6_ifindex; + + if (src_idx) { + if (fl6->flowi6_oif && + src_idx != fl6->flowi6_oif && + (READ_ONCE(sk->sk_bound_dev_if) != fl6->flowi6_oif || + !sk_dev_equal_l3scope(sk, src_idx))) + return -EINVAL; + fl6->flowi6_oif = src_idx; + } + + addr_type = __ipv6_addr_type(&src_info->ipi6_addr); + + rcu_read_lock(); + if (fl6->flowi6_oif) { + dev = dev_get_by_index_rcu(net, fl6->flowi6_oif); + if (!dev) { + rcu_read_unlock(); + return -ENODEV; + } + } else if (addr_type & IPV6_ADDR_LINKLOCAL) { + rcu_read_unlock(); + return -EINVAL; + } + + if (addr_type != IPV6_ADDR_ANY) { + int strict = __ipv6_addr_src_scope(addr_type) <= IPV6_ADDR_SCOPE_LINKLOCAL; + if (!ipv6_can_nonlocal_bind(net, inet_sk(sk)) && + !ipv6_chk_addr_and_flags(net, &src_info->ipi6_addr, + dev, !strict, 0, + IFA_F_TENTATIVE) && + !ipv6_chk_acast_addr_src(net, dev, + &src_info->ipi6_addr)) + err = -EINVAL; + else + fl6->saddr = src_info->ipi6_addr; + } + + rcu_read_unlock(); + + if (err) + goto exit_f; + + break; + } + + case IPV6_FLOWINFO: + if (cmsg->cmsg_len < CMSG_LEN(4)) { + err = -EINVAL; + goto exit_f; + } + + if (fl6->flowlabel&IPV6_FLOWINFO_MASK) { + if ((fl6->flowlabel^*(__be32 *)CMSG_DATA(cmsg))&~IPV6_FLOWINFO_MASK) { + err = -EINVAL; + goto exit_f; + } + } + fl6->flowlabel = IPV6_FLOWINFO_MASK & *(__be32 *)CMSG_DATA(cmsg); + break; + + case IPV6_2292HOPOPTS: + case IPV6_HOPOPTS: + if (opt->hopopt || cmsg->cmsg_len < CMSG_LEN(sizeof(struct ipv6_opt_hdr))) { + err = -EINVAL; + goto exit_f; + } + + hdr = (struct ipv6_opt_hdr *)CMSG_DATA(cmsg); + len = ((hdr->hdrlen + 1) << 3); + if (cmsg->cmsg_len < CMSG_LEN(len)) { + err = -EINVAL; + goto exit_f; + } + if (!ns_capable(net->user_ns, CAP_NET_RAW)) { + err = -EPERM; + goto exit_f; + } + opt->opt_nflen += len; + opt->hopopt = hdr; + break; + + case IPV6_2292DSTOPTS: + if (cmsg->cmsg_len < CMSG_LEN(sizeof(struct ipv6_opt_hdr))) { + err = -EINVAL; + goto exit_f; + } + + hdr = (struct ipv6_opt_hdr *)CMSG_DATA(cmsg); + len = ((hdr->hdrlen + 1) << 3); + if (cmsg->cmsg_len < CMSG_LEN(len)) { + err = -EINVAL; + goto exit_f; + } + if (!ns_capable(net->user_ns, CAP_NET_RAW)) { + err = -EPERM; + goto exit_f; + } + if (opt->dst1opt) { + err = -EINVAL; + goto exit_f; + } + opt->opt_flen += len; + opt->dst1opt = hdr; + break; + + case IPV6_DSTOPTS: + case IPV6_RTHDRDSTOPTS: + if (cmsg->cmsg_len < CMSG_LEN(sizeof(struct ipv6_opt_hdr))) { + err = -EINVAL; + goto exit_f; + } + + hdr = (struct ipv6_opt_hdr *)CMSG_DATA(cmsg); + len = ((hdr->hdrlen + 1) << 3); + if (cmsg->cmsg_len < CMSG_LEN(len)) { + err = -EINVAL; + goto exit_f; + } + if (!ns_capable(net->user_ns, CAP_NET_RAW)) { + err = -EPERM; + goto exit_f; + } + if (cmsg->cmsg_type == IPV6_DSTOPTS) { + opt->opt_flen += len; + opt->dst1opt = hdr; + } else { + opt->opt_nflen += len; + opt->dst0opt = hdr; + } + break; + + case IPV6_2292RTHDR: + case IPV6_RTHDR: + if (cmsg->cmsg_len < CMSG_LEN(sizeof(struct ipv6_rt_hdr))) { + err = -EINVAL; + goto exit_f; + } + + rthdr = (struct ipv6_rt_hdr *)CMSG_DATA(cmsg); + + switch (rthdr->type) { +#if IS_ENABLED(CONFIG_IPV6_MIP6) + case IPV6_SRCRT_TYPE_2: + if (rthdr->hdrlen != 2 || + rthdr->segments_left != 1) { + err = -EINVAL; + goto exit_f; + } + break; +#endif + default: + err = -EINVAL; + goto exit_f; + } + + len = ((rthdr->hdrlen + 1) << 3); + + if (cmsg->cmsg_len < CMSG_LEN(len)) { + err = -EINVAL; + goto exit_f; + } + + /* segments left must also match */ + if ((rthdr->hdrlen >> 1) != rthdr->segments_left) { + err = -EINVAL; + goto exit_f; + } + + opt->opt_nflen += len; + opt->srcrt = rthdr; + + if (cmsg->cmsg_type == IPV6_2292RTHDR && opt->dst1opt) { + int dsthdrlen = ((opt->dst1opt->hdrlen+1)<<3); + + opt->opt_nflen += dsthdrlen; + opt->dst0opt = opt->dst1opt; + opt->dst1opt = NULL; + opt->opt_flen -= dsthdrlen; + } + + break; + + case IPV6_2292HOPLIMIT: + case IPV6_HOPLIMIT: + if (cmsg->cmsg_len != CMSG_LEN(sizeof(int))) { + err = -EINVAL; + goto exit_f; + } + + ipc6->hlimit = *(int *)CMSG_DATA(cmsg); + if (ipc6->hlimit < -1 || ipc6->hlimit > 0xff) { + err = -EINVAL; + goto exit_f; + } + + break; + + case IPV6_TCLASS: + { + int tc; + + err = -EINVAL; + if (cmsg->cmsg_len != CMSG_LEN(sizeof(int))) + goto exit_f; + + tc = *(int *)CMSG_DATA(cmsg); + if (tc < -1 || tc > 0xff) + goto exit_f; + + err = 0; + ipc6->tclass = tc; + + break; + } + + case IPV6_DONTFRAG: + { + int df; + + err = -EINVAL; + if (cmsg->cmsg_len != CMSG_LEN(sizeof(int))) + goto exit_f; + + df = *(int *)CMSG_DATA(cmsg); + if (df < 0 || df > 1) + goto exit_f; + + err = 0; + ipc6->dontfrag = df; + + break; + } + default: + net_dbg_ratelimited("invalid cmsg type: %d\n", + cmsg->cmsg_type); + err = -EINVAL; + goto exit_f; + } + } + +exit_f: + return err; +} +EXPORT_SYMBOL_GPL(ip6_datagram_send_ctl); + +void __ip6_dgram_sock_seq_show(struct seq_file *seq, struct sock *sp, + __u16 srcp, __u16 destp, int rqueue, int bucket) +{ + const struct in6_addr *dest, *src; + + dest = &sp->sk_v6_daddr; + src = &sp->sk_v6_rcv_saddr; + seq_printf(seq, + "%5d: %08X%08X%08X%08X:%04X %08X%08X%08X%08X:%04X " + "%02X %08X:%08X %02X:%08lX %08X %5u %8d %lu %d %pK %u\n", + bucket, + src->s6_addr32[0], src->s6_addr32[1], + src->s6_addr32[2], src->s6_addr32[3], srcp, + dest->s6_addr32[0], dest->s6_addr32[1], + dest->s6_addr32[2], dest->s6_addr32[3], destp, + sp->sk_state, + sk_wmem_alloc_get(sp), + rqueue, + 0, 0L, 0, + from_kuid_munged(seq_user_ns(seq), sock_i_uid(sp)), + 0, + sock_i_ino(sp), + refcount_read(&sp->sk_refcnt), sp, + atomic_read(&sp->sk_drops)); +} diff --git a/net/ipv6/esp6.c b/net/ipv6/esp6.c new file mode 100644 index 000000000..c2dcb5c61 --- /dev/null +++ b/net/ipv6/esp6.c @@ -0,0 +1,1305 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (C)2002 USAGI/WIDE Project + * + * Authors + * + * Mitsuru KANDA @USAGI : IPv6 Support + * Kazunori MIYAZAWA @USAGI : + * Kunihiro Ishiguro + * + * This file is derived from net/ipv4/esp.c + */ + +#define pr_fmt(fmt) "IPv6: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +struct esp_skb_cb { + struct xfrm_skb_cb xfrm; + void *tmp; +}; + +struct esp_output_extra { + __be32 seqhi; + u32 esphoff; +}; + +#define ESP_SKB_CB(__skb) ((struct esp_skb_cb *)&((__skb)->cb[0])) + +/* + * Allocate an AEAD request structure with extra space for SG and IV. + * + * For alignment considerations the upper 32 bits of the sequence number are + * placed at the front, if present. Followed by the IV, the request and finally + * the SG list. + * + * TODO: Use spare space in skb for this where possible. + */ +static void *esp_alloc_tmp(struct crypto_aead *aead, int nfrags, int seqihlen) +{ + unsigned int len; + + len = seqihlen; + + len += crypto_aead_ivsize(aead); + + if (len) { + len += crypto_aead_alignmask(aead) & + ~(crypto_tfm_ctx_alignment() - 1); + len = ALIGN(len, crypto_tfm_ctx_alignment()); + } + + len += sizeof(struct aead_request) + crypto_aead_reqsize(aead); + len = ALIGN(len, __alignof__(struct scatterlist)); + + len += sizeof(struct scatterlist) * nfrags; + + return kmalloc(len, GFP_ATOMIC); +} + +static inline void *esp_tmp_extra(void *tmp) +{ + return PTR_ALIGN(tmp, __alignof__(struct esp_output_extra)); +} + +static inline u8 *esp_tmp_iv(struct crypto_aead *aead, void *tmp, int seqhilen) +{ + return crypto_aead_ivsize(aead) ? + PTR_ALIGN((u8 *)tmp + seqhilen, + crypto_aead_alignmask(aead) + 1) : tmp + seqhilen; +} + +static inline struct aead_request *esp_tmp_req(struct crypto_aead *aead, u8 *iv) +{ + struct aead_request *req; + + req = (void *)PTR_ALIGN(iv + crypto_aead_ivsize(aead), + crypto_tfm_ctx_alignment()); + aead_request_set_tfm(req, aead); + return req; +} + +static inline struct scatterlist *esp_req_sg(struct crypto_aead *aead, + struct aead_request *req) +{ + return (void *)ALIGN((unsigned long)(req + 1) + + crypto_aead_reqsize(aead), + __alignof__(struct scatterlist)); +} + +static void esp_ssg_unref(struct xfrm_state *x, void *tmp) +{ + struct crypto_aead *aead = x->data; + int extralen = 0; + u8 *iv; + struct aead_request *req; + struct scatterlist *sg; + + if (x->props.flags & XFRM_STATE_ESN) + extralen += sizeof(struct esp_output_extra); + + iv = esp_tmp_iv(aead, tmp, extralen); + req = esp_tmp_req(aead, iv); + + /* Unref skb_frag_pages in the src scatterlist if necessary. + * Skip the first sg which comes from skb->data. + */ + if (req->src != req->dst) + for (sg = sg_next(req->src); sg; sg = sg_next(sg)) + put_page(sg_page(sg)); +} + +#ifdef CONFIG_INET6_ESPINTCP +struct esp_tcp_sk { + struct sock *sk; + struct rcu_head rcu; +}; + +static void esp_free_tcp_sk(struct rcu_head *head) +{ + struct esp_tcp_sk *esk = container_of(head, struct esp_tcp_sk, rcu); + + sock_put(esk->sk); + kfree(esk); +} + +static struct sock *esp6_find_tcp_sk(struct xfrm_state *x) +{ + struct xfrm_encap_tmpl *encap = x->encap; + struct net *net = xs_net(x); + struct esp_tcp_sk *esk; + __be16 sport, dport; + struct sock *nsk; + struct sock *sk; + + sk = rcu_dereference(x->encap_sk); + if (sk && sk->sk_state == TCP_ESTABLISHED) + return sk; + + spin_lock_bh(&x->lock); + sport = encap->encap_sport; + dport = encap->encap_dport; + nsk = rcu_dereference_protected(x->encap_sk, + lockdep_is_held(&x->lock)); + if (sk && sk == nsk) { + esk = kmalloc(sizeof(*esk), GFP_ATOMIC); + if (!esk) { + spin_unlock_bh(&x->lock); + return ERR_PTR(-ENOMEM); + } + RCU_INIT_POINTER(x->encap_sk, NULL); + esk->sk = sk; + call_rcu(&esk->rcu, esp_free_tcp_sk); + } + spin_unlock_bh(&x->lock); + + sk = __inet6_lookup_established(net, net->ipv4.tcp_death_row.hashinfo, &x->id.daddr.in6, + dport, &x->props.saddr.in6, ntohs(sport), 0, 0); + if (!sk) + return ERR_PTR(-ENOENT); + + if (!tcp_is_ulp_esp(sk)) { + sock_put(sk); + return ERR_PTR(-EINVAL); + } + + spin_lock_bh(&x->lock); + nsk = rcu_dereference_protected(x->encap_sk, + lockdep_is_held(&x->lock)); + if (encap->encap_sport != sport || + encap->encap_dport != dport) { + sock_put(sk); + sk = nsk ?: ERR_PTR(-EREMCHG); + } else if (sk == nsk) { + sock_put(sk); + } else { + rcu_assign_pointer(x->encap_sk, sk); + } + spin_unlock_bh(&x->lock); + + return sk; +} + +static int esp_output_tcp_finish(struct xfrm_state *x, struct sk_buff *skb) +{ + struct sock *sk; + int err; + + rcu_read_lock(); + + sk = esp6_find_tcp_sk(x); + err = PTR_ERR_OR_ZERO(sk); + if (err) + goto out; + + bh_lock_sock(sk); + if (sock_owned_by_user(sk)) + err = espintcp_queue_out(sk, skb); + else + err = espintcp_push_skb(sk, skb); + bh_unlock_sock(sk); + +out: + rcu_read_unlock(); + return err; +} + +static int esp_output_tcp_encap_cb(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + struct xfrm_state *x = dst->xfrm; + + return esp_output_tcp_finish(x, skb); +} + +static int esp_output_tail_tcp(struct xfrm_state *x, struct sk_buff *skb) +{ + int err; + + local_bh_disable(); + err = xfrm_trans_queue_net(xs_net(x), skb, esp_output_tcp_encap_cb); + local_bh_enable(); + + /* EINPROGRESS just happens to do the right thing. It + * actually means that the skb has been consumed and + * isn't coming back. + */ + return err ?: -EINPROGRESS; +} +#else +static int esp_output_tail_tcp(struct xfrm_state *x, struct sk_buff *skb) +{ + kfree_skb(skb); + + return -EOPNOTSUPP; +} +#endif + +static void esp_output_encap_csum(struct sk_buff *skb) +{ + /* UDP encap with IPv6 requires a valid checksum */ + if (*skb_mac_header(skb) == IPPROTO_UDP) { + struct udphdr *uh = udp_hdr(skb); + struct ipv6hdr *ip6h = ipv6_hdr(skb); + int len = ntohs(uh->len); + unsigned int offset = skb_transport_offset(skb); + __wsum csum = skb_checksum(skb, offset, skb->len - offset, 0); + + uh->check = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, + len, IPPROTO_UDP, csum); + if (uh->check == 0) + uh->check = CSUM_MANGLED_0; + } +} + +static void esp_output_done(struct crypto_async_request *base, int err) +{ + struct sk_buff *skb = base->data; + struct xfrm_offload *xo = xfrm_offload(skb); + void *tmp; + struct xfrm_state *x; + + if (xo && (xo->flags & XFRM_DEV_RESUME)) { + struct sec_path *sp = skb_sec_path(skb); + + x = sp->xvec[sp->len - 1]; + } else { + x = skb_dst(skb)->xfrm; + } + + tmp = ESP_SKB_CB(skb)->tmp; + esp_ssg_unref(x, tmp); + kfree(tmp); + + esp_output_encap_csum(skb); + + if (xo && (xo->flags & XFRM_DEV_RESUME)) { + if (err) { + XFRM_INC_STATS(xs_net(x), LINUX_MIB_XFRMOUTSTATEPROTOERROR); + kfree_skb(skb); + return; + } + + skb_push(skb, skb->data - skb_mac_header(skb)); + secpath_reset(skb); + xfrm_dev_resume(skb); + } else { + if (!err && + x->encap && x->encap->encap_type == TCP_ENCAP_ESPINTCP) + esp_output_tail_tcp(x, skb); + else + xfrm_output_resume(skb->sk, skb, err); + } +} + +/* Move ESP header back into place. */ +static void esp_restore_header(struct sk_buff *skb, unsigned int offset) +{ + struct ip_esp_hdr *esph = (void *)(skb->data + offset); + void *tmp = ESP_SKB_CB(skb)->tmp; + __be32 *seqhi = esp_tmp_extra(tmp); + + esph->seq_no = esph->spi; + esph->spi = *seqhi; +} + +static void esp_output_restore_header(struct sk_buff *skb) +{ + void *tmp = ESP_SKB_CB(skb)->tmp; + struct esp_output_extra *extra = esp_tmp_extra(tmp); + + esp_restore_header(skb, skb_transport_offset(skb) + extra->esphoff - + sizeof(__be32)); +} + +static struct ip_esp_hdr *esp_output_set_esn(struct sk_buff *skb, + struct xfrm_state *x, + struct ip_esp_hdr *esph, + struct esp_output_extra *extra) +{ + /* For ESN we move the header forward by 4 bytes to + * accommodate the high bits. We will move it back after + * encryption. + */ + if ((x->props.flags & XFRM_STATE_ESN)) { + __u32 seqhi; + struct xfrm_offload *xo = xfrm_offload(skb); + + if (xo) + seqhi = xo->seq.hi; + else + seqhi = XFRM_SKB_CB(skb)->seq.output.hi; + + extra->esphoff = (unsigned char *)esph - + skb_transport_header(skb); + esph = (struct ip_esp_hdr *)((unsigned char *)esph - 4); + extra->seqhi = esph->spi; + esph->seq_no = htonl(seqhi); + } + + esph->spi = x->id.spi; + + return esph; +} + +static void esp_output_done_esn(struct crypto_async_request *base, int err) +{ + struct sk_buff *skb = base->data; + + esp_output_restore_header(skb); + esp_output_done(base, err); +} + +static struct ip_esp_hdr *esp6_output_udp_encap(struct sk_buff *skb, + int encap_type, + struct esp_info *esp, + __be16 sport, + __be16 dport) +{ + struct udphdr *uh; + __be32 *udpdata32; + unsigned int len; + + len = skb->len + esp->tailen - skb_transport_offset(skb); + if (len > U16_MAX) + return ERR_PTR(-EMSGSIZE); + + uh = (struct udphdr *)esp->esph; + uh->source = sport; + uh->dest = dport; + uh->len = htons(len); + uh->check = 0; + + *skb_mac_header(skb) = IPPROTO_UDP; + + if (encap_type == UDP_ENCAP_ESPINUDP_NON_IKE) { + udpdata32 = (__be32 *)(uh + 1); + udpdata32[0] = udpdata32[1] = 0; + return (struct ip_esp_hdr *)(udpdata32 + 2); + } + + return (struct ip_esp_hdr *)(uh + 1); +} + +#ifdef CONFIG_INET6_ESPINTCP +static struct ip_esp_hdr *esp6_output_tcp_encap(struct xfrm_state *x, + struct sk_buff *skb, + struct esp_info *esp) +{ + __be16 *lenp = (void *)esp->esph; + struct ip_esp_hdr *esph; + unsigned int len; + struct sock *sk; + + len = skb->len + esp->tailen - skb_transport_offset(skb); + if (len > IP_MAX_MTU) + return ERR_PTR(-EMSGSIZE); + + rcu_read_lock(); + sk = esp6_find_tcp_sk(x); + rcu_read_unlock(); + + if (IS_ERR(sk)) + return ERR_CAST(sk); + + *lenp = htons(len); + esph = (struct ip_esp_hdr *)(lenp + 1); + + return esph; +} +#else +static struct ip_esp_hdr *esp6_output_tcp_encap(struct xfrm_state *x, + struct sk_buff *skb, + struct esp_info *esp) +{ + return ERR_PTR(-EOPNOTSUPP); +} +#endif + +static int esp6_output_encap(struct xfrm_state *x, struct sk_buff *skb, + struct esp_info *esp) +{ + struct xfrm_encap_tmpl *encap = x->encap; + struct ip_esp_hdr *esph; + __be16 sport, dport; + int encap_type; + + spin_lock_bh(&x->lock); + sport = encap->encap_sport; + dport = encap->encap_dport; + encap_type = encap->encap_type; + spin_unlock_bh(&x->lock); + + switch (encap_type) { + default: + case UDP_ENCAP_ESPINUDP: + case UDP_ENCAP_ESPINUDP_NON_IKE: + esph = esp6_output_udp_encap(skb, encap_type, esp, sport, dport); + break; + case TCP_ENCAP_ESPINTCP: + esph = esp6_output_tcp_encap(x, skb, esp); + break; + } + + if (IS_ERR(esph)) + return PTR_ERR(esph); + + esp->esph = esph; + + return 0; +} + +int esp6_output_head(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *esp) +{ + u8 *tail; + int nfrags; + int esph_offset; + struct page *page; + struct sk_buff *trailer; + int tailen = esp->tailen; + + if (x->encap) { + int err = esp6_output_encap(x, skb, esp); + + if (err < 0) + return err; + } + + if (ALIGN(tailen, L1_CACHE_BYTES) > PAGE_SIZE || + ALIGN(skb->data_len, L1_CACHE_BYTES) > PAGE_SIZE) + goto cow; + + if (!skb_cloned(skb)) { + if (tailen <= skb_tailroom(skb)) { + nfrags = 1; + trailer = skb; + tail = skb_tail_pointer(trailer); + + goto skip_cow; + } else if ((skb_shinfo(skb)->nr_frags < MAX_SKB_FRAGS) + && !skb_has_frag_list(skb)) { + int allocsize; + struct sock *sk = skb->sk; + struct page_frag *pfrag = &x->xfrag; + + esp->inplace = false; + + allocsize = ALIGN(tailen, L1_CACHE_BYTES); + + spin_lock_bh(&x->lock); + + if (unlikely(!skb_page_frag_refill(allocsize, pfrag, GFP_ATOMIC))) { + spin_unlock_bh(&x->lock); + goto cow; + } + + page = pfrag->page; + get_page(page); + + tail = page_address(page) + pfrag->offset; + + esp_output_fill_trailer(tail, esp->tfclen, esp->plen, esp->proto); + + nfrags = skb_shinfo(skb)->nr_frags; + + __skb_fill_page_desc(skb, nfrags, page, pfrag->offset, + tailen); + skb_shinfo(skb)->nr_frags = ++nfrags; + + pfrag->offset = pfrag->offset + allocsize; + + spin_unlock_bh(&x->lock); + + nfrags++; + + skb->len += tailen; + skb->data_len += tailen; + skb->truesize += tailen; + if (sk && sk_fullsock(sk)) + refcount_add(tailen, &sk->sk_wmem_alloc); + + goto out; + } + } + +cow: + esph_offset = (unsigned char *)esp->esph - skb_transport_header(skb); + + nfrags = skb_cow_data(skb, tailen, &trailer); + if (nfrags < 0) + goto out; + tail = skb_tail_pointer(trailer); + esp->esph = (struct ip_esp_hdr *)(skb_transport_header(skb) + esph_offset); + +skip_cow: + esp_output_fill_trailer(tail, esp->tfclen, esp->plen, esp->proto); + pskb_put(skb, trailer, tailen); + +out: + return nfrags; +} +EXPORT_SYMBOL_GPL(esp6_output_head); + +int esp6_output_tail(struct xfrm_state *x, struct sk_buff *skb, struct esp_info *esp) +{ + u8 *iv; + int alen; + void *tmp; + int ivlen; + int assoclen; + int extralen; + struct page *page; + struct ip_esp_hdr *esph; + struct aead_request *req; + struct crypto_aead *aead; + struct scatterlist *sg, *dsg; + struct esp_output_extra *extra; + int err = -ENOMEM; + + assoclen = sizeof(struct ip_esp_hdr); + extralen = 0; + + if (x->props.flags & XFRM_STATE_ESN) { + extralen += sizeof(*extra); + assoclen += sizeof(__be32); + } + + aead = x->data; + alen = crypto_aead_authsize(aead); + ivlen = crypto_aead_ivsize(aead); + + tmp = esp_alloc_tmp(aead, esp->nfrags + 2, extralen); + if (!tmp) + goto error; + + extra = esp_tmp_extra(tmp); + iv = esp_tmp_iv(aead, tmp, extralen); + req = esp_tmp_req(aead, iv); + sg = esp_req_sg(aead, req); + + if (esp->inplace) + dsg = sg; + else + dsg = &sg[esp->nfrags]; + + esph = esp_output_set_esn(skb, x, esp->esph, extra); + esp->esph = esph; + + sg_init_table(sg, esp->nfrags); + err = skb_to_sgvec(skb, sg, + (unsigned char *)esph - skb->data, + assoclen + ivlen + esp->clen + alen); + if (unlikely(err < 0)) + goto error_free; + + if (!esp->inplace) { + int allocsize; + struct page_frag *pfrag = &x->xfrag; + + allocsize = ALIGN(skb->data_len, L1_CACHE_BYTES); + + spin_lock_bh(&x->lock); + if (unlikely(!skb_page_frag_refill(allocsize, pfrag, GFP_ATOMIC))) { + spin_unlock_bh(&x->lock); + goto error_free; + } + + skb_shinfo(skb)->nr_frags = 1; + + page = pfrag->page; + get_page(page); + /* replace page frags in skb with new page */ + __skb_fill_page_desc(skb, 0, page, pfrag->offset, skb->data_len); + pfrag->offset = pfrag->offset + allocsize; + spin_unlock_bh(&x->lock); + + sg_init_table(dsg, skb_shinfo(skb)->nr_frags + 1); + err = skb_to_sgvec(skb, dsg, + (unsigned char *)esph - skb->data, + assoclen + ivlen + esp->clen + alen); + if (unlikely(err < 0)) + goto error_free; + } + + if ((x->props.flags & XFRM_STATE_ESN)) + aead_request_set_callback(req, 0, esp_output_done_esn, skb); + else + aead_request_set_callback(req, 0, esp_output_done, skb); + + aead_request_set_crypt(req, sg, dsg, ivlen + esp->clen, iv); + aead_request_set_ad(req, assoclen); + + memset(iv, 0, ivlen); + memcpy(iv + ivlen - min(ivlen, 8), (u8 *)&esp->seqno + 8 - min(ivlen, 8), + min(ivlen, 8)); + + ESP_SKB_CB(skb)->tmp = tmp; + err = crypto_aead_encrypt(req); + + switch (err) { + case -EINPROGRESS: + goto error; + + case -ENOSPC: + err = NET_XMIT_DROP; + break; + + case 0: + if ((x->props.flags & XFRM_STATE_ESN)) + esp_output_restore_header(skb); + esp_output_encap_csum(skb); + } + + if (sg != dsg) + esp_ssg_unref(x, tmp); + + if (!err && x->encap && x->encap->encap_type == TCP_ENCAP_ESPINTCP) + err = esp_output_tail_tcp(x, skb); + +error_free: + kfree(tmp); +error: + return err; +} +EXPORT_SYMBOL_GPL(esp6_output_tail); + +static int esp6_output(struct xfrm_state *x, struct sk_buff *skb) +{ + int alen; + int blksize; + struct ip_esp_hdr *esph; + struct crypto_aead *aead; + struct esp_info esp; + + esp.inplace = true; + + esp.proto = *skb_mac_header(skb); + *skb_mac_header(skb) = IPPROTO_ESP; + + /* skb is pure payload to encrypt */ + + aead = x->data; + alen = crypto_aead_authsize(aead); + + esp.tfclen = 0; + if (x->tfcpad) { + struct xfrm_dst *dst = (struct xfrm_dst *)skb_dst(skb); + u32 padto; + + padto = min(x->tfcpad, xfrm_state_mtu(x, dst->child_mtu_cached)); + if (skb->len < padto) + esp.tfclen = padto - skb->len; + } + blksize = ALIGN(crypto_aead_blocksize(aead), 4); + esp.clen = ALIGN(skb->len + 2 + esp.tfclen, blksize); + esp.plen = esp.clen - skb->len - esp.tfclen; + esp.tailen = esp.tfclen + esp.plen + alen; + + esp.esph = ip_esp_hdr(skb); + + esp.nfrags = esp6_output_head(x, skb, &esp); + if (esp.nfrags < 0) + return esp.nfrags; + + esph = esp.esph; + esph->spi = x->id.spi; + + esph->seq_no = htonl(XFRM_SKB_CB(skb)->seq.output.low); + esp.seqno = cpu_to_be64(XFRM_SKB_CB(skb)->seq.output.low + + ((u64)XFRM_SKB_CB(skb)->seq.output.hi << 32)); + + skb_push(skb, -skb_network_offset(skb)); + + return esp6_output_tail(x, skb, &esp); +} + +static inline int esp_remove_trailer(struct sk_buff *skb) +{ + struct xfrm_state *x = xfrm_input_state(skb); + struct crypto_aead *aead = x->data; + int alen, hlen, elen; + int padlen, trimlen; + __wsum csumdiff; + u8 nexthdr[2]; + int ret; + + alen = crypto_aead_authsize(aead); + hlen = sizeof(struct ip_esp_hdr) + crypto_aead_ivsize(aead); + elen = skb->len - hlen; + + ret = skb_copy_bits(skb, skb->len - alen - 2, nexthdr, 2); + BUG_ON(ret); + + ret = -EINVAL; + padlen = nexthdr[0]; + if (padlen + 2 + alen >= elen) { + net_dbg_ratelimited("ipsec esp packet is garbage padlen=%d, elen=%d\n", + padlen + 2, elen - alen); + goto out; + } + + trimlen = alen + padlen + 2; + if (skb->ip_summed == CHECKSUM_COMPLETE) { + csumdiff = skb_checksum(skb, skb->len - trimlen, trimlen, 0); + skb->csum = csum_block_sub(skb->csum, csumdiff, + skb->len - trimlen); + } + ret = pskb_trim(skb, skb->len - trimlen); + if (unlikely(ret)) + return ret; + + ret = nexthdr[1]; + +out: + return ret; +} + +int esp6_input_done2(struct sk_buff *skb, int err) +{ + struct xfrm_state *x = xfrm_input_state(skb); + struct xfrm_offload *xo = xfrm_offload(skb); + struct crypto_aead *aead = x->data; + int hlen = sizeof(struct ip_esp_hdr) + crypto_aead_ivsize(aead); + int hdr_len = skb_network_header_len(skb); + + if (!xo || !(xo->flags & CRYPTO_DONE)) + kfree(ESP_SKB_CB(skb)->tmp); + + if (unlikely(err)) + goto out; + + err = esp_remove_trailer(skb); + if (unlikely(err < 0)) + goto out; + + if (x->encap) { + const struct ipv6hdr *ip6h = ipv6_hdr(skb); + int offset = skb_network_offset(skb) + sizeof(*ip6h); + struct xfrm_encap_tmpl *encap = x->encap; + u8 nexthdr = ip6h->nexthdr; + __be16 frag_off, source; + struct udphdr *uh; + struct tcphdr *th; + + offset = ipv6_skip_exthdr(skb, offset, &nexthdr, &frag_off); + if (offset == -1) { + err = -EINVAL; + goto out; + } + + uh = (void *)(skb->data + offset); + th = (void *)(skb->data + offset); + hdr_len += offset; + + switch (x->encap->encap_type) { + case TCP_ENCAP_ESPINTCP: + source = th->source; + break; + case UDP_ENCAP_ESPINUDP: + case UDP_ENCAP_ESPINUDP_NON_IKE: + source = uh->source; + break; + default: + WARN_ON_ONCE(1); + err = -EINVAL; + goto out; + } + + /* + * 1) if the NAT-T peer's IP or port changed then + * advertize the change to the keying daemon. + * This is an inbound SA, so just compare + * SRC ports. + */ + if (!ipv6_addr_equal(&ip6h->saddr, &x->props.saddr.in6) || + source != encap->encap_sport) { + xfrm_address_t ipaddr; + + memcpy(&ipaddr.a6, &ip6h->saddr.s6_addr, sizeof(ipaddr.a6)); + km_new_mapping(x, &ipaddr, source); + + /* XXX: perhaps add an extra + * policy check here, to see + * if we should allow or + * reject a packet from a + * different source + * address/port. + */ + } + + /* + * 2) ignore UDP/TCP checksums in case + * of NAT-T in Transport Mode, or + * perform other post-processing fixes + * as per draft-ietf-ipsec-udp-encaps-06, + * section 3.1.2 + */ + if (x->props.mode == XFRM_MODE_TRANSPORT) + skb->ip_summed = CHECKSUM_UNNECESSARY; + } + + skb_postpull_rcsum(skb, skb_network_header(skb), + skb_network_header_len(skb)); + skb_pull_rcsum(skb, hlen); + if (x->props.mode == XFRM_MODE_TUNNEL) + skb_reset_transport_header(skb); + else + skb_set_transport_header(skb, -hdr_len); + + /* RFC4303: Drop dummy packets without any error */ + if (err == IPPROTO_NONE) + err = -EINVAL; + +out: + return err; +} +EXPORT_SYMBOL_GPL(esp6_input_done2); + +static void esp_input_done(struct crypto_async_request *base, int err) +{ + struct sk_buff *skb = base->data; + + xfrm_input_resume(skb, esp6_input_done2(skb, err)); +} + +static void esp_input_restore_header(struct sk_buff *skb) +{ + esp_restore_header(skb, 0); + __skb_pull(skb, 4); +} + +static void esp_input_set_header(struct sk_buff *skb, __be32 *seqhi) +{ + struct xfrm_state *x = xfrm_input_state(skb); + + /* For ESN we move the header forward by 4 bytes to + * accommodate the high bits. We will move it back after + * decryption. + */ + if ((x->props.flags & XFRM_STATE_ESN)) { + struct ip_esp_hdr *esph = skb_push(skb, 4); + + *seqhi = esph->spi; + esph->spi = esph->seq_no; + esph->seq_no = XFRM_SKB_CB(skb)->seq.input.hi; + } +} + +static void esp_input_done_esn(struct crypto_async_request *base, int err) +{ + struct sk_buff *skb = base->data; + + esp_input_restore_header(skb); + esp_input_done(base, err); +} + +static int esp6_input(struct xfrm_state *x, struct sk_buff *skb) +{ + struct crypto_aead *aead = x->data; + struct aead_request *req; + struct sk_buff *trailer; + int ivlen = crypto_aead_ivsize(aead); + int elen = skb->len - sizeof(struct ip_esp_hdr) - ivlen; + int nfrags; + int assoclen; + int seqhilen; + int ret = 0; + void *tmp; + __be32 *seqhi; + u8 *iv; + struct scatterlist *sg; + + if (!pskb_may_pull(skb, sizeof(struct ip_esp_hdr) + ivlen)) { + ret = -EINVAL; + goto out; + } + + if (elen <= 0) { + ret = -EINVAL; + goto out; + } + + assoclen = sizeof(struct ip_esp_hdr); + seqhilen = 0; + + if (x->props.flags & XFRM_STATE_ESN) { + seqhilen += sizeof(__be32); + assoclen += seqhilen; + } + + if (!skb_cloned(skb)) { + if (!skb_is_nonlinear(skb)) { + nfrags = 1; + + goto skip_cow; + } else if (!skb_has_frag_list(skb)) { + nfrags = skb_shinfo(skb)->nr_frags; + nfrags++; + + goto skip_cow; + } + } + + nfrags = skb_cow_data(skb, 0, &trailer); + if (nfrags < 0) { + ret = -EINVAL; + goto out; + } + +skip_cow: + ret = -ENOMEM; + tmp = esp_alloc_tmp(aead, nfrags, seqhilen); + if (!tmp) + goto out; + + ESP_SKB_CB(skb)->tmp = tmp; + seqhi = esp_tmp_extra(tmp); + iv = esp_tmp_iv(aead, tmp, seqhilen); + req = esp_tmp_req(aead, iv); + sg = esp_req_sg(aead, req); + + esp_input_set_header(skb, seqhi); + + sg_init_table(sg, nfrags); + ret = skb_to_sgvec(skb, sg, 0, skb->len); + if (unlikely(ret < 0)) { + kfree(tmp); + goto out; + } + + skb->ip_summed = CHECKSUM_NONE; + + if ((x->props.flags & XFRM_STATE_ESN)) + aead_request_set_callback(req, 0, esp_input_done_esn, skb); + else + aead_request_set_callback(req, 0, esp_input_done, skb); + + aead_request_set_crypt(req, sg, sg, elen + ivlen, iv); + aead_request_set_ad(req, assoclen); + + ret = crypto_aead_decrypt(req); + if (ret == -EINPROGRESS) + goto out; + + if ((x->props.flags & XFRM_STATE_ESN)) + esp_input_restore_header(skb); + + ret = esp6_input_done2(skb, ret); + +out: + return ret; +} + +static int esp6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + struct net *net = dev_net(skb->dev); + const struct ipv6hdr *iph = (const struct ipv6hdr *)skb->data; + struct ip_esp_hdr *esph = (struct ip_esp_hdr *)(skb->data + offset); + struct xfrm_state *x; + + if (type != ICMPV6_PKT_TOOBIG && + type != NDISC_REDIRECT) + return 0; + + x = xfrm_state_lookup(net, skb->mark, (const xfrm_address_t *)&iph->daddr, + esph->spi, IPPROTO_ESP, AF_INET6); + if (!x) + return 0; + + if (type == NDISC_REDIRECT) + ip6_redirect(skb, net, skb->dev->ifindex, 0, + sock_net_uid(net, NULL)); + else + ip6_update_pmtu(skb, net, info, 0, 0, sock_net_uid(net, NULL)); + xfrm_state_put(x); + + return 0; +} + +static void esp6_destroy(struct xfrm_state *x) +{ + struct crypto_aead *aead = x->data; + + if (!aead) + return; + + crypto_free_aead(aead); +} + +static int esp_init_aead(struct xfrm_state *x, struct netlink_ext_ack *extack) +{ + char aead_name[CRYPTO_MAX_ALG_NAME]; + struct crypto_aead *aead; + int err; + + if (snprintf(aead_name, CRYPTO_MAX_ALG_NAME, "%s(%s)", + x->geniv, x->aead->alg_name) >= CRYPTO_MAX_ALG_NAME) { + NL_SET_ERR_MSG(extack, "Algorithm name is too long"); + return -ENAMETOOLONG; + } + + aead = crypto_alloc_aead(aead_name, 0, 0); + err = PTR_ERR(aead); + if (IS_ERR(aead)) + goto error; + + x->data = aead; + + err = crypto_aead_setkey(aead, x->aead->alg_key, + (x->aead->alg_key_len + 7) / 8); + if (err) + goto error; + + err = crypto_aead_setauthsize(aead, x->aead->alg_icv_len / 8); + if (err) + goto error; + + return 0; + +error: + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); + return err; +} + +static int esp_init_authenc(struct xfrm_state *x, + struct netlink_ext_ack *extack) +{ + struct crypto_aead *aead; + struct crypto_authenc_key_param *param; + struct rtattr *rta; + char *key; + char *p; + char authenc_name[CRYPTO_MAX_ALG_NAME]; + unsigned int keylen; + int err; + + err = -ENAMETOOLONG; + + if ((x->props.flags & XFRM_STATE_ESN)) { + if (snprintf(authenc_name, CRYPTO_MAX_ALG_NAME, + "%s%sauthencesn(%s,%s)%s", + x->geniv ?: "", x->geniv ? "(" : "", + x->aalg ? x->aalg->alg_name : "digest_null", + x->ealg->alg_name, + x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME) { + NL_SET_ERR_MSG(extack, "Algorithm name is too long"); + goto error; + } + } else { + if (snprintf(authenc_name, CRYPTO_MAX_ALG_NAME, + "%s%sauthenc(%s,%s)%s", + x->geniv ?: "", x->geniv ? "(" : "", + x->aalg ? x->aalg->alg_name : "digest_null", + x->ealg->alg_name, + x->geniv ? ")" : "") >= CRYPTO_MAX_ALG_NAME) { + NL_SET_ERR_MSG(extack, "Algorithm name is too long"); + goto error; + } + } + + aead = crypto_alloc_aead(authenc_name, 0, 0); + err = PTR_ERR(aead); + if (IS_ERR(aead)) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); + goto error; + } + + x->data = aead; + + keylen = (x->aalg ? (x->aalg->alg_key_len + 7) / 8 : 0) + + (x->ealg->alg_key_len + 7) / 8 + RTA_SPACE(sizeof(*param)); + err = -ENOMEM; + key = kmalloc(keylen, GFP_KERNEL); + if (!key) + goto error; + + p = key; + rta = (void *)p; + rta->rta_type = CRYPTO_AUTHENC_KEYA_PARAM; + rta->rta_len = RTA_LENGTH(sizeof(*param)); + param = RTA_DATA(rta); + p += RTA_SPACE(sizeof(*param)); + + if (x->aalg) { + struct xfrm_algo_desc *aalg_desc; + + memcpy(p, x->aalg->alg_key, (x->aalg->alg_key_len + 7) / 8); + p += (x->aalg->alg_key_len + 7) / 8; + + aalg_desc = xfrm_aalg_get_byname(x->aalg->alg_name, 0); + BUG_ON(!aalg_desc); + + err = -EINVAL; + if (aalg_desc->uinfo.auth.icv_fullbits / 8 != + crypto_aead_authsize(aead)) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); + goto free_key; + } + + err = crypto_aead_setauthsize( + aead, x->aalg->alg_trunc_len / 8); + if (err) { + NL_SET_ERR_MSG(extack, "Kernel was unable to initialize cryptographic operations"); + goto free_key; + } + } + + param->enckeylen = cpu_to_be32((x->ealg->alg_key_len + 7) / 8); + memcpy(p, x->ealg->alg_key, (x->ealg->alg_key_len + 7) / 8); + + err = crypto_aead_setkey(aead, key, keylen); + +free_key: + kfree(key); + +error: + return err; +} + +static int esp6_init_state(struct xfrm_state *x, struct netlink_ext_ack *extack) +{ + struct crypto_aead *aead; + u32 align; + int err; + + x->data = NULL; + + if (x->aead) { + err = esp_init_aead(x, extack); + } else if (x->ealg) { + err = esp_init_authenc(x, extack); + } else { + NL_SET_ERR_MSG(extack, "ESP: AEAD or CRYPT must be provided"); + err = -EINVAL; + } + + if (err) + goto error; + + aead = x->data; + + x->props.header_len = sizeof(struct ip_esp_hdr) + + crypto_aead_ivsize(aead); + switch (x->props.mode) { + case XFRM_MODE_BEET: + if (x->sel.family != AF_INET6) + x->props.header_len += IPV4_BEET_PHMAXLEN + + (sizeof(struct ipv6hdr) - sizeof(struct iphdr)); + break; + default: + case XFRM_MODE_TRANSPORT: + break; + case XFRM_MODE_TUNNEL: + x->props.header_len += sizeof(struct ipv6hdr); + break; + } + + if (x->encap) { + struct xfrm_encap_tmpl *encap = x->encap; + + switch (encap->encap_type) { + default: + NL_SET_ERR_MSG(extack, "Unsupported encapsulation type for ESP"); + err = -EINVAL; + goto error; + case UDP_ENCAP_ESPINUDP: + x->props.header_len += sizeof(struct udphdr); + break; + case UDP_ENCAP_ESPINUDP_NON_IKE: + x->props.header_len += sizeof(struct udphdr) + 2 * sizeof(u32); + break; +#ifdef CONFIG_INET6_ESPINTCP + case TCP_ENCAP_ESPINTCP: + /* only the length field, TCP encap is done by + * the socket + */ + x->props.header_len += 2; + break; +#endif + } + } + + align = ALIGN(crypto_aead_blocksize(aead), 4); + x->props.trailer_len = align + 1 + crypto_aead_authsize(aead); + +error: + return err; +} + +static int esp6_rcv_cb(struct sk_buff *skb, int err) +{ + return 0; +} + +static const struct xfrm_type esp6_type = { + .owner = THIS_MODULE, + .proto = IPPROTO_ESP, + .flags = XFRM_TYPE_REPLAY_PROT, + .init_state = esp6_init_state, + .destructor = esp6_destroy, + .input = esp6_input, + .output = esp6_output, +}; + +static struct xfrm6_protocol esp6_protocol = { + .handler = xfrm6_rcv, + .input_handler = xfrm_input, + .cb_handler = esp6_rcv_cb, + .err_handler = esp6_err, + .priority = 0, +}; + +static int __init esp6_init(void) +{ + if (xfrm_register_type(&esp6_type, AF_INET6) < 0) { + pr_info("%s: can't add xfrm type\n", __func__); + return -EAGAIN; + } + if (xfrm6_protocol_register(&esp6_protocol, IPPROTO_ESP) < 0) { + pr_info("%s: can't add protocol\n", __func__); + xfrm_unregister_type(&esp6_type, AF_INET6); + return -EAGAIN; + } + + return 0; +} + +static void __exit esp6_fini(void) +{ + if (xfrm6_protocol_deregister(&esp6_protocol, IPPROTO_ESP) < 0) + pr_info("%s: can't remove protocol\n", __func__); + xfrm_unregister_type(&esp6_type, AF_INET6); +} + +module_init(esp6_init); +module_exit(esp6_fini); + +MODULE_LICENSE("GPL"); +MODULE_ALIAS_XFRM_TYPE(AF_INET6, XFRM_PROTO_ESP); diff --git a/net/ipv6/esp6_offload.c b/net/ipv6/esp6_offload.c new file mode 100644 index 000000000..fc6a5be73 --- /dev/null +++ b/net/ipv6/esp6_offload.c @@ -0,0 +1,420 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * IPV6 GSO/GRO offload support + * Linux INET implementation + * + * Copyright (C) 2016 secunet Security Networks AG + * Author: Steffen Klassert + * + * ESP GRO support + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static __u16 esp6_nexthdr_esp_offset(struct ipv6hdr *ipv6_hdr, int nhlen) +{ + int off = sizeof(struct ipv6hdr); + struct ipv6_opt_hdr *exthdr; + + if (likely(ipv6_hdr->nexthdr == NEXTHDR_ESP)) + return offsetof(struct ipv6hdr, nexthdr); + + while (off < nhlen) { + exthdr = (void *)ipv6_hdr + off; + if (exthdr->nexthdr == NEXTHDR_ESP) + return off; + + off += ipv6_optlen(exthdr); + } + + return 0; +} + +static struct sk_buff *esp6_gro_receive(struct list_head *head, + struct sk_buff *skb) +{ + int offset = skb_gro_offset(skb); + struct xfrm_offload *xo; + struct xfrm_state *x; + __be32 seq; + __be32 spi; + int nhoff; + int err; + + if (!pskb_pull(skb, offset)) + return NULL; + + if ((err = xfrm_parse_spi(skb, IPPROTO_ESP, &spi, &seq)) != 0) + goto out; + + xo = xfrm_offload(skb); + if (!xo || !(xo->flags & CRYPTO_DONE)) { + struct sec_path *sp = secpath_set(skb); + + if (!sp) + goto out; + + if (sp->len == XFRM_MAX_DEPTH) + goto out_reset; + + x = xfrm_state_lookup(dev_net(skb->dev), skb->mark, + (xfrm_address_t *)&ipv6_hdr(skb)->daddr, + spi, IPPROTO_ESP, AF_INET6); + if (!x) + goto out_reset; + + skb->mark = xfrm_smark_get(skb->mark, x); + + sp->xvec[sp->len++] = x; + sp->olen++; + + xo = xfrm_offload(skb); + if (!xo) + goto out_reset; + } + + xo->flags |= XFRM_GRO; + + nhoff = esp6_nexthdr_esp_offset(ipv6_hdr(skb), offset); + if (!nhoff) + goto out; + + IP6CB(skb)->nhoff = nhoff; + XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip6 = NULL; + XFRM_SPI_SKB_CB(skb)->family = AF_INET6; + XFRM_SPI_SKB_CB(skb)->daddroff = offsetof(struct ipv6hdr, daddr); + XFRM_SPI_SKB_CB(skb)->seq = seq; + + /* We don't need to handle errors from xfrm_input, it does all + * the error handling and frees the resources on error. */ + xfrm_input(skb, IPPROTO_ESP, spi, -2); + + return ERR_PTR(-EINPROGRESS); +out_reset: + secpath_reset(skb); +out: + skb_push(skb, offset); + NAPI_GRO_CB(skb)->same_flow = 0; + NAPI_GRO_CB(skb)->flush = 1; + + return NULL; +} + +static void esp6_gso_encap(struct xfrm_state *x, struct sk_buff *skb) +{ + struct ip_esp_hdr *esph; + struct ipv6hdr *iph = ipv6_hdr(skb); + struct xfrm_offload *xo = xfrm_offload(skb); + u8 proto = iph->nexthdr; + + skb_push(skb, -skb_network_offset(skb)); + + if (x->outer_mode.encap == XFRM_MODE_TRANSPORT) { + __be16 frag; + + ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &proto, &frag); + } + + esph = ip_esp_hdr(skb); + *skb_mac_header(skb) = IPPROTO_ESP; + + esph->spi = x->id.spi; + esph->seq_no = htonl(XFRM_SKB_CB(skb)->seq.output.low); + + xo->proto = proto; +} + +static struct sk_buff *xfrm6_tunnel_gso_segment(struct xfrm_state *x, + struct sk_buff *skb, + netdev_features_t features) +{ + __be16 type = x->inner_mode.family == AF_INET ? htons(ETH_P_IP) + : htons(ETH_P_IPV6); + + return skb_eth_gso_segment(skb, features, type); +} + +static struct sk_buff *xfrm6_transport_gso_segment(struct xfrm_state *x, + struct sk_buff *skb, + netdev_features_t features) +{ + const struct net_offload *ops; + struct sk_buff *segs = ERR_PTR(-EINVAL); + struct xfrm_offload *xo = xfrm_offload(skb); + + skb->transport_header += x->props.header_len; + ops = rcu_dereference(inet6_offloads[xo->proto]); + if (likely(ops && ops->callbacks.gso_segment)) + segs = ops->callbacks.gso_segment(skb, features); + + return segs; +} + +static struct sk_buff *xfrm6_beet_gso_segment(struct xfrm_state *x, + struct sk_buff *skb, + netdev_features_t features) +{ + struct xfrm_offload *xo = xfrm_offload(skb); + struct sk_buff *segs = ERR_PTR(-EINVAL); + const struct net_offload *ops; + u8 proto = xo->proto; + + skb->transport_header += x->props.header_len; + + if (x->sel.family != AF_INET6) { + skb->transport_header -= + (sizeof(struct ipv6hdr) - sizeof(struct iphdr)); + + if (proto == IPPROTO_BEETPH) { + struct ip_beet_phdr *ph = + (struct ip_beet_phdr *)skb->data; + + skb->transport_header += ph->hdrlen * 8; + proto = ph->nexthdr; + } else { + skb->transport_header -= IPV4_BEET_PHMAXLEN; + } + + if (proto == IPPROTO_TCP) + skb_shinfo(skb)->gso_type |= SKB_GSO_TCPV6; + } else { + __be16 frag; + + skb->transport_header += + ipv6_skip_exthdr(skb, 0, &proto, &frag); + } + + if (proto == IPPROTO_IPIP) + skb_shinfo(skb)->gso_type |= SKB_GSO_IPXIP6; + + __skb_pull(skb, skb_transport_offset(skb)); + ops = rcu_dereference(inet6_offloads[proto]); + if (likely(ops && ops->callbacks.gso_segment)) + segs = ops->callbacks.gso_segment(skb, features); + + return segs; +} + +static struct sk_buff *xfrm6_outer_mode_gso_segment(struct xfrm_state *x, + struct sk_buff *skb, + netdev_features_t features) +{ + switch (x->outer_mode.encap) { + case XFRM_MODE_TUNNEL: + return xfrm6_tunnel_gso_segment(x, skb, features); + case XFRM_MODE_TRANSPORT: + return xfrm6_transport_gso_segment(x, skb, features); + case XFRM_MODE_BEET: + return xfrm6_beet_gso_segment(x, skb, features); + } + + return ERR_PTR(-EOPNOTSUPP); +} + +static struct sk_buff *esp6_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + struct xfrm_state *x; + struct ip_esp_hdr *esph; + struct crypto_aead *aead; + netdev_features_t esp_features = features; + struct xfrm_offload *xo = xfrm_offload(skb); + struct sec_path *sp; + + if (!xo) + return ERR_PTR(-EINVAL); + + if (!(skb_shinfo(skb)->gso_type & SKB_GSO_ESP)) + return ERR_PTR(-EINVAL); + + sp = skb_sec_path(skb); + x = sp->xvec[sp->len - 1]; + aead = x->data; + esph = ip_esp_hdr(skb); + + if (esph->spi != x->id.spi) + return ERR_PTR(-EINVAL); + + if (!pskb_may_pull(skb, sizeof(*esph) + crypto_aead_ivsize(aead))) + return ERR_PTR(-EINVAL); + + __skb_pull(skb, sizeof(*esph) + crypto_aead_ivsize(aead)); + + skb->encap_hdr_csum = 1; + + if (!(features & NETIF_F_HW_ESP) || x->xso.dev != skb->dev) + esp_features = features & ~(NETIF_F_SG | NETIF_F_CSUM_MASK | + NETIF_F_SCTP_CRC); + else if (!(features & NETIF_F_HW_ESP_TX_CSUM)) + esp_features = features & ~(NETIF_F_CSUM_MASK | + NETIF_F_SCTP_CRC); + + xo->flags |= XFRM_GSO_SEGMENT; + + return xfrm6_outer_mode_gso_segment(x, skb, esp_features); +} + +static int esp6_input_tail(struct xfrm_state *x, struct sk_buff *skb) +{ + struct crypto_aead *aead = x->data; + struct xfrm_offload *xo = xfrm_offload(skb); + + if (!pskb_may_pull(skb, sizeof(struct ip_esp_hdr) + crypto_aead_ivsize(aead))) + return -EINVAL; + + if (!(xo->flags & CRYPTO_DONE)) + skb->ip_summed = CHECKSUM_NONE; + + return esp6_input_done2(skb, 0); +} + +static int esp6_xmit(struct xfrm_state *x, struct sk_buff *skb, netdev_features_t features) +{ + int len; + int err; + int alen; + int blksize; + struct xfrm_offload *xo; + struct crypto_aead *aead; + struct esp_info esp; + bool hw_offload = true; + __u32 seq; + + esp.inplace = true; + + xo = xfrm_offload(skb); + + if (!xo) + return -EINVAL; + + if (!(features & NETIF_F_HW_ESP) || x->xso.dev != skb->dev) { + xo->flags |= CRYPTO_FALLBACK; + hw_offload = false; + } + + esp.proto = xo->proto; + + /* skb is pure payload to encrypt */ + + aead = x->data; + alen = crypto_aead_authsize(aead); + + esp.tfclen = 0; + /* XXX: Add support for tfc padding here. */ + + blksize = ALIGN(crypto_aead_blocksize(aead), 4); + esp.clen = ALIGN(skb->len + 2 + esp.tfclen, blksize); + esp.plen = esp.clen - skb->len - esp.tfclen; + esp.tailen = esp.tfclen + esp.plen + alen; + + if (!hw_offload || !skb_is_gso(skb)) { + esp.nfrags = esp6_output_head(x, skb, &esp); + if (esp.nfrags < 0) + return esp.nfrags; + } + + seq = xo->seq.low; + + esp.esph = ip_esp_hdr(skb); + esp.esph->spi = x->id.spi; + + skb_push(skb, -skb_network_offset(skb)); + + if (xo->flags & XFRM_GSO_SEGMENT) { + esp.esph->seq_no = htonl(seq); + + if (!skb_is_gso(skb)) + xo->seq.low++; + else + xo->seq.low += skb_shinfo(skb)->gso_segs; + } + + if (xo->seq.low < seq) + xo->seq.hi++; + + esp.seqno = cpu_to_be64(xo->seq.low + ((u64)xo->seq.hi << 32)); + + len = skb->len - sizeof(struct ipv6hdr); + if (len > IPV6_MAXPLEN) + len = 0; + + ipv6_hdr(skb)->payload_len = htons(len); + + if (hw_offload) { + if (!skb_ext_add(skb, SKB_EXT_SEC_PATH)) + return -ENOMEM; + + xo = xfrm_offload(skb); + if (!xo) + return -EINVAL; + + xo->flags |= XFRM_XMIT; + return 0; + } + + err = esp6_output_tail(x, skb, &esp); + if (err) + return err; + + secpath_reset(skb); + + if (skb_needs_linearize(skb, skb->dev->features) && + __skb_linearize(skb)) + return -ENOMEM; + return 0; +} + +static const struct net_offload esp6_offload = { + .callbacks = { + .gro_receive = esp6_gro_receive, + .gso_segment = esp6_gso_segment, + }, +}; + +static const struct xfrm_type_offload esp6_type_offload = { + .owner = THIS_MODULE, + .proto = IPPROTO_ESP, + .input_tail = esp6_input_tail, + .xmit = esp6_xmit, + .encap = esp6_gso_encap, +}; + +static int __init esp6_offload_init(void) +{ + if (xfrm_register_type_offload(&esp6_type_offload, AF_INET6) < 0) { + pr_info("%s: can't add xfrm type offload\n", __func__); + return -EAGAIN; + } + + return inet6_add_offload(&esp6_offload, IPPROTO_ESP); +} + +static void __exit esp6_offload_exit(void) +{ + xfrm_unregister_type_offload(&esp6_type_offload, AF_INET6); + inet6_del_offload(&esp6_offload, IPPROTO_ESP); +} + +module_init(esp6_offload_init); +module_exit(esp6_offload_exit); +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Steffen Klassert "); +MODULE_ALIAS_XFRM_OFFLOAD_TYPE(AF_INET6, XFRM_PROTO_ESP); +MODULE_DESCRIPTION("IPV6 GSO/GRO offload support"); diff --git a/net/ipv6/exthdrs.c b/net/ipv6/exthdrs.c new file mode 100644 index 000000000..5fa0e3730 --- /dev/null +++ b/net/ipv6/exthdrs.c @@ -0,0 +1,1401 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Extension Header handling for IPv6 + * Linux INET6 implementation + * + * Authors: + * Pedro Roque + * Andi Kleen + * Alexey Kuznetsov + */ + +/* Changes: + * yoshfuji : ensure not to overrun while parsing + * tlv options. + * Mitsuru KANDA @USAGI and: Remove ipv6_parse_exthdrs(). + * YOSHIFUJI Hideaki @USAGI Register inbound extension header + * handlers as inet6_protocol{}. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_IPV6_MIP6) +#include +#endif +#include +#include +#ifdef CONFIG_IPV6_SEG6_HMAC +#include +#endif +#include +#include +#include +#include + +#include + +/********************* + Generic functions + *********************/ + +/* An unknown option is detected, decide what to do */ + +static bool ip6_tlvopt_unknown(struct sk_buff *skb, int optoff, + bool disallow_unknowns) +{ + if (disallow_unknowns) { + /* If unknown TLVs are disallowed by configuration + * then always silently drop packet. Note this also + * means no ICMP parameter problem is sent which + * could be a good property to mitigate a reflection DOS + * attack. + */ + + goto drop; + } + + switch ((skb_network_header(skb)[optoff] & 0xC0) >> 6) { + case 0: /* ignore */ + return true; + + case 1: /* drop packet */ + break; + + case 3: /* Send ICMP if not a multicast address and drop packet */ + /* Actually, it is redundant check. icmp_send + will recheck in any case. + */ + if (ipv6_addr_is_multicast(&ipv6_hdr(skb)->daddr)) + break; + fallthrough; + case 2: /* send ICMP PARM PROB regardless and drop packet */ + icmpv6_param_prob_reason(skb, ICMPV6_UNK_OPTION, optoff, + SKB_DROP_REASON_UNHANDLED_PROTO); + return false; + } + +drop: + kfree_skb_reason(skb, SKB_DROP_REASON_UNHANDLED_PROTO); + return false; +} + +static bool ipv6_hop_ra(struct sk_buff *skb, int optoff); +static bool ipv6_hop_ioam(struct sk_buff *skb, int optoff); +static bool ipv6_hop_jumbo(struct sk_buff *skb, int optoff); +static bool ipv6_hop_calipso(struct sk_buff *skb, int optoff); +#if IS_ENABLED(CONFIG_IPV6_MIP6) +static bool ipv6_dest_hao(struct sk_buff *skb, int optoff); +#endif + +/* Parse tlv encoded option header (hop-by-hop or destination) */ + +static bool ip6_parse_tlv(bool hopbyhop, + struct sk_buff *skb, + int max_count) +{ + int len = (skb_transport_header(skb)[1] + 1) << 3; + const unsigned char *nh = skb_network_header(skb); + int off = skb_network_header_len(skb); + bool disallow_unknowns = false; + int tlv_count = 0; + int padlen = 0; + + if (unlikely(max_count < 0)) { + disallow_unknowns = true; + max_count = -max_count; + } + + if (skb_transport_offset(skb) + len > skb_headlen(skb)) + goto bad; + + off += 2; + len -= 2; + + while (len > 0) { + int optlen, i; + + if (nh[off] == IPV6_TLV_PAD1) { + padlen++; + if (padlen > 7) + goto bad; + off++; + len--; + continue; + } + if (len < 2) + goto bad; + optlen = nh[off + 1] + 2; + if (optlen > len) + goto bad; + + if (nh[off] == IPV6_TLV_PADN) { + /* RFC 2460 states that the purpose of PadN is + * to align the containing header to multiples + * of 8. 7 is therefore the highest valid value. + * See also RFC 4942, Section 2.1.9.5. + */ + padlen += optlen; + if (padlen > 7) + goto bad; + /* RFC 4942 recommends receiving hosts to + * actively check PadN payload to contain + * only zeroes. + */ + for (i = 2; i < optlen; i++) { + if (nh[off + i] != 0) + goto bad; + } + } else { + tlv_count++; + if (tlv_count > max_count) + goto bad; + + if (hopbyhop) { + switch (nh[off]) { + case IPV6_TLV_ROUTERALERT: + if (!ipv6_hop_ra(skb, off)) + return false; + break; + case IPV6_TLV_IOAM: + if (!ipv6_hop_ioam(skb, off)) + return false; + break; + case IPV6_TLV_JUMBO: + if (!ipv6_hop_jumbo(skb, off)) + return false; + break; + case IPV6_TLV_CALIPSO: + if (!ipv6_hop_calipso(skb, off)) + return false; + break; + default: + if (!ip6_tlvopt_unknown(skb, off, + disallow_unknowns)) + return false; + break; + } + } else { + switch (nh[off]) { +#if IS_ENABLED(CONFIG_IPV6_MIP6) + case IPV6_TLV_HAO: + if (!ipv6_dest_hao(skb, off)) + return false; + break; +#endif + default: + if (!ip6_tlvopt_unknown(skb, off, + disallow_unknowns)) + return false; + break; + } + } + padlen = 0; + } + off += optlen; + len -= optlen; + } + + if (len == 0) + return true; +bad: + kfree_skb_reason(skb, SKB_DROP_REASON_IP_INHDR); + return false; +} + +/***************************** + Destination options header. + *****************************/ + +#if IS_ENABLED(CONFIG_IPV6_MIP6) +static bool ipv6_dest_hao(struct sk_buff *skb, int optoff) +{ + struct ipv6_destopt_hao *hao; + struct inet6_skb_parm *opt = IP6CB(skb); + struct ipv6hdr *ipv6h = ipv6_hdr(skb); + SKB_DR(reason); + int ret; + + if (opt->dsthao) { + net_dbg_ratelimited("hao duplicated\n"); + goto discard; + } + opt->dsthao = opt->dst1; + opt->dst1 = 0; + + hao = (struct ipv6_destopt_hao *)(skb_network_header(skb) + optoff); + + if (hao->length != 16) { + net_dbg_ratelimited("hao invalid option length = %d\n", + hao->length); + SKB_DR_SET(reason, IP_INHDR); + goto discard; + } + + if (!(ipv6_addr_type(&hao->addr) & IPV6_ADDR_UNICAST)) { + net_dbg_ratelimited("hao is not an unicast addr: %pI6\n", + &hao->addr); + SKB_DR_SET(reason, INVALID_PROTO); + goto discard; + } + + ret = xfrm6_input_addr(skb, (xfrm_address_t *)&ipv6h->daddr, + (xfrm_address_t *)&hao->addr, IPPROTO_DSTOPTS); + if (unlikely(ret < 0)) { + SKB_DR_SET(reason, XFRM_POLICY); + goto discard; + } + + if (skb_cloned(skb)) { + if (pskb_expand_head(skb, 0, 0, GFP_ATOMIC)) + goto discard; + + /* update all variable using below by copied skbuff */ + hao = (struct ipv6_destopt_hao *)(skb_network_header(skb) + + optoff); + ipv6h = ipv6_hdr(skb); + } + + if (skb->ip_summed == CHECKSUM_COMPLETE) + skb->ip_summed = CHECKSUM_NONE; + + swap(ipv6h->saddr, hao->addr); + + if (skb->tstamp == 0) + __net_timestamp(skb); + + return true; + + discard: + kfree_skb_reason(skb, reason); + return false; +} +#endif + +static int ipv6_destopt_rcv(struct sk_buff *skb) +{ + struct inet6_dev *idev = __in6_dev_get(skb->dev); + struct inet6_skb_parm *opt = IP6CB(skb); +#if IS_ENABLED(CONFIG_IPV6_MIP6) + __u16 dstbuf; +#endif + struct dst_entry *dst = skb_dst(skb); + struct net *net = dev_net(skb->dev); + int extlen; + + if (!pskb_may_pull(skb, skb_transport_offset(skb) + 8) || + !pskb_may_pull(skb, (skb_transport_offset(skb) + + ((skb_transport_header(skb)[1] + 1) << 3)))) { + __IP6_INC_STATS(dev_net(dst->dev), idev, + IPSTATS_MIB_INHDRERRORS); +fail_and_free: + kfree_skb(skb); + return -1; + } + + extlen = (skb_transport_header(skb)[1] + 1) << 3; + if (extlen > net->ipv6.sysctl.max_dst_opts_len) + goto fail_and_free; + + opt->lastopt = opt->dst1 = skb_network_header_len(skb); +#if IS_ENABLED(CONFIG_IPV6_MIP6) + dstbuf = opt->dst1; +#endif + + if (ip6_parse_tlv(false, skb, net->ipv6.sysctl.max_dst_opts_cnt)) { + skb->transport_header += extlen; + opt = IP6CB(skb); +#if IS_ENABLED(CONFIG_IPV6_MIP6) + opt->nhoff = dstbuf; +#else + opt->nhoff = opt->dst1; +#endif + return 1; + } + + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INHDRERRORS); + return -1; +} + +static void seg6_update_csum(struct sk_buff *skb) +{ + struct ipv6_sr_hdr *hdr; + struct in6_addr *addr; + __be32 from, to; + + /* srh is at transport offset and seg_left is already decremented + * but daddr is not yet updated with next segment + */ + + hdr = (struct ipv6_sr_hdr *)skb_transport_header(skb); + addr = hdr->segments + hdr->segments_left; + + hdr->segments_left++; + from = *(__be32 *)hdr; + + hdr->segments_left--; + to = *(__be32 *)hdr; + + /* update skb csum with diff resulting from seg_left decrement */ + + update_csum_diff4(skb, from, to); + + /* compute csum diff between current and next segment and update */ + + update_csum_diff16(skb, (__be32 *)(&ipv6_hdr(skb)->daddr), + (__be32 *)addr); +} + +static int ipv6_srh_rcv(struct sk_buff *skb) +{ + struct inet6_skb_parm *opt = IP6CB(skb); + struct net *net = dev_net(skb->dev); + struct ipv6_sr_hdr *hdr; + struct inet6_dev *idev; + struct in6_addr *addr; + int accept_seg6; + + hdr = (struct ipv6_sr_hdr *)skb_transport_header(skb); + + idev = __in6_dev_get(skb->dev); + + accept_seg6 = net->ipv6.devconf_all->seg6_enabled; + if (accept_seg6 > idev->cnf.seg6_enabled) + accept_seg6 = idev->cnf.seg6_enabled; + + if (!accept_seg6) { + kfree_skb(skb); + return -1; + } + +#ifdef CONFIG_IPV6_SEG6_HMAC + if (!seg6_hmac_validate_skb(skb)) { + kfree_skb(skb); + return -1; + } +#endif + +looped_back: + if (hdr->segments_left == 0) { + if (hdr->nexthdr == NEXTHDR_IPV6 || hdr->nexthdr == NEXTHDR_IPV4) { + int offset = (hdr->hdrlen + 1) << 3; + + skb_postpull_rcsum(skb, skb_network_header(skb), + skb_network_header_len(skb)); + + if (!pskb_pull(skb, offset)) { + kfree_skb(skb); + return -1; + } + skb_postpull_rcsum(skb, skb_transport_header(skb), + offset); + + skb_reset_network_header(skb); + skb_reset_transport_header(skb); + skb->encapsulation = 0; + if (hdr->nexthdr == NEXTHDR_IPV4) + skb->protocol = htons(ETH_P_IP); + __skb_tunnel_rx(skb, skb->dev, net); + + netif_rx(skb); + return -1; + } + + opt->srcrt = skb_network_header_len(skb); + opt->lastopt = opt->srcrt; + skb->transport_header += (hdr->hdrlen + 1) << 3; + opt->nhoff = (&hdr->nexthdr) - skb_network_header(skb); + + return 1; + } + + if (hdr->segments_left >= (hdr->hdrlen >> 1)) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INHDRERRORS); + icmpv6_param_prob(skb, ICMPV6_HDR_FIELD, + ((&hdr->segments_left) - + skb_network_header(skb))); + return -1; + } + + if (skb_cloned(skb)) { + if (pskb_expand_head(skb, 0, 0, GFP_ATOMIC)) { + __IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)), + IPSTATS_MIB_OUTDISCARDS); + kfree_skb(skb); + return -1; + } + } + + hdr = (struct ipv6_sr_hdr *)skb_transport_header(skb); + + hdr->segments_left--; + addr = hdr->segments + hdr->segments_left; + + skb_push(skb, sizeof(struct ipv6hdr)); + + if (skb->ip_summed == CHECKSUM_COMPLETE) + seg6_update_csum(skb); + + ipv6_hdr(skb)->daddr = *addr; + + skb_dst_drop(skb); + + ip6_route_input(skb); + + if (skb_dst(skb)->error) { + dst_input(skb); + return -1; + } + + if (skb_dst(skb)->dev->flags & IFF_LOOPBACK) { + if (ipv6_hdr(skb)->hop_limit <= 1) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INHDRERRORS); + icmpv6_send(skb, ICMPV6_TIME_EXCEED, + ICMPV6_EXC_HOPLIMIT, 0); + kfree_skb(skb); + return -1; + } + ipv6_hdr(skb)->hop_limit--; + + skb_pull(skb, sizeof(struct ipv6hdr)); + goto looped_back; + } + + dst_input(skb); + + return -1; +} + +static int ipv6_rpl_srh_rcv(struct sk_buff *skb) +{ + struct ipv6_rpl_sr_hdr *hdr, *ohdr, *chdr; + struct inet6_skb_parm *opt = IP6CB(skb); + struct net *net = dev_net(skb->dev); + struct inet6_dev *idev; + struct ipv6hdr *oldhdr; + unsigned char *buf; + int accept_rpl_seg; + int i, err; + u64 n = 0; + u32 r; + + idev = __in6_dev_get(skb->dev); + + accept_rpl_seg = net->ipv6.devconf_all->rpl_seg_enabled; + if (accept_rpl_seg > idev->cnf.rpl_seg_enabled) + accept_rpl_seg = idev->cnf.rpl_seg_enabled; + + if (!accept_rpl_seg) { + kfree_skb(skb); + return -1; + } + +looped_back: + hdr = (struct ipv6_rpl_sr_hdr *)skb_transport_header(skb); + + if (hdr->segments_left == 0) { + if (hdr->nexthdr == NEXTHDR_IPV6) { + int offset = (hdr->hdrlen + 1) << 3; + + skb_postpull_rcsum(skb, skb_network_header(skb), + skb_network_header_len(skb)); + + if (!pskb_pull(skb, offset)) { + kfree_skb(skb); + return -1; + } + skb_postpull_rcsum(skb, skb_transport_header(skb), + offset); + + skb_reset_network_header(skb); + skb_reset_transport_header(skb); + skb->encapsulation = 0; + + __skb_tunnel_rx(skb, skb->dev, net); + + netif_rx(skb); + return -1; + } + + opt->srcrt = skb_network_header_len(skb); + opt->lastopt = opt->srcrt; + skb->transport_header += (hdr->hdrlen + 1) << 3; + opt->nhoff = (&hdr->nexthdr) - skb_network_header(skb); + + return 1; + } + + if (!pskb_may_pull(skb, sizeof(*hdr))) { + kfree_skb(skb); + return -1; + } + + n = (hdr->hdrlen << 3) - hdr->pad - (16 - hdr->cmpre); + r = do_div(n, (16 - hdr->cmpri)); + /* checks if calculation was without remainder and n fits into + * unsigned char which is segments_left field. Should not be + * higher than that. + */ + if (r || (n + 1) > 255) { + kfree_skb(skb); + return -1; + } + + if (hdr->segments_left > n + 1) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INHDRERRORS); + icmpv6_param_prob(skb, ICMPV6_HDR_FIELD, + ((&hdr->segments_left) - + skb_network_header(skb))); + return -1; + } + + if (!pskb_may_pull(skb, ipv6_rpl_srh_size(n, hdr->cmpri, + hdr->cmpre))) { + kfree_skb(skb); + return -1; + } + + hdr->segments_left--; + i = n - hdr->segments_left; + + buf = kcalloc(struct_size(hdr, segments.addr, n + 2), 2, GFP_ATOMIC); + if (unlikely(!buf)) { + kfree_skb(skb); + return -1; + } + + ohdr = (struct ipv6_rpl_sr_hdr *)buf; + ipv6_rpl_srh_decompress(ohdr, hdr, &ipv6_hdr(skb)->daddr, n); + chdr = (struct ipv6_rpl_sr_hdr *)(buf + ((ohdr->hdrlen + 1) << 3)); + + if ((ipv6_addr_type(&ipv6_hdr(skb)->daddr) & IPV6_ADDR_MULTICAST) || + (ipv6_addr_type(&ohdr->rpl_segaddr[i]) & IPV6_ADDR_MULTICAST)) { + kfree_skb(skb); + kfree(buf); + return -1; + } + + err = ipv6_chk_rpl_srh_loop(net, ohdr->rpl_segaddr, n + 1); + if (err) { + icmpv6_send(skb, ICMPV6_PARAMPROB, 0, 0); + kfree_skb(skb); + kfree(buf); + return -1; + } + + swap(ipv6_hdr(skb)->daddr, ohdr->rpl_segaddr[i]); + + ipv6_rpl_srh_compress(chdr, ohdr, &ipv6_hdr(skb)->daddr, n); + + oldhdr = ipv6_hdr(skb); + + skb_pull(skb, ((hdr->hdrlen + 1) << 3)); + skb_postpull_rcsum(skb, oldhdr, + sizeof(struct ipv6hdr) + ((hdr->hdrlen + 1) << 3)); + if (unlikely(!hdr->segments_left)) { + if (pskb_expand_head(skb, sizeof(struct ipv6hdr) + ((chdr->hdrlen + 1) << 3), 0, + GFP_ATOMIC)) { + __IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)), IPSTATS_MIB_OUTDISCARDS); + kfree_skb(skb); + kfree(buf); + return -1; + } + + oldhdr = ipv6_hdr(skb); + } + skb_push(skb, ((chdr->hdrlen + 1) << 3) + sizeof(struct ipv6hdr)); + skb_reset_network_header(skb); + skb_mac_header_rebuild(skb); + skb_set_transport_header(skb, sizeof(struct ipv6hdr)); + + memmove(ipv6_hdr(skb), oldhdr, sizeof(struct ipv6hdr)); + memcpy(skb_transport_header(skb), chdr, (chdr->hdrlen + 1) << 3); + + ipv6_hdr(skb)->payload_len = htons(skb->len - sizeof(struct ipv6hdr)); + skb_postpush_rcsum(skb, ipv6_hdr(skb), + sizeof(struct ipv6hdr) + ((chdr->hdrlen + 1) << 3)); + + kfree(buf); + + skb_dst_drop(skb); + + ip6_route_input(skb); + + if (skb_dst(skb)->error) { + dst_input(skb); + return -1; + } + + if (skb_dst(skb)->dev->flags & IFF_LOOPBACK) { + if (ipv6_hdr(skb)->hop_limit <= 1) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INHDRERRORS); + icmpv6_send(skb, ICMPV6_TIME_EXCEED, + ICMPV6_EXC_HOPLIMIT, 0); + kfree_skb(skb); + return -1; + } + ipv6_hdr(skb)->hop_limit--; + + skb_pull(skb, sizeof(struct ipv6hdr)); + goto looped_back; + } + + dst_input(skb); + + return -1; +} + +/******************************** + Routing header. + ********************************/ + +/* called with rcu_read_lock() */ +static int ipv6_rthdr_rcv(struct sk_buff *skb) +{ + struct inet6_dev *idev = __in6_dev_get(skb->dev); + struct inet6_skb_parm *opt = IP6CB(skb); + struct in6_addr *addr = NULL; + struct in6_addr daddr; + int n, i; + struct ipv6_rt_hdr *hdr; + struct rt0_hdr *rthdr; + struct net *net = dev_net(skb->dev); + int accept_source_route = net->ipv6.devconf_all->accept_source_route; + + if (idev && accept_source_route > idev->cnf.accept_source_route) + accept_source_route = idev->cnf.accept_source_route; + + if (!pskb_may_pull(skb, skb_transport_offset(skb) + 8) || + !pskb_may_pull(skb, (skb_transport_offset(skb) + + ((skb_transport_header(skb)[1] + 1) << 3)))) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INHDRERRORS); + kfree_skb(skb); + return -1; + } + + hdr = (struct ipv6_rt_hdr *)skb_transport_header(skb); + + if (ipv6_addr_is_multicast(&ipv6_hdr(skb)->daddr) || + skb->pkt_type != PACKET_HOST) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INADDRERRORS); + kfree_skb(skb); + return -1; + } + + switch (hdr->type) { + case IPV6_SRCRT_TYPE_4: + /* segment routing */ + return ipv6_srh_rcv(skb); + case IPV6_SRCRT_TYPE_3: + /* rpl segment routing */ + return ipv6_rpl_srh_rcv(skb); + default: + break; + } + +looped_back: + if (hdr->segments_left == 0) { + switch (hdr->type) { +#if IS_ENABLED(CONFIG_IPV6_MIP6) + case IPV6_SRCRT_TYPE_2: + /* Silently discard type 2 header unless it was + * processed by own + */ + if (!addr) { + __IP6_INC_STATS(net, idev, + IPSTATS_MIB_INADDRERRORS); + kfree_skb(skb); + return -1; + } + break; +#endif + default: + break; + } + + opt->lastopt = opt->srcrt = skb_network_header_len(skb); + skb->transport_header += (hdr->hdrlen + 1) << 3; + opt->dst0 = opt->dst1; + opt->dst1 = 0; + opt->nhoff = (&hdr->nexthdr) - skb_network_header(skb); + return 1; + } + + switch (hdr->type) { +#if IS_ENABLED(CONFIG_IPV6_MIP6) + case IPV6_SRCRT_TYPE_2: + if (accept_source_route < 0) + goto unknown_rh; + /* Silently discard invalid RTH type 2 */ + if (hdr->hdrlen != 2 || hdr->segments_left != 1) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INHDRERRORS); + kfree_skb(skb); + return -1; + } + break; +#endif + default: + goto unknown_rh; + } + + /* + * This is the routing header forwarding algorithm from + * RFC 2460, page 16. + */ + + n = hdr->hdrlen >> 1; + + if (hdr->segments_left > n) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INHDRERRORS); + icmpv6_param_prob(skb, ICMPV6_HDR_FIELD, + ((&hdr->segments_left) - + skb_network_header(skb))); + return -1; + } + + /* We are about to mangle packet header. Be careful! + Do not damage packets queued somewhere. + */ + if (skb_cloned(skb)) { + /* the copy is a forwarded packet */ + if (pskb_expand_head(skb, 0, 0, GFP_ATOMIC)) { + __IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)), + IPSTATS_MIB_OUTDISCARDS); + kfree_skb(skb); + return -1; + } + hdr = (struct ipv6_rt_hdr *)skb_transport_header(skb); + } + + if (skb->ip_summed == CHECKSUM_COMPLETE) + skb->ip_summed = CHECKSUM_NONE; + + i = n - --hdr->segments_left; + + rthdr = (struct rt0_hdr *) hdr; + addr = rthdr->addr; + addr += i - 1; + + switch (hdr->type) { +#if IS_ENABLED(CONFIG_IPV6_MIP6) + case IPV6_SRCRT_TYPE_2: + if (xfrm6_input_addr(skb, (xfrm_address_t *)addr, + (xfrm_address_t *)&ipv6_hdr(skb)->saddr, + IPPROTO_ROUTING) < 0) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INADDRERRORS); + kfree_skb(skb); + return -1; + } + if (!ipv6_chk_home_addr(dev_net(skb_dst(skb)->dev), addr)) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INADDRERRORS); + kfree_skb(skb); + return -1; + } + break; +#endif + default: + break; + } + + if (ipv6_addr_is_multicast(addr)) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INADDRERRORS); + kfree_skb(skb); + return -1; + } + + daddr = *addr; + *addr = ipv6_hdr(skb)->daddr; + ipv6_hdr(skb)->daddr = daddr; + + skb_dst_drop(skb); + ip6_route_input(skb); + if (skb_dst(skb)->error) { + skb_push(skb, skb->data - skb_network_header(skb)); + dst_input(skb); + return -1; + } + + if (skb_dst(skb)->dev->flags&IFF_LOOPBACK) { + if (ipv6_hdr(skb)->hop_limit <= 1) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INHDRERRORS); + icmpv6_send(skb, ICMPV6_TIME_EXCEED, ICMPV6_EXC_HOPLIMIT, + 0); + kfree_skb(skb); + return -1; + } + ipv6_hdr(skb)->hop_limit--; + goto looped_back; + } + + skb_push(skb, skb->data - skb_network_header(skb)); + dst_input(skb); + return -1; + +unknown_rh: + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INHDRERRORS); + icmpv6_param_prob(skb, ICMPV6_HDR_FIELD, + (&hdr->type) - skb_network_header(skb)); + return -1; +} + +static const struct inet6_protocol rthdr_protocol = { + .handler = ipv6_rthdr_rcv, + .flags = INET6_PROTO_NOPOLICY, +}; + +static const struct inet6_protocol destopt_protocol = { + .handler = ipv6_destopt_rcv, + .flags = INET6_PROTO_NOPOLICY, +}; + +static const struct inet6_protocol nodata_protocol = { + .handler = dst_discard, + .flags = INET6_PROTO_NOPOLICY, +}; + +int __init ipv6_exthdrs_init(void) +{ + int ret; + + ret = inet6_add_protocol(&rthdr_protocol, IPPROTO_ROUTING); + if (ret) + goto out; + + ret = inet6_add_protocol(&destopt_protocol, IPPROTO_DSTOPTS); + if (ret) + goto out_rthdr; + + ret = inet6_add_protocol(&nodata_protocol, IPPROTO_NONE); + if (ret) + goto out_destopt; + +out: + return ret; +out_destopt: + inet6_del_protocol(&destopt_protocol, IPPROTO_DSTOPTS); +out_rthdr: + inet6_del_protocol(&rthdr_protocol, IPPROTO_ROUTING); + goto out; +}; + +void ipv6_exthdrs_exit(void) +{ + inet6_del_protocol(&nodata_protocol, IPPROTO_NONE); + inet6_del_protocol(&destopt_protocol, IPPROTO_DSTOPTS); + inet6_del_protocol(&rthdr_protocol, IPPROTO_ROUTING); +} + +/********************************** + Hop-by-hop options. + **********************************/ + +/* + * Note: we cannot rely on skb_dst(skb) before we assign it in ip6_route_input(). + */ +static inline struct net *ipv6_skb_net(struct sk_buff *skb) +{ + return skb_dst(skb) ? dev_net(skb_dst(skb)->dev) : dev_net(skb->dev); +} + +/* Router Alert as of RFC 2711 */ + +static bool ipv6_hop_ra(struct sk_buff *skb, int optoff) +{ + const unsigned char *nh = skb_network_header(skb); + + if (nh[optoff + 1] == 2) { + IP6CB(skb)->flags |= IP6SKB_ROUTERALERT; + memcpy(&IP6CB(skb)->ra, nh + optoff + 2, sizeof(IP6CB(skb)->ra)); + return true; + } + net_dbg_ratelimited("ipv6_hop_ra: wrong RA length %d\n", + nh[optoff + 1]); + kfree_skb_reason(skb, SKB_DROP_REASON_IP_INHDR); + return false; +} + +/* IOAM */ + +static bool ipv6_hop_ioam(struct sk_buff *skb, int optoff) +{ + struct ioam6_trace_hdr *trace; + struct ioam6_namespace *ns; + struct ioam6_hdr *hdr; + + /* Bad alignment (must be 4n-aligned) */ + if (optoff & 3) + goto drop; + + /* Ignore if IOAM is not enabled on ingress */ + if (!__in6_dev_get(skb->dev)->cnf.ioam6_enabled) + goto ignore; + + /* Truncated Option header */ + hdr = (struct ioam6_hdr *)(skb_network_header(skb) + optoff); + if (hdr->opt_len < 2) + goto drop; + + switch (hdr->type) { + case IOAM6_TYPE_PREALLOC: + /* Truncated Pre-allocated Trace header */ + if (hdr->opt_len < 2 + sizeof(*trace)) + goto drop; + + /* Malformed Pre-allocated Trace header */ + trace = (struct ioam6_trace_hdr *)((u8 *)hdr + sizeof(*hdr)); + if (hdr->opt_len < 2 + sizeof(*trace) + trace->remlen * 4) + goto drop; + + /* Ignore if the IOAM namespace is unknown */ + ns = ioam6_namespace(ipv6_skb_net(skb), trace->namespace_id); + if (!ns) + goto ignore; + + if (!skb_valid_dst(skb)) + ip6_route_input(skb); + + ioam6_fill_trace_data(skb, ns, trace, true); + break; + default: + break; + } + +ignore: + return true; + +drop: + kfree_skb_reason(skb, SKB_DROP_REASON_IP_INHDR); + return false; +} + +/* Jumbo payload */ + +static bool ipv6_hop_jumbo(struct sk_buff *skb, int optoff) +{ + const unsigned char *nh = skb_network_header(skb); + SKB_DR(reason); + u32 pkt_len; + + if (nh[optoff + 1] != 4 || (optoff & 3) != 2) { + net_dbg_ratelimited("ipv6_hop_jumbo: wrong jumbo opt length/alignment %d\n", + nh[optoff+1]); + SKB_DR_SET(reason, IP_INHDR); + goto drop; + } + + pkt_len = ntohl(*(__be32 *)(nh + optoff + 2)); + if (pkt_len <= IPV6_MAXPLEN) { + icmpv6_param_prob_reason(skb, ICMPV6_HDR_FIELD, optoff + 2, + SKB_DROP_REASON_IP_INHDR); + return false; + } + if (ipv6_hdr(skb)->payload_len) { + icmpv6_param_prob_reason(skb, ICMPV6_HDR_FIELD, optoff, + SKB_DROP_REASON_IP_INHDR); + return false; + } + + if (pkt_len > skb->len - sizeof(struct ipv6hdr)) { + SKB_DR_SET(reason, PKT_TOO_SMALL); + goto drop; + } + + if (pskb_trim_rcsum(skb, pkt_len + sizeof(struct ipv6hdr))) + goto drop; + + IP6CB(skb)->flags |= IP6SKB_JUMBOGRAM; + return true; + +drop: + kfree_skb_reason(skb, reason); + return false; +} + +/* CALIPSO RFC 5570 */ + +static bool ipv6_hop_calipso(struct sk_buff *skb, int optoff) +{ + const unsigned char *nh = skb_network_header(skb); + + if (nh[optoff + 1] < 8) + goto drop; + + if (nh[optoff + 6] * 4 + 8 > nh[optoff + 1]) + goto drop; + + if (!calipso_validate(skb, nh + optoff)) + goto drop; + + return true; + +drop: + kfree_skb_reason(skb, SKB_DROP_REASON_IP_INHDR); + return false; +} + +int ipv6_parse_hopopts(struct sk_buff *skb) +{ + struct inet6_skb_parm *opt = IP6CB(skb); + struct net *net = dev_net(skb->dev); + int extlen; + + /* + * skb_network_header(skb) is equal to skb->data, and + * skb_network_header_len(skb) is always equal to + * sizeof(struct ipv6hdr) by definition of + * hop-by-hop options. + */ + if (!pskb_may_pull(skb, sizeof(struct ipv6hdr) + 8) || + !pskb_may_pull(skb, (sizeof(struct ipv6hdr) + + ((skb_transport_header(skb)[1] + 1) << 3)))) { +fail_and_free: + kfree_skb(skb); + return -1; + } + + extlen = (skb_transport_header(skb)[1] + 1) << 3; + if (extlen > net->ipv6.sysctl.max_hbh_opts_len) + goto fail_and_free; + + opt->flags |= IP6SKB_HOPBYHOP; + if (ip6_parse_tlv(true, skb, net->ipv6.sysctl.max_hbh_opts_cnt)) { + skb->transport_header += extlen; + opt = IP6CB(skb); + opt->nhoff = sizeof(struct ipv6hdr); + return 1; + } + return -1; +} + +/* + * Creating outbound headers. + * + * "build" functions work when skb is filled from head to tail (datagram) + * "push" functions work when headers are added from tail to head (tcp) + * + * In both cases we assume, that caller reserved enough room + * for headers. + */ + +static void ipv6_push_rthdr0(struct sk_buff *skb, u8 *proto, + struct ipv6_rt_hdr *opt, + struct in6_addr **addr_p, struct in6_addr *saddr) +{ + struct rt0_hdr *phdr, *ihdr; + int hops; + + ihdr = (struct rt0_hdr *) opt; + + phdr = skb_push(skb, (ihdr->rt_hdr.hdrlen + 1) << 3); + memcpy(phdr, ihdr, sizeof(struct rt0_hdr)); + + hops = ihdr->rt_hdr.hdrlen >> 1; + + if (hops > 1) + memcpy(phdr->addr, ihdr->addr + 1, + (hops - 1) * sizeof(struct in6_addr)); + + phdr->addr[hops - 1] = **addr_p; + *addr_p = ihdr->addr; + + phdr->rt_hdr.nexthdr = *proto; + *proto = NEXTHDR_ROUTING; +} + +static void ipv6_push_rthdr4(struct sk_buff *skb, u8 *proto, + struct ipv6_rt_hdr *opt, + struct in6_addr **addr_p, struct in6_addr *saddr) +{ + struct ipv6_sr_hdr *sr_phdr, *sr_ihdr; + int plen, hops; + + sr_ihdr = (struct ipv6_sr_hdr *)opt; + plen = (sr_ihdr->hdrlen + 1) << 3; + + sr_phdr = skb_push(skb, plen); + memcpy(sr_phdr, sr_ihdr, sizeof(struct ipv6_sr_hdr)); + + hops = sr_ihdr->first_segment + 1; + memcpy(sr_phdr->segments + 1, sr_ihdr->segments + 1, + (hops - 1) * sizeof(struct in6_addr)); + + sr_phdr->segments[0] = **addr_p; + *addr_p = &sr_ihdr->segments[sr_ihdr->segments_left]; + + if (sr_ihdr->hdrlen > hops * 2) { + int tlvs_offset, tlvs_length; + + tlvs_offset = (1 + hops * 2) << 3; + tlvs_length = (sr_ihdr->hdrlen - hops * 2) << 3; + memcpy((char *)sr_phdr + tlvs_offset, + (char *)sr_ihdr + tlvs_offset, tlvs_length); + } + +#ifdef CONFIG_IPV6_SEG6_HMAC + if (sr_has_hmac(sr_phdr)) { + struct net *net = NULL; + + if (skb->dev) + net = dev_net(skb->dev); + else if (skb->sk) + net = sock_net(skb->sk); + + WARN_ON(!net); + + if (net) + seg6_push_hmac(net, saddr, sr_phdr); + } +#endif + + sr_phdr->nexthdr = *proto; + *proto = NEXTHDR_ROUTING; +} + +static void ipv6_push_rthdr(struct sk_buff *skb, u8 *proto, + struct ipv6_rt_hdr *opt, + struct in6_addr **addr_p, struct in6_addr *saddr) +{ + switch (opt->type) { + case IPV6_SRCRT_TYPE_0: + case IPV6_SRCRT_STRICT: + case IPV6_SRCRT_TYPE_2: + ipv6_push_rthdr0(skb, proto, opt, addr_p, saddr); + break; + case IPV6_SRCRT_TYPE_4: + ipv6_push_rthdr4(skb, proto, opt, addr_p, saddr); + break; + default: + break; + } +} + +static void ipv6_push_exthdr(struct sk_buff *skb, u8 *proto, u8 type, struct ipv6_opt_hdr *opt) +{ + struct ipv6_opt_hdr *h = skb_push(skb, ipv6_optlen(opt)); + + memcpy(h, opt, ipv6_optlen(opt)); + h->nexthdr = *proto; + *proto = type; +} + +void ipv6_push_nfrag_opts(struct sk_buff *skb, struct ipv6_txoptions *opt, + u8 *proto, + struct in6_addr **daddr, struct in6_addr *saddr) +{ + if (opt->srcrt) { + ipv6_push_rthdr(skb, proto, opt->srcrt, daddr, saddr); + /* + * IPV6_RTHDRDSTOPTS is ignored + * unless IPV6_RTHDR is set (RFC3542). + */ + if (opt->dst0opt) + ipv6_push_exthdr(skb, proto, NEXTHDR_DEST, opt->dst0opt); + } + if (opt->hopopt) + ipv6_push_exthdr(skb, proto, NEXTHDR_HOP, opt->hopopt); +} + +void ipv6_push_frag_opts(struct sk_buff *skb, struct ipv6_txoptions *opt, u8 *proto) +{ + if (opt->dst1opt) + ipv6_push_exthdr(skb, proto, NEXTHDR_DEST, opt->dst1opt); +} +EXPORT_SYMBOL(ipv6_push_frag_opts); + +struct ipv6_txoptions * +ipv6_dup_options(struct sock *sk, struct ipv6_txoptions *opt) +{ + struct ipv6_txoptions *opt2; + + opt2 = sock_kmalloc(sk, opt->tot_len, GFP_ATOMIC); + if (opt2) { + long dif = (char *)opt2 - (char *)opt; + memcpy(opt2, opt, opt->tot_len); + if (opt2->hopopt) + *((char **)&opt2->hopopt) += dif; + if (opt2->dst0opt) + *((char **)&opt2->dst0opt) += dif; + if (opt2->dst1opt) + *((char **)&opt2->dst1opt) += dif; + if (opt2->srcrt) + *((char **)&opt2->srcrt) += dif; + refcount_set(&opt2->refcnt, 1); + } + return opt2; +} +EXPORT_SYMBOL_GPL(ipv6_dup_options); + +static void ipv6_renew_option(int renewtype, + struct ipv6_opt_hdr **dest, + struct ipv6_opt_hdr *old, + struct ipv6_opt_hdr *new, + int newtype, char **p) +{ + struct ipv6_opt_hdr *src; + + src = (renewtype == newtype ? new : old); + if (!src) + return; + + memcpy(*p, src, ipv6_optlen(src)); + *dest = (struct ipv6_opt_hdr *)*p; + *p += CMSG_ALIGN(ipv6_optlen(*dest)); +} + +/** + * ipv6_renew_options - replace a specific ext hdr with a new one. + * + * @sk: sock from which to allocate memory + * @opt: original options + * @newtype: option type to replace in @opt + * @newopt: new option of type @newtype to replace (user-mem) + * + * Returns a new set of options which is a copy of @opt with the + * option type @newtype replaced with @newopt. + * + * @opt may be NULL, in which case a new set of options is returned + * containing just @newopt. + * + * @newopt may be NULL, in which case the specified option type is + * not copied into the new set of options. + * + * The new set of options is allocated from the socket option memory + * buffer of @sk. + */ +struct ipv6_txoptions * +ipv6_renew_options(struct sock *sk, struct ipv6_txoptions *opt, + int newtype, struct ipv6_opt_hdr *newopt) +{ + int tot_len = 0; + char *p; + struct ipv6_txoptions *opt2; + + if (opt) { + if (newtype != IPV6_HOPOPTS && opt->hopopt) + tot_len += CMSG_ALIGN(ipv6_optlen(opt->hopopt)); + if (newtype != IPV6_RTHDRDSTOPTS && opt->dst0opt) + tot_len += CMSG_ALIGN(ipv6_optlen(opt->dst0opt)); + if (newtype != IPV6_RTHDR && opt->srcrt) + tot_len += CMSG_ALIGN(ipv6_optlen(opt->srcrt)); + if (newtype != IPV6_DSTOPTS && opt->dst1opt) + tot_len += CMSG_ALIGN(ipv6_optlen(opt->dst1opt)); + } + + if (newopt) + tot_len += CMSG_ALIGN(ipv6_optlen(newopt)); + + if (!tot_len) + return NULL; + + tot_len += sizeof(*opt2); + opt2 = sock_kmalloc(sk, tot_len, GFP_ATOMIC); + if (!opt2) + return ERR_PTR(-ENOBUFS); + + memset(opt2, 0, tot_len); + refcount_set(&opt2->refcnt, 1); + opt2->tot_len = tot_len; + p = (char *)(opt2 + 1); + + ipv6_renew_option(IPV6_HOPOPTS, &opt2->hopopt, + (opt ? opt->hopopt : NULL), + newopt, newtype, &p); + ipv6_renew_option(IPV6_RTHDRDSTOPTS, &opt2->dst0opt, + (opt ? opt->dst0opt : NULL), + newopt, newtype, &p); + ipv6_renew_option(IPV6_RTHDR, + (struct ipv6_opt_hdr **)&opt2->srcrt, + (opt ? (struct ipv6_opt_hdr *)opt->srcrt : NULL), + newopt, newtype, &p); + ipv6_renew_option(IPV6_DSTOPTS, &opt2->dst1opt, + (opt ? opt->dst1opt : NULL), + newopt, newtype, &p); + + opt2->opt_nflen = (opt2->hopopt ? ipv6_optlen(opt2->hopopt) : 0) + + (opt2->dst0opt ? ipv6_optlen(opt2->dst0opt) : 0) + + (opt2->srcrt ? ipv6_optlen(opt2->srcrt) : 0); + opt2->opt_flen = (opt2->dst1opt ? ipv6_optlen(opt2->dst1opt) : 0); + + return opt2; +} + +struct ipv6_txoptions *__ipv6_fixup_options(struct ipv6_txoptions *opt_space, + struct ipv6_txoptions *opt) +{ + /* + * ignore the dest before srcrt unless srcrt is being included. + * --yoshfuji + */ + if (opt->dst0opt && !opt->srcrt) { + if (opt_space != opt) { + memcpy(opt_space, opt, sizeof(*opt_space)); + opt = opt_space; + } + opt->opt_nflen -= ipv6_optlen(opt->dst0opt); + opt->dst0opt = NULL; + } + + return opt; +} +EXPORT_SYMBOL_GPL(__ipv6_fixup_options); + +/** + * fl6_update_dst - update flowi destination address with info given + * by srcrt option, if any. + * + * @fl6: flowi6 for which daddr is to be updated + * @opt: struct ipv6_txoptions in which to look for srcrt opt + * @orig: copy of original daddr address if modified + * + * Returns NULL if no txoptions or no srcrt, otherwise returns orig + * and initial value of fl6->daddr set in orig + */ +struct in6_addr *fl6_update_dst(struct flowi6 *fl6, + const struct ipv6_txoptions *opt, + struct in6_addr *orig) +{ + if (!opt || !opt->srcrt) + return NULL; + + *orig = fl6->daddr; + + switch (opt->srcrt->type) { + case IPV6_SRCRT_TYPE_0: + case IPV6_SRCRT_STRICT: + case IPV6_SRCRT_TYPE_2: + fl6->daddr = *((struct rt0_hdr *)opt->srcrt)->addr; + break; + case IPV6_SRCRT_TYPE_4: + { + struct ipv6_sr_hdr *srh = (struct ipv6_sr_hdr *)opt->srcrt; + + fl6->daddr = srh->segments[srh->segments_left]; + break; + } + default: + return NULL; + } + + return orig; +} +EXPORT_SYMBOL_GPL(fl6_update_dst); diff --git a/net/ipv6/exthdrs_core.c b/net/ipv6/exthdrs_core.c new file mode 100644 index 000000000..49e31e4ae --- /dev/null +++ b/net/ipv6/exthdrs_core.c @@ -0,0 +1,282 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * IPv6 library code, needed by static components when full IPv6 support is + * not configured or static. + */ +#include +#include + +/* + * find out if nexthdr is a well-known extension header or a protocol + */ + +bool ipv6_ext_hdr(u8 nexthdr) +{ + /* + * find out if nexthdr is an extension header or a protocol + */ + return (nexthdr == NEXTHDR_HOP) || + (nexthdr == NEXTHDR_ROUTING) || + (nexthdr == NEXTHDR_FRAGMENT) || + (nexthdr == NEXTHDR_AUTH) || + (nexthdr == NEXTHDR_NONE) || + (nexthdr == NEXTHDR_DEST); +} +EXPORT_SYMBOL(ipv6_ext_hdr); + +/* + * Skip any extension headers. This is used by the ICMP module. + * + * Note that strictly speaking this conflicts with RFC 2460 4.0: + * ...The contents and semantics of each extension header determine whether + * or not to proceed to the next header. Therefore, extension headers must + * be processed strictly in the order they appear in the packet; a + * receiver must not, for example, scan through a packet looking for a + * particular kind of extension header and process that header prior to + * processing all preceding ones. + * + * We do exactly this. This is a protocol bug. We can't decide after a + * seeing an unknown discard-with-error flavour TLV option if it's a + * ICMP error message or not (errors should never be send in reply to + * ICMP error messages). + * + * But I see no other way to do this. This might need to be reexamined + * when Linux implements ESP (and maybe AUTH) headers. + * --AK + * + * This function parses (probably truncated) exthdr set "hdr". + * "nexthdrp" initially points to some place, + * where type of the first header can be found. + * + * It skips all well-known exthdrs, and returns pointer to the start + * of unparsable area i.e. the first header with unknown type. + * If it is not NULL *nexthdr is updated by type/protocol of this header. + * + * NOTES: - if packet terminated with NEXTHDR_NONE it returns NULL. + * - it may return pointer pointing beyond end of packet, + * if the last recognized header is truncated in the middle. + * - if packet is truncated, so that all parsed headers are skipped, + * it returns NULL. + * - First fragment header is skipped, not-first ones + * are considered as unparsable. + * - Reports the offset field of the final fragment header so it is + * possible to tell whether this is a first fragment, later fragment, + * or not fragmented. + * - ESP is unparsable for now and considered like + * normal payload protocol. + * - Note also special handling of AUTH header. Thanks to IPsec wizards. + * + * --ANK (980726) + */ + +int ipv6_skip_exthdr(const struct sk_buff *skb, int start, u8 *nexthdrp, + __be16 *frag_offp) +{ + u8 nexthdr = *nexthdrp; + + *frag_offp = 0; + + while (ipv6_ext_hdr(nexthdr)) { + struct ipv6_opt_hdr _hdr, *hp; + int hdrlen; + + if (nexthdr == NEXTHDR_NONE) + return -1; + hp = skb_header_pointer(skb, start, sizeof(_hdr), &_hdr); + if (!hp) + return -1; + if (nexthdr == NEXTHDR_FRAGMENT) { + __be16 _frag_off, *fp; + fp = skb_header_pointer(skb, + start+offsetof(struct frag_hdr, + frag_off), + sizeof(_frag_off), + &_frag_off); + if (!fp) + return -1; + + *frag_offp = *fp; + if (ntohs(*frag_offp) & ~0x7) + break; + hdrlen = 8; + } else if (nexthdr == NEXTHDR_AUTH) + hdrlen = ipv6_authlen(hp); + else + hdrlen = ipv6_optlen(hp); + + nexthdr = hp->nexthdr; + start += hdrlen; + } + + *nexthdrp = nexthdr; + return start; +} +EXPORT_SYMBOL(ipv6_skip_exthdr); + +int ipv6_find_tlv(const struct sk_buff *skb, int offset, int type) +{ + const unsigned char *nh = skb_network_header(skb); + int packet_len = skb_tail_pointer(skb) - skb_network_header(skb); + struct ipv6_opt_hdr *hdr; + int len; + + if (offset + 2 > packet_len) + goto bad; + hdr = (struct ipv6_opt_hdr *)(nh + offset); + len = ((hdr->hdrlen + 1) << 3); + + if (offset + len > packet_len) + goto bad; + + offset += 2; + len -= 2; + + while (len > 0) { + int opttype = nh[offset]; + int optlen; + + if (opttype == type) + return offset; + + switch (opttype) { + case IPV6_TLV_PAD1: + optlen = 1; + break; + default: + if (len < 2) + goto bad; + optlen = nh[offset + 1] + 2; + if (optlen > len) + goto bad; + break; + } + offset += optlen; + len -= optlen; + } + /* not_found */ + bad: + return -1; +} +EXPORT_SYMBOL_GPL(ipv6_find_tlv); + +/* + * find the offset to specified header or the protocol number of last header + * if target < 0. "last header" is transport protocol header, ESP, or + * "No next header". + * + * Note that *offset is used as input/output parameter, and if it is not zero, + * then it must be a valid offset to an inner IPv6 header. This can be used + * to explore inner IPv6 header, eg. ICMPv6 error messages. + * + * If target header is found, its offset is set in *offset and return protocol + * number. Otherwise, return -1. + * + * If the first fragment doesn't contain the final protocol header or + * NEXTHDR_NONE it is considered invalid. + * + * Note that non-1st fragment is special case that "the protocol number + * of last header" is "next header" field in Fragment header. In this case, + * *offset is meaningless and fragment offset is stored in *fragoff if fragoff + * isn't NULL. + * + * if flags is not NULL and it's a fragment, then the frag flag + * IP6_FH_F_FRAG will be set. If it's an AH header, the + * IP6_FH_F_AUTH flag is set and target < 0, then this function will + * stop at the AH header. If IP6_FH_F_SKIP_RH flag was passed, then this + * function will skip all those routing headers, where segements_left was 0. + */ +int ipv6_find_hdr(const struct sk_buff *skb, unsigned int *offset, + int target, unsigned short *fragoff, int *flags) +{ + unsigned int start = skb_network_offset(skb) + sizeof(struct ipv6hdr); + u8 nexthdr = ipv6_hdr(skb)->nexthdr; + bool found; + + if (fragoff) + *fragoff = 0; + + if (*offset) { + struct ipv6hdr _ip6, *ip6; + + ip6 = skb_header_pointer(skb, *offset, sizeof(_ip6), &_ip6); + if (!ip6 || (ip6->version != 6)) + return -EBADMSG; + start = *offset + sizeof(struct ipv6hdr); + nexthdr = ip6->nexthdr; + } + + do { + struct ipv6_opt_hdr _hdr, *hp; + unsigned int hdrlen; + found = (nexthdr == target); + + if ((!ipv6_ext_hdr(nexthdr)) || nexthdr == NEXTHDR_NONE) { + if (target < 0 || found) + break; + return -ENOENT; + } + + hp = skb_header_pointer(skb, start, sizeof(_hdr), &_hdr); + if (!hp) + return -EBADMSG; + + if (nexthdr == NEXTHDR_ROUTING) { + struct ipv6_rt_hdr _rh, *rh; + + rh = skb_header_pointer(skb, start, sizeof(_rh), + &_rh); + if (!rh) + return -EBADMSG; + + if (flags && (*flags & IP6_FH_F_SKIP_RH) && + rh->segments_left == 0) + found = false; + } + + if (nexthdr == NEXTHDR_FRAGMENT) { + unsigned short _frag_off; + __be16 *fp; + + if (flags) /* Indicate that this is a fragment */ + *flags |= IP6_FH_F_FRAG; + fp = skb_header_pointer(skb, + start+offsetof(struct frag_hdr, + frag_off), + sizeof(_frag_off), + &_frag_off); + if (!fp) + return -EBADMSG; + + _frag_off = ntohs(*fp) & ~0x7; + if (_frag_off) { + if (target < 0 && + ((!ipv6_ext_hdr(hp->nexthdr)) || + hp->nexthdr == NEXTHDR_NONE)) { + if (fragoff) + *fragoff = _frag_off; + return hp->nexthdr; + } + if (!found) + return -ENOENT; + if (fragoff) + *fragoff = _frag_off; + break; + } + hdrlen = 8; + } else if (nexthdr == NEXTHDR_AUTH) { + if (flags && (*flags & IP6_FH_F_AUTH) && (target < 0)) + break; + hdrlen = ipv6_authlen(hp); + } else + hdrlen = ipv6_optlen(hp); + + if (!found) { + nexthdr = hp->nexthdr; + start += hdrlen; + } + } while (!found); + + *offset = start; + return nexthdr; +} +EXPORT_SYMBOL(ipv6_find_hdr); diff --git a/net/ipv6/exthdrs_offload.c b/net/ipv6/exthdrs_offload.c new file mode 100644 index 000000000..06750d65d --- /dev/null +++ b/net/ipv6/exthdrs_offload.c @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPV6 GSO/GRO offload support + * Linux INET6 implementation + * + * IPV6 Extension Header GSO/GRO support + */ +#include +#include "ip6_offload.h" + +static const struct net_offload rthdr_offload = { + .flags = INET6_PROTO_GSO_EXTHDR, +}; + +static const struct net_offload dstopt_offload = { + .flags = INET6_PROTO_GSO_EXTHDR, +}; + +int __init ipv6_exthdrs_offload_init(void) +{ + int ret; + + ret = inet6_add_offload(&rthdr_offload, IPPROTO_ROUTING); + if (ret) + goto out; + + ret = inet6_add_offload(&dstopt_offload, IPPROTO_DSTOPTS); + if (ret) + goto out_rt; + +out: + return ret; + +out_rt: + inet6_del_offload(&rthdr_offload, IPPROTO_ROUTING); + goto out; +} diff --git a/net/ipv6/fib6_notifier.c b/net/ipv6/fib6_notifier.c new file mode 100644 index 000000000..f87ae33e1 --- /dev/null +++ b/net/ipv6/fib6_notifier.c @@ -0,0 +1,64 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +int call_fib6_notifier(struct notifier_block *nb, + enum fib_event_type event_type, + struct fib_notifier_info *info) +{ + info->family = AF_INET6; + return call_fib_notifier(nb, event_type, info); +} + +int call_fib6_notifiers(struct net *net, enum fib_event_type event_type, + struct fib_notifier_info *info) +{ + info->family = AF_INET6; + return call_fib_notifiers(net, event_type, info); +} + +static unsigned int fib6_seq_read(struct net *net) +{ + return fib6_tables_seq_read(net) + fib6_rules_seq_read(net); +} + +static int fib6_dump(struct net *net, struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + int err; + + err = fib6_rules_dump(net, nb, extack); + if (err) + return err; + + return fib6_tables_dump(net, nb, extack); +} + +static const struct fib_notifier_ops fib6_notifier_ops_template = { + .family = AF_INET6, + .fib_seq_read = fib6_seq_read, + .fib_dump = fib6_dump, + .owner = THIS_MODULE, +}; + +int __net_init fib6_notifier_init(struct net *net) +{ + struct fib_notifier_ops *ops; + + ops = fib_notifier_ops_register(&fib6_notifier_ops_template, net); + if (IS_ERR(ops)) + return PTR_ERR(ops); + net->ipv6.notifier_ops = ops; + + return 0; +} + +void __net_exit fib6_notifier_exit(struct net *net) +{ + fib_notifier_ops_unregister(net->ipv6.notifier_ops); +} diff --git a/net/ipv6/fib6_rules.c b/net/ipv6/fib6_rules.c new file mode 100644 index 000000000..7c2003833 --- /dev/null +++ b/net/ipv6/fib6_rules.c @@ -0,0 +1,522 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/ipv6/fib6_rules.c IPv6 Routing Policy Rules + * + * Copyright (C)2003-2006 Helsinki University of Technology + * Copyright (C)2003-2006 USAGI/WIDE Project + * + * Authors + * Thomas Graf + * Ville Nuorvala + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +struct fib6_rule { + struct fib_rule common; + struct rt6key src; + struct rt6key dst; + dscp_t dscp; +}; + +static bool fib6_rule_matchall(const struct fib_rule *rule) +{ + struct fib6_rule *r = container_of(rule, struct fib6_rule, common); + + if (r->dst.plen || r->src.plen || r->dscp) + return false; + return fib_rule_matchall(rule); +} + +bool fib6_rule_default(const struct fib_rule *rule) +{ + if (!fib6_rule_matchall(rule) || rule->action != FR_ACT_TO_TBL || + rule->l3mdev) + return false; + if (rule->table != RT6_TABLE_LOCAL && rule->table != RT6_TABLE_MAIN) + return false; + return true; +} +EXPORT_SYMBOL_GPL(fib6_rule_default); + +int fib6_rules_dump(struct net *net, struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + return fib_rules_dump(net, nb, AF_INET6, extack); +} + +unsigned int fib6_rules_seq_read(struct net *net) +{ + return fib_rules_seq_read(net, AF_INET6); +} + +/* called with rcu lock held; no reference taken on fib6_info */ +int fib6_lookup(struct net *net, int oif, struct flowi6 *fl6, + struct fib6_result *res, int flags) +{ + int err; + + if (net->ipv6.fib6_has_custom_rules) { + struct fib_lookup_arg arg = { + .lookup_ptr = fib6_table_lookup, + .lookup_data = &oif, + .result = res, + .flags = FIB_LOOKUP_NOREF, + }; + + l3mdev_update_flow(net, flowi6_to_flowi(fl6)); + + err = fib_rules_lookup(net->ipv6.fib6_rules_ops, + flowi6_to_flowi(fl6), flags, &arg); + } else { + err = fib6_table_lookup(net, net->ipv6.fib6_local_tbl, oif, + fl6, res, flags); + if (err || res->f6i == net->ipv6.fib6_null_entry) + err = fib6_table_lookup(net, net->ipv6.fib6_main_tbl, + oif, fl6, res, flags); + } + + return err; +} + +struct dst_entry *fib6_rule_lookup(struct net *net, struct flowi6 *fl6, + const struct sk_buff *skb, + int flags, pol_lookup_t lookup) +{ + if (net->ipv6.fib6_has_custom_rules) { + struct fib6_result res = {}; + struct fib_lookup_arg arg = { + .lookup_ptr = lookup, + .lookup_data = skb, + .result = &res, + .flags = FIB_LOOKUP_NOREF, + }; + + /* update flow if oif or iif point to device enslaved to l3mdev */ + l3mdev_update_flow(net, flowi6_to_flowi(fl6)); + + fib_rules_lookup(net->ipv6.fib6_rules_ops, + flowi6_to_flowi(fl6), flags, &arg); + + if (res.rt6) + return &res.rt6->dst; + } else { + struct rt6_info *rt; + + rt = pol_lookup_func(lookup, + net, net->ipv6.fib6_local_tbl, fl6, skb, flags); + if (rt != net->ipv6.ip6_null_entry && rt->dst.error != -EAGAIN) + return &rt->dst; + ip6_rt_put_flags(rt, flags); + rt = pol_lookup_func(lookup, + net, net->ipv6.fib6_main_tbl, fl6, skb, flags); + if (rt->dst.error != -EAGAIN) + return &rt->dst; + ip6_rt_put_flags(rt, flags); + } + + if (!(flags & RT6_LOOKUP_F_DST_NOREF)) + dst_hold(&net->ipv6.ip6_null_entry->dst); + return &net->ipv6.ip6_null_entry->dst; +} + +static int fib6_rule_saddr(struct net *net, struct fib_rule *rule, int flags, + struct flowi6 *flp6, const struct net_device *dev) +{ + struct fib6_rule *r = (struct fib6_rule *)rule; + + /* If we need to find a source address for this traffic, + * we check the result if it meets requirement of the rule. + */ + if ((rule->flags & FIB_RULE_FIND_SADDR) && + r->src.plen && !(flags & RT6_LOOKUP_F_HAS_SADDR)) { + struct in6_addr saddr; + + if (ipv6_dev_get_saddr(net, dev, &flp6->daddr, + rt6_flags2srcprefs(flags), &saddr)) + return -EAGAIN; + + if (!ipv6_prefix_equal(&saddr, &r->src.addr, r->src.plen)) + return -EAGAIN; + + flp6->saddr = saddr; + } + + return 0; +} + +static int fib6_rule_action_alt(struct fib_rule *rule, struct flowi *flp, + int flags, struct fib_lookup_arg *arg) +{ + struct fib6_result *res = arg->result; + struct flowi6 *flp6 = &flp->u.ip6; + struct net *net = rule->fr_net; + struct fib6_table *table; + int err, *oif; + u32 tb_id; + + switch (rule->action) { + case FR_ACT_TO_TBL: + break; + case FR_ACT_UNREACHABLE: + return -ENETUNREACH; + case FR_ACT_PROHIBIT: + return -EACCES; + case FR_ACT_BLACKHOLE: + default: + return -EINVAL; + } + + tb_id = fib_rule_get_table(rule, arg); + table = fib6_get_table(net, tb_id); + if (!table) + return -EAGAIN; + + oif = (int *)arg->lookup_data; + err = fib6_table_lookup(net, table, *oif, flp6, res, flags); + if (!err && res->f6i != net->ipv6.fib6_null_entry) + err = fib6_rule_saddr(net, rule, flags, flp6, + res->nh->fib_nh_dev); + else + err = -EAGAIN; + + return err; +} + +static int __fib6_rule_action(struct fib_rule *rule, struct flowi *flp, + int flags, struct fib_lookup_arg *arg) +{ + struct fib6_result *res = arg->result; + struct flowi6 *flp6 = &flp->u.ip6; + struct rt6_info *rt = NULL; + struct fib6_table *table; + struct net *net = rule->fr_net; + pol_lookup_t lookup = arg->lookup_ptr; + int err = 0; + u32 tb_id; + + switch (rule->action) { + case FR_ACT_TO_TBL: + break; + case FR_ACT_UNREACHABLE: + err = -ENETUNREACH; + rt = net->ipv6.ip6_null_entry; + goto discard_pkt; + default: + case FR_ACT_BLACKHOLE: + err = -EINVAL; + rt = net->ipv6.ip6_blk_hole_entry; + goto discard_pkt; + case FR_ACT_PROHIBIT: + err = -EACCES; + rt = net->ipv6.ip6_prohibit_entry; + goto discard_pkt; + } + + tb_id = fib_rule_get_table(rule, arg); + table = fib6_get_table(net, tb_id); + if (!table) { + err = -EAGAIN; + goto out; + } + + rt = pol_lookup_func(lookup, + net, table, flp6, arg->lookup_data, flags); + if (rt != net->ipv6.ip6_null_entry) { + err = fib6_rule_saddr(net, rule, flags, flp6, + ip6_dst_idev(&rt->dst)->dev); + + if (err == -EAGAIN) + goto again; + + err = rt->dst.error; + if (err != -EAGAIN) + goto out; + } +again: + ip6_rt_put_flags(rt, flags); + err = -EAGAIN; + rt = NULL; + goto out; + +discard_pkt: + if (!(flags & RT6_LOOKUP_F_DST_NOREF)) + dst_hold(&rt->dst); +out: + res->rt6 = rt; + return err; +} + +INDIRECT_CALLABLE_SCOPE int fib6_rule_action(struct fib_rule *rule, + struct flowi *flp, int flags, + struct fib_lookup_arg *arg) +{ + if (arg->lookup_ptr == fib6_table_lookup) + return fib6_rule_action_alt(rule, flp, flags, arg); + + return __fib6_rule_action(rule, flp, flags, arg); +} + +INDIRECT_CALLABLE_SCOPE bool fib6_rule_suppress(struct fib_rule *rule, + int flags, + struct fib_lookup_arg *arg) +{ + struct fib6_result *res = arg->result; + struct rt6_info *rt = res->rt6; + struct net_device *dev = NULL; + + if (!rt) + return false; + + if (rt->rt6i_idev) + dev = rt->rt6i_idev->dev; + + /* do not accept result if the route does + * not meet the required prefix length + */ + if (rt->rt6i_dst.plen <= rule->suppress_prefixlen) + goto suppress_route; + + /* do not accept result if the route uses a device + * belonging to a forbidden interface group + */ + if (rule->suppress_ifgroup != -1 && dev && dev->group == rule->suppress_ifgroup) + goto suppress_route; + + return false; + +suppress_route: + ip6_rt_put_flags(rt, flags); + return true; +} + +INDIRECT_CALLABLE_SCOPE int fib6_rule_match(struct fib_rule *rule, + struct flowi *fl, int flags) +{ + struct fib6_rule *r = (struct fib6_rule *) rule; + struct flowi6 *fl6 = &fl->u.ip6; + + if (r->dst.plen && + !ipv6_prefix_equal(&fl6->daddr, &r->dst.addr, r->dst.plen)) + return 0; + + /* + * If FIB_RULE_FIND_SADDR is set and we do not have a + * source address for the traffic, we defer check for + * source address. + */ + if (r->src.plen) { + if (flags & RT6_LOOKUP_F_HAS_SADDR) { + if (!ipv6_prefix_equal(&fl6->saddr, &r->src.addr, + r->src.plen)) + return 0; + } else if (!(r->common.flags & FIB_RULE_FIND_SADDR)) + return 0; + } + + if (r->dscp && r->dscp != ip6_dscp(fl6->flowlabel)) + return 0; + + if (rule->ip_proto && (rule->ip_proto != fl6->flowi6_proto)) + return 0; + + if (fib_rule_port_range_set(&rule->sport_range) && + !fib_rule_port_inrange(&rule->sport_range, fl6->fl6_sport)) + return 0; + + if (fib_rule_port_range_set(&rule->dport_range) && + !fib_rule_port_inrange(&rule->dport_range, fl6->fl6_dport)) + return 0; + + return 1; +} + +static int fib6_rule_configure(struct fib_rule *rule, struct sk_buff *skb, + struct fib_rule_hdr *frh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + int err = -EINVAL; + struct net *net = sock_net(skb->sk); + struct fib6_rule *rule6 = (struct fib6_rule *) rule; + + if (!inet_validate_dscp(frh->tos)) { + NL_SET_ERR_MSG(extack, + "Invalid dsfield (tos): ECN bits must be 0"); + goto errout; + } + rule6->dscp = inet_dsfield_to_dscp(frh->tos); + + if (rule->action == FR_ACT_TO_TBL && !rule->l3mdev) { + if (rule->table == RT6_TABLE_UNSPEC) { + NL_SET_ERR_MSG(extack, "Invalid table"); + goto errout; + } + + if (fib6_new_table(net, rule->table) == NULL) { + err = -ENOBUFS; + goto errout; + } + } + + if (frh->src_len) + rule6->src.addr = nla_get_in6_addr(tb[FRA_SRC]); + + if (frh->dst_len) + rule6->dst.addr = nla_get_in6_addr(tb[FRA_DST]); + + rule6->src.plen = frh->src_len; + rule6->dst.plen = frh->dst_len; + + if (fib_rule_requires_fldissect(rule)) + net->ipv6.fib6_rules_require_fldissect++; + + net->ipv6.fib6_has_custom_rules = true; + err = 0; +errout: + return err; +} + +static int fib6_rule_delete(struct fib_rule *rule) +{ + struct net *net = rule->fr_net; + + if (net->ipv6.fib6_rules_require_fldissect && + fib_rule_requires_fldissect(rule)) + net->ipv6.fib6_rules_require_fldissect--; + + return 0; +} + +static int fib6_rule_compare(struct fib_rule *rule, struct fib_rule_hdr *frh, + struct nlattr **tb) +{ + struct fib6_rule *rule6 = (struct fib6_rule *) rule; + + if (frh->src_len && (rule6->src.plen != frh->src_len)) + return 0; + + if (frh->dst_len && (rule6->dst.plen != frh->dst_len)) + return 0; + + if (frh->tos && inet_dscp_to_dsfield(rule6->dscp) != frh->tos) + return 0; + + if (frh->src_len && + nla_memcmp(tb[FRA_SRC], &rule6->src.addr, sizeof(struct in6_addr))) + return 0; + + if (frh->dst_len && + nla_memcmp(tb[FRA_DST], &rule6->dst.addr, sizeof(struct in6_addr))) + return 0; + + return 1; +} + +static int fib6_rule_fill(struct fib_rule *rule, struct sk_buff *skb, + struct fib_rule_hdr *frh) +{ + struct fib6_rule *rule6 = (struct fib6_rule *) rule; + + frh->dst_len = rule6->dst.plen; + frh->src_len = rule6->src.plen; + frh->tos = inet_dscp_to_dsfield(rule6->dscp); + + if ((rule6->dst.plen && + nla_put_in6_addr(skb, FRA_DST, &rule6->dst.addr)) || + (rule6->src.plen && + nla_put_in6_addr(skb, FRA_SRC, &rule6->src.addr))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -ENOBUFS; +} + +static size_t fib6_rule_nlmsg_payload(struct fib_rule *rule) +{ + return nla_total_size(16) /* dst */ + + nla_total_size(16); /* src */ +} + +static const struct fib_rules_ops __net_initconst fib6_rules_ops_template = { + .family = AF_INET6, + .rule_size = sizeof(struct fib6_rule), + .addr_size = sizeof(struct in6_addr), + .action = fib6_rule_action, + .match = fib6_rule_match, + .suppress = fib6_rule_suppress, + .configure = fib6_rule_configure, + .delete = fib6_rule_delete, + .compare = fib6_rule_compare, + .fill = fib6_rule_fill, + .nlmsg_payload = fib6_rule_nlmsg_payload, + .nlgroup = RTNLGRP_IPV6_RULE, + .owner = THIS_MODULE, + .fro_net = &init_net, +}; + +static int __net_init fib6_rules_net_init(struct net *net) +{ + struct fib_rules_ops *ops; + int err; + + ops = fib_rules_register(&fib6_rules_ops_template, net); + if (IS_ERR(ops)) + return PTR_ERR(ops); + + err = fib_default_rule_add(ops, 0, RT6_TABLE_LOCAL, 0); + if (err) + goto out_fib6_rules_ops; + + err = fib_default_rule_add(ops, 0x7FFE, RT6_TABLE_MAIN, 0); + if (err) + goto out_fib6_rules_ops; + + net->ipv6.fib6_rules_ops = ops; + net->ipv6.fib6_rules_require_fldissect = 0; +out: + return err; + +out_fib6_rules_ops: + fib_rules_unregister(ops); + goto out; +} + +static void __net_exit fib6_rules_net_exit_batch(struct list_head *net_list) +{ + struct net *net; + + rtnl_lock(); + list_for_each_entry(net, net_list, exit_list) { + fib_rules_unregister(net->ipv6.fib6_rules_ops); + cond_resched(); + } + rtnl_unlock(); +} + +static struct pernet_operations fib6_rules_net_ops = { + .init = fib6_rules_net_init, + .exit_batch = fib6_rules_net_exit_batch, +}; + +int __init fib6_rules_init(void) +{ + return register_pernet_subsys(&fib6_rules_net_ops); +} + + +void fib6_rules_cleanup(void) +{ + unregister_pernet_subsys(&fib6_rules_net_ops); +} diff --git a/net/ipv6/fou6.c b/net/ipv6/fou6.c new file mode 100644 index 000000000..430518ae2 --- /dev/null +++ b/net/ipv6/fou6.c @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_IPV6_FOU_TUNNEL) + +static void fou6_build_udp(struct sk_buff *skb, struct ip_tunnel_encap *e, + struct flowi6 *fl6, u8 *protocol, __be16 sport) +{ + struct udphdr *uh; + + skb_push(skb, sizeof(struct udphdr)); + skb_reset_transport_header(skb); + + uh = udp_hdr(skb); + + uh->dest = e->dport; + uh->source = sport; + uh->len = htons(skb->len); + udp6_set_csum(!(e->flags & TUNNEL_ENCAP_FLAG_CSUM6), skb, + &fl6->saddr, &fl6->daddr, skb->len); + + *protocol = IPPROTO_UDP; +} + +static int fou6_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e, + u8 *protocol, struct flowi6 *fl6) +{ + __be16 sport; + int err; + int type = e->flags & TUNNEL_ENCAP_FLAG_CSUM6 ? + SKB_GSO_UDP_TUNNEL_CSUM : SKB_GSO_UDP_TUNNEL; + + err = __fou_build_header(skb, e, protocol, &sport, type); + if (err) + return err; + + fou6_build_udp(skb, e, fl6, protocol, sport); + + return 0; +} + +static int gue6_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e, + u8 *protocol, struct flowi6 *fl6) +{ + __be16 sport; + int err; + int type = e->flags & TUNNEL_ENCAP_FLAG_CSUM6 ? + SKB_GSO_UDP_TUNNEL_CSUM : SKB_GSO_UDP_TUNNEL; + + err = __gue_build_header(skb, e, protocol, &sport, type); + if (err) + return err; + + fou6_build_udp(skb, e, fl6, protocol, sport); + + return 0; +} + +static int gue6_err_proto_handler(int proto, struct sk_buff *skb, + struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + const struct inet6_protocol *ipprot; + + ipprot = rcu_dereference(inet6_protos[proto]); + if (ipprot && ipprot->err_handler) { + if (!ipprot->err_handler(skb, opt, type, code, offset, info)) + return 0; + } + + return -ENOENT; +} + +static int gue6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + int transport_offset = skb_transport_offset(skb); + struct guehdr *guehdr; + size_t len, optlen; + int ret; + + len = sizeof(struct udphdr) + sizeof(struct guehdr); + if (!pskb_may_pull(skb, transport_offset + len)) + return -EINVAL; + + guehdr = (struct guehdr *)&udp_hdr(skb)[1]; + + switch (guehdr->version) { + case 0: /* Full GUE header present */ + break; + case 1: { + /* Direct encasulation of IPv4 or IPv6 */ + skb_set_transport_header(skb, -(int)sizeof(struct icmp6hdr)); + + switch (((struct iphdr *)guehdr)->version) { + case 4: + ret = gue6_err_proto_handler(IPPROTO_IPIP, skb, opt, + type, code, offset, info); + goto out; + case 6: + ret = gue6_err_proto_handler(IPPROTO_IPV6, skb, opt, + type, code, offset, info); + goto out; + default: + ret = -EOPNOTSUPP; + goto out; + } + } + default: /* Undefined version */ + return -EOPNOTSUPP; + } + + if (guehdr->control) + return -ENOENT; + + optlen = guehdr->hlen << 2; + + if (!pskb_may_pull(skb, transport_offset + len + optlen)) + return -EINVAL; + + guehdr = (struct guehdr *)&udp_hdr(skb)[1]; + if (validate_gue_flags(guehdr, optlen)) + return -EINVAL; + + /* Handling exceptions for direct UDP encapsulation in GUE would lead to + * recursion. Besides, this kind of encapsulation can't even be + * configured currently. Discard this. + */ + if (guehdr->proto_ctype == IPPROTO_UDP || + guehdr->proto_ctype == IPPROTO_UDPLITE) + return -EOPNOTSUPP; + + skb_set_transport_header(skb, -(int)sizeof(struct icmp6hdr)); + ret = gue6_err_proto_handler(guehdr->proto_ctype, skb, + opt, type, code, offset, info); + +out: + skb_set_transport_header(skb, transport_offset); + return ret; +} + + +static const struct ip6_tnl_encap_ops fou_ip6tun_ops = { + .encap_hlen = fou_encap_hlen, + .build_header = fou6_build_header, + .err_handler = gue6_err, +}; + +static const struct ip6_tnl_encap_ops gue_ip6tun_ops = { + .encap_hlen = gue_encap_hlen, + .build_header = gue6_build_header, + .err_handler = gue6_err, +}; + +static int ip6_tnl_encap_add_fou_ops(void) +{ + int ret; + + ret = ip6_tnl_encap_add_ops(&fou_ip6tun_ops, TUNNEL_ENCAP_FOU); + if (ret < 0) { + pr_err("can't add fou6 ops\n"); + return ret; + } + + ret = ip6_tnl_encap_add_ops(&gue_ip6tun_ops, TUNNEL_ENCAP_GUE); + if (ret < 0) { + pr_err("can't add gue6 ops\n"); + ip6_tnl_encap_del_ops(&fou_ip6tun_ops, TUNNEL_ENCAP_FOU); + return ret; + } + + return 0; +} + +static void ip6_tnl_encap_del_fou_ops(void) +{ + ip6_tnl_encap_del_ops(&fou_ip6tun_ops, TUNNEL_ENCAP_FOU); + ip6_tnl_encap_del_ops(&gue_ip6tun_ops, TUNNEL_ENCAP_GUE); +} + +#else + +static int ip6_tnl_encap_add_fou_ops(void) +{ + return 0; +} + +static void ip6_tnl_encap_del_fou_ops(void) +{ +} + +#endif + +static int __init fou6_init(void) +{ + int ret; + + ret = ip6_tnl_encap_add_fou_ops(); + + return ret; +} + +static void __exit fou6_fini(void) +{ + ip6_tnl_encap_del_fou_ops(); +} + +module_init(fou6_init); +module_exit(fou6_fini); +MODULE_AUTHOR("Tom Herbert "); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Foo over UDP (IPv6)"); diff --git a/net/ipv6/icmp.c b/net/ipv6/icmp.c new file mode 100644 index 000000000..e2af7ab99 --- /dev/null +++ b/net/ipv6/icmp.c @@ -0,0 +1,1208 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Internet Control Message Protocol (ICMPv6) + * Linux INET6 implementation + * + * Authors: + * Pedro Roque + * + * Based on net/ipv4/icmp.c + * + * RFC 1885 + */ + +/* + * Changes: + * + * Andi Kleen : exception handling + * Andi Kleen add rate limits. never reply to a icmp. + * add more length checks and other fixes. + * yoshfuji : ensure to sent parameter problem for + * fragments. + * YOSHIFUJI Hideaki @USAGI: added sysctl for icmp rate limit. + * Randy Dunlap and + * YOSHIFUJI Hideaki @USAGI: Per-interface statistics support + * Kazunori MIYAZAWA @USAGI: change output process to use ip6_append_data + */ + +#define pr_fmt(fmt) "IPv6: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CONFIG_SYSCTL +#include +#endif + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +static DEFINE_PER_CPU(struct sock *, ipv6_icmp_sk); + +static int icmpv6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + /* icmpv6_notify checks 8 bytes can be pulled, icmp6hdr is 8 bytes */ + struct icmp6hdr *icmp6 = (struct icmp6hdr *) (skb->data + offset); + struct net *net = dev_net(skb->dev); + + if (type == ICMPV6_PKT_TOOBIG) + ip6_update_pmtu(skb, net, info, skb->dev->ifindex, 0, sock_net_uid(net, NULL)); + else if (type == NDISC_REDIRECT) + ip6_redirect(skb, net, skb->dev->ifindex, 0, + sock_net_uid(net, NULL)); + + if (!(type & ICMPV6_INFOMSG_MASK)) + if (icmp6->icmp6_type == ICMPV6_ECHO_REQUEST) + ping_err(skb, offset, ntohl(info)); + + return 0; +} + +static int icmpv6_rcv(struct sk_buff *skb); + +static const struct inet6_protocol icmpv6_protocol = { + .handler = icmpv6_rcv, + .err_handler = icmpv6_err, + .flags = INET6_PROTO_NOPOLICY|INET6_PROTO_FINAL, +}; + +/* Called with BH disabled */ +static struct sock *icmpv6_xmit_lock(struct net *net) +{ + struct sock *sk; + + sk = this_cpu_read(ipv6_icmp_sk); + if (unlikely(!spin_trylock(&sk->sk_lock.slock))) { + /* This can happen if the output path (f.e. SIT or + * ip6ip6 tunnel) signals dst_link_failure() for an + * outgoing ICMP6 packet. + */ + return NULL; + } + sock_net_set(sk, net); + return sk; +} + +static void icmpv6_xmit_unlock(struct sock *sk) +{ + sock_net_set(sk, &init_net); + spin_unlock(&sk->sk_lock.slock); +} + +/* + * Figure out, may we reply to this packet with icmp error. + * + * We do not reply, if: + * - it was icmp error message. + * - it is truncated, so that it is known, that protocol is ICMPV6 + * (i.e. in the middle of some exthdr) + * + * --ANK (980726) + */ + +static bool is_ineligible(const struct sk_buff *skb) +{ + int ptr = (u8 *)(ipv6_hdr(skb) + 1) - skb->data; + int len = skb->len - ptr; + __u8 nexthdr = ipv6_hdr(skb)->nexthdr; + __be16 frag_off; + + if (len < 0) + return true; + + ptr = ipv6_skip_exthdr(skb, ptr, &nexthdr, &frag_off); + if (ptr < 0) + return false; + if (nexthdr == IPPROTO_ICMPV6) { + u8 _type, *tp; + tp = skb_header_pointer(skb, + ptr+offsetof(struct icmp6hdr, icmp6_type), + sizeof(_type), &_type); + + /* Based on RFC 8200, Section 4.5 Fragment Header, return + * false if this is a fragment packet with no icmp header info. + */ + if (!tp && frag_off != 0) + return false; + else if (!tp || !(*tp & ICMPV6_INFOMSG_MASK)) + return true; + } + return false; +} + +static bool icmpv6_mask_allow(struct net *net, int type) +{ + if (type > ICMPV6_MSG_MAX) + return true; + + /* Limit if icmp type is set in ratemask. */ + if (!test_bit(type, net->ipv6.sysctl.icmpv6_ratemask)) + return true; + + return false; +} + +static bool icmpv6_global_allow(struct net *net, int type) +{ + if (icmpv6_mask_allow(net, type)) + return true; + + if (icmp_global_allow()) + return true; + + return false; +} + +/* + * Check the ICMP output rate limit + */ +static bool icmpv6_xrlim_allow(struct sock *sk, u8 type, + struct flowi6 *fl6) +{ + struct net *net = sock_net(sk); + struct dst_entry *dst; + bool res = false; + + if (icmpv6_mask_allow(net, type)) + return true; + + /* + * Look up the output route. + * XXX: perhaps the expire for routing entries cloned by + * this lookup should be more aggressive (not longer than timeout). + */ + dst = ip6_route_output(net, sk, fl6); + if (dst->error) { + IP6_INC_STATS(net, ip6_dst_idev(dst), + IPSTATS_MIB_OUTNOROUTES); + } else if (dst->dev && (dst->dev->flags&IFF_LOOPBACK)) { + res = true; + } else { + struct rt6_info *rt = (struct rt6_info *)dst; + int tmo = net->ipv6.sysctl.icmpv6_time; + struct inet_peer *peer; + + /* Give more bandwidth to wider prefixes. */ + if (rt->rt6i_dst.plen < 128) + tmo >>= ((128 - rt->rt6i_dst.plen)>>5); + + peer = inet_getpeer_v6(net->ipv6.peers, &fl6->daddr, 1); + res = inet_peer_xrlim_allow(peer, tmo); + if (peer) + inet_putpeer(peer); + } + dst_release(dst); + return res; +} + +static bool icmpv6_rt_has_prefsrc(struct sock *sk, u8 type, + struct flowi6 *fl6) +{ + struct net *net = sock_net(sk); + struct dst_entry *dst; + bool res = false; + + dst = ip6_route_output(net, sk, fl6); + if (!dst->error) { + struct rt6_info *rt = (struct rt6_info *)dst; + struct in6_addr prefsrc; + + rt6_get_prefsrc(rt, &prefsrc); + res = !ipv6_addr_any(&prefsrc); + } + dst_release(dst); + return res; +} + +/* + * an inline helper for the "simple" if statement below + * checks if parameter problem report is caused by an + * unrecognized IPv6 option that has the Option Type + * highest-order two bits set to 10 + */ + +static bool opt_unrec(struct sk_buff *skb, __u32 offset) +{ + u8 _optval, *op; + + offset += skb_network_offset(skb); + op = skb_header_pointer(skb, offset, sizeof(_optval), &_optval); + if (!op) + return true; + return (*op & 0xC0) == 0x80; +} + +void icmpv6_push_pending_frames(struct sock *sk, struct flowi6 *fl6, + struct icmp6hdr *thdr, int len) +{ + struct sk_buff *skb; + struct icmp6hdr *icmp6h; + + skb = skb_peek(&sk->sk_write_queue); + if (!skb) + return; + + icmp6h = icmp6_hdr(skb); + memcpy(icmp6h, thdr, sizeof(struct icmp6hdr)); + icmp6h->icmp6_cksum = 0; + + if (skb_queue_len(&sk->sk_write_queue) == 1) { + skb->csum = csum_partial(icmp6h, + sizeof(struct icmp6hdr), skb->csum); + icmp6h->icmp6_cksum = csum_ipv6_magic(&fl6->saddr, + &fl6->daddr, + len, fl6->flowi6_proto, + skb->csum); + } else { + __wsum tmp_csum = 0; + + skb_queue_walk(&sk->sk_write_queue, skb) { + tmp_csum = csum_add(tmp_csum, skb->csum); + } + + tmp_csum = csum_partial(icmp6h, + sizeof(struct icmp6hdr), tmp_csum); + icmp6h->icmp6_cksum = csum_ipv6_magic(&fl6->saddr, + &fl6->daddr, + len, fl6->flowi6_proto, + tmp_csum); + } + ip6_push_pending_frames(sk); +} + +struct icmpv6_msg { + struct sk_buff *skb; + int offset; + uint8_t type; +}; + +static int icmpv6_getfrag(void *from, char *to, int offset, int len, int odd, struct sk_buff *skb) +{ + struct icmpv6_msg *msg = (struct icmpv6_msg *) from; + struct sk_buff *org_skb = msg->skb; + __wsum csum; + + csum = skb_copy_and_csum_bits(org_skb, msg->offset + offset, + to, len); + skb->csum = csum_block_add(skb->csum, csum, odd); + if (!(msg->type & ICMPV6_INFOMSG_MASK)) + nf_ct_attach(skb, org_skb); + return 0; +} + +#if IS_ENABLED(CONFIG_IPV6_MIP6) +static void mip6_addr_swap(struct sk_buff *skb, const struct inet6_skb_parm *opt) +{ + struct ipv6hdr *iph = ipv6_hdr(skb); + struct ipv6_destopt_hao *hao; + struct in6_addr tmp; + int off; + + if (opt->dsthao) { + off = ipv6_find_tlv(skb, opt->dsthao, IPV6_TLV_HAO); + if (likely(off >= 0)) { + hao = (struct ipv6_destopt_hao *) + (skb_network_header(skb) + off); + tmp = iph->saddr; + iph->saddr = hao->addr; + hao->addr = tmp; + } + } +} +#else +static inline void mip6_addr_swap(struct sk_buff *skb, const struct inet6_skb_parm *opt) {} +#endif + +static struct dst_entry *icmpv6_route_lookup(struct net *net, + struct sk_buff *skb, + struct sock *sk, + struct flowi6 *fl6) +{ + struct dst_entry *dst, *dst2; + struct flowi6 fl2; + int err; + + err = ip6_dst_lookup(net, sk, &dst, fl6); + if (err) + return ERR_PTR(err); + + /* + * We won't send icmp if the destination is known + * anycast. + */ + if (ipv6_anycast_destination(dst, &fl6->daddr)) { + net_dbg_ratelimited("icmp6_send: acast source\n"); + dst_release(dst); + return ERR_PTR(-EINVAL); + } + + /* No need to clone since we're just using its address. */ + dst2 = dst; + + dst = xfrm_lookup(net, dst, flowi6_to_flowi(fl6), sk, 0); + if (!IS_ERR(dst)) { + if (dst != dst2) + return dst; + } else { + if (PTR_ERR(dst) == -EPERM) + dst = NULL; + else + return dst; + } + + err = xfrm_decode_session_reverse(skb, flowi6_to_flowi(&fl2), AF_INET6); + if (err) + goto relookup_failed; + + err = ip6_dst_lookup(net, sk, &dst2, &fl2); + if (err) + goto relookup_failed; + + dst2 = xfrm_lookup(net, dst2, flowi6_to_flowi(&fl2), sk, XFRM_LOOKUP_ICMP); + if (!IS_ERR(dst2)) { + dst_release(dst); + dst = dst2; + } else { + err = PTR_ERR(dst2); + if (err == -EPERM) { + dst_release(dst); + return dst2; + } else + goto relookup_failed; + } + +relookup_failed: + if (dst) + return dst; + return ERR_PTR(err); +} + +static struct net_device *icmp6_dev(const struct sk_buff *skb) +{ + struct net_device *dev = skb->dev; + + /* for local traffic to local address, skb dev is the loopback + * device. Check if there is a dst attached to the skb and if so + * get the real device index. Same is needed for replies to a link + * local address on a device enslaved to an L3 master device + */ + if (unlikely(dev->ifindex == LOOPBACK_IFINDEX || netif_is_l3_master(skb->dev))) { + const struct rt6_info *rt6 = skb_rt6_info(skb); + + /* The destination could be an external IP in Ext Hdr (SRv6, RPL, etc.), + * and ip6_null_entry could be set to skb if no route is found. + */ + if (rt6 && rt6->rt6i_idev) + dev = rt6->rt6i_idev->dev; + } + + return dev; +} + +static int icmp6_iif(const struct sk_buff *skb) +{ + return icmp6_dev(skb)->ifindex; +} + +/* + * Send an ICMP message in response to a packet in error + */ +void icmp6_send(struct sk_buff *skb, u8 type, u8 code, __u32 info, + const struct in6_addr *force_saddr, + const struct inet6_skb_parm *parm) +{ + struct inet6_dev *idev = NULL; + struct ipv6hdr *hdr = ipv6_hdr(skb); + struct sock *sk; + struct net *net; + struct ipv6_pinfo *np; + const struct in6_addr *saddr = NULL; + struct dst_entry *dst; + struct icmp6hdr tmp_hdr; + struct flowi6 fl6; + struct icmpv6_msg msg; + struct ipcm6_cookie ipc6; + int iif = 0; + int addr_type = 0; + int len; + u32 mark; + + if ((u8 *)hdr < skb->head || + (skb_network_header(skb) + sizeof(*hdr)) > skb_tail_pointer(skb)) + return; + + if (!skb->dev) + return; + net = dev_net(skb->dev); + mark = IP6_REPLY_MARK(net, skb->mark); + /* + * Make sure we respect the rules + * i.e. RFC 1885 2.4(e) + * Rule (e.1) is enforced by not using icmp6_send + * in any code that processes icmp errors. + */ + addr_type = ipv6_addr_type(&hdr->daddr); + + if (ipv6_chk_addr(net, &hdr->daddr, skb->dev, 0) || + ipv6_chk_acast_addr_src(net, skb->dev, &hdr->daddr)) + saddr = &hdr->daddr; + + /* + * Dest addr check + */ + + if (addr_type & IPV6_ADDR_MULTICAST || skb->pkt_type != PACKET_HOST) { + if (type != ICMPV6_PKT_TOOBIG && + !(type == ICMPV6_PARAMPROB && + code == ICMPV6_UNK_OPTION && + (opt_unrec(skb, info)))) + return; + + saddr = NULL; + } + + addr_type = ipv6_addr_type(&hdr->saddr); + + /* + * Source addr check + */ + + if (__ipv6_addr_needs_scope_id(addr_type)) { + iif = icmp6_iif(skb); + } else { + /* + * The source device is used for looking up which routing table + * to use for sending an ICMP error. + */ + iif = l3mdev_master_ifindex(skb->dev); + } + + /* + * Must not send error if the source does not uniquely + * identify a single node (RFC2463 Section 2.4). + * We check unspecified / multicast addresses here, + * and anycast addresses will be checked later. + */ + if ((addr_type == IPV6_ADDR_ANY) || (addr_type & IPV6_ADDR_MULTICAST)) { + net_dbg_ratelimited("icmp6_send: addr_any/mcast source [%pI6c > %pI6c]\n", + &hdr->saddr, &hdr->daddr); + return; + } + + /* + * Never answer to a ICMP packet. + */ + if (is_ineligible(skb)) { + net_dbg_ratelimited("icmp6_send: no reply to icmp error [%pI6c > %pI6c]\n", + &hdr->saddr, &hdr->daddr); + return; + } + + /* Needed by both icmp_global_allow and icmpv6_xmit_lock */ + local_bh_disable(); + + /* Check global sysctl_icmp_msgs_per_sec ratelimit */ + if (!(skb->dev->flags & IFF_LOOPBACK) && !icmpv6_global_allow(net, type)) + goto out_bh_enable; + + mip6_addr_swap(skb, parm); + + sk = icmpv6_xmit_lock(net); + if (!sk) + goto out_bh_enable; + + memset(&fl6, 0, sizeof(fl6)); + fl6.flowi6_proto = IPPROTO_ICMPV6; + fl6.daddr = hdr->saddr; + if (force_saddr) + saddr = force_saddr; + if (saddr) { + fl6.saddr = *saddr; + } else if (!icmpv6_rt_has_prefsrc(sk, type, &fl6)) { + /* select a more meaningful saddr from input if */ + struct net_device *in_netdev; + + in_netdev = dev_get_by_index(net, parm->iif); + if (in_netdev) { + ipv6_dev_get_saddr(net, in_netdev, &fl6.daddr, + inet6_sk(sk)->srcprefs, + &fl6.saddr); + dev_put(in_netdev); + } + } + fl6.flowi6_mark = mark; + fl6.flowi6_oif = iif; + fl6.fl6_icmp_type = type; + fl6.fl6_icmp_code = code; + fl6.flowi6_uid = sock_net_uid(net, NULL); + fl6.mp_hash = rt6_multipath_hash(net, &fl6, skb, NULL); + security_skb_classify_flow(skb, flowi6_to_flowi_common(&fl6)); + + np = inet6_sk(sk); + + if (!icmpv6_xrlim_allow(sk, type, &fl6)) + goto out; + + tmp_hdr.icmp6_type = type; + tmp_hdr.icmp6_code = code; + tmp_hdr.icmp6_cksum = 0; + tmp_hdr.icmp6_pointer = htonl(info); + + if (!fl6.flowi6_oif && ipv6_addr_is_multicast(&fl6.daddr)) + fl6.flowi6_oif = np->mcast_oif; + else if (!fl6.flowi6_oif) + fl6.flowi6_oif = np->ucast_oif; + + ipcm6_init_sk(&ipc6, np); + ipc6.sockc.mark = mark; + fl6.flowlabel = ip6_make_flowinfo(ipc6.tclass, fl6.flowlabel); + + dst = icmpv6_route_lookup(net, skb, sk, &fl6); + if (IS_ERR(dst)) + goto out; + + ipc6.hlimit = ip6_sk_dst_hoplimit(np, &fl6, dst); + + msg.skb = skb; + msg.offset = skb_network_offset(skb); + msg.type = type; + + len = skb->len - msg.offset; + len = min_t(unsigned int, len, IPV6_MIN_MTU - sizeof(struct ipv6hdr) - sizeof(struct icmp6hdr)); + if (len < 0) { + net_dbg_ratelimited("icmp: len problem [%pI6c > %pI6c]\n", + &hdr->saddr, &hdr->daddr); + goto out_dst_release; + } + + rcu_read_lock(); + idev = __in6_dev_get(skb->dev); + + if (ip6_append_data(sk, icmpv6_getfrag, &msg, + len + sizeof(struct icmp6hdr), + sizeof(struct icmp6hdr), + &ipc6, &fl6, (struct rt6_info *)dst, + MSG_DONTWAIT)) { + ICMP6_INC_STATS(net, idev, ICMP6_MIB_OUTERRORS); + ip6_flush_pending_frames(sk); + } else { + icmpv6_push_pending_frames(sk, &fl6, &tmp_hdr, + len + sizeof(struct icmp6hdr)); + } + rcu_read_unlock(); +out_dst_release: + dst_release(dst); +out: + icmpv6_xmit_unlock(sk); +out_bh_enable: + local_bh_enable(); +} +EXPORT_SYMBOL(icmp6_send); + +/* Slightly more convenient version of icmp6_send with drop reasons. + */ +void icmpv6_param_prob_reason(struct sk_buff *skb, u8 code, int pos, + enum skb_drop_reason reason) +{ + icmp6_send(skb, ICMPV6_PARAMPROB, code, pos, NULL, IP6CB(skb)); + kfree_skb_reason(skb, reason); +} + +/* Generate icmpv6 with type/code ICMPV6_DEST_UNREACH/ICMPV6_ADDR_UNREACH + * if sufficient data bytes are available + * @nhs is the size of the tunnel header(s) : + * Either an IPv4 header for SIT encap + * an IPv4 header + GRE header for GRE encap + */ +int ip6_err_gen_icmpv6_unreach(struct sk_buff *skb, int nhs, int type, + unsigned int data_len) +{ + struct in6_addr temp_saddr; + struct rt6_info *rt; + struct sk_buff *skb2; + u32 info = 0; + + if (!pskb_may_pull(skb, nhs + sizeof(struct ipv6hdr) + 8)) + return 1; + + /* RFC 4884 (partial) support for ICMP extensions */ + if (data_len < 128 || (data_len & 7) || skb->len < data_len) + data_len = 0; + + skb2 = data_len ? skb_copy(skb, GFP_ATOMIC) : skb_clone(skb, GFP_ATOMIC); + + if (!skb2) + return 1; + + skb_dst_drop(skb2); + skb_pull(skb2, nhs); + skb_reset_network_header(skb2); + + rt = rt6_lookup(dev_net(skb->dev), &ipv6_hdr(skb2)->saddr, NULL, 0, + skb, 0); + + if (rt && rt->dst.dev) + skb2->dev = rt->dst.dev; + + ipv6_addr_set_v4mapped(ip_hdr(skb)->saddr, &temp_saddr); + + if (data_len) { + /* RFC 4884 (partial) support : + * insert 0 padding at the end, before the extensions + */ + __skb_push(skb2, nhs); + skb_reset_network_header(skb2); + memmove(skb2->data, skb2->data + nhs, data_len - nhs); + memset(skb2->data + data_len - nhs, 0, nhs); + /* RFC 4884 4.5 : Length is measured in 64-bit words, + * and stored in reserved[0] + */ + info = (data_len/8) << 24; + } + if (type == ICMP_TIME_EXCEEDED) + icmp6_send(skb2, ICMPV6_TIME_EXCEED, ICMPV6_EXC_HOPLIMIT, + info, &temp_saddr, IP6CB(skb2)); + else + icmp6_send(skb2, ICMPV6_DEST_UNREACH, ICMPV6_ADDR_UNREACH, + info, &temp_saddr, IP6CB(skb2)); + if (rt) + ip6_rt_put(rt); + + kfree_skb(skb2); + + return 0; +} +EXPORT_SYMBOL(ip6_err_gen_icmpv6_unreach); + +static void icmpv6_echo_reply(struct sk_buff *skb) +{ + struct net *net = dev_net(skb->dev); + struct sock *sk; + struct inet6_dev *idev; + struct ipv6_pinfo *np; + const struct in6_addr *saddr = NULL; + struct icmp6hdr *icmph = icmp6_hdr(skb); + struct icmp6hdr tmp_hdr; + struct flowi6 fl6; + struct icmpv6_msg msg; + struct dst_entry *dst; + struct ipcm6_cookie ipc6; + u32 mark = IP6_REPLY_MARK(net, skb->mark); + bool acast; + u8 type; + + if (ipv6_addr_is_multicast(&ipv6_hdr(skb)->daddr) && + net->ipv6.sysctl.icmpv6_echo_ignore_multicast) + return; + + saddr = &ipv6_hdr(skb)->daddr; + + acast = ipv6_anycast_destination(skb_dst(skb), saddr); + if (acast && net->ipv6.sysctl.icmpv6_echo_ignore_anycast) + return; + + if (!ipv6_unicast_destination(skb) && + !(net->ipv6.sysctl.anycast_src_echo_reply && acast)) + saddr = NULL; + + if (icmph->icmp6_type == ICMPV6_EXT_ECHO_REQUEST) + type = ICMPV6_EXT_ECHO_REPLY; + else + type = ICMPV6_ECHO_REPLY; + + memcpy(&tmp_hdr, icmph, sizeof(tmp_hdr)); + tmp_hdr.icmp6_type = type; + + memset(&fl6, 0, sizeof(fl6)); + if (net->ipv6.sysctl.flowlabel_reflect & FLOWLABEL_REFLECT_ICMPV6_ECHO_REPLIES) + fl6.flowlabel = ip6_flowlabel(ipv6_hdr(skb)); + + fl6.flowi6_proto = IPPROTO_ICMPV6; + fl6.daddr = ipv6_hdr(skb)->saddr; + if (saddr) + fl6.saddr = *saddr; + fl6.flowi6_oif = icmp6_iif(skb); + fl6.fl6_icmp_type = type; + fl6.flowi6_mark = mark; + fl6.flowi6_uid = sock_net_uid(net, NULL); + security_skb_classify_flow(skb, flowi6_to_flowi_common(&fl6)); + + local_bh_disable(); + sk = icmpv6_xmit_lock(net); + if (!sk) + goto out_bh_enable; + np = inet6_sk(sk); + + if (!fl6.flowi6_oif && ipv6_addr_is_multicast(&fl6.daddr)) + fl6.flowi6_oif = np->mcast_oif; + else if (!fl6.flowi6_oif) + fl6.flowi6_oif = np->ucast_oif; + + if (ip6_dst_lookup(net, sk, &dst, &fl6)) + goto out; + dst = xfrm_lookup(net, dst, flowi6_to_flowi(&fl6), sk, 0); + if (IS_ERR(dst)) + goto out; + + /* Check the ratelimit */ + if ((!(skb->dev->flags & IFF_LOOPBACK) && !icmpv6_global_allow(net, ICMPV6_ECHO_REPLY)) || + !icmpv6_xrlim_allow(sk, ICMPV6_ECHO_REPLY, &fl6)) + goto out_dst_release; + + idev = __in6_dev_get(skb->dev); + + msg.skb = skb; + msg.offset = 0; + msg.type = type; + + ipcm6_init_sk(&ipc6, np); + ipc6.hlimit = ip6_sk_dst_hoplimit(np, &fl6, dst); + ipc6.tclass = ipv6_get_dsfield(ipv6_hdr(skb)); + ipc6.sockc.mark = mark; + + if (icmph->icmp6_type == ICMPV6_EXT_ECHO_REQUEST) + if (!icmp_build_probe(skb, (struct icmphdr *)&tmp_hdr)) + goto out_dst_release; + + if (ip6_append_data(sk, icmpv6_getfrag, &msg, + skb->len + sizeof(struct icmp6hdr), + sizeof(struct icmp6hdr), &ipc6, &fl6, + (struct rt6_info *)dst, MSG_DONTWAIT)) { + __ICMP6_INC_STATS(net, idev, ICMP6_MIB_OUTERRORS); + ip6_flush_pending_frames(sk); + } else { + icmpv6_push_pending_frames(sk, &fl6, &tmp_hdr, + skb->len + sizeof(struct icmp6hdr)); + } +out_dst_release: + dst_release(dst); +out: + icmpv6_xmit_unlock(sk); +out_bh_enable: + local_bh_enable(); +} + +void icmpv6_notify(struct sk_buff *skb, u8 type, u8 code, __be32 info) +{ + struct inet6_skb_parm *opt = IP6CB(skb); + const struct inet6_protocol *ipprot; + int inner_offset; + __be16 frag_off; + u8 nexthdr; + struct net *net = dev_net(skb->dev); + + if (!pskb_may_pull(skb, sizeof(struct ipv6hdr))) + goto out; + + seg6_icmp_srh(skb, opt); + + nexthdr = ((struct ipv6hdr *)skb->data)->nexthdr; + if (ipv6_ext_hdr(nexthdr)) { + /* now skip over extension headers */ + inner_offset = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), + &nexthdr, &frag_off); + if (inner_offset < 0) + goto out; + } else { + inner_offset = sizeof(struct ipv6hdr); + } + + /* Checkin header including 8 bytes of inner protocol header. */ + if (!pskb_may_pull(skb, inner_offset+8)) + goto out; + + /* BUGGG_FUTURE: we should try to parse exthdrs in this packet. + Without this we will not able f.e. to make source routed + pmtu discovery. + Corresponding argument (opt) to notifiers is already added. + --ANK (980726) + */ + + ipprot = rcu_dereference(inet6_protos[nexthdr]); + if (ipprot && ipprot->err_handler) + ipprot->err_handler(skb, opt, type, code, inner_offset, info); + + raw6_icmp_error(skb, nexthdr, type, code, inner_offset, info); + return; + +out: + __ICMP6_INC_STATS(net, __in6_dev_get(skb->dev), ICMP6_MIB_INERRORS); +} + +/* + * Handle icmp messages + */ + +static int icmpv6_rcv(struct sk_buff *skb) +{ + enum skb_drop_reason reason = SKB_DROP_REASON_NOT_SPECIFIED; + struct net *net = dev_net(skb->dev); + struct net_device *dev = icmp6_dev(skb); + struct inet6_dev *idev = __in6_dev_get(dev); + const struct in6_addr *saddr, *daddr; + struct icmp6hdr *hdr; + u8 type; + + if (!xfrm6_policy_check(NULL, XFRM_POLICY_IN, skb)) { + struct sec_path *sp = skb_sec_path(skb); + int nh; + + if (!(sp && sp->xvec[sp->len - 1]->props.flags & + XFRM_STATE_ICMP)) { + reason = SKB_DROP_REASON_XFRM_POLICY; + goto drop_no_count; + } + + if (!pskb_may_pull(skb, sizeof(*hdr) + sizeof(struct ipv6hdr))) + goto drop_no_count; + + nh = skb_network_offset(skb); + skb_set_network_header(skb, sizeof(*hdr)); + + if (!xfrm6_policy_check_reverse(NULL, XFRM_POLICY_IN, + skb)) { + reason = SKB_DROP_REASON_XFRM_POLICY; + goto drop_no_count; + } + + skb_set_network_header(skb, nh); + } + + __ICMP6_INC_STATS(dev_net(dev), idev, ICMP6_MIB_INMSGS); + + saddr = &ipv6_hdr(skb)->saddr; + daddr = &ipv6_hdr(skb)->daddr; + + if (skb_checksum_validate(skb, IPPROTO_ICMPV6, ip6_compute_pseudo)) { + net_dbg_ratelimited("ICMPv6 checksum failed [%pI6c > %pI6c]\n", + saddr, daddr); + goto csum_error; + } + + if (!pskb_pull(skb, sizeof(*hdr))) + goto discard_it; + + hdr = icmp6_hdr(skb); + + type = hdr->icmp6_type; + + ICMP6MSGIN_INC_STATS(dev_net(dev), idev, type); + + switch (type) { + case ICMPV6_ECHO_REQUEST: + if (!net->ipv6.sysctl.icmpv6_echo_ignore_all) + icmpv6_echo_reply(skb); + break; + case ICMPV6_EXT_ECHO_REQUEST: + if (!net->ipv6.sysctl.icmpv6_echo_ignore_all && + READ_ONCE(net->ipv4.sysctl_icmp_echo_enable_probe)) + icmpv6_echo_reply(skb); + break; + + case ICMPV6_ECHO_REPLY: + reason = ping_rcv(skb); + break; + + case ICMPV6_EXT_ECHO_REPLY: + reason = ping_rcv(skb); + break; + + case ICMPV6_PKT_TOOBIG: + /* BUGGG_FUTURE: if packet contains rthdr, we cannot update + standard destination cache. Seems, only "advanced" + destination cache will allow to solve this problem + --ANK (980726) + */ + if (!pskb_may_pull(skb, sizeof(struct ipv6hdr))) + goto discard_it; + hdr = icmp6_hdr(skb); + + /* to notify */ + fallthrough; + case ICMPV6_DEST_UNREACH: + case ICMPV6_TIME_EXCEED: + case ICMPV6_PARAMPROB: + icmpv6_notify(skb, type, hdr->icmp6_code, hdr->icmp6_mtu); + break; + + case NDISC_ROUTER_SOLICITATION: + case NDISC_ROUTER_ADVERTISEMENT: + case NDISC_NEIGHBOUR_SOLICITATION: + case NDISC_NEIGHBOUR_ADVERTISEMENT: + case NDISC_REDIRECT: + ndisc_rcv(skb); + break; + + case ICMPV6_MGM_QUERY: + igmp6_event_query(skb); + return 0; + + case ICMPV6_MGM_REPORT: + igmp6_event_report(skb); + return 0; + + case ICMPV6_MGM_REDUCTION: + case ICMPV6_NI_QUERY: + case ICMPV6_NI_REPLY: + case ICMPV6_MLD2_REPORT: + case ICMPV6_DHAAD_REQUEST: + case ICMPV6_DHAAD_REPLY: + case ICMPV6_MOBILE_PREFIX_SOL: + case ICMPV6_MOBILE_PREFIX_ADV: + break; + + default: + /* informational */ + if (type & ICMPV6_INFOMSG_MASK) + break; + + net_dbg_ratelimited("icmpv6: msg of unknown type [%pI6c > %pI6c]\n", + saddr, daddr); + + /* + * error of unknown type. + * must pass to upper level + */ + + icmpv6_notify(skb, type, hdr->icmp6_code, hdr->icmp6_mtu); + } + + /* until the v6 path can be better sorted assume failure and + * preserve the status quo behaviour for the rest of the paths to here + */ + if (reason) + kfree_skb_reason(skb, reason); + else + consume_skb(skb); + + return 0; + +csum_error: + reason = SKB_DROP_REASON_ICMP_CSUM; + __ICMP6_INC_STATS(dev_net(dev), idev, ICMP6_MIB_CSUMERRORS); +discard_it: + __ICMP6_INC_STATS(dev_net(dev), idev, ICMP6_MIB_INERRORS); +drop_no_count: + kfree_skb_reason(skb, reason); + return 0; +} + +void icmpv6_flow_init(struct sock *sk, struct flowi6 *fl6, + u8 type, + const struct in6_addr *saddr, + const struct in6_addr *daddr, + int oif) +{ + memset(fl6, 0, sizeof(*fl6)); + fl6->saddr = *saddr; + fl6->daddr = *daddr; + fl6->flowi6_proto = IPPROTO_ICMPV6; + fl6->fl6_icmp_type = type; + fl6->fl6_icmp_code = 0; + fl6->flowi6_oif = oif; + security_sk_classify_flow(sk, flowi6_to_flowi_common(fl6)); +} + +int __init icmpv6_init(void) +{ + struct sock *sk; + int err, i; + + for_each_possible_cpu(i) { + err = inet_ctl_sock_create(&sk, PF_INET6, + SOCK_RAW, IPPROTO_ICMPV6, &init_net); + if (err < 0) { + pr_err("Failed to initialize the ICMP6 control socket (err %d)\n", + err); + return err; + } + + per_cpu(ipv6_icmp_sk, i) = sk; + + /* Enough space for 2 64K ICMP packets, including + * sk_buff struct overhead. + */ + sk->sk_sndbuf = 2 * SKB_TRUESIZE(64 * 1024); + } + + err = -EAGAIN; + if (inet6_add_protocol(&icmpv6_protocol, IPPROTO_ICMPV6) < 0) + goto fail; + + err = inet6_register_icmp_sender(icmp6_send); + if (err) + goto sender_reg_err; + return 0; + +sender_reg_err: + inet6_del_protocol(&icmpv6_protocol, IPPROTO_ICMPV6); +fail: + pr_err("Failed to register ICMP6 protocol\n"); + return err; +} + +void icmpv6_cleanup(void) +{ + inet6_unregister_icmp_sender(icmp6_send); + inet6_del_protocol(&icmpv6_protocol, IPPROTO_ICMPV6); +} + + +static const struct icmp6_err { + int err; + int fatal; +} tab_unreach[] = { + { /* NOROUTE */ + .err = ENETUNREACH, + .fatal = 0, + }, + { /* ADM_PROHIBITED */ + .err = EACCES, + .fatal = 1, + }, + { /* Was NOT_NEIGHBOUR, now reserved */ + .err = EHOSTUNREACH, + .fatal = 0, + }, + { /* ADDR_UNREACH */ + .err = EHOSTUNREACH, + .fatal = 0, + }, + { /* PORT_UNREACH */ + .err = ECONNREFUSED, + .fatal = 1, + }, + { /* POLICY_FAIL */ + .err = EACCES, + .fatal = 1, + }, + { /* REJECT_ROUTE */ + .err = EACCES, + .fatal = 1, + }, +}; + +int icmpv6_err_convert(u8 type, u8 code, int *err) +{ + int fatal = 0; + + *err = EPROTO; + + switch (type) { + case ICMPV6_DEST_UNREACH: + fatal = 1; + if (code < ARRAY_SIZE(tab_unreach)) { + *err = tab_unreach[code].err; + fatal = tab_unreach[code].fatal; + } + break; + + case ICMPV6_PKT_TOOBIG: + *err = EMSGSIZE; + break; + + case ICMPV6_PARAMPROB: + *err = EPROTO; + fatal = 1; + break; + + case ICMPV6_TIME_EXCEED: + *err = EHOSTUNREACH; + break; + } + + return fatal; +} +EXPORT_SYMBOL(icmpv6_err_convert); + +#ifdef CONFIG_SYSCTL +static struct ctl_table ipv6_icmp_table_template[] = { + { + .procname = "ratelimit", + .data = &init_net.ipv6.sysctl.icmpv6_time, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_ms_jiffies, + }, + { + .procname = "echo_ignore_all", + .data = &init_net.ipv6.sysctl.icmpv6_echo_ignore_all, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "echo_ignore_multicast", + .data = &init_net.ipv6.sysctl.icmpv6_echo_ignore_multicast, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "echo_ignore_anycast", + .data = &init_net.ipv6.sysctl.icmpv6_echo_ignore_anycast, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "ratemask", + .data = &init_net.ipv6.sysctl.icmpv6_ratemask_ptr, + .maxlen = ICMPV6_MSG_MAX + 1, + .mode = 0644, + .proc_handler = proc_do_large_bitmap, + }, + { }, +}; + +struct ctl_table * __net_init ipv6_icmp_sysctl_init(struct net *net) +{ + struct ctl_table *table; + + table = kmemdup(ipv6_icmp_table_template, + sizeof(ipv6_icmp_table_template), + GFP_KERNEL); + + if (table) { + table[0].data = &net->ipv6.sysctl.icmpv6_time; + table[1].data = &net->ipv6.sysctl.icmpv6_echo_ignore_all; + table[2].data = &net->ipv6.sysctl.icmpv6_echo_ignore_multicast; + table[3].data = &net->ipv6.sysctl.icmpv6_echo_ignore_anycast; + table[4].data = &net->ipv6.sysctl.icmpv6_ratemask_ptr; + } + return table; +} +#endif diff --git a/net/ipv6/ila/Makefile b/net/ipv6/ila/Makefile new file mode 100644 index 000000000..1bc88ed7e --- /dev/null +++ b/net/ipv6/ila/Makefile @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for ILA module +# + +obj-$(CONFIG_IPV6_ILA) += ila.o + +ila-objs := ila_main.o ila_common.o ila_lwt.o ila_xlat.o diff --git a/net/ipv6/ila/ila.h b/net/ipv6/ila/ila.h new file mode 100644 index 000000000..ad5f6f6ba --- /dev/null +++ b/net/ipv6/ila/ila.h @@ -0,0 +1,125 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * Copyright (c) 2015 Tom Herbert + */ + +#ifndef __ILA_H +#define __ILA_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct ila_locator { + union { + __u8 v8[8]; + __be16 v16[4]; + __be32 v32[2]; + __be64 v64; + }; +}; + +struct ila_identifier { + union { + struct { +#if defined(__LITTLE_ENDIAN_BITFIELD) + u8 __space:4; + u8 csum_neutral:1; + u8 type:3; +#elif defined(__BIG_ENDIAN_BITFIELD) + u8 type:3; + u8 csum_neutral:1; + u8 __space:4; +#else +#error "Adjust your defines" +#endif + u8 __space2[7]; + }; + __u8 v8[8]; + __be16 v16[4]; + __be32 v32[2]; + __be64 v64; + }; +}; + +#define CSUM_NEUTRAL_FLAG htonl(0x10000000) + +struct ila_addr { + union { + struct in6_addr addr; + struct { + struct ila_locator loc; + struct ila_identifier ident; + }; + }; +}; + +static inline struct ila_addr *ila_a2i(struct in6_addr *addr) +{ + return (struct ila_addr *)addr; +} + +struct ila_params { + struct ila_locator locator; + struct ila_locator locator_match; + __wsum csum_diff; + u8 csum_mode; + u8 ident_type; +}; + +static inline __wsum compute_csum_diff8(const __be32 *from, const __be32 *to) +{ + __be32 diff[] = { + ~from[0], ~from[1], to[0], to[1], + }; + + return csum_partial(diff, sizeof(diff), 0); +} + +static inline bool ila_csum_neutral_set(struct ila_identifier ident) +{ + return !!(ident.csum_neutral); +} + +void ila_update_ipv6_locator(struct sk_buff *skb, struct ila_params *p, + bool set_csum_neutral); + +void ila_init_saved_csum(struct ila_params *p); + +struct ila_net { + struct { + struct rhashtable rhash_table; + spinlock_t *locks; /* Bucket locks for entry manipulation */ + unsigned int locks_mask; + bool hooks_registered; + } xlat; +}; + +int ila_lwt_init(void); +void ila_lwt_fini(void); + +int ila_xlat_init_net(struct net *net); +void ila_xlat_exit_net(struct net *net); + +int ila_xlat_nl_cmd_add_mapping(struct sk_buff *skb, struct genl_info *info); +int ila_xlat_nl_cmd_del_mapping(struct sk_buff *skb, struct genl_info *info); +int ila_xlat_nl_cmd_get_mapping(struct sk_buff *skb, struct genl_info *info); +int ila_xlat_nl_cmd_flush(struct sk_buff *skb, struct genl_info *info); +int ila_xlat_nl_dump_start(struct netlink_callback *cb); +int ila_xlat_nl_dump_done(struct netlink_callback *cb); +int ila_xlat_nl_dump(struct sk_buff *skb, struct netlink_callback *cb); + +extern unsigned int ila_net_id; + +extern struct genl_family ila_nl_family; + +#endif /* __ILA_H */ diff --git a/net/ipv6/ila/ila_common.c b/net/ipv6/ila/ila_common.c new file mode 100644 index 000000000..95e914691 --- /dev/null +++ b/net/ipv6/ila/ila_common.c @@ -0,0 +1,155 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ila.h" + +void ila_init_saved_csum(struct ila_params *p) +{ + if (!p->locator_match.v64) + return; + + p->csum_diff = compute_csum_diff8( + (__be32 *)&p->locator, + (__be32 *)&p->locator_match); +} + +static __wsum get_csum_diff_iaddr(struct ila_addr *iaddr, struct ila_params *p) +{ + if (p->locator_match.v64) + return p->csum_diff; + else + return compute_csum_diff8((__be32 *)&p->locator, + (__be32 *)&iaddr->loc); +} + +static __wsum get_csum_diff(struct ipv6hdr *ip6h, struct ila_params *p) +{ + return get_csum_diff_iaddr(ila_a2i(&ip6h->daddr), p); +} + +static void ila_csum_do_neutral_fmt(struct ila_addr *iaddr, + struct ila_params *p) +{ + __sum16 *adjust = (__force __sum16 *)&iaddr->ident.v16[3]; + __wsum diff, fval; + + diff = get_csum_diff_iaddr(iaddr, p); + + fval = (__force __wsum)(ila_csum_neutral_set(iaddr->ident) ? + CSUM_NEUTRAL_FLAG : ~CSUM_NEUTRAL_FLAG); + + diff = csum_add(diff, fval); + + *adjust = ~csum_fold(csum_add(diff, csum_unfold(*adjust))); + + /* Flip the csum-neutral bit. Either we are doing a SIR->ILA + * translation with ILA_CSUM_NEUTRAL_MAP as the csum_method + * and the C-bit is not set, or we are doing an ILA-SIR + * tranlsation and the C-bit is set. + */ + iaddr->ident.csum_neutral ^= 1; +} + +static void ila_csum_do_neutral_nofmt(struct ila_addr *iaddr, + struct ila_params *p) +{ + __sum16 *adjust = (__force __sum16 *)&iaddr->ident.v16[3]; + __wsum diff; + + diff = get_csum_diff_iaddr(iaddr, p); + + *adjust = ~csum_fold(csum_add(diff, csum_unfold(*adjust))); +} + +static void ila_csum_adjust_transport(struct sk_buff *skb, + struct ila_params *p) +{ + size_t nhoff = sizeof(struct ipv6hdr); + struct ipv6hdr *ip6h = ipv6_hdr(skb); + __wsum diff; + + switch (ip6h->nexthdr) { + case NEXTHDR_TCP: + if (likely(pskb_may_pull(skb, nhoff + sizeof(struct tcphdr)))) { + struct tcphdr *th = (struct tcphdr *) + (skb_network_header(skb) + nhoff); + + diff = get_csum_diff(ip6h, p); + inet_proto_csum_replace_by_diff(&th->check, skb, + diff, true); + } + break; + case NEXTHDR_UDP: + if (likely(pskb_may_pull(skb, nhoff + sizeof(struct udphdr)))) { + struct udphdr *uh = (struct udphdr *) + (skb_network_header(skb) + nhoff); + + if (uh->check || skb->ip_summed == CHECKSUM_PARTIAL) { + diff = get_csum_diff(ip6h, p); + inet_proto_csum_replace_by_diff(&uh->check, skb, + diff, true); + if (!uh->check) + uh->check = CSUM_MANGLED_0; + } + } + break; + case NEXTHDR_ICMP: + if (likely(pskb_may_pull(skb, + nhoff + sizeof(struct icmp6hdr)))) { + struct icmp6hdr *ih = (struct icmp6hdr *) + (skb_network_header(skb) + nhoff); + + diff = get_csum_diff(ip6h, p); + inet_proto_csum_replace_by_diff(&ih->icmp6_cksum, skb, + diff, true); + } + break; + } +} + +void ila_update_ipv6_locator(struct sk_buff *skb, struct ila_params *p, + bool sir2ila) +{ + struct ipv6hdr *ip6h = ipv6_hdr(skb); + struct ila_addr *iaddr = ila_a2i(&ip6h->daddr); + + switch (p->csum_mode) { + case ILA_CSUM_ADJUST_TRANSPORT: + ila_csum_adjust_transport(skb, p); + break; + case ILA_CSUM_NEUTRAL_MAP: + if (sir2ila) { + if (WARN_ON(ila_csum_neutral_set(iaddr->ident))) { + /* Checksum flag should never be + * set in a formatted SIR address. + */ + break; + } + } else if (!ila_csum_neutral_set(iaddr->ident)) { + /* ILA to SIR translation and C-bit isn't + * set so we're good. + */ + break; + } + ila_csum_do_neutral_fmt(iaddr, p); + break; + case ILA_CSUM_NEUTRAL_MAP_AUTO: + ila_csum_do_neutral_nofmt(iaddr, p); + break; + case ILA_CSUM_NO_ACTION: + break; + } + + /* Now change destination address */ + iaddr->loc = p->locator; +} diff --git a/net/ipv6/ila/ila_lwt.c b/net/ipv6/ila/ila_lwt.c new file mode 100644 index 000000000..8c1ce7895 --- /dev/null +++ b/net/ipv6/ila/ila_lwt.c @@ -0,0 +1,325 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ila.h" + +struct ila_lwt { + struct ila_params p; + struct dst_cache dst_cache; + u32 connected : 1; + u32 lwt_output : 1; +}; + +static inline struct ila_lwt *ila_lwt_lwtunnel( + struct lwtunnel_state *lwt) +{ + return (struct ila_lwt *)lwt->data; +} + +static inline struct ila_params *ila_params_lwtunnel( + struct lwtunnel_state *lwt) +{ + return &ila_lwt_lwtunnel(lwt)->p; +} + +static int ila_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct dst_entry *orig_dst = skb_dst(skb); + struct rt6_info *rt = (struct rt6_info *)orig_dst; + struct ila_lwt *ilwt = ila_lwt_lwtunnel(orig_dst->lwtstate); + struct dst_entry *dst; + int err = -EINVAL; + + if (skb->protocol != htons(ETH_P_IPV6)) + goto drop; + + if (ilwt->lwt_output) + ila_update_ipv6_locator(skb, + ila_params_lwtunnel(orig_dst->lwtstate), + true); + + if (rt->rt6i_flags & (RTF_GATEWAY | RTF_CACHE)) { + /* Already have a next hop address in route, no need for + * dest cache route. + */ + return orig_dst->lwtstate->orig_output(net, sk, skb); + } + + dst = dst_cache_get(&ilwt->dst_cache); + if (unlikely(!dst)) { + struct ipv6hdr *ip6h = ipv6_hdr(skb); + struct flowi6 fl6; + + /* Lookup a route for the new destination. Take into + * account that the base route may already have a gateway. + */ + + memset(&fl6, 0, sizeof(fl6)); + fl6.flowi6_oif = orig_dst->dev->ifindex; + fl6.flowi6_iif = LOOPBACK_IFINDEX; + fl6.daddr = *rt6_nexthop((struct rt6_info *)orig_dst, + &ip6h->daddr); + + dst = ip6_route_output(net, NULL, &fl6); + if (dst->error) { + err = -EHOSTUNREACH; + dst_release(dst); + goto drop; + } + + dst = xfrm_lookup(net, dst, flowi6_to_flowi(&fl6), NULL, 0); + if (IS_ERR(dst)) { + err = PTR_ERR(dst); + goto drop; + } + + if (ilwt->connected) + dst_cache_set_ip6(&ilwt->dst_cache, dst, &fl6.saddr); + } + + skb_dst_set(skb, dst); + return dst_output(net, sk, skb); + +drop: + kfree_skb(skb); + return err; +} + +static int ila_input(struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + struct ila_lwt *ilwt = ila_lwt_lwtunnel(dst->lwtstate); + + if (skb->protocol != htons(ETH_P_IPV6)) + goto drop; + + if (!ilwt->lwt_output) + ila_update_ipv6_locator(skb, + ila_params_lwtunnel(dst->lwtstate), + false); + + return dst->lwtstate->orig_input(skb); + +drop: + kfree_skb(skb); + return -EINVAL; +} + +static const struct nla_policy ila_nl_policy[ILA_ATTR_MAX + 1] = { + [ILA_ATTR_LOCATOR] = { .type = NLA_U64, }, + [ILA_ATTR_CSUM_MODE] = { .type = NLA_U8, }, + [ILA_ATTR_IDENT_TYPE] = { .type = NLA_U8, }, + [ILA_ATTR_HOOK_TYPE] = { .type = NLA_U8, }, +}; + +static int ila_build_state(struct net *net, struct nlattr *nla, + unsigned int family, const void *cfg, + struct lwtunnel_state **ts, + struct netlink_ext_ack *extack) +{ + struct ila_lwt *ilwt; + struct ila_params *p; + struct nlattr *tb[ILA_ATTR_MAX + 1]; + struct lwtunnel_state *newts; + const struct fib6_config *cfg6 = cfg; + struct ila_addr *iaddr; + u8 ident_type = ILA_ATYPE_USE_FORMAT; + u8 hook_type = ILA_HOOK_ROUTE_OUTPUT; + u8 csum_mode = ILA_CSUM_NO_ACTION; + bool lwt_output = true; + u8 eff_ident_type; + int ret; + + if (family != AF_INET6) + return -EINVAL; + + ret = nla_parse_nested_deprecated(tb, ILA_ATTR_MAX, nla, + ila_nl_policy, extack); + if (ret < 0) + return ret; + + if (!tb[ILA_ATTR_LOCATOR]) + return -EINVAL; + + iaddr = (struct ila_addr *)&cfg6->fc_dst; + + if (tb[ILA_ATTR_IDENT_TYPE]) + ident_type = nla_get_u8(tb[ILA_ATTR_IDENT_TYPE]); + + if (ident_type == ILA_ATYPE_USE_FORMAT) { + /* Infer identifier type from type field in formatted + * identifier. + */ + + if (cfg6->fc_dst_len < 8 * sizeof(struct ila_locator) + 3) { + /* Need to have full locator and at least type field + * included in destination + */ + return -EINVAL; + } + + eff_ident_type = iaddr->ident.type; + } else { + eff_ident_type = ident_type; + } + + switch (eff_ident_type) { + case ILA_ATYPE_IID: + /* Don't allow ILA for IID type */ + return -EINVAL; + case ILA_ATYPE_LUID: + break; + case ILA_ATYPE_VIRT_V4: + case ILA_ATYPE_VIRT_UNI_V6: + case ILA_ATYPE_VIRT_MULTI_V6: + case ILA_ATYPE_NONLOCAL_ADDR: + /* These ILA formats are not supported yet. */ + default: + return -EINVAL; + } + + if (tb[ILA_ATTR_HOOK_TYPE]) + hook_type = nla_get_u8(tb[ILA_ATTR_HOOK_TYPE]); + + switch (hook_type) { + case ILA_HOOK_ROUTE_OUTPUT: + lwt_output = true; + break; + case ILA_HOOK_ROUTE_INPUT: + lwt_output = false; + break; + default: + return -EINVAL; + } + + if (tb[ILA_ATTR_CSUM_MODE]) + csum_mode = nla_get_u8(tb[ILA_ATTR_CSUM_MODE]); + + if (csum_mode == ILA_CSUM_NEUTRAL_MAP && + ila_csum_neutral_set(iaddr->ident)) { + /* Don't allow translation if checksum neutral bit is + * configured and it's set in the SIR address. + */ + return -EINVAL; + } + + newts = lwtunnel_state_alloc(sizeof(*ilwt)); + if (!newts) + return -ENOMEM; + + ilwt = ila_lwt_lwtunnel(newts); + ret = dst_cache_init(&ilwt->dst_cache, GFP_ATOMIC); + if (ret) { + kfree(newts); + return ret; + } + + ilwt->lwt_output = !!lwt_output; + + p = ila_params_lwtunnel(newts); + + p->csum_mode = csum_mode; + p->ident_type = ident_type; + p->locator.v64 = (__force __be64)nla_get_u64(tb[ILA_ATTR_LOCATOR]); + + /* Precompute checksum difference for translation since we + * know both the old locator and the new one. + */ + p->locator_match = iaddr->loc; + + ila_init_saved_csum(p); + + newts->type = LWTUNNEL_ENCAP_ILA; + newts->flags |= LWTUNNEL_STATE_OUTPUT_REDIRECT | + LWTUNNEL_STATE_INPUT_REDIRECT; + + if (cfg6->fc_dst_len == 8 * sizeof(struct in6_addr)) + ilwt->connected = 1; + + *ts = newts; + + return 0; +} + +static void ila_destroy_state(struct lwtunnel_state *lwt) +{ + dst_cache_destroy(&ila_lwt_lwtunnel(lwt)->dst_cache); +} + +static int ila_fill_encap_info(struct sk_buff *skb, + struct lwtunnel_state *lwtstate) +{ + struct ila_params *p = ila_params_lwtunnel(lwtstate); + struct ila_lwt *ilwt = ila_lwt_lwtunnel(lwtstate); + + if (nla_put_u64_64bit(skb, ILA_ATTR_LOCATOR, (__force u64)p->locator.v64, + ILA_ATTR_PAD)) + goto nla_put_failure; + + if (nla_put_u8(skb, ILA_ATTR_CSUM_MODE, (__force u8)p->csum_mode)) + goto nla_put_failure; + + if (nla_put_u8(skb, ILA_ATTR_IDENT_TYPE, (__force u8)p->ident_type)) + goto nla_put_failure; + + if (nla_put_u8(skb, ILA_ATTR_HOOK_TYPE, + ilwt->lwt_output ? ILA_HOOK_ROUTE_OUTPUT : + ILA_HOOK_ROUTE_INPUT)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static int ila_encap_nlsize(struct lwtunnel_state *lwtstate) +{ + return nla_total_size_64bit(sizeof(u64)) + /* ILA_ATTR_LOCATOR */ + nla_total_size(sizeof(u8)) + /* ILA_ATTR_CSUM_MODE */ + nla_total_size(sizeof(u8)) + /* ILA_ATTR_IDENT_TYPE */ + nla_total_size(sizeof(u8)) + /* ILA_ATTR_HOOK_TYPE */ + 0; +} + +static int ila_encap_cmp(struct lwtunnel_state *a, struct lwtunnel_state *b) +{ + struct ila_params *a_p = ila_params_lwtunnel(a); + struct ila_params *b_p = ila_params_lwtunnel(b); + + return (a_p->locator.v64 != b_p->locator.v64); +} + +static const struct lwtunnel_encap_ops ila_encap_ops = { + .build_state = ila_build_state, + .destroy_state = ila_destroy_state, + .output = ila_output, + .input = ila_input, + .fill_encap = ila_fill_encap_info, + .get_encap_size = ila_encap_nlsize, + .cmp_encap = ila_encap_cmp, + .owner = THIS_MODULE, +}; + +int ila_lwt_init(void) +{ + return lwtunnel_encap_add_ops(&ila_encap_ops, LWTUNNEL_ENCAP_ILA); +} + +void ila_lwt_fini(void) +{ + lwtunnel_encap_del_ops(&ila_encap_ops, LWTUNNEL_ENCAP_ILA); +} diff --git a/net/ipv6/ila/ila_main.c b/net/ipv6/ila/ila_main.c new file mode 100644 index 000000000..3faf62530 --- /dev/null +++ b/net/ipv6/ila/ila_main.c @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include "ila.h" + +static const struct nla_policy ila_nl_policy[ILA_ATTR_MAX + 1] = { + [ILA_ATTR_LOCATOR] = { .type = NLA_U64, }, + [ILA_ATTR_LOCATOR_MATCH] = { .type = NLA_U64, }, + [ILA_ATTR_IFINDEX] = { .type = NLA_U32, }, + [ILA_ATTR_CSUM_MODE] = { .type = NLA_U8, }, + [ILA_ATTR_IDENT_TYPE] = { .type = NLA_U8, }, +}; + +static const struct genl_ops ila_nl_ops[] = { + { + .cmd = ILA_CMD_ADD, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = ila_xlat_nl_cmd_add_mapping, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = ILA_CMD_DEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = ila_xlat_nl_cmd_del_mapping, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = ILA_CMD_FLUSH, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = ila_xlat_nl_cmd_flush, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = ILA_CMD_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = ila_xlat_nl_cmd_get_mapping, + .start = ila_xlat_nl_dump_start, + .dumpit = ila_xlat_nl_dump, + .done = ila_xlat_nl_dump_done, + }, +}; + +unsigned int ila_net_id; + +struct genl_family ila_nl_family __ro_after_init = { + .hdrsize = 0, + .name = ILA_GENL_NAME, + .version = ILA_GENL_VERSION, + .maxattr = ILA_ATTR_MAX, + .policy = ila_nl_policy, + .netnsok = true, + .parallel_ops = true, + .module = THIS_MODULE, + .ops = ila_nl_ops, + .n_ops = ARRAY_SIZE(ila_nl_ops), + .resv_start_op = ILA_CMD_FLUSH + 1, +}; + +static __net_init int ila_init_net(struct net *net) +{ + int err; + + err = ila_xlat_init_net(net); + if (err) + goto ila_xlat_init_fail; + + return 0; + +ila_xlat_init_fail: + return err; +} + +static __net_exit void ila_exit_net(struct net *net) +{ + ila_xlat_exit_net(net); +} + +static struct pernet_operations ila_net_ops = { + .init = ila_init_net, + .exit = ila_exit_net, + .id = &ila_net_id, + .size = sizeof(struct ila_net), +}; + +static int __init ila_init(void) +{ + int ret; + + ret = register_pernet_device(&ila_net_ops); + if (ret) + goto register_device_fail; + + ret = genl_register_family(&ila_nl_family); + if (ret) + goto register_family_fail; + + ret = ila_lwt_init(); + if (ret) + goto fail_lwt; + + return 0; + +fail_lwt: + genl_unregister_family(&ila_nl_family); +register_family_fail: + unregister_pernet_device(&ila_net_ops); +register_device_fail: + return ret; +} + +static void __exit ila_fini(void) +{ + ila_lwt_fini(); + genl_unregister_family(&ila_nl_family); + unregister_pernet_device(&ila_net_ops); +} + +module_init(ila_init); +module_exit(ila_fini); +MODULE_AUTHOR("Tom Herbert "); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("IPv6: Identifier Locator Addressing (ILA)"); diff --git a/net/ipv6/ila/ila_xlat.c b/net/ipv6/ila/ila_xlat.c new file mode 100644 index 000000000..bee45dfeb --- /dev/null +++ b/net/ipv6/ila/ila_xlat.c @@ -0,0 +1,660 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ila.h" + +struct ila_xlat_params { + struct ila_params ip; + int ifindex; +}; + +struct ila_map { + struct ila_xlat_params xp; + struct rhash_head node; + struct ila_map __rcu *next; + struct rcu_head rcu; +}; + +#define MAX_LOCKS 1024 +#define LOCKS_PER_CPU 10 + +static int alloc_ila_locks(struct ila_net *ilan) +{ + return alloc_bucket_spinlocks(&ilan->xlat.locks, &ilan->xlat.locks_mask, + MAX_LOCKS, LOCKS_PER_CPU, + GFP_KERNEL); +} + +static u32 hashrnd __read_mostly; +static __always_inline void __ila_hash_secret_init(void) +{ + net_get_random_once(&hashrnd, sizeof(hashrnd)); +} + +static inline u32 ila_locator_hash(struct ila_locator loc) +{ + u32 *v = (u32 *)loc.v32; + + __ila_hash_secret_init(); + return jhash_2words(v[0], v[1], hashrnd); +} + +static inline spinlock_t *ila_get_lock(struct ila_net *ilan, + struct ila_locator loc) +{ + return &ilan->xlat.locks[ila_locator_hash(loc) & ilan->xlat.locks_mask]; +} + +static inline int ila_cmp_wildcards(struct ila_map *ila, + struct ila_addr *iaddr, int ifindex) +{ + return (ila->xp.ifindex && ila->xp.ifindex != ifindex); +} + +static inline int ila_cmp_params(struct ila_map *ila, + struct ila_xlat_params *xp) +{ + return (ila->xp.ifindex != xp->ifindex); +} + +static int ila_cmpfn(struct rhashtable_compare_arg *arg, + const void *obj) +{ + const struct ila_map *ila = obj; + + return (ila->xp.ip.locator_match.v64 != *(__be64 *)arg->key); +} + +static inline int ila_order(struct ila_map *ila) +{ + int score = 0; + + if (ila->xp.ifindex) + score += 1 << 1; + + return score; +} + +static const struct rhashtable_params rht_params = { + .nelem_hint = 1024, + .head_offset = offsetof(struct ila_map, node), + .key_offset = offsetof(struct ila_map, xp.ip.locator_match), + .key_len = sizeof(u64), /* identifier */ + .max_size = 1048576, + .min_size = 256, + .automatic_shrinking = true, + .obj_cmpfn = ila_cmpfn, +}; + +static int parse_nl_config(struct genl_info *info, + struct ila_xlat_params *xp) +{ + memset(xp, 0, sizeof(*xp)); + + if (info->attrs[ILA_ATTR_LOCATOR]) + xp->ip.locator.v64 = (__force __be64)nla_get_u64( + info->attrs[ILA_ATTR_LOCATOR]); + + if (info->attrs[ILA_ATTR_LOCATOR_MATCH]) + xp->ip.locator_match.v64 = (__force __be64)nla_get_u64( + info->attrs[ILA_ATTR_LOCATOR_MATCH]); + + if (info->attrs[ILA_ATTR_CSUM_MODE]) + xp->ip.csum_mode = nla_get_u8(info->attrs[ILA_ATTR_CSUM_MODE]); + else + xp->ip.csum_mode = ILA_CSUM_NO_ACTION; + + if (info->attrs[ILA_ATTR_IDENT_TYPE]) + xp->ip.ident_type = nla_get_u8( + info->attrs[ILA_ATTR_IDENT_TYPE]); + else + xp->ip.ident_type = ILA_ATYPE_USE_FORMAT; + + if (info->attrs[ILA_ATTR_IFINDEX]) + xp->ifindex = nla_get_s32(info->attrs[ILA_ATTR_IFINDEX]); + + return 0; +} + +/* Must be called with rcu readlock */ +static inline struct ila_map *ila_lookup_wildcards(struct ila_addr *iaddr, + int ifindex, + struct ila_net *ilan) +{ + struct ila_map *ila; + + ila = rhashtable_lookup_fast(&ilan->xlat.rhash_table, &iaddr->loc, + rht_params); + while (ila) { + if (!ila_cmp_wildcards(ila, iaddr, ifindex)) + return ila; + ila = rcu_access_pointer(ila->next); + } + + return NULL; +} + +/* Must be called with rcu readlock */ +static inline struct ila_map *ila_lookup_by_params(struct ila_xlat_params *xp, + struct ila_net *ilan) +{ + struct ila_map *ila; + + ila = rhashtable_lookup_fast(&ilan->xlat.rhash_table, + &xp->ip.locator_match, + rht_params); + while (ila) { + if (!ila_cmp_params(ila, xp)) + return ila; + ila = rcu_access_pointer(ila->next); + } + + return NULL; +} + +static inline void ila_release(struct ila_map *ila) +{ + kfree_rcu(ila, rcu); +} + +static void ila_free_node(struct ila_map *ila) +{ + struct ila_map *next; + + /* Assume rcu_readlock held */ + while (ila) { + next = rcu_access_pointer(ila->next); + ila_release(ila); + ila = next; + } +} + +static void ila_free_cb(void *ptr, void *arg) +{ + ila_free_node((struct ila_map *)ptr); +} + +static int ila_xlat_addr(struct sk_buff *skb, bool sir2ila); + +static unsigned int +ila_nf_input(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + ila_xlat_addr(skb, false); + return NF_ACCEPT; +} + +static const struct nf_hook_ops ila_nf_hook_ops[] = { + { + .hook = ila_nf_input, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_PRE_ROUTING, + .priority = -1, + }, +}; + +static int ila_add_mapping(struct net *net, struct ila_xlat_params *xp) +{ + struct ila_net *ilan = net_generic(net, ila_net_id); + struct ila_map *ila, *head; + spinlock_t *lock = ila_get_lock(ilan, xp->ip.locator_match); + int err = 0, order; + + if (!ilan->xlat.hooks_registered) { + /* We defer registering net hooks in the namespace until the + * first mapping is added. + */ + err = nf_register_net_hooks(net, ila_nf_hook_ops, + ARRAY_SIZE(ila_nf_hook_ops)); + if (err) + return err; + + ilan->xlat.hooks_registered = true; + } + + ila = kzalloc(sizeof(*ila), GFP_KERNEL); + if (!ila) + return -ENOMEM; + + ila_init_saved_csum(&xp->ip); + + ila->xp = *xp; + + order = ila_order(ila); + + spin_lock(lock); + + head = rhashtable_lookup_fast(&ilan->xlat.rhash_table, + &xp->ip.locator_match, + rht_params); + if (!head) { + /* New entry for the rhash_table */ + err = rhashtable_lookup_insert_fast(&ilan->xlat.rhash_table, + &ila->node, rht_params); + } else { + struct ila_map *tila = head, *prev = NULL; + + do { + if (!ila_cmp_params(tila, xp)) { + err = -EEXIST; + goto out; + } + + if (order > ila_order(tila)) + break; + + prev = tila; + tila = rcu_dereference_protected(tila->next, + lockdep_is_held(lock)); + } while (tila); + + if (prev) { + /* Insert in sub list of head */ + RCU_INIT_POINTER(ila->next, tila); + rcu_assign_pointer(prev->next, ila); + } else { + /* Make this ila new head */ + RCU_INIT_POINTER(ila->next, head); + err = rhashtable_replace_fast(&ilan->xlat.rhash_table, + &head->node, + &ila->node, rht_params); + if (err) + goto out; + } + } + +out: + spin_unlock(lock); + + if (err) + kfree(ila); + + return err; +} + +static int ila_del_mapping(struct net *net, struct ila_xlat_params *xp) +{ + struct ila_net *ilan = net_generic(net, ila_net_id); + struct ila_map *ila, *head, *prev; + spinlock_t *lock = ila_get_lock(ilan, xp->ip.locator_match); + int err = -ENOENT; + + spin_lock(lock); + + head = rhashtable_lookup_fast(&ilan->xlat.rhash_table, + &xp->ip.locator_match, rht_params); + ila = head; + + prev = NULL; + + while (ila) { + if (ila_cmp_params(ila, xp)) { + prev = ila; + ila = rcu_dereference_protected(ila->next, + lockdep_is_held(lock)); + continue; + } + + err = 0; + + if (prev) { + /* Not head, just delete from list */ + rcu_assign_pointer(prev->next, ila->next); + } else { + /* It is the head. If there is something in the + * sublist we need to make a new head. + */ + head = rcu_dereference_protected(ila->next, + lockdep_is_held(lock)); + if (head) { + /* Put first entry in the sublist into the + * table + */ + err = rhashtable_replace_fast( + &ilan->xlat.rhash_table, &ila->node, + &head->node, rht_params); + if (err) + goto out; + } else { + /* Entry no longer used */ + err = rhashtable_remove_fast( + &ilan->xlat.rhash_table, + &ila->node, rht_params); + } + } + + ila_release(ila); + + break; + } + +out: + spin_unlock(lock); + + return err; +} + +int ila_xlat_nl_cmd_add_mapping(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = genl_info_net(info); + struct ila_xlat_params p; + int err; + + err = parse_nl_config(info, &p); + if (err) + return err; + + return ila_add_mapping(net, &p); +} + +int ila_xlat_nl_cmd_del_mapping(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = genl_info_net(info); + struct ila_xlat_params xp; + int err; + + err = parse_nl_config(info, &xp); + if (err) + return err; + + ila_del_mapping(net, &xp); + + return 0; +} + +static inline spinlock_t *lock_from_ila_map(struct ila_net *ilan, + struct ila_map *ila) +{ + return ila_get_lock(ilan, ila->xp.ip.locator_match); +} + +int ila_xlat_nl_cmd_flush(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = genl_info_net(info); + struct ila_net *ilan = net_generic(net, ila_net_id); + struct rhashtable_iter iter; + struct ila_map *ila; + spinlock_t *lock; + int ret = 0; + + rhashtable_walk_enter(&ilan->xlat.rhash_table, &iter); + rhashtable_walk_start(&iter); + + for (;;) { + ila = rhashtable_walk_next(&iter); + + if (IS_ERR(ila)) { + if (PTR_ERR(ila) == -EAGAIN) + continue; + ret = PTR_ERR(ila); + goto done; + } else if (!ila) { + break; + } + + lock = lock_from_ila_map(ilan, ila); + + spin_lock(lock); + + ret = rhashtable_remove_fast(&ilan->xlat.rhash_table, + &ila->node, rht_params); + if (!ret) + ila_free_node(ila); + + spin_unlock(lock); + + if (ret) + break; + } + +done: + rhashtable_walk_stop(&iter); + rhashtable_walk_exit(&iter); + return ret; +} + +static int ila_fill_info(struct ila_map *ila, struct sk_buff *msg) +{ + if (nla_put_u64_64bit(msg, ILA_ATTR_LOCATOR, + (__force u64)ila->xp.ip.locator.v64, + ILA_ATTR_PAD) || + nla_put_u64_64bit(msg, ILA_ATTR_LOCATOR_MATCH, + (__force u64)ila->xp.ip.locator_match.v64, + ILA_ATTR_PAD) || + nla_put_s32(msg, ILA_ATTR_IFINDEX, ila->xp.ifindex) || + nla_put_u8(msg, ILA_ATTR_CSUM_MODE, ila->xp.ip.csum_mode) || + nla_put_u8(msg, ILA_ATTR_IDENT_TYPE, ila->xp.ip.ident_type)) + return -1; + + return 0; +} + +static int ila_dump_info(struct ila_map *ila, + u32 portid, u32 seq, u32 flags, + struct sk_buff *skb, u8 cmd) +{ + void *hdr; + + hdr = genlmsg_put(skb, portid, seq, &ila_nl_family, flags, cmd); + if (!hdr) + return -ENOMEM; + + if (ila_fill_info(ila, skb) < 0) + goto nla_put_failure; + + genlmsg_end(skb, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; +} + +int ila_xlat_nl_cmd_get_mapping(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = genl_info_net(info); + struct ila_net *ilan = net_generic(net, ila_net_id); + struct sk_buff *msg; + struct ila_xlat_params xp; + struct ila_map *ila; + int ret; + + ret = parse_nl_config(info, &xp); + if (ret) + return ret; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + rcu_read_lock(); + + ret = -ESRCH; + ila = ila_lookup_by_params(&xp, ilan); + if (ila) { + ret = ila_dump_info(ila, + info->snd_portid, + info->snd_seq, 0, msg, + info->genlhdr->cmd); + } + + rcu_read_unlock(); + + if (ret < 0) + goto out_free; + + return genlmsg_reply(msg, info); + +out_free: + nlmsg_free(msg); + return ret; +} + +struct ila_dump_iter { + struct rhashtable_iter rhiter; + int skip; +}; + +int ila_xlat_nl_dump_start(struct netlink_callback *cb) +{ + struct net *net = sock_net(cb->skb->sk); + struct ila_net *ilan = net_generic(net, ila_net_id); + struct ila_dump_iter *iter; + + iter = kmalloc(sizeof(*iter), GFP_KERNEL); + if (!iter) + return -ENOMEM; + + rhashtable_walk_enter(&ilan->xlat.rhash_table, &iter->rhiter); + + iter->skip = 0; + cb->args[0] = (long)iter; + + return 0; +} + +int ila_xlat_nl_dump_done(struct netlink_callback *cb) +{ + struct ila_dump_iter *iter = (struct ila_dump_iter *)cb->args[0]; + + rhashtable_walk_exit(&iter->rhiter); + + kfree(iter); + + return 0; +} + +int ila_xlat_nl_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct ila_dump_iter *iter = (struct ila_dump_iter *)cb->args[0]; + struct rhashtable_iter *rhiter = &iter->rhiter; + int skip = iter->skip; + struct ila_map *ila; + int ret; + + rhashtable_walk_start(rhiter); + + /* Get first entry */ + ila = rhashtable_walk_peek(rhiter); + + if (ila && !IS_ERR(ila) && skip) { + /* Skip over visited entries */ + + while (ila && skip) { + /* Skip over any ila entries in this list that we + * have already dumped. + */ + ila = rcu_access_pointer(ila->next); + skip--; + } + } + + skip = 0; + + for (;;) { + if (IS_ERR(ila)) { + ret = PTR_ERR(ila); + if (ret == -EAGAIN) { + /* Table has changed and iter has reset. Return + * -EAGAIN to the application even if we have + * written data to the skb. The application + * needs to deal with this. + */ + + goto out_ret; + } else { + break; + } + } else if (!ila) { + ret = 0; + break; + } + + while (ila) { + ret = ila_dump_info(ila, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + skb, ILA_CMD_GET); + if (ret) + goto out; + + skip++; + ila = rcu_access_pointer(ila->next); + } + + skip = 0; + ila = rhashtable_walk_next(rhiter); + } + +out: + iter->skip = skip; + ret = (skb->len ? : ret); + +out_ret: + rhashtable_walk_stop(rhiter); + return ret; +} + +int ila_xlat_init_net(struct net *net) +{ + struct ila_net *ilan = net_generic(net, ila_net_id); + int err; + + err = alloc_ila_locks(ilan); + if (err) + return err; + + err = rhashtable_init(&ilan->xlat.rhash_table, &rht_params); + if (err) { + free_bucket_spinlocks(ilan->xlat.locks); + return err; + } + + return 0; +} + +void ila_xlat_exit_net(struct net *net) +{ + struct ila_net *ilan = net_generic(net, ila_net_id); + + rhashtable_free_and_destroy(&ilan->xlat.rhash_table, ila_free_cb, NULL); + + free_bucket_spinlocks(ilan->xlat.locks); + + if (ilan->xlat.hooks_registered) + nf_unregister_net_hooks(net, ila_nf_hook_ops, + ARRAY_SIZE(ila_nf_hook_ops)); +} + +static int ila_xlat_addr(struct sk_buff *skb, bool sir2ila) +{ + struct ila_map *ila; + struct ipv6hdr *ip6h = ipv6_hdr(skb); + struct net *net = dev_net(skb->dev); + struct ila_net *ilan = net_generic(net, ila_net_id); + struct ila_addr *iaddr = ila_a2i(&ip6h->daddr); + + /* Assumes skb contains a valid IPv6 header that is pulled */ + + /* No check here that ILA type in the mapping matches what is in the + * address. We assume that whatever sender gaves us can be translated. + * The checksum mode however is relevant. + */ + + rcu_read_lock(); + + ila = ila_lookup_wildcards(iaddr, skb->dev->ifindex, ilan); + if (ila) + ila_update_ipv6_locator(skb, &ila->xp.ip, sir2ila); + + rcu_read_unlock(); + + return 0; +} diff --git a/net/ipv6/inet6_connection_sock.c b/net/ipv6/inet6_connection_sock.c new file mode 100644 index 000000000..5a9f4d722 --- /dev/null +++ b/net/ipv6/inet6_connection_sock.c @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * Support for INET6 connection oriented protocols. + * + * Authors: See the TCPv6 sources + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +struct dst_entry *inet6_csk_route_req(const struct sock *sk, + struct flowi6 *fl6, + const struct request_sock *req, + u8 proto) +{ + struct inet_request_sock *ireq = inet_rsk(req); + const struct ipv6_pinfo *np = inet6_sk(sk); + struct in6_addr *final_p, final; + struct dst_entry *dst; + + memset(fl6, 0, sizeof(*fl6)); + fl6->flowi6_proto = proto; + fl6->daddr = ireq->ir_v6_rmt_addr; + rcu_read_lock(); + final_p = fl6_update_dst(fl6, rcu_dereference(np->opt), &final); + rcu_read_unlock(); + fl6->saddr = ireq->ir_v6_loc_addr; + fl6->flowi6_oif = ireq->ir_iif; + fl6->flowi6_mark = ireq->ir_mark; + fl6->fl6_dport = ireq->ir_rmt_port; + fl6->fl6_sport = htons(ireq->ir_num); + fl6->flowi6_uid = sk->sk_uid; + security_req_classify_flow(req, flowi6_to_flowi_common(fl6)); + + dst = ip6_dst_lookup_flow(sock_net(sk), sk, fl6, final_p); + if (IS_ERR(dst)) + return NULL; + + return dst; +} +EXPORT_SYMBOL(inet6_csk_route_req); + +void inet6_csk_addr2sockaddr(struct sock *sk, struct sockaddr *uaddr) +{ + struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *) uaddr; + + sin6->sin6_family = AF_INET6; + sin6->sin6_addr = sk->sk_v6_daddr; + sin6->sin6_port = inet_sk(sk)->inet_dport; + /* We do not store received flowlabel for TCP */ + sin6->sin6_flowinfo = 0; + sin6->sin6_scope_id = ipv6_iface_scope_id(&sin6->sin6_addr, + sk->sk_bound_dev_if); +} +EXPORT_SYMBOL_GPL(inet6_csk_addr2sockaddr); + +static inline +struct dst_entry *__inet6_csk_dst_check(struct sock *sk, u32 cookie) +{ + return __sk_dst_check(sk, cookie); +} + +static struct dst_entry *inet6_csk_route_socket(struct sock *sk, + struct flowi6 *fl6) +{ + struct inet_sock *inet = inet_sk(sk); + struct ipv6_pinfo *np = inet6_sk(sk); + struct in6_addr *final_p, final; + struct dst_entry *dst; + + memset(fl6, 0, sizeof(*fl6)); + fl6->flowi6_proto = sk->sk_protocol; + fl6->daddr = sk->sk_v6_daddr; + fl6->saddr = np->saddr; + fl6->flowlabel = np->flow_label; + IP6_ECN_flow_xmit(sk, fl6->flowlabel); + fl6->flowi6_oif = sk->sk_bound_dev_if; + fl6->flowi6_mark = sk->sk_mark; + fl6->fl6_sport = inet->inet_sport; + fl6->fl6_dport = inet->inet_dport; + fl6->flowi6_uid = sk->sk_uid; + security_sk_classify_flow(sk, flowi6_to_flowi_common(fl6)); + + rcu_read_lock(); + final_p = fl6_update_dst(fl6, rcu_dereference(np->opt), &final); + rcu_read_unlock(); + + dst = __inet6_csk_dst_check(sk, np->dst_cookie); + if (!dst) { + dst = ip6_dst_lookup_flow(sock_net(sk), sk, fl6, final_p); + + if (!IS_ERR(dst)) + ip6_dst_store(sk, dst, NULL, NULL); + } + return dst; +} + +int inet6_csk_xmit(struct sock *sk, struct sk_buff *skb, struct flowi *fl_unused) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct flowi6 fl6; + struct dst_entry *dst; + int res; + + dst = inet6_csk_route_socket(sk, &fl6); + if (IS_ERR(dst)) { + sk->sk_err_soft = -PTR_ERR(dst); + sk->sk_route_caps = 0; + kfree_skb(skb); + return PTR_ERR(dst); + } + + rcu_read_lock(); + skb_dst_set_noref(skb, dst); + + /* Restore final destination back after routing done */ + fl6.daddr = sk->sk_v6_daddr; + + res = ip6_xmit(sk, skb, &fl6, sk->sk_mark, rcu_dereference(np->opt), + np->tclass, sk->sk_priority); + rcu_read_unlock(); + return res; +} +EXPORT_SYMBOL_GPL(inet6_csk_xmit); + +struct dst_entry *inet6_csk_update_pmtu(struct sock *sk, u32 mtu) +{ + struct flowi6 fl6; + struct dst_entry *dst = inet6_csk_route_socket(sk, &fl6); + + if (IS_ERR(dst)) + return NULL; + dst->ops->update_pmtu(dst, sk, NULL, mtu, true); + + dst = inet6_csk_route_socket(sk, &fl6); + return IS_ERR(dst) ? NULL : dst; +} +EXPORT_SYMBOL_GPL(inet6_csk_update_pmtu); diff --git a/net/ipv6/inet6_hashtables.c b/net/ipv6/inet6_hashtables.c new file mode 100644 index 000000000..b64b49012 --- /dev/null +++ b/net/ipv6/inet6_hashtables.c @@ -0,0 +1,338 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * Generic INET6 transport hashtables + * + * Authors: Lotsa people, from code originally in tcp, generalised here + * by Arnaldo Carvalho de Melo + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +u32 inet6_ehashfn(const struct net *net, + const struct in6_addr *laddr, const u16 lport, + const struct in6_addr *faddr, const __be16 fport) +{ + static u32 inet6_ehash_secret __read_mostly; + static u32 ipv6_hash_secret __read_mostly; + + u32 lhash, fhash; + + net_get_random_once(&inet6_ehash_secret, sizeof(inet6_ehash_secret)); + net_get_random_once(&ipv6_hash_secret, sizeof(ipv6_hash_secret)); + + lhash = (__force u32)laddr->s6_addr32[3]; + fhash = __ipv6_addr_jhash(faddr, ipv6_hash_secret); + + return __inet6_ehashfn(lhash, lport, fhash, fport, + inet6_ehash_secret + net_hash_mix(net)); +} + +/* + * Sockets in TCP_CLOSE state are _always_ taken out of the hash, so + * we need not check it for TCP lookups anymore, thanks Alexey. -DaveM + * + * The sockhash lock must be held as a reader here. + */ +struct sock *__inet6_lookup_established(struct net *net, + struct inet_hashinfo *hashinfo, + const struct in6_addr *saddr, + const __be16 sport, + const struct in6_addr *daddr, + const u16 hnum, + const int dif, const int sdif) +{ + struct sock *sk; + const struct hlist_nulls_node *node; + const __portpair ports = INET_COMBINED_PORTS(sport, hnum); + /* Optimize here for direct hit, only listening connections can + * have wildcards anyways. + */ + unsigned int hash = inet6_ehashfn(net, daddr, hnum, saddr, sport); + unsigned int slot = hash & hashinfo->ehash_mask; + struct inet_ehash_bucket *head = &hashinfo->ehash[slot]; + + +begin: + sk_nulls_for_each_rcu(sk, node, &head->chain) { + if (sk->sk_hash != hash) + continue; + if (!inet6_match(net, sk, saddr, daddr, ports, dif, sdif)) + continue; + if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt))) + goto out; + + if (unlikely(!inet6_match(net, sk, saddr, daddr, ports, dif, sdif))) { + sock_gen_put(sk); + goto begin; + } + goto found; + } + if (get_nulls_value(node) != slot) + goto begin; +out: + sk = NULL; +found: + return sk; +} +EXPORT_SYMBOL(__inet6_lookup_established); + +static inline int compute_score(struct sock *sk, struct net *net, + const unsigned short hnum, + const struct in6_addr *daddr, + const int dif, const int sdif) +{ + int score = -1; + + if (net_eq(sock_net(sk), net) && inet_sk(sk)->inet_num == hnum && + sk->sk_family == PF_INET6) { + if (!ipv6_addr_equal(&sk->sk_v6_rcv_saddr, daddr)) + return -1; + + if (!inet_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif)) + return -1; + + score = sk->sk_bound_dev_if ? 2 : 1; + if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id()) + score++; + } + return score; +} + +static inline struct sock *lookup_reuseport(struct net *net, struct sock *sk, + struct sk_buff *skb, int doff, + const struct in6_addr *saddr, + __be16 sport, + const struct in6_addr *daddr, + unsigned short hnum) +{ + struct sock *reuse_sk = NULL; + u32 phash; + + if (sk->sk_reuseport) { + phash = inet6_ehashfn(net, daddr, hnum, saddr, sport); + reuse_sk = reuseport_select_sock(sk, phash, skb, doff); + } + return reuse_sk; +} + +/* called with rcu_read_lock() */ +static struct sock *inet6_lhash2_lookup(struct net *net, + struct inet_listen_hashbucket *ilb2, + struct sk_buff *skb, int doff, + const struct in6_addr *saddr, + const __be16 sport, const struct in6_addr *daddr, + const unsigned short hnum, const int dif, const int sdif) +{ + struct sock *sk, *result = NULL; + struct hlist_nulls_node *node; + int score, hiscore = 0; + + sk_nulls_for_each_rcu(sk, node, &ilb2->nulls_head) { + score = compute_score(sk, net, hnum, daddr, dif, sdif); + if (score > hiscore) { + result = lookup_reuseport(net, sk, skb, doff, + saddr, sport, daddr, hnum); + if (result) + return result; + + result = sk; + hiscore = score; + } + } + + return result; +} + +static inline struct sock *inet6_lookup_run_bpf(struct net *net, + struct inet_hashinfo *hashinfo, + struct sk_buff *skb, int doff, + const struct in6_addr *saddr, + const __be16 sport, + const struct in6_addr *daddr, + const u16 hnum, const int dif) +{ + struct sock *sk, *reuse_sk; + bool no_reuseport; + + if (hashinfo != net->ipv4.tcp_death_row.hashinfo) + return NULL; /* only TCP is supported */ + + no_reuseport = bpf_sk_lookup_run_v6(net, IPPROTO_TCP, saddr, sport, + daddr, hnum, dif, &sk); + if (no_reuseport || IS_ERR_OR_NULL(sk)) + return sk; + + reuse_sk = lookup_reuseport(net, sk, skb, doff, saddr, sport, daddr, hnum); + if (reuse_sk) + sk = reuse_sk; + return sk; +} + +struct sock *inet6_lookup_listener(struct net *net, + struct inet_hashinfo *hashinfo, + struct sk_buff *skb, int doff, + const struct in6_addr *saddr, + const __be16 sport, const struct in6_addr *daddr, + const unsigned short hnum, const int dif, const int sdif) +{ + struct inet_listen_hashbucket *ilb2; + struct sock *result = NULL; + unsigned int hash2; + + /* Lookup redirect from BPF */ + if (static_branch_unlikely(&bpf_sk_lookup_enabled)) { + result = inet6_lookup_run_bpf(net, hashinfo, skb, doff, + saddr, sport, daddr, hnum, dif); + if (result) + goto done; + } + + hash2 = ipv6_portaddr_hash(net, daddr, hnum); + ilb2 = inet_lhash2_bucket(hashinfo, hash2); + + result = inet6_lhash2_lookup(net, ilb2, skb, doff, + saddr, sport, daddr, hnum, + dif, sdif); + if (result) + goto done; + + /* Lookup lhash2 with in6addr_any */ + hash2 = ipv6_portaddr_hash(net, &in6addr_any, hnum); + ilb2 = inet_lhash2_bucket(hashinfo, hash2); + + result = inet6_lhash2_lookup(net, ilb2, skb, doff, + saddr, sport, &in6addr_any, hnum, + dif, sdif); +done: + if (IS_ERR(result)) + return NULL; + return result; +} +EXPORT_SYMBOL_GPL(inet6_lookup_listener); + +struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo, + struct sk_buff *skb, int doff, + const struct in6_addr *saddr, const __be16 sport, + const struct in6_addr *daddr, const __be16 dport, + const int dif) +{ + struct sock *sk; + bool refcounted; + + sk = __inet6_lookup(net, hashinfo, skb, doff, saddr, sport, daddr, + ntohs(dport), dif, 0, &refcounted); + if (sk && !refcounted && !refcount_inc_not_zero(&sk->sk_refcnt)) + sk = NULL; + return sk; +} +EXPORT_SYMBOL_GPL(inet6_lookup); + +static int __inet6_check_established(struct inet_timewait_death_row *death_row, + struct sock *sk, const __u16 lport, + struct inet_timewait_sock **twp) +{ + struct inet_hashinfo *hinfo = death_row->hashinfo; + struct inet_sock *inet = inet_sk(sk); + const struct in6_addr *daddr = &sk->sk_v6_rcv_saddr; + const struct in6_addr *saddr = &sk->sk_v6_daddr; + const int dif = sk->sk_bound_dev_if; + struct net *net = sock_net(sk); + const int sdif = l3mdev_master_ifindex_by_index(net, dif); + const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport); + const unsigned int hash = inet6_ehashfn(net, daddr, lport, saddr, + inet->inet_dport); + struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash); + spinlock_t *lock = inet_ehash_lockp(hinfo, hash); + struct sock *sk2; + const struct hlist_nulls_node *node; + struct inet_timewait_sock *tw = NULL; + + spin_lock(lock); + + sk_nulls_for_each(sk2, node, &head->chain) { + if (sk2->sk_hash != hash) + continue; + + if (likely(inet6_match(net, sk2, saddr, daddr, ports, + dif, sdif))) { + if (sk2->sk_state == TCP_TIME_WAIT) { + tw = inet_twsk(sk2); + if (twsk_unique(sk, sk2, twp)) + break; + } + goto not_unique; + } + } + + /* Must record num and sport now. Otherwise we will see + * in hash table socket with a funny identity. + */ + inet->inet_num = lport; + inet->inet_sport = htons(lport); + sk->sk_hash = hash; + WARN_ON(!sk_unhashed(sk)); + __sk_nulls_add_node_rcu(sk, &head->chain); + if (tw) { + sk_nulls_del_node_init_rcu((struct sock *)tw); + __NET_INC_STATS(net, LINUX_MIB_TIMEWAITRECYCLED); + } + spin_unlock(lock); + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); + + if (twp) { + *twp = tw; + } else if (tw) { + /* Silly. Should hash-dance instead... */ + inet_twsk_deschedule_put(tw); + } + return 0; + +not_unique: + spin_unlock(lock); + return -EADDRNOTAVAIL; +} + +static u64 inet6_sk_port_offset(const struct sock *sk) +{ + const struct inet_sock *inet = inet_sk(sk); + + return secure_ipv6_port_ephemeral(sk->sk_v6_rcv_saddr.s6_addr32, + sk->sk_v6_daddr.s6_addr32, + inet->inet_dport); +} + +int inet6_hash_connect(struct inet_timewait_death_row *death_row, + struct sock *sk) +{ + u64 port_offset = 0; + + if (!inet_sk(sk)->inet_num) + port_offset = inet6_sk_port_offset(sk); + return __inet_hash_connect(death_row, sk, port_offset, + __inet6_check_established); +} +EXPORT_SYMBOL_GPL(inet6_hash_connect); + +int inet6_hash(struct sock *sk) +{ + int err = 0; + + if (sk->sk_state != TCP_CLOSE) + err = __inet_hash(sk, NULL); + + return err; +} +EXPORT_SYMBOL_GPL(inet6_hash); diff --git a/net/ipv6/ioam6.c b/net/ipv6/ioam6.c new file mode 100644 index 000000000..571f0e4d9 --- /dev/null +++ b/net/ipv6/ioam6.c @@ -0,0 +1,979 @@ +// SPDX-License-Identifier: GPL-2.0+ +/* + * IPv6 IOAM implementation + * + * Author: + * Justin Iurman + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +static void ioam6_ns_release(struct ioam6_namespace *ns) +{ + kfree_rcu(ns, rcu); +} + +static void ioam6_sc_release(struct ioam6_schema *sc) +{ + kfree_rcu(sc, rcu); +} + +static void ioam6_free_ns(void *ptr, void *arg) +{ + struct ioam6_namespace *ns = (struct ioam6_namespace *)ptr; + + if (ns) + ioam6_ns_release(ns); +} + +static void ioam6_free_sc(void *ptr, void *arg) +{ + struct ioam6_schema *sc = (struct ioam6_schema *)ptr; + + if (sc) + ioam6_sc_release(sc); +} + +static int ioam6_ns_cmpfn(struct rhashtable_compare_arg *arg, const void *obj) +{ + const struct ioam6_namespace *ns = obj; + + return (ns->id != *(__be16 *)arg->key); +} + +static int ioam6_sc_cmpfn(struct rhashtable_compare_arg *arg, const void *obj) +{ + const struct ioam6_schema *sc = obj; + + return (sc->id != *(u32 *)arg->key); +} + +static const struct rhashtable_params rht_ns_params = { + .key_len = sizeof(__be16), + .key_offset = offsetof(struct ioam6_namespace, id), + .head_offset = offsetof(struct ioam6_namespace, head), + .automatic_shrinking = true, + .obj_cmpfn = ioam6_ns_cmpfn, +}; + +static const struct rhashtable_params rht_sc_params = { + .key_len = sizeof(u32), + .key_offset = offsetof(struct ioam6_schema, id), + .head_offset = offsetof(struct ioam6_schema, head), + .automatic_shrinking = true, + .obj_cmpfn = ioam6_sc_cmpfn, +}; + +static struct genl_family ioam6_genl_family; + +static const struct nla_policy ioam6_genl_policy_addns[] = { + [IOAM6_ATTR_NS_ID] = { .type = NLA_U16 }, + [IOAM6_ATTR_NS_DATA] = { .type = NLA_U32 }, + [IOAM6_ATTR_NS_DATA_WIDE] = { .type = NLA_U64 }, +}; + +static const struct nla_policy ioam6_genl_policy_delns[] = { + [IOAM6_ATTR_NS_ID] = { .type = NLA_U16 }, +}; + +static const struct nla_policy ioam6_genl_policy_addsc[] = { + [IOAM6_ATTR_SC_ID] = { .type = NLA_U32 }, + [IOAM6_ATTR_SC_DATA] = { .type = NLA_BINARY, + .len = IOAM6_MAX_SCHEMA_DATA_LEN }, +}; + +static const struct nla_policy ioam6_genl_policy_delsc[] = { + [IOAM6_ATTR_SC_ID] = { .type = NLA_U32 }, +}; + +static const struct nla_policy ioam6_genl_policy_ns_sc[] = { + [IOAM6_ATTR_NS_ID] = { .type = NLA_U16 }, + [IOAM6_ATTR_SC_ID] = { .type = NLA_U32 }, + [IOAM6_ATTR_SC_NONE] = { .type = NLA_FLAG }, +}; + +static int ioam6_genl_addns(struct sk_buff *skb, struct genl_info *info) +{ + struct ioam6_pernet_data *nsdata; + struct ioam6_namespace *ns; + u64 data64; + u32 data32; + __be16 id; + int err; + + if (!info->attrs[IOAM6_ATTR_NS_ID]) + return -EINVAL; + + id = cpu_to_be16(nla_get_u16(info->attrs[IOAM6_ATTR_NS_ID])); + nsdata = ioam6_pernet(genl_info_net(info)); + + mutex_lock(&nsdata->lock); + + ns = rhashtable_lookup_fast(&nsdata->namespaces, &id, rht_ns_params); + if (ns) { + err = -EEXIST; + goto out_unlock; + } + + ns = kzalloc(sizeof(*ns), GFP_KERNEL); + if (!ns) { + err = -ENOMEM; + goto out_unlock; + } + + ns->id = id; + + if (!info->attrs[IOAM6_ATTR_NS_DATA]) + data32 = IOAM6_U32_UNAVAILABLE; + else + data32 = nla_get_u32(info->attrs[IOAM6_ATTR_NS_DATA]); + + if (!info->attrs[IOAM6_ATTR_NS_DATA_WIDE]) + data64 = IOAM6_U64_UNAVAILABLE; + else + data64 = nla_get_u64(info->attrs[IOAM6_ATTR_NS_DATA_WIDE]); + + ns->data = cpu_to_be32(data32); + ns->data_wide = cpu_to_be64(data64); + + err = rhashtable_lookup_insert_fast(&nsdata->namespaces, &ns->head, + rht_ns_params); + if (err) + kfree(ns); + +out_unlock: + mutex_unlock(&nsdata->lock); + return err; +} + +static int ioam6_genl_delns(struct sk_buff *skb, struct genl_info *info) +{ + struct ioam6_pernet_data *nsdata; + struct ioam6_namespace *ns; + struct ioam6_schema *sc; + __be16 id; + int err; + + if (!info->attrs[IOAM6_ATTR_NS_ID]) + return -EINVAL; + + id = cpu_to_be16(nla_get_u16(info->attrs[IOAM6_ATTR_NS_ID])); + nsdata = ioam6_pernet(genl_info_net(info)); + + mutex_lock(&nsdata->lock); + + ns = rhashtable_lookup_fast(&nsdata->namespaces, &id, rht_ns_params); + if (!ns) { + err = -ENOENT; + goto out_unlock; + } + + sc = rcu_dereference_protected(ns->schema, + lockdep_is_held(&nsdata->lock)); + + err = rhashtable_remove_fast(&nsdata->namespaces, &ns->head, + rht_ns_params); + if (err) + goto out_unlock; + + if (sc) + rcu_assign_pointer(sc->ns, NULL); + + ioam6_ns_release(ns); + +out_unlock: + mutex_unlock(&nsdata->lock); + return err; +} + +static int __ioam6_genl_dumpns_element(struct ioam6_namespace *ns, + u32 portid, + u32 seq, + u32 flags, + struct sk_buff *skb, + u8 cmd) +{ + struct ioam6_schema *sc; + u64 data64; + u32 data32; + void *hdr; + + hdr = genlmsg_put(skb, portid, seq, &ioam6_genl_family, flags, cmd); + if (!hdr) + return -ENOMEM; + + data32 = be32_to_cpu(ns->data); + data64 = be64_to_cpu(ns->data_wide); + + if (nla_put_u16(skb, IOAM6_ATTR_NS_ID, be16_to_cpu(ns->id)) || + (data32 != IOAM6_U32_UNAVAILABLE && + nla_put_u32(skb, IOAM6_ATTR_NS_DATA, data32)) || + (data64 != IOAM6_U64_UNAVAILABLE && + nla_put_u64_64bit(skb, IOAM6_ATTR_NS_DATA_WIDE, + data64, IOAM6_ATTR_PAD))) + goto nla_put_failure; + + rcu_read_lock(); + + sc = rcu_dereference(ns->schema); + if (sc && nla_put_u32(skb, IOAM6_ATTR_SC_ID, sc->id)) { + rcu_read_unlock(); + goto nla_put_failure; + } + + rcu_read_unlock(); + + genlmsg_end(skb, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; +} + +static int ioam6_genl_dumpns_start(struct netlink_callback *cb) +{ + struct ioam6_pernet_data *nsdata = ioam6_pernet(sock_net(cb->skb->sk)); + struct rhashtable_iter *iter = (struct rhashtable_iter *)cb->args[0]; + + if (!iter) { + iter = kmalloc(sizeof(*iter), GFP_KERNEL); + if (!iter) + return -ENOMEM; + + cb->args[0] = (long)iter; + } + + rhashtable_walk_enter(&nsdata->namespaces, iter); + + return 0; +} + +static int ioam6_genl_dumpns_done(struct netlink_callback *cb) +{ + struct rhashtable_iter *iter = (struct rhashtable_iter *)cb->args[0]; + + rhashtable_walk_exit(iter); + kfree(iter); + + return 0; +} + +static int ioam6_genl_dumpns(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct rhashtable_iter *iter; + struct ioam6_namespace *ns; + int err; + + iter = (struct rhashtable_iter *)cb->args[0]; + rhashtable_walk_start(iter); + + for (;;) { + ns = rhashtable_walk_next(iter); + + if (IS_ERR(ns)) { + if (PTR_ERR(ns) == -EAGAIN) + continue; + err = PTR_ERR(ns); + goto done; + } else if (!ns) { + break; + } + + err = __ioam6_genl_dumpns_element(ns, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI, + skb, + IOAM6_CMD_DUMP_NAMESPACES); + if (err) + goto done; + } + + err = skb->len; + +done: + rhashtable_walk_stop(iter); + return err; +} + +static int ioam6_genl_addsc(struct sk_buff *skb, struct genl_info *info) +{ + struct ioam6_pernet_data *nsdata; + int len, len_aligned, err; + struct ioam6_schema *sc; + u32 id; + + if (!info->attrs[IOAM6_ATTR_SC_ID] || !info->attrs[IOAM6_ATTR_SC_DATA]) + return -EINVAL; + + id = nla_get_u32(info->attrs[IOAM6_ATTR_SC_ID]); + nsdata = ioam6_pernet(genl_info_net(info)); + + mutex_lock(&nsdata->lock); + + sc = rhashtable_lookup_fast(&nsdata->schemas, &id, rht_sc_params); + if (sc) { + err = -EEXIST; + goto out_unlock; + } + + len = nla_len(info->attrs[IOAM6_ATTR_SC_DATA]); + len_aligned = ALIGN(len, 4); + + sc = kzalloc(sizeof(*sc) + len_aligned, GFP_KERNEL); + if (!sc) { + err = -ENOMEM; + goto out_unlock; + } + + sc->id = id; + sc->len = len_aligned; + sc->hdr = cpu_to_be32(sc->id | ((u8)(sc->len / 4) << 24)); + nla_memcpy(sc->data, info->attrs[IOAM6_ATTR_SC_DATA], len); + + err = rhashtable_lookup_insert_fast(&nsdata->schemas, &sc->head, + rht_sc_params); + if (err) + goto free_sc; + +out_unlock: + mutex_unlock(&nsdata->lock); + return err; +free_sc: + kfree(sc); + goto out_unlock; +} + +static int ioam6_genl_delsc(struct sk_buff *skb, struct genl_info *info) +{ + struct ioam6_pernet_data *nsdata; + struct ioam6_namespace *ns; + struct ioam6_schema *sc; + int err; + u32 id; + + if (!info->attrs[IOAM6_ATTR_SC_ID]) + return -EINVAL; + + id = nla_get_u32(info->attrs[IOAM6_ATTR_SC_ID]); + nsdata = ioam6_pernet(genl_info_net(info)); + + mutex_lock(&nsdata->lock); + + sc = rhashtable_lookup_fast(&nsdata->schemas, &id, rht_sc_params); + if (!sc) { + err = -ENOENT; + goto out_unlock; + } + + ns = rcu_dereference_protected(sc->ns, lockdep_is_held(&nsdata->lock)); + + err = rhashtable_remove_fast(&nsdata->schemas, &sc->head, + rht_sc_params); + if (err) + goto out_unlock; + + if (ns) + rcu_assign_pointer(ns->schema, NULL); + + ioam6_sc_release(sc); + +out_unlock: + mutex_unlock(&nsdata->lock); + return err; +} + +static int __ioam6_genl_dumpsc_element(struct ioam6_schema *sc, + u32 portid, u32 seq, u32 flags, + struct sk_buff *skb, u8 cmd) +{ + struct ioam6_namespace *ns; + void *hdr; + + hdr = genlmsg_put(skb, portid, seq, &ioam6_genl_family, flags, cmd); + if (!hdr) + return -ENOMEM; + + if (nla_put_u32(skb, IOAM6_ATTR_SC_ID, sc->id) || + nla_put(skb, IOAM6_ATTR_SC_DATA, sc->len, sc->data)) + goto nla_put_failure; + + rcu_read_lock(); + + ns = rcu_dereference(sc->ns); + if (ns && nla_put_u16(skb, IOAM6_ATTR_NS_ID, be16_to_cpu(ns->id))) { + rcu_read_unlock(); + goto nla_put_failure; + } + + rcu_read_unlock(); + + genlmsg_end(skb, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; +} + +static int ioam6_genl_dumpsc_start(struct netlink_callback *cb) +{ + struct ioam6_pernet_data *nsdata = ioam6_pernet(sock_net(cb->skb->sk)); + struct rhashtable_iter *iter = (struct rhashtable_iter *)cb->args[0]; + + if (!iter) { + iter = kmalloc(sizeof(*iter), GFP_KERNEL); + if (!iter) + return -ENOMEM; + + cb->args[0] = (long)iter; + } + + rhashtable_walk_enter(&nsdata->schemas, iter); + + return 0; +} + +static int ioam6_genl_dumpsc_done(struct netlink_callback *cb) +{ + struct rhashtable_iter *iter = (struct rhashtable_iter *)cb->args[0]; + + rhashtable_walk_exit(iter); + kfree(iter); + + return 0; +} + +static int ioam6_genl_dumpsc(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct rhashtable_iter *iter; + struct ioam6_schema *sc; + int err; + + iter = (struct rhashtable_iter *)cb->args[0]; + rhashtable_walk_start(iter); + + for (;;) { + sc = rhashtable_walk_next(iter); + + if (IS_ERR(sc)) { + if (PTR_ERR(sc) == -EAGAIN) + continue; + err = PTR_ERR(sc); + goto done; + } else if (!sc) { + break; + } + + err = __ioam6_genl_dumpsc_element(sc, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI, + skb, + IOAM6_CMD_DUMP_SCHEMAS); + if (err) + goto done; + } + + err = skb->len; + +done: + rhashtable_walk_stop(iter); + return err; +} + +static int ioam6_genl_ns_set_schema(struct sk_buff *skb, struct genl_info *info) +{ + struct ioam6_namespace *ns, *ns_ref; + struct ioam6_schema *sc, *sc_ref; + struct ioam6_pernet_data *nsdata; + __be16 ns_id; + u32 sc_id; + int err; + + if (!info->attrs[IOAM6_ATTR_NS_ID] || + (!info->attrs[IOAM6_ATTR_SC_ID] && + !info->attrs[IOAM6_ATTR_SC_NONE])) + return -EINVAL; + + ns_id = cpu_to_be16(nla_get_u16(info->attrs[IOAM6_ATTR_NS_ID])); + nsdata = ioam6_pernet(genl_info_net(info)); + + mutex_lock(&nsdata->lock); + + ns = rhashtable_lookup_fast(&nsdata->namespaces, &ns_id, rht_ns_params); + if (!ns) { + err = -ENOENT; + goto out_unlock; + } + + if (info->attrs[IOAM6_ATTR_SC_NONE]) { + sc = NULL; + } else { + sc_id = nla_get_u32(info->attrs[IOAM6_ATTR_SC_ID]); + sc = rhashtable_lookup_fast(&nsdata->schemas, &sc_id, + rht_sc_params); + if (!sc) { + err = -ENOENT; + goto out_unlock; + } + } + + sc_ref = rcu_dereference_protected(ns->schema, + lockdep_is_held(&nsdata->lock)); + if (sc_ref) + rcu_assign_pointer(sc_ref->ns, NULL); + rcu_assign_pointer(ns->schema, sc); + + if (sc) { + ns_ref = rcu_dereference_protected(sc->ns, + lockdep_is_held(&nsdata->lock)); + if (ns_ref) + rcu_assign_pointer(ns_ref->schema, NULL); + rcu_assign_pointer(sc->ns, ns); + } + + err = 0; + +out_unlock: + mutex_unlock(&nsdata->lock); + return err; +} + +static const struct genl_ops ioam6_genl_ops[] = { + { + .cmd = IOAM6_CMD_ADD_NAMESPACE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = ioam6_genl_addns, + .flags = GENL_ADMIN_PERM, + .policy = ioam6_genl_policy_addns, + .maxattr = ARRAY_SIZE(ioam6_genl_policy_addns) - 1, + }, + { + .cmd = IOAM6_CMD_DEL_NAMESPACE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = ioam6_genl_delns, + .flags = GENL_ADMIN_PERM, + .policy = ioam6_genl_policy_delns, + .maxattr = ARRAY_SIZE(ioam6_genl_policy_delns) - 1, + }, + { + .cmd = IOAM6_CMD_DUMP_NAMESPACES, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .start = ioam6_genl_dumpns_start, + .dumpit = ioam6_genl_dumpns, + .done = ioam6_genl_dumpns_done, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = IOAM6_CMD_ADD_SCHEMA, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = ioam6_genl_addsc, + .flags = GENL_ADMIN_PERM, + .policy = ioam6_genl_policy_addsc, + .maxattr = ARRAY_SIZE(ioam6_genl_policy_addsc) - 1, + }, + { + .cmd = IOAM6_CMD_DEL_SCHEMA, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = ioam6_genl_delsc, + .flags = GENL_ADMIN_PERM, + .policy = ioam6_genl_policy_delsc, + .maxattr = ARRAY_SIZE(ioam6_genl_policy_delsc) - 1, + }, + { + .cmd = IOAM6_CMD_DUMP_SCHEMAS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .start = ioam6_genl_dumpsc_start, + .dumpit = ioam6_genl_dumpsc, + .done = ioam6_genl_dumpsc_done, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = IOAM6_CMD_NS_SET_SCHEMA, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = ioam6_genl_ns_set_schema, + .flags = GENL_ADMIN_PERM, + .policy = ioam6_genl_policy_ns_sc, + .maxattr = ARRAY_SIZE(ioam6_genl_policy_ns_sc) - 1, + }, +}; + +static struct genl_family ioam6_genl_family __ro_after_init = { + .name = IOAM6_GENL_NAME, + .version = IOAM6_GENL_VERSION, + .netnsok = true, + .parallel_ops = true, + .ops = ioam6_genl_ops, + .n_ops = ARRAY_SIZE(ioam6_genl_ops), + .resv_start_op = IOAM6_CMD_NS_SET_SCHEMA + 1, + .module = THIS_MODULE, +}; + +struct ioam6_namespace *ioam6_namespace(struct net *net, __be16 id) +{ + struct ioam6_pernet_data *nsdata = ioam6_pernet(net); + + return rhashtable_lookup_fast(&nsdata->namespaces, &id, rht_ns_params); +} + +static void __ioam6_fill_trace_data(struct sk_buff *skb, + struct ioam6_namespace *ns, + struct ioam6_trace_hdr *trace, + struct ioam6_schema *sc, + u8 sclen, bool is_input) +{ + struct timespec64 ts; + ktime_t tstamp; + u64 raw64; + u32 raw32; + u16 raw16; + u8 *data; + u8 byte; + + data = trace->data + trace->remlen * 4 - trace->nodelen * 4 - sclen * 4; + + /* hop_lim and node_id */ + if (trace->type.bit0) { + byte = ipv6_hdr(skb)->hop_limit; + if (is_input) + byte--; + + raw32 = dev_net(skb_dst(skb)->dev)->ipv6.sysctl.ioam6_id; + + *(__be32 *)data = cpu_to_be32((byte << 24) | raw32); + data += sizeof(__be32); + } + + /* ingress_if_id and egress_if_id */ + if (trace->type.bit1) { + if (!skb->dev) + raw16 = IOAM6_U16_UNAVAILABLE; + else + raw16 = (__force u16)__in6_dev_get(skb->dev)->cnf.ioam6_id; + + *(__be16 *)data = cpu_to_be16(raw16); + data += sizeof(__be16); + + if (skb_dst(skb)->dev->flags & IFF_LOOPBACK) + raw16 = IOAM6_U16_UNAVAILABLE; + else + raw16 = (__force u16)__in6_dev_get(skb_dst(skb)->dev)->cnf.ioam6_id; + + *(__be16 *)data = cpu_to_be16(raw16); + data += sizeof(__be16); + } + + /* timestamp seconds */ + if (trace->type.bit2) { + if (!skb->dev) { + *(__be32 *)data = cpu_to_be32(IOAM6_U32_UNAVAILABLE); + } else { + tstamp = skb_tstamp_cond(skb, true); + ts = ktime_to_timespec64(tstamp); + + *(__be32 *)data = cpu_to_be32((u32)ts.tv_sec); + } + data += sizeof(__be32); + } + + /* timestamp subseconds */ + if (trace->type.bit3) { + if (!skb->dev) { + *(__be32 *)data = cpu_to_be32(IOAM6_U32_UNAVAILABLE); + } else { + if (!trace->type.bit2) { + tstamp = skb_tstamp_cond(skb, true); + ts = ktime_to_timespec64(tstamp); + } + + *(__be32 *)data = cpu_to_be32((u32)(ts.tv_nsec / NSEC_PER_USEC)); + } + data += sizeof(__be32); + } + + /* transit delay */ + if (trace->type.bit4) { + *(__be32 *)data = cpu_to_be32(IOAM6_U32_UNAVAILABLE); + data += sizeof(__be32); + } + + /* namespace data */ + if (trace->type.bit5) { + *(__be32 *)data = ns->data; + data += sizeof(__be32); + } + + /* queue depth */ + if (trace->type.bit6) { + struct netdev_queue *queue; + struct Qdisc *qdisc; + __u32 qlen, backlog; + + if (skb_dst(skb)->dev->flags & IFF_LOOPBACK) { + *(__be32 *)data = cpu_to_be32(IOAM6_U32_UNAVAILABLE); + } else { + queue = skb_get_tx_queue(skb_dst(skb)->dev, skb); + qdisc = rcu_dereference(queue->qdisc); + qdisc_qstats_qlen_backlog(qdisc, &qlen, &backlog); + + *(__be32 *)data = cpu_to_be32(backlog); + } + data += sizeof(__be32); + } + + /* checksum complement */ + if (trace->type.bit7) { + *(__be32 *)data = cpu_to_be32(IOAM6_U32_UNAVAILABLE); + data += sizeof(__be32); + } + + /* hop_lim and node_id (wide) */ + if (trace->type.bit8) { + byte = ipv6_hdr(skb)->hop_limit; + if (is_input) + byte--; + + raw64 = dev_net(skb_dst(skb)->dev)->ipv6.sysctl.ioam6_id_wide; + + *(__be64 *)data = cpu_to_be64(((u64)byte << 56) | raw64); + data += sizeof(__be64); + } + + /* ingress_if_id and egress_if_id (wide) */ + if (trace->type.bit9) { + if (!skb->dev) + raw32 = IOAM6_U32_UNAVAILABLE; + else + raw32 = __in6_dev_get(skb->dev)->cnf.ioam6_id_wide; + + *(__be32 *)data = cpu_to_be32(raw32); + data += sizeof(__be32); + + if (skb_dst(skb)->dev->flags & IFF_LOOPBACK) + raw32 = IOAM6_U32_UNAVAILABLE; + else + raw32 = __in6_dev_get(skb_dst(skb)->dev)->cnf.ioam6_id_wide; + + *(__be32 *)data = cpu_to_be32(raw32); + data += sizeof(__be32); + } + + /* namespace data (wide) */ + if (trace->type.bit10) { + *(__be64 *)data = ns->data_wide; + data += sizeof(__be64); + } + + /* buffer occupancy */ + if (trace->type.bit11) { + *(__be32 *)data = cpu_to_be32(IOAM6_U32_UNAVAILABLE); + data += sizeof(__be32); + } + + /* bit12 undefined: filled with empty value */ + if (trace->type.bit12) { + *(__be32 *)data = cpu_to_be32(IOAM6_U32_UNAVAILABLE); + data += sizeof(__be32); + } + + /* bit13 undefined: filled with empty value */ + if (trace->type.bit13) { + *(__be32 *)data = cpu_to_be32(IOAM6_U32_UNAVAILABLE); + data += sizeof(__be32); + } + + /* bit14 undefined: filled with empty value */ + if (trace->type.bit14) { + *(__be32 *)data = cpu_to_be32(IOAM6_U32_UNAVAILABLE); + data += sizeof(__be32); + } + + /* bit15 undefined: filled with empty value */ + if (trace->type.bit15) { + *(__be32 *)data = cpu_to_be32(IOAM6_U32_UNAVAILABLE); + data += sizeof(__be32); + } + + /* bit16 undefined: filled with empty value */ + if (trace->type.bit16) { + *(__be32 *)data = cpu_to_be32(IOAM6_U32_UNAVAILABLE); + data += sizeof(__be32); + } + + /* bit17 undefined: filled with empty value */ + if (trace->type.bit17) { + *(__be32 *)data = cpu_to_be32(IOAM6_U32_UNAVAILABLE); + data += sizeof(__be32); + } + + /* bit18 undefined: filled with empty value */ + if (trace->type.bit18) { + *(__be32 *)data = cpu_to_be32(IOAM6_U32_UNAVAILABLE); + data += sizeof(__be32); + } + + /* bit19 undefined: filled with empty value */ + if (trace->type.bit19) { + *(__be32 *)data = cpu_to_be32(IOAM6_U32_UNAVAILABLE); + data += sizeof(__be32); + } + + /* bit20 undefined: filled with empty value */ + if (trace->type.bit20) { + *(__be32 *)data = cpu_to_be32(IOAM6_U32_UNAVAILABLE); + data += sizeof(__be32); + } + + /* bit21 undefined: filled with empty value */ + if (trace->type.bit21) { + *(__be32 *)data = cpu_to_be32(IOAM6_U32_UNAVAILABLE); + data += sizeof(__be32); + } + + /* opaque state snapshot */ + if (trace->type.bit22) { + if (!sc) { + *(__be32 *)data = cpu_to_be32(IOAM6_U32_UNAVAILABLE >> 8); + } else { + *(__be32 *)data = sc->hdr; + data += sizeof(__be32); + + memcpy(data, sc->data, sc->len); + } + } +} + +/* called with rcu_read_lock() */ +void ioam6_fill_trace_data(struct sk_buff *skb, + struct ioam6_namespace *ns, + struct ioam6_trace_hdr *trace, + bool is_input) +{ + struct ioam6_schema *sc; + u8 sclen = 0; + + /* Skip if Overflow flag is set + */ + if (trace->overflow) + return; + + /* NodeLen does not include Opaque State Snapshot length. We need to + * take it into account if the corresponding bit is set (bit 22) and + * if the current IOAM namespace has an active schema attached to it + */ + sc = rcu_dereference(ns->schema); + if (trace->type.bit22) { + sclen = sizeof_field(struct ioam6_schema, hdr) / 4; + + if (sc) + sclen += sc->len / 4; + } + + /* If there is no space remaining, we set the Overflow flag and we + * skip without filling the trace + */ + if (!trace->remlen || trace->remlen < trace->nodelen + sclen) { + trace->overflow = 1; + return; + } + + __ioam6_fill_trace_data(skb, ns, trace, sc, sclen, is_input); + trace->remlen -= trace->nodelen + sclen; +} + +static int __net_init ioam6_net_init(struct net *net) +{ + struct ioam6_pernet_data *nsdata; + int err = -ENOMEM; + + nsdata = kzalloc(sizeof(*nsdata), GFP_KERNEL); + if (!nsdata) + goto out; + + mutex_init(&nsdata->lock); + net->ipv6.ioam6_data = nsdata; + + err = rhashtable_init(&nsdata->namespaces, &rht_ns_params); + if (err) + goto free_nsdata; + + err = rhashtable_init(&nsdata->schemas, &rht_sc_params); + if (err) + goto free_rht_ns; + +out: + return err; +free_rht_ns: + rhashtable_destroy(&nsdata->namespaces); +free_nsdata: + kfree(nsdata); + net->ipv6.ioam6_data = NULL; + goto out; +} + +static void __net_exit ioam6_net_exit(struct net *net) +{ + struct ioam6_pernet_data *nsdata = ioam6_pernet(net); + + rhashtable_free_and_destroy(&nsdata->namespaces, ioam6_free_ns, NULL); + rhashtable_free_and_destroy(&nsdata->schemas, ioam6_free_sc, NULL); + + kfree(nsdata); +} + +static struct pernet_operations ioam6_net_ops = { + .init = ioam6_net_init, + .exit = ioam6_net_exit, +}; + +int __init ioam6_init(void) +{ + int err = register_pernet_subsys(&ioam6_net_ops); + if (err) + goto out; + + err = genl_register_family(&ioam6_genl_family); + if (err) + goto out_unregister_pernet_subsys; + +#ifdef CONFIG_IPV6_IOAM6_LWTUNNEL + err = ioam6_iptunnel_init(); + if (err) + goto out_unregister_genl; +#endif + + pr_info("In-situ OAM (IOAM) with IPv6\n"); + +out: + return err; +#ifdef CONFIG_IPV6_IOAM6_LWTUNNEL +out_unregister_genl: + genl_unregister_family(&ioam6_genl_family); +#endif +out_unregister_pernet_subsys: + unregister_pernet_subsys(&ioam6_net_ops); + goto out; +} + +void ioam6_exit(void) +{ +#ifdef CONFIG_IPV6_IOAM6_LWTUNNEL + ioam6_iptunnel_exit(); +#endif + genl_unregister_family(&ioam6_genl_family); + unregister_pernet_subsys(&ioam6_net_ops); +} diff --git a/net/ipv6/ioam6_iptunnel.c b/net/ipv6/ioam6_iptunnel.c new file mode 100644 index 000000000..f6f5b83dd --- /dev/null +++ b/net/ipv6/ioam6_iptunnel.c @@ -0,0 +1,477 @@ +// SPDX-License-Identifier: GPL-2.0+ +/* + * IPv6 IOAM Lightweight Tunnel implementation + * + * Author: + * Justin Iurman + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define IOAM6_MASK_SHORT_FIELDS 0xff100000 +#define IOAM6_MASK_WIDE_FIELDS 0xe00000 + +struct ioam6_lwt_encap { + struct ipv6_hopopt_hdr eh; + u8 pad[2]; /* 2-octet padding for 4n-alignment */ + struct ioam6_hdr ioamh; + struct ioam6_trace_hdr traceh; +} __packed; + +struct ioam6_lwt_freq { + u32 k; + u32 n; +}; + +struct ioam6_lwt { + struct dst_cache cache; + struct ioam6_lwt_freq freq; + atomic_t pkt_cnt; + u8 mode; + struct in6_addr tundst; + struct ioam6_lwt_encap tuninfo; +}; + +static struct netlink_range_validation freq_range = { + .min = IOAM6_IPTUNNEL_FREQ_MIN, + .max = IOAM6_IPTUNNEL_FREQ_MAX, +}; + +static struct ioam6_lwt *ioam6_lwt_state(struct lwtunnel_state *lwt) +{ + return (struct ioam6_lwt *)lwt->data; +} + +static struct ioam6_lwt_encap *ioam6_lwt_info(struct lwtunnel_state *lwt) +{ + return &ioam6_lwt_state(lwt)->tuninfo; +} + +static struct ioam6_trace_hdr *ioam6_lwt_trace(struct lwtunnel_state *lwt) +{ + return &(ioam6_lwt_state(lwt)->tuninfo.traceh); +} + +static const struct nla_policy ioam6_iptunnel_policy[IOAM6_IPTUNNEL_MAX + 1] = { + [IOAM6_IPTUNNEL_FREQ_K] = NLA_POLICY_FULL_RANGE(NLA_U32, &freq_range), + [IOAM6_IPTUNNEL_FREQ_N] = NLA_POLICY_FULL_RANGE(NLA_U32, &freq_range), + [IOAM6_IPTUNNEL_MODE] = NLA_POLICY_RANGE(NLA_U8, + IOAM6_IPTUNNEL_MODE_MIN, + IOAM6_IPTUNNEL_MODE_MAX), + [IOAM6_IPTUNNEL_DST] = NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)), + [IOAM6_IPTUNNEL_TRACE] = NLA_POLICY_EXACT_LEN(sizeof(struct ioam6_trace_hdr)), +}; + +static bool ioam6_validate_trace_hdr(struct ioam6_trace_hdr *trace) +{ + u32 fields; + + if (!trace->type_be32 || !trace->remlen || + trace->remlen > IOAM6_TRACE_DATA_SIZE_MAX / 4 || + trace->type.bit12 | trace->type.bit13 | trace->type.bit14 | + trace->type.bit15 | trace->type.bit16 | trace->type.bit17 | + trace->type.bit18 | trace->type.bit19 | trace->type.bit20 | + trace->type.bit21) + return false; + + trace->nodelen = 0; + fields = be32_to_cpu(trace->type_be32); + + trace->nodelen += hweight32(fields & IOAM6_MASK_SHORT_FIELDS) + * (sizeof(__be32) / 4); + trace->nodelen += hweight32(fields & IOAM6_MASK_WIDE_FIELDS) + * (sizeof(__be64) / 4); + + return true; +} + +static int ioam6_build_state(struct net *net, struct nlattr *nla, + unsigned int family, const void *cfg, + struct lwtunnel_state **ts, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IOAM6_IPTUNNEL_MAX + 1]; + struct ioam6_lwt_encap *tuninfo; + struct ioam6_trace_hdr *trace; + struct lwtunnel_state *lwt; + struct ioam6_lwt *ilwt; + int len_aligned, err; + u32 freq_k, freq_n; + u8 mode; + + if (family != AF_INET6) + return -EINVAL; + + err = nla_parse_nested(tb, IOAM6_IPTUNNEL_MAX, nla, + ioam6_iptunnel_policy, extack); + if (err < 0) + return err; + + if ((!tb[IOAM6_IPTUNNEL_FREQ_K] && tb[IOAM6_IPTUNNEL_FREQ_N]) || + (tb[IOAM6_IPTUNNEL_FREQ_K] && !tb[IOAM6_IPTUNNEL_FREQ_N])) { + NL_SET_ERR_MSG(extack, "freq: missing parameter"); + return -EINVAL; + } else if (!tb[IOAM6_IPTUNNEL_FREQ_K] && !tb[IOAM6_IPTUNNEL_FREQ_N]) { + freq_k = IOAM6_IPTUNNEL_FREQ_MIN; + freq_n = IOAM6_IPTUNNEL_FREQ_MIN; + } else { + freq_k = nla_get_u32(tb[IOAM6_IPTUNNEL_FREQ_K]); + freq_n = nla_get_u32(tb[IOAM6_IPTUNNEL_FREQ_N]); + + if (freq_k > freq_n) { + NL_SET_ERR_MSG(extack, "freq: k > n is forbidden"); + return -EINVAL; + } + } + + if (!tb[IOAM6_IPTUNNEL_MODE]) + mode = IOAM6_IPTUNNEL_MODE_INLINE; + else + mode = nla_get_u8(tb[IOAM6_IPTUNNEL_MODE]); + + if (!tb[IOAM6_IPTUNNEL_DST] && mode != IOAM6_IPTUNNEL_MODE_INLINE) { + NL_SET_ERR_MSG(extack, "this mode needs a tunnel destination"); + return -EINVAL; + } + + if (!tb[IOAM6_IPTUNNEL_TRACE]) { + NL_SET_ERR_MSG(extack, "missing trace"); + return -EINVAL; + } + + trace = nla_data(tb[IOAM6_IPTUNNEL_TRACE]); + if (!ioam6_validate_trace_hdr(trace)) { + NL_SET_ERR_MSG_ATTR(extack, tb[IOAM6_IPTUNNEL_TRACE], + "invalid trace validation"); + return -EINVAL; + } + + len_aligned = ALIGN(trace->remlen * 4, 8); + lwt = lwtunnel_state_alloc(sizeof(*ilwt) + len_aligned); + if (!lwt) + return -ENOMEM; + + ilwt = ioam6_lwt_state(lwt); + err = dst_cache_init(&ilwt->cache, GFP_ATOMIC); + if (err) { + kfree(lwt); + return err; + } + + atomic_set(&ilwt->pkt_cnt, 0); + ilwt->freq.k = freq_k; + ilwt->freq.n = freq_n; + + ilwt->mode = mode; + if (tb[IOAM6_IPTUNNEL_DST]) + ilwt->tundst = nla_get_in6_addr(tb[IOAM6_IPTUNNEL_DST]); + + tuninfo = ioam6_lwt_info(lwt); + tuninfo->eh.hdrlen = ((sizeof(*tuninfo) + len_aligned) >> 3) - 1; + tuninfo->pad[0] = IPV6_TLV_PADN; + tuninfo->ioamh.type = IOAM6_TYPE_PREALLOC; + tuninfo->ioamh.opt_type = IPV6_TLV_IOAM; + tuninfo->ioamh.opt_len = sizeof(tuninfo->ioamh) - 2 + sizeof(*trace) + + trace->remlen * 4; + + memcpy(&tuninfo->traceh, trace, sizeof(*trace)); + + if (len_aligned - trace->remlen * 4) { + tuninfo->traceh.data[trace->remlen * 4] = IPV6_TLV_PADN; + tuninfo->traceh.data[trace->remlen * 4 + 1] = 2; + } + + lwt->type = LWTUNNEL_ENCAP_IOAM6; + lwt->flags |= LWTUNNEL_STATE_OUTPUT_REDIRECT; + + *ts = lwt; + + return 0; +} + +static int ioam6_do_fill(struct net *net, struct sk_buff *skb) +{ + struct ioam6_trace_hdr *trace; + struct ioam6_namespace *ns; + + trace = (struct ioam6_trace_hdr *)(skb_transport_header(skb) + + sizeof(struct ipv6_hopopt_hdr) + 2 + + sizeof(struct ioam6_hdr)); + + ns = ioam6_namespace(net, trace->namespace_id); + if (ns) + ioam6_fill_trace_data(skb, ns, trace, false); + + return 0; +} + +static int ioam6_do_inline(struct net *net, struct sk_buff *skb, + struct ioam6_lwt_encap *tuninfo) +{ + struct ipv6hdr *oldhdr, *hdr; + int hdrlen, err; + + hdrlen = (tuninfo->eh.hdrlen + 1) << 3; + + err = skb_cow_head(skb, hdrlen + skb->mac_len); + if (unlikely(err)) + return err; + + oldhdr = ipv6_hdr(skb); + skb_pull(skb, sizeof(*oldhdr)); + skb_postpull_rcsum(skb, skb_network_header(skb), sizeof(*oldhdr)); + + skb_push(skb, sizeof(*oldhdr) + hdrlen); + skb_reset_network_header(skb); + skb_mac_header_rebuild(skb); + + hdr = ipv6_hdr(skb); + memmove(hdr, oldhdr, sizeof(*oldhdr)); + tuninfo->eh.nexthdr = hdr->nexthdr; + + skb_set_transport_header(skb, sizeof(*hdr)); + skb_postpush_rcsum(skb, hdr, sizeof(*hdr) + hdrlen); + + memcpy(skb_transport_header(skb), (u8 *)tuninfo, hdrlen); + + hdr->nexthdr = NEXTHDR_HOP; + hdr->payload_len = cpu_to_be16(skb->len - sizeof(*hdr)); + + return ioam6_do_fill(net, skb); +} + +static int ioam6_do_encap(struct net *net, struct sk_buff *skb, + struct ioam6_lwt_encap *tuninfo, + struct in6_addr *tundst) +{ + struct dst_entry *dst = skb_dst(skb); + struct ipv6hdr *hdr, *inner_hdr; + int hdrlen, len, err; + + hdrlen = (tuninfo->eh.hdrlen + 1) << 3; + len = sizeof(*hdr) + hdrlen; + + err = skb_cow_head(skb, len + skb->mac_len); + if (unlikely(err)) + return err; + + inner_hdr = ipv6_hdr(skb); + + skb_push(skb, len); + skb_reset_network_header(skb); + skb_mac_header_rebuild(skb); + skb_set_transport_header(skb, sizeof(*hdr)); + + tuninfo->eh.nexthdr = NEXTHDR_IPV6; + memcpy(skb_transport_header(skb), (u8 *)tuninfo, hdrlen); + + hdr = ipv6_hdr(skb); + memcpy(hdr, inner_hdr, sizeof(*hdr)); + + hdr->nexthdr = NEXTHDR_HOP; + hdr->payload_len = cpu_to_be16(skb->len - sizeof(*hdr)); + hdr->daddr = *tundst; + ipv6_dev_get_saddr(net, dst->dev, &hdr->daddr, + IPV6_PREFER_SRC_PUBLIC, &hdr->saddr); + + skb_postpush_rcsum(skb, hdr, len); + + return ioam6_do_fill(net, skb); +} + +static int ioam6_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + struct in6_addr orig_daddr; + struct ioam6_lwt *ilwt; + int err = -EINVAL; + u32 pkt_cnt; + + if (skb->protocol != htons(ETH_P_IPV6)) + goto drop; + + ilwt = ioam6_lwt_state(dst->lwtstate); + + /* Check for insertion frequency (i.e., "k over n" insertions) */ + pkt_cnt = atomic_fetch_inc(&ilwt->pkt_cnt); + if (pkt_cnt % ilwt->freq.n >= ilwt->freq.k) + goto out; + + orig_daddr = ipv6_hdr(skb)->daddr; + + switch (ilwt->mode) { + case IOAM6_IPTUNNEL_MODE_INLINE: +do_inline: + /* Direct insertion - if there is no Hop-by-Hop yet */ + if (ipv6_hdr(skb)->nexthdr == NEXTHDR_HOP) + goto out; + + err = ioam6_do_inline(net, skb, &ilwt->tuninfo); + if (unlikely(err)) + goto drop; + + break; + case IOAM6_IPTUNNEL_MODE_ENCAP: +do_encap: + /* Encapsulation (ip6ip6) */ + err = ioam6_do_encap(net, skb, &ilwt->tuninfo, &ilwt->tundst); + if (unlikely(err)) + goto drop; + + break; + case IOAM6_IPTUNNEL_MODE_AUTO: + /* Automatic (RFC8200 compliant): + * - local packets -> INLINE mode + * - in-transit packets -> ENCAP mode + */ + if (!skb->dev) + goto do_inline; + + goto do_encap; + default: + goto drop; + } + + err = skb_cow_head(skb, LL_RESERVED_SPACE(dst->dev)); + if (unlikely(err)) + goto drop; + + if (!ipv6_addr_equal(&orig_daddr, &ipv6_hdr(skb)->daddr)) { + preempt_disable(); + dst = dst_cache_get(&ilwt->cache); + preempt_enable(); + + if (unlikely(!dst)) { + struct ipv6hdr *hdr = ipv6_hdr(skb); + struct flowi6 fl6; + + memset(&fl6, 0, sizeof(fl6)); + fl6.daddr = hdr->daddr; + fl6.saddr = hdr->saddr; + fl6.flowlabel = ip6_flowinfo(hdr); + fl6.flowi6_mark = skb->mark; + fl6.flowi6_proto = hdr->nexthdr; + + dst = ip6_route_output(net, NULL, &fl6); + if (dst->error) { + err = dst->error; + dst_release(dst); + goto drop; + } + + preempt_disable(); + dst_cache_set_ip6(&ilwt->cache, dst, &fl6.saddr); + preempt_enable(); + } + + skb_dst_drop(skb); + skb_dst_set(skb, dst); + + return dst_output(net, sk, skb); + } +out: + return dst->lwtstate->orig_output(net, sk, skb); +drop: + kfree_skb(skb); + return err; +} + +static void ioam6_destroy_state(struct lwtunnel_state *lwt) +{ + dst_cache_destroy(&ioam6_lwt_state(lwt)->cache); +} + +static int ioam6_fill_encap_info(struct sk_buff *skb, + struct lwtunnel_state *lwtstate) +{ + struct ioam6_lwt *ilwt = ioam6_lwt_state(lwtstate); + int err; + + err = nla_put_u32(skb, IOAM6_IPTUNNEL_FREQ_K, ilwt->freq.k); + if (err) + goto ret; + + err = nla_put_u32(skb, IOAM6_IPTUNNEL_FREQ_N, ilwt->freq.n); + if (err) + goto ret; + + err = nla_put_u8(skb, IOAM6_IPTUNNEL_MODE, ilwt->mode); + if (err) + goto ret; + + if (ilwt->mode != IOAM6_IPTUNNEL_MODE_INLINE) { + err = nla_put_in6_addr(skb, IOAM6_IPTUNNEL_DST, &ilwt->tundst); + if (err) + goto ret; + } + + err = nla_put(skb, IOAM6_IPTUNNEL_TRACE, sizeof(ilwt->tuninfo.traceh), + &ilwt->tuninfo.traceh); +ret: + return err; +} + +static int ioam6_encap_nlsize(struct lwtunnel_state *lwtstate) +{ + struct ioam6_lwt *ilwt = ioam6_lwt_state(lwtstate); + int nlsize; + + nlsize = nla_total_size(sizeof(ilwt->freq.k)) + + nla_total_size(sizeof(ilwt->freq.n)) + + nla_total_size(sizeof(ilwt->mode)) + + nla_total_size(sizeof(ilwt->tuninfo.traceh)); + + if (ilwt->mode != IOAM6_IPTUNNEL_MODE_INLINE) + nlsize += nla_total_size(sizeof(ilwt->tundst)); + + return nlsize; +} + +static int ioam6_encap_cmp(struct lwtunnel_state *a, struct lwtunnel_state *b) +{ + struct ioam6_trace_hdr *trace_a = ioam6_lwt_trace(a); + struct ioam6_trace_hdr *trace_b = ioam6_lwt_trace(b); + struct ioam6_lwt *ilwt_a = ioam6_lwt_state(a); + struct ioam6_lwt *ilwt_b = ioam6_lwt_state(b); + + return (ilwt_a->freq.k != ilwt_b->freq.k || + ilwt_a->freq.n != ilwt_b->freq.n || + ilwt_a->mode != ilwt_b->mode || + (ilwt_a->mode != IOAM6_IPTUNNEL_MODE_INLINE && + !ipv6_addr_equal(&ilwt_a->tundst, &ilwt_b->tundst)) || + trace_a->namespace_id != trace_b->namespace_id); +} + +static const struct lwtunnel_encap_ops ioam6_iptun_ops = { + .build_state = ioam6_build_state, + .destroy_state = ioam6_destroy_state, + .output = ioam6_output, + .fill_encap = ioam6_fill_encap_info, + .get_encap_size = ioam6_encap_nlsize, + .cmp_encap = ioam6_encap_cmp, + .owner = THIS_MODULE, +}; + +int __init ioam6_iptunnel_init(void) +{ + return lwtunnel_encap_add_ops(&ioam6_iptun_ops, LWTUNNEL_ENCAP_IOAM6); +} + +void ioam6_iptunnel_exit(void) +{ + lwtunnel_encap_del_ops(&ioam6_iptun_ops, LWTUNNEL_ENCAP_IOAM6); +} diff --git a/net/ipv6/ip6_checksum.c b/net/ipv6/ip6_checksum.c new file mode 100644 index 000000000..377717045 --- /dev/null +++ b/net/ipv6/ip6_checksum.c @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include + +#ifndef _HAVE_ARCH_IPV6_CSUM +__sum16 csum_ipv6_magic(const struct in6_addr *saddr, + const struct in6_addr *daddr, + __u32 len, __u8 proto, __wsum csum) +{ + + int carry; + __u32 ulen; + __u32 uproto; + __u32 sum = (__force u32)csum; + + sum += (__force u32)saddr->s6_addr32[0]; + carry = (sum < (__force u32)saddr->s6_addr32[0]); + sum += carry; + + sum += (__force u32)saddr->s6_addr32[1]; + carry = (sum < (__force u32)saddr->s6_addr32[1]); + sum += carry; + + sum += (__force u32)saddr->s6_addr32[2]; + carry = (sum < (__force u32)saddr->s6_addr32[2]); + sum += carry; + + sum += (__force u32)saddr->s6_addr32[3]; + carry = (sum < (__force u32)saddr->s6_addr32[3]); + sum += carry; + + sum += (__force u32)daddr->s6_addr32[0]; + carry = (sum < (__force u32)daddr->s6_addr32[0]); + sum += carry; + + sum += (__force u32)daddr->s6_addr32[1]; + carry = (sum < (__force u32)daddr->s6_addr32[1]); + sum += carry; + + sum += (__force u32)daddr->s6_addr32[2]; + carry = (sum < (__force u32)daddr->s6_addr32[2]); + sum += carry; + + sum += (__force u32)daddr->s6_addr32[3]; + carry = (sum < (__force u32)daddr->s6_addr32[3]); + sum += carry; + + ulen = (__force u32)htonl((__u32) len); + sum += ulen; + carry = (sum < ulen); + sum += carry; + + uproto = (__force u32)htonl(proto); + sum += uproto; + carry = (sum < uproto); + sum += carry; + + return csum_fold((__force __wsum)sum); +} +EXPORT_SYMBOL(csum_ipv6_magic); +#endif + +int udp6_csum_init(struct sk_buff *skb, struct udphdr *uh, int proto) +{ + int err; + + UDP_SKB_CB(skb)->partial_cov = 0; + UDP_SKB_CB(skb)->cscov = skb->len; + + if (proto == IPPROTO_UDPLITE) { + err = udplite_checksum_init(skb, uh); + if (err) + return err; + + if (UDP_SKB_CB(skb)->partial_cov) { + skb->csum = ip6_compute_pseudo(skb, proto); + return 0; + } + } + + /* To support RFC 6936 (allow zero checksum in UDP/IPV6 for tunnels) + * we accept a checksum of zero here. When we find the socket + * for the UDP packet we'll check if that socket allows zero checksum + * for IPv6 (set by socket option). + * + * Note, we are only interested in != 0 or == 0, thus the + * force to int. + */ + err = (__force int)skb_checksum_init_zero_check(skb, proto, uh->check, + ip6_compute_pseudo); + if (err) + return err; + + if (skb->ip_summed == CHECKSUM_COMPLETE && !skb->csum_valid) { + /* If SW calculated the value, we know it's bad */ + if (skb->csum_complete_sw) + return 1; + + /* HW says the value is bad. Let's validate that. + * skb->csum is no longer the full packet checksum, + * so don't treat is as such. + */ + skb_checksum_complete_unset(skb); + } + + return 0; +} +EXPORT_SYMBOL(udp6_csum_init); + +/* Function to set UDP checksum for an IPv6 UDP packet. This is intended + * for the simple case like when setting the checksum for a UDP tunnel. + */ +void udp6_set_csum(bool nocheck, struct sk_buff *skb, + const struct in6_addr *saddr, + const struct in6_addr *daddr, int len) +{ + struct udphdr *uh = udp_hdr(skb); + + if (nocheck) + uh->check = 0; + else if (skb_is_gso(skb)) + uh->check = ~udp_v6_check(len, saddr, daddr, 0); + else if (skb->ip_summed == CHECKSUM_PARTIAL) { + uh->check = 0; + uh->check = udp_v6_check(len, saddr, daddr, lco_csum(skb)); + if (uh->check == 0) + uh->check = CSUM_MANGLED_0; + } else { + skb->ip_summed = CHECKSUM_PARTIAL; + skb->csum_start = skb_transport_header(skb) - skb->head; + skb->csum_offset = offsetof(struct udphdr, check); + uh->check = ~udp_v6_check(len, saddr, daddr, 0); + } +} +EXPORT_SYMBOL(udp6_set_csum); diff --git a/net/ipv6/ip6_fib.c b/net/ipv6/ip6_fib.c new file mode 100644 index 000000000..1840735e9 --- /dev/null +++ b/net/ipv6/ip6_fib.c @@ -0,0 +1,2718 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Linux INET6 implementation + * Forwarding Information Database + * + * Authors: + * Pedro Roque + * + * Changes: + * Yuji SEKIYA @USAGI: Support default route on router node; + * remove ip6_null_entry from the top of + * routing table. + * Ville Nuorvala: Fixed routing subtrees. + */ + +#define pr_fmt(fmt) "IPv6: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +static struct kmem_cache *fib6_node_kmem __read_mostly; + +struct fib6_cleaner { + struct fib6_walker w; + struct net *net; + int (*func)(struct fib6_info *, void *arg); + int sernum; + void *arg; + bool skip_notify; +}; + +#ifdef CONFIG_IPV6_SUBTREES +#define FWS_INIT FWS_S +#else +#define FWS_INIT FWS_L +#endif + +static struct fib6_info *fib6_find_prefix(struct net *net, + struct fib6_table *table, + struct fib6_node *fn); +static struct fib6_node *fib6_repair_tree(struct net *net, + struct fib6_table *table, + struct fib6_node *fn); +static int fib6_walk(struct net *net, struct fib6_walker *w); +static int fib6_walk_continue(struct fib6_walker *w); + +/* + * A routing update causes an increase of the serial number on the + * affected subtree. This allows for cached routes to be asynchronously + * tested when modifications are made to the destination cache as a + * result of redirects, path MTU changes, etc. + */ + +static void fib6_gc_timer_cb(struct timer_list *t); + +#define FOR_WALKERS(net, w) \ + list_for_each_entry(w, &(net)->ipv6.fib6_walkers, lh) + +static void fib6_walker_link(struct net *net, struct fib6_walker *w) +{ + write_lock_bh(&net->ipv6.fib6_walker_lock); + list_add(&w->lh, &net->ipv6.fib6_walkers); + write_unlock_bh(&net->ipv6.fib6_walker_lock); +} + +static void fib6_walker_unlink(struct net *net, struct fib6_walker *w) +{ + write_lock_bh(&net->ipv6.fib6_walker_lock); + list_del(&w->lh); + write_unlock_bh(&net->ipv6.fib6_walker_lock); +} + +static int fib6_new_sernum(struct net *net) +{ + int new, old; + + do { + old = atomic_read(&net->ipv6.fib6_sernum); + new = old < INT_MAX ? old + 1 : 1; + } while (atomic_cmpxchg(&net->ipv6.fib6_sernum, + old, new) != old); + return new; +} + +enum { + FIB6_NO_SERNUM_CHANGE = 0, +}; + +void fib6_update_sernum(struct net *net, struct fib6_info *f6i) +{ + struct fib6_node *fn; + + fn = rcu_dereference_protected(f6i->fib6_node, + lockdep_is_held(&f6i->fib6_table->tb6_lock)); + if (fn) + WRITE_ONCE(fn->fn_sernum, fib6_new_sernum(net)); +} + +/* + * Auxiliary address test functions for the radix tree. + * + * These assume a 32bit processor (although it will work on + * 64bit processors) + */ + +/* + * test bit + */ +#if defined(__LITTLE_ENDIAN) +# define BITOP_BE32_SWIZZLE (0x1F & ~7) +#else +# define BITOP_BE32_SWIZZLE 0 +#endif + +static __be32 addr_bit_set(const void *token, int fn_bit) +{ + const __be32 *addr = token; + /* + * Here, + * 1 << ((~fn_bit ^ BITOP_BE32_SWIZZLE) & 0x1f) + * is optimized version of + * htonl(1 << ((~fn_bit)&0x1F)) + * See include/asm-generic/bitops/le.h. + */ + return (__force __be32)(1 << ((~fn_bit ^ BITOP_BE32_SWIZZLE) & 0x1f)) & + addr[fn_bit >> 5]; +} + +struct fib6_info *fib6_info_alloc(gfp_t gfp_flags, bool with_fib6_nh) +{ + struct fib6_info *f6i; + size_t sz = sizeof(*f6i); + + if (with_fib6_nh) + sz += sizeof(struct fib6_nh); + + f6i = kzalloc(sz, gfp_flags); + if (!f6i) + return NULL; + + /* fib6_siblings is a union with nh_list, so this initializes both */ + INIT_LIST_HEAD(&f6i->fib6_siblings); + refcount_set(&f6i->fib6_ref, 1); + + return f6i; +} + +void fib6_info_destroy_rcu(struct rcu_head *head) +{ + struct fib6_info *f6i = container_of(head, struct fib6_info, rcu); + + WARN_ON(f6i->fib6_node); + + if (f6i->nh) + nexthop_put(f6i->nh); + else + fib6_nh_release(f6i->fib6_nh); + + ip_fib_metrics_put(f6i->fib6_metrics); + kfree(f6i); +} +EXPORT_SYMBOL_GPL(fib6_info_destroy_rcu); + +static struct fib6_node *node_alloc(struct net *net) +{ + struct fib6_node *fn; + + fn = kmem_cache_zalloc(fib6_node_kmem, GFP_ATOMIC); + if (fn) + net->ipv6.rt6_stats->fib_nodes++; + + return fn; +} + +static void node_free_immediate(struct net *net, struct fib6_node *fn) +{ + kmem_cache_free(fib6_node_kmem, fn); + net->ipv6.rt6_stats->fib_nodes--; +} + +static void node_free_rcu(struct rcu_head *head) +{ + struct fib6_node *fn = container_of(head, struct fib6_node, rcu); + + kmem_cache_free(fib6_node_kmem, fn); +} + +static void node_free(struct net *net, struct fib6_node *fn) +{ + call_rcu(&fn->rcu, node_free_rcu); + net->ipv6.rt6_stats->fib_nodes--; +} + +static void fib6_free_table(struct fib6_table *table) +{ + inetpeer_invalidate_tree(&table->tb6_peers); + kfree(table); +} + +static void fib6_link_table(struct net *net, struct fib6_table *tb) +{ + unsigned int h; + + /* + * Initialize table lock at a single place to give lockdep a key, + * tables aren't visible prior to being linked to the list. + */ + spin_lock_init(&tb->tb6_lock); + h = tb->tb6_id & (FIB6_TABLE_HASHSZ - 1); + + /* + * No protection necessary, this is the only list mutatation + * operation, tables never disappear once they exist. + */ + hlist_add_head_rcu(&tb->tb6_hlist, &net->ipv6.fib_table_hash[h]); +} + +#ifdef CONFIG_IPV6_MULTIPLE_TABLES + +static struct fib6_table *fib6_alloc_table(struct net *net, u32 id) +{ + struct fib6_table *table; + + table = kzalloc(sizeof(*table), GFP_ATOMIC); + if (table) { + table->tb6_id = id; + rcu_assign_pointer(table->tb6_root.leaf, + net->ipv6.fib6_null_entry); + table->tb6_root.fn_flags = RTN_ROOT | RTN_TL_ROOT | RTN_RTINFO; + inet_peer_base_init(&table->tb6_peers); + } + + return table; +} + +struct fib6_table *fib6_new_table(struct net *net, u32 id) +{ + struct fib6_table *tb; + + if (id == 0) + id = RT6_TABLE_MAIN; + tb = fib6_get_table(net, id); + if (tb) + return tb; + + tb = fib6_alloc_table(net, id); + if (tb) + fib6_link_table(net, tb); + + return tb; +} +EXPORT_SYMBOL_GPL(fib6_new_table); + +struct fib6_table *fib6_get_table(struct net *net, u32 id) +{ + struct fib6_table *tb; + struct hlist_head *head; + unsigned int h; + + if (id == 0) + id = RT6_TABLE_MAIN; + h = id & (FIB6_TABLE_HASHSZ - 1); + rcu_read_lock(); + head = &net->ipv6.fib_table_hash[h]; + hlist_for_each_entry_rcu(tb, head, tb6_hlist) { + if (tb->tb6_id == id) { + rcu_read_unlock(); + return tb; + } + } + rcu_read_unlock(); + + return NULL; +} +EXPORT_SYMBOL_GPL(fib6_get_table); + +static void __net_init fib6_tables_init(struct net *net) +{ + fib6_link_table(net, net->ipv6.fib6_main_tbl); + fib6_link_table(net, net->ipv6.fib6_local_tbl); +} +#else + +struct fib6_table *fib6_new_table(struct net *net, u32 id) +{ + return fib6_get_table(net, id); +} + +struct fib6_table *fib6_get_table(struct net *net, u32 id) +{ + return net->ipv6.fib6_main_tbl; +} + +struct dst_entry *fib6_rule_lookup(struct net *net, struct flowi6 *fl6, + const struct sk_buff *skb, + int flags, pol_lookup_t lookup) +{ + struct rt6_info *rt; + + rt = pol_lookup_func(lookup, + net, net->ipv6.fib6_main_tbl, fl6, skb, flags); + if (rt->dst.error == -EAGAIN) { + ip6_rt_put_flags(rt, flags); + rt = net->ipv6.ip6_null_entry; + if (!(flags & RT6_LOOKUP_F_DST_NOREF)) + dst_hold(&rt->dst); + } + + return &rt->dst; +} + +/* called with rcu lock held; no reference taken on fib6_info */ +int fib6_lookup(struct net *net, int oif, struct flowi6 *fl6, + struct fib6_result *res, int flags) +{ + return fib6_table_lookup(net, net->ipv6.fib6_main_tbl, oif, fl6, + res, flags); +} + +static void __net_init fib6_tables_init(struct net *net) +{ + fib6_link_table(net, net->ipv6.fib6_main_tbl); +} + +#endif + +unsigned int fib6_tables_seq_read(struct net *net) +{ + unsigned int h, fib_seq = 0; + + rcu_read_lock(); + for (h = 0; h < FIB6_TABLE_HASHSZ; h++) { + struct hlist_head *head = &net->ipv6.fib_table_hash[h]; + struct fib6_table *tb; + + hlist_for_each_entry_rcu(tb, head, tb6_hlist) + fib_seq += tb->fib_seq; + } + rcu_read_unlock(); + + return fib_seq; +} + +static int call_fib6_entry_notifier(struct notifier_block *nb, + enum fib_event_type event_type, + struct fib6_info *rt, + struct netlink_ext_ack *extack) +{ + struct fib6_entry_notifier_info info = { + .info.extack = extack, + .rt = rt, + }; + + return call_fib6_notifier(nb, event_type, &info.info); +} + +static int call_fib6_multipath_entry_notifier(struct notifier_block *nb, + enum fib_event_type event_type, + struct fib6_info *rt, + unsigned int nsiblings, + struct netlink_ext_ack *extack) +{ + struct fib6_entry_notifier_info info = { + .info.extack = extack, + .rt = rt, + .nsiblings = nsiblings, + }; + + return call_fib6_notifier(nb, event_type, &info.info); +} + +int call_fib6_entry_notifiers(struct net *net, + enum fib_event_type event_type, + struct fib6_info *rt, + struct netlink_ext_ack *extack) +{ + struct fib6_entry_notifier_info info = { + .info.extack = extack, + .rt = rt, + }; + + rt->fib6_table->fib_seq++; + return call_fib6_notifiers(net, event_type, &info.info); +} + +int call_fib6_multipath_entry_notifiers(struct net *net, + enum fib_event_type event_type, + struct fib6_info *rt, + unsigned int nsiblings, + struct netlink_ext_ack *extack) +{ + struct fib6_entry_notifier_info info = { + .info.extack = extack, + .rt = rt, + .nsiblings = nsiblings, + }; + + rt->fib6_table->fib_seq++; + return call_fib6_notifiers(net, event_type, &info.info); +} + +int call_fib6_entry_notifiers_replace(struct net *net, struct fib6_info *rt) +{ + struct fib6_entry_notifier_info info = { + .rt = rt, + .nsiblings = rt->fib6_nsiblings, + }; + + rt->fib6_table->fib_seq++; + return call_fib6_notifiers(net, FIB_EVENT_ENTRY_REPLACE, &info.info); +} + +struct fib6_dump_arg { + struct net *net; + struct notifier_block *nb; + struct netlink_ext_ack *extack; +}; + +static int fib6_rt_dump(struct fib6_info *rt, struct fib6_dump_arg *arg) +{ + enum fib_event_type fib_event = FIB_EVENT_ENTRY_REPLACE; + int err; + + if (!rt || rt == arg->net->ipv6.fib6_null_entry) + return 0; + + if (rt->fib6_nsiblings) + err = call_fib6_multipath_entry_notifier(arg->nb, fib_event, + rt, + rt->fib6_nsiblings, + arg->extack); + else + err = call_fib6_entry_notifier(arg->nb, fib_event, rt, + arg->extack); + + return err; +} + +static int fib6_node_dump(struct fib6_walker *w) +{ + int err; + + err = fib6_rt_dump(w->leaf, w->args); + w->leaf = NULL; + return err; +} + +static int fib6_table_dump(struct net *net, struct fib6_table *tb, + struct fib6_walker *w) +{ + int err; + + w->root = &tb->tb6_root; + spin_lock_bh(&tb->tb6_lock); + err = fib6_walk(net, w); + spin_unlock_bh(&tb->tb6_lock); + return err; +} + +/* Called with rcu_read_lock() */ +int fib6_tables_dump(struct net *net, struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + struct fib6_dump_arg arg; + struct fib6_walker *w; + unsigned int h; + int err = 0; + + w = kzalloc(sizeof(*w), GFP_ATOMIC); + if (!w) + return -ENOMEM; + + w->func = fib6_node_dump; + arg.net = net; + arg.nb = nb; + arg.extack = extack; + w->args = &arg; + + for (h = 0; h < FIB6_TABLE_HASHSZ; h++) { + struct hlist_head *head = &net->ipv6.fib_table_hash[h]; + struct fib6_table *tb; + + hlist_for_each_entry_rcu(tb, head, tb6_hlist) { + err = fib6_table_dump(net, tb, w); + if (err) + goto out; + } + } + +out: + kfree(w); + + /* The tree traversal function should never return a positive value. */ + return err > 0 ? -EINVAL : err; +} + +static int fib6_dump_node(struct fib6_walker *w) +{ + int res; + struct fib6_info *rt; + + for_each_fib6_walker_rt(w) { + res = rt6_dump_route(rt, w->args, w->skip_in_node); + if (res >= 0) { + /* Frame is full, suspend walking */ + w->leaf = rt; + + /* We'll restart from this node, so if some routes were + * already dumped, skip them next time. + */ + w->skip_in_node += res; + + return 1; + } + w->skip_in_node = 0; + + /* Multipath routes are dumped in one route with the + * RTA_MULTIPATH attribute. Jump 'rt' to point to the + * last sibling of this route (no need to dump the + * sibling routes again) + */ + if (rt->fib6_nsiblings) + rt = list_last_entry(&rt->fib6_siblings, + struct fib6_info, + fib6_siblings); + } + w->leaf = NULL; + return 0; +} + +static void fib6_dump_end(struct netlink_callback *cb) +{ + struct net *net = sock_net(cb->skb->sk); + struct fib6_walker *w = (void *)cb->args[2]; + + if (w) { + if (cb->args[4]) { + cb->args[4] = 0; + fib6_walker_unlink(net, w); + } + cb->args[2] = 0; + kfree(w); + } + cb->done = (void *)cb->args[3]; + cb->args[1] = 3; +} + +static int fib6_dump_done(struct netlink_callback *cb) +{ + fib6_dump_end(cb); + return cb->done ? cb->done(cb) : 0; +} + +static int fib6_dump_table(struct fib6_table *table, struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + struct fib6_walker *w; + int res; + + w = (void *)cb->args[2]; + w->root = &table->tb6_root; + + if (cb->args[4] == 0) { + w->count = 0; + w->skip = 0; + w->skip_in_node = 0; + + spin_lock_bh(&table->tb6_lock); + res = fib6_walk(net, w); + spin_unlock_bh(&table->tb6_lock); + if (res > 0) { + cb->args[4] = 1; + cb->args[5] = READ_ONCE(w->root->fn_sernum); + } + } else { + int sernum = READ_ONCE(w->root->fn_sernum); + if (cb->args[5] != sernum) { + /* Begin at the root if the tree changed */ + cb->args[5] = sernum; + w->state = FWS_INIT; + w->node = w->root; + w->skip = w->count; + w->skip_in_node = 0; + } else + w->skip = 0; + + spin_lock_bh(&table->tb6_lock); + res = fib6_walk_continue(w); + spin_unlock_bh(&table->tb6_lock); + if (res <= 0) { + fib6_walker_unlink(net, w); + cb->args[4] = 0; + } + } + + return res; +} + +static int inet6_dump_fib(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct rt6_rtnl_dump_arg arg = { .filter.dump_exceptions = true, + .filter.dump_routes = true }; + const struct nlmsghdr *nlh = cb->nlh; + struct net *net = sock_net(skb->sk); + unsigned int h, s_h; + unsigned int e = 0, s_e; + struct fib6_walker *w; + struct fib6_table *tb; + struct hlist_head *head; + int res = 0; + + if (cb->strict_check) { + int err; + + err = ip_valid_fib_dump_req(net, nlh, &arg.filter, cb); + if (err < 0) + return err; + } else if (nlmsg_len(nlh) >= sizeof(struct rtmsg)) { + struct rtmsg *rtm = nlmsg_data(nlh); + + if (rtm->rtm_flags & RTM_F_PREFIX) + arg.filter.flags = RTM_F_PREFIX; + } + + w = (void *)cb->args[2]; + if (!w) { + /* New dump: + * + * 1. hook callback destructor. + */ + cb->args[3] = (long)cb->done; + cb->done = fib6_dump_done; + + /* + * 2. allocate and initialize walker. + */ + w = kzalloc(sizeof(*w), GFP_ATOMIC); + if (!w) + return -ENOMEM; + w->func = fib6_dump_node; + cb->args[2] = (long)w; + } + + arg.skb = skb; + arg.cb = cb; + arg.net = net; + w->args = &arg; + + if (arg.filter.table_id) { + tb = fib6_get_table(net, arg.filter.table_id); + if (!tb) { + if (rtnl_msg_family(cb->nlh) != PF_INET6) + goto out; + + NL_SET_ERR_MSG_MOD(cb->extack, "FIB table does not exist"); + return -ENOENT; + } + + if (!cb->args[0]) { + res = fib6_dump_table(tb, skb, cb); + if (!res) + cb->args[0] = 1; + } + goto out; + } + + s_h = cb->args[0]; + s_e = cb->args[1]; + + rcu_read_lock(); + for (h = s_h; h < FIB6_TABLE_HASHSZ; h++, s_e = 0) { + e = 0; + head = &net->ipv6.fib_table_hash[h]; + hlist_for_each_entry_rcu(tb, head, tb6_hlist) { + if (e < s_e) + goto next; + res = fib6_dump_table(tb, skb, cb); + if (res != 0) + goto out_unlock; +next: + e++; + } + } +out_unlock: + rcu_read_unlock(); + cb->args[1] = e; + cb->args[0] = h; +out: + res = res < 0 ? res : skb->len; + if (res <= 0) + fib6_dump_end(cb); + return res; +} + +void fib6_metric_set(struct fib6_info *f6i, int metric, u32 val) +{ + if (!f6i) + return; + + if (f6i->fib6_metrics == &dst_default_metrics) { + struct dst_metrics *p = kzalloc(sizeof(*p), GFP_ATOMIC); + + if (!p) + return; + + refcount_set(&p->refcnt, 1); + f6i->fib6_metrics = p; + } + + f6i->fib6_metrics->metrics[metric - 1] = val; +} + +/* + * Routing Table + * + * return the appropriate node for a routing tree "add" operation + * by either creating and inserting or by returning an existing + * node. + */ + +static struct fib6_node *fib6_add_1(struct net *net, + struct fib6_table *table, + struct fib6_node *root, + struct in6_addr *addr, int plen, + int offset, int allow_create, + int replace_required, + struct netlink_ext_ack *extack) +{ + struct fib6_node *fn, *in, *ln; + struct fib6_node *pn = NULL; + struct rt6key *key; + int bit; + __be32 dir = 0; + + RT6_TRACE("fib6_add_1\n"); + + /* insert node in tree */ + + fn = root; + + do { + struct fib6_info *leaf = rcu_dereference_protected(fn->leaf, + lockdep_is_held(&table->tb6_lock)); + key = (struct rt6key *)((u8 *)leaf + offset); + + /* + * Prefix match + */ + if (plen < fn->fn_bit || + !ipv6_prefix_equal(&key->addr, addr, fn->fn_bit)) { + if (!allow_create) { + if (replace_required) { + NL_SET_ERR_MSG(extack, + "Can not replace route - no match found"); + pr_warn("Can't replace route, no match found\n"); + return ERR_PTR(-ENOENT); + } + pr_warn("NLM_F_CREATE should be set when creating new route\n"); + } + goto insert_above; + } + + /* + * Exact match ? + */ + + if (plen == fn->fn_bit) { + /* clean up an intermediate node */ + if (!(fn->fn_flags & RTN_RTINFO)) { + RCU_INIT_POINTER(fn->leaf, NULL); + fib6_info_release(leaf); + /* remove null_entry in the root node */ + } else if (fn->fn_flags & RTN_TL_ROOT && + rcu_access_pointer(fn->leaf) == + net->ipv6.fib6_null_entry) { + RCU_INIT_POINTER(fn->leaf, NULL); + } + + return fn; + } + + /* + * We have more bits to go + */ + + /* Try to walk down on tree. */ + dir = addr_bit_set(addr, fn->fn_bit); + pn = fn; + fn = dir ? + rcu_dereference_protected(fn->right, + lockdep_is_held(&table->tb6_lock)) : + rcu_dereference_protected(fn->left, + lockdep_is_held(&table->tb6_lock)); + } while (fn); + + if (!allow_create) { + /* We should not create new node because + * NLM_F_REPLACE was specified without NLM_F_CREATE + * I assume it is safe to require NLM_F_CREATE when + * REPLACE flag is used! Later we may want to remove the + * check for replace_required, because according + * to netlink specification, NLM_F_CREATE + * MUST be specified if new route is created. + * That would keep IPv6 consistent with IPv4 + */ + if (replace_required) { + NL_SET_ERR_MSG(extack, + "Can not replace route - no match found"); + pr_warn("Can't replace route, no match found\n"); + return ERR_PTR(-ENOENT); + } + pr_warn("NLM_F_CREATE should be set when creating new route\n"); + } + /* + * We walked to the bottom of tree. + * Create new leaf node without children. + */ + + ln = node_alloc(net); + + if (!ln) + return ERR_PTR(-ENOMEM); + ln->fn_bit = plen; + RCU_INIT_POINTER(ln->parent, pn); + + if (dir) + rcu_assign_pointer(pn->right, ln); + else + rcu_assign_pointer(pn->left, ln); + + return ln; + + +insert_above: + /* + * split since we don't have a common prefix anymore or + * we have a less significant route. + * we've to insert an intermediate node on the list + * this new node will point to the one we need to create + * and the current + */ + + pn = rcu_dereference_protected(fn->parent, + lockdep_is_held(&table->tb6_lock)); + + /* find 1st bit in difference between the 2 addrs. + + See comment in __ipv6_addr_diff: bit may be an invalid value, + but if it is >= plen, the value is ignored in any case. + */ + + bit = __ipv6_addr_diff(addr, &key->addr, sizeof(*addr)); + + /* + * (intermediate)[in] + * / \ + * (new leaf node)[ln] (old node)[fn] + */ + if (plen > bit) { + in = node_alloc(net); + ln = node_alloc(net); + + if (!in || !ln) { + if (in) + node_free_immediate(net, in); + if (ln) + node_free_immediate(net, ln); + return ERR_PTR(-ENOMEM); + } + + /* + * new intermediate node. + * RTN_RTINFO will + * be off since that an address that chooses one of + * the branches would not match less specific routes + * in the other branch + */ + + in->fn_bit = bit; + + RCU_INIT_POINTER(in->parent, pn); + in->leaf = fn->leaf; + fib6_info_hold(rcu_dereference_protected(in->leaf, + lockdep_is_held(&table->tb6_lock))); + + /* update parent pointer */ + if (dir) + rcu_assign_pointer(pn->right, in); + else + rcu_assign_pointer(pn->left, in); + + ln->fn_bit = plen; + + RCU_INIT_POINTER(ln->parent, in); + rcu_assign_pointer(fn->parent, in); + + if (addr_bit_set(addr, bit)) { + rcu_assign_pointer(in->right, ln); + rcu_assign_pointer(in->left, fn); + } else { + rcu_assign_pointer(in->left, ln); + rcu_assign_pointer(in->right, fn); + } + } else { /* plen <= bit */ + + /* + * (new leaf node)[ln] + * / \ + * (old node)[fn] NULL + */ + + ln = node_alloc(net); + + if (!ln) + return ERR_PTR(-ENOMEM); + + ln->fn_bit = plen; + + RCU_INIT_POINTER(ln->parent, pn); + + if (addr_bit_set(&key->addr, plen)) + RCU_INIT_POINTER(ln->right, fn); + else + RCU_INIT_POINTER(ln->left, fn); + + rcu_assign_pointer(fn->parent, ln); + + if (dir) + rcu_assign_pointer(pn->right, ln); + else + rcu_assign_pointer(pn->left, ln); + } + return ln; +} + +static void __fib6_drop_pcpu_from(struct fib6_nh *fib6_nh, + const struct fib6_info *match, + const struct fib6_table *table) +{ + int cpu; + + if (!fib6_nh->rt6i_pcpu) + return; + + /* release the reference to this fib entry from + * all of its cached pcpu routes + */ + for_each_possible_cpu(cpu) { + struct rt6_info **ppcpu_rt; + struct rt6_info *pcpu_rt; + + ppcpu_rt = per_cpu_ptr(fib6_nh->rt6i_pcpu, cpu); + pcpu_rt = *ppcpu_rt; + + /* only dropping the 'from' reference if the cached route + * is using 'match'. The cached pcpu_rt->from only changes + * from a fib6_info to NULL (ip6_dst_destroy); it can never + * change from one fib6_info reference to another + */ + if (pcpu_rt && rcu_access_pointer(pcpu_rt->from) == match) { + struct fib6_info *from; + + from = xchg((__force struct fib6_info **)&pcpu_rt->from, NULL); + fib6_info_release(from); + } + } +} + +struct fib6_nh_pcpu_arg { + struct fib6_info *from; + const struct fib6_table *table; +}; + +static int fib6_nh_drop_pcpu_from(struct fib6_nh *nh, void *_arg) +{ + struct fib6_nh_pcpu_arg *arg = _arg; + + __fib6_drop_pcpu_from(nh, arg->from, arg->table); + return 0; +} + +static void fib6_drop_pcpu_from(struct fib6_info *f6i, + const struct fib6_table *table) +{ + /* Make sure rt6_make_pcpu_route() wont add other percpu routes + * while we are cleaning them here. + */ + f6i->fib6_destroying = 1; + mb(); /* paired with the cmpxchg() in rt6_make_pcpu_route() */ + + if (f6i->nh) { + struct fib6_nh_pcpu_arg arg = { + .from = f6i, + .table = table + }; + + nexthop_for_each_fib6_nh(f6i->nh, fib6_nh_drop_pcpu_from, + &arg); + } else { + struct fib6_nh *fib6_nh; + + fib6_nh = f6i->fib6_nh; + __fib6_drop_pcpu_from(fib6_nh, f6i, table); + } +} + +static void fib6_purge_rt(struct fib6_info *rt, struct fib6_node *fn, + struct net *net) +{ + struct fib6_table *table = rt->fib6_table; + + /* Flush all cached dst in exception table */ + rt6_flush_exceptions(rt); + fib6_drop_pcpu_from(rt, table); + + if (rt->nh && !list_empty(&rt->nh_list)) + list_del_init(&rt->nh_list); + + if (refcount_read(&rt->fib6_ref) != 1) { + /* This route is used as dummy address holder in some split + * nodes. It is not leaked, but it still holds other resources, + * which must be released in time. So, scan ascendant nodes + * and replace dummy references to this route with references + * to still alive ones. + */ + while (fn) { + struct fib6_info *leaf = rcu_dereference_protected(fn->leaf, + lockdep_is_held(&table->tb6_lock)); + struct fib6_info *new_leaf; + if (!(fn->fn_flags & RTN_RTINFO) && leaf == rt) { + new_leaf = fib6_find_prefix(net, table, fn); + fib6_info_hold(new_leaf); + + rcu_assign_pointer(fn->leaf, new_leaf); + fib6_info_release(rt); + } + fn = rcu_dereference_protected(fn->parent, + lockdep_is_held(&table->tb6_lock)); + } + } +} + +/* + * Insert routing information in a node. + */ + +static int fib6_add_rt2node(struct fib6_node *fn, struct fib6_info *rt, + struct nl_info *info, + struct netlink_ext_ack *extack) +{ + struct fib6_info *leaf = rcu_dereference_protected(fn->leaf, + lockdep_is_held(&rt->fib6_table->tb6_lock)); + struct fib6_info *iter = NULL; + struct fib6_info __rcu **ins; + struct fib6_info __rcu **fallback_ins = NULL; + int replace = (info->nlh && + (info->nlh->nlmsg_flags & NLM_F_REPLACE)); + int add = (!info->nlh || + (info->nlh->nlmsg_flags & NLM_F_CREATE)); + int found = 0; + bool rt_can_ecmp = rt6_qualify_for_ecmp(rt); + bool notify_sibling_rt = false; + u16 nlflags = NLM_F_EXCL; + int err; + + if (info->nlh && (info->nlh->nlmsg_flags & NLM_F_APPEND)) + nlflags |= NLM_F_APPEND; + + ins = &fn->leaf; + + for (iter = leaf; iter; + iter = rcu_dereference_protected(iter->fib6_next, + lockdep_is_held(&rt->fib6_table->tb6_lock))) { + /* + * Search for duplicates + */ + + if (iter->fib6_metric == rt->fib6_metric) { + /* + * Same priority level + */ + if (info->nlh && + (info->nlh->nlmsg_flags & NLM_F_EXCL)) + return -EEXIST; + + nlflags &= ~NLM_F_EXCL; + if (replace) { + if (rt_can_ecmp == rt6_qualify_for_ecmp(iter)) { + found++; + break; + } + fallback_ins = fallback_ins ?: ins; + goto next_iter; + } + + if (rt6_duplicate_nexthop(iter, rt)) { + if (rt->fib6_nsiblings) + rt->fib6_nsiblings = 0; + if (!(iter->fib6_flags & RTF_EXPIRES)) + return -EEXIST; + if (!(rt->fib6_flags & RTF_EXPIRES)) + fib6_clean_expires(iter); + else + fib6_set_expires(iter, rt->expires); + + if (rt->fib6_pmtu) + fib6_metric_set(iter, RTAX_MTU, + rt->fib6_pmtu); + return -EEXIST; + } + /* If we have the same destination and the same metric, + * but not the same gateway, then the route we try to + * add is sibling to this route, increment our counter + * of siblings, and later we will add our route to the + * list. + * Only static routes (which don't have flag + * RTF_EXPIRES) are used for ECMPv6. + * + * To avoid long list, we only had siblings if the + * route have a gateway. + */ + if (rt_can_ecmp && + rt6_qualify_for_ecmp(iter)) + rt->fib6_nsiblings++; + } + + if (iter->fib6_metric > rt->fib6_metric) + break; + +next_iter: + ins = &iter->fib6_next; + } + + if (fallback_ins && !found) { + /* No matching route with same ecmp-able-ness found, replace + * first matching route + */ + ins = fallback_ins; + iter = rcu_dereference_protected(*ins, + lockdep_is_held(&rt->fib6_table->tb6_lock)); + found++; + } + + /* Reset round-robin state, if necessary */ + if (ins == &fn->leaf) + fn->rr_ptr = NULL; + + /* Link this route to others same route. */ + if (rt->fib6_nsiblings) { + unsigned int fib6_nsiblings; + struct fib6_info *sibling, *temp_sibling; + + /* Find the first route that have the same metric */ + sibling = leaf; + notify_sibling_rt = true; + while (sibling) { + if (sibling->fib6_metric == rt->fib6_metric && + rt6_qualify_for_ecmp(sibling)) { + list_add_tail(&rt->fib6_siblings, + &sibling->fib6_siblings); + break; + } + sibling = rcu_dereference_protected(sibling->fib6_next, + lockdep_is_held(&rt->fib6_table->tb6_lock)); + notify_sibling_rt = false; + } + /* For each sibling in the list, increment the counter of + * siblings. BUG() if counters does not match, list of siblings + * is broken! + */ + fib6_nsiblings = 0; + list_for_each_entry_safe(sibling, temp_sibling, + &rt->fib6_siblings, fib6_siblings) { + sibling->fib6_nsiblings++; + BUG_ON(sibling->fib6_nsiblings != rt->fib6_nsiblings); + fib6_nsiblings++; + } + BUG_ON(fib6_nsiblings != rt->fib6_nsiblings); + rt6_multipath_rebalance(temp_sibling); + } + + /* + * insert node + */ + if (!replace) { + if (!add) + pr_warn("NLM_F_CREATE should be set when creating new route\n"); + +add: + nlflags |= NLM_F_CREATE; + + /* The route should only be notified if it is the first + * route in the node or if it is added as a sibling + * route to the first route in the node. + */ + if (!info->skip_notify_kernel && + (notify_sibling_rt || ins == &fn->leaf)) { + enum fib_event_type fib_event; + + if (notify_sibling_rt) + fib_event = FIB_EVENT_ENTRY_APPEND; + else + fib_event = FIB_EVENT_ENTRY_REPLACE; + err = call_fib6_entry_notifiers(info->nl_net, + fib_event, rt, + extack); + if (err) { + struct fib6_info *sibling, *next_sibling; + + /* If the route has siblings, then it first + * needs to be unlinked from them. + */ + if (!rt->fib6_nsiblings) + return err; + + list_for_each_entry_safe(sibling, next_sibling, + &rt->fib6_siblings, + fib6_siblings) + sibling->fib6_nsiblings--; + rt->fib6_nsiblings = 0; + list_del_init(&rt->fib6_siblings); + rt6_multipath_rebalance(next_sibling); + return err; + } + } + + rcu_assign_pointer(rt->fib6_next, iter); + fib6_info_hold(rt); + rcu_assign_pointer(rt->fib6_node, fn); + rcu_assign_pointer(*ins, rt); + if (!info->skip_notify) + inet6_rt_notify(RTM_NEWROUTE, rt, info, nlflags); + info->nl_net->ipv6.rt6_stats->fib_rt_entries++; + + if (!(fn->fn_flags & RTN_RTINFO)) { + info->nl_net->ipv6.rt6_stats->fib_route_nodes++; + fn->fn_flags |= RTN_RTINFO; + } + + } else { + int nsiblings; + + if (!found) { + if (add) + goto add; + pr_warn("NLM_F_REPLACE set, but no existing node found!\n"); + return -ENOENT; + } + + if (!info->skip_notify_kernel && ins == &fn->leaf) { + err = call_fib6_entry_notifiers(info->nl_net, + FIB_EVENT_ENTRY_REPLACE, + rt, extack); + if (err) + return err; + } + + fib6_info_hold(rt); + rcu_assign_pointer(rt->fib6_node, fn); + rt->fib6_next = iter->fib6_next; + rcu_assign_pointer(*ins, rt); + if (!info->skip_notify) + inet6_rt_notify(RTM_NEWROUTE, rt, info, NLM_F_REPLACE); + if (!(fn->fn_flags & RTN_RTINFO)) { + info->nl_net->ipv6.rt6_stats->fib_route_nodes++; + fn->fn_flags |= RTN_RTINFO; + } + nsiblings = iter->fib6_nsiblings; + iter->fib6_node = NULL; + fib6_purge_rt(iter, fn, info->nl_net); + if (rcu_access_pointer(fn->rr_ptr) == iter) + fn->rr_ptr = NULL; + fib6_info_release(iter); + + if (nsiblings) { + /* Replacing an ECMP route, remove all siblings */ + ins = &rt->fib6_next; + iter = rcu_dereference_protected(*ins, + lockdep_is_held(&rt->fib6_table->tb6_lock)); + while (iter) { + if (iter->fib6_metric > rt->fib6_metric) + break; + if (rt6_qualify_for_ecmp(iter)) { + *ins = iter->fib6_next; + iter->fib6_node = NULL; + fib6_purge_rt(iter, fn, info->nl_net); + if (rcu_access_pointer(fn->rr_ptr) == iter) + fn->rr_ptr = NULL; + fib6_info_release(iter); + nsiblings--; + info->nl_net->ipv6.rt6_stats->fib_rt_entries--; + } else { + ins = &iter->fib6_next; + } + iter = rcu_dereference_protected(*ins, + lockdep_is_held(&rt->fib6_table->tb6_lock)); + } + WARN_ON(nsiblings != 0); + } + } + + return 0; +} + +static void fib6_start_gc(struct net *net, struct fib6_info *rt) +{ + if (!timer_pending(&net->ipv6.ip6_fib_timer) && + (rt->fib6_flags & RTF_EXPIRES)) + mod_timer(&net->ipv6.ip6_fib_timer, + jiffies + net->ipv6.sysctl.ip6_rt_gc_interval); +} + +void fib6_force_start_gc(struct net *net) +{ + if (!timer_pending(&net->ipv6.ip6_fib_timer)) + mod_timer(&net->ipv6.ip6_fib_timer, + jiffies + net->ipv6.sysctl.ip6_rt_gc_interval); +} + +static void __fib6_update_sernum_upto_root(struct fib6_info *rt, + int sernum) +{ + struct fib6_node *fn = rcu_dereference_protected(rt->fib6_node, + lockdep_is_held(&rt->fib6_table->tb6_lock)); + + /* paired with smp_rmb() in fib6_get_cookie_safe() */ + smp_wmb(); + while (fn) { + WRITE_ONCE(fn->fn_sernum, sernum); + fn = rcu_dereference_protected(fn->parent, + lockdep_is_held(&rt->fib6_table->tb6_lock)); + } +} + +void fib6_update_sernum_upto_root(struct net *net, struct fib6_info *rt) +{ + __fib6_update_sernum_upto_root(rt, fib6_new_sernum(net)); +} + +/* allow ipv4 to update sernum via ipv6_stub */ +void fib6_update_sernum_stub(struct net *net, struct fib6_info *f6i) +{ + spin_lock_bh(&f6i->fib6_table->tb6_lock); + fib6_update_sernum_upto_root(net, f6i); + spin_unlock_bh(&f6i->fib6_table->tb6_lock); +} + +/* + * Add routing information to the routing tree. + * / + * with source addr info in sub-trees + * Need to own table->tb6_lock + */ + +int fib6_add(struct fib6_node *root, struct fib6_info *rt, + struct nl_info *info, struct netlink_ext_ack *extack) +{ + struct fib6_table *table = rt->fib6_table; + struct fib6_node *fn, *pn = NULL; + int err = -ENOMEM; + int allow_create = 1; + int replace_required = 0; + + if (info->nlh) { + if (!(info->nlh->nlmsg_flags & NLM_F_CREATE)) + allow_create = 0; + if (info->nlh->nlmsg_flags & NLM_F_REPLACE) + replace_required = 1; + } + if (!allow_create && !replace_required) + pr_warn("RTM_NEWROUTE with no NLM_F_CREATE or NLM_F_REPLACE\n"); + + fn = fib6_add_1(info->nl_net, table, root, + &rt->fib6_dst.addr, rt->fib6_dst.plen, + offsetof(struct fib6_info, fib6_dst), allow_create, + replace_required, extack); + if (IS_ERR(fn)) { + err = PTR_ERR(fn); + fn = NULL; + goto out; + } + + pn = fn; + +#ifdef CONFIG_IPV6_SUBTREES + if (rt->fib6_src.plen) { + struct fib6_node *sn; + + if (!rcu_access_pointer(fn->subtree)) { + struct fib6_node *sfn; + + /* + * Create subtree. + * + * fn[main tree] + * | + * sfn[subtree root] + * \ + * sn[new leaf node] + */ + + /* Create subtree root node */ + sfn = node_alloc(info->nl_net); + if (!sfn) + goto failure; + + fib6_info_hold(info->nl_net->ipv6.fib6_null_entry); + rcu_assign_pointer(sfn->leaf, + info->nl_net->ipv6.fib6_null_entry); + sfn->fn_flags = RTN_ROOT; + + /* Now add the first leaf node to new subtree */ + + sn = fib6_add_1(info->nl_net, table, sfn, + &rt->fib6_src.addr, rt->fib6_src.plen, + offsetof(struct fib6_info, fib6_src), + allow_create, replace_required, extack); + + if (IS_ERR(sn)) { + /* If it is failed, discard just allocated + root, and then (in failure) stale node + in main tree. + */ + node_free_immediate(info->nl_net, sfn); + err = PTR_ERR(sn); + goto failure; + } + + /* Now link new subtree to main tree */ + rcu_assign_pointer(sfn->parent, fn); + rcu_assign_pointer(fn->subtree, sfn); + } else { + sn = fib6_add_1(info->nl_net, table, FIB6_SUBTREE(fn), + &rt->fib6_src.addr, rt->fib6_src.plen, + offsetof(struct fib6_info, fib6_src), + allow_create, replace_required, extack); + + if (IS_ERR(sn)) { + err = PTR_ERR(sn); + goto failure; + } + } + + if (!rcu_access_pointer(fn->leaf)) { + if (fn->fn_flags & RTN_TL_ROOT) { + /* put back null_entry for root node */ + rcu_assign_pointer(fn->leaf, + info->nl_net->ipv6.fib6_null_entry); + } else { + fib6_info_hold(rt); + rcu_assign_pointer(fn->leaf, rt); + } + } + fn = sn; + } +#endif + + err = fib6_add_rt2node(fn, rt, info, extack); + if (!err) { + if (rt->nh) + list_add(&rt->nh_list, &rt->nh->f6i_list); + __fib6_update_sernum_upto_root(rt, fib6_new_sernum(info->nl_net)); + fib6_start_gc(info->nl_net, rt); + } + +out: + if (err) { +#ifdef CONFIG_IPV6_SUBTREES + /* + * If fib6_add_1 has cleared the old leaf pointer in the + * super-tree leaf node we have to find a new one for it. + */ + if (pn != fn) { + struct fib6_info *pn_leaf = + rcu_dereference_protected(pn->leaf, + lockdep_is_held(&table->tb6_lock)); + if (pn_leaf == rt) { + pn_leaf = NULL; + RCU_INIT_POINTER(pn->leaf, NULL); + fib6_info_release(rt); + } + if (!pn_leaf && !(pn->fn_flags & RTN_RTINFO)) { + pn_leaf = fib6_find_prefix(info->nl_net, table, + pn); + if (!pn_leaf) + pn_leaf = + info->nl_net->ipv6.fib6_null_entry; + fib6_info_hold(pn_leaf); + rcu_assign_pointer(pn->leaf, pn_leaf); + } + } +#endif + goto failure; + } else if (fib6_requires_src(rt)) { + fib6_routes_require_src_inc(info->nl_net); + } + return err; + +failure: + /* fn->leaf could be NULL and fib6_repair_tree() needs to be called if: + * 1. fn is an intermediate node and we failed to add the new + * route to it in both subtree creation failure and fib6_add_rt2node() + * failure case. + * 2. fn is the root node in the table and we fail to add the first + * default route to it. + */ + if (fn && + (!(fn->fn_flags & (RTN_RTINFO|RTN_ROOT)) || + (fn->fn_flags & RTN_TL_ROOT && + !rcu_access_pointer(fn->leaf)))) + fib6_repair_tree(info->nl_net, table, fn); + return err; +} + +/* + * Routing tree lookup + * + */ + +struct lookup_args { + int offset; /* key offset on fib6_info */ + const struct in6_addr *addr; /* search key */ +}; + +static struct fib6_node *fib6_node_lookup_1(struct fib6_node *root, + struct lookup_args *args) +{ + struct fib6_node *fn; + __be32 dir; + + if (unlikely(args->offset == 0)) + return NULL; + + /* + * Descend on a tree + */ + + fn = root; + + for (;;) { + struct fib6_node *next; + + dir = addr_bit_set(args->addr, fn->fn_bit); + + next = dir ? rcu_dereference(fn->right) : + rcu_dereference(fn->left); + + if (next) { + fn = next; + continue; + } + break; + } + + while (fn) { + struct fib6_node *subtree = FIB6_SUBTREE(fn); + + if (subtree || fn->fn_flags & RTN_RTINFO) { + struct fib6_info *leaf = rcu_dereference(fn->leaf); + struct rt6key *key; + + if (!leaf) + goto backtrack; + + key = (struct rt6key *) ((u8 *)leaf + args->offset); + + if (ipv6_prefix_equal(&key->addr, args->addr, key->plen)) { +#ifdef CONFIG_IPV6_SUBTREES + if (subtree) { + struct fib6_node *sfn; + sfn = fib6_node_lookup_1(subtree, + args + 1); + if (!sfn) + goto backtrack; + fn = sfn; + } +#endif + if (fn->fn_flags & RTN_RTINFO) + return fn; + } + } +backtrack: + if (fn->fn_flags & RTN_ROOT) + break; + + fn = rcu_dereference(fn->parent); + } + + return NULL; +} + +/* called with rcu_read_lock() held + */ +struct fib6_node *fib6_node_lookup(struct fib6_node *root, + const struct in6_addr *daddr, + const struct in6_addr *saddr) +{ + struct fib6_node *fn; + struct lookup_args args[] = { + { + .offset = offsetof(struct fib6_info, fib6_dst), + .addr = daddr, + }, +#ifdef CONFIG_IPV6_SUBTREES + { + .offset = offsetof(struct fib6_info, fib6_src), + .addr = saddr, + }, +#endif + { + .offset = 0, /* sentinel */ + } + }; + + fn = fib6_node_lookup_1(root, daddr ? args : args + 1); + if (!fn || fn->fn_flags & RTN_TL_ROOT) + fn = root; + + return fn; +} + +/* + * Get node with specified destination prefix (and source prefix, + * if subtrees are used) + * exact_match == true means we try to find fn with exact match of + * the passed in prefix addr + * exact_match == false means we try to find fn with longest prefix + * match of the passed in prefix addr. This is useful for finding fn + * for cached route as it will be stored in the exception table under + * the node with longest prefix length. + */ + + +static struct fib6_node *fib6_locate_1(struct fib6_node *root, + const struct in6_addr *addr, + int plen, int offset, + bool exact_match) +{ + struct fib6_node *fn, *prev = NULL; + + for (fn = root; fn ; ) { + struct fib6_info *leaf = rcu_dereference(fn->leaf); + struct rt6key *key; + + /* This node is being deleted */ + if (!leaf) { + if (plen <= fn->fn_bit) + goto out; + else + goto next; + } + + key = (struct rt6key *)((u8 *)leaf + offset); + + /* + * Prefix match + */ + if (plen < fn->fn_bit || + !ipv6_prefix_equal(&key->addr, addr, fn->fn_bit)) + goto out; + + if (plen == fn->fn_bit) + return fn; + + if (fn->fn_flags & RTN_RTINFO) + prev = fn; + +next: + /* + * We have more bits to go + */ + if (addr_bit_set(addr, fn->fn_bit)) + fn = rcu_dereference(fn->right); + else + fn = rcu_dereference(fn->left); + } +out: + if (exact_match) + return NULL; + else + return prev; +} + +struct fib6_node *fib6_locate(struct fib6_node *root, + const struct in6_addr *daddr, int dst_len, + const struct in6_addr *saddr, int src_len, + bool exact_match) +{ + struct fib6_node *fn; + + fn = fib6_locate_1(root, daddr, dst_len, + offsetof(struct fib6_info, fib6_dst), + exact_match); + +#ifdef CONFIG_IPV6_SUBTREES + if (src_len) { + WARN_ON(saddr == NULL); + if (fn) { + struct fib6_node *subtree = FIB6_SUBTREE(fn); + + if (subtree) { + fn = fib6_locate_1(subtree, saddr, src_len, + offsetof(struct fib6_info, fib6_src), + exact_match); + } + } + } +#endif + + if (fn && fn->fn_flags & RTN_RTINFO) + return fn; + + return NULL; +} + + +/* + * Deletion + * + */ + +static struct fib6_info *fib6_find_prefix(struct net *net, + struct fib6_table *table, + struct fib6_node *fn) +{ + struct fib6_node *child_left, *child_right; + + if (fn->fn_flags & RTN_ROOT) + return net->ipv6.fib6_null_entry; + + while (fn) { + child_left = rcu_dereference_protected(fn->left, + lockdep_is_held(&table->tb6_lock)); + child_right = rcu_dereference_protected(fn->right, + lockdep_is_held(&table->tb6_lock)); + if (child_left) + return rcu_dereference_protected(child_left->leaf, + lockdep_is_held(&table->tb6_lock)); + if (child_right) + return rcu_dereference_protected(child_right->leaf, + lockdep_is_held(&table->tb6_lock)); + + fn = FIB6_SUBTREE(fn); + } + return NULL; +} + +/* + * Called to trim the tree of intermediate nodes when possible. "fn" + * is the node we want to try and remove. + * Need to own table->tb6_lock + */ + +static struct fib6_node *fib6_repair_tree(struct net *net, + struct fib6_table *table, + struct fib6_node *fn) +{ + int children; + int nstate; + struct fib6_node *child; + struct fib6_walker *w; + int iter = 0; + + /* Set fn->leaf to null_entry for root node. */ + if (fn->fn_flags & RTN_TL_ROOT) { + rcu_assign_pointer(fn->leaf, net->ipv6.fib6_null_entry); + return fn; + } + + for (;;) { + struct fib6_node *fn_r = rcu_dereference_protected(fn->right, + lockdep_is_held(&table->tb6_lock)); + struct fib6_node *fn_l = rcu_dereference_protected(fn->left, + lockdep_is_held(&table->tb6_lock)); + struct fib6_node *pn = rcu_dereference_protected(fn->parent, + lockdep_is_held(&table->tb6_lock)); + struct fib6_node *pn_r = rcu_dereference_protected(pn->right, + lockdep_is_held(&table->tb6_lock)); + struct fib6_node *pn_l = rcu_dereference_protected(pn->left, + lockdep_is_held(&table->tb6_lock)); + struct fib6_info *fn_leaf = rcu_dereference_protected(fn->leaf, + lockdep_is_held(&table->tb6_lock)); + struct fib6_info *pn_leaf = rcu_dereference_protected(pn->leaf, + lockdep_is_held(&table->tb6_lock)); + struct fib6_info *new_fn_leaf; + + RT6_TRACE("fixing tree: plen=%d iter=%d\n", fn->fn_bit, iter); + iter++; + + WARN_ON(fn->fn_flags & RTN_RTINFO); + WARN_ON(fn->fn_flags & RTN_TL_ROOT); + WARN_ON(fn_leaf); + + children = 0; + child = NULL; + if (fn_r) { + child = fn_r; + children |= 1; + } + if (fn_l) { + child = fn_l; + children |= 2; + } + + if (children == 3 || FIB6_SUBTREE(fn) +#ifdef CONFIG_IPV6_SUBTREES + /* Subtree root (i.e. fn) may have one child */ + || (children && fn->fn_flags & RTN_ROOT) +#endif + ) { + new_fn_leaf = fib6_find_prefix(net, table, fn); +#if RT6_DEBUG >= 2 + if (!new_fn_leaf) { + WARN_ON(!new_fn_leaf); + new_fn_leaf = net->ipv6.fib6_null_entry; + } +#endif + fib6_info_hold(new_fn_leaf); + rcu_assign_pointer(fn->leaf, new_fn_leaf); + return pn; + } + +#ifdef CONFIG_IPV6_SUBTREES + if (FIB6_SUBTREE(pn) == fn) { + WARN_ON(!(fn->fn_flags & RTN_ROOT)); + RCU_INIT_POINTER(pn->subtree, NULL); + nstate = FWS_L; + } else { + WARN_ON(fn->fn_flags & RTN_ROOT); +#endif + if (pn_r == fn) + rcu_assign_pointer(pn->right, child); + else if (pn_l == fn) + rcu_assign_pointer(pn->left, child); +#if RT6_DEBUG >= 2 + else + WARN_ON(1); +#endif + if (child) + rcu_assign_pointer(child->parent, pn); + nstate = FWS_R; +#ifdef CONFIG_IPV6_SUBTREES + } +#endif + + read_lock(&net->ipv6.fib6_walker_lock); + FOR_WALKERS(net, w) { + if (!child) { + if (w->node == fn) { + RT6_TRACE("W %p adjusted by delnode 1, s=%d/%d\n", w, w->state, nstate); + w->node = pn; + w->state = nstate; + } + } else { + if (w->node == fn) { + w->node = child; + if (children&2) { + RT6_TRACE("W %p adjusted by delnode 2, s=%d\n", w, w->state); + w->state = w->state >= FWS_R ? FWS_U : FWS_INIT; + } else { + RT6_TRACE("W %p adjusted by delnode 2, s=%d\n", w, w->state); + w->state = w->state >= FWS_C ? FWS_U : FWS_INIT; + } + } + } + } + read_unlock(&net->ipv6.fib6_walker_lock); + + node_free(net, fn); + if (pn->fn_flags & RTN_RTINFO || FIB6_SUBTREE(pn)) + return pn; + + RCU_INIT_POINTER(pn->leaf, NULL); + fib6_info_release(pn_leaf); + fn = pn; + } +} + +static void fib6_del_route(struct fib6_table *table, struct fib6_node *fn, + struct fib6_info __rcu **rtp, struct nl_info *info) +{ + struct fib6_info *leaf, *replace_rt = NULL; + struct fib6_walker *w; + struct fib6_info *rt = rcu_dereference_protected(*rtp, + lockdep_is_held(&table->tb6_lock)); + struct net *net = info->nl_net; + bool notify_del = false; + + RT6_TRACE("fib6_del_route\n"); + + /* If the deleted route is the first in the node and it is not part of + * a multipath route, then we need to replace it with the next route + * in the node, if exists. + */ + leaf = rcu_dereference_protected(fn->leaf, + lockdep_is_held(&table->tb6_lock)); + if (leaf == rt && !rt->fib6_nsiblings) { + if (rcu_access_pointer(rt->fib6_next)) + replace_rt = rcu_dereference_protected(rt->fib6_next, + lockdep_is_held(&table->tb6_lock)); + else + notify_del = true; + } + + /* Unlink it */ + *rtp = rt->fib6_next; + rt->fib6_node = NULL; + net->ipv6.rt6_stats->fib_rt_entries--; + net->ipv6.rt6_stats->fib_discarded_routes++; + + /* Reset round-robin state, if necessary */ + if (rcu_access_pointer(fn->rr_ptr) == rt) + fn->rr_ptr = NULL; + + /* Remove this entry from other siblings */ + if (rt->fib6_nsiblings) { + struct fib6_info *sibling, *next_sibling; + + /* The route is deleted from a multipath route. If this + * multipath route is the first route in the node, then we need + * to emit a delete notification. Otherwise, we need to skip + * the notification. + */ + if (rt->fib6_metric == leaf->fib6_metric && + rt6_qualify_for_ecmp(leaf)) + notify_del = true; + list_for_each_entry_safe(sibling, next_sibling, + &rt->fib6_siblings, fib6_siblings) + sibling->fib6_nsiblings--; + rt->fib6_nsiblings = 0; + list_del_init(&rt->fib6_siblings); + rt6_multipath_rebalance(next_sibling); + } + + /* Adjust walkers */ + read_lock(&net->ipv6.fib6_walker_lock); + FOR_WALKERS(net, w) { + if (w->state == FWS_C && w->leaf == rt) { + RT6_TRACE("walker %p adjusted by delroute\n", w); + w->leaf = rcu_dereference_protected(rt->fib6_next, + lockdep_is_held(&table->tb6_lock)); + if (!w->leaf) + w->state = FWS_U; + } + } + read_unlock(&net->ipv6.fib6_walker_lock); + + /* If it was last route, call fib6_repair_tree() to: + * 1. For root node, put back null_entry as how the table was created. + * 2. For other nodes, expunge its radix tree node. + */ + if (!rcu_access_pointer(fn->leaf)) { + if (!(fn->fn_flags & RTN_TL_ROOT)) { + fn->fn_flags &= ~RTN_RTINFO; + net->ipv6.rt6_stats->fib_route_nodes--; + } + fn = fib6_repair_tree(net, table, fn); + } + + fib6_purge_rt(rt, fn, net); + + if (!info->skip_notify_kernel) { + if (notify_del) + call_fib6_entry_notifiers(net, FIB_EVENT_ENTRY_DEL, + rt, NULL); + else if (replace_rt) + call_fib6_entry_notifiers_replace(net, replace_rt); + } + if (!info->skip_notify) + inet6_rt_notify(RTM_DELROUTE, rt, info, 0); + + fib6_info_release(rt); +} + +/* Need to own table->tb6_lock */ +int fib6_del(struct fib6_info *rt, struct nl_info *info) +{ + struct net *net = info->nl_net; + struct fib6_info __rcu **rtp; + struct fib6_info __rcu **rtp_next; + struct fib6_table *table; + struct fib6_node *fn; + + if (rt == net->ipv6.fib6_null_entry) + return -ENOENT; + + table = rt->fib6_table; + fn = rcu_dereference_protected(rt->fib6_node, + lockdep_is_held(&table->tb6_lock)); + if (!fn) + return -ENOENT; + + WARN_ON(!(fn->fn_flags & RTN_RTINFO)); + + /* + * Walk the leaf entries looking for ourself + */ + + for (rtp = &fn->leaf; *rtp; rtp = rtp_next) { + struct fib6_info *cur = rcu_dereference_protected(*rtp, + lockdep_is_held(&table->tb6_lock)); + if (rt == cur) { + if (fib6_requires_src(cur)) + fib6_routes_require_src_dec(info->nl_net); + fib6_del_route(table, fn, rtp, info); + return 0; + } + rtp_next = &cur->fib6_next; + } + return -ENOENT; +} + +/* + * Tree traversal function. + * + * Certainly, it is not interrupt safe. + * However, it is internally reenterable wrt itself and fib6_add/fib6_del. + * It means, that we can modify tree during walking + * and use this function for garbage collection, clone pruning, + * cleaning tree when a device goes down etc. etc. + * + * It guarantees that every node will be traversed, + * and that it will be traversed only once. + * + * Callback function w->func may return: + * 0 -> continue walking. + * positive value -> walking is suspended (used by tree dumps, + * and probably by gc, if it will be split to several slices) + * negative value -> terminate walking. + * + * The function itself returns: + * 0 -> walk is complete. + * >0 -> walk is incomplete (i.e. suspended) + * <0 -> walk is terminated by an error. + * + * This function is called with tb6_lock held. + */ + +static int fib6_walk_continue(struct fib6_walker *w) +{ + struct fib6_node *fn, *pn, *left, *right; + + /* w->root should always be table->tb6_root */ + WARN_ON_ONCE(!(w->root->fn_flags & RTN_TL_ROOT)); + + for (;;) { + fn = w->node; + if (!fn) + return 0; + + switch (w->state) { +#ifdef CONFIG_IPV6_SUBTREES + case FWS_S: + if (FIB6_SUBTREE(fn)) { + w->node = FIB6_SUBTREE(fn); + continue; + } + w->state = FWS_L; + fallthrough; +#endif + case FWS_L: + left = rcu_dereference_protected(fn->left, 1); + if (left) { + w->node = left; + w->state = FWS_INIT; + continue; + } + w->state = FWS_R; + fallthrough; + case FWS_R: + right = rcu_dereference_protected(fn->right, 1); + if (right) { + w->node = right; + w->state = FWS_INIT; + continue; + } + w->state = FWS_C; + w->leaf = rcu_dereference_protected(fn->leaf, 1); + fallthrough; + case FWS_C: + if (w->leaf && fn->fn_flags & RTN_RTINFO) { + int err; + + if (w->skip) { + w->skip--; + goto skip; + } + + err = w->func(w); + if (err) + return err; + + w->count++; + continue; + } +skip: + w->state = FWS_U; + fallthrough; + case FWS_U: + if (fn == w->root) + return 0; + pn = rcu_dereference_protected(fn->parent, 1); + left = rcu_dereference_protected(pn->left, 1); + right = rcu_dereference_protected(pn->right, 1); + w->node = pn; +#ifdef CONFIG_IPV6_SUBTREES + if (FIB6_SUBTREE(pn) == fn) { + WARN_ON(!(fn->fn_flags & RTN_ROOT)); + w->state = FWS_L; + continue; + } +#endif + if (left == fn) { + w->state = FWS_R; + continue; + } + if (right == fn) { + w->state = FWS_C; + w->leaf = rcu_dereference_protected(w->node->leaf, 1); + continue; + } +#if RT6_DEBUG >= 2 + WARN_ON(1); +#endif + } + } +} + +static int fib6_walk(struct net *net, struct fib6_walker *w) +{ + int res; + + w->state = FWS_INIT; + w->node = w->root; + + fib6_walker_link(net, w); + res = fib6_walk_continue(w); + if (res <= 0) + fib6_walker_unlink(net, w); + return res; +} + +static int fib6_clean_node(struct fib6_walker *w) +{ + int res; + struct fib6_info *rt; + struct fib6_cleaner *c = container_of(w, struct fib6_cleaner, w); + struct nl_info info = { + .nl_net = c->net, + .skip_notify = c->skip_notify, + }; + + if (c->sernum != FIB6_NO_SERNUM_CHANGE && + READ_ONCE(w->node->fn_sernum) != c->sernum) + WRITE_ONCE(w->node->fn_sernum, c->sernum); + + if (!c->func) { + WARN_ON_ONCE(c->sernum == FIB6_NO_SERNUM_CHANGE); + w->leaf = NULL; + return 0; + } + + for_each_fib6_walker_rt(w) { + res = c->func(rt, c->arg); + if (res == -1) { + w->leaf = rt; + res = fib6_del(rt, &info); + if (res) { +#if RT6_DEBUG >= 2 + pr_debug("%s: del failed: rt=%p@%p err=%d\n", + __func__, rt, + rcu_access_pointer(rt->fib6_node), + res); +#endif + continue; + } + return 0; + } else if (res == -2) { + if (WARN_ON(!rt->fib6_nsiblings)) + continue; + rt = list_last_entry(&rt->fib6_siblings, + struct fib6_info, fib6_siblings); + continue; + } + WARN_ON(res != 0); + } + w->leaf = rt; + return 0; +} + +/* + * Convenient frontend to tree walker. + * + * func is called on each route. + * It may return -2 -> skip multipath route. + * -1 -> delete this route. + * 0 -> continue walking + */ + +static void fib6_clean_tree(struct net *net, struct fib6_node *root, + int (*func)(struct fib6_info *, void *arg), + int sernum, void *arg, bool skip_notify) +{ + struct fib6_cleaner c; + + c.w.root = root; + c.w.func = fib6_clean_node; + c.w.count = 0; + c.w.skip = 0; + c.w.skip_in_node = 0; + c.func = func; + c.sernum = sernum; + c.arg = arg; + c.net = net; + c.skip_notify = skip_notify; + + fib6_walk(net, &c.w); +} + +static void __fib6_clean_all(struct net *net, + int (*func)(struct fib6_info *, void *), + int sernum, void *arg, bool skip_notify) +{ + struct fib6_table *table; + struct hlist_head *head; + unsigned int h; + + rcu_read_lock(); + for (h = 0; h < FIB6_TABLE_HASHSZ; h++) { + head = &net->ipv6.fib_table_hash[h]; + hlist_for_each_entry_rcu(table, head, tb6_hlist) { + spin_lock_bh(&table->tb6_lock); + fib6_clean_tree(net, &table->tb6_root, + func, sernum, arg, skip_notify); + spin_unlock_bh(&table->tb6_lock); + } + } + rcu_read_unlock(); +} + +void fib6_clean_all(struct net *net, int (*func)(struct fib6_info *, void *), + void *arg) +{ + __fib6_clean_all(net, func, FIB6_NO_SERNUM_CHANGE, arg, false); +} + +void fib6_clean_all_skip_notify(struct net *net, + int (*func)(struct fib6_info *, void *), + void *arg) +{ + __fib6_clean_all(net, func, FIB6_NO_SERNUM_CHANGE, arg, true); +} + +static void fib6_flush_trees(struct net *net) +{ + int new_sernum = fib6_new_sernum(net); + + __fib6_clean_all(net, NULL, new_sernum, NULL, false); +} + +/* + * Garbage collection + */ + +static int fib6_age(struct fib6_info *rt, void *arg) +{ + struct fib6_gc_args *gc_args = arg; + unsigned long now = jiffies; + + /* + * check addrconf expiration here. + * Routes are expired even if they are in use. + */ + + if (rt->fib6_flags & RTF_EXPIRES && rt->expires) { + if (time_after(now, rt->expires)) { + RT6_TRACE("expiring %p\n", rt); + return -1; + } + gc_args->more++; + } + + /* Also age clones in the exception table. + * Note, that clones are aged out + * only if they are not in use now. + */ + rt6_age_exceptions(rt, gc_args, now); + + return 0; +} + +void fib6_run_gc(unsigned long expires, struct net *net, bool force) +{ + struct fib6_gc_args gc_args; + unsigned long now; + + if (force) { + spin_lock_bh(&net->ipv6.fib6_gc_lock); + } else if (!spin_trylock_bh(&net->ipv6.fib6_gc_lock)) { + mod_timer(&net->ipv6.ip6_fib_timer, jiffies + HZ); + return; + } + gc_args.timeout = expires ? (int)expires : + net->ipv6.sysctl.ip6_rt_gc_interval; + gc_args.more = 0; + + fib6_clean_all(net, fib6_age, &gc_args); + now = jiffies; + net->ipv6.ip6_rt_last_gc = now; + + if (gc_args.more) + mod_timer(&net->ipv6.ip6_fib_timer, + round_jiffies(now + + net->ipv6.sysctl.ip6_rt_gc_interval)); + else + del_timer(&net->ipv6.ip6_fib_timer); + spin_unlock_bh(&net->ipv6.fib6_gc_lock); +} + +static void fib6_gc_timer_cb(struct timer_list *t) +{ + struct net *arg = from_timer(arg, t, ipv6.ip6_fib_timer); + + fib6_run_gc(0, arg, true); +} + +static int __net_init fib6_net_init(struct net *net) +{ + size_t size = sizeof(struct hlist_head) * FIB6_TABLE_HASHSZ; + int err; + + err = fib6_notifier_init(net); + if (err) + return err; + + /* Default to 3-tuple */ + net->ipv6.sysctl.multipath_hash_fields = + FIB_MULTIPATH_HASH_FIELD_DEFAULT_MASK; + + spin_lock_init(&net->ipv6.fib6_gc_lock); + rwlock_init(&net->ipv6.fib6_walker_lock); + INIT_LIST_HEAD(&net->ipv6.fib6_walkers); + timer_setup(&net->ipv6.ip6_fib_timer, fib6_gc_timer_cb, 0); + + net->ipv6.rt6_stats = kzalloc(sizeof(*net->ipv6.rt6_stats), GFP_KERNEL); + if (!net->ipv6.rt6_stats) + goto out_notifier; + + /* Avoid false sharing : Use at least a full cache line */ + size = max_t(size_t, size, L1_CACHE_BYTES); + + net->ipv6.fib_table_hash = kzalloc(size, GFP_KERNEL); + if (!net->ipv6.fib_table_hash) + goto out_rt6_stats; + + net->ipv6.fib6_main_tbl = kzalloc(sizeof(*net->ipv6.fib6_main_tbl), + GFP_KERNEL); + if (!net->ipv6.fib6_main_tbl) + goto out_fib_table_hash; + + net->ipv6.fib6_main_tbl->tb6_id = RT6_TABLE_MAIN; + rcu_assign_pointer(net->ipv6.fib6_main_tbl->tb6_root.leaf, + net->ipv6.fib6_null_entry); + net->ipv6.fib6_main_tbl->tb6_root.fn_flags = + RTN_ROOT | RTN_TL_ROOT | RTN_RTINFO; + inet_peer_base_init(&net->ipv6.fib6_main_tbl->tb6_peers); + +#ifdef CONFIG_IPV6_MULTIPLE_TABLES + net->ipv6.fib6_local_tbl = kzalloc(sizeof(*net->ipv6.fib6_local_tbl), + GFP_KERNEL); + if (!net->ipv6.fib6_local_tbl) + goto out_fib6_main_tbl; + net->ipv6.fib6_local_tbl->tb6_id = RT6_TABLE_LOCAL; + rcu_assign_pointer(net->ipv6.fib6_local_tbl->tb6_root.leaf, + net->ipv6.fib6_null_entry); + net->ipv6.fib6_local_tbl->tb6_root.fn_flags = + RTN_ROOT | RTN_TL_ROOT | RTN_RTINFO; + inet_peer_base_init(&net->ipv6.fib6_local_tbl->tb6_peers); +#endif + fib6_tables_init(net); + + return 0; + +#ifdef CONFIG_IPV6_MULTIPLE_TABLES +out_fib6_main_tbl: + kfree(net->ipv6.fib6_main_tbl); +#endif +out_fib_table_hash: + kfree(net->ipv6.fib_table_hash); +out_rt6_stats: + kfree(net->ipv6.rt6_stats); +out_notifier: + fib6_notifier_exit(net); + return -ENOMEM; +} + +static void fib6_net_exit(struct net *net) +{ + unsigned int i; + + del_timer_sync(&net->ipv6.ip6_fib_timer); + + for (i = 0; i < FIB6_TABLE_HASHSZ; i++) { + struct hlist_head *head = &net->ipv6.fib_table_hash[i]; + struct hlist_node *tmp; + struct fib6_table *tb; + + hlist_for_each_entry_safe(tb, tmp, head, tb6_hlist) { + hlist_del(&tb->tb6_hlist); + fib6_free_table(tb); + } + } + + kfree(net->ipv6.fib_table_hash); + kfree(net->ipv6.rt6_stats); + fib6_notifier_exit(net); +} + +static struct pernet_operations fib6_net_ops = { + .init = fib6_net_init, + .exit = fib6_net_exit, +}; + +int __init fib6_init(void) +{ + int ret = -ENOMEM; + + fib6_node_kmem = kmem_cache_create("fib6_nodes", + sizeof(struct fib6_node), 0, + SLAB_HWCACHE_ALIGN | SLAB_ACCOUNT, + NULL); + if (!fib6_node_kmem) + goto out; + + ret = register_pernet_subsys(&fib6_net_ops); + if (ret) + goto out_kmem_cache_create; + + ret = rtnl_register_module(THIS_MODULE, PF_INET6, RTM_GETROUTE, NULL, + inet6_dump_fib, 0); + if (ret) + goto out_unregister_subsys; + + __fib6_flush_trees = fib6_flush_trees; +out: + return ret; + +out_unregister_subsys: + unregister_pernet_subsys(&fib6_net_ops); +out_kmem_cache_create: + kmem_cache_destroy(fib6_node_kmem); + goto out; +} + +void fib6_gc_cleanup(void) +{ + unregister_pernet_subsys(&fib6_net_ops); + kmem_cache_destroy(fib6_node_kmem); +} + +#ifdef CONFIG_PROC_FS +static int ipv6_route_native_seq_show(struct seq_file *seq, void *v) +{ + struct fib6_info *rt = v; + struct ipv6_route_iter *iter = seq->private; + struct fib6_nh *fib6_nh = rt->fib6_nh; + unsigned int flags = rt->fib6_flags; + const struct net_device *dev; + + if (rt->nh) + fib6_nh = nexthop_fib6_nh(rt->nh); + + seq_printf(seq, "%pi6 %02x ", &rt->fib6_dst.addr, rt->fib6_dst.plen); + +#ifdef CONFIG_IPV6_SUBTREES + seq_printf(seq, "%pi6 %02x ", &rt->fib6_src.addr, rt->fib6_src.plen); +#else + seq_puts(seq, "00000000000000000000000000000000 00 "); +#endif + if (fib6_nh->fib_nh_gw_family) { + flags |= RTF_GATEWAY; + seq_printf(seq, "%pi6", &fib6_nh->fib_nh_gw6); + } else { + seq_puts(seq, "00000000000000000000000000000000"); + } + + dev = fib6_nh->fib_nh_dev; + seq_printf(seq, " %08x %08x %08x %08x %8s\n", + rt->fib6_metric, refcount_read(&rt->fib6_ref), 0, + flags, dev ? dev->name : ""); + iter->w.leaf = NULL; + return 0; +} + +static int ipv6_route_yield(struct fib6_walker *w) +{ + struct ipv6_route_iter *iter = w->args; + + if (!iter->skip) + return 1; + + do { + iter->w.leaf = rcu_dereference_protected( + iter->w.leaf->fib6_next, + lockdep_is_held(&iter->tbl->tb6_lock)); + iter->skip--; + if (!iter->skip && iter->w.leaf) + return 1; + } while (iter->w.leaf); + + return 0; +} + +static void ipv6_route_seq_setup_walk(struct ipv6_route_iter *iter, + struct net *net) +{ + memset(&iter->w, 0, sizeof(iter->w)); + iter->w.func = ipv6_route_yield; + iter->w.root = &iter->tbl->tb6_root; + iter->w.state = FWS_INIT; + iter->w.node = iter->w.root; + iter->w.args = iter; + iter->sernum = READ_ONCE(iter->w.root->fn_sernum); + INIT_LIST_HEAD(&iter->w.lh); + fib6_walker_link(net, &iter->w); +} + +static struct fib6_table *ipv6_route_seq_next_table(struct fib6_table *tbl, + struct net *net) +{ + unsigned int h; + struct hlist_node *node; + + if (tbl) { + h = (tbl->tb6_id & (FIB6_TABLE_HASHSZ - 1)) + 1; + node = rcu_dereference(hlist_next_rcu(&tbl->tb6_hlist)); + } else { + h = 0; + node = NULL; + } + + while (!node && h < FIB6_TABLE_HASHSZ) { + node = rcu_dereference( + hlist_first_rcu(&net->ipv6.fib_table_hash[h++])); + } + return hlist_entry_safe(node, struct fib6_table, tb6_hlist); +} + +static void ipv6_route_check_sernum(struct ipv6_route_iter *iter) +{ + int sernum = READ_ONCE(iter->w.root->fn_sernum); + + if (iter->sernum != sernum) { + iter->sernum = sernum; + iter->w.state = FWS_INIT; + iter->w.node = iter->w.root; + WARN_ON(iter->w.skip); + iter->w.skip = iter->w.count; + } +} + +static void *ipv6_route_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + int r; + struct fib6_info *n; + struct net *net = seq_file_net(seq); + struct ipv6_route_iter *iter = seq->private; + + ++(*pos); + if (!v) + goto iter_table; + + n = rcu_dereference(((struct fib6_info *)v)->fib6_next); + if (n) + return n; + +iter_table: + ipv6_route_check_sernum(iter); + spin_lock_bh(&iter->tbl->tb6_lock); + r = fib6_walk_continue(&iter->w); + spin_unlock_bh(&iter->tbl->tb6_lock); + if (r > 0) { + return iter->w.leaf; + } else if (r < 0) { + fib6_walker_unlink(net, &iter->w); + return NULL; + } + fib6_walker_unlink(net, &iter->w); + + iter->tbl = ipv6_route_seq_next_table(iter->tbl, net); + if (!iter->tbl) + return NULL; + + ipv6_route_seq_setup_walk(iter, net); + goto iter_table; +} + +static void *ipv6_route_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(RCU) +{ + struct net *net = seq_file_net(seq); + struct ipv6_route_iter *iter = seq->private; + + rcu_read_lock(); + iter->tbl = ipv6_route_seq_next_table(NULL, net); + iter->skip = *pos; + + if (iter->tbl) { + loff_t p = 0; + + ipv6_route_seq_setup_walk(iter, net); + return ipv6_route_seq_next(seq, NULL, &p); + } else { + return NULL; + } +} + +static bool ipv6_route_iter_active(struct ipv6_route_iter *iter) +{ + struct fib6_walker *w = &iter->w; + return w->node && !(w->state == FWS_U && w->node == w->root); +} + +static void ipv6_route_native_seq_stop(struct seq_file *seq, void *v) + __releases(RCU) +{ + struct net *net = seq_file_net(seq); + struct ipv6_route_iter *iter = seq->private; + + if (ipv6_route_iter_active(iter)) + fib6_walker_unlink(net, &iter->w); + + rcu_read_unlock(); +} + +#if IS_BUILTIN(CONFIG_IPV6) && defined(CONFIG_BPF_SYSCALL) +static int ipv6_route_prog_seq_show(struct bpf_prog *prog, + struct bpf_iter_meta *meta, + void *v) +{ + struct bpf_iter__ipv6_route ctx; + + ctx.meta = meta; + ctx.rt = v; + return bpf_iter_run_prog(prog, &ctx); +} + +static int ipv6_route_seq_show(struct seq_file *seq, void *v) +{ + struct ipv6_route_iter *iter = seq->private; + struct bpf_iter_meta meta; + struct bpf_prog *prog; + int ret; + + meta.seq = seq; + prog = bpf_iter_get_info(&meta, false); + if (!prog) + return ipv6_route_native_seq_show(seq, v); + + ret = ipv6_route_prog_seq_show(prog, &meta, v); + iter->w.leaf = NULL; + + return ret; +} + +static void ipv6_route_seq_stop(struct seq_file *seq, void *v) +{ + struct bpf_iter_meta meta; + struct bpf_prog *prog; + + if (!v) { + meta.seq = seq; + prog = bpf_iter_get_info(&meta, true); + if (prog) + (void)ipv6_route_prog_seq_show(prog, &meta, v); + } + + ipv6_route_native_seq_stop(seq, v); +} +#else +static int ipv6_route_seq_show(struct seq_file *seq, void *v) +{ + return ipv6_route_native_seq_show(seq, v); +} + +static void ipv6_route_seq_stop(struct seq_file *seq, void *v) +{ + ipv6_route_native_seq_stop(seq, v); +} +#endif + +const struct seq_operations ipv6_route_seq_ops = { + .start = ipv6_route_seq_start, + .next = ipv6_route_seq_next, + .stop = ipv6_route_seq_stop, + .show = ipv6_route_seq_show +}; +#endif /* CONFIG_PROC_FS */ diff --git a/net/ipv6/ip6_flowlabel.c b/net/ipv6/ip6_flowlabel.c new file mode 100644 index 000000000..18481eb76 --- /dev/null +++ b/net/ipv6/ip6_flowlabel.c @@ -0,0 +1,909 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * ip6_flowlabel.c IPv6 flowlabel manager. + * + * Authors: Alexey Kuznetsov, + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include + +#define FL_MIN_LINGER 6 /* Minimal linger. It is set to 6sec specified + in old IPv6 RFC. Well, it was reasonable value. + */ +#define FL_MAX_LINGER 150 /* Maximal linger timeout */ + +/* FL hash table */ + +#define FL_MAX_PER_SOCK 32 +#define FL_MAX_SIZE 4096 +#define FL_HASH_MASK 255 +#define FL_HASH(l) (ntohl(l)&FL_HASH_MASK) + +static atomic_t fl_size = ATOMIC_INIT(0); +static struct ip6_flowlabel __rcu *fl_ht[FL_HASH_MASK+1]; + +static void ip6_fl_gc(struct timer_list *unused); +static DEFINE_TIMER(ip6_fl_gc_timer, ip6_fl_gc); + +/* FL hash table lock: it protects only of GC */ + +static DEFINE_SPINLOCK(ip6_fl_lock); + +/* Big socket sock */ + +static DEFINE_SPINLOCK(ip6_sk_fl_lock); + +DEFINE_STATIC_KEY_DEFERRED_FALSE(ipv6_flowlabel_exclusive, HZ); +EXPORT_SYMBOL(ipv6_flowlabel_exclusive); + +#define for_each_fl_rcu(hash, fl) \ + for (fl = rcu_dereference_bh(fl_ht[(hash)]); \ + fl != NULL; \ + fl = rcu_dereference_bh(fl->next)) +#define for_each_fl_continue_rcu(fl) \ + for (fl = rcu_dereference_bh(fl->next); \ + fl != NULL; \ + fl = rcu_dereference_bh(fl->next)) + +#define for_each_sk_fl_rcu(np, sfl) \ + for (sfl = rcu_dereference_bh(np->ipv6_fl_list); \ + sfl != NULL; \ + sfl = rcu_dereference_bh(sfl->next)) + +static inline struct ip6_flowlabel *__fl_lookup(struct net *net, __be32 label) +{ + struct ip6_flowlabel *fl; + + for_each_fl_rcu(FL_HASH(label), fl) { + if (fl->label == label && net_eq(fl->fl_net, net)) + return fl; + } + return NULL; +} + +static struct ip6_flowlabel *fl_lookup(struct net *net, __be32 label) +{ + struct ip6_flowlabel *fl; + + rcu_read_lock_bh(); + fl = __fl_lookup(net, label); + if (fl && !atomic_inc_not_zero(&fl->users)) + fl = NULL; + rcu_read_unlock_bh(); + return fl; +} + +static bool fl_shared_exclusive(struct ip6_flowlabel *fl) +{ + return fl->share == IPV6_FL_S_EXCL || + fl->share == IPV6_FL_S_PROCESS || + fl->share == IPV6_FL_S_USER; +} + +static void fl_free_rcu(struct rcu_head *head) +{ + struct ip6_flowlabel *fl = container_of(head, struct ip6_flowlabel, rcu); + + if (fl->share == IPV6_FL_S_PROCESS) + put_pid(fl->owner.pid); + kfree(fl->opt); + kfree(fl); +} + + +static void fl_free(struct ip6_flowlabel *fl) +{ + if (!fl) + return; + + if (fl_shared_exclusive(fl) || fl->opt) + static_branch_slow_dec_deferred(&ipv6_flowlabel_exclusive); + + call_rcu(&fl->rcu, fl_free_rcu); +} + +static void fl_release(struct ip6_flowlabel *fl) +{ + spin_lock_bh(&ip6_fl_lock); + + fl->lastuse = jiffies; + if (atomic_dec_and_test(&fl->users)) { + unsigned long ttd = fl->lastuse + fl->linger; + if (time_after(ttd, fl->expires)) + fl->expires = ttd; + ttd = fl->expires; + if (fl->opt && fl->share == IPV6_FL_S_EXCL) { + struct ipv6_txoptions *opt = fl->opt; + fl->opt = NULL; + kfree(opt); + } + if (!timer_pending(&ip6_fl_gc_timer) || + time_after(ip6_fl_gc_timer.expires, ttd)) + mod_timer(&ip6_fl_gc_timer, ttd); + } + spin_unlock_bh(&ip6_fl_lock); +} + +static void ip6_fl_gc(struct timer_list *unused) +{ + int i; + unsigned long now = jiffies; + unsigned long sched = 0; + + spin_lock(&ip6_fl_lock); + + for (i = 0; i <= FL_HASH_MASK; i++) { + struct ip6_flowlabel *fl; + struct ip6_flowlabel __rcu **flp; + + flp = &fl_ht[i]; + while ((fl = rcu_dereference_protected(*flp, + lockdep_is_held(&ip6_fl_lock))) != NULL) { + if (atomic_read(&fl->users) == 0) { + unsigned long ttd = fl->lastuse + fl->linger; + if (time_after(ttd, fl->expires)) + fl->expires = ttd; + ttd = fl->expires; + if (time_after_eq(now, ttd)) { + *flp = fl->next; + fl_free(fl); + atomic_dec(&fl_size); + continue; + } + if (!sched || time_before(ttd, sched)) + sched = ttd; + } + flp = &fl->next; + } + } + if (!sched && atomic_read(&fl_size)) + sched = now + FL_MAX_LINGER; + if (sched) { + mod_timer(&ip6_fl_gc_timer, sched); + } + spin_unlock(&ip6_fl_lock); +} + +static void __net_exit ip6_fl_purge(struct net *net) +{ + int i; + + spin_lock_bh(&ip6_fl_lock); + for (i = 0; i <= FL_HASH_MASK; i++) { + struct ip6_flowlabel *fl; + struct ip6_flowlabel __rcu **flp; + + flp = &fl_ht[i]; + while ((fl = rcu_dereference_protected(*flp, + lockdep_is_held(&ip6_fl_lock))) != NULL) { + if (net_eq(fl->fl_net, net) && + atomic_read(&fl->users) == 0) { + *flp = fl->next; + fl_free(fl); + atomic_dec(&fl_size); + continue; + } + flp = &fl->next; + } + } + spin_unlock_bh(&ip6_fl_lock); +} + +static struct ip6_flowlabel *fl_intern(struct net *net, + struct ip6_flowlabel *fl, __be32 label) +{ + struct ip6_flowlabel *lfl; + + fl->label = label & IPV6_FLOWLABEL_MASK; + + spin_lock_bh(&ip6_fl_lock); + if (label == 0) { + for (;;) { + fl->label = htonl(get_random_u32())&IPV6_FLOWLABEL_MASK; + if (fl->label) { + lfl = __fl_lookup(net, fl->label); + if (!lfl) + break; + } + } + } else { + /* + * we dropper the ip6_fl_lock, so this entry could reappear + * and we need to recheck with it. + * + * OTOH no need to search the active socket first, like it is + * done in ipv6_flowlabel_opt - sock is locked, so new entry + * with the same label can only appear on another sock + */ + lfl = __fl_lookup(net, fl->label); + if (lfl) { + atomic_inc(&lfl->users); + spin_unlock_bh(&ip6_fl_lock); + return lfl; + } + } + + fl->lastuse = jiffies; + fl->next = fl_ht[FL_HASH(fl->label)]; + rcu_assign_pointer(fl_ht[FL_HASH(fl->label)], fl); + atomic_inc(&fl_size); + spin_unlock_bh(&ip6_fl_lock); + return NULL; +} + + + +/* Socket flowlabel lists */ + +struct ip6_flowlabel *__fl6_sock_lookup(struct sock *sk, __be32 label) +{ + struct ipv6_fl_socklist *sfl; + struct ipv6_pinfo *np = inet6_sk(sk); + + label &= IPV6_FLOWLABEL_MASK; + + rcu_read_lock_bh(); + for_each_sk_fl_rcu(np, sfl) { + struct ip6_flowlabel *fl = sfl->fl; + + if (fl->label == label && atomic_inc_not_zero(&fl->users)) { + fl->lastuse = jiffies; + rcu_read_unlock_bh(); + return fl; + } + } + rcu_read_unlock_bh(); + return NULL; +} +EXPORT_SYMBOL_GPL(__fl6_sock_lookup); + +void fl6_free_socklist(struct sock *sk) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct ipv6_fl_socklist *sfl; + + if (!rcu_access_pointer(np->ipv6_fl_list)) + return; + + spin_lock_bh(&ip6_sk_fl_lock); + while ((sfl = rcu_dereference_protected(np->ipv6_fl_list, + lockdep_is_held(&ip6_sk_fl_lock))) != NULL) { + np->ipv6_fl_list = sfl->next; + spin_unlock_bh(&ip6_sk_fl_lock); + + fl_release(sfl->fl); + kfree_rcu(sfl, rcu); + + spin_lock_bh(&ip6_sk_fl_lock); + } + spin_unlock_bh(&ip6_sk_fl_lock); +} + +/* Service routines */ + + +/* + It is the only difficult place. flowlabel enforces equal headers + before and including routing header, however user may supply options + following rthdr. + */ + +struct ipv6_txoptions *fl6_merge_options(struct ipv6_txoptions *opt_space, + struct ip6_flowlabel *fl, + struct ipv6_txoptions *fopt) +{ + struct ipv6_txoptions *fl_opt = fl->opt; + + if (!fopt || fopt->opt_flen == 0) + return fl_opt; + + if (fl_opt) { + opt_space->hopopt = fl_opt->hopopt; + opt_space->dst0opt = fl_opt->dst0opt; + opt_space->srcrt = fl_opt->srcrt; + opt_space->opt_nflen = fl_opt->opt_nflen; + } else { + if (fopt->opt_nflen == 0) + return fopt; + opt_space->hopopt = NULL; + opt_space->dst0opt = NULL; + opt_space->srcrt = NULL; + opt_space->opt_nflen = 0; + } + opt_space->dst1opt = fopt->dst1opt; + opt_space->opt_flen = fopt->opt_flen; + opt_space->tot_len = fopt->tot_len; + return opt_space; +} +EXPORT_SYMBOL_GPL(fl6_merge_options); + +static unsigned long check_linger(unsigned long ttl) +{ + if (ttl < FL_MIN_LINGER) + return FL_MIN_LINGER*HZ; + if (ttl > FL_MAX_LINGER && !capable(CAP_NET_ADMIN)) + return 0; + return ttl*HZ; +} + +static int fl6_renew(struct ip6_flowlabel *fl, unsigned long linger, unsigned long expires) +{ + linger = check_linger(linger); + if (!linger) + return -EPERM; + expires = check_linger(expires); + if (!expires) + return -EPERM; + + spin_lock_bh(&ip6_fl_lock); + fl->lastuse = jiffies; + if (time_before(fl->linger, linger)) + fl->linger = linger; + if (time_before(expires, fl->linger)) + expires = fl->linger; + if (time_before(fl->expires, fl->lastuse + expires)) + fl->expires = fl->lastuse + expires; + spin_unlock_bh(&ip6_fl_lock); + + return 0; +} + +static struct ip6_flowlabel * +fl_create(struct net *net, struct sock *sk, struct in6_flowlabel_req *freq, + sockptr_t optval, int optlen, int *err_p) +{ + struct ip6_flowlabel *fl = NULL; + int olen; + int addr_type; + int err; + + olen = optlen - CMSG_ALIGN(sizeof(*freq)); + err = -EINVAL; + if (olen > 64 * 1024) + goto done; + + err = -ENOMEM; + fl = kzalloc(sizeof(*fl), GFP_KERNEL); + if (!fl) + goto done; + + if (olen > 0) { + struct msghdr msg; + struct flowi6 flowi6; + struct ipcm6_cookie ipc6; + + err = -ENOMEM; + fl->opt = kmalloc(sizeof(*fl->opt) + olen, GFP_KERNEL); + if (!fl->opt) + goto done; + + memset(fl->opt, 0, sizeof(*fl->opt)); + fl->opt->tot_len = sizeof(*fl->opt) + olen; + err = -EFAULT; + if (copy_from_sockptr_offset(fl->opt + 1, optval, + CMSG_ALIGN(sizeof(*freq)), olen)) + goto done; + + msg.msg_controllen = olen; + msg.msg_control = (void *)(fl->opt+1); + memset(&flowi6, 0, sizeof(flowi6)); + + ipc6.opt = fl->opt; + err = ip6_datagram_send_ctl(net, sk, &msg, &flowi6, &ipc6); + if (err) + goto done; + err = -EINVAL; + if (fl->opt->opt_flen) + goto done; + if (fl->opt->opt_nflen == 0) { + kfree(fl->opt); + fl->opt = NULL; + } + } + + fl->fl_net = net; + fl->expires = jiffies; + err = fl6_renew(fl, freq->flr_linger, freq->flr_expires); + if (err) + goto done; + fl->share = freq->flr_share; + addr_type = ipv6_addr_type(&freq->flr_dst); + if ((addr_type & IPV6_ADDR_MAPPED) || + addr_type == IPV6_ADDR_ANY) { + err = -EINVAL; + goto done; + } + fl->dst = freq->flr_dst; + atomic_set(&fl->users, 1); + switch (fl->share) { + case IPV6_FL_S_EXCL: + case IPV6_FL_S_ANY: + break; + case IPV6_FL_S_PROCESS: + fl->owner.pid = get_task_pid(current, PIDTYPE_PID); + break; + case IPV6_FL_S_USER: + fl->owner.uid = current_euid(); + break; + default: + err = -EINVAL; + goto done; + } + if (fl_shared_exclusive(fl) || fl->opt) { + WRITE_ONCE(sock_net(sk)->ipv6.flowlabel_has_excl, 1); + static_branch_deferred_inc(&ipv6_flowlabel_exclusive); + } + return fl; + +done: + if (fl) { + kfree(fl->opt); + kfree(fl); + } + *err_p = err; + return NULL; +} + +static int mem_check(struct sock *sk) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct ipv6_fl_socklist *sfl; + int room = FL_MAX_SIZE - atomic_read(&fl_size); + int count = 0; + + if (room > FL_MAX_SIZE - FL_MAX_PER_SOCK) + return 0; + + rcu_read_lock_bh(); + for_each_sk_fl_rcu(np, sfl) + count++; + rcu_read_unlock_bh(); + + if (room <= 0 || + ((count >= FL_MAX_PER_SOCK || + (count > 0 && room < FL_MAX_SIZE/2) || room < FL_MAX_SIZE/4) && + !capable(CAP_NET_ADMIN))) + return -ENOBUFS; + + return 0; +} + +static inline void fl_link(struct ipv6_pinfo *np, struct ipv6_fl_socklist *sfl, + struct ip6_flowlabel *fl) +{ + spin_lock_bh(&ip6_sk_fl_lock); + sfl->fl = fl; + sfl->next = np->ipv6_fl_list; + rcu_assign_pointer(np->ipv6_fl_list, sfl); + spin_unlock_bh(&ip6_sk_fl_lock); +} + +int ipv6_flowlabel_opt_get(struct sock *sk, struct in6_flowlabel_req *freq, + int flags) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct ipv6_fl_socklist *sfl; + + if (flags & IPV6_FL_F_REMOTE) { + freq->flr_label = np->rcv_flowinfo & IPV6_FLOWLABEL_MASK; + return 0; + } + + if (np->repflow) { + freq->flr_label = np->flow_label; + return 0; + } + + rcu_read_lock_bh(); + + for_each_sk_fl_rcu(np, sfl) { + if (sfl->fl->label == (np->flow_label & IPV6_FLOWLABEL_MASK)) { + spin_lock_bh(&ip6_fl_lock); + freq->flr_label = sfl->fl->label; + freq->flr_dst = sfl->fl->dst; + freq->flr_share = sfl->fl->share; + freq->flr_expires = (sfl->fl->expires - jiffies) / HZ; + freq->flr_linger = sfl->fl->linger / HZ; + + spin_unlock_bh(&ip6_fl_lock); + rcu_read_unlock_bh(); + return 0; + } + } + rcu_read_unlock_bh(); + + return -ENOENT; +} + +#define socklist_dereference(__sflp) \ + rcu_dereference_protected(__sflp, lockdep_is_held(&ip6_sk_fl_lock)) + +static int ipv6_flowlabel_put(struct sock *sk, struct in6_flowlabel_req *freq) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct ipv6_fl_socklist __rcu **sflp; + struct ipv6_fl_socklist *sfl; + + if (freq->flr_flags & IPV6_FL_F_REFLECT) { + if (sk->sk_protocol != IPPROTO_TCP) + return -ENOPROTOOPT; + if (!np->repflow) + return -ESRCH; + np->flow_label = 0; + np->repflow = 0; + return 0; + } + + spin_lock_bh(&ip6_sk_fl_lock); + for (sflp = &np->ipv6_fl_list; + (sfl = socklist_dereference(*sflp)) != NULL; + sflp = &sfl->next) { + if (sfl->fl->label == freq->flr_label) + goto found; + } + spin_unlock_bh(&ip6_sk_fl_lock); + return -ESRCH; +found: + if (freq->flr_label == (np->flow_label & IPV6_FLOWLABEL_MASK)) + np->flow_label &= ~IPV6_FLOWLABEL_MASK; + *sflp = sfl->next; + spin_unlock_bh(&ip6_sk_fl_lock); + fl_release(sfl->fl); + kfree_rcu(sfl, rcu); + return 0; +} + +static int ipv6_flowlabel_renew(struct sock *sk, struct in6_flowlabel_req *freq) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct net *net = sock_net(sk); + struct ipv6_fl_socklist *sfl; + int err; + + rcu_read_lock_bh(); + for_each_sk_fl_rcu(np, sfl) { + if (sfl->fl->label == freq->flr_label) { + err = fl6_renew(sfl->fl, freq->flr_linger, + freq->flr_expires); + rcu_read_unlock_bh(); + return err; + } + } + rcu_read_unlock_bh(); + + if (freq->flr_share == IPV6_FL_S_NONE && + ns_capable(net->user_ns, CAP_NET_ADMIN)) { + struct ip6_flowlabel *fl = fl_lookup(net, freq->flr_label); + + if (fl) { + err = fl6_renew(fl, freq->flr_linger, + freq->flr_expires); + fl_release(fl); + return err; + } + } + return -ESRCH; +} + +static int ipv6_flowlabel_get(struct sock *sk, struct in6_flowlabel_req *freq, + sockptr_t optval, int optlen) +{ + struct ipv6_fl_socklist *sfl, *sfl1 = NULL; + struct ip6_flowlabel *fl, *fl1 = NULL; + struct ipv6_pinfo *np = inet6_sk(sk); + struct net *net = sock_net(sk); + int err; + + if (freq->flr_flags & IPV6_FL_F_REFLECT) { + if (net->ipv6.sysctl.flowlabel_consistency) { + net_info_ratelimited("Can not set IPV6_FL_F_REFLECT if flowlabel_consistency sysctl is enable\n"); + return -EPERM; + } + + if (sk->sk_protocol != IPPROTO_TCP) + return -ENOPROTOOPT; + np->repflow = 1; + return 0; + } + + if (freq->flr_label & ~IPV6_FLOWLABEL_MASK) + return -EINVAL; + if (net->ipv6.sysctl.flowlabel_state_ranges && + (freq->flr_label & IPV6_FLOWLABEL_STATELESS_FLAG)) + return -ERANGE; + + fl = fl_create(net, sk, freq, optval, optlen, &err); + if (!fl) + return err; + + sfl1 = kmalloc(sizeof(*sfl1), GFP_KERNEL); + + if (freq->flr_label) { + err = -EEXIST; + rcu_read_lock_bh(); + for_each_sk_fl_rcu(np, sfl) { + if (sfl->fl->label == freq->flr_label) { + if (freq->flr_flags & IPV6_FL_F_EXCL) { + rcu_read_unlock_bh(); + goto done; + } + fl1 = sfl->fl; + if (!atomic_inc_not_zero(&fl1->users)) + fl1 = NULL; + break; + } + } + rcu_read_unlock_bh(); + + if (!fl1) + fl1 = fl_lookup(net, freq->flr_label); + if (fl1) { +recheck: + err = -EEXIST; + if (freq->flr_flags&IPV6_FL_F_EXCL) + goto release; + err = -EPERM; + if (fl1->share == IPV6_FL_S_EXCL || + fl1->share != fl->share || + ((fl1->share == IPV6_FL_S_PROCESS) && + (fl1->owner.pid != fl->owner.pid)) || + ((fl1->share == IPV6_FL_S_USER) && + !uid_eq(fl1->owner.uid, fl->owner.uid))) + goto release; + + err = -ENOMEM; + if (!sfl1) + goto release; + if (fl->linger > fl1->linger) + fl1->linger = fl->linger; + if ((long)(fl->expires - fl1->expires) > 0) + fl1->expires = fl->expires; + fl_link(np, sfl1, fl1); + fl_free(fl); + return 0; + +release: + fl_release(fl1); + goto done; + } + } + err = -ENOENT; + if (!(freq->flr_flags & IPV6_FL_F_CREATE)) + goto done; + + err = -ENOMEM; + if (!sfl1) + goto done; + + err = mem_check(sk); + if (err != 0) + goto done; + + fl1 = fl_intern(net, fl, freq->flr_label); + if (fl1) + goto recheck; + + if (!freq->flr_label) { + size_t offset = offsetof(struct in6_flowlabel_req, flr_label); + + if (copy_to_sockptr_offset(optval, offset, &fl->label, + sizeof(fl->label))) { + /* Intentionally ignore fault. */ + } + } + + fl_link(np, sfl1, fl); + return 0; +done: + fl_free(fl); + kfree(sfl1); + return err; +} + +int ipv6_flowlabel_opt(struct sock *sk, sockptr_t optval, int optlen) +{ + struct in6_flowlabel_req freq; + + if (optlen < sizeof(freq)) + return -EINVAL; + if (copy_from_sockptr(&freq, optval, sizeof(freq))) + return -EFAULT; + + switch (freq.flr_action) { + case IPV6_FL_A_PUT: + return ipv6_flowlabel_put(sk, &freq); + case IPV6_FL_A_RENEW: + return ipv6_flowlabel_renew(sk, &freq); + case IPV6_FL_A_GET: + return ipv6_flowlabel_get(sk, &freq, optval, optlen); + default: + return -EINVAL; + } +} + +#ifdef CONFIG_PROC_FS + +struct ip6fl_iter_state { + struct seq_net_private p; + struct pid_namespace *pid_ns; + int bucket; +}; + +#define ip6fl_seq_private(seq) ((struct ip6fl_iter_state *)(seq)->private) + +static struct ip6_flowlabel *ip6fl_get_first(struct seq_file *seq) +{ + struct ip6_flowlabel *fl = NULL; + struct ip6fl_iter_state *state = ip6fl_seq_private(seq); + struct net *net = seq_file_net(seq); + + for (state->bucket = 0; state->bucket <= FL_HASH_MASK; ++state->bucket) { + for_each_fl_rcu(state->bucket, fl) { + if (net_eq(fl->fl_net, net)) + goto out; + } + } + fl = NULL; +out: + return fl; +} + +static struct ip6_flowlabel *ip6fl_get_next(struct seq_file *seq, struct ip6_flowlabel *fl) +{ + struct ip6fl_iter_state *state = ip6fl_seq_private(seq); + struct net *net = seq_file_net(seq); + + for_each_fl_continue_rcu(fl) { + if (net_eq(fl->fl_net, net)) + goto out; + } + +try_again: + if (++state->bucket <= FL_HASH_MASK) { + for_each_fl_rcu(state->bucket, fl) { + if (net_eq(fl->fl_net, net)) + goto out; + } + goto try_again; + } + fl = NULL; + +out: + return fl; +} + +static struct ip6_flowlabel *ip6fl_get_idx(struct seq_file *seq, loff_t pos) +{ + struct ip6_flowlabel *fl = ip6fl_get_first(seq); + if (fl) + while (pos && (fl = ip6fl_get_next(seq, fl)) != NULL) + --pos; + return pos ? NULL : fl; +} + +static void *ip6fl_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(RCU) +{ + struct ip6fl_iter_state *state = ip6fl_seq_private(seq); + + state->pid_ns = proc_pid_ns(file_inode(seq->file)->i_sb); + + rcu_read_lock_bh(); + return *pos ? ip6fl_get_idx(seq, *pos - 1) : SEQ_START_TOKEN; +} + +static void *ip6fl_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct ip6_flowlabel *fl; + + if (v == SEQ_START_TOKEN) + fl = ip6fl_get_first(seq); + else + fl = ip6fl_get_next(seq, v); + ++*pos; + return fl; +} + +static void ip6fl_seq_stop(struct seq_file *seq, void *v) + __releases(RCU) +{ + rcu_read_unlock_bh(); +} + +static int ip6fl_seq_show(struct seq_file *seq, void *v) +{ + struct ip6fl_iter_state *state = ip6fl_seq_private(seq); + if (v == SEQ_START_TOKEN) { + seq_puts(seq, "Label S Owner Users Linger Expires Dst Opt\n"); + } else { + struct ip6_flowlabel *fl = v; + seq_printf(seq, + "%05X %-1d %-6d %-6d %-6ld %-8ld %pi6 %-4d\n", + (unsigned int)ntohl(fl->label), + fl->share, + ((fl->share == IPV6_FL_S_PROCESS) ? + pid_nr_ns(fl->owner.pid, state->pid_ns) : + ((fl->share == IPV6_FL_S_USER) ? + from_kuid_munged(seq_user_ns(seq), fl->owner.uid) : + 0)), + atomic_read(&fl->users), + fl->linger/HZ, + (long)(fl->expires - jiffies)/HZ, + &fl->dst, + fl->opt ? fl->opt->opt_nflen : 0); + } + return 0; +} + +static const struct seq_operations ip6fl_seq_ops = { + .start = ip6fl_seq_start, + .next = ip6fl_seq_next, + .stop = ip6fl_seq_stop, + .show = ip6fl_seq_show, +}; + +static int __net_init ip6_flowlabel_proc_init(struct net *net) +{ + if (!proc_create_net("ip6_flowlabel", 0444, net->proc_net, + &ip6fl_seq_ops, sizeof(struct ip6fl_iter_state))) + return -ENOMEM; + return 0; +} + +static void __net_exit ip6_flowlabel_proc_fini(struct net *net) +{ + remove_proc_entry("ip6_flowlabel", net->proc_net); +} +#else +static inline int ip6_flowlabel_proc_init(struct net *net) +{ + return 0; +} +static inline void ip6_flowlabel_proc_fini(struct net *net) +{ +} +#endif + +static void __net_exit ip6_flowlabel_net_exit(struct net *net) +{ + ip6_fl_purge(net); + ip6_flowlabel_proc_fini(net); +} + +static struct pernet_operations ip6_flowlabel_net_ops = { + .init = ip6_flowlabel_proc_init, + .exit = ip6_flowlabel_net_exit, +}; + +int ip6_flowlabel_init(void) +{ + return register_pernet_subsys(&ip6_flowlabel_net_ops); +} + +void ip6_flowlabel_cleanup(void) +{ + static_key_deferred_flush(&ipv6_flowlabel_exclusive); + del_timer(&ip6_fl_gc_timer); + unregister_pernet_subsys(&ip6_flowlabel_net_ops); +} diff --git a/net/ipv6/ip6_gre.c b/net/ipv6/ip6_gre.c new file mode 100644 index 000000000..d3fba7d8d --- /dev/null +++ b/net/ipv6/ip6_gre.c @@ -0,0 +1,2436 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * GRE over IPv6 protocol decoder. + * + * Authors: Dmitry Kozlov (xeb@mail.ru) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + + +static bool log_ecn_error = true; +module_param(log_ecn_error, bool, 0644); +MODULE_PARM_DESC(log_ecn_error, "Log packets received with corrupted ECN"); + +#define IP6_GRE_HASH_SIZE_SHIFT 5 +#define IP6_GRE_HASH_SIZE (1 << IP6_GRE_HASH_SIZE_SHIFT) + +static unsigned int ip6gre_net_id __read_mostly; +struct ip6gre_net { + struct ip6_tnl __rcu *tunnels[4][IP6_GRE_HASH_SIZE]; + + struct ip6_tnl __rcu *collect_md_tun; + struct ip6_tnl __rcu *collect_md_tun_erspan; + struct net_device *fb_tunnel_dev; +}; + +static struct rtnl_link_ops ip6gre_link_ops __read_mostly; +static struct rtnl_link_ops ip6gre_tap_ops __read_mostly; +static struct rtnl_link_ops ip6erspan_tap_ops __read_mostly; +static int ip6gre_tunnel_init(struct net_device *dev); +static void ip6gre_tunnel_setup(struct net_device *dev); +static void ip6gre_tunnel_link(struct ip6gre_net *ign, struct ip6_tnl *t); +static void ip6gre_tnl_link_config(struct ip6_tnl *t, int set_mtu); +static void ip6erspan_tnl_link_config(struct ip6_tnl *t, int set_mtu); + +/* Tunnel hash table */ + +/* + 4 hash tables: + + 3: (remote,local) + 2: (remote,*) + 1: (*,local) + 0: (*,*) + + We require exact key match i.e. if a key is present in packet + it will match only tunnel with the same key; if it is not present, + it will match only keyless tunnel. + + All keysless packets, if not matched configured keyless tunnels + will match fallback tunnel. + */ + +#define HASH_KEY(key) (((__force u32)key^((__force u32)key>>4))&(IP6_GRE_HASH_SIZE - 1)) +static u32 HASH_ADDR(const struct in6_addr *addr) +{ + u32 hash = ipv6_addr_hash(addr); + + return hash_32(hash, IP6_GRE_HASH_SIZE_SHIFT); +} + +#define tunnels_r_l tunnels[3] +#define tunnels_r tunnels[2] +#define tunnels_l tunnels[1] +#define tunnels_wc tunnels[0] + +/* Given src, dst and key, find appropriate for input tunnel. */ + +static struct ip6_tnl *ip6gre_tunnel_lookup(struct net_device *dev, + const struct in6_addr *remote, const struct in6_addr *local, + __be32 key, __be16 gre_proto) +{ + struct net *net = dev_net(dev); + int link = dev->ifindex; + unsigned int h0 = HASH_ADDR(remote); + unsigned int h1 = HASH_KEY(key); + struct ip6_tnl *t, *cand = NULL; + struct ip6gre_net *ign = net_generic(net, ip6gre_net_id); + int dev_type = (gre_proto == htons(ETH_P_TEB) || + gre_proto == htons(ETH_P_ERSPAN) || + gre_proto == htons(ETH_P_ERSPAN2)) ? + ARPHRD_ETHER : ARPHRD_IP6GRE; + int score, cand_score = 4; + struct net_device *ndev; + + for_each_ip_tunnel_rcu(t, ign->tunnels_r_l[h0 ^ h1]) { + if (!ipv6_addr_equal(local, &t->parms.laddr) || + !ipv6_addr_equal(remote, &t->parms.raddr) || + key != t->parms.i_key || + !(t->dev->flags & IFF_UP)) + continue; + + if (t->dev->type != ARPHRD_IP6GRE && + t->dev->type != dev_type) + continue; + + score = 0; + if (t->parms.link != link) + score |= 1; + if (t->dev->type != dev_type) + score |= 2; + if (score == 0) + return t; + + if (score < cand_score) { + cand = t; + cand_score = score; + } + } + + for_each_ip_tunnel_rcu(t, ign->tunnels_r[h0 ^ h1]) { + if (!ipv6_addr_equal(remote, &t->parms.raddr) || + key != t->parms.i_key || + !(t->dev->flags & IFF_UP)) + continue; + + if (t->dev->type != ARPHRD_IP6GRE && + t->dev->type != dev_type) + continue; + + score = 0; + if (t->parms.link != link) + score |= 1; + if (t->dev->type != dev_type) + score |= 2; + if (score == 0) + return t; + + if (score < cand_score) { + cand = t; + cand_score = score; + } + } + + for_each_ip_tunnel_rcu(t, ign->tunnels_l[h1]) { + if ((!ipv6_addr_equal(local, &t->parms.laddr) && + (!ipv6_addr_equal(local, &t->parms.raddr) || + !ipv6_addr_is_multicast(local))) || + key != t->parms.i_key || + !(t->dev->flags & IFF_UP)) + continue; + + if (t->dev->type != ARPHRD_IP6GRE && + t->dev->type != dev_type) + continue; + + score = 0; + if (t->parms.link != link) + score |= 1; + if (t->dev->type != dev_type) + score |= 2; + if (score == 0) + return t; + + if (score < cand_score) { + cand = t; + cand_score = score; + } + } + + for_each_ip_tunnel_rcu(t, ign->tunnels_wc[h1]) { + if (t->parms.i_key != key || + !(t->dev->flags & IFF_UP)) + continue; + + if (t->dev->type != ARPHRD_IP6GRE && + t->dev->type != dev_type) + continue; + + score = 0; + if (t->parms.link != link) + score |= 1; + if (t->dev->type != dev_type) + score |= 2; + if (score == 0) + return t; + + if (score < cand_score) { + cand = t; + cand_score = score; + } + } + + if (cand) + return cand; + + if (gre_proto == htons(ETH_P_ERSPAN) || + gre_proto == htons(ETH_P_ERSPAN2)) + t = rcu_dereference(ign->collect_md_tun_erspan); + else + t = rcu_dereference(ign->collect_md_tun); + + if (t && t->dev->flags & IFF_UP) + return t; + + ndev = READ_ONCE(ign->fb_tunnel_dev); + if (ndev && ndev->flags & IFF_UP) + return netdev_priv(ndev); + + return NULL; +} + +static struct ip6_tnl __rcu **__ip6gre_bucket(struct ip6gre_net *ign, + const struct __ip6_tnl_parm *p) +{ + const struct in6_addr *remote = &p->raddr; + const struct in6_addr *local = &p->laddr; + unsigned int h = HASH_KEY(p->i_key); + int prio = 0; + + if (!ipv6_addr_any(local)) + prio |= 1; + if (!ipv6_addr_any(remote) && !ipv6_addr_is_multicast(remote)) { + prio |= 2; + h ^= HASH_ADDR(remote); + } + + return &ign->tunnels[prio][h]; +} + +static void ip6gre_tunnel_link_md(struct ip6gre_net *ign, struct ip6_tnl *t) +{ + if (t->parms.collect_md) + rcu_assign_pointer(ign->collect_md_tun, t); +} + +static void ip6erspan_tunnel_link_md(struct ip6gre_net *ign, struct ip6_tnl *t) +{ + if (t->parms.collect_md) + rcu_assign_pointer(ign->collect_md_tun_erspan, t); +} + +static void ip6gre_tunnel_unlink_md(struct ip6gre_net *ign, struct ip6_tnl *t) +{ + if (t->parms.collect_md) + rcu_assign_pointer(ign->collect_md_tun, NULL); +} + +static void ip6erspan_tunnel_unlink_md(struct ip6gre_net *ign, + struct ip6_tnl *t) +{ + if (t->parms.collect_md) + rcu_assign_pointer(ign->collect_md_tun_erspan, NULL); +} + +static inline struct ip6_tnl __rcu **ip6gre_bucket(struct ip6gre_net *ign, + const struct ip6_tnl *t) +{ + return __ip6gre_bucket(ign, &t->parms); +} + +static void ip6gre_tunnel_link(struct ip6gre_net *ign, struct ip6_tnl *t) +{ + struct ip6_tnl __rcu **tp = ip6gre_bucket(ign, t); + + rcu_assign_pointer(t->next, rtnl_dereference(*tp)); + rcu_assign_pointer(*tp, t); +} + +static void ip6gre_tunnel_unlink(struct ip6gre_net *ign, struct ip6_tnl *t) +{ + struct ip6_tnl __rcu **tp; + struct ip6_tnl *iter; + + for (tp = ip6gre_bucket(ign, t); + (iter = rtnl_dereference(*tp)) != NULL; + tp = &iter->next) { + if (t == iter) { + rcu_assign_pointer(*tp, t->next); + break; + } + } +} + +static struct ip6_tnl *ip6gre_tunnel_find(struct net *net, + const struct __ip6_tnl_parm *parms, + int type) +{ + const struct in6_addr *remote = &parms->raddr; + const struct in6_addr *local = &parms->laddr; + __be32 key = parms->i_key; + int link = parms->link; + struct ip6_tnl *t; + struct ip6_tnl __rcu **tp; + struct ip6gre_net *ign = net_generic(net, ip6gre_net_id); + + for (tp = __ip6gre_bucket(ign, parms); + (t = rtnl_dereference(*tp)) != NULL; + tp = &t->next) + if (ipv6_addr_equal(local, &t->parms.laddr) && + ipv6_addr_equal(remote, &t->parms.raddr) && + key == t->parms.i_key && + link == t->parms.link && + type == t->dev->type) + break; + + return t; +} + +static struct ip6_tnl *ip6gre_tunnel_locate(struct net *net, + const struct __ip6_tnl_parm *parms, int create) +{ + struct ip6_tnl *t, *nt; + struct net_device *dev; + char name[IFNAMSIZ]; + struct ip6gre_net *ign = net_generic(net, ip6gre_net_id); + + t = ip6gre_tunnel_find(net, parms, ARPHRD_IP6GRE); + if (t && create) + return NULL; + if (t || !create) + return t; + + if (parms->name[0]) { + if (!dev_valid_name(parms->name)) + return NULL; + strscpy(name, parms->name, IFNAMSIZ); + } else { + strcpy(name, "ip6gre%d"); + } + dev = alloc_netdev(sizeof(*t), name, NET_NAME_UNKNOWN, + ip6gre_tunnel_setup); + if (!dev) + return NULL; + + dev_net_set(dev, net); + + nt = netdev_priv(dev); + nt->parms = *parms; + dev->rtnl_link_ops = &ip6gre_link_ops; + + nt->dev = dev; + nt->net = dev_net(dev); + + if (register_netdevice(dev) < 0) + goto failed_free; + + ip6gre_tnl_link_config(nt, 1); + ip6gre_tunnel_link(ign, nt); + return nt; + +failed_free: + free_netdev(dev); + return NULL; +} + +static void ip6erspan_tunnel_uninit(struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct ip6gre_net *ign = net_generic(t->net, ip6gre_net_id); + + ip6erspan_tunnel_unlink_md(ign, t); + ip6gre_tunnel_unlink(ign, t); + dst_cache_reset(&t->dst_cache); + netdev_put(dev, &t->dev_tracker); +} + +static void ip6gre_tunnel_uninit(struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct ip6gre_net *ign = net_generic(t->net, ip6gre_net_id); + + ip6gre_tunnel_unlink_md(ign, t); + ip6gre_tunnel_unlink(ign, t); + if (ign->fb_tunnel_dev == dev) + WRITE_ONCE(ign->fb_tunnel_dev, NULL); + dst_cache_reset(&t->dst_cache); + netdev_put(dev, &t->dev_tracker); +} + + +static int ip6gre_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + struct net *net = dev_net(skb->dev); + const struct ipv6hdr *ipv6h; + struct tnl_ptk_info tpi; + struct ip6_tnl *t; + + if (gre_parse_header(skb, &tpi, NULL, htons(ETH_P_IPV6), + offset) < 0) + return -EINVAL; + + ipv6h = (const struct ipv6hdr *)skb->data; + t = ip6gre_tunnel_lookup(skb->dev, &ipv6h->daddr, &ipv6h->saddr, + tpi.key, tpi.proto); + if (!t) + return -ENOENT; + + switch (type) { + case ICMPV6_DEST_UNREACH: + net_dbg_ratelimited("%s: Path to destination invalid or inactive!\n", + t->parms.name); + if (code != ICMPV6_PORT_UNREACH) + break; + return 0; + case ICMPV6_TIME_EXCEED: + if (code == ICMPV6_EXC_HOPLIMIT) { + net_dbg_ratelimited("%s: Too small hop limit or routing loop in tunnel!\n", + t->parms.name); + break; + } + return 0; + case ICMPV6_PARAMPROB: { + struct ipv6_tlv_tnl_enc_lim *tel; + __u32 teli; + + teli = 0; + if (code == ICMPV6_HDR_FIELD) + teli = ip6_tnl_parse_tlv_enc_lim(skb, skb->data); + + if (teli && teli == be32_to_cpu(info) - 2) { + tel = (struct ipv6_tlv_tnl_enc_lim *) &skb->data[teli]; + if (tel->encap_limit == 0) { + net_dbg_ratelimited("%s: Too small encapsulation limit or routing loop in tunnel!\n", + t->parms.name); + } + } else { + net_dbg_ratelimited("%s: Recipient unable to parse tunneled packet!\n", + t->parms.name); + } + return 0; + } + case ICMPV6_PKT_TOOBIG: + ip6_update_pmtu(skb, net, info, 0, 0, sock_net_uid(net, NULL)); + return 0; + case NDISC_REDIRECT: + ip6_redirect(skb, net, skb->dev->ifindex, 0, + sock_net_uid(net, NULL)); + return 0; + } + + if (time_before(jiffies, t->err_time + IP6TUNNEL_ERR_TIMEO)) + t->err_count++; + else + t->err_count = 1; + t->err_time = jiffies; + + return 0; +} + +static int ip6gre_rcv(struct sk_buff *skb, const struct tnl_ptk_info *tpi) +{ + const struct ipv6hdr *ipv6h; + struct ip6_tnl *tunnel; + + ipv6h = ipv6_hdr(skb); + tunnel = ip6gre_tunnel_lookup(skb->dev, + &ipv6h->saddr, &ipv6h->daddr, tpi->key, + tpi->proto); + if (tunnel) { + if (tunnel->parms.collect_md) { + struct metadata_dst *tun_dst; + __be64 tun_id; + __be16 flags; + + flags = tpi->flags; + tun_id = key32_to_tunnel_id(tpi->key); + + tun_dst = ipv6_tun_rx_dst(skb, flags, tun_id, 0); + if (!tun_dst) + return PACKET_REJECT; + + ip6_tnl_rcv(tunnel, skb, tpi, tun_dst, log_ecn_error); + } else { + ip6_tnl_rcv(tunnel, skb, tpi, NULL, log_ecn_error); + } + + return PACKET_RCVD; + } + + return PACKET_REJECT; +} + +static int ip6erspan_rcv(struct sk_buff *skb, + struct tnl_ptk_info *tpi, + int gre_hdr_len) +{ + struct erspan_base_hdr *ershdr; + const struct ipv6hdr *ipv6h; + struct erspan_md2 *md2; + struct ip6_tnl *tunnel; + u8 ver; + + ipv6h = ipv6_hdr(skb); + ershdr = (struct erspan_base_hdr *)skb->data; + ver = ershdr->ver; + + tunnel = ip6gre_tunnel_lookup(skb->dev, + &ipv6h->saddr, &ipv6h->daddr, tpi->key, + tpi->proto); + if (tunnel) { + int len = erspan_hdr_len(ver); + + if (unlikely(!pskb_may_pull(skb, len))) + return PACKET_REJECT; + + if (__iptunnel_pull_header(skb, len, + htons(ETH_P_TEB), + false, false) < 0) + return PACKET_REJECT; + + if (tunnel->parms.collect_md) { + struct erspan_metadata *pkt_md, *md; + struct metadata_dst *tun_dst; + struct ip_tunnel_info *info; + unsigned char *gh; + __be64 tun_id; + __be16 flags; + + tpi->flags |= TUNNEL_KEY; + flags = tpi->flags; + tun_id = key32_to_tunnel_id(tpi->key); + + tun_dst = ipv6_tun_rx_dst(skb, flags, tun_id, + sizeof(*md)); + if (!tun_dst) + return PACKET_REJECT; + + /* skb can be uncloned in __iptunnel_pull_header, so + * old pkt_md is no longer valid and we need to reset + * it + */ + gh = skb_network_header(skb) + + skb_network_header_len(skb); + pkt_md = (struct erspan_metadata *)(gh + gre_hdr_len + + sizeof(*ershdr)); + info = &tun_dst->u.tun_info; + md = ip_tunnel_info_opts(info); + md->version = ver; + md2 = &md->u.md2; + memcpy(md2, pkt_md, ver == 1 ? ERSPAN_V1_MDSIZE : + ERSPAN_V2_MDSIZE); + info->key.tun_flags |= TUNNEL_ERSPAN_OPT; + info->options_len = sizeof(*md); + + ip6_tnl_rcv(tunnel, skb, tpi, tun_dst, log_ecn_error); + + } else { + ip6_tnl_rcv(tunnel, skb, tpi, NULL, log_ecn_error); + } + + return PACKET_RCVD; + } + + return PACKET_REJECT; +} + +static int gre_rcv(struct sk_buff *skb) +{ + struct tnl_ptk_info tpi; + bool csum_err = false; + int hdr_len; + + hdr_len = gre_parse_header(skb, &tpi, &csum_err, htons(ETH_P_IPV6), 0); + if (hdr_len < 0) + goto drop; + + if (iptunnel_pull_header(skb, hdr_len, tpi.proto, false)) + goto drop; + + if (unlikely(tpi.proto == htons(ETH_P_ERSPAN) || + tpi.proto == htons(ETH_P_ERSPAN2))) { + if (ip6erspan_rcv(skb, &tpi, hdr_len) == PACKET_RCVD) + return 0; + goto out; + } + + if (ip6gre_rcv(skb, &tpi) == PACKET_RCVD) + return 0; + +out: + icmpv6_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_PORT_UNREACH, 0); +drop: + kfree_skb(skb); + return 0; +} + +static int gre_handle_offloads(struct sk_buff *skb, bool csum) +{ + return iptunnel_handle_offloads(skb, + csum ? SKB_GSO_GRE_CSUM : SKB_GSO_GRE); +} + +static void prepare_ip6gre_xmit_ipv4(struct sk_buff *skb, + struct net_device *dev, + struct flowi6 *fl6, __u8 *dsfield, + int *encap_limit) +{ + const struct iphdr *iph = ip_hdr(skb); + struct ip6_tnl *t = netdev_priv(dev); + + if (!(t->parms.flags & IP6_TNL_F_IGN_ENCAP_LIMIT)) + *encap_limit = t->parms.encap_limit; + + memcpy(fl6, &t->fl.u.ip6, sizeof(*fl6)); + + if (t->parms.flags & IP6_TNL_F_USE_ORIG_TCLASS) + *dsfield = ipv4_get_dsfield(iph); + else + *dsfield = ip6_tclass(t->parms.flowinfo); + + if (t->parms.flags & IP6_TNL_F_USE_ORIG_FWMARK) + fl6->flowi6_mark = skb->mark; + else + fl6->flowi6_mark = t->parms.fwmark; + + fl6->flowi6_uid = sock_net_uid(dev_net(dev), NULL); +} + +static int prepare_ip6gre_xmit_ipv6(struct sk_buff *skb, + struct net_device *dev, + struct flowi6 *fl6, __u8 *dsfield, + int *encap_limit) +{ + struct ipv6hdr *ipv6h; + struct ip6_tnl *t = netdev_priv(dev); + __u16 offset; + + offset = ip6_tnl_parse_tlv_enc_lim(skb, skb_network_header(skb)); + /* ip6_tnl_parse_tlv_enc_lim() might have reallocated skb->head */ + ipv6h = ipv6_hdr(skb); + + if (offset > 0) { + struct ipv6_tlv_tnl_enc_lim *tel; + + tel = (struct ipv6_tlv_tnl_enc_lim *)&skb_network_header(skb)[offset]; + if (tel->encap_limit == 0) { + icmpv6_ndo_send(skb, ICMPV6_PARAMPROB, + ICMPV6_HDR_FIELD, offset + 2); + return -1; + } + *encap_limit = tel->encap_limit - 1; + } else if (!(t->parms.flags & IP6_TNL_F_IGN_ENCAP_LIMIT)) { + *encap_limit = t->parms.encap_limit; + } + + memcpy(fl6, &t->fl.u.ip6, sizeof(*fl6)); + + if (t->parms.flags & IP6_TNL_F_USE_ORIG_TCLASS) + *dsfield = ipv6_get_dsfield(ipv6h); + else + *dsfield = ip6_tclass(t->parms.flowinfo); + + if (t->parms.flags & IP6_TNL_F_USE_ORIG_FLOWLABEL) + fl6->flowlabel |= ip6_flowlabel(ipv6h); + + if (t->parms.flags & IP6_TNL_F_USE_ORIG_FWMARK) + fl6->flowi6_mark = skb->mark; + else + fl6->flowi6_mark = t->parms.fwmark; + + fl6->flowi6_uid = sock_net_uid(dev_net(dev), NULL); + + return 0; +} + +static int prepare_ip6gre_xmit_other(struct sk_buff *skb, + struct net_device *dev, + struct flowi6 *fl6, __u8 *dsfield, + int *encap_limit) +{ + struct ip6_tnl *t = netdev_priv(dev); + + if (!(t->parms.flags & IP6_TNL_F_IGN_ENCAP_LIMIT)) + *encap_limit = t->parms.encap_limit; + + memcpy(fl6, &t->fl.u.ip6, sizeof(*fl6)); + + if (t->parms.flags & IP6_TNL_F_USE_ORIG_TCLASS) + *dsfield = 0; + else + *dsfield = ip6_tclass(t->parms.flowinfo); + + if (t->parms.flags & IP6_TNL_F_USE_ORIG_FWMARK) + fl6->flowi6_mark = skb->mark; + else + fl6->flowi6_mark = t->parms.fwmark; + + fl6->flowi6_uid = sock_net_uid(dev_net(dev), NULL); + + return 0; +} + +static struct ip_tunnel_info *skb_tunnel_info_txcheck(struct sk_buff *skb) +{ + struct ip_tunnel_info *tun_info; + + tun_info = skb_tunnel_info(skb); + if (unlikely(!tun_info || !(tun_info->mode & IP_TUNNEL_INFO_TX))) + return ERR_PTR(-EINVAL); + + return tun_info; +} + +static netdev_tx_t __gre6_xmit(struct sk_buff *skb, + struct net_device *dev, __u8 dsfield, + struct flowi6 *fl6, int encap_limit, + __u32 *pmtu, __be16 proto) +{ + struct ip6_tnl *tunnel = netdev_priv(dev); + __be16 protocol; + __be16 flags; + + if (dev->type == ARPHRD_ETHER) + IPCB(skb)->flags = 0; + + if (dev->header_ops && dev->type == ARPHRD_IP6GRE) + fl6->daddr = ((struct ipv6hdr *)skb->data)->daddr; + else + fl6->daddr = tunnel->parms.raddr; + + /* Push GRE header. */ + protocol = (dev->type == ARPHRD_ETHER) ? htons(ETH_P_TEB) : proto; + + if (tunnel->parms.collect_md) { + struct ip_tunnel_info *tun_info; + const struct ip_tunnel_key *key; + int tun_hlen; + + tun_info = skb_tunnel_info_txcheck(skb); + if (IS_ERR(tun_info) || + unlikely(ip_tunnel_info_af(tun_info) != AF_INET6)) + return -EINVAL; + + key = &tun_info->key; + memset(fl6, 0, sizeof(*fl6)); + fl6->flowi6_proto = IPPROTO_GRE; + fl6->daddr = key->u.ipv6.dst; + fl6->flowlabel = key->label; + fl6->flowi6_uid = sock_net_uid(dev_net(dev), NULL); + fl6->fl6_gre_key = tunnel_id_to_key32(key->tun_id); + + dsfield = key->tos; + flags = key->tun_flags & + (TUNNEL_CSUM | TUNNEL_KEY | TUNNEL_SEQ); + tun_hlen = gre_calc_hlen(flags); + + if (skb_cow_head(skb, dev->needed_headroom ?: tun_hlen + tunnel->encap_hlen)) + return -ENOMEM; + + gre_build_header(skb, tun_hlen, + flags, protocol, + tunnel_id_to_key32(tun_info->key.tun_id), + (flags & TUNNEL_SEQ) ? htonl(atomic_fetch_inc(&tunnel->o_seqno)) + : 0); + + } else { + if (skb_cow_head(skb, dev->needed_headroom ?: tunnel->hlen)) + return -ENOMEM; + + flags = tunnel->parms.o_flags; + + gre_build_header(skb, tunnel->tun_hlen, flags, + protocol, tunnel->parms.o_key, + (flags & TUNNEL_SEQ) ? htonl(atomic_fetch_inc(&tunnel->o_seqno)) + : 0); + } + + return ip6_tnl_xmit(skb, dev, dsfield, fl6, encap_limit, pmtu, + NEXTHDR_GRE); +} + +static inline int ip6gre_xmit_ipv4(struct sk_buff *skb, struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + int encap_limit = -1; + struct flowi6 fl6; + __u8 dsfield = 0; + __u32 mtu; + int err; + + memset(&(IPCB(skb)->opt), 0, sizeof(IPCB(skb)->opt)); + + if (!t->parms.collect_md) + prepare_ip6gre_xmit_ipv4(skb, dev, &fl6, + &dsfield, &encap_limit); + + err = gre_handle_offloads(skb, !!(t->parms.o_flags & TUNNEL_CSUM)); + if (err) + return -1; + + err = __gre6_xmit(skb, dev, dsfield, &fl6, encap_limit, &mtu, + skb->protocol); + if (err != 0) { + /* XXX: send ICMP error even if DF is not set. */ + if (err == -EMSGSIZE) + icmp_ndo_send(skb, ICMP_DEST_UNREACH, ICMP_FRAG_NEEDED, + htonl(mtu)); + return -1; + } + + return 0; +} + +static inline int ip6gre_xmit_ipv6(struct sk_buff *skb, struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct ipv6hdr *ipv6h = ipv6_hdr(skb); + int encap_limit = -1; + struct flowi6 fl6; + __u8 dsfield = 0; + __u32 mtu; + int err; + + if (ipv6_addr_equal(&t->parms.raddr, &ipv6h->saddr)) + return -1; + + if (!t->parms.collect_md && + prepare_ip6gre_xmit_ipv6(skb, dev, &fl6, &dsfield, &encap_limit)) + return -1; + + if (gre_handle_offloads(skb, !!(t->parms.o_flags & TUNNEL_CSUM))) + return -1; + + err = __gre6_xmit(skb, dev, dsfield, &fl6, encap_limit, + &mtu, skb->protocol); + if (err != 0) { + if (err == -EMSGSIZE) + icmpv6_ndo_send(skb, ICMPV6_PKT_TOOBIG, 0, mtu); + return -1; + } + + return 0; +} + +/** + * ip6gre_tnl_addr_conflict - compare packet addresses to tunnel's own + * @t: the outgoing tunnel device + * @hdr: IPv6 header from the incoming packet + * + * Description: + * Avoid trivial tunneling loop by checking that tunnel exit-point + * doesn't match source of incoming packet. + * + * Return: + * 1 if conflict, + * 0 else + **/ + +static inline bool ip6gre_tnl_addr_conflict(const struct ip6_tnl *t, + const struct ipv6hdr *hdr) +{ + return ipv6_addr_equal(&t->parms.raddr, &hdr->saddr); +} + +static int ip6gre_xmit_other(struct sk_buff *skb, struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + int encap_limit = -1; + struct flowi6 fl6; + __u8 dsfield = 0; + __u32 mtu; + int err; + + if (!t->parms.collect_md && + prepare_ip6gre_xmit_other(skb, dev, &fl6, &dsfield, &encap_limit)) + return -1; + + err = gre_handle_offloads(skb, !!(t->parms.o_flags & TUNNEL_CSUM)); + if (err) + return err; + err = __gre6_xmit(skb, dev, dsfield, &fl6, encap_limit, &mtu, skb->protocol); + + return err; +} + +static netdev_tx_t ip6gre_tunnel_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct net_device_stats *stats = &t->dev->stats; + __be16 payload_protocol; + int ret; + + if (!pskb_inet_may_pull(skb)) + goto tx_err; + + if (!ip6_tnl_xmit_ctl(t, &t->parms.laddr, &t->parms.raddr)) + goto tx_err; + + payload_protocol = skb_protocol(skb, true); + switch (payload_protocol) { + case htons(ETH_P_IP): + ret = ip6gre_xmit_ipv4(skb, dev); + break; + case htons(ETH_P_IPV6): + ret = ip6gre_xmit_ipv6(skb, dev); + break; + default: + ret = ip6gre_xmit_other(skb, dev); + break; + } + + if (ret < 0) + goto tx_err; + + return NETDEV_TX_OK; + +tx_err: + if (!t->parms.collect_md || !IS_ERR(skb_tunnel_info_txcheck(skb))) + stats->tx_errors++; + stats->tx_dropped++; + kfree_skb(skb); + return NETDEV_TX_OK; +} + +static netdev_tx_t ip6erspan_tunnel_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct ip_tunnel_info *tun_info = NULL; + struct ip6_tnl *t = netdev_priv(dev); + struct dst_entry *dst = skb_dst(skb); + struct net_device_stats *stats; + bool truncate = false; + int encap_limit = -1; + __u8 dsfield = false; + struct flowi6 fl6; + int err = -EINVAL; + __be16 proto; + __u32 mtu; + int nhoff; + + if (!pskb_inet_may_pull(skb)) + goto tx_err; + + if (!ip6_tnl_xmit_ctl(t, &t->parms.laddr, &t->parms.raddr)) + goto tx_err; + + if (gre_handle_offloads(skb, false)) + goto tx_err; + + if (skb->len > dev->mtu + dev->hard_header_len) { + if (pskb_trim(skb, dev->mtu + dev->hard_header_len)) + goto tx_err; + truncate = true; + } + + nhoff = skb_network_offset(skb); + if (skb->protocol == htons(ETH_P_IP) && + (ntohs(ip_hdr(skb)->tot_len) > skb->len - nhoff)) + truncate = true; + + if (skb->protocol == htons(ETH_P_IPV6)) { + int thoff; + + if (skb_transport_header_was_set(skb)) + thoff = skb_transport_offset(skb); + else + thoff = nhoff + sizeof(struct ipv6hdr); + if (ntohs(ipv6_hdr(skb)->payload_len) > skb->len - thoff) + truncate = true; + } + + if (skb_cow_head(skb, dev->needed_headroom ?: t->hlen)) + goto tx_err; + + t->parms.o_flags &= ~TUNNEL_KEY; + IPCB(skb)->flags = 0; + + /* For collect_md mode, derive fl6 from the tunnel key, + * for native mode, call prepare_ip6gre_xmit_{ipv4,ipv6}. + */ + if (t->parms.collect_md) { + const struct ip_tunnel_key *key; + struct erspan_metadata *md; + __be32 tun_id; + + tun_info = skb_tunnel_info_txcheck(skb); + if (IS_ERR(tun_info) || + unlikely(ip_tunnel_info_af(tun_info) != AF_INET6)) + goto tx_err; + + key = &tun_info->key; + memset(&fl6, 0, sizeof(fl6)); + fl6.flowi6_proto = IPPROTO_GRE; + fl6.daddr = key->u.ipv6.dst; + fl6.flowlabel = key->label; + fl6.flowi6_uid = sock_net_uid(dev_net(dev), NULL); + fl6.fl6_gre_key = tunnel_id_to_key32(key->tun_id); + + dsfield = key->tos; + if (!(tun_info->key.tun_flags & TUNNEL_ERSPAN_OPT)) + goto tx_err; + if (tun_info->options_len < sizeof(*md)) + goto tx_err; + md = ip_tunnel_info_opts(tun_info); + + tun_id = tunnel_id_to_key32(key->tun_id); + if (md->version == 1) { + erspan_build_header(skb, + ntohl(tun_id), + ntohl(md->u.index), truncate, + false); + proto = htons(ETH_P_ERSPAN); + } else if (md->version == 2) { + erspan_build_header_v2(skb, + ntohl(tun_id), + md->u.md2.dir, + get_hwid(&md->u.md2), + truncate, false); + proto = htons(ETH_P_ERSPAN2); + } else { + goto tx_err; + } + } else { + switch (skb->protocol) { + case htons(ETH_P_IP): + memset(&(IPCB(skb)->opt), 0, sizeof(IPCB(skb)->opt)); + prepare_ip6gre_xmit_ipv4(skb, dev, &fl6, + &dsfield, &encap_limit); + break; + case htons(ETH_P_IPV6): + if (ipv6_addr_equal(&t->parms.raddr, &ipv6_hdr(skb)->saddr)) + goto tx_err; + if (prepare_ip6gre_xmit_ipv6(skb, dev, &fl6, + &dsfield, &encap_limit)) + goto tx_err; + break; + default: + memcpy(&fl6, &t->fl.u.ip6, sizeof(fl6)); + break; + } + + if (t->parms.erspan_ver == 1) { + erspan_build_header(skb, ntohl(t->parms.o_key), + t->parms.index, + truncate, false); + proto = htons(ETH_P_ERSPAN); + } else if (t->parms.erspan_ver == 2) { + erspan_build_header_v2(skb, ntohl(t->parms.o_key), + t->parms.dir, + t->parms.hwid, + truncate, false); + proto = htons(ETH_P_ERSPAN2); + } else { + goto tx_err; + } + + fl6.daddr = t->parms.raddr; + } + + /* Push GRE header. */ + gre_build_header(skb, 8, TUNNEL_SEQ, proto, 0, htonl(atomic_fetch_inc(&t->o_seqno))); + + /* TooBig packet may have updated dst->dev's mtu */ + if (!t->parms.collect_md && dst && dst_mtu(dst) > dst->dev->mtu) + dst->ops->update_pmtu(dst, NULL, skb, dst->dev->mtu, false); + + err = ip6_tnl_xmit(skb, dev, dsfield, &fl6, encap_limit, &mtu, + NEXTHDR_GRE); + if (err != 0) { + /* XXX: send ICMP error even if DF is not set. */ + if (err == -EMSGSIZE) { + if (skb->protocol == htons(ETH_P_IP)) + icmp_ndo_send(skb, ICMP_DEST_UNREACH, + ICMP_FRAG_NEEDED, htonl(mtu)); + else + icmpv6_ndo_send(skb, ICMPV6_PKT_TOOBIG, 0, mtu); + } + + goto tx_err; + } + return NETDEV_TX_OK; + +tx_err: + stats = &t->dev->stats; + if (!IS_ERR(tun_info)) + stats->tx_errors++; + stats->tx_dropped++; + kfree_skb(skb); + return NETDEV_TX_OK; +} + +static void ip6gre_tnl_link_config_common(struct ip6_tnl *t) +{ + struct net_device *dev = t->dev; + struct __ip6_tnl_parm *p = &t->parms; + struct flowi6 *fl6 = &t->fl.u.ip6; + + if (dev->type != ARPHRD_ETHER) { + __dev_addr_set(dev, &p->laddr, sizeof(struct in6_addr)); + memcpy(dev->broadcast, &p->raddr, sizeof(struct in6_addr)); + } + + /* Set up flowi template */ + fl6->saddr = p->laddr; + fl6->daddr = p->raddr; + fl6->flowi6_oif = p->link; + fl6->flowlabel = 0; + fl6->flowi6_proto = IPPROTO_GRE; + fl6->fl6_gre_key = t->parms.o_key; + + if (!(p->flags&IP6_TNL_F_USE_ORIG_TCLASS)) + fl6->flowlabel |= IPV6_TCLASS_MASK & p->flowinfo; + if (!(p->flags&IP6_TNL_F_USE_ORIG_FLOWLABEL)) + fl6->flowlabel |= IPV6_FLOWLABEL_MASK & p->flowinfo; + + p->flags &= ~(IP6_TNL_F_CAP_XMIT|IP6_TNL_F_CAP_RCV|IP6_TNL_F_CAP_PER_PACKET); + p->flags |= ip6_tnl_get_cap(t, &p->laddr, &p->raddr); + + if (p->flags&IP6_TNL_F_CAP_XMIT && + p->flags&IP6_TNL_F_CAP_RCV && dev->type != ARPHRD_ETHER) + dev->flags |= IFF_POINTOPOINT; + else + dev->flags &= ~IFF_POINTOPOINT; +} + +static void ip6gre_tnl_link_config_route(struct ip6_tnl *t, int set_mtu, + int t_hlen) +{ + const struct __ip6_tnl_parm *p = &t->parms; + struct net_device *dev = t->dev; + + if (p->flags & IP6_TNL_F_CAP_XMIT) { + int strict = (ipv6_addr_type(&p->raddr) & + (IPV6_ADDR_MULTICAST|IPV6_ADDR_LINKLOCAL)); + + struct rt6_info *rt = rt6_lookup(t->net, + &p->raddr, &p->laddr, + p->link, NULL, strict); + + if (!rt) + return; + + if (rt->dst.dev) { + unsigned short dst_len = rt->dst.dev->hard_header_len + + t_hlen; + + if (t->dev->header_ops) + dev->hard_header_len = dst_len; + else + dev->needed_headroom = dst_len; + + if (set_mtu) { + int mtu = rt->dst.dev->mtu - t_hlen; + + if (!(t->parms.flags & IP6_TNL_F_IGN_ENCAP_LIMIT)) + mtu -= 8; + if (dev->type == ARPHRD_ETHER) + mtu -= ETH_HLEN; + + if (mtu < IPV6_MIN_MTU) + mtu = IPV6_MIN_MTU; + WRITE_ONCE(dev->mtu, mtu); + } + } + ip6_rt_put(rt); + } +} + +static int ip6gre_calc_hlen(struct ip6_tnl *tunnel) +{ + int t_hlen; + + tunnel->tun_hlen = gre_calc_hlen(tunnel->parms.o_flags); + tunnel->hlen = tunnel->tun_hlen + tunnel->encap_hlen; + + t_hlen = tunnel->hlen + sizeof(struct ipv6hdr); + + if (tunnel->dev->header_ops) + tunnel->dev->hard_header_len = LL_MAX_HEADER + t_hlen; + else + tunnel->dev->needed_headroom = LL_MAX_HEADER + t_hlen; + + return t_hlen; +} + +static void ip6gre_tnl_link_config(struct ip6_tnl *t, int set_mtu) +{ + ip6gre_tnl_link_config_common(t); + ip6gre_tnl_link_config_route(t, set_mtu, ip6gre_calc_hlen(t)); +} + +static void ip6gre_tnl_copy_tnl_parm(struct ip6_tnl *t, + const struct __ip6_tnl_parm *p) +{ + t->parms.laddr = p->laddr; + t->parms.raddr = p->raddr; + t->parms.flags = p->flags; + t->parms.hop_limit = p->hop_limit; + t->parms.encap_limit = p->encap_limit; + t->parms.flowinfo = p->flowinfo; + t->parms.link = p->link; + t->parms.proto = p->proto; + t->parms.i_key = p->i_key; + t->parms.o_key = p->o_key; + t->parms.i_flags = p->i_flags; + t->parms.o_flags = p->o_flags; + t->parms.fwmark = p->fwmark; + t->parms.erspan_ver = p->erspan_ver; + t->parms.index = p->index; + t->parms.dir = p->dir; + t->parms.hwid = p->hwid; + dst_cache_reset(&t->dst_cache); +} + +static int ip6gre_tnl_change(struct ip6_tnl *t, const struct __ip6_tnl_parm *p, + int set_mtu) +{ + ip6gre_tnl_copy_tnl_parm(t, p); + ip6gre_tnl_link_config(t, set_mtu); + return 0; +} + +static void ip6gre_tnl_parm_from_user(struct __ip6_tnl_parm *p, + const struct ip6_tnl_parm2 *u) +{ + p->laddr = u->laddr; + p->raddr = u->raddr; + p->flags = u->flags; + p->hop_limit = u->hop_limit; + p->encap_limit = u->encap_limit; + p->flowinfo = u->flowinfo; + p->link = u->link; + p->i_key = u->i_key; + p->o_key = u->o_key; + p->i_flags = gre_flags_to_tnl_flags(u->i_flags); + p->o_flags = gre_flags_to_tnl_flags(u->o_flags); + memcpy(p->name, u->name, sizeof(u->name)); +} + +static void ip6gre_tnl_parm_to_user(struct ip6_tnl_parm2 *u, + const struct __ip6_tnl_parm *p) +{ + u->proto = IPPROTO_GRE; + u->laddr = p->laddr; + u->raddr = p->raddr; + u->flags = p->flags; + u->hop_limit = p->hop_limit; + u->encap_limit = p->encap_limit; + u->flowinfo = p->flowinfo; + u->link = p->link; + u->i_key = p->i_key; + u->o_key = p->o_key; + u->i_flags = gre_tnl_flags_to_gre_flags(p->i_flags); + u->o_flags = gre_tnl_flags_to_gre_flags(p->o_flags); + memcpy(u->name, p->name, sizeof(u->name)); +} + +static int ip6gre_tunnel_siocdevprivate(struct net_device *dev, + struct ifreq *ifr, void __user *data, + int cmd) +{ + int err = 0; + struct ip6_tnl_parm2 p; + struct __ip6_tnl_parm p1; + struct ip6_tnl *t = netdev_priv(dev); + struct net *net = t->net; + struct ip6gre_net *ign = net_generic(net, ip6gre_net_id); + + memset(&p1, 0, sizeof(p1)); + + switch (cmd) { + case SIOCGETTUNNEL: + if (dev == ign->fb_tunnel_dev) { + if (copy_from_user(&p, data, sizeof(p))) { + err = -EFAULT; + break; + } + ip6gre_tnl_parm_from_user(&p1, &p); + t = ip6gre_tunnel_locate(net, &p1, 0); + if (!t) + t = netdev_priv(dev); + } + memset(&p, 0, sizeof(p)); + ip6gre_tnl_parm_to_user(&p, &t->parms); + if (copy_to_user(data, &p, sizeof(p))) + err = -EFAULT; + break; + + case SIOCADDTUNNEL: + case SIOCCHGTUNNEL: + err = -EPERM; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + goto done; + + err = -EFAULT; + if (copy_from_user(&p, data, sizeof(p))) + goto done; + + err = -EINVAL; + if ((p.i_flags|p.o_flags)&(GRE_VERSION|GRE_ROUTING)) + goto done; + + if (!(p.i_flags&GRE_KEY)) + p.i_key = 0; + if (!(p.o_flags&GRE_KEY)) + p.o_key = 0; + + ip6gre_tnl_parm_from_user(&p1, &p); + t = ip6gre_tunnel_locate(net, &p1, cmd == SIOCADDTUNNEL); + + if (dev != ign->fb_tunnel_dev && cmd == SIOCCHGTUNNEL) { + if (t) { + if (t->dev != dev) { + err = -EEXIST; + break; + } + } else { + t = netdev_priv(dev); + + ip6gre_tunnel_unlink(ign, t); + synchronize_net(); + ip6gre_tnl_change(t, &p1, 1); + ip6gre_tunnel_link(ign, t); + netdev_state_change(dev); + } + } + + if (t) { + err = 0; + + memset(&p, 0, sizeof(p)); + ip6gre_tnl_parm_to_user(&p, &t->parms); + if (copy_to_user(data, &p, sizeof(p))) + err = -EFAULT; + } else + err = (cmd == SIOCADDTUNNEL ? -ENOBUFS : -ENOENT); + break; + + case SIOCDELTUNNEL: + err = -EPERM; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + goto done; + + if (dev == ign->fb_tunnel_dev) { + err = -EFAULT; + if (copy_from_user(&p, data, sizeof(p))) + goto done; + err = -ENOENT; + ip6gre_tnl_parm_from_user(&p1, &p); + t = ip6gre_tunnel_locate(net, &p1, 0); + if (!t) + goto done; + err = -EPERM; + if (t == netdev_priv(ign->fb_tunnel_dev)) + goto done; + dev = t->dev; + } + unregister_netdevice(dev); + err = 0; + break; + + default: + err = -EINVAL; + } + +done: + return err; +} + +static int ip6gre_header(struct sk_buff *skb, struct net_device *dev, + unsigned short type, const void *daddr, + const void *saddr, unsigned int len) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct ipv6hdr *ipv6h; + __be16 *p; + + ipv6h = skb_push(skb, t->hlen + sizeof(*ipv6h)); + ip6_flow_hdr(ipv6h, 0, ip6_make_flowlabel(dev_net(dev), skb, + t->fl.u.ip6.flowlabel, + true, &t->fl.u.ip6)); + ipv6h->hop_limit = t->parms.hop_limit; + ipv6h->nexthdr = NEXTHDR_GRE; + ipv6h->saddr = t->parms.laddr; + ipv6h->daddr = t->parms.raddr; + + p = (__be16 *)(ipv6h + 1); + p[0] = t->parms.o_flags; + p[1] = htons(type); + + /* + * Set the source hardware address. + */ + + if (saddr) + memcpy(&ipv6h->saddr, saddr, sizeof(struct in6_addr)); + if (daddr) + memcpy(&ipv6h->daddr, daddr, sizeof(struct in6_addr)); + if (!ipv6_addr_any(&ipv6h->daddr)) + return t->hlen; + + return -t->hlen; +} + +static const struct header_ops ip6gre_header_ops = { + .create = ip6gre_header, +}; + +static const struct net_device_ops ip6gre_netdev_ops = { + .ndo_init = ip6gre_tunnel_init, + .ndo_uninit = ip6gre_tunnel_uninit, + .ndo_start_xmit = ip6gre_tunnel_xmit, + .ndo_siocdevprivate = ip6gre_tunnel_siocdevprivate, + .ndo_change_mtu = ip6_tnl_change_mtu, + .ndo_get_stats64 = dev_get_tstats64, + .ndo_get_iflink = ip6_tnl_get_iflink, +}; + +static void ip6gre_dev_free(struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + + gro_cells_destroy(&t->gro_cells); + dst_cache_destroy(&t->dst_cache); + free_percpu(dev->tstats); +} + +static void ip6gre_tunnel_setup(struct net_device *dev) +{ + dev->netdev_ops = &ip6gre_netdev_ops; + dev->needs_free_netdev = true; + dev->priv_destructor = ip6gre_dev_free; + + dev->type = ARPHRD_IP6GRE; + + dev->flags |= IFF_NOARP; + dev->addr_len = sizeof(struct in6_addr); + netif_keep_dst(dev); + /* This perm addr will be used as interface identifier by IPv6 */ + dev->addr_assign_type = NET_ADDR_RANDOM; + eth_random_addr(dev->perm_addr); +} + +#define GRE6_FEATURES (NETIF_F_SG | \ + NETIF_F_FRAGLIST | \ + NETIF_F_HIGHDMA | \ + NETIF_F_HW_CSUM) + +static void ip6gre_tnl_init_features(struct net_device *dev) +{ + struct ip6_tnl *nt = netdev_priv(dev); + __be16 flags; + + dev->features |= GRE6_FEATURES | NETIF_F_LLTX; + dev->hw_features |= GRE6_FEATURES; + + flags = nt->parms.o_flags; + + /* TCP offload with GRE SEQ is not supported, nor can we support 2 + * levels of outer headers requiring an update. + */ + if (flags & TUNNEL_SEQ) + return; + if (flags & TUNNEL_CSUM && nt->encap.type != TUNNEL_ENCAP_NONE) + return; + + dev->features |= NETIF_F_GSO_SOFTWARE; + dev->hw_features |= NETIF_F_GSO_SOFTWARE; +} + +static int ip6gre_tunnel_init_common(struct net_device *dev) +{ + struct ip6_tnl *tunnel; + int ret; + int t_hlen; + + tunnel = netdev_priv(dev); + + tunnel->dev = dev; + tunnel->net = dev_net(dev); + strcpy(tunnel->parms.name, dev->name); + + dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); + if (!dev->tstats) + return -ENOMEM; + + ret = dst_cache_init(&tunnel->dst_cache, GFP_KERNEL); + if (ret) + goto cleanup_alloc_pcpu_stats; + + ret = gro_cells_init(&tunnel->gro_cells, dev); + if (ret) + goto cleanup_dst_cache_init; + + t_hlen = ip6gre_calc_hlen(tunnel); + dev->mtu = ETH_DATA_LEN - t_hlen; + if (dev->type == ARPHRD_ETHER) + dev->mtu -= ETH_HLEN; + if (!(tunnel->parms.flags & IP6_TNL_F_IGN_ENCAP_LIMIT)) + dev->mtu -= 8; + + if (tunnel->parms.collect_md) { + netif_keep_dst(dev); + } + ip6gre_tnl_init_features(dev); + + netdev_hold(dev, &tunnel->dev_tracker, GFP_KERNEL); + return 0; + +cleanup_dst_cache_init: + dst_cache_destroy(&tunnel->dst_cache); +cleanup_alloc_pcpu_stats: + free_percpu(dev->tstats); + dev->tstats = NULL; + return ret; +} + +static int ip6gre_tunnel_init(struct net_device *dev) +{ + struct ip6_tnl *tunnel; + int ret; + + ret = ip6gre_tunnel_init_common(dev); + if (ret) + return ret; + + tunnel = netdev_priv(dev); + + if (tunnel->parms.collect_md) + return 0; + + __dev_addr_set(dev, &tunnel->parms.laddr, sizeof(struct in6_addr)); + memcpy(dev->broadcast, &tunnel->parms.raddr, sizeof(struct in6_addr)); + + if (ipv6_addr_any(&tunnel->parms.raddr)) + dev->header_ops = &ip6gre_header_ops; + + return 0; +} + +static void ip6gre_fb_tunnel_init(struct net_device *dev) +{ + struct ip6_tnl *tunnel = netdev_priv(dev); + + tunnel->dev = dev; + tunnel->net = dev_net(dev); + strcpy(tunnel->parms.name, dev->name); + + tunnel->hlen = sizeof(struct ipv6hdr) + 4; +} + +static struct inet6_protocol ip6gre_protocol __read_mostly = { + .handler = gre_rcv, + .err_handler = ip6gre_err, + .flags = INET6_PROTO_FINAL, +}; + +static void ip6gre_destroy_tunnels(struct net *net, struct list_head *head) +{ + struct ip6gre_net *ign = net_generic(net, ip6gre_net_id); + struct net_device *dev, *aux; + int prio; + + for_each_netdev_safe(net, dev, aux) + if (dev->rtnl_link_ops == &ip6gre_link_ops || + dev->rtnl_link_ops == &ip6gre_tap_ops || + dev->rtnl_link_ops == &ip6erspan_tap_ops) + unregister_netdevice_queue(dev, head); + + for (prio = 0; prio < 4; prio++) { + int h; + for (h = 0; h < IP6_GRE_HASH_SIZE; h++) { + struct ip6_tnl *t; + + t = rtnl_dereference(ign->tunnels[prio][h]); + + while (t) { + /* If dev is in the same netns, it has already + * been added to the list by the previous loop. + */ + if (!net_eq(dev_net(t->dev), net)) + unregister_netdevice_queue(t->dev, + head); + t = rtnl_dereference(t->next); + } + } + } +} + +static int __net_init ip6gre_init_net(struct net *net) +{ + struct ip6gre_net *ign = net_generic(net, ip6gre_net_id); + struct net_device *ndev; + int err; + + if (!net_has_fallback_tunnels(net)) + return 0; + ndev = alloc_netdev(sizeof(struct ip6_tnl), "ip6gre0", + NET_NAME_UNKNOWN, ip6gre_tunnel_setup); + if (!ndev) { + err = -ENOMEM; + goto err_alloc_dev; + } + ign->fb_tunnel_dev = ndev; + dev_net_set(ign->fb_tunnel_dev, net); + /* FB netdevice is special: we have one, and only one per netns. + * Allowing to move it to another netns is clearly unsafe. + */ + ign->fb_tunnel_dev->features |= NETIF_F_NETNS_LOCAL; + + + ip6gre_fb_tunnel_init(ign->fb_tunnel_dev); + ign->fb_tunnel_dev->rtnl_link_ops = &ip6gre_link_ops; + + err = register_netdev(ign->fb_tunnel_dev); + if (err) + goto err_reg_dev; + + rcu_assign_pointer(ign->tunnels_wc[0], + netdev_priv(ign->fb_tunnel_dev)); + return 0; + +err_reg_dev: + free_netdev(ndev); +err_alloc_dev: + return err; +} + +static void __net_exit ip6gre_exit_batch_net(struct list_head *net_list) +{ + struct net *net; + LIST_HEAD(list); + + rtnl_lock(); + list_for_each_entry(net, net_list, exit_list) + ip6gre_destroy_tunnels(net, &list); + unregister_netdevice_many(&list); + rtnl_unlock(); +} + +static struct pernet_operations ip6gre_net_ops = { + .init = ip6gre_init_net, + .exit_batch = ip6gre_exit_batch_net, + .id = &ip6gre_net_id, + .size = sizeof(struct ip6gre_net), +}; + +static int ip6gre_tunnel_validate(struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + __be16 flags; + + if (!data) + return 0; + + flags = 0; + if (data[IFLA_GRE_IFLAGS]) + flags |= nla_get_be16(data[IFLA_GRE_IFLAGS]); + if (data[IFLA_GRE_OFLAGS]) + flags |= nla_get_be16(data[IFLA_GRE_OFLAGS]); + if (flags & (GRE_VERSION|GRE_ROUTING)) + return -EINVAL; + + return 0; +} + +static int ip6gre_tap_validate(struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct in6_addr daddr; + + if (tb[IFLA_ADDRESS]) { + if (nla_len(tb[IFLA_ADDRESS]) != ETH_ALEN) + return -EINVAL; + if (!is_valid_ether_addr(nla_data(tb[IFLA_ADDRESS]))) + return -EADDRNOTAVAIL; + } + + if (!data) + goto out; + + if (data[IFLA_GRE_REMOTE]) { + daddr = nla_get_in6_addr(data[IFLA_GRE_REMOTE]); + if (ipv6_addr_any(&daddr)) + return -EINVAL; + } + +out: + return ip6gre_tunnel_validate(tb, data, extack); +} + +static int ip6erspan_tap_validate(struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + __be16 flags = 0; + int ret, ver = 0; + + if (!data) + return 0; + + ret = ip6gre_tap_validate(tb, data, extack); + if (ret) + return ret; + + /* ERSPAN should only have GRE sequence and key flag */ + if (data[IFLA_GRE_OFLAGS]) + flags |= nla_get_be16(data[IFLA_GRE_OFLAGS]); + if (data[IFLA_GRE_IFLAGS]) + flags |= nla_get_be16(data[IFLA_GRE_IFLAGS]); + if (!data[IFLA_GRE_COLLECT_METADATA] && + flags != (GRE_SEQ | GRE_KEY)) + return -EINVAL; + + /* ERSPAN Session ID only has 10-bit. Since we reuse + * 32-bit key field as ID, check it's range. + */ + if (data[IFLA_GRE_IKEY] && + (ntohl(nla_get_be32(data[IFLA_GRE_IKEY])) & ~ID_MASK)) + return -EINVAL; + + if (data[IFLA_GRE_OKEY] && + (ntohl(nla_get_be32(data[IFLA_GRE_OKEY])) & ~ID_MASK)) + return -EINVAL; + + if (data[IFLA_GRE_ERSPAN_VER]) { + ver = nla_get_u8(data[IFLA_GRE_ERSPAN_VER]); + if (ver != 1 && ver != 2) + return -EINVAL; + } + + if (ver == 1) { + if (data[IFLA_GRE_ERSPAN_INDEX]) { + u32 index = nla_get_u32(data[IFLA_GRE_ERSPAN_INDEX]); + + if (index & ~INDEX_MASK) + return -EINVAL; + } + } else if (ver == 2) { + if (data[IFLA_GRE_ERSPAN_DIR]) { + u16 dir = nla_get_u8(data[IFLA_GRE_ERSPAN_DIR]); + + if (dir & ~(DIR_MASK >> DIR_OFFSET)) + return -EINVAL; + } + + if (data[IFLA_GRE_ERSPAN_HWID]) { + u16 hwid = nla_get_u16(data[IFLA_GRE_ERSPAN_HWID]); + + if (hwid & ~(HWID_MASK >> HWID_OFFSET)) + return -EINVAL; + } + } + + return 0; +} + +static void ip6erspan_set_version(struct nlattr *data[], + struct __ip6_tnl_parm *parms) +{ + if (!data) + return; + + parms->erspan_ver = 1; + if (data[IFLA_GRE_ERSPAN_VER]) + parms->erspan_ver = nla_get_u8(data[IFLA_GRE_ERSPAN_VER]); + + if (parms->erspan_ver == 1) { + if (data[IFLA_GRE_ERSPAN_INDEX]) + parms->index = nla_get_u32(data[IFLA_GRE_ERSPAN_INDEX]); + } else if (parms->erspan_ver == 2) { + if (data[IFLA_GRE_ERSPAN_DIR]) + parms->dir = nla_get_u8(data[IFLA_GRE_ERSPAN_DIR]); + if (data[IFLA_GRE_ERSPAN_HWID]) + parms->hwid = nla_get_u16(data[IFLA_GRE_ERSPAN_HWID]); + } +} + +static void ip6gre_netlink_parms(struct nlattr *data[], + struct __ip6_tnl_parm *parms) +{ + memset(parms, 0, sizeof(*parms)); + + if (!data) + return; + + if (data[IFLA_GRE_LINK]) + parms->link = nla_get_u32(data[IFLA_GRE_LINK]); + + if (data[IFLA_GRE_IFLAGS]) + parms->i_flags = gre_flags_to_tnl_flags( + nla_get_be16(data[IFLA_GRE_IFLAGS])); + + if (data[IFLA_GRE_OFLAGS]) + parms->o_flags = gre_flags_to_tnl_flags( + nla_get_be16(data[IFLA_GRE_OFLAGS])); + + if (data[IFLA_GRE_IKEY]) + parms->i_key = nla_get_be32(data[IFLA_GRE_IKEY]); + + if (data[IFLA_GRE_OKEY]) + parms->o_key = nla_get_be32(data[IFLA_GRE_OKEY]); + + if (data[IFLA_GRE_LOCAL]) + parms->laddr = nla_get_in6_addr(data[IFLA_GRE_LOCAL]); + + if (data[IFLA_GRE_REMOTE]) + parms->raddr = nla_get_in6_addr(data[IFLA_GRE_REMOTE]); + + if (data[IFLA_GRE_TTL]) + parms->hop_limit = nla_get_u8(data[IFLA_GRE_TTL]); + + if (data[IFLA_GRE_ENCAP_LIMIT]) + parms->encap_limit = nla_get_u8(data[IFLA_GRE_ENCAP_LIMIT]); + + if (data[IFLA_GRE_FLOWINFO]) + parms->flowinfo = nla_get_be32(data[IFLA_GRE_FLOWINFO]); + + if (data[IFLA_GRE_FLAGS]) + parms->flags = nla_get_u32(data[IFLA_GRE_FLAGS]); + + if (data[IFLA_GRE_FWMARK]) + parms->fwmark = nla_get_u32(data[IFLA_GRE_FWMARK]); + + if (data[IFLA_GRE_COLLECT_METADATA]) + parms->collect_md = true; +} + +static int ip6gre_tap_init(struct net_device *dev) +{ + int ret; + + ret = ip6gre_tunnel_init_common(dev); + if (ret) + return ret; + + dev->priv_flags |= IFF_LIVE_ADDR_CHANGE; + + return 0; +} + +static const struct net_device_ops ip6gre_tap_netdev_ops = { + .ndo_init = ip6gre_tap_init, + .ndo_uninit = ip6gre_tunnel_uninit, + .ndo_start_xmit = ip6gre_tunnel_xmit, + .ndo_set_mac_address = eth_mac_addr, + .ndo_validate_addr = eth_validate_addr, + .ndo_change_mtu = ip6_tnl_change_mtu, + .ndo_get_stats64 = dev_get_tstats64, + .ndo_get_iflink = ip6_tnl_get_iflink, +}; + +static int ip6erspan_calc_hlen(struct ip6_tnl *tunnel) +{ + int t_hlen; + + tunnel->tun_hlen = 8; + tunnel->hlen = tunnel->tun_hlen + tunnel->encap_hlen + + erspan_hdr_len(tunnel->parms.erspan_ver); + + t_hlen = tunnel->hlen + sizeof(struct ipv6hdr); + tunnel->dev->needed_headroom = LL_MAX_HEADER + t_hlen; + return t_hlen; +} + +static int ip6erspan_tap_init(struct net_device *dev) +{ + struct ip6_tnl *tunnel; + int t_hlen; + int ret; + + tunnel = netdev_priv(dev); + + tunnel->dev = dev; + tunnel->net = dev_net(dev); + strcpy(tunnel->parms.name, dev->name); + + dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); + if (!dev->tstats) + return -ENOMEM; + + ret = dst_cache_init(&tunnel->dst_cache, GFP_KERNEL); + if (ret) + goto cleanup_alloc_pcpu_stats; + + ret = gro_cells_init(&tunnel->gro_cells, dev); + if (ret) + goto cleanup_dst_cache_init; + + t_hlen = ip6erspan_calc_hlen(tunnel); + dev->mtu = ETH_DATA_LEN - t_hlen; + if (dev->type == ARPHRD_ETHER) + dev->mtu -= ETH_HLEN; + if (!(tunnel->parms.flags & IP6_TNL_F_IGN_ENCAP_LIMIT)) + dev->mtu -= 8; + + dev->priv_flags |= IFF_LIVE_ADDR_CHANGE; + ip6erspan_tnl_link_config(tunnel, 1); + + netdev_hold(dev, &tunnel->dev_tracker, GFP_KERNEL); + return 0; + +cleanup_dst_cache_init: + dst_cache_destroy(&tunnel->dst_cache); +cleanup_alloc_pcpu_stats: + free_percpu(dev->tstats); + dev->tstats = NULL; + return ret; +} + +static const struct net_device_ops ip6erspan_netdev_ops = { + .ndo_init = ip6erspan_tap_init, + .ndo_uninit = ip6erspan_tunnel_uninit, + .ndo_start_xmit = ip6erspan_tunnel_xmit, + .ndo_set_mac_address = eth_mac_addr, + .ndo_validate_addr = eth_validate_addr, + .ndo_change_mtu = ip6_tnl_change_mtu, + .ndo_get_stats64 = dev_get_tstats64, + .ndo_get_iflink = ip6_tnl_get_iflink, +}; + +static void ip6gre_tap_setup(struct net_device *dev) +{ + + ether_setup(dev); + + dev->max_mtu = 0; + dev->netdev_ops = &ip6gre_tap_netdev_ops; + dev->needs_free_netdev = true; + dev->priv_destructor = ip6gre_dev_free; + + dev->priv_flags &= ~IFF_TX_SKB_SHARING; + dev->priv_flags |= IFF_LIVE_ADDR_CHANGE; + netif_keep_dst(dev); +} + +static bool ip6gre_netlink_encap_parms(struct nlattr *data[], + struct ip_tunnel_encap *ipencap) +{ + bool ret = false; + + memset(ipencap, 0, sizeof(*ipencap)); + + if (!data) + return ret; + + if (data[IFLA_GRE_ENCAP_TYPE]) { + ret = true; + ipencap->type = nla_get_u16(data[IFLA_GRE_ENCAP_TYPE]); + } + + if (data[IFLA_GRE_ENCAP_FLAGS]) { + ret = true; + ipencap->flags = nla_get_u16(data[IFLA_GRE_ENCAP_FLAGS]); + } + + if (data[IFLA_GRE_ENCAP_SPORT]) { + ret = true; + ipencap->sport = nla_get_be16(data[IFLA_GRE_ENCAP_SPORT]); + } + + if (data[IFLA_GRE_ENCAP_DPORT]) { + ret = true; + ipencap->dport = nla_get_be16(data[IFLA_GRE_ENCAP_DPORT]); + } + + return ret; +} + +static int ip6gre_newlink_common(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ip6_tnl *nt; + struct ip_tunnel_encap ipencap; + int err; + + nt = netdev_priv(dev); + + if (ip6gre_netlink_encap_parms(data, &ipencap)) { + int err = ip6_tnl_encap_setup(nt, &ipencap); + + if (err < 0) + return err; + } + + if (dev->type == ARPHRD_ETHER && !tb[IFLA_ADDRESS]) + eth_hw_addr_random(dev); + + nt->dev = dev; + nt->net = dev_net(dev); + + err = register_netdevice(dev); + if (err) + goto out; + + if (tb[IFLA_MTU]) + ip6_tnl_change_mtu(dev, nla_get_u32(tb[IFLA_MTU])); + +out: + return err; +} + +static int ip6gre_newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ip6_tnl *nt = netdev_priv(dev); + struct net *net = dev_net(dev); + struct ip6gre_net *ign; + int err; + + ip6gre_netlink_parms(data, &nt->parms); + ign = net_generic(net, ip6gre_net_id); + + if (nt->parms.collect_md) { + if (rtnl_dereference(ign->collect_md_tun)) + return -EEXIST; + } else { + if (ip6gre_tunnel_find(net, &nt->parms, dev->type)) + return -EEXIST; + } + + err = ip6gre_newlink_common(src_net, dev, tb, data, extack); + if (!err) { + ip6gre_tnl_link_config(nt, !tb[IFLA_MTU]); + ip6gre_tunnel_link_md(ign, nt); + ip6gre_tunnel_link(net_generic(net, ip6gre_net_id), nt); + } + return err; +} + +static struct ip6_tnl * +ip6gre_changelink_common(struct net_device *dev, struct nlattr *tb[], + struct nlattr *data[], struct __ip6_tnl_parm *p_p, + struct netlink_ext_ack *extack) +{ + struct ip6_tnl *t, *nt = netdev_priv(dev); + struct net *net = nt->net; + struct ip6gre_net *ign = net_generic(net, ip6gre_net_id); + struct ip_tunnel_encap ipencap; + + if (dev == ign->fb_tunnel_dev) + return ERR_PTR(-EINVAL); + + if (ip6gre_netlink_encap_parms(data, &ipencap)) { + int err = ip6_tnl_encap_setup(nt, &ipencap); + + if (err < 0) + return ERR_PTR(err); + } + + ip6gre_netlink_parms(data, p_p); + + t = ip6gre_tunnel_locate(net, p_p, 0); + + if (t) { + if (t->dev != dev) + return ERR_PTR(-EEXIST); + } else { + t = nt; + } + + return t; +} + +static int ip6gre_changelink(struct net_device *dev, struct nlattr *tb[], + struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct ip6gre_net *ign = net_generic(t->net, ip6gre_net_id); + struct __ip6_tnl_parm p; + + t = ip6gre_changelink_common(dev, tb, data, &p, extack); + if (IS_ERR(t)) + return PTR_ERR(t); + + ip6gre_tunnel_unlink_md(ign, t); + ip6gre_tunnel_unlink(ign, t); + ip6gre_tnl_change(t, &p, !tb[IFLA_MTU]); + ip6gre_tunnel_link_md(ign, t); + ip6gre_tunnel_link(ign, t); + return 0; +} + +static void ip6gre_dellink(struct net_device *dev, struct list_head *head) +{ + struct net *net = dev_net(dev); + struct ip6gre_net *ign = net_generic(net, ip6gre_net_id); + + if (dev != ign->fb_tunnel_dev) + unregister_netdevice_queue(dev, head); +} + +static size_t ip6gre_get_size(const struct net_device *dev) +{ + return + /* IFLA_GRE_LINK */ + nla_total_size(4) + + /* IFLA_GRE_IFLAGS */ + nla_total_size(2) + + /* IFLA_GRE_OFLAGS */ + nla_total_size(2) + + /* IFLA_GRE_IKEY */ + nla_total_size(4) + + /* IFLA_GRE_OKEY */ + nla_total_size(4) + + /* IFLA_GRE_LOCAL */ + nla_total_size(sizeof(struct in6_addr)) + + /* IFLA_GRE_REMOTE */ + nla_total_size(sizeof(struct in6_addr)) + + /* IFLA_GRE_TTL */ + nla_total_size(1) + + /* IFLA_GRE_ENCAP_LIMIT */ + nla_total_size(1) + + /* IFLA_GRE_FLOWINFO */ + nla_total_size(4) + + /* IFLA_GRE_FLAGS */ + nla_total_size(4) + + /* IFLA_GRE_ENCAP_TYPE */ + nla_total_size(2) + + /* IFLA_GRE_ENCAP_FLAGS */ + nla_total_size(2) + + /* IFLA_GRE_ENCAP_SPORT */ + nla_total_size(2) + + /* IFLA_GRE_ENCAP_DPORT */ + nla_total_size(2) + + /* IFLA_GRE_COLLECT_METADATA */ + nla_total_size(0) + + /* IFLA_GRE_FWMARK */ + nla_total_size(4) + + /* IFLA_GRE_ERSPAN_INDEX */ + nla_total_size(4) + + 0; +} + +static int ip6gre_fill_info(struct sk_buff *skb, const struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct __ip6_tnl_parm *p = &t->parms; + __be16 o_flags = p->o_flags; + + if (p->erspan_ver == 1 || p->erspan_ver == 2) { + if (!p->collect_md) + o_flags |= TUNNEL_KEY; + + if (nla_put_u8(skb, IFLA_GRE_ERSPAN_VER, p->erspan_ver)) + goto nla_put_failure; + + if (p->erspan_ver == 1) { + if (nla_put_u32(skb, IFLA_GRE_ERSPAN_INDEX, p->index)) + goto nla_put_failure; + } else { + if (nla_put_u8(skb, IFLA_GRE_ERSPAN_DIR, p->dir)) + goto nla_put_failure; + if (nla_put_u16(skb, IFLA_GRE_ERSPAN_HWID, p->hwid)) + goto nla_put_failure; + } + } + + if (nla_put_u32(skb, IFLA_GRE_LINK, p->link) || + nla_put_be16(skb, IFLA_GRE_IFLAGS, + gre_tnl_flags_to_gre_flags(p->i_flags)) || + nla_put_be16(skb, IFLA_GRE_OFLAGS, + gre_tnl_flags_to_gre_flags(o_flags)) || + nla_put_be32(skb, IFLA_GRE_IKEY, p->i_key) || + nla_put_be32(skb, IFLA_GRE_OKEY, p->o_key) || + nla_put_in6_addr(skb, IFLA_GRE_LOCAL, &p->laddr) || + nla_put_in6_addr(skb, IFLA_GRE_REMOTE, &p->raddr) || + nla_put_u8(skb, IFLA_GRE_TTL, p->hop_limit) || + nla_put_u8(skb, IFLA_GRE_ENCAP_LIMIT, p->encap_limit) || + nla_put_be32(skb, IFLA_GRE_FLOWINFO, p->flowinfo) || + nla_put_u32(skb, IFLA_GRE_FLAGS, p->flags) || + nla_put_u32(skb, IFLA_GRE_FWMARK, p->fwmark)) + goto nla_put_failure; + + if (nla_put_u16(skb, IFLA_GRE_ENCAP_TYPE, + t->encap.type) || + nla_put_be16(skb, IFLA_GRE_ENCAP_SPORT, + t->encap.sport) || + nla_put_be16(skb, IFLA_GRE_ENCAP_DPORT, + t->encap.dport) || + nla_put_u16(skb, IFLA_GRE_ENCAP_FLAGS, + t->encap.flags)) + goto nla_put_failure; + + if (p->collect_md) { + if (nla_put_flag(skb, IFLA_GRE_COLLECT_METADATA)) + goto nla_put_failure; + } + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static const struct nla_policy ip6gre_policy[IFLA_GRE_MAX + 1] = { + [IFLA_GRE_LINK] = { .type = NLA_U32 }, + [IFLA_GRE_IFLAGS] = { .type = NLA_U16 }, + [IFLA_GRE_OFLAGS] = { .type = NLA_U16 }, + [IFLA_GRE_IKEY] = { .type = NLA_U32 }, + [IFLA_GRE_OKEY] = { .type = NLA_U32 }, + [IFLA_GRE_LOCAL] = { .len = sizeof_field(struct ipv6hdr, saddr) }, + [IFLA_GRE_REMOTE] = { .len = sizeof_field(struct ipv6hdr, daddr) }, + [IFLA_GRE_TTL] = { .type = NLA_U8 }, + [IFLA_GRE_ENCAP_LIMIT] = { .type = NLA_U8 }, + [IFLA_GRE_FLOWINFO] = { .type = NLA_U32 }, + [IFLA_GRE_FLAGS] = { .type = NLA_U32 }, + [IFLA_GRE_ENCAP_TYPE] = { .type = NLA_U16 }, + [IFLA_GRE_ENCAP_FLAGS] = { .type = NLA_U16 }, + [IFLA_GRE_ENCAP_SPORT] = { .type = NLA_U16 }, + [IFLA_GRE_ENCAP_DPORT] = { .type = NLA_U16 }, + [IFLA_GRE_COLLECT_METADATA] = { .type = NLA_FLAG }, + [IFLA_GRE_FWMARK] = { .type = NLA_U32 }, + [IFLA_GRE_ERSPAN_INDEX] = { .type = NLA_U32 }, + [IFLA_GRE_ERSPAN_VER] = { .type = NLA_U8 }, + [IFLA_GRE_ERSPAN_DIR] = { .type = NLA_U8 }, + [IFLA_GRE_ERSPAN_HWID] = { .type = NLA_U16 }, +}; + +static void ip6erspan_tap_setup(struct net_device *dev) +{ + ether_setup(dev); + + dev->max_mtu = 0; + dev->netdev_ops = &ip6erspan_netdev_ops; + dev->needs_free_netdev = true; + dev->priv_destructor = ip6gre_dev_free; + + dev->priv_flags &= ~IFF_TX_SKB_SHARING; + dev->priv_flags |= IFF_LIVE_ADDR_CHANGE; + netif_keep_dst(dev); +} + +static int ip6erspan_newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ip6_tnl *nt = netdev_priv(dev); + struct net *net = dev_net(dev); + struct ip6gre_net *ign; + int err; + + ip6gre_netlink_parms(data, &nt->parms); + ip6erspan_set_version(data, &nt->parms); + ign = net_generic(net, ip6gre_net_id); + + if (nt->parms.collect_md) { + if (rtnl_dereference(ign->collect_md_tun_erspan)) + return -EEXIST; + } else { + if (ip6gre_tunnel_find(net, &nt->parms, dev->type)) + return -EEXIST; + } + + err = ip6gre_newlink_common(src_net, dev, tb, data, extack); + if (!err) { + ip6erspan_tnl_link_config(nt, !tb[IFLA_MTU]); + ip6erspan_tunnel_link_md(ign, nt); + ip6gre_tunnel_link(net_generic(net, ip6gre_net_id), nt); + } + return err; +} + +static void ip6erspan_tnl_link_config(struct ip6_tnl *t, int set_mtu) +{ + ip6gre_tnl_link_config_common(t); + ip6gre_tnl_link_config_route(t, set_mtu, ip6erspan_calc_hlen(t)); +} + +static int ip6erspan_tnl_change(struct ip6_tnl *t, + const struct __ip6_tnl_parm *p, int set_mtu) +{ + ip6gre_tnl_copy_tnl_parm(t, p); + ip6erspan_tnl_link_config(t, set_mtu); + return 0; +} + +static int ip6erspan_changelink(struct net_device *dev, struct nlattr *tb[], + struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ip6gre_net *ign = net_generic(dev_net(dev), ip6gre_net_id); + struct __ip6_tnl_parm p; + struct ip6_tnl *t; + + t = ip6gre_changelink_common(dev, tb, data, &p, extack); + if (IS_ERR(t)) + return PTR_ERR(t); + + ip6erspan_set_version(data, &p); + ip6gre_tunnel_unlink_md(ign, t); + ip6gre_tunnel_unlink(ign, t); + ip6erspan_tnl_change(t, &p, !tb[IFLA_MTU]); + ip6erspan_tunnel_link_md(ign, t); + ip6gre_tunnel_link(ign, t); + return 0; +} + +static struct rtnl_link_ops ip6gre_link_ops __read_mostly = { + .kind = "ip6gre", + .maxtype = IFLA_GRE_MAX, + .policy = ip6gre_policy, + .priv_size = sizeof(struct ip6_tnl), + .setup = ip6gre_tunnel_setup, + .validate = ip6gre_tunnel_validate, + .newlink = ip6gre_newlink, + .changelink = ip6gre_changelink, + .dellink = ip6gre_dellink, + .get_size = ip6gre_get_size, + .fill_info = ip6gre_fill_info, + .get_link_net = ip6_tnl_get_link_net, +}; + +static struct rtnl_link_ops ip6gre_tap_ops __read_mostly = { + .kind = "ip6gretap", + .maxtype = IFLA_GRE_MAX, + .policy = ip6gre_policy, + .priv_size = sizeof(struct ip6_tnl), + .setup = ip6gre_tap_setup, + .validate = ip6gre_tap_validate, + .newlink = ip6gre_newlink, + .changelink = ip6gre_changelink, + .get_size = ip6gre_get_size, + .fill_info = ip6gre_fill_info, + .get_link_net = ip6_tnl_get_link_net, +}; + +static struct rtnl_link_ops ip6erspan_tap_ops __read_mostly = { + .kind = "ip6erspan", + .maxtype = IFLA_GRE_MAX, + .policy = ip6gre_policy, + .priv_size = sizeof(struct ip6_tnl), + .setup = ip6erspan_tap_setup, + .validate = ip6erspan_tap_validate, + .newlink = ip6erspan_newlink, + .changelink = ip6erspan_changelink, + .get_size = ip6gre_get_size, + .fill_info = ip6gre_fill_info, + .get_link_net = ip6_tnl_get_link_net, +}; + +/* + * And now the modules code and kernel interface. + */ + +static int __init ip6gre_init(void) +{ + int err; + + pr_info("GRE over IPv6 tunneling driver\n"); + + err = register_pernet_device(&ip6gre_net_ops); + if (err < 0) + return err; + + err = inet6_add_protocol(&ip6gre_protocol, IPPROTO_GRE); + if (err < 0) { + pr_info("%s: can't add protocol\n", __func__); + goto add_proto_failed; + } + + err = rtnl_link_register(&ip6gre_link_ops); + if (err < 0) + goto rtnl_link_failed; + + err = rtnl_link_register(&ip6gre_tap_ops); + if (err < 0) + goto tap_ops_failed; + + err = rtnl_link_register(&ip6erspan_tap_ops); + if (err < 0) + goto erspan_link_failed; + +out: + return err; + +erspan_link_failed: + rtnl_link_unregister(&ip6gre_tap_ops); +tap_ops_failed: + rtnl_link_unregister(&ip6gre_link_ops); +rtnl_link_failed: + inet6_del_protocol(&ip6gre_protocol, IPPROTO_GRE); +add_proto_failed: + unregister_pernet_device(&ip6gre_net_ops); + goto out; +} + +static void __exit ip6gre_fini(void) +{ + rtnl_link_unregister(&ip6gre_tap_ops); + rtnl_link_unregister(&ip6gre_link_ops); + rtnl_link_unregister(&ip6erspan_tap_ops); + inet6_del_protocol(&ip6gre_protocol, IPPROTO_GRE); + unregister_pernet_device(&ip6gre_net_ops); +} + +module_init(ip6gre_init); +module_exit(ip6gre_fini); +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("D. Kozlov (xeb@mail.ru)"); +MODULE_DESCRIPTION("GRE over IPv6 tunneling device"); +MODULE_ALIAS_RTNL_LINK("ip6gre"); +MODULE_ALIAS_RTNL_LINK("ip6gretap"); +MODULE_ALIAS_RTNL_LINK("ip6erspan"); +MODULE_ALIAS_NETDEV("ip6gre0"); diff --git a/net/ipv6/ip6_icmp.c b/net/ipv6/ip6_icmp.c new file mode 100644 index 000000000..9e3574880 --- /dev/null +++ b/net/ipv6/ip6_icmp.c @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include + +#include + +#if IS_ENABLED(CONFIG_IPV6) + +#if !IS_BUILTIN(CONFIG_IPV6) + +static ip6_icmp_send_t __rcu *ip6_icmp_send; + +int inet6_register_icmp_sender(ip6_icmp_send_t *fn) +{ + return (cmpxchg((ip6_icmp_send_t **)&ip6_icmp_send, NULL, fn) == NULL) ? + 0 : -EBUSY; +} +EXPORT_SYMBOL(inet6_register_icmp_sender); + +int inet6_unregister_icmp_sender(ip6_icmp_send_t *fn) +{ + int ret; + + ret = (cmpxchg((ip6_icmp_send_t **)&ip6_icmp_send, fn, NULL) == fn) ? + 0 : -EINVAL; + + synchronize_net(); + + return ret; +} +EXPORT_SYMBOL(inet6_unregister_icmp_sender); + +void __icmpv6_send(struct sk_buff *skb, u8 type, u8 code, __u32 info, + const struct inet6_skb_parm *parm) +{ + ip6_icmp_send_t *send; + + rcu_read_lock(); + send = rcu_dereference(ip6_icmp_send); + if (send) + send(skb, type, code, info, NULL, parm); + rcu_read_unlock(); +} +EXPORT_SYMBOL(__icmpv6_send); +#endif + +#if IS_ENABLED(CONFIG_NF_NAT) +#include +void icmpv6_ndo_send(struct sk_buff *skb_in, u8 type, u8 code, __u32 info) +{ + struct inet6_skb_parm parm = { 0 }; + struct sk_buff *cloned_skb = NULL; + enum ip_conntrack_info ctinfo; + struct in6_addr orig_ip; + struct nf_conn *ct; + + ct = nf_ct_get(skb_in, &ctinfo); + if (!ct || !(ct->status & IPS_SRC_NAT)) { + __icmpv6_send(skb_in, type, code, info, &parm); + return; + } + + if (skb_shared(skb_in)) + skb_in = cloned_skb = skb_clone(skb_in, GFP_ATOMIC); + + if (unlikely(!skb_in || skb_network_header(skb_in) < skb_in->head || + (skb_network_header(skb_in) + sizeof(struct ipv6hdr)) > + skb_tail_pointer(skb_in) || skb_ensure_writable(skb_in, + skb_network_offset(skb_in) + sizeof(struct ipv6hdr)))) + goto out; + + orig_ip = ipv6_hdr(skb_in)->saddr; + ipv6_hdr(skb_in)->saddr = ct->tuplehash[0].tuple.src.u3.in6; + __icmpv6_send(skb_in, type, code, info, &parm); + ipv6_hdr(skb_in)->saddr = orig_ip; +out: + consume_skb(cloned_skb); +} +EXPORT_SYMBOL(icmpv6_ndo_send); +#endif +#endif diff --git a/net/ipv6/ip6_input.c b/net/ipv6/ip6_input.c new file mode 100644 index 000000000..b83788145 --- /dev/null +++ b/net/ipv6/ip6_input.c @@ -0,0 +1,593 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPv6 input + * Linux INET6 implementation + * + * Authors: + * Pedro Roque + * Ian P. Morris + * + * Based in linux/net/ipv4/ip_input.c + */ +/* Changes + * + * Mitsuru KANDA @USAGI and + * YOSHIFUJI Hideaki @USAGI: Remove ipv6_parse_exthdrs(). + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void ip6_rcv_finish_core(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + if (READ_ONCE(net->ipv4.sysctl_ip_early_demux) && + !skb_dst(skb) && !skb->sk) { + switch (ipv6_hdr(skb)->nexthdr) { + case IPPROTO_TCP: + if (READ_ONCE(net->ipv4.sysctl_tcp_early_demux)) + tcp_v6_early_demux(skb); + break; + case IPPROTO_UDP: + if (READ_ONCE(net->ipv4.sysctl_udp_early_demux)) + udp_v6_early_demux(skb); + break; + } + } + + if (!skb_valid_dst(skb)) + ip6_route_input(skb); +} + +int ip6_rcv_finish(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + /* if ingress device is enslaved to an L3 master device pass the + * skb to its handler for processing + */ + skb = l3mdev_ip6_rcv(skb); + if (!skb) + return NET_RX_SUCCESS; + ip6_rcv_finish_core(net, sk, skb); + + return dst_input(skb); +} + +static void ip6_sublist_rcv_finish(struct list_head *head) +{ + struct sk_buff *skb, *next; + + list_for_each_entry_safe(skb, next, head, list) { + skb_list_del_init(skb); + dst_input(skb); + } +} + +static bool ip6_can_use_hint(const struct sk_buff *skb, + const struct sk_buff *hint) +{ + return hint && !skb_dst(skb) && + ipv6_addr_equal(&ipv6_hdr(hint)->daddr, &ipv6_hdr(skb)->daddr); +} + +static struct sk_buff *ip6_extract_route_hint(const struct net *net, + struct sk_buff *skb) +{ + if (fib6_routes_require_src(net) || fib6_has_custom_rules(net) || + IP6CB(skb)->flags & IP6SKB_MULTIPATH) + return NULL; + + return skb; +} + +static void ip6_list_rcv_finish(struct net *net, struct sock *sk, + struct list_head *head) +{ + struct sk_buff *skb, *next, *hint = NULL; + struct dst_entry *curr_dst = NULL; + struct list_head sublist; + + INIT_LIST_HEAD(&sublist); + list_for_each_entry_safe(skb, next, head, list) { + struct dst_entry *dst; + + skb_list_del_init(skb); + /* if ingress device is enslaved to an L3 master device pass the + * skb to its handler for processing + */ + skb = l3mdev_ip6_rcv(skb); + if (!skb) + continue; + + if (ip6_can_use_hint(skb, hint)) + skb_dst_copy(skb, hint); + else + ip6_rcv_finish_core(net, sk, skb); + dst = skb_dst(skb); + if (curr_dst != dst) { + hint = ip6_extract_route_hint(net, skb); + + /* dispatch old sublist */ + if (!list_empty(&sublist)) + ip6_sublist_rcv_finish(&sublist); + /* start new sublist */ + INIT_LIST_HEAD(&sublist); + curr_dst = dst; + } + list_add_tail(&skb->list, &sublist); + } + /* dispatch final sublist */ + ip6_sublist_rcv_finish(&sublist); +} + +static struct sk_buff *ip6_rcv_core(struct sk_buff *skb, struct net_device *dev, + struct net *net) +{ + enum skb_drop_reason reason; + const struct ipv6hdr *hdr; + u32 pkt_len; + struct inet6_dev *idev; + + if (skb->pkt_type == PACKET_OTHERHOST) { + dev_core_stats_rx_otherhost_dropped_inc(skb->dev); + kfree_skb_reason(skb, SKB_DROP_REASON_OTHERHOST); + return NULL; + } + + rcu_read_lock(); + + idev = __in6_dev_get(skb->dev); + + __IP6_UPD_PO_STATS(net, idev, IPSTATS_MIB_IN, skb->len); + + SKB_DR_SET(reason, NOT_SPECIFIED); + if ((skb = skb_share_check(skb, GFP_ATOMIC)) == NULL || + !idev || unlikely(idev->cnf.disable_ipv6)) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INDISCARDS); + if (idev && unlikely(idev->cnf.disable_ipv6)) + SKB_DR_SET(reason, IPV6DISABLED); + goto drop; + } + + memset(IP6CB(skb), 0, sizeof(struct inet6_skb_parm)); + + /* + * Store incoming device index. When the packet will + * be queued, we cannot refer to skb->dev anymore. + * + * BTW, when we send a packet for our own local address on a + * non-loopback interface (e.g. ethX), it is being delivered + * via the loopback interface (lo) here; skb->dev = loopback_dev. + * It, however, should be considered as if it is being + * arrived via the sending interface (ethX), because of the + * nature of scoping architecture. --yoshfuji + */ + IP6CB(skb)->iif = skb_valid_dst(skb) ? ip6_dst_idev(skb_dst(skb))->dev->ifindex : dev->ifindex; + + if (unlikely(!pskb_may_pull(skb, sizeof(*hdr)))) + goto err; + + hdr = ipv6_hdr(skb); + + if (hdr->version != 6) { + SKB_DR_SET(reason, UNHANDLED_PROTO); + goto err; + } + + __IP6_ADD_STATS(net, idev, + IPSTATS_MIB_NOECTPKTS + + (ipv6_get_dsfield(hdr) & INET_ECN_MASK), + max_t(unsigned short, 1, skb_shinfo(skb)->gso_segs)); + /* + * RFC4291 2.5.3 + * The loopback address must not be used as the source address in IPv6 + * packets that are sent outside of a single node. [..] + * A packet received on an interface with a destination address + * of loopback must be dropped. + */ + if ((ipv6_addr_loopback(&hdr->saddr) || + ipv6_addr_loopback(&hdr->daddr)) && + !(dev->flags & IFF_LOOPBACK) && + !netif_is_l3_master(dev)) + goto err; + + /* RFC4291 Errata ID: 3480 + * Interface-Local scope spans only a single interface on a + * node and is useful only for loopback transmission of + * multicast. Packets with interface-local scope received + * from another node must be discarded. + */ + if (!(skb->pkt_type == PACKET_LOOPBACK || + dev->flags & IFF_LOOPBACK) && + ipv6_addr_is_multicast(&hdr->daddr) && + IPV6_ADDR_MC_SCOPE(&hdr->daddr) == 1) + goto err; + + /* If enabled, drop unicast packets that were encapsulated in link-layer + * multicast or broadcast to protected against the so-called "hole-196" + * attack in 802.11 wireless. + */ + if (!ipv6_addr_is_multicast(&hdr->daddr) && + (skb->pkt_type == PACKET_BROADCAST || + skb->pkt_type == PACKET_MULTICAST) && + idev->cnf.drop_unicast_in_l2_multicast) { + SKB_DR_SET(reason, UNICAST_IN_L2_MULTICAST); + goto err; + } + + /* RFC4291 2.7 + * Nodes must not originate a packet to a multicast address whose scope + * field contains the reserved value 0; if such a packet is received, it + * must be silently dropped. + */ + if (ipv6_addr_is_multicast(&hdr->daddr) && + IPV6_ADDR_MC_SCOPE(&hdr->daddr) == 0) + goto err; + + /* + * RFC4291 2.7 + * Multicast addresses must not be used as source addresses in IPv6 + * packets or appear in any Routing header. + */ + if (ipv6_addr_is_multicast(&hdr->saddr)) + goto err; + + skb->transport_header = skb->network_header + sizeof(*hdr); + IP6CB(skb)->nhoff = offsetof(struct ipv6hdr, nexthdr); + + pkt_len = ntohs(hdr->payload_len); + + /* pkt_len may be zero if Jumbo payload option is present */ + if (pkt_len || hdr->nexthdr != NEXTHDR_HOP) { + if (pkt_len + sizeof(struct ipv6hdr) > skb->len) { + __IP6_INC_STATS(net, + idev, IPSTATS_MIB_INTRUNCATEDPKTS); + SKB_DR_SET(reason, PKT_TOO_SMALL); + goto drop; + } + if (pskb_trim_rcsum(skb, pkt_len + sizeof(struct ipv6hdr))) + goto err; + hdr = ipv6_hdr(skb); + } + + if (hdr->nexthdr == NEXTHDR_HOP) { + if (ipv6_parse_hopopts(skb) < 0) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INHDRERRORS); + rcu_read_unlock(); + return NULL; + } + } + + rcu_read_unlock(); + + /* Must drop socket now because of tproxy. */ + if (!skb_sk_is_prefetched(skb)) + skb_orphan(skb); + + return skb; +err: + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INHDRERRORS); + SKB_DR_OR(reason, IP_INHDR); +drop: + rcu_read_unlock(); + kfree_skb_reason(skb, reason); + return NULL; +} + +int ipv6_rcv(struct sk_buff *skb, struct net_device *dev, struct packet_type *pt, struct net_device *orig_dev) +{ + struct net *net = dev_net(skb->dev); + + skb = ip6_rcv_core(skb, dev, net); + if (skb == NULL) + return NET_RX_DROP; + return NF_HOOK(NFPROTO_IPV6, NF_INET_PRE_ROUTING, + net, NULL, skb, dev, NULL, + ip6_rcv_finish); +} + +static void ip6_sublist_rcv(struct list_head *head, struct net_device *dev, + struct net *net) +{ + NF_HOOK_LIST(NFPROTO_IPV6, NF_INET_PRE_ROUTING, net, NULL, + head, dev, NULL, ip6_rcv_finish); + ip6_list_rcv_finish(net, NULL, head); +} + +/* Receive a list of IPv6 packets */ +void ipv6_list_rcv(struct list_head *head, struct packet_type *pt, + struct net_device *orig_dev) +{ + struct net_device *curr_dev = NULL; + struct net *curr_net = NULL; + struct sk_buff *skb, *next; + struct list_head sublist; + + INIT_LIST_HEAD(&sublist); + list_for_each_entry_safe(skb, next, head, list) { + struct net_device *dev = skb->dev; + struct net *net = dev_net(dev); + + skb_list_del_init(skb); + skb = ip6_rcv_core(skb, dev, net); + if (skb == NULL) + continue; + + if (curr_dev != dev || curr_net != net) { + /* dispatch old sublist */ + if (!list_empty(&sublist)) + ip6_sublist_rcv(&sublist, curr_dev, curr_net); + /* start new sublist */ + INIT_LIST_HEAD(&sublist); + curr_dev = dev; + curr_net = net; + } + list_add_tail(&skb->list, &sublist); + } + /* dispatch final sublist */ + if (!list_empty(&sublist)) + ip6_sublist_rcv(&sublist, curr_dev, curr_net); +} + +INDIRECT_CALLABLE_DECLARE(int tcp_v6_rcv(struct sk_buff *)); + +/* + * Deliver the packet to the host + */ +void ip6_protocol_deliver_rcu(struct net *net, struct sk_buff *skb, int nexthdr, + bool have_final) +{ + const struct inet6_protocol *ipprot; + struct inet6_dev *idev; + unsigned int nhoff; + SKB_DR(reason); + bool raw; + + /* + * Parse extension headers + */ + +resubmit: + idev = ip6_dst_idev(skb_dst(skb)); + nhoff = IP6CB(skb)->nhoff; + if (!have_final) { + if (!pskb_pull(skb, skb_transport_offset(skb))) + goto discard; + nexthdr = skb_network_header(skb)[nhoff]; + } + +resubmit_final: + raw = raw6_local_deliver(skb, nexthdr); + ipprot = rcu_dereference(inet6_protos[nexthdr]); + if (ipprot) { + int ret; + + if (have_final) { + if (!(ipprot->flags & INET6_PROTO_FINAL)) { + /* Once we've seen a final protocol don't + * allow encapsulation on any non-final + * ones. This allows foo in UDP encapsulation + * to work. + */ + goto discard; + } + } else if (ipprot->flags & INET6_PROTO_FINAL) { + const struct ipv6hdr *hdr; + int sdif = inet6_sdif(skb); + struct net_device *dev; + + /* Only do this once for first final protocol */ + have_final = true; + + + skb_postpull_rcsum(skb, skb_network_header(skb), + skb_network_header_len(skb)); + hdr = ipv6_hdr(skb); + + /* skb->dev passed may be master dev for vrfs. */ + if (sdif) { + dev = dev_get_by_index_rcu(net, sdif); + if (!dev) + goto discard; + } else { + dev = skb->dev; + } + + if (ipv6_addr_is_multicast(&hdr->daddr) && + !ipv6_chk_mcast_addr(dev, &hdr->daddr, + &hdr->saddr) && + !ipv6_is_mld(skb, nexthdr, skb_network_header_len(skb))) { + SKB_DR_SET(reason, IP_INADDRERRORS); + goto discard; + } + } + if (!(ipprot->flags & INET6_PROTO_NOPOLICY)) { + if (!xfrm6_policy_check(NULL, XFRM_POLICY_IN, skb)) { + SKB_DR_SET(reason, XFRM_POLICY); + goto discard; + } + nf_reset_ct(skb); + } + + ret = INDIRECT_CALL_2(ipprot->handler, tcp_v6_rcv, udpv6_rcv, + skb); + if (ret > 0) { + if (ipprot->flags & INET6_PROTO_FINAL) { + /* Not an extension header, most likely UDP + * encapsulation. Use return value as nexthdr + * protocol not nhoff (which presumably is + * not set by handler). + */ + nexthdr = ret; + goto resubmit_final; + } else { + goto resubmit; + } + } else if (ret == 0) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INDELIVERS); + } + } else { + if (!raw) { + if (xfrm6_policy_check(NULL, XFRM_POLICY_IN, skb)) { + __IP6_INC_STATS(net, idev, + IPSTATS_MIB_INUNKNOWNPROTOS); + icmpv6_send(skb, ICMPV6_PARAMPROB, + ICMPV6_UNK_NEXTHDR, nhoff); + SKB_DR_SET(reason, IP_NOPROTO); + } else { + SKB_DR_SET(reason, XFRM_POLICY); + } + kfree_skb_reason(skb, reason); + } else { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INDELIVERS); + consume_skb(skb); + } + } + return; + +discard: + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INDISCARDS); + kfree_skb_reason(skb, reason); +} + +static int ip6_input_finish(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + skb_clear_delivery_time(skb); + rcu_read_lock(); + ip6_protocol_deliver_rcu(net, skb, 0, false); + rcu_read_unlock(); + + return 0; +} + + +int ip6_input(struct sk_buff *skb) +{ + return NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_IN, + dev_net(skb->dev), NULL, skb, skb->dev, NULL, + ip6_input_finish); +} +EXPORT_SYMBOL_GPL(ip6_input); + +int ip6_mc_input(struct sk_buff *skb) +{ + int sdif = inet6_sdif(skb); + const struct ipv6hdr *hdr; + struct net_device *dev; + bool deliver; + + __IP6_UPD_PO_STATS(dev_net(skb_dst(skb)->dev), + __in6_dev_get_safely(skb->dev), IPSTATS_MIB_INMCAST, + skb->len); + + /* skb->dev passed may be master dev for vrfs. */ + if (sdif) { + rcu_read_lock(); + dev = dev_get_by_index_rcu(dev_net(skb->dev), sdif); + if (!dev) { + rcu_read_unlock(); + kfree_skb(skb); + return -ENODEV; + } + } else { + dev = skb->dev; + } + + hdr = ipv6_hdr(skb); + deliver = ipv6_chk_mcast_addr(dev, &hdr->daddr, NULL); + if (sdif) + rcu_read_unlock(); + +#ifdef CONFIG_IPV6_MROUTE + /* + * IPv6 multicast router mode is now supported ;) + */ + if (atomic_read(&dev_net(skb->dev)->ipv6.devconf_all->mc_forwarding) && + !(ipv6_addr_type(&hdr->daddr) & + (IPV6_ADDR_LOOPBACK|IPV6_ADDR_LINKLOCAL)) && + likely(!(IP6CB(skb)->flags & IP6SKB_FORWARDED))) { + /* + * Okay, we try to forward - split and duplicate + * packets. + */ + struct sk_buff *skb2; + struct inet6_skb_parm *opt = IP6CB(skb); + + /* Check for MLD */ + if (unlikely(opt->flags & IP6SKB_ROUTERALERT)) { + /* Check if this is a mld message */ + u8 nexthdr = hdr->nexthdr; + __be16 frag_off; + int offset; + + /* Check if the value of Router Alert + * is for MLD (0x0000). + */ + if (opt->ra == htons(IPV6_OPT_ROUTERALERT_MLD)) { + deliver = false; + + if (!ipv6_ext_hdr(nexthdr)) { + /* BUG */ + goto out; + } + offset = ipv6_skip_exthdr(skb, sizeof(*hdr), + &nexthdr, &frag_off); + if (offset < 0) + goto out; + + if (ipv6_is_mld(skb, nexthdr, offset)) + deliver = true; + + goto out; + } + /* unknown RA - process it normally */ + } + + if (deliver) + skb2 = skb_clone(skb, GFP_ATOMIC); + else { + skb2 = skb; + skb = NULL; + } + + if (skb2) { + ip6_mr_input(skb2); + } + } +out: +#endif + if (likely(deliver)) + ip6_input(skb); + else { + /* discard */ + kfree_skb(skb); + } + + return 0; +} diff --git a/net/ipv6/ip6_offload.c b/net/ipv6/ip6_offload.c new file mode 100644 index 000000000..3ee345672 --- /dev/null +++ b/net/ipv6/ip6_offload.c @@ -0,0 +1,488 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPV6 GSO/GRO offload support + * Linux INET6 implementation + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "ip6_offload.h" + +/* All GRO functions are always builtin, except UDP over ipv6, which lays in + * ipv6 module, as it depends on UDPv6 lookup function, so we need special care + * when ipv6 is built as a module + */ +#if IS_BUILTIN(CONFIG_IPV6) +#define INDIRECT_CALL_L4(f, f2, f1, ...) INDIRECT_CALL_2(f, f2, f1, __VA_ARGS__) +#else +#define INDIRECT_CALL_L4(f, f2, f1, ...) INDIRECT_CALL_1(f, f2, __VA_ARGS__) +#endif + +#define indirect_call_gro_receive_l4(f2, f1, cb, head, skb) \ +({ \ + unlikely(gro_recursion_inc_test(skb)) ? \ + NAPI_GRO_CB(skb)->flush |= 1, NULL : \ + INDIRECT_CALL_L4(cb, f2, f1, head, skb); \ +}) + +static int ipv6_gso_pull_exthdrs(struct sk_buff *skb, int proto) +{ + const struct net_offload *ops = NULL; + + for (;;) { + struct ipv6_opt_hdr *opth; + int len; + + if (proto != NEXTHDR_HOP) { + ops = rcu_dereference(inet6_offloads[proto]); + + if (unlikely(!ops)) + break; + + if (!(ops->flags & INET6_PROTO_GSO_EXTHDR)) + break; + } + + if (unlikely(!pskb_may_pull(skb, 8))) + break; + + opth = (void *)skb->data; + len = ipv6_optlen(opth); + + if (unlikely(!pskb_may_pull(skb, len))) + break; + + opth = (void *)skb->data; + proto = opth->nexthdr; + __skb_pull(skb, len); + } + + return proto; +} + +static struct sk_buff *ipv6_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + struct sk_buff *segs = ERR_PTR(-EINVAL); + struct ipv6hdr *ipv6h; + const struct net_offload *ops; + int proto, nexthdr; + struct frag_hdr *fptr; + unsigned int payload_len; + u8 *prevhdr; + int offset = 0; + bool encap, udpfrag; + int nhoff; + bool gso_partial; + + skb_reset_network_header(skb); + nexthdr = ipv6_has_hopopt_jumbo(skb); + if (nexthdr) { + const int hophdr_len = sizeof(struct hop_jumbo_hdr); + int err; + + err = skb_cow_head(skb, 0); + if (err < 0) + return ERR_PTR(err); + + /* remove the HBH header. + * Layout: [Ethernet header][IPv6 header][HBH][TCP header] + */ + memmove(skb_mac_header(skb) + hophdr_len, + skb_mac_header(skb), + ETH_HLEN + sizeof(struct ipv6hdr)); + skb->data += hophdr_len; + skb->len -= hophdr_len; + skb->network_header += hophdr_len; + skb->mac_header += hophdr_len; + ipv6h = (struct ipv6hdr *)skb->data; + ipv6h->nexthdr = nexthdr; + } + nhoff = skb_network_header(skb) - skb_mac_header(skb); + if (unlikely(!pskb_may_pull(skb, sizeof(*ipv6h)))) + goto out; + + encap = SKB_GSO_CB(skb)->encap_level > 0; + if (encap) + features &= skb->dev->hw_enc_features; + SKB_GSO_CB(skb)->encap_level += sizeof(*ipv6h); + + ipv6h = ipv6_hdr(skb); + __skb_pull(skb, sizeof(*ipv6h)); + segs = ERR_PTR(-EPROTONOSUPPORT); + + proto = ipv6_gso_pull_exthdrs(skb, ipv6h->nexthdr); + + if (skb->encapsulation && + skb_shinfo(skb)->gso_type & (SKB_GSO_IPXIP4 | SKB_GSO_IPXIP6)) + udpfrag = proto == IPPROTO_UDP && encap && + (skb_shinfo(skb)->gso_type & SKB_GSO_UDP); + else + udpfrag = proto == IPPROTO_UDP && !skb->encapsulation && + (skb_shinfo(skb)->gso_type & SKB_GSO_UDP); + + ops = rcu_dereference(inet6_offloads[proto]); + if (likely(ops && ops->callbacks.gso_segment)) { + skb_reset_transport_header(skb); + segs = ops->callbacks.gso_segment(skb, features); + if (!segs) + skb->network_header = skb_mac_header(skb) + nhoff - skb->head; + } + + if (IS_ERR_OR_NULL(segs)) + goto out; + + gso_partial = !!(skb_shinfo(segs)->gso_type & SKB_GSO_PARTIAL); + + for (skb = segs; skb; skb = skb->next) { + ipv6h = (struct ipv6hdr *)(skb_mac_header(skb) + nhoff); + if (gso_partial && skb_is_gso(skb)) + payload_len = skb_shinfo(skb)->gso_size + + SKB_GSO_CB(skb)->data_offset + + skb->head - (unsigned char *)(ipv6h + 1); + else + payload_len = skb->len - nhoff - sizeof(*ipv6h); + ipv6h->payload_len = htons(payload_len); + skb->network_header = (u8 *)ipv6h - skb->head; + skb_reset_mac_len(skb); + + if (udpfrag) { + int err = ip6_find_1stfragopt(skb, &prevhdr); + if (err < 0) { + kfree_skb_list(segs); + return ERR_PTR(err); + } + fptr = (struct frag_hdr *)((u8 *)ipv6h + err); + fptr->frag_off = htons(offset); + if (skb->next) + fptr->frag_off |= htons(IP6_MF); + offset += (ntohs(ipv6h->payload_len) - + sizeof(struct frag_hdr)); + } + if (encap) + skb_reset_inner_headers(skb); + } + +out: + return segs; +} + +/* Return the total length of all the extension hdrs, following the same + * logic in ipv6_gso_pull_exthdrs() when parsing ext-hdrs. + */ +static int ipv6_exthdrs_len(struct ipv6hdr *iph, + const struct net_offload **opps) +{ + struct ipv6_opt_hdr *opth = (void *)iph; + int len = 0, proto, optlen = sizeof(*iph); + + proto = iph->nexthdr; + for (;;) { + if (proto != NEXTHDR_HOP) { + *opps = rcu_dereference(inet6_offloads[proto]); + if (unlikely(!(*opps))) + break; + if (!((*opps)->flags & INET6_PROTO_GSO_EXTHDR)) + break; + } + opth = (void *)opth + optlen; + optlen = ipv6_optlen(opth); + len += optlen; + proto = opth->nexthdr; + } + return len; +} + +INDIRECT_CALLABLE_SCOPE struct sk_buff *ipv6_gro_receive(struct list_head *head, + struct sk_buff *skb) +{ + const struct net_offload *ops; + struct sk_buff *pp = NULL; + struct sk_buff *p; + struct ipv6hdr *iph; + unsigned int nlen; + unsigned int hlen; + unsigned int off; + u16 flush = 1; + int proto; + + off = skb_gro_offset(skb); + hlen = off + sizeof(*iph); + iph = skb_gro_header(skb, hlen, off); + if (unlikely(!iph)) + goto out; + + skb_set_network_header(skb, off); + skb_gro_pull(skb, sizeof(*iph)); + skb_set_transport_header(skb, skb_gro_offset(skb)); + + flush += ntohs(iph->payload_len) != skb_gro_len(skb); + + proto = iph->nexthdr; + ops = rcu_dereference(inet6_offloads[proto]); + if (!ops || !ops->callbacks.gro_receive) { + pskb_pull(skb, skb_gro_offset(skb)); + skb_gro_frag0_invalidate(skb); + proto = ipv6_gso_pull_exthdrs(skb, proto); + skb_gro_pull(skb, -skb_transport_offset(skb)); + skb_reset_transport_header(skb); + __skb_push(skb, skb_gro_offset(skb)); + + ops = rcu_dereference(inet6_offloads[proto]); + if (!ops || !ops->callbacks.gro_receive) + goto out; + + iph = ipv6_hdr(skb); + } + + NAPI_GRO_CB(skb)->proto = proto; + + flush--; + nlen = skb_network_header_len(skb); + + list_for_each_entry(p, head, list) { + const struct ipv6hdr *iph2; + __be32 first_word; /* */ + + if (!NAPI_GRO_CB(p)->same_flow) + continue; + + iph2 = (struct ipv6hdr *)(p->data + off); + first_word = *(__be32 *)iph ^ *(__be32 *)iph2; + + /* All fields must match except length and Traffic Class. + * XXX skbs on the gro_list have all been parsed and pulled + * already so we don't need to compare nlen + * (nlen != (sizeof(*iph2) + ipv6_exthdrs_len(iph2, &ops))) + * memcmp() alone below is sufficient, right? + */ + if ((first_word & htonl(0xF00FFFFF)) || + !ipv6_addr_equal(&iph->saddr, &iph2->saddr) || + !ipv6_addr_equal(&iph->daddr, &iph2->daddr) || + iph->nexthdr != iph2->nexthdr) { +not_same_flow: + NAPI_GRO_CB(p)->same_flow = 0; + continue; + } + if (unlikely(nlen > sizeof(struct ipv6hdr))) { + if (memcmp(iph + 1, iph2 + 1, + nlen - sizeof(struct ipv6hdr))) + goto not_same_flow; + } + /* flush if Traffic Class fields are different */ + NAPI_GRO_CB(p)->flush |= !!((first_word & htonl(0x0FF00000)) | + (__force __be32)(iph->hop_limit ^ iph2->hop_limit)); + NAPI_GRO_CB(p)->flush |= flush; + + /* If the previous IP ID value was based on an atomic + * datagram we can overwrite the value and ignore it. + */ + if (NAPI_GRO_CB(skb)->is_atomic) + NAPI_GRO_CB(p)->flush_id = 0; + } + + NAPI_GRO_CB(skb)->is_atomic = true; + NAPI_GRO_CB(skb)->flush |= flush; + + skb_gro_postpull_rcsum(skb, iph, nlen); + + pp = indirect_call_gro_receive_l4(tcp6_gro_receive, udp6_gro_receive, + ops->callbacks.gro_receive, head, skb); + +out: + skb_gro_flush_final(skb, pp, flush); + + return pp; +} + +static struct sk_buff *sit_ip6ip6_gro_receive(struct list_head *head, + struct sk_buff *skb) +{ + /* Common GRO receive for SIT and IP6IP6 */ + + if (NAPI_GRO_CB(skb)->encap_mark) { + NAPI_GRO_CB(skb)->flush = 1; + return NULL; + } + + NAPI_GRO_CB(skb)->encap_mark = 1; + + return ipv6_gro_receive(head, skb); +} + +static struct sk_buff *ip4ip6_gro_receive(struct list_head *head, + struct sk_buff *skb) +{ + /* Common GRO receive for SIT and IP6IP6 */ + + if (NAPI_GRO_CB(skb)->encap_mark) { + NAPI_GRO_CB(skb)->flush = 1; + return NULL; + } + + NAPI_GRO_CB(skb)->encap_mark = 1; + + return inet_gro_receive(head, skb); +} + +INDIRECT_CALLABLE_SCOPE int ipv6_gro_complete(struct sk_buff *skb, int nhoff) +{ + const struct net_offload *ops; + struct ipv6hdr *iph; + int err = -ENOSYS; + u32 payload_len; + + if (skb->encapsulation) { + skb_set_inner_protocol(skb, cpu_to_be16(ETH_P_IPV6)); + skb_set_inner_network_header(skb, nhoff); + } + + payload_len = skb->len - nhoff - sizeof(*iph); + if (unlikely(payload_len > IPV6_MAXPLEN)) { + struct hop_jumbo_hdr *hop_jumbo; + int hoplen = sizeof(*hop_jumbo); + + /* Move network header left */ + memmove(skb_mac_header(skb) - hoplen, skb_mac_header(skb), + skb->transport_header - skb->mac_header); + skb->data -= hoplen; + skb->len += hoplen; + skb->mac_header -= hoplen; + skb->network_header -= hoplen; + iph = (struct ipv6hdr *)(skb->data + nhoff); + hop_jumbo = (struct hop_jumbo_hdr *)(iph + 1); + + /* Build hop-by-hop options */ + hop_jumbo->nexthdr = iph->nexthdr; + hop_jumbo->hdrlen = 0; + hop_jumbo->tlv_type = IPV6_TLV_JUMBO; + hop_jumbo->tlv_len = 4; + hop_jumbo->jumbo_payload_len = htonl(payload_len + hoplen); + + iph->nexthdr = NEXTHDR_HOP; + iph->payload_len = 0; + } else { + iph = (struct ipv6hdr *)(skb->data + nhoff); + iph->payload_len = htons(payload_len); + } + + nhoff += sizeof(*iph) + ipv6_exthdrs_len(iph, &ops); + if (WARN_ON(!ops || !ops->callbacks.gro_complete)) + goto out; + + err = INDIRECT_CALL_L4(ops->callbacks.gro_complete, tcp6_gro_complete, + udp6_gro_complete, skb, nhoff); + +out: + return err; +} + +static int sit_gro_complete(struct sk_buff *skb, int nhoff) +{ + skb->encapsulation = 1; + skb_shinfo(skb)->gso_type |= SKB_GSO_IPXIP4; + return ipv6_gro_complete(skb, nhoff); +} + +static int ip6ip6_gro_complete(struct sk_buff *skb, int nhoff) +{ + skb->encapsulation = 1; + skb_shinfo(skb)->gso_type |= SKB_GSO_IPXIP6; + return ipv6_gro_complete(skb, nhoff); +} + +static int ip4ip6_gro_complete(struct sk_buff *skb, int nhoff) +{ + skb->encapsulation = 1; + skb_shinfo(skb)->gso_type |= SKB_GSO_IPXIP6; + return inet_gro_complete(skb, nhoff); +} + +static struct packet_offload ipv6_packet_offload __read_mostly = { + .type = cpu_to_be16(ETH_P_IPV6), + .callbacks = { + .gso_segment = ipv6_gso_segment, + .gro_receive = ipv6_gro_receive, + .gro_complete = ipv6_gro_complete, + }, +}; + +static struct sk_buff *sit_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + if (!(skb_shinfo(skb)->gso_type & SKB_GSO_IPXIP4)) + return ERR_PTR(-EINVAL); + + return ipv6_gso_segment(skb, features); +} + +static struct sk_buff *ip4ip6_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + if (!(skb_shinfo(skb)->gso_type & SKB_GSO_IPXIP6)) + return ERR_PTR(-EINVAL); + + return inet_gso_segment(skb, features); +} + +static struct sk_buff *ip6ip6_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + if (!(skb_shinfo(skb)->gso_type & SKB_GSO_IPXIP6)) + return ERR_PTR(-EINVAL); + + return ipv6_gso_segment(skb, features); +} + +static const struct net_offload sit_offload = { + .callbacks = { + .gso_segment = sit_gso_segment, + .gro_receive = sit_ip6ip6_gro_receive, + .gro_complete = sit_gro_complete, + }, +}; + +static const struct net_offload ip4ip6_offload = { + .callbacks = { + .gso_segment = ip4ip6_gso_segment, + .gro_receive = ip4ip6_gro_receive, + .gro_complete = ip4ip6_gro_complete, + }, +}; + +static const struct net_offload ip6ip6_offload = { + .callbacks = { + .gso_segment = ip6ip6_gso_segment, + .gro_receive = sit_ip6ip6_gro_receive, + .gro_complete = ip6ip6_gro_complete, + }, +}; +static int __init ipv6_offload_init(void) +{ + + if (tcpv6_offload_init() < 0) + pr_crit("%s: Cannot add TCP protocol offload\n", __func__); + if (ipv6_exthdrs_offload_init() < 0) + pr_crit("%s: Cannot add EXTHDRS protocol offload\n", __func__); + + dev_add_offload(&ipv6_packet_offload); + + inet_add_offload(&sit_offload, IPPROTO_IPV6); + inet6_add_offload(&ip6ip6_offload, IPPROTO_IPV6); + inet6_add_offload(&ip4ip6_offload, IPPROTO_IPIP); + + return 0; +} + +fs_initcall(ipv6_offload_init); diff --git a/net/ipv6/ip6_offload.h b/net/ipv6/ip6_offload.h new file mode 100644 index 000000000..e76898760 --- /dev/null +++ b/net/ipv6/ip6_offload.h @@ -0,0 +1,15 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * IPV6 GSO/GRO offload support + * Linux INET6 implementation + */ + +#ifndef __ip6_offload_h +#define __ip6_offload_h + +int ipv6_exthdrs_offload_init(void); +int udpv6_offload_init(void); +int udpv6_offload_exit(void); +int tcpv6_offload_init(void); + +#endif diff --git a/net/ipv6/ip6_output.c b/net/ipv6/ip6_output.c new file mode 100644 index 000000000..e9ae084d0 --- /dev/null +++ b/net/ipv6/ip6_output.c @@ -0,0 +1,2084 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPv6 output functions + * Linux INET6 implementation + * + * Authors: + * Pedro Roque + * + * Based on linux/net/ipv4/ip_output.c + * + * Changes: + * A.N.Kuznetsov : airthmetics in fragmentation. + * extension headers are implemented. + * route changes now work. + * ip6_forward does not confuse sniffers. + * etc. + * + * H. von Brand : Added missing #include + * Imran Patel : frag id should be in NBO + * Kazunori MIYAZAWA @USAGI + * : add ip6_append_data and related functions + * for datagram xmit + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int ip6_finish_output2(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + struct net_device *dev = dst->dev; + struct inet6_dev *idev = ip6_dst_idev(dst); + unsigned int hh_len = LL_RESERVED_SPACE(dev); + const struct in6_addr *daddr, *nexthop; + struct ipv6hdr *hdr; + struct neighbour *neigh; + int ret; + + /* Be paranoid, rather than too clever. */ + if (unlikely(hh_len > skb_headroom(skb)) && dev->header_ops) { + skb = skb_expand_head(skb, hh_len); + if (!skb) { + IP6_INC_STATS(net, idev, IPSTATS_MIB_OUTDISCARDS); + return -ENOMEM; + } + } + + hdr = ipv6_hdr(skb); + daddr = &hdr->daddr; + if (ipv6_addr_is_multicast(daddr)) { + if (!(dev->flags & IFF_LOOPBACK) && sk_mc_loop(sk) && + ((mroute6_is_socket(net, skb) && + !(IP6CB(skb)->flags & IP6SKB_FORWARDED)) || + ipv6_chk_mcast_addr(dev, daddr, &hdr->saddr))) { + struct sk_buff *newskb = skb_clone(skb, GFP_ATOMIC); + + /* Do not check for IFF_ALLMULTI; multicast routing + is not supported in any case. + */ + if (newskb) + NF_HOOK(NFPROTO_IPV6, NF_INET_POST_ROUTING, + net, sk, newskb, NULL, newskb->dev, + dev_loopback_xmit); + + if (hdr->hop_limit == 0) { + IP6_INC_STATS(net, idev, + IPSTATS_MIB_OUTDISCARDS); + kfree_skb(skb); + return 0; + } + } + + IP6_UPD_PO_STATS(net, idev, IPSTATS_MIB_OUTMCAST, skb->len); + if (IPV6_ADDR_MC_SCOPE(daddr) <= IPV6_ADDR_SCOPE_NODELOCAL && + !(dev->flags & IFF_LOOPBACK)) { + kfree_skb(skb); + return 0; + } + } + + if (lwtunnel_xmit_redirect(dst->lwtstate)) { + int res = lwtunnel_xmit(skb); + + if (res != LWTUNNEL_XMIT_CONTINUE) + return res; + } + + rcu_read_lock(); + nexthop = rt6_nexthop((struct rt6_info *)dst, daddr); + neigh = __ipv6_neigh_lookup_noref(dev, nexthop); + + if (unlikely(IS_ERR_OR_NULL(neigh))) { + if (unlikely(!neigh)) + neigh = __neigh_create(&nd_tbl, nexthop, dev, false); + if (IS_ERR(neigh)) { + rcu_read_unlock(); + IP6_INC_STATS(net, idev, IPSTATS_MIB_OUTNOROUTES); + kfree_skb_reason(skb, SKB_DROP_REASON_NEIGH_CREATEFAIL); + return -EINVAL; + } + } + sock_confirm_neigh(skb, neigh); + ret = neigh_output(neigh, skb, false); + rcu_read_unlock(); + return ret; +} + +static int +ip6_finish_output_gso_slowpath_drop(struct net *net, struct sock *sk, + struct sk_buff *skb, unsigned int mtu) +{ + struct sk_buff *segs, *nskb; + netdev_features_t features; + int ret = 0; + + /* Please see corresponding comment in ip_finish_output_gso + * describing the cases where GSO segment length exceeds the + * egress MTU. + */ + features = netif_skb_features(skb); + segs = skb_gso_segment(skb, features & ~NETIF_F_GSO_MASK); + if (IS_ERR_OR_NULL(segs)) { + kfree_skb(skb); + return -ENOMEM; + } + + consume_skb(skb); + + skb_list_walk_safe(segs, segs, nskb) { + int err; + + skb_mark_not_on_list(segs); + /* Last GSO segment can be smaller than gso_size (and MTU). + * Adding a fragment header would produce an "atomic fragment", + * which is considered harmful (RFC-8021). Avoid that. + */ + err = segs->len > mtu ? + ip6_fragment(net, sk, segs, ip6_finish_output2) : + ip6_finish_output2(net, sk, segs); + if (err && ret == 0) + ret = err; + } + + return ret; +} + +static int __ip6_finish_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + unsigned int mtu; + +#if defined(CONFIG_NETFILTER) && defined(CONFIG_XFRM) + /* Policy lookup after SNAT yielded a new policy */ + if (skb_dst(skb)->xfrm) { + IP6CB(skb)->flags |= IP6SKB_REROUTED; + return dst_output(net, sk, skb); + } +#endif + + mtu = ip6_skb_dst_mtu(skb); + if (skb_is_gso(skb) && + !(IP6CB(skb)->flags & IP6SKB_FAKEJUMBO) && + !skb_gso_validate_network_len(skb, mtu)) + return ip6_finish_output_gso_slowpath_drop(net, sk, skb, mtu); + + if ((skb->len > mtu && !skb_is_gso(skb)) || + dst_allfrag(skb_dst(skb)) || + (IP6CB(skb)->frag_max_size && skb->len > IP6CB(skb)->frag_max_size)) + return ip6_fragment(net, sk, skb, ip6_finish_output2); + else + return ip6_finish_output2(net, sk, skb); +} + +static int ip6_finish_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + int ret; + + ret = BPF_CGROUP_RUN_PROG_INET_EGRESS(sk, skb); + switch (ret) { + case NET_XMIT_SUCCESS: + case NET_XMIT_CN: + return __ip6_finish_output(net, sk, skb) ? : ret; + default: + kfree_skb_reason(skb, SKB_DROP_REASON_BPF_CGROUP_EGRESS); + return ret; + } +} + +int ip6_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct net_device *dev = skb_dst(skb)->dev, *indev = skb->dev; + struct inet6_dev *idev = ip6_dst_idev(skb_dst(skb)); + + skb->protocol = htons(ETH_P_IPV6); + skb->dev = dev; + + if (unlikely(idev->cnf.disable_ipv6)) { + IP6_INC_STATS(net, idev, IPSTATS_MIB_OUTDISCARDS); + kfree_skb_reason(skb, SKB_DROP_REASON_IPV6DISABLED); + return 0; + } + + return NF_HOOK_COND(NFPROTO_IPV6, NF_INET_POST_ROUTING, + net, sk, skb, indev, dev, + ip6_finish_output, + !(IP6CB(skb)->flags & IP6SKB_REROUTED)); +} +EXPORT_SYMBOL(ip6_output); + +bool ip6_autoflowlabel(struct net *net, const struct ipv6_pinfo *np) +{ + if (!np->autoflowlabel_set) + return ip6_default_np_autolabel(net); + else + return np->autoflowlabel; +} + +/* + * xmit an sk_buff (used by TCP, SCTP and DCCP) + * Note : socket lock is not held for SYNACK packets, but might be modified + * by calls to skb_set_owner_w() and ipv6_local_error(), + * which are using proper atomic operations or spinlocks. + */ +int ip6_xmit(const struct sock *sk, struct sk_buff *skb, struct flowi6 *fl6, + __u32 mark, struct ipv6_txoptions *opt, int tclass, u32 priority) +{ + struct net *net = sock_net(sk); + const struct ipv6_pinfo *np = inet6_sk(sk); + struct in6_addr *first_hop = &fl6->daddr; + struct dst_entry *dst = skb_dst(skb); + struct net_device *dev = dst->dev; + struct inet6_dev *idev = ip6_dst_idev(dst); + struct hop_jumbo_hdr *hop_jumbo; + int hoplen = sizeof(*hop_jumbo); + unsigned int head_room; + struct ipv6hdr *hdr; + u8 proto = fl6->flowi6_proto; + int seg_len = skb->len; + int hlimit = -1; + u32 mtu; + + head_room = sizeof(struct ipv6hdr) + hoplen + LL_RESERVED_SPACE(dev); + if (opt) + head_room += opt->opt_nflen + opt->opt_flen; + + if (unlikely(head_room > skb_headroom(skb))) { + skb = skb_expand_head(skb, head_room); + if (!skb) { + IP6_INC_STATS(net, idev, IPSTATS_MIB_OUTDISCARDS); + return -ENOBUFS; + } + } + + if (opt) { + seg_len += opt->opt_nflen + opt->opt_flen; + + if (opt->opt_flen) + ipv6_push_frag_opts(skb, opt, &proto); + + if (opt->opt_nflen) + ipv6_push_nfrag_opts(skb, opt, &proto, &first_hop, + &fl6->saddr); + } + + if (unlikely(seg_len > IPV6_MAXPLEN)) { + hop_jumbo = skb_push(skb, hoplen); + + hop_jumbo->nexthdr = proto; + hop_jumbo->hdrlen = 0; + hop_jumbo->tlv_type = IPV6_TLV_JUMBO; + hop_jumbo->tlv_len = 4; + hop_jumbo->jumbo_payload_len = htonl(seg_len + hoplen); + + proto = IPPROTO_HOPOPTS; + seg_len = 0; + IP6CB(skb)->flags |= IP6SKB_FAKEJUMBO; + } + + skb_push(skb, sizeof(struct ipv6hdr)); + skb_reset_network_header(skb); + hdr = ipv6_hdr(skb); + + /* + * Fill in the IPv6 header + */ + if (np) + hlimit = np->hop_limit; + if (hlimit < 0) + hlimit = ip6_dst_hoplimit(dst); + + ip6_flow_hdr(hdr, tclass, ip6_make_flowlabel(net, skb, fl6->flowlabel, + ip6_autoflowlabel(net, np), fl6)); + + hdr->payload_len = htons(seg_len); + hdr->nexthdr = proto; + hdr->hop_limit = hlimit; + + hdr->saddr = fl6->saddr; + hdr->daddr = *first_hop; + + skb->protocol = htons(ETH_P_IPV6); + skb->priority = priority; + skb->mark = mark; + + mtu = dst_mtu(dst); + if ((skb->len <= mtu) || skb->ignore_df || skb_is_gso(skb)) { + IP6_UPD_PO_STATS(net, idev, IPSTATS_MIB_OUT, skb->len); + + /* if egress device is enslaved to an L3 master device pass the + * skb to its handler for processing + */ + skb = l3mdev_ip6_out((struct sock *)sk, skb); + if (unlikely(!skb)) + return 0; + + /* hooks should never assume socket lock is held. + * we promote our socket to non const + */ + return NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, + net, (struct sock *)sk, skb, NULL, dev, + dst_output); + } + + skb->dev = dev; + /* ipv6_local_error() does not require socket lock, + * we promote our socket to non const + */ + ipv6_local_error((struct sock *)sk, EMSGSIZE, fl6, mtu); + + IP6_INC_STATS(net, idev, IPSTATS_MIB_FRAGFAILS); + kfree_skb(skb); + return -EMSGSIZE; +} +EXPORT_SYMBOL(ip6_xmit); + +static int ip6_call_ra_chain(struct sk_buff *skb, int sel) +{ + struct ip6_ra_chain *ra; + struct sock *last = NULL; + + read_lock(&ip6_ra_lock); + for (ra = ip6_ra_chain; ra; ra = ra->next) { + struct sock *sk = ra->sk; + if (sk && ra->sel == sel && + (!sk->sk_bound_dev_if || + sk->sk_bound_dev_if == skb->dev->ifindex)) { + struct ipv6_pinfo *np = inet6_sk(sk); + + if (np && np->rtalert_isolate && + !net_eq(sock_net(sk), dev_net(skb->dev))) { + continue; + } + if (last) { + struct sk_buff *skb2 = skb_clone(skb, GFP_ATOMIC); + if (skb2) + rawv6_rcv(last, skb2); + } + last = sk; + } + } + + if (last) { + rawv6_rcv(last, skb); + read_unlock(&ip6_ra_lock); + return 1; + } + read_unlock(&ip6_ra_lock); + return 0; +} + +static int ip6_forward_proxy_check(struct sk_buff *skb) +{ + struct ipv6hdr *hdr = ipv6_hdr(skb); + u8 nexthdr = hdr->nexthdr; + __be16 frag_off; + int offset; + + if (ipv6_ext_hdr(nexthdr)) { + offset = ipv6_skip_exthdr(skb, sizeof(*hdr), &nexthdr, &frag_off); + if (offset < 0) + return 0; + } else + offset = sizeof(struct ipv6hdr); + + if (nexthdr == IPPROTO_ICMPV6) { + struct icmp6hdr *icmp6; + + if (!pskb_may_pull(skb, (skb_network_header(skb) + + offset + 1 - skb->data))) + return 0; + + icmp6 = (struct icmp6hdr *)(skb_network_header(skb) + offset); + + switch (icmp6->icmp6_type) { + case NDISC_ROUTER_SOLICITATION: + case NDISC_ROUTER_ADVERTISEMENT: + case NDISC_NEIGHBOUR_SOLICITATION: + case NDISC_NEIGHBOUR_ADVERTISEMENT: + case NDISC_REDIRECT: + /* For reaction involving unicast neighbor discovery + * message destined to the proxied address, pass it to + * input function. + */ + return 1; + default: + break; + } + } + + /* + * The proxying router can't forward traffic sent to a link-local + * address, so signal the sender and discard the packet. This + * behavior is clarified by the MIPv6 specification. + */ + if (ipv6_addr_type(&hdr->daddr) & IPV6_ADDR_LINKLOCAL) { + dst_link_failure(skb); + return -1; + } + + return 0; +} + +static inline int ip6_forward_finish(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + + __IP6_INC_STATS(net, ip6_dst_idev(dst), IPSTATS_MIB_OUTFORWDATAGRAMS); + __IP6_ADD_STATS(net, ip6_dst_idev(dst), IPSTATS_MIB_OUTOCTETS, skb->len); + +#ifdef CONFIG_NET_SWITCHDEV + if (skb->offload_l3_fwd_mark) { + consume_skb(skb); + return 0; + } +#endif + + skb_clear_tstamp(skb); + return dst_output(net, sk, skb); +} + +static bool ip6_pkt_too_big(const struct sk_buff *skb, unsigned int mtu) +{ + if (skb->len <= mtu) + return false; + + /* ipv6 conntrack defrag sets max_frag_size + ignore_df */ + if (IP6CB(skb)->frag_max_size && IP6CB(skb)->frag_max_size > mtu) + return true; + + if (skb->ignore_df) + return false; + + if (skb_is_gso(skb) && skb_gso_validate_network_len(skb, mtu)) + return false; + + return true; +} + +int ip6_forward(struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + struct ipv6hdr *hdr = ipv6_hdr(skb); + struct inet6_skb_parm *opt = IP6CB(skb); + struct net *net = dev_net(dst->dev); + struct inet6_dev *idev; + SKB_DR(reason); + u32 mtu; + + idev = __in6_dev_get_safely(dev_get_by_index_rcu(net, IP6CB(skb)->iif)); + if (net->ipv6.devconf_all->forwarding == 0) + goto error; + + if (skb->pkt_type != PACKET_HOST) + goto drop; + + if (unlikely(skb->sk)) + goto drop; + + if (skb_warn_if_lro(skb)) + goto drop; + + if (!net->ipv6.devconf_all->disable_policy && + (!idev || !idev->cnf.disable_policy) && + !xfrm6_policy_check(NULL, XFRM_POLICY_FWD, skb)) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INDISCARDS); + goto drop; + } + + skb_forward_csum(skb); + + /* + * We DO NOT make any processing on + * RA packets, pushing them to user level AS IS + * without ane WARRANTY that application will be able + * to interpret them. The reason is that we + * cannot make anything clever here. + * + * We are not end-node, so that if packet contains + * AH/ESP, we cannot make anything. + * Defragmentation also would be mistake, RA packets + * cannot be fragmented, because there is no warranty + * that different fragments will go along one path. --ANK + */ + if (unlikely(opt->flags & IP6SKB_ROUTERALERT)) { + if (ip6_call_ra_chain(skb, ntohs(opt->ra))) + return 0; + } + + /* + * check and decrement ttl + */ + if (hdr->hop_limit <= 1) { + icmpv6_send(skb, ICMPV6_TIME_EXCEED, ICMPV6_EXC_HOPLIMIT, 0); + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INHDRERRORS); + + kfree_skb_reason(skb, SKB_DROP_REASON_IP_INHDR); + return -ETIMEDOUT; + } + + /* XXX: idev->cnf.proxy_ndp? */ + if (net->ipv6.devconf_all->proxy_ndp && + pneigh_lookup(&nd_tbl, net, &hdr->daddr, skb->dev, 0)) { + int proxied = ip6_forward_proxy_check(skb); + if (proxied > 0) { + /* It's tempting to decrease the hop limit + * here by 1, as we do at the end of the + * function too. + * + * But that would be incorrect, as proxying is + * not forwarding. The ip6_input function + * will handle this packet locally, and it + * depends on the hop limit being unchanged. + * + * One example is the NDP hop limit, that + * always has to stay 255, but other would be + * similar checks around RA packets, where the + * user can even change the desired limit. + */ + return ip6_input(skb); + } else if (proxied < 0) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INDISCARDS); + goto drop; + } + } + + if (!xfrm6_route_forward(skb)) { + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INDISCARDS); + SKB_DR_SET(reason, XFRM_POLICY); + goto drop; + } + dst = skb_dst(skb); + + /* IPv6 specs say nothing about it, but it is clear that we cannot + send redirects to source routed frames. + We don't send redirects to frames decapsulated from IPsec. + */ + if (IP6CB(skb)->iif == dst->dev->ifindex && + opt->srcrt == 0 && !skb_sec_path(skb)) { + struct in6_addr *target = NULL; + struct inet_peer *peer; + struct rt6_info *rt; + + /* + * incoming and outgoing devices are the same + * send a redirect. + */ + + rt = (struct rt6_info *) dst; + if (rt->rt6i_flags & RTF_GATEWAY) + target = &rt->rt6i_gateway; + else + target = &hdr->daddr; + + peer = inet_getpeer_v6(net->ipv6.peers, &hdr->daddr, 1); + + /* Limit redirects both by destination (here) + and by source (inside ndisc_send_redirect) + */ + if (inet_peer_xrlim_allow(peer, 1*HZ)) + ndisc_send_redirect(skb, target); + if (peer) + inet_putpeer(peer); + } else { + int addrtype = ipv6_addr_type(&hdr->saddr); + + /* This check is security critical. */ + if (addrtype == IPV6_ADDR_ANY || + addrtype & (IPV6_ADDR_MULTICAST | IPV6_ADDR_LOOPBACK)) + goto error; + if (addrtype & IPV6_ADDR_LINKLOCAL) { + icmpv6_send(skb, ICMPV6_DEST_UNREACH, + ICMPV6_NOT_NEIGHBOUR, 0); + goto error; + } + } + + mtu = ip6_dst_mtu_maybe_forward(dst, true); + if (mtu < IPV6_MIN_MTU) + mtu = IPV6_MIN_MTU; + + if (ip6_pkt_too_big(skb, mtu)) { + /* Again, force OUTPUT device used as source address */ + skb->dev = dst->dev; + icmpv6_send(skb, ICMPV6_PKT_TOOBIG, 0, mtu); + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INTOOBIGERRORS); + __IP6_INC_STATS(net, ip6_dst_idev(dst), + IPSTATS_MIB_FRAGFAILS); + kfree_skb_reason(skb, SKB_DROP_REASON_PKT_TOO_BIG); + return -EMSGSIZE; + } + + if (skb_cow(skb, dst->dev->hard_header_len)) { + __IP6_INC_STATS(net, ip6_dst_idev(dst), + IPSTATS_MIB_OUTDISCARDS); + goto drop; + } + + hdr = ipv6_hdr(skb); + + /* Mangling hops number delayed to point after skb COW */ + + hdr->hop_limit--; + + return NF_HOOK(NFPROTO_IPV6, NF_INET_FORWARD, + net, NULL, skb, skb->dev, dst->dev, + ip6_forward_finish); + +error: + __IP6_INC_STATS(net, idev, IPSTATS_MIB_INADDRERRORS); + SKB_DR_SET(reason, IP_INADDRERRORS); +drop: + kfree_skb_reason(skb, reason); + return -EINVAL; +} + +static void ip6_copy_metadata(struct sk_buff *to, struct sk_buff *from) +{ + to->pkt_type = from->pkt_type; + to->priority = from->priority; + to->protocol = from->protocol; + skb_dst_drop(to); + skb_dst_set(to, dst_clone(skb_dst(from))); + to->dev = from->dev; + to->mark = from->mark; + + skb_copy_hash(to, from); + +#ifdef CONFIG_NET_SCHED + to->tc_index = from->tc_index; +#endif + nf_copy(to, from); + skb_ext_copy(to, from); + skb_copy_secmark(to, from); +} + +int ip6_fraglist_init(struct sk_buff *skb, unsigned int hlen, u8 *prevhdr, + u8 nexthdr, __be32 frag_id, + struct ip6_fraglist_iter *iter) +{ + unsigned int first_len; + struct frag_hdr *fh; + + /* BUILD HEADER */ + *prevhdr = NEXTHDR_FRAGMENT; + iter->tmp_hdr = kmemdup(skb_network_header(skb), hlen, GFP_ATOMIC); + if (!iter->tmp_hdr) + return -ENOMEM; + + iter->frag = skb_shinfo(skb)->frag_list; + skb_frag_list_init(skb); + + iter->offset = 0; + iter->hlen = hlen; + iter->frag_id = frag_id; + iter->nexthdr = nexthdr; + + __skb_pull(skb, hlen); + fh = __skb_push(skb, sizeof(struct frag_hdr)); + __skb_push(skb, hlen); + skb_reset_network_header(skb); + memcpy(skb_network_header(skb), iter->tmp_hdr, hlen); + + fh->nexthdr = nexthdr; + fh->reserved = 0; + fh->frag_off = htons(IP6_MF); + fh->identification = frag_id; + + first_len = skb_pagelen(skb); + skb->data_len = first_len - skb_headlen(skb); + skb->len = first_len; + ipv6_hdr(skb)->payload_len = htons(first_len - sizeof(struct ipv6hdr)); + + return 0; +} +EXPORT_SYMBOL(ip6_fraglist_init); + +void ip6_fraglist_prepare(struct sk_buff *skb, + struct ip6_fraglist_iter *iter) +{ + struct sk_buff *frag = iter->frag; + unsigned int hlen = iter->hlen; + struct frag_hdr *fh; + + frag->ip_summed = CHECKSUM_NONE; + skb_reset_transport_header(frag); + fh = __skb_push(frag, sizeof(struct frag_hdr)); + __skb_push(frag, hlen); + skb_reset_network_header(frag); + memcpy(skb_network_header(frag), iter->tmp_hdr, hlen); + iter->offset += skb->len - hlen - sizeof(struct frag_hdr); + fh->nexthdr = iter->nexthdr; + fh->reserved = 0; + fh->frag_off = htons(iter->offset); + if (frag->next) + fh->frag_off |= htons(IP6_MF); + fh->identification = iter->frag_id; + ipv6_hdr(frag)->payload_len = htons(frag->len - sizeof(struct ipv6hdr)); + ip6_copy_metadata(frag, skb); +} +EXPORT_SYMBOL(ip6_fraglist_prepare); + +void ip6_frag_init(struct sk_buff *skb, unsigned int hlen, unsigned int mtu, + unsigned short needed_tailroom, int hdr_room, u8 *prevhdr, + u8 nexthdr, __be32 frag_id, struct ip6_frag_state *state) +{ + state->prevhdr = prevhdr; + state->nexthdr = nexthdr; + state->frag_id = frag_id; + + state->hlen = hlen; + state->mtu = mtu; + + state->left = skb->len - hlen; /* Space per frame */ + state->ptr = hlen; /* Where to start from */ + + state->hroom = hdr_room; + state->troom = needed_tailroom; + + state->offset = 0; +} +EXPORT_SYMBOL(ip6_frag_init); + +struct sk_buff *ip6_frag_next(struct sk_buff *skb, struct ip6_frag_state *state) +{ + u8 *prevhdr = state->prevhdr, *fragnexthdr_offset; + struct sk_buff *frag; + struct frag_hdr *fh; + unsigned int len; + + len = state->left; + /* IF: it doesn't fit, use 'mtu' - the data space left */ + if (len > state->mtu) + len = state->mtu; + /* IF: we are not sending up to and including the packet end + then align the next start on an eight byte boundary */ + if (len < state->left) + len &= ~7; + + /* Allocate buffer */ + frag = alloc_skb(len + state->hlen + sizeof(struct frag_hdr) + + state->hroom + state->troom, GFP_ATOMIC); + if (!frag) + return ERR_PTR(-ENOMEM); + + /* + * Set up data on packet + */ + + ip6_copy_metadata(frag, skb); + skb_reserve(frag, state->hroom); + skb_put(frag, len + state->hlen + sizeof(struct frag_hdr)); + skb_reset_network_header(frag); + fh = (struct frag_hdr *)(skb_network_header(frag) + state->hlen); + frag->transport_header = (frag->network_header + state->hlen + + sizeof(struct frag_hdr)); + + /* + * Charge the memory for the fragment to any owner + * it might possess + */ + if (skb->sk) + skb_set_owner_w(frag, skb->sk); + + /* + * Copy the packet header into the new buffer. + */ + skb_copy_from_linear_data(skb, skb_network_header(frag), state->hlen); + + fragnexthdr_offset = skb_network_header(frag); + fragnexthdr_offset += prevhdr - skb_network_header(skb); + *fragnexthdr_offset = NEXTHDR_FRAGMENT; + + /* + * Build fragment header. + */ + fh->nexthdr = state->nexthdr; + fh->reserved = 0; + fh->identification = state->frag_id; + + /* + * Copy a block of the IP datagram. + */ + BUG_ON(skb_copy_bits(skb, state->ptr, skb_transport_header(frag), + len)); + state->left -= len; + + fh->frag_off = htons(state->offset); + if (state->left > 0) + fh->frag_off |= htons(IP6_MF); + ipv6_hdr(frag)->payload_len = htons(frag->len - sizeof(struct ipv6hdr)); + + state->ptr += len; + state->offset += len; + + return frag; +} +EXPORT_SYMBOL(ip6_frag_next); + +int ip6_fragment(struct net *net, struct sock *sk, struct sk_buff *skb, + int (*output)(struct net *, struct sock *, struct sk_buff *)) +{ + struct sk_buff *frag; + struct rt6_info *rt = (struct rt6_info *)skb_dst(skb); + struct ipv6_pinfo *np = skb->sk && !dev_recursion_level() ? + inet6_sk(skb->sk) : NULL; + bool mono_delivery_time = skb->mono_delivery_time; + struct ip6_frag_state state; + unsigned int mtu, hlen, nexthdr_offset; + ktime_t tstamp = skb->tstamp; + int hroom, err = 0; + __be32 frag_id; + u8 *prevhdr, nexthdr = 0; + + err = ip6_find_1stfragopt(skb, &prevhdr); + if (err < 0) + goto fail; + hlen = err; + nexthdr = *prevhdr; + nexthdr_offset = prevhdr - skb_network_header(skb); + + mtu = ip6_skb_dst_mtu(skb); + + /* We must not fragment if the socket is set to force MTU discovery + * or if the skb it not generated by a local socket. + */ + if (unlikely(!skb->ignore_df && skb->len > mtu)) + goto fail_toobig; + + if (IP6CB(skb)->frag_max_size) { + if (IP6CB(skb)->frag_max_size > mtu) + goto fail_toobig; + + /* don't send fragments larger than what we received */ + mtu = IP6CB(skb)->frag_max_size; + if (mtu < IPV6_MIN_MTU) + mtu = IPV6_MIN_MTU; + } + + if (np && np->frag_size < mtu) { + if (np->frag_size) + mtu = np->frag_size; + } + if (mtu < hlen + sizeof(struct frag_hdr) + 8) + goto fail_toobig; + mtu -= hlen + sizeof(struct frag_hdr); + + frag_id = ipv6_select_ident(net, &ipv6_hdr(skb)->daddr, + &ipv6_hdr(skb)->saddr); + + if (skb->ip_summed == CHECKSUM_PARTIAL && + (err = skb_checksum_help(skb))) + goto fail; + + prevhdr = skb_network_header(skb) + nexthdr_offset; + hroom = LL_RESERVED_SPACE(rt->dst.dev); + if (skb_has_frag_list(skb)) { + unsigned int first_len = skb_pagelen(skb); + struct ip6_fraglist_iter iter; + struct sk_buff *frag2; + + if (first_len - hlen > mtu || + ((first_len - hlen) & 7) || + skb_cloned(skb) || + skb_headroom(skb) < (hroom + sizeof(struct frag_hdr))) + goto slow_path; + + skb_walk_frags(skb, frag) { + /* Correct geometry. */ + if (frag->len > mtu || + ((frag->len & 7) && frag->next) || + skb_headroom(frag) < (hlen + hroom + sizeof(struct frag_hdr))) + goto slow_path_clean; + + /* Partially cloned skb? */ + if (skb_shared(frag)) + goto slow_path_clean; + + BUG_ON(frag->sk); + if (skb->sk) { + frag->sk = skb->sk; + frag->destructor = sock_wfree; + } + skb->truesize -= frag->truesize; + } + + err = ip6_fraglist_init(skb, hlen, prevhdr, nexthdr, frag_id, + &iter); + if (err < 0) + goto fail; + + /* We prevent @rt from being freed. */ + rcu_read_lock(); + + for (;;) { + /* Prepare header of the next frame, + * before previous one went down. */ + if (iter.frag) + ip6_fraglist_prepare(skb, &iter); + + skb_set_delivery_time(skb, tstamp, mono_delivery_time); + err = output(net, sk, skb); + if (!err) + IP6_INC_STATS(net, ip6_dst_idev(&rt->dst), + IPSTATS_MIB_FRAGCREATES); + + if (err || !iter.frag) + break; + + skb = ip6_fraglist_next(&iter); + } + + kfree(iter.tmp_hdr); + + if (err == 0) { + IP6_INC_STATS(net, ip6_dst_idev(&rt->dst), + IPSTATS_MIB_FRAGOKS); + rcu_read_unlock(); + return 0; + } + + kfree_skb_list(iter.frag); + + IP6_INC_STATS(net, ip6_dst_idev(&rt->dst), + IPSTATS_MIB_FRAGFAILS); + rcu_read_unlock(); + return err; + +slow_path_clean: + skb_walk_frags(skb, frag2) { + if (frag2 == frag) + break; + frag2->sk = NULL; + frag2->destructor = NULL; + skb->truesize += frag2->truesize; + } + } + +slow_path: + /* + * Fragment the datagram. + */ + + ip6_frag_init(skb, hlen, mtu, rt->dst.dev->needed_tailroom, + LL_RESERVED_SPACE(rt->dst.dev), prevhdr, nexthdr, frag_id, + &state); + + /* + * Keep copying data until we run out. + */ + + while (state.left > 0) { + frag = ip6_frag_next(skb, &state); + if (IS_ERR(frag)) { + err = PTR_ERR(frag); + goto fail; + } + + /* + * Put this fragment into the sending queue. + */ + skb_set_delivery_time(frag, tstamp, mono_delivery_time); + err = output(net, sk, frag); + if (err) + goto fail; + + IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)), + IPSTATS_MIB_FRAGCREATES); + } + IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)), + IPSTATS_MIB_FRAGOKS); + consume_skb(skb); + return err; + +fail_toobig: + if (skb->sk && dst_allfrag(skb_dst(skb))) + sk_gso_disable(skb->sk); + + icmpv6_send(skb, ICMPV6_PKT_TOOBIG, 0, mtu); + err = -EMSGSIZE; + +fail: + IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)), + IPSTATS_MIB_FRAGFAILS); + kfree_skb(skb); + return err; +} + +static inline int ip6_rt_check(const struct rt6key *rt_key, + const struct in6_addr *fl_addr, + const struct in6_addr *addr_cache) +{ + return (rt_key->plen != 128 || !ipv6_addr_equal(fl_addr, &rt_key->addr)) && + (!addr_cache || !ipv6_addr_equal(fl_addr, addr_cache)); +} + +static struct dst_entry *ip6_sk_dst_check(struct sock *sk, + struct dst_entry *dst, + const struct flowi6 *fl6) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct rt6_info *rt; + + if (!dst) + goto out; + + if (dst->ops->family != AF_INET6) { + dst_release(dst); + return NULL; + } + + rt = (struct rt6_info *)dst; + /* Yes, checking route validity in not connected + * case is not very simple. Take into account, + * that we do not support routing by source, TOS, + * and MSG_DONTROUTE --ANK (980726) + * + * 1. ip6_rt_check(): If route was host route, + * check that cached destination is current. + * If it is network route, we still may + * check its validity using saved pointer + * to the last used address: daddr_cache. + * We do not want to save whole address now, + * (because main consumer of this service + * is tcp, which has not this problem), + * so that the last trick works only on connected + * sockets. + * 2. oif also should be the same. + */ + if (ip6_rt_check(&rt->rt6i_dst, &fl6->daddr, np->daddr_cache) || +#ifdef CONFIG_IPV6_SUBTREES + ip6_rt_check(&rt->rt6i_src, &fl6->saddr, np->saddr_cache) || +#endif + (fl6->flowi6_oif && fl6->flowi6_oif != dst->dev->ifindex)) { + dst_release(dst); + dst = NULL; + } + +out: + return dst; +} + +static int ip6_dst_lookup_tail(struct net *net, const struct sock *sk, + struct dst_entry **dst, struct flowi6 *fl6) +{ +#ifdef CONFIG_IPV6_OPTIMISTIC_DAD + struct neighbour *n; + struct rt6_info *rt; +#endif + int err; + int flags = 0; + + /* The correct way to handle this would be to do + * ip6_route_get_saddr, and then ip6_route_output; however, + * the route-specific preferred source forces the + * ip6_route_output call _before_ ip6_route_get_saddr. + * + * In source specific routing (no src=any default route), + * ip6_route_output will fail given src=any saddr, though, so + * that's why we try it again later. + */ + if (ipv6_addr_any(&fl6->saddr)) { + struct fib6_info *from; + struct rt6_info *rt; + + *dst = ip6_route_output(net, sk, fl6); + rt = (*dst)->error ? NULL : (struct rt6_info *)*dst; + + rcu_read_lock(); + from = rt ? rcu_dereference(rt->from) : NULL; + err = ip6_route_get_saddr(net, from, &fl6->daddr, + sk ? inet6_sk(sk)->srcprefs : 0, + &fl6->saddr); + rcu_read_unlock(); + + if (err) + goto out_err_release; + + /* If we had an erroneous initial result, pretend it + * never existed and let the SA-enabled version take + * over. + */ + if ((*dst)->error) { + dst_release(*dst); + *dst = NULL; + } + + if (fl6->flowi6_oif) + flags |= RT6_LOOKUP_F_IFACE; + } + + if (!*dst) + *dst = ip6_route_output_flags(net, sk, fl6, flags); + + err = (*dst)->error; + if (err) + goto out_err_release; + +#ifdef CONFIG_IPV6_OPTIMISTIC_DAD + /* + * Here if the dst entry we've looked up + * has a neighbour entry that is in the INCOMPLETE + * state and the src address from the flow is + * marked as OPTIMISTIC, we release the found + * dst entry and replace it instead with the + * dst entry of the nexthop router + */ + rt = (struct rt6_info *) *dst; + rcu_read_lock(); + n = __ipv6_neigh_lookup_noref(rt->dst.dev, + rt6_nexthop(rt, &fl6->daddr)); + err = n && !(READ_ONCE(n->nud_state) & NUD_VALID) ? -EINVAL : 0; + rcu_read_unlock(); + + if (err) { + struct inet6_ifaddr *ifp; + struct flowi6 fl_gw6; + int redirect; + + ifp = ipv6_get_ifaddr(net, &fl6->saddr, + (*dst)->dev, 1); + + redirect = (ifp && ifp->flags & IFA_F_OPTIMISTIC); + if (ifp) + in6_ifa_put(ifp); + + if (redirect) { + /* + * We need to get the dst entry for the + * default router instead + */ + dst_release(*dst); + memcpy(&fl_gw6, fl6, sizeof(struct flowi6)); + memset(&fl_gw6.daddr, 0, sizeof(struct in6_addr)); + *dst = ip6_route_output(net, sk, &fl_gw6); + err = (*dst)->error; + if (err) + goto out_err_release; + } + } +#endif + if (ipv6_addr_v4mapped(&fl6->saddr) && + !(ipv6_addr_v4mapped(&fl6->daddr) || ipv6_addr_any(&fl6->daddr))) { + err = -EAFNOSUPPORT; + goto out_err_release; + } + + return 0; + +out_err_release: + dst_release(*dst); + *dst = NULL; + + if (err == -ENETUNREACH) + IP6_INC_STATS(net, NULL, IPSTATS_MIB_OUTNOROUTES); + return err; +} + +/** + * ip6_dst_lookup - perform route lookup on flow + * @net: Network namespace to perform lookup in + * @sk: socket which provides route info + * @dst: pointer to dst_entry * for result + * @fl6: flow to lookup + * + * This function performs a route lookup on the given flow. + * + * It returns zero on success, or a standard errno code on error. + */ +int ip6_dst_lookup(struct net *net, struct sock *sk, struct dst_entry **dst, + struct flowi6 *fl6) +{ + *dst = NULL; + return ip6_dst_lookup_tail(net, sk, dst, fl6); +} +EXPORT_SYMBOL_GPL(ip6_dst_lookup); + +/** + * ip6_dst_lookup_flow - perform route lookup on flow with ipsec + * @net: Network namespace to perform lookup in + * @sk: socket which provides route info + * @fl6: flow to lookup + * @final_dst: final destination address for ipsec lookup + * + * This function performs a route lookup on the given flow. + * + * It returns a valid dst pointer on success, or a pointer encoded + * error code. + */ +struct dst_entry *ip6_dst_lookup_flow(struct net *net, const struct sock *sk, struct flowi6 *fl6, + const struct in6_addr *final_dst) +{ + struct dst_entry *dst = NULL; + int err; + + err = ip6_dst_lookup_tail(net, sk, &dst, fl6); + if (err) + return ERR_PTR(err); + if (final_dst) + fl6->daddr = *final_dst; + + return xfrm_lookup_route(net, dst, flowi6_to_flowi(fl6), sk, 0); +} +EXPORT_SYMBOL_GPL(ip6_dst_lookup_flow); + +/** + * ip6_sk_dst_lookup_flow - perform socket cached route lookup on flow + * @sk: socket which provides the dst cache and route info + * @fl6: flow to lookup + * @final_dst: final destination address for ipsec lookup + * @connected: whether @sk is connected or not + * + * This function performs a route lookup on the given flow with the + * possibility of using the cached route in the socket if it is valid. + * It will take the socket dst lock when operating on the dst cache. + * As a result, this function can only be used in process context. + * + * In addition, for a connected socket, cache the dst in the socket + * if the current cache is not valid. + * + * It returns a valid dst pointer on success, or a pointer encoded + * error code. + */ +struct dst_entry *ip6_sk_dst_lookup_flow(struct sock *sk, struct flowi6 *fl6, + const struct in6_addr *final_dst, + bool connected) +{ + struct dst_entry *dst = sk_dst_check(sk, inet6_sk(sk)->dst_cookie); + + dst = ip6_sk_dst_check(sk, dst, fl6); + if (dst) + return dst; + + dst = ip6_dst_lookup_flow(sock_net(sk), sk, fl6, final_dst); + if (connected && !IS_ERR(dst)) + ip6_sk_dst_store_flow(sk, dst_clone(dst), fl6); + + return dst; +} +EXPORT_SYMBOL_GPL(ip6_sk_dst_lookup_flow); + +/** + * ip6_dst_lookup_tunnel - perform route lookup on tunnel + * @skb: Packet for which lookup is done + * @dev: Tunnel device + * @net: Network namespace of tunnel device + * @sock: Socket which provides route info + * @saddr: Memory to store the src ip address + * @info: Tunnel information + * @protocol: IP protocol + * @use_cache: Flag to enable cache usage + * This function performs a route lookup on a tunnel + * + * It returns a valid dst pointer and stores src address to be used in + * tunnel in param saddr on success, else a pointer encoded error code. + */ + +struct dst_entry *ip6_dst_lookup_tunnel(struct sk_buff *skb, + struct net_device *dev, + struct net *net, + struct socket *sock, + struct in6_addr *saddr, + const struct ip_tunnel_info *info, + u8 protocol, + bool use_cache) +{ + struct dst_entry *dst = NULL; +#ifdef CONFIG_DST_CACHE + struct dst_cache *dst_cache; +#endif + struct flowi6 fl6; + __u8 prio; + +#ifdef CONFIG_DST_CACHE + dst_cache = (struct dst_cache *)&info->dst_cache; + if (use_cache) { + dst = dst_cache_get_ip6(dst_cache, saddr); + if (dst) + return dst; + } +#endif + memset(&fl6, 0, sizeof(fl6)); + fl6.flowi6_mark = skb->mark; + fl6.flowi6_proto = protocol; + fl6.daddr = info->key.u.ipv6.dst; + fl6.saddr = info->key.u.ipv6.src; + prio = info->key.tos; + fl6.flowlabel = ip6_make_flowinfo(prio, info->key.label); + + dst = ipv6_stub->ipv6_dst_lookup_flow(net, sock->sk, &fl6, + NULL); + if (IS_ERR(dst)) { + netdev_dbg(dev, "no route to %pI6\n", &fl6.daddr); + return ERR_PTR(-ENETUNREACH); + } + if (dst->dev == dev) { /* is this necessary? */ + netdev_dbg(dev, "circular route to %pI6\n", &fl6.daddr); + dst_release(dst); + return ERR_PTR(-ELOOP); + } +#ifdef CONFIG_DST_CACHE + if (use_cache) + dst_cache_set_ip6(dst_cache, dst, &fl6.saddr); +#endif + *saddr = fl6.saddr; + return dst; +} +EXPORT_SYMBOL_GPL(ip6_dst_lookup_tunnel); + +static inline struct ipv6_opt_hdr *ip6_opt_dup(struct ipv6_opt_hdr *src, + gfp_t gfp) +{ + return src ? kmemdup(src, (src->hdrlen + 1) * 8, gfp) : NULL; +} + +static inline struct ipv6_rt_hdr *ip6_rthdr_dup(struct ipv6_rt_hdr *src, + gfp_t gfp) +{ + return src ? kmemdup(src, (src->hdrlen + 1) * 8, gfp) : NULL; +} + +static void ip6_append_data_mtu(unsigned int *mtu, + int *maxfraglen, + unsigned int fragheaderlen, + struct sk_buff *skb, + struct rt6_info *rt, + unsigned int orig_mtu) +{ + if (!(rt->dst.flags & DST_XFRM_TUNNEL)) { + if (!skb) { + /* first fragment, reserve header_len */ + *mtu = orig_mtu - rt->dst.header_len; + + } else { + /* + * this fragment is not first, the headers + * space is regarded as data space. + */ + *mtu = orig_mtu; + } + *maxfraglen = ((*mtu - fragheaderlen) & ~7) + + fragheaderlen - sizeof(struct frag_hdr); + } +} + +static int ip6_setup_cork(struct sock *sk, struct inet_cork_full *cork, + struct inet6_cork *v6_cork, struct ipcm6_cookie *ipc6, + struct rt6_info *rt) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + unsigned int mtu; + struct ipv6_txoptions *nopt, *opt = ipc6->opt; + + /* callers pass dst together with a reference, set it first so + * ip6_cork_release() can put it down even in case of an error. + */ + cork->base.dst = &rt->dst; + + /* + * setup for corking + */ + if (opt) { + if (WARN_ON(v6_cork->opt)) + return -EINVAL; + + nopt = v6_cork->opt = kzalloc(sizeof(*opt), sk->sk_allocation); + if (unlikely(!nopt)) + return -ENOBUFS; + + nopt->tot_len = sizeof(*opt); + nopt->opt_flen = opt->opt_flen; + nopt->opt_nflen = opt->opt_nflen; + + nopt->dst0opt = ip6_opt_dup(opt->dst0opt, sk->sk_allocation); + if (opt->dst0opt && !nopt->dst0opt) + return -ENOBUFS; + + nopt->dst1opt = ip6_opt_dup(opt->dst1opt, sk->sk_allocation); + if (opt->dst1opt && !nopt->dst1opt) + return -ENOBUFS; + + nopt->hopopt = ip6_opt_dup(opt->hopopt, sk->sk_allocation); + if (opt->hopopt && !nopt->hopopt) + return -ENOBUFS; + + nopt->srcrt = ip6_rthdr_dup(opt->srcrt, sk->sk_allocation); + if (opt->srcrt && !nopt->srcrt) + return -ENOBUFS; + + /* need source address above miyazawa*/ + } + v6_cork->hop_limit = ipc6->hlimit; + v6_cork->tclass = ipc6->tclass; + if (rt->dst.flags & DST_XFRM_TUNNEL) + mtu = np->pmtudisc >= IPV6_PMTUDISC_PROBE ? + READ_ONCE(rt->dst.dev->mtu) : dst_mtu(&rt->dst); + else + mtu = np->pmtudisc >= IPV6_PMTUDISC_PROBE ? + READ_ONCE(rt->dst.dev->mtu) : dst_mtu(xfrm_dst_path(&rt->dst)); + if (np->frag_size < mtu) { + if (np->frag_size) + mtu = np->frag_size; + } + cork->base.fragsize = mtu; + cork->base.gso_size = ipc6->gso_size; + cork->base.tx_flags = 0; + cork->base.mark = ipc6->sockc.mark; + sock_tx_timestamp(sk, ipc6->sockc.tsflags, &cork->base.tx_flags); + + if (dst_allfrag(xfrm_dst_path(&rt->dst))) + cork->base.flags |= IPCORK_ALLFRAG; + cork->base.length = 0; + + cork->base.transmit_time = ipc6->sockc.transmit_time; + + return 0; +} + +static int __ip6_append_data(struct sock *sk, + struct sk_buff_head *queue, + struct inet_cork_full *cork_full, + struct inet6_cork *v6_cork, + struct page_frag *pfrag, + int getfrag(void *from, char *to, int offset, + int len, int odd, struct sk_buff *skb), + void *from, size_t length, int transhdrlen, + unsigned int flags, struct ipcm6_cookie *ipc6) +{ + struct sk_buff *skb, *skb_prev = NULL; + struct inet_cork *cork = &cork_full->base; + struct flowi6 *fl6 = &cork_full->fl.u.ip6; + unsigned int maxfraglen, fragheaderlen, mtu, orig_mtu, pmtu; + struct ubuf_info *uarg = NULL; + int exthdrlen = 0; + int dst_exthdrlen = 0; + int hh_len; + int copy; + int err; + int offset = 0; + bool zc = false; + u32 tskey = 0; + struct rt6_info *rt = (struct rt6_info *)cork->dst; + struct ipv6_txoptions *opt = v6_cork->opt; + int csummode = CHECKSUM_NONE; + unsigned int maxnonfragsize, headersize; + unsigned int wmem_alloc_delta = 0; + bool paged, extra_uref = false; + + skb = skb_peek_tail(queue); + if (!skb) { + exthdrlen = opt ? opt->opt_flen : 0; + dst_exthdrlen = rt->dst.header_len - rt->rt6i_nfheader_len; + } + + paged = !!cork->gso_size; + mtu = cork->gso_size ? IP6_MAX_MTU : cork->fragsize; + orig_mtu = mtu; + + if (cork->tx_flags & SKBTX_ANY_TSTAMP && + READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_OPT_ID) + tskey = atomic_inc_return(&sk->sk_tskey) - 1; + + hh_len = LL_RESERVED_SPACE(rt->dst.dev); + + fragheaderlen = sizeof(struct ipv6hdr) + rt->rt6i_nfheader_len + + (opt ? opt->opt_nflen : 0); + + headersize = sizeof(struct ipv6hdr) + + (opt ? opt->opt_flen + opt->opt_nflen : 0) + + (dst_allfrag(&rt->dst) ? + sizeof(struct frag_hdr) : 0) + + rt->rt6i_nfheader_len; + + if (mtu <= fragheaderlen || + ((mtu - fragheaderlen) & ~7) + fragheaderlen <= sizeof(struct frag_hdr)) + goto emsgsize; + + maxfraglen = ((mtu - fragheaderlen) & ~7) + fragheaderlen - + sizeof(struct frag_hdr); + + /* as per RFC 7112 section 5, the entire IPv6 Header Chain must fit + * the first fragment + */ + if (headersize + transhdrlen > mtu) + goto emsgsize; + + if (cork->length + length > mtu - headersize && ipc6->dontfrag && + (sk->sk_protocol == IPPROTO_UDP || + sk->sk_protocol == IPPROTO_ICMPV6 || + sk->sk_protocol == IPPROTO_RAW)) { + ipv6_local_rxpmtu(sk, fl6, mtu - headersize + + sizeof(struct ipv6hdr)); + goto emsgsize; + } + + if (ip6_sk_ignore_df(sk)) + maxnonfragsize = sizeof(struct ipv6hdr) + IPV6_MAXPLEN; + else + maxnonfragsize = mtu; + + if (cork->length + length > maxnonfragsize - headersize) { +emsgsize: + pmtu = max_t(int, mtu - headersize + sizeof(struct ipv6hdr), 0); + ipv6_local_error(sk, EMSGSIZE, fl6, pmtu); + return -EMSGSIZE; + } + + /* CHECKSUM_PARTIAL only with no extension headers and when + * we are not going to fragment + */ + if (transhdrlen && sk->sk_protocol == IPPROTO_UDP && + headersize == sizeof(struct ipv6hdr) && + length <= mtu - headersize && + (!(flags & MSG_MORE) || cork->gso_size) && + rt->dst.dev->features & (NETIF_F_IPV6_CSUM | NETIF_F_HW_CSUM)) + csummode = CHECKSUM_PARTIAL; + + if ((flags & MSG_ZEROCOPY) && length) { + struct msghdr *msg = from; + + if (getfrag == ip_generic_getfrag && msg->msg_ubuf) { + if (skb_zcopy(skb) && msg->msg_ubuf != skb_zcopy(skb)) + return -EINVAL; + + /* Leave uarg NULL if can't zerocopy, callers should + * be able to handle it. + */ + if ((rt->dst.dev->features & NETIF_F_SG) && + csummode == CHECKSUM_PARTIAL) { + paged = true; + zc = true; + uarg = msg->msg_ubuf; + } + } else if (sock_flag(sk, SOCK_ZEROCOPY)) { + uarg = msg_zerocopy_realloc(sk, length, skb_zcopy(skb)); + if (!uarg) + return -ENOBUFS; + extra_uref = !skb_zcopy(skb); /* only ref on new uarg */ + if (rt->dst.dev->features & NETIF_F_SG && + csummode == CHECKSUM_PARTIAL) { + paged = true; + zc = true; + } else { + uarg_to_msgzc(uarg)->zerocopy = 0; + skb_zcopy_set(skb, uarg, &extra_uref); + } + } + } + + /* + * Let's try using as much space as possible. + * Use MTU if total length of the message fits into the MTU. + * Otherwise, we need to reserve fragment header and + * fragment alignment (= 8-15 octects, in total). + * + * Note that we may need to "move" the data from the tail + * of the buffer to the new fragment when we split + * the message. + * + * FIXME: It may be fragmented into multiple chunks + * at once if non-fragmentable extension headers + * are too large. + * --yoshfuji + */ + + cork->length += length; + if (!skb) + goto alloc_new_skb; + + while (length > 0) { + /* Check if the remaining data fits into current packet. */ + copy = (cork->length <= mtu && !(cork->flags & IPCORK_ALLFRAG) ? mtu : maxfraglen) - skb->len; + if (copy < length) + copy = maxfraglen - skb->len; + + if (copy <= 0) { + char *data; + unsigned int datalen; + unsigned int fraglen; + unsigned int fraggap; + unsigned int alloclen, alloc_extra; + unsigned int pagedlen; +alloc_new_skb: + /* There's no room in the current skb */ + if (skb) + fraggap = skb->len - maxfraglen; + else + fraggap = 0; + /* update mtu and maxfraglen if necessary */ + if (!skb || !skb_prev) + ip6_append_data_mtu(&mtu, &maxfraglen, + fragheaderlen, skb, rt, + orig_mtu); + + skb_prev = skb; + + /* + * If remaining data exceeds the mtu, + * we know we need more fragment(s). + */ + datalen = length + fraggap; + + if (datalen > (cork->length <= mtu && !(cork->flags & IPCORK_ALLFRAG) ? mtu : maxfraglen) - fragheaderlen) + datalen = maxfraglen - fragheaderlen - rt->dst.trailer_len; + fraglen = datalen + fragheaderlen; + pagedlen = 0; + + alloc_extra = hh_len; + alloc_extra += dst_exthdrlen; + alloc_extra += rt->dst.trailer_len; + + /* We just reserve space for fragment header. + * Note: this may be overallocation if the message + * (without MSG_MORE) fits into the MTU. + */ + alloc_extra += sizeof(struct frag_hdr); + + if ((flags & MSG_MORE) && + !(rt->dst.dev->features&NETIF_F_SG)) + alloclen = mtu; + else if (!paged && + (fraglen + alloc_extra < SKB_MAX_ALLOC || + !(rt->dst.dev->features & NETIF_F_SG))) + alloclen = fraglen; + else { + alloclen = fragheaderlen + transhdrlen; + pagedlen = datalen - transhdrlen; + } + alloclen += alloc_extra; + + if (datalen != length + fraggap) { + /* + * this is not the last fragment, the trailer + * space is regarded as data space. + */ + datalen += rt->dst.trailer_len; + } + + fraglen = datalen + fragheaderlen; + + copy = datalen - transhdrlen - fraggap - pagedlen; + if (copy < 0) { + err = -EINVAL; + goto error; + } + if (transhdrlen) { + skb = sock_alloc_send_skb(sk, alloclen, + (flags & MSG_DONTWAIT), &err); + } else { + skb = NULL; + if (refcount_read(&sk->sk_wmem_alloc) + wmem_alloc_delta <= + 2 * sk->sk_sndbuf) + skb = alloc_skb(alloclen, + sk->sk_allocation); + if (unlikely(!skb)) + err = -ENOBUFS; + } + if (!skb) + goto error; + /* + * Fill in the control structures + */ + skb->protocol = htons(ETH_P_IPV6); + skb->ip_summed = csummode; + skb->csum = 0; + /* reserve for fragmentation and ipsec header */ + skb_reserve(skb, hh_len + sizeof(struct frag_hdr) + + dst_exthdrlen); + + /* + * Find where to start putting bytes + */ + data = skb_put(skb, fraglen - pagedlen); + skb_set_network_header(skb, exthdrlen); + data += fragheaderlen; + skb->transport_header = (skb->network_header + + fragheaderlen); + if (fraggap) { + skb->csum = skb_copy_and_csum_bits( + skb_prev, maxfraglen, + data + transhdrlen, fraggap); + skb_prev->csum = csum_sub(skb_prev->csum, + skb->csum); + data += fraggap; + pskb_trim_unique(skb_prev, maxfraglen); + } + if (copy > 0 && + getfrag(from, data + transhdrlen, offset, + copy, fraggap, skb) < 0) { + err = -EFAULT; + kfree_skb(skb); + goto error; + } + + offset += copy; + length -= copy + transhdrlen; + transhdrlen = 0; + exthdrlen = 0; + dst_exthdrlen = 0; + + /* Only the initial fragment is time stamped */ + skb_shinfo(skb)->tx_flags = cork->tx_flags; + cork->tx_flags = 0; + skb_shinfo(skb)->tskey = tskey; + tskey = 0; + skb_zcopy_set(skb, uarg, &extra_uref); + + if ((flags & MSG_CONFIRM) && !skb_prev) + skb_set_dst_pending_confirm(skb, 1); + + /* + * Put the packet on the pending queue + */ + if (!skb->destructor) { + skb->destructor = sock_wfree; + skb->sk = sk; + wmem_alloc_delta += skb->truesize; + } + __skb_queue_tail(queue, skb); + continue; + } + + if (copy > length) + copy = length; + + if (!(rt->dst.dev->features&NETIF_F_SG) && + skb_tailroom(skb) >= copy) { + unsigned int off; + + off = skb->len; + if (getfrag(from, skb_put(skb, copy), + offset, copy, off, skb) < 0) { + __skb_trim(skb, off); + err = -EFAULT; + goto error; + } + } else if (!zc) { + int i = skb_shinfo(skb)->nr_frags; + + err = -ENOMEM; + if (!sk_page_frag_refill(sk, pfrag)) + goto error; + + skb_zcopy_downgrade_managed(skb); + if (!skb_can_coalesce(skb, i, pfrag->page, + pfrag->offset)) { + err = -EMSGSIZE; + if (i == MAX_SKB_FRAGS) + goto error; + + __skb_fill_page_desc(skb, i, pfrag->page, + pfrag->offset, 0); + skb_shinfo(skb)->nr_frags = ++i; + get_page(pfrag->page); + } + copy = min_t(int, copy, pfrag->size - pfrag->offset); + if (getfrag(from, + page_address(pfrag->page) + pfrag->offset, + offset, copy, skb->len, skb) < 0) + goto error_efault; + + pfrag->offset += copy; + skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], copy); + skb->len += copy; + skb->data_len += copy; + skb->truesize += copy; + wmem_alloc_delta += copy; + } else { + err = skb_zerocopy_iter_dgram(skb, from, copy); + if (err < 0) + goto error; + } + offset += copy; + length -= copy; + } + + if (wmem_alloc_delta) + refcount_add(wmem_alloc_delta, &sk->sk_wmem_alloc); + return 0; + +error_efault: + err = -EFAULT; +error: + net_zcopy_put_abort(uarg, extra_uref); + cork->length -= length; + IP6_INC_STATS(sock_net(sk), rt->rt6i_idev, IPSTATS_MIB_OUTDISCARDS); + refcount_add(wmem_alloc_delta, &sk->sk_wmem_alloc); + return err; +} + +int ip6_append_data(struct sock *sk, + int getfrag(void *from, char *to, int offset, int len, + int odd, struct sk_buff *skb), + void *from, size_t length, int transhdrlen, + struct ipcm6_cookie *ipc6, struct flowi6 *fl6, + struct rt6_info *rt, unsigned int flags) +{ + struct inet_sock *inet = inet_sk(sk); + struct ipv6_pinfo *np = inet6_sk(sk); + int exthdrlen; + int err; + + if (flags&MSG_PROBE) + return 0; + if (skb_queue_empty(&sk->sk_write_queue)) { + /* + * setup for corking + */ + dst_hold(&rt->dst); + err = ip6_setup_cork(sk, &inet->cork, &np->cork, + ipc6, rt); + if (err) + return err; + + inet->cork.fl.u.ip6 = *fl6; + exthdrlen = (ipc6->opt ? ipc6->opt->opt_flen : 0); + length += exthdrlen; + transhdrlen += exthdrlen; + } else { + transhdrlen = 0; + } + + return __ip6_append_data(sk, &sk->sk_write_queue, &inet->cork, + &np->cork, sk_page_frag(sk), getfrag, + from, length, transhdrlen, flags, ipc6); +} +EXPORT_SYMBOL_GPL(ip6_append_data); + +static void ip6_cork_steal_dst(struct sk_buff *skb, struct inet_cork_full *cork) +{ + struct dst_entry *dst = cork->base.dst; + + cork->base.dst = NULL; + cork->base.flags &= ~IPCORK_ALLFRAG; + skb_dst_set(skb, dst); +} + +static void ip6_cork_release(struct inet_cork_full *cork, + struct inet6_cork *v6_cork) +{ + if (v6_cork->opt) { + struct ipv6_txoptions *opt = v6_cork->opt; + + kfree(opt->dst0opt); + kfree(opt->dst1opt); + kfree(opt->hopopt); + kfree(opt->srcrt); + kfree(opt); + v6_cork->opt = NULL; + } + + if (cork->base.dst) { + dst_release(cork->base.dst); + cork->base.dst = NULL; + cork->base.flags &= ~IPCORK_ALLFRAG; + } +} + +struct sk_buff *__ip6_make_skb(struct sock *sk, + struct sk_buff_head *queue, + struct inet_cork_full *cork, + struct inet6_cork *v6_cork) +{ + struct sk_buff *skb, *tmp_skb; + struct sk_buff **tail_skb; + struct in6_addr *final_dst; + struct ipv6_pinfo *np = inet6_sk(sk); + struct net *net = sock_net(sk); + struct ipv6hdr *hdr; + struct ipv6_txoptions *opt = v6_cork->opt; + struct rt6_info *rt = (struct rt6_info *)cork->base.dst; + struct flowi6 *fl6 = &cork->fl.u.ip6; + unsigned char proto = fl6->flowi6_proto; + + skb = __skb_dequeue(queue); + if (!skb) + goto out; + tail_skb = &(skb_shinfo(skb)->frag_list); + + /* move skb->data to ip header from ext header */ + if (skb->data < skb_network_header(skb)) + __skb_pull(skb, skb_network_offset(skb)); + while ((tmp_skb = __skb_dequeue(queue)) != NULL) { + __skb_pull(tmp_skb, skb_network_header_len(skb)); + *tail_skb = tmp_skb; + tail_skb = &(tmp_skb->next); + skb->len += tmp_skb->len; + skb->data_len += tmp_skb->len; + skb->truesize += tmp_skb->truesize; + tmp_skb->destructor = NULL; + tmp_skb->sk = NULL; + } + + /* Allow local fragmentation. */ + skb->ignore_df = ip6_sk_ignore_df(sk); + __skb_pull(skb, skb_network_header_len(skb)); + + final_dst = &fl6->daddr; + if (opt && opt->opt_flen) + ipv6_push_frag_opts(skb, opt, &proto); + if (opt && opt->opt_nflen) + ipv6_push_nfrag_opts(skb, opt, &proto, &final_dst, &fl6->saddr); + + skb_push(skb, sizeof(struct ipv6hdr)); + skb_reset_network_header(skb); + hdr = ipv6_hdr(skb); + + ip6_flow_hdr(hdr, v6_cork->tclass, + ip6_make_flowlabel(net, skb, fl6->flowlabel, + ip6_autoflowlabel(net, np), fl6)); + hdr->hop_limit = v6_cork->hop_limit; + hdr->nexthdr = proto; + hdr->saddr = fl6->saddr; + hdr->daddr = *final_dst; + + skb->priority = sk->sk_priority; + skb->mark = cork->base.mark; + skb->tstamp = cork->base.transmit_time; + + ip6_cork_steal_dst(skb, cork); + IP6_UPD_PO_STATS(net, rt->rt6i_idev, IPSTATS_MIB_OUT, skb->len); + if (proto == IPPROTO_ICMPV6) { + struct inet6_dev *idev = ip6_dst_idev(skb_dst(skb)); + u8 icmp6_type; + + if (sk->sk_socket->type == SOCK_RAW && !inet_sk(sk)->hdrincl) + icmp6_type = fl6->fl6_icmp_type; + else + icmp6_type = icmp6_hdr(skb)->icmp6_type; + ICMP6MSGOUT_INC_STATS(net, idev, icmp6_type); + ICMP6_INC_STATS(net, idev, ICMP6_MIB_OUTMSGS); + } + + ip6_cork_release(cork, v6_cork); +out: + return skb; +} + +int ip6_send_skb(struct sk_buff *skb) +{ + struct net *net = sock_net(skb->sk); + struct rt6_info *rt = (struct rt6_info *)skb_dst(skb); + int err; + + err = ip6_local_out(net, skb->sk, skb); + if (err) { + if (err > 0) + err = net_xmit_errno(err); + if (err) + IP6_INC_STATS(net, rt->rt6i_idev, + IPSTATS_MIB_OUTDISCARDS); + } + + return err; +} + +int ip6_push_pending_frames(struct sock *sk) +{ + struct sk_buff *skb; + + skb = ip6_finish_skb(sk); + if (!skb) + return 0; + + return ip6_send_skb(skb); +} +EXPORT_SYMBOL_GPL(ip6_push_pending_frames); + +static void __ip6_flush_pending_frames(struct sock *sk, + struct sk_buff_head *queue, + struct inet_cork_full *cork, + struct inet6_cork *v6_cork) +{ + struct sk_buff *skb; + + while ((skb = __skb_dequeue_tail(queue)) != NULL) { + if (skb_dst(skb)) + IP6_INC_STATS(sock_net(sk), ip6_dst_idev(skb_dst(skb)), + IPSTATS_MIB_OUTDISCARDS); + kfree_skb(skb); + } + + ip6_cork_release(cork, v6_cork); +} + +void ip6_flush_pending_frames(struct sock *sk) +{ + __ip6_flush_pending_frames(sk, &sk->sk_write_queue, + &inet_sk(sk)->cork, &inet6_sk(sk)->cork); +} +EXPORT_SYMBOL_GPL(ip6_flush_pending_frames); + +struct sk_buff *ip6_make_skb(struct sock *sk, + int getfrag(void *from, char *to, int offset, + int len, int odd, struct sk_buff *skb), + void *from, size_t length, int transhdrlen, + struct ipcm6_cookie *ipc6, struct rt6_info *rt, + unsigned int flags, struct inet_cork_full *cork) +{ + struct inet6_cork v6_cork; + struct sk_buff_head queue; + int exthdrlen = (ipc6->opt ? ipc6->opt->opt_flen : 0); + int err; + + if (flags & MSG_PROBE) { + dst_release(&rt->dst); + return NULL; + } + + __skb_queue_head_init(&queue); + + cork->base.flags = 0; + cork->base.addr = 0; + cork->base.opt = NULL; + v6_cork.opt = NULL; + err = ip6_setup_cork(sk, cork, &v6_cork, ipc6, rt); + if (err) { + ip6_cork_release(cork, &v6_cork); + return ERR_PTR(err); + } + if (ipc6->dontfrag < 0) + ipc6->dontfrag = inet6_sk(sk)->dontfrag; + + err = __ip6_append_data(sk, &queue, cork, &v6_cork, + ¤t->task_frag, getfrag, from, + length + exthdrlen, transhdrlen + exthdrlen, + flags, ipc6); + if (err) { + __ip6_flush_pending_frames(sk, &queue, cork, &v6_cork); + return ERR_PTR(err); + } + + return __ip6_make_skb(sk, &queue, cork, &v6_cork); +} diff --git a/net/ipv6/ip6_tunnel.c b/net/ipv6/ip6_tunnel.c new file mode 100644 index 000000000..9125e92d9 --- /dev/null +++ b/net/ipv6/ip6_tunnel.c @@ -0,0 +1,2367 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPv6 tunneling device + * Linux INET6 implementation + * + * Authors: + * Ville Nuorvala + * Yasuyuki Kozakai + * + * Based on: + * linux/net/ipv6/sit.c and linux/net/ipv4/ipip.c + * + * RFC 2473 + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +MODULE_AUTHOR("Ville Nuorvala"); +MODULE_DESCRIPTION("IPv6 tunneling device"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_RTNL_LINK("ip6tnl"); +MODULE_ALIAS_NETDEV("ip6tnl0"); + +#define IP6_TUNNEL_HASH_SIZE_SHIFT 5 +#define IP6_TUNNEL_HASH_SIZE (1 << IP6_TUNNEL_HASH_SIZE_SHIFT) + +static bool log_ecn_error = true; +module_param(log_ecn_error, bool, 0644); +MODULE_PARM_DESC(log_ecn_error, "Log packets received with corrupted ECN"); + +static u32 HASH(const struct in6_addr *addr1, const struct in6_addr *addr2) +{ + u32 hash = ipv6_addr_hash(addr1) ^ ipv6_addr_hash(addr2); + + return hash_32(hash, IP6_TUNNEL_HASH_SIZE_SHIFT); +} + +static int ip6_tnl_dev_init(struct net_device *dev); +static void ip6_tnl_dev_setup(struct net_device *dev); +static struct rtnl_link_ops ip6_link_ops __read_mostly; + +static unsigned int ip6_tnl_net_id __read_mostly; +struct ip6_tnl_net { + /* the IPv6 tunnel fallback device */ + struct net_device *fb_tnl_dev; + /* lists for storing tunnels in use */ + struct ip6_tnl __rcu *tnls_r_l[IP6_TUNNEL_HASH_SIZE]; + struct ip6_tnl __rcu *tnls_wc[1]; + struct ip6_tnl __rcu **tnls[2]; + struct ip6_tnl __rcu *collect_md_tun; +}; + +static inline int ip6_tnl_mpls_supported(void) +{ + return IS_ENABLED(CONFIG_MPLS); +} + +#define for_each_ip6_tunnel_rcu(start) \ + for (t = rcu_dereference(start); t; t = rcu_dereference(t->next)) + +/** + * ip6_tnl_lookup - fetch tunnel matching the end-point addresses + * @net: network namespace + * @link: ifindex of underlying interface + * @remote: the address of the tunnel exit-point + * @local: the address of the tunnel entry-point + * + * Return: + * tunnel matching given end-points if found, + * else fallback tunnel if its device is up, + * else %NULL + **/ + +static struct ip6_tnl * +ip6_tnl_lookup(struct net *net, int link, + const struct in6_addr *remote, const struct in6_addr *local) +{ + unsigned int hash = HASH(remote, local); + struct ip6_tnl *t, *cand = NULL; + struct ip6_tnl_net *ip6n = net_generic(net, ip6_tnl_net_id); + struct in6_addr any; + + for_each_ip6_tunnel_rcu(ip6n->tnls_r_l[hash]) { + if (!ipv6_addr_equal(local, &t->parms.laddr) || + !ipv6_addr_equal(remote, &t->parms.raddr) || + !(t->dev->flags & IFF_UP)) + continue; + + if (link == t->parms.link) + return t; + else + cand = t; + } + + memset(&any, 0, sizeof(any)); + hash = HASH(&any, local); + for_each_ip6_tunnel_rcu(ip6n->tnls_r_l[hash]) { + if (!ipv6_addr_equal(local, &t->parms.laddr) || + !ipv6_addr_any(&t->parms.raddr) || + !(t->dev->flags & IFF_UP)) + continue; + + if (link == t->parms.link) + return t; + else if (!cand) + cand = t; + } + + hash = HASH(remote, &any); + for_each_ip6_tunnel_rcu(ip6n->tnls_r_l[hash]) { + if (!ipv6_addr_equal(remote, &t->parms.raddr) || + !ipv6_addr_any(&t->parms.laddr) || + !(t->dev->flags & IFF_UP)) + continue; + + if (link == t->parms.link) + return t; + else if (!cand) + cand = t; + } + + if (cand) + return cand; + + t = rcu_dereference(ip6n->collect_md_tun); + if (t && t->dev->flags & IFF_UP) + return t; + + t = rcu_dereference(ip6n->tnls_wc[0]); + if (t && (t->dev->flags & IFF_UP)) + return t; + + return NULL; +} + +/** + * ip6_tnl_bucket - get head of list matching given tunnel parameters + * @ip6n: the private data for ip6_vti in the netns + * @p: parameters containing tunnel end-points + * + * Description: + * ip6_tnl_bucket() returns the head of the list matching the + * &struct in6_addr entries laddr and raddr in @p. + * + * Return: head of IPv6 tunnel list + **/ + +static struct ip6_tnl __rcu ** +ip6_tnl_bucket(struct ip6_tnl_net *ip6n, const struct __ip6_tnl_parm *p) +{ + const struct in6_addr *remote = &p->raddr; + const struct in6_addr *local = &p->laddr; + unsigned int h = 0; + int prio = 0; + + if (!ipv6_addr_any(remote) || !ipv6_addr_any(local)) { + prio = 1; + h = HASH(remote, local); + } + return &ip6n->tnls[prio][h]; +} + +/** + * ip6_tnl_link - add tunnel to hash table + * @ip6n: the private data for ip6_vti in the netns + * @t: tunnel to be added + **/ + +static void +ip6_tnl_link(struct ip6_tnl_net *ip6n, struct ip6_tnl *t) +{ + struct ip6_tnl __rcu **tp = ip6_tnl_bucket(ip6n, &t->parms); + + if (t->parms.collect_md) + rcu_assign_pointer(ip6n->collect_md_tun, t); + rcu_assign_pointer(t->next , rtnl_dereference(*tp)); + rcu_assign_pointer(*tp, t); +} + +/** + * ip6_tnl_unlink - remove tunnel from hash table + * @ip6n: the private data for ip6_vti in the netns + * @t: tunnel to be removed + **/ + +static void +ip6_tnl_unlink(struct ip6_tnl_net *ip6n, struct ip6_tnl *t) +{ + struct ip6_tnl __rcu **tp; + struct ip6_tnl *iter; + + if (t->parms.collect_md) + rcu_assign_pointer(ip6n->collect_md_tun, NULL); + + for (tp = ip6_tnl_bucket(ip6n, &t->parms); + (iter = rtnl_dereference(*tp)) != NULL; + tp = &iter->next) { + if (t == iter) { + rcu_assign_pointer(*tp, t->next); + break; + } + } +} + +static void ip6_dev_free(struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + + gro_cells_destroy(&t->gro_cells); + dst_cache_destroy(&t->dst_cache); + free_percpu(dev->tstats); +} + +static int ip6_tnl_create2(struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct net *net = dev_net(dev); + struct ip6_tnl_net *ip6n = net_generic(net, ip6_tnl_net_id); + int err; + + dev->rtnl_link_ops = &ip6_link_ops; + err = register_netdevice(dev); + if (err < 0) + goto out; + + strcpy(t->parms.name, dev->name); + + ip6_tnl_link(ip6n, t); + return 0; + +out: + return err; +} + +/** + * ip6_tnl_create - create a new tunnel + * @net: network namespace + * @p: tunnel parameters + * + * Description: + * Create tunnel matching given parameters. + * + * Return: + * created tunnel or error pointer + **/ + +static struct ip6_tnl *ip6_tnl_create(struct net *net, struct __ip6_tnl_parm *p) +{ + struct net_device *dev; + struct ip6_tnl *t; + char name[IFNAMSIZ]; + int err = -E2BIG; + + if (p->name[0]) { + if (!dev_valid_name(p->name)) + goto failed; + strscpy(name, p->name, IFNAMSIZ); + } else { + sprintf(name, "ip6tnl%%d"); + } + err = -ENOMEM; + dev = alloc_netdev(sizeof(*t), name, NET_NAME_UNKNOWN, + ip6_tnl_dev_setup); + if (!dev) + goto failed; + + dev_net_set(dev, net); + + t = netdev_priv(dev); + t->parms = *p; + t->net = dev_net(dev); + err = ip6_tnl_create2(dev); + if (err < 0) + goto failed_free; + + return t; + +failed_free: + free_netdev(dev); +failed: + return ERR_PTR(err); +} + +/** + * ip6_tnl_locate - find or create tunnel matching given parameters + * @net: network namespace + * @p: tunnel parameters + * @create: != 0 if allowed to create new tunnel if no match found + * + * Description: + * ip6_tnl_locate() first tries to locate an existing tunnel + * based on @parms. If this is unsuccessful, but @create is set a new + * tunnel device is created and registered for use. + * + * Return: + * matching tunnel or error pointer + **/ + +static struct ip6_tnl *ip6_tnl_locate(struct net *net, + struct __ip6_tnl_parm *p, int create) +{ + const struct in6_addr *remote = &p->raddr; + const struct in6_addr *local = &p->laddr; + struct ip6_tnl __rcu **tp; + struct ip6_tnl *t; + struct ip6_tnl_net *ip6n = net_generic(net, ip6_tnl_net_id); + + for (tp = ip6_tnl_bucket(ip6n, p); + (t = rtnl_dereference(*tp)) != NULL; + tp = &t->next) { + if (ipv6_addr_equal(local, &t->parms.laddr) && + ipv6_addr_equal(remote, &t->parms.raddr) && + p->link == t->parms.link) { + if (create) + return ERR_PTR(-EEXIST); + + return t; + } + } + if (!create) + return ERR_PTR(-ENODEV); + return ip6_tnl_create(net, p); +} + +/** + * ip6_tnl_dev_uninit - tunnel device uninitializer + * @dev: the device to be destroyed + * + * Description: + * ip6_tnl_dev_uninit() removes tunnel from its list + **/ + +static void +ip6_tnl_dev_uninit(struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct net *net = t->net; + struct ip6_tnl_net *ip6n = net_generic(net, ip6_tnl_net_id); + + if (dev == ip6n->fb_tnl_dev) + RCU_INIT_POINTER(ip6n->tnls_wc[0], NULL); + else + ip6_tnl_unlink(ip6n, t); + dst_cache_reset(&t->dst_cache); + netdev_put(dev, &t->dev_tracker); +} + +/** + * ip6_tnl_parse_tlv_enc_lim - handle encapsulation limit option + * @skb: received socket buffer + * @raw: the ICMPv6 error message data + * + * Return: + * 0 if none was found, + * else index to encapsulation limit + **/ + +__u16 ip6_tnl_parse_tlv_enc_lim(struct sk_buff *skb, __u8 *raw) +{ + const struct ipv6hdr *ipv6h = (const struct ipv6hdr *)raw; + unsigned int nhoff = raw - skb->data; + unsigned int off = nhoff + sizeof(*ipv6h); + u8 nexthdr = ipv6h->nexthdr; + + while (ipv6_ext_hdr(nexthdr) && nexthdr != NEXTHDR_NONE) { + struct ipv6_opt_hdr *hdr; + u16 optlen; + + if (!pskb_may_pull(skb, off + sizeof(*hdr))) + break; + + hdr = (struct ipv6_opt_hdr *)(skb->data + off); + if (nexthdr == NEXTHDR_FRAGMENT) { + optlen = 8; + } else if (nexthdr == NEXTHDR_AUTH) { + optlen = ipv6_authlen(hdr); + } else { + optlen = ipv6_optlen(hdr); + } + + if (!pskb_may_pull(skb, off + optlen)) + break; + + hdr = (struct ipv6_opt_hdr *)(skb->data + off); + if (nexthdr == NEXTHDR_FRAGMENT) { + struct frag_hdr *frag_hdr = (struct frag_hdr *)hdr; + + if (frag_hdr->frag_off) + break; + } + if (nexthdr == NEXTHDR_DEST) { + u16 i = 2; + + while (1) { + struct ipv6_tlv_tnl_enc_lim *tel; + + /* No more room for encapsulation limit */ + if (i + sizeof(*tel) > optlen) + break; + + tel = (struct ipv6_tlv_tnl_enc_lim *)(skb->data + off + i); + /* return index of option if found and valid */ + if (tel->type == IPV6_TLV_TNL_ENCAP_LIMIT && + tel->length == 1) + return i + off - nhoff; + /* else jump to next option */ + if (tel->type) + i += tel->length + 2; + else + i++; + } + } + nexthdr = hdr->nexthdr; + off += optlen; + } + return 0; +} +EXPORT_SYMBOL(ip6_tnl_parse_tlv_enc_lim); + +/* ip6_tnl_err() should handle errors in the tunnel according to the + * specifications in RFC 2473. + */ +static int +ip6_tnl_err(struct sk_buff *skb, __u8 ipproto, struct inet6_skb_parm *opt, + u8 *type, u8 *code, int *msg, __u32 *info, int offset) +{ + const struct ipv6hdr *ipv6h = (const struct ipv6hdr *)skb->data; + struct net *net = dev_net(skb->dev); + u8 rel_type = ICMPV6_DEST_UNREACH; + u8 rel_code = ICMPV6_ADDR_UNREACH; + __u32 rel_info = 0; + struct ip6_tnl *t; + int err = -ENOENT; + int rel_msg = 0; + u8 tproto; + __u16 len; + + /* If the packet doesn't contain the original IPv6 header we are + in trouble since we might need the source address for further + processing of the error. */ + + rcu_read_lock(); + t = ip6_tnl_lookup(dev_net(skb->dev), skb->dev->ifindex, &ipv6h->daddr, &ipv6h->saddr); + if (!t) + goto out; + + tproto = READ_ONCE(t->parms.proto); + if (tproto != ipproto && tproto != 0) + goto out; + + err = 0; + + switch (*type) { + case ICMPV6_DEST_UNREACH: + net_dbg_ratelimited("%s: Path to destination invalid or inactive!\n", + t->parms.name); + rel_msg = 1; + break; + case ICMPV6_TIME_EXCEED: + if ((*code) == ICMPV6_EXC_HOPLIMIT) { + net_dbg_ratelimited("%s: Too small hop limit or routing loop in tunnel!\n", + t->parms.name); + rel_msg = 1; + } + break; + case ICMPV6_PARAMPROB: { + struct ipv6_tlv_tnl_enc_lim *tel; + __u32 teli; + + teli = 0; + if ((*code) == ICMPV6_HDR_FIELD) + teli = ip6_tnl_parse_tlv_enc_lim(skb, skb->data); + + if (teli && teli == *info - 2) { + tel = (struct ipv6_tlv_tnl_enc_lim *) &skb->data[teli]; + if (tel->encap_limit == 0) { + net_dbg_ratelimited("%s: Too small encapsulation limit or routing loop in tunnel!\n", + t->parms.name); + rel_msg = 1; + } + } else { + net_dbg_ratelimited("%s: Recipient unable to parse tunneled packet!\n", + t->parms.name); + } + break; + } + case ICMPV6_PKT_TOOBIG: { + __u32 mtu; + + ip6_update_pmtu(skb, net, htonl(*info), 0, 0, + sock_net_uid(net, NULL)); + mtu = *info - offset; + if (mtu < IPV6_MIN_MTU) + mtu = IPV6_MIN_MTU; + len = sizeof(*ipv6h) + ntohs(ipv6h->payload_len); + if (len > mtu) { + rel_type = ICMPV6_PKT_TOOBIG; + rel_code = 0; + rel_info = mtu; + rel_msg = 1; + } + break; + } + case NDISC_REDIRECT: + ip6_redirect(skb, net, skb->dev->ifindex, 0, + sock_net_uid(net, NULL)); + break; + } + + *type = rel_type; + *code = rel_code; + *info = rel_info; + *msg = rel_msg; + +out: + rcu_read_unlock(); + return err; +} + +static int +ip4ip6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + __u32 rel_info = ntohl(info); + const struct iphdr *eiph; + struct sk_buff *skb2; + int err, rel_msg = 0; + u8 rel_type = type; + u8 rel_code = code; + struct rtable *rt; + struct flowi4 fl4; + + err = ip6_tnl_err(skb, IPPROTO_IPIP, opt, &rel_type, &rel_code, + &rel_msg, &rel_info, offset); + if (err < 0) + return err; + + if (rel_msg == 0) + return 0; + + switch (rel_type) { + case ICMPV6_DEST_UNREACH: + if (rel_code != ICMPV6_ADDR_UNREACH) + return 0; + rel_type = ICMP_DEST_UNREACH; + rel_code = ICMP_HOST_UNREACH; + break; + case ICMPV6_PKT_TOOBIG: + if (rel_code != 0) + return 0; + rel_type = ICMP_DEST_UNREACH; + rel_code = ICMP_FRAG_NEEDED; + break; + default: + return 0; + } + + if (!pskb_may_pull(skb, offset + sizeof(struct iphdr))) + return 0; + + skb2 = skb_clone(skb, GFP_ATOMIC); + if (!skb2) + return 0; + + skb_dst_drop(skb2); + + skb_pull(skb2, offset); + skb_reset_network_header(skb2); + eiph = ip_hdr(skb2); + + /* Try to guess incoming interface */ + rt = ip_route_output_ports(dev_net(skb->dev), &fl4, NULL, eiph->saddr, + 0, 0, 0, IPPROTO_IPIP, RT_TOS(eiph->tos), 0); + if (IS_ERR(rt)) + goto out; + + skb2->dev = rt->dst.dev; + ip_rt_put(rt); + + /* route "incoming" packet */ + if (rt->rt_flags & RTCF_LOCAL) { + rt = ip_route_output_ports(dev_net(skb->dev), &fl4, NULL, + eiph->daddr, eiph->saddr, 0, 0, + IPPROTO_IPIP, RT_TOS(eiph->tos), 0); + if (IS_ERR(rt) || rt->dst.dev->type != ARPHRD_TUNNEL6) { + if (!IS_ERR(rt)) + ip_rt_put(rt); + goto out; + } + skb_dst_set(skb2, &rt->dst); + } else { + if (ip_route_input(skb2, eiph->daddr, eiph->saddr, eiph->tos, + skb2->dev) || + skb_dst(skb2)->dev->type != ARPHRD_TUNNEL6) + goto out; + } + + /* change mtu on this route */ + if (rel_type == ICMP_DEST_UNREACH && rel_code == ICMP_FRAG_NEEDED) { + if (rel_info > dst_mtu(skb_dst(skb2))) + goto out; + + skb_dst_update_pmtu_no_confirm(skb2, rel_info); + } + + icmp_send(skb2, rel_type, rel_code, htonl(rel_info)); + +out: + kfree_skb(skb2); + return 0; +} + +static int +ip6ip6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + __u32 rel_info = ntohl(info); + int err, rel_msg = 0; + u8 rel_type = type; + u8 rel_code = code; + + err = ip6_tnl_err(skb, IPPROTO_IPV6, opt, &rel_type, &rel_code, + &rel_msg, &rel_info, offset); + if (err < 0) + return err; + + if (rel_msg && pskb_may_pull(skb, offset + sizeof(struct ipv6hdr))) { + struct rt6_info *rt; + struct sk_buff *skb2 = skb_clone(skb, GFP_ATOMIC); + + if (!skb2) + return 0; + + skb_dst_drop(skb2); + skb_pull(skb2, offset); + skb_reset_network_header(skb2); + + /* Try to guess incoming interface */ + rt = rt6_lookup(dev_net(skb->dev), &ipv6_hdr(skb2)->saddr, + NULL, 0, skb2, 0); + + if (rt && rt->dst.dev) + skb2->dev = rt->dst.dev; + + icmpv6_send(skb2, rel_type, rel_code, rel_info); + + ip6_rt_put(rt); + + kfree_skb(skb2); + } + + return 0; +} + +static int +mplsip6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + __u32 rel_info = ntohl(info); + int err, rel_msg = 0; + u8 rel_type = type; + u8 rel_code = code; + + err = ip6_tnl_err(skb, IPPROTO_MPLS, opt, &rel_type, &rel_code, + &rel_msg, &rel_info, offset); + return err; +} + +static int ip4ip6_dscp_ecn_decapsulate(const struct ip6_tnl *t, + const struct ipv6hdr *ipv6h, + struct sk_buff *skb) +{ + __u8 dsfield = ipv6_get_dsfield(ipv6h) & ~INET_ECN_MASK; + + if (t->parms.flags & IP6_TNL_F_RCV_DSCP_COPY) + ipv4_change_dsfield(ip_hdr(skb), INET_ECN_MASK, dsfield); + + return IP6_ECN_decapsulate(ipv6h, skb); +} + +static int ip6ip6_dscp_ecn_decapsulate(const struct ip6_tnl *t, + const struct ipv6hdr *ipv6h, + struct sk_buff *skb) +{ + if (t->parms.flags & IP6_TNL_F_RCV_DSCP_COPY) + ipv6_copy_dscp(ipv6_get_dsfield(ipv6h), ipv6_hdr(skb)); + + return IP6_ECN_decapsulate(ipv6h, skb); +} + +static inline int mplsip6_dscp_ecn_decapsulate(const struct ip6_tnl *t, + const struct ipv6hdr *ipv6h, + struct sk_buff *skb) +{ + /* ECN is not supported in AF_MPLS */ + return 0; +} + +__u32 ip6_tnl_get_cap(struct ip6_tnl *t, + const struct in6_addr *laddr, + const struct in6_addr *raddr) +{ + struct __ip6_tnl_parm *p = &t->parms; + int ltype = ipv6_addr_type(laddr); + int rtype = ipv6_addr_type(raddr); + __u32 flags = 0; + + if (ltype == IPV6_ADDR_ANY || rtype == IPV6_ADDR_ANY) { + flags = IP6_TNL_F_CAP_PER_PACKET; + } else if (ltype & (IPV6_ADDR_UNICAST|IPV6_ADDR_MULTICAST) && + rtype & (IPV6_ADDR_UNICAST|IPV6_ADDR_MULTICAST) && + !((ltype|rtype) & IPV6_ADDR_LOOPBACK) && + (!((ltype|rtype) & IPV6_ADDR_LINKLOCAL) || p->link)) { + if (ltype&IPV6_ADDR_UNICAST) + flags |= IP6_TNL_F_CAP_XMIT; + if (rtype&IPV6_ADDR_UNICAST) + flags |= IP6_TNL_F_CAP_RCV; + } + return flags; +} +EXPORT_SYMBOL(ip6_tnl_get_cap); + +/* called with rcu_read_lock() */ +int ip6_tnl_rcv_ctl(struct ip6_tnl *t, + const struct in6_addr *laddr, + const struct in6_addr *raddr) +{ + struct __ip6_tnl_parm *p = &t->parms; + int ret = 0; + struct net *net = t->net; + + if ((p->flags & IP6_TNL_F_CAP_RCV) || + ((p->flags & IP6_TNL_F_CAP_PER_PACKET) && + (ip6_tnl_get_cap(t, laddr, raddr) & IP6_TNL_F_CAP_RCV))) { + struct net_device *ldev = NULL; + + if (p->link) + ldev = dev_get_by_index_rcu(net, p->link); + + if ((ipv6_addr_is_multicast(laddr) || + likely(ipv6_chk_addr_and_flags(net, laddr, ldev, false, + 0, IFA_F_TENTATIVE))) && + ((p->flags & IP6_TNL_F_ALLOW_LOCAL_REMOTE) || + likely(!ipv6_chk_addr_and_flags(net, raddr, ldev, true, + 0, IFA_F_TENTATIVE)))) + ret = 1; + } + return ret; +} +EXPORT_SYMBOL_GPL(ip6_tnl_rcv_ctl); + +static int __ip6_tnl_rcv(struct ip6_tnl *tunnel, struct sk_buff *skb, + const struct tnl_ptk_info *tpi, + struct metadata_dst *tun_dst, + int (*dscp_ecn_decapsulate)(const struct ip6_tnl *t, + const struct ipv6hdr *ipv6h, + struct sk_buff *skb), + bool log_ecn_err) +{ + const struct ipv6hdr *ipv6h = ipv6_hdr(skb); + int err; + + if ((!(tpi->flags & TUNNEL_CSUM) && + (tunnel->parms.i_flags & TUNNEL_CSUM)) || + ((tpi->flags & TUNNEL_CSUM) && + !(tunnel->parms.i_flags & TUNNEL_CSUM))) { + tunnel->dev->stats.rx_crc_errors++; + tunnel->dev->stats.rx_errors++; + goto drop; + } + + if (tunnel->parms.i_flags & TUNNEL_SEQ) { + if (!(tpi->flags & TUNNEL_SEQ) || + (tunnel->i_seqno && + (s32)(ntohl(tpi->seq) - tunnel->i_seqno) < 0)) { + tunnel->dev->stats.rx_fifo_errors++; + tunnel->dev->stats.rx_errors++; + goto drop; + } + tunnel->i_seqno = ntohl(tpi->seq) + 1; + } + + skb->protocol = tpi->proto; + + /* Warning: All skb pointers will be invalidated! */ + if (tunnel->dev->type == ARPHRD_ETHER) { + if (!pskb_may_pull(skb, ETH_HLEN)) { + tunnel->dev->stats.rx_length_errors++; + tunnel->dev->stats.rx_errors++; + goto drop; + } + + ipv6h = ipv6_hdr(skb); + skb->protocol = eth_type_trans(skb, tunnel->dev); + skb_postpull_rcsum(skb, eth_hdr(skb), ETH_HLEN); + } else { + skb->dev = tunnel->dev; + skb_reset_mac_header(skb); + } + + skb_reset_network_header(skb); + memset(skb->cb, 0, sizeof(struct inet6_skb_parm)); + + __skb_tunnel_rx(skb, tunnel->dev, tunnel->net); + + err = dscp_ecn_decapsulate(tunnel, ipv6h, skb); + if (unlikely(err)) { + if (log_ecn_err) + net_info_ratelimited("non-ECT from %pI6 with DS=%#x\n", + &ipv6h->saddr, + ipv6_get_dsfield(ipv6h)); + if (err > 1) { + ++tunnel->dev->stats.rx_frame_errors; + ++tunnel->dev->stats.rx_errors; + goto drop; + } + } + + dev_sw_netstats_rx_add(tunnel->dev, skb->len); + + skb_scrub_packet(skb, !net_eq(tunnel->net, dev_net(tunnel->dev))); + + if (tun_dst) + skb_dst_set(skb, (struct dst_entry *)tun_dst); + + gro_cells_receive(&tunnel->gro_cells, skb); + return 0; + +drop: + if (tun_dst) + dst_release((struct dst_entry *)tun_dst); + kfree_skb(skb); + return 0; +} + +int ip6_tnl_rcv(struct ip6_tnl *t, struct sk_buff *skb, + const struct tnl_ptk_info *tpi, + struct metadata_dst *tun_dst, + bool log_ecn_err) +{ + int (*dscp_ecn_decapsulate)(const struct ip6_tnl *t, + const struct ipv6hdr *ipv6h, + struct sk_buff *skb); + + dscp_ecn_decapsulate = ip6ip6_dscp_ecn_decapsulate; + if (tpi->proto == htons(ETH_P_IP)) + dscp_ecn_decapsulate = ip4ip6_dscp_ecn_decapsulate; + + return __ip6_tnl_rcv(t, skb, tpi, tun_dst, dscp_ecn_decapsulate, + log_ecn_err); +} +EXPORT_SYMBOL(ip6_tnl_rcv); + +static const struct tnl_ptk_info tpi_v6 = { + /* no tunnel info required for ipxip6. */ + .proto = htons(ETH_P_IPV6), +}; + +static const struct tnl_ptk_info tpi_v4 = { + /* no tunnel info required for ipxip6. */ + .proto = htons(ETH_P_IP), +}; + +static const struct tnl_ptk_info tpi_mpls = { + /* no tunnel info required for mplsip6. */ + .proto = htons(ETH_P_MPLS_UC), +}; + +static int ipxip6_rcv(struct sk_buff *skb, u8 ipproto, + const struct tnl_ptk_info *tpi, + int (*dscp_ecn_decapsulate)(const struct ip6_tnl *t, + const struct ipv6hdr *ipv6h, + struct sk_buff *skb)) +{ + struct ip6_tnl *t; + const struct ipv6hdr *ipv6h = ipv6_hdr(skb); + struct metadata_dst *tun_dst = NULL; + int ret = -1; + + rcu_read_lock(); + t = ip6_tnl_lookup(dev_net(skb->dev), skb->dev->ifindex, &ipv6h->saddr, &ipv6h->daddr); + + if (t) { + u8 tproto = READ_ONCE(t->parms.proto); + + if (tproto != ipproto && tproto != 0) + goto drop; + if (!xfrm6_policy_check(NULL, XFRM_POLICY_IN, skb)) + goto drop; + ipv6h = ipv6_hdr(skb); + if (!ip6_tnl_rcv_ctl(t, &ipv6h->daddr, &ipv6h->saddr)) + goto drop; + if (iptunnel_pull_header(skb, 0, tpi->proto, false)) + goto drop; + if (t->parms.collect_md) { + tun_dst = ipv6_tun_rx_dst(skb, 0, 0, 0); + if (!tun_dst) + goto drop; + } + ret = __ip6_tnl_rcv(t, skb, tpi, tun_dst, dscp_ecn_decapsulate, + log_ecn_error); + } + + rcu_read_unlock(); + + return ret; + +drop: + rcu_read_unlock(); + kfree_skb(skb); + return 0; +} + +static int ip4ip6_rcv(struct sk_buff *skb) +{ + return ipxip6_rcv(skb, IPPROTO_IPIP, &tpi_v4, + ip4ip6_dscp_ecn_decapsulate); +} + +static int ip6ip6_rcv(struct sk_buff *skb) +{ + return ipxip6_rcv(skb, IPPROTO_IPV6, &tpi_v6, + ip6ip6_dscp_ecn_decapsulate); +} + +static int mplsip6_rcv(struct sk_buff *skb) +{ + return ipxip6_rcv(skb, IPPROTO_MPLS, &tpi_mpls, + mplsip6_dscp_ecn_decapsulate); +} + +struct ipv6_tel_txoption { + struct ipv6_txoptions ops; + __u8 dst_opt[8]; +}; + +static void init_tel_txopt(struct ipv6_tel_txoption *opt, __u8 encap_limit) +{ + memset(opt, 0, sizeof(struct ipv6_tel_txoption)); + + opt->dst_opt[2] = IPV6_TLV_TNL_ENCAP_LIMIT; + opt->dst_opt[3] = 1; + opt->dst_opt[4] = encap_limit; + opt->dst_opt[5] = IPV6_TLV_PADN; + opt->dst_opt[6] = 1; + + opt->ops.dst1opt = (struct ipv6_opt_hdr *) opt->dst_opt; + opt->ops.opt_nflen = 8; +} + +/** + * ip6_tnl_addr_conflict - compare packet addresses to tunnel's own + * @t: the outgoing tunnel device + * @hdr: IPv6 header from the incoming packet + * + * Description: + * Avoid trivial tunneling loop by checking that tunnel exit-point + * doesn't match source of incoming packet. + * + * Return: + * 1 if conflict, + * 0 else + **/ + +static inline bool +ip6_tnl_addr_conflict(const struct ip6_tnl *t, const struct ipv6hdr *hdr) +{ + return ipv6_addr_equal(&t->parms.raddr, &hdr->saddr); +} + +int ip6_tnl_xmit_ctl(struct ip6_tnl *t, + const struct in6_addr *laddr, + const struct in6_addr *raddr) +{ + struct __ip6_tnl_parm *p = &t->parms; + int ret = 0; + struct net *net = t->net; + + if (t->parms.collect_md) + return 1; + + if ((p->flags & IP6_TNL_F_CAP_XMIT) || + ((p->flags & IP6_TNL_F_CAP_PER_PACKET) && + (ip6_tnl_get_cap(t, laddr, raddr) & IP6_TNL_F_CAP_XMIT))) { + struct net_device *ldev = NULL; + + rcu_read_lock(); + if (p->link) + ldev = dev_get_by_index_rcu(net, p->link); + + if (unlikely(!ipv6_chk_addr_and_flags(net, laddr, ldev, false, + 0, IFA_F_TENTATIVE))) + pr_warn_ratelimited("%s xmit: Local address not yet configured!\n", + p->name); + else if (!(p->flags & IP6_TNL_F_ALLOW_LOCAL_REMOTE) && + !ipv6_addr_is_multicast(raddr) && + unlikely(ipv6_chk_addr_and_flags(net, raddr, ldev, + true, 0, IFA_F_TENTATIVE))) + pr_warn_ratelimited("%s xmit: Routing loop! Remote address found on this node!\n", + p->name); + else + ret = 1; + rcu_read_unlock(); + } + return ret; +} +EXPORT_SYMBOL_GPL(ip6_tnl_xmit_ctl); + +/** + * ip6_tnl_xmit - encapsulate packet and send + * @skb: the outgoing socket buffer + * @dev: the outgoing tunnel device + * @dsfield: dscp code for outer header + * @fl6: flow of tunneled packet + * @encap_limit: encapsulation limit + * @pmtu: Path MTU is stored if packet is too big + * @proto: next header value + * + * Description: + * Build new header and do some sanity checks on the packet before sending + * it. + * + * Return: + * 0 on success + * -1 fail + * %-EMSGSIZE message too big. return mtu in this case. + **/ + +int ip6_tnl_xmit(struct sk_buff *skb, struct net_device *dev, __u8 dsfield, + struct flowi6 *fl6, int encap_limit, __u32 *pmtu, + __u8 proto) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct net *net = t->net; + struct net_device_stats *stats = &t->dev->stats; + struct ipv6hdr *ipv6h; + struct ipv6_tel_txoption opt; + struct dst_entry *dst = NULL, *ndst = NULL; + struct net_device *tdev; + int mtu; + unsigned int eth_hlen = t->dev->type == ARPHRD_ETHER ? ETH_HLEN : 0; + unsigned int psh_hlen = sizeof(struct ipv6hdr) + t->encap_hlen; + unsigned int max_headroom = psh_hlen; + __be16 payload_protocol; + bool use_cache = false; + u8 hop_limit; + int err = -1; + + payload_protocol = skb_protocol(skb, true); + + if (t->parms.collect_md) { + hop_limit = skb_tunnel_info(skb)->key.ttl; + goto route_lookup; + } else { + hop_limit = t->parms.hop_limit; + } + + /* NBMA tunnel */ + if (ipv6_addr_any(&t->parms.raddr)) { + if (payload_protocol == htons(ETH_P_IPV6)) { + struct in6_addr *addr6; + struct neighbour *neigh; + int addr_type; + + if (!skb_dst(skb)) + goto tx_err_link_failure; + + neigh = dst_neigh_lookup(skb_dst(skb), + &ipv6_hdr(skb)->daddr); + if (!neigh) + goto tx_err_link_failure; + + addr6 = (struct in6_addr *)&neigh->primary_key; + addr_type = ipv6_addr_type(addr6); + + if (addr_type == IPV6_ADDR_ANY) + addr6 = &ipv6_hdr(skb)->daddr; + + memcpy(&fl6->daddr, addr6, sizeof(fl6->daddr)); + neigh_release(neigh); + } else if (payload_protocol == htons(ETH_P_IP)) { + const struct rtable *rt = skb_rtable(skb); + + if (!rt) + goto tx_err_link_failure; + + if (rt->rt_gw_family == AF_INET6) + memcpy(&fl6->daddr, &rt->rt_gw6, sizeof(fl6->daddr)); + } + } else if (t->parms.proto != 0 && !(t->parms.flags & + (IP6_TNL_F_USE_ORIG_TCLASS | + IP6_TNL_F_USE_ORIG_FWMARK))) { + /* enable the cache only if neither the outer protocol nor the + * routing decision depends on the current inner header value + */ + use_cache = true; + } + + if (use_cache) + dst = dst_cache_get(&t->dst_cache); + + if (!ip6_tnl_xmit_ctl(t, &fl6->saddr, &fl6->daddr)) + goto tx_err_link_failure; + + if (!dst) { +route_lookup: + /* add dsfield to flowlabel for route lookup */ + fl6->flowlabel = ip6_make_flowinfo(dsfield, fl6->flowlabel); + + dst = ip6_route_output(net, NULL, fl6); + + if (dst->error) + goto tx_err_link_failure; + dst = xfrm_lookup(net, dst, flowi6_to_flowi(fl6), NULL, 0); + if (IS_ERR(dst)) { + err = PTR_ERR(dst); + dst = NULL; + goto tx_err_link_failure; + } + if (t->parms.collect_md && ipv6_addr_any(&fl6->saddr) && + ipv6_dev_get_saddr(net, ip6_dst_idev(dst)->dev, + &fl6->daddr, 0, &fl6->saddr)) + goto tx_err_link_failure; + ndst = dst; + } + + tdev = dst->dev; + + if (tdev == dev) { + stats->collisions++; + net_warn_ratelimited("%s: Local routing loop detected!\n", + t->parms.name); + goto tx_err_dst_release; + } + mtu = dst_mtu(dst) - eth_hlen - psh_hlen - t->tun_hlen; + if (encap_limit >= 0) { + max_headroom += 8; + mtu -= 8; + } + mtu = max(mtu, skb->protocol == htons(ETH_P_IPV6) ? + IPV6_MIN_MTU : IPV4_MIN_MTU); + + skb_dst_update_pmtu_no_confirm(skb, mtu); + if (skb->len - t->tun_hlen - eth_hlen > mtu && !skb_is_gso(skb)) { + *pmtu = mtu; + err = -EMSGSIZE; + goto tx_err_dst_release; + } + + if (t->err_count > 0) { + if (time_before(jiffies, + t->err_time + IP6TUNNEL_ERR_TIMEO)) { + t->err_count--; + + dst_link_failure(skb); + } else { + t->err_count = 0; + } + } + + skb_scrub_packet(skb, !net_eq(t->net, dev_net(dev))); + + /* + * Okay, now see if we can stuff it in the buffer as-is. + */ + max_headroom += LL_RESERVED_SPACE(tdev); + + if (skb_headroom(skb) < max_headroom || skb_shared(skb) || + (skb_cloned(skb) && !skb_clone_writable(skb, 0))) { + struct sk_buff *new_skb; + + new_skb = skb_realloc_headroom(skb, max_headroom); + if (!new_skb) + goto tx_err_dst_release; + + if (skb->sk) + skb_set_owner_w(new_skb, skb->sk); + consume_skb(skb); + skb = new_skb; + } + + if (t->parms.collect_md) { + if (t->encap.type != TUNNEL_ENCAP_NONE) + goto tx_err_dst_release; + } else { + if (use_cache && ndst) + dst_cache_set_ip6(&t->dst_cache, ndst, &fl6->saddr); + } + skb_dst_set(skb, dst); + + if (hop_limit == 0) { + if (payload_protocol == htons(ETH_P_IP)) + hop_limit = ip_hdr(skb)->ttl; + else if (payload_protocol == htons(ETH_P_IPV6)) + hop_limit = ipv6_hdr(skb)->hop_limit; + else + hop_limit = ip6_dst_hoplimit(dst); + } + + /* Calculate max headroom for all the headers and adjust + * needed_headroom if necessary. + */ + max_headroom = LL_RESERVED_SPACE(dst->dev) + sizeof(struct ipv6hdr) + + dst->header_len + t->hlen; + if (max_headroom > READ_ONCE(dev->needed_headroom)) + WRITE_ONCE(dev->needed_headroom, max_headroom); + + err = ip6_tnl_encap(skb, t, &proto, fl6); + if (err) + return err; + + if (encap_limit >= 0) { + init_tel_txopt(&opt, encap_limit); + ipv6_push_frag_opts(skb, &opt.ops, &proto); + } + + skb_push(skb, sizeof(struct ipv6hdr)); + skb_reset_network_header(skb); + ipv6h = ipv6_hdr(skb); + ip6_flow_hdr(ipv6h, dsfield, + ip6_make_flowlabel(net, skb, fl6->flowlabel, true, fl6)); + ipv6h->hop_limit = hop_limit; + ipv6h->nexthdr = proto; + ipv6h->saddr = fl6->saddr; + ipv6h->daddr = fl6->daddr; + ip6tunnel_xmit(NULL, skb, dev); + return 0; +tx_err_link_failure: + stats->tx_carrier_errors++; + dst_link_failure(skb); +tx_err_dst_release: + dst_release(dst); + return err; +} +EXPORT_SYMBOL(ip6_tnl_xmit); + +static inline int +ipxip6_tnl_xmit(struct sk_buff *skb, struct net_device *dev, + u8 protocol) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct ipv6hdr *ipv6h; + const struct iphdr *iph; + int encap_limit = -1; + __u16 offset; + struct flowi6 fl6; + __u8 dsfield, orig_dsfield; + __u32 mtu; + u8 tproto; + int err; + + tproto = READ_ONCE(t->parms.proto); + if (tproto != protocol && tproto != 0) + return -1; + + if (t->parms.collect_md) { + struct ip_tunnel_info *tun_info; + const struct ip_tunnel_key *key; + + tun_info = skb_tunnel_info(skb); + if (unlikely(!tun_info || !(tun_info->mode & IP_TUNNEL_INFO_TX) || + ip_tunnel_info_af(tun_info) != AF_INET6)) + return -1; + key = &tun_info->key; + memset(&fl6, 0, sizeof(fl6)); + fl6.flowi6_proto = protocol; + fl6.saddr = key->u.ipv6.src; + fl6.daddr = key->u.ipv6.dst; + fl6.flowlabel = key->label; + dsfield = key->tos; + switch (protocol) { + case IPPROTO_IPIP: + iph = ip_hdr(skb); + orig_dsfield = ipv4_get_dsfield(iph); + break; + case IPPROTO_IPV6: + ipv6h = ipv6_hdr(skb); + orig_dsfield = ipv6_get_dsfield(ipv6h); + break; + default: + orig_dsfield = dsfield; + break; + } + } else { + if (!(t->parms.flags & IP6_TNL_F_IGN_ENCAP_LIMIT)) + encap_limit = t->parms.encap_limit; + if (protocol == IPPROTO_IPV6) { + offset = ip6_tnl_parse_tlv_enc_lim(skb, + skb_network_header(skb)); + /* ip6_tnl_parse_tlv_enc_lim() might have + * reallocated skb->head + */ + if (offset > 0) { + struct ipv6_tlv_tnl_enc_lim *tel; + + tel = (void *)&skb_network_header(skb)[offset]; + if (tel->encap_limit == 0) { + icmpv6_ndo_send(skb, ICMPV6_PARAMPROB, + ICMPV6_HDR_FIELD, offset + 2); + return -1; + } + encap_limit = tel->encap_limit - 1; + } + } + + memcpy(&fl6, &t->fl.u.ip6, sizeof(fl6)); + fl6.flowi6_proto = protocol; + + if (t->parms.flags & IP6_TNL_F_USE_ORIG_FWMARK) + fl6.flowi6_mark = skb->mark; + else + fl6.flowi6_mark = t->parms.fwmark; + switch (protocol) { + case IPPROTO_IPIP: + iph = ip_hdr(skb); + orig_dsfield = ipv4_get_dsfield(iph); + if (t->parms.flags & IP6_TNL_F_USE_ORIG_TCLASS) + dsfield = orig_dsfield; + else + dsfield = ip6_tclass(t->parms.flowinfo); + break; + case IPPROTO_IPV6: + ipv6h = ipv6_hdr(skb); + orig_dsfield = ipv6_get_dsfield(ipv6h); + if (t->parms.flags & IP6_TNL_F_USE_ORIG_TCLASS) + dsfield = orig_dsfield; + else + dsfield = ip6_tclass(t->parms.flowinfo); + if (t->parms.flags & IP6_TNL_F_USE_ORIG_FLOWLABEL) + fl6.flowlabel |= ip6_flowlabel(ipv6h); + break; + default: + orig_dsfield = dsfield = ip6_tclass(t->parms.flowinfo); + break; + } + } + + fl6.flowi6_uid = sock_net_uid(dev_net(dev), NULL); + dsfield = INET_ECN_encapsulate(dsfield, orig_dsfield); + + if (iptunnel_handle_offloads(skb, SKB_GSO_IPXIP6)) + return -1; + + skb_set_inner_ipproto(skb, protocol); + + err = ip6_tnl_xmit(skb, dev, dsfield, &fl6, encap_limit, &mtu, + protocol); + if (err != 0) { + /* XXX: send ICMP error even if DF is not set. */ + if (err == -EMSGSIZE) + switch (protocol) { + case IPPROTO_IPIP: + icmp_ndo_send(skb, ICMP_DEST_UNREACH, + ICMP_FRAG_NEEDED, htonl(mtu)); + break; + case IPPROTO_IPV6: + icmpv6_ndo_send(skb, ICMPV6_PKT_TOOBIG, 0, mtu); + break; + default: + break; + } + return -1; + } + + return 0; +} + +static netdev_tx_t +ip6_tnl_start_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct net_device_stats *stats = &t->dev->stats; + u8 ipproto; + int ret; + + if (!pskb_inet_may_pull(skb)) + goto tx_err; + + switch (skb->protocol) { + case htons(ETH_P_IP): + ipproto = IPPROTO_IPIP; + break; + case htons(ETH_P_IPV6): + if (ip6_tnl_addr_conflict(t, ipv6_hdr(skb))) + goto tx_err; + ipproto = IPPROTO_IPV6; + break; + case htons(ETH_P_MPLS_UC): + ipproto = IPPROTO_MPLS; + break; + default: + goto tx_err; + } + + ret = ipxip6_tnl_xmit(skb, dev, ipproto); + if (ret < 0) + goto tx_err; + + return NETDEV_TX_OK; + +tx_err: + stats->tx_errors++; + stats->tx_dropped++; + kfree_skb(skb); + return NETDEV_TX_OK; +} + +static void ip6_tnl_link_config(struct ip6_tnl *t) +{ + struct net_device *dev = t->dev; + struct net_device *tdev = NULL; + struct __ip6_tnl_parm *p = &t->parms; + struct flowi6 *fl6 = &t->fl.u.ip6; + int t_hlen; + int mtu; + + __dev_addr_set(dev, &p->laddr, sizeof(struct in6_addr)); + memcpy(dev->broadcast, &p->raddr, sizeof(struct in6_addr)); + + /* Set up flowi template */ + fl6->saddr = p->laddr; + fl6->daddr = p->raddr; + fl6->flowi6_oif = p->link; + fl6->flowlabel = 0; + + if (!(p->flags&IP6_TNL_F_USE_ORIG_TCLASS)) + fl6->flowlabel |= IPV6_TCLASS_MASK & p->flowinfo; + if (!(p->flags&IP6_TNL_F_USE_ORIG_FLOWLABEL)) + fl6->flowlabel |= IPV6_FLOWLABEL_MASK & p->flowinfo; + + p->flags &= ~(IP6_TNL_F_CAP_XMIT|IP6_TNL_F_CAP_RCV|IP6_TNL_F_CAP_PER_PACKET); + p->flags |= ip6_tnl_get_cap(t, &p->laddr, &p->raddr); + + if (p->flags&IP6_TNL_F_CAP_XMIT && p->flags&IP6_TNL_F_CAP_RCV) + dev->flags |= IFF_POINTOPOINT; + else + dev->flags &= ~IFF_POINTOPOINT; + + t->tun_hlen = 0; + t->hlen = t->encap_hlen + t->tun_hlen; + t_hlen = t->hlen + sizeof(struct ipv6hdr); + + if (p->flags & IP6_TNL_F_CAP_XMIT) { + int strict = (ipv6_addr_type(&p->raddr) & + (IPV6_ADDR_MULTICAST|IPV6_ADDR_LINKLOCAL)); + + struct rt6_info *rt = rt6_lookup(t->net, + &p->raddr, &p->laddr, + p->link, NULL, strict); + if (rt) { + tdev = rt->dst.dev; + ip6_rt_put(rt); + } + + if (!tdev && p->link) + tdev = __dev_get_by_index(t->net, p->link); + + if (tdev) { + dev->hard_header_len = tdev->hard_header_len + t_hlen; + mtu = min_t(unsigned int, tdev->mtu, IP6_MAX_MTU); + + mtu = mtu - t_hlen; + if (!(t->parms.flags & IP6_TNL_F_IGN_ENCAP_LIMIT)) + mtu -= 8; + + if (mtu < IPV6_MIN_MTU) + mtu = IPV6_MIN_MTU; + WRITE_ONCE(dev->mtu, mtu); + } + } +} + +/** + * ip6_tnl_change - update the tunnel parameters + * @t: tunnel to be changed + * @p: tunnel configuration parameters + * + * Description: + * ip6_tnl_change() updates the tunnel parameters + **/ + +static void +ip6_tnl_change(struct ip6_tnl *t, const struct __ip6_tnl_parm *p) +{ + t->parms.laddr = p->laddr; + t->parms.raddr = p->raddr; + t->parms.flags = p->flags; + t->parms.hop_limit = p->hop_limit; + t->parms.encap_limit = p->encap_limit; + t->parms.flowinfo = p->flowinfo; + t->parms.link = p->link; + t->parms.proto = p->proto; + t->parms.fwmark = p->fwmark; + dst_cache_reset(&t->dst_cache); + ip6_tnl_link_config(t); +} + +static void ip6_tnl_update(struct ip6_tnl *t, struct __ip6_tnl_parm *p) +{ + struct net *net = t->net; + struct ip6_tnl_net *ip6n = net_generic(net, ip6_tnl_net_id); + + ip6_tnl_unlink(ip6n, t); + synchronize_net(); + ip6_tnl_change(t, p); + ip6_tnl_link(ip6n, t); + netdev_state_change(t->dev); +} + +static void ip6_tnl0_update(struct ip6_tnl *t, struct __ip6_tnl_parm *p) +{ + /* for default tnl0 device allow to change only the proto */ + t->parms.proto = p->proto; + netdev_state_change(t->dev); +} + +static void +ip6_tnl_parm_from_user(struct __ip6_tnl_parm *p, const struct ip6_tnl_parm *u) +{ + p->laddr = u->laddr; + p->raddr = u->raddr; + p->flags = u->flags; + p->hop_limit = u->hop_limit; + p->encap_limit = u->encap_limit; + p->flowinfo = u->flowinfo; + p->link = u->link; + p->proto = u->proto; + memcpy(p->name, u->name, sizeof(u->name)); +} + +static void +ip6_tnl_parm_to_user(struct ip6_tnl_parm *u, const struct __ip6_tnl_parm *p) +{ + u->laddr = p->laddr; + u->raddr = p->raddr; + u->flags = p->flags; + u->hop_limit = p->hop_limit; + u->encap_limit = p->encap_limit; + u->flowinfo = p->flowinfo; + u->link = p->link; + u->proto = p->proto; + memcpy(u->name, p->name, sizeof(u->name)); +} + +/** + * ip6_tnl_siocdevprivate - configure ipv6 tunnels from userspace + * @dev: virtual device associated with tunnel + * @ifr: unused + * @data: parameters passed from userspace + * @cmd: command to be performed + * + * Description: + * ip6_tnl_ioctl() is used for managing IPv6 tunnels + * from userspace. + * + * The possible commands are the following: + * %SIOCGETTUNNEL: get tunnel parameters for device + * %SIOCADDTUNNEL: add tunnel matching given tunnel parameters + * %SIOCCHGTUNNEL: change tunnel parameters to those given + * %SIOCDELTUNNEL: delete tunnel + * + * The fallback device "ip6tnl0", created during module + * initialization, can be used for creating other tunnel devices. + * + * Return: + * 0 on success, + * %-EFAULT if unable to copy data to or from userspace, + * %-EPERM if current process hasn't %CAP_NET_ADMIN set + * %-EINVAL if passed tunnel parameters are invalid, + * %-EEXIST if changing a tunnel's parameters would cause a conflict + * %-ENODEV if attempting to change or delete a nonexisting device + **/ + +static int +ip6_tnl_siocdevprivate(struct net_device *dev, struct ifreq *ifr, + void __user *data, int cmd) +{ + int err = 0; + struct ip6_tnl_parm p; + struct __ip6_tnl_parm p1; + struct ip6_tnl *t = netdev_priv(dev); + struct net *net = t->net; + struct ip6_tnl_net *ip6n = net_generic(net, ip6_tnl_net_id); + + memset(&p1, 0, sizeof(p1)); + + switch (cmd) { + case SIOCGETTUNNEL: + if (dev == ip6n->fb_tnl_dev) { + if (copy_from_user(&p, data, sizeof(p))) { + err = -EFAULT; + break; + } + ip6_tnl_parm_from_user(&p1, &p); + t = ip6_tnl_locate(net, &p1, 0); + if (IS_ERR(t)) + t = netdev_priv(dev); + } else { + memset(&p, 0, sizeof(p)); + } + ip6_tnl_parm_to_user(&p, &t->parms); + if (copy_to_user(data, &p, sizeof(p))) + err = -EFAULT; + break; + case SIOCADDTUNNEL: + case SIOCCHGTUNNEL: + err = -EPERM; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + break; + err = -EFAULT; + if (copy_from_user(&p, data, sizeof(p))) + break; + err = -EINVAL; + if (p.proto != IPPROTO_IPV6 && p.proto != IPPROTO_IPIP && + p.proto != 0) + break; + ip6_tnl_parm_from_user(&p1, &p); + t = ip6_tnl_locate(net, &p1, cmd == SIOCADDTUNNEL); + if (cmd == SIOCCHGTUNNEL) { + if (!IS_ERR(t)) { + if (t->dev != dev) { + err = -EEXIST; + break; + } + } else + t = netdev_priv(dev); + if (dev == ip6n->fb_tnl_dev) + ip6_tnl0_update(t, &p1); + else + ip6_tnl_update(t, &p1); + } + if (!IS_ERR(t)) { + err = 0; + ip6_tnl_parm_to_user(&p, &t->parms); + if (copy_to_user(data, &p, sizeof(p))) + err = -EFAULT; + + } else { + err = PTR_ERR(t); + } + break; + case SIOCDELTUNNEL: + err = -EPERM; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + break; + + if (dev == ip6n->fb_tnl_dev) { + err = -EFAULT; + if (copy_from_user(&p, data, sizeof(p))) + break; + err = -ENOENT; + ip6_tnl_parm_from_user(&p1, &p); + t = ip6_tnl_locate(net, &p1, 0); + if (IS_ERR(t)) + break; + err = -EPERM; + if (t->dev == ip6n->fb_tnl_dev) + break; + dev = t->dev; + } + err = 0; + unregister_netdevice(dev); + break; + default: + err = -EINVAL; + } + return err; +} + +/** + * ip6_tnl_change_mtu - change mtu manually for tunnel device + * @dev: virtual device associated with tunnel + * @new_mtu: the new mtu + * + * Return: + * 0 on success, + * %-EINVAL if mtu too small + **/ + +int ip6_tnl_change_mtu(struct net_device *dev, int new_mtu) +{ + struct ip6_tnl *tnl = netdev_priv(dev); + + if (tnl->parms.proto == IPPROTO_IPV6) { + if (new_mtu < IPV6_MIN_MTU) + return -EINVAL; + } else { + if (new_mtu < ETH_MIN_MTU) + return -EINVAL; + } + if (tnl->parms.proto == IPPROTO_IPV6 || tnl->parms.proto == 0) { + if (new_mtu > IP6_MAX_MTU - dev->hard_header_len) + return -EINVAL; + } else { + if (new_mtu > IP_MAX_MTU - dev->hard_header_len) + return -EINVAL; + } + dev->mtu = new_mtu; + return 0; +} +EXPORT_SYMBOL(ip6_tnl_change_mtu); + +int ip6_tnl_get_iflink(const struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + + return t->parms.link; +} +EXPORT_SYMBOL(ip6_tnl_get_iflink); + +int ip6_tnl_encap_add_ops(const struct ip6_tnl_encap_ops *ops, + unsigned int num) +{ + if (num >= MAX_IPTUN_ENCAP_OPS) + return -ERANGE; + + return !cmpxchg((const struct ip6_tnl_encap_ops **) + &ip6tun_encaps[num], + NULL, ops) ? 0 : -1; +} +EXPORT_SYMBOL(ip6_tnl_encap_add_ops); + +int ip6_tnl_encap_del_ops(const struct ip6_tnl_encap_ops *ops, + unsigned int num) +{ + int ret; + + if (num >= MAX_IPTUN_ENCAP_OPS) + return -ERANGE; + + ret = (cmpxchg((const struct ip6_tnl_encap_ops **) + &ip6tun_encaps[num], + ops, NULL) == ops) ? 0 : -1; + + synchronize_net(); + + return ret; +} +EXPORT_SYMBOL(ip6_tnl_encap_del_ops); + +int ip6_tnl_encap_setup(struct ip6_tnl *t, + struct ip_tunnel_encap *ipencap) +{ + int hlen; + + memset(&t->encap, 0, sizeof(t->encap)); + + hlen = ip6_encap_hlen(ipencap); + if (hlen < 0) + return hlen; + + t->encap.type = ipencap->type; + t->encap.sport = ipencap->sport; + t->encap.dport = ipencap->dport; + t->encap.flags = ipencap->flags; + + t->encap_hlen = hlen; + t->hlen = t->encap_hlen + t->tun_hlen; + + return 0; +} +EXPORT_SYMBOL_GPL(ip6_tnl_encap_setup); + +static const struct net_device_ops ip6_tnl_netdev_ops = { + .ndo_init = ip6_tnl_dev_init, + .ndo_uninit = ip6_tnl_dev_uninit, + .ndo_start_xmit = ip6_tnl_start_xmit, + .ndo_siocdevprivate = ip6_tnl_siocdevprivate, + .ndo_change_mtu = ip6_tnl_change_mtu, + .ndo_get_stats64 = dev_get_tstats64, + .ndo_get_iflink = ip6_tnl_get_iflink, +}; + +#define IPXIPX_FEATURES (NETIF_F_SG | \ + NETIF_F_FRAGLIST | \ + NETIF_F_HIGHDMA | \ + NETIF_F_GSO_SOFTWARE | \ + NETIF_F_HW_CSUM) + +/** + * ip6_tnl_dev_setup - setup virtual tunnel device + * @dev: virtual device associated with tunnel + * + * Description: + * Initialize function pointers and device parameters + **/ + +static void ip6_tnl_dev_setup(struct net_device *dev) +{ + dev->netdev_ops = &ip6_tnl_netdev_ops; + dev->header_ops = &ip_tunnel_header_ops; + dev->needs_free_netdev = true; + dev->priv_destructor = ip6_dev_free; + + dev->type = ARPHRD_TUNNEL6; + dev->flags |= IFF_NOARP; + dev->addr_len = sizeof(struct in6_addr); + dev->features |= NETIF_F_LLTX; + netif_keep_dst(dev); + + dev->features |= IPXIPX_FEATURES; + dev->hw_features |= IPXIPX_FEATURES; + + /* This perm addr will be used as interface identifier by IPv6 */ + dev->addr_assign_type = NET_ADDR_RANDOM; + eth_random_addr(dev->perm_addr); +} + + +/** + * ip6_tnl_dev_init_gen - general initializer for all tunnel devices + * @dev: virtual device associated with tunnel + **/ + +static inline int +ip6_tnl_dev_init_gen(struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + int ret; + int t_hlen; + + t->dev = dev; + t->net = dev_net(dev); + dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); + if (!dev->tstats) + return -ENOMEM; + + ret = dst_cache_init(&t->dst_cache, GFP_KERNEL); + if (ret) + goto free_stats; + + ret = gro_cells_init(&t->gro_cells, dev); + if (ret) + goto destroy_dst; + + t->tun_hlen = 0; + t->hlen = t->encap_hlen + t->tun_hlen; + t_hlen = t->hlen + sizeof(struct ipv6hdr); + + dev->type = ARPHRD_TUNNEL6; + dev->hard_header_len = LL_MAX_HEADER + t_hlen; + dev->mtu = ETH_DATA_LEN - t_hlen; + if (!(t->parms.flags & IP6_TNL_F_IGN_ENCAP_LIMIT)) + dev->mtu -= 8; + dev->min_mtu = ETH_MIN_MTU; + dev->max_mtu = IP6_MAX_MTU - dev->hard_header_len; + + netdev_hold(dev, &t->dev_tracker, GFP_KERNEL); + return 0; + +destroy_dst: + dst_cache_destroy(&t->dst_cache); +free_stats: + free_percpu(dev->tstats); + dev->tstats = NULL; + + return ret; +} + +/** + * ip6_tnl_dev_init - initializer for all non fallback tunnel devices + * @dev: virtual device associated with tunnel + **/ + +static int ip6_tnl_dev_init(struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + int err = ip6_tnl_dev_init_gen(dev); + + if (err) + return err; + ip6_tnl_link_config(t); + if (t->parms.collect_md) + netif_keep_dst(dev); + return 0; +} + +/** + * ip6_fb_tnl_dev_init - initializer for fallback tunnel device + * @dev: fallback device + * + * Return: 0 + **/ + +static int __net_init ip6_fb_tnl_dev_init(struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct net *net = dev_net(dev); + struct ip6_tnl_net *ip6n = net_generic(net, ip6_tnl_net_id); + + t->parms.proto = IPPROTO_IPV6; + + rcu_assign_pointer(ip6n->tnls_wc[0], t); + return 0; +} + +static int ip6_tnl_validate(struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + u8 proto; + + if (!data || !data[IFLA_IPTUN_PROTO]) + return 0; + + proto = nla_get_u8(data[IFLA_IPTUN_PROTO]); + if (proto != IPPROTO_IPV6 && + proto != IPPROTO_IPIP && + proto != 0) + return -EINVAL; + + return 0; +} + +static void ip6_tnl_netlink_parms(struct nlattr *data[], + struct __ip6_tnl_parm *parms) +{ + memset(parms, 0, sizeof(*parms)); + + if (!data) + return; + + if (data[IFLA_IPTUN_LINK]) + parms->link = nla_get_u32(data[IFLA_IPTUN_LINK]); + + if (data[IFLA_IPTUN_LOCAL]) + parms->laddr = nla_get_in6_addr(data[IFLA_IPTUN_LOCAL]); + + if (data[IFLA_IPTUN_REMOTE]) + parms->raddr = nla_get_in6_addr(data[IFLA_IPTUN_REMOTE]); + + if (data[IFLA_IPTUN_TTL]) + parms->hop_limit = nla_get_u8(data[IFLA_IPTUN_TTL]); + + if (data[IFLA_IPTUN_ENCAP_LIMIT]) + parms->encap_limit = nla_get_u8(data[IFLA_IPTUN_ENCAP_LIMIT]); + + if (data[IFLA_IPTUN_FLOWINFO]) + parms->flowinfo = nla_get_be32(data[IFLA_IPTUN_FLOWINFO]); + + if (data[IFLA_IPTUN_FLAGS]) + parms->flags = nla_get_u32(data[IFLA_IPTUN_FLAGS]); + + if (data[IFLA_IPTUN_PROTO]) + parms->proto = nla_get_u8(data[IFLA_IPTUN_PROTO]); + + if (data[IFLA_IPTUN_COLLECT_METADATA]) + parms->collect_md = true; + + if (data[IFLA_IPTUN_FWMARK]) + parms->fwmark = nla_get_u32(data[IFLA_IPTUN_FWMARK]); +} + +static int ip6_tnl_newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct net *net = dev_net(dev); + struct ip6_tnl_net *ip6n = net_generic(net, ip6_tnl_net_id); + struct ip_tunnel_encap ipencap; + struct ip6_tnl *nt, *t; + int err; + + nt = netdev_priv(dev); + + if (ip_tunnel_netlink_encap_parms(data, &ipencap)) { + err = ip6_tnl_encap_setup(nt, &ipencap); + if (err < 0) + return err; + } + + ip6_tnl_netlink_parms(data, &nt->parms); + + if (nt->parms.collect_md) { + if (rtnl_dereference(ip6n->collect_md_tun)) + return -EEXIST; + } else { + t = ip6_tnl_locate(net, &nt->parms, 0); + if (!IS_ERR(t)) + return -EEXIST; + } + + err = ip6_tnl_create2(dev); + if (!err && tb[IFLA_MTU]) + ip6_tnl_change_mtu(dev, nla_get_u32(tb[IFLA_MTU])); + + return err; +} + +static int ip6_tnl_changelink(struct net_device *dev, struct nlattr *tb[], + struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct __ip6_tnl_parm p; + struct net *net = t->net; + struct ip6_tnl_net *ip6n = net_generic(net, ip6_tnl_net_id); + struct ip_tunnel_encap ipencap; + + if (dev == ip6n->fb_tnl_dev) + return -EINVAL; + + if (ip_tunnel_netlink_encap_parms(data, &ipencap)) { + int err = ip6_tnl_encap_setup(t, &ipencap); + + if (err < 0) + return err; + } + ip6_tnl_netlink_parms(data, &p); + if (p.collect_md) + return -EINVAL; + + t = ip6_tnl_locate(net, &p, 0); + if (!IS_ERR(t)) { + if (t->dev != dev) + return -EEXIST; + } else + t = netdev_priv(dev); + + ip6_tnl_update(t, &p); + return 0; +} + +static void ip6_tnl_dellink(struct net_device *dev, struct list_head *head) +{ + struct net *net = dev_net(dev); + struct ip6_tnl_net *ip6n = net_generic(net, ip6_tnl_net_id); + + if (dev != ip6n->fb_tnl_dev) + unregister_netdevice_queue(dev, head); +} + +static size_t ip6_tnl_get_size(const struct net_device *dev) +{ + return + /* IFLA_IPTUN_LINK */ + nla_total_size(4) + + /* IFLA_IPTUN_LOCAL */ + nla_total_size(sizeof(struct in6_addr)) + + /* IFLA_IPTUN_REMOTE */ + nla_total_size(sizeof(struct in6_addr)) + + /* IFLA_IPTUN_TTL */ + nla_total_size(1) + + /* IFLA_IPTUN_ENCAP_LIMIT */ + nla_total_size(1) + + /* IFLA_IPTUN_FLOWINFO */ + nla_total_size(4) + + /* IFLA_IPTUN_FLAGS */ + nla_total_size(4) + + /* IFLA_IPTUN_PROTO */ + nla_total_size(1) + + /* IFLA_IPTUN_ENCAP_TYPE */ + nla_total_size(2) + + /* IFLA_IPTUN_ENCAP_FLAGS */ + nla_total_size(2) + + /* IFLA_IPTUN_ENCAP_SPORT */ + nla_total_size(2) + + /* IFLA_IPTUN_ENCAP_DPORT */ + nla_total_size(2) + + /* IFLA_IPTUN_COLLECT_METADATA */ + nla_total_size(0) + + /* IFLA_IPTUN_FWMARK */ + nla_total_size(4) + + 0; +} + +static int ip6_tnl_fill_info(struct sk_buff *skb, const struct net_device *dev) +{ + struct ip6_tnl *tunnel = netdev_priv(dev); + struct __ip6_tnl_parm *parm = &tunnel->parms; + + if (nla_put_u32(skb, IFLA_IPTUN_LINK, parm->link) || + nla_put_in6_addr(skb, IFLA_IPTUN_LOCAL, &parm->laddr) || + nla_put_in6_addr(skb, IFLA_IPTUN_REMOTE, &parm->raddr) || + nla_put_u8(skb, IFLA_IPTUN_TTL, parm->hop_limit) || + nla_put_u8(skb, IFLA_IPTUN_ENCAP_LIMIT, parm->encap_limit) || + nla_put_be32(skb, IFLA_IPTUN_FLOWINFO, parm->flowinfo) || + nla_put_u32(skb, IFLA_IPTUN_FLAGS, parm->flags) || + nla_put_u8(skb, IFLA_IPTUN_PROTO, parm->proto) || + nla_put_u32(skb, IFLA_IPTUN_FWMARK, parm->fwmark)) + goto nla_put_failure; + + if (nla_put_u16(skb, IFLA_IPTUN_ENCAP_TYPE, tunnel->encap.type) || + nla_put_be16(skb, IFLA_IPTUN_ENCAP_SPORT, tunnel->encap.sport) || + nla_put_be16(skb, IFLA_IPTUN_ENCAP_DPORT, tunnel->encap.dport) || + nla_put_u16(skb, IFLA_IPTUN_ENCAP_FLAGS, tunnel->encap.flags)) + goto nla_put_failure; + + if (parm->collect_md) + if (nla_put_flag(skb, IFLA_IPTUN_COLLECT_METADATA)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +struct net *ip6_tnl_get_link_net(const struct net_device *dev) +{ + struct ip6_tnl *tunnel = netdev_priv(dev); + + return tunnel->net; +} +EXPORT_SYMBOL(ip6_tnl_get_link_net); + +static const struct nla_policy ip6_tnl_policy[IFLA_IPTUN_MAX + 1] = { + [IFLA_IPTUN_LINK] = { .type = NLA_U32 }, + [IFLA_IPTUN_LOCAL] = { .len = sizeof(struct in6_addr) }, + [IFLA_IPTUN_REMOTE] = { .len = sizeof(struct in6_addr) }, + [IFLA_IPTUN_TTL] = { .type = NLA_U8 }, + [IFLA_IPTUN_ENCAP_LIMIT] = { .type = NLA_U8 }, + [IFLA_IPTUN_FLOWINFO] = { .type = NLA_U32 }, + [IFLA_IPTUN_FLAGS] = { .type = NLA_U32 }, + [IFLA_IPTUN_PROTO] = { .type = NLA_U8 }, + [IFLA_IPTUN_ENCAP_TYPE] = { .type = NLA_U16 }, + [IFLA_IPTUN_ENCAP_FLAGS] = { .type = NLA_U16 }, + [IFLA_IPTUN_ENCAP_SPORT] = { .type = NLA_U16 }, + [IFLA_IPTUN_ENCAP_DPORT] = { .type = NLA_U16 }, + [IFLA_IPTUN_COLLECT_METADATA] = { .type = NLA_FLAG }, + [IFLA_IPTUN_FWMARK] = { .type = NLA_U32 }, +}; + +static struct rtnl_link_ops ip6_link_ops __read_mostly = { + .kind = "ip6tnl", + .maxtype = IFLA_IPTUN_MAX, + .policy = ip6_tnl_policy, + .priv_size = sizeof(struct ip6_tnl), + .setup = ip6_tnl_dev_setup, + .validate = ip6_tnl_validate, + .newlink = ip6_tnl_newlink, + .changelink = ip6_tnl_changelink, + .dellink = ip6_tnl_dellink, + .get_size = ip6_tnl_get_size, + .fill_info = ip6_tnl_fill_info, + .get_link_net = ip6_tnl_get_link_net, +}; + +static struct xfrm6_tunnel ip4ip6_handler __read_mostly = { + .handler = ip4ip6_rcv, + .err_handler = ip4ip6_err, + .priority = 1, +}; + +static struct xfrm6_tunnel ip6ip6_handler __read_mostly = { + .handler = ip6ip6_rcv, + .err_handler = ip6ip6_err, + .priority = 1, +}; + +static struct xfrm6_tunnel mplsip6_handler __read_mostly = { + .handler = mplsip6_rcv, + .err_handler = mplsip6_err, + .priority = 1, +}; + +static void __net_exit ip6_tnl_destroy_tunnels(struct net *net, struct list_head *list) +{ + struct ip6_tnl_net *ip6n = net_generic(net, ip6_tnl_net_id); + struct net_device *dev, *aux; + int h; + struct ip6_tnl *t; + + for_each_netdev_safe(net, dev, aux) + if (dev->rtnl_link_ops == &ip6_link_ops) + unregister_netdevice_queue(dev, list); + + for (h = 0; h < IP6_TUNNEL_HASH_SIZE; h++) { + t = rtnl_dereference(ip6n->tnls_r_l[h]); + while (t) { + /* If dev is in the same netns, it has already + * been added to the list by the previous loop. + */ + if (!net_eq(dev_net(t->dev), net)) + unregister_netdevice_queue(t->dev, list); + t = rtnl_dereference(t->next); + } + } + + t = rtnl_dereference(ip6n->tnls_wc[0]); + while (t) { + /* If dev is in the same netns, it has already + * been added to the list by the previous loop. + */ + if (!net_eq(dev_net(t->dev), net)) + unregister_netdevice_queue(t->dev, list); + t = rtnl_dereference(t->next); + } +} + +static int __net_init ip6_tnl_init_net(struct net *net) +{ + struct ip6_tnl_net *ip6n = net_generic(net, ip6_tnl_net_id); + struct ip6_tnl *t = NULL; + int err; + + ip6n->tnls[0] = ip6n->tnls_wc; + ip6n->tnls[1] = ip6n->tnls_r_l; + + if (!net_has_fallback_tunnels(net)) + return 0; + err = -ENOMEM; + ip6n->fb_tnl_dev = alloc_netdev(sizeof(struct ip6_tnl), "ip6tnl0", + NET_NAME_UNKNOWN, ip6_tnl_dev_setup); + + if (!ip6n->fb_tnl_dev) + goto err_alloc_dev; + dev_net_set(ip6n->fb_tnl_dev, net); + ip6n->fb_tnl_dev->rtnl_link_ops = &ip6_link_ops; + /* FB netdevice is special: we have one, and only one per netns. + * Allowing to move it to another netns is clearly unsafe. + */ + ip6n->fb_tnl_dev->features |= NETIF_F_NETNS_LOCAL; + + err = ip6_fb_tnl_dev_init(ip6n->fb_tnl_dev); + if (err < 0) + goto err_register; + + err = register_netdev(ip6n->fb_tnl_dev); + if (err < 0) + goto err_register; + + t = netdev_priv(ip6n->fb_tnl_dev); + + strcpy(t->parms.name, ip6n->fb_tnl_dev->name); + return 0; + +err_register: + free_netdev(ip6n->fb_tnl_dev); +err_alloc_dev: + return err; +} + +static void __net_exit ip6_tnl_exit_batch_net(struct list_head *net_list) +{ + struct net *net; + LIST_HEAD(list); + + rtnl_lock(); + list_for_each_entry(net, net_list, exit_list) + ip6_tnl_destroy_tunnels(net, &list); + unregister_netdevice_many(&list); + rtnl_unlock(); +} + +static struct pernet_operations ip6_tnl_net_ops = { + .init = ip6_tnl_init_net, + .exit_batch = ip6_tnl_exit_batch_net, + .id = &ip6_tnl_net_id, + .size = sizeof(struct ip6_tnl_net), +}; + +/** + * ip6_tunnel_init - register protocol and reserve needed resources + * + * Return: 0 on success + **/ + +static int __init ip6_tunnel_init(void) +{ + int err; + + if (!ipv6_mod_enabled()) + return -EOPNOTSUPP; + + err = register_pernet_device(&ip6_tnl_net_ops); + if (err < 0) + goto out_pernet; + + err = xfrm6_tunnel_register(&ip4ip6_handler, AF_INET); + if (err < 0) { + pr_err("%s: can't register ip4ip6\n", __func__); + goto out_ip4ip6; + } + + err = xfrm6_tunnel_register(&ip6ip6_handler, AF_INET6); + if (err < 0) { + pr_err("%s: can't register ip6ip6\n", __func__); + goto out_ip6ip6; + } + + if (ip6_tnl_mpls_supported()) { + err = xfrm6_tunnel_register(&mplsip6_handler, AF_MPLS); + if (err < 0) { + pr_err("%s: can't register mplsip6\n", __func__); + goto out_mplsip6; + } + } + + err = rtnl_link_register(&ip6_link_ops); + if (err < 0) + goto rtnl_link_failed; + + return 0; + +rtnl_link_failed: + if (ip6_tnl_mpls_supported()) + xfrm6_tunnel_deregister(&mplsip6_handler, AF_MPLS); +out_mplsip6: + xfrm6_tunnel_deregister(&ip6ip6_handler, AF_INET6); +out_ip6ip6: + xfrm6_tunnel_deregister(&ip4ip6_handler, AF_INET); +out_ip4ip6: + unregister_pernet_device(&ip6_tnl_net_ops); +out_pernet: + return err; +} + +/** + * ip6_tunnel_cleanup - free resources and unregister protocol + **/ + +static void __exit ip6_tunnel_cleanup(void) +{ + rtnl_link_unregister(&ip6_link_ops); + if (xfrm6_tunnel_deregister(&ip4ip6_handler, AF_INET)) + pr_info("%s: can't deregister ip4ip6\n", __func__); + + if (xfrm6_tunnel_deregister(&ip6ip6_handler, AF_INET6)) + pr_info("%s: can't deregister ip6ip6\n", __func__); + + if (ip6_tnl_mpls_supported() && + xfrm6_tunnel_deregister(&mplsip6_handler, AF_MPLS)) + pr_info("%s: can't deregister mplsip6\n", __func__); + unregister_pernet_device(&ip6_tnl_net_ops); +} + +module_init(ip6_tunnel_init); +module_exit(ip6_tunnel_cleanup); diff --git a/net/ipv6/ip6_udp_tunnel.c b/net/ipv6/ip6_udp_tunnel.c new file mode 100644 index 000000000..cdc4d4ee2 --- /dev/null +++ b/net/ipv6/ip6_udp_tunnel.c @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +int udp_sock_create6(struct net *net, struct udp_port_cfg *cfg, + struct socket **sockp) +{ + struct sockaddr_in6 udp6_addr = {}; + int err; + struct socket *sock = NULL; + + err = sock_create_kern(net, AF_INET6, SOCK_DGRAM, 0, &sock); + if (err < 0) + goto error; + + if (cfg->ipv6_v6only) { + err = ip6_sock_set_v6only(sock->sk); + if (err < 0) + goto error; + } + if (cfg->bind_ifindex) { + err = sock_bindtoindex(sock->sk, cfg->bind_ifindex, true); + if (err < 0) + goto error; + } + + udp6_addr.sin6_family = AF_INET6; + memcpy(&udp6_addr.sin6_addr, &cfg->local_ip6, + sizeof(udp6_addr.sin6_addr)); + udp6_addr.sin6_port = cfg->local_udp_port; + err = kernel_bind(sock, (struct sockaddr *)&udp6_addr, + sizeof(udp6_addr)); + if (err < 0) + goto error; + + if (cfg->peer_udp_port) { + memset(&udp6_addr, 0, sizeof(udp6_addr)); + udp6_addr.sin6_family = AF_INET6; + memcpy(&udp6_addr.sin6_addr, &cfg->peer_ip6, + sizeof(udp6_addr.sin6_addr)); + udp6_addr.sin6_port = cfg->peer_udp_port; + err = kernel_connect(sock, + (struct sockaddr *)&udp6_addr, + sizeof(udp6_addr), 0); + } + if (err < 0) + goto error; + + udp_set_no_check6_tx(sock->sk, !cfg->use_udp6_tx_checksums); + udp_set_no_check6_rx(sock->sk, !cfg->use_udp6_rx_checksums); + + *sockp = sock; + return 0; + +error: + if (sock) { + kernel_sock_shutdown(sock, SHUT_RDWR); + sock_release(sock); + } + *sockp = NULL; + return err; +} +EXPORT_SYMBOL_GPL(udp_sock_create6); + +int udp_tunnel6_xmit_skb(struct dst_entry *dst, struct sock *sk, + struct sk_buff *skb, + struct net_device *dev, struct in6_addr *saddr, + struct in6_addr *daddr, + __u8 prio, __u8 ttl, __be32 label, + __be16 src_port, __be16 dst_port, bool nocheck) +{ + struct udphdr *uh; + struct ipv6hdr *ip6h; + + __skb_push(skb, sizeof(*uh)); + skb_reset_transport_header(skb); + uh = udp_hdr(skb); + + uh->dest = dst_port; + uh->source = src_port; + + uh->len = htons(skb->len); + + skb_dst_set(skb, dst); + + udp6_set_csum(nocheck, skb, saddr, daddr, skb->len); + + __skb_push(skb, sizeof(*ip6h)); + skb_reset_network_header(skb); + ip6h = ipv6_hdr(skb); + ip6_flow_hdr(ip6h, prio, label); + ip6h->payload_len = htons(skb->len); + ip6h->nexthdr = IPPROTO_UDP; + ip6h->hop_limit = ttl; + ip6h->daddr = *daddr; + ip6h->saddr = *saddr; + + ip6tunnel_xmit(sk, skb, dev); + return 0; +} +EXPORT_SYMBOL_GPL(udp_tunnel6_xmit_skb); + +MODULE_LICENSE("GPL"); diff --git a/net/ipv6/ip6_vti.c b/net/ipv6/ip6_vti.c new file mode 100644 index 000000000..cb71463bb --- /dev/null +++ b/net/ipv6/ip6_vti.c @@ -0,0 +1,1335 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPv6 virtual tunneling interface + * + * Copyright (C) 2013 secunet Security Networks AG + * + * Author: + * Steffen Klassert + * + * Based on: + * net/ipv6/ip6_tunnel.c + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define IP6_VTI_HASH_SIZE_SHIFT 5 +#define IP6_VTI_HASH_SIZE (1 << IP6_VTI_HASH_SIZE_SHIFT) + +static u32 HASH(const struct in6_addr *addr1, const struct in6_addr *addr2) +{ + u32 hash = ipv6_addr_hash(addr1) ^ ipv6_addr_hash(addr2); + + return hash_32(hash, IP6_VTI_HASH_SIZE_SHIFT); +} + +static int vti6_dev_init(struct net_device *dev); +static void vti6_dev_setup(struct net_device *dev); +static struct rtnl_link_ops vti6_link_ops __read_mostly; + +static unsigned int vti6_net_id __read_mostly; +struct vti6_net { + /* the vti6 tunnel fallback device */ + struct net_device *fb_tnl_dev; + /* lists for storing tunnels in use */ + struct ip6_tnl __rcu *tnls_r_l[IP6_VTI_HASH_SIZE]; + struct ip6_tnl __rcu *tnls_wc[1]; + struct ip6_tnl __rcu **tnls[2]; +}; + +#define for_each_vti6_tunnel_rcu(start) \ + for (t = rcu_dereference(start); t; t = rcu_dereference(t->next)) + +/** + * vti6_tnl_lookup - fetch tunnel matching the end-point addresses + * @net: network namespace + * @remote: the address of the tunnel exit-point + * @local: the address of the tunnel entry-point + * + * Return: + * tunnel matching given end-points if found, + * else fallback tunnel if its device is up, + * else %NULL + **/ +static struct ip6_tnl * +vti6_tnl_lookup(struct net *net, const struct in6_addr *remote, + const struct in6_addr *local) +{ + unsigned int hash = HASH(remote, local); + struct ip6_tnl *t; + struct vti6_net *ip6n = net_generic(net, vti6_net_id); + struct in6_addr any; + + for_each_vti6_tunnel_rcu(ip6n->tnls_r_l[hash]) { + if (ipv6_addr_equal(local, &t->parms.laddr) && + ipv6_addr_equal(remote, &t->parms.raddr) && + (t->dev->flags & IFF_UP)) + return t; + } + + memset(&any, 0, sizeof(any)); + hash = HASH(&any, local); + for_each_vti6_tunnel_rcu(ip6n->tnls_r_l[hash]) { + if (ipv6_addr_equal(local, &t->parms.laddr) && + (t->dev->flags & IFF_UP)) + return t; + } + + hash = HASH(remote, &any); + for_each_vti6_tunnel_rcu(ip6n->tnls_r_l[hash]) { + if (ipv6_addr_equal(remote, &t->parms.raddr) && + (t->dev->flags & IFF_UP)) + return t; + } + + t = rcu_dereference(ip6n->tnls_wc[0]); + if (t && (t->dev->flags & IFF_UP)) + return t; + + return NULL; +} + +/** + * vti6_tnl_bucket - get head of list matching given tunnel parameters + * @ip6n: the private data for ip6_vti in the netns + * @p: parameters containing tunnel end-points + * + * Description: + * vti6_tnl_bucket() returns the head of the list matching the + * &struct in6_addr entries laddr and raddr in @p. + * + * Return: head of IPv6 tunnel list + **/ +static struct ip6_tnl __rcu ** +vti6_tnl_bucket(struct vti6_net *ip6n, const struct __ip6_tnl_parm *p) +{ + const struct in6_addr *remote = &p->raddr; + const struct in6_addr *local = &p->laddr; + unsigned int h = 0; + int prio = 0; + + if (!ipv6_addr_any(remote) || !ipv6_addr_any(local)) { + prio = 1; + h = HASH(remote, local); + } + return &ip6n->tnls[prio][h]; +} + +static void +vti6_tnl_link(struct vti6_net *ip6n, struct ip6_tnl *t) +{ + struct ip6_tnl __rcu **tp = vti6_tnl_bucket(ip6n, &t->parms); + + rcu_assign_pointer(t->next, rtnl_dereference(*tp)); + rcu_assign_pointer(*tp, t); +} + +static void +vti6_tnl_unlink(struct vti6_net *ip6n, struct ip6_tnl *t) +{ + struct ip6_tnl __rcu **tp; + struct ip6_tnl *iter; + + for (tp = vti6_tnl_bucket(ip6n, &t->parms); + (iter = rtnl_dereference(*tp)) != NULL; + tp = &iter->next) { + if (t == iter) { + rcu_assign_pointer(*tp, t->next); + break; + } + } +} + +static void vti6_dev_free(struct net_device *dev) +{ + free_percpu(dev->tstats); +} + +static int vti6_tnl_create2(struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct net *net = dev_net(dev); + struct vti6_net *ip6n = net_generic(net, vti6_net_id); + int err; + + dev->rtnl_link_ops = &vti6_link_ops; + err = register_netdevice(dev); + if (err < 0) + goto out; + + strcpy(t->parms.name, dev->name); + + vti6_tnl_link(ip6n, t); + + return 0; + +out: + return err; +} + +static struct ip6_tnl *vti6_tnl_create(struct net *net, struct __ip6_tnl_parm *p) +{ + struct net_device *dev; + struct ip6_tnl *t; + char name[IFNAMSIZ]; + int err; + + if (p->name[0]) { + if (!dev_valid_name(p->name)) + goto failed; + strscpy(name, p->name, IFNAMSIZ); + } else { + sprintf(name, "ip6_vti%%d"); + } + + dev = alloc_netdev(sizeof(*t), name, NET_NAME_UNKNOWN, vti6_dev_setup); + if (!dev) + goto failed; + + dev_net_set(dev, net); + + t = netdev_priv(dev); + t->parms = *p; + t->net = dev_net(dev); + + err = vti6_tnl_create2(dev); + if (err < 0) + goto failed_free; + + return t; + +failed_free: + free_netdev(dev); +failed: + return NULL; +} + +/** + * vti6_locate - find or create tunnel matching given parameters + * @net: network namespace + * @p: tunnel parameters + * @create: != 0 if allowed to create new tunnel if no match found + * + * Description: + * vti6_locate() first tries to locate an existing tunnel + * based on @parms. If this is unsuccessful, but @create is set a new + * tunnel device is created and registered for use. + * + * Return: + * matching tunnel or NULL + **/ +static struct ip6_tnl *vti6_locate(struct net *net, struct __ip6_tnl_parm *p, + int create) +{ + const struct in6_addr *remote = &p->raddr; + const struct in6_addr *local = &p->laddr; + struct ip6_tnl __rcu **tp; + struct ip6_tnl *t; + struct vti6_net *ip6n = net_generic(net, vti6_net_id); + + for (tp = vti6_tnl_bucket(ip6n, p); + (t = rtnl_dereference(*tp)) != NULL; + tp = &t->next) { + if (ipv6_addr_equal(local, &t->parms.laddr) && + ipv6_addr_equal(remote, &t->parms.raddr)) { + if (create) + return NULL; + + return t; + } + } + if (!create) + return NULL; + return vti6_tnl_create(net, p); +} + +/** + * vti6_dev_uninit - tunnel device uninitializer + * @dev: the device to be destroyed + * + * Description: + * vti6_dev_uninit() removes tunnel from its list + **/ +static void vti6_dev_uninit(struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct vti6_net *ip6n = net_generic(t->net, vti6_net_id); + + if (dev == ip6n->fb_tnl_dev) + RCU_INIT_POINTER(ip6n->tnls_wc[0], NULL); + else + vti6_tnl_unlink(ip6n, t); + netdev_put(dev, &t->dev_tracker); +} + +static int vti6_input_proto(struct sk_buff *skb, int nexthdr, __be32 spi, + int encap_type) +{ + struct ip6_tnl *t; + const struct ipv6hdr *ipv6h = ipv6_hdr(skb); + + rcu_read_lock(); + t = vti6_tnl_lookup(dev_net(skb->dev), &ipv6h->saddr, &ipv6h->daddr); + if (t) { + if (t->parms.proto != IPPROTO_IPV6 && t->parms.proto != 0) { + rcu_read_unlock(); + goto discard; + } + + if (!xfrm6_policy_check(NULL, XFRM_POLICY_IN, skb)) { + rcu_read_unlock(); + goto discard; + } + + ipv6h = ipv6_hdr(skb); + if (!ip6_tnl_rcv_ctl(t, &ipv6h->daddr, &ipv6h->saddr)) { + t->dev->stats.rx_dropped++; + rcu_read_unlock(); + goto discard; + } + + rcu_read_unlock(); + + XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip6 = t; + XFRM_SPI_SKB_CB(skb)->family = AF_INET6; + XFRM_SPI_SKB_CB(skb)->daddroff = offsetof(struct ipv6hdr, daddr); + return xfrm_input(skb, nexthdr, spi, encap_type); + } + rcu_read_unlock(); + return -EINVAL; +discard: + kfree_skb(skb); + return 0; +} + +static int vti6_rcv(struct sk_buff *skb) +{ + int nexthdr = skb_network_header(skb)[IP6CB(skb)->nhoff]; + + return vti6_input_proto(skb, nexthdr, 0, 0); +} + +static int vti6_rcv_cb(struct sk_buff *skb, int err) +{ + unsigned short family; + struct net_device *dev; + struct xfrm_state *x; + const struct xfrm_mode *inner_mode; + struct ip6_tnl *t = XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip6; + u32 orig_mark = skb->mark; + int ret; + + if (!t) + return 1; + + dev = t->dev; + + if (err) { + dev->stats.rx_errors++; + dev->stats.rx_dropped++; + + return 0; + } + + x = xfrm_input_state(skb); + + inner_mode = &x->inner_mode; + + if (x->sel.family == AF_UNSPEC) { + inner_mode = xfrm_ip2inner_mode(x, XFRM_MODE_SKB_CB(skb)->protocol); + if (inner_mode == NULL) { + XFRM_INC_STATS(dev_net(skb->dev), + LINUX_MIB_XFRMINSTATEMODEERROR); + return -EINVAL; + } + } + + family = inner_mode->family; + + skb->mark = be32_to_cpu(t->parms.i_key); + ret = xfrm_policy_check(NULL, XFRM_POLICY_IN, skb, family); + skb->mark = orig_mark; + + if (!ret) + return -EPERM; + + skb_scrub_packet(skb, !net_eq(t->net, dev_net(skb->dev))); + skb->dev = dev; + dev_sw_netstats_rx_add(dev, skb->len); + + return 0; +} + +/** + * vti6_addr_conflict - compare packet addresses to tunnel's own + * @t: the outgoing tunnel device + * @hdr: IPv6 header from the incoming packet + * + * Description: + * Avoid trivial tunneling loop by checking that tunnel exit-point + * doesn't match source of incoming packet. + * + * Return: + * 1 if conflict, + * 0 else + **/ +static inline bool +vti6_addr_conflict(const struct ip6_tnl *t, const struct ipv6hdr *hdr) +{ + return ipv6_addr_equal(&t->parms.raddr, &hdr->saddr); +} + +static bool vti6_state_check(const struct xfrm_state *x, + const struct in6_addr *dst, + const struct in6_addr *src) +{ + xfrm_address_t *daddr = (xfrm_address_t *)dst; + xfrm_address_t *saddr = (xfrm_address_t *)src; + + /* if there is no transform then this tunnel is not functional. + * Or if the xfrm is not mode tunnel. + */ + if (!x || x->props.mode != XFRM_MODE_TUNNEL || + x->props.family != AF_INET6) + return false; + + if (ipv6_addr_any(dst)) + return xfrm_addr_equal(saddr, &x->props.saddr, AF_INET6); + + if (!xfrm_state_addr_check(x, daddr, saddr, AF_INET6)) + return false; + + return true; +} + +/** + * vti6_xmit - send a packet + * @skb: the outgoing socket buffer + * @dev: the outgoing tunnel device + * @fl: the flow informations for the xfrm_lookup + **/ +static int +vti6_xmit(struct sk_buff *skb, struct net_device *dev, struct flowi *fl) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct net_device_stats *stats = &t->dev->stats; + struct dst_entry *dst = skb_dst(skb); + struct net_device *tdev; + struct xfrm_state *x; + int pkt_len = skb->len; + int err = -1; + int mtu; + + if (!dst) { + switch (skb->protocol) { + case htons(ETH_P_IP): { + struct rtable *rt; + + fl->u.ip4.flowi4_oif = dev->ifindex; + fl->u.ip4.flowi4_flags |= FLOWI_FLAG_ANYSRC; + rt = __ip_route_output_key(dev_net(dev), &fl->u.ip4); + if (IS_ERR(rt)) + goto tx_err_link_failure; + dst = &rt->dst; + skb_dst_set(skb, dst); + break; + } + case htons(ETH_P_IPV6): + fl->u.ip6.flowi6_oif = dev->ifindex; + fl->u.ip6.flowi6_flags |= FLOWI_FLAG_ANYSRC; + dst = ip6_route_output(dev_net(dev), NULL, &fl->u.ip6); + if (dst->error) { + dst_release(dst); + dst = NULL; + goto tx_err_link_failure; + } + skb_dst_set(skb, dst); + break; + default: + goto tx_err_link_failure; + } + } + + dst_hold(dst); + dst = xfrm_lookup_route(t->net, dst, fl, NULL, 0); + if (IS_ERR(dst)) { + err = PTR_ERR(dst); + dst = NULL; + goto tx_err_link_failure; + } + + if (dst->flags & DST_XFRM_QUEUE) + goto xmit; + + x = dst->xfrm; + if (!vti6_state_check(x, &t->parms.raddr, &t->parms.laddr)) + goto tx_err_link_failure; + + if (!ip6_tnl_xmit_ctl(t, (const struct in6_addr *)&x->props.saddr, + (const struct in6_addr *)&x->id.daddr)) + goto tx_err_link_failure; + + tdev = dst->dev; + + if (tdev == dev) { + stats->collisions++; + net_warn_ratelimited("%s: Local routing loop detected!\n", + t->parms.name); + goto tx_err_dst_release; + } + + mtu = dst_mtu(dst); + if (skb->len > mtu) { + skb_dst_update_pmtu_no_confirm(skb, mtu); + + if (skb->protocol == htons(ETH_P_IPV6)) { + if (mtu < IPV6_MIN_MTU) + mtu = IPV6_MIN_MTU; + + icmpv6_ndo_send(skb, ICMPV6_PKT_TOOBIG, 0, mtu); + } else { + if (!(ip_hdr(skb)->frag_off & htons(IP_DF))) + goto xmit; + icmp_ndo_send(skb, ICMP_DEST_UNREACH, ICMP_FRAG_NEEDED, + htonl(mtu)); + } + + err = -EMSGSIZE; + goto tx_err_dst_release; + } + +xmit: + skb_scrub_packet(skb, !net_eq(t->net, dev_net(dev))); + skb_dst_set(skb, dst); + skb->dev = skb_dst(skb)->dev; + + err = dst_output(t->net, skb->sk, skb); + if (net_xmit_eval(err) == 0) + err = pkt_len; + iptunnel_xmit_stats(dev, err); + + return 0; +tx_err_link_failure: + stats->tx_carrier_errors++; + dst_link_failure(skb); +tx_err_dst_release: + dst_release(dst); + return err; +} + +static netdev_tx_t +vti6_tnl_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct net_device_stats *stats = &t->dev->stats; + struct flowi fl; + int ret; + + if (!pskb_inet_may_pull(skb)) + goto tx_err; + + memset(&fl, 0, sizeof(fl)); + + switch (skb->protocol) { + case htons(ETH_P_IPV6): + if ((t->parms.proto != IPPROTO_IPV6 && t->parms.proto != 0) || + vti6_addr_conflict(t, ipv6_hdr(skb))) + goto tx_err; + + memset(IP6CB(skb), 0, sizeof(*IP6CB(skb))); + xfrm_decode_session(skb, &fl, AF_INET6); + break; + case htons(ETH_P_IP): + memset(IPCB(skb), 0, sizeof(*IPCB(skb))); + xfrm_decode_session(skb, &fl, AF_INET); + break; + default: + goto tx_err; + } + + /* override mark with tunnel output key */ + fl.flowi_mark = be32_to_cpu(t->parms.o_key); + + ret = vti6_xmit(skb, dev, &fl); + if (ret < 0) + goto tx_err; + + return NETDEV_TX_OK; + +tx_err: + stats->tx_errors++; + stats->tx_dropped++; + kfree_skb(skb); + return NETDEV_TX_OK; +} + +static int vti6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + __be32 spi; + __u32 mark; + struct xfrm_state *x; + struct ip6_tnl *t; + struct ip_esp_hdr *esph; + struct ip_auth_hdr *ah; + struct ip_comp_hdr *ipch; + struct net *net = dev_net(skb->dev); + const struct ipv6hdr *iph = (const struct ipv6hdr *)skb->data; + int protocol = iph->nexthdr; + + t = vti6_tnl_lookup(dev_net(skb->dev), &iph->daddr, &iph->saddr); + if (!t) + return -1; + + mark = be32_to_cpu(t->parms.o_key); + + switch (protocol) { + case IPPROTO_ESP: + esph = (struct ip_esp_hdr *)(skb->data + offset); + spi = esph->spi; + break; + case IPPROTO_AH: + ah = (struct ip_auth_hdr *)(skb->data + offset); + spi = ah->spi; + break; + case IPPROTO_COMP: + ipch = (struct ip_comp_hdr *)(skb->data + offset); + spi = htonl(ntohs(ipch->cpi)); + break; + default: + return 0; + } + + if (type != ICMPV6_PKT_TOOBIG && + type != NDISC_REDIRECT) + return 0; + + x = xfrm_state_lookup(net, mark, (const xfrm_address_t *)&iph->daddr, + spi, protocol, AF_INET6); + if (!x) + return 0; + + if (type == NDISC_REDIRECT) + ip6_redirect(skb, net, skb->dev->ifindex, 0, + sock_net_uid(net, NULL)); + else + ip6_update_pmtu(skb, net, info, 0, 0, sock_net_uid(net, NULL)); + xfrm_state_put(x); + + return 0; +} + +static void vti6_link_config(struct ip6_tnl *t, bool keep_mtu) +{ + struct net_device *dev = t->dev; + struct __ip6_tnl_parm *p = &t->parms; + struct net_device *tdev = NULL; + int mtu; + + __dev_addr_set(dev, &p->laddr, sizeof(struct in6_addr)); + memcpy(dev->broadcast, &p->raddr, sizeof(struct in6_addr)); + + p->flags &= ~(IP6_TNL_F_CAP_XMIT | IP6_TNL_F_CAP_RCV | + IP6_TNL_F_CAP_PER_PACKET); + p->flags |= ip6_tnl_get_cap(t, &p->laddr, &p->raddr); + + if (p->flags & IP6_TNL_F_CAP_XMIT && p->flags & IP6_TNL_F_CAP_RCV) + dev->flags |= IFF_POINTOPOINT; + else + dev->flags &= ~IFF_POINTOPOINT; + + if (keep_mtu && dev->mtu) { + dev->mtu = clamp(dev->mtu, dev->min_mtu, dev->max_mtu); + return; + } + + if (p->flags & IP6_TNL_F_CAP_XMIT) { + int strict = (ipv6_addr_type(&p->raddr) & + (IPV6_ADDR_MULTICAST | IPV6_ADDR_LINKLOCAL)); + struct rt6_info *rt = rt6_lookup(t->net, + &p->raddr, &p->laddr, + p->link, NULL, strict); + + if (rt) + tdev = rt->dst.dev; + ip6_rt_put(rt); + } + + if (!tdev && p->link) + tdev = __dev_get_by_index(t->net, p->link); + + if (tdev) + mtu = tdev->mtu - sizeof(struct ipv6hdr); + else + mtu = ETH_DATA_LEN - LL_MAX_HEADER - sizeof(struct ipv6hdr); + + dev->mtu = max_t(int, mtu, IPV4_MIN_MTU); +} + +/** + * vti6_tnl_change - update the tunnel parameters + * @t: tunnel to be changed + * @p: tunnel configuration parameters + * @keep_mtu: MTU was set from userspace, don't re-compute it + * + * Description: + * vti6_tnl_change() updates the tunnel parameters + **/ +static int +vti6_tnl_change(struct ip6_tnl *t, const struct __ip6_tnl_parm *p, + bool keep_mtu) +{ + t->parms.laddr = p->laddr; + t->parms.raddr = p->raddr; + t->parms.link = p->link; + t->parms.i_key = p->i_key; + t->parms.o_key = p->o_key; + t->parms.proto = p->proto; + t->parms.fwmark = p->fwmark; + dst_cache_reset(&t->dst_cache); + vti6_link_config(t, keep_mtu); + return 0; +} + +static int vti6_update(struct ip6_tnl *t, struct __ip6_tnl_parm *p, + bool keep_mtu) +{ + struct net *net = dev_net(t->dev); + struct vti6_net *ip6n = net_generic(net, vti6_net_id); + int err; + + vti6_tnl_unlink(ip6n, t); + synchronize_net(); + err = vti6_tnl_change(t, p, keep_mtu); + vti6_tnl_link(ip6n, t); + netdev_state_change(t->dev); + return err; +} + +static void +vti6_parm_from_user(struct __ip6_tnl_parm *p, const struct ip6_tnl_parm2 *u) +{ + p->laddr = u->laddr; + p->raddr = u->raddr; + p->link = u->link; + p->i_key = u->i_key; + p->o_key = u->o_key; + p->proto = u->proto; + + memcpy(p->name, u->name, sizeof(u->name)); +} + +static void +vti6_parm_to_user(struct ip6_tnl_parm2 *u, const struct __ip6_tnl_parm *p) +{ + u->laddr = p->laddr; + u->raddr = p->raddr; + u->link = p->link; + u->i_key = p->i_key; + u->o_key = p->o_key; + if (u->i_key) + u->i_flags |= GRE_KEY; + if (u->o_key) + u->o_flags |= GRE_KEY; + u->proto = p->proto; + + memcpy(u->name, p->name, sizeof(u->name)); +} + +/** + * vti6_siocdevprivate - configure vti6 tunnels from userspace + * @dev: virtual device associated with tunnel + * @ifr: unused + * @data: parameters passed from userspace + * @cmd: command to be performed + * + * Description: + * vti6_siocdevprivate() is used for managing vti6 tunnels + * from userspace. + * + * The possible commands are the following: + * %SIOCGETTUNNEL: get tunnel parameters for device + * %SIOCADDTUNNEL: add tunnel matching given tunnel parameters + * %SIOCCHGTUNNEL: change tunnel parameters to those given + * %SIOCDELTUNNEL: delete tunnel + * + * The fallback device "ip6_vti0", created during module + * initialization, can be used for creating other tunnel devices. + * + * Return: + * 0 on success, + * %-EFAULT if unable to copy data to or from userspace, + * %-EPERM if current process hasn't %CAP_NET_ADMIN set + * %-EINVAL if passed tunnel parameters are invalid, + * %-EEXIST if changing a tunnel's parameters would cause a conflict + * %-ENODEV if attempting to change or delete a nonexisting device + **/ +static int +vti6_siocdevprivate(struct net_device *dev, struct ifreq *ifr, void __user *data, int cmd) +{ + int err = 0; + struct ip6_tnl_parm2 p; + struct __ip6_tnl_parm p1; + struct ip6_tnl *t = NULL; + struct net *net = dev_net(dev); + struct vti6_net *ip6n = net_generic(net, vti6_net_id); + + memset(&p1, 0, sizeof(p1)); + + switch (cmd) { + case SIOCGETTUNNEL: + if (dev == ip6n->fb_tnl_dev) { + if (copy_from_user(&p, data, sizeof(p))) { + err = -EFAULT; + break; + } + vti6_parm_from_user(&p1, &p); + t = vti6_locate(net, &p1, 0); + } else { + memset(&p, 0, sizeof(p)); + } + if (!t) + t = netdev_priv(dev); + vti6_parm_to_user(&p, &t->parms); + if (copy_to_user(data, &p, sizeof(p))) + err = -EFAULT; + break; + case SIOCADDTUNNEL: + case SIOCCHGTUNNEL: + err = -EPERM; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + break; + err = -EFAULT; + if (copy_from_user(&p, data, sizeof(p))) + break; + err = -EINVAL; + if (p.proto != IPPROTO_IPV6 && p.proto != 0) + break; + vti6_parm_from_user(&p1, &p); + t = vti6_locate(net, &p1, cmd == SIOCADDTUNNEL); + if (dev != ip6n->fb_tnl_dev && cmd == SIOCCHGTUNNEL) { + if (t) { + if (t->dev != dev) { + err = -EEXIST; + break; + } + } else + t = netdev_priv(dev); + + err = vti6_update(t, &p1, false); + } + if (t) { + err = 0; + vti6_parm_to_user(&p, &t->parms); + if (copy_to_user(data, &p, sizeof(p))) + err = -EFAULT; + + } else + err = (cmd == SIOCADDTUNNEL ? -ENOBUFS : -ENOENT); + break; + case SIOCDELTUNNEL: + err = -EPERM; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + break; + + if (dev == ip6n->fb_tnl_dev) { + err = -EFAULT; + if (copy_from_user(&p, data, sizeof(p))) + break; + err = -ENOENT; + vti6_parm_from_user(&p1, &p); + t = vti6_locate(net, &p1, 0); + if (!t) + break; + err = -EPERM; + if (t->dev == ip6n->fb_tnl_dev) + break; + dev = t->dev; + } + err = 0; + unregister_netdevice(dev); + break; + default: + err = -EINVAL; + } + return err; +} + +static const struct net_device_ops vti6_netdev_ops = { + .ndo_init = vti6_dev_init, + .ndo_uninit = vti6_dev_uninit, + .ndo_start_xmit = vti6_tnl_xmit, + .ndo_siocdevprivate = vti6_siocdevprivate, + .ndo_get_stats64 = dev_get_tstats64, + .ndo_get_iflink = ip6_tnl_get_iflink, +}; + +/** + * vti6_dev_setup - setup virtual tunnel device + * @dev: virtual device associated with tunnel + * + * Description: + * Initialize function pointers and device parameters + **/ +static void vti6_dev_setup(struct net_device *dev) +{ + dev->netdev_ops = &vti6_netdev_ops; + dev->header_ops = &ip_tunnel_header_ops; + dev->needs_free_netdev = true; + dev->priv_destructor = vti6_dev_free; + + dev->type = ARPHRD_TUNNEL6; + dev->min_mtu = IPV4_MIN_MTU; + dev->max_mtu = IP_MAX_MTU - sizeof(struct ipv6hdr); + dev->flags |= IFF_NOARP; + dev->addr_len = sizeof(struct in6_addr); + netif_keep_dst(dev); + /* This perm addr will be used as interface identifier by IPv6 */ + dev->addr_assign_type = NET_ADDR_RANDOM; + eth_random_addr(dev->perm_addr); +} + +/** + * vti6_dev_init_gen - general initializer for all tunnel devices + * @dev: virtual device associated with tunnel + **/ +static inline int vti6_dev_init_gen(struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + + t->dev = dev; + t->net = dev_net(dev); + dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); + if (!dev->tstats) + return -ENOMEM; + netdev_hold(dev, &t->dev_tracker, GFP_KERNEL); + return 0; +} + +/** + * vti6_dev_init - initializer for all non fallback tunnel devices + * @dev: virtual device associated with tunnel + **/ +static int vti6_dev_init(struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + int err = vti6_dev_init_gen(dev); + + if (err) + return err; + vti6_link_config(t, true); + return 0; +} + +/** + * vti6_fb_tnl_dev_init - initializer for fallback tunnel device + * @dev: fallback device + * + * Return: 0 + **/ +static int __net_init vti6_fb_tnl_dev_init(struct net_device *dev) +{ + struct ip6_tnl *t = netdev_priv(dev); + struct net *net = dev_net(dev); + struct vti6_net *ip6n = net_generic(net, vti6_net_id); + + t->parms.proto = IPPROTO_IPV6; + + rcu_assign_pointer(ip6n->tnls_wc[0], t); + return 0; +} + +static int vti6_validate(struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + return 0; +} + +static void vti6_netlink_parms(struct nlattr *data[], + struct __ip6_tnl_parm *parms) +{ + memset(parms, 0, sizeof(*parms)); + + if (!data) + return; + + if (data[IFLA_VTI_LINK]) + parms->link = nla_get_u32(data[IFLA_VTI_LINK]); + + if (data[IFLA_VTI_LOCAL]) + parms->laddr = nla_get_in6_addr(data[IFLA_VTI_LOCAL]); + + if (data[IFLA_VTI_REMOTE]) + parms->raddr = nla_get_in6_addr(data[IFLA_VTI_REMOTE]); + + if (data[IFLA_VTI_IKEY]) + parms->i_key = nla_get_be32(data[IFLA_VTI_IKEY]); + + if (data[IFLA_VTI_OKEY]) + parms->o_key = nla_get_be32(data[IFLA_VTI_OKEY]); + + if (data[IFLA_VTI_FWMARK]) + parms->fwmark = nla_get_u32(data[IFLA_VTI_FWMARK]); +} + +static int vti6_newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct net *net = dev_net(dev); + struct ip6_tnl *nt; + + nt = netdev_priv(dev); + vti6_netlink_parms(data, &nt->parms); + + nt->parms.proto = IPPROTO_IPV6; + + if (vti6_locate(net, &nt->parms, 0)) + return -EEXIST; + + return vti6_tnl_create2(dev); +} + +static void vti6_dellink(struct net_device *dev, struct list_head *head) +{ + struct net *net = dev_net(dev); + struct vti6_net *ip6n = net_generic(net, vti6_net_id); + + if (dev != ip6n->fb_tnl_dev) + unregister_netdevice_queue(dev, head); +} + +static int vti6_changelink(struct net_device *dev, struct nlattr *tb[], + struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ip6_tnl *t; + struct __ip6_tnl_parm p; + struct net *net = dev_net(dev); + struct vti6_net *ip6n = net_generic(net, vti6_net_id); + + if (dev == ip6n->fb_tnl_dev) + return -EINVAL; + + vti6_netlink_parms(data, &p); + + t = vti6_locate(net, &p, 0); + + if (t) { + if (t->dev != dev) + return -EEXIST; + } else + t = netdev_priv(dev); + + return vti6_update(t, &p, tb && tb[IFLA_MTU]); +} + +static size_t vti6_get_size(const struct net_device *dev) +{ + return + /* IFLA_VTI_LINK */ + nla_total_size(4) + + /* IFLA_VTI_LOCAL */ + nla_total_size(sizeof(struct in6_addr)) + + /* IFLA_VTI_REMOTE */ + nla_total_size(sizeof(struct in6_addr)) + + /* IFLA_VTI_IKEY */ + nla_total_size(4) + + /* IFLA_VTI_OKEY */ + nla_total_size(4) + + /* IFLA_VTI_FWMARK */ + nla_total_size(4) + + 0; +} + +static int vti6_fill_info(struct sk_buff *skb, const struct net_device *dev) +{ + struct ip6_tnl *tunnel = netdev_priv(dev); + struct __ip6_tnl_parm *parm = &tunnel->parms; + + if (nla_put_u32(skb, IFLA_VTI_LINK, parm->link) || + nla_put_in6_addr(skb, IFLA_VTI_LOCAL, &parm->laddr) || + nla_put_in6_addr(skb, IFLA_VTI_REMOTE, &parm->raddr) || + nla_put_be32(skb, IFLA_VTI_IKEY, parm->i_key) || + nla_put_be32(skb, IFLA_VTI_OKEY, parm->o_key) || + nla_put_u32(skb, IFLA_VTI_FWMARK, parm->fwmark)) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static const struct nla_policy vti6_policy[IFLA_VTI_MAX + 1] = { + [IFLA_VTI_LINK] = { .type = NLA_U32 }, + [IFLA_VTI_LOCAL] = { .len = sizeof(struct in6_addr) }, + [IFLA_VTI_REMOTE] = { .len = sizeof(struct in6_addr) }, + [IFLA_VTI_IKEY] = { .type = NLA_U32 }, + [IFLA_VTI_OKEY] = { .type = NLA_U32 }, + [IFLA_VTI_FWMARK] = { .type = NLA_U32 }, +}; + +static struct rtnl_link_ops vti6_link_ops __read_mostly = { + .kind = "vti6", + .maxtype = IFLA_VTI_MAX, + .policy = vti6_policy, + .priv_size = sizeof(struct ip6_tnl), + .setup = vti6_dev_setup, + .validate = vti6_validate, + .newlink = vti6_newlink, + .dellink = vti6_dellink, + .changelink = vti6_changelink, + .get_size = vti6_get_size, + .fill_info = vti6_fill_info, + .get_link_net = ip6_tnl_get_link_net, +}; + +static void __net_exit vti6_destroy_tunnels(struct vti6_net *ip6n, + struct list_head *list) +{ + int h; + struct ip6_tnl *t; + + for (h = 0; h < IP6_VTI_HASH_SIZE; h++) { + t = rtnl_dereference(ip6n->tnls_r_l[h]); + while (t) { + unregister_netdevice_queue(t->dev, list); + t = rtnl_dereference(t->next); + } + } + + t = rtnl_dereference(ip6n->tnls_wc[0]); + if (t) + unregister_netdevice_queue(t->dev, list); +} + +static int __net_init vti6_init_net(struct net *net) +{ + struct vti6_net *ip6n = net_generic(net, vti6_net_id); + struct ip6_tnl *t = NULL; + int err; + + ip6n->tnls[0] = ip6n->tnls_wc; + ip6n->tnls[1] = ip6n->tnls_r_l; + + if (!net_has_fallback_tunnels(net)) + return 0; + err = -ENOMEM; + ip6n->fb_tnl_dev = alloc_netdev(sizeof(struct ip6_tnl), "ip6_vti0", + NET_NAME_UNKNOWN, vti6_dev_setup); + + if (!ip6n->fb_tnl_dev) + goto err_alloc_dev; + dev_net_set(ip6n->fb_tnl_dev, net); + ip6n->fb_tnl_dev->rtnl_link_ops = &vti6_link_ops; + + err = vti6_fb_tnl_dev_init(ip6n->fb_tnl_dev); + if (err < 0) + goto err_register; + + err = register_netdev(ip6n->fb_tnl_dev); + if (err < 0) + goto err_register; + + t = netdev_priv(ip6n->fb_tnl_dev); + + strcpy(t->parms.name, ip6n->fb_tnl_dev->name); + return 0; + +err_register: + free_netdev(ip6n->fb_tnl_dev); +err_alloc_dev: + return err; +} + +static void __net_exit vti6_exit_batch_net(struct list_head *net_list) +{ + struct vti6_net *ip6n; + struct net *net; + LIST_HEAD(list); + + rtnl_lock(); + list_for_each_entry(net, net_list, exit_list) { + ip6n = net_generic(net, vti6_net_id); + vti6_destroy_tunnels(ip6n, &list); + } + unregister_netdevice_many(&list); + rtnl_unlock(); +} + +static struct pernet_operations vti6_net_ops = { + .init = vti6_init_net, + .exit_batch = vti6_exit_batch_net, + .id = &vti6_net_id, + .size = sizeof(struct vti6_net), +}; + +static struct xfrm6_protocol vti_esp6_protocol __read_mostly = { + .handler = vti6_rcv, + .input_handler = vti6_input_proto, + .cb_handler = vti6_rcv_cb, + .err_handler = vti6_err, + .priority = 100, +}; + +static struct xfrm6_protocol vti_ah6_protocol __read_mostly = { + .handler = vti6_rcv, + .input_handler = vti6_input_proto, + .cb_handler = vti6_rcv_cb, + .err_handler = vti6_err, + .priority = 100, +}; + +static struct xfrm6_protocol vti_ipcomp6_protocol __read_mostly = { + .handler = vti6_rcv, + .input_handler = vti6_input_proto, + .cb_handler = vti6_rcv_cb, + .err_handler = vti6_err, + .priority = 100, +}; + +#if IS_REACHABLE(CONFIG_INET6_XFRM_TUNNEL) +static int vti6_rcv_tunnel(struct sk_buff *skb) +{ + const xfrm_address_t *saddr; + __be32 spi; + + saddr = (const xfrm_address_t *)&ipv6_hdr(skb)->saddr; + spi = xfrm6_tunnel_spi_lookup(dev_net(skb->dev), saddr); + + return vti6_input_proto(skb, IPPROTO_IPV6, spi, 0); +} + +static struct xfrm6_tunnel vti_ipv6_handler __read_mostly = { + .handler = vti6_rcv_tunnel, + .cb_handler = vti6_rcv_cb, + .err_handler = vti6_err, + .priority = 0, +}; + +static struct xfrm6_tunnel vti_ip6ip_handler __read_mostly = { + .handler = vti6_rcv_tunnel, + .cb_handler = vti6_rcv_cb, + .err_handler = vti6_err, + .priority = 0, +}; +#endif + +/** + * vti6_tunnel_init - register protocol and reserve needed resources + * + * Return: 0 on success + **/ +static int __init vti6_tunnel_init(void) +{ + const char *msg; + int err; + + msg = "tunnel device"; + err = register_pernet_device(&vti6_net_ops); + if (err < 0) + goto pernet_dev_failed; + + msg = "tunnel protocols"; + err = xfrm6_protocol_register(&vti_esp6_protocol, IPPROTO_ESP); + if (err < 0) + goto xfrm_proto_esp_failed; + err = xfrm6_protocol_register(&vti_ah6_protocol, IPPROTO_AH); + if (err < 0) + goto xfrm_proto_ah_failed; + err = xfrm6_protocol_register(&vti_ipcomp6_protocol, IPPROTO_COMP); + if (err < 0) + goto xfrm_proto_comp_failed; +#if IS_REACHABLE(CONFIG_INET6_XFRM_TUNNEL) + msg = "ipv6 tunnel"; + err = xfrm6_tunnel_register(&vti_ipv6_handler, AF_INET6); + if (err < 0) + goto vti_tunnel_ipv6_failed; + err = xfrm6_tunnel_register(&vti_ip6ip_handler, AF_INET); + if (err < 0) + goto vti_tunnel_ip6ip_failed; +#endif + + msg = "netlink interface"; + err = rtnl_link_register(&vti6_link_ops); + if (err < 0) + goto rtnl_link_failed; + + return 0; + +rtnl_link_failed: +#if IS_REACHABLE(CONFIG_INET6_XFRM_TUNNEL) + err = xfrm6_tunnel_deregister(&vti_ip6ip_handler, AF_INET); +vti_tunnel_ip6ip_failed: + err = xfrm6_tunnel_deregister(&vti_ipv6_handler, AF_INET6); +vti_tunnel_ipv6_failed: +#endif + xfrm6_protocol_deregister(&vti_ipcomp6_protocol, IPPROTO_COMP); +xfrm_proto_comp_failed: + xfrm6_protocol_deregister(&vti_ah6_protocol, IPPROTO_AH); +xfrm_proto_ah_failed: + xfrm6_protocol_deregister(&vti_esp6_protocol, IPPROTO_ESP); +xfrm_proto_esp_failed: + unregister_pernet_device(&vti6_net_ops); +pernet_dev_failed: + pr_err("vti6 init: failed to register %s\n", msg); + return err; +} + +/** + * vti6_tunnel_cleanup - free resources and unregister protocol + **/ +static void __exit vti6_tunnel_cleanup(void) +{ + rtnl_link_unregister(&vti6_link_ops); +#if IS_REACHABLE(CONFIG_INET6_XFRM_TUNNEL) + xfrm6_tunnel_deregister(&vti_ip6ip_handler, AF_INET); + xfrm6_tunnel_deregister(&vti_ipv6_handler, AF_INET6); +#endif + xfrm6_protocol_deregister(&vti_ipcomp6_protocol, IPPROTO_COMP); + xfrm6_protocol_deregister(&vti_ah6_protocol, IPPROTO_AH); + xfrm6_protocol_deregister(&vti_esp6_protocol, IPPROTO_ESP); + unregister_pernet_device(&vti6_net_ops); +} + +module_init(vti6_tunnel_init); +module_exit(vti6_tunnel_cleanup); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_RTNL_LINK("vti6"); +MODULE_ALIAS_NETDEV("ip6_vti0"); +MODULE_AUTHOR("Steffen Klassert"); +MODULE_DESCRIPTION("IPv6 virtual tunnel interface"); diff --git a/net/ipv6/ip6mr.c b/net/ipv6/ip6mr.c new file mode 100644 index 000000000..27fb54799 --- /dev/null +++ b/net/ipv6/ip6mr.c @@ -0,0 +1,2636 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Linux IPv6 multicast routing support for BSD pim6sd + * Based on net/ipv4/ipmr.c. + * + * (c) 2004 Mickael Hoerdt, + * LSIIT Laboratory, Strasbourg, France + * (c) 2004 Jean-Philippe Andriot, + * 6WIND, Paris, France + * Copyright (C)2007,2008 USAGI/WIDE Project + * YOSHIFUJI Hideaki + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +struct ip6mr_rule { + struct fib_rule common; +}; + +struct ip6mr_result { + struct mr_table *mrt; +}; + +/* Big lock, protecting vif table, mrt cache and mroute socket state. + Note that the changes are semaphored via rtnl_lock. + */ + +static DEFINE_SPINLOCK(mrt_lock); + +static struct net_device *vif_dev_read(const struct vif_device *vif) +{ + return rcu_dereference(vif->dev); +} + +/* Multicast router control variables */ + +/* Special spinlock for queue of unresolved entries */ +static DEFINE_SPINLOCK(mfc_unres_lock); + +/* We return to original Alan's scheme. Hash table of resolved + entries is changed only in process context and protected + with weak lock mrt_lock. Queue of unresolved entries is protected + with strong spinlock mfc_unres_lock. + + In this case data path is free of exclusive locks at all. + */ + +static struct kmem_cache *mrt_cachep __read_mostly; + +static struct mr_table *ip6mr_new_table(struct net *net, u32 id); +static void ip6mr_free_table(struct mr_table *mrt); + +static void ip6_mr_forward(struct net *net, struct mr_table *mrt, + struct net_device *dev, struct sk_buff *skb, + struct mfc6_cache *cache); +static int ip6mr_cache_report(const struct mr_table *mrt, struct sk_buff *pkt, + mifi_t mifi, int assert); +static void mr6_netlink_event(struct mr_table *mrt, struct mfc6_cache *mfc, + int cmd); +static void mrt6msg_netlink_event(const struct mr_table *mrt, struct sk_buff *pkt); +static int ip6mr_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack); +static int ip6mr_rtm_dumproute(struct sk_buff *skb, + struct netlink_callback *cb); +static void mroute_clean_tables(struct mr_table *mrt, int flags); +static void ipmr_expire_process(struct timer_list *t); + +#ifdef CONFIG_IPV6_MROUTE_MULTIPLE_TABLES +#define ip6mr_for_each_table(mrt, net) \ + list_for_each_entry_rcu(mrt, &net->ipv6.mr6_tables, list, \ + lockdep_rtnl_is_held() || \ + list_empty(&net->ipv6.mr6_tables)) + +static struct mr_table *ip6mr_mr_table_iter(struct net *net, + struct mr_table *mrt) +{ + struct mr_table *ret; + + if (!mrt) + ret = list_entry_rcu(net->ipv6.mr6_tables.next, + struct mr_table, list); + else + ret = list_entry_rcu(mrt->list.next, + struct mr_table, list); + + if (&ret->list == &net->ipv6.mr6_tables) + return NULL; + return ret; +} + +static struct mr_table *ip6mr_get_table(struct net *net, u32 id) +{ + struct mr_table *mrt; + + ip6mr_for_each_table(mrt, net) { + if (mrt->id == id) + return mrt; + } + return NULL; +} + +static int ip6mr_fib_lookup(struct net *net, struct flowi6 *flp6, + struct mr_table **mrt) +{ + int err; + struct ip6mr_result res; + struct fib_lookup_arg arg = { + .result = &res, + .flags = FIB_LOOKUP_NOREF, + }; + + /* update flow if oif or iif point to device enslaved to l3mdev */ + l3mdev_update_flow(net, flowi6_to_flowi(flp6)); + + err = fib_rules_lookup(net->ipv6.mr6_rules_ops, + flowi6_to_flowi(flp6), 0, &arg); + if (err < 0) + return err; + *mrt = res.mrt; + return 0; +} + +static int ip6mr_rule_action(struct fib_rule *rule, struct flowi *flp, + int flags, struct fib_lookup_arg *arg) +{ + struct ip6mr_result *res = arg->result; + struct mr_table *mrt; + + switch (rule->action) { + case FR_ACT_TO_TBL: + break; + case FR_ACT_UNREACHABLE: + return -ENETUNREACH; + case FR_ACT_PROHIBIT: + return -EACCES; + case FR_ACT_BLACKHOLE: + default: + return -EINVAL; + } + + arg->table = fib_rule_get_table(rule, arg); + + mrt = ip6mr_get_table(rule->fr_net, arg->table); + if (!mrt) + return -EAGAIN; + res->mrt = mrt; + return 0; +} + +static int ip6mr_rule_match(struct fib_rule *rule, struct flowi *flp, int flags) +{ + return 1; +} + +static int ip6mr_rule_configure(struct fib_rule *rule, struct sk_buff *skb, + struct fib_rule_hdr *frh, struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + return 0; +} + +static int ip6mr_rule_compare(struct fib_rule *rule, struct fib_rule_hdr *frh, + struct nlattr **tb) +{ + return 1; +} + +static int ip6mr_rule_fill(struct fib_rule *rule, struct sk_buff *skb, + struct fib_rule_hdr *frh) +{ + frh->dst_len = 0; + frh->src_len = 0; + frh->tos = 0; + return 0; +} + +static const struct fib_rules_ops __net_initconst ip6mr_rules_ops_template = { + .family = RTNL_FAMILY_IP6MR, + .rule_size = sizeof(struct ip6mr_rule), + .addr_size = sizeof(struct in6_addr), + .action = ip6mr_rule_action, + .match = ip6mr_rule_match, + .configure = ip6mr_rule_configure, + .compare = ip6mr_rule_compare, + .fill = ip6mr_rule_fill, + .nlgroup = RTNLGRP_IPV6_RULE, + .owner = THIS_MODULE, +}; + +static int __net_init ip6mr_rules_init(struct net *net) +{ + struct fib_rules_ops *ops; + struct mr_table *mrt; + int err; + + ops = fib_rules_register(&ip6mr_rules_ops_template, net); + if (IS_ERR(ops)) + return PTR_ERR(ops); + + INIT_LIST_HEAD(&net->ipv6.mr6_tables); + + mrt = ip6mr_new_table(net, RT6_TABLE_DFLT); + if (IS_ERR(mrt)) { + err = PTR_ERR(mrt); + goto err1; + } + + err = fib_default_rule_add(ops, 0x7fff, RT6_TABLE_DFLT, 0); + if (err < 0) + goto err2; + + net->ipv6.mr6_rules_ops = ops; + return 0; + +err2: + rtnl_lock(); + ip6mr_free_table(mrt); + rtnl_unlock(); +err1: + fib_rules_unregister(ops); + return err; +} + +static void __net_exit ip6mr_rules_exit(struct net *net) +{ + struct mr_table *mrt, *next; + + ASSERT_RTNL(); + list_for_each_entry_safe(mrt, next, &net->ipv6.mr6_tables, list) { + list_del(&mrt->list); + ip6mr_free_table(mrt); + } + fib_rules_unregister(net->ipv6.mr6_rules_ops); +} + +static int ip6mr_rules_dump(struct net *net, struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + return fib_rules_dump(net, nb, RTNL_FAMILY_IP6MR, extack); +} + +static unsigned int ip6mr_rules_seq_read(struct net *net) +{ + return fib_rules_seq_read(net, RTNL_FAMILY_IP6MR); +} + +bool ip6mr_rule_default(const struct fib_rule *rule) +{ + return fib_rule_matchall(rule) && rule->action == FR_ACT_TO_TBL && + rule->table == RT6_TABLE_DFLT && !rule->l3mdev; +} +EXPORT_SYMBOL(ip6mr_rule_default); +#else +#define ip6mr_for_each_table(mrt, net) \ + for (mrt = net->ipv6.mrt6; mrt; mrt = NULL) + +static struct mr_table *ip6mr_mr_table_iter(struct net *net, + struct mr_table *mrt) +{ + if (!mrt) + return net->ipv6.mrt6; + return NULL; +} + +static struct mr_table *ip6mr_get_table(struct net *net, u32 id) +{ + return net->ipv6.mrt6; +} + +static int ip6mr_fib_lookup(struct net *net, struct flowi6 *flp6, + struct mr_table **mrt) +{ + *mrt = net->ipv6.mrt6; + return 0; +} + +static int __net_init ip6mr_rules_init(struct net *net) +{ + struct mr_table *mrt; + + mrt = ip6mr_new_table(net, RT6_TABLE_DFLT); + if (IS_ERR(mrt)) + return PTR_ERR(mrt); + net->ipv6.mrt6 = mrt; + return 0; +} + +static void __net_exit ip6mr_rules_exit(struct net *net) +{ + ASSERT_RTNL(); + ip6mr_free_table(net->ipv6.mrt6); + net->ipv6.mrt6 = NULL; +} + +static int ip6mr_rules_dump(struct net *net, struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + return 0; +} + +static unsigned int ip6mr_rules_seq_read(struct net *net) +{ + return 0; +} +#endif + +static int ip6mr_hash_cmp(struct rhashtable_compare_arg *arg, + const void *ptr) +{ + const struct mfc6_cache_cmp_arg *cmparg = arg->key; + struct mfc6_cache *c = (struct mfc6_cache *)ptr; + + return !ipv6_addr_equal(&c->mf6c_mcastgrp, &cmparg->mf6c_mcastgrp) || + !ipv6_addr_equal(&c->mf6c_origin, &cmparg->mf6c_origin); +} + +static const struct rhashtable_params ip6mr_rht_params = { + .head_offset = offsetof(struct mr_mfc, mnode), + .key_offset = offsetof(struct mfc6_cache, cmparg), + .key_len = sizeof(struct mfc6_cache_cmp_arg), + .nelem_hint = 3, + .obj_cmpfn = ip6mr_hash_cmp, + .automatic_shrinking = true, +}; + +static void ip6mr_new_table_set(struct mr_table *mrt, + struct net *net) +{ +#ifdef CONFIG_IPV6_MROUTE_MULTIPLE_TABLES + list_add_tail_rcu(&mrt->list, &net->ipv6.mr6_tables); +#endif +} + +static struct mfc6_cache_cmp_arg ip6mr_mr_table_ops_cmparg_any = { + .mf6c_origin = IN6ADDR_ANY_INIT, + .mf6c_mcastgrp = IN6ADDR_ANY_INIT, +}; + +static struct mr_table_ops ip6mr_mr_table_ops = { + .rht_params = &ip6mr_rht_params, + .cmparg_any = &ip6mr_mr_table_ops_cmparg_any, +}; + +static struct mr_table *ip6mr_new_table(struct net *net, u32 id) +{ + struct mr_table *mrt; + + mrt = ip6mr_get_table(net, id); + if (mrt) + return mrt; + + return mr_table_alloc(net, id, &ip6mr_mr_table_ops, + ipmr_expire_process, ip6mr_new_table_set); +} + +static void ip6mr_free_table(struct mr_table *mrt) +{ + del_timer_sync(&mrt->ipmr_expire_timer); + mroute_clean_tables(mrt, MRT6_FLUSH_MIFS | MRT6_FLUSH_MIFS_STATIC | + MRT6_FLUSH_MFC | MRT6_FLUSH_MFC_STATIC); + rhltable_destroy(&mrt->mfc_hash); + kfree(mrt); +} + +#ifdef CONFIG_PROC_FS +/* The /proc interfaces to multicast routing + * /proc/ip6_mr_cache /proc/ip6_mr_vif + */ + +static void *ip6mr_vif_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(RCU) +{ + struct mr_vif_iter *iter = seq->private; + struct net *net = seq_file_net(seq); + struct mr_table *mrt; + + mrt = ip6mr_get_table(net, RT6_TABLE_DFLT); + if (!mrt) + return ERR_PTR(-ENOENT); + + iter->mrt = mrt; + + rcu_read_lock(); + return mr_vif_seq_start(seq, pos); +} + +static void ip6mr_vif_seq_stop(struct seq_file *seq, void *v) + __releases(RCU) +{ + rcu_read_unlock(); +} + +static int ip6mr_vif_seq_show(struct seq_file *seq, void *v) +{ + struct mr_vif_iter *iter = seq->private; + struct mr_table *mrt = iter->mrt; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, + "Interface BytesIn PktsIn BytesOut PktsOut Flags\n"); + } else { + const struct vif_device *vif = v; + const struct net_device *vif_dev; + const char *name; + + vif_dev = vif_dev_read(vif); + name = vif_dev ? vif_dev->name : "none"; + + seq_printf(seq, + "%2td %-10s %8ld %7ld %8ld %7ld %05X\n", + vif - mrt->vif_table, + name, vif->bytes_in, vif->pkt_in, + vif->bytes_out, vif->pkt_out, + vif->flags); + } + return 0; +} + +static const struct seq_operations ip6mr_vif_seq_ops = { + .start = ip6mr_vif_seq_start, + .next = mr_vif_seq_next, + .stop = ip6mr_vif_seq_stop, + .show = ip6mr_vif_seq_show, +}; + +static void *ipmr_mfc_seq_start(struct seq_file *seq, loff_t *pos) +{ + struct net *net = seq_file_net(seq); + struct mr_table *mrt; + + mrt = ip6mr_get_table(net, RT6_TABLE_DFLT); + if (!mrt) + return ERR_PTR(-ENOENT); + + return mr_mfc_seq_start(seq, pos, mrt, &mfc_unres_lock); +} + +static int ipmr_mfc_seq_show(struct seq_file *seq, void *v) +{ + int n; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, + "Group " + "Origin " + "Iif Pkts Bytes Wrong Oifs\n"); + } else { + const struct mfc6_cache *mfc = v; + const struct mr_mfc_iter *it = seq->private; + struct mr_table *mrt = it->mrt; + + seq_printf(seq, "%pI6 %pI6 %-3hd", + &mfc->mf6c_mcastgrp, &mfc->mf6c_origin, + mfc->_c.mfc_parent); + + if (it->cache != &mrt->mfc_unres_queue) { + seq_printf(seq, " %8lu %8lu %8lu", + mfc->_c.mfc_un.res.pkt, + mfc->_c.mfc_un.res.bytes, + mfc->_c.mfc_un.res.wrong_if); + for (n = mfc->_c.mfc_un.res.minvif; + n < mfc->_c.mfc_un.res.maxvif; n++) { + if (VIF_EXISTS(mrt, n) && + mfc->_c.mfc_un.res.ttls[n] < 255) + seq_printf(seq, + " %2d:%-3d", n, + mfc->_c.mfc_un.res.ttls[n]); + } + } else { + /* unresolved mfc_caches don't contain + * pkt, bytes and wrong_if values + */ + seq_printf(seq, " %8lu %8lu %8lu", 0ul, 0ul, 0ul); + } + seq_putc(seq, '\n'); + } + return 0; +} + +static const struct seq_operations ipmr_mfc_seq_ops = { + .start = ipmr_mfc_seq_start, + .next = mr_mfc_seq_next, + .stop = mr_mfc_seq_stop, + .show = ipmr_mfc_seq_show, +}; +#endif + +#ifdef CONFIG_IPV6_PIMSM_V2 + +static int pim6_rcv(struct sk_buff *skb) +{ + struct pimreghdr *pim; + struct ipv6hdr *encap; + struct net_device *reg_dev = NULL; + struct net *net = dev_net(skb->dev); + struct mr_table *mrt; + struct flowi6 fl6 = { + .flowi6_iif = skb->dev->ifindex, + .flowi6_mark = skb->mark, + }; + int reg_vif_num; + + if (!pskb_may_pull(skb, sizeof(*pim) + sizeof(*encap))) + goto drop; + + pim = (struct pimreghdr *)skb_transport_header(skb); + if (pim->type != ((PIM_VERSION << 4) | PIM_TYPE_REGISTER) || + (pim->flags & PIM_NULL_REGISTER) || + (csum_ipv6_magic(&ipv6_hdr(skb)->saddr, &ipv6_hdr(skb)->daddr, + sizeof(*pim), IPPROTO_PIM, + csum_partial((void *)pim, sizeof(*pim), 0)) && + csum_fold(skb_checksum(skb, 0, skb->len, 0)))) + goto drop; + + /* check if the inner packet is destined to mcast group */ + encap = (struct ipv6hdr *)(skb_transport_header(skb) + + sizeof(*pim)); + + if (!ipv6_addr_is_multicast(&encap->daddr) || + encap->payload_len == 0 || + ntohs(encap->payload_len) + sizeof(*pim) > skb->len) + goto drop; + + if (ip6mr_fib_lookup(net, &fl6, &mrt) < 0) + goto drop; + + /* Pairs with WRITE_ONCE() in mif6_add()/mif6_delete() */ + reg_vif_num = READ_ONCE(mrt->mroute_reg_vif_num); + if (reg_vif_num >= 0) + reg_dev = vif_dev_read(&mrt->vif_table[reg_vif_num]); + + if (!reg_dev) + goto drop; + + skb->mac_header = skb->network_header; + skb_pull(skb, (u8 *)encap - skb->data); + skb_reset_network_header(skb); + skb->protocol = htons(ETH_P_IPV6); + skb->ip_summed = CHECKSUM_NONE; + + skb_tunnel_rx(skb, reg_dev, dev_net(reg_dev)); + + netif_rx(skb); + + return 0; + drop: + kfree_skb(skb); + return 0; +} + +static const struct inet6_protocol pim6_protocol = { + .handler = pim6_rcv, +}; + +/* Service routines creating virtual interfaces: PIMREG */ + +static netdev_tx_t reg_vif_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct net *net = dev_net(dev); + struct mr_table *mrt; + struct flowi6 fl6 = { + .flowi6_oif = dev->ifindex, + .flowi6_iif = skb->skb_iif ? : LOOPBACK_IFINDEX, + .flowi6_mark = skb->mark, + }; + + if (!pskb_inet_may_pull(skb)) + goto tx_err; + + if (ip6mr_fib_lookup(net, &fl6, &mrt) < 0) + goto tx_err; + + dev->stats.tx_bytes += skb->len; + dev->stats.tx_packets++; + rcu_read_lock(); + ip6mr_cache_report(mrt, skb, READ_ONCE(mrt->mroute_reg_vif_num), + MRT6MSG_WHOLEPKT); + rcu_read_unlock(); + kfree_skb(skb); + return NETDEV_TX_OK; + +tx_err: + dev->stats.tx_errors++; + kfree_skb(skb); + return NETDEV_TX_OK; +} + +static int reg_vif_get_iflink(const struct net_device *dev) +{ + return 0; +} + +static const struct net_device_ops reg_vif_netdev_ops = { + .ndo_start_xmit = reg_vif_xmit, + .ndo_get_iflink = reg_vif_get_iflink, +}; + +static void reg_vif_setup(struct net_device *dev) +{ + dev->type = ARPHRD_PIMREG; + dev->mtu = 1500 - sizeof(struct ipv6hdr) - 8; + dev->flags = IFF_NOARP; + dev->netdev_ops = ®_vif_netdev_ops; + dev->needs_free_netdev = true; + dev->features |= NETIF_F_NETNS_LOCAL; +} + +static struct net_device *ip6mr_reg_vif(struct net *net, struct mr_table *mrt) +{ + struct net_device *dev; + char name[IFNAMSIZ]; + + if (mrt->id == RT6_TABLE_DFLT) + sprintf(name, "pim6reg"); + else + sprintf(name, "pim6reg%u", mrt->id); + + dev = alloc_netdev(0, name, NET_NAME_UNKNOWN, reg_vif_setup); + if (!dev) + return NULL; + + dev_net_set(dev, net); + + if (register_netdevice(dev)) { + free_netdev(dev); + return NULL; + } + + if (dev_open(dev, NULL)) + goto failure; + + dev_hold(dev); + return dev; + +failure: + unregister_netdevice(dev); + return NULL; +} +#endif + +static int call_ip6mr_vif_entry_notifiers(struct net *net, + enum fib_event_type event_type, + struct vif_device *vif, + struct net_device *vif_dev, + mifi_t vif_index, u32 tb_id) +{ + return mr_call_vif_notifiers(net, RTNL_FAMILY_IP6MR, event_type, + vif, vif_dev, vif_index, tb_id, + &net->ipv6.ipmr_seq); +} + +static int call_ip6mr_mfc_entry_notifiers(struct net *net, + enum fib_event_type event_type, + struct mfc6_cache *mfc, u32 tb_id) +{ + return mr_call_mfc_notifiers(net, RTNL_FAMILY_IP6MR, event_type, + &mfc->_c, tb_id, &net->ipv6.ipmr_seq); +} + +/* Delete a VIF entry */ +static int mif6_delete(struct mr_table *mrt, int vifi, int notify, + struct list_head *head) +{ + struct vif_device *v; + struct net_device *dev; + struct inet6_dev *in6_dev; + + if (vifi < 0 || vifi >= mrt->maxvif) + return -EADDRNOTAVAIL; + + v = &mrt->vif_table[vifi]; + + dev = rtnl_dereference(v->dev); + if (!dev) + return -EADDRNOTAVAIL; + + call_ip6mr_vif_entry_notifiers(read_pnet(&mrt->net), + FIB_EVENT_VIF_DEL, v, dev, + vifi, mrt->id); + spin_lock(&mrt_lock); + RCU_INIT_POINTER(v->dev, NULL); + +#ifdef CONFIG_IPV6_PIMSM_V2 + if (vifi == mrt->mroute_reg_vif_num) { + /* Pairs with READ_ONCE() in ip6mr_cache_report() and reg_vif_xmit() */ + WRITE_ONCE(mrt->mroute_reg_vif_num, -1); + } +#endif + + if (vifi + 1 == mrt->maxvif) { + int tmp; + for (tmp = vifi - 1; tmp >= 0; tmp--) { + if (VIF_EXISTS(mrt, tmp)) + break; + } + WRITE_ONCE(mrt->maxvif, tmp + 1); + } + + spin_unlock(&mrt_lock); + + dev_set_allmulti(dev, -1); + + in6_dev = __in6_dev_get(dev); + if (in6_dev) { + atomic_dec(&in6_dev->cnf.mc_forwarding); + inet6_netconf_notify_devconf(dev_net(dev), RTM_NEWNETCONF, + NETCONFA_MC_FORWARDING, + dev->ifindex, &in6_dev->cnf); + } + + if ((v->flags & MIFF_REGISTER) && !notify) + unregister_netdevice_queue(dev, head); + + netdev_put(dev, &v->dev_tracker); + return 0; +} + +static inline void ip6mr_cache_free_rcu(struct rcu_head *head) +{ + struct mr_mfc *c = container_of(head, struct mr_mfc, rcu); + + kmem_cache_free(mrt_cachep, (struct mfc6_cache *)c); +} + +static inline void ip6mr_cache_free(struct mfc6_cache *c) +{ + call_rcu(&c->_c.rcu, ip6mr_cache_free_rcu); +} + +/* Destroy an unresolved cache entry, killing queued skbs + and reporting error to netlink readers. + */ + +static void ip6mr_destroy_unres(struct mr_table *mrt, struct mfc6_cache *c) +{ + struct net *net = read_pnet(&mrt->net); + struct sk_buff *skb; + + atomic_dec(&mrt->cache_resolve_queue_len); + + while ((skb = skb_dequeue(&c->_c.mfc_un.unres.unresolved)) != NULL) { + if (ipv6_hdr(skb)->version == 0) { + struct nlmsghdr *nlh = skb_pull(skb, + sizeof(struct ipv6hdr)); + nlh->nlmsg_type = NLMSG_ERROR; + nlh->nlmsg_len = nlmsg_msg_size(sizeof(struct nlmsgerr)); + skb_trim(skb, nlh->nlmsg_len); + ((struct nlmsgerr *)nlmsg_data(nlh))->error = -ETIMEDOUT; + rtnl_unicast(skb, net, NETLINK_CB(skb).portid); + } else + kfree_skb(skb); + } + + ip6mr_cache_free(c); +} + + +/* Timer process for all the unresolved queue. */ + +static void ipmr_do_expire_process(struct mr_table *mrt) +{ + unsigned long now = jiffies; + unsigned long expires = 10 * HZ; + struct mr_mfc *c, *next; + + list_for_each_entry_safe(c, next, &mrt->mfc_unres_queue, list) { + if (time_after(c->mfc_un.unres.expires, now)) { + /* not yet... */ + unsigned long interval = c->mfc_un.unres.expires - now; + if (interval < expires) + expires = interval; + continue; + } + + list_del(&c->list); + mr6_netlink_event(mrt, (struct mfc6_cache *)c, RTM_DELROUTE); + ip6mr_destroy_unres(mrt, (struct mfc6_cache *)c); + } + + if (!list_empty(&mrt->mfc_unres_queue)) + mod_timer(&mrt->ipmr_expire_timer, jiffies + expires); +} + +static void ipmr_expire_process(struct timer_list *t) +{ + struct mr_table *mrt = from_timer(mrt, t, ipmr_expire_timer); + + if (!spin_trylock(&mfc_unres_lock)) { + mod_timer(&mrt->ipmr_expire_timer, jiffies + 1); + return; + } + + if (!list_empty(&mrt->mfc_unres_queue)) + ipmr_do_expire_process(mrt); + + spin_unlock(&mfc_unres_lock); +} + +/* Fill oifs list. It is called under locked mrt_lock. */ + +static void ip6mr_update_thresholds(struct mr_table *mrt, + struct mr_mfc *cache, + unsigned char *ttls) +{ + int vifi; + + cache->mfc_un.res.minvif = MAXMIFS; + cache->mfc_un.res.maxvif = 0; + memset(cache->mfc_un.res.ttls, 255, MAXMIFS); + + for (vifi = 0; vifi < mrt->maxvif; vifi++) { + if (VIF_EXISTS(mrt, vifi) && + ttls[vifi] && ttls[vifi] < 255) { + cache->mfc_un.res.ttls[vifi] = ttls[vifi]; + if (cache->mfc_un.res.minvif > vifi) + cache->mfc_un.res.minvif = vifi; + if (cache->mfc_un.res.maxvif <= vifi) + cache->mfc_un.res.maxvif = vifi + 1; + } + } + cache->mfc_un.res.lastuse = jiffies; +} + +static int mif6_add(struct net *net, struct mr_table *mrt, + struct mif6ctl *vifc, int mrtsock) +{ + int vifi = vifc->mif6c_mifi; + struct vif_device *v = &mrt->vif_table[vifi]; + struct net_device *dev; + struct inet6_dev *in6_dev; + int err; + + /* Is vif busy ? */ + if (VIF_EXISTS(mrt, vifi)) + return -EADDRINUSE; + + switch (vifc->mif6c_flags) { +#ifdef CONFIG_IPV6_PIMSM_V2 + case MIFF_REGISTER: + /* + * Special Purpose VIF in PIM + * All the packets will be sent to the daemon + */ + if (mrt->mroute_reg_vif_num >= 0) + return -EADDRINUSE; + dev = ip6mr_reg_vif(net, mrt); + if (!dev) + return -ENOBUFS; + err = dev_set_allmulti(dev, 1); + if (err) { + unregister_netdevice(dev); + dev_put(dev); + return err; + } + break; +#endif + case 0: + dev = dev_get_by_index(net, vifc->mif6c_pifi); + if (!dev) + return -EADDRNOTAVAIL; + err = dev_set_allmulti(dev, 1); + if (err) { + dev_put(dev); + return err; + } + break; + default: + return -EINVAL; + } + + in6_dev = __in6_dev_get(dev); + if (in6_dev) { + atomic_inc(&in6_dev->cnf.mc_forwarding); + inet6_netconf_notify_devconf(dev_net(dev), RTM_NEWNETCONF, + NETCONFA_MC_FORWARDING, + dev->ifindex, &in6_dev->cnf); + } + + /* Fill in the VIF structures */ + vif_device_init(v, dev, vifc->vifc_rate_limit, vifc->vifc_threshold, + vifc->mif6c_flags | (!mrtsock ? VIFF_STATIC : 0), + MIFF_REGISTER); + + /* And finish update writing critical data */ + spin_lock(&mrt_lock); + rcu_assign_pointer(v->dev, dev); + netdev_tracker_alloc(dev, &v->dev_tracker, GFP_ATOMIC); +#ifdef CONFIG_IPV6_PIMSM_V2 + if (v->flags & MIFF_REGISTER) + WRITE_ONCE(mrt->mroute_reg_vif_num, vifi); +#endif + if (vifi + 1 > mrt->maxvif) + WRITE_ONCE(mrt->maxvif, vifi + 1); + spin_unlock(&mrt_lock); + call_ip6mr_vif_entry_notifiers(net, FIB_EVENT_VIF_ADD, + v, dev, vifi, mrt->id); + return 0; +} + +static struct mfc6_cache *ip6mr_cache_find(struct mr_table *mrt, + const struct in6_addr *origin, + const struct in6_addr *mcastgrp) +{ + struct mfc6_cache_cmp_arg arg = { + .mf6c_origin = *origin, + .mf6c_mcastgrp = *mcastgrp, + }; + + return mr_mfc_find(mrt, &arg); +} + +/* Look for a (*,G) entry */ +static struct mfc6_cache *ip6mr_cache_find_any(struct mr_table *mrt, + struct in6_addr *mcastgrp, + mifi_t mifi) +{ + struct mfc6_cache_cmp_arg arg = { + .mf6c_origin = in6addr_any, + .mf6c_mcastgrp = *mcastgrp, + }; + + if (ipv6_addr_any(mcastgrp)) + return mr_mfc_find_any_parent(mrt, mifi); + return mr_mfc_find_any(mrt, mifi, &arg); +} + +/* Look for a (S,G,iif) entry if parent != -1 */ +static struct mfc6_cache * +ip6mr_cache_find_parent(struct mr_table *mrt, + const struct in6_addr *origin, + const struct in6_addr *mcastgrp, + int parent) +{ + struct mfc6_cache_cmp_arg arg = { + .mf6c_origin = *origin, + .mf6c_mcastgrp = *mcastgrp, + }; + + return mr_mfc_find_parent(mrt, &arg, parent); +} + +/* Allocate a multicast cache entry */ +static struct mfc6_cache *ip6mr_cache_alloc(void) +{ + struct mfc6_cache *c = kmem_cache_zalloc(mrt_cachep, GFP_KERNEL); + if (!c) + return NULL; + c->_c.mfc_un.res.last_assert = jiffies - MFC_ASSERT_THRESH - 1; + c->_c.mfc_un.res.minvif = MAXMIFS; + c->_c.free = ip6mr_cache_free_rcu; + refcount_set(&c->_c.mfc_un.res.refcount, 1); + return c; +} + +static struct mfc6_cache *ip6mr_cache_alloc_unres(void) +{ + struct mfc6_cache *c = kmem_cache_zalloc(mrt_cachep, GFP_ATOMIC); + if (!c) + return NULL; + skb_queue_head_init(&c->_c.mfc_un.unres.unresolved); + c->_c.mfc_un.unres.expires = jiffies + 10 * HZ; + return c; +} + +/* + * A cache entry has gone into a resolved state from queued + */ + +static void ip6mr_cache_resolve(struct net *net, struct mr_table *mrt, + struct mfc6_cache *uc, struct mfc6_cache *c) +{ + struct sk_buff *skb; + + /* + * Play the pending entries through our router + */ + + while ((skb = __skb_dequeue(&uc->_c.mfc_un.unres.unresolved))) { + if (ipv6_hdr(skb)->version == 0) { + struct nlmsghdr *nlh = skb_pull(skb, + sizeof(struct ipv6hdr)); + + if (mr_fill_mroute(mrt, skb, &c->_c, + nlmsg_data(nlh)) > 0) { + nlh->nlmsg_len = skb_tail_pointer(skb) - (u8 *)nlh; + } else { + nlh->nlmsg_type = NLMSG_ERROR; + nlh->nlmsg_len = nlmsg_msg_size(sizeof(struct nlmsgerr)); + skb_trim(skb, nlh->nlmsg_len); + ((struct nlmsgerr *)nlmsg_data(nlh))->error = -EMSGSIZE; + } + rtnl_unicast(skb, net, NETLINK_CB(skb).portid); + } else { + rcu_read_lock(); + ip6_mr_forward(net, mrt, skb->dev, skb, c); + rcu_read_unlock(); + } + } +} + +/* + * Bounce a cache query up to pim6sd and netlink. + * + * Called under rcu_read_lock() + */ + +static int ip6mr_cache_report(const struct mr_table *mrt, struct sk_buff *pkt, + mifi_t mifi, int assert) +{ + struct sock *mroute6_sk; + struct sk_buff *skb; + struct mrt6msg *msg; + int ret; + +#ifdef CONFIG_IPV6_PIMSM_V2 + if (assert == MRT6MSG_WHOLEPKT || assert == MRT6MSG_WRMIFWHOLE) + skb = skb_realloc_headroom(pkt, -skb_network_offset(pkt) + +sizeof(*msg)); + else +#endif + skb = alloc_skb(sizeof(struct ipv6hdr) + sizeof(*msg), GFP_ATOMIC); + + if (!skb) + return -ENOBUFS; + + /* I suppose that internal messages + * do not require checksums */ + + skb->ip_summed = CHECKSUM_UNNECESSARY; + +#ifdef CONFIG_IPV6_PIMSM_V2 + if (assert == MRT6MSG_WHOLEPKT || assert == MRT6MSG_WRMIFWHOLE) { + /* Ugly, but we have no choice with this interface. + Duplicate old header, fix length etc. + And all this only to mangle msg->im6_msgtype and + to set msg->im6_mbz to "mbz" :-) + */ + __skb_pull(skb, skb_network_offset(pkt)); + + skb_push(skb, sizeof(*msg)); + skb_reset_transport_header(skb); + msg = (struct mrt6msg *)skb_transport_header(skb); + msg->im6_mbz = 0; + msg->im6_msgtype = assert; + if (assert == MRT6MSG_WRMIFWHOLE) + msg->im6_mif = mifi; + else + msg->im6_mif = READ_ONCE(mrt->mroute_reg_vif_num); + msg->im6_pad = 0; + msg->im6_src = ipv6_hdr(pkt)->saddr; + msg->im6_dst = ipv6_hdr(pkt)->daddr; + + skb->ip_summed = CHECKSUM_UNNECESSARY; + } else +#endif + { + /* + * Copy the IP header + */ + + skb_put(skb, sizeof(struct ipv6hdr)); + skb_reset_network_header(skb); + skb_copy_to_linear_data(skb, ipv6_hdr(pkt), sizeof(struct ipv6hdr)); + + /* + * Add our header + */ + skb_put(skb, sizeof(*msg)); + skb_reset_transport_header(skb); + msg = (struct mrt6msg *)skb_transport_header(skb); + + msg->im6_mbz = 0; + msg->im6_msgtype = assert; + msg->im6_mif = mifi; + msg->im6_pad = 0; + msg->im6_src = ipv6_hdr(pkt)->saddr; + msg->im6_dst = ipv6_hdr(pkt)->daddr; + + skb_dst_set(skb, dst_clone(skb_dst(pkt))); + skb->ip_summed = CHECKSUM_UNNECESSARY; + } + + mroute6_sk = rcu_dereference(mrt->mroute_sk); + if (!mroute6_sk) { + kfree_skb(skb); + return -EINVAL; + } + + mrt6msg_netlink_event(mrt, skb); + + /* Deliver to user space multicast routing algorithms */ + ret = sock_queue_rcv_skb(mroute6_sk, skb); + + if (ret < 0) { + net_warn_ratelimited("mroute6: pending queue full, dropping entries\n"); + kfree_skb(skb); + } + + return ret; +} + +/* Queue a packet for resolution. It gets locked cache entry! */ +static int ip6mr_cache_unresolved(struct mr_table *mrt, mifi_t mifi, + struct sk_buff *skb, struct net_device *dev) +{ + struct mfc6_cache *c; + bool found = false; + int err; + + spin_lock_bh(&mfc_unres_lock); + list_for_each_entry(c, &mrt->mfc_unres_queue, _c.list) { + if (ipv6_addr_equal(&c->mf6c_mcastgrp, &ipv6_hdr(skb)->daddr) && + ipv6_addr_equal(&c->mf6c_origin, &ipv6_hdr(skb)->saddr)) { + found = true; + break; + } + } + + if (!found) { + /* + * Create a new entry if allowable + */ + + c = ip6mr_cache_alloc_unres(); + if (!c) { + spin_unlock_bh(&mfc_unres_lock); + + kfree_skb(skb); + return -ENOBUFS; + } + + /* Fill in the new cache entry */ + c->_c.mfc_parent = -1; + c->mf6c_origin = ipv6_hdr(skb)->saddr; + c->mf6c_mcastgrp = ipv6_hdr(skb)->daddr; + + /* + * Reflect first query at pim6sd + */ + err = ip6mr_cache_report(mrt, skb, mifi, MRT6MSG_NOCACHE); + if (err < 0) { + /* If the report failed throw the cache entry + out - Brad Parker + */ + spin_unlock_bh(&mfc_unres_lock); + + ip6mr_cache_free(c); + kfree_skb(skb); + return err; + } + + atomic_inc(&mrt->cache_resolve_queue_len); + list_add(&c->_c.list, &mrt->mfc_unres_queue); + mr6_netlink_event(mrt, c, RTM_NEWROUTE); + + ipmr_do_expire_process(mrt); + } + + /* See if we can append the packet */ + if (c->_c.mfc_un.unres.unresolved.qlen > 3) { + kfree_skb(skb); + err = -ENOBUFS; + } else { + if (dev) { + skb->dev = dev; + skb->skb_iif = dev->ifindex; + } + skb_queue_tail(&c->_c.mfc_un.unres.unresolved, skb); + err = 0; + } + + spin_unlock_bh(&mfc_unres_lock); + return err; +} + +/* + * MFC6 cache manipulation by user space + */ + +static int ip6mr_mfc_delete(struct mr_table *mrt, struct mf6cctl *mfc, + int parent) +{ + struct mfc6_cache *c; + + /* The entries are added/deleted only under RTNL */ + rcu_read_lock(); + c = ip6mr_cache_find_parent(mrt, &mfc->mf6cc_origin.sin6_addr, + &mfc->mf6cc_mcastgrp.sin6_addr, parent); + rcu_read_unlock(); + if (!c) + return -ENOENT; + rhltable_remove(&mrt->mfc_hash, &c->_c.mnode, ip6mr_rht_params); + list_del_rcu(&c->_c.list); + + call_ip6mr_mfc_entry_notifiers(read_pnet(&mrt->net), + FIB_EVENT_ENTRY_DEL, c, mrt->id); + mr6_netlink_event(mrt, c, RTM_DELROUTE); + mr_cache_put(&c->_c); + return 0; +} + +static int ip6mr_device_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct net *net = dev_net(dev); + struct mr_table *mrt; + struct vif_device *v; + int ct; + + if (event != NETDEV_UNREGISTER) + return NOTIFY_DONE; + + ip6mr_for_each_table(mrt, net) { + v = &mrt->vif_table[0]; + for (ct = 0; ct < mrt->maxvif; ct++, v++) { + if (rcu_access_pointer(v->dev) == dev) + mif6_delete(mrt, ct, 1, NULL); + } + } + + return NOTIFY_DONE; +} + +static unsigned int ip6mr_seq_read(struct net *net) +{ + ASSERT_RTNL(); + + return net->ipv6.ipmr_seq + ip6mr_rules_seq_read(net); +} + +static int ip6mr_dump(struct net *net, struct notifier_block *nb, + struct netlink_ext_ack *extack) +{ + return mr_dump(net, nb, RTNL_FAMILY_IP6MR, ip6mr_rules_dump, + ip6mr_mr_table_iter, extack); +} + +static struct notifier_block ip6_mr_notifier = { + .notifier_call = ip6mr_device_event +}; + +static const struct fib_notifier_ops ip6mr_notifier_ops_template = { + .family = RTNL_FAMILY_IP6MR, + .fib_seq_read = ip6mr_seq_read, + .fib_dump = ip6mr_dump, + .owner = THIS_MODULE, +}; + +static int __net_init ip6mr_notifier_init(struct net *net) +{ + struct fib_notifier_ops *ops; + + net->ipv6.ipmr_seq = 0; + + ops = fib_notifier_ops_register(&ip6mr_notifier_ops_template, net); + if (IS_ERR(ops)) + return PTR_ERR(ops); + + net->ipv6.ip6mr_notifier_ops = ops; + + return 0; +} + +static void __net_exit ip6mr_notifier_exit(struct net *net) +{ + fib_notifier_ops_unregister(net->ipv6.ip6mr_notifier_ops); + net->ipv6.ip6mr_notifier_ops = NULL; +} + +/* Setup for IP multicast routing */ +static int __net_init ip6mr_net_init(struct net *net) +{ + int err; + + err = ip6mr_notifier_init(net); + if (err) + return err; + + err = ip6mr_rules_init(net); + if (err < 0) + goto ip6mr_rules_fail; + +#ifdef CONFIG_PROC_FS + err = -ENOMEM; + if (!proc_create_net("ip6_mr_vif", 0, net->proc_net, &ip6mr_vif_seq_ops, + sizeof(struct mr_vif_iter))) + goto proc_vif_fail; + if (!proc_create_net("ip6_mr_cache", 0, net->proc_net, &ipmr_mfc_seq_ops, + sizeof(struct mr_mfc_iter))) + goto proc_cache_fail; +#endif + + return 0; + +#ifdef CONFIG_PROC_FS +proc_cache_fail: + remove_proc_entry("ip6_mr_vif", net->proc_net); +proc_vif_fail: + rtnl_lock(); + ip6mr_rules_exit(net); + rtnl_unlock(); +#endif +ip6mr_rules_fail: + ip6mr_notifier_exit(net); + return err; +} + +static void __net_exit ip6mr_net_exit(struct net *net) +{ +#ifdef CONFIG_PROC_FS + remove_proc_entry("ip6_mr_cache", net->proc_net); + remove_proc_entry("ip6_mr_vif", net->proc_net); +#endif + ip6mr_notifier_exit(net); +} + +static void __net_exit ip6mr_net_exit_batch(struct list_head *net_list) +{ + struct net *net; + + rtnl_lock(); + list_for_each_entry(net, net_list, exit_list) + ip6mr_rules_exit(net); + rtnl_unlock(); +} + +static struct pernet_operations ip6mr_net_ops = { + .init = ip6mr_net_init, + .exit = ip6mr_net_exit, + .exit_batch = ip6mr_net_exit_batch, +}; + +int __init ip6_mr_init(void) +{ + int err; + + mrt_cachep = kmem_cache_create("ip6_mrt_cache", + sizeof(struct mfc6_cache), + 0, SLAB_HWCACHE_ALIGN, + NULL); + if (!mrt_cachep) + return -ENOMEM; + + err = register_pernet_subsys(&ip6mr_net_ops); + if (err) + goto reg_pernet_fail; + + err = register_netdevice_notifier(&ip6_mr_notifier); + if (err) + goto reg_notif_fail; +#ifdef CONFIG_IPV6_PIMSM_V2 + if (inet6_add_protocol(&pim6_protocol, IPPROTO_PIM) < 0) { + pr_err("%s: can't add PIM protocol\n", __func__); + err = -EAGAIN; + goto add_proto_fail; + } +#endif + err = rtnl_register_module(THIS_MODULE, RTNL_FAMILY_IP6MR, RTM_GETROUTE, + ip6mr_rtm_getroute, ip6mr_rtm_dumproute, 0); + if (err == 0) + return 0; + +#ifdef CONFIG_IPV6_PIMSM_V2 + inet6_del_protocol(&pim6_protocol, IPPROTO_PIM); +add_proto_fail: + unregister_netdevice_notifier(&ip6_mr_notifier); +#endif +reg_notif_fail: + unregister_pernet_subsys(&ip6mr_net_ops); +reg_pernet_fail: + kmem_cache_destroy(mrt_cachep); + return err; +} + +void ip6_mr_cleanup(void) +{ + rtnl_unregister(RTNL_FAMILY_IP6MR, RTM_GETROUTE); +#ifdef CONFIG_IPV6_PIMSM_V2 + inet6_del_protocol(&pim6_protocol, IPPROTO_PIM); +#endif + unregister_netdevice_notifier(&ip6_mr_notifier); + unregister_pernet_subsys(&ip6mr_net_ops); + kmem_cache_destroy(mrt_cachep); +} + +static int ip6mr_mfc_add(struct net *net, struct mr_table *mrt, + struct mf6cctl *mfc, int mrtsock, int parent) +{ + unsigned char ttls[MAXMIFS]; + struct mfc6_cache *uc, *c; + struct mr_mfc *_uc; + bool found; + int i, err; + + if (mfc->mf6cc_parent >= MAXMIFS) + return -ENFILE; + + memset(ttls, 255, MAXMIFS); + for (i = 0; i < MAXMIFS; i++) { + if (IF_ISSET(i, &mfc->mf6cc_ifset)) + ttls[i] = 1; + } + + /* The entries are added/deleted only under RTNL */ + rcu_read_lock(); + c = ip6mr_cache_find_parent(mrt, &mfc->mf6cc_origin.sin6_addr, + &mfc->mf6cc_mcastgrp.sin6_addr, parent); + rcu_read_unlock(); + if (c) { + spin_lock(&mrt_lock); + c->_c.mfc_parent = mfc->mf6cc_parent; + ip6mr_update_thresholds(mrt, &c->_c, ttls); + if (!mrtsock) + c->_c.mfc_flags |= MFC_STATIC; + spin_unlock(&mrt_lock); + call_ip6mr_mfc_entry_notifiers(net, FIB_EVENT_ENTRY_REPLACE, + c, mrt->id); + mr6_netlink_event(mrt, c, RTM_NEWROUTE); + return 0; + } + + if (!ipv6_addr_any(&mfc->mf6cc_mcastgrp.sin6_addr) && + !ipv6_addr_is_multicast(&mfc->mf6cc_mcastgrp.sin6_addr)) + return -EINVAL; + + c = ip6mr_cache_alloc(); + if (!c) + return -ENOMEM; + + c->mf6c_origin = mfc->mf6cc_origin.sin6_addr; + c->mf6c_mcastgrp = mfc->mf6cc_mcastgrp.sin6_addr; + c->_c.mfc_parent = mfc->mf6cc_parent; + ip6mr_update_thresholds(mrt, &c->_c, ttls); + if (!mrtsock) + c->_c.mfc_flags |= MFC_STATIC; + + err = rhltable_insert_key(&mrt->mfc_hash, &c->cmparg, &c->_c.mnode, + ip6mr_rht_params); + if (err) { + pr_err("ip6mr: rhtable insert error %d\n", err); + ip6mr_cache_free(c); + return err; + } + list_add_tail_rcu(&c->_c.list, &mrt->mfc_cache_list); + + /* Check to see if we resolved a queued list. If so we + * need to send on the frames and tidy up. + */ + found = false; + spin_lock_bh(&mfc_unres_lock); + list_for_each_entry(_uc, &mrt->mfc_unres_queue, list) { + uc = (struct mfc6_cache *)_uc; + if (ipv6_addr_equal(&uc->mf6c_origin, &c->mf6c_origin) && + ipv6_addr_equal(&uc->mf6c_mcastgrp, &c->mf6c_mcastgrp)) { + list_del(&_uc->list); + atomic_dec(&mrt->cache_resolve_queue_len); + found = true; + break; + } + } + if (list_empty(&mrt->mfc_unres_queue)) + del_timer(&mrt->ipmr_expire_timer); + spin_unlock_bh(&mfc_unres_lock); + + if (found) { + ip6mr_cache_resolve(net, mrt, uc, c); + ip6mr_cache_free(uc); + } + call_ip6mr_mfc_entry_notifiers(net, FIB_EVENT_ENTRY_ADD, + c, mrt->id); + mr6_netlink_event(mrt, c, RTM_NEWROUTE); + return 0; +} + +/* + * Close the multicast socket, and clear the vif tables etc + */ + +static void mroute_clean_tables(struct mr_table *mrt, int flags) +{ + struct mr_mfc *c, *tmp; + LIST_HEAD(list); + int i; + + /* Shut down all active vif entries */ + if (flags & (MRT6_FLUSH_MIFS | MRT6_FLUSH_MIFS_STATIC)) { + for (i = 0; i < mrt->maxvif; i++) { + if (((mrt->vif_table[i].flags & VIFF_STATIC) && + !(flags & MRT6_FLUSH_MIFS_STATIC)) || + (!(mrt->vif_table[i].flags & VIFF_STATIC) && !(flags & MRT6_FLUSH_MIFS))) + continue; + mif6_delete(mrt, i, 0, &list); + } + unregister_netdevice_many(&list); + } + + /* Wipe the cache */ + if (flags & (MRT6_FLUSH_MFC | MRT6_FLUSH_MFC_STATIC)) { + list_for_each_entry_safe(c, tmp, &mrt->mfc_cache_list, list) { + if (((c->mfc_flags & MFC_STATIC) && !(flags & MRT6_FLUSH_MFC_STATIC)) || + (!(c->mfc_flags & MFC_STATIC) && !(flags & MRT6_FLUSH_MFC))) + continue; + rhltable_remove(&mrt->mfc_hash, &c->mnode, ip6mr_rht_params); + list_del_rcu(&c->list); + call_ip6mr_mfc_entry_notifiers(read_pnet(&mrt->net), + FIB_EVENT_ENTRY_DEL, + (struct mfc6_cache *)c, mrt->id); + mr6_netlink_event(mrt, (struct mfc6_cache *)c, RTM_DELROUTE); + mr_cache_put(c); + } + } + + if (flags & MRT6_FLUSH_MFC) { + if (atomic_read(&mrt->cache_resolve_queue_len) != 0) { + spin_lock_bh(&mfc_unres_lock); + list_for_each_entry_safe(c, tmp, &mrt->mfc_unres_queue, list) { + list_del(&c->list); + mr6_netlink_event(mrt, (struct mfc6_cache *)c, + RTM_DELROUTE); + ip6mr_destroy_unres(mrt, (struct mfc6_cache *)c); + } + spin_unlock_bh(&mfc_unres_lock); + } + } +} + +static int ip6mr_sk_init(struct mr_table *mrt, struct sock *sk) +{ + int err = 0; + struct net *net = sock_net(sk); + + rtnl_lock(); + spin_lock(&mrt_lock); + if (rtnl_dereference(mrt->mroute_sk)) { + err = -EADDRINUSE; + } else { + rcu_assign_pointer(mrt->mroute_sk, sk); + sock_set_flag(sk, SOCK_RCU_FREE); + atomic_inc(&net->ipv6.devconf_all->mc_forwarding); + } + spin_unlock(&mrt_lock); + + if (!err) + inet6_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_MC_FORWARDING, + NETCONFA_IFINDEX_ALL, + net->ipv6.devconf_all); + rtnl_unlock(); + + return err; +} + +int ip6mr_sk_done(struct sock *sk) +{ + struct net *net = sock_net(sk); + struct ipv6_devconf *devconf; + struct mr_table *mrt; + int err = -EACCES; + + if (sk->sk_type != SOCK_RAW || + inet_sk(sk)->inet_num != IPPROTO_ICMPV6) + return err; + + devconf = net->ipv6.devconf_all; + if (!devconf || !atomic_read(&devconf->mc_forwarding)) + return err; + + rtnl_lock(); + ip6mr_for_each_table(mrt, net) { + if (sk == rtnl_dereference(mrt->mroute_sk)) { + spin_lock(&mrt_lock); + RCU_INIT_POINTER(mrt->mroute_sk, NULL); + /* Note that mroute_sk had SOCK_RCU_FREE set, + * so the RCU grace period before sk freeing + * is guaranteed by sk_destruct() + */ + atomic_dec(&devconf->mc_forwarding); + spin_unlock(&mrt_lock); + inet6_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_MC_FORWARDING, + NETCONFA_IFINDEX_ALL, + net->ipv6.devconf_all); + + mroute_clean_tables(mrt, MRT6_FLUSH_MIFS | MRT6_FLUSH_MFC); + err = 0; + break; + } + } + rtnl_unlock(); + + return err; +} + +bool mroute6_is_socket(struct net *net, struct sk_buff *skb) +{ + struct mr_table *mrt; + struct flowi6 fl6 = { + .flowi6_iif = skb->skb_iif ? : LOOPBACK_IFINDEX, + .flowi6_oif = skb->dev->ifindex, + .flowi6_mark = skb->mark, + }; + + if (ip6mr_fib_lookup(net, &fl6, &mrt) < 0) + return NULL; + + return rcu_access_pointer(mrt->mroute_sk); +} +EXPORT_SYMBOL(mroute6_is_socket); + +/* + * Socket options and virtual interface manipulation. The whole + * virtual interface system is a complete heap, but unfortunately + * that's how BSD mrouted happens to think. Maybe one day with a proper + * MOSPF/PIM router set up we can clean this up. + */ + +int ip6_mroute_setsockopt(struct sock *sk, int optname, sockptr_t optval, + unsigned int optlen) +{ + int ret, parent = 0; + struct mif6ctl vif; + struct mf6cctl mfc; + mifi_t mifi; + struct net *net = sock_net(sk); + struct mr_table *mrt; + + if (sk->sk_type != SOCK_RAW || + inet_sk(sk)->inet_num != IPPROTO_ICMPV6) + return -EOPNOTSUPP; + + mrt = ip6mr_get_table(net, raw6_sk(sk)->ip6mr_table ? : RT6_TABLE_DFLT); + if (!mrt) + return -ENOENT; + + if (optname != MRT6_INIT) { + if (sk != rcu_access_pointer(mrt->mroute_sk) && + !ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EACCES; + } + + switch (optname) { + case MRT6_INIT: + if (optlen < sizeof(int)) + return -EINVAL; + + return ip6mr_sk_init(mrt, sk); + + case MRT6_DONE: + return ip6mr_sk_done(sk); + + case MRT6_ADD_MIF: + if (optlen < sizeof(vif)) + return -EINVAL; + if (copy_from_sockptr(&vif, optval, sizeof(vif))) + return -EFAULT; + if (vif.mif6c_mifi >= MAXMIFS) + return -ENFILE; + rtnl_lock(); + ret = mif6_add(net, mrt, &vif, + sk == rtnl_dereference(mrt->mroute_sk)); + rtnl_unlock(); + return ret; + + case MRT6_DEL_MIF: + if (optlen < sizeof(mifi_t)) + return -EINVAL; + if (copy_from_sockptr(&mifi, optval, sizeof(mifi_t))) + return -EFAULT; + rtnl_lock(); + ret = mif6_delete(mrt, mifi, 0, NULL); + rtnl_unlock(); + return ret; + + /* + * Manipulate the forwarding caches. These live + * in a sort of kernel/user symbiosis. + */ + case MRT6_ADD_MFC: + case MRT6_DEL_MFC: + parent = -1; + fallthrough; + case MRT6_ADD_MFC_PROXY: + case MRT6_DEL_MFC_PROXY: + if (optlen < sizeof(mfc)) + return -EINVAL; + if (copy_from_sockptr(&mfc, optval, sizeof(mfc))) + return -EFAULT; + if (parent == 0) + parent = mfc.mf6cc_parent; + rtnl_lock(); + if (optname == MRT6_DEL_MFC || optname == MRT6_DEL_MFC_PROXY) + ret = ip6mr_mfc_delete(mrt, &mfc, parent); + else + ret = ip6mr_mfc_add(net, mrt, &mfc, + sk == + rtnl_dereference(mrt->mroute_sk), + parent); + rtnl_unlock(); + return ret; + + case MRT6_FLUSH: + { + int flags; + + if (optlen != sizeof(flags)) + return -EINVAL; + if (copy_from_sockptr(&flags, optval, sizeof(flags))) + return -EFAULT; + rtnl_lock(); + mroute_clean_tables(mrt, flags); + rtnl_unlock(); + return 0; + } + + /* + * Control PIM assert (to activate pim will activate assert) + */ + case MRT6_ASSERT: + { + int v; + + if (optlen != sizeof(v)) + return -EINVAL; + if (copy_from_sockptr(&v, optval, sizeof(v))) + return -EFAULT; + mrt->mroute_do_assert = v; + return 0; + } + +#ifdef CONFIG_IPV6_PIMSM_V2 + case MRT6_PIM: + { + bool do_wrmifwhole; + int v; + + if (optlen != sizeof(v)) + return -EINVAL; + if (copy_from_sockptr(&v, optval, sizeof(v))) + return -EFAULT; + + do_wrmifwhole = (v == MRT6MSG_WRMIFWHOLE); + v = !!v; + rtnl_lock(); + ret = 0; + if (v != mrt->mroute_do_pim) { + mrt->mroute_do_pim = v; + mrt->mroute_do_assert = v; + mrt->mroute_do_wrvifwhole = do_wrmifwhole; + } + rtnl_unlock(); + return ret; + } + +#endif +#ifdef CONFIG_IPV6_MROUTE_MULTIPLE_TABLES + case MRT6_TABLE: + { + u32 v; + + if (optlen != sizeof(u32)) + return -EINVAL; + if (copy_from_sockptr(&v, optval, sizeof(v))) + return -EFAULT; + /* "pim6reg%u" should not exceed 16 bytes (IFNAMSIZ) */ + if (v != RT_TABLE_DEFAULT && v >= 100000000) + return -EINVAL; + if (sk == rcu_access_pointer(mrt->mroute_sk)) + return -EBUSY; + + rtnl_lock(); + ret = 0; + mrt = ip6mr_new_table(net, v); + if (IS_ERR(mrt)) + ret = PTR_ERR(mrt); + else + raw6_sk(sk)->ip6mr_table = v; + rtnl_unlock(); + return ret; + } +#endif + /* + * Spurious command, or MRT6_VERSION which you cannot + * set. + */ + default: + return -ENOPROTOOPT; + } +} + +/* + * Getsock opt support for the multicast routing system. + */ + +int ip6_mroute_getsockopt(struct sock *sk, int optname, sockptr_t optval, + sockptr_t optlen) +{ + int olr; + int val; + struct net *net = sock_net(sk); + struct mr_table *mrt; + + if (sk->sk_type != SOCK_RAW || + inet_sk(sk)->inet_num != IPPROTO_ICMPV6) + return -EOPNOTSUPP; + + mrt = ip6mr_get_table(net, raw6_sk(sk)->ip6mr_table ? : RT6_TABLE_DFLT); + if (!mrt) + return -ENOENT; + + switch (optname) { + case MRT6_VERSION: + val = 0x0305; + break; +#ifdef CONFIG_IPV6_PIMSM_V2 + case MRT6_PIM: + val = mrt->mroute_do_pim; + break; +#endif + case MRT6_ASSERT: + val = mrt->mroute_do_assert; + break; + default: + return -ENOPROTOOPT; + } + + if (copy_from_sockptr(&olr, optlen, sizeof(int))) + return -EFAULT; + + olr = min_t(int, olr, sizeof(int)); + if (olr < 0) + return -EINVAL; + + if (copy_to_sockptr(optlen, &olr, sizeof(int))) + return -EFAULT; + if (copy_to_sockptr(optval, &val, olr)) + return -EFAULT; + return 0; +} + +/* + * The IP multicast ioctl support routines. + */ + +int ip6mr_ioctl(struct sock *sk, int cmd, void __user *arg) +{ + struct sioc_sg_req6 sr; + struct sioc_mif_req6 vr; + struct vif_device *vif; + struct mfc6_cache *c; + struct net *net = sock_net(sk); + struct mr_table *mrt; + + mrt = ip6mr_get_table(net, raw6_sk(sk)->ip6mr_table ? : RT6_TABLE_DFLT); + if (!mrt) + return -ENOENT; + + switch (cmd) { + case SIOCGETMIFCNT_IN6: + if (copy_from_user(&vr, arg, sizeof(vr))) + return -EFAULT; + if (vr.mifi >= mrt->maxvif) + return -EINVAL; + vr.mifi = array_index_nospec(vr.mifi, mrt->maxvif); + rcu_read_lock(); + vif = &mrt->vif_table[vr.mifi]; + if (VIF_EXISTS(mrt, vr.mifi)) { + vr.icount = READ_ONCE(vif->pkt_in); + vr.ocount = READ_ONCE(vif->pkt_out); + vr.ibytes = READ_ONCE(vif->bytes_in); + vr.obytes = READ_ONCE(vif->bytes_out); + rcu_read_unlock(); + + if (copy_to_user(arg, &vr, sizeof(vr))) + return -EFAULT; + return 0; + } + rcu_read_unlock(); + return -EADDRNOTAVAIL; + case SIOCGETSGCNT_IN6: + if (copy_from_user(&sr, arg, sizeof(sr))) + return -EFAULT; + + rcu_read_lock(); + c = ip6mr_cache_find(mrt, &sr.src.sin6_addr, &sr.grp.sin6_addr); + if (c) { + sr.pktcnt = c->_c.mfc_un.res.pkt; + sr.bytecnt = c->_c.mfc_un.res.bytes; + sr.wrong_if = c->_c.mfc_un.res.wrong_if; + rcu_read_unlock(); + + if (copy_to_user(arg, &sr, sizeof(sr))) + return -EFAULT; + return 0; + } + rcu_read_unlock(); + return -EADDRNOTAVAIL; + default: + return -ENOIOCTLCMD; + } +} + +#ifdef CONFIG_COMPAT +struct compat_sioc_sg_req6 { + struct sockaddr_in6 src; + struct sockaddr_in6 grp; + compat_ulong_t pktcnt; + compat_ulong_t bytecnt; + compat_ulong_t wrong_if; +}; + +struct compat_sioc_mif_req6 { + mifi_t mifi; + compat_ulong_t icount; + compat_ulong_t ocount; + compat_ulong_t ibytes; + compat_ulong_t obytes; +}; + +int ip6mr_compat_ioctl(struct sock *sk, unsigned int cmd, void __user *arg) +{ + struct compat_sioc_sg_req6 sr; + struct compat_sioc_mif_req6 vr; + struct vif_device *vif; + struct mfc6_cache *c; + struct net *net = sock_net(sk); + struct mr_table *mrt; + + mrt = ip6mr_get_table(net, raw6_sk(sk)->ip6mr_table ? : RT6_TABLE_DFLT); + if (!mrt) + return -ENOENT; + + switch (cmd) { + case SIOCGETMIFCNT_IN6: + if (copy_from_user(&vr, arg, sizeof(vr))) + return -EFAULT; + if (vr.mifi >= mrt->maxvif) + return -EINVAL; + vr.mifi = array_index_nospec(vr.mifi, mrt->maxvif); + rcu_read_lock(); + vif = &mrt->vif_table[vr.mifi]; + if (VIF_EXISTS(mrt, vr.mifi)) { + vr.icount = READ_ONCE(vif->pkt_in); + vr.ocount = READ_ONCE(vif->pkt_out); + vr.ibytes = READ_ONCE(vif->bytes_in); + vr.obytes = READ_ONCE(vif->bytes_out); + rcu_read_unlock(); + + if (copy_to_user(arg, &vr, sizeof(vr))) + return -EFAULT; + return 0; + } + rcu_read_unlock(); + return -EADDRNOTAVAIL; + case SIOCGETSGCNT_IN6: + if (copy_from_user(&sr, arg, sizeof(sr))) + return -EFAULT; + + rcu_read_lock(); + c = ip6mr_cache_find(mrt, &sr.src.sin6_addr, &sr.grp.sin6_addr); + if (c) { + sr.pktcnt = c->_c.mfc_un.res.pkt; + sr.bytecnt = c->_c.mfc_un.res.bytes; + sr.wrong_if = c->_c.mfc_un.res.wrong_if; + rcu_read_unlock(); + + if (copy_to_user(arg, &sr, sizeof(sr))) + return -EFAULT; + return 0; + } + rcu_read_unlock(); + return -EADDRNOTAVAIL; + default: + return -ENOIOCTLCMD; + } +} +#endif + +static inline int ip6mr_forward2_finish(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)), + IPSTATS_MIB_OUTFORWDATAGRAMS); + IP6_ADD_STATS(net, ip6_dst_idev(skb_dst(skb)), + IPSTATS_MIB_OUTOCTETS, skb->len); + return dst_output(net, sk, skb); +} + +/* + * Processing handlers for ip6mr_forward + */ + +static int ip6mr_forward2(struct net *net, struct mr_table *mrt, + struct sk_buff *skb, int vifi) +{ + struct vif_device *vif = &mrt->vif_table[vifi]; + struct net_device *vif_dev; + struct ipv6hdr *ipv6h; + struct dst_entry *dst; + struct flowi6 fl6; + + vif_dev = vif_dev_read(vif); + if (!vif_dev) + goto out_free; + +#ifdef CONFIG_IPV6_PIMSM_V2 + if (vif->flags & MIFF_REGISTER) { + WRITE_ONCE(vif->pkt_out, vif->pkt_out + 1); + WRITE_ONCE(vif->bytes_out, vif->bytes_out + skb->len); + vif_dev->stats.tx_bytes += skb->len; + vif_dev->stats.tx_packets++; + ip6mr_cache_report(mrt, skb, vifi, MRT6MSG_WHOLEPKT); + goto out_free; + } +#endif + + ipv6h = ipv6_hdr(skb); + + fl6 = (struct flowi6) { + .flowi6_oif = vif->link, + .daddr = ipv6h->daddr, + }; + + dst = ip6_route_output(net, NULL, &fl6); + if (dst->error) { + dst_release(dst); + goto out_free; + } + + skb_dst_drop(skb); + skb_dst_set(skb, dst); + + /* + * RFC1584 teaches, that DVMRP/PIM router must deliver packets locally + * not only before forwarding, but after forwarding on all output + * interfaces. It is clear, if mrouter runs a multicasting + * program, it should receive packets not depending to what interface + * program is joined. + * If we will not make it, the program will have to join on all + * interfaces. On the other hand, multihoming host (or router, but + * not mrouter) cannot join to more than one interface - it will + * result in receiving multiple packets. + */ + skb->dev = vif_dev; + WRITE_ONCE(vif->pkt_out, vif->pkt_out + 1); + WRITE_ONCE(vif->bytes_out, vif->bytes_out + skb->len); + + /* We are about to write */ + /* XXX: extension headers? */ + if (skb_cow(skb, sizeof(*ipv6h) + LL_RESERVED_SPACE(vif_dev))) + goto out_free; + + ipv6h = ipv6_hdr(skb); + ipv6h->hop_limit--; + + IP6CB(skb)->flags |= IP6SKB_FORWARDED; + + return NF_HOOK(NFPROTO_IPV6, NF_INET_FORWARD, + net, NULL, skb, skb->dev, vif_dev, + ip6mr_forward2_finish); + +out_free: + kfree_skb(skb); + return 0; +} + +/* Called with rcu_read_lock() */ +static int ip6mr_find_vif(struct mr_table *mrt, struct net_device *dev) +{ + int ct; + + /* Pairs with WRITE_ONCE() in mif6_delete()/mif6_add() */ + for (ct = READ_ONCE(mrt->maxvif) - 1; ct >= 0; ct--) { + if (rcu_access_pointer(mrt->vif_table[ct].dev) == dev) + break; + } + return ct; +} + +/* Called under rcu_read_lock() */ +static void ip6_mr_forward(struct net *net, struct mr_table *mrt, + struct net_device *dev, struct sk_buff *skb, + struct mfc6_cache *c) +{ + int psend = -1; + int vif, ct; + int true_vifi = ip6mr_find_vif(mrt, dev); + + vif = c->_c.mfc_parent; + c->_c.mfc_un.res.pkt++; + c->_c.mfc_un.res.bytes += skb->len; + c->_c.mfc_un.res.lastuse = jiffies; + + if (ipv6_addr_any(&c->mf6c_origin) && true_vifi >= 0) { + struct mfc6_cache *cache_proxy; + + /* For an (*,G) entry, we only check that the incoming + * interface is part of the static tree. + */ + cache_proxy = mr_mfc_find_any_parent(mrt, vif); + if (cache_proxy && + cache_proxy->_c.mfc_un.res.ttls[true_vifi] < 255) + goto forward; + } + + /* + * Wrong interface: drop packet and (maybe) send PIM assert. + */ + if (rcu_access_pointer(mrt->vif_table[vif].dev) != dev) { + c->_c.mfc_un.res.wrong_if++; + + if (true_vifi >= 0 && mrt->mroute_do_assert && + /* pimsm uses asserts, when switching from RPT to SPT, + so that we cannot check that packet arrived on an oif. + It is bad, but otherwise we would need to move pretty + large chunk of pimd to kernel. Ough... --ANK + */ + (mrt->mroute_do_pim || + c->_c.mfc_un.res.ttls[true_vifi] < 255) && + time_after(jiffies, + c->_c.mfc_un.res.last_assert + + MFC_ASSERT_THRESH)) { + c->_c.mfc_un.res.last_assert = jiffies; + ip6mr_cache_report(mrt, skb, true_vifi, MRT6MSG_WRONGMIF); + if (mrt->mroute_do_wrvifwhole) + ip6mr_cache_report(mrt, skb, true_vifi, + MRT6MSG_WRMIFWHOLE); + } + goto dont_forward; + } + +forward: + WRITE_ONCE(mrt->vif_table[vif].pkt_in, + mrt->vif_table[vif].pkt_in + 1); + WRITE_ONCE(mrt->vif_table[vif].bytes_in, + mrt->vif_table[vif].bytes_in + skb->len); + + /* + * Forward the frame + */ + if (ipv6_addr_any(&c->mf6c_origin) && + ipv6_addr_any(&c->mf6c_mcastgrp)) { + if (true_vifi >= 0 && + true_vifi != c->_c.mfc_parent && + ipv6_hdr(skb)->hop_limit > + c->_c.mfc_un.res.ttls[c->_c.mfc_parent]) { + /* It's an (*,*) entry and the packet is not coming from + * the upstream: forward the packet to the upstream + * only. + */ + psend = c->_c.mfc_parent; + goto last_forward; + } + goto dont_forward; + } + for (ct = c->_c.mfc_un.res.maxvif - 1; + ct >= c->_c.mfc_un.res.minvif; ct--) { + /* For (*,G) entry, don't forward to the incoming interface */ + if ((!ipv6_addr_any(&c->mf6c_origin) || ct != true_vifi) && + ipv6_hdr(skb)->hop_limit > c->_c.mfc_un.res.ttls[ct]) { + if (psend != -1) { + struct sk_buff *skb2 = skb_clone(skb, GFP_ATOMIC); + if (skb2) + ip6mr_forward2(net, mrt, skb2, psend); + } + psend = ct; + } + } +last_forward: + if (psend != -1) { + ip6mr_forward2(net, mrt, skb, psend); + return; + } + +dont_forward: + kfree_skb(skb); +} + + +/* + * Multicast packets for forwarding arrive here + */ + +int ip6_mr_input(struct sk_buff *skb) +{ + struct mfc6_cache *cache; + struct net *net = dev_net(skb->dev); + struct mr_table *mrt; + struct flowi6 fl6 = { + .flowi6_iif = skb->dev->ifindex, + .flowi6_mark = skb->mark, + }; + int err; + struct net_device *dev; + + /* skb->dev passed in is the master dev for vrfs. + * Get the proper interface that does have a vif associated with it. + */ + dev = skb->dev; + if (netif_is_l3_master(skb->dev)) { + dev = dev_get_by_index_rcu(net, IPCB(skb)->iif); + if (!dev) { + kfree_skb(skb); + return -ENODEV; + } + } + + err = ip6mr_fib_lookup(net, &fl6, &mrt); + if (err < 0) { + kfree_skb(skb); + return err; + } + + cache = ip6mr_cache_find(mrt, + &ipv6_hdr(skb)->saddr, &ipv6_hdr(skb)->daddr); + if (!cache) { + int vif = ip6mr_find_vif(mrt, dev); + + if (vif >= 0) + cache = ip6mr_cache_find_any(mrt, + &ipv6_hdr(skb)->daddr, + vif); + } + + /* + * No usable cache entry + */ + if (!cache) { + int vif; + + vif = ip6mr_find_vif(mrt, dev); + if (vif >= 0) { + int err = ip6mr_cache_unresolved(mrt, vif, skb, dev); + + return err; + } + kfree_skb(skb); + return -ENODEV; + } + + ip6_mr_forward(net, mrt, dev, skb, cache); + + return 0; +} + +int ip6mr_get_route(struct net *net, struct sk_buff *skb, struct rtmsg *rtm, + u32 portid) +{ + int err; + struct mr_table *mrt; + struct mfc6_cache *cache; + struct rt6_info *rt = (struct rt6_info *)skb_dst(skb); + + mrt = ip6mr_get_table(net, RT6_TABLE_DFLT); + if (!mrt) + return -ENOENT; + + rcu_read_lock(); + cache = ip6mr_cache_find(mrt, &rt->rt6i_src.addr, &rt->rt6i_dst.addr); + if (!cache && skb->dev) { + int vif = ip6mr_find_vif(mrt, skb->dev); + + if (vif >= 0) + cache = ip6mr_cache_find_any(mrt, &rt->rt6i_dst.addr, + vif); + } + + if (!cache) { + struct sk_buff *skb2; + struct ipv6hdr *iph; + struct net_device *dev; + int vif; + + dev = skb->dev; + if (!dev || (vif = ip6mr_find_vif(mrt, dev)) < 0) { + rcu_read_unlock(); + return -ENODEV; + } + + /* really correct? */ + skb2 = alloc_skb(sizeof(struct ipv6hdr), GFP_ATOMIC); + if (!skb2) { + rcu_read_unlock(); + return -ENOMEM; + } + + NETLINK_CB(skb2).portid = portid; + skb_reset_transport_header(skb2); + + skb_put(skb2, sizeof(struct ipv6hdr)); + skb_reset_network_header(skb2); + + iph = ipv6_hdr(skb2); + iph->version = 0; + iph->priority = 0; + iph->flow_lbl[0] = 0; + iph->flow_lbl[1] = 0; + iph->flow_lbl[2] = 0; + iph->payload_len = 0; + iph->nexthdr = IPPROTO_NONE; + iph->hop_limit = 0; + iph->saddr = rt->rt6i_src.addr; + iph->daddr = rt->rt6i_dst.addr; + + err = ip6mr_cache_unresolved(mrt, vif, skb2, dev); + rcu_read_unlock(); + + return err; + } + + err = mr_fill_mroute(mrt, skb, &cache->_c, rtm); + rcu_read_unlock(); + return err; +} + +static int ip6mr_fill_mroute(struct mr_table *mrt, struct sk_buff *skb, + u32 portid, u32 seq, struct mfc6_cache *c, int cmd, + int flags) +{ + struct nlmsghdr *nlh; + struct rtmsg *rtm; + int err; + + nlh = nlmsg_put(skb, portid, seq, cmd, sizeof(*rtm), flags); + if (!nlh) + return -EMSGSIZE; + + rtm = nlmsg_data(nlh); + rtm->rtm_family = RTNL_FAMILY_IP6MR; + rtm->rtm_dst_len = 128; + rtm->rtm_src_len = 128; + rtm->rtm_tos = 0; + rtm->rtm_table = mrt->id; + if (nla_put_u32(skb, RTA_TABLE, mrt->id)) + goto nla_put_failure; + rtm->rtm_type = RTN_MULTICAST; + rtm->rtm_scope = RT_SCOPE_UNIVERSE; + if (c->_c.mfc_flags & MFC_STATIC) + rtm->rtm_protocol = RTPROT_STATIC; + else + rtm->rtm_protocol = RTPROT_MROUTED; + rtm->rtm_flags = 0; + + if (nla_put_in6_addr(skb, RTA_SRC, &c->mf6c_origin) || + nla_put_in6_addr(skb, RTA_DST, &c->mf6c_mcastgrp)) + goto nla_put_failure; + err = mr_fill_mroute(mrt, skb, &c->_c, rtm); + /* do not break the dump if cache is unresolved */ + if (err < 0 && err != -ENOENT) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int _ip6mr_fill_mroute(struct mr_table *mrt, struct sk_buff *skb, + u32 portid, u32 seq, struct mr_mfc *c, + int cmd, int flags) +{ + return ip6mr_fill_mroute(mrt, skb, portid, seq, (struct mfc6_cache *)c, + cmd, flags); +} + +static int mr6_msgsize(bool unresolved, int maxvif) +{ + size_t len = + NLMSG_ALIGN(sizeof(struct rtmsg)) + + nla_total_size(4) /* RTA_TABLE */ + + nla_total_size(sizeof(struct in6_addr)) /* RTA_SRC */ + + nla_total_size(sizeof(struct in6_addr)) /* RTA_DST */ + ; + + if (!unresolved) + len = len + + nla_total_size(4) /* RTA_IIF */ + + nla_total_size(0) /* RTA_MULTIPATH */ + + maxvif * NLA_ALIGN(sizeof(struct rtnexthop)) + /* RTA_MFC_STATS */ + + nla_total_size_64bit(sizeof(struct rta_mfc_stats)) + ; + + return len; +} + +static void mr6_netlink_event(struct mr_table *mrt, struct mfc6_cache *mfc, + int cmd) +{ + struct net *net = read_pnet(&mrt->net); + struct sk_buff *skb; + int err = -ENOBUFS; + + skb = nlmsg_new(mr6_msgsize(mfc->_c.mfc_parent >= MAXMIFS, mrt->maxvif), + GFP_ATOMIC); + if (!skb) + goto errout; + + err = ip6mr_fill_mroute(mrt, skb, 0, 0, mfc, cmd, 0); + if (err < 0) + goto errout; + + rtnl_notify(skb, net, 0, RTNLGRP_IPV6_MROUTE, NULL, GFP_ATOMIC); + return; + +errout: + kfree_skb(skb); + if (err < 0) + rtnl_set_sk_err(net, RTNLGRP_IPV6_MROUTE, err); +} + +static size_t mrt6msg_netlink_msgsize(size_t payloadlen) +{ + size_t len = + NLMSG_ALIGN(sizeof(struct rtgenmsg)) + + nla_total_size(1) /* IP6MRA_CREPORT_MSGTYPE */ + + nla_total_size(4) /* IP6MRA_CREPORT_MIF_ID */ + /* IP6MRA_CREPORT_SRC_ADDR */ + + nla_total_size(sizeof(struct in6_addr)) + /* IP6MRA_CREPORT_DST_ADDR */ + + nla_total_size(sizeof(struct in6_addr)) + /* IP6MRA_CREPORT_PKT */ + + nla_total_size(payloadlen) + ; + + return len; +} + +static void mrt6msg_netlink_event(const struct mr_table *mrt, struct sk_buff *pkt) +{ + struct net *net = read_pnet(&mrt->net); + struct nlmsghdr *nlh; + struct rtgenmsg *rtgenm; + struct mrt6msg *msg; + struct sk_buff *skb; + struct nlattr *nla; + int payloadlen; + + payloadlen = pkt->len - sizeof(struct mrt6msg); + msg = (struct mrt6msg *)skb_transport_header(pkt); + + skb = nlmsg_new(mrt6msg_netlink_msgsize(payloadlen), GFP_ATOMIC); + if (!skb) + goto errout; + + nlh = nlmsg_put(skb, 0, 0, RTM_NEWCACHEREPORT, + sizeof(struct rtgenmsg), 0); + if (!nlh) + goto errout; + rtgenm = nlmsg_data(nlh); + rtgenm->rtgen_family = RTNL_FAMILY_IP6MR; + if (nla_put_u8(skb, IP6MRA_CREPORT_MSGTYPE, msg->im6_msgtype) || + nla_put_u32(skb, IP6MRA_CREPORT_MIF_ID, msg->im6_mif) || + nla_put_in6_addr(skb, IP6MRA_CREPORT_SRC_ADDR, + &msg->im6_src) || + nla_put_in6_addr(skb, IP6MRA_CREPORT_DST_ADDR, + &msg->im6_dst)) + goto nla_put_failure; + + nla = nla_reserve(skb, IP6MRA_CREPORT_PKT, payloadlen); + if (!nla || skb_copy_bits(pkt, sizeof(struct mrt6msg), + nla_data(nla), payloadlen)) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + + rtnl_notify(skb, net, 0, RTNLGRP_IPV6_MROUTE_R, NULL, GFP_ATOMIC); + return; + +nla_put_failure: + nlmsg_cancel(skb, nlh); +errout: + kfree_skb(skb); + rtnl_set_sk_err(net, RTNLGRP_IPV6_MROUTE_R, -ENOBUFS); +} + +static const struct nla_policy ip6mr_getroute_policy[RTA_MAX + 1] = { + [RTA_SRC] = NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)), + [RTA_DST] = NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)), + [RTA_TABLE] = { .type = NLA_U32 }, +}; + +static int ip6mr_rtm_valid_getroute_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct rtmsg *rtm; + int err; + + err = nlmsg_parse(nlh, sizeof(*rtm), tb, RTA_MAX, ip6mr_getroute_policy, + extack); + if (err) + return err; + + rtm = nlmsg_data(nlh); + if ((rtm->rtm_src_len && rtm->rtm_src_len != 128) || + (rtm->rtm_dst_len && rtm->rtm_dst_len != 128) || + rtm->rtm_tos || rtm->rtm_table || rtm->rtm_protocol || + rtm->rtm_scope || rtm->rtm_type || rtm->rtm_flags) { + NL_SET_ERR_MSG_MOD(extack, + "Invalid values in header for multicast route get request"); + return -EINVAL; + } + + if ((tb[RTA_SRC] && !rtm->rtm_src_len) || + (tb[RTA_DST] && !rtm->rtm_dst_len)) { + NL_SET_ERR_MSG_MOD(extack, "rtm_src_len and rtm_dst_len must be 128 for IPv6"); + return -EINVAL; + } + + return 0; +} + +static int ip6mr_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(in_skb->sk); + struct in6_addr src = {}, grp = {}; + struct nlattr *tb[RTA_MAX + 1]; + struct mfc6_cache *cache; + struct mr_table *mrt; + struct sk_buff *skb; + u32 tableid; + int err; + + err = ip6mr_rtm_valid_getroute_req(in_skb, nlh, tb, extack); + if (err < 0) + return err; + + if (tb[RTA_SRC]) + src = nla_get_in6_addr(tb[RTA_SRC]); + if (tb[RTA_DST]) + grp = nla_get_in6_addr(tb[RTA_DST]); + tableid = tb[RTA_TABLE] ? nla_get_u32(tb[RTA_TABLE]) : 0; + + mrt = ip6mr_get_table(net, tableid ?: RT_TABLE_DEFAULT); + if (!mrt) { + NL_SET_ERR_MSG_MOD(extack, "MR table does not exist"); + return -ENOENT; + } + + /* entries are added/deleted only under RTNL */ + rcu_read_lock(); + cache = ip6mr_cache_find(mrt, &src, &grp); + rcu_read_unlock(); + if (!cache) { + NL_SET_ERR_MSG_MOD(extack, "MR cache entry not found"); + return -ENOENT; + } + + skb = nlmsg_new(mr6_msgsize(false, mrt->maxvif), GFP_KERNEL); + if (!skb) + return -ENOBUFS; + + err = ip6mr_fill_mroute(mrt, skb, NETLINK_CB(in_skb).portid, + nlh->nlmsg_seq, cache, RTM_NEWROUTE, 0); + if (err < 0) { + kfree_skb(skb); + return err; + } + + return rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid); +} + +static int ip6mr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb) +{ + const struct nlmsghdr *nlh = cb->nlh; + struct fib_dump_filter filter = {}; + int err; + + if (cb->strict_check) { + err = ip_valid_fib_dump_req(sock_net(skb->sk), nlh, + &filter, cb); + if (err < 0) + return err; + } + + if (filter.table_id) { + struct mr_table *mrt; + + mrt = ip6mr_get_table(sock_net(skb->sk), filter.table_id); + if (!mrt) { + if (rtnl_msg_family(cb->nlh) != RTNL_FAMILY_IP6MR) + return skb->len; + + NL_SET_ERR_MSG_MOD(cb->extack, "MR table does not exist"); + return -ENOENT; + } + err = mr_table_dump(mrt, skb, cb, _ip6mr_fill_mroute, + &mfc_unres_lock, &filter); + return skb->len ? : err; + } + + return mr_rtm_dumproute(skb, cb, ip6mr_mr_table_iter, + _ip6mr_fill_mroute, &mfc_unres_lock, &filter); +} diff --git a/net/ipv6/ipcomp6.c b/net/ipv6/ipcomp6.c new file mode 100644 index 000000000..72d4858de --- /dev/null +++ b/net/ipv6/ipcomp6.c @@ -0,0 +1,222 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IP Payload Compression Protocol (IPComp) for IPv6 - RFC3173 + * + * Copyright (C)2003 USAGI/WIDE Project + * + * Author Mitsuru KANDA + */ +/* + * [Memo] + * + * Outbound: + * The compression of IP datagram MUST be done before AH/ESP processing, + * fragmentation, and the addition of Hop-by-Hop/Routing header. + * + * Inbound: + * The decompression of IP datagram MUST be done after the reassembly, + * AH/ESP processing. + */ + +#define pr_fmt(fmt) "IPv6: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int ipcomp6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + struct net *net = dev_net(skb->dev); + __be32 spi; + const struct ipv6hdr *iph = (const struct ipv6hdr *)skb->data; + struct ip_comp_hdr *ipcomph = + (struct ip_comp_hdr *)(skb->data + offset); + struct xfrm_state *x; + + if (type != ICMPV6_PKT_TOOBIG && + type != NDISC_REDIRECT) + return 0; + + spi = htonl(ntohs(ipcomph->cpi)); + x = xfrm_state_lookup(net, skb->mark, (const xfrm_address_t *)&iph->daddr, + spi, IPPROTO_COMP, AF_INET6); + if (!x) + return 0; + + if (type == NDISC_REDIRECT) + ip6_redirect(skb, net, skb->dev->ifindex, 0, + sock_net_uid(net, NULL)); + else + ip6_update_pmtu(skb, net, info, 0, 0, sock_net_uid(net, NULL)); + xfrm_state_put(x); + + return 0; +} + +static struct xfrm_state *ipcomp6_tunnel_create(struct xfrm_state *x) +{ + struct net *net = xs_net(x); + struct xfrm_state *t = NULL; + + t = xfrm_state_alloc(net); + if (!t) + goto out; + + t->id.proto = IPPROTO_IPV6; + t->id.spi = xfrm6_tunnel_alloc_spi(net, (xfrm_address_t *)&x->props.saddr); + if (!t->id.spi) + goto error; + + memcpy(t->id.daddr.a6, x->id.daddr.a6, sizeof(struct in6_addr)); + memcpy(&t->sel, &x->sel, sizeof(t->sel)); + t->props.family = AF_INET6; + t->props.mode = x->props.mode; + memcpy(t->props.saddr.a6, x->props.saddr.a6, sizeof(struct in6_addr)); + memcpy(&t->mark, &x->mark, sizeof(t->mark)); + t->if_id = x->if_id; + + if (xfrm_init_state(t)) + goto error; + + atomic_set(&t->tunnel_users, 1); + +out: + return t; + +error: + t->km.state = XFRM_STATE_DEAD; + xfrm_state_put(t); + t = NULL; + goto out; +} + +static int ipcomp6_tunnel_attach(struct xfrm_state *x) +{ + struct net *net = xs_net(x); + int err = 0; + struct xfrm_state *t = NULL; + __be32 spi; + u32 mark = x->mark.m & x->mark.v; + + spi = xfrm6_tunnel_spi_lookup(net, (xfrm_address_t *)&x->props.saddr); + if (spi) + t = xfrm_state_lookup(net, mark, (xfrm_address_t *)&x->id.daddr, + spi, IPPROTO_IPV6, AF_INET6); + if (!t) { + t = ipcomp6_tunnel_create(x); + if (!t) { + err = -EINVAL; + goto out; + } + xfrm_state_insert(t); + xfrm_state_hold(t); + } + x->tunnel = t; + atomic_inc(&t->tunnel_users); + +out: + return err; +} + +static int ipcomp6_init_state(struct xfrm_state *x, + struct netlink_ext_ack *extack) +{ + int err = -EINVAL; + + x->props.header_len = 0; + switch (x->props.mode) { + case XFRM_MODE_TRANSPORT: + break; + case XFRM_MODE_TUNNEL: + x->props.header_len += sizeof(struct ipv6hdr); + break; + default: + NL_SET_ERR_MSG(extack, "Unsupported XFRM mode for IPcomp"); + goto out; + } + + err = ipcomp_init_state(x, extack); + if (err) + goto out; + + if (x->props.mode == XFRM_MODE_TUNNEL) { + err = ipcomp6_tunnel_attach(x); + if (err) { + NL_SET_ERR_MSG(extack, "Kernel error: failed to initialize the associated state"); + goto out; + } + } + + err = 0; +out: + return err; +} + +static int ipcomp6_rcv_cb(struct sk_buff *skb, int err) +{ + return 0; +} + +static const struct xfrm_type ipcomp6_type = { + .owner = THIS_MODULE, + .proto = IPPROTO_COMP, + .init_state = ipcomp6_init_state, + .destructor = ipcomp_destroy, + .input = ipcomp_input, + .output = ipcomp_output, +}; + +static struct xfrm6_protocol ipcomp6_protocol = { + .handler = xfrm6_rcv, + .input_handler = xfrm_input, + .cb_handler = ipcomp6_rcv_cb, + .err_handler = ipcomp6_err, + .priority = 0, +}; + +static int __init ipcomp6_init(void) +{ + if (xfrm_register_type(&ipcomp6_type, AF_INET6) < 0) { + pr_info("%s: can't add xfrm type\n", __func__); + return -EAGAIN; + } + if (xfrm6_protocol_register(&ipcomp6_protocol, IPPROTO_COMP) < 0) { + pr_info("%s: can't add protocol\n", __func__); + xfrm_unregister_type(&ipcomp6_type, AF_INET6); + return -EAGAIN; + } + return 0; +} + +static void __exit ipcomp6_fini(void) +{ + if (xfrm6_protocol_deregister(&ipcomp6_protocol, IPPROTO_COMP) < 0) + pr_info("%s: can't remove protocol\n", __func__); + xfrm_unregister_type(&ipcomp6_type, AF_INET6); +} + +module_init(ipcomp6_init); +module_exit(ipcomp6_fini); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("IP Payload Compression Protocol (IPComp) for IPv6 - RFC3173"); +MODULE_AUTHOR("Mitsuru KANDA "); + +MODULE_ALIAS_XFRM_TYPE(AF_INET6, XFRM_PROTO_COMP); diff --git a/net/ipv6/ipv6_sockglue.c b/net/ipv6/ipv6_sockglue.c new file mode 100644 index 000000000..532f4478c --- /dev/null +++ b/net/ipv6/ipv6_sockglue.c @@ -0,0 +1,1520 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPv6 BSD socket options interface + * Linux INET6 implementation + * + * Authors: + * Pedro Roque + * + * Based on linux/net/ipv4/ip_sockglue.c + * + * FIXME: Make the setsockopt code POSIX compliant: That is + * + * o Truncate getsockopt returns + * o Return an optlen of the truncated length if need be + * + * Changes: + * David L Stevens : + * - added multicast source filtering API for MLDv2 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +struct ip6_ra_chain *ip6_ra_chain; +DEFINE_RWLOCK(ip6_ra_lock); + +DEFINE_STATIC_KEY_FALSE(ip6_min_hopcount); + +int ip6_ra_control(struct sock *sk, int sel) +{ + struct ip6_ra_chain *ra, *new_ra, **rap; + + /* RA packet may be delivered ONLY to IPPROTO_RAW socket */ + if (sk->sk_type != SOCK_RAW || inet_sk(sk)->inet_num != IPPROTO_RAW) + return -ENOPROTOOPT; + + new_ra = (sel >= 0) ? kmalloc(sizeof(*new_ra), GFP_KERNEL) : NULL; + if (sel >= 0 && !new_ra) + return -ENOMEM; + + write_lock_bh(&ip6_ra_lock); + for (rap = &ip6_ra_chain; (ra = *rap) != NULL; rap = &ra->next) { + if (ra->sk == sk) { + if (sel >= 0) { + write_unlock_bh(&ip6_ra_lock); + kfree(new_ra); + return -EADDRINUSE; + } + + *rap = ra->next; + write_unlock_bh(&ip6_ra_lock); + + sock_put(sk); + kfree(ra); + return 0; + } + } + if (!new_ra) { + write_unlock_bh(&ip6_ra_lock); + return -ENOBUFS; + } + new_ra->sk = sk; + new_ra->sel = sel; + new_ra->next = ra; + *rap = new_ra; + sock_hold(sk); + write_unlock_bh(&ip6_ra_lock); + return 0; +} + +struct ipv6_txoptions *ipv6_update_options(struct sock *sk, + struct ipv6_txoptions *opt) +{ + if (inet_sk(sk)->is_icsk) { + if (opt && + !((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) && + inet_sk(sk)->inet_daddr != LOOPBACK4_IPV6) { + struct inet_connection_sock *icsk = inet_csk(sk); + icsk->icsk_ext_hdr_len = opt->opt_flen + opt->opt_nflen; + icsk->icsk_sync_mss(sk, icsk->icsk_pmtu_cookie); + } + } + opt = xchg((__force struct ipv6_txoptions **)&inet6_sk(sk)->opt, + opt); + sk_dst_reset(sk); + + return opt; +} + +static bool setsockopt_needs_rtnl(int optname) +{ + switch (optname) { + case IPV6_ADDRFORM: + case IPV6_ADD_MEMBERSHIP: + case IPV6_DROP_MEMBERSHIP: + case IPV6_JOIN_ANYCAST: + case IPV6_LEAVE_ANYCAST: + case MCAST_JOIN_GROUP: + case MCAST_LEAVE_GROUP: + case MCAST_JOIN_SOURCE_GROUP: + case MCAST_LEAVE_SOURCE_GROUP: + case MCAST_BLOCK_SOURCE: + case MCAST_UNBLOCK_SOURCE: + case MCAST_MSFILTER: + return true; + } + return false; +} + +static int copy_group_source_from_sockptr(struct group_source_req *greqs, + sockptr_t optval, int optlen) +{ + if (in_compat_syscall()) { + struct compat_group_source_req gr32; + + if (optlen < sizeof(gr32)) + return -EINVAL; + if (copy_from_sockptr(&gr32, optval, sizeof(gr32))) + return -EFAULT; + greqs->gsr_interface = gr32.gsr_interface; + greqs->gsr_group = gr32.gsr_group; + greqs->gsr_source = gr32.gsr_source; + } else { + if (optlen < sizeof(*greqs)) + return -EINVAL; + if (copy_from_sockptr(greqs, optval, sizeof(*greqs))) + return -EFAULT; + } + + return 0; +} + +static int do_ipv6_mcast_group_source(struct sock *sk, int optname, + sockptr_t optval, int optlen) +{ + struct group_source_req greqs; + int omode, add; + int ret; + + ret = copy_group_source_from_sockptr(&greqs, optval, optlen); + if (ret) + return ret; + + if (greqs.gsr_group.ss_family != AF_INET6 || + greqs.gsr_source.ss_family != AF_INET6) + return -EADDRNOTAVAIL; + + if (optname == MCAST_BLOCK_SOURCE) { + omode = MCAST_EXCLUDE; + add = 1; + } else if (optname == MCAST_UNBLOCK_SOURCE) { + omode = MCAST_EXCLUDE; + add = 0; + } else if (optname == MCAST_JOIN_SOURCE_GROUP) { + struct sockaddr_in6 *psin6; + int retv; + + psin6 = (struct sockaddr_in6 *)&greqs.gsr_group; + retv = ipv6_sock_mc_join_ssm(sk, greqs.gsr_interface, + &psin6->sin6_addr, + MCAST_INCLUDE); + /* prior join w/ different source is ok */ + if (retv && retv != -EADDRINUSE) + return retv; + omode = MCAST_INCLUDE; + add = 1; + } else /* MCAST_LEAVE_SOURCE_GROUP */ { + omode = MCAST_INCLUDE; + add = 0; + } + return ip6_mc_source(add, omode, sk, &greqs); +} + +static int ipv6_set_mcast_msfilter(struct sock *sk, sockptr_t optval, + int optlen) +{ + struct group_filter *gsf; + int ret; + + if (optlen < GROUP_FILTER_SIZE(0)) + return -EINVAL; + if (optlen > READ_ONCE(sysctl_optmem_max)) + return -ENOBUFS; + + gsf = memdup_sockptr(optval, optlen); + if (IS_ERR(gsf)) + return PTR_ERR(gsf); + + /* numsrc >= (4G-140)/128 overflow in 32 bits */ + ret = -ENOBUFS; + if (gsf->gf_numsrc >= 0x1ffffffU || + gsf->gf_numsrc > sysctl_mld_max_msf) + goto out_free_gsf; + + ret = -EINVAL; + if (GROUP_FILTER_SIZE(gsf->gf_numsrc) > optlen) + goto out_free_gsf; + + ret = ip6_mc_msfilter(sk, gsf, gsf->gf_slist_flex); +out_free_gsf: + kfree(gsf); + return ret; +} + +static int compat_ipv6_set_mcast_msfilter(struct sock *sk, sockptr_t optval, + int optlen) +{ + const int size0 = offsetof(struct compat_group_filter, gf_slist_flex); + struct compat_group_filter *gf32; + void *p; + int ret; + int n; + + if (optlen < size0) + return -EINVAL; + if (optlen > READ_ONCE(sysctl_optmem_max) - 4) + return -ENOBUFS; + + p = kmalloc(optlen + 4, GFP_KERNEL); + if (!p) + return -ENOMEM; + + gf32 = p + 4; /* we want ->gf_group and ->gf_slist_flex aligned */ + ret = -EFAULT; + if (copy_from_sockptr(gf32, optval, optlen)) + goto out_free_p; + + /* numsrc >= (4G-140)/128 overflow in 32 bits */ + ret = -ENOBUFS; + n = gf32->gf_numsrc; + if (n >= 0x1ffffffU || n > sysctl_mld_max_msf) + goto out_free_p; + + ret = -EINVAL; + if (offsetof(struct compat_group_filter, gf_slist_flex[n]) > optlen) + goto out_free_p; + + ret = ip6_mc_msfilter(sk, &(struct group_filter){ + .gf_interface = gf32->gf_interface, + .gf_group = gf32->gf_group, + .gf_fmode = gf32->gf_fmode, + .gf_numsrc = gf32->gf_numsrc}, gf32->gf_slist_flex); + +out_free_p: + kfree(p); + return ret; +} + +static int ipv6_mcast_join_leave(struct sock *sk, int optname, + sockptr_t optval, int optlen) +{ + struct sockaddr_in6 *psin6; + struct group_req greq; + + if (optlen < sizeof(greq)) + return -EINVAL; + if (copy_from_sockptr(&greq, optval, sizeof(greq))) + return -EFAULT; + + if (greq.gr_group.ss_family != AF_INET6) + return -EADDRNOTAVAIL; + psin6 = (struct sockaddr_in6 *)&greq.gr_group; + if (optname == MCAST_JOIN_GROUP) + return ipv6_sock_mc_join(sk, greq.gr_interface, + &psin6->sin6_addr); + return ipv6_sock_mc_drop(sk, greq.gr_interface, &psin6->sin6_addr); +} + +static int compat_ipv6_mcast_join_leave(struct sock *sk, int optname, + sockptr_t optval, int optlen) +{ + struct compat_group_req gr32; + struct sockaddr_in6 *psin6; + + if (optlen < sizeof(gr32)) + return -EINVAL; + if (copy_from_sockptr(&gr32, optval, sizeof(gr32))) + return -EFAULT; + + if (gr32.gr_group.ss_family != AF_INET6) + return -EADDRNOTAVAIL; + psin6 = (struct sockaddr_in6 *)&gr32.gr_group; + if (optname == MCAST_JOIN_GROUP) + return ipv6_sock_mc_join(sk, gr32.gr_interface, + &psin6->sin6_addr); + return ipv6_sock_mc_drop(sk, gr32.gr_interface, &psin6->sin6_addr); +} + +static int ipv6_set_opt_hdr(struct sock *sk, int optname, sockptr_t optval, + int optlen) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct ipv6_opt_hdr *new = NULL; + struct net *net = sock_net(sk); + struct ipv6_txoptions *opt; + int err; + + /* hop-by-hop / destination options are privileged option */ + if (optname != IPV6_RTHDR && !sockopt_ns_capable(net->user_ns, CAP_NET_RAW)) + return -EPERM; + + /* remove any sticky options header with a zero option + * length, per RFC3542. + */ + if (optlen > 0) { + if (sockptr_is_null(optval)) + return -EINVAL; + if (optlen < sizeof(struct ipv6_opt_hdr) || + optlen & 0x7 || + optlen > 8 * 255) + return -EINVAL; + + new = memdup_sockptr(optval, optlen); + if (IS_ERR(new)) + return PTR_ERR(new); + if (unlikely(ipv6_optlen(new) > optlen)) { + kfree(new); + return -EINVAL; + } + } + + opt = rcu_dereference_protected(np->opt, lockdep_sock_is_held(sk)); + opt = ipv6_renew_options(sk, opt, optname, new); + kfree(new); + if (IS_ERR(opt)) + return PTR_ERR(opt); + + /* routing header option needs extra check */ + err = -EINVAL; + if (optname == IPV6_RTHDR && opt && opt->srcrt) { + struct ipv6_rt_hdr *rthdr = opt->srcrt; + switch (rthdr->type) { +#if IS_ENABLED(CONFIG_IPV6_MIP6) + case IPV6_SRCRT_TYPE_2: + if (rthdr->hdrlen != 2 || rthdr->segments_left != 1) + goto sticky_done; + break; +#endif + case IPV6_SRCRT_TYPE_4: + { + struct ipv6_sr_hdr *srh = + (struct ipv6_sr_hdr *)opt->srcrt; + + if (!seg6_validate_srh(srh, optlen, false)) + goto sticky_done; + break; + } + default: + goto sticky_done; + } + } + + err = 0; + opt = ipv6_update_options(sk, opt); +sticky_done: + if (opt) { + atomic_sub(opt->tot_len, &sk->sk_omem_alloc); + txopt_put(opt); + } + return err; +} + +int do_ipv6_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct net *net = sock_net(sk); + int val, valbool; + int retv = -ENOPROTOOPT; + bool needs_rtnl = setsockopt_needs_rtnl(optname); + + if (sockptr_is_null(optval)) + val = 0; + else { + if (optlen >= sizeof(int)) { + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + } else + val = 0; + } + + valbool = (val != 0); + + if (ip6_mroute_opt(optname)) + return ip6_mroute_setsockopt(sk, optname, optval, optlen); + + if (needs_rtnl) + rtnl_lock(); + sockopt_lock_sock(sk); + + /* Another thread has converted the socket into IPv4 with + * IPV6_ADDRFORM concurrently. + */ + if (unlikely(sk->sk_family != AF_INET6)) + goto unlock; + + switch (optname) { + + case IPV6_ADDRFORM: + if (optlen < sizeof(int)) + goto e_inval; + if (val == PF_INET) { + if (sk->sk_type == SOCK_RAW) + break; + + if (sk->sk_protocol == IPPROTO_UDP || + sk->sk_protocol == IPPROTO_UDPLITE) { + struct udp_sock *up = udp_sk(sk); + if (up->pending == AF_INET6) { + retv = -EBUSY; + break; + } + } else if (sk->sk_protocol == IPPROTO_TCP) { + if (sk->sk_prot != &tcpv6_prot) { + retv = -EBUSY; + break; + } + } else { + break; + } + + if (sk->sk_state != TCP_ESTABLISHED) { + retv = -ENOTCONN; + break; + } + + if (ipv6_only_sock(sk) || + !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) { + retv = -EADDRNOTAVAIL; + break; + } + + __ipv6_sock_mc_close(sk); + __ipv6_sock_ac_close(sk); + + /* + * Sock is moving from IPv6 to IPv4 (sk_prot), so + * remove it from the refcnt debug socks count in the + * original family... + */ + sk_refcnt_debug_dec(sk); + + if (sk->sk_protocol == IPPROTO_TCP) { + struct inet_connection_sock *icsk = inet_csk(sk); + + sock_prot_inuse_add(net, sk->sk_prot, -1); + sock_prot_inuse_add(net, &tcp_prot, 1); + + /* Paired with READ_ONCE(sk->sk_prot) in inet6_stream_ops */ + WRITE_ONCE(sk->sk_prot, &tcp_prot); + /* Paired with READ_ONCE() in tcp_(get|set)sockopt() */ + WRITE_ONCE(icsk->icsk_af_ops, &ipv4_specific); + sk->sk_socket->ops = &inet_stream_ops; + sk->sk_family = PF_INET; + tcp_sync_mss(sk, icsk->icsk_pmtu_cookie); + } else { + struct proto *prot = &udp_prot; + + if (sk->sk_protocol == IPPROTO_UDPLITE) + prot = &udplite_prot; + + sock_prot_inuse_add(net, sk->sk_prot, -1); + sock_prot_inuse_add(net, prot, 1); + + /* Paired with READ_ONCE(sk->sk_prot) in inet6_dgram_ops */ + WRITE_ONCE(sk->sk_prot, prot); + sk->sk_socket->ops = &inet_dgram_ops; + sk->sk_family = PF_INET; + } + + /* Disable all options not to allocate memory anymore, + * but there is still a race. See the lockless path + * in udpv6_sendmsg() and ipv6_local_rxpmtu(). + */ + np->rxopt.all = 0; + + inet6_cleanup_sock(sk); + + /* + * ... and add it to the refcnt debug socks count + * in the new family. -acme + */ + sk_refcnt_debug_inc(sk); + module_put(THIS_MODULE); + retv = 0; + break; + } + goto e_inval; + + case IPV6_V6ONLY: + if (optlen < sizeof(int) || + inet_sk(sk)->inet_num) + goto e_inval; + sk->sk_ipv6only = valbool; + retv = 0; + break; + + case IPV6_RECVPKTINFO: + if (optlen < sizeof(int)) + goto e_inval; + np->rxopt.bits.rxinfo = valbool; + retv = 0; + break; + + case IPV6_2292PKTINFO: + if (optlen < sizeof(int)) + goto e_inval; + np->rxopt.bits.rxoinfo = valbool; + retv = 0; + break; + + case IPV6_RECVHOPLIMIT: + if (optlen < sizeof(int)) + goto e_inval; + np->rxopt.bits.rxhlim = valbool; + retv = 0; + break; + + case IPV6_2292HOPLIMIT: + if (optlen < sizeof(int)) + goto e_inval; + np->rxopt.bits.rxohlim = valbool; + retv = 0; + break; + + case IPV6_RECVRTHDR: + if (optlen < sizeof(int)) + goto e_inval; + np->rxopt.bits.srcrt = valbool; + retv = 0; + break; + + case IPV6_2292RTHDR: + if (optlen < sizeof(int)) + goto e_inval; + np->rxopt.bits.osrcrt = valbool; + retv = 0; + break; + + case IPV6_RECVHOPOPTS: + if (optlen < sizeof(int)) + goto e_inval; + np->rxopt.bits.hopopts = valbool; + retv = 0; + break; + + case IPV6_2292HOPOPTS: + if (optlen < sizeof(int)) + goto e_inval; + np->rxopt.bits.ohopopts = valbool; + retv = 0; + break; + + case IPV6_RECVDSTOPTS: + if (optlen < sizeof(int)) + goto e_inval; + np->rxopt.bits.dstopts = valbool; + retv = 0; + break; + + case IPV6_2292DSTOPTS: + if (optlen < sizeof(int)) + goto e_inval; + np->rxopt.bits.odstopts = valbool; + retv = 0; + break; + + case IPV6_TCLASS: + if (optlen < sizeof(int)) + goto e_inval; + if (val < -1 || val > 0xff) + goto e_inval; + /* RFC 3542, 6.5: default traffic class of 0x0 */ + if (val == -1) + val = 0; + if (sk->sk_type == SOCK_STREAM) { + val &= ~INET_ECN_MASK; + val |= np->tclass & INET_ECN_MASK; + } + if (np->tclass != val) { + np->tclass = val; + sk_dst_reset(sk); + } + retv = 0; + break; + + case IPV6_RECVTCLASS: + if (optlen < sizeof(int)) + goto e_inval; + np->rxopt.bits.rxtclass = valbool; + retv = 0; + break; + + case IPV6_FLOWINFO: + if (optlen < sizeof(int)) + goto e_inval; + np->rxopt.bits.rxflow = valbool; + retv = 0; + break; + + case IPV6_RECVPATHMTU: + if (optlen < sizeof(int)) + goto e_inval; + np->rxopt.bits.rxpmtu = valbool; + retv = 0; + break; + + case IPV6_TRANSPARENT: + if (valbool && !sockopt_ns_capable(net->user_ns, CAP_NET_RAW) && + !sockopt_ns_capable(net->user_ns, CAP_NET_ADMIN)) { + retv = -EPERM; + break; + } + if (optlen < sizeof(int)) + goto e_inval; + /* we don't have a separate transparent bit for IPV6 we use the one in the IPv4 socket */ + inet_sk(sk)->transparent = valbool; + retv = 0; + break; + + case IPV6_FREEBIND: + if (optlen < sizeof(int)) + goto e_inval; + /* we also don't have a separate freebind bit for IPV6 */ + inet_sk(sk)->freebind = valbool; + retv = 0; + break; + + case IPV6_RECVORIGDSTADDR: + if (optlen < sizeof(int)) + goto e_inval; + np->rxopt.bits.rxorigdstaddr = valbool; + retv = 0; + break; + + case IPV6_HOPOPTS: + case IPV6_RTHDRDSTOPTS: + case IPV6_RTHDR: + case IPV6_DSTOPTS: + retv = ipv6_set_opt_hdr(sk, optname, optval, optlen); + break; + + case IPV6_PKTINFO: + { + struct in6_pktinfo pkt; + + if (optlen == 0) + goto e_inval; + else if (optlen < sizeof(struct in6_pktinfo) || + sockptr_is_null(optval)) + goto e_inval; + + if (copy_from_sockptr(&pkt, optval, sizeof(pkt))) { + retv = -EFAULT; + break; + } + if (!sk_dev_equal_l3scope(sk, pkt.ipi6_ifindex)) + goto e_inval; + + np->sticky_pktinfo.ipi6_ifindex = pkt.ipi6_ifindex; + np->sticky_pktinfo.ipi6_addr = pkt.ipi6_addr; + retv = 0; + break; + } + + case IPV6_2292PKTOPTIONS: + { + struct ipv6_txoptions *opt = NULL; + struct msghdr msg; + struct flowi6 fl6; + struct ipcm6_cookie ipc6; + + memset(&fl6, 0, sizeof(fl6)); + fl6.flowi6_oif = sk->sk_bound_dev_if; + fl6.flowi6_mark = sk->sk_mark; + + if (optlen == 0) + goto update; + + /* 1K is probably excessive + * 1K is surely not enough, 2K per standard header is 16K. + */ + retv = -EINVAL; + if (optlen > 64*1024) + break; + + opt = sock_kmalloc(sk, sizeof(*opt) + optlen, GFP_KERNEL); + retv = -ENOBUFS; + if (!opt) + break; + + memset(opt, 0, sizeof(*opt)); + refcount_set(&opt->refcnt, 1); + opt->tot_len = sizeof(*opt) + optlen; + retv = -EFAULT; + if (copy_from_sockptr(opt + 1, optval, optlen)) + goto done; + + msg.msg_controllen = optlen; + msg.msg_control = (void *)(opt+1); + ipc6.opt = opt; + + retv = ip6_datagram_send_ctl(net, sk, &msg, &fl6, &ipc6); + if (retv) + goto done; +update: + retv = 0; + opt = ipv6_update_options(sk, opt); +done: + if (opt) { + atomic_sub(opt->tot_len, &sk->sk_omem_alloc); + txopt_put(opt); + } + break; + } + case IPV6_UNICAST_HOPS: + if (optlen < sizeof(int)) + goto e_inval; + if (val > 255 || val < -1) + goto e_inval; + np->hop_limit = val; + retv = 0; + break; + + case IPV6_MULTICAST_HOPS: + if (sk->sk_type == SOCK_STREAM) + break; + if (optlen < sizeof(int)) + goto e_inval; + if (val > 255 || val < -1) + goto e_inval; + np->mcast_hops = (val == -1 ? IPV6_DEFAULT_MCASTHOPS : val); + retv = 0; + break; + + case IPV6_MULTICAST_LOOP: + if (optlen < sizeof(int)) + goto e_inval; + if (val != valbool) + goto e_inval; + np->mc_loop = valbool; + retv = 0; + break; + + case IPV6_UNICAST_IF: + { + struct net_device *dev = NULL; + int ifindex; + + if (optlen != sizeof(int)) + goto e_inval; + + ifindex = (__force int)ntohl((__force __be32)val); + if (ifindex == 0) { + np->ucast_oif = 0; + retv = 0; + break; + } + + dev = dev_get_by_index(net, ifindex); + retv = -EADDRNOTAVAIL; + if (!dev) + break; + dev_put(dev); + + retv = -EINVAL; + if (sk->sk_bound_dev_if) + break; + + np->ucast_oif = ifindex; + retv = 0; + break; + } + + case IPV6_MULTICAST_IF: + if (sk->sk_type == SOCK_STREAM) + break; + if (optlen < sizeof(int)) + goto e_inval; + + if (val) { + struct net_device *dev; + int midx; + + rcu_read_lock(); + + dev = dev_get_by_index_rcu(net, val); + if (!dev) { + rcu_read_unlock(); + retv = -ENODEV; + break; + } + midx = l3mdev_master_ifindex_rcu(dev); + + rcu_read_unlock(); + + if (sk->sk_bound_dev_if && + sk->sk_bound_dev_if != val && + (!midx || midx != sk->sk_bound_dev_if)) + goto e_inval; + } + np->mcast_oif = val; + retv = 0; + break; + case IPV6_ADD_MEMBERSHIP: + case IPV6_DROP_MEMBERSHIP: + { + struct ipv6_mreq mreq; + + if (optlen < sizeof(struct ipv6_mreq)) + goto e_inval; + + retv = -EPROTO; + if (inet_sk(sk)->is_icsk) + break; + + retv = -EFAULT; + if (copy_from_sockptr(&mreq, optval, sizeof(struct ipv6_mreq))) + break; + + if (optname == IPV6_ADD_MEMBERSHIP) + retv = ipv6_sock_mc_join(sk, mreq.ipv6mr_ifindex, &mreq.ipv6mr_multiaddr); + else + retv = ipv6_sock_mc_drop(sk, mreq.ipv6mr_ifindex, &mreq.ipv6mr_multiaddr); + break; + } + case IPV6_JOIN_ANYCAST: + case IPV6_LEAVE_ANYCAST: + { + struct ipv6_mreq mreq; + + if (optlen < sizeof(struct ipv6_mreq)) + goto e_inval; + + retv = -EFAULT; + if (copy_from_sockptr(&mreq, optval, sizeof(struct ipv6_mreq))) + break; + + if (optname == IPV6_JOIN_ANYCAST) + retv = ipv6_sock_ac_join(sk, mreq.ipv6mr_ifindex, &mreq.ipv6mr_acaddr); + else + retv = ipv6_sock_ac_drop(sk, mreq.ipv6mr_ifindex, &mreq.ipv6mr_acaddr); + break; + } + case IPV6_MULTICAST_ALL: + if (optlen < sizeof(int)) + goto e_inval; + np->mc_all = valbool; + retv = 0; + break; + + case MCAST_JOIN_GROUP: + case MCAST_LEAVE_GROUP: + if (in_compat_syscall()) + retv = compat_ipv6_mcast_join_leave(sk, optname, optval, + optlen); + else + retv = ipv6_mcast_join_leave(sk, optname, optval, + optlen); + break; + case MCAST_JOIN_SOURCE_GROUP: + case MCAST_LEAVE_SOURCE_GROUP: + case MCAST_BLOCK_SOURCE: + case MCAST_UNBLOCK_SOURCE: + retv = do_ipv6_mcast_group_source(sk, optname, optval, optlen); + break; + case MCAST_MSFILTER: + if (in_compat_syscall()) + retv = compat_ipv6_set_mcast_msfilter(sk, optval, + optlen); + else + retv = ipv6_set_mcast_msfilter(sk, optval, optlen); + break; + case IPV6_ROUTER_ALERT: + if (optlen < sizeof(int)) + goto e_inval; + retv = ip6_ra_control(sk, val); + break; + case IPV6_ROUTER_ALERT_ISOLATE: + if (optlen < sizeof(int)) + goto e_inval; + np->rtalert_isolate = valbool; + retv = 0; + break; + case IPV6_MTU_DISCOVER: + if (optlen < sizeof(int)) + goto e_inval; + if (val < IPV6_PMTUDISC_DONT || val > IPV6_PMTUDISC_OMIT) + goto e_inval; + np->pmtudisc = val; + retv = 0; + break; + case IPV6_MTU: + if (optlen < sizeof(int)) + goto e_inval; + if (val && val < IPV6_MIN_MTU) + goto e_inval; + np->frag_size = val; + retv = 0; + break; + case IPV6_RECVERR: + if (optlen < sizeof(int)) + goto e_inval; + np->recverr = valbool; + if (!val) + skb_queue_purge(&sk->sk_error_queue); + retv = 0; + break; + case IPV6_FLOWINFO_SEND: + if (optlen < sizeof(int)) + goto e_inval; + np->sndflow = valbool; + retv = 0; + break; + case IPV6_FLOWLABEL_MGR: + retv = ipv6_flowlabel_opt(sk, optval, optlen); + break; + case IPV6_IPSEC_POLICY: + case IPV6_XFRM_POLICY: + retv = -EPERM; + if (!sockopt_ns_capable(net->user_ns, CAP_NET_ADMIN)) + break; + retv = xfrm_user_policy(sk, optname, optval, optlen); + break; + + case IPV6_ADDR_PREFERENCES: + if (optlen < sizeof(int)) + goto e_inval; + retv = __ip6_sock_set_addr_preferences(sk, val); + break; + case IPV6_MINHOPCOUNT: + if (optlen < sizeof(int)) + goto e_inval; + if (val < 0 || val > 255) + goto e_inval; + + if (val) + static_branch_enable(&ip6_min_hopcount); + + /* tcp_v6_err() and tcp_v6_rcv() might read min_hopcount + * while we are changing it. + */ + WRITE_ONCE(np->min_hopcount, val); + retv = 0; + break; + case IPV6_DONTFRAG: + np->dontfrag = valbool; + retv = 0; + break; + case IPV6_AUTOFLOWLABEL: + np->autoflowlabel = valbool; + np->autoflowlabel_set = 1; + retv = 0; + break; + case IPV6_RECVFRAGSIZE: + np->rxopt.bits.recvfragsize = valbool; + retv = 0; + break; + case IPV6_RECVERR_RFC4884: + if (optlen < sizeof(int)) + goto e_inval; + if (val < 0 || val > 1) + goto e_inval; + np->recverr_rfc4884 = valbool; + retv = 0; + break; + } + +unlock: + sockopt_release_sock(sk); + if (needs_rtnl) + rtnl_unlock(); + + return retv; + +e_inval: + sockopt_release_sock(sk); + if (needs_rtnl) + rtnl_unlock(); + return -EINVAL; +} + +int ipv6_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, + unsigned int optlen) +{ + int err; + + if (level == SOL_IP && sk->sk_type != SOCK_RAW) + return udp_prot.setsockopt(sk, level, optname, optval, optlen); + + if (level != SOL_IPV6) + return -ENOPROTOOPT; + + err = do_ipv6_setsockopt(sk, level, optname, optval, optlen); +#ifdef CONFIG_NETFILTER + /* we need to exclude all possible ENOPROTOOPTs except default case */ + if (err == -ENOPROTOOPT && optname != IPV6_IPSEC_POLICY && + optname != IPV6_XFRM_POLICY) + err = nf_setsockopt(sk, PF_INET6, optname, optval, optlen); +#endif + return err; +} +EXPORT_SYMBOL(ipv6_setsockopt); + +static int ipv6_getsockopt_sticky(struct sock *sk, struct ipv6_txoptions *opt, + int optname, sockptr_t optval, int len) +{ + struct ipv6_opt_hdr *hdr; + + if (!opt) + return 0; + + switch (optname) { + case IPV6_HOPOPTS: + hdr = opt->hopopt; + break; + case IPV6_RTHDRDSTOPTS: + hdr = opt->dst0opt; + break; + case IPV6_RTHDR: + hdr = (struct ipv6_opt_hdr *)opt->srcrt; + break; + case IPV6_DSTOPTS: + hdr = opt->dst1opt; + break; + default: + return -EINVAL; /* should not happen */ + } + + if (!hdr) + return 0; + + len = min_t(unsigned int, len, ipv6_optlen(hdr)); + if (copy_to_sockptr(optval, hdr, len)) + return -EFAULT; + return len; +} + +static int ipv6_get_msfilter(struct sock *sk, sockptr_t optval, + sockptr_t optlen, int len) +{ + const int size0 = offsetof(struct group_filter, gf_slist_flex); + struct group_filter gsf; + int num; + int err; + + if (len < size0) + return -EINVAL; + if (copy_from_sockptr(&gsf, optval, size0)) + return -EFAULT; + if (gsf.gf_group.ss_family != AF_INET6) + return -EADDRNOTAVAIL; + num = gsf.gf_numsrc; + sockopt_lock_sock(sk); + err = ip6_mc_msfget(sk, &gsf, optval, size0); + if (!err) { + if (num > gsf.gf_numsrc) + num = gsf.gf_numsrc; + len = GROUP_FILTER_SIZE(num); + if (copy_to_sockptr(optlen, &len, sizeof(int)) || + copy_to_sockptr(optval, &gsf, size0)) + err = -EFAULT; + } + sockopt_release_sock(sk); + return err; +} + +static int compat_ipv6_get_msfilter(struct sock *sk, sockptr_t optval, + sockptr_t optlen, int len) +{ + const int size0 = offsetof(struct compat_group_filter, gf_slist_flex); + struct compat_group_filter gf32; + struct group_filter gf; + int err; + int num; + + if (len < size0) + return -EINVAL; + + if (copy_from_sockptr(&gf32, optval, size0)) + return -EFAULT; + gf.gf_interface = gf32.gf_interface; + gf.gf_fmode = gf32.gf_fmode; + num = gf.gf_numsrc = gf32.gf_numsrc; + gf.gf_group = gf32.gf_group; + + if (gf.gf_group.ss_family != AF_INET6) + return -EADDRNOTAVAIL; + + sockopt_lock_sock(sk); + err = ip6_mc_msfget(sk, &gf, optval, size0); + sockopt_release_sock(sk); + if (err) + return err; + if (num > gf.gf_numsrc) + num = gf.gf_numsrc; + len = GROUP_FILTER_SIZE(num) - (sizeof(gf)-sizeof(gf32)); + if (copy_to_sockptr(optlen, &len, sizeof(int)) || + copy_to_sockptr_offset(optval, offsetof(struct compat_group_filter, gf_fmode), + &gf.gf_fmode, sizeof(gf32.gf_fmode)) || + copy_to_sockptr_offset(optval, offsetof(struct compat_group_filter, gf_numsrc), + &gf.gf_numsrc, sizeof(gf32.gf_numsrc))) + return -EFAULT; + return 0; +} + +int do_ipv6_getsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, sockptr_t optlen) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + int len; + int val; + + if (ip6_mroute_opt(optname)) + return ip6_mroute_getsockopt(sk, optname, optval, optlen); + + if (copy_from_sockptr(&len, optlen, sizeof(int))) + return -EFAULT; + switch (optname) { + case IPV6_ADDRFORM: + if (sk->sk_protocol != IPPROTO_UDP && + sk->sk_protocol != IPPROTO_UDPLITE && + sk->sk_protocol != IPPROTO_TCP) + return -ENOPROTOOPT; + if (sk->sk_state != TCP_ESTABLISHED) + return -ENOTCONN; + val = sk->sk_family; + break; + case MCAST_MSFILTER: + if (in_compat_syscall()) + return compat_ipv6_get_msfilter(sk, optval, optlen, len); + return ipv6_get_msfilter(sk, optval, optlen, len); + case IPV6_2292PKTOPTIONS: + { + struct msghdr msg; + struct sk_buff *skb; + + if (sk->sk_type != SOCK_STREAM) + return -ENOPROTOOPT; + + if (optval.is_kernel) { + msg.msg_control_is_user = false; + msg.msg_control = optval.kernel; + } else { + msg.msg_control_is_user = true; + msg.msg_control_user = optval.user; + } + msg.msg_controllen = len; + msg.msg_flags = 0; + + sockopt_lock_sock(sk); + skb = np->pktoptions; + if (skb) + ip6_datagram_recv_ctl(sk, &msg, skb); + sockopt_release_sock(sk); + if (!skb) { + if (np->rxopt.bits.rxinfo) { + struct in6_pktinfo src_info; + src_info.ipi6_ifindex = np->mcast_oif ? np->mcast_oif : + np->sticky_pktinfo.ipi6_ifindex; + src_info.ipi6_addr = np->mcast_oif ? sk->sk_v6_daddr : np->sticky_pktinfo.ipi6_addr; + put_cmsg(&msg, SOL_IPV6, IPV6_PKTINFO, sizeof(src_info), &src_info); + } + if (np->rxopt.bits.rxhlim) { + int hlim = np->mcast_hops; + put_cmsg(&msg, SOL_IPV6, IPV6_HOPLIMIT, sizeof(hlim), &hlim); + } + if (np->rxopt.bits.rxtclass) { + int tclass = (int)ip6_tclass(np->rcv_flowinfo); + + put_cmsg(&msg, SOL_IPV6, IPV6_TCLASS, sizeof(tclass), &tclass); + } + if (np->rxopt.bits.rxoinfo) { + struct in6_pktinfo src_info; + src_info.ipi6_ifindex = np->mcast_oif ? np->mcast_oif : + np->sticky_pktinfo.ipi6_ifindex; + src_info.ipi6_addr = np->mcast_oif ? sk->sk_v6_daddr : + np->sticky_pktinfo.ipi6_addr; + put_cmsg(&msg, SOL_IPV6, IPV6_2292PKTINFO, sizeof(src_info), &src_info); + } + if (np->rxopt.bits.rxohlim) { + int hlim = np->mcast_hops; + put_cmsg(&msg, SOL_IPV6, IPV6_2292HOPLIMIT, sizeof(hlim), &hlim); + } + if (np->rxopt.bits.rxflow) { + __be32 flowinfo = np->rcv_flowinfo; + + put_cmsg(&msg, SOL_IPV6, IPV6_FLOWINFO, sizeof(flowinfo), &flowinfo); + } + } + len -= msg.msg_controllen; + return copy_to_sockptr(optlen, &len, sizeof(int)); + } + case IPV6_MTU: + { + struct dst_entry *dst; + + val = 0; + rcu_read_lock(); + dst = __sk_dst_get(sk); + if (dst) + val = dst_mtu(dst); + rcu_read_unlock(); + if (!val) + return -ENOTCONN; + break; + } + + case IPV6_V6ONLY: + val = sk->sk_ipv6only; + break; + + case IPV6_RECVPKTINFO: + val = np->rxopt.bits.rxinfo; + break; + + case IPV6_2292PKTINFO: + val = np->rxopt.bits.rxoinfo; + break; + + case IPV6_RECVHOPLIMIT: + val = np->rxopt.bits.rxhlim; + break; + + case IPV6_2292HOPLIMIT: + val = np->rxopt.bits.rxohlim; + break; + + case IPV6_RECVRTHDR: + val = np->rxopt.bits.srcrt; + break; + + case IPV6_2292RTHDR: + val = np->rxopt.bits.osrcrt; + break; + + case IPV6_HOPOPTS: + case IPV6_RTHDRDSTOPTS: + case IPV6_RTHDR: + case IPV6_DSTOPTS: + { + struct ipv6_txoptions *opt; + + sockopt_lock_sock(sk); + opt = rcu_dereference_protected(np->opt, + lockdep_sock_is_held(sk)); + len = ipv6_getsockopt_sticky(sk, opt, optname, optval, len); + sockopt_release_sock(sk); + /* check if ipv6_getsockopt_sticky() returns err code */ + if (len < 0) + return len; + return copy_to_sockptr(optlen, &len, sizeof(int)); + } + + case IPV6_RECVHOPOPTS: + val = np->rxopt.bits.hopopts; + break; + + case IPV6_2292HOPOPTS: + val = np->rxopt.bits.ohopopts; + break; + + case IPV6_RECVDSTOPTS: + val = np->rxopt.bits.dstopts; + break; + + case IPV6_2292DSTOPTS: + val = np->rxopt.bits.odstopts; + break; + + case IPV6_TCLASS: + val = np->tclass; + break; + + case IPV6_RECVTCLASS: + val = np->rxopt.bits.rxtclass; + break; + + case IPV6_FLOWINFO: + val = np->rxopt.bits.rxflow; + break; + + case IPV6_RECVPATHMTU: + val = np->rxopt.bits.rxpmtu; + break; + + case IPV6_PATHMTU: + { + struct dst_entry *dst; + struct ip6_mtuinfo mtuinfo; + + if (len < sizeof(mtuinfo)) + return -EINVAL; + + len = sizeof(mtuinfo); + memset(&mtuinfo, 0, sizeof(mtuinfo)); + + rcu_read_lock(); + dst = __sk_dst_get(sk); + if (dst) + mtuinfo.ip6m_mtu = dst_mtu(dst); + rcu_read_unlock(); + if (!mtuinfo.ip6m_mtu) + return -ENOTCONN; + + if (copy_to_sockptr(optlen, &len, sizeof(int))) + return -EFAULT; + if (copy_to_sockptr(optval, &mtuinfo, len)) + return -EFAULT; + + return 0; + } + + case IPV6_TRANSPARENT: + val = inet_sk(sk)->transparent; + break; + + case IPV6_FREEBIND: + val = inet_sk(sk)->freebind; + break; + + case IPV6_RECVORIGDSTADDR: + val = np->rxopt.bits.rxorigdstaddr; + break; + + case IPV6_UNICAST_HOPS: + case IPV6_MULTICAST_HOPS: + { + struct dst_entry *dst; + + if (optname == IPV6_UNICAST_HOPS) + val = np->hop_limit; + else + val = np->mcast_hops; + + if (val < 0) { + rcu_read_lock(); + dst = __sk_dst_get(sk); + if (dst) + val = ip6_dst_hoplimit(dst); + rcu_read_unlock(); + } + + if (val < 0) + val = sock_net(sk)->ipv6.devconf_all->hop_limit; + break; + } + + case IPV6_MULTICAST_LOOP: + val = np->mc_loop; + break; + + case IPV6_MULTICAST_IF: + val = np->mcast_oif; + break; + + case IPV6_MULTICAST_ALL: + val = np->mc_all; + break; + + case IPV6_UNICAST_IF: + val = (__force int)htonl((__u32) np->ucast_oif); + break; + + case IPV6_MTU_DISCOVER: + val = np->pmtudisc; + break; + + case IPV6_RECVERR: + val = np->recverr; + break; + + case IPV6_FLOWINFO_SEND: + val = np->sndflow; + break; + + case IPV6_FLOWLABEL_MGR: + { + struct in6_flowlabel_req freq; + int flags; + + if (len < sizeof(freq)) + return -EINVAL; + + if (copy_from_sockptr(&freq, optval, sizeof(freq))) + return -EFAULT; + + if (freq.flr_action != IPV6_FL_A_GET) + return -EINVAL; + + len = sizeof(freq); + flags = freq.flr_flags; + + memset(&freq, 0, sizeof(freq)); + + val = ipv6_flowlabel_opt_get(sk, &freq, flags); + if (val < 0) + return val; + + if (copy_to_sockptr(optlen, &len, sizeof(int))) + return -EFAULT; + if (copy_to_sockptr(optval, &freq, len)) + return -EFAULT; + + return 0; + } + + case IPV6_ADDR_PREFERENCES: + val = 0; + + if (np->srcprefs & IPV6_PREFER_SRC_TMP) + val |= IPV6_PREFER_SRC_TMP; + else if (np->srcprefs & IPV6_PREFER_SRC_PUBLIC) + val |= IPV6_PREFER_SRC_PUBLIC; + else { + /* XXX: should we return system default? */ + val |= IPV6_PREFER_SRC_PUBTMP_DEFAULT; + } + + if (np->srcprefs & IPV6_PREFER_SRC_COA) + val |= IPV6_PREFER_SRC_COA; + else + val |= IPV6_PREFER_SRC_HOME; + break; + + case IPV6_MINHOPCOUNT: + val = np->min_hopcount; + break; + + case IPV6_DONTFRAG: + val = np->dontfrag; + break; + + case IPV6_AUTOFLOWLABEL: + val = ip6_autoflowlabel(sock_net(sk), np); + break; + + case IPV6_RECVFRAGSIZE: + val = np->rxopt.bits.recvfragsize; + break; + + case IPV6_ROUTER_ALERT_ISOLATE: + val = np->rtalert_isolate; + break; + + case IPV6_RECVERR_RFC4884: + val = np->recverr_rfc4884; + break; + + default: + return -ENOPROTOOPT; + } + len = min_t(unsigned int, sizeof(int), len); + if (copy_to_sockptr(optlen, &len, sizeof(int))) + return -EFAULT; + if (copy_to_sockptr(optval, &val, len)) + return -EFAULT; + return 0; +} + +int ipv6_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen) +{ + int err; + + if (level == SOL_IP && sk->sk_type != SOCK_RAW) + return udp_prot.getsockopt(sk, level, optname, optval, optlen); + + if (level != SOL_IPV6) + return -ENOPROTOOPT; + + err = do_ipv6_getsockopt(sk, level, optname, + USER_SOCKPTR(optval), USER_SOCKPTR(optlen)); +#ifdef CONFIG_NETFILTER + /* we need to exclude all possible ENOPROTOOPTs except default case */ + if (err == -ENOPROTOOPT && optname != IPV6_2292PKTOPTIONS) { + int len; + + if (get_user(len, optlen)) + return -EFAULT; + + err = nf_getsockopt(sk, PF_INET6, optname, optval, &len); + if (err >= 0) + err = put_user(len, optlen); + } +#endif + return err; +} +EXPORT_SYMBOL(ipv6_getsockopt); diff --git a/net/ipv6/mcast.c b/net/ipv6/mcast.c new file mode 100644 index 000000000..566f3b7b9 --- /dev/null +++ b/net/ipv6/mcast.c @@ -0,0 +1,3212 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Multicast support for IPv6 + * Linux INET6 implementation + * + * Authors: + * Pedro Roque + * + * Based on linux/ipv4/igmp.c and linux/ipv4/ip_sockglue.c + */ + +/* Changes: + * + * yoshfuji : fix format of router-alert option + * YOSHIFUJI Hideaki @USAGI: + * Fixed source address for MLD message based on + * . + * YOSHIFUJI Hideaki @USAGI: + * - Ignore Queries for invalid addresses. + * - MLD for link-local addresses. + * David L Stevens : + * - MLDv2 support + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +/* Ensure that we have struct in6_addr aligned on 32bit word. */ +static int __mld2_query_bugs[] __attribute__((__unused__)) = { + BUILD_BUG_ON_ZERO(offsetof(struct mld2_query, mld2q_srcs) % 4), + BUILD_BUG_ON_ZERO(offsetof(struct mld2_report, mld2r_grec) % 4), + BUILD_BUG_ON_ZERO(offsetof(struct mld2_grec, grec_mca) % 4) +}; + +static struct workqueue_struct *mld_wq; +static struct in6_addr mld2_all_mcr = MLD2_ALL_MCR_INIT; + +static void igmp6_join_group(struct ifmcaddr6 *ma); +static void igmp6_leave_group(struct ifmcaddr6 *ma); +static void mld_mca_work(struct work_struct *work); + +static void mld_ifc_event(struct inet6_dev *idev); +static bool mld_in_v1_mode(const struct inet6_dev *idev); +static int sf_setstate(struct ifmcaddr6 *pmc); +static void sf_markstate(struct ifmcaddr6 *pmc); +static void ip6_mc_clear_src(struct ifmcaddr6 *pmc); +static int ip6_mc_del_src(struct inet6_dev *idev, const struct in6_addr *pmca, + int sfmode, int sfcount, const struct in6_addr *psfsrc, + int delta); +static int ip6_mc_add_src(struct inet6_dev *idev, const struct in6_addr *pmca, + int sfmode, int sfcount, const struct in6_addr *psfsrc, + int delta); +static int ip6_mc_leave_src(struct sock *sk, struct ipv6_mc_socklist *iml, + struct inet6_dev *idev); +static int __ipv6_dev_mc_inc(struct net_device *dev, + const struct in6_addr *addr, unsigned int mode); + +#define MLD_QRV_DEFAULT 2 +/* RFC3810, 9.2. Query Interval */ +#define MLD_QI_DEFAULT (125 * HZ) +/* RFC3810, 9.3. Query Response Interval */ +#define MLD_QRI_DEFAULT (10 * HZ) + +/* RFC3810, 8.1 Query Version Distinctions */ +#define MLD_V1_QUERY_LEN 24 +#define MLD_V2_QUERY_LEN_MIN 28 + +#define IPV6_MLD_MAX_MSF 64 + +int sysctl_mld_max_msf __read_mostly = IPV6_MLD_MAX_MSF; +int sysctl_mld_qrv __read_mostly = MLD_QRV_DEFAULT; + +/* + * socket join on multicast group + */ +#define mc_dereference(e, idev) \ + rcu_dereference_protected(e, lockdep_is_held(&(idev)->mc_lock)) + +#define sock_dereference(e, sk) \ + rcu_dereference_protected(e, lockdep_sock_is_held(sk)) + +#define for_each_pmc_socklock(np, sk, pmc) \ + for (pmc = sock_dereference((np)->ipv6_mc_list, sk); \ + pmc; \ + pmc = sock_dereference(pmc->next, sk)) + +#define for_each_pmc_rcu(np, pmc) \ + for (pmc = rcu_dereference((np)->ipv6_mc_list); \ + pmc; \ + pmc = rcu_dereference(pmc->next)) + +#define for_each_psf_mclock(mc, psf) \ + for (psf = mc_dereference((mc)->mca_sources, mc->idev); \ + psf; \ + psf = mc_dereference(psf->sf_next, mc->idev)) + +#define for_each_psf_rcu(mc, psf) \ + for (psf = rcu_dereference((mc)->mca_sources); \ + psf; \ + psf = rcu_dereference(psf->sf_next)) + +#define for_each_psf_tomb(mc, psf) \ + for (psf = mc_dereference((mc)->mca_tomb, mc->idev); \ + psf; \ + psf = mc_dereference(psf->sf_next, mc->idev)) + +#define for_each_mc_mclock(idev, mc) \ + for (mc = mc_dereference((idev)->mc_list, idev); \ + mc; \ + mc = mc_dereference(mc->next, idev)) + +#define for_each_mc_rcu(idev, mc) \ + for (mc = rcu_dereference((idev)->mc_list); \ + mc; \ + mc = rcu_dereference(mc->next)) + +#define for_each_mc_tomb(idev, mc) \ + for (mc = mc_dereference((idev)->mc_tomb, idev); \ + mc; \ + mc = mc_dereference(mc->next, idev)) + +static int unsolicited_report_interval(struct inet6_dev *idev) +{ + int iv; + + if (mld_in_v1_mode(idev)) + iv = idev->cnf.mldv1_unsolicited_report_interval; + else + iv = idev->cnf.mldv2_unsolicited_report_interval; + + return iv > 0 ? iv : 1; +} + +static int __ipv6_sock_mc_join(struct sock *sk, int ifindex, + const struct in6_addr *addr, unsigned int mode) +{ + struct net_device *dev = NULL; + struct ipv6_mc_socklist *mc_lst; + struct ipv6_pinfo *np = inet6_sk(sk); + struct net *net = sock_net(sk); + int err; + + ASSERT_RTNL(); + + if (!ipv6_addr_is_multicast(addr)) + return -EINVAL; + + for_each_pmc_socklock(np, sk, mc_lst) { + if ((ifindex == 0 || mc_lst->ifindex == ifindex) && + ipv6_addr_equal(&mc_lst->addr, addr)) + return -EADDRINUSE; + } + + mc_lst = sock_kmalloc(sk, sizeof(struct ipv6_mc_socklist), GFP_KERNEL); + + if (!mc_lst) + return -ENOMEM; + + mc_lst->next = NULL; + mc_lst->addr = *addr; + + if (ifindex == 0) { + struct rt6_info *rt; + rt = rt6_lookup(net, addr, NULL, 0, NULL, 0); + if (rt) { + dev = rt->dst.dev; + ip6_rt_put(rt); + } + } else + dev = __dev_get_by_index(net, ifindex); + + if (!dev) { + sock_kfree_s(sk, mc_lst, sizeof(*mc_lst)); + return -ENODEV; + } + + mc_lst->ifindex = dev->ifindex; + mc_lst->sfmode = mode; + RCU_INIT_POINTER(mc_lst->sflist, NULL); + + /* + * now add/increase the group membership on the device + */ + + err = __ipv6_dev_mc_inc(dev, addr, mode); + + if (err) { + sock_kfree_s(sk, mc_lst, sizeof(*mc_lst)); + return err; + } + + mc_lst->next = np->ipv6_mc_list; + rcu_assign_pointer(np->ipv6_mc_list, mc_lst); + + return 0; +} + +int ipv6_sock_mc_join(struct sock *sk, int ifindex, const struct in6_addr *addr) +{ + return __ipv6_sock_mc_join(sk, ifindex, addr, MCAST_EXCLUDE); +} +EXPORT_SYMBOL(ipv6_sock_mc_join); + +int ipv6_sock_mc_join_ssm(struct sock *sk, int ifindex, + const struct in6_addr *addr, unsigned int mode) +{ + return __ipv6_sock_mc_join(sk, ifindex, addr, mode); +} + +/* + * socket leave on multicast group + */ +int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct ipv6_mc_socklist *mc_lst; + struct ipv6_mc_socklist __rcu **lnk; + struct net *net = sock_net(sk); + + ASSERT_RTNL(); + + if (!ipv6_addr_is_multicast(addr)) + return -EINVAL; + + for (lnk = &np->ipv6_mc_list; + (mc_lst = sock_dereference(*lnk, sk)) != NULL; + lnk = &mc_lst->next) { + if ((ifindex == 0 || mc_lst->ifindex == ifindex) && + ipv6_addr_equal(&mc_lst->addr, addr)) { + struct net_device *dev; + + *lnk = mc_lst->next; + + dev = __dev_get_by_index(net, mc_lst->ifindex); + if (dev) { + struct inet6_dev *idev = __in6_dev_get(dev); + + ip6_mc_leave_src(sk, mc_lst, idev); + if (idev) + __ipv6_dev_mc_dec(idev, &mc_lst->addr); + } else { + ip6_mc_leave_src(sk, mc_lst, NULL); + } + + atomic_sub(sizeof(*mc_lst), &sk->sk_omem_alloc); + kfree_rcu(mc_lst, rcu); + return 0; + } + } + + return -EADDRNOTAVAIL; +} +EXPORT_SYMBOL(ipv6_sock_mc_drop); + +static struct inet6_dev *ip6_mc_find_dev_rtnl(struct net *net, + const struct in6_addr *group, + int ifindex) +{ + struct net_device *dev = NULL; + struct inet6_dev *idev = NULL; + + if (ifindex == 0) { + struct rt6_info *rt = rt6_lookup(net, group, NULL, 0, NULL, 0); + + if (rt) { + dev = rt->dst.dev; + ip6_rt_put(rt); + } + } else { + dev = __dev_get_by_index(net, ifindex); + } + + if (!dev) + return NULL; + idev = __in6_dev_get(dev); + if (!idev) + return NULL; + if (idev->dead) + return NULL; + return idev; +} + +void __ipv6_sock_mc_close(struct sock *sk) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct ipv6_mc_socklist *mc_lst; + struct net *net = sock_net(sk); + + ASSERT_RTNL(); + + while ((mc_lst = sock_dereference(np->ipv6_mc_list, sk)) != NULL) { + struct net_device *dev; + + np->ipv6_mc_list = mc_lst->next; + + dev = __dev_get_by_index(net, mc_lst->ifindex); + if (dev) { + struct inet6_dev *idev = __in6_dev_get(dev); + + ip6_mc_leave_src(sk, mc_lst, idev); + if (idev) + __ipv6_dev_mc_dec(idev, &mc_lst->addr); + } else { + ip6_mc_leave_src(sk, mc_lst, NULL); + } + + atomic_sub(sizeof(*mc_lst), &sk->sk_omem_alloc); + kfree_rcu(mc_lst, rcu); + } +} + +void ipv6_sock_mc_close(struct sock *sk) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + + if (!rcu_access_pointer(np->ipv6_mc_list)) + return; + + rtnl_lock(); + lock_sock(sk); + __ipv6_sock_mc_close(sk); + release_sock(sk); + rtnl_unlock(); +} + +int ip6_mc_source(int add, int omode, struct sock *sk, + struct group_source_req *pgsr) +{ + struct in6_addr *source, *group; + struct ipv6_mc_socklist *pmc; + struct inet6_dev *idev; + struct ipv6_pinfo *inet6 = inet6_sk(sk); + struct ip6_sf_socklist *psl; + struct net *net = sock_net(sk); + int i, j, rv; + int leavegroup = 0; + int err; + + source = &((struct sockaddr_in6 *)&pgsr->gsr_source)->sin6_addr; + group = &((struct sockaddr_in6 *)&pgsr->gsr_group)->sin6_addr; + + if (!ipv6_addr_is_multicast(group)) + return -EINVAL; + + idev = ip6_mc_find_dev_rtnl(net, group, pgsr->gsr_interface); + if (!idev) + return -ENODEV; + + err = -EADDRNOTAVAIL; + + mutex_lock(&idev->mc_lock); + for_each_pmc_socklock(inet6, sk, pmc) { + if (pgsr->gsr_interface && pmc->ifindex != pgsr->gsr_interface) + continue; + if (ipv6_addr_equal(&pmc->addr, group)) + break; + } + if (!pmc) { /* must have a prior join */ + err = -EINVAL; + goto done; + } + /* if a source filter was set, must be the same mode as before */ + if (rcu_access_pointer(pmc->sflist)) { + if (pmc->sfmode != omode) { + err = -EINVAL; + goto done; + } + } else if (pmc->sfmode != omode) { + /* allow mode switches for empty-set filters */ + ip6_mc_add_src(idev, group, omode, 0, NULL, 0); + ip6_mc_del_src(idev, group, pmc->sfmode, 0, NULL, 0); + pmc->sfmode = omode; + } + + psl = sock_dereference(pmc->sflist, sk); + if (!add) { + if (!psl) + goto done; /* err = -EADDRNOTAVAIL */ + rv = !0; + for (i = 0; i < psl->sl_count; i++) { + rv = !ipv6_addr_equal(&psl->sl_addr[i], source); + if (rv == 0) + break; + } + if (rv) /* source not found */ + goto done; /* err = -EADDRNOTAVAIL */ + + /* special case - (INCLUDE, empty) == LEAVE_GROUP */ + if (psl->sl_count == 1 && omode == MCAST_INCLUDE) { + leavegroup = 1; + goto done; + } + + /* update the interface filter */ + ip6_mc_del_src(idev, group, omode, 1, source, 1); + + for (j = i+1; j < psl->sl_count; j++) + psl->sl_addr[j-1] = psl->sl_addr[j]; + psl->sl_count--; + err = 0; + goto done; + } + /* else, add a new source to the filter */ + + if (psl && psl->sl_count >= sysctl_mld_max_msf) { + err = -ENOBUFS; + goto done; + } + if (!psl || psl->sl_count == psl->sl_max) { + struct ip6_sf_socklist *newpsl; + int count = IP6_SFBLOCK; + + if (psl) + count += psl->sl_max; + newpsl = sock_kmalloc(sk, struct_size(newpsl, sl_addr, count), + GFP_KERNEL); + if (!newpsl) { + err = -ENOBUFS; + goto done; + } + newpsl->sl_max = count; + newpsl->sl_count = count - IP6_SFBLOCK; + if (psl) { + for (i = 0; i < psl->sl_count; i++) + newpsl->sl_addr[i] = psl->sl_addr[i]; + atomic_sub(struct_size(psl, sl_addr, psl->sl_max), + &sk->sk_omem_alloc); + } + rcu_assign_pointer(pmc->sflist, newpsl); + kfree_rcu(psl, rcu); + psl = newpsl; + } + rv = 1; /* > 0 for insert logic below if sl_count is 0 */ + for (i = 0; i < psl->sl_count; i++) { + rv = !ipv6_addr_equal(&psl->sl_addr[i], source); + if (rv == 0) /* There is an error in the address. */ + goto done; + } + for (j = psl->sl_count-1; j >= i; j--) + psl->sl_addr[j+1] = psl->sl_addr[j]; + psl->sl_addr[i] = *source; + psl->sl_count++; + err = 0; + /* update the interface list */ + ip6_mc_add_src(idev, group, omode, 1, source, 1); +done: + mutex_unlock(&idev->mc_lock); + if (leavegroup) + err = ipv6_sock_mc_drop(sk, pgsr->gsr_interface, group); + return err; +} + +int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf, + struct sockaddr_storage *list) +{ + const struct in6_addr *group; + struct ipv6_mc_socklist *pmc; + struct inet6_dev *idev; + struct ipv6_pinfo *inet6 = inet6_sk(sk); + struct ip6_sf_socklist *newpsl, *psl; + struct net *net = sock_net(sk); + int leavegroup = 0; + int i, err; + + group = &((struct sockaddr_in6 *)&gsf->gf_group)->sin6_addr; + + if (!ipv6_addr_is_multicast(group)) + return -EINVAL; + if (gsf->gf_fmode != MCAST_INCLUDE && + gsf->gf_fmode != MCAST_EXCLUDE) + return -EINVAL; + + idev = ip6_mc_find_dev_rtnl(net, group, gsf->gf_interface); + if (!idev) + return -ENODEV; + + err = 0; + + if (gsf->gf_fmode == MCAST_INCLUDE && gsf->gf_numsrc == 0) { + leavegroup = 1; + goto done; + } + + for_each_pmc_socklock(inet6, sk, pmc) { + if (pmc->ifindex != gsf->gf_interface) + continue; + if (ipv6_addr_equal(&pmc->addr, group)) + break; + } + if (!pmc) { /* must have a prior join */ + err = -EINVAL; + goto done; + } + if (gsf->gf_numsrc) { + newpsl = sock_kmalloc(sk, struct_size(newpsl, sl_addr, + gsf->gf_numsrc), + GFP_KERNEL); + if (!newpsl) { + err = -ENOBUFS; + goto done; + } + newpsl->sl_max = newpsl->sl_count = gsf->gf_numsrc; + for (i = 0; i < newpsl->sl_count; ++i, ++list) { + struct sockaddr_in6 *psin6; + + psin6 = (struct sockaddr_in6 *)list; + newpsl->sl_addr[i] = psin6->sin6_addr; + } + mutex_lock(&idev->mc_lock); + err = ip6_mc_add_src(idev, group, gsf->gf_fmode, + newpsl->sl_count, newpsl->sl_addr, 0); + if (err) { + mutex_unlock(&idev->mc_lock); + sock_kfree_s(sk, newpsl, struct_size(newpsl, sl_addr, + newpsl->sl_max)); + goto done; + } + mutex_unlock(&idev->mc_lock); + } else { + newpsl = NULL; + mutex_lock(&idev->mc_lock); + ip6_mc_add_src(idev, group, gsf->gf_fmode, 0, NULL, 0); + mutex_unlock(&idev->mc_lock); + } + + mutex_lock(&idev->mc_lock); + psl = sock_dereference(pmc->sflist, sk); + if (psl) { + ip6_mc_del_src(idev, group, pmc->sfmode, + psl->sl_count, psl->sl_addr, 0); + atomic_sub(struct_size(psl, sl_addr, psl->sl_max), + &sk->sk_omem_alloc); + } else { + ip6_mc_del_src(idev, group, pmc->sfmode, 0, NULL, 0); + } + rcu_assign_pointer(pmc->sflist, newpsl); + mutex_unlock(&idev->mc_lock); + kfree_rcu(psl, rcu); + pmc->sfmode = gsf->gf_fmode; + err = 0; +done: + if (leavegroup) + err = ipv6_sock_mc_drop(sk, gsf->gf_interface, group); + return err; +} + +int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf, + sockptr_t optval, size_t ss_offset) +{ + struct ipv6_pinfo *inet6 = inet6_sk(sk); + const struct in6_addr *group; + struct ipv6_mc_socklist *pmc; + struct ip6_sf_socklist *psl; + int i, count, copycount; + + group = &((struct sockaddr_in6 *)&gsf->gf_group)->sin6_addr; + + if (!ipv6_addr_is_multicast(group)) + return -EINVAL; + + /* changes to the ipv6_mc_list require the socket lock and + * rtnl lock. We have the socket lock, so reading the list is safe. + */ + + for_each_pmc_socklock(inet6, sk, pmc) { + if (pmc->ifindex != gsf->gf_interface) + continue; + if (ipv6_addr_equal(group, &pmc->addr)) + break; + } + if (!pmc) /* must have a prior join */ + return -EADDRNOTAVAIL; + + gsf->gf_fmode = pmc->sfmode; + psl = sock_dereference(pmc->sflist, sk); + count = psl ? psl->sl_count : 0; + + copycount = count < gsf->gf_numsrc ? count : gsf->gf_numsrc; + gsf->gf_numsrc = count; + for (i = 0; i < copycount; i++) { + struct sockaddr_in6 *psin6; + struct sockaddr_storage ss; + + psin6 = (struct sockaddr_in6 *)&ss; + memset(&ss, 0, sizeof(ss)); + psin6->sin6_family = AF_INET6; + psin6->sin6_addr = psl->sl_addr[i]; + if (copy_to_sockptr_offset(optval, ss_offset, &ss, sizeof(ss))) + return -EFAULT; + ss_offset += sizeof(ss); + } + return 0; +} + +bool inet6_mc_check(struct sock *sk, const struct in6_addr *mc_addr, + const struct in6_addr *src_addr) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct ipv6_mc_socklist *mc; + struct ip6_sf_socklist *psl; + bool rv = true; + + rcu_read_lock(); + for_each_pmc_rcu(np, mc) { + if (ipv6_addr_equal(&mc->addr, mc_addr)) + break; + } + if (!mc) { + rcu_read_unlock(); + return np->mc_all; + } + psl = rcu_dereference(mc->sflist); + if (!psl) { + rv = mc->sfmode == MCAST_EXCLUDE; + } else { + int i; + + for (i = 0; i < psl->sl_count; i++) { + if (ipv6_addr_equal(&psl->sl_addr[i], src_addr)) + break; + } + if (mc->sfmode == MCAST_INCLUDE && i >= psl->sl_count) + rv = false; + if (mc->sfmode == MCAST_EXCLUDE && i < psl->sl_count) + rv = false; + } + rcu_read_unlock(); + + return rv; +} + +/* called with mc_lock */ +static void igmp6_group_added(struct ifmcaddr6 *mc) +{ + struct net_device *dev = mc->idev->dev; + char buf[MAX_ADDR_LEN]; + + if (IPV6_ADDR_MC_SCOPE(&mc->mca_addr) < + IPV6_ADDR_SCOPE_LINKLOCAL) + return; + + if (!(mc->mca_flags&MAF_LOADED)) { + mc->mca_flags |= MAF_LOADED; + if (ndisc_mc_map(&mc->mca_addr, buf, dev, 0) == 0) + dev_mc_add(dev, buf); + } + + if (!(dev->flags & IFF_UP) || (mc->mca_flags & MAF_NOREPORT)) + return; + + if (mld_in_v1_mode(mc->idev)) { + igmp6_join_group(mc); + return; + } + /* else v2 */ + + /* Based on RFC3810 6.1, for newly added INCLUDE SSM, we + * should not send filter-mode change record as the mode + * should be from IN() to IN(A). + */ + if (mc->mca_sfmode == MCAST_EXCLUDE) + mc->mca_crcount = mc->idev->mc_qrv; + + mld_ifc_event(mc->idev); +} + +/* called with mc_lock */ +static void igmp6_group_dropped(struct ifmcaddr6 *mc) +{ + struct net_device *dev = mc->idev->dev; + char buf[MAX_ADDR_LEN]; + + if (IPV6_ADDR_MC_SCOPE(&mc->mca_addr) < + IPV6_ADDR_SCOPE_LINKLOCAL) + return; + + if (mc->mca_flags&MAF_LOADED) { + mc->mca_flags &= ~MAF_LOADED; + if (ndisc_mc_map(&mc->mca_addr, buf, dev, 0) == 0) + dev_mc_del(dev, buf); + } + + if (mc->mca_flags & MAF_NOREPORT) + return; + + if (!mc->idev->dead) + igmp6_leave_group(mc); + + if (cancel_delayed_work(&mc->mca_work)) + refcount_dec(&mc->mca_refcnt); +} + +/* + * deleted ifmcaddr6 manipulation + * called with mc_lock + */ +static void mld_add_delrec(struct inet6_dev *idev, struct ifmcaddr6 *im) +{ + struct ifmcaddr6 *pmc; + + /* this is an "ifmcaddr6" for convenience; only the fields below + * are actually used. In particular, the refcnt and users are not + * used for management of the delete list. Using the same structure + * for deleted items allows change reports to use common code with + * non-deleted or query-response MCA's. + */ + pmc = kzalloc(sizeof(*pmc), GFP_KERNEL); + if (!pmc) + return; + + pmc->idev = im->idev; + in6_dev_hold(idev); + pmc->mca_addr = im->mca_addr; + pmc->mca_crcount = idev->mc_qrv; + pmc->mca_sfmode = im->mca_sfmode; + if (pmc->mca_sfmode == MCAST_INCLUDE) { + struct ip6_sf_list *psf; + + rcu_assign_pointer(pmc->mca_tomb, + mc_dereference(im->mca_tomb, idev)); + rcu_assign_pointer(pmc->mca_sources, + mc_dereference(im->mca_sources, idev)); + RCU_INIT_POINTER(im->mca_tomb, NULL); + RCU_INIT_POINTER(im->mca_sources, NULL); + + for_each_psf_mclock(pmc, psf) + psf->sf_crcount = pmc->mca_crcount; + } + + rcu_assign_pointer(pmc->next, idev->mc_tomb); + rcu_assign_pointer(idev->mc_tomb, pmc); +} + +/* called with mc_lock */ +static void mld_del_delrec(struct inet6_dev *idev, struct ifmcaddr6 *im) +{ + struct ip6_sf_list *psf, *sources, *tomb; + struct in6_addr *pmca = &im->mca_addr; + struct ifmcaddr6 *pmc, *pmc_prev; + + pmc_prev = NULL; + for_each_mc_tomb(idev, pmc) { + if (ipv6_addr_equal(&pmc->mca_addr, pmca)) + break; + pmc_prev = pmc; + } + if (pmc) { + if (pmc_prev) + rcu_assign_pointer(pmc_prev->next, pmc->next); + else + rcu_assign_pointer(idev->mc_tomb, pmc->next); + } + + if (pmc) { + im->idev = pmc->idev; + if (im->mca_sfmode == MCAST_INCLUDE) { + tomb = rcu_replace_pointer(im->mca_tomb, + mc_dereference(pmc->mca_tomb, pmc->idev), + lockdep_is_held(&im->idev->mc_lock)); + rcu_assign_pointer(pmc->mca_tomb, tomb); + + sources = rcu_replace_pointer(im->mca_sources, + mc_dereference(pmc->mca_sources, pmc->idev), + lockdep_is_held(&im->idev->mc_lock)); + rcu_assign_pointer(pmc->mca_sources, sources); + for_each_psf_mclock(im, psf) + psf->sf_crcount = idev->mc_qrv; + } else { + im->mca_crcount = idev->mc_qrv; + } + in6_dev_put(pmc->idev); + ip6_mc_clear_src(pmc); + kfree_rcu(pmc, rcu); + } +} + +/* called with mc_lock */ +static void mld_clear_delrec(struct inet6_dev *idev) +{ + struct ifmcaddr6 *pmc, *nextpmc; + + pmc = mc_dereference(idev->mc_tomb, idev); + RCU_INIT_POINTER(idev->mc_tomb, NULL); + + for (; pmc; pmc = nextpmc) { + nextpmc = mc_dereference(pmc->next, idev); + ip6_mc_clear_src(pmc); + in6_dev_put(pmc->idev); + kfree_rcu(pmc, rcu); + } + + /* clear dead sources, too */ + for_each_mc_mclock(idev, pmc) { + struct ip6_sf_list *psf, *psf_next; + + psf = mc_dereference(pmc->mca_tomb, idev); + RCU_INIT_POINTER(pmc->mca_tomb, NULL); + for (; psf; psf = psf_next) { + psf_next = mc_dereference(psf->sf_next, idev); + kfree_rcu(psf, rcu); + } + } +} + +static void mld_clear_query(struct inet6_dev *idev) +{ + struct sk_buff *skb; + + spin_lock_bh(&idev->mc_query_lock); + while ((skb = __skb_dequeue(&idev->mc_query_queue))) + kfree_skb(skb); + spin_unlock_bh(&idev->mc_query_lock); +} + +static void mld_clear_report(struct inet6_dev *idev) +{ + struct sk_buff *skb; + + spin_lock_bh(&idev->mc_report_lock); + while ((skb = __skb_dequeue(&idev->mc_report_queue))) + kfree_skb(skb); + spin_unlock_bh(&idev->mc_report_lock); +} + +static void mca_get(struct ifmcaddr6 *mc) +{ + refcount_inc(&mc->mca_refcnt); +} + +static void ma_put(struct ifmcaddr6 *mc) +{ + if (refcount_dec_and_test(&mc->mca_refcnt)) { + in6_dev_put(mc->idev); + kfree_rcu(mc, rcu); + } +} + +/* called with mc_lock */ +static struct ifmcaddr6 *mca_alloc(struct inet6_dev *idev, + const struct in6_addr *addr, + unsigned int mode) +{ + struct ifmcaddr6 *mc; + + mc = kzalloc(sizeof(*mc), GFP_KERNEL); + if (!mc) + return NULL; + + INIT_DELAYED_WORK(&mc->mca_work, mld_mca_work); + + mc->mca_addr = *addr; + mc->idev = idev; /* reference taken by caller */ + mc->mca_users = 1; + /* mca_stamp should be updated upon changes */ + mc->mca_cstamp = mc->mca_tstamp = jiffies; + refcount_set(&mc->mca_refcnt, 1); + + mc->mca_sfmode = mode; + mc->mca_sfcount[mode] = 1; + + if (ipv6_addr_is_ll_all_nodes(&mc->mca_addr) || + IPV6_ADDR_MC_SCOPE(&mc->mca_addr) < IPV6_ADDR_SCOPE_LINKLOCAL) + mc->mca_flags |= MAF_NOREPORT; + + return mc; +} + +/* + * device multicast group inc (add if not found) + */ +static int __ipv6_dev_mc_inc(struct net_device *dev, + const struct in6_addr *addr, unsigned int mode) +{ + struct ifmcaddr6 *mc; + struct inet6_dev *idev; + + ASSERT_RTNL(); + + /* we need to take a reference on idev */ + idev = in6_dev_get(dev); + + if (!idev) + return -EINVAL; + + if (idev->dead) { + in6_dev_put(idev); + return -ENODEV; + } + + mutex_lock(&idev->mc_lock); + for_each_mc_mclock(idev, mc) { + if (ipv6_addr_equal(&mc->mca_addr, addr)) { + mc->mca_users++; + ip6_mc_add_src(idev, &mc->mca_addr, mode, 0, NULL, 0); + mutex_unlock(&idev->mc_lock); + in6_dev_put(idev); + return 0; + } + } + + mc = mca_alloc(idev, addr, mode); + if (!mc) { + mutex_unlock(&idev->mc_lock); + in6_dev_put(idev); + return -ENOMEM; + } + + rcu_assign_pointer(mc->next, idev->mc_list); + rcu_assign_pointer(idev->mc_list, mc); + + mca_get(mc); + + mld_del_delrec(idev, mc); + igmp6_group_added(mc); + mutex_unlock(&idev->mc_lock); + ma_put(mc); + return 0; +} + +int ipv6_dev_mc_inc(struct net_device *dev, const struct in6_addr *addr) +{ + return __ipv6_dev_mc_inc(dev, addr, MCAST_EXCLUDE); +} +EXPORT_SYMBOL(ipv6_dev_mc_inc); + +/* + * device multicast group del + */ +int __ipv6_dev_mc_dec(struct inet6_dev *idev, const struct in6_addr *addr) +{ + struct ifmcaddr6 *ma, __rcu **map; + + ASSERT_RTNL(); + + mutex_lock(&idev->mc_lock); + for (map = &idev->mc_list; + (ma = mc_dereference(*map, idev)); + map = &ma->next) { + if (ipv6_addr_equal(&ma->mca_addr, addr)) { + if (--ma->mca_users == 0) { + *map = ma->next; + + igmp6_group_dropped(ma); + ip6_mc_clear_src(ma); + mutex_unlock(&idev->mc_lock); + + ma_put(ma); + return 0; + } + mutex_unlock(&idev->mc_lock); + return 0; + } + } + + mutex_unlock(&idev->mc_lock); + return -ENOENT; +} + +int ipv6_dev_mc_dec(struct net_device *dev, const struct in6_addr *addr) +{ + struct inet6_dev *idev; + int err; + + ASSERT_RTNL(); + + idev = __in6_dev_get(dev); + if (!idev) + err = -ENODEV; + else + err = __ipv6_dev_mc_dec(idev, addr); + + return err; +} +EXPORT_SYMBOL(ipv6_dev_mc_dec); + +/* + * check if the interface/address pair is valid + */ +bool ipv6_chk_mcast_addr(struct net_device *dev, const struct in6_addr *group, + const struct in6_addr *src_addr) +{ + struct inet6_dev *idev; + struct ifmcaddr6 *mc; + bool rv = false; + + rcu_read_lock(); + idev = __in6_dev_get(dev); + if (idev) { + for_each_mc_rcu(idev, mc) { + if (ipv6_addr_equal(&mc->mca_addr, group)) + break; + } + if (mc) { + if (src_addr && !ipv6_addr_any(src_addr)) { + struct ip6_sf_list *psf; + + for_each_psf_rcu(mc, psf) { + if (ipv6_addr_equal(&psf->sf_addr, src_addr)) + break; + } + if (psf) + rv = psf->sf_count[MCAST_INCLUDE] || + psf->sf_count[MCAST_EXCLUDE] != + mc->mca_sfcount[MCAST_EXCLUDE]; + else + rv = mc->mca_sfcount[MCAST_EXCLUDE] != 0; + } else + rv = true; /* don't filter unspecified source */ + } + } + rcu_read_unlock(); + return rv; +} + +/* called with mc_lock */ +static void mld_gq_start_work(struct inet6_dev *idev) +{ + unsigned long tv = prandom_u32_max(idev->mc_maxdelay); + + idev->mc_gq_running = 1; + if (!mod_delayed_work(mld_wq, &idev->mc_gq_work, tv + 2)) + in6_dev_hold(idev); +} + +/* called with mc_lock */ +static void mld_gq_stop_work(struct inet6_dev *idev) +{ + idev->mc_gq_running = 0; + if (cancel_delayed_work(&idev->mc_gq_work)) + __in6_dev_put(idev); +} + +/* called with mc_lock */ +static void mld_ifc_start_work(struct inet6_dev *idev, unsigned long delay) +{ + unsigned long tv = prandom_u32_max(delay); + + if (!mod_delayed_work(mld_wq, &idev->mc_ifc_work, tv + 2)) + in6_dev_hold(idev); +} + +/* called with mc_lock */ +static void mld_ifc_stop_work(struct inet6_dev *idev) +{ + idev->mc_ifc_count = 0; + if (cancel_delayed_work(&idev->mc_ifc_work)) + __in6_dev_put(idev); +} + +/* called with mc_lock */ +static void mld_dad_start_work(struct inet6_dev *idev, unsigned long delay) +{ + unsigned long tv = prandom_u32_max(delay); + + if (!mod_delayed_work(mld_wq, &idev->mc_dad_work, tv + 2)) + in6_dev_hold(idev); +} + +static void mld_dad_stop_work(struct inet6_dev *idev) +{ + if (cancel_delayed_work(&idev->mc_dad_work)) + __in6_dev_put(idev); +} + +static void mld_query_stop_work(struct inet6_dev *idev) +{ + spin_lock_bh(&idev->mc_query_lock); + if (cancel_delayed_work(&idev->mc_query_work)) + __in6_dev_put(idev); + spin_unlock_bh(&idev->mc_query_lock); +} + +static void mld_report_stop_work(struct inet6_dev *idev) +{ + if (cancel_delayed_work_sync(&idev->mc_report_work)) + __in6_dev_put(idev); +} + +/* + * IGMP handling (alias multicast ICMPv6 messages) + * called with mc_lock + */ +static void igmp6_group_queried(struct ifmcaddr6 *ma, unsigned long resptime) +{ + unsigned long delay = resptime; + + /* Do not start work for these addresses */ + if (ipv6_addr_is_ll_all_nodes(&ma->mca_addr) || + IPV6_ADDR_MC_SCOPE(&ma->mca_addr) < IPV6_ADDR_SCOPE_LINKLOCAL) + return; + + if (cancel_delayed_work(&ma->mca_work)) { + refcount_dec(&ma->mca_refcnt); + delay = ma->mca_work.timer.expires - jiffies; + } + + if (delay >= resptime) + delay = prandom_u32_max(resptime); + + if (!mod_delayed_work(mld_wq, &ma->mca_work, delay)) + refcount_inc(&ma->mca_refcnt); + ma->mca_flags |= MAF_TIMER_RUNNING; +} + +/* mark EXCLUDE-mode sources + * called with mc_lock + */ +static bool mld_xmarksources(struct ifmcaddr6 *pmc, int nsrcs, + const struct in6_addr *srcs) +{ + struct ip6_sf_list *psf; + int i, scount; + + scount = 0; + for_each_psf_mclock(pmc, psf) { + if (scount == nsrcs) + break; + for (i = 0; i < nsrcs; i++) { + /* skip inactive filters */ + if (psf->sf_count[MCAST_INCLUDE] || + pmc->mca_sfcount[MCAST_EXCLUDE] != + psf->sf_count[MCAST_EXCLUDE]) + break; + if (ipv6_addr_equal(&srcs[i], &psf->sf_addr)) { + scount++; + break; + } + } + } + pmc->mca_flags &= ~MAF_GSQUERY; + if (scount == nsrcs) /* all sources excluded */ + return false; + return true; +} + +/* called with mc_lock */ +static bool mld_marksources(struct ifmcaddr6 *pmc, int nsrcs, + const struct in6_addr *srcs) +{ + struct ip6_sf_list *psf; + int i, scount; + + if (pmc->mca_sfmode == MCAST_EXCLUDE) + return mld_xmarksources(pmc, nsrcs, srcs); + + /* mark INCLUDE-mode sources */ + + scount = 0; + for_each_psf_mclock(pmc, psf) { + if (scount == nsrcs) + break; + for (i = 0; i < nsrcs; i++) { + if (ipv6_addr_equal(&srcs[i], &psf->sf_addr)) { + psf->sf_gsresp = 1; + scount++; + break; + } + } + } + if (!scount) { + pmc->mca_flags &= ~MAF_GSQUERY; + return false; + } + pmc->mca_flags |= MAF_GSQUERY; + return true; +} + +static int mld_force_mld_version(const struct inet6_dev *idev) +{ + /* Normally, both are 0 here. If enforcement to a particular is + * being used, individual device enforcement will have a lower + * precedence over 'all' device (.../conf/all/force_mld_version). + */ + + if (dev_net(idev->dev)->ipv6.devconf_all->force_mld_version != 0) + return dev_net(idev->dev)->ipv6.devconf_all->force_mld_version; + else + return idev->cnf.force_mld_version; +} + +static bool mld_in_v2_mode_only(const struct inet6_dev *idev) +{ + return mld_force_mld_version(idev) == 2; +} + +static bool mld_in_v1_mode_only(const struct inet6_dev *idev) +{ + return mld_force_mld_version(idev) == 1; +} + +static bool mld_in_v1_mode(const struct inet6_dev *idev) +{ + if (mld_in_v2_mode_only(idev)) + return false; + if (mld_in_v1_mode_only(idev)) + return true; + if (idev->mc_v1_seen && time_before(jiffies, idev->mc_v1_seen)) + return true; + + return false; +} + +static void mld_set_v1_mode(struct inet6_dev *idev) +{ + /* RFC3810, relevant sections: + * - 9.1. Robustness Variable + * - 9.2. Query Interval + * - 9.3. Query Response Interval + * - 9.12. Older Version Querier Present Timeout + */ + unsigned long switchback; + + switchback = (idev->mc_qrv * idev->mc_qi) + idev->mc_qri; + + idev->mc_v1_seen = jiffies + switchback; +} + +static void mld_update_qrv(struct inet6_dev *idev, + const struct mld2_query *mlh2) +{ + /* RFC3810, relevant sections: + * - 5.1.8. QRV (Querier's Robustness Variable) + * - 9.1. Robustness Variable + */ + + /* The value of the Robustness Variable MUST NOT be zero, + * and SHOULD NOT be one. Catch this here if we ever run + * into such a case in future. + */ + const int min_qrv = min(MLD_QRV_DEFAULT, sysctl_mld_qrv); + WARN_ON(idev->mc_qrv == 0); + + if (mlh2->mld2q_qrv > 0) + idev->mc_qrv = mlh2->mld2q_qrv; + + if (unlikely(idev->mc_qrv < min_qrv)) { + net_warn_ratelimited("IPv6: MLD: clamping QRV from %u to %u!\n", + idev->mc_qrv, min_qrv); + idev->mc_qrv = min_qrv; + } +} + +static void mld_update_qi(struct inet6_dev *idev, + const struct mld2_query *mlh2) +{ + /* RFC3810, relevant sections: + * - 5.1.9. QQIC (Querier's Query Interval Code) + * - 9.2. Query Interval + * - 9.12. Older Version Querier Present Timeout + * (the [Query Interval] in the last Query received) + */ + unsigned long mc_qqi; + + if (mlh2->mld2q_qqic < 128) { + mc_qqi = mlh2->mld2q_qqic; + } else { + unsigned long mc_man, mc_exp; + + mc_exp = MLDV2_QQIC_EXP(mlh2->mld2q_qqic); + mc_man = MLDV2_QQIC_MAN(mlh2->mld2q_qqic); + + mc_qqi = (mc_man | 0x10) << (mc_exp + 3); + } + + idev->mc_qi = mc_qqi * HZ; +} + +static void mld_update_qri(struct inet6_dev *idev, + const struct mld2_query *mlh2) +{ + /* RFC3810, relevant sections: + * - 5.1.3. Maximum Response Code + * - 9.3. Query Response Interval + */ + idev->mc_qri = msecs_to_jiffies(mldv2_mrc(mlh2)); +} + +static int mld_process_v1(struct inet6_dev *idev, struct mld_msg *mld, + unsigned long *max_delay, bool v1_query) +{ + unsigned long mldv1_md; + + /* Ignore v1 queries */ + if (mld_in_v2_mode_only(idev)) + return -EINVAL; + + mldv1_md = ntohs(mld->mld_maxdelay); + + /* When in MLDv1 fallback and a MLDv2 router start-up being + * unaware of current MLDv1 operation, the MRC == MRD mapping + * only works when the exponential algorithm is not being + * used (as MLDv1 is unaware of such things). + * + * According to the RFC author, the MLDv2 implementations + * he's aware of all use a MRC < 32768 on start up queries. + * + * Thus, should we *ever* encounter something else larger + * than that, just assume the maximum possible within our + * reach. + */ + if (!v1_query) + mldv1_md = min(mldv1_md, MLDV1_MRD_MAX_COMPAT); + + *max_delay = max(msecs_to_jiffies(mldv1_md), 1UL); + + /* MLDv1 router present: we need to go into v1 mode *only* + * when an MLDv1 query is received as per section 9.12. of + * RFC3810! And we know from RFC2710 section 3.7 that MLDv1 + * queries MUST be of exactly 24 octets. + */ + if (v1_query) + mld_set_v1_mode(idev); + + /* cancel MLDv2 report work */ + mld_gq_stop_work(idev); + /* cancel the interface change work */ + mld_ifc_stop_work(idev); + /* clear deleted report items */ + mld_clear_delrec(idev); + + return 0; +} + +static void mld_process_v2(struct inet6_dev *idev, struct mld2_query *mld, + unsigned long *max_delay) +{ + *max_delay = max(msecs_to_jiffies(mldv2_mrc(mld)), 1UL); + + mld_update_qrv(idev, mld); + mld_update_qi(idev, mld); + mld_update_qri(idev, mld); + + idev->mc_maxdelay = *max_delay; + + return; +} + +/* called with rcu_read_lock() */ +void igmp6_event_query(struct sk_buff *skb) +{ + struct inet6_dev *idev = __in6_dev_get(skb->dev); + + if (!idev || idev->dead) + goto out; + + spin_lock_bh(&idev->mc_query_lock); + if (skb_queue_len(&idev->mc_query_queue) < MLD_MAX_SKBS) { + __skb_queue_tail(&idev->mc_query_queue, skb); + if (!mod_delayed_work(mld_wq, &idev->mc_query_work, 0)) + in6_dev_hold(idev); + skb = NULL; + } + spin_unlock_bh(&idev->mc_query_lock); +out: + kfree_skb(skb); +} + +static void __mld_query_work(struct sk_buff *skb) +{ + struct mld2_query *mlh2 = NULL; + const struct in6_addr *group; + unsigned long max_delay; + struct inet6_dev *idev; + struct ifmcaddr6 *ma; + struct mld_msg *mld; + int group_type; + int mark = 0; + int len, err; + + if (!pskb_may_pull(skb, sizeof(struct in6_addr))) + goto kfree_skb; + + /* compute payload length excluding extension headers */ + len = ntohs(ipv6_hdr(skb)->payload_len) + sizeof(struct ipv6hdr); + len -= skb_network_header_len(skb); + + /* RFC3810 6.2 + * Upon reception of an MLD message that contains a Query, the node + * checks if the source address of the message is a valid link-local + * address, if the Hop Limit is set to 1, and if the Router Alert + * option is present in the Hop-By-Hop Options header of the IPv6 + * packet. If any of these checks fails, the packet is dropped. + */ + if (!(ipv6_addr_type(&ipv6_hdr(skb)->saddr) & IPV6_ADDR_LINKLOCAL) || + ipv6_hdr(skb)->hop_limit != 1 || + !(IP6CB(skb)->flags & IP6SKB_ROUTERALERT) || + IP6CB(skb)->ra != htons(IPV6_OPT_ROUTERALERT_MLD)) + goto kfree_skb; + + idev = in6_dev_get(skb->dev); + if (!idev) + goto kfree_skb; + + mld = (struct mld_msg *)icmp6_hdr(skb); + group = &mld->mld_mca; + group_type = ipv6_addr_type(group); + + if (group_type != IPV6_ADDR_ANY && + !(group_type&IPV6_ADDR_MULTICAST)) + goto out; + + if (len < MLD_V1_QUERY_LEN) { + goto out; + } else if (len == MLD_V1_QUERY_LEN || mld_in_v1_mode(idev)) { + err = mld_process_v1(idev, mld, &max_delay, + len == MLD_V1_QUERY_LEN); + if (err < 0) + goto out; + } else if (len >= MLD_V2_QUERY_LEN_MIN) { + int srcs_offset = sizeof(struct mld2_query) - + sizeof(struct icmp6hdr); + + if (!pskb_may_pull(skb, srcs_offset)) + goto out; + + mlh2 = (struct mld2_query *)skb_transport_header(skb); + + mld_process_v2(idev, mlh2, &max_delay); + + if (group_type == IPV6_ADDR_ANY) { /* general query */ + if (mlh2->mld2q_nsrcs) + goto out; /* no sources allowed */ + + mld_gq_start_work(idev); + goto out; + } + /* mark sources to include, if group & source-specific */ + if (mlh2->mld2q_nsrcs != 0) { + if (!pskb_may_pull(skb, srcs_offset + + ntohs(mlh2->mld2q_nsrcs) * sizeof(struct in6_addr))) + goto out; + + mlh2 = (struct mld2_query *)skb_transport_header(skb); + mark = 1; + } + } else { + goto out; + } + + if (group_type == IPV6_ADDR_ANY) { + for_each_mc_mclock(idev, ma) { + igmp6_group_queried(ma, max_delay); + } + } else { + for_each_mc_mclock(idev, ma) { + if (!ipv6_addr_equal(group, &ma->mca_addr)) + continue; + if (ma->mca_flags & MAF_TIMER_RUNNING) { + /* gsquery <- gsquery && mark */ + if (!mark) + ma->mca_flags &= ~MAF_GSQUERY; + } else { + /* gsquery <- mark */ + if (mark) + ma->mca_flags |= MAF_GSQUERY; + else + ma->mca_flags &= ~MAF_GSQUERY; + } + if (!(ma->mca_flags & MAF_GSQUERY) || + mld_marksources(ma, ntohs(mlh2->mld2q_nsrcs), mlh2->mld2q_srcs)) + igmp6_group_queried(ma, max_delay); + break; + } + } + +out: + in6_dev_put(idev); +kfree_skb: + consume_skb(skb); +} + +static void mld_query_work(struct work_struct *work) +{ + struct inet6_dev *idev = container_of(to_delayed_work(work), + struct inet6_dev, + mc_query_work); + struct sk_buff_head q; + struct sk_buff *skb; + bool rework = false; + int cnt = 0; + + skb_queue_head_init(&q); + + spin_lock_bh(&idev->mc_query_lock); + while ((skb = __skb_dequeue(&idev->mc_query_queue))) { + __skb_queue_tail(&q, skb); + + if (++cnt >= MLD_MAX_QUEUE) { + rework = true; + break; + } + } + spin_unlock_bh(&idev->mc_query_lock); + + mutex_lock(&idev->mc_lock); + while ((skb = __skb_dequeue(&q))) + __mld_query_work(skb); + mutex_unlock(&idev->mc_lock); + + if (rework && queue_delayed_work(mld_wq, &idev->mc_query_work, 0)) + return; + + in6_dev_put(idev); +} + +/* called with rcu_read_lock() */ +void igmp6_event_report(struct sk_buff *skb) +{ + struct inet6_dev *idev = __in6_dev_get(skb->dev); + + if (!idev || idev->dead) + goto out; + + spin_lock_bh(&idev->mc_report_lock); + if (skb_queue_len(&idev->mc_report_queue) < MLD_MAX_SKBS) { + __skb_queue_tail(&idev->mc_report_queue, skb); + if (!mod_delayed_work(mld_wq, &idev->mc_report_work, 0)) + in6_dev_hold(idev); + skb = NULL; + } + spin_unlock_bh(&idev->mc_report_lock); +out: + kfree_skb(skb); +} + +static void __mld_report_work(struct sk_buff *skb) +{ + struct inet6_dev *idev; + struct ifmcaddr6 *ma; + struct mld_msg *mld; + int addr_type; + + /* Our own report looped back. Ignore it. */ + if (skb->pkt_type == PACKET_LOOPBACK) + goto kfree_skb; + + /* send our report if the MC router may not have heard this report */ + if (skb->pkt_type != PACKET_MULTICAST && + skb->pkt_type != PACKET_BROADCAST) + goto kfree_skb; + + if (!pskb_may_pull(skb, sizeof(*mld) - sizeof(struct icmp6hdr))) + goto kfree_skb; + + mld = (struct mld_msg *)icmp6_hdr(skb); + + /* Drop reports with not link local source */ + addr_type = ipv6_addr_type(&ipv6_hdr(skb)->saddr); + if (addr_type != IPV6_ADDR_ANY && + !(addr_type&IPV6_ADDR_LINKLOCAL)) + goto kfree_skb; + + idev = in6_dev_get(skb->dev); + if (!idev) + goto kfree_skb; + + /* + * Cancel the work for this group + */ + + for_each_mc_mclock(idev, ma) { + if (ipv6_addr_equal(&ma->mca_addr, &mld->mld_mca)) { + if (cancel_delayed_work(&ma->mca_work)) + refcount_dec(&ma->mca_refcnt); + ma->mca_flags &= ~(MAF_LAST_REPORTER | + MAF_TIMER_RUNNING); + break; + } + } + + in6_dev_put(idev); +kfree_skb: + consume_skb(skb); +} + +static void mld_report_work(struct work_struct *work) +{ + struct inet6_dev *idev = container_of(to_delayed_work(work), + struct inet6_dev, + mc_report_work); + struct sk_buff_head q; + struct sk_buff *skb; + bool rework = false; + int cnt = 0; + + skb_queue_head_init(&q); + spin_lock_bh(&idev->mc_report_lock); + while ((skb = __skb_dequeue(&idev->mc_report_queue))) { + __skb_queue_tail(&q, skb); + + if (++cnt >= MLD_MAX_QUEUE) { + rework = true; + break; + } + } + spin_unlock_bh(&idev->mc_report_lock); + + mutex_lock(&idev->mc_lock); + while ((skb = __skb_dequeue(&q))) + __mld_report_work(skb); + mutex_unlock(&idev->mc_lock); + + if (rework && queue_delayed_work(mld_wq, &idev->mc_report_work, 0)) + return; + + in6_dev_put(idev); +} + +static bool is_in(struct ifmcaddr6 *pmc, struct ip6_sf_list *psf, int type, + int gdeleted, int sdeleted) +{ + switch (type) { + case MLD2_MODE_IS_INCLUDE: + case MLD2_MODE_IS_EXCLUDE: + if (gdeleted || sdeleted) + return false; + if (!((pmc->mca_flags & MAF_GSQUERY) && !psf->sf_gsresp)) { + if (pmc->mca_sfmode == MCAST_INCLUDE) + return true; + /* don't include if this source is excluded + * in all filters + */ + if (psf->sf_count[MCAST_INCLUDE]) + return type == MLD2_MODE_IS_INCLUDE; + return pmc->mca_sfcount[MCAST_EXCLUDE] == + psf->sf_count[MCAST_EXCLUDE]; + } + return false; + case MLD2_CHANGE_TO_INCLUDE: + if (gdeleted || sdeleted) + return false; + return psf->sf_count[MCAST_INCLUDE] != 0; + case MLD2_CHANGE_TO_EXCLUDE: + if (gdeleted || sdeleted) + return false; + if (pmc->mca_sfcount[MCAST_EXCLUDE] == 0 || + psf->sf_count[MCAST_INCLUDE]) + return false; + return pmc->mca_sfcount[MCAST_EXCLUDE] == + psf->sf_count[MCAST_EXCLUDE]; + case MLD2_ALLOW_NEW_SOURCES: + if (gdeleted || !psf->sf_crcount) + return false; + return (pmc->mca_sfmode == MCAST_INCLUDE) ^ sdeleted; + case MLD2_BLOCK_OLD_SOURCES: + if (pmc->mca_sfmode == MCAST_INCLUDE) + return gdeleted || (psf->sf_crcount && sdeleted); + return psf->sf_crcount && !gdeleted && !sdeleted; + } + return false; +} + +static int +mld_scount(struct ifmcaddr6 *pmc, int type, int gdeleted, int sdeleted) +{ + struct ip6_sf_list *psf; + int scount = 0; + + for_each_psf_mclock(pmc, psf) { + if (!is_in(pmc, psf, type, gdeleted, sdeleted)) + continue; + scount++; + } + return scount; +} + +static void ip6_mc_hdr(struct sock *sk, struct sk_buff *skb, + struct net_device *dev, + const struct in6_addr *saddr, + const struct in6_addr *daddr, + int proto, int len) +{ + struct ipv6hdr *hdr; + + skb->protocol = htons(ETH_P_IPV6); + skb->dev = dev; + + skb_reset_network_header(skb); + skb_put(skb, sizeof(struct ipv6hdr)); + hdr = ipv6_hdr(skb); + + ip6_flow_hdr(hdr, 0, 0); + + hdr->payload_len = htons(len); + hdr->nexthdr = proto; + hdr->hop_limit = inet6_sk(sk)->hop_limit; + + hdr->saddr = *saddr; + hdr->daddr = *daddr; +} + +static struct sk_buff *mld_newpack(struct inet6_dev *idev, unsigned int mtu) +{ + u8 ra[8] = { IPPROTO_ICMPV6, 0, IPV6_TLV_ROUTERALERT, + 2, 0, 0, IPV6_TLV_PADN, 0 }; + struct net_device *dev = idev->dev; + int hlen = LL_RESERVED_SPACE(dev); + int tlen = dev->needed_tailroom; + struct net *net = dev_net(dev); + const struct in6_addr *saddr; + struct in6_addr addr_buf; + struct mld2_report *pmr; + struct sk_buff *skb; + unsigned int size; + struct sock *sk; + int err; + + sk = net->ipv6.igmp_sk; + /* we assume size > sizeof(ra) here + * Also try to not allocate high-order pages for big MTU + */ + size = min_t(int, mtu, PAGE_SIZE / 2) + hlen + tlen; + skb = sock_alloc_send_skb(sk, size, 1, &err); + if (!skb) + return NULL; + + skb->priority = TC_PRIO_CONTROL; + skb_reserve(skb, hlen); + skb_tailroom_reserve(skb, mtu, tlen); + + if (ipv6_get_lladdr(dev, &addr_buf, IFA_F_TENTATIVE)) { + /* : + * use unspecified address as the source address + * when a valid link-local address is not available. + */ + saddr = &in6addr_any; + } else + saddr = &addr_buf; + + ip6_mc_hdr(sk, skb, dev, saddr, &mld2_all_mcr, NEXTHDR_HOP, 0); + + skb_put_data(skb, ra, sizeof(ra)); + + skb_set_transport_header(skb, skb_tail_pointer(skb) - skb->data); + skb_put(skb, sizeof(*pmr)); + pmr = (struct mld2_report *)skb_transport_header(skb); + pmr->mld2r_type = ICMPV6_MLD2_REPORT; + pmr->mld2r_resv1 = 0; + pmr->mld2r_cksum = 0; + pmr->mld2r_resv2 = 0; + pmr->mld2r_ngrec = 0; + return skb; +} + +static void mld_sendpack(struct sk_buff *skb) +{ + struct ipv6hdr *pip6 = ipv6_hdr(skb); + struct mld2_report *pmr = + (struct mld2_report *)skb_transport_header(skb); + int payload_len, mldlen; + struct inet6_dev *idev; + struct net *net = dev_net(skb->dev); + int err; + struct flowi6 fl6; + struct dst_entry *dst; + + rcu_read_lock(); + idev = __in6_dev_get(skb->dev); + IP6_UPD_PO_STATS(net, idev, IPSTATS_MIB_OUT, skb->len); + + payload_len = (skb_tail_pointer(skb) - skb_network_header(skb)) - + sizeof(*pip6); + mldlen = skb_tail_pointer(skb) - skb_transport_header(skb); + pip6->payload_len = htons(payload_len); + + pmr->mld2r_cksum = csum_ipv6_magic(&pip6->saddr, &pip6->daddr, mldlen, + IPPROTO_ICMPV6, + csum_partial(skb_transport_header(skb), + mldlen, 0)); + + icmpv6_flow_init(net->ipv6.igmp_sk, &fl6, ICMPV6_MLD2_REPORT, + &ipv6_hdr(skb)->saddr, &ipv6_hdr(skb)->daddr, + skb->dev->ifindex); + dst = icmp6_dst_alloc(skb->dev, &fl6); + + err = 0; + if (IS_ERR(dst)) { + err = PTR_ERR(dst); + dst = NULL; + } + skb_dst_set(skb, dst); + if (err) + goto err_out; + + err = NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, + net, net->ipv6.igmp_sk, skb, NULL, skb->dev, + dst_output); +out: + if (!err) { + ICMP6MSGOUT_INC_STATS(net, idev, ICMPV6_MLD2_REPORT); + ICMP6_INC_STATS(net, idev, ICMP6_MIB_OUTMSGS); + } else { + IP6_INC_STATS(net, idev, IPSTATS_MIB_OUTDISCARDS); + } + + rcu_read_unlock(); + return; + +err_out: + kfree_skb(skb); + goto out; +} + +static int grec_size(struct ifmcaddr6 *pmc, int type, int gdel, int sdel) +{ + return sizeof(struct mld2_grec) + 16 * mld_scount(pmc,type,gdel,sdel); +} + +static struct sk_buff *add_grhead(struct sk_buff *skb, struct ifmcaddr6 *pmc, + int type, struct mld2_grec **ppgr, unsigned int mtu) +{ + struct mld2_report *pmr; + struct mld2_grec *pgr; + + if (!skb) { + skb = mld_newpack(pmc->idev, mtu); + if (!skb) + return NULL; + } + pgr = skb_put(skb, sizeof(struct mld2_grec)); + pgr->grec_type = type; + pgr->grec_auxwords = 0; + pgr->grec_nsrcs = 0; + pgr->grec_mca = pmc->mca_addr; /* structure copy */ + pmr = (struct mld2_report *)skb_transport_header(skb); + pmr->mld2r_ngrec = htons(ntohs(pmr->mld2r_ngrec)+1); + *ppgr = pgr; + return skb; +} + +#define AVAILABLE(skb) ((skb) ? skb_availroom(skb) : 0) + +/* called with mc_lock */ +static struct sk_buff *add_grec(struct sk_buff *skb, struct ifmcaddr6 *pmc, + int type, int gdeleted, int sdeleted, + int crsend) +{ + struct ip6_sf_list *psf, *psf_prev, *psf_next; + int scount, stotal, first, isquery, truncate; + struct ip6_sf_list __rcu **psf_list; + struct inet6_dev *idev = pmc->idev; + struct net_device *dev = idev->dev; + struct mld2_grec *pgr = NULL; + struct mld2_report *pmr; + unsigned int mtu; + + if (pmc->mca_flags & MAF_NOREPORT) + return skb; + + mtu = READ_ONCE(dev->mtu); + if (mtu < IPV6_MIN_MTU) + return skb; + + isquery = type == MLD2_MODE_IS_INCLUDE || + type == MLD2_MODE_IS_EXCLUDE; + truncate = type == MLD2_MODE_IS_EXCLUDE || + type == MLD2_CHANGE_TO_EXCLUDE; + + stotal = scount = 0; + + psf_list = sdeleted ? &pmc->mca_tomb : &pmc->mca_sources; + + if (!rcu_access_pointer(*psf_list)) + goto empty_source; + + pmr = skb ? (struct mld2_report *)skb_transport_header(skb) : NULL; + + /* EX and TO_EX get a fresh packet, if needed */ + if (truncate) { + if (pmr && pmr->mld2r_ngrec && + AVAILABLE(skb) < grec_size(pmc, type, gdeleted, sdeleted)) { + if (skb) + mld_sendpack(skb); + skb = mld_newpack(idev, mtu); + } + } + first = 1; + psf_prev = NULL; + for (psf = mc_dereference(*psf_list, idev); + psf; + psf = psf_next) { + struct in6_addr *psrc; + + psf_next = mc_dereference(psf->sf_next, idev); + + if (!is_in(pmc, psf, type, gdeleted, sdeleted) && !crsend) { + psf_prev = psf; + continue; + } + + /* Based on RFC3810 6.1. Should not send source-list change + * records when there is a filter mode change. + */ + if (((gdeleted && pmc->mca_sfmode == MCAST_EXCLUDE) || + (!gdeleted && pmc->mca_crcount)) && + (type == MLD2_ALLOW_NEW_SOURCES || + type == MLD2_BLOCK_OLD_SOURCES) && psf->sf_crcount) + goto decrease_sf_crcount; + + /* clear marks on query responses */ + if (isquery) + psf->sf_gsresp = 0; + + if (AVAILABLE(skb) < sizeof(*psrc) + + first*sizeof(struct mld2_grec)) { + if (truncate && !first) + break; /* truncate these */ + if (pgr) + pgr->grec_nsrcs = htons(scount); + if (skb) + mld_sendpack(skb); + skb = mld_newpack(idev, mtu); + first = 1; + scount = 0; + } + if (first) { + skb = add_grhead(skb, pmc, type, &pgr, mtu); + first = 0; + } + if (!skb) + return NULL; + psrc = skb_put(skb, sizeof(*psrc)); + *psrc = psf->sf_addr; + scount++; stotal++; + if ((type == MLD2_ALLOW_NEW_SOURCES || + type == MLD2_BLOCK_OLD_SOURCES) && psf->sf_crcount) { +decrease_sf_crcount: + psf->sf_crcount--; + if ((sdeleted || gdeleted) && psf->sf_crcount == 0) { + if (psf_prev) + rcu_assign_pointer(psf_prev->sf_next, + mc_dereference(psf->sf_next, idev)); + else + rcu_assign_pointer(*psf_list, + mc_dereference(psf->sf_next, idev)); + kfree_rcu(psf, rcu); + continue; + } + } + psf_prev = psf; + } + +empty_source: + if (!stotal) { + if (type == MLD2_ALLOW_NEW_SOURCES || + type == MLD2_BLOCK_OLD_SOURCES) + return skb; + if (pmc->mca_crcount || isquery || crsend) { + /* make sure we have room for group header */ + if (skb && AVAILABLE(skb) < sizeof(struct mld2_grec)) { + mld_sendpack(skb); + skb = NULL; /* add_grhead will get a new one */ + } + skb = add_grhead(skb, pmc, type, &pgr, mtu); + } + } + if (pgr) + pgr->grec_nsrcs = htons(scount); + + if (isquery) + pmc->mca_flags &= ~MAF_GSQUERY; /* clear query state */ + return skb; +} + +/* called with mc_lock */ +static void mld_send_report(struct inet6_dev *idev, struct ifmcaddr6 *pmc) +{ + struct sk_buff *skb = NULL; + int type; + + if (!pmc) { + for_each_mc_mclock(idev, pmc) { + if (pmc->mca_flags & MAF_NOREPORT) + continue; + if (pmc->mca_sfcount[MCAST_EXCLUDE]) + type = MLD2_MODE_IS_EXCLUDE; + else + type = MLD2_MODE_IS_INCLUDE; + skb = add_grec(skb, pmc, type, 0, 0, 0); + } + } else { + if (pmc->mca_sfcount[MCAST_EXCLUDE]) + type = MLD2_MODE_IS_EXCLUDE; + else + type = MLD2_MODE_IS_INCLUDE; + skb = add_grec(skb, pmc, type, 0, 0, 0); + } + if (skb) + mld_sendpack(skb); +} + +/* + * remove zero-count source records from a source filter list + * called with mc_lock + */ +static void mld_clear_zeros(struct ip6_sf_list __rcu **ppsf, struct inet6_dev *idev) +{ + struct ip6_sf_list *psf_prev, *psf_next, *psf; + + psf_prev = NULL; + for (psf = mc_dereference(*ppsf, idev); + psf; + psf = psf_next) { + psf_next = mc_dereference(psf->sf_next, idev); + if (psf->sf_crcount == 0) { + if (psf_prev) + rcu_assign_pointer(psf_prev->sf_next, + mc_dereference(psf->sf_next, idev)); + else + rcu_assign_pointer(*ppsf, + mc_dereference(psf->sf_next, idev)); + kfree_rcu(psf, rcu); + } else { + psf_prev = psf; + } + } +} + +/* called with mc_lock */ +static void mld_send_cr(struct inet6_dev *idev) +{ + struct ifmcaddr6 *pmc, *pmc_prev, *pmc_next; + struct sk_buff *skb = NULL; + int type, dtype; + + /* deleted MCA's */ + pmc_prev = NULL; + for (pmc = mc_dereference(idev->mc_tomb, idev); + pmc; + pmc = pmc_next) { + pmc_next = mc_dereference(pmc->next, idev); + if (pmc->mca_sfmode == MCAST_INCLUDE) { + type = MLD2_BLOCK_OLD_SOURCES; + dtype = MLD2_BLOCK_OLD_SOURCES; + skb = add_grec(skb, pmc, type, 1, 0, 0); + skb = add_grec(skb, pmc, dtype, 1, 1, 0); + } + if (pmc->mca_crcount) { + if (pmc->mca_sfmode == MCAST_EXCLUDE) { + type = MLD2_CHANGE_TO_INCLUDE; + skb = add_grec(skb, pmc, type, 1, 0, 0); + } + pmc->mca_crcount--; + if (pmc->mca_crcount == 0) { + mld_clear_zeros(&pmc->mca_tomb, idev); + mld_clear_zeros(&pmc->mca_sources, idev); + } + } + if (pmc->mca_crcount == 0 && + !rcu_access_pointer(pmc->mca_tomb) && + !rcu_access_pointer(pmc->mca_sources)) { + if (pmc_prev) + rcu_assign_pointer(pmc_prev->next, pmc_next); + else + rcu_assign_pointer(idev->mc_tomb, pmc_next); + in6_dev_put(pmc->idev); + kfree_rcu(pmc, rcu); + } else + pmc_prev = pmc; + } + + /* change recs */ + for_each_mc_mclock(idev, pmc) { + if (pmc->mca_sfcount[MCAST_EXCLUDE]) { + type = MLD2_BLOCK_OLD_SOURCES; + dtype = MLD2_ALLOW_NEW_SOURCES; + } else { + type = MLD2_ALLOW_NEW_SOURCES; + dtype = MLD2_BLOCK_OLD_SOURCES; + } + skb = add_grec(skb, pmc, type, 0, 0, 0); + skb = add_grec(skb, pmc, dtype, 0, 1, 0); /* deleted sources */ + + /* filter mode changes */ + if (pmc->mca_crcount) { + if (pmc->mca_sfmode == MCAST_EXCLUDE) + type = MLD2_CHANGE_TO_EXCLUDE; + else + type = MLD2_CHANGE_TO_INCLUDE; + skb = add_grec(skb, pmc, type, 0, 0, 0); + pmc->mca_crcount--; + } + } + if (!skb) + return; + (void) mld_sendpack(skb); +} + +static void igmp6_send(struct in6_addr *addr, struct net_device *dev, int type) +{ + struct net *net = dev_net(dev); + struct sock *sk = net->ipv6.igmp_sk; + struct inet6_dev *idev; + struct sk_buff *skb; + struct mld_msg *hdr; + const struct in6_addr *snd_addr, *saddr; + struct in6_addr addr_buf; + int hlen = LL_RESERVED_SPACE(dev); + int tlen = dev->needed_tailroom; + int err, len, payload_len, full_len; + u8 ra[8] = { IPPROTO_ICMPV6, 0, + IPV6_TLV_ROUTERALERT, 2, 0, 0, + IPV6_TLV_PADN, 0 }; + struct flowi6 fl6; + struct dst_entry *dst; + + if (type == ICMPV6_MGM_REDUCTION) + snd_addr = &in6addr_linklocal_allrouters; + else + snd_addr = addr; + + len = sizeof(struct icmp6hdr) + sizeof(struct in6_addr); + payload_len = len + sizeof(ra); + full_len = sizeof(struct ipv6hdr) + payload_len; + + rcu_read_lock(); + IP6_UPD_PO_STATS(net, __in6_dev_get(dev), + IPSTATS_MIB_OUT, full_len); + rcu_read_unlock(); + + skb = sock_alloc_send_skb(sk, hlen + tlen + full_len, 1, &err); + + if (!skb) { + rcu_read_lock(); + IP6_INC_STATS(net, __in6_dev_get(dev), + IPSTATS_MIB_OUTDISCARDS); + rcu_read_unlock(); + return; + } + skb->priority = TC_PRIO_CONTROL; + skb_reserve(skb, hlen); + + if (ipv6_get_lladdr(dev, &addr_buf, IFA_F_TENTATIVE)) { + /* : + * use unspecified address as the source address + * when a valid link-local address is not available. + */ + saddr = &in6addr_any; + } else + saddr = &addr_buf; + + ip6_mc_hdr(sk, skb, dev, saddr, snd_addr, NEXTHDR_HOP, payload_len); + + skb_put_data(skb, ra, sizeof(ra)); + + hdr = skb_put_zero(skb, sizeof(struct mld_msg)); + hdr->mld_type = type; + hdr->mld_mca = *addr; + + hdr->mld_cksum = csum_ipv6_magic(saddr, snd_addr, len, + IPPROTO_ICMPV6, + csum_partial(hdr, len, 0)); + + rcu_read_lock(); + idev = __in6_dev_get(skb->dev); + + icmpv6_flow_init(sk, &fl6, type, + &ipv6_hdr(skb)->saddr, &ipv6_hdr(skb)->daddr, + skb->dev->ifindex); + dst = icmp6_dst_alloc(skb->dev, &fl6); + if (IS_ERR(dst)) { + err = PTR_ERR(dst); + goto err_out; + } + + skb_dst_set(skb, dst); + err = NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, + net, sk, skb, NULL, skb->dev, + dst_output); +out: + if (!err) { + ICMP6MSGOUT_INC_STATS(net, idev, type); + ICMP6_INC_STATS(net, idev, ICMP6_MIB_OUTMSGS); + } else + IP6_INC_STATS(net, idev, IPSTATS_MIB_OUTDISCARDS); + + rcu_read_unlock(); + return; + +err_out: + kfree_skb(skb); + goto out; +} + +/* called with mc_lock */ +static void mld_send_initial_cr(struct inet6_dev *idev) +{ + struct sk_buff *skb; + struct ifmcaddr6 *pmc; + int type; + + if (mld_in_v1_mode(idev)) + return; + + skb = NULL; + for_each_mc_mclock(idev, pmc) { + if (pmc->mca_sfcount[MCAST_EXCLUDE]) + type = MLD2_CHANGE_TO_EXCLUDE; + else + type = MLD2_ALLOW_NEW_SOURCES; + skb = add_grec(skb, pmc, type, 0, 0, 1); + } + if (skb) + mld_sendpack(skb); +} + +void ipv6_mc_dad_complete(struct inet6_dev *idev) +{ + mutex_lock(&idev->mc_lock); + idev->mc_dad_count = idev->mc_qrv; + if (idev->mc_dad_count) { + mld_send_initial_cr(idev); + idev->mc_dad_count--; + if (idev->mc_dad_count) + mld_dad_start_work(idev, + unsolicited_report_interval(idev)); + } + mutex_unlock(&idev->mc_lock); +} + +static void mld_dad_work(struct work_struct *work) +{ + struct inet6_dev *idev = container_of(to_delayed_work(work), + struct inet6_dev, + mc_dad_work); + mutex_lock(&idev->mc_lock); + mld_send_initial_cr(idev); + if (idev->mc_dad_count) { + idev->mc_dad_count--; + if (idev->mc_dad_count) + mld_dad_start_work(idev, + unsolicited_report_interval(idev)); + } + mutex_unlock(&idev->mc_lock); + in6_dev_put(idev); +} + +/* called with mc_lock */ +static int ip6_mc_del1_src(struct ifmcaddr6 *pmc, int sfmode, + const struct in6_addr *psfsrc) +{ + struct ip6_sf_list *psf, *psf_prev; + int rv = 0; + + psf_prev = NULL; + for_each_psf_mclock(pmc, psf) { + if (ipv6_addr_equal(&psf->sf_addr, psfsrc)) + break; + psf_prev = psf; + } + if (!psf || psf->sf_count[sfmode] == 0) { + /* source filter not found, or count wrong => bug */ + return -ESRCH; + } + psf->sf_count[sfmode]--; + if (!psf->sf_count[MCAST_INCLUDE] && !psf->sf_count[MCAST_EXCLUDE]) { + struct inet6_dev *idev = pmc->idev; + + /* no more filters for this source */ + if (psf_prev) + rcu_assign_pointer(psf_prev->sf_next, + mc_dereference(psf->sf_next, idev)); + else + rcu_assign_pointer(pmc->mca_sources, + mc_dereference(psf->sf_next, idev)); + + if (psf->sf_oldin && !(pmc->mca_flags & MAF_NOREPORT) && + !mld_in_v1_mode(idev)) { + psf->sf_crcount = idev->mc_qrv; + rcu_assign_pointer(psf->sf_next, + mc_dereference(pmc->mca_tomb, idev)); + rcu_assign_pointer(pmc->mca_tomb, psf); + rv = 1; + } else { + kfree_rcu(psf, rcu); + } + } + return rv; +} + +/* called with mc_lock */ +static int ip6_mc_del_src(struct inet6_dev *idev, const struct in6_addr *pmca, + int sfmode, int sfcount, const struct in6_addr *psfsrc, + int delta) +{ + struct ifmcaddr6 *pmc; + int changerec = 0; + int i, err; + + if (!idev) + return -ENODEV; + + for_each_mc_mclock(idev, pmc) { + if (ipv6_addr_equal(pmca, &pmc->mca_addr)) + break; + } + if (!pmc) + return -ESRCH; + + sf_markstate(pmc); + if (!delta) { + if (!pmc->mca_sfcount[sfmode]) + return -EINVAL; + + pmc->mca_sfcount[sfmode]--; + } + err = 0; + for (i = 0; i < sfcount; i++) { + int rv = ip6_mc_del1_src(pmc, sfmode, &psfsrc[i]); + + changerec |= rv > 0; + if (!err && rv < 0) + err = rv; + } + if (pmc->mca_sfmode == MCAST_EXCLUDE && + pmc->mca_sfcount[MCAST_EXCLUDE] == 0 && + pmc->mca_sfcount[MCAST_INCLUDE]) { + struct ip6_sf_list *psf; + + /* filter mode change */ + pmc->mca_sfmode = MCAST_INCLUDE; + pmc->mca_crcount = idev->mc_qrv; + idev->mc_ifc_count = pmc->mca_crcount; + for_each_psf_mclock(pmc, psf) + psf->sf_crcount = 0; + mld_ifc_event(pmc->idev); + } else if (sf_setstate(pmc) || changerec) { + mld_ifc_event(pmc->idev); + } + + return err; +} + +/* + * Add multicast single-source filter to the interface list + * called with mc_lock + */ +static int ip6_mc_add1_src(struct ifmcaddr6 *pmc, int sfmode, + const struct in6_addr *psfsrc) +{ + struct ip6_sf_list *psf, *psf_prev; + + psf_prev = NULL; + for_each_psf_mclock(pmc, psf) { + if (ipv6_addr_equal(&psf->sf_addr, psfsrc)) + break; + psf_prev = psf; + } + if (!psf) { + psf = kzalloc(sizeof(*psf), GFP_KERNEL); + if (!psf) + return -ENOBUFS; + + psf->sf_addr = *psfsrc; + if (psf_prev) { + rcu_assign_pointer(psf_prev->sf_next, psf); + } else { + rcu_assign_pointer(pmc->mca_sources, psf); + } + } + psf->sf_count[sfmode]++; + return 0; +} + +/* called with mc_lock */ +static void sf_markstate(struct ifmcaddr6 *pmc) +{ + struct ip6_sf_list *psf; + int mca_xcount = pmc->mca_sfcount[MCAST_EXCLUDE]; + + for_each_psf_mclock(pmc, psf) { + if (pmc->mca_sfcount[MCAST_EXCLUDE]) { + psf->sf_oldin = mca_xcount == + psf->sf_count[MCAST_EXCLUDE] && + !psf->sf_count[MCAST_INCLUDE]; + } else { + psf->sf_oldin = psf->sf_count[MCAST_INCLUDE] != 0; + } + } +} + +/* called with mc_lock */ +static int sf_setstate(struct ifmcaddr6 *pmc) +{ + struct ip6_sf_list *psf, *dpsf; + int mca_xcount = pmc->mca_sfcount[MCAST_EXCLUDE]; + int qrv = pmc->idev->mc_qrv; + int new_in, rv; + + rv = 0; + for_each_psf_mclock(pmc, psf) { + if (pmc->mca_sfcount[MCAST_EXCLUDE]) { + new_in = mca_xcount == psf->sf_count[MCAST_EXCLUDE] && + !psf->sf_count[MCAST_INCLUDE]; + } else + new_in = psf->sf_count[MCAST_INCLUDE] != 0; + if (new_in) { + if (!psf->sf_oldin) { + struct ip6_sf_list *prev = NULL; + + for_each_psf_tomb(pmc, dpsf) { + if (ipv6_addr_equal(&dpsf->sf_addr, + &psf->sf_addr)) + break; + prev = dpsf; + } + if (dpsf) { + if (prev) + rcu_assign_pointer(prev->sf_next, + mc_dereference(dpsf->sf_next, + pmc->idev)); + else + rcu_assign_pointer(pmc->mca_tomb, + mc_dereference(dpsf->sf_next, + pmc->idev)); + kfree_rcu(dpsf, rcu); + } + psf->sf_crcount = qrv; + rv++; + } + } else if (psf->sf_oldin) { + psf->sf_crcount = 0; + /* + * add or update "delete" records if an active filter + * is now inactive + */ + + for_each_psf_tomb(pmc, dpsf) + if (ipv6_addr_equal(&dpsf->sf_addr, + &psf->sf_addr)) + break; + if (!dpsf) { + dpsf = kmalloc(sizeof(*dpsf), GFP_KERNEL); + if (!dpsf) + continue; + *dpsf = *psf; + rcu_assign_pointer(dpsf->sf_next, + mc_dereference(pmc->mca_tomb, pmc->idev)); + rcu_assign_pointer(pmc->mca_tomb, dpsf); + } + dpsf->sf_crcount = qrv; + rv++; + } + } + return rv; +} + +/* + * Add multicast source filter list to the interface list + * called with mc_lock + */ +static int ip6_mc_add_src(struct inet6_dev *idev, const struct in6_addr *pmca, + int sfmode, int sfcount, const struct in6_addr *psfsrc, + int delta) +{ + struct ifmcaddr6 *pmc; + int isexclude; + int i, err; + + if (!idev) + return -ENODEV; + + for_each_mc_mclock(idev, pmc) { + if (ipv6_addr_equal(pmca, &pmc->mca_addr)) + break; + } + if (!pmc) + return -ESRCH; + + sf_markstate(pmc); + isexclude = pmc->mca_sfmode == MCAST_EXCLUDE; + if (!delta) + pmc->mca_sfcount[sfmode]++; + err = 0; + for (i = 0; i < sfcount; i++) { + err = ip6_mc_add1_src(pmc, sfmode, &psfsrc[i]); + if (err) + break; + } + if (err) { + int j; + + if (!delta) + pmc->mca_sfcount[sfmode]--; + for (j = 0; j < i; j++) + ip6_mc_del1_src(pmc, sfmode, &psfsrc[j]); + } else if (isexclude != (pmc->mca_sfcount[MCAST_EXCLUDE] != 0)) { + struct ip6_sf_list *psf; + + /* filter mode change */ + if (pmc->mca_sfcount[MCAST_EXCLUDE]) + pmc->mca_sfmode = MCAST_EXCLUDE; + else if (pmc->mca_sfcount[MCAST_INCLUDE]) + pmc->mca_sfmode = MCAST_INCLUDE; + /* else no filters; keep old mode for reports */ + + pmc->mca_crcount = idev->mc_qrv; + idev->mc_ifc_count = pmc->mca_crcount; + for_each_psf_mclock(pmc, psf) + psf->sf_crcount = 0; + mld_ifc_event(idev); + } else if (sf_setstate(pmc)) { + mld_ifc_event(idev); + } + return err; +} + +/* called with mc_lock */ +static void ip6_mc_clear_src(struct ifmcaddr6 *pmc) +{ + struct ip6_sf_list *psf, *nextpsf; + + for (psf = mc_dereference(pmc->mca_tomb, pmc->idev); + psf; + psf = nextpsf) { + nextpsf = mc_dereference(psf->sf_next, pmc->idev); + kfree_rcu(psf, rcu); + } + RCU_INIT_POINTER(pmc->mca_tomb, NULL); + for (psf = mc_dereference(pmc->mca_sources, pmc->idev); + psf; + psf = nextpsf) { + nextpsf = mc_dereference(psf->sf_next, pmc->idev); + kfree_rcu(psf, rcu); + } + RCU_INIT_POINTER(pmc->mca_sources, NULL); + pmc->mca_sfmode = MCAST_EXCLUDE; + pmc->mca_sfcount[MCAST_INCLUDE] = 0; + pmc->mca_sfcount[MCAST_EXCLUDE] = 1; +} + +/* called with mc_lock */ +static void igmp6_join_group(struct ifmcaddr6 *ma) +{ + unsigned long delay; + + if (ma->mca_flags & MAF_NOREPORT) + return; + + igmp6_send(&ma->mca_addr, ma->idev->dev, ICMPV6_MGM_REPORT); + + delay = prandom_u32_max(unsolicited_report_interval(ma->idev)); + + if (cancel_delayed_work(&ma->mca_work)) { + refcount_dec(&ma->mca_refcnt); + delay = ma->mca_work.timer.expires - jiffies; + } + + if (!mod_delayed_work(mld_wq, &ma->mca_work, delay)) + refcount_inc(&ma->mca_refcnt); + ma->mca_flags |= MAF_TIMER_RUNNING | MAF_LAST_REPORTER; +} + +static int ip6_mc_leave_src(struct sock *sk, struct ipv6_mc_socklist *iml, + struct inet6_dev *idev) +{ + struct ip6_sf_socklist *psl; + int err; + + psl = sock_dereference(iml->sflist, sk); + + if (idev) + mutex_lock(&idev->mc_lock); + + if (!psl) { + /* any-source empty exclude case */ + err = ip6_mc_del_src(idev, &iml->addr, iml->sfmode, 0, NULL, 0); + } else { + err = ip6_mc_del_src(idev, &iml->addr, iml->sfmode, + psl->sl_count, psl->sl_addr, 0); + RCU_INIT_POINTER(iml->sflist, NULL); + atomic_sub(struct_size(psl, sl_addr, psl->sl_max), + &sk->sk_omem_alloc); + kfree_rcu(psl, rcu); + } + + if (idev) + mutex_unlock(&idev->mc_lock); + + return err; +} + +/* called with mc_lock */ +static void igmp6_leave_group(struct ifmcaddr6 *ma) +{ + if (mld_in_v1_mode(ma->idev)) { + if (ma->mca_flags & MAF_LAST_REPORTER) { + igmp6_send(&ma->mca_addr, ma->idev->dev, + ICMPV6_MGM_REDUCTION); + } + } else { + mld_add_delrec(ma->idev, ma); + mld_ifc_event(ma->idev); + } +} + +static void mld_gq_work(struct work_struct *work) +{ + struct inet6_dev *idev = container_of(to_delayed_work(work), + struct inet6_dev, + mc_gq_work); + + mutex_lock(&idev->mc_lock); + mld_send_report(idev, NULL); + idev->mc_gq_running = 0; + mutex_unlock(&idev->mc_lock); + + in6_dev_put(idev); +} + +static void mld_ifc_work(struct work_struct *work) +{ + struct inet6_dev *idev = container_of(to_delayed_work(work), + struct inet6_dev, + mc_ifc_work); + + mutex_lock(&idev->mc_lock); + mld_send_cr(idev); + + if (idev->mc_ifc_count) { + idev->mc_ifc_count--; + if (idev->mc_ifc_count) + mld_ifc_start_work(idev, + unsolicited_report_interval(idev)); + } + mutex_unlock(&idev->mc_lock); + in6_dev_put(idev); +} + +/* called with mc_lock */ +static void mld_ifc_event(struct inet6_dev *idev) +{ + if (mld_in_v1_mode(idev)) + return; + + idev->mc_ifc_count = idev->mc_qrv; + mld_ifc_start_work(idev, 1); +} + +static void mld_mca_work(struct work_struct *work) +{ + struct ifmcaddr6 *ma = container_of(to_delayed_work(work), + struct ifmcaddr6, mca_work); + + mutex_lock(&ma->idev->mc_lock); + if (mld_in_v1_mode(ma->idev)) + igmp6_send(&ma->mca_addr, ma->idev->dev, ICMPV6_MGM_REPORT); + else + mld_send_report(ma->idev, ma); + ma->mca_flags |= MAF_LAST_REPORTER; + ma->mca_flags &= ~MAF_TIMER_RUNNING; + mutex_unlock(&ma->idev->mc_lock); + + ma_put(ma); +} + +/* Device changing type */ + +void ipv6_mc_unmap(struct inet6_dev *idev) +{ + struct ifmcaddr6 *i; + + /* Install multicast list, except for all-nodes (already installed) */ + + mutex_lock(&idev->mc_lock); + for_each_mc_mclock(idev, i) + igmp6_group_dropped(i); + mutex_unlock(&idev->mc_lock); +} + +void ipv6_mc_remap(struct inet6_dev *idev) +{ + ipv6_mc_up(idev); +} + +/* Device going down */ +void ipv6_mc_down(struct inet6_dev *idev) +{ + struct ifmcaddr6 *i; + + mutex_lock(&idev->mc_lock); + /* Withdraw multicast list */ + for_each_mc_mclock(idev, i) + igmp6_group_dropped(i); + mutex_unlock(&idev->mc_lock); + + /* Should stop work after group drop. or we will + * start work again in mld_ifc_event() + */ + synchronize_net(); + mld_query_stop_work(idev); + mld_report_stop_work(idev); + + mutex_lock(&idev->mc_lock); + mld_ifc_stop_work(idev); + mld_gq_stop_work(idev); + mutex_unlock(&idev->mc_lock); + + mld_dad_stop_work(idev); +} + +static void ipv6_mc_reset(struct inet6_dev *idev) +{ + idev->mc_qrv = sysctl_mld_qrv; + idev->mc_qi = MLD_QI_DEFAULT; + idev->mc_qri = MLD_QRI_DEFAULT; + idev->mc_v1_seen = 0; + idev->mc_maxdelay = unsolicited_report_interval(idev); +} + +/* Device going up */ + +void ipv6_mc_up(struct inet6_dev *idev) +{ + struct ifmcaddr6 *i; + + /* Install multicast list, except for all-nodes (already installed) */ + + ipv6_mc_reset(idev); + mutex_lock(&idev->mc_lock); + for_each_mc_mclock(idev, i) { + mld_del_delrec(idev, i); + igmp6_group_added(i); + } + mutex_unlock(&idev->mc_lock); +} + +/* IPv6 device initialization. */ + +void ipv6_mc_init_dev(struct inet6_dev *idev) +{ + idev->mc_gq_running = 0; + INIT_DELAYED_WORK(&idev->mc_gq_work, mld_gq_work); + RCU_INIT_POINTER(idev->mc_tomb, NULL); + idev->mc_ifc_count = 0; + INIT_DELAYED_WORK(&idev->mc_ifc_work, mld_ifc_work); + INIT_DELAYED_WORK(&idev->mc_dad_work, mld_dad_work); + INIT_DELAYED_WORK(&idev->mc_query_work, mld_query_work); + INIT_DELAYED_WORK(&idev->mc_report_work, mld_report_work); + skb_queue_head_init(&idev->mc_query_queue); + skb_queue_head_init(&idev->mc_report_queue); + spin_lock_init(&idev->mc_query_lock); + spin_lock_init(&idev->mc_report_lock); + mutex_init(&idev->mc_lock); + ipv6_mc_reset(idev); +} + +/* + * Device is about to be destroyed: clean up. + */ + +void ipv6_mc_destroy_dev(struct inet6_dev *idev) +{ + struct ifmcaddr6 *i; + + /* Deactivate works */ + ipv6_mc_down(idev); + mutex_lock(&idev->mc_lock); + mld_clear_delrec(idev); + mutex_unlock(&idev->mc_lock); + mld_clear_query(idev); + mld_clear_report(idev); + + /* Delete all-nodes address. */ + /* We cannot call ipv6_dev_mc_dec() directly, our caller in + * addrconf.c has NULL'd out dev->ip6_ptr so in6_dev_get() will + * fail. + */ + __ipv6_dev_mc_dec(idev, &in6addr_linklocal_allnodes); + + if (idev->cnf.forwarding) + __ipv6_dev_mc_dec(idev, &in6addr_linklocal_allrouters); + + mutex_lock(&idev->mc_lock); + while ((i = mc_dereference(idev->mc_list, idev))) { + rcu_assign_pointer(idev->mc_list, mc_dereference(i->next, idev)); + + ip6_mc_clear_src(i); + ma_put(i); + } + mutex_unlock(&idev->mc_lock); +} + +static void ipv6_mc_rejoin_groups(struct inet6_dev *idev) +{ + struct ifmcaddr6 *pmc; + + ASSERT_RTNL(); + + mutex_lock(&idev->mc_lock); + if (mld_in_v1_mode(idev)) { + for_each_mc_mclock(idev, pmc) + igmp6_join_group(pmc); + } else { + mld_send_report(idev, NULL); + } + mutex_unlock(&idev->mc_lock); +} + +static int ipv6_mc_netdev_event(struct notifier_block *this, + unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct inet6_dev *idev = __in6_dev_get(dev); + + switch (event) { + case NETDEV_RESEND_IGMP: + if (idev) + ipv6_mc_rejoin_groups(idev); + break; + default: + break; + } + + return NOTIFY_DONE; +} + +static struct notifier_block igmp6_netdev_notifier = { + .notifier_call = ipv6_mc_netdev_event, +}; + +#ifdef CONFIG_PROC_FS +struct igmp6_mc_iter_state { + struct seq_net_private p; + struct net_device *dev; + struct inet6_dev *idev; +}; + +#define igmp6_mc_seq_private(seq) ((struct igmp6_mc_iter_state *)(seq)->private) + +static inline struct ifmcaddr6 *igmp6_mc_get_first(struct seq_file *seq) +{ + struct ifmcaddr6 *im = NULL; + struct igmp6_mc_iter_state *state = igmp6_mc_seq_private(seq); + struct net *net = seq_file_net(seq); + + state->idev = NULL; + for_each_netdev_rcu(net, state->dev) { + struct inet6_dev *idev; + idev = __in6_dev_get(state->dev); + if (!idev) + continue; + + im = rcu_dereference(idev->mc_list); + if (im) { + state->idev = idev; + break; + } + } + return im; +} + +static struct ifmcaddr6 *igmp6_mc_get_next(struct seq_file *seq, struct ifmcaddr6 *im) +{ + struct igmp6_mc_iter_state *state = igmp6_mc_seq_private(seq); + + im = rcu_dereference(im->next); + while (!im) { + state->dev = next_net_device_rcu(state->dev); + if (!state->dev) { + state->idev = NULL; + break; + } + state->idev = __in6_dev_get(state->dev); + if (!state->idev) + continue; + im = rcu_dereference(state->idev->mc_list); + } + return im; +} + +static struct ifmcaddr6 *igmp6_mc_get_idx(struct seq_file *seq, loff_t pos) +{ + struct ifmcaddr6 *im = igmp6_mc_get_first(seq); + if (im) + while (pos && (im = igmp6_mc_get_next(seq, im)) != NULL) + --pos; + return pos ? NULL : im; +} + +static void *igmp6_mc_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(RCU) +{ + rcu_read_lock(); + return igmp6_mc_get_idx(seq, *pos); +} + +static void *igmp6_mc_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct ifmcaddr6 *im = igmp6_mc_get_next(seq, v); + + ++*pos; + return im; +} + +static void igmp6_mc_seq_stop(struct seq_file *seq, void *v) + __releases(RCU) +{ + struct igmp6_mc_iter_state *state = igmp6_mc_seq_private(seq); + + if (likely(state->idev)) + state->idev = NULL; + state->dev = NULL; + rcu_read_unlock(); +} + +static int igmp6_mc_seq_show(struct seq_file *seq, void *v) +{ + struct ifmcaddr6 *im = (struct ifmcaddr6 *)v; + struct igmp6_mc_iter_state *state = igmp6_mc_seq_private(seq); + + seq_printf(seq, + "%-4d %-15s %pi6 %5d %08X %ld\n", + state->dev->ifindex, state->dev->name, + &im->mca_addr, + im->mca_users, im->mca_flags, + (im->mca_flags & MAF_TIMER_RUNNING) ? + jiffies_to_clock_t(im->mca_work.timer.expires - jiffies) : 0); + return 0; +} + +static const struct seq_operations igmp6_mc_seq_ops = { + .start = igmp6_mc_seq_start, + .next = igmp6_mc_seq_next, + .stop = igmp6_mc_seq_stop, + .show = igmp6_mc_seq_show, +}; + +struct igmp6_mcf_iter_state { + struct seq_net_private p; + struct net_device *dev; + struct inet6_dev *idev; + struct ifmcaddr6 *im; +}; + +#define igmp6_mcf_seq_private(seq) ((struct igmp6_mcf_iter_state *)(seq)->private) + +static inline struct ip6_sf_list *igmp6_mcf_get_first(struct seq_file *seq) +{ + struct ip6_sf_list *psf = NULL; + struct ifmcaddr6 *im = NULL; + struct igmp6_mcf_iter_state *state = igmp6_mcf_seq_private(seq); + struct net *net = seq_file_net(seq); + + state->idev = NULL; + state->im = NULL; + for_each_netdev_rcu(net, state->dev) { + struct inet6_dev *idev; + idev = __in6_dev_get(state->dev); + if (unlikely(idev == NULL)) + continue; + + im = rcu_dereference(idev->mc_list); + if (likely(im)) { + psf = rcu_dereference(im->mca_sources); + if (likely(psf)) { + state->im = im; + state->idev = idev; + break; + } + } + } + return psf; +} + +static struct ip6_sf_list *igmp6_mcf_get_next(struct seq_file *seq, struct ip6_sf_list *psf) +{ + struct igmp6_mcf_iter_state *state = igmp6_mcf_seq_private(seq); + + psf = rcu_dereference(psf->sf_next); + while (!psf) { + state->im = rcu_dereference(state->im->next); + while (!state->im) { + state->dev = next_net_device_rcu(state->dev); + if (!state->dev) { + state->idev = NULL; + goto out; + } + state->idev = __in6_dev_get(state->dev); + if (!state->idev) + continue; + state->im = rcu_dereference(state->idev->mc_list); + } + if (!state->im) + break; + psf = rcu_dereference(state->im->mca_sources); + } +out: + return psf; +} + +static struct ip6_sf_list *igmp6_mcf_get_idx(struct seq_file *seq, loff_t pos) +{ + struct ip6_sf_list *psf = igmp6_mcf_get_first(seq); + if (psf) + while (pos && (psf = igmp6_mcf_get_next(seq, psf)) != NULL) + --pos; + return pos ? NULL : psf; +} + +static void *igmp6_mcf_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(RCU) +{ + rcu_read_lock(); + return *pos ? igmp6_mcf_get_idx(seq, *pos - 1) : SEQ_START_TOKEN; +} + +static void *igmp6_mcf_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct ip6_sf_list *psf; + if (v == SEQ_START_TOKEN) + psf = igmp6_mcf_get_first(seq); + else + psf = igmp6_mcf_get_next(seq, v); + ++*pos; + return psf; +} + +static void igmp6_mcf_seq_stop(struct seq_file *seq, void *v) + __releases(RCU) +{ + struct igmp6_mcf_iter_state *state = igmp6_mcf_seq_private(seq); + + if (likely(state->im)) + state->im = NULL; + if (likely(state->idev)) + state->idev = NULL; + + state->dev = NULL; + rcu_read_unlock(); +} + +static int igmp6_mcf_seq_show(struct seq_file *seq, void *v) +{ + struct ip6_sf_list *psf = (struct ip6_sf_list *)v; + struct igmp6_mcf_iter_state *state = igmp6_mcf_seq_private(seq); + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, "Idx Device Multicast Address Source Address INC EXC\n"); + } else { + seq_printf(seq, + "%3d %6.6s %pi6 %pi6 %6lu %6lu\n", + state->dev->ifindex, state->dev->name, + &state->im->mca_addr, + &psf->sf_addr, + psf->sf_count[MCAST_INCLUDE], + psf->sf_count[MCAST_EXCLUDE]); + } + return 0; +} + +static const struct seq_operations igmp6_mcf_seq_ops = { + .start = igmp6_mcf_seq_start, + .next = igmp6_mcf_seq_next, + .stop = igmp6_mcf_seq_stop, + .show = igmp6_mcf_seq_show, +}; + +static int __net_init igmp6_proc_init(struct net *net) +{ + int err; + + err = -ENOMEM; + if (!proc_create_net("igmp6", 0444, net->proc_net, &igmp6_mc_seq_ops, + sizeof(struct igmp6_mc_iter_state))) + goto out; + if (!proc_create_net("mcfilter6", 0444, net->proc_net, + &igmp6_mcf_seq_ops, + sizeof(struct igmp6_mcf_iter_state))) + goto out_proc_net_igmp6; + + err = 0; +out: + return err; + +out_proc_net_igmp6: + remove_proc_entry("igmp6", net->proc_net); + goto out; +} + +static void __net_exit igmp6_proc_exit(struct net *net) +{ + remove_proc_entry("mcfilter6", net->proc_net); + remove_proc_entry("igmp6", net->proc_net); +} +#else +static inline int igmp6_proc_init(struct net *net) +{ + return 0; +} +static inline void igmp6_proc_exit(struct net *net) +{ +} +#endif + +static int __net_init igmp6_net_init(struct net *net) +{ + int err; + + err = inet_ctl_sock_create(&net->ipv6.igmp_sk, PF_INET6, + SOCK_RAW, IPPROTO_ICMPV6, net); + if (err < 0) { + pr_err("Failed to initialize the IGMP6 control socket (err %d)\n", + err); + goto out; + } + + inet6_sk(net->ipv6.igmp_sk)->hop_limit = 1; + net->ipv6.igmp_sk->sk_allocation = GFP_KERNEL; + + err = inet_ctl_sock_create(&net->ipv6.mc_autojoin_sk, PF_INET6, + SOCK_RAW, IPPROTO_ICMPV6, net); + if (err < 0) { + pr_err("Failed to initialize the IGMP6 autojoin socket (err %d)\n", + err); + goto out_sock_create; + } + + err = igmp6_proc_init(net); + if (err) + goto out_sock_create_autojoin; + + return 0; + +out_sock_create_autojoin: + inet_ctl_sock_destroy(net->ipv6.mc_autojoin_sk); +out_sock_create: + inet_ctl_sock_destroy(net->ipv6.igmp_sk); +out: + return err; +} + +static void __net_exit igmp6_net_exit(struct net *net) +{ + inet_ctl_sock_destroy(net->ipv6.igmp_sk); + inet_ctl_sock_destroy(net->ipv6.mc_autojoin_sk); + igmp6_proc_exit(net); +} + +static struct pernet_operations igmp6_net_ops = { + .init = igmp6_net_init, + .exit = igmp6_net_exit, +}; + +int __init igmp6_init(void) +{ + int err; + + err = register_pernet_subsys(&igmp6_net_ops); + if (err) + return err; + + mld_wq = create_workqueue("mld"); + if (!mld_wq) { + unregister_pernet_subsys(&igmp6_net_ops); + return -ENOMEM; + } + + return err; +} + +int __init igmp6_late_init(void) +{ + return register_netdevice_notifier(&igmp6_netdev_notifier); +} + +void igmp6_cleanup(void) +{ + unregister_pernet_subsys(&igmp6_net_ops); + destroy_workqueue(mld_wq); +} + +void igmp6_late_cleanup(void) +{ + unregister_netdevice_notifier(&igmp6_netdev_notifier); +} diff --git a/net/ipv6/mcast_snoop.c b/net/ipv6/mcast_snoop.c new file mode 100644 index 000000000..04d5fcdfa --- /dev/null +++ b/net/ipv6/mcast_snoop.c @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2010: YOSHIFUJI Hideaki + * Copyright (C) 2015: Linus Lüssing + * + * Based on the MLD support added to br_multicast.c by YOSHIFUJI Hideaki. + */ + +#include +#include +#include +#include +#include + +static int ipv6_mc_check_ip6hdr(struct sk_buff *skb) +{ + const struct ipv6hdr *ip6h; + unsigned int len; + unsigned int offset = skb_network_offset(skb) + sizeof(*ip6h); + + if (!pskb_may_pull(skb, offset)) + return -EINVAL; + + ip6h = ipv6_hdr(skb); + + if (ip6h->version != 6) + return -EINVAL; + + len = offset + ntohs(ip6h->payload_len); + if (skb->len < len || len <= offset) + return -EINVAL; + + skb_set_transport_header(skb, offset); + + return 0; +} + +static int ipv6_mc_check_exthdrs(struct sk_buff *skb) +{ + const struct ipv6hdr *ip6h; + int offset; + u8 nexthdr; + __be16 frag_off; + + ip6h = ipv6_hdr(skb); + + if (ip6h->nexthdr != IPPROTO_HOPOPTS) + return -ENOMSG; + + nexthdr = ip6h->nexthdr; + offset = skb_network_offset(skb) + sizeof(*ip6h); + offset = ipv6_skip_exthdr(skb, offset, &nexthdr, &frag_off); + + if (offset < 0) + return -EINVAL; + + if (nexthdr != IPPROTO_ICMPV6) + return -ENOMSG; + + skb_set_transport_header(skb, offset); + + return 0; +} + +static int ipv6_mc_check_mld_reportv2(struct sk_buff *skb) +{ + unsigned int len = skb_transport_offset(skb); + + len += sizeof(struct mld2_report); + + return ipv6_mc_may_pull(skb, len) ? 0 : -EINVAL; +} + +static int ipv6_mc_check_mld_query(struct sk_buff *skb) +{ + unsigned int transport_len = ipv6_transport_len(skb); + struct mld_msg *mld; + unsigned int len; + + /* RFC2710+RFC3810 (MLDv1+MLDv2) require link-local source addresses */ + if (!(ipv6_addr_type(&ipv6_hdr(skb)->saddr) & IPV6_ADDR_LINKLOCAL)) + return -EINVAL; + + /* MLDv1? */ + if (transport_len != sizeof(struct mld_msg)) { + /* or MLDv2? */ + if (transport_len < sizeof(struct mld2_query)) + return -EINVAL; + + len = skb_transport_offset(skb) + sizeof(struct mld2_query); + if (!ipv6_mc_may_pull(skb, len)) + return -EINVAL; + } + + mld = (struct mld_msg *)skb_transport_header(skb); + + /* RFC2710+RFC3810 (MLDv1+MLDv2) require the multicast link layer + * all-nodes destination address (ff02::1) for general queries + */ + if (ipv6_addr_any(&mld->mld_mca) && + !ipv6_addr_is_ll_all_nodes(&ipv6_hdr(skb)->daddr)) + return -EINVAL; + + return 0; +} + +static int ipv6_mc_check_mld_msg(struct sk_buff *skb) +{ + unsigned int len = skb_transport_offset(skb) + sizeof(struct mld_msg); + struct mld_msg *mld; + + if (!ipv6_mc_may_pull(skb, len)) + return -ENODATA; + + mld = (struct mld_msg *)skb_transport_header(skb); + + switch (mld->mld_type) { + case ICMPV6_MGM_REDUCTION: + case ICMPV6_MGM_REPORT: + return 0; + case ICMPV6_MLD2_REPORT: + return ipv6_mc_check_mld_reportv2(skb); + case ICMPV6_MGM_QUERY: + return ipv6_mc_check_mld_query(skb); + default: + return -ENODATA; + } +} + +static inline __sum16 ipv6_mc_validate_checksum(struct sk_buff *skb) +{ + return skb_checksum_validate(skb, IPPROTO_ICMPV6, ip6_compute_pseudo); +} + +static int ipv6_mc_check_icmpv6(struct sk_buff *skb) +{ + unsigned int len = skb_transport_offset(skb) + sizeof(struct icmp6hdr); + unsigned int transport_len = ipv6_transport_len(skb); + struct sk_buff *skb_chk; + + if (!ipv6_mc_may_pull(skb, len)) + return -EINVAL; + + skb_chk = skb_checksum_trimmed(skb, transport_len, + ipv6_mc_validate_checksum); + if (!skb_chk) + return -EINVAL; + + if (skb_chk != skb) + kfree_skb(skb_chk); + + return 0; +} + +/** + * ipv6_mc_check_mld - checks whether this is a sane MLD packet + * @skb: the skb to validate + * + * Checks whether an IPv6 packet is a valid MLD packet. If so sets + * skb transport header accordingly and returns zero. + * + * -EINVAL: A broken packet was detected, i.e. it violates some internet + * standard + * -ENOMSG: IP header validation succeeded but it is not an ICMPv6 packet + * with a hop-by-hop option. + * -ENODATA: IP+ICMPv6 header with hop-by-hop option validation succeeded + * but it is not an MLD packet. + * -ENOMEM: A memory allocation failure happened. + * + * Caller needs to set the skb network header and free any returned skb if it + * differs from the provided skb. + */ +int ipv6_mc_check_mld(struct sk_buff *skb) +{ + int ret; + + ret = ipv6_mc_check_ip6hdr(skb); + if (ret < 0) + return ret; + + ret = ipv6_mc_check_exthdrs(skb); + if (ret < 0) + return ret; + + ret = ipv6_mc_check_icmpv6(skb); + if (ret < 0) + return ret; + + return ipv6_mc_check_mld_msg(skb); +} +EXPORT_SYMBOL(ipv6_mc_check_mld); diff --git a/net/ipv6/mip6.c b/net/ipv6/mip6.c new file mode 100644 index 000000000..83d2a8be2 --- /dev/null +++ b/net/ipv6/mip6.c @@ -0,0 +1,410 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (C)2003-2006 Helsinki University of Technology + * Copyright (C)2003-2006 USAGI/WIDE Project + */ +/* + * Authors: + * Noriaki TAKAMIYA @USAGI + * Masahide NAKAMURA @USAGI + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static inline unsigned int calc_padlen(unsigned int len, unsigned int n) +{ + return (n - len + 16) & 0x7; +} + +static inline void *mip6_padn(__u8 *data, __u8 padlen) +{ + if (!data) + return NULL; + if (padlen == 1) { + data[0] = IPV6_TLV_PAD1; + } else if (padlen > 1) { + data[0] = IPV6_TLV_PADN; + data[1] = padlen - 2; + if (padlen > 2) + memset(data+2, 0, data[1]); + } + return data + padlen; +} + +static inline void mip6_param_prob(struct sk_buff *skb, u8 code, int pos) +{ + icmpv6_send(skb, ICMPV6_PARAMPROB, code, pos); +} + +static int mip6_mh_len(int type) +{ + int len = 0; + + switch (type) { + case IP6_MH_TYPE_BRR: + len = 0; + break; + case IP6_MH_TYPE_HOTI: + case IP6_MH_TYPE_COTI: + case IP6_MH_TYPE_BU: + case IP6_MH_TYPE_BACK: + len = 1; + break; + case IP6_MH_TYPE_HOT: + case IP6_MH_TYPE_COT: + case IP6_MH_TYPE_BERROR: + len = 2; + break; + } + return len; +} + +static int mip6_mh_filter(struct sock *sk, struct sk_buff *skb) +{ + struct ip6_mh _hdr; + const struct ip6_mh *mh; + + mh = skb_header_pointer(skb, skb_transport_offset(skb), + sizeof(_hdr), &_hdr); + if (!mh) + return -1; + + if (((mh->ip6mh_hdrlen + 1) << 3) > skb->len) + return -1; + + if (mh->ip6mh_hdrlen < mip6_mh_len(mh->ip6mh_type)) { + net_dbg_ratelimited("mip6: MH message too short: %d vs >=%d\n", + mh->ip6mh_hdrlen, + mip6_mh_len(mh->ip6mh_type)); + mip6_param_prob(skb, 0, offsetof(struct ip6_mh, ip6mh_hdrlen) + + skb_network_header_len(skb)); + return -1; + } + + if (mh->ip6mh_proto != IPPROTO_NONE) { + net_dbg_ratelimited("mip6: MH invalid payload proto = %d\n", + mh->ip6mh_proto); + mip6_param_prob(skb, 0, offsetof(struct ip6_mh, ip6mh_proto) + + skb_network_header_len(skb)); + return -1; + } + + return 0; +} + +struct mip6_report_rate_limiter { + spinlock_t lock; + ktime_t stamp; + int iif; + struct in6_addr src; + struct in6_addr dst; +}; + +static struct mip6_report_rate_limiter mip6_report_rl = { + .lock = __SPIN_LOCK_UNLOCKED(mip6_report_rl.lock) +}; + +static int mip6_destopt_input(struct xfrm_state *x, struct sk_buff *skb) +{ + const struct ipv6hdr *iph = ipv6_hdr(skb); + struct ipv6_destopt_hdr *destopt = (struct ipv6_destopt_hdr *)skb->data; + int err = destopt->nexthdr; + + spin_lock(&x->lock); + if (!ipv6_addr_equal(&iph->saddr, (struct in6_addr *)x->coaddr) && + !ipv6_addr_any((struct in6_addr *)x->coaddr)) + err = -ENOENT; + spin_unlock(&x->lock); + + return err; +} + +/* Destination Option Header is inserted. + * IP Header's src address is replaced with Home Address Option in + * Destination Option Header. + */ +static int mip6_destopt_output(struct xfrm_state *x, struct sk_buff *skb) +{ + struct ipv6hdr *iph; + struct ipv6_destopt_hdr *dstopt; + struct ipv6_destopt_hao *hao; + u8 nexthdr; + int len; + + skb_push(skb, -skb_network_offset(skb)); + iph = ipv6_hdr(skb); + + nexthdr = *skb_mac_header(skb); + *skb_mac_header(skb) = IPPROTO_DSTOPTS; + + dstopt = (struct ipv6_destopt_hdr *)skb_transport_header(skb); + dstopt->nexthdr = nexthdr; + + hao = mip6_padn((char *)(dstopt + 1), + calc_padlen(sizeof(*dstopt), 6)); + + hao->type = IPV6_TLV_HAO; + BUILD_BUG_ON(sizeof(*hao) != 18); + hao->length = sizeof(*hao) - 2; + + len = ((char *)hao - (char *)dstopt) + sizeof(*hao); + + memcpy(&hao->addr, &iph->saddr, sizeof(hao->addr)); + spin_lock_bh(&x->lock); + memcpy(&iph->saddr, x->coaddr, sizeof(iph->saddr)); + spin_unlock_bh(&x->lock); + + WARN_ON(len != x->props.header_len); + dstopt->hdrlen = (x->props.header_len >> 3) - 1; + + return 0; +} + +static inline int mip6_report_rl_allow(ktime_t stamp, + const struct in6_addr *dst, + const struct in6_addr *src, int iif) +{ + int allow = 0; + + spin_lock_bh(&mip6_report_rl.lock); + if (mip6_report_rl.stamp != stamp || + mip6_report_rl.iif != iif || + !ipv6_addr_equal(&mip6_report_rl.src, src) || + !ipv6_addr_equal(&mip6_report_rl.dst, dst)) { + mip6_report_rl.stamp = stamp; + mip6_report_rl.iif = iif; + mip6_report_rl.src = *src; + mip6_report_rl.dst = *dst; + allow = 1; + } + spin_unlock_bh(&mip6_report_rl.lock); + return allow; +} + +static int mip6_destopt_reject(struct xfrm_state *x, struct sk_buff *skb, + const struct flowi *fl) +{ + struct net *net = xs_net(x); + struct inet6_skb_parm *opt = (struct inet6_skb_parm *)skb->cb; + const struct flowi6 *fl6 = &fl->u.ip6; + struct ipv6_destopt_hao *hao = NULL; + struct xfrm_selector sel; + int offset; + ktime_t stamp; + int err = 0; + + if (unlikely(fl6->flowi6_proto == IPPROTO_MH && + fl6->fl6_mh_type <= IP6_MH_TYPE_MAX)) + goto out; + + if (likely(opt->dsthao)) { + offset = ipv6_find_tlv(skb, opt->dsthao, IPV6_TLV_HAO); + if (likely(offset >= 0)) + hao = (struct ipv6_destopt_hao *) + (skb_network_header(skb) + offset); + } + + stamp = skb_get_ktime(skb); + + if (!mip6_report_rl_allow(stamp, &ipv6_hdr(skb)->daddr, + hao ? &hao->addr : &ipv6_hdr(skb)->saddr, + opt->iif)) + goto out; + + memset(&sel, 0, sizeof(sel)); + memcpy(&sel.daddr, (xfrm_address_t *)&ipv6_hdr(skb)->daddr, + sizeof(sel.daddr)); + sel.prefixlen_d = 128; + memcpy(&sel.saddr, (xfrm_address_t *)&ipv6_hdr(skb)->saddr, + sizeof(sel.saddr)); + sel.prefixlen_s = 128; + sel.family = AF_INET6; + sel.proto = fl6->flowi6_proto; + sel.dport = xfrm_flowi_dport(fl, &fl6->uli); + if (sel.dport) + sel.dport_mask = htons(~0); + sel.sport = xfrm_flowi_sport(fl, &fl6->uli); + if (sel.sport) + sel.sport_mask = htons(~0); + sel.ifindex = fl6->flowi6_oif; + + err = km_report(net, IPPROTO_DSTOPTS, &sel, + (hao ? (xfrm_address_t *)&hao->addr : NULL)); + + out: + return err; +} + +static int mip6_destopt_init_state(struct xfrm_state *x, struct netlink_ext_ack *extack) +{ + if (x->id.spi) { + NL_SET_ERR_MSG(extack, "SPI must be 0"); + return -EINVAL; + } + if (x->props.mode != XFRM_MODE_ROUTEOPTIMIZATION) { + NL_SET_ERR_MSG(extack, "XFRM mode must be XFRM_MODE_ROUTEOPTIMIZATION"); + return -EINVAL; + } + + x->props.header_len = sizeof(struct ipv6_destopt_hdr) + + calc_padlen(sizeof(struct ipv6_destopt_hdr), 6) + + sizeof(struct ipv6_destopt_hao); + WARN_ON(x->props.header_len != 24); + + return 0; +} + +/* + * Do nothing about destroying since it has no specific operation for + * destination options header unlike IPsec protocols. + */ +static void mip6_destopt_destroy(struct xfrm_state *x) +{ +} + +static const struct xfrm_type mip6_destopt_type = { + .owner = THIS_MODULE, + .proto = IPPROTO_DSTOPTS, + .flags = XFRM_TYPE_NON_FRAGMENT | XFRM_TYPE_LOCAL_COADDR, + .init_state = mip6_destopt_init_state, + .destructor = mip6_destopt_destroy, + .input = mip6_destopt_input, + .output = mip6_destopt_output, + .reject = mip6_destopt_reject, +}; + +static int mip6_rthdr_input(struct xfrm_state *x, struct sk_buff *skb) +{ + const struct ipv6hdr *iph = ipv6_hdr(skb); + struct rt2_hdr *rt2 = (struct rt2_hdr *)skb->data; + int err = rt2->rt_hdr.nexthdr; + + spin_lock(&x->lock); + if (!ipv6_addr_equal(&iph->daddr, (struct in6_addr *)x->coaddr) && + !ipv6_addr_any((struct in6_addr *)x->coaddr)) + err = -ENOENT; + spin_unlock(&x->lock); + + return err; +} + +/* Routing Header type 2 is inserted. + * IP Header's dst address is replaced with Routing Header's Home Address. + */ +static int mip6_rthdr_output(struct xfrm_state *x, struct sk_buff *skb) +{ + struct ipv6hdr *iph; + struct rt2_hdr *rt2; + u8 nexthdr; + + skb_push(skb, -skb_network_offset(skb)); + iph = ipv6_hdr(skb); + + nexthdr = *skb_mac_header(skb); + *skb_mac_header(skb) = IPPROTO_ROUTING; + + rt2 = (struct rt2_hdr *)skb_transport_header(skb); + rt2->rt_hdr.nexthdr = nexthdr; + rt2->rt_hdr.hdrlen = (x->props.header_len >> 3) - 1; + rt2->rt_hdr.type = IPV6_SRCRT_TYPE_2; + rt2->rt_hdr.segments_left = 1; + memset(&rt2->reserved, 0, sizeof(rt2->reserved)); + + WARN_ON(rt2->rt_hdr.hdrlen != 2); + + memcpy(&rt2->addr, &iph->daddr, sizeof(rt2->addr)); + spin_lock_bh(&x->lock); + memcpy(&iph->daddr, x->coaddr, sizeof(iph->daddr)); + spin_unlock_bh(&x->lock); + + return 0; +} + +static int mip6_rthdr_init_state(struct xfrm_state *x, struct netlink_ext_ack *extack) +{ + if (x->id.spi) { + NL_SET_ERR_MSG(extack, "SPI must be 0"); + return -EINVAL; + } + if (x->props.mode != XFRM_MODE_ROUTEOPTIMIZATION) { + NL_SET_ERR_MSG(extack, "XFRM mode must be XFRM_MODE_ROUTEOPTIMIZATION"); + return -EINVAL; + } + + x->props.header_len = sizeof(struct rt2_hdr); + + return 0; +} + +/* + * Do nothing about destroying since it has no specific operation for routing + * header type 2 unlike IPsec protocols. + */ +static void mip6_rthdr_destroy(struct xfrm_state *x) +{ +} + +static const struct xfrm_type mip6_rthdr_type = { + .owner = THIS_MODULE, + .proto = IPPROTO_ROUTING, + .flags = XFRM_TYPE_NON_FRAGMENT | XFRM_TYPE_REMOTE_COADDR, + .init_state = mip6_rthdr_init_state, + .destructor = mip6_rthdr_destroy, + .input = mip6_rthdr_input, + .output = mip6_rthdr_output, +}; + +static int __init mip6_init(void) +{ + pr_info("Mobile IPv6\n"); + + if (xfrm_register_type(&mip6_destopt_type, AF_INET6) < 0) { + pr_info("%s: can't add xfrm type(destopt)\n", __func__); + goto mip6_destopt_xfrm_fail; + } + if (xfrm_register_type(&mip6_rthdr_type, AF_INET6) < 0) { + pr_info("%s: can't add xfrm type(rthdr)\n", __func__); + goto mip6_rthdr_xfrm_fail; + } + if (rawv6_mh_filter_register(mip6_mh_filter) < 0) { + pr_info("%s: can't add rawv6 mh filter\n", __func__); + goto mip6_rawv6_mh_fail; + } + + + return 0; + + mip6_rawv6_mh_fail: + xfrm_unregister_type(&mip6_rthdr_type, AF_INET6); + mip6_rthdr_xfrm_fail: + xfrm_unregister_type(&mip6_destopt_type, AF_INET6); + mip6_destopt_xfrm_fail: + return -EAGAIN; +} + +static void __exit mip6_fini(void) +{ + if (rawv6_mh_filter_unregister(mip6_mh_filter) < 0) + pr_info("%s: can't remove rawv6 mh filter\n", __func__); + xfrm_unregister_type(&mip6_rthdr_type, AF_INET6); + xfrm_unregister_type(&mip6_destopt_type, AF_INET6); +} + +module_init(mip6_init); +module_exit(mip6_fini); + +MODULE_LICENSE("GPL"); +MODULE_ALIAS_XFRM_TYPE(AF_INET6, XFRM_PROTO_DSTOPTS); +MODULE_ALIAS_XFRM_TYPE(AF_INET6, XFRM_PROTO_ROUTING); diff --git a/net/ipv6/ndisc.c b/net/ipv6/ndisc.c new file mode 100644 index 000000000..8c5a99fe6 --- /dev/null +++ b/net/ipv6/ndisc.c @@ -0,0 +1,2061 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Neighbour Discovery for IPv6 + * Linux INET6 implementation + * + * Authors: + * Pedro Roque + * Mike Shaver + */ + +/* + * Changes: + * + * Alexey I. Froloff : RFC6106 (DNSSL) support + * Pierre Ynard : export userland ND options + * through netlink (RDNSS support) + * Lars Fenneberg : fixed MTU setting on receipt + * of an RA. + * Janos Farkas : kmalloc failure checks + * Alexey Kuznetsov : state machine reworked + * and moved to net/core. + * Pekka Savola : RFC2461 validation + * YOSHIFUJI Hideaki @USAGI : Verify ND options properly + */ + +#define pr_fmt(fmt) "ICMPv6: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_SYSCTL +#include +#endif + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +#include +#include + +static u32 ndisc_hash(const void *pkey, + const struct net_device *dev, + __u32 *hash_rnd); +static bool ndisc_key_eq(const struct neighbour *neigh, const void *pkey); +static bool ndisc_allow_add(const struct net_device *dev, + struct netlink_ext_ack *extack); +static int ndisc_constructor(struct neighbour *neigh); +static void ndisc_solicit(struct neighbour *neigh, struct sk_buff *skb); +static void ndisc_error_report(struct neighbour *neigh, struct sk_buff *skb); +static int pndisc_constructor(struct pneigh_entry *n); +static void pndisc_destructor(struct pneigh_entry *n); +static void pndisc_redo(struct sk_buff *skb); +static int ndisc_is_multicast(const void *pkey); + +static const struct neigh_ops ndisc_generic_ops = { + .family = AF_INET6, + .solicit = ndisc_solicit, + .error_report = ndisc_error_report, + .output = neigh_resolve_output, + .connected_output = neigh_connected_output, +}; + +static const struct neigh_ops ndisc_hh_ops = { + .family = AF_INET6, + .solicit = ndisc_solicit, + .error_report = ndisc_error_report, + .output = neigh_resolve_output, + .connected_output = neigh_resolve_output, +}; + + +static const struct neigh_ops ndisc_direct_ops = { + .family = AF_INET6, + .output = neigh_direct_output, + .connected_output = neigh_direct_output, +}; + +struct neigh_table nd_tbl = { + .family = AF_INET6, + .key_len = sizeof(struct in6_addr), + .protocol = cpu_to_be16(ETH_P_IPV6), + .hash = ndisc_hash, + .key_eq = ndisc_key_eq, + .constructor = ndisc_constructor, + .pconstructor = pndisc_constructor, + .pdestructor = pndisc_destructor, + .proxy_redo = pndisc_redo, + .is_multicast = ndisc_is_multicast, + .allow_add = ndisc_allow_add, + .id = "ndisc_cache", + .parms = { + .tbl = &nd_tbl, + .reachable_time = ND_REACHABLE_TIME, + .data = { + [NEIGH_VAR_MCAST_PROBES] = 3, + [NEIGH_VAR_UCAST_PROBES] = 3, + [NEIGH_VAR_RETRANS_TIME] = ND_RETRANS_TIMER, + [NEIGH_VAR_BASE_REACHABLE_TIME] = ND_REACHABLE_TIME, + [NEIGH_VAR_DELAY_PROBE_TIME] = 5 * HZ, + [NEIGH_VAR_INTERVAL_PROBE_TIME_MS] = 5 * HZ, + [NEIGH_VAR_GC_STALETIME] = 60 * HZ, + [NEIGH_VAR_QUEUE_LEN_BYTES] = SK_WMEM_MAX, + [NEIGH_VAR_PROXY_QLEN] = 64, + [NEIGH_VAR_ANYCAST_DELAY] = 1 * HZ, + [NEIGH_VAR_PROXY_DELAY] = (8 * HZ) / 10, + }, + }, + .gc_interval = 30 * HZ, + .gc_thresh1 = 128, + .gc_thresh2 = 512, + .gc_thresh3 = 1024, +}; +EXPORT_SYMBOL_GPL(nd_tbl); + +void __ndisc_fill_addr_option(struct sk_buff *skb, int type, const void *data, + int data_len, int pad) +{ + int space = __ndisc_opt_addr_space(data_len, pad); + u8 *opt = skb_put(skb, space); + + opt[0] = type; + opt[1] = space>>3; + + memset(opt + 2, 0, pad); + opt += pad; + space -= pad; + + memcpy(opt+2, data, data_len); + data_len += 2; + opt += data_len; + space -= data_len; + if (space > 0) + memset(opt, 0, space); +} +EXPORT_SYMBOL_GPL(__ndisc_fill_addr_option); + +static inline void ndisc_fill_addr_option(struct sk_buff *skb, int type, + const void *data, u8 icmp6_type) +{ + __ndisc_fill_addr_option(skb, type, data, skb->dev->addr_len, + ndisc_addr_option_pad(skb->dev->type)); + ndisc_ops_fill_addr_option(skb->dev, skb, icmp6_type); +} + +static inline void ndisc_fill_redirect_addr_option(struct sk_buff *skb, + void *ha, + const u8 *ops_data) +{ + ndisc_fill_addr_option(skb, ND_OPT_TARGET_LL_ADDR, ha, NDISC_REDIRECT); + ndisc_ops_fill_redirect_addr_option(skb->dev, skb, ops_data); +} + +static struct nd_opt_hdr *ndisc_next_option(struct nd_opt_hdr *cur, + struct nd_opt_hdr *end) +{ + int type; + if (!cur || !end || cur >= end) + return NULL; + type = cur->nd_opt_type; + do { + cur = ((void *)cur) + (cur->nd_opt_len << 3); + } while (cur < end && cur->nd_opt_type != type); + return cur <= end && cur->nd_opt_type == type ? cur : NULL; +} + +static inline int ndisc_is_useropt(const struct net_device *dev, + struct nd_opt_hdr *opt) +{ + return opt->nd_opt_type == ND_OPT_PREFIX_INFO || + opt->nd_opt_type == ND_OPT_RDNSS || + opt->nd_opt_type == ND_OPT_DNSSL || + opt->nd_opt_type == ND_OPT_CAPTIVE_PORTAL || + opt->nd_opt_type == ND_OPT_PREF64 || + ndisc_ops_is_useropt(dev, opt->nd_opt_type); +} + +static struct nd_opt_hdr *ndisc_next_useropt(const struct net_device *dev, + struct nd_opt_hdr *cur, + struct nd_opt_hdr *end) +{ + if (!cur || !end || cur >= end) + return NULL; + do { + cur = ((void *)cur) + (cur->nd_opt_len << 3); + } while (cur < end && !ndisc_is_useropt(dev, cur)); + return cur <= end && ndisc_is_useropt(dev, cur) ? cur : NULL; +} + +struct ndisc_options *ndisc_parse_options(const struct net_device *dev, + u8 *opt, int opt_len, + struct ndisc_options *ndopts) +{ + struct nd_opt_hdr *nd_opt = (struct nd_opt_hdr *)opt; + + if (!nd_opt || opt_len < 0 || !ndopts) + return NULL; + memset(ndopts, 0, sizeof(*ndopts)); + while (opt_len) { + int l; + if (opt_len < sizeof(struct nd_opt_hdr)) + return NULL; + l = nd_opt->nd_opt_len << 3; + if (opt_len < l || l == 0) + return NULL; + if (ndisc_ops_parse_options(dev, nd_opt, ndopts)) + goto next_opt; + switch (nd_opt->nd_opt_type) { + case ND_OPT_SOURCE_LL_ADDR: + case ND_OPT_TARGET_LL_ADDR: + case ND_OPT_MTU: + case ND_OPT_NONCE: + case ND_OPT_REDIRECT_HDR: + if (ndopts->nd_opt_array[nd_opt->nd_opt_type]) { + ND_PRINTK(2, warn, + "%s: duplicated ND6 option found: type=%d\n", + __func__, nd_opt->nd_opt_type); + } else { + ndopts->nd_opt_array[nd_opt->nd_opt_type] = nd_opt; + } + break; + case ND_OPT_PREFIX_INFO: + ndopts->nd_opts_pi_end = nd_opt; + if (!ndopts->nd_opt_array[nd_opt->nd_opt_type]) + ndopts->nd_opt_array[nd_opt->nd_opt_type] = nd_opt; + break; +#ifdef CONFIG_IPV6_ROUTE_INFO + case ND_OPT_ROUTE_INFO: + ndopts->nd_opts_ri_end = nd_opt; + if (!ndopts->nd_opts_ri) + ndopts->nd_opts_ri = nd_opt; + break; +#endif + default: + if (ndisc_is_useropt(dev, nd_opt)) { + ndopts->nd_useropts_end = nd_opt; + if (!ndopts->nd_useropts) + ndopts->nd_useropts = nd_opt; + } else { + /* + * Unknown options must be silently ignored, + * to accommodate future extension to the + * protocol. + */ + ND_PRINTK(2, notice, + "%s: ignored unsupported option; type=%d, len=%d\n", + __func__, + nd_opt->nd_opt_type, + nd_opt->nd_opt_len); + } + } +next_opt: + opt_len -= l; + nd_opt = ((void *)nd_opt) + l; + } + return ndopts; +} + +int ndisc_mc_map(const struct in6_addr *addr, char *buf, struct net_device *dev, int dir) +{ + switch (dev->type) { + case ARPHRD_ETHER: + case ARPHRD_IEEE802: /* Not sure. Check it later. --ANK */ + case ARPHRD_FDDI: + ipv6_eth_mc_map(addr, buf); + return 0; + case ARPHRD_ARCNET: + ipv6_arcnet_mc_map(addr, buf); + return 0; + case ARPHRD_INFINIBAND: + ipv6_ib_mc_map(addr, dev->broadcast, buf); + return 0; + case ARPHRD_IPGRE: + return ipv6_ipgre_mc_map(addr, dev->broadcast, buf); + default: + if (dir) { + memcpy(buf, dev->broadcast, dev->addr_len); + return 0; + } + } + return -EINVAL; +} +EXPORT_SYMBOL(ndisc_mc_map); + +static u32 ndisc_hash(const void *pkey, + const struct net_device *dev, + __u32 *hash_rnd) +{ + return ndisc_hashfn(pkey, dev, hash_rnd); +} + +static bool ndisc_key_eq(const struct neighbour *n, const void *pkey) +{ + return neigh_key_eq128(n, pkey); +} + +static int ndisc_constructor(struct neighbour *neigh) +{ + struct in6_addr *addr = (struct in6_addr *)&neigh->primary_key; + struct net_device *dev = neigh->dev; + struct inet6_dev *in6_dev; + struct neigh_parms *parms; + bool is_multicast = ipv6_addr_is_multicast(addr); + + in6_dev = in6_dev_get(dev); + if (!in6_dev) { + return -EINVAL; + } + + parms = in6_dev->nd_parms; + __neigh_parms_put(neigh->parms); + neigh->parms = neigh_parms_clone(parms); + + neigh->type = is_multicast ? RTN_MULTICAST : RTN_UNICAST; + if (!dev->header_ops) { + neigh->nud_state = NUD_NOARP; + neigh->ops = &ndisc_direct_ops; + neigh->output = neigh_direct_output; + } else { + if (is_multicast) { + neigh->nud_state = NUD_NOARP; + ndisc_mc_map(addr, neigh->ha, dev, 1); + } else if (dev->flags&(IFF_NOARP|IFF_LOOPBACK)) { + neigh->nud_state = NUD_NOARP; + memcpy(neigh->ha, dev->dev_addr, dev->addr_len); + if (dev->flags&IFF_LOOPBACK) + neigh->type = RTN_LOCAL; + } else if (dev->flags&IFF_POINTOPOINT) { + neigh->nud_state = NUD_NOARP; + memcpy(neigh->ha, dev->broadcast, dev->addr_len); + } + if (dev->header_ops->cache) + neigh->ops = &ndisc_hh_ops; + else + neigh->ops = &ndisc_generic_ops; + if (neigh->nud_state&NUD_VALID) + neigh->output = neigh->ops->connected_output; + else + neigh->output = neigh->ops->output; + } + in6_dev_put(in6_dev); + return 0; +} + +static int pndisc_constructor(struct pneigh_entry *n) +{ + struct in6_addr *addr = (struct in6_addr *)&n->key; + struct in6_addr maddr; + struct net_device *dev = n->dev; + + if (!dev || !__in6_dev_get(dev)) + return -EINVAL; + addrconf_addr_solict_mult(addr, &maddr); + ipv6_dev_mc_inc(dev, &maddr); + return 0; +} + +static void pndisc_destructor(struct pneigh_entry *n) +{ + struct in6_addr *addr = (struct in6_addr *)&n->key; + struct in6_addr maddr; + struct net_device *dev = n->dev; + + if (!dev || !__in6_dev_get(dev)) + return; + addrconf_addr_solict_mult(addr, &maddr); + ipv6_dev_mc_dec(dev, &maddr); +} + +/* called with rtnl held */ +static bool ndisc_allow_add(const struct net_device *dev, + struct netlink_ext_ack *extack) +{ + struct inet6_dev *idev = __in6_dev_get(dev); + + if (!idev || idev->cnf.disable_ipv6) { + NL_SET_ERR_MSG(extack, "IPv6 is disabled on this device"); + return false; + } + + return true; +} + +static struct sk_buff *ndisc_alloc_skb(struct net_device *dev, + int len) +{ + int hlen = LL_RESERVED_SPACE(dev); + int tlen = dev->needed_tailroom; + struct sock *sk = dev_net(dev)->ipv6.ndisc_sk; + struct sk_buff *skb; + + skb = alloc_skb(hlen + sizeof(struct ipv6hdr) + len + tlen, GFP_ATOMIC); + if (!skb) { + ND_PRINTK(0, err, "ndisc: %s failed to allocate an skb\n", + __func__); + return NULL; + } + + skb->protocol = htons(ETH_P_IPV6); + skb->dev = dev; + + skb_reserve(skb, hlen + sizeof(struct ipv6hdr)); + skb_reset_transport_header(skb); + + /* Manually assign socket ownership as we avoid calling + * sock_alloc_send_pskb() to bypass wmem buffer limits + */ + skb_set_owner_w(skb, sk); + + return skb; +} + +static void ip6_nd_hdr(struct sk_buff *skb, + const struct in6_addr *saddr, + const struct in6_addr *daddr, + int hop_limit, int len) +{ + struct ipv6hdr *hdr; + struct inet6_dev *idev; + unsigned tclass; + + rcu_read_lock(); + idev = __in6_dev_get(skb->dev); + tclass = idev ? idev->cnf.ndisc_tclass : 0; + rcu_read_unlock(); + + skb_push(skb, sizeof(*hdr)); + skb_reset_network_header(skb); + hdr = ipv6_hdr(skb); + + ip6_flow_hdr(hdr, tclass, 0); + + hdr->payload_len = htons(len); + hdr->nexthdr = IPPROTO_ICMPV6; + hdr->hop_limit = hop_limit; + + hdr->saddr = *saddr; + hdr->daddr = *daddr; +} + +void ndisc_send_skb(struct sk_buff *skb, const struct in6_addr *daddr, + const struct in6_addr *saddr) +{ + struct dst_entry *dst = skb_dst(skb); + struct net *net = dev_net(skb->dev); + struct sock *sk = net->ipv6.ndisc_sk; + struct inet6_dev *idev; + int err; + struct icmp6hdr *icmp6h = icmp6_hdr(skb); + u8 type; + + type = icmp6h->icmp6_type; + + if (!dst) { + struct flowi6 fl6; + int oif = skb->dev->ifindex; + + icmpv6_flow_init(sk, &fl6, type, saddr, daddr, oif); + dst = icmp6_dst_alloc(skb->dev, &fl6); + if (IS_ERR(dst)) { + kfree_skb(skb); + return; + } + + skb_dst_set(skb, dst); + } + + icmp6h->icmp6_cksum = csum_ipv6_magic(saddr, daddr, skb->len, + IPPROTO_ICMPV6, + csum_partial(icmp6h, + skb->len, 0)); + + ip6_nd_hdr(skb, saddr, daddr, inet6_sk(sk)->hop_limit, skb->len); + + rcu_read_lock(); + idev = __in6_dev_get(dst->dev); + IP6_UPD_PO_STATS(net, idev, IPSTATS_MIB_OUT, skb->len); + + err = NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, + net, sk, skb, NULL, dst->dev, + dst_output); + if (!err) { + ICMP6MSGOUT_INC_STATS(net, idev, type); + ICMP6_INC_STATS(net, idev, ICMP6_MIB_OUTMSGS); + } + + rcu_read_unlock(); +} +EXPORT_SYMBOL(ndisc_send_skb); + +void ndisc_send_na(struct net_device *dev, const struct in6_addr *daddr, + const struct in6_addr *solicited_addr, + bool router, bool solicited, bool override, bool inc_opt) +{ + struct sk_buff *skb; + struct in6_addr tmpaddr; + struct inet6_ifaddr *ifp; + const struct in6_addr *src_addr; + struct nd_msg *msg; + int optlen = 0; + + /* for anycast or proxy, solicited_addr != src_addr */ + ifp = ipv6_get_ifaddr(dev_net(dev), solicited_addr, dev, 1); + if (ifp) { + src_addr = solicited_addr; + if (ifp->flags & IFA_F_OPTIMISTIC) + override = false; + inc_opt |= ifp->idev->cnf.force_tllao; + in6_ifa_put(ifp); + } else { + if (ipv6_dev_get_saddr(dev_net(dev), dev, daddr, + inet6_sk(dev_net(dev)->ipv6.ndisc_sk)->srcprefs, + &tmpaddr)) + return; + src_addr = &tmpaddr; + } + + if (!dev->addr_len) + inc_opt = false; + if (inc_opt) + optlen += ndisc_opt_addr_space(dev, + NDISC_NEIGHBOUR_ADVERTISEMENT); + + skb = ndisc_alloc_skb(dev, sizeof(*msg) + optlen); + if (!skb) + return; + + msg = skb_put(skb, sizeof(*msg)); + *msg = (struct nd_msg) { + .icmph = { + .icmp6_type = NDISC_NEIGHBOUR_ADVERTISEMENT, + .icmp6_router = router, + .icmp6_solicited = solicited, + .icmp6_override = override, + }, + .target = *solicited_addr, + }; + + if (inc_opt) + ndisc_fill_addr_option(skb, ND_OPT_TARGET_LL_ADDR, + dev->dev_addr, + NDISC_NEIGHBOUR_ADVERTISEMENT); + + ndisc_send_skb(skb, daddr, src_addr); +} + +static void ndisc_send_unsol_na(struct net_device *dev) +{ + struct inet6_dev *idev; + struct inet6_ifaddr *ifa; + + idev = in6_dev_get(dev); + if (!idev) + return; + + read_lock_bh(&idev->lock); + list_for_each_entry(ifa, &idev->addr_list, if_list) { + /* skip tentative addresses until dad completes */ + if (ifa->flags & IFA_F_TENTATIVE && + !(ifa->flags & IFA_F_OPTIMISTIC)) + continue; + + ndisc_send_na(dev, &in6addr_linklocal_allnodes, &ifa->addr, + /*router=*/ !!idev->cnf.forwarding, + /*solicited=*/ false, /*override=*/ true, + /*inc_opt=*/ true); + } + read_unlock_bh(&idev->lock); + + in6_dev_put(idev); +} + +struct sk_buff *ndisc_ns_create(struct net_device *dev, const struct in6_addr *solicit, + const struct in6_addr *saddr, u64 nonce) +{ + int inc_opt = dev->addr_len; + struct sk_buff *skb; + struct nd_msg *msg; + int optlen = 0; + + if (!saddr) + return NULL; + + if (ipv6_addr_any(saddr)) + inc_opt = false; + if (inc_opt) + optlen += ndisc_opt_addr_space(dev, + NDISC_NEIGHBOUR_SOLICITATION); + if (nonce != 0) + optlen += 8; + + skb = ndisc_alloc_skb(dev, sizeof(*msg) + optlen); + if (!skb) + return NULL; + + msg = skb_put(skb, sizeof(*msg)); + *msg = (struct nd_msg) { + .icmph = { + .icmp6_type = NDISC_NEIGHBOUR_SOLICITATION, + }, + .target = *solicit, + }; + + if (inc_opt) + ndisc_fill_addr_option(skb, ND_OPT_SOURCE_LL_ADDR, + dev->dev_addr, + NDISC_NEIGHBOUR_SOLICITATION); + if (nonce != 0) { + u8 *opt = skb_put(skb, 8); + + opt[0] = ND_OPT_NONCE; + opt[1] = 8 >> 3; + memcpy(opt + 2, &nonce, 6); + } + + return skb; +} +EXPORT_SYMBOL(ndisc_ns_create); + +void ndisc_send_ns(struct net_device *dev, const struct in6_addr *solicit, + const struct in6_addr *daddr, const struct in6_addr *saddr, + u64 nonce) +{ + struct in6_addr addr_buf; + struct sk_buff *skb; + + if (!saddr) { + if (ipv6_get_lladdr(dev, &addr_buf, + (IFA_F_TENTATIVE | IFA_F_OPTIMISTIC))) + return; + saddr = &addr_buf; + } + + skb = ndisc_ns_create(dev, solicit, saddr, nonce); + + if (skb) + ndisc_send_skb(skb, daddr, saddr); +} + +void ndisc_send_rs(struct net_device *dev, const struct in6_addr *saddr, + const struct in6_addr *daddr) +{ + struct sk_buff *skb; + struct rs_msg *msg; + int send_sllao = dev->addr_len; + int optlen = 0; + +#ifdef CONFIG_IPV6_OPTIMISTIC_DAD + /* + * According to section 2.2 of RFC 4429, we must not + * send router solicitations with a sllao from + * optimistic addresses, but we may send the solicitation + * if we don't include the sllao. So here we check + * if our address is optimistic, and if so, we + * suppress the inclusion of the sllao. + */ + if (send_sllao) { + struct inet6_ifaddr *ifp = ipv6_get_ifaddr(dev_net(dev), saddr, + dev, 1); + if (ifp) { + if (ifp->flags & IFA_F_OPTIMISTIC) { + send_sllao = 0; + } + in6_ifa_put(ifp); + } else { + send_sllao = 0; + } + } +#endif + if (send_sllao) + optlen += ndisc_opt_addr_space(dev, NDISC_ROUTER_SOLICITATION); + + skb = ndisc_alloc_skb(dev, sizeof(*msg) + optlen); + if (!skb) + return; + + msg = skb_put(skb, sizeof(*msg)); + *msg = (struct rs_msg) { + .icmph = { + .icmp6_type = NDISC_ROUTER_SOLICITATION, + }, + }; + + if (send_sllao) + ndisc_fill_addr_option(skb, ND_OPT_SOURCE_LL_ADDR, + dev->dev_addr, + NDISC_ROUTER_SOLICITATION); + + ndisc_send_skb(skb, daddr, saddr); +} + + +static void ndisc_error_report(struct neighbour *neigh, struct sk_buff *skb) +{ + /* + * "The sender MUST return an ICMP + * destination unreachable" + */ + dst_link_failure(skb); + kfree_skb(skb); +} + +/* Called with locked neigh: either read or both */ + +static void ndisc_solicit(struct neighbour *neigh, struct sk_buff *skb) +{ + struct in6_addr *saddr = NULL; + struct in6_addr mcaddr; + struct net_device *dev = neigh->dev; + struct in6_addr *target = (struct in6_addr *)&neigh->primary_key; + int probes = atomic_read(&neigh->probes); + + if (skb && ipv6_chk_addr_and_flags(dev_net(dev), &ipv6_hdr(skb)->saddr, + dev, false, 1, + IFA_F_TENTATIVE|IFA_F_OPTIMISTIC)) + saddr = &ipv6_hdr(skb)->saddr; + probes -= NEIGH_VAR(neigh->parms, UCAST_PROBES); + if (probes < 0) { + if (!(READ_ONCE(neigh->nud_state) & NUD_VALID)) { + ND_PRINTK(1, dbg, + "%s: trying to ucast probe in NUD_INVALID: %pI6\n", + __func__, target); + } + ndisc_send_ns(dev, target, target, saddr, 0); + } else if ((probes -= NEIGH_VAR(neigh->parms, APP_PROBES)) < 0) { + neigh_app_ns(neigh); + } else { + addrconf_addr_solict_mult(target, &mcaddr); + ndisc_send_ns(dev, target, &mcaddr, saddr, 0); + } +} + +static int pndisc_is_router(const void *pkey, + struct net_device *dev) +{ + struct pneigh_entry *n; + int ret = -1; + + read_lock_bh(&nd_tbl.lock); + n = __pneigh_lookup(&nd_tbl, dev_net(dev), pkey, dev); + if (n) + ret = !!(n->flags & NTF_ROUTER); + read_unlock_bh(&nd_tbl.lock); + + return ret; +} + +void ndisc_update(const struct net_device *dev, struct neighbour *neigh, + const u8 *lladdr, u8 new, u32 flags, u8 icmp6_type, + struct ndisc_options *ndopts) +{ + neigh_update(neigh, lladdr, new, flags, 0); + /* report ndisc ops about neighbour update */ + ndisc_ops_update(dev, neigh, flags, icmp6_type, ndopts); +} + +static void ndisc_recv_ns(struct sk_buff *skb) +{ + struct nd_msg *msg = (struct nd_msg *)skb_transport_header(skb); + const struct in6_addr *saddr = &ipv6_hdr(skb)->saddr; + const struct in6_addr *daddr = &ipv6_hdr(skb)->daddr; + u8 *lladdr = NULL; + u32 ndoptlen = skb_tail_pointer(skb) - (skb_transport_header(skb) + + offsetof(struct nd_msg, opt)); + struct ndisc_options ndopts; + struct net_device *dev = skb->dev; + struct inet6_ifaddr *ifp; + struct inet6_dev *idev = NULL; + struct neighbour *neigh; + int dad = ipv6_addr_any(saddr); + bool inc; + int is_router = -1; + u64 nonce = 0; + + if (skb->len < sizeof(struct nd_msg)) { + ND_PRINTK(2, warn, "NS: packet too short\n"); + return; + } + + if (ipv6_addr_is_multicast(&msg->target)) { + ND_PRINTK(2, warn, "NS: multicast target address\n"); + return; + } + + /* + * RFC2461 7.1.1: + * DAD has to be destined for solicited node multicast address. + */ + if (dad && !ipv6_addr_is_solict_mult(daddr)) { + ND_PRINTK(2, warn, "NS: bad DAD packet (wrong destination)\n"); + return; + } + + if (!ndisc_parse_options(dev, msg->opt, ndoptlen, &ndopts)) { + ND_PRINTK(2, warn, "NS: invalid ND options\n"); + return; + } + + if (ndopts.nd_opts_src_lladdr) { + lladdr = ndisc_opt_addr_data(ndopts.nd_opts_src_lladdr, dev); + if (!lladdr) { + ND_PRINTK(2, warn, + "NS: invalid link-layer address length\n"); + return; + } + + /* RFC2461 7.1.1: + * If the IP source address is the unspecified address, + * there MUST NOT be source link-layer address option + * in the message. + */ + if (dad) { + ND_PRINTK(2, warn, + "NS: bad DAD packet (link-layer address option)\n"); + return; + } + } + if (ndopts.nd_opts_nonce && ndopts.nd_opts_nonce->nd_opt_len == 1) + memcpy(&nonce, (u8 *)(ndopts.nd_opts_nonce + 1), 6); + + inc = ipv6_addr_is_multicast(daddr); + + ifp = ipv6_get_ifaddr(dev_net(dev), &msg->target, dev, 1); + if (ifp) { +have_ifp: + if (ifp->flags & (IFA_F_TENTATIVE|IFA_F_OPTIMISTIC)) { + if (dad) { + if (nonce != 0 && ifp->dad_nonce == nonce) { + u8 *np = (u8 *)&nonce; + /* Matching nonce if looped back */ + ND_PRINTK(2, notice, + "%s: IPv6 DAD loopback for address %pI6c nonce %pM ignored\n", + ifp->idev->dev->name, + &ifp->addr, np); + goto out; + } + /* + * We are colliding with another node + * who is doing DAD + * so fail our DAD process + */ + addrconf_dad_failure(skb, ifp); + return; + } else { + /* + * This is not a dad solicitation. + * If we are an optimistic node, + * we should respond. + * Otherwise, we should ignore it. + */ + if (!(ifp->flags & IFA_F_OPTIMISTIC)) + goto out; + } + } + + idev = ifp->idev; + } else { + struct net *net = dev_net(dev); + + /* perhaps an address on the master device */ + if (netif_is_l3_slave(dev)) { + struct net_device *mdev; + + mdev = netdev_master_upper_dev_get_rcu(dev); + if (mdev) { + ifp = ipv6_get_ifaddr(net, &msg->target, mdev, 1); + if (ifp) + goto have_ifp; + } + } + + idev = in6_dev_get(dev); + if (!idev) { + /* XXX: count this drop? */ + return; + } + + if (ipv6_chk_acast_addr(net, dev, &msg->target) || + (idev->cnf.forwarding && + (net->ipv6.devconf_all->proxy_ndp || idev->cnf.proxy_ndp) && + (is_router = pndisc_is_router(&msg->target, dev)) >= 0)) { + if (!(NEIGH_CB(skb)->flags & LOCALLY_ENQUEUED) && + skb->pkt_type != PACKET_HOST && + inc && + NEIGH_VAR(idev->nd_parms, PROXY_DELAY) != 0) { + /* + * for anycast or proxy, + * sender should delay its response + * by a random time between 0 and + * MAX_ANYCAST_DELAY_TIME seconds. + * (RFC2461) -- yoshfuji + */ + struct sk_buff *n = skb_clone(skb, GFP_ATOMIC); + if (n) + pneigh_enqueue(&nd_tbl, idev->nd_parms, n); + goto out; + } + } else + goto out; + } + + if (is_router < 0) + is_router = idev->cnf.forwarding; + + if (dad) { + ndisc_send_na(dev, &in6addr_linklocal_allnodes, &msg->target, + !!is_router, false, (ifp != NULL), true); + goto out; + } + + if (inc) + NEIGH_CACHE_STAT_INC(&nd_tbl, rcv_probes_mcast); + else + NEIGH_CACHE_STAT_INC(&nd_tbl, rcv_probes_ucast); + + /* + * update / create cache entry + * for the source address + */ + neigh = __neigh_lookup(&nd_tbl, saddr, dev, + !inc || lladdr || !dev->addr_len); + if (neigh) + ndisc_update(dev, neigh, lladdr, NUD_STALE, + NEIGH_UPDATE_F_WEAK_OVERRIDE| + NEIGH_UPDATE_F_OVERRIDE, + NDISC_NEIGHBOUR_SOLICITATION, &ndopts); + if (neigh || !dev->header_ops) { + ndisc_send_na(dev, saddr, &msg->target, !!is_router, + true, (ifp != NULL && inc), inc); + if (neigh) + neigh_release(neigh); + } + +out: + if (ifp) + in6_ifa_put(ifp); + else + in6_dev_put(idev); +} + +static int accept_untracked_na(struct net_device *dev, struct in6_addr *saddr) +{ + struct inet6_dev *idev = __in6_dev_get(dev); + + switch (idev->cnf.accept_untracked_na) { + case 0: /* Don't accept untracked na (absent in neighbor cache) */ + return 0; + case 1: /* Create new entries from na if currently untracked */ + return 1; + case 2: /* Create new entries from untracked na only if saddr is in the + * same subnet as an address configured on the interface that + * received the na + */ + return !!ipv6_chk_prefix(saddr, dev); + default: + return 0; + } +} + +static void ndisc_recv_na(struct sk_buff *skb) +{ + struct nd_msg *msg = (struct nd_msg *)skb_transport_header(skb); + struct in6_addr *saddr = &ipv6_hdr(skb)->saddr; + const struct in6_addr *daddr = &ipv6_hdr(skb)->daddr; + u8 *lladdr = NULL; + u32 ndoptlen = skb_tail_pointer(skb) - (skb_transport_header(skb) + + offsetof(struct nd_msg, opt)); + struct ndisc_options ndopts; + struct net_device *dev = skb->dev; + struct inet6_dev *idev = __in6_dev_get(dev); + struct inet6_ifaddr *ifp; + struct neighbour *neigh; + u8 new_state; + + if (skb->len < sizeof(struct nd_msg)) { + ND_PRINTK(2, warn, "NA: packet too short\n"); + return; + } + + if (ipv6_addr_is_multicast(&msg->target)) { + ND_PRINTK(2, warn, "NA: target address is multicast\n"); + return; + } + + if (ipv6_addr_is_multicast(daddr) && + msg->icmph.icmp6_solicited) { + ND_PRINTK(2, warn, "NA: solicited NA is multicasted\n"); + return; + } + + /* For some 802.11 wireless deployments (and possibly other networks), + * there will be a NA proxy and unsolicitd packets are attacks + * and thus should not be accepted. + * drop_unsolicited_na takes precedence over accept_untracked_na + */ + if (!msg->icmph.icmp6_solicited && idev && + idev->cnf.drop_unsolicited_na) + return; + + if (!ndisc_parse_options(dev, msg->opt, ndoptlen, &ndopts)) { + ND_PRINTK(2, warn, "NS: invalid ND option\n"); + return; + } + if (ndopts.nd_opts_tgt_lladdr) { + lladdr = ndisc_opt_addr_data(ndopts.nd_opts_tgt_lladdr, dev); + if (!lladdr) { + ND_PRINTK(2, warn, + "NA: invalid link-layer address length\n"); + return; + } + } + ifp = ipv6_get_ifaddr(dev_net(dev), &msg->target, dev, 1); + if (ifp) { + if (skb->pkt_type != PACKET_LOOPBACK + && (ifp->flags & IFA_F_TENTATIVE)) { + addrconf_dad_failure(skb, ifp); + return; + } + /* What should we make now? The advertisement + is invalid, but ndisc specs say nothing + about it. It could be misconfiguration, or + an smart proxy agent tries to help us :-) + + We should not print the error if NA has been + received from loopback - it is just our own + unsolicited advertisement. + */ + if (skb->pkt_type != PACKET_LOOPBACK) + ND_PRINTK(1, warn, + "NA: %pM advertised our address %pI6c on %s!\n", + eth_hdr(skb)->h_source, &ifp->addr, ifp->idev->dev->name); + in6_ifa_put(ifp); + return; + } + + neigh = neigh_lookup(&nd_tbl, &msg->target, dev); + + /* RFC 9131 updates original Neighbour Discovery RFC 4861. + * NAs with Target LL Address option without a corresponding + * entry in the neighbour cache can now create a STALE neighbour + * cache entry on routers. + * + * entry accept fwding solicited behaviour + * ------- ------ ------ --------- ---------------------- + * present X X 0 Set state to STALE + * present X X 1 Set state to REACHABLE + * absent 0 X X Do nothing + * absent 1 0 X Do nothing + * absent 1 1 X Add a new STALE entry + * + * Note that we don't do a (daddr == all-routers-mcast) check. + */ + new_state = msg->icmph.icmp6_solicited ? NUD_REACHABLE : NUD_STALE; + if (!neigh && lladdr && idev && idev->cnf.forwarding) { + if (accept_untracked_na(dev, saddr)) { + neigh = neigh_create(&nd_tbl, &msg->target, dev); + new_state = NUD_STALE; + } + } + + if (neigh && !IS_ERR(neigh)) { + u8 old_flags = neigh->flags; + struct net *net = dev_net(dev); + + if (READ_ONCE(neigh->nud_state) & NUD_FAILED) + goto out; + + /* + * Don't update the neighbor cache entry on a proxy NA from + * ourselves because either the proxied node is off link or it + * has already sent a NA to us. + */ + if (lladdr && !memcmp(lladdr, dev->dev_addr, dev->addr_len) && + net->ipv6.devconf_all->forwarding && net->ipv6.devconf_all->proxy_ndp && + pneigh_lookup(&nd_tbl, net, &msg->target, dev, 0)) { + /* XXX: idev->cnf.proxy_ndp */ + goto out; + } + + ndisc_update(dev, neigh, lladdr, + new_state, + NEIGH_UPDATE_F_WEAK_OVERRIDE| + (msg->icmph.icmp6_override ? NEIGH_UPDATE_F_OVERRIDE : 0)| + NEIGH_UPDATE_F_OVERRIDE_ISROUTER| + (msg->icmph.icmp6_router ? NEIGH_UPDATE_F_ISROUTER : 0), + NDISC_NEIGHBOUR_ADVERTISEMENT, &ndopts); + + if ((old_flags & ~neigh->flags) & NTF_ROUTER) { + /* + * Change: router to host + */ + rt6_clean_tohost(dev_net(dev), saddr); + } + +out: + neigh_release(neigh); + } +} + +static void ndisc_recv_rs(struct sk_buff *skb) +{ + struct rs_msg *rs_msg = (struct rs_msg *)skb_transport_header(skb); + unsigned long ndoptlen = skb->len - sizeof(*rs_msg); + struct neighbour *neigh; + struct inet6_dev *idev; + const struct in6_addr *saddr = &ipv6_hdr(skb)->saddr; + struct ndisc_options ndopts; + u8 *lladdr = NULL; + + if (skb->len < sizeof(*rs_msg)) + return; + + idev = __in6_dev_get(skb->dev); + if (!idev) { + ND_PRINTK(1, err, "RS: can't find in6 device\n"); + return; + } + + /* Don't accept RS if we're not in router mode */ + if (!idev->cnf.forwarding) + goto out; + + /* + * Don't update NCE if src = ::; + * this implies that the source node has no ip address assigned yet. + */ + if (ipv6_addr_any(saddr)) + goto out; + + /* Parse ND options */ + if (!ndisc_parse_options(skb->dev, rs_msg->opt, ndoptlen, &ndopts)) { + ND_PRINTK(2, notice, "NS: invalid ND option, ignored\n"); + goto out; + } + + if (ndopts.nd_opts_src_lladdr) { + lladdr = ndisc_opt_addr_data(ndopts.nd_opts_src_lladdr, + skb->dev); + if (!lladdr) + goto out; + } + + neigh = __neigh_lookup(&nd_tbl, saddr, skb->dev, 1); + if (neigh) { + ndisc_update(skb->dev, neigh, lladdr, NUD_STALE, + NEIGH_UPDATE_F_WEAK_OVERRIDE| + NEIGH_UPDATE_F_OVERRIDE| + NEIGH_UPDATE_F_OVERRIDE_ISROUTER, + NDISC_ROUTER_SOLICITATION, &ndopts); + neigh_release(neigh); + } +out: + return; +} + +static void ndisc_ra_useropt(struct sk_buff *ra, struct nd_opt_hdr *opt) +{ + struct icmp6hdr *icmp6h = (struct icmp6hdr *)skb_transport_header(ra); + struct sk_buff *skb; + struct nlmsghdr *nlh; + struct nduseroptmsg *ndmsg; + struct net *net = dev_net(ra->dev); + int err; + int base_size = NLMSG_ALIGN(sizeof(struct nduseroptmsg) + + (opt->nd_opt_len << 3)); + size_t msg_size = base_size + nla_total_size(sizeof(struct in6_addr)); + + skb = nlmsg_new(msg_size, GFP_ATOMIC); + if (!skb) { + err = -ENOBUFS; + goto errout; + } + + nlh = nlmsg_put(skb, 0, 0, RTM_NEWNDUSEROPT, base_size, 0); + if (!nlh) { + goto nla_put_failure; + } + + ndmsg = nlmsg_data(nlh); + ndmsg->nduseropt_family = AF_INET6; + ndmsg->nduseropt_ifindex = ra->dev->ifindex; + ndmsg->nduseropt_icmp_type = icmp6h->icmp6_type; + ndmsg->nduseropt_icmp_code = icmp6h->icmp6_code; + ndmsg->nduseropt_opts_len = opt->nd_opt_len << 3; + + memcpy(ndmsg + 1, opt, opt->nd_opt_len << 3); + + if (nla_put_in6_addr(skb, NDUSEROPT_SRCADDR, &ipv6_hdr(ra)->saddr)) + goto nla_put_failure; + nlmsg_end(skb, nlh); + + rtnl_notify(skb, net, 0, RTNLGRP_ND_USEROPT, NULL, GFP_ATOMIC); + return; + +nla_put_failure: + nlmsg_free(skb); + err = -EMSGSIZE; +errout: + rtnl_set_sk_err(net, RTNLGRP_ND_USEROPT, err); +} + +static void ndisc_router_discovery(struct sk_buff *skb) +{ + struct ra_msg *ra_msg = (struct ra_msg *)skb_transport_header(skb); + struct neighbour *neigh = NULL; + struct inet6_dev *in6_dev; + struct fib6_info *rt = NULL; + u32 defrtr_usr_metric; + struct net *net; + int lifetime; + struct ndisc_options ndopts; + int optlen; + unsigned int pref = 0; + __u32 old_if_flags; + bool send_ifinfo_notify = false; + + __u8 *opt = (__u8 *)(ra_msg + 1); + + optlen = (skb_tail_pointer(skb) - skb_transport_header(skb)) - + sizeof(struct ra_msg); + + ND_PRINTK(2, info, + "RA: %s, dev: %s\n", + __func__, skb->dev->name); + if (!(ipv6_addr_type(&ipv6_hdr(skb)->saddr) & IPV6_ADDR_LINKLOCAL)) { + ND_PRINTK(2, warn, "RA: source address is not link-local\n"); + return; + } + if (optlen < 0) { + ND_PRINTK(2, warn, "RA: packet too short\n"); + return; + } + +#ifdef CONFIG_IPV6_NDISC_NODETYPE + if (skb->ndisc_nodetype == NDISC_NODETYPE_HOST) { + ND_PRINTK(2, warn, "RA: from host or unauthorized router\n"); + return; + } +#endif + + /* + * set the RA_RECV flag in the interface + */ + + in6_dev = __in6_dev_get(skb->dev); + if (!in6_dev) { + ND_PRINTK(0, err, "RA: can't find inet6 device for %s\n", + skb->dev->name); + return; + } + + if (!ndisc_parse_options(skb->dev, opt, optlen, &ndopts)) { + ND_PRINTK(2, warn, "RA: invalid ND options\n"); + return; + } + + if (!ipv6_accept_ra(in6_dev)) { + ND_PRINTK(2, info, + "RA: %s, did not accept ra for dev: %s\n", + __func__, skb->dev->name); + goto skip_linkparms; + } + +#ifdef CONFIG_IPV6_NDISC_NODETYPE + /* skip link-specific parameters from interior routers */ + if (skb->ndisc_nodetype == NDISC_NODETYPE_NODEFAULT) { + ND_PRINTK(2, info, + "RA: %s, nodetype is NODEFAULT, dev: %s\n", + __func__, skb->dev->name); + goto skip_linkparms; + } +#endif + + if (in6_dev->if_flags & IF_RS_SENT) { + /* + * flag that an RA was received after an RS was sent + * out on this interface. + */ + in6_dev->if_flags |= IF_RA_RCVD; + } + + /* + * Remember the managed/otherconf flags from most recently + * received RA message (RFC 2462) -- yoshfuji + */ + old_if_flags = in6_dev->if_flags; + in6_dev->if_flags = (in6_dev->if_flags & ~(IF_RA_MANAGED | + IF_RA_OTHERCONF)) | + (ra_msg->icmph.icmp6_addrconf_managed ? + IF_RA_MANAGED : 0) | + (ra_msg->icmph.icmp6_addrconf_other ? + IF_RA_OTHERCONF : 0); + + if (old_if_flags != in6_dev->if_flags) + send_ifinfo_notify = true; + + if (!in6_dev->cnf.accept_ra_defrtr) { + ND_PRINTK(2, info, + "RA: %s, defrtr is false for dev: %s\n", + __func__, skb->dev->name); + goto skip_defrtr; + } + + lifetime = ntohs(ra_msg->icmph.icmp6_rt_lifetime); + if (lifetime != 0 && lifetime < in6_dev->cnf.accept_ra_min_lft) { + ND_PRINTK(2, info, + "RA: router lifetime (%ds) is too short: %s\n", + lifetime, skb->dev->name); + goto skip_defrtr; + } + + /* Do not accept RA with source-addr found on local machine unless + * accept_ra_from_local is set to true. + */ + net = dev_net(in6_dev->dev); + if (!in6_dev->cnf.accept_ra_from_local && + ipv6_chk_addr(net, &ipv6_hdr(skb)->saddr, in6_dev->dev, 0)) { + ND_PRINTK(2, info, + "RA from local address detected on dev: %s: default router ignored\n", + skb->dev->name); + goto skip_defrtr; + } + +#ifdef CONFIG_IPV6_ROUTER_PREF + pref = ra_msg->icmph.icmp6_router_pref; + /* 10b is handled as if it were 00b (medium) */ + if (pref == ICMPV6_ROUTER_PREF_INVALID || + !in6_dev->cnf.accept_ra_rtr_pref) + pref = ICMPV6_ROUTER_PREF_MEDIUM; +#endif + /* routes added from RAs do not use nexthop objects */ + rt = rt6_get_dflt_router(net, &ipv6_hdr(skb)->saddr, skb->dev); + if (rt) { + neigh = ip6_neigh_lookup(&rt->fib6_nh->fib_nh_gw6, + rt->fib6_nh->fib_nh_dev, NULL, + &ipv6_hdr(skb)->saddr); + if (!neigh) { + ND_PRINTK(0, err, + "RA: %s got default router without neighbour\n", + __func__); + fib6_info_release(rt); + return; + } + } + /* Set default route metric as specified by user */ + defrtr_usr_metric = in6_dev->cnf.ra_defrtr_metric; + /* delete the route if lifetime is 0 or if metric needs change */ + if (rt && (lifetime == 0 || rt->fib6_metric != defrtr_usr_metric)) { + ip6_del_rt(net, rt, false); + rt = NULL; + } + + ND_PRINTK(3, info, "RA: rt: %p lifetime: %d, metric: %d, for dev: %s\n", + rt, lifetime, defrtr_usr_metric, skb->dev->name); + if (!rt && lifetime) { + ND_PRINTK(3, info, "RA: adding default router\n"); + + if (neigh) + neigh_release(neigh); + + rt = rt6_add_dflt_router(net, &ipv6_hdr(skb)->saddr, + skb->dev, pref, defrtr_usr_metric); + if (!rt) { + ND_PRINTK(0, err, + "RA: %s failed to add default route\n", + __func__); + return; + } + + neigh = ip6_neigh_lookup(&rt->fib6_nh->fib_nh_gw6, + rt->fib6_nh->fib_nh_dev, NULL, + &ipv6_hdr(skb)->saddr); + if (!neigh) { + ND_PRINTK(0, err, + "RA: %s got default router without neighbour\n", + __func__); + fib6_info_release(rt); + return; + } + neigh->flags |= NTF_ROUTER; + } else if (rt && IPV6_EXTRACT_PREF(rt->fib6_flags) != pref) { + struct nl_info nlinfo = { + .nl_net = net, + }; + rt->fib6_flags = (rt->fib6_flags & ~RTF_PREF_MASK) | RTF_PREF(pref); + inet6_rt_notify(RTM_NEWROUTE, rt, &nlinfo, NLM_F_REPLACE); + } + + if (rt) + fib6_set_expires(rt, jiffies + (HZ * lifetime)); + if (in6_dev->cnf.accept_ra_min_hop_limit < 256 && + ra_msg->icmph.icmp6_hop_limit) { + if (in6_dev->cnf.accept_ra_min_hop_limit <= ra_msg->icmph.icmp6_hop_limit) { + in6_dev->cnf.hop_limit = ra_msg->icmph.icmp6_hop_limit; + fib6_metric_set(rt, RTAX_HOPLIMIT, + ra_msg->icmph.icmp6_hop_limit); + } else { + ND_PRINTK(2, warn, "RA: Got route advertisement with lower hop_limit than minimum\n"); + } + } + +skip_defrtr: + + /* + * Update Reachable Time and Retrans Timer + */ + + if (in6_dev->nd_parms) { + unsigned long rtime = ntohl(ra_msg->retrans_timer); + + if (rtime && rtime/1000 < MAX_SCHEDULE_TIMEOUT/HZ) { + rtime = (rtime*HZ)/1000; + if (rtime < HZ/100) + rtime = HZ/100; + NEIGH_VAR_SET(in6_dev->nd_parms, RETRANS_TIME, rtime); + in6_dev->tstamp = jiffies; + send_ifinfo_notify = true; + } + + rtime = ntohl(ra_msg->reachable_time); + if (rtime && rtime/1000 < MAX_SCHEDULE_TIMEOUT/(3*HZ)) { + rtime = (rtime*HZ)/1000; + + if (rtime < HZ/10) + rtime = HZ/10; + + if (rtime != NEIGH_VAR(in6_dev->nd_parms, BASE_REACHABLE_TIME)) { + NEIGH_VAR_SET(in6_dev->nd_parms, + BASE_REACHABLE_TIME, rtime); + NEIGH_VAR_SET(in6_dev->nd_parms, + GC_STALETIME, 3 * rtime); + in6_dev->nd_parms->reachable_time = neigh_rand_reach_time(rtime); + in6_dev->tstamp = jiffies; + send_ifinfo_notify = true; + } + } + } + +skip_linkparms: + + /* + * Process options. + */ + + if (!neigh) + neigh = __neigh_lookup(&nd_tbl, &ipv6_hdr(skb)->saddr, + skb->dev, 1); + if (neigh) { + u8 *lladdr = NULL; + if (ndopts.nd_opts_src_lladdr) { + lladdr = ndisc_opt_addr_data(ndopts.nd_opts_src_lladdr, + skb->dev); + if (!lladdr) { + ND_PRINTK(2, warn, + "RA: invalid link-layer address length\n"); + goto out; + } + } + ndisc_update(skb->dev, neigh, lladdr, NUD_STALE, + NEIGH_UPDATE_F_WEAK_OVERRIDE| + NEIGH_UPDATE_F_OVERRIDE| + NEIGH_UPDATE_F_OVERRIDE_ISROUTER| + NEIGH_UPDATE_F_ISROUTER, + NDISC_ROUTER_ADVERTISEMENT, &ndopts); + } + + if (!ipv6_accept_ra(in6_dev)) { + ND_PRINTK(2, info, + "RA: %s, accept_ra is false for dev: %s\n", + __func__, skb->dev->name); + goto out; + } + +#ifdef CONFIG_IPV6_ROUTE_INFO + if (!in6_dev->cnf.accept_ra_from_local && + ipv6_chk_addr(dev_net(in6_dev->dev), &ipv6_hdr(skb)->saddr, + in6_dev->dev, 0)) { + ND_PRINTK(2, info, + "RA from local address detected on dev: %s: router info ignored.\n", + skb->dev->name); + goto skip_routeinfo; + } + + if (in6_dev->cnf.accept_ra_rtr_pref && ndopts.nd_opts_ri) { + struct nd_opt_hdr *p; + for (p = ndopts.nd_opts_ri; + p; + p = ndisc_next_option(p, ndopts.nd_opts_ri_end)) { + struct route_info *ri = (struct route_info *)p; +#ifdef CONFIG_IPV6_NDISC_NODETYPE + if (skb->ndisc_nodetype == NDISC_NODETYPE_NODEFAULT && + ri->prefix_len == 0) + continue; +#endif + if (ri->prefix_len == 0 && + !in6_dev->cnf.accept_ra_defrtr) + continue; + if (ri->lifetime != 0 && + ntohl(ri->lifetime) < in6_dev->cnf.accept_ra_min_lft) + continue; + if (ri->prefix_len < in6_dev->cnf.accept_ra_rt_info_min_plen) + continue; + if (ri->prefix_len > in6_dev->cnf.accept_ra_rt_info_max_plen) + continue; + rt6_route_rcv(skb->dev, (u8 *)p, (p->nd_opt_len) << 3, + &ipv6_hdr(skb)->saddr); + } + } + +skip_routeinfo: +#endif + +#ifdef CONFIG_IPV6_NDISC_NODETYPE + /* skip link-specific ndopts from interior routers */ + if (skb->ndisc_nodetype == NDISC_NODETYPE_NODEFAULT) { + ND_PRINTK(2, info, + "RA: %s, nodetype is NODEFAULT (interior routes), dev: %s\n", + __func__, skb->dev->name); + goto out; + } +#endif + + if (in6_dev->cnf.accept_ra_pinfo && ndopts.nd_opts_pi) { + struct nd_opt_hdr *p; + for (p = ndopts.nd_opts_pi; + p; + p = ndisc_next_option(p, ndopts.nd_opts_pi_end)) { + addrconf_prefix_rcv(skb->dev, (u8 *)p, + (p->nd_opt_len) << 3, + ndopts.nd_opts_src_lladdr != NULL); + } + } + + if (ndopts.nd_opts_mtu && in6_dev->cnf.accept_ra_mtu) { + __be32 n; + u32 mtu; + + memcpy(&n, ((u8 *)(ndopts.nd_opts_mtu+1))+2, sizeof(mtu)); + mtu = ntohl(n); + + if (in6_dev->ra_mtu != mtu) { + in6_dev->ra_mtu = mtu; + send_ifinfo_notify = true; + } + + if (mtu < IPV6_MIN_MTU || mtu > skb->dev->mtu) { + ND_PRINTK(2, warn, "RA: invalid mtu: %d\n", mtu); + } else if (in6_dev->cnf.mtu6 != mtu) { + in6_dev->cnf.mtu6 = mtu; + fib6_metric_set(rt, RTAX_MTU, mtu); + rt6_mtu_change(skb->dev, mtu); + } + } + + if (ndopts.nd_useropts) { + struct nd_opt_hdr *p; + for (p = ndopts.nd_useropts; + p; + p = ndisc_next_useropt(skb->dev, p, + ndopts.nd_useropts_end)) { + ndisc_ra_useropt(skb, p); + } + } + + if (ndopts.nd_opts_tgt_lladdr || ndopts.nd_opts_rh) { + ND_PRINTK(2, warn, "RA: invalid RA options\n"); + } +out: + /* Send a notify if RA changed managed/otherconf flags or + * timer settings or ra_mtu value + */ + if (send_ifinfo_notify) + inet6_ifinfo_notify(RTM_NEWLINK, in6_dev); + + fib6_info_release(rt); + if (neigh) + neigh_release(neigh); +} + +static void ndisc_redirect_rcv(struct sk_buff *skb) +{ + u8 *hdr; + struct ndisc_options ndopts; + struct rd_msg *msg = (struct rd_msg *)skb_transport_header(skb); + u32 ndoptlen = skb_tail_pointer(skb) - (skb_transport_header(skb) + + offsetof(struct rd_msg, opt)); + +#ifdef CONFIG_IPV6_NDISC_NODETYPE + switch (skb->ndisc_nodetype) { + case NDISC_NODETYPE_HOST: + case NDISC_NODETYPE_NODEFAULT: + ND_PRINTK(2, warn, + "Redirect: from host or unauthorized router\n"); + return; + } +#endif + + if (!(ipv6_addr_type(&ipv6_hdr(skb)->saddr) & IPV6_ADDR_LINKLOCAL)) { + ND_PRINTK(2, warn, + "Redirect: source address is not link-local\n"); + return; + } + + if (!ndisc_parse_options(skb->dev, msg->opt, ndoptlen, &ndopts)) + return; + + if (!ndopts.nd_opts_rh) { + ip6_redirect_no_header(skb, dev_net(skb->dev), + skb->dev->ifindex); + return; + } + + hdr = (u8 *)ndopts.nd_opts_rh; + hdr += 8; + if (!pskb_pull(skb, hdr - skb_transport_header(skb))) + return; + + icmpv6_notify(skb, NDISC_REDIRECT, 0, 0); +} + +static void ndisc_fill_redirect_hdr_option(struct sk_buff *skb, + struct sk_buff *orig_skb, + int rd_len) +{ + u8 *opt = skb_put(skb, rd_len); + + memset(opt, 0, 8); + *(opt++) = ND_OPT_REDIRECT_HDR; + *(opt++) = (rd_len >> 3); + opt += 6; + + skb_copy_bits(orig_skb, skb_network_offset(orig_skb), opt, + rd_len - 8); +} + +void ndisc_send_redirect(struct sk_buff *skb, const struct in6_addr *target) +{ + struct net_device *dev = skb->dev; + struct net *net = dev_net(dev); + struct sock *sk = net->ipv6.ndisc_sk; + int optlen = 0; + struct inet_peer *peer; + struct sk_buff *buff; + struct rd_msg *msg; + struct in6_addr saddr_buf; + struct rt6_info *rt; + struct dst_entry *dst; + struct flowi6 fl6; + int rd_len; + u8 ha_buf[MAX_ADDR_LEN], *ha = NULL, + ops_data_buf[NDISC_OPS_REDIRECT_DATA_SPACE], *ops_data = NULL; + bool ret; + + if (netif_is_l3_master(skb->dev)) { + dev = __dev_get_by_index(dev_net(skb->dev), IPCB(skb)->iif); + if (!dev) + return; + } + + if (ipv6_get_lladdr(dev, &saddr_buf, IFA_F_TENTATIVE)) { + ND_PRINTK(2, warn, "Redirect: no link-local address on %s\n", + dev->name); + return; + } + + if (!ipv6_addr_equal(&ipv6_hdr(skb)->daddr, target) && + ipv6_addr_type(target) != (IPV6_ADDR_UNICAST|IPV6_ADDR_LINKLOCAL)) { + ND_PRINTK(2, warn, + "Redirect: target address is not link-local unicast\n"); + return; + } + + icmpv6_flow_init(sk, &fl6, NDISC_REDIRECT, + &saddr_buf, &ipv6_hdr(skb)->saddr, dev->ifindex); + + dst = ip6_route_output(net, NULL, &fl6); + if (dst->error) { + dst_release(dst); + return; + } + dst = xfrm_lookup(net, dst, flowi6_to_flowi(&fl6), NULL, 0); + if (IS_ERR(dst)) + return; + + rt = (struct rt6_info *) dst; + + if (rt->rt6i_flags & RTF_GATEWAY) { + ND_PRINTK(2, warn, + "Redirect: destination is not a neighbour\n"); + goto release; + } + peer = inet_getpeer_v6(net->ipv6.peers, &ipv6_hdr(skb)->saddr, 1); + ret = inet_peer_xrlim_allow(peer, 1*HZ); + if (peer) + inet_putpeer(peer); + if (!ret) + goto release; + + if (dev->addr_len) { + struct neighbour *neigh = dst_neigh_lookup(skb_dst(skb), target); + if (!neigh) { + ND_PRINTK(2, warn, + "Redirect: no neigh for target address\n"); + goto release; + } + + read_lock_bh(&neigh->lock); + if (neigh->nud_state & NUD_VALID) { + memcpy(ha_buf, neigh->ha, dev->addr_len); + read_unlock_bh(&neigh->lock); + ha = ha_buf; + optlen += ndisc_redirect_opt_addr_space(dev, neigh, + ops_data_buf, + &ops_data); + } else + read_unlock_bh(&neigh->lock); + + neigh_release(neigh); + } + + rd_len = min_t(unsigned int, + IPV6_MIN_MTU - sizeof(struct ipv6hdr) - sizeof(*msg) - optlen, + skb->len + 8); + rd_len &= ~0x7; + optlen += rd_len; + + buff = ndisc_alloc_skb(dev, sizeof(*msg) + optlen); + if (!buff) + goto release; + + msg = skb_put(buff, sizeof(*msg)); + *msg = (struct rd_msg) { + .icmph = { + .icmp6_type = NDISC_REDIRECT, + }, + .target = *target, + .dest = ipv6_hdr(skb)->daddr, + }; + + /* + * include target_address option + */ + + if (ha) + ndisc_fill_redirect_addr_option(buff, ha, ops_data); + + /* + * build redirect option and copy skb over to the new packet. + */ + + if (rd_len) + ndisc_fill_redirect_hdr_option(buff, skb, rd_len); + + skb_dst_set(buff, dst); + ndisc_send_skb(buff, &ipv6_hdr(skb)->saddr, &saddr_buf); + return; + +release: + dst_release(dst); +} + +static void pndisc_redo(struct sk_buff *skb) +{ + ndisc_recv_ns(skb); + kfree_skb(skb); +} + +static int ndisc_is_multicast(const void *pkey) +{ + return ipv6_addr_is_multicast((struct in6_addr *)pkey); +} + +static bool ndisc_suppress_frag_ndisc(struct sk_buff *skb) +{ + struct inet6_dev *idev = __in6_dev_get(skb->dev); + + if (!idev) + return true; + if (IP6CB(skb)->flags & IP6SKB_FRAGMENTED && + idev->cnf.suppress_frag_ndisc) { + net_warn_ratelimited("Received fragmented ndisc packet. Carefully consider disabling suppress_frag_ndisc.\n"); + return true; + } + return false; +} + +int ndisc_rcv(struct sk_buff *skb) +{ + struct nd_msg *msg; + + if (ndisc_suppress_frag_ndisc(skb)) + return 0; + + if (skb_linearize(skb)) + return 0; + + msg = (struct nd_msg *)skb_transport_header(skb); + + __skb_push(skb, skb->data - skb_transport_header(skb)); + + if (ipv6_hdr(skb)->hop_limit != 255) { + ND_PRINTK(2, warn, "NDISC: invalid hop-limit: %d\n", + ipv6_hdr(skb)->hop_limit); + return 0; + } + + if (msg->icmph.icmp6_code != 0) { + ND_PRINTK(2, warn, "NDISC: invalid ICMPv6 code: %d\n", + msg->icmph.icmp6_code); + return 0; + } + + switch (msg->icmph.icmp6_type) { + case NDISC_NEIGHBOUR_SOLICITATION: + memset(NEIGH_CB(skb), 0, sizeof(struct neighbour_cb)); + ndisc_recv_ns(skb); + break; + + case NDISC_NEIGHBOUR_ADVERTISEMENT: + ndisc_recv_na(skb); + break; + + case NDISC_ROUTER_SOLICITATION: + ndisc_recv_rs(skb); + break; + + case NDISC_ROUTER_ADVERTISEMENT: + ndisc_router_discovery(skb); + break; + + case NDISC_REDIRECT: + ndisc_redirect_rcv(skb); + break; + } + + return 0; +} + +static int ndisc_netdev_event(struct notifier_block *this, unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct netdev_notifier_change_info *change_info; + struct net *net = dev_net(dev); + struct inet6_dev *idev; + bool evict_nocarrier; + + switch (event) { + case NETDEV_CHANGEADDR: + neigh_changeaddr(&nd_tbl, dev); + fib6_run_gc(0, net, false); + fallthrough; + case NETDEV_UP: + idev = in6_dev_get(dev); + if (!idev) + break; + if (idev->cnf.ndisc_notify || + net->ipv6.devconf_all->ndisc_notify) + ndisc_send_unsol_na(dev); + in6_dev_put(idev); + break; + case NETDEV_CHANGE: + idev = in6_dev_get(dev); + if (!idev) + evict_nocarrier = true; + else { + evict_nocarrier = idev->cnf.ndisc_evict_nocarrier && + net->ipv6.devconf_all->ndisc_evict_nocarrier; + in6_dev_put(idev); + } + + change_info = ptr; + if (change_info->flags_changed & IFF_NOARP) + neigh_changeaddr(&nd_tbl, dev); + if (evict_nocarrier && !netif_carrier_ok(dev)) + neigh_carrier_down(&nd_tbl, dev); + break; + case NETDEV_DOWN: + neigh_ifdown(&nd_tbl, dev); + fib6_run_gc(0, net, false); + break; + case NETDEV_NOTIFY_PEERS: + ndisc_send_unsol_na(dev); + break; + default: + break; + } + + return NOTIFY_DONE; +} + +static struct notifier_block ndisc_netdev_notifier = { + .notifier_call = ndisc_netdev_event, + .priority = ADDRCONF_NOTIFY_PRIORITY - 5, +}; + +#ifdef CONFIG_SYSCTL +static void ndisc_warn_deprecated_sysctl(struct ctl_table *ctl, + const char *func, const char *dev_name) +{ + static char warncomm[TASK_COMM_LEN]; + static int warned; + if (strcmp(warncomm, current->comm) && warned < 5) { + strcpy(warncomm, current->comm); + pr_warn("process `%s' is using deprecated sysctl (%s) net.ipv6.neigh.%s.%s - use net.ipv6.neigh.%s.%s_ms instead\n", + warncomm, func, + dev_name, ctl->procname, + dev_name, ctl->procname); + warned++; + } +} + +int ndisc_ifinfo_sysctl_change(struct ctl_table *ctl, int write, void *buffer, + size_t *lenp, loff_t *ppos) +{ + struct net_device *dev = ctl->extra1; + struct inet6_dev *idev; + int ret; + + if ((strcmp(ctl->procname, "retrans_time") == 0) || + (strcmp(ctl->procname, "base_reachable_time") == 0)) + ndisc_warn_deprecated_sysctl(ctl, "syscall", dev ? dev->name : "default"); + + if (strcmp(ctl->procname, "retrans_time") == 0) + ret = neigh_proc_dointvec(ctl, write, buffer, lenp, ppos); + + else if (strcmp(ctl->procname, "base_reachable_time") == 0) + ret = neigh_proc_dointvec_jiffies(ctl, write, + buffer, lenp, ppos); + + else if ((strcmp(ctl->procname, "retrans_time_ms") == 0) || + (strcmp(ctl->procname, "base_reachable_time_ms") == 0)) + ret = neigh_proc_dointvec_ms_jiffies(ctl, write, + buffer, lenp, ppos); + else + ret = -1; + + if (write && ret == 0 && dev && (idev = in6_dev_get(dev)) != NULL) { + if (ctl->data == &NEIGH_VAR(idev->nd_parms, BASE_REACHABLE_TIME)) + idev->nd_parms->reachable_time = + neigh_rand_reach_time(NEIGH_VAR(idev->nd_parms, BASE_REACHABLE_TIME)); + idev->tstamp = jiffies; + inet6_ifinfo_notify(RTM_NEWLINK, idev); + in6_dev_put(idev); + } + return ret; +} + + +#endif + +static int __net_init ndisc_net_init(struct net *net) +{ + struct ipv6_pinfo *np; + struct sock *sk; + int err; + + err = inet_ctl_sock_create(&sk, PF_INET6, + SOCK_RAW, IPPROTO_ICMPV6, net); + if (err < 0) { + ND_PRINTK(0, err, + "NDISC: Failed to initialize the control socket (err %d)\n", + err); + return err; + } + + net->ipv6.ndisc_sk = sk; + + np = inet6_sk(sk); + np->hop_limit = 255; + /* Do not loopback ndisc messages */ + np->mc_loop = 0; + + return 0; +} + +static void __net_exit ndisc_net_exit(struct net *net) +{ + inet_ctl_sock_destroy(net->ipv6.ndisc_sk); +} + +static struct pernet_operations ndisc_net_ops = { + .init = ndisc_net_init, + .exit = ndisc_net_exit, +}; + +int __init ndisc_init(void) +{ + int err; + + err = register_pernet_subsys(&ndisc_net_ops); + if (err) + return err; + /* + * Initialize the neighbour table + */ + neigh_table_init(NEIGH_ND_TABLE, &nd_tbl); + +#ifdef CONFIG_SYSCTL + err = neigh_sysctl_register(NULL, &nd_tbl.parms, + ndisc_ifinfo_sysctl_change); + if (err) + goto out_unregister_pernet; +out: +#endif + return err; + +#ifdef CONFIG_SYSCTL +out_unregister_pernet: + unregister_pernet_subsys(&ndisc_net_ops); + goto out; +#endif +} + +int __init ndisc_late_init(void) +{ + return register_netdevice_notifier(&ndisc_netdev_notifier); +} + +void ndisc_late_cleanup(void) +{ + unregister_netdevice_notifier(&ndisc_netdev_notifier); +} + +void ndisc_cleanup(void) +{ +#ifdef CONFIG_SYSCTL + neigh_sysctl_unregister(&nd_tbl.parms); +#endif + neigh_table_clear(NEIGH_ND_TABLE, &nd_tbl); + unregister_pernet_subsys(&ndisc_net_ops); +} diff --git a/net/ipv6/netfilter.c b/net/ipv6/netfilter.c new file mode 100644 index 000000000..857713d7a --- /dev/null +++ b/net/ipv6/netfilter.c @@ -0,0 +1,273 @@ +/* + * IPv6 specific functions of netfilter core + * + * Rusty Russell (C) 2000 -- This code is GPL. + * Patrick McHardy (C) 2006-2012 + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../bridge/br_private.h" + +int ip6_route_me_harder(struct net *net, struct sock *sk_partial, struct sk_buff *skb) +{ + const struct ipv6hdr *iph = ipv6_hdr(skb); + struct sock *sk = sk_to_full_sk(sk_partial); + struct net_device *dev = skb_dst(skb)->dev; + struct flow_keys flkeys; + unsigned int hh_len; + struct dst_entry *dst; + int strict = (ipv6_addr_type(&iph->daddr) & + (IPV6_ADDR_MULTICAST | IPV6_ADDR_LINKLOCAL)); + struct flowi6 fl6 = { + .flowi6_l3mdev = l3mdev_master_ifindex(dev), + .flowi6_mark = skb->mark, + .flowi6_uid = sock_net_uid(net, sk), + .daddr = iph->daddr, + .saddr = iph->saddr, + }; + int err; + + if (sk && sk->sk_bound_dev_if) + fl6.flowi6_oif = sk->sk_bound_dev_if; + else if (strict) + fl6.flowi6_oif = dev->ifindex; + + fib6_rules_early_flow_dissect(net, skb, &fl6, &flkeys); + dst = ip6_route_output(net, sk, &fl6); + err = dst->error; + if (err) { + IP6_INC_STATS(net, ip6_dst_idev(dst), IPSTATS_MIB_OUTNOROUTES); + net_dbg_ratelimited("ip6_route_me_harder: No more route\n"); + dst_release(dst); + return err; + } + + /* Drop old route. */ + skb_dst_drop(skb); + + skb_dst_set(skb, dst); + +#ifdef CONFIG_XFRM + if (!(IP6CB(skb)->flags & IP6SKB_XFRM_TRANSFORMED) && + xfrm_decode_session(skb, flowi6_to_flowi(&fl6), AF_INET6) == 0) { + skb_dst_set(skb, NULL); + dst = xfrm_lookup(net, dst, flowi6_to_flowi(&fl6), sk, 0); + if (IS_ERR(dst)) + return PTR_ERR(dst); + skb_dst_set(skb, dst); + } +#endif + + /* Change in oif may mean change in hh_len. */ + hh_len = skb_dst(skb)->dev->hard_header_len; + if (skb_headroom(skb) < hh_len && + pskb_expand_head(skb, HH_DATA_ALIGN(hh_len - skb_headroom(skb)), + 0, GFP_ATOMIC)) + return -ENOMEM; + + return 0; +} +EXPORT_SYMBOL(ip6_route_me_harder); + +static int nf_ip6_reroute(struct sk_buff *skb, + const struct nf_queue_entry *entry) +{ + struct ip6_rt_info *rt_info = nf_queue_entry_reroute(entry); + + if (entry->state.hook == NF_INET_LOCAL_OUT) { + const struct ipv6hdr *iph = ipv6_hdr(skb); + if (!ipv6_addr_equal(&iph->daddr, &rt_info->daddr) || + !ipv6_addr_equal(&iph->saddr, &rt_info->saddr) || + skb->mark != rt_info->mark) + return ip6_route_me_harder(entry->state.net, entry->state.sk, skb); + } + return 0; +} + +int __nf_ip6_route(struct net *net, struct dst_entry **dst, + struct flowi *fl, bool strict) +{ + static const struct ipv6_pinfo fake_pinfo; + static const struct inet_sock fake_sk = { + /* makes ip6_route_output set RT6_LOOKUP_F_IFACE: */ + .sk.sk_bound_dev_if = 1, + .pinet6 = (struct ipv6_pinfo *) &fake_pinfo, + }; + const void *sk = strict ? &fake_sk : NULL; + struct dst_entry *result; + int err; + + result = ip6_route_output(net, sk, &fl->u.ip6); + err = result->error; + if (err) + dst_release(result); + else + *dst = result; + return err; +} +EXPORT_SYMBOL_GPL(__nf_ip6_route); + +int br_ip6_fragment(struct net *net, struct sock *sk, struct sk_buff *skb, + struct nf_bridge_frag_data *data, + int (*output)(struct net *, struct sock *sk, + const struct nf_bridge_frag_data *data, + struct sk_buff *)) +{ + int frag_max_size = BR_INPUT_SKB_CB(skb)->frag_max_size; + bool mono_delivery_time = skb->mono_delivery_time; + ktime_t tstamp = skb->tstamp; + struct ip6_frag_state state; + u8 *prevhdr, nexthdr = 0; + unsigned int mtu, hlen; + int hroom, err = 0; + __be32 frag_id; + + err = ip6_find_1stfragopt(skb, &prevhdr); + if (err < 0) + goto blackhole; + hlen = err; + nexthdr = *prevhdr; + + mtu = skb->dev->mtu; + if (frag_max_size > mtu || + frag_max_size < IPV6_MIN_MTU) + goto blackhole; + + mtu = frag_max_size; + if (mtu < hlen + sizeof(struct frag_hdr) + 8) + goto blackhole; + mtu -= hlen + sizeof(struct frag_hdr); + + frag_id = ipv6_select_ident(net, &ipv6_hdr(skb)->daddr, + &ipv6_hdr(skb)->saddr); + + if (skb->ip_summed == CHECKSUM_PARTIAL && + (err = skb_checksum_help(skb))) + goto blackhole; + + hroom = LL_RESERVED_SPACE(skb->dev); + if (skb_has_frag_list(skb)) { + unsigned int first_len = skb_pagelen(skb); + struct ip6_fraglist_iter iter; + struct sk_buff *frag2; + + if (first_len - hlen > mtu || + skb_headroom(skb) < (hroom + sizeof(struct frag_hdr))) + goto blackhole; + + if (skb_cloned(skb)) + goto slow_path; + + skb_walk_frags(skb, frag2) { + if (frag2->len > mtu || + skb_headroom(frag2) < (hlen + hroom + sizeof(struct frag_hdr))) + goto blackhole; + + /* Partially cloned skb? */ + if (skb_shared(frag2)) + goto slow_path; + } + + err = ip6_fraglist_init(skb, hlen, prevhdr, nexthdr, frag_id, + &iter); + if (err < 0) + goto blackhole; + + for (;;) { + /* Prepare header of the next frame, + * before previous one went down. + */ + if (iter.frag) + ip6_fraglist_prepare(skb, &iter); + + skb_set_delivery_time(skb, tstamp, mono_delivery_time); + err = output(net, sk, data, skb); + if (err || !iter.frag) + break; + + skb = ip6_fraglist_next(&iter); + } + + kfree(iter.tmp_hdr); + if (!err) + return 0; + + kfree_skb_list(iter.frag); + return err; + } +slow_path: + /* This is a linearized skbuff, the original geometry is lost for us. + * This may also be a clone skbuff, we could preserve the geometry for + * the copies but probably not worth the effort. + */ + ip6_frag_init(skb, hlen, mtu, skb->dev->needed_tailroom, + LL_RESERVED_SPACE(skb->dev), prevhdr, nexthdr, frag_id, + &state); + + while (state.left > 0) { + struct sk_buff *skb2; + + skb2 = ip6_frag_next(skb, &state); + if (IS_ERR(skb2)) { + err = PTR_ERR(skb2); + goto blackhole; + } + + skb_set_delivery_time(skb2, tstamp, mono_delivery_time); + err = output(net, sk, data, skb2); + if (err) + goto blackhole; + } + consume_skb(skb); + return err; + +blackhole: + kfree_skb(skb); + return 0; +} +EXPORT_SYMBOL_GPL(br_ip6_fragment); + +static const struct nf_ipv6_ops ipv6ops = { +#if IS_MODULE(CONFIG_IPV6) + .chk_addr = ipv6_chk_addr, + .route_me_harder = ip6_route_me_harder, + .dev_get_saddr = ipv6_dev_get_saddr, + .route = __nf_ip6_route, +#if IS_ENABLED(CONFIG_SYN_COOKIES) + .cookie_init_sequence = __cookie_v6_init_sequence, + .cookie_v6_check = __cookie_v6_check, +#endif +#endif + .route_input = ip6_route_input, + .fragment = ip6_fragment, + .reroute = nf_ip6_reroute, +#if IS_MODULE(CONFIG_IPV6) + .br_fragment = br_ip6_fragment, +#endif +}; + +int __init ipv6_netfilter_init(void) +{ + RCU_INIT_POINTER(nf_ipv6_ops, &ipv6ops); + return 0; +} + +/* This can be called from inet6_init() on errors, so it cannot + * be marked __exit. -DaveM + */ +void ipv6_netfilter_fini(void) +{ + RCU_INIT_POINTER(nf_ipv6_ops, NULL); +} diff --git a/net/ipv6/netfilter/Kconfig b/net/ipv6/netfilter/Kconfig new file mode 100644 index 000000000..0ba62f486 --- /dev/null +++ b/net/ipv6/netfilter/Kconfig @@ -0,0 +1,288 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# IP netfilter configuration +# + +menu "IPv6: Netfilter Configuration" + depends on INET && IPV6 && NETFILTER + +config NF_SOCKET_IPV6 + tristate "IPv6 socket lookup support" + help + This option enables the IPv6 socket lookup infrastructure. This + is used by the {ip6,nf}tables socket match. + +config NF_TPROXY_IPV6 + tristate "IPv6 tproxy support" + +if NF_TABLES + +config NF_TABLES_IPV6 + bool "IPv6 nf_tables support" + help + This option enables the IPv6 support for nf_tables. + +if NF_TABLES_IPV6 + +config NFT_REJECT_IPV6 + select NF_REJECT_IPV6 + default NFT_REJECT + tristate + +config NFT_DUP_IPV6 + tristate "IPv6 nf_tables packet duplication support" + depends on !NF_CONNTRACK || NF_CONNTRACK + select NF_DUP_IPV6 + help + This module enables IPv6 packet duplication support for nf_tables. + +config NFT_FIB_IPV6 + tristate "nf_tables fib / ipv6 route lookup support" + select NFT_FIB + help + This module enables IPv6 FIB lookups, e.g. for reverse path filtering. + It also allows query of the FIB for the route type, e.g. local, unicast, + multicast or blackhole. + +endif # NF_TABLES_IPV6 +endif # NF_TABLES + +config NF_DUP_IPV6 + tristate "Netfilter IPv6 packet duplication to alternate destination" + depends on !NF_CONNTRACK || NF_CONNTRACK + help + This option enables the nf_dup_ipv6 core, which duplicates an IPv6 + packet to be rerouted to another destination. + +config NF_REJECT_IPV6 + tristate "IPv6 packet rejection" + default m if NETFILTER_ADVANCED=n + +config NF_LOG_IPV6 + tristate "IPv6 packet logging" + default m if NETFILTER_ADVANCED=n + select NF_LOG_SYSLOG + help + This is a backwards-compat option for the user's convenience + (e.g. when running oldconfig). It selects CONFIG_NF_LOG_SYSLOG. + +config IP6_NF_IPTABLES + tristate "IP6 tables support (required for filtering)" + depends on INET && IPV6 + select NETFILTER_XTABLES + default m if NETFILTER_ADVANCED=n + help + ip6tables is a general, extensible packet identification framework. + Currently only the packet filtering and packet mangling subsystem + for IPv6 use this, but connection tracking is going to follow. + Say 'Y' or 'M' here if you want to use either of those. + + To compile it as a module, choose M here. If unsure, say N. + +if IP6_NF_IPTABLES + +# The simple matches. +config IP6_NF_MATCH_AH + tristate '"ah" match support' + depends on NETFILTER_ADVANCED + help + This module allows one to match AH packets. + + To compile it as a module, choose M here. If unsure, say N. + +config IP6_NF_MATCH_EUI64 + tristate '"eui64" address check' + depends on NETFILTER_ADVANCED + help + This module performs checking on the IPv6 source address + Compares the last 64 bits with the EUI64 (delivered + from the MAC address) address + + To compile it as a module, choose M here. If unsure, say N. + +config IP6_NF_MATCH_FRAG + tristate '"frag" Fragmentation header match support' + depends on NETFILTER_ADVANCED + help + frag matching allows you to match packets based on the fragmentation + header of the packet. + + To compile it as a module, choose M here. If unsure, say N. + +config IP6_NF_MATCH_OPTS + tristate '"hbh" hop-by-hop and "dst" opts header match support' + depends on NETFILTER_ADVANCED + help + This allows one to match packets based on the hop-by-hop + and destination options headers of a packet. + + To compile it as a module, choose M here. If unsure, say N. + +config IP6_NF_MATCH_HL + tristate '"hl" hoplimit match support' + depends on NETFILTER_ADVANCED + select NETFILTER_XT_MATCH_HL + help + This is a backwards-compat option for the user's convenience + (e.g. when running oldconfig). It selects + CONFIG_NETFILTER_XT_MATCH_HL. + +config IP6_NF_MATCH_IPV6HEADER + tristate '"ipv6header" IPv6 Extension Headers Match' + default m if NETFILTER_ADVANCED=n + help + This module allows one to match packets based upon + the ipv6 extension headers. + + To compile it as a module, choose M here. If unsure, say N. + +config IP6_NF_MATCH_MH + tristate '"mh" match support' + depends on NETFILTER_ADVANCED + help + This module allows one to match MH packets. + + To compile it as a module, choose M here. If unsure, say N. + +config IP6_NF_MATCH_RPFILTER + tristate '"rpfilter" reverse path filter match support' + depends on NETFILTER_ADVANCED + depends on IP6_NF_MANGLE || IP6_NF_RAW + help + This option allows you to match packets whose replies would + go out via the interface the packet came in. + + To compile it as a module, choose M here. If unsure, say N. + The module will be called ip6t_rpfilter. + +config IP6_NF_MATCH_RT + tristate '"rt" Routing header match support' + depends on NETFILTER_ADVANCED + help + rt matching allows you to match packets based on the routing + header of the packet. + + To compile it as a module, choose M here. If unsure, say N. + +config IP6_NF_MATCH_SRH + tristate '"srh" Segment Routing header match support' + depends on NETFILTER_ADVANCED + help + srh matching allows you to match packets based on the segment + routing header of the packet. + + To compile it as a module, choose M here. If unsure, say N. + +# The targets +config IP6_NF_TARGET_HL + tristate '"HL" hoplimit target support' + depends on NETFILTER_ADVANCED && IP6_NF_MANGLE + select NETFILTER_XT_TARGET_HL + help + This is a backwards-compatible option for the user's convenience + (e.g. when running oldconfig). It selects + CONFIG_NETFILTER_XT_TARGET_HL. + +config IP6_NF_FILTER + tristate "Packet filtering" + default m if NETFILTER_ADVANCED=n + help + Packet filtering defines a table `filter', which has a series of + rules for simple packet filtering at local input, forwarding and + local output. See the man page for iptables(8). + + To compile it as a module, choose M here. If unsure, say N. + +config IP6_NF_TARGET_REJECT + tristate "REJECT target support" + depends on IP6_NF_FILTER + select NF_REJECT_IPV6 + default m if NETFILTER_ADVANCED=n + help + The REJECT target allows a filtering rule to specify that an ICMPv6 + error should be issued in response to an incoming packet, rather + than silently being dropped. + + To compile it as a module, choose M here. If unsure, say N. + +config IP6_NF_TARGET_SYNPROXY + tristate "SYNPROXY target support" + depends on NF_CONNTRACK && NETFILTER_ADVANCED + select NETFILTER_SYNPROXY + select SYN_COOKIES + help + The SYNPROXY target allows you to intercept TCP connections and + establish them using syncookies before they are passed on to the + server. This allows to avoid conntrack and server resource usage + during SYN-flood attacks. + + To compile it as a module, choose M here. If unsure, say N. + +config IP6_NF_MANGLE + tristate "Packet mangling" + default m if NETFILTER_ADVANCED=n + help + This option adds a `mangle' table to iptables: see the man page for + iptables(8). This table is used for various packet alterations + which can effect how the packet is routed. + + To compile it as a module, choose M here. If unsure, say N. + +config IP6_NF_RAW + tristate 'raw table support (required for TRACE)' + help + This option adds a `raw' table to ip6tables. This table is the very + first in the netfilter framework and hooks in at the PREROUTING + and OUTPUT chains. + + If you want to compile it as a module, say M here and read + . If unsure, say `N'. + +# security table for MAC policy +config IP6_NF_SECURITY + tristate "Security table" + depends on SECURITY + depends on NETFILTER_ADVANCED + help + This option adds a `security' table to iptables, for use + with Mandatory Access Control (MAC) policy. + + If unsure, say N. + +config IP6_NF_NAT + tristate "ip6tables NAT support" + depends on NF_CONNTRACK + depends on NETFILTER_ADVANCED + select NF_NAT + select NETFILTER_XT_NAT + help + This enables the `nat' table in ip6tables. This allows masquerading, + port forwarding and other forms of full Network Address Port + Translation. + + To compile it as a module, choose M here. If unsure, say N. + +if IP6_NF_NAT + +config IP6_NF_TARGET_MASQUERADE + tristate "MASQUERADE target support" + select NETFILTER_XT_TARGET_MASQUERADE + help + This is a backwards-compat option for the user's convenience + (e.g. when running oldconfig). It selects NETFILTER_XT_TARGET_MASQUERADE. + +config IP6_NF_TARGET_NPT + tristate "NPT (Network Prefix translation) target support" + help + This option adds the `SNPT' and `DNPT' target, which perform + stateless IPv6-to-IPv6 Network Prefix Translation per RFC 6296. + + To compile it as a module, choose M here. If unsure, say N. + +endif # IP6_NF_NAT + +endif # IP6_NF_IPTABLES +endmenu + +config NF_DEFRAG_IPV6 + tristate diff --git a/net/ipv6/netfilter/Makefile b/net/ipv6/netfilter/Makefile new file mode 100644 index 000000000..b8d6dc9ae --- /dev/null +++ b/net/ipv6/netfilter/Makefile @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the netfilter modules on top of IPv6. +# + +# Link order matters here. +obj-$(CONFIG_IP6_NF_IPTABLES) += ip6_tables.o +obj-$(CONFIG_IP6_NF_FILTER) += ip6table_filter.o +obj-$(CONFIG_IP6_NF_MANGLE) += ip6table_mangle.o +obj-$(CONFIG_IP6_NF_RAW) += ip6table_raw.o +obj-$(CONFIG_IP6_NF_SECURITY) += ip6table_security.o +obj-$(CONFIG_IP6_NF_NAT) += ip6table_nat.o + +# defrag +nf_defrag_ipv6-y := nf_defrag_ipv6_hooks.o nf_conntrack_reasm.o +obj-$(CONFIG_NF_DEFRAG_IPV6) += nf_defrag_ipv6.o + +obj-$(CONFIG_NF_SOCKET_IPV6) += nf_socket_ipv6.o +obj-$(CONFIG_NF_TPROXY_IPV6) += nf_tproxy_ipv6.o + +# reject +obj-$(CONFIG_NF_REJECT_IPV6) += nf_reject_ipv6.o + +obj-$(CONFIG_NF_DUP_IPV6) += nf_dup_ipv6.o + +# nf_tables +obj-$(CONFIG_NFT_REJECT_IPV6) += nft_reject_ipv6.o +obj-$(CONFIG_NFT_DUP_IPV6) += nft_dup_ipv6.o +obj-$(CONFIG_NFT_FIB_IPV6) += nft_fib_ipv6.o + +# matches +obj-$(CONFIG_IP6_NF_MATCH_AH) += ip6t_ah.o +obj-$(CONFIG_IP6_NF_MATCH_EUI64) += ip6t_eui64.o +obj-$(CONFIG_IP6_NF_MATCH_FRAG) += ip6t_frag.o +obj-$(CONFIG_IP6_NF_MATCH_IPV6HEADER) += ip6t_ipv6header.o +obj-$(CONFIG_IP6_NF_MATCH_MH) += ip6t_mh.o +obj-$(CONFIG_IP6_NF_MATCH_OPTS) += ip6t_hbh.o +obj-$(CONFIG_IP6_NF_MATCH_RPFILTER) += ip6t_rpfilter.o +obj-$(CONFIG_IP6_NF_MATCH_RT) += ip6t_rt.o +obj-$(CONFIG_IP6_NF_MATCH_SRH) += ip6t_srh.o + +# targets +obj-$(CONFIG_IP6_NF_TARGET_NPT) += ip6t_NPT.o +obj-$(CONFIG_IP6_NF_TARGET_REJECT) += ip6t_REJECT.o +obj-$(CONFIG_IP6_NF_TARGET_SYNPROXY) += ip6t_SYNPROXY.o diff --git a/net/ipv6/netfilter/ip6_tables.c b/net/ipv6/netfilter/ip6_tables.c new file mode 100644 index 000000000..0ce0ed17c --- /dev/null +++ b/net/ipv6/netfilter/ip6_tables.c @@ -0,0 +1,1960 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Packet matching code. + * + * Copyright (C) 1999 Paul `Rusty' Russell & Michael J. Neuling + * Copyright (C) 2000-2005 Netfilter Core Team + * Copyright (c) 2006-2010 Patrick McHardy + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include "../../netfilter/xt_repldata.h" + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Netfilter Core Team "); +MODULE_DESCRIPTION("IPv6 packet filter"); +MODULE_ALIAS("ip6t_icmp6"); + +void *ip6t_alloc_initial_table(const struct xt_table *info) +{ + return xt_alloc_initial_table(ip6t, IP6T); +} +EXPORT_SYMBOL_GPL(ip6t_alloc_initial_table); + +/* Returns whether matches rule or not. */ +/* Performance critical - called for every packet */ +static inline bool +ip6_packet_match(const struct sk_buff *skb, + const char *indev, + const char *outdev, + const struct ip6t_ip6 *ip6info, + unsigned int *protoff, + u16 *fragoff, bool *hotdrop) +{ + unsigned long ret; + const struct ipv6hdr *ipv6 = ipv6_hdr(skb); + + if (NF_INVF(ip6info, IP6T_INV_SRCIP, + ipv6_masked_addr_cmp(&ipv6->saddr, &ip6info->smsk, + &ip6info->src)) || + NF_INVF(ip6info, IP6T_INV_DSTIP, + ipv6_masked_addr_cmp(&ipv6->daddr, &ip6info->dmsk, + &ip6info->dst))) + return false; + + ret = ifname_compare_aligned(indev, ip6info->iniface, ip6info->iniface_mask); + + if (NF_INVF(ip6info, IP6T_INV_VIA_IN, ret != 0)) + return false; + + ret = ifname_compare_aligned(outdev, ip6info->outiface, ip6info->outiface_mask); + + if (NF_INVF(ip6info, IP6T_INV_VIA_OUT, ret != 0)) + return false; + +/* ... might want to do something with class and flowlabel here ... */ + + /* look for the desired protocol header */ + if (ip6info->flags & IP6T_F_PROTO) { + int protohdr; + unsigned short _frag_off; + + protohdr = ipv6_find_hdr(skb, protoff, -1, &_frag_off, NULL); + if (protohdr < 0) { + if (_frag_off == 0) + *hotdrop = true; + return false; + } + *fragoff = _frag_off; + + if (ip6info->proto == protohdr) { + if (ip6info->invflags & IP6T_INV_PROTO) + return false; + + return true; + } + + /* We need match for the '-p all', too! */ + if ((ip6info->proto != 0) && + !(ip6info->invflags & IP6T_INV_PROTO)) + return false; + } + return true; +} + +/* should be ip6 safe */ +static bool +ip6_checkentry(const struct ip6t_ip6 *ipv6) +{ + if (ipv6->flags & ~IP6T_F_MASK) + return false; + if (ipv6->invflags & ~IP6T_INV_MASK) + return false; + + return true; +} + +static unsigned int +ip6t_error(struct sk_buff *skb, const struct xt_action_param *par) +{ + net_info_ratelimited("error: `%s'\n", (const char *)par->targinfo); + + return NF_DROP; +} + +static inline struct ip6t_entry * +get_entry(const void *base, unsigned int offset) +{ + return (struct ip6t_entry *)(base + offset); +} + +/* All zeroes == unconditional rule. */ +/* Mildly perf critical (only if packet tracing is on) */ +static inline bool unconditional(const struct ip6t_entry *e) +{ + static const struct ip6t_ip6 uncond; + + return e->target_offset == sizeof(struct ip6t_entry) && + memcmp(&e->ipv6, &uncond, sizeof(uncond)) == 0; +} + +static inline const struct xt_entry_target * +ip6t_get_target_c(const struct ip6t_entry *e) +{ + return ip6t_get_target((struct ip6t_entry *)e); +} + +#if IS_ENABLED(CONFIG_NETFILTER_XT_TARGET_TRACE) +/* This cries for unification! */ +static const char *const hooknames[] = { + [NF_INET_PRE_ROUTING] = "PREROUTING", + [NF_INET_LOCAL_IN] = "INPUT", + [NF_INET_FORWARD] = "FORWARD", + [NF_INET_LOCAL_OUT] = "OUTPUT", + [NF_INET_POST_ROUTING] = "POSTROUTING", +}; + +enum nf_ip_trace_comments { + NF_IP6_TRACE_COMMENT_RULE, + NF_IP6_TRACE_COMMENT_RETURN, + NF_IP6_TRACE_COMMENT_POLICY, +}; + +static const char *const comments[] = { + [NF_IP6_TRACE_COMMENT_RULE] = "rule", + [NF_IP6_TRACE_COMMENT_RETURN] = "return", + [NF_IP6_TRACE_COMMENT_POLICY] = "policy", +}; + +static const struct nf_loginfo trace_loginfo = { + .type = NF_LOG_TYPE_LOG, + .u = { + .log = { + .level = LOGLEVEL_WARNING, + .logflags = NF_LOG_DEFAULT_MASK, + }, + }, +}; + +/* Mildly perf critical (only if packet tracing is on) */ +static inline int +get_chainname_rulenum(const struct ip6t_entry *s, const struct ip6t_entry *e, + const char *hookname, const char **chainname, + const char **comment, unsigned int *rulenum) +{ + const struct xt_standard_target *t = (void *)ip6t_get_target_c(s); + + if (strcmp(t->target.u.kernel.target->name, XT_ERROR_TARGET) == 0) { + /* Head of user chain: ERROR target with chainname */ + *chainname = t->target.data; + (*rulenum) = 0; + } else if (s == e) { + (*rulenum)++; + + if (unconditional(s) && + strcmp(t->target.u.kernel.target->name, + XT_STANDARD_TARGET) == 0 && + t->verdict < 0) { + /* Tail of chains: STANDARD target (return/policy) */ + *comment = *chainname == hookname + ? comments[NF_IP6_TRACE_COMMENT_POLICY] + : comments[NF_IP6_TRACE_COMMENT_RETURN]; + } + return 1; + } else + (*rulenum)++; + + return 0; +} + +static void trace_packet(struct net *net, + const struct sk_buff *skb, + unsigned int hook, + const struct net_device *in, + const struct net_device *out, + const char *tablename, + const struct xt_table_info *private, + const struct ip6t_entry *e) +{ + const struct ip6t_entry *root; + const char *hookname, *chainname, *comment; + const struct ip6t_entry *iter; + unsigned int rulenum = 0; + + root = get_entry(private->entries, private->hook_entry[hook]); + + hookname = chainname = hooknames[hook]; + comment = comments[NF_IP6_TRACE_COMMENT_RULE]; + + xt_entry_foreach(iter, root, private->size - private->hook_entry[hook]) + if (get_chainname_rulenum(iter, e, hookname, + &chainname, &comment, &rulenum) != 0) + break; + + nf_log_trace(net, AF_INET6, hook, skb, in, out, &trace_loginfo, + "TRACE: %s:%s:%s:%u ", + tablename, chainname, comment, rulenum); +} +#endif + +static inline struct ip6t_entry * +ip6t_next_entry(const struct ip6t_entry *entry) +{ + return (void *)entry + entry->next_offset; +} + +/* Returns one of the generic firewall policies, like NF_ACCEPT. */ +unsigned int +ip6t_do_table(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + const struct xt_table *table = priv; + unsigned int hook = state->hook; + static const char nulldevname[IFNAMSIZ] __attribute__((aligned(sizeof(long)))); + /* Initializing verdict to NF_DROP keeps gcc happy. */ + unsigned int verdict = NF_DROP; + const char *indev, *outdev; + const void *table_base; + struct ip6t_entry *e, **jumpstack; + unsigned int stackidx, cpu; + const struct xt_table_info *private; + struct xt_action_param acpar; + unsigned int addend; + + /* Initialization */ + stackidx = 0; + indev = state->in ? state->in->name : nulldevname; + outdev = state->out ? state->out->name : nulldevname; + /* We handle fragments by dealing with the first fragment as + * if it was a normal packet. All other fragments are treated + * normally, except that they will NEVER match rules that ask + * things we don't know, ie. tcp syn flag or ports). If the + * rule is also a fragment-specific rule, non-fragments won't + * match it. */ + acpar.fragoff = 0; + acpar.hotdrop = false; + acpar.state = state; + + WARN_ON(!(table->valid_hooks & (1 << hook))); + + local_bh_disable(); + addend = xt_write_recseq_begin(); + private = READ_ONCE(table->private); /* Address dependency. */ + cpu = smp_processor_id(); + table_base = private->entries; + jumpstack = (struct ip6t_entry **)private->jumpstack[cpu]; + + /* Switch to alternate jumpstack if we're being invoked via TEE. + * TEE issues XT_CONTINUE verdict on original skb so we must not + * clobber the jumpstack. + * + * For recursion via REJECT or SYNPROXY the stack will be clobbered + * but it is no problem since absolute verdict is issued by these. + */ + if (static_key_false(&xt_tee_enabled)) + jumpstack += private->stacksize * __this_cpu_read(nf_skb_duplicated); + + e = get_entry(table_base, private->hook_entry[hook]); + + do { + const struct xt_entry_target *t; + const struct xt_entry_match *ematch; + struct xt_counters *counter; + + WARN_ON(!e); + acpar.thoff = 0; + if (!ip6_packet_match(skb, indev, outdev, &e->ipv6, + &acpar.thoff, &acpar.fragoff, &acpar.hotdrop)) { + no_match: + e = ip6t_next_entry(e); + continue; + } + + xt_ematch_foreach(ematch, e) { + acpar.match = ematch->u.kernel.match; + acpar.matchinfo = ematch->data; + if (!acpar.match->match(skb, &acpar)) + goto no_match; + } + + counter = xt_get_this_cpu_counter(&e->counters); + ADD_COUNTER(*counter, skb->len, 1); + + t = ip6t_get_target_c(e); + WARN_ON(!t->u.kernel.target); + +#if IS_ENABLED(CONFIG_NETFILTER_XT_TARGET_TRACE) + /* The packet is traced: log it */ + if (unlikely(skb->nf_trace)) + trace_packet(state->net, skb, hook, state->in, + state->out, table->name, private, e); +#endif + /* Standard target? */ + if (!t->u.kernel.target->target) { + int v; + + v = ((struct xt_standard_target *)t)->verdict; + if (v < 0) { + /* Pop from stack? */ + if (v != XT_RETURN) { + verdict = (unsigned int)(-v) - 1; + break; + } + if (stackidx == 0) + e = get_entry(table_base, + private->underflow[hook]); + else + e = ip6t_next_entry(jumpstack[--stackidx]); + continue; + } + if (table_base + v != ip6t_next_entry(e) && + !(e->ipv6.flags & IP6T_F_GOTO)) { + if (unlikely(stackidx >= private->stacksize)) { + verdict = NF_DROP; + break; + } + jumpstack[stackidx++] = e; + } + + e = get_entry(table_base, v); + continue; + } + + acpar.target = t->u.kernel.target; + acpar.targinfo = t->data; + + verdict = t->u.kernel.target->target(skb, &acpar); + if (verdict == XT_CONTINUE) + e = ip6t_next_entry(e); + else + /* Verdict */ + break; + } while (!acpar.hotdrop); + + xt_write_recseq_end(addend); + local_bh_enable(); + + if (acpar.hotdrop) + return NF_DROP; + else return verdict; +} + +/* Figures out from what hook each rule can be called: returns 0 if + there are loops. Puts hook bitmask in comefrom. */ +static int +mark_source_chains(const struct xt_table_info *newinfo, + unsigned int valid_hooks, void *entry0, + unsigned int *offsets) +{ + unsigned int hook; + + /* No recursion; use packet counter to save back ptrs (reset + to 0 as we leave), and comefrom to save source hook bitmask */ + for (hook = 0; hook < NF_INET_NUMHOOKS; hook++) { + unsigned int pos = newinfo->hook_entry[hook]; + struct ip6t_entry *e = entry0 + pos; + + if (!(valid_hooks & (1 << hook))) + continue; + + /* Set initial back pointer. */ + e->counters.pcnt = pos; + + for (;;) { + const struct xt_standard_target *t + = (void *)ip6t_get_target_c(e); + int visited = e->comefrom & (1 << hook); + + if (e->comefrom & (1 << NF_INET_NUMHOOKS)) + return 0; + + e->comefrom |= ((1 << hook) | (1 << NF_INET_NUMHOOKS)); + + /* Unconditional return/END. */ + if ((unconditional(e) && + (strcmp(t->target.u.user.name, + XT_STANDARD_TARGET) == 0) && + t->verdict < 0) || visited) { + unsigned int oldpos, size; + + /* Return: backtrack through the last + big jump. */ + do { + e->comefrom ^= (1<counters.pcnt; + e->counters.pcnt = 0; + + /* We're at the start. */ + if (pos == oldpos) + goto next; + + e = entry0 + pos; + } while (oldpos == pos + e->next_offset); + + /* Move along one */ + size = e->next_offset; + e = entry0 + pos + size; + if (pos + size >= newinfo->size) + return 0; + e->counters.pcnt = pos; + pos += size; + } else { + int newpos = t->verdict; + + if (strcmp(t->target.u.user.name, + XT_STANDARD_TARGET) == 0 && + newpos >= 0) { + /* This a jump; chase it. */ + if (!xt_find_jump_offset(offsets, newpos, + newinfo->number)) + return 0; + } else { + /* ... this is a fallthru */ + newpos = pos + e->next_offset; + if (newpos >= newinfo->size) + return 0; + } + e = entry0 + newpos; + e->counters.pcnt = pos; + pos = newpos; + } + } +next: ; + } + return 1; +} + +static void cleanup_match(struct xt_entry_match *m, struct net *net) +{ + struct xt_mtdtor_param par; + + par.net = net; + par.match = m->u.kernel.match; + par.matchinfo = m->data; + par.family = NFPROTO_IPV6; + if (par.match->destroy != NULL) + par.match->destroy(&par); + module_put(par.match->me); +} + +static int check_match(struct xt_entry_match *m, struct xt_mtchk_param *par) +{ + const struct ip6t_ip6 *ipv6 = par->entryinfo; + + par->match = m->u.kernel.match; + par->matchinfo = m->data; + + return xt_check_match(par, m->u.match_size - sizeof(*m), + ipv6->proto, ipv6->invflags & IP6T_INV_PROTO); +} + +static int +find_check_match(struct xt_entry_match *m, struct xt_mtchk_param *par) +{ + struct xt_match *match; + int ret; + + match = xt_request_find_match(NFPROTO_IPV6, m->u.user.name, + m->u.user.revision); + if (IS_ERR(match)) + return PTR_ERR(match); + + m->u.kernel.match = match; + + ret = check_match(m, par); + if (ret) + goto err; + + return 0; +err: + module_put(m->u.kernel.match->me); + return ret; +} + +static int check_target(struct ip6t_entry *e, struct net *net, const char *name) +{ + struct xt_entry_target *t = ip6t_get_target(e); + struct xt_tgchk_param par = { + .net = net, + .table = name, + .entryinfo = e, + .target = t->u.kernel.target, + .targinfo = t->data, + .hook_mask = e->comefrom, + .family = NFPROTO_IPV6, + }; + + return xt_check_target(&par, t->u.target_size - sizeof(*t), + e->ipv6.proto, + e->ipv6.invflags & IP6T_INV_PROTO); +} + +static int +find_check_entry(struct ip6t_entry *e, struct net *net, const char *name, + unsigned int size, + struct xt_percpu_counter_alloc_state *alloc_state) +{ + struct xt_entry_target *t; + struct xt_target *target; + int ret; + unsigned int j; + struct xt_mtchk_param mtpar; + struct xt_entry_match *ematch; + + if (!xt_percpu_counter_alloc(alloc_state, &e->counters)) + return -ENOMEM; + + j = 0; + memset(&mtpar, 0, sizeof(mtpar)); + mtpar.net = net; + mtpar.table = name; + mtpar.entryinfo = &e->ipv6; + mtpar.hook_mask = e->comefrom; + mtpar.family = NFPROTO_IPV6; + xt_ematch_foreach(ematch, e) { + ret = find_check_match(ematch, &mtpar); + if (ret != 0) + goto cleanup_matches; + ++j; + } + + t = ip6t_get_target(e); + target = xt_request_find_target(NFPROTO_IPV6, t->u.user.name, + t->u.user.revision); + if (IS_ERR(target)) { + ret = PTR_ERR(target); + goto cleanup_matches; + } + t->u.kernel.target = target; + + ret = check_target(e, net, name); + if (ret) + goto err; + return 0; + err: + module_put(t->u.kernel.target->me); + cleanup_matches: + xt_ematch_foreach(ematch, e) { + if (j-- == 0) + break; + cleanup_match(ematch, net); + } + + xt_percpu_counter_free(&e->counters); + + return ret; +} + +static bool check_underflow(const struct ip6t_entry *e) +{ + const struct xt_entry_target *t; + unsigned int verdict; + + if (!unconditional(e)) + return false; + t = ip6t_get_target_c(e); + if (strcmp(t->u.user.name, XT_STANDARD_TARGET) != 0) + return false; + verdict = ((struct xt_standard_target *)t)->verdict; + verdict = -verdict - 1; + return verdict == NF_DROP || verdict == NF_ACCEPT; +} + +static int +check_entry_size_and_hooks(struct ip6t_entry *e, + struct xt_table_info *newinfo, + const unsigned char *base, + const unsigned char *limit, + const unsigned int *hook_entries, + const unsigned int *underflows, + unsigned int valid_hooks) +{ + unsigned int h; + int err; + + if ((unsigned long)e % __alignof__(struct ip6t_entry) != 0 || + (unsigned char *)e + sizeof(struct ip6t_entry) >= limit || + (unsigned char *)e + e->next_offset > limit) + return -EINVAL; + + if (e->next_offset + < sizeof(struct ip6t_entry) + sizeof(struct xt_entry_target)) + return -EINVAL; + + if (!ip6_checkentry(&e->ipv6)) + return -EINVAL; + + err = xt_check_entry_offsets(e, e->elems, e->target_offset, + e->next_offset); + if (err) + return err; + + /* Check hooks & underflows */ + for (h = 0; h < NF_INET_NUMHOOKS; h++) { + if (!(valid_hooks & (1 << h))) + continue; + if ((unsigned char *)e - base == hook_entries[h]) + newinfo->hook_entry[h] = hook_entries[h]; + if ((unsigned char *)e - base == underflows[h]) { + if (!check_underflow(e)) + return -EINVAL; + + newinfo->underflow[h] = underflows[h]; + } + } + + /* Clear counters and comefrom */ + e->counters = ((struct xt_counters) { 0, 0 }); + e->comefrom = 0; + return 0; +} + +static void cleanup_entry(struct ip6t_entry *e, struct net *net) +{ + struct xt_tgdtor_param par; + struct xt_entry_target *t; + struct xt_entry_match *ematch; + + /* Cleanup all matches */ + xt_ematch_foreach(ematch, e) + cleanup_match(ematch, net); + t = ip6t_get_target(e); + + par.net = net; + par.target = t->u.kernel.target; + par.targinfo = t->data; + par.family = NFPROTO_IPV6; + if (par.target->destroy != NULL) + par.target->destroy(&par); + module_put(par.target->me); + xt_percpu_counter_free(&e->counters); +} + +/* Checks and translates the user-supplied table segment (held in + newinfo) */ +static int +translate_table(struct net *net, struct xt_table_info *newinfo, void *entry0, + const struct ip6t_replace *repl) +{ + struct xt_percpu_counter_alloc_state alloc_state = { 0 }; + struct ip6t_entry *iter; + unsigned int *offsets; + unsigned int i; + int ret = 0; + + newinfo->size = repl->size; + newinfo->number = repl->num_entries; + + /* Init all hooks to impossible value. */ + for (i = 0; i < NF_INET_NUMHOOKS; i++) { + newinfo->hook_entry[i] = 0xFFFFFFFF; + newinfo->underflow[i] = 0xFFFFFFFF; + } + + offsets = xt_alloc_entry_offsets(newinfo->number); + if (!offsets) + return -ENOMEM; + i = 0; + /* Walk through entries, checking offsets. */ + xt_entry_foreach(iter, entry0, newinfo->size) { + ret = check_entry_size_and_hooks(iter, newinfo, entry0, + entry0 + repl->size, + repl->hook_entry, + repl->underflow, + repl->valid_hooks); + if (ret != 0) + goto out_free; + if (i < repl->num_entries) + offsets[i] = (void *)iter - entry0; + ++i; + if (strcmp(ip6t_get_target(iter)->u.user.name, + XT_ERROR_TARGET) == 0) + ++newinfo->stacksize; + } + + ret = -EINVAL; + if (i != repl->num_entries) + goto out_free; + + ret = xt_check_table_hooks(newinfo, repl->valid_hooks); + if (ret) + goto out_free; + + if (!mark_source_chains(newinfo, repl->valid_hooks, entry0, offsets)) { + ret = -ELOOP; + goto out_free; + } + kvfree(offsets); + + /* Finally, each sanity check must pass */ + i = 0; + xt_entry_foreach(iter, entry0, newinfo->size) { + ret = find_check_entry(iter, net, repl->name, repl->size, + &alloc_state); + if (ret != 0) + break; + ++i; + } + + if (ret != 0) { + xt_entry_foreach(iter, entry0, newinfo->size) { + if (i-- == 0) + break; + cleanup_entry(iter, net); + } + return ret; + } + + return ret; + out_free: + kvfree(offsets); + return ret; +} + +static void +get_counters(const struct xt_table_info *t, + struct xt_counters counters[]) +{ + struct ip6t_entry *iter; + unsigned int cpu; + unsigned int i; + + for_each_possible_cpu(cpu) { + seqcount_t *s = &per_cpu(xt_recseq, cpu); + + i = 0; + xt_entry_foreach(iter, t->entries, t->size) { + struct xt_counters *tmp; + u64 bcnt, pcnt; + unsigned int start; + + tmp = xt_get_per_cpu_counter(&iter->counters, cpu); + do { + start = read_seqcount_begin(s); + bcnt = tmp->bcnt; + pcnt = tmp->pcnt; + } while (read_seqcount_retry(s, start)); + + ADD_COUNTER(counters[i], bcnt, pcnt); + ++i; + cond_resched(); + } + } +} + +static void get_old_counters(const struct xt_table_info *t, + struct xt_counters counters[]) +{ + struct ip6t_entry *iter; + unsigned int cpu, i; + + for_each_possible_cpu(cpu) { + i = 0; + xt_entry_foreach(iter, t->entries, t->size) { + const struct xt_counters *tmp; + + tmp = xt_get_per_cpu_counter(&iter->counters, cpu); + ADD_COUNTER(counters[i], tmp->bcnt, tmp->pcnt); + ++i; + } + cond_resched(); + } +} + +static struct xt_counters *alloc_counters(const struct xt_table *table) +{ + unsigned int countersize; + struct xt_counters *counters; + const struct xt_table_info *private = table->private; + + /* We need atomic snapshot of counters: rest doesn't change + (other than comefrom, which userspace doesn't care + about). */ + countersize = sizeof(struct xt_counters) * private->number; + counters = vzalloc(countersize); + + if (counters == NULL) + return ERR_PTR(-ENOMEM); + + get_counters(private, counters); + + return counters; +} + +static int +copy_entries_to_user(unsigned int total_size, + const struct xt_table *table, + void __user *userptr) +{ + unsigned int off, num; + const struct ip6t_entry *e; + struct xt_counters *counters; + const struct xt_table_info *private = table->private; + int ret = 0; + const void *loc_cpu_entry; + + counters = alloc_counters(table); + if (IS_ERR(counters)) + return PTR_ERR(counters); + + loc_cpu_entry = private->entries; + + /* FIXME: use iterator macros --RR */ + /* ... then go back and fix counters and names */ + for (off = 0, num = 0; off < total_size; off += e->next_offset, num++){ + unsigned int i; + const struct xt_entry_match *m; + const struct xt_entry_target *t; + + e = loc_cpu_entry + off; + if (copy_to_user(userptr + off, e, sizeof(*e))) { + ret = -EFAULT; + goto free_counters; + } + if (copy_to_user(userptr + off + + offsetof(struct ip6t_entry, counters), + &counters[num], + sizeof(counters[num])) != 0) { + ret = -EFAULT; + goto free_counters; + } + + for (i = sizeof(struct ip6t_entry); + i < e->target_offset; + i += m->u.match_size) { + m = (void *)e + i; + + if (xt_match_to_user(m, userptr + off + i)) { + ret = -EFAULT; + goto free_counters; + } + } + + t = ip6t_get_target_c(e); + if (xt_target_to_user(t, userptr + off + e->target_offset)) { + ret = -EFAULT; + goto free_counters; + } + } + + free_counters: + vfree(counters); + return ret; +} + +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT +static void compat_standard_from_user(void *dst, const void *src) +{ + int v = *(compat_int_t *)src; + + if (v > 0) + v += xt_compat_calc_jump(AF_INET6, v); + memcpy(dst, &v, sizeof(v)); +} + +static int compat_standard_to_user(void __user *dst, const void *src) +{ + compat_int_t cv = *(int *)src; + + if (cv > 0) + cv -= xt_compat_calc_jump(AF_INET6, cv); + return copy_to_user(dst, &cv, sizeof(cv)) ? -EFAULT : 0; +} + +static int compat_calc_entry(const struct ip6t_entry *e, + const struct xt_table_info *info, + const void *base, struct xt_table_info *newinfo) +{ + const struct xt_entry_match *ematch; + const struct xt_entry_target *t; + unsigned int entry_offset; + int off, i, ret; + + off = sizeof(struct ip6t_entry) - sizeof(struct compat_ip6t_entry); + entry_offset = (void *)e - base; + xt_ematch_foreach(ematch, e) + off += xt_compat_match_offset(ematch->u.kernel.match); + t = ip6t_get_target_c(e); + off += xt_compat_target_offset(t->u.kernel.target); + newinfo->size -= off; + ret = xt_compat_add_offset(AF_INET6, entry_offset, off); + if (ret) + return ret; + + for (i = 0; i < NF_INET_NUMHOOKS; i++) { + if (info->hook_entry[i] && + (e < (struct ip6t_entry *)(base + info->hook_entry[i]))) + newinfo->hook_entry[i] -= off; + if (info->underflow[i] && + (e < (struct ip6t_entry *)(base + info->underflow[i]))) + newinfo->underflow[i] -= off; + } + return 0; +} + +static int compat_table_info(const struct xt_table_info *info, + struct xt_table_info *newinfo) +{ + struct ip6t_entry *iter; + const void *loc_cpu_entry; + int ret; + + if (!newinfo || !info) + return -EINVAL; + + /* we dont care about newinfo->entries */ + memcpy(newinfo, info, offsetof(struct xt_table_info, entries)); + newinfo->initial_entries = 0; + loc_cpu_entry = info->entries; + ret = xt_compat_init_offsets(AF_INET6, info->number); + if (ret) + return ret; + xt_entry_foreach(iter, loc_cpu_entry, info->size) { + ret = compat_calc_entry(iter, info, loc_cpu_entry, newinfo); + if (ret != 0) + return ret; + } + return 0; +} +#endif + +static int get_info(struct net *net, void __user *user, const int *len) +{ + char name[XT_TABLE_MAXNAMELEN]; + struct xt_table *t; + int ret; + + if (*len != sizeof(struct ip6t_getinfo)) + return -EINVAL; + + if (copy_from_user(name, user, sizeof(name)) != 0) + return -EFAULT; + + name[XT_TABLE_MAXNAMELEN-1] = '\0'; +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + if (in_compat_syscall()) + xt_compat_lock(AF_INET6); +#endif + t = xt_request_find_table_lock(net, AF_INET6, name); + if (!IS_ERR(t)) { + struct ip6t_getinfo info; + const struct xt_table_info *private = t->private; +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + struct xt_table_info tmp; + + if (in_compat_syscall()) { + ret = compat_table_info(private, &tmp); + xt_compat_flush_offsets(AF_INET6); + private = &tmp; + } +#endif + memset(&info, 0, sizeof(info)); + info.valid_hooks = t->valid_hooks; + memcpy(info.hook_entry, private->hook_entry, + sizeof(info.hook_entry)); + memcpy(info.underflow, private->underflow, + sizeof(info.underflow)); + info.num_entries = private->number; + info.size = private->size; + strcpy(info.name, name); + + if (copy_to_user(user, &info, *len) != 0) + ret = -EFAULT; + else + ret = 0; + + xt_table_unlock(t); + module_put(t->me); + } else + ret = PTR_ERR(t); +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + if (in_compat_syscall()) + xt_compat_unlock(AF_INET6); +#endif + return ret; +} + +static int +get_entries(struct net *net, struct ip6t_get_entries __user *uptr, + const int *len) +{ + int ret; + struct ip6t_get_entries get; + struct xt_table *t; + + if (*len < sizeof(get)) + return -EINVAL; + if (copy_from_user(&get, uptr, sizeof(get)) != 0) + return -EFAULT; + if (*len != sizeof(struct ip6t_get_entries) + get.size) + return -EINVAL; + + get.name[sizeof(get.name) - 1] = '\0'; + + t = xt_find_table_lock(net, AF_INET6, get.name); + if (!IS_ERR(t)) { + struct xt_table_info *private = t->private; + if (get.size == private->size) + ret = copy_entries_to_user(private->size, + t, uptr->entrytable); + else + ret = -EAGAIN; + + module_put(t->me); + xt_table_unlock(t); + } else + ret = PTR_ERR(t); + + return ret; +} + +static int +__do_replace(struct net *net, const char *name, unsigned int valid_hooks, + struct xt_table_info *newinfo, unsigned int num_counters, + void __user *counters_ptr) +{ + int ret; + struct xt_table *t; + struct xt_table_info *oldinfo; + struct xt_counters *counters; + struct ip6t_entry *iter; + + counters = xt_counters_alloc(num_counters); + if (!counters) { + ret = -ENOMEM; + goto out; + } + + t = xt_request_find_table_lock(net, AF_INET6, name); + if (IS_ERR(t)) { + ret = PTR_ERR(t); + goto free_newinfo_counters_untrans; + } + + /* You lied! */ + if (valid_hooks != t->valid_hooks) { + ret = -EINVAL; + goto put_module; + } + + oldinfo = xt_replace_table(t, num_counters, newinfo, &ret); + if (!oldinfo) + goto put_module; + + /* Update module usage count based on number of rules */ + if ((oldinfo->number > oldinfo->initial_entries) || + (newinfo->number <= oldinfo->initial_entries)) + module_put(t->me); + if ((oldinfo->number > oldinfo->initial_entries) && + (newinfo->number <= oldinfo->initial_entries)) + module_put(t->me); + + xt_table_unlock(t); + + get_old_counters(oldinfo, counters); + + /* Decrease module usage counts and free resource */ + xt_entry_foreach(iter, oldinfo->entries, oldinfo->size) + cleanup_entry(iter, net); + + xt_free_table_info(oldinfo); + if (copy_to_user(counters_ptr, counters, + sizeof(struct xt_counters) * num_counters) != 0) { + /* Silent error, can't fail, new table is already in place */ + net_warn_ratelimited("ip6tables: counters copy to user failed while replacing table\n"); + } + vfree(counters); + return 0; + + put_module: + module_put(t->me); + xt_table_unlock(t); + free_newinfo_counters_untrans: + vfree(counters); + out: + return ret; +} + +static int +do_replace(struct net *net, sockptr_t arg, unsigned int len) +{ + int ret; + struct ip6t_replace tmp; + struct xt_table_info *newinfo; + void *loc_cpu_entry; + struct ip6t_entry *iter; + + if (copy_from_sockptr(&tmp, arg, sizeof(tmp)) != 0) + return -EFAULT; + + /* overflow check */ + if (tmp.num_counters >= INT_MAX / sizeof(struct xt_counters)) + return -ENOMEM; + if (tmp.num_counters == 0) + return -EINVAL; + + tmp.name[sizeof(tmp.name)-1] = 0; + + newinfo = xt_alloc_table_info(tmp.size); + if (!newinfo) + return -ENOMEM; + + loc_cpu_entry = newinfo->entries; + if (copy_from_sockptr_offset(loc_cpu_entry, arg, sizeof(tmp), + tmp.size) != 0) { + ret = -EFAULT; + goto free_newinfo; + } + + ret = translate_table(net, newinfo, loc_cpu_entry, &tmp); + if (ret != 0) + goto free_newinfo; + + ret = __do_replace(net, tmp.name, tmp.valid_hooks, newinfo, + tmp.num_counters, tmp.counters); + if (ret) + goto free_newinfo_untrans; + return 0; + + free_newinfo_untrans: + xt_entry_foreach(iter, loc_cpu_entry, newinfo->size) + cleanup_entry(iter, net); + free_newinfo: + xt_free_table_info(newinfo); + return ret; +} + +static int +do_add_counters(struct net *net, sockptr_t arg, unsigned int len) +{ + unsigned int i; + struct xt_counters_info tmp; + struct xt_counters *paddc; + struct xt_table *t; + const struct xt_table_info *private; + int ret = 0; + struct ip6t_entry *iter; + unsigned int addend; + + paddc = xt_copy_counters(arg, len, &tmp); + if (IS_ERR(paddc)) + return PTR_ERR(paddc); + t = xt_find_table_lock(net, AF_INET6, tmp.name); + if (IS_ERR(t)) { + ret = PTR_ERR(t); + goto free; + } + + local_bh_disable(); + private = t->private; + if (private->number != tmp.num_counters) { + ret = -EINVAL; + goto unlock_up_free; + } + + i = 0; + addend = xt_write_recseq_begin(); + xt_entry_foreach(iter, private->entries, private->size) { + struct xt_counters *tmp; + + tmp = xt_get_this_cpu_counter(&iter->counters); + ADD_COUNTER(*tmp, paddc[i].bcnt, paddc[i].pcnt); + ++i; + } + xt_write_recseq_end(addend); + unlock_up_free: + local_bh_enable(); + xt_table_unlock(t); + module_put(t->me); + free: + vfree(paddc); + + return ret; +} + +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT +struct compat_ip6t_replace { + char name[XT_TABLE_MAXNAMELEN]; + u32 valid_hooks; + u32 num_entries; + u32 size; + u32 hook_entry[NF_INET_NUMHOOKS]; + u32 underflow[NF_INET_NUMHOOKS]; + u32 num_counters; + compat_uptr_t counters; /* struct xt_counters * */ + struct compat_ip6t_entry entries[]; +}; + +static int +compat_copy_entry_to_user(struct ip6t_entry *e, void __user **dstptr, + unsigned int *size, struct xt_counters *counters, + unsigned int i) +{ + struct xt_entry_target *t; + struct compat_ip6t_entry __user *ce; + u_int16_t target_offset, next_offset; + compat_uint_t origsize; + const struct xt_entry_match *ematch; + int ret = 0; + + origsize = *size; + ce = *dstptr; + if (copy_to_user(ce, e, sizeof(struct ip6t_entry)) != 0 || + copy_to_user(&ce->counters, &counters[i], + sizeof(counters[i])) != 0) + return -EFAULT; + + *dstptr += sizeof(struct compat_ip6t_entry); + *size -= sizeof(struct ip6t_entry) - sizeof(struct compat_ip6t_entry); + + xt_ematch_foreach(ematch, e) { + ret = xt_compat_match_to_user(ematch, dstptr, size); + if (ret != 0) + return ret; + } + target_offset = e->target_offset - (origsize - *size); + t = ip6t_get_target(e); + ret = xt_compat_target_to_user(t, dstptr, size); + if (ret) + return ret; + next_offset = e->next_offset - (origsize - *size); + if (put_user(target_offset, &ce->target_offset) != 0 || + put_user(next_offset, &ce->next_offset) != 0) + return -EFAULT; + return 0; +} + +static int +compat_find_calc_match(struct xt_entry_match *m, + const struct ip6t_ip6 *ipv6, + int *size) +{ + struct xt_match *match; + + match = xt_request_find_match(NFPROTO_IPV6, m->u.user.name, + m->u.user.revision); + if (IS_ERR(match)) + return PTR_ERR(match); + + m->u.kernel.match = match; + *size += xt_compat_match_offset(match); + return 0; +} + +static void compat_release_entry(struct compat_ip6t_entry *e) +{ + struct xt_entry_target *t; + struct xt_entry_match *ematch; + + /* Cleanup all matches */ + xt_ematch_foreach(ematch, e) + module_put(ematch->u.kernel.match->me); + t = compat_ip6t_get_target(e); + module_put(t->u.kernel.target->me); +} + +static int +check_compat_entry_size_and_hooks(struct compat_ip6t_entry *e, + struct xt_table_info *newinfo, + unsigned int *size, + const unsigned char *base, + const unsigned char *limit) +{ + struct xt_entry_match *ematch; + struct xt_entry_target *t; + struct xt_target *target; + unsigned int entry_offset; + unsigned int j; + int ret, off; + + if ((unsigned long)e % __alignof__(struct compat_ip6t_entry) != 0 || + (unsigned char *)e + sizeof(struct compat_ip6t_entry) >= limit || + (unsigned char *)e + e->next_offset > limit) + return -EINVAL; + + if (e->next_offset < sizeof(struct compat_ip6t_entry) + + sizeof(struct compat_xt_entry_target)) + return -EINVAL; + + if (!ip6_checkentry(&e->ipv6)) + return -EINVAL; + + ret = xt_compat_check_entry_offsets(e, e->elems, + e->target_offset, e->next_offset); + if (ret) + return ret; + + off = sizeof(struct ip6t_entry) - sizeof(struct compat_ip6t_entry); + entry_offset = (void *)e - (void *)base; + j = 0; + xt_ematch_foreach(ematch, e) { + ret = compat_find_calc_match(ematch, &e->ipv6, &off); + if (ret != 0) + goto release_matches; + ++j; + } + + t = compat_ip6t_get_target(e); + target = xt_request_find_target(NFPROTO_IPV6, t->u.user.name, + t->u.user.revision); + if (IS_ERR(target)) { + ret = PTR_ERR(target); + goto release_matches; + } + t->u.kernel.target = target; + + off += xt_compat_target_offset(target); + *size += off; + ret = xt_compat_add_offset(AF_INET6, entry_offset, off); + if (ret) + goto out; + + return 0; + +out: + module_put(t->u.kernel.target->me); +release_matches: + xt_ematch_foreach(ematch, e) { + if (j-- == 0) + break; + module_put(ematch->u.kernel.match->me); + } + return ret; +} + +static void +compat_copy_entry_from_user(struct compat_ip6t_entry *e, void **dstptr, + unsigned int *size, + struct xt_table_info *newinfo, unsigned char *base) +{ + struct xt_entry_target *t; + struct ip6t_entry *de; + unsigned int origsize; + int h; + struct xt_entry_match *ematch; + + origsize = *size; + de = *dstptr; + memcpy(de, e, sizeof(struct ip6t_entry)); + memcpy(&de->counters, &e->counters, sizeof(e->counters)); + + *dstptr += sizeof(struct ip6t_entry); + *size += sizeof(struct ip6t_entry) - sizeof(struct compat_ip6t_entry); + + xt_ematch_foreach(ematch, e) + xt_compat_match_from_user(ematch, dstptr, size); + + de->target_offset = e->target_offset - (origsize - *size); + t = compat_ip6t_get_target(e); + xt_compat_target_from_user(t, dstptr, size); + + de->next_offset = e->next_offset - (origsize - *size); + for (h = 0; h < NF_INET_NUMHOOKS; h++) { + if ((unsigned char *)de - base < newinfo->hook_entry[h]) + newinfo->hook_entry[h] -= origsize - *size; + if ((unsigned char *)de - base < newinfo->underflow[h]) + newinfo->underflow[h] -= origsize - *size; + } +} + +static int +translate_compat_table(struct net *net, + struct xt_table_info **pinfo, + void **pentry0, + const struct compat_ip6t_replace *compatr) +{ + unsigned int i, j; + struct xt_table_info *newinfo, *info; + void *pos, *entry0, *entry1; + struct compat_ip6t_entry *iter0; + struct ip6t_replace repl; + unsigned int size; + int ret; + + info = *pinfo; + entry0 = *pentry0; + size = compatr->size; + info->number = compatr->num_entries; + + j = 0; + xt_compat_lock(AF_INET6); + ret = xt_compat_init_offsets(AF_INET6, compatr->num_entries); + if (ret) + goto out_unlock; + /* Walk through entries, checking offsets. */ + xt_entry_foreach(iter0, entry0, compatr->size) { + ret = check_compat_entry_size_and_hooks(iter0, info, &size, + entry0, + entry0 + compatr->size); + if (ret != 0) + goto out_unlock; + ++j; + } + + ret = -EINVAL; + if (j != compatr->num_entries) + goto out_unlock; + + ret = -ENOMEM; + newinfo = xt_alloc_table_info(size); + if (!newinfo) + goto out_unlock; + + memset(newinfo->entries, 0, size); + + newinfo->number = compatr->num_entries; + for (i = 0; i < NF_INET_NUMHOOKS; i++) { + newinfo->hook_entry[i] = compatr->hook_entry[i]; + newinfo->underflow[i] = compatr->underflow[i]; + } + entry1 = newinfo->entries; + pos = entry1; + size = compatr->size; + xt_entry_foreach(iter0, entry0, compatr->size) + compat_copy_entry_from_user(iter0, &pos, &size, + newinfo, entry1); + + /* all module references in entry0 are now gone. */ + xt_compat_flush_offsets(AF_INET6); + xt_compat_unlock(AF_INET6); + + memcpy(&repl, compatr, sizeof(*compatr)); + + for (i = 0; i < NF_INET_NUMHOOKS; i++) { + repl.hook_entry[i] = newinfo->hook_entry[i]; + repl.underflow[i] = newinfo->underflow[i]; + } + + repl.num_counters = 0; + repl.counters = NULL; + repl.size = newinfo->size; + ret = translate_table(net, newinfo, entry1, &repl); + if (ret) + goto free_newinfo; + + *pinfo = newinfo; + *pentry0 = entry1; + xt_free_table_info(info); + return 0; + +free_newinfo: + xt_free_table_info(newinfo); + return ret; +out_unlock: + xt_compat_flush_offsets(AF_INET6); + xt_compat_unlock(AF_INET6); + xt_entry_foreach(iter0, entry0, compatr->size) { + if (j-- == 0) + break; + compat_release_entry(iter0); + } + return ret; +} + +static int +compat_do_replace(struct net *net, sockptr_t arg, unsigned int len) +{ + int ret; + struct compat_ip6t_replace tmp; + struct xt_table_info *newinfo; + void *loc_cpu_entry; + struct ip6t_entry *iter; + + if (copy_from_sockptr(&tmp, arg, sizeof(tmp)) != 0) + return -EFAULT; + + /* overflow check */ + if (tmp.num_counters >= INT_MAX / sizeof(struct xt_counters)) + return -ENOMEM; + if (tmp.num_counters == 0) + return -EINVAL; + + tmp.name[sizeof(tmp.name)-1] = 0; + + newinfo = xt_alloc_table_info(tmp.size); + if (!newinfo) + return -ENOMEM; + + loc_cpu_entry = newinfo->entries; + if (copy_from_sockptr_offset(loc_cpu_entry, arg, sizeof(tmp), + tmp.size) != 0) { + ret = -EFAULT; + goto free_newinfo; + } + + ret = translate_compat_table(net, &newinfo, &loc_cpu_entry, &tmp); + if (ret != 0) + goto free_newinfo; + + ret = __do_replace(net, tmp.name, tmp.valid_hooks, newinfo, + tmp.num_counters, compat_ptr(tmp.counters)); + if (ret) + goto free_newinfo_untrans; + return 0; + + free_newinfo_untrans: + xt_entry_foreach(iter, loc_cpu_entry, newinfo->size) + cleanup_entry(iter, net); + free_newinfo: + xt_free_table_info(newinfo); + return ret; +} + +struct compat_ip6t_get_entries { + char name[XT_TABLE_MAXNAMELEN]; + compat_uint_t size; + struct compat_ip6t_entry entrytable[]; +}; + +static int +compat_copy_entries_to_user(unsigned int total_size, struct xt_table *table, + void __user *userptr) +{ + struct xt_counters *counters; + const struct xt_table_info *private = table->private; + void __user *pos; + unsigned int size; + int ret = 0; + unsigned int i = 0; + struct ip6t_entry *iter; + + counters = alloc_counters(table); + if (IS_ERR(counters)) + return PTR_ERR(counters); + + pos = userptr; + size = total_size; + xt_entry_foreach(iter, private->entries, total_size) { + ret = compat_copy_entry_to_user(iter, &pos, + &size, counters, i++); + if (ret != 0) + break; + } + + vfree(counters); + return ret; +} + +static int +compat_get_entries(struct net *net, struct compat_ip6t_get_entries __user *uptr, + int *len) +{ + int ret; + struct compat_ip6t_get_entries get; + struct xt_table *t; + + if (*len < sizeof(get)) + return -EINVAL; + + if (copy_from_user(&get, uptr, sizeof(get)) != 0) + return -EFAULT; + + if (*len != sizeof(struct compat_ip6t_get_entries) + get.size) + return -EINVAL; + + get.name[sizeof(get.name) - 1] = '\0'; + + xt_compat_lock(AF_INET6); + t = xt_find_table_lock(net, AF_INET6, get.name); + if (!IS_ERR(t)) { + const struct xt_table_info *private = t->private; + struct xt_table_info info; + ret = compat_table_info(private, &info); + if (!ret && get.size == info.size) + ret = compat_copy_entries_to_user(private->size, + t, uptr->entrytable); + else if (!ret) + ret = -EAGAIN; + + xt_compat_flush_offsets(AF_INET6); + module_put(t->me); + xt_table_unlock(t); + } else + ret = PTR_ERR(t); + + xt_compat_unlock(AF_INET6); + return ret; +} +#endif + +static int +do_ip6t_set_ctl(struct sock *sk, int cmd, sockptr_t arg, unsigned int len) +{ + int ret; + + if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + switch (cmd) { + case IP6T_SO_SET_REPLACE: +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + if (in_compat_syscall()) + ret = compat_do_replace(sock_net(sk), arg, len); + else +#endif + ret = do_replace(sock_net(sk), arg, len); + break; + + case IP6T_SO_SET_ADD_COUNTERS: + ret = do_add_counters(sock_net(sk), arg, len); + break; + + default: + ret = -EINVAL; + } + + return ret; +} + +static int +do_ip6t_get_ctl(struct sock *sk, int cmd, void __user *user, int *len) +{ + int ret; + + if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + switch (cmd) { + case IP6T_SO_GET_INFO: + ret = get_info(sock_net(sk), user, len); + break; + + case IP6T_SO_GET_ENTRIES: +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + if (in_compat_syscall()) + ret = compat_get_entries(sock_net(sk), user, len); + else +#endif + ret = get_entries(sock_net(sk), user, len); + break; + + case IP6T_SO_GET_REVISION_MATCH: + case IP6T_SO_GET_REVISION_TARGET: { + struct xt_get_revision rev; + int target; + + if (*len != sizeof(rev)) { + ret = -EINVAL; + break; + } + if (copy_from_user(&rev, user, sizeof(rev)) != 0) { + ret = -EFAULT; + break; + } + rev.name[sizeof(rev.name)-1] = 0; + + if (cmd == IP6T_SO_GET_REVISION_TARGET) + target = 1; + else + target = 0; + + try_then_request_module(xt_find_revision(AF_INET6, rev.name, + rev.revision, + target, &ret), + "ip6t_%s", rev.name); + break; + } + + default: + ret = -EINVAL; + } + + return ret; +} + +static void __ip6t_unregister_table(struct net *net, struct xt_table *table) +{ + struct xt_table_info *private; + void *loc_cpu_entry; + struct module *table_owner = table->me; + struct ip6t_entry *iter; + + private = xt_unregister_table(table); + + /* Decrease module usage counts and free resources */ + loc_cpu_entry = private->entries; + xt_entry_foreach(iter, loc_cpu_entry, private->size) + cleanup_entry(iter, net); + if (private->number > private->initial_entries) + module_put(table_owner); + xt_free_table_info(private); +} + +int ip6t_register_table(struct net *net, const struct xt_table *table, + const struct ip6t_replace *repl, + const struct nf_hook_ops *template_ops) +{ + struct nf_hook_ops *ops; + unsigned int num_ops; + int ret, i; + struct xt_table_info *newinfo; + struct xt_table_info bootstrap = {0}; + void *loc_cpu_entry; + struct xt_table *new_table; + + newinfo = xt_alloc_table_info(repl->size); + if (!newinfo) + return -ENOMEM; + + loc_cpu_entry = newinfo->entries; + memcpy(loc_cpu_entry, repl->entries, repl->size); + + ret = translate_table(net, newinfo, loc_cpu_entry, repl); + if (ret != 0) { + xt_free_table_info(newinfo); + return ret; + } + + new_table = xt_register_table(net, table, &bootstrap, newinfo); + if (IS_ERR(new_table)) { + struct ip6t_entry *iter; + + xt_entry_foreach(iter, loc_cpu_entry, newinfo->size) + cleanup_entry(iter, net); + xt_free_table_info(newinfo); + return PTR_ERR(new_table); + } + + if (!template_ops) + return 0; + + num_ops = hweight32(table->valid_hooks); + if (num_ops == 0) { + ret = -EINVAL; + goto out_free; + } + + ops = kmemdup(template_ops, sizeof(*ops) * num_ops, GFP_KERNEL); + if (!ops) { + ret = -ENOMEM; + goto out_free; + } + + for (i = 0; i < num_ops; i++) + ops[i].priv = new_table; + + new_table->ops = ops; + + ret = nf_register_net_hooks(net, ops, num_ops); + if (ret != 0) + goto out_free; + + return ret; + +out_free: + __ip6t_unregister_table(net, new_table); + return ret; +} + +void ip6t_unregister_table_pre_exit(struct net *net, const char *name) +{ + struct xt_table *table = xt_find_table(net, NFPROTO_IPV6, name); + + if (table) + nf_unregister_net_hooks(net, table->ops, hweight32(table->valid_hooks)); +} + +void ip6t_unregister_table_exit(struct net *net, const char *name) +{ + struct xt_table *table = xt_find_table(net, NFPROTO_IPV6, name); + + if (table) + __ip6t_unregister_table(net, table); +} + +/* Returns 1 if the type and code is matched by the range, 0 otherwise */ +static inline bool +icmp6_type_code_match(u_int8_t test_type, u_int8_t min_code, u_int8_t max_code, + u_int8_t type, u_int8_t code, + bool invert) +{ + return (type == test_type && code >= min_code && code <= max_code) + ^ invert; +} + +static bool +icmp6_match(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct icmp6hdr *ic; + struct icmp6hdr _icmph; + const struct ip6t_icmp *icmpinfo = par->matchinfo; + + /* Must not be a fragment. */ + if (par->fragoff != 0) + return false; + + ic = skb_header_pointer(skb, par->thoff, sizeof(_icmph), &_icmph); + if (ic == NULL) { + /* We've been asked to examine this packet, and we + * can't. Hence, no choice but to drop. + */ + par->hotdrop = true; + return false; + } + + return icmp6_type_code_match(icmpinfo->type, + icmpinfo->code[0], + icmpinfo->code[1], + ic->icmp6_type, ic->icmp6_code, + !!(icmpinfo->invflags&IP6T_ICMP_INV)); +} + +/* Called when user tries to insert an entry of this type. */ +static int icmp6_checkentry(const struct xt_mtchk_param *par) +{ + const struct ip6t_icmp *icmpinfo = par->matchinfo; + + /* Must specify no unknown invflags */ + return (icmpinfo->invflags & ~IP6T_ICMP_INV) ? -EINVAL : 0; +} + +/* The built-in targets: standard (NULL) and error. */ +static struct xt_target ip6t_builtin_tg[] __read_mostly = { + { + .name = XT_STANDARD_TARGET, + .targetsize = sizeof(int), + .family = NFPROTO_IPV6, +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + .compatsize = sizeof(compat_int_t), + .compat_from_user = compat_standard_from_user, + .compat_to_user = compat_standard_to_user, +#endif + }, + { + .name = XT_ERROR_TARGET, + .target = ip6t_error, + .targetsize = XT_FUNCTION_MAXNAMELEN, + .family = NFPROTO_IPV6, + }, +}; + +static struct nf_sockopt_ops ip6t_sockopts = { + .pf = PF_INET6, + .set_optmin = IP6T_BASE_CTL, + .set_optmax = IP6T_SO_SET_MAX+1, + .set = do_ip6t_set_ctl, + .get_optmin = IP6T_BASE_CTL, + .get_optmax = IP6T_SO_GET_MAX+1, + .get = do_ip6t_get_ctl, + .owner = THIS_MODULE, +}; + +static struct xt_match ip6t_builtin_mt[] __read_mostly = { + { + .name = "icmp6", + .match = icmp6_match, + .matchsize = sizeof(struct ip6t_icmp), + .checkentry = icmp6_checkentry, + .proto = IPPROTO_ICMPV6, + .family = NFPROTO_IPV6, + .me = THIS_MODULE, + }, +}; + +static int __net_init ip6_tables_net_init(struct net *net) +{ + return xt_proto_init(net, NFPROTO_IPV6); +} + +static void __net_exit ip6_tables_net_exit(struct net *net) +{ + xt_proto_fini(net, NFPROTO_IPV6); +} + +static struct pernet_operations ip6_tables_net_ops = { + .init = ip6_tables_net_init, + .exit = ip6_tables_net_exit, +}; + +static int __init ip6_tables_init(void) +{ + int ret; + + ret = register_pernet_subsys(&ip6_tables_net_ops); + if (ret < 0) + goto err1; + + /* No one else will be downing sem now, so we won't sleep */ + ret = xt_register_targets(ip6t_builtin_tg, ARRAY_SIZE(ip6t_builtin_tg)); + if (ret < 0) + goto err2; + ret = xt_register_matches(ip6t_builtin_mt, ARRAY_SIZE(ip6t_builtin_mt)); + if (ret < 0) + goto err4; + + /* Register setsockopt */ + ret = nf_register_sockopt(&ip6t_sockopts); + if (ret < 0) + goto err5; + + return 0; + +err5: + xt_unregister_matches(ip6t_builtin_mt, ARRAY_SIZE(ip6t_builtin_mt)); +err4: + xt_unregister_targets(ip6t_builtin_tg, ARRAY_SIZE(ip6t_builtin_tg)); +err2: + unregister_pernet_subsys(&ip6_tables_net_ops); +err1: + return ret; +} + +static void __exit ip6_tables_fini(void) +{ + nf_unregister_sockopt(&ip6t_sockopts); + + xt_unregister_matches(ip6t_builtin_mt, ARRAY_SIZE(ip6t_builtin_mt)); + xt_unregister_targets(ip6t_builtin_tg, ARRAY_SIZE(ip6t_builtin_tg)); + unregister_pernet_subsys(&ip6_tables_net_ops); +} + +EXPORT_SYMBOL(ip6t_register_table); +EXPORT_SYMBOL(ip6t_unregister_table_pre_exit); +EXPORT_SYMBOL(ip6t_unregister_table_exit); +EXPORT_SYMBOL(ip6t_do_table); + +module_init(ip6_tables_init); +module_exit(ip6_tables_fini); diff --git a/net/ipv6/netfilter/ip6t_NPT.c b/net/ipv6/netfilter/ip6t_NPT.c new file mode 100644 index 000000000..787c74aa8 --- /dev/null +++ b/net/ipv6/netfilter/ip6t_NPT.c @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2011, 2012 Patrick McHardy + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +static int ip6t_npt_checkentry(const struct xt_tgchk_param *par) +{ + struct ip6t_npt_tginfo *npt = par->targinfo; + struct in6_addr pfx; + __wsum src_sum, dst_sum; + + if (npt->src_pfx_len > 64 || npt->dst_pfx_len > 64) + return -EINVAL; + + /* Ensure that LSB of prefix is zero */ + ipv6_addr_prefix(&pfx, &npt->src_pfx.in6, npt->src_pfx_len); + if (!ipv6_addr_equal(&pfx, &npt->src_pfx.in6)) + return -EINVAL; + ipv6_addr_prefix(&pfx, &npt->dst_pfx.in6, npt->dst_pfx_len); + if (!ipv6_addr_equal(&pfx, &npt->dst_pfx.in6)) + return -EINVAL; + + src_sum = csum_partial(&npt->src_pfx.in6, sizeof(npt->src_pfx.in6), 0); + dst_sum = csum_partial(&npt->dst_pfx.in6, sizeof(npt->dst_pfx.in6), 0); + + npt->adjustment = ~csum_fold(csum_sub(src_sum, dst_sum)); + return 0; +} + +static bool ip6t_npt_map_pfx(const struct ip6t_npt_tginfo *npt, + struct in6_addr *addr) +{ + unsigned int pfx_len; + unsigned int i, idx; + __be32 mask; + __sum16 sum; + + pfx_len = max(npt->src_pfx_len, npt->dst_pfx_len); + for (i = 0; i < pfx_len; i += 32) { + if (pfx_len - i >= 32) + mask = 0; + else + mask = htonl((1 << (i - pfx_len + 32)) - 1); + + idx = i / 32; + addr->s6_addr32[idx] &= mask; + addr->s6_addr32[idx] |= ~mask & npt->dst_pfx.in6.s6_addr32[idx]; + } + + if (pfx_len <= 48) + idx = 3; + else { + for (idx = 4; idx < ARRAY_SIZE(addr->s6_addr16); idx++) { + if ((__force __sum16)addr->s6_addr16[idx] != + CSUM_MANGLED_0) + break; + } + if (idx == ARRAY_SIZE(addr->s6_addr16)) + return false; + } + + sum = ~csum_fold(csum_add(csum_unfold((__force __sum16)addr->s6_addr16[idx]), + csum_unfold(npt->adjustment))); + if (sum == CSUM_MANGLED_0) + sum = 0; + *(__force __sum16 *)&addr->s6_addr16[idx] = sum; + + return true; +} + +static struct ipv6hdr *icmpv6_bounced_ipv6hdr(struct sk_buff *skb, + struct ipv6hdr *_bounced_hdr) +{ + if (ipv6_hdr(skb)->nexthdr != IPPROTO_ICMPV6) + return NULL; + + if (!icmpv6_is_err(icmp6_hdr(skb)->icmp6_type)) + return NULL; + + return skb_header_pointer(skb, + skb_transport_offset(skb) + sizeof(struct icmp6hdr), + sizeof(struct ipv6hdr), + _bounced_hdr); +} + +static unsigned int +ip6t_snpt_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct ip6t_npt_tginfo *npt = par->targinfo; + struct ipv6hdr _bounced_hdr; + struct ipv6hdr *bounced_hdr; + struct in6_addr bounced_pfx; + + if (!ip6t_npt_map_pfx(npt, &ipv6_hdr(skb)->saddr)) { + icmpv6_send(skb, ICMPV6_PARAMPROB, ICMPV6_HDR_FIELD, + offsetof(struct ipv6hdr, saddr)); + return NF_DROP; + } + + /* rewrite dst addr of bounced packet which was sent to dst range */ + bounced_hdr = icmpv6_bounced_ipv6hdr(skb, &_bounced_hdr); + if (bounced_hdr) { + ipv6_addr_prefix(&bounced_pfx, &bounced_hdr->daddr, npt->src_pfx_len); + if (ipv6_addr_cmp(&bounced_pfx, &npt->src_pfx.in6) == 0) + ip6t_npt_map_pfx(npt, &bounced_hdr->daddr); + } + + return XT_CONTINUE; +} + +static unsigned int +ip6t_dnpt_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct ip6t_npt_tginfo *npt = par->targinfo; + struct ipv6hdr _bounced_hdr; + struct ipv6hdr *bounced_hdr; + struct in6_addr bounced_pfx; + + if (!ip6t_npt_map_pfx(npt, &ipv6_hdr(skb)->daddr)) { + icmpv6_send(skb, ICMPV6_PARAMPROB, ICMPV6_HDR_FIELD, + offsetof(struct ipv6hdr, daddr)); + return NF_DROP; + } + + /* rewrite src addr of bounced packet which was sent from dst range */ + bounced_hdr = icmpv6_bounced_ipv6hdr(skb, &_bounced_hdr); + if (bounced_hdr) { + ipv6_addr_prefix(&bounced_pfx, &bounced_hdr->saddr, npt->src_pfx_len); + if (ipv6_addr_cmp(&bounced_pfx, &npt->src_pfx.in6) == 0) + ip6t_npt_map_pfx(npt, &bounced_hdr->saddr); + } + + return XT_CONTINUE; +} + +static struct xt_target ip6t_npt_target_reg[] __read_mostly = { + { + .name = "SNPT", + .table = "mangle", + .target = ip6t_snpt_tg, + .targetsize = sizeof(struct ip6t_npt_tginfo), + .usersize = offsetof(struct ip6t_npt_tginfo, adjustment), + .checkentry = ip6t_npt_checkentry, + .family = NFPROTO_IPV6, + .hooks = (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_POST_ROUTING), + .me = THIS_MODULE, + }, + { + .name = "DNPT", + .table = "mangle", + .target = ip6t_dnpt_tg, + .targetsize = sizeof(struct ip6t_npt_tginfo), + .usersize = offsetof(struct ip6t_npt_tginfo, adjustment), + .checkentry = ip6t_npt_checkentry, + .family = NFPROTO_IPV6, + .hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_OUT), + .me = THIS_MODULE, + }, +}; + +static int __init ip6t_npt_init(void) +{ + return xt_register_targets(ip6t_npt_target_reg, + ARRAY_SIZE(ip6t_npt_target_reg)); +} + +static void __exit ip6t_npt_exit(void) +{ + xt_unregister_targets(ip6t_npt_target_reg, + ARRAY_SIZE(ip6t_npt_target_reg)); +} + +module_init(ip6t_npt_init); +module_exit(ip6t_npt_exit); + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("IPv6-to-IPv6 Network Prefix Translation (RFC 6296)"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_ALIAS("ip6t_SNPT"); +MODULE_ALIAS("ip6t_DNPT"); diff --git a/net/ipv6/netfilter/ip6t_REJECT.c b/net/ipv6/netfilter/ip6t_REJECT.c new file mode 100644 index 000000000..a35019d2e --- /dev/null +++ b/net/ipv6/netfilter/ip6t_REJECT.c @@ -0,0 +1,121 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IP6 tables REJECT target module + * Linux INET6 implementation + * + * Copyright (C)2003 USAGI/WIDE Project + * + * Authors: + * Yasuyuki Kozakai + * + * Copyright (c) 2005-2007 Patrick McHardy + * + * Based on net/ipv4/netfilter/ipt_REJECT.c + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +MODULE_AUTHOR("Yasuyuki KOZAKAI "); +MODULE_DESCRIPTION("Xtables: packet \"rejection\" target for IPv6"); +MODULE_LICENSE("GPL"); + +static unsigned int +reject_tg6(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct ip6t_reject_info *reject = par->targinfo; + struct net *net = xt_net(par); + + switch (reject->with) { + case IP6T_ICMP6_NO_ROUTE: + nf_send_unreach6(net, skb, ICMPV6_NOROUTE, xt_hooknum(par)); + break; + case IP6T_ICMP6_ADM_PROHIBITED: + nf_send_unreach6(net, skb, ICMPV6_ADM_PROHIBITED, + xt_hooknum(par)); + break; + case IP6T_ICMP6_NOT_NEIGHBOUR: + nf_send_unreach6(net, skb, ICMPV6_NOT_NEIGHBOUR, + xt_hooknum(par)); + break; + case IP6T_ICMP6_ADDR_UNREACH: + nf_send_unreach6(net, skb, ICMPV6_ADDR_UNREACH, + xt_hooknum(par)); + break; + case IP6T_ICMP6_PORT_UNREACH: + nf_send_unreach6(net, skb, ICMPV6_PORT_UNREACH, + xt_hooknum(par)); + break; + case IP6T_ICMP6_ECHOREPLY: + /* Do nothing */ + break; + case IP6T_TCP_RESET: + nf_send_reset6(net, par->state->sk, skb, xt_hooknum(par)); + break; + case IP6T_ICMP6_POLICY_FAIL: + nf_send_unreach6(net, skb, ICMPV6_POLICY_FAIL, xt_hooknum(par)); + break; + case IP6T_ICMP6_REJECT_ROUTE: + nf_send_unreach6(net, skb, ICMPV6_REJECT_ROUTE, + xt_hooknum(par)); + break; + } + + return NF_DROP; +} + +static int reject_tg6_check(const struct xt_tgchk_param *par) +{ + const struct ip6t_reject_info *rejinfo = par->targinfo; + const struct ip6t_entry *e = par->entryinfo; + + if (rejinfo->with == IP6T_ICMP6_ECHOREPLY) { + pr_info_ratelimited("ECHOREPLY is not supported\n"); + return -EINVAL; + } else if (rejinfo->with == IP6T_TCP_RESET) { + /* Must specify that it's a TCP packet */ + if (!(e->ipv6.flags & IP6T_F_PROTO) || + e->ipv6.proto != IPPROTO_TCP || + (e->ipv6.invflags & XT_INV_PROTO)) { + pr_info_ratelimited("TCP_RESET illegal for non-tcp\n"); + return -EINVAL; + } + } + return 0; +} + +static struct xt_target reject_tg6_reg __read_mostly = { + .name = "REJECT", + .family = NFPROTO_IPV6, + .target = reject_tg6, + .targetsize = sizeof(struct ip6t_reject_info), + .table = "filter", + .hooks = (1 << NF_INET_LOCAL_IN) | (1 << NF_INET_FORWARD) | + (1 << NF_INET_LOCAL_OUT), + .checkentry = reject_tg6_check, + .me = THIS_MODULE +}; + +static int __init reject_tg6_init(void) +{ + return xt_register_target(&reject_tg6_reg); +} + +static void __exit reject_tg6_exit(void) +{ + xt_unregister_target(&reject_tg6_reg); +} + +module_init(reject_tg6_init); +module_exit(reject_tg6_exit); diff --git a/net/ipv6/netfilter/ip6t_SYNPROXY.c b/net/ipv6/netfilter/ip6t_SYNPROXY.c new file mode 100644 index 000000000..d51d0c3e5 --- /dev/null +++ b/net/ipv6/netfilter/ip6t_SYNPROXY.c @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2013 Patrick McHardy + */ + +#include +#include +#include + +#include + +static unsigned int +synproxy_tg6(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_synproxy_info *info = par->targinfo; + struct net *net = xt_net(par); + struct synproxy_net *snet = synproxy_pernet(net); + struct synproxy_options opts = {}; + struct tcphdr *th, _th; + + if (nf_ip6_checksum(skb, xt_hooknum(par), par->thoff, IPPROTO_TCP)) + return NF_DROP; + + th = skb_header_pointer(skb, par->thoff, sizeof(_th), &_th); + if (th == NULL) + return NF_DROP; + + if (!synproxy_parse_options(skb, par->thoff, th, &opts)) + return NF_DROP; + + if (th->syn && !(th->ack || th->fin || th->rst)) { + /* Initial SYN from client */ + this_cpu_inc(snet->stats->syn_received); + + if (th->ece && th->cwr) + opts.options |= XT_SYNPROXY_OPT_ECN; + + opts.options &= info->options; + opts.mss_encode = opts.mss_option; + opts.mss_option = info->mss; + if (opts.options & XT_SYNPROXY_OPT_TIMESTAMP) + synproxy_init_timestamp_cookie(info, &opts); + else + opts.options &= ~(XT_SYNPROXY_OPT_WSCALE | + XT_SYNPROXY_OPT_SACK_PERM | + XT_SYNPROXY_OPT_ECN); + + synproxy_send_client_synack_ipv6(net, skb, th, &opts); + consume_skb(skb); + return NF_STOLEN; + + } else if (th->ack && !(th->fin || th->rst || th->syn)) { + /* ACK from client */ + if (synproxy_recv_client_ack_ipv6(net, skb, th, &opts, + ntohl(th->seq))) { + consume_skb(skb); + return NF_STOLEN; + } else { + return NF_DROP; + } + } + + return XT_CONTINUE; +} + +static int synproxy_tg6_check(const struct xt_tgchk_param *par) +{ + struct synproxy_net *snet = synproxy_pernet(par->net); + const struct ip6t_entry *e = par->entryinfo; + int err; + + if (!(e->ipv6.flags & IP6T_F_PROTO) || + e->ipv6.proto != IPPROTO_TCP || + e->ipv6.invflags & XT_INV_PROTO) + return -EINVAL; + + err = nf_ct_netns_get(par->net, par->family); + if (err) + return err; + + err = nf_synproxy_ipv6_init(snet, par->net); + if (err) { + nf_ct_netns_put(par->net, par->family); + return err; + } + + return err; +} + +static void synproxy_tg6_destroy(const struct xt_tgdtor_param *par) +{ + struct synproxy_net *snet = synproxy_pernet(par->net); + + nf_synproxy_ipv6_fini(snet, par->net); + nf_ct_netns_put(par->net, par->family); +} + +static struct xt_target synproxy_tg6_reg __read_mostly = { + .name = "SYNPROXY", + .family = NFPROTO_IPV6, + .hooks = (1 << NF_INET_LOCAL_IN) | (1 << NF_INET_FORWARD), + .target = synproxy_tg6, + .targetsize = sizeof(struct xt_synproxy_info), + .checkentry = synproxy_tg6_check, + .destroy = synproxy_tg6_destroy, + .me = THIS_MODULE, +}; + +static int __init synproxy_tg6_init(void) +{ + return xt_register_target(&synproxy_tg6_reg); +} + +static void __exit synproxy_tg6_exit(void) +{ + xt_unregister_target(&synproxy_tg6_reg); +} + +module_init(synproxy_tg6_init); +module_exit(synproxy_tg6_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_DESCRIPTION("Intercept IPv6 TCP connections and establish them using syncookies"); diff --git a/net/ipv6/netfilter/ip6t_ah.c b/net/ipv6/netfilter/ip6t_ah.c new file mode 100644 index 000000000..70da2f2ce --- /dev/null +++ b/net/ipv6/netfilter/ip6t_ah.c @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Kernel module to match AH parameters. */ + +/* (C) 2001-2002 Andras Kis-Szabo + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Xtables: IPv6 IPsec-AH match"); +MODULE_AUTHOR("Andras Kis-Szabo "); + +/* Returns 1 if the spi is matched by the range, 0 otherwise */ +static inline bool +spi_match(u_int32_t min, u_int32_t max, u_int32_t spi, bool invert) +{ + bool r; + + pr_debug("spi_match:%c 0x%x <= 0x%x <= 0x%x\n", + invert ? '!' : ' ', min, spi, max); + r = (spi >= min && spi <= max) ^ invert; + pr_debug(" result %s\n", r ? "PASS" : "FAILED"); + return r; +} + +static bool ah_mt6(const struct sk_buff *skb, struct xt_action_param *par) +{ + struct ip_auth_hdr _ah; + const struct ip_auth_hdr *ah; + const struct ip6t_ah *ahinfo = par->matchinfo; + unsigned int ptr = 0; + unsigned int hdrlen = 0; + int err; + + err = ipv6_find_hdr(skb, &ptr, NEXTHDR_AUTH, NULL, NULL); + if (err < 0) { + if (err != -ENOENT) + par->hotdrop = true; + return false; + } + + ah = skb_header_pointer(skb, ptr, sizeof(_ah), &_ah); + if (ah == NULL) { + par->hotdrop = true; + return false; + } + + hdrlen = ipv6_authlen(ah); + + pr_debug("IPv6 AH LEN %u %u ", hdrlen, ah->hdrlen); + pr_debug("RES %04X ", ah->reserved); + pr_debug("SPI %u %08X\n", ntohl(ah->spi), ntohl(ah->spi)); + + pr_debug("IPv6 AH spi %02X ", + spi_match(ahinfo->spis[0], ahinfo->spis[1], + ntohl(ah->spi), + !!(ahinfo->invflags & IP6T_AH_INV_SPI))); + pr_debug("len %02X %04X %02X ", + ahinfo->hdrlen, hdrlen, + (!ahinfo->hdrlen || + (ahinfo->hdrlen == hdrlen) ^ + !!(ahinfo->invflags & IP6T_AH_INV_LEN))); + pr_debug("res %02X %04X %02X\n", + ahinfo->hdrres, ah->reserved, + !(ahinfo->hdrres && ah->reserved)); + + return spi_match(ahinfo->spis[0], ahinfo->spis[1], + ntohl(ah->spi), + !!(ahinfo->invflags & IP6T_AH_INV_SPI)) && + (!ahinfo->hdrlen || + (ahinfo->hdrlen == hdrlen) ^ + !!(ahinfo->invflags & IP6T_AH_INV_LEN)) && + !(ahinfo->hdrres && ah->reserved); +} + +static int ah_mt6_check(const struct xt_mtchk_param *par) +{ + const struct ip6t_ah *ahinfo = par->matchinfo; + + if (ahinfo->invflags & ~IP6T_AH_INV_MASK) { + pr_debug("unknown flags %X\n", ahinfo->invflags); + return -EINVAL; + } + return 0; +} + +static struct xt_match ah_mt6_reg __read_mostly = { + .name = "ah", + .family = NFPROTO_IPV6, + .match = ah_mt6, + .matchsize = sizeof(struct ip6t_ah), + .checkentry = ah_mt6_check, + .me = THIS_MODULE, +}; + +static int __init ah_mt6_init(void) +{ + return xt_register_match(&ah_mt6_reg); +} + +static void __exit ah_mt6_exit(void) +{ + xt_unregister_match(&ah_mt6_reg); +} + +module_init(ah_mt6_init); +module_exit(ah_mt6_exit); diff --git a/net/ipv6/netfilter/ip6t_eui64.c b/net/ipv6/netfilter/ip6t_eui64.c new file mode 100644 index 000000000..d704f7ed3 --- /dev/null +++ b/net/ipv6/netfilter/ip6t_eui64.c @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Kernel module to match EUI64 address parameters. */ + +/* (C) 2001-2002 Andras Kis-Szabo + */ + +#include +#include +#include +#include + +#include +#include + +MODULE_DESCRIPTION("Xtables: IPv6 EUI64 address match"); +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Andras Kis-Szabo "); + +static bool +eui64_mt6(const struct sk_buff *skb, struct xt_action_param *par) +{ + unsigned char eui64[8]; + + if (!(skb_mac_header(skb) >= skb->head && + skb_mac_header(skb) + ETH_HLEN <= skb->data) && + par->fragoff != 0) { + par->hotdrop = true; + return false; + } + + memset(eui64, 0, sizeof(eui64)); + + if (eth_hdr(skb)->h_proto == htons(ETH_P_IPV6)) { + if (ipv6_hdr(skb)->version == 0x6) { + memcpy(eui64, eth_hdr(skb)->h_source, 3); + memcpy(eui64 + 5, eth_hdr(skb)->h_source + 3, 3); + eui64[3] = 0xff; + eui64[4] = 0xfe; + eui64[0] ^= 0x02; + + if (!memcmp(ipv6_hdr(skb)->saddr.s6_addr + 8, eui64, + sizeof(eui64))) + return true; + } + } + + return false; +} + +static struct xt_match eui64_mt6_reg __read_mostly = { + .name = "eui64", + .family = NFPROTO_IPV6, + .match = eui64_mt6, + .matchsize = sizeof(int), + .hooks = (1 << NF_INET_PRE_ROUTING) | (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_FORWARD), + .me = THIS_MODULE, +}; + +static int __init eui64_mt6_init(void) +{ + return xt_register_match(&eui64_mt6_reg); +} + +static void __exit eui64_mt6_exit(void) +{ + xt_unregister_match(&eui64_mt6_reg); +} + +module_init(eui64_mt6_init); +module_exit(eui64_mt6_exit); diff --git a/net/ipv6/netfilter/ip6t_frag.c b/net/ipv6/netfilter/ip6t_frag.c new file mode 100644 index 000000000..3aad64393 --- /dev/null +++ b/net/ipv6/netfilter/ip6t_frag.c @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Kernel module to match FRAG parameters. */ + +/* (C) 2001-2002 Andras Kis-Szabo + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Xtables: IPv6 fragment match"); +MODULE_AUTHOR("Andras Kis-Szabo "); + +/* Returns 1 if the id is matched by the range, 0 otherwise */ +static inline bool +id_match(u_int32_t min, u_int32_t max, u_int32_t id, bool invert) +{ + bool r; + pr_debug("id_match:%c 0x%x <= 0x%x <= 0x%x\n", invert ? '!' : ' ', + min, id, max); + r = (id >= min && id <= max) ^ invert; + pr_debug(" result %s\n", r ? "PASS" : "FAILED"); + return r; +} + +static bool +frag_mt6(const struct sk_buff *skb, struct xt_action_param *par) +{ + struct frag_hdr _frag; + const struct frag_hdr *fh; + const struct ip6t_frag *fraginfo = par->matchinfo; + unsigned int ptr = 0; + int err; + + err = ipv6_find_hdr(skb, &ptr, NEXTHDR_FRAGMENT, NULL, NULL); + if (err < 0) { + if (err != -ENOENT) + par->hotdrop = true; + return false; + } + + fh = skb_header_pointer(skb, ptr, sizeof(_frag), &_frag); + if (fh == NULL) { + par->hotdrop = true; + return false; + } + + pr_debug("INFO %04X ", fh->frag_off); + pr_debug("OFFSET %04X ", ntohs(fh->frag_off) & ~0x7); + pr_debug("RES %02X %04X", fh->reserved, ntohs(fh->frag_off) & 0x6); + pr_debug("MF %04X ", fh->frag_off & htons(IP6_MF)); + pr_debug("ID %u %08X\n", ntohl(fh->identification), + ntohl(fh->identification)); + + pr_debug("IPv6 FRAG id %02X ", + id_match(fraginfo->ids[0], fraginfo->ids[1], + ntohl(fh->identification), + !!(fraginfo->invflags & IP6T_FRAG_INV_IDS))); + pr_debug("res %02X %02X%04X %02X ", + fraginfo->flags & IP6T_FRAG_RES, fh->reserved, + ntohs(fh->frag_off) & 0x6, + !((fraginfo->flags & IP6T_FRAG_RES) && + (fh->reserved || (ntohs(fh->frag_off) & 0x06)))); + pr_debug("first %02X %02X %02X ", + fraginfo->flags & IP6T_FRAG_FST, + ntohs(fh->frag_off) & ~0x7, + !((fraginfo->flags & IP6T_FRAG_FST) && + (ntohs(fh->frag_off) & ~0x7))); + pr_debug("mf %02X %02X %02X ", + fraginfo->flags & IP6T_FRAG_MF, + ntohs(fh->frag_off) & IP6_MF, + !((fraginfo->flags & IP6T_FRAG_MF) && + !((ntohs(fh->frag_off) & IP6_MF)))); + pr_debug("last %02X %02X %02X\n", + fraginfo->flags & IP6T_FRAG_NMF, + ntohs(fh->frag_off) & IP6_MF, + !((fraginfo->flags & IP6T_FRAG_NMF) && + (ntohs(fh->frag_off) & IP6_MF))); + + return id_match(fraginfo->ids[0], fraginfo->ids[1], + ntohl(fh->identification), + !!(fraginfo->invflags & IP6T_FRAG_INV_IDS)) && + !((fraginfo->flags & IP6T_FRAG_RES) && + (fh->reserved || (ntohs(fh->frag_off) & 0x6))) && + !((fraginfo->flags & IP6T_FRAG_FST) && + (ntohs(fh->frag_off) & ~0x7)) && + !((fraginfo->flags & IP6T_FRAG_MF) && + !(ntohs(fh->frag_off) & IP6_MF)) && + !((fraginfo->flags & IP6T_FRAG_NMF) && + (ntohs(fh->frag_off) & IP6_MF)); +} + +static int frag_mt6_check(const struct xt_mtchk_param *par) +{ + const struct ip6t_frag *fraginfo = par->matchinfo; + + if (fraginfo->invflags & ~IP6T_FRAG_INV_MASK) { + pr_debug("unknown flags %X\n", fraginfo->invflags); + return -EINVAL; + } + return 0; +} + +static struct xt_match frag_mt6_reg __read_mostly = { + .name = "frag", + .family = NFPROTO_IPV6, + .match = frag_mt6, + .matchsize = sizeof(struct ip6t_frag), + .checkentry = frag_mt6_check, + .me = THIS_MODULE, +}; + +static int __init frag_mt6_init(void) +{ + return xt_register_match(&frag_mt6_reg); +} + +static void __exit frag_mt6_exit(void) +{ + xt_unregister_match(&frag_mt6_reg); +} + +module_init(frag_mt6_init); +module_exit(frag_mt6_exit); diff --git a/net/ipv6/netfilter/ip6t_hbh.c b/net/ipv6/netfilter/ip6t_hbh.c new file mode 100644 index 000000000..e7a3fb935 --- /dev/null +++ b/net/ipv6/netfilter/ip6t_hbh.c @@ -0,0 +1,211 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Kernel module to match Hop-by-Hop and Destination parameters. */ + +/* (C) 2001-2002 Andras Kis-Szabo + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Xtables: IPv6 Hop-By-Hop and Destination Header match"); +MODULE_AUTHOR("Andras Kis-Szabo "); +MODULE_ALIAS("ip6t_dst"); + +/* + * (Type & 0xC0) >> 6 + * 0 -> ignorable + * 1 -> must drop the packet + * 2 -> send ICMP PARM PROB regardless and drop packet + * 3 -> Send ICMP if not a multicast address and drop packet + * (Type & 0x20) >> 5 + * 0 -> invariant + * 1 -> can change the routing + * (Type & 0x1F) Type + * 0 -> Pad1 (only 1 byte!) + * 1 -> PadN LENGTH info (total length = length + 2) + * C0 | 2 -> JUMBO 4 x x x x ( xxxx > 64k ) + * 5 -> RTALERT 2 x x + */ + +static struct xt_match hbh_mt6_reg[] __read_mostly; + +static bool +hbh_mt6(const struct sk_buff *skb, struct xt_action_param *par) +{ + struct ipv6_opt_hdr _optsh; + const struct ipv6_opt_hdr *oh; + const struct ip6t_opts *optinfo = par->matchinfo; + unsigned int temp; + unsigned int ptr = 0; + unsigned int hdrlen = 0; + bool ret = false; + u8 _opttype; + u8 _optlen; + const u_int8_t *tp = NULL; + const u_int8_t *lp = NULL; + unsigned int optlen; + int err; + + err = ipv6_find_hdr(skb, &ptr, + (par->match == &hbh_mt6_reg[0]) ? + NEXTHDR_HOP : NEXTHDR_DEST, NULL, NULL); + if (err < 0) { + if (err != -ENOENT) + par->hotdrop = true; + return false; + } + + oh = skb_header_pointer(skb, ptr, sizeof(_optsh), &_optsh); + if (oh == NULL) { + par->hotdrop = true; + return false; + } + + hdrlen = ipv6_optlen(oh); + if (skb->len - ptr < hdrlen) { + /* Packet smaller than it's length field */ + return false; + } + + pr_debug("IPv6 OPTS LEN %u %u ", hdrlen, oh->hdrlen); + + pr_debug("len %02X %04X %02X ", + optinfo->hdrlen, hdrlen, + (!(optinfo->flags & IP6T_OPTS_LEN) || + ((optinfo->hdrlen == hdrlen) ^ + !!(optinfo->invflags & IP6T_OPTS_INV_LEN)))); + + ret = (!(optinfo->flags & IP6T_OPTS_LEN) || + ((optinfo->hdrlen == hdrlen) ^ + !!(optinfo->invflags & IP6T_OPTS_INV_LEN))); + + ptr += 2; + hdrlen -= 2; + if (!(optinfo->flags & IP6T_OPTS_OPTS)) { + return ret; + } else { + pr_debug("Strict "); + pr_debug("#%d ", optinfo->optsnr); + for (temp = 0; temp < optinfo->optsnr; temp++) { + /* type field exists ? */ + if (hdrlen < 1) + break; + tp = skb_header_pointer(skb, ptr, sizeof(_opttype), + &_opttype); + if (tp == NULL) + break; + + /* Type check */ + if (*tp != (optinfo->opts[temp] & 0xFF00) >> 8) { + pr_debug("Tbad %02X %02X\n", *tp, + (optinfo->opts[temp] & 0xFF00) >> 8); + return false; + } else { + pr_debug("Tok "); + } + /* Length check */ + if (*tp) { + u16 spec_len; + + /* length field exists ? */ + if (hdrlen < 2) + break; + lp = skb_header_pointer(skb, ptr + 1, + sizeof(_optlen), + &_optlen); + if (lp == NULL) + break; + spec_len = optinfo->opts[temp] & 0x00FF; + + if (spec_len != 0x00FF && spec_len != *lp) { + pr_debug("Lbad %02X %04X\n", *lp, + spec_len); + return false; + } + pr_debug("Lok "); + optlen = *lp + 2; + } else { + pr_debug("Pad1\n"); + optlen = 1; + } + + /* Step to the next */ + pr_debug("len%04X\n", optlen); + + if ((ptr > skb->len - optlen || hdrlen < optlen) && + temp < optinfo->optsnr - 1) { + pr_debug("new pointer is too large!\n"); + break; + } + ptr += optlen; + hdrlen -= optlen; + } + if (temp == optinfo->optsnr) + return ret; + else + return false; + } + + return false; +} + +static int hbh_mt6_check(const struct xt_mtchk_param *par) +{ + const struct ip6t_opts *optsinfo = par->matchinfo; + + if (optsinfo->invflags & ~IP6T_OPTS_INV_MASK) { + pr_debug("unknown flags %X\n", optsinfo->invflags); + return -EINVAL; + } + + if (optsinfo->flags & IP6T_OPTS_NSTRICT) { + pr_debug("Not strict - not implemented"); + return -EINVAL; + } + + return 0; +} + +static struct xt_match hbh_mt6_reg[] __read_mostly = { + { + /* Note, hbh_mt6 relies on the order of hbh_mt6_reg */ + .name = "hbh", + .family = NFPROTO_IPV6, + .match = hbh_mt6, + .matchsize = sizeof(struct ip6t_opts), + .checkentry = hbh_mt6_check, + .me = THIS_MODULE, + }, + { + .name = "dst", + .family = NFPROTO_IPV6, + .match = hbh_mt6, + .matchsize = sizeof(struct ip6t_opts), + .checkentry = hbh_mt6_check, + .me = THIS_MODULE, + }, +}; + +static int __init hbh_mt6_init(void) +{ + return xt_register_matches(hbh_mt6_reg, ARRAY_SIZE(hbh_mt6_reg)); +} + +static void __exit hbh_mt6_exit(void) +{ + xt_unregister_matches(hbh_mt6_reg, ARRAY_SIZE(hbh_mt6_reg)); +} + +module_init(hbh_mt6_init); +module_exit(hbh_mt6_exit); diff --git a/net/ipv6/netfilter/ip6t_ipv6header.c b/net/ipv6/netfilter/ip6t_ipv6header.c new file mode 100644 index 000000000..c52ff929c --- /dev/null +++ b/net/ipv6/netfilter/ip6t_ipv6header.c @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* ipv6header match - matches IPv6 packets based + on whether they contain certain headers */ + +/* Original idea: Brad Chapman + * Rewritten by: Andras Kis-Szabo */ + +/* (C) 2001-2002 Andras Kis-Szabo + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Xtables: IPv6 header types match"); +MODULE_AUTHOR("Andras Kis-Szabo "); + +static bool +ipv6header_mt6(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct ip6t_ipv6header_info *info = par->matchinfo; + unsigned int temp; + int len; + u8 nexthdr; + unsigned int ptr; + + /* Make sure this isn't an evil packet */ + + /* type of the 1st exthdr */ + nexthdr = ipv6_hdr(skb)->nexthdr; + /* pointer to the 1st exthdr */ + ptr = sizeof(struct ipv6hdr); + /* available length */ + len = skb->len - ptr; + temp = 0; + + while (nf_ip6_ext_hdr(nexthdr)) { + const struct ipv6_opt_hdr *hp; + struct ipv6_opt_hdr _hdr; + int hdrlen; + + /* No more exthdr -> evaluate */ + if (nexthdr == NEXTHDR_NONE) { + temp |= MASK_NONE; + break; + } + /* Is there enough space for the next ext header? */ + if (len < (int)sizeof(struct ipv6_opt_hdr)) + return false; + /* ESP -> evaluate */ + if (nexthdr == NEXTHDR_ESP) { + temp |= MASK_ESP; + break; + } + + hp = skb_header_pointer(skb, ptr, sizeof(_hdr), &_hdr); + if (!hp) { + par->hotdrop = true; + return false; + } + + /* Calculate the header length */ + if (nexthdr == NEXTHDR_FRAGMENT) + hdrlen = 8; + else if (nexthdr == NEXTHDR_AUTH) + hdrlen = ipv6_authlen(hp); + else + hdrlen = ipv6_optlen(hp); + + /* set the flag */ + switch (nexthdr) { + case NEXTHDR_HOP: + temp |= MASK_HOPOPTS; + break; + case NEXTHDR_ROUTING: + temp |= MASK_ROUTING; + break; + case NEXTHDR_FRAGMENT: + temp |= MASK_FRAGMENT; + break; + case NEXTHDR_AUTH: + temp |= MASK_AH; + break; + case NEXTHDR_DEST: + temp |= MASK_DSTOPTS; + break; + default: + return false; + } + + nexthdr = hp->nexthdr; + len -= hdrlen; + ptr += hdrlen; + if (ptr > skb->len) + break; + } + + if (nexthdr != NEXTHDR_NONE && nexthdr != NEXTHDR_ESP) + temp |= MASK_PROTO; + + if (info->modeflag) + return !((temp ^ info->matchflags ^ info->invflags) + & info->matchflags); + else { + if (info->invflags) + return temp != info->matchflags; + else + return temp == info->matchflags; + } +} + +static int ipv6header_mt6_check(const struct xt_mtchk_param *par) +{ + const struct ip6t_ipv6header_info *info = par->matchinfo; + + /* invflags is 0 or 0xff in hard mode */ + if ((!info->modeflag) && info->invflags != 0x00 && + info->invflags != 0xFF) + return -EINVAL; + + return 0; +} + +static struct xt_match ipv6header_mt6_reg __read_mostly = { + .name = "ipv6header", + .family = NFPROTO_IPV6, + .match = ipv6header_mt6, + .matchsize = sizeof(struct ip6t_ipv6header_info), + .checkentry = ipv6header_mt6_check, + .destroy = NULL, + .me = THIS_MODULE, +}; + +static int __init ipv6header_mt6_init(void) +{ + return xt_register_match(&ipv6header_mt6_reg); +} + +static void __exit ipv6header_mt6_exit(void) +{ + xt_unregister_match(&ipv6header_mt6_reg); +} + +module_init(ipv6header_mt6_init); +module_exit(ipv6header_mt6_exit); diff --git a/net/ipv6/netfilter/ip6t_mh.c b/net/ipv6/netfilter/ip6t_mh.c new file mode 100644 index 000000000..fd492b69a --- /dev/null +++ b/net/ipv6/netfilter/ip6t_mh.c @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C)2006 USAGI/WIDE Project + * + * Author: + * Masahide NAKAMURA @USAGI + * + * Based on net/netfilter/xt_tcpudp.c + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include + +#include +#include + +MODULE_DESCRIPTION("Xtables: IPv6 Mobility Header match"); +MODULE_LICENSE("GPL"); + +/* Returns 1 if the type is matched by the range, 0 otherwise */ +static inline bool +type_match(u_int8_t min, u_int8_t max, u_int8_t type, bool invert) +{ + return (type >= min && type <= max) ^ invert; +} + +static bool mh_mt6(const struct sk_buff *skb, struct xt_action_param *par) +{ + struct ip6_mh _mh; + const struct ip6_mh *mh; + const struct ip6t_mh *mhinfo = par->matchinfo; + + /* Must not be a fragment. */ + if (par->fragoff != 0) + return false; + + mh = skb_header_pointer(skb, par->thoff, sizeof(_mh), &_mh); + if (mh == NULL) { + /* We've been asked to examine this packet, and we + can't. Hence, no choice but to drop. */ + pr_debug("Dropping evil MH tinygram.\n"); + par->hotdrop = true; + return false; + } + + if (mh->ip6mh_proto != IPPROTO_NONE) { + pr_debug("Dropping invalid MH Payload Proto: %u\n", + mh->ip6mh_proto); + par->hotdrop = true; + return false; + } + + return type_match(mhinfo->types[0], mhinfo->types[1], mh->ip6mh_type, + !!(mhinfo->invflags & IP6T_MH_INV_TYPE)); +} + +static int mh_mt6_check(const struct xt_mtchk_param *par) +{ + const struct ip6t_mh *mhinfo = par->matchinfo; + + /* Must specify no unknown invflags */ + return (mhinfo->invflags & ~IP6T_MH_INV_MASK) ? -EINVAL : 0; +} + +static struct xt_match mh_mt6_reg __read_mostly = { + .name = "mh", + .family = NFPROTO_IPV6, + .checkentry = mh_mt6_check, + .match = mh_mt6, + .matchsize = sizeof(struct ip6t_mh), + .proto = IPPROTO_MH, + .me = THIS_MODULE, +}; + +static int __init mh_mt6_init(void) +{ + return xt_register_match(&mh_mt6_reg); +} + +static void __exit mh_mt6_exit(void) +{ + xt_unregister_match(&mh_mt6_reg); +} + +module_init(mh_mt6_init); +module_exit(mh_mt6_exit); diff --git a/net/ipv6/netfilter/ip6t_rpfilter.c b/net/ipv6/netfilter/ip6t_rpfilter.c new file mode 100644 index 000000000..67c87a88c --- /dev/null +++ b/net/ipv6/netfilter/ip6t_rpfilter.c @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2011 Florian Westphal + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include + +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Florian Westphal "); +MODULE_DESCRIPTION("Xtables: IPv6 reverse path filter match"); + +static bool rpfilter_addr_unicast(const struct in6_addr *addr) +{ + int addr_type = ipv6_addr_type(addr); + return addr_type & IPV6_ADDR_UNICAST; +} + +static bool rpfilter_addr_linklocal(const struct in6_addr *addr) +{ + int addr_type = ipv6_addr_type(addr); + return addr_type & IPV6_ADDR_LINKLOCAL; +} + +static bool rpfilter_lookup_reverse6(struct net *net, const struct sk_buff *skb, + const struct net_device *dev, u8 flags) +{ + struct rt6_info *rt; + struct ipv6hdr *iph = ipv6_hdr(skb); + bool ret = false; + struct flowi6 fl6 = { + .flowi6_iif = LOOPBACK_IFINDEX, + .flowi6_l3mdev = l3mdev_master_ifindex_rcu(dev), + .flowlabel = (* (__be32 *) iph) & IPV6_FLOWINFO_MASK, + .flowi6_proto = iph->nexthdr, + .flowi6_uid = sock_net_uid(net, NULL), + .daddr = iph->saddr, + }; + int lookup_flags; + + if (rpfilter_addr_unicast(&iph->daddr)) { + memcpy(&fl6.saddr, &iph->daddr, sizeof(struct in6_addr)); + lookup_flags = RT6_LOOKUP_F_HAS_SADDR; + } else { + lookup_flags = 0; + } + + fl6.flowi6_mark = flags & XT_RPFILTER_VALID_MARK ? skb->mark : 0; + + if (rpfilter_addr_linklocal(&iph->saddr)) { + lookup_flags |= RT6_LOOKUP_F_IFACE; + fl6.flowi6_oif = dev->ifindex; + } else if ((flags & XT_RPFILTER_LOOSE) == 0) + fl6.flowi6_oif = dev->ifindex; + + rt = (void *)ip6_route_lookup(net, &fl6, skb, lookup_flags); + if (rt->dst.error) + goto out; + + if (rt->rt6i_flags & (RTF_REJECT|RTF_ANYCAST)) + goto out; + + if (rt->rt6i_flags & RTF_LOCAL) { + ret = flags & XT_RPFILTER_ACCEPT_LOCAL; + goto out; + } + + if (rt->rt6i_idev->dev == dev || + l3mdev_master_ifindex_rcu(rt->rt6i_idev->dev) == dev->ifindex || + (flags & XT_RPFILTER_LOOSE)) + ret = true; + out: + ip6_rt_put(rt); + return ret; +} + +static bool +rpfilter_is_loopback(const struct sk_buff *skb, const struct net_device *in) +{ + return skb->pkt_type == PACKET_LOOPBACK || in->flags & IFF_LOOPBACK; +} + +static bool rpfilter_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_rpfilter_info *info = par->matchinfo; + int saddrtype; + struct ipv6hdr *iph; + bool invert = info->flags & XT_RPFILTER_INVERT; + + if (rpfilter_is_loopback(skb, xt_in(par))) + return true ^ invert; + + iph = ipv6_hdr(skb); + saddrtype = ipv6_addr_type(&iph->saddr); + if (unlikely(saddrtype == IPV6_ADDR_ANY)) + return true ^ invert; /* not routable: forward path will drop it */ + + return rpfilter_lookup_reverse6(xt_net(par), skb, xt_in(par), + info->flags) ^ invert; +} + +static int rpfilter_check(const struct xt_mtchk_param *par) +{ + const struct xt_rpfilter_info *info = par->matchinfo; + unsigned int options = ~XT_RPFILTER_OPTION_MASK; + + if (info->flags & options) { + pr_info_ratelimited("unknown options\n"); + return -EINVAL; + } + + if (strcmp(par->table, "mangle") != 0 && + strcmp(par->table, "raw") != 0) { + pr_info_ratelimited("only valid in \'raw\' or \'mangle\' table, not \'%s\'\n", + par->table); + return -EINVAL; + } + + return 0; +} + +static struct xt_match rpfilter_mt_reg __read_mostly = { + .name = "rpfilter", + .family = NFPROTO_IPV6, + .checkentry = rpfilter_check, + .match = rpfilter_mt, + .matchsize = sizeof(struct xt_rpfilter_info), + .hooks = (1 << NF_INET_PRE_ROUTING), + .me = THIS_MODULE +}; + +static int __init rpfilter_mt_init(void) +{ + return xt_register_match(&rpfilter_mt_reg); +} + +static void __exit rpfilter_mt_exit(void) +{ + xt_unregister_match(&rpfilter_mt_reg); +} + +module_init(rpfilter_mt_init); +module_exit(rpfilter_mt_exit); diff --git a/net/ipv6/netfilter/ip6t_rt.c b/net/ipv6/netfilter/ip6t_rt.c new file mode 100644 index 000000000..4ad8b2032 --- /dev/null +++ b/net/ipv6/netfilter/ip6t_rt.c @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Kernel module to match ROUTING parameters. */ + +/* (C) 2001-2002 Andras Kis-Szabo + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Xtables: IPv6 Routing Header match"); +MODULE_AUTHOR("Andras Kis-Szabo "); + +/* Returns 1 if the id is matched by the range, 0 otherwise */ +static inline bool +segsleft_match(u_int32_t min, u_int32_t max, u_int32_t id, bool invert) +{ + return (id >= min && id <= max) ^ invert; +} + +static bool rt_mt6(const struct sk_buff *skb, struct xt_action_param *par) +{ + struct ipv6_rt_hdr _route; + const struct ipv6_rt_hdr *rh; + const struct ip6t_rt *rtinfo = par->matchinfo; + unsigned int temp; + unsigned int ptr = 0; + unsigned int hdrlen = 0; + bool ret = false; + struct in6_addr _addr; + const struct in6_addr *ap; + int err; + + err = ipv6_find_hdr(skb, &ptr, NEXTHDR_ROUTING, NULL, NULL); + if (err < 0) { + if (err != -ENOENT) + par->hotdrop = true; + return false; + } + + rh = skb_header_pointer(skb, ptr, sizeof(_route), &_route); + if (rh == NULL) { + par->hotdrop = true; + return false; + } + + hdrlen = ipv6_optlen(rh); + if (skb->len - ptr < hdrlen) { + /* Pcket smaller than its length field */ + return false; + } + + ret = (segsleft_match(rtinfo->segsleft[0], rtinfo->segsleft[1], + rh->segments_left, + !!(rtinfo->invflags & IP6T_RT_INV_SGS))) && + (!(rtinfo->flags & IP6T_RT_LEN) || + ((rtinfo->hdrlen == hdrlen) ^ + !!(rtinfo->invflags & IP6T_RT_INV_LEN))) && + (!(rtinfo->flags & IP6T_RT_TYP) || + ((rtinfo->rt_type == rh->type) ^ + !!(rtinfo->invflags & IP6T_RT_INV_TYP))); + + if (ret && (rtinfo->flags & IP6T_RT_RES)) { + const u_int32_t *rp; + u_int32_t _reserved; + rp = skb_header_pointer(skb, + ptr + offsetof(struct rt0_hdr, + reserved), + sizeof(_reserved), + &_reserved); + if (!rp) { + par->hotdrop = true; + return false; + } + + ret = (*rp == 0); + } + + if (!(rtinfo->flags & IP6T_RT_FST)) { + return ret; + } else if (rtinfo->flags & IP6T_RT_FST_NSTRICT) { + if (rtinfo->addrnr > (unsigned int)((hdrlen - 8) / 16)) { + return false; + } else { + unsigned int i = 0; + + for (temp = 0; + temp < (unsigned int)((hdrlen - 8) / 16); + temp++) { + ap = skb_header_pointer(skb, + ptr + + sizeof(struct rt0_hdr) + + temp * sizeof(_addr), + sizeof(_addr), + &_addr); + + if (ap == NULL) { + par->hotdrop = true; + return false; + } + + if (ipv6_addr_equal(ap, &rtinfo->addrs[i])) + i++; + if (i == rtinfo->addrnr) + break; + } + if (i == rtinfo->addrnr) + return ret; + else + return false; + } + } else { + if (rtinfo->addrnr > (unsigned int)((hdrlen - 8) / 16)) { + return false; + } else { + for (temp = 0; temp < rtinfo->addrnr; temp++) { + ap = skb_header_pointer(skb, + ptr + + sizeof(struct rt0_hdr) + + temp * sizeof(_addr), + sizeof(_addr), + &_addr); + if (ap == NULL) { + par->hotdrop = true; + return false; + } + + if (!ipv6_addr_equal(ap, &rtinfo->addrs[temp])) + break; + } + if (temp == rtinfo->addrnr && + temp == (unsigned int)((hdrlen - 8) / 16)) + return ret; + else + return false; + } + } + + return false; +} + +static int rt_mt6_check(const struct xt_mtchk_param *par) +{ + const struct ip6t_rt *rtinfo = par->matchinfo; + + if (rtinfo->invflags & ~IP6T_RT_INV_MASK) { + pr_debug("unknown flags %X\n", rtinfo->invflags); + return -EINVAL; + } + if ((rtinfo->flags & (IP6T_RT_RES | IP6T_RT_FST_MASK)) && + (!(rtinfo->flags & IP6T_RT_TYP) || + (rtinfo->rt_type != 0) || + (rtinfo->invflags & IP6T_RT_INV_TYP))) { + pr_debug("`--rt-type 0' required before `--rt-0-*'"); + return -EINVAL; + } + + return 0; +} + +static struct xt_match rt_mt6_reg __read_mostly = { + .name = "rt", + .family = NFPROTO_IPV6, + .match = rt_mt6, + .matchsize = sizeof(struct ip6t_rt), + .checkentry = rt_mt6_check, + .me = THIS_MODULE, +}; + +static int __init rt_mt6_init(void) +{ + return xt_register_match(&rt_mt6_reg); +} + +static void __exit rt_mt6_exit(void) +{ + xt_unregister_match(&rt_mt6_reg); +} + +module_init(rt_mt6_init); +module_exit(rt_mt6_exit); diff --git a/net/ipv6/netfilter/ip6t_srh.c b/net/ipv6/netfilter/ip6t_srh.c new file mode 100644 index 000000000..db0fd64d8 --- /dev/null +++ b/net/ipv6/netfilter/ip6t_srh.c @@ -0,0 +1,320 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* Kernel module to match Segment Routing Header (SRH) parameters. */ + +/* Author: + * Ahmed Abdelsalam + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +/* Test a struct->mt_invflags and a boolean for inequality */ +#define NF_SRH_INVF(ptr, flag, boolean) \ + ((boolean) ^ !!((ptr)->mt_invflags & (flag))) + +static bool srh_mt6(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct ip6t_srh *srhinfo = par->matchinfo; + struct ipv6_sr_hdr *srh; + struct ipv6_sr_hdr _srh; + int hdrlen, srhoff = 0; + + if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, NULL) < 0) + return false; + srh = skb_header_pointer(skb, srhoff, sizeof(_srh), &_srh); + if (!srh) + return false; + + hdrlen = ipv6_optlen(srh); + if (skb->len - srhoff < hdrlen) + return false; + + if (srh->type != IPV6_SRCRT_TYPE_4) + return false; + + if (srh->segments_left > srh->first_segment) + return false; + + /* Next Header matching */ + if (srhinfo->mt_flags & IP6T_SRH_NEXTHDR) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_NEXTHDR, + !(srh->nexthdr == srhinfo->next_hdr))) + return false; + + /* Header Extension Length matching */ + if (srhinfo->mt_flags & IP6T_SRH_LEN_EQ) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_LEN_EQ, + !(srh->hdrlen == srhinfo->hdr_len))) + return false; + + if (srhinfo->mt_flags & IP6T_SRH_LEN_GT) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_LEN_GT, + !(srh->hdrlen > srhinfo->hdr_len))) + return false; + + if (srhinfo->mt_flags & IP6T_SRH_LEN_LT) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_LEN_LT, + !(srh->hdrlen < srhinfo->hdr_len))) + return false; + + /* Segments Left matching */ + if (srhinfo->mt_flags & IP6T_SRH_SEGS_EQ) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_SEGS_EQ, + !(srh->segments_left == srhinfo->segs_left))) + return false; + + if (srhinfo->mt_flags & IP6T_SRH_SEGS_GT) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_SEGS_GT, + !(srh->segments_left > srhinfo->segs_left))) + return false; + + if (srhinfo->mt_flags & IP6T_SRH_SEGS_LT) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_SEGS_LT, + !(srh->segments_left < srhinfo->segs_left))) + return false; + + /** + * Last Entry matching + * Last_Entry field was introduced in revision 6 of the SRH draft. + * It was called First_Segment in the previous revision + */ + if (srhinfo->mt_flags & IP6T_SRH_LAST_EQ) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_LAST_EQ, + !(srh->first_segment == srhinfo->last_entry))) + return false; + + if (srhinfo->mt_flags & IP6T_SRH_LAST_GT) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_LAST_GT, + !(srh->first_segment > srhinfo->last_entry))) + return false; + + if (srhinfo->mt_flags & IP6T_SRH_LAST_LT) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_LAST_LT, + !(srh->first_segment < srhinfo->last_entry))) + return false; + + /** + * Tag matchig + * Tag field was introduced in revision 6 of the SRH draft. + */ + if (srhinfo->mt_flags & IP6T_SRH_TAG) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_TAG, + !(srh->tag == srhinfo->tag))) + return false; + return true; +} + +static bool srh1_mt6(const struct sk_buff *skb, struct xt_action_param *par) +{ + int hdrlen, psidoff, nsidoff, lsidoff, srhoff = 0; + const struct ip6t_srh1 *srhinfo = par->matchinfo; + struct in6_addr *psid, *nsid, *lsid; + struct in6_addr _psid, _nsid, _lsid; + struct ipv6_sr_hdr *srh; + struct ipv6_sr_hdr _srh; + + if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, NULL) < 0) + return false; + srh = skb_header_pointer(skb, srhoff, sizeof(_srh), &_srh); + if (!srh) + return false; + + hdrlen = ipv6_optlen(srh); + if (skb->len - srhoff < hdrlen) + return false; + + if (srh->type != IPV6_SRCRT_TYPE_4) + return false; + + if (srh->segments_left > srh->first_segment) + return false; + + /* Next Header matching */ + if (srhinfo->mt_flags & IP6T_SRH_NEXTHDR) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_NEXTHDR, + !(srh->nexthdr == srhinfo->next_hdr))) + return false; + + /* Header Extension Length matching */ + if (srhinfo->mt_flags & IP6T_SRH_LEN_EQ) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_LEN_EQ, + !(srh->hdrlen == srhinfo->hdr_len))) + return false; + if (srhinfo->mt_flags & IP6T_SRH_LEN_GT) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_LEN_GT, + !(srh->hdrlen > srhinfo->hdr_len))) + return false; + if (srhinfo->mt_flags & IP6T_SRH_LEN_LT) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_LEN_LT, + !(srh->hdrlen < srhinfo->hdr_len))) + return false; + + /* Segments Left matching */ + if (srhinfo->mt_flags & IP6T_SRH_SEGS_EQ) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_SEGS_EQ, + !(srh->segments_left == srhinfo->segs_left))) + return false; + if (srhinfo->mt_flags & IP6T_SRH_SEGS_GT) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_SEGS_GT, + !(srh->segments_left > srhinfo->segs_left))) + return false; + if (srhinfo->mt_flags & IP6T_SRH_SEGS_LT) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_SEGS_LT, + !(srh->segments_left < srhinfo->segs_left))) + return false; + + /** + * Last Entry matching + * Last_Entry field was introduced in revision 6 of the SRH draft. + * It was called First_Segment in the previous revision + */ + if (srhinfo->mt_flags & IP6T_SRH_LAST_EQ) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_LAST_EQ, + !(srh->first_segment == srhinfo->last_entry))) + return false; + if (srhinfo->mt_flags & IP6T_SRH_LAST_GT) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_LAST_GT, + !(srh->first_segment > srhinfo->last_entry))) + return false; + if (srhinfo->mt_flags & IP6T_SRH_LAST_LT) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_LAST_LT, + !(srh->first_segment < srhinfo->last_entry))) + return false; + + /** + * Tag matchig + * Tag field was introduced in revision 6 of the SRH draft + */ + if (srhinfo->mt_flags & IP6T_SRH_TAG) + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_TAG, + !(srh->tag == srhinfo->tag))) + return false; + + /* Previous SID matching */ + if (srhinfo->mt_flags & IP6T_SRH_PSID) { + if (srh->segments_left == srh->first_segment) + return false; + psidoff = srhoff + sizeof(struct ipv6_sr_hdr) + + ((srh->segments_left + 1) * sizeof(struct in6_addr)); + psid = skb_header_pointer(skb, psidoff, sizeof(_psid), &_psid); + if (!psid) + return false; + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_PSID, + ipv6_masked_addr_cmp(psid, &srhinfo->psid_msk, + &srhinfo->psid_addr))) + return false; + } + + /* Next SID matching */ + if (srhinfo->mt_flags & IP6T_SRH_NSID) { + if (srh->segments_left == 0) + return false; + nsidoff = srhoff + sizeof(struct ipv6_sr_hdr) + + ((srh->segments_left - 1) * sizeof(struct in6_addr)); + nsid = skb_header_pointer(skb, nsidoff, sizeof(_nsid), &_nsid); + if (!nsid) + return false; + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_NSID, + ipv6_masked_addr_cmp(nsid, &srhinfo->nsid_msk, + &srhinfo->nsid_addr))) + return false; + } + + /* Last SID matching */ + if (srhinfo->mt_flags & IP6T_SRH_LSID) { + lsidoff = srhoff + sizeof(struct ipv6_sr_hdr); + lsid = skb_header_pointer(skb, lsidoff, sizeof(_lsid), &_lsid); + if (!lsid) + return false; + if (NF_SRH_INVF(srhinfo, IP6T_SRH_INV_LSID, + ipv6_masked_addr_cmp(lsid, &srhinfo->lsid_msk, + &srhinfo->lsid_addr))) + return false; + } + return true; +} + +static int srh_mt6_check(const struct xt_mtchk_param *par) +{ + const struct ip6t_srh *srhinfo = par->matchinfo; + + if (srhinfo->mt_flags & ~IP6T_SRH_MASK) { + pr_info_ratelimited("unknown srh match flags %X\n", + srhinfo->mt_flags); + return -EINVAL; + } + + if (srhinfo->mt_invflags & ~IP6T_SRH_INV_MASK) { + pr_info_ratelimited("unknown srh invflags %X\n", + srhinfo->mt_invflags); + return -EINVAL; + } + + return 0; +} + +static int srh1_mt6_check(const struct xt_mtchk_param *par) +{ + const struct ip6t_srh1 *srhinfo = par->matchinfo; + + if (srhinfo->mt_flags & ~IP6T_SRH_MASK) { + pr_info_ratelimited("unknown srh match flags %X\n", + srhinfo->mt_flags); + return -EINVAL; + } + + if (srhinfo->mt_invflags & ~IP6T_SRH_INV_MASK) { + pr_info_ratelimited("unknown srh invflags %X\n", + srhinfo->mt_invflags); + return -EINVAL; + } + + return 0; +} + +static struct xt_match srh_mt6_reg[] __read_mostly = { + { + .name = "srh", + .revision = 0, + .family = NFPROTO_IPV6, + .match = srh_mt6, + .matchsize = sizeof(struct ip6t_srh), + .checkentry = srh_mt6_check, + .me = THIS_MODULE, + }, + { + .name = "srh", + .revision = 1, + .family = NFPROTO_IPV6, + .match = srh1_mt6, + .matchsize = sizeof(struct ip6t_srh1), + .checkentry = srh1_mt6_check, + .me = THIS_MODULE, + } +}; + +static int __init srh_mt6_init(void) +{ + return xt_register_matches(srh_mt6_reg, ARRAY_SIZE(srh_mt6_reg)); +} + +static void __exit srh_mt6_exit(void) +{ + xt_unregister_matches(srh_mt6_reg, ARRAY_SIZE(srh_mt6_reg)); +} + +module_init(srh_mt6_init); +module_exit(srh_mt6_exit); + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Xtables: IPv6 Segment Routing Header match"); +MODULE_AUTHOR("Ahmed Abdelsalam "); diff --git a/net/ipv6/netfilter/ip6table_filter.c b/net/ipv6/netfilter/ip6table_filter.c new file mode 100644 index 000000000..df785ebda --- /dev/null +++ b/net/ipv6/netfilter/ip6table_filter.c @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * This is the 1999 rewrite of IP Firewalling, aiming for kernel 2.3.x. + * + * Copyright (C) 1999 Paul `Rusty' Russell & Michael J. Neuling + * Copyright (C) 2000-2004 Netfilter Core Team + */ + +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Netfilter Core Team "); +MODULE_DESCRIPTION("ip6tables filter table"); + +#define FILTER_VALID_HOOKS ((1 << NF_INET_LOCAL_IN) | \ + (1 << NF_INET_FORWARD) | \ + (1 << NF_INET_LOCAL_OUT)) + +static const struct xt_table packet_filter = { + .name = "filter", + .valid_hooks = FILTER_VALID_HOOKS, + .me = THIS_MODULE, + .af = NFPROTO_IPV6, + .priority = NF_IP6_PRI_FILTER, +}; + +static struct nf_hook_ops *filter_ops __read_mostly; + +/* Default to forward because I got too much mail already. */ +static bool forward = true; +module_param(forward, bool, 0000); + +static int ip6table_filter_table_init(struct net *net) +{ + struct ip6t_replace *repl; + int err; + + repl = ip6t_alloc_initial_table(&packet_filter); + if (repl == NULL) + return -ENOMEM; + /* Entry 1 is the FORWARD hook */ + ((struct ip6t_standard *)repl->entries)[1].target.verdict = + forward ? -NF_ACCEPT - 1 : -NF_DROP - 1; + + err = ip6t_register_table(net, &packet_filter, repl, filter_ops); + kfree(repl); + return err; +} + +static int __net_init ip6table_filter_net_init(struct net *net) +{ + if (!forward) + return ip6table_filter_table_init(net); + + return 0; +} + +static void __net_exit ip6table_filter_net_pre_exit(struct net *net) +{ + ip6t_unregister_table_pre_exit(net, "filter"); +} + +static void __net_exit ip6table_filter_net_exit(struct net *net) +{ + ip6t_unregister_table_exit(net, "filter"); +} + +static struct pernet_operations ip6table_filter_net_ops = { + .init = ip6table_filter_net_init, + .pre_exit = ip6table_filter_net_pre_exit, + .exit = ip6table_filter_net_exit, +}; + +static int __init ip6table_filter_init(void) +{ + int ret = xt_register_template(&packet_filter, + ip6table_filter_table_init); + + if (ret < 0) + return ret; + + filter_ops = xt_hook_ops_alloc(&packet_filter, ip6t_do_table); + if (IS_ERR(filter_ops)) { + xt_unregister_template(&packet_filter); + return PTR_ERR(filter_ops); + } + + ret = register_pernet_subsys(&ip6table_filter_net_ops); + if (ret < 0) { + xt_unregister_template(&packet_filter); + kfree(filter_ops); + return ret; + } + + return ret; +} + +static void __exit ip6table_filter_fini(void) +{ + unregister_pernet_subsys(&ip6table_filter_net_ops); + xt_unregister_template(&packet_filter); + kfree(filter_ops); +} + +module_init(ip6table_filter_init); +module_exit(ip6table_filter_fini); diff --git a/net/ipv6/netfilter/ip6table_mangle.c b/net/ipv6/netfilter/ip6table_mangle.c new file mode 100644 index 000000000..a88b2ce4a --- /dev/null +++ b/net/ipv6/netfilter/ip6table_mangle.c @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * IPv6 packet mangling table, a port of the IPv4 mangle table to IPv6 + * + * Copyright (C) 2000-2001 by Harald Welte + * Copyright (C) 2000-2004 Netfilter Core Team + */ +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Netfilter Core Team "); +MODULE_DESCRIPTION("ip6tables mangle table"); + +#define MANGLE_VALID_HOOKS ((1 << NF_INET_PRE_ROUTING) | \ + (1 << NF_INET_LOCAL_IN) | \ + (1 << NF_INET_FORWARD) | \ + (1 << NF_INET_LOCAL_OUT) | \ + (1 << NF_INET_POST_ROUTING)) + +static const struct xt_table packet_mangler = { + .name = "mangle", + .valid_hooks = MANGLE_VALID_HOOKS, + .me = THIS_MODULE, + .af = NFPROTO_IPV6, + .priority = NF_IP6_PRI_MANGLE, +}; + +static unsigned int +ip6t_mangle_out(void *priv, struct sk_buff *skb, const struct nf_hook_state *state) +{ + unsigned int ret; + struct in6_addr saddr, daddr; + u_int8_t hop_limit; + u_int32_t flowlabel, mark; + int err; + + /* save source/dest address, mark, hoplimit, flowlabel, priority, */ + memcpy(&saddr, &ipv6_hdr(skb)->saddr, sizeof(saddr)); + memcpy(&daddr, &ipv6_hdr(skb)->daddr, sizeof(daddr)); + mark = skb->mark; + hop_limit = ipv6_hdr(skb)->hop_limit; + + /* flowlabel and prio (includes version, which shouldn't change either */ + flowlabel = *((u_int32_t *)ipv6_hdr(skb)); + + ret = ip6t_do_table(priv, skb, state); + + if (ret != NF_DROP && ret != NF_STOLEN && + (!ipv6_addr_equal(&ipv6_hdr(skb)->saddr, &saddr) || + !ipv6_addr_equal(&ipv6_hdr(skb)->daddr, &daddr) || + skb->mark != mark || + ipv6_hdr(skb)->hop_limit != hop_limit || + flowlabel != *((u_int32_t *)ipv6_hdr(skb)))) { + err = ip6_route_me_harder(state->net, state->sk, skb); + if (err < 0) + ret = NF_DROP_ERR(err); + } + + return ret; +} + +/* The work comes in here from netfilter.c. */ +static unsigned int +ip6table_mangle_hook(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + if (state->hook == NF_INET_LOCAL_OUT) + return ip6t_mangle_out(priv, skb, state); + return ip6t_do_table(priv, skb, state); +} + +static struct nf_hook_ops *mangle_ops __read_mostly; +static int ip6table_mangle_table_init(struct net *net) +{ + struct ip6t_replace *repl; + int ret; + + repl = ip6t_alloc_initial_table(&packet_mangler); + if (repl == NULL) + return -ENOMEM; + ret = ip6t_register_table(net, &packet_mangler, repl, mangle_ops); + kfree(repl); + return ret; +} + +static void __net_exit ip6table_mangle_net_pre_exit(struct net *net) +{ + ip6t_unregister_table_pre_exit(net, "mangle"); +} + +static void __net_exit ip6table_mangle_net_exit(struct net *net) +{ + ip6t_unregister_table_exit(net, "mangle"); +} + +static struct pernet_operations ip6table_mangle_net_ops = { + .pre_exit = ip6table_mangle_net_pre_exit, + .exit = ip6table_mangle_net_exit, +}; + +static int __init ip6table_mangle_init(void) +{ + int ret = xt_register_template(&packet_mangler, + ip6table_mangle_table_init); + + if (ret < 0) + return ret; + + mangle_ops = xt_hook_ops_alloc(&packet_mangler, ip6table_mangle_hook); + if (IS_ERR(mangle_ops)) { + xt_unregister_template(&packet_mangler); + return PTR_ERR(mangle_ops); + } + + ret = register_pernet_subsys(&ip6table_mangle_net_ops); + if (ret < 0) { + xt_unregister_template(&packet_mangler); + kfree(mangle_ops); + return ret; + } + + return ret; +} + +static void __exit ip6table_mangle_fini(void) +{ + unregister_pernet_subsys(&ip6table_mangle_net_ops); + xt_unregister_template(&packet_mangler); + kfree(mangle_ops); +} + +module_init(ip6table_mangle_init); +module_exit(ip6table_mangle_fini); diff --git a/net/ipv6/netfilter/ip6table_nat.c b/net/ipv6/netfilter/ip6table_nat.c new file mode 100644 index 000000000..bf3cb3a13 --- /dev/null +++ b/net/ipv6/netfilter/ip6table_nat.c @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2011 Patrick McHardy + * + * Based on Rusty Russell's IPv4 NAT code. Development of IPv6 NAT + * funded by Astaro. + */ + +#include +#include +#include +#include +#include +#include + +#include + +struct ip6table_nat_pernet { + struct nf_hook_ops *nf_nat_ops; +}; + +static unsigned int ip6table_nat_net_id __read_mostly; + +static const struct xt_table nf_nat_ipv6_table = { + .name = "nat", + .valid_hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_LOCAL_IN), + .me = THIS_MODULE, + .af = NFPROTO_IPV6, +}; + +static const struct nf_hook_ops nf_nat_ipv6_ops[] = { + { + .hook = ip6t_do_table, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_PRE_ROUTING, + .priority = NF_IP6_PRI_NAT_DST, + }, + { + .hook = ip6t_do_table, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_POST_ROUTING, + .priority = NF_IP6_PRI_NAT_SRC, + }, + { + .hook = ip6t_do_table, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_LOCAL_OUT, + .priority = NF_IP6_PRI_NAT_DST, + }, + { + .hook = ip6t_do_table, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_LOCAL_IN, + .priority = NF_IP6_PRI_NAT_SRC, + }, +}; + +static int ip6t_nat_register_lookups(struct net *net) +{ + struct ip6table_nat_pernet *xt_nat_net; + struct nf_hook_ops *ops; + struct xt_table *table; + int i, ret; + + table = xt_find_table(net, NFPROTO_IPV6, "nat"); + if (WARN_ON_ONCE(!table)) + return -ENOENT; + + xt_nat_net = net_generic(net, ip6table_nat_net_id); + ops = kmemdup(nf_nat_ipv6_ops, sizeof(nf_nat_ipv6_ops), GFP_KERNEL); + if (!ops) + return -ENOMEM; + + for (i = 0; i < ARRAY_SIZE(nf_nat_ipv6_ops); i++) { + ops[i].priv = table; + ret = nf_nat_ipv6_register_fn(net, &ops[i]); + if (ret) { + while (i) + nf_nat_ipv6_unregister_fn(net, &ops[--i]); + + kfree(ops); + return ret; + } + } + + xt_nat_net->nf_nat_ops = ops; + return 0; +} + +static void ip6t_nat_unregister_lookups(struct net *net) +{ + struct ip6table_nat_pernet *xt_nat_net = net_generic(net, ip6table_nat_net_id); + struct nf_hook_ops *ops = xt_nat_net->nf_nat_ops; + int i; + + if (!ops) + return; + + for (i = 0; i < ARRAY_SIZE(nf_nat_ipv6_ops); i++) + nf_nat_ipv6_unregister_fn(net, &ops[i]); + + kfree(ops); +} + +static int ip6table_nat_table_init(struct net *net) +{ + struct ip6t_replace *repl; + int ret; + + repl = ip6t_alloc_initial_table(&nf_nat_ipv6_table); + if (repl == NULL) + return -ENOMEM; + ret = ip6t_register_table(net, &nf_nat_ipv6_table, repl, + NULL); + if (ret < 0) { + kfree(repl); + return ret; + } + + ret = ip6t_nat_register_lookups(net); + if (ret < 0) + ip6t_unregister_table_exit(net, "nat"); + + kfree(repl); + return ret; +} + +static void __net_exit ip6table_nat_net_pre_exit(struct net *net) +{ + ip6t_nat_unregister_lookups(net); +} + +static void __net_exit ip6table_nat_net_exit(struct net *net) +{ + ip6t_unregister_table_exit(net, "nat"); +} + +static struct pernet_operations ip6table_nat_net_ops = { + .pre_exit = ip6table_nat_net_pre_exit, + .exit = ip6table_nat_net_exit, + .id = &ip6table_nat_net_id, + .size = sizeof(struct ip6table_nat_pernet), +}; + +static int __init ip6table_nat_init(void) +{ + int ret = xt_register_template(&nf_nat_ipv6_table, + ip6table_nat_table_init); + + if (ret < 0) + return ret; + + ret = register_pernet_subsys(&ip6table_nat_net_ops); + if (ret) + xt_unregister_template(&nf_nat_ipv6_table); + + return ret; +} + +static void __exit ip6table_nat_exit(void) +{ + unregister_pernet_subsys(&ip6table_nat_net_ops); + xt_unregister_template(&nf_nat_ipv6_table); +} + +module_init(ip6table_nat_init); +module_exit(ip6table_nat_exit); + +MODULE_LICENSE("GPL"); diff --git a/net/ipv6/netfilter/ip6table_raw.c b/net/ipv6/netfilter/ip6table_raw.c new file mode 100644 index 000000000..08861d5d1 --- /dev/null +++ b/net/ipv6/netfilter/ip6table_raw.c @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * IPv6 raw table, a port of the IPv4 raw table to IPv6 + * + * Copyright (C) 2003 Jozsef Kadlecsik + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include + +#define RAW_VALID_HOOKS ((1 << NF_INET_PRE_ROUTING) | (1 << NF_INET_LOCAL_OUT)) + +static bool raw_before_defrag __read_mostly; +MODULE_PARM_DESC(raw_before_defrag, "Enable raw table before defrag"); +module_param(raw_before_defrag, bool, 0000); + +static const struct xt_table packet_raw = { + .name = "raw", + .valid_hooks = RAW_VALID_HOOKS, + .me = THIS_MODULE, + .af = NFPROTO_IPV6, + .priority = NF_IP6_PRI_RAW, +}; + +static const struct xt_table packet_raw_before_defrag = { + .name = "raw", + .valid_hooks = RAW_VALID_HOOKS, + .me = THIS_MODULE, + .af = NFPROTO_IPV6, + .priority = NF_IP6_PRI_RAW_BEFORE_DEFRAG, +}; + +static struct nf_hook_ops *rawtable_ops __read_mostly; + +static int ip6table_raw_table_init(struct net *net) +{ + struct ip6t_replace *repl; + const struct xt_table *table = &packet_raw; + int ret; + + if (raw_before_defrag) + table = &packet_raw_before_defrag; + + repl = ip6t_alloc_initial_table(table); + if (repl == NULL) + return -ENOMEM; + ret = ip6t_register_table(net, table, repl, rawtable_ops); + kfree(repl); + return ret; +} + +static void __net_exit ip6table_raw_net_pre_exit(struct net *net) +{ + ip6t_unregister_table_pre_exit(net, "raw"); +} + +static void __net_exit ip6table_raw_net_exit(struct net *net) +{ + ip6t_unregister_table_exit(net, "raw"); +} + +static struct pernet_operations ip6table_raw_net_ops = { + .pre_exit = ip6table_raw_net_pre_exit, + .exit = ip6table_raw_net_exit, +}; + +static int __init ip6table_raw_init(void) +{ + const struct xt_table *table = &packet_raw; + int ret; + + if (raw_before_defrag) { + table = &packet_raw_before_defrag; + pr_info("Enabling raw table before defrag\n"); + } + + ret = xt_register_template(table, ip6table_raw_table_init); + if (ret < 0) + return ret; + + /* Register hooks */ + rawtable_ops = xt_hook_ops_alloc(table, ip6t_do_table); + if (IS_ERR(rawtable_ops)) { + xt_unregister_template(table); + return PTR_ERR(rawtable_ops); + } + + ret = register_pernet_subsys(&ip6table_raw_net_ops); + if (ret < 0) { + kfree(rawtable_ops); + xt_unregister_template(table); + return ret; + } + + return ret; +} + +static void __exit ip6table_raw_fini(void) +{ + unregister_pernet_subsys(&ip6table_raw_net_ops); + xt_unregister_template(&packet_raw); + kfree(rawtable_ops); +} + +module_init(ip6table_raw_init); +module_exit(ip6table_raw_fini); +MODULE_LICENSE("GPL"); diff --git a/net/ipv6/netfilter/ip6table_security.c b/net/ipv6/netfilter/ip6table_security.c new file mode 100644 index 000000000..4df14a9ba --- /dev/null +++ b/net/ipv6/netfilter/ip6table_security.c @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * "security" table for IPv6 + * + * This is for use by Mandatory Access Control (MAC) security models, + * which need to be able to manage security policy in separate context + * to DAC. + * + * Based on iptable_mangle.c + * + * Copyright (C) 1999 Paul `Rusty' Russell & Michael J. Neuling + * Copyright (C) 2000-2004 Netfilter Core Team netfilter.org> + * Copyright (C) 2008 Red Hat, Inc., James Morris redhat.com> + */ +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("James Morris redhat.com>"); +MODULE_DESCRIPTION("ip6tables security table, for MAC rules"); + +#define SECURITY_VALID_HOOKS (1 << NF_INET_LOCAL_IN) | \ + (1 << NF_INET_FORWARD) | \ + (1 << NF_INET_LOCAL_OUT) + +static const struct xt_table security_table = { + .name = "security", + .valid_hooks = SECURITY_VALID_HOOKS, + .me = THIS_MODULE, + .af = NFPROTO_IPV6, + .priority = NF_IP6_PRI_SECURITY, +}; + +static struct nf_hook_ops *sectbl_ops __read_mostly; + +static int ip6table_security_table_init(struct net *net) +{ + struct ip6t_replace *repl; + int ret; + + repl = ip6t_alloc_initial_table(&security_table); + if (repl == NULL) + return -ENOMEM; + ret = ip6t_register_table(net, &security_table, repl, sectbl_ops); + kfree(repl); + return ret; +} + +static void __net_exit ip6table_security_net_pre_exit(struct net *net) +{ + ip6t_unregister_table_pre_exit(net, "security"); +} + +static void __net_exit ip6table_security_net_exit(struct net *net) +{ + ip6t_unregister_table_exit(net, "security"); +} + +static struct pernet_operations ip6table_security_net_ops = { + .pre_exit = ip6table_security_net_pre_exit, + .exit = ip6table_security_net_exit, +}; + +static int __init ip6table_security_init(void) +{ + int ret = xt_register_template(&security_table, + ip6table_security_table_init); + + if (ret < 0) + return ret; + + sectbl_ops = xt_hook_ops_alloc(&security_table, ip6t_do_table); + if (IS_ERR(sectbl_ops)) { + xt_unregister_template(&security_table); + return PTR_ERR(sectbl_ops); + } + + ret = register_pernet_subsys(&ip6table_security_net_ops); + if (ret < 0) { + kfree(sectbl_ops); + xt_unregister_template(&security_table); + return ret; + } + + return ret; +} + +static void __exit ip6table_security_fini(void) +{ + unregister_pernet_subsys(&ip6table_security_net_ops); + xt_unregister_template(&security_table); + kfree(sectbl_ops); +} + +module_init(ip6table_security_init); +module_exit(ip6table_security_fini); diff --git a/net/ipv6/netfilter/nf_conntrack_reasm.c b/net/ipv6/netfilter/nf_conntrack_reasm.c new file mode 100644 index 000000000..38db0064d --- /dev/null +++ b/net/ipv6/netfilter/nf_conntrack_reasm.c @@ -0,0 +1,568 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPv6 fragment reassembly for connection tracking + * + * Copyright (C)2004 USAGI/WIDE Project + * + * Author: + * Yasuyuki Kozakai @USAGI + * + * Based on: net/ipv6/reassembly.c + */ + +#define pr_fmt(fmt) "IPv6-nf: " fmt + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +static const char nf_frags_cache_name[] = "nf-frags"; + +static unsigned int nf_frag_pernet_id __read_mostly; +static struct inet_frags nf_frags; + +static struct nft_ct_frag6_pernet *nf_frag_pernet(struct net *net) +{ + return net_generic(net, nf_frag_pernet_id); +} + +#ifdef CONFIG_SYSCTL + +static struct ctl_table nf_ct_frag6_sysctl_table[] = { + { + .procname = "nf_conntrack_frag6_timeout", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "nf_conntrack_frag6_low_thresh", + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + }, + { + .procname = "nf_conntrack_frag6_high_thresh", + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + }, + { } +}; + +static int nf_ct_frag6_sysctl_register(struct net *net) +{ + struct nft_ct_frag6_pernet *nf_frag; + struct ctl_table *table; + struct ctl_table_header *hdr; + + table = nf_ct_frag6_sysctl_table; + if (!net_eq(net, &init_net)) { + table = kmemdup(table, sizeof(nf_ct_frag6_sysctl_table), + GFP_KERNEL); + if (table == NULL) + goto err_alloc; + } + + nf_frag = nf_frag_pernet(net); + + table[0].data = &nf_frag->fqdir->timeout; + table[1].data = &nf_frag->fqdir->low_thresh; + table[1].extra2 = &nf_frag->fqdir->high_thresh; + table[2].data = &nf_frag->fqdir->high_thresh; + table[2].extra1 = &nf_frag->fqdir->low_thresh; + + hdr = register_net_sysctl(net, "net/netfilter", table); + if (hdr == NULL) + goto err_reg; + + nf_frag->nf_frag_frags_hdr = hdr; + return 0; + +err_reg: + if (!net_eq(net, &init_net)) + kfree(table); +err_alloc: + return -ENOMEM; +} + +static void __net_exit nf_ct_frags6_sysctl_unregister(struct net *net) +{ + struct nft_ct_frag6_pernet *nf_frag = nf_frag_pernet(net); + struct ctl_table *table; + + table = nf_frag->nf_frag_frags_hdr->ctl_table_arg; + unregister_net_sysctl_table(nf_frag->nf_frag_frags_hdr); + if (!net_eq(net, &init_net)) + kfree(table); +} + +#else +static int nf_ct_frag6_sysctl_register(struct net *net) +{ + return 0; +} +static void __net_exit nf_ct_frags6_sysctl_unregister(struct net *net) +{ +} +#endif + +static int nf_ct_frag6_reasm(struct frag_queue *fq, struct sk_buff *skb, + struct sk_buff *prev_tail, struct net_device *dev); + +static inline u8 ip6_frag_ecn(const struct ipv6hdr *ipv6h) +{ + return 1 << (ipv6_get_dsfield(ipv6h) & INET_ECN_MASK); +} + +static void nf_ct_frag6_expire(struct timer_list *t) +{ + struct inet_frag_queue *frag = from_timer(frag, t, timer); + struct frag_queue *fq; + + fq = container_of(frag, struct frag_queue, q); + + ip6frag_expire_frag_queue(fq->q.fqdir->net, fq); +} + +/* Creation primitives. */ +static struct frag_queue *fq_find(struct net *net, __be32 id, u32 user, + const struct ipv6hdr *hdr, int iif) +{ + struct nft_ct_frag6_pernet *nf_frag = nf_frag_pernet(net); + struct frag_v6_compare_key key = { + .id = id, + .saddr = hdr->saddr, + .daddr = hdr->daddr, + .user = user, + .iif = iif, + }; + struct inet_frag_queue *q; + + q = inet_frag_find(nf_frag->fqdir, &key); + if (!q) + return NULL; + + return container_of(q, struct frag_queue, q); +} + + +static int nf_ct_frag6_queue(struct frag_queue *fq, struct sk_buff *skb, + const struct frag_hdr *fhdr, int nhoff) +{ + unsigned int payload_len; + struct net_device *dev; + struct sk_buff *prev; + int offset, end, err; + u8 ecn; + + if (fq->q.flags & INET_FRAG_COMPLETE) { + pr_debug("Already completed\n"); + goto err; + } + + payload_len = ntohs(ipv6_hdr(skb)->payload_len); + + offset = ntohs(fhdr->frag_off) & ~0x7; + end = offset + (payload_len - + ((u8 *)(fhdr + 1) - (u8 *)(ipv6_hdr(skb) + 1))); + + if ((unsigned int)end > IPV6_MAXPLEN) { + pr_debug("offset is too large.\n"); + return -EINVAL; + } + + ecn = ip6_frag_ecn(ipv6_hdr(skb)); + + if (skb->ip_summed == CHECKSUM_COMPLETE) { + const unsigned char *nh = skb_network_header(skb); + skb->csum = csum_sub(skb->csum, + csum_partial(nh, (u8 *)(fhdr + 1) - nh, + 0)); + } + + /* Is this the final fragment? */ + if (!(fhdr->frag_off & htons(IP6_MF))) { + /* If we already have some bits beyond end + * or have different end, the segment is corrupted. + */ + if (end < fq->q.len || + ((fq->q.flags & INET_FRAG_LAST_IN) && end != fq->q.len)) { + pr_debug("already received last fragment\n"); + goto err; + } + fq->q.flags |= INET_FRAG_LAST_IN; + fq->q.len = end; + } else { + /* Check if the fragment is rounded to 8 bytes. + * Required by the RFC. + */ + if (end & 0x7) { + /* RFC2460 says always send parameter problem in + * this case. -DaveM + */ + pr_debug("end of fragment not rounded to 8 bytes.\n"); + inet_frag_kill(&fq->q); + return -EPROTO; + } + if (end > fq->q.len) { + /* Some bits beyond end -> corruption. */ + if (fq->q.flags & INET_FRAG_LAST_IN) { + pr_debug("last packet already reached.\n"); + goto err; + } + fq->q.len = end; + } + } + + if (end == offset) + goto err; + + /* Point into the IP datagram 'data' part. */ + if (!pskb_pull(skb, (u8 *) (fhdr + 1) - skb->data)) { + pr_debug("queue: message is too short.\n"); + goto err; + } + if (pskb_trim_rcsum(skb, end - offset)) { + pr_debug("Can't trim\n"); + goto err; + } + + /* Note : skb->rbnode and skb->dev share the same location. */ + dev = skb->dev; + /* Makes sure compiler wont do silly aliasing games */ + barrier(); + + prev = fq->q.fragments_tail; + err = inet_frag_queue_insert(&fq->q, skb, offset, end); + if (err) { + if (err == IPFRAG_DUP) { + /* No error for duplicates, pretend they got queued. */ + kfree_skb(skb); + return -EINPROGRESS; + } + goto insert_error; + } + + if (dev) + fq->iif = dev->ifindex; + + fq->q.stamp = skb->tstamp; + fq->q.mono_delivery_time = skb->mono_delivery_time; + fq->q.meat += skb->len; + fq->ecn |= ecn; + if (payload_len > fq->q.max_size) + fq->q.max_size = payload_len; + add_frag_mem_limit(fq->q.fqdir, skb->truesize); + + /* The first fragment. + * nhoffset is obtained from the first fragment, of course. + */ + if (offset == 0) { + fq->nhoffset = nhoff; + fq->q.flags |= INET_FRAG_FIRST_IN; + } + + if (fq->q.flags == (INET_FRAG_FIRST_IN | INET_FRAG_LAST_IN) && + fq->q.meat == fq->q.len) { + unsigned long orefdst = skb->_skb_refdst; + + skb->_skb_refdst = 0UL; + err = nf_ct_frag6_reasm(fq, skb, prev, dev); + skb->_skb_refdst = orefdst; + + /* After queue has assumed skb ownership, only 0 or + * -EINPROGRESS must be returned. + */ + return err ? -EINPROGRESS : 0; + } + + skb_dst_drop(skb); + return -EINPROGRESS; + +insert_error: + inet_frag_kill(&fq->q); +err: + skb_dst_drop(skb); + return -EINVAL; +} + +/* + * Check if this packet is complete. + * + * It is called with locked fq, and caller must check that + * queue is eligible for reassembly i.e. it is not COMPLETE, + * the last and the first frames arrived and all the bits are here. + */ +static int nf_ct_frag6_reasm(struct frag_queue *fq, struct sk_buff *skb, + struct sk_buff *prev_tail, struct net_device *dev) +{ + void *reasm_data; + int payload_len; + u8 ecn; + + inet_frag_kill(&fq->q); + + ecn = ip_frag_ecn_table[fq->ecn]; + if (unlikely(ecn == 0xff)) + goto err; + + reasm_data = inet_frag_reasm_prepare(&fq->q, skb, prev_tail); + if (!reasm_data) + goto err; + + payload_len = ((skb->data - skb_network_header(skb)) - + sizeof(struct ipv6hdr) + fq->q.len - + sizeof(struct frag_hdr)); + if (payload_len > IPV6_MAXPLEN) { + net_dbg_ratelimited("nf_ct_frag6_reasm: payload len = %d\n", + payload_len); + goto err; + } + + /* We have to remove fragment header from datagram and to relocate + * header in order to calculate ICV correctly. */ + skb_network_header(skb)[fq->nhoffset] = skb_transport_header(skb)[0]; + memmove(skb->head + sizeof(struct frag_hdr), skb->head, + (skb->data - skb->head) - sizeof(struct frag_hdr)); + skb->mac_header += sizeof(struct frag_hdr); + skb->network_header += sizeof(struct frag_hdr); + + skb_reset_transport_header(skb); + + inet_frag_reasm_finish(&fq->q, skb, reasm_data, false); + + skb->ignore_df = 1; + skb->dev = dev; + ipv6_hdr(skb)->payload_len = htons(payload_len); + ipv6_change_dsfield(ipv6_hdr(skb), 0xff, ecn); + IP6CB(skb)->frag_max_size = sizeof(struct ipv6hdr) + fq->q.max_size; + IP6CB(skb)->flags |= IP6SKB_FRAGMENTED; + + /* Yes, and fold redundant checksum back. 8) */ + if (skb->ip_summed == CHECKSUM_COMPLETE) + skb->csum = csum_partial(skb_network_header(skb), + skb_network_header_len(skb), + skb->csum); + + fq->q.rb_fragments = RB_ROOT; + fq->q.fragments_tail = NULL; + fq->q.last_run_head = NULL; + + return 0; + +err: + inet_frag_kill(&fq->q); + return -EINVAL; +} + +/* + * find the header just before Fragment Header. + * + * if success return 0 and set ... + * (*prevhdrp): the value of "Next Header Field" in the header + * just before Fragment Header. + * (*prevhoff): the offset of "Next Header Field" in the header + * just before Fragment Header. + * (*fhoff) : the offset of Fragment Header. + * + * Based on ipv6_skip_hdr() in net/ipv6/exthdr.c + * + */ +static int +find_prev_fhdr(struct sk_buff *skb, u8 *prevhdrp, int *prevhoff, int *fhoff) +{ + u8 nexthdr = ipv6_hdr(skb)->nexthdr; + const int netoff = skb_network_offset(skb); + u8 prev_nhoff = netoff + offsetof(struct ipv6hdr, nexthdr); + int start = netoff + sizeof(struct ipv6hdr); + int len = skb->len - start; + u8 prevhdr = NEXTHDR_IPV6; + + while (nexthdr != NEXTHDR_FRAGMENT) { + struct ipv6_opt_hdr hdr; + int hdrlen; + + if (!ipv6_ext_hdr(nexthdr)) { + return -1; + } + if (nexthdr == NEXTHDR_NONE) { + pr_debug("next header is none\n"); + return -1; + } + if (len < (int)sizeof(struct ipv6_opt_hdr)) { + pr_debug("too short\n"); + return -1; + } + if (skb_copy_bits(skb, start, &hdr, sizeof(hdr))) + BUG(); + if (nexthdr == NEXTHDR_AUTH) + hdrlen = ipv6_authlen(&hdr); + else + hdrlen = ipv6_optlen(&hdr); + + prevhdr = nexthdr; + prev_nhoff = start; + + nexthdr = hdr.nexthdr; + len -= hdrlen; + start += hdrlen; + } + + if (len < 0) + return -1; + + *prevhdrp = prevhdr; + *prevhoff = prev_nhoff; + *fhoff = start; + + return 0; +} + +int nf_ct_frag6_gather(struct net *net, struct sk_buff *skb, u32 user) +{ + u16 savethdr = skb->transport_header; + u8 nexthdr = NEXTHDR_FRAGMENT; + int fhoff, nhoff, ret; + struct frag_hdr *fhdr; + struct frag_queue *fq; + struct ipv6hdr *hdr; + u8 prevhdr; + + /* Jumbo payload inhibits frag. header */ + if (ipv6_hdr(skb)->payload_len == 0) { + pr_debug("payload len = 0\n"); + return 0; + } + + if (find_prev_fhdr(skb, &prevhdr, &nhoff, &fhoff) < 0) + return 0; + + /* Discard the first fragment if it does not include all headers + * RFC 8200, Section 4.5 + */ + if (ipv6frag_thdr_truncated(skb, fhoff, &nexthdr)) { + pr_debug("Drop incomplete fragment\n"); + return 0; + } + + if (!pskb_may_pull(skb, fhoff + sizeof(*fhdr))) + return -ENOMEM; + + skb_set_transport_header(skb, fhoff); + hdr = ipv6_hdr(skb); + fhdr = (struct frag_hdr *)skb_transport_header(skb); + + skb_orphan(skb); + fq = fq_find(net, fhdr->identification, user, hdr, + skb->dev ? skb->dev->ifindex : 0); + if (fq == NULL) { + pr_debug("Can't find and can't create new queue\n"); + return -ENOMEM; + } + + spin_lock_bh(&fq->q.lock); + + ret = nf_ct_frag6_queue(fq, skb, fhdr, nhoff); + if (ret == -EPROTO) { + skb->transport_header = savethdr; + ret = 0; + } + + spin_unlock_bh(&fq->q.lock); + inet_frag_put(&fq->q); + return ret; +} +EXPORT_SYMBOL_GPL(nf_ct_frag6_gather); + +static int nf_ct_net_init(struct net *net) +{ + struct nft_ct_frag6_pernet *nf_frag = nf_frag_pernet(net); + int res; + + res = fqdir_init(&nf_frag->fqdir, &nf_frags, net); + if (res < 0) + return res; + + nf_frag->fqdir->high_thresh = IPV6_FRAG_HIGH_THRESH; + nf_frag->fqdir->low_thresh = IPV6_FRAG_LOW_THRESH; + nf_frag->fqdir->timeout = IPV6_FRAG_TIMEOUT; + + res = nf_ct_frag6_sysctl_register(net); + if (res < 0) + fqdir_exit(nf_frag->fqdir); + return res; +} + +static void nf_ct_net_pre_exit(struct net *net) +{ + struct nft_ct_frag6_pernet *nf_frag = nf_frag_pernet(net); + + fqdir_pre_exit(nf_frag->fqdir); +} + +static void nf_ct_net_exit(struct net *net) +{ + struct nft_ct_frag6_pernet *nf_frag = nf_frag_pernet(net); + + nf_ct_frags6_sysctl_unregister(net); + fqdir_exit(nf_frag->fqdir); +} + +static struct pernet_operations nf_ct_net_ops = { + .init = nf_ct_net_init, + .pre_exit = nf_ct_net_pre_exit, + .exit = nf_ct_net_exit, + .id = &nf_frag_pernet_id, + .size = sizeof(struct nft_ct_frag6_pernet), +}; + +static const struct rhashtable_params nfct_rhash_params = { + .head_offset = offsetof(struct inet_frag_queue, node), + .hashfn = ip6frag_key_hashfn, + .obj_hashfn = ip6frag_obj_hashfn, + .obj_cmpfn = ip6frag_obj_cmpfn, + .automatic_shrinking = true, +}; + +int nf_ct_frag6_init(void) +{ + int ret = 0; + + nf_frags.constructor = ip6frag_init; + nf_frags.destructor = NULL; + nf_frags.qsize = sizeof(struct frag_queue); + nf_frags.frag_expire = nf_ct_frag6_expire; + nf_frags.frags_cache_name = nf_frags_cache_name; + nf_frags.rhash_params = nfct_rhash_params; + ret = inet_frags_init(&nf_frags); + if (ret) + goto out; + ret = register_pernet_subsys(&nf_ct_net_ops); + if (ret) + inet_frags_fini(&nf_frags); + +out: + return ret; +} + +void nf_ct_frag6_cleanup(void) +{ + unregister_pernet_subsys(&nf_ct_net_ops); + inet_frags_fini(&nf_frags); +} diff --git a/net/ipv6/netfilter/nf_defrag_ipv6_hooks.c b/net/ipv6/netfilter/nf_defrag_ipv6_hooks.c new file mode 100644 index 000000000..cb4eb1d2c --- /dev/null +++ b/net/ipv6/netfilter/nf_defrag_ipv6_hooks.c @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2004 Netfilter Core Team + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#if IS_ENABLED(CONFIG_NF_CONNTRACK) +#include +#include +#include +#include +#include +#endif +#include +#include + +static DEFINE_MUTEX(defrag6_mutex); + +static enum ip6_defrag_users nf_ct6_defrag_user(unsigned int hooknum, + struct sk_buff *skb) +{ + u16 zone_id = NF_CT_DEFAULT_ZONE_ID; +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + if (skb_nfct(skb)) { + enum ip_conntrack_info ctinfo; + const struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + + zone_id = nf_ct_zone_id(nf_ct_zone(ct), CTINFO2DIR(ctinfo)); + } +#endif + if (nf_bridge_in_prerouting(skb)) + return IP6_DEFRAG_CONNTRACK_BRIDGE_IN + zone_id; + + if (hooknum == NF_INET_PRE_ROUTING) + return IP6_DEFRAG_CONNTRACK_IN + zone_id; + else + return IP6_DEFRAG_CONNTRACK_OUT + zone_id; +} + +static unsigned int ipv6_defrag(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + int err; + +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + /* Previously seen (loopback)? */ + if (skb_nfct(skb) && !nf_ct_is_template((struct nf_conn *)skb_nfct(skb))) + return NF_ACCEPT; + + if (skb->_nfct == IP_CT_UNTRACKED) + return NF_ACCEPT; +#endif + + err = nf_ct_frag6_gather(state->net, skb, + nf_ct6_defrag_user(state->hook, skb)); + /* queued */ + if (err == -EINPROGRESS) + return NF_STOLEN; + + return err == 0 ? NF_ACCEPT : NF_DROP; +} + +static const struct nf_hook_ops ipv6_defrag_ops[] = { + { + .hook = ipv6_defrag, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_PRE_ROUTING, + .priority = NF_IP6_PRI_CONNTRACK_DEFRAG, + }, + { + .hook = ipv6_defrag, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_LOCAL_OUT, + .priority = NF_IP6_PRI_CONNTRACK_DEFRAG, + }, +}; + +static void __net_exit defrag6_net_exit(struct net *net) +{ + if (net->nf.defrag_ipv6_users) { + nf_unregister_net_hooks(net, ipv6_defrag_ops, + ARRAY_SIZE(ipv6_defrag_ops)); + net->nf.defrag_ipv6_users = 0; + } +} + +static struct pernet_operations defrag6_net_ops = { + .exit = defrag6_net_exit, +}; + +static int __init nf_defrag_init(void) +{ + int ret = 0; + + ret = nf_ct_frag6_init(); + if (ret < 0) { + pr_err("nf_defrag_ipv6: can't initialize frag6.\n"); + return ret; + } + ret = register_pernet_subsys(&defrag6_net_ops); + if (ret < 0) { + pr_err("nf_defrag_ipv6: can't register pernet ops\n"); + goto cleanup_frag6; + } + return ret; + +cleanup_frag6: + nf_ct_frag6_cleanup(); + return ret; + +} + +static void __exit nf_defrag_fini(void) +{ + unregister_pernet_subsys(&defrag6_net_ops); + nf_ct_frag6_cleanup(); +} + +int nf_defrag_ipv6_enable(struct net *net) +{ + int err = 0; + + mutex_lock(&defrag6_mutex); + if (net->nf.defrag_ipv6_users == UINT_MAX) { + err = -EOVERFLOW; + goto out_unlock; + } + + if (net->nf.defrag_ipv6_users) { + net->nf.defrag_ipv6_users++; + goto out_unlock; + } + + err = nf_register_net_hooks(net, ipv6_defrag_ops, + ARRAY_SIZE(ipv6_defrag_ops)); + if (err == 0) + net->nf.defrag_ipv6_users = 1; + + out_unlock: + mutex_unlock(&defrag6_mutex); + return err; +} +EXPORT_SYMBOL_GPL(nf_defrag_ipv6_enable); + +void nf_defrag_ipv6_disable(struct net *net) +{ + mutex_lock(&defrag6_mutex); + if (net->nf.defrag_ipv6_users) { + net->nf.defrag_ipv6_users--; + if (net->nf.defrag_ipv6_users == 0) + nf_unregister_net_hooks(net, ipv6_defrag_ops, + ARRAY_SIZE(ipv6_defrag_ops)); + } + mutex_unlock(&defrag6_mutex); +} +EXPORT_SYMBOL_GPL(nf_defrag_ipv6_disable); + +module_init(nf_defrag_init); +module_exit(nf_defrag_fini); + +MODULE_LICENSE("GPL"); diff --git a/net/ipv6/netfilter/nf_dup_ipv6.c b/net/ipv6/netfilter/nf_dup_ipv6.c new file mode 100644 index 000000000..a0a2de30b --- /dev/null +++ b/net/ipv6/netfilter/nf_dup_ipv6.c @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * (C) 2007 by Sebastian Claßen + * (C) 2007-2010 by Jan Engelhardt + * + * Extracted from xt_TEE.c + */ +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_NF_CONNTRACK) +#include +#endif + +static bool nf_dup_ipv6_route(struct net *net, struct sk_buff *skb, + const struct in6_addr *gw, int oif) +{ + const struct ipv6hdr *iph = ipv6_hdr(skb); + struct dst_entry *dst; + struct flowi6 fl6; + + memset(&fl6, 0, sizeof(fl6)); + if (oif != -1) + fl6.flowi6_oif = oif; + + fl6.daddr = *gw; + fl6.flowlabel = (__force __be32)(((iph->flow_lbl[0] & 0xF) << 16) | + (iph->flow_lbl[1] << 8) | iph->flow_lbl[2]); + fl6.flowi6_flags = FLOWI_FLAG_KNOWN_NH; + dst = ip6_route_output(net, NULL, &fl6); + if (dst->error) { + dst_release(dst); + return false; + } + skb_dst_drop(skb); + skb_dst_set(skb, dst); + skb->dev = dst->dev; + skb->protocol = htons(ETH_P_IPV6); + + return true; +} + +void nf_dup_ipv6(struct net *net, struct sk_buff *skb, unsigned int hooknum, + const struct in6_addr *gw, int oif) +{ + if (this_cpu_read(nf_skb_duplicated)) + return; + skb = pskb_copy(skb, GFP_ATOMIC); + if (skb == NULL) + return; + +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + nf_reset_ct(skb); + nf_ct_set(skb, NULL, IP_CT_UNTRACKED); +#endif + if (hooknum == NF_INET_PRE_ROUTING || + hooknum == NF_INET_LOCAL_IN) { + struct ipv6hdr *iph = ipv6_hdr(skb); + --iph->hop_limit; + } + if (nf_dup_ipv6_route(net, skb, gw, oif)) { + __this_cpu_write(nf_skb_duplicated, true); + ip6_local_out(net, skb->sk, skb); + __this_cpu_write(nf_skb_duplicated, false); + } else { + kfree_skb(skb); + } +} +EXPORT_SYMBOL_GPL(nf_dup_ipv6); + +MODULE_AUTHOR("Sebastian Claßen "); +MODULE_AUTHOR("Jan Engelhardt "); +MODULE_DESCRIPTION("nf_dup_ipv6: IPv6 packet duplication"); +MODULE_LICENSE("GPL"); diff --git a/net/ipv6/netfilter/nf_reject_ipv6.c b/net/ipv6/netfilter/nf_reject_ipv6.c new file mode 100644 index 000000000..433d98bbe --- /dev/null +++ b/net/ipv6/netfilter/nf_reject_ipv6.c @@ -0,0 +1,419 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2004 Netfilter Core Team + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +static bool nf_reject_v6_csum_ok(struct sk_buff *skb, int hook) +{ + const struct ipv6hdr *ip6h = ipv6_hdr(skb); + int thoff; + __be16 fo; + u8 proto = ip6h->nexthdr; + + if (skb_csum_unnecessary(skb)) + return true; + + if (ip6h->payload_len && + pskb_trim_rcsum(skb, ntohs(ip6h->payload_len) + sizeof(*ip6h))) + return false; + + ip6h = ipv6_hdr(skb); + thoff = ipv6_skip_exthdr(skb, ((u8*)(ip6h+1) - skb->data), &proto, &fo); + if (thoff < 0 || thoff >= skb->len || (fo & htons(~0x7)) != 0) + return false; + + if (!nf_reject_verify_csum(skb, thoff, proto)) + return true; + + return nf_ip6_checksum(skb, hook, thoff, proto) == 0; +} + +static int nf_reject_ip6hdr_validate(struct sk_buff *skb) +{ + struct ipv6hdr *hdr; + u32 pkt_len; + + if (!pskb_may_pull(skb, sizeof(struct ipv6hdr))) + return 0; + + hdr = ipv6_hdr(skb); + if (hdr->version != 6) + return 0; + + pkt_len = ntohs(hdr->payload_len); + if (pkt_len + sizeof(struct ipv6hdr) > skb->len) + return 0; + + return 1; +} + +struct sk_buff *nf_reject_skb_v6_tcp_reset(struct net *net, + struct sk_buff *oldskb, + const struct net_device *dev, + int hook) +{ + struct sk_buff *nskb; + const struct tcphdr *oth; + struct tcphdr _oth; + unsigned int otcplen; + struct ipv6hdr *nip6h; + + if (!nf_reject_ip6hdr_validate(oldskb)) + return NULL; + + oth = nf_reject_ip6_tcphdr_get(oldskb, &_oth, &otcplen, hook); + if (!oth) + return NULL; + + nskb = alloc_skb(sizeof(struct ipv6hdr) + sizeof(struct tcphdr) + + LL_MAX_HEADER, GFP_ATOMIC); + if (!nskb) + return NULL; + + nskb->dev = (struct net_device *)dev; + + skb_reserve(nskb, LL_MAX_HEADER); + nip6h = nf_reject_ip6hdr_put(nskb, oldskb, IPPROTO_TCP, + net->ipv6.devconf_all->hop_limit); + nf_reject_ip6_tcphdr_put(nskb, oldskb, oth, otcplen); + nip6h->payload_len = htons(nskb->len - sizeof(struct ipv6hdr)); + + return nskb; +} +EXPORT_SYMBOL_GPL(nf_reject_skb_v6_tcp_reset); + +struct sk_buff *nf_reject_skb_v6_unreach(struct net *net, + struct sk_buff *oldskb, + const struct net_device *dev, + int hook, u8 code) +{ + struct sk_buff *nskb; + struct ipv6hdr *nip6h; + struct icmp6hdr *icmp6h; + unsigned int len; + + if (!nf_reject_ip6hdr_validate(oldskb)) + return NULL; + + /* Include "As much of invoking packet as possible without the ICMPv6 + * packet exceeding the minimum IPv6 MTU" in the ICMP payload. + */ + len = min_t(unsigned int, 1220, oldskb->len); + + if (!pskb_may_pull(oldskb, len)) + return NULL; + + if (!nf_reject_v6_csum_ok(oldskb, hook)) + return NULL; + + nskb = alloc_skb(sizeof(struct ipv6hdr) + sizeof(struct icmp6hdr) + + LL_MAX_HEADER + len, GFP_ATOMIC); + if (!nskb) + return NULL; + + nskb->dev = (struct net_device *)dev; + + skb_reserve(nskb, LL_MAX_HEADER); + nip6h = nf_reject_ip6hdr_put(nskb, oldskb, IPPROTO_ICMPV6, + net->ipv6.devconf_all->hop_limit); + + skb_reset_transport_header(nskb); + icmp6h = skb_put_zero(nskb, sizeof(struct icmp6hdr)); + icmp6h->icmp6_type = ICMPV6_DEST_UNREACH; + icmp6h->icmp6_code = code; + + skb_put_data(nskb, skb_network_header(oldskb), len); + nip6h->payload_len = htons(nskb->len - sizeof(struct ipv6hdr)); + + icmp6h->icmp6_cksum = + csum_ipv6_magic(&nip6h->saddr, &nip6h->daddr, + nskb->len - sizeof(struct ipv6hdr), + IPPROTO_ICMPV6, + csum_partial(icmp6h, + nskb->len - sizeof(struct ipv6hdr), + 0)); + + return nskb; +} +EXPORT_SYMBOL_GPL(nf_reject_skb_v6_unreach); + +const struct tcphdr *nf_reject_ip6_tcphdr_get(struct sk_buff *oldskb, + struct tcphdr *otcph, + unsigned int *otcplen, int hook) +{ + const struct ipv6hdr *oip6h = ipv6_hdr(oldskb); + u8 proto; + __be16 frag_off; + int tcphoff; + + proto = oip6h->nexthdr; + tcphoff = ipv6_skip_exthdr(oldskb, ((u8 *)(oip6h + 1) - oldskb->data), + &proto, &frag_off); + + if ((tcphoff < 0) || (tcphoff > oldskb->len)) { + pr_debug("Cannot get TCP header.\n"); + return NULL; + } + + *otcplen = oldskb->len - tcphoff; + + /* IP header checks: fragment, too short. */ + if (proto != IPPROTO_TCP || *otcplen < sizeof(struct tcphdr)) { + pr_debug("proto(%d) != IPPROTO_TCP or too short (len = %d)\n", + proto, *otcplen); + return NULL; + } + + otcph = skb_header_pointer(oldskb, tcphoff, sizeof(struct tcphdr), + otcph); + if (otcph == NULL) + return NULL; + + /* No RST for RST. */ + if (otcph->rst) { + pr_debug("RST is set\n"); + return NULL; + } + + /* Check checksum. */ + if (nf_ip6_checksum(oldskb, hook, tcphoff, IPPROTO_TCP)) { + pr_debug("TCP checksum is invalid\n"); + return NULL; + } + + return otcph; +} +EXPORT_SYMBOL_GPL(nf_reject_ip6_tcphdr_get); + +struct ipv6hdr *nf_reject_ip6hdr_put(struct sk_buff *nskb, + const struct sk_buff *oldskb, + __u8 protocol, int hoplimit) +{ + struct ipv6hdr *ip6h; + const struct ipv6hdr *oip6h = ipv6_hdr(oldskb); +#define DEFAULT_TOS_VALUE 0x0U + const __u8 tclass = DEFAULT_TOS_VALUE; + + skb_put(nskb, sizeof(struct ipv6hdr)); + skb_reset_network_header(nskb); + ip6h = ipv6_hdr(nskb); + ip6_flow_hdr(ip6h, tclass, 0); + ip6h->hop_limit = hoplimit; + ip6h->nexthdr = protocol; + ip6h->saddr = oip6h->daddr; + ip6h->daddr = oip6h->saddr; + + nskb->protocol = htons(ETH_P_IPV6); + + return ip6h; +} +EXPORT_SYMBOL_GPL(nf_reject_ip6hdr_put); + +void nf_reject_ip6_tcphdr_put(struct sk_buff *nskb, + const struct sk_buff *oldskb, + const struct tcphdr *oth, unsigned int otcplen) +{ + struct tcphdr *tcph; + int needs_ack; + + skb_reset_transport_header(nskb); + tcph = skb_put(nskb, sizeof(struct tcphdr)); + /* Truncate to length (no data) */ + tcph->doff = sizeof(struct tcphdr)/4; + tcph->source = oth->dest; + tcph->dest = oth->source; + + if (oth->ack) { + needs_ack = 0; + tcph->seq = oth->ack_seq; + tcph->ack_seq = 0; + } else { + needs_ack = 1; + tcph->ack_seq = htonl(ntohl(oth->seq) + oth->syn + oth->fin + + otcplen - (oth->doff<<2)); + tcph->seq = 0; + } + + /* Reset flags */ + ((u_int8_t *)tcph)[13] = 0; + tcph->rst = 1; + tcph->ack = needs_ack; + tcph->window = 0; + tcph->urg_ptr = 0; + tcph->check = 0; + + /* Adjust TCP checksum */ + tcph->check = csum_ipv6_magic(&ipv6_hdr(nskb)->saddr, + &ipv6_hdr(nskb)->daddr, + sizeof(struct tcphdr), IPPROTO_TCP, + csum_partial(tcph, + sizeof(struct tcphdr), 0)); +} +EXPORT_SYMBOL_GPL(nf_reject_ip6_tcphdr_put); + +static int nf_reject6_fill_skb_dst(struct sk_buff *skb_in) +{ + struct dst_entry *dst = NULL; + struct flowi fl; + + memset(&fl, 0, sizeof(struct flowi)); + fl.u.ip6.daddr = ipv6_hdr(skb_in)->saddr; + nf_ip6_route(dev_net(skb_in->dev), &dst, &fl, false); + if (!dst) + return -1; + + skb_dst_set(skb_in, dst); + return 0; +} + +void nf_send_reset6(struct net *net, struct sock *sk, struct sk_buff *oldskb, + int hook) +{ + struct sk_buff *nskb; + struct tcphdr _otcph; + const struct tcphdr *otcph; + unsigned int otcplen, hh_len; + const struct ipv6hdr *oip6h = ipv6_hdr(oldskb); + struct ipv6hdr *ip6h; + struct dst_entry *dst = NULL; + struct flowi6 fl6; + + if ((!(ipv6_addr_type(&oip6h->saddr) & IPV6_ADDR_UNICAST)) || + (!(ipv6_addr_type(&oip6h->daddr) & IPV6_ADDR_UNICAST))) { + pr_debug("addr is not unicast.\n"); + return; + } + + otcph = nf_reject_ip6_tcphdr_get(oldskb, &_otcph, &otcplen, hook); + if (!otcph) + return; + + memset(&fl6, 0, sizeof(fl6)); + fl6.flowi6_proto = IPPROTO_TCP; + fl6.saddr = oip6h->daddr; + fl6.daddr = oip6h->saddr; + fl6.fl6_sport = otcph->dest; + fl6.fl6_dport = otcph->source; + + if (hook == NF_INET_PRE_ROUTING || hook == NF_INET_INGRESS) { + nf_ip6_route(net, &dst, flowi6_to_flowi(&fl6), false); + if (!dst) + return; + skb_dst_set(oldskb, dst); + } + + fl6.flowi6_oif = l3mdev_master_ifindex(skb_dst(oldskb)->dev); + fl6.flowi6_mark = IP6_REPLY_MARK(net, oldskb->mark); + security_skb_classify_flow(oldskb, flowi6_to_flowi_common(&fl6)); + dst = ip6_route_output(net, NULL, &fl6); + if (dst->error) { + dst_release(dst); + return; + } + dst = xfrm_lookup(net, dst, flowi6_to_flowi(&fl6), NULL, 0); + if (IS_ERR(dst)) + return; + + hh_len = (dst->dev->hard_header_len + 15)&~15; + nskb = alloc_skb(hh_len + 15 + dst->header_len + sizeof(struct ipv6hdr) + + sizeof(struct tcphdr) + dst->trailer_len, + GFP_ATOMIC); + + if (!nskb) { + net_dbg_ratelimited("cannot alloc skb\n"); + dst_release(dst); + return; + } + + skb_dst_set(nskb, dst); + + nskb->mark = fl6.flowi6_mark; + + skb_reserve(nskb, hh_len + dst->header_len); + ip6h = nf_reject_ip6hdr_put(nskb, oldskb, IPPROTO_TCP, + ip6_dst_hoplimit(dst)); + nf_reject_ip6_tcphdr_put(nskb, oldskb, otcph, otcplen); + + nf_ct_attach(nskb, oldskb); + +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + /* If we use ip6_local_out for bridged traffic, the MAC source on + * the RST will be ours, instead of the destination's. This confuses + * some routers/firewalls, and they drop the packet. So we need to + * build the eth header using the original destination's MAC as the + * source, and send the RST packet directly. + */ + if (nf_bridge_info_exists(oldskb)) { + struct ethhdr *oeth = eth_hdr(oldskb); + struct net_device *br_indev; + + br_indev = nf_bridge_get_physindev(oldskb, net); + if (!br_indev) { + kfree_skb(nskb); + return; + } + + nskb->dev = br_indev; + nskb->protocol = htons(ETH_P_IPV6); + ip6h->payload_len = htons(sizeof(struct tcphdr)); + if (dev_hard_header(nskb, nskb->dev, ntohs(nskb->protocol), + oeth->h_source, oeth->h_dest, nskb->len) < 0) { + kfree_skb(nskb); + return; + } + dev_queue_xmit(nskb); + } else +#endif + ip6_local_out(net, sk, nskb); +} +EXPORT_SYMBOL_GPL(nf_send_reset6); + +static bool reject6_csum_ok(struct sk_buff *skb, int hook) +{ + const struct ipv6hdr *ip6h = ipv6_hdr(skb); + int thoff; + __be16 fo; + u8 proto; + + if (skb_csum_unnecessary(skb)) + return true; + + proto = ip6h->nexthdr; + thoff = ipv6_skip_exthdr(skb, ((u8 *)(ip6h + 1) - skb->data), &proto, &fo); + + if (thoff < 0 || thoff >= skb->len || (fo & htons(~0x7)) != 0) + return false; + + if (!nf_reject_verify_csum(skb, thoff, proto)) + return true; + + return nf_ip6_checksum(skb, hook, thoff, proto) == 0; +} + +void nf_send_unreach6(struct net *net, struct sk_buff *skb_in, + unsigned char code, unsigned int hooknum) +{ + if (!reject6_csum_ok(skb_in, hooknum)) + return; + + if (hooknum == NF_INET_LOCAL_OUT && skb_in->dev == NULL) + skb_in->dev = net->loopback_dev; + + if ((hooknum == NF_INET_PRE_ROUTING || hooknum == NF_INET_INGRESS) && + nf_reject6_fill_skb_dst(skb_in) < 0) + return; + + icmpv6_send(skb_in, ICMPV6_DEST_UNREACH, code, 0); +} +EXPORT_SYMBOL_GPL(nf_send_unreach6); + +MODULE_LICENSE("GPL"); diff --git a/net/ipv6/netfilter/nf_socket_ipv6.c b/net/ipv6/netfilter/nf_socket_ipv6.c new file mode 100644 index 000000000..a7690ec62 --- /dev/null +++ b/net/ipv6/netfilter/nf_socket_ipv6.c @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2007-2008 BalaBit IT Ltd. + * Author: Krisztian Kovacs + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_NF_CONNTRACK) +#include +#endif + +static int +extract_icmp6_fields(const struct sk_buff *skb, + unsigned int outside_hdrlen, + int *protocol, + const struct in6_addr **raddr, + const struct in6_addr **laddr, + __be16 *rport, + __be16 *lport, + struct ipv6hdr *ipv6_var) +{ + const struct ipv6hdr *inside_iph; + struct icmp6hdr *icmph, _icmph; + __be16 *ports, _ports[2]; + u8 inside_nexthdr; + __be16 inside_fragoff; + int inside_hdrlen; + + icmph = skb_header_pointer(skb, outside_hdrlen, + sizeof(_icmph), &_icmph); + if (icmph == NULL) + return 1; + + if (icmph->icmp6_type & ICMPV6_INFOMSG_MASK) + return 1; + + inside_iph = skb_header_pointer(skb, outside_hdrlen + sizeof(_icmph), + sizeof(*ipv6_var), ipv6_var); + if (inside_iph == NULL) + return 1; + inside_nexthdr = inside_iph->nexthdr; + + inside_hdrlen = ipv6_skip_exthdr(skb, outside_hdrlen + sizeof(_icmph) + + sizeof(*ipv6_var), + &inside_nexthdr, &inside_fragoff); + if (inside_hdrlen < 0) + return 1; /* hjm: Packet has no/incomplete transport layer headers. */ + + if (inside_nexthdr != IPPROTO_TCP && + inside_nexthdr != IPPROTO_UDP) + return 1; + + ports = skb_header_pointer(skb, inside_hdrlen, + sizeof(_ports), &_ports); + if (ports == NULL) + return 1; + + /* the inside IP packet is the one quoted from our side, thus + * its saddr is the local address */ + *protocol = inside_nexthdr; + *laddr = &inside_iph->saddr; + *lport = ports[0]; + *raddr = &inside_iph->daddr; + *rport = ports[1]; + + return 0; +} + +static struct sock * +nf_socket_get_sock_v6(struct net *net, struct sk_buff *skb, int doff, + const u8 protocol, + const struct in6_addr *saddr, const struct in6_addr *daddr, + const __be16 sport, const __be16 dport, + const struct net_device *in) +{ + switch (protocol) { + case IPPROTO_TCP: + return inet6_lookup(net, net->ipv4.tcp_death_row.hashinfo, + skb, doff, saddr, sport, daddr, dport, + in->ifindex); + case IPPROTO_UDP: + return udp6_lib_lookup(net, saddr, sport, daddr, dport, + in->ifindex); + } + + return NULL; +} + +struct sock *nf_sk_lookup_slow_v6(struct net *net, const struct sk_buff *skb, + const struct net_device *indev) +{ + __be16 dport, sport; + const struct in6_addr *daddr = NULL, *saddr = NULL; + struct ipv6hdr *iph = ipv6_hdr(skb), ipv6_var; + struct sk_buff *data_skb = NULL; + int doff = 0; + int thoff = 0, tproto; + + tproto = ipv6_find_hdr(skb, &thoff, -1, NULL, NULL); + if (tproto < 0) { + pr_debug("unable to find transport header in IPv6 packet, dropping\n"); + return NULL; + } + + if (tproto == IPPROTO_UDP || tproto == IPPROTO_TCP) { + struct tcphdr _hdr; + struct udphdr *hp; + + hp = skb_header_pointer(skb, thoff, tproto == IPPROTO_UDP ? + sizeof(*hp) : sizeof(_hdr), &_hdr); + if (hp == NULL) + return NULL; + + saddr = &iph->saddr; + sport = hp->source; + daddr = &iph->daddr; + dport = hp->dest; + data_skb = (struct sk_buff *)skb; + doff = tproto == IPPROTO_TCP ? + thoff + __tcp_hdrlen((struct tcphdr *)hp) : + thoff + sizeof(*hp); + + } else if (tproto == IPPROTO_ICMPV6) { + if (extract_icmp6_fields(skb, thoff, &tproto, &saddr, &daddr, + &sport, &dport, &ipv6_var)) + return NULL; + } else { + return NULL; + } + + return nf_socket_get_sock_v6(net, data_skb, doff, tproto, saddr, daddr, + sport, dport, indev); +} +EXPORT_SYMBOL_GPL(nf_sk_lookup_slow_v6); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Krisztian Kovacs, Balazs Scheidler"); +MODULE_DESCRIPTION("Netfilter IPv6 socket lookup infrastructure"); diff --git a/net/ipv6/netfilter/nf_tproxy_ipv6.c b/net/ipv6/netfilter/nf_tproxy_ipv6.c new file mode 100644 index 000000000..52f828bb5 --- /dev/null +++ b/net/ipv6/netfilter/nf_tproxy_ipv6.c @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include + +const struct in6_addr * +nf_tproxy_laddr6(struct sk_buff *skb, const struct in6_addr *user_laddr, + const struct in6_addr *daddr) +{ + struct inet6_dev *indev; + struct inet6_ifaddr *ifa; + struct in6_addr *laddr; + + if (!ipv6_addr_any(user_laddr)) + return user_laddr; + laddr = NULL; + + indev = __in6_dev_get(skb->dev); + if (indev) { + read_lock_bh(&indev->lock); + list_for_each_entry(ifa, &indev->addr_list, if_list) { + if (ifa->flags & (IFA_F_TENTATIVE | IFA_F_DEPRECATED)) + continue; + + laddr = &ifa->addr; + break; + } + read_unlock_bh(&indev->lock); + } + + return laddr ? laddr : daddr; +} +EXPORT_SYMBOL_GPL(nf_tproxy_laddr6); + +struct sock * +nf_tproxy_handle_time_wait6(struct sk_buff *skb, int tproto, int thoff, + struct net *net, + const struct in6_addr *laddr, + const __be16 lport, + struct sock *sk) +{ + const struct ipv6hdr *iph = ipv6_hdr(skb); + struct tcphdr _hdr, *hp; + + hp = skb_header_pointer(skb, thoff, sizeof(_hdr), &_hdr); + if (hp == NULL) { + inet_twsk_put(inet_twsk(sk)); + return NULL; + } + + if (hp->syn && !hp->rst && !hp->ack && !hp->fin) { + /* SYN to a TIME_WAIT socket, we'd rather redirect it + * to a listener socket if there's one */ + struct sock *sk2; + + sk2 = nf_tproxy_get_sock_v6(net, skb, thoff, tproto, + &iph->saddr, + nf_tproxy_laddr6(skb, laddr, &iph->daddr), + hp->source, + lport ? lport : hp->dest, + skb->dev, NF_TPROXY_LOOKUP_LISTENER); + if (sk2) { + nf_tproxy_twsk_deschedule_put(inet_twsk(sk)); + sk = sk2; + } + } + + return sk; +} +EXPORT_SYMBOL_GPL(nf_tproxy_handle_time_wait6); + +struct sock * +nf_tproxy_get_sock_v6(struct net *net, struct sk_buff *skb, int thoff, + const u8 protocol, + const struct in6_addr *saddr, const struct in6_addr *daddr, + const __be16 sport, const __be16 dport, + const struct net_device *in, + const enum nf_tproxy_lookup_t lookup_type) +{ + struct inet_hashinfo *hinfo = net->ipv4.tcp_death_row.hashinfo; + struct sock *sk; + + switch (protocol) { + case IPPROTO_TCP: { + struct tcphdr _hdr, *hp; + + hp = skb_header_pointer(skb, thoff, + sizeof(struct tcphdr), &_hdr); + if (hp == NULL) + return NULL; + + switch (lookup_type) { + case NF_TPROXY_LOOKUP_LISTENER: + sk = inet6_lookup_listener(net, hinfo, skb, + thoff + __tcp_hdrlen(hp), + saddr, sport, + daddr, ntohs(dport), + in->ifindex, 0); + + if (sk && !refcount_inc_not_zero(&sk->sk_refcnt)) + sk = NULL; + /* NOTE: we return listeners even if bound to + * 0.0.0.0, those are filtered out in + * xt_socket, since xt_TPROXY needs 0 bound + * listeners too + */ + break; + case NF_TPROXY_LOOKUP_ESTABLISHED: + sk = __inet6_lookup_established(net, hinfo, saddr, sport, daddr, + ntohs(dport), in->ifindex, 0); + break; + default: + BUG(); + } + break; + } + case IPPROTO_UDP: + sk = udp6_lib_lookup(net, saddr, sport, daddr, dport, + in->ifindex); + if (sk) { + int connected = (sk->sk_state == TCP_ESTABLISHED); + int wildcard = ipv6_addr_any(&sk->sk_v6_rcv_saddr); + + /* NOTE: we return listeners even if bound to + * 0.0.0.0, those are filtered out in + * xt_socket, since xt_TPROXY needs 0 bound + * listeners too + */ + if ((lookup_type == NF_TPROXY_LOOKUP_ESTABLISHED && (!connected || wildcard)) || + (lookup_type == NF_TPROXY_LOOKUP_LISTENER && connected)) { + sock_put(sk); + sk = NULL; + } + } + break; + default: + WARN_ON(1); + sk = NULL; + } + + pr_debug("tproxy socket lookup: proto %u %pI6:%u -> %pI6:%u, lookup type: %d, sock %p\n", + protocol, saddr, ntohs(sport), daddr, ntohs(dport), lookup_type, sk); + + return sk; +} +EXPORT_SYMBOL_GPL(nf_tproxy_get_sock_v6); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Balazs Scheidler, Krisztian Kovacs"); +MODULE_DESCRIPTION("Netfilter IPv6 transparent proxy support"); diff --git a/net/ipv6/netfilter/nft_dup_ipv6.c b/net/ipv6/netfilter/nft_dup_ipv6.c new file mode 100644 index 000000000..70a405b40 --- /dev/null +++ b/net/ipv6/netfilter/nft_dup_ipv6.c @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2015 Pablo Neira Ayuso + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_dup_ipv6 { + u8 sreg_addr; + u8 sreg_dev; +}; + +static void nft_dup_ipv6_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_dup_ipv6 *priv = nft_expr_priv(expr); + struct in6_addr *gw = (struct in6_addr *)®s->data[priv->sreg_addr]; + int oif = priv->sreg_dev ? regs->data[priv->sreg_dev] : -1; + + nf_dup_ipv6(nft_net(pkt), pkt->skb, nft_hook(pkt), gw, oif); +} + +static int nft_dup_ipv6_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_dup_ipv6 *priv = nft_expr_priv(expr); + int err; + + if (tb[NFTA_DUP_SREG_ADDR] == NULL) + return -EINVAL; + + err = nft_parse_register_load(tb[NFTA_DUP_SREG_ADDR], &priv->sreg_addr, + sizeof(struct in6_addr)); + if (err < 0) + return err; + + if (tb[NFTA_DUP_SREG_DEV]) + err = nft_parse_register_load(tb[NFTA_DUP_SREG_DEV], + &priv->sreg_dev, sizeof(int)); + + return err; +} + +static int nft_dup_ipv6_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + struct nft_dup_ipv6 *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_DUP_SREG_ADDR, priv->sreg_addr)) + goto nla_put_failure; + if (priv->sreg_dev && + nft_dump_register(skb, NFTA_DUP_SREG_DEV, priv->sreg_dev)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static struct nft_expr_type nft_dup_ipv6_type; +static const struct nft_expr_ops nft_dup_ipv6_ops = { + .type = &nft_dup_ipv6_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_dup_ipv6)), + .eval = nft_dup_ipv6_eval, + .init = nft_dup_ipv6_init, + .dump = nft_dup_ipv6_dump, + .reduce = NFT_REDUCE_READONLY, +}; + +static const struct nla_policy nft_dup_ipv6_policy[NFTA_DUP_MAX + 1] = { + [NFTA_DUP_SREG_ADDR] = { .type = NLA_U32 }, + [NFTA_DUP_SREG_DEV] = { .type = NLA_U32 }, +}; + +static struct nft_expr_type nft_dup_ipv6_type __read_mostly = { + .family = NFPROTO_IPV6, + .name = "dup", + .ops = &nft_dup_ipv6_ops, + .policy = nft_dup_ipv6_policy, + .maxattr = NFTA_DUP_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_dup_ipv6_module_init(void) +{ + return nft_register_expr(&nft_dup_ipv6_type); +} + +static void __exit nft_dup_ipv6_module_exit(void) +{ + nft_unregister_expr(&nft_dup_ipv6_type); +} + +module_init(nft_dup_ipv6_module_init); +module_exit(nft_dup_ipv6_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_ALIAS_NFT_AF_EXPR(AF_INET6, "dup"); +MODULE_DESCRIPTION("IPv6 nftables packet duplication support"); diff --git a/net/ipv6/netfilter/nft_fib_ipv6.c b/net/ipv6/netfilter/nft_fib_ipv6.c new file mode 100644 index 000000000..36dc14b34 --- /dev/null +++ b/net/ipv6/netfilter/nft_fib_ipv6.c @@ -0,0 +1,283 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +static int get_ifindex(const struct net_device *dev) +{ + return dev ? dev->ifindex : 0; +} + +static int nft_fib6_flowi_init(struct flowi6 *fl6, const struct nft_fib *priv, + const struct nft_pktinfo *pkt, + const struct net_device *dev, + struct ipv6hdr *iph) +{ + int lookup_flags = 0; + + if (priv->flags & NFTA_FIB_F_DADDR) { + fl6->daddr = iph->daddr; + fl6->saddr = iph->saddr; + } else { + if (nft_hook(pkt) == NF_INET_FORWARD && + priv->flags & NFTA_FIB_F_IIF) + fl6->flowi6_iif = nft_out(pkt)->ifindex; + + fl6->daddr = iph->saddr; + fl6->saddr = iph->daddr; + } + + if (ipv6_addr_type(&fl6->daddr) & IPV6_ADDR_LINKLOCAL) { + lookup_flags |= RT6_LOOKUP_F_IFACE; + fl6->flowi6_oif = get_ifindex(dev ? dev : pkt->skb->dev); + } else if (priv->flags & NFTA_FIB_F_IIF) { + fl6->flowi6_l3mdev = l3mdev_master_ifindex_rcu(dev); + } + + if (ipv6_addr_type(&fl6->saddr) & IPV6_ADDR_UNICAST) + lookup_flags |= RT6_LOOKUP_F_HAS_SADDR; + + if (priv->flags & NFTA_FIB_F_MARK) + fl6->flowi6_mark = pkt->skb->mark; + + fl6->flowlabel = (*(__be32 *)iph) & IPV6_FLOWINFO_MASK; + + return lookup_flags; +} + +static u32 __nft_fib6_eval_type(const struct nft_fib *priv, + const struct nft_pktinfo *pkt, + struct ipv6hdr *iph) +{ + const struct net_device *dev = NULL; + int route_err, addrtype; + struct rt6_info *rt; + struct flowi6 fl6 = { + .flowi6_iif = LOOPBACK_IFINDEX, + .flowi6_proto = pkt->tprot, + .flowi6_uid = sock_net_uid(nft_net(pkt), NULL), + }; + u32 ret = 0; + + if (priv->flags & NFTA_FIB_F_IIF) + dev = nft_in(pkt); + else if (priv->flags & NFTA_FIB_F_OIF) + dev = nft_out(pkt); + + nft_fib6_flowi_init(&fl6, priv, pkt, dev, iph); + + if (dev && nf_ipv6_chk_addr(nft_net(pkt), &fl6.daddr, dev, true)) + ret = RTN_LOCAL; + + route_err = nf_ip6_route(nft_net(pkt), (struct dst_entry **)&rt, + flowi6_to_flowi(&fl6), false); + if (route_err) + goto err; + + if (rt->rt6i_flags & RTF_REJECT) { + route_err = rt->dst.error; + dst_release(&rt->dst); + goto err; + } + + if (ipv6_anycast_destination((struct dst_entry *)rt, &fl6.daddr)) + ret = RTN_ANYCAST; + else if (!dev && rt->rt6i_flags & RTF_LOCAL) + ret = RTN_LOCAL; + + dst_release(&rt->dst); + + if (ret) + return ret; + + addrtype = ipv6_addr_type(&fl6.daddr); + + if (addrtype & IPV6_ADDR_MULTICAST) + return RTN_MULTICAST; + if (addrtype & IPV6_ADDR_UNICAST) + return RTN_UNICAST; + + return RTN_UNSPEC; + err: + switch (route_err) { + case -EINVAL: + return RTN_BLACKHOLE; + case -EACCES: + return RTN_PROHIBIT; + case -EAGAIN: + return RTN_THROW; + default: + break; + } + + return RTN_UNREACHABLE; +} + +void nft_fib6_eval_type(const struct nft_expr *expr, struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_fib *priv = nft_expr_priv(expr); + int noff = skb_network_offset(pkt->skb); + u32 *dest = ®s->data[priv->dreg]; + struct ipv6hdr *iph, _iph; + + iph = skb_header_pointer(pkt->skb, noff, sizeof(_iph), &_iph); + if (!iph) { + regs->verdict.code = NFT_BREAK; + return; + } + + *dest = __nft_fib6_eval_type(priv, pkt, iph); +} +EXPORT_SYMBOL_GPL(nft_fib6_eval_type); + +static bool nft_fib_v6_skip_icmpv6(const struct sk_buff *skb, u8 next, const struct ipv6hdr *iph) +{ + if (likely(next != IPPROTO_ICMPV6)) + return false; + + if (ipv6_addr_type(&iph->saddr) != IPV6_ADDR_ANY) + return false; + + return ipv6_addr_type(&iph->daddr) & IPV6_ADDR_LINKLOCAL; +} + +void nft_fib6_eval(const struct nft_expr *expr, struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_fib *priv = nft_expr_priv(expr); + int noff = skb_network_offset(pkt->skb); + const struct net_device *oif = NULL; + u32 *dest = ®s->data[priv->dreg]; + struct ipv6hdr *iph, _iph; + struct flowi6 fl6 = { + .flowi6_iif = LOOPBACK_IFINDEX, + .flowi6_proto = pkt->tprot, + .flowi6_uid = sock_net_uid(nft_net(pkt), NULL), + }; + struct rt6_info *rt; + int lookup_flags; + + if (priv->flags & NFTA_FIB_F_IIF) + oif = nft_in(pkt); + else if (priv->flags & NFTA_FIB_F_OIF) + oif = nft_out(pkt); + + iph = skb_header_pointer(pkt->skb, noff, sizeof(_iph), &_iph); + if (!iph) { + regs->verdict.code = NFT_BREAK; + return; + } + + lookup_flags = nft_fib6_flowi_init(&fl6, priv, pkt, oif, iph); + + if (nft_hook(pkt) == NF_INET_PRE_ROUTING || + nft_hook(pkt) == NF_INET_INGRESS) { + if (nft_fib_is_loopback(pkt->skb, nft_in(pkt)) || + nft_fib_v6_skip_icmpv6(pkt->skb, pkt->tprot, iph)) { + nft_fib_store_result(dest, priv, nft_in(pkt)); + return; + } + } + + *dest = 0; + rt = (void *)ip6_route_lookup(nft_net(pkt), &fl6, pkt->skb, + lookup_flags); + if (rt->dst.error) + goto put_rt_err; + + /* Should not see RTF_LOCAL here */ + if (rt->rt6i_flags & (RTF_REJECT | RTF_ANYCAST | RTF_LOCAL)) + goto put_rt_err; + + if (oif && oif != rt->rt6i_idev->dev && + l3mdev_master_ifindex_rcu(rt->rt6i_idev->dev) != oif->ifindex) + goto put_rt_err; + + nft_fib_store_result(dest, priv, rt->rt6i_idev->dev); + put_rt_err: + ip6_rt_put(rt); +} +EXPORT_SYMBOL_GPL(nft_fib6_eval); + +static struct nft_expr_type nft_fib6_type; + +static const struct nft_expr_ops nft_fib6_type_ops = { + .type = &nft_fib6_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_fib)), + .eval = nft_fib6_eval_type, + .init = nft_fib_init, + .dump = nft_fib_dump, + .validate = nft_fib_validate, + .reduce = nft_fib_reduce, +}; + +static const struct nft_expr_ops nft_fib6_ops = { + .type = &nft_fib6_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_fib)), + .eval = nft_fib6_eval, + .init = nft_fib_init, + .dump = nft_fib_dump, + .validate = nft_fib_validate, + .reduce = nft_fib_reduce, +}; + +static const struct nft_expr_ops * +nft_fib6_select_ops(const struct nft_ctx *ctx, + const struct nlattr * const tb[]) +{ + enum nft_fib_result result; + + if (!tb[NFTA_FIB_RESULT]) + return ERR_PTR(-EINVAL); + + result = ntohl(nla_get_be32(tb[NFTA_FIB_RESULT])); + + switch (result) { + case NFT_FIB_RESULT_OIF: + return &nft_fib6_ops; + case NFT_FIB_RESULT_OIFNAME: + return &nft_fib6_ops; + case NFT_FIB_RESULT_ADDRTYPE: + return &nft_fib6_type_ops; + default: + return ERR_PTR(-EOPNOTSUPP); + } +} + +static struct nft_expr_type nft_fib6_type __read_mostly = { + .name = "fib", + .select_ops = nft_fib6_select_ops, + .policy = nft_fib_policy, + .maxattr = NFTA_FIB_MAX, + .family = NFPROTO_IPV6, + .owner = THIS_MODULE, +}; + +static int __init nft_fib6_module_init(void) +{ + return nft_register_expr(&nft_fib6_type); +} + +static void __exit nft_fib6_module_exit(void) +{ + nft_unregister_expr(&nft_fib6_type); +} +module_init(nft_fib6_module_init); +module_exit(nft_fib6_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Florian Westphal "); +MODULE_ALIAS_NFT_AF_EXPR(10, "fib"); +MODULE_DESCRIPTION("nftables fib / ipv6 route lookup support"); diff --git a/net/ipv6/netfilter/nft_reject_ipv6.c b/net/ipv6/netfilter/nft_reject_ipv6.c new file mode 100644 index 000000000..5c61294f4 --- /dev/null +++ b/net/ipv6/netfilter/nft_reject_ipv6.c @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008-2009 Patrick McHardy + * Copyright (c) 2013 Eric Leblond + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void nft_reject_ipv6_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_reject *priv = nft_expr_priv(expr); + + switch (priv->type) { + case NFT_REJECT_ICMP_UNREACH: + nf_send_unreach6(nft_net(pkt), pkt->skb, priv->icmp_code, + nft_hook(pkt)); + break; + case NFT_REJECT_TCP_RST: + nf_send_reset6(nft_net(pkt), nft_sk(pkt), pkt->skb, + nft_hook(pkt)); + break; + default: + break; + } + + regs->verdict.code = NF_DROP; +} + +static struct nft_expr_type nft_reject_ipv6_type; +static const struct nft_expr_ops nft_reject_ipv6_ops = { + .type = &nft_reject_ipv6_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_reject)), + .eval = nft_reject_ipv6_eval, + .init = nft_reject_init, + .dump = nft_reject_dump, + .validate = nft_reject_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_reject_ipv6_type __read_mostly = { + .family = NFPROTO_IPV6, + .name = "reject", + .ops = &nft_reject_ipv6_ops, + .policy = nft_reject_policy, + .maxattr = NFTA_REJECT_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_reject_ipv6_module_init(void) +{ + return nft_register_expr(&nft_reject_ipv6_type); +} + +static void __exit nft_reject_ipv6_module_exit(void) +{ + nft_unregister_expr(&nft_reject_ipv6_type); +} + +module_init(nft_reject_ipv6_module_init); +module_exit(nft_reject_ipv6_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_ALIAS_NFT_AF_EXPR(AF_INET6, "reject"); +MODULE_DESCRIPTION("IPv6 packet rejection for nftables"); diff --git a/net/ipv6/output_core.c b/net/ipv6/output_core.c new file mode 100644 index 000000000..2685c3f15 --- /dev/null +++ b/net/ipv6/output_core.c @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * IPv6 library code, needed by static components when full IPv6 support is + * not configured or static. These functions are needed by GSO/GRO implementation. + */ +#include +#include +#include +#include +#include +#include +#include + +static u32 __ipv6_select_ident(struct net *net, + const struct in6_addr *dst, + const struct in6_addr *src) +{ + u32 id; + + do { + id = get_random_u32(); + } while (!id); + + return id; +} + +/* This function exists only for tap drivers that must support broken + * clients requesting UFO without specifying an IPv6 fragment ID. + * + * This is similar to ipv6_select_ident() but we use an independent hash + * seed to limit information leakage. + * + * The network header must be set before calling this. + */ +__be32 ipv6_proxy_select_ident(struct net *net, struct sk_buff *skb) +{ + struct in6_addr buf[2]; + struct in6_addr *addrs; + u32 id; + + addrs = skb_header_pointer(skb, + skb_network_offset(skb) + + offsetof(struct ipv6hdr, saddr), + sizeof(buf), buf); + if (!addrs) + return 0; + + id = __ipv6_select_ident(net, &addrs[1], &addrs[0]); + return htonl(id); +} +EXPORT_SYMBOL_GPL(ipv6_proxy_select_ident); + +__be32 ipv6_select_ident(struct net *net, + const struct in6_addr *daddr, + const struct in6_addr *saddr) +{ + u32 id; + + id = __ipv6_select_ident(net, daddr, saddr); + return htonl(id); +} +EXPORT_SYMBOL(ipv6_select_ident); + +int ip6_find_1stfragopt(struct sk_buff *skb, u8 **nexthdr) +{ + unsigned int offset = sizeof(struct ipv6hdr); + unsigned int packet_len = skb_tail_pointer(skb) - + skb_network_header(skb); + int found_rhdr = 0; + *nexthdr = &ipv6_hdr(skb)->nexthdr; + + while (offset <= packet_len) { + struct ipv6_opt_hdr *exthdr; + + switch (**nexthdr) { + + case NEXTHDR_HOP: + break; + case NEXTHDR_ROUTING: + found_rhdr = 1; + break; + case NEXTHDR_DEST: +#if IS_ENABLED(CONFIG_IPV6_MIP6) + if (ipv6_find_tlv(skb, offset, IPV6_TLV_HAO) >= 0) + break; +#endif + if (found_rhdr) + return offset; + break; + default: + return offset; + } + + if (offset + sizeof(struct ipv6_opt_hdr) > packet_len) + return -EINVAL; + + exthdr = (struct ipv6_opt_hdr *)(skb_network_header(skb) + + offset); + offset += ipv6_optlen(exthdr); + if (offset > IPV6_MAXPLEN) + return -EINVAL; + *nexthdr = &exthdr->nexthdr; + } + + return -EINVAL; +} +EXPORT_SYMBOL(ip6_find_1stfragopt); + +#if IS_ENABLED(CONFIG_IPV6) +int ip6_dst_hoplimit(struct dst_entry *dst) +{ + int hoplimit = dst_metric_raw(dst, RTAX_HOPLIMIT); + if (hoplimit == 0) { + struct net_device *dev = dst->dev; + struct inet6_dev *idev; + + rcu_read_lock(); + idev = __in6_dev_get(dev); + if (idev) + hoplimit = idev->cnf.hop_limit; + else + hoplimit = dev_net(dev)->ipv6.devconf_all->hop_limit; + rcu_read_unlock(); + } + return hoplimit; +} +EXPORT_SYMBOL(ip6_dst_hoplimit); +#endif + +int __ip6_local_out(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + int len; + + len = skb->len - sizeof(struct ipv6hdr); + if (len > IPV6_MAXPLEN) + len = 0; + ipv6_hdr(skb)->payload_len = htons(len); + IP6CB(skb)->nhoff = offsetof(struct ipv6hdr, nexthdr); + + /* if egress device is enslaved to an L3 master device pass the + * skb to its handler for processing + */ + skb = l3mdev_ip6_out(sk, skb); + if (unlikely(!skb)) + return 0; + + skb->protocol = htons(ETH_P_IPV6); + + return nf_hook(NFPROTO_IPV6, NF_INET_LOCAL_OUT, + net, sk, skb, NULL, skb_dst(skb)->dev, + dst_output); +} +EXPORT_SYMBOL_GPL(__ip6_local_out); + +int ip6_local_out(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + int err; + + err = __ip6_local_out(net, sk, skb); + if (likely(err == 1)) + err = dst_output(net, sk, skb); + + return err; +} +EXPORT_SYMBOL_GPL(ip6_local_out); diff --git a/net/ipv6/ping.c b/net/ipv6/ping.c new file mode 100644 index 000000000..a5d7d1915 --- /dev/null +++ b/net/ipv6/ping.c @@ -0,0 +1,307 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * "Ping" sockets + * + * Based on ipv4/ping.c code. + * + * Authors: Lorenzo Colitti (IPv6 support) + * Vasiliy Kulikov / Openwall (IPv4 implementation, for Linux 2.6), + * Pavel Kankovsky (IPv4 implementation, for Linux 2.4.32) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* Compatibility glue so we can support IPv6 when it's compiled as a module */ +static int dummy_ipv6_recv_error(struct sock *sk, struct msghdr *msg, int len, + int *addr_len) +{ + return -EAFNOSUPPORT; +} +static void dummy_ip6_datagram_recv_ctl(struct sock *sk, struct msghdr *msg, + struct sk_buff *skb) +{ +} +static int dummy_icmpv6_err_convert(u8 type, u8 code, int *err) +{ + return -EAFNOSUPPORT; +} +static void dummy_ipv6_icmp_error(struct sock *sk, struct sk_buff *skb, int err, + __be16 port, u32 info, u8 *payload) {} +static int dummy_ipv6_chk_addr(struct net *net, const struct in6_addr *addr, + const struct net_device *dev, int strict) +{ + return 0; +} + +static int ping_v6_pre_connect(struct sock *sk, struct sockaddr *uaddr, + int addr_len) +{ + /* This check is replicated from __ip6_datagram_connect() and + * intended to prevent BPF program called below from accessing + * bytes that are out of the bound specified by user in addr_len. + */ + + if (addr_len < SIN6_LEN_RFC2133) + return -EINVAL; + + return BPF_CGROUP_RUN_PROG_INET6_CONNECT_LOCK(sk, uaddr); +} + +static int ping_v6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) +{ + struct inet_sock *inet = inet_sk(sk); + struct ipv6_pinfo *np = inet6_sk(sk); + struct icmp6hdr user_icmph; + int addr_type; + struct in6_addr *daddr; + int oif = 0; + struct flowi6 fl6; + int err; + struct dst_entry *dst; + struct rt6_info *rt; + struct pingfakehdr pfh; + struct ipcm6_cookie ipc6; + + err = ping_common_sendmsg(AF_INET6, msg, len, &user_icmph, + sizeof(user_icmph)); + if (err) + return err; + + memset(&fl6, 0, sizeof(fl6)); + + if (msg->msg_name) { + DECLARE_SOCKADDR(struct sockaddr_in6 *, u, msg->msg_name); + if (msg->msg_namelen < sizeof(*u)) + return -EINVAL; + if (u->sin6_family != AF_INET6) { + return -EAFNOSUPPORT; + } + daddr = &(u->sin6_addr); + if (np->sndflow) + fl6.flowlabel = u->sin6_flowinfo & IPV6_FLOWINFO_MASK; + if (__ipv6_addr_needs_scope_id(ipv6_addr_type(daddr))) + oif = u->sin6_scope_id; + } else { + if (sk->sk_state != TCP_ESTABLISHED) + return -EDESTADDRREQ; + daddr = &sk->sk_v6_daddr; + fl6.flowlabel = np->flow_label; + } + + if (!oif) + oif = sk->sk_bound_dev_if; + + if (!oif) + oif = np->sticky_pktinfo.ipi6_ifindex; + + if (!oif && ipv6_addr_is_multicast(daddr)) + oif = np->mcast_oif; + else if (!oif) + oif = np->ucast_oif; + + addr_type = ipv6_addr_type(daddr); + if ((__ipv6_addr_needs_scope_id(addr_type) && !oif) || + (addr_type & IPV6_ADDR_MAPPED) || + (oif && sk->sk_bound_dev_if && oif != sk->sk_bound_dev_if && + l3mdev_master_ifindex_by_index(sock_net(sk), oif) != sk->sk_bound_dev_if)) + return -EINVAL; + + ipcm6_init_sk(&ipc6, np); + ipc6.sockc.tsflags = READ_ONCE(sk->sk_tsflags); + ipc6.sockc.mark = READ_ONCE(sk->sk_mark); + + fl6.flowi6_oif = oif; + + if (msg->msg_controllen) { + struct ipv6_txoptions opt = {}; + + opt.tot_len = sizeof(opt); + ipc6.opt = &opt; + + err = ip6_datagram_send_ctl(sock_net(sk), sk, msg, &fl6, &ipc6); + if (err < 0) + return err; + + /* Changes to txoptions and flow info are not implemented, yet. + * Drop the options. + */ + ipc6.opt = NULL; + } + + fl6.flowi6_proto = IPPROTO_ICMPV6; + fl6.saddr = np->saddr; + fl6.daddr = *daddr; + fl6.flowi6_mark = ipc6.sockc.mark; + fl6.flowi6_uid = sk->sk_uid; + fl6.fl6_icmp_type = user_icmph.icmp6_type; + fl6.fl6_icmp_code = user_icmph.icmp6_code; + security_sk_classify_flow(sk, flowi6_to_flowi_common(&fl6)); + + fl6.flowlabel = ip6_make_flowinfo(ipc6.tclass, fl6.flowlabel); + + dst = ip6_sk_dst_lookup_flow(sk, &fl6, daddr, false); + if (IS_ERR(dst)) + return PTR_ERR(dst); + rt = (struct rt6_info *) dst; + + if (!fl6.flowi6_oif && ipv6_addr_is_multicast(&fl6.daddr)) + fl6.flowi6_oif = np->mcast_oif; + else if (!fl6.flowi6_oif) + fl6.flowi6_oif = np->ucast_oif; + + pfh.icmph.type = user_icmph.icmp6_type; + pfh.icmph.code = user_icmph.icmp6_code; + pfh.icmph.checksum = 0; + pfh.icmph.un.echo.id = inet->inet_sport; + pfh.icmph.un.echo.sequence = user_icmph.icmp6_sequence; + pfh.msg = msg; + pfh.wcheck = 0; + pfh.family = AF_INET6; + + if (ipc6.hlimit < 0) + ipc6.hlimit = ip6_sk_dst_hoplimit(np, &fl6, dst); + + lock_sock(sk); + err = ip6_append_data(sk, ping_getfrag, &pfh, len, + sizeof(struct icmp6hdr), &ipc6, &fl6, rt, + MSG_DONTWAIT); + + if (err) { + ICMP6_INC_STATS(sock_net(sk), rt->rt6i_idev, + ICMP6_MIB_OUTERRORS); + ip6_flush_pending_frames(sk); + } else { + icmpv6_push_pending_frames(sk, &fl6, + (struct icmp6hdr *)&pfh.icmph, len); + } + release_sock(sk); + + dst_release(dst); + + if (err) + return err; + + return len; +} + +struct proto pingv6_prot = { + .name = "PINGv6", + .owner = THIS_MODULE, + .init = ping_init_sock, + .close = ping_close, + .pre_connect = ping_v6_pre_connect, + .connect = ip6_datagram_connect_v6_only, + .disconnect = __udp_disconnect, + .setsockopt = ipv6_setsockopt, + .getsockopt = ipv6_getsockopt, + .sendmsg = ping_v6_sendmsg, + .recvmsg = ping_recvmsg, + .bind = ping_bind, + .backlog_rcv = ping_queue_rcv_skb, + .hash = ping_hash, + .unhash = ping_unhash, + .get_port = ping_get_port, + .put_port = ping_unhash, + .obj_size = sizeof(struct raw6_sock), +}; +EXPORT_SYMBOL_GPL(pingv6_prot); + +static struct inet_protosw pingv6_protosw = { + .type = SOCK_DGRAM, + .protocol = IPPROTO_ICMPV6, + .prot = &pingv6_prot, + .ops = &inet6_sockraw_ops, + .flags = INET_PROTOSW_REUSE, +}; + +#ifdef CONFIG_PROC_FS +static void *ping_v6_seq_start(struct seq_file *seq, loff_t *pos) +{ + return ping_seq_start(seq, pos, AF_INET6); +} + +static int ping_v6_seq_show(struct seq_file *seq, void *v) +{ + if (v == SEQ_START_TOKEN) { + seq_puts(seq, IPV6_SEQ_DGRAM_HEADER); + } else { + int bucket = ((struct ping_iter_state *) seq->private)->bucket; + struct inet_sock *inet = inet_sk(v); + __u16 srcp = ntohs(inet->inet_sport); + __u16 destp = ntohs(inet->inet_dport); + ip6_dgram_sock_seq_show(seq, v, srcp, destp, bucket); + } + return 0; +} + +static const struct seq_operations ping_v6_seq_ops = { + .start = ping_v6_seq_start, + .show = ping_v6_seq_show, + .next = ping_seq_next, + .stop = ping_seq_stop, +}; + +static int __net_init ping_v6_proc_init_net(struct net *net) +{ + if (!proc_create_net("icmp6", 0444, net->proc_net, &ping_v6_seq_ops, + sizeof(struct ping_iter_state))) + return -ENOMEM; + return 0; +} + +static void __net_exit ping_v6_proc_exit_net(struct net *net) +{ + remove_proc_entry("icmp6", net->proc_net); +} + +static struct pernet_operations ping_v6_net_ops = { + .init = ping_v6_proc_init_net, + .exit = ping_v6_proc_exit_net, +}; +#endif + +int __init pingv6_init(void) +{ +#ifdef CONFIG_PROC_FS + int ret = register_pernet_subsys(&ping_v6_net_ops); + if (ret) + return ret; +#endif + pingv6_ops.ipv6_recv_error = ipv6_recv_error; + pingv6_ops.ip6_datagram_recv_common_ctl = ip6_datagram_recv_common_ctl; + pingv6_ops.ip6_datagram_recv_specific_ctl = + ip6_datagram_recv_specific_ctl; + pingv6_ops.icmpv6_err_convert = icmpv6_err_convert; + pingv6_ops.ipv6_icmp_error = ipv6_icmp_error; + pingv6_ops.ipv6_chk_addr = ipv6_chk_addr; + return inet6_register_protosw(&pingv6_protosw); +} + +/* This never gets called because it's not possible to unload the ipv6 module, + * but just in case. + */ +void pingv6_exit(void) +{ + pingv6_ops.ipv6_recv_error = dummy_ipv6_recv_error; + pingv6_ops.ip6_datagram_recv_common_ctl = dummy_ip6_datagram_recv_ctl; + pingv6_ops.ip6_datagram_recv_specific_ctl = dummy_ip6_datagram_recv_ctl; + pingv6_ops.icmpv6_err_convert = dummy_icmpv6_err_convert; + pingv6_ops.ipv6_icmp_error = dummy_ipv6_icmp_error; + pingv6_ops.ipv6_chk_addr = dummy_ipv6_chk_addr; +#ifdef CONFIG_PROC_FS + unregister_pernet_subsys(&ping_v6_net_ops); +#endif + inet6_unregister_protosw(&pingv6_protosw); +} diff --git a/net/ipv6/proc.c b/net/ipv6/proc.c new file mode 100644 index 000000000..d6306aa46 --- /dev/null +++ b/net/ipv6/proc.c @@ -0,0 +1,319 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * This file implements the various access functions for the + * PROC file system. This is very similar to the IPv4 version, + * except it reports the sockets in the INET6 address family. + * + * Authors: David S. Miller (davem@caip.rutgers.edu) + * YOSHIFUJI Hideaki + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define MAX4(a, b, c, d) \ + max_t(u32, max_t(u32, a, b), max_t(u32, c, d)) +#define SNMP_MIB_MAX MAX4(UDP_MIB_MAX, TCP_MIB_MAX, \ + IPSTATS_MIB_MAX, ICMP_MIB_MAX) + +static int sockstat6_seq_show(struct seq_file *seq, void *v) +{ + struct net *net = seq->private; + + seq_printf(seq, "TCP6: inuse %d\n", + sock_prot_inuse_get(net, &tcpv6_prot)); + seq_printf(seq, "UDP6: inuse %d\n", + sock_prot_inuse_get(net, &udpv6_prot)); + seq_printf(seq, "UDPLITE6: inuse %d\n", + sock_prot_inuse_get(net, &udplitev6_prot)); + seq_printf(seq, "RAW6: inuse %d\n", + sock_prot_inuse_get(net, &rawv6_prot)); + seq_printf(seq, "FRAG6: inuse %u memory %lu\n", + atomic_read(&net->ipv6.fqdir->rhashtable.nelems), + frag_mem_limit(net->ipv6.fqdir)); + return 0; +} + +static const struct snmp_mib snmp6_ipstats_list[] = { +/* ipv6 mib according to RFC 2465 */ + SNMP_MIB_ITEM("Ip6InReceives", IPSTATS_MIB_INPKTS), + SNMP_MIB_ITEM("Ip6InHdrErrors", IPSTATS_MIB_INHDRERRORS), + SNMP_MIB_ITEM("Ip6InTooBigErrors", IPSTATS_MIB_INTOOBIGERRORS), + SNMP_MIB_ITEM("Ip6InNoRoutes", IPSTATS_MIB_INNOROUTES), + SNMP_MIB_ITEM("Ip6InAddrErrors", IPSTATS_MIB_INADDRERRORS), + SNMP_MIB_ITEM("Ip6InUnknownProtos", IPSTATS_MIB_INUNKNOWNPROTOS), + SNMP_MIB_ITEM("Ip6InTruncatedPkts", IPSTATS_MIB_INTRUNCATEDPKTS), + SNMP_MIB_ITEM("Ip6InDiscards", IPSTATS_MIB_INDISCARDS), + SNMP_MIB_ITEM("Ip6InDelivers", IPSTATS_MIB_INDELIVERS), + SNMP_MIB_ITEM("Ip6OutForwDatagrams", IPSTATS_MIB_OUTFORWDATAGRAMS), + SNMP_MIB_ITEM("Ip6OutRequests", IPSTATS_MIB_OUTPKTS), + SNMP_MIB_ITEM("Ip6OutDiscards", IPSTATS_MIB_OUTDISCARDS), + SNMP_MIB_ITEM("Ip6OutNoRoutes", IPSTATS_MIB_OUTNOROUTES), + SNMP_MIB_ITEM("Ip6ReasmTimeout", IPSTATS_MIB_REASMTIMEOUT), + SNMP_MIB_ITEM("Ip6ReasmReqds", IPSTATS_MIB_REASMREQDS), + SNMP_MIB_ITEM("Ip6ReasmOKs", IPSTATS_MIB_REASMOKS), + SNMP_MIB_ITEM("Ip6ReasmFails", IPSTATS_MIB_REASMFAILS), + SNMP_MIB_ITEM("Ip6FragOKs", IPSTATS_MIB_FRAGOKS), + SNMP_MIB_ITEM("Ip6FragFails", IPSTATS_MIB_FRAGFAILS), + SNMP_MIB_ITEM("Ip6FragCreates", IPSTATS_MIB_FRAGCREATES), + SNMP_MIB_ITEM("Ip6InMcastPkts", IPSTATS_MIB_INMCASTPKTS), + SNMP_MIB_ITEM("Ip6OutMcastPkts", IPSTATS_MIB_OUTMCASTPKTS), + SNMP_MIB_ITEM("Ip6InOctets", IPSTATS_MIB_INOCTETS), + SNMP_MIB_ITEM("Ip6OutOctets", IPSTATS_MIB_OUTOCTETS), + SNMP_MIB_ITEM("Ip6InMcastOctets", IPSTATS_MIB_INMCASTOCTETS), + SNMP_MIB_ITEM("Ip6OutMcastOctets", IPSTATS_MIB_OUTMCASTOCTETS), + SNMP_MIB_ITEM("Ip6InBcastOctets", IPSTATS_MIB_INBCASTOCTETS), + SNMP_MIB_ITEM("Ip6OutBcastOctets", IPSTATS_MIB_OUTBCASTOCTETS), + /* IPSTATS_MIB_CSUMERRORS is not relevant in IPv6 (no checksum) */ + SNMP_MIB_ITEM("Ip6InNoECTPkts", IPSTATS_MIB_NOECTPKTS), + SNMP_MIB_ITEM("Ip6InECT1Pkts", IPSTATS_MIB_ECT1PKTS), + SNMP_MIB_ITEM("Ip6InECT0Pkts", IPSTATS_MIB_ECT0PKTS), + SNMP_MIB_ITEM("Ip6InCEPkts", IPSTATS_MIB_CEPKTS), + SNMP_MIB_SENTINEL +}; + +static const struct snmp_mib snmp6_icmp6_list[] = { +/* icmpv6 mib according to RFC 2466 */ + SNMP_MIB_ITEM("Icmp6InMsgs", ICMP6_MIB_INMSGS), + SNMP_MIB_ITEM("Icmp6InErrors", ICMP6_MIB_INERRORS), + SNMP_MIB_ITEM("Icmp6OutMsgs", ICMP6_MIB_OUTMSGS), + SNMP_MIB_ITEM("Icmp6OutErrors", ICMP6_MIB_OUTERRORS), + SNMP_MIB_ITEM("Icmp6InCsumErrors", ICMP6_MIB_CSUMERRORS), + SNMP_MIB_SENTINEL +}; + +/* RFC 4293 v6 ICMPMsgStatsTable; named items for RFC 2466 compatibility */ +static const char *const icmp6type2name[256] = { + [ICMPV6_DEST_UNREACH] = "DestUnreachs", + [ICMPV6_PKT_TOOBIG] = "PktTooBigs", + [ICMPV6_TIME_EXCEED] = "TimeExcds", + [ICMPV6_PARAMPROB] = "ParmProblems", + [ICMPV6_ECHO_REQUEST] = "Echos", + [ICMPV6_ECHO_REPLY] = "EchoReplies", + [ICMPV6_MGM_QUERY] = "GroupMembQueries", + [ICMPV6_MGM_REPORT] = "GroupMembResponses", + [ICMPV6_MGM_REDUCTION] = "GroupMembReductions", + [ICMPV6_MLD2_REPORT] = "MLDv2Reports", + [NDISC_ROUTER_ADVERTISEMENT] = "RouterAdvertisements", + [NDISC_ROUTER_SOLICITATION] = "RouterSolicits", + [NDISC_NEIGHBOUR_ADVERTISEMENT] = "NeighborAdvertisements", + [NDISC_NEIGHBOUR_SOLICITATION] = "NeighborSolicits", + [NDISC_REDIRECT] = "Redirects", +}; + + +static const struct snmp_mib snmp6_udp6_list[] = { + SNMP_MIB_ITEM("Udp6InDatagrams", UDP_MIB_INDATAGRAMS), + SNMP_MIB_ITEM("Udp6NoPorts", UDP_MIB_NOPORTS), + SNMP_MIB_ITEM("Udp6InErrors", UDP_MIB_INERRORS), + SNMP_MIB_ITEM("Udp6OutDatagrams", UDP_MIB_OUTDATAGRAMS), + SNMP_MIB_ITEM("Udp6RcvbufErrors", UDP_MIB_RCVBUFERRORS), + SNMP_MIB_ITEM("Udp6SndbufErrors", UDP_MIB_SNDBUFERRORS), + SNMP_MIB_ITEM("Udp6InCsumErrors", UDP_MIB_CSUMERRORS), + SNMP_MIB_ITEM("Udp6IgnoredMulti", UDP_MIB_IGNOREDMULTI), + SNMP_MIB_ITEM("Udp6MemErrors", UDP_MIB_MEMERRORS), + SNMP_MIB_SENTINEL +}; + +static const struct snmp_mib snmp6_udplite6_list[] = { + SNMP_MIB_ITEM("UdpLite6InDatagrams", UDP_MIB_INDATAGRAMS), + SNMP_MIB_ITEM("UdpLite6NoPorts", UDP_MIB_NOPORTS), + SNMP_MIB_ITEM("UdpLite6InErrors", UDP_MIB_INERRORS), + SNMP_MIB_ITEM("UdpLite6OutDatagrams", UDP_MIB_OUTDATAGRAMS), + SNMP_MIB_ITEM("UdpLite6RcvbufErrors", UDP_MIB_RCVBUFERRORS), + SNMP_MIB_ITEM("UdpLite6SndbufErrors", UDP_MIB_SNDBUFERRORS), + SNMP_MIB_ITEM("UdpLite6InCsumErrors", UDP_MIB_CSUMERRORS), + SNMP_MIB_ITEM("UdpLite6MemErrors", UDP_MIB_MEMERRORS), + SNMP_MIB_SENTINEL +}; + +static void snmp6_seq_show_icmpv6msg(struct seq_file *seq, atomic_long_t *smib) +{ + char name[32]; + int i; + + /* print by name -- deprecated items */ + for (i = 0; i < ICMP6MSG_MIB_MAX; i++) { + int icmptype; + const char *p; + + icmptype = i & 0xff; + p = icmp6type2name[icmptype]; + if (!p) /* don't print un-named types here */ + continue; + snprintf(name, sizeof(name), "Icmp6%s%s", + i & 0x100 ? "Out" : "In", p); + seq_printf(seq, "%-32s\t%lu\n", name, + atomic_long_read(smib + i)); + } + + /* print by number (nonzero only) - ICMPMsgStat format */ + for (i = 0; i < ICMP6MSG_MIB_MAX; i++) { + unsigned long val; + + val = atomic_long_read(smib + i); + if (!val) + continue; + snprintf(name, sizeof(name), "Icmp6%sType%u", + i & 0x100 ? "Out" : "In", i & 0xff); + seq_printf(seq, "%-32s\t%lu\n", name, val); + } +} + +/* can be called either with percpu mib (pcpumib != NULL), + * or shared one (smib != NULL) + */ +static void snmp6_seq_show_item(struct seq_file *seq, void __percpu *pcpumib, + atomic_long_t *smib, + const struct snmp_mib *itemlist) +{ + unsigned long buff[SNMP_MIB_MAX]; + int i; + + if (pcpumib) { + memset(buff, 0, sizeof(unsigned long) * SNMP_MIB_MAX); + + snmp_get_cpu_field_batch(buff, itemlist, pcpumib); + for (i = 0; itemlist[i].name; i++) + seq_printf(seq, "%-32s\t%lu\n", + itemlist[i].name, buff[i]); + } else { + for (i = 0; itemlist[i].name; i++) + seq_printf(seq, "%-32s\t%lu\n", itemlist[i].name, + atomic_long_read(smib + itemlist[i].entry)); + } +} + +static void snmp6_seq_show_item64(struct seq_file *seq, void __percpu *mib, + const struct snmp_mib *itemlist, size_t syncpoff) +{ + u64 buff64[SNMP_MIB_MAX]; + int i; + + memset(buff64, 0, sizeof(u64) * SNMP_MIB_MAX); + + snmp_get_cpu_field64_batch(buff64, itemlist, mib, syncpoff); + for (i = 0; itemlist[i].name; i++) + seq_printf(seq, "%-32s\t%llu\n", itemlist[i].name, buff64[i]); +} + +static int snmp6_seq_show(struct seq_file *seq, void *v) +{ + struct net *net = (struct net *)seq->private; + + snmp6_seq_show_item64(seq, net->mib.ipv6_statistics, + snmp6_ipstats_list, offsetof(struct ipstats_mib, syncp)); + snmp6_seq_show_item(seq, net->mib.icmpv6_statistics, + NULL, snmp6_icmp6_list); + snmp6_seq_show_icmpv6msg(seq, net->mib.icmpv6msg_statistics->mibs); + snmp6_seq_show_item(seq, net->mib.udp_stats_in6, + NULL, snmp6_udp6_list); + snmp6_seq_show_item(seq, net->mib.udplite_stats_in6, + NULL, snmp6_udplite6_list); + return 0; +} + +static int snmp6_dev_seq_show(struct seq_file *seq, void *v) +{ + struct inet6_dev *idev = (struct inet6_dev *)seq->private; + + seq_printf(seq, "%-32s\t%u\n", "ifIndex", idev->dev->ifindex); + snmp6_seq_show_item64(seq, idev->stats.ipv6, + snmp6_ipstats_list, offsetof(struct ipstats_mib, syncp)); + snmp6_seq_show_item(seq, NULL, idev->stats.icmpv6dev->mibs, + snmp6_icmp6_list); + snmp6_seq_show_icmpv6msg(seq, idev->stats.icmpv6msgdev->mibs); + return 0; +} + +int snmp6_register_dev(struct inet6_dev *idev) +{ + struct proc_dir_entry *p; + struct net *net; + + if (!idev || !idev->dev) + return -EINVAL; + + net = dev_net(idev->dev); + if (!net->mib.proc_net_devsnmp6) + return -ENOENT; + + p = proc_create_single_data(idev->dev->name, 0444, + net->mib.proc_net_devsnmp6, snmp6_dev_seq_show, idev); + if (!p) + return -ENOMEM; + + idev->stats.proc_dir_entry = p; + return 0; +} + +int snmp6_unregister_dev(struct inet6_dev *idev) +{ + struct net *net = dev_net(idev->dev); + if (!net->mib.proc_net_devsnmp6) + return -ENOENT; + if (!idev->stats.proc_dir_entry) + return -EINVAL; + proc_remove(idev->stats.proc_dir_entry); + idev->stats.proc_dir_entry = NULL; + return 0; +} + +static int __net_init ipv6_proc_init_net(struct net *net) +{ + if (!proc_create_net_single("sockstat6", 0444, net->proc_net, + sockstat6_seq_show, NULL)) + return -ENOMEM; + + if (!proc_create_net_single("snmp6", 0444, net->proc_net, + snmp6_seq_show, NULL)) + goto proc_snmp6_fail; + + net->mib.proc_net_devsnmp6 = proc_mkdir("dev_snmp6", net->proc_net); + if (!net->mib.proc_net_devsnmp6) + goto proc_dev_snmp6_fail; + return 0; + +proc_dev_snmp6_fail: + remove_proc_entry("snmp6", net->proc_net); +proc_snmp6_fail: + remove_proc_entry("sockstat6", net->proc_net); + return -ENOMEM; +} + +static void __net_exit ipv6_proc_exit_net(struct net *net) +{ + remove_proc_entry("sockstat6", net->proc_net); + remove_proc_entry("dev_snmp6", net->proc_net); + remove_proc_entry("snmp6", net->proc_net); +} + +static struct pernet_operations ipv6_proc_ops = { + .init = ipv6_proc_init_net, + .exit = ipv6_proc_exit_net, +}; + +int __init ipv6_misc_proc_init(void) +{ + return register_pernet_subsys(&ipv6_proc_ops); +} + +void ipv6_misc_proc_exit(void) +{ + unregister_pernet_subsys(&ipv6_proc_ops); +} diff --git a/net/ipv6/protocol.c b/net/ipv6/protocol.c new file mode 100644 index 000000000..d4b1806ba --- /dev/null +++ b/net/ipv6/protocol.c @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * PF_INET6 protocol dispatch tables. + * + * Authors: Pedro Roque + */ + +/* + * Changes: + * + * Vince Laviano (vince@cs.stanford.edu) 16 May 2001 + * - Removed unused variable 'inet6_protocol_base' + * - Modified inet6_del_protocol() to correctly maintain copy bit. + */ +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_IPV6) +struct inet6_protocol __rcu *inet6_protos[MAX_INET_PROTOS] __read_mostly; +EXPORT_SYMBOL(inet6_protos); + +int inet6_add_protocol(const struct inet6_protocol *prot, unsigned char protocol) +{ + return !cmpxchg((const struct inet6_protocol **)&inet6_protos[protocol], + NULL, prot) ? 0 : -1; +} +EXPORT_SYMBOL(inet6_add_protocol); + +int inet6_del_protocol(const struct inet6_protocol *prot, unsigned char protocol) +{ + int ret; + + ret = (cmpxchg((const struct inet6_protocol **)&inet6_protos[protocol], + prot, NULL) == prot) ? 0 : -1; + + synchronize_net(); + + return ret; +} +EXPORT_SYMBOL(inet6_del_protocol); +#endif + +const struct net_offload __rcu *inet6_offloads[MAX_INET_PROTOS] __read_mostly; +EXPORT_SYMBOL(inet6_offloads); + +int inet6_add_offload(const struct net_offload *prot, unsigned char protocol) +{ + return !cmpxchg((const struct net_offload **)&inet6_offloads[protocol], + NULL, prot) ? 0 : -1; +} +EXPORT_SYMBOL(inet6_add_offload); + +int inet6_del_offload(const struct net_offload *prot, unsigned char protocol) +{ + int ret; + + ret = (cmpxchg((const struct net_offload **)&inet6_offloads[protocol], + prot, NULL) == prot) ? 0 : -1; + + synchronize_net(); + + return ret; +} +EXPORT_SYMBOL(inet6_del_offload); diff --git a/net/ipv6/raw.c b/net/ipv6/raw.c new file mode 100644 index 000000000..dc31752a7 --- /dev/null +++ b/net/ipv6/raw.c @@ -0,0 +1,1319 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * RAW sockets for IPv6 + * Linux INET6 implementation + * + * Authors: + * Pedro Roque + * + * Adapted from linux/net/ipv4/raw.c + * + * Fixes: + * Hideaki YOSHIFUJI : sin6_scope_id support + * YOSHIFUJI,H.@USAGI : raw checksum (RFC2292(bis) compliance) + * Kazunori MIYAZAWA @USAGI: change process style to use ip6_append_data + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_IPV6_MIP6) +#include +#endif +#include + +#include +#include +#include + +#include +#include +#include + +#define ICMPV6_HDRLEN 4 /* ICMPv6 header, RFC 4443 Section 2.1 */ + +struct raw_hashinfo raw_v6_hashinfo; +EXPORT_SYMBOL_GPL(raw_v6_hashinfo); + +bool raw_v6_match(struct net *net, struct sock *sk, unsigned short num, + const struct in6_addr *loc_addr, + const struct in6_addr *rmt_addr, int dif, int sdif) +{ + if (inet_sk(sk)->inet_num != num || + !net_eq(sock_net(sk), net) || + (!ipv6_addr_any(&sk->sk_v6_daddr) && + !ipv6_addr_equal(&sk->sk_v6_daddr, rmt_addr)) || + !raw_sk_bound_dev_eq(net, sk->sk_bound_dev_if, + dif, sdif)) + return false; + + if (ipv6_addr_any(&sk->sk_v6_rcv_saddr) || + ipv6_addr_equal(&sk->sk_v6_rcv_saddr, loc_addr) || + (ipv6_addr_is_multicast(loc_addr) && + inet6_mc_check(sk, loc_addr, rmt_addr))) + return true; + + return false; +} +EXPORT_SYMBOL_GPL(raw_v6_match); + +/* + * 0 - deliver + * 1 - block + */ +static int icmpv6_filter(const struct sock *sk, const struct sk_buff *skb) +{ + struct icmp6hdr _hdr; + const struct icmp6hdr *hdr; + + /* We require only the four bytes of the ICMPv6 header, not any + * additional bytes of message body in "struct icmp6hdr". + */ + hdr = skb_header_pointer(skb, skb_transport_offset(skb), + ICMPV6_HDRLEN, &_hdr); + if (hdr) { + const __u32 *data = &raw6_sk(sk)->filter.data[0]; + unsigned int type = hdr->icmp6_type; + + return (data[type >> 5] & (1U << (type & 31))) != 0; + } + return 1; +} + +#if IS_ENABLED(CONFIG_IPV6_MIP6) +typedef int mh_filter_t(struct sock *sock, struct sk_buff *skb); + +static mh_filter_t __rcu *mh_filter __read_mostly; + +int rawv6_mh_filter_register(mh_filter_t filter) +{ + rcu_assign_pointer(mh_filter, filter); + return 0; +} +EXPORT_SYMBOL(rawv6_mh_filter_register); + +int rawv6_mh_filter_unregister(mh_filter_t filter) +{ + RCU_INIT_POINTER(mh_filter, NULL); + synchronize_rcu(); + return 0; +} +EXPORT_SYMBOL(rawv6_mh_filter_unregister); + +#endif + +/* + * demultiplex raw sockets. + * (should consider queueing the skb in the sock receive_queue + * without calling rawv6.c) + * + * Caller owns SKB so we must make clones. + */ +static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr) +{ + struct net *net = dev_net(skb->dev); + const struct in6_addr *saddr; + const struct in6_addr *daddr; + struct hlist_head *hlist; + struct sock *sk; + bool delivered = false; + __u8 hash; + + saddr = &ipv6_hdr(skb)->saddr; + daddr = saddr + 1; + + hash = raw_hashfunc(net, nexthdr); + hlist = &raw_v6_hashinfo.ht[hash]; + rcu_read_lock(); + sk_for_each_rcu(sk, hlist) { + int filtered; + + if (!raw_v6_match(net, sk, nexthdr, daddr, saddr, + inet6_iif(skb), inet6_sdif(skb))) + continue; + delivered = true; + switch (nexthdr) { + case IPPROTO_ICMPV6: + filtered = icmpv6_filter(sk, skb); + break; + +#if IS_ENABLED(CONFIG_IPV6_MIP6) + case IPPROTO_MH: + { + /* XXX: To validate MH only once for each packet, + * this is placed here. It should be after checking + * xfrm policy, however it doesn't. The checking xfrm + * policy is placed in rawv6_rcv() because it is + * required for each socket. + */ + mh_filter_t *filter; + + filter = rcu_dereference(mh_filter); + filtered = filter ? (*filter)(sk, skb) : 0; + break; + } +#endif + default: + filtered = 0; + break; + } + + if (filtered < 0) + break; + if (filtered == 0) { + struct sk_buff *clone = skb_clone(skb, GFP_ATOMIC); + + /* Not releasing hash table! */ + if (clone) + rawv6_rcv(sk, clone); + } + } + rcu_read_unlock(); + return delivered; +} + +bool raw6_local_deliver(struct sk_buff *skb, int nexthdr) +{ + return ipv6_raw_deliver(skb, nexthdr); +} + +/* This cleans up af_inet6 a bit. -DaveM */ +static int rawv6_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len) +{ + struct inet_sock *inet = inet_sk(sk); + struct ipv6_pinfo *np = inet6_sk(sk); + struct sockaddr_in6 *addr = (struct sockaddr_in6 *) uaddr; + __be32 v4addr = 0; + int addr_type; + int err; + + if (addr_len < SIN6_LEN_RFC2133) + return -EINVAL; + + if (addr->sin6_family != AF_INET6) + return -EINVAL; + + addr_type = ipv6_addr_type(&addr->sin6_addr); + + /* Raw sockets are IPv6 only */ + if (addr_type == IPV6_ADDR_MAPPED) + return -EADDRNOTAVAIL; + + lock_sock(sk); + + err = -EINVAL; + if (sk->sk_state != TCP_CLOSE) + goto out; + + rcu_read_lock(); + /* Check if the address belongs to the host. */ + if (addr_type != IPV6_ADDR_ANY) { + struct net_device *dev = NULL; + + if (__ipv6_addr_needs_scope_id(addr_type)) { + if (addr_len >= sizeof(struct sockaddr_in6) && + addr->sin6_scope_id) { + /* Override any existing binding, if another + * one is supplied by user. + */ + sk->sk_bound_dev_if = addr->sin6_scope_id; + } + + /* Binding to link-local address requires an interface */ + if (!sk->sk_bound_dev_if) + goto out_unlock; + } + + if (sk->sk_bound_dev_if) { + err = -ENODEV; + dev = dev_get_by_index_rcu(sock_net(sk), + sk->sk_bound_dev_if); + if (!dev) + goto out_unlock; + } + + /* ipv4 addr of the socket is invalid. Only the + * unspecified and mapped address have a v4 equivalent. + */ + v4addr = LOOPBACK4_IPV6; + if (!(addr_type & IPV6_ADDR_MULTICAST) && + !ipv6_can_nonlocal_bind(sock_net(sk), inet)) { + err = -EADDRNOTAVAIL; + if (!ipv6_chk_addr(sock_net(sk), &addr->sin6_addr, + dev, 0)) { + goto out_unlock; + } + } + } + + inet->inet_rcv_saddr = inet->inet_saddr = v4addr; + sk->sk_v6_rcv_saddr = addr->sin6_addr; + if (!(addr_type & IPV6_ADDR_MULTICAST)) + np->saddr = addr->sin6_addr; + err = 0; +out_unlock: + rcu_read_unlock(); +out: + release_sock(sk); + return err; +} + +static void rawv6_err(struct sock *sk, struct sk_buff *skb, + struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + struct inet_sock *inet = inet_sk(sk); + struct ipv6_pinfo *np = inet6_sk(sk); + int err; + int harderr; + + /* Report error on raw socket, if: + 1. User requested recverr. + 2. Socket is connected (otherwise the error indication + is useless without recverr and error is hard. + */ + if (!np->recverr && sk->sk_state != TCP_ESTABLISHED) + return; + + harderr = icmpv6_err_convert(type, code, &err); + if (type == ICMPV6_PKT_TOOBIG) { + ip6_sk_update_pmtu(skb, sk, info); + harderr = (np->pmtudisc == IPV6_PMTUDISC_DO); + } + if (type == NDISC_REDIRECT) { + ip6_sk_redirect(skb, sk); + return; + } + if (np->recverr) { + u8 *payload = skb->data; + if (!inet->hdrincl) + payload += offset; + ipv6_icmp_error(sk, skb, err, 0, ntohl(info), payload); + } + + if (np->recverr || harderr) { + sk->sk_err = err; + sk_error_report(sk); + } +} + +void raw6_icmp_error(struct sk_buff *skb, int nexthdr, + u8 type, u8 code, int inner_offset, __be32 info) +{ + struct net *net = dev_net(skb->dev); + struct hlist_head *hlist; + struct sock *sk; + int hash; + + hash = raw_hashfunc(net, nexthdr); + hlist = &raw_v6_hashinfo.ht[hash]; + rcu_read_lock(); + sk_for_each_rcu(sk, hlist) { + /* Note: ipv6_hdr(skb) != skb->data */ + const struct ipv6hdr *ip6h = (const struct ipv6hdr *)skb->data; + + if (!raw_v6_match(net, sk, nexthdr, &ip6h->saddr, &ip6h->daddr, + inet6_iif(skb), inet6_iif(skb))) + continue; + rawv6_err(sk, skb, NULL, type, code, inner_offset, info); + } + rcu_read_unlock(); +} + +static inline int rawv6_rcv_skb(struct sock *sk, struct sk_buff *skb) +{ + if ((raw6_sk(sk)->checksum || rcu_access_pointer(sk->sk_filter)) && + skb_checksum_complete(skb)) { + atomic_inc(&sk->sk_drops); + kfree_skb(skb); + return NET_RX_DROP; + } + + /* Charge it to the socket. */ + skb_dst_drop(skb); + if (sock_queue_rcv_skb(sk, skb) < 0) { + kfree_skb(skb); + return NET_RX_DROP; + } + + return 0; +} + +/* + * This is next to useless... + * if we demultiplex in network layer we don't need the extra call + * just to queue the skb... + * maybe we could have the network decide upon a hint if it + * should call raw_rcv for demultiplexing + */ +int rawv6_rcv(struct sock *sk, struct sk_buff *skb) +{ + struct inet_sock *inet = inet_sk(sk); + struct raw6_sock *rp = raw6_sk(sk); + + if (!xfrm6_policy_check(sk, XFRM_POLICY_IN, skb)) { + atomic_inc(&sk->sk_drops); + kfree_skb(skb); + return NET_RX_DROP; + } + nf_reset_ct(skb); + + if (!rp->checksum) + skb->ip_summed = CHECKSUM_UNNECESSARY; + + if (skb->ip_summed == CHECKSUM_COMPLETE) { + skb_postpull_rcsum(skb, skb_network_header(skb), + skb_network_header_len(skb)); + if (!csum_ipv6_magic(&ipv6_hdr(skb)->saddr, + &ipv6_hdr(skb)->daddr, + skb->len, inet->inet_num, skb->csum)) + skb->ip_summed = CHECKSUM_UNNECESSARY; + } + if (!skb_csum_unnecessary(skb)) + skb->csum = ~csum_unfold(csum_ipv6_magic(&ipv6_hdr(skb)->saddr, + &ipv6_hdr(skb)->daddr, + skb->len, + inet->inet_num, 0)); + + if (inet->hdrincl) { + if (skb_checksum_complete(skb)) { + atomic_inc(&sk->sk_drops); + kfree_skb(skb); + return NET_RX_DROP; + } + } + + rawv6_rcv_skb(sk, skb); + return 0; +} + + +/* + * This should be easy, if there is something there + * we return it, otherwise we block. + */ + +static int rawv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int flags, int *addr_len) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + DECLARE_SOCKADDR(struct sockaddr_in6 *, sin6, msg->msg_name); + struct sk_buff *skb; + size_t copied; + int err; + + if (flags & MSG_OOB) + return -EOPNOTSUPP; + + if (flags & MSG_ERRQUEUE) + return ipv6_recv_error(sk, msg, len, addr_len); + + if (np->rxpmtu && np->rxopt.bits.rxpmtu) + return ipv6_recv_rxpmtu(sk, msg, len, addr_len); + + skb = skb_recv_datagram(sk, flags, &err); + if (!skb) + goto out; + + copied = skb->len; + if (copied > len) { + copied = len; + msg->msg_flags |= MSG_TRUNC; + } + + if (skb_csum_unnecessary(skb)) { + err = skb_copy_datagram_msg(skb, 0, msg, copied); + } else if (msg->msg_flags&MSG_TRUNC) { + if (__skb_checksum_complete(skb)) + goto csum_copy_err; + err = skb_copy_datagram_msg(skb, 0, msg, copied); + } else { + err = skb_copy_and_csum_datagram_msg(skb, 0, msg); + if (err == -EINVAL) + goto csum_copy_err; + } + if (err) + goto out_free; + + /* Copy the address. */ + if (sin6) { + sin6->sin6_family = AF_INET6; + sin6->sin6_port = 0; + sin6->sin6_addr = ipv6_hdr(skb)->saddr; + sin6->sin6_flowinfo = 0; + sin6->sin6_scope_id = ipv6_iface_scope_id(&sin6->sin6_addr, + inet6_iif(skb)); + *addr_len = sizeof(*sin6); + } + + sock_recv_cmsgs(msg, sk, skb); + + if (np->rxopt.all) + ip6_datagram_recv_ctl(sk, msg, skb); + + err = copied; + if (flags & MSG_TRUNC) + err = skb->len; + +out_free: + skb_free_datagram(sk, skb); +out: + return err; + +csum_copy_err: + skb_kill_datagram(sk, skb, flags); + + /* Error for blocking case is chosen to masquerade + as some normal condition. + */ + err = (flags&MSG_DONTWAIT) ? -EAGAIN : -EHOSTUNREACH; + goto out; +} + +static int rawv6_push_pending_frames(struct sock *sk, struct flowi6 *fl6, + struct raw6_sock *rp) +{ + struct ipv6_txoptions *opt; + struct sk_buff *skb; + int err = 0; + int offset; + int len; + int total_len; + __wsum tmp_csum; + __sum16 csum; + + if (!rp->checksum) + goto send; + + skb = skb_peek(&sk->sk_write_queue); + if (!skb) + goto out; + + offset = rp->offset; + total_len = inet_sk(sk)->cork.base.length; + opt = inet6_sk(sk)->cork.opt; + total_len -= opt ? opt->opt_flen : 0; + + if (offset >= total_len - 1) { + err = -EINVAL; + ip6_flush_pending_frames(sk); + goto out; + } + + /* should be check HW csum miyazawa */ + if (skb_queue_len(&sk->sk_write_queue) == 1) { + /* + * Only one fragment on the socket. + */ + tmp_csum = skb->csum; + } else { + struct sk_buff *csum_skb = NULL; + tmp_csum = 0; + + skb_queue_walk(&sk->sk_write_queue, skb) { + tmp_csum = csum_add(tmp_csum, skb->csum); + + if (csum_skb) + continue; + + len = skb->len - skb_transport_offset(skb); + if (offset >= len) { + offset -= len; + continue; + } + + csum_skb = skb; + } + + skb = csum_skb; + } + + offset += skb_transport_offset(skb); + err = skb_copy_bits(skb, offset, &csum, 2); + if (err < 0) { + ip6_flush_pending_frames(sk); + goto out; + } + + /* in case cksum was not initialized */ + if (unlikely(csum)) + tmp_csum = csum_sub(tmp_csum, csum_unfold(csum)); + + csum = csum_ipv6_magic(&fl6->saddr, &fl6->daddr, + total_len, fl6->flowi6_proto, tmp_csum); + + if (csum == 0 && fl6->flowi6_proto == IPPROTO_UDP) + csum = CSUM_MANGLED_0; + + BUG_ON(skb_store_bits(skb, offset, &csum, 2)); + +send: + err = ip6_push_pending_frames(sk); +out: + return err; +} + +static int rawv6_send_hdrinc(struct sock *sk, struct msghdr *msg, int length, + struct flowi6 *fl6, struct dst_entry **dstp, + unsigned int flags, const struct sockcm_cookie *sockc) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct net *net = sock_net(sk); + struct ipv6hdr *iph; + struct sk_buff *skb; + int err; + struct rt6_info *rt = (struct rt6_info *)*dstp; + int hlen = LL_RESERVED_SPACE(rt->dst.dev); + int tlen = rt->dst.dev->needed_tailroom; + + if (length > rt->dst.dev->mtu) { + ipv6_local_error(sk, EMSGSIZE, fl6, rt->dst.dev->mtu); + return -EMSGSIZE; + } + if (length < sizeof(struct ipv6hdr)) + return -EINVAL; + if (flags&MSG_PROBE) + goto out; + + skb = sock_alloc_send_skb(sk, + length + hlen + tlen + 15, + flags & MSG_DONTWAIT, &err); + if (!skb) + goto error; + skb_reserve(skb, hlen); + + skb->protocol = htons(ETH_P_IPV6); + skb->priority = READ_ONCE(sk->sk_priority); + skb->mark = sockc->mark; + skb->tstamp = sockc->transmit_time; + + skb_put(skb, length); + skb_reset_network_header(skb); + iph = ipv6_hdr(skb); + + skb->ip_summed = CHECKSUM_NONE; + + skb_setup_tx_timestamp(skb, sockc->tsflags); + + if (flags & MSG_CONFIRM) + skb_set_dst_pending_confirm(skb, 1); + + skb->transport_header = skb->network_header; + err = memcpy_from_msg(iph, msg, length); + if (err) { + err = -EFAULT; + kfree_skb(skb); + goto error; + } + + skb_dst_set(skb, &rt->dst); + *dstp = NULL; + + /* if egress device is enslaved to an L3 master device pass the + * skb to its handler for processing + */ + skb = l3mdev_ip6_out(sk, skb); + if (unlikely(!skb)) + return 0; + + /* Acquire rcu_read_lock() in case we need to use rt->rt6i_idev + * in the error path. Since skb has been freed, the dst could + * have been queued for deletion. + */ + rcu_read_lock(); + IP6_UPD_PO_STATS(net, rt->rt6i_idev, IPSTATS_MIB_OUT, skb->len); + err = NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, net, sk, skb, + NULL, rt->dst.dev, dst_output); + if (err > 0) + err = net_xmit_errno(err); + if (err) { + IP6_INC_STATS(net, rt->rt6i_idev, IPSTATS_MIB_OUTDISCARDS); + rcu_read_unlock(); + goto error_check; + } + rcu_read_unlock(); +out: + return 0; + +error: + IP6_INC_STATS(net, rt->rt6i_idev, IPSTATS_MIB_OUTDISCARDS); +error_check: + if (err == -ENOBUFS && !np->recverr) + err = 0; + return err; +} + +struct raw6_frag_vec { + struct msghdr *msg; + int hlen; + char c[4]; +}; + +static int rawv6_probe_proto_opt(struct raw6_frag_vec *rfv, struct flowi6 *fl6) +{ + int err = 0; + switch (fl6->flowi6_proto) { + case IPPROTO_ICMPV6: + rfv->hlen = 2; + err = memcpy_from_msg(rfv->c, rfv->msg, rfv->hlen); + if (!err) { + fl6->fl6_icmp_type = rfv->c[0]; + fl6->fl6_icmp_code = rfv->c[1]; + } + break; + case IPPROTO_MH: + rfv->hlen = 4; + err = memcpy_from_msg(rfv->c, rfv->msg, rfv->hlen); + if (!err) + fl6->fl6_mh_type = rfv->c[2]; + } + return err; +} + +static int raw6_getfrag(void *from, char *to, int offset, int len, int odd, + struct sk_buff *skb) +{ + struct raw6_frag_vec *rfv = from; + + if (offset < rfv->hlen) { + int copy = min(rfv->hlen - offset, len); + + if (skb->ip_summed == CHECKSUM_PARTIAL) + memcpy(to, rfv->c + offset, copy); + else + skb->csum = csum_block_add( + skb->csum, + csum_partial_copy_nocheck(rfv->c + offset, + to, copy), + odd); + + odd = 0; + offset += copy; + to += copy; + len -= copy; + + if (!len) + return 0; + } + + offset -= rfv->hlen; + + return ip_generic_getfrag(rfv->msg, to, offset, len, odd, skb); +} + +static int rawv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) +{ + struct ipv6_txoptions *opt_to_free = NULL; + struct ipv6_txoptions opt_space; + DECLARE_SOCKADDR(struct sockaddr_in6 *, sin6, msg->msg_name); + struct in6_addr *daddr, *final_p, final; + struct inet_sock *inet = inet_sk(sk); + struct ipv6_pinfo *np = inet6_sk(sk); + struct raw6_sock *rp = raw6_sk(sk); + struct ipv6_txoptions *opt = NULL; + struct ip6_flowlabel *flowlabel = NULL; + struct dst_entry *dst = NULL; + struct raw6_frag_vec rfv; + struct flowi6 fl6; + struct ipcm6_cookie ipc6; + int addr_len = msg->msg_namelen; + int hdrincl; + u16 proto; + int err; + + /* Rough check on arithmetic overflow, + better check is made in ip6_append_data(). + */ + if (len > INT_MAX) + return -EMSGSIZE; + + /* Mirror BSD error message compatibility */ + if (msg->msg_flags & MSG_OOB) + return -EOPNOTSUPP; + + /* hdrincl should be READ_ONCE(inet->hdrincl) + * but READ_ONCE() doesn't work with bit fields. + * Doing this indirectly yields the same result. + */ + hdrincl = inet->hdrincl; + hdrincl = READ_ONCE(hdrincl); + + /* + * Get and verify the address. + */ + memset(&fl6, 0, sizeof(fl6)); + + fl6.flowi6_mark = READ_ONCE(sk->sk_mark); + fl6.flowi6_uid = sk->sk_uid; + + ipcm6_init(&ipc6); + ipc6.sockc.tsflags = READ_ONCE(sk->sk_tsflags); + ipc6.sockc.mark = fl6.flowi6_mark; + + if (sin6) { + if (addr_len < SIN6_LEN_RFC2133) + return -EINVAL; + + if (sin6->sin6_family && sin6->sin6_family != AF_INET6) + return -EAFNOSUPPORT; + + /* port is the proto value [0..255] carried in nexthdr */ + proto = ntohs(sin6->sin6_port); + + if (!proto) + proto = inet->inet_num; + else if (proto != inet->inet_num && + inet->inet_num != IPPROTO_RAW) + return -EINVAL; + + if (proto > 255) + return -EINVAL; + + daddr = &sin6->sin6_addr; + if (np->sndflow) { + fl6.flowlabel = sin6->sin6_flowinfo&IPV6_FLOWINFO_MASK; + if (fl6.flowlabel&IPV6_FLOWLABEL_MASK) { + flowlabel = fl6_sock_lookup(sk, fl6.flowlabel); + if (IS_ERR(flowlabel)) + return -EINVAL; + } + } + + /* + * Otherwise it will be difficult to maintain + * sk->sk_dst_cache. + */ + if (sk->sk_state == TCP_ESTABLISHED && + ipv6_addr_equal(daddr, &sk->sk_v6_daddr)) + daddr = &sk->sk_v6_daddr; + + if (addr_len >= sizeof(struct sockaddr_in6) && + sin6->sin6_scope_id && + __ipv6_addr_needs_scope_id(__ipv6_addr_type(daddr))) + fl6.flowi6_oif = sin6->sin6_scope_id; + } else { + if (sk->sk_state != TCP_ESTABLISHED) + return -EDESTADDRREQ; + + proto = inet->inet_num; + daddr = &sk->sk_v6_daddr; + fl6.flowlabel = np->flow_label; + } + + if (fl6.flowi6_oif == 0) + fl6.flowi6_oif = sk->sk_bound_dev_if; + + if (msg->msg_controllen) { + opt = &opt_space; + memset(opt, 0, sizeof(struct ipv6_txoptions)); + opt->tot_len = sizeof(struct ipv6_txoptions); + ipc6.opt = opt; + + err = ip6_datagram_send_ctl(sock_net(sk), sk, msg, &fl6, &ipc6); + if (err < 0) { + fl6_sock_release(flowlabel); + return err; + } + if ((fl6.flowlabel&IPV6_FLOWLABEL_MASK) && !flowlabel) { + flowlabel = fl6_sock_lookup(sk, fl6.flowlabel); + if (IS_ERR(flowlabel)) + return -EINVAL; + } + if (!(opt->opt_nflen|opt->opt_flen)) + opt = NULL; + } + if (!opt) { + opt = txopt_get(np); + opt_to_free = opt; + } + if (flowlabel) + opt = fl6_merge_options(&opt_space, flowlabel, opt); + opt = ipv6_fixup_options(&opt_space, opt); + + fl6.flowi6_proto = proto; + fl6.flowi6_mark = ipc6.sockc.mark; + + if (!hdrincl) { + rfv.msg = msg; + rfv.hlen = 0; + err = rawv6_probe_proto_opt(&rfv, &fl6); + if (err) + goto out; + } + + if (!ipv6_addr_any(daddr)) + fl6.daddr = *daddr; + else + fl6.daddr.s6_addr[15] = 0x1; /* :: means loopback (BSD'ism) */ + if (ipv6_addr_any(&fl6.saddr) && !ipv6_addr_any(&np->saddr)) + fl6.saddr = np->saddr; + + final_p = fl6_update_dst(&fl6, opt, &final); + + if (!fl6.flowi6_oif && ipv6_addr_is_multicast(&fl6.daddr)) + fl6.flowi6_oif = np->mcast_oif; + else if (!fl6.flowi6_oif) + fl6.flowi6_oif = np->ucast_oif; + security_sk_classify_flow(sk, flowi6_to_flowi_common(&fl6)); + + if (hdrincl) + fl6.flowi6_flags |= FLOWI_FLAG_KNOWN_NH; + + if (ipc6.tclass < 0) + ipc6.tclass = np->tclass; + + fl6.flowlabel = ip6_make_flowinfo(ipc6.tclass, fl6.flowlabel); + + dst = ip6_dst_lookup_flow(sock_net(sk), sk, &fl6, final_p); + if (IS_ERR(dst)) { + err = PTR_ERR(dst); + goto out; + } + if (ipc6.hlimit < 0) + ipc6.hlimit = ip6_sk_dst_hoplimit(np, &fl6, dst); + + if (ipc6.dontfrag < 0) + ipc6.dontfrag = np->dontfrag; + + if (msg->msg_flags&MSG_CONFIRM) + goto do_confirm; + +back_from_confirm: + if (hdrincl) + err = rawv6_send_hdrinc(sk, msg, len, &fl6, &dst, + msg->msg_flags, &ipc6.sockc); + else { + ipc6.opt = opt; + lock_sock(sk); + err = ip6_append_data(sk, raw6_getfrag, &rfv, + len, 0, &ipc6, &fl6, (struct rt6_info *)dst, + msg->msg_flags); + + if (err) + ip6_flush_pending_frames(sk); + else if (!(msg->msg_flags & MSG_MORE)) + err = rawv6_push_pending_frames(sk, &fl6, rp); + release_sock(sk); + } +done: + dst_release(dst); +out: + fl6_sock_release(flowlabel); + txopt_put(opt_to_free); + return err < 0 ? err : len; +do_confirm: + if (msg->msg_flags & MSG_PROBE) + dst_confirm_neigh(dst, &fl6.daddr); + if (!(msg->msg_flags & MSG_PROBE) || len) + goto back_from_confirm; + err = 0; + goto done; +} + +static int rawv6_seticmpfilter(struct sock *sk, int level, int optname, + sockptr_t optval, int optlen) +{ + switch (optname) { + case ICMPV6_FILTER: + if (optlen > sizeof(struct icmp6_filter)) + optlen = sizeof(struct icmp6_filter); + if (copy_from_sockptr(&raw6_sk(sk)->filter, optval, optlen)) + return -EFAULT; + return 0; + default: + return -ENOPROTOOPT; + } + + return 0; +} + +static int rawv6_geticmpfilter(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen) +{ + int len; + + switch (optname) { + case ICMPV6_FILTER: + if (get_user(len, optlen)) + return -EFAULT; + if (len < 0) + return -EINVAL; + if (len > sizeof(struct icmp6_filter)) + len = sizeof(struct icmp6_filter); + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &raw6_sk(sk)->filter, len)) + return -EFAULT; + return 0; + default: + return -ENOPROTOOPT; + } + + return 0; +} + + +static int do_rawv6_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct raw6_sock *rp = raw6_sk(sk); + int val; + + if (optlen < sizeof(val)) + return -EINVAL; + + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + + switch (optname) { + case IPV6_HDRINCL: + if (sk->sk_type != SOCK_RAW) + return -EINVAL; + inet_sk(sk)->hdrincl = !!val; + return 0; + case IPV6_CHECKSUM: + if (inet_sk(sk)->inet_num == IPPROTO_ICMPV6 && + level == IPPROTO_IPV6) { + /* + * RFC3542 tells that IPV6_CHECKSUM socket + * option in the IPPROTO_IPV6 level is not + * allowed on ICMPv6 sockets. + * If you want to set it, use IPPROTO_RAW + * level IPV6_CHECKSUM socket option + * (Linux extension). + */ + return -EINVAL; + } + + /* You may get strange result with a positive odd offset; + RFC2292bis agrees with me. */ + if (val > 0 && (val&1)) + return -EINVAL; + if (val < 0) { + rp->checksum = 0; + } else { + rp->checksum = 1; + rp->offset = val; + } + + return 0; + + default: + return -ENOPROTOOPT; + } +} + +static int rawv6_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + switch (level) { + case SOL_RAW: + break; + + case SOL_ICMPV6: + if (inet_sk(sk)->inet_num != IPPROTO_ICMPV6) + return -EOPNOTSUPP; + return rawv6_seticmpfilter(sk, level, optname, optval, optlen); + case SOL_IPV6: + if (optname == IPV6_CHECKSUM || + optname == IPV6_HDRINCL) + break; + fallthrough; + default: + return ipv6_setsockopt(sk, level, optname, optval, optlen); + } + + return do_rawv6_setsockopt(sk, level, optname, optval, optlen); +} + +static int do_rawv6_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct raw6_sock *rp = raw6_sk(sk); + int val, len; + + if (get_user(len, optlen)) + return -EFAULT; + + switch (optname) { + case IPV6_HDRINCL: + val = inet_sk(sk)->hdrincl; + break; + case IPV6_CHECKSUM: + /* + * We allow getsockopt() for IPPROTO_IPV6-level + * IPV6_CHECKSUM socket option on ICMPv6 sockets + * since RFC3542 is silent about it. + */ + if (rp->checksum == 0) + val = -1; + else + val = rp->offset; + break; + + default: + return -ENOPROTOOPT; + } + + len = min_t(unsigned int, sizeof(int), len); + + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + return 0; +} + +static int rawv6_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen) +{ + switch (level) { + case SOL_RAW: + break; + + case SOL_ICMPV6: + if (inet_sk(sk)->inet_num != IPPROTO_ICMPV6) + return -EOPNOTSUPP; + return rawv6_geticmpfilter(sk, level, optname, optval, optlen); + case SOL_IPV6: + if (optname == IPV6_CHECKSUM || + optname == IPV6_HDRINCL) + break; + fallthrough; + default: + return ipv6_getsockopt(sk, level, optname, optval, optlen); + } + + return do_rawv6_getsockopt(sk, level, optname, optval, optlen); +} + +static int rawv6_ioctl(struct sock *sk, int cmd, unsigned long arg) +{ + switch (cmd) { + case SIOCOUTQ: { + int amount = sk_wmem_alloc_get(sk); + + return put_user(amount, (int __user *)arg); + } + case SIOCINQ: { + struct sk_buff *skb; + int amount = 0; + + spin_lock_bh(&sk->sk_receive_queue.lock); + skb = skb_peek(&sk->sk_receive_queue); + if (skb) + amount = skb->len; + spin_unlock_bh(&sk->sk_receive_queue.lock); + return put_user(amount, (int __user *)arg); + } + + default: +#ifdef CONFIG_IPV6_MROUTE + return ip6mr_ioctl(sk, cmd, (void __user *)arg); +#else + return -ENOIOCTLCMD; +#endif + } +} + +#ifdef CONFIG_COMPAT +static int compat_rawv6_ioctl(struct sock *sk, unsigned int cmd, unsigned long arg) +{ + switch (cmd) { + case SIOCOUTQ: + case SIOCINQ: + return -ENOIOCTLCMD; + default: +#ifdef CONFIG_IPV6_MROUTE + return ip6mr_compat_ioctl(sk, cmd, compat_ptr(arg)); +#else + return -ENOIOCTLCMD; +#endif + } +} +#endif + +static void rawv6_close(struct sock *sk, long timeout) +{ + if (inet_sk(sk)->inet_num == IPPROTO_RAW) + ip6_ra_control(sk, -1); + ip6mr_sk_done(sk); + sk_common_release(sk); +} + +static void raw6_destroy(struct sock *sk) +{ + lock_sock(sk); + ip6_flush_pending_frames(sk); + release_sock(sk); +} + +static int rawv6_init_sk(struct sock *sk) +{ + struct raw6_sock *rp = raw6_sk(sk); + + switch (inet_sk(sk)->inet_num) { + case IPPROTO_ICMPV6: + rp->checksum = 1; + rp->offset = 2; + break; + case IPPROTO_MH: + rp->checksum = 1; + rp->offset = 4; + break; + default: + break; + } + return 0; +} + +struct proto rawv6_prot = { + .name = "RAWv6", + .owner = THIS_MODULE, + .close = rawv6_close, + .destroy = raw6_destroy, + .connect = ip6_datagram_connect_v6_only, + .disconnect = __udp_disconnect, + .ioctl = rawv6_ioctl, + .init = rawv6_init_sk, + .setsockopt = rawv6_setsockopt, + .getsockopt = rawv6_getsockopt, + .sendmsg = rawv6_sendmsg, + .recvmsg = rawv6_recvmsg, + .bind = rawv6_bind, + .backlog_rcv = rawv6_rcv_skb, + .hash = raw_hash_sk, + .unhash = raw_unhash_sk, + .obj_size = sizeof(struct raw6_sock), + .useroffset = offsetof(struct raw6_sock, filter), + .usersize = sizeof_field(struct raw6_sock, filter), + .h.raw_hash = &raw_v6_hashinfo, +#ifdef CONFIG_COMPAT + .compat_ioctl = compat_rawv6_ioctl, +#endif + .diag_destroy = raw_abort, +}; + +#ifdef CONFIG_PROC_FS +static int raw6_seq_show(struct seq_file *seq, void *v) +{ + if (v == SEQ_START_TOKEN) { + seq_puts(seq, IPV6_SEQ_DGRAM_HEADER); + } else { + struct sock *sp = v; + __u16 srcp = inet_sk(sp)->inet_num; + ip6_dgram_sock_seq_show(seq, v, srcp, 0, + raw_seq_private(seq)->bucket); + } + return 0; +} + +static const struct seq_operations raw6_seq_ops = { + .start = raw_seq_start, + .next = raw_seq_next, + .stop = raw_seq_stop, + .show = raw6_seq_show, +}; + +static int __net_init raw6_init_net(struct net *net) +{ + if (!proc_create_net_data("raw6", 0444, net->proc_net, &raw6_seq_ops, + sizeof(struct raw_iter_state), &raw_v6_hashinfo)) + return -ENOMEM; + + return 0; +} + +static void __net_exit raw6_exit_net(struct net *net) +{ + remove_proc_entry("raw6", net->proc_net); +} + +static struct pernet_operations raw6_net_ops = { + .init = raw6_init_net, + .exit = raw6_exit_net, +}; + +int __init raw6_proc_init(void) +{ + return register_pernet_subsys(&raw6_net_ops); +} + +void raw6_proc_exit(void) +{ + unregister_pernet_subsys(&raw6_net_ops); +} +#endif /* CONFIG_PROC_FS */ + +/* Same as inet6_dgram_ops, sans udp_poll. */ +const struct proto_ops inet6_sockraw_ops = { + .family = PF_INET6, + .owner = THIS_MODULE, + .release = inet6_release, + .bind = inet6_bind, + .connect = inet_dgram_connect, /* ok */ + .socketpair = sock_no_socketpair, /* a do nothing */ + .accept = sock_no_accept, /* a do nothing */ + .getname = inet6_getname, + .poll = datagram_poll, /* ok */ + .ioctl = inet6_ioctl, /* must change */ + .gettstamp = sock_gettstamp, + .listen = sock_no_listen, /* ok */ + .shutdown = inet_shutdown, /* ok */ + .setsockopt = sock_common_setsockopt, /* ok */ + .getsockopt = sock_common_getsockopt, /* ok */ + .sendmsg = inet_sendmsg, /* ok */ + .recvmsg = sock_common_recvmsg, /* ok */ + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +#ifdef CONFIG_COMPAT + .compat_ioctl = inet6_compat_ioctl, +#endif +}; + +static struct inet_protosw rawv6_protosw = { + .type = SOCK_RAW, + .protocol = IPPROTO_IP, /* wild card */ + .prot = &rawv6_prot, + .ops = &inet6_sockraw_ops, + .flags = INET_PROTOSW_REUSE, +}; + +int __init rawv6_init(void) +{ + return inet6_register_protosw(&rawv6_protosw); +} + +void rawv6_exit(void) +{ + inet6_unregister_protosw(&rawv6_protosw); +} diff --git a/net/ipv6/reassembly.c b/net/ipv6/reassembly.c new file mode 100644 index 000000000..ff866f2a8 --- /dev/null +++ b/net/ipv6/reassembly.c @@ -0,0 +1,612 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPv6 fragment reassembly + * Linux INET6 implementation + * + * Authors: + * Pedro Roque + * + * Based on: net/ipv4/ip_fragment.c + */ + +/* + * Fixes: + * Andi Kleen Make it work with multiple hosts. + * More RFC compliance. + * + * Horst von Brand Add missing #include + * Alexey Kuznetsov SMP races, threading, cleanup. + * Patrick McHardy LRU queue of frag heads for evictor. + * Mitsuru KANDA @USAGI Register inet6_protocol{}. + * David Stevens and + * YOSHIFUJI,H. @USAGI Always remove fragment header to + * calculate ICV correctly. + */ + +#define pr_fmt(fmt) "IPv6: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static const char ip6_frag_cache_name[] = "ip6-frags"; + +static u8 ip6_frag_ecn(const struct ipv6hdr *ipv6h) +{ + return 1 << (ipv6_get_dsfield(ipv6h) & INET_ECN_MASK); +} + +static struct inet_frags ip6_frags; + +static int ip6_frag_reasm(struct frag_queue *fq, struct sk_buff *skb, + struct sk_buff *prev_tail, struct net_device *dev); + +static void ip6_frag_expire(struct timer_list *t) +{ + struct inet_frag_queue *frag = from_timer(frag, t, timer); + struct frag_queue *fq; + + fq = container_of(frag, struct frag_queue, q); + + ip6frag_expire_frag_queue(fq->q.fqdir->net, fq); +} + +static struct frag_queue * +fq_find(struct net *net, __be32 id, const struct ipv6hdr *hdr, int iif) +{ + struct frag_v6_compare_key key = { + .id = id, + .saddr = hdr->saddr, + .daddr = hdr->daddr, + .user = IP6_DEFRAG_LOCAL_DELIVER, + .iif = iif, + }; + struct inet_frag_queue *q; + + if (!(ipv6_addr_type(&hdr->daddr) & (IPV6_ADDR_MULTICAST | + IPV6_ADDR_LINKLOCAL))) + key.iif = 0; + + q = inet_frag_find(net->ipv6.fqdir, &key); + if (!q) + return NULL; + + return container_of(q, struct frag_queue, q); +} + +static int ip6_frag_queue(struct frag_queue *fq, struct sk_buff *skb, + struct frag_hdr *fhdr, int nhoff, + u32 *prob_offset) +{ + struct net *net = dev_net(skb_dst(skb)->dev); + int offset, end, fragsize; + struct sk_buff *prev_tail; + struct net_device *dev; + int err = -ENOENT; + u8 ecn; + + if (fq->q.flags & INET_FRAG_COMPLETE) + goto err; + + err = -EINVAL; + offset = ntohs(fhdr->frag_off) & ~0x7; + end = offset + (ntohs(ipv6_hdr(skb)->payload_len) - + ((u8 *)(fhdr + 1) - (u8 *)(ipv6_hdr(skb) + 1))); + + if ((unsigned int)end > IPV6_MAXPLEN) { + *prob_offset = (u8 *)&fhdr->frag_off - skb_network_header(skb); + /* note that if prob_offset is set, the skb is freed elsewhere, + * we do not free it here. + */ + return -1; + } + + ecn = ip6_frag_ecn(ipv6_hdr(skb)); + + if (skb->ip_summed == CHECKSUM_COMPLETE) { + const unsigned char *nh = skb_network_header(skb); + skb->csum = csum_sub(skb->csum, + csum_partial(nh, (u8 *)(fhdr + 1) - nh, + 0)); + } + + /* Is this the final fragment? */ + if (!(fhdr->frag_off & htons(IP6_MF))) { + /* If we already have some bits beyond end + * or have different end, the segment is corrupted. + */ + if (end < fq->q.len || + ((fq->q.flags & INET_FRAG_LAST_IN) && end != fq->q.len)) + goto discard_fq; + fq->q.flags |= INET_FRAG_LAST_IN; + fq->q.len = end; + } else { + /* Check if the fragment is rounded to 8 bytes. + * Required by the RFC. + */ + if (end & 0x7) { + /* RFC2460 says always send parameter problem in + * this case. -DaveM + */ + *prob_offset = offsetof(struct ipv6hdr, payload_len); + return -1; + } + if (end > fq->q.len) { + /* Some bits beyond end -> corruption. */ + if (fq->q.flags & INET_FRAG_LAST_IN) + goto discard_fq; + fq->q.len = end; + } + } + + if (end == offset) + goto discard_fq; + + err = -ENOMEM; + /* Point into the IP datagram 'data' part. */ + if (!pskb_pull(skb, (u8 *) (fhdr + 1) - skb->data)) + goto discard_fq; + + err = pskb_trim_rcsum(skb, end - offset); + if (err) + goto discard_fq; + + /* Note : skb->rbnode and skb->dev share the same location. */ + dev = skb->dev; + /* Makes sure compiler wont do silly aliasing games */ + barrier(); + + prev_tail = fq->q.fragments_tail; + err = inet_frag_queue_insert(&fq->q, skb, offset, end); + if (err) + goto insert_error; + + if (dev) + fq->iif = dev->ifindex; + + fq->q.stamp = skb->tstamp; + fq->q.mono_delivery_time = skb->mono_delivery_time; + fq->q.meat += skb->len; + fq->ecn |= ecn; + add_frag_mem_limit(fq->q.fqdir, skb->truesize); + + fragsize = -skb_network_offset(skb) + skb->len; + if (fragsize > fq->q.max_size) + fq->q.max_size = fragsize; + + /* The first fragment. + * nhoffset is obtained from the first fragment, of course. + */ + if (offset == 0) { + fq->nhoffset = nhoff; + fq->q.flags |= INET_FRAG_FIRST_IN; + } + + if (fq->q.flags == (INET_FRAG_FIRST_IN | INET_FRAG_LAST_IN) && + fq->q.meat == fq->q.len) { + unsigned long orefdst = skb->_skb_refdst; + + skb->_skb_refdst = 0UL; + err = ip6_frag_reasm(fq, skb, prev_tail, dev); + skb->_skb_refdst = orefdst; + return err; + } + + skb_dst_drop(skb); + return -EINPROGRESS; + +insert_error: + if (err == IPFRAG_DUP) { + kfree_skb(skb); + return -EINVAL; + } + err = -EINVAL; + __IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)), + IPSTATS_MIB_REASM_OVERLAPS); +discard_fq: + inet_frag_kill(&fq->q); + __IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)), + IPSTATS_MIB_REASMFAILS); +err: + kfree_skb(skb); + return err; +} + +/* + * Check if this packet is complete. + * + * It is called with locked fq, and caller must check that + * queue is eligible for reassembly i.e. it is not COMPLETE, + * the last and the first frames arrived and all the bits are here. + */ +static int ip6_frag_reasm(struct frag_queue *fq, struct sk_buff *skb, + struct sk_buff *prev_tail, struct net_device *dev) +{ + struct net *net = fq->q.fqdir->net; + unsigned int nhoff; + void *reasm_data; + int payload_len; + u8 ecn; + + inet_frag_kill(&fq->q); + + ecn = ip_frag_ecn_table[fq->ecn]; + if (unlikely(ecn == 0xff)) + goto out_fail; + + reasm_data = inet_frag_reasm_prepare(&fq->q, skb, prev_tail); + if (!reasm_data) + goto out_oom; + + payload_len = ((skb->data - skb_network_header(skb)) - + sizeof(struct ipv6hdr) + fq->q.len - + sizeof(struct frag_hdr)); + if (payload_len > IPV6_MAXPLEN) + goto out_oversize; + + /* We have to remove fragment header from datagram and to relocate + * header in order to calculate ICV correctly. */ + nhoff = fq->nhoffset; + skb_network_header(skb)[nhoff] = skb_transport_header(skb)[0]; + memmove(skb->head + sizeof(struct frag_hdr), skb->head, + (skb->data - skb->head) - sizeof(struct frag_hdr)); + if (skb_mac_header_was_set(skb)) + skb->mac_header += sizeof(struct frag_hdr); + skb->network_header += sizeof(struct frag_hdr); + + skb_reset_transport_header(skb); + + inet_frag_reasm_finish(&fq->q, skb, reasm_data, true); + + skb->dev = dev; + ipv6_hdr(skb)->payload_len = htons(payload_len); + ipv6_change_dsfield(ipv6_hdr(skb), 0xff, ecn); + IP6CB(skb)->nhoff = nhoff; + IP6CB(skb)->flags |= IP6SKB_FRAGMENTED; + IP6CB(skb)->frag_max_size = fq->q.max_size; + + /* Yes, and fold redundant checksum back. 8) */ + skb_postpush_rcsum(skb, skb_network_header(skb), + skb_network_header_len(skb)); + + rcu_read_lock(); + __IP6_INC_STATS(net, __in6_dev_stats_get(dev, skb), IPSTATS_MIB_REASMOKS); + rcu_read_unlock(); + fq->q.rb_fragments = RB_ROOT; + fq->q.fragments_tail = NULL; + fq->q.last_run_head = NULL; + return 1; + +out_oversize: + net_dbg_ratelimited("ip6_frag_reasm: payload len = %d\n", payload_len); + goto out_fail; +out_oom: + net_dbg_ratelimited("ip6_frag_reasm: no memory for reassembly\n"); +out_fail: + rcu_read_lock(); + __IP6_INC_STATS(net, __in6_dev_stats_get(dev, skb), IPSTATS_MIB_REASMFAILS); + rcu_read_unlock(); + inet_frag_kill(&fq->q); + return -1; +} + +static int ipv6_frag_rcv(struct sk_buff *skb) +{ + struct frag_hdr *fhdr; + struct frag_queue *fq; + const struct ipv6hdr *hdr = ipv6_hdr(skb); + struct net *net = dev_net(skb_dst(skb)->dev); + u8 nexthdr; + int iif; + + if (IP6CB(skb)->flags & IP6SKB_FRAGMENTED) + goto fail_hdr; + + __IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)), IPSTATS_MIB_REASMREQDS); + + /* Jumbo payload inhibits frag. header */ + if (hdr->payload_len == 0) + goto fail_hdr; + + if (!pskb_may_pull(skb, (skb_transport_offset(skb) + + sizeof(struct frag_hdr)))) + goto fail_hdr; + + hdr = ipv6_hdr(skb); + fhdr = (struct frag_hdr *)skb_transport_header(skb); + + if (!(fhdr->frag_off & htons(IP6_OFFSET | IP6_MF))) { + /* It is not a fragmented frame */ + skb->transport_header += sizeof(struct frag_hdr); + __IP6_INC_STATS(net, + ip6_dst_idev(skb_dst(skb)), IPSTATS_MIB_REASMOKS); + + IP6CB(skb)->nhoff = (u8 *)fhdr - skb_network_header(skb); + IP6CB(skb)->flags |= IP6SKB_FRAGMENTED; + IP6CB(skb)->frag_max_size = ntohs(hdr->payload_len) + + sizeof(struct ipv6hdr); + return 1; + } + + /* RFC 8200, Section 4.5 Fragment Header: + * If the first fragment does not include all headers through an + * Upper-Layer header, then that fragment should be discarded and + * an ICMP Parameter Problem, Code 3, message should be sent to + * the source of the fragment, with the Pointer field set to zero. + */ + nexthdr = hdr->nexthdr; + if (ipv6frag_thdr_truncated(skb, skb_transport_offset(skb), &nexthdr)) { + __IP6_INC_STATS(net, __in6_dev_get_safely(skb->dev), + IPSTATS_MIB_INHDRERRORS); + icmpv6_param_prob(skb, ICMPV6_HDR_INCOMP, 0); + return -1; + } + + iif = skb->dev ? skb->dev->ifindex : 0; + fq = fq_find(net, fhdr->identification, hdr, iif); + if (fq) { + u32 prob_offset = 0; + int ret; + + spin_lock(&fq->q.lock); + + fq->iif = iif; + ret = ip6_frag_queue(fq, skb, fhdr, IP6CB(skb)->nhoff, + &prob_offset); + + spin_unlock(&fq->q.lock); + inet_frag_put(&fq->q); + if (prob_offset) { + __IP6_INC_STATS(net, __in6_dev_get_safely(skb->dev), + IPSTATS_MIB_INHDRERRORS); + /* icmpv6_param_prob() calls kfree_skb(skb) */ + icmpv6_param_prob(skb, ICMPV6_HDR_FIELD, prob_offset); + } + return ret; + } + + __IP6_INC_STATS(net, ip6_dst_idev(skb_dst(skb)), IPSTATS_MIB_REASMFAILS); + kfree_skb(skb); + return -1; + +fail_hdr: + __IP6_INC_STATS(net, __in6_dev_get_safely(skb->dev), + IPSTATS_MIB_INHDRERRORS); + icmpv6_param_prob(skb, ICMPV6_HDR_FIELD, skb_network_header_len(skb)); + return -1; +} + +static const struct inet6_protocol frag_protocol = { + .handler = ipv6_frag_rcv, + .flags = INET6_PROTO_NOPOLICY, +}; + +#ifdef CONFIG_SYSCTL + +static struct ctl_table ip6_frags_ns_ctl_table[] = { + { + .procname = "ip6frag_high_thresh", + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + }, + { + .procname = "ip6frag_low_thresh", + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + }, + { + .procname = "ip6frag_time", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { } +}; + +/* secret interval has been deprecated */ +static int ip6_frags_secret_interval_unused; +static struct ctl_table ip6_frags_ctl_table[] = { + { + .procname = "ip6frag_secret_interval", + .data = &ip6_frags_secret_interval_unused, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { } +}; + +static int __net_init ip6_frags_ns_sysctl_register(struct net *net) +{ + struct ctl_table *table; + struct ctl_table_header *hdr; + + table = ip6_frags_ns_ctl_table; + if (!net_eq(net, &init_net)) { + table = kmemdup(table, sizeof(ip6_frags_ns_ctl_table), GFP_KERNEL); + if (!table) + goto err_alloc; + + } + table[0].data = &net->ipv6.fqdir->high_thresh; + table[0].extra1 = &net->ipv6.fqdir->low_thresh; + table[1].data = &net->ipv6.fqdir->low_thresh; + table[1].extra2 = &net->ipv6.fqdir->high_thresh; + table[2].data = &net->ipv6.fqdir->timeout; + + hdr = register_net_sysctl(net, "net/ipv6", table); + if (!hdr) + goto err_reg; + + net->ipv6.sysctl.frags_hdr = hdr; + return 0; + +err_reg: + if (!net_eq(net, &init_net)) + kfree(table); +err_alloc: + return -ENOMEM; +} + +static void __net_exit ip6_frags_ns_sysctl_unregister(struct net *net) +{ + struct ctl_table *table; + + table = net->ipv6.sysctl.frags_hdr->ctl_table_arg; + unregister_net_sysctl_table(net->ipv6.sysctl.frags_hdr); + if (!net_eq(net, &init_net)) + kfree(table); +} + +static struct ctl_table_header *ip6_ctl_header; + +static int ip6_frags_sysctl_register(void) +{ + ip6_ctl_header = register_net_sysctl(&init_net, "net/ipv6", + ip6_frags_ctl_table); + return ip6_ctl_header == NULL ? -ENOMEM : 0; +} + +static void ip6_frags_sysctl_unregister(void) +{ + unregister_net_sysctl_table(ip6_ctl_header); +} +#else +static int ip6_frags_ns_sysctl_register(struct net *net) +{ + return 0; +} + +static void ip6_frags_ns_sysctl_unregister(struct net *net) +{ +} + +static int ip6_frags_sysctl_register(void) +{ + return 0; +} + +static void ip6_frags_sysctl_unregister(void) +{ +} +#endif + +static int __net_init ipv6_frags_init_net(struct net *net) +{ + int res; + + res = fqdir_init(&net->ipv6.fqdir, &ip6_frags, net); + if (res < 0) + return res; + + net->ipv6.fqdir->high_thresh = IPV6_FRAG_HIGH_THRESH; + net->ipv6.fqdir->low_thresh = IPV6_FRAG_LOW_THRESH; + net->ipv6.fqdir->timeout = IPV6_FRAG_TIMEOUT; + + res = ip6_frags_ns_sysctl_register(net); + if (res < 0) + fqdir_exit(net->ipv6.fqdir); + return res; +} + +static void __net_exit ipv6_frags_pre_exit_net(struct net *net) +{ + fqdir_pre_exit(net->ipv6.fqdir); +} + +static void __net_exit ipv6_frags_exit_net(struct net *net) +{ + ip6_frags_ns_sysctl_unregister(net); + fqdir_exit(net->ipv6.fqdir); +} + +static struct pernet_operations ip6_frags_ops = { + .init = ipv6_frags_init_net, + .pre_exit = ipv6_frags_pre_exit_net, + .exit = ipv6_frags_exit_net, +}; + +static const struct rhashtable_params ip6_rhash_params = { + .head_offset = offsetof(struct inet_frag_queue, node), + .hashfn = ip6frag_key_hashfn, + .obj_hashfn = ip6frag_obj_hashfn, + .obj_cmpfn = ip6frag_obj_cmpfn, + .automatic_shrinking = true, +}; + +int __init ipv6_frag_init(void) +{ + int ret; + + ip6_frags.constructor = ip6frag_init; + ip6_frags.destructor = NULL; + ip6_frags.qsize = sizeof(struct frag_queue); + ip6_frags.frag_expire = ip6_frag_expire; + ip6_frags.frags_cache_name = ip6_frag_cache_name; + ip6_frags.rhash_params = ip6_rhash_params; + ret = inet_frags_init(&ip6_frags); + if (ret) + goto out; + + ret = inet6_add_protocol(&frag_protocol, IPPROTO_FRAGMENT); + if (ret) + goto err_protocol; + + ret = ip6_frags_sysctl_register(); + if (ret) + goto err_sysctl; + + ret = register_pernet_subsys(&ip6_frags_ops); + if (ret) + goto err_pernet; + +out: + return ret; + +err_pernet: + ip6_frags_sysctl_unregister(); +err_sysctl: + inet6_del_protocol(&frag_protocol, IPPROTO_FRAGMENT); +err_protocol: + inet_frags_fini(&ip6_frags); + goto out; +} + +void ipv6_frag_exit(void) +{ + ip6_frags_sysctl_unregister(); + unregister_pernet_subsys(&ip6_frags_ops); + inet6_del_protocol(&frag_protocol, IPPROTO_FRAGMENT); + inet_frags_fini(&ip6_frags); +} diff --git a/net/ipv6/route.c b/net/ipv6/route.c new file mode 100644 index 000000000..7f65dc750 --- /dev/null +++ b/net/ipv6/route.c @@ -0,0 +1,6792 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Linux INET6 implementation + * FIB front-end. + * + * Authors: + * Pedro Roque + */ + +/* Changes: + * + * YOSHIFUJI Hideaki @USAGI + * reworked default router selection. + * - respect outgoing interface + * - select from (probably) reachable routers (i.e. + * routers in REACHABLE, STALE, DELAY or PROBE states). + * - always select the same router if it is (probably) + * reachable. otherwise, round-robin the list. + * Ville Nuorvala + * Fixed routing subtrees. + */ + +#define pr_fmt(fmt) "IPv6: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CONFIG_SYSCTL +#include +#endif + +static int ip6_rt_type_to_error(u8 fib6_type); + +#define CREATE_TRACE_POINTS +#include +EXPORT_TRACEPOINT_SYMBOL_GPL(fib6_table_lookup); +#undef CREATE_TRACE_POINTS + +enum rt6_nud_state { + RT6_NUD_FAIL_HARD = -3, + RT6_NUD_FAIL_PROBE = -2, + RT6_NUD_FAIL_DO_RR = -1, + RT6_NUD_SUCCEED = 1 +}; + +INDIRECT_CALLABLE_SCOPE +struct dst_entry *ip6_dst_check(struct dst_entry *dst, u32 cookie); +static unsigned int ip6_default_advmss(const struct dst_entry *dst); +INDIRECT_CALLABLE_SCOPE +unsigned int ip6_mtu(const struct dst_entry *dst); +static struct dst_entry *ip6_negative_advice(struct dst_entry *); +static void ip6_dst_destroy(struct dst_entry *); +static void ip6_dst_ifdown(struct dst_entry *, + struct net_device *dev, int how); +static void ip6_dst_gc(struct dst_ops *ops); + +static int ip6_pkt_discard(struct sk_buff *skb); +static int ip6_pkt_discard_out(struct net *net, struct sock *sk, struct sk_buff *skb); +static int ip6_pkt_prohibit(struct sk_buff *skb); +static int ip6_pkt_prohibit_out(struct net *net, struct sock *sk, struct sk_buff *skb); +static void ip6_link_failure(struct sk_buff *skb); +static void ip6_rt_update_pmtu(struct dst_entry *dst, struct sock *sk, + struct sk_buff *skb, u32 mtu, + bool confirm_neigh); +static void rt6_do_redirect(struct dst_entry *dst, struct sock *sk, + struct sk_buff *skb); +static int rt6_score_route(const struct fib6_nh *nh, u32 fib6_flags, int oif, + int strict); +static size_t rt6_nlmsg_size(struct fib6_info *f6i); +static int rt6_fill_node(struct net *net, struct sk_buff *skb, + struct fib6_info *rt, struct dst_entry *dst, + struct in6_addr *dest, struct in6_addr *src, + int iif, int type, u32 portid, u32 seq, + unsigned int flags); +static struct rt6_info *rt6_find_cached_rt(const struct fib6_result *res, + const struct in6_addr *daddr, + const struct in6_addr *saddr); + +#ifdef CONFIG_IPV6_ROUTE_INFO +static struct fib6_info *rt6_add_route_info(struct net *net, + const struct in6_addr *prefix, int prefixlen, + const struct in6_addr *gwaddr, + struct net_device *dev, + unsigned int pref); +static struct fib6_info *rt6_get_route_info(struct net *net, + const struct in6_addr *prefix, int prefixlen, + const struct in6_addr *gwaddr, + struct net_device *dev); +#endif + +struct uncached_list { + spinlock_t lock; + struct list_head head; + struct list_head quarantine; +}; + +static DEFINE_PER_CPU_ALIGNED(struct uncached_list, rt6_uncached_list); + +void rt6_uncached_list_add(struct rt6_info *rt) +{ + struct uncached_list *ul = raw_cpu_ptr(&rt6_uncached_list); + + rt->rt6i_uncached_list = ul; + + spin_lock_bh(&ul->lock); + list_add_tail(&rt->rt6i_uncached, &ul->head); + spin_unlock_bh(&ul->lock); +} + +void rt6_uncached_list_del(struct rt6_info *rt) +{ + if (!list_empty(&rt->rt6i_uncached)) { + struct uncached_list *ul = rt->rt6i_uncached_list; + + spin_lock_bh(&ul->lock); + list_del_init(&rt->rt6i_uncached); + spin_unlock_bh(&ul->lock); + } +} + +static void rt6_uncached_list_flush_dev(struct net_device *dev) +{ + int cpu; + + for_each_possible_cpu(cpu) { + struct uncached_list *ul = per_cpu_ptr(&rt6_uncached_list, cpu); + struct rt6_info *rt, *safe; + + if (list_empty(&ul->head)) + continue; + + spin_lock_bh(&ul->lock); + list_for_each_entry_safe(rt, safe, &ul->head, rt6i_uncached) { + struct inet6_dev *rt_idev = rt->rt6i_idev; + struct net_device *rt_dev = rt->dst.dev; + bool handled = false; + + if (rt_idev->dev == dev) { + rt->rt6i_idev = in6_dev_get(blackhole_netdev); + in6_dev_put(rt_idev); + handled = true; + } + + if (rt_dev == dev) { + rt->dst.dev = blackhole_netdev; + netdev_ref_replace(rt_dev, blackhole_netdev, + &rt->dst.dev_tracker, + GFP_ATOMIC); + handled = true; + } + if (handled) + list_move(&rt->rt6i_uncached, + &ul->quarantine); + } + spin_unlock_bh(&ul->lock); + } +} + +static inline const void *choose_neigh_daddr(const struct in6_addr *p, + struct sk_buff *skb, + const void *daddr) +{ + if (!ipv6_addr_any(p)) + return (const void *) p; + else if (skb) + return &ipv6_hdr(skb)->daddr; + return daddr; +} + +struct neighbour *ip6_neigh_lookup(const struct in6_addr *gw, + struct net_device *dev, + struct sk_buff *skb, + const void *daddr) +{ + struct neighbour *n; + + daddr = choose_neigh_daddr(gw, skb, daddr); + n = __ipv6_neigh_lookup(dev, daddr); + if (n) + return n; + + n = neigh_create(&nd_tbl, daddr, dev); + return IS_ERR(n) ? NULL : n; +} + +static struct neighbour *ip6_dst_neigh_lookup(const struct dst_entry *dst, + struct sk_buff *skb, + const void *daddr) +{ + const struct rt6_info *rt = container_of(dst, struct rt6_info, dst); + + return ip6_neigh_lookup(rt6_nexthop(rt, &in6addr_any), + dst->dev, skb, daddr); +} + +static void ip6_confirm_neigh(const struct dst_entry *dst, const void *daddr) +{ + struct net_device *dev = dst->dev; + struct rt6_info *rt = (struct rt6_info *)dst; + + daddr = choose_neigh_daddr(rt6_nexthop(rt, &in6addr_any), NULL, daddr); + if (!daddr) + return; + if (dev->flags & (IFF_NOARP | IFF_LOOPBACK)) + return; + if (ipv6_addr_is_multicast((const struct in6_addr *)daddr)) + return; + __ipv6_confirm_neigh(dev, daddr); +} + +static struct dst_ops ip6_dst_ops_template = { + .family = AF_INET6, + .gc = ip6_dst_gc, + .gc_thresh = 1024, + .check = ip6_dst_check, + .default_advmss = ip6_default_advmss, + .mtu = ip6_mtu, + .cow_metrics = dst_cow_metrics_generic, + .destroy = ip6_dst_destroy, + .ifdown = ip6_dst_ifdown, + .negative_advice = ip6_negative_advice, + .link_failure = ip6_link_failure, + .update_pmtu = ip6_rt_update_pmtu, + .redirect = rt6_do_redirect, + .local_out = __ip6_local_out, + .neigh_lookup = ip6_dst_neigh_lookup, + .confirm_neigh = ip6_confirm_neigh, +}; + +static struct dst_ops ip6_dst_blackhole_ops = { + .family = AF_INET6, + .default_advmss = ip6_default_advmss, + .neigh_lookup = ip6_dst_neigh_lookup, + .check = ip6_dst_check, + .destroy = ip6_dst_destroy, + .cow_metrics = dst_cow_metrics_generic, + .update_pmtu = dst_blackhole_update_pmtu, + .redirect = dst_blackhole_redirect, + .mtu = dst_blackhole_mtu, +}; + +static const u32 ip6_template_metrics[RTAX_MAX] = { + [RTAX_HOPLIMIT - 1] = 0, +}; + +static const struct fib6_info fib6_null_entry_template = { + .fib6_flags = (RTF_REJECT | RTF_NONEXTHOP), + .fib6_protocol = RTPROT_KERNEL, + .fib6_metric = ~(u32)0, + .fib6_ref = REFCOUNT_INIT(1), + .fib6_type = RTN_UNREACHABLE, + .fib6_metrics = (struct dst_metrics *)&dst_default_metrics, +}; + +static const struct rt6_info ip6_null_entry_template = { + .dst = { + .__refcnt = ATOMIC_INIT(1), + .__use = 1, + .obsolete = DST_OBSOLETE_FORCE_CHK, + .error = -ENETUNREACH, + .input = ip6_pkt_discard, + .output = ip6_pkt_discard_out, + }, + .rt6i_flags = (RTF_REJECT | RTF_NONEXTHOP), +}; + +#ifdef CONFIG_IPV6_MULTIPLE_TABLES + +static const struct rt6_info ip6_prohibit_entry_template = { + .dst = { + .__refcnt = ATOMIC_INIT(1), + .__use = 1, + .obsolete = DST_OBSOLETE_FORCE_CHK, + .error = -EACCES, + .input = ip6_pkt_prohibit, + .output = ip6_pkt_prohibit_out, + }, + .rt6i_flags = (RTF_REJECT | RTF_NONEXTHOP), +}; + +static const struct rt6_info ip6_blk_hole_entry_template = { + .dst = { + .__refcnt = ATOMIC_INIT(1), + .__use = 1, + .obsolete = DST_OBSOLETE_FORCE_CHK, + .error = -EINVAL, + .input = dst_discard, + .output = dst_discard_out, + }, + .rt6i_flags = (RTF_REJECT | RTF_NONEXTHOP), +}; + +#endif + +static void rt6_info_init(struct rt6_info *rt) +{ + memset_after(rt, 0, dst); + INIT_LIST_HEAD(&rt->rt6i_uncached); +} + +/* allocate dst with ip6_dst_ops */ +struct rt6_info *ip6_dst_alloc(struct net *net, struct net_device *dev, + int flags) +{ + struct rt6_info *rt = dst_alloc(&net->ipv6.ip6_dst_ops, dev, + 1, DST_OBSOLETE_FORCE_CHK, flags); + + if (rt) { + rt6_info_init(rt); + atomic_inc(&net->ipv6.rt6_stats->fib_rt_alloc); + } + + return rt; +} +EXPORT_SYMBOL(ip6_dst_alloc); + +static void ip6_dst_destroy(struct dst_entry *dst) +{ + struct rt6_info *rt = (struct rt6_info *)dst; + struct fib6_info *from; + struct inet6_dev *idev; + + ip_dst_metrics_put(dst); + rt6_uncached_list_del(rt); + + idev = rt->rt6i_idev; + if (idev) { + rt->rt6i_idev = NULL; + in6_dev_put(idev); + } + + from = xchg((__force struct fib6_info **)&rt->from, NULL); + fib6_info_release(from); +} + +static void ip6_dst_ifdown(struct dst_entry *dst, struct net_device *dev, + int how) +{ + struct rt6_info *rt = (struct rt6_info *)dst; + struct inet6_dev *idev = rt->rt6i_idev; + + if (idev && idev->dev != blackhole_netdev) { + struct inet6_dev *blackhole_idev = in6_dev_get(blackhole_netdev); + + if (blackhole_idev) { + rt->rt6i_idev = blackhole_idev; + in6_dev_put(idev); + } + } +} + +static bool __rt6_check_expired(const struct rt6_info *rt) +{ + if (rt->rt6i_flags & RTF_EXPIRES) + return time_after(jiffies, rt->dst.expires); + else + return false; +} + +static bool rt6_check_expired(const struct rt6_info *rt) +{ + struct fib6_info *from; + + from = rcu_dereference(rt->from); + + if (rt->rt6i_flags & RTF_EXPIRES) { + if (time_after(jiffies, rt->dst.expires)) + return true; + } else if (from) { + return rt->dst.obsolete != DST_OBSOLETE_FORCE_CHK || + fib6_check_expired(from); + } + return false; +} + +void fib6_select_path(const struct net *net, struct fib6_result *res, + struct flowi6 *fl6, int oif, bool have_oif_match, + const struct sk_buff *skb, int strict) +{ + struct fib6_info *sibling, *next_sibling; + struct fib6_info *match = res->f6i; + + if (!match->nh && (!match->fib6_nsiblings || have_oif_match)) + goto out; + + if (match->nh && have_oif_match && res->nh) + return; + + if (skb) + IP6CB(skb)->flags |= IP6SKB_MULTIPATH; + + /* We might have already computed the hash for ICMPv6 errors. In such + * case it will always be non-zero. Otherwise now is the time to do it. + */ + if (!fl6->mp_hash && + (!match->nh || nexthop_is_multipath(match->nh))) + fl6->mp_hash = rt6_multipath_hash(net, fl6, skb, NULL); + + if (unlikely(match->nh)) { + nexthop_path_fib6_result(res, fl6->mp_hash); + return; + } + + if (fl6->mp_hash <= atomic_read(&match->fib6_nh->fib_nh_upper_bound)) + goto out; + + list_for_each_entry_safe(sibling, next_sibling, &match->fib6_siblings, + fib6_siblings) { + const struct fib6_nh *nh = sibling->fib6_nh; + int nh_upper_bound; + + nh_upper_bound = atomic_read(&nh->fib_nh_upper_bound); + if (fl6->mp_hash > nh_upper_bound) + continue; + if (rt6_score_route(nh, sibling->fib6_flags, oif, strict) < 0) + break; + match = sibling; + break; + } + +out: + res->f6i = match; + res->nh = match->fib6_nh; +} + +/* + * Route lookup. rcu_read_lock() should be held. + */ + +static bool __rt6_device_match(struct net *net, const struct fib6_nh *nh, + const struct in6_addr *saddr, int oif, int flags) +{ + const struct net_device *dev; + + if (nh->fib_nh_flags & RTNH_F_DEAD) + return false; + + dev = nh->fib_nh_dev; + if (oif) { + if (dev->ifindex == oif) + return true; + } else { + if (ipv6_chk_addr(net, saddr, dev, + flags & RT6_LOOKUP_F_IFACE)) + return true; + } + + return false; +} + +struct fib6_nh_dm_arg { + struct net *net; + const struct in6_addr *saddr; + int oif; + int flags; + struct fib6_nh *nh; +}; + +static int __rt6_nh_dev_match(struct fib6_nh *nh, void *_arg) +{ + struct fib6_nh_dm_arg *arg = _arg; + + arg->nh = nh; + return __rt6_device_match(arg->net, nh, arg->saddr, arg->oif, + arg->flags); +} + +/* returns fib6_nh from nexthop or NULL */ +static struct fib6_nh *rt6_nh_dev_match(struct net *net, struct nexthop *nh, + struct fib6_result *res, + const struct in6_addr *saddr, + int oif, int flags) +{ + struct fib6_nh_dm_arg arg = { + .net = net, + .saddr = saddr, + .oif = oif, + .flags = flags, + }; + + if (nexthop_is_blackhole(nh)) + return NULL; + + if (nexthop_for_each_fib6_nh(nh, __rt6_nh_dev_match, &arg)) + return arg.nh; + + return NULL; +} + +static void rt6_device_match(struct net *net, struct fib6_result *res, + const struct in6_addr *saddr, int oif, int flags) +{ + struct fib6_info *f6i = res->f6i; + struct fib6_info *spf6i; + struct fib6_nh *nh; + + if (!oif && ipv6_addr_any(saddr)) { + if (unlikely(f6i->nh)) { + nh = nexthop_fib6_nh(f6i->nh); + if (nexthop_is_blackhole(f6i->nh)) + goto out_blackhole; + } else { + nh = f6i->fib6_nh; + } + if (!(nh->fib_nh_flags & RTNH_F_DEAD)) + goto out; + } + + for (spf6i = f6i; spf6i; spf6i = rcu_dereference(spf6i->fib6_next)) { + bool matched = false; + + if (unlikely(spf6i->nh)) { + nh = rt6_nh_dev_match(net, spf6i->nh, res, saddr, + oif, flags); + if (nh) + matched = true; + } else { + nh = spf6i->fib6_nh; + if (__rt6_device_match(net, nh, saddr, oif, flags)) + matched = true; + } + if (matched) { + res->f6i = spf6i; + goto out; + } + } + + if (oif && flags & RT6_LOOKUP_F_IFACE) { + res->f6i = net->ipv6.fib6_null_entry; + nh = res->f6i->fib6_nh; + goto out; + } + + if (unlikely(f6i->nh)) { + nh = nexthop_fib6_nh(f6i->nh); + if (nexthop_is_blackhole(f6i->nh)) + goto out_blackhole; + } else { + nh = f6i->fib6_nh; + } + + if (nh->fib_nh_flags & RTNH_F_DEAD) { + res->f6i = net->ipv6.fib6_null_entry; + nh = res->f6i->fib6_nh; + } +out: + res->nh = nh; + res->fib6_type = res->f6i->fib6_type; + res->fib6_flags = res->f6i->fib6_flags; + return; + +out_blackhole: + res->fib6_flags |= RTF_REJECT; + res->fib6_type = RTN_BLACKHOLE; + res->nh = nh; +} + +#ifdef CONFIG_IPV6_ROUTER_PREF +struct __rt6_probe_work { + struct work_struct work; + struct in6_addr target; + struct net_device *dev; + netdevice_tracker dev_tracker; +}; + +static void rt6_probe_deferred(struct work_struct *w) +{ + struct in6_addr mcaddr; + struct __rt6_probe_work *work = + container_of(w, struct __rt6_probe_work, work); + + addrconf_addr_solict_mult(&work->target, &mcaddr); + ndisc_send_ns(work->dev, &work->target, &mcaddr, NULL, 0); + netdev_put(work->dev, &work->dev_tracker); + kfree(work); +} + +static void rt6_probe(struct fib6_nh *fib6_nh) +{ + struct __rt6_probe_work *work = NULL; + const struct in6_addr *nh_gw; + unsigned long last_probe; + struct neighbour *neigh; + struct net_device *dev; + struct inet6_dev *idev; + + /* + * Okay, this does not seem to be appropriate + * for now, however, we need to check if it + * is really so; aka Router Reachability Probing. + * + * Router Reachability Probe MUST be rate-limited + * to no more than one per minute. + */ + if (!fib6_nh->fib_nh_gw_family) + return; + + nh_gw = &fib6_nh->fib_nh_gw6; + dev = fib6_nh->fib_nh_dev; + rcu_read_lock(); + last_probe = READ_ONCE(fib6_nh->last_probe); + idev = __in6_dev_get(dev); + neigh = __ipv6_neigh_lookup_noref(dev, nh_gw); + if (neigh) { + if (READ_ONCE(neigh->nud_state) & NUD_VALID) + goto out; + + write_lock_bh(&neigh->lock); + if (!(neigh->nud_state & NUD_VALID) && + time_after(jiffies, + neigh->updated + idev->cnf.rtr_probe_interval)) { + work = kmalloc(sizeof(*work), GFP_ATOMIC); + if (work) + __neigh_set_probe_once(neigh); + } + write_unlock_bh(&neigh->lock); + } else if (time_after(jiffies, last_probe + + idev->cnf.rtr_probe_interval)) { + work = kmalloc(sizeof(*work), GFP_ATOMIC); + } + + if (!work || cmpxchg(&fib6_nh->last_probe, + last_probe, jiffies) != last_probe) { + kfree(work); + } else { + INIT_WORK(&work->work, rt6_probe_deferred); + work->target = *nh_gw; + netdev_hold(dev, &work->dev_tracker, GFP_ATOMIC); + work->dev = dev; + schedule_work(&work->work); + } + +out: + rcu_read_unlock(); +} +#else +static inline void rt6_probe(struct fib6_nh *fib6_nh) +{ +} +#endif + +/* + * Default Router Selection (RFC 2461 6.3.6) + */ +static enum rt6_nud_state rt6_check_neigh(const struct fib6_nh *fib6_nh) +{ + enum rt6_nud_state ret = RT6_NUD_FAIL_HARD; + struct neighbour *neigh; + + rcu_read_lock(); + neigh = __ipv6_neigh_lookup_noref(fib6_nh->fib_nh_dev, + &fib6_nh->fib_nh_gw6); + if (neigh) { + u8 nud_state = READ_ONCE(neigh->nud_state); + + if (nud_state & NUD_VALID) + ret = RT6_NUD_SUCCEED; +#ifdef CONFIG_IPV6_ROUTER_PREF + else if (!(nud_state & NUD_FAILED)) + ret = RT6_NUD_SUCCEED; + else + ret = RT6_NUD_FAIL_PROBE; +#endif + } else { + ret = IS_ENABLED(CONFIG_IPV6_ROUTER_PREF) ? + RT6_NUD_SUCCEED : RT6_NUD_FAIL_DO_RR; + } + rcu_read_unlock(); + + return ret; +} + +static int rt6_score_route(const struct fib6_nh *nh, u32 fib6_flags, int oif, + int strict) +{ + int m = 0; + + if (!oif || nh->fib_nh_dev->ifindex == oif) + m = 2; + + if (!m && (strict & RT6_LOOKUP_F_IFACE)) + return RT6_NUD_FAIL_HARD; +#ifdef CONFIG_IPV6_ROUTER_PREF + m |= IPV6_DECODE_PREF(IPV6_EXTRACT_PREF(fib6_flags)) << 2; +#endif + if ((strict & RT6_LOOKUP_F_REACHABLE) && + !(fib6_flags & RTF_NONEXTHOP) && nh->fib_nh_gw_family) { + int n = rt6_check_neigh(nh); + if (n < 0) + return n; + } + return m; +} + +static bool find_match(struct fib6_nh *nh, u32 fib6_flags, + int oif, int strict, int *mpri, bool *do_rr) +{ + bool match_do_rr = false; + bool rc = false; + int m; + + if (nh->fib_nh_flags & RTNH_F_DEAD) + goto out; + + if (ip6_ignore_linkdown(nh->fib_nh_dev) && + nh->fib_nh_flags & RTNH_F_LINKDOWN && + !(strict & RT6_LOOKUP_F_IGNORE_LINKSTATE)) + goto out; + + m = rt6_score_route(nh, fib6_flags, oif, strict); + if (m == RT6_NUD_FAIL_DO_RR) { + match_do_rr = true; + m = 0; /* lowest valid score */ + } else if (m == RT6_NUD_FAIL_HARD) { + goto out; + } + + if (strict & RT6_LOOKUP_F_REACHABLE) + rt6_probe(nh); + + /* note that m can be RT6_NUD_FAIL_PROBE at this point */ + if (m > *mpri) { + *do_rr = match_do_rr; + *mpri = m; + rc = true; + } +out: + return rc; +} + +struct fib6_nh_frl_arg { + u32 flags; + int oif; + int strict; + int *mpri; + bool *do_rr; + struct fib6_nh *nh; +}; + +static int rt6_nh_find_match(struct fib6_nh *nh, void *_arg) +{ + struct fib6_nh_frl_arg *arg = _arg; + + arg->nh = nh; + return find_match(nh, arg->flags, arg->oif, arg->strict, + arg->mpri, arg->do_rr); +} + +static void __find_rr_leaf(struct fib6_info *f6i_start, + struct fib6_info *nomatch, u32 metric, + struct fib6_result *res, struct fib6_info **cont, + int oif, int strict, bool *do_rr, int *mpri) +{ + struct fib6_info *f6i; + + for (f6i = f6i_start; + f6i && f6i != nomatch; + f6i = rcu_dereference(f6i->fib6_next)) { + bool matched = false; + struct fib6_nh *nh; + + if (cont && f6i->fib6_metric != metric) { + *cont = f6i; + return; + } + + if (fib6_check_expired(f6i)) + continue; + + if (unlikely(f6i->nh)) { + struct fib6_nh_frl_arg arg = { + .flags = f6i->fib6_flags, + .oif = oif, + .strict = strict, + .mpri = mpri, + .do_rr = do_rr + }; + + if (nexthop_is_blackhole(f6i->nh)) { + res->fib6_flags = RTF_REJECT; + res->fib6_type = RTN_BLACKHOLE; + res->f6i = f6i; + res->nh = nexthop_fib6_nh(f6i->nh); + return; + } + if (nexthop_for_each_fib6_nh(f6i->nh, rt6_nh_find_match, + &arg)) { + matched = true; + nh = arg.nh; + } + } else { + nh = f6i->fib6_nh; + if (find_match(nh, f6i->fib6_flags, oif, strict, + mpri, do_rr)) + matched = true; + } + if (matched) { + res->f6i = f6i; + res->nh = nh; + res->fib6_flags = f6i->fib6_flags; + res->fib6_type = f6i->fib6_type; + } + } +} + +static void find_rr_leaf(struct fib6_node *fn, struct fib6_info *leaf, + struct fib6_info *rr_head, int oif, int strict, + bool *do_rr, struct fib6_result *res) +{ + u32 metric = rr_head->fib6_metric; + struct fib6_info *cont = NULL; + int mpri = -1; + + __find_rr_leaf(rr_head, NULL, metric, res, &cont, + oif, strict, do_rr, &mpri); + + __find_rr_leaf(leaf, rr_head, metric, res, &cont, + oif, strict, do_rr, &mpri); + + if (res->f6i || !cont) + return; + + __find_rr_leaf(cont, NULL, metric, res, NULL, + oif, strict, do_rr, &mpri); +} + +static void rt6_select(struct net *net, struct fib6_node *fn, int oif, + struct fib6_result *res, int strict) +{ + struct fib6_info *leaf = rcu_dereference(fn->leaf); + struct fib6_info *rt0; + bool do_rr = false; + int key_plen; + + /* make sure this function or its helpers sets f6i */ + res->f6i = NULL; + + if (!leaf || leaf == net->ipv6.fib6_null_entry) + goto out; + + rt0 = rcu_dereference(fn->rr_ptr); + if (!rt0) + rt0 = leaf; + + /* Double check to make sure fn is not an intermediate node + * and fn->leaf does not points to its child's leaf + * (This might happen if all routes under fn are deleted from + * the tree and fib6_repair_tree() is called on the node.) + */ + key_plen = rt0->fib6_dst.plen; +#ifdef CONFIG_IPV6_SUBTREES + if (rt0->fib6_src.plen) + key_plen = rt0->fib6_src.plen; +#endif + if (fn->fn_bit != key_plen) + goto out; + + find_rr_leaf(fn, leaf, rt0, oif, strict, &do_rr, res); + if (do_rr) { + struct fib6_info *next = rcu_dereference(rt0->fib6_next); + + /* no entries matched; do round-robin */ + if (!next || next->fib6_metric != rt0->fib6_metric) + next = leaf; + + if (next != rt0) { + spin_lock_bh(&leaf->fib6_table->tb6_lock); + /* make sure next is not being deleted from the tree */ + if (next->fib6_node) + rcu_assign_pointer(fn->rr_ptr, next); + spin_unlock_bh(&leaf->fib6_table->tb6_lock); + } + } + +out: + if (!res->f6i) { + res->f6i = net->ipv6.fib6_null_entry; + res->nh = res->f6i->fib6_nh; + res->fib6_flags = res->f6i->fib6_flags; + res->fib6_type = res->f6i->fib6_type; + } +} + +static bool rt6_is_gw_or_nonexthop(const struct fib6_result *res) +{ + return (res->f6i->fib6_flags & RTF_NONEXTHOP) || + res->nh->fib_nh_gw_family; +} + +#ifdef CONFIG_IPV6_ROUTE_INFO +int rt6_route_rcv(struct net_device *dev, u8 *opt, int len, + const struct in6_addr *gwaddr) +{ + struct net *net = dev_net(dev); + struct route_info *rinfo = (struct route_info *) opt; + struct in6_addr prefix_buf, *prefix; + unsigned int pref; + unsigned long lifetime; + struct fib6_info *rt; + + if (len < sizeof(struct route_info)) { + return -EINVAL; + } + + /* Sanity check for prefix_len and length */ + if (rinfo->length > 3) { + return -EINVAL; + } else if (rinfo->prefix_len > 128) { + return -EINVAL; + } else if (rinfo->prefix_len > 64) { + if (rinfo->length < 2) { + return -EINVAL; + } + } else if (rinfo->prefix_len > 0) { + if (rinfo->length < 1) { + return -EINVAL; + } + } + + pref = rinfo->route_pref; + if (pref == ICMPV6_ROUTER_PREF_INVALID) + return -EINVAL; + + lifetime = addrconf_timeout_fixup(ntohl(rinfo->lifetime), HZ); + + if (rinfo->length == 3) + prefix = (struct in6_addr *)rinfo->prefix; + else { + /* this function is safe */ + ipv6_addr_prefix(&prefix_buf, + (struct in6_addr *)rinfo->prefix, + rinfo->prefix_len); + prefix = &prefix_buf; + } + + if (rinfo->prefix_len == 0) + rt = rt6_get_dflt_router(net, gwaddr, dev); + else + rt = rt6_get_route_info(net, prefix, rinfo->prefix_len, + gwaddr, dev); + + if (rt && !lifetime) { + ip6_del_rt(net, rt, false); + rt = NULL; + } + + if (!rt && lifetime) + rt = rt6_add_route_info(net, prefix, rinfo->prefix_len, gwaddr, + dev, pref); + else if (rt) + rt->fib6_flags = RTF_ROUTEINFO | + (rt->fib6_flags & ~RTF_PREF_MASK) | RTF_PREF(pref); + + if (rt) { + if (!addrconf_finite_timeout(lifetime)) + fib6_clean_expires(rt); + else + fib6_set_expires(rt, jiffies + HZ * lifetime); + + fib6_info_release(rt); + } + return 0; +} +#endif + +/* + * Misc support functions + */ + +/* called with rcu_lock held */ +static struct net_device *ip6_rt_get_dev_rcu(const struct fib6_result *res) +{ + struct net_device *dev = res->nh->fib_nh_dev; + + if (res->fib6_flags & (RTF_LOCAL | RTF_ANYCAST)) { + /* for copies of local routes, dst->dev needs to be the + * device if it is a master device, the master device if + * device is enslaved, and the loopback as the default + */ + if (netif_is_l3_slave(dev) && + !rt6_need_strict(&res->f6i->fib6_dst.addr)) + dev = l3mdev_master_dev_rcu(dev); + else if (!netif_is_l3_master(dev)) + dev = dev_net(dev)->loopback_dev; + /* last case is netif_is_l3_master(dev) is true in which + * case we want dev returned to be dev + */ + } + + return dev; +} + +static const int fib6_prop[RTN_MAX + 1] = { + [RTN_UNSPEC] = 0, + [RTN_UNICAST] = 0, + [RTN_LOCAL] = 0, + [RTN_BROADCAST] = 0, + [RTN_ANYCAST] = 0, + [RTN_MULTICAST] = 0, + [RTN_BLACKHOLE] = -EINVAL, + [RTN_UNREACHABLE] = -EHOSTUNREACH, + [RTN_PROHIBIT] = -EACCES, + [RTN_THROW] = -EAGAIN, + [RTN_NAT] = -EINVAL, + [RTN_XRESOLVE] = -EINVAL, +}; + +static int ip6_rt_type_to_error(u8 fib6_type) +{ + return fib6_prop[fib6_type]; +} + +static unsigned short fib6_info_dst_flags(struct fib6_info *rt) +{ + unsigned short flags = 0; + + if (rt->dst_nocount) + flags |= DST_NOCOUNT; + if (rt->dst_nopolicy) + flags |= DST_NOPOLICY; + + return flags; +} + +static void ip6_rt_init_dst_reject(struct rt6_info *rt, u8 fib6_type) +{ + rt->dst.error = ip6_rt_type_to_error(fib6_type); + + switch (fib6_type) { + case RTN_BLACKHOLE: + rt->dst.output = dst_discard_out; + rt->dst.input = dst_discard; + break; + case RTN_PROHIBIT: + rt->dst.output = ip6_pkt_prohibit_out; + rt->dst.input = ip6_pkt_prohibit; + break; + case RTN_THROW: + case RTN_UNREACHABLE: + default: + rt->dst.output = ip6_pkt_discard_out; + rt->dst.input = ip6_pkt_discard; + break; + } +} + +static void ip6_rt_init_dst(struct rt6_info *rt, const struct fib6_result *res) +{ + struct fib6_info *f6i = res->f6i; + + if (res->fib6_flags & RTF_REJECT) { + ip6_rt_init_dst_reject(rt, res->fib6_type); + return; + } + + rt->dst.error = 0; + rt->dst.output = ip6_output; + + if (res->fib6_type == RTN_LOCAL || res->fib6_type == RTN_ANYCAST) { + rt->dst.input = ip6_input; + } else if (ipv6_addr_type(&f6i->fib6_dst.addr) & IPV6_ADDR_MULTICAST) { + rt->dst.input = ip6_mc_input; + } else { + rt->dst.input = ip6_forward; + } + + if (res->nh->fib_nh_lws) { + rt->dst.lwtstate = lwtstate_get(res->nh->fib_nh_lws); + lwtunnel_set_redirect(&rt->dst); + } + + rt->dst.lastuse = jiffies; +} + +/* Caller must already hold reference to @from */ +static void rt6_set_from(struct rt6_info *rt, struct fib6_info *from) +{ + rt->rt6i_flags &= ~RTF_EXPIRES; + rcu_assign_pointer(rt->from, from); + ip_dst_init_metrics(&rt->dst, from->fib6_metrics); +} + +/* Caller must already hold reference to f6i in result */ +static void ip6_rt_copy_init(struct rt6_info *rt, const struct fib6_result *res) +{ + const struct fib6_nh *nh = res->nh; + const struct net_device *dev = nh->fib_nh_dev; + struct fib6_info *f6i = res->f6i; + + ip6_rt_init_dst(rt, res); + + rt->rt6i_dst = f6i->fib6_dst; + rt->rt6i_idev = dev ? in6_dev_get(dev) : NULL; + rt->rt6i_flags = res->fib6_flags; + if (nh->fib_nh_gw_family) { + rt->rt6i_gateway = nh->fib_nh_gw6; + rt->rt6i_flags |= RTF_GATEWAY; + } + rt6_set_from(rt, f6i); +#ifdef CONFIG_IPV6_SUBTREES + rt->rt6i_src = f6i->fib6_src; +#endif +} + +static struct fib6_node* fib6_backtrack(struct fib6_node *fn, + struct in6_addr *saddr) +{ + struct fib6_node *pn, *sn; + while (1) { + if (fn->fn_flags & RTN_TL_ROOT) + return NULL; + pn = rcu_dereference(fn->parent); + sn = FIB6_SUBTREE(pn); + if (sn && sn != fn) + fn = fib6_node_lookup(sn, NULL, saddr); + else + fn = pn; + if (fn->fn_flags & RTN_RTINFO) + return fn; + } +} + +static bool ip6_hold_safe(struct net *net, struct rt6_info **prt) +{ + struct rt6_info *rt = *prt; + + if (dst_hold_safe(&rt->dst)) + return true; + if (net) { + rt = net->ipv6.ip6_null_entry; + dst_hold(&rt->dst); + } else { + rt = NULL; + } + *prt = rt; + return false; +} + +/* called with rcu_lock held */ +static struct rt6_info *ip6_create_rt_rcu(const struct fib6_result *res) +{ + struct net_device *dev = res->nh->fib_nh_dev; + struct fib6_info *f6i = res->f6i; + unsigned short flags; + struct rt6_info *nrt; + + if (!fib6_info_hold_safe(f6i)) + goto fallback; + + flags = fib6_info_dst_flags(f6i); + nrt = ip6_dst_alloc(dev_net(dev), dev, flags); + if (!nrt) { + fib6_info_release(f6i); + goto fallback; + } + + ip6_rt_copy_init(nrt, res); + return nrt; + +fallback: + nrt = dev_net(dev)->ipv6.ip6_null_entry; + dst_hold(&nrt->dst); + return nrt; +} + +INDIRECT_CALLABLE_SCOPE struct rt6_info *ip6_pol_route_lookup(struct net *net, + struct fib6_table *table, + struct flowi6 *fl6, + const struct sk_buff *skb, + int flags) +{ + struct fib6_result res = {}; + struct fib6_node *fn; + struct rt6_info *rt; + + rcu_read_lock(); + fn = fib6_node_lookup(&table->tb6_root, &fl6->daddr, &fl6->saddr); +restart: + res.f6i = rcu_dereference(fn->leaf); + if (!res.f6i) + res.f6i = net->ipv6.fib6_null_entry; + else + rt6_device_match(net, &res, &fl6->saddr, fl6->flowi6_oif, + flags); + + if (res.f6i == net->ipv6.fib6_null_entry) { + fn = fib6_backtrack(fn, &fl6->saddr); + if (fn) + goto restart; + + rt = net->ipv6.ip6_null_entry; + dst_hold(&rt->dst); + goto out; + } else if (res.fib6_flags & RTF_REJECT) { + goto do_create; + } + + fib6_select_path(net, &res, fl6, fl6->flowi6_oif, + fl6->flowi6_oif != 0, skb, flags); + + /* Search through exception table */ + rt = rt6_find_cached_rt(&res, &fl6->daddr, &fl6->saddr); + if (rt) { + if (ip6_hold_safe(net, &rt)) + dst_use_noref(&rt->dst, jiffies); + } else { +do_create: + rt = ip6_create_rt_rcu(&res); + } + +out: + trace_fib6_table_lookup(net, &res, table, fl6); + + rcu_read_unlock(); + + return rt; +} + +struct dst_entry *ip6_route_lookup(struct net *net, struct flowi6 *fl6, + const struct sk_buff *skb, int flags) +{ + return fib6_rule_lookup(net, fl6, skb, flags, ip6_pol_route_lookup); +} +EXPORT_SYMBOL_GPL(ip6_route_lookup); + +struct rt6_info *rt6_lookup(struct net *net, const struct in6_addr *daddr, + const struct in6_addr *saddr, int oif, + const struct sk_buff *skb, int strict) +{ + struct flowi6 fl6 = { + .flowi6_oif = oif, + .daddr = *daddr, + }; + struct dst_entry *dst; + int flags = strict ? RT6_LOOKUP_F_IFACE : 0; + + if (saddr) { + memcpy(&fl6.saddr, saddr, sizeof(*saddr)); + flags |= RT6_LOOKUP_F_HAS_SADDR; + } + + dst = fib6_rule_lookup(net, &fl6, skb, flags, ip6_pol_route_lookup); + if (dst->error == 0) + return (struct rt6_info *) dst; + + dst_release(dst); + + return NULL; +} +EXPORT_SYMBOL(rt6_lookup); + +/* ip6_ins_rt is called with FREE table->tb6_lock. + * It takes new route entry, the addition fails by any reason the + * route is released. + * Caller must hold dst before calling it. + */ + +static int __ip6_ins_rt(struct fib6_info *rt, struct nl_info *info, + struct netlink_ext_ack *extack) +{ + int err; + struct fib6_table *table; + + table = rt->fib6_table; + spin_lock_bh(&table->tb6_lock); + err = fib6_add(&table->tb6_root, rt, info, extack); + spin_unlock_bh(&table->tb6_lock); + + return err; +} + +int ip6_ins_rt(struct net *net, struct fib6_info *rt) +{ + struct nl_info info = { .nl_net = net, }; + + return __ip6_ins_rt(rt, &info, NULL); +} + +static struct rt6_info *ip6_rt_cache_alloc(const struct fib6_result *res, + const struct in6_addr *daddr, + const struct in6_addr *saddr) +{ + struct fib6_info *f6i = res->f6i; + struct net_device *dev; + struct rt6_info *rt; + + /* + * Clone the route. + */ + + if (!fib6_info_hold_safe(f6i)) + return NULL; + + dev = ip6_rt_get_dev_rcu(res); + rt = ip6_dst_alloc(dev_net(dev), dev, 0); + if (!rt) { + fib6_info_release(f6i); + return NULL; + } + + ip6_rt_copy_init(rt, res); + rt->rt6i_flags |= RTF_CACHE; + rt->rt6i_dst.addr = *daddr; + rt->rt6i_dst.plen = 128; + + if (!rt6_is_gw_or_nonexthop(res)) { + if (f6i->fib6_dst.plen != 128 && + ipv6_addr_equal(&f6i->fib6_dst.addr, daddr)) + rt->rt6i_flags |= RTF_ANYCAST; +#ifdef CONFIG_IPV6_SUBTREES + if (rt->rt6i_src.plen && saddr) { + rt->rt6i_src.addr = *saddr; + rt->rt6i_src.plen = 128; + } +#endif + } + + return rt; +} + +static struct rt6_info *ip6_rt_pcpu_alloc(const struct fib6_result *res) +{ + struct fib6_info *f6i = res->f6i; + unsigned short flags = fib6_info_dst_flags(f6i); + struct net_device *dev; + struct rt6_info *pcpu_rt; + + if (!fib6_info_hold_safe(f6i)) + return NULL; + + rcu_read_lock(); + dev = ip6_rt_get_dev_rcu(res); + pcpu_rt = ip6_dst_alloc(dev_net(dev), dev, flags | DST_NOCOUNT); + rcu_read_unlock(); + if (!pcpu_rt) { + fib6_info_release(f6i); + return NULL; + } + ip6_rt_copy_init(pcpu_rt, res); + pcpu_rt->rt6i_flags |= RTF_PCPU; + + if (f6i->nh) + pcpu_rt->sernum = rt_genid_ipv6(dev_net(dev)); + + return pcpu_rt; +} + +static bool rt6_is_valid(const struct rt6_info *rt6) +{ + return rt6->sernum == rt_genid_ipv6(dev_net(rt6->dst.dev)); +} + +/* It should be called with rcu_read_lock() acquired */ +static struct rt6_info *rt6_get_pcpu_route(const struct fib6_result *res) +{ + struct rt6_info *pcpu_rt; + + pcpu_rt = this_cpu_read(*res->nh->rt6i_pcpu); + + if (pcpu_rt && pcpu_rt->sernum && !rt6_is_valid(pcpu_rt)) { + struct rt6_info *prev, **p; + + p = this_cpu_ptr(res->nh->rt6i_pcpu); + prev = xchg(p, NULL); + if (prev) { + dst_dev_put(&prev->dst); + dst_release(&prev->dst); + } + + pcpu_rt = NULL; + } + + return pcpu_rt; +} + +static struct rt6_info *rt6_make_pcpu_route(struct net *net, + const struct fib6_result *res) +{ + struct rt6_info *pcpu_rt, *prev, **p; + + pcpu_rt = ip6_rt_pcpu_alloc(res); + if (!pcpu_rt) + return NULL; + + p = this_cpu_ptr(res->nh->rt6i_pcpu); + prev = cmpxchg(p, NULL, pcpu_rt); + BUG_ON(prev); + + if (res->f6i->fib6_destroying) { + struct fib6_info *from; + + from = xchg((__force struct fib6_info **)&pcpu_rt->from, NULL); + fib6_info_release(from); + } + + return pcpu_rt; +} + +/* exception hash table implementation + */ +static DEFINE_SPINLOCK(rt6_exception_lock); + +/* Remove rt6_ex from hash table and free the memory + * Caller must hold rt6_exception_lock + */ +static void rt6_remove_exception(struct rt6_exception_bucket *bucket, + struct rt6_exception *rt6_ex) +{ + struct fib6_info *from; + struct net *net; + + if (!bucket || !rt6_ex) + return; + + net = dev_net(rt6_ex->rt6i->dst.dev); + net->ipv6.rt6_stats->fib_rt_cache--; + + /* purge completely the exception to allow releasing the held resources: + * some [sk] cache may keep the dst around for unlimited time + */ + from = xchg((__force struct fib6_info **)&rt6_ex->rt6i->from, NULL); + fib6_info_release(from); + dst_dev_put(&rt6_ex->rt6i->dst); + + hlist_del_rcu(&rt6_ex->hlist); + dst_release(&rt6_ex->rt6i->dst); + kfree_rcu(rt6_ex, rcu); + WARN_ON_ONCE(!bucket->depth); + bucket->depth--; +} + +/* Remove oldest rt6_ex in bucket and free the memory + * Caller must hold rt6_exception_lock + */ +static void rt6_exception_remove_oldest(struct rt6_exception_bucket *bucket) +{ + struct rt6_exception *rt6_ex, *oldest = NULL; + + if (!bucket) + return; + + hlist_for_each_entry(rt6_ex, &bucket->chain, hlist) { + if (!oldest || time_before(rt6_ex->stamp, oldest->stamp)) + oldest = rt6_ex; + } + rt6_remove_exception(bucket, oldest); +} + +static u32 rt6_exception_hash(const struct in6_addr *dst, + const struct in6_addr *src) +{ + static siphash_aligned_key_t rt6_exception_key; + struct { + struct in6_addr dst; + struct in6_addr src; + } __aligned(SIPHASH_ALIGNMENT) combined = { + .dst = *dst, + }; + u64 val; + + net_get_random_once(&rt6_exception_key, sizeof(rt6_exception_key)); + +#ifdef CONFIG_IPV6_SUBTREES + if (src) + combined.src = *src; +#endif + val = siphash(&combined, sizeof(combined), &rt6_exception_key); + + return hash_64(val, FIB6_EXCEPTION_BUCKET_SIZE_SHIFT); +} + +/* Helper function to find the cached rt in the hash table + * and update bucket pointer to point to the bucket for this + * (daddr, saddr) pair + * Caller must hold rt6_exception_lock + */ +static struct rt6_exception * +__rt6_find_exception_spinlock(struct rt6_exception_bucket **bucket, + const struct in6_addr *daddr, + const struct in6_addr *saddr) +{ + struct rt6_exception *rt6_ex; + u32 hval; + + if (!(*bucket) || !daddr) + return NULL; + + hval = rt6_exception_hash(daddr, saddr); + *bucket += hval; + + hlist_for_each_entry(rt6_ex, &(*bucket)->chain, hlist) { + struct rt6_info *rt6 = rt6_ex->rt6i; + bool matched = ipv6_addr_equal(daddr, &rt6->rt6i_dst.addr); + +#ifdef CONFIG_IPV6_SUBTREES + if (matched && saddr) + matched = ipv6_addr_equal(saddr, &rt6->rt6i_src.addr); +#endif + if (matched) + return rt6_ex; + } + return NULL; +} + +/* Helper function to find the cached rt in the hash table + * and update bucket pointer to point to the bucket for this + * (daddr, saddr) pair + * Caller must hold rcu_read_lock() + */ +static struct rt6_exception * +__rt6_find_exception_rcu(struct rt6_exception_bucket **bucket, + const struct in6_addr *daddr, + const struct in6_addr *saddr) +{ + struct rt6_exception *rt6_ex; + u32 hval; + + WARN_ON_ONCE(!rcu_read_lock_held()); + + if (!(*bucket) || !daddr) + return NULL; + + hval = rt6_exception_hash(daddr, saddr); + *bucket += hval; + + hlist_for_each_entry_rcu(rt6_ex, &(*bucket)->chain, hlist) { + struct rt6_info *rt6 = rt6_ex->rt6i; + bool matched = ipv6_addr_equal(daddr, &rt6->rt6i_dst.addr); + +#ifdef CONFIG_IPV6_SUBTREES + if (matched && saddr) + matched = ipv6_addr_equal(saddr, &rt6->rt6i_src.addr); +#endif + if (matched) + return rt6_ex; + } + return NULL; +} + +static unsigned int fib6_mtu(const struct fib6_result *res) +{ + const struct fib6_nh *nh = res->nh; + unsigned int mtu; + + if (res->f6i->fib6_pmtu) { + mtu = res->f6i->fib6_pmtu; + } else { + struct net_device *dev = nh->fib_nh_dev; + struct inet6_dev *idev; + + rcu_read_lock(); + idev = __in6_dev_get(dev); + mtu = idev->cnf.mtu6; + rcu_read_unlock(); + } + + mtu = min_t(unsigned int, mtu, IP6_MAX_MTU); + + return mtu - lwtunnel_headroom(nh->fib_nh_lws, mtu); +} + +#define FIB6_EXCEPTION_BUCKET_FLUSHED 0x1UL + +/* used when the flushed bit is not relevant, only access to the bucket + * (ie., all bucket users except rt6_insert_exception); + * + * called under rcu lock; sometimes called with rt6_exception_lock held + */ +static +struct rt6_exception_bucket *fib6_nh_get_excptn_bucket(const struct fib6_nh *nh, + spinlock_t *lock) +{ + struct rt6_exception_bucket *bucket; + + if (lock) + bucket = rcu_dereference_protected(nh->rt6i_exception_bucket, + lockdep_is_held(lock)); + else + bucket = rcu_dereference(nh->rt6i_exception_bucket); + + /* remove bucket flushed bit if set */ + if (bucket) { + unsigned long p = (unsigned long)bucket; + + p &= ~FIB6_EXCEPTION_BUCKET_FLUSHED; + bucket = (struct rt6_exception_bucket *)p; + } + + return bucket; +} + +static bool fib6_nh_excptn_bucket_flushed(struct rt6_exception_bucket *bucket) +{ + unsigned long p = (unsigned long)bucket; + + return !!(p & FIB6_EXCEPTION_BUCKET_FLUSHED); +} + +/* called with rt6_exception_lock held */ +static void fib6_nh_excptn_bucket_set_flushed(struct fib6_nh *nh, + spinlock_t *lock) +{ + struct rt6_exception_bucket *bucket; + unsigned long p; + + bucket = rcu_dereference_protected(nh->rt6i_exception_bucket, + lockdep_is_held(lock)); + + p = (unsigned long)bucket; + p |= FIB6_EXCEPTION_BUCKET_FLUSHED; + bucket = (struct rt6_exception_bucket *)p; + rcu_assign_pointer(nh->rt6i_exception_bucket, bucket); +} + +static int rt6_insert_exception(struct rt6_info *nrt, + const struct fib6_result *res) +{ + struct net *net = dev_net(nrt->dst.dev); + struct rt6_exception_bucket *bucket; + struct fib6_info *f6i = res->f6i; + struct in6_addr *src_key = NULL; + struct rt6_exception *rt6_ex; + struct fib6_nh *nh = res->nh; + int max_depth; + int err = 0; + + spin_lock_bh(&rt6_exception_lock); + + bucket = rcu_dereference_protected(nh->rt6i_exception_bucket, + lockdep_is_held(&rt6_exception_lock)); + if (!bucket) { + bucket = kcalloc(FIB6_EXCEPTION_BUCKET_SIZE, sizeof(*bucket), + GFP_ATOMIC); + if (!bucket) { + err = -ENOMEM; + goto out; + } + rcu_assign_pointer(nh->rt6i_exception_bucket, bucket); + } else if (fib6_nh_excptn_bucket_flushed(bucket)) { + err = -EINVAL; + goto out; + } + +#ifdef CONFIG_IPV6_SUBTREES + /* fib6_src.plen != 0 indicates f6i is in subtree + * and exception table is indexed by a hash of + * both fib6_dst and fib6_src. + * Otherwise, the exception table is indexed by + * a hash of only fib6_dst. + */ + if (f6i->fib6_src.plen) + src_key = &nrt->rt6i_src.addr; +#endif + /* rt6_mtu_change() might lower mtu on f6i. + * Only insert this exception route if its mtu + * is less than f6i's mtu value. + */ + if (dst_metric_raw(&nrt->dst, RTAX_MTU) >= fib6_mtu(res)) { + err = -EINVAL; + goto out; + } + + rt6_ex = __rt6_find_exception_spinlock(&bucket, &nrt->rt6i_dst.addr, + src_key); + if (rt6_ex) + rt6_remove_exception(bucket, rt6_ex); + + rt6_ex = kzalloc(sizeof(*rt6_ex), GFP_ATOMIC); + if (!rt6_ex) { + err = -ENOMEM; + goto out; + } + rt6_ex->rt6i = nrt; + rt6_ex->stamp = jiffies; + hlist_add_head_rcu(&rt6_ex->hlist, &bucket->chain); + bucket->depth++; + net->ipv6.rt6_stats->fib_rt_cache++; + + /* Randomize max depth to avoid some side channels attacks. */ + max_depth = FIB6_MAX_DEPTH + prandom_u32_max(FIB6_MAX_DEPTH); + while (bucket->depth > max_depth) + rt6_exception_remove_oldest(bucket); + +out: + spin_unlock_bh(&rt6_exception_lock); + + /* Update fn->fn_sernum to invalidate all cached dst */ + if (!err) { + spin_lock_bh(&f6i->fib6_table->tb6_lock); + fib6_update_sernum(net, f6i); + spin_unlock_bh(&f6i->fib6_table->tb6_lock); + fib6_force_start_gc(net); + } + + return err; +} + +static void fib6_nh_flush_exceptions(struct fib6_nh *nh, struct fib6_info *from) +{ + struct rt6_exception_bucket *bucket; + struct rt6_exception *rt6_ex; + struct hlist_node *tmp; + int i; + + spin_lock_bh(&rt6_exception_lock); + + bucket = fib6_nh_get_excptn_bucket(nh, &rt6_exception_lock); + if (!bucket) + goto out; + + /* Prevent rt6_insert_exception() to recreate the bucket list */ + if (!from) + fib6_nh_excptn_bucket_set_flushed(nh, &rt6_exception_lock); + + for (i = 0; i < FIB6_EXCEPTION_BUCKET_SIZE; i++) { + hlist_for_each_entry_safe(rt6_ex, tmp, &bucket->chain, hlist) { + if (!from || + rcu_access_pointer(rt6_ex->rt6i->from) == from) + rt6_remove_exception(bucket, rt6_ex); + } + WARN_ON_ONCE(!from && bucket->depth); + bucket++; + } +out: + spin_unlock_bh(&rt6_exception_lock); +} + +static int rt6_nh_flush_exceptions(struct fib6_nh *nh, void *arg) +{ + struct fib6_info *f6i = arg; + + fib6_nh_flush_exceptions(nh, f6i); + + return 0; +} + +void rt6_flush_exceptions(struct fib6_info *f6i) +{ + if (f6i->nh) + nexthop_for_each_fib6_nh(f6i->nh, rt6_nh_flush_exceptions, + f6i); + else + fib6_nh_flush_exceptions(f6i->fib6_nh, f6i); +} + +/* Find cached rt in the hash table inside passed in rt + * Caller has to hold rcu_read_lock() + */ +static struct rt6_info *rt6_find_cached_rt(const struct fib6_result *res, + const struct in6_addr *daddr, + const struct in6_addr *saddr) +{ + const struct in6_addr *src_key = NULL; + struct rt6_exception_bucket *bucket; + struct rt6_exception *rt6_ex; + struct rt6_info *ret = NULL; + +#ifdef CONFIG_IPV6_SUBTREES + /* fib6i_src.plen != 0 indicates f6i is in subtree + * and exception table is indexed by a hash of + * both fib6_dst and fib6_src. + * However, the src addr used to create the hash + * might not be exactly the passed in saddr which + * is a /128 addr from the flow. + * So we need to use f6i->fib6_src to redo lookup + * if the passed in saddr does not find anything. + * (See the logic in ip6_rt_cache_alloc() on how + * rt->rt6i_src is updated.) + */ + if (res->f6i->fib6_src.plen) + src_key = saddr; +find_ex: +#endif + bucket = fib6_nh_get_excptn_bucket(res->nh, NULL); + rt6_ex = __rt6_find_exception_rcu(&bucket, daddr, src_key); + + if (rt6_ex && !rt6_check_expired(rt6_ex->rt6i)) + ret = rt6_ex->rt6i; + +#ifdef CONFIG_IPV6_SUBTREES + /* Use fib6_src as src_key and redo lookup */ + if (!ret && src_key && src_key != &res->f6i->fib6_src.addr) { + src_key = &res->f6i->fib6_src.addr; + goto find_ex; + } +#endif + + return ret; +} + +/* Remove the passed in cached rt from the hash table that contains it */ +static int fib6_nh_remove_exception(const struct fib6_nh *nh, int plen, + const struct rt6_info *rt) +{ + const struct in6_addr *src_key = NULL; + struct rt6_exception_bucket *bucket; + struct rt6_exception *rt6_ex; + int err; + + if (!rcu_access_pointer(nh->rt6i_exception_bucket)) + return -ENOENT; + + spin_lock_bh(&rt6_exception_lock); + bucket = fib6_nh_get_excptn_bucket(nh, &rt6_exception_lock); + +#ifdef CONFIG_IPV6_SUBTREES + /* rt6i_src.plen != 0 indicates 'from' is in subtree + * and exception table is indexed by a hash of + * both rt6i_dst and rt6i_src. + * Otherwise, the exception table is indexed by + * a hash of only rt6i_dst. + */ + if (plen) + src_key = &rt->rt6i_src.addr; +#endif + rt6_ex = __rt6_find_exception_spinlock(&bucket, + &rt->rt6i_dst.addr, + src_key); + if (rt6_ex) { + rt6_remove_exception(bucket, rt6_ex); + err = 0; + } else { + err = -ENOENT; + } + + spin_unlock_bh(&rt6_exception_lock); + return err; +} + +struct fib6_nh_excptn_arg { + struct rt6_info *rt; + int plen; +}; + +static int rt6_nh_remove_exception_rt(struct fib6_nh *nh, void *_arg) +{ + struct fib6_nh_excptn_arg *arg = _arg; + int err; + + err = fib6_nh_remove_exception(nh, arg->plen, arg->rt); + if (err == 0) + return 1; + + return 0; +} + +static int rt6_remove_exception_rt(struct rt6_info *rt) +{ + struct fib6_info *from; + + from = rcu_dereference(rt->from); + if (!from || !(rt->rt6i_flags & RTF_CACHE)) + return -EINVAL; + + if (from->nh) { + struct fib6_nh_excptn_arg arg = { + .rt = rt, + .plen = from->fib6_src.plen + }; + int rc; + + /* rc = 1 means an entry was found */ + rc = nexthop_for_each_fib6_nh(from->nh, + rt6_nh_remove_exception_rt, + &arg); + return rc ? 0 : -ENOENT; + } + + return fib6_nh_remove_exception(from->fib6_nh, + from->fib6_src.plen, rt); +} + +/* Find rt6_ex which contains the passed in rt cache and + * refresh its stamp + */ +static void fib6_nh_update_exception(const struct fib6_nh *nh, int plen, + const struct rt6_info *rt) +{ + const struct in6_addr *src_key = NULL; + struct rt6_exception_bucket *bucket; + struct rt6_exception *rt6_ex; + + bucket = fib6_nh_get_excptn_bucket(nh, NULL); +#ifdef CONFIG_IPV6_SUBTREES + /* rt6i_src.plen != 0 indicates 'from' is in subtree + * and exception table is indexed by a hash of + * both rt6i_dst and rt6i_src. + * Otherwise, the exception table is indexed by + * a hash of only rt6i_dst. + */ + if (plen) + src_key = &rt->rt6i_src.addr; +#endif + rt6_ex = __rt6_find_exception_rcu(&bucket, &rt->rt6i_dst.addr, src_key); + if (rt6_ex) + rt6_ex->stamp = jiffies; +} + +struct fib6_nh_match_arg { + const struct net_device *dev; + const struct in6_addr *gw; + struct fib6_nh *match; +}; + +/* determine if fib6_nh has given device and gateway */ +static int fib6_nh_find_match(struct fib6_nh *nh, void *_arg) +{ + struct fib6_nh_match_arg *arg = _arg; + + if (arg->dev != nh->fib_nh_dev || + (arg->gw && !nh->fib_nh_gw_family) || + (!arg->gw && nh->fib_nh_gw_family) || + (arg->gw && !ipv6_addr_equal(arg->gw, &nh->fib_nh_gw6))) + return 0; + + arg->match = nh; + + /* found a match, break the loop */ + return 1; +} + +static void rt6_update_exception_stamp_rt(struct rt6_info *rt) +{ + struct fib6_info *from; + struct fib6_nh *fib6_nh; + + rcu_read_lock(); + + from = rcu_dereference(rt->from); + if (!from || !(rt->rt6i_flags & RTF_CACHE)) + goto unlock; + + if (from->nh) { + struct fib6_nh_match_arg arg = { + .dev = rt->dst.dev, + .gw = &rt->rt6i_gateway, + }; + + nexthop_for_each_fib6_nh(from->nh, fib6_nh_find_match, &arg); + + if (!arg.match) + goto unlock; + fib6_nh = arg.match; + } else { + fib6_nh = from->fib6_nh; + } + fib6_nh_update_exception(fib6_nh, from->fib6_src.plen, rt); +unlock: + rcu_read_unlock(); +} + +static bool rt6_mtu_change_route_allowed(struct inet6_dev *idev, + struct rt6_info *rt, int mtu) +{ + /* If the new MTU is lower than the route PMTU, this new MTU will be the + * lowest MTU in the path: always allow updating the route PMTU to + * reflect PMTU decreases. + * + * If the new MTU is higher, and the route PMTU is equal to the local + * MTU, this means the old MTU is the lowest in the path, so allow + * updating it: if other nodes now have lower MTUs, PMTU discovery will + * handle this. + */ + + if (dst_mtu(&rt->dst) >= mtu) + return true; + + if (dst_mtu(&rt->dst) == idev->cnf.mtu6) + return true; + + return false; +} + +static void rt6_exceptions_update_pmtu(struct inet6_dev *idev, + const struct fib6_nh *nh, int mtu) +{ + struct rt6_exception_bucket *bucket; + struct rt6_exception *rt6_ex; + int i; + + bucket = fib6_nh_get_excptn_bucket(nh, &rt6_exception_lock); + if (!bucket) + return; + + for (i = 0; i < FIB6_EXCEPTION_BUCKET_SIZE; i++) { + hlist_for_each_entry(rt6_ex, &bucket->chain, hlist) { + struct rt6_info *entry = rt6_ex->rt6i; + + /* For RTF_CACHE with rt6i_pmtu == 0 (i.e. a redirected + * route), the metrics of its rt->from have already + * been updated. + */ + if (dst_metric_raw(&entry->dst, RTAX_MTU) && + rt6_mtu_change_route_allowed(idev, entry, mtu)) + dst_metric_set(&entry->dst, RTAX_MTU, mtu); + } + bucket++; + } +} + +#define RTF_CACHE_GATEWAY (RTF_GATEWAY | RTF_CACHE) + +static void fib6_nh_exceptions_clean_tohost(const struct fib6_nh *nh, + const struct in6_addr *gateway) +{ + struct rt6_exception_bucket *bucket; + struct rt6_exception *rt6_ex; + struct hlist_node *tmp; + int i; + + if (!rcu_access_pointer(nh->rt6i_exception_bucket)) + return; + + spin_lock_bh(&rt6_exception_lock); + bucket = fib6_nh_get_excptn_bucket(nh, &rt6_exception_lock); + if (bucket) { + for (i = 0; i < FIB6_EXCEPTION_BUCKET_SIZE; i++) { + hlist_for_each_entry_safe(rt6_ex, tmp, + &bucket->chain, hlist) { + struct rt6_info *entry = rt6_ex->rt6i; + + if ((entry->rt6i_flags & RTF_CACHE_GATEWAY) == + RTF_CACHE_GATEWAY && + ipv6_addr_equal(gateway, + &entry->rt6i_gateway)) { + rt6_remove_exception(bucket, rt6_ex); + } + } + bucket++; + } + } + + spin_unlock_bh(&rt6_exception_lock); +} + +static void rt6_age_examine_exception(struct rt6_exception_bucket *bucket, + struct rt6_exception *rt6_ex, + struct fib6_gc_args *gc_args, + unsigned long now) +{ + struct rt6_info *rt = rt6_ex->rt6i; + + /* we are pruning and obsoleting aged-out and non gateway exceptions + * even if others have still references to them, so that on next + * dst_check() such references can be dropped. + * EXPIRES exceptions - e.g. pmtu-generated ones are pruned when + * expired, independently from their aging, as per RFC 8201 section 4 + */ + if (!(rt->rt6i_flags & RTF_EXPIRES)) { + if (time_after_eq(now, rt->dst.lastuse + gc_args->timeout)) { + RT6_TRACE("aging clone %p\n", rt); + rt6_remove_exception(bucket, rt6_ex); + return; + } + } else if (time_after(jiffies, rt->dst.expires)) { + RT6_TRACE("purging expired route %p\n", rt); + rt6_remove_exception(bucket, rt6_ex); + return; + } + + if (rt->rt6i_flags & RTF_GATEWAY) { + struct neighbour *neigh; + + neigh = __ipv6_neigh_lookup_noref(rt->dst.dev, &rt->rt6i_gateway); + + if (!(neigh && (neigh->flags & NTF_ROUTER))) { + RT6_TRACE("purging route %p via non-router but gateway\n", + rt); + rt6_remove_exception(bucket, rt6_ex); + return; + } + } + + gc_args->more++; +} + +static void fib6_nh_age_exceptions(const struct fib6_nh *nh, + struct fib6_gc_args *gc_args, + unsigned long now) +{ + struct rt6_exception_bucket *bucket; + struct rt6_exception *rt6_ex; + struct hlist_node *tmp; + int i; + + if (!rcu_access_pointer(nh->rt6i_exception_bucket)) + return; + + rcu_read_lock_bh(); + spin_lock(&rt6_exception_lock); + bucket = fib6_nh_get_excptn_bucket(nh, &rt6_exception_lock); + if (bucket) { + for (i = 0; i < FIB6_EXCEPTION_BUCKET_SIZE; i++) { + hlist_for_each_entry_safe(rt6_ex, tmp, + &bucket->chain, hlist) { + rt6_age_examine_exception(bucket, rt6_ex, + gc_args, now); + } + bucket++; + } + } + spin_unlock(&rt6_exception_lock); + rcu_read_unlock_bh(); +} + +struct fib6_nh_age_excptn_arg { + struct fib6_gc_args *gc_args; + unsigned long now; +}; + +static int rt6_nh_age_exceptions(struct fib6_nh *nh, void *_arg) +{ + struct fib6_nh_age_excptn_arg *arg = _arg; + + fib6_nh_age_exceptions(nh, arg->gc_args, arg->now); + return 0; +} + +void rt6_age_exceptions(struct fib6_info *f6i, + struct fib6_gc_args *gc_args, + unsigned long now) +{ + if (f6i->nh) { + struct fib6_nh_age_excptn_arg arg = { + .gc_args = gc_args, + .now = now + }; + + nexthop_for_each_fib6_nh(f6i->nh, rt6_nh_age_exceptions, + &arg); + } else { + fib6_nh_age_exceptions(f6i->fib6_nh, gc_args, now); + } +} + +/* must be called with rcu lock held */ +int fib6_table_lookup(struct net *net, struct fib6_table *table, int oif, + struct flowi6 *fl6, struct fib6_result *res, int strict) +{ + struct fib6_node *fn, *saved_fn; + + fn = fib6_node_lookup(&table->tb6_root, &fl6->daddr, &fl6->saddr); + saved_fn = fn; + +redo_rt6_select: + rt6_select(net, fn, oif, res, strict); + if (res->f6i == net->ipv6.fib6_null_entry) { + fn = fib6_backtrack(fn, &fl6->saddr); + if (fn) + goto redo_rt6_select; + else if (strict & RT6_LOOKUP_F_REACHABLE) { + /* also consider unreachable route */ + strict &= ~RT6_LOOKUP_F_REACHABLE; + fn = saved_fn; + goto redo_rt6_select; + } + } + + trace_fib6_table_lookup(net, res, table, fl6); + + return 0; +} + +struct rt6_info *ip6_pol_route(struct net *net, struct fib6_table *table, + int oif, struct flowi6 *fl6, + const struct sk_buff *skb, int flags) +{ + struct fib6_result res = {}; + struct rt6_info *rt = NULL; + int strict = 0; + + WARN_ON_ONCE((flags & RT6_LOOKUP_F_DST_NOREF) && + !rcu_read_lock_held()); + + strict |= flags & RT6_LOOKUP_F_IFACE; + strict |= flags & RT6_LOOKUP_F_IGNORE_LINKSTATE; + if (net->ipv6.devconf_all->forwarding == 0) + strict |= RT6_LOOKUP_F_REACHABLE; + + rcu_read_lock(); + + fib6_table_lookup(net, table, oif, fl6, &res, strict); + if (res.f6i == net->ipv6.fib6_null_entry) + goto out; + + fib6_select_path(net, &res, fl6, oif, false, skb, strict); + + /*Search through exception table */ + rt = rt6_find_cached_rt(&res, &fl6->daddr, &fl6->saddr); + if (rt) { + goto out; + } else if (unlikely((fl6->flowi6_flags & FLOWI_FLAG_KNOWN_NH) && + !res.nh->fib_nh_gw_family)) { + /* Create a RTF_CACHE clone which will not be + * owned by the fib6 tree. It is for the special case where + * the daddr in the skb during the neighbor look-up is different + * from the fl6->daddr used to look-up route here. + */ + rt = ip6_rt_cache_alloc(&res, &fl6->daddr, NULL); + + if (rt) { + /* 1 refcnt is taken during ip6_rt_cache_alloc(). + * As rt6_uncached_list_add() does not consume refcnt, + * this refcnt is always returned to the caller even + * if caller sets RT6_LOOKUP_F_DST_NOREF flag. + */ + rt6_uncached_list_add(rt); + rcu_read_unlock(); + + return rt; + } + } else { + /* Get a percpu copy */ + local_bh_disable(); + rt = rt6_get_pcpu_route(&res); + + if (!rt) + rt = rt6_make_pcpu_route(net, &res); + + local_bh_enable(); + } +out: + if (!rt) + rt = net->ipv6.ip6_null_entry; + if (!(flags & RT6_LOOKUP_F_DST_NOREF)) + ip6_hold_safe(net, &rt); + rcu_read_unlock(); + + return rt; +} +EXPORT_SYMBOL_GPL(ip6_pol_route); + +INDIRECT_CALLABLE_SCOPE struct rt6_info *ip6_pol_route_input(struct net *net, + struct fib6_table *table, + struct flowi6 *fl6, + const struct sk_buff *skb, + int flags) +{ + return ip6_pol_route(net, table, fl6->flowi6_iif, fl6, skb, flags); +} + +struct dst_entry *ip6_route_input_lookup(struct net *net, + struct net_device *dev, + struct flowi6 *fl6, + const struct sk_buff *skb, + int flags) +{ + if (rt6_need_strict(&fl6->daddr) && dev->type != ARPHRD_PIMREG) + flags |= RT6_LOOKUP_F_IFACE; + + return fib6_rule_lookup(net, fl6, skb, flags, ip6_pol_route_input); +} +EXPORT_SYMBOL_GPL(ip6_route_input_lookup); + +static void ip6_multipath_l3_keys(const struct sk_buff *skb, + struct flow_keys *keys, + struct flow_keys *flkeys) +{ + const struct ipv6hdr *outer_iph = ipv6_hdr(skb); + const struct ipv6hdr *key_iph = outer_iph; + struct flow_keys *_flkeys = flkeys; + const struct ipv6hdr *inner_iph; + const struct icmp6hdr *icmph; + struct ipv6hdr _inner_iph; + struct icmp6hdr _icmph; + + if (likely(outer_iph->nexthdr != IPPROTO_ICMPV6)) + goto out; + + icmph = skb_header_pointer(skb, skb_transport_offset(skb), + sizeof(_icmph), &_icmph); + if (!icmph) + goto out; + + if (!icmpv6_is_err(icmph->icmp6_type)) + goto out; + + inner_iph = skb_header_pointer(skb, + skb_transport_offset(skb) + sizeof(*icmph), + sizeof(_inner_iph), &_inner_iph); + if (!inner_iph) + goto out; + + key_iph = inner_iph; + _flkeys = NULL; +out: + if (_flkeys) { + keys->addrs.v6addrs.src = _flkeys->addrs.v6addrs.src; + keys->addrs.v6addrs.dst = _flkeys->addrs.v6addrs.dst; + keys->tags.flow_label = _flkeys->tags.flow_label; + keys->basic.ip_proto = _flkeys->basic.ip_proto; + } else { + keys->addrs.v6addrs.src = key_iph->saddr; + keys->addrs.v6addrs.dst = key_iph->daddr; + keys->tags.flow_label = ip6_flowlabel(key_iph); + keys->basic.ip_proto = key_iph->nexthdr; + } +} + +static u32 rt6_multipath_custom_hash_outer(const struct net *net, + const struct sk_buff *skb, + bool *p_has_inner) +{ + u32 hash_fields = ip6_multipath_hash_fields(net); + struct flow_keys keys, hash_keys; + + if (!(hash_fields & FIB_MULTIPATH_HASH_FIELD_OUTER_MASK)) + return 0; + + memset(&hash_keys, 0, sizeof(hash_keys)); + skb_flow_dissect_flow_keys(skb, &keys, FLOW_DISSECTOR_F_STOP_AT_ENCAP); + + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_SRC_IP) + hash_keys.addrs.v6addrs.src = keys.addrs.v6addrs.src; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_DST_IP) + hash_keys.addrs.v6addrs.dst = keys.addrs.v6addrs.dst; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_IP_PROTO) + hash_keys.basic.ip_proto = keys.basic.ip_proto; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_FLOWLABEL) + hash_keys.tags.flow_label = keys.tags.flow_label; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_SRC_PORT) + hash_keys.ports.src = keys.ports.src; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_DST_PORT) + hash_keys.ports.dst = keys.ports.dst; + + *p_has_inner = !!(keys.control.flags & FLOW_DIS_ENCAPSULATION); + return flow_hash_from_keys(&hash_keys); +} + +static u32 rt6_multipath_custom_hash_inner(const struct net *net, + const struct sk_buff *skb, + bool has_inner) +{ + u32 hash_fields = ip6_multipath_hash_fields(net); + struct flow_keys keys, hash_keys; + + /* We assume the packet carries an encapsulation, but if none was + * encountered during dissection of the outer flow, then there is no + * point in calling the flow dissector again. + */ + if (!has_inner) + return 0; + + if (!(hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_MASK)) + return 0; + + memset(&hash_keys, 0, sizeof(hash_keys)); + skb_flow_dissect_flow_keys(skb, &keys, 0); + + if (!(keys.control.flags & FLOW_DIS_ENCAPSULATION)) + return 0; + + if (keys.control.addr_type == FLOW_DISSECTOR_KEY_IPV4_ADDRS) { + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_SRC_IP) + hash_keys.addrs.v4addrs.src = keys.addrs.v4addrs.src; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_DST_IP) + hash_keys.addrs.v4addrs.dst = keys.addrs.v4addrs.dst; + } else if (keys.control.addr_type == FLOW_DISSECTOR_KEY_IPV6_ADDRS) { + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_SRC_IP) + hash_keys.addrs.v6addrs.src = keys.addrs.v6addrs.src; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_DST_IP) + hash_keys.addrs.v6addrs.dst = keys.addrs.v6addrs.dst; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_FLOWLABEL) + hash_keys.tags.flow_label = keys.tags.flow_label; + } + + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_IP_PROTO) + hash_keys.basic.ip_proto = keys.basic.ip_proto; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_SRC_PORT) + hash_keys.ports.src = keys.ports.src; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_INNER_DST_PORT) + hash_keys.ports.dst = keys.ports.dst; + + return flow_hash_from_keys(&hash_keys); +} + +static u32 rt6_multipath_custom_hash_skb(const struct net *net, + const struct sk_buff *skb) +{ + u32 mhash, mhash_inner; + bool has_inner = true; + + mhash = rt6_multipath_custom_hash_outer(net, skb, &has_inner); + mhash_inner = rt6_multipath_custom_hash_inner(net, skb, has_inner); + + return jhash_2words(mhash, mhash_inner, 0); +} + +static u32 rt6_multipath_custom_hash_fl6(const struct net *net, + const struct flowi6 *fl6) +{ + u32 hash_fields = ip6_multipath_hash_fields(net); + struct flow_keys hash_keys; + + if (!(hash_fields & FIB_MULTIPATH_HASH_FIELD_OUTER_MASK)) + return 0; + + memset(&hash_keys, 0, sizeof(hash_keys)); + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_SRC_IP) + hash_keys.addrs.v6addrs.src = fl6->saddr; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_DST_IP) + hash_keys.addrs.v6addrs.dst = fl6->daddr; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_IP_PROTO) + hash_keys.basic.ip_proto = fl6->flowi6_proto; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_FLOWLABEL) + hash_keys.tags.flow_label = (__force u32)flowi6_get_flowlabel(fl6); + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_SRC_PORT) + hash_keys.ports.src = fl6->fl6_sport; + if (hash_fields & FIB_MULTIPATH_HASH_FIELD_DST_PORT) + hash_keys.ports.dst = fl6->fl6_dport; + + return flow_hash_from_keys(&hash_keys); +} + +/* if skb is set it will be used and fl6 can be NULL */ +u32 rt6_multipath_hash(const struct net *net, const struct flowi6 *fl6, + const struct sk_buff *skb, struct flow_keys *flkeys) +{ + struct flow_keys hash_keys; + u32 mhash = 0; + + switch (ip6_multipath_hash_policy(net)) { + case 0: + memset(&hash_keys, 0, sizeof(hash_keys)); + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + if (skb) { + ip6_multipath_l3_keys(skb, &hash_keys, flkeys); + } else { + hash_keys.addrs.v6addrs.src = fl6->saddr; + hash_keys.addrs.v6addrs.dst = fl6->daddr; + hash_keys.tags.flow_label = (__force u32)flowi6_get_flowlabel(fl6); + hash_keys.basic.ip_proto = fl6->flowi6_proto; + } + mhash = flow_hash_from_keys(&hash_keys); + break; + case 1: + if (skb) { + unsigned int flag = FLOW_DISSECTOR_F_STOP_AT_ENCAP; + struct flow_keys keys; + + /* short-circuit if we already have L4 hash present */ + if (skb->l4_hash) + return skb_get_hash_raw(skb) >> 1; + + memset(&hash_keys, 0, sizeof(hash_keys)); + + if (!flkeys) { + skb_flow_dissect_flow_keys(skb, &keys, flag); + flkeys = &keys; + } + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + hash_keys.addrs.v6addrs.src = flkeys->addrs.v6addrs.src; + hash_keys.addrs.v6addrs.dst = flkeys->addrs.v6addrs.dst; + hash_keys.ports.src = flkeys->ports.src; + hash_keys.ports.dst = flkeys->ports.dst; + hash_keys.basic.ip_proto = flkeys->basic.ip_proto; + } else { + memset(&hash_keys, 0, sizeof(hash_keys)); + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + hash_keys.addrs.v6addrs.src = fl6->saddr; + hash_keys.addrs.v6addrs.dst = fl6->daddr; + hash_keys.ports.src = fl6->fl6_sport; + hash_keys.ports.dst = fl6->fl6_dport; + hash_keys.basic.ip_proto = fl6->flowi6_proto; + } + mhash = flow_hash_from_keys(&hash_keys); + break; + case 2: + memset(&hash_keys, 0, sizeof(hash_keys)); + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + if (skb) { + struct flow_keys keys; + + if (!flkeys) { + skb_flow_dissect_flow_keys(skb, &keys, 0); + flkeys = &keys; + } + + /* Inner can be v4 or v6 */ + if (flkeys->control.addr_type == FLOW_DISSECTOR_KEY_IPV4_ADDRS) { + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; + hash_keys.addrs.v4addrs.src = flkeys->addrs.v4addrs.src; + hash_keys.addrs.v4addrs.dst = flkeys->addrs.v4addrs.dst; + } else if (flkeys->control.addr_type == FLOW_DISSECTOR_KEY_IPV6_ADDRS) { + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + hash_keys.addrs.v6addrs.src = flkeys->addrs.v6addrs.src; + hash_keys.addrs.v6addrs.dst = flkeys->addrs.v6addrs.dst; + hash_keys.tags.flow_label = flkeys->tags.flow_label; + hash_keys.basic.ip_proto = flkeys->basic.ip_proto; + } else { + /* Same as case 0 */ + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + ip6_multipath_l3_keys(skb, &hash_keys, flkeys); + } + } else { + /* Same as case 0 */ + hash_keys.control.addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + hash_keys.addrs.v6addrs.src = fl6->saddr; + hash_keys.addrs.v6addrs.dst = fl6->daddr; + hash_keys.tags.flow_label = (__force u32)flowi6_get_flowlabel(fl6); + hash_keys.basic.ip_proto = fl6->flowi6_proto; + } + mhash = flow_hash_from_keys(&hash_keys); + break; + case 3: + if (skb) + mhash = rt6_multipath_custom_hash_skb(net, skb); + else + mhash = rt6_multipath_custom_hash_fl6(net, fl6); + break; + } + + return mhash >> 1; +} + +/* Called with rcu held */ +void ip6_route_input(struct sk_buff *skb) +{ + const struct ipv6hdr *iph = ipv6_hdr(skb); + struct net *net = dev_net(skb->dev); + int flags = RT6_LOOKUP_F_HAS_SADDR | RT6_LOOKUP_F_DST_NOREF; + struct ip_tunnel_info *tun_info; + struct flowi6 fl6 = { + .flowi6_iif = skb->dev->ifindex, + .daddr = iph->daddr, + .saddr = iph->saddr, + .flowlabel = ip6_flowinfo(iph), + .flowi6_mark = skb->mark, + .flowi6_proto = iph->nexthdr, + }; + struct flow_keys *flkeys = NULL, _flkeys; + + tun_info = skb_tunnel_info(skb); + if (tun_info && !(tun_info->mode & IP_TUNNEL_INFO_TX)) + fl6.flowi6_tun_key.tun_id = tun_info->key.tun_id; + + if (fib6_rules_early_flow_dissect(net, skb, &fl6, &_flkeys)) + flkeys = &_flkeys; + + if (unlikely(fl6.flowi6_proto == IPPROTO_ICMPV6)) + fl6.mp_hash = rt6_multipath_hash(net, &fl6, skb, flkeys); + skb_dst_drop(skb); + skb_dst_set_noref(skb, ip6_route_input_lookup(net, skb->dev, + &fl6, skb, flags)); +} + +INDIRECT_CALLABLE_SCOPE struct rt6_info *ip6_pol_route_output(struct net *net, + struct fib6_table *table, + struct flowi6 *fl6, + const struct sk_buff *skb, + int flags) +{ + return ip6_pol_route(net, table, fl6->flowi6_oif, fl6, skb, flags); +} + +struct dst_entry *ip6_route_output_flags_noref(struct net *net, + const struct sock *sk, + struct flowi6 *fl6, int flags) +{ + bool any_src; + + if (ipv6_addr_type(&fl6->daddr) & + (IPV6_ADDR_MULTICAST | IPV6_ADDR_LINKLOCAL)) { + struct dst_entry *dst; + + /* This function does not take refcnt on the dst */ + dst = l3mdev_link_scope_lookup(net, fl6); + if (dst) + return dst; + } + + fl6->flowi6_iif = LOOPBACK_IFINDEX; + + flags |= RT6_LOOKUP_F_DST_NOREF; + any_src = ipv6_addr_any(&fl6->saddr); + if ((sk && sk->sk_bound_dev_if) || rt6_need_strict(&fl6->daddr) || + (fl6->flowi6_oif && any_src)) + flags |= RT6_LOOKUP_F_IFACE; + + if (!any_src) + flags |= RT6_LOOKUP_F_HAS_SADDR; + else if (sk) + flags |= rt6_srcprefs2flags(inet6_sk(sk)->srcprefs); + + return fib6_rule_lookup(net, fl6, NULL, flags, ip6_pol_route_output); +} +EXPORT_SYMBOL_GPL(ip6_route_output_flags_noref); + +struct dst_entry *ip6_route_output_flags(struct net *net, + const struct sock *sk, + struct flowi6 *fl6, + int flags) +{ + struct dst_entry *dst; + struct rt6_info *rt6; + + rcu_read_lock(); + dst = ip6_route_output_flags_noref(net, sk, fl6, flags); + rt6 = (struct rt6_info *)dst; + /* For dst cached in uncached_list, refcnt is already taken. */ + if (list_empty(&rt6->rt6i_uncached) && !dst_hold_safe(dst)) { + dst = &net->ipv6.ip6_null_entry->dst; + dst_hold(dst); + } + rcu_read_unlock(); + + return dst; +} +EXPORT_SYMBOL_GPL(ip6_route_output_flags); + +struct dst_entry *ip6_blackhole_route(struct net *net, struct dst_entry *dst_orig) +{ + struct rt6_info *rt, *ort = (struct rt6_info *) dst_orig; + struct net_device *loopback_dev = net->loopback_dev; + struct dst_entry *new = NULL; + + rt = dst_alloc(&ip6_dst_blackhole_ops, loopback_dev, 1, + DST_OBSOLETE_DEAD, 0); + if (rt) { + rt6_info_init(rt); + atomic_inc(&net->ipv6.rt6_stats->fib_rt_alloc); + + new = &rt->dst; + new->__use = 1; + new->input = dst_discard; + new->output = dst_discard_out; + + dst_copy_metrics(new, &ort->dst); + + rt->rt6i_idev = in6_dev_get(loopback_dev); + rt->rt6i_gateway = ort->rt6i_gateway; + rt->rt6i_flags = ort->rt6i_flags & ~RTF_PCPU; + + memcpy(&rt->rt6i_dst, &ort->rt6i_dst, sizeof(struct rt6key)); +#ifdef CONFIG_IPV6_SUBTREES + memcpy(&rt->rt6i_src, &ort->rt6i_src, sizeof(struct rt6key)); +#endif + } + + dst_release(dst_orig); + return new ? new : ERR_PTR(-ENOMEM); +} + +/* + * Destination cache support functions + */ + +static bool fib6_check(struct fib6_info *f6i, u32 cookie) +{ + u32 rt_cookie = 0; + + if (!fib6_get_cookie_safe(f6i, &rt_cookie) || rt_cookie != cookie) + return false; + + if (fib6_check_expired(f6i)) + return false; + + return true; +} + +static struct dst_entry *rt6_check(struct rt6_info *rt, + struct fib6_info *from, + u32 cookie) +{ + u32 rt_cookie = 0; + + if (!from || !fib6_get_cookie_safe(from, &rt_cookie) || + rt_cookie != cookie) + return NULL; + + if (rt6_check_expired(rt)) + return NULL; + + return &rt->dst; +} + +static struct dst_entry *rt6_dst_from_check(struct rt6_info *rt, + struct fib6_info *from, + u32 cookie) +{ + if (!__rt6_check_expired(rt) && + rt->dst.obsolete == DST_OBSOLETE_FORCE_CHK && + fib6_check(from, cookie)) + return &rt->dst; + else + return NULL; +} + +INDIRECT_CALLABLE_SCOPE struct dst_entry *ip6_dst_check(struct dst_entry *dst, + u32 cookie) +{ + struct dst_entry *dst_ret; + struct fib6_info *from; + struct rt6_info *rt; + + rt = container_of(dst, struct rt6_info, dst); + + if (rt->sernum) + return rt6_is_valid(rt) ? dst : NULL; + + rcu_read_lock(); + + /* All IPV6 dsts are created with ->obsolete set to the value + * DST_OBSOLETE_FORCE_CHK which forces validation calls down + * into this function always. + */ + + from = rcu_dereference(rt->from); + + if (from && (rt->rt6i_flags & RTF_PCPU || + unlikely(!list_empty(&rt->rt6i_uncached)))) + dst_ret = rt6_dst_from_check(rt, from, cookie); + else + dst_ret = rt6_check(rt, from, cookie); + + rcu_read_unlock(); + + return dst_ret; +} +EXPORT_INDIRECT_CALLABLE(ip6_dst_check); + +static struct dst_entry *ip6_negative_advice(struct dst_entry *dst) +{ + struct rt6_info *rt = (struct rt6_info *) dst; + + if (rt) { + if (rt->rt6i_flags & RTF_CACHE) { + rcu_read_lock(); + if (rt6_check_expired(rt)) { + rt6_remove_exception_rt(rt); + dst = NULL; + } + rcu_read_unlock(); + } else { + dst_release(dst); + dst = NULL; + } + } + return dst; +} + +static void ip6_link_failure(struct sk_buff *skb) +{ + struct rt6_info *rt; + + icmpv6_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_ADDR_UNREACH, 0); + + rt = (struct rt6_info *) skb_dst(skb); + if (rt) { + rcu_read_lock(); + if (rt->rt6i_flags & RTF_CACHE) { + rt6_remove_exception_rt(rt); + } else { + struct fib6_info *from; + struct fib6_node *fn; + + from = rcu_dereference(rt->from); + if (from) { + fn = rcu_dereference(from->fib6_node); + if (fn && (rt->rt6i_flags & RTF_DEFAULT)) + WRITE_ONCE(fn->fn_sernum, -1); + } + } + rcu_read_unlock(); + } +} + +static void rt6_update_expires(struct rt6_info *rt0, int timeout) +{ + if (!(rt0->rt6i_flags & RTF_EXPIRES)) { + struct fib6_info *from; + + rcu_read_lock(); + from = rcu_dereference(rt0->from); + if (from) + rt0->dst.expires = from->expires; + rcu_read_unlock(); + } + + dst_set_expires(&rt0->dst, timeout); + rt0->rt6i_flags |= RTF_EXPIRES; +} + +static void rt6_do_update_pmtu(struct rt6_info *rt, u32 mtu) +{ + struct net *net = dev_net(rt->dst.dev); + + dst_metric_set(&rt->dst, RTAX_MTU, mtu); + rt->rt6i_flags |= RTF_MODIFIED; + rt6_update_expires(rt, net->ipv6.sysctl.ip6_rt_mtu_expires); +} + +static bool rt6_cache_allowed_for_pmtu(const struct rt6_info *rt) +{ + return !(rt->rt6i_flags & RTF_CACHE) && + (rt->rt6i_flags & RTF_PCPU || rcu_access_pointer(rt->from)); +} + +static void __ip6_rt_update_pmtu(struct dst_entry *dst, const struct sock *sk, + const struct ipv6hdr *iph, u32 mtu, + bool confirm_neigh) +{ + const struct in6_addr *daddr, *saddr; + struct rt6_info *rt6 = (struct rt6_info *)dst; + + /* Note: do *NOT* check dst_metric_locked(dst, RTAX_MTU) + * IPv6 pmtu discovery isn't optional, so 'mtu lock' cannot disable it. + * [see also comment in rt6_mtu_change_route()] + */ + + if (iph) { + daddr = &iph->daddr; + saddr = &iph->saddr; + } else if (sk) { + daddr = &sk->sk_v6_daddr; + saddr = &inet6_sk(sk)->saddr; + } else { + daddr = NULL; + saddr = NULL; + } + + if (confirm_neigh) + dst_confirm_neigh(dst, daddr); + + if (mtu < IPV6_MIN_MTU) + return; + if (mtu >= dst_mtu(dst)) + return; + + if (!rt6_cache_allowed_for_pmtu(rt6)) { + rt6_do_update_pmtu(rt6, mtu); + /* update rt6_ex->stamp for cache */ + if (rt6->rt6i_flags & RTF_CACHE) + rt6_update_exception_stamp_rt(rt6); + } else if (daddr) { + struct fib6_result res = {}; + struct rt6_info *nrt6; + + rcu_read_lock(); + res.f6i = rcu_dereference(rt6->from); + if (!res.f6i) + goto out_unlock; + + res.fib6_flags = res.f6i->fib6_flags; + res.fib6_type = res.f6i->fib6_type; + + if (res.f6i->nh) { + struct fib6_nh_match_arg arg = { + .dev = dst->dev, + .gw = &rt6->rt6i_gateway, + }; + + nexthop_for_each_fib6_nh(res.f6i->nh, + fib6_nh_find_match, &arg); + + /* fib6_info uses a nexthop that does not have fib6_nh + * using the dst->dev + gw. Should be impossible. + */ + if (!arg.match) + goto out_unlock; + + res.nh = arg.match; + } else { + res.nh = res.f6i->fib6_nh; + } + + nrt6 = ip6_rt_cache_alloc(&res, daddr, saddr); + if (nrt6) { + rt6_do_update_pmtu(nrt6, mtu); + if (rt6_insert_exception(nrt6, &res)) + dst_release_immediate(&nrt6->dst); + } +out_unlock: + rcu_read_unlock(); + } +} + +static void ip6_rt_update_pmtu(struct dst_entry *dst, struct sock *sk, + struct sk_buff *skb, u32 mtu, + bool confirm_neigh) +{ + __ip6_rt_update_pmtu(dst, sk, skb ? ipv6_hdr(skb) : NULL, mtu, + confirm_neigh); +} + +void ip6_update_pmtu(struct sk_buff *skb, struct net *net, __be32 mtu, + int oif, u32 mark, kuid_t uid) +{ + const struct ipv6hdr *iph = (struct ipv6hdr *) skb->data; + struct dst_entry *dst; + struct flowi6 fl6 = { + .flowi6_oif = oif, + .flowi6_mark = mark ? mark : IP6_REPLY_MARK(net, skb->mark), + .daddr = iph->daddr, + .saddr = iph->saddr, + .flowlabel = ip6_flowinfo(iph), + .flowi6_uid = uid, + }; + + dst = ip6_route_output(net, NULL, &fl6); + if (!dst->error) + __ip6_rt_update_pmtu(dst, NULL, iph, ntohl(mtu), true); + dst_release(dst); +} +EXPORT_SYMBOL_GPL(ip6_update_pmtu); + +void ip6_sk_update_pmtu(struct sk_buff *skb, struct sock *sk, __be32 mtu) +{ + int oif = sk->sk_bound_dev_if; + struct dst_entry *dst; + + if (!oif && skb->dev) + oif = l3mdev_master_ifindex(skb->dev); + + ip6_update_pmtu(skb, sock_net(sk), mtu, oif, READ_ONCE(sk->sk_mark), + sk->sk_uid); + + dst = __sk_dst_get(sk); + if (!dst || !dst->obsolete || + dst->ops->check(dst, inet6_sk(sk)->dst_cookie)) + return; + + bh_lock_sock(sk); + if (!sock_owned_by_user(sk) && !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) + ip6_datagram_dst_update(sk, false); + bh_unlock_sock(sk); +} +EXPORT_SYMBOL_GPL(ip6_sk_update_pmtu); + +void ip6_sk_dst_store_flow(struct sock *sk, struct dst_entry *dst, + const struct flowi6 *fl6) +{ +#ifdef CONFIG_IPV6_SUBTREES + struct ipv6_pinfo *np = inet6_sk(sk); +#endif + + ip6_dst_store(sk, dst, + ipv6_addr_equal(&fl6->daddr, &sk->sk_v6_daddr) ? + &sk->sk_v6_daddr : NULL, +#ifdef CONFIG_IPV6_SUBTREES + ipv6_addr_equal(&fl6->saddr, &np->saddr) ? + &np->saddr : +#endif + NULL); +} + +static bool ip6_redirect_nh_match(const struct fib6_result *res, + struct flowi6 *fl6, + const struct in6_addr *gw, + struct rt6_info **ret) +{ + const struct fib6_nh *nh = res->nh; + + if (nh->fib_nh_flags & RTNH_F_DEAD || !nh->fib_nh_gw_family || + fl6->flowi6_oif != nh->fib_nh_dev->ifindex) + return false; + + /* rt_cache's gateway might be different from its 'parent' + * in the case of an ip redirect. + * So we keep searching in the exception table if the gateway + * is different. + */ + if (!ipv6_addr_equal(gw, &nh->fib_nh_gw6)) { + struct rt6_info *rt_cache; + + rt_cache = rt6_find_cached_rt(res, &fl6->daddr, &fl6->saddr); + if (rt_cache && + ipv6_addr_equal(gw, &rt_cache->rt6i_gateway)) { + *ret = rt_cache; + return true; + } + return false; + } + return true; +} + +struct fib6_nh_rd_arg { + struct fib6_result *res; + struct flowi6 *fl6; + const struct in6_addr *gw; + struct rt6_info **ret; +}; + +static int fib6_nh_redirect_match(struct fib6_nh *nh, void *_arg) +{ + struct fib6_nh_rd_arg *arg = _arg; + + arg->res->nh = nh; + return ip6_redirect_nh_match(arg->res, arg->fl6, arg->gw, arg->ret); +} + +/* Handle redirects */ +struct ip6rd_flowi { + struct flowi6 fl6; + struct in6_addr gateway; +}; + +INDIRECT_CALLABLE_SCOPE struct rt6_info *__ip6_route_redirect(struct net *net, + struct fib6_table *table, + struct flowi6 *fl6, + const struct sk_buff *skb, + int flags) +{ + struct ip6rd_flowi *rdfl = (struct ip6rd_flowi *)fl6; + struct rt6_info *ret = NULL; + struct fib6_result res = {}; + struct fib6_nh_rd_arg arg = { + .res = &res, + .fl6 = fl6, + .gw = &rdfl->gateway, + .ret = &ret + }; + struct fib6_info *rt; + struct fib6_node *fn; + + /* Get the "current" route for this destination and + * check if the redirect has come from appropriate router. + * + * RFC 4861 specifies that redirects should only be + * accepted if they come from the nexthop to the target. + * Due to the way the routes are chosen, this notion + * is a bit fuzzy and one might need to check all possible + * routes. + */ + + rcu_read_lock(); + fn = fib6_node_lookup(&table->tb6_root, &fl6->daddr, &fl6->saddr); +restart: + for_each_fib6_node_rt_rcu(fn) { + res.f6i = rt; + if (fib6_check_expired(rt)) + continue; + if (rt->fib6_flags & RTF_REJECT) + break; + if (unlikely(rt->nh)) { + if (nexthop_is_blackhole(rt->nh)) + continue; + /* on match, res->nh is filled in and potentially ret */ + if (nexthop_for_each_fib6_nh(rt->nh, + fib6_nh_redirect_match, + &arg)) + goto out; + } else { + res.nh = rt->fib6_nh; + if (ip6_redirect_nh_match(&res, fl6, &rdfl->gateway, + &ret)) + goto out; + } + } + + if (!rt) + rt = net->ipv6.fib6_null_entry; + else if (rt->fib6_flags & RTF_REJECT) { + ret = net->ipv6.ip6_null_entry; + goto out; + } + + if (rt == net->ipv6.fib6_null_entry) { + fn = fib6_backtrack(fn, &fl6->saddr); + if (fn) + goto restart; + } + + res.f6i = rt; + res.nh = rt->fib6_nh; +out: + if (ret) { + ip6_hold_safe(net, &ret); + } else { + res.fib6_flags = res.f6i->fib6_flags; + res.fib6_type = res.f6i->fib6_type; + ret = ip6_create_rt_rcu(&res); + } + + rcu_read_unlock(); + + trace_fib6_table_lookup(net, &res, table, fl6); + return ret; +}; + +static struct dst_entry *ip6_route_redirect(struct net *net, + const struct flowi6 *fl6, + const struct sk_buff *skb, + const struct in6_addr *gateway) +{ + int flags = RT6_LOOKUP_F_HAS_SADDR; + struct ip6rd_flowi rdfl; + + rdfl.fl6 = *fl6; + rdfl.gateway = *gateway; + + return fib6_rule_lookup(net, &rdfl.fl6, skb, + flags, __ip6_route_redirect); +} + +void ip6_redirect(struct sk_buff *skb, struct net *net, int oif, u32 mark, + kuid_t uid) +{ + const struct ipv6hdr *iph = (struct ipv6hdr *) skb->data; + struct dst_entry *dst; + struct flowi6 fl6 = { + .flowi6_iif = LOOPBACK_IFINDEX, + .flowi6_oif = oif, + .flowi6_mark = mark, + .daddr = iph->daddr, + .saddr = iph->saddr, + .flowlabel = ip6_flowinfo(iph), + .flowi6_uid = uid, + }; + + dst = ip6_route_redirect(net, &fl6, skb, &ipv6_hdr(skb)->saddr); + rt6_do_redirect(dst, NULL, skb); + dst_release(dst); +} +EXPORT_SYMBOL_GPL(ip6_redirect); + +void ip6_redirect_no_header(struct sk_buff *skb, struct net *net, int oif) +{ + const struct ipv6hdr *iph = ipv6_hdr(skb); + const struct rd_msg *msg = (struct rd_msg *)icmp6_hdr(skb); + struct dst_entry *dst; + struct flowi6 fl6 = { + .flowi6_iif = LOOPBACK_IFINDEX, + .flowi6_oif = oif, + .daddr = msg->dest, + .saddr = iph->daddr, + .flowi6_uid = sock_net_uid(net, NULL), + }; + + dst = ip6_route_redirect(net, &fl6, skb, &iph->saddr); + rt6_do_redirect(dst, NULL, skb); + dst_release(dst); +} + +void ip6_sk_redirect(struct sk_buff *skb, struct sock *sk) +{ + ip6_redirect(skb, sock_net(sk), sk->sk_bound_dev_if, + READ_ONCE(sk->sk_mark), sk->sk_uid); +} +EXPORT_SYMBOL_GPL(ip6_sk_redirect); + +static unsigned int ip6_default_advmss(const struct dst_entry *dst) +{ + struct net_device *dev = dst->dev; + unsigned int mtu = dst_mtu(dst); + struct net *net = dev_net(dev); + + mtu -= sizeof(struct ipv6hdr) + sizeof(struct tcphdr); + + if (mtu < net->ipv6.sysctl.ip6_rt_min_advmss) + mtu = net->ipv6.sysctl.ip6_rt_min_advmss; + + /* + * Maximal non-jumbo IPv6 payload is IPV6_MAXPLEN and + * corresponding MSS is IPV6_MAXPLEN - tcp_header_size. + * IPV6_MAXPLEN is also valid and means: "any MSS, + * rely only on pmtu discovery" + */ + if (mtu > IPV6_MAXPLEN - sizeof(struct tcphdr)) + mtu = IPV6_MAXPLEN; + return mtu; +} + +INDIRECT_CALLABLE_SCOPE unsigned int ip6_mtu(const struct dst_entry *dst) +{ + return ip6_dst_mtu_maybe_forward(dst, false); +} +EXPORT_INDIRECT_CALLABLE(ip6_mtu); + +/* MTU selection: + * 1. mtu on route is locked - use it + * 2. mtu from nexthop exception + * 3. mtu from egress device + * + * based on ip6_dst_mtu_forward and exception logic of + * rt6_find_cached_rt; called with rcu_read_lock + */ +u32 ip6_mtu_from_fib6(const struct fib6_result *res, + const struct in6_addr *daddr, + const struct in6_addr *saddr) +{ + const struct fib6_nh *nh = res->nh; + struct fib6_info *f6i = res->f6i; + struct inet6_dev *idev; + struct rt6_info *rt; + u32 mtu = 0; + + if (unlikely(fib6_metric_locked(f6i, RTAX_MTU))) { + mtu = f6i->fib6_pmtu; + if (mtu) + goto out; + } + + rt = rt6_find_cached_rt(res, daddr, saddr); + if (unlikely(rt)) { + mtu = dst_metric_raw(&rt->dst, RTAX_MTU); + } else { + struct net_device *dev = nh->fib_nh_dev; + + mtu = IPV6_MIN_MTU; + idev = __in6_dev_get(dev); + if (idev && idev->cnf.mtu6 > mtu) + mtu = idev->cnf.mtu6; + } + + mtu = min_t(unsigned int, mtu, IP6_MAX_MTU); +out: + return mtu - lwtunnel_headroom(nh->fib_nh_lws, mtu); +} + +struct dst_entry *icmp6_dst_alloc(struct net_device *dev, + struct flowi6 *fl6) +{ + struct dst_entry *dst; + struct rt6_info *rt; + struct inet6_dev *idev = in6_dev_get(dev); + struct net *net = dev_net(dev); + + if (unlikely(!idev)) + return ERR_PTR(-ENODEV); + + rt = ip6_dst_alloc(net, dev, 0); + if (unlikely(!rt)) { + in6_dev_put(idev); + dst = ERR_PTR(-ENOMEM); + goto out; + } + + rt->dst.input = ip6_input; + rt->dst.output = ip6_output; + rt->rt6i_gateway = fl6->daddr; + rt->rt6i_dst.addr = fl6->daddr; + rt->rt6i_dst.plen = 128; + rt->rt6i_idev = idev; + dst_metric_set(&rt->dst, RTAX_HOPLIMIT, 0); + + /* Add this dst into uncached_list so that rt6_disable_ip() can + * do proper release of the net_device + */ + rt6_uncached_list_add(rt); + + dst = xfrm_lookup(net, &rt->dst, flowi6_to_flowi(fl6), NULL, 0); + +out: + return dst; +} + +static void ip6_dst_gc(struct dst_ops *ops) +{ + struct net *net = container_of(ops, struct net, ipv6.ip6_dst_ops); + int rt_min_interval = net->ipv6.sysctl.ip6_rt_gc_min_interval; + int rt_elasticity = net->ipv6.sysctl.ip6_rt_gc_elasticity; + int rt_gc_timeout = net->ipv6.sysctl.ip6_rt_gc_timeout; + unsigned long rt_last_gc = net->ipv6.ip6_rt_last_gc; + unsigned int val; + int entries; + + entries = dst_entries_get_fast(ops); + if (entries > ops->gc_thresh) + entries = dst_entries_get_slow(ops); + + if (time_after(rt_last_gc + rt_min_interval, jiffies)) + goto out; + + fib6_run_gc(atomic_inc_return(&net->ipv6.ip6_rt_gc_expire), net, true); + entries = dst_entries_get_slow(ops); + if (entries < ops->gc_thresh) + atomic_set(&net->ipv6.ip6_rt_gc_expire, rt_gc_timeout >> 1); +out: + val = atomic_read(&net->ipv6.ip6_rt_gc_expire); + atomic_set(&net->ipv6.ip6_rt_gc_expire, val - (val >> rt_elasticity)); +} + +static int ip6_nh_lookup_table(struct net *net, struct fib6_config *cfg, + const struct in6_addr *gw_addr, u32 tbid, + int flags, struct fib6_result *res) +{ + struct flowi6 fl6 = { + .flowi6_oif = cfg->fc_ifindex, + .daddr = *gw_addr, + .saddr = cfg->fc_prefsrc, + }; + struct fib6_table *table; + int err; + + table = fib6_get_table(net, tbid); + if (!table) + return -EINVAL; + + if (!ipv6_addr_any(&cfg->fc_prefsrc)) + flags |= RT6_LOOKUP_F_HAS_SADDR; + + flags |= RT6_LOOKUP_F_IGNORE_LINKSTATE; + + err = fib6_table_lookup(net, table, cfg->fc_ifindex, &fl6, res, flags); + if (!err && res->f6i != net->ipv6.fib6_null_entry) + fib6_select_path(net, res, &fl6, cfg->fc_ifindex, + cfg->fc_ifindex != 0, NULL, flags); + + return err; +} + +static int ip6_route_check_nh_onlink(struct net *net, + struct fib6_config *cfg, + const struct net_device *dev, + struct netlink_ext_ack *extack) +{ + u32 tbid = l3mdev_fib_table_rcu(dev) ? : RT_TABLE_MAIN; + const struct in6_addr *gw_addr = &cfg->fc_gateway; + struct fib6_result res = {}; + int err; + + err = ip6_nh_lookup_table(net, cfg, gw_addr, tbid, 0, &res); + if (!err && !(res.fib6_flags & RTF_REJECT) && + /* ignore match if it is the default route */ + !ipv6_addr_any(&res.f6i->fib6_dst.addr) && + (res.fib6_type != RTN_UNICAST || dev != res.nh->fib_nh_dev)) { + NL_SET_ERR_MSG(extack, + "Nexthop has invalid gateway or device mismatch"); + err = -EINVAL; + } + + return err; +} + +static int ip6_route_check_nh(struct net *net, + struct fib6_config *cfg, + struct net_device **_dev, + struct inet6_dev **idev) +{ + const struct in6_addr *gw_addr = &cfg->fc_gateway; + struct net_device *dev = _dev ? *_dev : NULL; + int flags = RT6_LOOKUP_F_IFACE; + struct fib6_result res = {}; + int err = -EHOSTUNREACH; + + if (cfg->fc_table) { + err = ip6_nh_lookup_table(net, cfg, gw_addr, + cfg->fc_table, flags, &res); + /* gw_addr can not require a gateway or resolve to a reject + * route. If a device is given, it must match the result. + */ + if (err || res.fib6_flags & RTF_REJECT || + res.nh->fib_nh_gw_family || + (dev && dev != res.nh->fib_nh_dev)) + err = -EHOSTUNREACH; + } + + if (err < 0) { + struct flowi6 fl6 = { + .flowi6_oif = cfg->fc_ifindex, + .daddr = *gw_addr, + }; + + err = fib6_lookup(net, cfg->fc_ifindex, &fl6, &res, flags); + if (err || res.fib6_flags & RTF_REJECT || + res.nh->fib_nh_gw_family) + err = -EHOSTUNREACH; + + if (err) + return err; + + fib6_select_path(net, &res, &fl6, cfg->fc_ifindex, + cfg->fc_ifindex != 0, NULL, flags); + } + + err = 0; + if (dev) { + if (dev != res.nh->fib_nh_dev) + err = -EHOSTUNREACH; + } else { + *_dev = dev = res.nh->fib_nh_dev; + dev_hold(dev); + *idev = in6_dev_get(dev); + } + + return err; +} + +static int ip6_validate_gw(struct net *net, struct fib6_config *cfg, + struct net_device **_dev, struct inet6_dev **idev, + struct netlink_ext_ack *extack) +{ + const struct in6_addr *gw_addr = &cfg->fc_gateway; + int gwa_type = ipv6_addr_type(gw_addr); + bool skip_dev = gwa_type & IPV6_ADDR_LINKLOCAL ? false : true; + const struct net_device *dev = *_dev; + bool need_addr_check = !dev; + int err = -EINVAL; + + /* if gw_addr is local we will fail to detect this in case + * address is still TENTATIVE (DAD in progress). rt6_lookup() + * will return already-added prefix route via interface that + * prefix route was assigned to, which might be non-loopback. + */ + if (dev && + ipv6_chk_addr_and_flags(net, gw_addr, dev, skip_dev, 0, 0)) { + NL_SET_ERR_MSG(extack, "Gateway can not be a local address"); + goto out; + } + + if (gwa_type != (IPV6_ADDR_LINKLOCAL | IPV6_ADDR_UNICAST)) { + /* IPv6 strictly inhibits using not link-local + * addresses as nexthop address. + * Otherwise, router will not able to send redirects. + * It is very good, but in some (rare!) circumstances + * (SIT, PtP, NBMA NOARP links) it is handy to allow + * some exceptions. --ANK + * We allow IPv4-mapped nexthops to support RFC4798-type + * addressing + */ + if (!(gwa_type & (IPV6_ADDR_UNICAST | IPV6_ADDR_MAPPED))) { + NL_SET_ERR_MSG(extack, "Invalid gateway address"); + goto out; + } + + rcu_read_lock(); + + if (cfg->fc_flags & RTNH_F_ONLINK) + err = ip6_route_check_nh_onlink(net, cfg, dev, extack); + else + err = ip6_route_check_nh(net, cfg, _dev, idev); + + rcu_read_unlock(); + + if (err) + goto out; + } + + /* reload in case device was changed */ + dev = *_dev; + + err = -EINVAL; + if (!dev) { + NL_SET_ERR_MSG(extack, "Egress device not specified"); + goto out; + } else if (dev->flags & IFF_LOOPBACK) { + NL_SET_ERR_MSG(extack, + "Egress device can not be loopback device for this route"); + goto out; + } + + /* if we did not check gw_addr above, do so now that the + * egress device has been resolved. + */ + if (need_addr_check && + ipv6_chk_addr_and_flags(net, gw_addr, dev, skip_dev, 0, 0)) { + NL_SET_ERR_MSG(extack, "Gateway can not be a local address"); + goto out; + } + + err = 0; +out: + return err; +} + +static bool fib6_is_reject(u32 flags, struct net_device *dev, int addr_type) +{ + if ((flags & RTF_REJECT) || + (dev && (dev->flags & IFF_LOOPBACK) && + !(addr_type & IPV6_ADDR_LOOPBACK) && + !(flags & (RTF_ANYCAST | RTF_LOCAL)))) + return true; + + return false; +} + +int fib6_nh_init(struct net *net, struct fib6_nh *fib6_nh, + struct fib6_config *cfg, gfp_t gfp_flags, + struct netlink_ext_ack *extack) +{ + struct net_device *dev = NULL; + struct inet6_dev *idev = NULL; + int addr_type; + int err; + + fib6_nh->fib_nh_family = AF_INET6; +#ifdef CONFIG_IPV6_ROUTER_PREF + fib6_nh->last_probe = jiffies; +#endif + if (cfg->fc_is_fdb) { + fib6_nh->fib_nh_gw6 = cfg->fc_gateway; + fib6_nh->fib_nh_gw_family = AF_INET6; + return 0; + } + + err = -ENODEV; + if (cfg->fc_ifindex) { + dev = dev_get_by_index(net, cfg->fc_ifindex); + if (!dev) + goto out; + idev = in6_dev_get(dev); + if (!idev) + goto out; + } + + if (cfg->fc_flags & RTNH_F_ONLINK) { + if (!dev) { + NL_SET_ERR_MSG(extack, + "Nexthop device required for onlink"); + goto out; + } + + if (!(dev->flags & IFF_UP)) { + NL_SET_ERR_MSG(extack, "Nexthop device is not up"); + err = -ENETDOWN; + goto out; + } + + fib6_nh->fib_nh_flags |= RTNH_F_ONLINK; + } + + fib6_nh->fib_nh_weight = 1; + + /* We cannot add true routes via loopback here, + * they would result in kernel looping; promote them to reject routes + */ + addr_type = ipv6_addr_type(&cfg->fc_dst); + if (fib6_is_reject(cfg->fc_flags, dev, addr_type)) { + /* hold loopback dev/idev if we haven't done so. */ + if (dev != net->loopback_dev) { + if (dev) { + dev_put(dev); + in6_dev_put(idev); + } + dev = net->loopback_dev; + dev_hold(dev); + idev = in6_dev_get(dev); + if (!idev) { + err = -ENODEV; + goto out; + } + } + goto pcpu_alloc; + } + + if (cfg->fc_flags & RTF_GATEWAY) { + err = ip6_validate_gw(net, cfg, &dev, &idev, extack); + if (err) + goto out; + + fib6_nh->fib_nh_gw6 = cfg->fc_gateway; + fib6_nh->fib_nh_gw_family = AF_INET6; + } + + err = -ENODEV; + if (!dev) + goto out; + + if (idev->cnf.disable_ipv6) { + NL_SET_ERR_MSG(extack, "IPv6 is disabled on nexthop device"); + err = -EACCES; + goto out; + } + + if (!(dev->flags & IFF_UP) && !cfg->fc_ignore_dev_down) { + NL_SET_ERR_MSG(extack, "Nexthop device is not up"); + err = -ENETDOWN; + goto out; + } + + if (!(cfg->fc_flags & (RTF_LOCAL | RTF_ANYCAST)) && + !netif_carrier_ok(dev)) + fib6_nh->fib_nh_flags |= RTNH_F_LINKDOWN; + + err = fib_nh_common_init(net, &fib6_nh->nh_common, cfg->fc_encap, + cfg->fc_encap_type, cfg, gfp_flags, extack); + if (err) + goto out; + +pcpu_alloc: + fib6_nh->rt6i_pcpu = alloc_percpu_gfp(struct rt6_info *, gfp_flags); + if (!fib6_nh->rt6i_pcpu) { + err = -ENOMEM; + goto out; + } + + fib6_nh->fib_nh_dev = dev; + netdev_tracker_alloc(dev, &fib6_nh->fib_nh_dev_tracker, gfp_flags); + + fib6_nh->fib_nh_oif = dev->ifindex; + err = 0; +out: + if (idev) + in6_dev_put(idev); + + if (err) { + lwtstate_put(fib6_nh->fib_nh_lws); + fib6_nh->fib_nh_lws = NULL; + dev_put(dev); + } + + return err; +} + +void fib6_nh_release(struct fib6_nh *fib6_nh) +{ + struct rt6_exception_bucket *bucket; + + rcu_read_lock(); + + fib6_nh_flush_exceptions(fib6_nh, NULL); + bucket = fib6_nh_get_excptn_bucket(fib6_nh, NULL); + if (bucket) { + rcu_assign_pointer(fib6_nh->rt6i_exception_bucket, NULL); + kfree(bucket); + } + + rcu_read_unlock(); + + fib6_nh_release_dsts(fib6_nh); + free_percpu(fib6_nh->rt6i_pcpu); + + fib_nh_common_release(&fib6_nh->nh_common); +} + +void fib6_nh_release_dsts(struct fib6_nh *fib6_nh) +{ + int cpu; + + if (!fib6_nh->rt6i_pcpu) + return; + + for_each_possible_cpu(cpu) { + struct rt6_info *pcpu_rt, **ppcpu_rt; + + ppcpu_rt = per_cpu_ptr(fib6_nh->rt6i_pcpu, cpu); + pcpu_rt = xchg(ppcpu_rt, NULL); + if (pcpu_rt) { + dst_dev_put(&pcpu_rt->dst); + dst_release(&pcpu_rt->dst); + } + } +} + +static struct fib6_info *ip6_route_info_create(struct fib6_config *cfg, + gfp_t gfp_flags, + struct netlink_ext_ack *extack) +{ + struct net *net = cfg->fc_nlinfo.nl_net; + struct fib6_info *rt = NULL; + struct nexthop *nh = NULL; + struct fib6_table *table; + struct fib6_nh *fib6_nh; + int err = -EINVAL; + int addr_type; + + /* RTF_PCPU is an internal flag; can not be set by userspace */ + if (cfg->fc_flags & RTF_PCPU) { + NL_SET_ERR_MSG(extack, "Userspace can not set RTF_PCPU"); + goto out; + } + + /* RTF_CACHE is an internal flag; can not be set by userspace */ + if (cfg->fc_flags & RTF_CACHE) { + NL_SET_ERR_MSG(extack, "Userspace can not set RTF_CACHE"); + goto out; + } + + if (cfg->fc_type > RTN_MAX) { + NL_SET_ERR_MSG(extack, "Invalid route type"); + goto out; + } + + if (cfg->fc_dst_len > 128) { + NL_SET_ERR_MSG(extack, "Invalid prefix length"); + goto out; + } + if (cfg->fc_src_len > 128) { + NL_SET_ERR_MSG(extack, "Invalid source address length"); + goto out; + } +#ifndef CONFIG_IPV6_SUBTREES + if (cfg->fc_src_len) { + NL_SET_ERR_MSG(extack, + "Specifying source address requires IPV6_SUBTREES to be enabled"); + goto out; + } +#endif + if (cfg->fc_nh_id) { + nh = nexthop_find_by_id(net, cfg->fc_nh_id); + if (!nh) { + NL_SET_ERR_MSG(extack, "Nexthop id does not exist"); + goto out; + } + err = fib6_check_nexthop(nh, cfg, extack); + if (err) + goto out; + } + + err = -ENOBUFS; + if (cfg->fc_nlinfo.nlh && + !(cfg->fc_nlinfo.nlh->nlmsg_flags & NLM_F_CREATE)) { + table = fib6_get_table(net, cfg->fc_table); + if (!table) { + pr_warn("NLM_F_CREATE should be specified when creating new route\n"); + table = fib6_new_table(net, cfg->fc_table); + } + } else { + table = fib6_new_table(net, cfg->fc_table); + } + + if (!table) + goto out; + + err = -ENOMEM; + rt = fib6_info_alloc(gfp_flags, !nh); + if (!rt) + goto out; + + rt->fib6_metrics = ip_fib_metrics_init(net, cfg->fc_mx, cfg->fc_mx_len, + extack); + if (IS_ERR(rt->fib6_metrics)) { + err = PTR_ERR(rt->fib6_metrics); + /* Do not leave garbage there. */ + rt->fib6_metrics = (struct dst_metrics *)&dst_default_metrics; + goto out_free; + } + + if (cfg->fc_flags & RTF_ADDRCONF) + rt->dst_nocount = true; + + if (cfg->fc_flags & RTF_EXPIRES) + fib6_set_expires(rt, jiffies + + clock_t_to_jiffies(cfg->fc_expires)); + else + fib6_clean_expires(rt); + + if (cfg->fc_protocol == RTPROT_UNSPEC) + cfg->fc_protocol = RTPROT_BOOT; + rt->fib6_protocol = cfg->fc_protocol; + + rt->fib6_table = table; + rt->fib6_metric = cfg->fc_metric; + rt->fib6_type = cfg->fc_type ? : RTN_UNICAST; + rt->fib6_flags = cfg->fc_flags & ~RTF_GATEWAY; + + ipv6_addr_prefix(&rt->fib6_dst.addr, &cfg->fc_dst, cfg->fc_dst_len); + rt->fib6_dst.plen = cfg->fc_dst_len; + +#ifdef CONFIG_IPV6_SUBTREES + ipv6_addr_prefix(&rt->fib6_src.addr, &cfg->fc_src, cfg->fc_src_len); + rt->fib6_src.plen = cfg->fc_src_len; +#endif + if (nh) { + if (rt->fib6_src.plen) { + NL_SET_ERR_MSG(extack, "Nexthops can not be used with source routing"); + goto out_free; + } + if (!nexthop_get(nh)) { + NL_SET_ERR_MSG(extack, "Nexthop has been deleted"); + goto out_free; + } + rt->nh = nh; + fib6_nh = nexthop_fib6_nh(rt->nh); + } else { + err = fib6_nh_init(net, rt->fib6_nh, cfg, gfp_flags, extack); + if (err) + goto out; + + fib6_nh = rt->fib6_nh; + + /* We cannot add true routes via loopback here, they would + * result in kernel looping; promote them to reject routes + */ + addr_type = ipv6_addr_type(&cfg->fc_dst); + if (fib6_is_reject(cfg->fc_flags, rt->fib6_nh->fib_nh_dev, + addr_type)) + rt->fib6_flags = RTF_REJECT | RTF_NONEXTHOP; + } + + if (!ipv6_addr_any(&cfg->fc_prefsrc)) { + struct net_device *dev = fib6_nh->fib_nh_dev; + + if (!ipv6_chk_addr(net, &cfg->fc_prefsrc, dev, 0)) { + NL_SET_ERR_MSG(extack, "Invalid source address"); + err = -EINVAL; + goto out; + } + rt->fib6_prefsrc.addr = cfg->fc_prefsrc; + rt->fib6_prefsrc.plen = 128; + } else + rt->fib6_prefsrc.plen = 0; + + return rt; +out: + fib6_info_release(rt); + return ERR_PTR(err); +out_free: + ip_fib_metrics_put(rt->fib6_metrics); + kfree(rt); + return ERR_PTR(err); +} + +int ip6_route_add(struct fib6_config *cfg, gfp_t gfp_flags, + struct netlink_ext_ack *extack) +{ + struct fib6_info *rt; + int err; + + rt = ip6_route_info_create(cfg, gfp_flags, extack); + if (IS_ERR(rt)) + return PTR_ERR(rt); + + err = __ip6_ins_rt(rt, &cfg->fc_nlinfo, extack); + fib6_info_release(rt); + + return err; +} + +static int __ip6_del_rt(struct fib6_info *rt, struct nl_info *info) +{ + struct net *net = info->nl_net; + struct fib6_table *table; + int err; + + if (rt == net->ipv6.fib6_null_entry) { + err = -ENOENT; + goto out; + } + + table = rt->fib6_table; + spin_lock_bh(&table->tb6_lock); + err = fib6_del(rt, info); + spin_unlock_bh(&table->tb6_lock); + +out: + fib6_info_release(rt); + return err; +} + +int ip6_del_rt(struct net *net, struct fib6_info *rt, bool skip_notify) +{ + struct nl_info info = { + .nl_net = net, + .skip_notify = skip_notify + }; + + return __ip6_del_rt(rt, &info); +} + +static int __ip6_del_rt_siblings(struct fib6_info *rt, struct fib6_config *cfg) +{ + struct nl_info *info = &cfg->fc_nlinfo; + struct net *net = info->nl_net; + struct sk_buff *skb = NULL; + struct fib6_table *table; + int err = -ENOENT; + + if (rt == net->ipv6.fib6_null_entry) + goto out_put; + table = rt->fib6_table; + spin_lock_bh(&table->tb6_lock); + + if (rt->fib6_nsiblings && cfg->fc_delete_all_nh) { + struct fib6_info *sibling, *next_sibling; + struct fib6_node *fn; + + /* prefer to send a single notification with all hops */ + skb = nlmsg_new(rt6_nlmsg_size(rt), gfp_any()); + if (skb) { + u32 seq = info->nlh ? info->nlh->nlmsg_seq : 0; + + if (rt6_fill_node(net, skb, rt, NULL, + NULL, NULL, 0, RTM_DELROUTE, + info->portid, seq, 0) < 0) { + kfree_skb(skb); + skb = NULL; + } else + info->skip_notify = 1; + } + + /* 'rt' points to the first sibling route. If it is not the + * leaf, then we do not need to send a notification. Otherwise, + * we need to check if the last sibling has a next route or not + * and emit a replace or delete notification, respectively. + */ + info->skip_notify_kernel = 1; + fn = rcu_dereference_protected(rt->fib6_node, + lockdep_is_held(&table->tb6_lock)); + if (rcu_access_pointer(fn->leaf) == rt) { + struct fib6_info *last_sibling, *replace_rt; + + last_sibling = list_last_entry(&rt->fib6_siblings, + struct fib6_info, + fib6_siblings); + replace_rt = rcu_dereference_protected( + last_sibling->fib6_next, + lockdep_is_held(&table->tb6_lock)); + if (replace_rt) + call_fib6_entry_notifiers_replace(net, + replace_rt); + else + call_fib6_multipath_entry_notifiers(net, + FIB_EVENT_ENTRY_DEL, + rt, rt->fib6_nsiblings, + NULL); + } + list_for_each_entry_safe(sibling, next_sibling, + &rt->fib6_siblings, + fib6_siblings) { + err = fib6_del(sibling, info); + if (err) + goto out_unlock; + } + } + + err = fib6_del(rt, info); +out_unlock: + spin_unlock_bh(&table->tb6_lock); +out_put: + fib6_info_release(rt); + + if (skb) { + rtnl_notify(skb, net, info->portid, RTNLGRP_IPV6_ROUTE, + info->nlh, gfp_any()); + } + return err; +} + +static int __ip6_del_cached_rt(struct rt6_info *rt, struct fib6_config *cfg) +{ + int rc = -ESRCH; + + if (cfg->fc_ifindex && rt->dst.dev->ifindex != cfg->fc_ifindex) + goto out; + + if (cfg->fc_flags & RTF_GATEWAY && + !ipv6_addr_equal(&cfg->fc_gateway, &rt->rt6i_gateway)) + goto out; + + rc = rt6_remove_exception_rt(rt); +out: + return rc; +} + +static int ip6_del_cached_rt(struct fib6_config *cfg, struct fib6_info *rt, + struct fib6_nh *nh) +{ + struct fib6_result res = { + .f6i = rt, + .nh = nh, + }; + struct rt6_info *rt_cache; + + rt_cache = rt6_find_cached_rt(&res, &cfg->fc_dst, &cfg->fc_src); + if (rt_cache) + return __ip6_del_cached_rt(rt_cache, cfg); + + return 0; +} + +struct fib6_nh_del_cached_rt_arg { + struct fib6_config *cfg; + struct fib6_info *f6i; +}; + +static int fib6_nh_del_cached_rt(struct fib6_nh *nh, void *_arg) +{ + struct fib6_nh_del_cached_rt_arg *arg = _arg; + int rc; + + rc = ip6_del_cached_rt(arg->cfg, arg->f6i, nh); + return rc != -ESRCH ? rc : 0; +} + +static int ip6_del_cached_rt_nh(struct fib6_config *cfg, struct fib6_info *f6i) +{ + struct fib6_nh_del_cached_rt_arg arg = { + .cfg = cfg, + .f6i = f6i + }; + + return nexthop_for_each_fib6_nh(f6i->nh, fib6_nh_del_cached_rt, &arg); +} + +static int ip6_route_del(struct fib6_config *cfg, + struct netlink_ext_ack *extack) +{ + struct fib6_table *table; + struct fib6_info *rt; + struct fib6_node *fn; + int err = -ESRCH; + + table = fib6_get_table(cfg->fc_nlinfo.nl_net, cfg->fc_table); + if (!table) { + NL_SET_ERR_MSG(extack, "FIB table does not exist"); + return err; + } + + rcu_read_lock(); + + fn = fib6_locate(&table->tb6_root, + &cfg->fc_dst, cfg->fc_dst_len, + &cfg->fc_src, cfg->fc_src_len, + !(cfg->fc_flags & RTF_CACHE)); + + if (fn) { + for_each_fib6_node_rt_rcu(fn) { + struct fib6_nh *nh; + + if (rt->nh && cfg->fc_nh_id && + rt->nh->id != cfg->fc_nh_id) + continue; + + if (cfg->fc_flags & RTF_CACHE) { + int rc = 0; + + if (rt->nh) { + rc = ip6_del_cached_rt_nh(cfg, rt); + } else if (cfg->fc_nh_id) { + continue; + } else { + nh = rt->fib6_nh; + rc = ip6_del_cached_rt(cfg, rt, nh); + } + if (rc != -ESRCH) { + rcu_read_unlock(); + return rc; + } + continue; + } + + if (cfg->fc_metric && cfg->fc_metric != rt->fib6_metric) + continue; + if (cfg->fc_protocol && + cfg->fc_protocol != rt->fib6_protocol) + continue; + + if (rt->nh) { + if (!fib6_info_hold_safe(rt)) + continue; + rcu_read_unlock(); + + return __ip6_del_rt(rt, &cfg->fc_nlinfo); + } + if (cfg->fc_nh_id) + continue; + + nh = rt->fib6_nh; + if (cfg->fc_ifindex && + (!nh->fib_nh_dev || + nh->fib_nh_dev->ifindex != cfg->fc_ifindex)) + continue; + if (cfg->fc_flags & RTF_GATEWAY && + !ipv6_addr_equal(&cfg->fc_gateway, &nh->fib_nh_gw6)) + continue; + if (!fib6_info_hold_safe(rt)) + continue; + rcu_read_unlock(); + + /* if gateway was specified only delete the one hop */ + if (cfg->fc_flags & RTF_GATEWAY) + return __ip6_del_rt(rt, &cfg->fc_nlinfo); + + return __ip6_del_rt_siblings(rt, cfg); + } + } + rcu_read_unlock(); + + return err; +} + +static void rt6_do_redirect(struct dst_entry *dst, struct sock *sk, struct sk_buff *skb) +{ + struct netevent_redirect netevent; + struct rt6_info *rt, *nrt = NULL; + struct fib6_result res = {}; + struct ndisc_options ndopts; + struct inet6_dev *in6_dev; + struct neighbour *neigh; + struct rd_msg *msg; + int optlen, on_link; + u8 *lladdr; + + optlen = skb_tail_pointer(skb) - skb_transport_header(skb); + optlen -= sizeof(*msg); + + if (optlen < 0) { + net_dbg_ratelimited("rt6_do_redirect: packet too short\n"); + return; + } + + msg = (struct rd_msg *)icmp6_hdr(skb); + + if (ipv6_addr_is_multicast(&msg->dest)) { + net_dbg_ratelimited("rt6_do_redirect: destination address is multicast\n"); + return; + } + + on_link = 0; + if (ipv6_addr_equal(&msg->dest, &msg->target)) { + on_link = 1; + } else if (ipv6_addr_type(&msg->target) != + (IPV6_ADDR_UNICAST|IPV6_ADDR_LINKLOCAL)) { + net_dbg_ratelimited("rt6_do_redirect: target address is not link-local unicast\n"); + return; + } + + in6_dev = __in6_dev_get(skb->dev); + if (!in6_dev) + return; + if (in6_dev->cnf.forwarding || !in6_dev->cnf.accept_redirects) + return; + + /* RFC2461 8.1: + * The IP source address of the Redirect MUST be the same as the current + * first-hop router for the specified ICMP Destination Address. + */ + + if (!ndisc_parse_options(skb->dev, msg->opt, optlen, &ndopts)) { + net_dbg_ratelimited("rt6_redirect: invalid ND options\n"); + return; + } + + lladdr = NULL; + if (ndopts.nd_opts_tgt_lladdr) { + lladdr = ndisc_opt_addr_data(ndopts.nd_opts_tgt_lladdr, + skb->dev); + if (!lladdr) { + net_dbg_ratelimited("rt6_redirect: invalid link-layer address length\n"); + return; + } + } + + rt = (struct rt6_info *) dst; + if (rt->rt6i_flags & RTF_REJECT) { + net_dbg_ratelimited("rt6_redirect: source isn't a valid nexthop for redirect target\n"); + return; + } + + /* Redirect received -> path was valid. + * Look, redirects are sent only in response to data packets, + * so that this nexthop apparently is reachable. --ANK + */ + dst_confirm_neigh(&rt->dst, &ipv6_hdr(skb)->saddr); + + neigh = __neigh_lookup(&nd_tbl, &msg->target, skb->dev, 1); + if (!neigh) + return; + + /* + * We have finally decided to accept it. + */ + + ndisc_update(skb->dev, neigh, lladdr, NUD_STALE, + NEIGH_UPDATE_F_WEAK_OVERRIDE| + NEIGH_UPDATE_F_OVERRIDE| + (on_link ? 0 : (NEIGH_UPDATE_F_OVERRIDE_ISROUTER| + NEIGH_UPDATE_F_ISROUTER)), + NDISC_REDIRECT, &ndopts); + + rcu_read_lock(); + res.f6i = rcu_dereference(rt->from); + if (!res.f6i) + goto out; + + if (res.f6i->nh) { + struct fib6_nh_match_arg arg = { + .dev = dst->dev, + .gw = &rt->rt6i_gateway, + }; + + nexthop_for_each_fib6_nh(res.f6i->nh, + fib6_nh_find_match, &arg); + + /* fib6_info uses a nexthop that does not have fib6_nh + * using the dst->dev. Should be impossible + */ + if (!arg.match) + goto out; + res.nh = arg.match; + } else { + res.nh = res.f6i->fib6_nh; + } + + res.fib6_flags = res.f6i->fib6_flags; + res.fib6_type = res.f6i->fib6_type; + nrt = ip6_rt_cache_alloc(&res, &msg->dest, NULL); + if (!nrt) + goto out; + + nrt->rt6i_flags = RTF_GATEWAY|RTF_UP|RTF_DYNAMIC|RTF_CACHE; + if (on_link) + nrt->rt6i_flags &= ~RTF_GATEWAY; + + nrt->rt6i_gateway = *(struct in6_addr *)neigh->primary_key; + + /* rt6_insert_exception() will take care of duplicated exceptions */ + if (rt6_insert_exception(nrt, &res)) { + dst_release_immediate(&nrt->dst); + goto out; + } + + netevent.old = &rt->dst; + netevent.new = &nrt->dst; + netevent.daddr = &msg->dest; + netevent.neigh = neigh; + call_netevent_notifiers(NETEVENT_REDIRECT, &netevent); + +out: + rcu_read_unlock(); + neigh_release(neigh); +} + +#ifdef CONFIG_IPV6_ROUTE_INFO +static struct fib6_info *rt6_get_route_info(struct net *net, + const struct in6_addr *prefix, int prefixlen, + const struct in6_addr *gwaddr, + struct net_device *dev) +{ + u32 tb_id = l3mdev_fib_table(dev) ? : RT6_TABLE_INFO; + int ifindex = dev->ifindex; + struct fib6_node *fn; + struct fib6_info *rt = NULL; + struct fib6_table *table; + + table = fib6_get_table(net, tb_id); + if (!table) + return NULL; + + rcu_read_lock(); + fn = fib6_locate(&table->tb6_root, prefix, prefixlen, NULL, 0, true); + if (!fn) + goto out; + + for_each_fib6_node_rt_rcu(fn) { + /* these routes do not use nexthops */ + if (rt->nh) + continue; + if (rt->fib6_nh->fib_nh_dev->ifindex != ifindex) + continue; + if (!(rt->fib6_flags & RTF_ROUTEINFO) || + !rt->fib6_nh->fib_nh_gw_family) + continue; + if (!ipv6_addr_equal(&rt->fib6_nh->fib_nh_gw6, gwaddr)) + continue; + if (!fib6_info_hold_safe(rt)) + continue; + break; + } +out: + rcu_read_unlock(); + return rt; +} + +static struct fib6_info *rt6_add_route_info(struct net *net, + const struct in6_addr *prefix, int prefixlen, + const struct in6_addr *gwaddr, + struct net_device *dev, + unsigned int pref) +{ + struct fib6_config cfg = { + .fc_metric = IP6_RT_PRIO_USER, + .fc_ifindex = dev->ifindex, + .fc_dst_len = prefixlen, + .fc_flags = RTF_GATEWAY | RTF_ADDRCONF | RTF_ROUTEINFO | + RTF_UP | RTF_PREF(pref), + .fc_protocol = RTPROT_RA, + .fc_type = RTN_UNICAST, + .fc_nlinfo.portid = 0, + .fc_nlinfo.nlh = NULL, + .fc_nlinfo.nl_net = net, + }; + + cfg.fc_table = l3mdev_fib_table(dev) ? : RT6_TABLE_INFO; + cfg.fc_dst = *prefix; + cfg.fc_gateway = *gwaddr; + + /* We should treat it as a default route if prefix length is 0. */ + if (!prefixlen) + cfg.fc_flags |= RTF_DEFAULT; + + ip6_route_add(&cfg, GFP_ATOMIC, NULL); + + return rt6_get_route_info(net, prefix, prefixlen, gwaddr, dev); +} +#endif + +struct fib6_info *rt6_get_dflt_router(struct net *net, + const struct in6_addr *addr, + struct net_device *dev) +{ + u32 tb_id = l3mdev_fib_table(dev) ? : RT6_TABLE_DFLT; + struct fib6_info *rt; + struct fib6_table *table; + + table = fib6_get_table(net, tb_id); + if (!table) + return NULL; + + rcu_read_lock(); + for_each_fib6_node_rt_rcu(&table->tb6_root) { + struct fib6_nh *nh; + + /* RA routes do not use nexthops */ + if (rt->nh) + continue; + + nh = rt->fib6_nh; + if (dev == nh->fib_nh_dev && + ((rt->fib6_flags & (RTF_ADDRCONF | RTF_DEFAULT)) == (RTF_ADDRCONF | RTF_DEFAULT)) && + ipv6_addr_equal(&nh->fib_nh_gw6, addr)) + break; + } + if (rt && !fib6_info_hold_safe(rt)) + rt = NULL; + rcu_read_unlock(); + return rt; +} + +struct fib6_info *rt6_add_dflt_router(struct net *net, + const struct in6_addr *gwaddr, + struct net_device *dev, + unsigned int pref, + u32 defrtr_usr_metric) +{ + struct fib6_config cfg = { + .fc_table = l3mdev_fib_table(dev) ? : RT6_TABLE_DFLT, + .fc_metric = defrtr_usr_metric, + .fc_ifindex = dev->ifindex, + .fc_flags = RTF_GATEWAY | RTF_ADDRCONF | RTF_DEFAULT | + RTF_UP | RTF_EXPIRES | RTF_PREF(pref), + .fc_protocol = RTPROT_RA, + .fc_type = RTN_UNICAST, + .fc_nlinfo.portid = 0, + .fc_nlinfo.nlh = NULL, + .fc_nlinfo.nl_net = net, + }; + + cfg.fc_gateway = *gwaddr; + + if (!ip6_route_add(&cfg, GFP_ATOMIC, NULL)) { + struct fib6_table *table; + + table = fib6_get_table(dev_net(dev), cfg.fc_table); + if (table) + table->flags |= RT6_TABLE_HAS_DFLT_ROUTER; + } + + return rt6_get_dflt_router(net, gwaddr, dev); +} + +static void __rt6_purge_dflt_routers(struct net *net, + struct fib6_table *table) +{ + struct fib6_info *rt; + +restart: + rcu_read_lock(); + for_each_fib6_node_rt_rcu(&table->tb6_root) { + struct net_device *dev = fib6_info_nh_dev(rt); + struct inet6_dev *idev = dev ? __in6_dev_get(dev) : NULL; + + if (rt->fib6_flags & (RTF_DEFAULT | RTF_ADDRCONF) && + (!idev || idev->cnf.accept_ra != 2) && + fib6_info_hold_safe(rt)) { + rcu_read_unlock(); + ip6_del_rt(net, rt, false); + goto restart; + } + } + rcu_read_unlock(); + + table->flags &= ~RT6_TABLE_HAS_DFLT_ROUTER; +} + +void rt6_purge_dflt_routers(struct net *net) +{ + struct fib6_table *table; + struct hlist_head *head; + unsigned int h; + + rcu_read_lock(); + + for (h = 0; h < FIB6_TABLE_HASHSZ; h++) { + head = &net->ipv6.fib_table_hash[h]; + hlist_for_each_entry_rcu(table, head, tb6_hlist) { + if (table->flags & RT6_TABLE_HAS_DFLT_ROUTER) + __rt6_purge_dflt_routers(net, table); + } + } + + rcu_read_unlock(); +} + +static void rtmsg_to_fib6_config(struct net *net, + struct in6_rtmsg *rtmsg, + struct fib6_config *cfg) +{ + *cfg = (struct fib6_config){ + .fc_table = l3mdev_fib_table_by_index(net, rtmsg->rtmsg_ifindex) ? + : RT6_TABLE_MAIN, + .fc_ifindex = rtmsg->rtmsg_ifindex, + .fc_metric = rtmsg->rtmsg_metric ? : IP6_RT_PRIO_USER, + .fc_expires = rtmsg->rtmsg_info, + .fc_dst_len = rtmsg->rtmsg_dst_len, + .fc_src_len = rtmsg->rtmsg_src_len, + .fc_flags = rtmsg->rtmsg_flags, + .fc_type = rtmsg->rtmsg_type, + + .fc_nlinfo.nl_net = net, + + .fc_dst = rtmsg->rtmsg_dst, + .fc_src = rtmsg->rtmsg_src, + .fc_gateway = rtmsg->rtmsg_gateway, + }; +} + +int ipv6_route_ioctl(struct net *net, unsigned int cmd, struct in6_rtmsg *rtmsg) +{ + struct fib6_config cfg; + int err; + + if (cmd != SIOCADDRT && cmd != SIOCDELRT) + return -EINVAL; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + rtmsg_to_fib6_config(net, rtmsg, &cfg); + + rtnl_lock(); + switch (cmd) { + case SIOCADDRT: + err = ip6_route_add(&cfg, GFP_KERNEL, NULL); + break; + case SIOCDELRT: + err = ip6_route_del(&cfg, NULL); + break; + } + rtnl_unlock(); + return err; +} + +/* + * Drop the packet on the floor + */ + +static int ip6_pkt_drop(struct sk_buff *skb, u8 code, int ipstats_mib_noroutes) +{ + struct dst_entry *dst = skb_dst(skb); + struct net *net = dev_net(dst->dev); + struct inet6_dev *idev; + SKB_DR(reason); + int type; + + if (netif_is_l3_master(skb->dev) || + dst->dev == net->loopback_dev) + idev = __in6_dev_get_safely(dev_get_by_index_rcu(net, IP6CB(skb)->iif)); + else + idev = ip6_dst_idev(dst); + + switch (ipstats_mib_noroutes) { + case IPSTATS_MIB_INNOROUTES: + type = ipv6_addr_type(&ipv6_hdr(skb)->daddr); + if (type == IPV6_ADDR_ANY) { + SKB_DR_SET(reason, IP_INADDRERRORS); + IP6_INC_STATS(net, idev, IPSTATS_MIB_INADDRERRORS); + break; + } + SKB_DR_SET(reason, IP_INNOROUTES); + fallthrough; + case IPSTATS_MIB_OUTNOROUTES: + SKB_DR_OR(reason, IP_OUTNOROUTES); + IP6_INC_STATS(net, idev, ipstats_mib_noroutes); + break; + } + + /* Start over by dropping the dst for l3mdev case */ + if (netif_is_l3_master(skb->dev)) + skb_dst_drop(skb); + + icmpv6_send(skb, ICMPV6_DEST_UNREACH, code, 0); + kfree_skb_reason(skb, reason); + return 0; +} + +static int ip6_pkt_discard(struct sk_buff *skb) +{ + return ip6_pkt_drop(skb, ICMPV6_NOROUTE, IPSTATS_MIB_INNOROUTES); +} + +static int ip6_pkt_discard_out(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + skb->dev = skb_dst(skb)->dev; + return ip6_pkt_drop(skb, ICMPV6_NOROUTE, IPSTATS_MIB_OUTNOROUTES); +} + +static int ip6_pkt_prohibit(struct sk_buff *skb) +{ + return ip6_pkt_drop(skb, ICMPV6_ADM_PROHIBITED, IPSTATS_MIB_INNOROUTES); +} + +static int ip6_pkt_prohibit_out(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + skb->dev = skb_dst(skb)->dev; + return ip6_pkt_drop(skb, ICMPV6_ADM_PROHIBITED, IPSTATS_MIB_OUTNOROUTES); +} + +/* + * Allocate a dst for local (unicast / anycast) address. + */ + +struct fib6_info *addrconf_f6i_alloc(struct net *net, + struct inet6_dev *idev, + const struct in6_addr *addr, + bool anycast, gfp_t gfp_flags) +{ + struct fib6_config cfg = { + .fc_table = l3mdev_fib_table(idev->dev) ? : RT6_TABLE_LOCAL, + .fc_ifindex = idev->dev->ifindex, + .fc_flags = RTF_UP | RTF_NONEXTHOP, + .fc_dst = *addr, + .fc_dst_len = 128, + .fc_protocol = RTPROT_KERNEL, + .fc_nlinfo.nl_net = net, + .fc_ignore_dev_down = true, + }; + struct fib6_info *f6i; + + if (anycast) { + cfg.fc_type = RTN_ANYCAST; + cfg.fc_flags |= RTF_ANYCAST; + } else { + cfg.fc_type = RTN_LOCAL; + cfg.fc_flags |= RTF_LOCAL; + } + + f6i = ip6_route_info_create(&cfg, gfp_flags, NULL); + if (!IS_ERR(f6i)) { + f6i->dst_nocount = true; + + if (!anycast && + (net->ipv6.devconf_all->disable_policy || + idev->cnf.disable_policy)) + f6i->dst_nopolicy = true; + } + + return f6i; +} + +/* remove deleted ip from prefsrc entries */ +struct arg_dev_net_ip { + struct net_device *dev; + struct net *net; + struct in6_addr *addr; +}; + +static int fib6_remove_prefsrc(struct fib6_info *rt, void *arg) +{ + struct net_device *dev = ((struct arg_dev_net_ip *)arg)->dev; + struct net *net = ((struct arg_dev_net_ip *)arg)->net; + struct in6_addr *addr = ((struct arg_dev_net_ip *)arg)->addr; + + if (!rt->nh && + ((void *)rt->fib6_nh->fib_nh_dev == dev || !dev) && + rt != net->ipv6.fib6_null_entry && + ipv6_addr_equal(addr, &rt->fib6_prefsrc.addr)) { + spin_lock_bh(&rt6_exception_lock); + /* remove prefsrc entry */ + rt->fib6_prefsrc.plen = 0; + spin_unlock_bh(&rt6_exception_lock); + } + return 0; +} + +void rt6_remove_prefsrc(struct inet6_ifaddr *ifp) +{ + struct net *net = dev_net(ifp->idev->dev); + struct arg_dev_net_ip adni = { + .dev = ifp->idev->dev, + .net = net, + .addr = &ifp->addr, + }; + fib6_clean_all(net, fib6_remove_prefsrc, &adni); +} + +#define RTF_RA_ROUTER (RTF_ADDRCONF | RTF_DEFAULT) + +/* Remove routers and update dst entries when gateway turn into host. */ +static int fib6_clean_tohost(struct fib6_info *rt, void *arg) +{ + struct in6_addr *gateway = (struct in6_addr *)arg; + struct fib6_nh *nh; + + /* RA routes do not use nexthops */ + if (rt->nh) + return 0; + + nh = rt->fib6_nh; + if (((rt->fib6_flags & RTF_RA_ROUTER) == RTF_RA_ROUTER) && + nh->fib_nh_gw_family && ipv6_addr_equal(gateway, &nh->fib_nh_gw6)) + return -1; + + /* Further clean up cached routes in exception table. + * This is needed because cached route may have a different + * gateway than its 'parent' in the case of an ip redirect. + */ + fib6_nh_exceptions_clean_tohost(nh, gateway); + + return 0; +} + +void rt6_clean_tohost(struct net *net, struct in6_addr *gateway) +{ + fib6_clean_all(net, fib6_clean_tohost, gateway); +} + +struct arg_netdev_event { + const struct net_device *dev; + union { + unsigned char nh_flags; + unsigned long event; + }; +}; + +static struct fib6_info *rt6_multipath_first_sibling(const struct fib6_info *rt) +{ + struct fib6_info *iter; + struct fib6_node *fn; + + fn = rcu_dereference_protected(rt->fib6_node, + lockdep_is_held(&rt->fib6_table->tb6_lock)); + iter = rcu_dereference_protected(fn->leaf, + lockdep_is_held(&rt->fib6_table->tb6_lock)); + while (iter) { + if (iter->fib6_metric == rt->fib6_metric && + rt6_qualify_for_ecmp(iter)) + return iter; + iter = rcu_dereference_protected(iter->fib6_next, + lockdep_is_held(&rt->fib6_table->tb6_lock)); + } + + return NULL; +} + +/* only called for fib entries with builtin fib6_nh */ +static bool rt6_is_dead(const struct fib6_info *rt) +{ + if (rt->fib6_nh->fib_nh_flags & RTNH_F_DEAD || + (rt->fib6_nh->fib_nh_flags & RTNH_F_LINKDOWN && + ip6_ignore_linkdown(rt->fib6_nh->fib_nh_dev))) + return true; + + return false; +} + +static int rt6_multipath_total_weight(const struct fib6_info *rt) +{ + struct fib6_info *iter; + int total = 0; + + if (!rt6_is_dead(rt)) + total += rt->fib6_nh->fib_nh_weight; + + list_for_each_entry(iter, &rt->fib6_siblings, fib6_siblings) { + if (!rt6_is_dead(iter)) + total += iter->fib6_nh->fib_nh_weight; + } + + return total; +} + +static void rt6_upper_bound_set(struct fib6_info *rt, int *weight, int total) +{ + int upper_bound = -1; + + if (!rt6_is_dead(rt)) { + *weight += rt->fib6_nh->fib_nh_weight; + upper_bound = DIV_ROUND_CLOSEST_ULL((u64) (*weight) << 31, + total) - 1; + } + atomic_set(&rt->fib6_nh->fib_nh_upper_bound, upper_bound); +} + +static void rt6_multipath_upper_bound_set(struct fib6_info *rt, int total) +{ + struct fib6_info *iter; + int weight = 0; + + rt6_upper_bound_set(rt, &weight, total); + + list_for_each_entry(iter, &rt->fib6_siblings, fib6_siblings) + rt6_upper_bound_set(iter, &weight, total); +} + +void rt6_multipath_rebalance(struct fib6_info *rt) +{ + struct fib6_info *first; + int total; + + /* In case the entire multipath route was marked for flushing, + * then there is no need to rebalance upon the removal of every + * sibling route. + */ + if (!rt->fib6_nsiblings || rt->should_flush) + return; + + /* During lookup routes are evaluated in order, so we need to + * make sure upper bounds are assigned from the first sibling + * onwards. + */ + first = rt6_multipath_first_sibling(rt); + if (WARN_ON_ONCE(!first)) + return; + + total = rt6_multipath_total_weight(first); + rt6_multipath_upper_bound_set(first, total); +} + +static int fib6_ifup(struct fib6_info *rt, void *p_arg) +{ + const struct arg_netdev_event *arg = p_arg; + struct net *net = dev_net(arg->dev); + + if (rt != net->ipv6.fib6_null_entry && !rt->nh && + rt->fib6_nh->fib_nh_dev == arg->dev) { + rt->fib6_nh->fib_nh_flags &= ~arg->nh_flags; + fib6_update_sernum_upto_root(net, rt); + rt6_multipath_rebalance(rt); + } + + return 0; +} + +void rt6_sync_up(struct net_device *dev, unsigned char nh_flags) +{ + struct arg_netdev_event arg = { + .dev = dev, + { + .nh_flags = nh_flags, + }, + }; + + if (nh_flags & RTNH_F_DEAD && netif_carrier_ok(dev)) + arg.nh_flags |= RTNH_F_LINKDOWN; + + fib6_clean_all(dev_net(dev), fib6_ifup, &arg); +} + +/* only called for fib entries with inline fib6_nh */ +static bool rt6_multipath_uses_dev(const struct fib6_info *rt, + const struct net_device *dev) +{ + struct fib6_info *iter; + + if (rt->fib6_nh->fib_nh_dev == dev) + return true; + list_for_each_entry(iter, &rt->fib6_siblings, fib6_siblings) + if (iter->fib6_nh->fib_nh_dev == dev) + return true; + + return false; +} + +static void rt6_multipath_flush(struct fib6_info *rt) +{ + struct fib6_info *iter; + + rt->should_flush = 1; + list_for_each_entry(iter, &rt->fib6_siblings, fib6_siblings) + iter->should_flush = 1; +} + +static unsigned int rt6_multipath_dead_count(const struct fib6_info *rt, + const struct net_device *down_dev) +{ + struct fib6_info *iter; + unsigned int dead = 0; + + if (rt->fib6_nh->fib_nh_dev == down_dev || + rt->fib6_nh->fib_nh_flags & RTNH_F_DEAD) + dead++; + list_for_each_entry(iter, &rt->fib6_siblings, fib6_siblings) + if (iter->fib6_nh->fib_nh_dev == down_dev || + iter->fib6_nh->fib_nh_flags & RTNH_F_DEAD) + dead++; + + return dead; +} + +static void rt6_multipath_nh_flags_set(struct fib6_info *rt, + const struct net_device *dev, + unsigned char nh_flags) +{ + struct fib6_info *iter; + + if (rt->fib6_nh->fib_nh_dev == dev) + rt->fib6_nh->fib_nh_flags |= nh_flags; + list_for_each_entry(iter, &rt->fib6_siblings, fib6_siblings) + if (iter->fib6_nh->fib_nh_dev == dev) + iter->fib6_nh->fib_nh_flags |= nh_flags; +} + +/* called with write lock held for table with rt */ +static int fib6_ifdown(struct fib6_info *rt, void *p_arg) +{ + const struct arg_netdev_event *arg = p_arg; + const struct net_device *dev = arg->dev; + struct net *net = dev_net(dev); + + if (rt == net->ipv6.fib6_null_entry || rt->nh) + return 0; + + switch (arg->event) { + case NETDEV_UNREGISTER: + return rt->fib6_nh->fib_nh_dev == dev ? -1 : 0; + case NETDEV_DOWN: + if (rt->should_flush) + return -1; + if (!rt->fib6_nsiblings) + return rt->fib6_nh->fib_nh_dev == dev ? -1 : 0; + if (rt6_multipath_uses_dev(rt, dev)) { + unsigned int count; + + count = rt6_multipath_dead_count(rt, dev); + if (rt->fib6_nsiblings + 1 == count) { + rt6_multipath_flush(rt); + return -1; + } + rt6_multipath_nh_flags_set(rt, dev, RTNH_F_DEAD | + RTNH_F_LINKDOWN); + fib6_update_sernum(net, rt); + rt6_multipath_rebalance(rt); + } + return -2; + case NETDEV_CHANGE: + if (rt->fib6_nh->fib_nh_dev != dev || + rt->fib6_flags & (RTF_LOCAL | RTF_ANYCAST)) + break; + rt->fib6_nh->fib_nh_flags |= RTNH_F_LINKDOWN; + rt6_multipath_rebalance(rt); + break; + } + + return 0; +} + +void rt6_sync_down_dev(struct net_device *dev, unsigned long event) +{ + struct arg_netdev_event arg = { + .dev = dev, + { + .event = event, + }, + }; + struct net *net = dev_net(dev); + + if (net->ipv6.sysctl.skip_notify_on_dev_down) + fib6_clean_all_skip_notify(net, fib6_ifdown, &arg); + else + fib6_clean_all(net, fib6_ifdown, &arg); +} + +void rt6_disable_ip(struct net_device *dev, unsigned long event) +{ + rt6_sync_down_dev(dev, event); + rt6_uncached_list_flush_dev(dev); + neigh_ifdown(&nd_tbl, dev); +} + +struct rt6_mtu_change_arg { + struct net_device *dev; + unsigned int mtu; + struct fib6_info *f6i; +}; + +static int fib6_nh_mtu_change(struct fib6_nh *nh, void *_arg) +{ + struct rt6_mtu_change_arg *arg = (struct rt6_mtu_change_arg *)_arg; + struct fib6_info *f6i = arg->f6i; + + /* For administrative MTU increase, there is no way to discover + * IPv6 PMTU increase, so PMTU increase should be updated here. + * Since RFC 1981 doesn't include administrative MTU increase + * update PMTU increase is a MUST. (i.e. jumbo frame) + */ + if (nh->fib_nh_dev == arg->dev) { + struct inet6_dev *idev = __in6_dev_get(arg->dev); + u32 mtu = f6i->fib6_pmtu; + + if (mtu >= arg->mtu || + (mtu < arg->mtu && mtu == idev->cnf.mtu6)) + fib6_metric_set(f6i, RTAX_MTU, arg->mtu); + + spin_lock_bh(&rt6_exception_lock); + rt6_exceptions_update_pmtu(idev, nh, arg->mtu); + spin_unlock_bh(&rt6_exception_lock); + } + + return 0; +} + +static int rt6_mtu_change_route(struct fib6_info *f6i, void *p_arg) +{ + struct rt6_mtu_change_arg *arg = (struct rt6_mtu_change_arg *) p_arg; + struct inet6_dev *idev; + + /* In IPv6 pmtu discovery is not optional, + so that RTAX_MTU lock cannot disable it. + We still use this lock to block changes + caused by addrconf/ndisc. + */ + + idev = __in6_dev_get(arg->dev); + if (!idev) + return 0; + + if (fib6_metric_locked(f6i, RTAX_MTU)) + return 0; + + arg->f6i = f6i; + if (f6i->nh) { + /* fib6_nh_mtu_change only returns 0, so this is safe */ + return nexthop_for_each_fib6_nh(f6i->nh, fib6_nh_mtu_change, + arg); + } + + return fib6_nh_mtu_change(f6i->fib6_nh, arg); +} + +void rt6_mtu_change(struct net_device *dev, unsigned int mtu) +{ + struct rt6_mtu_change_arg arg = { + .dev = dev, + .mtu = mtu, + }; + + fib6_clean_all(dev_net(dev), rt6_mtu_change_route, &arg); +} + +static const struct nla_policy rtm_ipv6_policy[RTA_MAX+1] = { + [RTA_UNSPEC] = { .strict_start_type = RTA_DPORT + 1 }, + [RTA_GATEWAY] = { .len = sizeof(struct in6_addr) }, + [RTA_PREFSRC] = { .len = sizeof(struct in6_addr) }, + [RTA_OIF] = { .type = NLA_U32 }, + [RTA_IIF] = { .type = NLA_U32 }, + [RTA_PRIORITY] = { .type = NLA_U32 }, + [RTA_METRICS] = { .type = NLA_NESTED }, + [RTA_MULTIPATH] = { .len = sizeof(struct rtnexthop) }, + [RTA_PREF] = { .type = NLA_U8 }, + [RTA_ENCAP_TYPE] = { .type = NLA_U16 }, + [RTA_ENCAP] = { .type = NLA_NESTED }, + [RTA_EXPIRES] = { .type = NLA_U32 }, + [RTA_UID] = { .type = NLA_U32 }, + [RTA_MARK] = { .type = NLA_U32 }, + [RTA_TABLE] = { .type = NLA_U32 }, + [RTA_IP_PROTO] = { .type = NLA_U8 }, + [RTA_SPORT] = { .type = NLA_U16 }, + [RTA_DPORT] = { .type = NLA_U16 }, + [RTA_NH_ID] = { .type = NLA_U32 }, +}; + +static int rtm_to_fib6_config(struct sk_buff *skb, struct nlmsghdr *nlh, + struct fib6_config *cfg, + struct netlink_ext_ack *extack) +{ + struct rtmsg *rtm; + struct nlattr *tb[RTA_MAX+1]; + unsigned int pref; + int err; + + err = nlmsg_parse_deprecated(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_ipv6_policy, extack); + if (err < 0) + goto errout; + + err = -EINVAL; + rtm = nlmsg_data(nlh); + + if (rtm->rtm_tos) { + NL_SET_ERR_MSG(extack, + "Invalid dsfield (tos): option not available for IPv6"); + goto errout; + } + + *cfg = (struct fib6_config){ + .fc_table = rtm->rtm_table, + .fc_dst_len = rtm->rtm_dst_len, + .fc_src_len = rtm->rtm_src_len, + .fc_flags = RTF_UP, + .fc_protocol = rtm->rtm_protocol, + .fc_type = rtm->rtm_type, + + .fc_nlinfo.portid = NETLINK_CB(skb).portid, + .fc_nlinfo.nlh = nlh, + .fc_nlinfo.nl_net = sock_net(skb->sk), + }; + + if (rtm->rtm_type == RTN_UNREACHABLE || + rtm->rtm_type == RTN_BLACKHOLE || + rtm->rtm_type == RTN_PROHIBIT || + rtm->rtm_type == RTN_THROW) + cfg->fc_flags |= RTF_REJECT; + + if (rtm->rtm_type == RTN_LOCAL) + cfg->fc_flags |= RTF_LOCAL; + + if (rtm->rtm_flags & RTM_F_CLONED) + cfg->fc_flags |= RTF_CACHE; + + cfg->fc_flags |= (rtm->rtm_flags & RTNH_F_ONLINK); + + if (tb[RTA_NH_ID]) { + if (tb[RTA_GATEWAY] || tb[RTA_OIF] || + tb[RTA_MULTIPATH] || tb[RTA_ENCAP]) { + NL_SET_ERR_MSG(extack, + "Nexthop specification and nexthop id are mutually exclusive"); + goto errout; + } + cfg->fc_nh_id = nla_get_u32(tb[RTA_NH_ID]); + } + + if (tb[RTA_GATEWAY]) { + cfg->fc_gateway = nla_get_in6_addr(tb[RTA_GATEWAY]); + cfg->fc_flags |= RTF_GATEWAY; + } + if (tb[RTA_VIA]) { + NL_SET_ERR_MSG(extack, "IPv6 does not support RTA_VIA attribute"); + goto errout; + } + + if (tb[RTA_DST]) { + int plen = (rtm->rtm_dst_len + 7) >> 3; + + if (nla_len(tb[RTA_DST]) < plen) + goto errout; + + nla_memcpy(&cfg->fc_dst, tb[RTA_DST], plen); + } + + if (tb[RTA_SRC]) { + int plen = (rtm->rtm_src_len + 7) >> 3; + + if (nla_len(tb[RTA_SRC]) < plen) + goto errout; + + nla_memcpy(&cfg->fc_src, tb[RTA_SRC], plen); + } + + if (tb[RTA_PREFSRC]) + cfg->fc_prefsrc = nla_get_in6_addr(tb[RTA_PREFSRC]); + + if (tb[RTA_OIF]) + cfg->fc_ifindex = nla_get_u32(tb[RTA_OIF]); + + if (tb[RTA_PRIORITY]) + cfg->fc_metric = nla_get_u32(tb[RTA_PRIORITY]); + + if (tb[RTA_METRICS]) { + cfg->fc_mx = nla_data(tb[RTA_METRICS]); + cfg->fc_mx_len = nla_len(tb[RTA_METRICS]); + } + + if (tb[RTA_TABLE]) + cfg->fc_table = nla_get_u32(tb[RTA_TABLE]); + + if (tb[RTA_MULTIPATH]) { + cfg->fc_mp = nla_data(tb[RTA_MULTIPATH]); + cfg->fc_mp_len = nla_len(tb[RTA_MULTIPATH]); + + err = lwtunnel_valid_encap_type_attr(cfg->fc_mp, + cfg->fc_mp_len, extack); + if (err < 0) + goto errout; + } + + if (tb[RTA_PREF]) { + pref = nla_get_u8(tb[RTA_PREF]); + if (pref != ICMPV6_ROUTER_PREF_LOW && + pref != ICMPV6_ROUTER_PREF_HIGH) + pref = ICMPV6_ROUTER_PREF_MEDIUM; + cfg->fc_flags |= RTF_PREF(pref); + } + + if (tb[RTA_ENCAP]) + cfg->fc_encap = tb[RTA_ENCAP]; + + if (tb[RTA_ENCAP_TYPE]) { + cfg->fc_encap_type = nla_get_u16(tb[RTA_ENCAP_TYPE]); + + err = lwtunnel_valid_encap_type(cfg->fc_encap_type, extack); + if (err < 0) + goto errout; + } + + if (tb[RTA_EXPIRES]) { + unsigned long timeout = addrconf_timeout_fixup(nla_get_u32(tb[RTA_EXPIRES]), HZ); + + if (addrconf_finite_timeout(timeout)) { + cfg->fc_expires = jiffies_to_clock_t(timeout * HZ); + cfg->fc_flags |= RTF_EXPIRES; + } + } + + err = 0; +errout: + return err; +} + +struct rt6_nh { + struct fib6_info *fib6_info; + struct fib6_config r_cfg; + struct list_head next; +}; + +static int ip6_route_info_append(struct net *net, + struct list_head *rt6_nh_list, + struct fib6_info *rt, + struct fib6_config *r_cfg) +{ + struct rt6_nh *nh; + int err = -EEXIST; + + list_for_each_entry(nh, rt6_nh_list, next) { + /* check if fib6_info already exists */ + if (rt6_duplicate_nexthop(nh->fib6_info, rt)) + return err; + } + + nh = kzalloc(sizeof(*nh), GFP_KERNEL); + if (!nh) + return -ENOMEM; + nh->fib6_info = rt; + memcpy(&nh->r_cfg, r_cfg, sizeof(*r_cfg)); + list_add_tail(&nh->next, rt6_nh_list); + + return 0; +} + +static void ip6_route_mpath_notify(struct fib6_info *rt, + struct fib6_info *rt_last, + struct nl_info *info, + __u16 nlflags) +{ + /* if this is an APPEND route, then rt points to the first route + * inserted and rt_last points to last route inserted. Userspace + * wants a consistent dump of the route which starts at the first + * nexthop. Since sibling routes are always added at the end of + * the list, find the first sibling of the last route appended + */ + if ((nlflags & NLM_F_APPEND) && rt_last && rt_last->fib6_nsiblings) { + rt = list_first_entry(&rt_last->fib6_siblings, + struct fib6_info, + fib6_siblings); + } + + if (rt) + inet6_rt_notify(RTM_NEWROUTE, rt, info, nlflags); +} + +static bool ip6_route_mpath_should_notify(const struct fib6_info *rt) +{ + bool rt_can_ecmp = rt6_qualify_for_ecmp(rt); + bool should_notify = false; + struct fib6_info *leaf; + struct fib6_node *fn; + + rcu_read_lock(); + fn = rcu_dereference(rt->fib6_node); + if (!fn) + goto out; + + leaf = rcu_dereference(fn->leaf); + if (!leaf) + goto out; + + if (rt == leaf || + (rt_can_ecmp && rt->fib6_metric == leaf->fib6_metric && + rt6_qualify_for_ecmp(leaf))) + should_notify = true; +out: + rcu_read_unlock(); + + return should_notify; +} + +static int fib6_gw_from_attr(struct in6_addr *gw, struct nlattr *nla, + struct netlink_ext_ack *extack) +{ + if (nla_len(nla) < sizeof(*gw)) { + NL_SET_ERR_MSG(extack, "Invalid IPv6 address in RTA_GATEWAY"); + return -EINVAL; + } + + *gw = nla_get_in6_addr(nla); + + return 0; +} + +static int ip6_route_multipath_add(struct fib6_config *cfg, + struct netlink_ext_ack *extack) +{ + struct fib6_info *rt_notif = NULL, *rt_last = NULL; + struct nl_info *info = &cfg->fc_nlinfo; + struct fib6_config r_cfg; + struct rtnexthop *rtnh; + struct fib6_info *rt; + struct rt6_nh *err_nh; + struct rt6_nh *nh, *nh_safe; + __u16 nlflags; + int remaining; + int attrlen; + int err = 1; + int nhn = 0; + int replace = (cfg->fc_nlinfo.nlh && + (cfg->fc_nlinfo.nlh->nlmsg_flags & NLM_F_REPLACE)); + LIST_HEAD(rt6_nh_list); + + nlflags = replace ? NLM_F_REPLACE : NLM_F_CREATE; + if (info->nlh && info->nlh->nlmsg_flags & NLM_F_APPEND) + nlflags |= NLM_F_APPEND; + + remaining = cfg->fc_mp_len; + rtnh = (struct rtnexthop *)cfg->fc_mp; + + /* Parse a Multipath Entry and build a list (rt6_nh_list) of + * fib6_info structs per nexthop + */ + while (rtnh_ok(rtnh, remaining)) { + memcpy(&r_cfg, cfg, sizeof(*cfg)); + if (rtnh->rtnh_ifindex) + r_cfg.fc_ifindex = rtnh->rtnh_ifindex; + + attrlen = rtnh_attrlen(rtnh); + if (attrlen > 0) { + struct nlattr *nla, *attrs = rtnh_attrs(rtnh); + + nla = nla_find(attrs, attrlen, RTA_GATEWAY); + if (nla) { + err = fib6_gw_from_attr(&r_cfg.fc_gateway, nla, + extack); + if (err) + goto cleanup; + + r_cfg.fc_flags |= RTF_GATEWAY; + } + r_cfg.fc_encap = nla_find(attrs, attrlen, RTA_ENCAP); + + /* RTA_ENCAP_TYPE length checked in + * lwtunnel_valid_encap_type_attr + */ + nla = nla_find(attrs, attrlen, RTA_ENCAP_TYPE); + if (nla) + r_cfg.fc_encap_type = nla_get_u16(nla); + } + + r_cfg.fc_flags |= (rtnh->rtnh_flags & RTNH_F_ONLINK); + rt = ip6_route_info_create(&r_cfg, GFP_KERNEL, extack); + if (IS_ERR(rt)) { + err = PTR_ERR(rt); + rt = NULL; + goto cleanup; + } + if (!rt6_qualify_for_ecmp(rt)) { + err = -EINVAL; + NL_SET_ERR_MSG(extack, + "Device only routes can not be added for IPv6 using the multipath API."); + fib6_info_release(rt); + goto cleanup; + } + + rt->fib6_nh->fib_nh_weight = rtnh->rtnh_hops + 1; + + err = ip6_route_info_append(info->nl_net, &rt6_nh_list, + rt, &r_cfg); + if (err) { + fib6_info_release(rt); + goto cleanup; + } + + rtnh = rtnh_next(rtnh, &remaining); + } + + if (list_empty(&rt6_nh_list)) { + NL_SET_ERR_MSG(extack, + "Invalid nexthop configuration - no valid nexthops"); + return -EINVAL; + } + + /* for add and replace send one notification with all nexthops. + * Skip the notification in fib6_add_rt2node and send one with + * the full route when done + */ + info->skip_notify = 1; + + /* For add and replace, send one notification with all nexthops. For + * append, send one notification with all appended nexthops. + */ + info->skip_notify_kernel = 1; + + err_nh = NULL; + list_for_each_entry(nh, &rt6_nh_list, next) { + err = __ip6_ins_rt(nh->fib6_info, info, extack); + fib6_info_release(nh->fib6_info); + + if (!err) { + /* save reference to last route successfully inserted */ + rt_last = nh->fib6_info; + + /* save reference to first route for notification */ + if (!rt_notif) + rt_notif = nh->fib6_info; + } + + /* nh->fib6_info is used or freed at this point, reset to NULL*/ + nh->fib6_info = NULL; + if (err) { + if (replace && nhn) + NL_SET_ERR_MSG_MOD(extack, + "multipath route replace failed (check consistency of installed routes)"); + err_nh = nh; + goto add_errout; + } + + /* Because each route is added like a single route we remove + * these flags after the first nexthop: if there is a collision, + * we have already failed to add the first nexthop: + * fib6_add_rt2node() has rejected it; when replacing, old + * nexthops have been replaced by first new, the rest should + * be added to it. + */ + if (cfg->fc_nlinfo.nlh) { + cfg->fc_nlinfo.nlh->nlmsg_flags &= ~(NLM_F_EXCL | + NLM_F_REPLACE); + cfg->fc_nlinfo.nlh->nlmsg_flags |= NLM_F_CREATE; + } + nhn++; + } + + /* An in-kernel notification should only be sent in case the new + * multipath route is added as the first route in the node, or if + * it was appended to it. We pass 'rt_notif' since it is the first + * sibling and might allow us to skip some checks in the replace case. + */ + if (ip6_route_mpath_should_notify(rt_notif)) { + enum fib_event_type fib_event; + + if (rt_notif->fib6_nsiblings != nhn - 1) + fib_event = FIB_EVENT_ENTRY_APPEND; + else + fib_event = FIB_EVENT_ENTRY_REPLACE; + + err = call_fib6_multipath_entry_notifiers(info->nl_net, + fib_event, rt_notif, + nhn - 1, extack); + if (err) { + /* Delete all the siblings that were just added */ + err_nh = NULL; + goto add_errout; + } + } + + /* success ... tell user about new route */ + ip6_route_mpath_notify(rt_notif, rt_last, info, nlflags); + goto cleanup; + +add_errout: + /* send notification for routes that were added so that + * the delete notifications sent by ip6_route_del are + * coherent + */ + if (rt_notif) + ip6_route_mpath_notify(rt_notif, rt_last, info, nlflags); + + /* Delete routes that were already added */ + list_for_each_entry(nh, &rt6_nh_list, next) { + if (err_nh == nh) + break; + ip6_route_del(&nh->r_cfg, extack); + } + +cleanup: + list_for_each_entry_safe(nh, nh_safe, &rt6_nh_list, next) { + if (nh->fib6_info) + fib6_info_release(nh->fib6_info); + list_del(&nh->next); + kfree(nh); + } + + return err; +} + +static int ip6_route_multipath_del(struct fib6_config *cfg, + struct netlink_ext_ack *extack) +{ + struct fib6_config r_cfg; + struct rtnexthop *rtnh; + int last_err = 0; + int remaining; + int attrlen; + int err; + + remaining = cfg->fc_mp_len; + rtnh = (struct rtnexthop *)cfg->fc_mp; + + /* Parse a Multipath Entry */ + while (rtnh_ok(rtnh, remaining)) { + memcpy(&r_cfg, cfg, sizeof(*cfg)); + if (rtnh->rtnh_ifindex) + r_cfg.fc_ifindex = rtnh->rtnh_ifindex; + + attrlen = rtnh_attrlen(rtnh); + if (attrlen > 0) { + struct nlattr *nla, *attrs = rtnh_attrs(rtnh); + + nla = nla_find(attrs, attrlen, RTA_GATEWAY); + if (nla) { + err = fib6_gw_from_attr(&r_cfg.fc_gateway, nla, + extack); + if (err) { + last_err = err; + goto next_rtnh; + } + + r_cfg.fc_flags |= RTF_GATEWAY; + } + } + err = ip6_route_del(&r_cfg, extack); + if (err) + last_err = err; + +next_rtnh: + rtnh = rtnh_next(rtnh, &remaining); + } + + return last_err; +} + +static int inet6_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct fib6_config cfg; + int err; + + err = rtm_to_fib6_config(skb, nlh, &cfg, extack); + if (err < 0) + return err; + + if (cfg.fc_nh_id && + !nexthop_find_by_id(sock_net(skb->sk), cfg.fc_nh_id)) { + NL_SET_ERR_MSG(extack, "Nexthop id does not exist"); + return -EINVAL; + } + + if (cfg.fc_mp) + return ip6_route_multipath_del(&cfg, extack); + else { + cfg.fc_delete_all_nh = 1; + return ip6_route_del(&cfg, extack); + } +} + +static int inet6_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct fib6_config cfg; + int err; + + err = rtm_to_fib6_config(skb, nlh, &cfg, extack); + if (err < 0) + return err; + + if (cfg.fc_metric == 0) + cfg.fc_metric = IP6_RT_PRIO_USER; + + if (cfg.fc_mp) + return ip6_route_multipath_add(&cfg, extack); + else + return ip6_route_add(&cfg, GFP_KERNEL, extack); +} + +/* add the overhead of this fib6_nh to nexthop_len */ +static int rt6_nh_nlmsg_size(struct fib6_nh *nh, void *arg) +{ + int *nexthop_len = arg; + + *nexthop_len += nla_total_size(0) /* RTA_MULTIPATH */ + + NLA_ALIGN(sizeof(struct rtnexthop)) + + nla_total_size(16); /* RTA_GATEWAY */ + + if (nh->fib_nh_lws) { + /* RTA_ENCAP_TYPE */ + *nexthop_len += lwtunnel_get_encap_size(nh->fib_nh_lws); + /* RTA_ENCAP */ + *nexthop_len += nla_total_size(2); + } + + return 0; +} + +static size_t rt6_nlmsg_size(struct fib6_info *f6i) +{ + int nexthop_len; + + if (f6i->nh) { + nexthop_len = nla_total_size(4); /* RTA_NH_ID */ + nexthop_for_each_fib6_nh(f6i->nh, rt6_nh_nlmsg_size, + &nexthop_len); + } else { + struct fib6_info *sibling, *next_sibling; + struct fib6_nh *nh = f6i->fib6_nh; + + nexthop_len = 0; + if (f6i->fib6_nsiblings) { + rt6_nh_nlmsg_size(nh, &nexthop_len); + + list_for_each_entry_safe(sibling, next_sibling, + &f6i->fib6_siblings, fib6_siblings) { + rt6_nh_nlmsg_size(sibling->fib6_nh, &nexthop_len); + } + } + nexthop_len += lwtunnel_get_encap_size(nh->fib_nh_lws); + } + + return NLMSG_ALIGN(sizeof(struct rtmsg)) + + nla_total_size(16) /* RTA_SRC */ + + nla_total_size(16) /* RTA_DST */ + + nla_total_size(16) /* RTA_GATEWAY */ + + nla_total_size(16) /* RTA_PREFSRC */ + + nla_total_size(4) /* RTA_TABLE */ + + nla_total_size(4) /* RTA_IIF */ + + nla_total_size(4) /* RTA_OIF */ + + nla_total_size(4) /* RTA_PRIORITY */ + + RTAX_MAX * nla_total_size(4) /* RTA_METRICS */ + + nla_total_size(sizeof(struct rta_cacheinfo)) + + nla_total_size(TCP_CA_NAME_MAX) /* RTAX_CC_ALGO */ + + nla_total_size(1) /* RTA_PREF */ + + nexthop_len; +} + +static int rt6_fill_node_nexthop(struct sk_buff *skb, struct nexthop *nh, + unsigned char *flags) +{ + if (nexthop_is_multipath(nh)) { + struct nlattr *mp; + + mp = nla_nest_start_noflag(skb, RTA_MULTIPATH); + if (!mp) + goto nla_put_failure; + + if (nexthop_mpath_fill_node(skb, nh, AF_INET6)) + goto nla_put_failure; + + nla_nest_end(skb, mp); + } else { + struct fib6_nh *fib6_nh; + + fib6_nh = nexthop_fib6_nh(nh); + if (fib_nexthop_info(skb, &fib6_nh->nh_common, AF_INET6, + flags, false) < 0) + goto nla_put_failure; + } + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static int rt6_fill_node(struct net *net, struct sk_buff *skb, + struct fib6_info *rt, struct dst_entry *dst, + struct in6_addr *dest, struct in6_addr *src, + int iif, int type, u32 portid, u32 seq, + unsigned int flags) +{ + struct rt6_info *rt6 = (struct rt6_info *)dst; + struct rt6key *rt6_dst, *rt6_src; + u32 *pmetrics, table, rt6_flags; + unsigned char nh_flags = 0; + struct nlmsghdr *nlh; + struct rtmsg *rtm; + long expires = 0; + + nlh = nlmsg_put(skb, portid, seq, type, sizeof(*rtm), flags); + if (!nlh) + return -EMSGSIZE; + + if (rt6) { + rt6_dst = &rt6->rt6i_dst; + rt6_src = &rt6->rt6i_src; + rt6_flags = rt6->rt6i_flags; + } else { + rt6_dst = &rt->fib6_dst; + rt6_src = &rt->fib6_src; + rt6_flags = rt->fib6_flags; + } + + rtm = nlmsg_data(nlh); + rtm->rtm_family = AF_INET6; + rtm->rtm_dst_len = rt6_dst->plen; + rtm->rtm_src_len = rt6_src->plen; + rtm->rtm_tos = 0; + if (rt->fib6_table) + table = rt->fib6_table->tb6_id; + else + table = RT6_TABLE_UNSPEC; + rtm->rtm_table = table < 256 ? table : RT_TABLE_COMPAT; + if (nla_put_u32(skb, RTA_TABLE, table)) + goto nla_put_failure; + + rtm->rtm_type = rt->fib6_type; + rtm->rtm_flags = 0; + rtm->rtm_scope = RT_SCOPE_UNIVERSE; + rtm->rtm_protocol = rt->fib6_protocol; + + if (rt6_flags & RTF_CACHE) + rtm->rtm_flags |= RTM_F_CLONED; + + if (dest) { + if (nla_put_in6_addr(skb, RTA_DST, dest)) + goto nla_put_failure; + rtm->rtm_dst_len = 128; + } else if (rtm->rtm_dst_len) + if (nla_put_in6_addr(skb, RTA_DST, &rt6_dst->addr)) + goto nla_put_failure; +#ifdef CONFIG_IPV6_SUBTREES + if (src) { + if (nla_put_in6_addr(skb, RTA_SRC, src)) + goto nla_put_failure; + rtm->rtm_src_len = 128; + } else if (rtm->rtm_src_len && + nla_put_in6_addr(skb, RTA_SRC, &rt6_src->addr)) + goto nla_put_failure; +#endif + if (iif) { +#ifdef CONFIG_IPV6_MROUTE + if (ipv6_addr_is_multicast(&rt6_dst->addr)) { + int err = ip6mr_get_route(net, skb, rtm, portid); + + if (err == 0) + return 0; + if (err < 0) + goto nla_put_failure; + } else +#endif + if (nla_put_u32(skb, RTA_IIF, iif)) + goto nla_put_failure; + } else if (dest) { + struct in6_addr saddr_buf; + if (ip6_route_get_saddr(net, rt, dest, 0, &saddr_buf) == 0 && + nla_put_in6_addr(skb, RTA_PREFSRC, &saddr_buf)) + goto nla_put_failure; + } + + if (rt->fib6_prefsrc.plen) { + struct in6_addr saddr_buf; + saddr_buf = rt->fib6_prefsrc.addr; + if (nla_put_in6_addr(skb, RTA_PREFSRC, &saddr_buf)) + goto nla_put_failure; + } + + pmetrics = dst ? dst_metrics_ptr(dst) : rt->fib6_metrics->metrics; + if (rtnetlink_put_metrics(skb, pmetrics) < 0) + goto nla_put_failure; + + if (nla_put_u32(skb, RTA_PRIORITY, rt->fib6_metric)) + goto nla_put_failure; + + /* For multipath routes, walk the siblings list and add + * each as a nexthop within RTA_MULTIPATH. + */ + if (rt6) { + if (rt6_flags & RTF_GATEWAY && + nla_put_in6_addr(skb, RTA_GATEWAY, &rt6->rt6i_gateway)) + goto nla_put_failure; + + if (dst->dev && nla_put_u32(skb, RTA_OIF, dst->dev->ifindex)) + goto nla_put_failure; + + if (dst->lwtstate && + lwtunnel_fill_encap(skb, dst->lwtstate, RTA_ENCAP, RTA_ENCAP_TYPE) < 0) + goto nla_put_failure; + } else if (rt->fib6_nsiblings) { + struct fib6_info *sibling, *next_sibling; + struct nlattr *mp; + + mp = nla_nest_start_noflag(skb, RTA_MULTIPATH); + if (!mp) + goto nla_put_failure; + + if (fib_add_nexthop(skb, &rt->fib6_nh->nh_common, + rt->fib6_nh->fib_nh_weight, AF_INET6, + 0) < 0) + goto nla_put_failure; + + list_for_each_entry_safe(sibling, next_sibling, + &rt->fib6_siblings, fib6_siblings) { + if (fib_add_nexthop(skb, &sibling->fib6_nh->nh_common, + sibling->fib6_nh->fib_nh_weight, + AF_INET6, 0) < 0) + goto nla_put_failure; + } + + nla_nest_end(skb, mp); + } else if (rt->nh) { + if (nla_put_u32(skb, RTA_NH_ID, rt->nh->id)) + goto nla_put_failure; + + if (nexthop_is_blackhole(rt->nh)) + rtm->rtm_type = RTN_BLACKHOLE; + + if (READ_ONCE(net->ipv4.sysctl_nexthop_compat_mode) && + rt6_fill_node_nexthop(skb, rt->nh, &nh_flags) < 0) + goto nla_put_failure; + + rtm->rtm_flags |= nh_flags; + } else { + if (fib_nexthop_info(skb, &rt->fib6_nh->nh_common, AF_INET6, + &nh_flags, false) < 0) + goto nla_put_failure; + + rtm->rtm_flags |= nh_flags; + } + + if (rt6_flags & RTF_EXPIRES) { + expires = dst ? dst->expires : rt->expires; + expires -= jiffies; + } + + if (!dst) { + if (READ_ONCE(rt->offload)) + rtm->rtm_flags |= RTM_F_OFFLOAD; + if (READ_ONCE(rt->trap)) + rtm->rtm_flags |= RTM_F_TRAP; + if (READ_ONCE(rt->offload_failed)) + rtm->rtm_flags |= RTM_F_OFFLOAD_FAILED; + } + + if (rtnl_put_cacheinfo(skb, dst, 0, expires, dst ? dst->error : 0) < 0) + goto nla_put_failure; + + if (nla_put_u8(skb, RTA_PREF, IPV6_EXTRACT_PREF(rt6_flags))) + goto nla_put_failure; + + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int fib6_info_nh_uses_dev(struct fib6_nh *nh, void *arg) +{ + const struct net_device *dev = arg; + + if (nh->fib_nh_dev == dev) + return 1; + + return 0; +} + +static bool fib6_info_uses_dev(const struct fib6_info *f6i, + const struct net_device *dev) +{ + if (f6i->nh) { + struct net_device *_dev = (struct net_device *)dev; + + return !!nexthop_for_each_fib6_nh(f6i->nh, + fib6_info_nh_uses_dev, + _dev); + } + + if (f6i->fib6_nh->fib_nh_dev == dev) + return true; + + if (f6i->fib6_nsiblings) { + struct fib6_info *sibling, *next_sibling; + + list_for_each_entry_safe(sibling, next_sibling, + &f6i->fib6_siblings, fib6_siblings) { + if (sibling->fib6_nh->fib_nh_dev == dev) + return true; + } + } + + return false; +} + +struct fib6_nh_exception_dump_walker { + struct rt6_rtnl_dump_arg *dump; + struct fib6_info *rt; + unsigned int flags; + unsigned int skip; + unsigned int count; +}; + +static int rt6_nh_dump_exceptions(struct fib6_nh *nh, void *arg) +{ + struct fib6_nh_exception_dump_walker *w = arg; + struct rt6_rtnl_dump_arg *dump = w->dump; + struct rt6_exception_bucket *bucket; + struct rt6_exception *rt6_ex; + int i, err; + + bucket = fib6_nh_get_excptn_bucket(nh, NULL); + if (!bucket) + return 0; + + for (i = 0; i < FIB6_EXCEPTION_BUCKET_SIZE; i++) { + hlist_for_each_entry(rt6_ex, &bucket->chain, hlist) { + if (w->skip) { + w->skip--; + continue; + } + + /* Expiration of entries doesn't bump sernum, insertion + * does. Removal is triggered by insertion, so we can + * rely on the fact that if entries change between two + * partial dumps, this node is scanned again completely, + * see rt6_insert_exception() and fib6_dump_table(). + * + * Count expired entries we go through as handled + * entries that we'll skip next time, in case of partial + * node dump. Otherwise, if entries expire meanwhile, + * we'll skip the wrong amount. + */ + if (rt6_check_expired(rt6_ex->rt6i)) { + w->count++; + continue; + } + + err = rt6_fill_node(dump->net, dump->skb, w->rt, + &rt6_ex->rt6i->dst, NULL, NULL, 0, + RTM_NEWROUTE, + NETLINK_CB(dump->cb->skb).portid, + dump->cb->nlh->nlmsg_seq, w->flags); + if (err) + return err; + + w->count++; + } + bucket++; + } + + return 0; +} + +/* Return -1 if done with node, number of handled routes on partial dump */ +int rt6_dump_route(struct fib6_info *rt, void *p_arg, unsigned int skip) +{ + struct rt6_rtnl_dump_arg *arg = (struct rt6_rtnl_dump_arg *) p_arg; + struct fib_dump_filter *filter = &arg->filter; + unsigned int flags = NLM_F_MULTI; + struct net *net = arg->net; + int count = 0; + + if (rt == net->ipv6.fib6_null_entry) + return -1; + + if ((filter->flags & RTM_F_PREFIX) && + !(rt->fib6_flags & RTF_PREFIX_RT)) { + /* success since this is not a prefix route */ + return -1; + } + if (filter->filter_set && + ((filter->rt_type && rt->fib6_type != filter->rt_type) || + (filter->dev && !fib6_info_uses_dev(rt, filter->dev)) || + (filter->protocol && rt->fib6_protocol != filter->protocol))) { + return -1; + } + + if (filter->filter_set || + !filter->dump_routes || !filter->dump_exceptions) { + flags |= NLM_F_DUMP_FILTERED; + } + + if (filter->dump_routes) { + if (skip) { + skip--; + } else { + if (rt6_fill_node(net, arg->skb, rt, NULL, NULL, NULL, + 0, RTM_NEWROUTE, + NETLINK_CB(arg->cb->skb).portid, + arg->cb->nlh->nlmsg_seq, flags)) { + return 0; + } + count++; + } + } + + if (filter->dump_exceptions) { + struct fib6_nh_exception_dump_walker w = { .dump = arg, + .rt = rt, + .flags = flags, + .skip = skip, + .count = 0 }; + int err; + + rcu_read_lock(); + if (rt->nh) { + err = nexthop_for_each_fib6_nh(rt->nh, + rt6_nh_dump_exceptions, + &w); + } else { + err = rt6_nh_dump_exceptions(rt->fib6_nh, &w); + } + rcu_read_unlock(); + + if (err) + return count + w.count; + } + + return -1; +} + +static int inet6_rtm_valid_getroute_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct rtmsg *rtm; + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) { + NL_SET_ERR_MSG_MOD(extack, + "Invalid header for get route request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse_deprecated(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_ipv6_policy, extack); + + rtm = nlmsg_data(nlh); + if ((rtm->rtm_src_len && rtm->rtm_src_len != 128) || + (rtm->rtm_dst_len && rtm->rtm_dst_len != 128) || + rtm->rtm_table || rtm->rtm_protocol || rtm->rtm_scope || + rtm->rtm_type) { + NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for get route request"); + return -EINVAL; + } + if (rtm->rtm_flags & ~RTM_F_FIB_MATCH) { + NL_SET_ERR_MSG_MOD(extack, + "Invalid flags for get route request"); + return -EINVAL; + } + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_ipv6_policy, extack); + if (err) + return err; + + if ((tb[RTA_SRC] && !rtm->rtm_src_len) || + (tb[RTA_DST] && !rtm->rtm_dst_len)) { + NL_SET_ERR_MSG_MOD(extack, "rtm_src_len and rtm_dst_len must be 128 for IPv6"); + return -EINVAL; + } + + for (i = 0; i <= RTA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case RTA_SRC: + case RTA_DST: + case RTA_IIF: + case RTA_OIF: + case RTA_MARK: + case RTA_UID: + case RTA_SPORT: + case RTA_DPORT: + case RTA_IP_PROTO: + break; + default: + NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in get route request"); + return -EINVAL; + } + } + + return 0; +} + +static int inet6_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(in_skb->sk); + struct nlattr *tb[RTA_MAX+1]; + int err, iif = 0, oif = 0; + struct fib6_info *from; + struct dst_entry *dst; + struct rt6_info *rt; + struct sk_buff *skb; + struct rtmsg *rtm; + struct flowi6 fl6 = {}; + bool fibmatch; + + err = inet6_rtm_valid_getroute_req(in_skb, nlh, tb, extack); + if (err < 0) + goto errout; + + err = -EINVAL; + rtm = nlmsg_data(nlh); + fl6.flowlabel = ip6_make_flowinfo(rtm->rtm_tos, 0); + fibmatch = !!(rtm->rtm_flags & RTM_F_FIB_MATCH); + + if (tb[RTA_SRC]) { + if (nla_len(tb[RTA_SRC]) < sizeof(struct in6_addr)) + goto errout; + + fl6.saddr = *(struct in6_addr *)nla_data(tb[RTA_SRC]); + } + + if (tb[RTA_DST]) { + if (nla_len(tb[RTA_DST]) < sizeof(struct in6_addr)) + goto errout; + + fl6.daddr = *(struct in6_addr *)nla_data(tb[RTA_DST]); + } + + if (tb[RTA_IIF]) + iif = nla_get_u32(tb[RTA_IIF]); + + if (tb[RTA_OIF]) + oif = nla_get_u32(tb[RTA_OIF]); + + if (tb[RTA_MARK]) + fl6.flowi6_mark = nla_get_u32(tb[RTA_MARK]); + + if (tb[RTA_UID]) + fl6.flowi6_uid = make_kuid(current_user_ns(), + nla_get_u32(tb[RTA_UID])); + else + fl6.flowi6_uid = iif ? INVALID_UID : current_uid(); + + if (tb[RTA_SPORT]) + fl6.fl6_sport = nla_get_be16(tb[RTA_SPORT]); + + if (tb[RTA_DPORT]) + fl6.fl6_dport = nla_get_be16(tb[RTA_DPORT]); + + if (tb[RTA_IP_PROTO]) { + err = rtm_getroute_parse_ip_proto(tb[RTA_IP_PROTO], + &fl6.flowi6_proto, AF_INET6, + extack); + if (err) + goto errout; + } + + if (iif) { + struct net_device *dev; + int flags = 0; + + rcu_read_lock(); + + dev = dev_get_by_index_rcu(net, iif); + if (!dev) { + rcu_read_unlock(); + err = -ENODEV; + goto errout; + } + + fl6.flowi6_iif = iif; + + if (!ipv6_addr_any(&fl6.saddr)) + flags |= RT6_LOOKUP_F_HAS_SADDR; + + dst = ip6_route_input_lookup(net, dev, &fl6, NULL, flags); + + rcu_read_unlock(); + } else { + fl6.flowi6_oif = oif; + + dst = ip6_route_output(net, NULL, &fl6); + } + + + rt = container_of(dst, struct rt6_info, dst); + if (rt->dst.error) { + err = rt->dst.error; + ip6_rt_put(rt); + goto errout; + } + + if (rt == net->ipv6.ip6_null_entry) { + err = rt->dst.error; + ip6_rt_put(rt); + goto errout; + } + + skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) { + ip6_rt_put(rt); + err = -ENOBUFS; + goto errout; + } + + skb_dst_set(skb, &rt->dst); + + rcu_read_lock(); + from = rcu_dereference(rt->from); + if (from) { + if (fibmatch) + err = rt6_fill_node(net, skb, from, NULL, NULL, NULL, + iif, RTM_NEWROUTE, + NETLINK_CB(in_skb).portid, + nlh->nlmsg_seq, 0); + else + err = rt6_fill_node(net, skb, from, dst, &fl6.daddr, + &fl6.saddr, iif, RTM_NEWROUTE, + NETLINK_CB(in_skb).portid, + nlh->nlmsg_seq, 0); + } else { + err = -ENETUNREACH; + } + rcu_read_unlock(); + + if (err < 0) { + kfree_skb(skb); + goto errout; + } + + err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid); +errout: + return err; +} + +void inet6_rt_notify(int event, struct fib6_info *rt, struct nl_info *info, + unsigned int nlm_flags) +{ + struct sk_buff *skb; + struct net *net = info->nl_net; + u32 seq; + int err; + + err = -ENOBUFS; + seq = info->nlh ? info->nlh->nlmsg_seq : 0; + + skb = nlmsg_new(rt6_nlmsg_size(rt), gfp_any()); + if (!skb) + goto errout; + + err = rt6_fill_node(net, skb, rt, NULL, NULL, NULL, 0, + event, info->portid, seq, nlm_flags); + if (err < 0) { + /* -EMSGSIZE implies BUG in rt6_nlmsg_size() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + rtnl_notify(skb, net, info->portid, RTNLGRP_IPV6_ROUTE, + info->nlh, gfp_any()); + return; +errout: + if (err < 0) + rtnl_set_sk_err(net, RTNLGRP_IPV6_ROUTE, err); +} + +void fib6_rt_update(struct net *net, struct fib6_info *rt, + struct nl_info *info) +{ + u32 seq = info->nlh ? info->nlh->nlmsg_seq : 0; + struct sk_buff *skb; + int err = -ENOBUFS; + + skb = nlmsg_new(rt6_nlmsg_size(rt), gfp_any()); + if (!skb) + goto errout; + + err = rt6_fill_node(net, skb, rt, NULL, NULL, NULL, 0, + RTM_NEWROUTE, info->portid, seq, NLM_F_REPLACE); + if (err < 0) { + /* -EMSGSIZE implies BUG in rt6_nlmsg_size() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + rtnl_notify(skb, net, info->portid, RTNLGRP_IPV6_ROUTE, + info->nlh, gfp_any()); + return; +errout: + if (err < 0) + rtnl_set_sk_err(net, RTNLGRP_IPV6_ROUTE, err); +} + +void fib6_info_hw_flags_set(struct net *net, struct fib6_info *f6i, + bool offload, bool trap, bool offload_failed) +{ + struct sk_buff *skb; + int err; + + if (READ_ONCE(f6i->offload) == offload && + READ_ONCE(f6i->trap) == trap && + READ_ONCE(f6i->offload_failed) == offload_failed) + return; + + WRITE_ONCE(f6i->offload, offload); + WRITE_ONCE(f6i->trap, trap); + + /* 2 means send notifications only if offload_failed was changed. */ + if (net->ipv6.sysctl.fib_notify_on_flag_change == 2 && + READ_ONCE(f6i->offload_failed) == offload_failed) + return; + + WRITE_ONCE(f6i->offload_failed, offload_failed); + + if (!rcu_access_pointer(f6i->fib6_node)) + /* The route was removed from the tree, do not send + * notification. + */ + return; + + if (!net->ipv6.sysctl.fib_notify_on_flag_change) + return; + + skb = nlmsg_new(rt6_nlmsg_size(f6i), GFP_KERNEL); + if (!skb) { + err = -ENOBUFS; + goto errout; + } + + err = rt6_fill_node(net, skb, f6i, NULL, NULL, NULL, 0, RTM_NEWROUTE, 0, + 0, 0); + if (err < 0) { + /* -EMSGSIZE implies BUG in rt6_nlmsg_size() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + + rtnl_notify(skb, net, 0, RTNLGRP_IPV6_ROUTE, NULL, GFP_KERNEL); + return; + +errout: + rtnl_set_sk_err(net, RTNLGRP_IPV6_ROUTE, err); +} +EXPORT_SYMBOL(fib6_info_hw_flags_set); + +static int ip6_route_dev_notify(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct net *net = dev_net(dev); + + if (!(dev->flags & IFF_LOOPBACK)) + return NOTIFY_OK; + + if (event == NETDEV_REGISTER) { + net->ipv6.fib6_null_entry->fib6_nh->fib_nh_dev = dev; + net->ipv6.ip6_null_entry->dst.dev = dev; + net->ipv6.ip6_null_entry->rt6i_idev = in6_dev_get(dev); +#ifdef CONFIG_IPV6_MULTIPLE_TABLES + net->ipv6.ip6_prohibit_entry->dst.dev = dev; + net->ipv6.ip6_prohibit_entry->rt6i_idev = in6_dev_get(dev); + net->ipv6.ip6_blk_hole_entry->dst.dev = dev; + net->ipv6.ip6_blk_hole_entry->rt6i_idev = in6_dev_get(dev); +#endif + } else if (event == NETDEV_UNREGISTER && + dev->reg_state != NETREG_UNREGISTERED) { + /* NETDEV_UNREGISTER could be fired for multiple times by + * netdev_wait_allrefs(). Make sure we only call this once. + */ + in6_dev_put_clear(&net->ipv6.ip6_null_entry->rt6i_idev); +#ifdef CONFIG_IPV6_MULTIPLE_TABLES + in6_dev_put_clear(&net->ipv6.ip6_prohibit_entry->rt6i_idev); + in6_dev_put_clear(&net->ipv6.ip6_blk_hole_entry->rt6i_idev); +#endif + } + + return NOTIFY_OK; +} + +/* + * /proc + */ + +#ifdef CONFIG_PROC_FS +static int rt6_stats_seq_show(struct seq_file *seq, void *v) +{ + struct net *net = (struct net *)seq->private; + seq_printf(seq, "%04x %04x %04x %04x %04x %04x %04x\n", + net->ipv6.rt6_stats->fib_nodes, + net->ipv6.rt6_stats->fib_route_nodes, + atomic_read(&net->ipv6.rt6_stats->fib_rt_alloc), + net->ipv6.rt6_stats->fib_rt_entries, + net->ipv6.rt6_stats->fib_rt_cache, + dst_entries_get_slow(&net->ipv6.ip6_dst_ops), + net->ipv6.rt6_stats->fib_discarded_routes); + + return 0; +} +#endif /* CONFIG_PROC_FS */ + +#ifdef CONFIG_SYSCTL + +static int ipv6_sysctl_rtcache_flush(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net; + int delay; + int ret; + if (!write) + return -EINVAL; + + net = (struct net *)ctl->extra1; + delay = net->ipv6.sysctl.flush_delay; + ret = proc_dointvec(ctl, write, buffer, lenp, ppos); + if (ret) + return ret; + + fib6_run_gc(delay <= 0 ? 0 : (unsigned long)delay, net, delay > 0); + return 0; +} + +static struct ctl_table ipv6_route_table_template[] = { + { + .procname = "max_size", + .data = &init_net.ipv6.sysctl.ip6_rt_max_size, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "gc_thresh", + .data = &ip6_dst_ops_template.gc_thresh, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "flush", + .data = &init_net.ipv6.sysctl.flush_delay, + .maxlen = sizeof(int), + .mode = 0200, + .proc_handler = ipv6_sysctl_rtcache_flush + }, + { + .procname = "gc_min_interval", + .data = &init_net.ipv6.sysctl.ip6_rt_gc_min_interval, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "gc_timeout", + .data = &init_net.ipv6.sysctl.ip6_rt_gc_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "gc_interval", + .data = &init_net.ipv6.sysctl.ip6_rt_gc_interval, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "gc_elasticity", + .data = &init_net.ipv6.sysctl.ip6_rt_gc_elasticity, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "mtu_expires", + .data = &init_net.ipv6.sysctl.ip6_rt_mtu_expires, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "min_adv_mss", + .data = &init_net.ipv6.sysctl.ip6_rt_min_advmss, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "gc_min_interval_ms", + .data = &init_net.ipv6.sysctl.ip6_rt_gc_min_interval, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_ms_jiffies, + }, + { + .procname = "skip_notify_on_dev_down", + .data = &init_net.ipv6.sysctl.skip_notify_on_dev_down, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + { } +}; + +struct ctl_table * __net_init ipv6_route_sysctl_init(struct net *net) +{ + struct ctl_table *table; + + table = kmemdup(ipv6_route_table_template, + sizeof(ipv6_route_table_template), + GFP_KERNEL); + + if (table) { + table[0].data = &net->ipv6.sysctl.ip6_rt_max_size; + table[1].data = &net->ipv6.ip6_dst_ops.gc_thresh; + table[2].data = &net->ipv6.sysctl.flush_delay; + table[2].extra1 = net; + table[3].data = &net->ipv6.sysctl.ip6_rt_gc_min_interval; + table[4].data = &net->ipv6.sysctl.ip6_rt_gc_timeout; + table[5].data = &net->ipv6.sysctl.ip6_rt_gc_interval; + table[6].data = &net->ipv6.sysctl.ip6_rt_gc_elasticity; + table[7].data = &net->ipv6.sysctl.ip6_rt_mtu_expires; + table[8].data = &net->ipv6.sysctl.ip6_rt_min_advmss; + table[9].data = &net->ipv6.sysctl.ip6_rt_gc_min_interval; + table[10].data = &net->ipv6.sysctl.skip_notify_on_dev_down; + + /* Don't export sysctls to unprivileged users */ + if (net->user_ns != &init_user_ns) + table[1].procname = NULL; + } + + return table; +} +#endif + +static int __net_init ip6_route_net_init(struct net *net) +{ + int ret = -ENOMEM; + + memcpy(&net->ipv6.ip6_dst_ops, &ip6_dst_ops_template, + sizeof(net->ipv6.ip6_dst_ops)); + + if (dst_entries_init(&net->ipv6.ip6_dst_ops) < 0) + goto out_ip6_dst_ops; + + net->ipv6.fib6_null_entry = fib6_info_alloc(GFP_KERNEL, true); + if (!net->ipv6.fib6_null_entry) + goto out_ip6_dst_entries; + memcpy(net->ipv6.fib6_null_entry, &fib6_null_entry_template, + sizeof(*net->ipv6.fib6_null_entry)); + + net->ipv6.ip6_null_entry = kmemdup(&ip6_null_entry_template, + sizeof(*net->ipv6.ip6_null_entry), + GFP_KERNEL); + if (!net->ipv6.ip6_null_entry) + goto out_fib6_null_entry; + net->ipv6.ip6_null_entry->dst.ops = &net->ipv6.ip6_dst_ops; + dst_init_metrics(&net->ipv6.ip6_null_entry->dst, + ip6_template_metrics, true); + INIT_LIST_HEAD(&net->ipv6.ip6_null_entry->rt6i_uncached); + +#ifdef CONFIG_IPV6_MULTIPLE_TABLES + net->ipv6.fib6_has_custom_rules = false; + net->ipv6.ip6_prohibit_entry = kmemdup(&ip6_prohibit_entry_template, + sizeof(*net->ipv6.ip6_prohibit_entry), + GFP_KERNEL); + if (!net->ipv6.ip6_prohibit_entry) + goto out_ip6_null_entry; + net->ipv6.ip6_prohibit_entry->dst.ops = &net->ipv6.ip6_dst_ops; + dst_init_metrics(&net->ipv6.ip6_prohibit_entry->dst, + ip6_template_metrics, true); + INIT_LIST_HEAD(&net->ipv6.ip6_prohibit_entry->rt6i_uncached); + + net->ipv6.ip6_blk_hole_entry = kmemdup(&ip6_blk_hole_entry_template, + sizeof(*net->ipv6.ip6_blk_hole_entry), + GFP_KERNEL); + if (!net->ipv6.ip6_blk_hole_entry) + goto out_ip6_prohibit_entry; + net->ipv6.ip6_blk_hole_entry->dst.ops = &net->ipv6.ip6_dst_ops; + dst_init_metrics(&net->ipv6.ip6_blk_hole_entry->dst, + ip6_template_metrics, true); + INIT_LIST_HEAD(&net->ipv6.ip6_blk_hole_entry->rt6i_uncached); +#ifdef CONFIG_IPV6_SUBTREES + net->ipv6.fib6_routes_require_src = 0; +#endif +#endif + + net->ipv6.sysctl.flush_delay = 0; + net->ipv6.sysctl.ip6_rt_max_size = INT_MAX; + net->ipv6.sysctl.ip6_rt_gc_min_interval = HZ / 2; + net->ipv6.sysctl.ip6_rt_gc_timeout = 60*HZ; + net->ipv6.sysctl.ip6_rt_gc_interval = 30*HZ; + net->ipv6.sysctl.ip6_rt_gc_elasticity = 9; + net->ipv6.sysctl.ip6_rt_mtu_expires = 10*60*HZ; + net->ipv6.sysctl.ip6_rt_min_advmss = IPV6_MIN_MTU - 20 - 40; + net->ipv6.sysctl.skip_notify_on_dev_down = 0; + + atomic_set(&net->ipv6.ip6_rt_gc_expire, 30*HZ); + + ret = 0; +out: + return ret; + +#ifdef CONFIG_IPV6_MULTIPLE_TABLES +out_ip6_prohibit_entry: + kfree(net->ipv6.ip6_prohibit_entry); +out_ip6_null_entry: + kfree(net->ipv6.ip6_null_entry); +#endif +out_fib6_null_entry: + kfree(net->ipv6.fib6_null_entry); +out_ip6_dst_entries: + dst_entries_destroy(&net->ipv6.ip6_dst_ops); +out_ip6_dst_ops: + goto out; +} + +static void __net_exit ip6_route_net_exit(struct net *net) +{ + kfree(net->ipv6.fib6_null_entry); + kfree(net->ipv6.ip6_null_entry); +#ifdef CONFIG_IPV6_MULTIPLE_TABLES + kfree(net->ipv6.ip6_prohibit_entry); + kfree(net->ipv6.ip6_blk_hole_entry); +#endif + dst_entries_destroy(&net->ipv6.ip6_dst_ops); +} + +static int __net_init ip6_route_net_init_late(struct net *net) +{ +#ifdef CONFIG_PROC_FS + if (!proc_create_net("ipv6_route", 0, net->proc_net, + &ipv6_route_seq_ops, + sizeof(struct ipv6_route_iter))) + return -ENOMEM; + + if (!proc_create_net_single("rt6_stats", 0444, net->proc_net, + rt6_stats_seq_show, NULL)) { + remove_proc_entry("ipv6_route", net->proc_net); + return -ENOMEM; + } +#endif + return 0; +} + +static void __net_exit ip6_route_net_exit_late(struct net *net) +{ +#ifdef CONFIG_PROC_FS + remove_proc_entry("ipv6_route", net->proc_net); + remove_proc_entry("rt6_stats", net->proc_net); +#endif +} + +static struct pernet_operations ip6_route_net_ops = { + .init = ip6_route_net_init, + .exit = ip6_route_net_exit, +}; + +static int __net_init ipv6_inetpeer_init(struct net *net) +{ + struct inet_peer_base *bp = kmalloc(sizeof(*bp), GFP_KERNEL); + + if (!bp) + return -ENOMEM; + inet_peer_base_init(bp); + net->ipv6.peers = bp; + return 0; +} + +static void __net_exit ipv6_inetpeer_exit(struct net *net) +{ + struct inet_peer_base *bp = net->ipv6.peers; + + net->ipv6.peers = NULL; + inetpeer_invalidate_tree(bp); + kfree(bp); +} + +static struct pernet_operations ipv6_inetpeer_ops = { + .init = ipv6_inetpeer_init, + .exit = ipv6_inetpeer_exit, +}; + +static struct pernet_operations ip6_route_net_late_ops = { + .init = ip6_route_net_init_late, + .exit = ip6_route_net_exit_late, +}; + +static struct notifier_block ip6_route_dev_notifier = { + .notifier_call = ip6_route_dev_notify, + .priority = ADDRCONF_NOTIFY_PRIORITY - 10, +}; + +void __init ip6_route_init_special_entries(void) +{ + /* Registering of the loopback is done before this portion of code, + * the loopback reference in rt6_info will not be taken, do it + * manually for init_net */ + init_net.ipv6.fib6_null_entry->fib6_nh->fib_nh_dev = init_net.loopback_dev; + init_net.ipv6.ip6_null_entry->dst.dev = init_net.loopback_dev; + init_net.ipv6.ip6_null_entry->rt6i_idev = in6_dev_get(init_net.loopback_dev); + #ifdef CONFIG_IPV6_MULTIPLE_TABLES + init_net.ipv6.ip6_prohibit_entry->dst.dev = init_net.loopback_dev; + init_net.ipv6.ip6_prohibit_entry->rt6i_idev = in6_dev_get(init_net.loopback_dev); + init_net.ipv6.ip6_blk_hole_entry->dst.dev = init_net.loopback_dev; + init_net.ipv6.ip6_blk_hole_entry->rt6i_idev = in6_dev_get(init_net.loopback_dev); + #endif +} + +#if IS_BUILTIN(CONFIG_IPV6) +#if defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS) +DEFINE_BPF_ITER_FUNC(ipv6_route, struct bpf_iter_meta *meta, struct fib6_info *rt) + +BTF_ID_LIST(btf_fib6_info_id) +BTF_ID(struct, fib6_info) + +static const struct bpf_iter_seq_info ipv6_route_seq_info = { + .seq_ops = &ipv6_route_seq_ops, + .init_seq_private = bpf_iter_init_seq_net, + .fini_seq_private = bpf_iter_fini_seq_net, + .seq_priv_size = sizeof(struct ipv6_route_iter), +}; + +static struct bpf_iter_reg ipv6_route_reg_info = { + .target = "ipv6_route", + .ctx_arg_info_size = 1, + .ctx_arg_info = { + { offsetof(struct bpf_iter__ipv6_route, rt), + PTR_TO_BTF_ID_OR_NULL }, + }, + .seq_info = &ipv6_route_seq_info, +}; + +static int __init bpf_iter_register(void) +{ + ipv6_route_reg_info.ctx_arg_info[0].btf_id = *btf_fib6_info_id; + return bpf_iter_reg_target(&ipv6_route_reg_info); +} + +static void bpf_iter_unregister(void) +{ + bpf_iter_unreg_target(&ipv6_route_reg_info); +} +#endif +#endif + +int __init ip6_route_init(void) +{ + int ret; + int cpu; + + ret = -ENOMEM; + ip6_dst_ops_template.kmem_cachep = + kmem_cache_create("ip6_dst_cache", sizeof(struct rt6_info), 0, + SLAB_HWCACHE_ALIGN | SLAB_ACCOUNT, NULL); + if (!ip6_dst_ops_template.kmem_cachep) + goto out; + + ret = dst_entries_init(&ip6_dst_blackhole_ops); + if (ret) + goto out_kmem_cache; + + ret = register_pernet_subsys(&ipv6_inetpeer_ops); + if (ret) + goto out_dst_entries; + + ret = register_pernet_subsys(&ip6_route_net_ops); + if (ret) + goto out_register_inetpeer; + + ip6_dst_blackhole_ops.kmem_cachep = ip6_dst_ops_template.kmem_cachep; + + ret = fib6_init(); + if (ret) + goto out_register_subsys; + + ret = xfrm6_init(); + if (ret) + goto out_fib6_init; + + ret = fib6_rules_init(); + if (ret) + goto xfrm6_init; + + ret = register_pernet_subsys(&ip6_route_net_late_ops); + if (ret) + goto fib6_rules_init; + + ret = rtnl_register_module(THIS_MODULE, PF_INET6, RTM_NEWROUTE, + inet6_rtm_newroute, NULL, 0); + if (ret < 0) + goto out_register_late_subsys; + + ret = rtnl_register_module(THIS_MODULE, PF_INET6, RTM_DELROUTE, + inet6_rtm_delroute, NULL, 0); + if (ret < 0) + goto out_register_late_subsys; + + ret = rtnl_register_module(THIS_MODULE, PF_INET6, RTM_GETROUTE, + inet6_rtm_getroute, NULL, + RTNL_FLAG_DOIT_UNLOCKED); + if (ret < 0) + goto out_register_late_subsys; + + ret = register_netdevice_notifier(&ip6_route_dev_notifier); + if (ret) + goto out_register_late_subsys; + +#if IS_BUILTIN(CONFIG_IPV6) +#if defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS) + ret = bpf_iter_register(); + if (ret) + goto out_register_late_subsys; +#endif +#endif + + for_each_possible_cpu(cpu) { + struct uncached_list *ul = per_cpu_ptr(&rt6_uncached_list, cpu); + + INIT_LIST_HEAD(&ul->head); + INIT_LIST_HEAD(&ul->quarantine); + spin_lock_init(&ul->lock); + } + +out: + return ret; + +out_register_late_subsys: + rtnl_unregister_all(PF_INET6); + unregister_pernet_subsys(&ip6_route_net_late_ops); +fib6_rules_init: + fib6_rules_cleanup(); +xfrm6_init: + xfrm6_fini(); +out_fib6_init: + fib6_gc_cleanup(); +out_register_subsys: + unregister_pernet_subsys(&ip6_route_net_ops); +out_register_inetpeer: + unregister_pernet_subsys(&ipv6_inetpeer_ops); +out_dst_entries: + dst_entries_destroy(&ip6_dst_blackhole_ops); +out_kmem_cache: + kmem_cache_destroy(ip6_dst_ops_template.kmem_cachep); + goto out; +} + +void ip6_route_cleanup(void) +{ +#if IS_BUILTIN(CONFIG_IPV6) +#if defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS) + bpf_iter_unregister(); +#endif +#endif + unregister_netdevice_notifier(&ip6_route_dev_notifier); + unregister_pernet_subsys(&ip6_route_net_late_ops); + fib6_rules_cleanup(); + xfrm6_fini(); + fib6_gc_cleanup(); + unregister_pernet_subsys(&ipv6_inetpeer_ops); + unregister_pernet_subsys(&ip6_route_net_ops); + dst_entries_destroy(&ip6_dst_blackhole_ops); + kmem_cache_destroy(ip6_dst_ops_template.kmem_cachep); +} diff --git a/net/ipv6/rpl.c b/net/ipv6/rpl.c new file mode 100644 index 000000000..d1876f192 --- /dev/null +++ b/net/ipv6/rpl.c @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Authors: + * (C) 2020 Alexander Aring + */ + +#include +#include + +#define IPV6_PFXTAIL_LEN(x) (sizeof(struct in6_addr) - (x)) +#define IPV6_RPL_BEST_ADDR_COMPRESSION 15 + +static void ipv6_rpl_addr_decompress(struct in6_addr *dst, + const struct in6_addr *daddr, + const void *post, unsigned char pfx) +{ + memcpy(dst, daddr, pfx); + memcpy(&dst->s6_addr[pfx], post, IPV6_PFXTAIL_LEN(pfx)); +} + +static void ipv6_rpl_addr_compress(void *dst, const struct in6_addr *addr, + unsigned char pfx) +{ + memcpy(dst, &addr->s6_addr[pfx], IPV6_PFXTAIL_LEN(pfx)); +} + +static void *ipv6_rpl_segdata_pos(const struct ipv6_rpl_sr_hdr *hdr, int i) +{ + return (void *)&hdr->rpl_segdata[i * IPV6_PFXTAIL_LEN(hdr->cmpri)]; +} + +size_t ipv6_rpl_srh_size(unsigned char n, unsigned char cmpri, + unsigned char cmpre) +{ + return sizeof(struct ipv6_rpl_sr_hdr) + (n * IPV6_PFXTAIL_LEN(cmpri)) + + IPV6_PFXTAIL_LEN(cmpre); +} + +void ipv6_rpl_srh_decompress(struct ipv6_rpl_sr_hdr *outhdr, + const struct ipv6_rpl_sr_hdr *inhdr, + const struct in6_addr *daddr, unsigned char n) +{ + int i; + + outhdr->nexthdr = inhdr->nexthdr; + outhdr->hdrlen = (((n + 1) * sizeof(struct in6_addr)) >> 3); + outhdr->pad = 0; + outhdr->type = inhdr->type; + outhdr->segments_left = inhdr->segments_left; + outhdr->cmpri = 0; + outhdr->cmpre = 0; + + for (i = 0; i < n; i++) + ipv6_rpl_addr_decompress(&outhdr->rpl_segaddr[i], daddr, + ipv6_rpl_segdata_pos(inhdr, i), + inhdr->cmpri); + + ipv6_rpl_addr_decompress(&outhdr->rpl_segaddr[n], daddr, + ipv6_rpl_segdata_pos(inhdr, n), + inhdr->cmpre); +} + +static unsigned char ipv6_rpl_srh_calc_cmpri(const struct ipv6_rpl_sr_hdr *inhdr, + const struct in6_addr *daddr, + unsigned char n) +{ + unsigned char plen; + int i; + + for (plen = 0; plen < sizeof(*daddr); plen++) { + for (i = 0; i < n; i++) { + if (daddr->s6_addr[plen] != + inhdr->rpl_segaddr[i].s6_addr[plen]) + return plen; + } + } + + return IPV6_RPL_BEST_ADDR_COMPRESSION; +} + +static unsigned char ipv6_rpl_srh_calc_cmpre(const struct in6_addr *daddr, + const struct in6_addr *last_segment) +{ + unsigned int plen; + + for (plen = 0; plen < sizeof(*daddr); plen++) { + if (daddr->s6_addr[plen] != last_segment->s6_addr[plen]) + return plen; + } + + return IPV6_RPL_BEST_ADDR_COMPRESSION; +} + +void ipv6_rpl_srh_compress(struct ipv6_rpl_sr_hdr *outhdr, + const struct ipv6_rpl_sr_hdr *inhdr, + const struct in6_addr *daddr, unsigned char n) +{ + unsigned char cmpri, cmpre; + size_t seglen; + int i; + + cmpri = ipv6_rpl_srh_calc_cmpri(inhdr, daddr, n); + cmpre = ipv6_rpl_srh_calc_cmpre(daddr, &inhdr->rpl_segaddr[n]); + + outhdr->nexthdr = inhdr->nexthdr; + seglen = (n * IPV6_PFXTAIL_LEN(cmpri)) + IPV6_PFXTAIL_LEN(cmpre); + outhdr->hdrlen = seglen >> 3; + if (seglen & 0x7) { + outhdr->hdrlen++; + outhdr->pad = 8 - (seglen & 0x7); + } else { + outhdr->pad = 0; + } + outhdr->type = inhdr->type; + outhdr->segments_left = inhdr->segments_left; + outhdr->cmpri = cmpri; + outhdr->cmpre = cmpre; + + for (i = 0; i < n; i++) + ipv6_rpl_addr_compress(ipv6_rpl_segdata_pos(outhdr, i), + &inhdr->rpl_segaddr[i], cmpri); + + ipv6_rpl_addr_compress(ipv6_rpl_segdata_pos(outhdr, n), + &inhdr->rpl_segaddr[n], cmpre); +} diff --git a/net/ipv6/rpl_iptunnel.c b/net/ipv6/rpl_iptunnel.c new file mode 100644 index 000000000..ff691d9f4 --- /dev/null +++ b/net/ipv6/rpl_iptunnel.c @@ -0,0 +1,376 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Authors: + * (C) 2020 Alexander Aring + */ + +#include + +#include +#include +#include +#include +#include + +struct rpl_iptunnel_encap { + struct ipv6_rpl_sr_hdr srh[0]; +}; + +struct rpl_lwt { + struct dst_cache cache; + struct rpl_iptunnel_encap tuninfo; +}; + +static inline struct rpl_lwt *rpl_lwt_lwtunnel(struct lwtunnel_state *lwt) +{ + return (struct rpl_lwt *)lwt->data; +} + +static inline struct rpl_iptunnel_encap * +rpl_encap_lwtunnel(struct lwtunnel_state *lwt) +{ + return &rpl_lwt_lwtunnel(lwt)->tuninfo; +} + +static const struct nla_policy rpl_iptunnel_policy[RPL_IPTUNNEL_MAX + 1] = { + [RPL_IPTUNNEL_SRH] = { .type = NLA_BINARY }, +}; + +static bool rpl_validate_srh(struct net *net, struct ipv6_rpl_sr_hdr *srh, + size_t seglen) +{ + int err; + + if ((srh->hdrlen << 3) != seglen) + return false; + + /* check at least one segment and seglen fit with segments_left */ + if (!srh->segments_left || + (srh->segments_left * sizeof(struct in6_addr)) != seglen) + return false; + + if (srh->cmpri || srh->cmpre) + return false; + + err = ipv6_chk_rpl_srh_loop(net, srh->rpl_segaddr, + srh->segments_left); + if (err) + return false; + + if (ipv6_addr_type(&srh->rpl_segaddr[srh->segments_left - 1]) & + IPV6_ADDR_MULTICAST) + return false; + + return true; +} + +static int rpl_build_state(struct net *net, struct nlattr *nla, + unsigned int family, const void *cfg, + struct lwtunnel_state **ts, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[RPL_IPTUNNEL_MAX + 1]; + struct lwtunnel_state *newts; + struct ipv6_rpl_sr_hdr *srh; + struct rpl_lwt *rlwt; + int err, srh_len; + + if (family != AF_INET6) + return -EINVAL; + + err = nla_parse_nested(tb, RPL_IPTUNNEL_MAX, nla, + rpl_iptunnel_policy, extack); + if (err < 0) + return err; + + if (!tb[RPL_IPTUNNEL_SRH]) + return -EINVAL; + + srh = nla_data(tb[RPL_IPTUNNEL_SRH]); + srh_len = nla_len(tb[RPL_IPTUNNEL_SRH]); + + if (srh_len < sizeof(*srh)) + return -EINVAL; + + /* verify that SRH is consistent */ + if (!rpl_validate_srh(net, srh, srh_len - sizeof(*srh))) + return -EINVAL; + + newts = lwtunnel_state_alloc(srh_len + sizeof(*rlwt)); + if (!newts) + return -ENOMEM; + + rlwt = rpl_lwt_lwtunnel(newts); + + err = dst_cache_init(&rlwt->cache, GFP_ATOMIC); + if (err) { + kfree(newts); + return err; + } + + memcpy(&rlwt->tuninfo.srh, srh, srh_len); + + newts->type = LWTUNNEL_ENCAP_RPL; + newts->flags |= LWTUNNEL_STATE_INPUT_REDIRECT; + newts->flags |= LWTUNNEL_STATE_OUTPUT_REDIRECT; + + *ts = newts; + + return 0; +} + +static void rpl_destroy_state(struct lwtunnel_state *lwt) +{ + dst_cache_destroy(&rpl_lwt_lwtunnel(lwt)->cache); +} + +static int rpl_do_srh_inline(struct sk_buff *skb, const struct rpl_lwt *rlwt, + const struct ipv6_rpl_sr_hdr *srh) +{ + struct ipv6_rpl_sr_hdr *isrh, *csrh; + const struct ipv6hdr *oldhdr; + struct ipv6hdr *hdr; + unsigned char *buf; + size_t hdrlen; + int err; + + oldhdr = ipv6_hdr(skb); + + buf = kcalloc(struct_size(srh, segments.addr, srh->segments_left), 2, GFP_ATOMIC); + if (!buf) + return -ENOMEM; + + isrh = (struct ipv6_rpl_sr_hdr *)buf; + csrh = (struct ipv6_rpl_sr_hdr *)(buf + ((srh->hdrlen + 1) << 3)); + + memcpy(isrh, srh, sizeof(*isrh)); + memcpy(isrh->rpl_segaddr, &srh->rpl_segaddr[1], + (srh->segments_left - 1) * 16); + isrh->rpl_segaddr[srh->segments_left - 1] = oldhdr->daddr; + + ipv6_rpl_srh_compress(csrh, isrh, &srh->rpl_segaddr[0], + isrh->segments_left - 1); + + hdrlen = ((csrh->hdrlen + 1) << 3); + + err = skb_cow_head(skb, hdrlen + skb->mac_len); + if (unlikely(err)) { + kfree(buf); + return err; + } + + skb_pull(skb, sizeof(struct ipv6hdr)); + skb_postpull_rcsum(skb, skb_network_header(skb), + sizeof(struct ipv6hdr)); + + skb_push(skb, sizeof(struct ipv6hdr) + hdrlen); + skb_reset_network_header(skb); + skb_mac_header_rebuild(skb); + + hdr = ipv6_hdr(skb); + memmove(hdr, oldhdr, sizeof(*hdr)); + isrh = (void *)hdr + sizeof(*hdr); + memcpy(isrh, csrh, hdrlen); + + isrh->nexthdr = hdr->nexthdr; + hdr->nexthdr = NEXTHDR_ROUTING; + hdr->daddr = srh->rpl_segaddr[0]; + + ipv6_hdr(skb)->payload_len = htons(skb->len - sizeof(struct ipv6hdr)); + skb_set_transport_header(skb, sizeof(struct ipv6hdr)); + + skb_postpush_rcsum(skb, hdr, sizeof(struct ipv6hdr) + hdrlen); + + kfree(buf); + + return 0; +} + +static int rpl_do_srh(struct sk_buff *skb, const struct rpl_lwt *rlwt) +{ + struct dst_entry *dst = skb_dst(skb); + struct rpl_iptunnel_encap *tinfo; + + if (skb->protocol != htons(ETH_P_IPV6)) + return -EINVAL; + + tinfo = rpl_encap_lwtunnel(dst->lwtstate); + + return rpl_do_srh_inline(skb, rlwt, tinfo->srh); +} + +static int rpl_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct dst_entry *orig_dst = skb_dst(skb); + struct dst_entry *dst = NULL; + struct rpl_lwt *rlwt; + int err; + + rlwt = rpl_lwt_lwtunnel(orig_dst->lwtstate); + + err = rpl_do_srh(skb, rlwt); + if (unlikely(err)) + goto drop; + + preempt_disable(); + dst = dst_cache_get(&rlwt->cache); + preempt_enable(); + + if (unlikely(!dst)) { + struct ipv6hdr *hdr = ipv6_hdr(skb); + struct flowi6 fl6; + + memset(&fl6, 0, sizeof(fl6)); + fl6.daddr = hdr->daddr; + fl6.saddr = hdr->saddr; + fl6.flowlabel = ip6_flowinfo(hdr); + fl6.flowi6_mark = skb->mark; + fl6.flowi6_proto = hdr->nexthdr; + + dst = ip6_route_output(net, NULL, &fl6); + if (dst->error) { + err = dst->error; + dst_release(dst); + goto drop; + } + + preempt_disable(); + dst_cache_set_ip6(&rlwt->cache, dst, &fl6.saddr); + preempt_enable(); + } + + skb_dst_drop(skb); + skb_dst_set(skb, dst); + + err = skb_cow_head(skb, LL_RESERVED_SPACE(dst->dev)); + if (unlikely(err)) + goto drop; + + return dst_output(net, sk, skb); + +drop: + kfree_skb(skb); + return err; +} + +static int rpl_input(struct sk_buff *skb) +{ + struct dst_entry *orig_dst = skb_dst(skb); + struct dst_entry *dst = NULL; + struct rpl_lwt *rlwt; + int err; + + rlwt = rpl_lwt_lwtunnel(orig_dst->lwtstate); + + err = rpl_do_srh(skb, rlwt); + if (unlikely(err)) { + kfree_skb(skb); + return err; + } + + preempt_disable(); + dst = dst_cache_get(&rlwt->cache); + preempt_enable(); + + skb_dst_drop(skb); + + if (!dst) { + ip6_route_input(skb); + dst = skb_dst(skb); + if (!dst->error) { + preempt_disable(); + dst_cache_set_ip6(&rlwt->cache, dst, + &ipv6_hdr(skb)->saddr); + preempt_enable(); + } + } else { + skb_dst_set(skb, dst); + } + + err = skb_cow_head(skb, LL_RESERVED_SPACE(dst->dev)); + if (unlikely(err)) + return err; + + return dst_input(skb); +} + +static int nla_put_rpl_srh(struct sk_buff *skb, int attrtype, + struct rpl_iptunnel_encap *tuninfo) +{ + struct rpl_iptunnel_encap *data; + struct nlattr *nla; + int len; + + len = RPL_IPTUNNEL_SRH_SIZE(tuninfo->srh); + + nla = nla_reserve(skb, attrtype, len); + if (!nla) + return -EMSGSIZE; + + data = nla_data(nla); + memcpy(data, tuninfo->srh, len); + + return 0; +} + +static int rpl_fill_encap_info(struct sk_buff *skb, + struct lwtunnel_state *lwtstate) +{ + struct rpl_iptunnel_encap *tuninfo = rpl_encap_lwtunnel(lwtstate); + + if (nla_put_rpl_srh(skb, RPL_IPTUNNEL_SRH, tuninfo)) + return -EMSGSIZE; + + return 0; +} + +static int rpl_encap_nlsize(struct lwtunnel_state *lwtstate) +{ + struct rpl_iptunnel_encap *tuninfo = rpl_encap_lwtunnel(lwtstate); + + return nla_total_size(RPL_IPTUNNEL_SRH_SIZE(tuninfo->srh)); +} + +static int rpl_encap_cmp(struct lwtunnel_state *a, struct lwtunnel_state *b) +{ + struct rpl_iptunnel_encap *a_hdr = rpl_encap_lwtunnel(a); + struct rpl_iptunnel_encap *b_hdr = rpl_encap_lwtunnel(b); + int len = RPL_IPTUNNEL_SRH_SIZE(a_hdr->srh); + + if (len != RPL_IPTUNNEL_SRH_SIZE(b_hdr->srh)) + return 1; + + return memcmp(a_hdr, b_hdr, len); +} + +static const struct lwtunnel_encap_ops rpl_ops = { + .build_state = rpl_build_state, + .destroy_state = rpl_destroy_state, + .output = rpl_output, + .input = rpl_input, + .fill_encap = rpl_fill_encap_info, + .get_encap_size = rpl_encap_nlsize, + .cmp_encap = rpl_encap_cmp, + .owner = THIS_MODULE, +}; + +int __init rpl_init(void) +{ + int err; + + err = lwtunnel_encap_add_ops(&rpl_ops, LWTUNNEL_ENCAP_RPL); + if (err) + goto out; + + pr_info("RPL Segment Routing with IPv6\n"); + + return 0; + +out: + return err; +} + +void rpl_exit(void) +{ + lwtunnel_encap_del_ops(&rpl_ops, LWTUNNEL_ENCAP_RPL); +} diff --git a/net/ipv6/seg6.c b/net/ipv6/seg6.c new file mode 100644 index 000000000..29346a6ee --- /dev/null +++ b/net/ipv6/seg6.c @@ -0,0 +1,569 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * SR-IPv6 implementation + * + * Author: + * David Lebrun + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#ifdef CONFIG_IPV6_SEG6_HMAC +#include +#endif + +bool seg6_validate_srh(struct ipv6_sr_hdr *srh, int len, bool reduced) +{ + unsigned int tlv_offset; + int max_last_entry; + int trailing; + + if (srh->type != IPV6_SRCRT_TYPE_4) + return false; + + if (((srh->hdrlen + 1) << 3) != len) + return false; + + if (!reduced && srh->segments_left > srh->first_segment) { + return false; + } else { + max_last_entry = (srh->hdrlen / 2) - 1; + + if (srh->first_segment > max_last_entry) + return false; + + if (srh->segments_left > srh->first_segment + 1) + return false; + } + + tlv_offset = sizeof(*srh) + ((srh->first_segment + 1) << 4); + + trailing = len - tlv_offset; + if (trailing < 0) + return false; + + while (trailing) { + struct sr6_tlv *tlv; + unsigned int tlv_len; + + if (trailing < sizeof(*tlv)) + return false; + + tlv = (struct sr6_tlv *)((unsigned char *)srh + tlv_offset); + tlv_len = sizeof(*tlv) + tlv->len; + + trailing -= tlv_len; + if (trailing < 0) + return false; + + tlv_offset += tlv_len; + } + + return true; +} + +struct ipv6_sr_hdr *seg6_get_srh(struct sk_buff *skb, int flags) +{ + struct ipv6_sr_hdr *srh; + int len, srhoff = 0; + + if (ipv6_find_hdr(skb, &srhoff, IPPROTO_ROUTING, NULL, &flags) < 0) + return NULL; + + if (!pskb_may_pull(skb, srhoff + sizeof(*srh))) + return NULL; + + srh = (struct ipv6_sr_hdr *)(skb->data + srhoff); + + len = (srh->hdrlen + 1) << 3; + + if (!pskb_may_pull(skb, srhoff + len)) + return NULL; + + /* note that pskb_may_pull may change pointers in header; + * for this reason it is necessary to reload them when needed. + */ + srh = (struct ipv6_sr_hdr *)(skb->data + srhoff); + + if (!seg6_validate_srh(srh, len, true)) + return NULL; + + return srh; +} + +/* Determine if an ICMP invoking packet contains a segment routing + * header. If it does, extract the offset to the true destination + * address, which is in the first segment address. + */ +void seg6_icmp_srh(struct sk_buff *skb, struct inet6_skb_parm *opt) +{ + __u16 network_header = skb->network_header; + struct ipv6_sr_hdr *srh; + + /* Update network header to point to the invoking packet + * inside the ICMP packet, so we can use the seg6_get_srh() + * helper. + */ + skb_reset_network_header(skb); + + srh = seg6_get_srh(skb, 0); + if (!srh) + goto out; + + if (srh->type != IPV6_SRCRT_TYPE_4) + goto out; + + opt->flags |= IP6SKB_SEG6; + opt->srhoff = (unsigned char *)srh - skb->data; + +out: + /* Restore the network header back to the ICMP packet */ + skb->network_header = network_header; +} + +static struct genl_family seg6_genl_family; + +static const struct nla_policy seg6_genl_policy[SEG6_ATTR_MAX + 1] = { + [SEG6_ATTR_DST] = { .type = NLA_BINARY, + .len = sizeof(struct in6_addr) }, + [SEG6_ATTR_DSTLEN] = { .type = NLA_S32, }, + [SEG6_ATTR_HMACKEYID] = { .type = NLA_U32, }, + [SEG6_ATTR_SECRET] = { .type = NLA_BINARY, }, + [SEG6_ATTR_SECRETLEN] = { .type = NLA_U8, }, + [SEG6_ATTR_ALGID] = { .type = NLA_U8, }, + [SEG6_ATTR_HMACINFO] = { .type = NLA_NESTED, }, +}; + +#ifdef CONFIG_IPV6_SEG6_HMAC + +static int seg6_genl_sethmac(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = genl_info_net(info); + struct seg6_pernet_data *sdata; + struct seg6_hmac_info *hinfo; + u32 hmackeyid; + char *secret; + int err = 0; + u8 algid; + u8 slen; + + sdata = seg6_pernet(net); + + if (!info->attrs[SEG6_ATTR_HMACKEYID] || + !info->attrs[SEG6_ATTR_SECRETLEN] || + !info->attrs[SEG6_ATTR_ALGID]) + return -EINVAL; + + hmackeyid = nla_get_u32(info->attrs[SEG6_ATTR_HMACKEYID]); + slen = nla_get_u8(info->attrs[SEG6_ATTR_SECRETLEN]); + algid = nla_get_u8(info->attrs[SEG6_ATTR_ALGID]); + + if (hmackeyid == 0) + return -EINVAL; + + if (slen > SEG6_HMAC_SECRET_LEN) + return -EINVAL; + + mutex_lock(&sdata->lock); + hinfo = seg6_hmac_info_lookup(net, hmackeyid); + + if (!slen) { + err = seg6_hmac_info_del(net, hmackeyid); + + goto out_unlock; + } + + if (!info->attrs[SEG6_ATTR_SECRET]) { + err = -EINVAL; + goto out_unlock; + } + + if (slen > nla_len(info->attrs[SEG6_ATTR_SECRET])) { + err = -EINVAL; + goto out_unlock; + } + + if (hinfo) { + err = seg6_hmac_info_del(net, hmackeyid); + if (err) + goto out_unlock; + } + + secret = (char *)nla_data(info->attrs[SEG6_ATTR_SECRET]); + + hinfo = kzalloc(sizeof(*hinfo), GFP_KERNEL); + if (!hinfo) { + err = -ENOMEM; + goto out_unlock; + } + + memcpy(hinfo->secret, secret, slen); + hinfo->slen = slen; + hinfo->alg_id = algid; + hinfo->hmackeyid = hmackeyid; + + err = seg6_hmac_info_add(net, hmackeyid, hinfo); + if (err) + kfree(hinfo); + +out_unlock: + mutex_unlock(&sdata->lock); + return err; +} + +#else + +static int seg6_genl_sethmac(struct sk_buff *skb, struct genl_info *info) +{ + return -ENOTSUPP; +} + +#endif + +static int seg6_genl_set_tunsrc(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = genl_info_net(info); + struct in6_addr *val, *t_old, *t_new; + struct seg6_pernet_data *sdata; + + sdata = seg6_pernet(net); + + if (!info->attrs[SEG6_ATTR_DST]) + return -EINVAL; + + val = nla_data(info->attrs[SEG6_ATTR_DST]); + t_new = kmemdup(val, sizeof(*val), GFP_KERNEL); + if (!t_new) + return -ENOMEM; + + mutex_lock(&sdata->lock); + + t_old = sdata->tun_src; + rcu_assign_pointer(sdata->tun_src, t_new); + + mutex_unlock(&sdata->lock); + + synchronize_net(); + kfree(t_old); + + return 0; +} + +static int seg6_genl_get_tunsrc(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = genl_info_net(info); + struct in6_addr *tun_src; + struct sk_buff *msg; + void *hdr; + + msg = genlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, info->snd_portid, info->snd_seq, + &seg6_genl_family, 0, SEG6_CMD_GET_TUNSRC); + if (!hdr) + goto free_msg; + + rcu_read_lock(); + tun_src = rcu_dereference(seg6_pernet(net)->tun_src); + + if (nla_put(msg, SEG6_ATTR_DST, sizeof(struct in6_addr), tun_src)) + goto nla_put_failure; + + rcu_read_unlock(); + + genlmsg_end(msg, hdr); + return genlmsg_reply(msg, info); + +nla_put_failure: + rcu_read_unlock(); +free_msg: + nlmsg_free(msg); + return -ENOMEM; +} + +#ifdef CONFIG_IPV6_SEG6_HMAC + +static int __seg6_hmac_fill_info(struct seg6_hmac_info *hinfo, + struct sk_buff *msg) +{ + if (nla_put_u32(msg, SEG6_ATTR_HMACKEYID, hinfo->hmackeyid) || + nla_put_u8(msg, SEG6_ATTR_SECRETLEN, hinfo->slen) || + nla_put(msg, SEG6_ATTR_SECRET, hinfo->slen, hinfo->secret) || + nla_put_u8(msg, SEG6_ATTR_ALGID, hinfo->alg_id)) + return -1; + + return 0; +} + +static int __seg6_genl_dumphmac_element(struct seg6_hmac_info *hinfo, + u32 portid, u32 seq, u32 flags, + struct sk_buff *skb, u8 cmd) +{ + void *hdr; + + hdr = genlmsg_put(skb, portid, seq, &seg6_genl_family, flags, cmd); + if (!hdr) + return -ENOMEM; + + if (__seg6_hmac_fill_info(hinfo, skb) < 0) + goto nla_put_failure; + + genlmsg_end(skb, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; +} + +static int seg6_genl_dumphmac_start(struct netlink_callback *cb) +{ + struct net *net = sock_net(cb->skb->sk); + struct seg6_pernet_data *sdata; + struct rhashtable_iter *iter; + + sdata = seg6_pernet(net); + iter = (struct rhashtable_iter *)cb->args[0]; + + if (!iter) { + iter = kmalloc(sizeof(*iter), GFP_KERNEL); + if (!iter) + return -ENOMEM; + + cb->args[0] = (long)iter; + } + + rhashtable_walk_enter(&sdata->hmac_infos, iter); + + return 0; +} + +static int seg6_genl_dumphmac_done(struct netlink_callback *cb) +{ + struct rhashtable_iter *iter = (struct rhashtable_iter *)cb->args[0]; + + rhashtable_walk_exit(iter); + + kfree(iter); + + return 0; +} + +static int seg6_genl_dumphmac(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct rhashtable_iter *iter = (struct rhashtable_iter *)cb->args[0]; + struct seg6_hmac_info *hinfo; + int ret; + + rhashtable_walk_start(iter); + + for (;;) { + hinfo = rhashtable_walk_next(iter); + + if (IS_ERR(hinfo)) { + if (PTR_ERR(hinfo) == -EAGAIN) + continue; + ret = PTR_ERR(hinfo); + goto done; + } else if (!hinfo) { + break; + } + + ret = __seg6_genl_dumphmac_element(hinfo, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI, + skb, SEG6_CMD_DUMPHMAC); + if (ret) + goto done; + } + + ret = skb->len; + +done: + rhashtable_walk_stop(iter); + return ret; +} + +#else + +static int seg6_genl_dumphmac_start(struct netlink_callback *cb) +{ + return 0; +} + +static int seg6_genl_dumphmac_done(struct netlink_callback *cb) +{ + return 0; +} + +static int seg6_genl_dumphmac(struct sk_buff *skb, struct netlink_callback *cb) +{ + return -ENOTSUPP; +} + +#endif + +static int __net_init seg6_net_init(struct net *net) +{ + struct seg6_pernet_data *sdata; + + sdata = kzalloc(sizeof(*sdata), GFP_KERNEL); + if (!sdata) + return -ENOMEM; + + mutex_init(&sdata->lock); + + sdata->tun_src = kzalloc(sizeof(*sdata->tun_src), GFP_KERNEL); + if (!sdata->tun_src) { + kfree(sdata); + return -ENOMEM; + } + + net->ipv6.seg6_data = sdata; + +#ifdef CONFIG_IPV6_SEG6_HMAC + if (seg6_hmac_net_init(net)) { + kfree(rcu_dereference_raw(sdata->tun_src)); + kfree(sdata); + return -ENOMEM; + } +#endif + + return 0; +} + +static void __net_exit seg6_net_exit(struct net *net) +{ + struct seg6_pernet_data *sdata = seg6_pernet(net); + +#ifdef CONFIG_IPV6_SEG6_HMAC + seg6_hmac_net_exit(net); +#endif + + kfree(rcu_dereference_raw(sdata->tun_src)); + kfree(sdata); +} + +static struct pernet_operations ip6_segments_ops = { + .init = seg6_net_init, + .exit = seg6_net_exit, +}; + +static const struct genl_ops seg6_genl_ops[] = { + { + .cmd = SEG6_CMD_SETHMAC, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = seg6_genl_sethmac, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = SEG6_CMD_DUMPHMAC, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .start = seg6_genl_dumphmac_start, + .dumpit = seg6_genl_dumphmac, + .done = seg6_genl_dumphmac_done, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = SEG6_CMD_SET_TUNSRC, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = seg6_genl_set_tunsrc, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = SEG6_CMD_GET_TUNSRC, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = seg6_genl_get_tunsrc, + .flags = GENL_ADMIN_PERM, + }, +}; + +static struct genl_family seg6_genl_family __ro_after_init = { + .hdrsize = 0, + .name = SEG6_GENL_NAME, + .version = SEG6_GENL_VERSION, + .maxattr = SEG6_ATTR_MAX, + .policy = seg6_genl_policy, + .netnsok = true, + .parallel_ops = true, + .ops = seg6_genl_ops, + .n_ops = ARRAY_SIZE(seg6_genl_ops), + .resv_start_op = SEG6_CMD_GET_TUNSRC + 1, + .module = THIS_MODULE, +}; + +int __init seg6_init(void) +{ + int err; + + err = genl_register_family(&seg6_genl_family); + if (err) + goto out; + + err = register_pernet_subsys(&ip6_segments_ops); + if (err) + goto out_unregister_genl; + +#ifdef CONFIG_IPV6_SEG6_LWTUNNEL + err = seg6_iptunnel_init(); + if (err) + goto out_unregister_pernet; + + err = seg6_local_init(); + if (err) + goto out_unregister_pernet; +#endif + +#ifdef CONFIG_IPV6_SEG6_HMAC + err = seg6_hmac_init(); + if (err) + goto out_unregister_iptun; +#endif + + pr_info("Segment Routing with IPv6\n"); + +out: + return err; +#ifdef CONFIG_IPV6_SEG6_HMAC +out_unregister_iptun: +#ifdef CONFIG_IPV6_SEG6_LWTUNNEL + seg6_local_exit(); + seg6_iptunnel_exit(); +#endif +#endif +#ifdef CONFIG_IPV6_SEG6_LWTUNNEL +out_unregister_pernet: + unregister_pernet_subsys(&ip6_segments_ops); +#endif +out_unregister_genl: + genl_unregister_family(&seg6_genl_family); + goto out; +} + +void seg6_exit(void) +{ +#ifdef CONFIG_IPV6_SEG6_HMAC + seg6_hmac_exit(); +#endif +#ifdef CONFIG_IPV6_SEG6_LWTUNNEL + seg6_iptunnel_exit(); +#endif + unregister_pernet_subsys(&ip6_segments_ops); + genl_unregister_family(&seg6_genl_family); +} diff --git a/net/ipv6/seg6_hmac.c b/net/ipv6/seg6_hmac.c new file mode 100644 index 000000000..d43c50a73 --- /dev/null +++ b/net/ipv6/seg6_hmac.c @@ -0,0 +1,439 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * SR-IPv6 implementation -- HMAC functions + * + * Author: + * David Lebrun + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +static DEFINE_PER_CPU(char [SEG6_HMAC_RING_SIZE], hmac_ring); + +static int seg6_hmac_cmpfn(struct rhashtable_compare_arg *arg, const void *obj) +{ + const struct seg6_hmac_info *hinfo = obj; + + return (hinfo->hmackeyid != *(__u32 *)arg->key); +} + +static inline void seg6_hinfo_release(struct seg6_hmac_info *hinfo) +{ + kfree_rcu(hinfo, rcu); +} + +static void seg6_free_hi(void *ptr, void *arg) +{ + struct seg6_hmac_info *hinfo = (struct seg6_hmac_info *)ptr; + + if (hinfo) + seg6_hinfo_release(hinfo); +} + +static const struct rhashtable_params rht_params = { + .head_offset = offsetof(struct seg6_hmac_info, node), + .key_offset = offsetof(struct seg6_hmac_info, hmackeyid), + .key_len = sizeof(u32), + .automatic_shrinking = true, + .obj_cmpfn = seg6_hmac_cmpfn, +}; + +static struct seg6_hmac_algo hmac_algos[] = { + { + .alg_id = SEG6_HMAC_ALGO_SHA1, + .name = "hmac(sha1)", + }, + { + .alg_id = SEG6_HMAC_ALGO_SHA256, + .name = "hmac(sha256)", + }, +}; + +static struct sr6_tlv_hmac *seg6_get_tlv_hmac(struct ipv6_sr_hdr *srh) +{ + struct sr6_tlv_hmac *tlv; + + if (srh->hdrlen < (srh->first_segment + 1) * 2 + 5) + return NULL; + + if (!sr_has_hmac(srh)) + return NULL; + + tlv = (struct sr6_tlv_hmac *) + ((char *)srh + ((srh->hdrlen + 1) << 3) - 40); + + if (tlv->tlvhdr.type != SR6_TLV_HMAC || tlv->tlvhdr.len != 38) + return NULL; + + return tlv; +} + +static struct seg6_hmac_algo *__hmac_get_algo(u8 alg_id) +{ + struct seg6_hmac_algo *algo; + int i, alg_count; + + alg_count = ARRAY_SIZE(hmac_algos); + for (i = 0; i < alg_count; i++) { + algo = &hmac_algos[i]; + if (algo->alg_id == alg_id) + return algo; + } + + return NULL; +} + +static int __do_hmac(struct seg6_hmac_info *hinfo, const char *text, u8 psize, + u8 *output, int outlen) +{ + struct seg6_hmac_algo *algo; + struct crypto_shash *tfm; + struct shash_desc *shash; + int ret, dgsize; + + algo = __hmac_get_algo(hinfo->alg_id); + if (!algo) + return -ENOENT; + + tfm = *this_cpu_ptr(algo->tfms); + + dgsize = crypto_shash_digestsize(tfm); + if (dgsize > outlen) { + pr_debug("sr-ipv6: __do_hmac: digest size too big (%d / %d)\n", + dgsize, outlen); + return -ENOMEM; + } + + ret = crypto_shash_setkey(tfm, hinfo->secret, hinfo->slen); + if (ret < 0) { + pr_debug("sr-ipv6: crypto_shash_setkey failed: err %d\n", ret); + goto failed; + } + + shash = *this_cpu_ptr(algo->shashs); + shash->tfm = tfm; + + ret = crypto_shash_digest(shash, text, psize, output); + if (ret < 0) { + pr_debug("sr-ipv6: crypto_shash_digest failed: err %d\n", ret); + goto failed; + } + + return dgsize; + +failed: + return ret; +} + +int seg6_hmac_compute(struct seg6_hmac_info *hinfo, struct ipv6_sr_hdr *hdr, + struct in6_addr *saddr, u8 *output) +{ + __be32 hmackeyid = cpu_to_be32(hinfo->hmackeyid); + u8 tmp_out[SEG6_HMAC_MAX_DIGESTSIZE]; + int plen, i, dgsize, wrsize; + char *ring, *off; + + /* a 160-byte buffer for digest output allows to store highest known + * hash function (RadioGatun) with up to 1216 bits + */ + + /* saddr(16) + first_seg(1) + flags(1) + keyid(4) + seglist(16n) */ + plen = 16 + 1 + 1 + 4 + (hdr->first_segment + 1) * 16; + + /* this limit allows for 14 segments */ + if (plen >= SEG6_HMAC_RING_SIZE) + return -EMSGSIZE; + + /* Let's build the HMAC text on the ring buffer. The text is composed + * as follows, in order: + * + * 1. Source IPv6 address (128 bits) + * 2. first_segment value (8 bits) + * 3. Flags (8 bits) + * 4. HMAC Key ID (32 bits) + * 5. All segments in the segments list (n * 128 bits) + */ + + local_bh_disable(); + ring = this_cpu_ptr(hmac_ring); + off = ring; + + /* source address */ + memcpy(off, saddr, 16); + off += 16; + + /* first_segment value */ + *off++ = hdr->first_segment; + + /* flags */ + *off++ = hdr->flags; + + /* HMAC Key ID */ + memcpy(off, &hmackeyid, 4); + off += 4; + + /* all segments in the list */ + for (i = 0; i < hdr->first_segment + 1; i++) { + memcpy(off, hdr->segments + i, 16); + off += 16; + } + + dgsize = __do_hmac(hinfo, ring, plen, tmp_out, + SEG6_HMAC_MAX_DIGESTSIZE); + local_bh_enable(); + + if (dgsize < 0) + return dgsize; + + wrsize = SEG6_HMAC_FIELD_LEN; + if (wrsize > dgsize) + wrsize = dgsize; + + memset(output, 0, SEG6_HMAC_FIELD_LEN); + memcpy(output, tmp_out, wrsize); + + return 0; +} +EXPORT_SYMBOL(seg6_hmac_compute); + +/* checks if an incoming SR-enabled packet's HMAC status matches + * the incoming policy. + * + * called with rcu_read_lock() + */ +bool seg6_hmac_validate_skb(struct sk_buff *skb) +{ + u8 hmac_output[SEG6_HMAC_FIELD_LEN]; + struct net *net = dev_net(skb->dev); + struct seg6_hmac_info *hinfo; + struct sr6_tlv_hmac *tlv; + struct ipv6_sr_hdr *srh; + struct inet6_dev *idev; + + idev = __in6_dev_get(skb->dev); + + srh = (struct ipv6_sr_hdr *)skb_transport_header(skb); + + tlv = seg6_get_tlv_hmac(srh); + + /* mandatory check but no tlv */ + if (idev->cnf.seg6_require_hmac > 0 && !tlv) + return false; + + /* no check */ + if (idev->cnf.seg6_require_hmac < 0) + return true; + + /* check only if present */ + if (idev->cnf.seg6_require_hmac == 0 && !tlv) + return true; + + /* now, seg6_require_hmac >= 0 && tlv */ + + hinfo = seg6_hmac_info_lookup(net, be32_to_cpu(tlv->hmackeyid)); + if (!hinfo) + return false; + + if (seg6_hmac_compute(hinfo, srh, &ipv6_hdr(skb)->saddr, hmac_output)) + return false; + + if (memcmp(hmac_output, tlv->hmac, SEG6_HMAC_FIELD_LEN) != 0) + return false; + + return true; +} +EXPORT_SYMBOL(seg6_hmac_validate_skb); + +/* called with rcu_read_lock() */ +struct seg6_hmac_info *seg6_hmac_info_lookup(struct net *net, u32 key) +{ + struct seg6_pernet_data *sdata = seg6_pernet(net); + struct seg6_hmac_info *hinfo; + + hinfo = rhashtable_lookup_fast(&sdata->hmac_infos, &key, rht_params); + + return hinfo; +} +EXPORT_SYMBOL(seg6_hmac_info_lookup); + +int seg6_hmac_info_add(struct net *net, u32 key, struct seg6_hmac_info *hinfo) +{ + struct seg6_pernet_data *sdata = seg6_pernet(net); + int err; + + err = rhashtable_lookup_insert_fast(&sdata->hmac_infos, &hinfo->node, + rht_params); + + return err; +} +EXPORT_SYMBOL(seg6_hmac_info_add); + +int seg6_hmac_info_del(struct net *net, u32 key) +{ + struct seg6_pernet_data *sdata = seg6_pernet(net); + struct seg6_hmac_info *hinfo; + int err = -ENOENT; + + hinfo = rhashtable_lookup_fast(&sdata->hmac_infos, &key, rht_params); + if (!hinfo) + goto out; + + err = rhashtable_remove_fast(&sdata->hmac_infos, &hinfo->node, + rht_params); + if (err) + goto out; + + seg6_hinfo_release(hinfo); + +out: + return err; +} +EXPORT_SYMBOL(seg6_hmac_info_del); + +int seg6_push_hmac(struct net *net, struct in6_addr *saddr, + struct ipv6_sr_hdr *srh) +{ + struct seg6_hmac_info *hinfo; + struct sr6_tlv_hmac *tlv; + int err = -ENOENT; + + tlv = seg6_get_tlv_hmac(srh); + if (!tlv) + return -EINVAL; + + rcu_read_lock(); + + hinfo = seg6_hmac_info_lookup(net, be32_to_cpu(tlv->hmackeyid)); + if (!hinfo) + goto out; + + memset(tlv->hmac, 0, SEG6_HMAC_FIELD_LEN); + err = seg6_hmac_compute(hinfo, srh, saddr, tlv->hmac); + +out: + rcu_read_unlock(); + return err; +} +EXPORT_SYMBOL(seg6_push_hmac); + +static int seg6_hmac_init_algo(void) +{ + struct seg6_hmac_algo *algo; + struct crypto_shash *tfm; + struct shash_desc *shash; + int i, alg_count, cpu; + + alg_count = ARRAY_SIZE(hmac_algos); + + for (i = 0; i < alg_count; i++) { + struct crypto_shash **p_tfm; + int shsize; + + algo = &hmac_algos[i]; + algo->tfms = alloc_percpu(struct crypto_shash *); + if (!algo->tfms) + return -ENOMEM; + + for_each_possible_cpu(cpu) { + tfm = crypto_alloc_shash(algo->name, 0, 0); + if (IS_ERR(tfm)) + return PTR_ERR(tfm); + p_tfm = per_cpu_ptr(algo->tfms, cpu); + *p_tfm = tfm; + } + + p_tfm = raw_cpu_ptr(algo->tfms); + tfm = *p_tfm; + + shsize = sizeof(*shash) + crypto_shash_descsize(tfm); + + algo->shashs = alloc_percpu(struct shash_desc *); + if (!algo->shashs) + return -ENOMEM; + + for_each_possible_cpu(cpu) { + shash = kzalloc_node(shsize, GFP_KERNEL, + cpu_to_node(cpu)); + if (!shash) + return -ENOMEM; + *per_cpu_ptr(algo->shashs, cpu) = shash; + } + } + + return 0; +} + +int __init seg6_hmac_init(void) +{ + return seg6_hmac_init_algo(); +} + +int __net_init seg6_hmac_net_init(struct net *net) +{ + struct seg6_pernet_data *sdata = seg6_pernet(net); + + return rhashtable_init(&sdata->hmac_infos, &rht_params); +} + +void seg6_hmac_exit(void) +{ + struct seg6_hmac_algo *algo = NULL; + int i, alg_count, cpu; + + alg_count = ARRAY_SIZE(hmac_algos); + for (i = 0; i < alg_count; i++) { + algo = &hmac_algos[i]; + for_each_possible_cpu(cpu) { + struct crypto_shash *tfm; + struct shash_desc *shash; + + shash = *per_cpu_ptr(algo->shashs, cpu); + kfree(shash); + tfm = *per_cpu_ptr(algo->tfms, cpu); + crypto_free_shash(tfm); + } + free_percpu(algo->tfms); + free_percpu(algo->shashs); + } +} +EXPORT_SYMBOL(seg6_hmac_exit); + +void __net_exit seg6_hmac_net_exit(struct net *net) +{ + struct seg6_pernet_data *sdata = seg6_pernet(net); + + rhashtable_free_and_destroy(&sdata->hmac_infos, seg6_free_hi, NULL); +} +EXPORT_SYMBOL(seg6_hmac_net_exit); diff --git a/net/ipv6/seg6_iptunnel.c b/net/ipv6/seg6_iptunnel.c new file mode 100644 index 000000000..34db88120 --- /dev/null +++ b/net/ipv6/seg6_iptunnel.c @@ -0,0 +1,745 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * SR-IPv6 implementation + * + * Author: + * David Lebrun + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_IPV6_SEG6_HMAC +#include +#endif +#include + +static size_t seg6_lwt_headroom(struct seg6_iptunnel_encap *tuninfo) +{ + int head = 0; + + switch (tuninfo->mode) { + case SEG6_IPTUN_MODE_INLINE: + break; + case SEG6_IPTUN_MODE_ENCAP: + case SEG6_IPTUN_MODE_ENCAP_RED: + head = sizeof(struct ipv6hdr); + break; + case SEG6_IPTUN_MODE_L2ENCAP: + case SEG6_IPTUN_MODE_L2ENCAP_RED: + return 0; + } + + return ((tuninfo->srh->hdrlen + 1) << 3) + head; +} + +struct seg6_lwt { + struct dst_cache cache; + struct seg6_iptunnel_encap tuninfo[]; +}; + +static inline struct seg6_lwt *seg6_lwt_lwtunnel(struct lwtunnel_state *lwt) +{ + return (struct seg6_lwt *)lwt->data; +} + +static inline struct seg6_iptunnel_encap * +seg6_encap_lwtunnel(struct lwtunnel_state *lwt) +{ + return seg6_lwt_lwtunnel(lwt)->tuninfo; +} + +static const struct nla_policy seg6_iptunnel_policy[SEG6_IPTUNNEL_MAX + 1] = { + [SEG6_IPTUNNEL_SRH] = { .type = NLA_BINARY }, +}; + +static int nla_put_srh(struct sk_buff *skb, int attrtype, + struct seg6_iptunnel_encap *tuninfo) +{ + struct seg6_iptunnel_encap *data; + struct nlattr *nla; + int len; + + len = SEG6_IPTUN_ENCAP_SIZE(tuninfo); + + nla = nla_reserve(skb, attrtype, len); + if (!nla) + return -EMSGSIZE; + + data = nla_data(nla); + memcpy(data, tuninfo, len); + + return 0; +} + +static void set_tun_src(struct net *net, struct net_device *dev, + struct in6_addr *daddr, struct in6_addr *saddr) +{ + struct seg6_pernet_data *sdata = seg6_pernet(net); + struct in6_addr *tun_src; + + rcu_read_lock(); + + tun_src = rcu_dereference(sdata->tun_src); + + if (!ipv6_addr_any(tun_src)) { + memcpy(saddr, tun_src, sizeof(struct in6_addr)); + } else { + ipv6_dev_get_saddr(net, dev, daddr, IPV6_PREFER_SRC_PUBLIC, + saddr); + } + + rcu_read_unlock(); +} + +/* Compute flowlabel for outer IPv6 header */ +static __be32 seg6_make_flowlabel(struct net *net, struct sk_buff *skb, + struct ipv6hdr *inner_hdr) +{ + int do_flowlabel = net->ipv6.sysctl.seg6_flowlabel; + __be32 flowlabel = 0; + u32 hash; + + if (do_flowlabel > 0) { + hash = skb_get_hash(skb); + hash = rol32(hash, 16); + flowlabel = (__force __be32)hash & IPV6_FLOWLABEL_MASK; + } else if (!do_flowlabel && skb->protocol == htons(ETH_P_IPV6)) { + flowlabel = ip6_flowlabel(inner_hdr); + } + return flowlabel; +} + +/* encapsulate an IPv6 packet within an outer IPv6 header with a given SRH */ +int seg6_do_srh_encap(struct sk_buff *skb, struct ipv6_sr_hdr *osrh, int proto) +{ + struct dst_entry *dst = skb_dst(skb); + struct net *net = dev_net(dst->dev); + struct ipv6hdr *hdr, *inner_hdr; + struct ipv6_sr_hdr *isrh; + int hdrlen, tot_len, err; + __be32 flowlabel; + + hdrlen = (osrh->hdrlen + 1) << 3; + tot_len = hdrlen + sizeof(*hdr); + + err = skb_cow_head(skb, tot_len + skb->mac_len); + if (unlikely(err)) + return err; + + inner_hdr = ipv6_hdr(skb); + flowlabel = seg6_make_flowlabel(net, skb, inner_hdr); + + skb_push(skb, tot_len); + skb_reset_network_header(skb); + skb_mac_header_rebuild(skb); + hdr = ipv6_hdr(skb); + + /* inherit tc, flowlabel and hlim + * hlim will be decremented in ip6_forward() afterwards and + * decapsulation will overwrite inner hlim with outer hlim + */ + + if (skb->protocol == htons(ETH_P_IPV6)) { + ip6_flow_hdr(hdr, ip6_tclass(ip6_flowinfo(inner_hdr)), + flowlabel); + hdr->hop_limit = inner_hdr->hop_limit; + } else { + ip6_flow_hdr(hdr, 0, flowlabel); + hdr->hop_limit = ip6_dst_hoplimit(skb_dst(skb)); + + memset(IP6CB(skb), 0, sizeof(*IP6CB(skb))); + + /* the control block has been erased, so we have to set the + * iif once again. + * We read the receiving interface index directly from the + * skb->skb_iif as it is done in the IPv4 receiving path (i.e.: + * ip_rcv_core(...)). + */ + IP6CB(skb)->iif = skb->skb_iif; + } + + hdr->nexthdr = NEXTHDR_ROUTING; + + isrh = (void *)hdr + sizeof(*hdr); + memcpy(isrh, osrh, hdrlen); + + isrh->nexthdr = proto; + + hdr->daddr = isrh->segments[isrh->first_segment]; + set_tun_src(net, dst->dev, &hdr->daddr, &hdr->saddr); + +#ifdef CONFIG_IPV6_SEG6_HMAC + if (sr_has_hmac(isrh)) { + err = seg6_push_hmac(net, &hdr->saddr, isrh); + if (unlikely(err)) + return err; + } +#endif + + hdr->payload_len = htons(skb->len - sizeof(struct ipv6hdr)); + + skb_postpush_rcsum(skb, hdr, tot_len); + + return 0; +} +EXPORT_SYMBOL_GPL(seg6_do_srh_encap); + +/* encapsulate an IPv6 packet within an outer IPv6 header with reduced SRH */ +static int seg6_do_srh_encap_red(struct sk_buff *skb, + struct ipv6_sr_hdr *osrh, int proto) +{ + __u8 first_seg = osrh->first_segment; + struct dst_entry *dst = skb_dst(skb); + struct net *net = dev_net(dst->dev); + struct ipv6hdr *hdr, *inner_hdr; + int hdrlen = ipv6_optlen(osrh); + int red_tlv_offset, tlv_offset; + struct ipv6_sr_hdr *isrh; + bool skip_srh = false; + __be32 flowlabel; + int tot_len, err; + int red_hdrlen; + int tlvs_len; + + if (first_seg > 0) { + red_hdrlen = hdrlen - sizeof(struct in6_addr); + } else { + /* NOTE: if tag/flags and/or other TLVs are introduced in the + * seg6_iptunnel infrastructure, they should be considered when + * deciding to skip the SRH. + */ + skip_srh = !sr_has_hmac(osrh); + + red_hdrlen = skip_srh ? 0 : hdrlen; + } + + tot_len = red_hdrlen + sizeof(struct ipv6hdr); + + err = skb_cow_head(skb, tot_len + skb->mac_len); + if (unlikely(err)) + return err; + + inner_hdr = ipv6_hdr(skb); + flowlabel = seg6_make_flowlabel(net, skb, inner_hdr); + + skb_push(skb, tot_len); + skb_reset_network_header(skb); + skb_mac_header_rebuild(skb); + hdr = ipv6_hdr(skb); + + /* based on seg6_do_srh_encap() */ + if (skb->protocol == htons(ETH_P_IPV6)) { + ip6_flow_hdr(hdr, ip6_tclass(ip6_flowinfo(inner_hdr)), + flowlabel); + hdr->hop_limit = inner_hdr->hop_limit; + } else { + ip6_flow_hdr(hdr, 0, flowlabel); + hdr->hop_limit = ip6_dst_hoplimit(skb_dst(skb)); + + memset(IP6CB(skb), 0, sizeof(*IP6CB(skb))); + IP6CB(skb)->iif = skb->skb_iif; + } + + /* no matter if we have to skip the SRH or not, the first segment + * always comes in the pushed IPv6 header. + */ + hdr->daddr = osrh->segments[first_seg]; + + if (skip_srh) { + hdr->nexthdr = proto; + + set_tun_src(net, dst->dev, &hdr->daddr, &hdr->saddr); + goto out; + } + + /* we cannot skip the SRH, slow path */ + + hdr->nexthdr = NEXTHDR_ROUTING; + isrh = (void *)hdr + sizeof(struct ipv6hdr); + + if (unlikely(!first_seg)) { + /* this is a very rare case; we have only one SID but + * we cannot skip the SRH since we are carrying some + * other info. + */ + memcpy(isrh, osrh, hdrlen); + goto srcaddr; + } + + tlv_offset = sizeof(*osrh) + (first_seg + 1) * sizeof(struct in6_addr); + red_tlv_offset = tlv_offset - sizeof(struct in6_addr); + + memcpy(isrh, osrh, red_tlv_offset); + + tlvs_len = hdrlen - tlv_offset; + if (unlikely(tlvs_len > 0)) { + const void *s = (const void *)osrh + tlv_offset; + void *d = (void *)isrh + red_tlv_offset; + + memcpy(d, s, tlvs_len); + } + + --isrh->first_segment; + isrh->hdrlen -= 2; + +srcaddr: + isrh->nexthdr = proto; + set_tun_src(net, dst->dev, &hdr->daddr, &hdr->saddr); + +#ifdef CONFIG_IPV6_SEG6_HMAC + if (unlikely(!skip_srh && sr_has_hmac(isrh))) { + err = seg6_push_hmac(net, &hdr->saddr, isrh); + if (unlikely(err)) + return err; + } +#endif + +out: + hdr->payload_len = htons(skb->len - sizeof(struct ipv6hdr)); + + skb_postpush_rcsum(skb, hdr, tot_len); + + return 0; +} + +/* insert an SRH within an IPv6 packet, just after the IPv6 header */ +int seg6_do_srh_inline(struct sk_buff *skb, struct ipv6_sr_hdr *osrh) +{ + struct ipv6hdr *hdr, *oldhdr; + struct ipv6_sr_hdr *isrh; + int hdrlen, err; + + hdrlen = (osrh->hdrlen + 1) << 3; + + err = skb_cow_head(skb, hdrlen + skb->mac_len); + if (unlikely(err)) + return err; + + oldhdr = ipv6_hdr(skb); + + skb_pull(skb, sizeof(struct ipv6hdr)); + skb_postpull_rcsum(skb, skb_network_header(skb), + sizeof(struct ipv6hdr)); + + skb_push(skb, sizeof(struct ipv6hdr) + hdrlen); + skb_reset_network_header(skb); + skb_mac_header_rebuild(skb); + + hdr = ipv6_hdr(skb); + + memmove(hdr, oldhdr, sizeof(*hdr)); + + isrh = (void *)hdr + sizeof(*hdr); + memcpy(isrh, osrh, hdrlen); + + isrh->nexthdr = hdr->nexthdr; + hdr->nexthdr = NEXTHDR_ROUTING; + + isrh->segments[0] = hdr->daddr; + hdr->daddr = isrh->segments[isrh->first_segment]; + +#ifdef CONFIG_IPV6_SEG6_HMAC + if (sr_has_hmac(isrh)) { + struct net *net = dev_net(skb_dst(skb)->dev); + + err = seg6_push_hmac(net, &hdr->saddr, isrh); + if (unlikely(err)) + return err; + } +#endif + + hdr->payload_len = htons(skb->len - sizeof(struct ipv6hdr)); + + skb_postpush_rcsum(skb, hdr, sizeof(struct ipv6hdr) + hdrlen); + + return 0; +} +EXPORT_SYMBOL_GPL(seg6_do_srh_inline); + +static int seg6_do_srh(struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + struct seg6_iptunnel_encap *tinfo; + int proto, err = 0; + + tinfo = seg6_encap_lwtunnel(dst->lwtstate); + + switch (tinfo->mode) { + case SEG6_IPTUN_MODE_INLINE: + if (skb->protocol != htons(ETH_P_IPV6)) + return -EINVAL; + + err = seg6_do_srh_inline(skb, tinfo->srh); + if (err) + return err; + break; + case SEG6_IPTUN_MODE_ENCAP: + case SEG6_IPTUN_MODE_ENCAP_RED: + err = iptunnel_handle_offloads(skb, SKB_GSO_IPXIP6); + if (err) + return err; + + if (skb->protocol == htons(ETH_P_IPV6)) + proto = IPPROTO_IPV6; + else if (skb->protocol == htons(ETH_P_IP)) + proto = IPPROTO_IPIP; + else + return -EINVAL; + + if (tinfo->mode == SEG6_IPTUN_MODE_ENCAP) + err = seg6_do_srh_encap(skb, tinfo->srh, proto); + else + err = seg6_do_srh_encap_red(skb, tinfo->srh, proto); + + if (err) + return err; + + skb_set_inner_transport_header(skb, skb_transport_offset(skb)); + skb_set_inner_protocol(skb, skb->protocol); + skb->protocol = htons(ETH_P_IPV6); + break; + case SEG6_IPTUN_MODE_L2ENCAP: + case SEG6_IPTUN_MODE_L2ENCAP_RED: + if (!skb_mac_header_was_set(skb)) + return -EINVAL; + + if (pskb_expand_head(skb, skb->mac_len, 0, GFP_ATOMIC) < 0) + return -ENOMEM; + + skb_mac_header_rebuild(skb); + skb_push(skb, skb->mac_len); + + if (tinfo->mode == SEG6_IPTUN_MODE_L2ENCAP) + err = seg6_do_srh_encap(skb, tinfo->srh, + IPPROTO_ETHERNET); + else + err = seg6_do_srh_encap_red(skb, tinfo->srh, + IPPROTO_ETHERNET); + + if (err) + return err; + + skb->protocol = htons(ETH_P_IPV6); + break; + } + + skb_set_transport_header(skb, sizeof(struct ipv6hdr)); + nf_reset_ct(skb); + + return 0; +} + +static int seg6_input_finish(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + return dst_input(skb); +} + +static int seg6_input_core(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + struct dst_entry *orig_dst = skb_dst(skb); + struct dst_entry *dst = NULL; + struct seg6_lwt *slwt; + int err; + + err = seg6_do_srh(skb); + if (unlikely(err)) { + kfree_skb(skb); + return err; + } + + slwt = seg6_lwt_lwtunnel(orig_dst->lwtstate); + + preempt_disable(); + dst = dst_cache_get(&slwt->cache); + preempt_enable(); + + skb_dst_drop(skb); + + if (!dst) { + ip6_route_input(skb); + dst = skb_dst(skb); + if (!dst->error) { + preempt_disable(); + dst_cache_set_ip6(&slwt->cache, dst, + &ipv6_hdr(skb)->saddr); + preempt_enable(); + } + } else { + skb_dst_set(skb, dst); + } + + err = skb_cow_head(skb, LL_RESERVED_SPACE(dst->dev)); + if (unlikely(err)) + return err; + + if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled)) + return NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, + dev_net(skb->dev), NULL, skb, NULL, + skb_dst(skb)->dev, seg6_input_finish); + + return seg6_input_finish(dev_net(skb->dev), NULL, skb); +} + +static int seg6_input_nf(struct sk_buff *skb) +{ + struct net_device *dev = skb_dst(skb)->dev; + struct net *net = dev_net(skb->dev); + + switch (skb->protocol) { + case htons(ETH_P_IP): + return NF_HOOK(NFPROTO_IPV4, NF_INET_POST_ROUTING, net, NULL, + skb, NULL, dev, seg6_input_core); + case htons(ETH_P_IPV6): + return NF_HOOK(NFPROTO_IPV6, NF_INET_POST_ROUTING, net, NULL, + skb, NULL, dev, seg6_input_core); + } + + return -EINVAL; +} + +static int seg6_input(struct sk_buff *skb) +{ + if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled)) + return seg6_input_nf(skb); + + return seg6_input_core(dev_net(skb->dev), NULL, skb); +} + +static int seg6_output_core(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + struct dst_entry *orig_dst = skb_dst(skb); + struct dst_entry *dst = NULL; + struct seg6_lwt *slwt; + int err; + + err = seg6_do_srh(skb); + if (unlikely(err)) + goto drop; + + slwt = seg6_lwt_lwtunnel(orig_dst->lwtstate); + + preempt_disable(); + dst = dst_cache_get(&slwt->cache); + preempt_enable(); + + if (unlikely(!dst)) { + struct ipv6hdr *hdr = ipv6_hdr(skb); + struct flowi6 fl6; + + memset(&fl6, 0, sizeof(fl6)); + fl6.daddr = hdr->daddr; + fl6.saddr = hdr->saddr; + fl6.flowlabel = ip6_flowinfo(hdr); + fl6.flowi6_mark = skb->mark; + fl6.flowi6_proto = hdr->nexthdr; + + dst = ip6_route_output(net, NULL, &fl6); + if (dst->error) { + err = dst->error; + dst_release(dst); + goto drop; + } + + preempt_disable(); + dst_cache_set_ip6(&slwt->cache, dst, &fl6.saddr); + preempt_enable(); + } + + skb_dst_drop(skb); + skb_dst_set(skb, dst); + + err = skb_cow_head(skb, LL_RESERVED_SPACE(dst->dev)); + if (unlikely(err)) + goto drop; + + if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled)) + return NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, net, sk, skb, + NULL, skb_dst(skb)->dev, dst_output); + + return dst_output(net, sk, skb); +drop: + kfree_skb(skb); + return err; +} + +static int seg6_output_nf(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct net_device *dev = skb_dst(skb)->dev; + + switch (skb->protocol) { + case htons(ETH_P_IP): + return NF_HOOK(NFPROTO_IPV4, NF_INET_POST_ROUTING, net, sk, skb, + NULL, dev, seg6_output_core); + case htons(ETH_P_IPV6): + return NF_HOOK(NFPROTO_IPV6, NF_INET_POST_ROUTING, net, sk, skb, + NULL, dev, seg6_output_core); + } + + return -EINVAL; +} + +static int seg6_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled)) + return seg6_output_nf(net, sk, skb); + + return seg6_output_core(net, sk, skb); +} + +static int seg6_build_state(struct net *net, struct nlattr *nla, + unsigned int family, const void *cfg, + struct lwtunnel_state **ts, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[SEG6_IPTUNNEL_MAX + 1]; + struct seg6_iptunnel_encap *tuninfo; + struct lwtunnel_state *newts; + int tuninfo_len, min_size; + struct seg6_lwt *slwt; + int err; + + if (family != AF_INET && family != AF_INET6) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, SEG6_IPTUNNEL_MAX, nla, + seg6_iptunnel_policy, extack); + + if (err < 0) + return err; + + if (!tb[SEG6_IPTUNNEL_SRH]) + return -EINVAL; + + tuninfo = nla_data(tb[SEG6_IPTUNNEL_SRH]); + tuninfo_len = nla_len(tb[SEG6_IPTUNNEL_SRH]); + + /* tuninfo must contain at least the iptunnel encap structure, + * the SRH and one segment + */ + min_size = sizeof(*tuninfo) + sizeof(struct ipv6_sr_hdr) + + sizeof(struct in6_addr); + if (tuninfo_len < min_size) + return -EINVAL; + + switch (tuninfo->mode) { + case SEG6_IPTUN_MODE_INLINE: + if (family != AF_INET6) + return -EINVAL; + + break; + case SEG6_IPTUN_MODE_ENCAP: + break; + case SEG6_IPTUN_MODE_L2ENCAP: + break; + case SEG6_IPTUN_MODE_ENCAP_RED: + break; + case SEG6_IPTUN_MODE_L2ENCAP_RED: + break; + default: + return -EINVAL; + } + + /* verify that SRH is consistent */ + if (!seg6_validate_srh(tuninfo->srh, tuninfo_len - sizeof(*tuninfo), false)) + return -EINVAL; + + newts = lwtunnel_state_alloc(tuninfo_len + sizeof(*slwt)); + if (!newts) + return -ENOMEM; + + slwt = seg6_lwt_lwtunnel(newts); + + err = dst_cache_init(&slwt->cache, GFP_ATOMIC); + if (err) { + kfree(newts); + return err; + } + + memcpy(&slwt->tuninfo, tuninfo, tuninfo_len); + + newts->type = LWTUNNEL_ENCAP_SEG6; + newts->flags |= LWTUNNEL_STATE_INPUT_REDIRECT; + + if (tuninfo->mode != SEG6_IPTUN_MODE_L2ENCAP) + newts->flags |= LWTUNNEL_STATE_OUTPUT_REDIRECT; + + newts->headroom = seg6_lwt_headroom(tuninfo); + + *ts = newts; + + return 0; +} + +static void seg6_destroy_state(struct lwtunnel_state *lwt) +{ + dst_cache_destroy(&seg6_lwt_lwtunnel(lwt)->cache); +} + +static int seg6_fill_encap_info(struct sk_buff *skb, + struct lwtunnel_state *lwtstate) +{ + struct seg6_iptunnel_encap *tuninfo = seg6_encap_lwtunnel(lwtstate); + + if (nla_put_srh(skb, SEG6_IPTUNNEL_SRH, tuninfo)) + return -EMSGSIZE; + + return 0; +} + +static int seg6_encap_nlsize(struct lwtunnel_state *lwtstate) +{ + struct seg6_iptunnel_encap *tuninfo = seg6_encap_lwtunnel(lwtstate); + + return nla_total_size(SEG6_IPTUN_ENCAP_SIZE(tuninfo)); +} + +static int seg6_encap_cmp(struct lwtunnel_state *a, struct lwtunnel_state *b) +{ + struct seg6_iptunnel_encap *a_hdr = seg6_encap_lwtunnel(a); + struct seg6_iptunnel_encap *b_hdr = seg6_encap_lwtunnel(b); + int len = SEG6_IPTUN_ENCAP_SIZE(a_hdr); + + if (len != SEG6_IPTUN_ENCAP_SIZE(b_hdr)) + return 1; + + return memcmp(a_hdr, b_hdr, len); +} + +static const struct lwtunnel_encap_ops seg6_iptun_ops = { + .build_state = seg6_build_state, + .destroy_state = seg6_destroy_state, + .output = seg6_output, + .input = seg6_input, + .fill_encap = seg6_fill_encap_info, + .get_encap_size = seg6_encap_nlsize, + .cmp_encap = seg6_encap_cmp, + .owner = THIS_MODULE, +}; + +int __init seg6_iptunnel_init(void) +{ + return lwtunnel_encap_add_ops(&seg6_iptun_ops, LWTUNNEL_ENCAP_SEG6); +} + +void seg6_iptunnel_exit(void) +{ + lwtunnel_encap_del_ops(&seg6_iptun_ops, LWTUNNEL_ENCAP_SEG6); +} diff --git a/net/ipv6/seg6_local.c b/net/ipv6/seg6_local.c new file mode 100644 index 000000000..8370726ae --- /dev/null +++ b/net/ipv6/seg6_local.c @@ -0,0 +1,2310 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * SR-IPv6 implementation + * + * Authors: + * David Lebrun + * eBPF support: Mathieu Xhonneux + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_IPV6_SEG6_HMAC +#include +#endif +#include +#include +#include +#include + +#define SEG6_F_ATTR(i) BIT(i) + +struct seg6_local_lwt; + +/* callbacks used for customizing the creation and destruction of a behavior */ +struct seg6_local_lwtunnel_ops { + int (*build_state)(struct seg6_local_lwt *slwt, const void *cfg, + struct netlink_ext_ack *extack); + void (*destroy_state)(struct seg6_local_lwt *slwt); +}; + +struct seg6_action_desc { + int action; + unsigned long attrs; + + /* The optattrs field is used for specifying all the optional + * attributes supported by a specific behavior. + * It means that if one of these attributes is not provided in the + * netlink message during the behavior creation, no errors will be + * returned to the userspace. + * + * Each attribute can be only of two types (mutually exclusive): + * 1) required or 2) optional. + * Every user MUST obey to this rule! If you set an attribute as + * required the same attribute CANNOT be set as optional and vice + * versa. + */ + unsigned long optattrs; + + int (*input)(struct sk_buff *skb, struct seg6_local_lwt *slwt); + int static_headroom; + + struct seg6_local_lwtunnel_ops slwt_ops; +}; + +struct bpf_lwt_prog { + struct bpf_prog *prog; + char *name; +}; + +/* default length values (expressed in bits) for both Locator-Block and + * Locator-Node Function. + * + * Both SEG6_LOCAL_LCBLOCK_DBITS and SEG6_LOCAL_LCNODE_FN_DBITS *must* be: + * i) greater than 0; + * ii) evenly divisible by 8. In other terms, the lengths of the + * Locator-Block and Locator-Node Function must be byte-aligned (we can + * relax this constraint in the future if really needed). + * + * Moreover, a third condition must hold: + * iii) SEG6_LOCAL_LCBLOCK_DBITS + SEG6_LOCAL_LCNODE_FN_DBITS <= 128. + * + * The correctness of SEG6_LOCAL_LCBLOCK_DBITS and SEG6_LOCAL_LCNODE_FN_DBITS + * values are checked during the kernel compilation. If the compilation stops, + * check the value of these parameters to see if they meet conditions (i), (ii) + * and (iii). + */ +#define SEG6_LOCAL_LCBLOCK_DBITS 32 +#define SEG6_LOCAL_LCNODE_FN_DBITS 16 + +/* The following next_csid_chk_{cntr,lcblock,lcblock_fn}_bits macros can be + * used directly to check whether the lengths (in bits) of Locator-Block and + * Locator-Node Function are valid according to (i), (ii), (iii). + */ +#define next_csid_chk_cntr_bits(blen, flen) \ + ((blen) + (flen) > 128) + +#define next_csid_chk_lcblock_bits(blen) \ +({ \ + typeof(blen) __tmp = blen; \ + (!__tmp || __tmp > 120 || (__tmp & 0x07)); \ +}) + +#define next_csid_chk_lcnode_fn_bits(flen) \ + next_csid_chk_lcblock_bits(flen) + +/* Supported Flavor operations are reported in this bitmask */ +#define SEG6_LOCAL_FLV_SUPP_OPS (BIT(SEG6_LOCAL_FLV_OP_NEXT_CSID)) + +struct seg6_flavors_info { + /* Flavor operations */ + __u32 flv_ops; + + /* Locator-Block length, expressed in bits */ + __u8 lcblock_bits; + /* Locator-Node Function length, expressed in bits*/ + __u8 lcnode_func_bits; +}; + +enum seg6_end_dt_mode { + DT_INVALID_MODE = -EINVAL, + DT_LEGACY_MODE = 0, + DT_VRF_MODE = 1, +}; + +struct seg6_end_dt_info { + enum seg6_end_dt_mode mode; + + struct net *net; + /* VRF device associated to the routing table used by the SRv6 + * End.DT4/DT6 behavior for routing IPv4/IPv6 packets. + */ + int vrf_ifindex; + int vrf_table; + + /* tunneled packet family (IPv4 or IPv6). + * Protocol and header length are inferred from family. + */ + u16 family; +}; + +struct pcpu_seg6_local_counters { + u64_stats_t packets; + u64_stats_t bytes; + u64_stats_t errors; + + struct u64_stats_sync syncp; +}; + +/* This struct groups all the SRv6 Behavior counters supported so far. + * + * put_nla_counters() makes use of this data structure to collect all counter + * values after the per-CPU counter evaluation has been performed. + * Finally, each counter value (in seg6_local_counters) is stored in the + * corresponding netlink attribute and sent to user space. + * + * NB: we don't want to expose this structure to user space! + */ +struct seg6_local_counters { + __u64 packets; + __u64 bytes; + __u64 errors; +}; + +#define seg6_local_alloc_pcpu_counters(__gfp) \ + __netdev_alloc_pcpu_stats(struct pcpu_seg6_local_counters, \ + ((__gfp) | __GFP_ZERO)) + +#define SEG6_F_LOCAL_COUNTERS SEG6_F_ATTR(SEG6_LOCAL_COUNTERS) + +struct seg6_local_lwt { + int action; + struct ipv6_sr_hdr *srh; + int table; + struct in_addr nh4; + struct in6_addr nh6; + int iif; + int oif; + struct bpf_lwt_prog bpf; +#ifdef CONFIG_NET_L3_MASTER_DEV + struct seg6_end_dt_info dt_info; +#endif + struct seg6_flavors_info flv_info; + + struct pcpu_seg6_local_counters __percpu *pcpu_counters; + + int headroom; + struct seg6_action_desc *desc; + /* unlike the required attrs, we have to track the optional attributes + * that have been effectively parsed. + */ + unsigned long parsed_optattrs; +}; + +static struct seg6_local_lwt *seg6_local_lwtunnel(struct lwtunnel_state *lwt) +{ + return (struct seg6_local_lwt *)lwt->data; +} + +static struct ipv6_sr_hdr *get_and_validate_srh(struct sk_buff *skb) +{ + struct ipv6_sr_hdr *srh; + + srh = seg6_get_srh(skb, IP6_FH_F_SKIP_RH); + if (!srh) + return NULL; + +#ifdef CONFIG_IPV6_SEG6_HMAC + if (!seg6_hmac_validate_skb(skb)) + return NULL; +#endif + + return srh; +} + +static bool decap_and_validate(struct sk_buff *skb, int proto) +{ + struct ipv6_sr_hdr *srh; + unsigned int off = 0; + + srh = seg6_get_srh(skb, 0); + if (srh && srh->segments_left > 0) + return false; + +#ifdef CONFIG_IPV6_SEG6_HMAC + if (srh && !seg6_hmac_validate_skb(skb)) + return false; +#endif + + if (ipv6_find_hdr(skb, &off, proto, NULL, NULL) < 0) + return false; + + if (!pskb_pull(skb, off)) + return false; + + skb_postpull_rcsum(skb, skb_network_header(skb), off); + + skb_reset_network_header(skb); + skb_reset_transport_header(skb); + if (iptunnel_pull_offloads(skb)) + return false; + + return true; +} + +static void advance_nextseg(struct ipv6_sr_hdr *srh, struct in6_addr *daddr) +{ + struct in6_addr *addr; + + srh->segments_left--; + addr = srh->segments + srh->segments_left; + *daddr = *addr; +} + +static int +seg6_lookup_any_nexthop(struct sk_buff *skb, struct in6_addr *nhaddr, + u32 tbl_id, bool local_delivery) +{ + struct net *net = dev_net(skb->dev); + struct ipv6hdr *hdr = ipv6_hdr(skb); + int flags = RT6_LOOKUP_F_HAS_SADDR; + struct dst_entry *dst = NULL; + struct rt6_info *rt; + struct flowi6 fl6; + int dev_flags = 0; + + memset(&fl6, 0, sizeof(fl6)); + fl6.flowi6_iif = skb->dev->ifindex; + fl6.daddr = nhaddr ? *nhaddr : hdr->daddr; + fl6.saddr = hdr->saddr; + fl6.flowlabel = ip6_flowinfo(hdr); + fl6.flowi6_mark = skb->mark; + fl6.flowi6_proto = hdr->nexthdr; + + if (nhaddr) + fl6.flowi6_flags = FLOWI_FLAG_KNOWN_NH; + + if (!tbl_id) { + dst = ip6_route_input_lookup(net, skb->dev, &fl6, skb, flags); + } else { + struct fib6_table *table; + + table = fib6_get_table(net, tbl_id); + if (!table) + goto out; + + rt = ip6_pol_route(net, table, 0, &fl6, skb, flags); + dst = &rt->dst; + } + + /* we want to discard traffic destined for local packet processing, + * if @local_delivery is set to false. + */ + if (!local_delivery) + dev_flags |= IFF_LOOPBACK; + + if (dst && (dst->dev->flags & dev_flags) && !dst->error) { + dst_release(dst); + dst = NULL; + } + +out: + if (!dst) { + rt = net->ipv6.ip6_blk_hole_entry; + dst = &rt->dst; + dst_hold(dst); + } + + skb_dst_drop(skb); + skb_dst_set(skb, dst); + return dst->error; +} + +int seg6_lookup_nexthop(struct sk_buff *skb, + struct in6_addr *nhaddr, u32 tbl_id) +{ + return seg6_lookup_any_nexthop(skb, nhaddr, tbl_id, false); +} + +static __u8 seg6_flv_lcblock_octects(const struct seg6_flavors_info *finfo) +{ + return finfo->lcblock_bits >> 3; +} + +static __u8 seg6_flv_lcnode_func_octects(const struct seg6_flavors_info *finfo) +{ + return finfo->lcnode_func_bits >> 3; +} + +static bool seg6_next_csid_is_arg_zero(const struct in6_addr *addr, + const struct seg6_flavors_info *finfo) +{ + __u8 fnc_octects = seg6_flv_lcnode_func_octects(finfo); + __u8 blk_octects = seg6_flv_lcblock_octects(finfo); + __u8 arg_octects; + int i; + + arg_octects = 16 - blk_octects - fnc_octects; + for (i = 0; i < arg_octects; ++i) { + if (addr->s6_addr[blk_octects + fnc_octects + i] != 0x00) + return false; + } + + return true; +} + +/* assume that DA.Argument length > 0 */ +static void seg6_next_csid_advance_arg(struct in6_addr *addr, + const struct seg6_flavors_info *finfo) +{ + __u8 fnc_octects = seg6_flv_lcnode_func_octects(finfo); + __u8 blk_octects = seg6_flv_lcblock_octects(finfo); + + /* advance DA.Argument */ + memmove(&addr->s6_addr[blk_octects], + &addr->s6_addr[blk_octects + fnc_octects], + 16 - blk_octects - fnc_octects); + + memset(&addr->s6_addr[16 - fnc_octects], 0x00, fnc_octects); +} + +static int input_action_end_core(struct sk_buff *skb, + struct seg6_local_lwt *slwt) +{ + struct ipv6_sr_hdr *srh; + + srh = get_and_validate_srh(skb); + if (!srh) + goto drop; + + advance_nextseg(srh, &ipv6_hdr(skb)->daddr); + + seg6_lookup_nexthop(skb, NULL, 0); + + return dst_input(skb); + +drop: + kfree_skb(skb); + return -EINVAL; +} + +static int end_next_csid_core(struct sk_buff *skb, struct seg6_local_lwt *slwt) +{ + const struct seg6_flavors_info *finfo = &slwt->flv_info; + struct in6_addr *daddr = &ipv6_hdr(skb)->daddr; + + if (seg6_next_csid_is_arg_zero(daddr, finfo)) + return input_action_end_core(skb, slwt); + + /* update DA */ + seg6_next_csid_advance_arg(daddr, finfo); + + seg6_lookup_nexthop(skb, NULL, 0); + + return dst_input(skb); +} + +static bool seg6_next_csid_enabled(__u32 fops) +{ + return fops & BIT(SEG6_LOCAL_FLV_OP_NEXT_CSID); +} + +/* regular endpoint function */ +static int input_action_end(struct sk_buff *skb, struct seg6_local_lwt *slwt) +{ + const struct seg6_flavors_info *finfo = &slwt->flv_info; + + if (seg6_next_csid_enabled(finfo->flv_ops)) + return end_next_csid_core(skb, slwt); + + return input_action_end_core(skb, slwt); +} + +/* regular endpoint, and forward to specified nexthop */ +static int input_action_end_x(struct sk_buff *skb, struct seg6_local_lwt *slwt) +{ + struct ipv6_sr_hdr *srh; + + srh = get_and_validate_srh(skb); + if (!srh) + goto drop; + + advance_nextseg(srh, &ipv6_hdr(skb)->daddr); + + seg6_lookup_nexthop(skb, &slwt->nh6, 0); + + return dst_input(skb); + +drop: + kfree_skb(skb); + return -EINVAL; +} + +static int input_action_end_t(struct sk_buff *skb, struct seg6_local_lwt *slwt) +{ + struct ipv6_sr_hdr *srh; + + srh = get_and_validate_srh(skb); + if (!srh) + goto drop; + + advance_nextseg(srh, &ipv6_hdr(skb)->daddr); + + seg6_lookup_nexthop(skb, NULL, slwt->table); + + return dst_input(skb); + +drop: + kfree_skb(skb); + return -EINVAL; +} + +/* decapsulate and forward inner L2 frame on specified interface */ +static int input_action_end_dx2(struct sk_buff *skb, + struct seg6_local_lwt *slwt) +{ + struct net *net = dev_net(skb->dev); + struct net_device *odev; + struct ethhdr *eth; + + if (!decap_and_validate(skb, IPPROTO_ETHERNET)) + goto drop; + + if (!pskb_may_pull(skb, ETH_HLEN)) + goto drop; + + skb_reset_mac_header(skb); + eth = (struct ethhdr *)skb->data; + + /* To determine the frame's protocol, we assume it is 802.3. This avoids + * a call to eth_type_trans(), which is not really relevant for our + * use case. + */ + if (!eth_proto_is_802_3(eth->h_proto)) + goto drop; + + odev = dev_get_by_index_rcu(net, slwt->oif); + if (!odev) + goto drop; + + /* As we accept Ethernet frames, make sure the egress device is of + * the correct type. + */ + if (odev->type != ARPHRD_ETHER) + goto drop; + + if (!(odev->flags & IFF_UP) || !netif_carrier_ok(odev)) + goto drop; + + skb_orphan(skb); + + if (skb_warn_if_lro(skb)) + goto drop; + + skb_forward_csum(skb); + + if (skb->len - ETH_HLEN > odev->mtu) + goto drop; + + skb->dev = odev; + skb->protocol = eth->h_proto; + + return dev_queue_xmit(skb); + +drop: + kfree_skb(skb); + return -EINVAL; +} + +static int input_action_end_dx6_finish(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + struct dst_entry *orig_dst = skb_dst(skb); + struct in6_addr *nhaddr = NULL; + struct seg6_local_lwt *slwt; + + slwt = seg6_local_lwtunnel(orig_dst->lwtstate); + + /* The inner packet is not associated to any local interface, + * so we do not call netif_rx(). + * + * If slwt->nh6 is set to ::, then lookup the nexthop for the + * inner packet's DA. Otherwise, use the specified nexthop. + */ + if (!ipv6_addr_any(&slwt->nh6)) + nhaddr = &slwt->nh6; + + seg6_lookup_nexthop(skb, nhaddr, 0); + + return dst_input(skb); +} + +/* decapsulate and forward to specified nexthop */ +static int input_action_end_dx6(struct sk_buff *skb, + struct seg6_local_lwt *slwt) +{ + /* this function accepts IPv6 encapsulated packets, with either + * an SRH with SL=0, or no SRH. + */ + + if (!decap_and_validate(skb, IPPROTO_IPV6)) + goto drop; + + if (!pskb_may_pull(skb, sizeof(struct ipv6hdr))) + goto drop; + + skb_set_transport_header(skb, sizeof(struct ipv6hdr)); + nf_reset_ct(skb); + + if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled)) + return NF_HOOK(NFPROTO_IPV6, NF_INET_PRE_ROUTING, + dev_net(skb->dev), NULL, skb, NULL, + skb_dst(skb)->dev, input_action_end_dx6_finish); + + return input_action_end_dx6_finish(dev_net(skb->dev), NULL, skb); +drop: + kfree_skb(skb); + return -EINVAL; +} + +static int input_action_end_dx4_finish(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + struct dst_entry *orig_dst = skb_dst(skb); + struct seg6_local_lwt *slwt; + struct iphdr *iph; + __be32 nhaddr; + int err; + + slwt = seg6_local_lwtunnel(orig_dst->lwtstate); + + iph = ip_hdr(skb); + + nhaddr = slwt->nh4.s_addr ?: iph->daddr; + + skb_dst_drop(skb); + + err = ip_route_input(skb, nhaddr, iph->saddr, 0, skb->dev); + if (err) { + kfree_skb(skb); + return -EINVAL; + } + + return dst_input(skb); +} + +static int input_action_end_dx4(struct sk_buff *skb, + struct seg6_local_lwt *slwt) +{ + if (!decap_and_validate(skb, IPPROTO_IPIP)) + goto drop; + + if (!pskb_may_pull(skb, sizeof(struct iphdr))) + goto drop; + + skb->protocol = htons(ETH_P_IP); + skb_set_transport_header(skb, sizeof(struct iphdr)); + nf_reset_ct(skb); + + if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled)) + return NF_HOOK(NFPROTO_IPV4, NF_INET_PRE_ROUTING, + dev_net(skb->dev), NULL, skb, NULL, + skb_dst(skb)->dev, input_action_end_dx4_finish); + + return input_action_end_dx4_finish(dev_net(skb->dev), NULL, skb); +drop: + kfree_skb(skb); + return -EINVAL; +} + +#ifdef CONFIG_NET_L3_MASTER_DEV +static struct net *fib6_config_get_net(const struct fib6_config *fib6_cfg) +{ + const struct nl_info *nli = &fib6_cfg->fc_nlinfo; + + return nli->nl_net; +} + +static int __seg6_end_dt_vrf_build(struct seg6_local_lwt *slwt, const void *cfg, + u16 family, struct netlink_ext_ack *extack) +{ + struct seg6_end_dt_info *info = &slwt->dt_info; + int vrf_ifindex; + struct net *net; + + net = fib6_config_get_net(cfg); + + /* note that vrf_table was already set by parse_nla_vrftable() */ + vrf_ifindex = l3mdev_ifindex_lookup_by_table_id(L3MDEV_TYPE_VRF, net, + info->vrf_table); + if (vrf_ifindex < 0) { + if (vrf_ifindex == -EPERM) { + NL_SET_ERR_MSG(extack, + "Strict mode for VRF is disabled"); + } else if (vrf_ifindex == -ENODEV) { + NL_SET_ERR_MSG(extack, + "Table has no associated VRF device"); + } else { + pr_debug("seg6local: SRv6 End.DT* creation error=%d\n", + vrf_ifindex); + } + + return vrf_ifindex; + } + + info->net = net; + info->vrf_ifindex = vrf_ifindex; + + info->family = family; + info->mode = DT_VRF_MODE; + + return 0; +} + +/* The SRv6 End.DT4/DT6 behavior extracts the inner (IPv4/IPv6) packet and + * routes the IPv4/IPv6 packet by looking at the configured routing table. + * + * In the SRv6 End.DT4/DT6 use case, we can receive traffic (IPv6+Segment + * Routing Header packets) from several interfaces and the outer IPv6 + * destination address (DA) is used for retrieving the specific instance of the + * End.DT4/DT6 behavior that should process the packets. + * + * However, the inner IPv4/IPv6 packet is not really bound to any receiving + * interface and thus the End.DT4/DT6 sets the VRF (associated with the + * corresponding routing table) as the *receiving* interface. + * In other words, the End.DT4/DT6 processes a packet as if it has been received + * directly by the VRF (and not by one of its slave devices, if any). + * In this way, the VRF interface is used for routing the IPv4/IPv6 packet in + * according to the routing table configured by the End.DT4/DT6 instance. + * + * This design allows you to get some interesting features like: + * 1) the statistics on rx packets; + * 2) the possibility to install a packet sniffer on the receiving interface + * (the VRF one) for looking at the incoming packets; + * 3) the possibility to leverage the netfilter prerouting hook for the inner + * IPv4 packet. + * + * This function returns: + * - the sk_buff* when the VRF rcv handler has processed the packet correctly; + * - NULL when the skb is consumed by the VRF rcv handler; + * - a pointer which encodes a negative error number in case of error. + * Note that in this case, the function takes care of freeing the skb. + */ +static struct sk_buff *end_dt_vrf_rcv(struct sk_buff *skb, u16 family, + struct net_device *dev) +{ + /* based on l3mdev_ip_rcv; we are only interested in the master */ + if (unlikely(!netif_is_l3_master(dev) && !netif_has_l3_rx_handler(dev))) + goto drop; + + if (unlikely(!dev->l3mdev_ops->l3mdev_l3_rcv)) + goto drop; + + /* the decap packet IPv4/IPv6 does not come with any mac header info. + * We must unset the mac header to allow the VRF device to rebuild it, + * just in case there is a sniffer attached on the device. + */ + skb_unset_mac_header(skb); + + skb = dev->l3mdev_ops->l3mdev_l3_rcv(dev, skb, family); + if (!skb) + /* the skb buffer was consumed by the handler */ + return NULL; + + /* when a packet is received by a VRF or by one of its slaves, the + * master device reference is set into the skb. + */ + if (unlikely(skb->dev != dev || skb->skb_iif != dev->ifindex)) + goto drop; + + return skb; + +drop: + kfree_skb(skb); + return ERR_PTR(-EINVAL); +} + +static struct net_device *end_dt_get_vrf_rcu(struct sk_buff *skb, + struct seg6_end_dt_info *info) +{ + int vrf_ifindex = info->vrf_ifindex; + struct net *net = info->net; + + if (unlikely(vrf_ifindex < 0)) + goto error; + + if (unlikely(!net_eq(dev_net(skb->dev), net))) + goto error; + + return dev_get_by_index_rcu(net, vrf_ifindex); + +error: + return NULL; +} + +static struct sk_buff *end_dt_vrf_core(struct sk_buff *skb, + struct seg6_local_lwt *slwt, u16 family) +{ + struct seg6_end_dt_info *info = &slwt->dt_info; + struct net_device *vrf; + __be16 protocol; + int hdrlen; + + vrf = end_dt_get_vrf_rcu(skb, info); + if (unlikely(!vrf)) + goto drop; + + switch (family) { + case AF_INET: + protocol = htons(ETH_P_IP); + hdrlen = sizeof(struct iphdr); + break; + case AF_INET6: + protocol = htons(ETH_P_IPV6); + hdrlen = sizeof(struct ipv6hdr); + break; + case AF_UNSPEC: + fallthrough; + default: + goto drop; + } + + if (unlikely(info->family != AF_UNSPEC && info->family != family)) { + pr_warn_once("seg6local: SRv6 End.DT* family mismatch"); + goto drop; + } + + skb->protocol = protocol; + + skb_dst_drop(skb); + + skb_set_transport_header(skb, hdrlen); + nf_reset_ct(skb); + + return end_dt_vrf_rcv(skb, family, vrf); + +drop: + kfree_skb(skb); + return ERR_PTR(-EINVAL); +} + +static int input_action_end_dt4(struct sk_buff *skb, + struct seg6_local_lwt *slwt) +{ + struct iphdr *iph; + int err; + + if (!decap_and_validate(skb, IPPROTO_IPIP)) + goto drop; + + if (!pskb_may_pull(skb, sizeof(struct iphdr))) + goto drop; + + skb = end_dt_vrf_core(skb, slwt, AF_INET); + if (!skb) + /* packet has been processed and consumed by the VRF */ + return 0; + + if (IS_ERR(skb)) + return PTR_ERR(skb); + + iph = ip_hdr(skb); + + err = ip_route_input(skb, iph->daddr, iph->saddr, 0, skb->dev); + if (unlikely(err)) + goto drop; + + return dst_input(skb); + +drop: + kfree_skb(skb); + return -EINVAL; +} + +static int seg6_end_dt4_build(struct seg6_local_lwt *slwt, const void *cfg, + struct netlink_ext_ack *extack) +{ + return __seg6_end_dt_vrf_build(slwt, cfg, AF_INET, extack); +} + +static enum +seg6_end_dt_mode seg6_end_dt6_parse_mode(struct seg6_local_lwt *slwt) +{ + unsigned long parsed_optattrs = slwt->parsed_optattrs; + bool legacy, vrfmode; + + legacy = !!(parsed_optattrs & SEG6_F_ATTR(SEG6_LOCAL_TABLE)); + vrfmode = !!(parsed_optattrs & SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE)); + + if (!(legacy ^ vrfmode)) + /* both are absent or present: invalid DT6 mode */ + return DT_INVALID_MODE; + + return legacy ? DT_LEGACY_MODE : DT_VRF_MODE; +} + +static enum seg6_end_dt_mode seg6_end_dt6_get_mode(struct seg6_local_lwt *slwt) +{ + struct seg6_end_dt_info *info = &slwt->dt_info; + + return info->mode; +} + +static int seg6_end_dt6_build(struct seg6_local_lwt *slwt, const void *cfg, + struct netlink_ext_ack *extack) +{ + enum seg6_end_dt_mode mode = seg6_end_dt6_parse_mode(slwt); + struct seg6_end_dt_info *info = &slwt->dt_info; + + switch (mode) { + case DT_LEGACY_MODE: + info->mode = DT_LEGACY_MODE; + return 0; + case DT_VRF_MODE: + return __seg6_end_dt_vrf_build(slwt, cfg, AF_INET6, extack); + default: + NL_SET_ERR_MSG(extack, "table or vrftable must be specified"); + return -EINVAL; + } +} +#endif + +static int input_action_end_dt6(struct sk_buff *skb, + struct seg6_local_lwt *slwt) +{ + if (!decap_and_validate(skb, IPPROTO_IPV6)) + goto drop; + + if (!pskb_may_pull(skb, sizeof(struct ipv6hdr))) + goto drop; + +#ifdef CONFIG_NET_L3_MASTER_DEV + if (seg6_end_dt6_get_mode(slwt) == DT_LEGACY_MODE) + goto legacy_mode; + + /* DT6_VRF_MODE */ + skb = end_dt_vrf_core(skb, slwt, AF_INET6); + if (!skb) + /* packet has been processed and consumed by the VRF */ + return 0; + + if (IS_ERR(skb)) + return PTR_ERR(skb); + + /* note: this time we do not need to specify the table because the VRF + * takes care of selecting the correct table. + */ + seg6_lookup_any_nexthop(skb, NULL, 0, true); + + return dst_input(skb); + +legacy_mode: +#endif + skb_set_transport_header(skb, sizeof(struct ipv6hdr)); + + seg6_lookup_any_nexthop(skb, NULL, slwt->table, true); + + return dst_input(skb); + +drop: + kfree_skb(skb); + return -EINVAL; +} + +#ifdef CONFIG_NET_L3_MASTER_DEV +static int seg6_end_dt46_build(struct seg6_local_lwt *slwt, const void *cfg, + struct netlink_ext_ack *extack) +{ + return __seg6_end_dt_vrf_build(slwt, cfg, AF_UNSPEC, extack); +} + +static int input_action_end_dt46(struct sk_buff *skb, + struct seg6_local_lwt *slwt) +{ + unsigned int off = 0; + int nexthdr; + + nexthdr = ipv6_find_hdr(skb, &off, -1, NULL, NULL); + if (unlikely(nexthdr < 0)) + goto drop; + + switch (nexthdr) { + case IPPROTO_IPIP: + return input_action_end_dt4(skb, slwt); + case IPPROTO_IPV6: + return input_action_end_dt6(skb, slwt); + } + +drop: + kfree_skb(skb); + return -EINVAL; +} +#endif + +/* push an SRH on top of the current one */ +static int input_action_end_b6(struct sk_buff *skb, struct seg6_local_lwt *slwt) +{ + struct ipv6_sr_hdr *srh; + int err = -EINVAL; + + srh = get_and_validate_srh(skb); + if (!srh) + goto drop; + + err = seg6_do_srh_inline(skb, slwt->srh); + if (err) + goto drop; + + skb_set_transport_header(skb, sizeof(struct ipv6hdr)); + + seg6_lookup_nexthop(skb, NULL, 0); + + return dst_input(skb); + +drop: + kfree_skb(skb); + return err; +} + +/* encapsulate within an outer IPv6 header and a specified SRH */ +static int input_action_end_b6_encap(struct sk_buff *skb, + struct seg6_local_lwt *slwt) +{ + struct ipv6_sr_hdr *srh; + int err = -EINVAL; + + srh = get_and_validate_srh(skb); + if (!srh) + goto drop; + + advance_nextseg(srh, &ipv6_hdr(skb)->daddr); + + skb_reset_inner_headers(skb); + skb->encapsulation = 1; + + err = seg6_do_srh_encap(skb, slwt->srh, IPPROTO_IPV6); + if (err) + goto drop; + + skb_set_transport_header(skb, sizeof(struct ipv6hdr)); + + seg6_lookup_nexthop(skb, NULL, 0); + + return dst_input(skb); + +drop: + kfree_skb(skb); + return err; +} + +DEFINE_PER_CPU(struct seg6_bpf_srh_state, seg6_bpf_srh_states); + +bool seg6_bpf_has_valid_srh(struct sk_buff *skb) +{ + struct seg6_bpf_srh_state *srh_state = + this_cpu_ptr(&seg6_bpf_srh_states); + struct ipv6_sr_hdr *srh = srh_state->srh; + + if (unlikely(srh == NULL)) + return false; + + if (unlikely(!srh_state->valid)) { + if ((srh_state->hdrlen & 7) != 0) + return false; + + srh->hdrlen = (u8)(srh_state->hdrlen >> 3); + if (!seg6_validate_srh(srh, (srh->hdrlen + 1) << 3, true)) + return false; + + srh_state->valid = true; + } + + return true; +} + +static int input_action_end_bpf(struct sk_buff *skb, + struct seg6_local_lwt *slwt) +{ + struct seg6_bpf_srh_state *srh_state = + this_cpu_ptr(&seg6_bpf_srh_states); + struct ipv6_sr_hdr *srh; + int ret; + + srh = get_and_validate_srh(skb); + if (!srh) { + kfree_skb(skb); + return -EINVAL; + } + advance_nextseg(srh, &ipv6_hdr(skb)->daddr); + + /* preempt_disable is needed to protect the per-CPU buffer srh_state, + * which is also accessed by the bpf_lwt_seg6_* helpers + */ + preempt_disable(); + srh_state->srh = srh; + srh_state->hdrlen = srh->hdrlen << 3; + srh_state->valid = true; + + rcu_read_lock(); + bpf_compute_data_pointers(skb); + ret = bpf_prog_run_save_cb(slwt->bpf.prog, skb); + rcu_read_unlock(); + + switch (ret) { + case BPF_OK: + case BPF_REDIRECT: + break; + case BPF_DROP: + goto drop; + default: + pr_warn_once("bpf-seg6local: Illegal return value %u\n", ret); + goto drop; + } + + if (srh_state->srh && !seg6_bpf_has_valid_srh(skb)) + goto drop; + + preempt_enable(); + if (ret != BPF_REDIRECT) + seg6_lookup_nexthop(skb, NULL, 0); + + return dst_input(skb); + +drop: + preempt_enable(); + kfree_skb(skb); + return -EINVAL; +} + +static struct seg6_action_desc seg6_action_table[] = { + { + .action = SEG6_LOCAL_ACTION_END, + .attrs = 0, + .optattrs = SEG6_F_LOCAL_COUNTERS | + SEG6_F_ATTR(SEG6_LOCAL_FLAVORS), + .input = input_action_end, + }, + { + .action = SEG6_LOCAL_ACTION_END_X, + .attrs = SEG6_F_ATTR(SEG6_LOCAL_NH6), + .optattrs = SEG6_F_LOCAL_COUNTERS, + .input = input_action_end_x, + }, + { + .action = SEG6_LOCAL_ACTION_END_T, + .attrs = SEG6_F_ATTR(SEG6_LOCAL_TABLE), + .optattrs = SEG6_F_LOCAL_COUNTERS, + .input = input_action_end_t, + }, + { + .action = SEG6_LOCAL_ACTION_END_DX2, + .attrs = SEG6_F_ATTR(SEG6_LOCAL_OIF), + .optattrs = SEG6_F_LOCAL_COUNTERS, + .input = input_action_end_dx2, + }, + { + .action = SEG6_LOCAL_ACTION_END_DX6, + .attrs = SEG6_F_ATTR(SEG6_LOCAL_NH6), + .optattrs = SEG6_F_LOCAL_COUNTERS, + .input = input_action_end_dx6, + }, + { + .action = SEG6_LOCAL_ACTION_END_DX4, + .attrs = SEG6_F_ATTR(SEG6_LOCAL_NH4), + .optattrs = SEG6_F_LOCAL_COUNTERS, + .input = input_action_end_dx4, + }, + { + .action = SEG6_LOCAL_ACTION_END_DT4, + .attrs = SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE), + .optattrs = SEG6_F_LOCAL_COUNTERS, +#ifdef CONFIG_NET_L3_MASTER_DEV + .input = input_action_end_dt4, + .slwt_ops = { + .build_state = seg6_end_dt4_build, + }, +#endif + }, + { + .action = SEG6_LOCAL_ACTION_END_DT6, +#ifdef CONFIG_NET_L3_MASTER_DEV + .attrs = 0, + .optattrs = SEG6_F_LOCAL_COUNTERS | + SEG6_F_ATTR(SEG6_LOCAL_TABLE) | + SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE), + .slwt_ops = { + .build_state = seg6_end_dt6_build, + }, +#else + .attrs = SEG6_F_ATTR(SEG6_LOCAL_TABLE), + .optattrs = SEG6_F_LOCAL_COUNTERS, +#endif + .input = input_action_end_dt6, + }, + { + .action = SEG6_LOCAL_ACTION_END_DT46, + .attrs = SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE), + .optattrs = SEG6_F_LOCAL_COUNTERS, +#ifdef CONFIG_NET_L3_MASTER_DEV + .input = input_action_end_dt46, + .slwt_ops = { + .build_state = seg6_end_dt46_build, + }, +#endif + }, + { + .action = SEG6_LOCAL_ACTION_END_B6, + .attrs = SEG6_F_ATTR(SEG6_LOCAL_SRH), + .optattrs = SEG6_F_LOCAL_COUNTERS, + .input = input_action_end_b6, + }, + { + .action = SEG6_LOCAL_ACTION_END_B6_ENCAP, + .attrs = SEG6_F_ATTR(SEG6_LOCAL_SRH), + .optattrs = SEG6_F_LOCAL_COUNTERS, + .input = input_action_end_b6_encap, + .static_headroom = sizeof(struct ipv6hdr), + }, + { + .action = SEG6_LOCAL_ACTION_END_BPF, + .attrs = SEG6_F_ATTR(SEG6_LOCAL_BPF), + .optattrs = SEG6_F_LOCAL_COUNTERS, + .input = input_action_end_bpf, + }, + +}; + +static struct seg6_action_desc *__get_action_desc(int action) +{ + struct seg6_action_desc *desc; + int i, count; + + count = ARRAY_SIZE(seg6_action_table); + for (i = 0; i < count; i++) { + desc = &seg6_action_table[i]; + if (desc->action == action) + return desc; + } + + return NULL; +} + +static bool seg6_lwtunnel_counters_enabled(struct seg6_local_lwt *slwt) +{ + return slwt->parsed_optattrs & SEG6_F_LOCAL_COUNTERS; +} + +static void seg6_local_update_counters(struct seg6_local_lwt *slwt, + unsigned int len, int err) +{ + struct pcpu_seg6_local_counters *pcounters; + + pcounters = this_cpu_ptr(slwt->pcpu_counters); + u64_stats_update_begin(&pcounters->syncp); + + if (likely(!err)) { + u64_stats_inc(&pcounters->packets); + u64_stats_add(&pcounters->bytes, len); + } else { + u64_stats_inc(&pcounters->errors); + } + + u64_stats_update_end(&pcounters->syncp); +} + +static int seg6_local_input_core(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + struct dst_entry *orig_dst = skb_dst(skb); + struct seg6_action_desc *desc; + struct seg6_local_lwt *slwt; + unsigned int len = skb->len; + int rc; + + slwt = seg6_local_lwtunnel(orig_dst->lwtstate); + desc = slwt->desc; + + rc = desc->input(skb, slwt); + + if (!seg6_lwtunnel_counters_enabled(slwt)) + return rc; + + seg6_local_update_counters(slwt, len, rc); + + return rc; +} + +static int seg6_local_input(struct sk_buff *skb) +{ + if (skb->protocol != htons(ETH_P_IPV6)) { + kfree_skb(skb); + return -EINVAL; + } + + if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled)) + return NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_IN, + dev_net(skb->dev), NULL, skb, skb->dev, NULL, + seg6_local_input_core); + + return seg6_local_input_core(dev_net(skb->dev), NULL, skb); +} + +static const struct nla_policy seg6_local_policy[SEG6_LOCAL_MAX + 1] = { + [SEG6_LOCAL_ACTION] = { .type = NLA_U32 }, + [SEG6_LOCAL_SRH] = { .type = NLA_BINARY }, + [SEG6_LOCAL_TABLE] = { .type = NLA_U32 }, + [SEG6_LOCAL_VRFTABLE] = { .type = NLA_U32 }, + [SEG6_LOCAL_NH4] = { .type = NLA_BINARY, + .len = sizeof(struct in_addr) }, + [SEG6_LOCAL_NH6] = { .type = NLA_BINARY, + .len = sizeof(struct in6_addr) }, + [SEG6_LOCAL_IIF] = { .type = NLA_U32 }, + [SEG6_LOCAL_OIF] = { .type = NLA_U32 }, + [SEG6_LOCAL_BPF] = { .type = NLA_NESTED }, + [SEG6_LOCAL_COUNTERS] = { .type = NLA_NESTED }, + [SEG6_LOCAL_FLAVORS] = { .type = NLA_NESTED }, +}; + +static int parse_nla_srh(struct nlattr **attrs, struct seg6_local_lwt *slwt, + struct netlink_ext_ack *extack) +{ + struct ipv6_sr_hdr *srh; + int len; + + srh = nla_data(attrs[SEG6_LOCAL_SRH]); + len = nla_len(attrs[SEG6_LOCAL_SRH]); + + /* SRH must contain at least one segment */ + if (len < sizeof(*srh) + sizeof(struct in6_addr)) + return -EINVAL; + + if (!seg6_validate_srh(srh, len, false)) + return -EINVAL; + + slwt->srh = kmemdup(srh, len, GFP_KERNEL); + if (!slwt->srh) + return -ENOMEM; + + slwt->headroom += len; + + return 0; +} + +static int put_nla_srh(struct sk_buff *skb, struct seg6_local_lwt *slwt) +{ + struct ipv6_sr_hdr *srh; + struct nlattr *nla; + int len; + + srh = slwt->srh; + len = (srh->hdrlen + 1) << 3; + + nla = nla_reserve(skb, SEG6_LOCAL_SRH, len); + if (!nla) + return -EMSGSIZE; + + memcpy(nla_data(nla), srh, len); + + return 0; +} + +static int cmp_nla_srh(struct seg6_local_lwt *a, struct seg6_local_lwt *b) +{ + int len = (a->srh->hdrlen + 1) << 3; + + if (len != ((b->srh->hdrlen + 1) << 3)) + return 1; + + return memcmp(a->srh, b->srh, len); +} + +static void destroy_attr_srh(struct seg6_local_lwt *slwt) +{ + kfree(slwt->srh); +} + +static int parse_nla_table(struct nlattr **attrs, struct seg6_local_lwt *slwt, + struct netlink_ext_ack *extack) +{ + slwt->table = nla_get_u32(attrs[SEG6_LOCAL_TABLE]); + + return 0; +} + +static int put_nla_table(struct sk_buff *skb, struct seg6_local_lwt *slwt) +{ + if (nla_put_u32(skb, SEG6_LOCAL_TABLE, slwt->table)) + return -EMSGSIZE; + + return 0; +} + +static int cmp_nla_table(struct seg6_local_lwt *a, struct seg6_local_lwt *b) +{ + if (a->table != b->table) + return 1; + + return 0; +} + +static struct +seg6_end_dt_info *seg6_possible_end_dt_info(struct seg6_local_lwt *slwt) +{ +#ifdef CONFIG_NET_L3_MASTER_DEV + return &slwt->dt_info; +#else + return ERR_PTR(-EOPNOTSUPP); +#endif +} + +static int parse_nla_vrftable(struct nlattr **attrs, + struct seg6_local_lwt *slwt, + struct netlink_ext_ack *extack) +{ + struct seg6_end_dt_info *info = seg6_possible_end_dt_info(slwt); + + if (IS_ERR(info)) + return PTR_ERR(info); + + info->vrf_table = nla_get_u32(attrs[SEG6_LOCAL_VRFTABLE]); + + return 0; +} + +static int put_nla_vrftable(struct sk_buff *skb, struct seg6_local_lwt *slwt) +{ + struct seg6_end_dt_info *info = seg6_possible_end_dt_info(slwt); + + if (IS_ERR(info)) + return PTR_ERR(info); + + if (nla_put_u32(skb, SEG6_LOCAL_VRFTABLE, info->vrf_table)) + return -EMSGSIZE; + + return 0; +} + +static int cmp_nla_vrftable(struct seg6_local_lwt *a, struct seg6_local_lwt *b) +{ + struct seg6_end_dt_info *info_a = seg6_possible_end_dt_info(a); + struct seg6_end_dt_info *info_b = seg6_possible_end_dt_info(b); + + if (info_a->vrf_table != info_b->vrf_table) + return 1; + + return 0; +} + +static int parse_nla_nh4(struct nlattr **attrs, struct seg6_local_lwt *slwt, + struct netlink_ext_ack *extack) +{ + memcpy(&slwt->nh4, nla_data(attrs[SEG6_LOCAL_NH4]), + sizeof(struct in_addr)); + + return 0; +} + +static int put_nla_nh4(struct sk_buff *skb, struct seg6_local_lwt *slwt) +{ + struct nlattr *nla; + + nla = nla_reserve(skb, SEG6_LOCAL_NH4, sizeof(struct in_addr)); + if (!nla) + return -EMSGSIZE; + + memcpy(nla_data(nla), &slwt->nh4, sizeof(struct in_addr)); + + return 0; +} + +static int cmp_nla_nh4(struct seg6_local_lwt *a, struct seg6_local_lwt *b) +{ + return memcmp(&a->nh4, &b->nh4, sizeof(struct in_addr)); +} + +static int parse_nla_nh6(struct nlattr **attrs, struct seg6_local_lwt *slwt, + struct netlink_ext_ack *extack) +{ + memcpy(&slwt->nh6, nla_data(attrs[SEG6_LOCAL_NH6]), + sizeof(struct in6_addr)); + + return 0; +} + +static int put_nla_nh6(struct sk_buff *skb, struct seg6_local_lwt *slwt) +{ + struct nlattr *nla; + + nla = nla_reserve(skb, SEG6_LOCAL_NH6, sizeof(struct in6_addr)); + if (!nla) + return -EMSGSIZE; + + memcpy(nla_data(nla), &slwt->nh6, sizeof(struct in6_addr)); + + return 0; +} + +static int cmp_nla_nh6(struct seg6_local_lwt *a, struct seg6_local_lwt *b) +{ + return memcmp(&a->nh6, &b->nh6, sizeof(struct in6_addr)); +} + +static int parse_nla_iif(struct nlattr **attrs, struct seg6_local_lwt *slwt, + struct netlink_ext_ack *extack) +{ + slwt->iif = nla_get_u32(attrs[SEG6_LOCAL_IIF]); + + return 0; +} + +static int put_nla_iif(struct sk_buff *skb, struct seg6_local_lwt *slwt) +{ + if (nla_put_u32(skb, SEG6_LOCAL_IIF, slwt->iif)) + return -EMSGSIZE; + + return 0; +} + +static int cmp_nla_iif(struct seg6_local_lwt *a, struct seg6_local_lwt *b) +{ + if (a->iif != b->iif) + return 1; + + return 0; +} + +static int parse_nla_oif(struct nlattr **attrs, struct seg6_local_lwt *slwt, + struct netlink_ext_ack *extack) +{ + slwt->oif = nla_get_u32(attrs[SEG6_LOCAL_OIF]); + + return 0; +} + +static int put_nla_oif(struct sk_buff *skb, struct seg6_local_lwt *slwt) +{ + if (nla_put_u32(skb, SEG6_LOCAL_OIF, slwt->oif)) + return -EMSGSIZE; + + return 0; +} + +static int cmp_nla_oif(struct seg6_local_lwt *a, struct seg6_local_lwt *b) +{ + if (a->oif != b->oif) + return 1; + + return 0; +} + +#define MAX_PROG_NAME 256 +static const struct nla_policy bpf_prog_policy[SEG6_LOCAL_BPF_PROG_MAX + 1] = { + [SEG6_LOCAL_BPF_PROG] = { .type = NLA_U32, }, + [SEG6_LOCAL_BPF_PROG_NAME] = { .type = NLA_NUL_STRING, + .len = MAX_PROG_NAME }, +}; + +static int parse_nla_bpf(struct nlattr **attrs, struct seg6_local_lwt *slwt, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[SEG6_LOCAL_BPF_PROG_MAX + 1]; + struct bpf_prog *p; + int ret; + u32 fd; + + ret = nla_parse_nested_deprecated(tb, SEG6_LOCAL_BPF_PROG_MAX, + attrs[SEG6_LOCAL_BPF], + bpf_prog_policy, NULL); + if (ret < 0) + return ret; + + if (!tb[SEG6_LOCAL_BPF_PROG] || !tb[SEG6_LOCAL_BPF_PROG_NAME]) + return -EINVAL; + + slwt->bpf.name = nla_memdup(tb[SEG6_LOCAL_BPF_PROG_NAME], GFP_KERNEL); + if (!slwt->bpf.name) + return -ENOMEM; + + fd = nla_get_u32(tb[SEG6_LOCAL_BPF_PROG]); + p = bpf_prog_get_type(fd, BPF_PROG_TYPE_LWT_SEG6LOCAL); + if (IS_ERR(p)) { + kfree(slwt->bpf.name); + return PTR_ERR(p); + } + + slwt->bpf.prog = p; + return 0; +} + +static int put_nla_bpf(struct sk_buff *skb, struct seg6_local_lwt *slwt) +{ + struct nlattr *nest; + + if (!slwt->bpf.prog) + return 0; + + nest = nla_nest_start_noflag(skb, SEG6_LOCAL_BPF); + if (!nest) + return -EMSGSIZE; + + if (nla_put_u32(skb, SEG6_LOCAL_BPF_PROG, slwt->bpf.prog->aux->id)) + return -EMSGSIZE; + + if (slwt->bpf.name && + nla_put_string(skb, SEG6_LOCAL_BPF_PROG_NAME, slwt->bpf.name)) + return -EMSGSIZE; + + return nla_nest_end(skb, nest); +} + +static int cmp_nla_bpf(struct seg6_local_lwt *a, struct seg6_local_lwt *b) +{ + if (!a->bpf.name && !b->bpf.name) + return 0; + + if (!a->bpf.name || !b->bpf.name) + return 1; + + return strcmp(a->bpf.name, b->bpf.name); +} + +static void destroy_attr_bpf(struct seg6_local_lwt *slwt) +{ + kfree(slwt->bpf.name); + if (slwt->bpf.prog) + bpf_prog_put(slwt->bpf.prog); +} + +static const struct +nla_policy seg6_local_counters_policy[SEG6_LOCAL_CNT_MAX + 1] = { + [SEG6_LOCAL_CNT_PACKETS] = { .type = NLA_U64 }, + [SEG6_LOCAL_CNT_BYTES] = { .type = NLA_U64 }, + [SEG6_LOCAL_CNT_ERRORS] = { .type = NLA_U64 }, +}; + +static int parse_nla_counters(struct nlattr **attrs, + struct seg6_local_lwt *slwt, + struct netlink_ext_ack *extack) +{ + struct pcpu_seg6_local_counters __percpu *pcounters; + struct nlattr *tb[SEG6_LOCAL_CNT_MAX + 1]; + int ret; + + ret = nla_parse_nested_deprecated(tb, SEG6_LOCAL_CNT_MAX, + attrs[SEG6_LOCAL_COUNTERS], + seg6_local_counters_policy, NULL); + if (ret < 0) + return ret; + + /* basic support for SRv6 Behavior counters requires at least: + * packets, bytes and errors. + */ + if (!tb[SEG6_LOCAL_CNT_PACKETS] || !tb[SEG6_LOCAL_CNT_BYTES] || + !tb[SEG6_LOCAL_CNT_ERRORS]) + return -EINVAL; + + /* counters are always zero initialized */ + pcounters = seg6_local_alloc_pcpu_counters(GFP_KERNEL); + if (!pcounters) + return -ENOMEM; + + slwt->pcpu_counters = pcounters; + + return 0; +} + +static int seg6_local_fill_nla_counters(struct sk_buff *skb, + struct seg6_local_counters *counters) +{ + if (nla_put_u64_64bit(skb, SEG6_LOCAL_CNT_PACKETS, counters->packets, + SEG6_LOCAL_CNT_PAD)) + return -EMSGSIZE; + + if (nla_put_u64_64bit(skb, SEG6_LOCAL_CNT_BYTES, counters->bytes, + SEG6_LOCAL_CNT_PAD)) + return -EMSGSIZE; + + if (nla_put_u64_64bit(skb, SEG6_LOCAL_CNT_ERRORS, counters->errors, + SEG6_LOCAL_CNT_PAD)) + return -EMSGSIZE; + + return 0; +} + +static int put_nla_counters(struct sk_buff *skb, struct seg6_local_lwt *slwt) +{ + struct seg6_local_counters counters = { 0, 0, 0 }; + struct nlattr *nest; + int rc, i; + + nest = nla_nest_start(skb, SEG6_LOCAL_COUNTERS); + if (!nest) + return -EMSGSIZE; + + for_each_possible_cpu(i) { + struct pcpu_seg6_local_counters *pcounters; + u64 packets, bytes, errors; + unsigned int start; + + pcounters = per_cpu_ptr(slwt->pcpu_counters, i); + do { + start = u64_stats_fetch_begin_irq(&pcounters->syncp); + + packets = u64_stats_read(&pcounters->packets); + bytes = u64_stats_read(&pcounters->bytes); + errors = u64_stats_read(&pcounters->errors); + + } while (u64_stats_fetch_retry_irq(&pcounters->syncp, start)); + + counters.packets += packets; + counters.bytes += bytes; + counters.errors += errors; + } + + rc = seg6_local_fill_nla_counters(skb, &counters); + if (rc < 0) { + nla_nest_cancel(skb, nest); + return rc; + } + + return nla_nest_end(skb, nest); +} + +static int cmp_nla_counters(struct seg6_local_lwt *a, struct seg6_local_lwt *b) +{ + /* a and b are equal if both have pcpu_counters set or not */ + return (!!((unsigned long)a->pcpu_counters)) ^ + (!!((unsigned long)b->pcpu_counters)); +} + +static void destroy_attr_counters(struct seg6_local_lwt *slwt) +{ + free_percpu(slwt->pcpu_counters); +} + +static const +struct nla_policy seg6_local_flavors_policy[SEG6_LOCAL_FLV_MAX + 1] = { + [SEG6_LOCAL_FLV_OPERATION] = { .type = NLA_U32 }, + [SEG6_LOCAL_FLV_LCBLOCK_BITS] = { .type = NLA_U8 }, + [SEG6_LOCAL_FLV_LCNODE_FN_BITS] = { .type = NLA_U8 }, +}; + +/* check whether the lengths of the Locator-Block and Locator-Node Function + * are compatible with the dimension of a C-SID container. + */ +static int seg6_chk_next_csid_cfg(__u8 block_len, __u8 func_len) +{ + /* Locator-Block and Locator-Node Function cannot exceed 128 bits + * (i.e. C-SID container lenghts). + */ + if (next_csid_chk_cntr_bits(block_len, func_len)) + return -EINVAL; + + /* Locator-Block length must be greater than zero and evenly divisible + * by 8. There must be room for a Locator-Node Function, at least. + */ + if (next_csid_chk_lcblock_bits(block_len)) + return -EINVAL; + + /* Locator-Node Function length must be greater than zero and evenly + * divisible by 8. There must be room for the Locator-Block. + */ + if (next_csid_chk_lcnode_fn_bits(func_len)) + return -EINVAL; + + return 0; +} + +static int seg6_parse_nla_next_csid_cfg(struct nlattr **tb, + struct seg6_flavors_info *finfo, + struct netlink_ext_ack *extack) +{ + __u8 func_len = SEG6_LOCAL_LCNODE_FN_DBITS; + __u8 block_len = SEG6_LOCAL_LCBLOCK_DBITS; + int rc; + + if (tb[SEG6_LOCAL_FLV_LCBLOCK_BITS]) + block_len = nla_get_u8(tb[SEG6_LOCAL_FLV_LCBLOCK_BITS]); + + if (tb[SEG6_LOCAL_FLV_LCNODE_FN_BITS]) + func_len = nla_get_u8(tb[SEG6_LOCAL_FLV_LCNODE_FN_BITS]); + + rc = seg6_chk_next_csid_cfg(block_len, func_len); + if (rc < 0) { + NL_SET_ERR_MSG(extack, + "Invalid Locator Block/Node Function lengths"); + return rc; + } + + finfo->lcblock_bits = block_len; + finfo->lcnode_func_bits = func_len; + + return 0; +} + +static int parse_nla_flavors(struct nlattr **attrs, struct seg6_local_lwt *slwt, + struct netlink_ext_ack *extack) +{ + struct seg6_flavors_info *finfo = &slwt->flv_info; + struct nlattr *tb[SEG6_LOCAL_FLV_MAX + 1]; + unsigned long fops; + int rc; + + rc = nla_parse_nested_deprecated(tb, SEG6_LOCAL_FLV_MAX, + attrs[SEG6_LOCAL_FLAVORS], + seg6_local_flavors_policy, NULL); + if (rc < 0) + return rc; + + /* this attribute MUST always be present since it represents the Flavor + * operation(s) to be carried out. + */ + if (!tb[SEG6_LOCAL_FLV_OPERATION]) + return -EINVAL; + + fops = nla_get_u32(tb[SEG6_LOCAL_FLV_OPERATION]); + if (fops & ~SEG6_LOCAL_FLV_SUPP_OPS) { + NL_SET_ERR_MSG(extack, "Unsupported Flavor operation(s)"); + return -EOPNOTSUPP; + } + + finfo->flv_ops = fops; + + if (seg6_next_csid_enabled(fops)) { + /* Locator-Block and Locator-Node Function lengths can be + * provided by the user space. Otherwise, default values are + * applied. + */ + rc = seg6_parse_nla_next_csid_cfg(tb, finfo, extack); + if (rc < 0) + return rc; + } + + return 0; +} + +static int seg6_fill_nla_next_csid_cfg(struct sk_buff *skb, + struct seg6_flavors_info *finfo) +{ + if (nla_put_u8(skb, SEG6_LOCAL_FLV_LCBLOCK_BITS, finfo->lcblock_bits)) + return -EMSGSIZE; + + if (nla_put_u8(skb, SEG6_LOCAL_FLV_LCNODE_FN_BITS, + finfo->lcnode_func_bits)) + return -EMSGSIZE; + + return 0; +} + +static int put_nla_flavors(struct sk_buff *skb, struct seg6_local_lwt *slwt) +{ + struct seg6_flavors_info *finfo = &slwt->flv_info; + __u32 fops = finfo->flv_ops; + struct nlattr *nest; + int rc; + + nest = nla_nest_start(skb, SEG6_LOCAL_FLAVORS); + if (!nest) + return -EMSGSIZE; + + if (nla_put_u32(skb, SEG6_LOCAL_FLV_OPERATION, fops)) { + rc = -EMSGSIZE; + goto err; + } + + if (seg6_next_csid_enabled(fops)) { + rc = seg6_fill_nla_next_csid_cfg(skb, finfo); + if (rc < 0) + goto err; + } + + return nla_nest_end(skb, nest); + +err: + nla_nest_cancel(skb, nest); + return rc; +} + +static int seg6_cmp_nla_next_csid_cfg(struct seg6_flavors_info *finfo_a, + struct seg6_flavors_info *finfo_b) +{ + if (finfo_a->lcblock_bits != finfo_b->lcblock_bits) + return 1; + + if (finfo_a->lcnode_func_bits != finfo_b->lcnode_func_bits) + return 1; + + return 0; +} + +static int cmp_nla_flavors(struct seg6_local_lwt *a, struct seg6_local_lwt *b) +{ + struct seg6_flavors_info *finfo_a = &a->flv_info; + struct seg6_flavors_info *finfo_b = &b->flv_info; + + if (finfo_a->flv_ops != finfo_b->flv_ops) + return 1; + + if (seg6_next_csid_enabled(finfo_a->flv_ops)) { + if (seg6_cmp_nla_next_csid_cfg(finfo_a, finfo_b)) + return 1; + } + + return 0; +} + +static int encap_size_flavors(struct seg6_local_lwt *slwt) +{ + struct seg6_flavors_info *finfo = &slwt->flv_info; + int nlsize; + + nlsize = nla_total_size(0) + /* nest SEG6_LOCAL_FLAVORS */ + nla_total_size(4); /* SEG6_LOCAL_FLV_OPERATION */ + + if (seg6_next_csid_enabled(finfo->flv_ops)) + nlsize += nla_total_size(1) + /* SEG6_LOCAL_FLV_LCBLOCK_BITS */ + nla_total_size(1); /* SEG6_LOCAL_FLV_LCNODE_FN_BITS */ + + return nlsize; +} + +struct seg6_action_param { + int (*parse)(struct nlattr **attrs, struct seg6_local_lwt *slwt, + struct netlink_ext_ack *extack); + int (*put)(struct sk_buff *skb, struct seg6_local_lwt *slwt); + int (*cmp)(struct seg6_local_lwt *a, struct seg6_local_lwt *b); + + /* optional destroy() callback useful for releasing resources which + * have been previously acquired in the corresponding parse() + * function. + */ + void (*destroy)(struct seg6_local_lwt *slwt); +}; + +static struct seg6_action_param seg6_action_params[SEG6_LOCAL_MAX + 1] = { + [SEG6_LOCAL_SRH] = { .parse = parse_nla_srh, + .put = put_nla_srh, + .cmp = cmp_nla_srh, + .destroy = destroy_attr_srh }, + + [SEG6_LOCAL_TABLE] = { .parse = parse_nla_table, + .put = put_nla_table, + .cmp = cmp_nla_table }, + + [SEG6_LOCAL_NH4] = { .parse = parse_nla_nh4, + .put = put_nla_nh4, + .cmp = cmp_nla_nh4 }, + + [SEG6_LOCAL_NH6] = { .parse = parse_nla_nh6, + .put = put_nla_nh6, + .cmp = cmp_nla_nh6 }, + + [SEG6_LOCAL_IIF] = { .parse = parse_nla_iif, + .put = put_nla_iif, + .cmp = cmp_nla_iif }, + + [SEG6_LOCAL_OIF] = { .parse = parse_nla_oif, + .put = put_nla_oif, + .cmp = cmp_nla_oif }, + + [SEG6_LOCAL_BPF] = { .parse = parse_nla_bpf, + .put = put_nla_bpf, + .cmp = cmp_nla_bpf, + .destroy = destroy_attr_bpf }, + + [SEG6_LOCAL_VRFTABLE] = { .parse = parse_nla_vrftable, + .put = put_nla_vrftable, + .cmp = cmp_nla_vrftable }, + + [SEG6_LOCAL_COUNTERS] = { .parse = parse_nla_counters, + .put = put_nla_counters, + .cmp = cmp_nla_counters, + .destroy = destroy_attr_counters }, + + [SEG6_LOCAL_FLAVORS] = { .parse = parse_nla_flavors, + .put = put_nla_flavors, + .cmp = cmp_nla_flavors }, +}; + +/* call the destroy() callback (if available) for each set attribute in + * @parsed_attrs, starting from the first attribute up to the @max_parsed + * (excluded) attribute. + */ +static void __destroy_attrs(unsigned long parsed_attrs, int max_parsed, + struct seg6_local_lwt *slwt) +{ + struct seg6_action_param *param; + int i; + + /* Every required seg6local attribute is identified by an ID which is + * encoded as a flag (i.e: 1 << ID) in the 'attrs' bitmask; + * + * We scan the 'parsed_attrs' bitmask, starting from the first attribute + * up to the @max_parsed (excluded) attribute. + * For each set attribute, we retrieve the corresponding destroy() + * callback. If the callback is not available, then we skip to the next + * attribute; otherwise, we call the destroy() callback. + */ + for (i = SEG6_LOCAL_SRH; i < max_parsed; ++i) { + if (!(parsed_attrs & SEG6_F_ATTR(i))) + continue; + + param = &seg6_action_params[i]; + + if (param->destroy) + param->destroy(slwt); + } +} + +/* release all the resources that may have been acquired during parsing + * operations. + */ +static void destroy_attrs(struct seg6_local_lwt *slwt) +{ + unsigned long attrs = slwt->desc->attrs | slwt->parsed_optattrs; + + __destroy_attrs(attrs, SEG6_LOCAL_MAX + 1, slwt); +} + +static int parse_nla_optional_attrs(struct nlattr **attrs, + struct seg6_local_lwt *slwt, + struct netlink_ext_ack *extack) +{ + struct seg6_action_desc *desc = slwt->desc; + unsigned long parsed_optattrs = 0; + struct seg6_action_param *param; + int err, i; + + for (i = SEG6_LOCAL_SRH; i < SEG6_LOCAL_MAX + 1; ++i) { + if (!(desc->optattrs & SEG6_F_ATTR(i)) || !attrs[i]) + continue; + + /* once here, the i-th attribute is provided by the + * userspace AND it is identified optional as well. + */ + param = &seg6_action_params[i]; + + err = param->parse(attrs, slwt, extack); + if (err < 0) + goto parse_optattrs_err; + + /* current attribute has been correctly parsed */ + parsed_optattrs |= SEG6_F_ATTR(i); + } + + /* store in the tunnel state all the optional attributed successfully + * parsed. + */ + slwt->parsed_optattrs = parsed_optattrs; + + return 0; + +parse_optattrs_err: + __destroy_attrs(parsed_optattrs, i, slwt); + + return err; +} + +/* call the custom constructor of the behavior during its initialization phase + * and after that all its attributes have been parsed successfully. + */ +static int +seg6_local_lwtunnel_build_state(struct seg6_local_lwt *slwt, const void *cfg, + struct netlink_ext_ack *extack) +{ + struct seg6_action_desc *desc = slwt->desc; + struct seg6_local_lwtunnel_ops *ops; + + ops = &desc->slwt_ops; + if (!ops->build_state) + return 0; + + return ops->build_state(slwt, cfg, extack); +} + +/* call the custom destructor of the behavior which is invoked before the + * tunnel is going to be destroyed. + */ +static void seg6_local_lwtunnel_destroy_state(struct seg6_local_lwt *slwt) +{ + struct seg6_action_desc *desc = slwt->desc; + struct seg6_local_lwtunnel_ops *ops; + + ops = &desc->slwt_ops; + if (!ops->destroy_state) + return; + + ops->destroy_state(slwt); +} + +static int parse_nla_action(struct nlattr **attrs, struct seg6_local_lwt *slwt, + struct netlink_ext_ack *extack) +{ + struct seg6_action_param *param; + struct seg6_action_desc *desc; + unsigned long invalid_attrs; + int i, err; + + desc = __get_action_desc(slwt->action); + if (!desc) + return -EINVAL; + + if (!desc->input) + return -EOPNOTSUPP; + + slwt->desc = desc; + slwt->headroom += desc->static_headroom; + + /* Forcing the desc->optattrs *set* and the desc->attrs *set* to be + * disjoined, this allow us to release acquired resources by optional + * attributes and by required attributes independently from each other + * without any interference. + * In other terms, we are sure that we do not release some the acquired + * resources twice. + * + * Note that if an attribute is configured both as required and as + * optional, it means that the user has messed something up in the + * seg6_action_table. Therefore, this check is required for SRv6 + * behaviors to work properly. + */ + invalid_attrs = desc->attrs & desc->optattrs; + if (invalid_attrs) { + WARN_ONCE(1, + "An attribute cannot be both required AND optional"); + return -EINVAL; + } + + /* parse the required attributes */ + for (i = SEG6_LOCAL_SRH; i < SEG6_LOCAL_MAX + 1; i++) { + if (desc->attrs & SEG6_F_ATTR(i)) { + if (!attrs[i]) + return -EINVAL; + + param = &seg6_action_params[i]; + + err = param->parse(attrs, slwt, extack); + if (err < 0) + goto parse_attrs_err; + } + } + + /* parse the optional attributes, if any */ + err = parse_nla_optional_attrs(attrs, slwt, extack); + if (err < 0) + goto parse_attrs_err; + + return 0; + +parse_attrs_err: + /* release any resource that may have been acquired during the i-1 + * parse() operations. + */ + __destroy_attrs(desc->attrs, i, slwt); + + return err; +} + +static int seg6_local_build_state(struct net *net, struct nlattr *nla, + unsigned int family, const void *cfg, + struct lwtunnel_state **ts, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[SEG6_LOCAL_MAX + 1]; + struct lwtunnel_state *newts; + struct seg6_local_lwt *slwt; + int err; + + if (family != AF_INET6) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, SEG6_LOCAL_MAX, nla, + seg6_local_policy, extack); + + if (err < 0) + return err; + + if (!tb[SEG6_LOCAL_ACTION]) + return -EINVAL; + + newts = lwtunnel_state_alloc(sizeof(*slwt)); + if (!newts) + return -ENOMEM; + + slwt = seg6_local_lwtunnel(newts); + slwt->action = nla_get_u32(tb[SEG6_LOCAL_ACTION]); + + err = parse_nla_action(tb, slwt, extack); + if (err < 0) + goto out_free; + + err = seg6_local_lwtunnel_build_state(slwt, cfg, extack); + if (err < 0) + goto out_destroy_attrs; + + newts->type = LWTUNNEL_ENCAP_SEG6_LOCAL; + newts->flags = LWTUNNEL_STATE_INPUT_REDIRECT; + newts->headroom = slwt->headroom; + + *ts = newts; + + return 0; + +out_destroy_attrs: + destroy_attrs(slwt); +out_free: + kfree(newts); + return err; +} + +static void seg6_local_destroy_state(struct lwtunnel_state *lwt) +{ + struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt); + + seg6_local_lwtunnel_destroy_state(slwt); + + destroy_attrs(slwt); + + return; +} + +static int seg6_local_fill_encap(struct sk_buff *skb, + struct lwtunnel_state *lwt) +{ + struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt); + struct seg6_action_param *param; + unsigned long attrs; + int i, err; + + if (nla_put_u32(skb, SEG6_LOCAL_ACTION, slwt->action)) + return -EMSGSIZE; + + attrs = slwt->desc->attrs | slwt->parsed_optattrs; + + for (i = SEG6_LOCAL_SRH; i < SEG6_LOCAL_MAX + 1; i++) { + if (attrs & SEG6_F_ATTR(i)) { + param = &seg6_action_params[i]; + err = param->put(skb, slwt); + if (err < 0) + return err; + } + } + + return 0; +} + +static int seg6_local_get_encap_size(struct lwtunnel_state *lwt) +{ + struct seg6_local_lwt *slwt = seg6_local_lwtunnel(lwt); + unsigned long attrs; + int nlsize; + + nlsize = nla_total_size(4); /* action */ + + attrs = slwt->desc->attrs | slwt->parsed_optattrs; + + if (attrs & SEG6_F_ATTR(SEG6_LOCAL_SRH)) + nlsize += nla_total_size((slwt->srh->hdrlen + 1) << 3); + + if (attrs & SEG6_F_ATTR(SEG6_LOCAL_TABLE)) + nlsize += nla_total_size(4); + + if (attrs & SEG6_F_ATTR(SEG6_LOCAL_NH4)) + nlsize += nla_total_size(4); + + if (attrs & SEG6_F_ATTR(SEG6_LOCAL_NH6)) + nlsize += nla_total_size(16); + + if (attrs & SEG6_F_ATTR(SEG6_LOCAL_IIF)) + nlsize += nla_total_size(4); + + if (attrs & SEG6_F_ATTR(SEG6_LOCAL_OIF)) + nlsize += nla_total_size(4); + + if (attrs & SEG6_F_ATTR(SEG6_LOCAL_BPF)) + nlsize += nla_total_size(sizeof(struct nlattr)) + + nla_total_size(MAX_PROG_NAME) + + nla_total_size(4); + + if (attrs & SEG6_F_ATTR(SEG6_LOCAL_VRFTABLE)) + nlsize += nla_total_size(4); + + if (attrs & SEG6_F_LOCAL_COUNTERS) + nlsize += nla_total_size(0) + /* nest SEG6_LOCAL_COUNTERS */ + /* SEG6_LOCAL_CNT_PACKETS */ + nla_total_size_64bit(sizeof(__u64)) + + /* SEG6_LOCAL_CNT_BYTES */ + nla_total_size_64bit(sizeof(__u64)) + + /* SEG6_LOCAL_CNT_ERRORS */ + nla_total_size_64bit(sizeof(__u64)); + + if (attrs & SEG6_F_ATTR(SEG6_LOCAL_FLAVORS)) + nlsize += encap_size_flavors(slwt); + + return nlsize; +} + +static int seg6_local_cmp_encap(struct lwtunnel_state *a, + struct lwtunnel_state *b) +{ + struct seg6_local_lwt *slwt_a, *slwt_b; + struct seg6_action_param *param; + unsigned long attrs_a, attrs_b; + int i; + + slwt_a = seg6_local_lwtunnel(a); + slwt_b = seg6_local_lwtunnel(b); + + if (slwt_a->action != slwt_b->action) + return 1; + + attrs_a = slwt_a->desc->attrs | slwt_a->parsed_optattrs; + attrs_b = slwt_b->desc->attrs | slwt_b->parsed_optattrs; + + if (attrs_a != attrs_b) + return 1; + + for (i = SEG6_LOCAL_SRH; i < SEG6_LOCAL_MAX + 1; i++) { + if (attrs_a & SEG6_F_ATTR(i)) { + param = &seg6_action_params[i]; + if (param->cmp(slwt_a, slwt_b)) + return 1; + } + } + + return 0; +} + +static const struct lwtunnel_encap_ops seg6_local_ops = { + .build_state = seg6_local_build_state, + .destroy_state = seg6_local_destroy_state, + .input = seg6_local_input, + .fill_encap = seg6_local_fill_encap, + .get_encap_size = seg6_local_get_encap_size, + .cmp_encap = seg6_local_cmp_encap, + .owner = THIS_MODULE, +}; + +int __init seg6_local_init(void) +{ + /* If the max total number of defined attributes is reached, then your + * kernel build stops here. + * + * This check is required to avoid arithmetic overflows when processing + * behavior attributes and the maximum number of defined attributes + * exceeds the allowed value. + */ + BUILD_BUG_ON(SEG6_LOCAL_MAX + 1 > BITS_PER_TYPE(unsigned long)); + + /* If the default NEXT-C-SID Locator-Block/Node Function lengths (in + * bits) have been changed with invalid values, kernel build stops + * here. + */ + BUILD_BUG_ON(next_csid_chk_cntr_bits(SEG6_LOCAL_LCBLOCK_DBITS, + SEG6_LOCAL_LCNODE_FN_DBITS)); + BUILD_BUG_ON(next_csid_chk_lcblock_bits(SEG6_LOCAL_LCBLOCK_DBITS)); + BUILD_BUG_ON(next_csid_chk_lcnode_fn_bits(SEG6_LOCAL_LCNODE_FN_DBITS)); + + return lwtunnel_encap_add_ops(&seg6_local_ops, + LWTUNNEL_ENCAP_SEG6_LOCAL); +} + +void seg6_local_exit(void) +{ + lwtunnel_encap_del_ops(&seg6_local_ops, LWTUNNEL_ENCAP_SEG6_LOCAL); +} diff --git a/net/ipv6/sit.c b/net/ipv6/sit.c new file mode 100644 index 000000000..3ffb6a5b1 --- /dev/null +++ b/net/ipv6/sit.c @@ -0,0 +1,1961 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPv6 over IPv4 tunnel device - Simple Internet Transition (SIT) + * Linux INET6 implementation + * + * Authors: + * Pedro Roque + * Alexey Kuznetsov + * + * Changes: + * Roger Venning : 6to4 support + * Nate Thompson : 6to4 support + * Fred Templin : isatap support + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + This version of net/ipv6/sit.c is cloned of net/ipv4/ip_gre.c + + For comments look at net/ipv4/ip_gre.c --ANK + */ + +#define IP6_SIT_HASH_SIZE 16 +#define HASH(addr) (((__force u32)addr^((__force u32)addr>>4))&0xF) + +static bool log_ecn_error = true; +module_param(log_ecn_error, bool, 0644); +MODULE_PARM_DESC(log_ecn_error, "Log packets received with corrupted ECN"); + +static int ipip6_tunnel_init(struct net_device *dev); +static void ipip6_tunnel_setup(struct net_device *dev); +static void ipip6_dev_free(struct net_device *dev); +static bool check_6rd(struct ip_tunnel *tunnel, const struct in6_addr *v6dst, + __be32 *v4dst); +static struct rtnl_link_ops sit_link_ops __read_mostly; + +static unsigned int sit_net_id __read_mostly; +struct sit_net { + struct ip_tunnel __rcu *tunnels_r_l[IP6_SIT_HASH_SIZE]; + struct ip_tunnel __rcu *tunnels_r[IP6_SIT_HASH_SIZE]; + struct ip_tunnel __rcu *tunnels_l[IP6_SIT_HASH_SIZE]; + struct ip_tunnel __rcu *tunnels_wc[1]; + struct ip_tunnel __rcu **tunnels[4]; + + struct net_device *fb_tunnel_dev; +}; + +static inline struct sit_net *dev_to_sit_net(struct net_device *dev) +{ + struct ip_tunnel *t = netdev_priv(dev); + + return net_generic(t->net, sit_net_id); +} + +/* + * Must be invoked with rcu_read_lock + */ +static struct ip_tunnel *ipip6_tunnel_lookup(struct net *net, + struct net_device *dev, + __be32 remote, __be32 local, + int sifindex) +{ + unsigned int h0 = HASH(remote); + unsigned int h1 = HASH(local); + struct ip_tunnel *t; + struct sit_net *sitn = net_generic(net, sit_net_id); + int ifindex = dev ? dev->ifindex : 0; + + for_each_ip_tunnel_rcu(t, sitn->tunnels_r_l[h0 ^ h1]) { + if (local == t->parms.iph.saddr && + remote == t->parms.iph.daddr && + (!dev || !t->parms.link || ifindex == t->parms.link || + sifindex == t->parms.link) && + (t->dev->flags & IFF_UP)) + return t; + } + for_each_ip_tunnel_rcu(t, sitn->tunnels_r[h0]) { + if (remote == t->parms.iph.daddr && + (!dev || !t->parms.link || ifindex == t->parms.link || + sifindex == t->parms.link) && + (t->dev->flags & IFF_UP)) + return t; + } + for_each_ip_tunnel_rcu(t, sitn->tunnels_l[h1]) { + if (local == t->parms.iph.saddr && + (!dev || !t->parms.link || ifindex == t->parms.link || + sifindex == t->parms.link) && + (t->dev->flags & IFF_UP)) + return t; + } + t = rcu_dereference(sitn->tunnels_wc[0]); + if (t && (t->dev->flags & IFF_UP)) + return t; + return NULL; +} + +static struct ip_tunnel __rcu **__ipip6_bucket(struct sit_net *sitn, + struct ip_tunnel_parm *parms) +{ + __be32 remote = parms->iph.daddr; + __be32 local = parms->iph.saddr; + unsigned int h = 0; + int prio = 0; + + if (remote) { + prio |= 2; + h ^= HASH(remote); + } + if (local) { + prio |= 1; + h ^= HASH(local); + } + return &sitn->tunnels[prio][h]; +} + +static inline struct ip_tunnel __rcu **ipip6_bucket(struct sit_net *sitn, + struct ip_tunnel *t) +{ + return __ipip6_bucket(sitn, &t->parms); +} + +static void ipip6_tunnel_unlink(struct sit_net *sitn, struct ip_tunnel *t) +{ + struct ip_tunnel __rcu **tp; + struct ip_tunnel *iter; + + for (tp = ipip6_bucket(sitn, t); + (iter = rtnl_dereference(*tp)) != NULL; + tp = &iter->next) { + if (t == iter) { + rcu_assign_pointer(*tp, t->next); + break; + } + } +} + +static void ipip6_tunnel_link(struct sit_net *sitn, struct ip_tunnel *t) +{ + struct ip_tunnel __rcu **tp = ipip6_bucket(sitn, t); + + rcu_assign_pointer(t->next, rtnl_dereference(*tp)); + rcu_assign_pointer(*tp, t); +} + +static void ipip6_tunnel_clone_6rd(struct net_device *dev, struct sit_net *sitn) +{ +#ifdef CONFIG_IPV6_SIT_6RD + struct ip_tunnel *t = netdev_priv(dev); + + if (dev == sitn->fb_tunnel_dev || !sitn->fb_tunnel_dev) { + ipv6_addr_set(&t->ip6rd.prefix, htonl(0x20020000), 0, 0, 0); + t->ip6rd.relay_prefix = 0; + t->ip6rd.prefixlen = 16; + t->ip6rd.relay_prefixlen = 0; + } else { + struct ip_tunnel *t0 = netdev_priv(sitn->fb_tunnel_dev); + memcpy(&t->ip6rd, &t0->ip6rd, sizeof(t->ip6rd)); + } +#endif +} + +static int ipip6_tunnel_create(struct net_device *dev) +{ + struct ip_tunnel *t = netdev_priv(dev); + struct net *net = dev_net(dev); + struct sit_net *sitn = net_generic(net, sit_net_id); + int err; + + __dev_addr_set(dev, &t->parms.iph.saddr, 4); + memcpy(dev->broadcast, &t->parms.iph.daddr, 4); + + if ((__force u16)t->parms.i_flags & SIT_ISATAP) + dev->priv_flags |= IFF_ISATAP; + + dev->rtnl_link_ops = &sit_link_ops; + + err = register_netdevice(dev); + if (err < 0) + goto out; + + ipip6_tunnel_clone_6rd(dev, sitn); + + ipip6_tunnel_link(sitn, t); + return 0; + +out: + return err; +} + +static struct ip_tunnel *ipip6_tunnel_locate(struct net *net, + struct ip_tunnel_parm *parms, int create) +{ + __be32 remote = parms->iph.daddr; + __be32 local = parms->iph.saddr; + struct ip_tunnel *t, *nt; + struct ip_tunnel __rcu **tp; + struct net_device *dev; + char name[IFNAMSIZ]; + struct sit_net *sitn = net_generic(net, sit_net_id); + + for (tp = __ipip6_bucket(sitn, parms); + (t = rtnl_dereference(*tp)) != NULL; + tp = &t->next) { + if (local == t->parms.iph.saddr && + remote == t->parms.iph.daddr && + parms->link == t->parms.link) { + if (create) + return NULL; + else + return t; + } + } + if (!create) + goto failed; + + if (parms->name[0]) { + if (!dev_valid_name(parms->name)) + goto failed; + strscpy(name, parms->name, IFNAMSIZ); + } else { + strcpy(name, "sit%d"); + } + dev = alloc_netdev(sizeof(*t), name, NET_NAME_UNKNOWN, + ipip6_tunnel_setup); + if (!dev) + return NULL; + + dev_net_set(dev, net); + + nt = netdev_priv(dev); + + nt->parms = *parms; + if (ipip6_tunnel_create(dev) < 0) + goto failed_free; + + if (!parms->name[0]) + strcpy(parms->name, dev->name); + + return nt; + +failed_free: + free_netdev(dev); +failed: + return NULL; +} + +#define for_each_prl_rcu(start) \ + for (prl = rcu_dereference(start); \ + prl; \ + prl = rcu_dereference(prl->next)) + +static struct ip_tunnel_prl_entry * +__ipip6_tunnel_locate_prl(struct ip_tunnel *t, __be32 addr) +{ + struct ip_tunnel_prl_entry *prl; + + for_each_prl_rcu(t->prl) + if (prl->addr == addr) + break; + return prl; + +} + +static int ipip6_tunnel_get_prl(struct net_device *dev, struct ip_tunnel_prl __user *a) +{ + struct ip_tunnel *t = netdev_priv(dev); + struct ip_tunnel_prl kprl, *kp; + struct ip_tunnel_prl_entry *prl; + unsigned int cmax, c = 0, ca, len; + int ret = 0; + + if (dev == dev_to_sit_net(dev)->fb_tunnel_dev) + return -EINVAL; + + if (copy_from_user(&kprl, a, sizeof(kprl))) + return -EFAULT; + cmax = kprl.datalen / sizeof(kprl); + if (cmax > 1 && kprl.addr != htonl(INADDR_ANY)) + cmax = 1; + + /* For simple GET or for root users, + * we try harder to allocate. + */ + kp = (cmax <= 1 || capable(CAP_NET_ADMIN)) ? + kcalloc(cmax, sizeof(*kp), GFP_KERNEL_ACCOUNT | __GFP_NOWARN) : + NULL; + + ca = min(t->prl_count, cmax); + + if (!kp) { + /* We don't try hard to allocate much memory for + * non-root users. + * For root users, retry allocating enough memory for + * the answer. + */ + kp = kcalloc(ca, sizeof(*kp), GFP_ATOMIC | __GFP_ACCOUNT | + __GFP_NOWARN); + if (!kp) { + ret = -ENOMEM; + goto out; + } + } + + rcu_read_lock(); + for_each_prl_rcu(t->prl) { + if (c >= cmax) + break; + if (kprl.addr != htonl(INADDR_ANY) && prl->addr != kprl.addr) + continue; + kp[c].addr = prl->addr; + kp[c].flags = prl->flags; + c++; + if (kprl.addr != htonl(INADDR_ANY)) + break; + } + + rcu_read_unlock(); + + len = sizeof(*kp) * c; + ret = 0; + if ((len && copy_to_user(a + 1, kp, len)) || put_user(len, &a->datalen)) + ret = -EFAULT; + + kfree(kp); +out: + return ret; +} + +static int +ipip6_tunnel_add_prl(struct ip_tunnel *t, struct ip_tunnel_prl *a, int chg) +{ + struct ip_tunnel_prl_entry *p; + int err = 0; + + if (a->addr == htonl(INADDR_ANY)) + return -EINVAL; + + ASSERT_RTNL(); + + for (p = rtnl_dereference(t->prl); p; p = rtnl_dereference(p->next)) { + if (p->addr == a->addr) { + if (chg) { + p->flags = a->flags; + goto out; + } + err = -EEXIST; + goto out; + } + } + + if (chg) { + err = -ENXIO; + goto out; + } + + p = kzalloc(sizeof(struct ip_tunnel_prl_entry), GFP_KERNEL); + if (!p) { + err = -ENOBUFS; + goto out; + } + + p->next = t->prl; + p->addr = a->addr; + p->flags = a->flags; + t->prl_count++; + rcu_assign_pointer(t->prl, p); +out: + return err; +} + +static void prl_list_destroy_rcu(struct rcu_head *head) +{ + struct ip_tunnel_prl_entry *p, *n; + + p = container_of(head, struct ip_tunnel_prl_entry, rcu_head); + do { + n = rcu_dereference_protected(p->next, 1); + kfree(p); + p = n; + } while (p); +} + +static int +ipip6_tunnel_del_prl(struct ip_tunnel *t, struct ip_tunnel_prl *a) +{ + struct ip_tunnel_prl_entry *x; + struct ip_tunnel_prl_entry __rcu **p; + int err = 0; + + ASSERT_RTNL(); + + if (a && a->addr != htonl(INADDR_ANY)) { + for (p = &t->prl; + (x = rtnl_dereference(*p)) != NULL; + p = &x->next) { + if (x->addr == a->addr) { + *p = x->next; + kfree_rcu(x, rcu_head); + t->prl_count--; + goto out; + } + } + err = -ENXIO; + } else { + x = rtnl_dereference(t->prl); + if (x) { + t->prl_count = 0; + call_rcu(&x->rcu_head, prl_list_destroy_rcu); + t->prl = NULL; + } + } +out: + return err; +} + +static int ipip6_tunnel_prl_ctl(struct net_device *dev, + struct ip_tunnel_prl __user *data, int cmd) +{ + struct ip_tunnel *t = netdev_priv(dev); + struct ip_tunnel_prl prl; + int err; + + if (!ns_capable(t->net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + if (dev == dev_to_sit_net(dev)->fb_tunnel_dev) + return -EINVAL; + + if (copy_from_user(&prl, data, sizeof(prl))) + return -EFAULT; + + switch (cmd) { + case SIOCDELPRL: + err = ipip6_tunnel_del_prl(t, &prl); + break; + case SIOCADDPRL: + case SIOCCHGPRL: + err = ipip6_tunnel_add_prl(t, &prl, cmd == SIOCCHGPRL); + break; + } + dst_cache_reset(&t->dst_cache); + netdev_state_change(dev); + return err; +} + +static int +isatap_chksrc(struct sk_buff *skb, const struct iphdr *iph, struct ip_tunnel *t) +{ + struct ip_tunnel_prl_entry *p; + int ok = 1; + + rcu_read_lock(); + p = __ipip6_tunnel_locate_prl(t, iph->saddr); + if (p) { + if (p->flags & PRL_DEFAULT) + skb->ndisc_nodetype = NDISC_NODETYPE_DEFAULT; + else + skb->ndisc_nodetype = NDISC_NODETYPE_NODEFAULT; + } else { + const struct in6_addr *addr6 = &ipv6_hdr(skb)->saddr; + + if (ipv6_addr_is_isatap(addr6) && + (addr6->s6_addr32[3] == iph->saddr) && + ipv6_chk_prefix(addr6, t->dev)) + skb->ndisc_nodetype = NDISC_NODETYPE_HOST; + else + ok = 0; + } + rcu_read_unlock(); + return ok; +} + +static void ipip6_tunnel_uninit(struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + struct sit_net *sitn = net_generic(tunnel->net, sit_net_id); + + if (dev == sitn->fb_tunnel_dev) { + RCU_INIT_POINTER(sitn->tunnels_wc[0], NULL); + } else { + ipip6_tunnel_unlink(sitn, tunnel); + ipip6_tunnel_del_prl(tunnel, NULL); + } + dst_cache_reset(&tunnel->dst_cache); + netdev_put(dev, &tunnel->dev_tracker); +} + +static int ipip6_err(struct sk_buff *skb, u32 info) +{ + const struct iphdr *iph = (const struct iphdr *)skb->data; + const int type = icmp_hdr(skb)->type; + const int code = icmp_hdr(skb)->code; + unsigned int data_len = 0; + struct ip_tunnel *t; + int sifindex; + int err; + + switch (type) { + default: + case ICMP_PARAMETERPROB: + return 0; + + case ICMP_DEST_UNREACH: + switch (code) { + case ICMP_SR_FAILED: + /* Impossible event. */ + return 0; + default: + /* All others are translated to HOST_UNREACH. + rfc2003 contains "deep thoughts" about NET_UNREACH, + I believe they are just ether pollution. --ANK + */ + break; + } + break; + case ICMP_TIME_EXCEEDED: + if (code != ICMP_EXC_TTL) + return 0; + data_len = icmp_hdr(skb)->un.reserved[1] * 4; /* RFC 4884 4.1 */ + break; + case ICMP_REDIRECT: + break; + } + + err = -ENOENT; + + sifindex = netif_is_l3_master(skb->dev) ? IPCB(skb)->iif : 0; + t = ipip6_tunnel_lookup(dev_net(skb->dev), skb->dev, + iph->daddr, iph->saddr, sifindex); + if (!t) + goto out; + + if (type == ICMP_DEST_UNREACH && code == ICMP_FRAG_NEEDED) { + ipv4_update_pmtu(skb, dev_net(skb->dev), info, + t->parms.link, iph->protocol); + err = 0; + goto out; + } + if (type == ICMP_REDIRECT) { + ipv4_redirect(skb, dev_net(skb->dev), t->parms.link, + iph->protocol); + err = 0; + goto out; + } + + err = 0; + if (__in6_dev_get(skb->dev) && + !ip6_err_gen_icmpv6_unreach(skb, iph->ihl * 4, type, data_len)) + goto out; + + if (t->parms.iph.daddr == 0) + goto out; + + if (t->parms.iph.ttl == 0 && type == ICMP_TIME_EXCEEDED) + goto out; + + if (time_before(jiffies, t->err_time + IPTUNNEL_ERR_TIMEO)) + t->err_count++; + else + t->err_count = 1; + t->err_time = jiffies; +out: + return err; +} + +static inline bool is_spoofed_6rd(struct ip_tunnel *tunnel, const __be32 v4addr, + const struct in6_addr *v6addr) +{ + __be32 v4embed = 0; + if (check_6rd(tunnel, v6addr, &v4embed) && v4addr != v4embed) + return true; + return false; +} + +/* Checks if an address matches an address on the tunnel interface. + * Used to detect the NAT of proto 41 packets and let them pass spoofing test. + * Long story: + * This function is called after we considered the packet as spoofed + * in is_spoofed_6rd. + * We may have a router that is doing NAT for proto 41 packets + * for an internal station. Destination a.a.a.a/PREFIX:bbbb:bbbb + * will be translated to n.n.n.n/PREFIX:bbbb:bbbb. And is_spoofed_6rd + * function will return true, dropping the packet. + * But, we can still check if is spoofed against the IP + * addresses associated with the interface. + */ +static bool only_dnatted(const struct ip_tunnel *tunnel, + const struct in6_addr *v6dst) +{ + int prefix_len; + +#ifdef CONFIG_IPV6_SIT_6RD + prefix_len = tunnel->ip6rd.prefixlen + 32 + - tunnel->ip6rd.relay_prefixlen; +#else + prefix_len = 48; +#endif + return ipv6_chk_custom_prefix(v6dst, prefix_len, tunnel->dev); +} + +/* Returns true if a packet is spoofed */ +static bool packet_is_spoofed(struct sk_buff *skb, + const struct iphdr *iph, + struct ip_tunnel *tunnel) +{ + const struct ipv6hdr *ipv6h; + + if (tunnel->dev->priv_flags & IFF_ISATAP) { + if (!isatap_chksrc(skb, iph, tunnel)) + return true; + + return false; + } + + if (tunnel->dev->flags & IFF_POINTOPOINT) + return false; + + ipv6h = ipv6_hdr(skb); + + if (unlikely(is_spoofed_6rd(tunnel, iph->saddr, &ipv6h->saddr))) { + net_warn_ratelimited("Src spoofed %pI4/%pI6c -> %pI4/%pI6c\n", + &iph->saddr, &ipv6h->saddr, + &iph->daddr, &ipv6h->daddr); + return true; + } + + if (likely(!is_spoofed_6rd(tunnel, iph->daddr, &ipv6h->daddr))) + return false; + + if (only_dnatted(tunnel, &ipv6h->daddr)) + return false; + + net_warn_ratelimited("Dst spoofed %pI4/%pI6c -> %pI4/%pI6c\n", + &iph->saddr, &ipv6h->saddr, + &iph->daddr, &ipv6h->daddr); + return true; +} + +static int ipip6_rcv(struct sk_buff *skb) +{ + const struct iphdr *iph = ip_hdr(skb); + struct ip_tunnel *tunnel; + int sifindex; + int err; + + sifindex = netif_is_l3_master(skb->dev) ? IPCB(skb)->iif : 0; + tunnel = ipip6_tunnel_lookup(dev_net(skb->dev), skb->dev, + iph->saddr, iph->daddr, sifindex); + if (tunnel) { + if (tunnel->parms.iph.protocol != IPPROTO_IPV6 && + tunnel->parms.iph.protocol != 0) + goto out; + + skb->mac_header = skb->network_header; + skb_reset_network_header(skb); + IPCB(skb)->flags = 0; + skb->dev = tunnel->dev; + + if (packet_is_spoofed(skb, iph, tunnel)) { + DEV_STATS_INC(tunnel->dev, rx_errors); + goto out; + } + + if (iptunnel_pull_header(skb, 0, htons(ETH_P_IPV6), + !net_eq(tunnel->net, dev_net(tunnel->dev)))) + goto out; + + /* skb can be uncloned in iptunnel_pull_header, so + * old iph is no longer valid + */ + iph = (const struct iphdr *)skb_mac_header(skb); + skb_reset_mac_header(skb); + + err = IP_ECN_decapsulate(iph, skb); + if (unlikely(err)) { + if (log_ecn_error) + net_info_ratelimited("non-ECT from %pI4 with TOS=%#x\n", + &iph->saddr, iph->tos); + if (err > 1) { + DEV_STATS_INC(tunnel->dev, rx_frame_errors); + DEV_STATS_INC(tunnel->dev, rx_errors); + goto out; + } + } + + dev_sw_netstats_rx_add(tunnel->dev, skb->len); + + netif_rx(skb); + + return 0; + } + + /* no tunnel matched, let upstream know, ipsec may handle it */ + return 1; +out: + kfree_skb(skb); + return 0; +} + +static const struct tnl_ptk_info ipip_tpi = { + /* no tunnel info required for ipip. */ + .proto = htons(ETH_P_IP), +}; + +#if IS_ENABLED(CONFIG_MPLS) +static const struct tnl_ptk_info mplsip_tpi = { + /* no tunnel info required for mplsip. */ + .proto = htons(ETH_P_MPLS_UC), +}; +#endif + +static int sit_tunnel_rcv(struct sk_buff *skb, u8 ipproto) +{ + const struct iphdr *iph; + struct ip_tunnel *tunnel; + int sifindex; + + sifindex = netif_is_l3_master(skb->dev) ? IPCB(skb)->iif : 0; + + iph = ip_hdr(skb); + tunnel = ipip6_tunnel_lookup(dev_net(skb->dev), skb->dev, + iph->saddr, iph->daddr, sifindex); + if (tunnel) { + const struct tnl_ptk_info *tpi; + + if (tunnel->parms.iph.protocol != ipproto && + tunnel->parms.iph.protocol != 0) + goto drop; + + if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb)) + goto drop; +#if IS_ENABLED(CONFIG_MPLS) + if (ipproto == IPPROTO_MPLS) + tpi = &mplsip_tpi; + else +#endif + tpi = &ipip_tpi; + if (iptunnel_pull_header(skb, 0, tpi->proto, false)) + goto drop; + skb_reset_mac_header(skb); + + return ip_tunnel_rcv(tunnel, skb, tpi, NULL, log_ecn_error); + } + + return 1; + +drop: + kfree_skb(skb); + return 0; +} + +static int ipip_rcv(struct sk_buff *skb) +{ + return sit_tunnel_rcv(skb, IPPROTO_IPIP); +} + +#if IS_ENABLED(CONFIG_MPLS) +static int mplsip_rcv(struct sk_buff *skb) +{ + return sit_tunnel_rcv(skb, IPPROTO_MPLS); +} +#endif + +/* + * If the IPv6 address comes from 6rd / 6to4 (RFC 3056) addr space this function + * stores the embedded IPv4 address in v4dst and returns true. + */ +static bool check_6rd(struct ip_tunnel *tunnel, const struct in6_addr *v6dst, + __be32 *v4dst) +{ +#ifdef CONFIG_IPV6_SIT_6RD + if (ipv6_prefix_equal(v6dst, &tunnel->ip6rd.prefix, + tunnel->ip6rd.prefixlen)) { + unsigned int pbw0, pbi0; + int pbi1; + u32 d; + + pbw0 = tunnel->ip6rd.prefixlen >> 5; + pbi0 = tunnel->ip6rd.prefixlen & 0x1f; + + d = tunnel->ip6rd.relay_prefixlen < 32 ? + (ntohl(v6dst->s6_addr32[pbw0]) << pbi0) >> + tunnel->ip6rd.relay_prefixlen : 0; + + pbi1 = pbi0 - tunnel->ip6rd.relay_prefixlen; + if (pbi1 > 0) + d |= ntohl(v6dst->s6_addr32[pbw0 + 1]) >> + (32 - pbi1); + + *v4dst = tunnel->ip6rd.relay_prefix | htonl(d); + return true; + } +#else + if (v6dst->s6_addr16[0] == htons(0x2002)) { + /* 6to4 v6 addr has 16 bits prefix, 32 v4addr, 16 SLA, ... */ + memcpy(v4dst, &v6dst->s6_addr16[1], 4); + return true; + } +#endif + return false; +} + +static inline __be32 try_6rd(struct ip_tunnel *tunnel, + const struct in6_addr *v6dst) +{ + __be32 dst = 0; + check_6rd(tunnel, v6dst, &dst); + return dst; +} + +/* + * This function assumes it is being called from dev_queue_xmit() + * and that skb is filled properly by that function. + */ + +static netdev_tx_t ipip6_tunnel_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + const struct iphdr *tiph = &tunnel->parms.iph; + const struct ipv6hdr *iph6 = ipv6_hdr(skb); + u8 tos = tunnel->parms.iph.tos; + __be16 df = tiph->frag_off; + struct rtable *rt; /* Route to the other host */ + struct net_device *tdev; /* Device to other host */ + unsigned int max_headroom; /* The extra header space needed */ + __be32 dst = tiph->daddr; + struct flowi4 fl4; + int mtu; + const struct in6_addr *addr6; + int addr_type; + u8 ttl; + u8 protocol = IPPROTO_IPV6; + int t_hlen = tunnel->hlen + sizeof(struct iphdr); + + if (tos == 1) + tos = ipv6_get_dsfield(iph6); + + /* ISATAP (RFC4214) - must come before 6to4 */ + if (dev->priv_flags & IFF_ISATAP) { + struct neighbour *neigh = NULL; + bool do_tx_error = false; + + if (skb_dst(skb)) + neigh = dst_neigh_lookup(skb_dst(skb), &iph6->daddr); + + if (!neigh) { + net_dbg_ratelimited("nexthop == NULL\n"); + goto tx_error; + } + + addr6 = (const struct in6_addr *)&neigh->primary_key; + addr_type = ipv6_addr_type(addr6); + + if ((addr_type & IPV6_ADDR_UNICAST) && + ipv6_addr_is_isatap(addr6)) + dst = addr6->s6_addr32[3]; + else + do_tx_error = true; + + neigh_release(neigh); + if (do_tx_error) + goto tx_error; + } + + if (!dst) + dst = try_6rd(tunnel, &iph6->daddr); + + if (!dst) { + struct neighbour *neigh = NULL; + bool do_tx_error = false; + + if (skb_dst(skb)) + neigh = dst_neigh_lookup(skb_dst(skb), &iph6->daddr); + + if (!neigh) { + net_dbg_ratelimited("nexthop == NULL\n"); + goto tx_error; + } + + addr6 = (const struct in6_addr *)&neigh->primary_key; + addr_type = ipv6_addr_type(addr6); + + if (addr_type == IPV6_ADDR_ANY) { + addr6 = &ipv6_hdr(skb)->daddr; + addr_type = ipv6_addr_type(addr6); + } + + if ((addr_type & IPV6_ADDR_COMPATv4) != 0) + dst = addr6->s6_addr32[3]; + else + do_tx_error = true; + + neigh_release(neigh); + if (do_tx_error) + goto tx_error; + } + + flowi4_init_output(&fl4, tunnel->parms.link, tunnel->fwmark, + RT_TOS(tos), RT_SCOPE_UNIVERSE, IPPROTO_IPV6, + 0, dst, tiph->saddr, 0, 0, + sock_net_uid(tunnel->net, NULL)); + + rt = dst_cache_get_ip4(&tunnel->dst_cache, &fl4.saddr); + if (!rt) { + rt = ip_route_output_flow(tunnel->net, &fl4, NULL); + if (IS_ERR(rt)) { + DEV_STATS_INC(dev, tx_carrier_errors); + goto tx_error_icmp; + } + dst_cache_set_ip4(&tunnel->dst_cache, &rt->dst, fl4.saddr); + } + + if (rt->rt_type != RTN_UNICAST && rt->rt_type != RTN_LOCAL) { + ip_rt_put(rt); + DEV_STATS_INC(dev, tx_carrier_errors); + goto tx_error_icmp; + } + tdev = rt->dst.dev; + + if (tdev == dev) { + ip_rt_put(rt); + DEV_STATS_INC(dev, collisions); + goto tx_error; + } + + if (iptunnel_handle_offloads(skb, SKB_GSO_IPXIP4)) { + ip_rt_put(rt); + goto tx_error; + } + + if (df) { + mtu = dst_mtu(&rt->dst) - t_hlen; + + if (mtu < IPV4_MIN_MTU) { + DEV_STATS_INC(dev, collisions); + ip_rt_put(rt); + goto tx_error; + } + + if (mtu < IPV6_MIN_MTU) { + mtu = IPV6_MIN_MTU; + df = 0; + } + + if (tunnel->parms.iph.daddr) + skb_dst_update_pmtu_no_confirm(skb, mtu); + + if (skb->len > mtu && !skb_is_gso(skb)) { + icmpv6_ndo_send(skb, ICMPV6_PKT_TOOBIG, 0, mtu); + ip_rt_put(rt); + goto tx_error; + } + } + + if (tunnel->err_count > 0) { + if (time_before(jiffies, + tunnel->err_time + IPTUNNEL_ERR_TIMEO)) { + tunnel->err_count--; + dst_link_failure(skb); + } else + tunnel->err_count = 0; + } + + /* + * Okay, now see if we can stuff it in the buffer as-is. + */ + max_headroom = LL_RESERVED_SPACE(tdev) + t_hlen; + + if (skb_headroom(skb) < max_headroom || skb_shared(skb) || + (skb_cloned(skb) && !skb_clone_writable(skb, 0))) { + struct sk_buff *new_skb = skb_realloc_headroom(skb, max_headroom); + if (!new_skb) { + ip_rt_put(rt); + DEV_STATS_INC(dev, tx_dropped); + kfree_skb(skb); + return NETDEV_TX_OK; + } + if (skb->sk) + skb_set_owner_w(new_skb, skb->sk); + dev_kfree_skb(skb); + skb = new_skb; + iph6 = ipv6_hdr(skb); + } + ttl = tiph->ttl; + if (ttl == 0) + ttl = iph6->hop_limit; + tos = INET_ECN_encapsulate(tos, ipv6_get_dsfield(iph6)); + + if (ip_tunnel_encap(skb, tunnel, &protocol, &fl4) < 0) { + ip_rt_put(rt); + goto tx_error; + } + + skb_set_inner_ipproto(skb, IPPROTO_IPV6); + + iptunnel_xmit(NULL, rt, skb, fl4.saddr, fl4.daddr, protocol, tos, ttl, + df, !net_eq(tunnel->net, dev_net(dev))); + return NETDEV_TX_OK; + +tx_error_icmp: + dst_link_failure(skb); +tx_error: + kfree_skb(skb); + DEV_STATS_INC(dev, tx_errors); + return NETDEV_TX_OK; +} + +static netdev_tx_t sit_tunnel_xmit__(struct sk_buff *skb, + struct net_device *dev, u8 ipproto) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + const struct iphdr *tiph = &tunnel->parms.iph; + + if (iptunnel_handle_offloads(skb, SKB_GSO_IPXIP4)) + goto tx_error; + + skb_set_inner_ipproto(skb, ipproto); + + ip_tunnel_xmit(skb, dev, tiph, ipproto); + return NETDEV_TX_OK; +tx_error: + kfree_skb(skb); + DEV_STATS_INC(dev, tx_errors); + return NETDEV_TX_OK; +} + +static netdev_tx_t sit_tunnel_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + if (!pskb_inet_may_pull(skb)) + goto tx_err; + + switch (skb->protocol) { + case htons(ETH_P_IP): + sit_tunnel_xmit__(skb, dev, IPPROTO_IPIP); + break; + case htons(ETH_P_IPV6): + ipip6_tunnel_xmit(skb, dev); + break; +#if IS_ENABLED(CONFIG_MPLS) + case htons(ETH_P_MPLS_UC): + sit_tunnel_xmit__(skb, dev, IPPROTO_MPLS); + break; +#endif + default: + goto tx_err; + } + + return NETDEV_TX_OK; + +tx_err: + DEV_STATS_INC(dev, tx_errors); + kfree_skb(skb); + return NETDEV_TX_OK; + +} + +static void ipip6_tunnel_bind_dev(struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + int t_hlen = tunnel->hlen + sizeof(struct iphdr); + struct net_device *tdev = NULL; + int hlen = LL_MAX_HEADER; + const struct iphdr *iph; + struct flowi4 fl4; + + iph = &tunnel->parms.iph; + + if (iph->daddr) { + struct rtable *rt = ip_route_output_ports(tunnel->net, &fl4, + NULL, + iph->daddr, iph->saddr, + 0, 0, + IPPROTO_IPV6, + RT_TOS(iph->tos), + tunnel->parms.link); + + if (!IS_ERR(rt)) { + tdev = rt->dst.dev; + ip_rt_put(rt); + } + dev->flags |= IFF_POINTOPOINT; + } + + if (!tdev && tunnel->parms.link) + tdev = __dev_get_by_index(tunnel->net, tunnel->parms.link); + + if (tdev && !netif_is_l3_master(tdev)) { + int mtu; + + mtu = tdev->mtu - t_hlen; + if (mtu < IPV6_MIN_MTU) + mtu = IPV6_MIN_MTU; + WRITE_ONCE(dev->mtu, mtu); + hlen = tdev->hard_header_len + tdev->needed_headroom; + } + dev->needed_headroom = t_hlen + hlen; +} + +static void ipip6_tunnel_update(struct ip_tunnel *t, struct ip_tunnel_parm *p, + __u32 fwmark) +{ + struct net *net = t->net; + struct sit_net *sitn = net_generic(net, sit_net_id); + + ipip6_tunnel_unlink(sitn, t); + synchronize_net(); + t->parms.iph.saddr = p->iph.saddr; + t->parms.iph.daddr = p->iph.daddr; + __dev_addr_set(t->dev, &p->iph.saddr, 4); + memcpy(t->dev->broadcast, &p->iph.daddr, 4); + ipip6_tunnel_link(sitn, t); + t->parms.iph.ttl = p->iph.ttl; + t->parms.iph.tos = p->iph.tos; + t->parms.iph.frag_off = p->iph.frag_off; + if (t->parms.link != p->link || t->fwmark != fwmark) { + t->parms.link = p->link; + t->fwmark = fwmark; + ipip6_tunnel_bind_dev(t->dev); + } + dst_cache_reset(&t->dst_cache); + netdev_state_change(t->dev); +} + +#ifdef CONFIG_IPV6_SIT_6RD +static int ipip6_tunnel_update_6rd(struct ip_tunnel *t, + struct ip_tunnel_6rd *ip6rd) +{ + struct in6_addr prefix; + __be32 relay_prefix; + + if (ip6rd->relay_prefixlen > 32 || + ip6rd->prefixlen + (32 - ip6rd->relay_prefixlen) > 64) + return -EINVAL; + + ipv6_addr_prefix(&prefix, &ip6rd->prefix, ip6rd->prefixlen); + if (!ipv6_addr_equal(&prefix, &ip6rd->prefix)) + return -EINVAL; + if (ip6rd->relay_prefixlen) + relay_prefix = ip6rd->relay_prefix & + htonl(0xffffffffUL << + (32 - ip6rd->relay_prefixlen)); + else + relay_prefix = 0; + if (relay_prefix != ip6rd->relay_prefix) + return -EINVAL; + + t->ip6rd.prefix = prefix; + t->ip6rd.relay_prefix = relay_prefix; + t->ip6rd.prefixlen = ip6rd->prefixlen; + t->ip6rd.relay_prefixlen = ip6rd->relay_prefixlen; + dst_cache_reset(&t->dst_cache); + netdev_state_change(t->dev); + return 0; +} + +static int +ipip6_tunnel_get6rd(struct net_device *dev, struct ip_tunnel_parm __user *data) +{ + struct ip_tunnel *t = netdev_priv(dev); + struct ip_tunnel_6rd ip6rd; + struct ip_tunnel_parm p; + + if (dev == dev_to_sit_net(dev)->fb_tunnel_dev) { + if (copy_from_user(&p, data, sizeof(p))) + return -EFAULT; + t = ipip6_tunnel_locate(t->net, &p, 0); + } + if (!t) + t = netdev_priv(dev); + + ip6rd.prefix = t->ip6rd.prefix; + ip6rd.relay_prefix = t->ip6rd.relay_prefix; + ip6rd.prefixlen = t->ip6rd.prefixlen; + ip6rd.relay_prefixlen = t->ip6rd.relay_prefixlen; + if (copy_to_user(data, &ip6rd, sizeof(ip6rd))) + return -EFAULT; + return 0; +} + +static int +ipip6_tunnel_6rdctl(struct net_device *dev, struct ip_tunnel_6rd __user *data, + int cmd) +{ + struct ip_tunnel *t = netdev_priv(dev); + struct ip_tunnel_6rd ip6rd; + int err; + + if (!ns_capable(t->net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + if (copy_from_user(&ip6rd, data, sizeof(ip6rd))) + return -EFAULT; + + if (cmd != SIOCDEL6RD) { + err = ipip6_tunnel_update_6rd(t, &ip6rd); + if (err < 0) + return err; + } else + ipip6_tunnel_clone_6rd(dev, dev_to_sit_net(dev)); + return 0; +} + +#endif /* CONFIG_IPV6_SIT_6RD */ + +static bool ipip6_valid_ip_proto(u8 ipproto) +{ + return ipproto == IPPROTO_IPV6 || + ipproto == IPPROTO_IPIP || +#if IS_ENABLED(CONFIG_MPLS) + ipproto == IPPROTO_MPLS || +#endif + ipproto == 0; +} + +static int +__ipip6_tunnel_ioctl_validate(struct net *net, struct ip_tunnel_parm *p) +{ + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + if (!ipip6_valid_ip_proto(p->iph.protocol)) + return -EINVAL; + if (p->iph.version != 4 || + p->iph.ihl != 5 || (p->iph.frag_off & htons(~IP_DF))) + return -EINVAL; + + if (p->iph.ttl) + p->iph.frag_off |= htons(IP_DF); + return 0; +} + +static int +ipip6_tunnel_get(struct net_device *dev, struct ip_tunnel_parm *p) +{ + struct ip_tunnel *t = netdev_priv(dev); + + if (dev == dev_to_sit_net(dev)->fb_tunnel_dev) + t = ipip6_tunnel_locate(t->net, p, 0); + if (!t) + t = netdev_priv(dev); + memcpy(p, &t->parms, sizeof(*p)); + return 0; +} + +static int +ipip6_tunnel_add(struct net_device *dev, struct ip_tunnel_parm *p) +{ + struct ip_tunnel *t = netdev_priv(dev); + int err; + + err = __ipip6_tunnel_ioctl_validate(t->net, p); + if (err) + return err; + + t = ipip6_tunnel_locate(t->net, p, 1); + if (!t) + return -ENOBUFS; + return 0; +} + +static int +ipip6_tunnel_change(struct net_device *dev, struct ip_tunnel_parm *p) +{ + struct ip_tunnel *t = netdev_priv(dev); + int err; + + err = __ipip6_tunnel_ioctl_validate(t->net, p); + if (err) + return err; + + t = ipip6_tunnel_locate(t->net, p, 0); + if (dev == dev_to_sit_net(dev)->fb_tunnel_dev) { + if (!t) + return -ENOENT; + } else { + if (t) { + if (t->dev != dev) + return -EEXIST; + } else { + if (((dev->flags & IFF_POINTOPOINT) && !p->iph.daddr) || + (!(dev->flags & IFF_POINTOPOINT) && p->iph.daddr)) + return -EINVAL; + t = netdev_priv(dev); + } + + ipip6_tunnel_update(t, p, t->fwmark); + } + + return 0; +} + +static int +ipip6_tunnel_del(struct net_device *dev, struct ip_tunnel_parm *p) +{ + struct ip_tunnel *t = netdev_priv(dev); + + if (!ns_capable(t->net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + if (dev == dev_to_sit_net(dev)->fb_tunnel_dev) { + t = ipip6_tunnel_locate(t->net, p, 0); + if (!t) + return -ENOENT; + if (t == netdev_priv(dev_to_sit_net(dev)->fb_tunnel_dev)) + return -EPERM; + dev = t->dev; + } + unregister_netdevice(dev); + return 0; +} + +static int +ipip6_tunnel_ctl(struct net_device *dev, struct ip_tunnel_parm *p, int cmd) +{ + switch (cmd) { + case SIOCGETTUNNEL: + return ipip6_tunnel_get(dev, p); + case SIOCADDTUNNEL: + return ipip6_tunnel_add(dev, p); + case SIOCCHGTUNNEL: + return ipip6_tunnel_change(dev, p); + case SIOCDELTUNNEL: + return ipip6_tunnel_del(dev, p); + default: + return -EINVAL; + } +} + +static int +ipip6_tunnel_siocdevprivate(struct net_device *dev, struct ifreq *ifr, + void __user *data, int cmd) +{ + switch (cmd) { + case SIOCGETTUNNEL: + case SIOCADDTUNNEL: + case SIOCCHGTUNNEL: + case SIOCDELTUNNEL: + return ip_tunnel_siocdevprivate(dev, ifr, data, cmd); + case SIOCGETPRL: + return ipip6_tunnel_get_prl(dev, data); + case SIOCADDPRL: + case SIOCDELPRL: + case SIOCCHGPRL: + return ipip6_tunnel_prl_ctl(dev, data, cmd); +#ifdef CONFIG_IPV6_SIT_6RD + case SIOCGET6RD: + return ipip6_tunnel_get6rd(dev, data); + case SIOCADD6RD: + case SIOCCHG6RD: + case SIOCDEL6RD: + return ipip6_tunnel_6rdctl(dev, data, cmd); +#endif + default: + return -EINVAL; + } +} + +static const struct net_device_ops ipip6_netdev_ops = { + .ndo_init = ipip6_tunnel_init, + .ndo_uninit = ipip6_tunnel_uninit, + .ndo_start_xmit = sit_tunnel_xmit, + .ndo_siocdevprivate = ipip6_tunnel_siocdevprivate, + .ndo_get_stats64 = dev_get_tstats64, + .ndo_get_iflink = ip_tunnel_get_iflink, + .ndo_tunnel_ctl = ipip6_tunnel_ctl, +}; + +static void ipip6_dev_free(struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + + dst_cache_destroy(&tunnel->dst_cache); + free_percpu(dev->tstats); +} + +#define SIT_FEATURES (NETIF_F_SG | \ + NETIF_F_FRAGLIST | \ + NETIF_F_HIGHDMA | \ + NETIF_F_GSO_SOFTWARE | \ + NETIF_F_HW_CSUM) + +static void ipip6_tunnel_setup(struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + int t_hlen = tunnel->hlen + sizeof(struct iphdr); + + dev->netdev_ops = &ipip6_netdev_ops; + dev->header_ops = &ip_tunnel_header_ops; + dev->needs_free_netdev = true; + dev->priv_destructor = ipip6_dev_free; + + dev->type = ARPHRD_SIT; + dev->mtu = ETH_DATA_LEN - t_hlen; + dev->min_mtu = IPV6_MIN_MTU; + dev->max_mtu = IP6_MAX_MTU - t_hlen; + dev->flags = IFF_NOARP; + netif_keep_dst(dev); + dev->addr_len = 4; + dev->features |= NETIF_F_LLTX; + dev->features |= SIT_FEATURES; + dev->hw_features |= SIT_FEATURES; +} + +static int ipip6_tunnel_init(struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + int err; + + tunnel->dev = dev; + tunnel->net = dev_net(dev); + strcpy(tunnel->parms.name, dev->name); + + ipip6_tunnel_bind_dev(dev); + dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); + if (!dev->tstats) + return -ENOMEM; + + err = dst_cache_init(&tunnel->dst_cache, GFP_KERNEL); + if (err) { + free_percpu(dev->tstats); + dev->tstats = NULL; + return err; + } + netdev_hold(dev, &tunnel->dev_tracker, GFP_KERNEL); + return 0; +} + +static void __net_init ipip6_fb_tunnel_init(struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + struct iphdr *iph = &tunnel->parms.iph; + struct net *net = dev_net(dev); + struct sit_net *sitn = net_generic(net, sit_net_id); + + iph->version = 4; + iph->protocol = IPPROTO_IPV6; + iph->ihl = 5; + iph->ttl = 64; + + rcu_assign_pointer(sitn->tunnels_wc[0], tunnel); +} + +static int ipip6_validate(struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + u8 proto; + + if (!data || !data[IFLA_IPTUN_PROTO]) + return 0; + + proto = nla_get_u8(data[IFLA_IPTUN_PROTO]); + if (!ipip6_valid_ip_proto(proto)) + return -EINVAL; + + return 0; +} + +static void ipip6_netlink_parms(struct nlattr *data[], + struct ip_tunnel_parm *parms, + __u32 *fwmark) +{ + memset(parms, 0, sizeof(*parms)); + + parms->iph.version = 4; + parms->iph.protocol = IPPROTO_IPV6; + parms->iph.ihl = 5; + parms->iph.ttl = 64; + + if (!data) + return; + + ip_tunnel_netlink_parms(data, parms); + + if (data[IFLA_IPTUN_FWMARK]) + *fwmark = nla_get_u32(data[IFLA_IPTUN_FWMARK]); +} + +#ifdef CONFIG_IPV6_SIT_6RD +/* This function returns true when 6RD attributes are present in the nl msg */ +static bool ipip6_netlink_6rd_parms(struct nlattr *data[], + struct ip_tunnel_6rd *ip6rd) +{ + bool ret = false; + memset(ip6rd, 0, sizeof(*ip6rd)); + + if (!data) + return ret; + + if (data[IFLA_IPTUN_6RD_PREFIX]) { + ret = true; + ip6rd->prefix = nla_get_in6_addr(data[IFLA_IPTUN_6RD_PREFIX]); + } + + if (data[IFLA_IPTUN_6RD_RELAY_PREFIX]) { + ret = true; + ip6rd->relay_prefix = + nla_get_be32(data[IFLA_IPTUN_6RD_RELAY_PREFIX]); + } + + if (data[IFLA_IPTUN_6RD_PREFIXLEN]) { + ret = true; + ip6rd->prefixlen = nla_get_u16(data[IFLA_IPTUN_6RD_PREFIXLEN]); + } + + if (data[IFLA_IPTUN_6RD_RELAY_PREFIXLEN]) { + ret = true; + ip6rd->relay_prefixlen = + nla_get_u16(data[IFLA_IPTUN_6RD_RELAY_PREFIXLEN]); + } + + return ret; +} +#endif + +static int ipip6_newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct net *net = dev_net(dev); + struct ip_tunnel *nt; + struct ip_tunnel_encap ipencap; +#ifdef CONFIG_IPV6_SIT_6RD + struct ip_tunnel_6rd ip6rd; +#endif + int err; + + nt = netdev_priv(dev); + + if (ip_tunnel_netlink_encap_parms(data, &ipencap)) { + err = ip_tunnel_encap_setup(nt, &ipencap); + if (err < 0) + return err; + } + + ipip6_netlink_parms(data, &nt->parms, &nt->fwmark); + + if (ipip6_tunnel_locate(net, &nt->parms, 0)) + return -EEXIST; + + err = ipip6_tunnel_create(dev); + if (err < 0) + return err; + + if (tb[IFLA_MTU]) { + u32 mtu = nla_get_u32(tb[IFLA_MTU]); + + if (mtu >= IPV6_MIN_MTU && + mtu <= IP6_MAX_MTU - dev->hard_header_len) + dev->mtu = mtu; + } + +#ifdef CONFIG_IPV6_SIT_6RD + if (ipip6_netlink_6rd_parms(data, &ip6rd)) { + err = ipip6_tunnel_update_6rd(nt, &ip6rd); + if (err < 0) + unregister_netdevice_queue(dev, NULL); + } +#endif + + return err; +} + +static int ipip6_changelink(struct net_device *dev, struct nlattr *tb[], + struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct ip_tunnel *t = netdev_priv(dev); + struct ip_tunnel_parm p; + struct ip_tunnel_encap ipencap; + struct net *net = t->net; + struct sit_net *sitn = net_generic(net, sit_net_id); +#ifdef CONFIG_IPV6_SIT_6RD + struct ip_tunnel_6rd ip6rd; +#endif + __u32 fwmark = t->fwmark; + int err; + + if (dev == sitn->fb_tunnel_dev) + return -EINVAL; + + if (ip_tunnel_netlink_encap_parms(data, &ipencap)) { + err = ip_tunnel_encap_setup(t, &ipencap); + if (err < 0) + return err; + } + + ipip6_netlink_parms(data, &p, &fwmark); + + if (((dev->flags & IFF_POINTOPOINT) && !p.iph.daddr) || + (!(dev->flags & IFF_POINTOPOINT) && p.iph.daddr)) + return -EINVAL; + + t = ipip6_tunnel_locate(net, &p, 0); + + if (t) { + if (t->dev != dev) + return -EEXIST; + } else + t = netdev_priv(dev); + + ipip6_tunnel_update(t, &p, fwmark); + +#ifdef CONFIG_IPV6_SIT_6RD + if (ipip6_netlink_6rd_parms(data, &ip6rd)) + return ipip6_tunnel_update_6rd(t, &ip6rd); +#endif + + return 0; +} + +static size_t ipip6_get_size(const struct net_device *dev) +{ + return + /* IFLA_IPTUN_LINK */ + nla_total_size(4) + + /* IFLA_IPTUN_LOCAL */ + nla_total_size(4) + + /* IFLA_IPTUN_REMOTE */ + nla_total_size(4) + + /* IFLA_IPTUN_TTL */ + nla_total_size(1) + + /* IFLA_IPTUN_TOS */ + nla_total_size(1) + + /* IFLA_IPTUN_PMTUDISC */ + nla_total_size(1) + + /* IFLA_IPTUN_FLAGS */ + nla_total_size(2) + + /* IFLA_IPTUN_PROTO */ + nla_total_size(1) + +#ifdef CONFIG_IPV6_SIT_6RD + /* IFLA_IPTUN_6RD_PREFIX */ + nla_total_size(sizeof(struct in6_addr)) + + /* IFLA_IPTUN_6RD_RELAY_PREFIX */ + nla_total_size(4) + + /* IFLA_IPTUN_6RD_PREFIXLEN */ + nla_total_size(2) + + /* IFLA_IPTUN_6RD_RELAY_PREFIXLEN */ + nla_total_size(2) + +#endif + /* IFLA_IPTUN_ENCAP_TYPE */ + nla_total_size(2) + + /* IFLA_IPTUN_ENCAP_FLAGS */ + nla_total_size(2) + + /* IFLA_IPTUN_ENCAP_SPORT */ + nla_total_size(2) + + /* IFLA_IPTUN_ENCAP_DPORT */ + nla_total_size(2) + + /* IFLA_IPTUN_FWMARK */ + nla_total_size(4) + + 0; +} + +static int ipip6_fill_info(struct sk_buff *skb, const struct net_device *dev) +{ + struct ip_tunnel *tunnel = netdev_priv(dev); + struct ip_tunnel_parm *parm = &tunnel->parms; + + if (nla_put_u32(skb, IFLA_IPTUN_LINK, parm->link) || + nla_put_in_addr(skb, IFLA_IPTUN_LOCAL, parm->iph.saddr) || + nla_put_in_addr(skb, IFLA_IPTUN_REMOTE, parm->iph.daddr) || + nla_put_u8(skb, IFLA_IPTUN_TTL, parm->iph.ttl) || + nla_put_u8(skb, IFLA_IPTUN_TOS, parm->iph.tos) || + nla_put_u8(skb, IFLA_IPTUN_PMTUDISC, + !!(parm->iph.frag_off & htons(IP_DF))) || + nla_put_u8(skb, IFLA_IPTUN_PROTO, parm->iph.protocol) || + nla_put_be16(skb, IFLA_IPTUN_FLAGS, parm->i_flags) || + nla_put_u32(skb, IFLA_IPTUN_FWMARK, tunnel->fwmark)) + goto nla_put_failure; + +#ifdef CONFIG_IPV6_SIT_6RD + if (nla_put_in6_addr(skb, IFLA_IPTUN_6RD_PREFIX, + &tunnel->ip6rd.prefix) || + nla_put_in_addr(skb, IFLA_IPTUN_6RD_RELAY_PREFIX, + tunnel->ip6rd.relay_prefix) || + nla_put_u16(skb, IFLA_IPTUN_6RD_PREFIXLEN, + tunnel->ip6rd.prefixlen) || + nla_put_u16(skb, IFLA_IPTUN_6RD_RELAY_PREFIXLEN, + tunnel->ip6rd.relay_prefixlen)) + goto nla_put_failure; +#endif + + if (nla_put_u16(skb, IFLA_IPTUN_ENCAP_TYPE, + tunnel->encap.type) || + nla_put_be16(skb, IFLA_IPTUN_ENCAP_SPORT, + tunnel->encap.sport) || + nla_put_be16(skb, IFLA_IPTUN_ENCAP_DPORT, + tunnel->encap.dport) || + nla_put_u16(skb, IFLA_IPTUN_ENCAP_FLAGS, + tunnel->encap.flags)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static const struct nla_policy ipip6_policy[IFLA_IPTUN_MAX + 1] = { + [IFLA_IPTUN_LINK] = { .type = NLA_U32 }, + [IFLA_IPTUN_LOCAL] = { .type = NLA_U32 }, + [IFLA_IPTUN_REMOTE] = { .type = NLA_U32 }, + [IFLA_IPTUN_TTL] = { .type = NLA_U8 }, + [IFLA_IPTUN_TOS] = { .type = NLA_U8 }, + [IFLA_IPTUN_PMTUDISC] = { .type = NLA_U8 }, + [IFLA_IPTUN_FLAGS] = { .type = NLA_U16 }, + [IFLA_IPTUN_PROTO] = { .type = NLA_U8 }, +#ifdef CONFIG_IPV6_SIT_6RD + [IFLA_IPTUN_6RD_PREFIX] = { .len = sizeof(struct in6_addr) }, + [IFLA_IPTUN_6RD_RELAY_PREFIX] = { .type = NLA_U32 }, + [IFLA_IPTUN_6RD_PREFIXLEN] = { .type = NLA_U16 }, + [IFLA_IPTUN_6RD_RELAY_PREFIXLEN] = { .type = NLA_U16 }, +#endif + [IFLA_IPTUN_ENCAP_TYPE] = { .type = NLA_U16 }, + [IFLA_IPTUN_ENCAP_FLAGS] = { .type = NLA_U16 }, + [IFLA_IPTUN_ENCAP_SPORT] = { .type = NLA_U16 }, + [IFLA_IPTUN_ENCAP_DPORT] = { .type = NLA_U16 }, + [IFLA_IPTUN_FWMARK] = { .type = NLA_U32 }, +}; + +static void ipip6_dellink(struct net_device *dev, struct list_head *head) +{ + struct net *net = dev_net(dev); + struct sit_net *sitn = net_generic(net, sit_net_id); + + if (dev != sitn->fb_tunnel_dev) + unregister_netdevice_queue(dev, head); +} + +static struct rtnl_link_ops sit_link_ops __read_mostly = { + .kind = "sit", + .maxtype = IFLA_IPTUN_MAX, + .policy = ipip6_policy, + .priv_size = sizeof(struct ip_tunnel), + .setup = ipip6_tunnel_setup, + .validate = ipip6_validate, + .newlink = ipip6_newlink, + .changelink = ipip6_changelink, + .get_size = ipip6_get_size, + .fill_info = ipip6_fill_info, + .dellink = ipip6_dellink, + .get_link_net = ip_tunnel_get_link_net, +}; + +static struct xfrm_tunnel sit_handler __read_mostly = { + .handler = ipip6_rcv, + .err_handler = ipip6_err, + .priority = 1, +}; + +static struct xfrm_tunnel ipip_handler __read_mostly = { + .handler = ipip_rcv, + .err_handler = ipip6_err, + .priority = 2, +}; + +#if IS_ENABLED(CONFIG_MPLS) +static struct xfrm_tunnel mplsip_handler __read_mostly = { + .handler = mplsip_rcv, + .err_handler = ipip6_err, + .priority = 2, +}; +#endif + +static void __net_exit sit_destroy_tunnels(struct net *net, + struct list_head *head) +{ + struct sit_net *sitn = net_generic(net, sit_net_id); + struct net_device *dev, *aux; + int prio; + + for_each_netdev_safe(net, dev, aux) + if (dev->rtnl_link_ops == &sit_link_ops) + unregister_netdevice_queue(dev, head); + + for (prio = 0; prio < 4; prio++) { + int h; + for (h = 0; h < (prio ? IP6_SIT_HASH_SIZE : 1); h++) { + struct ip_tunnel *t; + + t = rtnl_dereference(sitn->tunnels[prio][h]); + while (t) { + /* If dev is in the same netns, it has already + * been added to the list by the previous loop. + */ + if (!net_eq(dev_net(t->dev), net)) + unregister_netdevice_queue(t->dev, + head); + t = rtnl_dereference(t->next); + } + } + } +} + +static int __net_init sit_init_net(struct net *net) +{ + struct sit_net *sitn = net_generic(net, sit_net_id); + struct ip_tunnel *t; + int err; + + sitn->tunnels[0] = sitn->tunnels_wc; + sitn->tunnels[1] = sitn->tunnels_l; + sitn->tunnels[2] = sitn->tunnels_r; + sitn->tunnels[3] = sitn->tunnels_r_l; + + if (!net_has_fallback_tunnels(net)) + return 0; + + sitn->fb_tunnel_dev = alloc_netdev(sizeof(struct ip_tunnel), "sit0", + NET_NAME_UNKNOWN, + ipip6_tunnel_setup); + if (!sitn->fb_tunnel_dev) { + err = -ENOMEM; + goto err_alloc_dev; + } + dev_net_set(sitn->fb_tunnel_dev, net); + sitn->fb_tunnel_dev->rtnl_link_ops = &sit_link_ops; + /* FB netdevice is special: we have one, and only one per netns. + * Allowing to move it to another netns is clearly unsafe. + */ + sitn->fb_tunnel_dev->features |= NETIF_F_NETNS_LOCAL; + + err = register_netdev(sitn->fb_tunnel_dev); + if (err) + goto err_reg_dev; + + ipip6_tunnel_clone_6rd(sitn->fb_tunnel_dev, sitn); + ipip6_fb_tunnel_init(sitn->fb_tunnel_dev); + + t = netdev_priv(sitn->fb_tunnel_dev); + + strcpy(t->parms.name, sitn->fb_tunnel_dev->name); + return 0; + +err_reg_dev: + free_netdev(sitn->fb_tunnel_dev); +err_alloc_dev: + return err; +} + +static void __net_exit sit_exit_batch_net(struct list_head *net_list) +{ + LIST_HEAD(list); + struct net *net; + + rtnl_lock(); + list_for_each_entry(net, net_list, exit_list) + sit_destroy_tunnels(net, &list); + + unregister_netdevice_many(&list); + rtnl_unlock(); +} + +static struct pernet_operations sit_net_ops = { + .init = sit_init_net, + .exit_batch = sit_exit_batch_net, + .id = &sit_net_id, + .size = sizeof(struct sit_net), +}; + +static void __exit sit_cleanup(void) +{ + rtnl_link_unregister(&sit_link_ops); + xfrm4_tunnel_deregister(&sit_handler, AF_INET6); + xfrm4_tunnel_deregister(&ipip_handler, AF_INET); +#if IS_ENABLED(CONFIG_MPLS) + xfrm4_tunnel_deregister(&mplsip_handler, AF_MPLS); +#endif + + unregister_pernet_device(&sit_net_ops); + rcu_barrier(); /* Wait for completion of call_rcu()'s */ +} + +static int __init sit_init(void) +{ + int err; + + pr_info("IPv6, IPv4 and MPLS over IPv4 tunneling driver\n"); + + err = register_pernet_device(&sit_net_ops); + if (err < 0) + return err; + err = xfrm4_tunnel_register(&sit_handler, AF_INET6); + if (err < 0) { + pr_info("%s: can't register ip6ip4\n", __func__); + goto xfrm_tunnel_failed; + } + err = xfrm4_tunnel_register(&ipip_handler, AF_INET); + if (err < 0) { + pr_info("%s: can't register ip4ip4\n", __func__); + goto xfrm_tunnel4_failed; + } +#if IS_ENABLED(CONFIG_MPLS) + err = xfrm4_tunnel_register(&mplsip_handler, AF_MPLS); + if (err < 0) { + pr_info("%s: can't register mplsip\n", __func__); + goto xfrm_tunnel_mpls_failed; + } +#endif + err = rtnl_link_register(&sit_link_ops); + if (err < 0) + goto rtnl_link_failed; + +out: + return err; + +rtnl_link_failed: +#if IS_ENABLED(CONFIG_MPLS) + xfrm4_tunnel_deregister(&mplsip_handler, AF_MPLS); +xfrm_tunnel_mpls_failed: +#endif + xfrm4_tunnel_deregister(&ipip_handler, AF_INET); +xfrm_tunnel4_failed: + xfrm4_tunnel_deregister(&sit_handler, AF_INET6); +xfrm_tunnel_failed: + unregister_pernet_device(&sit_net_ops); + goto out; +} + +module_init(sit_init); +module_exit(sit_cleanup); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_RTNL_LINK("sit"); +MODULE_ALIAS_NETDEV("sit0"); diff --git a/net/ipv6/syncookies.c b/net/ipv6/syncookies.c new file mode 100644 index 000000000..8698b49df --- /dev/null +++ b/net/ipv6/syncookies.c @@ -0,0 +1,267 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPv6 Syncookies implementation for the Linux kernel + * + * Authors: + * Glenn Griffin + * + * Based on IPv4 implementation by Andi Kleen + * linux/net/ipv4/syncookies.c + */ + +#include +#include +#include +#include +#include +#include +#include + +#define COOKIEBITS 24 /* Upper bits store count */ +#define COOKIEMASK (((__u32)1 << COOKIEBITS) - 1) + +static siphash_aligned_key_t syncookie6_secret[2]; + +/* RFC 2460, Section 8.3: + * [ipv6 tcp] MSS must be computed as the maximum packet size minus 60 [..] + * + * Due to IPV6_MIN_MTU=1280 the lowest possible MSS is 1220, which allows + * using higher values than ipv4 tcp syncookies. + * The other values are chosen based on ethernet (1500 and 9k MTU), plus + * one that accounts for common encap (PPPoe) overhead. Table must be sorted. + */ +static __u16 const msstab[] = { + 1280 - 60, /* IPV6_MIN_MTU - 60 */ + 1480 - 60, + 1500 - 60, + 9000 - 60, +}; + +static u32 cookie_hash(const struct in6_addr *saddr, + const struct in6_addr *daddr, + __be16 sport, __be16 dport, u32 count, int c) +{ + const struct { + struct in6_addr saddr; + struct in6_addr daddr; + u32 count; + __be16 sport; + __be16 dport; + } __aligned(SIPHASH_ALIGNMENT) combined = { + .saddr = *saddr, + .daddr = *daddr, + .count = count, + .sport = sport, + .dport = dport + }; + + net_get_random_once(syncookie6_secret, sizeof(syncookie6_secret)); + return siphash(&combined, offsetofend(typeof(combined), dport), + &syncookie6_secret[c]); +} + +static __u32 secure_tcp_syn_cookie(const struct in6_addr *saddr, + const struct in6_addr *daddr, + __be16 sport, __be16 dport, __u32 sseq, + __u32 data) +{ + u32 count = tcp_cookie_time(); + return (cookie_hash(saddr, daddr, sport, dport, 0, 0) + + sseq + (count << COOKIEBITS) + + ((cookie_hash(saddr, daddr, sport, dport, count, 1) + data) + & COOKIEMASK)); +} + +static __u32 check_tcp_syn_cookie(__u32 cookie, const struct in6_addr *saddr, + const struct in6_addr *daddr, __be16 sport, + __be16 dport, __u32 sseq) +{ + __u32 diff, count = tcp_cookie_time(); + + cookie -= cookie_hash(saddr, daddr, sport, dport, 0, 0) + sseq; + + diff = (count - (cookie >> COOKIEBITS)) & ((__u32) -1 >> COOKIEBITS); + if (diff >= MAX_SYNCOOKIE_AGE) + return (__u32)-1; + + return (cookie - + cookie_hash(saddr, daddr, sport, dport, count - diff, 1)) + & COOKIEMASK; +} + +u32 __cookie_v6_init_sequence(const struct ipv6hdr *iph, + const struct tcphdr *th, __u16 *mssp) +{ + int mssind; + const __u16 mss = *mssp; + + for (mssind = ARRAY_SIZE(msstab) - 1; mssind ; mssind--) + if (mss >= msstab[mssind]) + break; + + *mssp = msstab[mssind]; + + return secure_tcp_syn_cookie(&iph->saddr, &iph->daddr, th->source, + th->dest, ntohl(th->seq), mssind); +} +EXPORT_SYMBOL_GPL(__cookie_v6_init_sequence); + +__u32 cookie_v6_init_sequence(const struct sk_buff *skb, __u16 *mssp) +{ + const struct ipv6hdr *iph = ipv6_hdr(skb); + const struct tcphdr *th = tcp_hdr(skb); + + return __cookie_v6_init_sequence(iph, th, mssp); +} + +int __cookie_v6_check(const struct ipv6hdr *iph, const struct tcphdr *th, + __u32 cookie) +{ + __u32 seq = ntohl(th->seq) - 1; + __u32 mssind = check_tcp_syn_cookie(cookie, &iph->saddr, &iph->daddr, + th->source, th->dest, seq); + + return mssind < ARRAY_SIZE(msstab) ? msstab[mssind] : 0; +} +EXPORT_SYMBOL_GPL(__cookie_v6_check); + +struct sock *cookie_v6_check(struct sock *sk, struct sk_buff *skb) +{ + struct tcp_options_received tcp_opt; + struct inet_request_sock *ireq; + struct tcp_request_sock *treq; + struct ipv6_pinfo *np = inet6_sk(sk); + struct tcp_sock *tp = tcp_sk(sk); + const struct tcphdr *th = tcp_hdr(skb); + __u32 cookie = ntohl(th->ack_seq) - 1; + struct sock *ret = sk; + struct request_sock *req; + int full_space, mss; + struct dst_entry *dst; + __u8 rcv_wscale; + u32 tsoff = 0; + + if (!READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_syncookies) || + !th->ack || th->rst) + goto out; + + if (tcp_synq_no_recent_overflow(sk)) + goto out; + + mss = __cookie_v6_check(ipv6_hdr(skb), th, cookie); + if (mss == 0) { + __NET_INC_STATS(sock_net(sk), LINUX_MIB_SYNCOOKIESFAILED); + goto out; + } + + __NET_INC_STATS(sock_net(sk), LINUX_MIB_SYNCOOKIESRECV); + + /* check for timestamp cookie support */ + memset(&tcp_opt, 0, sizeof(tcp_opt)); + tcp_parse_options(sock_net(sk), skb, &tcp_opt, 0, NULL); + + if (tcp_opt.saw_tstamp && tcp_opt.rcv_tsecr) { + tsoff = secure_tcpv6_ts_off(sock_net(sk), + ipv6_hdr(skb)->daddr.s6_addr32, + ipv6_hdr(skb)->saddr.s6_addr32); + tcp_opt.rcv_tsecr -= tsoff; + } + + if (!cookie_timestamp_decode(sock_net(sk), &tcp_opt)) + goto out; + + ret = NULL; + req = cookie_tcp_reqsk_alloc(&tcp6_request_sock_ops, + &tcp_request_sock_ipv6_ops, sk, skb); + if (!req) + goto out; + + ireq = inet_rsk(req); + treq = tcp_rsk(req); + treq->tfo_listener = false; + + req->mss = mss; + ireq->ir_rmt_port = th->source; + ireq->ir_num = ntohs(th->dest); + ireq->ir_v6_rmt_addr = ipv6_hdr(skb)->saddr; + ireq->ir_v6_loc_addr = ipv6_hdr(skb)->daddr; + + if (security_inet_conn_request(sk, skb, req)) + goto out_free; + + if (ipv6_opt_accepted(sk, skb, &TCP_SKB_CB(skb)->header.h6) || + np->rxopt.bits.rxinfo || np->rxopt.bits.rxoinfo || + np->rxopt.bits.rxhlim || np->rxopt.bits.rxohlim) { + refcount_inc(&skb->users); + ireq->pktopts = skb; + } + + ireq->ir_iif = inet_request_bound_dev_if(sk, skb); + /* So that link locals have meaning */ + if (!sk->sk_bound_dev_if && + ipv6_addr_type(&ireq->ir_v6_rmt_addr) & IPV6_ADDR_LINKLOCAL) + ireq->ir_iif = tcp_v6_iif(skb); + + ireq->ir_mark = inet_request_mark(sk, skb); + + req->num_retrans = 0; + ireq->snd_wscale = tcp_opt.snd_wscale; + ireq->sack_ok = tcp_opt.sack_ok; + ireq->wscale_ok = tcp_opt.wscale_ok; + ireq->tstamp_ok = tcp_opt.saw_tstamp; + req->ts_recent = tcp_opt.saw_tstamp ? tcp_opt.rcv_tsval : 0; + treq->snt_synack = 0; + treq->rcv_isn = ntohl(th->seq) - 1; + treq->snt_isn = cookie; + treq->ts_off = 0; + treq->txhash = net_tx_rndhash(); + if (IS_ENABLED(CONFIG_SMC)) + ireq->smc_ok = 0; + + /* + * We need to lookup the dst_entry to get the correct window size. + * This is taken from tcp_v6_syn_recv_sock. Somebody please enlighten + * me if there is a preferred way. + */ + { + struct in6_addr *final_p, final; + struct flowi6 fl6; + memset(&fl6, 0, sizeof(fl6)); + fl6.flowi6_proto = IPPROTO_TCP; + fl6.daddr = ireq->ir_v6_rmt_addr; + final_p = fl6_update_dst(&fl6, rcu_dereference(np->opt), &final); + fl6.saddr = ireq->ir_v6_loc_addr; + fl6.flowi6_oif = ireq->ir_iif; + fl6.flowi6_mark = ireq->ir_mark; + fl6.fl6_dport = ireq->ir_rmt_port; + fl6.fl6_sport = inet_sk(sk)->inet_sport; + fl6.flowi6_uid = sk->sk_uid; + security_req_classify_flow(req, flowi6_to_flowi_common(&fl6)); + + dst = ip6_dst_lookup_flow(sock_net(sk), sk, &fl6, final_p); + if (IS_ERR(dst)) + goto out_free; + } + + req->rsk_window_clamp = tp->window_clamp ? :dst_metric(dst, RTAX_WINDOW); + /* limit the window selection if the user enforce a smaller rx buffer */ + full_space = tcp_full_space(sk); + if (sk->sk_userlocks & SOCK_RCVBUF_LOCK && + (req->rsk_window_clamp > full_space || req->rsk_window_clamp == 0)) + req->rsk_window_clamp = full_space; + + tcp_select_initial_window(sk, full_space, req->mss, + &req->rsk_rcv_wnd, &req->rsk_window_clamp, + ireq->wscale_ok, &rcv_wscale, + dst_metric(dst, RTAX_INITRWND)); + + ireq->rcv_wscale = rcv_wscale; + ireq->ecn_ok = cookie_ecn_ok(&tcp_opt, sock_net(sk), dst); + + ret = tcp_get_cookie_sock(sk, skb, req, dst, tsoff); +out: + return ret; +out_free: + reqsk_free(req); + return NULL; +} diff --git a/net/ipv6/sysctl_net_ipv6.c b/net/ipv6/sysctl_net_ipv6.c new file mode 100644 index 000000000..94a0a294c --- /dev/null +++ b/net/ipv6/sysctl_net_ipv6.c @@ -0,0 +1,357 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * sysctl_net_ipv6.c: sysctl interface to net IPV6 subsystem. + * + * Changes: + * YOSHIFUJI Hideaki @USAGI: added icmp sysctl table. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_NETLABEL +#include +#endif +#include + +static int flowlabel_reflect_max = 0x7; +static int auto_flowlabels_max = IP6_AUTO_FLOW_LABEL_MAX; +static u32 rt6_multipath_hash_fields_all_mask = + FIB_MULTIPATH_HASH_FIELD_ALL_MASK; +static u32 ioam6_id_max = IOAM6_DEFAULT_ID; +static u64 ioam6_id_wide_max = IOAM6_DEFAULT_ID_WIDE; + +static int proc_rt6_multipath_hash_policy(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net; + int ret; + + net = container_of(table->data, struct net, + ipv6.sysctl.multipath_hash_policy); + ret = proc_dou8vec_minmax(table, write, buffer, lenp, ppos); + if (write && ret == 0) + call_netevent_notifiers(NETEVENT_IPV6_MPATH_HASH_UPDATE, net); + + return ret; +} + +static int +proc_rt6_multipath_hash_fields(struct ctl_table *table, int write, void *buffer, + size_t *lenp, loff_t *ppos) +{ + struct net *net; + int ret; + + net = container_of(table->data, struct net, + ipv6.sysctl.multipath_hash_fields); + ret = proc_douintvec_minmax(table, write, buffer, lenp, ppos); + if (write && ret == 0) + call_netevent_notifiers(NETEVENT_IPV6_MPATH_HASH_UPDATE, net); + + return ret; +} + +static struct ctl_table ipv6_table_template[] = { + { + .procname = "bindv6only", + .data = &init_net.ipv6.sysctl.bindv6only, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "anycast_src_echo_reply", + .data = &init_net.ipv6.sysctl.anycast_src_echo_reply, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "flowlabel_consistency", + .data = &init_net.ipv6.sysctl.flowlabel_consistency, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "auto_flowlabels", + .data = &init_net.ipv6.sysctl.auto_flowlabels, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra2 = &auto_flowlabels_max + }, + { + .procname = "fwmark_reflect", + .data = &init_net.ipv6.sysctl.fwmark_reflect, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "idgen_retries", + .data = &init_net.ipv6.sysctl.idgen_retries, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "idgen_delay", + .data = &init_net.ipv6.sysctl.idgen_delay, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "flowlabel_state_ranges", + .data = &init_net.ipv6.sysctl.flowlabel_state_ranges, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "ip_nonlocal_bind", + .data = &init_net.ipv6.sysctl.ip_nonlocal_bind, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + { + .procname = "flowlabel_reflect", + .data = &init_net.ipv6.sysctl.flowlabel_reflect, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &flowlabel_reflect_max, + }, + { + .procname = "max_dst_opts_number", + .data = &init_net.ipv6.sysctl.max_dst_opts_cnt, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "max_hbh_opts_number", + .data = &init_net.ipv6.sysctl.max_hbh_opts_cnt, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "max_dst_opts_length", + .data = &init_net.ipv6.sysctl.max_dst_opts_len, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "max_hbh_length", + .data = &init_net.ipv6.sysctl.max_hbh_opts_len, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "fib_multipath_hash_policy", + .data = &init_net.ipv6.sysctl.multipath_hash_policy, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_rt6_multipath_hash_policy, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_THREE, + }, + { + .procname = "fib_multipath_hash_fields", + .data = &init_net.ipv6.sysctl.multipath_hash_fields, + .maxlen = sizeof(u32), + .mode = 0644, + .proc_handler = proc_rt6_multipath_hash_fields, + .extra1 = SYSCTL_ONE, + .extra2 = &rt6_multipath_hash_fields_all_mask, + }, + { + .procname = "seg6_flowlabel", + .data = &init_net.ipv6.sysctl.seg6_flowlabel, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "fib_notify_on_flag_change", + .data = &init_net.ipv6.sysctl.fib_notify_on_flag_change, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_TWO, + }, + { + .procname = "ioam6_id", + .data = &init_net.ipv6.sysctl.ioam6_id, + .maxlen = sizeof(u32), + .mode = 0644, + .proc_handler = proc_douintvec_minmax, + .extra2 = &ioam6_id_max, + }, + { + .procname = "ioam6_id_wide", + .data = &init_net.ipv6.sysctl.ioam6_id_wide, + .maxlen = sizeof(u64), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + .extra2 = &ioam6_id_wide_max, + }, + { } +}; + +static struct ctl_table ipv6_rotable[] = { + { + .procname = "mld_max_msf", + .data = &sysctl_mld_max_msf, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "mld_qrv", + .data = &sysctl_mld_qrv, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE + }, +#ifdef CONFIG_NETLABEL + { + .procname = "calipso_cache_enable", + .data = &calipso_cache_enabled, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "calipso_cache_bucket_size", + .data = &calipso_cache_bucketsize, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, +#endif /* CONFIG_NETLABEL */ + { } +}; + +static int __net_init ipv6_sysctl_net_init(struct net *net) +{ + struct ctl_table *ipv6_table; + struct ctl_table *ipv6_route_table; + struct ctl_table *ipv6_icmp_table; + int err, i; + + err = -ENOMEM; + ipv6_table = kmemdup(ipv6_table_template, sizeof(ipv6_table_template), + GFP_KERNEL); + if (!ipv6_table) + goto out; + /* Update the variables to point into the current struct net */ + for (i = 0; i < ARRAY_SIZE(ipv6_table_template) - 1; i++) + ipv6_table[i].data += (void *)net - (void *)&init_net; + + ipv6_route_table = ipv6_route_sysctl_init(net); + if (!ipv6_route_table) + goto out_ipv6_table; + + ipv6_icmp_table = ipv6_icmp_sysctl_init(net); + if (!ipv6_icmp_table) + goto out_ipv6_route_table; + + net->ipv6.sysctl.hdr = register_net_sysctl(net, "net/ipv6", ipv6_table); + if (!net->ipv6.sysctl.hdr) + goto out_ipv6_icmp_table; + + net->ipv6.sysctl.route_hdr = + register_net_sysctl(net, "net/ipv6/route", ipv6_route_table); + if (!net->ipv6.sysctl.route_hdr) + goto out_unregister_ipv6_table; + + net->ipv6.sysctl.icmp_hdr = + register_net_sysctl(net, "net/ipv6/icmp", ipv6_icmp_table); + if (!net->ipv6.sysctl.icmp_hdr) + goto out_unregister_route_table; + + err = 0; +out: + return err; +out_unregister_route_table: + unregister_net_sysctl_table(net->ipv6.sysctl.route_hdr); +out_unregister_ipv6_table: + unregister_net_sysctl_table(net->ipv6.sysctl.hdr); +out_ipv6_icmp_table: + kfree(ipv6_icmp_table); +out_ipv6_route_table: + kfree(ipv6_route_table); +out_ipv6_table: + kfree(ipv6_table); + goto out; +} + +static void __net_exit ipv6_sysctl_net_exit(struct net *net) +{ + struct ctl_table *ipv6_table; + struct ctl_table *ipv6_route_table; + struct ctl_table *ipv6_icmp_table; + + ipv6_table = net->ipv6.sysctl.hdr->ctl_table_arg; + ipv6_route_table = net->ipv6.sysctl.route_hdr->ctl_table_arg; + ipv6_icmp_table = net->ipv6.sysctl.icmp_hdr->ctl_table_arg; + + unregister_net_sysctl_table(net->ipv6.sysctl.icmp_hdr); + unregister_net_sysctl_table(net->ipv6.sysctl.route_hdr); + unregister_net_sysctl_table(net->ipv6.sysctl.hdr); + + kfree(ipv6_table); + kfree(ipv6_route_table); + kfree(ipv6_icmp_table); +} + +static struct pernet_operations ipv6_sysctl_net_ops = { + .init = ipv6_sysctl_net_init, + .exit = ipv6_sysctl_net_exit, +}; + +static struct ctl_table_header *ip6_header; + +int ipv6_sysctl_register(void) +{ + int err = -ENOMEM; + + ip6_header = register_net_sysctl(&init_net, "net/ipv6", ipv6_rotable); + if (!ip6_header) + goto out; + + err = register_pernet_subsys(&ipv6_sysctl_net_ops); + if (err) + goto err_pernet; +out: + return err; + +err_pernet: + unregister_net_sysctl_table(ip6_header); + goto out; +} + +void ipv6_sysctl_unregister(void) +{ + unregister_net_sysctl_table(ip6_header); + unregister_pernet_subsys(&ipv6_sysctl_net_ops); +} diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c new file mode 100644 index 000000000..ba9a22db5 --- /dev/null +++ b/net/ipv6/tcp_ipv6.c @@ -0,0 +1,2271 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * TCP over IPv6 + * Linux INET6 implementation + * + * Authors: + * Pedro Roque + * + * Based on: + * linux/net/ipv4/tcp.c + * linux/net/ipv4/tcp_input.c + * linux/net/ipv4/tcp_output.c + * + * Fixes: + * Hideaki YOSHIFUJI : sin6_scope_id support + * YOSHIFUJI Hideaki @USAGI and: Support IPV6_V6ONLY socket option, which + * Alexey Kuznetsov allow both IPv4 and IPv6 sockets to bind + * a single port at the same time. + * YOSHIFUJI Hideaki @USAGI: convert /proc/net/tcp6 to seq_file. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include + +static void tcp_v6_send_reset(const struct sock *sk, struct sk_buff *skb); +static void tcp_v6_reqsk_send_ack(const struct sock *sk, struct sk_buff *skb, + struct request_sock *req); + +INDIRECT_CALLABLE_SCOPE int tcp_v6_do_rcv(struct sock *sk, struct sk_buff *skb); + +static const struct inet_connection_sock_af_ops ipv6_mapped; +const struct inet_connection_sock_af_ops ipv6_specific; +#ifdef CONFIG_TCP_MD5SIG +static const struct tcp_sock_af_ops tcp_sock_ipv6_specific; +static const struct tcp_sock_af_ops tcp_sock_ipv6_mapped_specific; +#else +static struct tcp_md5sig_key *tcp_v6_md5_do_lookup(const struct sock *sk, + const struct in6_addr *addr, + int l3index) +{ + return NULL; +} +#endif + +/* Helper returning the inet6 address from a given tcp socket. + * It can be used in TCP stack instead of inet6_sk(sk). + * This avoids a dereference and allow compiler optimizations. + * It is a specialized version of inet6_sk_generic(). + */ +static struct ipv6_pinfo *tcp_inet6_sk(const struct sock *sk) +{ + unsigned int offset = sizeof(struct tcp6_sock) - sizeof(struct ipv6_pinfo); + + return (struct ipv6_pinfo *)(((u8 *)sk) + offset); +} + +static void inet6_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + + if (dst && dst_hold_safe(dst)) { + const struct rt6_info *rt = (const struct rt6_info *)dst; + + rcu_assign_pointer(sk->sk_rx_dst, dst); + sk->sk_rx_dst_ifindex = skb->skb_iif; + sk->sk_rx_dst_cookie = rt6_get_cookie(rt); + } +} + +static u32 tcp_v6_init_seq(const struct sk_buff *skb) +{ + return secure_tcpv6_seq(ipv6_hdr(skb)->daddr.s6_addr32, + ipv6_hdr(skb)->saddr.s6_addr32, + tcp_hdr(skb)->dest, + tcp_hdr(skb)->source); +} + +static u32 tcp_v6_init_ts_off(const struct net *net, const struct sk_buff *skb) +{ + return secure_tcpv6_ts_off(net, ipv6_hdr(skb)->daddr.s6_addr32, + ipv6_hdr(skb)->saddr.s6_addr32); +} + +static int tcp_v6_pre_connect(struct sock *sk, struct sockaddr *uaddr, + int addr_len) +{ + /* This check is replicated from tcp_v6_connect() and intended to + * prevent BPF program called below from accessing bytes that are out + * of the bound specified by user in addr_len. + */ + if (addr_len < SIN6_LEN_RFC2133) + return -EINVAL; + + sock_owned_by_me(sk); + + return BPF_CGROUP_RUN_PROG_INET6_CONNECT(sk, uaddr); +} + +static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr, + int addr_len) +{ + struct sockaddr_in6 *usin = (struct sockaddr_in6 *) uaddr; + struct inet_connection_sock *icsk = inet_csk(sk); + struct in6_addr *saddr = NULL, *final_p, final; + struct inet_timewait_death_row *tcp_death_row; + struct ipv6_pinfo *np = tcp_inet6_sk(sk); + struct inet_sock *inet = inet_sk(sk); + struct tcp_sock *tp = tcp_sk(sk); + struct net *net = sock_net(sk); + struct ipv6_txoptions *opt; + struct dst_entry *dst; + struct flowi6 fl6; + int addr_type; + int err; + + if (addr_len < SIN6_LEN_RFC2133) + return -EINVAL; + + if (usin->sin6_family != AF_INET6) + return -EAFNOSUPPORT; + + memset(&fl6, 0, sizeof(fl6)); + + if (np->sndflow) { + fl6.flowlabel = usin->sin6_flowinfo&IPV6_FLOWINFO_MASK; + IP6_ECN_flow_init(fl6.flowlabel); + if (fl6.flowlabel&IPV6_FLOWLABEL_MASK) { + struct ip6_flowlabel *flowlabel; + flowlabel = fl6_sock_lookup(sk, fl6.flowlabel); + if (IS_ERR(flowlabel)) + return -EINVAL; + fl6_sock_release(flowlabel); + } + } + + /* + * connect() to INADDR_ANY means loopback (BSD'ism). + */ + + if (ipv6_addr_any(&usin->sin6_addr)) { + if (ipv6_addr_v4mapped(&sk->sk_v6_rcv_saddr)) + ipv6_addr_set_v4mapped(htonl(INADDR_LOOPBACK), + &usin->sin6_addr); + else + usin->sin6_addr = in6addr_loopback; + } + + addr_type = ipv6_addr_type(&usin->sin6_addr); + + if (addr_type & IPV6_ADDR_MULTICAST) + return -ENETUNREACH; + + if (addr_type&IPV6_ADDR_LINKLOCAL) { + if (addr_len >= sizeof(struct sockaddr_in6) && + usin->sin6_scope_id) { + /* If interface is set while binding, indices + * must coincide. + */ + if (!sk_dev_equal_l3scope(sk, usin->sin6_scope_id)) + return -EINVAL; + + sk->sk_bound_dev_if = usin->sin6_scope_id; + } + + /* Connect to link-local address requires an interface */ + if (!sk->sk_bound_dev_if) + return -EINVAL; + } + + if (tp->rx_opt.ts_recent_stamp && + !ipv6_addr_equal(&sk->sk_v6_daddr, &usin->sin6_addr)) { + tp->rx_opt.ts_recent = 0; + tp->rx_opt.ts_recent_stamp = 0; + WRITE_ONCE(tp->write_seq, 0); + } + + sk->sk_v6_daddr = usin->sin6_addr; + np->flow_label = fl6.flowlabel; + + /* + * TCP over IPv4 + */ + + if (addr_type & IPV6_ADDR_MAPPED) { + u32 exthdrlen = icsk->icsk_ext_hdr_len; + struct sockaddr_in sin; + + if (ipv6_only_sock(sk)) + return -ENETUNREACH; + + sin.sin_family = AF_INET; + sin.sin_port = usin->sin6_port; + sin.sin_addr.s_addr = usin->sin6_addr.s6_addr32[3]; + + /* Paired with READ_ONCE() in tcp_(get|set)sockopt() */ + WRITE_ONCE(icsk->icsk_af_ops, &ipv6_mapped); + if (sk_is_mptcp(sk)) + mptcpv6_handle_mapped(sk, true); + sk->sk_backlog_rcv = tcp_v4_do_rcv; +#ifdef CONFIG_TCP_MD5SIG + tp->af_specific = &tcp_sock_ipv6_mapped_specific; +#endif + + err = tcp_v4_connect(sk, (struct sockaddr *)&sin, sizeof(sin)); + + if (err) { + icsk->icsk_ext_hdr_len = exthdrlen; + /* Paired with READ_ONCE() in tcp_(get|set)sockopt() */ + WRITE_ONCE(icsk->icsk_af_ops, &ipv6_specific); + if (sk_is_mptcp(sk)) + mptcpv6_handle_mapped(sk, false); + sk->sk_backlog_rcv = tcp_v6_do_rcv; +#ifdef CONFIG_TCP_MD5SIG + tp->af_specific = &tcp_sock_ipv6_specific; +#endif + goto failure; + } + np->saddr = sk->sk_v6_rcv_saddr; + + return err; + } + + if (!ipv6_addr_any(&sk->sk_v6_rcv_saddr)) + saddr = &sk->sk_v6_rcv_saddr; + + fl6.flowi6_proto = IPPROTO_TCP; + fl6.daddr = sk->sk_v6_daddr; + fl6.saddr = saddr ? *saddr : np->saddr; + fl6.flowlabel = ip6_make_flowinfo(np->tclass, np->flow_label); + fl6.flowi6_oif = sk->sk_bound_dev_if; + fl6.flowi6_mark = sk->sk_mark; + fl6.fl6_dport = usin->sin6_port; + fl6.fl6_sport = inet->inet_sport; + fl6.flowi6_uid = sk->sk_uid; + + opt = rcu_dereference_protected(np->opt, lockdep_sock_is_held(sk)); + final_p = fl6_update_dst(&fl6, opt, &final); + + security_sk_classify_flow(sk, flowi6_to_flowi_common(&fl6)); + + dst = ip6_dst_lookup_flow(net, sk, &fl6, final_p); + if (IS_ERR(dst)) { + err = PTR_ERR(dst); + goto failure; + } + + tcp_death_row = &sock_net(sk)->ipv4.tcp_death_row; + + if (!saddr) { + saddr = &fl6.saddr; + + err = inet_bhash2_update_saddr(sk, saddr, AF_INET6); + if (err) + goto failure; + } + + /* set the source address */ + np->saddr = *saddr; + inet->inet_rcv_saddr = LOOPBACK4_IPV6; + + sk->sk_gso_type = SKB_GSO_TCPV6; + ip6_dst_store(sk, dst, NULL, NULL); + + icsk->icsk_ext_hdr_len = 0; + if (opt) + icsk->icsk_ext_hdr_len = opt->opt_flen + + opt->opt_nflen; + + tp->rx_opt.mss_clamp = IPV6_MIN_MTU - sizeof(struct tcphdr) - sizeof(struct ipv6hdr); + + inet->inet_dport = usin->sin6_port; + + tcp_set_state(sk, TCP_SYN_SENT); + err = inet6_hash_connect(tcp_death_row, sk); + if (err) + goto late_failure; + + sk_set_txhash(sk); + + if (likely(!tp->repair)) { + if (!tp->write_seq) + WRITE_ONCE(tp->write_seq, + secure_tcpv6_seq(np->saddr.s6_addr32, + sk->sk_v6_daddr.s6_addr32, + inet->inet_sport, + inet->inet_dport)); + tp->tsoffset = secure_tcpv6_ts_off(net, np->saddr.s6_addr32, + sk->sk_v6_daddr.s6_addr32); + } + + if (tcp_fastopen_defer_connect(sk, &err)) + return err; + if (err) + goto late_failure; + + err = tcp_connect(sk); + if (err) + goto late_failure; + + return 0; + +late_failure: + tcp_set_state(sk, TCP_CLOSE); + inet_bhash2_reset_saddr(sk); +failure: + inet->inet_dport = 0; + sk->sk_route_caps = 0; + return err; +} + +static void tcp_v6_mtu_reduced(struct sock *sk) +{ + struct dst_entry *dst; + u32 mtu; + + if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) + return; + + mtu = READ_ONCE(tcp_sk(sk)->mtu_info); + + /* Drop requests trying to increase our current mss. + * Check done in __ip6_rt_update_pmtu() is too late. + */ + if (tcp_mtu_to_mss(sk, mtu) >= tcp_sk(sk)->mss_cache) + return; + + dst = inet6_csk_update_pmtu(sk, mtu); + if (!dst) + return; + + if (inet_csk(sk)->icsk_pmtu_cookie > dst_mtu(dst)) { + tcp_sync_mss(sk, dst_mtu(dst)); + tcp_simple_retransmit(sk); + } +} + +static int tcp_v6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + const struct ipv6hdr *hdr = (const struct ipv6hdr *)skb->data; + const struct tcphdr *th = (struct tcphdr *)(skb->data+offset); + struct net *net = dev_net(skb->dev); + struct request_sock *fastopen; + struct ipv6_pinfo *np; + struct tcp_sock *tp; + __u32 seq, snd_una; + struct sock *sk; + bool fatal; + int err; + + sk = __inet6_lookup_established(net, net->ipv4.tcp_death_row.hashinfo, + &hdr->daddr, th->dest, + &hdr->saddr, ntohs(th->source), + skb->dev->ifindex, inet6_sdif(skb)); + + if (!sk) { + __ICMP6_INC_STATS(net, __in6_dev_get(skb->dev), + ICMP6_MIB_INERRORS); + return -ENOENT; + } + + if (sk->sk_state == TCP_TIME_WAIT) { + inet_twsk_put(inet_twsk(sk)); + return 0; + } + seq = ntohl(th->seq); + fatal = icmpv6_err_convert(type, code, &err); + if (sk->sk_state == TCP_NEW_SYN_RECV) { + tcp_req_err(sk, seq, fatal); + return 0; + } + + bh_lock_sock(sk); + if (sock_owned_by_user(sk) && type != ICMPV6_PKT_TOOBIG) + __NET_INC_STATS(net, LINUX_MIB_LOCKDROPPEDICMPS); + + if (sk->sk_state == TCP_CLOSE) + goto out; + + if (static_branch_unlikely(&ip6_min_hopcount)) { + /* min_hopcount can be changed concurrently from do_ipv6_setsockopt() */ + if (ipv6_hdr(skb)->hop_limit < READ_ONCE(tcp_inet6_sk(sk)->min_hopcount)) { + __NET_INC_STATS(net, LINUX_MIB_TCPMINTTLDROP); + goto out; + } + } + + tp = tcp_sk(sk); + /* XXX (TFO) - tp->snd_una should be ISN (tcp_create_openreq_child() */ + fastopen = rcu_dereference(tp->fastopen_rsk); + snd_una = fastopen ? tcp_rsk(fastopen)->snt_isn : tp->snd_una; + if (sk->sk_state != TCP_LISTEN && + !between(seq, snd_una, tp->snd_nxt)) { + __NET_INC_STATS(net, LINUX_MIB_OUTOFWINDOWICMPS); + goto out; + } + + np = tcp_inet6_sk(sk); + + if (type == NDISC_REDIRECT) { + if (!sock_owned_by_user(sk)) { + struct dst_entry *dst = __sk_dst_check(sk, np->dst_cookie); + + if (dst) + dst->ops->redirect(dst, sk, skb); + } + goto out; + } + + if (type == ICMPV6_PKT_TOOBIG) { + u32 mtu = ntohl(info); + + /* We are not interested in TCP_LISTEN and open_requests + * (SYN-ACKs send out by Linux are always <576bytes so + * they should go through unfragmented). + */ + if (sk->sk_state == TCP_LISTEN) + goto out; + + if (!ip6_sk_accept_pmtu(sk)) + goto out; + + if (mtu < IPV6_MIN_MTU) + goto out; + + WRITE_ONCE(tp->mtu_info, mtu); + + if (!sock_owned_by_user(sk)) + tcp_v6_mtu_reduced(sk); + else if (!test_and_set_bit(TCP_MTU_REDUCED_DEFERRED, + &sk->sk_tsq_flags)) + sock_hold(sk); + goto out; + } + + + /* Might be for an request_sock */ + switch (sk->sk_state) { + case TCP_SYN_SENT: + case TCP_SYN_RECV: + /* Only in fast or simultaneous open. If a fast open socket is + * already accepted it is treated as a connected one below. + */ + if (fastopen && !fastopen->sk) + break; + + ipv6_icmp_error(sk, skb, err, th->dest, ntohl(info), (u8 *)th); + + if (!sock_owned_by_user(sk)) { + sk->sk_err = err; + sk_error_report(sk); /* Wake people up to see the error (see connect in sock.c) */ + + tcp_done(sk); + } else + sk->sk_err_soft = err; + goto out; + case TCP_LISTEN: + break; + default: + /* check if this ICMP message allows revert of backoff. + * (see RFC 6069) + */ + if (!fastopen && type == ICMPV6_DEST_UNREACH && + code == ICMPV6_NOROUTE) + tcp_ld_RTO_revert(sk, seq); + } + + if (!sock_owned_by_user(sk) && np->recverr) { + sk->sk_err = err; + sk_error_report(sk); + } else + sk->sk_err_soft = err; + +out: + bh_unlock_sock(sk); + sock_put(sk); + return 0; +} + + +static int tcp_v6_send_synack(const struct sock *sk, struct dst_entry *dst, + struct flowi *fl, + struct request_sock *req, + struct tcp_fastopen_cookie *foc, + enum tcp_synack_type synack_type, + struct sk_buff *syn_skb) +{ + struct inet_request_sock *ireq = inet_rsk(req); + struct ipv6_pinfo *np = tcp_inet6_sk(sk); + struct ipv6_txoptions *opt; + struct flowi6 *fl6 = &fl->u.ip6; + struct sk_buff *skb; + int err = -ENOMEM; + u8 tclass; + + /* First, grab a route. */ + if (!dst && (dst = inet6_csk_route_req(sk, fl6, req, + IPPROTO_TCP)) == NULL) + goto done; + + skb = tcp_make_synack(sk, dst, req, foc, synack_type, syn_skb); + + if (skb) { + __tcp_v6_send_check(skb, &ireq->ir_v6_loc_addr, + &ireq->ir_v6_rmt_addr); + + fl6->daddr = ireq->ir_v6_rmt_addr; + if (np->repflow && ireq->pktopts) + fl6->flowlabel = ip6_flowlabel(ipv6_hdr(ireq->pktopts)); + + tclass = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_reflect_tos) ? + (tcp_rsk(req)->syn_tos & ~INET_ECN_MASK) | + (np->tclass & INET_ECN_MASK) : + np->tclass; + + if (!INET_ECN_is_capable(tclass) && + tcp_bpf_ca_needs_ecn((struct sock *)req)) + tclass |= INET_ECN_ECT_0; + + rcu_read_lock(); + opt = ireq->ipv6_opt; + if (!opt) + opt = rcu_dereference(np->opt); + err = ip6_xmit(sk, skb, fl6, skb->mark ? : READ_ONCE(sk->sk_mark), + opt, tclass, sk->sk_priority); + rcu_read_unlock(); + err = net_xmit_eval(err); + } + +done: + return err; +} + + +static void tcp_v6_reqsk_destructor(struct request_sock *req) +{ + kfree(inet_rsk(req)->ipv6_opt); + consume_skb(inet_rsk(req)->pktopts); +} + +#ifdef CONFIG_TCP_MD5SIG +static struct tcp_md5sig_key *tcp_v6_md5_do_lookup(const struct sock *sk, + const struct in6_addr *addr, + int l3index) +{ + return tcp_md5_do_lookup(sk, l3index, + (union tcp_md5_addr *)addr, AF_INET6); +} + +static struct tcp_md5sig_key *tcp_v6_md5_lookup(const struct sock *sk, + const struct sock *addr_sk) +{ + int l3index; + + l3index = l3mdev_master_ifindex_by_index(sock_net(sk), + addr_sk->sk_bound_dev_if); + return tcp_v6_md5_do_lookup(sk, &addr_sk->sk_v6_daddr, + l3index); +} + +static int tcp_v6_parse_md5_keys(struct sock *sk, int optname, + sockptr_t optval, int optlen) +{ + struct tcp_md5sig cmd; + struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&cmd.tcpm_addr; + int l3index = 0; + u8 prefixlen; + u8 flags; + + if (optlen < sizeof(cmd)) + return -EINVAL; + + if (copy_from_sockptr(&cmd, optval, sizeof(cmd))) + return -EFAULT; + + if (sin6->sin6_family != AF_INET6) + return -EINVAL; + + flags = cmd.tcpm_flags & TCP_MD5SIG_FLAG_IFINDEX; + + if (optname == TCP_MD5SIG_EXT && + cmd.tcpm_flags & TCP_MD5SIG_FLAG_PREFIX) { + prefixlen = cmd.tcpm_prefixlen; + if (prefixlen > 128 || (ipv6_addr_v4mapped(&sin6->sin6_addr) && + prefixlen > 32)) + return -EINVAL; + } else { + prefixlen = ipv6_addr_v4mapped(&sin6->sin6_addr) ? 32 : 128; + } + + if (optname == TCP_MD5SIG_EXT && cmd.tcpm_ifindex && + cmd.tcpm_flags & TCP_MD5SIG_FLAG_IFINDEX) { + struct net_device *dev; + + rcu_read_lock(); + dev = dev_get_by_index_rcu(sock_net(sk), cmd.tcpm_ifindex); + if (dev && netif_is_l3_master(dev)) + l3index = dev->ifindex; + rcu_read_unlock(); + + /* ok to reference set/not set outside of rcu; + * right now device MUST be an L3 master + */ + if (!dev || !l3index) + return -EINVAL; + } + + if (!cmd.tcpm_keylen) { + if (ipv6_addr_v4mapped(&sin6->sin6_addr)) + return tcp_md5_do_del(sk, (union tcp_md5_addr *)&sin6->sin6_addr.s6_addr32[3], + AF_INET, prefixlen, + l3index, flags); + return tcp_md5_do_del(sk, (union tcp_md5_addr *)&sin6->sin6_addr, + AF_INET6, prefixlen, l3index, flags); + } + + if (cmd.tcpm_keylen > TCP_MD5SIG_MAXKEYLEN) + return -EINVAL; + + if (ipv6_addr_v4mapped(&sin6->sin6_addr)) + return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr.s6_addr32[3], + AF_INET, prefixlen, l3index, flags, + cmd.tcpm_key, cmd.tcpm_keylen, + GFP_KERNEL); + + return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr, + AF_INET6, prefixlen, l3index, flags, + cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL); +} + +static int tcp_v6_md5_hash_headers(struct tcp_md5sig_pool *hp, + const struct in6_addr *daddr, + const struct in6_addr *saddr, + const struct tcphdr *th, int nbytes) +{ + struct tcp6_pseudohdr *bp; + struct scatterlist sg; + struct tcphdr *_th; + + bp = hp->scratch; + /* 1. TCP pseudo-header (RFC2460) */ + bp->saddr = *saddr; + bp->daddr = *daddr; + bp->protocol = cpu_to_be32(IPPROTO_TCP); + bp->len = cpu_to_be32(nbytes); + + _th = (struct tcphdr *)(bp + 1); + memcpy(_th, th, sizeof(*th)); + _th->check = 0; + + sg_init_one(&sg, bp, sizeof(*bp) + sizeof(*th)); + ahash_request_set_crypt(hp->md5_req, &sg, NULL, + sizeof(*bp) + sizeof(*th)); + return crypto_ahash_update(hp->md5_req); +} + +static int tcp_v6_md5_hash_hdr(char *md5_hash, const struct tcp_md5sig_key *key, + const struct in6_addr *daddr, struct in6_addr *saddr, + const struct tcphdr *th) +{ + struct tcp_md5sig_pool *hp; + struct ahash_request *req; + + hp = tcp_get_md5sig_pool(); + if (!hp) + goto clear_hash_noput; + req = hp->md5_req; + + if (crypto_ahash_init(req)) + goto clear_hash; + if (tcp_v6_md5_hash_headers(hp, daddr, saddr, th, th->doff << 2)) + goto clear_hash; + if (tcp_md5_hash_key(hp, key)) + goto clear_hash; + ahash_request_set_crypt(req, NULL, md5_hash, 0); + if (crypto_ahash_final(req)) + goto clear_hash; + + tcp_put_md5sig_pool(); + return 0; + +clear_hash: + tcp_put_md5sig_pool(); +clear_hash_noput: + memset(md5_hash, 0, 16); + return 1; +} + +static int tcp_v6_md5_hash_skb(char *md5_hash, + const struct tcp_md5sig_key *key, + const struct sock *sk, + const struct sk_buff *skb) +{ + const struct in6_addr *saddr, *daddr; + struct tcp_md5sig_pool *hp; + struct ahash_request *req; + const struct tcphdr *th = tcp_hdr(skb); + + if (sk) { /* valid for establish/request sockets */ + saddr = &sk->sk_v6_rcv_saddr; + daddr = &sk->sk_v6_daddr; + } else { + const struct ipv6hdr *ip6h = ipv6_hdr(skb); + saddr = &ip6h->saddr; + daddr = &ip6h->daddr; + } + + hp = tcp_get_md5sig_pool(); + if (!hp) + goto clear_hash_noput; + req = hp->md5_req; + + if (crypto_ahash_init(req)) + goto clear_hash; + + if (tcp_v6_md5_hash_headers(hp, daddr, saddr, th, skb->len)) + goto clear_hash; + if (tcp_md5_hash_skb_data(hp, skb, th->doff << 2)) + goto clear_hash; + if (tcp_md5_hash_key(hp, key)) + goto clear_hash; + ahash_request_set_crypt(req, NULL, md5_hash, 0); + if (crypto_ahash_final(req)) + goto clear_hash; + + tcp_put_md5sig_pool(); + return 0; + +clear_hash: + tcp_put_md5sig_pool(); +clear_hash_noput: + memset(md5_hash, 0, 16); + return 1; +} + +#endif + +static void tcp_v6_init_req(struct request_sock *req, + const struct sock *sk_listener, + struct sk_buff *skb) +{ + bool l3_slave = ipv6_l3mdev_skb(TCP_SKB_CB(skb)->header.h6.flags); + struct inet_request_sock *ireq = inet_rsk(req); + const struct ipv6_pinfo *np = tcp_inet6_sk(sk_listener); + + ireq->ir_v6_rmt_addr = ipv6_hdr(skb)->saddr; + ireq->ir_v6_loc_addr = ipv6_hdr(skb)->daddr; + + /* So that link locals have meaning */ + if ((!sk_listener->sk_bound_dev_if || l3_slave) && + ipv6_addr_type(&ireq->ir_v6_rmt_addr) & IPV6_ADDR_LINKLOCAL) + ireq->ir_iif = tcp_v6_iif(skb); + + if (!TCP_SKB_CB(skb)->tcp_tw_isn && + (ipv6_opt_accepted(sk_listener, skb, &TCP_SKB_CB(skb)->header.h6) || + np->rxopt.bits.rxinfo || + np->rxopt.bits.rxoinfo || np->rxopt.bits.rxhlim || + np->rxopt.bits.rxohlim || np->repflow)) { + refcount_inc(&skb->users); + ireq->pktopts = skb; + } +} + +static struct dst_entry *tcp_v6_route_req(const struct sock *sk, + struct sk_buff *skb, + struct flowi *fl, + struct request_sock *req) +{ + tcp_v6_init_req(req, sk, skb); + + if (security_inet_conn_request(sk, skb, req)) + return NULL; + + return inet6_csk_route_req(sk, &fl->u.ip6, req, IPPROTO_TCP); +} + +struct request_sock_ops tcp6_request_sock_ops __read_mostly = { + .family = AF_INET6, + .obj_size = sizeof(struct tcp6_request_sock), + .rtx_syn_ack = tcp_rtx_synack, + .send_ack = tcp_v6_reqsk_send_ack, + .destructor = tcp_v6_reqsk_destructor, + .send_reset = tcp_v6_send_reset, + .syn_ack_timeout = tcp_syn_ack_timeout, +}; + +const struct tcp_request_sock_ops tcp_request_sock_ipv6_ops = { + .mss_clamp = IPV6_MIN_MTU - sizeof(struct tcphdr) - + sizeof(struct ipv6hdr), +#ifdef CONFIG_TCP_MD5SIG + .req_md5_lookup = tcp_v6_md5_lookup, + .calc_md5_hash = tcp_v6_md5_hash_skb, +#endif +#ifdef CONFIG_SYN_COOKIES + .cookie_init_seq = cookie_v6_init_sequence, +#endif + .route_req = tcp_v6_route_req, + .init_seq = tcp_v6_init_seq, + .init_ts_off = tcp_v6_init_ts_off, + .send_synack = tcp_v6_send_synack, +}; + +static void tcp_v6_send_response(const struct sock *sk, struct sk_buff *skb, u32 seq, + u32 ack, u32 win, u32 tsval, u32 tsecr, + int oif, struct tcp_md5sig_key *key, int rst, + u8 tclass, __be32 label, u32 priority, u32 txhash) +{ + const struct tcphdr *th = tcp_hdr(skb); + struct tcphdr *t1; + struct sk_buff *buff; + struct flowi6 fl6; + struct net *net = sk ? sock_net(sk) : dev_net(skb_dst(skb)->dev); + struct sock *ctl_sk = net->ipv6.tcp_sk; + unsigned int tot_len = sizeof(struct tcphdr); + __be32 mrst = 0, *topt; + struct dst_entry *dst; + __u32 mark = 0; + + if (tsecr) + tot_len += TCPOLEN_TSTAMP_ALIGNED; +#ifdef CONFIG_TCP_MD5SIG + if (key) + tot_len += TCPOLEN_MD5SIG_ALIGNED; +#endif + +#ifdef CONFIG_MPTCP + if (rst && !key) { + mrst = mptcp_reset_option(skb); + + if (mrst) + tot_len += sizeof(__be32); + } +#endif + + buff = alloc_skb(MAX_TCP_HEADER, GFP_ATOMIC); + if (!buff) + return; + + skb_reserve(buff, MAX_TCP_HEADER); + + t1 = skb_push(buff, tot_len); + skb_reset_transport_header(buff); + + /* Swap the send and the receive. */ + memset(t1, 0, sizeof(*t1)); + t1->dest = th->source; + t1->source = th->dest; + t1->doff = tot_len / 4; + t1->seq = htonl(seq); + t1->ack_seq = htonl(ack); + t1->ack = !rst || !th->ack; + t1->rst = rst; + t1->window = htons(win); + + topt = (__be32 *)(t1 + 1); + + if (tsecr) { + *topt++ = htonl((TCPOPT_NOP << 24) | (TCPOPT_NOP << 16) | + (TCPOPT_TIMESTAMP << 8) | TCPOLEN_TIMESTAMP); + *topt++ = htonl(tsval); + *topt++ = htonl(tsecr); + } + + if (mrst) + *topt++ = mrst; + +#ifdef CONFIG_TCP_MD5SIG + if (key) { + *topt++ = htonl((TCPOPT_NOP << 24) | (TCPOPT_NOP << 16) | + (TCPOPT_MD5SIG << 8) | TCPOLEN_MD5SIG); + tcp_v6_md5_hash_hdr((__u8 *)topt, key, + &ipv6_hdr(skb)->saddr, + &ipv6_hdr(skb)->daddr, t1); + } +#endif + + memset(&fl6, 0, sizeof(fl6)); + fl6.daddr = ipv6_hdr(skb)->saddr; + fl6.saddr = ipv6_hdr(skb)->daddr; + fl6.flowlabel = label; + + buff->ip_summed = CHECKSUM_PARTIAL; + + __tcp_v6_send_check(buff, &fl6.saddr, &fl6.daddr); + + fl6.flowi6_proto = IPPROTO_TCP; + if (rt6_need_strict(&fl6.daddr) && !oif) + fl6.flowi6_oif = tcp_v6_iif(skb); + else { + if (!oif && netif_index_is_l3_master(net, skb->skb_iif)) + oif = skb->skb_iif; + + fl6.flowi6_oif = oif; + } + + if (sk) { + if (sk->sk_state == TCP_TIME_WAIT) + mark = inet_twsk(sk)->tw_mark; + else + mark = READ_ONCE(sk->sk_mark); + skb_set_delivery_time(buff, tcp_transmit_time(sk), true); + } + if (txhash) { + /* autoflowlabel/skb_get_hash_flowi6 rely on buff->hash */ + skb_set_hash(buff, txhash, PKT_HASH_TYPE_L4); + } + fl6.flowi6_mark = IP6_REPLY_MARK(net, skb->mark) ?: mark; + fl6.fl6_dport = t1->dest; + fl6.fl6_sport = t1->source; + fl6.flowi6_uid = sock_net_uid(net, sk && sk_fullsock(sk) ? sk : NULL); + security_skb_classify_flow(skb, flowi6_to_flowi_common(&fl6)); + + /* Pass a socket to ip6_dst_lookup either it is for RST + * Underlying function will use this to retrieve the network + * namespace + */ + if (sk && sk->sk_state != TCP_TIME_WAIT) + dst = ip6_dst_lookup_flow(net, sk, &fl6, NULL); /*sk's xfrm_policy can be referred*/ + else + dst = ip6_dst_lookup_flow(net, ctl_sk, &fl6, NULL); + if (!IS_ERR(dst)) { + skb_dst_set(buff, dst); + ip6_xmit(ctl_sk, buff, &fl6, fl6.flowi6_mark, NULL, + tclass & ~INET_ECN_MASK, priority); + TCP_INC_STATS(net, TCP_MIB_OUTSEGS); + if (rst) + TCP_INC_STATS(net, TCP_MIB_OUTRSTS); + return; + } + + kfree_skb(buff); +} + +static void tcp_v6_send_reset(const struct sock *sk, struct sk_buff *skb) +{ + const struct tcphdr *th = tcp_hdr(skb); + struct ipv6hdr *ipv6h = ipv6_hdr(skb); + u32 seq = 0, ack_seq = 0; + struct tcp_md5sig_key *key = NULL; +#ifdef CONFIG_TCP_MD5SIG + const __u8 *hash_location = NULL; + unsigned char newhash[16]; + int genhash; + struct sock *sk1 = NULL; +#endif + __be32 label = 0; + u32 priority = 0; + struct net *net; + u32 txhash = 0; + int oif = 0; + + if (th->rst) + return; + + /* If sk not NULL, it means we did a successful lookup and incoming + * route had to be correct. prequeue might have dropped our dst. + */ + if (!sk && !ipv6_unicast_destination(skb)) + return; + + net = sk ? sock_net(sk) : dev_net(skb_dst(skb)->dev); +#ifdef CONFIG_TCP_MD5SIG + rcu_read_lock(); + hash_location = tcp_parse_md5sig_option(th); + if (sk && sk_fullsock(sk)) { + int l3index; + + /* sdif set, means packet ingressed via a device + * in an L3 domain and inet_iif is set to it. + */ + l3index = tcp_v6_sdif(skb) ? tcp_v6_iif_l3_slave(skb) : 0; + key = tcp_v6_md5_do_lookup(sk, &ipv6h->saddr, l3index); + } else if (hash_location) { + int dif = tcp_v6_iif_l3_slave(skb); + int sdif = tcp_v6_sdif(skb); + int l3index; + + /* + * active side is lost. Try to find listening socket through + * source port, and then find md5 key through listening socket. + * we are not loose security here: + * Incoming packet is checked with md5 hash with finding key, + * no RST generated if md5 hash doesn't match. + */ + sk1 = inet6_lookup_listener(net, net->ipv4.tcp_death_row.hashinfo, + NULL, 0, &ipv6h->saddr, th->source, + &ipv6h->daddr, ntohs(th->source), + dif, sdif); + if (!sk1) + goto out; + + /* sdif set, means packet ingressed via a device + * in an L3 domain and dif is set to it. + */ + l3index = tcp_v6_sdif(skb) ? dif : 0; + + key = tcp_v6_md5_do_lookup(sk1, &ipv6h->saddr, l3index); + if (!key) + goto out; + + genhash = tcp_v6_md5_hash_skb(newhash, key, NULL, skb); + if (genhash || memcmp(hash_location, newhash, 16) != 0) + goto out; + } +#endif + + if (th->ack) + seq = ntohl(th->ack_seq); + else + ack_seq = ntohl(th->seq) + th->syn + th->fin + skb->len - + (th->doff << 2); + + if (sk) { + oif = sk->sk_bound_dev_if; + if (sk_fullsock(sk)) { + const struct ipv6_pinfo *np = tcp_inet6_sk(sk); + + trace_tcp_send_reset(sk, skb); + if (np->repflow) + label = ip6_flowlabel(ipv6h); + priority = sk->sk_priority; + txhash = sk->sk_txhash; + } + if (sk->sk_state == TCP_TIME_WAIT) { + label = cpu_to_be32(inet_twsk(sk)->tw_flowlabel); + priority = inet_twsk(sk)->tw_priority; + txhash = inet_twsk(sk)->tw_txhash; + } + } else { + if (net->ipv6.sysctl.flowlabel_reflect & FLOWLABEL_REFLECT_TCP_RESET) + label = ip6_flowlabel(ipv6h); + } + + tcp_v6_send_response(sk, skb, seq, ack_seq, 0, 0, 0, oif, key, 1, + ipv6_get_dsfield(ipv6h), label, priority, txhash); + +#ifdef CONFIG_TCP_MD5SIG +out: + rcu_read_unlock(); +#endif +} + +static void tcp_v6_send_ack(const struct sock *sk, struct sk_buff *skb, u32 seq, + u32 ack, u32 win, u32 tsval, u32 tsecr, int oif, + struct tcp_md5sig_key *key, u8 tclass, + __be32 label, u32 priority, u32 txhash) +{ + tcp_v6_send_response(sk, skb, seq, ack, win, tsval, tsecr, oif, key, 0, + tclass, label, priority, txhash); +} + +static void tcp_v6_timewait_ack(struct sock *sk, struct sk_buff *skb) +{ + struct inet_timewait_sock *tw = inet_twsk(sk); + struct tcp_timewait_sock *tcptw = tcp_twsk(sk); + + tcp_v6_send_ack(sk, skb, tcptw->tw_snd_nxt, tcptw->tw_rcv_nxt, + tcptw->tw_rcv_wnd >> tw->tw_rcv_wscale, + tcp_time_stamp_raw() + tcptw->tw_ts_offset, + tcptw->tw_ts_recent, tw->tw_bound_dev_if, tcp_twsk_md5_key(tcptw), + tw->tw_tclass, cpu_to_be32(tw->tw_flowlabel), tw->tw_priority, + tw->tw_txhash); + + inet_twsk_put(tw); +} + +static void tcp_v6_reqsk_send_ack(const struct sock *sk, struct sk_buff *skb, + struct request_sock *req) +{ + int l3index; + + l3index = tcp_v6_sdif(skb) ? tcp_v6_iif_l3_slave(skb) : 0; + + /* sk->sk_state == TCP_LISTEN -> for regular TCP_SYN_RECV + * sk->sk_state == TCP_SYN_RECV -> for Fast Open. + */ + /* RFC 7323 2.3 + * The window field (SEG.WND) of every outgoing segment, with the + * exception of segments, MUST be right-shifted by + * Rcv.Wind.Shift bits: + */ + tcp_v6_send_ack(sk, skb, (sk->sk_state == TCP_LISTEN) ? + tcp_rsk(req)->snt_isn + 1 : tcp_sk(sk)->snd_nxt, + tcp_rsk(req)->rcv_nxt, + req->rsk_rcv_wnd >> inet_rsk(req)->rcv_wscale, + tcp_time_stamp_raw() + tcp_rsk(req)->ts_off, + READ_ONCE(req->ts_recent), sk->sk_bound_dev_if, + tcp_v6_md5_do_lookup(sk, &ipv6_hdr(skb)->saddr, l3index), + ipv6_get_dsfield(ipv6_hdr(skb)), 0, + READ_ONCE(sk->sk_priority), + READ_ONCE(tcp_rsk(req)->txhash)); +} + + +static struct sock *tcp_v6_cookie_check(struct sock *sk, struct sk_buff *skb) +{ +#ifdef CONFIG_SYN_COOKIES + const struct tcphdr *th = tcp_hdr(skb); + + if (!th->syn) + sk = cookie_v6_check(sk, skb); +#endif + return sk; +} + +u16 tcp_v6_get_syncookie(struct sock *sk, struct ipv6hdr *iph, + struct tcphdr *th, u32 *cookie) +{ + u16 mss = 0; +#ifdef CONFIG_SYN_COOKIES + mss = tcp_get_syncookie_mss(&tcp6_request_sock_ops, + &tcp_request_sock_ipv6_ops, sk, th); + if (mss) { + *cookie = __cookie_v6_init_sequence(iph, th, &mss); + tcp_synq_overflow(sk); + } +#endif + return mss; +} + +static int tcp_v6_conn_request(struct sock *sk, struct sk_buff *skb) +{ + if (skb->protocol == htons(ETH_P_IP)) + return tcp_v4_conn_request(sk, skb); + + if (!ipv6_unicast_destination(skb)) + goto drop; + + if (ipv6_addr_v4mapped(&ipv6_hdr(skb)->saddr)) { + __IP6_INC_STATS(sock_net(sk), NULL, IPSTATS_MIB_INHDRERRORS); + return 0; + } + + return tcp_conn_request(&tcp6_request_sock_ops, + &tcp_request_sock_ipv6_ops, sk, skb); + +drop: + tcp_listendrop(sk); + return 0; /* don't send reset */ +} + +static void tcp_v6_restore_cb(struct sk_buff *skb) +{ + /* We need to move header back to the beginning if xfrm6_policy_check() + * and tcp_v6_fill_cb() are going to be called again. + * ip6_datagram_recv_specific_ctl() also expects IP6CB to be there. + */ + memmove(IP6CB(skb), &TCP_SKB_CB(skb)->header.h6, + sizeof(struct inet6_skb_parm)); +} + +static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff *skb, + struct request_sock *req, + struct dst_entry *dst, + struct request_sock *req_unhash, + bool *own_req) +{ + struct inet_request_sock *ireq; + struct ipv6_pinfo *newnp; + const struct ipv6_pinfo *np = tcp_inet6_sk(sk); + struct ipv6_txoptions *opt; + struct inet_sock *newinet; + bool found_dup_sk = false; + struct tcp_sock *newtp; + struct sock *newsk; +#ifdef CONFIG_TCP_MD5SIG + struct tcp_md5sig_key *key; + int l3index; +#endif + struct flowi6 fl6; + + if (skb->protocol == htons(ETH_P_IP)) { + /* + * v6 mapped + */ + + newsk = tcp_v4_syn_recv_sock(sk, skb, req, dst, + req_unhash, own_req); + + if (!newsk) + return NULL; + + inet_sk(newsk)->pinet6 = tcp_inet6_sk(newsk); + + newnp = tcp_inet6_sk(newsk); + newtp = tcp_sk(newsk); + + memcpy(newnp, np, sizeof(struct ipv6_pinfo)); + + newnp->saddr = newsk->sk_v6_rcv_saddr; + + inet_csk(newsk)->icsk_af_ops = &ipv6_mapped; + if (sk_is_mptcp(newsk)) + mptcpv6_handle_mapped(newsk, true); + newsk->sk_backlog_rcv = tcp_v4_do_rcv; +#ifdef CONFIG_TCP_MD5SIG + newtp->af_specific = &tcp_sock_ipv6_mapped_specific; +#endif + + newnp->ipv6_mc_list = NULL; + newnp->ipv6_ac_list = NULL; + newnp->ipv6_fl_list = NULL; + newnp->pktoptions = NULL; + newnp->opt = NULL; + newnp->mcast_oif = inet_iif(skb); + newnp->mcast_hops = ip_hdr(skb)->ttl; + newnp->rcv_flowinfo = 0; + if (np->repflow) + newnp->flow_label = 0; + + /* + * No need to charge this sock to the relevant IPv6 refcnt debug socks count + * here, tcp_create_openreq_child now does this for us, see the comment in + * that function for the gory details. -acme + */ + + /* It is tricky place. Until this moment IPv4 tcp + worked with IPv6 icsk.icsk_af_ops. + Sync it now. + */ + tcp_sync_mss(newsk, inet_csk(newsk)->icsk_pmtu_cookie); + + return newsk; + } + + ireq = inet_rsk(req); + + if (sk_acceptq_is_full(sk)) + goto out_overflow; + + if (!dst) { + dst = inet6_csk_route_req(sk, &fl6, req, IPPROTO_TCP); + if (!dst) + goto out; + } + + newsk = tcp_create_openreq_child(sk, req, skb); + if (!newsk) + goto out_nonewsk; + + /* + * No need to charge this sock to the relevant IPv6 refcnt debug socks + * count here, tcp_create_openreq_child now does this for us, see the + * comment in that function for the gory details. -acme + */ + + newsk->sk_gso_type = SKB_GSO_TCPV6; + ip6_dst_store(newsk, dst, NULL, NULL); + inet6_sk_rx_dst_set(newsk, skb); + + inet_sk(newsk)->pinet6 = tcp_inet6_sk(newsk); + + newtp = tcp_sk(newsk); + newinet = inet_sk(newsk); + newnp = tcp_inet6_sk(newsk); + + memcpy(newnp, np, sizeof(struct ipv6_pinfo)); + + newsk->sk_v6_daddr = ireq->ir_v6_rmt_addr; + newnp->saddr = ireq->ir_v6_loc_addr; + newsk->sk_v6_rcv_saddr = ireq->ir_v6_loc_addr; + newsk->sk_bound_dev_if = ireq->ir_iif; + + /* Now IPv6 options... + + First: no IPv4 options. + */ + newinet->inet_opt = NULL; + newnp->ipv6_mc_list = NULL; + newnp->ipv6_ac_list = NULL; + newnp->ipv6_fl_list = NULL; + + /* Clone RX bits */ + newnp->rxopt.all = np->rxopt.all; + + newnp->pktoptions = NULL; + newnp->opt = NULL; + newnp->mcast_oif = tcp_v6_iif(skb); + newnp->mcast_hops = ipv6_hdr(skb)->hop_limit; + newnp->rcv_flowinfo = ip6_flowinfo(ipv6_hdr(skb)); + if (np->repflow) + newnp->flow_label = ip6_flowlabel(ipv6_hdr(skb)); + + /* Set ToS of the new socket based upon the value of incoming SYN. + * ECT bits are set later in tcp_init_transfer(). + */ + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_reflect_tos)) + newnp->tclass = tcp_rsk(req)->syn_tos & ~INET_ECN_MASK; + + /* Clone native IPv6 options from listening socket (if any) + + Yes, keeping reference count would be much more clever, + but we make one more one thing there: reattach optmem + to newsk. + */ + opt = ireq->ipv6_opt; + if (!opt) + opt = rcu_dereference(np->opt); + if (opt) { + opt = ipv6_dup_options(newsk, opt); + RCU_INIT_POINTER(newnp->opt, opt); + } + inet_csk(newsk)->icsk_ext_hdr_len = 0; + if (opt) + inet_csk(newsk)->icsk_ext_hdr_len = opt->opt_nflen + + opt->opt_flen; + + tcp_ca_openreq_child(newsk, dst); + + tcp_sync_mss(newsk, dst_mtu(dst)); + newtp->advmss = tcp_mss_clamp(tcp_sk(sk), dst_metric_advmss(dst)); + + tcp_initialize_rcv_mss(newsk); + + newinet->inet_daddr = newinet->inet_saddr = LOOPBACK4_IPV6; + newinet->inet_rcv_saddr = LOOPBACK4_IPV6; + +#ifdef CONFIG_TCP_MD5SIG + l3index = l3mdev_master_ifindex_by_index(sock_net(sk), ireq->ir_iif); + + /* Copy over the MD5 key from the original socket */ + key = tcp_v6_md5_do_lookup(sk, &newsk->sk_v6_daddr, l3index); + if (key) { + /* We're using one, so create a matching key + * on the newsk structure. If we fail to get + * memory, then we end up not copying the key + * across. Shucks. + */ + tcp_md5_do_add(newsk, (union tcp_md5_addr *)&newsk->sk_v6_daddr, + AF_INET6, 128, l3index, key->flags, key->key, key->keylen, + sk_gfp_mask(sk, GFP_ATOMIC)); + } +#endif + + if (__inet_inherit_port(sk, newsk) < 0) { + inet_csk_prepare_forced_close(newsk); + tcp_done(newsk); + goto out; + } + *own_req = inet_ehash_nolisten(newsk, req_to_sk(req_unhash), + &found_dup_sk); + if (*own_req) { + tcp_move_syn(newtp, req); + + /* Clone pktoptions received with SYN, if we own the req */ + if (ireq->pktopts) { + newnp->pktoptions = skb_clone_and_charge_r(ireq->pktopts, newsk); + consume_skb(ireq->pktopts); + ireq->pktopts = NULL; + if (newnp->pktoptions) + tcp_v6_restore_cb(newnp->pktoptions); + } + } else { + if (!req_unhash && found_dup_sk) { + /* This code path should only be executed in the + * syncookie case only + */ + bh_unlock_sock(newsk); + sock_put(newsk); + newsk = NULL; + } + } + + return newsk; + +out_overflow: + __NET_INC_STATS(sock_net(sk), LINUX_MIB_LISTENOVERFLOWS); +out_nonewsk: + dst_release(dst); +out: + tcp_listendrop(sk); + return NULL; +} + +INDIRECT_CALLABLE_DECLARE(struct dst_entry *ipv4_dst_check(struct dst_entry *, + u32)); +/* The socket must have it's spinlock held when we get + * here, unless it is a TCP_LISTEN socket. + * + * We have a potential double-lock case here, so even when + * doing backlog processing we use the BH locking scheme. + * This is because we cannot sleep with the original spinlock + * held. + */ +INDIRECT_CALLABLE_SCOPE +int tcp_v6_do_rcv(struct sock *sk, struct sk_buff *skb) +{ + struct ipv6_pinfo *np = tcp_inet6_sk(sk); + struct sk_buff *opt_skb = NULL; + enum skb_drop_reason reason; + struct tcp_sock *tp; + + /* Imagine: socket is IPv6. IPv4 packet arrives, + goes to IPv4 receive handler and backlogged. + From backlog it always goes here. Kerboom... + Fortunately, tcp_rcv_established and rcv_established + handle them correctly, but it is not case with + tcp_v6_hnd_req and tcp_v6_send_reset(). --ANK + */ + + if (skb->protocol == htons(ETH_P_IP)) + return tcp_v4_do_rcv(sk, skb); + + /* + * socket locking is here for SMP purposes as backlog rcv + * is currently called with bh processing disabled. + */ + + /* Do Stevens' IPV6_PKTOPTIONS. + + Yes, guys, it is the only place in our code, where we + may make it not affecting IPv4. + The rest of code is protocol independent, + and I do not like idea to uglify IPv4. + + Actually, all the idea behind IPV6_PKTOPTIONS + looks not very well thought. For now we latch + options, received in the last packet, enqueued + by tcp. Feel free to propose better solution. + --ANK (980728) + */ + if (np->rxopt.all) + opt_skb = skb_clone_and_charge_r(skb, sk); + + reason = SKB_DROP_REASON_NOT_SPECIFIED; + if (sk->sk_state == TCP_ESTABLISHED) { /* Fast path */ + struct dst_entry *dst; + + dst = rcu_dereference_protected(sk->sk_rx_dst, + lockdep_sock_is_held(sk)); + + sock_rps_save_rxhash(sk, skb); + sk_mark_napi_id(sk, skb); + if (dst) { + if (sk->sk_rx_dst_ifindex != skb->skb_iif || + INDIRECT_CALL_1(dst->ops->check, ip6_dst_check, + dst, sk->sk_rx_dst_cookie) == NULL) { + RCU_INIT_POINTER(sk->sk_rx_dst, NULL); + dst_release(dst); + } + } + + tcp_rcv_established(sk, skb); + if (opt_skb) + goto ipv6_pktoptions; + return 0; + } + + if (tcp_checksum_complete(skb)) + goto csum_err; + + if (sk->sk_state == TCP_LISTEN) { + struct sock *nsk = tcp_v6_cookie_check(sk, skb); + + if (!nsk) + goto discard; + + if (nsk != sk) { + if (tcp_child_process(sk, nsk, skb)) + goto reset; + if (opt_skb) + __kfree_skb(opt_skb); + return 0; + } + } else + sock_rps_save_rxhash(sk, skb); + + if (tcp_rcv_state_process(sk, skb)) + goto reset; + if (opt_skb) + goto ipv6_pktoptions; + return 0; + +reset: + tcp_v6_send_reset(sk, skb); +discard: + if (opt_skb) + __kfree_skb(opt_skb); + kfree_skb_reason(skb, reason); + return 0; +csum_err: + reason = SKB_DROP_REASON_TCP_CSUM; + trace_tcp_bad_csum(skb); + TCP_INC_STATS(sock_net(sk), TCP_MIB_CSUMERRORS); + TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS); + goto discard; + + +ipv6_pktoptions: + /* Do you ask, what is it? + + 1. skb was enqueued by tcp. + 2. skb is added to tail of read queue, rather than out of order. + 3. socket is not in passive state. + 4. Finally, it really contains options, which user wants to receive. + */ + tp = tcp_sk(sk); + if (TCP_SKB_CB(opt_skb)->end_seq == tp->rcv_nxt && + !((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN))) { + if (np->rxopt.bits.rxinfo || np->rxopt.bits.rxoinfo) + np->mcast_oif = tcp_v6_iif(opt_skb); + if (np->rxopt.bits.rxhlim || np->rxopt.bits.rxohlim) + np->mcast_hops = ipv6_hdr(opt_skb)->hop_limit; + if (np->rxopt.bits.rxflow || np->rxopt.bits.rxtclass) + np->rcv_flowinfo = ip6_flowinfo(ipv6_hdr(opt_skb)); + if (np->repflow) + np->flow_label = ip6_flowlabel(ipv6_hdr(opt_skb)); + if (ipv6_opt_accepted(sk, opt_skb, &TCP_SKB_CB(opt_skb)->header.h6)) { + tcp_v6_restore_cb(opt_skb); + opt_skb = xchg(&np->pktoptions, opt_skb); + } else { + __kfree_skb(opt_skb); + opt_skb = xchg(&np->pktoptions, NULL); + } + } + + consume_skb(opt_skb); + return 0; +} + +static void tcp_v6_fill_cb(struct sk_buff *skb, const struct ipv6hdr *hdr, + const struct tcphdr *th) +{ + /* This is tricky: we move IP6CB at its correct location into + * TCP_SKB_CB(). It must be done after xfrm6_policy_check(), because + * _decode_session6() uses IP6CB(). + * barrier() makes sure compiler won't play aliasing games. + */ + memmove(&TCP_SKB_CB(skb)->header.h6, IP6CB(skb), + sizeof(struct inet6_skb_parm)); + barrier(); + + TCP_SKB_CB(skb)->seq = ntohl(th->seq); + TCP_SKB_CB(skb)->end_seq = (TCP_SKB_CB(skb)->seq + th->syn + th->fin + + skb->len - th->doff*4); + TCP_SKB_CB(skb)->ack_seq = ntohl(th->ack_seq); + TCP_SKB_CB(skb)->tcp_flags = tcp_flag_byte(th); + TCP_SKB_CB(skb)->tcp_tw_isn = 0; + TCP_SKB_CB(skb)->ip_dsfield = ipv6_get_dsfield(hdr); + TCP_SKB_CB(skb)->sacked = 0; + TCP_SKB_CB(skb)->has_rxtstamp = + skb->tstamp || skb_hwtstamps(skb)->hwtstamp; +} + +INDIRECT_CALLABLE_SCOPE int tcp_v6_rcv(struct sk_buff *skb) +{ + enum skb_drop_reason drop_reason; + int sdif = inet6_sdif(skb); + int dif = inet6_iif(skb); + const struct tcphdr *th; + const struct ipv6hdr *hdr; + bool refcounted; + struct sock *sk; + int ret; + struct net *net = dev_net(skb->dev); + + drop_reason = SKB_DROP_REASON_NOT_SPECIFIED; + if (skb->pkt_type != PACKET_HOST) + goto discard_it; + + /* + * Count it even if it's bad. + */ + __TCP_INC_STATS(net, TCP_MIB_INSEGS); + + if (!pskb_may_pull(skb, sizeof(struct tcphdr))) + goto discard_it; + + th = (const struct tcphdr *)skb->data; + + if (unlikely(th->doff < sizeof(struct tcphdr) / 4)) { + drop_reason = SKB_DROP_REASON_PKT_TOO_SMALL; + goto bad_packet; + } + if (!pskb_may_pull(skb, th->doff*4)) + goto discard_it; + + if (skb_checksum_init(skb, IPPROTO_TCP, ip6_compute_pseudo)) + goto csum_error; + + th = (const struct tcphdr *)skb->data; + hdr = ipv6_hdr(skb); + +lookup: + sk = __inet6_lookup_skb(net->ipv4.tcp_death_row.hashinfo, skb, __tcp_hdrlen(th), + th->source, th->dest, inet6_iif(skb), sdif, + &refcounted); + if (!sk) + goto no_tcp_socket; + +process: + if (sk->sk_state == TCP_TIME_WAIT) + goto do_time_wait; + + if (sk->sk_state == TCP_NEW_SYN_RECV) { + struct request_sock *req = inet_reqsk(sk); + bool req_stolen = false; + struct sock *nsk; + + sk = req->rsk_listener; + if (!xfrm6_policy_check(sk, XFRM_POLICY_IN, skb)) + drop_reason = SKB_DROP_REASON_XFRM_POLICY; + else + drop_reason = tcp_inbound_md5_hash(sk, skb, + &hdr->saddr, &hdr->daddr, + AF_INET6, dif, sdif); + if (drop_reason) { + sk_drops_add(sk, skb); + reqsk_put(req); + goto discard_it; + } + if (tcp_checksum_complete(skb)) { + reqsk_put(req); + goto csum_error; + } + if (unlikely(sk->sk_state != TCP_LISTEN)) { + nsk = reuseport_migrate_sock(sk, req_to_sk(req), skb); + if (!nsk) { + inet_csk_reqsk_queue_drop_and_put(sk, req); + goto lookup; + } + sk = nsk; + /* reuseport_migrate_sock() has already held one sk_refcnt + * before returning. + */ + } else { + sock_hold(sk); + } + refcounted = true; + nsk = NULL; + if (!tcp_filter(sk, skb)) { + th = (const struct tcphdr *)skb->data; + hdr = ipv6_hdr(skb); + tcp_v6_fill_cb(skb, hdr, th); + nsk = tcp_check_req(sk, skb, req, false, &req_stolen); + } else { + drop_reason = SKB_DROP_REASON_SOCKET_FILTER; + } + if (!nsk) { + reqsk_put(req); + if (req_stolen) { + /* Another cpu got exclusive access to req + * and created a full blown socket. + * Try to feed this packet to this socket + * instead of discarding it. + */ + tcp_v6_restore_cb(skb); + sock_put(sk); + goto lookup; + } + goto discard_and_relse; + } + nf_reset_ct(skb); + if (nsk == sk) { + reqsk_put(req); + tcp_v6_restore_cb(skb); + } else if (tcp_child_process(sk, nsk, skb)) { + tcp_v6_send_reset(nsk, skb); + goto discard_and_relse; + } else { + sock_put(sk); + return 0; + } + } + + if (static_branch_unlikely(&ip6_min_hopcount)) { + /* min_hopcount can be changed concurrently from do_ipv6_setsockopt() */ + if (hdr->hop_limit < READ_ONCE(tcp_inet6_sk(sk)->min_hopcount)) { + __NET_INC_STATS(net, LINUX_MIB_TCPMINTTLDROP); + goto discard_and_relse; + } + } + + if (!xfrm6_policy_check(sk, XFRM_POLICY_IN, skb)) { + drop_reason = SKB_DROP_REASON_XFRM_POLICY; + goto discard_and_relse; + } + + drop_reason = tcp_inbound_md5_hash(sk, skb, &hdr->saddr, &hdr->daddr, + AF_INET6, dif, sdif); + if (drop_reason) + goto discard_and_relse; + + nf_reset_ct(skb); + + if (tcp_filter(sk, skb)) { + drop_reason = SKB_DROP_REASON_SOCKET_FILTER; + goto discard_and_relse; + } + th = (const struct tcphdr *)skb->data; + hdr = ipv6_hdr(skb); + tcp_v6_fill_cb(skb, hdr, th); + + skb->dev = NULL; + + if (sk->sk_state == TCP_LISTEN) { + ret = tcp_v6_do_rcv(sk, skb); + goto put_and_return; + } + + sk_incoming_cpu_update(sk); + + bh_lock_sock_nested(sk); + tcp_segs_in(tcp_sk(sk), skb); + ret = 0; + if (!sock_owned_by_user(sk)) { + ret = tcp_v6_do_rcv(sk, skb); + } else { + if (tcp_add_backlog(sk, skb, &drop_reason)) + goto discard_and_relse; + } + bh_unlock_sock(sk); +put_and_return: + if (refcounted) + sock_put(sk); + return ret ? -1 : 0; + +no_tcp_socket: + drop_reason = SKB_DROP_REASON_NO_SOCKET; + if (!xfrm6_policy_check(NULL, XFRM_POLICY_IN, skb)) + goto discard_it; + + tcp_v6_fill_cb(skb, hdr, th); + + if (tcp_checksum_complete(skb)) { +csum_error: + drop_reason = SKB_DROP_REASON_TCP_CSUM; + trace_tcp_bad_csum(skb); + __TCP_INC_STATS(net, TCP_MIB_CSUMERRORS); +bad_packet: + __TCP_INC_STATS(net, TCP_MIB_INERRS); + } else { + tcp_v6_send_reset(NULL, skb); + } + +discard_it: + SKB_DR_OR(drop_reason, NOT_SPECIFIED); + kfree_skb_reason(skb, drop_reason); + return 0; + +discard_and_relse: + sk_drops_add(sk, skb); + if (refcounted) + sock_put(sk); + goto discard_it; + +do_time_wait: + if (!xfrm6_policy_check(NULL, XFRM_POLICY_IN, skb)) { + drop_reason = SKB_DROP_REASON_XFRM_POLICY; + inet_twsk_put(inet_twsk(sk)); + goto discard_it; + } + + tcp_v6_fill_cb(skb, hdr, th); + + if (tcp_checksum_complete(skb)) { + inet_twsk_put(inet_twsk(sk)); + goto csum_error; + } + + switch (tcp_timewait_state_process(inet_twsk(sk), skb, th)) { + case TCP_TW_SYN: + { + struct sock *sk2; + + sk2 = inet6_lookup_listener(net, net->ipv4.tcp_death_row.hashinfo, + skb, __tcp_hdrlen(th), + &ipv6_hdr(skb)->saddr, th->source, + &ipv6_hdr(skb)->daddr, + ntohs(th->dest), + tcp_v6_iif_l3_slave(skb), + sdif); + if (sk2) { + struct inet_timewait_sock *tw = inet_twsk(sk); + inet_twsk_deschedule_put(tw); + sk = sk2; + tcp_v6_restore_cb(skb); + refcounted = false; + goto process; + } + } + /* to ACK */ + fallthrough; + case TCP_TW_ACK: + tcp_v6_timewait_ack(sk, skb); + break; + case TCP_TW_RST: + tcp_v6_send_reset(sk, skb); + inet_twsk_deschedule_put(inet_twsk(sk)); + goto discard_it; + case TCP_TW_SUCCESS: + ; + } + goto discard_it; +} + +void tcp_v6_early_demux(struct sk_buff *skb) +{ + struct net *net = dev_net(skb->dev); + const struct ipv6hdr *hdr; + const struct tcphdr *th; + struct sock *sk; + + if (skb->pkt_type != PACKET_HOST) + return; + + if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct tcphdr))) + return; + + hdr = ipv6_hdr(skb); + th = tcp_hdr(skb); + + if (th->doff < sizeof(struct tcphdr) / 4) + return; + + /* Note : We use inet6_iif() here, not tcp_v6_iif() */ + sk = __inet6_lookup_established(net, net->ipv4.tcp_death_row.hashinfo, + &hdr->saddr, th->source, + &hdr->daddr, ntohs(th->dest), + inet6_iif(skb), inet6_sdif(skb)); + if (sk) { + skb->sk = sk; + skb->destructor = sock_edemux; + if (sk_fullsock(sk)) { + struct dst_entry *dst = rcu_dereference(sk->sk_rx_dst); + + if (dst) + dst = dst_check(dst, sk->sk_rx_dst_cookie); + if (dst && + sk->sk_rx_dst_ifindex == skb->skb_iif) + skb_dst_set_noref(skb, dst); + } + } +} + +static struct timewait_sock_ops tcp6_timewait_sock_ops = { + .twsk_obj_size = sizeof(struct tcp6_timewait_sock), + .twsk_unique = tcp_twsk_unique, + .twsk_destructor = tcp_twsk_destructor, +}; + +INDIRECT_CALLABLE_SCOPE void tcp_v6_send_check(struct sock *sk, struct sk_buff *skb) +{ + __tcp_v6_send_check(skb, &sk->sk_v6_rcv_saddr, &sk->sk_v6_daddr); +} + +const struct inet_connection_sock_af_ops ipv6_specific = { + .queue_xmit = inet6_csk_xmit, + .send_check = tcp_v6_send_check, + .rebuild_header = inet6_sk_rebuild_header, + .sk_rx_dst_set = inet6_sk_rx_dst_set, + .conn_request = tcp_v6_conn_request, + .syn_recv_sock = tcp_v6_syn_recv_sock, + .net_header_len = sizeof(struct ipv6hdr), + .net_frag_header_len = sizeof(struct frag_hdr), + .setsockopt = ipv6_setsockopt, + .getsockopt = ipv6_getsockopt, + .addr2sockaddr = inet6_csk_addr2sockaddr, + .sockaddr_len = sizeof(struct sockaddr_in6), + .mtu_reduced = tcp_v6_mtu_reduced, +}; + +#ifdef CONFIG_TCP_MD5SIG +static const struct tcp_sock_af_ops tcp_sock_ipv6_specific = { + .md5_lookup = tcp_v6_md5_lookup, + .calc_md5_hash = tcp_v6_md5_hash_skb, + .md5_parse = tcp_v6_parse_md5_keys, +}; +#endif + +/* + * TCP over IPv4 via INET6 API + */ +static const struct inet_connection_sock_af_ops ipv6_mapped = { + .queue_xmit = ip_queue_xmit, + .send_check = tcp_v4_send_check, + .rebuild_header = inet_sk_rebuild_header, + .sk_rx_dst_set = inet_sk_rx_dst_set, + .conn_request = tcp_v6_conn_request, + .syn_recv_sock = tcp_v6_syn_recv_sock, + .net_header_len = sizeof(struct iphdr), + .setsockopt = ipv6_setsockopt, + .getsockopt = ipv6_getsockopt, + .addr2sockaddr = inet6_csk_addr2sockaddr, + .sockaddr_len = sizeof(struct sockaddr_in6), + .mtu_reduced = tcp_v4_mtu_reduced, +}; + +#ifdef CONFIG_TCP_MD5SIG +static const struct tcp_sock_af_ops tcp_sock_ipv6_mapped_specific = { + .md5_lookup = tcp_v4_md5_lookup, + .calc_md5_hash = tcp_v4_md5_hash_skb, + .md5_parse = tcp_v6_parse_md5_keys, +}; +#endif + +/* NOTE: A lot of things set to zero explicitly by call to + * sk_alloc() so need not be done here. + */ +static int tcp_v6_init_sock(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + tcp_init_sock(sk); + + icsk->icsk_af_ops = &ipv6_specific; + +#ifdef CONFIG_TCP_MD5SIG + tcp_sk(sk)->af_specific = &tcp_sock_ipv6_specific; +#endif + + return 0; +} + +#ifdef CONFIG_PROC_FS +/* Proc filesystem TCPv6 sock list dumping. */ +static void get_openreq6(struct seq_file *seq, + const struct request_sock *req, int i) +{ + long ttd = req->rsk_timer.expires - jiffies; + const struct in6_addr *src = &inet_rsk(req)->ir_v6_loc_addr; + const struct in6_addr *dest = &inet_rsk(req)->ir_v6_rmt_addr; + + if (ttd < 0) + ttd = 0; + + seq_printf(seq, + "%4d: %08X%08X%08X%08X:%04X %08X%08X%08X%08X:%04X " + "%02X %08X:%08X %02X:%08lX %08X %5u %8d %d %d %pK\n", + i, + src->s6_addr32[0], src->s6_addr32[1], + src->s6_addr32[2], src->s6_addr32[3], + inet_rsk(req)->ir_num, + dest->s6_addr32[0], dest->s6_addr32[1], + dest->s6_addr32[2], dest->s6_addr32[3], + ntohs(inet_rsk(req)->ir_rmt_port), + TCP_SYN_RECV, + 0, 0, /* could print option size, but that is af dependent. */ + 1, /* timers active (only the expire timer) */ + jiffies_to_clock_t(ttd), + req->num_timeout, + from_kuid_munged(seq_user_ns(seq), + sock_i_uid(req->rsk_listener)), + 0, /* non standard timer */ + 0, /* open_requests have no inode */ + 0, req); +} + +static void get_tcp6_sock(struct seq_file *seq, struct sock *sp, int i) +{ + const struct in6_addr *dest, *src; + __u16 destp, srcp; + int timer_active; + unsigned long timer_expires; + const struct inet_sock *inet = inet_sk(sp); + const struct tcp_sock *tp = tcp_sk(sp); + const struct inet_connection_sock *icsk = inet_csk(sp); + const struct fastopen_queue *fastopenq = &icsk->icsk_accept_queue.fastopenq; + int rx_queue; + int state; + + dest = &sp->sk_v6_daddr; + src = &sp->sk_v6_rcv_saddr; + destp = ntohs(inet->inet_dport); + srcp = ntohs(inet->inet_sport); + + if (icsk->icsk_pending == ICSK_TIME_RETRANS || + icsk->icsk_pending == ICSK_TIME_REO_TIMEOUT || + icsk->icsk_pending == ICSK_TIME_LOSS_PROBE) { + timer_active = 1; + timer_expires = icsk->icsk_timeout; + } else if (icsk->icsk_pending == ICSK_TIME_PROBE0) { + timer_active = 4; + timer_expires = icsk->icsk_timeout; + } else if (timer_pending(&sp->sk_timer)) { + timer_active = 2; + timer_expires = sp->sk_timer.expires; + } else { + timer_active = 0; + timer_expires = jiffies; + } + + state = inet_sk_state_load(sp); + if (state == TCP_LISTEN) + rx_queue = READ_ONCE(sp->sk_ack_backlog); + else + /* Because we don't lock the socket, + * we might find a transient negative value. + */ + rx_queue = max_t(int, READ_ONCE(tp->rcv_nxt) - + READ_ONCE(tp->copied_seq), 0); + + seq_printf(seq, + "%4d: %08X%08X%08X%08X:%04X %08X%08X%08X%08X:%04X " + "%02X %08X:%08X %02X:%08lX %08X %5u %8d %lu %d %pK %lu %lu %u %u %d\n", + i, + src->s6_addr32[0], src->s6_addr32[1], + src->s6_addr32[2], src->s6_addr32[3], srcp, + dest->s6_addr32[0], dest->s6_addr32[1], + dest->s6_addr32[2], dest->s6_addr32[3], destp, + state, + READ_ONCE(tp->write_seq) - tp->snd_una, + rx_queue, + timer_active, + jiffies_delta_to_clock_t(timer_expires - jiffies), + icsk->icsk_retransmits, + from_kuid_munged(seq_user_ns(seq), sock_i_uid(sp)), + icsk->icsk_probes_out, + sock_i_ino(sp), + refcount_read(&sp->sk_refcnt), sp, + jiffies_to_clock_t(icsk->icsk_rto), + jiffies_to_clock_t(icsk->icsk_ack.ato), + (icsk->icsk_ack.quick << 1) | inet_csk_in_pingpong_mode(sp), + tcp_snd_cwnd(tp), + state == TCP_LISTEN ? + fastopenq->max_qlen : + (tcp_in_initial_slowstart(tp) ? -1 : tp->snd_ssthresh) + ); +} + +static void get_timewait6_sock(struct seq_file *seq, + struct inet_timewait_sock *tw, int i) +{ + long delta = tw->tw_timer.expires - jiffies; + const struct in6_addr *dest, *src; + __u16 destp, srcp; + + dest = &tw->tw_v6_daddr; + src = &tw->tw_v6_rcv_saddr; + destp = ntohs(tw->tw_dport); + srcp = ntohs(tw->tw_sport); + + seq_printf(seq, + "%4d: %08X%08X%08X%08X:%04X %08X%08X%08X%08X:%04X " + "%02X %08X:%08X %02X:%08lX %08X %5d %8d %d %d %pK\n", + i, + src->s6_addr32[0], src->s6_addr32[1], + src->s6_addr32[2], src->s6_addr32[3], srcp, + dest->s6_addr32[0], dest->s6_addr32[1], + dest->s6_addr32[2], dest->s6_addr32[3], destp, + tw->tw_substate, 0, 0, + 3, jiffies_delta_to_clock_t(delta), 0, 0, 0, 0, + refcount_read(&tw->tw_refcnt), tw); +} + +static int tcp6_seq_show(struct seq_file *seq, void *v) +{ + struct tcp_iter_state *st; + struct sock *sk = v; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, + " sl " + "local_address " + "remote_address " + "st tx_queue rx_queue tr tm->when retrnsmt" + " uid timeout inode\n"); + goto out; + } + st = seq->private; + + if (sk->sk_state == TCP_TIME_WAIT) + get_timewait6_sock(seq, v, st->num); + else if (sk->sk_state == TCP_NEW_SYN_RECV) + get_openreq6(seq, v, st->num); + else + get_tcp6_sock(seq, v, st->num); +out: + return 0; +} + +static const struct seq_operations tcp6_seq_ops = { + .show = tcp6_seq_show, + .start = tcp_seq_start, + .next = tcp_seq_next, + .stop = tcp_seq_stop, +}; + +static struct tcp_seq_afinfo tcp6_seq_afinfo = { + .family = AF_INET6, +}; + +int __net_init tcp6_proc_init(struct net *net) +{ + if (!proc_create_net_data("tcp6", 0444, net->proc_net, &tcp6_seq_ops, + sizeof(struct tcp_iter_state), &tcp6_seq_afinfo)) + return -ENOMEM; + return 0; +} + +void tcp6_proc_exit(struct net *net) +{ + remove_proc_entry("tcp6", net->proc_net); +} +#endif + +struct proto tcpv6_prot = { + .name = "TCPv6", + .owner = THIS_MODULE, + .close = tcp_close, + .pre_connect = tcp_v6_pre_connect, + .connect = tcp_v6_connect, + .disconnect = tcp_disconnect, + .accept = inet_csk_accept, + .ioctl = tcp_ioctl, + .init = tcp_v6_init_sock, + .destroy = tcp_v4_destroy_sock, + .shutdown = tcp_shutdown, + .setsockopt = tcp_setsockopt, + .getsockopt = tcp_getsockopt, + .bpf_bypass_getsockopt = tcp_bpf_bypass_getsockopt, + .keepalive = tcp_set_keepalive, + .recvmsg = tcp_recvmsg, + .sendmsg = tcp_sendmsg, + .splice_eof = tcp_splice_eof, + .sendpage = tcp_sendpage, + .backlog_rcv = tcp_v6_do_rcv, + .release_cb = tcp_release_cb, + .hash = inet6_hash, + .unhash = inet_unhash, + .get_port = inet_csk_get_port, + .put_port = inet_put_port, +#ifdef CONFIG_BPF_SYSCALL + .psock_update_sk_prot = tcp_bpf_update_proto, +#endif + .enter_memory_pressure = tcp_enter_memory_pressure, + .leave_memory_pressure = tcp_leave_memory_pressure, + .stream_memory_free = tcp_stream_memory_free, + .sockets_allocated = &tcp_sockets_allocated, + + .memory_allocated = &tcp_memory_allocated, + .per_cpu_fw_alloc = &tcp_memory_per_cpu_fw_alloc, + + .memory_pressure = &tcp_memory_pressure, + .orphan_count = &tcp_orphan_count, + .sysctl_mem = sysctl_tcp_mem, + .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_tcp_wmem), + .sysctl_rmem_offset = offsetof(struct net, ipv4.sysctl_tcp_rmem), + .max_header = MAX_TCP_HEADER, + .obj_size = sizeof(struct tcp6_sock), + .slab_flags = SLAB_TYPESAFE_BY_RCU, + .twsk_prot = &tcp6_timewait_sock_ops, + .rsk_prot = &tcp6_request_sock_ops, + .h.hashinfo = NULL, + .no_autobind = true, + .diag_destroy = tcp_abort, +}; +EXPORT_SYMBOL_GPL(tcpv6_prot); + +static const struct inet6_protocol tcpv6_protocol = { + .handler = tcp_v6_rcv, + .err_handler = tcp_v6_err, + .flags = INET6_PROTO_NOPOLICY|INET6_PROTO_FINAL, +}; + +static struct inet_protosw tcpv6_protosw = { + .type = SOCK_STREAM, + .protocol = IPPROTO_TCP, + .prot = &tcpv6_prot, + .ops = &inet6_stream_ops, + .flags = INET_PROTOSW_PERMANENT | + INET_PROTOSW_ICSK, +}; + +static int __net_init tcpv6_net_init(struct net *net) +{ + return inet_ctl_sock_create(&net->ipv6.tcp_sk, PF_INET6, + SOCK_RAW, IPPROTO_TCP, net); +} + +static void __net_exit tcpv6_net_exit(struct net *net) +{ + inet_ctl_sock_destroy(net->ipv6.tcp_sk); +} + +static void __net_exit tcpv6_net_exit_batch(struct list_head *net_exit_list) +{ + tcp_twsk_purge(net_exit_list, AF_INET6); +} + +static struct pernet_operations tcpv6_net_ops = { + .init = tcpv6_net_init, + .exit = tcpv6_net_exit, + .exit_batch = tcpv6_net_exit_batch, +}; + +int __init tcpv6_init(void) +{ + int ret; + + ret = inet6_add_protocol(&tcpv6_protocol, IPPROTO_TCP); + if (ret) + goto out; + + /* register inet6 protocol */ + ret = inet6_register_protosw(&tcpv6_protosw); + if (ret) + goto out_tcpv6_protocol; + + ret = register_pernet_subsys(&tcpv6_net_ops); + if (ret) + goto out_tcpv6_protosw; + + ret = mptcpv6_init(); + if (ret) + goto out_tcpv6_pernet_subsys; + +out: + return ret; + +out_tcpv6_pernet_subsys: + unregister_pernet_subsys(&tcpv6_net_ops); +out_tcpv6_protosw: + inet6_unregister_protosw(&tcpv6_protosw); +out_tcpv6_protocol: + inet6_del_protocol(&tcpv6_protocol, IPPROTO_TCP); + goto out; +} + +void tcpv6_exit(void) +{ + unregister_pernet_subsys(&tcpv6_net_ops); + inet6_unregister_protosw(&tcpv6_protosw); + inet6_del_protocol(&tcpv6_protocol, IPPROTO_TCP); +} diff --git a/net/ipv6/tcpv6_offload.c b/net/ipv6/tcpv6_offload.c new file mode 100644 index 000000000..39db5a226 --- /dev/null +++ b/net/ipv6/tcpv6_offload.c @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPV6 GSO/GRO offload support + * Linux INET6 implementation + * + * TCPv6 GSO/GRO support + */ +#include +#include +#include +#include +#include +#include +#include "ip6_offload.h" + +INDIRECT_CALLABLE_SCOPE +struct sk_buff *tcp6_gro_receive(struct list_head *head, struct sk_buff *skb) +{ + /* Don't bother verifying checksum if we're going to flush anyway. */ + if (!NAPI_GRO_CB(skb)->flush && + skb_gro_checksum_validate(skb, IPPROTO_TCP, + ip6_gro_compute_pseudo)) { + NAPI_GRO_CB(skb)->flush = 1; + return NULL; + } + + return tcp_gro_receive(head, skb); +} + +INDIRECT_CALLABLE_SCOPE int tcp6_gro_complete(struct sk_buff *skb, int thoff) +{ + const struct ipv6hdr *iph = ipv6_hdr(skb); + struct tcphdr *th = tcp_hdr(skb); + + th->check = ~tcp_v6_check(skb->len - thoff, &iph->saddr, + &iph->daddr, 0); + skb_shinfo(skb)->gso_type |= SKB_GSO_TCPV6; + + return tcp_gro_complete(skb); +} + +static struct sk_buff *tcp6_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + struct tcphdr *th; + + if (!(skb_shinfo(skb)->gso_type & SKB_GSO_TCPV6)) + return ERR_PTR(-EINVAL); + + if (!pskb_may_pull(skb, sizeof(*th))) + return ERR_PTR(-EINVAL); + + if (unlikely(skb->ip_summed != CHECKSUM_PARTIAL)) { + const struct ipv6hdr *ipv6h = ipv6_hdr(skb); + struct tcphdr *th = tcp_hdr(skb); + + /* Set up pseudo header, usually expect stack to have done + * this. + */ + + th->check = 0; + skb->ip_summed = CHECKSUM_PARTIAL; + __tcp_v6_send_check(skb, &ipv6h->saddr, &ipv6h->daddr); + } + + return tcp_gso_segment(skb, features); +} +static const struct net_offload tcpv6_offload = { + .callbacks = { + .gso_segment = tcp6_gso_segment, + .gro_receive = tcp6_gro_receive, + .gro_complete = tcp6_gro_complete, + }, +}; + +int __init tcpv6_offload_init(void) +{ + return inet6_add_offload(&tcpv6_offload, IPPROTO_TCP); +} diff --git a/net/ipv6/tunnel6.c b/net/ipv6/tunnel6.c new file mode 100644 index 000000000..00e8d8b1c --- /dev/null +++ b/net/ipv6/tunnel6.c @@ -0,0 +1,305 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (C)2003,2004 USAGI/WIDE Project + * + * Authors Mitsuru KANDA + * YOSHIFUJI Hideaki + */ + +#define pr_fmt(fmt) "IPv6: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static struct xfrm6_tunnel __rcu *tunnel6_handlers __read_mostly; +static struct xfrm6_tunnel __rcu *tunnel46_handlers __read_mostly; +static struct xfrm6_tunnel __rcu *tunnelmpls6_handlers __read_mostly; +static DEFINE_MUTEX(tunnel6_mutex); + +static inline int xfrm6_tunnel_mpls_supported(void) +{ + return IS_ENABLED(CONFIG_MPLS); +} + +int xfrm6_tunnel_register(struct xfrm6_tunnel *handler, unsigned short family) +{ + struct xfrm6_tunnel __rcu **pprev; + struct xfrm6_tunnel *t; + int ret = -EEXIST; + int priority = handler->priority; + + mutex_lock(&tunnel6_mutex); + + switch (family) { + case AF_INET6: + pprev = &tunnel6_handlers; + break; + case AF_INET: + pprev = &tunnel46_handlers; + break; + case AF_MPLS: + pprev = &tunnelmpls6_handlers; + break; + default: + goto err; + } + + for (; (t = rcu_dereference_protected(*pprev, + lockdep_is_held(&tunnel6_mutex))) != NULL; + pprev = &t->next) { + if (t->priority > priority) + break; + if (t->priority == priority) + goto err; + } + + handler->next = *pprev; + rcu_assign_pointer(*pprev, handler); + + ret = 0; + +err: + mutex_unlock(&tunnel6_mutex); + + return ret; +} +EXPORT_SYMBOL(xfrm6_tunnel_register); + +int xfrm6_tunnel_deregister(struct xfrm6_tunnel *handler, unsigned short family) +{ + struct xfrm6_tunnel __rcu **pprev; + struct xfrm6_tunnel *t; + int ret = -ENOENT; + + mutex_lock(&tunnel6_mutex); + + switch (family) { + case AF_INET6: + pprev = &tunnel6_handlers; + break; + case AF_INET: + pprev = &tunnel46_handlers; + break; + case AF_MPLS: + pprev = &tunnelmpls6_handlers; + break; + default: + goto err; + } + + for (; (t = rcu_dereference_protected(*pprev, + lockdep_is_held(&tunnel6_mutex))) != NULL; + pprev = &t->next) { + if (t == handler) { + *pprev = handler->next; + ret = 0; + break; + } + } + +err: + mutex_unlock(&tunnel6_mutex); + + synchronize_net(); + + return ret; +} +EXPORT_SYMBOL(xfrm6_tunnel_deregister); + +#define for_each_tunnel_rcu(head, handler) \ + for (handler = rcu_dereference(head); \ + handler != NULL; \ + handler = rcu_dereference(handler->next)) \ + +static int tunnelmpls6_rcv(struct sk_buff *skb) +{ + struct xfrm6_tunnel *handler; + + if (!pskb_may_pull(skb, sizeof(struct ipv6hdr))) + goto drop; + + for_each_tunnel_rcu(tunnelmpls6_handlers, handler) + if (!handler->handler(skb)) + return 0; + + icmpv6_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_PORT_UNREACH, 0); + +drop: + kfree_skb(skb); + return 0; +} + +static int tunnel6_rcv(struct sk_buff *skb) +{ + struct xfrm6_tunnel *handler; + + if (!pskb_may_pull(skb, sizeof(struct ipv6hdr))) + goto drop; + + for_each_tunnel_rcu(tunnel6_handlers, handler) + if (!handler->handler(skb)) + return 0; + + icmpv6_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_PORT_UNREACH, 0); + +drop: + kfree_skb(skb); + return 0; +} + +#if IS_ENABLED(CONFIG_INET6_XFRM_TUNNEL) +static int tunnel6_rcv_cb(struct sk_buff *skb, u8 proto, int err) +{ + struct xfrm6_tunnel __rcu *head; + struct xfrm6_tunnel *handler; + int ret; + + head = (proto == IPPROTO_IPV6) ? tunnel6_handlers : tunnel46_handlers; + + for_each_tunnel_rcu(head, handler) { + if (handler->cb_handler) { + ret = handler->cb_handler(skb, err); + if (ret <= 0) + return ret; + } + } + + return 0; +} + +static const struct xfrm_input_afinfo tunnel6_input_afinfo = { + .family = AF_INET6, + .is_ipip = true, + .callback = tunnel6_rcv_cb, +}; +#endif + +static int tunnel46_rcv(struct sk_buff *skb) +{ + struct xfrm6_tunnel *handler; + + if (!pskb_may_pull(skb, sizeof(struct iphdr))) + goto drop; + + for_each_tunnel_rcu(tunnel46_handlers, handler) + if (!handler->handler(skb)) + return 0; + + icmpv6_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_PORT_UNREACH, 0); + +drop: + kfree_skb(skb); + return 0; +} + +static int tunnel6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + struct xfrm6_tunnel *handler; + + for_each_tunnel_rcu(tunnel6_handlers, handler) + if (!handler->err_handler(skb, opt, type, code, offset, info)) + return 0; + + return -ENOENT; +} + +static int tunnel46_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + struct xfrm6_tunnel *handler; + + for_each_tunnel_rcu(tunnel46_handlers, handler) + if (!handler->err_handler(skb, opt, type, code, offset, info)) + return 0; + + return -ENOENT; +} + +static int tunnelmpls6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + struct xfrm6_tunnel *handler; + + for_each_tunnel_rcu(tunnelmpls6_handlers, handler) + if (!handler->err_handler(skb, opt, type, code, offset, info)) + return 0; + + return -ENOENT; +} + +static const struct inet6_protocol tunnel6_protocol = { + .handler = tunnel6_rcv, + .err_handler = tunnel6_err, + .flags = INET6_PROTO_NOPOLICY|INET6_PROTO_FINAL, +}; + +static const struct inet6_protocol tunnel46_protocol = { + .handler = tunnel46_rcv, + .err_handler = tunnel46_err, + .flags = INET6_PROTO_NOPOLICY|INET6_PROTO_FINAL, +}; + +static const struct inet6_protocol tunnelmpls6_protocol = { + .handler = tunnelmpls6_rcv, + .err_handler = tunnelmpls6_err, + .flags = INET6_PROTO_NOPOLICY|INET6_PROTO_FINAL, +}; + +static int __init tunnel6_init(void) +{ + if (inet6_add_protocol(&tunnel6_protocol, IPPROTO_IPV6)) { + pr_err("%s: can't add protocol\n", __func__); + return -EAGAIN; + } + if (inet6_add_protocol(&tunnel46_protocol, IPPROTO_IPIP)) { + pr_err("%s: can't add protocol\n", __func__); + inet6_del_protocol(&tunnel6_protocol, IPPROTO_IPV6); + return -EAGAIN; + } + if (xfrm6_tunnel_mpls_supported() && + inet6_add_protocol(&tunnelmpls6_protocol, IPPROTO_MPLS)) { + pr_err("%s: can't add protocol\n", __func__); + inet6_del_protocol(&tunnel6_protocol, IPPROTO_IPV6); + inet6_del_protocol(&tunnel46_protocol, IPPROTO_IPIP); + return -EAGAIN; + } +#if IS_ENABLED(CONFIG_INET6_XFRM_TUNNEL) + if (xfrm_input_register_afinfo(&tunnel6_input_afinfo)) { + pr_err("%s: can't add input afinfo\n", __func__); + inet6_del_protocol(&tunnel6_protocol, IPPROTO_IPV6); + inet6_del_protocol(&tunnel46_protocol, IPPROTO_IPIP); + if (xfrm6_tunnel_mpls_supported()) + inet6_del_protocol(&tunnelmpls6_protocol, IPPROTO_MPLS); + return -EAGAIN; + } +#endif + return 0; +} + +static void __exit tunnel6_fini(void) +{ +#if IS_ENABLED(CONFIG_INET6_XFRM_TUNNEL) + if (xfrm_input_unregister_afinfo(&tunnel6_input_afinfo)) + pr_err("%s: can't remove input afinfo\n", __func__); +#endif + if (inet6_del_protocol(&tunnel46_protocol, IPPROTO_IPIP)) + pr_err("%s: can't remove protocol\n", __func__); + if (inet6_del_protocol(&tunnel6_protocol, IPPROTO_IPV6)) + pr_err("%s: can't remove protocol\n", __func__); + if (xfrm6_tunnel_mpls_supported() && + inet6_del_protocol(&tunnelmpls6_protocol, IPPROTO_MPLS)) + pr_err("%s: can't remove protocol\n", __func__); +} + +module_init(tunnel6_init); +module_exit(tunnel6_fini); +MODULE_LICENSE("GPL"); diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c new file mode 100644 index 000000000..c2c02dea6 --- /dev/null +++ b/net/ipv6/udp.c @@ -0,0 +1,1839 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * UDP over IPv6 + * Linux INET6 implementation + * + * Authors: + * Pedro Roque + * + * Based on linux/ipv4/udp.c + * + * Fixes: + * Hideaki YOSHIFUJI : sin6_scope_id support + * YOSHIFUJI Hideaki @USAGI and: Support IPV6_V6ONLY socket option, which + * Alexey Kuznetsov allow both IPv4 and IPv6 sockets to bind + * a single port at the same time. + * Kazunori MIYAZAWA @USAGI: change process style to use ip6_append_data + * YOSHIFUJI Hideaki @USAGI: convert /proc/net/udp6 to seq_file. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include "udp_impl.h" + +static void udpv6_destruct_sock(struct sock *sk) +{ + udp_destruct_common(sk); + inet6_sock_destruct(sk); +} + +int udpv6_init_sock(struct sock *sk) +{ + skb_queue_head_init(&udp_sk(sk)->reader_queue); + sk->sk_destruct = udpv6_destruct_sock; + set_bit(SOCK_SUPPORT_ZC, &sk->sk_socket->flags); + return 0; +} + +static u32 udp6_ehashfn(const struct net *net, + const struct in6_addr *laddr, + const u16 lport, + const struct in6_addr *faddr, + const __be16 fport) +{ + static u32 udp6_ehash_secret __read_mostly; + static u32 udp_ipv6_hash_secret __read_mostly; + + u32 lhash, fhash; + + net_get_random_once(&udp6_ehash_secret, + sizeof(udp6_ehash_secret)); + net_get_random_once(&udp_ipv6_hash_secret, + sizeof(udp_ipv6_hash_secret)); + + lhash = (__force u32)laddr->s6_addr32[3]; + fhash = __ipv6_addr_jhash(faddr, udp_ipv6_hash_secret); + + return __inet6_ehashfn(lhash, lport, fhash, fport, + udp6_ehash_secret + net_hash_mix(net)); +} + +int udp_v6_get_port(struct sock *sk, unsigned short snum) +{ + unsigned int hash2_nulladdr = + ipv6_portaddr_hash(sock_net(sk), &in6addr_any, snum); + unsigned int hash2_partial = + ipv6_portaddr_hash(sock_net(sk), &sk->sk_v6_rcv_saddr, 0); + + /* precompute partial secondary hash */ + udp_sk(sk)->udp_portaddr_hash = hash2_partial; + return udp_lib_get_port(sk, snum, hash2_nulladdr); +} + +void udp_v6_rehash(struct sock *sk) +{ + u16 new_hash = ipv6_portaddr_hash(sock_net(sk), + &sk->sk_v6_rcv_saddr, + inet_sk(sk)->inet_num); + + udp_lib_rehash(sk, new_hash); +} + +static int compute_score(struct sock *sk, struct net *net, + const struct in6_addr *saddr, __be16 sport, + const struct in6_addr *daddr, unsigned short hnum, + int dif, int sdif) +{ + int bound_dev_if, score; + struct inet_sock *inet; + bool dev_match; + + if (!net_eq(sock_net(sk), net) || + udp_sk(sk)->udp_port_hash != hnum || + sk->sk_family != PF_INET6) + return -1; + + if (!ipv6_addr_equal(&sk->sk_v6_rcv_saddr, daddr)) + return -1; + + score = 0; + inet = inet_sk(sk); + + if (inet->inet_dport) { + if (inet->inet_dport != sport) + return -1; + score++; + } + + if (!ipv6_addr_any(&sk->sk_v6_daddr)) { + if (!ipv6_addr_equal(&sk->sk_v6_daddr, saddr)) + return -1; + score++; + } + + bound_dev_if = READ_ONCE(sk->sk_bound_dev_if); + dev_match = udp_sk_bound_dev_eq(net, bound_dev_if, dif, sdif); + if (!dev_match) + return -1; + if (bound_dev_if) + score++; + + if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id()) + score++; + + return score; +} + +static struct sock *lookup_reuseport(struct net *net, struct sock *sk, + struct sk_buff *skb, + const struct in6_addr *saddr, + __be16 sport, + const struct in6_addr *daddr, + unsigned int hnum) +{ + struct sock *reuse_sk = NULL; + u32 hash; + + if (sk->sk_reuseport && sk->sk_state != TCP_ESTABLISHED) { + hash = udp6_ehashfn(net, daddr, hnum, saddr, sport); + reuse_sk = reuseport_select_sock(sk, hash, skb, + sizeof(struct udphdr)); + } + return reuse_sk; +} + +/* called with rcu_read_lock() */ +static struct sock *udp6_lib_lookup2(struct net *net, + const struct in6_addr *saddr, __be16 sport, + const struct in6_addr *daddr, unsigned int hnum, + int dif, int sdif, struct udp_hslot *hslot2, + struct sk_buff *skb) +{ + struct sock *sk, *result; + int score, badness; + + result = NULL; + badness = -1; + udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) { + score = compute_score(sk, net, saddr, sport, + daddr, hnum, dif, sdif); + if (score > badness) { + badness = score; + result = lookup_reuseport(net, sk, skb, saddr, sport, daddr, hnum); + if (!result) { + result = sk; + continue; + } + + /* Fall back to scoring if group has connections */ + if (!reuseport_has_conns(sk)) + return result; + + /* Reuseport logic returned an error, keep original score. */ + if (IS_ERR(result)) + continue; + + badness = compute_score(sk, net, saddr, sport, + daddr, hnum, dif, sdif); + } + } + return result; +} + +static inline struct sock *udp6_lookup_run_bpf(struct net *net, + struct udp_table *udptable, + struct sk_buff *skb, + const struct in6_addr *saddr, + __be16 sport, + const struct in6_addr *daddr, + u16 hnum, const int dif) +{ + struct sock *sk, *reuse_sk; + bool no_reuseport; + + if (udptable != &udp_table) + return NULL; /* only UDP is supported */ + + no_reuseport = bpf_sk_lookup_run_v6(net, IPPROTO_UDP, saddr, sport, + daddr, hnum, dif, &sk); + if (no_reuseport || IS_ERR_OR_NULL(sk)) + return sk; + + reuse_sk = lookup_reuseport(net, sk, skb, saddr, sport, daddr, hnum); + if (reuse_sk) + sk = reuse_sk; + return sk; +} + +/* rcu_read_lock() must be held */ +struct sock *__udp6_lib_lookup(struct net *net, + const struct in6_addr *saddr, __be16 sport, + const struct in6_addr *daddr, __be16 dport, + int dif, int sdif, struct udp_table *udptable, + struct sk_buff *skb) +{ + unsigned short hnum = ntohs(dport); + unsigned int hash2, slot2; + struct udp_hslot *hslot2; + struct sock *result, *sk; + + hash2 = ipv6_portaddr_hash(net, daddr, hnum); + slot2 = hash2 & udptable->mask; + hslot2 = &udptable->hash2[slot2]; + + /* Lookup connected or non-wildcard sockets */ + result = udp6_lib_lookup2(net, saddr, sport, + daddr, hnum, dif, sdif, + hslot2, skb); + if (!IS_ERR_OR_NULL(result) && result->sk_state == TCP_ESTABLISHED) + goto done; + + /* Lookup redirect from BPF */ + if (static_branch_unlikely(&bpf_sk_lookup_enabled)) { + sk = udp6_lookup_run_bpf(net, udptable, skb, + saddr, sport, daddr, hnum, dif); + if (sk) { + result = sk; + goto done; + } + } + + /* Got non-wildcard socket or error on first lookup */ + if (result) + goto done; + + /* Lookup wildcard sockets */ + hash2 = ipv6_portaddr_hash(net, &in6addr_any, hnum); + slot2 = hash2 & udptable->mask; + hslot2 = &udptable->hash2[slot2]; + + result = udp6_lib_lookup2(net, saddr, sport, + &in6addr_any, hnum, dif, sdif, + hslot2, skb); +done: + if (IS_ERR(result)) + return NULL; + return result; +} +EXPORT_SYMBOL_GPL(__udp6_lib_lookup); + +static struct sock *__udp6_lib_lookup_skb(struct sk_buff *skb, + __be16 sport, __be16 dport, + struct udp_table *udptable) +{ + const struct ipv6hdr *iph = ipv6_hdr(skb); + + return __udp6_lib_lookup(dev_net(skb->dev), &iph->saddr, sport, + &iph->daddr, dport, inet6_iif(skb), + inet6_sdif(skb), udptable, skb); +} + +struct sock *udp6_lib_lookup_skb(const struct sk_buff *skb, + __be16 sport, __be16 dport) +{ + const struct ipv6hdr *iph = ipv6_hdr(skb); + + return __udp6_lib_lookup(dev_net(skb->dev), &iph->saddr, sport, + &iph->daddr, dport, inet6_iif(skb), + inet6_sdif(skb), &udp_table, NULL); +} + +/* Must be called under rcu_read_lock(). + * Does increment socket refcount. + */ +#if IS_ENABLED(CONFIG_NF_TPROXY_IPV6) || IS_ENABLED(CONFIG_NF_SOCKET_IPV6) +struct sock *udp6_lib_lookup(struct net *net, const struct in6_addr *saddr, __be16 sport, + const struct in6_addr *daddr, __be16 dport, int dif) +{ + struct sock *sk; + + sk = __udp6_lib_lookup(net, saddr, sport, daddr, dport, + dif, 0, &udp_table, NULL); + if (sk && !refcount_inc_not_zero(&sk->sk_refcnt)) + sk = NULL; + return sk; +} +EXPORT_SYMBOL_GPL(udp6_lib_lookup); +#endif + +/* do not use the scratch area len for jumbogram: their length execeeds the + * scratch area space; note that the IP6CB flags is still in the first + * cacheline, so checking for jumbograms is cheap + */ +static int udp6_skb_len(struct sk_buff *skb) +{ + return unlikely(inet6_is_jumbogram(skb)) ? skb->len : udp_skb_len(skb); +} + +/* + * This should be easy, if there is something there we + * return it, otherwise we block. + */ + +int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int flags, int *addr_len) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct inet_sock *inet = inet_sk(sk); + struct sk_buff *skb; + unsigned int ulen, copied; + int off, err, peeking = flags & MSG_PEEK; + int is_udplite = IS_UDPLITE(sk); + struct udp_mib __percpu *mib; + bool checksum_valid = false; + int is_udp4; + + if (flags & MSG_ERRQUEUE) + return ipv6_recv_error(sk, msg, len, addr_len); + + if (np->rxpmtu && np->rxopt.bits.rxpmtu) + return ipv6_recv_rxpmtu(sk, msg, len, addr_len); + +try_again: + off = sk_peek_offset(sk, flags); + skb = __skb_recv_udp(sk, flags, &off, &err); + if (!skb) + return err; + + ulen = udp6_skb_len(skb); + copied = len; + if (copied > ulen - off) + copied = ulen - off; + else if (copied < ulen) + msg->msg_flags |= MSG_TRUNC; + + is_udp4 = (skb->protocol == htons(ETH_P_IP)); + mib = __UDPX_MIB(sk, is_udp4); + + /* + * If checksum is needed at all, try to do it while copying the + * data. If the data is truncated, or if we only want a partial + * coverage checksum (UDP-Lite), do it before the copy. + */ + + if (copied < ulen || peeking || + (is_udplite && UDP_SKB_CB(skb)->partial_cov)) { + checksum_valid = udp_skb_csum_unnecessary(skb) || + !__udp_lib_checksum_complete(skb); + if (!checksum_valid) + goto csum_copy_err; + } + + if (checksum_valid || udp_skb_csum_unnecessary(skb)) { + if (udp_skb_is_linear(skb)) + err = copy_linear_skb(skb, copied, off, &msg->msg_iter); + else + err = skb_copy_datagram_msg(skb, off, msg, copied); + } else { + err = skb_copy_and_csum_datagram_msg(skb, off, msg); + if (err == -EINVAL) + goto csum_copy_err; + } + if (unlikely(err)) { + if (!peeking) { + atomic_inc(&sk->sk_drops); + SNMP_INC_STATS(mib, UDP_MIB_INERRORS); + } + kfree_skb(skb); + return err; + } + if (!peeking) + SNMP_INC_STATS(mib, UDP_MIB_INDATAGRAMS); + + sock_recv_cmsgs(msg, sk, skb); + + /* Copy the address. */ + if (msg->msg_name) { + DECLARE_SOCKADDR(struct sockaddr_in6 *, sin6, msg->msg_name); + sin6->sin6_family = AF_INET6; + sin6->sin6_port = udp_hdr(skb)->source; + sin6->sin6_flowinfo = 0; + + if (is_udp4) { + ipv6_addr_set_v4mapped(ip_hdr(skb)->saddr, + &sin6->sin6_addr); + sin6->sin6_scope_id = 0; + } else { + sin6->sin6_addr = ipv6_hdr(skb)->saddr; + sin6->sin6_scope_id = + ipv6_iface_scope_id(&sin6->sin6_addr, + inet6_iif(skb)); + } + *addr_len = sizeof(*sin6); + + BPF_CGROUP_RUN_PROG_UDP6_RECVMSG_LOCK(sk, + (struct sockaddr *)sin6); + } + + if (udp_test_bit(GRO_ENABLED, sk)) + udp_cmsg_recv(msg, sk, skb); + + if (np->rxopt.all) + ip6_datagram_recv_common_ctl(sk, msg, skb); + + if (is_udp4) { + if (inet->cmsg_flags) + ip_cmsg_recv_offset(msg, sk, skb, + sizeof(struct udphdr), off); + } else { + if (np->rxopt.all) + ip6_datagram_recv_specific_ctl(sk, msg, skb); + } + + err = copied; + if (flags & MSG_TRUNC) + err = ulen; + + skb_consume_udp(sk, skb, peeking ? -err : err); + return err; + +csum_copy_err: + if (!__sk_queue_drop_skb(sk, &udp_sk(sk)->reader_queue, skb, flags, + udp_skb_destructor)) { + SNMP_INC_STATS(mib, UDP_MIB_CSUMERRORS); + SNMP_INC_STATS(mib, UDP_MIB_INERRORS); + } + kfree_skb(skb); + + /* starting over for a new packet, but check if we need to yield */ + cond_resched(); + msg->msg_flags &= ~MSG_TRUNC; + goto try_again; +} + +DEFINE_STATIC_KEY_FALSE(udpv6_encap_needed_key); +void udpv6_encap_enable(void) +{ + static_branch_inc(&udpv6_encap_needed_key); +} +EXPORT_SYMBOL(udpv6_encap_enable); + +/* Handler for tunnels with arbitrary destination ports: no socket lookup, go + * through error handlers in encapsulations looking for a match. + */ +static int __udp6_lib_err_encap_no_sk(struct sk_buff *skb, + struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + int i; + + for (i = 0; i < MAX_IPTUN_ENCAP_OPS; i++) { + int (*handler)(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info); + const struct ip6_tnl_encap_ops *encap; + + encap = rcu_dereference(ip6tun_encaps[i]); + if (!encap) + continue; + handler = encap->err_handler; + if (handler && !handler(skb, opt, type, code, offset, info)) + return 0; + } + + return -ENOENT; +} + +/* Try to match ICMP errors to UDP tunnels by looking up a socket without + * reversing source and destination port: this will match tunnels that force the + * same destination port on both endpoints (e.g. VXLAN, GENEVE). Note that + * lwtunnels might actually break this assumption by being configured with + * different destination ports on endpoints, in this case we won't be able to + * trace ICMP messages back to them. + * + * If this doesn't match any socket, probe tunnels with arbitrary destination + * ports (e.g. FoU, GUE): there, the receiving socket is useless, as the port + * we've sent packets to won't necessarily match the local destination port. + * + * Then ask the tunnel implementation to match the error against a valid + * association. + * + * Return an error if we can't find a match, the socket if we need further + * processing, zero otherwise. + */ +static struct sock *__udp6_lib_err_encap(struct net *net, + const struct ipv6hdr *hdr, int offset, + struct udphdr *uh, + struct udp_table *udptable, + struct sock *sk, + struct sk_buff *skb, + struct inet6_skb_parm *opt, + u8 type, u8 code, __be32 info) +{ + int (*lookup)(struct sock *sk, struct sk_buff *skb); + int network_offset, transport_offset; + struct udp_sock *up; + + network_offset = skb_network_offset(skb); + transport_offset = skb_transport_offset(skb); + + /* Network header needs to point to the outer IPv6 header inside ICMP */ + skb_reset_network_header(skb); + + /* Transport header needs to point to the UDP header */ + skb_set_transport_header(skb, offset); + + if (sk) { + up = udp_sk(sk); + + lookup = READ_ONCE(up->encap_err_lookup); + if (lookup && lookup(sk, skb)) + sk = NULL; + + goto out; + } + + sk = __udp6_lib_lookup(net, &hdr->daddr, uh->source, + &hdr->saddr, uh->dest, + inet6_iif(skb), 0, udptable, skb); + if (sk) { + up = udp_sk(sk); + + lookup = READ_ONCE(up->encap_err_lookup); + if (!lookup || lookup(sk, skb)) + sk = NULL; + } + +out: + if (!sk) { + sk = ERR_PTR(__udp6_lib_err_encap_no_sk(skb, opt, type, code, + offset, info)); + } + + skb_set_transport_header(skb, transport_offset); + skb_set_network_header(skb, network_offset); + + return sk; +} + +int __udp6_lib_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info, + struct udp_table *udptable) +{ + struct ipv6_pinfo *np; + const struct ipv6hdr *hdr = (const struct ipv6hdr *)skb->data; + const struct in6_addr *saddr = &hdr->saddr; + const struct in6_addr *daddr = seg6_get_daddr(skb, opt) ? : &hdr->daddr; + struct udphdr *uh = (struct udphdr *)(skb->data+offset); + bool tunnel = false; + struct sock *sk; + int harderr; + int err; + struct net *net = dev_net(skb->dev); + + sk = __udp6_lib_lookup(net, daddr, uh->dest, saddr, uh->source, + inet6_iif(skb), inet6_sdif(skb), udptable, NULL); + + if (!sk || READ_ONCE(udp_sk(sk)->encap_type)) { + /* No socket for error: try tunnels before discarding */ + if (static_branch_unlikely(&udpv6_encap_needed_key)) { + sk = __udp6_lib_err_encap(net, hdr, offset, uh, + udptable, sk, skb, + opt, type, code, info); + if (!sk) + return 0; + } else + sk = ERR_PTR(-ENOENT); + + if (IS_ERR(sk)) { + __ICMP6_INC_STATS(net, __in6_dev_get(skb->dev), + ICMP6_MIB_INERRORS); + return PTR_ERR(sk); + } + + tunnel = true; + } + + harderr = icmpv6_err_convert(type, code, &err); + np = inet6_sk(sk); + + if (type == ICMPV6_PKT_TOOBIG) { + if (!ip6_sk_accept_pmtu(sk)) + goto out; + ip6_sk_update_pmtu(skb, sk, info); + if (np->pmtudisc != IPV6_PMTUDISC_DONT) + harderr = 1; + } + if (type == NDISC_REDIRECT) { + if (tunnel) { + ip6_redirect(skb, sock_net(sk), inet6_iif(skb), + READ_ONCE(sk->sk_mark), sk->sk_uid); + } else { + ip6_sk_redirect(skb, sk); + } + goto out; + } + + /* Tunnels don't have an application socket: don't pass errors back */ + if (tunnel) { + if (udp_sk(sk)->encap_err_rcv) + udp_sk(sk)->encap_err_rcv(sk, skb, offset); + goto out; + } + + if (!np->recverr) { + if (!harderr || sk->sk_state != TCP_ESTABLISHED) + goto out; + } else { + ipv6_icmp_error(sk, skb, err, uh->dest, ntohl(info), (u8 *)(uh+1)); + } + + sk->sk_err = err; + sk_error_report(sk); +out: + return 0; +} + +static int __udpv6_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) +{ + int rc; + + if (!ipv6_addr_any(&sk->sk_v6_daddr)) { + sock_rps_save_rxhash(sk, skb); + sk_mark_napi_id(sk, skb); + sk_incoming_cpu_update(sk); + } else { + sk_mark_napi_id_once(sk, skb); + } + + rc = __udp_enqueue_schedule_skb(sk, skb); + if (rc < 0) { + int is_udplite = IS_UDPLITE(sk); + enum skb_drop_reason drop_reason; + + /* Note that an ENOMEM error is charged twice */ + if (rc == -ENOMEM) { + UDP6_INC_STATS(sock_net(sk), + UDP_MIB_RCVBUFERRORS, is_udplite); + drop_reason = SKB_DROP_REASON_SOCKET_RCVBUFF; + } else { + UDP6_INC_STATS(sock_net(sk), + UDP_MIB_MEMERRORS, is_udplite); + drop_reason = SKB_DROP_REASON_PROTO_MEM; + } + UDP6_INC_STATS(sock_net(sk), UDP_MIB_INERRORS, is_udplite); + kfree_skb_reason(skb, drop_reason); + return -1; + } + + return 0; +} + +static __inline__ int udpv6_err(struct sk_buff *skb, + struct inet6_skb_parm *opt, u8 type, + u8 code, int offset, __be32 info) +{ + return __udp6_lib_err(skb, opt, type, code, offset, info, &udp_table); +} + +static int udpv6_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb) +{ + enum skb_drop_reason drop_reason = SKB_DROP_REASON_NOT_SPECIFIED; + struct udp_sock *up = udp_sk(sk); + int is_udplite = IS_UDPLITE(sk); + + if (!xfrm6_policy_check(sk, XFRM_POLICY_IN, skb)) { + drop_reason = SKB_DROP_REASON_XFRM_POLICY; + goto drop; + } + nf_reset_ct(skb); + + if (static_branch_unlikely(&udpv6_encap_needed_key) && + READ_ONCE(up->encap_type)) { + int (*encap_rcv)(struct sock *sk, struct sk_buff *skb); + + /* + * This is an encapsulation socket so pass the skb to + * the socket's udp_encap_rcv() hook. Otherwise, just + * fall through and pass this up the UDP socket. + * up->encap_rcv() returns the following value: + * =0 if skb was successfully passed to the encap + * handler or was discarded by it. + * >0 if skb should be passed on to UDP. + * <0 if skb should be resubmitted as proto -N + */ + + /* if we're overly short, let UDP handle it */ + encap_rcv = READ_ONCE(up->encap_rcv); + if (encap_rcv) { + int ret; + + /* Verify checksum before giving to encap */ + if (udp_lib_checksum_complete(skb)) + goto csum_error; + + ret = encap_rcv(sk, skb); + if (ret <= 0) { + __UDP6_INC_STATS(sock_net(sk), + UDP_MIB_INDATAGRAMS, + is_udplite); + return -ret; + } + } + + /* FALLTHROUGH -- it's a UDP Packet */ + } + + /* + * UDP-Lite specific tests, ignored on UDP sockets (see net/ipv4/udp.c). + */ + if ((up->pcflag & UDPLITE_RECV_CC) && UDP_SKB_CB(skb)->partial_cov) { + + if (up->pcrlen == 0) { /* full coverage was set */ + net_dbg_ratelimited("UDPLITE6: partial coverage %d while full coverage %d requested\n", + UDP_SKB_CB(skb)->cscov, skb->len); + goto drop; + } + if (UDP_SKB_CB(skb)->cscov < up->pcrlen) { + net_dbg_ratelimited("UDPLITE6: coverage %d too small, need min %d\n", + UDP_SKB_CB(skb)->cscov, up->pcrlen); + goto drop; + } + } + + prefetch(&sk->sk_rmem_alloc); + if (rcu_access_pointer(sk->sk_filter) && + udp_lib_checksum_complete(skb)) + goto csum_error; + + if (sk_filter_trim_cap(sk, skb, sizeof(struct udphdr))) { + drop_reason = SKB_DROP_REASON_SOCKET_FILTER; + goto drop; + } + + udp_csum_pull_header(skb); + + skb_dst_drop(skb); + + return __udpv6_queue_rcv_skb(sk, skb); + +csum_error: + drop_reason = SKB_DROP_REASON_UDP_CSUM; + __UDP6_INC_STATS(sock_net(sk), UDP_MIB_CSUMERRORS, is_udplite); +drop: + __UDP6_INC_STATS(sock_net(sk), UDP_MIB_INERRORS, is_udplite); + atomic_inc(&sk->sk_drops); + kfree_skb_reason(skb, drop_reason); + return -1; +} + +static int udpv6_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) +{ + struct sk_buff *next, *segs; + int ret; + + if (likely(!udp_unexpected_gso(sk, skb))) + return udpv6_queue_rcv_one_skb(sk, skb); + + __skb_push(skb, -skb_mac_offset(skb)); + segs = udp_rcv_segment(sk, skb, false); + skb_list_walk_safe(segs, skb, next) { + __skb_pull(skb, skb_transport_offset(skb)); + + udp_post_segment_fix_csum(skb); + ret = udpv6_queue_rcv_one_skb(sk, skb); + if (ret > 0) + ip6_protocol_deliver_rcu(dev_net(skb->dev), skb, ret, + true); + } + return 0; +} + +static bool __udp_v6_is_mcast_sock(struct net *net, struct sock *sk, + __be16 loc_port, const struct in6_addr *loc_addr, + __be16 rmt_port, const struct in6_addr *rmt_addr, + int dif, int sdif, unsigned short hnum) +{ + struct inet_sock *inet = inet_sk(sk); + + if (!net_eq(sock_net(sk), net)) + return false; + + if (udp_sk(sk)->udp_port_hash != hnum || + sk->sk_family != PF_INET6 || + (inet->inet_dport && inet->inet_dport != rmt_port) || + (!ipv6_addr_any(&sk->sk_v6_daddr) && + !ipv6_addr_equal(&sk->sk_v6_daddr, rmt_addr)) || + !udp_sk_bound_dev_eq(net, READ_ONCE(sk->sk_bound_dev_if), dif, sdif) || + (!ipv6_addr_any(&sk->sk_v6_rcv_saddr) && + !ipv6_addr_equal(&sk->sk_v6_rcv_saddr, loc_addr))) + return false; + if (!inet6_mc_check(sk, loc_addr, rmt_addr)) + return false; + return true; +} + +static void udp6_csum_zero_error(struct sk_buff *skb) +{ + /* RFC 2460 section 8.1 says that we SHOULD log + * this error. Well, it is reasonable. + */ + net_dbg_ratelimited("IPv6: udp checksum is 0 for [%pI6c]:%u->[%pI6c]:%u\n", + &ipv6_hdr(skb)->saddr, ntohs(udp_hdr(skb)->source), + &ipv6_hdr(skb)->daddr, ntohs(udp_hdr(skb)->dest)); +} + +/* + * Note: called only from the BH handler context, + * so we don't need to lock the hashes. + */ +static int __udp6_lib_mcast_deliver(struct net *net, struct sk_buff *skb, + const struct in6_addr *saddr, const struct in6_addr *daddr, + struct udp_table *udptable, int proto) +{ + struct sock *sk, *first = NULL; + const struct udphdr *uh = udp_hdr(skb); + unsigned short hnum = ntohs(uh->dest); + struct udp_hslot *hslot = udp_hashslot(udptable, net, hnum); + unsigned int offset = offsetof(typeof(*sk), sk_node); + unsigned int hash2 = 0, hash2_any = 0, use_hash2 = (hslot->count > 10); + int dif = inet6_iif(skb); + int sdif = inet6_sdif(skb); + struct hlist_node *node; + struct sk_buff *nskb; + + if (use_hash2) { + hash2_any = ipv6_portaddr_hash(net, &in6addr_any, hnum) & + udptable->mask; + hash2 = ipv6_portaddr_hash(net, daddr, hnum) & udptable->mask; +start_lookup: + hslot = &udptable->hash2[hash2]; + offset = offsetof(typeof(*sk), __sk_common.skc_portaddr_node); + } + + sk_for_each_entry_offset_rcu(sk, node, &hslot->head, offset) { + if (!__udp_v6_is_mcast_sock(net, sk, uh->dest, daddr, + uh->source, saddr, dif, sdif, + hnum)) + continue; + /* If zero checksum and no_check is not on for + * the socket then skip it. + */ + if (!uh->check && !udp_get_no_check6_rx(sk)) + continue; + if (!first) { + first = sk; + continue; + } + nskb = skb_clone(skb, GFP_ATOMIC); + if (unlikely(!nskb)) { + atomic_inc(&sk->sk_drops); + __UDP6_INC_STATS(net, UDP_MIB_RCVBUFERRORS, + IS_UDPLITE(sk)); + __UDP6_INC_STATS(net, UDP_MIB_INERRORS, + IS_UDPLITE(sk)); + continue; + } + + if (udpv6_queue_rcv_skb(sk, nskb) > 0) + consume_skb(nskb); + } + + /* Also lookup *:port if we are using hash2 and haven't done so yet. */ + if (use_hash2 && hash2 != hash2_any) { + hash2 = hash2_any; + goto start_lookup; + } + + if (first) { + if (udpv6_queue_rcv_skb(first, skb) > 0) + consume_skb(skb); + } else { + kfree_skb(skb); + __UDP6_INC_STATS(net, UDP_MIB_IGNOREDMULTI, + proto == IPPROTO_UDPLITE); + } + return 0; +} + +static void udp6_sk_rx_dst_set(struct sock *sk, struct dst_entry *dst) +{ + if (udp_sk_rx_dst_set(sk, dst)) { + const struct rt6_info *rt = (const struct rt6_info *)dst; + + sk->sk_rx_dst_cookie = rt6_get_cookie(rt); + } +} + +/* wrapper for udp_queue_rcv_skb tacking care of csum conversion and + * return code conversion for ip layer consumption + */ +static int udp6_unicast_rcv_skb(struct sock *sk, struct sk_buff *skb, + struct udphdr *uh) +{ + int ret; + + if (inet_get_convert_csum(sk) && uh->check && !IS_UDPLITE(sk)) + skb_checksum_try_convert(skb, IPPROTO_UDP, ip6_compute_pseudo); + + ret = udpv6_queue_rcv_skb(sk, skb); + + /* a return value > 0 means to resubmit the input */ + if (ret > 0) + return ret; + return 0; +} + +int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable, + int proto) +{ + enum skb_drop_reason reason = SKB_DROP_REASON_NOT_SPECIFIED; + const struct in6_addr *saddr, *daddr; + struct net *net = dev_net(skb->dev); + struct udphdr *uh; + struct sock *sk; + bool refcounted; + u32 ulen = 0; + + if (!pskb_may_pull(skb, sizeof(struct udphdr))) + goto discard; + + saddr = &ipv6_hdr(skb)->saddr; + daddr = &ipv6_hdr(skb)->daddr; + uh = udp_hdr(skb); + + ulen = ntohs(uh->len); + if (ulen > skb->len) + goto short_packet; + + if (proto == IPPROTO_UDP) { + /* UDP validates ulen. */ + + /* Check for jumbo payload */ + if (ulen == 0) + ulen = skb->len; + + if (ulen < sizeof(*uh)) + goto short_packet; + + if (ulen < skb->len) { + if (pskb_trim_rcsum(skb, ulen)) + goto short_packet; + saddr = &ipv6_hdr(skb)->saddr; + daddr = &ipv6_hdr(skb)->daddr; + uh = udp_hdr(skb); + } + } + + if (udp6_csum_init(skb, uh, proto)) + goto csum_error; + + /* Check if the socket is already available, e.g. due to early demux */ + sk = skb_steal_sock(skb, &refcounted); + if (sk) { + struct dst_entry *dst = skb_dst(skb); + int ret; + + if (unlikely(rcu_dereference(sk->sk_rx_dst) != dst)) + udp6_sk_rx_dst_set(sk, dst); + + if (!uh->check && !udp_get_no_check6_rx(sk)) { + if (refcounted) + sock_put(sk); + goto report_csum_error; + } + + ret = udp6_unicast_rcv_skb(sk, skb, uh); + if (refcounted) + sock_put(sk); + return ret; + } + + /* + * Multicast receive code + */ + if (ipv6_addr_is_multicast(daddr)) + return __udp6_lib_mcast_deliver(net, skb, + saddr, daddr, udptable, proto); + + /* Unicast */ + sk = __udp6_lib_lookup_skb(skb, uh->source, uh->dest, udptable); + if (sk) { + if (!uh->check && !udp_get_no_check6_rx(sk)) + goto report_csum_error; + return udp6_unicast_rcv_skb(sk, skb, uh); + } + + reason = SKB_DROP_REASON_NO_SOCKET; + + if (!uh->check) + goto report_csum_error; + + if (!xfrm6_policy_check(NULL, XFRM_POLICY_IN, skb)) + goto discard; + nf_reset_ct(skb); + + if (udp_lib_checksum_complete(skb)) + goto csum_error; + + __UDP6_INC_STATS(net, UDP_MIB_NOPORTS, proto == IPPROTO_UDPLITE); + icmpv6_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_PORT_UNREACH, 0); + + kfree_skb_reason(skb, reason); + return 0; + +short_packet: + if (reason == SKB_DROP_REASON_NOT_SPECIFIED) + reason = SKB_DROP_REASON_PKT_TOO_SMALL; + net_dbg_ratelimited("UDP%sv6: short packet: From [%pI6c]:%u %d/%d to [%pI6c]:%u\n", + proto == IPPROTO_UDPLITE ? "-Lite" : "", + saddr, ntohs(uh->source), + ulen, skb->len, + daddr, ntohs(uh->dest)); + goto discard; + +report_csum_error: + udp6_csum_zero_error(skb); +csum_error: + if (reason == SKB_DROP_REASON_NOT_SPECIFIED) + reason = SKB_DROP_REASON_UDP_CSUM; + __UDP6_INC_STATS(net, UDP_MIB_CSUMERRORS, proto == IPPROTO_UDPLITE); +discard: + __UDP6_INC_STATS(net, UDP_MIB_INERRORS, proto == IPPROTO_UDPLITE); + kfree_skb_reason(skb, reason); + return 0; +} + + +static struct sock *__udp6_lib_demux_lookup(struct net *net, + __be16 loc_port, const struct in6_addr *loc_addr, + __be16 rmt_port, const struct in6_addr *rmt_addr, + int dif, int sdif) +{ + unsigned short hnum = ntohs(loc_port); + unsigned int hash2, slot2; + struct udp_hslot *hslot2; + __portpair ports; + struct sock *sk; + + hash2 = ipv6_portaddr_hash(net, loc_addr, hnum); + slot2 = hash2 & udp_table.mask; + hslot2 = &udp_table.hash2[slot2]; + ports = INET_COMBINED_PORTS(rmt_port, hnum); + + udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) { + if (sk->sk_state == TCP_ESTABLISHED && + inet6_match(net, sk, rmt_addr, loc_addr, ports, dif, sdif)) + return sk; + /* Only check first socket in chain */ + break; + } + return NULL; +} + +void udp_v6_early_demux(struct sk_buff *skb) +{ + struct net *net = dev_net(skb->dev); + const struct udphdr *uh; + struct sock *sk; + struct dst_entry *dst; + int dif = skb->dev->ifindex; + int sdif = inet6_sdif(skb); + + if (!pskb_may_pull(skb, skb_transport_offset(skb) + + sizeof(struct udphdr))) + return; + + uh = udp_hdr(skb); + + if (skb->pkt_type == PACKET_HOST) + sk = __udp6_lib_demux_lookup(net, uh->dest, + &ipv6_hdr(skb)->daddr, + uh->source, &ipv6_hdr(skb)->saddr, + dif, sdif); + else + return; + + if (!sk || !refcount_inc_not_zero(&sk->sk_refcnt)) + return; + + skb->sk = sk; + skb->destructor = sock_efree; + dst = rcu_dereference(sk->sk_rx_dst); + + if (dst) + dst = dst_check(dst, sk->sk_rx_dst_cookie); + if (dst) { + /* set noref for now. + * any place which wants to hold dst has to call + * dst_hold_safe() + */ + skb_dst_set_noref(skb, dst); + } +} + +INDIRECT_CALLABLE_SCOPE int udpv6_rcv(struct sk_buff *skb) +{ + return __udp6_lib_rcv(skb, &udp_table, IPPROTO_UDP); +} + +/* + * Throw away all pending data and cancel the corking. Socket is locked. + */ +static void udp_v6_flush_pending_frames(struct sock *sk) +{ + struct udp_sock *up = udp_sk(sk); + + if (up->pending == AF_INET) + udp_flush_pending_frames(sk); + else if (up->pending) { + up->len = 0; + WRITE_ONCE(up->pending, 0); + ip6_flush_pending_frames(sk); + } +} + +static int udpv6_pre_connect(struct sock *sk, struct sockaddr *uaddr, + int addr_len) +{ + if (addr_len < offsetofend(struct sockaddr, sa_family)) + return -EINVAL; + /* The following checks are replicated from __ip6_datagram_connect() + * and intended to prevent BPF program called below from accessing + * bytes that are out of the bound specified by user in addr_len. + */ + if (uaddr->sa_family == AF_INET) { + if (ipv6_only_sock(sk)) + return -EAFNOSUPPORT; + return udp_pre_connect(sk, uaddr, addr_len); + } + + if (addr_len < SIN6_LEN_RFC2133) + return -EINVAL; + + return BPF_CGROUP_RUN_PROG_INET6_CONNECT_LOCK(sk, uaddr); +} + +/** + * udp6_hwcsum_outgoing - handle outgoing HW checksumming + * @sk: socket we are sending on + * @skb: sk_buff containing the filled-in UDP header + * (checksum field must be zeroed out) + * @saddr: source address + * @daddr: destination address + * @len: length of packet + */ +static void udp6_hwcsum_outgoing(struct sock *sk, struct sk_buff *skb, + const struct in6_addr *saddr, + const struct in6_addr *daddr, int len) +{ + unsigned int offset; + struct udphdr *uh = udp_hdr(skb); + struct sk_buff *frags = skb_shinfo(skb)->frag_list; + __wsum csum = 0; + + if (!frags) { + /* Only one fragment on the socket. */ + skb->csum_start = skb_transport_header(skb) - skb->head; + skb->csum_offset = offsetof(struct udphdr, check); + uh->check = ~csum_ipv6_magic(saddr, daddr, len, IPPROTO_UDP, 0); + } else { + /* + * HW-checksum won't work as there are two or more + * fragments on the socket so that all csums of sk_buffs + * should be together + */ + offset = skb_transport_offset(skb); + skb->csum = skb_checksum(skb, offset, skb->len - offset, 0); + csum = skb->csum; + + skb->ip_summed = CHECKSUM_NONE; + + do { + csum = csum_add(csum, frags->csum); + } while ((frags = frags->next)); + + uh->check = csum_ipv6_magic(saddr, daddr, len, IPPROTO_UDP, + csum); + if (uh->check == 0) + uh->check = CSUM_MANGLED_0; + } +} + +/* + * Sending + */ + +static int udp_v6_send_skb(struct sk_buff *skb, struct flowi6 *fl6, + struct inet_cork *cork) +{ + struct sock *sk = skb->sk; + struct udphdr *uh; + int err = 0; + int is_udplite = IS_UDPLITE(sk); + __wsum csum = 0; + int offset = skb_transport_offset(skb); + int len = skb->len - offset; + int datalen = len - sizeof(*uh); + + /* + * Create a UDP header + */ + uh = udp_hdr(skb); + uh->source = fl6->fl6_sport; + uh->dest = fl6->fl6_dport; + uh->len = htons(len); + uh->check = 0; + + if (cork->gso_size) { + const int hlen = skb_network_header_len(skb) + + sizeof(struct udphdr); + + if (hlen + cork->gso_size > cork->fragsize) { + kfree_skb(skb); + return -EINVAL; + } + if (datalen > cork->gso_size * UDP_MAX_SEGMENTS) { + kfree_skb(skb); + return -EINVAL; + } + if (udp_get_no_check6_tx(sk)) { + kfree_skb(skb); + return -EINVAL; + } + if (skb->ip_summed != CHECKSUM_PARTIAL || is_udplite || + dst_xfrm(skb_dst(skb))) { + kfree_skb(skb); + return -EIO; + } + + if (datalen > cork->gso_size) { + skb_shinfo(skb)->gso_size = cork->gso_size; + skb_shinfo(skb)->gso_type = SKB_GSO_UDP_L4; + skb_shinfo(skb)->gso_segs = DIV_ROUND_UP(datalen, + cork->gso_size); + } + goto csum_partial; + } + + if (is_udplite) + csum = udplite_csum(skb); + else if (udp_get_no_check6_tx(sk)) { /* UDP csum disabled */ + skb->ip_summed = CHECKSUM_NONE; + goto send; + } else if (skb->ip_summed == CHECKSUM_PARTIAL) { /* UDP hardware csum */ +csum_partial: + udp6_hwcsum_outgoing(sk, skb, &fl6->saddr, &fl6->daddr, len); + goto send; + } else + csum = udp_csum(skb); + + /* add protocol-dependent pseudo-header */ + uh->check = csum_ipv6_magic(&fl6->saddr, &fl6->daddr, + len, fl6->flowi6_proto, csum); + if (uh->check == 0) + uh->check = CSUM_MANGLED_0; + +send: + err = ip6_send_skb(skb); + if (err) { + if (err == -ENOBUFS && !inet6_sk(sk)->recverr) { + UDP6_INC_STATS(sock_net(sk), + UDP_MIB_SNDBUFERRORS, is_udplite); + err = 0; + } + } else { + UDP6_INC_STATS(sock_net(sk), + UDP_MIB_OUTDATAGRAMS, is_udplite); + } + return err; +} + +static int udp_v6_push_pending_frames(struct sock *sk) +{ + struct sk_buff *skb; + struct udp_sock *up = udp_sk(sk); + int err = 0; + + if (up->pending == AF_INET) + return udp_push_pending_frames(sk); + + skb = ip6_finish_skb(sk); + if (!skb) + goto out; + + err = udp_v6_send_skb(skb, &inet_sk(sk)->cork.fl.u.ip6, + &inet_sk(sk)->cork.base); +out: + up->len = 0; + WRITE_ONCE(up->pending, 0); + return err; +} + +int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) +{ + struct ipv6_txoptions opt_space; + struct udp_sock *up = udp_sk(sk); + struct inet_sock *inet = inet_sk(sk); + struct ipv6_pinfo *np = inet6_sk(sk); + DECLARE_SOCKADDR(struct sockaddr_in6 *, sin6, msg->msg_name); + struct in6_addr *daddr, *final_p, final; + struct ipv6_txoptions *opt = NULL; + struct ipv6_txoptions *opt_to_free = NULL; + struct ip6_flowlabel *flowlabel = NULL; + struct inet_cork_full cork; + struct flowi6 *fl6 = &cork.fl.u.ip6; + struct dst_entry *dst; + struct ipcm6_cookie ipc6; + int addr_len = msg->msg_namelen; + bool connected = false; + int ulen = len; + int corkreq = udp_test_bit(CORK, sk) || msg->msg_flags & MSG_MORE; + int err; + int is_udplite = IS_UDPLITE(sk); + int (*getfrag)(void *, char *, int, int, int, struct sk_buff *); + + ipcm6_init(&ipc6); + ipc6.gso_size = READ_ONCE(up->gso_size); + ipc6.sockc.tsflags = READ_ONCE(sk->sk_tsflags); + ipc6.sockc.mark = READ_ONCE(sk->sk_mark); + + /* destination address check */ + if (sin6) { + if (addr_len < offsetof(struct sockaddr, sa_data)) + return -EINVAL; + + switch (sin6->sin6_family) { + case AF_INET6: + if (addr_len < SIN6_LEN_RFC2133) + return -EINVAL; + daddr = &sin6->sin6_addr; + if (ipv6_addr_any(daddr) && + ipv6_addr_v4mapped(&np->saddr)) + ipv6_addr_set_v4mapped(htonl(INADDR_LOOPBACK), + daddr); + break; + case AF_INET: + goto do_udp_sendmsg; + case AF_UNSPEC: + msg->msg_name = sin6 = NULL; + msg->msg_namelen = addr_len = 0; + daddr = NULL; + break; + default: + return -EINVAL; + } + } else if (!READ_ONCE(up->pending)) { + if (sk->sk_state != TCP_ESTABLISHED) + return -EDESTADDRREQ; + daddr = &sk->sk_v6_daddr; + } else + daddr = NULL; + + if (daddr) { + if (ipv6_addr_v4mapped(daddr)) { + struct sockaddr_in sin; + sin.sin_family = AF_INET; + sin.sin_port = sin6 ? sin6->sin6_port : inet->inet_dport; + sin.sin_addr.s_addr = daddr->s6_addr32[3]; + msg->msg_name = &sin; + msg->msg_namelen = sizeof(sin); +do_udp_sendmsg: + err = ipv6_only_sock(sk) ? + -ENETUNREACH : udp_sendmsg(sk, msg, len); + msg->msg_name = sin6; + msg->msg_namelen = addr_len; + return err; + } + } + + /* Rough check on arithmetic overflow, + better check is made in ip6_append_data(). + */ + if (len > INT_MAX - sizeof(struct udphdr)) + return -EMSGSIZE; + + getfrag = is_udplite ? udplite_getfrag : ip_generic_getfrag; + if (READ_ONCE(up->pending)) { + if (READ_ONCE(up->pending) == AF_INET) + return udp_sendmsg(sk, msg, len); + /* + * There are pending frames. + * The socket lock must be held while it's corked. + */ + lock_sock(sk); + if (likely(up->pending)) { + if (unlikely(up->pending != AF_INET6)) { + release_sock(sk); + return -EAFNOSUPPORT; + } + dst = NULL; + goto do_append_data; + } + release_sock(sk); + } + ulen += sizeof(struct udphdr); + + memset(fl6, 0, sizeof(*fl6)); + + if (sin6) { + if (sin6->sin6_port == 0) + return -EINVAL; + + fl6->fl6_dport = sin6->sin6_port; + daddr = &sin6->sin6_addr; + + if (np->sndflow) { + fl6->flowlabel = sin6->sin6_flowinfo&IPV6_FLOWINFO_MASK; + if (fl6->flowlabel & IPV6_FLOWLABEL_MASK) { + flowlabel = fl6_sock_lookup(sk, fl6->flowlabel); + if (IS_ERR(flowlabel)) + return -EINVAL; + } + } + + /* + * Otherwise it will be difficult to maintain + * sk->sk_dst_cache. + */ + if (sk->sk_state == TCP_ESTABLISHED && + ipv6_addr_equal(daddr, &sk->sk_v6_daddr)) + daddr = &sk->sk_v6_daddr; + + if (addr_len >= sizeof(struct sockaddr_in6) && + sin6->sin6_scope_id && + __ipv6_addr_needs_scope_id(__ipv6_addr_type(daddr))) + fl6->flowi6_oif = sin6->sin6_scope_id; + } else { + if (sk->sk_state != TCP_ESTABLISHED) + return -EDESTADDRREQ; + + fl6->fl6_dport = inet->inet_dport; + daddr = &sk->sk_v6_daddr; + fl6->flowlabel = np->flow_label; + connected = true; + } + + if (!fl6->flowi6_oif) + fl6->flowi6_oif = READ_ONCE(sk->sk_bound_dev_if); + + if (!fl6->flowi6_oif) + fl6->flowi6_oif = np->sticky_pktinfo.ipi6_ifindex; + + fl6->flowi6_uid = sk->sk_uid; + + if (msg->msg_controllen) { + opt = &opt_space; + memset(opt, 0, sizeof(struct ipv6_txoptions)); + opt->tot_len = sizeof(*opt); + ipc6.opt = opt; + + err = udp_cmsg_send(sk, msg, &ipc6.gso_size); + if (err > 0) + err = ip6_datagram_send_ctl(sock_net(sk), sk, msg, fl6, + &ipc6); + if (err < 0) { + fl6_sock_release(flowlabel); + return err; + } + if ((fl6->flowlabel&IPV6_FLOWLABEL_MASK) && !flowlabel) { + flowlabel = fl6_sock_lookup(sk, fl6->flowlabel); + if (IS_ERR(flowlabel)) + return -EINVAL; + } + if (!(opt->opt_nflen|opt->opt_flen)) + opt = NULL; + connected = false; + } + if (!opt) { + opt = txopt_get(np); + opt_to_free = opt; + } + if (flowlabel) + opt = fl6_merge_options(&opt_space, flowlabel, opt); + opt = ipv6_fixup_options(&opt_space, opt); + ipc6.opt = opt; + + fl6->flowi6_proto = sk->sk_protocol; + fl6->flowi6_mark = ipc6.sockc.mark; + fl6->daddr = *daddr; + if (ipv6_addr_any(&fl6->saddr) && !ipv6_addr_any(&np->saddr)) + fl6->saddr = np->saddr; + fl6->fl6_sport = inet->inet_sport; + + if (cgroup_bpf_enabled(CGROUP_UDP6_SENDMSG) && !connected) { + err = BPF_CGROUP_RUN_PROG_UDP6_SENDMSG_LOCK(sk, + (struct sockaddr *)sin6, + &fl6->saddr); + if (err) + goto out_no_dst; + if (sin6) { + if (ipv6_addr_v4mapped(&sin6->sin6_addr)) { + /* BPF program rewrote IPv6-only by IPv4-mapped + * IPv6. It's currently unsupported. + */ + err = -ENOTSUPP; + goto out_no_dst; + } + if (sin6->sin6_port == 0) { + /* BPF program set invalid port. Reject it. */ + err = -EINVAL; + goto out_no_dst; + } + fl6->fl6_dport = sin6->sin6_port; + fl6->daddr = sin6->sin6_addr; + } + } + + if (ipv6_addr_any(&fl6->daddr)) + fl6->daddr.s6_addr[15] = 0x1; /* :: means loopback (BSD'ism) */ + + final_p = fl6_update_dst(fl6, opt, &final); + if (final_p) + connected = false; + + if (!fl6->flowi6_oif && ipv6_addr_is_multicast(&fl6->daddr)) { + fl6->flowi6_oif = np->mcast_oif; + connected = false; + } else if (!fl6->flowi6_oif) + fl6->flowi6_oif = np->ucast_oif; + + security_sk_classify_flow(sk, flowi6_to_flowi_common(fl6)); + + if (ipc6.tclass < 0) + ipc6.tclass = np->tclass; + + fl6->flowlabel = ip6_make_flowinfo(ipc6.tclass, fl6->flowlabel); + + dst = ip6_sk_dst_lookup_flow(sk, fl6, final_p, connected); + if (IS_ERR(dst)) { + err = PTR_ERR(dst); + dst = NULL; + goto out; + } + + if (ipc6.hlimit < 0) + ipc6.hlimit = ip6_sk_dst_hoplimit(np, fl6, dst); + + if (msg->msg_flags&MSG_CONFIRM) + goto do_confirm; +back_from_confirm: + + /* Lockless fast path for the non-corking case */ + if (!corkreq) { + struct sk_buff *skb; + + skb = ip6_make_skb(sk, getfrag, msg, ulen, + sizeof(struct udphdr), &ipc6, + (struct rt6_info *)dst, + msg->msg_flags, &cork); + err = PTR_ERR(skb); + if (!IS_ERR_OR_NULL(skb)) + err = udp_v6_send_skb(skb, fl6, &cork.base); + /* ip6_make_skb steals dst reference */ + goto out_no_dst; + } + + lock_sock(sk); + if (unlikely(up->pending)) { + /* The socket is already corked while preparing it. */ + /* ... which is an evident application bug. --ANK */ + release_sock(sk); + + net_dbg_ratelimited("udp cork app bug 2\n"); + err = -EINVAL; + goto out; + } + + WRITE_ONCE(up->pending, AF_INET6); + +do_append_data: + if (ipc6.dontfrag < 0) + ipc6.dontfrag = np->dontfrag; + up->len += ulen; + err = ip6_append_data(sk, getfrag, msg, ulen, sizeof(struct udphdr), + &ipc6, fl6, (struct rt6_info *)dst, + corkreq ? msg->msg_flags|MSG_MORE : msg->msg_flags); + if (err) + udp_v6_flush_pending_frames(sk); + else if (!corkreq) + err = udp_v6_push_pending_frames(sk); + else if (unlikely(skb_queue_empty(&sk->sk_write_queue))) + WRITE_ONCE(up->pending, 0); + + if (err > 0) + err = np->recverr ? net_xmit_errno(err) : 0; + release_sock(sk); + +out: + dst_release(dst); +out_no_dst: + fl6_sock_release(flowlabel); + txopt_put(opt_to_free); + if (!err) + return len; + /* + * ENOBUFS = no kernel mem, SOCK_NOSPACE = no sndbuf space. Reporting + * ENOBUFS might not be good (it's not tunable per se), but otherwise + * we don't have a good statistic (IpOutDiscards but it can be too many + * things). We could add another new stat but at least for now that + * seems like overkill. + */ + if (err == -ENOBUFS || test_bit(SOCK_NOSPACE, &sk->sk_socket->flags)) { + UDP6_INC_STATS(sock_net(sk), + UDP_MIB_SNDBUFERRORS, is_udplite); + } + return err; + +do_confirm: + if (msg->msg_flags & MSG_PROBE) + dst_confirm_neigh(dst, &fl6->daddr); + if (!(msg->msg_flags&MSG_PROBE) || len) + goto back_from_confirm; + err = 0; + goto out; +} + +static void udpv6_splice_eof(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct udp_sock *up = udp_sk(sk); + + if (!READ_ONCE(up->pending) || udp_test_bit(CORK, sk)) + return; + + lock_sock(sk); + if (up->pending && !udp_test_bit(CORK, sk)) + udp_v6_push_pending_frames(sk); + release_sock(sk); +} + +void udpv6_destroy_sock(struct sock *sk) +{ + struct udp_sock *up = udp_sk(sk); + lock_sock(sk); + + /* protects from races with udp_abort() */ + sock_set_flag(sk, SOCK_DEAD); + udp_v6_flush_pending_frames(sk); + release_sock(sk); + + if (static_branch_unlikely(&udpv6_encap_needed_key)) { + if (up->encap_type) { + void (*encap_destroy)(struct sock *sk); + encap_destroy = READ_ONCE(up->encap_destroy); + if (encap_destroy) + encap_destroy(sk); + } + if (udp_test_bit(ENCAP_ENABLED, sk)) { + static_branch_dec(&udpv6_encap_needed_key); + udp_encap_disable(); + } + } +} + +/* + * Socket option code for UDP + */ +int udpv6_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, + unsigned int optlen) +{ + if (level == SOL_UDP || level == SOL_UDPLITE) + return udp_lib_setsockopt(sk, level, optname, + optval, optlen, + udp_v6_push_pending_frames); + return ipv6_setsockopt(sk, level, optname, optval, optlen); +} + +int udpv6_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen) +{ + if (level == SOL_UDP || level == SOL_UDPLITE) + return udp_lib_getsockopt(sk, level, optname, optval, optlen); + return ipv6_getsockopt(sk, level, optname, optval, optlen); +} + +static const struct inet6_protocol udpv6_protocol = { + .handler = udpv6_rcv, + .err_handler = udpv6_err, + .flags = INET6_PROTO_NOPOLICY|INET6_PROTO_FINAL, +}; + +/* ------------------------------------------------------------------------ */ +#ifdef CONFIG_PROC_FS +int udp6_seq_show(struct seq_file *seq, void *v) +{ + if (v == SEQ_START_TOKEN) { + seq_puts(seq, IPV6_SEQ_DGRAM_HEADER); + } else { + int bucket = ((struct udp_iter_state *)seq->private)->bucket; + struct inet_sock *inet = inet_sk(v); + __u16 srcp = ntohs(inet->inet_sport); + __u16 destp = ntohs(inet->inet_dport); + __ip6_dgram_sock_seq_show(seq, v, srcp, destp, + udp_rqueue_get(v), bucket); + } + return 0; +} + +const struct seq_operations udp6_seq_ops = { + .start = udp_seq_start, + .next = udp_seq_next, + .stop = udp_seq_stop, + .show = udp6_seq_show, +}; +EXPORT_SYMBOL(udp6_seq_ops); + +static struct udp_seq_afinfo udp6_seq_afinfo = { + .family = AF_INET6, + .udp_table = &udp_table, +}; + +int __net_init udp6_proc_init(struct net *net) +{ + if (!proc_create_net_data("udp6", 0444, net->proc_net, &udp6_seq_ops, + sizeof(struct udp_iter_state), &udp6_seq_afinfo)) + return -ENOMEM; + return 0; +} + +void udp6_proc_exit(struct net *net) +{ + remove_proc_entry("udp6", net->proc_net); +} +#endif /* CONFIG_PROC_FS */ + +/* ------------------------------------------------------------------------ */ + +struct proto udpv6_prot = { + .name = "UDPv6", + .owner = THIS_MODULE, + .close = udp_lib_close, + .pre_connect = udpv6_pre_connect, + .connect = ip6_datagram_connect, + .disconnect = udp_disconnect, + .ioctl = udp_ioctl, + .init = udpv6_init_sock, + .destroy = udpv6_destroy_sock, + .setsockopt = udpv6_setsockopt, + .getsockopt = udpv6_getsockopt, + .sendmsg = udpv6_sendmsg, + .recvmsg = udpv6_recvmsg, + .splice_eof = udpv6_splice_eof, + .release_cb = ip6_datagram_release_cb, + .hash = udp_lib_hash, + .unhash = udp_lib_unhash, + .rehash = udp_v6_rehash, + .get_port = udp_v6_get_port, + .put_port = udp_lib_unhash, +#ifdef CONFIG_BPF_SYSCALL + .psock_update_sk_prot = udp_bpf_update_proto, +#endif + + .memory_allocated = &udp_memory_allocated, + .per_cpu_fw_alloc = &udp_memory_per_cpu_fw_alloc, + + .sysctl_mem = sysctl_udp_mem, + .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min), + .sysctl_rmem_offset = offsetof(struct net, ipv4.sysctl_udp_rmem_min), + .obj_size = sizeof(struct udp6_sock), + .h.udp_table = &udp_table, + .diag_destroy = udp_abort, +}; + +static struct inet_protosw udpv6_protosw = { + .type = SOCK_DGRAM, + .protocol = IPPROTO_UDP, + .prot = &udpv6_prot, + .ops = &inet6_dgram_ops, + .flags = INET_PROTOSW_PERMANENT, +}; + +int __init udpv6_init(void) +{ + int ret; + + ret = inet6_add_protocol(&udpv6_protocol, IPPROTO_UDP); + if (ret) + goto out; + + ret = inet6_register_protosw(&udpv6_protosw); + if (ret) + goto out_udpv6_protocol; +out: + return ret; + +out_udpv6_protocol: + inet6_del_protocol(&udpv6_protocol, IPPROTO_UDP); + goto out; +} + +void udpv6_exit(void) +{ + inet6_unregister_protosw(&udpv6_protosw); + inet6_del_protocol(&udpv6_protocol, IPPROTO_UDP); +} diff --git a/net/ipv6/udp_impl.h b/net/ipv6/udp_impl.h new file mode 100644 index 000000000..0590f5663 --- /dev/null +++ b/net/ipv6/udp_impl.h @@ -0,0 +1,31 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _UDP6_IMPL_H +#define _UDP6_IMPL_H +#include +#include +#include +#include +#include +#include + +int __udp6_lib_rcv(struct sk_buff *, struct udp_table *, int); +int __udp6_lib_err(struct sk_buff *, struct inet6_skb_parm *, u8, u8, int, + __be32, struct udp_table *); + +int udpv6_init_sock(struct sock *sk); +int udp_v6_get_port(struct sock *sk, unsigned short snum); +void udp_v6_rehash(struct sock *sk); + +int udpv6_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen); +int udpv6_setsockopt(struct sock *sk, int level, int optname, sockptr_t optval, + unsigned int optlen); +int udpv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len); +int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags, + int *addr_len); +void udpv6_destroy_sock(struct sock *sk); + +#ifdef CONFIG_PROC_FS +int udp6_seq_show(struct seq_file *seq, void *v); +#endif +#endif /* _UDP6_IMPL_H */ diff --git a/net/ipv6/udp_offload.c b/net/ipv6/udp_offload.c new file mode 100644 index 000000000..7720d04ed --- /dev/null +++ b/net/ipv6/udp_offload.c @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPV6 GSO/GRO offload support + * Linux INET6 implementation + * + * UDPv6 GSO support + */ +#include +#include +#include +#include +#include +#include +#include +#include "ip6_offload.h" +#include + +static struct sk_buff *udp6_ufo_fragment(struct sk_buff *skb, + netdev_features_t features) +{ + struct sk_buff *segs = ERR_PTR(-EINVAL); + unsigned int mss; + unsigned int unfrag_ip6hlen, unfrag_len; + struct frag_hdr *fptr; + u8 *packet_start, *prevhdr; + u8 nexthdr; + u8 frag_hdr_sz = sizeof(struct frag_hdr); + __wsum csum; + int tnl_hlen; + int err; + + if (skb->encapsulation && skb_shinfo(skb)->gso_type & + (SKB_GSO_UDP_TUNNEL|SKB_GSO_UDP_TUNNEL_CSUM)) + segs = skb_udp_tunnel_segment(skb, features, true); + else { + const struct ipv6hdr *ipv6h; + struct udphdr *uh; + + if (!(skb_shinfo(skb)->gso_type & (SKB_GSO_UDP | SKB_GSO_UDP_L4))) + goto out; + + if (!pskb_may_pull(skb, sizeof(struct udphdr))) + goto out; + + if (skb_shinfo(skb)->gso_type & SKB_GSO_UDP_L4) + return __udp_gso_segment(skb, features, true); + + mss = skb_shinfo(skb)->gso_size; + if (unlikely(skb->len <= mss)) + goto out; + + /* Do software UFO. Complete and fill in the UDP checksum as HW cannot + * do checksum of UDP packets sent as multiple IP fragments. + */ + + uh = udp_hdr(skb); + ipv6h = ipv6_hdr(skb); + + uh->check = 0; + csum = skb_checksum(skb, 0, skb->len, 0); + uh->check = udp_v6_check(skb->len, &ipv6h->saddr, + &ipv6h->daddr, csum); + if (uh->check == 0) + uh->check = CSUM_MANGLED_0; + + skb->ip_summed = CHECKSUM_UNNECESSARY; + + /* If there is no outer header we can fake a checksum offload + * due to the fact that we have already done the checksum in + * software prior to segmenting the frame. + */ + if (!skb->encap_hdr_csum) + features |= NETIF_F_HW_CSUM; + + /* Check if there is enough headroom to insert fragment header. */ + tnl_hlen = skb_tnl_header_len(skb); + if (skb->mac_header < (tnl_hlen + frag_hdr_sz)) { + if (gso_pskb_expand_head(skb, tnl_hlen + frag_hdr_sz)) + goto out; + } + + /* Find the unfragmentable header and shift it left by frag_hdr_sz + * bytes to insert fragment header. + */ + err = ip6_find_1stfragopt(skb, &prevhdr); + if (err < 0) + return ERR_PTR(err); + unfrag_ip6hlen = err; + nexthdr = *prevhdr; + *prevhdr = NEXTHDR_FRAGMENT; + unfrag_len = (skb_network_header(skb) - skb_mac_header(skb)) + + unfrag_ip6hlen + tnl_hlen; + packet_start = (u8 *) skb->head + SKB_GSO_CB(skb)->mac_offset; + memmove(packet_start-frag_hdr_sz, packet_start, unfrag_len); + + SKB_GSO_CB(skb)->mac_offset -= frag_hdr_sz; + skb->mac_header -= frag_hdr_sz; + skb->network_header -= frag_hdr_sz; + + fptr = (struct frag_hdr *)(skb_network_header(skb) + unfrag_ip6hlen); + fptr->nexthdr = nexthdr; + fptr->reserved = 0; + fptr->identification = ipv6_proxy_select_ident(dev_net(skb->dev), skb); + + /* Fragment the skb. ipv6 header and the remaining fields of the + * fragment header are updated in ipv6_gso_segment() + */ + segs = skb_segment(skb, features); + } + +out: + return segs; +} + +static struct sock *udp6_gro_lookup_skb(struct sk_buff *skb, __be16 sport, + __be16 dport) +{ + const struct ipv6hdr *iph = skb_gro_network_header(skb); + + return __udp6_lib_lookup(dev_net(skb->dev), &iph->saddr, sport, + &iph->daddr, dport, inet6_iif(skb), + inet6_sdif(skb), &udp_table, NULL); +} + +INDIRECT_CALLABLE_SCOPE +struct sk_buff *udp6_gro_receive(struct list_head *head, struct sk_buff *skb) +{ + struct udphdr *uh = udp_gro_udphdr(skb); + struct sock *sk = NULL; + struct sk_buff *pp; + + if (unlikely(!uh)) + goto flush; + + /* Don't bother verifying checksum if we're going to flush anyway. */ + if (NAPI_GRO_CB(skb)->flush) + goto skip; + + if (skb_gro_checksum_validate_zero_check(skb, IPPROTO_UDP, uh->check, + ip6_gro_compute_pseudo)) + goto flush; + else if (uh->check) + skb_gro_checksum_try_convert(skb, IPPROTO_UDP, + ip6_gro_compute_pseudo); + +skip: + NAPI_GRO_CB(skb)->is_ipv6 = 1; + + if (static_branch_unlikely(&udpv6_encap_needed_key)) + sk = udp6_gro_lookup_skb(skb, uh->source, uh->dest); + + pp = udp_gro_receive(head, skb, uh, sk); + return pp; + +flush: + NAPI_GRO_CB(skb)->flush = 1; + return NULL; +} + +INDIRECT_CALLABLE_SCOPE int udp6_gro_complete(struct sk_buff *skb, int nhoff) +{ + const struct ipv6hdr *ipv6h = ipv6_hdr(skb); + struct udphdr *uh = (struct udphdr *)(skb->data + nhoff); + + /* do fraglist only if there is no outer UDP encap (or we already processed it) */ + if (NAPI_GRO_CB(skb)->is_flist && !NAPI_GRO_CB(skb)->encap_mark) { + uh->len = htons(skb->len - nhoff); + + skb_shinfo(skb)->gso_type |= (SKB_GSO_FRAGLIST|SKB_GSO_UDP_L4); + skb_shinfo(skb)->gso_segs = NAPI_GRO_CB(skb)->count; + + if (skb->ip_summed == CHECKSUM_UNNECESSARY) { + if (skb->csum_level < SKB_MAX_CSUM_LEVEL) + skb->csum_level++; + } else { + skb->ip_summed = CHECKSUM_UNNECESSARY; + skb->csum_level = 0; + } + + return 0; + } + + if (uh->check) + uh->check = ~udp_v6_check(skb->len - nhoff, &ipv6h->saddr, + &ipv6h->daddr, 0); + + return udp_gro_complete(skb, nhoff, udp6_lib_lookup_skb); +} + +static const struct net_offload udpv6_offload = { + .callbacks = { + .gso_segment = udp6_ufo_fragment, + .gro_receive = udp6_gro_receive, + .gro_complete = udp6_gro_complete, + }, +}; + +int udpv6_offload_init(void) +{ + return inet6_add_offload(&udpv6_offload, IPPROTO_UDP); +} + +int udpv6_offload_exit(void) +{ + return inet6_del_offload(&udpv6_offload, IPPROTO_UDP); +} diff --git a/net/ipv6/udplite.c b/net/ipv6/udplite.c new file mode 100644 index 000000000..3bab0cc13 --- /dev/null +++ b/net/ipv6/udplite.c @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * UDPLITEv6 An implementation of the UDP-Lite protocol over IPv6. + * See also net/ipv4/udplite.c + * + * Authors: Gerrit Renker + * + * Changes: + * Fixes: + */ +#include +#include +#include "udp_impl.h" + +static int udplitev6_sk_init(struct sock *sk) +{ + udpv6_init_sock(sk); + udp_sk(sk)->pcflag = UDPLITE_BIT; + return 0; +} + +static int udplitev6_rcv(struct sk_buff *skb) +{ + return __udp6_lib_rcv(skb, &udplite_table, IPPROTO_UDPLITE); +} + +static int udplitev6_err(struct sk_buff *skb, + struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + return __udp6_lib_err(skb, opt, type, code, offset, info, + &udplite_table); +} + +static const struct inet6_protocol udplitev6_protocol = { + .handler = udplitev6_rcv, + .err_handler = udplitev6_err, + .flags = INET6_PROTO_NOPOLICY|INET6_PROTO_FINAL, +}; + +struct proto udplitev6_prot = { + .name = "UDPLITEv6", + .owner = THIS_MODULE, + .close = udp_lib_close, + .connect = ip6_datagram_connect, + .disconnect = udp_disconnect, + .ioctl = udp_ioctl, + .init = udplitev6_sk_init, + .destroy = udpv6_destroy_sock, + .setsockopt = udpv6_setsockopt, + .getsockopt = udpv6_getsockopt, + .sendmsg = udpv6_sendmsg, + .recvmsg = udpv6_recvmsg, + .hash = udp_lib_hash, + .unhash = udp_lib_unhash, + .rehash = udp_v6_rehash, + .get_port = udp_v6_get_port, + + .memory_allocated = &udp_memory_allocated, + .per_cpu_fw_alloc = &udp_memory_per_cpu_fw_alloc, + + .sysctl_mem = sysctl_udp_mem, + .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min), + .sysctl_rmem_offset = offsetof(struct net, ipv4.sysctl_udp_rmem_min), + .obj_size = sizeof(struct udp6_sock), + .h.udp_table = &udplite_table, +}; + +static struct inet_protosw udplite6_protosw = { + .type = SOCK_DGRAM, + .protocol = IPPROTO_UDPLITE, + .prot = &udplitev6_prot, + .ops = &inet6_dgram_ops, + .flags = INET_PROTOSW_PERMANENT, +}; + +int __init udplitev6_init(void) +{ + int ret; + + ret = inet6_add_protocol(&udplitev6_protocol, IPPROTO_UDPLITE); + if (ret) + goto out; + + ret = inet6_register_protosw(&udplite6_protosw); + if (ret) + goto out_udplitev6_protocol; +out: + return ret; + +out_udplitev6_protocol: + inet6_del_protocol(&udplitev6_protocol, IPPROTO_UDPLITE); + goto out; +} + +void udplitev6_exit(void) +{ + inet6_unregister_protosw(&udplite6_protosw); + inet6_del_protocol(&udplitev6_protocol, IPPROTO_UDPLITE); +} + +#ifdef CONFIG_PROC_FS +static struct udp_seq_afinfo udplite6_seq_afinfo = { + .family = AF_INET6, + .udp_table = &udplite_table, +}; + +static int __net_init udplite6_proc_init_net(struct net *net) +{ + if (!proc_create_net_data("udplite6", 0444, net->proc_net, + &udp6_seq_ops, sizeof(struct udp_iter_state), + &udplite6_seq_afinfo)) + return -ENOMEM; + return 0; +} + +static void __net_exit udplite6_proc_exit_net(struct net *net) +{ + remove_proc_entry("udplite6", net->proc_net); +} + +static struct pernet_operations udplite6_net_ops = { + .init = udplite6_proc_init_net, + .exit = udplite6_proc_exit_net, +}; + +int __init udplite6_proc_init(void) +{ + return register_pernet_subsys(&udplite6_net_ops); +} + +void udplite6_proc_exit(void) +{ + unregister_pernet_subsys(&udplite6_net_ops); +} +#endif diff --git a/net/ipv6/xfrm6_input.c b/net/ipv6/xfrm6_input.c new file mode 100644 index 000000000..415638724 --- /dev/null +++ b/net/ipv6/xfrm6_input.c @@ -0,0 +1,258 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * xfrm6_input.c: based on net/ipv4/xfrm4_input.c + * + * Authors: + * Mitsuru KANDA @USAGI + * Kazunori MIYAZAWA @USAGI + * Kunihiro Ishiguro + * YOSHIFUJI Hideaki @USAGI + * IPv6 support + */ + +#include +#include +#include +#include +#include +#include + +int xfrm6_rcv_spi(struct sk_buff *skb, int nexthdr, __be32 spi, + struct ip6_tnl *t) +{ + XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip6 = t; + XFRM_SPI_SKB_CB(skb)->family = AF_INET6; + XFRM_SPI_SKB_CB(skb)->daddroff = offsetof(struct ipv6hdr, daddr); + return xfrm_input(skb, nexthdr, spi, 0); +} +EXPORT_SYMBOL(xfrm6_rcv_spi); + +static int xfrm6_transport_finish2(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + if (xfrm_trans_queue(skb, ip6_rcv_finish)) { + kfree_skb(skb); + return NET_RX_DROP; + } + + return 0; +} + +int xfrm6_transport_finish(struct sk_buff *skb, int async) +{ + struct xfrm_offload *xo = xfrm_offload(skb); + int nhlen = skb->data - skb_network_header(skb); + + skb_network_header(skb)[IP6CB(skb)->nhoff] = + XFRM_MODE_SKB_CB(skb)->protocol; + +#ifndef CONFIG_NETFILTER + if (!async) + return 1; +#endif + + __skb_push(skb, nhlen); + ipv6_hdr(skb)->payload_len = htons(skb->len - sizeof(struct ipv6hdr)); + skb_postpush_rcsum(skb, skb_network_header(skb), nhlen); + + if (xo && (xo->flags & XFRM_GRO)) { + skb_mac_header_rebuild(skb); + skb_reset_transport_header(skb); + return 0; + } + + NF_HOOK(NFPROTO_IPV6, NF_INET_PRE_ROUTING, + dev_net(skb->dev), NULL, skb, skb->dev, NULL, + xfrm6_transport_finish2); + return 0; +} + +/* If it's a keepalive packet, then just eat it. + * If it's an encapsulated packet, then pass it to the + * IPsec xfrm input. + * Returns 0 if skb passed to xfrm or was dropped. + * Returns >0 if skb should be passed to UDP. + * Returns <0 if skb should be resubmitted (-ret is protocol) + */ +int xfrm6_udp_encap_rcv(struct sock *sk, struct sk_buff *skb) +{ + struct udp_sock *up = udp_sk(sk); + struct udphdr *uh; + struct ipv6hdr *ip6h; + int len; + int ip6hlen = sizeof(struct ipv6hdr); + __u8 *udpdata; + __be32 *udpdata32; + u16 encap_type; + + if (skb->protocol == htons(ETH_P_IP)) + return xfrm4_udp_encap_rcv(sk, skb); + + encap_type = READ_ONCE(up->encap_type); + /* if this is not encapsulated socket, then just return now */ + if (!encap_type) + return 1; + + /* If this is a paged skb, make sure we pull up + * whatever data we need to look at. */ + len = skb->len - sizeof(struct udphdr); + if (!pskb_may_pull(skb, sizeof(struct udphdr) + min(len, 8))) + return 1; + + /* Now we can get the pointers */ + uh = udp_hdr(skb); + udpdata = (__u8 *)uh + sizeof(struct udphdr); + udpdata32 = (__be32 *)udpdata; + + switch (encap_type) { + default: + case UDP_ENCAP_ESPINUDP: + /* Check if this is a keepalive packet. If so, eat it. */ + if (len == 1 && udpdata[0] == 0xff) { + goto drop; + } else if (len > sizeof(struct ip_esp_hdr) && udpdata32[0] != 0) { + /* ESP Packet without Non-ESP header */ + len = sizeof(struct udphdr); + } else + /* Must be an IKE packet.. pass it through */ + return 1; + break; + case UDP_ENCAP_ESPINUDP_NON_IKE: + /* Check if this is a keepalive packet. If so, eat it. */ + if (len == 1 && udpdata[0] == 0xff) { + goto drop; + } else if (len > 2 * sizeof(u32) + sizeof(struct ip_esp_hdr) && + udpdata32[0] == 0 && udpdata32[1] == 0) { + + /* ESP Packet with Non-IKE marker */ + len = sizeof(struct udphdr) + 2 * sizeof(u32); + } else + /* Must be an IKE packet.. pass it through */ + return 1; + break; + } + + /* At this point we are sure that this is an ESPinUDP packet, + * so we need to remove 'len' bytes from the packet (the UDP + * header and optional ESP marker bytes) and then modify the + * protocol to ESP, and then call into the transform receiver. + */ + if (skb_unclone(skb, GFP_ATOMIC)) + goto drop; + + /* Now we can update and verify the packet length... */ + ip6h = ipv6_hdr(skb); + ip6h->payload_len = htons(ntohs(ip6h->payload_len) - len); + if (skb->len < ip6hlen + len) { + /* packet is too small!?! */ + goto drop; + } + + /* pull the data buffer up to the ESP header and set the + * transport header to point to ESP. Keep UDP on the stack + * for later. + */ + __skb_pull(skb, len); + skb_reset_transport_header(skb); + + /* process ESP */ + return xfrm6_rcv_encap(skb, IPPROTO_ESP, 0, encap_type); + +drop: + kfree_skb(skb); + return 0; +} + +int xfrm6_rcv_tnl(struct sk_buff *skb, struct ip6_tnl *t) +{ + return xfrm6_rcv_spi(skb, skb_network_header(skb)[IP6CB(skb)->nhoff], + 0, t); +} +EXPORT_SYMBOL(xfrm6_rcv_tnl); + +int xfrm6_rcv(struct sk_buff *skb) +{ + return xfrm6_rcv_tnl(skb, NULL); +} +EXPORT_SYMBOL(xfrm6_rcv); +int xfrm6_input_addr(struct sk_buff *skb, xfrm_address_t *daddr, + xfrm_address_t *saddr, u8 proto) +{ + struct net *net = dev_net(skb->dev); + struct xfrm_state *x = NULL; + struct sec_path *sp; + int i = 0; + + sp = secpath_set(skb); + if (!sp) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINERROR); + goto drop; + } + + if (1 + sp->len == XFRM_MAX_DEPTH) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINBUFFERERROR); + goto drop; + } + + for (i = 0; i < 3; i++) { + xfrm_address_t *dst, *src; + + switch (i) { + case 0: + dst = daddr; + src = saddr; + break; + case 1: + /* lookup state with wild-card source address */ + dst = daddr; + src = (xfrm_address_t *)&in6addr_any; + break; + default: + /* lookup state with wild-card addresses */ + dst = (xfrm_address_t *)&in6addr_any; + src = (xfrm_address_t *)&in6addr_any; + break; + } + + x = xfrm_state_lookup_byaddr(net, skb->mark, dst, src, proto, AF_INET6); + if (!x) + continue; + + spin_lock(&x->lock); + + if ((!i || (x->props.flags & XFRM_STATE_WILDRECV)) && + likely(x->km.state == XFRM_STATE_VALID) && + !xfrm_state_check_expire(x)) { + spin_unlock(&x->lock); + if (x->type->input(x, skb) > 0) { + /* found a valid state */ + break; + } + } else + spin_unlock(&x->lock); + + xfrm_state_put(x); + x = NULL; + } + + if (!x) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINNOSTATES); + xfrm_audit_state_notfound_simple(skb, AF_INET6); + goto drop; + } + + sp->xvec[sp->len++] = x; + + spin_lock(&x->lock); + + x->curlft.bytes += skb->len; + x->curlft.packets++; + + spin_unlock(&x->lock); + + return 1; + +drop: + return -1; +} +EXPORT_SYMBOL(xfrm6_input_addr); diff --git a/net/ipv6/xfrm6_output.c b/net/ipv6/xfrm6_output.c new file mode 100644 index 000000000..ad0790464 --- /dev/null +++ b/net/ipv6/xfrm6_output.c @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * xfrm6_output.c - Common IPsec encapsulation code for IPv6. + * Copyright (C) 2002 USAGI/WIDE Project + * Copyright (c) 2004 Herbert Xu + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void xfrm6_local_rxpmtu(struct sk_buff *skb, u32 mtu) +{ + struct flowi6 fl6; + struct sock *sk = skb->sk; + + fl6.flowi6_oif = sk->sk_bound_dev_if; + fl6.daddr = ipv6_hdr(skb)->daddr; + + ipv6_local_rxpmtu(sk, &fl6, mtu); +} + +void xfrm6_local_error(struct sk_buff *skb, u32 mtu) +{ + struct flowi6 fl6; + const struct ipv6hdr *hdr; + struct sock *sk = skb->sk; + + hdr = skb->encapsulation ? inner_ipv6_hdr(skb) : ipv6_hdr(skb); + fl6.fl6_dport = inet_sk(sk)->inet_dport; + fl6.daddr = hdr->daddr; + + ipv6_local_error(sk, EMSGSIZE, &fl6, mtu); +} + +static int __xfrm6_output_finish(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + return xfrm_output(sk, skb); +} + +static int xfrm6_noneed_fragment(struct sk_buff *skb) +{ + struct frag_hdr *fh; + u8 prevhdr = ipv6_hdr(skb)->nexthdr; + + if (prevhdr != NEXTHDR_FRAGMENT) + return 0; + fh = (struct frag_hdr *)(skb->data + sizeof(struct ipv6hdr)); + if (fh->nexthdr == NEXTHDR_ESP || fh->nexthdr == NEXTHDR_AUTH) + return 1; + return 0; +} + +static int __xfrm6_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + struct xfrm_state *x = dst->xfrm; + unsigned int mtu; + bool toobig; + +#ifdef CONFIG_NETFILTER + if (!x) { + IP6CB(skb)->flags |= IP6SKB_REROUTED; + return dst_output(net, sk, skb); + } +#endif + + if (x->props.mode != XFRM_MODE_TUNNEL) + goto skip_frag; + + if (skb->protocol == htons(ETH_P_IPV6)) + mtu = ip6_skb_dst_mtu(skb); + else + mtu = dst_mtu(skb_dst(skb)); + + toobig = skb->len > mtu && !skb_is_gso(skb); + + if (toobig && xfrm6_local_dontfrag(skb->sk)) { + xfrm6_local_rxpmtu(skb, mtu); + kfree_skb(skb); + return -EMSGSIZE; + } else if (toobig && xfrm6_noneed_fragment(skb)) { + skb->ignore_df = 1; + goto skip_frag; + } else if (!skb->ignore_df && toobig && skb->sk) { + xfrm_local_error(skb, mtu); + kfree_skb(skb); + return -EMSGSIZE; + } + + if (toobig || dst_allfrag(skb_dst(skb))) + return ip6_fragment(net, sk, skb, + __xfrm6_output_finish); + +skip_frag: + return xfrm_output(sk, skb); +} + +int xfrm6_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + return NF_HOOK_COND(NFPROTO_IPV6, NF_INET_POST_ROUTING, + net, sk, skb, skb->dev, skb_dst(skb)->dev, + __xfrm6_output, + !(IP6CB(skb)->flags & IP6SKB_REROUTED)); +} diff --git a/net/ipv6/xfrm6_policy.c b/net/ipv6/xfrm6_policy.c new file mode 100644 index 000000000..f0053087d --- /dev/null +++ b/net/ipv6/xfrm6_policy.c @@ -0,0 +1,310 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * xfrm6_policy.c: based on xfrm4_policy.c + * + * Authors: + * Mitsuru KANDA @USAGI + * Kazunori MIYAZAWA @USAGI + * Kunihiro Ishiguro + * IPv6 support + * YOSHIFUJI Hideaki + * Split up af-specific portion + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static struct dst_entry *xfrm6_dst_lookup(struct net *net, int tos, int oif, + const xfrm_address_t *saddr, + const xfrm_address_t *daddr, + u32 mark) +{ + struct flowi6 fl6; + struct dst_entry *dst; + int err; + + memset(&fl6, 0, sizeof(fl6)); + fl6.flowi6_l3mdev = l3mdev_master_ifindex_by_index(net, oif); + fl6.flowi6_mark = mark; + memcpy(&fl6.daddr, daddr, sizeof(fl6.daddr)); + if (saddr) + memcpy(&fl6.saddr, saddr, sizeof(fl6.saddr)); + + dst = ip6_route_output(net, NULL, &fl6); + + err = dst->error; + if (dst->error) { + dst_release(dst); + dst = ERR_PTR(err); + } + + return dst; +} + +static int xfrm6_get_saddr(struct net *net, int oif, + xfrm_address_t *saddr, xfrm_address_t *daddr, + u32 mark) +{ + struct dst_entry *dst; + struct net_device *dev; + + dst = xfrm6_dst_lookup(net, 0, oif, NULL, daddr, mark); + if (IS_ERR(dst)) + return -EHOSTUNREACH; + + dev = ip6_dst_idev(dst)->dev; + ipv6_dev_get_saddr(dev_net(dev), dev, &daddr->in6, 0, &saddr->in6); + dst_release(dst); + return 0; +} + +static int xfrm6_fill_dst(struct xfrm_dst *xdst, struct net_device *dev, + const struct flowi *fl) +{ + struct rt6_info *rt = (struct rt6_info *)xdst->route; + + xdst->u.dst.dev = dev; + netdev_hold(dev, &xdst->u.dst.dev_tracker, GFP_ATOMIC); + + xdst->u.rt6.rt6i_idev = in6_dev_get(dev); + if (!xdst->u.rt6.rt6i_idev) { + netdev_put(dev, &xdst->u.dst.dev_tracker); + return -ENODEV; + } + + /* Sheit... I remember I did this right. Apparently, + * it was magically lost, so this code needs audit */ + xdst->u.rt6.rt6i_flags = rt->rt6i_flags & (RTF_ANYCAST | + RTF_LOCAL); + xdst->route_cookie = rt6_get_cookie(rt); + xdst->u.rt6.rt6i_gateway = rt->rt6i_gateway; + xdst->u.rt6.rt6i_dst = rt->rt6i_dst; + xdst->u.rt6.rt6i_src = rt->rt6i_src; + INIT_LIST_HEAD(&xdst->u.rt6.rt6i_uncached); + rt6_uncached_list_add(&xdst->u.rt6); + + return 0; +} + +static void xfrm6_update_pmtu(struct dst_entry *dst, struct sock *sk, + struct sk_buff *skb, u32 mtu, + bool confirm_neigh) +{ + struct xfrm_dst *xdst = (struct xfrm_dst *)dst; + struct dst_entry *path = xdst->route; + + path->ops->update_pmtu(path, sk, skb, mtu, confirm_neigh); +} + +static void xfrm6_redirect(struct dst_entry *dst, struct sock *sk, + struct sk_buff *skb) +{ + struct xfrm_dst *xdst = (struct xfrm_dst *)dst; + struct dst_entry *path = xdst->route; + + path->ops->redirect(path, sk, skb); +} + +static void xfrm6_dst_destroy(struct dst_entry *dst) +{ + struct xfrm_dst *xdst = (struct xfrm_dst *)dst; + + dst_destroy_metrics_generic(dst); + if (xdst->u.rt6.rt6i_uncached_list) + rt6_uncached_list_del(&xdst->u.rt6); + if (likely(xdst->u.rt6.rt6i_idev)) + in6_dev_put(xdst->u.rt6.rt6i_idev); + xfrm_dst_destroy(xdst); +} + +static void xfrm6_dst_ifdown(struct dst_entry *dst, struct net_device *dev, + int unregister) +{ + struct xfrm_dst *xdst; + + if (!unregister) + return; + + xdst = (struct xfrm_dst *)dst; + if (xdst->u.rt6.rt6i_idev->dev == dev) { + struct inet6_dev *loopback_idev = + in6_dev_get(dev_net(dev)->loopback_dev); + + do { + in6_dev_put(xdst->u.rt6.rt6i_idev); + xdst->u.rt6.rt6i_idev = loopback_idev; + in6_dev_hold(loopback_idev); + xdst = (struct xfrm_dst *)xfrm_dst_child(&xdst->u.dst); + } while (xdst->u.dst.xfrm); + + __in6_dev_put(loopback_idev); + } + + xfrm_dst_ifdown(dst, dev); +} + +static struct dst_ops xfrm6_dst_ops_template = { + .family = AF_INET6, + .update_pmtu = xfrm6_update_pmtu, + .redirect = xfrm6_redirect, + .cow_metrics = dst_cow_metrics_generic, + .destroy = xfrm6_dst_destroy, + .ifdown = xfrm6_dst_ifdown, + .local_out = __ip6_local_out, + .gc_thresh = 32768, +}; + +static const struct xfrm_policy_afinfo xfrm6_policy_afinfo = { + .dst_ops = &xfrm6_dst_ops_template, + .dst_lookup = xfrm6_dst_lookup, + .get_saddr = xfrm6_get_saddr, + .fill_dst = xfrm6_fill_dst, + .blackhole_route = ip6_blackhole_route, +}; + +static int __init xfrm6_policy_init(void) +{ + return xfrm_policy_register_afinfo(&xfrm6_policy_afinfo, AF_INET6); +} + +static void xfrm6_policy_fini(void) +{ + xfrm_policy_unregister_afinfo(&xfrm6_policy_afinfo); +} + +#ifdef CONFIG_SYSCTL +static struct ctl_table xfrm6_policy_table[] = { + { + .procname = "xfrm6_gc_thresh", + .data = &init_net.xfrm.xfrm6_dst_ops.gc_thresh, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { } +}; + +static int __net_init xfrm6_net_sysctl_init(struct net *net) +{ + struct ctl_table *table; + struct ctl_table_header *hdr; + + table = xfrm6_policy_table; + if (!net_eq(net, &init_net)) { + table = kmemdup(table, sizeof(xfrm6_policy_table), GFP_KERNEL); + if (!table) + goto err_alloc; + + table[0].data = &net->xfrm.xfrm6_dst_ops.gc_thresh; + } + + hdr = register_net_sysctl(net, "net/ipv6", table); + if (!hdr) + goto err_reg; + + net->ipv6.sysctl.xfrm6_hdr = hdr; + return 0; + +err_reg: + if (!net_eq(net, &init_net)) + kfree(table); +err_alloc: + return -ENOMEM; +} + +static void __net_exit xfrm6_net_sysctl_exit(struct net *net) +{ + struct ctl_table *table; + + if (!net->ipv6.sysctl.xfrm6_hdr) + return; + + table = net->ipv6.sysctl.xfrm6_hdr->ctl_table_arg; + unregister_net_sysctl_table(net->ipv6.sysctl.xfrm6_hdr); + if (!net_eq(net, &init_net)) + kfree(table); +} +#else /* CONFIG_SYSCTL */ +static inline int xfrm6_net_sysctl_init(struct net *net) +{ + return 0; +} + +static inline void xfrm6_net_sysctl_exit(struct net *net) +{ +} +#endif + +static int __net_init xfrm6_net_init(struct net *net) +{ + int ret; + + memcpy(&net->xfrm.xfrm6_dst_ops, &xfrm6_dst_ops_template, + sizeof(xfrm6_dst_ops_template)); + ret = dst_entries_init(&net->xfrm.xfrm6_dst_ops); + if (ret) + return ret; + + ret = xfrm6_net_sysctl_init(net); + if (ret) + dst_entries_destroy(&net->xfrm.xfrm6_dst_ops); + + return ret; +} + +static void __net_exit xfrm6_net_exit(struct net *net) +{ + xfrm6_net_sysctl_exit(net); + dst_entries_destroy(&net->xfrm.xfrm6_dst_ops); +} + +static struct pernet_operations xfrm6_net_ops = { + .init = xfrm6_net_init, + .exit = xfrm6_net_exit, +}; + +int __init xfrm6_init(void) +{ + int ret; + + ret = xfrm6_policy_init(); + if (ret) + goto out; + ret = xfrm6_state_init(); + if (ret) + goto out_policy; + + ret = xfrm6_protocol_init(); + if (ret) + goto out_state; + + ret = register_pernet_subsys(&xfrm6_net_ops); + if (ret) + goto out_protocol; +out: + return ret; +out_protocol: + xfrm6_protocol_fini(); +out_state: + xfrm6_state_fini(); +out_policy: + xfrm6_policy_fini(); + goto out; +} + +void xfrm6_fini(void) +{ + unregister_pernet_subsys(&xfrm6_net_ops); + xfrm6_protocol_fini(); + xfrm6_policy_fini(); + xfrm6_state_fini(); +} diff --git a/net/ipv6/xfrm6_protocol.c b/net/ipv6/xfrm6_protocol.c new file mode 100644 index 000000000..ea2f805d3 --- /dev/null +++ b/net/ipv6/xfrm6_protocol.c @@ -0,0 +1,327 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* xfrm6_protocol.c - Generic xfrm protocol multiplexer for ipv6. + * + * Copyright (C) 2013 secunet Security Networks AG + * + * Author: + * Steffen Klassert + * + * Based on: + * net/ipv4/xfrm4_protocol.c + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +static struct xfrm6_protocol __rcu *esp6_handlers __read_mostly; +static struct xfrm6_protocol __rcu *ah6_handlers __read_mostly; +static struct xfrm6_protocol __rcu *ipcomp6_handlers __read_mostly; +static DEFINE_MUTEX(xfrm6_protocol_mutex); + +static inline struct xfrm6_protocol __rcu **proto_handlers(u8 protocol) +{ + switch (protocol) { + case IPPROTO_ESP: + return &esp6_handlers; + case IPPROTO_AH: + return &ah6_handlers; + case IPPROTO_COMP: + return &ipcomp6_handlers; + } + + return NULL; +} + +#define for_each_protocol_rcu(head, handler) \ + for (handler = rcu_dereference(head); \ + handler != NULL; \ + handler = rcu_dereference(handler->next)) \ + +static int xfrm6_rcv_cb(struct sk_buff *skb, u8 protocol, int err) +{ + int ret; + struct xfrm6_protocol *handler; + struct xfrm6_protocol __rcu **head = proto_handlers(protocol); + + if (!head) + return 0; + + for_each_protocol_rcu(*proto_handlers(protocol), handler) + if ((ret = handler->cb_handler(skb, err)) <= 0) + return ret; + + return 0; +} + +int xfrm6_rcv_encap(struct sk_buff *skb, int nexthdr, __be32 spi, + int encap_type) +{ + int ret; + struct xfrm6_protocol *handler; + struct xfrm6_protocol __rcu **head = proto_handlers(nexthdr); + + XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip6 = NULL; + XFRM_SPI_SKB_CB(skb)->family = AF_INET6; + XFRM_SPI_SKB_CB(skb)->daddroff = offsetof(struct ipv6hdr, daddr); + + if (!head) + goto out; + + if (!skb_dst(skb)) { + const struct ipv6hdr *ip6h = ipv6_hdr(skb); + int flags = RT6_LOOKUP_F_HAS_SADDR; + struct dst_entry *dst; + struct flowi6 fl6 = { + .flowi6_iif = skb->dev->ifindex, + .daddr = ip6h->daddr, + .saddr = ip6h->saddr, + .flowlabel = ip6_flowinfo(ip6h), + .flowi6_mark = skb->mark, + .flowi6_proto = ip6h->nexthdr, + }; + + dst = ip6_route_input_lookup(dev_net(skb->dev), skb->dev, &fl6, + skb, flags); + if (dst->error) + goto drop; + skb_dst_set(skb, dst); + } + + for_each_protocol_rcu(*head, handler) + if ((ret = handler->input_handler(skb, nexthdr, spi, encap_type)) != -EINVAL) + return ret; + +out: + icmpv6_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_PORT_UNREACH, 0); + +drop: + kfree_skb(skb); + return 0; +} +EXPORT_SYMBOL(xfrm6_rcv_encap); + +static int xfrm6_esp_rcv(struct sk_buff *skb) +{ + int ret; + struct xfrm6_protocol *handler; + + XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip6 = NULL; + + for_each_protocol_rcu(esp6_handlers, handler) + if ((ret = handler->handler(skb)) != -EINVAL) + return ret; + + icmpv6_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_PORT_UNREACH, 0); + + kfree_skb(skb); + return 0; +} + +static int xfrm6_esp_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + struct xfrm6_protocol *handler; + + for_each_protocol_rcu(esp6_handlers, handler) + if (!handler->err_handler(skb, opt, type, code, offset, info)) + return 0; + + return -ENOENT; +} + +static int xfrm6_ah_rcv(struct sk_buff *skb) +{ + int ret; + struct xfrm6_protocol *handler; + + XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip6 = NULL; + + for_each_protocol_rcu(ah6_handlers, handler) + if ((ret = handler->handler(skb)) != -EINVAL) + return ret; + + icmpv6_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_PORT_UNREACH, 0); + + kfree_skb(skb); + return 0; +} + +static int xfrm6_ah_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + struct xfrm6_protocol *handler; + + for_each_protocol_rcu(ah6_handlers, handler) + if (!handler->err_handler(skb, opt, type, code, offset, info)) + return 0; + + return -ENOENT; +} + +static int xfrm6_ipcomp_rcv(struct sk_buff *skb) +{ + int ret; + struct xfrm6_protocol *handler; + + XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip6 = NULL; + + for_each_protocol_rcu(ipcomp6_handlers, handler) + if ((ret = handler->handler(skb)) != -EINVAL) + return ret; + + icmpv6_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_PORT_UNREACH, 0); + + kfree_skb(skb); + return 0; +} + +static int xfrm6_ipcomp_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + struct xfrm6_protocol *handler; + + for_each_protocol_rcu(ipcomp6_handlers, handler) + if (!handler->err_handler(skb, opt, type, code, offset, info)) + return 0; + + return -ENOENT; +} + +static const struct inet6_protocol esp6_protocol = { + .handler = xfrm6_esp_rcv, + .err_handler = xfrm6_esp_err, + .flags = INET6_PROTO_NOPOLICY, +}; + +static const struct inet6_protocol ah6_protocol = { + .handler = xfrm6_ah_rcv, + .err_handler = xfrm6_ah_err, + .flags = INET6_PROTO_NOPOLICY, +}; + +static const struct inet6_protocol ipcomp6_protocol = { + .handler = xfrm6_ipcomp_rcv, + .err_handler = xfrm6_ipcomp_err, + .flags = INET6_PROTO_NOPOLICY, +}; + +static const struct xfrm_input_afinfo xfrm6_input_afinfo = { + .family = AF_INET6, + .callback = xfrm6_rcv_cb, +}; + +static inline const struct inet6_protocol *netproto(unsigned char protocol) +{ + switch (protocol) { + case IPPROTO_ESP: + return &esp6_protocol; + case IPPROTO_AH: + return &ah6_protocol; + case IPPROTO_COMP: + return &ipcomp6_protocol; + } + + return NULL; +} + +int xfrm6_protocol_register(struct xfrm6_protocol *handler, + unsigned char protocol) +{ + struct xfrm6_protocol __rcu **pprev; + struct xfrm6_protocol *t; + bool add_netproto = false; + int ret = -EEXIST; + int priority = handler->priority; + + if (!proto_handlers(protocol) || !netproto(protocol)) + return -EINVAL; + + mutex_lock(&xfrm6_protocol_mutex); + + if (!rcu_dereference_protected(*proto_handlers(protocol), + lockdep_is_held(&xfrm6_protocol_mutex))) + add_netproto = true; + + for (pprev = proto_handlers(protocol); + (t = rcu_dereference_protected(*pprev, + lockdep_is_held(&xfrm6_protocol_mutex))) != NULL; + pprev = &t->next) { + if (t->priority < priority) + break; + if (t->priority == priority) + goto err; + } + + handler->next = *pprev; + rcu_assign_pointer(*pprev, handler); + + ret = 0; + +err: + mutex_unlock(&xfrm6_protocol_mutex); + + if (add_netproto) { + if (inet6_add_protocol(netproto(protocol), protocol)) { + pr_err("%s: can't add protocol\n", __func__); + ret = -EAGAIN; + } + } + + return ret; +} +EXPORT_SYMBOL(xfrm6_protocol_register); + +int xfrm6_protocol_deregister(struct xfrm6_protocol *handler, + unsigned char protocol) +{ + struct xfrm6_protocol __rcu **pprev; + struct xfrm6_protocol *t; + int ret = -ENOENT; + + if (!proto_handlers(protocol) || !netproto(protocol)) + return -EINVAL; + + mutex_lock(&xfrm6_protocol_mutex); + + for (pprev = proto_handlers(protocol); + (t = rcu_dereference_protected(*pprev, + lockdep_is_held(&xfrm6_protocol_mutex))) != NULL; + pprev = &t->next) { + if (t == handler) { + *pprev = handler->next; + ret = 0; + break; + } + } + + if (!rcu_dereference_protected(*proto_handlers(protocol), + lockdep_is_held(&xfrm6_protocol_mutex))) { + if (inet6_del_protocol(netproto(protocol), protocol) < 0) { + pr_err("%s: can't remove protocol\n", __func__); + ret = -EAGAIN; + } + } + + mutex_unlock(&xfrm6_protocol_mutex); + + synchronize_net(); + + return ret; +} +EXPORT_SYMBOL(xfrm6_protocol_deregister); + +int __init xfrm6_protocol_init(void) +{ + return xfrm_input_register_afinfo(&xfrm6_input_afinfo); +} + +void xfrm6_protocol_fini(void) +{ + xfrm_input_unregister_afinfo(&xfrm6_input_afinfo); +} diff --git a/net/ipv6/xfrm6_state.c b/net/ipv6/xfrm6_state.c new file mode 100644 index 000000000..6610b2198 --- /dev/null +++ b/net/ipv6/xfrm6_state.c @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * xfrm6_state.c: based on xfrm4_state.c + * + * Authors: + * Mitsuru KANDA @USAGI + * Kazunori MIYAZAWA @USAGI + * Kunihiro Ishiguro + * IPv6 support + * YOSHIFUJI Hideaki @USAGI + * Split up af-specific portion + * + */ + +#include + +static struct xfrm_state_afinfo xfrm6_state_afinfo = { + .family = AF_INET6, + .proto = IPPROTO_IPV6, + .output = xfrm6_output, + .transport_finish = xfrm6_transport_finish, + .local_error = xfrm6_local_error, +}; + +int __init xfrm6_state_init(void) +{ + return xfrm_state_register_afinfo(&xfrm6_state_afinfo); +} + +void xfrm6_state_fini(void) +{ + xfrm_state_unregister_afinfo(&xfrm6_state_afinfo); +} diff --git a/net/ipv6/xfrm6_tunnel.c b/net/ipv6/xfrm6_tunnel.c new file mode 100644 index 000000000..1323f2f69 --- /dev/null +++ b/net/ipv6/xfrm6_tunnel.c @@ -0,0 +1,405 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (C)2003,2004 USAGI/WIDE Project + * + * Authors Mitsuru KANDA + * YOSHIFUJI Hideaki + * + * Based on net/ipv4/xfrm4_tunnel.c + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define XFRM6_TUNNEL_SPI_BYADDR_HSIZE 256 +#define XFRM6_TUNNEL_SPI_BYSPI_HSIZE 256 + +#define XFRM6_TUNNEL_SPI_MIN 1 +#define XFRM6_TUNNEL_SPI_MAX 0xffffffff + +struct xfrm6_tunnel_net { + struct hlist_head spi_byaddr[XFRM6_TUNNEL_SPI_BYADDR_HSIZE]; + struct hlist_head spi_byspi[XFRM6_TUNNEL_SPI_BYSPI_HSIZE]; + u32 spi; +}; + +static unsigned int xfrm6_tunnel_net_id __read_mostly; +static inline struct xfrm6_tunnel_net *xfrm6_tunnel_pernet(struct net *net) +{ + return net_generic(net, xfrm6_tunnel_net_id); +} + +/* + * xfrm_tunnel_spi things are for allocating unique id ("spi") + * per xfrm_address_t. + */ +struct xfrm6_tunnel_spi { + struct hlist_node list_byaddr; + struct hlist_node list_byspi; + xfrm_address_t addr; + u32 spi; + refcount_t refcnt; + struct rcu_head rcu_head; +}; + +static DEFINE_SPINLOCK(xfrm6_tunnel_spi_lock); + +static struct kmem_cache *xfrm6_tunnel_spi_kmem __read_mostly; + +static inline unsigned int xfrm6_tunnel_spi_hash_byaddr(const xfrm_address_t *addr) +{ + unsigned int h; + + h = ipv6_addr_hash((const struct in6_addr *)addr); + h ^= h >> 16; + h ^= h >> 8; + h &= XFRM6_TUNNEL_SPI_BYADDR_HSIZE - 1; + + return h; +} + +static inline unsigned int xfrm6_tunnel_spi_hash_byspi(u32 spi) +{ + return spi % XFRM6_TUNNEL_SPI_BYSPI_HSIZE; +} + +static struct xfrm6_tunnel_spi *__xfrm6_tunnel_spi_lookup(struct net *net, const xfrm_address_t *saddr) +{ + struct xfrm6_tunnel_net *xfrm6_tn = xfrm6_tunnel_pernet(net); + struct xfrm6_tunnel_spi *x6spi; + + hlist_for_each_entry_rcu(x6spi, + &xfrm6_tn->spi_byaddr[xfrm6_tunnel_spi_hash_byaddr(saddr)], + list_byaddr, lockdep_is_held(&xfrm6_tunnel_spi_lock)) { + if (xfrm6_addr_equal(&x6spi->addr, saddr)) + return x6spi; + } + + return NULL; +} + +__be32 xfrm6_tunnel_spi_lookup(struct net *net, const xfrm_address_t *saddr) +{ + struct xfrm6_tunnel_spi *x6spi; + u32 spi; + + rcu_read_lock_bh(); + x6spi = __xfrm6_tunnel_spi_lookup(net, saddr); + spi = x6spi ? x6spi->spi : 0; + rcu_read_unlock_bh(); + return htonl(spi); +} +EXPORT_SYMBOL(xfrm6_tunnel_spi_lookup); + +static int __xfrm6_tunnel_spi_check(struct net *net, u32 spi) +{ + struct xfrm6_tunnel_net *xfrm6_tn = xfrm6_tunnel_pernet(net); + struct xfrm6_tunnel_spi *x6spi; + int index = xfrm6_tunnel_spi_hash_byspi(spi); + + hlist_for_each_entry(x6spi, + &xfrm6_tn->spi_byspi[index], + list_byspi) { + if (x6spi->spi == spi) + return -1; + } + return index; +} + +static u32 __xfrm6_tunnel_alloc_spi(struct net *net, xfrm_address_t *saddr) +{ + struct xfrm6_tunnel_net *xfrm6_tn = xfrm6_tunnel_pernet(net); + u32 spi; + struct xfrm6_tunnel_spi *x6spi; + int index; + + if (xfrm6_tn->spi < XFRM6_TUNNEL_SPI_MIN || + xfrm6_tn->spi >= XFRM6_TUNNEL_SPI_MAX) + xfrm6_tn->spi = XFRM6_TUNNEL_SPI_MIN; + else + xfrm6_tn->spi++; + + for (spi = xfrm6_tn->spi; spi <= XFRM6_TUNNEL_SPI_MAX; spi++) { + index = __xfrm6_tunnel_spi_check(net, spi); + if (index >= 0) + goto alloc_spi; + + if (spi == XFRM6_TUNNEL_SPI_MAX) + break; + } + for (spi = XFRM6_TUNNEL_SPI_MIN; spi < xfrm6_tn->spi; spi++) { + index = __xfrm6_tunnel_spi_check(net, spi); + if (index >= 0) + goto alloc_spi; + } + spi = 0; + goto out; +alloc_spi: + xfrm6_tn->spi = spi; + x6spi = kmem_cache_alloc(xfrm6_tunnel_spi_kmem, GFP_ATOMIC); + if (!x6spi) + goto out; + + memcpy(&x6spi->addr, saddr, sizeof(x6spi->addr)); + x6spi->spi = spi; + refcount_set(&x6spi->refcnt, 1); + + hlist_add_head_rcu(&x6spi->list_byspi, &xfrm6_tn->spi_byspi[index]); + + index = xfrm6_tunnel_spi_hash_byaddr(saddr); + hlist_add_head_rcu(&x6spi->list_byaddr, &xfrm6_tn->spi_byaddr[index]); +out: + return spi; +} + +__be32 xfrm6_tunnel_alloc_spi(struct net *net, xfrm_address_t *saddr) +{ + struct xfrm6_tunnel_spi *x6spi; + u32 spi; + + spin_lock_bh(&xfrm6_tunnel_spi_lock); + x6spi = __xfrm6_tunnel_spi_lookup(net, saddr); + if (x6spi) { + refcount_inc(&x6spi->refcnt); + spi = x6spi->spi; + } else + spi = __xfrm6_tunnel_alloc_spi(net, saddr); + spin_unlock_bh(&xfrm6_tunnel_spi_lock); + + return htonl(spi); +} +EXPORT_SYMBOL(xfrm6_tunnel_alloc_spi); + +static void x6spi_destroy_rcu(struct rcu_head *head) +{ + kmem_cache_free(xfrm6_tunnel_spi_kmem, + container_of(head, struct xfrm6_tunnel_spi, rcu_head)); +} + +static void xfrm6_tunnel_free_spi(struct net *net, xfrm_address_t *saddr) +{ + struct xfrm6_tunnel_net *xfrm6_tn = xfrm6_tunnel_pernet(net); + struct xfrm6_tunnel_spi *x6spi; + struct hlist_node *n; + + spin_lock_bh(&xfrm6_tunnel_spi_lock); + + hlist_for_each_entry_safe(x6spi, n, + &xfrm6_tn->spi_byaddr[xfrm6_tunnel_spi_hash_byaddr(saddr)], + list_byaddr) + { + if (xfrm6_addr_equal(&x6spi->addr, saddr)) { + if (refcount_dec_and_test(&x6spi->refcnt)) { + hlist_del_rcu(&x6spi->list_byaddr); + hlist_del_rcu(&x6spi->list_byspi); + call_rcu(&x6spi->rcu_head, x6spi_destroy_rcu); + break; + } + } + } + spin_unlock_bh(&xfrm6_tunnel_spi_lock); +} + +static int xfrm6_tunnel_output(struct xfrm_state *x, struct sk_buff *skb) +{ + skb_push(skb, -skb_network_offset(skb)); + return 0; +} + +static int xfrm6_tunnel_input(struct xfrm_state *x, struct sk_buff *skb) +{ + return skb_network_header(skb)[IP6CB(skb)->nhoff]; +} + +static int xfrm6_tunnel_rcv(struct sk_buff *skb) +{ + struct net *net = dev_net(skb->dev); + const struct ipv6hdr *iph = ipv6_hdr(skb); + __be32 spi; + + spi = xfrm6_tunnel_spi_lookup(net, (const xfrm_address_t *)&iph->saddr); + return xfrm6_rcv_spi(skb, IPPROTO_IPV6, spi, NULL); +} + +static int xfrm6_tunnel_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + /* xfrm6_tunnel native err handling */ + switch (type) { + case ICMPV6_DEST_UNREACH: + switch (code) { + case ICMPV6_NOROUTE: + case ICMPV6_ADM_PROHIBITED: + case ICMPV6_NOT_NEIGHBOUR: + case ICMPV6_ADDR_UNREACH: + case ICMPV6_PORT_UNREACH: + default: + break; + } + break; + case ICMPV6_PKT_TOOBIG: + break; + case ICMPV6_TIME_EXCEED: + switch (code) { + case ICMPV6_EXC_HOPLIMIT: + break; + case ICMPV6_EXC_FRAGTIME: + default: + break; + } + break; + case ICMPV6_PARAMPROB: + switch (code) { + case ICMPV6_HDR_FIELD: break; + case ICMPV6_UNK_NEXTHDR: break; + case ICMPV6_UNK_OPTION: break; + } + break; + default: + break; + } + + return 0; +} + +static int xfrm6_tunnel_init_state(struct xfrm_state *x, struct netlink_ext_ack *extack) +{ + if (x->props.mode != XFRM_MODE_TUNNEL) { + NL_SET_ERR_MSG(extack, "IPv6 tunnel can only be used with tunnel mode"); + return -EINVAL; + } + + if (x->encap) { + NL_SET_ERR_MSG(extack, "IPv6 tunnel is not compatible with encapsulation"); + return -EINVAL; + } + + x->props.header_len = sizeof(struct ipv6hdr); + + return 0; +} + +static void xfrm6_tunnel_destroy(struct xfrm_state *x) +{ + struct net *net = xs_net(x); + + xfrm6_tunnel_free_spi(net, (xfrm_address_t *)&x->props.saddr); +} + +static const struct xfrm_type xfrm6_tunnel_type = { + .owner = THIS_MODULE, + .proto = IPPROTO_IPV6, + .init_state = xfrm6_tunnel_init_state, + .destructor = xfrm6_tunnel_destroy, + .input = xfrm6_tunnel_input, + .output = xfrm6_tunnel_output, +}; + +static struct xfrm6_tunnel xfrm6_tunnel_handler __read_mostly = { + .handler = xfrm6_tunnel_rcv, + .err_handler = xfrm6_tunnel_err, + .priority = 3, +}; + +static struct xfrm6_tunnel xfrm46_tunnel_handler __read_mostly = { + .handler = xfrm6_tunnel_rcv, + .err_handler = xfrm6_tunnel_err, + .priority = 3, +}; + +static int __net_init xfrm6_tunnel_net_init(struct net *net) +{ + struct xfrm6_tunnel_net *xfrm6_tn = xfrm6_tunnel_pernet(net); + unsigned int i; + + for (i = 0; i < XFRM6_TUNNEL_SPI_BYADDR_HSIZE; i++) + INIT_HLIST_HEAD(&xfrm6_tn->spi_byaddr[i]); + for (i = 0; i < XFRM6_TUNNEL_SPI_BYSPI_HSIZE; i++) + INIT_HLIST_HEAD(&xfrm6_tn->spi_byspi[i]); + xfrm6_tn->spi = 0; + + return 0; +} + +static void __net_exit xfrm6_tunnel_net_exit(struct net *net) +{ + struct xfrm6_tunnel_net *xfrm6_tn = xfrm6_tunnel_pernet(net); + unsigned int i; + + xfrm_flush_gc(); + xfrm_state_flush(net, 0, false, true); + + for (i = 0; i < XFRM6_TUNNEL_SPI_BYADDR_HSIZE; i++) + WARN_ON_ONCE(!hlist_empty(&xfrm6_tn->spi_byaddr[i])); + + for (i = 0; i < XFRM6_TUNNEL_SPI_BYSPI_HSIZE; i++) + WARN_ON_ONCE(!hlist_empty(&xfrm6_tn->spi_byspi[i])); +} + +static struct pernet_operations xfrm6_tunnel_net_ops = { + .init = xfrm6_tunnel_net_init, + .exit = xfrm6_tunnel_net_exit, + .id = &xfrm6_tunnel_net_id, + .size = sizeof(struct xfrm6_tunnel_net), +}; + +static int __init xfrm6_tunnel_init(void) +{ + int rv; + + xfrm6_tunnel_spi_kmem = kmem_cache_create("xfrm6_tunnel_spi", + sizeof(struct xfrm6_tunnel_spi), + 0, SLAB_HWCACHE_ALIGN, + NULL); + if (!xfrm6_tunnel_spi_kmem) + return -ENOMEM; + rv = register_pernet_subsys(&xfrm6_tunnel_net_ops); + if (rv < 0) + goto out_pernet; + rv = xfrm_register_type(&xfrm6_tunnel_type, AF_INET6); + if (rv < 0) + goto out_type; + rv = xfrm6_tunnel_register(&xfrm6_tunnel_handler, AF_INET6); + if (rv < 0) + goto out_xfrm6; + rv = xfrm6_tunnel_register(&xfrm46_tunnel_handler, AF_INET); + if (rv < 0) + goto out_xfrm46; + return 0; + +out_xfrm46: + xfrm6_tunnel_deregister(&xfrm6_tunnel_handler, AF_INET6); +out_xfrm6: + xfrm_unregister_type(&xfrm6_tunnel_type, AF_INET6); +out_type: + unregister_pernet_subsys(&xfrm6_tunnel_net_ops); +out_pernet: + kmem_cache_destroy(xfrm6_tunnel_spi_kmem); + return rv; +} + +static void __exit xfrm6_tunnel_fini(void) +{ + xfrm6_tunnel_deregister(&xfrm46_tunnel_handler, AF_INET); + xfrm6_tunnel_deregister(&xfrm6_tunnel_handler, AF_INET6); + xfrm_unregister_type(&xfrm6_tunnel_type, AF_INET6); + unregister_pernet_subsys(&xfrm6_tunnel_net_ops); + /* Someone maybe has gotten the xfrm6_tunnel_spi. + * So need to wait it. + */ + rcu_barrier(); + kmem_cache_destroy(xfrm6_tunnel_spi_kmem); +} + +module_init(xfrm6_tunnel_init); +module_exit(xfrm6_tunnel_fini); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_XFRM_TYPE(AF_INET6, XFRM_PROTO_IPV6); diff --git a/net/iucv/Kconfig b/net/iucv/Kconfig new file mode 100644 index 000000000..5cfddc9c6 --- /dev/null +++ b/net/iucv/Kconfig @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: GPL-2.0-only +config IUCV + depends on S390 + def_tristate y if S390 + prompt "IUCV support (S390 - z/VM only)" + help + Select this option if you want to use inter-user communication + under VM or VIF. If you run on z/VM, say "Y" to enable a fast + communication link between VM guests. + +config AFIUCV + depends on S390 + def_tristate m if QETH_L3 || IUCV + prompt "AF_IUCV Socket support (S390 - z/VM and HiperSockets transport)" + help + Select this option if you want to use AF_IUCV socket applications + based on z/VM inter-user communication vehicle or based on + HiperSockets. diff --git a/net/iucv/Makefile b/net/iucv/Makefile new file mode 100644 index 000000000..984d7ff05 --- /dev/null +++ b/net/iucv/Makefile @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for IUCV +# + +obj-$(CONFIG_IUCV) += iucv.o +obj-$(CONFIG_AFIUCV) += af_iucv.o diff --git a/net/iucv/af_iucv.c b/net/iucv/af_iucv.c new file mode 100644 index 000000000..498a0c35b --- /dev/null +++ b/net/iucv/af_iucv.c @@ -0,0 +1,2332 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * IUCV protocol stack for Linux on zSeries + * + * Copyright IBM Corp. 2006, 2009 + * + * Author(s): Jennifer Hunt + * Hendrik Brueckner + * PM functions: + * Ursula Braun + */ + +#define KMSG_COMPONENT "af_iucv" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#define VERSION "1.2" + +static char iucv_userid[80]; + +static struct proto iucv_proto = { + .name = "AF_IUCV", + .owner = THIS_MODULE, + .obj_size = sizeof(struct iucv_sock), +}; + +static struct iucv_interface *pr_iucv; +static struct iucv_handler af_iucv_handler; + +/* special AF_IUCV IPRM messages */ +static const u8 iprm_shutdown[8] = + {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}; + +#define TRGCLS_SIZE sizeof_field(struct iucv_message, class) + +#define __iucv_sock_wait(sk, condition, timeo, ret) \ +do { \ + DEFINE_WAIT(__wait); \ + long __timeo = timeo; \ + ret = 0; \ + prepare_to_wait(sk_sleep(sk), &__wait, TASK_INTERRUPTIBLE); \ + while (!(condition)) { \ + if (!__timeo) { \ + ret = -EAGAIN; \ + break; \ + } \ + if (signal_pending(current)) { \ + ret = sock_intr_errno(__timeo); \ + break; \ + } \ + release_sock(sk); \ + __timeo = schedule_timeout(__timeo); \ + lock_sock(sk); \ + ret = sock_error(sk); \ + if (ret) \ + break; \ + } \ + finish_wait(sk_sleep(sk), &__wait); \ +} while (0) + +#define iucv_sock_wait(sk, condition, timeo) \ +({ \ + int __ret = 0; \ + if (!(condition)) \ + __iucv_sock_wait(sk, condition, timeo, __ret); \ + __ret; \ +}) + +static struct sock *iucv_accept_dequeue(struct sock *parent, + struct socket *newsock); +static void iucv_sock_kill(struct sock *sk); +static void iucv_sock_close(struct sock *sk); + +static void afiucv_hs_callback_txnotify(struct sock *sk, enum iucv_tx_notify); + +static struct iucv_sock_list iucv_sk_list = { + .lock = __RW_LOCK_UNLOCKED(iucv_sk_list.lock), + .autobind_name = ATOMIC_INIT(0) +}; + +static inline void high_nmcpy(unsigned char *dst, char *src) +{ + memcpy(dst, src, 8); +} + +static inline void low_nmcpy(unsigned char *dst, char *src) +{ + memcpy(&dst[8], src, 8); +} + +/** + * iucv_msg_length() - Returns the length of an iucv message. + * @msg: Pointer to struct iucv_message, MUST NOT be NULL + * + * The function returns the length of the specified iucv message @msg of data + * stored in a buffer and of data stored in the parameter list (PRMDATA). + * + * For IUCV_IPRMDATA, AF_IUCV uses the following convention to transport socket + * data: + * PRMDATA[0..6] socket data (max 7 bytes); + * PRMDATA[7] socket data length value (len is 0xff - PRMDATA[7]) + * + * The socket data length is computed by subtracting the socket data length + * value from 0xFF. + * If the socket data len is greater 7, then PRMDATA can be used for special + * notifications (see iucv_sock_shutdown); and further, + * if the socket data len is > 7, the function returns 8. + * + * Use this function to allocate socket buffers to store iucv message data. + */ +static inline size_t iucv_msg_length(struct iucv_message *msg) +{ + size_t datalen; + + if (msg->flags & IUCV_IPRMDATA) { + datalen = 0xff - msg->rmmsg[7]; + return (datalen < 8) ? datalen : 8; + } + return msg->length; +} + +/** + * iucv_sock_in_state() - check for specific states + * @sk: sock structure + * @state: first iucv sk state + * @state2: second iucv sk state + * + * Returns true if the socket in either in the first or second state. + */ +static int iucv_sock_in_state(struct sock *sk, int state, int state2) +{ + return (sk->sk_state == state || sk->sk_state == state2); +} + +/** + * iucv_below_msglim() - function to check if messages can be sent + * @sk: sock structure + * + * Returns true if the send queue length is lower than the message limit. + * Always returns true if the socket is not connected (no iucv path for + * checking the message limit). + */ +static inline int iucv_below_msglim(struct sock *sk) +{ + struct iucv_sock *iucv = iucv_sk(sk); + + if (sk->sk_state != IUCV_CONNECTED) + return 1; + if (iucv->transport == AF_IUCV_TRANS_IUCV) + return (atomic_read(&iucv->skbs_in_xmit) < iucv->path->msglim); + else + return ((atomic_read(&iucv->msg_sent) < iucv->msglimit_peer) && + (atomic_read(&iucv->pendings) <= 0)); +} + +/* + * iucv_sock_wake_msglim() - Wake up thread waiting on msg limit + */ +static void iucv_sock_wake_msglim(struct sock *sk) +{ + struct socket_wq *wq; + + rcu_read_lock(); + wq = rcu_dereference(sk->sk_wq); + if (skwq_has_sleeper(wq)) + wake_up_interruptible_all(&wq->wait); + sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT); + rcu_read_unlock(); +} + +/* + * afiucv_hs_send() - send a message through HiperSockets transport + */ +static int afiucv_hs_send(struct iucv_message *imsg, struct sock *sock, + struct sk_buff *skb, u8 flags) +{ + struct iucv_sock *iucv = iucv_sk(sock); + struct af_iucv_trans_hdr *phs_hdr; + int err, confirm_recv = 0; + + phs_hdr = skb_push(skb, sizeof(*phs_hdr)); + memset(phs_hdr, 0, sizeof(*phs_hdr)); + skb_reset_network_header(skb); + + phs_hdr->magic = ETH_P_AF_IUCV; + phs_hdr->version = 1; + phs_hdr->flags = flags; + if (flags == AF_IUCV_FLAG_SYN) + phs_hdr->window = iucv->msglimit; + else if ((flags == AF_IUCV_FLAG_WIN) || !flags) { + confirm_recv = atomic_read(&iucv->msg_recv); + phs_hdr->window = confirm_recv; + if (confirm_recv) + phs_hdr->flags = phs_hdr->flags | AF_IUCV_FLAG_WIN; + } + memcpy(phs_hdr->destUserID, iucv->dst_user_id, 8); + memcpy(phs_hdr->destAppName, iucv->dst_name, 8); + memcpy(phs_hdr->srcUserID, iucv->src_user_id, 8); + memcpy(phs_hdr->srcAppName, iucv->src_name, 8); + ASCEBC(phs_hdr->destUserID, sizeof(phs_hdr->destUserID)); + ASCEBC(phs_hdr->destAppName, sizeof(phs_hdr->destAppName)); + ASCEBC(phs_hdr->srcUserID, sizeof(phs_hdr->srcUserID)); + ASCEBC(phs_hdr->srcAppName, sizeof(phs_hdr->srcAppName)); + if (imsg) + memcpy(&phs_hdr->iucv_hdr, imsg, sizeof(struct iucv_message)); + + skb->dev = iucv->hs_dev; + if (!skb->dev) { + err = -ENODEV; + goto err_free; + } + + dev_hard_header(skb, skb->dev, ETH_P_AF_IUCV, NULL, NULL, skb->len); + + if (!(skb->dev->flags & IFF_UP) || !netif_carrier_ok(skb->dev)) { + err = -ENETDOWN; + goto err_free; + } + if (skb->len > skb->dev->mtu) { + if (sock->sk_type == SOCK_SEQPACKET) { + err = -EMSGSIZE; + goto err_free; + } + err = pskb_trim(skb, skb->dev->mtu); + if (err) + goto err_free; + } + skb->protocol = cpu_to_be16(ETH_P_AF_IUCV); + + atomic_inc(&iucv->skbs_in_xmit); + err = dev_queue_xmit(skb); + if (net_xmit_eval(err)) { + atomic_dec(&iucv->skbs_in_xmit); + } else { + atomic_sub(confirm_recv, &iucv->msg_recv); + WARN_ON(atomic_read(&iucv->msg_recv) < 0); + } + return net_xmit_eval(err); + +err_free: + kfree_skb(skb); + return err; +} + +static struct sock *__iucv_get_sock_by_name(char *nm) +{ + struct sock *sk; + + sk_for_each(sk, &iucv_sk_list.head) + if (!memcmp(&iucv_sk(sk)->src_name, nm, 8)) + return sk; + + return NULL; +} + +static void iucv_sock_destruct(struct sock *sk) +{ + skb_queue_purge(&sk->sk_receive_queue); + skb_queue_purge(&sk->sk_error_queue); + + if (!sock_flag(sk, SOCK_DEAD)) { + pr_err("Attempt to release alive iucv socket %p\n", sk); + return; + } + + WARN_ON(atomic_read(&sk->sk_rmem_alloc)); + WARN_ON(refcount_read(&sk->sk_wmem_alloc)); + WARN_ON(sk->sk_wmem_queued); + WARN_ON(sk->sk_forward_alloc); +} + +/* Cleanup Listen */ +static void iucv_sock_cleanup_listen(struct sock *parent) +{ + struct sock *sk; + + /* Close non-accepted connections */ + while ((sk = iucv_accept_dequeue(parent, NULL))) { + iucv_sock_close(sk); + iucv_sock_kill(sk); + } + + parent->sk_state = IUCV_CLOSED; +} + +static void iucv_sock_link(struct iucv_sock_list *l, struct sock *sk) +{ + write_lock_bh(&l->lock); + sk_add_node(sk, &l->head); + write_unlock_bh(&l->lock); +} + +static void iucv_sock_unlink(struct iucv_sock_list *l, struct sock *sk) +{ + write_lock_bh(&l->lock); + sk_del_node_init(sk); + write_unlock_bh(&l->lock); +} + +/* Kill socket (only if zapped and orphaned) */ +static void iucv_sock_kill(struct sock *sk) +{ + if (!sock_flag(sk, SOCK_ZAPPED) || sk->sk_socket) + return; + + iucv_sock_unlink(&iucv_sk_list, sk); + sock_set_flag(sk, SOCK_DEAD); + sock_put(sk); +} + +/* Terminate an IUCV path */ +static void iucv_sever_path(struct sock *sk, int with_user_data) +{ + unsigned char user_data[16]; + struct iucv_sock *iucv = iucv_sk(sk); + struct iucv_path *path = iucv->path; + + if (iucv->path) { + iucv->path = NULL; + if (with_user_data) { + low_nmcpy(user_data, iucv->src_name); + high_nmcpy(user_data, iucv->dst_name); + ASCEBC(user_data, sizeof(user_data)); + pr_iucv->path_sever(path, user_data); + } else + pr_iucv->path_sever(path, NULL); + iucv_path_free(path); + } +} + +/* Send controlling flags through an IUCV socket for HIPER transport */ +static int iucv_send_ctrl(struct sock *sk, u8 flags) +{ + struct iucv_sock *iucv = iucv_sk(sk); + int err = 0; + int blen; + struct sk_buff *skb; + u8 shutdown = 0; + + blen = sizeof(struct af_iucv_trans_hdr) + + LL_RESERVED_SPACE(iucv->hs_dev); + if (sk->sk_shutdown & SEND_SHUTDOWN) { + /* controlling flags should be sent anyway */ + shutdown = sk->sk_shutdown; + sk->sk_shutdown &= RCV_SHUTDOWN; + } + skb = sock_alloc_send_skb(sk, blen, 1, &err); + if (skb) { + skb_reserve(skb, blen); + err = afiucv_hs_send(NULL, sk, skb, flags); + } + if (shutdown) + sk->sk_shutdown = shutdown; + return err; +} + +/* Close an IUCV socket */ +static void iucv_sock_close(struct sock *sk) +{ + struct iucv_sock *iucv = iucv_sk(sk); + unsigned long timeo; + int err = 0; + + lock_sock(sk); + + switch (sk->sk_state) { + case IUCV_LISTEN: + iucv_sock_cleanup_listen(sk); + break; + + case IUCV_CONNECTED: + if (iucv->transport == AF_IUCV_TRANS_HIPER) { + err = iucv_send_ctrl(sk, AF_IUCV_FLAG_FIN); + sk->sk_state = IUCV_DISCONN; + sk->sk_state_change(sk); + } + fallthrough; + + case IUCV_DISCONN: + sk->sk_state = IUCV_CLOSING; + sk->sk_state_change(sk); + + if (!err && atomic_read(&iucv->skbs_in_xmit) > 0) { + if (sock_flag(sk, SOCK_LINGER) && sk->sk_lingertime) + timeo = sk->sk_lingertime; + else + timeo = IUCV_DISCONN_TIMEOUT; + iucv_sock_wait(sk, + iucv_sock_in_state(sk, IUCV_CLOSED, 0), + timeo); + } + fallthrough; + + case IUCV_CLOSING: + sk->sk_state = IUCV_CLOSED; + sk->sk_state_change(sk); + + sk->sk_err = ECONNRESET; + sk->sk_state_change(sk); + + skb_queue_purge(&iucv->send_skb_q); + skb_queue_purge(&iucv->backlog_skb_q); + fallthrough; + + default: + iucv_sever_path(sk, 1); + } + + if (iucv->hs_dev) { + dev_put(iucv->hs_dev); + iucv->hs_dev = NULL; + sk->sk_bound_dev_if = 0; + } + + /* mark socket for deletion by iucv_sock_kill() */ + sock_set_flag(sk, SOCK_ZAPPED); + + release_sock(sk); +} + +static void iucv_sock_init(struct sock *sk, struct sock *parent) +{ + if (parent) { + sk->sk_type = parent->sk_type; + security_sk_clone(parent, sk); + } +} + +static struct sock *iucv_sock_alloc(struct socket *sock, int proto, gfp_t prio, int kern) +{ + struct sock *sk; + struct iucv_sock *iucv; + + sk = sk_alloc(&init_net, PF_IUCV, prio, &iucv_proto, kern); + if (!sk) + return NULL; + iucv = iucv_sk(sk); + + sock_init_data(sock, sk); + INIT_LIST_HEAD(&iucv->accept_q); + spin_lock_init(&iucv->accept_q_lock); + skb_queue_head_init(&iucv->send_skb_q); + INIT_LIST_HEAD(&iucv->message_q.list); + spin_lock_init(&iucv->message_q.lock); + skb_queue_head_init(&iucv->backlog_skb_q); + iucv->send_tag = 0; + atomic_set(&iucv->pendings, 0); + iucv->flags = 0; + iucv->msglimit = 0; + atomic_set(&iucv->skbs_in_xmit, 0); + atomic_set(&iucv->msg_sent, 0); + atomic_set(&iucv->msg_recv, 0); + iucv->path = NULL; + iucv->sk_txnotify = afiucv_hs_callback_txnotify; + memset(&iucv->init, 0, sizeof(iucv->init)); + if (pr_iucv) + iucv->transport = AF_IUCV_TRANS_IUCV; + else + iucv->transport = AF_IUCV_TRANS_HIPER; + + sk->sk_destruct = iucv_sock_destruct; + sk->sk_sndtimeo = IUCV_CONN_TIMEOUT; + + sock_reset_flag(sk, SOCK_ZAPPED); + + sk->sk_protocol = proto; + sk->sk_state = IUCV_OPEN; + + iucv_sock_link(&iucv_sk_list, sk); + return sk; +} + +static void iucv_accept_enqueue(struct sock *parent, struct sock *sk) +{ + unsigned long flags; + struct iucv_sock *par = iucv_sk(parent); + + sock_hold(sk); + spin_lock_irqsave(&par->accept_q_lock, flags); + list_add_tail(&iucv_sk(sk)->accept_q, &par->accept_q); + spin_unlock_irqrestore(&par->accept_q_lock, flags); + iucv_sk(sk)->parent = parent; + sk_acceptq_added(parent); +} + +static void iucv_accept_unlink(struct sock *sk) +{ + unsigned long flags; + struct iucv_sock *par = iucv_sk(iucv_sk(sk)->parent); + + spin_lock_irqsave(&par->accept_q_lock, flags); + list_del_init(&iucv_sk(sk)->accept_q); + spin_unlock_irqrestore(&par->accept_q_lock, flags); + sk_acceptq_removed(iucv_sk(sk)->parent); + iucv_sk(sk)->parent = NULL; + sock_put(sk); +} + +static struct sock *iucv_accept_dequeue(struct sock *parent, + struct socket *newsock) +{ + struct iucv_sock *isk, *n; + struct sock *sk; + + list_for_each_entry_safe(isk, n, &iucv_sk(parent)->accept_q, accept_q) { + sk = (struct sock *) isk; + lock_sock(sk); + + if (sk->sk_state == IUCV_CLOSED) { + iucv_accept_unlink(sk); + release_sock(sk); + continue; + } + + if (sk->sk_state == IUCV_CONNECTED || + sk->sk_state == IUCV_DISCONN || + !newsock) { + iucv_accept_unlink(sk); + if (newsock) + sock_graft(sk, newsock); + + release_sock(sk); + return sk; + } + + release_sock(sk); + } + return NULL; +} + +static void __iucv_auto_name(struct iucv_sock *iucv) +{ + char name[12]; + + sprintf(name, "%08x", atomic_inc_return(&iucv_sk_list.autobind_name)); + while (__iucv_get_sock_by_name(name)) { + sprintf(name, "%08x", + atomic_inc_return(&iucv_sk_list.autobind_name)); + } + memcpy(iucv->src_name, name, 8); +} + +/* Bind an unbound socket */ +static int iucv_sock_bind(struct socket *sock, struct sockaddr *addr, + int addr_len) +{ + DECLARE_SOCKADDR(struct sockaddr_iucv *, sa, addr); + char uid[sizeof(sa->siucv_user_id)]; + struct sock *sk = sock->sk; + struct iucv_sock *iucv; + int err = 0; + struct net_device *dev; + + /* Verify the input sockaddr */ + if (addr_len < sizeof(struct sockaddr_iucv) || + addr->sa_family != AF_IUCV) + return -EINVAL; + + lock_sock(sk); + if (sk->sk_state != IUCV_OPEN) { + err = -EBADFD; + goto done; + } + + write_lock_bh(&iucv_sk_list.lock); + + iucv = iucv_sk(sk); + if (__iucv_get_sock_by_name(sa->siucv_name)) { + err = -EADDRINUSE; + goto done_unlock; + } + if (iucv->path) + goto done_unlock; + + /* Bind the socket */ + if (pr_iucv) + if (!memcmp(sa->siucv_user_id, iucv_userid, 8)) + goto vm_bind; /* VM IUCV transport */ + + /* try hiper transport */ + memcpy(uid, sa->siucv_user_id, sizeof(uid)); + ASCEBC(uid, 8); + rcu_read_lock(); + for_each_netdev_rcu(&init_net, dev) { + if (!memcmp(dev->perm_addr, uid, 8)) { + memcpy(iucv->src_user_id, sa->siucv_user_id, 8); + /* Check for uninitialized siucv_name */ + if (strncmp(sa->siucv_name, " ", 8) == 0) + __iucv_auto_name(iucv); + else + memcpy(iucv->src_name, sa->siucv_name, 8); + sk->sk_bound_dev_if = dev->ifindex; + iucv->hs_dev = dev; + dev_hold(dev); + sk->sk_state = IUCV_BOUND; + iucv->transport = AF_IUCV_TRANS_HIPER; + if (!iucv->msglimit) + iucv->msglimit = IUCV_HIPER_MSGLIM_DEFAULT; + rcu_read_unlock(); + goto done_unlock; + } + } + rcu_read_unlock(); +vm_bind: + if (pr_iucv) { + /* use local userid for backward compat */ + memcpy(iucv->src_name, sa->siucv_name, 8); + memcpy(iucv->src_user_id, iucv_userid, 8); + sk->sk_state = IUCV_BOUND; + iucv->transport = AF_IUCV_TRANS_IUCV; + sk->sk_allocation |= GFP_DMA; + if (!iucv->msglimit) + iucv->msglimit = IUCV_QUEUELEN_DEFAULT; + goto done_unlock; + } + /* found no dev to bind */ + err = -ENODEV; +done_unlock: + /* Release the socket list lock */ + write_unlock_bh(&iucv_sk_list.lock); +done: + release_sock(sk); + return err; +} + +/* Automatically bind an unbound socket */ +static int iucv_sock_autobind(struct sock *sk) +{ + struct iucv_sock *iucv = iucv_sk(sk); + int err = 0; + + if (unlikely(!pr_iucv)) + return -EPROTO; + + memcpy(iucv->src_user_id, iucv_userid, 8); + iucv->transport = AF_IUCV_TRANS_IUCV; + sk->sk_allocation |= GFP_DMA; + + write_lock_bh(&iucv_sk_list.lock); + __iucv_auto_name(iucv); + write_unlock_bh(&iucv_sk_list.lock); + + if (!iucv->msglimit) + iucv->msglimit = IUCV_QUEUELEN_DEFAULT; + + return err; +} + +static int afiucv_path_connect(struct socket *sock, struct sockaddr *addr) +{ + DECLARE_SOCKADDR(struct sockaddr_iucv *, sa, addr); + struct sock *sk = sock->sk; + struct iucv_sock *iucv = iucv_sk(sk); + unsigned char user_data[16]; + int err; + + high_nmcpy(user_data, sa->siucv_name); + low_nmcpy(user_data, iucv->src_name); + ASCEBC(user_data, sizeof(user_data)); + + /* Create path. */ + iucv->path = iucv_path_alloc(iucv->msglimit, + IUCV_IPRMDATA, GFP_KERNEL); + if (!iucv->path) { + err = -ENOMEM; + goto done; + } + err = pr_iucv->path_connect(iucv->path, &af_iucv_handler, + sa->siucv_user_id, NULL, user_data, + sk); + if (err) { + iucv_path_free(iucv->path); + iucv->path = NULL; + switch (err) { + case 0x0b: /* Target communicator is not logged on */ + err = -ENETUNREACH; + break; + case 0x0d: /* Max connections for this guest exceeded */ + case 0x0e: /* Max connections for target guest exceeded */ + err = -EAGAIN; + break; + case 0x0f: /* Missing IUCV authorization */ + err = -EACCES; + break; + default: + err = -ECONNREFUSED; + break; + } + } +done: + return err; +} + +/* Connect an unconnected socket */ +static int iucv_sock_connect(struct socket *sock, struct sockaddr *addr, + int alen, int flags) +{ + DECLARE_SOCKADDR(struct sockaddr_iucv *, sa, addr); + struct sock *sk = sock->sk; + struct iucv_sock *iucv = iucv_sk(sk); + int err; + + if (alen < sizeof(struct sockaddr_iucv) || addr->sa_family != AF_IUCV) + return -EINVAL; + + if (sk->sk_state != IUCV_OPEN && sk->sk_state != IUCV_BOUND) + return -EBADFD; + + if (sk->sk_state == IUCV_OPEN && + iucv->transport == AF_IUCV_TRANS_HIPER) + return -EBADFD; /* explicit bind required */ + + if (sk->sk_type != SOCK_STREAM && sk->sk_type != SOCK_SEQPACKET) + return -EINVAL; + + if (sk->sk_state == IUCV_OPEN) { + err = iucv_sock_autobind(sk); + if (unlikely(err)) + return err; + } + + lock_sock(sk); + + /* Set the destination information */ + memcpy(iucv->dst_user_id, sa->siucv_user_id, 8); + memcpy(iucv->dst_name, sa->siucv_name, 8); + + if (iucv->transport == AF_IUCV_TRANS_HIPER) + err = iucv_send_ctrl(sock->sk, AF_IUCV_FLAG_SYN); + else + err = afiucv_path_connect(sock, addr); + if (err) + goto done; + + if (sk->sk_state != IUCV_CONNECTED) + err = iucv_sock_wait(sk, iucv_sock_in_state(sk, IUCV_CONNECTED, + IUCV_DISCONN), + sock_sndtimeo(sk, flags & O_NONBLOCK)); + + if (sk->sk_state == IUCV_DISCONN || sk->sk_state == IUCV_CLOSED) + err = -ECONNREFUSED; + + if (err && iucv->transport == AF_IUCV_TRANS_IUCV) + iucv_sever_path(sk, 0); + +done: + release_sock(sk); + return err; +} + +/* Move a socket into listening state. */ +static int iucv_sock_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + int err; + + lock_sock(sk); + + err = -EINVAL; + if (sk->sk_state != IUCV_BOUND) + goto done; + + if (sock->type != SOCK_STREAM && sock->type != SOCK_SEQPACKET) + goto done; + + sk->sk_max_ack_backlog = backlog; + sk->sk_ack_backlog = 0; + sk->sk_state = IUCV_LISTEN; + err = 0; + +done: + release_sock(sk); + return err; +} + +/* Accept a pending connection */ +static int iucv_sock_accept(struct socket *sock, struct socket *newsock, + int flags, bool kern) +{ + DECLARE_WAITQUEUE(wait, current); + struct sock *sk = sock->sk, *nsk; + long timeo; + int err = 0; + + lock_sock_nested(sk, SINGLE_DEPTH_NESTING); + + if (sk->sk_state != IUCV_LISTEN) { + err = -EBADFD; + goto done; + } + + timeo = sock_rcvtimeo(sk, flags & O_NONBLOCK); + + /* Wait for an incoming connection */ + add_wait_queue_exclusive(sk_sleep(sk), &wait); + while (!(nsk = iucv_accept_dequeue(sk, newsock))) { + set_current_state(TASK_INTERRUPTIBLE); + if (!timeo) { + err = -EAGAIN; + break; + } + + release_sock(sk); + timeo = schedule_timeout(timeo); + lock_sock_nested(sk, SINGLE_DEPTH_NESTING); + + if (sk->sk_state != IUCV_LISTEN) { + err = -EBADFD; + break; + } + + if (signal_pending(current)) { + err = sock_intr_errno(timeo); + break; + } + } + + set_current_state(TASK_RUNNING); + remove_wait_queue(sk_sleep(sk), &wait); + + if (err) + goto done; + + newsock->state = SS_CONNECTED; + +done: + release_sock(sk); + return err; +} + +static int iucv_sock_getname(struct socket *sock, struct sockaddr *addr, + int peer) +{ + DECLARE_SOCKADDR(struct sockaddr_iucv *, siucv, addr); + struct sock *sk = sock->sk; + struct iucv_sock *iucv = iucv_sk(sk); + + addr->sa_family = AF_IUCV; + + if (peer) { + memcpy(siucv->siucv_user_id, iucv->dst_user_id, 8); + memcpy(siucv->siucv_name, iucv->dst_name, 8); + } else { + memcpy(siucv->siucv_user_id, iucv->src_user_id, 8); + memcpy(siucv->siucv_name, iucv->src_name, 8); + } + memset(&siucv->siucv_port, 0, sizeof(siucv->siucv_port)); + memset(&siucv->siucv_addr, 0, sizeof(siucv->siucv_addr)); + memset(&siucv->siucv_nodeid, 0, sizeof(siucv->siucv_nodeid)); + + return sizeof(struct sockaddr_iucv); +} + +/** + * iucv_send_iprm() - Send socket data in parameter list of an iucv message. + * @path: IUCV path + * @msg: Pointer to a struct iucv_message + * @skb: The socket data to send, skb->len MUST BE <= 7 + * + * Send the socket data in the parameter list in the iucv message + * (IUCV_IPRMDATA). The socket data is stored at index 0 to 6 in the parameter + * list and the socket data len at index 7 (last byte). + * See also iucv_msg_length(). + * + * Returns the error code from the iucv_message_send() call. + */ +static int iucv_send_iprm(struct iucv_path *path, struct iucv_message *msg, + struct sk_buff *skb) +{ + u8 prmdata[8]; + + memcpy(prmdata, (void *) skb->data, skb->len); + prmdata[7] = 0xff - (u8) skb->len; + return pr_iucv->message_send(path, msg, IUCV_IPRMDATA, 0, + (void *) prmdata, 8); +} + +static int iucv_sock_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + struct sock *sk = sock->sk; + struct iucv_sock *iucv = iucv_sk(sk); + size_t headroom = 0; + size_t linear; + struct sk_buff *skb; + struct iucv_message txmsg = {0}; + struct cmsghdr *cmsg; + int cmsg_done; + long timeo; + char user_id[9]; + char appl_id[9]; + int err; + int noblock = msg->msg_flags & MSG_DONTWAIT; + + err = sock_error(sk); + if (err) + return err; + + if (msg->msg_flags & MSG_OOB) + return -EOPNOTSUPP; + + /* SOCK_SEQPACKET: we do not support segmented records */ + if (sk->sk_type == SOCK_SEQPACKET && !(msg->msg_flags & MSG_EOR)) + return -EOPNOTSUPP; + + lock_sock(sk); + + if (sk->sk_shutdown & SEND_SHUTDOWN) { + err = -EPIPE; + goto out; + } + + /* Return if the socket is not in connected state */ + if (sk->sk_state != IUCV_CONNECTED) { + err = -ENOTCONN; + goto out; + } + + /* initialize defaults */ + cmsg_done = 0; /* check for duplicate headers */ + + /* iterate over control messages */ + for_each_cmsghdr(cmsg, msg) { + if (!CMSG_OK(msg, cmsg)) { + err = -EINVAL; + goto out; + } + + if (cmsg->cmsg_level != SOL_IUCV) + continue; + + if (cmsg->cmsg_type & cmsg_done) { + err = -EINVAL; + goto out; + } + cmsg_done |= cmsg->cmsg_type; + + switch (cmsg->cmsg_type) { + case SCM_IUCV_TRGCLS: + if (cmsg->cmsg_len != CMSG_LEN(TRGCLS_SIZE)) { + err = -EINVAL; + goto out; + } + + /* set iucv message target class */ + memcpy(&txmsg.class, + (void *) CMSG_DATA(cmsg), TRGCLS_SIZE); + + break; + + default: + err = -EINVAL; + goto out; + } + } + + /* allocate one skb for each iucv message: + * this is fine for SOCK_SEQPACKET (unless we want to support + * segmented records using the MSG_EOR flag), but + * for SOCK_STREAM we might want to improve it in future */ + if (iucv->transport == AF_IUCV_TRANS_HIPER) { + headroom = sizeof(struct af_iucv_trans_hdr) + + LL_RESERVED_SPACE(iucv->hs_dev); + linear = min(len, PAGE_SIZE - headroom); + } else { + if (len < PAGE_SIZE) { + linear = len; + } else { + /* In nonlinear "classic" iucv skb, + * reserve space for iucv_array + */ + headroom = sizeof(struct iucv_array) * + (MAX_SKB_FRAGS + 1); + linear = PAGE_SIZE - headroom; + } + } + skb = sock_alloc_send_pskb(sk, headroom + linear, len - linear, + noblock, &err, 0); + if (!skb) + goto out; + if (headroom) + skb_reserve(skb, headroom); + skb_put(skb, linear); + skb->len = len; + skb->data_len = len - linear; + err = skb_copy_datagram_from_iter(skb, 0, &msg->msg_iter, len); + if (err) + goto fail; + + /* wait if outstanding messages for iucv path has reached */ + timeo = sock_sndtimeo(sk, noblock); + err = iucv_sock_wait(sk, iucv_below_msglim(sk), timeo); + if (err) + goto fail; + + /* return -ECONNRESET if the socket is no longer connected */ + if (sk->sk_state != IUCV_CONNECTED) { + err = -ECONNRESET; + goto fail; + } + + /* increment and save iucv message tag for msg_completion cbk */ + txmsg.tag = iucv->send_tag++; + IUCV_SKB_CB(skb)->tag = txmsg.tag; + + if (iucv->transport == AF_IUCV_TRANS_HIPER) { + atomic_inc(&iucv->msg_sent); + err = afiucv_hs_send(&txmsg, sk, skb, 0); + if (err) { + atomic_dec(&iucv->msg_sent); + goto out; + } + } else { /* Classic VM IUCV transport */ + skb_queue_tail(&iucv->send_skb_q, skb); + atomic_inc(&iucv->skbs_in_xmit); + + if (((iucv->path->flags & IUCV_IPRMDATA) & iucv->flags) && + skb->len <= 7) { + err = iucv_send_iprm(iucv->path, &txmsg, skb); + + /* on success: there is no message_complete callback */ + /* for an IPRMDATA msg; remove skb from send queue */ + if (err == 0) { + atomic_dec(&iucv->skbs_in_xmit); + skb_unlink(skb, &iucv->send_skb_q); + consume_skb(skb); + } + + /* this error should never happen since the */ + /* IUCV_IPRMDATA path flag is set... sever path */ + if (err == 0x15) { + pr_iucv->path_sever(iucv->path, NULL); + atomic_dec(&iucv->skbs_in_xmit); + skb_unlink(skb, &iucv->send_skb_q); + err = -EPIPE; + goto fail; + } + } else if (skb_is_nonlinear(skb)) { + struct iucv_array *iba = (struct iucv_array *)skb->head; + int i; + + /* skip iucv_array lying in the headroom */ + iba[0].address = (u32)(addr_t)skb->data; + iba[0].length = (u32)skb_headlen(skb); + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; + + iba[i + 1].address = + (u32)(addr_t)skb_frag_address(frag); + iba[i + 1].length = (u32)skb_frag_size(frag); + } + err = pr_iucv->message_send(iucv->path, &txmsg, + IUCV_IPBUFLST, 0, + (void *)iba, skb->len); + } else { /* non-IPRM Linear skb */ + err = pr_iucv->message_send(iucv->path, &txmsg, + 0, 0, (void *)skb->data, skb->len); + } + if (err) { + if (err == 3) { + user_id[8] = 0; + memcpy(user_id, iucv->dst_user_id, 8); + appl_id[8] = 0; + memcpy(appl_id, iucv->dst_name, 8); + pr_err( + "Application %s on z/VM guest %s exceeds message limit\n", + appl_id, user_id); + err = -EAGAIN; + } else { + err = -EPIPE; + } + + atomic_dec(&iucv->skbs_in_xmit); + skb_unlink(skb, &iucv->send_skb_q); + goto fail; + } + } + + release_sock(sk); + return len; + +fail: + kfree_skb(skb); +out: + release_sock(sk); + return err; +} + +static struct sk_buff *alloc_iucv_recv_skb(unsigned long len) +{ + size_t headroom, linear; + struct sk_buff *skb; + int err; + + if (len < PAGE_SIZE) { + headroom = 0; + linear = len; + } else { + headroom = sizeof(struct iucv_array) * (MAX_SKB_FRAGS + 1); + linear = PAGE_SIZE - headroom; + } + skb = alloc_skb_with_frags(headroom + linear, len - linear, + 0, &err, GFP_ATOMIC | GFP_DMA); + WARN_ONCE(!skb, + "alloc of recv iucv skb len=%lu failed with errcode=%d\n", + len, err); + if (skb) { + if (headroom) + skb_reserve(skb, headroom); + skb_put(skb, linear); + skb->len = len; + skb->data_len = len - linear; + } + return skb; +} + +/* iucv_process_message() - Receive a single outstanding IUCV message + * + * Locking: must be called with message_q.lock held + */ +static void iucv_process_message(struct sock *sk, struct sk_buff *skb, + struct iucv_path *path, + struct iucv_message *msg) +{ + int rc; + unsigned int len; + + len = iucv_msg_length(msg); + + /* store msg target class in the second 4 bytes of skb ctrl buffer */ + /* Note: the first 4 bytes are reserved for msg tag */ + IUCV_SKB_CB(skb)->class = msg->class; + + /* check for special IPRM messages (e.g. iucv_sock_shutdown) */ + if ((msg->flags & IUCV_IPRMDATA) && len > 7) { + if (memcmp(msg->rmmsg, iprm_shutdown, 8) == 0) { + skb->data = NULL; + skb->len = 0; + } + } else { + if (skb_is_nonlinear(skb)) { + struct iucv_array *iba = (struct iucv_array *)skb->head; + int i; + + iba[0].address = (u32)(addr_t)skb->data; + iba[0].length = (u32)skb_headlen(skb); + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; + + iba[i + 1].address = + (u32)(addr_t)skb_frag_address(frag); + iba[i + 1].length = (u32)skb_frag_size(frag); + } + rc = pr_iucv->message_receive(path, msg, + IUCV_IPBUFLST, + (void *)iba, len, NULL); + } else { + rc = pr_iucv->message_receive(path, msg, + msg->flags & IUCV_IPRMDATA, + skb->data, len, NULL); + } + if (rc) { + kfree_skb(skb); + return; + } + WARN_ON_ONCE(skb->len != len); + } + + IUCV_SKB_CB(skb)->offset = 0; + if (sk_filter(sk, skb)) { + atomic_inc(&sk->sk_drops); /* skb rejected by filter */ + kfree_skb(skb); + return; + } + if (__sock_queue_rcv_skb(sk, skb)) /* handle rcv queue full */ + skb_queue_tail(&iucv_sk(sk)->backlog_skb_q, skb); +} + +/* iucv_process_message_q() - Process outstanding IUCV messages + * + * Locking: must be called with message_q.lock held + */ +static void iucv_process_message_q(struct sock *sk) +{ + struct iucv_sock *iucv = iucv_sk(sk); + struct sk_buff *skb; + struct sock_msg_q *p, *n; + + list_for_each_entry_safe(p, n, &iucv->message_q.list, list) { + skb = alloc_iucv_recv_skb(iucv_msg_length(&p->msg)); + if (!skb) + break; + iucv_process_message(sk, skb, p->path, &p->msg); + list_del(&p->list); + kfree(p); + if (!skb_queue_empty(&iucv->backlog_skb_q)) + break; + } +} + +static int iucv_sock_recvmsg(struct socket *sock, struct msghdr *msg, + size_t len, int flags) +{ + struct sock *sk = sock->sk; + struct iucv_sock *iucv = iucv_sk(sk); + unsigned int copied, rlen; + struct sk_buff *skb, *rskb, *cskb; + int err = 0; + u32 offset; + + if ((sk->sk_state == IUCV_DISCONN) && + skb_queue_empty(&iucv->backlog_skb_q) && + skb_queue_empty(&sk->sk_receive_queue) && + list_empty(&iucv->message_q.list)) + return 0; + + if (flags & (MSG_OOB)) + return -EOPNOTSUPP; + + /* receive/dequeue next skb: + * the function understands MSG_PEEK and, thus, does not dequeue skb */ + skb = skb_recv_datagram(sk, flags, &err); + if (!skb) { + if (sk->sk_shutdown & RCV_SHUTDOWN) + return 0; + return err; + } + + offset = IUCV_SKB_CB(skb)->offset; + rlen = skb->len - offset; /* real length of skb */ + copied = min_t(unsigned int, rlen, len); + if (!rlen) + sk->sk_shutdown = sk->sk_shutdown | RCV_SHUTDOWN; + + cskb = skb; + if (skb_copy_datagram_msg(cskb, offset, msg, copied)) { + if (!(flags & MSG_PEEK)) + skb_queue_head(&sk->sk_receive_queue, skb); + return -EFAULT; + } + + /* SOCK_SEQPACKET: set MSG_TRUNC if recv buf size is too small */ + if (sk->sk_type == SOCK_SEQPACKET) { + if (copied < rlen) + msg->msg_flags |= MSG_TRUNC; + /* each iucv message contains a complete record */ + msg->msg_flags |= MSG_EOR; + } + + /* create control message to store iucv msg target class: + * get the trgcls from the control buffer of the skb due to + * fragmentation of original iucv message. */ + err = put_cmsg(msg, SOL_IUCV, SCM_IUCV_TRGCLS, + sizeof(IUCV_SKB_CB(skb)->class), + (void *)&IUCV_SKB_CB(skb)->class); + if (err) { + if (!(flags & MSG_PEEK)) + skb_queue_head(&sk->sk_receive_queue, skb); + return err; + } + + /* Mark read part of skb as used */ + if (!(flags & MSG_PEEK)) { + + /* SOCK_STREAM: re-queue skb if it contains unreceived data */ + if (sk->sk_type == SOCK_STREAM) { + if (copied < rlen) { + IUCV_SKB_CB(skb)->offset = offset + copied; + skb_queue_head(&sk->sk_receive_queue, skb); + goto done; + } + } + + consume_skb(skb); + if (iucv->transport == AF_IUCV_TRANS_HIPER) { + atomic_inc(&iucv->msg_recv); + if (atomic_read(&iucv->msg_recv) > iucv->msglimit) { + WARN_ON(1); + iucv_sock_close(sk); + return -EFAULT; + } + } + + /* Queue backlog skbs */ + spin_lock_bh(&iucv->message_q.lock); + rskb = skb_dequeue(&iucv->backlog_skb_q); + while (rskb) { + IUCV_SKB_CB(rskb)->offset = 0; + if (__sock_queue_rcv_skb(sk, rskb)) { + /* handle rcv queue full */ + skb_queue_head(&iucv->backlog_skb_q, + rskb); + break; + } + rskb = skb_dequeue(&iucv->backlog_skb_q); + } + if (skb_queue_empty(&iucv->backlog_skb_q)) { + if (!list_empty(&iucv->message_q.list)) + iucv_process_message_q(sk); + if (atomic_read(&iucv->msg_recv) >= + iucv->msglimit / 2) { + err = iucv_send_ctrl(sk, AF_IUCV_FLAG_WIN); + if (err) { + sk->sk_state = IUCV_DISCONN; + sk->sk_state_change(sk); + } + } + } + spin_unlock_bh(&iucv->message_q.lock); + } + +done: + /* SOCK_SEQPACKET: return real length if MSG_TRUNC is set */ + if (sk->sk_type == SOCK_SEQPACKET && (flags & MSG_TRUNC)) + copied = rlen; + + return copied; +} + +static inline __poll_t iucv_accept_poll(struct sock *parent) +{ + struct iucv_sock *isk, *n; + struct sock *sk; + + list_for_each_entry_safe(isk, n, &iucv_sk(parent)->accept_q, accept_q) { + sk = (struct sock *) isk; + + if (sk->sk_state == IUCV_CONNECTED) + return EPOLLIN | EPOLLRDNORM; + } + + return 0; +} + +static __poll_t iucv_sock_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + struct sock *sk = sock->sk; + __poll_t mask = 0; + + sock_poll_wait(file, sock, wait); + + if (sk->sk_state == IUCV_LISTEN) + return iucv_accept_poll(sk); + + if (sk->sk_err || !skb_queue_empty(&sk->sk_error_queue)) + mask |= EPOLLERR | + (sock_flag(sk, SOCK_SELECT_ERR_QUEUE) ? EPOLLPRI : 0); + + if (sk->sk_shutdown & RCV_SHUTDOWN) + mask |= EPOLLRDHUP; + + if (sk->sk_shutdown == SHUTDOWN_MASK) + mask |= EPOLLHUP; + + if (!skb_queue_empty(&sk->sk_receive_queue) || + (sk->sk_shutdown & RCV_SHUTDOWN)) + mask |= EPOLLIN | EPOLLRDNORM; + + if (sk->sk_state == IUCV_CLOSED) + mask |= EPOLLHUP; + + if (sk->sk_state == IUCV_DISCONN) + mask |= EPOLLIN; + + if (sock_writeable(sk) && iucv_below_msglim(sk)) + mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND; + else + sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk); + + return mask; +} + +static int iucv_sock_shutdown(struct socket *sock, int how) +{ + struct sock *sk = sock->sk; + struct iucv_sock *iucv = iucv_sk(sk); + struct iucv_message txmsg; + int err = 0; + + how++; + + if ((how & ~SHUTDOWN_MASK) || !how) + return -EINVAL; + + lock_sock(sk); + switch (sk->sk_state) { + case IUCV_LISTEN: + case IUCV_DISCONN: + case IUCV_CLOSING: + case IUCV_CLOSED: + err = -ENOTCONN; + goto fail; + default: + break; + } + + if ((how == SEND_SHUTDOWN || how == SHUTDOWN_MASK) && + sk->sk_state == IUCV_CONNECTED) { + if (iucv->transport == AF_IUCV_TRANS_IUCV) { + txmsg.class = 0; + txmsg.tag = 0; + err = pr_iucv->message_send(iucv->path, &txmsg, + IUCV_IPRMDATA, 0, (void *) iprm_shutdown, 8); + if (err) { + switch (err) { + case 1: + err = -ENOTCONN; + break; + case 2: + err = -ECONNRESET; + break; + default: + err = -ENOTCONN; + break; + } + } + } else + iucv_send_ctrl(sk, AF_IUCV_FLAG_SHT); + } + + sk->sk_shutdown |= how; + if (how == RCV_SHUTDOWN || how == SHUTDOWN_MASK) { + if ((iucv->transport == AF_IUCV_TRANS_IUCV) && + iucv->path) { + err = pr_iucv->path_quiesce(iucv->path, NULL); + if (err) + err = -ENOTCONN; +/* skb_queue_purge(&sk->sk_receive_queue); */ + } + skb_queue_purge(&sk->sk_receive_queue); + } + + /* Wake up anyone sleeping in poll */ + sk->sk_state_change(sk); + +fail: + release_sock(sk); + return err; +} + +static int iucv_sock_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + int err = 0; + + if (!sk) + return 0; + + iucv_sock_close(sk); + + sock_orphan(sk); + iucv_sock_kill(sk); + return err; +} + +/* getsockopt and setsockopt */ +static int iucv_sock_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct iucv_sock *iucv = iucv_sk(sk); + int val; + int rc; + + if (level != SOL_IUCV) + return -ENOPROTOOPT; + + if (optlen < sizeof(int)) + return -EINVAL; + + if (copy_from_sockptr(&val, optval, sizeof(int))) + return -EFAULT; + + rc = 0; + + lock_sock(sk); + switch (optname) { + case SO_IPRMDATA_MSG: + if (val) + iucv->flags |= IUCV_IPRMDATA; + else + iucv->flags &= ~IUCV_IPRMDATA; + break; + case SO_MSGLIMIT: + switch (sk->sk_state) { + case IUCV_OPEN: + case IUCV_BOUND: + if (val < 1 || val > U16_MAX) + rc = -EINVAL; + else + iucv->msglimit = val; + break; + default: + rc = -EINVAL; + break; + } + break; + default: + rc = -ENOPROTOOPT; + break; + } + release_sock(sk); + + return rc; +} + +static int iucv_sock_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + struct iucv_sock *iucv = iucv_sk(sk); + unsigned int val; + int len; + + if (level != SOL_IUCV) + return -ENOPROTOOPT; + + if (get_user(len, optlen)) + return -EFAULT; + + if (len < 0) + return -EINVAL; + + len = min_t(unsigned int, len, sizeof(int)); + + switch (optname) { + case SO_IPRMDATA_MSG: + val = (iucv->flags & IUCV_IPRMDATA) ? 1 : 0; + break; + case SO_MSGLIMIT: + lock_sock(sk); + val = (iucv->path != NULL) ? iucv->path->msglim /* connected */ + : iucv->msglimit; /* default */ + release_sock(sk); + break; + case SO_MSGSIZE: + if (sk->sk_state == IUCV_OPEN) + return -EBADFD; + val = (iucv->hs_dev) ? iucv->hs_dev->mtu - + sizeof(struct af_iucv_trans_hdr) - ETH_HLEN : + 0x7fffffff; + break; + default: + return -ENOPROTOOPT; + } + + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + + return 0; +} + + +/* Callback wrappers - called from iucv base support */ +static int iucv_callback_connreq(struct iucv_path *path, + u8 ipvmid[8], u8 ipuser[16]) +{ + unsigned char user_data[16]; + unsigned char nuser_data[16]; + unsigned char src_name[8]; + struct sock *sk, *nsk; + struct iucv_sock *iucv, *niucv; + int err; + + memcpy(src_name, ipuser, 8); + EBCASC(src_name, 8); + /* Find out if this path belongs to af_iucv. */ + read_lock(&iucv_sk_list.lock); + iucv = NULL; + sk = NULL; + sk_for_each(sk, &iucv_sk_list.head) + if (sk->sk_state == IUCV_LISTEN && + !memcmp(&iucv_sk(sk)->src_name, src_name, 8)) { + /* + * Found a listening socket with + * src_name == ipuser[0-7]. + */ + iucv = iucv_sk(sk); + break; + } + read_unlock(&iucv_sk_list.lock); + if (!iucv) + /* No socket found, not one of our paths. */ + return -EINVAL; + + bh_lock_sock(sk); + + /* Check if parent socket is listening */ + low_nmcpy(user_data, iucv->src_name); + high_nmcpy(user_data, iucv->dst_name); + ASCEBC(user_data, sizeof(user_data)); + if (sk->sk_state != IUCV_LISTEN) { + err = pr_iucv->path_sever(path, user_data); + iucv_path_free(path); + goto fail; + } + + /* Check for backlog size */ + if (sk_acceptq_is_full(sk)) { + err = pr_iucv->path_sever(path, user_data); + iucv_path_free(path); + goto fail; + } + + /* Create the new socket */ + nsk = iucv_sock_alloc(NULL, sk->sk_protocol, GFP_ATOMIC, 0); + if (!nsk) { + err = pr_iucv->path_sever(path, user_data); + iucv_path_free(path); + goto fail; + } + + niucv = iucv_sk(nsk); + iucv_sock_init(nsk, sk); + niucv->transport = AF_IUCV_TRANS_IUCV; + nsk->sk_allocation |= GFP_DMA; + + /* Set the new iucv_sock */ + memcpy(niucv->dst_name, ipuser + 8, 8); + EBCASC(niucv->dst_name, 8); + memcpy(niucv->dst_user_id, ipvmid, 8); + memcpy(niucv->src_name, iucv->src_name, 8); + memcpy(niucv->src_user_id, iucv->src_user_id, 8); + niucv->path = path; + + /* Call iucv_accept */ + high_nmcpy(nuser_data, ipuser + 8); + memcpy(nuser_data + 8, niucv->src_name, 8); + ASCEBC(nuser_data + 8, 8); + + /* set message limit for path based on msglimit of accepting socket */ + niucv->msglimit = iucv->msglimit; + path->msglim = iucv->msglimit; + err = pr_iucv->path_accept(path, &af_iucv_handler, nuser_data, nsk); + if (err) { + iucv_sever_path(nsk, 1); + iucv_sock_kill(nsk); + goto fail; + } + + iucv_accept_enqueue(sk, nsk); + + /* Wake up accept */ + nsk->sk_state = IUCV_CONNECTED; + sk->sk_data_ready(sk); + err = 0; +fail: + bh_unlock_sock(sk); + return 0; +} + +static void iucv_callback_connack(struct iucv_path *path, u8 ipuser[16]) +{ + struct sock *sk = path->private; + + sk->sk_state = IUCV_CONNECTED; + sk->sk_state_change(sk); +} + +static void iucv_callback_rx(struct iucv_path *path, struct iucv_message *msg) +{ + struct sock *sk = path->private; + struct iucv_sock *iucv = iucv_sk(sk); + struct sk_buff *skb; + struct sock_msg_q *save_msg; + int len; + + if (sk->sk_shutdown & RCV_SHUTDOWN) { + pr_iucv->message_reject(path, msg); + return; + } + + spin_lock(&iucv->message_q.lock); + + if (!list_empty(&iucv->message_q.list) || + !skb_queue_empty(&iucv->backlog_skb_q)) + goto save_message; + + len = atomic_read(&sk->sk_rmem_alloc); + len += SKB_TRUESIZE(iucv_msg_length(msg)); + if (len > sk->sk_rcvbuf) + goto save_message; + + skb = alloc_iucv_recv_skb(iucv_msg_length(msg)); + if (!skb) + goto save_message; + + iucv_process_message(sk, skb, path, msg); + goto out_unlock; + +save_message: + save_msg = kzalloc(sizeof(struct sock_msg_q), GFP_ATOMIC | GFP_DMA); + if (!save_msg) + goto out_unlock; + save_msg->path = path; + save_msg->msg = *msg; + + list_add_tail(&save_msg->list, &iucv->message_q.list); + +out_unlock: + spin_unlock(&iucv->message_q.lock); +} + +static void iucv_callback_txdone(struct iucv_path *path, + struct iucv_message *msg) +{ + struct sock *sk = path->private; + struct sk_buff *this = NULL; + struct sk_buff_head *list; + struct sk_buff *list_skb; + struct iucv_sock *iucv; + unsigned long flags; + + iucv = iucv_sk(sk); + list = &iucv->send_skb_q; + + bh_lock_sock(sk); + + spin_lock_irqsave(&list->lock, flags); + skb_queue_walk(list, list_skb) { + if (msg->tag == IUCV_SKB_CB(list_skb)->tag) { + this = list_skb; + break; + } + } + if (this) { + atomic_dec(&iucv->skbs_in_xmit); + __skb_unlink(this, list); + } + + spin_unlock_irqrestore(&list->lock, flags); + + if (this) { + consume_skb(this); + /* wake up any process waiting for sending */ + iucv_sock_wake_msglim(sk); + } + + if (sk->sk_state == IUCV_CLOSING) { + if (atomic_read(&iucv->skbs_in_xmit) == 0) { + sk->sk_state = IUCV_CLOSED; + sk->sk_state_change(sk); + } + } + bh_unlock_sock(sk); + +} + +static void iucv_callback_connrej(struct iucv_path *path, u8 ipuser[16]) +{ + struct sock *sk = path->private; + + if (sk->sk_state == IUCV_CLOSED) + return; + + bh_lock_sock(sk); + iucv_sever_path(sk, 1); + sk->sk_state = IUCV_DISCONN; + + sk->sk_state_change(sk); + bh_unlock_sock(sk); +} + +/* called if the other communication side shuts down its RECV direction; + * in turn, the callback sets SEND_SHUTDOWN to disable sending of data. + */ +static void iucv_callback_shutdown(struct iucv_path *path, u8 ipuser[16]) +{ + struct sock *sk = path->private; + + bh_lock_sock(sk); + if (sk->sk_state != IUCV_CLOSED) { + sk->sk_shutdown |= SEND_SHUTDOWN; + sk->sk_state_change(sk); + } + bh_unlock_sock(sk); +} + +static struct iucv_handler af_iucv_handler = { + .path_pending = iucv_callback_connreq, + .path_complete = iucv_callback_connack, + .path_severed = iucv_callback_connrej, + .message_pending = iucv_callback_rx, + .message_complete = iucv_callback_txdone, + .path_quiesced = iucv_callback_shutdown, +}; + +/***************** HiperSockets transport callbacks ********************/ +static void afiucv_swap_src_dest(struct sk_buff *skb) +{ + struct af_iucv_trans_hdr *trans_hdr = iucv_trans_hdr(skb); + char tmpID[8]; + char tmpName[8]; + + ASCEBC(trans_hdr->destUserID, sizeof(trans_hdr->destUserID)); + ASCEBC(trans_hdr->destAppName, sizeof(trans_hdr->destAppName)); + ASCEBC(trans_hdr->srcUserID, sizeof(trans_hdr->srcUserID)); + ASCEBC(trans_hdr->srcAppName, sizeof(trans_hdr->srcAppName)); + memcpy(tmpID, trans_hdr->srcUserID, 8); + memcpy(tmpName, trans_hdr->srcAppName, 8); + memcpy(trans_hdr->srcUserID, trans_hdr->destUserID, 8); + memcpy(trans_hdr->srcAppName, trans_hdr->destAppName, 8); + memcpy(trans_hdr->destUserID, tmpID, 8); + memcpy(trans_hdr->destAppName, tmpName, 8); + skb_push(skb, ETH_HLEN); + memset(skb->data, 0, ETH_HLEN); +} + +/* + * afiucv_hs_callback_syn - react on received SYN + */ +static int afiucv_hs_callback_syn(struct sock *sk, struct sk_buff *skb) +{ + struct af_iucv_trans_hdr *trans_hdr = iucv_trans_hdr(skb); + struct sock *nsk; + struct iucv_sock *iucv, *niucv; + int err; + + iucv = iucv_sk(sk); + if (!iucv) { + /* no sock - connection refused */ + afiucv_swap_src_dest(skb); + trans_hdr->flags = AF_IUCV_FLAG_SYN | AF_IUCV_FLAG_FIN; + err = dev_queue_xmit(skb); + goto out; + } + + nsk = iucv_sock_alloc(NULL, sk->sk_protocol, GFP_ATOMIC, 0); + bh_lock_sock(sk); + if ((sk->sk_state != IUCV_LISTEN) || + sk_acceptq_is_full(sk) || + !nsk) { + /* error on server socket - connection refused */ + afiucv_swap_src_dest(skb); + trans_hdr->flags = AF_IUCV_FLAG_SYN | AF_IUCV_FLAG_FIN; + err = dev_queue_xmit(skb); + iucv_sock_kill(nsk); + bh_unlock_sock(sk); + goto out; + } + + niucv = iucv_sk(nsk); + iucv_sock_init(nsk, sk); + niucv->transport = AF_IUCV_TRANS_HIPER; + niucv->msglimit = iucv->msglimit; + if (!trans_hdr->window) + niucv->msglimit_peer = IUCV_HIPER_MSGLIM_DEFAULT; + else + niucv->msglimit_peer = trans_hdr->window; + memcpy(niucv->dst_name, trans_hdr->srcAppName, 8); + memcpy(niucv->dst_user_id, trans_hdr->srcUserID, 8); + memcpy(niucv->src_name, iucv->src_name, 8); + memcpy(niucv->src_user_id, iucv->src_user_id, 8); + nsk->sk_bound_dev_if = sk->sk_bound_dev_if; + niucv->hs_dev = iucv->hs_dev; + dev_hold(niucv->hs_dev); + afiucv_swap_src_dest(skb); + trans_hdr->flags = AF_IUCV_FLAG_SYN | AF_IUCV_FLAG_ACK; + trans_hdr->window = niucv->msglimit; + /* if receiver acks the xmit connection is established */ + err = dev_queue_xmit(skb); + if (!err) { + iucv_accept_enqueue(sk, nsk); + nsk->sk_state = IUCV_CONNECTED; + sk->sk_data_ready(sk); + } else + iucv_sock_kill(nsk); + bh_unlock_sock(sk); + +out: + return NET_RX_SUCCESS; +} + +/* + * afiucv_hs_callback_synack() - react on received SYN-ACK + */ +static int afiucv_hs_callback_synack(struct sock *sk, struct sk_buff *skb) +{ + struct iucv_sock *iucv = iucv_sk(sk); + + if (!iucv || sk->sk_state != IUCV_BOUND) { + kfree_skb(skb); + return NET_RX_SUCCESS; + } + + bh_lock_sock(sk); + iucv->msglimit_peer = iucv_trans_hdr(skb)->window; + sk->sk_state = IUCV_CONNECTED; + sk->sk_state_change(sk); + bh_unlock_sock(sk); + consume_skb(skb); + return NET_RX_SUCCESS; +} + +/* + * afiucv_hs_callback_synfin() - react on received SYN_FIN + */ +static int afiucv_hs_callback_synfin(struct sock *sk, struct sk_buff *skb) +{ + struct iucv_sock *iucv = iucv_sk(sk); + + if (!iucv || sk->sk_state != IUCV_BOUND) { + kfree_skb(skb); + return NET_RX_SUCCESS; + } + + bh_lock_sock(sk); + sk->sk_state = IUCV_DISCONN; + sk->sk_state_change(sk); + bh_unlock_sock(sk); + consume_skb(skb); + return NET_RX_SUCCESS; +} + +/* + * afiucv_hs_callback_fin() - react on received FIN + */ +static int afiucv_hs_callback_fin(struct sock *sk, struct sk_buff *skb) +{ + struct iucv_sock *iucv = iucv_sk(sk); + + /* other end of connection closed */ + if (!iucv) { + kfree_skb(skb); + return NET_RX_SUCCESS; + } + + bh_lock_sock(sk); + if (sk->sk_state == IUCV_CONNECTED) { + sk->sk_state = IUCV_DISCONN; + sk->sk_state_change(sk); + } + bh_unlock_sock(sk); + consume_skb(skb); + return NET_RX_SUCCESS; +} + +/* + * afiucv_hs_callback_win() - react on received WIN + */ +static int afiucv_hs_callback_win(struct sock *sk, struct sk_buff *skb) +{ + struct iucv_sock *iucv = iucv_sk(sk); + + if (!iucv) + return NET_RX_SUCCESS; + + if (sk->sk_state != IUCV_CONNECTED) + return NET_RX_SUCCESS; + + atomic_sub(iucv_trans_hdr(skb)->window, &iucv->msg_sent); + iucv_sock_wake_msglim(sk); + return NET_RX_SUCCESS; +} + +/* + * afiucv_hs_callback_rx() - react on received data + */ +static int afiucv_hs_callback_rx(struct sock *sk, struct sk_buff *skb) +{ + struct iucv_sock *iucv = iucv_sk(sk); + + if (!iucv) { + kfree_skb(skb); + return NET_RX_SUCCESS; + } + + if (sk->sk_state != IUCV_CONNECTED) { + kfree_skb(skb); + return NET_RX_SUCCESS; + } + + if (sk->sk_shutdown & RCV_SHUTDOWN) { + kfree_skb(skb); + return NET_RX_SUCCESS; + } + + /* write stuff from iucv_msg to skb cb */ + skb_pull(skb, sizeof(struct af_iucv_trans_hdr)); + skb_reset_transport_header(skb); + skb_reset_network_header(skb); + IUCV_SKB_CB(skb)->offset = 0; + if (sk_filter(sk, skb)) { + atomic_inc(&sk->sk_drops); /* skb rejected by filter */ + kfree_skb(skb); + return NET_RX_SUCCESS; + } + + spin_lock(&iucv->message_q.lock); + if (skb_queue_empty(&iucv->backlog_skb_q)) { + if (__sock_queue_rcv_skb(sk, skb)) + /* handle rcv queue full */ + skb_queue_tail(&iucv->backlog_skb_q, skb); + } else + skb_queue_tail(&iucv_sk(sk)->backlog_skb_q, skb); + spin_unlock(&iucv->message_q.lock); + return NET_RX_SUCCESS; +} + +/* + * afiucv_hs_rcv() - base function for arriving data through HiperSockets + * transport + * called from netif RX softirq + */ +static int afiucv_hs_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + struct sock *sk; + struct iucv_sock *iucv; + struct af_iucv_trans_hdr *trans_hdr; + int err = NET_RX_SUCCESS; + char nullstring[8]; + + if (!pskb_may_pull(skb, sizeof(*trans_hdr))) { + kfree_skb(skb); + return NET_RX_SUCCESS; + } + + trans_hdr = iucv_trans_hdr(skb); + EBCASC(trans_hdr->destAppName, sizeof(trans_hdr->destAppName)); + EBCASC(trans_hdr->destUserID, sizeof(trans_hdr->destUserID)); + EBCASC(trans_hdr->srcAppName, sizeof(trans_hdr->srcAppName)); + EBCASC(trans_hdr->srcUserID, sizeof(trans_hdr->srcUserID)); + memset(nullstring, 0, sizeof(nullstring)); + iucv = NULL; + sk = NULL; + read_lock(&iucv_sk_list.lock); + sk_for_each(sk, &iucv_sk_list.head) { + if (trans_hdr->flags == AF_IUCV_FLAG_SYN) { + if ((!memcmp(&iucv_sk(sk)->src_name, + trans_hdr->destAppName, 8)) && + (!memcmp(&iucv_sk(sk)->src_user_id, + trans_hdr->destUserID, 8)) && + (!memcmp(&iucv_sk(sk)->dst_name, nullstring, 8)) && + (!memcmp(&iucv_sk(sk)->dst_user_id, + nullstring, 8))) { + iucv = iucv_sk(sk); + break; + } + } else { + if ((!memcmp(&iucv_sk(sk)->src_name, + trans_hdr->destAppName, 8)) && + (!memcmp(&iucv_sk(sk)->src_user_id, + trans_hdr->destUserID, 8)) && + (!memcmp(&iucv_sk(sk)->dst_name, + trans_hdr->srcAppName, 8)) && + (!memcmp(&iucv_sk(sk)->dst_user_id, + trans_hdr->srcUserID, 8))) { + iucv = iucv_sk(sk); + break; + } + } + } + read_unlock(&iucv_sk_list.lock); + if (!iucv) + sk = NULL; + + /* no sock + how should we send with no sock + 1) send without sock no send rc checking? + 2) introduce default sock to handle this cases + + SYN -> send SYN|ACK in good case, send SYN|FIN in bad case + data -> send FIN + SYN|ACK, SYN|FIN, FIN -> no action? */ + + switch (trans_hdr->flags) { + case AF_IUCV_FLAG_SYN: + /* connect request */ + err = afiucv_hs_callback_syn(sk, skb); + break; + case (AF_IUCV_FLAG_SYN | AF_IUCV_FLAG_ACK): + /* connect request confirmed */ + err = afiucv_hs_callback_synack(sk, skb); + break; + case (AF_IUCV_FLAG_SYN | AF_IUCV_FLAG_FIN): + /* connect request refused */ + err = afiucv_hs_callback_synfin(sk, skb); + break; + case (AF_IUCV_FLAG_FIN): + /* close request */ + err = afiucv_hs_callback_fin(sk, skb); + break; + case (AF_IUCV_FLAG_WIN): + err = afiucv_hs_callback_win(sk, skb); + if (skb->len == sizeof(struct af_iucv_trans_hdr)) { + consume_skb(skb); + break; + } + fallthrough; /* and receive non-zero length data */ + case (AF_IUCV_FLAG_SHT): + /* shutdown request */ + fallthrough; /* and receive zero length data */ + case 0: + /* plain data frame */ + IUCV_SKB_CB(skb)->class = trans_hdr->iucv_hdr.class; + err = afiucv_hs_callback_rx(sk, skb); + break; + default: + kfree_skb(skb); + } + + return err; +} + +/* + * afiucv_hs_callback_txnotify() - handle send notifications from HiperSockets + * transport + */ +static void afiucv_hs_callback_txnotify(struct sock *sk, enum iucv_tx_notify n) +{ + struct iucv_sock *iucv = iucv_sk(sk); + + if (sock_flag(sk, SOCK_ZAPPED)) + return; + + switch (n) { + case TX_NOTIFY_OK: + atomic_dec(&iucv->skbs_in_xmit); + iucv_sock_wake_msglim(sk); + break; + case TX_NOTIFY_PENDING: + atomic_inc(&iucv->pendings); + break; + case TX_NOTIFY_DELAYED_OK: + atomic_dec(&iucv->skbs_in_xmit); + if (atomic_dec_return(&iucv->pendings) <= 0) + iucv_sock_wake_msglim(sk); + break; + default: + atomic_dec(&iucv->skbs_in_xmit); + if (sk->sk_state == IUCV_CONNECTED) { + sk->sk_state = IUCV_DISCONN; + sk->sk_state_change(sk); + } + } + + if (sk->sk_state == IUCV_CLOSING) { + if (atomic_read(&iucv->skbs_in_xmit) == 0) { + sk->sk_state = IUCV_CLOSED; + sk->sk_state_change(sk); + } + } +} + +/* + * afiucv_netdev_event: handle netdev notifier chain events + */ +static int afiucv_netdev_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct net_device *event_dev = netdev_notifier_info_to_dev(ptr); + struct sock *sk; + struct iucv_sock *iucv; + + switch (event) { + case NETDEV_REBOOT: + case NETDEV_GOING_DOWN: + sk_for_each(sk, &iucv_sk_list.head) { + iucv = iucv_sk(sk); + if ((iucv->hs_dev == event_dev) && + (sk->sk_state == IUCV_CONNECTED)) { + if (event == NETDEV_GOING_DOWN) + iucv_send_ctrl(sk, AF_IUCV_FLAG_FIN); + sk->sk_state = IUCV_DISCONN; + sk->sk_state_change(sk); + } + } + break; + case NETDEV_DOWN: + case NETDEV_UNREGISTER: + default: + break; + } + return NOTIFY_DONE; +} + +static struct notifier_block afiucv_netdev_notifier = { + .notifier_call = afiucv_netdev_event, +}; + +static const struct proto_ops iucv_sock_ops = { + .family = PF_IUCV, + .owner = THIS_MODULE, + .release = iucv_sock_release, + .bind = iucv_sock_bind, + .connect = iucv_sock_connect, + .listen = iucv_sock_listen, + .accept = iucv_sock_accept, + .getname = iucv_sock_getname, + .sendmsg = iucv_sock_sendmsg, + .recvmsg = iucv_sock_recvmsg, + .poll = iucv_sock_poll, + .ioctl = sock_no_ioctl, + .mmap = sock_no_mmap, + .socketpair = sock_no_socketpair, + .shutdown = iucv_sock_shutdown, + .setsockopt = iucv_sock_setsockopt, + .getsockopt = iucv_sock_getsockopt, +}; + +static int iucv_sock_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + + if (protocol && protocol != PF_IUCV) + return -EPROTONOSUPPORT; + + sock->state = SS_UNCONNECTED; + + switch (sock->type) { + case SOCK_STREAM: + case SOCK_SEQPACKET: + /* currently, proto ops can handle both sk types */ + sock->ops = &iucv_sock_ops; + break; + default: + return -ESOCKTNOSUPPORT; + } + + sk = iucv_sock_alloc(sock, protocol, GFP_KERNEL, kern); + if (!sk) + return -ENOMEM; + + iucv_sock_init(sk, NULL); + + return 0; +} + +static const struct net_proto_family iucv_sock_family_ops = { + .family = AF_IUCV, + .owner = THIS_MODULE, + .create = iucv_sock_create, +}; + +static struct packet_type iucv_packet_type = { + .type = cpu_to_be16(ETH_P_AF_IUCV), + .func = afiucv_hs_rcv, +}; + +static int __init afiucv_init(void) +{ + int err; + + if (MACHINE_IS_VM && IS_ENABLED(CONFIG_IUCV)) { + cpcmd("QUERY USERID", iucv_userid, sizeof(iucv_userid), &err); + if (unlikely(err)) { + WARN_ON(err); + err = -EPROTONOSUPPORT; + goto out; + } + + pr_iucv = &iucv_if; + } else { + memset(&iucv_userid, 0, sizeof(iucv_userid)); + pr_iucv = NULL; + } + + err = proto_register(&iucv_proto, 0); + if (err) + goto out; + err = sock_register(&iucv_sock_family_ops); + if (err) + goto out_proto; + + if (pr_iucv) { + err = pr_iucv->iucv_register(&af_iucv_handler, 0); + if (err) + goto out_sock; + } + + err = register_netdevice_notifier(&afiucv_netdev_notifier); + if (err) + goto out_notifier; + + dev_add_pack(&iucv_packet_type); + return 0; + +out_notifier: + if (pr_iucv) + pr_iucv->iucv_unregister(&af_iucv_handler, 0); +out_sock: + sock_unregister(PF_IUCV); +out_proto: + proto_unregister(&iucv_proto); +out: + return err; +} + +static void __exit afiucv_exit(void) +{ + if (pr_iucv) + pr_iucv->iucv_unregister(&af_iucv_handler, 0); + + unregister_netdevice_notifier(&afiucv_netdev_notifier); + dev_remove_pack(&iucv_packet_type); + sock_unregister(PF_IUCV); + proto_unregister(&iucv_proto); +} + +module_init(afiucv_init); +module_exit(afiucv_exit); + +MODULE_AUTHOR("Jennifer Hunt "); +MODULE_DESCRIPTION("IUCV Sockets ver " VERSION); +MODULE_VERSION(VERSION); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_IUCV); diff --git a/net/iucv/iucv.c b/net/iucv/iucv.c new file mode 100644 index 000000000..fc3fddeb6 --- /dev/null +++ b/net/iucv/iucv.c @@ -0,0 +1,1908 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IUCV base infrastructure. + * + * Copyright IBM Corp. 2001, 2009 + * + * Author(s): + * Original source: + * Alan Altmark (Alan_Altmark@us.ibm.com) Sept. 2000 + * Xenia Tkatschow (xenia@us.ibm.com) + * 2Gb awareness and general cleanup: + * Fritz Elfert (elfert@de.ibm.com, felfert@millenux.com) + * Rewritten for af_iucv: + * Martin Schwidefsky + * PM functions: + * Ursula Braun (ursula.braun@de.ibm.com) + * + * Documentation used: + * The original source + * CP Programming Service, IBM document # SC24-5760 + */ + +#define KMSG_COMPONENT "iucv" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * FLAGS: + * All flags are defined in the field IPFLAGS1 of each function + * and can be found in CP Programming Services. + * IPSRCCLS - Indicates you have specified a source class. + * IPTRGCLS - Indicates you have specified a target class. + * IPFGPID - Indicates you have specified a pathid. + * IPFGMID - Indicates you have specified a message ID. + * IPNORPY - Indicates a one-way message. No reply expected. + * IPALL - Indicates that all paths are affected. + */ +#define IUCV_IPSRCCLS 0x01 +#define IUCV_IPTRGCLS 0x01 +#define IUCV_IPFGPID 0x02 +#define IUCV_IPFGMID 0x04 +#define IUCV_IPNORPY 0x10 +#define IUCV_IPALL 0x80 + +static int iucv_bus_match(struct device *dev, struct device_driver *drv) +{ + return 0; +} + +struct bus_type iucv_bus = { + .name = "iucv", + .match = iucv_bus_match, +}; +EXPORT_SYMBOL(iucv_bus); + +struct device *iucv_root; +EXPORT_SYMBOL(iucv_root); + +static int iucv_available; + +/* General IUCV interrupt structure */ +struct iucv_irq_data { + u16 ippathid; + u8 ipflags1; + u8 iptype; + u32 res2[9]; +}; + +struct iucv_irq_list { + struct list_head list; + struct iucv_irq_data data; +}; + +static struct iucv_irq_data *iucv_irq_data[NR_CPUS]; +static cpumask_t iucv_buffer_cpumask = { CPU_BITS_NONE }; +static cpumask_t iucv_irq_cpumask = { CPU_BITS_NONE }; + +/* + * Queue of interrupt buffers lock for delivery via the tasklet + * (fast but can't call smp_call_function). + */ +static LIST_HEAD(iucv_task_queue); + +/* + * The tasklet for fast delivery of iucv interrupts. + */ +static void iucv_tasklet_fn(unsigned long); +static DECLARE_TASKLET_OLD(iucv_tasklet, iucv_tasklet_fn); + +/* + * Queue of interrupt buffers for delivery via a work queue + * (slower but can call smp_call_function). + */ +static LIST_HEAD(iucv_work_queue); + +/* + * The work element to deliver path pending interrupts. + */ +static void iucv_work_fn(struct work_struct *work); +static DECLARE_WORK(iucv_work, iucv_work_fn); + +/* + * Spinlock protecting task and work queue. + */ +static DEFINE_SPINLOCK(iucv_queue_lock); + +enum iucv_command_codes { + IUCV_QUERY = 0, + IUCV_RETRIEVE_BUFFER = 2, + IUCV_SEND = 4, + IUCV_RECEIVE = 5, + IUCV_REPLY = 6, + IUCV_REJECT = 8, + IUCV_PURGE = 9, + IUCV_ACCEPT = 10, + IUCV_CONNECT = 11, + IUCV_DECLARE_BUFFER = 12, + IUCV_QUIESCE = 13, + IUCV_RESUME = 14, + IUCV_SEVER = 15, + IUCV_SETMASK = 16, + IUCV_SETCONTROLMASK = 17, +}; + +/* + * Error messages that are used with the iucv_sever function. They get + * converted to EBCDIC. + */ +static char iucv_error_no_listener[16] = "NO LISTENER"; +static char iucv_error_no_memory[16] = "NO MEMORY"; +static char iucv_error_pathid[16] = "INVALID PATHID"; + +/* + * iucv_handler_list: List of registered handlers. + */ +static LIST_HEAD(iucv_handler_list); + +/* + * iucv_path_table: an array of iucv_path structures. + */ +static struct iucv_path **iucv_path_table; +static unsigned long iucv_max_pathid; + +/* + * iucv_lock: spinlock protecting iucv_handler_list and iucv_pathid_table + */ +static DEFINE_SPINLOCK(iucv_table_lock); + +/* + * iucv_active_cpu: contains the number of the cpu executing the tasklet + * or the work handler. Needed for iucv_path_sever called from tasklet. + */ +static int iucv_active_cpu = -1; + +/* + * Mutex and wait queue for iucv_register/iucv_unregister. + */ +static DEFINE_MUTEX(iucv_register_mutex); + +/* + * Counter for number of non-smp capable handlers. + */ +static int iucv_nonsmp_handler; + +/* + * IUCV control data structure. Used by iucv_path_accept, iucv_path_connect, + * iucv_path_quiesce and iucv_path_sever. + */ +struct iucv_cmd_control { + u16 ippathid; + u8 ipflags1; + u8 iprcode; + u16 ipmsglim; + u16 res1; + u8 ipvmid[8]; + u8 ipuser[16]; + u8 iptarget[8]; +} __attribute__ ((packed,aligned(8))); + +/* + * Data in parameter list iucv structure. Used by iucv_message_send, + * iucv_message_send2way and iucv_message_reply. + */ +struct iucv_cmd_dpl { + u16 ippathid; + u8 ipflags1; + u8 iprcode; + u32 ipmsgid; + u32 iptrgcls; + u8 iprmmsg[8]; + u32 ipsrccls; + u32 ipmsgtag; + u32 ipbfadr2; + u32 ipbfln2f; + u32 res; +} __attribute__ ((packed,aligned(8))); + +/* + * Data in buffer iucv structure. Used by iucv_message_receive, + * iucv_message_reject, iucv_message_send, iucv_message_send2way + * and iucv_declare_cpu. + */ +struct iucv_cmd_db { + u16 ippathid; + u8 ipflags1; + u8 iprcode; + u32 ipmsgid; + u32 iptrgcls; + u32 ipbfadr1; + u32 ipbfln1f; + u32 ipsrccls; + u32 ipmsgtag; + u32 ipbfadr2; + u32 ipbfln2f; + u32 res; +} __attribute__ ((packed,aligned(8))); + +/* + * Purge message iucv structure. Used by iucv_message_purge. + */ +struct iucv_cmd_purge { + u16 ippathid; + u8 ipflags1; + u8 iprcode; + u32 ipmsgid; + u8 ipaudit[3]; + u8 res1[5]; + u32 res2; + u32 ipsrccls; + u32 ipmsgtag; + u32 res3[3]; +} __attribute__ ((packed,aligned(8))); + +/* + * Set mask iucv structure. Used by iucv_enable_cpu. + */ +struct iucv_cmd_set_mask { + u8 ipmask; + u8 res1[2]; + u8 iprcode; + u32 res2[9]; +} __attribute__ ((packed,aligned(8))); + +union iucv_param { + struct iucv_cmd_control ctrl; + struct iucv_cmd_dpl dpl; + struct iucv_cmd_db db; + struct iucv_cmd_purge purge; + struct iucv_cmd_set_mask set_mask; +}; + +/* + * Anchor for per-cpu IUCV command parameter block. + */ +static union iucv_param *iucv_param[NR_CPUS]; +static union iucv_param *iucv_param_irq[NR_CPUS]; + +/** + * __iucv_call_b2f0 + * @command: identifier of IUCV call to CP. + * @parm: pointer to a struct iucv_parm block + * + * Calls CP to execute IUCV commands. + * + * Returns the result of the CP IUCV call. + */ +static inline int __iucv_call_b2f0(int command, union iucv_param *parm) +{ + int cc; + + asm volatile( + " lgr 0,%[reg0]\n" + " lgr 1,%[reg1]\n" + " .long 0xb2f01000\n" + " ipm %[cc]\n" + " srl %[cc],28\n" + : [cc] "=&d" (cc), "+m" (*parm) + : [reg0] "d" ((unsigned long)command), + [reg1] "d" ((unsigned long)parm) + : "cc", "0", "1"); + return cc; +} + +static inline int iucv_call_b2f0(int command, union iucv_param *parm) +{ + int ccode; + + ccode = __iucv_call_b2f0(command, parm); + return ccode == 1 ? parm->ctrl.iprcode : ccode; +} + +/* + * iucv_query_maxconn + * + * Determines the maximum number of connections that may be established. + * + * Returns the maximum number of connections or -EPERM is IUCV is not + * available. + */ +static int __iucv_query_maxconn(void *param, unsigned long *max_pathid) +{ + unsigned long reg1 = virt_to_phys(param); + int cc; + + asm volatile ( + " lghi 0,%[cmd]\n" + " lgr 1,%[reg1]\n" + " .long 0xb2f01000\n" + " ipm %[cc]\n" + " srl %[cc],28\n" + " lgr %[reg1],1\n" + : [cc] "=&d" (cc), [reg1] "+&d" (reg1) + : [cmd] "K" (IUCV_QUERY) + : "cc", "0", "1"); + *max_pathid = reg1; + return cc; +} + +static int iucv_query_maxconn(void) +{ + unsigned long max_pathid; + void *param; + int ccode; + + param = kzalloc(sizeof(union iucv_param), GFP_KERNEL | GFP_DMA); + if (!param) + return -ENOMEM; + ccode = __iucv_query_maxconn(param, &max_pathid); + if (ccode == 0) + iucv_max_pathid = max_pathid; + kfree(param); + return ccode ? -EPERM : 0; +} + +/** + * iucv_allow_cpu + * @data: unused + * + * Allow iucv interrupts on this cpu. + */ +static void iucv_allow_cpu(void *data) +{ + int cpu = smp_processor_id(); + union iucv_param *parm; + + /* + * Enable all iucv interrupts. + * ipmask contains bits for the different interrupts + * 0x80 - Flag to allow nonpriority message pending interrupts + * 0x40 - Flag to allow priority message pending interrupts + * 0x20 - Flag to allow nonpriority message completion interrupts + * 0x10 - Flag to allow priority message completion interrupts + * 0x08 - Flag to allow IUCV control interrupts + */ + parm = iucv_param_irq[cpu]; + memset(parm, 0, sizeof(union iucv_param)); + parm->set_mask.ipmask = 0xf8; + iucv_call_b2f0(IUCV_SETMASK, parm); + + /* + * Enable all iucv control interrupts. + * ipmask contains bits for the different interrupts + * 0x80 - Flag to allow pending connections interrupts + * 0x40 - Flag to allow connection complete interrupts + * 0x20 - Flag to allow connection severed interrupts + * 0x10 - Flag to allow connection quiesced interrupts + * 0x08 - Flag to allow connection resumed interrupts + */ + memset(parm, 0, sizeof(union iucv_param)); + parm->set_mask.ipmask = 0xf8; + iucv_call_b2f0(IUCV_SETCONTROLMASK, parm); + /* Set indication that iucv interrupts are allowed for this cpu. */ + cpumask_set_cpu(cpu, &iucv_irq_cpumask); +} + +/** + * iucv_block_cpu + * @data: unused + * + * Block iucv interrupts on this cpu. + */ +static void iucv_block_cpu(void *data) +{ + int cpu = smp_processor_id(); + union iucv_param *parm; + + /* Disable all iucv interrupts. */ + parm = iucv_param_irq[cpu]; + memset(parm, 0, sizeof(union iucv_param)); + iucv_call_b2f0(IUCV_SETMASK, parm); + + /* Clear indication that iucv interrupts are allowed for this cpu. */ + cpumask_clear_cpu(cpu, &iucv_irq_cpumask); +} + +/** + * iucv_declare_cpu + * @data: unused + * + * Declare a interrupt buffer on this cpu. + */ +static void iucv_declare_cpu(void *data) +{ + int cpu = smp_processor_id(); + union iucv_param *parm; + int rc; + + if (cpumask_test_cpu(cpu, &iucv_buffer_cpumask)) + return; + + /* Declare interrupt buffer. */ + parm = iucv_param_irq[cpu]; + memset(parm, 0, sizeof(union iucv_param)); + parm->db.ipbfadr1 = virt_to_phys(iucv_irq_data[cpu]); + rc = iucv_call_b2f0(IUCV_DECLARE_BUFFER, parm); + if (rc) { + char *err = "Unknown"; + switch (rc) { + case 0x03: + err = "Directory error"; + break; + case 0x0a: + err = "Invalid length"; + break; + case 0x13: + err = "Buffer already exists"; + break; + case 0x3e: + err = "Buffer overlap"; + break; + case 0x5c: + err = "Paging or storage error"; + break; + } + pr_warn("Defining an interrupt buffer on CPU %i failed with 0x%02x (%s)\n", + cpu, rc, err); + return; + } + + /* Set indication that an iucv buffer exists for this cpu. */ + cpumask_set_cpu(cpu, &iucv_buffer_cpumask); + + if (iucv_nonsmp_handler == 0 || cpumask_empty(&iucv_irq_cpumask)) + /* Enable iucv interrupts on this cpu. */ + iucv_allow_cpu(NULL); + else + /* Disable iucv interrupts on this cpu. */ + iucv_block_cpu(NULL); +} + +/** + * iucv_retrieve_cpu + * @data: unused + * + * Retrieve interrupt buffer on this cpu. + */ +static void iucv_retrieve_cpu(void *data) +{ + int cpu = smp_processor_id(); + union iucv_param *parm; + + if (!cpumask_test_cpu(cpu, &iucv_buffer_cpumask)) + return; + + /* Block iucv interrupts. */ + iucv_block_cpu(NULL); + + /* Retrieve interrupt buffer. */ + parm = iucv_param_irq[cpu]; + iucv_call_b2f0(IUCV_RETRIEVE_BUFFER, parm); + + /* Clear indication that an iucv buffer exists for this cpu. */ + cpumask_clear_cpu(cpu, &iucv_buffer_cpumask); +} + +/* + * iucv_setmask_mp + * + * Allow iucv interrupts on all cpus. + */ +static void iucv_setmask_mp(void) +{ + int cpu; + + cpus_read_lock(); + for_each_online_cpu(cpu) + /* Enable all cpus with a declared buffer. */ + if (cpumask_test_cpu(cpu, &iucv_buffer_cpumask) && + !cpumask_test_cpu(cpu, &iucv_irq_cpumask)) + smp_call_function_single(cpu, iucv_allow_cpu, + NULL, 1); + cpus_read_unlock(); +} + +/* + * iucv_setmask_up + * + * Allow iucv interrupts on a single cpu. + */ +static void iucv_setmask_up(void) +{ + cpumask_t cpumask; + int cpu; + + /* Disable all cpu but the first in cpu_irq_cpumask. */ + cpumask_copy(&cpumask, &iucv_irq_cpumask); + cpumask_clear_cpu(cpumask_first(&iucv_irq_cpumask), &cpumask); + for_each_cpu(cpu, &cpumask) + smp_call_function_single(cpu, iucv_block_cpu, NULL, 1); +} + +/* + * iucv_enable + * + * This function makes iucv ready for use. It allocates the pathid + * table, declares an iucv interrupt buffer and enables the iucv + * interrupts. Called when the first user has registered an iucv + * handler. + */ +static int iucv_enable(void) +{ + size_t alloc_size; + int cpu, rc; + + cpus_read_lock(); + rc = -ENOMEM; + alloc_size = iucv_max_pathid * sizeof(struct iucv_path); + iucv_path_table = kzalloc(alloc_size, GFP_KERNEL); + if (!iucv_path_table) + goto out; + /* Declare per cpu buffers. */ + rc = -EIO; + for_each_online_cpu(cpu) + smp_call_function_single(cpu, iucv_declare_cpu, NULL, 1); + if (cpumask_empty(&iucv_buffer_cpumask)) + /* No cpu could declare an iucv buffer. */ + goto out; + cpus_read_unlock(); + return 0; +out: + kfree(iucv_path_table); + iucv_path_table = NULL; + cpus_read_unlock(); + return rc; +} + +/* + * iucv_disable + * + * This function shuts down iucv. It disables iucv interrupts, retrieves + * the iucv interrupt buffer and frees the pathid table. Called after the + * last user unregister its iucv handler. + */ +static void iucv_disable(void) +{ + cpus_read_lock(); + on_each_cpu(iucv_retrieve_cpu, NULL, 1); + kfree(iucv_path_table); + iucv_path_table = NULL; + cpus_read_unlock(); +} + +static int iucv_cpu_dead(unsigned int cpu) +{ + kfree(iucv_param_irq[cpu]); + iucv_param_irq[cpu] = NULL; + kfree(iucv_param[cpu]); + iucv_param[cpu] = NULL; + kfree(iucv_irq_data[cpu]); + iucv_irq_data[cpu] = NULL; + return 0; +} + +static int iucv_cpu_prepare(unsigned int cpu) +{ + /* Note: GFP_DMA used to get memory below 2G */ + iucv_irq_data[cpu] = kmalloc_node(sizeof(struct iucv_irq_data), + GFP_KERNEL|GFP_DMA, cpu_to_node(cpu)); + if (!iucv_irq_data[cpu]) + goto out_free; + + /* Allocate parameter blocks. */ + iucv_param[cpu] = kmalloc_node(sizeof(union iucv_param), + GFP_KERNEL|GFP_DMA, cpu_to_node(cpu)); + if (!iucv_param[cpu]) + goto out_free; + + iucv_param_irq[cpu] = kmalloc_node(sizeof(union iucv_param), + GFP_KERNEL|GFP_DMA, cpu_to_node(cpu)); + if (!iucv_param_irq[cpu]) + goto out_free; + + return 0; + +out_free: + iucv_cpu_dead(cpu); + return -ENOMEM; +} + +static int iucv_cpu_online(unsigned int cpu) +{ + if (!iucv_path_table) + return 0; + iucv_declare_cpu(NULL); + return 0; +} + +static int iucv_cpu_down_prep(unsigned int cpu) +{ + cpumask_t cpumask; + + if (!iucv_path_table) + return 0; + + cpumask_copy(&cpumask, &iucv_buffer_cpumask); + cpumask_clear_cpu(cpu, &cpumask); + if (cpumask_empty(&cpumask)) + /* Can't offline last IUCV enabled cpu. */ + return -EINVAL; + + iucv_retrieve_cpu(NULL); + if (!cpumask_empty(&iucv_irq_cpumask)) + return 0; + smp_call_function_single(cpumask_first(&iucv_buffer_cpumask), + iucv_allow_cpu, NULL, 1); + return 0; +} + +/** + * iucv_sever_pathid + * @pathid: path identification number. + * @userdata: 16-bytes of user data. + * + * Sever an iucv path to free up the pathid. Used internally. + */ +static int iucv_sever_pathid(u16 pathid, u8 *userdata) +{ + union iucv_param *parm; + + parm = iucv_param_irq[smp_processor_id()]; + memset(parm, 0, sizeof(union iucv_param)); + if (userdata) + memcpy(parm->ctrl.ipuser, userdata, sizeof(parm->ctrl.ipuser)); + parm->ctrl.ippathid = pathid; + return iucv_call_b2f0(IUCV_SEVER, parm); +} + +/** + * __iucv_cleanup_queue + * @dummy: unused dummy argument + * + * Nop function called via smp_call_function to force work items from + * pending external iucv interrupts to the work queue. + */ +static void __iucv_cleanup_queue(void *dummy) +{ +} + +/** + * iucv_cleanup_queue + * + * Function called after a path has been severed to find all remaining + * work items for the now stale pathid. The caller needs to hold the + * iucv_table_lock. + */ +static void iucv_cleanup_queue(void) +{ + struct iucv_irq_list *p, *n; + + /* + * When a path is severed, the pathid can be reused immediately + * on a iucv connect or a connection pending interrupt. Remove + * all entries from the task queue that refer to a stale pathid + * (iucv_path_table[ix] == NULL). Only then do the iucv connect + * or deliver the connection pending interrupt. To get all the + * pending interrupts force them to the work queue by calling + * an empty function on all cpus. + */ + smp_call_function(__iucv_cleanup_queue, NULL, 1); + spin_lock_irq(&iucv_queue_lock); + list_for_each_entry_safe(p, n, &iucv_task_queue, list) { + /* Remove stale work items from the task queue. */ + if (iucv_path_table[p->data.ippathid] == NULL) { + list_del(&p->list); + kfree(p); + } + } + spin_unlock_irq(&iucv_queue_lock); +} + +/** + * iucv_register: + * @handler: address of iucv handler structure + * @smp: != 0 indicates that the handler can deal with out of order messages + * + * Registers a driver with IUCV. + * + * Returns 0 on success, -ENOMEM if the memory allocation for the pathid + * table failed, or -EIO if IUCV_DECLARE_BUFFER failed on all cpus. + */ +int iucv_register(struct iucv_handler *handler, int smp) +{ + int rc; + + if (!iucv_available) + return -ENOSYS; + mutex_lock(&iucv_register_mutex); + if (!smp) + iucv_nonsmp_handler++; + if (list_empty(&iucv_handler_list)) { + rc = iucv_enable(); + if (rc) + goto out_mutex; + } else if (!smp && iucv_nonsmp_handler == 1) + iucv_setmask_up(); + INIT_LIST_HEAD(&handler->paths); + + spin_lock_bh(&iucv_table_lock); + list_add_tail(&handler->list, &iucv_handler_list); + spin_unlock_bh(&iucv_table_lock); + rc = 0; +out_mutex: + mutex_unlock(&iucv_register_mutex); + return rc; +} +EXPORT_SYMBOL(iucv_register); + +/** + * iucv_unregister + * @handler: address of iucv handler structure + * @smp: != 0 indicates that the handler can deal with out of order messages + * + * Unregister driver from IUCV. + */ +void iucv_unregister(struct iucv_handler *handler, int smp) +{ + struct iucv_path *p, *n; + + mutex_lock(&iucv_register_mutex); + spin_lock_bh(&iucv_table_lock); + /* Remove handler from the iucv_handler_list. */ + list_del_init(&handler->list); + /* Sever all pathids still referring to the handler. */ + list_for_each_entry_safe(p, n, &handler->paths, list) { + iucv_sever_pathid(p->pathid, NULL); + iucv_path_table[p->pathid] = NULL; + list_del(&p->list); + iucv_path_free(p); + } + spin_unlock_bh(&iucv_table_lock); + if (!smp) + iucv_nonsmp_handler--; + if (list_empty(&iucv_handler_list)) + iucv_disable(); + else if (!smp && iucv_nonsmp_handler == 0) + iucv_setmask_mp(); + mutex_unlock(&iucv_register_mutex); +} +EXPORT_SYMBOL(iucv_unregister); + +static int iucv_reboot_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + int i; + + if (cpumask_empty(&iucv_irq_cpumask)) + return NOTIFY_DONE; + + cpus_read_lock(); + on_each_cpu_mask(&iucv_irq_cpumask, iucv_block_cpu, NULL, 1); + preempt_disable(); + for (i = 0; i < iucv_max_pathid; i++) { + if (iucv_path_table[i]) + iucv_sever_pathid(i, NULL); + } + preempt_enable(); + cpus_read_unlock(); + iucv_disable(); + return NOTIFY_DONE; +} + +static struct notifier_block iucv_reboot_notifier = { + .notifier_call = iucv_reboot_event, +}; + +/** + * iucv_path_accept + * @path: address of iucv path structure + * @handler: address of iucv handler structure + * @userdata: 16 bytes of data reflected to the communication partner + * @private: private data passed to interrupt handlers for this path + * + * This function is issued after the user received a connection pending + * external interrupt and now wishes to complete the IUCV communication path. + * + * Returns the result of the CP IUCV call. + */ +int iucv_path_accept(struct iucv_path *path, struct iucv_handler *handler, + u8 *userdata, void *private) +{ + union iucv_param *parm; + int rc; + + local_bh_disable(); + if (cpumask_empty(&iucv_buffer_cpumask)) { + rc = -EIO; + goto out; + } + /* Prepare parameter block. */ + parm = iucv_param[smp_processor_id()]; + memset(parm, 0, sizeof(union iucv_param)); + parm->ctrl.ippathid = path->pathid; + parm->ctrl.ipmsglim = path->msglim; + if (userdata) + memcpy(parm->ctrl.ipuser, userdata, sizeof(parm->ctrl.ipuser)); + parm->ctrl.ipflags1 = path->flags; + + rc = iucv_call_b2f0(IUCV_ACCEPT, parm); + if (!rc) { + path->private = private; + path->msglim = parm->ctrl.ipmsglim; + path->flags = parm->ctrl.ipflags1; + } +out: + local_bh_enable(); + return rc; +} +EXPORT_SYMBOL(iucv_path_accept); + +/** + * iucv_path_connect + * @path: address of iucv path structure + * @handler: address of iucv handler structure + * @userid: 8-byte user identification + * @system: 8-byte target system identification + * @userdata: 16 bytes of data reflected to the communication partner + * @private: private data passed to interrupt handlers for this path + * + * This function establishes an IUCV path. Although the connect may complete + * successfully, you are not able to use the path until you receive an IUCV + * Connection Complete external interrupt. + * + * Returns the result of the CP IUCV call. + */ +int iucv_path_connect(struct iucv_path *path, struct iucv_handler *handler, + u8 *userid, u8 *system, u8 *userdata, + void *private) +{ + union iucv_param *parm; + int rc; + + spin_lock_bh(&iucv_table_lock); + iucv_cleanup_queue(); + if (cpumask_empty(&iucv_buffer_cpumask)) { + rc = -EIO; + goto out; + } + parm = iucv_param[smp_processor_id()]; + memset(parm, 0, sizeof(union iucv_param)); + parm->ctrl.ipmsglim = path->msglim; + parm->ctrl.ipflags1 = path->flags; + if (userid) { + memcpy(parm->ctrl.ipvmid, userid, sizeof(parm->ctrl.ipvmid)); + ASCEBC(parm->ctrl.ipvmid, sizeof(parm->ctrl.ipvmid)); + EBC_TOUPPER(parm->ctrl.ipvmid, sizeof(parm->ctrl.ipvmid)); + } + if (system) { + memcpy(parm->ctrl.iptarget, system, + sizeof(parm->ctrl.iptarget)); + ASCEBC(parm->ctrl.iptarget, sizeof(parm->ctrl.iptarget)); + EBC_TOUPPER(parm->ctrl.iptarget, sizeof(parm->ctrl.iptarget)); + } + if (userdata) + memcpy(parm->ctrl.ipuser, userdata, sizeof(parm->ctrl.ipuser)); + + rc = iucv_call_b2f0(IUCV_CONNECT, parm); + if (!rc) { + if (parm->ctrl.ippathid < iucv_max_pathid) { + path->pathid = parm->ctrl.ippathid; + path->msglim = parm->ctrl.ipmsglim; + path->flags = parm->ctrl.ipflags1; + path->handler = handler; + path->private = private; + list_add_tail(&path->list, &handler->paths); + iucv_path_table[path->pathid] = path; + } else { + iucv_sever_pathid(parm->ctrl.ippathid, + iucv_error_pathid); + rc = -EIO; + } + } +out: + spin_unlock_bh(&iucv_table_lock); + return rc; +} +EXPORT_SYMBOL(iucv_path_connect); + +/** + * iucv_path_quiesce: + * @path: address of iucv path structure + * @userdata: 16 bytes of data reflected to the communication partner + * + * This function temporarily suspends incoming messages on an IUCV path. + * You can later reactivate the path by invoking the iucv_resume function. + * + * Returns the result from the CP IUCV call. + */ +int iucv_path_quiesce(struct iucv_path *path, u8 *userdata) +{ + union iucv_param *parm; + int rc; + + local_bh_disable(); + if (cpumask_empty(&iucv_buffer_cpumask)) { + rc = -EIO; + goto out; + } + parm = iucv_param[smp_processor_id()]; + memset(parm, 0, sizeof(union iucv_param)); + if (userdata) + memcpy(parm->ctrl.ipuser, userdata, sizeof(parm->ctrl.ipuser)); + parm->ctrl.ippathid = path->pathid; + rc = iucv_call_b2f0(IUCV_QUIESCE, parm); +out: + local_bh_enable(); + return rc; +} +EXPORT_SYMBOL(iucv_path_quiesce); + +/** + * iucv_path_resume: + * @path: address of iucv path structure + * @userdata: 16 bytes of data reflected to the communication partner + * + * This function resumes incoming messages on an IUCV path that has + * been stopped with iucv_path_quiesce. + * + * Returns the result from the CP IUCV call. + */ +int iucv_path_resume(struct iucv_path *path, u8 *userdata) +{ + union iucv_param *parm; + int rc; + + local_bh_disable(); + if (cpumask_empty(&iucv_buffer_cpumask)) { + rc = -EIO; + goto out; + } + parm = iucv_param[smp_processor_id()]; + memset(parm, 0, sizeof(union iucv_param)); + if (userdata) + memcpy(parm->ctrl.ipuser, userdata, sizeof(parm->ctrl.ipuser)); + parm->ctrl.ippathid = path->pathid; + rc = iucv_call_b2f0(IUCV_RESUME, parm); +out: + local_bh_enable(); + return rc; +} + +/** + * iucv_path_sever + * @path: address of iucv path structure + * @userdata: 16 bytes of data reflected to the communication partner + * + * This function terminates an IUCV path. + * + * Returns the result from the CP IUCV call. + */ +int iucv_path_sever(struct iucv_path *path, u8 *userdata) +{ + int rc; + + preempt_disable(); + if (cpumask_empty(&iucv_buffer_cpumask)) { + rc = -EIO; + goto out; + } + if (iucv_active_cpu != smp_processor_id()) + spin_lock_bh(&iucv_table_lock); + rc = iucv_sever_pathid(path->pathid, userdata); + iucv_path_table[path->pathid] = NULL; + list_del_init(&path->list); + if (iucv_active_cpu != smp_processor_id()) + spin_unlock_bh(&iucv_table_lock); +out: + preempt_enable(); + return rc; +} +EXPORT_SYMBOL(iucv_path_sever); + +/** + * iucv_message_purge + * @path: address of iucv path structure + * @msg: address of iucv msg structure + * @srccls: source class of message + * + * Cancels a message you have sent. + * + * Returns the result from the CP IUCV call. + */ +int iucv_message_purge(struct iucv_path *path, struct iucv_message *msg, + u32 srccls) +{ + union iucv_param *parm; + int rc; + + local_bh_disable(); + if (cpumask_empty(&iucv_buffer_cpumask)) { + rc = -EIO; + goto out; + } + parm = iucv_param[smp_processor_id()]; + memset(parm, 0, sizeof(union iucv_param)); + parm->purge.ippathid = path->pathid; + parm->purge.ipmsgid = msg->id; + parm->purge.ipsrccls = srccls; + parm->purge.ipflags1 = IUCV_IPSRCCLS | IUCV_IPFGMID | IUCV_IPFGPID; + rc = iucv_call_b2f0(IUCV_PURGE, parm); + if (!rc) { + msg->audit = (*(u32 *) &parm->purge.ipaudit) >> 8; + msg->tag = parm->purge.ipmsgtag; + } +out: + local_bh_enable(); + return rc; +} +EXPORT_SYMBOL(iucv_message_purge); + +/** + * iucv_message_receive_iprmdata + * @path: address of iucv path structure + * @msg: address of iucv msg structure + * @flags: how the message is received (IUCV_IPBUFLST) + * @buffer: address of data buffer or address of struct iucv_array + * @size: length of data buffer + * @residual: + * + * Internal function used by iucv_message_receive and __iucv_message_receive + * to receive RMDATA data stored in struct iucv_message. + */ +static int iucv_message_receive_iprmdata(struct iucv_path *path, + struct iucv_message *msg, + u8 flags, void *buffer, + size_t size, size_t *residual) +{ + struct iucv_array *array; + u8 *rmmsg; + size_t copy; + + /* + * Message is 8 bytes long and has been stored to the + * message descriptor itself. + */ + if (residual) + *residual = abs(size - 8); + rmmsg = msg->rmmsg; + if (flags & IUCV_IPBUFLST) { + /* Copy to struct iucv_array. */ + size = (size < 8) ? size : 8; + for (array = buffer; size > 0; array++) { + copy = min_t(size_t, size, array->length); + memcpy((u8 *)(addr_t) array->address, + rmmsg, copy); + rmmsg += copy; + size -= copy; + } + } else { + /* Copy to direct buffer. */ + memcpy(buffer, rmmsg, min_t(size_t, size, 8)); + } + return 0; +} + +/** + * __iucv_message_receive + * @path: address of iucv path structure + * @msg: address of iucv msg structure + * @flags: how the message is received (IUCV_IPBUFLST) + * @buffer: address of data buffer or address of struct iucv_array + * @size: length of data buffer + * @residual: + * + * This function receives messages that are being sent to you over + * established paths. This function will deal with RMDATA messages + * embedded in struct iucv_message as well. + * + * Locking: no locking + * + * Returns the result from the CP IUCV call. + */ +int __iucv_message_receive(struct iucv_path *path, struct iucv_message *msg, + u8 flags, void *buffer, size_t size, size_t *residual) +{ + union iucv_param *parm; + int rc; + + if (msg->flags & IUCV_IPRMDATA) + return iucv_message_receive_iprmdata(path, msg, flags, + buffer, size, residual); + if (cpumask_empty(&iucv_buffer_cpumask)) + return -EIO; + + parm = iucv_param[smp_processor_id()]; + memset(parm, 0, sizeof(union iucv_param)); + parm->db.ipbfadr1 = (u32)(addr_t) buffer; + parm->db.ipbfln1f = (u32) size; + parm->db.ipmsgid = msg->id; + parm->db.ippathid = path->pathid; + parm->db.iptrgcls = msg->class; + parm->db.ipflags1 = (flags | IUCV_IPFGPID | + IUCV_IPFGMID | IUCV_IPTRGCLS); + rc = iucv_call_b2f0(IUCV_RECEIVE, parm); + if (!rc || rc == 5) { + msg->flags = parm->db.ipflags1; + if (residual) + *residual = parm->db.ipbfln1f; + } + return rc; +} +EXPORT_SYMBOL(__iucv_message_receive); + +/** + * iucv_message_receive + * @path: address of iucv path structure + * @msg: address of iucv msg structure + * @flags: how the message is received (IUCV_IPBUFLST) + * @buffer: address of data buffer or address of struct iucv_array + * @size: length of data buffer + * @residual: + * + * This function receives messages that are being sent to you over + * established paths. This function will deal with RMDATA messages + * embedded in struct iucv_message as well. + * + * Locking: local_bh_enable/local_bh_disable + * + * Returns the result from the CP IUCV call. + */ +int iucv_message_receive(struct iucv_path *path, struct iucv_message *msg, + u8 flags, void *buffer, size_t size, size_t *residual) +{ + int rc; + + if (msg->flags & IUCV_IPRMDATA) + return iucv_message_receive_iprmdata(path, msg, flags, + buffer, size, residual); + local_bh_disable(); + rc = __iucv_message_receive(path, msg, flags, buffer, size, residual); + local_bh_enable(); + return rc; +} +EXPORT_SYMBOL(iucv_message_receive); + +/** + * iucv_message_reject + * @path: address of iucv path structure + * @msg: address of iucv msg structure + * + * The reject function refuses a specified message. Between the time you + * are notified of a message and the time that you complete the message, + * the message may be rejected. + * + * Returns the result from the CP IUCV call. + */ +int iucv_message_reject(struct iucv_path *path, struct iucv_message *msg) +{ + union iucv_param *parm; + int rc; + + local_bh_disable(); + if (cpumask_empty(&iucv_buffer_cpumask)) { + rc = -EIO; + goto out; + } + parm = iucv_param[smp_processor_id()]; + memset(parm, 0, sizeof(union iucv_param)); + parm->db.ippathid = path->pathid; + parm->db.ipmsgid = msg->id; + parm->db.iptrgcls = msg->class; + parm->db.ipflags1 = (IUCV_IPTRGCLS | IUCV_IPFGMID | IUCV_IPFGPID); + rc = iucv_call_b2f0(IUCV_REJECT, parm); +out: + local_bh_enable(); + return rc; +} +EXPORT_SYMBOL(iucv_message_reject); + +/** + * iucv_message_reply + * @path: address of iucv path structure + * @msg: address of iucv msg structure + * @flags: how the reply is sent (IUCV_IPRMDATA, IUCV_IPPRTY, IUCV_IPBUFLST) + * @reply: address of reply data buffer or address of struct iucv_array + * @size: length of reply data buffer + * + * This function responds to the two-way messages that you receive. You + * must identify completely the message to which you wish to reply. ie, + * pathid, msgid, and trgcls. Prmmsg signifies the data is moved into + * the parameter list. + * + * Returns the result from the CP IUCV call. + */ +int iucv_message_reply(struct iucv_path *path, struct iucv_message *msg, + u8 flags, void *reply, size_t size) +{ + union iucv_param *parm; + int rc; + + local_bh_disable(); + if (cpumask_empty(&iucv_buffer_cpumask)) { + rc = -EIO; + goto out; + } + parm = iucv_param[smp_processor_id()]; + memset(parm, 0, sizeof(union iucv_param)); + if (flags & IUCV_IPRMDATA) { + parm->dpl.ippathid = path->pathid; + parm->dpl.ipflags1 = flags; + parm->dpl.ipmsgid = msg->id; + parm->dpl.iptrgcls = msg->class; + memcpy(parm->dpl.iprmmsg, reply, min_t(size_t, size, 8)); + } else { + parm->db.ipbfadr1 = (u32)(addr_t) reply; + parm->db.ipbfln1f = (u32) size; + parm->db.ippathid = path->pathid; + parm->db.ipflags1 = flags; + parm->db.ipmsgid = msg->id; + parm->db.iptrgcls = msg->class; + } + rc = iucv_call_b2f0(IUCV_REPLY, parm); +out: + local_bh_enable(); + return rc; +} +EXPORT_SYMBOL(iucv_message_reply); + +/** + * __iucv_message_send + * @path: address of iucv path structure + * @msg: address of iucv msg structure + * @flags: how the message is sent (IUCV_IPRMDATA, IUCV_IPPRTY, IUCV_IPBUFLST) + * @srccls: source class of message + * @buffer: address of send buffer or address of struct iucv_array + * @size: length of send buffer + * + * This function transmits data to another application. Data to be + * transmitted is in a buffer and this is a one-way message and the + * receiver will not reply to the message. + * + * Locking: no locking + * + * Returns the result from the CP IUCV call. + */ +int __iucv_message_send(struct iucv_path *path, struct iucv_message *msg, + u8 flags, u32 srccls, void *buffer, size_t size) +{ + union iucv_param *parm; + int rc; + + if (cpumask_empty(&iucv_buffer_cpumask)) { + rc = -EIO; + goto out; + } + parm = iucv_param[smp_processor_id()]; + memset(parm, 0, sizeof(union iucv_param)); + if (flags & IUCV_IPRMDATA) { + /* Message of 8 bytes can be placed into the parameter list. */ + parm->dpl.ippathid = path->pathid; + parm->dpl.ipflags1 = flags | IUCV_IPNORPY; + parm->dpl.iptrgcls = msg->class; + parm->dpl.ipsrccls = srccls; + parm->dpl.ipmsgtag = msg->tag; + memcpy(parm->dpl.iprmmsg, buffer, 8); + } else { + parm->db.ipbfadr1 = (u32)(addr_t) buffer; + parm->db.ipbfln1f = (u32) size; + parm->db.ippathid = path->pathid; + parm->db.ipflags1 = flags | IUCV_IPNORPY; + parm->db.iptrgcls = msg->class; + parm->db.ipsrccls = srccls; + parm->db.ipmsgtag = msg->tag; + } + rc = iucv_call_b2f0(IUCV_SEND, parm); + if (!rc) + msg->id = parm->db.ipmsgid; +out: + return rc; +} +EXPORT_SYMBOL(__iucv_message_send); + +/** + * iucv_message_send + * @path: address of iucv path structure + * @msg: address of iucv msg structure + * @flags: how the message is sent (IUCV_IPRMDATA, IUCV_IPPRTY, IUCV_IPBUFLST) + * @srccls: source class of message + * @buffer: address of send buffer or address of struct iucv_array + * @size: length of send buffer + * + * This function transmits data to another application. Data to be + * transmitted is in a buffer and this is a one-way message and the + * receiver will not reply to the message. + * + * Locking: local_bh_enable/local_bh_disable + * + * Returns the result from the CP IUCV call. + */ +int iucv_message_send(struct iucv_path *path, struct iucv_message *msg, + u8 flags, u32 srccls, void *buffer, size_t size) +{ + int rc; + + local_bh_disable(); + rc = __iucv_message_send(path, msg, flags, srccls, buffer, size); + local_bh_enable(); + return rc; +} +EXPORT_SYMBOL(iucv_message_send); + +/** + * iucv_message_send2way + * @path: address of iucv path structure + * @msg: address of iucv msg structure + * @flags: how the message is sent and the reply is received + * (IUCV_IPRMDATA, IUCV_IPBUFLST, IUCV_IPPRTY, IUCV_ANSLST) + * @srccls: source class of message + * @buffer: address of send buffer or address of struct iucv_array + * @size: length of send buffer + * @answer: address of answer buffer or address of struct iucv_array + * @asize: size of reply buffer + * @residual: ignored + * + * This function transmits data to another application. Data to be + * transmitted is in a buffer. The receiver of the send is expected to + * reply to the message and a buffer is provided into which IUCV moves + * the reply to this message. + * + * Returns the result from the CP IUCV call. + */ +int iucv_message_send2way(struct iucv_path *path, struct iucv_message *msg, + u8 flags, u32 srccls, void *buffer, size_t size, + void *answer, size_t asize, size_t *residual) +{ + union iucv_param *parm; + int rc; + + local_bh_disable(); + if (cpumask_empty(&iucv_buffer_cpumask)) { + rc = -EIO; + goto out; + } + parm = iucv_param[smp_processor_id()]; + memset(parm, 0, sizeof(union iucv_param)); + if (flags & IUCV_IPRMDATA) { + parm->dpl.ippathid = path->pathid; + parm->dpl.ipflags1 = path->flags; /* priority message */ + parm->dpl.iptrgcls = msg->class; + parm->dpl.ipsrccls = srccls; + parm->dpl.ipmsgtag = msg->tag; + parm->dpl.ipbfadr2 = (u32)(addr_t) answer; + parm->dpl.ipbfln2f = (u32) asize; + memcpy(parm->dpl.iprmmsg, buffer, 8); + } else { + parm->db.ippathid = path->pathid; + parm->db.ipflags1 = path->flags; /* priority message */ + parm->db.iptrgcls = msg->class; + parm->db.ipsrccls = srccls; + parm->db.ipmsgtag = msg->tag; + parm->db.ipbfadr1 = (u32)(addr_t) buffer; + parm->db.ipbfln1f = (u32) size; + parm->db.ipbfadr2 = (u32)(addr_t) answer; + parm->db.ipbfln2f = (u32) asize; + } + rc = iucv_call_b2f0(IUCV_SEND, parm); + if (!rc) + msg->id = parm->db.ipmsgid; +out: + local_bh_enable(); + return rc; +} +EXPORT_SYMBOL(iucv_message_send2way); + +struct iucv_path_pending { + u16 ippathid; + u8 ipflags1; + u8 iptype; + u16 ipmsglim; + u16 res1; + u8 ipvmid[8]; + u8 ipuser[16]; + u32 res3; + u8 ippollfg; + u8 res4[3]; +} __packed; + +/** + * iucv_path_pending + * @data: Pointer to external interrupt buffer + * + * Process connection pending work item. Called from tasklet while holding + * iucv_table_lock. + */ +static void iucv_path_pending(struct iucv_irq_data *data) +{ + struct iucv_path_pending *ipp = (void *) data; + struct iucv_handler *handler; + struct iucv_path *path; + char *error; + + BUG_ON(iucv_path_table[ipp->ippathid]); + /* New pathid, handler found. Create a new path struct. */ + error = iucv_error_no_memory; + path = iucv_path_alloc(ipp->ipmsglim, ipp->ipflags1, GFP_ATOMIC); + if (!path) + goto out_sever; + path->pathid = ipp->ippathid; + iucv_path_table[path->pathid] = path; + EBCASC(ipp->ipvmid, 8); + + /* Call registered handler until one is found that wants the path. */ + list_for_each_entry(handler, &iucv_handler_list, list) { + if (!handler->path_pending) + continue; + /* + * Add path to handler to allow a call to iucv_path_sever + * inside the path_pending function. If the handler returns + * an error remove the path from the handler again. + */ + list_add(&path->list, &handler->paths); + path->handler = handler; + if (!handler->path_pending(path, ipp->ipvmid, ipp->ipuser)) + return; + list_del(&path->list); + path->handler = NULL; + } + /* No handler wanted the path. */ + iucv_path_table[path->pathid] = NULL; + iucv_path_free(path); + error = iucv_error_no_listener; +out_sever: + iucv_sever_pathid(ipp->ippathid, error); +} + +struct iucv_path_complete { + u16 ippathid; + u8 ipflags1; + u8 iptype; + u16 ipmsglim; + u16 res1; + u8 res2[8]; + u8 ipuser[16]; + u32 res3; + u8 ippollfg; + u8 res4[3]; +} __packed; + +/** + * iucv_path_complete + * @data: Pointer to external interrupt buffer + * + * Process connection complete work item. Called from tasklet while holding + * iucv_table_lock. + */ +static void iucv_path_complete(struct iucv_irq_data *data) +{ + struct iucv_path_complete *ipc = (void *) data; + struct iucv_path *path = iucv_path_table[ipc->ippathid]; + + if (path) + path->flags = ipc->ipflags1; + if (path && path->handler && path->handler->path_complete) + path->handler->path_complete(path, ipc->ipuser); +} + +struct iucv_path_severed { + u16 ippathid; + u8 res1; + u8 iptype; + u32 res2; + u8 res3[8]; + u8 ipuser[16]; + u32 res4; + u8 ippollfg; + u8 res5[3]; +} __packed; + +/** + * iucv_path_severed + * @data: Pointer to external interrupt buffer + * + * Process connection severed work item. Called from tasklet while holding + * iucv_table_lock. + */ +static void iucv_path_severed(struct iucv_irq_data *data) +{ + struct iucv_path_severed *ips = (void *) data; + struct iucv_path *path = iucv_path_table[ips->ippathid]; + + if (!path || !path->handler) /* Already severed */ + return; + if (path->handler->path_severed) + path->handler->path_severed(path, ips->ipuser); + else { + iucv_sever_pathid(path->pathid, NULL); + iucv_path_table[path->pathid] = NULL; + list_del(&path->list); + iucv_path_free(path); + } +} + +struct iucv_path_quiesced { + u16 ippathid; + u8 res1; + u8 iptype; + u32 res2; + u8 res3[8]; + u8 ipuser[16]; + u32 res4; + u8 ippollfg; + u8 res5[3]; +} __packed; + +/** + * iucv_path_quiesced + * @data: Pointer to external interrupt buffer + * + * Process connection quiesced work item. Called from tasklet while holding + * iucv_table_lock. + */ +static void iucv_path_quiesced(struct iucv_irq_data *data) +{ + struct iucv_path_quiesced *ipq = (void *) data; + struct iucv_path *path = iucv_path_table[ipq->ippathid]; + + if (path && path->handler && path->handler->path_quiesced) + path->handler->path_quiesced(path, ipq->ipuser); +} + +struct iucv_path_resumed { + u16 ippathid; + u8 res1; + u8 iptype; + u32 res2; + u8 res3[8]; + u8 ipuser[16]; + u32 res4; + u8 ippollfg; + u8 res5[3]; +} __packed; + +/** + * iucv_path_resumed + * @data: Pointer to external interrupt buffer + * + * Process connection resumed work item. Called from tasklet while holding + * iucv_table_lock. + */ +static void iucv_path_resumed(struct iucv_irq_data *data) +{ + struct iucv_path_resumed *ipr = (void *) data; + struct iucv_path *path = iucv_path_table[ipr->ippathid]; + + if (path && path->handler && path->handler->path_resumed) + path->handler->path_resumed(path, ipr->ipuser); +} + +struct iucv_message_complete { + u16 ippathid; + u8 ipflags1; + u8 iptype; + u32 ipmsgid; + u32 ipaudit; + u8 iprmmsg[8]; + u32 ipsrccls; + u32 ipmsgtag; + u32 res; + u32 ipbfln2f; + u8 ippollfg; + u8 res2[3]; +} __packed; + +/** + * iucv_message_complete + * @data: Pointer to external interrupt buffer + * + * Process message complete work item. Called from tasklet while holding + * iucv_table_lock. + */ +static void iucv_message_complete(struct iucv_irq_data *data) +{ + struct iucv_message_complete *imc = (void *) data; + struct iucv_path *path = iucv_path_table[imc->ippathid]; + struct iucv_message msg; + + if (path && path->handler && path->handler->message_complete) { + msg.flags = imc->ipflags1; + msg.id = imc->ipmsgid; + msg.audit = imc->ipaudit; + memcpy(msg.rmmsg, imc->iprmmsg, 8); + msg.class = imc->ipsrccls; + msg.tag = imc->ipmsgtag; + msg.length = imc->ipbfln2f; + path->handler->message_complete(path, &msg); + } +} + +struct iucv_message_pending { + u16 ippathid; + u8 ipflags1; + u8 iptype; + u32 ipmsgid; + u32 iptrgcls; + struct { + union { + u32 iprmmsg1_u32; + u8 iprmmsg1[4]; + } ln1msg1; + union { + u32 ipbfln1f; + u8 iprmmsg2[4]; + } ln1msg2; + } rmmsg; + u32 res1[3]; + u32 ipbfln2f; + u8 ippollfg; + u8 res2[3]; +} __packed; + +/** + * iucv_message_pending + * @data: Pointer to external interrupt buffer + * + * Process message pending work item. Called from tasklet while holding + * iucv_table_lock. + */ +static void iucv_message_pending(struct iucv_irq_data *data) +{ + struct iucv_message_pending *imp = (void *) data; + struct iucv_path *path = iucv_path_table[imp->ippathid]; + struct iucv_message msg; + + if (path && path->handler && path->handler->message_pending) { + msg.flags = imp->ipflags1; + msg.id = imp->ipmsgid; + msg.class = imp->iptrgcls; + if (imp->ipflags1 & IUCV_IPRMDATA) { + memcpy(msg.rmmsg, &imp->rmmsg, 8); + msg.length = 8; + } else + msg.length = imp->rmmsg.ln1msg2.ipbfln1f; + msg.reply_size = imp->ipbfln2f; + path->handler->message_pending(path, &msg); + } +} + +/* + * iucv_tasklet_fn: + * + * This tasklet loops over the queue of irq buffers created by + * iucv_external_interrupt, calls the appropriate action handler + * and then frees the buffer. + */ +static void iucv_tasklet_fn(unsigned long ignored) +{ + typedef void iucv_irq_fn(struct iucv_irq_data *); + static iucv_irq_fn *irq_fn[] = { + [0x02] = iucv_path_complete, + [0x03] = iucv_path_severed, + [0x04] = iucv_path_quiesced, + [0x05] = iucv_path_resumed, + [0x06] = iucv_message_complete, + [0x07] = iucv_message_complete, + [0x08] = iucv_message_pending, + [0x09] = iucv_message_pending, + }; + LIST_HEAD(task_queue); + struct iucv_irq_list *p, *n; + + /* Serialize tasklet, iucv_path_sever and iucv_path_connect. */ + if (!spin_trylock(&iucv_table_lock)) { + tasklet_schedule(&iucv_tasklet); + return; + } + iucv_active_cpu = smp_processor_id(); + + spin_lock_irq(&iucv_queue_lock); + list_splice_init(&iucv_task_queue, &task_queue); + spin_unlock_irq(&iucv_queue_lock); + + list_for_each_entry_safe(p, n, &task_queue, list) { + list_del_init(&p->list); + irq_fn[p->data.iptype](&p->data); + kfree(p); + } + + iucv_active_cpu = -1; + spin_unlock(&iucv_table_lock); +} + +/* + * iucv_work_fn: + * + * This work function loops over the queue of path pending irq blocks + * created by iucv_external_interrupt, calls the appropriate action + * handler and then frees the buffer. + */ +static void iucv_work_fn(struct work_struct *work) +{ + LIST_HEAD(work_queue); + struct iucv_irq_list *p, *n; + + /* Serialize tasklet, iucv_path_sever and iucv_path_connect. */ + spin_lock_bh(&iucv_table_lock); + iucv_active_cpu = smp_processor_id(); + + spin_lock_irq(&iucv_queue_lock); + list_splice_init(&iucv_work_queue, &work_queue); + spin_unlock_irq(&iucv_queue_lock); + + iucv_cleanup_queue(); + list_for_each_entry_safe(p, n, &work_queue, list) { + list_del_init(&p->list); + iucv_path_pending(&p->data); + kfree(p); + } + + iucv_active_cpu = -1; + spin_unlock_bh(&iucv_table_lock); +} + +/* + * iucv_external_interrupt + * + * Handles external interrupts coming in from CP. + * Places the interrupt buffer on a queue and schedules iucv_tasklet_fn(). + */ +static void iucv_external_interrupt(struct ext_code ext_code, + unsigned int param32, unsigned long param64) +{ + struct iucv_irq_data *p; + struct iucv_irq_list *work; + + inc_irq_stat(IRQEXT_IUC); + p = iucv_irq_data[smp_processor_id()]; + if (p->ippathid >= iucv_max_pathid) { + WARN_ON(p->ippathid >= iucv_max_pathid); + iucv_sever_pathid(p->ippathid, iucv_error_no_listener); + return; + } + BUG_ON(p->iptype < 0x01 || p->iptype > 0x09); + work = kmalloc(sizeof(struct iucv_irq_list), GFP_ATOMIC); + if (!work) { + pr_warn("iucv_external_interrupt: out of memory\n"); + return; + } + memcpy(&work->data, p, sizeof(work->data)); + spin_lock(&iucv_queue_lock); + if (p->iptype == 0x01) { + /* Path pending interrupt. */ + list_add_tail(&work->list, &iucv_work_queue); + schedule_work(&iucv_work); + } else { + /* The other interrupts. */ + list_add_tail(&work->list, &iucv_task_queue); + tasklet_schedule(&iucv_tasklet); + } + spin_unlock(&iucv_queue_lock); +} + +struct iucv_interface iucv_if = { + .message_receive = iucv_message_receive, + .__message_receive = __iucv_message_receive, + .message_reply = iucv_message_reply, + .message_reject = iucv_message_reject, + .message_send = iucv_message_send, + .__message_send = __iucv_message_send, + .message_send2way = iucv_message_send2way, + .message_purge = iucv_message_purge, + .path_accept = iucv_path_accept, + .path_connect = iucv_path_connect, + .path_quiesce = iucv_path_quiesce, + .path_resume = iucv_path_resume, + .path_sever = iucv_path_sever, + .iucv_register = iucv_register, + .iucv_unregister = iucv_unregister, + .bus = NULL, + .root = NULL, +}; +EXPORT_SYMBOL(iucv_if); + +static enum cpuhp_state iucv_online; +/** + * iucv_init + * + * Allocates and initializes various data structures. + */ +static int __init iucv_init(void) +{ + int rc; + + if (!MACHINE_IS_VM) { + rc = -EPROTONOSUPPORT; + goto out; + } + ctl_set_bit(0, 1); + rc = iucv_query_maxconn(); + if (rc) + goto out_ctl; + rc = register_external_irq(EXT_IRQ_IUCV, iucv_external_interrupt); + if (rc) + goto out_ctl; + iucv_root = root_device_register("iucv"); + if (IS_ERR(iucv_root)) { + rc = PTR_ERR(iucv_root); + goto out_int; + } + + rc = cpuhp_setup_state(CPUHP_NET_IUCV_PREPARE, "net/iucv:prepare", + iucv_cpu_prepare, iucv_cpu_dead); + if (rc) + goto out_dev; + rc = cpuhp_setup_state(CPUHP_AP_ONLINE_DYN, "net/iucv:online", + iucv_cpu_online, iucv_cpu_down_prep); + if (rc < 0) + goto out_prep; + iucv_online = rc; + + rc = register_reboot_notifier(&iucv_reboot_notifier); + if (rc) + goto out_remove_hp; + ASCEBC(iucv_error_no_listener, 16); + ASCEBC(iucv_error_no_memory, 16); + ASCEBC(iucv_error_pathid, 16); + iucv_available = 1; + rc = bus_register(&iucv_bus); + if (rc) + goto out_reboot; + iucv_if.root = iucv_root; + iucv_if.bus = &iucv_bus; + return 0; + +out_reboot: + unregister_reboot_notifier(&iucv_reboot_notifier); +out_remove_hp: + cpuhp_remove_state(iucv_online); +out_prep: + cpuhp_remove_state(CPUHP_NET_IUCV_PREPARE); +out_dev: + root_device_unregister(iucv_root); +out_int: + unregister_external_irq(EXT_IRQ_IUCV, iucv_external_interrupt); +out_ctl: + ctl_clear_bit(0, 1); +out: + return rc; +} + +/** + * iucv_exit + * + * Frees everything allocated from iucv_init. + */ +static void __exit iucv_exit(void) +{ + struct iucv_irq_list *p, *n; + + spin_lock_irq(&iucv_queue_lock); + list_for_each_entry_safe(p, n, &iucv_task_queue, list) + kfree(p); + list_for_each_entry_safe(p, n, &iucv_work_queue, list) + kfree(p); + spin_unlock_irq(&iucv_queue_lock); + unregister_reboot_notifier(&iucv_reboot_notifier); + + cpuhp_remove_state_nocalls(iucv_online); + cpuhp_remove_state(CPUHP_NET_IUCV_PREPARE); + root_device_unregister(iucv_root); + bus_unregister(&iucv_bus); + unregister_external_irq(EXT_IRQ_IUCV, iucv_external_interrupt); +} + +subsys_initcall(iucv_init); +module_exit(iucv_exit); + +MODULE_AUTHOR("(C) 2001 IBM Corp. by Fritz Elfert (felfert@millenux.com)"); +MODULE_DESCRIPTION("Linux for S/390 IUCV lowlevel driver"); +MODULE_LICENSE("GPL"); diff --git a/net/kcm/Kconfig b/net/kcm/Kconfig new file mode 100644 index 000000000..16f39f256 --- /dev/null +++ b/net/kcm/Kconfig @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: GPL-2.0-only + +config AF_KCM + tristate "KCM sockets" + depends on INET + select BPF_SYSCALL + select STREAM_PARSER + help + KCM (Kernel Connection Multiplexor) sockets provide a method + for multiplexing messages of a message based application + protocol over kernel connectons (e.g. TCP connections). diff --git a/net/kcm/Makefile b/net/kcm/Makefile new file mode 100644 index 000000000..6c4569221 --- /dev/null +++ b/net/kcm/Makefile @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: GPL-2.0-only +obj-$(CONFIG_AF_KCM) += kcm.o + +kcm-y := kcmsock.o kcmproc.o diff --git a/net/kcm/kcmproc.c b/net/kcm/kcmproc.c new file mode 100644 index 000000000..25c1007f1 --- /dev/null +++ b/net/kcm/kcmproc.c @@ -0,0 +1,387 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CONFIG_PROC_FS +static struct kcm_mux *kcm_get_first(struct seq_file *seq) +{ + struct net *net = seq_file_net(seq); + struct kcm_net *knet = net_generic(net, kcm_net_id); + + return list_first_or_null_rcu(&knet->mux_list, + struct kcm_mux, kcm_mux_list); +} + +static struct kcm_mux *kcm_get_next(struct kcm_mux *mux) +{ + struct kcm_net *knet = mux->knet; + + return list_next_or_null_rcu(&knet->mux_list, &mux->kcm_mux_list, + struct kcm_mux, kcm_mux_list); +} + +static struct kcm_mux *kcm_get_idx(struct seq_file *seq, loff_t pos) +{ + struct net *net = seq_file_net(seq); + struct kcm_net *knet = net_generic(net, kcm_net_id); + struct kcm_mux *m; + + list_for_each_entry_rcu(m, &knet->mux_list, kcm_mux_list) { + if (!pos) + return m; + --pos; + } + return NULL; +} + +static void *kcm_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + void *p; + + if (v == SEQ_START_TOKEN) + p = kcm_get_first(seq); + else + p = kcm_get_next(v); + ++*pos; + return p; +} + +static void *kcm_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(rcu) +{ + rcu_read_lock(); + + if (!*pos) + return SEQ_START_TOKEN; + else + return kcm_get_idx(seq, *pos - 1); +} + +static void kcm_seq_stop(struct seq_file *seq, void *v) + __releases(rcu) +{ + rcu_read_unlock(); +} + +struct kcm_proc_mux_state { + struct seq_net_private p; + int idx; +}; + +static void kcm_format_mux_header(struct seq_file *seq) +{ + struct net *net = seq_file_net(seq); + struct kcm_net *knet = net_generic(net, kcm_net_id); + + seq_printf(seq, + "*** KCM statistics (%d MUX) ****\n", + knet->count); + + seq_printf(seq, + "%-14s %-10s %-16s %-10s %-16s %-8s %-8s %-8s %-8s %s", + "Object", + "RX-Msgs", + "RX-Bytes", + "TX-Msgs", + "TX-Bytes", + "Recv-Q", + "Rmem", + "Send-Q", + "Smem", + "Status"); + + /* XXX: pdsts header stuff here */ + seq_puts(seq, "\n"); +} + +static void kcm_format_sock(struct kcm_sock *kcm, struct seq_file *seq, + int i, int *len) +{ + seq_printf(seq, + " kcm-%-7u %-10llu %-16llu %-10llu %-16llu %-8d %-8d %-8d %-8s ", + kcm->index, + kcm->stats.rx_msgs, + kcm->stats.rx_bytes, + kcm->stats.tx_msgs, + kcm->stats.tx_bytes, + kcm->sk.sk_receive_queue.qlen, + sk_rmem_alloc_get(&kcm->sk), + kcm->sk.sk_write_queue.qlen, + "-"); + + if (kcm->tx_psock) + seq_printf(seq, "Psck-%u ", kcm->tx_psock->index); + + if (kcm->tx_wait) + seq_puts(seq, "TxWait "); + + if (kcm->tx_wait_more) + seq_puts(seq, "WMore "); + + if (kcm->rx_wait) + seq_puts(seq, "RxWait "); + + seq_puts(seq, "\n"); +} + +static void kcm_format_psock(struct kcm_psock *psock, struct seq_file *seq, + int i, int *len) +{ + seq_printf(seq, + " psock-%-5u %-10llu %-16llu %-10llu %-16llu %-8d %-8d %-8d %-8d ", + psock->index, + psock->strp.stats.msgs, + psock->strp.stats.bytes, + psock->stats.tx_msgs, + psock->stats.tx_bytes, + psock->sk->sk_receive_queue.qlen, + atomic_read(&psock->sk->sk_rmem_alloc), + psock->sk->sk_write_queue.qlen, + refcount_read(&psock->sk->sk_wmem_alloc)); + + if (psock->done) + seq_puts(seq, "Done "); + + if (psock->tx_stopped) + seq_puts(seq, "TxStop "); + + if (psock->strp.stopped) + seq_puts(seq, "RxStop "); + + if (psock->tx_kcm) + seq_printf(seq, "Rsvd-%d ", psock->tx_kcm->index); + + if (!psock->strp.paused && !psock->ready_rx_msg) { + if (psock->sk->sk_receive_queue.qlen) { + if (psock->strp.need_bytes) + seq_printf(seq, "RxWait=%u ", + psock->strp.need_bytes); + else + seq_printf(seq, "RxWait "); + } + } else { + if (psock->strp.paused) + seq_puts(seq, "RxPause "); + + if (psock->ready_rx_msg) + seq_puts(seq, "RdyRx "); + } + + seq_puts(seq, "\n"); +} + +static void +kcm_format_mux(struct kcm_mux *mux, loff_t idx, struct seq_file *seq) +{ + int i, len; + struct kcm_sock *kcm; + struct kcm_psock *psock; + + /* mux information */ + seq_printf(seq, + "%-6s%-8s %-10llu %-16llu %-10llu %-16llu %-8s %-8s %-8s %-8s ", + "mux", "", + mux->stats.rx_msgs, + mux->stats.rx_bytes, + mux->stats.tx_msgs, + mux->stats.tx_bytes, + "-", "-", "-", "-"); + + seq_printf(seq, "KCMs: %d, Psocks %d\n", + mux->kcm_socks_cnt, mux->psocks_cnt); + + /* kcm sock information */ + i = 0; + spin_lock_bh(&mux->lock); + list_for_each_entry(kcm, &mux->kcm_socks, kcm_sock_list) { + kcm_format_sock(kcm, seq, i, &len); + i++; + } + i = 0; + list_for_each_entry(psock, &mux->psocks, psock_list) { + kcm_format_psock(psock, seq, i, &len); + i++; + } + spin_unlock_bh(&mux->lock); +} + +static int kcm_seq_show(struct seq_file *seq, void *v) +{ + struct kcm_proc_mux_state *mux_state; + + mux_state = seq->private; + if (v == SEQ_START_TOKEN) { + mux_state->idx = 0; + kcm_format_mux_header(seq); + } else { + kcm_format_mux(v, mux_state->idx, seq); + mux_state->idx++; + } + return 0; +} + +static const struct seq_operations kcm_seq_ops = { + .show = kcm_seq_show, + .start = kcm_seq_start, + .next = kcm_seq_next, + .stop = kcm_seq_stop, +}; + +static int kcm_stats_seq_show(struct seq_file *seq, void *v) +{ + struct kcm_psock_stats psock_stats; + struct kcm_mux_stats mux_stats; + struct strp_aggr_stats strp_stats; + struct kcm_mux *mux; + struct kcm_psock *psock; + struct net *net = seq->private; + struct kcm_net *knet = net_generic(net, kcm_net_id); + + memset(&mux_stats, 0, sizeof(mux_stats)); + memset(&psock_stats, 0, sizeof(psock_stats)); + memset(&strp_stats, 0, sizeof(strp_stats)); + + mutex_lock(&knet->mutex); + + aggregate_mux_stats(&knet->aggregate_mux_stats, &mux_stats); + aggregate_psock_stats(&knet->aggregate_psock_stats, + &psock_stats); + aggregate_strp_stats(&knet->aggregate_strp_stats, + &strp_stats); + + list_for_each_entry(mux, &knet->mux_list, kcm_mux_list) { + spin_lock_bh(&mux->lock); + aggregate_mux_stats(&mux->stats, &mux_stats); + aggregate_psock_stats(&mux->aggregate_psock_stats, + &psock_stats); + aggregate_strp_stats(&mux->aggregate_strp_stats, + &strp_stats); + list_for_each_entry(psock, &mux->psocks, psock_list) { + aggregate_psock_stats(&psock->stats, &psock_stats); + save_strp_stats(&psock->strp, &strp_stats); + } + + spin_unlock_bh(&mux->lock); + } + + mutex_unlock(&knet->mutex); + + seq_printf(seq, + "%-8s %-10s %-16s %-10s %-16s %-10s %-10s %-10s %-10s %-10s\n", + "MUX", + "RX-Msgs", + "RX-Bytes", + "TX-Msgs", + "TX-Bytes", + "TX-Retries", + "Attach", + "Unattach", + "UnattchRsvd", + "RX-RdyDrops"); + + seq_printf(seq, + "%-8s %-10llu %-16llu %-10llu %-16llu %-10u %-10u %-10u %-10u %-10u\n", + "", + mux_stats.rx_msgs, + mux_stats.rx_bytes, + mux_stats.tx_msgs, + mux_stats.tx_bytes, + mux_stats.tx_retries, + mux_stats.psock_attach, + mux_stats.psock_unattach_rsvd, + mux_stats.psock_unattach, + mux_stats.rx_ready_drops); + + seq_printf(seq, + "%-8s %-10s %-16s %-10s %-16s %-10s %-10s %-10s %-10s %-10s %-10s %-10s %-10s %-10s %-10s %-10s\n", + "Psock", + "RX-Msgs", + "RX-Bytes", + "TX-Msgs", + "TX-Bytes", + "Reserved", + "Unreserved", + "RX-Aborts", + "RX-Intr", + "RX-Unrecov", + "RX-MemFail", + "RX-NeedMor", + "RX-BadLen", + "RX-TooBig", + "RX-Timeout", + "TX-Aborts"); + + seq_printf(seq, + "%-8s %-10llu %-16llu %-10llu %-16llu %-10llu %-10llu %-10u %-10u %-10u %-10u %-10u %-10u %-10u %-10u %-10u\n", + "", + strp_stats.msgs, + strp_stats.bytes, + psock_stats.tx_msgs, + psock_stats.tx_bytes, + psock_stats.reserved, + psock_stats.unreserved, + strp_stats.aborts, + strp_stats.interrupted, + strp_stats.unrecov_intr, + strp_stats.mem_fail, + strp_stats.need_more_hdr, + strp_stats.bad_hdr_len, + strp_stats.msg_too_big, + strp_stats.msg_timeouts, + psock_stats.tx_aborts); + + return 0; +} + +static int kcm_proc_init_net(struct net *net) +{ + if (!proc_create_net_single("kcm_stats", 0444, net->proc_net, + kcm_stats_seq_show, NULL)) + goto out_kcm_stats; + + if (!proc_create_net("kcm", 0444, net->proc_net, &kcm_seq_ops, + sizeof(struct kcm_proc_mux_state))) + goto out_kcm; + + return 0; + +out_kcm: + remove_proc_entry("kcm_stats", net->proc_net); +out_kcm_stats: + return -ENOMEM; +} + +static void kcm_proc_exit_net(struct net *net) +{ + remove_proc_entry("kcm", net->proc_net); + remove_proc_entry("kcm_stats", net->proc_net); +} + +static struct pernet_operations kcm_net_ops = { + .init = kcm_proc_init_net, + .exit = kcm_proc_exit_net, +}; + +int __init kcm_proc_init(void) +{ + return register_pernet_subsys(&kcm_net_ops); +} + +void __exit kcm_proc_exit(void) +{ + unregister_pernet_subsys(&kcm_net_ops); +} + +#endif /* CONFIG_PROC_FS */ diff --git a/net/kcm/kcmsock.c b/net/kcm/kcmsock.c new file mode 100644 index 000000000..65845c59c --- /dev/null +++ b/net/kcm/kcmsock.c @@ -0,0 +1,2071 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Kernel Connection Multiplexor + * + * Copyright (c) 2016 Tom Herbert + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +unsigned int kcm_net_id; + +static struct kmem_cache *kcm_psockp __read_mostly; +static struct kmem_cache *kcm_muxp __read_mostly; +static struct workqueue_struct *kcm_wq; + +static inline struct kcm_sock *kcm_sk(const struct sock *sk) +{ + return (struct kcm_sock *)sk; +} + +static inline struct kcm_tx_msg *kcm_tx_msg(struct sk_buff *skb) +{ + return (struct kcm_tx_msg *)skb->cb; +} + +static void report_csk_error(struct sock *csk, int err) +{ + csk->sk_err = EPIPE; + sk_error_report(csk); +} + +static void kcm_abort_tx_psock(struct kcm_psock *psock, int err, + bool wakeup_kcm) +{ + struct sock *csk = psock->sk; + struct kcm_mux *mux = psock->mux; + + /* Unrecoverable error in transmit */ + + spin_lock_bh(&mux->lock); + + if (psock->tx_stopped) { + spin_unlock_bh(&mux->lock); + return; + } + + psock->tx_stopped = 1; + KCM_STATS_INCR(psock->stats.tx_aborts); + + if (!psock->tx_kcm) { + /* Take off psocks_avail list */ + list_del(&psock->psock_avail_list); + } else if (wakeup_kcm) { + /* In this case psock is being aborted while outside of + * write_msgs and psock is reserved. Schedule tx_work + * to handle the failure there. Need to commit tx_stopped + * before queuing work. + */ + smp_mb(); + + queue_work(kcm_wq, &psock->tx_kcm->tx_work); + } + + spin_unlock_bh(&mux->lock); + + /* Report error on lower socket */ + report_csk_error(csk, err); +} + +/* RX mux lock held. */ +static void kcm_update_rx_mux_stats(struct kcm_mux *mux, + struct kcm_psock *psock) +{ + STRP_STATS_ADD(mux->stats.rx_bytes, + psock->strp.stats.bytes - + psock->saved_rx_bytes); + mux->stats.rx_msgs += + psock->strp.stats.msgs - psock->saved_rx_msgs; + psock->saved_rx_msgs = psock->strp.stats.msgs; + psock->saved_rx_bytes = psock->strp.stats.bytes; +} + +static void kcm_update_tx_mux_stats(struct kcm_mux *mux, + struct kcm_psock *psock) +{ + KCM_STATS_ADD(mux->stats.tx_bytes, + psock->stats.tx_bytes - psock->saved_tx_bytes); + mux->stats.tx_msgs += + psock->stats.tx_msgs - psock->saved_tx_msgs; + psock->saved_tx_msgs = psock->stats.tx_msgs; + psock->saved_tx_bytes = psock->stats.tx_bytes; +} + +static int kcm_queue_rcv_skb(struct sock *sk, struct sk_buff *skb); + +/* KCM is ready to receive messages on its queue-- either the KCM is new or + * has become unblocked after being blocked on full socket buffer. Queue any + * pending ready messages on a psock. RX mux lock held. + */ +static void kcm_rcv_ready(struct kcm_sock *kcm) +{ + struct kcm_mux *mux = kcm->mux; + struct kcm_psock *psock; + struct sk_buff *skb; + + if (unlikely(kcm->rx_wait || kcm->rx_psock || kcm->rx_disabled)) + return; + + while (unlikely((skb = __skb_dequeue(&mux->rx_hold_queue)))) { + if (kcm_queue_rcv_skb(&kcm->sk, skb)) { + /* Assuming buffer limit has been reached */ + skb_queue_head(&mux->rx_hold_queue, skb); + WARN_ON(!sk_rmem_alloc_get(&kcm->sk)); + return; + } + } + + while (!list_empty(&mux->psocks_ready)) { + psock = list_first_entry(&mux->psocks_ready, struct kcm_psock, + psock_ready_list); + + if (kcm_queue_rcv_skb(&kcm->sk, psock->ready_rx_msg)) { + /* Assuming buffer limit has been reached */ + WARN_ON(!sk_rmem_alloc_get(&kcm->sk)); + return; + } + + /* Consumed the ready message on the psock. Schedule rx_work to + * get more messages. + */ + list_del(&psock->psock_ready_list); + psock->ready_rx_msg = NULL; + /* Commit clearing of ready_rx_msg for queuing work */ + smp_mb(); + + strp_unpause(&psock->strp); + strp_check_rcv(&psock->strp); + } + + /* Buffer limit is okay now, add to ready list */ + list_add_tail(&kcm->wait_rx_list, + &kcm->mux->kcm_rx_waiters); + /* paired with lockless reads in kcm_rfree() */ + WRITE_ONCE(kcm->rx_wait, true); +} + +static void kcm_rfree(struct sk_buff *skb) +{ + struct sock *sk = skb->sk; + struct kcm_sock *kcm = kcm_sk(sk); + struct kcm_mux *mux = kcm->mux; + unsigned int len = skb->truesize; + + sk_mem_uncharge(sk, len); + atomic_sub(len, &sk->sk_rmem_alloc); + + /* For reading rx_wait and rx_psock without holding lock */ + smp_mb__after_atomic(); + + if (!READ_ONCE(kcm->rx_wait) && !READ_ONCE(kcm->rx_psock) && + sk_rmem_alloc_get(sk) < sk->sk_rcvlowat) { + spin_lock_bh(&mux->rx_lock); + kcm_rcv_ready(kcm); + spin_unlock_bh(&mux->rx_lock); + } +} + +static int kcm_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) +{ + struct sk_buff_head *list = &sk->sk_receive_queue; + + if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf) + return -ENOMEM; + + if (!sk_rmem_schedule(sk, skb, skb->truesize)) + return -ENOBUFS; + + skb->dev = NULL; + + skb_orphan(skb); + skb->sk = sk; + skb->destructor = kcm_rfree; + atomic_add(skb->truesize, &sk->sk_rmem_alloc); + sk_mem_charge(sk, skb->truesize); + + skb_queue_tail(list, skb); + + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_data_ready(sk); + + return 0; +} + +/* Requeue received messages for a kcm socket to other kcm sockets. This is + * called with a kcm socket is receive disabled. + * RX mux lock held. + */ +static void requeue_rx_msgs(struct kcm_mux *mux, struct sk_buff_head *head) +{ + struct sk_buff *skb; + struct kcm_sock *kcm; + + while ((skb = skb_dequeue(head))) { + /* Reset destructor to avoid calling kcm_rcv_ready */ + skb->destructor = sock_rfree; + skb_orphan(skb); +try_again: + if (list_empty(&mux->kcm_rx_waiters)) { + skb_queue_tail(&mux->rx_hold_queue, skb); + continue; + } + + kcm = list_first_entry(&mux->kcm_rx_waiters, + struct kcm_sock, wait_rx_list); + + if (kcm_queue_rcv_skb(&kcm->sk, skb)) { + /* Should mean socket buffer full */ + list_del(&kcm->wait_rx_list); + /* paired with lockless reads in kcm_rfree() */ + WRITE_ONCE(kcm->rx_wait, false); + + /* Commit rx_wait to read in kcm_free */ + smp_wmb(); + + goto try_again; + } + } +} + +/* Lower sock lock held */ +static struct kcm_sock *reserve_rx_kcm(struct kcm_psock *psock, + struct sk_buff *head) +{ + struct kcm_mux *mux = psock->mux; + struct kcm_sock *kcm; + + WARN_ON(psock->ready_rx_msg); + + if (psock->rx_kcm) + return psock->rx_kcm; + + spin_lock_bh(&mux->rx_lock); + + if (psock->rx_kcm) { + spin_unlock_bh(&mux->rx_lock); + return psock->rx_kcm; + } + + kcm_update_rx_mux_stats(mux, psock); + + if (list_empty(&mux->kcm_rx_waiters)) { + psock->ready_rx_msg = head; + strp_pause(&psock->strp); + list_add_tail(&psock->psock_ready_list, + &mux->psocks_ready); + spin_unlock_bh(&mux->rx_lock); + return NULL; + } + + kcm = list_first_entry(&mux->kcm_rx_waiters, + struct kcm_sock, wait_rx_list); + list_del(&kcm->wait_rx_list); + /* paired with lockless reads in kcm_rfree() */ + WRITE_ONCE(kcm->rx_wait, false); + + psock->rx_kcm = kcm; + /* paired with lockless reads in kcm_rfree() */ + WRITE_ONCE(kcm->rx_psock, psock); + + spin_unlock_bh(&mux->rx_lock); + + return kcm; +} + +static void kcm_done(struct kcm_sock *kcm); + +static void kcm_done_work(struct work_struct *w) +{ + kcm_done(container_of(w, struct kcm_sock, done_work)); +} + +/* Lower sock held */ +static void unreserve_rx_kcm(struct kcm_psock *psock, + bool rcv_ready) +{ + struct kcm_sock *kcm = psock->rx_kcm; + struct kcm_mux *mux = psock->mux; + + if (!kcm) + return; + + spin_lock_bh(&mux->rx_lock); + + psock->rx_kcm = NULL; + /* paired with lockless reads in kcm_rfree() */ + WRITE_ONCE(kcm->rx_psock, NULL); + + /* Commit kcm->rx_psock before sk_rmem_alloc_get to sync with + * kcm_rfree + */ + smp_mb(); + + if (unlikely(kcm->done)) { + spin_unlock_bh(&mux->rx_lock); + + /* Need to run kcm_done in a task since we need to qcquire + * callback locks which may already be held here. + */ + INIT_WORK(&kcm->done_work, kcm_done_work); + schedule_work(&kcm->done_work); + return; + } + + if (unlikely(kcm->rx_disabled)) { + requeue_rx_msgs(mux, &kcm->sk.sk_receive_queue); + } else if (rcv_ready || unlikely(!sk_rmem_alloc_get(&kcm->sk))) { + /* Check for degenerative race with rx_wait that all + * data was dequeued (accounted for in kcm_rfree). + */ + kcm_rcv_ready(kcm); + } + spin_unlock_bh(&mux->rx_lock); +} + +/* Lower sock lock held */ +static void psock_data_ready(struct sock *sk) +{ + struct kcm_psock *psock; + + read_lock_bh(&sk->sk_callback_lock); + + psock = (struct kcm_psock *)sk->sk_user_data; + if (likely(psock)) + strp_data_ready(&psock->strp); + + read_unlock_bh(&sk->sk_callback_lock); +} + +/* Called with lower sock held */ +static void kcm_rcv_strparser(struct strparser *strp, struct sk_buff *skb) +{ + struct kcm_psock *psock = container_of(strp, struct kcm_psock, strp); + struct kcm_sock *kcm; + +try_queue: + kcm = reserve_rx_kcm(psock, skb); + if (!kcm) { + /* Unable to reserve a KCM, message is held in psock and strp + * is paused. + */ + return; + } + + if (kcm_queue_rcv_skb(&kcm->sk, skb)) { + /* Should mean socket buffer full */ + unreserve_rx_kcm(psock, false); + goto try_queue; + } +} + +static int kcm_parse_func_strparser(struct strparser *strp, struct sk_buff *skb) +{ + struct kcm_psock *psock = container_of(strp, struct kcm_psock, strp); + struct bpf_prog *prog = psock->bpf_prog; + int res; + + res = bpf_prog_run_pin_on_cpu(prog, skb); + return res; +} + +static int kcm_read_sock_done(struct strparser *strp, int err) +{ + struct kcm_psock *psock = container_of(strp, struct kcm_psock, strp); + + unreserve_rx_kcm(psock, true); + + return err; +} + +static void psock_state_change(struct sock *sk) +{ + /* TCP only does a EPOLLIN for a half close. Do a EPOLLHUP here + * since application will normally not poll with EPOLLIN + * on the TCP sockets. + */ + + report_csk_error(sk, EPIPE); +} + +static void psock_write_space(struct sock *sk) +{ + struct kcm_psock *psock; + struct kcm_mux *mux; + struct kcm_sock *kcm; + + read_lock_bh(&sk->sk_callback_lock); + + psock = (struct kcm_psock *)sk->sk_user_data; + if (unlikely(!psock)) + goto out; + mux = psock->mux; + + spin_lock_bh(&mux->lock); + + /* Check if the socket is reserved so someone is waiting for sending. */ + kcm = psock->tx_kcm; + if (kcm && !unlikely(kcm->tx_stopped)) + queue_work(kcm_wq, &kcm->tx_work); + + spin_unlock_bh(&mux->lock); +out: + read_unlock_bh(&sk->sk_callback_lock); +} + +static void unreserve_psock(struct kcm_sock *kcm); + +/* kcm sock is locked. */ +static struct kcm_psock *reserve_psock(struct kcm_sock *kcm) +{ + struct kcm_mux *mux = kcm->mux; + struct kcm_psock *psock; + + psock = kcm->tx_psock; + + smp_rmb(); /* Must read tx_psock before tx_wait */ + + if (psock) { + WARN_ON(kcm->tx_wait); + if (unlikely(psock->tx_stopped)) + unreserve_psock(kcm); + else + return kcm->tx_psock; + } + + spin_lock_bh(&mux->lock); + + /* Check again under lock to see if psock was reserved for this + * psock via psock_unreserve. + */ + psock = kcm->tx_psock; + if (unlikely(psock)) { + WARN_ON(kcm->tx_wait); + spin_unlock_bh(&mux->lock); + return kcm->tx_psock; + } + + if (!list_empty(&mux->psocks_avail)) { + psock = list_first_entry(&mux->psocks_avail, + struct kcm_psock, + psock_avail_list); + list_del(&psock->psock_avail_list); + if (kcm->tx_wait) { + list_del(&kcm->wait_psock_list); + kcm->tx_wait = false; + } + kcm->tx_psock = psock; + psock->tx_kcm = kcm; + KCM_STATS_INCR(psock->stats.reserved); + } else if (!kcm->tx_wait) { + list_add_tail(&kcm->wait_psock_list, + &mux->kcm_tx_waiters); + kcm->tx_wait = true; + } + + spin_unlock_bh(&mux->lock); + + return psock; +} + +/* mux lock held */ +static void psock_now_avail(struct kcm_psock *psock) +{ + struct kcm_mux *mux = psock->mux; + struct kcm_sock *kcm; + + if (list_empty(&mux->kcm_tx_waiters)) { + list_add_tail(&psock->psock_avail_list, + &mux->psocks_avail); + } else { + kcm = list_first_entry(&mux->kcm_tx_waiters, + struct kcm_sock, + wait_psock_list); + list_del(&kcm->wait_psock_list); + kcm->tx_wait = false; + psock->tx_kcm = kcm; + + /* Commit before changing tx_psock since that is read in + * reserve_psock before queuing work. + */ + smp_mb(); + + kcm->tx_psock = psock; + KCM_STATS_INCR(psock->stats.reserved); + queue_work(kcm_wq, &kcm->tx_work); + } +} + +/* kcm sock is locked. */ +static void unreserve_psock(struct kcm_sock *kcm) +{ + struct kcm_psock *psock; + struct kcm_mux *mux = kcm->mux; + + spin_lock_bh(&mux->lock); + + psock = kcm->tx_psock; + + if (WARN_ON(!psock)) { + spin_unlock_bh(&mux->lock); + return; + } + + smp_rmb(); /* Read tx_psock before tx_wait */ + + kcm_update_tx_mux_stats(mux, psock); + + WARN_ON(kcm->tx_wait); + + kcm->tx_psock = NULL; + psock->tx_kcm = NULL; + KCM_STATS_INCR(psock->stats.unreserved); + + if (unlikely(psock->tx_stopped)) { + if (psock->done) { + /* Deferred free */ + list_del(&psock->psock_list); + mux->psocks_cnt--; + sock_put(psock->sk); + fput(psock->sk->sk_socket->file); + kmem_cache_free(kcm_psockp, psock); + } + + /* Don't put back on available list */ + + spin_unlock_bh(&mux->lock); + + return; + } + + psock_now_avail(psock); + + spin_unlock_bh(&mux->lock); +} + +static void kcm_report_tx_retry(struct kcm_sock *kcm) +{ + struct kcm_mux *mux = kcm->mux; + + spin_lock_bh(&mux->lock); + KCM_STATS_INCR(mux->stats.tx_retries); + spin_unlock_bh(&mux->lock); +} + +/* Write any messages ready on the kcm socket. Called with kcm sock lock + * held. Return bytes actually sent or error. + */ +static int kcm_write_msgs(struct kcm_sock *kcm) +{ + struct sock *sk = &kcm->sk; + struct kcm_psock *psock; + struct sk_buff *skb, *head; + struct kcm_tx_msg *txm; + unsigned short fragidx, frag_offset; + unsigned int sent, total_sent = 0; + int ret = 0; + + kcm->tx_wait_more = false; + psock = kcm->tx_psock; + if (unlikely(psock && psock->tx_stopped)) { + /* A reserved psock was aborted asynchronously. Unreserve + * it and we'll retry the message. + */ + unreserve_psock(kcm); + kcm_report_tx_retry(kcm); + if (skb_queue_empty(&sk->sk_write_queue)) + return 0; + + kcm_tx_msg(skb_peek(&sk->sk_write_queue))->sent = 0; + + } else if (skb_queue_empty(&sk->sk_write_queue)) { + return 0; + } + + head = skb_peek(&sk->sk_write_queue); + txm = kcm_tx_msg(head); + + if (txm->sent) { + /* Send of first skbuff in queue already in progress */ + if (WARN_ON(!psock)) { + ret = -EINVAL; + goto out; + } + sent = txm->sent; + frag_offset = txm->frag_offset; + fragidx = txm->fragidx; + skb = txm->frag_skb; + + goto do_frag; + } + +try_again: + psock = reserve_psock(kcm); + if (!psock) + goto out; + + do { + skb = head; + txm = kcm_tx_msg(head); + sent = 0; + +do_frag_list: + if (WARN_ON(!skb_shinfo(skb)->nr_frags)) { + ret = -EINVAL; + goto out; + } + + for (fragidx = 0; fragidx < skb_shinfo(skb)->nr_frags; + fragidx++) { + skb_frag_t *frag; + + frag_offset = 0; +do_frag: + frag = &skb_shinfo(skb)->frags[fragidx]; + if (WARN_ON(!skb_frag_size(frag))) { + ret = -EINVAL; + goto out; + } + + ret = kernel_sendpage(psock->sk->sk_socket, + skb_frag_page(frag), + skb_frag_off(frag) + frag_offset, + skb_frag_size(frag) - frag_offset, + MSG_DONTWAIT); + if (ret <= 0) { + if (ret == -EAGAIN) { + /* Save state to try again when there's + * write space on the socket + */ + txm->sent = sent; + txm->frag_offset = frag_offset; + txm->fragidx = fragidx; + txm->frag_skb = skb; + + ret = 0; + goto out; + } + + /* Hard failure in sending message, abort this + * psock since it has lost framing + * synchronization and retry sending the + * message from the beginning. + */ + kcm_abort_tx_psock(psock, ret ? -ret : EPIPE, + true); + unreserve_psock(kcm); + + txm->sent = 0; + kcm_report_tx_retry(kcm); + ret = 0; + + goto try_again; + } + + sent += ret; + frag_offset += ret; + KCM_STATS_ADD(psock->stats.tx_bytes, ret); + if (frag_offset < skb_frag_size(frag)) { + /* Not finished with this frag */ + goto do_frag; + } + } + + if (skb == head) { + if (skb_has_frag_list(skb)) { + skb = skb_shinfo(skb)->frag_list; + goto do_frag_list; + } + } else if (skb->next) { + skb = skb->next; + goto do_frag_list; + } + + /* Successfully sent the whole packet, account for it. */ + skb_dequeue(&sk->sk_write_queue); + kfree_skb(head); + sk->sk_wmem_queued -= sent; + total_sent += sent; + KCM_STATS_INCR(psock->stats.tx_msgs); + } while ((head = skb_peek(&sk->sk_write_queue))); +out: + if (!head) { + /* Done with all queued messages. */ + WARN_ON(!skb_queue_empty(&sk->sk_write_queue)); + unreserve_psock(kcm); + } + + /* Check if write space is available */ + sk->sk_write_space(sk); + + return total_sent ? : ret; +} + +static void kcm_tx_work(struct work_struct *w) +{ + struct kcm_sock *kcm = container_of(w, struct kcm_sock, tx_work); + struct sock *sk = &kcm->sk; + int err; + + lock_sock(sk); + + /* Primarily for SOCK_DGRAM sockets, also handle asynchronous tx + * aborts + */ + err = kcm_write_msgs(kcm); + if (err < 0) { + /* Hard failure in write, report error on KCM socket */ + pr_warn("KCM: Hard failure on kcm_write_msgs %d\n", err); + report_csk_error(&kcm->sk, -err); + goto out; + } + + /* Primarily for SOCK_SEQPACKET sockets */ + if (likely(sk->sk_socket) && + test_bit(SOCK_NOSPACE, &sk->sk_socket->flags)) { + clear_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + sk->sk_write_space(sk); + } + +out: + release_sock(sk); +} + +static void kcm_push(struct kcm_sock *kcm) +{ + if (kcm->tx_wait_more) + kcm_write_msgs(kcm); +} + +static ssize_t kcm_sendpage(struct socket *sock, struct page *page, + int offset, size_t size, int flags) + +{ + struct sock *sk = sock->sk; + struct kcm_sock *kcm = kcm_sk(sk); + struct sk_buff *skb = NULL, *head = NULL; + long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); + bool eor; + int err = 0; + int i; + + if (flags & MSG_SENDPAGE_NOTLAST) + flags |= MSG_MORE; + + /* No MSG_EOR from splice, only look at MSG_MORE */ + eor = !(flags & MSG_MORE); + + lock_sock(sk); + + sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); + + err = -EPIPE; + if (sk->sk_err) + goto out_error; + + if (kcm->seq_skb) { + /* Previously opened message */ + head = kcm->seq_skb; + skb = kcm_tx_msg(head)->last_skb; + i = skb_shinfo(skb)->nr_frags; + + if (skb_can_coalesce(skb, i, page, offset)) { + skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], size); + skb_shinfo(skb)->flags |= SKBFL_SHARED_FRAG; + goto coalesced; + } + + if (i >= MAX_SKB_FRAGS) { + struct sk_buff *tskb; + + tskb = alloc_skb(0, sk->sk_allocation); + while (!tskb) { + kcm_push(kcm); + err = sk_stream_wait_memory(sk, &timeo); + if (err) + goto out_error; + } + + if (head == skb) + skb_shinfo(head)->frag_list = tskb; + else + skb->next = tskb; + + skb = tskb; + skb->ip_summed = CHECKSUM_UNNECESSARY; + i = 0; + } + } else { + /* Call the sk_stream functions to manage the sndbuf mem. */ + if (!sk_stream_memory_free(sk)) { + kcm_push(kcm); + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + err = sk_stream_wait_memory(sk, &timeo); + if (err) + goto out_error; + } + + head = alloc_skb(0, sk->sk_allocation); + while (!head) { + kcm_push(kcm); + err = sk_stream_wait_memory(sk, &timeo); + if (err) + goto out_error; + } + + skb = head; + i = 0; + } + + get_page(page); + skb_fill_page_desc_noacc(skb, i, page, offset, size); + skb_shinfo(skb)->flags |= SKBFL_SHARED_FRAG; + +coalesced: + skb->len += size; + skb->data_len += size; + skb->truesize += size; + sk->sk_wmem_queued += size; + sk_mem_charge(sk, size); + + if (head != skb) { + head->len += size; + head->data_len += size; + head->truesize += size; + } + + if (eor) { + bool not_busy = skb_queue_empty(&sk->sk_write_queue); + + /* Message complete, queue it on send buffer */ + __skb_queue_tail(&sk->sk_write_queue, head); + kcm->seq_skb = NULL; + KCM_STATS_INCR(kcm->stats.tx_msgs); + + if (flags & MSG_BATCH) { + kcm->tx_wait_more = true; + } else if (kcm->tx_wait_more || not_busy) { + err = kcm_write_msgs(kcm); + if (err < 0) { + /* We got a hard error in write_msgs but have + * already queued this message. Report an error + * in the socket, but don't affect return value + * from sendmsg + */ + pr_warn("KCM: Hard failure on kcm_write_msgs\n"); + report_csk_error(&kcm->sk, -err); + } + } + } else { + /* Message not complete, save state */ + kcm->seq_skb = head; + kcm_tx_msg(head)->last_skb = skb; + } + + KCM_STATS_ADD(kcm->stats.tx_bytes, size); + + release_sock(sk); + return size; + +out_error: + kcm_push(kcm); + + err = sk_stream_error(sk, flags, err); + + /* make sure we wake any epoll edge trigger waiter */ + if (unlikely(skb_queue_len(&sk->sk_write_queue) == 0 && err == -EAGAIN)) + sk->sk_write_space(sk); + + release_sock(sk); + return err; +} + +static int kcm_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) +{ + struct sock *sk = sock->sk; + struct kcm_sock *kcm = kcm_sk(sk); + struct sk_buff *skb = NULL, *head = NULL; + size_t copy, copied = 0; + long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + int eor = (sock->type == SOCK_DGRAM) ? + !(msg->msg_flags & MSG_MORE) : !!(msg->msg_flags & MSG_EOR); + int err = -EPIPE; + + lock_sock(sk); + + /* Per tcp_sendmsg this should be in poll */ + sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); + + if (sk->sk_err) + goto out_error; + + if (kcm->seq_skb) { + /* Previously opened message */ + head = kcm->seq_skb; + skb = kcm_tx_msg(head)->last_skb; + goto start; + } + + /* Call the sk_stream functions to manage the sndbuf mem. */ + if (!sk_stream_memory_free(sk)) { + kcm_push(kcm); + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + err = sk_stream_wait_memory(sk, &timeo); + if (err) + goto out_error; + } + + if (msg_data_left(msg)) { + /* New message, alloc head skb */ + head = alloc_skb(0, sk->sk_allocation); + while (!head) { + kcm_push(kcm); + err = sk_stream_wait_memory(sk, &timeo); + if (err) + goto out_error; + + head = alloc_skb(0, sk->sk_allocation); + } + + skb = head; + + /* Set ip_summed to CHECKSUM_UNNECESSARY to avoid calling + * csum_and_copy_from_iter from skb_do_copy_data_nocache. + */ + skb->ip_summed = CHECKSUM_UNNECESSARY; + } + +start: + while (msg_data_left(msg)) { + bool merge = true; + int i = skb_shinfo(skb)->nr_frags; + struct page_frag *pfrag = sk_page_frag(sk); + + if (!sk_page_frag_refill(sk, pfrag)) + goto wait_for_memory; + + if (!skb_can_coalesce(skb, i, pfrag->page, + pfrag->offset)) { + if (i == MAX_SKB_FRAGS) { + struct sk_buff *tskb; + + tskb = alloc_skb(0, sk->sk_allocation); + if (!tskb) + goto wait_for_memory; + + if (head == skb) + skb_shinfo(head)->frag_list = tskb; + else + skb->next = tskb; + + skb = tskb; + skb->ip_summed = CHECKSUM_UNNECESSARY; + continue; + } + merge = false; + } + + copy = min_t(int, msg_data_left(msg), + pfrag->size - pfrag->offset); + + if (!sk_wmem_schedule(sk, copy)) + goto wait_for_memory; + + err = skb_copy_to_page_nocache(sk, &msg->msg_iter, skb, + pfrag->page, + pfrag->offset, + copy); + if (err) + goto out_error; + + /* Update the skb. */ + if (merge) { + skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], copy); + } else { + skb_fill_page_desc(skb, i, pfrag->page, + pfrag->offset, copy); + get_page(pfrag->page); + } + + pfrag->offset += copy; + copied += copy; + if (head != skb) { + head->len += copy; + head->data_len += copy; + } + + continue; + +wait_for_memory: + kcm_push(kcm); + err = sk_stream_wait_memory(sk, &timeo); + if (err) + goto out_error; + } + + if (eor) { + bool not_busy = skb_queue_empty(&sk->sk_write_queue); + + if (head) { + /* Message complete, queue it on send buffer */ + __skb_queue_tail(&sk->sk_write_queue, head); + kcm->seq_skb = NULL; + KCM_STATS_INCR(kcm->stats.tx_msgs); + } + + if (msg->msg_flags & MSG_BATCH) { + kcm->tx_wait_more = true; + } else if (kcm->tx_wait_more || not_busy) { + err = kcm_write_msgs(kcm); + if (err < 0) { + /* We got a hard error in write_msgs but have + * already queued this message. Report an error + * in the socket, but don't affect return value + * from sendmsg + */ + pr_warn("KCM: Hard failure on kcm_write_msgs\n"); + report_csk_error(&kcm->sk, -err); + } + } + } else { + /* Message not complete, save state */ +partial_message: + if (head) { + kcm->seq_skb = head; + kcm_tx_msg(head)->last_skb = skb; + } + } + + KCM_STATS_ADD(kcm->stats.tx_bytes, copied); + + release_sock(sk); + return copied; + +out_error: + kcm_push(kcm); + + if (sock->type == SOCK_SEQPACKET) { + /* Wrote some bytes before encountering an + * error, return partial success. + */ + if (copied) + goto partial_message; + if (head != kcm->seq_skb) + kfree_skb(head); + } else { + kfree_skb(head); + kcm->seq_skb = NULL; + } + + err = sk_stream_error(sk, msg->msg_flags, err); + + /* make sure we wake any epoll edge trigger waiter */ + if (unlikely(skb_queue_len(&sk->sk_write_queue) == 0 && err == -EAGAIN)) + sk->sk_write_space(sk); + + release_sock(sk); + return err; +} + +static int kcm_recvmsg(struct socket *sock, struct msghdr *msg, + size_t len, int flags) +{ + struct sock *sk = sock->sk; + struct kcm_sock *kcm = kcm_sk(sk); + int err = 0; + struct strp_msg *stm; + int copied = 0; + struct sk_buff *skb; + + skb = skb_recv_datagram(sk, flags, &err); + if (!skb) + goto out; + + /* Okay, have a message on the receive queue */ + + stm = strp_msg(skb); + + if (len > stm->full_len) + len = stm->full_len; + + err = skb_copy_datagram_msg(skb, stm->offset, msg, len); + if (err < 0) + goto out; + + copied = len; + if (likely(!(flags & MSG_PEEK))) { + KCM_STATS_ADD(kcm->stats.rx_bytes, copied); + if (copied < stm->full_len) { + if (sock->type == SOCK_DGRAM) { + /* Truncated message */ + msg->msg_flags |= MSG_TRUNC; + goto msg_finished; + } + stm->offset += copied; + stm->full_len -= copied; + } else { +msg_finished: + /* Finished with message */ + msg->msg_flags |= MSG_EOR; + KCM_STATS_INCR(kcm->stats.rx_msgs); + } + } + +out: + skb_free_datagram(sk, skb); + return copied ? : err; +} + +static ssize_t kcm_splice_read(struct socket *sock, loff_t *ppos, + struct pipe_inode_info *pipe, size_t len, + unsigned int flags) +{ + struct sock *sk = sock->sk; + struct kcm_sock *kcm = kcm_sk(sk); + struct strp_msg *stm; + int err = 0; + ssize_t copied; + struct sk_buff *skb; + + /* Only support splice for SOCKSEQPACKET */ + + skb = skb_recv_datagram(sk, flags, &err); + if (!skb) + goto err_out; + + /* Okay, have a message on the receive queue */ + + stm = strp_msg(skb); + + if (len > stm->full_len) + len = stm->full_len; + + copied = skb_splice_bits(skb, sk, stm->offset, pipe, len, flags); + if (copied < 0) { + err = copied; + goto err_out; + } + + KCM_STATS_ADD(kcm->stats.rx_bytes, copied); + + stm->offset += copied; + stm->full_len -= copied; + + /* We have no way to return MSG_EOR. If all the bytes have been + * read we still leave the message in the receive socket buffer. + * A subsequent recvmsg needs to be done to return MSG_EOR and + * finish reading the message. + */ + + skb_free_datagram(sk, skb); + return copied; + +err_out: + skb_free_datagram(sk, skb); + return err; +} + +/* kcm sock lock held */ +static void kcm_recv_disable(struct kcm_sock *kcm) +{ + struct kcm_mux *mux = kcm->mux; + + if (kcm->rx_disabled) + return; + + spin_lock_bh(&mux->rx_lock); + + kcm->rx_disabled = 1; + + /* If a psock is reserved we'll do cleanup in unreserve */ + if (!kcm->rx_psock) { + if (kcm->rx_wait) { + list_del(&kcm->wait_rx_list); + /* paired with lockless reads in kcm_rfree() */ + WRITE_ONCE(kcm->rx_wait, false); + } + + requeue_rx_msgs(mux, &kcm->sk.sk_receive_queue); + } + + spin_unlock_bh(&mux->rx_lock); +} + +/* kcm sock lock held */ +static void kcm_recv_enable(struct kcm_sock *kcm) +{ + struct kcm_mux *mux = kcm->mux; + + if (!kcm->rx_disabled) + return; + + spin_lock_bh(&mux->rx_lock); + + kcm->rx_disabled = 0; + kcm_rcv_ready(kcm); + + spin_unlock_bh(&mux->rx_lock); +} + +static int kcm_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct kcm_sock *kcm = kcm_sk(sock->sk); + int val, valbool; + int err = 0; + + if (level != SOL_KCM) + return -ENOPROTOOPT; + + if (optlen < sizeof(int)) + return -EINVAL; + + if (copy_from_sockptr(&val, optval, sizeof(int))) + return -EFAULT; + + valbool = val ? 1 : 0; + + switch (optname) { + case KCM_RECV_DISABLE: + lock_sock(&kcm->sk); + if (valbool) + kcm_recv_disable(kcm); + else + kcm_recv_enable(kcm); + release_sock(&kcm->sk); + break; + default: + err = -ENOPROTOOPT; + } + + return err; +} + +static int kcm_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct kcm_sock *kcm = kcm_sk(sock->sk); + int val, len; + + if (level != SOL_KCM) + return -ENOPROTOOPT; + + if (get_user(len, optlen)) + return -EFAULT; + + len = min_t(unsigned int, len, sizeof(int)); + if (len < 0) + return -EINVAL; + + switch (optname) { + case KCM_RECV_DISABLE: + val = kcm->rx_disabled; + break; + default: + return -ENOPROTOOPT; + } + + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + return 0; +} + +static void init_kcm_sock(struct kcm_sock *kcm, struct kcm_mux *mux) +{ + struct kcm_sock *tkcm; + struct list_head *head; + int index = 0; + + /* For SOCK_SEQPACKET sock type, datagram_poll checks the sk_state, so + * we set sk_state, otherwise epoll_wait always returns right away with + * EPOLLHUP + */ + kcm->sk.sk_state = TCP_ESTABLISHED; + + /* Add to mux's kcm sockets list */ + kcm->mux = mux; + spin_lock_bh(&mux->lock); + + head = &mux->kcm_socks; + list_for_each_entry(tkcm, &mux->kcm_socks, kcm_sock_list) { + if (tkcm->index != index) + break; + head = &tkcm->kcm_sock_list; + index++; + } + + list_add(&kcm->kcm_sock_list, head); + kcm->index = index; + + mux->kcm_socks_cnt++; + spin_unlock_bh(&mux->lock); + + INIT_WORK(&kcm->tx_work, kcm_tx_work); + + spin_lock_bh(&mux->rx_lock); + kcm_rcv_ready(kcm); + spin_unlock_bh(&mux->rx_lock); +} + +static int kcm_attach(struct socket *sock, struct socket *csock, + struct bpf_prog *prog) +{ + struct kcm_sock *kcm = kcm_sk(sock->sk); + struct kcm_mux *mux = kcm->mux; + struct sock *csk; + struct kcm_psock *psock = NULL, *tpsock; + struct list_head *head; + int index = 0; + static const struct strp_callbacks cb = { + .rcv_msg = kcm_rcv_strparser, + .parse_msg = kcm_parse_func_strparser, + .read_sock_done = kcm_read_sock_done, + }; + int err = 0; + + csk = csock->sk; + if (!csk) + return -EINVAL; + + lock_sock(csk); + + /* Only allow TCP sockets to be attached for now */ + if ((csk->sk_family != AF_INET && csk->sk_family != AF_INET6) || + csk->sk_protocol != IPPROTO_TCP) { + err = -EOPNOTSUPP; + goto out; + } + + /* Don't allow listeners or closed sockets */ + if (csk->sk_state == TCP_LISTEN || csk->sk_state == TCP_CLOSE) { + err = -EOPNOTSUPP; + goto out; + } + + psock = kmem_cache_zalloc(kcm_psockp, GFP_KERNEL); + if (!psock) { + err = -ENOMEM; + goto out; + } + + psock->mux = mux; + psock->sk = csk; + psock->bpf_prog = prog; + + write_lock_bh(&csk->sk_callback_lock); + + /* Check if sk_user_data is already by KCM or someone else. + * Must be done under lock to prevent race conditions. + */ + if (csk->sk_user_data) { + write_unlock_bh(&csk->sk_callback_lock); + kmem_cache_free(kcm_psockp, psock); + err = -EALREADY; + goto out; + } + + err = strp_init(&psock->strp, csk, &cb); + if (err) { + write_unlock_bh(&csk->sk_callback_lock); + kmem_cache_free(kcm_psockp, psock); + goto out; + } + + psock->save_data_ready = csk->sk_data_ready; + psock->save_write_space = csk->sk_write_space; + psock->save_state_change = csk->sk_state_change; + csk->sk_user_data = psock; + csk->sk_data_ready = psock_data_ready; + csk->sk_write_space = psock_write_space; + csk->sk_state_change = psock_state_change; + + write_unlock_bh(&csk->sk_callback_lock); + + sock_hold(csk); + + /* Finished initialization, now add the psock to the MUX. */ + spin_lock_bh(&mux->lock); + head = &mux->psocks; + list_for_each_entry(tpsock, &mux->psocks, psock_list) { + if (tpsock->index != index) + break; + head = &tpsock->psock_list; + index++; + } + + list_add(&psock->psock_list, head); + psock->index = index; + + KCM_STATS_INCR(mux->stats.psock_attach); + mux->psocks_cnt++; + psock_now_avail(psock); + spin_unlock_bh(&mux->lock); + + /* Schedule RX work in case there are already bytes queued */ + strp_check_rcv(&psock->strp); + +out: + release_sock(csk); + + return err; +} + +static int kcm_attach_ioctl(struct socket *sock, struct kcm_attach *info) +{ + struct socket *csock; + struct bpf_prog *prog; + int err; + + csock = sockfd_lookup(info->fd, &err); + if (!csock) + return -ENOENT; + + prog = bpf_prog_get_type(info->bpf_fd, BPF_PROG_TYPE_SOCKET_FILTER); + if (IS_ERR(prog)) { + err = PTR_ERR(prog); + goto out; + } + + err = kcm_attach(sock, csock, prog); + if (err) { + bpf_prog_put(prog); + goto out; + } + + /* Keep reference on file also */ + + return 0; +out: + sockfd_put(csock); + return err; +} + +static void kcm_unattach(struct kcm_psock *psock) +{ + struct sock *csk = psock->sk; + struct kcm_mux *mux = psock->mux; + + lock_sock(csk); + + /* Stop getting callbacks from TCP socket. After this there should + * be no way to reserve a kcm for this psock. + */ + write_lock_bh(&csk->sk_callback_lock); + csk->sk_user_data = NULL; + csk->sk_data_ready = psock->save_data_ready; + csk->sk_write_space = psock->save_write_space; + csk->sk_state_change = psock->save_state_change; + strp_stop(&psock->strp); + + if (WARN_ON(psock->rx_kcm)) { + write_unlock_bh(&csk->sk_callback_lock); + release_sock(csk); + return; + } + + spin_lock_bh(&mux->rx_lock); + + /* Stop receiver activities. After this point psock should not be + * able to get onto ready list either through callbacks or work. + */ + if (psock->ready_rx_msg) { + list_del(&psock->psock_ready_list); + kfree_skb(psock->ready_rx_msg); + psock->ready_rx_msg = NULL; + KCM_STATS_INCR(mux->stats.rx_ready_drops); + } + + spin_unlock_bh(&mux->rx_lock); + + write_unlock_bh(&csk->sk_callback_lock); + + /* Call strp_done without sock lock */ + release_sock(csk); + strp_done(&psock->strp); + lock_sock(csk); + + bpf_prog_put(psock->bpf_prog); + + spin_lock_bh(&mux->lock); + + aggregate_psock_stats(&psock->stats, &mux->aggregate_psock_stats); + save_strp_stats(&psock->strp, &mux->aggregate_strp_stats); + + KCM_STATS_INCR(mux->stats.psock_unattach); + + if (psock->tx_kcm) { + /* psock was reserved. Just mark it finished and we will clean + * up in the kcm paths, we need kcm lock which can not be + * acquired here. + */ + KCM_STATS_INCR(mux->stats.psock_unattach_rsvd); + spin_unlock_bh(&mux->lock); + + /* We are unattaching a socket that is reserved. Abort the + * socket since we may be out of sync in sending on it. We need + * to do this without the mux lock. + */ + kcm_abort_tx_psock(psock, EPIPE, false); + + spin_lock_bh(&mux->lock); + if (!psock->tx_kcm) { + /* psock now unreserved in window mux was unlocked */ + goto no_reserved; + } + psock->done = 1; + + /* Commit done before queuing work to process it */ + smp_mb(); + + /* Queue tx work to make sure psock->done is handled */ + queue_work(kcm_wq, &psock->tx_kcm->tx_work); + spin_unlock_bh(&mux->lock); + } else { +no_reserved: + if (!psock->tx_stopped) + list_del(&psock->psock_avail_list); + list_del(&psock->psock_list); + mux->psocks_cnt--; + spin_unlock_bh(&mux->lock); + + sock_put(csk); + fput(csk->sk_socket->file); + kmem_cache_free(kcm_psockp, psock); + } + + release_sock(csk); +} + +static int kcm_unattach_ioctl(struct socket *sock, struct kcm_unattach *info) +{ + struct kcm_sock *kcm = kcm_sk(sock->sk); + struct kcm_mux *mux = kcm->mux; + struct kcm_psock *psock; + struct socket *csock; + struct sock *csk; + int err; + + csock = sockfd_lookup(info->fd, &err); + if (!csock) + return -ENOENT; + + csk = csock->sk; + if (!csk) { + err = -EINVAL; + goto out; + } + + err = -ENOENT; + + spin_lock_bh(&mux->lock); + + list_for_each_entry(psock, &mux->psocks, psock_list) { + if (psock->sk != csk) + continue; + + /* Found the matching psock */ + + if (psock->unattaching || WARN_ON(psock->done)) { + err = -EALREADY; + break; + } + + psock->unattaching = 1; + + spin_unlock_bh(&mux->lock); + + /* Lower socket lock should already be held */ + kcm_unattach(psock); + + err = 0; + goto out; + } + + spin_unlock_bh(&mux->lock); + +out: + sockfd_put(csock); + return err; +} + +static struct proto kcm_proto = { + .name = "KCM", + .owner = THIS_MODULE, + .obj_size = sizeof(struct kcm_sock), +}; + +/* Clone a kcm socket. */ +static struct file *kcm_clone(struct socket *osock) +{ + struct socket *newsock; + struct sock *newsk; + + newsock = sock_alloc(); + if (!newsock) + return ERR_PTR(-ENFILE); + + newsock->type = osock->type; + newsock->ops = osock->ops; + + __module_get(newsock->ops->owner); + + newsk = sk_alloc(sock_net(osock->sk), PF_KCM, GFP_KERNEL, + &kcm_proto, false); + if (!newsk) { + sock_release(newsock); + return ERR_PTR(-ENOMEM); + } + sock_init_data(newsock, newsk); + init_kcm_sock(kcm_sk(newsk), kcm_sk(osock->sk)->mux); + + return sock_alloc_file(newsock, 0, osock->sk->sk_prot_creator->name); +} + +static int kcm_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + int err; + + switch (cmd) { + case SIOCKCMATTACH: { + struct kcm_attach info; + + if (copy_from_user(&info, (void __user *)arg, sizeof(info))) + return -EFAULT; + + err = kcm_attach_ioctl(sock, &info); + + break; + } + case SIOCKCMUNATTACH: { + struct kcm_unattach info; + + if (copy_from_user(&info, (void __user *)arg, sizeof(info))) + return -EFAULT; + + err = kcm_unattach_ioctl(sock, &info); + + break; + } + case SIOCKCMCLONE: { + struct kcm_clone info; + struct file *file; + + info.fd = get_unused_fd_flags(0); + if (unlikely(info.fd < 0)) + return info.fd; + + file = kcm_clone(sock); + if (IS_ERR(file)) { + put_unused_fd(info.fd); + return PTR_ERR(file); + } + if (copy_to_user((void __user *)arg, &info, + sizeof(info))) { + put_unused_fd(info.fd); + fput(file); + return -EFAULT; + } + fd_install(info.fd, file); + err = 0; + break; + } + default: + err = -ENOIOCTLCMD; + break; + } + + return err; +} + +static void free_mux(struct rcu_head *rcu) +{ + struct kcm_mux *mux = container_of(rcu, + struct kcm_mux, rcu); + + kmem_cache_free(kcm_muxp, mux); +} + +static void release_mux(struct kcm_mux *mux) +{ + struct kcm_net *knet = mux->knet; + struct kcm_psock *psock, *tmp_psock; + + /* Release psocks */ + list_for_each_entry_safe(psock, tmp_psock, + &mux->psocks, psock_list) { + if (!WARN_ON(psock->unattaching)) + kcm_unattach(psock); + } + + if (WARN_ON(mux->psocks_cnt)) + return; + + __skb_queue_purge(&mux->rx_hold_queue); + + mutex_lock(&knet->mutex); + aggregate_mux_stats(&mux->stats, &knet->aggregate_mux_stats); + aggregate_psock_stats(&mux->aggregate_psock_stats, + &knet->aggregate_psock_stats); + aggregate_strp_stats(&mux->aggregate_strp_stats, + &knet->aggregate_strp_stats); + list_del_rcu(&mux->kcm_mux_list); + knet->count--; + mutex_unlock(&knet->mutex); + + call_rcu(&mux->rcu, free_mux); +} + +static void kcm_done(struct kcm_sock *kcm) +{ + struct kcm_mux *mux = kcm->mux; + struct sock *sk = &kcm->sk; + int socks_cnt; + + spin_lock_bh(&mux->rx_lock); + if (kcm->rx_psock) { + /* Cleanup in unreserve_rx_kcm */ + WARN_ON(kcm->done); + kcm->rx_disabled = 1; + kcm->done = 1; + spin_unlock_bh(&mux->rx_lock); + return; + } + + if (kcm->rx_wait) { + list_del(&kcm->wait_rx_list); + /* paired with lockless reads in kcm_rfree() */ + WRITE_ONCE(kcm->rx_wait, false); + } + /* Move any pending receive messages to other kcm sockets */ + requeue_rx_msgs(mux, &sk->sk_receive_queue); + + spin_unlock_bh(&mux->rx_lock); + + if (WARN_ON(sk_rmem_alloc_get(sk))) + return; + + /* Detach from MUX */ + spin_lock_bh(&mux->lock); + + list_del(&kcm->kcm_sock_list); + mux->kcm_socks_cnt--; + socks_cnt = mux->kcm_socks_cnt; + + spin_unlock_bh(&mux->lock); + + if (!socks_cnt) { + /* We are done with the mux now. */ + release_mux(mux); + } + + WARN_ON(kcm->rx_wait); + + sock_put(&kcm->sk); +} + +/* Called by kcm_release to close a KCM socket. + * If this is the last KCM socket on the MUX, destroy the MUX. + */ +static int kcm_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct kcm_sock *kcm; + struct kcm_mux *mux; + struct kcm_psock *psock; + + if (!sk) + return 0; + + kcm = kcm_sk(sk); + mux = kcm->mux; + + lock_sock(sk); + sock_orphan(sk); + kfree_skb(kcm->seq_skb); + + /* Purge queue under lock to avoid race condition with tx_work trying + * to act when queue is nonempty. If tx_work runs after this point + * it will just return. + */ + __skb_queue_purge(&sk->sk_write_queue); + + /* Set tx_stopped. This is checked when psock is bound to a kcm and we + * get a writespace callback. This prevents further work being queued + * from the callback (unbinding the psock occurs after canceling work. + */ + kcm->tx_stopped = 1; + + release_sock(sk); + + spin_lock_bh(&mux->lock); + if (kcm->tx_wait) { + /* Take of tx_wait list, after this point there should be no way + * that a psock will be assigned to this kcm. + */ + list_del(&kcm->wait_psock_list); + kcm->tx_wait = false; + } + spin_unlock_bh(&mux->lock); + + /* Cancel work. After this point there should be no outside references + * to the kcm socket. + */ + cancel_work_sync(&kcm->tx_work); + + lock_sock(sk); + psock = kcm->tx_psock; + if (psock) { + /* A psock was reserved, so we need to kill it since it + * may already have some bytes queued from a message. We + * need to do this after removing kcm from tx_wait list. + */ + kcm_abort_tx_psock(psock, EPIPE, false); + unreserve_psock(kcm); + } + release_sock(sk); + + WARN_ON(kcm->tx_wait); + WARN_ON(kcm->tx_psock); + + sock->sk = NULL; + + kcm_done(kcm); + + return 0; +} + +static const struct proto_ops kcm_dgram_ops = { + .family = PF_KCM, + .owner = THIS_MODULE, + .release = kcm_release, + .bind = sock_no_bind, + .connect = sock_no_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = sock_no_getname, + .poll = datagram_poll, + .ioctl = kcm_ioctl, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = kcm_setsockopt, + .getsockopt = kcm_getsockopt, + .sendmsg = kcm_sendmsg, + .recvmsg = kcm_recvmsg, + .mmap = sock_no_mmap, + .sendpage = kcm_sendpage, +}; + +static const struct proto_ops kcm_seqpacket_ops = { + .family = PF_KCM, + .owner = THIS_MODULE, + .release = kcm_release, + .bind = sock_no_bind, + .connect = sock_no_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = sock_no_getname, + .poll = datagram_poll, + .ioctl = kcm_ioctl, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = kcm_setsockopt, + .getsockopt = kcm_getsockopt, + .sendmsg = kcm_sendmsg, + .recvmsg = kcm_recvmsg, + .mmap = sock_no_mmap, + .sendpage = kcm_sendpage, + .splice_read = kcm_splice_read, +}; + +/* Create proto operation for kcm sockets */ +static int kcm_create(struct net *net, struct socket *sock, + int protocol, int kern) +{ + struct kcm_net *knet = net_generic(net, kcm_net_id); + struct sock *sk; + struct kcm_mux *mux; + + switch (sock->type) { + case SOCK_DGRAM: + sock->ops = &kcm_dgram_ops; + break; + case SOCK_SEQPACKET: + sock->ops = &kcm_seqpacket_ops; + break; + default: + return -ESOCKTNOSUPPORT; + } + + if (protocol != KCMPROTO_CONNECTED) + return -EPROTONOSUPPORT; + + sk = sk_alloc(net, PF_KCM, GFP_KERNEL, &kcm_proto, kern); + if (!sk) + return -ENOMEM; + + /* Allocate a kcm mux, shared between KCM sockets */ + mux = kmem_cache_zalloc(kcm_muxp, GFP_KERNEL); + if (!mux) { + sk_free(sk); + return -ENOMEM; + } + + spin_lock_init(&mux->lock); + spin_lock_init(&mux->rx_lock); + INIT_LIST_HEAD(&mux->kcm_socks); + INIT_LIST_HEAD(&mux->kcm_rx_waiters); + INIT_LIST_HEAD(&mux->kcm_tx_waiters); + + INIT_LIST_HEAD(&mux->psocks); + INIT_LIST_HEAD(&mux->psocks_ready); + INIT_LIST_HEAD(&mux->psocks_avail); + + mux->knet = knet; + + /* Add new MUX to list */ + mutex_lock(&knet->mutex); + list_add_rcu(&mux->kcm_mux_list, &knet->mux_list); + knet->count++; + mutex_unlock(&knet->mutex); + + skb_queue_head_init(&mux->rx_hold_queue); + + /* Init KCM socket */ + sock_init_data(sock, sk); + init_kcm_sock(kcm_sk(sk), mux); + + return 0; +} + +static const struct net_proto_family kcm_family_ops = { + .family = PF_KCM, + .create = kcm_create, + .owner = THIS_MODULE, +}; + +static __net_init int kcm_init_net(struct net *net) +{ + struct kcm_net *knet = net_generic(net, kcm_net_id); + + INIT_LIST_HEAD_RCU(&knet->mux_list); + mutex_init(&knet->mutex); + + return 0; +} + +static __net_exit void kcm_exit_net(struct net *net) +{ + struct kcm_net *knet = net_generic(net, kcm_net_id); + + /* All KCM sockets should be closed at this point, which should mean + * that all multiplexors and psocks have been destroyed. + */ + WARN_ON(!list_empty(&knet->mux_list)); + + mutex_destroy(&knet->mutex); +} + +static struct pernet_operations kcm_net_ops = { + .init = kcm_init_net, + .exit = kcm_exit_net, + .id = &kcm_net_id, + .size = sizeof(struct kcm_net), +}; + +static int __init kcm_init(void) +{ + int err = -ENOMEM; + + kcm_muxp = kmem_cache_create("kcm_mux_cache", + sizeof(struct kcm_mux), 0, + SLAB_HWCACHE_ALIGN, NULL); + if (!kcm_muxp) + goto fail; + + kcm_psockp = kmem_cache_create("kcm_psock_cache", + sizeof(struct kcm_psock), 0, + SLAB_HWCACHE_ALIGN, NULL); + if (!kcm_psockp) + goto fail; + + kcm_wq = create_singlethread_workqueue("kkcmd"); + if (!kcm_wq) + goto fail; + + err = proto_register(&kcm_proto, 1); + if (err) + goto fail; + + err = register_pernet_device(&kcm_net_ops); + if (err) + goto net_ops_fail; + + err = sock_register(&kcm_family_ops); + if (err) + goto sock_register_fail; + + err = kcm_proc_init(); + if (err) + goto proc_init_fail; + + return 0; + +proc_init_fail: + sock_unregister(PF_KCM); + +sock_register_fail: + unregister_pernet_device(&kcm_net_ops); + +net_ops_fail: + proto_unregister(&kcm_proto); + +fail: + kmem_cache_destroy(kcm_muxp); + kmem_cache_destroy(kcm_psockp); + + if (kcm_wq) + destroy_workqueue(kcm_wq); + + return err; +} + +static void __exit kcm_exit(void) +{ + kcm_proc_exit(); + sock_unregister(PF_KCM); + unregister_pernet_device(&kcm_net_ops); + proto_unregister(&kcm_proto); + destroy_workqueue(kcm_wq); + + kmem_cache_destroy(kcm_muxp); + kmem_cache_destroy(kcm_psockp); +} + +module_init(kcm_init); +module_exit(kcm_exit); + +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_KCM); diff --git a/net/key/Makefile b/net/key/Makefile new file mode 100644 index 000000000..ed779c22f --- /dev/null +++ b/net/key/Makefile @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the key AF. +# + +obj-$(CONFIG_NET_KEY) += af_key.o diff --git a/net/key/af_key.c b/net/key/af_key.c new file mode 100644 index 000000000..8a8f2429d --- /dev/null +++ b/net/key/af_key.c @@ -0,0 +1,3930 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/key/af_key.c An implementation of PF_KEYv2 sockets. + * + * Authors: Maxim Giryaev + * David S. Miller + * Alexey Kuznetsov + * Kunihiro Ishiguro + * Kazunori MIYAZAWA / USAGI Project + * Derek Atkins + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#define _X2KEY(x) ((x) == XFRM_INF ? 0 : (x)) +#define _KEY2X(x) ((x) == 0 ? XFRM_INF : (x)) + +static unsigned int pfkey_net_id __read_mostly; +struct netns_pfkey { + /* List of all pfkey sockets. */ + struct hlist_head table; + atomic_t socks_nr; +}; +static DEFINE_MUTEX(pfkey_mutex); + +#define DUMMY_MARK 0 +static const struct xfrm_mark dummy_mark = {0, 0}; +struct pfkey_sock { + /* struct sock must be the first member of struct pfkey_sock */ + struct sock sk; + int registered; + int promisc; + + struct { + uint8_t msg_version; + uint32_t msg_portid; + int (*dump)(struct pfkey_sock *sk); + void (*done)(struct pfkey_sock *sk); + union { + struct xfrm_policy_walk policy; + struct xfrm_state_walk state; + } u; + struct sk_buff *skb; + } dump; + struct mutex dump_lock; +}; + +static int parse_sockaddr_pair(struct sockaddr *sa, int ext_len, + xfrm_address_t *saddr, xfrm_address_t *daddr, + u16 *family); + +static inline struct pfkey_sock *pfkey_sk(struct sock *sk) +{ + return (struct pfkey_sock *)sk; +} + +static int pfkey_can_dump(const struct sock *sk) +{ + if (3 * atomic_read(&sk->sk_rmem_alloc) <= 2 * sk->sk_rcvbuf) + return 1; + return 0; +} + +static void pfkey_terminate_dump(struct pfkey_sock *pfk) +{ + if (pfk->dump.dump) { + if (pfk->dump.skb) { + kfree_skb(pfk->dump.skb); + pfk->dump.skb = NULL; + } + pfk->dump.done(pfk); + pfk->dump.dump = NULL; + pfk->dump.done = NULL; + } +} + +static void pfkey_sock_destruct(struct sock *sk) +{ + struct net *net = sock_net(sk); + struct netns_pfkey *net_pfkey = net_generic(net, pfkey_net_id); + + pfkey_terminate_dump(pfkey_sk(sk)); + skb_queue_purge(&sk->sk_receive_queue); + + if (!sock_flag(sk, SOCK_DEAD)) { + pr_err("Attempt to release alive pfkey socket: %p\n", sk); + return; + } + + WARN_ON(atomic_read(&sk->sk_rmem_alloc)); + WARN_ON(refcount_read(&sk->sk_wmem_alloc)); + + atomic_dec(&net_pfkey->socks_nr); +} + +static const struct proto_ops pfkey_ops; + +static void pfkey_insert(struct sock *sk) +{ + struct net *net = sock_net(sk); + struct netns_pfkey *net_pfkey = net_generic(net, pfkey_net_id); + + mutex_lock(&pfkey_mutex); + sk_add_node_rcu(sk, &net_pfkey->table); + mutex_unlock(&pfkey_mutex); +} + +static void pfkey_remove(struct sock *sk) +{ + mutex_lock(&pfkey_mutex); + sk_del_node_init_rcu(sk); + mutex_unlock(&pfkey_mutex); +} + +static struct proto key_proto = { + .name = "KEY", + .owner = THIS_MODULE, + .obj_size = sizeof(struct pfkey_sock), +}; + +static int pfkey_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct netns_pfkey *net_pfkey = net_generic(net, pfkey_net_id); + struct sock *sk; + struct pfkey_sock *pfk; + + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + if (sock->type != SOCK_RAW) + return -ESOCKTNOSUPPORT; + if (protocol != PF_KEY_V2) + return -EPROTONOSUPPORT; + + sk = sk_alloc(net, PF_KEY, GFP_KERNEL, &key_proto, kern); + if (sk == NULL) + return -ENOMEM; + + pfk = pfkey_sk(sk); + mutex_init(&pfk->dump_lock); + + sock->ops = &pfkey_ops; + sock_init_data(sock, sk); + + sk->sk_family = PF_KEY; + sk->sk_destruct = pfkey_sock_destruct; + + atomic_inc(&net_pfkey->socks_nr); + + pfkey_insert(sk); + + return 0; +} + +static int pfkey_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + + if (!sk) + return 0; + + pfkey_remove(sk); + + sock_orphan(sk); + sock->sk = NULL; + skb_queue_purge(&sk->sk_write_queue); + + synchronize_rcu(); + sock_put(sk); + + return 0; +} + +static int pfkey_broadcast_one(struct sk_buff *skb, gfp_t allocation, + struct sock *sk) +{ + int err = -ENOBUFS; + + if (atomic_read(&sk->sk_rmem_alloc) > sk->sk_rcvbuf) + return err; + + skb = skb_clone(skb, allocation); + + if (skb) { + skb_set_owner_r(skb, sk); + skb_queue_tail(&sk->sk_receive_queue, skb); + sk->sk_data_ready(sk); + err = 0; + } + return err; +} + +/* Send SKB to all pfkey sockets matching selected criteria. */ +#define BROADCAST_ALL 0 +#define BROADCAST_ONE 1 +#define BROADCAST_REGISTERED 2 +#define BROADCAST_PROMISC_ONLY 4 +static int pfkey_broadcast(struct sk_buff *skb, gfp_t allocation, + int broadcast_flags, struct sock *one_sk, + struct net *net) +{ + struct netns_pfkey *net_pfkey = net_generic(net, pfkey_net_id); + struct sock *sk; + int err = -ESRCH; + + /* XXX Do we need something like netlink_overrun? I think + * XXX PF_KEY socket apps will not mind current behavior. + */ + if (!skb) + return -ENOMEM; + + rcu_read_lock(); + sk_for_each_rcu(sk, &net_pfkey->table) { + struct pfkey_sock *pfk = pfkey_sk(sk); + int err2; + + /* Yes, it means that if you are meant to receive this + * pfkey message you receive it twice as promiscuous + * socket. + */ + if (pfk->promisc) + pfkey_broadcast_one(skb, GFP_ATOMIC, sk); + + /* the exact target will be processed later */ + if (sk == one_sk) + continue; + if (broadcast_flags != BROADCAST_ALL) { + if (broadcast_flags & BROADCAST_PROMISC_ONLY) + continue; + if ((broadcast_flags & BROADCAST_REGISTERED) && + !pfk->registered) + continue; + if (broadcast_flags & BROADCAST_ONE) + continue; + } + + err2 = pfkey_broadcast_one(skb, GFP_ATOMIC, sk); + + /* Error is cleared after successful sending to at least one + * registered KM */ + if ((broadcast_flags & BROADCAST_REGISTERED) && err) + err = err2; + } + rcu_read_unlock(); + + if (one_sk != NULL) + err = pfkey_broadcast_one(skb, allocation, one_sk); + + kfree_skb(skb); + return err; +} + +static int pfkey_do_dump(struct pfkey_sock *pfk) +{ + struct sadb_msg *hdr; + int rc; + + mutex_lock(&pfk->dump_lock); + if (!pfk->dump.dump) { + rc = 0; + goto out; + } + + rc = pfk->dump.dump(pfk); + if (rc == -ENOBUFS) { + rc = 0; + goto out; + } + + if (pfk->dump.skb) { + if (!pfkey_can_dump(&pfk->sk)) { + rc = 0; + goto out; + } + + hdr = (struct sadb_msg *) pfk->dump.skb->data; + hdr->sadb_msg_seq = 0; + hdr->sadb_msg_errno = rc; + pfkey_broadcast(pfk->dump.skb, GFP_ATOMIC, BROADCAST_ONE, + &pfk->sk, sock_net(&pfk->sk)); + pfk->dump.skb = NULL; + } + + pfkey_terminate_dump(pfk); + +out: + mutex_unlock(&pfk->dump_lock); + return rc; +} + +static inline void pfkey_hdr_dup(struct sadb_msg *new, + const struct sadb_msg *orig) +{ + *new = *orig; +} + +static int pfkey_error(const struct sadb_msg *orig, int err, struct sock *sk) +{ + struct sk_buff *skb = alloc_skb(sizeof(struct sadb_msg) + 16, GFP_KERNEL); + struct sadb_msg *hdr; + + if (!skb) + return -ENOBUFS; + + /* Woe be to the platform trying to support PFKEY yet + * having normal errnos outside the 1-255 range, inclusive. + */ + err = -err; + if (err == ERESTARTSYS || + err == ERESTARTNOHAND || + err == ERESTARTNOINTR) + err = EINTR; + if (err >= 512) + err = EINVAL; + BUG_ON(err <= 0 || err >= 256); + + hdr = skb_put(skb, sizeof(struct sadb_msg)); + pfkey_hdr_dup(hdr, orig); + hdr->sadb_msg_errno = (uint8_t) err; + hdr->sadb_msg_len = (sizeof(struct sadb_msg) / + sizeof(uint64_t)); + + pfkey_broadcast(skb, GFP_KERNEL, BROADCAST_ONE, sk, sock_net(sk)); + + return 0; +} + +static const u8 sadb_ext_min_len[] = { + [SADB_EXT_RESERVED] = (u8) 0, + [SADB_EXT_SA] = (u8) sizeof(struct sadb_sa), + [SADB_EXT_LIFETIME_CURRENT] = (u8) sizeof(struct sadb_lifetime), + [SADB_EXT_LIFETIME_HARD] = (u8) sizeof(struct sadb_lifetime), + [SADB_EXT_LIFETIME_SOFT] = (u8) sizeof(struct sadb_lifetime), + [SADB_EXT_ADDRESS_SRC] = (u8) sizeof(struct sadb_address), + [SADB_EXT_ADDRESS_DST] = (u8) sizeof(struct sadb_address), + [SADB_EXT_ADDRESS_PROXY] = (u8) sizeof(struct sadb_address), + [SADB_EXT_KEY_AUTH] = (u8) sizeof(struct sadb_key), + [SADB_EXT_KEY_ENCRYPT] = (u8) sizeof(struct sadb_key), + [SADB_EXT_IDENTITY_SRC] = (u8) sizeof(struct sadb_ident), + [SADB_EXT_IDENTITY_DST] = (u8) sizeof(struct sadb_ident), + [SADB_EXT_SENSITIVITY] = (u8) sizeof(struct sadb_sens), + [SADB_EXT_PROPOSAL] = (u8) sizeof(struct sadb_prop), + [SADB_EXT_SUPPORTED_AUTH] = (u8) sizeof(struct sadb_supported), + [SADB_EXT_SUPPORTED_ENCRYPT] = (u8) sizeof(struct sadb_supported), + [SADB_EXT_SPIRANGE] = (u8) sizeof(struct sadb_spirange), + [SADB_X_EXT_KMPRIVATE] = (u8) sizeof(struct sadb_x_kmprivate), + [SADB_X_EXT_POLICY] = (u8) sizeof(struct sadb_x_policy), + [SADB_X_EXT_SA2] = (u8) sizeof(struct sadb_x_sa2), + [SADB_X_EXT_NAT_T_TYPE] = (u8) sizeof(struct sadb_x_nat_t_type), + [SADB_X_EXT_NAT_T_SPORT] = (u8) sizeof(struct sadb_x_nat_t_port), + [SADB_X_EXT_NAT_T_DPORT] = (u8) sizeof(struct sadb_x_nat_t_port), + [SADB_X_EXT_NAT_T_OA] = (u8) sizeof(struct sadb_address), + [SADB_X_EXT_SEC_CTX] = (u8) sizeof(struct sadb_x_sec_ctx), + [SADB_X_EXT_KMADDRESS] = (u8) sizeof(struct sadb_x_kmaddress), + [SADB_X_EXT_FILTER] = (u8) sizeof(struct sadb_x_filter), +}; + +/* Verify sadb_address_{len,prefixlen} against sa_family. */ +static int verify_address_len(const void *p) +{ + const struct sadb_address *sp = p; + const struct sockaddr *addr = (const struct sockaddr *)(sp + 1); + const struct sockaddr_in *sin; +#if IS_ENABLED(CONFIG_IPV6) + const struct sockaddr_in6 *sin6; +#endif + int len; + + if (sp->sadb_address_len < + DIV_ROUND_UP(sizeof(*sp) + offsetofend(typeof(*addr), sa_family), + sizeof(uint64_t))) + return -EINVAL; + + switch (addr->sa_family) { + case AF_INET: + len = DIV_ROUND_UP(sizeof(*sp) + sizeof(*sin), sizeof(uint64_t)); + if (sp->sadb_address_len != len || + sp->sadb_address_prefixlen > 32) + return -EINVAL; + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + len = DIV_ROUND_UP(sizeof(*sp) + sizeof(*sin6), sizeof(uint64_t)); + if (sp->sadb_address_len != len || + sp->sadb_address_prefixlen > 128) + return -EINVAL; + break; +#endif + default: + /* It is user using kernel to keep track of security + * associations for another protocol, such as + * OSPF/RSVP/RIPV2/MIP. It is user's job to verify + * lengths. + * + * XXX Actually, association/policy database is not yet + * XXX able to cope with arbitrary sockaddr families. + * XXX When it can, remove this -EINVAL. -DaveM + */ + return -EINVAL; + } + + return 0; +} + +static inline int sadb_key_len(const struct sadb_key *key) +{ + int key_bytes = DIV_ROUND_UP(key->sadb_key_bits, 8); + + return DIV_ROUND_UP(sizeof(struct sadb_key) + key_bytes, + sizeof(uint64_t)); +} + +static int verify_key_len(const void *p) +{ + const struct sadb_key *key = p; + + if (sadb_key_len(key) > key->sadb_key_len) + return -EINVAL; + + return 0; +} + +static inline int pfkey_sec_ctx_len(const struct sadb_x_sec_ctx *sec_ctx) +{ + return DIV_ROUND_UP(sizeof(struct sadb_x_sec_ctx) + + sec_ctx->sadb_x_ctx_len, + sizeof(uint64_t)); +} + +static inline int verify_sec_ctx_len(const void *p) +{ + const struct sadb_x_sec_ctx *sec_ctx = p; + int len = sec_ctx->sadb_x_ctx_len; + + if (len > PAGE_SIZE) + return -EINVAL; + + len = pfkey_sec_ctx_len(sec_ctx); + + if (sec_ctx->sadb_x_sec_len != len) + return -EINVAL; + + return 0; +} + +static inline struct xfrm_user_sec_ctx *pfkey_sadb2xfrm_user_sec_ctx(const struct sadb_x_sec_ctx *sec_ctx, + gfp_t gfp) +{ + struct xfrm_user_sec_ctx *uctx = NULL; + int ctx_size = sec_ctx->sadb_x_ctx_len; + + uctx = kmalloc((sizeof(*uctx)+ctx_size), gfp); + + if (!uctx) + return NULL; + + uctx->len = pfkey_sec_ctx_len(sec_ctx); + uctx->exttype = sec_ctx->sadb_x_sec_exttype; + uctx->ctx_doi = sec_ctx->sadb_x_ctx_doi; + uctx->ctx_alg = sec_ctx->sadb_x_ctx_alg; + uctx->ctx_len = sec_ctx->sadb_x_ctx_len; + memcpy(uctx + 1, sec_ctx + 1, + uctx->ctx_len); + + return uctx; +} + +static int present_and_same_family(const struct sadb_address *src, + const struct sadb_address *dst) +{ + const struct sockaddr *s_addr, *d_addr; + + if (!src || !dst) + return 0; + + s_addr = (const struct sockaddr *)(src + 1); + d_addr = (const struct sockaddr *)(dst + 1); + if (s_addr->sa_family != d_addr->sa_family) + return 0; + if (s_addr->sa_family != AF_INET +#if IS_ENABLED(CONFIG_IPV6) + && s_addr->sa_family != AF_INET6 +#endif + ) + return 0; + + return 1; +} + +static int parse_exthdrs(struct sk_buff *skb, const struct sadb_msg *hdr, void **ext_hdrs) +{ + const char *p = (char *) hdr; + int len = skb->len; + + len -= sizeof(*hdr); + p += sizeof(*hdr); + while (len > 0) { + const struct sadb_ext *ehdr = (const struct sadb_ext *) p; + uint16_t ext_type; + int ext_len; + + if (len < sizeof(*ehdr)) + return -EINVAL; + + ext_len = ehdr->sadb_ext_len; + ext_len *= sizeof(uint64_t); + ext_type = ehdr->sadb_ext_type; + if (ext_len < sizeof(uint64_t) || + ext_len > len || + ext_type == SADB_EXT_RESERVED) + return -EINVAL; + + if (ext_type <= SADB_EXT_MAX) { + int min = (int) sadb_ext_min_len[ext_type]; + if (ext_len < min) + return -EINVAL; + if (ext_hdrs[ext_type-1] != NULL) + return -EINVAL; + switch (ext_type) { + case SADB_EXT_ADDRESS_SRC: + case SADB_EXT_ADDRESS_DST: + case SADB_EXT_ADDRESS_PROXY: + case SADB_X_EXT_NAT_T_OA: + if (verify_address_len(p)) + return -EINVAL; + break; + case SADB_X_EXT_SEC_CTX: + if (verify_sec_ctx_len(p)) + return -EINVAL; + break; + case SADB_EXT_KEY_AUTH: + case SADB_EXT_KEY_ENCRYPT: + if (verify_key_len(p)) + return -EINVAL; + break; + default: + break; + } + ext_hdrs[ext_type-1] = (void *) p; + } + p += ext_len; + len -= ext_len; + } + + return 0; +} + +static uint16_t +pfkey_satype2proto(uint8_t satype) +{ + switch (satype) { + case SADB_SATYPE_UNSPEC: + return IPSEC_PROTO_ANY; + case SADB_SATYPE_AH: + return IPPROTO_AH; + case SADB_SATYPE_ESP: + return IPPROTO_ESP; + case SADB_X_SATYPE_IPCOMP: + return IPPROTO_COMP; + default: + return 0; + } + /* NOTREACHED */ +} + +static uint8_t +pfkey_proto2satype(uint16_t proto) +{ + switch (proto) { + case IPPROTO_AH: + return SADB_SATYPE_AH; + case IPPROTO_ESP: + return SADB_SATYPE_ESP; + case IPPROTO_COMP: + return SADB_X_SATYPE_IPCOMP; + default: + return 0; + } + /* NOTREACHED */ +} + +/* BTW, this scheme means that there is no way with PFKEY2 sockets to + * say specifically 'just raw sockets' as we encode them as 255. + */ + +static uint8_t pfkey_proto_to_xfrm(uint8_t proto) +{ + return proto == IPSEC_PROTO_ANY ? 0 : proto; +} + +static uint8_t pfkey_proto_from_xfrm(uint8_t proto) +{ + return proto ? proto : IPSEC_PROTO_ANY; +} + +static inline int pfkey_sockaddr_len(sa_family_t family) +{ + switch (family) { + case AF_INET: + return sizeof(struct sockaddr_in); +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + return sizeof(struct sockaddr_in6); +#endif + } + return 0; +} + +static +int pfkey_sockaddr_extract(const struct sockaddr *sa, xfrm_address_t *xaddr) +{ + switch (sa->sa_family) { + case AF_INET: + xaddr->a4 = + ((struct sockaddr_in *)sa)->sin_addr.s_addr; + return AF_INET; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + memcpy(xaddr->a6, + &((struct sockaddr_in6 *)sa)->sin6_addr, + sizeof(struct in6_addr)); + return AF_INET6; +#endif + } + return 0; +} + +static +int pfkey_sadb_addr2xfrm_addr(const struct sadb_address *addr, xfrm_address_t *xaddr) +{ + return pfkey_sockaddr_extract((struct sockaddr *)(addr + 1), + xaddr); +} + +static struct xfrm_state *pfkey_xfrm_state_lookup(struct net *net, const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + const struct sadb_sa *sa; + const struct sadb_address *addr; + uint16_t proto; + unsigned short family; + xfrm_address_t *xaddr; + + sa = ext_hdrs[SADB_EXT_SA - 1]; + if (sa == NULL) + return NULL; + + proto = pfkey_satype2proto(hdr->sadb_msg_satype); + if (proto == 0) + return NULL; + + /* sadb_address_len should be checked by caller */ + addr = ext_hdrs[SADB_EXT_ADDRESS_DST - 1]; + if (addr == NULL) + return NULL; + + family = ((const struct sockaddr *)(addr + 1))->sa_family; + switch (family) { + case AF_INET: + xaddr = (xfrm_address_t *)&((const struct sockaddr_in *)(addr + 1))->sin_addr; + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + xaddr = (xfrm_address_t *)&((const struct sockaddr_in6 *)(addr + 1))->sin6_addr; + break; +#endif + default: + xaddr = NULL; + } + + if (!xaddr) + return NULL; + + return xfrm_state_lookup(net, DUMMY_MARK, xaddr, sa->sadb_sa_spi, proto, family); +} + +#define PFKEY_ALIGN8(a) (1 + (((a) - 1) | (8 - 1))) + +static int +pfkey_sockaddr_size(sa_family_t family) +{ + return PFKEY_ALIGN8(pfkey_sockaddr_len(family)); +} + +static inline int pfkey_mode_from_xfrm(int mode) +{ + switch(mode) { + case XFRM_MODE_TRANSPORT: + return IPSEC_MODE_TRANSPORT; + case XFRM_MODE_TUNNEL: + return IPSEC_MODE_TUNNEL; + case XFRM_MODE_BEET: + return IPSEC_MODE_BEET; + default: + return -1; + } +} + +static inline int pfkey_mode_to_xfrm(int mode) +{ + switch(mode) { + case IPSEC_MODE_ANY: /*XXX*/ + case IPSEC_MODE_TRANSPORT: + return XFRM_MODE_TRANSPORT; + case IPSEC_MODE_TUNNEL: + return XFRM_MODE_TUNNEL; + case IPSEC_MODE_BEET: + return XFRM_MODE_BEET; + default: + return -1; + } +} + +static unsigned int pfkey_sockaddr_fill(const xfrm_address_t *xaddr, __be16 port, + struct sockaddr *sa, + unsigned short family) +{ + switch (family) { + case AF_INET: + { + struct sockaddr_in *sin = (struct sockaddr_in *)sa; + sin->sin_family = AF_INET; + sin->sin_port = port; + sin->sin_addr.s_addr = xaddr->a4; + memset(sin->sin_zero, 0, sizeof(sin->sin_zero)); + return 32; + } +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + { + struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sa; + sin6->sin6_family = AF_INET6; + sin6->sin6_port = port; + sin6->sin6_flowinfo = 0; + sin6->sin6_addr = xaddr->in6; + sin6->sin6_scope_id = 0; + return 128; + } +#endif + } + return 0; +} + +static struct sk_buff *__pfkey_xfrm_state2msg(const struct xfrm_state *x, + int add_keys, int hsc) +{ + struct sk_buff *skb; + struct sadb_msg *hdr; + struct sadb_sa *sa; + struct sadb_lifetime *lifetime; + struct sadb_address *addr; + struct sadb_key *key; + struct sadb_x_sa2 *sa2; + struct sadb_x_sec_ctx *sec_ctx; + struct xfrm_sec_ctx *xfrm_ctx; + int ctx_size = 0; + int size; + int auth_key_size = 0; + int encrypt_key_size = 0; + int sockaddr_size; + struct xfrm_encap_tmpl *natt = NULL; + int mode; + + /* address family check */ + sockaddr_size = pfkey_sockaddr_size(x->props.family); + if (!sockaddr_size) + return ERR_PTR(-EINVAL); + + /* base, SA, (lifetime (HSC),) address(SD), (address(P),) + key(AE), (identity(SD),) (sensitivity)> */ + size = sizeof(struct sadb_msg) +sizeof(struct sadb_sa) + + sizeof(struct sadb_lifetime) + + ((hsc & 1) ? sizeof(struct sadb_lifetime) : 0) + + ((hsc & 2) ? sizeof(struct sadb_lifetime) : 0) + + sizeof(struct sadb_address)*2 + + sockaddr_size*2 + + sizeof(struct sadb_x_sa2); + + if ((xfrm_ctx = x->security)) { + ctx_size = PFKEY_ALIGN8(xfrm_ctx->ctx_len); + size += sizeof(struct sadb_x_sec_ctx) + ctx_size; + } + + /* identity & sensitivity */ + if (!xfrm_addr_equal(&x->sel.saddr, &x->props.saddr, x->props.family)) + size += sizeof(struct sadb_address) + sockaddr_size; + + if (add_keys) { + if (x->aalg && x->aalg->alg_key_len) { + auth_key_size = + PFKEY_ALIGN8((x->aalg->alg_key_len + 7) / 8); + size += sizeof(struct sadb_key) + auth_key_size; + } + if (x->ealg && x->ealg->alg_key_len) { + encrypt_key_size = + PFKEY_ALIGN8((x->ealg->alg_key_len+7) / 8); + size += sizeof(struct sadb_key) + encrypt_key_size; + } + } + if (x->encap) + natt = x->encap; + + if (natt && natt->encap_type) { + size += sizeof(struct sadb_x_nat_t_type); + size += sizeof(struct sadb_x_nat_t_port); + size += sizeof(struct sadb_x_nat_t_port); + } + + skb = alloc_skb(size + 16, GFP_ATOMIC); + if (skb == NULL) + return ERR_PTR(-ENOBUFS); + + /* call should fill header later */ + hdr = skb_put(skb, sizeof(struct sadb_msg)); + memset(hdr, 0, size); /* XXX do we need this ? */ + hdr->sadb_msg_len = size / sizeof(uint64_t); + + /* sa */ + sa = skb_put(skb, sizeof(struct sadb_sa)); + sa->sadb_sa_len = sizeof(struct sadb_sa)/sizeof(uint64_t); + sa->sadb_sa_exttype = SADB_EXT_SA; + sa->sadb_sa_spi = x->id.spi; + sa->sadb_sa_replay = x->props.replay_window; + switch (x->km.state) { + case XFRM_STATE_VALID: + sa->sadb_sa_state = x->km.dying ? + SADB_SASTATE_DYING : SADB_SASTATE_MATURE; + break; + case XFRM_STATE_ACQ: + sa->sadb_sa_state = SADB_SASTATE_LARVAL; + break; + default: + sa->sadb_sa_state = SADB_SASTATE_DEAD; + break; + } + sa->sadb_sa_auth = 0; + if (x->aalg) { + struct xfrm_algo_desc *a = xfrm_aalg_get_byname(x->aalg->alg_name, 0); + sa->sadb_sa_auth = (a && a->pfkey_supported) ? + a->desc.sadb_alg_id : 0; + } + sa->sadb_sa_encrypt = 0; + BUG_ON(x->ealg && x->calg); + if (x->ealg) { + struct xfrm_algo_desc *a = xfrm_ealg_get_byname(x->ealg->alg_name, 0); + sa->sadb_sa_encrypt = (a && a->pfkey_supported) ? + a->desc.sadb_alg_id : 0; + } + /* KAME compatible: sadb_sa_encrypt is overloaded with calg id */ + if (x->calg) { + struct xfrm_algo_desc *a = xfrm_calg_get_byname(x->calg->alg_name, 0); + sa->sadb_sa_encrypt = (a && a->pfkey_supported) ? + a->desc.sadb_alg_id : 0; + } + + sa->sadb_sa_flags = 0; + if (x->props.flags & XFRM_STATE_NOECN) + sa->sadb_sa_flags |= SADB_SAFLAGS_NOECN; + if (x->props.flags & XFRM_STATE_DECAP_DSCP) + sa->sadb_sa_flags |= SADB_SAFLAGS_DECAP_DSCP; + if (x->props.flags & XFRM_STATE_NOPMTUDISC) + sa->sadb_sa_flags |= SADB_SAFLAGS_NOPMTUDISC; + + /* hard time */ + if (hsc & 2) { + lifetime = skb_put(skb, sizeof(struct sadb_lifetime)); + lifetime->sadb_lifetime_len = + sizeof(struct sadb_lifetime)/sizeof(uint64_t); + lifetime->sadb_lifetime_exttype = SADB_EXT_LIFETIME_HARD; + lifetime->sadb_lifetime_allocations = _X2KEY(x->lft.hard_packet_limit); + lifetime->sadb_lifetime_bytes = _X2KEY(x->lft.hard_byte_limit); + lifetime->sadb_lifetime_addtime = x->lft.hard_add_expires_seconds; + lifetime->sadb_lifetime_usetime = x->lft.hard_use_expires_seconds; + } + /* soft time */ + if (hsc & 1) { + lifetime = skb_put(skb, sizeof(struct sadb_lifetime)); + lifetime->sadb_lifetime_len = + sizeof(struct sadb_lifetime)/sizeof(uint64_t); + lifetime->sadb_lifetime_exttype = SADB_EXT_LIFETIME_SOFT; + lifetime->sadb_lifetime_allocations = _X2KEY(x->lft.soft_packet_limit); + lifetime->sadb_lifetime_bytes = _X2KEY(x->lft.soft_byte_limit); + lifetime->sadb_lifetime_addtime = x->lft.soft_add_expires_seconds; + lifetime->sadb_lifetime_usetime = x->lft.soft_use_expires_seconds; + } + /* current time */ + lifetime = skb_put(skb, sizeof(struct sadb_lifetime)); + lifetime->sadb_lifetime_len = + sizeof(struct sadb_lifetime)/sizeof(uint64_t); + lifetime->sadb_lifetime_exttype = SADB_EXT_LIFETIME_CURRENT; + lifetime->sadb_lifetime_allocations = x->curlft.packets; + lifetime->sadb_lifetime_bytes = x->curlft.bytes; + lifetime->sadb_lifetime_addtime = x->curlft.add_time; + lifetime->sadb_lifetime_usetime = x->curlft.use_time; + /* src address */ + addr = skb_put(skb, sizeof(struct sadb_address) + sockaddr_size); + addr->sadb_address_len = + (sizeof(struct sadb_address)+sockaddr_size)/ + sizeof(uint64_t); + addr->sadb_address_exttype = SADB_EXT_ADDRESS_SRC; + /* "if the ports are non-zero, then the sadb_address_proto field, + normally zero, MUST be filled in with the transport + protocol's number." - RFC2367 */ + addr->sadb_address_proto = 0; + addr->sadb_address_reserved = 0; + + addr->sadb_address_prefixlen = + pfkey_sockaddr_fill(&x->props.saddr, 0, + (struct sockaddr *) (addr + 1), + x->props.family); + BUG_ON(!addr->sadb_address_prefixlen); + + /* dst address */ + addr = skb_put(skb, sizeof(struct sadb_address) + sockaddr_size); + addr->sadb_address_len = + (sizeof(struct sadb_address)+sockaddr_size)/ + sizeof(uint64_t); + addr->sadb_address_exttype = SADB_EXT_ADDRESS_DST; + addr->sadb_address_proto = 0; + addr->sadb_address_reserved = 0; + + addr->sadb_address_prefixlen = + pfkey_sockaddr_fill(&x->id.daddr, 0, + (struct sockaddr *) (addr + 1), + x->props.family); + BUG_ON(!addr->sadb_address_prefixlen); + + if (!xfrm_addr_equal(&x->sel.saddr, &x->props.saddr, + x->props.family)) { + addr = skb_put(skb, + sizeof(struct sadb_address) + sockaddr_size); + addr->sadb_address_len = + (sizeof(struct sadb_address)+sockaddr_size)/ + sizeof(uint64_t); + addr->sadb_address_exttype = SADB_EXT_ADDRESS_PROXY; + addr->sadb_address_proto = + pfkey_proto_from_xfrm(x->sel.proto); + addr->sadb_address_prefixlen = x->sel.prefixlen_s; + addr->sadb_address_reserved = 0; + + pfkey_sockaddr_fill(&x->sel.saddr, x->sel.sport, + (struct sockaddr *) (addr + 1), + x->props.family); + } + + /* auth key */ + if (add_keys && auth_key_size) { + key = skb_put(skb, sizeof(struct sadb_key) + auth_key_size); + key->sadb_key_len = (sizeof(struct sadb_key) + auth_key_size) / + sizeof(uint64_t); + key->sadb_key_exttype = SADB_EXT_KEY_AUTH; + key->sadb_key_bits = x->aalg->alg_key_len; + key->sadb_key_reserved = 0; + memcpy(key + 1, x->aalg->alg_key, (x->aalg->alg_key_len+7)/8); + } + /* encrypt key */ + if (add_keys && encrypt_key_size) { + key = skb_put(skb, sizeof(struct sadb_key) + encrypt_key_size); + key->sadb_key_len = (sizeof(struct sadb_key) + + encrypt_key_size) / sizeof(uint64_t); + key->sadb_key_exttype = SADB_EXT_KEY_ENCRYPT; + key->sadb_key_bits = x->ealg->alg_key_len; + key->sadb_key_reserved = 0; + memcpy(key + 1, x->ealg->alg_key, + (x->ealg->alg_key_len+7)/8); + } + + /* sa */ + sa2 = skb_put(skb, sizeof(struct sadb_x_sa2)); + sa2->sadb_x_sa2_len = sizeof(struct sadb_x_sa2)/sizeof(uint64_t); + sa2->sadb_x_sa2_exttype = SADB_X_EXT_SA2; + if ((mode = pfkey_mode_from_xfrm(x->props.mode)) < 0) { + kfree_skb(skb); + return ERR_PTR(-EINVAL); + } + sa2->sadb_x_sa2_mode = mode; + sa2->sadb_x_sa2_reserved1 = 0; + sa2->sadb_x_sa2_reserved2 = 0; + sa2->sadb_x_sa2_sequence = 0; + sa2->sadb_x_sa2_reqid = x->props.reqid; + + if (natt && natt->encap_type) { + struct sadb_x_nat_t_type *n_type; + struct sadb_x_nat_t_port *n_port; + + /* type */ + n_type = skb_put(skb, sizeof(*n_type)); + n_type->sadb_x_nat_t_type_len = sizeof(*n_type)/sizeof(uint64_t); + n_type->sadb_x_nat_t_type_exttype = SADB_X_EXT_NAT_T_TYPE; + n_type->sadb_x_nat_t_type_type = natt->encap_type; + n_type->sadb_x_nat_t_type_reserved[0] = 0; + n_type->sadb_x_nat_t_type_reserved[1] = 0; + n_type->sadb_x_nat_t_type_reserved[2] = 0; + + /* source port */ + n_port = skb_put(skb, sizeof(*n_port)); + n_port->sadb_x_nat_t_port_len = sizeof(*n_port)/sizeof(uint64_t); + n_port->sadb_x_nat_t_port_exttype = SADB_X_EXT_NAT_T_SPORT; + n_port->sadb_x_nat_t_port_port = natt->encap_sport; + n_port->sadb_x_nat_t_port_reserved = 0; + + /* dest port */ + n_port = skb_put(skb, sizeof(*n_port)); + n_port->sadb_x_nat_t_port_len = sizeof(*n_port)/sizeof(uint64_t); + n_port->sadb_x_nat_t_port_exttype = SADB_X_EXT_NAT_T_DPORT; + n_port->sadb_x_nat_t_port_port = natt->encap_dport; + n_port->sadb_x_nat_t_port_reserved = 0; + } + + /* security context */ + if (xfrm_ctx) { + sec_ctx = skb_put(skb, + sizeof(struct sadb_x_sec_ctx) + ctx_size); + sec_ctx->sadb_x_sec_len = + (sizeof(struct sadb_x_sec_ctx) + ctx_size) / sizeof(uint64_t); + sec_ctx->sadb_x_sec_exttype = SADB_X_EXT_SEC_CTX; + sec_ctx->sadb_x_ctx_doi = xfrm_ctx->ctx_doi; + sec_ctx->sadb_x_ctx_alg = xfrm_ctx->ctx_alg; + sec_ctx->sadb_x_ctx_len = xfrm_ctx->ctx_len; + memcpy(sec_ctx + 1, xfrm_ctx->ctx_str, + xfrm_ctx->ctx_len); + } + + return skb; +} + + +static inline struct sk_buff *pfkey_xfrm_state2msg(const struct xfrm_state *x) +{ + struct sk_buff *skb; + + skb = __pfkey_xfrm_state2msg(x, 1, 3); + + return skb; +} + +static inline struct sk_buff *pfkey_xfrm_state2msg_expire(const struct xfrm_state *x, + int hsc) +{ + return __pfkey_xfrm_state2msg(x, 0, hsc); +} + +static struct xfrm_state * pfkey_msg2xfrm_state(struct net *net, + const struct sadb_msg *hdr, + void * const *ext_hdrs) +{ + struct xfrm_state *x; + const struct sadb_lifetime *lifetime; + const struct sadb_sa *sa; + const struct sadb_key *key; + const struct sadb_x_sec_ctx *sec_ctx; + uint16_t proto; + int err; + + + sa = ext_hdrs[SADB_EXT_SA - 1]; + if (!sa || + !present_and_same_family(ext_hdrs[SADB_EXT_ADDRESS_SRC-1], + ext_hdrs[SADB_EXT_ADDRESS_DST-1])) + return ERR_PTR(-EINVAL); + if (hdr->sadb_msg_satype == SADB_SATYPE_ESP && + !ext_hdrs[SADB_EXT_KEY_ENCRYPT-1]) + return ERR_PTR(-EINVAL); + if (hdr->sadb_msg_satype == SADB_SATYPE_AH && + !ext_hdrs[SADB_EXT_KEY_AUTH-1]) + return ERR_PTR(-EINVAL); + if (!!ext_hdrs[SADB_EXT_LIFETIME_HARD-1] != + !!ext_hdrs[SADB_EXT_LIFETIME_SOFT-1]) + return ERR_PTR(-EINVAL); + + proto = pfkey_satype2proto(hdr->sadb_msg_satype); + if (proto == 0) + return ERR_PTR(-EINVAL); + + /* default error is no buffer space */ + err = -ENOBUFS; + + /* RFC2367: + + Only SADB_SASTATE_MATURE SAs may be submitted in an SADB_ADD message. + SADB_SASTATE_LARVAL SAs are created by SADB_GETSPI and it is not + sensible to add a new SA in the DYING or SADB_SASTATE_DEAD state. + Therefore, the sadb_sa_state field of all submitted SAs MUST be + SADB_SASTATE_MATURE and the kernel MUST return an error if this is + not true. + + However, KAME setkey always uses SADB_SASTATE_LARVAL. + Hence, we have to _ignore_ sadb_sa_state, which is also reasonable. + */ + if (sa->sadb_sa_auth > SADB_AALG_MAX || + (hdr->sadb_msg_satype == SADB_X_SATYPE_IPCOMP && + sa->sadb_sa_encrypt > SADB_X_CALG_MAX) || + sa->sadb_sa_encrypt > SADB_EALG_MAX) + return ERR_PTR(-EINVAL); + key = ext_hdrs[SADB_EXT_KEY_AUTH - 1]; + if (key != NULL && + sa->sadb_sa_auth != SADB_X_AALG_NULL && + key->sadb_key_bits == 0) + return ERR_PTR(-EINVAL); + key = ext_hdrs[SADB_EXT_KEY_ENCRYPT-1]; + if (key != NULL && + sa->sadb_sa_encrypt != SADB_EALG_NULL && + key->sadb_key_bits == 0) + return ERR_PTR(-EINVAL); + + x = xfrm_state_alloc(net); + if (x == NULL) + return ERR_PTR(-ENOBUFS); + + x->id.proto = proto; + x->id.spi = sa->sadb_sa_spi; + x->props.replay_window = min_t(unsigned int, sa->sadb_sa_replay, + (sizeof(x->replay.bitmap) * 8)); + if (sa->sadb_sa_flags & SADB_SAFLAGS_NOECN) + x->props.flags |= XFRM_STATE_NOECN; + if (sa->sadb_sa_flags & SADB_SAFLAGS_DECAP_DSCP) + x->props.flags |= XFRM_STATE_DECAP_DSCP; + if (sa->sadb_sa_flags & SADB_SAFLAGS_NOPMTUDISC) + x->props.flags |= XFRM_STATE_NOPMTUDISC; + + lifetime = ext_hdrs[SADB_EXT_LIFETIME_HARD - 1]; + if (lifetime != NULL) { + x->lft.hard_packet_limit = _KEY2X(lifetime->sadb_lifetime_allocations); + x->lft.hard_byte_limit = _KEY2X(lifetime->sadb_lifetime_bytes); + x->lft.hard_add_expires_seconds = lifetime->sadb_lifetime_addtime; + x->lft.hard_use_expires_seconds = lifetime->sadb_lifetime_usetime; + } + lifetime = ext_hdrs[SADB_EXT_LIFETIME_SOFT - 1]; + if (lifetime != NULL) { + x->lft.soft_packet_limit = _KEY2X(lifetime->sadb_lifetime_allocations); + x->lft.soft_byte_limit = _KEY2X(lifetime->sadb_lifetime_bytes); + x->lft.soft_add_expires_seconds = lifetime->sadb_lifetime_addtime; + x->lft.soft_use_expires_seconds = lifetime->sadb_lifetime_usetime; + } + + sec_ctx = ext_hdrs[SADB_X_EXT_SEC_CTX - 1]; + if (sec_ctx != NULL) { + struct xfrm_user_sec_ctx *uctx = pfkey_sadb2xfrm_user_sec_ctx(sec_ctx, GFP_KERNEL); + + if (!uctx) + goto out; + + err = security_xfrm_state_alloc(x, uctx); + kfree(uctx); + + if (err) + goto out; + } + + err = -ENOBUFS; + key = ext_hdrs[SADB_EXT_KEY_AUTH - 1]; + if (sa->sadb_sa_auth) { + int keysize = 0; + struct xfrm_algo_desc *a = xfrm_aalg_get_byid(sa->sadb_sa_auth); + if (!a || !a->pfkey_supported) { + err = -ENOSYS; + goto out; + } + if (key) + keysize = (key->sadb_key_bits + 7) / 8; + x->aalg = kmalloc(sizeof(*x->aalg) + keysize, GFP_KERNEL); + if (!x->aalg) { + err = -ENOMEM; + goto out; + } + strcpy(x->aalg->alg_name, a->name); + x->aalg->alg_key_len = 0; + if (key) { + x->aalg->alg_key_len = key->sadb_key_bits; + memcpy(x->aalg->alg_key, key+1, keysize); + } + x->aalg->alg_trunc_len = a->uinfo.auth.icv_truncbits; + x->props.aalgo = sa->sadb_sa_auth; + /* x->algo.flags = sa->sadb_sa_flags; */ + } + if (sa->sadb_sa_encrypt) { + if (hdr->sadb_msg_satype == SADB_X_SATYPE_IPCOMP) { + struct xfrm_algo_desc *a = xfrm_calg_get_byid(sa->sadb_sa_encrypt); + if (!a || !a->pfkey_supported) { + err = -ENOSYS; + goto out; + } + x->calg = kmalloc(sizeof(*x->calg), GFP_KERNEL); + if (!x->calg) { + err = -ENOMEM; + goto out; + } + strcpy(x->calg->alg_name, a->name); + x->props.calgo = sa->sadb_sa_encrypt; + } else { + int keysize = 0; + struct xfrm_algo_desc *a = xfrm_ealg_get_byid(sa->sadb_sa_encrypt); + if (!a || !a->pfkey_supported) { + err = -ENOSYS; + goto out; + } + key = (struct sadb_key*) ext_hdrs[SADB_EXT_KEY_ENCRYPT-1]; + if (key) + keysize = (key->sadb_key_bits + 7) / 8; + x->ealg = kmalloc(sizeof(*x->ealg) + keysize, GFP_KERNEL); + if (!x->ealg) { + err = -ENOMEM; + goto out; + } + strcpy(x->ealg->alg_name, a->name); + x->ealg->alg_key_len = 0; + if (key) { + x->ealg->alg_key_len = key->sadb_key_bits; + memcpy(x->ealg->alg_key, key+1, keysize); + } + x->props.ealgo = sa->sadb_sa_encrypt; + x->geniv = a->uinfo.encr.geniv; + } + } + /* x->algo.flags = sa->sadb_sa_flags; */ + + x->props.family = pfkey_sadb_addr2xfrm_addr((struct sadb_address *) ext_hdrs[SADB_EXT_ADDRESS_SRC-1], + &x->props.saddr); + pfkey_sadb_addr2xfrm_addr((struct sadb_address *) ext_hdrs[SADB_EXT_ADDRESS_DST-1], + &x->id.daddr); + + if (ext_hdrs[SADB_X_EXT_SA2-1]) { + const struct sadb_x_sa2 *sa2 = ext_hdrs[SADB_X_EXT_SA2-1]; + int mode = pfkey_mode_to_xfrm(sa2->sadb_x_sa2_mode); + if (mode < 0) { + err = -EINVAL; + goto out; + } + x->props.mode = mode; + x->props.reqid = sa2->sadb_x_sa2_reqid; + } + + if (ext_hdrs[SADB_EXT_ADDRESS_PROXY-1]) { + const struct sadb_address *addr = ext_hdrs[SADB_EXT_ADDRESS_PROXY-1]; + + /* Nobody uses this, but we try. */ + x->sel.family = pfkey_sadb_addr2xfrm_addr(addr, &x->sel.saddr); + x->sel.prefixlen_s = addr->sadb_address_prefixlen; + } + + if (!x->sel.family) + x->sel.family = x->props.family; + + if (ext_hdrs[SADB_X_EXT_NAT_T_TYPE-1]) { + const struct sadb_x_nat_t_type* n_type; + struct xfrm_encap_tmpl *natt; + + x->encap = kmalloc(sizeof(*x->encap), GFP_KERNEL); + if (!x->encap) { + err = -ENOMEM; + goto out; + } + + natt = x->encap; + n_type = ext_hdrs[SADB_X_EXT_NAT_T_TYPE-1]; + natt->encap_type = n_type->sadb_x_nat_t_type_type; + + if (ext_hdrs[SADB_X_EXT_NAT_T_SPORT-1]) { + const struct sadb_x_nat_t_port *n_port = + ext_hdrs[SADB_X_EXT_NAT_T_SPORT-1]; + natt->encap_sport = n_port->sadb_x_nat_t_port_port; + } + if (ext_hdrs[SADB_X_EXT_NAT_T_DPORT-1]) { + const struct sadb_x_nat_t_port *n_port = + ext_hdrs[SADB_X_EXT_NAT_T_DPORT-1]; + natt->encap_dport = n_port->sadb_x_nat_t_port_port; + } + memset(&natt->encap_oa, 0, sizeof(natt->encap_oa)); + } + + err = xfrm_init_state(x); + if (err) + goto out; + + x->km.seq = hdr->sadb_msg_seq; + return x; + +out: + x->km.state = XFRM_STATE_DEAD; + xfrm_state_put(x); + return ERR_PTR(err); +} + +static int pfkey_reserved(struct sock *sk, struct sk_buff *skb, const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + return -EOPNOTSUPP; +} + +static int pfkey_getspi(struct sock *sk, struct sk_buff *skb, const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + struct net *net = sock_net(sk); + struct sk_buff *resp_skb; + struct sadb_x_sa2 *sa2; + struct sadb_address *saddr, *daddr; + struct sadb_msg *out_hdr; + struct sadb_spirange *range; + struct xfrm_state *x = NULL; + int mode; + int err; + u32 min_spi, max_spi; + u32 reqid; + u8 proto; + unsigned short family; + xfrm_address_t *xsaddr = NULL, *xdaddr = NULL; + + if (!present_and_same_family(ext_hdrs[SADB_EXT_ADDRESS_SRC-1], + ext_hdrs[SADB_EXT_ADDRESS_DST-1])) + return -EINVAL; + + proto = pfkey_satype2proto(hdr->sadb_msg_satype); + if (proto == 0) + return -EINVAL; + + if ((sa2 = ext_hdrs[SADB_X_EXT_SA2-1]) != NULL) { + mode = pfkey_mode_to_xfrm(sa2->sadb_x_sa2_mode); + if (mode < 0) + return -EINVAL; + reqid = sa2->sadb_x_sa2_reqid; + } else { + mode = 0; + reqid = 0; + } + + saddr = ext_hdrs[SADB_EXT_ADDRESS_SRC-1]; + daddr = ext_hdrs[SADB_EXT_ADDRESS_DST-1]; + + family = ((struct sockaddr *)(saddr + 1))->sa_family; + switch (family) { + case AF_INET: + xdaddr = (xfrm_address_t *)&((struct sockaddr_in *)(daddr + 1))->sin_addr.s_addr; + xsaddr = (xfrm_address_t *)&((struct sockaddr_in *)(saddr + 1))->sin_addr.s_addr; + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + xdaddr = (xfrm_address_t *)&((struct sockaddr_in6 *)(daddr + 1))->sin6_addr; + xsaddr = (xfrm_address_t *)&((struct sockaddr_in6 *)(saddr + 1))->sin6_addr; + break; +#endif + } + + if (hdr->sadb_msg_seq) { + x = xfrm_find_acq_byseq(net, DUMMY_MARK, hdr->sadb_msg_seq); + if (x && !xfrm_addr_equal(&x->id.daddr, xdaddr, family)) { + xfrm_state_put(x); + x = NULL; + } + } + + if (!x) + x = xfrm_find_acq(net, &dummy_mark, mode, reqid, 0, proto, xdaddr, xsaddr, 1, family); + + if (x == NULL) + return -ENOENT; + + min_spi = 0x100; + max_spi = 0x0fffffff; + + range = ext_hdrs[SADB_EXT_SPIRANGE-1]; + if (range) { + min_spi = range->sadb_spirange_min; + max_spi = range->sadb_spirange_max; + } + + err = verify_spi_info(x->id.proto, min_spi, max_spi); + if (err) { + xfrm_state_put(x); + return err; + } + + err = xfrm_alloc_spi(x, min_spi, max_spi); + resp_skb = err ? ERR_PTR(err) : pfkey_xfrm_state2msg(x); + + if (IS_ERR(resp_skb)) { + xfrm_state_put(x); + return PTR_ERR(resp_skb); + } + + out_hdr = (struct sadb_msg *) resp_skb->data; + out_hdr->sadb_msg_version = hdr->sadb_msg_version; + out_hdr->sadb_msg_type = SADB_GETSPI; + out_hdr->sadb_msg_satype = pfkey_proto2satype(proto); + out_hdr->sadb_msg_errno = 0; + out_hdr->sadb_msg_reserved = 0; + out_hdr->sadb_msg_seq = hdr->sadb_msg_seq; + out_hdr->sadb_msg_pid = hdr->sadb_msg_pid; + + xfrm_state_put(x); + + pfkey_broadcast(resp_skb, GFP_KERNEL, BROADCAST_ONE, sk, net); + + return 0; +} + +static int pfkey_acquire(struct sock *sk, struct sk_buff *skb, const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + struct net *net = sock_net(sk); + struct xfrm_state *x; + + if (hdr->sadb_msg_len != sizeof(struct sadb_msg)/8) + return -EOPNOTSUPP; + + if (hdr->sadb_msg_seq == 0 || hdr->sadb_msg_errno == 0) + return 0; + + x = xfrm_find_acq_byseq(net, DUMMY_MARK, hdr->sadb_msg_seq); + if (x == NULL) + return 0; + + spin_lock_bh(&x->lock); + if (x->km.state == XFRM_STATE_ACQ) + x->km.state = XFRM_STATE_ERROR; + + spin_unlock_bh(&x->lock); + xfrm_state_put(x); + return 0; +} + +static inline int event2poltype(int event) +{ + switch (event) { + case XFRM_MSG_DELPOLICY: + return SADB_X_SPDDELETE; + case XFRM_MSG_NEWPOLICY: + return SADB_X_SPDADD; + case XFRM_MSG_UPDPOLICY: + return SADB_X_SPDUPDATE; + case XFRM_MSG_POLEXPIRE: + // return SADB_X_SPDEXPIRE; + default: + pr_err("pfkey: Unknown policy event %d\n", event); + break; + } + + return 0; +} + +static inline int event2keytype(int event) +{ + switch (event) { + case XFRM_MSG_DELSA: + return SADB_DELETE; + case XFRM_MSG_NEWSA: + return SADB_ADD; + case XFRM_MSG_UPDSA: + return SADB_UPDATE; + case XFRM_MSG_EXPIRE: + return SADB_EXPIRE; + default: + pr_err("pfkey: Unknown SA event %d\n", event); + break; + } + + return 0; +} + +/* ADD/UPD/DEL */ +static int key_notify_sa(struct xfrm_state *x, const struct km_event *c) +{ + struct sk_buff *skb; + struct sadb_msg *hdr; + + skb = pfkey_xfrm_state2msg(x); + + if (IS_ERR(skb)) + return PTR_ERR(skb); + + hdr = (struct sadb_msg *) skb->data; + hdr->sadb_msg_version = PF_KEY_V2; + hdr->sadb_msg_type = event2keytype(c->event); + hdr->sadb_msg_satype = pfkey_proto2satype(x->id.proto); + hdr->sadb_msg_errno = 0; + hdr->sadb_msg_reserved = 0; + hdr->sadb_msg_seq = c->seq; + hdr->sadb_msg_pid = c->portid; + + pfkey_broadcast(skb, GFP_ATOMIC, BROADCAST_ALL, NULL, xs_net(x)); + + return 0; +} + +static int pfkey_add(struct sock *sk, struct sk_buff *skb, const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + struct net *net = sock_net(sk); + struct xfrm_state *x; + int err; + struct km_event c; + + x = pfkey_msg2xfrm_state(net, hdr, ext_hdrs); + if (IS_ERR(x)) + return PTR_ERR(x); + + xfrm_state_hold(x); + if (hdr->sadb_msg_type == SADB_ADD) + err = xfrm_state_add(x); + else + err = xfrm_state_update(x); + + xfrm_audit_state_add(x, err ? 0 : 1, true); + + if (err < 0) { + x->km.state = XFRM_STATE_DEAD; + __xfrm_state_put(x); + goto out; + } + + if (hdr->sadb_msg_type == SADB_ADD) + c.event = XFRM_MSG_NEWSA; + else + c.event = XFRM_MSG_UPDSA; + c.seq = hdr->sadb_msg_seq; + c.portid = hdr->sadb_msg_pid; + km_state_notify(x, &c); +out: + xfrm_state_put(x); + return err; +} + +static int pfkey_delete(struct sock *sk, struct sk_buff *skb, const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + struct net *net = sock_net(sk); + struct xfrm_state *x; + struct km_event c; + int err; + + if (!ext_hdrs[SADB_EXT_SA-1] || + !present_and_same_family(ext_hdrs[SADB_EXT_ADDRESS_SRC-1], + ext_hdrs[SADB_EXT_ADDRESS_DST-1])) + return -EINVAL; + + x = pfkey_xfrm_state_lookup(net, hdr, ext_hdrs); + if (x == NULL) + return -ESRCH; + + if ((err = security_xfrm_state_delete(x))) + goto out; + + if (xfrm_state_kern(x)) { + err = -EPERM; + goto out; + } + + err = xfrm_state_delete(x); + + if (err < 0) + goto out; + + c.seq = hdr->sadb_msg_seq; + c.portid = hdr->sadb_msg_pid; + c.event = XFRM_MSG_DELSA; + km_state_notify(x, &c); +out: + xfrm_audit_state_delete(x, err ? 0 : 1, true); + xfrm_state_put(x); + + return err; +} + +static int pfkey_get(struct sock *sk, struct sk_buff *skb, const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + struct net *net = sock_net(sk); + __u8 proto; + struct sk_buff *out_skb; + struct sadb_msg *out_hdr; + struct xfrm_state *x; + + if (!ext_hdrs[SADB_EXT_SA-1] || + !present_and_same_family(ext_hdrs[SADB_EXT_ADDRESS_SRC-1], + ext_hdrs[SADB_EXT_ADDRESS_DST-1])) + return -EINVAL; + + x = pfkey_xfrm_state_lookup(net, hdr, ext_hdrs); + if (x == NULL) + return -ESRCH; + + out_skb = pfkey_xfrm_state2msg(x); + proto = x->id.proto; + xfrm_state_put(x); + if (IS_ERR(out_skb)) + return PTR_ERR(out_skb); + + out_hdr = (struct sadb_msg *) out_skb->data; + out_hdr->sadb_msg_version = hdr->sadb_msg_version; + out_hdr->sadb_msg_type = SADB_GET; + out_hdr->sadb_msg_satype = pfkey_proto2satype(proto); + out_hdr->sadb_msg_errno = 0; + out_hdr->sadb_msg_reserved = 0; + out_hdr->sadb_msg_seq = hdr->sadb_msg_seq; + out_hdr->sadb_msg_pid = hdr->sadb_msg_pid; + pfkey_broadcast(out_skb, GFP_ATOMIC, BROADCAST_ONE, sk, sock_net(sk)); + + return 0; +} + +static struct sk_buff *compose_sadb_supported(const struct sadb_msg *orig, + gfp_t allocation) +{ + struct sk_buff *skb; + struct sadb_msg *hdr; + int len, auth_len, enc_len, i; + + auth_len = xfrm_count_pfkey_auth_supported(); + if (auth_len) { + auth_len *= sizeof(struct sadb_alg); + auth_len += sizeof(struct sadb_supported); + } + + enc_len = xfrm_count_pfkey_enc_supported(); + if (enc_len) { + enc_len *= sizeof(struct sadb_alg); + enc_len += sizeof(struct sadb_supported); + } + + len = enc_len + auth_len + sizeof(struct sadb_msg); + + skb = alloc_skb(len + 16, allocation); + if (!skb) + goto out_put_algs; + + hdr = skb_put(skb, sizeof(*hdr)); + pfkey_hdr_dup(hdr, orig); + hdr->sadb_msg_errno = 0; + hdr->sadb_msg_len = len / sizeof(uint64_t); + + if (auth_len) { + struct sadb_supported *sp; + struct sadb_alg *ap; + + sp = skb_put(skb, auth_len); + ap = (struct sadb_alg *) (sp + 1); + + sp->sadb_supported_len = auth_len / sizeof(uint64_t); + sp->sadb_supported_exttype = SADB_EXT_SUPPORTED_AUTH; + + for (i = 0; ; i++) { + struct xfrm_algo_desc *aalg = xfrm_aalg_get_byidx(i); + if (!aalg) + break; + if (!aalg->pfkey_supported) + continue; + if (aalg->available) + *ap++ = aalg->desc; + } + } + + if (enc_len) { + struct sadb_supported *sp; + struct sadb_alg *ap; + + sp = skb_put(skb, enc_len); + ap = (struct sadb_alg *) (sp + 1); + + sp->sadb_supported_len = enc_len / sizeof(uint64_t); + sp->sadb_supported_exttype = SADB_EXT_SUPPORTED_ENCRYPT; + + for (i = 0; ; i++) { + struct xfrm_algo_desc *ealg = xfrm_ealg_get_byidx(i); + if (!ealg) + break; + if (!ealg->pfkey_supported) + continue; + if (ealg->available) + *ap++ = ealg->desc; + } + } + +out_put_algs: + return skb; +} + +static int pfkey_register(struct sock *sk, struct sk_buff *skb, const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + struct pfkey_sock *pfk = pfkey_sk(sk); + struct sk_buff *supp_skb; + + if (hdr->sadb_msg_satype > SADB_SATYPE_MAX) + return -EINVAL; + + if (hdr->sadb_msg_satype != SADB_SATYPE_UNSPEC) { + if (pfk->registered&(1<sadb_msg_satype)) + return -EEXIST; + pfk->registered |= (1<sadb_msg_satype); + } + + mutex_lock(&pfkey_mutex); + xfrm_probe_algs(); + + supp_skb = compose_sadb_supported(hdr, GFP_KERNEL | __GFP_ZERO); + mutex_unlock(&pfkey_mutex); + + if (!supp_skb) { + if (hdr->sadb_msg_satype != SADB_SATYPE_UNSPEC) + pfk->registered &= ~(1<sadb_msg_satype); + + return -ENOBUFS; + } + + pfkey_broadcast(supp_skb, GFP_KERNEL, BROADCAST_REGISTERED, sk, + sock_net(sk)); + return 0; +} + +static int unicast_flush_resp(struct sock *sk, const struct sadb_msg *ihdr) +{ + struct sk_buff *skb; + struct sadb_msg *hdr; + + skb = alloc_skb(sizeof(struct sadb_msg) + 16, GFP_ATOMIC); + if (!skb) + return -ENOBUFS; + + hdr = skb_put_data(skb, ihdr, sizeof(struct sadb_msg)); + hdr->sadb_msg_errno = (uint8_t) 0; + hdr->sadb_msg_len = (sizeof(struct sadb_msg) / sizeof(uint64_t)); + + return pfkey_broadcast(skb, GFP_ATOMIC, BROADCAST_ONE, sk, + sock_net(sk)); +} + +static int key_notify_sa_flush(const struct km_event *c) +{ + struct sk_buff *skb; + struct sadb_msg *hdr; + + skb = alloc_skb(sizeof(struct sadb_msg) + 16, GFP_ATOMIC); + if (!skb) + return -ENOBUFS; + hdr = skb_put(skb, sizeof(struct sadb_msg)); + hdr->sadb_msg_satype = pfkey_proto2satype(c->data.proto); + hdr->sadb_msg_type = SADB_FLUSH; + hdr->sadb_msg_seq = c->seq; + hdr->sadb_msg_pid = c->portid; + hdr->sadb_msg_version = PF_KEY_V2; + hdr->sadb_msg_errno = (uint8_t) 0; + hdr->sadb_msg_len = (sizeof(struct sadb_msg) / sizeof(uint64_t)); + hdr->sadb_msg_reserved = 0; + + pfkey_broadcast(skb, GFP_ATOMIC, BROADCAST_ALL, NULL, c->net); + + return 0; +} + +static int pfkey_flush(struct sock *sk, struct sk_buff *skb, const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + struct net *net = sock_net(sk); + unsigned int proto; + struct km_event c; + int err, err2; + + proto = pfkey_satype2proto(hdr->sadb_msg_satype); + if (proto == 0) + return -EINVAL; + + err = xfrm_state_flush(net, proto, true, false); + err2 = unicast_flush_resp(sk, hdr); + if (err || err2) { + if (err == -ESRCH) /* empty table - go quietly */ + err = 0; + return err ? err : err2; + } + + c.data.proto = proto; + c.seq = hdr->sadb_msg_seq; + c.portid = hdr->sadb_msg_pid; + c.event = XFRM_MSG_FLUSHSA; + c.net = net; + km_state_notify(NULL, &c); + + return 0; +} + +static int dump_sa(struct xfrm_state *x, int count, void *ptr) +{ + struct pfkey_sock *pfk = ptr; + struct sk_buff *out_skb; + struct sadb_msg *out_hdr; + + if (!pfkey_can_dump(&pfk->sk)) + return -ENOBUFS; + + out_skb = pfkey_xfrm_state2msg(x); + if (IS_ERR(out_skb)) + return PTR_ERR(out_skb); + + out_hdr = (struct sadb_msg *) out_skb->data; + out_hdr->sadb_msg_version = pfk->dump.msg_version; + out_hdr->sadb_msg_type = SADB_DUMP; + out_hdr->sadb_msg_satype = pfkey_proto2satype(x->id.proto); + out_hdr->sadb_msg_errno = 0; + out_hdr->sadb_msg_reserved = 0; + out_hdr->sadb_msg_seq = count + 1; + out_hdr->sadb_msg_pid = pfk->dump.msg_portid; + + if (pfk->dump.skb) + pfkey_broadcast(pfk->dump.skb, GFP_ATOMIC, BROADCAST_ONE, + &pfk->sk, sock_net(&pfk->sk)); + pfk->dump.skb = out_skb; + + return 0; +} + +static int pfkey_dump_sa(struct pfkey_sock *pfk) +{ + struct net *net = sock_net(&pfk->sk); + return xfrm_state_walk(net, &pfk->dump.u.state, dump_sa, (void *) pfk); +} + +static void pfkey_dump_sa_done(struct pfkey_sock *pfk) +{ + struct net *net = sock_net(&pfk->sk); + + xfrm_state_walk_done(&pfk->dump.u.state, net); +} + +static int pfkey_dump(struct sock *sk, struct sk_buff *skb, const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + u8 proto; + struct xfrm_address_filter *filter = NULL; + struct pfkey_sock *pfk = pfkey_sk(sk); + + mutex_lock(&pfk->dump_lock); + if (pfk->dump.dump != NULL) { + mutex_unlock(&pfk->dump_lock); + return -EBUSY; + } + + proto = pfkey_satype2proto(hdr->sadb_msg_satype); + if (proto == 0) { + mutex_unlock(&pfk->dump_lock); + return -EINVAL; + } + + if (ext_hdrs[SADB_X_EXT_FILTER - 1]) { + struct sadb_x_filter *xfilter = ext_hdrs[SADB_X_EXT_FILTER - 1]; + + if ((xfilter->sadb_x_filter_splen > + (sizeof(xfrm_address_t) << 3)) || + (xfilter->sadb_x_filter_dplen > + (sizeof(xfrm_address_t) << 3))) { + mutex_unlock(&pfk->dump_lock); + return -EINVAL; + } + filter = kmalloc(sizeof(*filter), GFP_KERNEL); + if (filter == NULL) { + mutex_unlock(&pfk->dump_lock); + return -ENOMEM; + } + + memcpy(&filter->saddr, &xfilter->sadb_x_filter_saddr, + sizeof(xfrm_address_t)); + memcpy(&filter->daddr, &xfilter->sadb_x_filter_daddr, + sizeof(xfrm_address_t)); + filter->family = xfilter->sadb_x_filter_family; + filter->splen = xfilter->sadb_x_filter_splen; + filter->dplen = xfilter->sadb_x_filter_dplen; + } + + pfk->dump.msg_version = hdr->sadb_msg_version; + pfk->dump.msg_portid = hdr->sadb_msg_pid; + pfk->dump.dump = pfkey_dump_sa; + pfk->dump.done = pfkey_dump_sa_done; + xfrm_state_walk_init(&pfk->dump.u.state, proto, filter); + mutex_unlock(&pfk->dump_lock); + + return pfkey_do_dump(pfk); +} + +static int pfkey_promisc(struct sock *sk, struct sk_buff *skb, const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + struct pfkey_sock *pfk = pfkey_sk(sk); + int satype = hdr->sadb_msg_satype; + bool reset_errno = false; + + if (hdr->sadb_msg_len == (sizeof(*hdr) / sizeof(uint64_t))) { + reset_errno = true; + if (satype != 0 && satype != 1) + return -EINVAL; + pfk->promisc = satype; + } + if (reset_errno && skb_cloned(skb)) + skb = skb_copy(skb, GFP_KERNEL); + else + skb = skb_clone(skb, GFP_KERNEL); + + if (reset_errno && skb) { + struct sadb_msg *new_hdr = (struct sadb_msg *) skb->data; + new_hdr->sadb_msg_errno = 0; + } + + pfkey_broadcast(skb, GFP_KERNEL, BROADCAST_ALL, NULL, sock_net(sk)); + return 0; +} + +static int check_reqid(struct xfrm_policy *xp, int dir, int count, void *ptr) +{ + int i; + u32 reqid = *(u32*)ptr; + + for (i=0; ixfrm_nr; i++) { + if (xp->xfrm_vec[i].reqid == reqid) + return -EEXIST; + } + return 0; +} + +static u32 gen_reqid(struct net *net) +{ + struct xfrm_policy_walk walk; + u32 start; + int rc; + static u32 reqid = IPSEC_MANUAL_REQID_MAX; + + start = reqid; + do { + ++reqid; + if (reqid == 0) + reqid = IPSEC_MANUAL_REQID_MAX+1; + xfrm_policy_walk_init(&walk, XFRM_POLICY_TYPE_MAIN); + rc = xfrm_policy_walk(net, &walk, check_reqid, (void*)&reqid); + xfrm_policy_walk_done(&walk, net); + if (rc != -EEXIST) + return reqid; + } while (reqid != start); + return 0; +} + +static int +parse_ipsecrequest(struct xfrm_policy *xp, struct sadb_x_policy *pol, + struct sadb_x_ipsecrequest *rq) +{ + struct net *net = xp_net(xp); + struct xfrm_tmpl *t = xp->xfrm_vec + xp->xfrm_nr; + int mode; + + if (xp->xfrm_nr >= XFRM_MAX_DEPTH) + return -ELOOP; + + if (rq->sadb_x_ipsecrequest_mode == 0) + return -EINVAL; + if (!xfrm_id_proto_valid(rq->sadb_x_ipsecrequest_proto)) + return -EINVAL; + + t->id.proto = rq->sadb_x_ipsecrequest_proto; + if ((mode = pfkey_mode_to_xfrm(rq->sadb_x_ipsecrequest_mode)) < 0) + return -EINVAL; + t->mode = mode; + if (rq->sadb_x_ipsecrequest_level == IPSEC_LEVEL_USE) { + if ((mode == XFRM_MODE_TUNNEL || mode == XFRM_MODE_BEET) && + pol->sadb_x_policy_dir == IPSEC_DIR_OUTBOUND) + return -EINVAL; + t->optional = 1; + } else if (rq->sadb_x_ipsecrequest_level == IPSEC_LEVEL_UNIQUE) { + t->reqid = rq->sadb_x_ipsecrequest_reqid; + if (t->reqid > IPSEC_MANUAL_REQID_MAX) + t->reqid = 0; + if (!t->reqid && !(t->reqid = gen_reqid(net))) + return -ENOBUFS; + } + + /* addresses present only in tunnel mode */ + if (t->mode == XFRM_MODE_TUNNEL) { + int err; + + err = parse_sockaddr_pair( + (struct sockaddr *)(rq + 1), + rq->sadb_x_ipsecrequest_len - sizeof(*rq), + &t->saddr, &t->id.daddr, &t->encap_family); + if (err) + return err; + } else + t->encap_family = xp->family; + + /* No way to set this via kame pfkey */ + t->allalgs = 1; + xp->xfrm_nr++; + return 0; +} + +static int +parse_ipsecrequests(struct xfrm_policy *xp, struct sadb_x_policy *pol) +{ + int err; + int len = pol->sadb_x_policy_len*8 - sizeof(struct sadb_x_policy); + struct sadb_x_ipsecrequest *rq = (void*)(pol+1); + + if (pol->sadb_x_policy_len * 8 < sizeof(struct sadb_x_policy)) + return -EINVAL; + + while (len >= sizeof(*rq)) { + if (len < rq->sadb_x_ipsecrequest_len || + rq->sadb_x_ipsecrequest_len < sizeof(*rq)) + return -EINVAL; + + if ((err = parse_ipsecrequest(xp, pol, rq)) < 0) + return err; + len -= rq->sadb_x_ipsecrequest_len; + rq = (void*)((u8*)rq + rq->sadb_x_ipsecrequest_len); + } + return 0; +} + +static inline int pfkey_xfrm_policy2sec_ctx_size(const struct xfrm_policy *xp) +{ + struct xfrm_sec_ctx *xfrm_ctx = xp->security; + + if (xfrm_ctx) { + int len = sizeof(struct sadb_x_sec_ctx); + len += xfrm_ctx->ctx_len; + return PFKEY_ALIGN8(len); + } + return 0; +} + +static int pfkey_xfrm_policy2msg_size(const struct xfrm_policy *xp) +{ + const struct xfrm_tmpl *t; + int sockaddr_size = pfkey_sockaddr_size(xp->family); + int socklen = 0; + int i; + + for (i=0; ixfrm_nr; i++) { + t = xp->xfrm_vec + i; + socklen += pfkey_sockaddr_len(t->encap_family); + } + + return sizeof(struct sadb_msg) + + (sizeof(struct sadb_lifetime) * 3) + + (sizeof(struct sadb_address) * 2) + + (sockaddr_size * 2) + + sizeof(struct sadb_x_policy) + + (xp->xfrm_nr * sizeof(struct sadb_x_ipsecrequest)) + + (socklen * 2) + + pfkey_xfrm_policy2sec_ctx_size(xp); +} + +static struct sk_buff * pfkey_xfrm_policy2msg_prep(const struct xfrm_policy *xp) +{ + struct sk_buff *skb; + int size; + + size = pfkey_xfrm_policy2msg_size(xp); + + skb = alloc_skb(size + 16, GFP_ATOMIC); + if (skb == NULL) + return ERR_PTR(-ENOBUFS); + + return skb; +} + +static int pfkey_xfrm_policy2msg(struct sk_buff *skb, const struct xfrm_policy *xp, int dir) +{ + struct sadb_msg *hdr; + struct sadb_address *addr; + struct sadb_lifetime *lifetime; + struct sadb_x_policy *pol; + struct sadb_x_sec_ctx *sec_ctx; + struct xfrm_sec_ctx *xfrm_ctx; + int i; + int size; + int sockaddr_size = pfkey_sockaddr_size(xp->family); + int socklen = pfkey_sockaddr_len(xp->family); + + size = pfkey_xfrm_policy2msg_size(xp); + + /* call should fill header later */ + hdr = skb_put(skb, sizeof(struct sadb_msg)); + memset(hdr, 0, size); /* XXX do we need this ? */ + + /* src address */ + addr = skb_put(skb, sizeof(struct sadb_address) + sockaddr_size); + addr->sadb_address_len = + (sizeof(struct sadb_address)+sockaddr_size)/ + sizeof(uint64_t); + addr->sadb_address_exttype = SADB_EXT_ADDRESS_SRC; + addr->sadb_address_proto = pfkey_proto_from_xfrm(xp->selector.proto); + addr->sadb_address_prefixlen = xp->selector.prefixlen_s; + addr->sadb_address_reserved = 0; + if (!pfkey_sockaddr_fill(&xp->selector.saddr, + xp->selector.sport, + (struct sockaddr *) (addr + 1), + xp->family)) + BUG(); + + /* dst address */ + addr = skb_put(skb, sizeof(struct sadb_address) + sockaddr_size); + addr->sadb_address_len = + (sizeof(struct sadb_address)+sockaddr_size)/ + sizeof(uint64_t); + addr->sadb_address_exttype = SADB_EXT_ADDRESS_DST; + addr->sadb_address_proto = pfkey_proto_from_xfrm(xp->selector.proto); + addr->sadb_address_prefixlen = xp->selector.prefixlen_d; + addr->sadb_address_reserved = 0; + + pfkey_sockaddr_fill(&xp->selector.daddr, xp->selector.dport, + (struct sockaddr *) (addr + 1), + xp->family); + + /* hard time */ + lifetime = skb_put(skb, sizeof(struct sadb_lifetime)); + lifetime->sadb_lifetime_len = + sizeof(struct sadb_lifetime)/sizeof(uint64_t); + lifetime->sadb_lifetime_exttype = SADB_EXT_LIFETIME_HARD; + lifetime->sadb_lifetime_allocations = _X2KEY(xp->lft.hard_packet_limit); + lifetime->sadb_lifetime_bytes = _X2KEY(xp->lft.hard_byte_limit); + lifetime->sadb_lifetime_addtime = xp->lft.hard_add_expires_seconds; + lifetime->sadb_lifetime_usetime = xp->lft.hard_use_expires_seconds; + /* soft time */ + lifetime = skb_put(skb, sizeof(struct sadb_lifetime)); + lifetime->sadb_lifetime_len = + sizeof(struct sadb_lifetime)/sizeof(uint64_t); + lifetime->sadb_lifetime_exttype = SADB_EXT_LIFETIME_SOFT; + lifetime->sadb_lifetime_allocations = _X2KEY(xp->lft.soft_packet_limit); + lifetime->sadb_lifetime_bytes = _X2KEY(xp->lft.soft_byte_limit); + lifetime->sadb_lifetime_addtime = xp->lft.soft_add_expires_seconds; + lifetime->sadb_lifetime_usetime = xp->lft.soft_use_expires_seconds; + /* current time */ + lifetime = skb_put(skb, sizeof(struct sadb_lifetime)); + lifetime->sadb_lifetime_len = + sizeof(struct sadb_lifetime)/sizeof(uint64_t); + lifetime->sadb_lifetime_exttype = SADB_EXT_LIFETIME_CURRENT; + lifetime->sadb_lifetime_allocations = xp->curlft.packets; + lifetime->sadb_lifetime_bytes = xp->curlft.bytes; + lifetime->sadb_lifetime_addtime = xp->curlft.add_time; + lifetime->sadb_lifetime_usetime = xp->curlft.use_time; + + pol = skb_put(skb, sizeof(struct sadb_x_policy)); + pol->sadb_x_policy_len = sizeof(struct sadb_x_policy)/sizeof(uint64_t); + pol->sadb_x_policy_exttype = SADB_X_EXT_POLICY; + pol->sadb_x_policy_type = IPSEC_POLICY_DISCARD; + if (xp->action == XFRM_POLICY_ALLOW) { + if (xp->xfrm_nr) + pol->sadb_x_policy_type = IPSEC_POLICY_IPSEC; + else + pol->sadb_x_policy_type = IPSEC_POLICY_NONE; + } + pol->sadb_x_policy_dir = dir+1; + pol->sadb_x_policy_reserved = 0; + pol->sadb_x_policy_id = xp->index; + pol->sadb_x_policy_priority = xp->priority; + + for (i=0; ixfrm_nr; i++) { + const struct xfrm_tmpl *t = xp->xfrm_vec + i; + struct sadb_x_ipsecrequest *rq; + int req_size; + int mode; + + req_size = sizeof(struct sadb_x_ipsecrequest); + if (t->mode == XFRM_MODE_TUNNEL) { + socklen = pfkey_sockaddr_len(t->encap_family); + req_size += socklen * 2; + } else { + size -= 2*socklen; + } + rq = skb_put(skb, req_size); + pol->sadb_x_policy_len += req_size/8; + memset(rq, 0, sizeof(*rq)); + rq->sadb_x_ipsecrequest_len = req_size; + rq->sadb_x_ipsecrequest_proto = t->id.proto; + if ((mode = pfkey_mode_from_xfrm(t->mode)) < 0) + return -EINVAL; + rq->sadb_x_ipsecrequest_mode = mode; + rq->sadb_x_ipsecrequest_level = IPSEC_LEVEL_REQUIRE; + if (t->reqid) + rq->sadb_x_ipsecrequest_level = IPSEC_LEVEL_UNIQUE; + if (t->optional) + rq->sadb_x_ipsecrequest_level = IPSEC_LEVEL_USE; + rq->sadb_x_ipsecrequest_reqid = t->reqid; + + if (t->mode == XFRM_MODE_TUNNEL) { + u8 *sa = (void *)(rq + 1); + pfkey_sockaddr_fill(&t->saddr, 0, + (struct sockaddr *)sa, + t->encap_family); + pfkey_sockaddr_fill(&t->id.daddr, 0, + (struct sockaddr *) (sa + socklen), + t->encap_family); + } + } + + /* security context */ + if ((xfrm_ctx = xp->security)) { + int ctx_size = pfkey_xfrm_policy2sec_ctx_size(xp); + + sec_ctx = skb_put(skb, ctx_size); + sec_ctx->sadb_x_sec_len = ctx_size / sizeof(uint64_t); + sec_ctx->sadb_x_sec_exttype = SADB_X_EXT_SEC_CTX; + sec_ctx->sadb_x_ctx_doi = xfrm_ctx->ctx_doi; + sec_ctx->sadb_x_ctx_alg = xfrm_ctx->ctx_alg; + sec_ctx->sadb_x_ctx_len = xfrm_ctx->ctx_len; + memcpy(sec_ctx + 1, xfrm_ctx->ctx_str, + xfrm_ctx->ctx_len); + } + + hdr->sadb_msg_len = size / sizeof(uint64_t); + hdr->sadb_msg_reserved = refcount_read(&xp->refcnt); + + return 0; +} + +static int key_notify_policy(struct xfrm_policy *xp, int dir, const struct km_event *c) +{ + struct sk_buff *out_skb; + struct sadb_msg *out_hdr; + int err; + + out_skb = pfkey_xfrm_policy2msg_prep(xp); + if (IS_ERR(out_skb)) + return PTR_ERR(out_skb); + + err = pfkey_xfrm_policy2msg(out_skb, xp, dir); + if (err < 0) { + kfree_skb(out_skb); + return err; + } + + out_hdr = (struct sadb_msg *) out_skb->data; + out_hdr->sadb_msg_version = PF_KEY_V2; + + if (c->data.byid && c->event == XFRM_MSG_DELPOLICY) + out_hdr->sadb_msg_type = SADB_X_SPDDELETE2; + else + out_hdr->sadb_msg_type = event2poltype(c->event); + out_hdr->sadb_msg_errno = 0; + out_hdr->sadb_msg_seq = c->seq; + out_hdr->sadb_msg_pid = c->portid; + pfkey_broadcast(out_skb, GFP_ATOMIC, BROADCAST_ALL, NULL, xp_net(xp)); + return 0; + +} + +static int pfkey_spdadd(struct sock *sk, struct sk_buff *skb, const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + struct net *net = sock_net(sk); + int err = 0; + struct sadb_lifetime *lifetime; + struct sadb_address *sa; + struct sadb_x_policy *pol; + struct xfrm_policy *xp; + struct km_event c; + struct sadb_x_sec_ctx *sec_ctx; + + if (!present_and_same_family(ext_hdrs[SADB_EXT_ADDRESS_SRC-1], + ext_hdrs[SADB_EXT_ADDRESS_DST-1]) || + !ext_hdrs[SADB_X_EXT_POLICY-1]) + return -EINVAL; + + pol = ext_hdrs[SADB_X_EXT_POLICY-1]; + if (pol->sadb_x_policy_type > IPSEC_POLICY_IPSEC) + return -EINVAL; + if (!pol->sadb_x_policy_dir || pol->sadb_x_policy_dir >= IPSEC_DIR_MAX) + return -EINVAL; + + xp = xfrm_policy_alloc(net, GFP_KERNEL); + if (xp == NULL) + return -ENOBUFS; + + xp->action = (pol->sadb_x_policy_type == IPSEC_POLICY_DISCARD ? + XFRM_POLICY_BLOCK : XFRM_POLICY_ALLOW); + xp->priority = pol->sadb_x_policy_priority; + + sa = ext_hdrs[SADB_EXT_ADDRESS_SRC-1]; + xp->family = pfkey_sadb_addr2xfrm_addr(sa, &xp->selector.saddr); + xp->selector.family = xp->family; + xp->selector.prefixlen_s = sa->sadb_address_prefixlen; + xp->selector.proto = pfkey_proto_to_xfrm(sa->sadb_address_proto); + xp->selector.sport = ((struct sockaddr_in *)(sa+1))->sin_port; + if (xp->selector.sport) + xp->selector.sport_mask = htons(0xffff); + + sa = ext_hdrs[SADB_EXT_ADDRESS_DST-1]; + pfkey_sadb_addr2xfrm_addr(sa, &xp->selector.daddr); + xp->selector.prefixlen_d = sa->sadb_address_prefixlen; + + /* Amusing, we set this twice. KAME apps appear to set same value + * in both addresses. + */ + xp->selector.proto = pfkey_proto_to_xfrm(sa->sadb_address_proto); + + xp->selector.dport = ((struct sockaddr_in *)(sa+1))->sin_port; + if (xp->selector.dport) + xp->selector.dport_mask = htons(0xffff); + + sec_ctx = ext_hdrs[SADB_X_EXT_SEC_CTX - 1]; + if (sec_ctx != NULL) { + struct xfrm_user_sec_ctx *uctx = pfkey_sadb2xfrm_user_sec_ctx(sec_ctx, GFP_KERNEL); + + if (!uctx) { + err = -ENOBUFS; + goto out; + } + + err = security_xfrm_policy_alloc(&xp->security, uctx, GFP_KERNEL); + kfree(uctx); + + if (err) + goto out; + } + + xp->lft.soft_byte_limit = XFRM_INF; + xp->lft.hard_byte_limit = XFRM_INF; + xp->lft.soft_packet_limit = XFRM_INF; + xp->lft.hard_packet_limit = XFRM_INF; + if ((lifetime = ext_hdrs[SADB_EXT_LIFETIME_HARD-1]) != NULL) { + xp->lft.hard_packet_limit = _KEY2X(lifetime->sadb_lifetime_allocations); + xp->lft.hard_byte_limit = _KEY2X(lifetime->sadb_lifetime_bytes); + xp->lft.hard_add_expires_seconds = lifetime->sadb_lifetime_addtime; + xp->lft.hard_use_expires_seconds = lifetime->sadb_lifetime_usetime; + } + if ((lifetime = ext_hdrs[SADB_EXT_LIFETIME_SOFT-1]) != NULL) { + xp->lft.soft_packet_limit = _KEY2X(lifetime->sadb_lifetime_allocations); + xp->lft.soft_byte_limit = _KEY2X(lifetime->sadb_lifetime_bytes); + xp->lft.soft_add_expires_seconds = lifetime->sadb_lifetime_addtime; + xp->lft.soft_use_expires_seconds = lifetime->sadb_lifetime_usetime; + } + xp->xfrm_nr = 0; + if (pol->sadb_x_policy_type == IPSEC_POLICY_IPSEC && + (err = parse_ipsecrequests(xp, pol)) < 0) + goto out; + + err = xfrm_policy_insert(pol->sadb_x_policy_dir-1, xp, + hdr->sadb_msg_type != SADB_X_SPDUPDATE); + + xfrm_audit_policy_add(xp, err ? 0 : 1, true); + + if (err) + goto out; + + if (hdr->sadb_msg_type == SADB_X_SPDUPDATE) + c.event = XFRM_MSG_UPDPOLICY; + else + c.event = XFRM_MSG_NEWPOLICY; + + c.seq = hdr->sadb_msg_seq; + c.portid = hdr->sadb_msg_pid; + + km_policy_notify(xp, pol->sadb_x_policy_dir-1, &c); + xfrm_pol_put(xp); + return 0; + +out: + xp->walk.dead = 1; + xfrm_policy_destroy(xp); + return err; +} + +static int pfkey_spddelete(struct sock *sk, struct sk_buff *skb, const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + struct net *net = sock_net(sk); + int err; + struct sadb_address *sa; + struct sadb_x_policy *pol; + struct xfrm_policy *xp; + struct xfrm_selector sel; + struct km_event c; + struct sadb_x_sec_ctx *sec_ctx; + struct xfrm_sec_ctx *pol_ctx = NULL; + + if (!present_and_same_family(ext_hdrs[SADB_EXT_ADDRESS_SRC-1], + ext_hdrs[SADB_EXT_ADDRESS_DST-1]) || + !ext_hdrs[SADB_X_EXT_POLICY-1]) + return -EINVAL; + + pol = ext_hdrs[SADB_X_EXT_POLICY-1]; + if (!pol->sadb_x_policy_dir || pol->sadb_x_policy_dir >= IPSEC_DIR_MAX) + return -EINVAL; + + memset(&sel, 0, sizeof(sel)); + + sa = ext_hdrs[SADB_EXT_ADDRESS_SRC-1]; + sel.family = pfkey_sadb_addr2xfrm_addr(sa, &sel.saddr); + sel.prefixlen_s = sa->sadb_address_prefixlen; + sel.proto = pfkey_proto_to_xfrm(sa->sadb_address_proto); + sel.sport = ((struct sockaddr_in *)(sa+1))->sin_port; + if (sel.sport) + sel.sport_mask = htons(0xffff); + + sa = ext_hdrs[SADB_EXT_ADDRESS_DST-1]; + pfkey_sadb_addr2xfrm_addr(sa, &sel.daddr); + sel.prefixlen_d = sa->sadb_address_prefixlen; + sel.proto = pfkey_proto_to_xfrm(sa->sadb_address_proto); + sel.dport = ((struct sockaddr_in *)(sa+1))->sin_port; + if (sel.dport) + sel.dport_mask = htons(0xffff); + + sec_ctx = ext_hdrs[SADB_X_EXT_SEC_CTX - 1]; + if (sec_ctx != NULL) { + struct xfrm_user_sec_ctx *uctx = pfkey_sadb2xfrm_user_sec_ctx(sec_ctx, GFP_KERNEL); + + if (!uctx) + return -ENOMEM; + + err = security_xfrm_policy_alloc(&pol_ctx, uctx, GFP_KERNEL); + kfree(uctx); + if (err) + return err; + } + + xp = xfrm_policy_bysel_ctx(net, &dummy_mark, 0, XFRM_POLICY_TYPE_MAIN, + pol->sadb_x_policy_dir - 1, &sel, pol_ctx, + 1, &err); + security_xfrm_policy_free(pol_ctx); + if (xp == NULL) + return -ENOENT; + + xfrm_audit_policy_delete(xp, err ? 0 : 1, true); + + if (err) + goto out; + + c.seq = hdr->sadb_msg_seq; + c.portid = hdr->sadb_msg_pid; + c.data.byid = 0; + c.event = XFRM_MSG_DELPOLICY; + km_policy_notify(xp, pol->sadb_x_policy_dir-1, &c); + +out: + xfrm_pol_put(xp); + return err; +} + +static int key_pol_get_resp(struct sock *sk, struct xfrm_policy *xp, const struct sadb_msg *hdr, int dir) +{ + int err; + struct sk_buff *out_skb; + struct sadb_msg *out_hdr; + err = 0; + + out_skb = pfkey_xfrm_policy2msg_prep(xp); + if (IS_ERR(out_skb)) { + err = PTR_ERR(out_skb); + goto out; + } + err = pfkey_xfrm_policy2msg(out_skb, xp, dir); + if (err < 0) { + kfree_skb(out_skb); + goto out; + } + + out_hdr = (struct sadb_msg *) out_skb->data; + out_hdr->sadb_msg_version = hdr->sadb_msg_version; + out_hdr->sadb_msg_type = hdr->sadb_msg_type; + out_hdr->sadb_msg_satype = 0; + out_hdr->sadb_msg_errno = 0; + out_hdr->sadb_msg_seq = hdr->sadb_msg_seq; + out_hdr->sadb_msg_pid = hdr->sadb_msg_pid; + pfkey_broadcast(out_skb, GFP_ATOMIC, BROADCAST_ONE, sk, xp_net(xp)); + err = 0; + +out: + return err; +} + +static int pfkey_sockaddr_pair_size(sa_family_t family) +{ + return PFKEY_ALIGN8(pfkey_sockaddr_len(family) * 2); +} + +static int parse_sockaddr_pair(struct sockaddr *sa, int ext_len, + xfrm_address_t *saddr, xfrm_address_t *daddr, + u16 *family) +{ + int af, socklen; + + if (ext_len < 2 || ext_len < pfkey_sockaddr_pair_size(sa->sa_family)) + return -EINVAL; + + af = pfkey_sockaddr_extract(sa, saddr); + if (!af) + return -EINVAL; + + socklen = pfkey_sockaddr_len(af); + if (pfkey_sockaddr_extract((struct sockaddr *) (((u8 *)sa) + socklen), + daddr) != af) + return -EINVAL; + + *family = af; + return 0; +} + +#ifdef CONFIG_NET_KEY_MIGRATE +static int ipsecrequests_to_migrate(struct sadb_x_ipsecrequest *rq1, int len, + struct xfrm_migrate *m) +{ + int err; + struct sadb_x_ipsecrequest *rq2; + int mode; + + if (len < sizeof(*rq1) || + len < rq1->sadb_x_ipsecrequest_len || + rq1->sadb_x_ipsecrequest_len < sizeof(*rq1)) + return -EINVAL; + + /* old endoints */ + err = parse_sockaddr_pair((struct sockaddr *)(rq1 + 1), + rq1->sadb_x_ipsecrequest_len - sizeof(*rq1), + &m->old_saddr, &m->old_daddr, + &m->old_family); + if (err) + return err; + + rq2 = (struct sadb_x_ipsecrequest *)((u8 *)rq1 + rq1->sadb_x_ipsecrequest_len); + len -= rq1->sadb_x_ipsecrequest_len; + + if (len <= sizeof(*rq2) || + len < rq2->sadb_x_ipsecrequest_len || + rq2->sadb_x_ipsecrequest_len < sizeof(*rq2)) + return -EINVAL; + + /* new endpoints */ + err = parse_sockaddr_pair((struct sockaddr *)(rq2 + 1), + rq2->sadb_x_ipsecrequest_len - sizeof(*rq2), + &m->new_saddr, &m->new_daddr, + &m->new_family); + if (err) + return err; + + if (rq1->sadb_x_ipsecrequest_proto != rq2->sadb_x_ipsecrequest_proto || + rq1->sadb_x_ipsecrequest_mode != rq2->sadb_x_ipsecrequest_mode || + rq1->sadb_x_ipsecrequest_reqid != rq2->sadb_x_ipsecrequest_reqid) + return -EINVAL; + + m->proto = rq1->sadb_x_ipsecrequest_proto; + if ((mode = pfkey_mode_to_xfrm(rq1->sadb_x_ipsecrequest_mode)) < 0) + return -EINVAL; + m->mode = mode; + m->reqid = rq1->sadb_x_ipsecrequest_reqid; + + return ((int)(rq1->sadb_x_ipsecrequest_len + + rq2->sadb_x_ipsecrequest_len)); +} + +static int pfkey_migrate(struct sock *sk, struct sk_buff *skb, + const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + int i, len, ret, err = -EINVAL; + u8 dir; + struct sadb_address *sa; + struct sadb_x_kmaddress *kma; + struct sadb_x_policy *pol; + struct sadb_x_ipsecrequest *rq; + struct xfrm_selector sel; + struct xfrm_migrate m[XFRM_MAX_DEPTH]; + struct xfrm_kmaddress k; + struct net *net = sock_net(sk); + + if (!present_and_same_family(ext_hdrs[SADB_EXT_ADDRESS_SRC - 1], + ext_hdrs[SADB_EXT_ADDRESS_DST - 1]) || + !ext_hdrs[SADB_X_EXT_POLICY - 1]) { + err = -EINVAL; + goto out; + } + + kma = ext_hdrs[SADB_X_EXT_KMADDRESS - 1]; + pol = ext_hdrs[SADB_X_EXT_POLICY - 1]; + + if (pol->sadb_x_policy_dir >= IPSEC_DIR_MAX) { + err = -EINVAL; + goto out; + } + + if (kma) { + /* convert sadb_x_kmaddress to xfrm_kmaddress */ + k.reserved = kma->sadb_x_kmaddress_reserved; + ret = parse_sockaddr_pair((struct sockaddr *)(kma + 1), + 8*(kma->sadb_x_kmaddress_len) - sizeof(*kma), + &k.local, &k.remote, &k.family); + if (ret < 0) { + err = ret; + goto out; + } + } + + dir = pol->sadb_x_policy_dir - 1; + memset(&sel, 0, sizeof(sel)); + + /* set source address info of selector */ + sa = ext_hdrs[SADB_EXT_ADDRESS_SRC - 1]; + sel.family = pfkey_sadb_addr2xfrm_addr(sa, &sel.saddr); + sel.prefixlen_s = sa->sadb_address_prefixlen; + sel.proto = pfkey_proto_to_xfrm(sa->sadb_address_proto); + sel.sport = ((struct sockaddr_in *)(sa + 1))->sin_port; + if (sel.sport) + sel.sport_mask = htons(0xffff); + + /* set destination address info of selector */ + sa = ext_hdrs[SADB_EXT_ADDRESS_DST - 1]; + pfkey_sadb_addr2xfrm_addr(sa, &sel.daddr); + sel.prefixlen_d = sa->sadb_address_prefixlen; + sel.proto = pfkey_proto_to_xfrm(sa->sadb_address_proto); + sel.dport = ((struct sockaddr_in *)(sa + 1))->sin_port; + if (sel.dport) + sel.dport_mask = htons(0xffff); + + rq = (struct sadb_x_ipsecrequest *)(pol + 1); + + /* extract ipsecrequests */ + i = 0; + len = pol->sadb_x_policy_len * 8 - sizeof(struct sadb_x_policy); + + while (len > 0 && i < XFRM_MAX_DEPTH) { + ret = ipsecrequests_to_migrate(rq, len, &m[i]); + if (ret < 0) { + err = ret; + goto out; + } else { + rq = (struct sadb_x_ipsecrequest *)((u8 *)rq + ret); + len -= ret; + i++; + } + } + + if (!i || len > 0) { + err = -EINVAL; + goto out; + } + + return xfrm_migrate(&sel, dir, XFRM_POLICY_TYPE_MAIN, m, i, + kma ? &k : NULL, net, NULL, 0); + + out: + return err; +} +#else +static int pfkey_migrate(struct sock *sk, struct sk_buff *skb, + const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + return -ENOPROTOOPT; +} +#endif + + +static int pfkey_spdget(struct sock *sk, struct sk_buff *skb, const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + struct net *net = sock_net(sk); + unsigned int dir; + int err = 0, delete; + struct sadb_x_policy *pol; + struct xfrm_policy *xp; + struct km_event c; + + if ((pol = ext_hdrs[SADB_X_EXT_POLICY-1]) == NULL) + return -EINVAL; + + dir = xfrm_policy_id2dir(pol->sadb_x_policy_id); + if (dir >= XFRM_POLICY_MAX) + return -EINVAL; + + delete = (hdr->sadb_msg_type == SADB_X_SPDDELETE2); + xp = xfrm_policy_byid(net, &dummy_mark, 0, XFRM_POLICY_TYPE_MAIN, + dir, pol->sadb_x_policy_id, delete, &err); + if (xp == NULL) + return -ENOENT; + + if (delete) { + xfrm_audit_policy_delete(xp, err ? 0 : 1, true); + + if (err) + goto out; + c.seq = hdr->sadb_msg_seq; + c.portid = hdr->sadb_msg_pid; + c.data.byid = 1; + c.event = XFRM_MSG_DELPOLICY; + km_policy_notify(xp, dir, &c); + } else { + err = key_pol_get_resp(sk, xp, hdr, dir); + } + +out: + xfrm_pol_put(xp); + return err; +} + +static int dump_sp(struct xfrm_policy *xp, int dir, int count, void *ptr) +{ + struct pfkey_sock *pfk = ptr; + struct sk_buff *out_skb; + struct sadb_msg *out_hdr; + int err; + + if (!pfkey_can_dump(&pfk->sk)) + return -ENOBUFS; + + out_skb = pfkey_xfrm_policy2msg_prep(xp); + if (IS_ERR(out_skb)) + return PTR_ERR(out_skb); + + err = pfkey_xfrm_policy2msg(out_skb, xp, dir); + if (err < 0) { + kfree_skb(out_skb); + return err; + } + + out_hdr = (struct sadb_msg *) out_skb->data; + out_hdr->sadb_msg_version = pfk->dump.msg_version; + out_hdr->sadb_msg_type = SADB_X_SPDDUMP; + out_hdr->sadb_msg_satype = SADB_SATYPE_UNSPEC; + out_hdr->sadb_msg_errno = 0; + out_hdr->sadb_msg_seq = count + 1; + out_hdr->sadb_msg_pid = pfk->dump.msg_portid; + + if (pfk->dump.skb) + pfkey_broadcast(pfk->dump.skb, GFP_ATOMIC, BROADCAST_ONE, + &pfk->sk, sock_net(&pfk->sk)); + pfk->dump.skb = out_skb; + + return 0; +} + +static int pfkey_dump_sp(struct pfkey_sock *pfk) +{ + struct net *net = sock_net(&pfk->sk); + return xfrm_policy_walk(net, &pfk->dump.u.policy, dump_sp, (void *) pfk); +} + +static void pfkey_dump_sp_done(struct pfkey_sock *pfk) +{ + struct net *net = sock_net((struct sock *)pfk); + + xfrm_policy_walk_done(&pfk->dump.u.policy, net); +} + +static int pfkey_spddump(struct sock *sk, struct sk_buff *skb, const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + struct pfkey_sock *pfk = pfkey_sk(sk); + + mutex_lock(&pfk->dump_lock); + if (pfk->dump.dump != NULL) { + mutex_unlock(&pfk->dump_lock); + return -EBUSY; + } + + pfk->dump.msg_version = hdr->sadb_msg_version; + pfk->dump.msg_portid = hdr->sadb_msg_pid; + pfk->dump.dump = pfkey_dump_sp; + pfk->dump.done = pfkey_dump_sp_done; + xfrm_policy_walk_init(&pfk->dump.u.policy, XFRM_POLICY_TYPE_MAIN); + mutex_unlock(&pfk->dump_lock); + + return pfkey_do_dump(pfk); +} + +static int key_notify_policy_flush(const struct km_event *c) +{ + struct sk_buff *skb_out; + struct sadb_msg *hdr; + + skb_out = alloc_skb(sizeof(struct sadb_msg) + 16, GFP_ATOMIC); + if (!skb_out) + return -ENOBUFS; + hdr = skb_put(skb_out, sizeof(struct sadb_msg)); + hdr->sadb_msg_type = SADB_X_SPDFLUSH; + hdr->sadb_msg_seq = c->seq; + hdr->sadb_msg_pid = c->portid; + hdr->sadb_msg_version = PF_KEY_V2; + hdr->sadb_msg_errno = (uint8_t) 0; + hdr->sadb_msg_satype = SADB_SATYPE_UNSPEC; + hdr->sadb_msg_len = (sizeof(struct sadb_msg) / sizeof(uint64_t)); + hdr->sadb_msg_reserved = 0; + pfkey_broadcast(skb_out, GFP_ATOMIC, BROADCAST_ALL, NULL, c->net); + return 0; + +} + +static int pfkey_spdflush(struct sock *sk, struct sk_buff *skb, const struct sadb_msg *hdr, void * const *ext_hdrs) +{ + struct net *net = sock_net(sk); + struct km_event c; + int err, err2; + + err = xfrm_policy_flush(net, XFRM_POLICY_TYPE_MAIN, true); + err2 = unicast_flush_resp(sk, hdr); + if (err || err2) { + if (err == -ESRCH) /* empty table - old silent behavior */ + return 0; + return err; + } + + c.data.type = XFRM_POLICY_TYPE_MAIN; + c.event = XFRM_MSG_FLUSHPOLICY; + c.portid = hdr->sadb_msg_pid; + c.seq = hdr->sadb_msg_seq; + c.net = net; + km_policy_notify(NULL, 0, &c); + + return 0; +} + +typedef int (*pfkey_handler)(struct sock *sk, struct sk_buff *skb, + const struct sadb_msg *hdr, void * const *ext_hdrs); +static const pfkey_handler pfkey_funcs[SADB_MAX + 1] = { + [SADB_RESERVED] = pfkey_reserved, + [SADB_GETSPI] = pfkey_getspi, + [SADB_UPDATE] = pfkey_add, + [SADB_ADD] = pfkey_add, + [SADB_DELETE] = pfkey_delete, + [SADB_GET] = pfkey_get, + [SADB_ACQUIRE] = pfkey_acquire, + [SADB_REGISTER] = pfkey_register, + [SADB_EXPIRE] = NULL, + [SADB_FLUSH] = pfkey_flush, + [SADB_DUMP] = pfkey_dump, + [SADB_X_PROMISC] = pfkey_promisc, + [SADB_X_PCHANGE] = NULL, + [SADB_X_SPDUPDATE] = pfkey_spdadd, + [SADB_X_SPDADD] = pfkey_spdadd, + [SADB_X_SPDDELETE] = pfkey_spddelete, + [SADB_X_SPDGET] = pfkey_spdget, + [SADB_X_SPDACQUIRE] = NULL, + [SADB_X_SPDDUMP] = pfkey_spddump, + [SADB_X_SPDFLUSH] = pfkey_spdflush, + [SADB_X_SPDSETIDX] = pfkey_spdadd, + [SADB_X_SPDDELETE2] = pfkey_spdget, + [SADB_X_MIGRATE] = pfkey_migrate, +}; + +static int pfkey_process(struct sock *sk, struct sk_buff *skb, const struct sadb_msg *hdr) +{ + void *ext_hdrs[SADB_EXT_MAX]; + int err; + + /* Non-zero return value of pfkey_broadcast() does not always signal + * an error and even on an actual error we may still want to process + * the message so rather ignore the return value. + */ + pfkey_broadcast(skb_clone(skb, GFP_KERNEL), GFP_KERNEL, + BROADCAST_PROMISC_ONLY, NULL, sock_net(sk)); + + memset(ext_hdrs, 0, sizeof(ext_hdrs)); + err = parse_exthdrs(skb, hdr, ext_hdrs); + if (!err) { + err = -EOPNOTSUPP; + if (pfkey_funcs[hdr->sadb_msg_type]) + err = pfkey_funcs[hdr->sadb_msg_type](sk, skb, hdr, ext_hdrs); + } + return err; +} + +static struct sadb_msg *pfkey_get_base_msg(struct sk_buff *skb, int *errp) +{ + struct sadb_msg *hdr = NULL; + + if (skb->len < sizeof(*hdr)) { + *errp = -EMSGSIZE; + } else { + hdr = (struct sadb_msg *) skb->data; + if (hdr->sadb_msg_version != PF_KEY_V2 || + hdr->sadb_msg_reserved != 0 || + (hdr->sadb_msg_type <= SADB_RESERVED || + hdr->sadb_msg_type > SADB_MAX)) { + hdr = NULL; + *errp = -EINVAL; + } else if (hdr->sadb_msg_len != (skb->len / + sizeof(uint64_t)) || + hdr->sadb_msg_len < (sizeof(struct sadb_msg) / + sizeof(uint64_t))) { + hdr = NULL; + *errp = -EMSGSIZE; + } else { + *errp = 0; + } + } + return hdr; +} + +static inline int aalg_tmpl_set(const struct xfrm_tmpl *t, + const struct xfrm_algo_desc *d) +{ + unsigned int id = d->desc.sadb_alg_id; + + if (id >= sizeof(t->aalgos) * 8) + return 0; + + return (t->aalgos >> id) & 1; +} + +static inline int ealg_tmpl_set(const struct xfrm_tmpl *t, + const struct xfrm_algo_desc *d) +{ + unsigned int id = d->desc.sadb_alg_id; + + if (id >= sizeof(t->ealgos) * 8) + return 0; + + return (t->ealgos >> id) & 1; +} + +static int count_ah_combs(const struct xfrm_tmpl *t) +{ + int i, sz = 0; + + for (i = 0; ; i++) { + const struct xfrm_algo_desc *aalg = xfrm_aalg_get_byidx(i); + if (!aalg) + break; + if (!aalg->pfkey_supported) + continue; + if (aalg_tmpl_set(t, aalg)) + sz += sizeof(struct sadb_comb); + } + return sz + sizeof(struct sadb_prop); +} + +static int count_esp_combs(const struct xfrm_tmpl *t) +{ + int i, k, sz = 0; + + for (i = 0; ; i++) { + const struct xfrm_algo_desc *ealg = xfrm_ealg_get_byidx(i); + if (!ealg) + break; + + if (!ealg->pfkey_supported) + continue; + + if (!(ealg_tmpl_set(t, ealg))) + continue; + + for (k = 1; ; k++) { + const struct xfrm_algo_desc *aalg = xfrm_aalg_get_byidx(k); + if (!aalg) + break; + + if (!aalg->pfkey_supported) + continue; + + if (aalg_tmpl_set(t, aalg)) + sz += sizeof(struct sadb_comb); + } + } + return sz + sizeof(struct sadb_prop); +} + +static int dump_ah_combs(struct sk_buff *skb, const struct xfrm_tmpl *t) +{ + struct sadb_prop *p; + int sz = 0; + int i; + + p = skb_put(skb, sizeof(struct sadb_prop)); + p->sadb_prop_len = sizeof(struct sadb_prop)/8; + p->sadb_prop_exttype = SADB_EXT_PROPOSAL; + p->sadb_prop_replay = 32; + memset(p->sadb_prop_reserved, 0, sizeof(p->sadb_prop_reserved)); + + for (i = 0; ; i++) { + const struct xfrm_algo_desc *aalg = xfrm_aalg_get_byidx(i); + if (!aalg) + break; + + if (!aalg->pfkey_supported) + continue; + + if (aalg_tmpl_set(t, aalg) && aalg->available) { + struct sadb_comb *c; + c = skb_put_zero(skb, sizeof(struct sadb_comb)); + p->sadb_prop_len += sizeof(struct sadb_comb)/8; + c->sadb_comb_auth = aalg->desc.sadb_alg_id; + c->sadb_comb_auth_minbits = aalg->desc.sadb_alg_minbits; + c->sadb_comb_auth_maxbits = aalg->desc.sadb_alg_maxbits; + c->sadb_comb_hard_addtime = 24*60*60; + c->sadb_comb_soft_addtime = 20*60*60; + c->sadb_comb_hard_usetime = 8*60*60; + c->sadb_comb_soft_usetime = 7*60*60; + sz += sizeof(*c); + } + } + + return sz + sizeof(*p); +} + +static int dump_esp_combs(struct sk_buff *skb, const struct xfrm_tmpl *t) +{ + struct sadb_prop *p; + int sz = 0; + int i, k; + + p = skb_put(skb, sizeof(struct sadb_prop)); + p->sadb_prop_len = sizeof(struct sadb_prop)/8; + p->sadb_prop_exttype = SADB_EXT_PROPOSAL; + p->sadb_prop_replay = 32; + memset(p->sadb_prop_reserved, 0, sizeof(p->sadb_prop_reserved)); + + for (i=0; ; i++) { + const struct xfrm_algo_desc *ealg = xfrm_ealg_get_byidx(i); + if (!ealg) + break; + + if (!ealg->pfkey_supported) + continue; + + if (!(ealg_tmpl_set(t, ealg) && ealg->available)) + continue; + + for (k = 1; ; k++) { + struct sadb_comb *c; + const struct xfrm_algo_desc *aalg = xfrm_aalg_get_byidx(k); + if (!aalg) + break; + if (!aalg->pfkey_supported) + continue; + if (!(aalg_tmpl_set(t, aalg) && aalg->available)) + continue; + c = skb_put(skb, sizeof(struct sadb_comb)); + memset(c, 0, sizeof(*c)); + p->sadb_prop_len += sizeof(struct sadb_comb)/8; + c->sadb_comb_auth = aalg->desc.sadb_alg_id; + c->sadb_comb_auth_minbits = aalg->desc.sadb_alg_minbits; + c->sadb_comb_auth_maxbits = aalg->desc.sadb_alg_maxbits; + c->sadb_comb_encrypt = ealg->desc.sadb_alg_id; + c->sadb_comb_encrypt_minbits = ealg->desc.sadb_alg_minbits; + c->sadb_comb_encrypt_maxbits = ealg->desc.sadb_alg_maxbits; + c->sadb_comb_hard_addtime = 24*60*60; + c->sadb_comb_soft_addtime = 20*60*60; + c->sadb_comb_hard_usetime = 8*60*60; + c->sadb_comb_soft_usetime = 7*60*60; + sz += sizeof(*c); + } + } + + return sz + sizeof(*p); +} + +static int key_notify_policy_expire(struct xfrm_policy *xp, const struct km_event *c) +{ + return 0; +} + +static int key_notify_sa_expire(struct xfrm_state *x, const struct km_event *c) +{ + struct sk_buff *out_skb; + struct sadb_msg *out_hdr; + int hard; + int hsc; + + hard = c->data.hard; + if (hard) + hsc = 2; + else + hsc = 1; + + out_skb = pfkey_xfrm_state2msg_expire(x, hsc); + if (IS_ERR(out_skb)) + return PTR_ERR(out_skb); + + out_hdr = (struct sadb_msg *) out_skb->data; + out_hdr->sadb_msg_version = PF_KEY_V2; + out_hdr->sadb_msg_type = SADB_EXPIRE; + out_hdr->sadb_msg_satype = pfkey_proto2satype(x->id.proto); + out_hdr->sadb_msg_errno = 0; + out_hdr->sadb_msg_reserved = 0; + out_hdr->sadb_msg_seq = 0; + out_hdr->sadb_msg_pid = 0; + + pfkey_broadcast(out_skb, GFP_ATOMIC, BROADCAST_REGISTERED, NULL, + xs_net(x)); + return 0; +} + +static int pfkey_send_notify(struct xfrm_state *x, const struct km_event *c) +{ + struct net *net = x ? xs_net(x) : c->net; + struct netns_pfkey *net_pfkey = net_generic(net, pfkey_net_id); + + if (atomic_read(&net_pfkey->socks_nr) == 0) + return 0; + + switch (c->event) { + case XFRM_MSG_EXPIRE: + return key_notify_sa_expire(x, c); + case XFRM_MSG_DELSA: + case XFRM_MSG_NEWSA: + case XFRM_MSG_UPDSA: + return key_notify_sa(x, c); + case XFRM_MSG_FLUSHSA: + return key_notify_sa_flush(c); + case XFRM_MSG_NEWAE: /* not yet supported */ + break; + default: + pr_err("pfkey: Unknown SA event %d\n", c->event); + break; + } + + return 0; +} + +static int pfkey_send_policy_notify(struct xfrm_policy *xp, int dir, const struct km_event *c) +{ + if (xp && xp->type != XFRM_POLICY_TYPE_MAIN) + return 0; + + switch (c->event) { + case XFRM_MSG_POLEXPIRE: + return key_notify_policy_expire(xp, c); + case XFRM_MSG_DELPOLICY: + case XFRM_MSG_NEWPOLICY: + case XFRM_MSG_UPDPOLICY: + return key_notify_policy(xp, dir, c); + case XFRM_MSG_FLUSHPOLICY: + if (c->data.type != XFRM_POLICY_TYPE_MAIN) + break; + return key_notify_policy_flush(c); + default: + pr_err("pfkey: Unknown policy event %d\n", c->event); + break; + } + + return 0; +} + +static u32 get_acqseq(void) +{ + u32 res; + static atomic_t acqseq; + + do { + res = atomic_inc_return(&acqseq); + } while (!res); + return res; +} + +static bool pfkey_is_alive(const struct km_event *c) +{ + struct netns_pfkey *net_pfkey = net_generic(c->net, pfkey_net_id); + struct sock *sk; + bool is_alive = false; + + rcu_read_lock(); + sk_for_each_rcu(sk, &net_pfkey->table) { + if (pfkey_sk(sk)->registered) { + is_alive = true; + break; + } + } + rcu_read_unlock(); + + return is_alive; +} + +static int pfkey_send_acquire(struct xfrm_state *x, struct xfrm_tmpl *t, struct xfrm_policy *xp) +{ + struct sk_buff *skb; + struct sadb_msg *hdr; + struct sadb_address *addr; + struct sadb_x_policy *pol; + int sockaddr_size; + int size; + struct sadb_x_sec_ctx *sec_ctx; + struct xfrm_sec_ctx *xfrm_ctx; + int ctx_size = 0; + int alg_size = 0; + + sockaddr_size = pfkey_sockaddr_size(x->props.family); + if (!sockaddr_size) + return -EINVAL; + + size = sizeof(struct sadb_msg) + + (sizeof(struct sadb_address) * 2) + + (sockaddr_size * 2) + + sizeof(struct sadb_x_policy); + + if (x->id.proto == IPPROTO_AH) + alg_size = count_ah_combs(t); + else if (x->id.proto == IPPROTO_ESP) + alg_size = count_esp_combs(t); + + if ((xfrm_ctx = x->security)) { + ctx_size = PFKEY_ALIGN8(xfrm_ctx->ctx_len); + size += sizeof(struct sadb_x_sec_ctx) + ctx_size; + } + + skb = alloc_skb(size + alg_size + 16, GFP_ATOMIC); + if (skb == NULL) + return -ENOMEM; + + hdr = skb_put(skb, sizeof(struct sadb_msg)); + hdr->sadb_msg_version = PF_KEY_V2; + hdr->sadb_msg_type = SADB_ACQUIRE; + hdr->sadb_msg_satype = pfkey_proto2satype(x->id.proto); + hdr->sadb_msg_len = size / sizeof(uint64_t); + hdr->sadb_msg_errno = 0; + hdr->sadb_msg_reserved = 0; + hdr->sadb_msg_seq = x->km.seq = get_acqseq(); + hdr->sadb_msg_pid = 0; + + /* src address */ + addr = skb_put(skb, sizeof(struct sadb_address) + sockaddr_size); + addr->sadb_address_len = + (sizeof(struct sadb_address)+sockaddr_size)/ + sizeof(uint64_t); + addr->sadb_address_exttype = SADB_EXT_ADDRESS_SRC; + addr->sadb_address_proto = 0; + addr->sadb_address_reserved = 0; + addr->sadb_address_prefixlen = + pfkey_sockaddr_fill(&x->props.saddr, 0, + (struct sockaddr *) (addr + 1), + x->props.family); + if (!addr->sadb_address_prefixlen) + BUG(); + + /* dst address */ + addr = skb_put(skb, sizeof(struct sadb_address) + sockaddr_size); + addr->sadb_address_len = + (sizeof(struct sadb_address)+sockaddr_size)/ + sizeof(uint64_t); + addr->sadb_address_exttype = SADB_EXT_ADDRESS_DST; + addr->sadb_address_proto = 0; + addr->sadb_address_reserved = 0; + addr->sadb_address_prefixlen = + pfkey_sockaddr_fill(&x->id.daddr, 0, + (struct sockaddr *) (addr + 1), + x->props.family); + if (!addr->sadb_address_prefixlen) + BUG(); + + pol = skb_put(skb, sizeof(struct sadb_x_policy)); + pol->sadb_x_policy_len = sizeof(struct sadb_x_policy)/sizeof(uint64_t); + pol->sadb_x_policy_exttype = SADB_X_EXT_POLICY; + pol->sadb_x_policy_type = IPSEC_POLICY_IPSEC; + pol->sadb_x_policy_dir = XFRM_POLICY_OUT + 1; + pol->sadb_x_policy_reserved = 0; + pol->sadb_x_policy_id = xp->index; + pol->sadb_x_policy_priority = xp->priority; + + /* Set sadb_comb's. */ + alg_size = 0; + if (x->id.proto == IPPROTO_AH) + alg_size = dump_ah_combs(skb, t); + else if (x->id.proto == IPPROTO_ESP) + alg_size = dump_esp_combs(skb, t); + + hdr->sadb_msg_len += alg_size / 8; + + /* security context */ + if (xfrm_ctx) { + sec_ctx = skb_put(skb, + sizeof(struct sadb_x_sec_ctx) + ctx_size); + sec_ctx->sadb_x_sec_len = + (sizeof(struct sadb_x_sec_ctx) + ctx_size) / sizeof(uint64_t); + sec_ctx->sadb_x_sec_exttype = SADB_X_EXT_SEC_CTX; + sec_ctx->sadb_x_ctx_doi = xfrm_ctx->ctx_doi; + sec_ctx->sadb_x_ctx_alg = xfrm_ctx->ctx_alg; + sec_ctx->sadb_x_ctx_len = xfrm_ctx->ctx_len; + memcpy(sec_ctx + 1, xfrm_ctx->ctx_str, + xfrm_ctx->ctx_len); + } + + return pfkey_broadcast(skb, GFP_ATOMIC, BROADCAST_REGISTERED, NULL, + xs_net(x)); +} + +static struct xfrm_policy *pfkey_compile_policy(struct sock *sk, int opt, + u8 *data, int len, int *dir) +{ + struct net *net = sock_net(sk); + struct xfrm_policy *xp; + struct sadb_x_policy *pol = (struct sadb_x_policy*)data; + struct sadb_x_sec_ctx *sec_ctx; + + switch (sk->sk_family) { + case AF_INET: + if (opt != IP_IPSEC_POLICY) { + *dir = -EOPNOTSUPP; + return NULL; + } + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + if (opt != IPV6_IPSEC_POLICY) { + *dir = -EOPNOTSUPP; + return NULL; + } + break; +#endif + default: + *dir = -EINVAL; + return NULL; + } + + *dir = -EINVAL; + + if (len < sizeof(struct sadb_x_policy) || + pol->sadb_x_policy_len*8 > len || + pol->sadb_x_policy_type > IPSEC_POLICY_BYPASS || + (!pol->sadb_x_policy_dir || pol->sadb_x_policy_dir > IPSEC_DIR_OUTBOUND)) + return NULL; + + xp = xfrm_policy_alloc(net, GFP_ATOMIC); + if (xp == NULL) { + *dir = -ENOBUFS; + return NULL; + } + + xp->action = (pol->sadb_x_policy_type == IPSEC_POLICY_DISCARD ? + XFRM_POLICY_BLOCK : XFRM_POLICY_ALLOW); + + xp->lft.soft_byte_limit = XFRM_INF; + xp->lft.hard_byte_limit = XFRM_INF; + xp->lft.soft_packet_limit = XFRM_INF; + xp->lft.hard_packet_limit = XFRM_INF; + xp->family = sk->sk_family; + + xp->xfrm_nr = 0; + if (pol->sadb_x_policy_type == IPSEC_POLICY_IPSEC && + (*dir = parse_ipsecrequests(xp, pol)) < 0) + goto out; + + /* security context too */ + if (len >= (pol->sadb_x_policy_len*8 + + sizeof(struct sadb_x_sec_ctx))) { + char *p = (char *)pol; + struct xfrm_user_sec_ctx *uctx; + + p += pol->sadb_x_policy_len*8; + sec_ctx = (struct sadb_x_sec_ctx *)p; + if (len < pol->sadb_x_policy_len*8 + + sec_ctx->sadb_x_sec_len*8) { + *dir = -EINVAL; + goto out; + } + if ((*dir = verify_sec_ctx_len(p))) + goto out; + uctx = pfkey_sadb2xfrm_user_sec_ctx(sec_ctx, GFP_ATOMIC); + *dir = security_xfrm_policy_alloc(&xp->security, uctx, GFP_ATOMIC); + kfree(uctx); + + if (*dir) + goto out; + } + + *dir = pol->sadb_x_policy_dir-1; + return xp; + +out: + xp->walk.dead = 1; + xfrm_policy_destroy(xp); + return NULL; +} + +static int pfkey_send_new_mapping(struct xfrm_state *x, xfrm_address_t *ipaddr, __be16 sport) +{ + struct sk_buff *skb; + struct sadb_msg *hdr; + struct sadb_sa *sa; + struct sadb_address *addr; + struct sadb_x_nat_t_port *n_port; + int sockaddr_size; + int size; + __u8 satype = (x->id.proto == IPPROTO_ESP ? SADB_SATYPE_ESP : 0); + struct xfrm_encap_tmpl *natt = NULL; + + sockaddr_size = pfkey_sockaddr_size(x->props.family); + if (!sockaddr_size) + return -EINVAL; + + if (!satype) + return -EINVAL; + + if (!x->encap) + return -EINVAL; + + natt = x->encap; + + /* Build an SADB_X_NAT_T_NEW_MAPPING message: + * + * HDR | SA | ADDRESS_SRC (old addr) | NAT_T_SPORT (old port) | + * ADDRESS_DST (new addr) | NAT_T_DPORT (new port) + */ + + size = sizeof(struct sadb_msg) + + sizeof(struct sadb_sa) + + (sizeof(struct sadb_address) * 2) + + (sockaddr_size * 2) + + (sizeof(struct sadb_x_nat_t_port) * 2); + + skb = alloc_skb(size + 16, GFP_ATOMIC); + if (skb == NULL) + return -ENOMEM; + + hdr = skb_put(skb, sizeof(struct sadb_msg)); + hdr->sadb_msg_version = PF_KEY_V2; + hdr->sadb_msg_type = SADB_X_NAT_T_NEW_MAPPING; + hdr->sadb_msg_satype = satype; + hdr->sadb_msg_len = size / sizeof(uint64_t); + hdr->sadb_msg_errno = 0; + hdr->sadb_msg_reserved = 0; + hdr->sadb_msg_seq = x->km.seq; + hdr->sadb_msg_pid = 0; + + /* SA */ + sa = skb_put(skb, sizeof(struct sadb_sa)); + sa->sadb_sa_len = sizeof(struct sadb_sa)/sizeof(uint64_t); + sa->sadb_sa_exttype = SADB_EXT_SA; + sa->sadb_sa_spi = x->id.spi; + sa->sadb_sa_replay = 0; + sa->sadb_sa_state = 0; + sa->sadb_sa_auth = 0; + sa->sadb_sa_encrypt = 0; + sa->sadb_sa_flags = 0; + + /* ADDRESS_SRC (old addr) */ + addr = skb_put(skb, sizeof(struct sadb_address) + sockaddr_size); + addr->sadb_address_len = + (sizeof(struct sadb_address)+sockaddr_size)/ + sizeof(uint64_t); + addr->sadb_address_exttype = SADB_EXT_ADDRESS_SRC; + addr->sadb_address_proto = 0; + addr->sadb_address_reserved = 0; + addr->sadb_address_prefixlen = + pfkey_sockaddr_fill(&x->props.saddr, 0, + (struct sockaddr *) (addr + 1), + x->props.family); + if (!addr->sadb_address_prefixlen) + BUG(); + + /* NAT_T_SPORT (old port) */ + n_port = skb_put(skb, sizeof(*n_port)); + n_port->sadb_x_nat_t_port_len = sizeof(*n_port)/sizeof(uint64_t); + n_port->sadb_x_nat_t_port_exttype = SADB_X_EXT_NAT_T_SPORT; + n_port->sadb_x_nat_t_port_port = natt->encap_sport; + n_port->sadb_x_nat_t_port_reserved = 0; + + /* ADDRESS_DST (new addr) */ + addr = skb_put(skb, sizeof(struct sadb_address) + sockaddr_size); + addr->sadb_address_len = + (sizeof(struct sadb_address)+sockaddr_size)/ + sizeof(uint64_t); + addr->sadb_address_exttype = SADB_EXT_ADDRESS_DST; + addr->sadb_address_proto = 0; + addr->sadb_address_reserved = 0; + addr->sadb_address_prefixlen = + pfkey_sockaddr_fill(ipaddr, 0, + (struct sockaddr *) (addr + 1), + x->props.family); + if (!addr->sadb_address_prefixlen) + BUG(); + + /* NAT_T_DPORT (new port) */ + n_port = skb_put(skb, sizeof(*n_port)); + n_port->sadb_x_nat_t_port_len = sizeof(*n_port)/sizeof(uint64_t); + n_port->sadb_x_nat_t_port_exttype = SADB_X_EXT_NAT_T_DPORT; + n_port->sadb_x_nat_t_port_port = sport; + n_port->sadb_x_nat_t_port_reserved = 0; + + return pfkey_broadcast(skb, GFP_ATOMIC, BROADCAST_REGISTERED, NULL, + xs_net(x)); +} + +#ifdef CONFIG_NET_KEY_MIGRATE +static int set_sadb_address(struct sk_buff *skb, int sasize, int type, + const struct xfrm_selector *sel) +{ + struct sadb_address *addr; + addr = skb_put(skb, sizeof(struct sadb_address) + sasize); + addr->sadb_address_len = (sizeof(struct sadb_address) + sasize)/8; + addr->sadb_address_exttype = type; + addr->sadb_address_proto = sel->proto; + addr->sadb_address_reserved = 0; + + switch (type) { + case SADB_EXT_ADDRESS_SRC: + addr->sadb_address_prefixlen = sel->prefixlen_s; + pfkey_sockaddr_fill(&sel->saddr, 0, + (struct sockaddr *)(addr + 1), + sel->family); + break; + case SADB_EXT_ADDRESS_DST: + addr->sadb_address_prefixlen = sel->prefixlen_d; + pfkey_sockaddr_fill(&sel->daddr, 0, + (struct sockaddr *)(addr + 1), + sel->family); + break; + default: + return -EINVAL; + } + + return 0; +} + + +static int set_sadb_kmaddress(struct sk_buff *skb, const struct xfrm_kmaddress *k) +{ + struct sadb_x_kmaddress *kma; + u8 *sa; + int family = k->family; + int socklen = pfkey_sockaddr_len(family); + int size_req; + + size_req = (sizeof(struct sadb_x_kmaddress) + + pfkey_sockaddr_pair_size(family)); + + kma = skb_put_zero(skb, size_req); + kma->sadb_x_kmaddress_len = size_req / 8; + kma->sadb_x_kmaddress_exttype = SADB_X_EXT_KMADDRESS; + kma->sadb_x_kmaddress_reserved = k->reserved; + + sa = (u8 *)(kma + 1); + if (!pfkey_sockaddr_fill(&k->local, 0, (struct sockaddr *)sa, family) || + !pfkey_sockaddr_fill(&k->remote, 0, (struct sockaddr *)(sa+socklen), family)) + return -EINVAL; + + return 0; +} + +static int set_ipsecrequest(struct sk_buff *skb, + uint8_t proto, uint8_t mode, int level, + uint32_t reqid, uint8_t family, + const xfrm_address_t *src, const xfrm_address_t *dst) +{ + struct sadb_x_ipsecrequest *rq; + u8 *sa; + int socklen = pfkey_sockaddr_len(family); + int size_req; + + size_req = sizeof(struct sadb_x_ipsecrequest) + + pfkey_sockaddr_pair_size(family); + + rq = skb_put_zero(skb, size_req); + rq->sadb_x_ipsecrequest_len = size_req; + rq->sadb_x_ipsecrequest_proto = proto; + rq->sadb_x_ipsecrequest_mode = mode; + rq->sadb_x_ipsecrequest_level = level; + rq->sadb_x_ipsecrequest_reqid = reqid; + + sa = (u8 *) (rq + 1); + if (!pfkey_sockaddr_fill(src, 0, (struct sockaddr *)sa, family) || + !pfkey_sockaddr_fill(dst, 0, (struct sockaddr *)(sa + socklen), family)) + return -EINVAL; + + return 0; +} +#endif + +#ifdef CONFIG_NET_KEY_MIGRATE +static int pfkey_send_migrate(const struct xfrm_selector *sel, u8 dir, u8 type, + const struct xfrm_migrate *m, int num_bundles, + const struct xfrm_kmaddress *k, + const struct xfrm_encap_tmpl *encap) +{ + int i; + int sasize_sel; + int size = 0; + int size_pol = 0; + struct sk_buff *skb; + struct sadb_msg *hdr; + struct sadb_x_policy *pol; + const struct xfrm_migrate *mp; + + if (type != XFRM_POLICY_TYPE_MAIN) + return 0; + + if (num_bundles <= 0 || num_bundles > XFRM_MAX_DEPTH) + return -EINVAL; + + if (k != NULL) { + /* addresses for KM */ + size += PFKEY_ALIGN8(sizeof(struct sadb_x_kmaddress) + + pfkey_sockaddr_pair_size(k->family)); + } + + /* selector */ + sasize_sel = pfkey_sockaddr_size(sel->family); + if (!sasize_sel) + return -EINVAL; + size += (sizeof(struct sadb_address) + sasize_sel) * 2; + + /* policy info */ + size_pol += sizeof(struct sadb_x_policy); + + /* ipsecrequests */ + for (i = 0, mp = m; i < num_bundles; i++, mp++) { + /* old locator pair */ + size_pol += sizeof(struct sadb_x_ipsecrequest) + + pfkey_sockaddr_pair_size(mp->old_family); + /* new locator pair */ + size_pol += sizeof(struct sadb_x_ipsecrequest) + + pfkey_sockaddr_pair_size(mp->new_family); + } + + size += sizeof(struct sadb_msg) + size_pol; + + /* alloc buffer */ + skb = alloc_skb(size, GFP_ATOMIC); + if (skb == NULL) + return -ENOMEM; + + hdr = skb_put(skb, sizeof(struct sadb_msg)); + hdr->sadb_msg_version = PF_KEY_V2; + hdr->sadb_msg_type = SADB_X_MIGRATE; + hdr->sadb_msg_satype = pfkey_proto2satype(m->proto); + hdr->sadb_msg_len = size / 8; + hdr->sadb_msg_errno = 0; + hdr->sadb_msg_reserved = 0; + hdr->sadb_msg_seq = 0; + hdr->sadb_msg_pid = 0; + + /* Addresses to be used by KM for negotiation, if ext is available */ + if (k != NULL && (set_sadb_kmaddress(skb, k) < 0)) + goto err; + + /* selector src */ + set_sadb_address(skb, sasize_sel, SADB_EXT_ADDRESS_SRC, sel); + + /* selector dst */ + set_sadb_address(skb, sasize_sel, SADB_EXT_ADDRESS_DST, sel); + + /* policy information */ + pol = skb_put(skb, sizeof(struct sadb_x_policy)); + pol->sadb_x_policy_len = size_pol / 8; + pol->sadb_x_policy_exttype = SADB_X_EXT_POLICY; + pol->sadb_x_policy_type = IPSEC_POLICY_IPSEC; + pol->sadb_x_policy_dir = dir + 1; + pol->sadb_x_policy_reserved = 0; + pol->sadb_x_policy_id = 0; + pol->sadb_x_policy_priority = 0; + + for (i = 0, mp = m; i < num_bundles; i++, mp++) { + /* old ipsecrequest */ + int mode = pfkey_mode_from_xfrm(mp->mode); + if (mode < 0) + goto err; + if (set_ipsecrequest(skb, mp->proto, mode, + (mp->reqid ? IPSEC_LEVEL_UNIQUE : IPSEC_LEVEL_REQUIRE), + mp->reqid, mp->old_family, + &mp->old_saddr, &mp->old_daddr) < 0) + goto err; + + /* new ipsecrequest */ + if (set_ipsecrequest(skb, mp->proto, mode, + (mp->reqid ? IPSEC_LEVEL_UNIQUE : IPSEC_LEVEL_REQUIRE), + mp->reqid, mp->new_family, + &mp->new_saddr, &mp->new_daddr) < 0) + goto err; + } + + /* broadcast migrate message to sockets */ + pfkey_broadcast(skb, GFP_ATOMIC, BROADCAST_ALL, NULL, &init_net); + + return 0; + +err: + kfree_skb(skb); + return -EINVAL; +} +#else +static int pfkey_send_migrate(const struct xfrm_selector *sel, u8 dir, u8 type, + const struct xfrm_migrate *m, int num_bundles, + const struct xfrm_kmaddress *k, + const struct xfrm_encap_tmpl *encap) +{ + return -ENOPROTOOPT; +} +#endif + +static int pfkey_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) +{ + struct sock *sk = sock->sk; + struct sk_buff *skb = NULL; + struct sadb_msg *hdr = NULL; + int err; + struct net *net = sock_net(sk); + + err = -EOPNOTSUPP; + if (msg->msg_flags & MSG_OOB) + goto out; + + err = -EMSGSIZE; + if ((unsigned int)len > sk->sk_sndbuf - 32) + goto out; + + err = -ENOBUFS; + skb = alloc_skb(len, GFP_KERNEL); + if (skb == NULL) + goto out; + + err = -EFAULT; + if (memcpy_from_msg(skb_put(skb,len), msg, len)) + goto out; + + hdr = pfkey_get_base_msg(skb, &err); + if (!hdr) + goto out; + + mutex_lock(&net->xfrm.xfrm_cfg_mutex); + err = pfkey_process(sk, skb, hdr); + mutex_unlock(&net->xfrm.xfrm_cfg_mutex); + +out: + if (err && hdr && pfkey_error(hdr, err, sk) == 0) + err = 0; + kfree_skb(skb); + + return err ? : len; +} + +static int pfkey_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, + int flags) +{ + struct sock *sk = sock->sk; + struct pfkey_sock *pfk = pfkey_sk(sk); + struct sk_buff *skb; + int copied, err; + + err = -EINVAL; + if (flags & ~(MSG_PEEK|MSG_DONTWAIT|MSG_TRUNC|MSG_CMSG_COMPAT)) + goto out; + + skb = skb_recv_datagram(sk, flags, &err); + if (skb == NULL) + goto out; + + copied = skb->len; + if (copied > len) { + msg->msg_flags |= MSG_TRUNC; + copied = len; + } + + skb_reset_transport_header(skb); + err = skb_copy_datagram_msg(skb, 0, msg, copied); + if (err) + goto out_free; + + sock_recv_cmsgs(msg, sk, skb); + + err = (flags & MSG_TRUNC) ? skb->len : copied; + + if (pfk->dump.dump != NULL && + 3 * atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf) + pfkey_do_dump(pfk); + +out_free: + skb_free_datagram(sk, skb); +out: + return err; +} + +static const struct proto_ops pfkey_ops = { + .family = PF_KEY, + .owner = THIS_MODULE, + /* Operations that make no sense on pfkey sockets. */ + .bind = sock_no_bind, + .connect = sock_no_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = sock_no_getname, + .ioctl = sock_no_ioctl, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, + + /* Now the operations that really occur. */ + .release = pfkey_release, + .poll = datagram_poll, + .sendmsg = pfkey_sendmsg, + .recvmsg = pfkey_recvmsg, +}; + +static const struct net_proto_family pfkey_family_ops = { + .family = PF_KEY, + .create = pfkey_create, + .owner = THIS_MODULE, +}; + +#ifdef CONFIG_PROC_FS +static int pfkey_seq_show(struct seq_file *f, void *v) +{ + struct sock *s = sk_entry(v); + + if (v == SEQ_START_TOKEN) + seq_printf(f ,"sk RefCnt Rmem Wmem User Inode\n"); + else + seq_printf(f, "%pK %-6d %-6u %-6u %-6u %-6lu\n", + s, + refcount_read(&s->sk_refcnt), + sk_rmem_alloc_get(s), + sk_wmem_alloc_get(s), + from_kuid_munged(seq_user_ns(f), sock_i_uid(s)), + sock_i_ino(s) + ); + return 0; +} + +static void *pfkey_seq_start(struct seq_file *f, loff_t *ppos) + __acquires(rcu) +{ + struct net *net = seq_file_net(f); + struct netns_pfkey *net_pfkey = net_generic(net, pfkey_net_id); + + rcu_read_lock(); + return seq_hlist_start_head_rcu(&net_pfkey->table, *ppos); +} + +static void *pfkey_seq_next(struct seq_file *f, void *v, loff_t *ppos) +{ + struct net *net = seq_file_net(f); + struct netns_pfkey *net_pfkey = net_generic(net, pfkey_net_id); + + return seq_hlist_next_rcu(v, &net_pfkey->table, ppos); +} + +static void pfkey_seq_stop(struct seq_file *f, void *v) + __releases(rcu) +{ + rcu_read_unlock(); +} + +static const struct seq_operations pfkey_seq_ops = { + .start = pfkey_seq_start, + .next = pfkey_seq_next, + .stop = pfkey_seq_stop, + .show = pfkey_seq_show, +}; + +static int __net_init pfkey_init_proc(struct net *net) +{ + struct proc_dir_entry *e; + + e = proc_create_net("pfkey", 0, net->proc_net, &pfkey_seq_ops, + sizeof(struct seq_net_private)); + if (e == NULL) + return -ENOMEM; + + return 0; +} + +static void __net_exit pfkey_exit_proc(struct net *net) +{ + remove_proc_entry("pfkey", net->proc_net); +} +#else +static inline int pfkey_init_proc(struct net *net) +{ + return 0; +} + +static inline void pfkey_exit_proc(struct net *net) +{ +} +#endif + +static struct xfrm_mgr pfkeyv2_mgr = +{ + .notify = pfkey_send_notify, + .acquire = pfkey_send_acquire, + .compile_policy = pfkey_compile_policy, + .new_mapping = pfkey_send_new_mapping, + .notify_policy = pfkey_send_policy_notify, + .migrate = pfkey_send_migrate, + .is_alive = pfkey_is_alive, +}; + +static int __net_init pfkey_net_init(struct net *net) +{ + struct netns_pfkey *net_pfkey = net_generic(net, pfkey_net_id); + int rv; + + INIT_HLIST_HEAD(&net_pfkey->table); + atomic_set(&net_pfkey->socks_nr, 0); + + rv = pfkey_init_proc(net); + + return rv; +} + +static void __net_exit pfkey_net_exit(struct net *net) +{ + struct netns_pfkey *net_pfkey = net_generic(net, pfkey_net_id); + + pfkey_exit_proc(net); + WARN_ON(!hlist_empty(&net_pfkey->table)); +} + +static struct pernet_operations pfkey_net_ops = { + .init = pfkey_net_init, + .exit = pfkey_net_exit, + .id = &pfkey_net_id, + .size = sizeof(struct netns_pfkey), +}; + +static void __exit ipsec_pfkey_exit(void) +{ + xfrm_unregister_km(&pfkeyv2_mgr); + sock_unregister(PF_KEY); + unregister_pernet_subsys(&pfkey_net_ops); + proto_unregister(&key_proto); +} + +static int __init ipsec_pfkey_init(void) +{ + int err = proto_register(&key_proto, 0); + + if (err != 0) + goto out; + + err = register_pernet_subsys(&pfkey_net_ops); + if (err != 0) + goto out_unregister_key_proto; + err = sock_register(&pfkey_family_ops); + if (err != 0) + goto out_unregister_pernet; + xfrm_register_km(&pfkeyv2_mgr); +out: + return err; + +out_unregister_pernet: + unregister_pernet_subsys(&pfkey_net_ops); +out_unregister_key_proto: + proto_unregister(&key_proto); + goto out; +} + +module_init(ipsec_pfkey_init); +module_exit(ipsec_pfkey_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_KEY); diff --git a/net/l2tp/Kconfig b/net/l2tp/Kconfig new file mode 100644 index 000000000..b7856748e --- /dev/null +++ b/net/l2tp/Kconfig @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Layer Two Tunneling Protocol (L2TP) +# + +menuconfig L2TP + tristate "Layer Two Tunneling Protocol (L2TP)" + depends on (IPV6 || IPV6=n) + depends on INET + select NET_UDP_TUNNEL + help + Layer Two Tunneling Protocol + + From RFC 2661 . + + L2TP facilitates the tunneling of packets across an + intervening network in a way that is as transparent as + possible to both end-users and applications. + + L2TP is often used to tunnel PPP traffic over IP + tunnels. One IP tunnel may carry thousands of individual PPP + connections. L2TP is also used as a VPN protocol, popular + with home workers to connect to their offices. + + L2TPv3 allows other protocols as well as PPP to be carried + over L2TP tunnels. L2TPv3 is defined in RFC 3931 + . + + The kernel component handles only L2TP data packets: a + userland daemon handles L2TP the control protocol (tunnel + and session setup). One such daemon is OpenL2TP + (http://openl2tp.org/). + + If you don't need L2TP, say N. To compile all L2TP code as + modules, choose M here. + +config L2TP_DEBUGFS + tristate "L2TP debugfs support" + depends on L2TP && DEBUG_FS + help + Support for l2tp directory in debugfs filesystem. This may be + used to dump internal state of the l2tp drivers for problem + analysis. + + If unsure, say 'Y'. + + To compile this driver as a module, choose M here. The module + will be called l2tp_debugfs. + +config L2TP_V3 + bool "L2TPv3 support" + depends on L2TP + help + Layer Two Tunneling Protocol Version 3 + + From RFC 3931 . + + The Layer Two Tunneling Protocol (L2TP) provides a dynamic + mechanism for tunneling Layer 2 (L2) "circuits" across a + packet-oriented data network (e.g., over IP). L2TP, as + originally defined in RFC 2661, is a standard method for + tunneling Point-to-Point Protocol (PPP) [RFC1661] sessions. + L2TP has since been adopted for tunneling a number of other + L2 protocols, including ATM, Frame Relay, HDLC and even raw + ethernet frames. + + If you are connecting to L2TPv3 equipment, or you want to + tunnel raw ethernet frames using L2TP, say Y here. If + unsure, say N. + +config L2TP_IP + tristate "L2TP IP encapsulation for L2TPv3" + depends on L2TP_V3 + help + Support for L2TP-over-IP socket family. + + The L2TPv3 protocol defines two possible encapsulations for + L2TP frames, namely UDP and plain IP (without UDP). This + driver provides a new L2TPIP socket family with which + userspace L2TPv3 daemons may create L2TP/IP tunnel sockets + when UDP encapsulation is not required. When L2TP is carried + in IP packets, it used IP protocol number 115, so this port + must be enabled in firewalls. + + To compile this driver as a module, choose M here. The module + will be called l2tp_ip. + +config L2TP_ETH + tristate "L2TP ethernet pseudowire support for L2TPv3" + depends on L2TP_V3 + help + Support for carrying raw ethernet frames over L2TPv3. + + From RFC 4719 . + + The Layer 2 Tunneling Protocol, Version 3 (L2TPv3) can be + used as a control protocol and for data encapsulation to set + up Pseudowires for transporting layer 2 Packet Data Units + across an IP network [RFC3931]. + + This driver provides an ethernet virtual interface for each + L2TP ethernet pseudowire instance. Standard Linux tools may + be used to assign an IP address to the local virtual + interface, or add the interface to a bridge. + + If you are using L2TPv3, you will almost certainly want to + enable this option. + + To compile this driver as a module, choose M here. The module + will be called l2tp_eth. diff --git a/net/l2tp/Makefile b/net/l2tp/Makefile new file mode 100644 index 000000000..cf8f27071 --- /dev/null +++ b/net/l2tp/Makefile @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the L2TP. +# + +obj-$(CONFIG_L2TP) += l2tp_core.o + +CFLAGS_l2tp_core.o += -I$(src) + +# Build l2tp as modules if L2TP is M +obj-$(subst y,$(CONFIG_L2TP),$(CONFIG_PPPOL2TP)) += l2tp_ppp.o +obj-$(subst y,$(CONFIG_L2TP),$(CONFIG_L2TP_IP)) += l2tp_ip.o +obj-$(subst y,$(CONFIG_L2TP),$(CONFIG_L2TP_V3)) += l2tp_netlink.o +obj-$(subst y,$(CONFIG_L2TP),$(CONFIG_L2TP_ETH)) += l2tp_eth.o +obj-$(subst y,$(CONFIG_L2TP),$(CONFIG_L2TP_DEBUGFS)) += l2tp_debugfs.o +ifneq ($(CONFIG_IPV6),) +obj-$(subst y,$(CONFIG_L2TP),$(CONFIG_L2TP_IP)) += l2tp_ip6.o +endif diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c new file mode 100644 index 000000000..8d21ff25f --- /dev/null +++ b/net/l2tp/l2tp_core.c @@ -0,0 +1,1723 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* L2TP core. + * + * Copyright (c) 2008,2009,2010 Katalix Systems Ltd + * + * This file contains some code of the original L2TPv2 pppol2tp + * driver, which has the following copyright: + * + * Authors: Martijn van Oosterhout + * James Chapman (jchapman@katalix.com) + * Contributors: + * Michal Ostrowski + * Arnaldo Carvalho de Melo + * David S. Miller (davem@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "l2tp_core.h" +#include "trace.h" + +#define CREATE_TRACE_POINTS +#include "trace.h" + +#define L2TP_DRV_VERSION "V2.0" + +/* L2TP header constants */ +#define L2TP_HDRFLAG_T 0x8000 +#define L2TP_HDRFLAG_L 0x4000 +#define L2TP_HDRFLAG_S 0x0800 +#define L2TP_HDRFLAG_O 0x0200 +#define L2TP_HDRFLAG_P 0x0100 + +#define L2TP_HDR_VER_MASK 0x000F +#define L2TP_HDR_VER_2 0x0002 +#define L2TP_HDR_VER_3 0x0003 + +/* L2TPv3 default L2-specific sublayer */ +#define L2TP_SLFLAG_S 0x40000000 +#define L2TP_SL_SEQ_MASK 0x00ffffff + +#define L2TP_HDR_SIZE_MAX 14 + +/* Default trace flags */ +#define L2TP_DEFAULT_DEBUG_FLAGS 0 + +/* Private data stored for received packets in the skb. + */ +struct l2tp_skb_cb { + u32 ns; + u16 has_seq; + u16 length; + unsigned long expires; +}; + +#define L2TP_SKB_CB(skb) ((struct l2tp_skb_cb *)&(skb)->cb[sizeof(struct inet_skb_parm)]) + +static struct workqueue_struct *l2tp_wq; + +/* per-net private data for this module */ +static unsigned int l2tp_net_id; +struct l2tp_net { + /* Lock for write access to l2tp_tunnel_idr */ + spinlock_t l2tp_tunnel_idr_lock; + struct idr l2tp_tunnel_idr; + struct hlist_head l2tp_session_hlist[L2TP_HASH_SIZE_2]; + /* Lock for write access to l2tp_session_hlist */ + spinlock_t l2tp_session_hlist_lock; +}; + +#if IS_ENABLED(CONFIG_IPV6) +static bool l2tp_sk_is_v6(struct sock *sk) +{ + return sk->sk_family == PF_INET6 && + !ipv6_addr_v4mapped(&sk->sk_v6_daddr); +} +#endif + +static inline struct l2tp_net *l2tp_pernet(const struct net *net) +{ + return net_generic(net, l2tp_net_id); +} + +/* Session hash global list for L2TPv3. + * The session_id SHOULD be random according to RFC3931, but several + * L2TP implementations use incrementing session_ids. So we do a real + * hash on the session_id, rather than a simple bitmask. + */ +static inline struct hlist_head * +l2tp_session_id_hash_2(struct l2tp_net *pn, u32 session_id) +{ + return &pn->l2tp_session_hlist[hash_32(session_id, L2TP_HASH_BITS_2)]; +} + +/* Session hash list. + * The session_id SHOULD be random according to RFC2661, but several + * L2TP implementations (Cisco and Microsoft) use incrementing + * session_ids. So we do a real hash on the session_id, rather than a + * simple bitmask. + */ +static inline struct hlist_head * +l2tp_session_id_hash(struct l2tp_tunnel *tunnel, u32 session_id) +{ + return &tunnel->session_hlist[hash_32(session_id, L2TP_HASH_BITS)]; +} + +static void l2tp_tunnel_free(struct l2tp_tunnel *tunnel) +{ + trace_free_tunnel(tunnel); + sock_put(tunnel->sock); + /* the tunnel is freed in the socket destructor */ +} + +static void l2tp_session_free(struct l2tp_session *session) +{ + trace_free_session(session); + if (session->tunnel) + l2tp_tunnel_dec_refcount(session->tunnel); + kfree(session); +} + +struct l2tp_tunnel *l2tp_sk_to_tunnel(struct sock *sk) +{ + struct l2tp_tunnel *tunnel = sk->sk_user_data; + + if (tunnel) + if (WARN_ON(tunnel->magic != L2TP_TUNNEL_MAGIC)) + return NULL; + + return tunnel; +} +EXPORT_SYMBOL_GPL(l2tp_sk_to_tunnel); + +void l2tp_tunnel_inc_refcount(struct l2tp_tunnel *tunnel) +{ + refcount_inc(&tunnel->ref_count); +} +EXPORT_SYMBOL_GPL(l2tp_tunnel_inc_refcount); + +void l2tp_tunnel_dec_refcount(struct l2tp_tunnel *tunnel) +{ + if (refcount_dec_and_test(&tunnel->ref_count)) + l2tp_tunnel_free(tunnel); +} +EXPORT_SYMBOL_GPL(l2tp_tunnel_dec_refcount); + +void l2tp_session_inc_refcount(struct l2tp_session *session) +{ + refcount_inc(&session->ref_count); +} +EXPORT_SYMBOL_GPL(l2tp_session_inc_refcount); + +void l2tp_session_dec_refcount(struct l2tp_session *session) +{ + if (refcount_dec_and_test(&session->ref_count)) + l2tp_session_free(session); +} +EXPORT_SYMBOL_GPL(l2tp_session_dec_refcount); + +/* Lookup a tunnel. A new reference is held on the returned tunnel. */ +struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id) +{ + const struct l2tp_net *pn = l2tp_pernet(net); + struct l2tp_tunnel *tunnel; + + rcu_read_lock_bh(); + tunnel = idr_find(&pn->l2tp_tunnel_idr, tunnel_id); + if (tunnel && refcount_inc_not_zero(&tunnel->ref_count)) { + rcu_read_unlock_bh(); + return tunnel; + } + rcu_read_unlock_bh(); + + return NULL; +} +EXPORT_SYMBOL_GPL(l2tp_tunnel_get); + +struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth) +{ + struct l2tp_net *pn = l2tp_pernet(net); + unsigned long tunnel_id, tmp; + struct l2tp_tunnel *tunnel; + int count = 0; + + rcu_read_lock_bh(); + idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) { + if (tunnel && ++count > nth && + refcount_inc_not_zero(&tunnel->ref_count)) { + rcu_read_unlock_bh(); + return tunnel; + } + } + rcu_read_unlock_bh(); + + return NULL; +} +EXPORT_SYMBOL_GPL(l2tp_tunnel_get_nth); + +struct l2tp_session *l2tp_tunnel_get_session(struct l2tp_tunnel *tunnel, + u32 session_id) +{ + struct hlist_head *session_list; + struct l2tp_session *session; + + session_list = l2tp_session_id_hash(tunnel, session_id); + + rcu_read_lock_bh(); + hlist_for_each_entry_rcu(session, session_list, hlist) + if (session->session_id == session_id) { + l2tp_session_inc_refcount(session); + rcu_read_unlock_bh(); + + return session; + } + rcu_read_unlock_bh(); + + return NULL; +} +EXPORT_SYMBOL_GPL(l2tp_tunnel_get_session); + +struct l2tp_session *l2tp_session_get(const struct net *net, u32 session_id) +{ + struct hlist_head *session_list; + struct l2tp_session *session; + + session_list = l2tp_session_id_hash_2(l2tp_pernet(net), session_id); + + rcu_read_lock_bh(); + hlist_for_each_entry_rcu(session, session_list, global_hlist) + if (session->session_id == session_id) { + l2tp_session_inc_refcount(session); + rcu_read_unlock_bh(); + + return session; + } + rcu_read_unlock_bh(); + + return NULL; +} +EXPORT_SYMBOL_GPL(l2tp_session_get); + +struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth) +{ + int hash; + struct l2tp_session *session; + int count = 0; + + rcu_read_lock_bh(); + for (hash = 0; hash < L2TP_HASH_SIZE; hash++) { + hlist_for_each_entry_rcu(session, &tunnel->session_hlist[hash], hlist) { + if (++count > nth) { + l2tp_session_inc_refcount(session); + rcu_read_unlock_bh(); + return session; + } + } + } + + rcu_read_unlock_bh(); + + return NULL; +} +EXPORT_SYMBOL_GPL(l2tp_session_get_nth); + +/* Lookup a session by interface name. + * This is very inefficient but is only used by management interfaces. + */ +struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net, + const char *ifname) +{ + struct l2tp_net *pn = l2tp_pernet(net); + int hash; + struct l2tp_session *session; + + rcu_read_lock_bh(); + for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) { + hlist_for_each_entry_rcu(session, &pn->l2tp_session_hlist[hash], global_hlist) { + if (!strcmp(session->ifname, ifname)) { + l2tp_session_inc_refcount(session); + rcu_read_unlock_bh(); + + return session; + } + } + } + + rcu_read_unlock_bh(); + + return NULL; +} +EXPORT_SYMBOL_GPL(l2tp_session_get_by_ifname); + +int l2tp_session_register(struct l2tp_session *session, + struct l2tp_tunnel *tunnel) +{ + struct l2tp_session *session_walk; + struct hlist_head *g_head; + struct hlist_head *head; + struct l2tp_net *pn; + int err; + + head = l2tp_session_id_hash(tunnel, session->session_id); + + spin_lock_bh(&tunnel->hlist_lock); + if (!tunnel->acpt_newsess) { + err = -ENODEV; + goto err_tlock; + } + + hlist_for_each_entry(session_walk, head, hlist) + if (session_walk->session_id == session->session_id) { + err = -EEXIST; + goto err_tlock; + } + + if (tunnel->version == L2TP_HDR_VER_3) { + pn = l2tp_pernet(tunnel->l2tp_net); + g_head = l2tp_session_id_hash_2(pn, session->session_id); + + spin_lock_bh(&pn->l2tp_session_hlist_lock); + + /* IP encap expects session IDs to be globally unique, while + * UDP encap doesn't. + */ + hlist_for_each_entry(session_walk, g_head, global_hlist) + if (session_walk->session_id == session->session_id && + (session_walk->tunnel->encap == L2TP_ENCAPTYPE_IP || + tunnel->encap == L2TP_ENCAPTYPE_IP)) { + err = -EEXIST; + goto err_tlock_pnlock; + } + + l2tp_tunnel_inc_refcount(tunnel); + hlist_add_head_rcu(&session->global_hlist, g_head); + + spin_unlock_bh(&pn->l2tp_session_hlist_lock); + } else { + l2tp_tunnel_inc_refcount(tunnel); + } + + hlist_add_head_rcu(&session->hlist, head); + spin_unlock_bh(&tunnel->hlist_lock); + + trace_register_session(session); + + return 0; + +err_tlock_pnlock: + spin_unlock_bh(&pn->l2tp_session_hlist_lock); +err_tlock: + spin_unlock_bh(&tunnel->hlist_lock); + + return err; +} +EXPORT_SYMBOL_GPL(l2tp_session_register); + +/***************************************************************************** + * Receive data handling + *****************************************************************************/ + +/* Queue a skb in order. We come here only if the skb has an L2TP sequence + * number. + */ +static void l2tp_recv_queue_skb(struct l2tp_session *session, struct sk_buff *skb) +{ + struct sk_buff *skbp; + struct sk_buff *tmp; + u32 ns = L2TP_SKB_CB(skb)->ns; + + spin_lock_bh(&session->reorder_q.lock); + skb_queue_walk_safe(&session->reorder_q, skbp, tmp) { + if (L2TP_SKB_CB(skbp)->ns > ns) { + __skb_queue_before(&session->reorder_q, skbp, skb); + atomic_long_inc(&session->stats.rx_oos_packets); + goto out; + } + } + + __skb_queue_tail(&session->reorder_q, skb); + +out: + spin_unlock_bh(&session->reorder_q.lock); +} + +/* Dequeue a single skb. + */ +static void l2tp_recv_dequeue_skb(struct l2tp_session *session, struct sk_buff *skb) +{ + struct l2tp_tunnel *tunnel = session->tunnel; + int length = L2TP_SKB_CB(skb)->length; + + /* We're about to requeue the skb, so return resources + * to its current owner (a socket receive buffer). + */ + skb_orphan(skb); + + atomic_long_inc(&tunnel->stats.rx_packets); + atomic_long_add(length, &tunnel->stats.rx_bytes); + atomic_long_inc(&session->stats.rx_packets); + atomic_long_add(length, &session->stats.rx_bytes); + + if (L2TP_SKB_CB(skb)->has_seq) { + /* Bump our Nr */ + session->nr++; + session->nr &= session->nr_max; + trace_session_seqnum_update(session); + } + + /* call private receive handler */ + if (session->recv_skb) + (*session->recv_skb)(session, skb, L2TP_SKB_CB(skb)->length); + else + kfree_skb(skb); +} + +/* Dequeue skbs from the session's reorder_q, subject to packet order. + * Skbs that have been in the queue for too long are simply discarded. + */ +static void l2tp_recv_dequeue(struct l2tp_session *session) +{ + struct sk_buff *skb; + struct sk_buff *tmp; + + /* If the pkt at the head of the queue has the nr that we + * expect to send up next, dequeue it and any other + * in-sequence packets behind it. + */ +start: + spin_lock_bh(&session->reorder_q.lock); + skb_queue_walk_safe(&session->reorder_q, skb, tmp) { + struct l2tp_skb_cb *cb = L2TP_SKB_CB(skb); + + /* If the packet has been pending on the queue for too long, discard it */ + if (time_after(jiffies, cb->expires)) { + atomic_long_inc(&session->stats.rx_seq_discards); + atomic_long_inc(&session->stats.rx_errors); + trace_session_pkt_expired(session, cb->ns); + session->reorder_skip = 1; + __skb_unlink(skb, &session->reorder_q); + kfree_skb(skb); + continue; + } + + if (cb->has_seq) { + if (session->reorder_skip) { + session->reorder_skip = 0; + session->nr = cb->ns; + trace_session_seqnum_reset(session); + } + if (cb->ns != session->nr) + goto out; + } + __skb_unlink(skb, &session->reorder_q); + + /* Process the skb. We release the queue lock while we + * do so to let other contexts process the queue. + */ + spin_unlock_bh(&session->reorder_q.lock); + l2tp_recv_dequeue_skb(session, skb); + goto start; + } + +out: + spin_unlock_bh(&session->reorder_q.lock); +} + +static int l2tp_seq_check_rx_window(struct l2tp_session *session, u32 nr) +{ + u32 nws; + + if (nr >= session->nr) + nws = nr - session->nr; + else + nws = (session->nr_max + 1) - (session->nr - nr); + + return nws < session->nr_window_size; +} + +/* If packet has sequence numbers, queue it if acceptable. Returns 0 if + * acceptable, else non-zero. + */ +static int l2tp_recv_data_seq(struct l2tp_session *session, struct sk_buff *skb) +{ + struct l2tp_skb_cb *cb = L2TP_SKB_CB(skb); + + if (!l2tp_seq_check_rx_window(session, cb->ns)) { + /* Packet sequence number is outside allowed window. + * Discard it. + */ + trace_session_pkt_outside_rx_window(session, cb->ns); + goto discard; + } + + if (session->reorder_timeout != 0) { + /* Packet reordering enabled. Add skb to session's + * reorder queue, in order of ns. + */ + l2tp_recv_queue_skb(session, skb); + goto out; + } + + /* Packet reordering disabled. Discard out-of-sequence packets, while + * tracking the number if in-sequence packets after the first OOS packet + * is seen. After nr_oos_count_max in-sequence packets, reset the + * sequence number to re-enable packet reception. + */ + if (cb->ns == session->nr) { + skb_queue_tail(&session->reorder_q, skb); + } else { + u32 nr_oos = cb->ns; + u32 nr_next = (session->nr_oos + 1) & session->nr_max; + + if (nr_oos == nr_next) + session->nr_oos_count++; + else + session->nr_oos_count = 0; + + session->nr_oos = nr_oos; + if (session->nr_oos_count > session->nr_oos_count_max) { + session->reorder_skip = 1; + } + if (!session->reorder_skip) { + atomic_long_inc(&session->stats.rx_seq_discards); + trace_session_pkt_oos(session, cb->ns); + goto discard; + } + skb_queue_tail(&session->reorder_q, skb); + } + +out: + return 0; + +discard: + return 1; +} + +/* Do receive processing of L2TP data frames. We handle both L2TPv2 + * and L2TPv3 data frames here. + * + * L2TPv2 Data Message Header + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * |T|L|x|x|S|x|O|P|x|x|x|x| Ver | Length (opt) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Tunnel ID | Session ID | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Ns (opt) | Nr (opt) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Offset Size (opt) | Offset pad... (opt) + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * Data frames are marked by T=0. All other fields are the same as + * those in L2TP control frames. + * + * L2TPv3 Data Message Header + * + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | L2TP Session Header | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | L2-Specific Sublayer | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Tunnel Payload ... + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * L2TPv3 Session Header Over IP + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Session ID | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Cookie (optional, maximum 64 bits)... + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * L2TPv3 L2-Specific Sublayer Format + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * |x|S|x|x|x|x|x|x| Sequence Number | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * Cookie value and sublayer format are negotiated with the peer when + * the session is set up. Unlike L2TPv2, we do not need to parse the + * packet header to determine if optional fields are present. + * + * Caller must already have parsed the frame and determined that it is + * a data (not control) frame before coming here. Fields up to the + * session-id have already been parsed and ptr points to the data + * after the session-id. + */ +void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb, + unsigned char *ptr, unsigned char *optr, u16 hdrflags, + int length) +{ + struct l2tp_tunnel *tunnel = session->tunnel; + int offset; + + /* Parse and check optional cookie */ + if (session->peer_cookie_len > 0) { + if (memcmp(ptr, &session->peer_cookie[0], session->peer_cookie_len)) { + pr_debug_ratelimited("%s: cookie mismatch (%u/%u). Discarding.\n", + tunnel->name, tunnel->tunnel_id, + session->session_id); + atomic_long_inc(&session->stats.rx_cookie_discards); + goto discard; + } + ptr += session->peer_cookie_len; + } + + /* Handle the optional sequence numbers. Sequence numbers are + * in different places for L2TPv2 and L2TPv3. + * + * If we are the LAC, enable/disable sequence numbers under + * the control of the LNS. If no sequence numbers present but + * we were expecting them, discard frame. + */ + L2TP_SKB_CB(skb)->has_seq = 0; + if (tunnel->version == L2TP_HDR_VER_2) { + if (hdrflags & L2TP_HDRFLAG_S) { + /* Store L2TP info in the skb */ + L2TP_SKB_CB(skb)->ns = ntohs(*(__be16 *)ptr); + L2TP_SKB_CB(skb)->has_seq = 1; + ptr += 2; + /* Skip past nr in the header */ + ptr += 2; + + } + } else if (session->l2specific_type == L2TP_L2SPECTYPE_DEFAULT) { + u32 l2h = ntohl(*(__be32 *)ptr); + + if (l2h & 0x40000000) { + /* Store L2TP info in the skb */ + L2TP_SKB_CB(skb)->ns = l2h & 0x00ffffff; + L2TP_SKB_CB(skb)->has_seq = 1; + } + ptr += 4; + } + + if (L2TP_SKB_CB(skb)->has_seq) { + /* Received a packet with sequence numbers. If we're the LAC, + * check if we sre sending sequence numbers and if not, + * configure it so. + */ + if (!session->lns_mode && !session->send_seq) { + trace_session_seqnum_lns_enable(session); + session->send_seq = 1; + l2tp_session_set_header_len(session, tunnel->version); + } + } else { + /* No sequence numbers. + * If user has configured mandatory sequence numbers, discard. + */ + if (session->recv_seq) { + pr_debug_ratelimited("%s: recv data has no seq numbers when required. Discarding.\n", + session->name); + atomic_long_inc(&session->stats.rx_seq_discards); + goto discard; + } + + /* If we're the LAC and we're sending sequence numbers, the + * LNS has requested that we no longer send sequence numbers. + * If we're the LNS and we're sending sequence numbers, the + * LAC is broken. Discard the frame. + */ + if (!session->lns_mode && session->send_seq) { + trace_session_seqnum_lns_disable(session); + session->send_seq = 0; + l2tp_session_set_header_len(session, tunnel->version); + } else if (session->send_seq) { + pr_debug_ratelimited("%s: recv data has no seq numbers when required. Discarding.\n", + session->name); + atomic_long_inc(&session->stats.rx_seq_discards); + goto discard; + } + } + + /* Session data offset is defined only for L2TPv2 and is + * indicated by an optional 16-bit value in the header. + */ + if (tunnel->version == L2TP_HDR_VER_2) { + /* If offset bit set, skip it. */ + if (hdrflags & L2TP_HDRFLAG_O) { + offset = ntohs(*(__be16 *)ptr); + ptr += 2 + offset; + } + } + + offset = ptr - optr; + if (!pskb_may_pull(skb, offset)) + goto discard; + + __skb_pull(skb, offset); + + /* Prepare skb for adding to the session's reorder_q. Hold + * packets for max reorder_timeout or 1 second if not + * reordering. + */ + L2TP_SKB_CB(skb)->length = length; + L2TP_SKB_CB(skb)->expires = jiffies + + (session->reorder_timeout ? session->reorder_timeout : HZ); + + /* Add packet to the session's receive queue. Reordering is done here, if + * enabled. Saved L2TP protocol info is stored in skb->sb[]. + */ + if (L2TP_SKB_CB(skb)->has_seq) { + if (l2tp_recv_data_seq(session, skb)) + goto discard; + } else { + /* No sequence numbers. Add the skb to the tail of the + * reorder queue. This ensures that it will be + * delivered after all previous sequenced skbs. + */ + skb_queue_tail(&session->reorder_q, skb); + } + + /* Try to dequeue as many skbs from reorder_q as we can. */ + l2tp_recv_dequeue(session); + + return; + +discard: + atomic_long_inc(&session->stats.rx_errors); + kfree_skb(skb); +} +EXPORT_SYMBOL_GPL(l2tp_recv_common); + +/* Drop skbs from the session's reorder_q + */ +static void l2tp_session_queue_purge(struct l2tp_session *session) +{ + struct sk_buff *skb = NULL; + + while ((skb = skb_dequeue(&session->reorder_q))) { + atomic_long_inc(&session->stats.rx_errors); + kfree_skb(skb); + } +} + +/* Internal UDP receive frame. Do the real work of receiving an L2TP data frame + * here. The skb is not on a list when we get here. + * Returns 0 if the packet was a data packet and was successfully passed on. + * Returns 1 if the packet was not a good data packet and could not be + * forwarded. All such packets are passed up to userspace to deal with. + */ +static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb) +{ + struct l2tp_session *session = NULL; + unsigned char *ptr, *optr; + u16 hdrflags; + u32 tunnel_id, session_id; + u16 version; + int length; + + /* UDP has verified checksum */ + + /* UDP always verifies the packet length. */ + __skb_pull(skb, sizeof(struct udphdr)); + + /* Short packet? */ + if (!pskb_may_pull(skb, L2TP_HDR_SIZE_MAX)) { + pr_debug_ratelimited("%s: recv short packet (len=%d)\n", + tunnel->name, skb->len); + goto invalid; + } + + /* Point to L2TP header */ + optr = skb->data; + ptr = skb->data; + + /* Get L2TP header flags */ + hdrflags = ntohs(*(__be16 *)ptr); + + /* Check protocol version */ + version = hdrflags & L2TP_HDR_VER_MASK; + if (version != tunnel->version) { + pr_debug_ratelimited("%s: recv protocol version mismatch: got %d expected %d\n", + tunnel->name, version, tunnel->version); + goto invalid; + } + + /* Get length of L2TP packet */ + length = skb->len; + + /* If type is control packet, it is handled by userspace. */ + if (hdrflags & L2TP_HDRFLAG_T) + goto pass; + + /* Skip flags */ + ptr += 2; + + if (tunnel->version == L2TP_HDR_VER_2) { + /* If length is present, skip it */ + if (hdrflags & L2TP_HDRFLAG_L) + ptr += 2; + + /* Extract tunnel and session ID */ + tunnel_id = ntohs(*(__be16 *)ptr); + ptr += 2; + session_id = ntohs(*(__be16 *)ptr); + ptr += 2; + } else { + ptr += 2; /* skip reserved bits */ + tunnel_id = tunnel->tunnel_id; + session_id = ntohl(*(__be32 *)ptr); + ptr += 4; + } + + /* Find the session context */ + session = l2tp_tunnel_get_session(tunnel, session_id); + if (!session || !session->recv_skb) { + if (session) + l2tp_session_dec_refcount(session); + + /* Not found? Pass to userspace to deal with */ + pr_debug_ratelimited("%s: no session found (%u/%u). Passing up.\n", + tunnel->name, tunnel_id, session_id); + goto pass; + } + + if (tunnel->version == L2TP_HDR_VER_3 && + l2tp_v3_ensure_opt_in_linear(session, skb, &ptr, &optr)) { + l2tp_session_dec_refcount(session); + goto invalid; + } + + l2tp_recv_common(session, skb, ptr, optr, hdrflags, length); + l2tp_session_dec_refcount(session); + + return 0; + +invalid: + atomic_long_inc(&tunnel->stats.rx_invalid); + +pass: + /* Put UDP header back */ + __skb_push(skb, sizeof(struct udphdr)); + + return 1; +} + +/* UDP encapsulation receive handler. See net/ipv4/udp.c. + * Return codes: + * 0 : success. + * <0: error + * >0: skb should be passed up to userspace as UDP. + */ +int l2tp_udp_encap_recv(struct sock *sk, struct sk_buff *skb) +{ + struct l2tp_tunnel *tunnel; + + /* Note that this is called from the encap_rcv hook inside an + * RCU-protected region, but without the socket being locked. + * Hence we use rcu_dereference_sk_user_data to access the + * tunnel data structure rather the usual l2tp_sk_to_tunnel + * accessor function. + */ + tunnel = rcu_dereference_sk_user_data(sk); + if (!tunnel) + goto pass_up; + if (WARN_ON(tunnel->magic != L2TP_TUNNEL_MAGIC)) + goto pass_up; + + if (l2tp_udp_recv_core(tunnel, skb)) + goto pass_up; + + return 0; + +pass_up: + return 1; +} +EXPORT_SYMBOL_GPL(l2tp_udp_encap_recv); + +/************************************************************************ + * Transmit handling + ***********************************************************************/ + +/* Build an L2TP header for the session into the buffer provided. + */ +static int l2tp_build_l2tpv2_header(struct l2tp_session *session, void *buf) +{ + struct l2tp_tunnel *tunnel = session->tunnel; + __be16 *bufp = buf; + __be16 *optr = buf; + u16 flags = L2TP_HDR_VER_2; + u32 tunnel_id = tunnel->peer_tunnel_id; + u32 session_id = session->peer_session_id; + + if (session->send_seq) + flags |= L2TP_HDRFLAG_S; + + /* Setup L2TP header. */ + *bufp++ = htons(flags); + *bufp++ = htons(tunnel_id); + *bufp++ = htons(session_id); + if (session->send_seq) { + *bufp++ = htons(session->ns); + *bufp++ = 0; + session->ns++; + session->ns &= 0xffff; + trace_session_seqnum_update(session); + } + + return bufp - optr; +} + +static int l2tp_build_l2tpv3_header(struct l2tp_session *session, void *buf) +{ + struct l2tp_tunnel *tunnel = session->tunnel; + char *bufp = buf; + char *optr = bufp; + + /* Setup L2TP header. The header differs slightly for UDP and + * IP encapsulations. For UDP, there is 4 bytes of flags. + */ + if (tunnel->encap == L2TP_ENCAPTYPE_UDP) { + u16 flags = L2TP_HDR_VER_3; + *((__be16 *)bufp) = htons(flags); + bufp += 2; + *((__be16 *)bufp) = 0; + bufp += 2; + } + + *((__be32 *)bufp) = htonl(session->peer_session_id); + bufp += 4; + if (session->cookie_len) { + memcpy(bufp, &session->cookie[0], session->cookie_len); + bufp += session->cookie_len; + } + if (session->l2specific_type == L2TP_L2SPECTYPE_DEFAULT) { + u32 l2h = 0; + + if (session->send_seq) { + l2h = 0x40000000 | session->ns; + session->ns++; + session->ns &= 0xffffff; + trace_session_seqnum_update(session); + } + + *((__be32 *)bufp) = htonl(l2h); + bufp += 4; + } + + return bufp - optr; +} + +/* Queue the packet to IP for output: tunnel socket lock must be held */ +static int l2tp_xmit_queue(struct l2tp_tunnel *tunnel, struct sk_buff *skb, struct flowi *fl) +{ + int err; + + skb->ignore_df = 1; + skb_dst_drop(skb); +#if IS_ENABLED(CONFIG_IPV6) + if (l2tp_sk_is_v6(tunnel->sock)) + err = inet6_csk_xmit(tunnel->sock, skb, NULL); + else +#endif + err = ip_queue_xmit(tunnel->sock, skb, fl); + + return err >= 0 ? NET_XMIT_SUCCESS : NET_XMIT_DROP; +} + +static int l2tp_xmit_core(struct l2tp_session *session, struct sk_buff *skb, unsigned int *len) +{ + struct l2tp_tunnel *tunnel = session->tunnel; + unsigned int data_len = skb->len; + struct sock *sk = tunnel->sock; + int headroom, uhlen, udp_len; + int ret = NET_XMIT_SUCCESS; + struct inet_sock *inet; + struct udphdr *uh; + + /* Check that there's enough headroom in the skb to insert IP, + * UDP and L2TP headers. If not enough, expand it to + * make room. Adjust truesize. + */ + uhlen = (tunnel->encap == L2TP_ENCAPTYPE_UDP) ? sizeof(*uh) : 0; + headroom = NET_SKB_PAD + sizeof(struct iphdr) + uhlen + session->hdr_len; + if (skb_cow_head(skb, headroom)) { + kfree_skb(skb); + return NET_XMIT_DROP; + } + + /* Setup L2TP header */ + if (tunnel->version == L2TP_HDR_VER_2) + l2tp_build_l2tpv2_header(session, __skb_push(skb, session->hdr_len)); + else + l2tp_build_l2tpv3_header(session, __skb_push(skb, session->hdr_len)); + + /* Reset skb netfilter state */ + memset(&(IPCB(skb)->opt), 0, sizeof(IPCB(skb)->opt)); + IPCB(skb)->flags &= ~(IPSKB_XFRM_TUNNEL_SIZE | IPSKB_XFRM_TRANSFORMED | IPSKB_REROUTED); + nf_reset_ct(skb); + + bh_lock_sock_nested(sk); + if (sock_owned_by_user(sk)) { + kfree_skb(skb); + ret = NET_XMIT_DROP; + goto out_unlock; + } + + /* The user-space may change the connection status for the user-space + * provided socket at run time: we must check it under the socket lock + */ + if (tunnel->fd >= 0 && sk->sk_state != TCP_ESTABLISHED) { + kfree_skb(skb); + ret = NET_XMIT_DROP; + goto out_unlock; + } + + /* Report transmitted length before we add encap header, which keeps + * statistics consistent for both UDP and IP encap tx/rx paths. + */ + *len = skb->len; + + inet = inet_sk(sk); + switch (tunnel->encap) { + case L2TP_ENCAPTYPE_UDP: + /* Setup UDP header */ + __skb_push(skb, sizeof(*uh)); + skb_reset_transport_header(skb); + uh = udp_hdr(skb); + uh->source = inet->inet_sport; + uh->dest = inet->inet_dport; + udp_len = uhlen + session->hdr_len + data_len; + uh->len = htons(udp_len); + + /* Calculate UDP checksum if configured to do so */ +#if IS_ENABLED(CONFIG_IPV6) + if (l2tp_sk_is_v6(sk)) + udp6_set_csum(udp_get_no_check6_tx(sk), + skb, &inet6_sk(sk)->saddr, + &sk->sk_v6_daddr, udp_len); + else +#endif + udp_set_csum(sk->sk_no_check_tx, skb, inet->inet_saddr, + inet->inet_daddr, udp_len); + break; + + case L2TP_ENCAPTYPE_IP: + break; + } + + ret = l2tp_xmit_queue(tunnel, skb, &inet->cork.fl); + +out_unlock: + bh_unlock_sock(sk); + + return ret; +} + +/* If caller requires the skb to have a ppp header, the header must be + * inserted in the skb data before calling this function. + */ +int l2tp_xmit_skb(struct l2tp_session *session, struct sk_buff *skb) +{ + unsigned int len = 0; + int ret; + + ret = l2tp_xmit_core(session, skb, &len); + if (ret == NET_XMIT_SUCCESS) { + atomic_long_inc(&session->tunnel->stats.tx_packets); + atomic_long_add(len, &session->tunnel->stats.tx_bytes); + atomic_long_inc(&session->stats.tx_packets); + atomic_long_add(len, &session->stats.tx_bytes); + } else { + atomic_long_inc(&session->tunnel->stats.tx_errors); + atomic_long_inc(&session->stats.tx_errors); + } + return ret; +} +EXPORT_SYMBOL_GPL(l2tp_xmit_skb); + +/***************************************************************************** + * Tinnel and session create/destroy. + *****************************************************************************/ + +/* Tunnel socket destruct hook. + * The tunnel context is deleted only when all session sockets have been + * closed. + */ +static void l2tp_tunnel_destruct(struct sock *sk) +{ + struct l2tp_tunnel *tunnel = l2tp_sk_to_tunnel(sk); + + if (!tunnel) + goto end; + + /* Disable udp encapsulation */ + switch (tunnel->encap) { + case L2TP_ENCAPTYPE_UDP: + /* No longer an encapsulation socket. See net/ipv4/udp.c */ + WRITE_ONCE(udp_sk(sk)->encap_type, 0); + udp_sk(sk)->encap_rcv = NULL; + udp_sk(sk)->encap_destroy = NULL; + break; + case L2TP_ENCAPTYPE_IP: + break; + } + + /* Remove hooks into tunnel socket */ + write_lock_bh(&sk->sk_callback_lock); + sk->sk_destruct = tunnel->old_sk_destruct; + sk->sk_user_data = NULL; + write_unlock_bh(&sk->sk_callback_lock); + + /* Call the original destructor */ + if (sk->sk_destruct) + (*sk->sk_destruct)(sk); + + kfree_rcu(tunnel, rcu); +end: + return; +} + +/* Remove an l2tp session from l2tp_core's hash lists. */ +static void l2tp_session_unhash(struct l2tp_session *session) +{ + struct l2tp_tunnel *tunnel = session->tunnel; + + /* Remove the session from core hashes */ + if (tunnel) { + /* Remove from the per-tunnel hash */ + spin_lock_bh(&tunnel->hlist_lock); + hlist_del_init_rcu(&session->hlist); + spin_unlock_bh(&tunnel->hlist_lock); + + /* For L2TPv3 we have a per-net hash: remove from there, too */ + if (tunnel->version != L2TP_HDR_VER_2) { + struct l2tp_net *pn = l2tp_pernet(tunnel->l2tp_net); + + spin_lock_bh(&pn->l2tp_session_hlist_lock); + hlist_del_init_rcu(&session->global_hlist); + spin_unlock_bh(&pn->l2tp_session_hlist_lock); + } + + synchronize_rcu(); + } +} + +/* When the tunnel is closed, all the attached sessions need to go too. + */ +static void l2tp_tunnel_closeall(struct l2tp_tunnel *tunnel) +{ + struct l2tp_session *session; + int hash; + + spin_lock_bh(&tunnel->hlist_lock); + tunnel->acpt_newsess = false; + for (hash = 0; hash < L2TP_HASH_SIZE; hash++) { +again: + hlist_for_each_entry_rcu(session, &tunnel->session_hlist[hash], hlist) { + hlist_del_init_rcu(&session->hlist); + + spin_unlock_bh(&tunnel->hlist_lock); + l2tp_session_delete(session); + spin_lock_bh(&tunnel->hlist_lock); + + /* Now restart from the beginning of this hash + * chain. We always remove a session from the + * list so we are guaranteed to make forward + * progress. + */ + goto again; + } + } + spin_unlock_bh(&tunnel->hlist_lock); +} + +/* Tunnel socket destroy hook for UDP encapsulation */ +static void l2tp_udp_encap_destroy(struct sock *sk) +{ + struct l2tp_tunnel *tunnel = l2tp_sk_to_tunnel(sk); + + if (tunnel) + l2tp_tunnel_delete(tunnel); +} + +static void l2tp_tunnel_remove(struct net *net, struct l2tp_tunnel *tunnel) +{ + struct l2tp_net *pn = l2tp_pernet(net); + + spin_lock_bh(&pn->l2tp_tunnel_idr_lock); + idr_remove(&pn->l2tp_tunnel_idr, tunnel->tunnel_id); + spin_unlock_bh(&pn->l2tp_tunnel_idr_lock); +} + +/* Workqueue tunnel deletion function */ +static void l2tp_tunnel_del_work(struct work_struct *work) +{ + struct l2tp_tunnel *tunnel = container_of(work, struct l2tp_tunnel, + del_work); + struct sock *sk = tunnel->sock; + struct socket *sock = sk->sk_socket; + + l2tp_tunnel_closeall(tunnel); + + /* If the tunnel socket was created within the kernel, use + * the sk API to release it here. + */ + if (tunnel->fd < 0) { + if (sock) { + kernel_sock_shutdown(sock, SHUT_RDWR); + sock_release(sock); + } + } + + l2tp_tunnel_remove(tunnel->l2tp_net, tunnel); + /* drop initial ref */ + l2tp_tunnel_dec_refcount(tunnel); + + /* drop workqueue ref */ + l2tp_tunnel_dec_refcount(tunnel); +} + +/* Create a socket for the tunnel, if one isn't set up by + * userspace. This is used for static tunnels where there is no + * managing L2TP daemon. + * + * Since we don't want these sockets to keep a namespace alive by + * themselves, we drop the socket's namespace refcount after creation. + * These sockets are freed when the namespace exits using the pernet + * exit hook. + */ +static int l2tp_tunnel_sock_create(struct net *net, + u32 tunnel_id, + u32 peer_tunnel_id, + struct l2tp_tunnel_cfg *cfg, + struct socket **sockp) +{ + int err = -EINVAL; + struct socket *sock = NULL; + struct udp_port_cfg udp_conf; + + switch (cfg->encap) { + case L2TP_ENCAPTYPE_UDP: + memset(&udp_conf, 0, sizeof(udp_conf)); + +#if IS_ENABLED(CONFIG_IPV6) + if (cfg->local_ip6 && cfg->peer_ip6) { + udp_conf.family = AF_INET6; + memcpy(&udp_conf.local_ip6, cfg->local_ip6, + sizeof(udp_conf.local_ip6)); + memcpy(&udp_conf.peer_ip6, cfg->peer_ip6, + sizeof(udp_conf.peer_ip6)); + udp_conf.use_udp6_tx_checksums = + !cfg->udp6_zero_tx_checksums; + udp_conf.use_udp6_rx_checksums = + !cfg->udp6_zero_rx_checksums; + } else +#endif + { + udp_conf.family = AF_INET; + udp_conf.local_ip = cfg->local_ip; + udp_conf.peer_ip = cfg->peer_ip; + udp_conf.use_udp_checksums = cfg->use_udp_checksums; + } + + udp_conf.local_udp_port = htons(cfg->local_udp_port); + udp_conf.peer_udp_port = htons(cfg->peer_udp_port); + + err = udp_sock_create(net, &udp_conf, &sock); + if (err < 0) + goto out; + + break; + + case L2TP_ENCAPTYPE_IP: +#if IS_ENABLED(CONFIG_IPV6) + if (cfg->local_ip6 && cfg->peer_ip6) { + struct sockaddr_l2tpip6 ip6_addr = {0}; + + err = sock_create_kern(net, AF_INET6, SOCK_DGRAM, + IPPROTO_L2TP, &sock); + if (err < 0) + goto out; + + ip6_addr.l2tp_family = AF_INET6; + memcpy(&ip6_addr.l2tp_addr, cfg->local_ip6, + sizeof(ip6_addr.l2tp_addr)); + ip6_addr.l2tp_conn_id = tunnel_id; + err = kernel_bind(sock, (struct sockaddr *)&ip6_addr, + sizeof(ip6_addr)); + if (err < 0) + goto out; + + ip6_addr.l2tp_family = AF_INET6; + memcpy(&ip6_addr.l2tp_addr, cfg->peer_ip6, + sizeof(ip6_addr.l2tp_addr)); + ip6_addr.l2tp_conn_id = peer_tunnel_id; + err = kernel_connect(sock, + (struct sockaddr *)&ip6_addr, + sizeof(ip6_addr), 0); + if (err < 0) + goto out; + } else +#endif + { + struct sockaddr_l2tpip ip_addr = {0}; + + err = sock_create_kern(net, AF_INET, SOCK_DGRAM, + IPPROTO_L2TP, &sock); + if (err < 0) + goto out; + + ip_addr.l2tp_family = AF_INET; + ip_addr.l2tp_addr = cfg->local_ip; + ip_addr.l2tp_conn_id = tunnel_id; + err = kernel_bind(sock, (struct sockaddr *)&ip_addr, + sizeof(ip_addr)); + if (err < 0) + goto out; + + ip_addr.l2tp_family = AF_INET; + ip_addr.l2tp_addr = cfg->peer_ip; + ip_addr.l2tp_conn_id = peer_tunnel_id; + err = kernel_connect(sock, (struct sockaddr *)&ip_addr, + sizeof(ip_addr), 0); + if (err < 0) + goto out; + } + break; + + default: + goto out; + } + +out: + *sockp = sock; + if (err < 0 && sock) { + kernel_sock_shutdown(sock, SHUT_RDWR); + sock_release(sock); + *sockp = NULL; + } + + return err; +} + +int l2tp_tunnel_create(int fd, int version, u32 tunnel_id, u32 peer_tunnel_id, + struct l2tp_tunnel_cfg *cfg, struct l2tp_tunnel **tunnelp) +{ + struct l2tp_tunnel *tunnel = NULL; + int err; + enum l2tp_encap_type encap = L2TP_ENCAPTYPE_UDP; + + if (cfg) + encap = cfg->encap; + + tunnel = kzalloc(sizeof(*tunnel), GFP_KERNEL); + if (!tunnel) { + err = -ENOMEM; + goto err; + } + + tunnel->version = version; + tunnel->tunnel_id = tunnel_id; + tunnel->peer_tunnel_id = peer_tunnel_id; + + tunnel->magic = L2TP_TUNNEL_MAGIC; + sprintf(&tunnel->name[0], "tunl %u", tunnel_id); + spin_lock_init(&tunnel->hlist_lock); + tunnel->acpt_newsess = true; + + tunnel->encap = encap; + + refcount_set(&tunnel->ref_count, 1); + tunnel->fd = fd; + + /* Init delete workqueue struct */ + INIT_WORK(&tunnel->del_work, l2tp_tunnel_del_work); + + INIT_LIST_HEAD(&tunnel->list); + + err = 0; +err: + if (tunnelp) + *tunnelp = tunnel; + + return err; +} +EXPORT_SYMBOL_GPL(l2tp_tunnel_create); + +static int l2tp_validate_socket(const struct sock *sk, const struct net *net, + enum l2tp_encap_type encap) +{ + if (!net_eq(sock_net(sk), net)) + return -EINVAL; + + if (sk->sk_type != SOCK_DGRAM) + return -EPROTONOSUPPORT; + + if (sk->sk_family != PF_INET && sk->sk_family != PF_INET6) + return -EPROTONOSUPPORT; + + if ((encap == L2TP_ENCAPTYPE_UDP && sk->sk_protocol != IPPROTO_UDP) || + (encap == L2TP_ENCAPTYPE_IP && sk->sk_protocol != IPPROTO_L2TP)) + return -EPROTONOSUPPORT; + + if (sk->sk_user_data) + return -EBUSY; + + return 0; +} + +int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, + struct l2tp_tunnel_cfg *cfg) +{ + struct l2tp_net *pn = l2tp_pernet(net); + u32 tunnel_id = tunnel->tunnel_id; + struct socket *sock; + struct sock *sk; + int ret; + + spin_lock_bh(&pn->l2tp_tunnel_idr_lock); + ret = idr_alloc_u32(&pn->l2tp_tunnel_idr, NULL, &tunnel_id, tunnel_id, + GFP_ATOMIC); + spin_unlock_bh(&pn->l2tp_tunnel_idr_lock); + if (ret) + return ret == -ENOSPC ? -EEXIST : ret; + + if (tunnel->fd < 0) { + ret = l2tp_tunnel_sock_create(net, tunnel->tunnel_id, + tunnel->peer_tunnel_id, cfg, + &sock); + if (ret < 0) + goto err; + } else { + sock = sockfd_lookup(tunnel->fd, &ret); + if (!sock) + goto err; + } + + sk = sock->sk; + lock_sock(sk); + write_lock_bh(&sk->sk_callback_lock); + ret = l2tp_validate_socket(sk, net, tunnel->encap); + if (ret < 0) + goto err_inval_sock; + rcu_assign_sk_user_data(sk, tunnel); + write_unlock_bh(&sk->sk_callback_lock); + + if (tunnel->encap == L2TP_ENCAPTYPE_UDP) { + struct udp_tunnel_sock_cfg udp_cfg = { + .sk_user_data = tunnel, + .encap_type = UDP_ENCAP_L2TPINUDP, + .encap_rcv = l2tp_udp_encap_recv, + .encap_destroy = l2tp_udp_encap_destroy, + }; + + setup_udp_tunnel_sock(net, sock, &udp_cfg); + } + + tunnel->old_sk_destruct = sk->sk_destruct; + sk->sk_destruct = &l2tp_tunnel_destruct; + sk->sk_allocation = GFP_ATOMIC; + release_sock(sk); + + sock_hold(sk); + tunnel->sock = sk; + tunnel->l2tp_net = net; + + spin_lock_bh(&pn->l2tp_tunnel_idr_lock); + idr_replace(&pn->l2tp_tunnel_idr, tunnel, tunnel->tunnel_id); + spin_unlock_bh(&pn->l2tp_tunnel_idr_lock); + + trace_register_tunnel(tunnel); + + if (tunnel->fd >= 0) + sockfd_put(sock); + + return 0; + +err_inval_sock: + write_unlock_bh(&sk->sk_callback_lock); + release_sock(sk); + + if (tunnel->fd < 0) + sock_release(sock); + else + sockfd_put(sock); +err: + l2tp_tunnel_remove(net, tunnel); + return ret; +} +EXPORT_SYMBOL_GPL(l2tp_tunnel_register); + +/* This function is used by the netlink TUNNEL_DELETE command. + */ +void l2tp_tunnel_delete(struct l2tp_tunnel *tunnel) +{ + if (!test_and_set_bit(0, &tunnel->dead)) { + trace_delete_tunnel(tunnel); + l2tp_tunnel_inc_refcount(tunnel); + queue_work(l2tp_wq, &tunnel->del_work); + } +} +EXPORT_SYMBOL_GPL(l2tp_tunnel_delete); + +void l2tp_session_delete(struct l2tp_session *session) +{ + if (test_and_set_bit(0, &session->dead)) + return; + + trace_delete_session(session); + l2tp_session_unhash(session); + l2tp_session_queue_purge(session); + if (session->session_close) + (*session->session_close)(session); + + l2tp_session_dec_refcount(session); +} +EXPORT_SYMBOL_GPL(l2tp_session_delete); + +/* We come here whenever a session's send_seq, cookie_len or + * l2specific_type parameters are set. + */ +void l2tp_session_set_header_len(struct l2tp_session *session, int version) +{ + if (version == L2TP_HDR_VER_2) { + session->hdr_len = 6; + if (session->send_seq) + session->hdr_len += 4; + } else { + session->hdr_len = 4 + session->cookie_len; + session->hdr_len += l2tp_get_l2specific_len(session); + if (session->tunnel->encap == L2TP_ENCAPTYPE_UDP) + session->hdr_len += 4; + } +} +EXPORT_SYMBOL_GPL(l2tp_session_set_header_len); + +struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunnel, u32 session_id, + u32 peer_session_id, struct l2tp_session_cfg *cfg) +{ + struct l2tp_session *session; + + session = kzalloc(sizeof(*session) + priv_size, GFP_KERNEL); + if (session) { + session->magic = L2TP_SESSION_MAGIC; + session->tunnel = tunnel; + + session->session_id = session_id; + session->peer_session_id = peer_session_id; + session->nr = 0; + if (tunnel->version == L2TP_HDR_VER_2) + session->nr_max = 0xffff; + else + session->nr_max = 0xffffff; + session->nr_window_size = session->nr_max / 2; + session->nr_oos_count_max = 4; + + /* Use NR of first received packet */ + session->reorder_skip = 1; + + sprintf(&session->name[0], "sess %u/%u", + tunnel->tunnel_id, session->session_id); + + skb_queue_head_init(&session->reorder_q); + + INIT_HLIST_NODE(&session->hlist); + INIT_HLIST_NODE(&session->global_hlist); + + if (cfg) { + session->pwtype = cfg->pw_type; + session->send_seq = cfg->send_seq; + session->recv_seq = cfg->recv_seq; + session->lns_mode = cfg->lns_mode; + session->reorder_timeout = cfg->reorder_timeout; + session->l2specific_type = cfg->l2specific_type; + session->cookie_len = cfg->cookie_len; + memcpy(&session->cookie[0], &cfg->cookie[0], cfg->cookie_len); + session->peer_cookie_len = cfg->peer_cookie_len; + memcpy(&session->peer_cookie[0], &cfg->peer_cookie[0], cfg->peer_cookie_len); + } + + l2tp_session_set_header_len(session, tunnel->version); + + refcount_set(&session->ref_count, 1); + + return session; + } + + return ERR_PTR(-ENOMEM); +} +EXPORT_SYMBOL_GPL(l2tp_session_create); + +/***************************************************************************** + * Init and cleanup + *****************************************************************************/ + +static __net_init int l2tp_init_net(struct net *net) +{ + struct l2tp_net *pn = net_generic(net, l2tp_net_id); + int hash; + + idr_init(&pn->l2tp_tunnel_idr); + spin_lock_init(&pn->l2tp_tunnel_idr_lock); + + for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) + INIT_HLIST_HEAD(&pn->l2tp_session_hlist[hash]); + + spin_lock_init(&pn->l2tp_session_hlist_lock); + + return 0; +} + +static __net_exit void l2tp_exit_net(struct net *net) +{ + struct l2tp_net *pn = l2tp_pernet(net); + struct l2tp_tunnel *tunnel = NULL; + unsigned long tunnel_id, tmp; + int hash; + + rcu_read_lock_bh(); + idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) { + if (tunnel) + l2tp_tunnel_delete(tunnel); + } + rcu_read_unlock_bh(); + + if (l2tp_wq) + flush_workqueue(l2tp_wq); + rcu_barrier(); + + for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) + WARN_ON_ONCE(!hlist_empty(&pn->l2tp_session_hlist[hash])); + idr_destroy(&pn->l2tp_tunnel_idr); +} + +static struct pernet_operations l2tp_net_ops = { + .init = l2tp_init_net, + .exit = l2tp_exit_net, + .id = &l2tp_net_id, + .size = sizeof(struct l2tp_net), +}; + +static int __init l2tp_init(void) +{ + int rc = 0; + + rc = register_pernet_device(&l2tp_net_ops); + if (rc) + goto out; + + l2tp_wq = alloc_workqueue("l2tp", WQ_UNBOUND, 0); + if (!l2tp_wq) { + pr_err("alloc_workqueue failed\n"); + unregister_pernet_device(&l2tp_net_ops); + rc = -ENOMEM; + goto out; + } + + pr_info("L2TP core driver, %s\n", L2TP_DRV_VERSION); + +out: + return rc; +} + +static void __exit l2tp_exit(void) +{ + unregister_pernet_device(&l2tp_net_ops); + if (l2tp_wq) { + destroy_workqueue(l2tp_wq); + l2tp_wq = NULL; + } +} + +module_init(l2tp_init); +module_exit(l2tp_exit); + +MODULE_AUTHOR("James Chapman "); +MODULE_DESCRIPTION("L2TP core"); +MODULE_LICENSE("GPL"); +MODULE_VERSION(L2TP_DRV_VERSION); diff --git a/net/l2tp/l2tp_core.h b/net/l2tp/l2tp_core.h new file mode 100644 index 000000000..a88e070b4 --- /dev/null +++ b/net/l2tp/l2tp_core.h @@ -0,0 +1,346 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* L2TP internal definitions. + * + * Copyright (c) 2008,2009 Katalix Systems Ltd + */ +#include + +#ifndef _L2TP_CORE_H_ +#define _L2TP_CORE_H_ + +#include +#include + +#ifdef CONFIG_XFRM +#include +#endif + +/* Random numbers used for internal consistency checks of tunnel and session structures */ +#define L2TP_TUNNEL_MAGIC 0x42114DDA +#define L2TP_SESSION_MAGIC 0x0C04EB7D + +/* Per tunnel session hash table size */ +#define L2TP_HASH_BITS 4 +#define L2TP_HASH_SIZE BIT(L2TP_HASH_BITS) + +/* System-wide session hash table size */ +#define L2TP_HASH_BITS_2 8 +#define L2TP_HASH_SIZE_2 BIT(L2TP_HASH_BITS_2) + +struct sk_buff; + +struct l2tp_stats { + atomic_long_t tx_packets; + atomic_long_t tx_bytes; + atomic_long_t tx_errors; + atomic_long_t rx_packets; + atomic_long_t rx_bytes; + atomic_long_t rx_seq_discards; + atomic_long_t rx_oos_packets; + atomic_long_t rx_errors; + atomic_long_t rx_cookie_discards; + atomic_long_t rx_invalid; +}; + +struct l2tp_tunnel; + +/* L2TP session configuration */ +struct l2tp_session_cfg { + enum l2tp_pwtype pw_type; + unsigned int recv_seq:1; /* expect receive packets with sequence numbers? */ + unsigned int send_seq:1; /* send packets with sequence numbers? */ + unsigned int lns_mode:1; /* behave as LNS? + * LAC enables sequence numbers under LNS control. + */ + u16 l2specific_type; /* Layer 2 specific type */ + u8 cookie[8]; /* optional cookie */ + int cookie_len; /* 0, 4 or 8 bytes */ + u8 peer_cookie[8]; /* peer's cookie */ + int peer_cookie_len; /* 0, 4 or 8 bytes */ + int reorder_timeout; /* configured reorder timeout (in jiffies) */ + char *ifname; +}; + +/* Represents a session (pseudowire) instance. + * Tracks runtime state including cookies, dataplane packet sequencing, and IO statistics. + * Is linked into a per-tunnel session hashlist; and in the case of an L2TPv3 session into + * an additional per-net ("global") hashlist. + */ +#define L2TP_SESSION_NAME_MAX 32 +struct l2tp_session { + int magic; /* should be L2TP_SESSION_MAGIC */ + long dead; + + struct l2tp_tunnel *tunnel; /* back pointer to tunnel context */ + u32 session_id; + u32 peer_session_id; + u8 cookie[8]; + int cookie_len; + u8 peer_cookie[8]; + int peer_cookie_len; + u16 l2specific_type; + u16 hdr_len; + u32 nr; /* session NR state (receive) */ + u32 ns; /* session NR state (send) */ + struct sk_buff_head reorder_q; /* receive reorder queue */ + u32 nr_max; /* max NR. Depends on tunnel */ + u32 nr_window_size; /* NR window size */ + u32 nr_oos; /* NR of last OOS packet */ + int nr_oos_count; /* for OOS recovery */ + int nr_oos_count_max; + struct hlist_node hlist; /* hash list node */ + refcount_t ref_count; + + char name[L2TP_SESSION_NAME_MAX]; /* for logging */ + char ifname[IFNAMSIZ]; + unsigned int recv_seq:1; /* expect receive packets with sequence numbers? */ + unsigned int send_seq:1; /* send packets with sequence numbers? */ + unsigned int lns_mode:1; /* behave as LNS? + * LAC enables sequence numbers under LNS control. + */ + int reorder_timeout; /* configured reorder timeout (in jiffies) */ + int reorder_skip; /* set if skip to next nr */ + enum l2tp_pwtype pwtype; + struct l2tp_stats stats; + struct hlist_node global_hlist; /* global hash list node */ + + /* Session receive handler for data packets. + * Each pseudowire implementation should implement this callback in order to + * handle incoming packets. Packets are passed to the pseudowire handler after + * reordering, if data sequence numbers are enabled for the session. + */ + void (*recv_skb)(struct l2tp_session *session, struct sk_buff *skb, int data_len); + + /* Session close handler. + * Each pseudowire implementation may implement this callback in order to carry + * out pseudowire-specific shutdown actions. + * The callback is called by core after unhashing the session and purging its + * reorder queue. + */ + void (*session_close)(struct l2tp_session *session); + + /* Session show handler. + * Pseudowire-specific implementation of debugfs session rendering. + * The callback is called by l2tp_debugfs.c after rendering core session + * information. + */ + void (*show)(struct seq_file *m, void *priv); + + u8 priv[]; /* private data */ +}; + +/* L2TP tunnel configuration */ +struct l2tp_tunnel_cfg { + enum l2tp_encap_type encap; + + /* Used only for kernel-created sockets */ + struct in_addr local_ip; + struct in_addr peer_ip; +#if IS_ENABLED(CONFIG_IPV6) + struct in6_addr *local_ip6; + struct in6_addr *peer_ip6; +#endif + u16 local_udp_port; + u16 peer_udp_port; + unsigned int use_udp_checksums:1, + udp6_zero_tx_checksums:1, + udp6_zero_rx_checksums:1; +}; + +/* Represents a tunnel instance. + * Tracks runtime state including IO statistics. + * Holds the tunnel socket (either passed from userspace or directly created by the kernel). + * Maintains a hashlist of sessions belonging to the tunnel instance. + * Is linked into a per-net list of tunnels. + */ +#define L2TP_TUNNEL_NAME_MAX 20 +struct l2tp_tunnel { + int magic; /* Should be L2TP_TUNNEL_MAGIC */ + + unsigned long dead; + + struct rcu_head rcu; + spinlock_t hlist_lock; /* write-protection for session_hlist */ + bool acpt_newsess; /* indicates whether this tunnel accepts + * new sessions. Protected by hlist_lock. + */ + struct hlist_head session_hlist[L2TP_HASH_SIZE]; + /* hashed list of sessions, hashed by id */ + u32 tunnel_id; + u32 peer_tunnel_id; + int version; /* 2=>L2TPv2, 3=>L2TPv3 */ + + char name[L2TP_TUNNEL_NAME_MAX]; /* for logging */ + enum l2tp_encap_type encap; + struct l2tp_stats stats; + + struct list_head list; /* list node on per-namespace list of tunnels */ + struct net *l2tp_net; /* the net we belong to */ + + refcount_t ref_count; + void (*old_sk_destruct)(struct sock *sk); + struct sock *sock; /* parent socket */ + int fd; /* parent fd, if tunnel socket was created + * by userspace + */ + + struct work_struct del_work; +}; + +/* Pseudowire ops callbacks for use with the l2tp genetlink interface */ +struct l2tp_nl_cmd_ops { + /* The pseudowire session create callback is responsible for creating a session + * instance for a specific pseudowire type. + * It must call l2tp_session_create and l2tp_session_register to register the + * session instance, as well as carry out any pseudowire-specific initialisation. + * It must return >= 0 on success, or an appropriate negative errno value on failure. + */ + int (*session_create)(struct net *net, struct l2tp_tunnel *tunnel, + u32 session_id, u32 peer_session_id, + struct l2tp_session_cfg *cfg); + + /* The pseudowire session delete callback is responsible for initiating the deletion + * of a session instance. + * It must call l2tp_session_delete, as well as carry out any pseudowire-specific + * teardown actions. + */ + void (*session_delete)(struct l2tp_session *session); +}; + +static inline void *l2tp_session_priv(struct l2tp_session *session) +{ + return &session->priv[0]; +} + +/* Tunnel and session refcounts */ +void l2tp_tunnel_inc_refcount(struct l2tp_tunnel *tunnel); +void l2tp_tunnel_dec_refcount(struct l2tp_tunnel *tunnel); +void l2tp_session_inc_refcount(struct l2tp_session *session); +void l2tp_session_dec_refcount(struct l2tp_session *session); + +/* Tunnel and session lookup. + * These functions take a reference on the instances they return, so + * the caller must ensure that the reference is dropped appropriately. + */ +struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id); +struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth); +struct l2tp_session *l2tp_tunnel_get_session(struct l2tp_tunnel *tunnel, + u32 session_id); + +struct l2tp_session *l2tp_session_get(const struct net *net, u32 session_id); +struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth); +struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net, + const char *ifname); + +/* Tunnel and session lifetime management. + * Creation of a new instance is a two-step process: create, then register. + * Destruction is triggered using the *_delete functions, and completes asynchronously. + */ +int l2tp_tunnel_create(int fd, int version, u32 tunnel_id, + u32 peer_tunnel_id, struct l2tp_tunnel_cfg *cfg, + struct l2tp_tunnel **tunnelp); +int l2tp_tunnel_register(struct l2tp_tunnel *tunnel, struct net *net, + struct l2tp_tunnel_cfg *cfg); +void l2tp_tunnel_delete(struct l2tp_tunnel *tunnel); + +struct l2tp_session *l2tp_session_create(int priv_size, + struct l2tp_tunnel *tunnel, + u32 session_id, u32 peer_session_id, + struct l2tp_session_cfg *cfg); +int l2tp_session_register(struct l2tp_session *session, + struct l2tp_tunnel *tunnel); +void l2tp_session_delete(struct l2tp_session *session); + +/* Receive path helpers. If data sequencing is enabled for the session these + * functions handle queuing and reordering prior to passing packets to the + * pseudowire code to be passed to userspace. + */ +void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb, + unsigned char *ptr, unsigned char *optr, u16 hdrflags, + int length); +int l2tp_udp_encap_recv(struct sock *sk, struct sk_buff *skb); + +/* Transmit path helpers for sending packets over the tunnel socket. */ +void l2tp_session_set_header_len(struct l2tp_session *session, int version); +int l2tp_xmit_skb(struct l2tp_session *session, struct sk_buff *skb); + +/* Pseudowire management. + * Pseudowires should register with l2tp core on module init, and unregister + * on module exit. + */ +int l2tp_nl_register_ops(enum l2tp_pwtype pw_type, const struct l2tp_nl_cmd_ops *ops); +void l2tp_nl_unregister_ops(enum l2tp_pwtype pw_type); + +/* IOCTL helper for IP encap modules. */ +int l2tp_ioctl(struct sock *sk, int cmd, unsigned long arg); + +/* Extract the tunnel structure from a socket's sk_user_data pointer, + * validating the tunnel magic feather. + */ +struct l2tp_tunnel *l2tp_sk_to_tunnel(struct sock *sk); + +static inline int l2tp_get_l2specific_len(struct l2tp_session *session) +{ + switch (session->l2specific_type) { + case L2TP_L2SPECTYPE_DEFAULT: + return 4; + case L2TP_L2SPECTYPE_NONE: + default: + return 0; + } +} + +static inline u32 l2tp_tunnel_dst_mtu(const struct l2tp_tunnel *tunnel) +{ + struct dst_entry *dst; + u32 mtu; + + dst = sk_dst_get(tunnel->sock); + if (!dst) + return 0; + + mtu = dst_mtu(dst); + dst_release(dst); + + return mtu; +} + +#ifdef CONFIG_XFRM +static inline bool l2tp_tunnel_uses_xfrm(const struct l2tp_tunnel *tunnel) +{ + struct sock *sk = tunnel->sock; + + return sk && (rcu_access_pointer(sk->sk_policy[0]) || + rcu_access_pointer(sk->sk_policy[1])); +} +#else +static inline bool l2tp_tunnel_uses_xfrm(const struct l2tp_tunnel *tunnel) +{ + return false; +} +#endif + +static inline int l2tp_v3_ensure_opt_in_linear(struct l2tp_session *session, struct sk_buff *skb, + unsigned char **ptr, unsigned char **optr) +{ + int opt_len = session->peer_cookie_len + l2tp_get_l2specific_len(session); + + if (opt_len > 0) { + int off = *ptr - *optr; + + if (!pskb_may_pull(skb, off + opt_len)) + return -1; + + if (skb->data != *optr) { + *optr = skb->data; + *ptr = skb->data + off; + } + } + + return 0; +} + +#define MODULE_ALIAS_L2TP_PWTYPE(type) \ + MODULE_ALIAS("net-l2tp-type-" __stringify(type)) + +#endif /* _L2TP_CORE_H_ */ diff --git a/net/l2tp/l2tp_debugfs.c b/net/l2tp/l2tp_debugfs.c new file mode 100644 index 000000000..4595b56d1 --- /dev/null +++ b/net/l2tp/l2tp_debugfs.c @@ -0,0 +1,348 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* L2TP subsystem debugfs + * + * Copyright (c) 2010 Katalix Systems Ltd + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "l2tp_core.h" + +static struct dentry *rootdir; + +struct l2tp_dfs_seq_data { + struct net *net; + netns_tracker ns_tracker; + int tunnel_idx; /* current tunnel */ + int session_idx; /* index of session within current tunnel */ + struct l2tp_tunnel *tunnel; + struct l2tp_session *session; /* NULL means get next tunnel */ +}; + +static void l2tp_dfs_next_tunnel(struct l2tp_dfs_seq_data *pd) +{ + /* Drop reference taken during previous invocation */ + if (pd->tunnel) + l2tp_tunnel_dec_refcount(pd->tunnel); + + pd->tunnel = l2tp_tunnel_get_nth(pd->net, pd->tunnel_idx); + pd->tunnel_idx++; +} + +static void l2tp_dfs_next_session(struct l2tp_dfs_seq_data *pd) +{ + /* Drop reference taken during previous invocation */ + if (pd->session) + l2tp_session_dec_refcount(pd->session); + + pd->session = l2tp_session_get_nth(pd->tunnel, pd->session_idx); + pd->session_idx++; + + if (!pd->session) { + pd->session_idx = 0; + l2tp_dfs_next_tunnel(pd); + } +} + +static void *l2tp_dfs_seq_start(struct seq_file *m, loff_t *offs) +{ + struct l2tp_dfs_seq_data *pd = SEQ_START_TOKEN; + loff_t pos = *offs; + + if (!pos) + goto out; + + if (WARN_ON(!m->private)) { + pd = NULL; + goto out; + } + pd = m->private; + + if (!pd->tunnel) + l2tp_dfs_next_tunnel(pd); + else + l2tp_dfs_next_session(pd); + + /* NULL tunnel and session indicates end of list */ + if (!pd->tunnel && !pd->session) + pd = NULL; + +out: + return pd; +} + +static void *l2tp_dfs_seq_next(struct seq_file *m, void *v, loff_t *pos) +{ + (*pos)++; + return NULL; +} + +static void l2tp_dfs_seq_stop(struct seq_file *p, void *v) +{ + struct l2tp_dfs_seq_data *pd = v; + + if (!pd || pd == SEQ_START_TOKEN) + return; + + /* Drop reference taken by last invocation of l2tp_dfs_next_session() + * or l2tp_dfs_next_tunnel(). + */ + if (pd->session) { + l2tp_session_dec_refcount(pd->session); + pd->session = NULL; + } + if (pd->tunnel) { + l2tp_tunnel_dec_refcount(pd->tunnel); + pd->tunnel = NULL; + } +} + +static void l2tp_dfs_seq_tunnel_show(struct seq_file *m, void *v) +{ + struct l2tp_tunnel *tunnel = v; + struct l2tp_session *session; + int session_count = 0; + int hash; + + rcu_read_lock_bh(); + for (hash = 0; hash < L2TP_HASH_SIZE; hash++) { + hlist_for_each_entry_rcu(session, &tunnel->session_hlist[hash], hlist) { + /* Session ID of zero is a dummy/reserved value used by pppol2tp */ + if (session->session_id == 0) + continue; + + session_count++; + } + } + rcu_read_unlock_bh(); + + seq_printf(m, "\nTUNNEL %u peer %u", tunnel->tunnel_id, tunnel->peer_tunnel_id); + if (tunnel->sock) { + struct inet_sock *inet = inet_sk(tunnel->sock); + +#if IS_ENABLED(CONFIG_IPV6) + if (tunnel->sock->sk_family == AF_INET6) { + const struct ipv6_pinfo *np = inet6_sk(tunnel->sock); + + seq_printf(m, " from %pI6c to %pI6c\n", + &np->saddr, &tunnel->sock->sk_v6_daddr); + } +#endif + if (tunnel->sock->sk_family == AF_INET) + seq_printf(m, " from %pI4 to %pI4\n", + &inet->inet_saddr, &inet->inet_daddr); + + if (tunnel->encap == L2TP_ENCAPTYPE_UDP) + seq_printf(m, " source port %hu, dest port %hu\n", + ntohs(inet->inet_sport), ntohs(inet->inet_dport)); + } + seq_printf(m, " L2TPv%d, %s\n", tunnel->version, + tunnel->encap == L2TP_ENCAPTYPE_UDP ? "UDP" : + tunnel->encap == L2TP_ENCAPTYPE_IP ? "IP" : + ""); + seq_printf(m, " %d sessions, refcnt %d/%d\n", session_count, + tunnel->sock ? refcount_read(&tunnel->sock->sk_refcnt) : 0, + refcount_read(&tunnel->ref_count)); + seq_printf(m, " %08x rx %ld/%ld/%ld rx %ld/%ld/%ld\n", + 0, + atomic_long_read(&tunnel->stats.tx_packets), + atomic_long_read(&tunnel->stats.tx_bytes), + atomic_long_read(&tunnel->stats.tx_errors), + atomic_long_read(&tunnel->stats.rx_packets), + atomic_long_read(&tunnel->stats.rx_bytes), + atomic_long_read(&tunnel->stats.rx_errors)); +} + +static void l2tp_dfs_seq_session_show(struct seq_file *m, void *v) +{ + struct l2tp_session *session = v; + + seq_printf(m, " SESSION %u, peer %u, %s\n", session->session_id, + session->peer_session_id, + session->pwtype == L2TP_PWTYPE_ETH ? "ETH" : + session->pwtype == L2TP_PWTYPE_PPP ? "PPP" : + ""); + if (session->send_seq || session->recv_seq) + seq_printf(m, " nr %u, ns %u\n", session->nr, session->ns); + seq_printf(m, " refcnt %d\n", refcount_read(&session->ref_count)); + seq_printf(m, " config 0/0/%c/%c/-/%s %08x %u\n", + session->recv_seq ? 'R' : '-', + session->send_seq ? 'S' : '-', + session->lns_mode ? "LNS" : "LAC", + 0, + jiffies_to_msecs(session->reorder_timeout)); + seq_printf(m, " offset 0 l2specific %hu/%d\n", + session->l2specific_type, l2tp_get_l2specific_len(session)); + if (session->cookie_len) { + seq_printf(m, " cookie %02x%02x%02x%02x", + session->cookie[0], session->cookie[1], + session->cookie[2], session->cookie[3]); + if (session->cookie_len == 8) + seq_printf(m, "%02x%02x%02x%02x", + session->cookie[4], session->cookie[5], + session->cookie[6], session->cookie[7]); + seq_puts(m, "\n"); + } + if (session->peer_cookie_len) { + seq_printf(m, " peer cookie %02x%02x%02x%02x", + session->peer_cookie[0], session->peer_cookie[1], + session->peer_cookie[2], session->peer_cookie[3]); + if (session->peer_cookie_len == 8) + seq_printf(m, "%02x%02x%02x%02x", + session->peer_cookie[4], session->peer_cookie[5], + session->peer_cookie[6], session->peer_cookie[7]); + seq_puts(m, "\n"); + } + + seq_printf(m, " %u/%u tx %ld/%ld/%ld rx %ld/%ld/%ld\n", + session->nr, session->ns, + atomic_long_read(&session->stats.tx_packets), + atomic_long_read(&session->stats.tx_bytes), + atomic_long_read(&session->stats.tx_errors), + atomic_long_read(&session->stats.rx_packets), + atomic_long_read(&session->stats.rx_bytes), + atomic_long_read(&session->stats.rx_errors)); + + if (session->show) + session->show(m, session); +} + +static int l2tp_dfs_seq_show(struct seq_file *m, void *v) +{ + struct l2tp_dfs_seq_data *pd = v; + + /* display header on line 1 */ + if (v == SEQ_START_TOKEN) { + seq_puts(m, "TUNNEL ID, peer ID from IP to IP\n"); + seq_puts(m, " L2TPv2/L2TPv3, UDP/IP\n"); + seq_puts(m, " sessions session-count, refcnt refcnt/sk->refcnt\n"); + seq_puts(m, " debug tx-pkts/bytes/errs rx-pkts/bytes/errs\n"); + seq_puts(m, " SESSION ID, peer ID, PWTYPE\n"); + seq_puts(m, " refcnt cnt\n"); + seq_puts(m, " offset OFFSET l2specific TYPE/LEN\n"); + seq_puts(m, " [ cookie ]\n"); + seq_puts(m, " [ peer cookie ]\n"); + seq_puts(m, " config mtu/mru/rcvseq/sendseq/dataseq/lns debug reorderto\n"); + seq_puts(m, " nr/ns tx-pkts/bytes/errs rx-pkts/bytes/errs\n"); + goto out; + } + + if (!pd->session) + l2tp_dfs_seq_tunnel_show(m, pd->tunnel); + else + l2tp_dfs_seq_session_show(m, pd->session); + +out: + return 0; +} + +static const struct seq_operations l2tp_dfs_seq_ops = { + .start = l2tp_dfs_seq_start, + .next = l2tp_dfs_seq_next, + .stop = l2tp_dfs_seq_stop, + .show = l2tp_dfs_seq_show, +}; + +static int l2tp_dfs_seq_open(struct inode *inode, struct file *file) +{ + struct l2tp_dfs_seq_data *pd; + struct seq_file *seq; + int rc = -ENOMEM; + + pd = kzalloc(sizeof(*pd), GFP_KERNEL); + if (!pd) + goto out; + + /* Derive the network namespace from the pid opening the + * file. + */ + pd->net = get_net_ns_by_pid(current->pid); + if (IS_ERR(pd->net)) { + rc = PTR_ERR(pd->net); + goto err_free_pd; + } + netns_tracker_alloc(pd->net, &pd->ns_tracker, GFP_KERNEL); + rc = seq_open(file, &l2tp_dfs_seq_ops); + if (rc) + goto err_free_net; + + seq = file->private_data; + seq->private = pd; + +out: + return rc; + +err_free_net: + put_net_track(pd->net, &pd->ns_tracker); +err_free_pd: + kfree(pd); + goto out; +} + +static int l2tp_dfs_seq_release(struct inode *inode, struct file *file) +{ + struct l2tp_dfs_seq_data *pd; + struct seq_file *seq; + + seq = file->private_data; + pd = seq->private; + if (pd->net) + put_net_track(pd->net, &pd->ns_tracker); + kfree(pd); + seq_release(inode, file); + + return 0; +} + +static const struct file_operations l2tp_dfs_fops = { + .owner = THIS_MODULE, + .open = l2tp_dfs_seq_open, + .read = seq_read, + .llseek = seq_lseek, + .release = l2tp_dfs_seq_release, +}; + +static int __init l2tp_debugfs_init(void) +{ + rootdir = debugfs_create_dir("l2tp", NULL); + + debugfs_create_file("tunnels", 0600, rootdir, NULL, &l2tp_dfs_fops); + + pr_info("L2TP debugfs support\n"); + + return 0; +} + +static void __exit l2tp_debugfs_exit(void) +{ + debugfs_remove_recursive(rootdir); +} + +module_init(l2tp_debugfs_init); +module_exit(l2tp_debugfs_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("James Chapman "); +MODULE_DESCRIPTION("L2TP debugfs driver"); +MODULE_VERSION("1.0"); diff --git a/net/l2tp/l2tp_eth.c b/net/l2tp/l2tp_eth.c new file mode 100644 index 000000000..f2ae03c40 --- /dev/null +++ b/net/l2tp/l2tp_eth.c @@ -0,0 +1,370 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* L2TPv3 ethernet pseudowire driver + * + * Copyright (c) 2008,2009,2010 Katalix Systems Ltd + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "l2tp_core.h" + +/* Default device name. May be overridden by name specified by user */ +#define L2TP_ETH_DEV_NAME "l2tpeth%d" + +/* via netdev_priv() */ +struct l2tp_eth { + struct l2tp_session *session; + atomic_long_t tx_bytes; + atomic_long_t tx_packets; + atomic_long_t tx_dropped; + atomic_long_t rx_bytes; + atomic_long_t rx_packets; + atomic_long_t rx_errors; +}; + +/* via l2tp_session_priv() */ +struct l2tp_eth_sess { + struct net_device __rcu *dev; +}; + +static int l2tp_eth_dev_init(struct net_device *dev) +{ + eth_hw_addr_random(dev); + eth_broadcast_addr(dev->broadcast); + netdev_lockdep_set_classes(dev); + + return 0; +} + +static void l2tp_eth_dev_uninit(struct net_device *dev) +{ + struct l2tp_eth *priv = netdev_priv(dev); + struct l2tp_eth_sess *spriv; + + spriv = l2tp_session_priv(priv->session); + RCU_INIT_POINTER(spriv->dev, NULL); + /* No need for synchronize_net() here. We're called by + * unregister_netdev*(), which does the synchronisation for us. + */ +} + +static netdev_tx_t l2tp_eth_dev_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct l2tp_eth *priv = netdev_priv(dev); + struct l2tp_session *session = priv->session; + unsigned int len = skb->len; + int ret = l2tp_xmit_skb(session, skb); + + if (likely(ret == NET_XMIT_SUCCESS)) { + atomic_long_add(len, &priv->tx_bytes); + atomic_long_inc(&priv->tx_packets); + } else { + atomic_long_inc(&priv->tx_dropped); + } + return NETDEV_TX_OK; +} + +static void l2tp_eth_get_stats64(struct net_device *dev, + struct rtnl_link_stats64 *stats) +{ + struct l2tp_eth *priv = netdev_priv(dev); + + stats->tx_bytes = (unsigned long)atomic_long_read(&priv->tx_bytes); + stats->tx_packets = (unsigned long)atomic_long_read(&priv->tx_packets); + stats->tx_dropped = (unsigned long)atomic_long_read(&priv->tx_dropped); + stats->rx_bytes = (unsigned long)atomic_long_read(&priv->rx_bytes); + stats->rx_packets = (unsigned long)atomic_long_read(&priv->rx_packets); + stats->rx_errors = (unsigned long)atomic_long_read(&priv->rx_errors); +} + +static const struct net_device_ops l2tp_eth_netdev_ops = { + .ndo_init = l2tp_eth_dev_init, + .ndo_uninit = l2tp_eth_dev_uninit, + .ndo_start_xmit = l2tp_eth_dev_xmit, + .ndo_get_stats64 = l2tp_eth_get_stats64, + .ndo_set_mac_address = eth_mac_addr, +}; + +static struct device_type l2tpeth_type = { + .name = "l2tpeth", +}; + +static void l2tp_eth_dev_setup(struct net_device *dev) +{ + SET_NETDEV_DEVTYPE(dev, &l2tpeth_type); + ether_setup(dev); + dev->priv_flags &= ~IFF_TX_SKB_SHARING; + dev->features |= NETIF_F_LLTX; + dev->netdev_ops = &l2tp_eth_netdev_ops; + dev->needs_free_netdev = true; +} + +static void l2tp_eth_dev_recv(struct l2tp_session *session, struct sk_buff *skb, int data_len) +{ + struct l2tp_eth_sess *spriv = l2tp_session_priv(session); + struct net_device *dev; + struct l2tp_eth *priv; + + if (!pskb_may_pull(skb, ETH_HLEN)) + goto error; + + secpath_reset(skb); + + /* checksums verified by L2TP */ + skb->ip_summed = CHECKSUM_NONE; + + skb_dst_drop(skb); + nf_reset_ct(skb); + + rcu_read_lock(); + dev = rcu_dereference(spriv->dev); + if (!dev) + goto error_rcu; + + priv = netdev_priv(dev); + if (dev_forward_skb(dev, skb) == NET_RX_SUCCESS) { + atomic_long_inc(&priv->rx_packets); + atomic_long_add(data_len, &priv->rx_bytes); + } else { + atomic_long_inc(&priv->rx_errors); + } + rcu_read_unlock(); + + return; + +error_rcu: + rcu_read_unlock(); +error: + kfree_skb(skb); +} + +static void l2tp_eth_delete(struct l2tp_session *session) +{ + struct l2tp_eth_sess *spriv; + struct net_device *dev; + + if (session) { + spriv = l2tp_session_priv(session); + + rtnl_lock(); + dev = rtnl_dereference(spriv->dev); + if (dev) { + unregister_netdevice(dev); + rtnl_unlock(); + module_put(THIS_MODULE); + } else { + rtnl_unlock(); + } + } +} + +static void l2tp_eth_show(struct seq_file *m, void *arg) +{ + struct l2tp_session *session = arg; + struct l2tp_eth_sess *spriv = l2tp_session_priv(session); + struct net_device *dev; + + rcu_read_lock(); + dev = rcu_dereference(spriv->dev); + if (!dev) { + rcu_read_unlock(); + return; + } + dev_hold(dev); + rcu_read_unlock(); + + seq_printf(m, " interface %s\n", dev->name); + + dev_put(dev); +} + +static void l2tp_eth_adjust_mtu(struct l2tp_tunnel *tunnel, + struct l2tp_session *session, + struct net_device *dev) +{ + unsigned int overhead = 0; + u32 l3_overhead = 0; + u32 mtu; + + /* if the encap is UDP, account for UDP header size */ + if (tunnel->encap == L2TP_ENCAPTYPE_UDP) { + overhead += sizeof(struct udphdr); + dev->needed_headroom += sizeof(struct udphdr); + } + + lock_sock(tunnel->sock); + l3_overhead = kernel_sock_ip_overhead(tunnel->sock); + release_sock(tunnel->sock); + + if (l3_overhead == 0) { + /* L3 Overhead couldn't be identified, this could be + * because tunnel->sock was NULL or the socket's + * address family was not IPv4 or IPv6, + * dev mtu stays at 1500. + */ + return; + } + /* Adjust MTU, factor overhead - underlay L3, overlay L2 hdr + * UDP overhead, if any, was already factored in above. + */ + overhead += session->hdr_len + ETH_HLEN + l3_overhead; + + mtu = l2tp_tunnel_dst_mtu(tunnel) - overhead; + if (mtu < dev->min_mtu || mtu > dev->max_mtu) + dev->mtu = ETH_DATA_LEN - overhead; + else + dev->mtu = mtu; + + dev->needed_headroom += session->hdr_len; +} + +static int l2tp_eth_create(struct net *net, struct l2tp_tunnel *tunnel, + u32 session_id, u32 peer_session_id, + struct l2tp_session_cfg *cfg) +{ + unsigned char name_assign_type; + struct net_device *dev; + char name[IFNAMSIZ]; + struct l2tp_session *session; + struct l2tp_eth *priv; + struct l2tp_eth_sess *spriv; + int rc; + + if (cfg->ifname) { + strscpy(name, cfg->ifname, IFNAMSIZ); + name_assign_type = NET_NAME_USER; + } else { + strcpy(name, L2TP_ETH_DEV_NAME); + name_assign_type = NET_NAME_ENUM; + } + + session = l2tp_session_create(sizeof(*spriv), tunnel, session_id, + peer_session_id, cfg); + if (IS_ERR(session)) { + rc = PTR_ERR(session); + goto err; + } + + dev = alloc_netdev(sizeof(*priv), name, name_assign_type, + l2tp_eth_dev_setup); + if (!dev) { + rc = -ENOMEM; + goto err_sess; + } + + dev_net_set(dev, net); + dev->min_mtu = 0; + dev->max_mtu = ETH_MAX_MTU; + l2tp_eth_adjust_mtu(tunnel, session, dev); + + priv = netdev_priv(dev); + priv->session = session; + + session->recv_skb = l2tp_eth_dev_recv; + session->session_close = l2tp_eth_delete; + if (IS_ENABLED(CONFIG_L2TP_DEBUGFS)) + session->show = l2tp_eth_show; + + spriv = l2tp_session_priv(session); + + l2tp_session_inc_refcount(session); + + rtnl_lock(); + + /* Register both device and session while holding the rtnl lock. This + * ensures that l2tp_eth_delete() will see that there's a device to + * unregister, even if it happened to run before we assign spriv->dev. + */ + rc = l2tp_session_register(session, tunnel); + if (rc < 0) { + rtnl_unlock(); + goto err_sess_dev; + } + + rc = register_netdevice(dev); + if (rc < 0) { + rtnl_unlock(); + l2tp_session_delete(session); + l2tp_session_dec_refcount(session); + free_netdev(dev); + + return rc; + } + + strscpy(session->ifname, dev->name, IFNAMSIZ); + rcu_assign_pointer(spriv->dev, dev); + + rtnl_unlock(); + + l2tp_session_dec_refcount(session); + + __module_get(THIS_MODULE); + + return 0; + +err_sess_dev: + l2tp_session_dec_refcount(session); + free_netdev(dev); +err_sess: + kfree(session); +err: + return rc; +} + +static const struct l2tp_nl_cmd_ops l2tp_eth_nl_cmd_ops = { + .session_create = l2tp_eth_create, + .session_delete = l2tp_session_delete, +}; + +static int __init l2tp_eth_init(void) +{ + int err = 0; + + err = l2tp_nl_register_ops(L2TP_PWTYPE_ETH, &l2tp_eth_nl_cmd_ops); + if (err) + goto err; + + pr_info("L2TP ethernet pseudowire support (L2TPv3)\n"); + + return 0; + +err: + return err; +} + +static void __exit l2tp_eth_exit(void) +{ + l2tp_nl_unregister_ops(L2TP_PWTYPE_ETH); +} + +module_init(l2tp_eth_init); +module_exit(l2tp_eth_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("James Chapman "); +MODULE_DESCRIPTION("L2TP ethernet pseudowire driver"); +MODULE_VERSION("1.0"); +MODULE_ALIAS_L2TP_PWTYPE(5); diff --git a/net/l2tp/l2tp_ip.c b/net/l2tp/l2tp_ip.c new file mode 100644 index 000000000..41a74fc84 --- /dev/null +++ b/net/l2tp/l2tp_ip.c @@ -0,0 +1,684 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* L2TPv3 IP encapsulation support + * + * Copyright (c) 2008,2009,2010 Katalix Systems Ltd + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "l2tp_core.h" + +struct l2tp_ip_sock { + /* inet_sock has to be the first member of l2tp_ip_sock */ + struct inet_sock inet; + + u32 conn_id; + u32 peer_conn_id; +}; + +static DEFINE_RWLOCK(l2tp_ip_lock); +static struct hlist_head l2tp_ip_table; +static struct hlist_head l2tp_ip_bind_table; + +static inline struct l2tp_ip_sock *l2tp_ip_sk(const struct sock *sk) +{ + return (struct l2tp_ip_sock *)sk; +} + +static struct sock *__l2tp_ip_bind_lookup(const struct net *net, __be32 laddr, + __be32 raddr, int dif, u32 tunnel_id) +{ + struct sock *sk; + + sk_for_each_bound(sk, &l2tp_ip_bind_table) { + const struct l2tp_ip_sock *l2tp = l2tp_ip_sk(sk); + const struct inet_sock *inet = inet_sk(sk); + int bound_dev_if; + + if (!net_eq(sock_net(sk), net)) + continue; + + bound_dev_if = READ_ONCE(sk->sk_bound_dev_if); + if (bound_dev_if && dif && bound_dev_if != dif) + continue; + + if (inet->inet_rcv_saddr && laddr && + inet->inet_rcv_saddr != laddr) + continue; + + if (inet->inet_daddr && raddr && inet->inet_daddr != raddr) + continue; + + if (l2tp->conn_id != tunnel_id) + continue; + + goto found; + } + + sk = NULL; +found: + return sk; +} + +/* When processing receive frames, there are two cases to + * consider. Data frames consist of a non-zero session-id and an + * optional cookie. Control frames consist of a regular L2TP header + * preceded by 32-bits of zeros. + * + * L2TPv3 Session Header Over IP + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Session ID | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Cookie (optional, maximum 64 bits)... + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * L2TPv3 Control Message Header Over IP + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | (32 bits of zeros) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * |T|L|x|x|S|x|x|x|x|x|x|x| Ver | Length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Control Connection ID | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Ns | Nr | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * All control frames are passed to userspace. + */ +static int l2tp_ip_recv(struct sk_buff *skb) +{ + struct net *net = dev_net(skb->dev); + struct sock *sk; + u32 session_id; + u32 tunnel_id; + unsigned char *ptr, *optr; + struct l2tp_session *session; + struct l2tp_tunnel *tunnel = NULL; + struct iphdr *iph; + + if (!pskb_may_pull(skb, 4)) + goto discard; + + /* Point to L2TP header */ + optr = skb->data; + ptr = skb->data; + session_id = ntohl(*((__be32 *)ptr)); + ptr += 4; + + /* RFC3931: L2TP/IP packets have the first 4 bytes containing + * the session_id. If it is 0, the packet is a L2TP control + * frame and the session_id value can be discarded. + */ + if (session_id == 0) { + __skb_pull(skb, 4); + goto pass_up; + } + + /* Ok, this is a data packet. Lookup the session. */ + session = l2tp_session_get(net, session_id); + if (!session) + goto discard; + + tunnel = session->tunnel; + if (!tunnel) + goto discard_sess; + + if (l2tp_v3_ensure_opt_in_linear(session, skb, &ptr, &optr)) + goto discard_sess; + + l2tp_recv_common(session, skb, ptr, optr, 0, skb->len); + l2tp_session_dec_refcount(session); + + return 0; + +pass_up: + /* Get the tunnel_id from the L2TP header */ + if (!pskb_may_pull(skb, 12)) + goto discard; + + if ((skb->data[0] & 0xc0) != 0xc0) + goto discard; + + tunnel_id = ntohl(*(__be32 *)&skb->data[4]); + iph = (struct iphdr *)skb_network_header(skb); + + read_lock_bh(&l2tp_ip_lock); + sk = __l2tp_ip_bind_lookup(net, iph->daddr, iph->saddr, inet_iif(skb), + tunnel_id); + if (!sk) { + read_unlock_bh(&l2tp_ip_lock); + goto discard; + } + sock_hold(sk); + read_unlock_bh(&l2tp_ip_lock); + + if (!xfrm4_policy_check(sk, XFRM_POLICY_IN, skb)) + goto discard_put; + + nf_reset_ct(skb); + + return sk_receive_skb(sk, skb, 1); + +discard_sess: + l2tp_session_dec_refcount(session); + goto discard; + +discard_put: + sock_put(sk); + +discard: + kfree_skb(skb); + return 0; +} + +static int l2tp_ip_hash(struct sock *sk) +{ + if (sk_unhashed(sk)) { + write_lock_bh(&l2tp_ip_lock); + sk_add_node(sk, &l2tp_ip_table); + write_unlock_bh(&l2tp_ip_lock); + } + return 0; +} + +static void l2tp_ip_unhash(struct sock *sk) +{ + if (sk_unhashed(sk)) + return; + write_lock_bh(&l2tp_ip_lock); + sk_del_node_init(sk); + write_unlock_bh(&l2tp_ip_lock); +} + +static int l2tp_ip_open(struct sock *sk) +{ + /* Prevent autobind. We don't have ports. */ + inet_sk(sk)->inet_num = IPPROTO_L2TP; + + l2tp_ip_hash(sk); + return 0; +} + +static void l2tp_ip_close(struct sock *sk, long timeout) +{ + write_lock_bh(&l2tp_ip_lock); + hlist_del_init(&sk->sk_bind_node); + sk_del_node_init(sk); + write_unlock_bh(&l2tp_ip_lock); + sk_common_release(sk); +} + +static void l2tp_ip_destroy_sock(struct sock *sk) +{ + struct l2tp_tunnel *tunnel = l2tp_sk_to_tunnel(sk); + struct sk_buff *skb; + + while ((skb = __skb_dequeue_tail(&sk->sk_write_queue)) != NULL) + kfree_skb(skb); + + if (tunnel) + l2tp_tunnel_delete(tunnel); +} + +static int l2tp_ip_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len) +{ + struct inet_sock *inet = inet_sk(sk); + struct sockaddr_l2tpip *addr = (struct sockaddr_l2tpip *)uaddr; + struct net *net = sock_net(sk); + int ret; + int chk_addr_ret; + + if (addr_len < sizeof(struct sockaddr_l2tpip)) + return -EINVAL; + if (addr->l2tp_family != AF_INET) + return -EINVAL; + + lock_sock(sk); + + ret = -EINVAL; + if (!sock_flag(sk, SOCK_ZAPPED)) + goto out; + + if (sk->sk_state != TCP_CLOSE) + goto out; + + chk_addr_ret = inet_addr_type(net, addr->l2tp_addr.s_addr); + ret = -EADDRNOTAVAIL; + if (addr->l2tp_addr.s_addr && chk_addr_ret != RTN_LOCAL && + chk_addr_ret != RTN_MULTICAST && chk_addr_ret != RTN_BROADCAST) + goto out; + + if (addr->l2tp_addr.s_addr) { + inet->inet_rcv_saddr = addr->l2tp_addr.s_addr; + inet->inet_saddr = addr->l2tp_addr.s_addr; + } + if (chk_addr_ret == RTN_MULTICAST || chk_addr_ret == RTN_BROADCAST) + inet->inet_saddr = 0; /* Use device */ + + write_lock_bh(&l2tp_ip_lock); + if (__l2tp_ip_bind_lookup(net, addr->l2tp_addr.s_addr, 0, + sk->sk_bound_dev_if, addr->l2tp_conn_id)) { + write_unlock_bh(&l2tp_ip_lock); + ret = -EADDRINUSE; + goto out; + } + + sk_dst_reset(sk); + l2tp_ip_sk(sk)->conn_id = addr->l2tp_conn_id; + + sk_add_bind_node(sk, &l2tp_ip_bind_table); + sk_del_node_init(sk); + write_unlock_bh(&l2tp_ip_lock); + + ret = 0; + sock_reset_flag(sk, SOCK_ZAPPED); + +out: + release_sock(sk); + + return ret; +} + +static int l2tp_ip_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) +{ + struct sockaddr_l2tpip *lsa = (struct sockaddr_l2tpip *)uaddr; + int rc; + + if (addr_len < sizeof(*lsa)) + return -EINVAL; + + if (ipv4_is_multicast(lsa->l2tp_addr.s_addr)) + return -EINVAL; + + lock_sock(sk); + + /* Must bind first - autobinding does not work */ + if (sock_flag(sk, SOCK_ZAPPED)) { + rc = -EINVAL; + goto out_sk; + } + + rc = __ip4_datagram_connect(sk, uaddr, addr_len); + if (rc < 0) + goto out_sk; + + l2tp_ip_sk(sk)->peer_conn_id = lsa->l2tp_conn_id; + + write_lock_bh(&l2tp_ip_lock); + hlist_del_init(&sk->sk_bind_node); + sk_add_bind_node(sk, &l2tp_ip_bind_table); + write_unlock_bh(&l2tp_ip_lock); + +out_sk: + release_sock(sk); + + return rc; +} + +static int l2tp_ip_disconnect(struct sock *sk, int flags) +{ + if (sock_flag(sk, SOCK_ZAPPED)) + return 0; + + return __udp_disconnect(sk, flags); +} + +static int l2tp_ip_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + struct sock *sk = sock->sk; + struct inet_sock *inet = inet_sk(sk); + struct l2tp_ip_sock *lsk = l2tp_ip_sk(sk); + struct sockaddr_l2tpip *lsa = (struct sockaddr_l2tpip *)uaddr; + + memset(lsa, 0, sizeof(*lsa)); + lsa->l2tp_family = AF_INET; + if (peer) { + if (!inet->inet_dport) + return -ENOTCONN; + lsa->l2tp_conn_id = lsk->peer_conn_id; + lsa->l2tp_addr.s_addr = inet->inet_daddr; + } else { + __be32 addr = inet->inet_rcv_saddr; + + if (!addr) + addr = inet->inet_saddr; + lsa->l2tp_conn_id = lsk->conn_id; + lsa->l2tp_addr.s_addr = addr; + } + return sizeof(*lsa); +} + +static int l2tp_ip_backlog_recv(struct sock *sk, struct sk_buff *skb) +{ + int rc; + + /* Charge it to the socket, dropping if the queue is full. */ + rc = sock_queue_rcv_skb(sk, skb); + if (rc < 0) + goto drop; + + return 0; + +drop: + IP_INC_STATS(sock_net(sk), IPSTATS_MIB_INDISCARDS); + kfree_skb(skb); + return 0; +} + +/* Userspace will call sendmsg() on the tunnel socket to send L2TP + * control frames. + */ +static int l2tp_ip_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) +{ + struct sk_buff *skb; + int rc; + struct inet_sock *inet = inet_sk(sk); + struct rtable *rt = NULL; + struct flowi4 *fl4; + int connected = 0; + __be32 daddr; + + lock_sock(sk); + + rc = -ENOTCONN; + if (sock_flag(sk, SOCK_DEAD)) + goto out; + + /* Get and verify the address. */ + if (msg->msg_name) { + DECLARE_SOCKADDR(struct sockaddr_l2tpip *, lip, msg->msg_name); + + rc = -EINVAL; + if (msg->msg_namelen < sizeof(*lip)) + goto out; + + if (lip->l2tp_family != AF_INET) { + rc = -EAFNOSUPPORT; + if (lip->l2tp_family != AF_UNSPEC) + goto out; + } + + daddr = lip->l2tp_addr.s_addr; + } else { + rc = -EDESTADDRREQ; + if (sk->sk_state != TCP_ESTABLISHED) + goto out; + + daddr = inet->inet_daddr; + connected = 1; + } + + /* Allocate a socket buffer */ + rc = -ENOMEM; + skb = sock_wmalloc(sk, 2 + NET_SKB_PAD + sizeof(struct iphdr) + + 4 + len, 0, GFP_KERNEL); + if (!skb) + goto error; + + /* Reserve space for headers, putting IP header on 4-byte boundary. */ + skb_reserve(skb, 2 + NET_SKB_PAD); + skb_reset_network_header(skb); + skb_reserve(skb, sizeof(struct iphdr)); + skb_reset_transport_header(skb); + + /* Insert 0 session_id */ + *((__be32 *)skb_put(skb, 4)) = 0; + + /* Copy user data into skb */ + rc = memcpy_from_msg(skb_put(skb, len), msg, len); + if (rc < 0) { + kfree_skb(skb); + goto error; + } + + fl4 = &inet->cork.fl.u.ip4; + if (connected) + rt = (struct rtable *)__sk_dst_check(sk, 0); + + rcu_read_lock(); + if (!rt) { + const struct ip_options_rcu *inet_opt; + + inet_opt = rcu_dereference(inet->inet_opt); + + /* Use correct destination address if we have options. */ + if (inet_opt && inet_opt->opt.srr) + daddr = inet_opt->opt.faddr; + + /* If this fails, retransmit mechanism of transport layer will + * keep trying until route appears or the connection times + * itself out. + */ + rt = ip_route_output_ports(sock_net(sk), fl4, sk, + daddr, inet->inet_saddr, + inet->inet_dport, inet->inet_sport, + sk->sk_protocol, RT_CONN_FLAGS(sk), + sk->sk_bound_dev_if); + if (IS_ERR(rt)) + goto no_route; + if (connected) { + sk_setup_caps(sk, &rt->dst); + } else { + skb_dst_set(skb, &rt->dst); + goto xmit; + } + } + + /* We don't need to clone dst here, it is guaranteed to not disappear. + * __dev_xmit_skb() might force a refcount if needed. + */ + skb_dst_set_noref(skb, &rt->dst); + +xmit: + /* Queue the packet to IP for output */ + rc = ip_queue_xmit(sk, skb, &inet->cork.fl); + rcu_read_unlock(); + +error: + if (rc >= 0) + rc = len; + +out: + release_sock(sk); + return rc; + +no_route: + rcu_read_unlock(); + IP_INC_STATS(sock_net(sk), IPSTATS_MIB_OUTNOROUTES); + kfree_skb(skb); + rc = -EHOSTUNREACH; + goto out; +} + +static int l2tp_ip_recvmsg(struct sock *sk, struct msghdr *msg, + size_t len, int flags, int *addr_len) +{ + struct inet_sock *inet = inet_sk(sk); + size_t copied = 0; + int err = -EOPNOTSUPP; + DECLARE_SOCKADDR(struct sockaddr_in *, sin, msg->msg_name); + struct sk_buff *skb; + + if (flags & MSG_OOB) + goto out; + + skb = skb_recv_datagram(sk, flags, &err); + if (!skb) + goto out; + + copied = skb->len; + if (len < copied) { + msg->msg_flags |= MSG_TRUNC; + copied = len; + } + + err = skb_copy_datagram_msg(skb, 0, msg, copied); + if (err) + goto done; + + sock_recv_timestamp(msg, sk, skb); + + /* Copy the address. */ + if (sin) { + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = ip_hdr(skb)->saddr; + sin->sin_port = 0; + memset(&sin->sin_zero, 0, sizeof(sin->sin_zero)); + *addr_len = sizeof(*sin); + } + if (inet->cmsg_flags) + ip_cmsg_recv(msg, skb); + if (flags & MSG_TRUNC) + copied = skb->len; +done: + skb_free_datagram(sk, skb); +out: + return err ? err : copied; +} + +int l2tp_ioctl(struct sock *sk, int cmd, unsigned long arg) +{ + struct sk_buff *skb; + int amount; + + switch (cmd) { + case SIOCOUTQ: + amount = sk_wmem_alloc_get(sk); + break; + case SIOCINQ: + spin_lock_bh(&sk->sk_receive_queue.lock); + skb = skb_peek(&sk->sk_receive_queue); + amount = skb ? skb->len : 0; + spin_unlock_bh(&sk->sk_receive_queue.lock); + break; + + default: + return -ENOIOCTLCMD; + } + + return put_user(amount, (int __user *)arg); +} +EXPORT_SYMBOL_GPL(l2tp_ioctl); + +static struct proto l2tp_ip_prot = { + .name = "L2TP/IP", + .owner = THIS_MODULE, + .init = l2tp_ip_open, + .close = l2tp_ip_close, + .bind = l2tp_ip_bind, + .connect = l2tp_ip_connect, + .disconnect = l2tp_ip_disconnect, + .ioctl = l2tp_ioctl, + .destroy = l2tp_ip_destroy_sock, + .setsockopt = ip_setsockopt, + .getsockopt = ip_getsockopt, + .sendmsg = l2tp_ip_sendmsg, + .recvmsg = l2tp_ip_recvmsg, + .backlog_rcv = l2tp_ip_backlog_recv, + .hash = l2tp_ip_hash, + .unhash = l2tp_ip_unhash, + .obj_size = sizeof(struct l2tp_ip_sock), +}; + +static const struct proto_ops l2tp_ip_ops = { + .family = PF_INET, + .owner = THIS_MODULE, + .release = inet_release, + .bind = inet_bind, + .connect = inet_dgram_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = l2tp_ip_getname, + .poll = datagram_poll, + .ioctl = inet_ioctl, + .gettstamp = sock_gettstamp, + .listen = sock_no_listen, + .shutdown = inet_shutdown, + .setsockopt = sock_common_setsockopt, + .getsockopt = sock_common_getsockopt, + .sendmsg = inet_sendmsg, + .recvmsg = sock_common_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static struct inet_protosw l2tp_ip_protosw = { + .type = SOCK_DGRAM, + .protocol = IPPROTO_L2TP, + .prot = &l2tp_ip_prot, + .ops = &l2tp_ip_ops, +}; + +static struct net_protocol l2tp_ip_protocol __read_mostly = { + .handler = l2tp_ip_recv, +}; + +static int __init l2tp_ip_init(void) +{ + int err; + + pr_info("L2TP IP encapsulation support (L2TPv3)\n"); + + err = proto_register(&l2tp_ip_prot, 1); + if (err != 0) + goto out; + + err = inet_add_protocol(&l2tp_ip_protocol, IPPROTO_L2TP); + if (err) + goto out1; + + inet_register_protosw(&l2tp_ip_protosw); + return 0; + +out1: + proto_unregister(&l2tp_ip_prot); +out: + return err; +} + +static void __exit l2tp_ip_exit(void) +{ + inet_unregister_protosw(&l2tp_ip_protosw); + inet_del_protocol(&l2tp_ip_protocol, IPPROTO_L2TP); + proto_unregister(&l2tp_ip_prot); +} + +module_init(l2tp_ip_init); +module_exit(l2tp_ip_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("James Chapman "); +MODULE_DESCRIPTION("L2TP over IP"); +MODULE_VERSION("1.0"); + +/* Use the values of SOCK_DGRAM (2) as type and IPPROTO_L2TP (115) as protocol, + * because __stringify doesn't like enums + */ +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_INET, 115, 2); +MODULE_ALIAS_NET_PF_PROTO(PF_INET, 115); diff --git a/net/l2tp/l2tp_ip6.c b/net/l2tp/l2tp_ip6.c new file mode 100644 index 000000000..314ec3a51 --- /dev/null +++ b/net/l2tp/l2tp_ip6.c @@ -0,0 +1,813 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* L2TPv3 IP encapsulation support for IPv6 + * + * Copyright (c) 2012 Katalix Systems Ltd + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "l2tp_core.h" + +struct l2tp_ip6_sock { + /* inet_sock has to be the first member of l2tp_ip6_sock */ + struct inet_sock inet; + + u32 conn_id; + u32 peer_conn_id; + + /* ipv6_pinfo has to be the last member of l2tp_ip6_sock, see + * inet6_sk_generic + */ + struct ipv6_pinfo inet6; +}; + +static DEFINE_RWLOCK(l2tp_ip6_lock); +static struct hlist_head l2tp_ip6_table; +static struct hlist_head l2tp_ip6_bind_table; + +static inline struct l2tp_ip6_sock *l2tp_ip6_sk(const struct sock *sk) +{ + return (struct l2tp_ip6_sock *)sk; +} + +static struct sock *__l2tp_ip6_bind_lookup(const struct net *net, + const struct in6_addr *laddr, + const struct in6_addr *raddr, + int dif, u32 tunnel_id) +{ + struct sock *sk; + + sk_for_each_bound(sk, &l2tp_ip6_bind_table) { + const struct in6_addr *sk_laddr = inet6_rcv_saddr(sk); + const struct in6_addr *sk_raddr = &sk->sk_v6_daddr; + const struct l2tp_ip6_sock *l2tp = l2tp_ip6_sk(sk); + int bound_dev_if; + + if (!net_eq(sock_net(sk), net)) + continue; + + bound_dev_if = READ_ONCE(sk->sk_bound_dev_if); + if (bound_dev_if && dif && bound_dev_if != dif) + continue; + + if (sk_laddr && !ipv6_addr_any(sk_laddr) && + !ipv6_addr_any(laddr) && !ipv6_addr_equal(sk_laddr, laddr)) + continue; + + if (!ipv6_addr_any(sk_raddr) && raddr && + !ipv6_addr_any(raddr) && !ipv6_addr_equal(sk_raddr, raddr)) + continue; + + if (l2tp->conn_id != tunnel_id) + continue; + + goto found; + } + + sk = NULL; +found: + return sk; +} + +/* When processing receive frames, there are two cases to + * consider. Data frames consist of a non-zero session-id and an + * optional cookie. Control frames consist of a regular L2TP header + * preceded by 32-bits of zeros. + * + * L2TPv3 Session Header Over IP + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Session ID | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Cookie (optional, maximum 64 bits)... + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * L2TPv3 Control Message Header Over IP + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | (32 bits of zeros) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * |T|L|x|x|S|x|x|x|x|x|x|x| Ver | Length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Control Connection ID | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Ns | Nr | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * All control frames are passed to userspace. + */ +static int l2tp_ip6_recv(struct sk_buff *skb) +{ + struct net *net = dev_net(skb->dev); + struct sock *sk; + u32 session_id; + u32 tunnel_id; + unsigned char *ptr, *optr; + struct l2tp_session *session; + struct l2tp_tunnel *tunnel = NULL; + struct ipv6hdr *iph; + + if (!pskb_may_pull(skb, 4)) + goto discard; + + /* Point to L2TP header */ + optr = skb->data; + ptr = skb->data; + session_id = ntohl(*((__be32 *)ptr)); + ptr += 4; + + /* RFC3931: L2TP/IP packets have the first 4 bytes containing + * the session_id. If it is 0, the packet is a L2TP control + * frame and the session_id value can be discarded. + */ + if (session_id == 0) { + __skb_pull(skb, 4); + goto pass_up; + } + + /* Ok, this is a data packet. Lookup the session. */ + session = l2tp_session_get(net, session_id); + if (!session) + goto discard; + + tunnel = session->tunnel; + if (!tunnel) + goto discard_sess; + + if (l2tp_v3_ensure_opt_in_linear(session, skb, &ptr, &optr)) + goto discard_sess; + + l2tp_recv_common(session, skb, ptr, optr, 0, skb->len); + l2tp_session_dec_refcount(session); + + return 0; + +pass_up: + /* Get the tunnel_id from the L2TP header */ + if (!pskb_may_pull(skb, 12)) + goto discard; + + if ((skb->data[0] & 0xc0) != 0xc0) + goto discard; + + tunnel_id = ntohl(*(__be32 *)&skb->data[4]); + iph = ipv6_hdr(skb); + + read_lock_bh(&l2tp_ip6_lock); + sk = __l2tp_ip6_bind_lookup(net, &iph->daddr, &iph->saddr, + inet6_iif(skb), tunnel_id); + if (!sk) { + read_unlock_bh(&l2tp_ip6_lock); + goto discard; + } + sock_hold(sk); + read_unlock_bh(&l2tp_ip6_lock); + + if (!xfrm6_policy_check(sk, XFRM_POLICY_IN, skb)) + goto discard_put; + + nf_reset_ct(skb); + + return sk_receive_skb(sk, skb, 1); + +discard_sess: + l2tp_session_dec_refcount(session); + goto discard; + +discard_put: + sock_put(sk); + +discard: + kfree_skb(skb); + return 0; +} + +static int l2tp_ip6_hash(struct sock *sk) +{ + if (sk_unhashed(sk)) { + write_lock_bh(&l2tp_ip6_lock); + sk_add_node(sk, &l2tp_ip6_table); + write_unlock_bh(&l2tp_ip6_lock); + } + return 0; +} + +static void l2tp_ip6_unhash(struct sock *sk) +{ + if (sk_unhashed(sk)) + return; + write_lock_bh(&l2tp_ip6_lock); + sk_del_node_init(sk); + write_unlock_bh(&l2tp_ip6_lock); +} + +static int l2tp_ip6_open(struct sock *sk) +{ + /* Prevent autobind. We don't have ports. */ + inet_sk(sk)->inet_num = IPPROTO_L2TP; + + l2tp_ip6_hash(sk); + return 0; +} + +static void l2tp_ip6_close(struct sock *sk, long timeout) +{ + write_lock_bh(&l2tp_ip6_lock); + hlist_del_init(&sk->sk_bind_node); + sk_del_node_init(sk); + write_unlock_bh(&l2tp_ip6_lock); + + sk_common_release(sk); +} + +static void l2tp_ip6_destroy_sock(struct sock *sk) +{ + struct l2tp_tunnel *tunnel = l2tp_sk_to_tunnel(sk); + + lock_sock(sk); + ip6_flush_pending_frames(sk); + release_sock(sk); + + if (tunnel) + l2tp_tunnel_delete(tunnel); +} + +static int l2tp_ip6_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len) +{ + struct inet_sock *inet = inet_sk(sk); + struct ipv6_pinfo *np = inet6_sk(sk); + struct sockaddr_l2tpip6 *addr = (struct sockaddr_l2tpip6 *)uaddr; + struct net *net = sock_net(sk); + __be32 v4addr = 0; + int bound_dev_if; + int addr_type; + int err; + + if (addr->l2tp_family != AF_INET6) + return -EINVAL; + if (addr_len < sizeof(*addr)) + return -EINVAL; + + addr_type = ipv6_addr_type(&addr->l2tp_addr); + + /* l2tp_ip6 sockets are IPv6 only */ + if (addr_type == IPV6_ADDR_MAPPED) + return -EADDRNOTAVAIL; + + /* L2TP is point-point, not multicast */ + if (addr_type & IPV6_ADDR_MULTICAST) + return -EADDRNOTAVAIL; + + lock_sock(sk); + + err = -EINVAL; + if (!sock_flag(sk, SOCK_ZAPPED)) + goto out_unlock; + + if (sk->sk_state != TCP_CLOSE) + goto out_unlock; + + bound_dev_if = sk->sk_bound_dev_if; + + /* Check if the address belongs to the host. */ + rcu_read_lock(); + if (addr_type != IPV6_ADDR_ANY) { + struct net_device *dev = NULL; + + if (addr_type & IPV6_ADDR_LINKLOCAL) { + if (addr->l2tp_scope_id) + bound_dev_if = addr->l2tp_scope_id; + + /* Binding to link-local address requires an + * interface. + */ + if (!bound_dev_if) + goto out_unlock_rcu; + + err = -ENODEV; + dev = dev_get_by_index_rcu(sock_net(sk), bound_dev_if); + if (!dev) + goto out_unlock_rcu; + } + + /* ipv4 addr of the socket is invalid. Only the + * unspecified and mapped address have a v4 equivalent. + */ + v4addr = LOOPBACK4_IPV6; + err = -EADDRNOTAVAIL; + if (!ipv6_chk_addr(sock_net(sk), &addr->l2tp_addr, dev, 0)) + goto out_unlock_rcu; + } + rcu_read_unlock(); + + write_lock_bh(&l2tp_ip6_lock); + if (__l2tp_ip6_bind_lookup(net, &addr->l2tp_addr, NULL, bound_dev_if, + addr->l2tp_conn_id)) { + write_unlock_bh(&l2tp_ip6_lock); + err = -EADDRINUSE; + goto out_unlock; + } + + inet->inet_saddr = v4addr; + inet->inet_rcv_saddr = v4addr; + sk->sk_bound_dev_if = bound_dev_if; + sk->sk_v6_rcv_saddr = addr->l2tp_addr; + np->saddr = addr->l2tp_addr; + + l2tp_ip6_sk(sk)->conn_id = addr->l2tp_conn_id; + + sk_add_bind_node(sk, &l2tp_ip6_bind_table); + sk_del_node_init(sk); + write_unlock_bh(&l2tp_ip6_lock); + + sock_reset_flag(sk, SOCK_ZAPPED); + release_sock(sk); + return 0; + +out_unlock_rcu: + rcu_read_unlock(); +out_unlock: + release_sock(sk); + + return err; +} + +static int l2tp_ip6_connect(struct sock *sk, struct sockaddr *uaddr, + int addr_len) +{ + struct sockaddr_l2tpip6 *lsa = (struct sockaddr_l2tpip6 *)uaddr; + struct sockaddr_in6 *usin = (struct sockaddr_in6 *)uaddr; + struct in6_addr *daddr; + int addr_type; + int rc; + + if (addr_len < sizeof(*lsa)) + return -EINVAL; + + if (usin->sin6_family != AF_INET6) + return -EINVAL; + + addr_type = ipv6_addr_type(&usin->sin6_addr); + if (addr_type & IPV6_ADDR_MULTICAST) + return -EINVAL; + + if (addr_type & IPV6_ADDR_MAPPED) { + daddr = &usin->sin6_addr; + if (ipv4_is_multicast(daddr->s6_addr32[3])) + return -EINVAL; + } + + lock_sock(sk); + + /* Must bind first - autobinding does not work */ + if (sock_flag(sk, SOCK_ZAPPED)) { + rc = -EINVAL; + goto out_sk; + } + + rc = __ip6_datagram_connect(sk, uaddr, addr_len); + if (rc < 0) + goto out_sk; + + l2tp_ip6_sk(sk)->peer_conn_id = lsa->l2tp_conn_id; + + write_lock_bh(&l2tp_ip6_lock); + hlist_del_init(&sk->sk_bind_node); + sk_add_bind_node(sk, &l2tp_ip6_bind_table); + write_unlock_bh(&l2tp_ip6_lock); + +out_sk: + release_sock(sk); + + return rc; +} + +static int l2tp_ip6_disconnect(struct sock *sk, int flags) +{ + if (sock_flag(sk, SOCK_ZAPPED)) + return 0; + + return __udp_disconnect(sk, flags); +} + +static int l2tp_ip6_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + struct sockaddr_l2tpip6 *lsa = (struct sockaddr_l2tpip6 *)uaddr; + struct sock *sk = sock->sk; + struct ipv6_pinfo *np = inet6_sk(sk); + struct l2tp_ip6_sock *lsk = l2tp_ip6_sk(sk); + + lsa->l2tp_family = AF_INET6; + lsa->l2tp_flowinfo = 0; + lsa->l2tp_scope_id = 0; + lsa->l2tp_unused = 0; + if (peer) { + if (!lsk->peer_conn_id) + return -ENOTCONN; + lsa->l2tp_conn_id = lsk->peer_conn_id; + lsa->l2tp_addr = sk->sk_v6_daddr; + if (np->sndflow) + lsa->l2tp_flowinfo = np->flow_label; + } else { + if (ipv6_addr_any(&sk->sk_v6_rcv_saddr)) + lsa->l2tp_addr = np->saddr; + else + lsa->l2tp_addr = sk->sk_v6_rcv_saddr; + + lsa->l2tp_conn_id = lsk->conn_id; + } + if (ipv6_addr_type(&lsa->l2tp_addr) & IPV6_ADDR_LINKLOCAL) + lsa->l2tp_scope_id = READ_ONCE(sk->sk_bound_dev_if); + return sizeof(*lsa); +} + +static int l2tp_ip6_backlog_recv(struct sock *sk, struct sk_buff *skb) +{ + int rc; + + /* Charge it to the socket, dropping if the queue is full. */ + rc = sock_queue_rcv_skb(sk, skb); + if (rc < 0) + goto drop; + + return 0; + +drop: + IP_INC_STATS(sock_net(sk), IPSTATS_MIB_INDISCARDS); + kfree_skb(skb); + return -1; +} + +static int l2tp_ip6_push_pending_frames(struct sock *sk) +{ + struct sk_buff *skb; + __be32 *transhdr = NULL; + int err = 0; + + skb = skb_peek(&sk->sk_write_queue); + if (!skb) + goto out; + + transhdr = (__be32 *)skb_transport_header(skb); + *transhdr = 0; + + err = ip6_push_pending_frames(sk); + +out: + return err; +} + +/* Userspace will call sendmsg() on the tunnel socket to send L2TP + * control frames. + */ +static int l2tp_ip6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) +{ + struct ipv6_txoptions opt_space; + DECLARE_SOCKADDR(struct sockaddr_l2tpip6 *, lsa, msg->msg_name); + struct in6_addr *daddr, *final_p, final; + struct ipv6_pinfo *np = inet6_sk(sk); + struct ipv6_txoptions *opt_to_free = NULL; + struct ipv6_txoptions *opt = NULL; + struct ip6_flowlabel *flowlabel = NULL; + struct dst_entry *dst = NULL; + struct flowi6 fl6; + struct ipcm6_cookie ipc6; + int addr_len = msg->msg_namelen; + int transhdrlen = 4; /* zero session-id */ + int ulen; + int err; + + /* Rough check on arithmetic overflow, + * better check is made in ip6_append_data(). + */ + if (len > INT_MAX - transhdrlen) + return -EMSGSIZE; + + /* Mirror BSD error message compatibility */ + if (msg->msg_flags & MSG_OOB) + return -EOPNOTSUPP; + + /* Get and verify the address */ + memset(&fl6, 0, sizeof(fl6)); + + fl6.flowi6_mark = READ_ONCE(sk->sk_mark); + fl6.flowi6_uid = sk->sk_uid; + + ipcm6_init(&ipc6); + + if (lsa) { + if (addr_len < SIN6_LEN_RFC2133) + return -EINVAL; + + if (lsa->l2tp_family && lsa->l2tp_family != AF_INET6) + return -EAFNOSUPPORT; + + daddr = &lsa->l2tp_addr; + if (np->sndflow) { + fl6.flowlabel = lsa->l2tp_flowinfo & IPV6_FLOWINFO_MASK; + if (fl6.flowlabel & IPV6_FLOWLABEL_MASK) { + flowlabel = fl6_sock_lookup(sk, fl6.flowlabel); + if (IS_ERR(flowlabel)) + return -EINVAL; + } + } + + /* Otherwise it will be difficult to maintain + * sk->sk_dst_cache. + */ + if (sk->sk_state == TCP_ESTABLISHED && + ipv6_addr_equal(daddr, &sk->sk_v6_daddr)) + daddr = &sk->sk_v6_daddr; + + if (addr_len >= sizeof(struct sockaddr_in6) && + lsa->l2tp_scope_id && + ipv6_addr_type(daddr) & IPV6_ADDR_LINKLOCAL) + fl6.flowi6_oif = lsa->l2tp_scope_id; + } else { + if (sk->sk_state != TCP_ESTABLISHED) + return -EDESTADDRREQ; + + daddr = &sk->sk_v6_daddr; + fl6.flowlabel = np->flow_label; + } + + if (fl6.flowi6_oif == 0) + fl6.flowi6_oif = READ_ONCE(sk->sk_bound_dev_if); + + if (msg->msg_controllen) { + opt = &opt_space; + memset(opt, 0, sizeof(struct ipv6_txoptions)); + opt->tot_len = sizeof(struct ipv6_txoptions); + ipc6.opt = opt; + + err = ip6_datagram_send_ctl(sock_net(sk), sk, msg, &fl6, &ipc6); + if (err < 0) { + fl6_sock_release(flowlabel); + return err; + } + if ((fl6.flowlabel & IPV6_FLOWLABEL_MASK) && !flowlabel) { + flowlabel = fl6_sock_lookup(sk, fl6.flowlabel); + if (IS_ERR(flowlabel)) + return -EINVAL; + } + if (!(opt->opt_nflen | opt->opt_flen)) + opt = NULL; + } + + if (!opt) { + opt = txopt_get(np); + opt_to_free = opt; + } + if (flowlabel) + opt = fl6_merge_options(&opt_space, flowlabel, opt); + opt = ipv6_fixup_options(&opt_space, opt); + ipc6.opt = opt; + + fl6.flowi6_proto = sk->sk_protocol; + if (!ipv6_addr_any(daddr)) + fl6.daddr = *daddr; + else + fl6.daddr.s6_addr[15] = 0x1; /* :: means loopback (BSD'ism) */ + if (ipv6_addr_any(&fl6.saddr) && !ipv6_addr_any(&np->saddr)) + fl6.saddr = np->saddr; + + final_p = fl6_update_dst(&fl6, opt, &final); + + if (!fl6.flowi6_oif && ipv6_addr_is_multicast(&fl6.daddr)) + fl6.flowi6_oif = np->mcast_oif; + else if (!fl6.flowi6_oif) + fl6.flowi6_oif = np->ucast_oif; + + security_sk_classify_flow(sk, flowi6_to_flowi_common(&fl6)); + + if (ipc6.tclass < 0) + ipc6.tclass = np->tclass; + + fl6.flowlabel = ip6_make_flowinfo(ipc6.tclass, fl6.flowlabel); + + dst = ip6_dst_lookup_flow(sock_net(sk), sk, &fl6, final_p); + if (IS_ERR(dst)) { + err = PTR_ERR(dst); + goto out; + } + + if (ipc6.hlimit < 0) + ipc6.hlimit = ip6_sk_dst_hoplimit(np, &fl6, dst); + + if (ipc6.dontfrag < 0) + ipc6.dontfrag = np->dontfrag; + + if (msg->msg_flags & MSG_CONFIRM) + goto do_confirm; + +back_from_confirm: + lock_sock(sk); + ulen = len + skb_queue_empty(&sk->sk_write_queue) ? transhdrlen : 0; + err = ip6_append_data(sk, ip_generic_getfrag, msg, + ulen, transhdrlen, &ipc6, + &fl6, (struct rt6_info *)dst, + msg->msg_flags); + if (err) + ip6_flush_pending_frames(sk); + else if (!(msg->msg_flags & MSG_MORE)) + err = l2tp_ip6_push_pending_frames(sk); + release_sock(sk); +done: + dst_release(dst); +out: + fl6_sock_release(flowlabel); + txopt_put(opt_to_free); + + return err < 0 ? err : len; + +do_confirm: + if (msg->msg_flags & MSG_PROBE) + dst_confirm_neigh(dst, &fl6.daddr); + if (!(msg->msg_flags & MSG_PROBE) || len) + goto back_from_confirm; + err = 0; + goto done; +} + +static int l2tp_ip6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int flags, int *addr_len) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + DECLARE_SOCKADDR(struct sockaddr_l2tpip6 *, lsa, msg->msg_name); + size_t copied = 0; + int err = -EOPNOTSUPP; + struct sk_buff *skb; + + if (flags & MSG_OOB) + goto out; + + if (flags & MSG_ERRQUEUE) + return ipv6_recv_error(sk, msg, len, addr_len); + + skb = skb_recv_datagram(sk, flags, &err); + if (!skb) + goto out; + + copied = skb->len; + if (len < copied) { + msg->msg_flags |= MSG_TRUNC; + copied = len; + } + + err = skb_copy_datagram_msg(skb, 0, msg, copied); + if (err) + goto done; + + sock_recv_timestamp(msg, sk, skb); + + /* Copy the address. */ + if (lsa) { + lsa->l2tp_family = AF_INET6; + lsa->l2tp_unused = 0; + lsa->l2tp_addr = ipv6_hdr(skb)->saddr; + lsa->l2tp_flowinfo = 0; + lsa->l2tp_scope_id = 0; + lsa->l2tp_conn_id = 0; + if (ipv6_addr_type(&lsa->l2tp_addr) & IPV6_ADDR_LINKLOCAL) + lsa->l2tp_scope_id = inet6_iif(skb); + *addr_len = sizeof(*lsa); + } + + if (np->rxopt.all) + ip6_datagram_recv_ctl(sk, msg, skb); + + if (flags & MSG_TRUNC) + copied = skb->len; +done: + skb_free_datagram(sk, skb); +out: + return err ? err : copied; +} + +static struct proto l2tp_ip6_prot = { + .name = "L2TP/IPv6", + .owner = THIS_MODULE, + .init = l2tp_ip6_open, + .close = l2tp_ip6_close, + .bind = l2tp_ip6_bind, + .connect = l2tp_ip6_connect, + .disconnect = l2tp_ip6_disconnect, + .ioctl = l2tp_ioctl, + .destroy = l2tp_ip6_destroy_sock, + .setsockopt = ipv6_setsockopt, + .getsockopt = ipv6_getsockopt, + .sendmsg = l2tp_ip6_sendmsg, + .recvmsg = l2tp_ip6_recvmsg, + .backlog_rcv = l2tp_ip6_backlog_recv, + .hash = l2tp_ip6_hash, + .unhash = l2tp_ip6_unhash, + .obj_size = sizeof(struct l2tp_ip6_sock), +}; + +static const struct proto_ops l2tp_ip6_ops = { + .family = PF_INET6, + .owner = THIS_MODULE, + .release = inet6_release, + .bind = inet6_bind, + .connect = inet_dgram_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = l2tp_ip6_getname, + .poll = datagram_poll, + .ioctl = inet6_ioctl, + .gettstamp = sock_gettstamp, + .listen = sock_no_listen, + .shutdown = inet_shutdown, + .setsockopt = sock_common_setsockopt, + .getsockopt = sock_common_getsockopt, + .sendmsg = inet_sendmsg, + .recvmsg = sock_common_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +#ifdef CONFIG_COMPAT + .compat_ioctl = inet6_compat_ioctl, +#endif +}; + +static struct inet_protosw l2tp_ip6_protosw = { + .type = SOCK_DGRAM, + .protocol = IPPROTO_L2TP, + .prot = &l2tp_ip6_prot, + .ops = &l2tp_ip6_ops, +}; + +static struct inet6_protocol l2tp_ip6_protocol __read_mostly = { + .handler = l2tp_ip6_recv, +}; + +static int __init l2tp_ip6_init(void) +{ + int err; + + pr_info("L2TP IP encapsulation support for IPv6 (L2TPv3)\n"); + + err = proto_register(&l2tp_ip6_prot, 1); + if (err != 0) + goto out; + + err = inet6_add_protocol(&l2tp_ip6_protocol, IPPROTO_L2TP); + if (err) + goto out1; + + inet6_register_protosw(&l2tp_ip6_protosw); + return 0; + +out1: + proto_unregister(&l2tp_ip6_prot); +out: + return err; +} + +static void __exit l2tp_ip6_exit(void) +{ + inet6_unregister_protosw(&l2tp_ip6_protosw); + inet6_del_protocol(&l2tp_ip6_protocol, IPPROTO_L2TP); + proto_unregister(&l2tp_ip6_prot); +} + +module_init(l2tp_ip6_init); +module_exit(l2tp_ip6_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Chris Elston "); +MODULE_DESCRIPTION("L2TP IP encapsulation for IPv6"); +MODULE_VERSION("1.0"); + +/* Use the values of SOCK_DGRAM (2) as type and IPPROTO_L2TP (115) as protocol, + * because __stringify doesn't like enums + */ +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_INET6, 115, 2); +MODULE_ALIAS_NET_PF_PROTO(PF_INET6, 115); diff --git a/net/l2tp/l2tp_netlink.c b/net/l2tp/l2tp_netlink.c new file mode 100644 index 000000000..a901fd14f --- /dev/null +++ b/net/l2tp/l2tp_netlink.c @@ -0,0 +1,1048 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* L2TP netlink layer, for management + * + * Copyright (c) 2008,2009,2010 Katalix Systems Ltd + * + * Partly based on the IrDA nelink implementation + * (see net/irda/irnetlink.c) which is: + * Copyright (c) 2007 Samuel Ortiz + * which is in turn partly based on the wireless netlink code: + * Copyright 2006 Johannes Berg + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "l2tp_core.h" + +static struct genl_family l2tp_nl_family; + +static const struct genl_multicast_group l2tp_multicast_group[] = { + { + .name = L2TP_GENL_MCGROUP, + }, +}; + +static int l2tp_nl_tunnel_send(struct sk_buff *skb, u32 portid, u32 seq, + int flags, struct l2tp_tunnel *tunnel, u8 cmd); +static int l2tp_nl_session_send(struct sk_buff *skb, u32 portid, u32 seq, + int flags, struct l2tp_session *session, + u8 cmd); + +/* Accessed under genl lock */ +static const struct l2tp_nl_cmd_ops *l2tp_nl_cmd_ops[__L2TP_PWTYPE_MAX]; + +static struct l2tp_session *l2tp_nl_session_get(struct genl_info *info) +{ + u32 tunnel_id; + u32 session_id; + char *ifname; + struct l2tp_tunnel *tunnel; + struct l2tp_session *session = NULL; + struct net *net = genl_info_net(info); + + if (info->attrs[L2TP_ATTR_IFNAME]) { + ifname = nla_data(info->attrs[L2TP_ATTR_IFNAME]); + session = l2tp_session_get_by_ifname(net, ifname); + } else if ((info->attrs[L2TP_ATTR_SESSION_ID]) && + (info->attrs[L2TP_ATTR_CONN_ID])) { + tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]); + session_id = nla_get_u32(info->attrs[L2TP_ATTR_SESSION_ID]); + tunnel = l2tp_tunnel_get(net, tunnel_id); + if (tunnel) { + session = l2tp_tunnel_get_session(tunnel, session_id); + l2tp_tunnel_dec_refcount(tunnel); + } + } + + return session; +} + +static int l2tp_nl_cmd_noop(struct sk_buff *skb, struct genl_info *info) +{ + struct sk_buff *msg; + void *hdr; + int ret = -ENOBUFS; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) { + ret = -ENOMEM; + goto out; + } + + hdr = genlmsg_put(msg, info->snd_portid, info->snd_seq, + &l2tp_nl_family, 0, L2TP_CMD_NOOP); + if (!hdr) { + ret = -EMSGSIZE; + goto err_out; + } + + genlmsg_end(msg, hdr); + + return genlmsg_unicast(genl_info_net(info), msg, info->snd_portid); + +err_out: + nlmsg_free(msg); + +out: + return ret; +} + +static int l2tp_tunnel_notify(struct genl_family *family, + struct genl_info *info, + struct l2tp_tunnel *tunnel, + u8 cmd) +{ + struct sk_buff *msg; + int ret; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + ret = l2tp_nl_tunnel_send(msg, info->snd_portid, info->snd_seq, + NLM_F_ACK, tunnel, cmd); + + if (ret >= 0) { + ret = genlmsg_multicast_allns(family, msg, 0, 0, GFP_ATOMIC); + /* We don't care if no one is listening */ + if (ret == -ESRCH) + ret = 0; + return ret; + } + + nlmsg_free(msg); + + return ret; +} + +static int l2tp_session_notify(struct genl_family *family, + struct genl_info *info, + struct l2tp_session *session, + u8 cmd) +{ + struct sk_buff *msg; + int ret; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + ret = l2tp_nl_session_send(msg, info->snd_portid, info->snd_seq, + NLM_F_ACK, session, cmd); + + if (ret >= 0) { + ret = genlmsg_multicast_allns(family, msg, 0, 0, GFP_ATOMIC); + /* We don't care if no one is listening */ + if (ret == -ESRCH) + ret = 0; + return ret; + } + + nlmsg_free(msg); + + return ret; +} + +static int l2tp_nl_cmd_tunnel_create_get_addr(struct nlattr **attrs, struct l2tp_tunnel_cfg *cfg) +{ + if (attrs[L2TP_ATTR_UDP_SPORT]) + cfg->local_udp_port = nla_get_u16(attrs[L2TP_ATTR_UDP_SPORT]); + if (attrs[L2TP_ATTR_UDP_DPORT]) + cfg->peer_udp_port = nla_get_u16(attrs[L2TP_ATTR_UDP_DPORT]); + cfg->use_udp_checksums = nla_get_flag(attrs[L2TP_ATTR_UDP_CSUM]); + + /* Must have either AF_INET or AF_INET6 address for source and destination */ +#if IS_ENABLED(CONFIG_IPV6) + if (attrs[L2TP_ATTR_IP6_SADDR] && attrs[L2TP_ATTR_IP6_DADDR]) { + cfg->local_ip6 = nla_data(attrs[L2TP_ATTR_IP6_SADDR]); + cfg->peer_ip6 = nla_data(attrs[L2TP_ATTR_IP6_DADDR]); + cfg->udp6_zero_tx_checksums = nla_get_flag(attrs[L2TP_ATTR_UDP_ZERO_CSUM6_TX]); + cfg->udp6_zero_rx_checksums = nla_get_flag(attrs[L2TP_ATTR_UDP_ZERO_CSUM6_RX]); + return 0; + } +#endif + if (attrs[L2TP_ATTR_IP_SADDR] && attrs[L2TP_ATTR_IP_DADDR]) { + cfg->local_ip.s_addr = nla_get_in_addr(attrs[L2TP_ATTR_IP_SADDR]); + cfg->peer_ip.s_addr = nla_get_in_addr(attrs[L2TP_ATTR_IP_DADDR]); + return 0; + } + return -EINVAL; +} + +static int l2tp_nl_cmd_tunnel_create(struct sk_buff *skb, struct genl_info *info) +{ + u32 tunnel_id; + u32 peer_tunnel_id; + int proto_version; + int fd = -1; + int ret = 0; + struct l2tp_tunnel_cfg cfg = { 0, }; + struct l2tp_tunnel *tunnel; + struct net *net = genl_info_net(info); + struct nlattr **attrs = info->attrs; + + if (!attrs[L2TP_ATTR_CONN_ID]) { + ret = -EINVAL; + goto out; + } + tunnel_id = nla_get_u32(attrs[L2TP_ATTR_CONN_ID]); + + if (!attrs[L2TP_ATTR_PEER_CONN_ID]) { + ret = -EINVAL; + goto out; + } + peer_tunnel_id = nla_get_u32(attrs[L2TP_ATTR_PEER_CONN_ID]); + + if (!attrs[L2TP_ATTR_PROTO_VERSION]) { + ret = -EINVAL; + goto out; + } + proto_version = nla_get_u8(attrs[L2TP_ATTR_PROTO_VERSION]); + + if (!attrs[L2TP_ATTR_ENCAP_TYPE]) { + ret = -EINVAL; + goto out; + } + cfg.encap = nla_get_u16(attrs[L2TP_ATTR_ENCAP_TYPE]); + + /* Managed tunnels take the tunnel socket from userspace. + * Unmanaged tunnels must call out the source and destination addresses + * for the kernel to create the tunnel socket itself. + */ + if (attrs[L2TP_ATTR_FD]) { + fd = nla_get_u32(attrs[L2TP_ATTR_FD]); + } else { + ret = l2tp_nl_cmd_tunnel_create_get_addr(attrs, &cfg); + if (ret < 0) + goto out; + } + + ret = -EINVAL; + switch (cfg.encap) { + case L2TP_ENCAPTYPE_UDP: + case L2TP_ENCAPTYPE_IP: + ret = l2tp_tunnel_create(fd, proto_version, tunnel_id, + peer_tunnel_id, &cfg, &tunnel); + break; + } + + if (ret < 0) + goto out; + + l2tp_tunnel_inc_refcount(tunnel); + ret = l2tp_tunnel_register(tunnel, net, &cfg); + if (ret < 0) { + kfree(tunnel); + goto out; + } + ret = l2tp_tunnel_notify(&l2tp_nl_family, info, tunnel, + L2TP_CMD_TUNNEL_CREATE); + l2tp_tunnel_dec_refcount(tunnel); + +out: + return ret; +} + +static int l2tp_nl_cmd_tunnel_delete(struct sk_buff *skb, struct genl_info *info) +{ + struct l2tp_tunnel *tunnel; + u32 tunnel_id; + int ret = 0; + struct net *net = genl_info_net(info); + + if (!info->attrs[L2TP_ATTR_CONN_ID]) { + ret = -EINVAL; + goto out; + } + tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]); + + tunnel = l2tp_tunnel_get(net, tunnel_id); + if (!tunnel) { + ret = -ENODEV; + goto out; + } + + l2tp_tunnel_notify(&l2tp_nl_family, info, + tunnel, L2TP_CMD_TUNNEL_DELETE); + + l2tp_tunnel_delete(tunnel); + + l2tp_tunnel_dec_refcount(tunnel); + +out: + return ret; +} + +static int l2tp_nl_cmd_tunnel_modify(struct sk_buff *skb, struct genl_info *info) +{ + struct l2tp_tunnel *tunnel; + u32 tunnel_id; + int ret = 0; + struct net *net = genl_info_net(info); + + if (!info->attrs[L2TP_ATTR_CONN_ID]) { + ret = -EINVAL; + goto out; + } + tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]); + + tunnel = l2tp_tunnel_get(net, tunnel_id); + if (!tunnel) { + ret = -ENODEV; + goto out; + } + + ret = l2tp_tunnel_notify(&l2tp_nl_family, info, + tunnel, L2TP_CMD_TUNNEL_MODIFY); + + l2tp_tunnel_dec_refcount(tunnel); + +out: + return ret; +} + +#if IS_ENABLED(CONFIG_IPV6) +static int l2tp_nl_tunnel_send_addr6(struct sk_buff *skb, struct sock *sk, + enum l2tp_encap_type encap) +{ + struct inet_sock *inet = inet_sk(sk); + struct ipv6_pinfo *np = inet6_sk(sk); + + switch (encap) { + case L2TP_ENCAPTYPE_UDP: + if (udp_get_no_check6_tx(sk) && + nla_put_flag(skb, L2TP_ATTR_UDP_ZERO_CSUM6_TX)) + return -1; + if (udp_get_no_check6_rx(sk) && + nla_put_flag(skb, L2TP_ATTR_UDP_ZERO_CSUM6_RX)) + return -1; + if (nla_put_u16(skb, L2TP_ATTR_UDP_SPORT, ntohs(inet->inet_sport)) || + nla_put_u16(skb, L2TP_ATTR_UDP_DPORT, ntohs(inet->inet_dport))) + return -1; + fallthrough; + case L2TP_ENCAPTYPE_IP: + if (nla_put_in6_addr(skb, L2TP_ATTR_IP6_SADDR, &np->saddr) || + nla_put_in6_addr(skb, L2TP_ATTR_IP6_DADDR, &sk->sk_v6_daddr)) + return -1; + break; + } + return 0; +} +#endif + +static int l2tp_nl_tunnel_send_addr4(struct sk_buff *skb, struct sock *sk, + enum l2tp_encap_type encap) +{ + struct inet_sock *inet = inet_sk(sk); + + switch (encap) { + case L2TP_ENCAPTYPE_UDP: + if (nla_put_u8(skb, L2TP_ATTR_UDP_CSUM, !sk->sk_no_check_tx) || + nla_put_u16(skb, L2TP_ATTR_UDP_SPORT, ntohs(inet->inet_sport)) || + nla_put_u16(skb, L2TP_ATTR_UDP_DPORT, ntohs(inet->inet_dport))) + return -1; + fallthrough; + case L2TP_ENCAPTYPE_IP: + if (nla_put_in_addr(skb, L2TP_ATTR_IP_SADDR, inet->inet_saddr) || + nla_put_in_addr(skb, L2TP_ATTR_IP_DADDR, inet->inet_daddr)) + return -1; + break; + } + + return 0; +} + +/* Append attributes for the tunnel address, handling the different attribute types + * used for different tunnel encapsulation and AF_INET v.s. AF_INET6. + */ +static int l2tp_nl_tunnel_send_addr(struct sk_buff *skb, struct l2tp_tunnel *tunnel) +{ + struct sock *sk = tunnel->sock; + + if (!sk) + return 0; + +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) + return l2tp_nl_tunnel_send_addr6(skb, sk, tunnel->encap); +#endif + return l2tp_nl_tunnel_send_addr4(skb, sk, tunnel->encap); +} + +static int l2tp_nl_tunnel_send(struct sk_buff *skb, u32 portid, u32 seq, int flags, + struct l2tp_tunnel *tunnel, u8 cmd) +{ + void *hdr; + struct nlattr *nest; + + hdr = genlmsg_put(skb, portid, seq, &l2tp_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + if (nla_put_u8(skb, L2TP_ATTR_PROTO_VERSION, tunnel->version) || + nla_put_u32(skb, L2TP_ATTR_CONN_ID, tunnel->tunnel_id) || + nla_put_u32(skb, L2TP_ATTR_PEER_CONN_ID, tunnel->peer_tunnel_id) || + nla_put_u32(skb, L2TP_ATTR_DEBUG, 0) || + nla_put_u16(skb, L2TP_ATTR_ENCAP_TYPE, tunnel->encap)) + goto nla_put_failure; + + nest = nla_nest_start_noflag(skb, L2TP_ATTR_STATS); + if (!nest) + goto nla_put_failure; + + if (nla_put_u64_64bit(skb, L2TP_ATTR_TX_PACKETS, + atomic_long_read(&tunnel->stats.tx_packets), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_TX_BYTES, + atomic_long_read(&tunnel->stats.tx_bytes), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_TX_ERRORS, + atomic_long_read(&tunnel->stats.tx_errors), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_RX_PACKETS, + atomic_long_read(&tunnel->stats.rx_packets), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_RX_BYTES, + atomic_long_read(&tunnel->stats.rx_bytes), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_RX_SEQ_DISCARDS, + atomic_long_read(&tunnel->stats.rx_seq_discards), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_RX_COOKIE_DISCARDS, + atomic_long_read(&tunnel->stats.rx_cookie_discards), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_RX_OOS_PACKETS, + atomic_long_read(&tunnel->stats.rx_oos_packets), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_RX_ERRORS, + atomic_long_read(&tunnel->stats.rx_errors), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_RX_INVALID, + atomic_long_read(&tunnel->stats.rx_invalid), + L2TP_ATTR_STATS_PAD)) + goto nla_put_failure; + nla_nest_end(skb, nest); + + if (l2tp_nl_tunnel_send_addr(skb, tunnel)) + goto nla_put_failure; + + genlmsg_end(skb, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(skb, hdr); + return -1; +} + +static int l2tp_nl_cmd_tunnel_get(struct sk_buff *skb, struct genl_info *info) +{ + struct l2tp_tunnel *tunnel; + struct sk_buff *msg; + u32 tunnel_id; + int ret = -ENOBUFS; + struct net *net = genl_info_net(info); + + if (!info->attrs[L2TP_ATTR_CONN_ID]) { + ret = -EINVAL; + goto err; + } + + tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) { + ret = -ENOMEM; + goto err; + } + + tunnel = l2tp_tunnel_get(net, tunnel_id); + if (!tunnel) { + ret = -ENODEV; + goto err_nlmsg; + } + + ret = l2tp_nl_tunnel_send(msg, info->snd_portid, info->snd_seq, + NLM_F_ACK, tunnel, L2TP_CMD_TUNNEL_GET); + if (ret < 0) + goto err_nlmsg_tunnel; + + l2tp_tunnel_dec_refcount(tunnel); + + return genlmsg_unicast(net, msg, info->snd_portid); + +err_nlmsg_tunnel: + l2tp_tunnel_dec_refcount(tunnel); +err_nlmsg: + nlmsg_free(msg); +err: + return ret; +} + +static int l2tp_nl_cmd_tunnel_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + int ti = cb->args[0]; + struct l2tp_tunnel *tunnel; + struct net *net = sock_net(skb->sk); + + for (;;) { + tunnel = l2tp_tunnel_get_nth(net, ti); + if (!tunnel) + goto out; + + if (l2tp_nl_tunnel_send(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + tunnel, L2TP_CMD_TUNNEL_GET) < 0) { + l2tp_tunnel_dec_refcount(tunnel); + goto out; + } + l2tp_tunnel_dec_refcount(tunnel); + + ti++; + } + +out: + cb->args[0] = ti; + + return skb->len; +} + +static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *info) +{ + u32 tunnel_id = 0; + u32 session_id; + u32 peer_session_id; + int ret = 0; + struct l2tp_tunnel *tunnel; + struct l2tp_session *session; + struct l2tp_session_cfg cfg = { 0, }; + struct net *net = genl_info_net(info); + + if (!info->attrs[L2TP_ATTR_CONN_ID]) { + ret = -EINVAL; + goto out; + } + + tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]); + tunnel = l2tp_tunnel_get(net, tunnel_id); + if (!tunnel) { + ret = -ENODEV; + goto out; + } + + if (!info->attrs[L2TP_ATTR_SESSION_ID]) { + ret = -EINVAL; + goto out_tunnel; + } + session_id = nla_get_u32(info->attrs[L2TP_ATTR_SESSION_ID]); + + if (!info->attrs[L2TP_ATTR_PEER_SESSION_ID]) { + ret = -EINVAL; + goto out_tunnel; + } + peer_session_id = nla_get_u32(info->attrs[L2TP_ATTR_PEER_SESSION_ID]); + + if (!info->attrs[L2TP_ATTR_PW_TYPE]) { + ret = -EINVAL; + goto out_tunnel; + } + cfg.pw_type = nla_get_u16(info->attrs[L2TP_ATTR_PW_TYPE]); + if (cfg.pw_type >= __L2TP_PWTYPE_MAX) { + ret = -EINVAL; + goto out_tunnel; + } + + /* L2TPv2 only accepts PPP pseudo-wires */ + if (tunnel->version == 2 && cfg.pw_type != L2TP_PWTYPE_PPP) { + ret = -EPROTONOSUPPORT; + goto out_tunnel; + } + + if (tunnel->version > 2) { + if (info->attrs[L2TP_ATTR_L2SPEC_TYPE]) { + cfg.l2specific_type = nla_get_u8(info->attrs[L2TP_ATTR_L2SPEC_TYPE]); + if (cfg.l2specific_type != L2TP_L2SPECTYPE_DEFAULT && + cfg.l2specific_type != L2TP_L2SPECTYPE_NONE) { + ret = -EINVAL; + goto out_tunnel; + } + } else { + cfg.l2specific_type = L2TP_L2SPECTYPE_DEFAULT; + } + + if (info->attrs[L2TP_ATTR_COOKIE]) { + u16 len = nla_len(info->attrs[L2TP_ATTR_COOKIE]); + + if (len > 8) { + ret = -EINVAL; + goto out_tunnel; + } + cfg.cookie_len = len; + memcpy(&cfg.cookie[0], nla_data(info->attrs[L2TP_ATTR_COOKIE]), len); + } + if (info->attrs[L2TP_ATTR_PEER_COOKIE]) { + u16 len = nla_len(info->attrs[L2TP_ATTR_PEER_COOKIE]); + + if (len > 8) { + ret = -EINVAL; + goto out_tunnel; + } + cfg.peer_cookie_len = len; + memcpy(&cfg.peer_cookie[0], nla_data(info->attrs[L2TP_ATTR_PEER_COOKIE]), len); + } + if (info->attrs[L2TP_ATTR_IFNAME]) + cfg.ifname = nla_data(info->attrs[L2TP_ATTR_IFNAME]); + } + + if (info->attrs[L2TP_ATTR_RECV_SEQ]) + cfg.recv_seq = nla_get_u8(info->attrs[L2TP_ATTR_RECV_SEQ]); + + if (info->attrs[L2TP_ATTR_SEND_SEQ]) + cfg.send_seq = nla_get_u8(info->attrs[L2TP_ATTR_SEND_SEQ]); + + if (info->attrs[L2TP_ATTR_LNS_MODE]) + cfg.lns_mode = nla_get_u8(info->attrs[L2TP_ATTR_LNS_MODE]); + + if (info->attrs[L2TP_ATTR_RECV_TIMEOUT]) + cfg.reorder_timeout = nla_get_msecs(info->attrs[L2TP_ATTR_RECV_TIMEOUT]); + +#ifdef CONFIG_MODULES + if (!l2tp_nl_cmd_ops[cfg.pw_type]) { + genl_unlock(); + request_module("net-l2tp-type-%u", cfg.pw_type); + genl_lock(); + } +#endif + if (!l2tp_nl_cmd_ops[cfg.pw_type] || !l2tp_nl_cmd_ops[cfg.pw_type]->session_create) { + ret = -EPROTONOSUPPORT; + goto out_tunnel; + } + + ret = l2tp_nl_cmd_ops[cfg.pw_type]->session_create(net, tunnel, + session_id, + peer_session_id, + &cfg); + + if (ret >= 0) { + session = l2tp_tunnel_get_session(tunnel, session_id); + if (session) { + ret = l2tp_session_notify(&l2tp_nl_family, info, session, + L2TP_CMD_SESSION_CREATE); + l2tp_session_dec_refcount(session); + } + } + +out_tunnel: + l2tp_tunnel_dec_refcount(tunnel); +out: + return ret; +} + +static int l2tp_nl_cmd_session_delete(struct sk_buff *skb, struct genl_info *info) +{ + int ret = 0; + struct l2tp_session *session; + u16 pw_type; + + session = l2tp_nl_session_get(info); + if (!session) { + ret = -ENODEV; + goto out; + } + + l2tp_session_notify(&l2tp_nl_family, info, + session, L2TP_CMD_SESSION_DELETE); + + pw_type = session->pwtype; + if (pw_type < __L2TP_PWTYPE_MAX) + if (l2tp_nl_cmd_ops[pw_type] && l2tp_nl_cmd_ops[pw_type]->session_delete) + l2tp_nl_cmd_ops[pw_type]->session_delete(session); + + l2tp_session_dec_refcount(session); + +out: + return ret; +} + +static int l2tp_nl_cmd_session_modify(struct sk_buff *skb, struct genl_info *info) +{ + int ret = 0; + struct l2tp_session *session; + + session = l2tp_nl_session_get(info); + if (!session) { + ret = -ENODEV; + goto out; + } + + if (info->attrs[L2TP_ATTR_RECV_SEQ]) + session->recv_seq = nla_get_u8(info->attrs[L2TP_ATTR_RECV_SEQ]); + + if (info->attrs[L2TP_ATTR_SEND_SEQ]) { + session->send_seq = nla_get_u8(info->attrs[L2TP_ATTR_SEND_SEQ]); + l2tp_session_set_header_len(session, session->tunnel->version); + } + + if (info->attrs[L2TP_ATTR_LNS_MODE]) + session->lns_mode = nla_get_u8(info->attrs[L2TP_ATTR_LNS_MODE]); + + if (info->attrs[L2TP_ATTR_RECV_TIMEOUT]) + session->reorder_timeout = nla_get_msecs(info->attrs[L2TP_ATTR_RECV_TIMEOUT]); + + ret = l2tp_session_notify(&l2tp_nl_family, info, + session, L2TP_CMD_SESSION_MODIFY); + + l2tp_session_dec_refcount(session); + +out: + return ret; +} + +static int l2tp_nl_session_send(struct sk_buff *skb, u32 portid, u32 seq, int flags, + struct l2tp_session *session, u8 cmd) +{ + void *hdr; + struct nlattr *nest; + struct l2tp_tunnel *tunnel = session->tunnel; + + hdr = genlmsg_put(skb, portid, seq, &l2tp_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + if (nla_put_u32(skb, L2TP_ATTR_CONN_ID, tunnel->tunnel_id) || + nla_put_u32(skb, L2TP_ATTR_SESSION_ID, session->session_id) || + nla_put_u32(skb, L2TP_ATTR_PEER_CONN_ID, tunnel->peer_tunnel_id) || + nla_put_u32(skb, L2TP_ATTR_PEER_SESSION_ID, session->peer_session_id) || + nla_put_u32(skb, L2TP_ATTR_DEBUG, 0) || + nla_put_u16(skb, L2TP_ATTR_PW_TYPE, session->pwtype)) + goto nla_put_failure; + + if ((session->ifname[0] && + nla_put_string(skb, L2TP_ATTR_IFNAME, session->ifname)) || + (session->cookie_len && + nla_put(skb, L2TP_ATTR_COOKIE, session->cookie_len, session->cookie)) || + (session->peer_cookie_len && + nla_put(skb, L2TP_ATTR_PEER_COOKIE, session->peer_cookie_len, session->peer_cookie)) || + nla_put_u8(skb, L2TP_ATTR_RECV_SEQ, session->recv_seq) || + nla_put_u8(skb, L2TP_ATTR_SEND_SEQ, session->send_seq) || + nla_put_u8(skb, L2TP_ATTR_LNS_MODE, session->lns_mode) || + (l2tp_tunnel_uses_xfrm(tunnel) && + nla_put_u8(skb, L2TP_ATTR_USING_IPSEC, 1)) || + (session->reorder_timeout && + nla_put_msecs(skb, L2TP_ATTR_RECV_TIMEOUT, + session->reorder_timeout, L2TP_ATTR_PAD))) + goto nla_put_failure; + + nest = nla_nest_start_noflag(skb, L2TP_ATTR_STATS); + if (!nest) + goto nla_put_failure; + + if (nla_put_u64_64bit(skb, L2TP_ATTR_TX_PACKETS, + atomic_long_read(&session->stats.tx_packets), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_TX_BYTES, + atomic_long_read(&session->stats.tx_bytes), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_TX_ERRORS, + atomic_long_read(&session->stats.tx_errors), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_RX_PACKETS, + atomic_long_read(&session->stats.rx_packets), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_RX_BYTES, + atomic_long_read(&session->stats.rx_bytes), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_RX_SEQ_DISCARDS, + atomic_long_read(&session->stats.rx_seq_discards), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_RX_COOKIE_DISCARDS, + atomic_long_read(&session->stats.rx_cookie_discards), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_RX_OOS_PACKETS, + atomic_long_read(&session->stats.rx_oos_packets), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_RX_ERRORS, + atomic_long_read(&session->stats.rx_errors), + L2TP_ATTR_STATS_PAD) || + nla_put_u64_64bit(skb, L2TP_ATTR_RX_INVALID, + atomic_long_read(&session->stats.rx_invalid), + L2TP_ATTR_STATS_PAD)) + goto nla_put_failure; + nla_nest_end(skb, nest); + + genlmsg_end(skb, hdr); + return 0; + + nla_put_failure: + genlmsg_cancel(skb, hdr); + return -1; +} + +static int l2tp_nl_cmd_session_get(struct sk_buff *skb, struct genl_info *info) +{ + struct l2tp_session *session; + struct sk_buff *msg; + int ret; + + session = l2tp_nl_session_get(info); + if (!session) { + ret = -ENODEV; + goto err; + } + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) { + ret = -ENOMEM; + goto err_ref; + } + + ret = l2tp_nl_session_send(msg, info->snd_portid, info->snd_seq, + 0, session, L2TP_CMD_SESSION_GET); + if (ret < 0) + goto err_ref_msg; + + ret = genlmsg_unicast(genl_info_net(info), msg, info->snd_portid); + + l2tp_session_dec_refcount(session); + + return ret; + +err_ref_msg: + nlmsg_free(msg); +err_ref: + l2tp_session_dec_refcount(session); +err: + return ret; +} + +static int l2tp_nl_cmd_session_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + struct l2tp_session *session; + struct l2tp_tunnel *tunnel = NULL; + int ti = cb->args[0]; + int si = cb->args[1]; + + for (;;) { + if (!tunnel) { + tunnel = l2tp_tunnel_get_nth(net, ti); + if (!tunnel) + goto out; + } + + session = l2tp_session_get_nth(tunnel, si); + if (!session) { + ti++; + l2tp_tunnel_dec_refcount(tunnel); + tunnel = NULL; + si = 0; + continue; + } + + if (l2tp_nl_session_send(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + session, L2TP_CMD_SESSION_GET) < 0) { + l2tp_session_dec_refcount(session); + l2tp_tunnel_dec_refcount(tunnel); + break; + } + l2tp_session_dec_refcount(session); + + si++; + } + +out: + cb->args[0] = ti; + cb->args[1] = si; + + return skb->len; +} + +static const struct nla_policy l2tp_nl_policy[L2TP_ATTR_MAX + 1] = { + [L2TP_ATTR_NONE] = { .type = NLA_UNSPEC, }, + [L2TP_ATTR_PW_TYPE] = { .type = NLA_U16, }, + [L2TP_ATTR_ENCAP_TYPE] = { .type = NLA_U16, }, + [L2TP_ATTR_OFFSET] = { .type = NLA_U16, }, + [L2TP_ATTR_DATA_SEQ] = { .type = NLA_U8, }, + [L2TP_ATTR_L2SPEC_TYPE] = { .type = NLA_U8, }, + [L2TP_ATTR_L2SPEC_LEN] = { .type = NLA_U8, }, + [L2TP_ATTR_PROTO_VERSION] = { .type = NLA_U8, }, + [L2TP_ATTR_CONN_ID] = { .type = NLA_U32, }, + [L2TP_ATTR_PEER_CONN_ID] = { .type = NLA_U32, }, + [L2TP_ATTR_SESSION_ID] = { .type = NLA_U32, }, + [L2TP_ATTR_PEER_SESSION_ID] = { .type = NLA_U32, }, + [L2TP_ATTR_UDP_CSUM] = { .type = NLA_U8, }, + [L2TP_ATTR_VLAN_ID] = { .type = NLA_U16, }, + [L2TP_ATTR_DEBUG] = { .type = NLA_U32, }, + [L2TP_ATTR_RECV_SEQ] = { .type = NLA_U8, }, + [L2TP_ATTR_SEND_SEQ] = { .type = NLA_U8, }, + [L2TP_ATTR_LNS_MODE] = { .type = NLA_U8, }, + [L2TP_ATTR_USING_IPSEC] = { .type = NLA_U8, }, + [L2TP_ATTR_RECV_TIMEOUT] = { .type = NLA_MSECS, }, + [L2TP_ATTR_FD] = { .type = NLA_U32, }, + [L2TP_ATTR_IP_SADDR] = { .type = NLA_U32, }, + [L2TP_ATTR_IP_DADDR] = { .type = NLA_U32, }, + [L2TP_ATTR_UDP_SPORT] = { .type = NLA_U16, }, + [L2TP_ATTR_UDP_DPORT] = { .type = NLA_U16, }, + [L2TP_ATTR_MTU] = { .type = NLA_U16, }, + [L2TP_ATTR_MRU] = { .type = NLA_U16, }, + [L2TP_ATTR_STATS] = { .type = NLA_NESTED, }, + [L2TP_ATTR_IP6_SADDR] = { + .type = NLA_BINARY, + .len = sizeof(struct in6_addr), + }, + [L2TP_ATTR_IP6_DADDR] = { + .type = NLA_BINARY, + .len = sizeof(struct in6_addr), + }, + [L2TP_ATTR_IFNAME] = { + .type = NLA_NUL_STRING, + .len = IFNAMSIZ - 1, + }, + [L2TP_ATTR_COOKIE] = { + .type = NLA_BINARY, + .len = 8, + }, + [L2TP_ATTR_PEER_COOKIE] = { + .type = NLA_BINARY, + .len = 8, + }, +}; + +static const struct genl_small_ops l2tp_nl_ops[] = { + { + .cmd = L2TP_CMD_NOOP, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = l2tp_nl_cmd_noop, + /* can be retrieved by unprivileged users */ + }, + { + .cmd = L2TP_CMD_TUNNEL_CREATE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = l2tp_nl_cmd_tunnel_create, + .flags = GENL_UNS_ADMIN_PERM, + }, + { + .cmd = L2TP_CMD_TUNNEL_DELETE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = l2tp_nl_cmd_tunnel_delete, + .flags = GENL_UNS_ADMIN_PERM, + }, + { + .cmd = L2TP_CMD_TUNNEL_MODIFY, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = l2tp_nl_cmd_tunnel_modify, + .flags = GENL_UNS_ADMIN_PERM, + }, + { + .cmd = L2TP_CMD_TUNNEL_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = l2tp_nl_cmd_tunnel_get, + .dumpit = l2tp_nl_cmd_tunnel_dump, + .flags = GENL_UNS_ADMIN_PERM, + }, + { + .cmd = L2TP_CMD_SESSION_CREATE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = l2tp_nl_cmd_session_create, + .flags = GENL_UNS_ADMIN_PERM, + }, + { + .cmd = L2TP_CMD_SESSION_DELETE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = l2tp_nl_cmd_session_delete, + .flags = GENL_UNS_ADMIN_PERM, + }, + { + .cmd = L2TP_CMD_SESSION_MODIFY, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = l2tp_nl_cmd_session_modify, + .flags = GENL_UNS_ADMIN_PERM, + }, + { + .cmd = L2TP_CMD_SESSION_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = l2tp_nl_cmd_session_get, + .dumpit = l2tp_nl_cmd_session_dump, + .flags = GENL_UNS_ADMIN_PERM, + }, +}; + +static struct genl_family l2tp_nl_family __ro_after_init = { + .name = L2TP_GENL_NAME, + .version = L2TP_GENL_VERSION, + .hdrsize = 0, + .maxattr = L2TP_ATTR_MAX, + .policy = l2tp_nl_policy, + .netnsok = true, + .module = THIS_MODULE, + .small_ops = l2tp_nl_ops, + .n_small_ops = ARRAY_SIZE(l2tp_nl_ops), + .resv_start_op = L2TP_CMD_SESSION_GET + 1, + .mcgrps = l2tp_multicast_group, + .n_mcgrps = ARRAY_SIZE(l2tp_multicast_group), +}; + +int l2tp_nl_register_ops(enum l2tp_pwtype pw_type, const struct l2tp_nl_cmd_ops *ops) +{ + int ret; + + ret = -EINVAL; + if (pw_type >= __L2TP_PWTYPE_MAX) + goto err; + + genl_lock(); + ret = -EBUSY; + if (l2tp_nl_cmd_ops[pw_type]) + goto out; + + l2tp_nl_cmd_ops[pw_type] = ops; + ret = 0; + +out: + genl_unlock(); +err: + return ret; +} +EXPORT_SYMBOL_GPL(l2tp_nl_register_ops); + +void l2tp_nl_unregister_ops(enum l2tp_pwtype pw_type) +{ + if (pw_type < __L2TP_PWTYPE_MAX) { + genl_lock(); + l2tp_nl_cmd_ops[pw_type] = NULL; + genl_unlock(); + } +} +EXPORT_SYMBOL_GPL(l2tp_nl_unregister_ops); + +static int __init l2tp_nl_init(void) +{ + pr_info("L2TP netlink interface\n"); + return genl_register_family(&l2tp_nl_family); +} + +static void l2tp_nl_cleanup(void) +{ + genl_unregister_family(&l2tp_nl_family); +} + +module_init(l2tp_nl_init); +module_exit(l2tp_nl_cleanup); + +MODULE_AUTHOR("James Chapman "); +MODULE_DESCRIPTION("L2TP netlink"); +MODULE_LICENSE("GPL"); +MODULE_VERSION("1.0"); +MODULE_ALIAS_GENL_FAMILY("l2tp"); diff --git a/net/l2tp/l2tp_ppp.c b/net/l2tp/l2tp_ppp.c new file mode 100644 index 000000000..f011af660 --- /dev/null +++ b/net/l2tp/l2tp_ppp.c @@ -0,0 +1,1743 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/***************************************************************************** + * Linux PPP over L2TP (PPPoX/PPPoL2TP) Sockets + * + * PPPoX --- Generic PPP encapsulation socket family + * PPPoL2TP --- PPP over L2TP (RFC 2661) + * + * Version: 2.0.0 + * + * Authors: James Chapman (jchapman@katalix.com) + * + * Based on original work by Martijn van Oosterhout + * + * License: + */ + +/* This driver handles only L2TP data frames; control frames are handled by a + * userspace application. + * + * To send data in an L2TP session, userspace opens a PPPoL2TP socket and + * attaches it to a bound UDP socket with local tunnel_id / session_id and + * peer tunnel_id / session_id set. Data can then be sent or received using + * regular socket sendmsg() / recvmsg() calls. Kernel parameters of the socket + * can be read or modified using ioctl() or [gs]etsockopt() calls. + * + * When a PPPoL2TP socket is connected with local and peer session_id values + * zero, the socket is treated as a special tunnel management socket. + * + * Here's example userspace code to create a socket for sending/receiving data + * over an L2TP session:- + * + * struct sockaddr_pppol2tp sax; + * int fd; + * int session_fd; + * + * fd = socket(AF_PPPOX, SOCK_DGRAM, PX_PROTO_OL2TP); + * + * sax.sa_family = AF_PPPOX; + * sax.sa_protocol = PX_PROTO_OL2TP; + * sax.pppol2tp.fd = tunnel_fd; // bound UDP socket + * sax.pppol2tp.addr.sin_addr.s_addr = addr->sin_addr.s_addr; + * sax.pppol2tp.addr.sin_port = addr->sin_port; + * sax.pppol2tp.addr.sin_family = AF_INET; + * sax.pppol2tp.s_tunnel = tunnel_id; + * sax.pppol2tp.s_session = session_id; + * sax.pppol2tp.d_tunnel = peer_tunnel_id; + * sax.pppol2tp.d_session = peer_session_id; + * + * session_fd = connect(fd, (struct sockaddr *)&sax, sizeof(sax)); + * + * A pppd plugin that allows PPP traffic to be carried over L2TP using + * this driver is available from the OpenL2TP project at + * http://openl2tp.sourceforge.net. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "l2tp_core.h" + +#define PPPOL2TP_DRV_VERSION "V2.0" + +/* Space for UDP, L2TP and PPP headers */ +#define PPPOL2TP_HEADER_OVERHEAD 40 + +/* Number of bytes to build transmit L2TP headers. + * Unfortunately the size is different depending on whether sequence numbers + * are enabled. + */ +#define PPPOL2TP_L2TP_HDR_SIZE_SEQ 10 +#define PPPOL2TP_L2TP_HDR_SIZE_NOSEQ 6 + +/* Private data of each session. This data lives at the end of struct + * l2tp_session, referenced via session->priv[]. + */ +struct pppol2tp_session { + int owner; /* pid that opened the socket */ + + struct mutex sk_lock; /* Protects .sk */ + struct sock __rcu *sk; /* Pointer to the session PPPoX socket */ + struct sock *__sk; /* Copy of .sk, for cleanup */ + struct rcu_head rcu; /* For asynchronous release */ +}; + +static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb); + +static const struct ppp_channel_ops pppol2tp_chan_ops = { + .start_xmit = pppol2tp_xmit, +}; + +static const struct proto_ops pppol2tp_ops; + +/* Retrieves the pppol2tp socket associated to a session. + * A reference is held on the returned socket, so this function must be paired + * with sock_put(). + */ +static struct sock *pppol2tp_session_get_sock(struct l2tp_session *session) +{ + struct pppol2tp_session *ps = l2tp_session_priv(session); + struct sock *sk; + + rcu_read_lock(); + sk = rcu_dereference(ps->sk); + if (sk) + sock_hold(sk); + rcu_read_unlock(); + + return sk; +} + +/* Helpers to obtain tunnel/session contexts from sockets. + */ +static inline struct l2tp_session *pppol2tp_sock_to_session(struct sock *sk) +{ + struct l2tp_session *session; + + if (!sk) + return NULL; + + sock_hold(sk); + session = (struct l2tp_session *)(sk->sk_user_data); + if (!session) { + sock_put(sk); + goto out; + } + if (WARN_ON(session->magic != L2TP_SESSION_MAGIC)) { + session = NULL; + sock_put(sk); + goto out; + } + +out: + return session; +} + +/***************************************************************************** + * Receive data handling + *****************************************************************************/ + +/* Receive message. This is the recvmsg for the PPPoL2TP socket. + */ +static int pppol2tp_recvmsg(struct socket *sock, struct msghdr *msg, + size_t len, int flags) +{ + int err; + struct sk_buff *skb; + struct sock *sk = sock->sk; + + err = -EIO; + if (sk->sk_state & PPPOX_BOUND) + goto end; + + err = 0; + skb = skb_recv_datagram(sk, flags, &err); + if (!skb) + goto end; + + if (len > skb->len) + len = skb->len; + else if (len < skb->len) + msg->msg_flags |= MSG_TRUNC; + + err = skb_copy_datagram_msg(skb, 0, msg, len); + if (likely(err == 0)) + err = len; + + kfree_skb(skb); +end: + return err; +} + +static void pppol2tp_recv(struct l2tp_session *session, struct sk_buff *skb, int data_len) +{ + struct pppol2tp_session *ps = l2tp_session_priv(session); + struct sock *sk = NULL; + + /* If the socket is bound, send it in to PPP's input queue. Otherwise + * queue it on the session socket. + */ + rcu_read_lock(); + sk = rcu_dereference(ps->sk); + if (!sk) + goto no_sock; + + /* If the first two bytes are 0xFF03, consider that it is the PPP's + * Address and Control fields and skip them. The L2TP module has always + * worked this way, although, in theory, the use of these fields should + * be negotiated and handled at the PPP layer. These fields are + * constant: 0xFF is the All-Stations Address and 0x03 the Unnumbered + * Information command with Poll/Final bit set to zero (RFC 1662). + */ + if (pskb_may_pull(skb, 2) && skb->data[0] == PPP_ALLSTATIONS && + skb->data[1] == PPP_UI) + skb_pull(skb, 2); + + if (sk->sk_state & PPPOX_BOUND) { + struct pppox_sock *po; + + po = pppox_sk(sk); + ppp_input(&po->chan, skb); + } else { + if (sock_queue_rcv_skb(sk, skb) < 0) { + atomic_long_inc(&session->stats.rx_errors); + kfree_skb(skb); + } + } + rcu_read_unlock(); + + return; + +no_sock: + rcu_read_unlock(); + pr_warn_ratelimited("%s: no socket in recv\n", session->name); + kfree_skb(skb); +} + +/************************************************************************ + * Transmit handling + ***********************************************************************/ + +/* This is the sendmsg for the PPPoL2TP pppol2tp_session socket. We come here + * when a user application does a sendmsg() on the session socket. L2TP and + * PPP headers must be inserted into the user's data. + */ +static int pppol2tp_sendmsg(struct socket *sock, struct msghdr *m, + size_t total_len) +{ + struct sock *sk = sock->sk; + struct sk_buff *skb; + int error; + struct l2tp_session *session; + struct l2tp_tunnel *tunnel; + int uhlen; + + error = -ENOTCONN; + if (sock_flag(sk, SOCK_DEAD) || !(sk->sk_state & PPPOX_CONNECTED)) + goto error; + + /* Get session and tunnel contexts */ + error = -EBADF; + session = pppol2tp_sock_to_session(sk); + if (!session) + goto error; + + tunnel = session->tunnel; + + uhlen = (tunnel->encap == L2TP_ENCAPTYPE_UDP) ? sizeof(struct udphdr) : 0; + + /* Allocate a socket buffer */ + error = -ENOMEM; + skb = sock_wmalloc(sk, NET_SKB_PAD + sizeof(struct iphdr) + + uhlen + session->hdr_len + + 2 + total_len, /* 2 bytes for PPP_ALLSTATIONS & PPP_UI */ + 0, GFP_KERNEL); + if (!skb) + goto error_put_sess; + + /* Reserve space for headers. */ + skb_reserve(skb, NET_SKB_PAD); + skb_reset_network_header(skb); + skb_reserve(skb, sizeof(struct iphdr)); + skb_reset_transport_header(skb); + skb_reserve(skb, uhlen); + + /* Add PPP header */ + skb->data[0] = PPP_ALLSTATIONS; + skb->data[1] = PPP_UI; + skb_put(skb, 2); + + /* Copy user data into skb */ + error = memcpy_from_msg(skb_put(skb, total_len), m, total_len); + if (error < 0) { + kfree_skb(skb); + goto error_put_sess; + } + + local_bh_disable(); + l2tp_xmit_skb(session, skb); + local_bh_enable(); + + sock_put(sk); + + return total_len; + +error_put_sess: + sock_put(sk); +error: + return error; +} + +/* Transmit function called by generic PPP driver. Sends PPP frame + * over PPPoL2TP socket. + * + * This is almost the same as pppol2tp_sendmsg(), but rather than + * being called with a msghdr from userspace, it is called with a skb + * from the kernel. + * + * The supplied skb from ppp doesn't have enough headroom for the + * insertion of L2TP, UDP and IP headers so we need to allocate more + * headroom in the skb. This will create a cloned skb. But we must be + * careful in the error case because the caller will expect to free + * the skb it supplied, not our cloned skb. So we take care to always + * leave the original skb unfreed if we return an error. + */ +static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb) +{ + struct sock *sk = (struct sock *)chan->private; + struct l2tp_session *session; + struct l2tp_tunnel *tunnel; + int uhlen, headroom; + + if (sock_flag(sk, SOCK_DEAD) || !(sk->sk_state & PPPOX_CONNECTED)) + goto abort; + + /* Get session and tunnel contexts from the socket */ + session = pppol2tp_sock_to_session(sk); + if (!session) + goto abort; + + tunnel = session->tunnel; + + uhlen = (tunnel->encap == L2TP_ENCAPTYPE_UDP) ? sizeof(struct udphdr) : 0; + headroom = NET_SKB_PAD + + sizeof(struct iphdr) + /* IP header */ + uhlen + /* UDP header (if L2TP_ENCAPTYPE_UDP) */ + session->hdr_len + /* L2TP header */ + 2; /* 2 bytes for PPP_ALLSTATIONS & PPP_UI */ + if (skb_cow_head(skb, headroom)) + goto abort_put_sess; + + /* Setup PPP header */ + __skb_push(skb, 2); + skb->data[0] = PPP_ALLSTATIONS; + skb->data[1] = PPP_UI; + + local_bh_disable(); + l2tp_xmit_skb(session, skb); + local_bh_enable(); + + sock_put(sk); + + return 1; + +abort_put_sess: + sock_put(sk); +abort: + /* Free the original skb */ + kfree_skb(skb); + return 1; +} + +/***************************************************************************** + * Session (and tunnel control) socket create/destroy. + *****************************************************************************/ + +static void pppol2tp_put_sk(struct rcu_head *head) +{ + struct pppol2tp_session *ps; + + ps = container_of(head, typeof(*ps), rcu); + sock_put(ps->__sk); +} + +/* Really kill the session socket. (Called from sock_put() if + * refcnt == 0.) + */ +static void pppol2tp_session_destruct(struct sock *sk) +{ + struct l2tp_session *session = sk->sk_user_data; + + skb_queue_purge(&sk->sk_receive_queue); + skb_queue_purge(&sk->sk_write_queue); + + if (session) { + sk->sk_user_data = NULL; + if (WARN_ON(session->magic != L2TP_SESSION_MAGIC)) + return; + l2tp_session_dec_refcount(session); + } +} + +/* Called when the PPPoX socket (session) is closed. + */ +static int pppol2tp_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct l2tp_session *session; + int error; + + if (!sk) + return 0; + + error = -EBADF; + lock_sock(sk); + if (sock_flag(sk, SOCK_DEAD) != 0) + goto error; + + pppox_unbind_sock(sk); + + /* Signal the death of the socket. */ + sk->sk_state = PPPOX_DEAD; + sock_orphan(sk); + sock->sk = NULL; + + session = pppol2tp_sock_to_session(sk); + if (session) { + struct pppol2tp_session *ps; + + l2tp_session_delete(session); + + ps = l2tp_session_priv(session); + mutex_lock(&ps->sk_lock); + ps->__sk = rcu_dereference_protected(ps->sk, + lockdep_is_held(&ps->sk_lock)); + RCU_INIT_POINTER(ps->sk, NULL); + mutex_unlock(&ps->sk_lock); + call_rcu(&ps->rcu, pppol2tp_put_sk); + + /* Rely on the sock_put() call at the end of the function for + * dropping the reference held by pppol2tp_sock_to_session(). + * The last reference will be dropped by pppol2tp_put_sk(). + */ + } + + release_sock(sk); + + /* This will delete the session context via + * pppol2tp_session_destruct() if the socket's refcnt drops to + * zero. + */ + sock_put(sk); + + return 0; + +error: + release_sock(sk); + return error; +} + +static struct proto pppol2tp_sk_proto = { + .name = "PPPOL2TP", + .owner = THIS_MODULE, + .obj_size = sizeof(struct pppox_sock), +}; + +static int pppol2tp_backlog_recv(struct sock *sk, struct sk_buff *skb) +{ + int rc; + + rc = l2tp_udp_encap_recv(sk, skb); + if (rc) + kfree_skb(skb); + + return NET_RX_SUCCESS; +} + +/* socket() handler. Initialize a new struct sock. + */ +static int pppol2tp_create(struct net *net, struct socket *sock, int kern) +{ + int error = -ENOMEM; + struct sock *sk; + + sk = sk_alloc(net, PF_PPPOX, GFP_KERNEL, &pppol2tp_sk_proto, kern); + if (!sk) + goto out; + + sock_init_data(sock, sk); + + sock->state = SS_UNCONNECTED; + sock->ops = &pppol2tp_ops; + + sk->sk_backlog_rcv = pppol2tp_backlog_recv; + sk->sk_protocol = PX_PROTO_OL2TP; + sk->sk_family = PF_PPPOX; + sk->sk_state = PPPOX_NONE; + sk->sk_type = SOCK_STREAM; + sk->sk_destruct = pppol2tp_session_destruct; + + error = 0; + +out: + return error; +} + +static void pppol2tp_show(struct seq_file *m, void *arg) +{ + struct l2tp_session *session = arg; + struct sock *sk; + + sk = pppol2tp_session_get_sock(session); + if (sk) { + struct pppox_sock *po = pppox_sk(sk); + + seq_printf(m, " interface %s\n", ppp_dev_name(&po->chan)); + sock_put(sk); + } +} + +static void pppol2tp_session_init(struct l2tp_session *session) +{ + struct pppol2tp_session *ps; + + session->recv_skb = pppol2tp_recv; + if (IS_ENABLED(CONFIG_L2TP_DEBUGFS)) + session->show = pppol2tp_show; + + ps = l2tp_session_priv(session); + mutex_init(&ps->sk_lock); + ps->owner = current->pid; +} + +struct l2tp_connect_info { + u8 version; + int fd; + u32 tunnel_id; + u32 peer_tunnel_id; + u32 session_id; + u32 peer_session_id; +}; + +static int pppol2tp_sockaddr_get_info(const void *sa, int sa_len, + struct l2tp_connect_info *info) +{ + switch (sa_len) { + case sizeof(struct sockaddr_pppol2tp): + { + const struct sockaddr_pppol2tp *sa_v2in4 = sa; + + if (sa_v2in4->sa_protocol != PX_PROTO_OL2TP) + return -EINVAL; + + info->version = 2; + info->fd = sa_v2in4->pppol2tp.fd; + info->tunnel_id = sa_v2in4->pppol2tp.s_tunnel; + info->peer_tunnel_id = sa_v2in4->pppol2tp.d_tunnel; + info->session_id = sa_v2in4->pppol2tp.s_session; + info->peer_session_id = sa_v2in4->pppol2tp.d_session; + + break; + } + case sizeof(struct sockaddr_pppol2tpv3): + { + const struct sockaddr_pppol2tpv3 *sa_v3in4 = sa; + + if (sa_v3in4->sa_protocol != PX_PROTO_OL2TP) + return -EINVAL; + + info->version = 3; + info->fd = sa_v3in4->pppol2tp.fd; + info->tunnel_id = sa_v3in4->pppol2tp.s_tunnel; + info->peer_tunnel_id = sa_v3in4->pppol2tp.d_tunnel; + info->session_id = sa_v3in4->pppol2tp.s_session; + info->peer_session_id = sa_v3in4->pppol2tp.d_session; + + break; + } + case sizeof(struct sockaddr_pppol2tpin6): + { + const struct sockaddr_pppol2tpin6 *sa_v2in6 = sa; + + if (sa_v2in6->sa_protocol != PX_PROTO_OL2TP) + return -EINVAL; + + info->version = 2; + info->fd = sa_v2in6->pppol2tp.fd; + info->tunnel_id = sa_v2in6->pppol2tp.s_tunnel; + info->peer_tunnel_id = sa_v2in6->pppol2tp.d_tunnel; + info->session_id = sa_v2in6->pppol2tp.s_session; + info->peer_session_id = sa_v2in6->pppol2tp.d_session; + + break; + } + case sizeof(struct sockaddr_pppol2tpv3in6): + { + const struct sockaddr_pppol2tpv3in6 *sa_v3in6 = sa; + + if (sa_v3in6->sa_protocol != PX_PROTO_OL2TP) + return -EINVAL; + + info->version = 3; + info->fd = sa_v3in6->pppol2tp.fd; + info->tunnel_id = sa_v3in6->pppol2tp.s_tunnel; + info->peer_tunnel_id = sa_v3in6->pppol2tp.d_tunnel; + info->session_id = sa_v3in6->pppol2tp.s_session; + info->peer_session_id = sa_v3in6->pppol2tp.d_session; + + break; + } + default: + return -EINVAL; + } + + return 0; +} + +/* Rough estimation of the maximum payload size a tunnel can transmit without + * fragmenting at the lower IP layer. Assumes L2TPv2 with sequence + * numbers and no IP option. Not quite accurate, but the result is mostly + * unused anyway. + */ +static int pppol2tp_tunnel_mtu(const struct l2tp_tunnel *tunnel) +{ + int mtu; + + mtu = l2tp_tunnel_dst_mtu(tunnel); + if (mtu <= PPPOL2TP_HEADER_OVERHEAD) + return 1500 - PPPOL2TP_HEADER_OVERHEAD; + + return mtu - PPPOL2TP_HEADER_OVERHEAD; +} + +static struct l2tp_tunnel *pppol2tp_tunnel_get(struct net *net, + const struct l2tp_connect_info *info, + bool *new_tunnel) +{ + struct l2tp_tunnel *tunnel; + int error; + + *new_tunnel = false; + + tunnel = l2tp_tunnel_get(net, info->tunnel_id); + + /* Special case: create tunnel context if session_id and + * peer_session_id is 0. Otherwise look up tunnel using supplied + * tunnel id. + */ + if (!info->session_id && !info->peer_session_id) { + if (!tunnel) { + struct l2tp_tunnel_cfg tcfg = { + .encap = L2TP_ENCAPTYPE_UDP, + }; + + /* Prevent l2tp_tunnel_register() from trying to set up + * a kernel socket. + */ + if (info->fd < 0) + return ERR_PTR(-EBADF); + + error = l2tp_tunnel_create(info->fd, + info->version, + info->tunnel_id, + info->peer_tunnel_id, &tcfg, + &tunnel); + if (error < 0) + return ERR_PTR(error); + + l2tp_tunnel_inc_refcount(tunnel); + error = l2tp_tunnel_register(tunnel, net, &tcfg); + if (error < 0) { + kfree(tunnel); + return ERR_PTR(error); + } + + *new_tunnel = true; + } + } else { + /* Error if we can't find the tunnel */ + if (!tunnel) + return ERR_PTR(-ENOENT); + + /* Error if socket is not prepped */ + if (!tunnel->sock) { + l2tp_tunnel_dec_refcount(tunnel); + return ERR_PTR(-ENOENT); + } + } + + return tunnel; +} + +/* connect() handler. Attach a PPPoX socket to a tunnel UDP socket + */ +static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, + int sockaddr_len, int flags) +{ + struct sock *sk = sock->sk; + struct pppox_sock *po = pppox_sk(sk); + struct l2tp_session *session = NULL; + struct l2tp_connect_info info; + struct l2tp_tunnel *tunnel; + struct pppol2tp_session *ps; + struct l2tp_session_cfg cfg = { 0, }; + bool drop_refcnt = false; + bool new_session = false; + bool new_tunnel = false; + int error; + + error = pppol2tp_sockaddr_get_info(uservaddr, sockaddr_len, &info); + if (error < 0) + return error; + + /* Don't bind if tunnel_id is 0 */ + if (!info.tunnel_id) + return -EINVAL; + + tunnel = pppol2tp_tunnel_get(sock_net(sk), &info, &new_tunnel); + if (IS_ERR(tunnel)) + return PTR_ERR(tunnel); + + lock_sock(sk); + + /* Check for already bound sockets */ + error = -EBUSY; + if (sk->sk_state & PPPOX_CONNECTED) + goto end; + + /* We don't supporting rebinding anyway */ + error = -EALREADY; + if (sk->sk_user_data) + goto end; /* socket is already attached */ + + if (tunnel->peer_tunnel_id == 0) + tunnel->peer_tunnel_id = info.peer_tunnel_id; + + session = l2tp_tunnel_get_session(tunnel, info.session_id); + if (session) { + drop_refcnt = true; + + if (session->pwtype != L2TP_PWTYPE_PPP) { + error = -EPROTOTYPE; + goto end; + } + + ps = l2tp_session_priv(session); + + /* Using a pre-existing session is fine as long as it hasn't + * been connected yet. + */ + mutex_lock(&ps->sk_lock); + if (rcu_dereference_protected(ps->sk, + lockdep_is_held(&ps->sk_lock)) || + ps->__sk) { + mutex_unlock(&ps->sk_lock); + error = -EEXIST; + goto end; + } + } else { + cfg.pw_type = L2TP_PWTYPE_PPP; + + session = l2tp_session_create(sizeof(struct pppol2tp_session), + tunnel, info.session_id, + info.peer_session_id, &cfg); + if (IS_ERR(session)) { + error = PTR_ERR(session); + goto end; + } + + pppol2tp_session_init(session); + ps = l2tp_session_priv(session); + l2tp_session_inc_refcount(session); + + mutex_lock(&ps->sk_lock); + error = l2tp_session_register(session, tunnel); + if (error < 0) { + mutex_unlock(&ps->sk_lock); + kfree(session); + goto end; + } + drop_refcnt = true; + new_session = true; + } + + /* Special case: if source & dest session_id == 0x0000, this + * socket is being created to manage the tunnel. Just set up + * the internal context for use by ioctl() and sockopt() + * handlers. + */ + if (session->session_id == 0 && session->peer_session_id == 0) { + error = 0; + goto out_no_ppp; + } + + /* The only header we need to worry about is the L2TP + * header. This size is different depending on whether + * sequence numbers are enabled for the data channel. + */ + po->chan.hdrlen = PPPOL2TP_L2TP_HDR_SIZE_NOSEQ; + + po->chan.private = sk; + po->chan.ops = &pppol2tp_chan_ops; + po->chan.mtu = pppol2tp_tunnel_mtu(tunnel); + + error = ppp_register_net_channel(sock_net(sk), &po->chan); + if (error) { + mutex_unlock(&ps->sk_lock); + goto end; + } + +out_no_ppp: + /* This is how we get the session context from the socket. */ + sk->sk_user_data = session; + rcu_assign_pointer(ps->sk, sk); + mutex_unlock(&ps->sk_lock); + + /* Keep the reference we've grabbed on the session: sk doesn't expect + * the session to disappear. pppol2tp_session_destruct() is responsible + * for dropping it. + */ + drop_refcnt = false; + + sk->sk_state = PPPOX_CONNECTED; + +end: + if (error) { + if (new_session) + l2tp_session_delete(session); + if (new_tunnel) + l2tp_tunnel_delete(tunnel); + } + if (drop_refcnt) + l2tp_session_dec_refcount(session); + l2tp_tunnel_dec_refcount(tunnel); + release_sock(sk); + + return error; +} + +#ifdef CONFIG_L2TP_V3 + +/* Called when creating sessions via the netlink interface. */ +static int pppol2tp_session_create(struct net *net, struct l2tp_tunnel *tunnel, + u32 session_id, u32 peer_session_id, + struct l2tp_session_cfg *cfg) +{ + int error; + struct l2tp_session *session; + + /* Error if tunnel socket is not prepped */ + if (!tunnel->sock) { + error = -ENOENT; + goto err; + } + + /* Allocate and initialize a new session context. */ + session = l2tp_session_create(sizeof(struct pppol2tp_session), + tunnel, session_id, + peer_session_id, cfg); + if (IS_ERR(session)) { + error = PTR_ERR(session); + goto err; + } + + pppol2tp_session_init(session); + + error = l2tp_session_register(session, tunnel); + if (error < 0) + goto err_sess; + + return 0; + +err_sess: + kfree(session); +err: + return error; +} + +#endif /* CONFIG_L2TP_V3 */ + +/* getname() support. + */ +static int pppol2tp_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + int len = 0; + int error = 0; + struct l2tp_session *session; + struct l2tp_tunnel *tunnel; + struct sock *sk = sock->sk; + struct inet_sock *inet; + struct pppol2tp_session *pls; + + error = -ENOTCONN; + if (!sk) + goto end; + if (!(sk->sk_state & PPPOX_CONNECTED)) + goto end; + + error = -EBADF; + session = pppol2tp_sock_to_session(sk); + if (!session) + goto end; + + pls = l2tp_session_priv(session); + tunnel = session->tunnel; + + inet = inet_sk(tunnel->sock); + if (tunnel->version == 2 && tunnel->sock->sk_family == AF_INET) { + struct sockaddr_pppol2tp sp; + + len = sizeof(sp); + memset(&sp, 0, len); + sp.sa_family = AF_PPPOX; + sp.sa_protocol = PX_PROTO_OL2TP; + sp.pppol2tp.fd = tunnel->fd; + sp.pppol2tp.pid = pls->owner; + sp.pppol2tp.s_tunnel = tunnel->tunnel_id; + sp.pppol2tp.d_tunnel = tunnel->peer_tunnel_id; + sp.pppol2tp.s_session = session->session_id; + sp.pppol2tp.d_session = session->peer_session_id; + sp.pppol2tp.addr.sin_family = AF_INET; + sp.pppol2tp.addr.sin_port = inet->inet_dport; + sp.pppol2tp.addr.sin_addr.s_addr = inet->inet_daddr; + memcpy(uaddr, &sp, len); +#if IS_ENABLED(CONFIG_IPV6) + } else if (tunnel->version == 2 && tunnel->sock->sk_family == AF_INET6) { + struct sockaddr_pppol2tpin6 sp; + + len = sizeof(sp); + memset(&sp, 0, len); + sp.sa_family = AF_PPPOX; + sp.sa_protocol = PX_PROTO_OL2TP; + sp.pppol2tp.fd = tunnel->fd; + sp.pppol2tp.pid = pls->owner; + sp.pppol2tp.s_tunnel = tunnel->tunnel_id; + sp.pppol2tp.d_tunnel = tunnel->peer_tunnel_id; + sp.pppol2tp.s_session = session->session_id; + sp.pppol2tp.d_session = session->peer_session_id; + sp.pppol2tp.addr.sin6_family = AF_INET6; + sp.pppol2tp.addr.sin6_port = inet->inet_dport; + memcpy(&sp.pppol2tp.addr.sin6_addr, &tunnel->sock->sk_v6_daddr, + sizeof(tunnel->sock->sk_v6_daddr)); + memcpy(uaddr, &sp, len); + } else if (tunnel->version == 3 && tunnel->sock->sk_family == AF_INET6) { + struct sockaddr_pppol2tpv3in6 sp; + + len = sizeof(sp); + memset(&sp, 0, len); + sp.sa_family = AF_PPPOX; + sp.sa_protocol = PX_PROTO_OL2TP; + sp.pppol2tp.fd = tunnel->fd; + sp.pppol2tp.pid = pls->owner; + sp.pppol2tp.s_tunnel = tunnel->tunnel_id; + sp.pppol2tp.d_tunnel = tunnel->peer_tunnel_id; + sp.pppol2tp.s_session = session->session_id; + sp.pppol2tp.d_session = session->peer_session_id; + sp.pppol2tp.addr.sin6_family = AF_INET6; + sp.pppol2tp.addr.sin6_port = inet->inet_dport; + memcpy(&sp.pppol2tp.addr.sin6_addr, &tunnel->sock->sk_v6_daddr, + sizeof(tunnel->sock->sk_v6_daddr)); + memcpy(uaddr, &sp, len); +#endif + } else if (tunnel->version == 3) { + struct sockaddr_pppol2tpv3 sp; + + len = sizeof(sp); + memset(&sp, 0, len); + sp.sa_family = AF_PPPOX; + sp.sa_protocol = PX_PROTO_OL2TP; + sp.pppol2tp.fd = tunnel->fd; + sp.pppol2tp.pid = pls->owner; + sp.pppol2tp.s_tunnel = tunnel->tunnel_id; + sp.pppol2tp.d_tunnel = tunnel->peer_tunnel_id; + sp.pppol2tp.s_session = session->session_id; + sp.pppol2tp.d_session = session->peer_session_id; + sp.pppol2tp.addr.sin_family = AF_INET; + sp.pppol2tp.addr.sin_port = inet->inet_dport; + sp.pppol2tp.addr.sin_addr.s_addr = inet->inet_daddr; + memcpy(uaddr, &sp, len); + } + + error = len; + + sock_put(sk); +end: + return error; +} + +/**************************************************************************** + * ioctl() handlers. + * + * The PPPoX socket is created for L2TP sessions: tunnels have their own UDP + * sockets. However, in order to control kernel tunnel features, we allow + * userspace to create a special "tunnel" PPPoX socket which is used for + * control only. Tunnel PPPoX sockets have session_id == 0 and simply allow + * the user application to issue L2TP setsockopt(), getsockopt() and ioctl() + * calls. + ****************************************************************************/ + +static void pppol2tp_copy_stats(struct pppol2tp_ioc_stats *dest, + const struct l2tp_stats *stats) +{ + memset(dest, 0, sizeof(*dest)); + + dest->tx_packets = atomic_long_read(&stats->tx_packets); + dest->tx_bytes = atomic_long_read(&stats->tx_bytes); + dest->tx_errors = atomic_long_read(&stats->tx_errors); + dest->rx_packets = atomic_long_read(&stats->rx_packets); + dest->rx_bytes = atomic_long_read(&stats->rx_bytes); + dest->rx_seq_discards = atomic_long_read(&stats->rx_seq_discards); + dest->rx_oos_packets = atomic_long_read(&stats->rx_oos_packets); + dest->rx_errors = atomic_long_read(&stats->rx_errors); +} + +static int pppol2tp_tunnel_copy_stats(struct pppol2tp_ioc_stats *stats, + struct l2tp_tunnel *tunnel) +{ + struct l2tp_session *session; + + if (!stats->session_id) { + pppol2tp_copy_stats(stats, &tunnel->stats); + return 0; + } + + /* If session_id is set, search the corresponding session in the + * context of this tunnel and record the session's statistics. + */ + session = l2tp_tunnel_get_session(tunnel, stats->session_id); + if (!session) + return -EBADR; + + if (session->pwtype != L2TP_PWTYPE_PPP) { + l2tp_session_dec_refcount(session); + return -EBADR; + } + + pppol2tp_copy_stats(stats, &session->stats); + l2tp_session_dec_refcount(session); + + return 0; +} + +static int pppol2tp_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + struct pppol2tp_ioc_stats stats; + struct l2tp_session *session; + + switch (cmd) { + case PPPIOCGMRU: + case PPPIOCGFLAGS: + session = sock->sk->sk_user_data; + if (!session) + return -ENOTCONN; + + if (WARN_ON(session->magic != L2TP_SESSION_MAGIC)) + return -EBADF; + + /* Not defined for tunnels */ + if (!session->session_id && !session->peer_session_id) + return -ENOSYS; + + if (put_user(0, (int __user *)arg)) + return -EFAULT; + break; + + case PPPIOCSMRU: + case PPPIOCSFLAGS: + session = sock->sk->sk_user_data; + if (!session) + return -ENOTCONN; + + if (WARN_ON(session->magic != L2TP_SESSION_MAGIC)) + return -EBADF; + + /* Not defined for tunnels */ + if (!session->session_id && !session->peer_session_id) + return -ENOSYS; + + if (!access_ok((int __user *)arg, sizeof(int))) + return -EFAULT; + break; + + case PPPIOCGL2TPSTATS: + session = sock->sk->sk_user_data; + if (!session) + return -ENOTCONN; + + if (WARN_ON(session->magic != L2TP_SESSION_MAGIC)) + return -EBADF; + + /* Session 0 represents the parent tunnel */ + if (!session->session_id && !session->peer_session_id) { + u32 session_id; + int err; + + if (copy_from_user(&stats, (void __user *)arg, + sizeof(stats))) + return -EFAULT; + + session_id = stats.session_id; + err = pppol2tp_tunnel_copy_stats(&stats, + session->tunnel); + if (err < 0) + return err; + + stats.session_id = session_id; + } else { + pppol2tp_copy_stats(&stats, &session->stats); + stats.session_id = session->session_id; + } + stats.tunnel_id = session->tunnel->tunnel_id; + stats.using_ipsec = l2tp_tunnel_uses_xfrm(session->tunnel); + + if (copy_to_user((void __user *)arg, &stats, sizeof(stats))) + return -EFAULT; + break; + + default: + return -ENOIOCTLCMD; + } + + return 0; +} + +/***************************************************************************** + * setsockopt() / getsockopt() support. + * + * The PPPoX socket is created for L2TP sessions: tunnels have their own UDP + * sockets. In order to control kernel tunnel features, we allow userspace to + * create a special "tunnel" PPPoX socket which is used for control only. + * Tunnel PPPoX sockets have session_id == 0 and simply allow the user + * application to issue L2TP setsockopt(), getsockopt() and ioctl() calls. + *****************************************************************************/ + +/* Tunnel setsockopt() helper. + */ +static int pppol2tp_tunnel_setsockopt(struct sock *sk, + struct l2tp_tunnel *tunnel, + int optname, int val) +{ + int err = 0; + + switch (optname) { + case PPPOL2TP_SO_DEBUG: + /* Tunnel debug flags option is deprecated */ + break; + + default: + err = -ENOPROTOOPT; + break; + } + + return err; +} + +/* Session setsockopt helper. + */ +static int pppol2tp_session_setsockopt(struct sock *sk, + struct l2tp_session *session, + int optname, int val) +{ + int err = 0; + + switch (optname) { + case PPPOL2TP_SO_RECVSEQ: + if (val != 0 && val != 1) { + err = -EINVAL; + break; + } + session->recv_seq = !!val; + break; + + case PPPOL2TP_SO_SENDSEQ: + if (val != 0 && val != 1) { + err = -EINVAL; + break; + } + session->send_seq = !!val; + { + struct pppox_sock *po = pppox_sk(sk); + + po->chan.hdrlen = val ? PPPOL2TP_L2TP_HDR_SIZE_SEQ : + PPPOL2TP_L2TP_HDR_SIZE_NOSEQ; + } + l2tp_session_set_header_len(session, session->tunnel->version); + break; + + case PPPOL2TP_SO_LNSMODE: + if (val != 0 && val != 1) { + err = -EINVAL; + break; + } + session->lns_mode = !!val; + break; + + case PPPOL2TP_SO_DEBUG: + /* Session debug flags option is deprecated */ + break; + + case PPPOL2TP_SO_REORDERTO: + session->reorder_timeout = msecs_to_jiffies(val); + break; + + default: + err = -ENOPROTOOPT; + break; + } + + return err; +} + +/* Main setsockopt() entry point. + * Does API checks, then calls either the tunnel or session setsockopt + * handler, according to whether the PPPoL2TP socket is a for a regular + * session or the special tunnel type. + */ +static int pppol2tp_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct l2tp_session *session; + struct l2tp_tunnel *tunnel; + int val; + int err; + + if (level != SOL_PPPOL2TP) + return -EINVAL; + + if (optlen < sizeof(int)) + return -EINVAL; + + if (copy_from_sockptr(&val, optval, sizeof(int))) + return -EFAULT; + + err = -ENOTCONN; + if (!sk->sk_user_data) + goto end; + + /* Get session context from the socket */ + err = -EBADF; + session = pppol2tp_sock_to_session(sk); + if (!session) + goto end; + + /* Special case: if session_id == 0x0000, treat as operation on tunnel + */ + if (session->session_id == 0 && session->peer_session_id == 0) { + tunnel = session->tunnel; + err = pppol2tp_tunnel_setsockopt(sk, tunnel, optname, val); + } else { + err = pppol2tp_session_setsockopt(sk, session, optname, val); + } + + sock_put(sk); +end: + return err; +} + +/* Tunnel getsockopt helper. Called with sock locked. + */ +static int pppol2tp_tunnel_getsockopt(struct sock *sk, + struct l2tp_tunnel *tunnel, + int optname, int *val) +{ + int err = 0; + + switch (optname) { + case PPPOL2TP_SO_DEBUG: + /* Tunnel debug flags option is deprecated */ + *val = 0; + break; + + default: + err = -ENOPROTOOPT; + break; + } + + return err; +} + +/* Session getsockopt helper. Called with sock locked. + */ +static int pppol2tp_session_getsockopt(struct sock *sk, + struct l2tp_session *session, + int optname, int *val) +{ + int err = 0; + + switch (optname) { + case PPPOL2TP_SO_RECVSEQ: + *val = session->recv_seq; + break; + + case PPPOL2TP_SO_SENDSEQ: + *val = session->send_seq; + break; + + case PPPOL2TP_SO_LNSMODE: + *val = session->lns_mode; + break; + + case PPPOL2TP_SO_DEBUG: + /* Session debug flags option is deprecated */ + *val = 0; + break; + + case PPPOL2TP_SO_REORDERTO: + *val = (int)jiffies_to_msecs(session->reorder_timeout); + break; + + default: + err = -ENOPROTOOPT; + } + + return err; +} + +/* Main getsockopt() entry point. + * Does API checks, then calls either the tunnel or session getsockopt + * handler, according to whether the PPPoX socket is a for a regular session + * or the special tunnel type. + */ +static int pppol2tp_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + struct l2tp_session *session; + struct l2tp_tunnel *tunnel; + int val, len; + int err; + + if (level != SOL_PPPOL2TP) + return -EINVAL; + + if (get_user(len, optlen)) + return -EFAULT; + + len = min_t(unsigned int, len, sizeof(int)); + + if (len < 0) + return -EINVAL; + + err = -ENOTCONN; + if (!sk->sk_user_data) + goto end; + + /* Get the session context */ + err = -EBADF; + session = pppol2tp_sock_to_session(sk); + if (!session) + goto end; + + /* Special case: if session_id == 0x0000, treat as operation on tunnel */ + if (session->session_id == 0 && session->peer_session_id == 0) { + tunnel = session->tunnel; + err = pppol2tp_tunnel_getsockopt(sk, tunnel, optname, &val); + if (err) + goto end_put_sess; + } else { + err = pppol2tp_session_getsockopt(sk, session, optname, &val); + if (err) + goto end_put_sess; + } + + err = -EFAULT; + if (put_user(len, optlen)) + goto end_put_sess; + + if (copy_to_user((void __user *)optval, &val, len)) + goto end_put_sess; + + err = 0; + +end_put_sess: + sock_put(sk); +end: + return err; +} + +/***************************************************************************** + * /proc filesystem for debug + * Since the original pppol2tp driver provided /proc/net/pppol2tp for + * L2TPv2, we dump only L2TPv2 tunnels and sessions here. + *****************************************************************************/ + +static unsigned int pppol2tp_net_id; + +#ifdef CONFIG_PROC_FS + +struct pppol2tp_seq_data { + struct seq_net_private p; + int tunnel_idx; /* current tunnel */ + int session_idx; /* index of session within current tunnel */ + struct l2tp_tunnel *tunnel; + struct l2tp_session *session; /* NULL means get next tunnel */ +}; + +static void pppol2tp_next_tunnel(struct net *net, struct pppol2tp_seq_data *pd) +{ + /* Drop reference taken during previous invocation */ + if (pd->tunnel) + l2tp_tunnel_dec_refcount(pd->tunnel); + + for (;;) { + pd->tunnel = l2tp_tunnel_get_nth(net, pd->tunnel_idx); + pd->tunnel_idx++; + + /* Only accept L2TPv2 tunnels */ + if (!pd->tunnel || pd->tunnel->version == 2) + return; + + l2tp_tunnel_dec_refcount(pd->tunnel); + } +} + +static void pppol2tp_next_session(struct net *net, struct pppol2tp_seq_data *pd) +{ + /* Drop reference taken during previous invocation */ + if (pd->session) + l2tp_session_dec_refcount(pd->session); + + pd->session = l2tp_session_get_nth(pd->tunnel, pd->session_idx); + pd->session_idx++; + + if (!pd->session) { + pd->session_idx = 0; + pppol2tp_next_tunnel(net, pd); + } +} + +static void *pppol2tp_seq_start(struct seq_file *m, loff_t *offs) +{ + struct pppol2tp_seq_data *pd = SEQ_START_TOKEN; + loff_t pos = *offs; + struct net *net; + + if (!pos) + goto out; + + if (WARN_ON(!m->private)) { + pd = NULL; + goto out; + } + + pd = m->private; + net = seq_file_net(m); + + if (!pd->tunnel) + pppol2tp_next_tunnel(net, pd); + else + pppol2tp_next_session(net, pd); + + /* NULL tunnel and session indicates end of list */ + if (!pd->tunnel && !pd->session) + pd = NULL; + +out: + return pd; +} + +static void *pppol2tp_seq_next(struct seq_file *m, void *v, loff_t *pos) +{ + (*pos)++; + return NULL; +} + +static void pppol2tp_seq_stop(struct seq_file *p, void *v) +{ + struct pppol2tp_seq_data *pd = v; + + if (!pd || pd == SEQ_START_TOKEN) + return; + + /* Drop reference taken by last invocation of pppol2tp_next_session() + * or pppol2tp_next_tunnel(). + */ + if (pd->session) { + l2tp_session_dec_refcount(pd->session); + pd->session = NULL; + } + if (pd->tunnel) { + l2tp_tunnel_dec_refcount(pd->tunnel); + pd->tunnel = NULL; + } +} + +static void pppol2tp_seq_tunnel_show(struct seq_file *m, void *v) +{ + struct l2tp_tunnel *tunnel = v; + + seq_printf(m, "\nTUNNEL '%s', %c %d\n", + tunnel->name, + (tunnel == tunnel->sock->sk_user_data) ? 'Y' : 'N', + refcount_read(&tunnel->ref_count) - 1); + seq_printf(m, " %08x %ld/%ld/%ld %ld/%ld/%ld\n", + 0, + atomic_long_read(&tunnel->stats.tx_packets), + atomic_long_read(&tunnel->stats.tx_bytes), + atomic_long_read(&tunnel->stats.tx_errors), + atomic_long_read(&tunnel->stats.rx_packets), + atomic_long_read(&tunnel->stats.rx_bytes), + atomic_long_read(&tunnel->stats.rx_errors)); +} + +static void pppol2tp_seq_session_show(struct seq_file *m, void *v) +{ + struct l2tp_session *session = v; + struct l2tp_tunnel *tunnel = session->tunnel; + unsigned char state; + char user_data_ok; + struct sock *sk; + u32 ip = 0; + u16 port = 0; + + if (tunnel->sock) { + struct inet_sock *inet = inet_sk(tunnel->sock); + + ip = ntohl(inet->inet_saddr); + port = ntohs(inet->inet_sport); + } + + sk = pppol2tp_session_get_sock(session); + if (sk) { + state = sk->sk_state; + user_data_ok = (session == sk->sk_user_data) ? 'Y' : 'N'; + } else { + state = 0; + user_data_ok = 'N'; + } + + seq_printf(m, " SESSION '%s' %08X/%d %04X/%04X -> %04X/%04X %d %c\n", + session->name, ip, port, + tunnel->tunnel_id, + session->session_id, + tunnel->peer_tunnel_id, + session->peer_session_id, + state, user_data_ok); + seq_printf(m, " 0/0/%c/%c/%s %08x %u\n", + session->recv_seq ? 'R' : '-', + session->send_seq ? 'S' : '-', + session->lns_mode ? "LNS" : "LAC", + 0, + jiffies_to_msecs(session->reorder_timeout)); + seq_printf(m, " %u/%u %ld/%ld/%ld %ld/%ld/%ld\n", + session->nr, session->ns, + atomic_long_read(&session->stats.tx_packets), + atomic_long_read(&session->stats.tx_bytes), + atomic_long_read(&session->stats.tx_errors), + atomic_long_read(&session->stats.rx_packets), + atomic_long_read(&session->stats.rx_bytes), + atomic_long_read(&session->stats.rx_errors)); + + if (sk) { + struct pppox_sock *po = pppox_sk(sk); + + seq_printf(m, " interface %s\n", ppp_dev_name(&po->chan)); + sock_put(sk); + } +} + +static int pppol2tp_seq_show(struct seq_file *m, void *v) +{ + struct pppol2tp_seq_data *pd = v; + + /* display header on line 1 */ + if (v == SEQ_START_TOKEN) { + seq_puts(m, "PPPoL2TP driver info, " PPPOL2TP_DRV_VERSION "\n"); + seq_puts(m, "TUNNEL name, user-data-ok session-count\n"); + seq_puts(m, " debug tx-pkts/bytes/errs rx-pkts/bytes/errs\n"); + seq_puts(m, " SESSION name, addr/port src-tid/sid dest-tid/sid state user-data-ok\n"); + seq_puts(m, " mtu/mru/rcvseq/sendseq/lns debug reorderto\n"); + seq_puts(m, " nr/ns tx-pkts/bytes/errs rx-pkts/bytes/errs\n"); + goto out; + } + + if (!pd->session) + pppol2tp_seq_tunnel_show(m, pd->tunnel); + else + pppol2tp_seq_session_show(m, pd->session); + +out: + return 0; +} + +static const struct seq_operations pppol2tp_seq_ops = { + .start = pppol2tp_seq_start, + .next = pppol2tp_seq_next, + .stop = pppol2tp_seq_stop, + .show = pppol2tp_seq_show, +}; +#endif /* CONFIG_PROC_FS */ + +/***************************************************************************** + * Network namespace + *****************************************************************************/ + +static __net_init int pppol2tp_init_net(struct net *net) +{ + struct proc_dir_entry *pde; + int err = 0; + + pde = proc_create_net("pppol2tp", 0444, net->proc_net, + &pppol2tp_seq_ops, sizeof(struct pppol2tp_seq_data)); + if (!pde) { + err = -ENOMEM; + goto out; + } + +out: + return err; +} + +static __net_exit void pppol2tp_exit_net(struct net *net) +{ + remove_proc_entry("pppol2tp", net->proc_net); +} + +static struct pernet_operations pppol2tp_net_ops = { + .init = pppol2tp_init_net, + .exit = pppol2tp_exit_net, + .id = &pppol2tp_net_id, +}; + +/***************************************************************************** + * Init and cleanup + *****************************************************************************/ + +static const struct proto_ops pppol2tp_ops = { + .family = AF_PPPOX, + .owner = THIS_MODULE, + .release = pppol2tp_release, + .bind = sock_no_bind, + .connect = pppol2tp_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = pppol2tp_getname, + .poll = datagram_poll, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = pppol2tp_setsockopt, + .getsockopt = pppol2tp_getsockopt, + .sendmsg = pppol2tp_sendmsg, + .recvmsg = pppol2tp_recvmsg, + .mmap = sock_no_mmap, + .ioctl = pppox_ioctl, +#ifdef CONFIG_COMPAT + .compat_ioctl = pppox_compat_ioctl, +#endif +}; + +static const struct pppox_proto pppol2tp_proto = { + .create = pppol2tp_create, + .ioctl = pppol2tp_ioctl, + .owner = THIS_MODULE, +}; + +#ifdef CONFIG_L2TP_V3 + +static const struct l2tp_nl_cmd_ops pppol2tp_nl_cmd_ops = { + .session_create = pppol2tp_session_create, + .session_delete = l2tp_session_delete, +}; + +#endif /* CONFIG_L2TP_V3 */ + +static int __init pppol2tp_init(void) +{ + int err; + + err = register_pernet_device(&pppol2tp_net_ops); + if (err) + goto out; + + err = proto_register(&pppol2tp_sk_proto, 0); + if (err) + goto out_unregister_pppol2tp_pernet; + + err = register_pppox_proto(PX_PROTO_OL2TP, &pppol2tp_proto); + if (err) + goto out_unregister_pppol2tp_proto; + +#ifdef CONFIG_L2TP_V3 + err = l2tp_nl_register_ops(L2TP_PWTYPE_PPP, &pppol2tp_nl_cmd_ops); + if (err) + goto out_unregister_pppox; +#endif + + pr_info("PPPoL2TP kernel driver, %s\n", PPPOL2TP_DRV_VERSION); + +out: + return err; + +#ifdef CONFIG_L2TP_V3 +out_unregister_pppox: + unregister_pppox_proto(PX_PROTO_OL2TP); +#endif +out_unregister_pppol2tp_proto: + proto_unregister(&pppol2tp_sk_proto); +out_unregister_pppol2tp_pernet: + unregister_pernet_device(&pppol2tp_net_ops); + goto out; +} + +static void __exit pppol2tp_exit(void) +{ +#ifdef CONFIG_L2TP_V3 + l2tp_nl_unregister_ops(L2TP_PWTYPE_PPP); +#endif + unregister_pppox_proto(PX_PROTO_OL2TP); + proto_unregister(&pppol2tp_sk_proto); + unregister_pernet_device(&pppol2tp_net_ops); +} + +module_init(pppol2tp_init); +module_exit(pppol2tp_exit); + +MODULE_AUTHOR("James Chapman "); +MODULE_DESCRIPTION("PPP over L2TP over UDP"); +MODULE_LICENSE("GPL"); +MODULE_VERSION(PPPOL2TP_DRV_VERSION); +MODULE_ALIAS_NET_PF_PROTO(PF_PPPOX, PX_PROTO_OL2TP); +MODULE_ALIAS_L2TP_PWTYPE(7); diff --git a/net/l2tp/trace.h b/net/l2tp/trace.h new file mode 100644 index 000000000..8596eaa12 --- /dev/null +++ b/net/l2tp/trace.h @@ -0,0 +1,211 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +#undef TRACE_SYSTEM +#define TRACE_SYSTEM l2tp + +#if !defined(_TRACE_L2TP_H) || defined(TRACE_HEADER_MULTI_READ) +#define _TRACE_L2TP_H + +#include +#include +#include "l2tp_core.h" + +#define encap_type_name(e) { L2TP_ENCAPTYPE_##e, #e } +#define show_encap_type_name(val) \ + __print_symbolic(val, \ + encap_type_name(UDP), \ + encap_type_name(IP)) + +#define pw_type_name(p) { L2TP_PWTYPE_##p, #p } +#define show_pw_type_name(val) \ + __print_symbolic(val, \ + pw_type_name(ETH_VLAN), \ + pw_type_name(ETH), \ + pw_type_name(PPP), \ + pw_type_name(PPP_AC), \ + pw_type_name(IP)) + +DECLARE_EVENT_CLASS(tunnel_only_evt, + TP_PROTO(struct l2tp_tunnel *tunnel), + TP_ARGS(tunnel), + TP_STRUCT__entry( + __array(char, name, L2TP_TUNNEL_NAME_MAX) + ), + TP_fast_assign( + memcpy(__entry->name, tunnel->name, L2TP_TUNNEL_NAME_MAX); + ), + TP_printk("%s", __entry->name) +); + +DECLARE_EVENT_CLASS(session_only_evt, + TP_PROTO(struct l2tp_session *session), + TP_ARGS(session), + TP_STRUCT__entry( + __array(char, name, L2TP_SESSION_NAME_MAX) + ), + TP_fast_assign( + memcpy(__entry->name, session->name, L2TP_SESSION_NAME_MAX); + ), + TP_printk("%s", __entry->name) +); + +TRACE_EVENT(register_tunnel, + TP_PROTO(struct l2tp_tunnel *tunnel), + TP_ARGS(tunnel), + TP_STRUCT__entry( + __array(char, name, L2TP_TUNNEL_NAME_MAX) + __field(int, fd) + __field(u32, tid) + __field(u32, ptid) + __field(int, version) + __field(enum l2tp_encap_type, encap) + ), + TP_fast_assign( + memcpy(__entry->name, tunnel->name, L2TP_TUNNEL_NAME_MAX); + __entry->fd = tunnel->fd; + __entry->tid = tunnel->tunnel_id; + __entry->ptid = tunnel->peer_tunnel_id; + __entry->version = tunnel->version; + __entry->encap = tunnel->encap; + ), + TP_printk("%s: type=%s encap=%s version=L2TPv%d tid=%u ptid=%u fd=%d", + __entry->name, + __entry->fd > 0 ? "managed" : "unmanaged", + show_encap_type_name(__entry->encap), + __entry->version, + __entry->tid, + __entry->ptid, + __entry->fd) +); + +DEFINE_EVENT(tunnel_only_evt, delete_tunnel, + TP_PROTO(struct l2tp_tunnel *tunnel), + TP_ARGS(tunnel) +); + +DEFINE_EVENT(tunnel_only_evt, free_tunnel, + TP_PROTO(struct l2tp_tunnel *tunnel), + TP_ARGS(tunnel) +); + +TRACE_EVENT(register_session, + TP_PROTO(struct l2tp_session *session), + TP_ARGS(session), + TP_STRUCT__entry( + __array(char, name, L2TP_SESSION_NAME_MAX) + __field(u32, tid) + __field(u32, ptid) + __field(u32, sid) + __field(u32, psid) + __field(enum l2tp_pwtype, pwtype) + ), + TP_fast_assign( + memcpy(__entry->name, session->name, L2TP_SESSION_NAME_MAX); + __entry->tid = session->tunnel ? session->tunnel->tunnel_id : 0; + __entry->ptid = session->tunnel ? session->tunnel->peer_tunnel_id : 0; + __entry->sid = session->session_id; + __entry->psid = session->peer_session_id; + __entry->pwtype = session->pwtype; + ), + TP_printk("%s: pseudowire=%s sid=%u psid=%u tid=%u ptid=%u", + __entry->name, + show_pw_type_name(__entry->pwtype), + __entry->sid, + __entry->psid, + __entry->sid, + __entry->psid) +); + +DEFINE_EVENT(session_only_evt, delete_session, + TP_PROTO(struct l2tp_session *session), + TP_ARGS(session) +); + +DEFINE_EVENT(session_only_evt, free_session, + TP_PROTO(struct l2tp_session *session), + TP_ARGS(session) +); + +DEFINE_EVENT(session_only_evt, session_seqnum_lns_enable, + TP_PROTO(struct l2tp_session *session), + TP_ARGS(session) +); + +DEFINE_EVENT(session_only_evt, session_seqnum_lns_disable, + TP_PROTO(struct l2tp_session *session), + TP_ARGS(session) +); + +DECLARE_EVENT_CLASS(session_seqnum_evt, + TP_PROTO(struct l2tp_session *session), + TP_ARGS(session), + TP_STRUCT__entry( + __array(char, name, L2TP_SESSION_NAME_MAX) + __field(u32, ns) + __field(u32, nr) + ), + TP_fast_assign( + memcpy(__entry->name, session->name, L2TP_SESSION_NAME_MAX); + __entry->ns = session->ns; + __entry->nr = session->nr; + ), + TP_printk("%s: ns=%u nr=%u", + __entry->name, + __entry->ns, + __entry->nr) +); + +DEFINE_EVENT(session_seqnum_evt, session_seqnum_update, + TP_PROTO(struct l2tp_session *session), + TP_ARGS(session) +); + +DEFINE_EVENT(session_seqnum_evt, session_seqnum_reset, + TP_PROTO(struct l2tp_session *session), + TP_ARGS(session) +); + +DECLARE_EVENT_CLASS(session_pkt_discard_evt, + TP_PROTO(struct l2tp_session *session, u32 pkt_ns), + TP_ARGS(session, pkt_ns), + TP_STRUCT__entry( + __array(char, name, L2TP_SESSION_NAME_MAX) + __field(u32, pkt_ns) + __field(u32, my_nr) + __field(u32, reorder_q_len) + ), + TP_fast_assign( + memcpy(__entry->name, session->name, L2TP_SESSION_NAME_MAX); + __entry->pkt_ns = pkt_ns, + __entry->my_nr = session->nr; + __entry->reorder_q_len = skb_queue_len(&session->reorder_q); + ), + TP_printk("%s: pkt_ns=%u my_nr=%u reorder_q_len=%u", + __entry->name, + __entry->pkt_ns, + __entry->my_nr, + __entry->reorder_q_len) +); + +DEFINE_EVENT(session_pkt_discard_evt, session_pkt_expired, + TP_PROTO(struct l2tp_session *session, u32 pkt_ns), + TP_ARGS(session, pkt_ns) +); + +DEFINE_EVENT(session_pkt_discard_evt, session_pkt_outside_rx_window, + TP_PROTO(struct l2tp_session *session, u32 pkt_ns), + TP_ARGS(session, pkt_ns) +); + +DEFINE_EVENT(session_pkt_discard_evt, session_pkt_oos, + TP_PROTO(struct l2tp_session *session, u32 pkt_ns), + TP_ARGS(session, pkt_ns) +); + +#endif /* _TRACE_L2TP_H */ + +/* This part must be outside protection */ +#undef TRACE_INCLUDE_PATH +#define TRACE_INCLUDE_PATH . +#undef TRACE_INCLUDE_FILE +#define TRACE_INCLUDE_FILE trace +#include diff --git a/net/l3mdev/Kconfig b/net/l3mdev/Kconfig new file mode 100644 index 000000000..2b2861e1f --- /dev/null +++ b/net/l3mdev/Kconfig @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Configuration for L3 master device support +# + +config NET_L3_MASTER_DEV + bool "L3 Master device support" + depends on INET || IPV6 + help + This module provides glue between core networking code and device + drivers to support L3 master devices like VRF. diff --git a/net/l3mdev/Makefile b/net/l3mdev/Makefile new file mode 100644 index 000000000..9e7da0acc --- /dev/null +++ b/net/l3mdev/Makefile @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the L3 device API +# + +obj-y += l3mdev.o diff --git a/net/l3mdev/l3mdev.c b/net/l3mdev/l3mdev.c new file mode 100644 index 000000000..ca1091634 --- /dev/null +++ b/net/l3mdev/l3mdev.c @@ -0,0 +1,301 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/l3mdev/l3mdev.c - L3 master device implementation + * Copyright (c) 2015 Cumulus Networks + * Copyright (c) 2015 David Ahern + */ + +#include +#include +#include + +static DEFINE_SPINLOCK(l3mdev_lock); + +struct l3mdev_handler { + lookup_by_table_id_t dev_lookup; +}; + +static struct l3mdev_handler l3mdev_handlers[L3MDEV_TYPE_MAX + 1]; + +static int l3mdev_check_type(enum l3mdev_type l3type) +{ + if (l3type <= L3MDEV_TYPE_UNSPEC || l3type > L3MDEV_TYPE_MAX) + return -EINVAL; + + return 0; +} + +int l3mdev_table_lookup_register(enum l3mdev_type l3type, + lookup_by_table_id_t fn) +{ + struct l3mdev_handler *hdlr; + int res; + + res = l3mdev_check_type(l3type); + if (res) + return res; + + hdlr = &l3mdev_handlers[l3type]; + + spin_lock(&l3mdev_lock); + + if (hdlr->dev_lookup) { + res = -EBUSY; + goto unlock; + } + + hdlr->dev_lookup = fn; + res = 0; + +unlock: + spin_unlock(&l3mdev_lock); + + return res; +} +EXPORT_SYMBOL_GPL(l3mdev_table_lookup_register); + +void l3mdev_table_lookup_unregister(enum l3mdev_type l3type, + lookup_by_table_id_t fn) +{ + struct l3mdev_handler *hdlr; + + if (l3mdev_check_type(l3type)) + return; + + hdlr = &l3mdev_handlers[l3type]; + + spin_lock(&l3mdev_lock); + + if (hdlr->dev_lookup == fn) + hdlr->dev_lookup = NULL; + + spin_unlock(&l3mdev_lock); +} +EXPORT_SYMBOL_GPL(l3mdev_table_lookup_unregister); + +int l3mdev_ifindex_lookup_by_table_id(enum l3mdev_type l3type, + struct net *net, u32 table_id) +{ + lookup_by_table_id_t lookup; + struct l3mdev_handler *hdlr; + int ifindex = -EINVAL; + int res; + + res = l3mdev_check_type(l3type); + if (res) + return res; + + hdlr = &l3mdev_handlers[l3type]; + + spin_lock(&l3mdev_lock); + + lookup = hdlr->dev_lookup; + if (!lookup) + goto unlock; + + ifindex = lookup(net, table_id); + +unlock: + spin_unlock(&l3mdev_lock); + + return ifindex; +} +EXPORT_SYMBOL_GPL(l3mdev_ifindex_lookup_by_table_id); + +/** + * l3mdev_master_ifindex_rcu - get index of L3 master device + * @dev: targeted interface + */ + +int l3mdev_master_ifindex_rcu(const struct net_device *dev) +{ + int ifindex = 0; + + if (!dev) + return 0; + + if (netif_is_l3_master(dev)) { + ifindex = dev->ifindex; + } else if (netif_is_l3_slave(dev)) { + struct net_device *master; + struct net_device *_dev = (struct net_device *)dev; + + /* netdev_master_upper_dev_get_rcu calls + * list_first_or_null_rcu to walk the upper dev list. + * list_first_or_null_rcu does not handle a const arg. We aren't + * making changes, just want the master device from that list so + * typecast to remove the const + */ + master = netdev_master_upper_dev_get_rcu(_dev); + if (master) + ifindex = master->ifindex; + } + + return ifindex; +} +EXPORT_SYMBOL_GPL(l3mdev_master_ifindex_rcu); + +/** + * l3mdev_master_upper_ifindex_by_index_rcu - get index of upper l3 master + * device + * @net: network namespace for device index lookup + * @ifindex: targeted interface + */ +int l3mdev_master_upper_ifindex_by_index_rcu(struct net *net, int ifindex) +{ + struct net_device *dev; + + dev = dev_get_by_index_rcu(net, ifindex); + while (dev && !netif_is_l3_master(dev)) + dev = netdev_master_upper_dev_get_rcu(dev); + + return dev ? dev->ifindex : 0; +} +EXPORT_SYMBOL_GPL(l3mdev_master_upper_ifindex_by_index_rcu); + +/** + * l3mdev_fib_table_rcu - get FIB table id associated with an L3 + * master interface + * @dev: targeted interface + */ + +u32 l3mdev_fib_table_rcu(const struct net_device *dev) +{ + u32 tb_id = 0; + + if (!dev) + return 0; + + if (netif_is_l3_master(dev)) { + if (dev->l3mdev_ops->l3mdev_fib_table) + tb_id = dev->l3mdev_ops->l3mdev_fib_table(dev); + } else if (netif_is_l3_slave(dev)) { + /* Users of netdev_master_upper_dev_get_rcu need non-const, + * but current inet_*type functions take a const + */ + struct net_device *_dev = (struct net_device *) dev; + const struct net_device *master; + + master = netdev_master_upper_dev_get_rcu(_dev); + if (master && + master->l3mdev_ops->l3mdev_fib_table) + tb_id = master->l3mdev_ops->l3mdev_fib_table(master); + } + + return tb_id; +} +EXPORT_SYMBOL_GPL(l3mdev_fib_table_rcu); + +u32 l3mdev_fib_table_by_index(struct net *net, int ifindex) +{ + struct net_device *dev; + u32 tb_id = 0; + + if (!ifindex) + return 0; + + rcu_read_lock(); + + dev = dev_get_by_index_rcu(net, ifindex); + if (dev) + tb_id = l3mdev_fib_table_rcu(dev); + + rcu_read_unlock(); + + return tb_id; +} +EXPORT_SYMBOL_GPL(l3mdev_fib_table_by_index); + +/** + * l3mdev_link_scope_lookup - IPv6 route lookup based on flow for link + * local and multicast addresses + * @net: network namespace for device index lookup + * @fl6: IPv6 flow struct for lookup + * This function does not hold refcnt on the returned dst. + * Caller must hold rcu_read_lock(). + */ + +struct dst_entry *l3mdev_link_scope_lookup(struct net *net, + struct flowi6 *fl6) +{ + struct dst_entry *dst = NULL; + struct net_device *dev; + + WARN_ON_ONCE(!rcu_read_lock_held()); + if (fl6->flowi6_oif) { + dev = dev_get_by_index_rcu(net, fl6->flowi6_oif); + if (dev && netif_is_l3_slave(dev)) + dev = netdev_master_upper_dev_get_rcu(dev); + + if (dev && netif_is_l3_master(dev) && + dev->l3mdev_ops->l3mdev_link_scope_lookup) + dst = dev->l3mdev_ops->l3mdev_link_scope_lookup(dev, fl6); + } + + return dst; +} +EXPORT_SYMBOL_GPL(l3mdev_link_scope_lookup); + +/** + * l3mdev_fib_rule_match - Determine if flowi references an + * L3 master device + * @net: network namespace for device index lookup + * @fl: flow struct + * @arg: store the table the rule matched with here + */ + +int l3mdev_fib_rule_match(struct net *net, struct flowi *fl, + struct fib_lookup_arg *arg) +{ + struct net_device *dev; + int rc = 0; + + /* update flow ensures flowi_l3mdev is set when relevant */ + if (!fl->flowi_l3mdev) + return 0; + + rcu_read_lock(); + + dev = dev_get_by_index_rcu(net, fl->flowi_l3mdev); + if (dev && netif_is_l3_master(dev) && + dev->l3mdev_ops->l3mdev_fib_table) { + arg->table = dev->l3mdev_ops->l3mdev_fib_table(dev); + rc = 1; + } + + rcu_read_unlock(); + + return rc; +} + +void l3mdev_update_flow(struct net *net, struct flowi *fl) +{ + struct net_device *dev; + + rcu_read_lock(); + + if (fl->flowi_oif) { + dev = dev_get_by_index_rcu(net, fl->flowi_oif); + if (dev) { + if (!fl->flowi_l3mdev) + fl->flowi_l3mdev = l3mdev_master_ifindex_rcu(dev); + + /* oif set to L3mdev directs lookup to its table; + * reset to avoid oif match in fib_lookup + */ + if (netif_is_l3_master(dev)) + fl->flowi_oif = 0; + goto out; + } + } + + if (fl->flowi_iif > LOOPBACK_IFINDEX && !fl->flowi_l3mdev) { + dev = dev_get_by_index_rcu(net, fl->flowi_iif); + if (dev) + fl->flowi_l3mdev = l3mdev_master_ifindex_rcu(dev); + } + +out: + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(l3mdev_update_flow); diff --git a/net/lapb/Kconfig b/net/lapb/Kconfig new file mode 100644 index 000000000..da87b47f0 --- /dev/null +++ b/net/lapb/Kconfig @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# LAPB Data Link Drive +# + +config LAPB + tristate "LAPB Data Link Driver" + help + Link Access Procedure, Balanced (LAPB) is the data link layer (i.e. + the lower) part of the X.25 protocol. It offers a reliable + connection service to exchange data frames with one other host, and + it is used to transport higher level protocols (mostly X.25 Packet + Layer, the higher part of X.25, but others are possible as well). + Usually, LAPB is used with specialized X.21 network cards, but Linux + currently supports LAPB only over Ethernet connections. If you want + to use LAPB connections over Ethernet, say Y here and to "LAPB over + Ethernet driver" below. Read + for technical + details. + + To compile this driver as a module, choose M here: the + module will be called lapb. If unsure, say N. diff --git a/net/lapb/Makefile b/net/lapb/Makefile new file mode 100644 index 000000000..7be91b4c0 --- /dev/null +++ b/net/lapb/Makefile @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the Linux LAPB layer. +# + +obj-$(CONFIG_LAPB) += lapb.o + +lapb-y := lapb_in.o lapb_out.o lapb_subr.o lapb_timer.o lapb_iface.o diff --git a/net/lapb/lapb_iface.c b/net/lapb/lapb_iface.c new file mode 100644 index 000000000..0971ca48b --- /dev/null +++ b/net/lapb/lapb_iface.c @@ -0,0 +1,553 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * LAPB release 002 + * + * This code REQUIRES 2.1.15 or higher/ NET3.038 + * + * History + * LAPB 001 Jonathan Naylor Started Coding + * LAPB 002 Jonathan Naylor New timer architecture. + * 2000-10-29 Henner Eisen lapb_data_indication() return status. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static LIST_HEAD(lapb_list); +static DEFINE_RWLOCK(lapb_list_lock); + +/* + * Free an allocated lapb control block. + */ +static void lapb_free_cb(struct lapb_cb *lapb) +{ + kfree(lapb); +} + +static __inline__ void lapb_hold(struct lapb_cb *lapb) +{ + refcount_inc(&lapb->refcnt); +} + +static __inline__ void lapb_put(struct lapb_cb *lapb) +{ + if (refcount_dec_and_test(&lapb->refcnt)) + lapb_free_cb(lapb); +} + +/* + * Socket removal during an interrupt is now safe. + */ +static void __lapb_remove_cb(struct lapb_cb *lapb) +{ + if (lapb->node.next) { + list_del(&lapb->node); + lapb_put(lapb); + } +} + +/* + * Add a socket to the bound sockets list. + */ +static void __lapb_insert_cb(struct lapb_cb *lapb) +{ + list_add(&lapb->node, &lapb_list); + lapb_hold(lapb); +} + +static struct lapb_cb *__lapb_devtostruct(struct net_device *dev) +{ + struct lapb_cb *lapb, *use = NULL; + + list_for_each_entry(lapb, &lapb_list, node) { + if (lapb->dev == dev) { + use = lapb; + break; + } + } + + if (use) + lapb_hold(use); + + return use; +} + +static struct lapb_cb *lapb_devtostruct(struct net_device *dev) +{ + struct lapb_cb *rc; + + read_lock_bh(&lapb_list_lock); + rc = __lapb_devtostruct(dev); + read_unlock_bh(&lapb_list_lock); + + return rc; +} +/* + * Create an empty LAPB control block. + */ +static struct lapb_cb *lapb_create_cb(void) +{ + struct lapb_cb *lapb = kzalloc(sizeof(*lapb), GFP_ATOMIC); + + if (!lapb) + goto out; + + skb_queue_head_init(&lapb->write_queue); + skb_queue_head_init(&lapb->ack_queue); + + timer_setup(&lapb->t1timer, NULL, 0); + timer_setup(&lapb->t2timer, NULL, 0); + lapb->t1timer_running = false; + lapb->t2timer_running = false; + + lapb->t1 = LAPB_DEFAULT_T1; + lapb->t2 = LAPB_DEFAULT_T2; + lapb->n2 = LAPB_DEFAULT_N2; + lapb->mode = LAPB_DEFAULT_MODE; + lapb->window = LAPB_DEFAULT_WINDOW; + lapb->state = LAPB_STATE_0; + + spin_lock_init(&lapb->lock); + refcount_set(&lapb->refcnt, 1); +out: + return lapb; +} + +int lapb_register(struct net_device *dev, + const struct lapb_register_struct *callbacks) +{ + struct lapb_cb *lapb; + int rc = LAPB_BADTOKEN; + + write_lock_bh(&lapb_list_lock); + + lapb = __lapb_devtostruct(dev); + if (lapb) { + lapb_put(lapb); + goto out; + } + + lapb = lapb_create_cb(); + rc = LAPB_NOMEM; + if (!lapb) + goto out; + + lapb->dev = dev; + lapb->callbacks = callbacks; + + __lapb_insert_cb(lapb); + + lapb_start_t1timer(lapb); + + rc = LAPB_OK; +out: + write_unlock_bh(&lapb_list_lock); + return rc; +} +EXPORT_SYMBOL(lapb_register); + +int lapb_unregister(struct net_device *dev) +{ + struct lapb_cb *lapb; + int rc = LAPB_BADTOKEN; + + write_lock_bh(&lapb_list_lock); + lapb = __lapb_devtostruct(dev); + if (!lapb) + goto out; + lapb_put(lapb); + + /* Wait for other refs to "lapb" to drop */ + while (refcount_read(&lapb->refcnt) > 2) + usleep_range(1, 10); + + spin_lock_bh(&lapb->lock); + + lapb_stop_t1timer(lapb); + lapb_stop_t2timer(lapb); + + lapb_clear_queues(lapb); + + spin_unlock_bh(&lapb->lock); + + /* Wait for running timers to stop */ + del_timer_sync(&lapb->t1timer); + del_timer_sync(&lapb->t2timer); + + __lapb_remove_cb(lapb); + + lapb_put(lapb); + rc = LAPB_OK; +out: + write_unlock_bh(&lapb_list_lock); + return rc; +} +EXPORT_SYMBOL(lapb_unregister); + +int lapb_getparms(struct net_device *dev, struct lapb_parms_struct *parms) +{ + int rc = LAPB_BADTOKEN; + struct lapb_cb *lapb = lapb_devtostruct(dev); + + if (!lapb) + goto out; + + spin_lock_bh(&lapb->lock); + + parms->t1 = lapb->t1 / HZ; + parms->t2 = lapb->t2 / HZ; + parms->n2 = lapb->n2; + parms->n2count = lapb->n2count; + parms->state = lapb->state; + parms->window = lapb->window; + parms->mode = lapb->mode; + + if (!timer_pending(&lapb->t1timer)) + parms->t1timer = 0; + else + parms->t1timer = (lapb->t1timer.expires - jiffies) / HZ; + + if (!timer_pending(&lapb->t2timer)) + parms->t2timer = 0; + else + parms->t2timer = (lapb->t2timer.expires - jiffies) / HZ; + + spin_unlock_bh(&lapb->lock); + lapb_put(lapb); + rc = LAPB_OK; +out: + return rc; +} +EXPORT_SYMBOL(lapb_getparms); + +int lapb_setparms(struct net_device *dev, struct lapb_parms_struct *parms) +{ + int rc = LAPB_BADTOKEN; + struct lapb_cb *lapb = lapb_devtostruct(dev); + + if (!lapb) + goto out; + + spin_lock_bh(&lapb->lock); + + rc = LAPB_INVALUE; + if (parms->t1 < 1 || parms->t2 < 1 || parms->n2 < 1) + goto out_put; + + if (lapb->state == LAPB_STATE_0) { + if (parms->mode & LAPB_EXTENDED) { + if (parms->window < 1 || parms->window > 127) + goto out_put; + } else { + if (parms->window < 1 || parms->window > 7) + goto out_put; + } + lapb->mode = parms->mode; + lapb->window = parms->window; + } + + lapb->t1 = parms->t1 * HZ; + lapb->t2 = parms->t2 * HZ; + lapb->n2 = parms->n2; + + rc = LAPB_OK; +out_put: + spin_unlock_bh(&lapb->lock); + lapb_put(lapb); +out: + return rc; +} +EXPORT_SYMBOL(lapb_setparms); + +int lapb_connect_request(struct net_device *dev) +{ + struct lapb_cb *lapb = lapb_devtostruct(dev); + int rc = LAPB_BADTOKEN; + + if (!lapb) + goto out; + + spin_lock_bh(&lapb->lock); + + rc = LAPB_OK; + if (lapb->state == LAPB_STATE_1) + goto out_put; + + rc = LAPB_CONNECTED; + if (lapb->state == LAPB_STATE_3 || lapb->state == LAPB_STATE_4) + goto out_put; + + lapb_establish_data_link(lapb); + + lapb_dbg(0, "(%p) S0 -> S1\n", lapb->dev); + lapb->state = LAPB_STATE_1; + + rc = LAPB_OK; +out_put: + spin_unlock_bh(&lapb->lock); + lapb_put(lapb); +out: + return rc; +} +EXPORT_SYMBOL(lapb_connect_request); + +static int __lapb_disconnect_request(struct lapb_cb *lapb) +{ + switch (lapb->state) { + case LAPB_STATE_0: + return LAPB_NOTCONNECTED; + + case LAPB_STATE_1: + lapb_dbg(1, "(%p) S1 TX DISC(1)\n", lapb->dev); + lapb_dbg(0, "(%p) S1 -> S0\n", lapb->dev); + lapb_send_control(lapb, LAPB_DISC, LAPB_POLLON, LAPB_COMMAND); + lapb->state = LAPB_STATE_0; + lapb_start_t1timer(lapb); + return LAPB_NOTCONNECTED; + + case LAPB_STATE_2: + return LAPB_OK; + } + + lapb_clear_queues(lapb); + lapb->n2count = 0; + lapb_send_control(lapb, LAPB_DISC, LAPB_POLLON, LAPB_COMMAND); + lapb_start_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb->state = LAPB_STATE_2; + + lapb_dbg(1, "(%p) S3 DISC(1)\n", lapb->dev); + lapb_dbg(0, "(%p) S3 -> S2\n", lapb->dev); + + return LAPB_OK; +} + +int lapb_disconnect_request(struct net_device *dev) +{ + struct lapb_cb *lapb = lapb_devtostruct(dev); + int rc = LAPB_BADTOKEN; + + if (!lapb) + goto out; + + spin_lock_bh(&lapb->lock); + + rc = __lapb_disconnect_request(lapb); + + spin_unlock_bh(&lapb->lock); + lapb_put(lapb); +out: + return rc; +} +EXPORT_SYMBOL(lapb_disconnect_request); + +int lapb_data_request(struct net_device *dev, struct sk_buff *skb) +{ + struct lapb_cb *lapb = lapb_devtostruct(dev); + int rc = LAPB_BADTOKEN; + + if (!lapb) + goto out; + + spin_lock_bh(&lapb->lock); + + rc = LAPB_NOTCONNECTED; + if (lapb->state != LAPB_STATE_3 && lapb->state != LAPB_STATE_4) + goto out_put; + + skb_queue_tail(&lapb->write_queue, skb); + lapb_kick(lapb); + rc = LAPB_OK; +out_put: + spin_unlock_bh(&lapb->lock); + lapb_put(lapb); +out: + return rc; +} +EXPORT_SYMBOL(lapb_data_request); + +int lapb_data_received(struct net_device *dev, struct sk_buff *skb) +{ + struct lapb_cb *lapb = lapb_devtostruct(dev); + int rc = LAPB_BADTOKEN; + + if (lapb) { + spin_lock_bh(&lapb->lock); + lapb_data_input(lapb, skb); + spin_unlock_bh(&lapb->lock); + lapb_put(lapb); + rc = LAPB_OK; + } + + return rc; +} +EXPORT_SYMBOL(lapb_data_received); + +void lapb_connect_confirmation(struct lapb_cb *lapb, int reason) +{ + if (lapb->callbacks->connect_confirmation) + lapb->callbacks->connect_confirmation(lapb->dev, reason); +} + +void lapb_connect_indication(struct lapb_cb *lapb, int reason) +{ + if (lapb->callbacks->connect_indication) + lapb->callbacks->connect_indication(lapb->dev, reason); +} + +void lapb_disconnect_confirmation(struct lapb_cb *lapb, int reason) +{ + if (lapb->callbacks->disconnect_confirmation) + lapb->callbacks->disconnect_confirmation(lapb->dev, reason); +} + +void lapb_disconnect_indication(struct lapb_cb *lapb, int reason) +{ + if (lapb->callbacks->disconnect_indication) + lapb->callbacks->disconnect_indication(lapb->dev, reason); +} + +int lapb_data_indication(struct lapb_cb *lapb, struct sk_buff *skb) +{ + if (lapb->callbacks->data_indication) + return lapb->callbacks->data_indication(lapb->dev, skb); + + kfree_skb(skb); + return NET_RX_SUCCESS; /* For now; must be != NET_RX_DROP */ +} + +int lapb_data_transmit(struct lapb_cb *lapb, struct sk_buff *skb) +{ + int used = 0; + + if (lapb->callbacks->data_transmit) { + lapb->callbacks->data_transmit(lapb->dev, skb); + used = 1; + } + + return used; +} + +/* Handle device status changes. */ +static int lapb_device_event(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct lapb_cb *lapb; + + if (!net_eq(dev_net(dev), &init_net)) + return NOTIFY_DONE; + + if (dev->type != ARPHRD_X25) + return NOTIFY_DONE; + + lapb = lapb_devtostruct(dev); + if (!lapb) + return NOTIFY_DONE; + + spin_lock_bh(&lapb->lock); + + switch (event) { + case NETDEV_UP: + lapb_dbg(0, "(%p) Interface up: %s\n", dev, dev->name); + + if (netif_carrier_ok(dev)) { + lapb_dbg(0, "(%p): Carrier is already up: %s\n", dev, + dev->name); + if (lapb->mode & LAPB_DCE) { + lapb_start_t1timer(lapb); + } else { + if (lapb->state == LAPB_STATE_0) { + lapb->state = LAPB_STATE_1; + lapb_establish_data_link(lapb); + } + } + } + break; + case NETDEV_GOING_DOWN: + if (netif_carrier_ok(dev)) + __lapb_disconnect_request(lapb); + break; + case NETDEV_DOWN: + lapb_dbg(0, "(%p) Interface down: %s\n", dev, dev->name); + lapb_dbg(0, "(%p) S%d -> S0\n", dev, lapb->state); + lapb_clear_queues(lapb); + lapb->state = LAPB_STATE_0; + lapb->n2count = 0; + lapb_stop_t1timer(lapb); + lapb_stop_t2timer(lapb); + break; + case NETDEV_CHANGE: + if (netif_carrier_ok(dev)) { + lapb_dbg(0, "(%p): Carrier detected: %s\n", dev, + dev->name); + if (lapb->mode & LAPB_DCE) { + lapb_start_t1timer(lapb); + } else { + if (lapb->state == LAPB_STATE_0) { + lapb->state = LAPB_STATE_1; + lapb_establish_data_link(lapb); + } + } + } else { + lapb_dbg(0, "(%p) Carrier lost: %s\n", dev, dev->name); + lapb_dbg(0, "(%p) S%d -> S0\n", dev, lapb->state); + lapb_clear_queues(lapb); + lapb->state = LAPB_STATE_0; + lapb->n2count = 0; + lapb_stop_t1timer(lapb); + lapb_stop_t2timer(lapb); + } + break; + } + + spin_unlock_bh(&lapb->lock); + lapb_put(lapb); + return NOTIFY_DONE; +} + +static struct notifier_block lapb_dev_notifier = { + .notifier_call = lapb_device_event, +}; + +static int __init lapb_init(void) +{ + return register_netdevice_notifier(&lapb_dev_notifier); +} + +static void __exit lapb_exit(void) +{ + WARN_ON(!list_empty(&lapb_list)); + + unregister_netdevice_notifier(&lapb_dev_notifier); +} + +MODULE_AUTHOR("Jonathan Naylor "); +MODULE_DESCRIPTION("The X.25 Link Access Procedure B link layer protocol"); +MODULE_LICENSE("GPL"); + +module_init(lapb_init); +module_exit(lapb_exit); diff --git a/net/lapb/lapb_in.c b/net/lapb/lapb_in.c new file mode 100644 index 000000000..38ae23c09 --- /dev/null +++ b/net/lapb/lapb_in.c @@ -0,0 +1,556 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * LAPB release 002 + * + * This code REQUIRES 2.1.15 or higher/ NET3.038 + * + * History + * LAPB 001 Jonathan Naulor Started Coding + * LAPB 002 Jonathan Naylor New timer architecture. + * 2000-10-29 Henner Eisen lapb_data_indication() return status. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * State machine for state 0, Disconnected State. + * The handling of the timer(s) is in file lapb_timer.c. + */ +static void lapb_state0_machine(struct lapb_cb *lapb, struct sk_buff *skb, + struct lapb_frame *frame) +{ + switch (frame->type) { + case LAPB_SABM: + lapb_dbg(1, "(%p) S0 RX SABM(%d)\n", lapb->dev, frame->pf); + if (lapb->mode & LAPB_EXTENDED) { + lapb_dbg(1, "(%p) S0 TX DM(%d)\n", + lapb->dev, frame->pf); + lapb_send_control(lapb, LAPB_DM, frame->pf, + LAPB_RESPONSE); + } else { + lapb_dbg(1, "(%p) S0 TX UA(%d)\n", + lapb->dev, frame->pf); + lapb_dbg(0, "(%p) S0 -> S3\n", lapb->dev); + lapb_send_control(lapb, LAPB_UA, frame->pf, + LAPB_RESPONSE); + lapb_stop_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb->state = LAPB_STATE_3; + lapb->condition = 0x00; + lapb->n2count = 0; + lapb->vs = 0; + lapb->vr = 0; + lapb->va = 0; + lapb_connect_indication(lapb, LAPB_OK); + } + break; + + case LAPB_SABME: + lapb_dbg(1, "(%p) S0 RX SABME(%d)\n", lapb->dev, frame->pf); + if (lapb->mode & LAPB_EXTENDED) { + lapb_dbg(1, "(%p) S0 TX UA(%d)\n", + lapb->dev, frame->pf); + lapb_dbg(0, "(%p) S0 -> S3\n", lapb->dev); + lapb_send_control(lapb, LAPB_UA, frame->pf, + LAPB_RESPONSE); + lapb_stop_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb->state = LAPB_STATE_3; + lapb->condition = 0x00; + lapb->n2count = 0; + lapb->vs = 0; + lapb->vr = 0; + lapb->va = 0; + lapb_connect_indication(lapb, LAPB_OK); + } else { + lapb_dbg(1, "(%p) S0 TX DM(%d)\n", + lapb->dev, frame->pf); + lapb_send_control(lapb, LAPB_DM, frame->pf, + LAPB_RESPONSE); + } + break; + + case LAPB_DISC: + lapb_dbg(1, "(%p) S0 RX DISC(%d)\n", lapb->dev, frame->pf); + lapb_dbg(1, "(%p) S0 TX UA(%d)\n", lapb->dev, frame->pf); + lapb_send_control(lapb, LAPB_UA, frame->pf, LAPB_RESPONSE); + break; + + default: + break; + } + + kfree_skb(skb); +} + +/* + * State machine for state 1, Awaiting Connection State. + * The handling of the timer(s) is in file lapb_timer.c. + */ +static void lapb_state1_machine(struct lapb_cb *lapb, struct sk_buff *skb, + struct lapb_frame *frame) +{ + switch (frame->type) { + case LAPB_SABM: + lapb_dbg(1, "(%p) S1 RX SABM(%d)\n", lapb->dev, frame->pf); + if (lapb->mode & LAPB_EXTENDED) { + lapb_dbg(1, "(%p) S1 TX DM(%d)\n", + lapb->dev, frame->pf); + lapb_send_control(lapb, LAPB_DM, frame->pf, + LAPB_RESPONSE); + } else { + lapb_dbg(1, "(%p) S1 TX UA(%d)\n", + lapb->dev, frame->pf); + lapb_send_control(lapb, LAPB_UA, frame->pf, + LAPB_RESPONSE); + } + break; + + case LAPB_SABME: + lapb_dbg(1, "(%p) S1 RX SABME(%d)\n", lapb->dev, frame->pf); + if (lapb->mode & LAPB_EXTENDED) { + lapb_dbg(1, "(%p) S1 TX UA(%d)\n", + lapb->dev, frame->pf); + lapb_send_control(lapb, LAPB_UA, frame->pf, + LAPB_RESPONSE); + } else { + lapb_dbg(1, "(%p) S1 TX DM(%d)\n", + lapb->dev, frame->pf); + lapb_send_control(lapb, LAPB_DM, frame->pf, + LAPB_RESPONSE); + } + break; + + case LAPB_DISC: + lapb_dbg(1, "(%p) S1 RX DISC(%d)\n", lapb->dev, frame->pf); + lapb_dbg(1, "(%p) S1 TX DM(%d)\n", lapb->dev, frame->pf); + lapb_send_control(lapb, LAPB_DM, frame->pf, LAPB_RESPONSE); + break; + + case LAPB_UA: + lapb_dbg(1, "(%p) S1 RX UA(%d)\n", lapb->dev, frame->pf); + if (frame->pf) { + lapb_dbg(0, "(%p) S1 -> S3\n", lapb->dev); + lapb_stop_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb->state = LAPB_STATE_3; + lapb->condition = 0x00; + lapb->n2count = 0; + lapb->vs = 0; + lapb->vr = 0; + lapb->va = 0; + lapb_connect_confirmation(lapb, LAPB_OK); + } + break; + + case LAPB_DM: + lapb_dbg(1, "(%p) S1 RX DM(%d)\n", lapb->dev, frame->pf); + if (frame->pf) { + lapb_dbg(0, "(%p) S1 -> S0\n", lapb->dev); + lapb_clear_queues(lapb); + lapb->state = LAPB_STATE_0; + lapb_start_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb_disconnect_indication(lapb, LAPB_REFUSED); + } + break; + } + + kfree_skb(skb); +} + +/* + * State machine for state 2, Awaiting Release State. + * The handling of the timer(s) is in file lapb_timer.c + */ +static void lapb_state2_machine(struct lapb_cb *lapb, struct sk_buff *skb, + struct lapb_frame *frame) +{ + switch (frame->type) { + case LAPB_SABM: + case LAPB_SABME: + lapb_dbg(1, "(%p) S2 RX {SABM,SABME}(%d)\n", + lapb->dev, frame->pf); + lapb_dbg(1, "(%p) S2 TX DM(%d)\n", lapb->dev, frame->pf); + lapb_send_control(lapb, LAPB_DM, frame->pf, LAPB_RESPONSE); + break; + + case LAPB_DISC: + lapb_dbg(1, "(%p) S2 RX DISC(%d)\n", lapb->dev, frame->pf); + lapb_dbg(1, "(%p) S2 TX UA(%d)\n", lapb->dev, frame->pf); + lapb_send_control(lapb, LAPB_UA, frame->pf, LAPB_RESPONSE); + break; + + case LAPB_UA: + lapb_dbg(1, "(%p) S2 RX UA(%d)\n", lapb->dev, frame->pf); + if (frame->pf) { + lapb_dbg(0, "(%p) S2 -> S0\n", lapb->dev); + lapb->state = LAPB_STATE_0; + lapb_start_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb_disconnect_confirmation(lapb, LAPB_OK); + } + break; + + case LAPB_DM: + lapb_dbg(1, "(%p) S2 RX DM(%d)\n", lapb->dev, frame->pf); + if (frame->pf) { + lapb_dbg(0, "(%p) S2 -> S0\n", lapb->dev); + lapb->state = LAPB_STATE_0; + lapb_start_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb_disconnect_confirmation(lapb, LAPB_NOTCONNECTED); + } + break; + + case LAPB_I: + case LAPB_REJ: + case LAPB_RNR: + case LAPB_RR: + lapb_dbg(1, "(%p) S2 RX {I,REJ,RNR,RR}(%d)\n", + lapb->dev, frame->pf); + lapb_dbg(1, "(%p) S2 RX DM(%d)\n", lapb->dev, frame->pf); + if (frame->pf) + lapb_send_control(lapb, LAPB_DM, frame->pf, + LAPB_RESPONSE); + break; + } + + kfree_skb(skb); +} + +/* + * State machine for state 3, Connected State. + * The handling of the timer(s) is in file lapb_timer.c + */ +static void lapb_state3_machine(struct lapb_cb *lapb, struct sk_buff *skb, + struct lapb_frame *frame) +{ + int queued = 0; + int modulus = (lapb->mode & LAPB_EXTENDED) ? LAPB_EMODULUS : + LAPB_SMODULUS; + + switch (frame->type) { + case LAPB_SABM: + lapb_dbg(1, "(%p) S3 RX SABM(%d)\n", lapb->dev, frame->pf); + if (lapb->mode & LAPB_EXTENDED) { + lapb_dbg(1, "(%p) S3 TX DM(%d)\n", + lapb->dev, frame->pf); + lapb_send_control(lapb, LAPB_DM, frame->pf, + LAPB_RESPONSE); + } else { + lapb_dbg(1, "(%p) S3 TX UA(%d)\n", + lapb->dev, frame->pf); + lapb_send_control(lapb, LAPB_UA, frame->pf, + LAPB_RESPONSE); + lapb_stop_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb->condition = 0x00; + lapb->n2count = 0; + lapb->vs = 0; + lapb->vr = 0; + lapb->va = 0; + lapb_requeue_frames(lapb); + } + break; + + case LAPB_SABME: + lapb_dbg(1, "(%p) S3 RX SABME(%d)\n", lapb->dev, frame->pf); + if (lapb->mode & LAPB_EXTENDED) { + lapb_dbg(1, "(%p) S3 TX UA(%d)\n", + lapb->dev, frame->pf); + lapb_send_control(lapb, LAPB_UA, frame->pf, + LAPB_RESPONSE); + lapb_stop_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb->condition = 0x00; + lapb->n2count = 0; + lapb->vs = 0; + lapb->vr = 0; + lapb->va = 0; + lapb_requeue_frames(lapb); + } else { + lapb_dbg(1, "(%p) S3 TX DM(%d)\n", + lapb->dev, frame->pf); + lapb_send_control(lapb, LAPB_DM, frame->pf, + LAPB_RESPONSE); + } + break; + + case LAPB_DISC: + lapb_dbg(1, "(%p) S3 RX DISC(%d)\n", lapb->dev, frame->pf); + lapb_dbg(0, "(%p) S3 -> S0\n", lapb->dev); + lapb_clear_queues(lapb); + lapb_send_control(lapb, LAPB_UA, frame->pf, LAPB_RESPONSE); + lapb_start_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb->state = LAPB_STATE_0; + lapb_disconnect_indication(lapb, LAPB_OK); + break; + + case LAPB_DM: + lapb_dbg(1, "(%p) S3 RX DM(%d)\n", lapb->dev, frame->pf); + lapb_dbg(0, "(%p) S3 -> S0\n", lapb->dev); + lapb_clear_queues(lapb); + lapb->state = LAPB_STATE_0; + lapb_start_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb_disconnect_indication(lapb, LAPB_NOTCONNECTED); + break; + + case LAPB_RNR: + lapb_dbg(1, "(%p) S3 RX RNR(%d) R%d\n", + lapb->dev, frame->pf, frame->nr); + lapb->condition |= LAPB_PEER_RX_BUSY_CONDITION; + lapb_check_need_response(lapb, frame->cr, frame->pf); + if (lapb_validate_nr(lapb, frame->nr)) { + lapb_check_iframes_acked(lapb, frame->nr); + } else { + lapb->frmr_data = *frame; + lapb->frmr_type = LAPB_FRMR_Z; + lapb_transmit_frmr(lapb); + lapb_dbg(0, "(%p) S3 -> S4\n", lapb->dev); + lapb_start_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb->state = LAPB_STATE_4; + lapb->n2count = 0; + } + break; + + case LAPB_RR: + lapb_dbg(1, "(%p) S3 RX RR(%d) R%d\n", + lapb->dev, frame->pf, frame->nr); + lapb->condition &= ~LAPB_PEER_RX_BUSY_CONDITION; + lapb_check_need_response(lapb, frame->cr, frame->pf); + if (lapb_validate_nr(lapb, frame->nr)) { + lapb_check_iframes_acked(lapb, frame->nr); + } else { + lapb->frmr_data = *frame; + lapb->frmr_type = LAPB_FRMR_Z; + lapb_transmit_frmr(lapb); + lapb_dbg(0, "(%p) S3 -> S4\n", lapb->dev); + lapb_start_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb->state = LAPB_STATE_4; + lapb->n2count = 0; + } + break; + + case LAPB_REJ: + lapb_dbg(1, "(%p) S3 RX REJ(%d) R%d\n", + lapb->dev, frame->pf, frame->nr); + lapb->condition &= ~LAPB_PEER_RX_BUSY_CONDITION; + lapb_check_need_response(lapb, frame->cr, frame->pf); + if (lapb_validate_nr(lapb, frame->nr)) { + lapb_frames_acked(lapb, frame->nr); + lapb_stop_t1timer(lapb); + lapb->n2count = 0; + lapb_requeue_frames(lapb); + } else { + lapb->frmr_data = *frame; + lapb->frmr_type = LAPB_FRMR_Z; + lapb_transmit_frmr(lapb); + lapb_dbg(0, "(%p) S3 -> S4\n", lapb->dev); + lapb_start_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb->state = LAPB_STATE_4; + lapb->n2count = 0; + } + break; + + case LAPB_I: + lapb_dbg(1, "(%p) S3 RX I(%d) S%d R%d\n", + lapb->dev, frame->pf, frame->ns, frame->nr); + if (!lapb_validate_nr(lapb, frame->nr)) { + lapb->frmr_data = *frame; + lapb->frmr_type = LAPB_FRMR_Z; + lapb_transmit_frmr(lapb); + lapb_dbg(0, "(%p) S3 -> S4\n", lapb->dev); + lapb_start_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb->state = LAPB_STATE_4; + lapb->n2count = 0; + break; + } + if (lapb->condition & LAPB_PEER_RX_BUSY_CONDITION) + lapb_frames_acked(lapb, frame->nr); + else + lapb_check_iframes_acked(lapb, frame->nr); + + if (frame->ns == lapb->vr) { + int cn; + cn = lapb_data_indication(lapb, skb); + queued = 1; + /* + * If upper layer has dropped the frame, we + * basically ignore any further protocol + * processing. This will cause the peer + * to re-transmit the frame later like + * a frame lost on the wire. + */ + if (cn == NET_RX_DROP) { + pr_debug("rx congestion\n"); + break; + } + lapb->vr = (lapb->vr + 1) % modulus; + lapb->condition &= ~LAPB_REJECT_CONDITION; + if (frame->pf) + lapb_enquiry_response(lapb); + else { + if (!(lapb->condition & + LAPB_ACK_PENDING_CONDITION)) { + lapb->condition |= LAPB_ACK_PENDING_CONDITION; + lapb_start_t2timer(lapb); + } + } + } else { + if (lapb->condition & LAPB_REJECT_CONDITION) { + if (frame->pf) + lapb_enquiry_response(lapb); + } else { + lapb_dbg(1, "(%p) S3 TX REJ(%d) R%d\n", + lapb->dev, frame->pf, lapb->vr); + lapb->condition |= LAPB_REJECT_CONDITION; + lapb_send_control(lapb, LAPB_REJ, frame->pf, + LAPB_RESPONSE); + lapb->condition &= ~LAPB_ACK_PENDING_CONDITION; + } + } + break; + + case LAPB_FRMR: + lapb_dbg(1, "(%p) S3 RX FRMR(%d) %5ph\n", + lapb->dev, frame->pf, + skb->data); + lapb_establish_data_link(lapb); + lapb_dbg(0, "(%p) S3 -> S1\n", lapb->dev); + lapb_requeue_frames(lapb); + lapb->state = LAPB_STATE_1; + break; + + case LAPB_ILLEGAL: + lapb_dbg(1, "(%p) S3 RX ILLEGAL(%d)\n", lapb->dev, frame->pf); + lapb->frmr_data = *frame; + lapb->frmr_type = LAPB_FRMR_W; + lapb_transmit_frmr(lapb); + lapb_dbg(0, "(%p) S3 -> S4\n", lapb->dev); + lapb_start_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb->state = LAPB_STATE_4; + lapb->n2count = 0; + break; + } + + if (!queued) + kfree_skb(skb); +} + +/* + * State machine for state 4, Frame Reject State. + * The handling of the timer(s) is in file lapb_timer.c. + */ +static void lapb_state4_machine(struct lapb_cb *lapb, struct sk_buff *skb, + struct lapb_frame *frame) +{ + switch (frame->type) { + case LAPB_SABM: + lapb_dbg(1, "(%p) S4 RX SABM(%d)\n", lapb->dev, frame->pf); + if (lapb->mode & LAPB_EXTENDED) { + lapb_dbg(1, "(%p) S4 TX DM(%d)\n", + lapb->dev, frame->pf); + lapb_send_control(lapb, LAPB_DM, frame->pf, + LAPB_RESPONSE); + } else { + lapb_dbg(1, "(%p) S4 TX UA(%d)\n", + lapb->dev, frame->pf); + lapb_dbg(0, "(%p) S4 -> S3\n", lapb->dev); + lapb_send_control(lapb, LAPB_UA, frame->pf, + LAPB_RESPONSE); + lapb_stop_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb->state = LAPB_STATE_3; + lapb->condition = 0x00; + lapb->n2count = 0; + lapb->vs = 0; + lapb->vr = 0; + lapb->va = 0; + lapb_connect_indication(lapb, LAPB_OK); + } + break; + + case LAPB_SABME: + lapb_dbg(1, "(%p) S4 RX SABME(%d)\n", lapb->dev, frame->pf); + if (lapb->mode & LAPB_EXTENDED) { + lapb_dbg(1, "(%p) S4 TX UA(%d)\n", + lapb->dev, frame->pf); + lapb_dbg(0, "(%p) S4 -> S3\n", lapb->dev); + lapb_send_control(lapb, LAPB_UA, frame->pf, + LAPB_RESPONSE); + lapb_stop_t1timer(lapb); + lapb_stop_t2timer(lapb); + lapb->state = LAPB_STATE_3; + lapb->condition = 0x00; + lapb->n2count = 0; + lapb->vs = 0; + lapb->vr = 0; + lapb->va = 0; + lapb_connect_indication(lapb, LAPB_OK); + } else { + lapb_dbg(1, "(%p) S4 TX DM(%d)\n", + lapb->dev, frame->pf); + lapb_send_control(lapb, LAPB_DM, frame->pf, + LAPB_RESPONSE); + } + break; + } + + kfree_skb(skb); +} + +/* + * Process an incoming LAPB frame + */ +void lapb_data_input(struct lapb_cb *lapb, struct sk_buff *skb) +{ + struct lapb_frame frame; + + if (lapb_decode(lapb, skb, &frame) < 0) { + kfree_skb(skb); + return; + } + + switch (lapb->state) { + case LAPB_STATE_0: + lapb_state0_machine(lapb, skb, &frame); break; + case LAPB_STATE_1: + lapb_state1_machine(lapb, skb, &frame); break; + case LAPB_STATE_2: + lapb_state2_machine(lapb, skb, &frame); break; + case LAPB_STATE_3: + lapb_state3_machine(lapb, skb, &frame); break; + case LAPB_STATE_4: + lapb_state4_machine(lapb, skb, &frame); break; + } + + lapb_kick(lapb); +} diff --git a/net/lapb/lapb_out.c b/net/lapb/lapb_out.c new file mode 100644 index 000000000..a966d29c7 --- /dev/null +++ b/net/lapb/lapb_out.c @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * LAPB release 002 + * + * This code REQUIRES 2.1.15 or higher/ NET3.038 + * + * History + * LAPB 001 Jonathan Naylor Started Coding + * LAPB 002 Jonathan Naylor New timer architecture. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * This procedure is passed a buffer descriptor for an iframe. It builds + * the rest of the control part of the frame and then writes it out. + */ +static void lapb_send_iframe(struct lapb_cb *lapb, struct sk_buff *skb, int poll_bit) +{ + unsigned char *frame; + + if (!skb) + return; + + if (lapb->mode & LAPB_EXTENDED) { + frame = skb_push(skb, 2); + + frame[0] = LAPB_I; + frame[0] |= lapb->vs << 1; + frame[1] = poll_bit ? LAPB_EPF : 0; + frame[1] |= lapb->vr << 1; + } else { + frame = skb_push(skb, 1); + + *frame = LAPB_I; + *frame |= poll_bit ? LAPB_SPF : 0; + *frame |= lapb->vr << 5; + *frame |= lapb->vs << 1; + } + + lapb_dbg(1, "(%p) S%d TX I(%d) S%d R%d\n", + lapb->dev, lapb->state, poll_bit, lapb->vs, lapb->vr); + + lapb_transmit_buffer(lapb, skb, LAPB_COMMAND); +} + +void lapb_kick(struct lapb_cb *lapb) +{ + struct sk_buff *skb, *skbn; + unsigned short modulus, start, end; + + modulus = (lapb->mode & LAPB_EXTENDED) ? LAPB_EMODULUS : LAPB_SMODULUS; + start = !skb_peek(&lapb->ack_queue) ? lapb->va : lapb->vs; + end = (lapb->va + lapb->window) % modulus; + + if (!(lapb->condition & LAPB_PEER_RX_BUSY_CONDITION) && + start != end && skb_peek(&lapb->write_queue)) { + lapb->vs = start; + + /* + * Dequeue the frame and copy it. + */ + skb = skb_dequeue(&lapb->write_queue); + + do { + skbn = skb_copy(skb, GFP_ATOMIC); + if (!skbn) { + skb_queue_head(&lapb->write_queue, skb); + break; + } + + if (skb->sk) + skb_set_owner_w(skbn, skb->sk); + + /* + * Transmit the frame copy. + */ + lapb_send_iframe(lapb, skbn, LAPB_POLLOFF); + + lapb->vs = (lapb->vs + 1) % modulus; + + /* + * Requeue the original data frame. + */ + skb_queue_tail(&lapb->ack_queue, skb); + + } while (lapb->vs != end && (skb = skb_dequeue(&lapb->write_queue)) != NULL); + + lapb->condition &= ~LAPB_ACK_PENDING_CONDITION; + + if (!lapb_t1timer_running(lapb)) + lapb_start_t1timer(lapb); + } +} + +void lapb_transmit_buffer(struct lapb_cb *lapb, struct sk_buff *skb, int type) +{ + unsigned char *ptr; + + ptr = skb_push(skb, 1); + + if (lapb->mode & LAPB_MLP) { + if (lapb->mode & LAPB_DCE) { + if (type == LAPB_COMMAND) + *ptr = LAPB_ADDR_C; + if (type == LAPB_RESPONSE) + *ptr = LAPB_ADDR_D; + } else { + if (type == LAPB_COMMAND) + *ptr = LAPB_ADDR_D; + if (type == LAPB_RESPONSE) + *ptr = LAPB_ADDR_C; + } + } else { + if (lapb->mode & LAPB_DCE) { + if (type == LAPB_COMMAND) + *ptr = LAPB_ADDR_A; + if (type == LAPB_RESPONSE) + *ptr = LAPB_ADDR_B; + } else { + if (type == LAPB_COMMAND) + *ptr = LAPB_ADDR_B; + if (type == LAPB_RESPONSE) + *ptr = LAPB_ADDR_A; + } + } + + lapb_dbg(2, "(%p) S%d TX %3ph\n", lapb->dev, lapb->state, skb->data); + + if (!lapb_data_transmit(lapb, skb)) + kfree_skb(skb); +} + +void lapb_establish_data_link(struct lapb_cb *lapb) +{ + lapb->condition = 0x00; + lapb->n2count = 0; + + if (lapb->mode & LAPB_EXTENDED) { + lapb_dbg(1, "(%p) S%d TX SABME(1)\n", lapb->dev, lapb->state); + lapb_send_control(lapb, LAPB_SABME, LAPB_POLLON, LAPB_COMMAND); + } else { + lapb_dbg(1, "(%p) S%d TX SABM(1)\n", lapb->dev, lapb->state); + lapb_send_control(lapb, LAPB_SABM, LAPB_POLLON, LAPB_COMMAND); + } + + lapb_start_t1timer(lapb); + lapb_stop_t2timer(lapb); +} + +void lapb_enquiry_response(struct lapb_cb *lapb) +{ + lapb_dbg(1, "(%p) S%d TX RR(1) R%d\n", + lapb->dev, lapb->state, lapb->vr); + + lapb_send_control(lapb, LAPB_RR, LAPB_POLLON, LAPB_RESPONSE); + + lapb->condition &= ~LAPB_ACK_PENDING_CONDITION; +} + +void lapb_timeout_response(struct lapb_cb *lapb) +{ + lapb_dbg(1, "(%p) S%d TX RR(0) R%d\n", + lapb->dev, lapb->state, lapb->vr); + lapb_send_control(lapb, LAPB_RR, LAPB_POLLOFF, LAPB_RESPONSE); + + lapb->condition &= ~LAPB_ACK_PENDING_CONDITION; +} + +void lapb_check_iframes_acked(struct lapb_cb *lapb, unsigned short nr) +{ + if (lapb->vs == nr) { + lapb_frames_acked(lapb, nr); + lapb_stop_t1timer(lapb); + lapb->n2count = 0; + } else if (lapb->va != nr) { + lapb_frames_acked(lapb, nr); + lapb_start_t1timer(lapb); + } +} + +void lapb_check_need_response(struct lapb_cb *lapb, int type, int pf) +{ + if (type == LAPB_COMMAND && pf) + lapb_enquiry_response(lapb); +} diff --git a/net/lapb/lapb_subr.c b/net/lapb/lapb_subr.c new file mode 100644 index 000000000..592a22d86 --- /dev/null +++ b/net/lapb/lapb_subr.c @@ -0,0 +1,299 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * LAPB release 002 + * + * This code REQUIRES 2.1.15 or higher/ NET3.038 + * + * History + * LAPB 001 Jonathan Naylor Started Coding + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * This routine purges all the queues of frames. + */ +void lapb_clear_queues(struct lapb_cb *lapb) +{ + skb_queue_purge(&lapb->write_queue); + skb_queue_purge(&lapb->ack_queue); +} + +/* + * This routine purges the input queue of those frames that have been + * acknowledged. This replaces the boxes labelled "V(a) <- N(r)" on the + * SDL diagram. + */ +void lapb_frames_acked(struct lapb_cb *lapb, unsigned short nr) +{ + struct sk_buff *skb; + int modulus; + + modulus = (lapb->mode & LAPB_EXTENDED) ? LAPB_EMODULUS : LAPB_SMODULUS; + + /* + * Remove all the ack-ed frames from the ack queue. + */ + if (lapb->va != nr) + while (skb_peek(&lapb->ack_queue) && lapb->va != nr) { + skb = skb_dequeue(&lapb->ack_queue); + kfree_skb(skb); + lapb->va = (lapb->va + 1) % modulus; + } +} + +void lapb_requeue_frames(struct lapb_cb *lapb) +{ + struct sk_buff *skb, *skb_prev = NULL; + + /* + * Requeue all the un-ack-ed frames on the output queue to be picked + * up by lapb_kick called from the timer. This arrangement handles the + * possibility of an empty output queue. + */ + while ((skb = skb_dequeue(&lapb->ack_queue)) != NULL) { + if (!skb_prev) + skb_queue_head(&lapb->write_queue, skb); + else + skb_append(skb_prev, skb, &lapb->write_queue); + skb_prev = skb; + } +} + +/* + * Validate that the value of nr is between va and vs. Return true or + * false for testing. + */ +int lapb_validate_nr(struct lapb_cb *lapb, unsigned short nr) +{ + unsigned short vc = lapb->va; + int modulus; + + modulus = (lapb->mode & LAPB_EXTENDED) ? LAPB_EMODULUS : LAPB_SMODULUS; + + while (vc != lapb->vs) { + if (nr == vc) + return 1; + vc = (vc + 1) % modulus; + } + + return nr == lapb->vs; +} + +/* + * This routine is the centralised routine for parsing the control + * information for the different frame formats. + */ +int lapb_decode(struct lapb_cb *lapb, struct sk_buff *skb, + struct lapb_frame *frame) +{ + frame->type = LAPB_ILLEGAL; + + lapb_dbg(2, "(%p) S%d RX %3ph\n", lapb->dev, lapb->state, skb->data); + + /* We always need to look at 2 bytes, sometimes we need + * to look at 3 and those cases are handled below. + */ + if (!pskb_may_pull(skb, 2)) + return -1; + + if (lapb->mode & LAPB_MLP) { + if (lapb->mode & LAPB_DCE) { + if (skb->data[0] == LAPB_ADDR_D) + frame->cr = LAPB_COMMAND; + if (skb->data[0] == LAPB_ADDR_C) + frame->cr = LAPB_RESPONSE; + } else { + if (skb->data[0] == LAPB_ADDR_C) + frame->cr = LAPB_COMMAND; + if (skb->data[0] == LAPB_ADDR_D) + frame->cr = LAPB_RESPONSE; + } + } else { + if (lapb->mode & LAPB_DCE) { + if (skb->data[0] == LAPB_ADDR_B) + frame->cr = LAPB_COMMAND; + if (skb->data[0] == LAPB_ADDR_A) + frame->cr = LAPB_RESPONSE; + } else { + if (skb->data[0] == LAPB_ADDR_A) + frame->cr = LAPB_COMMAND; + if (skb->data[0] == LAPB_ADDR_B) + frame->cr = LAPB_RESPONSE; + } + } + + skb_pull(skb, 1); + + if (lapb->mode & LAPB_EXTENDED) { + if (!(skb->data[0] & LAPB_S)) { + if (!pskb_may_pull(skb, 2)) + return -1; + /* + * I frame - carries NR/NS/PF + */ + frame->type = LAPB_I; + frame->ns = (skb->data[0] >> 1) & 0x7F; + frame->nr = (skb->data[1] >> 1) & 0x7F; + frame->pf = skb->data[1] & LAPB_EPF; + frame->control[0] = skb->data[0]; + frame->control[1] = skb->data[1]; + skb_pull(skb, 2); + } else if ((skb->data[0] & LAPB_U) == 1) { + if (!pskb_may_pull(skb, 2)) + return -1; + /* + * S frame - take out PF/NR + */ + frame->type = skb->data[0] & 0x0F; + frame->nr = (skb->data[1] >> 1) & 0x7F; + frame->pf = skb->data[1] & LAPB_EPF; + frame->control[0] = skb->data[0]; + frame->control[1] = skb->data[1]; + skb_pull(skb, 2); + } else if ((skb->data[0] & LAPB_U) == 3) { + /* + * U frame - take out PF + */ + frame->type = skb->data[0] & ~LAPB_SPF; + frame->pf = skb->data[0] & LAPB_SPF; + frame->control[0] = skb->data[0]; + frame->control[1] = 0x00; + skb_pull(skb, 1); + } + } else { + if (!(skb->data[0] & LAPB_S)) { + /* + * I frame - carries NR/NS/PF + */ + frame->type = LAPB_I; + frame->ns = (skb->data[0] >> 1) & 0x07; + frame->nr = (skb->data[0] >> 5) & 0x07; + frame->pf = skb->data[0] & LAPB_SPF; + } else if ((skb->data[0] & LAPB_U) == 1) { + /* + * S frame - take out PF/NR + */ + frame->type = skb->data[0] & 0x0F; + frame->nr = (skb->data[0] >> 5) & 0x07; + frame->pf = skb->data[0] & LAPB_SPF; + } else if ((skb->data[0] & LAPB_U) == 3) { + /* + * U frame - take out PF + */ + frame->type = skb->data[0] & ~LAPB_SPF; + frame->pf = skb->data[0] & LAPB_SPF; + } + + frame->control[0] = skb->data[0]; + + skb_pull(skb, 1); + } + + return 0; +} + +/* + * This routine is called when the HDLC layer internally generates a + * command or response for the remote machine ( eg. RR, UA etc. ). + * Only supervisory or unnumbered frames are processed, FRMRs are handled + * by lapb_transmit_frmr below. + */ +void lapb_send_control(struct lapb_cb *lapb, int frametype, + int poll_bit, int type) +{ + struct sk_buff *skb; + unsigned char *dptr; + + if ((skb = alloc_skb(LAPB_HEADER_LEN + 3, GFP_ATOMIC)) == NULL) + return; + + skb_reserve(skb, LAPB_HEADER_LEN + 1); + + if (lapb->mode & LAPB_EXTENDED) { + if ((frametype & LAPB_U) == LAPB_U) { + dptr = skb_put(skb, 1); + *dptr = frametype; + *dptr |= poll_bit ? LAPB_SPF : 0; + } else { + dptr = skb_put(skb, 2); + dptr[0] = frametype; + dptr[1] = (lapb->vr << 1); + dptr[1] |= poll_bit ? LAPB_EPF : 0; + } + } else { + dptr = skb_put(skb, 1); + *dptr = frametype; + *dptr |= poll_bit ? LAPB_SPF : 0; + if ((frametype & LAPB_U) == LAPB_S) /* S frames carry NR */ + *dptr |= (lapb->vr << 5); + } + + lapb_transmit_buffer(lapb, skb, type); +} + +/* + * This routine generates FRMRs based on information previously stored in + * the LAPB control block. + */ +void lapb_transmit_frmr(struct lapb_cb *lapb) +{ + struct sk_buff *skb; + unsigned char *dptr; + + if ((skb = alloc_skb(LAPB_HEADER_LEN + 7, GFP_ATOMIC)) == NULL) + return; + + skb_reserve(skb, LAPB_HEADER_LEN + 1); + + if (lapb->mode & LAPB_EXTENDED) { + dptr = skb_put(skb, 6); + *dptr++ = LAPB_FRMR; + *dptr++ = lapb->frmr_data.control[0]; + *dptr++ = lapb->frmr_data.control[1]; + *dptr++ = (lapb->vs << 1) & 0xFE; + *dptr = (lapb->vr << 1) & 0xFE; + if (lapb->frmr_data.cr == LAPB_RESPONSE) + *dptr |= 0x01; + dptr++; + *dptr++ = lapb->frmr_type; + + lapb_dbg(1, "(%p) S%d TX FRMR %5ph\n", + lapb->dev, lapb->state, + &skb->data[1]); + } else { + dptr = skb_put(skb, 4); + *dptr++ = LAPB_FRMR; + *dptr++ = lapb->frmr_data.control[0]; + *dptr = (lapb->vs << 1) & 0x0E; + *dptr |= (lapb->vr << 5) & 0xE0; + if (lapb->frmr_data.cr == LAPB_RESPONSE) + *dptr |= 0x10; + dptr++; + *dptr++ = lapb->frmr_type; + + lapb_dbg(1, "(%p) S%d TX FRMR %3ph\n", + lapb->dev, lapb->state, &skb->data[1]); + } + + lapb_transmit_buffer(lapb, skb, LAPB_RESPONSE); +} diff --git a/net/lapb/lapb_timer.c b/net/lapb/lapb_timer.c new file mode 100644 index 000000000..5be688690 --- /dev/null +++ b/net/lapb/lapb_timer.c @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * LAPB release 002 + * + * This code REQUIRES 2.1.15 or higher/ NET3.038 + * + * History + * LAPB 001 Jonathan Naylor Started Coding + * LAPB 002 Jonathan Naylor New timer architecture. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void lapb_t1timer_expiry(struct timer_list *); +static void lapb_t2timer_expiry(struct timer_list *); + +void lapb_start_t1timer(struct lapb_cb *lapb) +{ + del_timer(&lapb->t1timer); + + lapb->t1timer.function = lapb_t1timer_expiry; + lapb->t1timer.expires = jiffies + lapb->t1; + + lapb->t1timer_running = true; + add_timer(&lapb->t1timer); +} + +void lapb_start_t2timer(struct lapb_cb *lapb) +{ + del_timer(&lapb->t2timer); + + lapb->t2timer.function = lapb_t2timer_expiry; + lapb->t2timer.expires = jiffies + lapb->t2; + + lapb->t2timer_running = true; + add_timer(&lapb->t2timer); +} + +void lapb_stop_t1timer(struct lapb_cb *lapb) +{ + lapb->t1timer_running = false; + del_timer(&lapb->t1timer); +} + +void lapb_stop_t2timer(struct lapb_cb *lapb) +{ + lapb->t2timer_running = false; + del_timer(&lapb->t2timer); +} + +int lapb_t1timer_running(struct lapb_cb *lapb) +{ + return lapb->t1timer_running; +} + +static void lapb_t2timer_expiry(struct timer_list *t) +{ + struct lapb_cb *lapb = from_timer(lapb, t, t2timer); + + spin_lock_bh(&lapb->lock); + if (timer_pending(&lapb->t2timer)) /* A new timer has been set up */ + goto out; + if (!lapb->t2timer_running) /* The timer has been stopped */ + goto out; + + if (lapb->condition & LAPB_ACK_PENDING_CONDITION) { + lapb->condition &= ~LAPB_ACK_PENDING_CONDITION; + lapb_timeout_response(lapb); + } + lapb->t2timer_running = false; + +out: + spin_unlock_bh(&lapb->lock); +} + +static void lapb_t1timer_expiry(struct timer_list *t) +{ + struct lapb_cb *lapb = from_timer(lapb, t, t1timer); + + spin_lock_bh(&lapb->lock); + if (timer_pending(&lapb->t1timer)) /* A new timer has been set up */ + goto out; + if (!lapb->t1timer_running) /* The timer has been stopped */ + goto out; + + switch (lapb->state) { + + /* + * If we are a DCE, send DM up to N2 times, then switch to + * STATE_1 and send SABM(E). + */ + case LAPB_STATE_0: + if (lapb->mode & LAPB_DCE && + lapb->n2count != lapb->n2) { + lapb->n2count++; + lapb_send_control(lapb, LAPB_DM, LAPB_POLLOFF, LAPB_RESPONSE); + } else { + lapb->state = LAPB_STATE_1; + lapb_establish_data_link(lapb); + } + break; + + /* + * Awaiting connection state, send SABM(E), up to N2 times. + */ + case LAPB_STATE_1: + if (lapb->n2count == lapb->n2) { + lapb_clear_queues(lapb); + lapb->state = LAPB_STATE_0; + lapb_disconnect_indication(lapb, LAPB_TIMEDOUT); + lapb_dbg(0, "(%p) S1 -> S0\n", lapb->dev); + lapb->t1timer_running = false; + goto out; + } else { + lapb->n2count++; + if (lapb->mode & LAPB_EXTENDED) { + lapb_dbg(1, "(%p) S1 TX SABME(1)\n", + lapb->dev); + lapb_send_control(lapb, LAPB_SABME, LAPB_POLLON, LAPB_COMMAND); + } else { + lapb_dbg(1, "(%p) S1 TX SABM(1)\n", + lapb->dev); + lapb_send_control(lapb, LAPB_SABM, LAPB_POLLON, LAPB_COMMAND); + } + } + break; + + /* + * Awaiting disconnection state, send DISC, up to N2 times. + */ + case LAPB_STATE_2: + if (lapb->n2count == lapb->n2) { + lapb_clear_queues(lapb); + lapb->state = LAPB_STATE_0; + lapb_disconnect_confirmation(lapb, LAPB_TIMEDOUT); + lapb_dbg(0, "(%p) S2 -> S0\n", lapb->dev); + lapb->t1timer_running = false; + goto out; + } else { + lapb->n2count++; + lapb_dbg(1, "(%p) S2 TX DISC(1)\n", lapb->dev); + lapb_send_control(lapb, LAPB_DISC, LAPB_POLLON, LAPB_COMMAND); + } + break; + + /* + * Data transfer state, restransmit I frames, up to N2 times. + */ + case LAPB_STATE_3: + if (lapb->n2count == lapb->n2) { + lapb_clear_queues(lapb); + lapb->state = LAPB_STATE_0; + lapb_stop_t2timer(lapb); + lapb_disconnect_indication(lapb, LAPB_TIMEDOUT); + lapb_dbg(0, "(%p) S3 -> S0\n", lapb->dev); + lapb->t1timer_running = false; + goto out; + } else { + lapb->n2count++; + lapb_requeue_frames(lapb); + lapb_kick(lapb); + } + break; + + /* + * Frame reject state, restransmit FRMR frames, up to N2 times. + */ + case LAPB_STATE_4: + if (lapb->n2count == lapb->n2) { + lapb_clear_queues(lapb); + lapb->state = LAPB_STATE_0; + lapb_disconnect_indication(lapb, LAPB_TIMEDOUT); + lapb_dbg(0, "(%p) S4 -> S0\n", lapb->dev); + lapb->t1timer_running = false; + goto out; + } else { + lapb->n2count++; + lapb_transmit_frmr(lapb); + } + break; + } + + lapb_start_t1timer(lapb); + +out: + spin_unlock_bh(&lapb->lock); +} diff --git a/net/llc/Kconfig b/net/llc/Kconfig new file mode 100644 index 000000000..7f79f5e13 --- /dev/null +++ b/net/llc/Kconfig @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: GPL-2.0-only +config LLC + tristate + +config LLC2 + tristate "ANSI/IEEE 802.2 LLC type 2 Support" + select LLC + help + This is a Logical Link Layer type 2, connection oriented support. + Select this if you want to have support for PF_LLC sockets. diff --git a/net/llc/Makefile b/net/llc/Makefile new file mode 100644 index 000000000..5e0ef436d --- /dev/null +++ b/net/llc/Makefile @@ -0,0 +1,25 @@ +########################################################################### +# Makefile for the Linux 802.2 LLC (fully-functional) layer. +# +# Copyright (c) 1997 by Procom Technology,Inc. +# 2001-2003 by Arnaldo Carvalho de Melo +# +# This program can be redistributed or modified under the terms of the +# GNU General Public License as published by the Free Software Foundation. +# This program is distributed without any warranty or implied warranty +# of merchantability or fitness for a particular purpose. +# +# See the GNU General Public License for more details. +########################################################################### + +obj-$(CONFIG_LLC) += llc.o + +llc-y := llc_core.o llc_input.o llc_output.o + +obj-$(CONFIG_LLC2) += llc2.o + +llc2-y := llc_if.o llc_c_ev.o llc_c_ac.o llc_conn.o llc_c_st.o llc_pdu.o \ + llc_sap.o llc_s_ac.o llc_s_ev.o llc_s_st.o af_llc.o llc_station.o + +llc2-$(CONFIG_PROC_FS) += llc_proc.o +llc2-$(CONFIG_SYSCTL) += sysctl_net_llc.o diff --git a/net/llc/af_llc.c b/net/llc/af_llc.c new file mode 100644 index 000000000..19c478bd8 --- /dev/null +++ b/net/llc/af_llc.c @@ -0,0 +1,1309 @@ +/* + * af_llc.c - LLC User Interface SAPs + * Description: + * Functions in this module are implementation of socket based llc + * communications for the Linux operating system. Support of llc class + * one and class two is provided via SOCK_DGRAM and SOCK_STREAM + * respectively. + * + * An llc2 connection is (mac + sap), only one llc2 sap connection + * is allowed per mac. Though one sap may have multiple mac + sap + * connections. + * + * Copyright (c) 2001 by Jay Schulist + * 2002-2003 by Arnaldo Carvalho de Melo + * + * This program can be redistributed or modified under the terms of the + * GNU General Public License as published by the Free Software Foundation. + * This program is distributed without any warranty or implied warranty + * of merchantability or fitness for a particular purpose. + * + * See the GNU General Public License for more details. + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +/* remember: uninitialized global data is zeroed because its in .bss */ +static u16 llc_ui_sap_last_autoport = LLC_SAP_DYN_START; +static u16 llc_ui_sap_link_no_max[256]; +static struct sockaddr_llc llc_ui_addrnull; +static const struct proto_ops llc_ui_ops; + +static bool llc_ui_wait_for_conn(struct sock *sk, long timeout); +static int llc_ui_wait_for_disc(struct sock *sk, long timeout); +static int llc_ui_wait_for_busy_core(struct sock *sk, long timeout); + +#if 0 +#define dprintk(args...) printk(KERN_DEBUG args) +#else +#define dprintk(args...) do {} while (0) +#endif + +/* Maybe we'll add some more in the future. */ +#define LLC_CMSG_PKTINFO 1 + + +/** + * llc_ui_next_link_no - return the next unused link number for a sap + * @sap: Address of sap to get link number from. + * + * Return the next unused link number for a given sap. + */ +static inline u16 llc_ui_next_link_no(int sap) +{ + return llc_ui_sap_link_no_max[sap]++; +} + +/** + * llc_proto_type - return eth protocol for ARP header type + * @arphrd: ARP header type. + * + * Given an ARP header type return the corresponding ethernet protocol. + */ +static inline __be16 llc_proto_type(u16 arphrd) +{ + return htons(ETH_P_802_2); +} + +/** + * llc_ui_addr_null - determines if a address structure is null + * @addr: Address to test if null. + */ +static inline u8 llc_ui_addr_null(struct sockaddr_llc *addr) +{ + return !memcmp(addr, &llc_ui_addrnull, sizeof(*addr)); +} + +/** + * llc_ui_header_len - return length of llc header based on operation + * @sk: Socket which contains a valid llc socket type. + * @addr: Complete sockaddr_llc structure received from the user. + * + * Provide the length of the llc header depending on what kind of + * operation the user would like to perform and the type of socket. + * Returns the correct llc header length. + */ +static inline u8 llc_ui_header_len(struct sock *sk, struct sockaddr_llc *addr) +{ + u8 rc = LLC_PDU_LEN_U; + + if (addr->sllc_test) + rc = LLC_PDU_LEN_U; + else if (addr->sllc_xid) + /* We need to expand header to sizeof(struct llc_xid_info) + * since llc_pdu_init_as_xid_cmd() sets 4,5,6 bytes of LLC header + * as XID PDU. In llc_ui_sendmsg() we reserved header size and then + * filled all other space with user data. If we won't reserve this + * bytes, llc_pdu_init_as_xid_cmd() will overwrite user data + */ + rc = LLC_PDU_LEN_U_XID; + else if (sk->sk_type == SOCK_STREAM) + rc = LLC_PDU_LEN_I; + return rc; +} + +/** + * llc_ui_send_data - send data via reliable llc2 connection + * @sk: Connection the socket is using. + * @skb: Data the user wishes to send. + * @noblock: can we block waiting for data? + * + * Send data via reliable llc2 connection. + * Returns 0 upon success, non-zero if action did not succeed. + * + * This function always consumes a reference to the skb. + */ +static int llc_ui_send_data(struct sock* sk, struct sk_buff *skb, int noblock) +{ + struct llc_sock* llc = llc_sk(sk); + + if (unlikely(llc_data_accept_state(llc->state) || + llc->remote_busy_flag || + llc->p_flag)) { + long timeout = sock_sndtimeo(sk, noblock); + int rc; + + rc = llc_ui_wait_for_busy_core(sk, timeout); + if (rc) { + kfree_skb(skb); + return rc; + } + } + return llc_build_and_send_pkt(sk, skb); +} + +static void llc_ui_sk_init(struct socket *sock, struct sock *sk) +{ + sock_graft(sk, sock); + sk->sk_type = sock->type; + sock->ops = &llc_ui_ops; +} + +static struct proto llc_proto = { + .name = "LLC", + .owner = THIS_MODULE, + .obj_size = sizeof(struct llc_sock), + .slab_flags = SLAB_TYPESAFE_BY_RCU, +}; + +/** + * llc_ui_create - alloc and init a new llc_ui socket + * @net: network namespace (must be default network) + * @sock: Socket to initialize and attach allocated sk to. + * @protocol: Unused. + * @kern: on behalf of kernel or userspace + * + * Allocate and initialize a new llc_ui socket, validate the user wants a + * socket type we have available. + * Returns 0 upon success, negative upon failure. + */ +static int llc_ui_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + int rc = -ESOCKTNOSUPPORT; + + if (!ns_capable(net->user_ns, CAP_NET_RAW)) + return -EPERM; + + if (!net_eq(net, &init_net)) + return -EAFNOSUPPORT; + + if (likely(sock->type == SOCK_DGRAM || sock->type == SOCK_STREAM)) { + rc = -ENOMEM; + sk = llc_sk_alloc(net, PF_LLC, GFP_KERNEL, &llc_proto, kern); + if (sk) { + rc = 0; + llc_ui_sk_init(sock, sk); + } + } + return rc; +} + +/** + * llc_ui_release - shutdown socket + * @sock: Socket to release. + * + * Shutdown and deallocate an existing socket. + */ +static int llc_ui_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct llc_sock *llc; + + if (unlikely(sk == NULL)) + goto out; + sock_hold(sk); + lock_sock(sk); + llc = llc_sk(sk); + dprintk("%s: closing local(%02X) remote(%02X)\n", __func__, + llc->laddr.lsap, llc->daddr.lsap); + if (!llc_send_disc(sk)) + llc_ui_wait_for_disc(sk, sk->sk_rcvtimeo); + if (!sock_flag(sk, SOCK_ZAPPED)) { + struct llc_sap *sap = llc->sap; + + /* Hold this for release_sock(), so that llc_backlog_rcv() + * could still use it. + */ + llc_sap_hold(sap); + llc_sap_remove_socket(llc->sap, sk); + release_sock(sk); + llc_sap_put(sap); + } else { + release_sock(sk); + } + netdev_put(llc->dev, &llc->dev_tracker); + sock_put(sk); + llc_sk_free(sk); +out: + return 0; +} + +/** + * llc_ui_autoport - provide dynamically allocate SAP number + * + * Provide the caller with a dynamically allocated SAP number according + * to the rules that are set in this function. Returns: 0, upon failure, + * SAP number otherwise. + */ +static int llc_ui_autoport(void) +{ + struct llc_sap *sap; + int i, tries = 0; + + while (tries < LLC_SAP_DYN_TRIES) { + for (i = llc_ui_sap_last_autoport; + i < LLC_SAP_DYN_STOP; i += 2) { + sap = llc_sap_find(i); + if (!sap) { + llc_ui_sap_last_autoport = i + 2; + goto out; + } + llc_sap_put(sap); + } + llc_ui_sap_last_autoport = LLC_SAP_DYN_START; + tries++; + } + i = 0; +out: + return i; +} + +/** + * llc_ui_autobind - automatically bind a socket to a sap + * @sock: socket to bind + * @addr: address to connect to + * + * Used by llc_ui_connect and llc_ui_sendmsg when the user hasn't + * specifically used llc_ui_bind to bind to an specific address/sap + * + * Returns: 0 upon success, negative otherwise. + */ +static int llc_ui_autobind(struct socket *sock, struct sockaddr_llc *addr) +{ + struct sock *sk = sock->sk; + struct llc_sock *llc = llc_sk(sk); + struct net_device *dev = NULL; + struct llc_sap *sap; + int rc = -EINVAL; + + if (!sock_flag(sk, SOCK_ZAPPED)) + goto out; + if (!addr->sllc_arphrd) + addr->sllc_arphrd = ARPHRD_ETHER; + if (addr->sllc_arphrd != ARPHRD_ETHER) + goto out; + rc = -ENODEV; + if (sk->sk_bound_dev_if) { + dev = dev_get_by_index(&init_net, sk->sk_bound_dev_if); + if (dev && addr->sllc_arphrd != dev->type) { + dev_put(dev); + dev = NULL; + } + } else + dev = dev_getfirstbyhwtype(&init_net, addr->sllc_arphrd); + if (!dev) + goto out; + rc = -EUSERS; + llc->laddr.lsap = llc_ui_autoport(); + if (!llc->laddr.lsap) + goto out; + rc = -EBUSY; /* some other network layer is using the sap */ + sap = llc_sap_open(llc->laddr.lsap, NULL); + if (!sap) + goto out; + + /* Note: We do not expect errors from this point. */ + llc->dev = dev; + netdev_tracker_alloc(llc->dev, &llc->dev_tracker, GFP_KERNEL); + dev = NULL; + + memcpy(llc->laddr.mac, llc->dev->dev_addr, IFHWADDRLEN); + memcpy(&llc->addr, addr, sizeof(llc->addr)); + /* assign new connection to its SAP */ + llc_sap_add_socket(sap, sk); + sock_reset_flag(sk, SOCK_ZAPPED); + rc = 0; +out: + dev_put(dev); + return rc; +} + +/** + * llc_ui_bind - bind a socket to a specific address. + * @sock: Socket to bind an address to. + * @uaddr: Address the user wants the socket bound to. + * @addrlen: Length of the uaddr structure. + * + * Bind a socket to a specific address. For llc a user is able to bind to + * a specific sap only or mac + sap. + * If the user desires to bind to a specific mac + sap, it is possible to + * have multiple sap connections via multiple macs. + * Bind and autobind for that matter must enforce the correct sap usage + * otherwise all hell will break loose. + * Returns: 0 upon success, negative otherwise. + */ +static int llc_ui_bind(struct socket *sock, struct sockaddr *uaddr, int addrlen) +{ + struct sockaddr_llc *addr = (struct sockaddr_llc *)uaddr; + struct sock *sk = sock->sk; + struct llc_sock *llc = llc_sk(sk); + struct net_device *dev = NULL; + struct llc_sap *sap; + int rc = -EINVAL; + + lock_sock(sk); + if (unlikely(!sock_flag(sk, SOCK_ZAPPED) || addrlen != sizeof(*addr))) + goto out; + rc = -EAFNOSUPPORT; + if (!addr->sllc_arphrd) + addr->sllc_arphrd = ARPHRD_ETHER; + if (unlikely(addr->sllc_family != AF_LLC || addr->sllc_arphrd != ARPHRD_ETHER)) + goto out; + dprintk("%s: binding %02X\n", __func__, addr->sllc_sap); + rc = -ENODEV; + rcu_read_lock(); + if (sk->sk_bound_dev_if) { + dev = dev_get_by_index_rcu(&init_net, sk->sk_bound_dev_if); + if (dev) { + if (is_zero_ether_addr(addr->sllc_mac)) + memcpy(addr->sllc_mac, dev->dev_addr, + IFHWADDRLEN); + if (addr->sllc_arphrd != dev->type || + !ether_addr_equal(addr->sllc_mac, + dev->dev_addr)) { + rc = -EINVAL; + dev = NULL; + } + } + } else { + dev = dev_getbyhwaddr_rcu(&init_net, addr->sllc_arphrd, + addr->sllc_mac); + } + dev_hold(dev); + rcu_read_unlock(); + if (!dev) + goto out; + + if (!addr->sllc_sap) { + rc = -EUSERS; + addr->sllc_sap = llc_ui_autoport(); + if (!addr->sllc_sap) + goto out; + } + sap = llc_sap_find(addr->sllc_sap); + if (!sap) { + sap = llc_sap_open(addr->sllc_sap, NULL); + rc = -EBUSY; /* some other network layer is using the sap */ + if (!sap) + goto out; + } else { + struct llc_addr laddr, daddr; + struct sock *ask; + + memset(&laddr, 0, sizeof(laddr)); + memset(&daddr, 0, sizeof(daddr)); + /* + * FIXME: check if the address is multicast, + * only SOCK_DGRAM can do this. + */ + memcpy(laddr.mac, addr->sllc_mac, IFHWADDRLEN); + laddr.lsap = addr->sllc_sap; + rc = -EADDRINUSE; /* mac + sap clash. */ + ask = llc_lookup_established(sap, &daddr, &laddr); + if (ask) { + sock_put(ask); + goto out_put; + } + } + + /* Note: We do not expect errors from this point. */ + llc->dev = dev; + netdev_tracker_alloc(llc->dev, &llc->dev_tracker, GFP_KERNEL); + dev = NULL; + + llc->laddr.lsap = addr->sllc_sap; + memcpy(llc->laddr.mac, addr->sllc_mac, IFHWADDRLEN); + memcpy(&llc->addr, addr, sizeof(llc->addr)); + /* assign new connection to its SAP */ + llc_sap_add_socket(sap, sk); + sock_reset_flag(sk, SOCK_ZAPPED); + rc = 0; +out_put: + llc_sap_put(sap); +out: + dev_put(dev); + release_sock(sk); + return rc; +} + +/** + * llc_ui_shutdown - shutdown a connect llc2 socket. + * @sock: Socket to shutdown. + * @how: What part of the socket to shutdown. + * + * Shutdown a connected llc2 socket. Currently this function only supports + * shutting down both sends and receives (2), we could probably make this + * function such that a user can shutdown only half the connection but not + * right now. + * Returns: 0 upon success, negative otherwise. + */ +static int llc_ui_shutdown(struct socket *sock, int how) +{ + struct sock *sk = sock->sk; + int rc = -ENOTCONN; + + lock_sock(sk); + if (unlikely(sk->sk_state != TCP_ESTABLISHED)) + goto out; + rc = -EINVAL; + if (how != 2) + goto out; + rc = llc_send_disc(sk); + if (!rc) + rc = llc_ui_wait_for_disc(sk, sk->sk_rcvtimeo); + /* Wake up anyone sleeping in poll */ + sk->sk_state_change(sk); +out: + release_sock(sk); + return rc; +} + +/** + * llc_ui_connect - Connect to a remote llc2 mac + sap. + * @sock: Socket which will be connected to the remote destination. + * @uaddr: Remote and possibly the local address of the new connection. + * @addrlen: Size of uaddr structure. + * @flags: Operational flags specified by the user. + * + * Connect to a remote llc2 mac + sap. The caller must specify the + * destination mac and address to connect to. If the user hasn't previously + * called bind(2) with a smac the address of the first interface of the + * specified arp type will be used. + * This function will autobind if user did not previously call bind. + * Returns: 0 upon success, negative otherwise. + */ +static int llc_ui_connect(struct socket *sock, struct sockaddr *uaddr, + int addrlen, int flags) +{ + struct sock *sk = sock->sk; + struct llc_sock *llc = llc_sk(sk); + struct sockaddr_llc *addr = (struct sockaddr_llc *)uaddr; + int rc = -EINVAL; + + lock_sock(sk); + if (unlikely(addrlen != sizeof(*addr))) + goto out; + rc = -EAFNOSUPPORT; + if (unlikely(addr->sllc_family != AF_LLC)) + goto out; + if (unlikely(sk->sk_type != SOCK_STREAM)) + goto out; + rc = -EALREADY; + if (unlikely(sock->state == SS_CONNECTING)) + goto out; + /* bind connection to sap if user hasn't done it. */ + if (sock_flag(sk, SOCK_ZAPPED)) { + /* bind to sap with null dev, exclusive */ + rc = llc_ui_autobind(sock, addr); + if (rc) + goto out; + } + llc->daddr.lsap = addr->sllc_sap; + memcpy(llc->daddr.mac, addr->sllc_mac, IFHWADDRLEN); + sock->state = SS_CONNECTING; + sk->sk_state = TCP_SYN_SENT; + llc->link = llc_ui_next_link_no(llc->sap->laddr.lsap); + rc = llc_establish_connection(sk, llc->dev->dev_addr, + addr->sllc_mac, addr->sllc_sap); + if (rc) { + dprintk("%s: llc_ui_send_conn failed :-(\n", __func__); + sock->state = SS_UNCONNECTED; + sk->sk_state = TCP_CLOSE; + goto out; + } + + if (sk->sk_state == TCP_SYN_SENT) { + const long timeo = sock_sndtimeo(sk, flags & O_NONBLOCK); + + if (!timeo || !llc_ui_wait_for_conn(sk, timeo)) + goto out; + + rc = sock_intr_errno(timeo); + if (signal_pending(current)) + goto out; + } + + if (sk->sk_state == TCP_CLOSE) + goto sock_error; + + sock->state = SS_CONNECTED; + rc = 0; +out: + release_sock(sk); + return rc; +sock_error: + rc = sock_error(sk) ? : -ECONNABORTED; + sock->state = SS_UNCONNECTED; + goto out; +} + +/** + * llc_ui_listen - allow a normal socket to accept incoming connections + * @sock: Socket to allow incoming connections on. + * @backlog: Number of connections to queue. + * + * Allow a normal socket to accept incoming connections. + * Returns 0 upon success, negative otherwise. + */ +static int llc_ui_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + int rc = -EINVAL; + + lock_sock(sk); + if (unlikely(sock->state != SS_UNCONNECTED)) + goto out; + rc = -EOPNOTSUPP; + if (unlikely(sk->sk_type != SOCK_STREAM)) + goto out; + rc = -EAGAIN; + if (sock_flag(sk, SOCK_ZAPPED)) + goto out; + rc = 0; + if (!(unsigned int)backlog) /* BSDism */ + backlog = 1; + sk->sk_max_ack_backlog = backlog; + if (sk->sk_state != TCP_LISTEN) { + sk->sk_ack_backlog = 0; + sk->sk_state = TCP_LISTEN; + } + sk->sk_socket->flags |= __SO_ACCEPTCON; +out: + release_sock(sk); + return rc; +} + +static int llc_ui_wait_for_disc(struct sock *sk, long timeout) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + int rc = 0; + + add_wait_queue(sk_sleep(sk), &wait); + while (1) { + if (sk_wait_event(sk, &timeout, + READ_ONCE(sk->sk_state) == TCP_CLOSE, &wait)) + break; + rc = -ERESTARTSYS; + if (signal_pending(current)) + break; + rc = -EAGAIN; + if (!timeout) + break; + rc = 0; + } + remove_wait_queue(sk_sleep(sk), &wait); + return rc; +} + +static bool llc_ui_wait_for_conn(struct sock *sk, long timeout) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + add_wait_queue(sk_sleep(sk), &wait); + while (1) { + if (sk_wait_event(sk, &timeout, + READ_ONCE(sk->sk_state) != TCP_SYN_SENT, &wait)) + break; + if (signal_pending(current) || !timeout) + break; + } + remove_wait_queue(sk_sleep(sk), &wait); + return timeout; +} + +static int llc_ui_wait_for_busy_core(struct sock *sk, long timeout) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + struct llc_sock *llc = llc_sk(sk); + int rc; + + add_wait_queue(sk_sleep(sk), &wait); + while (1) { + rc = 0; + if (sk_wait_event(sk, &timeout, + (READ_ONCE(sk->sk_shutdown) & RCV_SHUTDOWN) || + (!llc_data_accept_state(llc->state) && + !llc->remote_busy_flag && + !llc->p_flag), &wait)) + break; + rc = -ERESTARTSYS; + if (signal_pending(current)) + break; + rc = -EAGAIN; + if (!timeout) + break; + } + remove_wait_queue(sk_sleep(sk), &wait); + return rc; +} + +static int llc_wait_data(struct sock *sk, long timeo) +{ + int rc; + + while (1) { + /* + * POSIX 1003.1g mandates this order. + */ + rc = sock_error(sk); + if (rc) + break; + rc = 0; + if (sk->sk_shutdown & RCV_SHUTDOWN) + break; + rc = -EAGAIN; + if (!timeo) + break; + rc = sock_intr_errno(timeo); + if (signal_pending(current)) + break; + rc = 0; + if (sk_wait_data(sk, &timeo, NULL)) + break; + } + return rc; +} + +static void llc_cmsg_rcv(struct msghdr *msg, struct sk_buff *skb) +{ + struct llc_sock *llc = llc_sk(skb->sk); + + if (llc->cmsg_flags & LLC_CMSG_PKTINFO) { + struct llc_pktinfo info; + + memset(&info, 0, sizeof(info)); + info.lpi_ifindex = llc_sk(skb->sk)->dev->ifindex; + llc_pdu_decode_dsap(skb, &info.lpi_sap); + llc_pdu_decode_da(skb, info.lpi_mac); + put_cmsg(msg, SOL_LLC, LLC_OPT_PKTINFO, sizeof(info), &info); + } +} + +/** + * llc_ui_accept - accept a new incoming connection. + * @sock: Socket which connections arrive on. + * @newsock: Socket to move incoming connection to. + * @flags: User specified operational flags. + * @kern: If the socket is kernel internal + * + * Accept a new incoming connection. + * Returns 0 upon success, negative otherwise. + */ +static int llc_ui_accept(struct socket *sock, struct socket *newsock, int flags, + bool kern) +{ + struct sock *sk = sock->sk, *newsk; + struct llc_sock *llc, *newllc; + struct sk_buff *skb; + int rc = -EOPNOTSUPP; + + dprintk("%s: accepting on %02X\n", __func__, + llc_sk(sk)->laddr.lsap); + lock_sock(sk); + if (unlikely(sk->sk_type != SOCK_STREAM)) + goto out; + rc = -EINVAL; + if (unlikely(sock->state != SS_UNCONNECTED || + sk->sk_state != TCP_LISTEN)) + goto out; + /* wait for a connection to arrive. */ + if (skb_queue_empty(&sk->sk_receive_queue)) { + rc = llc_wait_data(sk, sk->sk_rcvtimeo); + if (rc) + goto out; + } + dprintk("%s: got a new connection on %02X\n", __func__, + llc_sk(sk)->laddr.lsap); + skb = skb_dequeue(&sk->sk_receive_queue); + rc = -EINVAL; + if (!skb->sk) + goto frees; + rc = 0; + newsk = skb->sk; + /* attach connection to a new socket. */ + llc_ui_sk_init(newsock, newsk); + sock_reset_flag(newsk, SOCK_ZAPPED); + newsk->sk_state = TCP_ESTABLISHED; + newsock->state = SS_CONNECTED; + llc = llc_sk(sk); + newllc = llc_sk(newsk); + memcpy(&newllc->addr, &llc->addr, sizeof(newllc->addr)); + newllc->link = llc_ui_next_link_no(newllc->laddr.lsap); + + /* put original socket back into a clean listen state. */ + sk->sk_state = TCP_LISTEN; + sk_acceptq_removed(sk); + dprintk("%s: ok success on %02X, client on %02X\n", __func__, + llc_sk(sk)->addr.sllc_sap, newllc->daddr.lsap); +frees: + kfree_skb(skb); +out: + release_sock(sk); + return rc; +} + +/** + * llc_ui_recvmsg - copy received data to the socket user. + * @sock: Socket to copy data from. + * @msg: Various user space related information. + * @len: Size of user buffer. + * @flags: User specified flags. + * + * Copy received data to the socket user. + * Returns non-negative upon success, negative otherwise. + */ +static int llc_ui_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, + int flags) +{ + DECLARE_SOCKADDR(struct sockaddr_llc *, uaddr, msg->msg_name); + const int nonblock = flags & MSG_DONTWAIT; + struct sk_buff *skb = NULL; + struct sock *sk = sock->sk; + struct llc_sock *llc = llc_sk(sk); + size_t copied = 0; + u32 peek_seq = 0; + u32 *seq, skb_len; + unsigned long used; + int target; /* Read at least this many bytes */ + long timeo; + + lock_sock(sk); + copied = -ENOTCONN; + if (unlikely(sk->sk_type == SOCK_STREAM && sk->sk_state == TCP_LISTEN)) + goto out; + + timeo = sock_rcvtimeo(sk, nonblock); + + seq = &llc->copied_seq; + if (flags & MSG_PEEK) { + peek_seq = llc->copied_seq; + seq = &peek_seq; + } + + target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); + copied = 0; + + do { + u32 offset; + + /* + * We need to check signals first, to get correct SIGURG + * handling. FIXME: Need to check this doesn't impact 1003.1g + * and move it down to the bottom of the loop + */ + if (signal_pending(current)) { + if (copied) + break; + copied = timeo ? sock_intr_errno(timeo) : -EAGAIN; + break; + } + + /* Next get a buffer. */ + + skb = skb_peek(&sk->sk_receive_queue); + if (skb) { + offset = *seq; + goto found_ok_skb; + } + /* Well, if we have backlog, try to process it now yet. */ + + if (copied >= target && !READ_ONCE(sk->sk_backlog.tail)) + break; + + if (copied) { + if (sk->sk_err || + sk->sk_state == TCP_CLOSE || + (sk->sk_shutdown & RCV_SHUTDOWN) || + !timeo || + (flags & MSG_PEEK)) + break; + } else { + if (sock_flag(sk, SOCK_DONE)) + break; + + if (sk->sk_err) { + copied = sock_error(sk); + break; + } + if (sk->sk_shutdown & RCV_SHUTDOWN) + break; + + if (sk->sk_type == SOCK_STREAM && sk->sk_state == TCP_CLOSE) { + if (!sock_flag(sk, SOCK_DONE)) { + /* + * This occurs when user tries to read + * from never connected socket. + */ + copied = -ENOTCONN; + break; + } + break; + } + if (!timeo) { + copied = -EAGAIN; + break; + } + } + + if (copied >= target) { /* Do not sleep, just process backlog. */ + release_sock(sk); + lock_sock(sk); + } else + sk_wait_data(sk, &timeo, NULL); + + if ((flags & MSG_PEEK) && peek_seq != llc->copied_seq) { + net_dbg_ratelimited("LLC(%s:%d): Application bug, race in MSG_PEEK\n", + current->comm, + task_pid_nr(current)); + peek_seq = llc->copied_seq; + } + continue; + found_ok_skb: + skb_len = skb->len; + /* Ok so how much can we use? */ + used = skb->len - offset; + if (len < used) + used = len; + + if (!(flags & MSG_TRUNC)) { + int rc = skb_copy_datagram_msg(skb, offset, msg, used); + if (rc) { + /* Exception. Bailout! */ + if (!copied) + copied = -EFAULT; + break; + } + } + + *seq += used; + copied += used; + len -= used; + + /* For non stream protcols we get one packet per recvmsg call */ + if (sk->sk_type != SOCK_STREAM) + goto copy_uaddr; + + if (!(flags & MSG_PEEK)) { + skb_unlink(skb, &sk->sk_receive_queue); + kfree_skb(skb); + *seq = 0; + } + + /* Partial read */ + if (used + offset < skb_len) + continue; + } while (len > 0); + +out: + release_sock(sk); + return copied; +copy_uaddr: + if (uaddr != NULL && skb != NULL) { + memcpy(uaddr, llc_ui_skb_cb(skb), sizeof(*uaddr)); + msg->msg_namelen = sizeof(*uaddr); + } + if (llc_sk(sk)->cmsg_flags) + llc_cmsg_rcv(msg, skb); + + if (!(flags & MSG_PEEK)) { + skb_unlink(skb, &sk->sk_receive_queue); + kfree_skb(skb); + *seq = 0; + } + + goto out; +} + +/** + * llc_ui_sendmsg - Transmit data provided by the socket user. + * @sock: Socket to transmit data from. + * @msg: Various user related information. + * @len: Length of data to transmit. + * + * Transmit data provided by the socket user. + * Returns non-negative upon success, negative otherwise. + */ +static int llc_ui_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) +{ + DECLARE_SOCKADDR(struct sockaddr_llc *, addr, msg->msg_name); + struct sock *sk = sock->sk; + struct llc_sock *llc = llc_sk(sk); + int flags = msg->msg_flags; + int noblock = flags & MSG_DONTWAIT; + int rc = -EINVAL, copied = 0, hdrlen, hh_len; + struct sk_buff *skb = NULL; + struct net_device *dev; + size_t size = 0; + + dprintk("%s: sending from %02X to %02X\n", __func__, + llc->laddr.lsap, llc->daddr.lsap); + lock_sock(sk); + if (addr) { + if (msg->msg_namelen < sizeof(*addr)) + goto out; + } else { + if (llc_ui_addr_null(&llc->addr)) + goto out; + addr = &llc->addr; + } + /* must bind connection to sap if user hasn't done it. */ + if (sock_flag(sk, SOCK_ZAPPED)) { + /* bind to sap with null dev, exclusive. */ + rc = llc_ui_autobind(sock, addr); + if (rc) + goto out; + } + dev = llc->dev; + hh_len = LL_RESERVED_SPACE(dev); + hdrlen = llc_ui_header_len(sk, addr); + size = hdrlen + len; + size = min_t(size_t, size, READ_ONCE(dev->mtu)); + copied = size - hdrlen; + rc = -EINVAL; + if (copied < 0) + goto out; + release_sock(sk); + skb = sock_alloc_send_skb(sk, hh_len + size, noblock, &rc); + lock_sock(sk); + if (!skb) + goto out; + if (sock_flag(sk, SOCK_ZAPPED) || + llc->dev != dev || + hdrlen != llc_ui_header_len(sk, addr) || + hh_len != LL_RESERVED_SPACE(dev) || + size > READ_ONCE(dev->mtu)) + goto out; + skb->dev = dev; + skb->protocol = llc_proto_type(addr->sllc_arphrd); + skb_reserve(skb, hh_len + hdrlen); + rc = memcpy_from_msg(skb_put(skb, copied), msg, copied); + if (rc) + goto out; + if (sk->sk_type == SOCK_DGRAM || addr->sllc_ua) { + llc_build_and_send_ui_pkt(llc->sap, skb, addr->sllc_mac, + addr->sllc_sap); + skb = NULL; + goto out; + } + if (addr->sllc_test) { + llc_build_and_send_test_pkt(llc->sap, skb, addr->sllc_mac, + addr->sllc_sap); + skb = NULL; + goto out; + } + if (addr->sllc_xid) { + llc_build_and_send_xid_pkt(llc->sap, skb, addr->sllc_mac, + addr->sllc_sap); + skb = NULL; + goto out; + } + rc = -ENOPROTOOPT; + if (!(sk->sk_type == SOCK_STREAM && !addr->sllc_ua)) + goto out; + rc = llc_ui_send_data(sk, skb, noblock); + skb = NULL; +out: + kfree_skb(skb); + if (rc) + dprintk("%s: failed sending from %02X to %02X: %d\n", + __func__, llc->laddr.lsap, llc->daddr.lsap, rc); + release_sock(sk); + return rc ? : copied; +} + +/** + * llc_ui_getname - return the address info of a socket + * @sock: Socket to get address of. + * @uaddr: Address structure to return information. + * @peer: Does user want local or remote address information. + * + * Return the address information of a socket. + */ +static int llc_ui_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + struct sockaddr_llc sllc; + struct sock *sk = sock->sk; + struct llc_sock *llc = llc_sk(sk); + int rc = -EBADF; + + memset(&sllc, 0, sizeof(sllc)); + lock_sock(sk); + if (sock_flag(sk, SOCK_ZAPPED)) + goto out; + if (peer) { + rc = -ENOTCONN; + if (sk->sk_state != TCP_ESTABLISHED) + goto out; + if(llc->dev) + sllc.sllc_arphrd = llc->dev->type; + sllc.sllc_sap = llc->daddr.lsap; + memcpy(&sllc.sllc_mac, &llc->daddr.mac, IFHWADDRLEN); + } else { + rc = -EINVAL; + if (!llc->sap) + goto out; + sllc.sllc_sap = llc->sap->laddr.lsap; + + if (llc->dev) { + sllc.sllc_arphrd = llc->dev->type; + memcpy(&sllc.sllc_mac, llc->dev->dev_addr, + IFHWADDRLEN); + } + } + sllc.sllc_family = AF_LLC; + memcpy(uaddr, &sllc, sizeof(sllc)); + rc = sizeof(sllc); +out: + release_sock(sk); + return rc; +} + +/** + * llc_ui_ioctl - io controls for PF_LLC + * @sock: Socket to get/set info + * @cmd: command + * @arg: optional argument for cmd + * + * get/set info on llc sockets + */ +static int llc_ui_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + return -ENOIOCTLCMD; +} + +/** + * llc_ui_setsockopt - set various connection specific parameters. + * @sock: Socket to set options on. + * @level: Socket level user is requesting operations on. + * @optname: Operation name. + * @optval: User provided operation data. + * @optlen: Length of optval. + * + * Set various connection specific parameters. + */ +static int llc_ui_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct llc_sock *llc = llc_sk(sk); + unsigned int opt; + int rc = -EINVAL; + + lock_sock(sk); + if (unlikely(level != SOL_LLC || optlen != sizeof(int))) + goto out; + rc = copy_from_sockptr(&opt, optval, sizeof(opt)); + if (rc) + goto out; + rc = -EINVAL; + switch (optname) { + case LLC_OPT_RETRY: + if (opt > LLC_OPT_MAX_RETRY) + goto out; + llc->n2 = opt; + break; + case LLC_OPT_SIZE: + if (opt > LLC_OPT_MAX_SIZE) + goto out; + llc->n1 = opt; + break; + case LLC_OPT_ACK_TMR_EXP: + if (opt > LLC_OPT_MAX_ACK_TMR_EXP) + goto out; + llc->ack_timer.expire = opt * HZ; + break; + case LLC_OPT_P_TMR_EXP: + if (opt > LLC_OPT_MAX_P_TMR_EXP) + goto out; + llc->pf_cycle_timer.expire = opt * HZ; + break; + case LLC_OPT_REJ_TMR_EXP: + if (opt > LLC_OPT_MAX_REJ_TMR_EXP) + goto out; + llc->rej_sent_timer.expire = opt * HZ; + break; + case LLC_OPT_BUSY_TMR_EXP: + if (opt > LLC_OPT_MAX_BUSY_TMR_EXP) + goto out; + llc->busy_state_timer.expire = opt * HZ; + break; + case LLC_OPT_TX_WIN: + if (opt > LLC_OPT_MAX_WIN) + goto out; + llc->k = opt; + break; + case LLC_OPT_RX_WIN: + if (opt > LLC_OPT_MAX_WIN) + goto out; + llc->rw = opt; + break; + case LLC_OPT_PKTINFO: + if (opt) + llc->cmsg_flags |= LLC_CMSG_PKTINFO; + else + llc->cmsg_flags &= ~LLC_CMSG_PKTINFO; + break; + default: + rc = -ENOPROTOOPT; + goto out; + } + rc = 0; +out: + release_sock(sk); + return rc; +} + +/** + * llc_ui_getsockopt - get connection specific socket info + * @sock: Socket to get information from. + * @level: Socket level user is requesting operations on. + * @optname: Operation name. + * @optval: Variable to return operation data in. + * @optlen: Length of optval. + * + * Get connection specific socket information. + */ +static int llc_ui_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + struct llc_sock *llc = llc_sk(sk); + int val = 0, len = 0, rc = -EINVAL; + + lock_sock(sk); + if (unlikely(level != SOL_LLC)) + goto out; + rc = get_user(len, optlen); + if (rc) + goto out; + rc = -EINVAL; + if (len != sizeof(int)) + goto out; + switch (optname) { + case LLC_OPT_RETRY: + val = llc->n2; break; + case LLC_OPT_SIZE: + val = llc->n1; break; + case LLC_OPT_ACK_TMR_EXP: + val = llc->ack_timer.expire / HZ; break; + case LLC_OPT_P_TMR_EXP: + val = llc->pf_cycle_timer.expire / HZ; break; + case LLC_OPT_REJ_TMR_EXP: + val = llc->rej_sent_timer.expire / HZ; break; + case LLC_OPT_BUSY_TMR_EXP: + val = llc->busy_state_timer.expire / HZ; break; + case LLC_OPT_TX_WIN: + val = llc->k; break; + case LLC_OPT_RX_WIN: + val = llc->rw; break; + case LLC_OPT_PKTINFO: + val = (llc->cmsg_flags & LLC_CMSG_PKTINFO) != 0; + break; + default: + rc = -ENOPROTOOPT; + goto out; + } + rc = 0; + if (put_user(len, optlen) || copy_to_user(optval, &val, len)) + rc = -EFAULT; +out: + release_sock(sk); + return rc; +} + +static const struct net_proto_family llc_ui_family_ops = { + .family = PF_LLC, + .create = llc_ui_create, + .owner = THIS_MODULE, +}; + +static const struct proto_ops llc_ui_ops = { + .family = PF_LLC, + .owner = THIS_MODULE, + .release = llc_ui_release, + .bind = llc_ui_bind, + .connect = llc_ui_connect, + .socketpair = sock_no_socketpair, + .accept = llc_ui_accept, + .getname = llc_ui_getname, + .poll = datagram_poll, + .ioctl = llc_ui_ioctl, + .listen = llc_ui_listen, + .shutdown = llc_ui_shutdown, + .setsockopt = llc_ui_setsockopt, + .getsockopt = llc_ui_getsockopt, + .sendmsg = llc_ui_sendmsg, + .recvmsg = llc_ui_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static const char llc_proc_err_msg[] __initconst = + KERN_CRIT "LLC: Unable to register the proc_fs entries\n"; +static const char llc_sysctl_err_msg[] __initconst = + KERN_CRIT "LLC: Unable to register the sysctl entries\n"; +static const char llc_sock_err_msg[] __initconst = + KERN_CRIT "LLC: Unable to register the network family\n"; + +static int __init llc2_init(void) +{ + int rc = proto_register(&llc_proto, 0); + + if (rc != 0) + goto out; + + llc_build_offset_table(); + llc_station_init(); + llc_ui_sap_last_autoport = LLC_SAP_DYN_START; + rc = llc_proc_init(); + if (rc != 0) { + printk(llc_proc_err_msg); + goto out_station; + } + rc = llc_sysctl_init(); + if (rc) { + printk(llc_sysctl_err_msg); + goto out_proc; + } + rc = sock_register(&llc_ui_family_ops); + if (rc) { + printk(llc_sock_err_msg); + goto out_sysctl; + } + llc_add_pack(LLC_DEST_SAP, llc_sap_handler); + llc_add_pack(LLC_DEST_CONN, llc_conn_handler); +out: + return rc; +out_sysctl: + llc_sysctl_exit(); +out_proc: + llc_proc_exit(); +out_station: + llc_station_exit(); + proto_unregister(&llc_proto); + goto out; +} + +static void __exit llc2_exit(void) +{ + llc_station_exit(); + llc_remove_pack(LLC_DEST_SAP); + llc_remove_pack(LLC_DEST_CONN); + sock_unregister(PF_LLC); + llc_proc_exit(); + llc_sysctl_exit(); + proto_unregister(&llc_proto); +} + +module_init(llc2_init); +module_exit(llc2_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Procom 1997, Jay Schullist 2001, Arnaldo C. Melo 2001-2003"); +MODULE_DESCRIPTION("IEEE 802.2 PF_LLC support"); +MODULE_ALIAS_NETPROTO(PF_LLC); diff --git a/net/llc/llc_c_ac.c b/net/llc/llc_c_ac.c new file mode 100644 index 000000000..40ca3c1e4 --- /dev/null +++ b/net/llc/llc_c_ac.c @@ -0,0 +1,1451 @@ +/* + * llc_c_ac.c - actions performed during connection state transition. + * + * Description: + * Functions in this module are implementation of connection component actions + * Details of actions can be found in IEEE-802.2 standard document. + * All functions have one connection and one event as input argument. All of + * them return 0 On success and 1 otherwise. + * + * Copyright (c) 1997 by Procom Technology, Inc. + * 2001-2003 by Arnaldo Carvalho de Melo + * + * This program can be redistributed or modified under the terms of the + * GNU General Public License as published by the Free Software Foundation. + * This program is distributed without any warranty or implied warranty + * of merchantability or fitness for a particular purpose. + * + * See the GNU General Public License for more details. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +static int llc_conn_ac_inc_vs_by_1(struct sock *sk, struct sk_buff *skb); +static void llc_process_tmr_ev(struct sock *sk, struct sk_buff *skb); +static int llc_conn_ac_data_confirm(struct sock *sk, struct sk_buff *ev); + +static int llc_conn_ac_inc_npta_value(struct sock *sk, struct sk_buff *skb); + +static int llc_conn_ac_send_rr_rsp_f_set_ackpf(struct sock *sk, + struct sk_buff *skb); + +static int llc_conn_ac_set_p_flag_1(struct sock *sk, struct sk_buff *skb); + +#define INCORRECT 0 + +int llc_conn_ac_clear_remote_busy(struct sock *sk, struct sk_buff *skb) +{ + struct llc_sock *llc = llc_sk(sk); + + if (llc->remote_busy_flag) { + u8 nr; + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + llc->remote_busy_flag = 0; + del_timer(&llc->busy_state_timer.timer); + nr = LLC_I_GET_NR(pdu); + llc_conn_resend_i_pdu_as_cmd(sk, nr, 0); + } + return 0; +} + +int llc_conn_ac_conn_ind(struct sock *sk, struct sk_buff *skb) +{ + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + ev->ind_prim = LLC_CONN_PRIM; + return 0; +} + +int llc_conn_ac_conn_confirm(struct sock *sk, struct sk_buff *skb) +{ + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + ev->cfm_prim = LLC_CONN_PRIM; + return 0; +} + +static int llc_conn_ac_data_confirm(struct sock *sk, struct sk_buff *skb) +{ + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + ev->cfm_prim = LLC_DATA_PRIM; + return 0; +} + +int llc_conn_ac_data_ind(struct sock *sk, struct sk_buff *skb) +{ + llc_conn_rtn_pdu(sk, skb); + return 0; +} + +int llc_conn_ac_disc_ind(struct sock *sk, struct sk_buff *skb) +{ + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + u8 reason = 0; + int rc = 0; + + if (ev->type == LLC_CONN_EV_TYPE_PDU) { + struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + if (LLC_PDU_IS_RSP(pdu) && + LLC_PDU_TYPE_IS_U(pdu) && + LLC_U_PDU_RSP(pdu) == LLC_2_PDU_RSP_DM) + reason = LLC_DISC_REASON_RX_DM_RSP_PDU; + else if (LLC_PDU_IS_CMD(pdu) && + LLC_PDU_TYPE_IS_U(pdu) && + LLC_U_PDU_CMD(pdu) == LLC_2_PDU_CMD_DISC) + reason = LLC_DISC_REASON_RX_DISC_CMD_PDU; + } else if (ev->type == LLC_CONN_EV_TYPE_ACK_TMR) + reason = LLC_DISC_REASON_ACK_TMR_EXP; + else + rc = -EINVAL; + if (!rc) { + ev->reason = reason; + ev->ind_prim = LLC_DISC_PRIM; + } + return rc; +} + +int llc_conn_ac_disc_confirm(struct sock *sk, struct sk_buff *skb) +{ + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + ev->reason = ev->status; + ev->cfm_prim = LLC_DISC_PRIM; + return 0; +} + +int llc_conn_ac_rst_ind(struct sock *sk, struct sk_buff *skb) +{ + u8 reason = 0; + int rc = 1; + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + struct llc_sock *llc = llc_sk(sk); + + switch (ev->type) { + case LLC_CONN_EV_TYPE_PDU: + if (LLC_PDU_IS_RSP(pdu) && + LLC_PDU_TYPE_IS_U(pdu) && + LLC_U_PDU_RSP(pdu) == LLC_2_PDU_RSP_FRMR) { + reason = LLC_RESET_REASON_LOCAL; + rc = 0; + } else if (LLC_PDU_IS_CMD(pdu) && + LLC_PDU_TYPE_IS_U(pdu) && + LLC_U_PDU_CMD(pdu) == LLC_2_PDU_CMD_SABME) { + reason = LLC_RESET_REASON_REMOTE; + rc = 0; + } + break; + case LLC_CONN_EV_TYPE_ACK_TMR: + case LLC_CONN_EV_TYPE_P_TMR: + case LLC_CONN_EV_TYPE_REJ_TMR: + case LLC_CONN_EV_TYPE_BUSY_TMR: + if (llc->retry_count > llc->n2) { + reason = LLC_RESET_REASON_LOCAL; + rc = 0; + } + break; + } + if (!rc) { + ev->reason = reason; + ev->ind_prim = LLC_RESET_PRIM; + } + return rc; +} + +int llc_conn_ac_rst_confirm(struct sock *sk, struct sk_buff *skb) +{ + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + ev->reason = 0; + ev->cfm_prim = LLC_RESET_PRIM; + return 0; +} + +int llc_conn_ac_clear_remote_busy_if_f_eq_1(struct sock *sk, + struct sk_buff *skb) +{ + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + if (LLC_PDU_IS_RSP(pdu) && + LLC_PDU_TYPE_IS_I(pdu) && + LLC_I_PF_IS_1(pdu) && llc_sk(sk)->ack_pf) + llc_conn_ac_clear_remote_busy(sk, skb); + return 0; +} + +int llc_conn_ac_stop_rej_tmr_if_data_flag_eq_2(struct sock *sk, + struct sk_buff *skb) +{ + struct llc_sock *llc = llc_sk(sk); + + if (llc->data_flag == 2) + del_timer(&llc->rej_sent_timer.timer); + return 0; +} + +int llc_conn_ac_send_disc_cmd_p_set_x(struct sock *sk, struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_U, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_U, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_CMD); + llc_pdu_init_as_disc_cmd(nskb, 1); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + llc_conn_ac_set_p_flag_1(sk, skb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_send_dm_rsp_f_set_p(struct sock *sk, struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_U, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + u8 f_bit; + + llc_pdu_decode_pf_bit(skb, &f_bit); + llc_pdu_header_init(nskb, LLC_PDU_TYPE_U, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_dm_rsp(nskb, f_bit); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_send_dm_rsp_f_set_1(struct sock *sk, struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_U, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_U, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_dm_rsp(nskb, 1); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_send_frmr_rsp_f_set_x(struct sock *sk, struct sk_buff *skb) +{ + u8 f_bit; + int rc = -ENOBUFS; + struct sk_buff *nskb; + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + struct llc_sock *llc = llc_sk(sk); + + llc->rx_pdu_hdr = *((u32 *)pdu); + if (LLC_PDU_IS_CMD(pdu)) + llc_pdu_decode_pf_bit(skb, &f_bit); + else + f_bit = 0; + nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_U, + sizeof(struct llc_frmr_info)); + if (nskb) { + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_U, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_frmr_rsp(nskb, pdu, f_bit, llc->vS, + llc->vR, INCORRECT); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_resend_frmr_rsp_f_set_0(struct sock *sk, struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_U, + sizeof(struct llc_frmr_info)); + + if (nskb) { + struct llc_sap *sap = llc->sap; + struct llc_pdu_sn *pdu = (struct llc_pdu_sn *)&llc->rx_pdu_hdr; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_U, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_frmr_rsp(nskb, pdu, 0, llc->vS, + llc->vR, INCORRECT); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_resend_frmr_rsp_f_set_p(struct sock *sk, struct sk_buff *skb) +{ + u8 f_bit; + int rc = -ENOBUFS; + struct sk_buff *nskb; + struct llc_sock *llc = llc_sk(sk); + + llc_pdu_decode_pf_bit(skb, &f_bit); + nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_U, + sizeof(struct llc_frmr_info)); + if (nskb) { + struct llc_sap *sap = llc->sap; + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_U, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_frmr_rsp(nskb, pdu, f_bit, llc->vS, + llc->vR, INCORRECT); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_send_i_cmd_p_set_1(struct sock *sk, struct sk_buff *skb) +{ + int rc; + struct llc_sock *llc = llc_sk(sk); + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(skb, LLC_PDU_TYPE_I, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_CMD); + llc_pdu_init_as_i_cmd(skb, 1, llc->vS, llc->vR); + rc = llc_mac_hdr_init(skb, llc->dev->dev_addr, llc->daddr.mac); + if (likely(!rc)) { + skb_get(skb); + llc_conn_send_pdu(sk, skb); + llc_conn_ac_inc_vs_by_1(sk, skb); + } + return rc; +} + +static int llc_conn_ac_send_i_cmd_p_set_0(struct sock *sk, struct sk_buff *skb) +{ + int rc; + struct llc_sock *llc = llc_sk(sk); + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(skb, LLC_PDU_TYPE_I, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_CMD); + llc_pdu_init_as_i_cmd(skb, 0, llc->vS, llc->vR); + rc = llc_mac_hdr_init(skb, llc->dev->dev_addr, llc->daddr.mac); + if (likely(!rc)) { + skb_get(skb); + llc_conn_send_pdu(sk, skb); + llc_conn_ac_inc_vs_by_1(sk, skb); + } + return rc; +} + +int llc_conn_ac_send_i_xxx_x_set_0(struct sock *sk, struct sk_buff *skb) +{ + int rc; + struct llc_sock *llc = llc_sk(sk); + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(skb, LLC_PDU_TYPE_I, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_CMD); + llc_pdu_init_as_i_cmd(skb, 0, llc->vS, llc->vR); + rc = llc_mac_hdr_init(skb, llc->dev->dev_addr, llc->daddr.mac); + if (likely(!rc)) { + skb_get(skb); + llc_conn_send_pdu(sk, skb); + llc_conn_ac_inc_vs_by_1(sk, skb); + } + return 0; +} + +int llc_conn_ac_resend_i_xxx_x_set_0(struct sock *sk, struct sk_buff *skb) +{ + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + u8 nr = LLC_I_GET_NR(pdu); + + llc_conn_resend_i_pdu_as_cmd(sk, nr, 0); + return 0; +} + +int llc_conn_ac_resend_i_xxx_x_set_0_or_send_rr(struct sock *sk, + struct sk_buff *skb) +{ + u8 nr; + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_U, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_U, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_rr_rsp(nskb, 0, llc->vR); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (likely(!rc)) + llc_conn_send_pdu(sk, nskb); + else + kfree_skb(skb); + } + if (rc) { + nr = LLC_I_GET_NR(pdu); + rc = 0; + llc_conn_resend_i_pdu_as_cmd(sk, nr, 0); + } + return rc; +} + +int llc_conn_ac_resend_i_rsp_f_set_1(struct sock *sk, struct sk_buff *skb) +{ + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + u8 nr = LLC_I_GET_NR(pdu); + + llc_conn_resend_i_pdu_as_rsp(sk, nr, 1); + return 0; +} + +int llc_conn_ac_send_rej_cmd_p_set_1(struct sock *sk, struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_S, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_S, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_CMD); + llc_pdu_init_as_rej_cmd(nskb, 1, llc->vR); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_send_rej_rsp_f_set_1(struct sock *sk, struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_S, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_S, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_rej_rsp(nskb, 1, llc->vR); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_send_rej_xxx_x_set_0(struct sock *sk, struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_S, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_S, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_rej_rsp(nskb, 0, llc->vR); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_send_rnr_cmd_p_set_1(struct sock *sk, struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_S, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_S, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_CMD); + llc_pdu_init_as_rnr_cmd(nskb, 1, llc->vR); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_send_rnr_rsp_f_set_1(struct sock *sk, struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_S, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_S, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_rnr_rsp(nskb, 1, llc->vR); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_send_rnr_xxx_x_set_0(struct sock *sk, struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_S, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_S, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_rnr_rsp(nskb, 0, llc->vR); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_set_remote_busy(struct sock *sk, struct sk_buff *skb) +{ + struct llc_sock *llc = llc_sk(sk); + + if (!llc->remote_busy_flag) { + llc->remote_busy_flag = 1; + mod_timer(&llc->busy_state_timer.timer, + jiffies + llc->busy_state_timer.expire); + } + return 0; +} + +int llc_conn_ac_opt_send_rnr_xxx_x_set_0(struct sock *sk, struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_S, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_S, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_rnr_rsp(nskb, 0, llc->vR); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_send_rr_cmd_p_set_1(struct sock *sk, struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_S, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_S, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_CMD); + llc_pdu_init_as_rr_cmd(nskb, 1, llc->vR); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_send_rr_rsp_f_set_1(struct sock *sk, struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_S, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + u8 f_bit = 1; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_S, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_rr_rsp(nskb, f_bit, llc->vR); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_send_ack_rsp_f_set_1(struct sock *sk, struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_S, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_S, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_rr_rsp(nskb, 1, llc->vR); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_send_rr_xxx_x_set_0(struct sock *sk, struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_S, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_S, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_rr_rsp(nskb, 0, llc->vR); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_send_ack_xxx_x_set_0(struct sock *sk, struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_S, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_S, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_rr_rsp(nskb, 0, llc->vR); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +void llc_conn_set_p_flag(struct sock *sk, u8 value) +{ + int state_changed = llc_sk(sk)->p_flag && !value; + + llc_sk(sk)->p_flag = value; + + if (state_changed) + sk->sk_state_change(sk); +} + +int llc_conn_ac_send_sabme_cmd_p_set_x(struct sock *sk, struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_U, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + const u8 *dmac = llc->daddr.mac; + + if (llc->dev->flags & IFF_LOOPBACK) + dmac = llc->dev->dev_addr; + llc_pdu_header_init(nskb, LLC_PDU_TYPE_U, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_CMD); + llc_pdu_init_as_sabme_cmd(nskb, 1); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, dmac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + llc_conn_set_p_flag(sk, 1); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_send_ua_rsp_f_set_p(struct sock *sk, struct sk_buff *skb) +{ + u8 f_bit; + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_U, 0); + + llc_pdu_decode_pf_bit(skb, &f_bit); + if (nskb) { + struct llc_sap *sap = llc->sap; + + nskb->dev = llc->dev; + llc_pdu_header_init(nskb, LLC_PDU_TYPE_U, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_ua_rsp(nskb, f_bit); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +int llc_conn_ac_set_s_flag_0(struct sock *sk, struct sk_buff *skb) +{ + llc_sk(sk)->s_flag = 0; + return 0; +} + +int llc_conn_ac_set_s_flag_1(struct sock *sk, struct sk_buff *skb) +{ + llc_sk(sk)->s_flag = 1; + return 0; +} + +int llc_conn_ac_start_p_timer(struct sock *sk, struct sk_buff *skb) +{ + struct llc_sock *llc = llc_sk(sk); + + llc_conn_set_p_flag(sk, 1); + mod_timer(&llc->pf_cycle_timer.timer, + jiffies + llc->pf_cycle_timer.expire); + return 0; +} + +/** + * llc_conn_ac_send_ack_if_needed - check if ack is needed + * @sk: current connection structure + * @skb: current event + * + * Checks number of received PDUs which have not been acknowledged, yet, + * If number of them reaches to "npta"(Number of PDUs To Acknowledge) then + * sends an RR response as acknowledgement for them. Returns 0 for + * success, 1 otherwise. + */ +int llc_conn_ac_send_ack_if_needed(struct sock *sk, struct sk_buff *skb) +{ + u8 pf_bit; + struct llc_sock *llc = llc_sk(sk); + + llc_pdu_decode_pf_bit(skb, &pf_bit); + llc->ack_pf |= pf_bit & 1; + if (!llc->ack_must_be_send) { + llc->first_pdu_Ns = llc->vR; + llc->ack_must_be_send = 1; + llc->ack_pf = pf_bit & 1; + } + if (((llc->vR - llc->first_pdu_Ns + 1 + LLC_2_SEQ_NBR_MODULO) + % LLC_2_SEQ_NBR_MODULO) >= llc->npta) { + llc_conn_ac_send_rr_rsp_f_set_ackpf(sk, skb); + llc->ack_must_be_send = 0; + llc->ack_pf = 0; + llc_conn_ac_inc_npta_value(sk, skb); + } + return 0; +} + +/** + * llc_conn_ac_rst_sendack_flag - resets ack_must_be_send flag + * @sk: current connection structure + * @skb: current event + * + * This action resets ack_must_be_send flag of given connection, this flag + * indicates if there is any PDU which has not been acknowledged yet. + * Returns 0 for success, 1 otherwise. + */ +int llc_conn_ac_rst_sendack_flag(struct sock *sk, struct sk_buff *skb) +{ + llc_sk(sk)->ack_must_be_send = llc_sk(sk)->ack_pf = 0; + return 0; +} + +/** + * llc_conn_ac_send_i_rsp_f_set_ackpf - acknowledge received PDUs + * @sk: current connection structure + * @skb: current event + * + * Sends an I response PDU with f-bit set to ack_pf flag as acknowledge to + * all received PDUs which have not been acknowledged, yet. ack_pf flag is + * set to one if one PDU with p-bit set to one is received. Returns 0 for + * success, 1 otherwise. + */ +static int llc_conn_ac_send_i_rsp_f_set_ackpf(struct sock *sk, + struct sk_buff *skb) +{ + int rc; + struct llc_sock *llc = llc_sk(sk); + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(skb, LLC_PDU_TYPE_I, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_i_cmd(skb, llc->ack_pf, llc->vS, llc->vR); + rc = llc_mac_hdr_init(skb, llc->dev->dev_addr, llc->daddr.mac); + if (likely(!rc)) { + skb_get(skb); + llc_conn_send_pdu(sk, skb); + llc_conn_ac_inc_vs_by_1(sk, skb); + } + return rc; +} + +/** + * llc_conn_ac_send_i_as_ack - sends an I-format PDU to acknowledge rx PDUs + * @sk: current connection structure. + * @skb: current event. + * + * This action sends an I-format PDU as acknowledge to received PDUs which + * have not been acknowledged, yet, if there is any. By using of this + * action number of acknowledgements decreases, this technic is called + * piggy backing. Returns 0 for success, 1 otherwise. + */ +int llc_conn_ac_send_i_as_ack(struct sock *sk, struct sk_buff *skb) +{ + struct llc_sock *llc = llc_sk(sk); + int ret; + + if (llc->ack_must_be_send) { + ret = llc_conn_ac_send_i_rsp_f_set_ackpf(sk, skb); + llc->ack_must_be_send = 0 ; + llc->ack_pf = 0; + } else { + ret = llc_conn_ac_send_i_cmd_p_set_0(sk, skb); + } + + return ret; +} + +/** + * llc_conn_ac_send_rr_rsp_f_set_ackpf - ack all rx PDUs not yet acked + * @sk: current connection structure. + * @skb: current event. + * + * This action sends an RR response with f-bit set to ack_pf flag as + * acknowledge to all received PDUs which have not been acknowledged, yet, + * if there is any. ack_pf flag indicates if a PDU has been received with + * p-bit set to one. Returns 0 for success, 1 otherwise. + */ +static int llc_conn_ac_send_rr_rsp_f_set_ackpf(struct sock *sk, + struct sk_buff *skb) +{ + int rc = -ENOBUFS; + struct llc_sock *llc = llc_sk(sk); + struct sk_buff *nskb = llc_alloc_frame(sk, llc->dev, LLC_PDU_TYPE_S, 0); + + if (nskb) { + struct llc_sap *sap = llc->sap; + + llc_pdu_header_init(nskb, LLC_PDU_TYPE_S, sap->laddr.lsap, + llc->daddr.lsap, LLC_PDU_RSP); + llc_pdu_init_as_rr_rsp(nskb, llc->ack_pf, llc->vR); + rc = llc_mac_hdr_init(nskb, llc->dev->dev_addr, llc->daddr.mac); + if (unlikely(rc)) + goto free; + llc_conn_send_pdu(sk, nskb); + } +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +/** + * llc_conn_ac_inc_npta_value - tries to make value of npta greater + * @sk: current connection structure. + * @skb: current event. + * + * After "inc_cntr" times calling of this action, "npta" increase by one. + * this action tries to make vale of "npta" greater as possible; number of + * acknowledgements decreases by increasing of "npta". Returns 0 for + * success, 1 otherwise. + */ +static int llc_conn_ac_inc_npta_value(struct sock *sk, struct sk_buff *skb) +{ + struct llc_sock *llc = llc_sk(sk); + + if (!llc->inc_cntr) { + llc->dec_step = 0; + llc->dec_cntr = llc->inc_cntr = 2; + ++llc->npta; + if (llc->npta > (u8) ~LLC_2_SEQ_NBR_MODULO) + llc->npta = (u8) ~LLC_2_SEQ_NBR_MODULO; + } else + --llc->inc_cntr; + return 0; +} + +/** + * llc_conn_ac_adjust_npta_by_rr - decreases "npta" by one + * @sk: current connection structure. + * @skb: current event. + * + * After receiving "dec_cntr" times RR command, this action decreases + * "npta" by one. Returns 0 for success, 1 otherwise. + */ +int llc_conn_ac_adjust_npta_by_rr(struct sock *sk, struct sk_buff *skb) +{ + struct llc_sock *llc = llc_sk(sk); + + if (!llc->connect_step && !llc->remote_busy_flag) { + if (!llc->dec_step) { + if (!llc->dec_cntr) { + llc->inc_cntr = llc->dec_cntr = 2; + if (llc->npta > 0) + llc->npta = llc->npta - 1; + } else + llc->dec_cntr -=1; + } + } else + llc->connect_step = 0 ; + return 0; +} + +/** + * llc_conn_ac_adjust_npta_by_rnr - decreases "npta" by one + * @sk: current connection structure. + * @skb: current event. + * + * After receiving "dec_cntr" times RNR command, this action decreases + * "npta" by one. Returns 0 for success, 1 otherwise. + */ +int llc_conn_ac_adjust_npta_by_rnr(struct sock *sk, struct sk_buff *skb) +{ + struct llc_sock *llc = llc_sk(sk); + + if (llc->remote_busy_flag) + if (!llc->dec_step) { + if (!llc->dec_cntr) { + llc->inc_cntr = llc->dec_cntr = 2; + if (llc->npta > 0) + --llc->npta; + } else + --llc->dec_cntr; + } + return 0; +} + +/** + * llc_conn_ac_dec_tx_win_size - decreases tx window size + * @sk: current connection structure. + * @skb: current event. + * + * After receiving of a REJ command or response, transmit window size is + * decreased by number of PDUs which are outstanding yet. Returns 0 for + * success, 1 otherwise. + */ +int llc_conn_ac_dec_tx_win_size(struct sock *sk, struct sk_buff *skb) +{ + struct llc_sock *llc = llc_sk(sk); + u8 unacked_pdu = skb_queue_len(&llc->pdu_unack_q); + + if (llc->k - unacked_pdu < 1) + llc->k = 1; + else + llc->k -= unacked_pdu; + return 0; +} + +/** + * llc_conn_ac_inc_tx_win_size - tx window size is inc by 1 + * @sk: current connection structure. + * @skb: current event. + * + * After receiving an RR response with f-bit set to one, transmit window + * size is increased by one. Returns 0 for success, 1 otherwise. + */ +int llc_conn_ac_inc_tx_win_size(struct sock *sk, struct sk_buff *skb) +{ + struct llc_sock *llc = llc_sk(sk); + + llc->k += 1; + if (llc->k > (u8) ~LLC_2_SEQ_NBR_MODULO) + llc->k = (u8) ~LLC_2_SEQ_NBR_MODULO; + return 0; +} + +int llc_conn_ac_stop_all_timers(struct sock *sk, struct sk_buff *skb) +{ + llc_sk_stop_all_timers(sk, false); + return 0; +} + +int llc_conn_ac_stop_other_timers(struct sock *sk, struct sk_buff *skb) +{ + struct llc_sock *llc = llc_sk(sk); + + del_timer(&llc->rej_sent_timer.timer); + del_timer(&llc->pf_cycle_timer.timer); + del_timer(&llc->busy_state_timer.timer); + llc->ack_must_be_send = 0; + llc->ack_pf = 0; + return 0; +} + +int llc_conn_ac_start_ack_timer(struct sock *sk, struct sk_buff *skb) +{ + struct llc_sock *llc = llc_sk(sk); + + mod_timer(&llc->ack_timer.timer, jiffies + llc->ack_timer.expire); + return 0; +} + +int llc_conn_ac_start_rej_timer(struct sock *sk, struct sk_buff *skb) +{ + struct llc_sock *llc = llc_sk(sk); + + mod_timer(&llc->rej_sent_timer.timer, + jiffies + llc->rej_sent_timer.expire); + return 0; +} + +int llc_conn_ac_start_ack_tmr_if_not_running(struct sock *sk, + struct sk_buff *skb) +{ + struct llc_sock *llc = llc_sk(sk); + + if (!timer_pending(&llc->ack_timer.timer)) + mod_timer(&llc->ack_timer.timer, + jiffies + llc->ack_timer.expire); + return 0; +} + +int llc_conn_ac_stop_ack_timer(struct sock *sk, struct sk_buff *skb) +{ + del_timer(&llc_sk(sk)->ack_timer.timer); + return 0; +} + +int llc_conn_ac_stop_p_timer(struct sock *sk, struct sk_buff *skb) +{ + struct llc_sock *llc = llc_sk(sk); + + del_timer(&llc->pf_cycle_timer.timer); + llc_conn_set_p_flag(sk, 0); + return 0; +} + +int llc_conn_ac_stop_rej_timer(struct sock *sk, struct sk_buff *skb) +{ + del_timer(&llc_sk(sk)->rej_sent_timer.timer); + return 0; +} + +int llc_conn_ac_upd_nr_received(struct sock *sk, struct sk_buff *skb) +{ + int acked; + u16 unacked = 0; + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + struct llc_sock *llc = llc_sk(sk); + + llc->last_nr = PDU_SUPV_GET_Nr(pdu); + acked = llc_conn_remove_acked_pdus(sk, llc->last_nr, &unacked); + /* On loopback we don't queue I frames in unack_pdu_q queue. */ + if (acked > 0 || (llc->dev->flags & IFF_LOOPBACK)) { + llc->retry_count = 0; + del_timer(&llc->ack_timer.timer); + if (llc->failed_data_req) { + /* already, we did not accept data from upper layer + * (tx_window full or unacceptable state). Now, we + * can send data and must inform to upper layer. + */ + llc->failed_data_req = 0; + llc_conn_ac_data_confirm(sk, skb); + } + if (unacked) + mod_timer(&llc->ack_timer.timer, + jiffies + llc->ack_timer.expire); + } else if (llc->failed_data_req) { + u8 f_bit; + + llc_pdu_decode_pf_bit(skb, &f_bit); + if (f_bit == 1) { + llc->failed_data_req = 0; + llc_conn_ac_data_confirm(sk, skb); + } + } + return 0; +} + +int llc_conn_ac_upd_p_flag(struct sock *sk, struct sk_buff *skb) +{ + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + if (LLC_PDU_IS_RSP(pdu)) { + u8 f_bit; + + llc_pdu_decode_pf_bit(skb, &f_bit); + if (f_bit) { + llc_conn_set_p_flag(sk, 0); + llc_conn_ac_stop_p_timer(sk, skb); + } + } + return 0; +} + +int llc_conn_ac_set_data_flag_2(struct sock *sk, struct sk_buff *skb) +{ + llc_sk(sk)->data_flag = 2; + return 0; +} + +int llc_conn_ac_set_data_flag_0(struct sock *sk, struct sk_buff *skb) +{ + llc_sk(sk)->data_flag = 0; + return 0; +} + +int llc_conn_ac_set_data_flag_1(struct sock *sk, struct sk_buff *skb) +{ + llc_sk(sk)->data_flag = 1; + return 0; +} + +int llc_conn_ac_set_data_flag_1_if_data_flag_eq_0(struct sock *sk, + struct sk_buff *skb) +{ + if (!llc_sk(sk)->data_flag) + llc_sk(sk)->data_flag = 1; + return 0; +} + +int llc_conn_ac_set_p_flag_0(struct sock *sk, struct sk_buff *skb) +{ + llc_conn_set_p_flag(sk, 0); + return 0; +} + +static int llc_conn_ac_set_p_flag_1(struct sock *sk, struct sk_buff *skb) +{ + llc_conn_set_p_flag(sk, 1); + return 0; +} + +int llc_conn_ac_set_remote_busy_0(struct sock *sk, struct sk_buff *skb) +{ + llc_sk(sk)->remote_busy_flag = 0; + return 0; +} + +int llc_conn_ac_set_cause_flag_0(struct sock *sk, struct sk_buff *skb) +{ + llc_sk(sk)->cause_flag = 0; + return 0; +} + +int llc_conn_ac_set_cause_flag_1(struct sock *sk, struct sk_buff *skb) +{ + llc_sk(sk)->cause_flag = 1; + return 0; +} + +int llc_conn_ac_set_retry_cnt_0(struct sock *sk, struct sk_buff *skb) +{ + llc_sk(sk)->retry_count = 0; + return 0; +} + +int llc_conn_ac_inc_retry_cnt_by_1(struct sock *sk, struct sk_buff *skb) +{ + llc_sk(sk)->retry_count++; + return 0; +} + +int llc_conn_ac_set_vr_0(struct sock *sk, struct sk_buff *skb) +{ + llc_sk(sk)->vR = 0; + return 0; +} + +int llc_conn_ac_inc_vr_by_1(struct sock *sk, struct sk_buff *skb) +{ + llc_sk(sk)->vR = PDU_GET_NEXT_Vr(llc_sk(sk)->vR); + return 0; +} + +int llc_conn_ac_set_vs_0(struct sock *sk, struct sk_buff *skb) +{ + llc_sk(sk)->vS = 0; + return 0; +} + +int llc_conn_ac_set_vs_nr(struct sock *sk, struct sk_buff *skb) +{ + llc_sk(sk)->vS = llc_sk(sk)->last_nr; + return 0; +} + +static int llc_conn_ac_inc_vs_by_1(struct sock *sk, struct sk_buff *skb) +{ + llc_sk(sk)->vS = (llc_sk(sk)->vS + 1) % LLC_2_SEQ_NBR_MODULO; + return 0; +} + +static void llc_conn_tmr_common_cb(struct sock *sk, u8 type) +{ + struct sk_buff *skb = alloc_skb(0, GFP_ATOMIC); + + bh_lock_sock(sk); + if (skb) { + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + skb_set_owner_r(skb, sk); + ev->type = type; + llc_process_tmr_ev(sk, skb); + } + bh_unlock_sock(sk); +} + +void llc_conn_pf_cycle_tmr_cb(struct timer_list *t) +{ + struct llc_sock *llc = from_timer(llc, t, pf_cycle_timer.timer); + + llc_conn_tmr_common_cb(&llc->sk, LLC_CONN_EV_TYPE_P_TMR); +} + +void llc_conn_busy_tmr_cb(struct timer_list *t) +{ + struct llc_sock *llc = from_timer(llc, t, busy_state_timer.timer); + + llc_conn_tmr_common_cb(&llc->sk, LLC_CONN_EV_TYPE_BUSY_TMR); +} + +void llc_conn_ack_tmr_cb(struct timer_list *t) +{ + struct llc_sock *llc = from_timer(llc, t, ack_timer.timer); + + llc_conn_tmr_common_cb(&llc->sk, LLC_CONN_EV_TYPE_ACK_TMR); +} + +void llc_conn_rej_tmr_cb(struct timer_list *t) +{ + struct llc_sock *llc = from_timer(llc, t, rej_sent_timer.timer); + + llc_conn_tmr_common_cb(&llc->sk, LLC_CONN_EV_TYPE_REJ_TMR); +} + +int llc_conn_ac_rst_vs(struct sock *sk, struct sk_buff *skb) +{ + llc_sk(sk)->X = llc_sk(sk)->vS; + llc_conn_ac_set_vs_nr(sk, skb); + return 0; +} + +int llc_conn_ac_upd_vs(struct sock *sk, struct sk_buff *skb) +{ + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + u8 nr = PDU_SUPV_GET_Nr(pdu); + + if (llc_circular_between(llc_sk(sk)->vS, nr, llc_sk(sk)->X)) + llc_conn_ac_set_vs_nr(sk, skb); + return 0; +} + +/* + * Non-standard actions; these not contained in IEEE specification; for + * our own usage + */ +/** + * llc_conn_disc - removes connection from SAP list and frees it + * @sk: closed connection + * @skb: occurred event + */ +int llc_conn_disc(struct sock *sk, struct sk_buff *skb) +{ + /* FIXME: this thing seems to want to die */ + return 0; +} + +/** + * llc_conn_reset - resets connection + * @sk : reseting connection. + * @skb: occurred event. + * + * Stop all timers, empty all queues and reset all flags. + */ +int llc_conn_reset(struct sock *sk, struct sk_buff *skb) +{ + llc_sk_reset(sk); + return 0; +} + +/** + * llc_circular_between - designates that b is between a and c or not + * @a: lower bound + * @b: element to see if is between a and b + * @c: upper bound + * + * This function designates that b is between a and c or not (for example, + * 0 is between 127 and 1). Returns 1 if b is between a and c, 0 + * otherwise. + */ +u8 llc_circular_between(u8 a, u8 b, u8 c) +{ + b = b - a; + c = c - a; + return b <= c; +} + +/** + * llc_process_tmr_ev - timer backend + * @sk: active connection + * @skb: occurred event + * + * This function is called from timer callback functions. When connection + * is busy (during sending a data frame) timer expiration event must be + * queued. Otherwise this event can be sent to connection state machine. + * Queued events will process by llc_backlog_rcv function after sending + * data frame. + */ +static void llc_process_tmr_ev(struct sock *sk, struct sk_buff *skb) +{ + if (llc_sk(sk)->state == LLC_CONN_OUT_OF_SVC) { + printk(KERN_WARNING "%s: timer called on closed connection\n", + __func__); + kfree_skb(skb); + } else { + if (!sock_owned_by_user(sk)) + llc_conn_state_process(sk, skb); + else { + llc_set_backlog_type(skb, LLC_EVENT); + __sk_add_backlog(sk, skb); + } + } +} diff --git a/net/llc/llc_c_ev.c b/net/llc/llc_c_ev.c new file mode 100644 index 000000000..d6627a80c --- /dev/null +++ b/net/llc/llc_c_ev.c @@ -0,0 +1,748 @@ +/* + * llc_c_ev.c - Connection component state transition event qualifiers + * + * A 'state' consists of a number of possible event matching functions, + * the actions associated with each being executed when that event is + * matched; a 'state machine' accepts events in a serial fashion from an + * event queue. Each event is passed to each successive event matching + * function until a match is made (the event matching function returns + * success, or '0') or the list of event matching functions is exhausted. + * If a match is made, the actions associated with the event are executed + * and the state is changed to that event's transition state. Before some + * events are recognized, even after a match has been made, a certain + * number of 'event qualifier' functions must also be executed. If these + * all execute successfully, then the event is finally executed. + * + * These event functions must return 0 for success, to show a matched + * event, of 1 if the event does not match. Event qualifier functions + * must return a 0 for success or a non-zero for failure. Each function + * is simply responsible for verifying one single thing and returning + * either a success or failure. + * + * All of followed event functions are described in 802.2 LLC Protocol + * standard document except two functions that we added that will explain + * in their comments, at below. + * + * Copyright (c) 1997 by Procom Technology, Inc. + * 2001-2003 by Arnaldo Carvalho de Melo + * + * This program can be redistributed or modified under the terms of the + * GNU General Public License as published by the Free Software Foundation. + * This program is distributed without any warranty or implied warranty + * of merchantability or fitness for a particular purpose. + * + * See the GNU General Public License for more details. + */ +#include +#include +#include +#include +#include +#include +#include + +#if 1 +#define dprintk(args...) printk(KERN_DEBUG args) +#else +#define dprintk(args...) +#endif + +/** + * llc_util_ns_inside_rx_window - check if sequence number is in rx window + * @ns: sequence number of received pdu. + * @vr: sequence number which receiver expects to receive. + * @rw: receive window size of receiver. + * + * Checks if sequence number of received PDU is in range of receive + * window. Returns 0 for success, 1 otherwise + */ +static u16 llc_util_ns_inside_rx_window(u8 ns, u8 vr, u8 rw) +{ + return !llc_circular_between(vr, ns, + (vr + rw - 1) % LLC_2_SEQ_NBR_MODULO); +} + +/** + * llc_util_nr_inside_tx_window - check if sequence number is in tx window + * @sk: current connection. + * @nr: N(R) of received PDU. + * + * This routine checks if N(R) of received PDU is in range of transmit + * window; on the other hand checks if received PDU acknowledges some + * outstanding PDUs that are in transmit window. Returns 0 for success, 1 + * otherwise. + */ +static u16 llc_util_nr_inside_tx_window(struct sock *sk, u8 nr) +{ + u8 nr1, nr2; + struct sk_buff *skb; + struct llc_pdu_sn *pdu; + struct llc_sock *llc = llc_sk(sk); + int rc = 0; + + if (llc->dev->flags & IFF_LOOPBACK) + goto out; + rc = 1; + if (skb_queue_empty(&llc->pdu_unack_q)) + goto out; + skb = skb_peek(&llc->pdu_unack_q); + pdu = llc_pdu_sn_hdr(skb); + nr1 = LLC_I_GET_NS(pdu); + skb = skb_peek_tail(&llc->pdu_unack_q); + pdu = llc_pdu_sn_hdr(skb); + nr2 = LLC_I_GET_NS(pdu); + rc = !llc_circular_between(nr1, nr, (nr2 + 1) % LLC_2_SEQ_NBR_MODULO); +out: + return rc; +} + +int llc_conn_ev_conn_req(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + return ev->prim == LLC_CONN_PRIM && + ev->prim_type == LLC_PRIM_TYPE_REQ ? 0 : 1; +} + +int llc_conn_ev_data_req(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + return ev->prim == LLC_DATA_PRIM && + ev->prim_type == LLC_PRIM_TYPE_REQ ? 0 : 1; +} + +int llc_conn_ev_disc_req(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + return ev->prim == LLC_DISC_PRIM && + ev->prim_type == LLC_PRIM_TYPE_REQ ? 0 : 1; +} + +int llc_conn_ev_rst_req(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + return ev->prim == LLC_RESET_PRIM && + ev->prim_type == LLC_PRIM_TYPE_REQ ? 0 : 1; +} + +int llc_conn_ev_local_busy_detected(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + return ev->type == LLC_CONN_EV_TYPE_SIMPLE && + ev->prim_type == LLC_CONN_EV_LOCAL_BUSY_DETECTED ? 0 : 1; +} + +int llc_conn_ev_local_busy_cleared(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + return ev->type == LLC_CONN_EV_TYPE_SIMPLE && + ev->prim_type == LLC_CONN_EV_LOCAL_BUSY_CLEARED ? 0 : 1; +} + +int llc_conn_ev_rx_bad_pdu(struct sock *sk, struct sk_buff *skb) +{ + return 1; +} + +int llc_conn_ev_rx_disc_cmd_pbit_set_x(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + return LLC_PDU_IS_CMD(pdu) && LLC_PDU_TYPE_IS_U(pdu) && + LLC_U_PDU_CMD(pdu) == LLC_2_PDU_CMD_DISC ? 0 : 1; +} + +int llc_conn_ev_rx_dm_rsp_fbit_set_x(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + return LLC_PDU_IS_RSP(pdu) && LLC_PDU_TYPE_IS_U(pdu) && + LLC_U_PDU_RSP(pdu) == LLC_2_PDU_RSP_DM ? 0 : 1; +} + +int llc_conn_ev_rx_frmr_rsp_fbit_set_x(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + return LLC_PDU_IS_RSP(pdu) && LLC_PDU_TYPE_IS_U(pdu) && + LLC_U_PDU_RSP(pdu) == LLC_2_PDU_RSP_FRMR ? 0 : 1; +} + +int llc_conn_ev_rx_i_cmd_pbit_set_0(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + return llc_conn_space(sk, skb) && + LLC_PDU_IS_CMD(pdu) && LLC_PDU_TYPE_IS_I(pdu) && + LLC_I_PF_IS_0(pdu) && + LLC_I_GET_NS(pdu) == llc_sk(sk)->vR ? 0 : 1; +} + +int llc_conn_ev_rx_i_cmd_pbit_set_1(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + return llc_conn_space(sk, skb) && + LLC_PDU_IS_CMD(pdu) && LLC_PDU_TYPE_IS_I(pdu) && + LLC_I_PF_IS_1(pdu) && + LLC_I_GET_NS(pdu) == llc_sk(sk)->vR ? 0 : 1; +} + +int llc_conn_ev_rx_i_cmd_pbit_set_0_unexpd_ns(struct sock *sk, + struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + const u8 vr = llc_sk(sk)->vR; + const u8 ns = LLC_I_GET_NS(pdu); + + return LLC_PDU_IS_CMD(pdu) && LLC_PDU_TYPE_IS_I(pdu) && + LLC_I_PF_IS_0(pdu) && ns != vr && + !llc_util_ns_inside_rx_window(ns, vr, llc_sk(sk)->rw) ? 0 : 1; +} + +int llc_conn_ev_rx_i_cmd_pbit_set_1_unexpd_ns(struct sock *sk, + struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + const u8 vr = llc_sk(sk)->vR; + const u8 ns = LLC_I_GET_NS(pdu); + + return LLC_PDU_IS_CMD(pdu) && LLC_PDU_TYPE_IS_I(pdu) && + LLC_I_PF_IS_1(pdu) && ns != vr && + !llc_util_ns_inside_rx_window(ns, vr, llc_sk(sk)->rw) ? 0 : 1; +} + +int llc_conn_ev_rx_i_cmd_pbit_set_x_inval_ns(struct sock *sk, + struct sk_buff *skb) +{ + const struct llc_pdu_sn * pdu = llc_pdu_sn_hdr(skb); + const u8 vr = llc_sk(sk)->vR; + const u8 ns = LLC_I_GET_NS(pdu); + const u16 rc = LLC_PDU_IS_CMD(pdu) && LLC_PDU_TYPE_IS_I(pdu) && + ns != vr && + llc_util_ns_inside_rx_window(ns, vr, llc_sk(sk)->rw) ? 0 : 1; + if (!rc) + dprintk("%s: matched, state=%d, ns=%d, vr=%d\n", + __func__, llc_sk(sk)->state, ns, vr); + return rc; +} + +int llc_conn_ev_rx_i_rsp_fbit_set_0(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + return llc_conn_space(sk, skb) && + LLC_PDU_IS_RSP(pdu) && LLC_PDU_TYPE_IS_I(pdu) && + LLC_I_PF_IS_0(pdu) && + LLC_I_GET_NS(pdu) == llc_sk(sk)->vR ? 0 : 1; +} + +int llc_conn_ev_rx_i_rsp_fbit_set_1(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + return LLC_PDU_IS_RSP(pdu) && LLC_PDU_TYPE_IS_I(pdu) && + LLC_I_PF_IS_1(pdu) && + LLC_I_GET_NS(pdu) == llc_sk(sk)->vR ? 0 : 1; +} + +int llc_conn_ev_rx_i_rsp_fbit_set_x(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + return llc_conn_space(sk, skb) && + LLC_PDU_IS_RSP(pdu) && LLC_PDU_TYPE_IS_I(pdu) && + LLC_I_GET_NS(pdu) == llc_sk(sk)->vR ? 0 : 1; +} + +int llc_conn_ev_rx_i_rsp_fbit_set_0_unexpd_ns(struct sock *sk, + struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + const u8 vr = llc_sk(sk)->vR; + const u8 ns = LLC_I_GET_NS(pdu); + + return LLC_PDU_IS_RSP(pdu) && LLC_PDU_TYPE_IS_I(pdu) && + LLC_I_PF_IS_0(pdu) && ns != vr && + !llc_util_ns_inside_rx_window(ns, vr, llc_sk(sk)->rw) ? 0 : 1; +} + +int llc_conn_ev_rx_i_rsp_fbit_set_1_unexpd_ns(struct sock *sk, + struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + const u8 vr = llc_sk(sk)->vR; + const u8 ns = LLC_I_GET_NS(pdu); + + return LLC_PDU_IS_RSP(pdu) && LLC_PDU_TYPE_IS_I(pdu) && + LLC_I_PF_IS_1(pdu) && ns != vr && + !llc_util_ns_inside_rx_window(ns, vr, llc_sk(sk)->rw) ? 0 : 1; +} + +int llc_conn_ev_rx_i_rsp_fbit_set_x_unexpd_ns(struct sock *sk, + struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + const u8 vr = llc_sk(sk)->vR; + const u8 ns = LLC_I_GET_NS(pdu); + + return LLC_PDU_IS_RSP(pdu) && LLC_PDU_TYPE_IS_I(pdu) && ns != vr && + !llc_util_ns_inside_rx_window(ns, vr, llc_sk(sk)->rw) ? 0 : 1; +} + +int llc_conn_ev_rx_i_rsp_fbit_set_x_inval_ns(struct sock *sk, + struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + const u8 vr = llc_sk(sk)->vR; + const u8 ns = LLC_I_GET_NS(pdu); + const u16 rc = LLC_PDU_IS_RSP(pdu) && LLC_PDU_TYPE_IS_I(pdu) && + ns != vr && + llc_util_ns_inside_rx_window(ns, vr, llc_sk(sk)->rw) ? 0 : 1; + if (!rc) + dprintk("%s: matched, state=%d, ns=%d, vr=%d\n", + __func__, llc_sk(sk)->state, ns, vr); + return rc; +} + +int llc_conn_ev_rx_rej_cmd_pbit_set_0(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + return LLC_PDU_IS_CMD(pdu) && LLC_PDU_TYPE_IS_S(pdu) && + LLC_S_PF_IS_0(pdu) && + LLC_S_PDU_CMD(pdu) == LLC_2_PDU_CMD_REJ ? 0 : 1; +} + +int llc_conn_ev_rx_rej_cmd_pbit_set_1(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + return LLC_PDU_IS_CMD(pdu) && LLC_PDU_TYPE_IS_S(pdu) && + LLC_S_PF_IS_1(pdu) && + LLC_S_PDU_CMD(pdu) == LLC_2_PDU_CMD_REJ ? 0 : 1; +} + +int llc_conn_ev_rx_rej_rsp_fbit_set_0(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + return LLC_PDU_IS_RSP(pdu) && LLC_PDU_TYPE_IS_S(pdu) && + LLC_S_PF_IS_0(pdu) && + LLC_S_PDU_RSP(pdu) == LLC_2_PDU_RSP_REJ ? 0 : 1; +} + +int llc_conn_ev_rx_rej_rsp_fbit_set_1(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + return LLC_PDU_IS_RSP(pdu) && LLC_PDU_TYPE_IS_S(pdu) && + LLC_S_PF_IS_1(pdu) && + LLC_S_PDU_RSP(pdu) == LLC_2_PDU_RSP_REJ ? 0 : 1; +} + +int llc_conn_ev_rx_rej_rsp_fbit_set_x(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + return LLC_PDU_IS_RSP(pdu) && LLC_PDU_TYPE_IS_S(pdu) && + LLC_S_PDU_RSP(pdu) == LLC_2_PDU_RSP_REJ ? 0 : 1; +} + +int llc_conn_ev_rx_rnr_cmd_pbit_set_0(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + return LLC_PDU_IS_CMD(pdu) && LLC_PDU_TYPE_IS_S(pdu) && + LLC_S_PF_IS_0(pdu) && + LLC_S_PDU_CMD(pdu) == LLC_2_PDU_CMD_RNR ? 0 : 1; +} + +int llc_conn_ev_rx_rnr_cmd_pbit_set_1(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + return LLC_PDU_IS_CMD(pdu) && LLC_PDU_TYPE_IS_S(pdu) && + LLC_S_PF_IS_1(pdu) && + LLC_S_PDU_CMD(pdu) == LLC_2_PDU_CMD_RNR ? 0 : 1; +} + +int llc_conn_ev_rx_rnr_rsp_fbit_set_0(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + return LLC_PDU_IS_RSP(pdu) && LLC_PDU_TYPE_IS_S(pdu) && + LLC_S_PF_IS_0(pdu) && + LLC_S_PDU_RSP(pdu) == LLC_2_PDU_RSP_RNR ? 0 : 1; +} + +int llc_conn_ev_rx_rnr_rsp_fbit_set_1(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + return LLC_PDU_IS_RSP(pdu) && LLC_PDU_TYPE_IS_S(pdu) && + LLC_S_PF_IS_1(pdu) && + LLC_S_PDU_RSP(pdu) == LLC_2_PDU_RSP_RNR ? 0 : 1; +} + +int llc_conn_ev_rx_rr_cmd_pbit_set_0(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + return LLC_PDU_IS_CMD(pdu) && LLC_PDU_TYPE_IS_S(pdu) && + LLC_S_PF_IS_0(pdu) && + LLC_S_PDU_CMD(pdu) == LLC_2_PDU_CMD_RR ? 0 : 1; +} + +int llc_conn_ev_rx_rr_cmd_pbit_set_1(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + return LLC_PDU_IS_CMD(pdu) && LLC_PDU_TYPE_IS_S(pdu) && + LLC_S_PF_IS_1(pdu) && + LLC_S_PDU_CMD(pdu) == LLC_2_PDU_CMD_RR ? 0 : 1; +} + +int llc_conn_ev_rx_rr_rsp_fbit_set_0(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + return llc_conn_space(sk, skb) && + LLC_PDU_IS_RSP(pdu) && LLC_PDU_TYPE_IS_S(pdu) && + LLC_S_PF_IS_0(pdu) && + LLC_S_PDU_RSP(pdu) == LLC_2_PDU_RSP_RR ? 0 : 1; +} + +int llc_conn_ev_rx_rr_rsp_fbit_set_1(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + return llc_conn_space(sk, skb) && + LLC_PDU_IS_RSP(pdu) && LLC_PDU_TYPE_IS_S(pdu) && + LLC_S_PF_IS_1(pdu) && + LLC_S_PDU_RSP(pdu) == LLC_2_PDU_RSP_RR ? 0 : 1; +} + +int llc_conn_ev_rx_sabme_cmd_pbit_set_x(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + return LLC_PDU_IS_CMD(pdu) && LLC_PDU_TYPE_IS_U(pdu) && + LLC_U_PDU_CMD(pdu) == LLC_2_PDU_CMD_SABME ? 0 : 1; +} + +int llc_conn_ev_rx_ua_rsp_fbit_set_x(struct sock *sk, struct sk_buff *skb) +{ + struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + return LLC_PDU_IS_RSP(pdu) && LLC_PDU_TYPE_IS_U(pdu) && + LLC_U_PDU_RSP(pdu) == LLC_2_PDU_RSP_UA ? 0 : 1; +} + +int llc_conn_ev_rx_xxx_cmd_pbit_set_1(struct sock *sk, struct sk_buff *skb) +{ + u16 rc = 1; + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + if (LLC_PDU_IS_CMD(pdu)) { + if (LLC_PDU_TYPE_IS_I(pdu) || LLC_PDU_TYPE_IS_S(pdu)) { + if (LLC_I_PF_IS_1(pdu)) + rc = 0; + } else if (LLC_PDU_TYPE_IS_U(pdu) && LLC_U_PF_IS_1(pdu)) + rc = 0; + } + return rc; +} + +int llc_conn_ev_rx_xxx_cmd_pbit_set_x(struct sock *sk, struct sk_buff *skb) +{ + u16 rc = 1; + const struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + if (LLC_PDU_IS_CMD(pdu)) { + if (LLC_PDU_TYPE_IS_I(pdu) || LLC_PDU_TYPE_IS_S(pdu)) + rc = 0; + else if (LLC_PDU_TYPE_IS_U(pdu)) + switch (LLC_U_PDU_CMD(pdu)) { + case LLC_2_PDU_CMD_SABME: + case LLC_2_PDU_CMD_DISC: + rc = 0; + break; + } + } + return rc; +} + +int llc_conn_ev_rx_xxx_rsp_fbit_set_x(struct sock *sk, struct sk_buff *skb) +{ + u16 rc = 1; + const struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + if (LLC_PDU_IS_RSP(pdu)) { + if (LLC_PDU_TYPE_IS_I(pdu) || LLC_PDU_TYPE_IS_S(pdu)) + rc = 0; + else if (LLC_PDU_TYPE_IS_U(pdu)) + switch (LLC_U_PDU_RSP(pdu)) { + case LLC_2_PDU_RSP_UA: + case LLC_2_PDU_RSP_DM: + case LLC_2_PDU_RSP_FRMR: + rc = 0; + break; + } + } + + return rc; +} + +int llc_conn_ev_rx_zzz_cmd_pbit_set_x_inval_nr(struct sock *sk, + struct sk_buff *skb) +{ + u16 rc = 1; + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + const u8 vs = llc_sk(sk)->vS; + const u8 nr = LLC_I_GET_NR(pdu); + + if (LLC_PDU_IS_CMD(pdu) && + (LLC_PDU_TYPE_IS_I(pdu) || LLC_PDU_TYPE_IS_S(pdu)) && + nr != vs && llc_util_nr_inside_tx_window(sk, nr)) { + dprintk("%s: matched, state=%d, vs=%d, nr=%d\n", + __func__, llc_sk(sk)->state, vs, nr); + rc = 0; + } + return rc; +} + +int llc_conn_ev_rx_zzz_rsp_fbit_set_x_inval_nr(struct sock *sk, + struct sk_buff *skb) +{ + u16 rc = 1; + const struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + const u8 vs = llc_sk(sk)->vS; + const u8 nr = LLC_I_GET_NR(pdu); + + if (LLC_PDU_IS_RSP(pdu) && + (LLC_PDU_TYPE_IS_I(pdu) || LLC_PDU_TYPE_IS_S(pdu)) && + nr != vs && llc_util_nr_inside_tx_window(sk, nr)) { + rc = 0; + dprintk("%s: matched, state=%d, vs=%d, nr=%d\n", + __func__, llc_sk(sk)->state, vs, nr); + } + return rc; +} + +int llc_conn_ev_rx_any_frame(struct sock *sk, struct sk_buff *skb) +{ + return 0; +} + +int llc_conn_ev_p_tmr_exp(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + return ev->type != LLC_CONN_EV_TYPE_P_TMR; +} + +int llc_conn_ev_ack_tmr_exp(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + return ev->type != LLC_CONN_EV_TYPE_ACK_TMR; +} + +int llc_conn_ev_rej_tmr_exp(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + return ev->type != LLC_CONN_EV_TYPE_REJ_TMR; +} + +int llc_conn_ev_busy_tmr_exp(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + return ev->type != LLC_CONN_EV_TYPE_BUSY_TMR; +} + +int llc_conn_ev_init_p_f_cycle(struct sock *sk, struct sk_buff *skb) +{ + return 1; +} + +int llc_conn_ev_tx_buffer_full(struct sock *sk, struct sk_buff *skb) +{ + const struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + return ev->type == LLC_CONN_EV_TYPE_SIMPLE && + ev->prim_type == LLC_CONN_EV_TX_BUFF_FULL ? 0 : 1; +} + +/* Event qualifier functions + * + * these functions simply verify the value of a state flag associated with + * the connection and return either a 0 for success or a non-zero value + * for not-success; verify the event is the type we expect + */ +int llc_conn_ev_qlfy_data_flag_eq_1(struct sock *sk, struct sk_buff *skb) +{ + return llc_sk(sk)->data_flag != 1; +} + +int llc_conn_ev_qlfy_data_flag_eq_0(struct sock *sk, struct sk_buff *skb) +{ + return llc_sk(sk)->data_flag; +} + +int llc_conn_ev_qlfy_data_flag_eq_2(struct sock *sk, struct sk_buff *skb) +{ + return llc_sk(sk)->data_flag != 2; +} + +int llc_conn_ev_qlfy_p_flag_eq_1(struct sock *sk, struct sk_buff *skb) +{ + return llc_sk(sk)->p_flag != 1; +} + +/** + * llc_conn_ev_qlfy_last_frame_eq_1 - checks if frame is last in tx window + * @sk: current connection structure. + * @skb: current event. + * + * This function determines when frame which is sent, is last frame of + * transmit window, if it is then this function return zero else return + * one. This function is used for sending last frame of transmit window + * as I-format command with p-bit set to one. Returns 0 if frame is last + * frame, 1 otherwise. + */ +int llc_conn_ev_qlfy_last_frame_eq_1(struct sock *sk, struct sk_buff *skb) +{ + return !(skb_queue_len(&llc_sk(sk)->pdu_unack_q) + 1 == llc_sk(sk)->k); +} + +/** + * llc_conn_ev_qlfy_last_frame_eq_0 - checks if frame isn't last in tx window + * @sk: current connection structure. + * @skb: current event. + * + * This function determines when frame which is sent, isn't last frame of + * transmit window, if it isn't then this function return zero else return + * one. Returns 0 if frame isn't last frame, 1 otherwise. + */ +int llc_conn_ev_qlfy_last_frame_eq_0(struct sock *sk, struct sk_buff *skb) +{ + return skb_queue_len(&llc_sk(sk)->pdu_unack_q) + 1 == llc_sk(sk)->k; +} + +int llc_conn_ev_qlfy_p_flag_eq_0(struct sock *sk, struct sk_buff *skb) +{ + return llc_sk(sk)->p_flag; +} + +int llc_conn_ev_qlfy_p_flag_eq_f(struct sock *sk, struct sk_buff *skb) +{ + u8 f_bit; + + llc_pdu_decode_pf_bit(skb, &f_bit); + return llc_sk(sk)->p_flag == f_bit ? 0 : 1; +} + +int llc_conn_ev_qlfy_remote_busy_eq_0(struct sock *sk, struct sk_buff *skb) +{ + return llc_sk(sk)->remote_busy_flag; +} + +int llc_conn_ev_qlfy_remote_busy_eq_1(struct sock *sk, struct sk_buff *skb) +{ + return !llc_sk(sk)->remote_busy_flag; +} + +int llc_conn_ev_qlfy_retry_cnt_lt_n2(struct sock *sk, struct sk_buff *skb) +{ + return !(llc_sk(sk)->retry_count < llc_sk(sk)->n2); +} + +int llc_conn_ev_qlfy_retry_cnt_gte_n2(struct sock *sk, struct sk_buff *skb) +{ + return !(llc_sk(sk)->retry_count >= llc_sk(sk)->n2); +} + +int llc_conn_ev_qlfy_s_flag_eq_1(struct sock *sk, struct sk_buff *skb) +{ + return !llc_sk(sk)->s_flag; +} + +int llc_conn_ev_qlfy_s_flag_eq_0(struct sock *sk, struct sk_buff *skb) +{ + return llc_sk(sk)->s_flag; +} + +int llc_conn_ev_qlfy_cause_flag_eq_1(struct sock *sk, struct sk_buff *skb) +{ + return !llc_sk(sk)->cause_flag; +} + +int llc_conn_ev_qlfy_cause_flag_eq_0(struct sock *sk, struct sk_buff *skb) +{ + return llc_sk(sk)->cause_flag; +} + +int llc_conn_ev_qlfy_set_status_conn(struct sock *sk, struct sk_buff *skb) +{ + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + ev->status = LLC_STATUS_CONN; + return 0; +} + +int llc_conn_ev_qlfy_set_status_disc(struct sock *sk, struct sk_buff *skb) +{ + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + ev->status = LLC_STATUS_DISC; + return 0; +} + +int llc_conn_ev_qlfy_set_status_failed(struct sock *sk, struct sk_buff *skb) +{ + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + ev->status = LLC_STATUS_FAILED; + return 0; +} + +int llc_conn_ev_qlfy_set_status_remote_busy(struct sock *sk, + struct sk_buff *skb) +{ + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + ev->status = LLC_STATUS_REMOTE_BUSY; + return 0; +} + +int llc_conn_ev_qlfy_set_status_refuse(struct sock *sk, struct sk_buff *skb) +{ + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + ev->status = LLC_STATUS_REFUSE; + return 0; +} + +int llc_conn_ev_qlfy_set_status_conflict(struct sock *sk, struct sk_buff *skb) +{ + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + ev->status = LLC_STATUS_CONFLICT; + return 0; +} + +int llc_conn_ev_qlfy_set_status_rst_done(struct sock *sk, struct sk_buff *skb) +{ + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + ev->status = LLC_STATUS_RESET_DONE; + return 0; +} diff --git a/net/llc/llc_c_st.c b/net/llc/llc_c_st.c new file mode 100644 index 000000000..2467573b5 --- /dev/null +++ b/net/llc/llc_c_st.c @@ -0,0 +1,4946 @@ +/* + * llc_c_st.c - This module contains state transition of connection component. + * + * Description of event functions and actions there is in 802.2 LLC standard, + * or in "llc_c_ac.c" and "llc_c_ev.c" modules. + * + * Copyright (c) 1997 by Procom Technology, Inc. + * 2001-2003 by Arnaldo Carvalho de Melo + * + * This program can be redistributed or modified under the terms of the + * GNU General Public License as published by the Free Software Foundation. + * This program is distributed without any warranty or implied warranty + * of merchantability or fitness for a particular purpose. + * + * See the GNU General Public License for more details. + */ +#include +#include +#include +#include +#include +#include + +#define NONE NULL + +/* COMMON CONNECTION STATE transitions + * Common transitions for + * LLC_CONN_STATE_NORMAL, + * LLC_CONN_STATE_BUSY, + * LLC_CONN_STATE_REJ, + * LLC_CONN_STATE_AWAIT, + * LLC_CONN_STATE_AWAIT_BUSY and + * LLC_CONN_STATE_AWAIT_REJ states + */ +/* State transitions for LLC_CONN_EV_DISC_REQ event */ +static const llc_conn_action_t llc_common_actions_1[] = { + [0] = llc_conn_ac_send_disc_cmd_p_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_stop_other_timers, + [3] = llc_conn_ac_set_retry_cnt_0, + [4] = llc_conn_ac_set_cause_flag_1, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_common_state_trans_1 = { + .ev = llc_conn_ev_disc_req, + .next_state = LLC_CONN_STATE_D_CONN, + .ev_qualifiers = NONE, + .ev_actions = llc_common_actions_1, +}; + +/* State transitions for LLC_CONN_EV_RESET_REQ event */ +static const llc_conn_action_t llc_common_actions_2[] = { + [0] = llc_conn_ac_send_sabme_cmd_p_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_stop_other_timers, + [3] = llc_conn_ac_set_retry_cnt_0, + [4] = llc_conn_ac_set_cause_flag_1, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_common_state_trans_2 = { + .ev = llc_conn_ev_rst_req, + .next_state = LLC_CONN_STATE_RESET, + .ev_qualifiers = NONE, + .ev_actions = llc_common_actions_2, +}; + +/* State transitions for LLC_CONN_EV_RX_SABME_CMD_Pbit_SET_X event */ +static const llc_conn_action_t llc_common_actions_3[] = { + [0] = llc_conn_ac_stop_all_timers, + [1] = llc_conn_ac_set_vs_0, + [2] = llc_conn_ac_set_vr_0, + [3] = llc_conn_ac_send_ua_rsp_f_set_p, + [4] = llc_conn_ac_rst_ind, + [5] = llc_conn_ac_set_p_flag_0, + [6] = llc_conn_ac_set_remote_busy_0, + [7] = llc_conn_reset, + [8] = NULL, +}; + +static struct llc_conn_state_trans llc_common_state_trans_3 = { + .ev = llc_conn_ev_rx_sabme_cmd_pbit_set_x, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = NONE, + .ev_actions = llc_common_actions_3, +}; + +/* State transitions for LLC_CONN_EV_RX_DISC_CMD_Pbit_SET_X event */ +static const llc_conn_action_t llc_common_actions_4[] = { + [0] = llc_conn_ac_stop_all_timers, + [1] = llc_conn_ac_send_ua_rsp_f_set_p, + [2] = llc_conn_ac_disc_ind, + [3] = llc_conn_disc, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_common_state_trans_4 = { + .ev = llc_conn_ev_rx_disc_cmd_pbit_set_x, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = NONE, + .ev_actions = llc_common_actions_4, +}; + +/* State transitions for LLC_CONN_EV_RX_FRMR_RSP_Fbit_SET_X event */ +static const llc_conn_action_t llc_common_actions_5[] = { + [0] = llc_conn_ac_send_sabme_cmd_p_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_stop_other_timers, + [3] = llc_conn_ac_set_retry_cnt_0, + [4] = llc_conn_ac_rst_ind, + [5] = llc_conn_ac_set_cause_flag_0, + [6] = llc_conn_reset, + [7] = NULL, +}; + +static struct llc_conn_state_trans llc_common_state_trans_5 = { + .ev = llc_conn_ev_rx_frmr_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_RESET, + .ev_qualifiers = NONE, + .ev_actions = llc_common_actions_5, +}; + +/* State transitions for LLC_CONN_EV_RX_DM_RSP_Fbit_SET_X event */ +static const llc_conn_action_t llc_common_actions_6[] = { + [0] = llc_conn_ac_disc_ind, + [1] = llc_conn_ac_stop_all_timers, + [2] = llc_conn_disc, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_common_state_trans_6 = { + .ev = llc_conn_ev_rx_dm_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = NONE, + .ev_actions = llc_common_actions_6, +}; + +/* State transitions for LLC_CONN_EV_RX_ZZZ_CMD_Pbit_SET_X_INVAL_Nr event */ +static const llc_conn_action_t llc_common_actions_7a[] = { + [0] = llc_conn_ac_send_frmr_rsp_f_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_stop_other_timers, + [3] = llc_conn_ac_set_retry_cnt_0, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_common_state_trans_7a = { + .ev = llc_conn_ev_rx_zzz_cmd_pbit_set_x_inval_nr, + .next_state = LLC_CONN_STATE_ERROR, + .ev_qualifiers = NONE, + .ev_actions = llc_common_actions_7a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_X_INVAL_Ns event */ +static const llc_conn_action_t llc_common_actions_7b[] = { + [0] = llc_conn_ac_send_frmr_rsp_f_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_stop_other_timers, + [3] = llc_conn_ac_set_retry_cnt_0, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_common_state_trans_7b = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_x_inval_ns, + .next_state = LLC_CONN_STATE_ERROR, + .ev_qualifiers = NONE, + .ev_actions = llc_common_actions_7b, +}; + +/* State transitions for LLC_CONN_EV_RX_ZZZ_RSP_Fbit_SET_X_INVAL_Nr event */ +static const llc_conn_action_t llc_common_actions_8a[] = { + [0] = llc_conn_ac_send_frmr_rsp_f_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_stop_other_timers, + [3] = llc_conn_ac_set_retry_cnt_0, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_common_state_trans_8a = { + .ev = llc_conn_ev_rx_zzz_rsp_fbit_set_x_inval_nr, + .next_state = LLC_CONN_STATE_ERROR, + .ev_qualifiers = NONE, + .ev_actions = llc_common_actions_8a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_X_INVAL_Ns event */ +static const llc_conn_action_t llc_common_actions_8b[] = { + [0] = llc_conn_ac_send_frmr_rsp_f_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_stop_other_timers, + [3] = llc_conn_ac_set_retry_cnt_0, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_common_state_trans_8b = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_x_inval_ns, + .next_state = LLC_CONN_STATE_ERROR, + .ev_qualifiers = NONE, + .ev_actions = llc_common_actions_8b, +}; + +/* State transitions for LLC_CONN_EV_RX_BAD_PDU event */ +static const llc_conn_action_t llc_common_actions_8c[] = { + [0] = llc_conn_ac_send_frmr_rsp_f_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_stop_other_timers, + [3] = llc_conn_ac_set_retry_cnt_0, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_common_state_trans_8c = { + .ev = llc_conn_ev_rx_bad_pdu, + .next_state = LLC_CONN_STATE_ERROR, + .ev_qualifiers = NONE, + .ev_actions = llc_common_actions_8c, +}; + +/* State transitions for LLC_CONN_EV_RX_UA_RSP_Fbit_SET_X event */ +static const llc_conn_action_t llc_common_actions_9[] = { + [0] = llc_conn_ac_send_frmr_rsp_f_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_stop_other_timers, + [3] = llc_conn_ac_set_retry_cnt_0, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_common_state_trans_9 = { + .ev = llc_conn_ev_rx_ua_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_ERROR, + .ev_qualifiers = NONE, + .ev_actions = llc_common_actions_9, +}; + +/* State transitions for LLC_CONN_EV_RX_XXX_RSP_Fbit_SET_1 event */ +#if 0 +static const llc_conn_ev_qfyr_t llc_common_ev_qfyrs_10[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = NULL, +}; + +static const llc_conn_action_t llc_common_actions_10[] = { + [0] = llc_conn_ac_send_frmr_rsp_f_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_stop_other_timers, + [3] = llc_conn_ac_set_retry_cnt_0, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_common_state_trans_10 = { + .ev = llc_conn_ev_rx_xxx_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_ERROR, + .ev_qualifiers = llc_common_ev_qfyrs_10, + .ev_actions = llc_common_actions_10, +}; +#endif + +/* State transitions for LLC_CONN_EV_P_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_common_ev_qfyrs_11a[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_gte_n2, + [1] = NULL, +}; + +static const llc_conn_action_t llc_common_actions_11a[] = { + [0] = llc_conn_ac_send_sabme_cmd_p_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_stop_other_timers, + [3] = llc_conn_ac_set_retry_cnt_0, + [4] = llc_conn_ac_set_cause_flag_0, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_common_state_trans_11a = { + .ev = llc_conn_ev_p_tmr_exp, + .next_state = LLC_CONN_STATE_RESET, + .ev_qualifiers = llc_common_ev_qfyrs_11a, + .ev_actions = llc_common_actions_11a, +}; + +/* State transitions for LLC_CONN_EV_ACK_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_common_ev_qfyrs_11b[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_gte_n2, + [1] = NULL, +}; + +static const llc_conn_action_t llc_common_actions_11b[] = { + [0] = llc_conn_ac_send_sabme_cmd_p_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_stop_other_timers, + [3] = llc_conn_ac_set_retry_cnt_0, + [4] = llc_conn_ac_set_cause_flag_0, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_common_state_trans_11b = { + .ev = llc_conn_ev_ack_tmr_exp, + .next_state = LLC_CONN_STATE_RESET, + .ev_qualifiers = llc_common_ev_qfyrs_11b, + .ev_actions = llc_common_actions_11b, +}; + +/* State transitions for LLC_CONN_EV_REJ_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_common_ev_qfyrs_11c[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_gte_n2, + [1] = NULL, +}; + +static const llc_conn_action_t llc_common_actions_11c[] = { + [0] = llc_conn_ac_send_sabme_cmd_p_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_stop_other_timers, + [3] = llc_conn_ac_set_retry_cnt_0, + [4] = llc_conn_ac_set_cause_flag_0, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_common_state_trans_11c = { + .ev = llc_conn_ev_rej_tmr_exp, + .next_state = LLC_CONN_STATE_RESET, + .ev_qualifiers = llc_common_ev_qfyrs_11c, + .ev_actions = llc_common_actions_11c, +}; + +/* State transitions for LLC_CONN_EV_BUSY_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_common_ev_qfyrs_11d[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_gte_n2, + [1] = NULL, +}; + +static const llc_conn_action_t llc_common_actions_11d[] = { + [0] = llc_conn_ac_send_sabme_cmd_p_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_stop_other_timers, + [3] = llc_conn_ac_set_retry_cnt_0, + [4] = llc_conn_ac_set_cause_flag_0, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_common_state_trans_11d = { + .ev = llc_conn_ev_busy_tmr_exp, + .next_state = LLC_CONN_STATE_RESET, + .ev_qualifiers = llc_common_ev_qfyrs_11d, + .ev_actions = llc_common_actions_11d, +}; + +/* + * Common dummy state transition; must be last entry for all state + * transition groups - it'll be on .bss, so will be zeroed. + */ +static struct llc_conn_state_trans llc_common_state_trans_end; + +/* LLC_CONN_STATE_ADM transitions */ +/* State transitions for LLC_CONN_EV_CONN_REQ event */ +static const llc_conn_action_t llc_adm_actions_1[] = { + [0] = llc_conn_ac_send_sabme_cmd_p_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_set_retry_cnt_0, + [3] = llc_conn_ac_set_s_flag_0, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_adm_state_trans_1 = { + .ev = llc_conn_ev_conn_req, + .next_state = LLC_CONN_STATE_SETUP, + .ev_qualifiers = NONE, + .ev_actions = llc_adm_actions_1, +}; + +/* State transitions for LLC_CONN_EV_RX_SABME_CMD_Pbit_SET_X event */ +static const llc_conn_action_t llc_adm_actions_2[] = { + [0] = llc_conn_ac_send_ua_rsp_f_set_p, + [1] = llc_conn_ac_set_vs_0, + [2] = llc_conn_ac_set_vr_0, + [3] = llc_conn_ac_set_retry_cnt_0, + [4] = llc_conn_ac_set_p_flag_0, + [5] = llc_conn_ac_set_remote_busy_0, + [6] = llc_conn_ac_conn_ind, + [7] = NULL, +}; + +static struct llc_conn_state_trans llc_adm_state_trans_2 = { + .ev = llc_conn_ev_rx_sabme_cmd_pbit_set_x, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = NONE, + .ev_actions = llc_adm_actions_2, +}; + +/* State transitions for LLC_CONN_EV_RX_DISC_CMD_Pbit_SET_X event */ +static const llc_conn_action_t llc_adm_actions_3[] = { + [0] = llc_conn_ac_send_dm_rsp_f_set_p, + [1] = llc_conn_disc, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_adm_state_trans_3 = { + .ev = llc_conn_ev_rx_disc_cmd_pbit_set_x, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = NONE, + .ev_actions = llc_adm_actions_3, +}; + +/* State transitions for LLC_CONN_EV_RX_XXX_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_adm_actions_4[] = { + [0] = llc_conn_ac_send_dm_rsp_f_set_1, + [1] = llc_conn_disc, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_adm_state_trans_4 = { + .ev = llc_conn_ev_rx_xxx_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = NONE, + .ev_actions = llc_adm_actions_4, +}; + +/* State transitions for LLC_CONN_EV_RX_XXX_YYY event */ +static const llc_conn_action_t llc_adm_actions_5[] = { + [0] = llc_conn_disc, + [1] = NULL, +}; + +static struct llc_conn_state_trans llc_adm_state_trans_5 = { + .ev = llc_conn_ev_rx_any_frame, + .next_state = LLC_CONN_OUT_OF_SVC, + .ev_qualifiers = NONE, + .ev_actions = llc_adm_actions_5, +}; + +/* + * Array of pointers; + * one to each transition + */ +static struct llc_conn_state_trans *llc_adm_state_transitions[] = { + [0] = &llc_adm_state_trans_1, /* Request */ + [1] = &llc_common_state_trans_end, + [2] = &llc_common_state_trans_end, /* local_busy */ + [3] = &llc_common_state_trans_end, /* init_pf_cycle */ + [4] = &llc_common_state_trans_end, /* timer */ + [5] = &llc_adm_state_trans_2, /* Receive frame */ + [6] = &llc_adm_state_trans_3, + [7] = &llc_adm_state_trans_4, + [8] = &llc_adm_state_trans_5, + [9] = &llc_common_state_trans_end, +}; + +/* LLC_CONN_STATE_SETUP transitions */ +/* State transitions for LLC_CONN_EV_RX_SABME_CMD_Pbit_SET_X event */ +static const llc_conn_action_t llc_setup_actions_1[] = { + [0] = llc_conn_ac_send_ua_rsp_f_set_p, + [1] = llc_conn_ac_set_vs_0, + [2] = llc_conn_ac_set_vr_0, + [3] = llc_conn_ac_set_s_flag_1, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_setup_state_trans_1 = { + .ev = llc_conn_ev_rx_sabme_cmd_pbit_set_x, + .next_state = LLC_CONN_STATE_SETUP, + .ev_qualifiers = NONE, + .ev_actions = llc_setup_actions_1, +}; + +/* State transitions for LLC_CONN_EV_RX_UA_RSP_Fbit_SET_X event */ +static const llc_conn_ev_qfyr_t llc_setup_ev_qfyrs_2[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_f, + [1] = llc_conn_ev_qlfy_set_status_conn, + [2] = NULL, +}; + +static const llc_conn_action_t llc_setup_actions_2[] = { + [0] = llc_conn_ac_stop_ack_timer, + [1] = llc_conn_ac_set_vs_0, + [2] = llc_conn_ac_set_vr_0, + [3] = llc_conn_ac_upd_p_flag, + [4] = llc_conn_ac_set_remote_busy_0, + [5] = llc_conn_ac_conn_confirm, + [6] = NULL, +}; + +static struct llc_conn_state_trans llc_setup_state_trans_2 = { + .ev = llc_conn_ev_rx_ua_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_setup_ev_qfyrs_2, + .ev_actions = llc_setup_actions_2, +}; + +/* State transitions for LLC_CONN_EV_ACK_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_setup_ev_qfyrs_3[] = { + [0] = llc_conn_ev_qlfy_s_flag_eq_1, + [1] = llc_conn_ev_qlfy_set_status_conn, + [2] = NULL, +}; + +static const llc_conn_action_t llc_setup_actions_3[] = { + [0] = llc_conn_ac_set_p_flag_0, + [1] = llc_conn_ac_set_remote_busy_0, + [2] = llc_conn_ac_conn_confirm, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_setup_state_trans_3 = { + .ev = llc_conn_ev_ack_tmr_exp, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_setup_ev_qfyrs_3, + .ev_actions = llc_setup_actions_3, +}; + +/* State transitions for LLC_CONN_EV_RX_DISC_CMD_Pbit_SET_X event */ +static const llc_conn_ev_qfyr_t llc_setup_ev_qfyrs_4[] = { + [0] = llc_conn_ev_qlfy_set_status_disc, + [1] = NULL, +}; + +static const llc_conn_action_t llc_setup_actions_4[] = { + [0] = llc_conn_ac_send_dm_rsp_f_set_p, + [1] = llc_conn_ac_stop_ack_timer, + [2] = llc_conn_ac_conn_confirm, + [3] = llc_conn_disc, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_setup_state_trans_4 = { + .ev = llc_conn_ev_rx_disc_cmd_pbit_set_x, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = llc_setup_ev_qfyrs_4, + .ev_actions = llc_setup_actions_4, +}; + +/* State transitions for LLC_CONN_EV_RX_DM_RSP_Fbit_SET_X event */ +static const llc_conn_ev_qfyr_t llc_setup_ev_qfyrs_5[] = { + [0] = llc_conn_ev_qlfy_set_status_disc, + [1] = NULL, +}; + +static const llc_conn_action_t llc_setup_actions_5[] = { + [0] = llc_conn_ac_stop_ack_timer, + [1] = llc_conn_ac_conn_confirm, + [2] = llc_conn_disc, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_setup_state_trans_5 = { + .ev = llc_conn_ev_rx_dm_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = llc_setup_ev_qfyrs_5, + .ev_actions = llc_setup_actions_5, +}; + +/* State transitions for LLC_CONN_EV_ACK_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_setup_ev_qfyrs_7[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [1] = llc_conn_ev_qlfy_s_flag_eq_0, + [2] = NULL, +}; + +static const llc_conn_action_t llc_setup_actions_7[] = { + [0] = llc_conn_ac_send_sabme_cmd_p_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_inc_retry_cnt_by_1, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_setup_state_trans_7 = { + .ev = llc_conn_ev_ack_tmr_exp, + .next_state = LLC_CONN_STATE_SETUP, + .ev_qualifiers = llc_setup_ev_qfyrs_7, + .ev_actions = llc_setup_actions_7, +}; + +/* State transitions for LLC_CONN_EV_ACK_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_setup_ev_qfyrs_8[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_gte_n2, + [1] = llc_conn_ev_qlfy_s_flag_eq_0, + [2] = llc_conn_ev_qlfy_set_status_failed, + [3] = NULL, +}; + +static const llc_conn_action_t llc_setup_actions_8[] = { + [0] = llc_conn_ac_conn_confirm, + [1] = llc_conn_disc, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_setup_state_trans_8 = { + .ev = llc_conn_ev_ack_tmr_exp, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = llc_setup_ev_qfyrs_8, + .ev_actions = llc_setup_actions_8, +}; + +/* + * Array of pointers; + * one to each transition + */ +static struct llc_conn_state_trans *llc_setup_state_transitions[] = { + [0] = &llc_common_state_trans_end, /* Request */ + [1] = &llc_common_state_trans_end, /* local busy */ + [2] = &llc_common_state_trans_end, /* init_pf_cycle */ + [3] = &llc_setup_state_trans_3, /* Timer */ + [4] = &llc_setup_state_trans_7, + [5] = &llc_setup_state_trans_8, + [6] = &llc_common_state_trans_end, + [7] = &llc_setup_state_trans_1, /* Receive frame */ + [8] = &llc_setup_state_trans_2, + [9] = &llc_setup_state_trans_4, + [10] = &llc_setup_state_trans_5, + [11] = &llc_common_state_trans_end, +}; + +/* LLC_CONN_STATE_NORMAL transitions */ +/* State transitions for LLC_CONN_EV_DATA_REQ event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_1[] = { + [0] = llc_conn_ev_qlfy_remote_busy_eq_0, + [1] = llc_conn_ev_qlfy_p_flag_eq_0, + [2] = llc_conn_ev_qlfy_last_frame_eq_0, + [3] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_1[] = { + [0] = llc_conn_ac_send_i_as_ack, + [1] = llc_conn_ac_start_ack_tmr_if_not_running, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_1 = { + .ev = llc_conn_ev_data_req, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_normal_ev_qfyrs_1, + .ev_actions = llc_normal_actions_1, +}; + +/* State transitions for LLC_CONN_EV_DATA_REQ event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_2[] = { + [0] = llc_conn_ev_qlfy_remote_busy_eq_0, + [1] = llc_conn_ev_qlfy_p_flag_eq_0, + [2] = llc_conn_ev_qlfy_last_frame_eq_1, + [3] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_2[] = { + [0] = llc_conn_ac_send_i_cmd_p_set_1, + [1] = llc_conn_ac_start_p_timer, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_2 = { + .ev = llc_conn_ev_data_req, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_normal_ev_qfyrs_2, + .ev_actions = llc_normal_actions_2, +}; + +/* State transitions for LLC_CONN_EV_DATA_REQ event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_2_1[] = { + [0] = llc_conn_ev_qlfy_remote_busy_eq_1, + [1] = llc_conn_ev_qlfy_set_status_remote_busy, + [2] = NULL, +}; + +/* just one member, NULL, .bss zeroes it */ +static const llc_conn_action_t llc_normal_actions_2_1[1]; + +static struct llc_conn_state_trans llc_normal_state_trans_2_1 = { + .ev = llc_conn_ev_data_req, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_normal_ev_qfyrs_2_1, + .ev_actions = llc_normal_actions_2_1, +}; + +/* State transitions for LLC_CONN_EV_LOCAL_BUSY_DETECTED event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_3[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_3[] = { + [0] = llc_conn_ac_rst_sendack_flag, + [1] = llc_conn_ac_send_rnr_xxx_x_set_0, + [2] = llc_conn_ac_set_data_flag_0, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_3 = { + .ev = llc_conn_ev_local_busy_detected, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_normal_ev_qfyrs_3, + .ev_actions = llc_normal_actions_3, +}; + +/* State transitions for LLC_CONN_EV_LOCAL_BUSY_DETECTED event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_4[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_4[] = { + [0] = llc_conn_ac_rst_sendack_flag, + [1] = llc_conn_ac_send_rnr_xxx_x_set_0, + [2] = llc_conn_ac_set_data_flag_0, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_4 = { + .ev = llc_conn_ev_local_busy_detected, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_normal_ev_qfyrs_4, + .ev_actions = llc_normal_actions_4, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_0_UNEXPD_Ns event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_5a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_5a[] = { + [0] = llc_conn_ac_rst_sendack_flag, + [1] = llc_conn_ac_send_rej_xxx_x_set_0, + [2] = llc_conn_ac_upd_nr_received, + [3] = llc_conn_ac_upd_p_flag, + [4] = llc_conn_ac_start_rej_timer, + [5] = llc_conn_ac_clear_remote_busy_if_f_eq_1, + [6] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_5a = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_0_unexpd_ns, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_normal_ev_qfyrs_5a, + .ev_actions = llc_normal_actions_5a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_0_UNEXPD_Ns event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_5b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_5b[] = { + [0] = llc_conn_ac_rst_sendack_flag, + [1] = llc_conn_ac_send_rej_xxx_x_set_0, + [2] = llc_conn_ac_upd_nr_received, + [3] = llc_conn_ac_upd_p_flag, + [4] = llc_conn_ac_start_rej_timer, + [5] = llc_conn_ac_clear_remote_busy_if_f_eq_1, + [6] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_5b = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_0_unexpd_ns, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_normal_ev_qfyrs_5b, + .ev_actions = llc_normal_actions_5b, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_1_UNEXPD_Ns event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_5c[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_5c[] = { + [0] = llc_conn_ac_rst_sendack_flag, + [1] = llc_conn_ac_send_rej_xxx_x_set_0, + [2] = llc_conn_ac_upd_nr_received, + [3] = llc_conn_ac_upd_p_flag, + [4] = llc_conn_ac_start_rej_timer, + [5] = llc_conn_ac_clear_remote_busy_if_f_eq_1, + [6] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_5c = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_1_unexpd_ns, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_normal_ev_qfyrs_5c, + .ev_actions = llc_normal_actions_5c, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_0_UNEXPD_Ns event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_6a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_6a[] = { + [0] = llc_conn_ac_rst_sendack_flag, + [1] = llc_conn_ac_send_rej_xxx_x_set_0, + [2] = llc_conn_ac_upd_nr_received, + [3] = llc_conn_ac_start_rej_timer, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_6a = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_0_unexpd_ns, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_normal_ev_qfyrs_6a, + .ev_actions = llc_normal_actions_6a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_0_UNEXPD_Ns event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_6b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_6b[] = { + [0] = llc_conn_ac_rst_sendack_flag, + [1] = llc_conn_ac_send_rej_xxx_x_set_0, + [2] = llc_conn_ac_upd_nr_received, + [3] = llc_conn_ac_start_rej_timer, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_6b = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_0_unexpd_ns, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_normal_ev_qfyrs_6b, + .ev_actions = llc_normal_actions_6b, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_1_UNEXPD_Ns event */ +static const llc_conn_action_t llc_normal_actions_7[] = { + [0] = llc_conn_ac_rst_sendack_flag, + [1] = llc_conn_ac_send_rej_rsp_f_set_1, + [2] = llc_conn_ac_upd_nr_received, + [3] = llc_conn_ac_start_rej_timer, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_7 = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_1_unexpd_ns, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_normal_actions_7, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_X event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_8a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_f, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_8[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_upd_p_flag, + [3] = llc_conn_ac_upd_nr_received, + [4] = llc_conn_ac_clear_remote_busy_if_f_eq_1, + [5] = llc_conn_ac_send_ack_if_needed, + [6] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_8a = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_normal_ev_qfyrs_8a, + .ev_actions = llc_normal_actions_8, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_8b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_8b = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_normal_ev_qfyrs_8b, + .ev_actions = llc_normal_actions_8, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_9a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_9a[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_data_ind, + [3] = llc_conn_ac_send_ack_if_needed, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_9a = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_normal_ev_qfyrs_9a, + .ev_actions = llc_normal_actions_9a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_9b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_9b[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_data_ind, + [3] = llc_conn_ac_send_ack_if_needed, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_9b = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_normal_ev_qfyrs_9b, + .ev_actions = llc_normal_actions_9b, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_normal_actions_10[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_send_ack_rsp_f_set_1, + [2] = llc_conn_ac_rst_sendack_flag, + [3] = llc_conn_ac_upd_nr_received, + [4] = llc_conn_ac_data_ind, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_10 = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = NONE, + .ev_actions = llc_normal_actions_10, +}; + +/* State transitions for * LLC_CONN_EV_RX_RR_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_normal_actions_11a[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_11a = { + .ev = llc_conn_ev_rx_rr_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = NONE, + .ev_actions = llc_normal_actions_11a, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_normal_actions_11b[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_11b = { + .ev = llc_conn_ev_rx_rr_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = NONE, + .ev_actions = llc_normal_actions_11b, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_RSP_Fbit_SET_1 event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_11c[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_11c[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_inc_tx_win_size, + [3] = llc_conn_ac_clear_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_11c = { + .ev = llc_conn_ev_rx_rr_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_normal_ev_qfyrs_11c, + .ev_actions = llc_normal_actions_11c, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_normal_actions_12[] = { + [0] = llc_conn_ac_send_ack_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_adjust_npta_by_rr, + [3] = llc_conn_ac_rst_sendack_flag, + [4] = llc_conn_ac_clear_remote_busy, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_12 = { + .ev = llc_conn_ev_rx_rr_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = NONE, + .ev_actions = llc_normal_actions_12, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_normal_actions_13a[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_set_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_13a = { + .ev = llc_conn_ev_rx_rnr_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = NONE, + .ev_actions = llc_normal_actions_13a, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_normal_actions_13b[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_set_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_13b = { + .ev = llc_conn_ev_rx_rnr_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = NONE, + .ev_actions = llc_normal_actions_13b, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_RSP_Fbit_SET_1 event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_13c[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_13c[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_set_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_13c = { + .ev = llc_conn_ev_rx_rnr_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_normal_ev_qfyrs_13c, + .ev_actions = llc_normal_actions_13c, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_normal_actions_14[] = { + [0] = llc_conn_ac_send_rr_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_adjust_npta_by_rnr, + [3] = llc_conn_ac_rst_sendack_flag, + [4] = llc_conn_ac_set_remote_busy, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_14 = { + .ev = llc_conn_ev_rx_rnr_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = NONE, + .ev_actions = llc_normal_actions_14, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_CMD_Pbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_15a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_15a[] = { + [0] = llc_conn_ac_set_vs_nr, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_p_flag, + [3] = llc_conn_ac_dec_tx_win_size, + [4] = llc_conn_ac_resend_i_xxx_x_set_0, + [5] = llc_conn_ac_clear_remote_busy, + [6] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_15a = { + .ev = llc_conn_ev_rx_rej_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_normal_ev_qfyrs_15a, + .ev_actions = llc_normal_actions_15a, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_RSP_Fbit_SET_X event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_15b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_f, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_15b[] = { + [0] = llc_conn_ac_set_vs_nr, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_p_flag, + [3] = llc_conn_ac_dec_tx_win_size, + [4] = llc_conn_ac_resend_i_xxx_x_set_0, + [5] = llc_conn_ac_clear_remote_busy, + [6] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_15b = { + .ev = llc_conn_ev_rx_rej_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_normal_ev_qfyrs_15b, + .ev_actions = llc_normal_actions_15b, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_CMD_Pbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_16a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_16a[] = { + [0] = llc_conn_ac_set_vs_nr, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_dec_tx_win_size, + [3] = llc_conn_ac_resend_i_xxx_x_set_0, + [4] = llc_conn_ac_clear_remote_busy, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_16a = { + .ev = llc_conn_ev_rx_rej_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_normal_ev_qfyrs_16a, + .ev_actions = llc_normal_actions_16a, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_RSP_Fbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_16b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_16b[] = { + [0] = llc_conn_ac_set_vs_nr, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_dec_tx_win_size, + [3] = llc_conn_ac_resend_i_xxx_x_set_0, + [4] = llc_conn_ac_clear_remote_busy, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_16b = { + .ev = llc_conn_ev_rx_rej_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_normal_ev_qfyrs_16b, + .ev_actions = llc_normal_actions_16b, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_normal_actions_17[] = { + [0] = llc_conn_ac_set_vs_nr, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_dec_tx_win_size, + [3] = llc_conn_ac_resend_i_rsp_f_set_1, + [4] = llc_conn_ac_clear_remote_busy, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_17 = { + .ev = llc_conn_ev_rx_rej_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = NONE, + .ev_actions = llc_normal_actions_17, +}; + +/* State transitions for LLC_CONN_EV_INIT_P_F_CYCLE event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_18[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_18[] = { + [0] = llc_conn_ac_send_rr_cmd_p_set_1, + [1] = llc_conn_ac_start_p_timer, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_18 = { + .ev = llc_conn_ev_init_p_f_cycle, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_normal_ev_qfyrs_18, + .ev_actions = llc_normal_actions_18, +}; + +/* State transitions for LLC_CONN_EV_P_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_19[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_19[] = { + [0] = llc_conn_ac_rst_sendack_flag, + [1] = llc_conn_ac_send_rr_cmd_p_set_1, + [2] = llc_conn_ac_rst_vs, + [3] = llc_conn_ac_start_p_timer, + [4] = llc_conn_ac_inc_retry_cnt_by_1, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_19 = { + .ev = llc_conn_ev_p_tmr_exp, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = llc_normal_ev_qfyrs_19, + .ev_actions = llc_normal_actions_19, +}; + +/* State transitions for LLC_CONN_EV_ACK_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_20a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [2] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_20a[] = { + [0] = llc_conn_ac_rst_sendack_flag, + [1] = llc_conn_ac_send_rr_cmd_p_set_1, + [2] = llc_conn_ac_rst_vs, + [3] = llc_conn_ac_start_p_timer, + [4] = llc_conn_ac_inc_retry_cnt_by_1, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_20a = { + .ev = llc_conn_ev_ack_tmr_exp, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = llc_normal_ev_qfyrs_20a, + .ev_actions = llc_normal_actions_20a, +}; + +/* State transitions for LLC_CONN_EV_BUSY_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_20b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [2] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_20b[] = { + [0] = llc_conn_ac_rst_sendack_flag, + [1] = llc_conn_ac_send_rr_cmd_p_set_1, + [2] = llc_conn_ac_rst_vs, + [3] = llc_conn_ac_start_p_timer, + [4] = llc_conn_ac_inc_retry_cnt_by_1, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_20b = { + .ev = llc_conn_ev_busy_tmr_exp, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = llc_normal_ev_qfyrs_20b, + .ev_actions = llc_normal_actions_20b, +}; + +/* State transitions for LLC_CONN_EV_TX_BUFF_FULL event */ +static const llc_conn_ev_qfyr_t llc_normal_ev_qfyrs_21[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = NULL, +}; + +static const llc_conn_action_t llc_normal_actions_21[] = { + [0] = llc_conn_ac_send_rr_cmd_p_set_1, + [1] = llc_conn_ac_start_p_timer, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_normal_state_trans_21 = { + .ev = llc_conn_ev_tx_buffer_full, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_normal_ev_qfyrs_21, + .ev_actions = llc_normal_actions_21, +}; + +/* + * Array of pointers; + * one to each transition + */ +static struct llc_conn_state_trans *llc_normal_state_transitions[] = { + [0] = &llc_normal_state_trans_1, /* Requests */ + [1] = &llc_normal_state_trans_2, + [2] = &llc_normal_state_trans_2_1, + [3] = &llc_common_state_trans_1, + [4] = &llc_common_state_trans_2, + [5] = &llc_common_state_trans_end, + [6] = &llc_normal_state_trans_21, + [7] = &llc_normal_state_trans_3, /* Local busy */ + [8] = &llc_normal_state_trans_4, + [9] = &llc_common_state_trans_end, + [10] = &llc_normal_state_trans_18, /* Init pf cycle */ + [11] = &llc_common_state_trans_end, + [12] = &llc_common_state_trans_11a, /* Timers */ + [13] = &llc_common_state_trans_11b, + [14] = &llc_common_state_trans_11c, + [15] = &llc_common_state_trans_11d, + [16] = &llc_normal_state_trans_19, + [17] = &llc_normal_state_trans_20a, + [18] = &llc_normal_state_trans_20b, + [19] = &llc_common_state_trans_end, + [20] = &llc_normal_state_trans_8b, /* Receive frames */ + [21] = &llc_normal_state_trans_9b, + [22] = &llc_normal_state_trans_10, + [23] = &llc_normal_state_trans_11b, + [24] = &llc_normal_state_trans_11c, + [25] = &llc_normal_state_trans_5a, + [26] = &llc_normal_state_trans_5b, + [27] = &llc_normal_state_trans_5c, + [28] = &llc_normal_state_trans_6a, + [29] = &llc_normal_state_trans_6b, + [30] = &llc_normal_state_trans_7, + [31] = &llc_normal_state_trans_8a, + [32] = &llc_normal_state_trans_9a, + [33] = &llc_normal_state_trans_11a, + [34] = &llc_normal_state_trans_12, + [35] = &llc_normal_state_trans_13a, + [36] = &llc_normal_state_trans_13b, + [37] = &llc_normal_state_trans_13c, + [38] = &llc_normal_state_trans_14, + [39] = &llc_normal_state_trans_15a, + [40] = &llc_normal_state_trans_15b, + [41] = &llc_normal_state_trans_16a, + [42] = &llc_normal_state_trans_16b, + [43] = &llc_normal_state_trans_17, + [44] = &llc_common_state_trans_3, + [45] = &llc_common_state_trans_4, + [46] = &llc_common_state_trans_5, + [47] = &llc_common_state_trans_6, + [48] = &llc_common_state_trans_7a, + [49] = &llc_common_state_trans_7b, + [50] = &llc_common_state_trans_8a, + [51] = &llc_common_state_trans_8b, + [52] = &llc_common_state_trans_8c, + [53] = &llc_common_state_trans_9, + /* [54] = &llc_common_state_trans_10, */ + [54] = &llc_common_state_trans_end, +}; + +/* LLC_CONN_STATE_BUSY transitions */ +/* State transitions for LLC_CONN_EV_DATA_REQ event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_1[] = { + [0] = llc_conn_ev_qlfy_remote_busy_eq_0, + [1] = llc_conn_ev_qlfy_p_flag_eq_0, + [2] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_1[] = { + [0] = llc_conn_ac_send_i_xxx_x_set_0, + [1] = llc_conn_ac_start_ack_tmr_if_not_running, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_1 = { + .ev = llc_conn_ev_data_req, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_1, + .ev_actions = llc_busy_actions_1, +}; + +/* State transitions for LLC_CONN_EV_DATA_REQ event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_2[] = { + [0] = llc_conn_ev_qlfy_remote_busy_eq_0, + [1] = llc_conn_ev_qlfy_p_flag_eq_1, + [2] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_2[] = { + [0] = llc_conn_ac_send_i_xxx_x_set_0, + [1] = llc_conn_ac_start_ack_tmr_if_not_running, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_2 = { + .ev = llc_conn_ev_data_req, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_2, + .ev_actions = llc_busy_actions_2, +}; + +/* State transitions for LLC_CONN_EV_DATA_REQ event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_2_1[] = { + [0] = llc_conn_ev_qlfy_remote_busy_eq_1, + [1] = llc_conn_ev_qlfy_set_status_remote_busy, + [2] = NULL, +}; + +/* just one member, NULL, .bss zeroes it */ +static const llc_conn_action_t llc_busy_actions_2_1[1]; + +static struct llc_conn_state_trans llc_busy_state_trans_2_1 = { + .ev = llc_conn_ev_data_req, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_2_1, + .ev_actions = llc_busy_actions_2_1, +}; + +/* State transitions for LLC_CONN_EV_LOCAL_BUSY_CLEARED event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_3[] = { + [0] = llc_conn_ev_qlfy_data_flag_eq_1, + [1] = llc_conn_ev_qlfy_p_flag_eq_0, + [2] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_3[] = { + [0] = llc_conn_ac_send_rej_xxx_x_set_0, + [1] = llc_conn_ac_start_rej_timer, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_3 = { + .ev = llc_conn_ev_local_busy_cleared, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_busy_ev_qfyrs_3, + .ev_actions = llc_busy_actions_3, +}; + +/* State transitions for LLC_CONN_EV_LOCAL_BUSY_CLEARED event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_4[] = { + [0] = llc_conn_ev_qlfy_data_flag_eq_1, + [1] = llc_conn_ev_qlfy_p_flag_eq_1, + [2] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_4[] = { + [0] = llc_conn_ac_send_rej_xxx_x_set_0, + [1] = llc_conn_ac_start_rej_timer, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_4 = { + .ev = llc_conn_ev_local_busy_cleared, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_busy_ev_qfyrs_4, + .ev_actions = llc_busy_actions_4, +}; + +/* State transitions for LLC_CONN_EV_LOCAL_BUSY_CLEARED event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_5[] = { + [0] = llc_conn_ev_qlfy_data_flag_eq_0, + [1] = llc_conn_ev_qlfy_p_flag_eq_0, + [2] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_5[] = { + [0] = llc_conn_ac_send_rr_xxx_x_set_0, + [1] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_5 = { + .ev = llc_conn_ev_local_busy_cleared, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_busy_ev_qfyrs_5, + .ev_actions = llc_busy_actions_5, +}; + +/* State transitions for LLC_CONN_EV_LOCAL_BUSY_CLEARED event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_6[] = { + [0] = llc_conn_ev_qlfy_data_flag_eq_0, + [1] = llc_conn_ev_qlfy_p_flag_eq_1, + [2] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_6[] = { + [0] = llc_conn_ac_send_rr_xxx_x_set_0, + [1] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_6 = { + .ev = llc_conn_ev_local_busy_cleared, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_busy_ev_qfyrs_6, + .ev_actions = llc_busy_actions_6, +}; + +/* State transitions for LLC_CONN_EV_LOCAL_BUSY_CLEARED event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_7[] = { + [0] = llc_conn_ev_qlfy_data_flag_eq_2, + [1] = llc_conn_ev_qlfy_p_flag_eq_0, + [2] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_7[] = { + [0] = llc_conn_ac_send_rr_xxx_x_set_0, + [1] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_7 = { + .ev = llc_conn_ev_local_busy_cleared, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_busy_ev_qfyrs_7, + .ev_actions = llc_busy_actions_7, +}; + +/* State transitions for LLC_CONN_EV_LOCAL_BUSY_CLEARED event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_8[] = { + [0] = llc_conn_ev_qlfy_data_flag_eq_2, + [1] = llc_conn_ev_qlfy_p_flag_eq_1, + [2] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_8[] = { + [0] = llc_conn_ac_send_rr_xxx_x_set_0, + [1] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_8 = { + .ev = llc_conn_ev_local_busy_cleared, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_busy_ev_qfyrs_8, + .ev_actions = llc_busy_actions_8, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_X_UNEXPD_Ns event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_9a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_f, + [1] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_9a[] = { + [0] = llc_conn_ac_opt_send_rnr_xxx_x_set_0, + [1] = llc_conn_ac_upd_p_flag, + [2] = llc_conn_ac_upd_nr_received, + [3] = llc_conn_ac_set_data_flag_1_if_data_flag_eq_0, + [4] = llc_conn_ac_clear_remote_busy_if_f_eq_1, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_9a = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_x_unexpd_ns, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_9a, + .ev_actions = llc_busy_actions_9a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_0_UNEXPD_Ns event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_9b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_9b[] = { + [0] = llc_conn_ac_opt_send_rnr_xxx_x_set_0, + [1] = llc_conn_ac_upd_p_flag, + [2] = llc_conn_ac_upd_nr_received, + [3] = llc_conn_ac_set_data_flag_1_if_data_flag_eq_0, + [4] = llc_conn_ac_clear_remote_busy_if_f_eq_1, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_9b = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_0_unexpd_ns, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_9b, + .ev_actions = llc_busy_actions_9b, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_0_UNEXPD_Ns event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_10a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_10a[] = { + [0] = llc_conn_ac_opt_send_rnr_xxx_x_set_0, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_set_data_flag_1_if_data_flag_eq_0, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_10a = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_0_unexpd_ns, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_10a, + .ev_actions = llc_busy_actions_10a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_0_UNEXPD_Ns event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_10b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_10b[] = { + [0] = llc_conn_ac_opt_send_rnr_xxx_x_set_0, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_set_data_flag_1_if_data_flag_eq_0, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_10b = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_0_unexpd_ns, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_10b, + .ev_actions = llc_busy_actions_10b, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_1_UNEXPD_Ns event */ +static const llc_conn_action_t llc_busy_actions_11[] = { + [0] = llc_conn_ac_send_rnr_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_set_data_flag_1_if_data_flag_eq_0, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_11 = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_1_unexpd_ns, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_busy_actions_11, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_busy_actions_12[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_send_rnr_rsp_f_set_1, + [3] = llc_conn_ac_upd_nr_received, + [4] = llc_conn_ac_stop_rej_tmr_if_data_flag_eq_2, + [5] = llc_conn_ac_set_data_flag_0, + [6] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_12 = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_busy_actions_12, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_X event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_13a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_f, + [1] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_13a[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_upd_p_flag, + [3] = llc_conn_ac_opt_send_rnr_xxx_x_set_0, + [4] = llc_conn_ac_upd_nr_received, + [5] = llc_conn_ac_stop_rej_tmr_if_data_flag_eq_2, + [6] = llc_conn_ac_set_data_flag_0, + [7] = llc_conn_ac_clear_remote_busy_if_f_eq_1, + [8] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_13a = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_13a, + .ev_actions = llc_busy_actions_13a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_13b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_13b[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_upd_p_flag, + [3] = llc_conn_ac_opt_send_rnr_xxx_x_set_0, + [4] = llc_conn_ac_upd_nr_received, + [5] = llc_conn_ac_stop_rej_tmr_if_data_flag_eq_2, + [6] = llc_conn_ac_set_data_flag_0, + [7] = llc_conn_ac_clear_remote_busy_if_f_eq_1, + [8] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_13b = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_13b, + .ev_actions = llc_busy_actions_13b, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_14a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_14a[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_opt_send_rnr_xxx_x_set_0, + [3] = llc_conn_ac_upd_nr_received, + [4] = llc_conn_ac_stop_rej_tmr_if_data_flag_eq_2, + [5] = llc_conn_ac_set_data_flag_0, + [6] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_14a = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_14a, + .ev_actions = llc_busy_actions_14a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_14b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_14b[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_opt_send_rnr_xxx_x_set_0, + [3] = llc_conn_ac_upd_nr_received, + [4] = llc_conn_ac_stop_rej_tmr_if_data_flag_eq_2, + [5] = llc_conn_ac_set_data_flag_0, + [6] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_14b = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_14b, + .ev_actions = llc_busy_actions_14b, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_busy_actions_15a[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_15a = { + .ev = llc_conn_ev_rx_rr_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_busy_actions_15a, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_busy_actions_15b[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_15b = { + .ev = llc_conn_ev_rx_rr_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_busy_actions_15b, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_RSP_Fbit_SET_1 event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_15c[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_15c[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_15c = { + .ev = llc_conn_ev_rx_rr_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_15c, + .ev_actions = llc_busy_actions_15c, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_busy_actions_16[] = { + [0] = llc_conn_ac_send_rnr_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_16 = { + .ev = llc_conn_ev_rx_rr_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_busy_actions_16, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_busy_actions_17a[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_set_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_17a = { + .ev = llc_conn_ev_rx_rnr_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_busy_actions_17a, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_busy_actions_17b[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_set_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_17b = { + .ev = llc_conn_ev_rx_rnr_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_busy_actions_17b, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_RSP_Fbit_SET_1 event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_17c[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_17c[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_set_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_17c = { + .ev = llc_conn_ev_rx_rnr_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_17c, + .ev_actions = llc_busy_actions_17c, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_busy_actions_18[] = { + [0] = llc_conn_ac_send_rnr_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_set_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_18 = { + .ev = llc_conn_ev_rx_rnr_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_busy_actions_18, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_CMD_Pbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_19a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_19a[] = { + [0] = llc_conn_ac_set_vs_nr, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_p_flag, + [3] = llc_conn_ac_resend_i_xxx_x_set_0, + [4] = llc_conn_ac_clear_remote_busy, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_19a = { + .ev = llc_conn_ev_rx_rej_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_19a, + .ev_actions = llc_busy_actions_19a, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_RSP_Fbit_SET_X event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_19b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_f, + [1] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_19b[] = { + [0] = llc_conn_ac_set_vs_nr, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_p_flag, + [3] = llc_conn_ac_resend_i_xxx_x_set_0, + [4] = llc_conn_ac_clear_remote_busy, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_19b = { + .ev = llc_conn_ev_rx_rej_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_19b, + .ev_actions = llc_busy_actions_19b, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_CMD_Pbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_20a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_20a[] = { + [0] = llc_conn_ac_set_vs_nr, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_resend_i_xxx_x_set_0, + [3] = llc_conn_ac_clear_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_20a = { + .ev = llc_conn_ev_rx_rej_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_20a, + .ev_actions = llc_busy_actions_20a, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_RSP_Fbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_20b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_20b[] = { + [0] = llc_conn_ac_set_vs_nr, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_resend_i_xxx_x_set_0, + [3] = llc_conn_ac_clear_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_20b = { + .ev = llc_conn_ev_rx_rej_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_20b, + .ev_actions = llc_busy_actions_20b, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_busy_actions_21[] = { + [0] = llc_conn_ac_set_vs_nr, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_send_rnr_rsp_f_set_1, + [3] = llc_conn_ac_resend_i_xxx_x_set_0, + [4] = llc_conn_ac_clear_remote_busy, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_21 = { + .ev = llc_conn_ev_rx_rej_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_busy_actions_21, +}; + +/* State transitions for LLC_CONN_EV_INIT_P_F_CYCLE event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_22[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_22[] = { + [0] = llc_conn_ac_send_rnr_cmd_p_set_1, + [1] = llc_conn_ac_start_p_timer, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_22 = { + .ev = llc_conn_ev_init_p_f_cycle, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_22, + .ev_actions = llc_busy_actions_22, +}; + +/* State transitions for LLC_CONN_EV_P_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_23[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [1] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_23[] = { + [0] = llc_conn_ac_send_rnr_cmd_p_set_1, + [1] = llc_conn_ac_rst_vs, + [2] = llc_conn_ac_start_p_timer, + [3] = llc_conn_ac_inc_retry_cnt_by_1, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_23 = { + .ev = llc_conn_ev_p_tmr_exp, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_23, + .ev_actions = llc_busy_actions_23, +}; + +/* State transitions for LLC_CONN_EV_ACK_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_24a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [2] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_24a[] = { + [0] = llc_conn_ac_send_rnr_cmd_p_set_1, + [1] = llc_conn_ac_start_p_timer, + [2] = llc_conn_ac_inc_retry_cnt_by_1, + [3] = llc_conn_ac_rst_vs, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_24a = { + .ev = llc_conn_ev_ack_tmr_exp, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_24a, + .ev_actions = llc_busy_actions_24a, +}; + +/* State transitions for LLC_CONN_EV_BUSY_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_24b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [2] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_24b[] = { + [0] = llc_conn_ac_send_rnr_cmd_p_set_1, + [1] = llc_conn_ac_start_p_timer, + [2] = llc_conn_ac_inc_retry_cnt_by_1, + [3] = llc_conn_ac_rst_vs, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_24b = { + .ev = llc_conn_ev_busy_tmr_exp, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_24b, + .ev_actions = llc_busy_actions_24b, +}; + +/* State transitions for LLC_CONN_EV_REJ_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_25[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [2] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_25[] = { + [0] = llc_conn_ac_send_rnr_cmd_p_set_1, + [1] = llc_conn_ac_start_p_timer, + [2] = llc_conn_ac_inc_retry_cnt_by_1, + [3] = llc_conn_ac_rst_vs, + [4] = llc_conn_ac_set_data_flag_1, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_25 = { + .ev = llc_conn_ev_rej_tmr_exp, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_25, + .ev_actions = llc_busy_actions_25, +}; + +/* State transitions for LLC_CONN_EV_REJ_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_busy_ev_qfyrs_26[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [2] = NULL, +}; + +static const llc_conn_action_t llc_busy_actions_26[] = { + [0] = llc_conn_ac_set_data_flag_1, + [1] = NULL, +}; + +static struct llc_conn_state_trans llc_busy_state_trans_26 = { + .ev = llc_conn_ev_rej_tmr_exp, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_busy_ev_qfyrs_26, + .ev_actions = llc_busy_actions_26, +}; + +/* + * Array of pointers; + * one to each transition + */ +static struct llc_conn_state_trans *llc_busy_state_transitions[] = { + [0] = &llc_common_state_trans_1, /* Request */ + [1] = &llc_common_state_trans_2, + [2] = &llc_busy_state_trans_1, + [3] = &llc_busy_state_trans_2, + [4] = &llc_busy_state_trans_2_1, + [5] = &llc_common_state_trans_end, + [6] = &llc_busy_state_trans_3, /* Local busy */ + [7] = &llc_busy_state_trans_4, + [8] = &llc_busy_state_trans_5, + [9] = &llc_busy_state_trans_6, + [10] = &llc_busy_state_trans_7, + [11] = &llc_busy_state_trans_8, + [12] = &llc_common_state_trans_end, + [13] = &llc_busy_state_trans_22, /* Initiate PF cycle */ + [14] = &llc_common_state_trans_end, + [15] = &llc_common_state_trans_11a, /* Timer */ + [16] = &llc_common_state_trans_11b, + [17] = &llc_common_state_trans_11c, + [18] = &llc_common_state_trans_11d, + [19] = &llc_busy_state_trans_23, + [20] = &llc_busy_state_trans_24a, + [21] = &llc_busy_state_trans_24b, + [22] = &llc_busy_state_trans_25, + [23] = &llc_busy_state_trans_26, + [24] = &llc_common_state_trans_end, + [25] = &llc_busy_state_trans_9a, /* Receive frame */ + [26] = &llc_busy_state_trans_9b, + [27] = &llc_busy_state_trans_10a, + [28] = &llc_busy_state_trans_10b, + [29] = &llc_busy_state_trans_11, + [30] = &llc_busy_state_trans_12, + [31] = &llc_busy_state_trans_13a, + [32] = &llc_busy_state_trans_13b, + [33] = &llc_busy_state_trans_14a, + [34] = &llc_busy_state_trans_14b, + [35] = &llc_busy_state_trans_15a, + [36] = &llc_busy_state_trans_15b, + [37] = &llc_busy_state_trans_15c, + [38] = &llc_busy_state_trans_16, + [39] = &llc_busy_state_trans_17a, + [40] = &llc_busy_state_trans_17b, + [41] = &llc_busy_state_trans_17c, + [42] = &llc_busy_state_trans_18, + [43] = &llc_busy_state_trans_19a, + [44] = &llc_busy_state_trans_19b, + [45] = &llc_busy_state_trans_20a, + [46] = &llc_busy_state_trans_20b, + [47] = &llc_busy_state_trans_21, + [48] = &llc_common_state_trans_3, + [49] = &llc_common_state_trans_4, + [50] = &llc_common_state_trans_5, + [51] = &llc_common_state_trans_6, + [52] = &llc_common_state_trans_7a, + [53] = &llc_common_state_trans_7b, + [54] = &llc_common_state_trans_8a, + [55] = &llc_common_state_trans_8b, + [56] = &llc_common_state_trans_8c, + [57] = &llc_common_state_trans_9, + /* [58] = &llc_common_state_trans_10, */ + [58] = &llc_common_state_trans_end, +}; + +/* LLC_CONN_STATE_REJ transitions */ +/* State transitions for LLC_CONN_EV_DATA_REQ event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_1[] = { + [0] = llc_conn_ev_qlfy_remote_busy_eq_0, + [1] = llc_conn_ev_qlfy_p_flag_eq_0, + [2] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_1[] = { + [0] = llc_conn_ac_send_i_xxx_x_set_0, + [1] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_1 = { + .ev = llc_conn_ev_data_req, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_reject_ev_qfyrs_1, + .ev_actions = llc_reject_actions_1, +}; + +/* State transitions for LLC_CONN_EV_DATA_REQ event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_2[] = { + [0] = llc_conn_ev_qlfy_remote_busy_eq_0, + [1] = llc_conn_ev_qlfy_p_flag_eq_1, + [2] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_2[] = { + [0] = llc_conn_ac_send_i_xxx_x_set_0, + [1] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_2 = { + .ev = llc_conn_ev_data_req, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_reject_ev_qfyrs_2, + .ev_actions = llc_reject_actions_2, +}; + +/* State transitions for LLC_CONN_EV_DATA_REQ event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_2_1[] = { + [0] = llc_conn_ev_qlfy_remote_busy_eq_1, + [1] = llc_conn_ev_qlfy_set_status_remote_busy, + [2] = NULL, +}; + +/* just one member, NULL, .bss zeroes it */ +static const llc_conn_action_t llc_reject_actions_2_1[1]; + +static struct llc_conn_state_trans llc_reject_state_trans_2_1 = { + .ev = llc_conn_ev_data_req, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_reject_ev_qfyrs_2_1, + .ev_actions = llc_reject_actions_2_1, +}; + + +/* State transitions for LLC_CONN_EV_LOCAL_BUSY_DETECTED event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_3[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_3[] = { + [0] = llc_conn_ac_send_rnr_xxx_x_set_0, + [1] = llc_conn_ac_set_data_flag_2, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_3 = { + .ev = llc_conn_ev_local_busy_detected, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_reject_ev_qfyrs_3, + .ev_actions = llc_reject_actions_3, +}; + +/* State transitions for LLC_CONN_EV_LOCAL_BUSY_DETECTED event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_4[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_4[] = { + [0] = llc_conn_ac_send_rnr_xxx_x_set_0, + [1] = llc_conn_ac_set_data_flag_2, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_4 = { + .ev = llc_conn_ev_local_busy_detected, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = llc_reject_ev_qfyrs_4, + .ev_actions = llc_reject_actions_4, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_0_UNEXPD_Ns event */ +static const llc_conn_action_t llc_reject_actions_5a[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_p_flag, + [2] = llc_conn_ac_clear_remote_busy_if_f_eq_1, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_5a = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_0_unexpd_ns, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_reject_actions_5a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_0_UNEXPD_Ns event */ +static const llc_conn_action_t llc_reject_actions_5b[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_p_flag, + [2] = llc_conn_ac_clear_remote_busy_if_f_eq_1, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_5b = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_0_unexpd_ns, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_reject_actions_5b, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_1_UNEXPD_Ns event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_5c[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_5c[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_p_flag, + [2] = llc_conn_ac_clear_remote_busy_if_f_eq_1, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_5c = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_1_unexpd_ns, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_reject_ev_qfyrs_5c, + .ev_actions = llc_reject_actions_5c, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_1_UNEXPD_Ns event */ +static const llc_conn_action_t llc_reject_actions_6[] = { + [0] = llc_conn_ac_send_rr_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_6 = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_1_unexpd_ns, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_reject_actions_6, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_X event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_7a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_f, + [1] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_7a[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_upd_p_flag, + [3] = llc_conn_ac_send_ack_xxx_x_set_0, + [4] = llc_conn_ac_upd_nr_received, + [5] = llc_conn_ac_clear_remote_busy_if_f_eq_1, + [6] = llc_conn_ac_stop_rej_timer, + [7] = NULL, + +}; + +static struct llc_conn_state_trans llc_reject_state_trans_7a = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_reject_ev_qfyrs_7a, + .ev_actions = llc_reject_actions_7a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_7b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_7b[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_upd_p_flag, + [3] = llc_conn_ac_send_ack_xxx_x_set_0, + [4] = llc_conn_ac_upd_nr_received, + [5] = llc_conn_ac_clear_remote_busy_if_f_eq_1, + [6] = llc_conn_ac_stop_rej_timer, + [7] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_7b = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_reject_ev_qfyrs_7b, + .ev_actions = llc_reject_actions_7b, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_8a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_8a[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_send_ack_xxx_x_set_0, + [3] = llc_conn_ac_upd_nr_received, + [4] = llc_conn_ac_stop_rej_timer, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_8a = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_reject_ev_qfyrs_8a, + .ev_actions = llc_reject_actions_8a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_8b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_8b[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_send_ack_xxx_x_set_0, + [3] = llc_conn_ac_upd_nr_received, + [4] = llc_conn_ac_stop_rej_timer, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_8b = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_reject_ev_qfyrs_8b, + .ev_actions = llc_reject_actions_8b, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_reject_actions_9[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_send_ack_rsp_f_set_1, + [3] = llc_conn_ac_upd_nr_received, + [4] = llc_conn_ac_stop_rej_timer, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_9 = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = NONE, + .ev_actions = llc_reject_actions_9, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_reject_actions_10a[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_10a = { + .ev = llc_conn_ev_rx_rr_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_reject_actions_10a, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_reject_actions_10b[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_10b = { + .ev = llc_conn_ev_rx_rr_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_reject_actions_10b, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_RSP_Fbit_SET_1 event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_10c[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_10c[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_10c = { + .ev = llc_conn_ev_rx_rr_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_reject_ev_qfyrs_10c, + .ev_actions = llc_reject_actions_10c, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_reject_actions_11[] = { + [0] = llc_conn_ac_send_ack_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_11 = { + .ev = llc_conn_ev_rx_rr_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_reject_actions_11, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_reject_actions_12a[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_set_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_12a = { + .ev = llc_conn_ev_rx_rnr_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_reject_actions_12a, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_reject_actions_12b[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_set_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_12b = { + .ev = llc_conn_ev_rx_rnr_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_reject_actions_12b, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_RSP_Fbit_SET_1 event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_12c[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_12c[] = { + [0] = llc_conn_ac_upd_p_flag, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_set_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_12c = { + .ev = llc_conn_ev_rx_rnr_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_reject_ev_qfyrs_12c, + .ev_actions = llc_reject_actions_12c, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_reject_actions_13[] = { + [0] = llc_conn_ac_send_rr_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_set_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_13 = { + .ev = llc_conn_ev_rx_rnr_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_reject_actions_13, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_CMD_Pbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_14a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_14a[] = { + [0] = llc_conn_ac_set_vs_nr, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_p_flag, + [3] = llc_conn_ac_resend_i_xxx_x_set_0, + [4] = llc_conn_ac_clear_remote_busy, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_14a = { + .ev = llc_conn_ev_rx_rej_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_reject_ev_qfyrs_14a, + .ev_actions = llc_reject_actions_14a, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_RSP_Fbit_SET_X event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_14b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_f, + [1] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_14b[] = { + [0] = llc_conn_ac_set_vs_nr, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_p_flag, + [3] = llc_conn_ac_resend_i_xxx_x_set_0, + [4] = llc_conn_ac_clear_remote_busy, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_14b = { + .ev = llc_conn_ev_rx_rej_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_reject_ev_qfyrs_14b, + .ev_actions = llc_reject_actions_14b, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_CMD_Pbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_15a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_15a[] = { + [0] = llc_conn_ac_set_vs_nr, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_resend_i_xxx_x_set_0, + [3] = llc_conn_ac_clear_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_15a = { + .ev = llc_conn_ev_rx_rej_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_reject_ev_qfyrs_15a, + .ev_actions = llc_reject_actions_15a, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_RSP_Fbit_SET_0 event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_15b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_15b[] = { + [0] = llc_conn_ac_set_vs_nr, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_resend_i_xxx_x_set_0, + [3] = llc_conn_ac_clear_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_15b = { + .ev = llc_conn_ev_rx_rej_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_reject_ev_qfyrs_15b, + .ev_actions = llc_reject_actions_15b, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_reject_actions_16[] = { + [0] = llc_conn_ac_set_vs_nr, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_resend_i_rsp_f_set_1, + [3] = llc_conn_ac_clear_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_16 = { + .ev = llc_conn_ev_rx_rej_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_reject_actions_16, +}; + +/* State transitions for LLC_CONN_EV_INIT_P_F_CYCLE event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_17[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_17[] = { + [0] = llc_conn_ac_send_rr_cmd_p_set_1, + [1] = llc_conn_ac_start_p_timer, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_17 = { + .ev = llc_conn_ev_init_p_f_cycle, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_reject_ev_qfyrs_17, + .ev_actions = llc_reject_actions_17, +}; + +/* State transitions for LLC_CONN_EV_REJ_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_18[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [2] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_18[] = { + [0] = llc_conn_ac_send_rej_cmd_p_set_1, + [1] = llc_conn_ac_start_p_timer, + [2] = llc_conn_ac_start_rej_timer, + [3] = llc_conn_ac_inc_retry_cnt_by_1, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_18 = { + .ev = llc_conn_ev_rej_tmr_exp, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = llc_reject_ev_qfyrs_18, + .ev_actions = llc_reject_actions_18, +}; + +/* State transitions for LLC_CONN_EV_P_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_19[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [1] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_19[] = { + [0] = llc_conn_ac_send_rr_cmd_p_set_1, + [1] = llc_conn_ac_start_p_timer, + [2] = llc_conn_ac_start_rej_timer, + [3] = llc_conn_ac_inc_retry_cnt_by_1, + [4] = llc_conn_ac_rst_vs, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_19 = { + .ev = llc_conn_ev_p_tmr_exp, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = llc_reject_ev_qfyrs_19, + .ev_actions = llc_reject_actions_19, +}; + +/* State transitions for LLC_CONN_EV_ACK_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_20a[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [2] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_20a[] = { + [0] = llc_conn_ac_send_rr_cmd_p_set_1, + [1] = llc_conn_ac_start_p_timer, + [2] = llc_conn_ac_start_rej_timer, + [3] = llc_conn_ac_inc_retry_cnt_by_1, + [4] = llc_conn_ac_rst_vs, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_20a = { + .ev = llc_conn_ev_ack_tmr_exp, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = llc_reject_ev_qfyrs_20a, + .ev_actions = llc_reject_actions_20a, +}; + +/* State transitions for LLC_CONN_EV_BUSY_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_reject_ev_qfyrs_20b[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_0, + [1] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [2] = NULL, +}; + +static const llc_conn_action_t llc_reject_actions_20b[] = { + [0] = llc_conn_ac_send_rr_cmd_p_set_1, + [1] = llc_conn_ac_start_p_timer, + [2] = llc_conn_ac_start_rej_timer, + [3] = llc_conn_ac_inc_retry_cnt_by_1, + [4] = llc_conn_ac_rst_vs, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_reject_state_trans_20b = { + .ev = llc_conn_ev_busy_tmr_exp, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = llc_reject_ev_qfyrs_20b, + .ev_actions = llc_reject_actions_20b, +}; + +/* + * Array of pointers; + * one to each transition + */ +static struct llc_conn_state_trans *llc_reject_state_transitions[] = { + [0] = &llc_common_state_trans_1, /* Request */ + [1] = &llc_common_state_trans_2, + [2] = &llc_common_state_trans_end, + [3] = &llc_reject_state_trans_1, + [4] = &llc_reject_state_trans_2, + [5] = &llc_reject_state_trans_2_1, + [6] = &llc_reject_state_trans_3, /* Local busy */ + [7] = &llc_reject_state_trans_4, + [8] = &llc_common_state_trans_end, + [9] = &llc_reject_state_trans_17, /* Initiate PF cycle */ + [10] = &llc_common_state_trans_end, + [11] = &llc_common_state_trans_11a, /* Timer */ + [12] = &llc_common_state_trans_11b, + [13] = &llc_common_state_trans_11c, + [14] = &llc_common_state_trans_11d, + [15] = &llc_reject_state_trans_18, + [16] = &llc_reject_state_trans_19, + [17] = &llc_reject_state_trans_20a, + [18] = &llc_reject_state_trans_20b, + [19] = &llc_common_state_trans_end, + [20] = &llc_common_state_trans_3, /* Receive frame */ + [21] = &llc_common_state_trans_4, + [22] = &llc_common_state_trans_5, + [23] = &llc_common_state_trans_6, + [24] = &llc_common_state_trans_7a, + [25] = &llc_common_state_trans_7b, + [26] = &llc_common_state_trans_8a, + [27] = &llc_common_state_trans_8b, + [28] = &llc_common_state_trans_8c, + [29] = &llc_common_state_trans_9, + /* [30] = &llc_common_state_trans_10, */ + [30] = &llc_reject_state_trans_5a, + [31] = &llc_reject_state_trans_5b, + [32] = &llc_reject_state_trans_5c, + [33] = &llc_reject_state_trans_6, + [34] = &llc_reject_state_trans_7a, + [35] = &llc_reject_state_trans_7b, + [36] = &llc_reject_state_trans_8a, + [37] = &llc_reject_state_trans_8b, + [38] = &llc_reject_state_trans_9, + [39] = &llc_reject_state_trans_10a, + [40] = &llc_reject_state_trans_10b, + [41] = &llc_reject_state_trans_10c, + [42] = &llc_reject_state_trans_11, + [43] = &llc_reject_state_trans_12a, + [44] = &llc_reject_state_trans_12b, + [45] = &llc_reject_state_trans_12c, + [46] = &llc_reject_state_trans_13, + [47] = &llc_reject_state_trans_14a, + [48] = &llc_reject_state_trans_14b, + [49] = &llc_reject_state_trans_15a, + [50] = &llc_reject_state_trans_15b, + [51] = &llc_reject_state_trans_16, + [52] = &llc_common_state_trans_end, +}; + +/* LLC_CONN_STATE_AWAIT transitions */ +/* State transitions for LLC_CONN_EV_DATA_REQ event */ +static const llc_conn_ev_qfyr_t llc_await_ev_qfyrs_1_0[] = { + [0] = llc_conn_ev_qlfy_set_status_refuse, + [1] = NULL, +}; + +/* just one member, NULL, .bss zeroes it */ +static const llc_conn_action_t llc_await_actions_1_0[1]; + +static struct llc_conn_state_trans llc_await_state_trans_1_0 = { + .ev = llc_conn_ev_data_req, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = llc_await_ev_qfyrs_1_0, + .ev_actions = llc_await_actions_1_0, +}; + +/* State transitions for LLC_CONN_EV_LOCAL_BUSY_DETECTED event */ +static const llc_conn_action_t llc_await_actions_1[] = { + [0] = llc_conn_ac_send_rnr_xxx_x_set_0, + [1] = llc_conn_ac_set_data_flag_0, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_1 = { + .ev = llc_conn_ev_local_busy_detected, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_1, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_1_UNEXPD_Ns event */ +static const llc_conn_action_t llc_await_actions_2[] = { + [0] = llc_conn_ac_send_rej_xxx_x_set_0, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = llc_conn_ac_stop_p_timer, + [4] = llc_conn_ac_resend_i_xxx_x_set_0, + [5] = llc_conn_ac_start_rej_timer, + [6] = llc_conn_ac_clear_remote_busy, + [7] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_2 = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_1_unexpd_ns, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_2, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_0_UNEXPD_Ns event */ +static const llc_conn_action_t llc_await_actions_3a[] = { + [0] = llc_conn_ac_send_rej_xxx_x_set_0, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = llc_conn_ac_start_rej_timer, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_3a = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_0_unexpd_ns, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_3a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_0_UNEXPD_Ns event */ +static const llc_conn_action_t llc_await_actions_3b[] = { + [0] = llc_conn_ac_send_rej_xxx_x_set_0, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = llc_conn_ac_start_rej_timer, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_3b = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_0_unexpd_ns, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_3b, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_1_UNEXPD_Ns event */ +static const llc_conn_action_t llc_await_actions_4[] = { + [0] = llc_conn_ac_send_rej_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = llc_conn_ac_start_rej_timer, + [4] = llc_conn_ac_start_p_timer, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_4 = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_1_unexpd_ns, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_4, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_1 event */ +static const llc_conn_action_t llc_await_actions_5[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_stop_p_timer, + [3] = llc_conn_ac_upd_nr_received, + [4] = llc_conn_ac_upd_vs, + [5] = llc_conn_ac_resend_i_xxx_x_set_0_or_send_rr, + [6] = llc_conn_ac_clear_remote_busy, + [7] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_5 = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_5, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_await_actions_6a[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_send_rr_xxx_x_set_0, + [3] = llc_conn_ac_upd_nr_received, + [4] = llc_conn_ac_upd_vs, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_6a = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_6a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_await_actions_6b[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_send_rr_xxx_x_set_0, + [3] = llc_conn_ac_upd_nr_received, + [4] = llc_conn_ac_upd_vs, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_6b = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_6b, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_await_actions_7[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_send_rr_rsp_f_set_1, + [3] = llc_conn_ac_upd_nr_received, + [4] = llc_conn_ac_upd_vs, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_7 = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_7, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_RSP_Fbit_SET_1 event */ +static const llc_conn_action_t llc_await_actions_8a[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_stop_p_timer, + [3] = llc_conn_ac_resend_i_xxx_x_set_0, + [4] = llc_conn_ac_clear_remote_busy, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_8a = { + .ev = llc_conn_ev_rx_rr_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_8a, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_RSP_Fbit_SET_1 event */ +static const llc_conn_action_t llc_await_actions_8b[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_stop_p_timer, + [3] = llc_conn_ac_resend_i_xxx_x_set_0, + [4] = llc_conn_ac_clear_remote_busy, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_8b = { + .ev = llc_conn_ev_rx_rej_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_8b, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_await_actions_9a[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_9a = { + .ev = llc_conn_ev_rx_rr_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_9a, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_await_actions_9b[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_9b = { + .ev = llc_conn_ev_rx_rr_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_9b, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_await_actions_9c[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_9c = { + .ev = llc_conn_ev_rx_rej_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_9c, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_await_actions_9d[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_9d = { + .ev = llc_conn_ev_rx_rej_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_9d, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_await_actions_10a[] = { + [0] = llc_conn_ac_send_rr_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = llc_conn_ac_clear_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_10a = { + .ev = llc_conn_ev_rx_rr_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_10a, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_await_actions_10b[] = { + [0] = llc_conn_ac_send_rr_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = llc_conn_ac_clear_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_10b = { + .ev = llc_conn_ev_rx_rej_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_10b, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_RSP_Fbit_SET_1 event */ +static const llc_conn_action_t llc_await_actions_11[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_stop_p_timer, + [3] = llc_conn_ac_set_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_11 = { + .ev = llc_conn_ev_rx_rnr_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_11, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_await_actions_12a[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_set_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_12a = { + .ev = llc_conn_ev_rx_rnr_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_12a, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_await_actions_12b[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_set_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_12b = { + .ev = llc_conn_ev_rx_rnr_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_12b, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_await_actions_13[] = { + [0] = llc_conn_ac_send_rr_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = llc_conn_ac_set_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_13 = { + .ev = llc_conn_ev_rx_rnr_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = NONE, + .ev_actions = llc_await_actions_13, +}; + +/* State transitions for LLC_CONN_EV_P_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_await_ev_qfyrs_14[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [1] = NULL, +}; + +static const llc_conn_action_t llc_await_actions_14[] = { + [0] = llc_conn_ac_send_rr_cmd_p_set_1, + [1] = llc_conn_ac_start_p_timer, + [2] = llc_conn_ac_inc_retry_cnt_by_1, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_state_trans_14 = { + .ev = llc_conn_ev_p_tmr_exp, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = llc_await_ev_qfyrs_14, + .ev_actions = llc_await_actions_14, +}; + +/* + * Array of pointers; + * one to each transition + */ +static struct llc_conn_state_trans *llc_await_state_transitions[] = { + [0] = &llc_common_state_trans_1, /* Request */ + [1] = &llc_common_state_trans_2, + [2] = &llc_await_state_trans_1_0, + [3] = &llc_common_state_trans_end, + [4] = &llc_await_state_trans_1, /* Local busy */ + [5] = &llc_common_state_trans_end, + [6] = &llc_common_state_trans_end, /* Initiate PF Cycle */ + [7] = &llc_common_state_trans_11a, /* Timer */ + [8] = &llc_common_state_trans_11b, + [9] = &llc_common_state_trans_11c, + [10] = &llc_common_state_trans_11d, + [11] = &llc_await_state_trans_14, + [12] = &llc_common_state_trans_end, + [13] = &llc_common_state_trans_3, /* Receive frame */ + [14] = &llc_common_state_trans_4, + [15] = &llc_common_state_trans_5, + [16] = &llc_common_state_trans_6, + [17] = &llc_common_state_trans_7a, + [18] = &llc_common_state_trans_7b, + [19] = &llc_common_state_trans_8a, + [20] = &llc_common_state_trans_8b, + [21] = &llc_common_state_trans_8c, + [22] = &llc_common_state_trans_9, + /* [23] = &llc_common_state_trans_10, */ + [23] = &llc_await_state_trans_2, + [24] = &llc_await_state_trans_3a, + [25] = &llc_await_state_trans_3b, + [26] = &llc_await_state_trans_4, + [27] = &llc_await_state_trans_5, + [28] = &llc_await_state_trans_6a, + [29] = &llc_await_state_trans_6b, + [30] = &llc_await_state_trans_7, + [31] = &llc_await_state_trans_8a, + [32] = &llc_await_state_trans_8b, + [33] = &llc_await_state_trans_9a, + [34] = &llc_await_state_trans_9b, + [35] = &llc_await_state_trans_9c, + [36] = &llc_await_state_trans_9d, + [37] = &llc_await_state_trans_10a, + [38] = &llc_await_state_trans_10b, + [39] = &llc_await_state_trans_11, + [40] = &llc_await_state_trans_12a, + [41] = &llc_await_state_trans_12b, + [42] = &llc_await_state_trans_13, + [43] = &llc_common_state_trans_end, +}; + +/* LLC_CONN_STATE_AWAIT_BUSY transitions */ +/* State transitions for LLC_CONN_EV_DATA_CONN_REQ event */ +static const llc_conn_ev_qfyr_t llc_await_busy_ev_qfyrs_1_0[] = { + [0] = llc_conn_ev_qlfy_set_status_refuse, + [1] = NULL, +}; + +/* just one member, NULL, .bss zeroes it */ +static const llc_conn_action_t llc_await_busy_actions_1_0[1]; + +static struct llc_conn_state_trans llc_await_busy_state_trans_1_0 = { + .ev = llc_conn_ev_data_req, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = llc_await_busy_ev_qfyrs_1_0, + .ev_actions = llc_await_busy_actions_1_0, +}; + +/* State transitions for LLC_CONN_EV_LOCAL_BUSY_CLEARED event */ +static const llc_conn_ev_qfyr_t llc_await_busy_ev_qfyrs_1[] = { + [0] = llc_conn_ev_qlfy_data_flag_eq_1, + [1] = NULL, +}; + +static const llc_conn_action_t llc_await_busy_actions_1[] = { + [0] = llc_conn_ac_send_rej_xxx_x_set_0, + [1] = llc_conn_ac_start_rej_timer, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_1 = { + .ev = llc_conn_ev_local_busy_cleared, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = llc_await_busy_ev_qfyrs_1, + .ev_actions = llc_await_busy_actions_1, +}; + +/* State transitions for LLC_CONN_EV_LOCAL_BUSY_CLEARED event */ +static const llc_conn_ev_qfyr_t llc_await_busy_ev_qfyrs_2[] = { + [0] = llc_conn_ev_qlfy_data_flag_eq_0, + [1] = NULL, +}; + +static const llc_conn_action_t llc_await_busy_actions_2[] = { + [0] = llc_conn_ac_send_rr_xxx_x_set_0, + [1] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_2 = { + .ev = llc_conn_ev_local_busy_cleared, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = llc_await_busy_ev_qfyrs_2, + .ev_actions = llc_await_busy_actions_2, +}; + +/* State transitions for LLC_CONN_EV_LOCAL_BUSY_CLEARED event */ +static const llc_conn_ev_qfyr_t llc_await_busy_ev_qfyrs_3[] = { + [0] = llc_conn_ev_qlfy_data_flag_eq_2, + [1] = NULL, +}; + +static const llc_conn_action_t llc_await_busy_actions_3[] = { + [0] = llc_conn_ac_send_rr_xxx_x_set_0, + [1] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_3 = { + .ev = llc_conn_ev_local_busy_cleared, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = llc_await_busy_ev_qfyrs_3, + .ev_actions = llc_await_busy_actions_3, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_1_UNEXPD_Ns event */ +static const llc_conn_action_t llc_await_busy_actions_4[] = { + [0] = llc_conn_ac_opt_send_rnr_xxx_x_set_0, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = llc_conn_ac_stop_p_timer, + [4] = llc_conn_ac_set_data_flag_1, + [5] = llc_conn_ac_clear_remote_busy, + [6] = llc_conn_ac_resend_i_xxx_x_set_0, + [7] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_4 = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_1_unexpd_ns, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_4, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_0_UNEXPD_Ns event */ +static const llc_conn_action_t llc_await_busy_actions_5a[] = { + [0] = llc_conn_ac_opt_send_rnr_xxx_x_set_0, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = llc_conn_ac_set_data_flag_1, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_5a = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_0_unexpd_ns, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_5a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_0_UNEXPD_Ns event */ +static const llc_conn_action_t llc_await_busy_actions_5b[] = { + [0] = llc_conn_ac_opt_send_rnr_xxx_x_set_0, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = llc_conn_ac_set_data_flag_1, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_5b = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_0_unexpd_ns, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_5b, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_1_UNEXPD_Ns event */ +static const llc_conn_action_t llc_await_busy_actions_6[] = { + [0] = llc_conn_ac_send_rnr_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = llc_conn_ac_set_data_flag_1, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_6 = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_1_unexpd_ns, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_6, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_1 event */ +static const llc_conn_action_t llc_await_busy_actions_7[] = { + [0] = llc_conn_ac_opt_send_rnr_xxx_x_set_0, + [1] = llc_conn_ac_inc_vr_by_1, + [2] = llc_conn_ac_data_ind, + [3] = llc_conn_ac_stop_p_timer, + [4] = llc_conn_ac_upd_nr_received, + [5] = llc_conn_ac_upd_vs, + [6] = llc_conn_ac_set_data_flag_0, + [7] = llc_conn_ac_clear_remote_busy, + [8] = llc_conn_ac_resend_i_xxx_x_set_0, + [9] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_7 = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_7, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_await_busy_actions_8a[] = { + [0] = llc_conn_ac_opt_send_rnr_xxx_x_set_0, + [1] = llc_conn_ac_inc_vr_by_1, + [2] = llc_conn_ac_data_ind, + [3] = llc_conn_ac_upd_nr_received, + [4] = llc_conn_ac_upd_vs, + [5] = llc_conn_ac_set_data_flag_0, + [6] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_8a = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_8a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_await_busy_actions_8b[] = { + [0] = llc_conn_ac_opt_send_rnr_xxx_x_set_0, + [1] = llc_conn_ac_inc_vr_by_1, + [2] = llc_conn_ac_data_ind, + [3] = llc_conn_ac_upd_nr_received, + [4] = llc_conn_ac_upd_vs, + [5] = llc_conn_ac_set_data_flag_0, + [6] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_8b = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_8b, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_await_busy_actions_9[] = { + [0] = llc_conn_ac_send_rnr_rsp_f_set_1, + [1] = llc_conn_ac_inc_vr_by_1, + [2] = llc_conn_ac_data_ind, + [3] = llc_conn_ac_upd_nr_received, + [4] = llc_conn_ac_upd_vs, + [5] = llc_conn_ac_set_data_flag_0, + [6] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_9 = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_9, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_RSP_Fbit_SET_1 event */ +static const llc_conn_action_t llc_await_busy_actions_10a[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_stop_p_timer, + [3] = llc_conn_ac_resend_i_xxx_x_set_0, + [4] = llc_conn_ac_clear_remote_busy, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_10a = { + .ev = llc_conn_ev_rx_rr_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_10a, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_RSP_Fbit_SET_1 event */ +static const llc_conn_action_t llc_await_busy_actions_10b[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_stop_p_timer, + [3] = llc_conn_ac_resend_i_xxx_x_set_0, + [4] = llc_conn_ac_clear_remote_busy, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_10b = { + .ev = llc_conn_ev_rx_rej_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_10b, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_await_busy_actions_11a[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_11a = { + .ev = llc_conn_ev_rx_rr_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_11a, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_await_busy_actions_11b[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_11b = { + .ev = llc_conn_ev_rx_rr_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_11b, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_await_busy_actions_11c[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_11c = { + .ev = llc_conn_ev_rx_rej_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_11c, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_await_busy_actions_11d[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_11d = { + .ev = llc_conn_ev_rx_rej_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_11d, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_await_busy_actions_12a[] = { + [0] = llc_conn_ac_send_rnr_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = llc_conn_ac_clear_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_12a = { + .ev = llc_conn_ev_rx_rr_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_12a, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_await_busy_actions_12b[] = { + [0] = llc_conn_ac_send_rnr_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = llc_conn_ac_clear_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_12b = { + .ev = llc_conn_ev_rx_rej_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_12b, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_RSP_Fbit_SET_1 event */ +static const llc_conn_action_t llc_await_busy_actions_13[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_stop_p_timer, + [3] = llc_conn_ac_set_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_13 = { + .ev = llc_conn_ev_rx_rnr_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_13, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_await_busy_actions_14a[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_set_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_14a = { + .ev = llc_conn_ev_rx_rnr_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_14a, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_await_busy_actions_14b[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_set_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_14b = { + .ev = llc_conn_ev_rx_rnr_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_14b, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_await_busy_actions_15[] = { + [0] = llc_conn_ac_send_rnr_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = llc_conn_ac_set_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_15 = { + .ev = llc_conn_ev_rx_rnr_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_busy_actions_15, +}; + +/* State transitions for LLC_CONN_EV_P_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_await_busy_ev_qfyrs_16[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [1] = NULL, +}; + +static const llc_conn_action_t llc_await_busy_actions_16[] = { + [0] = llc_conn_ac_send_rnr_cmd_p_set_1, + [1] = llc_conn_ac_start_p_timer, + [2] = llc_conn_ac_inc_retry_cnt_by_1, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_busy_state_trans_16 = { + .ev = llc_conn_ev_p_tmr_exp, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = llc_await_busy_ev_qfyrs_16, + .ev_actions = llc_await_busy_actions_16, +}; + +/* + * Array of pointers; + * one to each transition + */ +static struct llc_conn_state_trans *llc_await_busy_state_transitions[] = { + [0] = &llc_common_state_trans_1, /* Request */ + [1] = &llc_common_state_trans_2, + [2] = &llc_await_busy_state_trans_1_0, + [3] = &llc_common_state_trans_end, + [4] = &llc_await_busy_state_trans_1, /* Local busy */ + [5] = &llc_await_busy_state_trans_2, + [6] = &llc_await_busy_state_trans_3, + [7] = &llc_common_state_trans_end, + [8] = &llc_common_state_trans_end, /* Initiate PF cycle */ + [9] = &llc_common_state_trans_11a, /* Timer */ + [10] = &llc_common_state_trans_11b, + [11] = &llc_common_state_trans_11c, + [12] = &llc_common_state_trans_11d, + [13] = &llc_await_busy_state_trans_16, + [14] = &llc_common_state_trans_end, + [15] = &llc_await_busy_state_trans_4, /* Receive frame */ + [16] = &llc_await_busy_state_trans_5a, + [17] = &llc_await_busy_state_trans_5b, + [18] = &llc_await_busy_state_trans_6, + [19] = &llc_await_busy_state_trans_7, + [20] = &llc_await_busy_state_trans_8a, + [21] = &llc_await_busy_state_trans_8b, + [22] = &llc_await_busy_state_trans_9, + [23] = &llc_await_busy_state_trans_10a, + [24] = &llc_await_busy_state_trans_10b, + [25] = &llc_await_busy_state_trans_11a, + [26] = &llc_await_busy_state_trans_11b, + [27] = &llc_await_busy_state_trans_11c, + [28] = &llc_await_busy_state_trans_11d, + [29] = &llc_await_busy_state_trans_12a, + [30] = &llc_await_busy_state_trans_12b, + [31] = &llc_await_busy_state_trans_13, + [32] = &llc_await_busy_state_trans_14a, + [33] = &llc_await_busy_state_trans_14b, + [34] = &llc_await_busy_state_trans_15, + [35] = &llc_common_state_trans_3, + [36] = &llc_common_state_trans_4, + [37] = &llc_common_state_trans_5, + [38] = &llc_common_state_trans_6, + [39] = &llc_common_state_trans_7a, + [40] = &llc_common_state_trans_7b, + [41] = &llc_common_state_trans_8a, + [42] = &llc_common_state_trans_8b, + [43] = &llc_common_state_trans_8c, + [44] = &llc_common_state_trans_9, + /* [45] = &llc_common_state_trans_10, */ + [45] = &llc_common_state_trans_end, +}; + +/* ----------------- LLC_CONN_STATE_AWAIT_REJ transitions --------------- */ +/* State transitions for LLC_CONN_EV_DATA_CONN_REQ event */ +static const llc_conn_ev_qfyr_t llc_await_reject_ev_qfyrs_1_0[] = { + [0] = llc_conn_ev_qlfy_set_status_refuse, + [1] = NULL, +}; + +/* just one member, NULL, .bss zeroes it */ +static const llc_conn_action_t llc_await_reject_actions_1_0[1]; + +static struct llc_conn_state_trans llc_await_reject_state_trans_1_0 = { + .ev = llc_conn_ev_data_req, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = llc_await_reject_ev_qfyrs_1_0, + .ev_actions = llc_await_reject_actions_1_0, +}; + +/* State transitions for LLC_CONN_EV_LOCAL_BUSY_DETECTED event */ +static const llc_conn_action_t llc_await_rejct_actions_1[] = { + [0] = llc_conn_ac_send_rnr_xxx_x_set_0, + [1] = llc_conn_ac_set_data_flag_2, + [2] = NULL +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_1 = { + .ev = llc_conn_ev_local_busy_detected, + .next_state = LLC_CONN_STATE_AWAIT_BUSY, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_1, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_0_UNEXPD_Ns event */ +static const llc_conn_action_t llc_await_rejct_actions_2a[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = NULL +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_2a = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_0_unexpd_ns, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_2a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_0_UNEXPD_Ns event */ +static const llc_conn_action_t llc_await_rejct_actions_2b[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = NULL +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_2b = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_0_unexpd_ns, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_2b, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_1_UNEXPD_Ns event */ +static const llc_conn_action_t llc_await_rejct_actions_3[] = { + [0] = llc_conn_ac_send_rr_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = NULL +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_3 = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_1_unexpd_ns, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_3, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_1 event */ +static const llc_conn_action_t llc_await_rejct_actions_4[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_stop_p_timer, + [3] = llc_conn_ac_stop_rej_timer, + [4] = llc_conn_ac_upd_nr_received, + [5] = llc_conn_ac_upd_vs, + [6] = llc_conn_ac_resend_i_xxx_x_set_0_or_send_rr, + [7] = llc_conn_ac_clear_remote_busy, + [8] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_4 = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_4, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_await_rejct_actions_5a[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_send_rr_xxx_x_set_0, + [3] = llc_conn_ac_stop_rej_timer, + [4] = llc_conn_ac_upd_nr_received, + [5] = llc_conn_ac_upd_vs, + [6] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_5a = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_5a, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_await_rejct_actions_5b[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_send_rr_xxx_x_set_0, + [3] = llc_conn_ac_stop_rej_timer, + [4] = llc_conn_ac_upd_nr_received, + [5] = llc_conn_ac_upd_vs, + [6] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_5b = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_5b, +}; + +/* State transitions for LLC_CONN_EV_RX_I_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_await_rejct_actions_6[] = { + [0] = llc_conn_ac_inc_vr_by_1, + [1] = llc_conn_ac_data_ind, + [2] = llc_conn_ac_send_rr_rsp_f_set_1, + [3] = llc_conn_ac_stop_rej_timer, + [4] = llc_conn_ac_upd_nr_received, + [5] = llc_conn_ac_upd_vs, + [6] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_6 = { + .ev = llc_conn_ev_rx_i_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_AWAIT, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_6, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_RSP_Fbit_SET_1 event */ +static const llc_conn_action_t llc_await_rejct_actions_7a[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_stop_p_timer, + [3] = llc_conn_ac_resend_i_xxx_x_set_0, + [4] = llc_conn_ac_clear_remote_busy, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_7a = { + .ev = llc_conn_ev_rx_rr_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_7a, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_RSP_Fbit_SET_1 event */ +static const llc_conn_action_t llc_await_rejct_actions_7b[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_stop_p_timer, + [3] = llc_conn_ac_resend_i_xxx_x_set_0, + [4] = llc_conn_ac_clear_remote_busy, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_7b = { + .ev = llc_conn_ev_rx_rej_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_7b, +}; + +/* State transitions for LLC_CONN_EV_RX_I_RSP_Fbit_SET_1_UNEXPD_Ns event */ +static const llc_conn_action_t llc_await_rejct_actions_7c[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_stop_p_timer, + [3] = llc_conn_ac_resend_i_xxx_x_set_0, + [4] = llc_conn_ac_clear_remote_busy, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_7c = { + .ev = llc_conn_ev_rx_i_rsp_fbit_set_1_unexpd_ns, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_7c, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_await_rejct_actions_8a[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_8a = { + .ev = llc_conn_ev_rx_rr_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_8a, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_await_rejct_actions_8b[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_8b = { + .ev = llc_conn_ev_rx_rr_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_8b, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_await_rejct_actions_8c[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_8c = { + .ev = llc_conn_ev_rx_rej_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_8c, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_await_rejct_actions_8d[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_clear_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_8d = { + .ev = llc_conn_ev_rx_rej_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_8d, +}; + +/* State transitions for LLC_CONN_EV_RX_RR_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_await_rejct_actions_9a[] = { + [0] = llc_conn_ac_send_rr_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = llc_conn_ac_clear_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_9a = { + .ev = llc_conn_ev_rx_rr_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_9a, +}; + +/* State transitions for LLC_CONN_EV_RX_REJ_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_await_rejct_actions_9b[] = { + [0] = llc_conn_ac_send_rr_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = llc_conn_ac_clear_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_9b = { + .ev = llc_conn_ev_rx_rej_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_9b, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_RSP_Fbit_SET_1 event */ +static const llc_conn_action_t llc_await_rejct_actions_10[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_stop_p_timer, + [3] = llc_conn_ac_set_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_10 = { + .ev = llc_conn_ev_rx_rnr_rsp_fbit_set_1, + .next_state = LLC_CONN_STATE_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_10, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_CMD_Pbit_SET_0 event */ +static const llc_conn_action_t llc_await_rejct_actions_11a[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_set_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_11a = { + .ev = llc_conn_ev_rx_rnr_cmd_pbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_11a, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_RSP_Fbit_SET_0 event */ +static const llc_conn_action_t llc_await_rejct_actions_11b[] = { + [0] = llc_conn_ac_upd_nr_received, + [1] = llc_conn_ac_upd_vs, + [2] = llc_conn_ac_set_remote_busy, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_11b = { + .ev = llc_conn_ev_rx_rnr_rsp_fbit_set_0, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_11b, +}; + +/* State transitions for LLC_CONN_EV_RX_RNR_CMD_Pbit_SET_1 event */ +static const llc_conn_action_t llc_await_rejct_actions_12[] = { + [0] = llc_conn_ac_send_rr_rsp_f_set_1, + [1] = llc_conn_ac_upd_nr_received, + [2] = llc_conn_ac_upd_vs, + [3] = llc_conn_ac_set_remote_busy, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_12 = { + .ev = llc_conn_ev_rx_rnr_cmd_pbit_set_1, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = NONE, + .ev_actions = llc_await_rejct_actions_12, +}; + +/* State transitions for LLC_CONN_EV_P_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_await_rejct_ev_qfyrs_13[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [1] = NULL, +}; + +static const llc_conn_action_t llc_await_rejct_actions_13[] = { + [0] = llc_conn_ac_send_rej_cmd_p_set_1, + [1] = llc_conn_ac_stop_p_timer, + [2] = llc_conn_ac_inc_retry_cnt_by_1, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_await_rejct_state_trans_13 = { + .ev = llc_conn_ev_p_tmr_exp, + .next_state = LLC_CONN_STATE_AWAIT_REJ, + .ev_qualifiers = llc_await_rejct_ev_qfyrs_13, + .ev_actions = llc_await_rejct_actions_13, +}; + +/* + * Array of pointers; + * one to each transition + */ +static struct llc_conn_state_trans *llc_await_rejct_state_transitions[] = { + [0] = &llc_await_reject_state_trans_1_0, + [1] = &llc_common_state_trans_1, /* requests */ + [2] = &llc_common_state_trans_2, + [3] = &llc_common_state_trans_end, + [4] = &llc_await_rejct_state_trans_1, /* local busy */ + [5] = &llc_common_state_trans_end, + [6] = &llc_common_state_trans_end, /* Initiate PF cycle */ + [7] = &llc_await_rejct_state_trans_13, /* timers */ + [8] = &llc_common_state_trans_11a, + [9] = &llc_common_state_trans_11b, + [10] = &llc_common_state_trans_11c, + [11] = &llc_common_state_trans_11d, + [12] = &llc_common_state_trans_end, + [13] = &llc_await_rejct_state_trans_2a, /* receive frames */ + [14] = &llc_await_rejct_state_trans_2b, + [15] = &llc_await_rejct_state_trans_3, + [16] = &llc_await_rejct_state_trans_4, + [17] = &llc_await_rejct_state_trans_5a, + [18] = &llc_await_rejct_state_trans_5b, + [19] = &llc_await_rejct_state_trans_6, + [20] = &llc_await_rejct_state_trans_7a, + [21] = &llc_await_rejct_state_trans_7b, + [22] = &llc_await_rejct_state_trans_7c, + [23] = &llc_await_rejct_state_trans_8a, + [24] = &llc_await_rejct_state_trans_8b, + [25] = &llc_await_rejct_state_trans_8c, + [26] = &llc_await_rejct_state_trans_8d, + [27] = &llc_await_rejct_state_trans_9a, + [28] = &llc_await_rejct_state_trans_9b, + [29] = &llc_await_rejct_state_trans_10, + [30] = &llc_await_rejct_state_trans_11a, + [31] = &llc_await_rejct_state_trans_11b, + [32] = &llc_await_rejct_state_trans_12, + [33] = &llc_common_state_trans_3, + [34] = &llc_common_state_trans_4, + [35] = &llc_common_state_trans_5, + [36] = &llc_common_state_trans_6, + [37] = &llc_common_state_trans_7a, + [38] = &llc_common_state_trans_7b, + [39] = &llc_common_state_trans_8a, + [40] = &llc_common_state_trans_8b, + [41] = &llc_common_state_trans_8c, + [42] = &llc_common_state_trans_9, + /* [43] = &llc_common_state_trans_10, */ + [43] = &llc_common_state_trans_end, +}; + +/* LLC_CONN_STATE_D_CONN transitions */ +/* State transitions for LLC_CONN_EV_RX_SABME_CMD_Pbit_SET_X event, + * cause_flag = 1 */ +static const llc_conn_ev_qfyr_t llc_d_conn_ev_qfyrs_1[] = { + [0] = llc_conn_ev_qlfy_cause_flag_eq_1, + [1] = llc_conn_ev_qlfy_set_status_conflict, + [2] = NULL, +}; + +static const llc_conn_action_t llc_d_conn_actions_1[] = { + [0] = llc_conn_ac_send_dm_rsp_f_set_p, + [1] = llc_conn_ac_stop_ack_timer, + [2] = llc_conn_ac_disc_confirm, + [3] = llc_conn_disc, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_d_conn_state_trans_1 = { + .ev = llc_conn_ev_rx_sabme_cmd_pbit_set_x, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = llc_d_conn_ev_qfyrs_1, + .ev_actions = llc_d_conn_actions_1, +}; + +/* State transitions for LLC_CONN_EV_RX_SABME_CMD_Pbit_SET_X event, + * cause_flag = 0 + */ +static const llc_conn_ev_qfyr_t llc_d_conn_ev_qfyrs_1_1[] = { + [0] = llc_conn_ev_qlfy_cause_flag_eq_0, + [1] = llc_conn_ev_qlfy_set_status_conflict, + [2] = NULL, +}; + +static const llc_conn_action_t llc_d_conn_actions_1_1[] = { + [0] = llc_conn_ac_send_dm_rsp_f_set_p, + [1] = llc_conn_ac_stop_ack_timer, + [2] = llc_conn_disc, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_d_conn_state_trans_1_1 = { + .ev = llc_conn_ev_rx_sabme_cmd_pbit_set_x, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = llc_d_conn_ev_qfyrs_1_1, + .ev_actions = llc_d_conn_actions_1_1, +}; + +/* State transitions for LLC_CONN_EV_RX_UA_RSP_Fbit_SET_X event, + * cause_flag = 1 + */ +static const llc_conn_ev_qfyr_t llc_d_conn_ev_qfyrs_2[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_f, + [1] = llc_conn_ev_qlfy_cause_flag_eq_1, + [2] = llc_conn_ev_qlfy_set_status_disc, + [3] = NULL, +}; + +static const llc_conn_action_t llc_d_conn_actions_2[] = { + [0] = llc_conn_ac_stop_ack_timer, + [1] = llc_conn_ac_disc_confirm, + [2] = llc_conn_disc, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_d_conn_state_trans_2 = { + .ev = llc_conn_ev_rx_ua_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = llc_d_conn_ev_qfyrs_2, + .ev_actions = llc_d_conn_actions_2, +}; + +/* State transitions for LLC_CONN_EV_RX_UA_RSP_Fbit_SET_X event, + * cause_flag = 0 + */ +static const llc_conn_ev_qfyr_t llc_d_conn_ev_qfyrs_2_1[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_f, + [1] = llc_conn_ev_qlfy_cause_flag_eq_0, + [2] = llc_conn_ev_qlfy_set_status_disc, + [3] = NULL, +}; + +static const llc_conn_action_t llc_d_conn_actions_2_1[] = { + [0] = llc_conn_ac_stop_ack_timer, + [1] = llc_conn_disc, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_d_conn_state_trans_2_1 = { + .ev = llc_conn_ev_rx_ua_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = llc_d_conn_ev_qfyrs_2_1, + .ev_actions = llc_d_conn_actions_2_1, +}; + +/* State transitions for LLC_CONN_EV_RX_DISC_CMD_Pbit_SET_X event */ +static const llc_conn_action_t llc_d_conn_actions_3[] = { + [0] = llc_conn_ac_send_ua_rsp_f_set_p, + [1] = NULL, +}; + +static struct llc_conn_state_trans llc_d_conn_state_trans_3 = { + .ev = llc_conn_ev_rx_disc_cmd_pbit_set_x, + .next_state = LLC_CONN_STATE_D_CONN, + .ev_qualifiers = NONE, + .ev_actions = llc_d_conn_actions_3, +}; + +/* State transitions for LLC_CONN_EV_RX_DM_RSP_Fbit_SET_X event, + * cause_flag = 1 + */ +static const llc_conn_ev_qfyr_t llc_d_conn_ev_qfyrs_4[] = { + [0] = llc_conn_ev_qlfy_cause_flag_eq_1, + [1] = llc_conn_ev_qlfy_set_status_disc, + [2] = NULL, +}; + +static const llc_conn_action_t llc_d_conn_actions_4[] = { + [0] = llc_conn_ac_stop_ack_timer, + [1] = llc_conn_ac_disc_confirm, + [2] = llc_conn_disc, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_d_conn_state_trans_4 = { + .ev = llc_conn_ev_rx_dm_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = llc_d_conn_ev_qfyrs_4, + .ev_actions = llc_d_conn_actions_4, +}; + +/* State transitions for LLC_CONN_EV_RX_DM_RSP_Fbit_SET_X event, + * cause_flag = 0 + */ +static const llc_conn_ev_qfyr_t llc_d_conn_ev_qfyrs_4_1[] = { + [0] = llc_conn_ev_qlfy_cause_flag_eq_0, + [1] = llc_conn_ev_qlfy_set_status_disc, + [2] = NULL, +}; + +static const llc_conn_action_t llc_d_conn_actions_4_1[] = { + [0] = llc_conn_ac_stop_ack_timer, + [1] = llc_conn_disc, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_d_conn_state_trans_4_1 = { + .ev = llc_conn_ev_rx_dm_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = llc_d_conn_ev_qfyrs_4_1, + .ev_actions = llc_d_conn_actions_4_1, +}; + +/* + * State transition for + * LLC_CONN_EV_DATA_CONN_REQ event + */ +static const llc_conn_ev_qfyr_t llc_d_conn_ev_qfyrs_5[] = { + [0] = llc_conn_ev_qlfy_set_status_refuse, + [1] = NULL, +}; + +/* just one member, NULL, .bss zeroes it */ +static const llc_conn_action_t llc_d_conn_actions_5[1]; + +static struct llc_conn_state_trans llc_d_conn_state_trans_5 = { + .ev = llc_conn_ev_data_req, + .next_state = LLC_CONN_STATE_D_CONN, + .ev_qualifiers = llc_d_conn_ev_qfyrs_5, + .ev_actions = llc_d_conn_actions_5, +}; + +/* State transitions for LLC_CONN_EV_ACK_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_d_conn_ev_qfyrs_6[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [1] = NULL, +}; + +static const llc_conn_action_t llc_d_conn_actions_6[] = { + [0] = llc_conn_ac_send_disc_cmd_p_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_inc_retry_cnt_by_1, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_d_conn_state_trans_6 = { + .ev = llc_conn_ev_ack_tmr_exp, + .next_state = LLC_CONN_STATE_D_CONN, + .ev_qualifiers = llc_d_conn_ev_qfyrs_6, + .ev_actions = llc_d_conn_actions_6, +}; + +/* State transitions for LLC_CONN_EV_ACK_TMR_EXP event, cause_flag = 1 */ +static const llc_conn_ev_qfyr_t llc_d_conn_ev_qfyrs_7[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_gte_n2, + [1] = llc_conn_ev_qlfy_cause_flag_eq_1, + [2] = llc_conn_ev_qlfy_set_status_failed, + [3] = NULL, +}; + +static const llc_conn_action_t llc_d_conn_actions_7[] = { + [0] = llc_conn_ac_disc_confirm, + [1] = llc_conn_disc, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_d_conn_state_trans_7 = { + .ev = llc_conn_ev_ack_tmr_exp, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = llc_d_conn_ev_qfyrs_7, + .ev_actions = llc_d_conn_actions_7, +}; + +/* State transitions for LLC_CONN_EV_ACK_TMR_EXP event, cause_flag = 0 */ +static const llc_conn_ev_qfyr_t llc_d_conn_ev_qfyrs_8[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_gte_n2, + [1] = llc_conn_ev_qlfy_cause_flag_eq_0, + [2] = llc_conn_ev_qlfy_set_status_failed, + [3] = NULL, +}; + +static const llc_conn_action_t llc_d_conn_actions_8[] = { + [0] = llc_conn_disc, + [1] = NULL, +}; + +static struct llc_conn_state_trans llc_d_conn_state_trans_8 = { + .ev = llc_conn_ev_ack_tmr_exp, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = llc_d_conn_ev_qfyrs_8, + .ev_actions = llc_d_conn_actions_8, +}; + +/* + * Array of pointers; + * one to each transition + */ +static struct llc_conn_state_trans *llc_d_conn_state_transitions[] = { + [0] = &llc_d_conn_state_trans_5, /* Request */ + [1] = &llc_common_state_trans_end, + [2] = &llc_common_state_trans_end, /* Local busy */ + [3] = &llc_common_state_trans_end, /* Initiate PF cycle */ + [4] = &llc_d_conn_state_trans_6, /* Timer */ + [5] = &llc_d_conn_state_trans_7, + [6] = &llc_d_conn_state_trans_8, + [7] = &llc_common_state_trans_end, + [8] = &llc_d_conn_state_trans_1, /* Receive frame */ + [9] = &llc_d_conn_state_trans_1_1, + [10] = &llc_d_conn_state_trans_2, + [11] = &llc_d_conn_state_trans_2_1, + [12] = &llc_d_conn_state_trans_3, + [13] = &llc_d_conn_state_trans_4, + [14] = &llc_d_conn_state_trans_4_1, + [15] = &llc_common_state_trans_end, +}; + +/* LLC_CONN_STATE_RESET transitions */ +/* State transitions for LLC_CONN_EV_RX_SABME_CMD_Pbit_SET_X event */ +static const llc_conn_action_t llc_rst_actions_1[] = { + [0] = llc_conn_ac_set_vs_0, + [1] = llc_conn_ac_set_vr_0, + [2] = llc_conn_ac_set_s_flag_1, + [3] = llc_conn_ac_send_ua_rsp_f_set_p, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_rst_state_trans_1 = { + .ev = llc_conn_ev_rx_sabme_cmd_pbit_set_x, + .next_state = LLC_CONN_STATE_RESET, + .ev_qualifiers = NONE, + .ev_actions = llc_rst_actions_1, +}; + +/* State transitions for LLC_CONN_EV_RX_UA_RSP_Fbit_SET_X event, + * cause_flag = 1 + */ +static const llc_conn_ev_qfyr_t llc_rst_ev_qfyrs_2[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_f, + [1] = llc_conn_ev_qlfy_cause_flag_eq_1, + [2] = llc_conn_ev_qlfy_set_status_conn, + [3] = NULL, +}; + +static const llc_conn_action_t llc_rst_actions_2[] = { + [0] = llc_conn_ac_stop_ack_timer, + [1] = llc_conn_ac_set_vs_0, + [2] = llc_conn_ac_set_vr_0, + [3] = llc_conn_ac_upd_p_flag, + [4] = llc_conn_ac_rst_confirm, + [5] = llc_conn_ac_set_remote_busy_0, + [6] = llc_conn_reset, + [7] = NULL, +}; + +static struct llc_conn_state_trans llc_rst_state_trans_2 = { + .ev = llc_conn_ev_rx_ua_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_rst_ev_qfyrs_2, + .ev_actions = llc_rst_actions_2, +}; + +/* State transitions for LLC_CONN_EV_RX_UA_RSP_Fbit_SET_X event, + * cause_flag = 0 + */ +static const llc_conn_ev_qfyr_t llc_rst_ev_qfyrs_2_1[] = { + [0] = llc_conn_ev_qlfy_p_flag_eq_f, + [1] = llc_conn_ev_qlfy_cause_flag_eq_0, + [2] = llc_conn_ev_qlfy_set_status_rst_done, + [3] = NULL, +}; + +static const llc_conn_action_t llc_rst_actions_2_1[] = { + [0] = llc_conn_ac_stop_ack_timer, + [1] = llc_conn_ac_set_vs_0, + [2] = llc_conn_ac_set_vr_0, + [3] = llc_conn_ac_upd_p_flag, + [4] = llc_conn_ac_rst_confirm, + [5] = llc_conn_ac_set_remote_busy_0, + [6] = llc_conn_reset, + [7] = NULL, +}; + +static struct llc_conn_state_trans llc_rst_state_trans_2_1 = { + .ev = llc_conn_ev_rx_ua_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_rst_ev_qfyrs_2_1, + .ev_actions = llc_rst_actions_2_1, +}; + +/* State transitions for LLC_CONN_EV_ACK_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_rst_ev_qfyrs_3[] = { + [0] = llc_conn_ev_qlfy_s_flag_eq_1, + [1] = llc_conn_ev_qlfy_set_status_rst_done, + [2] = NULL, +}; + +static const llc_conn_action_t llc_rst_actions_3[] = { + [0] = llc_conn_ac_set_p_flag_0, + [1] = llc_conn_ac_set_remote_busy_0, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_rst_state_trans_3 = { + .ev = llc_conn_ev_ack_tmr_exp, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = llc_rst_ev_qfyrs_3, + .ev_actions = llc_rst_actions_3, +}; + +/* State transitions for LLC_CONN_EV_RX_DISC_CMD_Pbit_SET_X event, + * cause_flag = 1 + */ +static const llc_conn_ev_qfyr_t llc_rst_ev_qfyrs_4[] = { + [0] = llc_conn_ev_qlfy_cause_flag_eq_1, + [1] = llc_conn_ev_qlfy_set_status_disc, + [2] = NULL, +}; +static const llc_conn_action_t llc_rst_actions_4[] = { + [0] = llc_conn_ac_send_dm_rsp_f_set_p, + [1] = llc_conn_ac_disc_ind, + [2] = llc_conn_ac_stop_ack_timer, + [3] = llc_conn_disc, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_rst_state_trans_4 = { + .ev = llc_conn_ev_rx_disc_cmd_pbit_set_x, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = llc_rst_ev_qfyrs_4, + .ev_actions = llc_rst_actions_4, +}; + +/* State transitions for LLC_CONN_EV_RX_DISC_CMD_Pbit_SET_X event, + * cause_flag = 0 + */ +static const llc_conn_ev_qfyr_t llc_rst_ev_qfyrs_4_1[] = { + [0] = llc_conn_ev_qlfy_cause_flag_eq_0, + [1] = llc_conn_ev_qlfy_set_status_refuse, + [2] = NULL, +}; + +static const llc_conn_action_t llc_rst_actions_4_1[] = { + [0] = llc_conn_ac_send_dm_rsp_f_set_p, + [1] = llc_conn_ac_stop_ack_timer, + [2] = llc_conn_disc, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_rst_state_trans_4_1 = { + .ev = llc_conn_ev_rx_disc_cmd_pbit_set_x, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = llc_rst_ev_qfyrs_4_1, + .ev_actions = llc_rst_actions_4_1, +}; + +/* State transitions for LLC_CONN_EV_RX_DM_RSP_Fbit_SET_X event, + * cause_flag = 1 + */ +static const llc_conn_ev_qfyr_t llc_rst_ev_qfyrs_5[] = { + [0] = llc_conn_ev_qlfy_cause_flag_eq_1, + [1] = llc_conn_ev_qlfy_set_status_disc, + [2] = NULL, +}; + +static const llc_conn_action_t llc_rst_actions_5[] = { + [0] = llc_conn_ac_disc_ind, + [1] = llc_conn_ac_stop_ack_timer, + [2] = llc_conn_disc, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_rst_state_trans_5 = { + .ev = llc_conn_ev_rx_dm_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = llc_rst_ev_qfyrs_5, + .ev_actions = llc_rst_actions_5, +}; + +/* State transitions for LLC_CONN_EV_RX_DM_RSP_Fbit_SET_X event, + * cause_flag = 0 + */ +static const llc_conn_ev_qfyr_t llc_rst_ev_qfyrs_5_1[] = { + [0] = llc_conn_ev_qlfy_cause_flag_eq_0, + [1] = llc_conn_ev_qlfy_set_status_refuse, + [2] = NULL, +}; + +static const llc_conn_action_t llc_rst_actions_5_1[] = { + [0] = llc_conn_ac_stop_ack_timer, + [1] = llc_conn_disc, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_rst_state_trans_5_1 = { + .ev = llc_conn_ev_rx_dm_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = llc_rst_ev_qfyrs_5_1, + .ev_actions = llc_rst_actions_5_1, +}; + +/* State transitions for DATA_CONN_REQ event */ +static const llc_conn_ev_qfyr_t llc_rst_ev_qfyrs_6[] = { + [0] = llc_conn_ev_qlfy_set_status_refuse, + [1] = NULL, +}; + +/* just one member, NULL, .bss zeroes it */ +static const llc_conn_action_t llc_rst_actions_6[1]; + +static struct llc_conn_state_trans llc_rst_state_trans_6 = { + .ev = llc_conn_ev_data_req, + .next_state = LLC_CONN_STATE_RESET, + .ev_qualifiers = llc_rst_ev_qfyrs_6, + .ev_actions = llc_rst_actions_6, +}; + +/* State transitions for LLC_CONN_EV_ACK_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_rst_ev_qfyrs_7[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [1] = llc_conn_ev_qlfy_s_flag_eq_0, + [2] = NULL, +}; + +static const llc_conn_action_t llc_rst_actions_7[] = { + [0] = llc_conn_ac_send_sabme_cmd_p_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_inc_retry_cnt_by_1, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_rst_state_trans_7 = { + .ev = llc_conn_ev_ack_tmr_exp, + .next_state = LLC_CONN_STATE_RESET, + .ev_qualifiers = llc_rst_ev_qfyrs_7, + .ev_actions = llc_rst_actions_7, +}; + +/* State transitions for LLC_CONN_EV_ACK_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_rst_ev_qfyrs_8[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_gte_n2, + [1] = llc_conn_ev_qlfy_s_flag_eq_0, + [2] = llc_conn_ev_qlfy_cause_flag_eq_1, + [3] = llc_conn_ev_qlfy_set_status_failed, + [4] = NULL, +}; +static const llc_conn_action_t llc_rst_actions_8[] = { + [0] = llc_conn_ac_disc_ind, + [1] = llc_conn_disc, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_rst_state_trans_8 = { + .ev = llc_conn_ev_ack_tmr_exp, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = llc_rst_ev_qfyrs_8, + .ev_actions = llc_rst_actions_8, +}; + +/* State transitions for LLC_CONN_EV_ACK_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_rst_ev_qfyrs_8_1[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_gte_n2, + [1] = llc_conn_ev_qlfy_s_flag_eq_0, + [2] = llc_conn_ev_qlfy_cause_flag_eq_0, + [3] = llc_conn_ev_qlfy_set_status_failed, + [4] = NULL, +}; +static const llc_conn_action_t llc_rst_actions_8_1[] = { + [0] = llc_conn_ac_disc_ind, + [1] = llc_conn_disc, + [2] = NULL, +}; + +static struct llc_conn_state_trans llc_rst_state_trans_8_1 = { + .ev = llc_conn_ev_ack_tmr_exp, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = llc_rst_ev_qfyrs_8_1, + .ev_actions = llc_rst_actions_8_1, +}; + +/* + * Array of pointers; + * one to each transition + */ +static struct llc_conn_state_trans *llc_rst_state_transitions[] = { + [0] = &llc_rst_state_trans_6, /* Request */ + [1] = &llc_common_state_trans_end, + [2] = &llc_common_state_trans_end, /* Local busy */ + [3] = &llc_common_state_trans_end, /* Initiate PF cycle */ + [4] = &llc_rst_state_trans_3, /* Timer */ + [5] = &llc_rst_state_trans_7, + [6] = &llc_rst_state_trans_8, + [7] = &llc_rst_state_trans_8_1, + [8] = &llc_common_state_trans_end, + [9] = &llc_rst_state_trans_1, /* Receive frame */ + [10] = &llc_rst_state_trans_2, + [11] = &llc_rst_state_trans_2_1, + [12] = &llc_rst_state_trans_4, + [13] = &llc_rst_state_trans_4_1, + [14] = &llc_rst_state_trans_5, + [15] = &llc_rst_state_trans_5_1, + [16] = &llc_common_state_trans_end, +}; + +/* LLC_CONN_STATE_ERROR transitions */ +/* State transitions for LLC_CONN_EV_RX_SABME_CMD_Pbit_SET_X event */ +static const llc_conn_action_t llc_error_actions_1[] = { + [0] = llc_conn_ac_set_vs_0, + [1] = llc_conn_ac_set_vr_0, + [2] = llc_conn_ac_send_ua_rsp_f_set_p, + [3] = llc_conn_ac_rst_ind, + [4] = llc_conn_ac_set_p_flag_0, + [5] = llc_conn_ac_set_remote_busy_0, + [6] = llc_conn_ac_stop_ack_timer, + [7] = llc_conn_reset, + [8] = NULL, +}; + +static struct llc_conn_state_trans llc_error_state_trans_1 = { + .ev = llc_conn_ev_rx_sabme_cmd_pbit_set_x, + .next_state = LLC_CONN_STATE_NORMAL, + .ev_qualifiers = NONE, + .ev_actions = llc_error_actions_1, +}; + +/* State transitions for LLC_CONN_EV_RX_DISC_CMD_Pbit_SET_X event */ +static const llc_conn_action_t llc_error_actions_2[] = { + [0] = llc_conn_ac_send_ua_rsp_f_set_p, + [1] = llc_conn_ac_disc_ind, + [2] = llc_conn_ac_stop_ack_timer, + [3] = llc_conn_disc, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_error_state_trans_2 = { + .ev = llc_conn_ev_rx_disc_cmd_pbit_set_x, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = NONE, + .ev_actions = llc_error_actions_2, +}; + +/* State transitions for LLC_CONN_EV_RX_DM_RSP_Fbit_SET_X event */ +static const llc_conn_action_t llc_error_actions_3[] = { + [0] = llc_conn_ac_disc_ind, + [1] = llc_conn_ac_stop_ack_timer, + [2] = llc_conn_disc, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_error_state_trans_3 = { + .ev = llc_conn_ev_rx_dm_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = NONE, + .ev_actions = llc_error_actions_3, +}; + +/* State transitions for LLC_CONN_EV_RX_FRMR_RSP_Fbit_SET_X event */ +static const llc_conn_action_t llc_error_actions_4[] = { + [0] = llc_conn_ac_send_sabme_cmd_p_set_x, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_set_retry_cnt_0, + [3] = llc_conn_ac_set_cause_flag_0, + [4] = NULL, +}; + +static struct llc_conn_state_trans llc_error_state_trans_4 = { + .ev = llc_conn_ev_rx_frmr_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_RESET, + .ev_qualifiers = NONE, + .ev_actions = llc_error_actions_4, +}; + +/* State transitions for LLC_CONN_EV_RX_XXX_CMD_Pbit_SET_X event */ +static const llc_conn_action_t llc_error_actions_5[] = { + [0] = llc_conn_ac_resend_frmr_rsp_f_set_p, + [1] = NULL, +}; + +static struct llc_conn_state_trans llc_error_state_trans_5 = { + .ev = llc_conn_ev_rx_xxx_cmd_pbit_set_x, + .next_state = LLC_CONN_STATE_ERROR, + .ev_qualifiers = NONE, + .ev_actions = llc_error_actions_5, +}; + +/* State transitions for LLC_CONN_EV_RX_XXX_RSP_Fbit_SET_X event */ +static struct llc_conn_state_trans llc_error_state_trans_6 = { + .ev = llc_conn_ev_rx_xxx_rsp_fbit_set_x, + .next_state = LLC_CONN_STATE_ERROR, + .ev_qualifiers = NONE, + .ev_actions = NONE, +}; + +/* State transitions for LLC_CONN_EV_ACK_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_error_ev_qfyrs_7[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_lt_n2, + [1] = NULL, +}; + +static const llc_conn_action_t llc_error_actions_7[] = { + [0] = llc_conn_ac_resend_frmr_rsp_f_set_0, + [1] = llc_conn_ac_start_ack_timer, + [2] = llc_conn_ac_inc_retry_cnt_by_1, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_error_state_trans_7 = { + .ev = llc_conn_ev_ack_tmr_exp, + .next_state = LLC_CONN_STATE_ERROR, + .ev_qualifiers = llc_error_ev_qfyrs_7, + .ev_actions = llc_error_actions_7, +}; + +/* State transitions for LLC_CONN_EV_ACK_TMR_EXP event */ +static const llc_conn_ev_qfyr_t llc_error_ev_qfyrs_8[] = { + [0] = llc_conn_ev_qlfy_retry_cnt_gte_n2, + [1] = NULL, +}; + +static const llc_conn_action_t llc_error_actions_8[] = { + [0] = llc_conn_ac_send_sabme_cmd_p_set_x, + [1] = llc_conn_ac_set_s_flag_0, + [2] = llc_conn_ac_start_ack_timer, + [3] = llc_conn_ac_set_retry_cnt_0, + [4] = llc_conn_ac_set_cause_flag_0, + [5] = NULL, +}; + +static struct llc_conn_state_trans llc_error_state_trans_8 = { + .ev = llc_conn_ev_ack_tmr_exp, + .next_state = LLC_CONN_STATE_RESET, + .ev_qualifiers = llc_error_ev_qfyrs_8, + .ev_actions = llc_error_actions_8, +}; + +/* State transitions for LLC_CONN_EV_DATA_CONN_REQ event */ +static const llc_conn_ev_qfyr_t llc_error_ev_qfyrs_9[] = { + [0] = llc_conn_ev_qlfy_set_status_refuse, + [1] = NULL, +}; + +/* just one member, NULL, .bss zeroes it */ +static const llc_conn_action_t llc_error_actions_9[1]; + +static struct llc_conn_state_trans llc_error_state_trans_9 = { + .ev = llc_conn_ev_data_req, + .next_state = LLC_CONN_STATE_ERROR, + .ev_qualifiers = llc_error_ev_qfyrs_9, + .ev_actions = llc_error_actions_9, +}; + +/* + * Array of pointers; + * one to each transition + */ +static struct llc_conn_state_trans *llc_error_state_transitions[] = { + [0] = &llc_error_state_trans_9, /* Request */ + [1] = &llc_common_state_trans_end, + [2] = &llc_common_state_trans_end, /* Local busy */ + [3] = &llc_common_state_trans_end, /* Initiate PF cycle */ + [4] = &llc_error_state_trans_7, /* Timer */ + [5] = &llc_error_state_trans_8, + [6] = &llc_common_state_trans_end, + [7] = &llc_error_state_trans_1, /* Receive frame */ + [8] = &llc_error_state_trans_2, + [9] = &llc_error_state_trans_3, + [10] = &llc_error_state_trans_4, + [11] = &llc_error_state_trans_5, + [12] = &llc_error_state_trans_6, + [13] = &llc_common_state_trans_end, +}; + +/* LLC_CONN_STATE_TEMP transitions */ +/* State transitions for LLC_CONN_EV_DISC_REQ event */ +static const llc_conn_action_t llc_temp_actions_1[] = { + [0] = llc_conn_ac_stop_all_timers, + [1] = llc_conn_ac_send_disc_cmd_p_set_x, + [2] = llc_conn_disc, + [3] = NULL, +}; + +static struct llc_conn_state_trans llc_temp_state_trans_1 = { + .ev = llc_conn_ev_disc_req, + .next_state = LLC_CONN_STATE_ADM, + .ev_qualifiers = NONE, + .ev_actions = llc_temp_actions_1, +}; + +/* + * Array of pointers; + * one to each transition + */ +static struct llc_conn_state_trans *llc_temp_state_transitions[] = { + [0] = &llc_temp_state_trans_1, /* requests */ + [1] = &llc_common_state_trans_end, + [2] = &llc_common_state_trans_end, /* local busy */ + [3] = &llc_common_state_trans_end, /* init_pf_cycle */ + [4] = &llc_common_state_trans_end, /* timer */ + [5] = &llc_common_state_trans_end, /* receive */ +}; + +/* Connection State Transition Table */ +struct llc_conn_state llc_conn_state_table[NBR_CONN_STATES] = { + [LLC_CONN_STATE_ADM - 1] = { + .current_state = LLC_CONN_STATE_ADM, + .transitions = llc_adm_state_transitions, + }, + [LLC_CONN_STATE_SETUP - 1] = { + .current_state = LLC_CONN_STATE_SETUP, + .transitions = llc_setup_state_transitions, + }, + [LLC_CONN_STATE_NORMAL - 1] = { + .current_state = LLC_CONN_STATE_NORMAL, + .transitions = llc_normal_state_transitions, + }, + [LLC_CONN_STATE_BUSY - 1] = { + .current_state = LLC_CONN_STATE_BUSY, + .transitions = llc_busy_state_transitions, + }, + [LLC_CONN_STATE_REJ - 1] = { + .current_state = LLC_CONN_STATE_REJ, + .transitions = llc_reject_state_transitions, + }, + [LLC_CONN_STATE_AWAIT - 1] = { + .current_state = LLC_CONN_STATE_AWAIT, + .transitions = llc_await_state_transitions, + }, + [LLC_CONN_STATE_AWAIT_BUSY - 1] = { + .current_state = LLC_CONN_STATE_AWAIT_BUSY, + .transitions = llc_await_busy_state_transitions, + }, + [LLC_CONN_STATE_AWAIT_REJ - 1] = { + .current_state = LLC_CONN_STATE_AWAIT_REJ, + .transitions = llc_await_rejct_state_transitions, + }, + [LLC_CONN_STATE_D_CONN - 1] = { + .current_state = LLC_CONN_STATE_D_CONN, + .transitions = llc_d_conn_state_transitions, + }, + [LLC_CONN_STATE_RESET - 1] = { + .current_state = LLC_CONN_STATE_RESET, + .transitions = llc_rst_state_transitions, + }, + [LLC_CONN_STATE_ERROR - 1] = { + .current_state = LLC_CONN_STATE_ERROR, + .transitions = llc_error_state_transitions, + }, + [LLC_CONN_STATE_TEMP - 1] = { + .current_state = LLC_CONN_STATE_TEMP, + .transitions = llc_temp_state_transitions, + }, +}; diff --git a/net/llc/llc_conn.c b/net/llc/llc_conn.c new file mode 100644 index 000000000..912aa9bd5 --- /dev/null +++ b/net/llc/llc_conn.c @@ -0,0 +1,1020 @@ +/* + * llc_conn.c - Driver routines for connection component. + * + * Copyright (c) 1997 by Procom Technology, Inc. + * 2001-2003 by Arnaldo Carvalho de Melo + * + * This program can be redistributed or modified under the terms of the + * GNU General Public License as published by the Free Software Foundation. + * This program is distributed without any warranty or implied warranty + * of merchantability or fitness for a particular purpose. + * + * See the GNU General Public License for more details. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if 0 +#define dprintk(args...) printk(KERN_DEBUG args) +#else +#define dprintk(args...) +#endif + +static int llc_find_offset(int state, int ev_type); +static void llc_conn_send_pdus(struct sock *sk); +static int llc_conn_service(struct sock *sk, struct sk_buff *skb); +static int llc_exec_conn_trans_actions(struct sock *sk, + struct llc_conn_state_trans *trans, + struct sk_buff *ev); +static struct llc_conn_state_trans *llc_qualify_conn_ev(struct sock *sk, + struct sk_buff *skb); + +/* Offset table on connection states transition diagram */ +static int llc_offset_table[NBR_CONN_STATES][NBR_CONN_EV]; + +int sysctl_llc2_ack_timeout = LLC2_ACK_TIME * HZ; +int sysctl_llc2_p_timeout = LLC2_P_TIME * HZ; +int sysctl_llc2_rej_timeout = LLC2_REJ_TIME * HZ; +int sysctl_llc2_busy_timeout = LLC2_BUSY_TIME * HZ; + +/** + * llc_conn_state_process - sends event to connection state machine + * @sk: connection + * @skb: occurred event + * + * Sends an event to connection state machine. After processing event + * (executing it's actions and changing state), upper layer will be + * indicated or confirmed, if needed. Returns 0 for success, 1 for + * failure. The socket lock has to be held before calling this function. + * + * This function always consumes a reference to the skb. + */ +int llc_conn_state_process(struct sock *sk, struct sk_buff *skb) +{ + int rc; + struct llc_sock *llc = llc_sk(skb->sk); + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + ev->ind_prim = ev->cfm_prim = 0; + /* + * Send event to state machine + */ + rc = llc_conn_service(skb->sk, skb); + if (unlikely(rc != 0)) { + printk(KERN_ERR "%s: llc_conn_service failed\n", __func__); + goto out_skb_put; + } + + switch (ev->ind_prim) { + case LLC_DATA_PRIM: + skb_get(skb); + llc_save_primitive(sk, skb, LLC_DATA_PRIM); + if (unlikely(sock_queue_rcv_skb(sk, skb))) { + /* + * shouldn't happen + */ + printk(KERN_ERR "%s: sock_queue_rcv_skb failed!\n", + __func__); + kfree_skb(skb); + } + break; + case LLC_CONN_PRIM: + /* + * Can't be sock_queue_rcv_skb, because we have to leave the + * skb->sk pointing to the newly created struct sock in + * llc_conn_handler. -acme + */ + skb_get(skb); + skb_queue_tail(&sk->sk_receive_queue, skb); + sk->sk_state_change(sk); + break; + case LLC_DISC_PRIM: + sock_hold(sk); + if (sk->sk_type == SOCK_STREAM && + sk->sk_state == TCP_ESTABLISHED) { + sk->sk_shutdown = SHUTDOWN_MASK; + sk->sk_socket->state = SS_UNCONNECTED; + sk->sk_state = TCP_CLOSE; + if (!sock_flag(sk, SOCK_DEAD)) { + sock_set_flag(sk, SOCK_DEAD); + sk->sk_state_change(sk); + } + } + sock_put(sk); + break; + case LLC_RESET_PRIM: + /* + * FIXME: + * RESET is not being notified to upper layers for now + */ + printk(KERN_INFO "%s: received a reset ind!\n", __func__); + break; + default: + if (ev->ind_prim) + printk(KERN_INFO "%s: received unknown %d prim!\n", + __func__, ev->ind_prim); + /* No indication */ + break; + } + + switch (ev->cfm_prim) { + case LLC_DATA_PRIM: + if (!llc_data_accept_state(llc->state)) + sk->sk_write_space(sk); + else + rc = llc->failed_data_req = 1; + break; + case LLC_CONN_PRIM: + if (sk->sk_type == SOCK_STREAM && + sk->sk_state == TCP_SYN_SENT) { + if (ev->status) { + sk->sk_socket->state = SS_UNCONNECTED; + sk->sk_state = TCP_CLOSE; + } else { + sk->sk_socket->state = SS_CONNECTED; + sk->sk_state = TCP_ESTABLISHED; + } + sk->sk_state_change(sk); + } + break; + case LLC_DISC_PRIM: + sock_hold(sk); + if (sk->sk_type == SOCK_STREAM && sk->sk_state == TCP_CLOSING) { + sk->sk_socket->state = SS_UNCONNECTED; + sk->sk_state = TCP_CLOSE; + sk->sk_state_change(sk); + } + sock_put(sk); + break; + case LLC_RESET_PRIM: + /* + * FIXME: + * RESET is not being notified to upper layers for now + */ + printk(KERN_INFO "%s: received a reset conf!\n", __func__); + break; + default: + if (ev->cfm_prim) + printk(KERN_INFO "%s: received unknown %d prim!\n", + __func__, ev->cfm_prim); + /* No confirmation */ + break; + } +out_skb_put: + kfree_skb(skb); + return rc; +} + +void llc_conn_send_pdu(struct sock *sk, struct sk_buff *skb) +{ + /* queue PDU to send to MAC layer */ + skb_queue_tail(&sk->sk_write_queue, skb); + llc_conn_send_pdus(sk); +} + +/** + * llc_conn_rtn_pdu - sends received data pdu to upper layer + * @sk: Active connection + * @skb: Received data frame + * + * Sends received data pdu to upper layer (by using indicate function). + * Prepares service parameters (prim and prim_data). calling indication + * function will be done in llc_conn_state_process. + */ +void llc_conn_rtn_pdu(struct sock *sk, struct sk_buff *skb) +{ + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + ev->ind_prim = LLC_DATA_PRIM; +} + +/** + * llc_conn_resend_i_pdu_as_cmd - resend all all unacknowledged I PDUs + * @sk: active connection + * @nr: NR + * @first_p_bit: p_bit value of first pdu + * + * Resend all unacknowledged I PDUs, starting with the NR; send first as + * command PDU with P bit equal first_p_bit; if more than one send + * subsequent as command PDUs with P bit equal zero (0). + */ +void llc_conn_resend_i_pdu_as_cmd(struct sock *sk, u8 nr, u8 first_p_bit) +{ + struct sk_buff *skb; + struct llc_pdu_sn *pdu; + u16 nbr_unack_pdus; + struct llc_sock *llc; + u8 howmany_resend = 0; + + llc_conn_remove_acked_pdus(sk, nr, &nbr_unack_pdus); + if (!nbr_unack_pdus) + goto out; + /* + * Process unack PDUs only if unack queue is not empty; remove + * appropriate PDUs, fix them up, and put them on mac_pdu_q. + */ + llc = llc_sk(sk); + + while ((skb = skb_dequeue(&llc->pdu_unack_q)) != NULL) { + pdu = llc_pdu_sn_hdr(skb); + llc_pdu_set_cmd_rsp(skb, LLC_PDU_CMD); + llc_pdu_set_pf_bit(skb, first_p_bit); + skb_queue_tail(&sk->sk_write_queue, skb); + first_p_bit = 0; + llc->vS = LLC_I_GET_NS(pdu); + howmany_resend++; + } + if (howmany_resend > 0) + llc->vS = (llc->vS + 1) % LLC_2_SEQ_NBR_MODULO; + /* any PDUs to re-send are queued up; start sending to MAC */ + llc_conn_send_pdus(sk); +out:; +} + +/** + * llc_conn_resend_i_pdu_as_rsp - Resend all unacknowledged I PDUs + * @sk: active connection. + * @nr: NR + * @first_f_bit: f_bit value of first pdu. + * + * Resend all unacknowledged I PDUs, starting with the NR; send first as + * response PDU with F bit equal first_f_bit; if more than one send + * subsequent as response PDUs with F bit equal zero (0). + */ +void llc_conn_resend_i_pdu_as_rsp(struct sock *sk, u8 nr, u8 first_f_bit) +{ + struct sk_buff *skb; + u16 nbr_unack_pdus; + struct llc_sock *llc = llc_sk(sk); + u8 howmany_resend = 0; + + llc_conn_remove_acked_pdus(sk, nr, &nbr_unack_pdus); + if (!nbr_unack_pdus) + goto out; + /* + * Process unack PDUs only if unack queue is not empty; remove + * appropriate PDUs, fix them up, and put them on mac_pdu_q + */ + while ((skb = skb_dequeue(&llc->pdu_unack_q)) != NULL) { + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + llc_pdu_set_cmd_rsp(skb, LLC_PDU_RSP); + llc_pdu_set_pf_bit(skb, first_f_bit); + skb_queue_tail(&sk->sk_write_queue, skb); + first_f_bit = 0; + llc->vS = LLC_I_GET_NS(pdu); + howmany_resend++; + } + if (howmany_resend > 0) + llc->vS = (llc->vS + 1) % LLC_2_SEQ_NBR_MODULO; + /* any PDUs to re-send are queued up; start sending to MAC */ + llc_conn_send_pdus(sk); +out:; +} + +/** + * llc_conn_remove_acked_pdus - Removes acknowledged pdus from tx queue + * @sk: active connection + * @nr: NR + * @how_many_unacked: size of pdu_unack_q after removing acked pdus + * + * Removes acknowledged pdus from transmit queue (pdu_unack_q). Returns + * the number of pdus that removed from queue. + */ +int llc_conn_remove_acked_pdus(struct sock *sk, u8 nr, u16 *how_many_unacked) +{ + int pdu_pos, i; + struct sk_buff *skb; + struct llc_pdu_sn *pdu; + int nbr_acked = 0; + struct llc_sock *llc = llc_sk(sk); + int q_len = skb_queue_len(&llc->pdu_unack_q); + + if (!q_len) + goto out; + skb = skb_peek(&llc->pdu_unack_q); + pdu = llc_pdu_sn_hdr(skb); + + /* finding position of last acked pdu in queue */ + pdu_pos = ((int)LLC_2_SEQ_NBR_MODULO + (int)nr - + (int)LLC_I_GET_NS(pdu)) % LLC_2_SEQ_NBR_MODULO; + + for (i = 0; i < pdu_pos && i < q_len; i++) { + skb = skb_dequeue(&llc->pdu_unack_q); + kfree_skb(skb); + nbr_acked++; + } +out: + *how_many_unacked = skb_queue_len(&llc->pdu_unack_q); + return nbr_acked; +} + +/** + * llc_conn_send_pdus - Sends queued PDUs + * @sk: active connection + * + * Sends queued pdus to MAC layer for transmission. + */ +static void llc_conn_send_pdus(struct sock *sk) +{ + struct sk_buff *skb; + + while ((skb = skb_dequeue(&sk->sk_write_queue)) != NULL) { + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + if (LLC_PDU_TYPE_IS_I(pdu) && + !(skb->dev->flags & IFF_LOOPBACK)) { + struct sk_buff *skb2 = skb_clone(skb, GFP_ATOMIC); + + skb_queue_tail(&llc_sk(sk)->pdu_unack_q, skb); + if (!skb2) + break; + skb = skb2; + } + dev_queue_xmit(skb); + } +} + +/** + * llc_conn_service - finds transition and changes state of connection + * @sk: connection + * @skb: happened event + * + * This function finds transition that matches with happened event, then + * executes related actions and finally changes state of connection. + * Returns 0 for success, 1 for failure. + */ +static int llc_conn_service(struct sock *sk, struct sk_buff *skb) +{ + int rc = 1; + struct llc_sock *llc = llc_sk(sk); + struct llc_conn_state_trans *trans; + + if (llc->state > NBR_CONN_STATES) + goto out; + rc = 0; + trans = llc_qualify_conn_ev(sk, skb); + if (trans) { + rc = llc_exec_conn_trans_actions(sk, trans, skb); + if (!rc && trans->next_state != NO_STATE_CHANGE) { + llc->state = trans->next_state; + if (!llc_data_accept_state(llc->state)) + sk->sk_state_change(sk); + } + } +out: + return rc; +} + +/** + * llc_qualify_conn_ev - finds transition for event + * @sk: connection + * @skb: happened event + * + * This function finds transition that matches with happened event. + * Returns pointer to found transition on success, %NULL otherwise. + */ +static struct llc_conn_state_trans *llc_qualify_conn_ev(struct sock *sk, + struct sk_buff *skb) +{ + struct llc_conn_state_trans **next_trans; + const llc_conn_ev_qfyr_t *next_qualifier; + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + struct llc_sock *llc = llc_sk(sk); + struct llc_conn_state *curr_state = + &llc_conn_state_table[llc->state - 1]; + + /* search thru events for this state until + * list exhausted or until no more + */ + for (next_trans = curr_state->transitions + + llc_find_offset(llc->state - 1, ev->type); + (*next_trans)->ev; next_trans++) { + if (!((*next_trans)->ev)(sk, skb)) { + /* got POSSIBLE event match; the event may require + * qualification based on the values of a number of + * state flags; if all qualifications are met (i.e., + * if all qualifying functions return success, or 0, + * then this is THE event we're looking for + */ + for (next_qualifier = (*next_trans)->ev_qualifiers; + next_qualifier && *next_qualifier && + !(*next_qualifier)(sk, skb); next_qualifier++) + /* nothing */; + if (!next_qualifier || !*next_qualifier) + /* all qualifiers executed successfully; this is + * our transition; return it so we can perform + * the associated actions & change the state + */ + return *next_trans; + } + } + return NULL; +} + +/** + * llc_exec_conn_trans_actions - executes related actions + * @sk: connection + * @trans: transition that it's actions must be performed + * @skb: event + * + * Executes actions that is related to happened event. Returns 0 for + * success, 1 to indicate failure of at least one action. + */ +static int llc_exec_conn_trans_actions(struct sock *sk, + struct llc_conn_state_trans *trans, + struct sk_buff *skb) +{ + int rc = 0; + const llc_conn_action_t *next_action; + + for (next_action = trans->ev_actions; + next_action && *next_action; next_action++) { + int rc2 = (*next_action)(sk, skb); + + if (rc2 == 2) { + rc = rc2; + break; + } else if (rc2) + rc = 1; + } + return rc; +} + +static inline bool llc_estab_match(const struct llc_sap *sap, + const struct llc_addr *daddr, + const struct llc_addr *laddr, + const struct sock *sk) +{ + struct llc_sock *llc = llc_sk(sk); + + return llc->laddr.lsap == laddr->lsap && + llc->daddr.lsap == daddr->lsap && + ether_addr_equal(llc->laddr.mac, laddr->mac) && + ether_addr_equal(llc->daddr.mac, daddr->mac); +} + +/** + * __llc_lookup_established - Finds connection for the remote/local sap/mac + * @sap: SAP + * @daddr: address of remote LLC (MAC + SAP) + * @laddr: address of local LLC (MAC + SAP) + * + * Search connection list of the SAP and finds connection using the remote + * mac, remote sap, local mac, and local sap. Returns pointer for + * connection found, %NULL otherwise. + * Caller has to make sure local_bh is disabled. + */ +static struct sock *__llc_lookup_established(struct llc_sap *sap, + struct llc_addr *daddr, + struct llc_addr *laddr) +{ + struct sock *rc; + struct hlist_nulls_node *node; + int slot = llc_sk_laddr_hashfn(sap, laddr); + struct hlist_nulls_head *laddr_hb = &sap->sk_laddr_hash[slot]; + + rcu_read_lock(); +again: + sk_nulls_for_each_rcu(rc, node, laddr_hb) { + if (llc_estab_match(sap, daddr, laddr, rc)) { + /* Extra checks required by SLAB_TYPESAFE_BY_RCU */ + if (unlikely(!refcount_inc_not_zero(&rc->sk_refcnt))) + goto again; + if (unlikely(llc_sk(rc)->sap != sap || + !llc_estab_match(sap, daddr, laddr, rc))) { + sock_put(rc); + continue; + } + goto found; + } + } + rc = NULL; + /* + * if the nulls value we got at the end of this lookup is + * not the expected one, we must restart lookup. + * We probably met an item that was moved to another chain. + */ + if (unlikely(get_nulls_value(node) != slot)) + goto again; +found: + rcu_read_unlock(); + return rc; +} + +struct sock *llc_lookup_established(struct llc_sap *sap, + struct llc_addr *daddr, + struct llc_addr *laddr) +{ + struct sock *sk; + + local_bh_disable(); + sk = __llc_lookup_established(sap, daddr, laddr); + local_bh_enable(); + return sk; +} + +static inline bool llc_listener_match(const struct llc_sap *sap, + const struct llc_addr *laddr, + const struct sock *sk) +{ + struct llc_sock *llc = llc_sk(sk); + + return sk->sk_type == SOCK_STREAM && sk->sk_state == TCP_LISTEN && + llc->laddr.lsap == laddr->lsap && + ether_addr_equal(llc->laddr.mac, laddr->mac); +} + +static struct sock *__llc_lookup_listener(struct llc_sap *sap, + struct llc_addr *laddr) +{ + struct sock *rc; + struct hlist_nulls_node *node; + int slot = llc_sk_laddr_hashfn(sap, laddr); + struct hlist_nulls_head *laddr_hb = &sap->sk_laddr_hash[slot]; + + rcu_read_lock(); +again: + sk_nulls_for_each_rcu(rc, node, laddr_hb) { + if (llc_listener_match(sap, laddr, rc)) { + /* Extra checks required by SLAB_TYPESAFE_BY_RCU */ + if (unlikely(!refcount_inc_not_zero(&rc->sk_refcnt))) + goto again; + if (unlikely(llc_sk(rc)->sap != sap || + !llc_listener_match(sap, laddr, rc))) { + sock_put(rc); + continue; + } + goto found; + } + } + rc = NULL; + /* + * if the nulls value we got at the end of this lookup is + * not the expected one, we must restart lookup. + * We probably met an item that was moved to another chain. + */ + if (unlikely(get_nulls_value(node) != slot)) + goto again; +found: + rcu_read_unlock(); + return rc; +} + +/** + * llc_lookup_listener - Finds listener for local MAC + SAP + * @sap: SAP + * @laddr: address of local LLC (MAC + SAP) + * + * Search connection list of the SAP and finds connection listening on + * local mac, and local sap. Returns pointer for parent socket found, + * %NULL otherwise. + * Caller has to make sure local_bh is disabled. + */ +static struct sock *llc_lookup_listener(struct llc_sap *sap, + struct llc_addr *laddr) +{ + static struct llc_addr null_addr; + struct sock *rc = __llc_lookup_listener(sap, laddr); + + if (!rc) + rc = __llc_lookup_listener(sap, &null_addr); + + return rc; +} + +static struct sock *__llc_lookup(struct llc_sap *sap, + struct llc_addr *daddr, + struct llc_addr *laddr) +{ + struct sock *sk = __llc_lookup_established(sap, daddr, laddr); + + return sk ? : llc_lookup_listener(sap, laddr); +} + +/** + * llc_data_accept_state - designates if in this state data can be sent. + * @state: state of connection. + * + * Returns 0 if data can be sent, 1 otherwise. + */ +u8 llc_data_accept_state(u8 state) +{ + return state != LLC_CONN_STATE_NORMAL && state != LLC_CONN_STATE_BUSY && + state != LLC_CONN_STATE_REJ; +} + +/** + * llc_find_next_offset - finds offset for next category of transitions + * @state: state table. + * @offset: start offset. + * + * Finds offset of next category of transitions in transition table. + * Returns the start index of next category. + */ +static u16 __init llc_find_next_offset(struct llc_conn_state *state, u16 offset) +{ + u16 cnt = 0; + struct llc_conn_state_trans **next_trans; + + for (next_trans = state->transitions + offset; + (*next_trans)->ev; next_trans++) + ++cnt; + return cnt; +} + +/** + * llc_build_offset_table - builds offset table of connection + * + * Fills offset table of connection state transition table + * (llc_offset_table). + */ +void __init llc_build_offset_table(void) +{ + struct llc_conn_state *curr_state; + int state, ev_type, next_offset; + + for (state = 0; state < NBR_CONN_STATES; state++) { + curr_state = &llc_conn_state_table[state]; + next_offset = 0; + for (ev_type = 0; ev_type < NBR_CONN_EV; ev_type++) { + llc_offset_table[state][ev_type] = next_offset; + next_offset += llc_find_next_offset(curr_state, + next_offset) + 1; + } + } +} + +/** + * llc_find_offset - finds start offset of category of transitions + * @state: state of connection + * @ev_type: type of happened event + * + * Finds start offset of desired category of transitions. Returns the + * desired start offset. + */ +static int llc_find_offset(int state, int ev_type) +{ + int rc = 0; + /* at this stage, llc_offset_table[..][2] is not important. it is for + * init_pf_cycle and I don't know what is it. + */ + switch (ev_type) { + case LLC_CONN_EV_TYPE_PRIM: + rc = llc_offset_table[state][0]; break; + case LLC_CONN_EV_TYPE_PDU: + rc = llc_offset_table[state][4]; break; + case LLC_CONN_EV_TYPE_SIMPLE: + rc = llc_offset_table[state][1]; break; + case LLC_CONN_EV_TYPE_P_TMR: + case LLC_CONN_EV_TYPE_ACK_TMR: + case LLC_CONN_EV_TYPE_REJ_TMR: + case LLC_CONN_EV_TYPE_BUSY_TMR: + rc = llc_offset_table[state][3]; break; + } + return rc; +} + +/** + * llc_sap_add_socket - adds a socket to a SAP + * @sap: SAP + * @sk: socket + * + * This function adds a socket to the hash tables of a SAP. + */ +void llc_sap_add_socket(struct llc_sap *sap, struct sock *sk) +{ + struct llc_sock *llc = llc_sk(sk); + struct hlist_head *dev_hb = llc_sk_dev_hash(sap, llc->dev->ifindex); + struct hlist_nulls_head *laddr_hb = llc_sk_laddr_hash(sap, &llc->laddr); + + llc_sap_hold(sap); + llc_sk(sk)->sap = sap; + + spin_lock_bh(&sap->sk_lock); + sock_set_flag(sk, SOCK_RCU_FREE); + sap->sk_count++; + sk_nulls_add_node_rcu(sk, laddr_hb); + hlist_add_head(&llc->dev_hash_node, dev_hb); + spin_unlock_bh(&sap->sk_lock); +} + +/** + * llc_sap_remove_socket - removes a socket from SAP + * @sap: SAP + * @sk: socket + * + * This function removes a connection from the hash tables of a SAP if + * the connection was in this list. + */ +void llc_sap_remove_socket(struct llc_sap *sap, struct sock *sk) +{ + struct llc_sock *llc = llc_sk(sk); + + spin_lock_bh(&sap->sk_lock); + sk_nulls_del_node_init_rcu(sk); + hlist_del(&llc->dev_hash_node); + sap->sk_count--; + spin_unlock_bh(&sap->sk_lock); + llc_sap_put(sap); +} + +/** + * llc_conn_rcv - sends received pdus to the connection state machine + * @sk: current connection structure. + * @skb: received frame. + * + * Sends received pdus to the connection state machine. + */ +static int llc_conn_rcv(struct sock *sk, struct sk_buff *skb) +{ + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + ev->type = LLC_CONN_EV_TYPE_PDU; + ev->reason = 0; + return llc_conn_state_process(sk, skb); +} + +static struct sock *llc_create_incoming_sock(struct sock *sk, + struct net_device *dev, + struct llc_addr *saddr, + struct llc_addr *daddr) +{ + struct sock *newsk = llc_sk_alloc(sock_net(sk), sk->sk_family, GFP_ATOMIC, + sk->sk_prot, 0); + struct llc_sock *newllc, *llc = llc_sk(sk); + + if (!newsk) + goto out; + newllc = llc_sk(newsk); + memcpy(&newllc->laddr, daddr, sizeof(newllc->laddr)); + memcpy(&newllc->daddr, saddr, sizeof(newllc->daddr)); + newllc->dev = dev; + dev_hold(dev); + llc_sap_add_socket(llc->sap, newsk); + llc_sap_hold(llc->sap); +out: + return newsk; +} + +void llc_conn_handler(struct llc_sap *sap, struct sk_buff *skb) +{ + struct llc_addr saddr, daddr; + struct sock *sk; + + llc_pdu_decode_sa(skb, saddr.mac); + llc_pdu_decode_ssap(skb, &saddr.lsap); + llc_pdu_decode_da(skb, daddr.mac); + llc_pdu_decode_dsap(skb, &daddr.lsap); + + sk = __llc_lookup(sap, &saddr, &daddr); + if (!sk) + goto drop; + + bh_lock_sock(sk); + /* + * This has to be done here and not at the upper layer ->accept + * method because of the way the PROCOM state machine works: + * it needs to set several state variables (see, for instance, + * llc_adm_actions_2 in net/llc/llc_c_st.c) and send a packet to + * the originator of the new connection, and this state has to be + * in the newly created struct sock private area. -acme + */ + if (unlikely(sk->sk_state == TCP_LISTEN)) { + struct sock *newsk = llc_create_incoming_sock(sk, skb->dev, + &saddr, &daddr); + if (!newsk) + goto drop_unlock; + skb_set_owner_r(skb, newsk); + } else { + /* + * Can't be skb_set_owner_r, this will be done at the + * llc_conn_state_process function, later on, when we will use + * skb_queue_rcv_skb to send it to upper layers, this is + * another trick required to cope with how the PROCOM state + * machine works. -acme + */ + skb_orphan(skb); + sock_hold(sk); + skb->sk = sk; + skb->destructor = sock_efree; + } + if (!sock_owned_by_user(sk)) + llc_conn_rcv(sk, skb); + else { + dprintk("%s: adding to backlog...\n", __func__); + llc_set_backlog_type(skb, LLC_PACKET); + if (sk_add_backlog(sk, skb, READ_ONCE(sk->sk_rcvbuf))) + goto drop_unlock; + } +out: + bh_unlock_sock(sk); + sock_put(sk); + return; +drop: + kfree_skb(skb); + return; +drop_unlock: + kfree_skb(skb); + goto out; +} + +#undef LLC_REFCNT_DEBUG +#ifdef LLC_REFCNT_DEBUG +static atomic_t llc_sock_nr; +#endif + +/** + * llc_backlog_rcv - Processes rx frames and expired timers. + * @sk: LLC sock (p8022 connection) + * @skb: queued rx frame or event + * + * This function processes frames that has received and timers that has + * expired during sending an I pdu (refer to data_req_handler). frames + * queue by llc_rcv function (llc_mac.c) and timers queue by timer + * callback functions(llc_c_ac.c). + */ +static int llc_backlog_rcv(struct sock *sk, struct sk_buff *skb) +{ + int rc = 0; + struct llc_sock *llc = llc_sk(sk); + + if (likely(llc_backlog_type(skb) == LLC_PACKET)) { + if (likely(llc->state > 1)) /* not closed */ + rc = llc_conn_rcv(sk, skb); + else + goto out_kfree_skb; + } else if (llc_backlog_type(skb) == LLC_EVENT) { + /* timer expiration event */ + if (likely(llc->state > 1)) /* not closed */ + rc = llc_conn_state_process(sk, skb); + else + goto out_kfree_skb; + } else { + printk(KERN_ERR "%s: invalid skb in backlog\n", __func__); + goto out_kfree_skb; + } +out: + return rc; +out_kfree_skb: + kfree_skb(skb); + goto out; +} + +/** + * llc_sk_init - Initializes a socket with default llc values. + * @sk: socket to initialize. + * + * Initializes a socket with default llc values. + */ +static void llc_sk_init(struct sock *sk) +{ + struct llc_sock *llc = llc_sk(sk); + + llc->state = LLC_CONN_STATE_ADM; + llc->inc_cntr = llc->dec_cntr = 2; + llc->dec_step = llc->connect_step = 1; + + timer_setup(&llc->ack_timer.timer, llc_conn_ack_tmr_cb, 0); + llc->ack_timer.expire = sysctl_llc2_ack_timeout; + + timer_setup(&llc->pf_cycle_timer.timer, llc_conn_pf_cycle_tmr_cb, 0); + llc->pf_cycle_timer.expire = sysctl_llc2_p_timeout; + + timer_setup(&llc->rej_sent_timer.timer, llc_conn_rej_tmr_cb, 0); + llc->rej_sent_timer.expire = sysctl_llc2_rej_timeout; + + timer_setup(&llc->busy_state_timer.timer, llc_conn_busy_tmr_cb, 0); + llc->busy_state_timer.expire = sysctl_llc2_busy_timeout; + + llc->n2 = 2; /* max retransmit */ + llc->k = 2; /* tx win size, will adjust dynam */ + llc->rw = 128; /* rx win size (opt and equal to + * tx_win of remote LLC) */ + skb_queue_head_init(&llc->pdu_unack_q); + sk->sk_backlog_rcv = llc_backlog_rcv; +} + +/** + * llc_sk_alloc - Allocates LLC sock + * @net: network namespace + * @family: upper layer protocol family + * @priority: for allocation (%GFP_KERNEL, %GFP_ATOMIC, etc) + * @prot: struct proto associated with this new sock instance + * @kern: is this to be a kernel socket? + * + * Allocates a LLC sock and initializes it. Returns the new LLC sock + * or %NULL if there's no memory available for one + */ +struct sock *llc_sk_alloc(struct net *net, int family, gfp_t priority, struct proto *prot, int kern) +{ + struct sock *sk = sk_alloc(net, family, priority, prot, kern); + + if (!sk) + goto out; + llc_sk_init(sk); + sock_init_data(NULL, sk); +#ifdef LLC_REFCNT_DEBUG + atomic_inc(&llc_sock_nr); + printk(KERN_DEBUG "LLC socket %p created in %s, now we have %d alive\n", sk, + __func__, atomic_read(&llc_sock_nr)); +#endif +out: + return sk; +} + +void llc_sk_stop_all_timers(struct sock *sk, bool sync) +{ + struct llc_sock *llc = llc_sk(sk); + + if (sync) { + del_timer_sync(&llc->pf_cycle_timer.timer); + del_timer_sync(&llc->ack_timer.timer); + del_timer_sync(&llc->rej_sent_timer.timer); + del_timer_sync(&llc->busy_state_timer.timer); + } else { + del_timer(&llc->pf_cycle_timer.timer); + del_timer(&llc->ack_timer.timer); + del_timer(&llc->rej_sent_timer.timer); + del_timer(&llc->busy_state_timer.timer); + } + + llc->ack_must_be_send = 0; + llc->ack_pf = 0; +} + +/** + * llc_sk_free - Frees a LLC socket + * @sk: - socket to free + * + * Frees a LLC socket + */ +void llc_sk_free(struct sock *sk) +{ + struct llc_sock *llc = llc_sk(sk); + + llc->state = LLC_CONN_OUT_OF_SVC; + /* Stop all (possibly) running timers */ + llc_sk_stop_all_timers(sk, true); +#ifdef DEBUG_LLC_CONN_ALLOC + printk(KERN_INFO "%s: unackq=%d, txq=%d\n", __func__, + skb_queue_len(&llc->pdu_unack_q), + skb_queue_len(&sk->sk_write_queue)); +#endif + skb_queue_purge(&sk->sk_receive_queue); + skb_queue_purge(&sk->sk_write_queue); + skb_queue_purge(&llc->pdu_unack_q); +#ifdef LLC_REFCNT_DEBUG + if (refcount_read(&sk->sk_refcnt) != 1) { + printk(KERN_DEBUG "Destruction of LLC sock %p delayed in %s, cnt=%d\n", + sk, __func__, refcount_read(&sk->sk_refcnt)); + printk(KERN_DEBUG "%d LLC sockets are still alive\n", + atomic_read(&llc_sock_nr)); + } else { + atomic_dec(&llc_sock_nr); + printk(KERN_DEBUG "LLC socket %p released in %s, %d are still alive\n", sk, + __func__, atomic_read(&llc_sock_nr)); + } +#endif + sock_put(sk); +} + +/** + * llc_sk_reset - resets a connection + * @sk: LLC socket to reset + * + * Resets a connection to the out of service state. Stops its timers + * and frees any frames in the queues of the connection. + */ +void llc_sk_reset(struct sock *sk) +{ + struct llc_sock *llc = llc_sk(sk); + + llc_conn_ac_stop_all_timers(sk, NULL); + skb_queue_purge(&sk->sk_write_queue); + skb_queue_purge(&llc->pdu_unack_q); + llc->remote_busy_flag = 0; + llc->cause_flag = 0; + llc->retry_count = 0; + llc_conn_set_p_flag(sk, 0); + llc->f_flag = 0; + llc->s_flag = 0; + llc->ack_pf = 0; + llc->first_pdu_Ns = 0; + llc->ack_must_be_send = 0; + llc->dec_step = 1; + llc->inc_cntr = 2; + llc->dec_cntr = 2; + llc->X = 0; + llc->failed_data_req = 0 ; + llc->last_nr = 0; +} diff --git a/net/llc/llc_core.c b/net/llc/llc_core.c new file mode 100644 index 000000000..4f16d9c88 --- /dev/null +++ b/net/llc/llc_core.c @@ -0,0 +1,159 @@ +/* + * llc_core.c - Minimum needed routines for sap handling and module init/exit + * + * Copyright (c) 1997 by Procom Technology, Inc. + * 2001-2003 by Arnaldo Carvalho de Melo + * + * This program can be redistributed or modified under the terms of the + * GNU General Public License as published by the Free Software Foundation. + * This program is distributed without any warranty or implied warranty + * of merchantability or fitness for a particular purpose. + * + * See the GNU General Public License for more details. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +LIST_HEAD(llc_sap_list); +static DEFINE_SPINLOCK(llc_sap_list_lock); + +/** + * llc_sap_alloc - allocates and initializes sap. + * + * Allocates and initializes sap. + */ +static struct llc_sap *llc_sap_alloc(void) +{ + struct llc_sap *sap = kzalloc(sizeof(*sap), GFP_ATOMIC); + int i; + + if (sap) { + /* sap->laddr.mac - leave as a null, it's filled by bind */ + sap->state = LLC_SAP_STATE_ACTIVE; + spin_lock_init(&sap->sk_lock); + for (i = 0; i < LLC_SK_LADDR_HASH_ENTRIES; i++) + INIT_HLIST_NULLS_HEAD(&sap->sk_laddr_hash[i], i); + refcount_set(&sap->refcnt, 1); + } + return sap; +} + +static struct llc_sap *__llc_sap_find(unsigned char sap_value) +{ + struct llc_sap *sap; + + list_for_each_entry(sap, &llc_sap_list, node) + if (sap->laddr.lsap == sap_value) + goto out; + sap = NULL; +out: + return sap; +} + +/** + * llc_sap_find - searches a SAP in station + * @sap_value: sap to be found + * + * Searches for a sap in the sap list of the LLC's station upon the sap ID. + * If the sap is found it will be refcounted and the user will have to do + * a llc_sap_put after use. + * Returns the sap or %NULL if not found. + */ +struct llc_sap *llc_sap_find(unsigned char sap_value) +{ + struct llc_sap *sap; + + rcu_read_lock_bh(); + sap = __llc_sap_find(sap_value); + if (!sap || !llc_sap_hold_safe(sap)) + sap = NULL; + rcu_read_unlock_bh(); + return sap; +} + +/** + * llc_sap_open - open interface to the upper layers. + * @lsap: SAP number. + * @func: rcv func for datalink protos + * + * Interface function to upper layer. Each one who wants to get a SAP + * (for example NetBEUI) should call this function. Returns the opened + * SAP for success, NULL for failure. + */ +struct llc_sap *llc_sap_open(unsigned char lsap, + int (*func)(struct sk_buff *skb, + struct net_device *dev, + struct packet_type *pt, + struct net_device *orig_dev)) +{ + struct llc_sap *sap = NULL; + + spin_lock_bh(&llc_sap_list_lock); + if (__llc_sap_find(lsap)) /* SAP already exists */ + goto out; + sap = llc_sap_alloc(); + if (!sap) + goto out; + sap->laddr.lsap = lsap; + sap->rcv_func = func; + list_add_tail_rcu(&sap->node, &llc_sap_list); +out: + spin_unlock_bh(&llc_sap_list_lock); + return sap; +} + +/** + * llc_sap_close - close interface for upper layers. + * @sap: SAP to be closed. + * + * Close interface function to upper layer. Each one who wants to + * close an open SAP (for example NetBEUI) should call this function. + * Removes this sap from the list of saps in the station and then + * frees the memory for this sap. + */ +void llc_sap_close(struct llc_sap *sap) +{ + WARN_ON(sap->sk_count); + + spin_lock_bh(&llc_sap_list_lock); + list_del_rcu(&sap->node); + spin_unlock_bh(&llc_sap_list_lock); + + kfree_rcu(sap, rcu); +} + +static struct packet_type llc_packet_type __read_mostly = { + .type = cpu_to_be16(ETH_P_802_2), + .func = llc_rcv, +}; + +static int __init llc_init(void) +{ + dev_add_pack(&llc_packet_type); + return 0; +} + +static void __exit llc_exit(void) +{ + dev_remove_pack(&llc_packet_type); +} + +module_init(llc_init); +module_exit(llc_exit); + +EXPORT_SYMBOL(llc_sap_list); +EXPORT_SYMBOL(llc_sap_find); +EXPORT_SYMBOL(llc_sap_open); +EXPORT_SYMBOL(llc_sap_close); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Procom 1997, Jay Schullist 2001, Arnaldo C. Melo 2001-2003"); +MODULE_DESCRIPTION("LLC IEEE 802.2 core support"); diff --git a/net/llc/llc_if.c b/net/llc/llc_if.c new file mode 100644 index 000000000..dde9bf08a --- /dev/null +++ b/net/llc/llc_if.c @@ -0,0 +1,157 @@ +/* + * llc_if.c - Defines LLC interface to upper layer + * + * Copyright (c) 1997 by Procom Technology, Inc. + * 2001-2003 by Arnaldo Carvalho de Melo + * + * This program can be redistributed or modified under the terms of the + * GNU General Public License as published by the Free Software Foundation. + * This program is distributed without any warranty or implied warranty + * of merchantability or fitness for a particular purpose. + * + * See the GNU General Public License for more details. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/** + * llc_build_and_send_pkt - Connection data sending for upper layers. + * @sk: connection + * @skb: packet to send + * + * This function is called when upper layer wants to send data using + * connection oriented communication mode. During sending data, connection + * will be locked and received frames and expired timers will be queued. + * Returns 0 for success, -ECONNABORTED when the connection already + * closed and -EBUSY when sending data is not permitted in this state or + * LLC has send an I pdu with p bit set to 1 and is waiting for it's + * response. + * + * This function always consumes a reference to the skb. + */ +int llc_build_and_send_pkt(struct sock *sk, struct sk_buff *skb) +{ + struct llc_conn_state_ev *ev; + int rc = -ECONNABORTED; + struct llc_sock *llc = llc_sk(sk); + + if (unlikely(llc->state == LLC_CONN_STATE_ADM)) + goto out_free; + rc = -EBUSY; + if (unlikely(llc_data_accept_state(llc->state) || /* data_conn_refuse */ + llc->p_flag)) { + llc->failed_data_req = 1; + goto out_free; + } + ev = llc_conn_ev(skb); + ev->type = LLC_CONN_EV_TYPE_PRIM; + ev->prim = LLC_DATA_PRIM; + ev->prim_type = LLC_PRIM_TYPE_REQ; + skb->dev = llc->dev; + return llc_conn_state_process(sk, skb); + +out_free: + kfree_skb(skb); + return rc; +} + +/** + * llc_establish_connection - Called by upper layer to establish a conn + * @sk: connection + * @lmac: local mac address + * @dmac: destination mac address + * @dsap: destination sap + * + * Upper layer calls this to establish an LLC connection with a remote + * machine. This function packages a proper event and sends it connection + * component state machine. Success or failure of connection + * establishment will inform to upper layer via calling it's confirm + * function and passing proper information. + */ +int llc_establish_connection(struct sock *sk, const u8 *lmac, u8 *dmac, u8 dsap) +{ + int rc = -EISCONN; + struct llc_addr laddr, daddr; + struct sk_buff *skb; + struct llc_sock *llc = llc_sk(sk); + struct sock *existing; + + laddr.lsap = llc->sap->laddr.lsap; + daddr.lsap = dsap; + memcpy(daddr.mac, dmac, sizeof(daddr.mac)); + memcpy(laddr.mac, lmac, sizeof(laddr.mac)); + existing = llc_lookup_established(llc->sap, &daddr, &laddr); + if (existing) { + if (existing->sk_state == TCP_ESTABLISHED) { + sk = existing; + goto out_put; + } else + sock_put(existing); + } + sock_hold(sk); + rc = -ENOMEM; + skb = alloc_skb(0, GFP_ATOMIC); + if (skb) { + struct llc_conn_state_ev *ev = llc_conn_ev(skb); + + ev->type = LLC_CONN_EV_TYPE_PRIM; + ev->prim = LLC_CONN_PRIM; + ev->prim_type = LLC_PRIM_TYPE_REQ; + skb_set_owner_w(skb, sk); + rc = llc_conn_state_process(sk, skb); + } +out_put: + sock_put(sk); + return rc; +} + +/** + * llc_send_disc - Called by upper layer to close a connection + * @sk: connection to be closed + * + * Upper layer calls this when it wants to close an established LLC + * connection with a remote machine. This function packages a proper event + * and sends it to connection component state machine. Returns 0 for + * success, 1 otherwise. + */ +int llc_send_disc(struct sock *sk) +{ + u16 rc = 1; + struct llc_conn_state_ev *ev; + struct sk_buff *skb; + + sock_hold(sk); + if (sk->sk_type != SOCK_STREAM || sk->sk_state != TCP_ESTABLISHED || + llc_sk(sk)->state == LLC_CONN_STATE_ADM || + llc_sk(sk)->state == LLC_CONN_OUT_OF_SVC) + goto out; + /* + * Postpone unassigning the connection from its SAP and returning the + * connection until all ACTIONs have been completely executed + */ + skb = alloc_skb(0, GFP_ATOMIC); + if (!skb) + goto out; + skb_set_owner_w(skb, sk); + sk->sk_state = TCP_CLOSING; + ev = llc_conn_ev(skb); + ev->type = LLC_CONN_EV_TYPE_PRIM; + ev->prim = LLC_DISC_PRIM; + ev->prim_type = LLC_PRIM_TYPE_REQ; + rc = llc_conn_state_process(sk, skb); +out: + sock_put(sk); + return rc; +} diff --git a/net/llc/llc_input.c b/net/llc/llc_input.c new file mode 100644 index 000000000..51bccfb00 --- /dev/null +++ b/net/llc/llc_input.c @@ -0,0 +1,230 @@ +/* + * llc_input.c - Minimal input path for LLC + * + * Copyright (c) 1997 by Procom Technology, Inc. + * 2001-2003 by Arnaldo Carvalho de Melo + * + * This program can be redistributed or modified under the terms of the + * GNU General Public License as published by the Free Software Foundation. + * This program is distributed without any warranty or implied warranty + * of merchantability or fitness for a particular purpose. + * + * See the GNU General Public License for more details. + */ +#include +#include +#include +#include +#include +#include +#include + +#if 0 +#define dprintk(args...) printk(KERN_DEBUG args) +#else +#define dprintk(args...) +#endif + +/* + * Packet handler for the station, registerable because in the minimal + * LLC core that is taking shape only the very minimal subset of LLC that + * is needed for things like IPX, Appletalk, etc will stay, with all the + * rest in the llc1 and llc2 modules. + */ +static void (*llc_station_handler)(struct sk_buff *skb); + +/* + * Packet handlers for LLC_DEST_SAP and LLC_DEST_CONN. + */ +static void (*llc_type_handlers[2])(struct llc_sap *sap, + struct sk_buff *skb); + +void llc_add_pack(int type, void (*handler)(struct llc_sap *sap, + struct sk_buff *skb)) +{ + smp_wmb(); /* ensure initialisation is complete before it's called */ + if (type == LLC_DEST_SAP || type == LLC_DEST_CONN) + llc_type_handlers[type - 1] = handler; +} + +void llc_remove_pack(int type) +{ + if (type == LLC_DEST_SAP || type == LLC_DEST_CONN) + llc_type_handlers[type - 1] = NULL; + synchronize_net(); +} + +void llc_set_station_handler(void (*handler)(struct sk_buff *skb)) +{ + /* Ensure initialisation is complete before it's called */ + if (handler) + smp_wmb(); + + llc_station_handler = handler; + + if (!handler) + synchronize_net(); +} + +/** + * llc_pdu_type - returns which LLC component must handle for PDU + * @skb: input skb + * + * This function returns which LLC component must handle this PDU. + */ +static __inline__ int llc_pdu_type(struct sk_buff *skb) +{ + int type = LLC_DEST_CONN; /* I-PDU or S-PDU type */ + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + if ((pdu->ctrl_1 & LLC_PDU_TYPE_MASK) != LLC_PDU_TYPE_U) + goto out; + switch (LLC_U_PDU_CMD(pdu)) { + case LLC_1_PDU_CMD_XID: + case LLC_1_PDU_CMD_UI: + case LLC_1_PDU_CMD_TEST: + type = LLC_DEST_SAP; + break; + case LLC_2_PDU_CMD_SABME: + case LLC_2_PDU_CMD_DISC: + case LLC_2_PDU_RSP_UA: + case LLC_2_PDU_RSP_DM: + case LLC_2_PDU_RSP_FRMR: + break; + default: + type = LLC_DEST_INVALID; + break; + } +out: + return type; +} + +/** + * llc_fixup_skb - initializes skb pointers + * @skb: This argument points to incoming skb + * + * Initializes internal skb pointer to start of network layer by deriving + * length of LLC header; finds length of LLC control field in LLC header + * by looking at the two lowest-order bits of the first control field + * byte; field is either 3 or 4 bytes long. + */ +static inline int llc_fixup_skb(struct sk_buff *skb) +{ + u8 llc_len = 2; + struct llc_pdu_un *pdu; + + if (unlikely(!pskb_may_pull(skb, sizeof(*pdu)))) + return 0; + + pdu = (struct llc_pdu_un *)skb->data; + if ((pdu->ctrl_1 & LLC_PDU_TYPE_MASK) == LLC_PDU_TYPE_U) + llc_len = 1; + llc_len += 2; + + if (unlikely(!pskb_may_pull(skb, llc_len))) + return 0; + + skb->transport_header += llc_len; + skb_pull(skb, llc_len); + if (skb->protocol == htons(ETH_P_802_2)) { + __be16 pdulen; + s32 data_size; + + if (skb->mac_len < ETH_HLEN) + return 0; + + pdulen = eth_hdr(skb)->h_proto; + data_size = ntohs(pdulen) - llc_len; + + if (data_size < 0 || + !pskb_may_pull(skb, data_size)) + return 0; + if (unlikely(pskb_trim_rcsum(skb, data_size))) + return 0; + } + return 1; +} + +/** + * llc_rcv - 802.2 entry point from net lower layers + * @skb: received pdu + * @dev: device that receive pdu + * @pt: packet type + * @orig_dev: the original receive net device + * + * When the system receives a 802.2 frame this function is called. It + * checks SAP and connection of received pdu and passes frame to + * llc_{station,sap,conn}_rcv for sending to proper state machine. If + * the frame is related to a busy connection (a connection is sending + * data now), it queues this frame in the connection's backlog. + */ +int llc_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + struct llc_sap *sap; + struct llc_pdu_sn *pdu; + int dest; + int (*rcv)(struct sk_buff *, struct net_device *, + struct packet_type *, struct net_device *); + void (*sta_handler)(struct sk_buff *skb); + void (*sap_handler)(struct llc_sap *sap, struct sk_buff *skb); + + /* + * When the interface is in promisc. mode, drop all the crap that it + * receives, do not try to analyse it. + */ + if (unlikely(skb->pkt_type == PACKET_OTHERHOST)) { + dprintk("%s: PACKET_OTHERHOST\n", __func__); + goto drop; + } + skb = skb_share_check(skb, GFP_ATOMIC); + if (unlikely(!skb)) + goto out; + if (unlikely(!llc_fixup_skb(skb))) + goto drop; + pdu = llc_pdu_sn_hdr(skb); + if (unlikely(!pdu->dsap)) /* NULL DSAP, refer to station */ + goto handle_station; + sap = llc_sap_find(pdu->dsap); + if (unlikely(!sap)) {/* unknown SAP */ + dprintk("%s: llc_sap_find(%02X) failed!\n", __func__, + pdu->dsap); + goto drop; + } + /* + * First the upper layer protocols that don't need the full + * LLC functionality + */ + rcv = rcu_dereference(sap->rcv_func); + dest = llc_pdu_type(skb); + sap_handler = dest ? READ_ONCE(llc_type_handlers[dest - 1]) : NULL; + if (unlikely(!sap_handler)) { + if (rcv) + rcv(skb, dev, pt, orig_dev); + else + kfree_skb(skb); + } else { + if (rcv) { + struct sk_buff *cskb = skb_clone(skb, GFP_ATOMIC); + if (cskb) + rcv(cskb, dev, pt, orig_dev); + } + sap_handler(sap, skb); + } + llc_sap_put(sap); +out: + return 0; +drop: + kfree_skb(skb); + goto out; +handle_station: + sta_handler = READ_ONCE(llc_station_handler); + if (!sta_handler) + goto drop; + sta_handler(skb); + goto out; +} + +EXPORT_SYMBOL(llc_add_pack); +EXPORT_SYMBOL(llc_remove_pack); +EXPORT_SYMBOL(llc_set_station_handler); diff --git a/net/llc/llc_output.c b/net/llc/llc_output.c new file mode 100644 index 000000000..5a6466fc6 --- /dev/null +++ b/net/llc/llc_output.c @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * llc_output.c - LLC minimal output path + * + * Copyright (c) 1997 by Procom Technology, Inc. + * 2001-2003 by Arnaldo Carvalho de Melo + */ + +#include +#include +#include +#include +#include +#include + +/** + * llc_mac_hdr_init - fills MAC header fields + * @skb: Address of the frame to initialize its MAC header + * @sa: The MAC source address + * @da: The MAC destination address + * + * Fills MAC header fields, depending on MAC type. Returns 0, If MAC type + * is a valid type and initialization completes correctly 1, otherwise. + */ +int llc_mac_hdr_init(struct sk_buff *skb, + const unsigned char *sa, const unsigned char *da) +{ + int rc = -EINVAL; + + switch (skb->dev->type) { + case ARPHRD_ETHER: + case ARPHRD_LOOPBACK: + rc = dev_hard_header(skb, skb->dev, ETH_P_802_2, da, sa, + skb->len); + if (rc > 0) + rc = 0; + break; + default: + break; + } + return rc; +} + +/** + * llc_build_and_send_ui_pkt - unitdata request interface for upper layers + * @sap: sap to use + * @skb: packet to send + * @dmac: destination mac address + * @dsap: destination sap + * + * Upper layers calls this function when upper layer wants to send data + * using connection-less mode communication (UI pdu). + * + * Accept data frame from network layer to be sent using connection- + * less mode communication; timeout/retries handled by network layer; + * package primitive as an event and send to SAP event handler + */ +int llc_build_and_send_ui_pkt(struct llc_sap *sap, struct sk_buff *skb, + const unsigned char *dmac, unsigned char dsap) +{ + int rc; + llc_pdu_header_init(skb, LLC_PDU_TYPE_U, sap->laddr.lsap, + dsap, LLC_PDU_CMD); + llc_pdu_init_as_ui_cmd(skb); + rc = llc_mac_hdr_init(skb, skb->dev->dev_addr, dmac); + if (likely(!rc)) + rc = dev_queue_xmit(skb); + else + kfree_skb(skb); + return rc; +} + +EXPORT_SYMBOL(llc_mac_hdr_init); +EXPORT_SYMBOL(llc_build_and_send_ui_pkt); diff --git a/net/llc/llc_pdu.c b/net/llc/llc_pdu.c new file mode 100644 index 000000000..63749dde5 --- /dev/null +++ b/net/llc/llc_pdu.c @@ -0,0 +1,372 @@ +/* + * llc_pdu.c - access to PDU internals + * + * Copyright (c) 1997 by Procom Technology, Inc. + * 2001-2003 by Arnaldo Carvalho de Melo + * + * This program can be redistributed or modified under the terms of the + * GNU General Public License as published by the Free Software Foundation. + * This program is distributed without any warranty or implied warranty + * of merchantability or fitness for a particular purpose. + * + * See the GNU General Public License for more details. + */ + +#include +#include + +static void llc_pdu_decode_pdu_type(struct sk_buff *skb, u8 *type); +static u8 llc_pdu_get_pf_bit(struct llc_pdu_sn *pdu); + +void llc_pdu_set_cmd_rsp(struct sk_buff *skb, u8 pdu_type) +{ + llc_pdu_un_hdr(skb)->ssap |= pdu_type; +} + +/** + * llc_pdu_set_pf_bit - sets poll/final bit in LLC header + * @skb: Frame to set bit in + * @bit_value: poll/final bit (0 or 1). + * + * This function sets poll/final bit in LLC header (based on type of PDU). + * in I or S pdus, p/f bit is right bit of fourth byte in header. in U + * pdus p/f bit is fifth bit of third byte. + */ +void llc_pdu_set_pf_bit(struct sk_buff *skb, u8 bit_value) +{ + u8 pdu_type; + struct llc_pdu_sn *pdu; + + llc_pdu_decode_pdu_type(skb, &pdu_type); + pdu = llc_pdu_sn_hdr(skb); + + switch (pdu_type) { + case LLC_PDU_TYPE_I: + case LLC_PDU_TYPE_S: + pdu->ctrl_2 = (pdu->ctrl_2 & 0xFE) | bit_value; + break; + case LLC_PDU_TYPE_U: + pdu->ctrl_1 |= (pdu->ctrl_1 & 0xEF) | (bit_value << 4); + break; + } +} + +/** + * llc_pdu_decode_pf_bit - extracs poll/final bit from LLC header + * @skb: input skb that p/f bit must be extracted from it + * @pf_bit: poll/final bit (0 or 1) + * + * This function extracts poll/final bit from LLC header (based on type of + * PDU). In I or S pdus, p/f bit is right bit of fourth byte in header. In + * U pdus p/f bit is fifth bit of third byte. + */ +void llc_pdu_decode_pf_bit(struct sk_buff *skb, u8 *pf_bit) +{ + u8 pdu_type; + struct llc_pdu_sn *pdu; + + llc_pdu_decode_pdu_type(skb, &pdu_type); + pdu = llc_pdu_sn_hdr(skb); + + switch (pdu_type) { + case LLC_PDU_TYPE_I: + case LLC_PDU_TYPE_S: + *pf_bit = pdu->ctrl_2 & LLC_S_PF_BIT_MASK; + break; + case LLC_PDU_TYPE_U: + *pf_bit = (pdu->ctrl_1 & LLC_U_PF_BIT_MASK) >> 4; + break; + } +} + +/** + * llc_pdu_init_as_disc_cmd - Builds DISC PDU + * @skb: Address of the skb to build + * @p_bit: The P bit to set in the PDU + * + * Builds a pdu frame as a DISC command. + */ +void llc_pdu_init_as_disc_cmd(struct sk_buff *skb, u8 p_bit) +{ + struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + pdu->ctrl_1 = LLC_PDU_TYPE_U; + pdu->ctrl_1 |= LLC_2_PDU_CMD_DISC; + pdu->ctrl_1 |= ((p_bit & 1) << 4) & LLC_U_PF_BIT_MASK; +} + +/** + * llc_pdu_init_as_i_cmd - builds I pdu + * @skb: Address of the skb to build + * @p_bit: The P bit to set in the PDU + * @ns: The sequence number of the data PDU + * @nr: The seq. number of the expected I PDU from the remote + * + * Builds a pdu frame as an I command. + */ +void llc_pdu_init_as_i_cmd(struct sk_buff *skb, u8 p_bit, u8 ns, u8 nr) +{ + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + pdu->ctrl_1 = LLC_PDU_TYPE_I; + pdu->ctrl_2 = 0; + pdu->ctrl_2 |= (p_bit & LLC_I_PF_BIT_MASK); /* p/f bit */ + pdu->ctrl_1 |= (ns << 1) & 0xFE; /* set N(S) in bits 2..8 */ + pdu->ctrl_2 |= (nr << 1) & 0xFE; /* set N(R) in bits 10..16 */ +} + +/** + * llc_pdu_init_as_rej_cmd - builds REJ PDU + * @skb: Address of the skb to build + * @p_bit: The P bit to set in the PDU + * @nr: The seq. number of the expected I PDU from the remote + * + * Builds a pdu frame as a REJ command. + */ +void llc_pdu_init_as_rej_cmd(struct sk_buff *skb, u8 p_bit, u8 nr) +{ + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + pdu->ctrl_1 = LLC_PDU_TYPE_S; + pdu->ctrl_1 |= LLC_2_PDU_CMD_REJ; + pdu->ctrl_2 = 0; + pdu->ctrl_2 |= p_bit & LLC_S_PF_BIT_MASK; + pdu->ctrl_1 &= 0x0F; /* setting bits 5..8 to zero(reserved) */ + pdu->ctrl_2 |= (nr << 1) & 0xFE; /* set N(R) in bits 10..16 */ +} + +/** + * llc_pdu_init_as_rnr_cmd - builds RNR pdu + * @skb: Address of the skb to build + * @p_bit: The P bit to set in the PDU + * @nr: The seq. number of the expected I PDU from the remote + * + * Builds a pdu frame as an RNR command. + */ +void llc_pdu_init_as_rnr_cmd(struct sk_buff *skb, u8 p_bit, u8 nr) +{ + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + pdu->ctrl_1 = LLC_PDU_TYPE_S; + pdu->ctrl_1 |= LLC_2_PDU_CMD_RNR; + pdu->ctrl_2 = 0; + pdu->ctrl_2 |= p_bit & LLC_S_PF_BIT_MASK; + pdu->ctrl_1 &= 0x0F; /* setting bits 5..8 to zero(reserved) */ + pdu->ctrl_2 |= (nr << 1) & 0xFE; /* set N(R) in bits 10..16 */ +} + +/** + * llc_pdu_init_as_rr_cmd - Builds RR pdu + * @skb: Address of the skb to build + * @p_bit: The P bit to set in the PDU + * @nr: The seq. number of the expected I PDU from the remote + * + * Builds a pdu frame as an RR command. + */ +void llc_pdu_init_as_rr_cmd(struct sk_buff *skb, u8 p_bit, u8 nr) +{ + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + pdu->ctrl_1 = LLC_PDU_TYPE_S; + pdu->ctrl_1 |= LLC_2_PDU_CMD_RR; + pdu->ctrl_2 = p_bit & LLC_S_PF_BIT_MASK; + pdu->ctrl_1 &= 0x0F; /* setting bits 5..8 to zero(reserved) */ + pdu->ctrl_2 |= (nr << 1) & 0xFE; /* set N(R) in bits 10..16 */ +} + +/** + * llc_pdu_init_as_sabme_cmd - builds SABME pdu + * @skb: Address of the skb to build + * @p_bit: The P bit to set in the PDU + * + * Builds a pdu frame as an SABME command. + */ +void llc_pdu_init_as_sabme_cmd(struct sk_buff *skb, u8 p_bit) +{ + struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + pdu->ctrl_1 = LLC_PDU_TYPE_U; + pdu->ctrl_1 |= LLC_2_PDU_CMD_SABME; + pdu->ctrl_1 |= ((p_bit & 1) << 4) & LLC_U_PF_BIT_MASK; +} + +/** + * llc_pdu_init_as_dm_rsp - builds DM response pdu + * @skb: Address of the skb to build + * @f_bit: The F bit to set in the PDU + * + * Builds a pdu frame as a DM response. + */ +void llc_pdu_init_as_dm_rsp(struct sk_buff *skb, u8 f_bit) +{ + struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + pdu->ctrl_1 = LLC_PDU_TYPE_U; + pdu->ctrl_1 |= LLC_2_PDU_RSP_DM; + pdu->ctrl_1 |= ((f_bit & 1) << 4) & LLC_U_PF_BIT_MASK; +} + +/** + * llc_pdu_init_as_frmr_rsp - builds FRMR response PDU + * @skb: Address of the frame to build + * @prev_pdu: The rejected PDU frame + * @f_bit: The F bit to set in the PDU + * @vs: tx state vari value for the data link conn at the rejecting LLC + * @vr: rx state var value for the data link conn at the rejecting LLC + * @vzyxw: completely described in the IEEE Std 802.2 document (Pg 55) + * + * Builds a pdu frame as a FRMR response. + */ +void llc_pdu_init_as_frmr_rsp(struct sk_buff *skb, struct llc_pdu_sn *prev_pdu, + u8 f_bit, u8 vs, u8 vr, u8 vzyxw) +{ + struct llc_frmr_info *frmr_info; + u8 prev_pf = 0; + u8 *ctrl; + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + pdu->ctrl_1 = LLC_PDU_TYPE_U; + pdu->ctrl_1 |= LLC_2_PDU_RSP_FRMR; + pdu->ctrl_1 |= ((f_bit & 1) << 4) & LLC_U_PF_BIT_MASK; + + frmr_info = (struct llc_frmr_info *)&pdu->ctrl_2; + ctrl = (u8 *)&prev_pdu->ctrl_1; + FRMR_INFO_SET_REJ_CNTRL(frmr_info,ctrl); + FRMR_INFO_SET_Vs(frmr_info, vs); + FRMR_INFO_SET_Vr(frmr_info, vr); + prev_pf = llc_pdu_get_pf_bit(prev_pdu); + FRMR_INFO_SET_C_R_BIT(frmr_info, prev_pf); + FRMR_INFO_SET_INVALID_PDU_CTRL_IND(frmr_info, vzyxw); + FRMR_INFO_SET_INVALID_PDU_INFO_IND(frmr_info, vzyxw); + FRMR_INFO_SET_PDU_INFO_2LONG_IND(frmr_info, vzyxw); + FRMR_INFO_SET_PDU_INVALID_Nr_IND(frmr_info, vzyxw); + FRMR_INFO_SET_PDU_INVALID_Ns_IND(frmr_info, vzyxw); + skb_put(skb, sizeof(struct llc_frmr_info)); +} + +/** + * llc_pdu_init_as_rr_rsp - builds RR response pdu + * @skb: Address of the skb to build + * @f_bit: The F bit to set in the PDU + * @nr: The seq. number of the expected data PDU from the remote + * + * Builds a pdu frame as an RR response. + */ +void llc_pdu_init_as_rr_rsp(struct sk_buff *skb, u8 f_bit, u8 nr) +{ + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + pdu->ctrl_1 = LLC_PDU_TYPE_S; + pdu->ctrl_1 |= LLC_2_PDU_RSP_RR; + pdu->ctrl_2 = 0; + pdu->ctrl_2 |= f_bit & LLC_S_PF_BIT_MASK; + pdu->ctrl_1 &= 0x0F; /* setting bits 5..8 to zero(reserved) */ + pdu->ctrl_2 |= (nr << 1) & 0xFE; /* set N(R) in bits 10..16 */ +} + +/** + * llc_pdu_init_as_rej_rsp - builds REJ response pdu + * @skb: Address of the skb to build + * @f_bit: The F bit to set in the PDU + * @nr: The seq. number of the expected data PDU from the remote + * + * Builds a pdu frame as a REJ response. + */ +void llc_pdu_init_as_rej_rsp(struct sk_buff *skb, u8 f_bit, u8 nr) +{ + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + pdu->ctrl_1 = LLC_PDU_TYPE_S; + pdu->ctrl_1 |= LLC_2_PDU_RSP_REJ; + pdu->ctrl_2 = 0; + pdu->ctrl_2 |= f_bit & LLC_S_PF_BIT_MASK; + pdu->ctrl_1 &= 0x0F; /* setting bits 5..8 to zero(reserved) */ + pdu->ctrl_2 |= (nr << 1) & 0xFE; /* set N(R) in bits 10..16 */ +} + +/** + * llc_pdu_init_as_rnr_rsp - builds RNR response pdu + * @skb: Address of the frame to build + * @f_bit: The F bit to set in the PDU + * @nr: The seq. number of the expected data PDU from the remote + * + * Builds a pdu frame as an RNR response. + */ +void llc_pdu_init_as_rnr_rsp(struct sk_buff *skb, u8 f_bit, u8 nr) +{ + struct llc_pdu_sn *pdu = llc_pdu_sn_hdr(skb); + + pdu->ctrl_1 = LLC_PDU_TYPE_S; + pdu->ctrl_1 |= LLC_2_PDU_RSP_RNR; + pdu->ctrl_2 = 0; + pdu->ctrl_2 |= f_bit & LLC_S_PF_BIT_MASK; + pdu->ctrl_1 &= 0x0F; /* setting bits 5..8 to zero(reserved) */ + pdu->ctrl_2 |= (nr << 1) & 0xFE; /* set N(R) in bits 10..16 */ +} + +/** + * llc_pdu_init_as_ua_rsp - builds UA response pdu + * @skb: Address of the frame to build + * @f_bit: The F bit to set in the PDU + * + * Builds a pdu frame as a UA response. + */ +void llc_pdu_init_as_ua_rsp(struct sk_buff *skb, u8 f_bit) +{ + struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + pdu->ctrl_1 = LLC_PDU_TYPE_U; + pdu->ctrl_1 |= LLC_2_PDU_RSP_UA; + pdu->ctrl_1 |= ((f_bit & 1) << 4) & LLC_U_PF_BIT_MASK; +} + +/** + * llc_pdu_decode_pdu_type - designates PDU type + * @skb: input skb that type of it must be designated. + * @type: type of PDU (output argument). + * + * This function designates type of PDU (I, S or U). + */ +static void llc_pdu_decode_pdu_type(struct sk_buff *skb, u8 *type) +{ + struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + if (pdu->ctrl_1 & 1) { + if ((pdu->ctrl_1 & LLC_PDU_TYPE_U) == LLC_PDU_TYPE_U) + *type = LLC_PDU_TYPE_U; + else + *type = LLC_PDU_TYPE_S; + } else + *type = LLC_PDU_TYPE_I; +} + +/** + * llc_pdu_get_pf_bit - extracts p/f bit of input PDU + * @pdu: pointer to LLC header. + * + * This function extracts p/f bit of input PDU. at first examines type of + * PDU and then extracts p/f bit. Returns the p/f bit. + */ +static u8 llc_pdu_get_pf_bit(struct llc_pdu_sn *pdu) +{ + u8 pdu_type; + u8 pf_bit = 0; + + if (pdu->ctrl_1 & 1) { + if ((pdu->ctrl_1 & LLC_PDU_TYPE_U) == LLC_PDU_TYPE_U) + pdu_type = LLC_PDU_TYPE_U; + else + pdu_type = LLC_PDU_TYPE_S; + } else + pdu_type = LLC_PDU_TYPE_I; + switch (pdu_type) { + case LLC_PDU_TYPE_I: + case LLC_PDU_TYPE_S: + pf_bit = pdu->ctrl_2 & LLC_S_PF_BIT_MASK; + break; + case LLC_PDU_TYPE_U: + pf_bit = (pdu->ctrl_1 & LLC_U_PF_BIT_MASK) >> 4; + break; + } + return pf_bit; +} diff --git a/net/llc/llc_proc.c b/net/llc/llc_proc.c new file mode 100644 index 000000000..07e9abb59 --- /dev/null +++ b/net/llc/llc_proc.c @@ -0,0 +1,251 @@ +/* + * proc_llc.c - proc interface for LLC + * + * Copyright (c) 2001 by Jay Schulist + * 2002-2003 by Arnaldo Carvalho de Melo + * + * This program can be redistributed or modified under the terms of the + * GNU General Public License as published by the Free Software Foundation. + * This program is distributed without any warranty or implied warranty + * of merchantability or fitness for a particular purpose. + * + * See the GNU General Public License for more details. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void llc_ui_format_mac(struct seq_file *seq, const u8 *addr) +{ + seq_printf(seq, "%pM", addr); +} + +static struct sock *llc_get_sk_idx(loff_t pos) +{ + struct llc_sap *sap; + struct sock *sk = NULL; + int i; + + list_for_each_entry_rcu(sap, &llc_sap_list, node) { + spin_lock_bh(&sap->sk_lock); + for (i = 0; i < LLC_SK_LADDR_HASH_ENTRIES; i++) { + struct hlist_nulls_head *head = &sap->sk_laddr_hash[i]; + struct hlist_nulls_node *node; + + sk_nulls_for_each(sk, node, head) { + if (!pos) + goto found; /* keep the lock */ + --pos; + } + } + spin_unlock_bh(&sap->sk_lock); + } + sk = NULL; +found: + return sk; +} + +static void *llc_seq_start(struct seq_file *seq, loff_t *pos) __acquires(RCU) +{ + loff_t l = *pos; + + rcu_read_lock_bh(); + return l ? llc_get_sk_idx(--l) : SEQ_START_TOKEN; +} + +static struct sock *laddr_hash_next(struct llc_sap *sap, int bucket) +{ + struct hlist_nulls_node *node; + struct sock *sk = NULL; + + while (++bucket < LLC_SK_LADDR_HASH_ENTRIES) + sk_nulls_for_each(sk, node, &sap->sk_laddr_hash[bucket]) + goto out; + +out: + return sk; +} + +static void *llc_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct sock* sk, *next; + struct llc_sock *llc; + struct llc_sap *sap; + + ++*pos; + if (v == SEQ_START_TOKEN) { + sk = llc_get_sk_idx(0); + goto out; + } + sk = v; + next = sk_nulls_next(sk); + if (next) { + sk = next; + goto out; + } + llc = llc_sk(sk); + sap = llc->sap; + sk = laddr_hash_next(sap, llc_sk_laddr_hashfn(sap, &llc->laddr)); + if (sk) + goto out; + spin_unlock_bh(&sap->sk_lock); + list_for_each_entry_continue_rcu(sap, &llc_sap_list, node) { + spin_lock_bh(&sap->sk_lock); + sk = laddr_hash_next(sap, -1); + if (sk) + break; /* keep the lock */ + spin_unlock_bh(&sap->sk_lock); + } +out: + return sk; +} + +static void llc_seq_stop(struct seq_file *seq, void *v) +{ + if (v && v != SEQ_START_TOKEN) { + struct sock *sk = v; + struct llc_sock *llc = llc_sk(sk); + struct llc_sap *sap = llc->sap; + + spin_unlock_bh(&sap->sk_lock); + } + rcu_read_unlock_bh(); +} + +static int llc_seq_socket_show(struct seq_file *seq, void *v) +{ + struct sock* sk; + struct llc_sock *llc; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, "SKt Mc local_mac_sap remote_mac_sap " + " tx_queue rx_queue st uid link\n"); + goto out; + } + sk = v; + llc = llc_sk(sk); + + /* FIXME: check if the address is multicast */ + seq_printf(seq, "%2X %2X ", sk->sk_type, 0); + + if (llc->dev) + llc_ui_format_mac(seq, llc->dev->dev_addr); + else { + u8 addr[6] = {0,0,0,0,0,0}; + llc_ui_format_mac(seq, addr); + } + seq_printf(seq, "@%02X ", llc->sap->laddr.lsap); + llc_ui_format_mac(seq, llc->daddr.mac); + seq_printf(seq, "@%02X %8d %8d %2d %3u %4d\n", llc->daddr.lsap, + sk_wmem_alloc_get(sk), + sk_rmem_alloc_get(sk) - llc->copied_seq, + sk->sk_state, + from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk)), + llc->link); +out: + return 0; +} + +static const char *const llc_conn_state_names[] = { + [LLC_CONN_STATE_ADM] = "adm", + [LLC_CONN_STATE_SETUP] = "setup", + [LLC_CONN_STATE_NORMAL] = "normal", + [LLC_CONN_STATE_BUSY] = "busy", + [LLC_CONN_STATE_REJ] = "rej", + [LLC_CONN_STATE_AWAIT] = "await", + [LLC_CONN_STATE_AWAIT_BUSY] = "await_busy", + [LLC_CONN_STATE_AWAIT_REJ] = "await_rej", + [LLC_CONN_STATE_D_CONN] = "d_conn", + [LLC_CONN_STATE_RESET] = "reset", + [LLC_CONN_STATE_ERROR] = "error", + [LLC_CONN_STATE_TEMP] = "temp", +}; + +static int llc_seq_core_show(struct seq_file *seq, void *v) +{ + struct sock* sk; + struct llc_sock *llc; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, "Connection list:\n" + "dsap state retr txw rxw pf ff sf df rs cs " + "tack tpfc trs tbs blog busr\n"); + goto out; + } + sk = v; + llc = llc_sk(sk); + + seq_printf(seq, " %02X %-10s %3d %3d %3d %2d %2d %2d %2d %2d %2d " + "%4d %4d %3d %3d %4d %4d\n", + llc->daddr.lsap, llc_conn_state_names[llc->state], + llc->retry_count, llc->k, llc->rw, llc->p_flag, llc->f_flag, + llc->s_flag, llc->data_flag, llc->remote_busy_flag, + llc->cause_flag, timer_pending(&llc->ack_timer.timer), + timer_pending(&llc->pf_cycle_timer.timer), + timer_pending(&llc->rej_sent_timer.timer), + timer_pending(&llc->busy_state_timer.timer), + !!sk->sk_backlog.tail, sock_owned_by_user_nocheck(sk)); +out: + return 0; +} + +static const struct seq_operations llc_seq_socket_ops = { + .start = llc_seq_start, + .next = llc_seq_next, + .stop = llc_seq_stop, + .show = llc_seq_socket_show, +}; + +static const struct seq_operations llc_seq_core_ops = { + .start = llc_seq_start, + .next = llc_seq_next, + .stop = llc_seq_stop, + .show = llc_seq_core_show, +}; + +static struct proc_dir_entry *llc_proc_dir; + +int __init llc_proc_init(void) +{ + int rc = -ENOMEM; + struct proc_dir_entry *p; + + llc_proc_dir = proc_mkdir("llc", init_net.proc_net); + if (!llc_proc_dir) + goto out; + + p = proc_create_seq("socket", 0444, llc_proc_dir, &llc_seq_socket_ops); + if (!p) + goto out_socket; + + p = proc_create_seq("core", 0444, llc_proc_dir, &llc_seq_core_ops); + if (!p) + goto out_core; + + rc = 0; +out: + return rc; +out_core: + remove_proc_entry("socket", llc_proc_dir); +out_socket: + remove_proc_entry("llc", init_net.proc_net); + goto out; +} + +void llc_proc_exit(void) +{ + remove_proc_entry("socket", llc_proc_dir); + remove_proc_entry("core", llc_proc_dir); + remove_proc_entry("llc", init_net.proc_net); +} diff --git a/net/llc/llc_s_ac.c b/net/llc/llc_s_ac.c new file mode 100644 index 000000000..06fb8e694 --- /dev/null +++ b/net/llc/llc_s_ac.c @@ -0,0 +1,217 @@ +/* + * llc_s_ac.c - actions performed during sap state transition. + * + * Description : + * Functions in this module are implementation of sap component actions. + * Details of actions can be found in IEEE-802.2 standard document. + * All functions have one sap and one event as input argument. All of + * them return 0 On success and 1 otherwise. + * + * Copyright (c) 1997 by Procom Technology, Inc. + * 2001-2003 by Arnaldo Carvalho de Melo + * + * This program can be redistributed or modified under the terms of the + * GNU General Public License as published by the Free Software Foundation. + * This program is distributed without any warranty or implied warranty + * of merchantability or fitness for a particular purpose. + * + * See the GNU General Public License for more details. + */ + +#include +#include +#include +#include +#include +#include + + +/** + * llc_sap_action_unitdata_ind - forward UI PDU to network layer + * @sap: SAP + * @skb: the event to forward + * + * Received a UI PDU from MAC layer; forward to network layer as a + * UNITDATA INDICATION; verify our event is the kind we expect + */ +int llc_sap_action_unitdata_ind(struct llc_sap *sap, struct sk_buff *skb) +{ + llc_sap_rtn_pdu(sap, skb); + return 0; +} + +/** + * llc_sap_action_send_ui - sends UI PDU resp to UNITDATA REQ to MAC layer + * @sap: SAP + * @skb: the event to send + * + * Sends a UI PDU to the MAC layer in response to a UNITDATA REQUEST + * primitive from the network layer. Verifies event is a primitive type of + * event. Verify the primitive is a UNITDATA REQUEST. + */ +int llc_sap_action_send_ui(struct llc_sap *sap, struct sk_buff *skb) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + int rc; + + llc_pdu_header_init(skb, LLC_PDU_TYPE_U, ev->saddr.lsap, + ev->daddr.lsap, LLC_PDU_CMD); + llc_pdu_init_as_ui_cmd(skb); + rc = llc_mac_hdr_init(skb, ev->saddr.mac, ev->daddr.mac); + if (likely(!rc)) { + skb_get(skb); + rc = dev_queue_xmit(skb); + } + return rc; +} + +/** + * llc_sap_action_send_xid_c - send XID PDU as response to XID REQ + * @sap: SAP + * @skb: the event to send + * + * Send a XID command PDU to MAC layer in response to a XID REQUEST + * primitive from the network layer. Verify event is a primitive type + * event. Verify the primitive is a XID REQUEST. + */ +int llc_sap_action_send_xid_c(struct llc_sap *sap, struct sk_buff *skb) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + int rc; + + llc_pdu_header_init(skb, LLC_PDU_TYPE_U_XID, ev->saddr.lsap, + ev->daddr.lsap, LLC_PDU_CMD); + llc_pdu_init_as_xid_cmd(skb, LLC_XID_NULL_CLASS_2, 0); + rc = llc_mac_hdr_init(skb, ev->saddr.mac, ev->daddr.mac); + if (likely(!rc)) { + skb_get(skb); + rc = dev_queue_xmit(skb); + } + return rc; +} + +/** + * llc_sap_action_send_xid_r - send XID PDU resp to MAC for received XID + * @sap: SAP + * @skb: the event to send + * + * Send XID response PDU to MAC in response to an earlier received XID + * command PDU. Verify event is a PDU type event + */ +int llc_sap_action_send_xid_r(struct llc_sap *sap, struct sk_buff *skb) +{ + u8 mac_da[ETH_ALEN], mac_sa[ETH_ALEN], dsap; + int rc = 1; + struct sk_buff *nskb; + + llc_pdu_decode_sa(skb, mac_da); + llc_pdu_decode_da(skb, mac_sa); + llc_pdu_decode_ssap(skb, &dsap); + nskb = llc_alloc_frame(NULL, skb->dev, LLC_PDU_TYPE_U, + sizeof(struct llc_xid_info)); + if (!nskb) + goto out; + llc_pdu_header_init(nskb, LLC_PDU_TYPE_U, sap->laddr.lsap, dsap, + LLC_PDU_RSP); + llc_pdu_init_as_xid_rsp(nskb, LLC_XID_NULL_CLASS_2, 0); + rc = llc_mac_hdr_init(nskb, mac_sa, mac_da); + if (likely(!rc)) + rc = dev_queue_xmit(nskb); +out: + return rc; +} + +/** + * llc_sap_action_send_test_c - send TEST PDU to MAC in resp to TEST REQ + * @sap: SAP + * @skb: the event to send + * + * Send a TEST command PDU to the MAC layer in response to a TEST REQUEST + * primitive from the network layer. Verify event is a primitive type + * event; verify the primitive is a TEST REQUEST. + */ +int llc_sap_action_send_test_c(struct llc_sap *sap, struct sk_buff *skb) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + int rc; + + llc_pdu_header_init(skb, LLC_PDU_TYPE_U, ev->saddr.lsap, + ev->daddr.lsap, LLC_PDU_CMD); + llc_pdu_init_as_test_cmd(skb); + rc = llc_mac_hdr_init(skb, ev->saddr.mac, ev->daddr.mac); + if (likely(!rc)) { + skb_get(skb); + rc = dev_queue_xmit(skb); + } + return rc; +} + +int llc_sap_action_send_test_r(struct llc_sap *sap, struct sk_buff *skb) +{ + u8 mac_da[ETH_ALEN], mac_sa[ETH_ALEN], dsap; + struct sk_buff *nskb; + int rc = 1; + u32 data_size; + + if (skb->mac_len < ETH_HLEN) + return 1; + + llc_pdu_decode_sa(skb, mac_da); + llc_pdu_decode_da(skb, mac_sa); + llc_pdu_decode_ssap(skb, &dsap); + + /* The test request command is type U (llc_len = 3) */ + data_size = ntohs(eth_hdr(skb)->h_proto) - 3; + nskb = llc_alloc_frame(NULL, skb->dev, LLC_PDU_TYPE_U, data_size); + if (!nskb) + goto out; + llc_pdu_header_init(nskb, LLC_PDU_TYPE_U, sap->laddr.lsap, dsap, + LLC_PDU_RSP); + llc_pdu_init_as_test_rsp(nskb, skb); + rc = llc_mac_hdr_init(nskb, mac_sa, mac_da); + if (likely(!rc)) + rc = dev_queue_xmit(nskb); +out: + return rc; +} + +/** + * llc_sap_action_report_status - report data link status to layer mgmt + * @sap: SAP + * @skb: the event to send + * + * Report data link status to layer management. Verify our event is the + * kind we expect. + */ +int llc_sap_action_report_status(struct llc_sap *sap, struct sk_buff *skb) +{ + return 0; +} + +/** + * llc_sap_action_xid_ind - send XID PDU resp to net layer via XID IND + * @sap: SAP + * @skb: the event to send + * + * Send a XID response PDU to the network layer via a XID INDICATION + * primitive. + */ +int llc_sap_action_xid_ind(struct llc_sap *sap, struct sk_buff *skb) +{ + llc_sap_rtn_pdu(sap, skb); + return 0; +} + +/** + * llc_sap_action_test_ind - send TEST PDU to net layer via TEST IND + * @sap: SAP + * @skb: the event to send + * + * Send a TEST response PDU to the network layer via a TEST INDICATION + * primitive. Verify our event is a PDU type event. + */ +int llc_sap_action_test_ind(struct llc_sap *sap, struct sk_buff *skb) +{ + llc_sap_rtn_pdu(sap, skb); + return 0; +} diff --git a/net/llc/llc_s_ev.c b/net/llc/llc_s_ev.c new file mode 100644 index 000000000..a74d2a1d6 --- /dev/null +++ b/net/llc/llc_s_ev.c @@ -0,0 +1,115 @@ +/* + * llc_s_ev.c - Defines SAP component events + * + * The followed event functions are SAP component events which are described + * in 802.2 LLC protocol standard document. + * + * Copyright (c) 1997 by Procom Technology, Inc. + * 2001-2003 by Arnaldo Carvalho de Melo + * + * This program can be redistributed or modified under the terms of the + * GNU General Public License as published by the Free Software Foundation. + * This program is distributed without any warranty or implied warranty + * of merchantability or fitness for a particular purpose. + * + * See the GNU General Public License for more details. + */ +#include +#include +#include +#include +#include + +int llc_sap_ev_activation_req(struct llc_sap *sap, struct sk_buff *skb) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + + return ev->type == LLC_SAP_EV_TYPE_SIMPLE && + ev->prim_type == LLC_SAP_EV_ACTIVATION_REQ ? 0 : 1; +} + +int llc_sap_ev_rx_ui(struct llc_sap *sap, struct sk_buff *skb) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + return ev->type == LLC_SAP_EV_TYPE_PDU && LLC_PDU_IS_CMD(pdu) && + LLC_PDU_TYPE_IS_U(pdu) && + LLC_U_PDU_CMD(pdu) == LLC_1_PDU_CMD_UI ? 0 : 1; +} + +int llc_sap_ev_unitdata_req(struct llc_sap *sap, struct sk_buff *skb) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + + return ev->type == LLC_SAP_EV_TYPE_PRIM && + ev->prim == LLC_DATAUNIT_PRIM && + ev->prim_type == LLC_PRIM_TYPE_REQ ? 0 : 1; + +} + +int llc_sap_ev_xid_req(struct llc_sap *sap, struct sk_buff *skb) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + + return ev->type == LLC_SAP_EV_TYPE_PRIM && + ev->prim == LLC_XID_PRIM && + ev->prim_type == LLC_PRIM_TYPE_REQ ? 0 : 1; +} + +int llc_sap_ev_rx_xid_c(struct llc_sap *sap, struct sk_buff *skb) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + return ev->type == LLC_SAP_EV_TYPE_PDU && LLC_PDU_IS_CMD(pdu) && + LLC_PDU_TYPE_IS_U(pdu) && + LLC_U_PDU_CMD(pdu) == LLC_1_PDU_CMD_XID ? 0 : 1; +} + +int llc_sap_ev_rx_xid_r(struct llc_sap *sap, struct sk_buff *skb) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + return ev->type == LLC_SAP_EV_TYPE_PDU && LLC_PDU_IS_RSP(pdu) && + LLC_PDU_TYPE_IS_U(pdu) && + LLC_U_PDU_RSP(pdu) == LLC_1_PDU_CMD_XID ? 0 : 1; +} + +int llc_sap_ev_test_req(struct llc_sap *sap, struct sk_buff *skb) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + + return ev->type == LLC_SAP_EV_TYPE_PRIM && + ev->prim == LLC_TEST_PRIM && + ev->prim_type == LLC_PRIM_TYPE_REQ ? 0 : 1; +} + +int llc_sap_ev_rx_test_c(struct llc_sap *sap, struct sk_buff *skb) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + return ev->type == LLC_SAP_EV_TYPE_PDU && LLC_PDU_IS_CMD(pdu) && + LLC_PDU_TYPE_IS_U(pdu) && + LLC_U_PDU_CMD(pdu) == LLC_1_PDU_CMD_TEST ? 0 : 1; +} + +int llc_sap_ev_rx_test_r(struct llc_sap *sap, struct sk_buff *skb) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + return ev->type == LLC_SAP_EV_TYPE_PDU && LLC_PDU_IS_RSP(pdu) && + LLC_PDU_TYPE_IS_U(pdu) && + LLC_U_PDU_RSP(pdu) == LLC_1_PDU_CMD_TEST ? 0 : 1; +} + +int llc_sap_ev_deactivation_req(struct llc_sap *sap, struct sk_buff *skb) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + + return ev->type == LLC_SAP_EV_TYPE_SIMPLE && + ev->prim_type == LLC_SAP_EV_DEACTIVATION_REQ ? 0 : 1; +} diff --git a/net/llc/llc_s_st.c b/net/llc/llc_s_st.c new file mode 100644 index 000000000..308c61688 --- /dev/null +++ b/net/llc/llc_s_st.c @@ -0,0 +1,183 @@ +/* + * llc_s_st.c - Defines SAP component state machine transitions. + * + * The followed transitions are SAP component state machine transitions + * which are described in 802.2 LLC protocol standard document. + * + * Copyright (c) 1997 by Procom Technology, Inc. + * 2001-2003 by Arnaldo Carvalho de Melo + * + * This program can be redistributed or modified under the terms of the + * GNU General Public License as published by the Free Software Foundation. + * This program is distributed without any warranty or implied warranty + * of merchantability or fitness for a particular purpose. + * + * See the GNU General Public License for more details. + */ +#include +#include +#include +#include +#include + +/* dummy last-transition indicator; common to all state transition groups + * last entry for this state + * all members are zeros, .bss zeroes it + */ +static struct llc_sap_state_trans llc_sap_state_trans_end; + +/* state LLC_SAP_STATE_INACTIVE transition for + * LLC_SAP_EV_ACTIVATION_REQ event + */ +static const llc_sap_action_t llc_sap_inactive_state_actions_1[] = { + [0] = llc_sap_action_report_status, + [1] = NULL, +}; + +static struct llc_sap_state_trans llc_sap_inactive_state_trans_1 = { + .ev = llc_sap_ev_activation_req, + .next_state = LLC_SAP_STATE_ACTIVE, + .ev_actions = llc_sap_inactive_state_actions_1, +}; + +/* array of pointers; one to each transition */ +static struct llc_sap_state_trans *llc_sap_inactive_state_transitions[] = { + [0] = &llc_sap_inactive_state_trans_1, + [1] = &llc_sap_state_trans_end, +}; + +/* state LLC_SAP_STATE_ACTIVE transition for LLC_SAP_EV_RX_UI event */ +static const llc_sap_action_t llc_sap_active_state_actions_1[] = { + [0] = llc_sap_action_unitdata_ind, + [1] = NULL, +}; + +static struct llc_sap_state_trans llc_sap_active_state_trans_1 = { + .ev = llc_sap_ev_rx_ui, + .next_state = LLC_SAP_STATE_ACTIVE, + .ev_actions = llc_sap_active_state_actions_1, +}; + +/* state LLC_SAP_STATE_ACTIVE transition for LLC_SAP_EV_UNITDATA_REQ event */ +static const llc_sap_action_t llc_sap_active_state_actions_2[] = { + [0] = llc_sap_action_send_ui, + [1] = NULL, +}; + +static struct llc_sap_state_trans llc_sap_active_state_trans_2 = { + .ev = llc_sap_ev_unitdata_req, + .next_state = LLC_SAP_STATE_ACTIVE, + .ev_actions = llc_sap_active_state_actions_2, +}; + +/* state LLC_SAP_STATE_ACTIVE transition for LLC_SAP_EV_XID_REQ event */ +static const llc_sap_action_t llc_sap_active_state_actions_3[] = { + [0] = llc_sap_action_send_xid_c, + [1] = NULL, +}; + +static struct llc_sap_state_trans llc_sap_active_state_trans_3 = { + .ev = llc_sap_ev_xid_req, + .next_state = LLC_SAP_STATE_ACTIVE, + .ev_actions = llc_sap_active_state_actions_3, +}; + +/* state LLC_SAP_STATE_ACTIVE transition for LLC_SAP_EV_RX_XID_C event */ +static const llc_sap_action_t llc_sap_active_state_actions_4[] = { + [0] = llc_sap_action_send_xid_r, + [1] = NULL, +}; + +static struct llc_sap_state_trans llc_sap_active_state_trans_4 = { + .ev = llc_sap_ev_rx_xid_c, + .next_state = LLC_SAP_STATE_ACTIVE, + .ev_actions = llc_sap_active_state_actions_4, +}; + +/* state LLC_SAP_STATE_ACTIVE transition for LLC_SAP_EV_RX_XID_R event */ +static const llc_sap_action_t llc_sap_active_state_actions_5[] = { + [0] = llc_sap_action_xid_ind, + [1] = NULL, +}; + +static struct llc_sap_state_trans llc_sap_active_state_trans_5 = { + .ev = llc_sap_ev_rx_xid_r, + .next_state = LLC_SAP_STATE_ACTIVE, + .ev_actions = llc_sap_active_state_actions_5, +}; + +/* state LLC_SAP_STATE_ACTIVE transition for LLC_SAP_EV_TEST_REQ event */ +static const llc_sap_action_t llc_sap_active_state_actions_6[] = { + [0] = llc_sap_action_send_test_c, + [1] = NULL, +}; + +static struct llc_sap_state_trans llc_sap_active_state_trans_6 = { + .ev = llc_sap_ev_test_req, + .next_state = LLC_SAP_STATE_ACTIVE, + .ev_actions = llc_sap_active_state_actions_6, +}; + +/* state LLC_SAP_STATE_ACTIVE transition for LLC_SAP_EV_RX_TEST_C event */ +static const llc_sap_action_t llc_sap_active_state_actions_7[] = { + [0] = llc_sap_action_send_test_r, + [1] = NULL, +}; + +static struct llc_sap_state_trans llc_sap_active_state_trans_7 = { + .ev = llc_sap_ev_rx_test_c, + .next_state = LLC_SAP_STATE_ACTIVE, + .ev_actions = llc_sap_active_state_actions_7 +}; + +/* state LLC_SAP_STATE_ACTIVE transition for LLC_SAP_EV_RX_TEST_R event */ +static const llc_sap_action_t llc_sap_active_state_actions_8[] = { + [0] = llc_sap_action_test_ind, + [1] = NULL, +}; + +static struct llc_sap_state_trans llc_sap_active_state_trans_8 = { + .ev = llc_sap_ev_rx_test_r, + .next_state = LLC_SAP_STATE_ACTIVE, + .ev_actions = llc_sap_active_state_actions_8, +}; + +/* state LLC_SAP_STATE_ACTIVE transition for + * LLC_SAP_EV_DEACTIVATION_REQ event + */ +static const llc_sap_action_t llc_sap_active_state_actions_9[] = { + [0] = llc_sap_action_report_status, + [1] = NULL, +}; + +static struct llc_sap_state_trans llc_sap_active_state_trans_9 = { + .ev = llc_sap_ev_deactivation_req, + .next_state = LLC_SAP_STATE_INACTIVE, + .ev_actions = llc_sap_active_state_actions_9 +}; + +/* array of pointers; one to each transition */ +static struct llc_sap_state_trans *llc_sap_active_state_transitions[] = { + [0] = &llc_sap_active_state_trans_2, + [1] = &llc_sap_active_state_trans_1, + [2] = &llc_sap_active_state_trans_3, + [3] = &llc_sap_active_state_trans_4, + [4] = &llc_sap_active_state_trans_5, + [5] = &llc_sap_active_state_trans_6, + [6] = &llc_sap_active_state_trans_7, + [7] = &llc_sap_active_state_trans_8, + [8] = &llc_sap_active_state_trans_9, + [9] = &llc_sap_state_trans_end, +}; + +/* SAP state transition table */ +struct llc_sap_state llc_sap_state_table[LLC_NR_SAP_STATES] = { + [LLC_SAP_STATE_INACTIVE - 1] = { + .curr_state = LLC_SAP_STATE_INACTIVE, + .transitions = llc_sap_inactive_state_transitions, + }, + [LLC_SAP_STATE_ACTIVE - 1] = { + .curr_state = LLC_SAP_STATE_ACTIVE, + .transitions = llc_sap_active_state_transitions, + }, +}; diff --git a/net/llc/llc_sap.c b/net/llc/llc_sap.c new file mode 100644 index 000000000..6805ce43a --- /dev/null +++ b/net/llc/llc_sap.c @@ -0,0 +1,439 @@ +/* + * llc_sap.c - driver routines for SAP component. + * + * Copyright (c) 1997 by Procom Technology, Inc. + * 2001-2003 by Arnaldo Carvalho de Melo + * + * This program can be redistributed or modified under the terms of the + * GNU General Public License as published by the Free Software Foundation. + * This program is distributed without any warranty or implied warranty + * of merchantability or fitness for a particular purpose. + * + * See the GNU General Public License for more details. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int llc_mac_header_len(unsigned short devtype) +{ + switch (devtype) { + case ARPHRD_ETHER: + case ARPHRD_LOOPBACK: + return sizeof(struct ethhdr); + } + return 0; +} + +/** + * llc_alloc_frame - allocates sk_buff for frame + * @sk: socket to allocate frame to + * @dev: network device this skb will be sent over + * @type: pdu type to allocate + * @data_size: data size to allocate + * + * Allocates an sk_buff for frame and initializes sk_buff fields. + * Returns allocated skb or %NULL when out of memory. + */ +struct sk_buff *llc_alloc_frame(struct sock *sk, struct net_device *dev, + u8 type, u32 data_size) +{ + int hlen = type == LLC_PDU_TYPE_U ? 3 : 4; + struct sk_buff *skb; + + hlen += llc_mac_header_len(dev->type); + skb = alloc_skb(hlen + data_size, GFP_ATOMIC); + + if (skb) { + skb_reset_mac_header(skb); + skb_reserve(skb, hlen); + skb_reset_network_header(skb); + skb_reset_transport_header(skb); + skb->protocol = htons(ETH_P_802_2); + skb->dev = dev; + if (sk != NULL) + skb_set_owner_w(skb, sk); + } + return skb; +} + +void llc_save_primitive(struct sock *sk, struct sk_buff *skb, u8 prim) +{ + struct sockaddr_llc *addr; + + /* save primitive for use by the user. */ + addr = llc_ui_skb_cb(skb); + + memset(addr, 0, sizeof(*addr)); + addr->sllc_family = sk->sk_family; + addr->sllc_arphrd = skb->dev->type; + addr->sllc_test = prim == LLC_TEST_PRIM; + addr->sllc_xid = prim == LLC_XID_PRIM; + addr->sllc_ua = prim == LLC_DATAUNIT_PRIM; + llc_pdu_decode_sa(skb, addr->sllc_mac); + llc_pdu_decode_ssap(skb, &addr->sllc_sap); +} + +/** + * llc_sap_rtn_pdu - Informs upper layer on rx of an UI, XID or TEST pdu. + * @sap: pointer to SAP + * @skb: received pdu + */ +void llc_sap_rtn_pdu(struct llc_sap *sap, struct sk_buff *skb) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + switch (LLC_U_PDU_RSP(pdu)) { + case LLC_1_PDU_CMD_TEST: + ev->prim = LLC_TEST_PRIM; break; + case LLC_1_PDU_CMD_XID: + ev->prim = LLC_XID_PRIM; break; + case LLC_1_PDU_CMD_UI: + ev->prim = LLC_DATAUNIT_PRIM; break; + } + ev->ind_cfm_flag = LLC_IND; +} + +/** + * llc_find_sap_trans - finds transition for event + * @sap: pointer to SAP + * @skb: happened event + * + * This function finds transition that matches with happened event. + * Returns the pointer to found transition on success or %NULL for + * failure. + */ +static struct llc_sap_state_trans *llc_find_sap_trans(struct llc_sap *sap, + struct sk_buff *skb) +{ + int i = 0; + struct llc_sap_state_trans *rc = NULL; + struct llc_sap_state_trans **next_trans; + struct llc_sap_state *curr_state = &llc_sap_state_table[sap->state - 1]; + /* + * Search thru events for this state until list exhausted or until + * its obvious the event is not valid for the current state + */ + for (next_trans = curr_state->transitions; next_trans[i]->ev; i++) + if (!next_trans[i]->ev(sap, skb)) { + rc = next_trans[i]; /* got event match; return it */ + break; + } + return rc; +} + +/** + * llc_exec_sap_trans_actions - execute actions related to event + * @sap: pointer to SAP + * @trans: pointer to transition that it's actions must be performed + * @skb: happened event. + * + * This function executes actions that is related to happened event. + * Returns 0 for success and 1 for failure of at least one action. + */ +static int llc_exec_sap_trans_actions(struct llc_sap *sap, + struct llc_sap_state_trans *trans, + struct sk_buff *skb) +{ + int rc = 0; + const llc_sap_action_t *next_action = trans->ev_actions; + + for (; next_action && *next_action; next_action++) + if ((*next_action)(sap, skb)) + rc = 1; + return rc; +} + +/** + * llc_sap_next_state - finds transition, execs actions & change SAP state + * @sap: pointer to SAP + * @skb: happened event + * + * This function finds transition that matches with happened event, then + * executes related actions and finally changes state of SAP. It returns + * 0 on success and 1 for failure. + */ +static int llc_sap_next_state(struct llc_sap *sap, struct sk_buff *skb) +{ + int rc = 1; + struct llc_sap_state_trans *trans; + + if (sap->state > LLC_NR_SAP_STATES) + goto out; + trans = llc_find_sap_trans(sap, skb); + if (!trans) + goto out; + /* + * Got the state to which we next transition; perform the actions + * associated with this transition before actually transitioning to the + * next state + */ + rc = llc_exec_sap_trans_actions(sap, trans, skb); + if (rc) + goto out; + /* + * Transition SAP to next state if all actions execute successfully + */ + sap->state = trans->next_state; +out: + return rc; +} + +/** + * llc_sap_state_process - sends event to SAP state machine + * @sap: sap to use + * @skb: pointer to occurred event + * + * After executing actions of the event, upper layer will be indicated + * if needed(on receiving an UI frame). sk can be null for the + * datalink_proto case. + * + * This function always consumes a reference to the skb. + */ +static void llc_sap_state_process(struct llc_sap *sap, struct sk_buff *skb) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + + ev->ind_cfm_flag = 0; + llc_sap_next_state(sap, skb); + + if (ev->ind_cfm_flag == LLC_IND && skb->sk->sk_state != TCP_LISTEN) { + llc_save_primitive(skb->sk, skb, ev->prim); + + /* queue skb to the user. */ + if (sock_queue_rcv_skb(skb->sk, skb) == 0) + return; + } + kfree_skb(skb); +} + +/** + * llc_build_and_send_test_pkt - TEST interface for upper layers. + * @sap: sap to use + * @skb: packet to send + * @dmac: destination mac address + * @dsap: destination sap + * + * This function is called when upper layer wants to send a TEST pdu. + * Returns 0 for success, 1 otherwise. + */ +void llc_build_and_send_test_pkt(struct llc_sap *sap, + struct sk_buff *skb, u8 *dmac, u8 dsap) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + + ev->saddr.lsap = sap->laddr.lsap; + ev->daddr.lsap = dsap; + memcpy(ev->saddr.mac, skb->dev->dev_addr, IFHWADDRLEN); + memcpy(ev->daddr.mac, dmac, IFHWADDRLEN); + + ev->type = LLC_SAP_EV_TYPE_PRIM; + ev->prim = LLC_TEST_PRIM; + ev->prim_type = LLC_PRIM_TYPE_REQ; + llc_sap_state_process(sap, skb); +} + +/** + * llc_build_and_send_xid_pkt - XID interface for upper layers + * @sap: sap to use + * @skb: packet to send + * @dmac: destination mac address + * @dsap: destination sap + * + * This function is called when upper layer wants to send a XID pdu. + * Returns 0 for success, 1 otherwise. + */ +void llc_build_and_send_xid_pkt(struct llc_sap *sap, struct sk_buff *skb, + u8 *dmac, u8 dsap) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + + ev->saddr.lsap = sap->laddr.lsap; + ev->daddr.lsap = dsap; + memcpy(ev->saddr.mac, skb->dev->dev_addr, IFHWADDRLEN); + memcpy(ev->daddr.mac, dmac, IFHWADDRLEN); + + ev->type = LLC_SAP_EV_TYPE_PRIM; + ev->prim = LLC_XID_PRIM; + ev->prim_type = LLC_PRIM_TYPE_REQ; + llc_sap_state_process(sap, skb); +} + +/** + * llc_sap_rcv - sends received pdus to the sap state machine + * @sap: current sap component structure. + * @skb: received frame. + * @sk: socket to associate to frame + * + * Sends received pdus to the sap state machine. + */ +static void llc_sap_rcv(struct llc_sap *sap, struct sk_buff *skb, + struct sock *sk) +{ + struct llc_sap_state_ev *ev = llc_sap_ev(skb); + + ev->type = LLC_SAP_EV_TYPE_PDU; + ev->reason = 0; + skb_orphan(skb); + sock_hold(sk); + skb->sk = sk; + skb->destructor = sock_efree; + llc_sap_state_process(sap, skb); +} + +static inline bool llc_dgram_match(const struct llc_sap *sap, + const struct llc_addr *laddr, + const struct sock *sk) +{ + struct llc_sock *llc = llc_sk(sk); + + return sk->sk_type == SOCK_DGRAM && + llc->laddr.lsap == laddr->lsap && + ether_addr_equal(llc->laddr.mac, laddr->mac); +} + +/** + * llc_lookup_dgram - Finds dgram socket for the local sap/mac + * @sap: SAP + * @laddr: address of local LLC (MAC + SAP) + * + * Search socket list of the SAP and finds connection using the local + * mac, and local sap. Returns pointer for socket found, %NULL otherwise. + */ +static struct sock *llc_lookup_dgram(struct llc_sap *sap, + const struct llc_addr *laddr) +{ + struct sock *rc; + struct hlist_nulls_node *node; + int slot = llc_sk_laddr_hashfn(sap, laddr); + struct hlist_nulls_head *laddr_hb = &sap->sk_laddr_hash[slot]; + + rcu_read_lock_bh(); +again: + sk_nulls_for_each_rcu(rc, node, laddr_hb) { + if (llc_dgram_match(sap, laddr, rc)) { + /* Extra checks required by SLAB_TYPESAFE_BY_RCU */ + if (unlikely(!refcount_inc_not_zero(&rc->sk_refcnt))) + goto again; + if (unlikely(llc_sk(rc)->sap != sap || + !llc_dgram_match(sap, laddr, rc))) { + sock_put(rc); + continue; + } + goto found; + } + } + rc = NULL; + /* + * if the nulls value we got at the end of this lookup is + * not the expected one, we must restart lookup. + * We probably met an item that was moved to another chain. + */ + if (unlikely(get_nulls_value(node) != slot)) + goto again; +found: + rcu_read_unlock_bh(); + return rc; +} + +static inline bool llc_mcast_match(const struct llc_sap *sap, + const struct llc_addr *laddr, + const struct sk_buff *skb, + const struct sock *sk) +{ + struct llc_sock *llc = llc_sk(sk); + + return sk->sk_type == SOCK_DGRAM && + llc->laddr.lsap == laddr->lsap && + llc->dev == skb->dev; +} + +static void llc_do_mcast(struct llc_sap *sap, struct sk_buff *skb, + struct sock **stack, int count) +{ + struct sk_buff *skb1; + int i; + + for (i = 0; i < count; i++) { + skb1 = skb_clone(skb, GFP_ATOMIC); + if (!skb1) { + sock_put(stack[i]); + continue; + } + + llc_sap_rcv(sap, skb1, stack[i]); + sock_put(stack[i]); + } +} + +/** + * llc_sap_mcast - Deliver multicast PDU's to all matching datagram sockets. + * @sap: SAP + * @laddr: address of local LLC (MAC + SAP) + * @skb: PDU to deliver + * + * Search socket list of the SAP and finds connections with same sap. + * Deliver clone to each. + */ +static void llc_sap_mcast(struct llc_sap *sap, + const struct llc_addr *laddr, + struct sk_buff *skb) +{ + int i = 0; + struct sock *sk; + struct sock *stack[256 / sizeof(struct sock *)]; + struct llc_sock *llc; + struct hlist_head *dev_hb = llc_sk_dev_hash(sap, skb->dev->ifindex); + + spin_lock_bh(&sap->sk_lock); + hlist_for_each_entry(llc, dev_hb, dev_hash_node) { + + sk = &llc->sk; + + if (!llc_mcast_match(sap, laddr, skb, sk)) + continue; + + sock_hold(sk); + if (i < ARRAY_SIZE(stack)) + stack[i++] = sk; + else { + llc_do_mcast(sap, skb, stack, i); + i = 0; + } + } + spin_unlock_bh(&sap->sk_lock); + + llc_do_mcast(sap, skb, stack, i); +} + + +void llc_sap_handler(struct llc_sap *sap, struct sk_buff *skb) +{ + struct llc_addr laddr; + + llc_pdu_decode_da(skb, laddr.mac); + llc_pdu_decode_dsap(skb, &laddr.lsap); + + if (is_multicast_ether_addr(laddr.mac)) { + llc_sap_mcast(sap, &laddr, skb); + kfree_skb(skb); + } else { + struct sock *sk = llc_lookup_dgram(sap, &laddr); + if (sk) { + llc_sap_rcv(sap, skb, sk); + sock_put(sk); + } else + kfree_skb(skb); + } +} diff --git a/net/llc/llc_station.c b/net/llc/llc_station.c new file mode 100644 index 000000000..f50654292 --- /dev/null +++ b/net/llc/llc_station.c @@ -0,0 +1,126 @@ +/* + * llc_station.c - station component of LLC + * + * Copyright (c) 1997 by Procom Technology, Inc. + * 2001-2003 by Arnaldo Carvalho de Melo + * + * This program can be redistributed or modified under the terms of the + * GNU General Public License as published by the Free Software Foundation. + * This program is distributed without any warranty or implied warranty + * of merchantability or fitness for a particular purpose. + * + * See the GNU General Public License for more details. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int llc_stat_ev_rx_null_dsap_xid_c(struct sk_buff *skb) +{ + struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + return LLC_PDU_IS_CMD(pdu) && /* command PDU */ + LLC_PDU_TYPE_IS_U(pdu) && /* U type PDU */ + LLC_U_PDU_CMD(pdu) == LLC_1_PDU_CMD_XID && + !pdu->dsap; /* NULL DSAP value */ +} + +static int llc_stat_ev_rx_null_dsap_test_c(struct sk_buff *skb) +{ + struct llc_pdu_un *pdu = llc_pdu_un_hdr(skb); + + return LLC_PDU_IS_CMD(pdu) && /* command PDU */ + LLC_PDU_TYPE_IS_U(pdu) && /* U type PDU */ + LLC_U_PDU_CMD(pdu) == LLC_1_PDU_CMD_TEST && + !pdu->dsap; /* NULL DSAP */ +} + +static int llc_station_ac_send_xid_r(struct sk_buff *skb) +{ + u8 mac_da[ETH_ALEN], dsap; + int rc = 1; + struct sk_buff *nskb = llc_alloc_frame(NULL, skb->dev, LLC_PDU_TYPE_U, + sizeof(struct llc_xid_info)); + + if (!nskb) + goto out; + llc_pdu_decode_sa(skb, mac_da); + llc_pdu_decode_ssap(skb, &dsap); + llc_pdu_header_init(nskb, LLC_PDU_TYPE_U, 0, dsap, LLC_PDU_RSP); + llc_pdu_init_as_xid_rsp(nskb, LLC_XID_NULL_CLASS_2, 127); + rc = llc_mac_hdr_init(nskb, skb->dev->dev_addr, mac_da); + if (unlikely(rc)) + goto free; + dev_queue_xmit(nskb); +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +static int llc_station_ac_send_test_r(struct sk_buff *skb) +{ + u8 mac_da[ETH_ALEN], dsap; + int rc = 1; + u32 data_size; + struct sk_buff *nskb; + + if (skb->mac_len < ETH_HLEN) + goto out; + + /* The test request command is type U (llc_len = 3) */ + data_size = ntohs(eth_hdr(skb)->h_proto) - 3; + nskb = llc_alloc_frame(NULL, skb->dev, LLC_PDU_TYPE_U, data_size); + + if (!nskb) + goto out; + llc_pdu_decode_sa(skb, mac_da); + llc_pdu_decode_ssap(skb, &dsap); + llc_pdu_header_init(nskb, LLC_PDU_TYPE_U, 0, dsap, LLC_PDU_RSP); + llc_pdu_init_as_test_rsp(nskb, skb); + rc = llc_mac_hdr_init(nskb, skb->dev->dev_addr, mac_da); + if (unlikely(rc)) + goto free; + dev_queue_xmit(nskb); +out: + return rc; +free: + kfree_skb(nskb); + goto out; +} + +/** + * llc_station_rcv - send received pdu to the station state machine + * @skb: received frame. + * + * Sends data unit to station state machine. + */ +static void llc_station_rcv(struct sk_buff *skb) +{ + if (llc_stat_ev_rx_null_dsap_xid_c(skb)) + llc_station_ac_send_xid_r(skb); + else if (llc_stat_ev_rx_null_dsap_test_c(skb)) + llc_station_ac_send_test_r(skb); + kfree_skb(skb); +} + +void __init llc_station_init(void) +{ + llc_set_station_handler(llc_station_rcv); +} + +void llc_station_exit(void) +{ + llc_set_station_handler(NULL); +} diff --git a/net/llc/sysctl_net_llc.c b/net/llc/sysctl_net_llc.c new file mode 100644 index 000000000..8443a6d84 --- /dev/null +++ b/net/llc/sysctl_net_llc.c @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * sysctl_net_llc.c: sysctl interface to LLC net subsystem. + * + * Arnaldo Carvalho de Melo + */ + +#include +#include +#include +#include +#include + +#ifndef CONFIG_SYSCTL +#error This file should not be compiled without CONFIG_SYSCTL defined +#endif + +static struct ctl_table llc2_timeout_table[] = { + { + .procname = "ack", + .data = &sysctl_llc2_ack_timeout, + .maxlen = sizeof(sysctl_llc2_ack_timeout), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "busy", + .data = &sysctl_llc2_busy_timeout, + .maxlen = sizeof(sysctl_llc2_busy_timeout), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "p", + .data = &sysctl_llc2_p_timeout, + .maxlen = sizeof(sysctl_llc2_p_timeout), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "rej", + .data = &sysctl_llc2_rej_timeout, + .maxlen = sizeof(sysctl_llc2_rej_timeout), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { }, +}; + +static struct ctl_table llc_station_table[] = { + { }, +}; + +static struct ctl_table_header *llc2_timeout_header; +static struct ctl_table_header *llc_station_header; + +int __init llc_sysctl_init(void) +{ + llc2_timeout_header = register_net_sysctl(&init_net, "net/llc/llc2/timeout", llc2_timeout_table); + llc_station_header = register_net_sysctl(&init_net, "net/llc/station", llc_station_table); + + if (!llc2_timeout_header || !llc_station_header) { + llc_sysctl_exit(); + return -ENOMEM; + } + return 0; +} + +void llc_sysctl_exit(void) +{ + if (llc2_timeout_header) { + unregister_net_sysctl_table(llc2_timeout_header); + llc2_timeout_header = NULL; + } + if (llc_station_header) { + unregister_net_sysctl_table(llc_station_header); + llc_station_header = NULL; + } +} diff --git a/net/mac80211/Kconfig b/net/mac80211/Kconfig new file mode 100644 index 000000000..51ec8256b --- /dev/null +++ b/net/mac80211/Kconfig @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: GPL-2.0-only +config MAC80211 + tristate "Generic IEEE 802.11 Networking Stack (mac80211)" + depends on CFG80211 + select CRYPTO + select CRYPTO_LIB_ARC4 + select CRYPTO_AES + select CRYPTO_CCM + select CRYPTO_GCM + select CRYPTO_CMAC + select CRC32 + help + This option enables the hardware independent IEEE 802.11 + networking stack. + +comment "CFG80211 needs to be enabled for MAC80211" + depends on CFG80211=n + +if MAC80211 != n + +config MAC80211_HAS_RC + bool + +config MAC80211_RC_MINSTREL + bool "Minstrel" if EXPERT + select MAC80211_HAS_RC + default y + help + This option enables the 'minstrel' TX rate control algorithm + +choice + prompt "Default rate control algorithm" + depends on MAC80211_HAS_RC + default MAC80211_RC_DEFAULT_MINSTREL + help + This option selects the default rate control algorithm + mac80211 will use. Note that this default can still be + overridden through the ieee80211_default_rc_algo module + parameter if different algorithms are available. + +config MAC80211_RC_DEFAULT_MINSTREL + bool "Minstrel" + depends on MAC80211_RC_MINSTREL + help + Select Minstrel as the default rate control algorithm. + + +endchoice + +config MAC80211_RC_DEFAULT + string + default "minstrel_ht" if MAC80211_RC_DEFAULT_MINSTREL + default "" + +endif + +comment "Some wireless drivers require a rate control algorithm" + depends on MAC80211 && MAC80211_HAS_RC=n + +config MAC80211_MESH + bool "Enable mac80211 mesh networking support" + depends on MAC80211 + help + Select this option to enable 802.11 mesh operation in mac80211 + drivers that support it. 802.11 mesh connects multiple stations + over (possibly multi-hop) wireless links to form a single logical + LAN. + +config MAC80211_LEDS + bool "Enable LED triggers" + depends on MAC80211 + depends on LEDS_CLASS=y || LEDS_CLASS=MAC80211 + select LEDS_TRIGGERS + help + This option enables a few LED triggers for different + packet receive/transmit events. + +config MAC80211_DEBUGFS + bool "Export mac80211 internals in DebugFS" + depends on MAC80211 && DEBUG_FS + help + Select this to see extensive information about + the internal state of mac80211 in debugfs. + + Say N unless you know you need this. + +config MAC80211_MESSAGE_TRACING + bool "Trace all mac80211 debug messages" + depends on MAC80211 + help + Select this option to have mac80211 register the + mac80211_msg trace subsystem with tracepoints to + collect all debugging messages, independent of + printing them into the kernel log. + + The overhead in this option is that all the messages + need to be present in the binary and formatted at + runtime for tracing. + +menuconfig MAC80211_DEBUG_MENU + bool "Select mac80211 debugging features" + depends on MAC80211 + help + This option collects various mac80211 debug settings. + +config MAC80211_NOINLINE + bool "Do not inline TX/RX handlers" + depends on MAC80211_DEBUG_MENU + help + This option affects code generation in mac80211, when + selected some functions are marked "noinline" to allow + easier debugging of problems in the transmit and receive + paths. + + This option increases code size a bit and inserts a lot + of function calls in the code, but is otherwise safe to + enable. + + If unsure, say N unless you expect to be finding problems + in mac80211. + +config MAC80211_VERBOSE_DEBUG + bool "Verbose debugging output" + depends on MAC80211_DEBUG_MENU + help + Selecting this option causes mac80211 to print out + many debugging messages. It should not be selected + on production systems as some of the messages are + remotely triggerable. + + Do not select this option. + +config MAC80211_MLME_DEBUG + bool "Verbose managed MLME output" + depends on MAC80211_DEBUG_MENU + help + Selecting this option causes mac80211 to print out + debugging messages for the managed-mode MLME. It + should not be selected on production systems as some + of the messages are remotely triggerable. + + Do not select this option. + +config MAC80211_STA_DEBUG + bool "Verbose station debugging" + depends on MAC80211_DEBUG_MENU + help + Selecting this option causes mac80211 to print out + debugging messages for station addition/removal. + + Do not select this option. + +config MAC80211_HT_DEBUG + bool "Verbose HT debugging" + depends on MAC80211_DEBUG_MENU + help + This option enables 802.11n High Throughput features + debug tracing output. + + It should not be selected on production systems as some + of the messages are remotely triggerable. + + Do not select this option. + +config MAC80211_OCB_DEBUG + bool "Verbose OCB debugging" + depends on MAC80211_DEBUG_MENU + help + Selecting this option causes mac80211 to print out + very verbose OCB debugging messages. It should not + be selected on production systems as those messages + are remotely triggerable. + + Do not select this option. + +config MAC80211_IBSS_DEBUG + bool "Verbose IBSS debugging" + depends on MAC80211_DEBUG_MENU + help + Selecting this option causes mac80211 to print out + very verbose IBSS debugging messages. It should not + be selected on production systems as those messages + are remotely triggerable. + + Do not select this option. + +config MAC80211_PS_DEBUG + bool "Verbose powersave mode debugging" + depends on MAC80211_DEBUG_MENU + help + Selecting this option causes mac80211 to print out very + verbose power save mode debugging messages (when mac80211 + is an AP and has power saving stations.) + It should not be selected on production systems as those + messages are remotely triggerable. + + Do not select this option. + +config MAC80211_MPL_DEBUG + bool "Verbose mesh peer link debugging" + depends on MAC80211_DEBUG_MENU + depends on MAC80211_MESH + help + Selecting this option causes mac80211 to print out very + verbose mesh peer link debugging messages (when mac80211 + is taking part in a mesh network). + It should not be selected on production systems as those + messages are remotely triggerable. + + Do not select this option. + +config MAC80211_MPATH_DEBUG + bool "Verbose mesh path debugging" + depends on MAC80211_DEBUG_MENU + depends on MAC80211_MESH + help + Selecting this option causes mac80211 to print out very + verbose mesh path selection debugging messages (when mac80211 + is taking part in a mesh network). + It should not be selected on production systems as those + messages are remotely triggerable. + + Do not select this option. + +config MAC80211_MHWMP_DEBUG + bool "Verbose mesh HWMP routing debugging" + depends on MAC80211_DEBUG_MENU + depends on MAC80211_MESH + help + Selecting this option causes mac80211 to print out very + verbose mesh routing (HWMP) debugging messages (when mac80211 + is taking part in a mesh network). + It should not be selected on production systems as those + messages are remotely triggerable. + + Do not select this option. + +config MAC80211_MESH_SYNC_DEBUG + bool "Verbose mesh synchronization debugging" + depends on MAC80211_DEBUG_MENU + depends on MAC80211_MESH + help + Selecting this option causes mac80211 to print out very verbose mesh + synchronization debugging messages (when mac80211 is taking part in a + mesh network). + + Do not select this option. + +config MAC80211_MESH_CSA_DEBUG + bool "Verbose mesh channel switch debugging" + depends on MAC80211_DEBUG_MENU + depends on MAC80211_MESH + help + Selecting this option causes mac80211 to print out very verbose mesh + channel switch debugging messages (when mac80211 is taking part in a + mesh network). + + Do not select this option. + +config MAC80211_MESH_PS_DEBUG + bool "Verbose mesh powersave debugging" + depends on MAC80211_DEBUG_MENU + depends on MAC80211_MESH + help + Selecting this option causes mac80211 to print out very verbose mesh + powersave debugging messages (when mac80211 is taking part in a + mesh network). + + Do not select this option. + +config MAC80211_TDLS_DEBUG + bool "Verbose TDLS debugging" + depends on MAC80211_DEBUG_MENU + help + Selecting this option causes mac80211 to print out very + verbose TDLS selection debugging messages (when mac80211 + is a TDLS STA). + It should not be selected on production systems as those + messages are remotely triggerable. + + Do not select this option. + +config MAC80211_DEBUG_COUNTERS + bool "Extra statistics for TX/RX debugging" + depends on MAC80211_DEBUG_MENU + depends on MAC80211_DEBUGFS + help + Selecting this option causes mac80211 to keep additional + and very verbose statistics about TX and RX handler use + as well as a few selected dot11 counters. These will be + exposed in debugfs. + + Note that some of the counters are not concurrency safe + and may thus not always be accurate. + + If unsure, say N. + +config MAC80211_STA_HASH_MAX_SIZE + int "Station hash table maximum size" if MAC80211_DEBUG_MENU + default 0 + help + Setting this option to a low value (e.g. 4) allows testing the + hash table with collisions relatively deterministically (just + connect more stations than the number selected here.) + + If unsure, leave the default of 0. diff --git a/net/mac80211/Makefile b/net/mac80211/Makefile new file mode 100644 index 000000000..b8de44da1 --- /dev/null +++ b/net/mac80211/Makefile @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: GPL-2.0 +obj-$(CONFIG_MAC80211) += mac80211.o + +# mac80211 objects +mac80211-y := \ + main.o status.o \ + driver-ops.o \ + sta_info.o \ + wep.o \ + aead_api.o \ + wpa.o \ + scan.o offchannel.o \ + ht.o agg-tx.o agg-rx.o \ + vht.o \ + he.o \ + s1g.o \ + ibss.o \ + iface.o \ + link.o \ + rate.o \ + michael.o \ + tkip.o \ + aes_cmac.o \ + aes_gmac.o \ + fils_aead.o \ + cfg.o \ + ethtool.o \ + rx.o \ + spectmgmt.o \ + tx.o \ + key.o \ + util.o \ + wme.o \ + chan.o \ + trace.o mlme.o \ + tdls.o \ + ocb.o \ + airtime.o \ + eht.o + +mac80211-$(CONFIG_MAC80211_LEDS) += led.o +mac80211-$(CONFIG_MAC80211_DEBUGFS) += \ + debugfs.o \ + debugfs_sta.o \ + debugfs_netdev.o \ + debugfs_key.o + +mac80211-$(CONFIG_MAC80211_MESH) += \ + mesh.o \ + mesh_pathtbl.o \ + mesh_plink.o \ + mesh_hwmp.o \ + mesh_sync.o \ + mesh_ps.o + +mac80211-$(CONFIG_PM) += pm.o + +CFLAGS_trace.o := -I$(src) + +rc80211_minstrel-y := \ + rc80211_minstrel_ht.o + +rc80211_minstrel-$(CONFIG_MAC80211_DEBUGFS) += \ + rc80211_minstrel_ht_debugfs.o + +mac80211-$(CONFIG_MAC80211_RC_MINSTREL) += $(rc80211_minstrel-y) + +ccflags-y += -DDEBUG diff --git a/net/mac80211/aead_api.c b/net/mac80211/aead_api.c new file mode 100644 index 000000000..b00d6f5b3 --- /dev/null +++ b/net/mac80211/aead_api.c @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2003-2004, Instant802 Networks, Inc. + * Copyright 2005-2006, Devicescape Software, Inc. + * Copyright 2014-2015, Qualcomm Atheros, Inc. + * + * Rewrite: Copyright (C) 2013 Linaro Ltd + */ + +#include +#include +#include +#include +#include + +#include "aead_api.h" + +int aead_encrypt(struct crypto_aead *tfm, u8 *b_0, u8 *aad, size_t aad_len, + u8 *data, size_t data_len, u8 *mic) +{ + size_t mic_len = crypto_aead_authsize(tfm); + struct scatterlist sg[3]; + struct aead_request *aead_req; + int reqsize = sizeof(*aead_req) + crypto_aead_reqsize(tfm); + u8 *__aad; + int ret; + + aead_req = kzalloc(reqsize + aad_len, GFP_ATOMIC); + if (!aead_req) + return -ENOMEM; + + __aad = (u8 *)aead_req + reqsize; + memcpy(__aad, aad, aad_len); + + sg_init_table(sg, 3); + sg_set_buf(&sg[0], __aad, aad_len); + sg_set_buf(&sg[1], data, data_len); + sg_set_buf(&sg[2], mic, mic_len); + + aead_request_set_tfm(aead_req, tfm); + aead_request_set_crypt(aead_req, sg, sg, data_len, b_0); + aead_request_set_ad(aead_req, sg[0].length); + + ret = crypto_aead_encrypt(aead_req); + kfree_sensitive(aead_req); + + return ret; +} + +int aead_decrypt(struct crypto_aead *tfm, u8 *b_0, u8 *aad, size_t aad_len, + u8 *data, size_t data_len, u8 *mic) +{ + size_t mic_len = crypto_aead_authsize(tfm); + struct scatterlist sg[3]; + struct aead_request *aead_req; + int reqsize = sizeof(*aead_req) + crypto_aead_reqsize(tfm); + u8 *__aad; + int err; + + if (data_len == 0) + return -EINVAL; + + aead_req = kzalloc(reqsize + aad_len, GFP_ATOMIC); + if (!aead_req) + return -ENOMEM; + + __aad = (u8 *)aead_req + reqsize; + memcpy(__aad, aad, aad_len); + + sg_init_table(sg, 3); + sg_set_buf(&sg[0], __aad, aad_len); + sg_set_buf(&sg[1], data, data_len); + sg_set_buf(&sg[2], mic, mic_len); + + aead_request_set_tfm(aead_req, tfm); + aead_request_set_crypt(aead_req, sg, sg, data_len + mic_len, b_0); + aead_request_set_ad(aead_req, sg[0].length); + + err = crypto_aead_decrypt(aead_req); + kfree_sensitive(aead_req); + + return err; +} + +struct crypto_aead * +aead_key_setup_encrypt(const char *alg, const u8 key[], + size_t key_len, size_t mic_len) +{ + struct crypto_aead *tfm; + int err; + + tfm = crypto_alloc_aead(alg, 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(tfm)) + return tfm; + + err = crypto_aead_setkey(tfm, key, key_len); + if (err) + goto free_aead; + err = crypto_aead_setauthsize(tfm, mic_len); + if (err) + goto free_aead; + + return tfm; + +free_aead: + crypto_free_aead(tfm); + return ERR_PTR(err); +} + +void aead_key_free(struct crypto_aead *tfm) +{ + crypto_free_aead(tfm); +} diff --git a/net/mac80211/aead_api.h b/net/mac80211/aead_api.h new file mode 100644 index 000000000..7d463b809 --- /dev/null +++ b/net/mac80211/aead_api.h @@ -0,0 +1,23 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ + +#ifndef _AEAD_API_H +#define _AEAD_API_H + +#include +#include + +struct crypto_aead * +aead_key_setup_encrypt(const char *alg, const u8 key[], + size_t key_len, size_t mic_len); + +int aead_encrypt(struct crypto_aead *tfm, u8 *b_0, u8 *aad, + size_t aad_len, u8 *data, + size_t data_len, u8 *mic); + +int aead_decrypt(struct crypto_aead *tfm, u8 *b_0, u8 *aad, + size_t aad_len, u8 *data, + size_t data_len, u8 *mic); + +void aead_key_free(struct crypto_aead *tfm); + +#endif /* _AEAD_API_H */ diff --git a/net/mac80211/aes_ccm.h b/net/mac80211/aes_ccm.h new file mode 100644 index 000000000..96256193c --- /dev/null +++ b/net/mac80211/aes_ccm.h @@ -0,0 +1,45 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright 2003-2004, Instant802 Networks, Inc. + * Copyright 2006, Devicescape Software, Inc. + */ + +#ifndef AES_CCM_H +#define AES_CCM_H + +#include "aead_api.h" + +#define CCM_AAD_LEN 32 + +static inline struct crypto_aead * +ieee80211_aes_key_setup_encrypt(const u8 key[], size_t key_len, size_t mic_len) +{ + return aead_key_setup_encrypt("ccm(aes)", key, key_len, mic_len); +} + +static inline int +ieee80211_aes_ccm_encrypt(struct crypto_aead *tfm, + u8 *b_0, u8 *aad, u8 *data, + size_t data_len, u8 *mic) +{ + return aead_encrypt(tfm, b_0, aad + 2, + be16_to_cpup((__be16 *)aad), + data, data_len, mic); +} + +static inline int +ieee80211_aes_ccm_decrypt(struct crypto_aead *tfm, + u8 *b_0, u8 *aad, u8 *data, + size_t data_len, u8 *mic) +{ + return aead_decrypt(tfm, b_0, aad + 2, + be16_to_cpup((__be16 *)aad), + data, data_len, mic); +} + +static inline void ieee80211_aes_key_free(struct crypto_aead *tfm) +{ + return aead_key_free(tfm); +} + +#endif /* AES_CCM_H */ diff --git a/net/mac80211/aes_cmac.c b/net/mac80211/aes_cmac.c new file mode 100644 index 000000000..48c04f89d --- /dev/null +++ b/net/mac80211/aes_cmac.c @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * AES-128-CMAC with TLen 16 for IEEE 802.11w BIP + * Copyright 2008, Jouni Malinen + * Copyright (C) 2020 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include + +#include +#include "key.h" +#include "aes_cmac.h" + +#define CMAC_TLEN 8 /* CMAC TLen = 64 bits (8 octets) */ +#define CMAC_TLEN_256 16 /* CMAC TLen = 128 bits (16 octets) */ +#define AAD_LEN 20 + +static const u8 zero[CMAC_TLEN_256]; + +void ieee80211_aes_cmac(struct crypto_shash *tfm, const u8 *aad, + const u8 *data, size_t data_len, u8 *mic) +{ + SHASH_DESC_ON_STACK(desc, tfm); + u8 out[AES_BLOCK_SIZE]; + const __le16 *fc; + + desc->tfm = tfm; + + crypto_shash_init(desc); + crypto_shash_update(desc, aad, AAD_LEN); + fc = (const __le16 *)aad; + if (ieee80211_is_beacon(*fc)) { + /* mask Timestamp field to zero */ + crypto_shash_update(desc, zero, 8); + crypto_shash_update(desc, data + 8, data_len - 8 - CMAC_TLEN); + } else { + crypto_shash_update(desc, data, data_len - CMAC_TLEN); + } + crypto_shash_finup(desc, zero, CMAC_TLEN, out); + + memcpy(mic, out, CMAC_TLEN); +} + +void ieee80211_aes_cmac_256(struct crypto_shash *tfm, const u8 *aad, + const u8 *data, size_t data_len, u8 *mic) +{ + SHASH_DESC_ON_STACK(desc, tfm); + const __le16 *fc; + + desc->tfm = tfm; + + crypto_shash_init(desc); + crypto_shash_update(desc, aad, AAD_LEN); + fc = (const __le16 *)aad; + if (ieee80211_is_beacon(*fc)) { + /* mask Timestamp field to zero */ + crypto_shash_update(desc, zero, 8); + crypto_shash_update(desc, data + 8, + data_len - 8 - CMAC_TLEN_256); + } else { + crypto_shash_update(desc, data, data_len - CMAC_TLEN_256); + } + crypto_shash_finup(desc, zero, CMAC_TLEN_256, mic); +} + +struct crypto_shash *ieee80211_aes_cmac_key_setup(const u8 key[], + size_t key_len) +{ + struct crypto_shash *tfm; + + tfm = crypto_alloc_shash("cmac(aes)", 0, 0); + if (!IS_ERR(tfm)) { + int err = crypto_shash_setkey(tfm, key, key_len); + + if (err) { + crypto_free_shash(tfm); + return ERR_PTR(err); + } + } + + return tfm; +} + +void ieee80211_aes_cmac_key_free(struct crypto_shash *tfm) +{ + crypto_free_shash(tfm); +} diff --git a/net/mac80211/aes_cmac.h b/net/mac80211/aes_cmac.h new file mode 100644 index 000000000..76817446f --- /dev/null +++ b/net/mac80211/aes_cmac.h @@ -0,0 +1,20 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright 2008, Jouni Malinen + */ + +#ifndef AES_CMAC_H +#define AES_CMAC_H + +#include +#include + +struct crypto_shash *ieee80211_aes_cmac_key_setup(const u8 key[], + size_t key_len); +void ieee80211_aes_cmac(struct crypto_shash *tfm, const u8 *aad, + const u8 *data, size_t data_len, u8 *mic); +void ieee80211_aes_cmac_256(struct crypto_shash *tfm, const u8 *aad, + const u8 *data, size_t data_len, u8 *mic); +void ieee80211_aes_cmac_key_free(struct crypto_shash *tfm); + +#endif /* AES_CMAC_H */ diff --git a/net/mac80211/aes_gcm.h b/net/mac80211/aes_gcm.h new file mode 100644 index 000000000..b14093b2f --- /dev/null +++ b/net/mac80211/aes_gcm.h @@ -0,0 +1,43 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright 2014-2015, Qualcomm Atheros, Inc. + */ + +#ifndef AES_GCM_H +#define AES_GCM_H + +#include "aead_api.h" + +#define GCM_AAD_LEN 32 + +static inline int ieee80211_aes_gcm_encrypt(struct crypto_aead *tfm, + u8 *j_0, u8 *aad, u8 *data, + size_t data_len, u8 *mic) +{ + return aead_encrypt(tfm, j_0, aad + 2, + be16_to_cpup((__be16 *)aad), + data, data_len, mic); +} + +static inline int ieee80211_aes_gcm_decrypt(struct crypto_aead *tfm, + u8 *j_0, u8 *aad, u8 *data, + size_t data_len, u8 *mic) +{ + return aead_decrypt(tfm, j_0, aad + 2, + be16_to_cpup((__be16 *)aad), + data, data_len, mic); +} + +static inline struct crypto_aead * +ieee80211_aes_gcm_key_setup_encrypt(const u8 key[], size_t key_len) +{ + return aead_key_setup_encrypt("gcm(aes)", key, + key_len, IEEE80211_GCMP_MIC_LEN); +} + +static inline void ieee80211_aes_gcm_key_free(struct crypto_aead *tfm) +{ + return aead_key_free(tfm); +} + +#endif /* AES_GCM_H */ diff --git a/net/mac80211/aes_gmac.c b/net/mac80211/aes_gmac.c new file mode 100644 index 000000000..512cab073 --- /dev/null +++ b/net/mac80211/aes_gmac.c @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * AES-GMAC for IEEE 802.11 BIP-GMAC-128 and BIP-GMAC-256 + * Copyright 2015, Qualcomm Atheros, Inc. + */ + +#include +#include +#include +#include +#include + +#include +#include "key.h" +#include "aes_gmac.h" + +int ieee80211_aes_gmac(struct crypto_aead *tfm, const u8 *aad, u8 *nonce, + const u8 *data, size_t data_len, u8 *mic) +{ + struct scatterlist sg[5]; + u8 *zero, *__aad, iv[AES_BLOCK_SIZE]; + struct aead_request *aead_req; + int reqsize = sizeof(*aead_req) + crypto_aead_reqsize(tfm); + const __le16 *fc; + int ret; + + if (data_len < GMAC_MIC_LEN) + return -EINVAL; + + aead_req = kzalloc(reqsize + GMAC_MIC_LEN + GMAC_AAD_LEN, GFP_ATOMIC); + if (!aead_req) + return -ENOMEM; + + zero = (u8 *)aead_req + reqsize; + __aad = zero + GMAC_MIC_LEN; + memcpy(__aad, aad, GMAC_AAD_LEN); + + fc = (const __le16 *)aad; + if (ieee80211_is_beacon(*fc)) { + /* mask Timestamp field to zero */ + sg_init_table(sg, 5); + sg_set_buf(&sg[0], __aad, GMAC_AAD_LEN); + sg_set_buf(&sg[1], zero, 8); + sg_set_buf(&sg[2], data + 8, data_len - 8 - GMAC_MIC_LEN); + sg_set_buf(&sg[3], zero, GMAC_MIC_LEN); + sg_set_buf(&sg[4], mic, GMAC_MIC_LEN); + } else { + sg_init_table(sg, 4); + sg_set_buf(&sg[0], __aad, GMAC_AAD_LEN); + sg_set_buf(&sg[1], data, data_len - GMAC_MIC_LEN); + sg_set_buf(&sg[2], zero, GMAC_MIC_LEN); + sg_set_buf(&sg[3], mic, GMAC_MIC_LEN); + } + + memcpy(iv, nonce, GMAC_NONCE_LEN); + memset(iv + GMAC_NONCE_LEN, 0, sizeof(iv) - GMAC_NONCE_LEN); + iv[AES_BLOCK_SIZE - 1] = 0x01; + + aead_request_set_tfm(aead_req, tfm); + aead_request_set_crypt(aead_req, sg, sg, 0, iv); + aead_request_set_ad(aead_req, GMAC_AAD_LEN + data_len); + + ret = crypto_aead_encrypt(aead_req); + kfree_sensitive(aead_req); + + return ret; +} + +struct crypto_aead *ieee80211_aes_gmac_key_setup(const u8 key[], + size_t key_len) +{ + struct crypto_aead *tfm; + int err; + + tfm = crypto_alloc_aead("gcm(aes)", 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(tfm)) + return tfm; + + err = crypto_aead_setkey(tfm, key, key_len); + if (!err) + err = crypto_aead_setauthsize(tfm, GMAC_MIC_LEN); + if (!err) + return tfm; + + crypto_free_aead(tfm); + return ERR_PTR(err); +} + +void ieee80211_aes_gmac_key_free(struct crypto_aead *tfm) +{ + crypto_free_aead(tfm); +} diff --git a/net/mac80211/aes_gmac.h b/net/mac80211/aes_gmac.h new file mode 100644 index 000000000..c739356ba --- /dev/null +++ b/net/mac80211/aes_gmac.h @@ -0,0 +1,21 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright 2015, Qualcomm Atheros, Inc. + */ + +#ifndef AES_GMAC_H +#define AES_GMAC_H + +#include + +#define GMAC_AAD_LEN 20 +#define GMAC_MIC_LEN 16 +#define GMAC_NONCE_LEN 12 + +struct crypto_aead *ieee80211_aes_gmac_key_setup(const u8 key[], + size_t key_len); +int ieee80211_aes_gmac(struct crypto_aead *tfm, const u8 *aad, u8 *nonce, + const u8 *data, size_t data_len, u8 *mic); +void ieee80211_aes_gmac_key_free(struct crypto_aead *tfm); + +#endif /* AES_GMAC_H */ diff --git a/net/mac80211/agg-rx.c b/net/mac80211/agg-rx.c new file mode 100644 index 000000000..9414d3bbd --- /dev/null +++ b/net/mac80211/agg-rx.c @@ -0,0 +1,562 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * HT handling + * + * Copyright 2003, Jouni Malinen + * Copyright 2002-2005, Instant802 Networks, Inc. + * Copyright 2005-2006, Devicescape Software, Inc. + * Copyright 2006-2007 Jiri Benc + * Copyright 2007, Michael Wu + * Copyright 2007-2010, Intel Corporation + * Copyright(c) 2015-2017 Intel Deutschland GmbH + * Copyright (C) 2018-2022 Intel Corporation + */ + +/** + * DOC: RX A-MPDU aggregation + * + * Aggregation on the RX side requires only implementing the + * @ampdu_action callback that is invoked to start/stop any + * block-ack sessions for RX aggregation. + * + * When RX aggregation is started by the peer, the driver is + * notified via @ampdu_action function, with the + * %IEEE80211_AMPDU_RX_START action, and may reject the request + * in which case a negative response is sent to the peer, if it + * accepts it a positive response is sent. + * + * While the session is active, the device/driver are required + * to de-aggregate frames and pass them up one by one to mac80211, + * which will handle the reorder buffer. + * + * When the aggregation session is stopped again by the peer or + * ourselves, the driver's @ampdu_action function will be called + * with the action %IEEE80211_AMPDU_RX_STOP. In this case, the + * call must not fail. + */ + +#include +#include +#include +#include +#include "ieee80211_i.h" +#include "driver-ops.h" + +static void ieee80211_free_tid_rx(struct rcu_head *h) +{ + struct tid_ampdu_rx *tid_rx = + container_of(h, struct tid_ampdu_rx, rcu_head); + int i; + + for (i = 0; i < tid_rx->buf_size; i++) + __skb_queue_purge(&tid_rx->reorder_buf[i]); + kfree(tid_rx->reorder_buf); + kfree(tid_rx->reorder_time); + kfree(tid_rx); +} + +void ___ieee80211_stop_rx_ba_session(struct sta_info *sta, u16 tid, + u16 initiator, u16 reason, bool tx) +{ + struct ieee80211_local *local = sta->local; + struct tid_ampdu_rx *tid_rx; + struct ieee80211_ampdu_params params = { + .sta = &sta->sta, + .action = IEEE80211_AMPDU_RX_STOP, + .tid = tid, + .amsdu = false, + .timeout = 0, + .ssn = 0, + }; + + lockdep_assert_held(&sta->ampdu_mlme.mtx); + + tid_rx = rcu_dereference_protected(sta->ampdu_mlme.tid_rx[tid], + lockdep_is_held(&sta->ampdu_mlme.mtx)); + + if (!test_bit(tid, sta->ampdu_mlme.agg_session_valid)) + return; + + RCU_INIT_POINTER(sta->ampdu_mlme.tid_rx[tid], NULL); + __clear_bit(tid, sta->ampdu_mlme.agg_session_valid); + + ht_dbg(sta->sdata, + "Rx BA session stop requested for %pM tid %u %s reason: %d\n", + sta->sta.addr, tid, + initiator == WLAN_BACK_RECIPIENT ? "recipient" : "initiator", + (int)reason); + + if (drv_ampdu_action(local, sta->sdata, ¶ms)) + sdata_info(sta->sdata, + "HW problem - can not stop rx aggregation for %pM tid %d\n", + sta->sta.addr, tid); + + /* check if this is a self generated aggregation halt */ + if (initiator == WLAN_BACK_RECIPIENT && tx) + ieee80211_send_delba(sta->sdata, sta->sta.addr, + tid, WLAN_BACK_RECIPIENT, reason); + + /* + * return here in case tid_rx is not assigned - which will happen if + * IEEE80211_HW_SUPPORTS_REORDERING_BUFFER is set. + */ + if (!tid_rx) + return; + + del_timer_sync(&tid_rx->session_timer); + + /* make sure ieee80211_sta_reorder_release() doesn't re-arm the timer */ + spin_lock_bh(&tid_rx->reorder_lock); + tid_rx->removed = true; + spin_unlock_bh(&tid_rx->reorder_lock); + del_timer_sync(&tid_rx->reorder_timer); + + call_rcu(&tid_rx->rcu_head, ieee80211_free_tid_rx); +} + +void __ieee80211_stop_rx_ba_session(struct sta_info *sta, u16 tid, + u16 initiator, u16 reason, bool tx) +{ + mutex_lock(&sta->ampdu_mlme.mtx); + ___ieee80211_stop_rx_ba_session(sta, tid, initiator, reason, tx); + mutex_unlock(&sta->ampdu_mlme.mtx); +} + +void ieee80211_stop_rx_ba_session(struct ieee80211_vif *vif, u16 ba_rx_bitmap, + const u8 *addr) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct sta_info *sta; + int i; + + rcu_read_lock(); + sta = sta_info_get_bss(sdata, addr); + if (!sta) { + rcu_read_unlock(); + return; + } + + for (i = 0; i < IEEE80211_NUM_TIDS; i++) + if (ba_rx_bitmap & BIT(i)) + set_bit(i, sta->ampdu_mlme.tid_rx_stop_requested); + + ieee80211_queue_work(&sta->local->hw, &sta->ampdu_mlme.work); + rcu_read_unlock(); +} +EXPORT_SYMBOL(ieee80211_stop_rx_ba_session); + +/* + * After accepting the AddBA Request we activated a timer, + * resetting it after each frame that arrives from the originator. + */ +static void sta_rx_agg_session_timer_expired(struct timer_list *t) +{ + struct tid_ampdu_rx *tid_rx = from_timer(tid_rx, t, session_timer); + struct sta_info *sta = tid_rx->sta; + u8 tid = tid_rx->tid; + unsigned long timeout; + + timeout = tid_rx->last_rx + TU_TO_JIFFIES(tid_rx->timeout); + if (time_is_after_jiffies(timeout)) { + mod_timer(&tid_rx->session_timer, timeout); + return; + } + + ht_dbg(sta->sdata, "RX session timer expired on %pM tid %d\n", + sta->sta.addr, tid); + + set_bit(tid, sta->ampdu_mlme.tid_rx_timer_expired); + ieee80211_queue_work(&sta->local->hw, &sta->ampdu_mlme.work); +} + +static void sta_rx_agg_reorder_timer_expired(struct timer_list *t) +{ + struct tid_ampdu_rx *tid_rx = from_timer(tid_rx, t, reorder_timer); + + rcu_read_lock(); + ieee80211_release_reorder_timeout(tid_rx->sta, tid_rx->tid); + rcu_read_unlock(); +} + +static void ieee80211_add_addbaext(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, + const struct ieee80211_addba_ext_ie *req, + u16 buf_size) +{ + struct ieee80211_supported_band *sband; + struct ieee80211_addba_ext_ie *resp; + const struct ieee80211_sta_he_cap *he_cap; + u8 frag_level, cap_frag_level; + u8 *pos; + + sband = ieee80211_get_sband(sdata); + if (!sband) + return; + he_cap = ieee80211_get_he_iftype_cap(sband, + ieee80211_vif_type_p2p(&sdata->vif)); + if (!he_cap) + return; + + pos = skb_put_zero(skb, 2 + sizeof(struct ieee80211_addba_ext_ie)); + *pos++ = WLAN_EID_ADDBA_EXT; + *pos++ = sizeof(struct ieee80211_addba_ext_ie); + resp = (struct ieee80211_addba_ext_ie *)pos; + resp->data = req->data & IEEE80211_ADDBA_EXT_NO_FRAG; + + frag_level = u32_get_bits(req->data, + IEEE80211_ADDBA_EXT_FRAG_LEVEL_MASK); + cap_frag_level = u32_get_bits(he_cap->he_cap_elem.mac_cap_info[0], + IEEE80211_HE_MAC_CAP0_DYNAMIC_FRAG_MASK); + if (frag_level > cap_frag_level) + frag_level = cap_frag_level; + resp->data |= u8_encode_bits(frag_level, + IEEE80211_ADDBA_EXT_FRAG_LEVEL_MASK); + resp->data |= u8_encode_bits(buf_size >> IEEE80211_ADDBA_EXT_BUF_SIZE_SHIFT, + IEEE80211_ADDBA_EXT_BUF_SIZE_MASK); +} + +static void ieee80211_send_addba_resp(struct sta_info *sta, u8 *da, u16 tid, + u8 dialog_token, u16 status, u16 policy, + u16 buf_size, u16 timeout, + const struct ieee80211_addba_ext_ie *addbaext) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sdata->local; + struct sk_buff *skb; + struct ieee80211_mgmt *mgmt; + bool amsdu = ieee80211_hw_check(&local->hw, SUPPORTS_AMSDU_IN_AMPDU); + u16 capab; + + skb = dev_alloc_skb(sizeof(*mgmt) + + 2 + sizeof(struct ieee80211_addba_ext_ie) + + local->hw.extra_tx_headroom); + if (!skb) + return; + + skb_reserve(skb, local->hw.extra_tx_headroom); + mgmt = skb_put_zero(skb, 24); + memcpy(mgmt->da, da, ETH_ALEN); + memcpy(mgmt->sa, sdata->vif.addr, ETH_ALEN); + if (sdata->vif.type == NL80211_IFTYPE_AP || + sdata->vif.type == NL80211_IFTYPE_AP_VLAN || + sdata->vif.type == NL80211_IFTYPE_MESH_POINT) + memcpy(mgmt->bssid, sdata->vif.addr, ETH_ALEN); + else if (sdata->vif.type == NL80211_IFTYPE_STATION) + memcpy(mgmt->bssid, sdata->deflink.u.mgd.bssid, ETH_ALEN); + else if (sdata->vif.type == NL80211_IFTYPE_ADHOC) + memcpy(mgmt->bssid, sdata->u.ibss.bssid, ETH_ALEN); + + mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_ACTION); + + skb_put(skb, 1 + sizeof(mgmt->u.action.u.addba_resp)); + mgmt->u.action.category = WLAN_CATEGORY_BACK; + mgmt->u.action.u.addba_resp.action_code = WLAN_ACTION_ADDBA_RESP; + mgmt->u.action.u.addba_resp.dialog_token = dialog_token; + + capab = u16_encode_bits(amsdu, IEEE80211_ADDBA_PARAM_AMSDU_MASK); + capab |= u16_encode_bits(policy, IEEE80211_ADDBA_PARAM_POLICY_MASK); + capab |= u16_encode_bits(tid, IEEE80211_ADDBA_PARAM_TID_MASK); + capab |= u16_encode_bits(buf_size, IEEE80211_ADDBA_PARAM_BUF_SIZE_MASK); + + mgmt->u.action.u.addba_resp.capab = cpu_to_le16(capab); + mgmt->u.action.u.addba_resp.timeout = cpu_to_le16(timeout); + mgmt->u.action.u.addba_resp.status = cpu_to_le16(status); + + if (sta->sta.deflink.he_cap.has_he && addbaext) + ieee80211_add_addbaext(sdata, skb, addbaext, buf_size); + + ieee80211_tx_skb(sdata, skb); +} + +void ___ieee80211_start_rx_ba_session(struct sta_info *sta, + u8 dialog_token, u16 timeout, + u16 start_seq_num, u16 ba_policy, u16 tid, + u16 buf_size, bool tx, bool auto_seq, + const struct ieee80211_addba_ext_ie *addbaext) +{ + struct ieee80211_local *local = sta->sdata->local; + struct tid_ampdu_rx *tid_agg_rx; + struct ieee80211_ampdu_params params = { + .sta = &sta->sta, + .action = IEEE80211_AMPDU_RX_START, + .tid = tid, + .amsdu = false, + .timeout = timeout, + .ssn = start_seq_num, + }; + int i, ret = -EOPNOTSUPP; + u16 status = WLAN_STATUS_REQUEST_DECLINED; + u16 max_buf_size; + + if (tid >= IEEE80211_FIRST_TSPEC_TSID) { + ht_dbg(sta->sdata, + "STA %pM requests BA session on unsupported tid %d\n", + sta->sta.addr, tid); + goto end; + } + + if (!sta->sta.deflink.ht_cap.ht_supported && + sta->sdata->vif.bss_conf.chandef.chan->band != NL80211_BAND_6GHZ) { + ht_dbg(sta->sdata, + "STA %pM erroneously requests BA session on tid %d w/o QoS\n", + sta->sta.addr, tid); + /* send a response anyway, it's an error case if we get here */ + goto end; + } + + if (test_sta_flag(sta, WLAN_STA_BLOCK_BA)) { + ht_dbg(sta->sdata, + "Suspend in progress - Denying ADDBA request (%pM tid %d)\n", + sta->sta.addr, tid); + goto end; + } + + if (sta->sta.deflink.eht_cap.has_eht) + max_buf_size = IEEE80211_MAX_AMPDU_BUF_EHT; + else if (sta->sta.deflink.he_cap.has_he) + max_buf_size = IEEE80211_MAX_AMPDU_BUF_HE; + else + max_buf_size = IEEE80211_MAX_AMPDU_BUF_HT; + + /* sanity check for incoming parameters: + * check if configuration can support the BA policy + * and if buffer size does not exceeds max value */ + /* XXX: check own ht delayed BA capability?? */ + if (((ba_policy != 1) && + (!(sta->sta.deflink.ht_cap.cap & IEEE80211_HT_CAP_DELAY_BA))) || + (buf_size > max_buf_size)) { + status = WLAN_STATUS_INVALID_QOS_PARAM; + ht_dbg_ratelimited(sta->sdata, + "AddBA Req with bad params from %pM on tid %u. policy %d, buffer size %d\n", + sta->sta.addr, tid, ba_policy, buf_size); + goto end; + } + /* determine default buffer size */ + if (buf_size == 0) + buf_size = max_buf_size; + + /* make sure the size doesn't exceed the maximum supported by the hw */ + if (buf_size > sta->sta.max_rx_aggregation_subframes) + buf_size = sta->sta.max_rx_aggregation_subframes; + params.buf_size = buf_size; + + ht_dbg(sta->sdata, "AddBA Req buf_size=%d for %pM\n", + buf_size, sta->sta.addr); + + /* examine state machine */ + lockdep_assert_held(&sta->ampdu_mlme.mtx); + + if (test_bit(tid, sta->ampdu_mlme.agg_session_valid)) { + if (sta->ampdu_mlme.tid_rx_token[tid] == dialog_token) { + struct tid_ampdu_rx *tid_rx; + + ht_dbg_ratelimited(sta->sdata, + "updated AddBA Req from %pM on tid %u\n", + sta->sta.addr, tid); + /* We have no API to update the timeout value in the + * driver so reject the timeout update if the timeout + * changed. If it did not change, i.e., no real update, + * just reply with success. + */ + rcu_read_lock(); + tid_rx = rcu_dereference(sta->ampdu_mlme.tid_rx[tid]); + if (tid_rx && tid_rx->timeout == timeout) + status = WLAN_STATUS_SUCCESS; + else + status = WLAN_STATUS_REQUEST_DECLINED; + rcu_read_unlock(); + goto end; + } + + ht_dbg_ratelimited(sta->sdata, + "unexpected AddBA Req from %pM on tid %u\n", + sta->sta.addr, tid); + + /* delete existing Rx BA session on the same tid */ + ___ieee80211_stop_rx_ba_session(sta, tid, WLAN_BACK_RECIPIENT, + WLAN_STATUS_UNSPECIFIED_QOS, + false); + } + + if (ieee80211_hw_check(&local->hw, SUPPORTS_REORDERING_BUFFER)) { + ret = drv_ampdu_action(local, sta->sdata, ¶ms); + ht_dbg(sta->sdata, + "Rx A-MPDU request on %pM tid %d result %d\n", + sta->sta.addr, tid, ret); + if (!ret) + status = WLAN_STATUS_SUCCESS; + goto end; + } + + /* prepare A-MPDU MLME for Rx aggregation */ + tid_agg_rx = kzalloc(sizeof(*tid_agg_rx), GFP_KERNEL); + if (!tid_agg_rx) + goto end; + + spin_lock_init(&tid_agg_rx->reorder_lock); + + /* rx timer */ + timer_setup(&tid_agg_rx->session_timer, + sta_rx_agg_session_timer_expired, TIMER_DEFERRABLE); + + /* rx reorder timer */ + timer_setup(&tid_agg_rx->reorder_timer, + sta_rx_agg_reorder_timer_expired, 0); + + /* prepare reordering buffer */ + tid_agg_rx->reorder_buf = + kcalloc(buf_size, sizeof(struct sk_buff_head), GFP_KERNEL); + tid_agg_rx->reorder_time = + kcalloc(buf_size, sizeof(unsigned long), GFP_KERNEL); + if (!tid_agg_rx->reorder_buf || !tid_agg_rx->reorder_time) { + kfree(tid_agg_rx->reorder_buf); + kfree(tid_agg_rx->reorder_time); + kfree(tid_agg_rx); + goto end; + } + + for (i = 0; i < buf_size; i++) + __skb_queue_head_init(&tid_agg_rx->reorder_buf[i]); + + ret = drv_ampdu_action(local, sta->sdata, ¶ms); + ht_dbg(sta->sdata, "Rx A-MPDU request on %pM tid %d result %d\n", + sta->sta.addr, tid, ret); + if (ret) { + kfree(tid_agg_rx->reorder_buf); + kfree(tid_agg_rx->reorder_time); + kfree(tid_agg_rx); + goto end; + } + + /* update data */ + tid_agg_rx->ssn = start_seq_num; + tid_agg_rx->head_seq_num = start_seq_num; + tid_agg_rx->buf_size = buf_size; + tid_agg_rx->timeout = timeout; + tid_agg_rx->stored_mpdu_num = 0; + tid_agg_rx->auto_seq = auto_seq; + tid_agg_rx->started = false; + tid_agg_rx->reorder_buf_filtered = 0; + tid_agg_rx->tid = tid; + tid_agg_rx->sta = sta; + status = WLAN_STATUS_SUCCESS; + + /* activate it for RX */ + rcu_assign_pointer(sta->ampdu_mlme.tid_rx[tid], tid_agg_rx); + + if (timeout) { + mod_timer(&tid_agg_rx->session_timer, TU_TO_EXP_TIME(timeout)); + tid_agg_rx->last_rx = jiffies; + } + +end: + if (status == WLAN_STATUS_SUCCESS) { + __set_bit(tid, sta->ampdu_mlme.agg_session_valid); + __clear_bit(tid, sta->ampdu_mlme.unexpected_agg); + sta->ampdu_mlme.tid_rx_token[tid] = dialog_token; + } + + if (tx) + ieee80211_send_addba_resp(sta, sta->sta.addr, tid, + dialog_token, status, 1, buf_size, + timeout, addbaext); +} + +static void __ieee80211_start_rx_ba_session(struct sta_info *sta, + u8 dialog_token, u16 timeout, + u16 start_seq_num, u16 ba_policy, + u16 tid, u16 buf_size, bool tx, + bool auto_seq, + const struct ieee80211_addba_ext_ie *addbaext) +{ + mutex_lock(&sta->ampdu_mlme.mtx); + ___ieee80211_start_rx_ba_session(sta, dialog_token, timeout, + start_seq_num, ba_policy, tid, + buf_size, tx, auto_seq, addbaext); + mutex_unlock(&sta->ampdu_mlme.mtx); +} + +void ieee80211_process_addba_request(struct ieee80211_local *local, + struct sta_info *sta, + struct ieee80211_mgmt *mgmt, + size_t len) +{ + u16 capab, tid, timeout, ba_policy, buf_size, start_seq_num; + struct ieee802_11_elems *elems = NULL; + u8 dialog_token; + int ies_len; + + /* extract session parameters from addba request frame */ + dialog_token = mgmt->u.action.u.addba_req.dialog_token; + timeout = le16_to_cpu(mgmt->u.action.u.addba_req.timeout); + start_seq_num = + le16_to_cpu(mgmt->u.action.u.addba_req.start_seq_num) >> 4; + + capab = le16_to_cpu(mgmt->u.action.u.addba_req.capab); + ba_policy = (capab & IEEE80211_ADDBA_PARAM_POLICY_MASK) >> 1; + tid = (capab & IEEE80211_ADDBA_PARAM_TID_MASK) >> 2; + buf_size = (capab & IEEE80211_ADDBA_PARAM_BUF_SIZE_MASK) >> 6; + + ies_len = len - offsetof(struct ieee80211_mgmt, + u.action.u.addba_req.variable); + if (ies_len) { + elems = ieee802_11_parse_elems(mgmt->u.action.u.addba_req.variable, + ies_len, true, NULL); + if (!elems || elems->parse_error) + goto free; + } + + if (sta->sta.deflink.eht_cap.has_eht && elems && elems->addba_ext_ie) { + u8 buf_size_1k = u8_get_bits(elems->addba_ext_ie->data, + IEEE80211_ADDBA_EXT_BUF_SIZE_MASK); + + buf_size |= buf_size_1k << IEEE80211_ADDBA_EXT_BUF_SIZE_SHIFT; + } + + __ieee80211_start_rx_ba_session(sta, dialog_token, timeout, + start_seq_num, ba_policy, tid, + buf_size, true, false, + elems ? elems->addba_ext_ie : NULL); +free: + kfree(elems); +} + +void ieee80211_manage_rx_ba_offl(struct ieee80211_vif *vif, + const u8 *addr, unsigned int tid) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + + rcu_read_lock(); + sta = sta_info_get_bss(sdata, addr); + if (!sta) + goto unlock; + + set_bit(tid, sta->ampdu_mlme.tid_rx_manage_offl); + ieee80211_queue_work(&local->hw, &sta->ampdu_mlme.work); + unlock: + rcu_read_unlock(); +} +EXPORT_SYMBOL(ieee80211_manage_rx_ba_offl); + +void ieee80211_rx_ba_timer_expired(struct ieee80211_vif *vif, + const u8 *addr, unsigned int tid) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + + rcu_read_lock(); + sta = sta_info_get_bss(sdata, addr); + if (!sta) + goto unlock; + + set_bit(tid, sta->ampdu_mlme.tid_rx_timer_expired); + ieee80211_queue_work(&local->hw, &sta->ampdu_mlme.work); + + unlock: + rcu_read_unlock(); +} +EXPORT_SYMBOL(ieee80211_rx_ba_timer_expired); diff --git a/net/mac80211/agg-tx.c b/net/mac80211/agg-tx.c new file mode 100644 index 000000000..85d2b9e4b --- /dev/null +++ b/net/mac80211/agg-tx.c @@ -0,0 +1,1044 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * HT handling + * + * Copyright 2003, Jouni Malinen + * Copyright 2002-2005, Instant802 Networks, Inc. + * Copyright 2005-2006, Devicescape Software, Inc. + * Copyright 2006-2007 Jiri Benc + * Copyright 2007, Michael Wu + * Copyright 2007-2010, Intel Corporation + * Copyright(c) 2015-2017 Intel Deutschland GmbH + * Copyright (C) 2018 - 2022 Intel Corporation + */ + +#include +#include +#include +#include +#include "ieee80211_i.h" +#include "driver-ops.h" +#include "wme.h" + +/** + * DOC: TX A-MPDU aggregation + * + * Aggregation on the TX side requires setting the hardware flag + * %IEEE80211_HW_AMPDU_AGGREGATION. The driver will then be handed + * packets with a flag indicating A-MPDU aggregation. The driver + * or device is responsible for actually aggregating the frames, + * as well as deciding how many and which to aggregate. + * + * When TX aggregation is started by some subsystem (usually the rate + * control algorithm would be appropriate) by calling the + * ieee80211_start_tx_ba_session() function, the driver will be + * notified via its @ampdu_action function, with the + * %IEEE80211_AMPDU_TX_START action. + * + * In response to that, the driver is later required to call the + * ieee80211_start_tx_ba_cb_irqsafe() function, which will really + * start the aggregation session after the peer has also responded. + * If the peer responds negatively, the session will be stopped + * again right away. Note that it is possible for the aggregation + * session to be stopped before the driver has indicated that it + * is done setting it up, in which case it must not indicate the + * setup completion. + * + * Also note that, since we also need to wait for a response from + * the peer, the driver is notified of the completion of the + * handshake by the %IEEE80211_AMPDU_TX_OPERATIONAL action to the + * @ampdu_action callback. + * + * Similarly, when the aggregation session is stopped by the peer + * or something calling ieee80211_stop_tx_ba_session(), the driver's + * @ampdu_action function will be called with the action + * %IEEE80211_AMPDU_TX_STOP. In this case, the call must not fail, + * and the driver must later call ieee80211_stop_tx_ba_cb_irqsafe(). + * Note that the sta can get destroyed before the BA tear down is + * complete. + */ + +static void ieee80211_send_addba_request(struct ieee80211_sub_if_data *sdata, + const u8 *da, u16 tid, + u8 dialog_token, u16 start_seq_num, + u16 agg_size, u16 timeout) +{ + struct ieee80211_local *local = sdata->local; + struct sk_buff *skb; + struct ieee80211_mgmt *mgmt; + u16 capab; + + skb = dev_alloc_skb(sizeof(*mgmt) + local->hw.extra_tx_headroom); + + if (!skb) + return; + + skb_reserve(skb, local->hw.extra_tx_headroom); + mgmt = skb_put_zero(skb, 24); + memcpy(mgmt->da, da, ETH_ALEN); + memcpy(mgmt->sa, sdata->vif.addr, ETH_ALEN); + if (sdata->vif.type == NL80211_IFTYPE_AP || + sdata->vif.type == NL80211_IFTYPE_AP_VLAN || + sdata->vif.type == NL80211_IFTYPE_MESH_POINT) + memcpy(mgmt->bssid, sdata->vif.addr, ETH_ALEN); + else if (sdata->vif.type == NL80211_IFTYPE_STATION) + memcpy(mgmt->bssid, sdata->deflink.u.mgd.bssid, ETH_ALEN); + else if (sdata->vif.type == NL80211_IFTYPE_ADHOC) + memcpy(mgmt->bssid, sdata->u.ibss.bssid, ETH_ALEN); + + mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_ACTION); + + skb_put(skb, 1 + sizeof(mgmt->u.action.u.addba_req)); + + mgmt->u.action.category = WLAN_CATEGORY_BACK; + mgmt->u.action.u.addba_req.action_code = WLAN_ACTION_ADDBA_REQ; + + mgmt->u.action.u.addba_req.dialog_token = dialog_token; + capab = IEEE80211_ADDBA_PARAM_AMSDU_MASK; + capab |= IEEE80211_ADDBA_PARAM_POLICY_MASK; + capab |= u16_encode_bits(tid, IEEE80211_ADDBA_PARAM_TID_MASK); + capab |= u16_encode_bits(agg_size, IEEE80211_ADDBA_PARAM_BUF_SIZE_MASK); + + mgmt->u.action.u.addba_req.capab = cpu_to_le16(capab); + + mgmt->u.action.u.addba_req.timeout = cpu_to_le16(timeout); + mgmt->u.action.u.addba_req.start_seq_num = + cpu_to_le16(start_seq_num << 4); + + ieee80211_tx_skb_tid(sdata, skb, tid, -1); +} + +void ieee80211_send_bar(struct ieee80211_vif *vif, u8 *ra, u16 tid, u16 ssn) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_local *local = sdata->local; + struct sk_buff *skb; + struct ieee80211_bar *bar; + u16 bar_control = 0; + + skb = dev_alloc_skb(sizeof(*bar) + local->hw.extra_tx_headroom); + if (!skb) + return; + + skb_reserve(skb, local->hw.extra_tx_headroom); + bar = skb_put_zero(skb, sizeof(*bar)); + bar->frame_control = cpu_to_le16(IEEE80211_FTYPE_CTL | + IEEE80211_STYPE_BACK_REQ); + memcpy(bar->ra, ra, ETH_ALEN); + memcpy(bar->ta, sdata->vif.addr, ETH_ALEN); + bar_control |= (u16)IEEE80211_BAR_CTRL_ACK_POLICY_NORMAL; + bar_control |= (u16)IEEE80211_BAR_CTRL_CBMTID_COMPRESSED_BA; + bar_control |= (u16)(tid << IEEE80211_BAR_CTRL_TID_INFO_SHIFT); + bar->control = cpu_to_le16(bar_control); + bar->start_seq_num = cpu_to_le16(ssn); + + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_INTFL_DONT_ENCRYPT | + IEEE80211_TX_CTL_REQ_TX_STATUS; + ieee80211_tx_skb_tid(sdata, skb, tid, -1); +} +EXPORT_SYMBOL(ieee80211_send_bar); + +void ieee80211_assign_tid_tx(struct sta_info *sta, int tid, + struct tid_ampdu_tx *tid_tx) +{ + lockdep_assert_held(&sta->ampdu_mlme.mtx); + lockdep_assert_held(&sta->lock); + rcu_assign_pointer(sta->ampdu_mlme.tid_tx[tid], tid_tx); +} + +/* + * When multiple aggregation sessions on multiple stations + * are being created/destroyed simultaneously, we need to + * refcount the global queue stop caused by that in order + * to not get into a situation where one of the aggregation + * setup or teardown re-enables queues before the other is + * ready to handle that. + * + * These two functions take care of this issue by keeping + * a global "agg_queue_stop" refcount. + */ +static void __acquires(agg_queue) +ieee80211_stop_queue_agg(struct ieee80211_sub_if_data *sdata, int tid) +{ + int queue = sdata->vif.hw_queue[ieee80211_ac_from_tid(tid)]; + + /* we do refcounting here, so don't use the queue reason refcounting */ + + if (atomic_inc_return(&sdata->local->agg_queue_stop[queue]) == 1) + ieee80211_stop_queue_by_reason( + &sdata->local->hw, queue, + IEEE80211_QUEUE_STOP_REASON_AGGREGATION, + false); + __acquire(agg_queue); +} + +static void __releases(agg_queue) +ieee80211_wake_queue_agg(struct ieee80211_sub_if_data *sdata, int tid) +{ + int queue = sdata->vif.hw_queue[ieee80211_ac_from_tid(tid)]; + + if (atomic_dec_return(&sdata->local->agg_queue_stop[queue]) == 0) + ieee80211_wake_queue_by_reason( + &sdata->local->hw, queue, + IEEE80211_QUEUE_STOP_REASON_AGGREGATION, + false); + __release(agg_queue); +} + +static void +ieee80211_agg_stop_txq(struct sta_info *sta, int tid) +{ + struct ieee80211_txq *txq = sta->sta.txq[tid]; + struct ieee80211_sub_if_data *sdata; + struct fq *fq; + struct txq_info *txqi; + + if (!txq) + return; + + txqi = to_txq_info(txq); + sdata = vif_to_sdata(txq->vif); + fq = &sdata->local->fq; + + /* Lock here to protect against further seqno updates on dequeue */ + spin_lock_bh(&fq->lock); + set_bit(IEEE80211_TXQ_STOP, &txqi->flags); + spin_unlock_bh(&fq->lock); +} + +static void +ieee80211_agg_start_txq(struct sta_info *sta, int tid, bool enable) +{ + struct ieee80211_txq *txq = sta->sta.txq[tid]; + struct txq_info *txqi; + + lockdep_assert_held(&sta->ampdu_mlme.mtx); + + if (!txq) + return; + + txqi = to_txq_info(txq); + + if (enable) + set_bit(IEEE80211_TXQ_AMPDU, &txqi->flags); + else + clear_bit(IEEE80211_TXQ_AMPDU, &txqi->flags); + + clear_bit(IEEE80211_TXQ_STOP, &txqi->flags); + local_bh_disable(); + rcu_read_lock(); + schedule_and_wake_txq(sta->sdata->local, txqi); + rcu_read_unlock(); + local_bh_enable(); +} + +/* + * splice packets from the STA's pending to the local pending, + * requires a call to ieee80211_agg_splice_finish later + */ +static void __acquires(agg_queue) +ieee80211_agg_splice_packets(struct ieee80211_sub_if_data *sdata, + struct tid_ampdu_tx *tid_tx, u16 tid) +{ + struct ieee80211_local *local = sdata->local; + int queue = sdata->vif.hw_queue[ieee80211_ac_from_tid(tid)]; + unsigned long flags; + + ieee80211_stop_queue_agg(sdata, tid); + + if (WARN(!tid_tx, + "TID %d gone but expected when splicing aggregates from the pending queue\n", + tid)) + return; + + if (!skb_queue_empty(&tid_tx->pending)) { + spin_lock_irqsave(&local->queue_stop_reason_lock, flags); + /* copy over remaining packets */ + skb_queue_splice_tail_init(&tid_tx->pending, + &local->pending[queue]); + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); + } +} + +static void __releases(agg_queue) +ieee80211_agg_splice_finish(struct ieee80211_sub_if_data *sdata, u16 tid) +{ + ieee80211_wake_queue_agg(sdata, tid); +} + +static void ieee80211_remove_tid_tx(struct sta_info *sta, int tid) +{ + struct tid_ampdu_tx *tid_tx; + + lockdep_assert_held(&sta->ampdu_mlme.mtx); + lockdep_assert_held(&sta->lock); + + tid_tx = rcu_dereference_protected_tid_tx(sta, tid); + + /* + * When we get here, the TX path will not be lockless any more wrt. + * aggregation, since the OPERATIONAL bit has long been cleared. + * Thus it will block on getting the lock, if it occurs. So if we + * stop the queue now, we will not get any more packets, and any + * that might be being processed will wait for us here, thereby + * guaranteeing that no packets go to the tid_tx pending queue any + * more. + */ + + ieee80211_agg_splice_packets(sta->sdata, tid_tx, tid); + + /* future packets must not find the tid_tx struct any more */ + ieee80211_assign_tid_tx(sta, tid, NULL); + + ieee80211_agg_splice_finish(sta->sdata, tid); + + kfree_rcu(tid_tx, rcu_head); +} + +int ___ieee80211_stop_tx_ba_session(struct sta_info *sta, u16 tid, + enum ieee80211_agg_stop_reason reason) +{ + struct ieee80211_local *local = sta->local; + struct tid_ampdu_tx *tid_tx; + struct ieee80211_ampdu_params params = { + .sta = &sta->sta, + .tid = tid, + .buf_size = 0, + .amsdu = false, + .timeout = 0, + .ssn = 0, + }; + int ret; + + lockdep_assert_held(&sta->ampdu_mlme.mtx); + + switch (reason) { + case AGG_STOP_DECLINED: + case AGG_STOP_LOCAL_REQUEST: + case AGG_STOP_PEER_REQUEST: + params.action = IEEE80211_AMPDU_TX_STOP_CONT; + break; + case AGG_STOP_DESTROY_STA: + params.action = IEEE80211_AMPDU_TX_STOP_FLUSH; + break; + default: + WARN_ON_ONCE(1); + return -EINVAL; + } + + spin_lock_bh(&sta->lock); + + /* free struct pending for start, if present */ + tid_tx = sta->ampdu_mlme.tid_start_tx[tid]; + kfree(tid_tx); + sta->ampdu_mlme.tid_start_tx[tid] = NULL; + + tid_tx = rcu_dereference_protected_tid_tx(sta, tid); + if (!tid_tx) { + spin_unlock_bh(&sta->lock); + return -ENOENT; + } + + /* + * if we're already stopping ignore any new requests to stop + * unless we're destroying it in which case notify the driver + */ + if (test_bit(HT_AGG_STATE_STOPPING, &tid_tx->state)) { + spin_unlock_bh(&sta->lock); + if (reason != AGG_STOP_DESTROY_STA) + return -EALREADY; + params.action = IEEE80211_AMPDU_TX_STOP_FLUSH_CONT; + ret = drv_ampdu_action(local, sta->sdata, ¶ms); + WARN_ON_ONCE(ret); + return 0; + } + + if (test_bit(HT_AGG_STATE_WANT_START, &tid_tx->state)) { + /* not even started yet! */ + ieee80211_assign_tid_tx(sta, tid, NULL); + spin_unlock_bh(&sta->lock); + kfree_rcu(tid_tx, rcu_head); + return 0; + } + + set_bit(HT_AGG_STATE_STOPPING, &tid_tx->state); + + ieee80211_agg_stop_txq(sta, tid); + + spin_unlock_bh(&sta->lock); + + ht_dbg(sta->sdata, "Tx BA session stop requested for %pM tid %u\n", + sta->sta.addr, tid); + + del_timer_sync(&tid_tx->addba_resp_timer); + del_timer_sync(&tid_tx->session_timer); + + /* + * After this packets are no longer handed right through + * to the driver but are put onto tid_tx->pending instead, + * with locking to ensure proper access. + */ + clear_bit(HT_AGG_STATE_OPERATIONAL, &tid_tx->state); + + /* + * There might be a few packets being processed right now (on + * another CPU) that have already gotten past the aggregation + * check when it was still OPERATIONAL and consequently have + * IEEE80211_TX_CTL_AMPDU set. In that case, this code might + * call into the driver at the same time or even before the + * TX paths calls into it, which could confuse the driver. + * + * Wait for all currently running TX paths to finish before + * telling the driver. New packets will not go through since + * the aggregation session is no longer OPERATIONAL. + */ + if (!local->in_reconfig) + synchronize_net(); + + tid_tx->stop_initiator = reason == AGG_STOP_PEER_REQUEST ? + WLAN_BACK_RECIPIENT : + WLAN_BACK_INITIATOR; + tid_tx->tx_stop = reason == AGG_STOP_LOCAL_REQUEST; + + ret = drv_ampdu_action(local, sta->sdata, ¶ms); + + /* HW shall not deny going back to legacy */ + if (WARN_ON(ret)) { + /* + * We may have pending packets get stuck in this case... + * Not bothering with a workaround for now. + */ + } + + /* + * In the case of AGG_STOP_DESTROY_STA, the driver won't + * necessarily call ieee80211_stop_tx_ba_cb(), so this may + * seem like we can leave the tid_tx data pending forever. + * This is true, in a way, but "forever" is only until the + * station struct is actually destroyed. In the meantime, + * leaving it around ensures that we don't transmit packets + * to the driver on this TID which might confuse it. + */ + + return 0; +} + +/* + * After sending add Block Ack request we activated a timer until + * add Block Ack response will arrive from the recipient. + * If this timer expires sta_addba_resp_timer_expired will be executed. + */ +static void sta_addba_resp_timer_expired(struct timer_list *t) +{ + struct tid_ampdu_tx *tid_tx = from_timer(tid_tx, t, addba_resp_timer); + struct sta_info *sta = tid_tx->sta; + u8 tid = tid_tx->tid; + + /* check if the TID waits for addBA response */ + if (test_bit(HT_AGG_STATE_RESPONSE_RECEIVED, &tid_tx->state)) { + ht_dbg(sta->sdata, + "timer expired on %pM tid %d not expecting addBA response\n", + sta->sta.addr, tid); + return; + } + + ht_dbg(sta->sdata, "addBA response timer expired on %pM tid %d\n", + sta->sta.addr, tid); + + ieee80211_stop_tx_ba_session(&sta->sta, tid); +} + +static void ieee80211_send_addba_with_timeout(struct sta_info *sta, + struct tid_ampdu_tx *tid_tx) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sta->local; + u8 tid = tid_tx->tid; + u16 buf_size; + + /* activate the timer for the recipient's addBA response */ + mod_timer(&tid_tx->addba_resp_timer, jiffies + ADDBA_RESP_INTERVAL); + ht_dbg(sdata, "activated addBA response timer on %pM tid %d\n", + sta->sta.addr, tid); + + spin_lock_bh(&sta->lock); + sta->ampdu_mlme.last_addba_req_time[tid] = jiffies; + sta->ampdu_mlme.addba_req_num[tid]++; + spin_unlock_bh(&sta->lock); + + if (sta->sta.deflink.he_cap.has_he) { + buf_size = local->hw.max_tx_aggregation_subframes; + } else { + /* + * We really should use what the driver told us it will + * transmit as the maximum, but certain APs (e.g. the + * LinkSys WRT120N with FW v1.0.07 build 002 Jun 18 2012) + * will crash when we use a lower number. + */ + buf_size = IEEE80211_MAX_AMPDU_BUF_HT; + } + + /* send AddBA request */ + ieee80211_send_addba_request(sdata, sta->sta.addr, tid, + tid_tx->dialog_token, tid_tx->ssn, + buf_size, tid_tx->timeout); + + WARN_ON(test_and_set_bit(HT_AGG_STATE_SENT_ADDBA, &tid_tx->state)); +} + +void ieee80211_tx_ba_session_handle_start(struct sta_info *sta, int tid) +{ + struct tid_ampdu_tx *tid_tx; + struct ieee80211_local *local = sta->local; + struct ieee80211_sub_if_data *sdata; + struct ieee80211_ampdu_params params = { + .sta = &sta->sta, + .action = IEEE80211_AMPDU_TX_START, + .tid = tid, + .buf_size = 0, + .amsdu = false, + .timeout = 0, + }; + int ret; + + tid_tx = rcu_dereference_protected_tid_tx(sta, tid); + + /* + * Start queuing up packets for this aggregation session. + * We're going to release them once the driver is OK with + * that. + */ + clear_bit(HT_AGG_STATE_WANT_START, &tid_tx->state); + + /* + * Make sure no packets are being processed. This ensures that + * we have a valid starting sequence number and that in-flight + * packets have been flushed out and no packets for this TID + * will go into the driver during the ampdu_action call. + */ + synchronize_net(); + + sdata = sta->sdata; + params.ssn = sta->tid_seq[tid] >> 4; + ret = drv_ampdu_action(local, sdata, ¶ms); + tid_tx->ssn = params.ssn; + if (ret == IEEE80211_AMPDU_TX_START_DELAY_ADDBA) { + return; + } else if (ret == IEEE80211_AMPDU_TX_START_IMMEDIATE) { + /* + * We didn't send the request yet, so don't need to check + * here if we already got a response, just mark as driver + * ready immediately. + */ + set_bit(HT_AGG_STATE_DRV_READY, &tid_tx->state); + } else if (ret) { + if (!sdata) + return; + + ht_dbg(sdata, + "BA request denied - HW unavailable for %pM tid %d\n", + sta->sta.addr, tid); + spin_lock_bh(&sta->lock); + ieee80211_agg_splice_packets(sdata, tid_tx, tid); + ieee80211_assign_tid_tx(sta, tid, NULL); + ieee80211_agg_splice_finish(sdata, tid); + spin_unlock_bh(&sta->lock); + + ieee80211_agg_start_txq(sta, tid, false); + + kfree_rcu(tid_tx, rcu_head); + return; + } + + ieee80211_send_addba_with_timeout(sta, tid_tx); +} + +/* + * After accepting the AddBA Response we activated a timer, + * resetting it after each frame that we send. + */ +static void sta_tx_agg_session_timer_expired(struct timer_list *t) +{ + struct tid_ampdu_tx *tid_tx = from_timer(tid_tx, t, session_timer); + struct sta_info *sta = tid_tx->sta; + u8 tid = tid_tx->tid; + unsigned long timeout; + + if (test_bit(HT_AGG_STATE_STOPPING, &tid_tx->state)) { + return; + } + + timeout = tid_tx->last_tx + TU_TO_JIFFIES(tid_tx->timeout); + if (time_is_after_jiffies(timeout)) { + mod_timer(&tid_tx->session_timer, timeout); + return; + } + + ht_dbg(sta->sdata, "tx session timer expired on %pM tid %d\n", + sta->sta.addr, tid); + + ieee80211_stop_tx_ba_session(&sta->sta, tid); +} + +int ieee80211_start_tx_ba_session(struct ieee80211_sta *pubsta, u16 tid, + u16 timeout) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sdata->local; + struct tid_ampdu_tx *tid_tx; + int ret = 0; + + trace_api_start_tx_ba_session(pubsta, tid); + + if (WARN(sta->reserved_tid == tid, + "Requested to start BA session on reserved tid=%d", tid)) + return -EINVAL; + + if (!pubsta->deflink.ht_cap.ht_supported && + sta->sdata->vif.bss_conf.chandef.chan->band != NL80211_BAND_6GHZ) + return -EINVAL; + + if (WARN_ON_ONCE(!local->ops->ampdu_action)) + return -EINVAL; + + if ((tid >= IEEE80211_NUM_TIDS) || + !ieee80211_hw_check(&local->hw, AMPDU_AGGREGATION) || + ieee80211_hw_check(&local->hw, TX_AMPDU_SETUP_IN_HW)) + return -EINVAL; + + if (WARN_ON(tid >= IEEE80211_FIRST_TSPEC_TSID)) + return -EINVAL; + + ht_dbg(sdata, "Open BA session requested for %pM tid %u\n", + pubsta->addr, tid); + + if (sdata->vif.type != NL80211_IFTYPE_STATION && + sdata->vif.type != NL80211_IFTYPE_MESH_POINT && + sdata->vif.type != NL80211_IFTYPE_AP_VLAN && + sdata->vif.type != NL80211_IFTYPE_AP && + sdata->vif.type != NL80211_IFTYPE_ADHOC) + return -EINVAL; + + if (test_sta_flag(sta, WLAN_STA_BLOCK_BA)) { + ht_dbg(sdata, + "BA sessions blocked - Denying BA session request %pM tid %d\n", + sta->sta.addr, tid); + return -EINVAL; + } + + if (test_sta_flag(sta, WLAN_STA_MFP) && + !test_sta_flag(sta, WLAN_STA_AUTHORIZED)) { + ht_dbg(sdata, + "MFP STA not authorized - deny BA session request %pM tid %d\n", + sta->sta.addr, tid); + return -EINVAL; + } + + /* + * 802.11n-2009 11.5.1.1: If the initiating STA is an HT STA, is a + * member of an IBSS, and has no other existing Block Ack agreement + * with the recipient STA, then the initiating STA shall transmit a + * Probe Request frame to the recipient STA and shall not transmit an + * ADDBA Request frame unless it receives a Probe Response frame + * from the recipient within dot11ADDBAFailureTimeout. + * + * The probe request mechanism for ADDBA is currently not implemented, + * but we only build up Block Ack session with HT STAs. This information + * is set when we receive a bss info from a probe response or a beacon. + */ + if (sta->sdata->vif.type == NL80211_IFTYPE_ADHOC && + !sta->sta.deflink.ht_cap.ht_supported) { + ht_dbg(sdata, + "BA request denied - IBSS STA %pM does not advertise HT support\n", + pubsta->addr); + return -EINVAL; + } + + spin_lock_bh(&sta->lock); + + /* we have tried too many times, receiver does not want A-MPDU */ + if (sta->ampdu_mlme.addba_req_num[tid] > HT_AGG_MAX_RETRIES) { + ret = -EBUSY; + goto err_unlock_sta; + } + + /* + * if we have tried more than HT_AGG_BURST_RETRIES times we + * will spread our requests in time to avoid stalling connection + * for too long + */ + if (sta->ampdu_mlme.addba_req_num[tid] > HT_AGG_BURST_RETRIES && + time_before(jiffies, sta->ampdu_mlme.last_addba_req_time[tid] + + HT_AGG_RETRIES_PERIOD)) { + ht_dbg(sdata, + "BA request denied - %d failed requests on %pM tid %u\n", + sta->ampdu_mlme.addba_req_num[tid], sta->sta.addr, tid); + ret = -EBUSY; + goto err_unlock_sta; + } + + tid_tx = rcu_dereference_protected_tid_tx(sta, tid); + /* check if the TID is not in aggregation flow already */ + if (tid_tx || sta->ampdu_mlme.tid_start_tx[tid]) { + ht_dbg(sdata, + "BA request denied - session is not idle on %pM tid %u\n", + sta->sta.addr, tid); + ret = -EAGAIN; + goto err_unlock_sta; + } + + /* prepare A-MPDU MLME for Tx aggregation */ + tid_tx = kzalloc(sizeof(struct tid_ampdu_tx), GFP_ATOMIC); + if (!tid_tx) { + ret = -ENOMEM; + goto err_unlock_sta; + } + + skb_queue_head_init(&tid_tx->pending); + __set_bit(HT_AGG_STATE_WANT_START, &tid_tx->state); + + tid_tx->timeout = timeout; + tid_tx->sta = sta; + tid_tx->tid = tid; + + /* response timer */ + timer_setup(&tid_tx->addba_resp_timer, sta_addba_resp_timer_expired, 0); + + /* tx timer */ + timer_setup(&tid_tx->session_timer, + sta_tx_agg_session_timer_expired, TIMER_DEFERRABLE); + + /* assign a dialog token */ + sta->ampdu_mlme.dialog_token_allocator++; + tid_tx->dialog_token = sta->ampdu_mlme.dialog_token_allocator; + + /* + * Finally, assign it to the start array; the work item will + * collect it and move it to the normal array. + */ + sta->ampdu_mlme.tid_start_tx[tid] = tid_tx; + + ieee80211_queue_work(&local->hw, &sta->ampdu_mlme.work); + + /* this flow continues off the work */ + err_unlock_sta: + spin_unlock_bh(&sta->lock); + return ret; +} +EXPORT_SYMBOL(ieee80211_start_tx_ba_session); + +static void ieee80211_agg_tx_operational(struct ieee80211_local *local, + struct sta_info *sta, u16 tid) +{ + struct tid_ampdu_tx *tid_tx; + struct ieee80211_ampdu_params params = { + .sta = &sta->sta, + .action = IEEE80211_AMPDU_TX_OPERATIONAL, + .tid = tid, + .timeout = 0, + .ssn = 0, + }; + + lockdep_assert_held(&sta->ampdu_mlme.mtx); + + tid_tx = rcu_dereference_protected_tid_tx(sta, tid); + params.buf_size = tid_tx->buf_size; + params.amsdu = tid_tx->amsdu; + + ht_dbg(sta->sdata, "Aggregation is on for %pM tid %d\n", + sta->sta.addr, tid); + + drv_ampdu_action(local, sta->sdata, ¶ms); + + /* + * synchronize with TX path, while splicing the TX path + * should block so it won't put more packets onto pending. + */ + spin_lock_bh(&sta->lock); + + ieee80211_agg_splice_packets(sta->sdata, tid_tx, tid); + /* + * Now mark as operational. This will be visible + * in the TX path, and lets it go lock-free in + * the common case. + */ + set_bit(HT_AGG_STATE_OPERATIONAL, &tid_tx->state); + ieee80211_agg_splice_finish(sta->sdata, tid); + + spin_unlock_bh(&sta->lock); + + ieee80211_agg_start_txq(sta, tid, true); +} + +void ieee80211_start_tx_ba_cb(struct sta_info *sta, int tid, + struct tid_ampdu_tx *tid_tx) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sdata->local; + + if (WARN_ON(test_and_set_bit(HT_AGG_STATE_DRV_READY, &tid_tx->state))) + return; + + if (!test_bit(HT_AGG_STATE_SENT_ADDBA, &tid_tx->state)) { + ieee80211_send_addba_with_timeout(sta, tid_tx); + /* RESPONSE_RECEIVED state whould trigger the flow again */ + return; + } + + if (test_bit(HT_AGG_STATE_RESPONSE_RECEIVED, &tid_tx->state)) + ieee80211_agg_tx_operational(local, sta, tid); +} + +static struct tid_ampdu_tx * +ieee80211_lookup_tid_tx(struct ieee80211_sub_if_data *sdata, + const u8 *ra, u16 tid, struct sta_info **sta) +{ + struct tid_ampdu_tx *tid_tx; + + if (tid >= IEEE80211_NUM_TIDS) { + ht_dbg(sdata, "Bad TID value: tid = %d (>= %d)\n", + tid, IEEE80211_NUM_TIDS); + return NULL; + } + + *sta = sta_info_get_bss(sdata, ra); + if (!*sta) { + ht_dbg(sdata, "Could not find station: %pM\n", ra); + return NULL; + } + + tid_tx = rcu_dereference((*sta)->ampdu_mlme.tid_tx[tid]); + + if (WARN_ON(!tid_tx)) + ht_dbg(sdata, "addBA was not requested!\n"); + + return tid_tx; +} + +void ieee80211_start_tx_ba_cb_irqsafe(struct ieee80211_vif *vif, + const u8 *ra, u16 tid) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + struct tid_ampdu_tx *tid_tx; + + trace_api_start_tx_ba_cb(sdata, ra, tid); + + rcu_read_lock(); + tid_tx = ieee80211_lookup_tid_tx(sdata, ra, tid, &sta); + if (!tid_tx) + goto out; + + set_bit(HT_AGG_STATE_START_CB, &tid_tx->state); + ieee80211_queue_work(&local->hw, &sta->ampdu_mlme.work); + out: + rcu_read_unlock(); +} +EXPORT_SYMBOL(ieee80211_start_tx_ba_cb_irqsafe); + +int __ieee80211_stop_tx_ba_session(struct sta_info *sta, u16 tid, + enum ieee80211_agg_stop_reason reason) +{ + int ret; + + mutex_lock(&sta->ampdu_mlme.mtx); + + ret = ___ieee80211_stop_tx_ba_session(sta, tid, reason); + + mutex_unlock(&sta->ampdu_mlme.mtx); + + return ret; +} + +int ieee80211_stop_tx_ba_session(struct ieee80211_sta *pubsta, u16 tid) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sdata->local; + struct tid_ampdu_tx *tid_tx; + int ret = 0; + + trace_api_stop_tx_ba_session(pubsta, tid); + + if (!local->ops->ampdu_action) + return -EINVAL; + + if (tid >= IEEE80211_NUM_TIDS) + return -EINVAL; + + spin_lock_bh(&sta->lock); + tid_tx = rcu_dereference_protected_tid_tx(sta, tid); + + if (!tid_tx) { + ret = -ENOENT; + goto unlock; + } + + WARN(sta->reserved_tid == tid, + "Requested to stop BA session on reserved tid=%d", tid); + + if (test_bit(HT_AGG_STATE_STOPPING, &tid_tx->state)) { + /* already in progress stopping it */ + ret = 0; + goto unlock; + } + + set_bit(HT_AGG_STATE_WANT_STOP, &tid_tx->state); + ieee80211_queue_work(&local->hw, &sta->ampdu_mlme.work); + + unlock: + spin_unlock_bh(&sta->lock); + return ret; +} +EXPORT_SYMBOL(ieee80211_stop_tx_ba_session); + +void ieee80211_stop_tx_ba_cb(struct sta_info *sta, int tid, + struct tid_ampdu_tx *tid_tx) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + bool send_delba = false; + bool start_txq = false; + + ht_dbg(sdata, "Stopping Tx BA session for %pM tid %d\n", + sta->sta.addr, tid); + + spin_lock_bh(&sta->lock); + + if (!test_bit(HT_AGG_STATE_STOPPING, &tid_tx->state)) { + ht_dbg(sdata, + "unexpected callback to A-MPDU stop for %pM tid %d\n", + sta->sta.addr, tid); + goto unlock_sta; + } + + if (tid_tx->stop_initiator == WLAN_BACK_INITIATOR && tid_tx->tx_stop) + send_delba = true; + + ieee80211_remove_tid_tx(sta, tid); + start_txq = true; + + unlock_sta: + spin_unlock_bh(&sta->lock); + + if (start_txq) + ieee80211_agg_start_txq(sta, tid, false); + + if (send_delba) + ieee80211_send_delba(sdata, sta->sta.addr, tid, + WLAN_BACK_INITIATOR, WLAN_REASON_QSTA_NOT_USE); +} + +void ieee80211_stop_tx_ba_cb_irqsafe(struct ieee80211_vif *vif, + const u8 *ra, u16 tid) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + struct tid_ampdu_tx *tid_tx; + + trace_api_stop_tx_ba_cb(sdata, ra, tid); + + rcu_read_lock(); + tid_tx = ieee80211_lookup_tid_tx(sdata, ra, tid, &sta); + if (!tid_tx) + goto out; + + set_bit(HT_AGG_STATE_STOP_CB, &tid_tx->state); + ieee80211_queue_work(&local->hw, &sta->ampdu_mlme.work); + out: + rcu_read_unlock(); +} +EXPORT_SYMBOL(ieee80211_stop_tx_ba_cb_irqsafe); + + +void ieee80211_process_addba_resp(struct ieee80211_local *local, + struct sta_info *sta, + struct ieee80211_mgmt *mgmt, + size_t len) +{ + struct tid_ampdu_tx *tid_tx; + struct ieee80211_txq *txq; + u16 capab, tid, buf_size; + bool amsdu; + + capab = le16_to_cpu(mgmt->u.action.u.addba_resp.capab); + amsdu = capab & IEEE80211_ADDBA_PARAM_AMSDU_MASK; + tid = u16_get_bits(capab, IEEE80211_ADDBA_PARAM_TID_MASK); + buf_size = u16_get_bits(capab, IEEE80211_ADDBA_PARAM_BUF_SIZE_MASK); + buf_size = min(buf_size, local->hw.max_tx_aggregation_subframes); + + txq = sta->sta.txq[tid]; + if (!amsdu && txq) + set_bit(IEEE80211_TXQ_NO_AMSDU, &to_txq_info(txq)->flags); + + mutex_lock(&sta->ampdu_mlme.mtx); + + tid_tx = rcu_dereference_protected_tid_tx(sta, tid); + if (!tid_tx) + goto out; + + if (mgmt->u.action.u.addba_resp.dialog_token != tid_tx->dialog_token) { + ht_dbg(sta->sdata, "wrong addBA response token, %pM tid %d\n", + sta->sta.addr, tid); + goto out; + } + + del_timer_sync(&tid_tx->addba_resp_timer); + + ht_dbg(sta->sdata, "switched off addBA timer for %pM tid %d\n", + sta->sta.addr, tid); + + /* + * addba_resp_timer may have fired before we got here, and + * caused WANT_STOP to be set. If the stop then was already + * processed further, STOPPING might be set. + */ + if (test_bit(HT_AGG_STATE_WANT_STOP, &tid_tx->state) || + test_bit(HT_AGG_STATE_STOPPING, &tid_tx->state)) { + ht_dbg(sta->sdata, + "got addBA resp for %pM tid %d but we already gave up\n", + sta->sta.addr, tid); + goto out; + } + + /* + * IEEE 802.11-2007 7.3.1.14: + * In an ADDBA Response frame, when the Status Code field + * is set to 0, the Buffer Size subfield is set to a value + * of at least 1. + */ + if (le16_to_cpu(mgmt->u.action.u.addba_resp.status) + == WLAN_STATUS_SUCCESS && buf_size) { + if (test_and_set_bit(HT_AGG_STATE_RESPONSE_RECEIVED, + &tid_tx->state)) { + /* ignore duplicate response */ + goto out; + } + + tid_tx->buf_size = buf_size; + tid_tx->amsdu = amsdu; + + if (test_bit(HT_AGG_STATE_DRV_READY, &tid_tx->state)) + ieee80211_agg_tx_operational(local, sta, tid); + + sta->ampdu_mlme.addba_req_num[tid] = 0; + + tid_tx->timeout = + le16_to_cpu(mgmt->u.action.u.addba_resp.timeout); + + if (tid_tx->timeout) { + mod_timer(&tid_tx->session_timer, + TU_TO_EXP_TIME(tid_tx->timeout)); + tid_tx->last_tx = jiffies; + } + + } else { + ___ieee80211_stop_tx_ba_session(sta, tid, AGG_STOP_DECLINED); + } + + out: + mutex_unlock(&sta->ampdu_mlme.mtx); +} diff --git a/net/mac80211/airtime.c b/net/mac80211/airtime.c new file mode 100644 index 000000000..e8ebd343e --- /dev/null +++ b/net/mac80211/airtime.c @@ -0,0 +1,711 @@ +// SPDX-License-Identifier: ISC +/* + * Copyright (C) 2019 Felix Fietkau + * Copyright (C) 2021-2022 Intel Corporation + */ + +#include +#include "ieee80211_i.h" +#include "sta_info.h" + +#define AVG_PKT_SIZE 1024 + +/* Number of bits for an average sized packet */ +#define MCS_NBITS (AVG_PKT_SIZE << 3) + +/* Number of kilo-symbols (symbols * 1024) for a packet with (bps) bits per + * symbol. We use k-symbols to avoid rounding in the _TIME macros below. + */ +#define MCS_N_KSYMS(bps) DIV_ROUND_UP(MCS_NBITS << 10, (bps)) + +/* Transmission time (in 1024 * usec) for a packet containing (ksyms) * 1024 + * symbols. + */ +#define MCS_SYMBOL_TIME(sgi, ksyms) \ + (sgi ? \ + ((ksyms) * 4 * 18) / 20 : /* 3.6 us per sym */ \ + ((ksyms) * 4) /* 4.0 us per sym */ \ + ) + +/* Transmit duration for the raw data part of an average sized packet */ +#define MCS_DURATION(streams, sgi, bps) \ + ((u32)MCS_SYMBOL_TIME(sgi, MCS_N_KSYMS((streams) * (bps)))) + +#define MCS_DURATION_S(shift, streams, sgi, bps) \ + ((u16)((MCS_DURATION(streams, sgi, bps) >> shift))) + +/* These should match the values in enum nl80211_he_gi */ +#define HE_GI_08 0 +#define HE_GI_16 1 +#define HE_GI_32 2 + +/* Transmission time (1024 usec) for a packet containing (ksyms) * k-symbols */ +#define HE_SYMBOL_TIME(gi, ksyms) \ + (gi == HE_GI_08 ? \ + ((ksyms) * 16 * 17) / 20 : /* 13.6 us per sym */ \ + (gi == HE_GI_16 ? \ + ((ksyms) * 16 * 18) / 20 : /* 14.4 us per sym */ \ + ((ksyms) * 16) /* 16.0 us per sym */ \ + )) + +/* Transmit duration for the raw data part of an average sized packet */ +#define HE_DURATION(streams, gi, bps) \ + ((u32)HE_SYMBOL_TIME(gi, MCS_N_KSYMS((streams) * (bps)))) + +#define HE_DURATION_S(shift, streams, gi, bps) \ + (HE_DURATION(streams, gi, bps) >> shift) + +#define BW_20 0 +#define BW_40 1 +#define BW_80 2 +#define BW_160 3 + +/* + * Define group sort order: HT40 -> SGI -> #streams + */ +#define IEEE80211_MAX_STREAMS 4 +#define IEEE80211_HT_STREAM_GROUPS 4 /* BW(=2) * SGI(=2) */ +#define IEEE80211_VHT_STREAM_GROUPS 8 /* BW(=4) * SGI(=2) */ + +#define IEEE80211_HE_MAX_STREAMS 8 + +#define IEEE80211_HT_GROUPS_NB (IEEE80211_MAX_STREAMS * \ + IEEE80211_HT_STREAM_GROUPS) +#define IEEE80211_VHT_GROUPS_NB (IEEE80211_MAX_STREAMS * \ + IEEE80211_VHT_STREAM_GROUPS) + +#define IEEE80211_HT_GROUP_0 0 +#define IEEE80211_VHT_GROUP_0 (IEEE80211_HT_GROUP_0 + IEEE80211_HT_GROUPS_NB) +#define IEEE80211_HE_GROUP_0 (IEEE80211_VHT_GROUP_0 + IEEE80211_VHT_GROUPS_NB) + +#define MCS_GROUP_RATES 12 + +#define HT_GROUP_IDX(_streams, _sgi, _ht40) \ + IEEE80211_HT_GROUP_0 + \ + IEEE80211_MAX_STREAMS * 2 * _ht40 + \ + IEEE80211_MAX_STREAMS * _sgi + \ + _streams - 1 + +#define _MAX(a, b) (((a)>(b))?(a):(b)) + +#define GROUP_SHIFT(duration) \ + _MAX(0, 16 - __builtin_clz(duration)) + +/* MCS rate information for an MCS group */ +#define __MCS_GROUP(_streams, _sgi, _ht40, _s) \ + [HT_GROUP_IDX(_streams, _sgi, _ht40)] = { \ + .shift = _s, \ + .duration = { \ + MCS_DURATION_S(_s, _streams, _sgi, _ht40 ? 54 : 26), \ + MCS_DURATION_S(_s, _streams, _sgi, _ht40 ? 108 : 52), \ + MCS_DURATION_S(_s, _streams, _sgi, _ht40 ? 162 : 78), \ + MCS_DURATION_S(_s, _streams, _sgi, _ht40 ? 216 : 104), \ + MCS_DURATION_S(_s, _streams, _sgi, _ht40 ? 324 : 156), \ + MCS_DURATION_S(_s, _streams, _sgi, _ht40 ? 432 : 208), \ + MCS_DURATION_S(_s, _streams, _sgi, _ht40 ? 486 : 234), \ + MCS_DURATION_S(_s, _streams, _sgi, _ht40 ? 540 : 260) \ + } \ +} + +#define MCS_GROUP_SHIFT(_streams, _sgi, _ht40) \ + GROUP_SHIFT(MCS_DURATION(_streams, _sgi, _ht40 ? 54 : 26)) + +#define MCS_GROUP(_streams, _sgi, _ht40) \ + __MCS_GROUP(_streams, _sgi, _ht40, \ + MCS_GROUP_SHIFT(_streams, _sgi, _ht40)) + +#define VHT_GROUP_IDX(_streams, _sgi, _bw) \ + (IEEE80211_VHT_GROUP_0 + \ + IEEE80211_MAX_STREAMS * 2 * (_bw) + \ + IEEE80211_MAX_STREAMS * (_sgi) + \ + (_streams) - 1) + +#define BW2VBPS(_bw, r4, r3, r2, r1) \ + (_bw == BW_160 ? r4 : _bw == BW_80 ? r3 : _bw == BW_40 ? r2 : r1) + +#define __VHT_GROUP(_streams, _sgi, _bw, _s) \ + [VHT_GROUP_IDX(_streams, _sgi, _bw)] = { \ + .shift = _s, \ + .duration = { \ + MCS_DURATION_S(_s, _streams, _sgi, \ + BW2VBPS(_bw, 234, 117, 54, 26)), \ + MCS_DURATION_S(_s, _streams, _sgi, \ + BW2VBPS(_bw, 468, 234, 108, 52)), \ + MCS_DURATION_S(_s, _streams, _sgi, \ + BW2VBPS(_bw, 702, 351, 162, 78)), \ + MCS_DURATION_S(_s, _streams, _sgi, \ + BW2VBPS(_bw, 936, 468, 216, 104)), \ + MCS_DURATION_S(_s, _streams, _sgi, \ + BW2VBPS(_bw, 1404, 702, 324, 156)), \ + MCS_DURATION_S(_s, _streams, _sgi, \ + BW2VBPS(_bw, 1872, 936, 432, 208)), \ + MCS_DURATION_S(_s, _streams, _sgi, \ + BW2VBPS(_bw, 2106, 1053, 486, 234)), \ + MCS_DURATION_S(_s, _streams, _sgi, \ + BW2VBPS(_bw, 2340, 1170, 540, 260)), \ + MCS_DURATION_S(_s, _streams, _sgi, \ + BW2VBPS(_bw, 2808, 1404, 648, 312)), \ + MCS_DURATION_S(_s, _streams, _sgi, \ + BW2VBPS(_bw, 3120, 1560, 720, 346)) \ + } \ +} + +#define VHT_GROUP_SHIFT(_streams, _sgi, _bw) \ + GROUP_SHIFT(MCS_DURATION(_streams, _sgi, \ + BW2VBPS(_bw, 243, 117, 54, 26))) + +#define VHT_GROUP(_streams, _sgi, _bw) \ + __VHT_GROUP(_streams, _sgi, _bw, \ + VHT_GROUP_SHIFT(_streams, _sgi, _bw)) + + +#define HE_GROUP_IDX(_streams, _gi, _bw) \ + (IEEE80211_HE_GROUP_0 + \ + IEEE80211_HE_MAX_STREAMS * 3 * (_bw) + \ + IEEE80211_HE_MAX_STREAMS * (_gi) + \ + (_streams) - 1) + +#define __HE_GROUP(_streams, _gi, _bw, _s) \ + [HE_GROUP_IDX(_streams, _gi, _bw)] = { \ + .shift = _s, \ + .duration = { \ + HE_DURATION_S(_s, _streams, _gi, \ + BW2VBPS(_bw, 979, 489, 230, 115)), \ + HE_DURATION_S(_s, _streams, _gi, \ + BW2VBPS(_bw, 1958, 979, 475, 230)), \ + HE_DURATION_S(_s, _streams, _gi, \ + BW2VBPS(_bw, 2937, 1468, 705, 345)), \ + HE_DURATION_S(_s, _streams, _gi, \ + BW2VBPS(_bw, 3916, 1958, 936, 475)), \ + HE_DURATION_S(_s, _streams, _gi, \ + BW2VBPS(_bw, 5875, 2937, 1411, 705)), \ + HE_DURATION_S(_s, _streams, _gi, \ + BW2VBPS(_bw, 7833, 3916, 1872, 936)), \ + HE_DURATION_S(_s, _streams, _gi, \ + BW2VBPS(_bw, 8827, 4406, 2102, 1051)), \ + HE_DURATION_S(_s, _streams, _gi, \ + BW2VBPS(_bw, 9806, 4896, 2347, 1166)), \ + HE_DURATION_S(_s, _streams, _gi, \ + BW2VBPS(_bw, 11764, 5875, 2808, 1411)), \ + HE_DURATION_S(_s, _streams, _gi, \ + BW2VBPS(_bw, 13060, 6523, 3124, 1555)), \ + HE_DURATION_S(_s, _streams, _gi, \ + BW2VBPS(_bw, 14702, 7344, 3513, 1756)), \ + HE_DURATION_S(_s, _streams, _gi, \ + BW2VBPS(_bw, 16329, 8164, 3902, 1944)) \ + } \ +} + +#define HE_GROUP_SHIFT(_streams, _gi, _bw) \ + GROUP_SHIFT(HE_DURATION(_streams, _gi, \ + BW2VBPS(_bw, 979, 489, 230, 115))) + +#define HE_GROUP(_streams, _gi, _bw) \ + __HE_GROUP(_streams, _gi, _bw, \ + HE_GROUP_SHIFT(_streams, _gi, _bw)) +struct mcs_group { + u8 shift; + u16 duration[MCS_GROUP_RATES]; +}; + +static const struct mcs_group airtime_mcs_groups[] = { + MCS_GROUP(1, 0, BW_20), + MCS_GROUP(2, 0, BW_20), + MCS_GROUP(3, 0, BW_20), + MCS_GROUP(4, 0, BW_20), + + MCS_GROUP(1, 1, BW_20), + MCS_GROUP(2, 1, BW_20), + MCS_GROUP(3, 1, BW_20), + MCS_GROUP(4, 1, BW_20), + + MCS_GROUP(1, 0, BW_40), + MCS_GROUP(2, 0, BW_40), + MCS_GROUP(3, 0, BW_40), + MCS_GROUP(4, 0, BW_40), + + MCS_GROUP(1, 1, BW_40), + MCS_GROUP(2, 1, BW_40), + MCS_GROUP(3, 1, BW_40), + MCS_GROUP(4, 1, BW_40), + + VHT_GROUP(1, 0, BW_20), + VHT_GROUP(2, 0, BW_20), + VHT_GROUP(3, 0, BW_20), + VHT_GROUP(4, 0, BW_20), + + VHT_GROUP(1, 1, BW_20), + VHT_GROUP(2, 1, BW_20), + VHT_GROUP(3, 1, BW_20), + VHT_GROUP(4, 1, BW_20), + + VHT_GROUP(1, 0, BW_40), + VHT_GROUP(2, 0, BW_40), + VHT_GROUP(3, 0, BW_40), + VHT_GROUP(4, 0, BW_40), + + VHT_GROUP(1, 1, BW_40), + VHT_GROUP(2, 1, BW_40), + VHT_GROUP(3, 1, BW_40), + VHT_GROUP(4, 1, BW_40), + + VHT_GROUP(1, 0, BW_80), + VHT_GROUP(2, 0, BW_80), + VHT_GROUP(3, 0, BW_80), + VHT_GROUP(4, 0, BW_80), + + VHT_GROUP(1, 1, BW_80), + VHT_GROUP(2, 1, BW_80), + VHT_GROUP(3, 1, BW_80), + VHT_GROUP(4, 1, BW_80), + + VHT_GROUP(1, 0, BW_160), + VHT_GROUP(2, 0, BW_160), + VHT_GROUP(3, 0, BW_160), + VHT_GROUP(4, 0, BW_160), + + VHT_GROUP(1, 1, BW_160), + VHT_GROUP(2, 1, BW_160), + VHT_GROUP(3, 1, BW_160), + VHT_GROUP(4, 1, BW_160), + + HE_GROUP(1, HE_GI_08, BW_20), + HE_GROUP(2, HE_GI_08, BW_20), + HE_GROUP(3, HE_GI_08, BW_20), + HE_GROUP(4, HE_GI_08, BW_20), + HE_GROUP(5, HE_GI_08, BW_20), + HE_GROUP(6, HE_GI_08, BW_20), + HE_GROUP(7, HE_GI_08, BW_20), + HE_GROUP(8, HE_GI_08, BW_20), + + HE_GROUP(1, HE_GI_16, BW_20), + HE_GROUP(2, HE_GI_16, BW_20), + HE_GROUP(3, HE_GI_16, BW_20), + HE_GROUP(4, HE_GI_16, BW_20), + HE_GROUP(5, HE_GI_16, BW_20), + HE_GROUP(6, HE_GI_16, BW_20), + HE_GROUP(7, HE_GI_16, BW_20), + HE_GROUP(8, HE_GI_16, BW_20), + + HE_GROUP(1, HE_GI_32, BW_20), + HE_GROUP(2, HE_GI_32, BW_20), + HE_GROUP(3, HE_GI_32, BW_20), + HE_GROUP(4, HE_GI_32, BW_20), + HE_GROUP(5, HE_GI_32, BW_20), + HE_GROUP(6, HE_GI_32, BW_20), + HE_GROUP(7, HE_GI_32, BW_20), + HE_GROUP(8, HE_GI_32, BW_20), + + HE_GROUP(1, HE_GI_08, BW_40), + HE_GROUP(2, HE_GI_08, BW_40), + HE_GROUP(3, HE_GI_08, BW_40), + HE_GROUP(4, HE_GI_08, BW_40), + HE_GROUP(5, HE_GI_08, BW_40), + HE_GROUP(6, HE_GI_08, BW_40), + HE_GROUP(7, HE_GI_08, BW_40), + HE_GROUP(8, HE_GI_08, BW_40), + + HE_GROUP(1, HE_GI_16, BW_40), + HE_GROUP(2, HE_GI_16, BW_40), + HE_GROUP(3, HE_GI_16, BW_40), + HE_GROUP(4, HE_GI_16, BW_40), + HE_GROUP(5, HE_GI_16, BW_40), + HE_GROUP(6, HE_GI_16, BW_40), + HE_GROUP(7, HE_GI_16, BW_40), + HE_GROUP(8, HE_GI_16, BW_40), + + HE_GROUP(1, HE_GI_32, BW_40), + HE_GROUP(2, HE_GI_32, BW_40), + HE_GROUP(3, HE_GI_32, BW_40), + HE_GROUP(4, HE_GI_32, BW_40), + HE_GROUP(5, HE_GI_32, BW_40), + HE_GROUP(6, HE_GI_32, BW_40), + HE_GROUP(7, HE_GI_32, BW_40), + HE_GROUP(8, HE_GI_32, BW_40), + + HE_GROUP(1, HE_GI_08, BW_80), + HE_GROUP(2, HE_GI_08, BW_80), + HE_GROUP(3, HE_GI_08, BW_80), + HE_GROUP(4, HE_GI_08, BW_80), + HE_GROUP(5, HE_GI_08, BW_80), + HE_GROUP(6, HE_GI_08, BW_80), + HE_GROUP(7, HE_GI_08, BW_80), + HE_GROUP(8, HE_GI_08, BW_80), + + HE_GROUP(1, HE_GI_16, BW_80), + HE_GROUP(2, HE_GI_16, BW_80), + HE_GROUP(3, HE_GI_16, BW_80), + HE_GROUP(4, HE_GI_16, BW_80), + HE_GROUP(5, HE_GI_16, BW_80), + HE_GROUP(6, HE_GI_16, BW_80), + HE_GROUP(7, HE_GI_16, BW_80), + HE_GROUP(8, HE_GI_16, BW_80), + + HE_GROUP(1, HE_GI_32, BW_80), + HE_GROUP(2, HE_GI_32, BW_80), + HE_GROUP(3, HE_GI_32, BW_80), + HE_GROUP(4, HE_GI_32, BW_80), + HE_GROUP(5, HE_GI_32, BW_80), + HE_GROUP(6, HE_GI_32, BW_80), + HE_GROUP(7, HE_GI_32, BW_80), + HE_GROUP(8, HE_GI_32, BW_80), + + HE_GROUP(1, HE_GI_08, BW_160), + HE_GROUP(2, HE_GI_08, BW_160), + HE_GROUP(3, HE_GI_08, BW_160), + HE_GROUP(4, HE_GI_08, BW_160), + HE_GROUP(5, HE_GI_08, BW_160), + HE_GROUP(6, HE_GI_08, BW_160), + HE_GROUP(7, HE_GI_08, BW_160), + HE_GROUP(8, HE_GI_08, BW_160), + + HE_GROUP(1, HE_GI_16, BW_160), + HE_GROUP(2, HE_GI_16, BW_160), + HE_GROUP(3, HE_GI_16, BW_160), + HE_GROUP(4, HE_GI_16, BW_160), + HE_GROUP(5, HE_GI_16, BW_160), + HE_GROUP(6, HE_GI_16, BW_160), + HE_GROUP(7, HE_GI_16, BW_160), + HE_GROUP(8, HE_GI_16, BW_160), + + HE_GROUP(1, HE_GI_32, BW_160), + HE_GROUP(2, HE_GI_32, BW_160), + HE_GROUP(3, HE_GI_32, BW_160), + HE_GROUP(4, HE_GI_32, BW_160), + HE_GROUP(5, HE_GI_32, BW_160), + HE_GROUP(6, HE_GI_32, BW_160), + HE_GROUP(7, HE_GI_32, BW_160), + HE_GROUP(8, HE_GI_32, BW_160), +}; + +static u32 +ieee80211_calc_legacy_rate_duration(u16 bitrate, bool short_pre, + bool cck, int len) +{ + u32 duration; + + if (cck) { + duration = 144 + 48; /* preamble + PLCP */ + if (short_pre) + duration >>= 1; + + duration += 10; /* SIFS */ + } else { + duration = 20 + 16; /* premable + SIFS */ + } + + len <<= 3; + duration += (len * 10) / bitrate; + + return duration; +} + +static u32 ieee80211_get_rate_duration(struct ieee80211_hw *hw, + struct ieee80211_rx_status *status, + u32 *overhead) +{ + bool sgi = status->enc_flags & RX_ENC_FLAG_SHORT_GI; + int bw, streams; + int group, idx; + u32 duration; + + switch (status->bw) { + case RATE_INFO_BW_20: + bw = BW_20; + break; + case RATE_INFO_BW_40: + bw = BW_40; + break; + case RATE_INFO_BW_80: + bw = BW_80; + break; + case RATE_INFO_BW_160: + bw = BW_160; + break; + default: + WARN_ON_ONCE(1); + return 0; + } + + switch (status->encoding) { + case RX_ENC_VHT: + streams = status->nss; + idx = status->rate_idx; + group = VHT_GROUP_IDX(streams, sgi, bw); + break; + case RX_ENC_HT: + streams = ((status->rate_idx >> 3) & 3) + 1; + idx = status->rate_idx & 7; + group = HT_GROUP_IDX(streams, sgi, bw); + break; + case RX_ENC_HE: + streams = status->nss; + idx = status->rate_idx; + group = HE_GROUP_IDX(streams, status->he_gi, bw); + break; + default: + WARN_ON_ONCE(1); + return 0; + } + + if (WARN_ON_ONCE((status->encoding != RX_ENC_HE && streams > 4) || + (status->encoding == RX_ENC_HE && streams > 8))) + return 0; + + if (idx >= MCS_GROUP_RATES) + return 0; + + duration = airtime_mcs_groups[group].duration[idx]; + duration <<= airtime_mcs_groups[group].shift; + *overhead = 36 + (streams << 2); + + return duration; +} + + +u32 ieee80211_calc_rx_airtime(struct ieee80211_hw *hw, + struct ieee80211_rx_status *status, + int len) +{ + struct ieee80211_supported_band *sband; + u32 duration, overhead = 0; + + if (status->encoding == RX_ENC_LEGACY) { + const struct ieee80211_rate *rate; + bool sp = status->enc_flags & RX_ENC_FLAG_SHORTPRE; + bool cck; + + /* on 60GHz or sub-1GHz band, there are no legacy rates */ + if (WARN_ON_ONCE(status->band == NL80211_BAND_60GHZ || + status->band == NL80211_BAND_S1GHZ)) + return 0; + + sband = hw->wiphy->bands[status->band]; + if (!sband || status->rate_idx >= sband->n_bitrates) + return 0; + + rate = &sband->bitrates[status->rate_idx]; + cck = rate->flags & IEEE80211_RATE_MANDATORY_B; + + return ieee80211_calc_legacy_rate_duration(rate->bitrate, sp, + cck, len); + } + + duration = ieee80211_get_rate_duration(hw, status, &overhead); + if (!duration) + return 0; + + duration *= len; + duration /= AVG_PKT_SIZE; + duration /= 1024; + + return duration + overhead; +} +EXPORT_SYMBOL_GPL(ieee80211_calc_rx_airtime); + +static bool ieee80211_fill_rate_info(struct ieee80211_hw *hw, + struct ieee80211_rx_status *stat, u8 band, + struct rate_info *ri) +{ + struct ieee80211_supported_band *sband = hw->wiphy->bands[band]; + int i; + + if (!ri || !sband) + return false; + + stat->bw = ri->bw; + stat->nss = ri->nss; + stat->rate_idx = ri->mcs; + + if (ri->flags & RATE_INFO_FLAGS_HE_MCS) + stat->encoding = RX_ENC_HE; + else if (ri->flags & RATE_INFO_FLAGS_VHT_MCS) + stat->encoding = RX_ENC_VHT; + else if (ri->flags & RATE_INFO_FLAGS_MCS) + stat->encoding = RX_ENC_HT; + else + stat->encoding = RX_ENC_LEGACY; + + if (ri->flags & RATE_INFO_FLAGS_SHORT_GI) + stat->enc_flags |= RX_ENC_FLAG_SHORT_GI; + + stat->he_gi = ri->he_gi; + + if (stat->encoding != RX_ENC_LEGACY) + return true; + + stat->rate_idx = 0; + for (i = 0; i < sband->n_bitrates; i++) { + if (ri->legacy != sband->bitrates[i].bitrate) + continue; + + stat->rate_idx = i; + return true; + } + + return false; +} + +static int ieee80211_fill_rx_status(struct ieee80211_rx_status *stat, + struct ieee80211_hw *hw, + struct ieee80211_tx_rate *rate, + struct rate_info *ri, u8 band, int len) +{ + memset(stat, 0, sizeof(*stat)); + stat->band = band; + + if (ieee80211_fill_rate_info(hw, stat, band, ri)) + return 0; + + if (rate->idx < 0 || !rate->count) + return -1; + + if (rate->flags & IEEE80211_TX_RC_160_MHZ_WIDTH) + stat->bw = RATE_INFO_BW_160; + else if (rate->flags & IEEE80211_TX_RC_80_MHZ_WIDTH) + stat->bw = RATE_INFO_BW_80; + else if (rate->flags & IEEE80211_TX_RC_40_MHZ_WIDTH) + stat->bw = RATE_INFO_BW_40; + else + stat->bw = RATE_INFO_BW_20; + + stat->enc_flags = 0; + if (rate->flags & IEEE80211_TX_RC_USE_SHORT_PREAMBLE) + stat->enc_flags |= RX_ENC_FLAG_SHORTPRE; + if (rate->flags & IEEE80211_TX_RC_SHORT_GI) + stat->enc_flags |= RX_ENC_FLAG_SHORT_GI; + + stat->rate_idx = rate->idx; + if (rate->flags & IEEE80211_TX_RC_VHT_MCS) { + stat->encoding = RX_ENC_VHT; + stat->rate_idx = ieee80211_rate_get_vht_mcs(rate); + stat->nss = ieee80211_rate_get_vht_nss(rate); + } else if (rate->flags & IEEE80211_TX_RC_MCS) { + stat->encoding = RX_ENC_HT; + } else { + stat->encoding = RX_ENC_LEGACY; + } + + return 0; +} + +static u32 ieee80211_calc_tx_airtime_rate(struct ieee80211_hw *hw, + struct ieee80211_tx_rate *rate, + struct rate_info *ri, + u8 band, int len) +{ + struct ieee80211_rx_status stat; + + if (ieee80211_fill_rx_status(&stat, hw, rate, ri, band, len)) + return 0; + + return ieee80211_calc_rx_airtime(hw, &stat, len); +} + +u32 ieee80211_calc_tx_airtime(struct ieee80211_hw *hw, + struct ieee80211_tx_info *info, + int len) +{ + u32 duration = 0; + int i; + + for (i = 0; i < ARRAY_SIZE(info->status.rates); i++) { + struct ieee80211_tx_rate *rate = &info->status.rates[i]; + u32 cur_duration; + + cur_duration = ieee80211_calc_tx_airtime_rate(hw, rate, NULL, + info->band, len); + if (!cur_duration) + break; + + duration += cur_duration * rate->count; + } + + return duration; +} +EXPORT_SYMBOL_GPL(ieee80211_calc_tx_airtime); + +u32 ieee80211_calc_expected_tx_airtime(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, + struct ieee80211_sta *pubsta, + int len, bool ampdu) +{ + struct ieee80211_supported_band *sband; + struct ieee80211_chanctx_conf *conf; + int rateidx, shift = 0; + bool cck, short_pream; + u32 basic_rates; + u8 band = 0; + u16 rate; + + len += 38; /* Ethernet header length */ + + conf = rcu_dereference(vif->bss_conf.chanctx_conf); + if (conf) { + band = conf->def.chan->band; + shift = ieee80211_chandef_get_shift(&conf->def); + } + + if (pubsta) { + struct sta_info *sta = container_of(pubsta, struct sta_info, + sta); + struct ieee80211_rx_status stat; + struct ieee80211_tx_rate *tx_rate = &sta->deflink.tx_stats.last_rate; + struct rate_info *ri = &sta->deflink.tx_stats.last_rate_info; + u32 duration, overhead; + u8 agg_shift; + + if (ieee80211_fill_rx_status(&stat, hw, tx_rate, ri, band, len)) + return 0; + + if (stat.encoding == RX_ENC_LEGACY || !ampdu) + return ieee80211_calc_rx_airtime(hw, &stat, len); + + duration = ieee80211_get_rate_duration(hw, &stat, &overhead); + /* + * Assume that HT/VHT transmission on any AC except VO will + * use aggregation. Since we don't have reliable reporting + * of aggregation length, assume an average size based on the + * tx rate. + * This will not be very accurate, but much better than simply + * assuming un-aggregated tx in all cases. + */ + if (duration > 400 * 1024) /* <= VHT20 MCS2 1S */ + agg_shift = 1; + else if (duration > 250 * 1024) /* <= VHT20 MCS3 1S or MCS1 2S */ + agg_shift = 2; + else if (duration > 150 * 1024) /* <= VHT20 MCS5 1S or MCS2 2S */ + agg_shift = 3; + else if (duration > 70 * 1024) /* <= VHT20 MCS5 2S */ + agg_shift = 4; + else if (stat.encoding != RX_ENC_HE || + duration > 20 * 1024) /* <= HE40 MCS6 2S */ + agg_shift = 5; + else + agg_shift = 6; + + duration *= len; + duration /= AVG_PKT_SIZE; + duration /= 1024; + duration += (overhead >> agg_shift); + + return max_t(u32, duration, 4); + } + + if (!conf) + return 0; + + /* No station to get latest rate from, so calculate the worst-case + * duration using the lowest configured basic rate. + */ + sband = hw->wiphy->bands[band]; + + basic_rates = vif->bss_conf.basic_rates; + short_pream = vif->bss_conf.use_short_preamble; + + rateidx = basic_rates ? ffs(basic_rates) - 1 : 0; + rate = sband->bitrates[rateidx].bitrate << shift; + cck = sband->bitrates[rateidx].flags & IEEE80211_RATE_MANDATORY_B; + + return ieee80211_calc_legacy_rate_duration(rate, short_pream, cck, len); +} diff --git a/net/mac80211/cfg.c b/net/mac80211/cfg.c new file mode 100644 index 000000000..a2c486608 --- /dev/null +++ b/net/mac80211/cfg.c @@ -0,0 +1,4989 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * mac80211 configuration hooks for cfg80211 + * + * Copyright 2006-2010 Johannes Berg + * Copyright 2013-2015 Intel Mobile Communications GmbH + * Copyright (C) 2015-2017 Intel Deutschland GmbH + * Copyright (C) 2018-2022 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ieee80211_i.h" +#include "driver-ops.h" +#include "rate.h" +#include "mesh.h" +#include "wme.h" + +static struct ieee80211_link_data * +ieee80211_link_or_deflink(struct ieee80211_sub_if_data *sdata, int link_id, + bool require_valid) +{ + struct ieee80211_link_data *link; + + if (link_id < 0) { + /* + * For keys, if sdata is not an MLD, we might not use + * the return value at all (if it's not a pairwise key), + * so in that case (require_valid==false) don't error. + */ + if (require_valid && sdata->vif.valid_links) + return ERR_PTR(-EINVAL); + + return &sdata->deflink; + } + + link = sdata_dereference(sdata->link[link_id], sdata); + if (!link) + return ERR_PTR(-ENOLINK); + return link; +} + +static void ieee80211_set_mu_mimo_follow(struct ieee80211_sub_if_data *sdata, + struct vif_params *params) +{ + bool mu_mimo_groups = false; + bool mu_mimo_follow = false; + + if (params->vht_mumimo_groups) { + u64 membership; + + BUILD_BUG_ON(sizeof(membership) != WLAN_MEMBERSHIP_LEN); + + memcpy(sdata->vif.bss_conf.mu_group.membership, + params->vht_mumimo_groups, WLAN_MEMBERSHIP_LEN); + memcpy(sdata->vif.bss_conf.mu_group.position, + params->vht_mumimo_groups + WLAN_MEMBERSHIP_LEN, + WLAN_USER_POSITION_LEN); + ieee80211_link_info_change_notify(sdata, &sdata->deflink, + BSS_CHANGED_MU_GROUPS); + /* don't care about endianness - just check for 0 */ + memcpy(&membership, params->vht_mumimo_groups, + WLAN_MEMBERSHIP_LEN); + mu_mimo_groups = membership != 0; + } + + if (params->vht_mumimo_follow_addr) { + mu_mimo_follow = + is_valid_ether_addr(params->vht_mumimo_follow_addr); + ether_addr_copy(sdata->u.mntr.mu_follow_addr, + params->vht_mumimo_follow_addr); + } + + sdata->vif.bss_conf.mu_mimo_owner = mu_mimo_groups || mu_mimo_follow; +} + +static int ieee80211_set_mon_options(struct ieee80211_sub_if_data *sdata, + struct vif_params *params) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_sub_if_data *monitor_sdata; + + /* check flags first */ + if (params->flags && ieee80211_sdata_running(sdata)) { + u32 mask = MONITOR_FLAG_COOK_FRAMES | MONITOR_FLAG_ACTIVE; + + /* + * Prohibit MONITOR_FLAG_COOK_FRAMES and + * MONITOR_FLAG_ACTIVE to be changed while the + * interface is up. + * Else we would need to add a lot of cruft + * to update everything: + * cooked_mntrs, monitor and all fif_* counters + * reconfigure hardware + */ + if ((params->flags & mask) != (sdata->u.mntr.flags & mask)) + return -EBUSY; + } + + /* also validate MU-MIMO change */ + monitor_sdata = wiphy_dereference(local->hw.wiphy, + local->monitor_sdata); + + if (!monitor_sdata && + (params->vht_mumimo_groups || params->vht_mumimo_follow_addr)) + return -EOPNOTSUPP; + + /* apply all changes now - no failures allowed */ + + if (monitor_sdata) + ieee80211_set_mu_mimo_follow(monitor_sdata, params); + + if (params->flags) { + if (ieee80211_sdata_running(sdata)) { + ieee80211_adjust_monitor_flags(sdata, -1); + sdata->u.mntr.flags = params->flags; + ieee80211_adjust_monitor_flags(sdata, 1); + + ieee80211_configure_filter(local); + } else { + /* + * Because the interface is down, ieee80211_do_stop + * and ieee80211_do_open take care of "everything" + * mentioned in the comment above. + */ + sdata->u.mntr.flags = params->flags; + } + } + + return 0; +} + +static int ieee80211_set_ap_mbssid_options(struct ieee80211_sub_if_data *sdata, + struct cfg80211_mbssid_config params, + struct ieee80211_bss_conf *link_conf) +{ + struct ieee80211_sub_if_data *tx_sdata; + + sdata->vif.mbssid_tx_vif = NULL; + link_conf->bssid_index = 0; + link_conf->nontransmitted = false; + link_conf->ema_ap = false; + link_conf->bssid_indicator = 0; + + if (sdata->vif.type != NL80211_IFTYPE_AP || !params.tx_wdev) + return -EINVAL; + + tx_sdata = IEEE80211_WDEV_TO_SUB_IF(params.tx_wdev); + if (!tx_sdata) + return -EINVAL; + + if (tx_sdata == sdata) { + sdata->vif.mbssid_tx_vif = &sdata->vif; + } else { + sdata->vif.mbssid_tx_vif = &tx_sdata->vif; + link_conf->nontransmitted = true; + link_conf->bssid_index = params.index; + } + if (params.ema) + link_conf->ema_ap = true; + + return 0; +} + +static struct wireless_dev *ieee80211_add_iface(struct wiphy *wiphy, + const char *name, + unsigned char name_assign_type, + enum nl80211_iftype type, + struct vif_params *params) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + struct wireless_dev *wdev; + struct ieee80211_sub_if_data *sdata; + int err; + + err = ieee80211_if_add(local, name, name_assign_type, &wdev, type, params); + if (err) + return ERR_PTR(err); + + sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + + if (type == NL80211_IFTYPE_MONITOR) { + err = ieee80211_set_mon_options(sdata, params); + if (err) { + ieee80211_if_remove(sdata); + return NULL; + } + } + + return wdev; +} + +static int ieee80211_del_iface(struct wiphy *wiphy, struct wireless_dev *wdev) +{ + ieee80211_if_remove(IEEE80211_WDEV_TO_SUB_IF(wdev)); + + return 0; +} + +static int ieee80211_change_iface(struct wiphy *wiphy, + struct net_device *dev, + enum nl80211_iftype type, + struct vif_params *params) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + int ret; + + ret = ieee80211_if_change_type(sdata, type); + if (ret) + return ret; + + if (type == NL80211_IFTYPE_AP_VLAN && params->use_4addr == 0) { + RCU_INIT_POINTER(sdata->u.vlan.sta, NULL); + ieee80211_check_fast_rx_iface(sdata); + } else if (type == NL80211_IFTYPE_STATION && params->use_4addr >= 0) { + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + + if (params->use_4addr == ifmgd->use_4addr) + return 0; + + /* FIXME: no support for 4-addr MLO yet */ + if (sdata->vif.valid_links) + return -EOPNOTSUPP; + + sdata->u.mgd.use_4addr = params->use_4addr; + if (!ifmgd->associated) + return 0; + + mutex_lock(&local->sta_mtx); + sta = sta_info_get(sdata, sdata->deflink.u.mgd.bssid); + if (sta) + drv_sta_set_4addr(local, sdata, &sta->sta, + params->use_4addr); + mutex_unlock(&local->sta_mtx); + + if (params->use_4addr) + ieee80211_send_4addr_nullfunc(local, sdata); + } + + if (sdata->vif.type == NL80211_IFTYPE_MONITOR) { + ret = ieee80211_set_mon_options(sdata, params); + if (ret) + return ret; + } + + return 0; +} + +static int ieee80211_start_p2p_device(struct wiphy *wiphy, + struct wireless_dev *wdev) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + int ret; + + mutex_lock(&sdata->local->chanctx_mtx); + ret = ieee80211_check_combinations(sdata, NULL, 0, 0); + mutex_unlock(&sdata->local->chanctx_mtx); + if (ret < 0) + return ret; + + return ieee80211_do_open(wdev, true); +} + +static void ieee80211_stop_p2p_device(struct wiphy *wiphy, + struct wireless_dev *wdev) +{ + ieee80211_sdata_stop(IEEE80211_WDEV_TO_SUB_IF(wdev)); +} + +static int ieee80211_start_nan(struct wiphy *wiphy, + struct wireless_dev *wdev, + struct cfg80211_nan_conf *conf) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + int ret; + + mutex_lock(&sdata->local->chanctx_mtx); + ret = ieee80211_check_combinations(sdata, NULL, 0, 0); + mutex_unlock(&sdata->local->chanctx_mtx); + if (ret < 0) + return ret; + + ret = ieee80211_do_open(wdev, true); + if (ret) + return ret; + + ret = drv_start_nan(sdata->local, sdata, conf); + if (ret) + ieee80211_sdata_stop(sdata); + + sdata->u.nan.conf = *conf; + + return ret; +} + +static void ieee80211_stop_nan(struct wiphy *wiphy, + struct wireless_dev *wdev) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + + drv_stop_nan(sdata->local, sdata); + ieee80211_sdata_stop(sdata); +} + +static int ieee80211_nan_change_conf(struct wiphy *wiphy, + struct wireless_dev *wdev, + struct cfg80211_nan_conf *conf, + u32 changes) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + struct cfg80211_nan_conf new_conf; + int ret = 0; + + if (sdata->vif.type != NL80211_IFTYPE_NAN) + return -EOPNOTSUPP; + + if (!ieee80211_sdata_running(sdata)) + return -ENETDOWN; + + new_conf = sdata->u.nan.conf; + + if (changes & CFG80211_NAN_CONF_CHANGED_PREF) + new_conf.master_pref = conf->master_pref; + + if (changes & CFG80211_NAN_CONF_CHANGED_BANDS) + new_conf.bands = conf->bands; + + ret = drv_nan_change_conf(sdata->local, sdata, &new_conf, changes); + if (!ret) + sdata->u.nan.conf = new_conf; + + return ret; +} + +static int ieee80211_add_nan_func(struct wiphy *wiphy, + struct wireless_dev *wdev, + struct cfg80211_nan_func *nan_func) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + int ret; + + if (sdata->vif.type != NL80211_IFTYPE_NAN) + return -EOPNOTSUPP; + + if (!ieee80211_sdata_running(sdata)) + return -ENETDOWN; + + spin_lock_bh(&sdata->u.nan.func_lock); + + ret = idr_alloc(&sdata->u.nan.function_inst_ids, + nan_func, 1, sdata->local->hw.max_nan_de_entries + 1, + GFP_ATOMIC); + spin_unlock_bh(&sdata->u.nan.func_lock); + + if (ret < 0) + return ret; + + nan_func->instance_id = ret; + + WARN_ON(nan_func->instance_id == 0); + + ret = drv_add_nan_func(sdata->local, sdata, nan_func); + if (ret) { + spin_lock_bh(&sdata->u.nan.func_lock); + idr_remove(&sdata->u.nan.function_inst_ids, + nan_func->instance_id); + spin_unlock_bh(&sdata->u.nan.func_lock); + } + + return ret; +} + +static struct cfg80211_nan_func * +ieee80211_find_nan_func_by_cookie(struct ieee80211_sub_if_data *sdata, + u64 cookie) +{ + struct cfg80211_nan_func *func; + int id; + + lockdep_assert_held(&sdata->u.nan.func_lock); + + idr_for_each_entry(&sdata->u.nan.function_inst_ids, func, id) { + if (func->cookie == cookie) + return func; + } + + return NULL; +} + +static void ieee80211_del_nan_func(struct wiphy *wiphy, + struct wireless_dev *wdev, u64 cookie) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + struct cfg80211_nan_func *func; + u8 instance_id = 0; + + if (sdata->vif.type != NL80211_IFTYPE_NAN || + !ieee80211_sdata_running(sdata)) + return; + + spin_lock_bh(&sdata->u.nan.func_lock); + + func = ieee80211_find_nan_func_by_cookie(sdata, cookie); + if (func) + instance_id = func->instance_id; + + spin_unlock_bh(&sdata->u.nan.func_lock); + + if (instance_id) + drv_del_nan_func(sdata->local, sdata, instance_id); +} + +static int ieee80211_set_noack_map(struct wiphy *wiphy, + struct net_device *dev, + u16 noack_map) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + sdata->noack_map = noack_map; + + ieee80211_check_fast_xmit_iface(sdata); + + return 0; +} + +static int ieee80211_set_tx(struct ieee80211_sub_if_data *sdata, + const u8 *mac_addr, u8 key_idx) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_key *key; + struct sta_info *sta; + int ret = -EINVAL; + + if (!wiphy_ext_feature_isset(local->hw.wiphy, + NL80211_EXT_FEATURE_EXT_KEY_ID)) + return -EINVAL; + + sta = sta_info_get_bss(sdata, mac_addr); + + if (!sta) + return -EINVAL; + + if (sta->ptk_idx == key_idx) + return 0; + + mutex_lock(&local->key_mtx); + key = key_mtx_dereference(local, sta->ptk[key_idx]); + + if (key && key->conf.flags & IEEE80211_KEY_FLAG_NO_AUTO_TX) + ret = ieee80211_set_tx_key(key); + + mutex_unlock(&local->key_mtx); + return ret; +} + +static int ieee80211_add_key(struct wiphy *wiphy, struct net_device *dev, + int link_id, u8 key_idx, bool pairwise, + const u8 *mac_addr, struct key_params *params) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_link_data *link = + ieee80211_link_or_deflink(sdata, link_id, false); + struct ieee80211_local *local = sdata->local; + struct sta_info *sta = NULL; + struct ieee80211_key *key; + int err; + + if (!ieee80211_sdata_running(sdata)) + return -ENETDOWN; + + if (IS_ERR(link)) + return PTR_ERR(link); + + if (pairwise && params->mode == NL80211_KEY_SET_TX) + return ieee80211_set_tx(sdata, mac_addr, key_idx); + + /* reject WEP and TKIP keys if WEP failed to initialize */ + switch (params->cipher) { + case WLAN_CIPHER_SUITE_WEP40: + case WLAN_CIPHER_SUITE_TKIP: + case WLAN_CIPHER_SUITE_WEP104: + if (link_id >= 0) + return -EINVAL; + if (WARN_ON_ONCE(fips_enabled)) + return -EINVAL; + break; + default: + break; + } + + key = ieee80211_key_alloc(params->cipher, key_idx, params->key_len, + params->key, params->seq_len, params->seq); + if (IS_ERR(key)) + return PTR_ERR(key); + + key->conf.link_id = link_id; + + if (pairwise) + key->conf.flags |= IEEE80211_KEY_FLAG_PAIRWISE; + + if (params->mode == NL80211_KEY_NO_TX) + key->conf.flags |= IEEE80211_KEY_FLAG_NO_AUTO_TX; + + mutex_lock(&local->sta_mtx); + + if (mac_addr) { + sta = sta_info_get_bss(sdata, mac_addr); + /* + * The ASSOC test makes sure the driver is ready to + * receive the key. When wpa_supplicant has roamed + * using FT, it attempts to set the key before + * association has completed, this rejects that attempt + * so it will set the key again after association. + * + * TODO: accept the key if we have a station entry and + * add it to the device after the station. + */ + if (!sta || !test_sta_flag(sta, WLAN_STA_ASSOC)) { + ieee80211_key_free_unused(key); + err = -ENOENT; + goto out_unlock; + } + } + + switch (sdata->vif.type) { + case NL80211_IFTYPE_STATION: + if (sdata->u.mgd.mfp != IEEE80211_MFP_DISABLED) + key->conf.flags |= IEEE80211_KEY_FLAG_RX_MGMT; + break; + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + /* Keys without a station are used for TX only */ + if (sta && test_sta_flag(sta, WLAN_STA_MFP)) + key->conf.flags |= IEEE80211_KEY_FLAG_RX_MGMT; + break; + case NL80211_IFTYPE_ADHOC: + /* no MFP (yet) */ + break; + case NL80211_IFTYPE_MESH_POINT: +#ifdef CONFIG_MAC80211_MESH + if (sdata->u.mesh.security != IEEE80211_MESH_SEC_NONE) + key->conf.flags |= IEEE80211_KEY_FLAG_RX_MGMT; + break; +#endif + case NL80211_IFTYPE_WDS: + case NL80211_IFTYPE_MONITOR: + case NL80211_IFTYPE_P2P_DEVICE: + case NL80211_IFTYPE_NAN: + case NL80211_IFTYPE_UNSPECIFIED: + case NUM_NL80211_IFTYPES: + case NL80211_IFTYPE_P2P_CLIENT: + case NL80211_IFTYPE_P2P_GO: + case NL80211_IFTYPE_OCB: + /* shouldn't happen */ + WARN_ON_ONCE(1); + break; + } + + err = ieee80211_key_link(key, link, sta); + /* KRACK protection, shouldn't happen but just silently accept key */ + if (err == -EALREADY) + err = 0; + + out_unlock: + mutex_unlock(&local->sta_mtx); + + return err; +} + +static struct ieee80211_key * +ieee80211_lookup_key(struct ieee80211_sub_if_data *sdata, int link_id, + u8 key_idx, bool pairwise, const u8 *mac_addr) +{ + struct ieee80211_local *local __maybe_unused = sdata->local; + struct ieee80211_link_data *link = &sdata->deflink; + struct ieee80211_key *key; + + if (link_id >= 0) { + link = rcu_dereference_check(sdata->link[link_id], + lockdep_is_held(&sdata->wdev.mtx)); + if (!link) + return NULL; + } + + if (mac_addr) { + struct sta_info *sta; + struct link_sta_info *link_sta; + + sta = sta_info_get_bss(sdata, mac_addr); + if (!sta) + return NULL; + + if (link_id >= 0) { + link_sta = rcu_dereference_check(sta->link[link_id], + lockdep_is_held(&local->sta_mtx)); + if (!link_sta) + return NULL; + } else { + link_sta = &sta->deflink; + } + + if (pairwise && key_idx < NUM_DEFAULT_KEYS) + return rcu_dereference_check_key_mtx(local, + sta->ptk[key_idx]); + + if (!pairwise && + key_idx < NUM_DEFAULT_KEYS + + NUM_DEFAULT_MGMT_KEYS + + NUM_DEFAULT_BEACON_KEYS) + return rcu_dereference_check_key_mtx(local, + link_sta->gtk[key_idx]); + + return NULL; + } + + if (pairwise && key_idx < NUM_DEFAULT_KEYS) + return rcu_dereference_check_key_mtx(local, + sdata->keys[key_idx]); + + key = rcu_dereference_check_key_mtx(local, link->gtk[key_idx]); + if (key) + return key; + + /* or maybe it was a WEP key */ + if (key_idx < NUM_DEFAULT_KEYS) + return rcu_dereference_check_key_mtx(local, sdata->keys[key_idx]); + + return NULL; +} + +static int ieee80211_del_key(struct wiphy *wiphy, struct net_device *dev, + int link_id, u8 key_idx, bool pairwise, + const u8 *mac_addr) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + struct ieee80211_key *key; + int ret; + + mutex_lock(&local->sta_mtx); + mutex_lock(&local->key_mtx); + + key = ieee80211_lookup_key(sdata, link_id, key_idx, pairwise, mac_addr); + if (!key) { + ret = -ENOENT; + goto out_unlock; + } + + ieee80211_key_free(key, sdata->vif.type == NL80211_IFTYPE_STATION); + + ret = 0; + out_unlock: + mutex_unlock(&local->key_mtx); + mutex_unlock(&local->sta_mtx); + + return ret; +} + +static int ieee80211_get_key(struct wiphy *wiphy, struct net_device *dev, + int link_id, u8 key_idx, bool pairwise, + const u8 *mac_addr, void *cookie, + void (*callback)(void *cookie, + struct key_params *params)) +{ + struct ieee80211_sub_if_data *sdata; + u8 seq[6] = {0}; + struct key_params params; + struct ieee80211_key *key; + u64 pn64; + u32 iv32; + u16 iv16; + int err = -ENOENT; + struct ieee80211_key_seq kseq = {}; + + sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + rcu_read_lock(); + + key = ieee80211_lookup_key(sdata, link_id, key_idx, pairwise, mac_addr); + if (!key) + goto out; + + memset(¶ms, 0, sizeof(params)); + + params.cipher = key->conf.cipher; + + switch (key->conf.cipher) { + case WLAN_CIPHER_SUITE_TKIP: + pn64 = atomic64_read(&key->conf.tx_pn); + iv32 = TKIP_PN_TO_IV32(pn64); + iv16 = TKIP_PN_TO_IV16(pn64); + + if (key->flags & KEY_FLAG_UPLOADED_TO_HARDWARE && + !(key->conf.flags & IEEE80211_KEY_FLAG_GENERATE_IV)) { + drv_get_key_seq(sdata->local, key, &kseq); + iv32 = kseq.tkip.iv32; + iv16 = kseq.tkip.iv16; + } + + seq[0] = iv16 & 0xff; + seq[1] = (iv16 >> 8) & 0xff; + seq[2] = iv32 & 0xff; + seq[3] = (iv32 >> 8) & 0xff; + seq[4] = (iv32 >> 16) & 0xff; + seq[5] = (iv32 >> 24) & 0xff; + params.seq = seq; + params.seq_len = 6; + break; + case WLAN_CIPHER_SUITE_CCMP: + case WLAN_CIPHER_SUITE_CCMP_256: + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + BUILD_BUG_ON(offsetof(typeof(kseq), ccmp) != + offsetof(typeof(kseq), aes_cmac)); + fallthrough; + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + BUILD_BUG_ON(offsetof(typeof(kseq), ccmp) != + offsetof(typeof(kseq), aes_gmac)); + fallthrough; + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + BUILD_BUG_ON(offsetof(typeof(kseq), ccmp) != + offsetof(typeof(kseq), gcmp)); + + if (key->flags & KEY_FLAG_UPLOADED_TO_HARDWARE && + !(key->conf.flags & IEEE80211_KEY_FLAG_GENERATE_IV)) { + drv_get_key_seq(sdata->local, key, &kseq); + memcpy(seq, kseq.ccmp.pn, 6); + } else { + pn64 = atomic64_read(&key->conf.tx_pn); + seq[0] = pn64; + seq[1] = pn64 >> 8; + seq[2] = pn64 >> 16; + seq[3] = pn64 >> 24; + seq[4] = pn64 >> 32; + seq[5] = pn64 >> 40; + } + params.seq = seq; + params.seq_len = 6; + break; + default: + if (!(key->flags & KEY_FLAG_UPLOADED_TO_HARDWARE)) + break; + if (WARN_ON(key->conf.flags & IEEE80211_KEY_FLAG_GENERATE_IV)) + break; + drv_get_key_seq(sdata->local, key, &kseq); + params.seq = kseq.hw.seq; + params.seq_len = kseq.hw.seq_len; + break; + } + + params.key = key->conf.key; + params.key_len = key->conf.keylen; + + callback(cookie, ¶ms); + err = 0; + + out: + rcu_read_unlock(); + return err; +} + +static int ieee80211_config_default_key(struct wiphy *wiphy, + struct net_device *dev, + int link_id, u8 key_idx, bool uni, + bool multi) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_link_data *link = + ieee80211_link_or_deflink(sdata, link_id, false); + + if (IS_ERR(link)) + return PTR_ERR(link); + + ieee80211_set_default_key(link, key_idx, uni, multi); + + return 0; +} + +static int ieee80211_config_default_mgmt_key(struct wiphy *wiphy, + struct net_device *dev, + int link_id, u8 key_idx) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_link_data *link = + ieee80211_link_or_deflink(sdata, link_id, true); + + if (IS_ERR(link)) + return PTR_ERR(link); + + ieee80211_set_default_mgmt_key(link, key_idx); + + return 0; +} + +static int ieee80211_config_default_beacon_key(struct wiphy *wiphy, + struct net_device *dev, + int link_id, u8 key_idx) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_link_data *link = + ieee80211_link_or_deflink(sdata, link_id, true); + + if (IS_ERR(link)) + return PTR_ERR(link); + + ieee80211_set_default_beacon_key(link, key_idx); + + return 0; +} + +void sta_set_rate_info_tx(struct sta_info *sta, + const struct ieee80211_tx_rate *rate, + struct rate_info *rinfo) +{ + rinfo->flags = 0; + if (rate->flags & IEEE80211_TX_RC_MCS) { + rinfo->flags |= RATE_INFO_FLAGS_MCS; + rinfo->mcs = rate->idx; + } else if (rate->flags & IEEE80211_TX_RC_VHT_MCS) { + rinfo->flags |= RATE_INFO_FLAGS_VHT_MCS; + rinfo->mcs = ieee80211_rate_get_vht_mcs(rate); + rinfo->nss = ieee80211_rate_get_vht_nss(rate); + } else { + struct ieee80211_supported_band *sband; + int shift = ieee80211_vif_get_shift(&sta->sdata->vif); + u16 brate; + + sband = ieee80211_get_sband(sta->sdata); + WARN_ON_ONCE(sband && !sband->bitrates); + if (sband && sband->bitrates) { + brate = sband->bitrates[rate->idx].bitrate; + rinfo->legacy = DIV_ROUND_UP(brate, 1 << shift); + } + } + if (rate->flags & IEEE80211_TX_RC_40_MHZ_WIDTH) + rinfo->bw = RATE_INFO_BW_40; + else if (rate->flags & IEEE80211_TX_RC_80_MHZ_WIDTH) + rinfo->bw = RATE_INFO_BW_80; + else if (rate->flags & IEEE80211_TX_RC_160_MHZ_WIDTH) + rinfo->bw = RATE_INFO_BW_160; + else + rinfo->bw = RATE_INFO_BW_20; + if (rate->flags & IEEE80211_TX_RC_SHORT_GI) + rinfo->flags |= RATE_INFO_FLAGS_SHORT_GI; +} + +static int ieee80211_dump_station(struct wiphy *wiphy, struct net_device *dev, + int idx, u8 *mac, struct station_info *sinfo) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + int ret = -ENOENT; + + mutex_lock(&local->sta_mtx); + + sta = sta_info_get_by_idx(sdata, idx); + if (sta) { + ret = 0; + memcpy(mac, sta->sta.addr, ETH_ALEN); + sta_set_sinfo(sta, sinfo, true); + } + + mutex_unlock(&local->sta_mtx); + + return ret; +} + +static int ieee80211_dump_survey(struct wiphy *wiphy, struct net_device *dev, + int idx, struct survey_info *survey) +{ + struct ieee80211_local *local = wdev_priv(dev->ieee80211_ptr); + + return drv_get_survey(local, idx, survey); +} + +static int ieee80211_get_station(struct wiphy *wiphy, struct net_device *dev, + const u8 *mac, struct station_info *sinfo) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + int ret = -ENOENT; + + mutex_lock(&local->sta_mtx); + + sta = sta_info_get_bss(sdata, mac); + if (sta) { + ret = 0; + sta_set_sinfo(sta, sinfo, true); + } + + mutex_unlock(&local->sta_mtx); + + return ret; +} + +static int ieee80211_set_monitor_channel(struct wiphy *wiphy, + struct cfg80211_chan_def *chandef) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + struct ieee80211_sub_if_data *sdata; + int ret = 0; + + if (cfg80211_chandef_identical(&local->monitor_chandef, chandef)) + return 0; + + mutex_lock(&local->mtx); + if (local->use_chanctx) { + sdata = wiphy_dereference(local->hw.wiphy, + local->monitor_sdata); + if (sdata) { + ieee80211_link_release_channel(&sdata->deflink); + ret = ieee80211_link_use_channel(&sdata->deflink, + chandef, + IEEE80211_CHANCTX_EXCLUSIVE); + } + } else if (local->open_count == local->monitors) { + local->_oper_chandef = *chandef; + ieee80211_hw_config(local, 0); + } + + if (ret == 0) + local->monitor_chandef = *chandef; + mutex_unlock(&local->mtx); + + return ret; +} + +static int +ieee80211_set_probe_resp(struct ieee80211_sub_if_data *sdata, + const u8 *resp, size_t resp_len, + const struct ieee80211_csa_settings *csa, + const struct ieee80211_color_change_settings *cca, + struct ieee80211_link_data *link) +{ + struct probe_resp *new, *old; + + if (!resp || !resp_len) + return 1; + + old = sdata_dereference(link->u.ap.probe_resp, sdata); + + new = kzalloc(sizeof(struct probe_resp) + resp_len, GFP_KERNEL); + if (!new) + return -ENOMEM; + + new->len = resp_len; + memcpy(new->data, resp, resp_len); + + if (csa) + memcpy(new->cntdwn_counter_offsets, csa->counter_offsets_presp, + csa->n_counter_offsets_presp * + sizeof(new->cntdwn_counter_offsets[0])); + else if (cca) + new->cntdwn_counter_offsets[0] = cca->counter_offset_presp; + + rcu_assign_pointer(link->u.ap.probe_resp, new); + if (old) + kfree_rcu(old, rcu_head); + + return 0; +} + +static int ieee80211_set_fils_discovery(struct ieee80211_sub_if_data *sdata, + struct cfg80211_fils_discovery *params, + struct ieee80211_link_data *link, + struct ieee80211_bss_conf *link_conf) +{ + struct fils_discovery_data *new, *old = NULL; + struct ieee80211_fils_discovery *fd; + + if (!params->tmpl || !params->tmpl_len) + return -EINVAL; + + fd = &link_conf->fils_discovery; + fd->min_interval = params->min_interval; + fd->max_interval = params->max_interval; + + old = sdata_dereference(link->u.ap.fils_discovery, sdata); + new = kzalloc(sizeof(*new) + params->tmpl_len, GFP_KERNEL); + if (!new) + return -ENOMEM; + new->len = params->tmpl_len; + memcpy(new->data, params->tmpl, params->tmpl_len); + rcu_assign_pointer(link->u.ap.fils_discovery, new); + + if (old) + kfree_rcu(old, rcu_head); + + return 0; +} + +static int +ieee80211_set_unsol_bcast_probe_resp(struct ieee80211_sub_if_data *sdata, + struct cfg80211_unsol_bcast_probe_resp *params, + struct ieee80211_link_data *link, + struct ieee80211_bss_conf *link_conf) +{ + struct unsol_bcast_probe_resp_data *new, *old = NULL; + + if (!params->tmpl || !params->tmpl_len) + return -EINVAL; + + old = sdata_dereference(link->u.ap.unsol_bcast_probe_resp, sdata); + new = kzalloc(sizeof(*new) + params->tmpl_len, GFP_KERNEL); + if (!new) + return -ENOMEM; + new->len = params->tmpl_len; + memcpy(new->data, params->tmpl, params->tmpl_len); + rcu_assign_pointer(link->u.ap.unsol_bcast_probe_resp, new); + + if (old) + kfree_rcu(old, rcu_head); + + link_conf->unsol_bcast_probe_resp_interval = params->interval; + + return 0; +} + +static int ieee80211_set_ftm_responder_params( + struct ieee80211_sub_if_data *sdata, + const u8 *lci, size_t lci_len, + const u8 *civicloc, size_t civicloc_len, + struct ieee80211_bss_conf *link_conf) +{ + struct ieee80211_ftm_responder_params *new, *old; + u8 *pos; + int len; + + if (!lci_len && !civicloc_len) + return 0; + + old = link_conf->ftmr_params; + len = lci_len + civicloc_len; + + new = kzalloc(sizeof(*new) + len, GFP_KERNEL); + if (!new) + return -ENOMEM; + + pos = (u8 *)(new + 1); + if (lci_len) { + new->lci_len = lci_len; + new->lci = pos; + memcpy(pos, lci, lci_len); + pos += lci_len; + } + + if (civicloc_len) { + new->civicloc_len = civicloc_len; + new->civicloc = pos; + memcpy(pos, civicloc, civicloc_len); + pos += civicloc_len; + } + + link_conf->ftmr_params = new; + kfree(old); + + return 0; +} + +static int +ieee80211_copy_mbssid_beacon(u8 *pos, struct cfg80211_mbssid_elems *dst, + struct cfg80211_mbssid_elems *src) +{ + int i, offset = 0; + + for (i = 0; i < src->cnt; i++) { + memcpy(pos + offset, src->elem[i].data, src->elem[i].len); + dst->elem[i].len = src->elem[i].len; + dst->elem[i].data = pos + offset; + offset += dst->elem[i].len; + } + dst->cnt = src->cnt; + + return offset; +} + +static int ieee80211_assign_beacon(struct ieee80211_sub_if_data *sdata, + struct ieee80211_link_data *link, + struct cfg80211_beacon_data *params, + const struct ieee80211_csa_settings *csa, + const struct ieee80211_color_change_settings *cca) +{ + struct cfg80211_mbssid_elems *mbssid = NULL; + struct beacon_data *new, *old; + int new_head_len, new_tail_len; + int size, err; + u32 changed = BSS_CHANGED_BEACON; + struct ieee80211_bss_conf *link_conf = link->conf; + + old = sdata_dereference(link->u.ap.beacon, sdata); + + /* Need to have a beacon head if we don't have one yet */ + if (!params->head && !old) + return -EINVAL; + + /* new or old head? */ + if (params->head) + new_head_len = params->head_len; + else + new_head_len = old->head_len; + + /* new or old tail? */ + if (params->tail || !old) + /* params->tail_len will be zero for !params->tail */ + new_tail_len = params->tail_len; + else + new_tail_len = old->tail_len; + + size = sizeof(*new) + new_head_len + new_tail_len; + + /* new or old multiple BSSID elements? */ + if (params->mbssid_ies) { + mbssid = params->mbssid_ies; + size += struct_size(new->mbssid_ies, elem, mbssid->cnt); + size += ieee80211_get_mbssid_beacon_len(mbssid); + } else if (old && old->mbssid_ies) { + mbssid = old->mbssid_ies; + size += struct_size(new->mbssid_ies, elem, mbssid->cnt); + size += ieee80211_get_mbssid_beacon_len(mbssid); + } + + new = kzalloc(size, GFP_KERNEL); + if (!new) + return -ENOMEM; + + /* start filling the new info now */ + + /* + * pointers go into the block we allocated, + * memory is | beacon_data | head | tail | mbssid_ies + */ + new->head = ((u8 *) new) + sizeof(*new); + new->tail = new->head + new_head_len; + new->head_len = new_head_len; + new->tail_len = new_tail_len; + /* copy in optional mbssid_ies */ + if (mbssid) { + u8 *pos = new->tail + new->tail_len; + + new->mbssid_ies = (void *)pos; + pos += struct_size(new->mbssid_ies, elem, mbssid->cnt); + ieee80211_copy_mbssid_beacon(pos, new->mbssid_ies, mbssid); + /* update bssid_indicator */ + link_conf->bssid_indicator = + ilog2(__roundup_pow_of_two(mbssid->cnt + 1)); + } + + if (csa) { + new->cntdwn_current_counter = csa->count; + memcpy(new->cntdwn_counter_offsets, csa->counter_offsets_beacon, + csa->n_counter_offsets_beacon * + sizeof(new->cntdwn_counter_offsets[0])); + } else if (cca) { + new->cntdwn_current_counter = cca->count; + new->cntdwn_counter_offsets[0] = cca->counter_offset_beacon; + } + + /* copy in head */ + if (params->head) + memcpy(new->head, params->head, new_head_len); + else + memcpy(new->head, old->head, new_head_len); + + /* copy in optional tail */ + if (params->tail) + memcpy(new->tail, params->tail, new_tail_len); + else + if (old) + memcpy(new->tail, old->tail, new_tail_len); + + err = ieee80211_set_probe_resp(sdata, params->probe_resp, + params->probe_resp_len, csa, cca, link); + if (err < 0) { + kfree(new); + return err; + } + if (err == 0) + changed |= BSS_CHANGED_AP_PROBE_RESP; + + if (params->ftm_responder != -1) { + link_conf->ftm_responder = params->ftm_responder; + err = ieee80211_set_ftm_responder_params(sdata, + params->lci, + params->lci_len, + params->civicloc, + params->civicloc_len, + link_conf); + + if (err < 0) { + kfree(new); + return err; + } + + changed |= BSS_CHANGED_FTM_RESPONDER; + } + + rcu_assign_pointer(link->u.ap.beacon, new); + sdata->u.ap.active = true; + + if (old) + kfree_rcu(old, rcu_head); + + return changed; +} + +static int ieee80211_start_ap(struct wiphy *wiphy, struct net_device *dev, + struct cfg80211_ap_settings *params) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + struct beacon_data *old; + struct ieee80211_sub_if_data *vlan; + u32 changed = BSS_CHANGED_BEACON_INT | + BSS_CHANGED_BEACON_ENABLED | + BSS_CHANGED_BEACON | + BSS_CHANGED_P2P_PS | + BSS_CHANGED_TXPOWER | + BSS_CHANGED_TWT; + int i, err; + int prev_beacon_int; + unsigned int link_id = params->beacon.link_id; + struct ieee80211_link_data *link; + struct ieee80211_bss_conf *link_conf; + + link = sdata_dereference(sdata->link[link_id], sdata); + if (!link) + return -ENOLINK; + + link_conf = link->conf; + + old = sdata_dereference(link->u.ap.beacon, sdata); + if (old) + return -EALREADY; + + if (params->smps_mode != NL80211_SMPS_OFF) + return -ENOTSUPP; + + link->smps_mode = IEEE80211_SMPS_OFF; + + link->needed_rx_chains = sdata->local->rx_chains; + + prev_beacon_int = link_conf->beacon_int; + link_conf->beacon_int = params->beacon_interval; + + if (params->he_cap && params->he_oper) { + link_conf->he_support = true; + link_conf->htc_trig_based_pkt_ext = + le32_get_bits(params->he_oper->he_oper_params, + IEEE80211_HE_OPERATION_DFLT_PE_DURATION_MASK); + link_conf->frame_time_rts_th = + le32_get_bits(params->he_oper->he_oper_params, + IEEE80211_HE_OPERATION_RTS_THRESHOLD_MASK); + changed |= BSS_CHANGED_HE_OBSS_PD; + + if (params->beacon.he_bss_color.enabled) + changed |= BSS_CHANGED_HE_BSS_COLOR; + } + + if (sdata->vif.type == NL80211_IFTYPE_AP && + params->mbssid_config.tx_wdev) { + err = ieee80211_set_ap_mbssid_options(sdata, + params->mbssid_config, + link_conf); + if (err) + return err; + } + + mutex_lock(&local->mtx); + err = ieee80211_link_use_channel(link, ¶ms->chandef, + IEEE80211_CHANCTX_SHARED); + if (!err) + ieee80211_link_copy_chanctx_to_vlans(link, false); + mutex_unlock(&local->mtx); + if (err) { + link_conf->beacon_int = prev_beacon_int; + return err; + } + + /* + * Apply control port protocol, this allows us to + * not encrypt dynamic WEP control frames. + */ + sdata->control_port_protocol = params->crypto.control_port_ethertype; + sdata->control_port_no_encrypt = params->crypto.control_port_no_encrypt; + sdata->control_port_over_nl80211 = + params->crypto.control_port_over_nl80211; + sdata->control_port_no_preauth = + params->crypto.control_port_no_preauth; + + list_for_each_entry(vlan, &sdata->u.ap.vlans, u.vlan.list) { + vlan->control_port_protocol = + params->crypto.control_port_ethertype; + vlan->control_port_no_encrypt = + params->crypto.control_port_no_encrypt; + vlan->control_port_over_nl80211 = + params->crypto.control_port_over_nl80211; + vlan->control_port_no_preauth = + params->crypto.control_port_no_preauth; + } + + link_conf->dtim_period = params->dtim_period; + link_conf->enable_beacon = true; + link_conf->allow_p2p_go_ps = sdata->vif.p2p; + link_conf->twt_responder = params->twt_responder; + link_conf->he_obss_pd = params->he_obss_pd; + link_conf->he_bss_color = params->beacon.he_bss_color; + sdata->vif.cfg.s1g = params->chandef.chan->band == + NL80211_BAND_S1GHZ; + + sdata->vif.cfg.ssid_len = params->ssid_len; + if (params->ssid_len) + memcpy(sdata->vif.cfg.ssid, params->ssid, + params->ssid_len); + link_conf->hidden_ssid = + (params->hidden_ssid != NL80211_HIDDEN_SSID_NOT_IN_USE); + + memset(&link_conf->p2p_noa_attr, 0, + sizeof(link_conf->p2p_noa_attr)); + link_conf->p2p_noa_attr.oppps_ctwindow = + params->p2p_ctwindow & IEEE80211_P2P_OPPPS_CTWINDOW_MASK; + if (params->p2p_opp_ps) + link_conf->p2p_noa_attr.oppps_ctwindow |= + IEEE80211_P2P_OPPPS_ENABLE_BIT; + + sdata->beacon_rate_set = false; + if (wiphy_ext_feature_isset(local->hw.wiphy, + NL80211_EXT_FEATURE_BEACON_RATE_LEGACY)) { + for (i = 0; i < NUM_NL80211_BANDS; i++) { + sdata->beacon_rateidx_mask[i] = + params->beacon_rate.control[i].legacy; + if (sdata->beacon_rateidx_mask[i]) + sdata->beacon_rate_set = true; + } + } + + if (ieee80211_hw_check(&local->hw, HAS_RATE_CONTROL)) + link_conf->beacon_tx_rate = params->beacon_rate; + + err = ieee80211_assign_beacon(sdata, link, ¶ms->beacon, NULL, NULL); + if (err < 0) + goto error; + changed |= err; + + if (params->fils_discovery.max_interval) { + err = ieee80211_set_fils_discovery(sdata, + ¶ms->fils_discovery, + link, link_conf); + if (err < 0) + goto error; + changed |= BSS_CHANGED_FILS_DISCOVERY; + } + + if (params->unsol_bcast_probe_resp.interval) { + err = ieee80211_set_unsol_bcast_probe_resp(sdata, + ¶ms->unsol_bcast_probe_resp, + link, link_conf); + if (err < 0) + goto error; + changed |= BSS_CHANGED_UNSOL_BCAST_PROBE_RESP; + } + + err = drv_start_ap(sdata->local, sdata, link_conf); + if (err) { + old = sdata_dereference(link->u.ap.beacon, sdata); + + if (old) + kfree_rcu(old, rcu_head); + RCU_INIT_POINTER(link->u.ap.beacon, NULL); + sdata->u.ap.active = false; + goto error; + } + + ieee80211_recalc_dtim(local, sdata); + ieee80211_vif_cfg_change_notify(sdata, BSS_CHANGED_SSID); + ieee80211_link_info_change_notify(sdata, link, changed); + + netif_carrier_on(dev); + list_for_each_entry(vlan, &sdata->u.ap.vlans, u.vlan.list) + netif_carrier_on(vlan->dev); + + return 0; + +error: + mutex_lock(&local->mtx); + ieee80211_link_release_channel(link); + mutex_unlock(&local->mtx); + + return err; +} + +static int ieee80211_change_beacon(struct wiphy *wiphy, struct net_device *dev, + struct cfg80211_beacon_data *params) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_link_data *link; + struct beacon_data *old; + int err; + struct ieee80211_bss_conf *link_conf; + + sdata_assert_lock(sdata); + + link = sdata_dereference(sdata->link[params->link_id], sdata); + if (!link) + return -ENOLINK; + + link_conf = link->conf; + + /* don't allow changing the beacon while a countdown is in place - offset + * of channel switch counter may change + */ + if (link_conf->csa_active || link_conf->color_change_active) + return -EBUSY; + + old = sdata_dereference(link->u.ap.beacon, sdata); + if (!old) + return -ENOENT; + + err = ieee80211_assign_beacon(sdata, link, params, NULL, NULL); + if (err < 0) + return err; + + if (params->he_bss_color_valid && + params->he_bss_color.enabled != link_conf->he_bss_color.enabled) { + link_conf->he_bss_color.enabled = params->he_bss_color.enabled; + err |= BSS_CHANGED_HE_BSS_COLOR; + } + + ieee80211_link_info_change_notify(sdata, link, err); + return 0; +} + +static void ieee80211_free_next_beacon(struct ieee80211_link_data *link) +{ + if (!link->u.ap.next_beacon) + return; + + kfree(link->u.ap.next_beacon->mbssid_ies); + kfree(link->u.ap.next_beacon); + link->u.ap.next_beacon = NULL; +} + +static int ieee80211_stop_ap(struct wiphy *wiphy, struct net_device *dev, + unsigned int link_id) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_sub_if_data *vlan; + struct ieee80211_local *local = sdata->local; + struct beacon_data *old_beacon; + struct probe_resp *old_probe_resp; + struct fils_discovery_data *old_fils_discovery; + struct unsol_bcast_probe_resp_data *old_unsol_bcast_probe_resp; + struct cfg80211_chan_def chandef; + struct ieee80211_link_data *link = + sdata_dereference(sdata->link[link_id], sdata); + struct ieee80211_bss_conf *link_conf = link->conf; + + sdata_assert_lock(sdata); + + old_beacon = sdata_dereference(link->u.ap.beacon, sdata); + if (!old_beacon) + return -ENOENT; + old_probe_resp = sdata_dereference(link->u.ap.probe_resp, + sdata); + old_fils_discovery = sdata_dereference(link->u.ap.fils_discovery, + sdata); + old_unsol_bcast_probe_resp = + sdata_dereference(link->u.ap.unsol_bcast_probe_resp, + sdata); + + /* abort any running channel switch or color change */ + mutex_lock(&local->mtx); + link_conf->csa_active = false; + link_conf->color_change_active = false; + if (link->csa_block_tx) { + ieee80211_wake_vif_queues(local, sdata, + IEEE80211_QUEUE_STOP_REASON_CSA); + link->csa_block_tx = false; + } + + mutex_unlock(&local->mtx); + + ieee80211_free_next_beacon(link); + + /* turn off carrier for this interface and dependent VLANs */ + list_for_each_entry(vlan, &sdata->u.ap.vlans, u.vlan.list) + netif_carrier_off(vlan->dev); + netif_carrier_off(dev); + + /* remove beacon and probe response */ + sdata->u.ap.active = false; + RCU_INIT_POINTER(link->u.ap.beacon, NULL); + RCU_INIT_POINTER(link->u.ap.probe_resp, NULL); + RCU_INIT_POINTER(link->u.ap.fils_discovery, NULL); + RCU_INIT_POINTER(link->u.ap.unsol_bcast_probe_resp, NULL); + kfree_rcu(old_beacon, rcu_head); + if (old_probe_resp) + kfree_rcu(old_probe_resp, rcu_head); + if (old_fils_discovery) + kfree_rcu(old_fils_discovery, rcu_head); + if (old_unsol_bcast_probe_resp) + kfree_rcu(old_unsol_bcast_probe_resp, rcu_head); + + kfree(link_conf->ftmr_params); + link_conf->ftmr_params = NULL; + + sdata->vif.mbssid_tx_vif = NULL; + link_conf->bssid_index = 0; + link_conf->nontransmitted = false; + link_conf->ema_ap = false; + link_conf->bssid_indicator = 0; + + __sta_info_flush(sdata, true); + ieee80211_free_keys(sdata, true); + + link_conf->enable_beacon = false; + sdata->beacon_rate_set = false; + sdata->vif.cfg.ssid_len = 0; + clear_bit(SDATA_STATE_OFFCHANNEL_BEACON_STOPPED, &sdata->state); + ieee80211_link_info_change_notify(sdata, link, + BSS_CHANGED_BEACON_ENABLED); + + if (sdata->wdev.cac_started) { + chandef = link_conf->chandef; + cancel_delayed_work_sync(&link->dfs_cac_timer_work); + cfg80211_cac_event(sdata->dev, &chandef, + NL80211_RADAR_CAC_ABORTED, + GFP_KERNEL); + } + + drv_stop_ap(sdata->local, sdata, link_conf); + + /* free all potentially still buffered bcast frames */ + local->total_ps_buffered -= skb_queue_len(&sdata->u.ap.ps.bc_buf); + ieee80211_purge_tx_queue(&local->hw, &sdata->u.ap.ps.bc_buf); + + mutex_lock(&local->mtx); + ieee80211_link_copy_chanctx_to_vlans(link, true); + ieee80211_link_release_channel(link); + mutex_unlock(&local->mtx); + + return 0; +} + +static int sta_apply_auth_flags(struct ieee80211_local *local, + struct sta_info *sta, + u32 mask, u32 set) +{ + int ret; + + if (mask & BIT(NL80211_STA_FLAG_AUTHENTICATED) && + set & BIT(NL80211_STA_FLAG_AUTHENTICATED) && + !test_sta_flag(sta, WLAN_STA_AUTH)) { + ret = sta_info_move_state(sta, IEEE80211_STA_AUTH); + if (ret) + return ret; + } + + if (mask & BIT(NL80211_STA_FLAG_ASSOCIATED) && + set & BIT(NL80211_STA_FLAG_ASSOCIATED) && + !test_sta_flag(sta, WLAN_STA_ASSOC)) { + /* + * When peer becomes associated, init rate control as + * well. Some drivers require rate control initialized + * before drv_sta_state() is called. + */ + if (!test_sta_flag(sta, WLAN_STA_RATE_CONTROL)) + rate_control_rate_init(sta); + + ret = sta_info_move_state(sta, IEEE80211_STA_ASSOC); + if (ret) + return ret; + } + + if (mask & BIT(NL80211_STA_FLAG_AUTHORIZED)) { + if (set & BIT(NL80211_STA_FLAG_AUTHORIZED)) + ret = sta_info_move_state(sta, IEEE80211_STA_AUTHORIZED); + else if (test_sta_flag(sta, WLAN_STA_AUTHORIZED)) + ret = sta_info_move_state(sta, IEEE80211_STA_ASSOC); + else + ret = 0; + if (ret) + return ret; + } + + if (mask & BIT(NL80211_STA_FLAG_ASSOCIATED) && + !(set & BIT(NL80211_STA_FLAG_ASSOCIATED)) && + test_sta_flag(sta, WLAN_STA_ASSOC)) { + ret = sta_info_move_state(sta, IEEE80211_STA_AUTH); + if (ret) + return ret; + } + + if (mask & BIT(NL80211_STA_FLAG_AUTHENTICATED) && + !(set & BIT(NL80211_STA_FLAG_AUTHENTICATED)) && + test_sta_flag(sta, WLAN_STA_AUTH)) { + ret = sta_info_move_state(sta, IEEE80211_STA_NONE); + if (ret) + return ret; + } + + return 0; +} + +static void sta_apply_mesh_params(struct ieee80211_local *local, + struct sta_info *sta, + struct station_parameters *params) +{ +#ifdef CONFIG_MAC80211_MESH + struct ieee80211_sub_if_data *sdata = sta->sdata; + u32 changed = 0; + + if (params->sta_modify_mask & STATION_PARAM_APPLY_PLINK_STATE) { + switch (params->plink_state) { + case NL80211_PLINK_ESTAB: + if (sta->mesh->plink_state != NL80211_PLINK_ESTAB) + changed = mesh_plink_inc_estab_count(sdata); + sta->mesh->plink_state = params->plink_state; + sta->mesh->aid = params->peer_aid; + + ieee80211_mps_sta_status_update(sta); + changed |= ieee80211_mps_set_sta_local_pm(sta, + sdata->u.mesh.mshcfg.power_mode); + + ewma_mesh_tx_rate_avg_init(&sta->mesh->tx_rate_avg); + /* init at low value */ + ewma_mesh_tx_rate_avg_add(&sta->mesh->tx_rate_avg, 10); + + break; + case NL80211_PLINK_LISTEN: + case NL80211_PLINK_BLOCKED: + case NL80211_PLINK_OPN_SNT: + case NL80211_PLINK_OPN_RCVD: + case NL80211_PLINK_CNF_RCVD: + case NL80211_PLINK_HOLDING: + if (sta->mesh->plink_state == NL80211_PLINK_ESTAB) + changed = mesh_plink_dec_estab_count(sdata); + sta->mesh->plink_state = params->plink_state; + + ieee80211_mps_sta_status_update(sta); + changed |= ieee80211_mps_set_sta_local_pm(sta, + NL80211_MESH_POWER_UNKNOWN); + break; + default: + /* nothing */ + break; + } + } + + switch (params->plink_action) { + case NL80211_PLINK_ACTION_NO_ACTION: + /* nothing */ + break; + case NL80211_PLINK_ACTION_OPEN: + changed |= mesh_plink_open(sta); + break; + case NL80211_PLINK_ACTION_BLOCK: + changed |= mesh_plink_block(sta); + break; + } + + if (params->local_pm) + changed |= ieee80211_mps_set_sta_local_pm(sta, + params->local_pm); + + ieee80211_mbss_info_change_notify(sdata, changed); +#endif +} + +static int sta_link_apply_parameters(struct ieee80211_local *local, + struct sta_info *sta, bool new_link, + struct link_station_parameters *params) +{ + int ret = 0; + struct ieee80211_supported_band *sband; + struct ieee80211_sub_if_data *sdata = sta->sdata; + u32 link_id = params->link_id < 0 ? 0 : params->link_id; + struct ieee80211_link_data *link = + sdata_dereference(sdata->link[link_id], sdata); + struct link_sta_info *link_sta = + rcu_dereference_protected(sta->link[link_id], + lockdep_is_held(&local->sta_mtx)); + + /* + * If there are no changes, then accept a link that exist, + * unless it's a new link. + */ + if (params->link_id >= 0 && !new_link && + !params->link_mac && !params->txpwr_set && + !params->supported_rates_len && + !params->ht_capa && !params->vht_capa && + !params->he_capa && !params->eht_capa && + !params->opmode_notif_used) + return 0; + + if (!link || !link_sta) + return -EINVAL; + + sband = ieee80211_get_link_sband(link); + if (!sband) + return -EINVAL; + + if (params->link_mac) { + if (new_link) { + memcpy(link_sta->addr, params->link_mac, ETH_ALEN); + memcpy(link_sta->pub->addr, params->link_mac, ETH_ALEN); + } else if (!ether_addr_equal(link_sta->addr, + params->link_mac)) { + return -EINVAL; + } + } else if (new_link) { + return -EINVAL; + } + + if (params->txpwr_set) { + link_sta->pub->txpwr.type = params->txpwr.type; + if (params->txpwr.type == NL80211_TX_POWER_LIMITED) + link_sta->pub->txpwr.power = params->txpwr.power; + ret = drv_sta_set_txpwr(local, sdata, sta); + if (ret) + return ret; + } + + if (params->supported_rates && + params->supported_rates_len) { + ieee80211_parse_bitrates(link->conf->chandef.width, + sband, params->supported_rates, + params->supported_rates_len, + &link_sta->pub->supp_rates[sband->band]); + } + + if (params->ht_capa) + ieee80211_ht_cap_ie_to_sta_ht_cap(sdata, sband, + params->ht_capa, link_sta); + + /* VHT can override some HT caps such as the A-MSDU max length */ + if (params->vht_capa) + ieee80211_vht_cap_ie_to_sta_vht_cap(sdata, sband, + params->vht_capa, NULL, + link_sta); + + if (params->he_capa) + ieee80211_he_cap_ie_to_sta_he_cap(sdata, sband, + (void *)params->he_capa, + params->he_capa_len, + (void *)params->he_6ghz_capa, + link_sta); + + if (params->eht_capa) + ieee80211_eht_cap_ie_to_sta_eht_cap(sdata, sband, + (u8 *)params->he_capa, + params->he_capa_len, + params->eht_capa, + params->eht_capa_len, + link_sta); + + if (params->opmode_notif_used) { + /* returned value is only needed for rc update, but the + * rc isn't initialized here yet, so ignore it + */ + __ieee80211_vht_handle_opmode(sdata, link_sta, + params->opmode_notif, + sband->band); + } + + return ret; +} + +static int sta_apply_parameters(struct ieee80211_local *local, + struct sta_info *sta, + struct station_parameters *params) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + u32 mask, set; + int ret = 0; + + mask = params->sta_flags_mask; + set = params->sta_flags_set; + + if (ieee80211_vif_is_mesh(&sdata->vif)) { + /* + * In mesh mode, ASSOCIATED isn't part of the nl80211 + * API but must follow AUTHENTICATED for driver state. + */ + if (mask & BIT(NL80211_STA_FLAG_AUTHENTICATED)) + mask |= BIT(NL80211_STA_FLAG_ASSOCIATED); + if (set & BIT(NL80211_STA_FLAG_AUTHENTICATED)) + set |= BIT(NL80211_STA_FLAG_ASSOCIATED); + } else if (test_sta_flag(sta, WLAN_STA_TDLS_PEER)) { + /* + * TDLS -- everything follows authorized, but + * only becoming authorized is possible, not + * going back + */ + if (set & BIT(NL80211_STA_FLAG_AUTHORIZED)) { + set |= BIT(NL80211_STA_FLAG_AUTHENTICATED) | + BIT(NL80211_STA_FLAG_ASSOCIATED); + mask |= BIT(NL80211_STA_FLAG_AUTHENTICATED) | + BIT(NL80211_STA_FLAG_ASSOCIATED); + } + } + + if (mask & BIT(NL80211_STA_FLAG_WME) && + local->hw.queues >= IEEE80211_NUM_ACS) + sta->sta.wme = set & BIT(NL80211_STA_FLAG_WME); + + /* auth flags will be set later for TDLS, + * and for unassociated stations that move to associated */ + if (!test_sta_flag(sta, WLAN_STA_TDLS_PEER) && + !((mask & BIT(NL80211_STA_FLAG_ASSOCIATED)) && + (set & BIT(NL80211_STA_FLAG_ASSOCIATED)))) { + ret = sta_apply_auth_flags(local, sta, mask, set); + if (ret) + return ret; + } + + if (mask & BIT(NL80211_STA_FLAG_SHORT_PREAMBLE)) { + if (set & BIT(NL80211_STA_FLAG_SHORT_PREAMBLE)) + set_sta_flag(sta, WLAN_STA_SHORT_PREAMBLE); + else + clear_sta_flag(sta, WLAN_STA_SHORT_PREAMBLE); + } + + if (mask & BIT(NL80211_STA_FLAG_MFP)) { + sta->sta.mfp = !!(set & BIT(NL80211_STA_FLAG_MFP)); + if (set & BIT(NL80211_STA_FLAG_MFP)) + set_sta_flag(sta, WLAN_STA_MFP); + else + clear_sta_flag(sta, WLAN_STA_MFP); + } + + if (mask & BIT(NL80211_STA_FLAG_TDLS_PEER)) { + if (set & BIT(NL80211_STA_FLAG_TDLS_PEER)) + set_sta_flag(sta, WLAN_STA_TDLS_PEER); + else + clear_sta_flag(sta, WLAN_STA_TDLS_PEER); + } + + /* mark TDLS channel switch support, if the AP allows it */ + if (test_sta_flag(sta, WLAN_STA_TDLS_PEER) && + !sdata->deflink.u.mgd.tdls_chan_switch_prohibited && + params->ext_capab_len >= 4 && + params->ext_capab[3] & WLAN_EXT_CAPA4_TDLS_CHAN_SWITCH) + set_sta_flag(sta, WLAN_STA_TDLS_CHAN_SWITCH); + + if (test_sta_flag(sta, WLAN_STA_TDLS_PEER) && + !sdata->u.mgd.tdls_wider_bw_prohibited && + ieee80211_hw_check(&local->hw, TDLS_WIDER_BW) && + params->ext_capab_len >= 8 && + params->ext_capab[7] & WLAN_EXT_CAPA8_TDLS_WIDE_BW_ENABLED) + set_sta_flag(sta, WLAN_STA_TDLS_WIDER_BW); + + if (params->sta_modify_mask & STATION_PARAM_APPLY_UAPSD) { + sta->sta.uapsd_queues = params->uapsd_queues; + sta->sta.max_sp = params->max_sp; + } + + ieee80211_sta_set_max_amsdu_subframes(sta, params->ext_capab, + params->ext_capab_len); + + /* + * cfg80211 validates this (1-2007) and allows setting the AID + * only when creating a new station entry + */ + if (params->aid) + sta->sta.aid = params->aid; + + /* + * Some of the following updates would be racy if called on an + * existing station, via ieee80211_change_station(). However, + * all such changes are rejected by cfg80211 except for updates + * changing the supported rates on an existing but not yet used + * TDLS peer. + */ + + if (params->listen_interval >= 0) + sta->listen_interval = params->listen_interval; + + ret = sta_link_apply_parameters(local, sta, false, + ¶ms->link_sta_params); + if (ret) + return ret; + + if (params->support_p2p_ps >= 0) + sta->sta.support_p2p_ps = params->support_p2p_ps; + + if (ieee80211_vif_is_mesh(&sdata->vif)) + sta_apply_mesh_params(local, sta, params); + + if (params->airtime_weight) + sta->airtime_weight = params->airtime_weight; + + /* set the STA state after all sta info from usermode has been set */ + if (test_sta_flag(sta, WLAN_STA_TDLS_PEER) || + set & BIT(NL80211_STA_FLAG_ASSOCIATED)) { + ret = sta_apply_auth_flags(local, sta, mask, set); + if (ret) + return ret; + } + + /* Mark the STA as MLO if MLD MAC address is available */ + if (params->link_sta_params.mld_mac) + sta->sta.mlo = true; + + return 0; +} + +static int ieee80211_add_station(struct wiphy *wiphy, struct net_device *dev, + const u8 *mac, + struct station_parameters *params) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + struct sta_info *sta; + struct ieee80211_sub_if_data *sdata; + int err; + + if (params->vlan) { + sdata = IEEE80211_DEV_TO_SUB_IF(params->vlan); + + if (sdata->vif.type != NL80211_IFTYPE_AP_VLAN && + sdata->vif.type != NL80211_IFTYPE_AP) + return -EINVAL; + } else + sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + if (ether_addr_equal(mac, sdata->vif.addr)) + return -EINVAL; + + if (!is_valid_ether_addr(mac)) + return -EINVAL; + + if (params->sta_flags_set & BIT(NL80211_STA_FLAG_TDLS_PEER) && + sdata->vif.type == NL80211_IFTYPE_STATION && + !sdata->u.mgd.associated) + return -EINVAL; + + /* + * If we have a link ID, it can be a non-MLO station on an AP MLD, + * but we need to have a link_mac in that case as well, so use the + * STA's MAC address in that case. + */ + if (params->link_sta_params.link_id >= 0) + sta = sta_info_alloc_with_link(sdata, mac, + params->link_sta_params.link_id, + params->link_sta_params.link_mac ?: mac, + GFP_KERNEL); + else + sta = sta_info_alloc(sdata, mac, GFP_KERNEL); + + if (!sta) + return -ENOMEM; + + if (params->sta_flags_set & BIT(NL80211_STA_FLAG_TDLS_PEER)) + sta->sta.tdls = true; + + /* Though the mutex is not needed here (since the station is not + * visible yet), sta_apply_parameters (and inner functions) require + * the mutex due to other paths. + */ + mutex_lock(&local->sta_mtx); + err = sta_apply_parameters(local, sta, params); + mutex_unlock(&local->sta_mtx); + if (err) { + sta_info_free(local, sta); + return err; + } + + /* + * for TDLS and for unassociated station, rate control should be + * initialized only when rates are known and station is marked + * authorized/associated + */ + if (!test_sta_flag(sta, WLAN_STA_TDLS_PEER) && + test_sta_flag(sta, WLAN_STA_ASSOC)) + rate_control_rate_init(sta); + + return sta_info_insert(sta); +} + +static int ieee80211_del_station(struct wiphy *wiphy, struct net_device *dev, + struct station_del_parameters *params) +{ + struct ieee80211_sub_if_data *sdata; + + sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + if (params->mac) + return sta_info_destroy_addr_bss(sdata, params->mac); + + sta_info_flush(sdata); + return 0; +} + +static int ieee80211_change_station(struct wiphy *wiphy, + struct net_device *dev, const u8 *mac, + struct station_parameters *params) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = wiphy_priv(wiphy); + struct sta_info *sta; + struct ieee80211_sub_if_data *vlansdata; + enum cfg80211_station_type statype; + int err; + + mutex_lock(&local->sta_mtx); + + sta = sta_info_get_bss(sdata, mac); + if (!sta) { + err = -ENOENT; + goto out_err; + } + + switch (sdata->vif.type) { + case NL80211_IFTYPE_MESH_POINT: + if (sdata->u.mesh.user_mpm) + statype = CFG80211_STA_MESH_PEER_USER; + else + statype = CFG80211_STA_MESH_PEER_KERNEL; + break; + case NL80211_IFTYPE_ADHOC: + statype = CFG80211_STA_IBSS; + break; + case NL80211_IFTYPE_STATION: + if (!test_sta_flag(sta, WLAN_STA_TDLS_PEER)) { + statype = CFG80211_STA_AP_STA; + break; + } + if (test_sta_flag(sta, WLAN_STA_AUTHORIZED)) + statype = CFG80211_STA_TDLS_PEER_ACTIVE; + else + statype = CFG80211_STA_TDLS_PEER_SETUP; + break; + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + if (test_sta_flag(sta, WLAN_STA_ASSOC)) + statype = CFG80211_STA_AP_CLIENT; + else + statype = CFG80211_STA_AP_CLIENT_UNASSOC; + break; + default: + err = -EOPNOTSUPP; + goto out_err; + } + + err = cfg80211_check_station_change(wiphy, params, statype); + if (err) + goto out_err; + + if (params->vlan && params->vlan != sta->sdata->dev) { + vlansdata = IEEE80211_DEV_TO_SUB_IF(params->vlan); + + if (params->vlan->ieee80211_ptr->use_4addr) { + if (vlansdata->u.vlan.sta) { + err = -EBUSY; + goto out_err; + } + + rcu_assign_pointer(vlansdata->u.vlan.sta, sta); + __ieee80211_check_fast_rx_iface(vlansdata); + drv_sta_set_4addr(local, sta->sdata, &sta->sta, true); + } + + if (sta->sdata->vif.type == NL80211_IFTYPE_AP_VLAN && + sta->sdata->u.vlan.sta) { + ieee80211_clear_fast_rx(sta); + RCU_INIT_POINTER(sta->sdata->u.vlan.sta, NULL); + } + + if (test_sta_flag(sta, WLAN_STA_AUTHORIZED)) + ieee80211_vif_dec_num_mcast(sta->sdata); + + sta->sdata = vlansdata; + ieee80211_check_fast_xmit(sta); + + if (test_sta_flag(sta, WLAN_STA_AUTHORIZED)) { + ieee80211_vif_inc_num_mcast(sta->sdata); + cfg80211_send_layer2_update(sta->sdata->dev, + sta->sta.addr); + } + } + + /* we use sta_info_get_bss() so this might be different */ + if (sdata != sta->sdata) { + mutex_lock_nested(&sta->sdata->wdev.mtx, 1); + err = sta_apply_parameters(local, sta, params); + mutex_unlock(&sta->sdata->wdev.mtx); + } else { + err = sta_apply_parameters(local, sta, params); + } + if (err) + goto out_err; + + mutex_unlock(&local->sta_mtx); + + if (sdata->vif.type == NL80211_IFTYPE_STATION && + params->sta_flags_mask & BIT(NL80211_STA_FLAG_AUTHORIZED)) { + ieee80211_recalc_ps(local); + ieee80211_recalc_ps_vif(sdata); + } + + return 0; +out_err: + mutex_unlock(&local->sta_mtx); + return err; +} + +#ifdef CONFIG_MAC80211_MESH +static int ieee80211_add_mpath(struct wiphy *wiphy, struct net_device *dev, + const u8 *dst, const u8 *next_hop) +{ + struct ieee80211_sub_if_data *sdata; + struct mesh_path *mpath; + struct sta_info *sta; + + sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + rcu_read_lock(); + sta = sta_info_get(sdata, next_hop); + if (!sta) { + rcu_read_unlock(); + return -ENOENT; + } + + mpath = mesh_path_add(sdata, dst); + if (IS_ERR(mpath)) { + rcu_read_unlock(); + return PTR_ERR(mpath); + } + + mesh_path_fix_nexthop(mpath, sta); + + rcu_read_unlock(); + return 0; +} + +static int ieee80211_del_mpath(struct wiphy *wiphy, struct net_device *dev, + const u8 *dst) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + if (dst) + return mesh_path_del(sdata, dst); + + mesh_path_flush_by_iface(sdata); + return 0; +} + +static int ieee80211_change_mpath(struct wiphy *wiphy, struct net_device *dev, + const u8 *dst, const u8 *next_hop) +{ + struct ieee80211_sub_if_data *sdata; + struct mesh_path *mpath; + struct sta_info *sta; + + sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + rcu_read_lock(); + + sta = sta_info_get(sdata, next_hop); + if (!sta) { + rcu_read_unlock(); + return -ENOENT; + } + + mpath = mesh_path_lookup(sdata, dst); + if (!mpath) { + rcu_read_unlock(); + return -ENOENT; + } + + mesh_path_fix_nexthop(mpath, sta); + + rcu_read_unlock(); + return 0; +} + +static void mpath_set_pinfo(struct mesh_path *mpath, u8 *next_hop, + struct mpath_info *pinfo) +{ + struct sta_info *next_hop_sta = rcu_dereference(mpath->next_hop); + + if (next_hop_sta) + memcpy(next_hop, next_hop_sta->sta.addr, ETH_ALEN); + else + eth_zero_addr(next_hop); + + memset(pinfo, 0, sizeof(*pinfo)); + + pinfo->generation = mpath->sdata->u.mesh.mesh_paths_generation; + + pinfo->filled = MPATH_INFO_FRAME_QLEN | + MPATH_INFO_SN | + MPATH_INFO_METRIC | + MPATH_INFO_EXPTIME | + MPATH_INFO_DISCOVERY_TIMEOUT | + MPATH_INFO_DISCOVERY_RETRIES | + MPATH_INFO_FLAGS | + MPATH_INFO_HOP_COUNT | + MPATH_INFO_PATH_CHANGE; + + pinfo->frame_qlen = mpath->frame_queue.qlen; + pinfo->sn = mpath->sn; + pinfo->metric = mpath->metric; + if (time_before(jiffies, mpath->exp_time)) + pinfo->exptime = jiffies_to_msecs(mpath->exp_time - jiffies); + pinfo->discovery_timeout = + jiffies_to_msecs(mpath->discovery_timeout); + pinfo->discovery_retries = mpath->discovery_retries; + if (mpath->flags & MESH_PATH_ACTIVE) + pinfo->flags |= NL80211_MPATH_FLAG_ACTIVE; + if (mpath->flags & MESH_PATH_RESOLVING) + pinfo->flags |= NL80211_MPATH_FLAG_RESOLVING; + if (mpath->flags & MESH_PATH_SN_VALID) + pinfo->flags |= NL80211_MPATH_FLAG_SN_VALID; + if (mpath->flags & MESH_PATH_FIXED) + pinfo->flags |= NL80211_MPATH_FLAG_FIXED; + if (mpath->flags & MESH_PATH_RESOLVED) + pinfo->flags |= NL80211_MPATH_FLAG_RESOLVED; + pinfo->hop_count = mpath->hop_count; + pinfo->path_change_count = mpath->path_change_count; +} + +static int ieee80211_get_mpath(struct wiphy *wiphy, struct net_device *dev, + u8 *dst, u8 *next_hop, struct mpath_info *pinfo) + +{ + struct ieee80211_sub_if_data *sdata; + struct mesh_path *mpath; + + sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + rcu_read_lock(); + mpath = mesh_path_lookup(sdata, dst); + if (!mpath) { + rcu_read_unlock(); + return -ENOENT; + } + memcpy(dst, mpath->dst, ETH_ALEN); + mpath_set_pinfo(mpath, next_hop, pinfo); + rcu_read_unlock(); + return 0; +} + +static int ieee80211_dump_mpath(struct wiphy *wiphy, struct net_device *dev, + int idx, u8 *dst, u8 *next_hop, + struct mpath_info *pinfo) +{ + struct ieee80211_sub_if_data *sdata; + struct mesh_path *mpath; + + sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + rcu_read_lock(); + mpath = mesh_path_lookup_by_idx(sdata, idx); + if (!mpath) { + rcu_read_unlock(); + return -ENOENT; + } + memcpy(dst, mpath->dst, ETH_ALEN); + mpath_set_pinfo(mpath, next_hop, pinfo); + rcu_read_unlock(); + return 0; +} + +static void mpp_set_pinfo(struct mesh_path *mpath, u8 *mpp, + struct mpath_info *pinfo) +{ + memset(pinfo, 0, sizeof(*pinfo)); + memcpy(mpp, mpath->mpp, ETH_ALEN); + + pinfo->generation = mpath->sdata->u.mesh.mpp_paths_generation; +} + +static int ieee80211_get_mpp(struct wiphy *wiphy, struct net_device *dev, + u8 *dst, u8 *mpp, struct mpath_info *pinfo) + +{ + struct ieee80211_sub_if_data *sdata; + struct mesh_path *mpath; + + sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + rcu_read_lock(); + mpath = mpp_path_lookup(sdata, dst); + if (!mpath) { + rcu_read_unlock(); + return -ENOENT; + } + memcpy(dst, mpath->dst, ETH_ALEN); + mpp_set_pinfo(mpath, mpp, pinfo); + rcu_read_unlock(); + return 0; +} + +static int ieee80211_dump_mpp(struct wiphy *wiphy, struct net_device *dev, + int idx, u8 *dst, u8 *mpp, + struct mpath_info *pinfo) +{ + struct ieee80211_sub_if_data *sdata; + struct mesh_path *mpath; + + sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + rcu_read_lock(); + mpath = mpp_path_lookup_by_idx(sdata, idx); + if (!mpath) { + rcu_read_unlock(); + return -ENOENT; + } + memcpy(dst, mpath->dst, ETH_ALEN); + mpp_set_pinfo(mpath, mpp, pinfo); + rcu_read_unlock(); + return 0; +} + +static int ieee80211_get_mesh_config(struct wiphy *wiphy, + struct net_device *dev, + struct mesh_config *conf) +{ + struct ieee80211_sub_if_data *sdata; + sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + memcpy(conf, &(sdata->u.mesh.mshcfg), sizeof(struct mesh_config)); + return 0; +} + +static inline bool _chg_mesh_attr(enum nl80211_meshconf_params parm, u32 mask) +{ + return (mask >> (parm-1)) & 0x1; +} + +static int copy_mesh_setup(struct ieee80211_if_mesh *ifmsh, + const struct mesh_setup *setup) +{ + u8 *new_ie; + struct ieee80211_sub_if_data *sdata = container_of(ifmsh, + struct ieee80211_sub_if_data, u.mesh); + int i; + + /* allocate information elements */ + new_ie = NULL; + + if (setup->ie_len) { + new_ie = kmemdup(setup->ie, setup->ie_len, + GFP_KERNEL); + if (!new_ie) + return -ENOMEM; + } + ifmsh->ie_len = setup->ie_len; + ifmsh->ie = new_ie; + + /* now copy the rest of the setup parameters */ + ifmsh->mesh_id_len = setup->mesh_id_len; + memcpy(ifmsh->mesh_id, setup->mesh_id, ifmsh->mesh_id_len); + ifmsh->mesh_sp_id = setup->sync_method; + ifmsh->mesh_pp_id = setup->path_sel_proto; + ifmsh->mesh_pm_id = setup->path_metric; + ifmsh->user_mpm = setup->user_mpm; + ifmsh->mesh_auth_id = setup->auth_id; + ifmsh->security = IEEE80211_MESH_SEC_NONE; + ifmsh->userspace_handles_dfs = setup->userspace_handles_dfs; + if (setup->is_authenticated) + ifmsh->security |= IEEE80211_MESH_SEC_AUTHED; + if (setup->is_secure) + ifmsh->security |= IEEE80211_MESH_SEC_SECURED; + + /* mcast rate setting in Mesh Node */ + memcpy(sdata->vif.bss_conf.mcast_rate, setup->mcast_rate, + sizeof(setup->mcast_rate)); + sdata->vif.bss_conf.basic_rates = setup->basic_rates; + + sdata->vif.bss_conf.beacon_int = setup->beacon_interval; + sdata->vif.bss_conf.dtim_period = setup->dtim_period; + + sdata->beacon_rate_set = false; + if (wiphy_ext_feature_isset(sdata->local->hw.wiphy, + NL80211_EXT_FEATURE_BEACON_RATE_LEGACY)) { + for (i = 0; i < NUM_NL80211_BANDS; i++) { + sdata->beacon_rateidx_mask[i] = + setup->beacon_rate.control[i].legacy; + if (sdata->beacon_rateidx_mask[i]) + sdata->beacon_rate_set = true; + } + } + + return 0; +} + +static int ieee80211_update_mesh_config(struct wiphy *wiphy, + struct net_device *dev, u32 mask, + const struct mesh_config *nconf) +{ + struct mesh_config *conf; + struct ieee80211_sub_if_data *sdata; + struct ieee80211_if_mesh *ifmsh; + + sdata = IEEE80211_DEV_TO_SUB_IF(dev); + ifmsh = &sdata->u.mesh; + + /* Set the config options which we are interested in setting */ + conf = &(sdata->u.mesh.mshcfg); + if (_chg_mesh_attr(NL80211_MESHCONF_RETRY_TIMEOUT, mask)) + conf->dot11MeshRetryTimeout = nconf->dot11MeshRetryTimeout; + if (_chg_mesh_attr(NL80211_MESHCONF_CONFIRM_TIMEOUT, mask)) + conf->dot11MeshConfirmTimeout = nconf->dot11MeshConfirmTimeout; + if (_chg_mesh_attr(NL80211_MESHCONF_HOLDING_TIMEOUT, mask)) + conf->dot11MeshHoldingTimeout = nconf->dot11MeshHoldingTimeout; + if (_chg_mesh_attr(NL80211_MESHCONF_MAX_PEER_LINKS, mask)) + conf->dot11MeshMaxPeerLinks = nconf->dot11MeshMaxPeerLinks; + if (_chg_mesh_attr(NL80211_MESHCONF_MAX_RETRIES, mask)) + conf->dot11MeshMaxRetries = nconf->dot11MeshMaxRetries; + if (_chg_mesh_attr(NL80211_MESHCONF_TTL, mask)) + conf->dot11MeshTTL = nconf->dot11MeshTTL; + if (_chg_mesh_attr(NL80211_MESHCONF_ELEMENT_TTL, mask)) + conf->element_ttl = nconf->element_ttl; + if (_chg_mesh_attr(NL80211_MESHCONF_AUTO_OPEN_PLINKS, mask)) { + if (ifmsh->user_mpm) + return -EBUSY; + conf->auto_open_plinks = nconf->auto_open_plinks; + } + if (_chg_mesh_attr(NL80211_MESHCONF_SYNC_OFFSET_MAX_NEIGHBOR, mask)) + conf->dot11MeshNbrOffsetMaxNeighbor = + nconf->dot11MeshNbrOffsetMaxNeighbor; + if (_chg_mesh_attr(NL80211_MESHCONF_HWMP_MAX_PREQ_RETRIES, mask)) + conf->dot11MeshHWMPmaxPREQretries = + nconf->dot11MeshHWMPmaxPREQretries; + if (_chg_mesh_attr(NL80211_MESHCONF_PATH_REFRESH_TIME, mask)) + conf->path_refresh_time = nconf->path_refresh_time; + if (_chg_mesh_attr(NL80211_MESHCONF_MIN_DISCOVERY_TIMEOUT, mask)) + conf->min_discovery_timeout = nconf->min_discovery_timeout; + if (_chg_mesh_attr(NL80211_MESHCONF_HWMP_ACTIVE_PATH_TIMEOUT, mask)) + conf->dot11MeshHWMPactivePathTimeout = + nconf->dot11MeshHWMPactivePathTimeout; + if (_chg_mesh_attr(NL80211_MESHCONF_HWMP_PREQ_MIN_INTERVAL, mask)) + conf->dot11MeshHWMPpreqMinInterval = + nconf->dot11MeshHWMPpreqMinInterval; + if (_chg_mesh_attr(NL80211_MESHCONF_HWMP_PERR_MIN_INTERVAL, mask)) + conf->dot11MeshHWMPperrMinInterval = + nconf->dot11MeshHWMPperrMinInterval; + if (_chg_mesh_attr(NL80211_MESHCONF_HWMP_NET_DIAM_TRVS_TIME, + mask)) + conf->dot11MeshHWMPnetDiameterTraversalTime = + nconf->dot11MeshHWMPnetDiameterTraversalTime; + if (_chg_mesh_attr(NL80211_MESHCONF_HWMP_ROOTMODE, mask)) { + conf->dot11MeshHWMPRootMode = nconf->dot11MeshHWMPRootMode; + ieee80211_mesh_root_setup(ifmsh); + } + if (_chg_mesh_attr(NL80211_MESHCONF_GATE_ANNOUNCEMENTS, mask)) { + /* our current gate announcement implementation rides on root + * announcements, so require this ifmsh to also be a root node + * */ + if (nconf->dot11MeshGateAnnouncementProtocol && + !(conf->dot11MeshHWMPRootMode > IEEE80211_ROOTMODE_ROOT)) { + conf->dot11MeshHWMPRootMode = IEEE80211_PROACTIVE_RANN; + ieee80211_mesh_root_setup(ifmsh); + } + conf->dot11MeshGateAnnouncementProtocol = + nconf->dot11MeshGateAnnouncementProtocol; + } + if (_chg_mesh_attr(NL80211_MESHCONF_HWMP_RANN_INTERVAL, mask)) + conf->dot11MeshHWMPRannInterval = + nconf->dot11MeshHWMPRannInterval; + if (_chg_mesh_attr(NL80211_MESHCONF_FORWARDING, mask)) + conf->dot11MeshForwarding = nconf->dot11MeshForwarding; + if (_chg_mesh_attr(NL80211_MESHCONF_RSSI_THRESHOLD, mask)) { + /* our RSSI threshold implementation is supported only for + * devices that report signal in dBm. + */ + if (!ieee80211_hw_check(&sdata->local->hw, SIGNAL_DBM)) + return -ENOTSUPP; + conf->rssi_threshold = nconf->rssi_threshold; + } + if (_chg_mesh_attr(NL80211_MESHCONF_HT_OPMODE, mask)) { + conf->ht_opmode = nconf->ht_opmode; + sdata->vif.bss_conf.ht_operation_mode = nconf->ht_opmode; + ieee80211_link_info_change_notify(sdata, &sdata->deflink, + BSS_CHANGED_HT); + } + if (_chg_mesh_attr(NL80211_MESHCONF_HWMP_PATH_TO_ROOT_TIMEOUT, mask)) + conf->dot11MeshHWMPactivePathToRootTimeout = + nconf->dot11MeshHWMPactivePathToRootTimeout; + if (_chg_mesh_attr(NL80211_MESHCONF_HWMP_ROOT_INTERVAL, mask)) + conf->dot11MeshHWMProotInterval = + nconf->dot11MeshHWMProotInterval; + if (_chg_mesh_attr(NL80211_MESHCONF_HWMP_CONFIRMATION_INTERVAL, mask)) + conf->dot11MeshHWMPconfirmationInterval = + nconf->dot11MeshHWMPconfirmationInterval; + if (_chg_mesh_attr(NL80211_MESHCONF_POWER_MODE, mask)) { + conf->power_mode = nconf->power_mode; + ieee80211_mps_local_status_update(sdata); + } + if (_chg_mesh_attr(NL80211_MESHCONF_AWAKE_WINDOW, mask)) + conf->dot11MeshAwakeWindowDuration = + nconf->dot11MeshAwakeWindowDuration; + if (_chg_mesh_attr(NL80211_MESHCONF_PLINK_TIMEOUT, mask)) + conf->plink_timeout = nconf->plink_timeout; + if (_chg_mesh_attr(NL80211_MESHCONF_CONNECTED_TO_GATE, mask)) + conf->dot11MeshConnectedToMeshGate = + nconf->dot11MeshConnectedToMeshGate; + if (_chg_mesh_attr(NL80211_MESHCONF_NOLEARN, mask)) + conf->dot11MeshNolearn = nconf->dot11MeshNolearn; + if (_chg_mesh_attr(NL80211_MESHCONF_CONNECTED_TO_AS, mask)) + conf->dot11MeshConnectedToAuthServer = + nconf->dot11MeshConnectedToAuthServer; + ieee80211_mbss_info_change_notify(sdata, BSS_CHANGED_BEACON); + return 0; +} + +static int ieee80211_join_mesh(struct wiphy *wiphy, struct net_device *dev, + const struct mesh_config *conf, + const struct mesh_setup *setup) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + int err; + + memcpy(&ifmsh->mshcfg, conf, sizeof(struct mesh_config)); + err = copy_mesh_setup(ifmsh, setup); + if (err) + return err; + + sdata->control_port_over_nl80211 = setup->control_port_over_nl80211; + + /* can mesh use other SMPS modes? */ + sdata->deflink.smps_mode = IEEE80211_SMPS_OFF; + sdata->deflink.needed_rx_chains = sdata->local->rx_chains; + + mutex_lock(&sdata->local->mtx); + err = ieee80211_link_use_channel(&sdata->deflink, &setup->chandef, + IEEE80211_CHANCTX_SHARED); + mutex_unlock(&sdata->local->mtx); + if (err) + return err; + + return ieee80211_start_mesh(sdata); +} + +static int ieee80211_leave_mesh(struct wiphy *wiphy, struct net_device *dev) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + ieee80211_stop_mesh(sdata); + mutex_lock(&sdata->local->mtx); + ieee80211_link_release_channel(&sdata->deflink); + kfree(sdata->u.mesh.ie); + mutex_unlock(&sdata->local->mtx); + + return 0; +} +#endif + +static int ieee80211_change_bss(struct wiphy *wiphy, + struct net_device *dev, + struct bss_parameters *params) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_supported_band *sband; + u32 changed = 0; + + if (!sdata_dereference(sdata->deflink.u.ap.beacon, sdata)) + return -ENOENT; + + sband = ieee80211_get_sband(sdata); + if (!sband) + return -EINVAL; + + if (params->use_cts_prot >= 0) { + sdata->vif.bss_conf.use_cts_prot = params->use_cts_prot; + changed |= BSS_CHANGED_ERP_CTS_PROT; + } + if (params->use_short_preamble >= 0) { + sdata->vif.bss_conf.use_short_preamble = + params->use_short_preamble; + changed |= BSS_CHANGED_ERP_PREAMBLE; + } + + if (!sdata->vif.bss_conf.use_short_slot && + (sband->band == NL80211_BAND_5GHZ || + sband->band == NL80211_BAND_6GHZ)) { + sdata->vif.bss_conf.use_short_slot = true; + changed |= BSS_CHANGED_ERP_SLOT; + } + + if (params->use_short_slot_time >= 0) { + sdata->vif.bss_conf.use_short_slot = + params->use_short_slot_time; + changed |= BSS_CHANGED_ERP_SLOT; + } + + if (params->basic_rates) { + ieee80211_parse_bitrates(sdata->vif.bss_conf.chandef.width, + wiphy->bands[sband->band], + params->basic_rates, + params->basic_rates_len, + &sdata->vif.bss_conf.basic_rates); + changed |= BSS_CHANGED_BASIC_RATES; + ieee80211_check_rate_mask(&sdata->deflink); + } + + if (params->ap_isolate >= 0) { + if (params->ap_isolate) + sdata->flags |= IEEE80211_SDATA_DONT_BRIDGE_PACKETS; + else + sdata->flags &= ~IEEE80211_SDATA_DONT_BRIDGE_PACKETS; + ieee80211_check_fast_rx_iface(sdata); + } + + if (params->ht_opmode >= 0) { + sdata->vif.bss_conf.ht_operation_mode = + (u16) params->ht_opmode; + changed |= BSS_CHANGED_HT; + } + + if (params->p2p_ctwindow >= 0) { + sdata->vif.bss_conf.p2p_noa_attr.oppps_ctwindow &= + ~IEEE80211_P2P_OPPPS_CTWINDOW_MASK; + sdata->vif.bss_conf.p2p_noa_attr.oppps_ctwindow |= + params->p2p_ctwindow & IEEE80211_P2P_OPPPS_CTWINDOW_MASK; + changed |= BSS_CHANGED_P2P_PS; + } + + if (params->p2p_opp_ps > 0) { + sdata->vif.bss_conf.p2p_noa_attr.oppps_ctwindow |= + IEEE80211_P2P_OPPPS_ENABLE_BIT; + changed |= BSS_CHANGED_P2P_PS; + } else if (params->p2p_opp_ps == 0) { + sdata->vif.bss_conf.p2p_noa_attr.oppps_ctwindow &= + ~IEEE80211_P2P_OPPPS_ENABLE_BIT; + changed |= BSS_CHANGED_P2P_PS; + } + + ieee80211_link_info_change_notify(sdata, &sdata->deflink, changed); + + return 0; +} + +static int ieee80211_set_txq_params(struct wiphy *wiphy, + struct net_device *dev, + struct ieee80211_txq_params *params) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_link_data *link = + ieee80211_link_or_deflink(sdata, params->link_id, true); + struct ieee80211_tx_queue_params p; + + if (!local->ops->conf_tx) + return -EOPNOTSUPP; + + if (local->hw.queues < IEEE80211_NUM_ACS) + return -EOPNOTSUPP; + + if (IS_ERR(link)) + return PTR_ERR(link); + + memset(&p, 0, sizeof(p)); + p.aifs = params->aifs; + p.cw_max = params->cwmax; + p.cw_min = params->cwmin; + p.txop = params->txop; + + /* + * Setting tx queue params disables u-apsd because it's only + * called in master mode. + */ + p.uapsd = false; + + ieee80211_regulatory_limit_wmm_params(sdata, &p, params->ac); + + link->tx_conf[params->ac] = p; + if (drv_conf_tx(local, link, params->ac, &p)) { + wiphy_debug(local->hw.wiphy, + "failed to set TX queue parameters for AC %d\n", + params->ac); + return -EINVAL; + } + + ieee80211_link_info_change_notify(sdata, link, + BSS_CHANGED_QOS); + + return 0; +} + +#ifdef CONFIG_PM +static int ieee80211_suspend(struct wiphy *wiphy, + struct cfg80211_wowlan *wowlan) +{ + return __ieee80211_suspend(wiphy_priv(wiphy), wowlan); +} + +static int ieee80211_resume(struct wiphy *wiphy) +{ + return __ieee80211_resume(wiphy_priv(wiphy)); +} +#else +#define ieee80211_suspend NULL +#define ieee80211_resume NULL +#endif + +static int ieee80211_scan(struct wiphy *wiphy, + struct cfg80211_scan_request *req) +{ + struct ieee80211_sub_if_data *sdata; + + sdata = IEEE80211_WDEV_TO_SUB_IF(req->wdev); + + switch (ieee80211_vif_type_p2p(&sdata->vif)) { + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_ADHOC: + case NL80211_IFTYPE_MESH_POINT: + case NL80211_IFTYPE_P2P_CLIENT: + case NL80211_IFTYPE_P2P_DEVICE: + break; + case NL80211_IFTYPE_P2P_GO: + if (sdata->local->ops->hw_scan) + break; + /* + * FIXME: implement NoA while scanning in software, + * for now fall through to allow scanning only when + * beaconing hasn't been configured yet + */ + fallthrough; + case NL80211_IFTYPE_AP: + /* + * If the scan has been forced (and the driver supports + * forcing), don't care about being beaconing already. + * This will create problems to the attached stations (e.g. all + * the frames sent while scanning on other channel will be + * lost) + */ + if (sdata->deflink.u.ap.beacon && + (!(wiphy->features & NL80211_FEATURE_AP_SCAN) || + !(req->flags & NL80211_SCAN_FLAG_AP))) + return -EOPNOTSUPP; + break; + case NL80211_IFTYPE_NAN: + default: + return -EOPNOTSUPP; + } + + return ieee80211_request_scan(sdata, req); +} + +static void ieee80211_abort_scan(struct wiphy *wiphy, struct wireless_dev *wdev) +{ + ieee80211_scan_cancel(wiphy_priv(wiphy)); +} + +static int +ieee80211_sched_scan_start(struct wiphy *wiphy, + struct net_device *dev, + struct cfg80211_sched_scan_request *req) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + if (!sdata->local->ops->sched_scan_start) + return -EOPNOTSUPP; + + return ieee80211_request_sched_scan_start(sdata, req); +} + +static int +ieee80211_sched_scan_stop(struct wiphy *wiphy, struct net_device *dev, + u64 reqid) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + + if (!local->ops->sched_scan_stop) + return -EOPNOTSUPP; + + return ieee80211_request_sched_scan_stop(local); +} + +static int ieee80211_auth(struct wiphy *wiphy, struct net_device *dev, + struct cfg80211_auth_request *req) +{ + return ieee80211_mgd_auth(IEEE80211_DEV_TO_SUB_IF(dev), req); +} + +static int ieee80211_assoc(struct wiphy *wiphy, struct net_device *dev, + struct cfg80211_assoc_request *req) +{ + return ieee80211_mgd_assoc(IEEE80211_DEV_TO_SUB_IF(dev), req); +} + +static int ieee80211_deauth(struct wiphy *wiphy, struct net_device *dev, + struct cfg80211_deauth_request *req) +{ + return ieee80211_mgd_deauth(IEEE80211_DEV_TO_SUB_IF(dev), req); +} + +static int ieee80211_disassoc(struct wiphy *wiphy, struct net_device *dev, + struct cfg80211_disassoc_request *req) +{ + return ieee80211_mgd_disassoc(IEEE80211_DEV_TO_SUB_IF(dev), req); +} + +static int ieee80211_join_ibss(struct wiphy *wiphy, struct net_device *dev, + struct cfg80211_ibss_params *params) +{ + return ieee80211_ibss_join(IEEE80211_DEV_TO_SUB_IF(dev), params); +} + +static int ieee80211_leave_ibss(struct wiphy *wiphy, struct net_device *dev) +{ + return ieee80211_ibss_leave(IEEE80211_DEV_TO_SUB_IF(dev)); +} + +static int ieee80211_join_ocb(struct wiphy *wiphy, struct net_device *dev, + struct ocb_setup *setup) +{ + return ieee80211_ocb_join(IEEE80211_DEV_TO_SUB_IF(dev), setup); +} + +static int ieee80211_leave_ocb(struct wiphy *wiphy, struct net_device *dev) +{ + return ieee80211_ocb_leave(IEEE80211_DEV_TO_SUB_IF(dev)); +} + +static int ieee80211_set_mcast_rate(struct wiphy *wiphy, struct net_device *dev, + int rate[NUM_NL80211_BANDS]) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + memcpy(sdata->vif.bss_conf.mcast_rate, rate, + sizeof(int) * NUM_NL80211_BANDS); + + ieee80211_link_info_change_notify(sdata, &sdata->deflink, + BSS_CHANGED_MCAST_RATE); + + return 0; +} + +static int ieee80211_set_wiphy_params(struct wiphy *wiphy, u32 changed) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + int err; + + if (changed & WIPHY_PARAM_FRAG_THRESHOLD) { + ieee80211_check_fast_xmit_all(local); + + err = drv_set_frag_threshold(local, wiphy->frag_threshold); + + if (err) { + ieee80211_check_fast_xmit_all(local); + return err; + } + } + + if ((changed & WIPHY_PARAM_COVERAGE_CLASS) || + (changed & WIPHY_PARAM_DYN_ACK)) { + s16 coverage_class; + + coverage_class = changed & WIPHY_PARAM_COVERAGE_CLASS ? + wiphy->coverage_class : -1; + err = drv_set_coverage_class(local, coverage_class); + + if (err) + return err; + } + + if (changed & WIPHY_PARAM_RTS_THRESHOLD) { + err = drv_set_rts_threshold(local, wiphy->rts_threshold); + + if (err) + return err; + } + + if (changed & WIPHY_PARAM_RETRY_SHORT) { + if (wiphy->retry_short > IEEE80211_MAX_TX_RETRY) + return -EINVAL; + local->hw.conf.short_frame_max_tx_count = wiphy->retry_short; + } + if (changed & WIPHY_PARAM_RETRY_LONG) { + if (wiphy->retry_long > IEEE80211_MAX_TX_RETRY) + return -EINVAL; + local->hw.conf.long_frame_max_tx_count = wiphy->retry_long; + } + if (changed & + (WIPHY_PARAM_RETRY_SHORT | WIPHY_PARAM_RETRY_LONG)) + ieee80211_hw_config(local, IEEE80211_CONF_CHANGE_RETRY_LIMITS); + + if (changed & (WIPHY_PARAM_TXQ_LIMIT | + WIPHY_PARAM_TXQ_MEMORY_LIMIT | + WIPHY_PARAM_TXQ_QUANTUM)) + ieee80211_txq_set_params(local); + + return 0; +} + +static int ieee80211_set_tx_power(struct wiphy *wiphy, + struct wireless_dev *wdev, + enum nl80211_tx_power_setting type, int mbm) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + struct ieee80211_sub_if_data *sdata; + enum nl80211_tx_power_setting txp_type = type; + bool update_txp_type = false; + bool has_monitor = false; + + if (wdev) { + sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + + if (sdata->vif.type == NL80211_IFTYPE_MONITOR) { + sdata = wiphy_dereference(local->hw.wiphy, + local->monitor_sdata); + if (!sdata) + return -EOPNOTSUPP; + } + + switch (type) { + case NL80211_TX_POWER_AUTOMATIC: + sdata->deflink.user_power_level = + IEEE80211_UNSET_POWER_LEVEL; + txp_type = NL80211_TX_POWER_LIMITED; + break; + case NL80211_TX_POWER_LIMITED: + case NL80211_TX_POWER_FIXED: + if (mbm < 0 || (mbm % 100)) + return -EOPNOTSUPP; + sdata->deflink.user_power_level = MBM_TO_DBM(mbm); + break; + } + + if (txp_type != sdata->vif.bss_conf.txpower_type) { + update_txp_type = true; + sdata->vif.bss_conf.txpower_type = txp_type; + } + + ieee80211_recalc_txpower(sdata, update_txp_type); + + return 0; + } + + switch (type) { + case NL80211_TX_POWER_AUTOMATIC: + local->user_power_level = IEEE80211_UNSET_POWER_LEVEL; + txp_type = NL80211_TX_POWER_LIMITED; + break; + case NL80211_TX_POWER_LIMITED: + case NL80211_TX_POWER_FIXED: + if (mbm < 0 || (mbm % 100)) + return -EOPNOTSUPP; + local->user_power_level = MBM_TO_DBM(mbm); + break; + } + + mutex_lock(&local->iflist_mtx); + list_for_each_entry(sdata, &local->interfaces, list) { + if (sdata->vif.type == NL80211_IFTYPE_MONITOR) { + has_monitor = true; + continue; + } + sdata->deflink.user_power_level = local->user_power_level; + if (txp_type != sdata->vif.bss_conf.txpower_type) + update_txp_type = true; + sdata->vif.bss_conf.txpower_type = txp_type; + } + list_for_each_entry(sdata, &local->interfaces, list) { + if (sdata->vif.type == NL80211_IFTYPE_MONITOR) + continue; + ieee80211_recalc_txpower(sdata, update_txp_type); + } + mutex_unlock(&local->iflist_mtx); + + if (has_monitor) { + sdata = wiphy_dereference(local->hw.wiphy, + local->monitor_sdata); + if (sdata) { + sdata->deflink.user_power_level = local->user_power_level; + if (txp_type != sdata->vif.bss_conf.txpower_type) + update_txp_type = true; + sdata->vif.bss_conf.txpower_type = txp_type; + + ieee80211_recalc_txpower(sdata, update_txp_type); + } + } + + return 0; +} + +static int ieee80211_get_tx_power(struct wiphy *wiphy, + struct wireless_dev *wdev, + int *dbm) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + + if (local->ops->get_txpower) + return drv_get_txpower(local, sdata, dbm); + + if (!local->use_chanctx) + *dbm = local->hw.conf.power_level; + else + *dbm = sdata->vif.bss_conf.txpower; + + /* INT_MIN indicates no power level was set yet */ + if (*dbm == INT_MIN) + return -EINVAL; + + return 0; +} + +static void ieee80211_rfkill_poll(struct wiphy *wiphy) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + + drv_rfkill_poll(local); +} + +#ifdef CONFIG_NL80211_TESTMODE +static int ieee80211_testmode_cmd(struct wiphy *wiphy, + struct wireless_dev *wdev, + void *data, int len) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + struct ieee80211_vif *vif = NULL; + + if (!local->ops->testmode_cmd) + return -EOPNOTSUPP; + + if (wdev) { + struct ieee80211_sub_if_data *sdata; + + sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + if (sdata->flags & IEEE80211_SDATA_IN_DRIVER) + vif = &sdata->vif; + } + + return local->ops->testmode_cmd(&local->hw, vif, data, len); +} + +static int ieee80211_testmode_dump(struct wiphy *wiphy, + struct sk_buff *skb, + struct netlink_callback *cb, + void *data, int len) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + + if (!local->ops->testmode_dump) + return -EOPNOTSUPP; + + return local->ops->testmode_dump(&local->hw, skb, cb, data, len); +} +#endif + +int __ieee80211_request_smps_mgd(struct ieee80211_sub_if_data *sdata, + struct ieee80211_link_data *link, + enum ieee80211_smps_mode smps_mode) +{ + const u8 *ap; + enum ieee80211_smps_mode old_req; + int err; + struct sta_info *sta; + bool tdls_peer_found = false; + + lockdep_assert_held(&sdata->wdev.mtx); + + if (WARN_ON_ONCE(sdata->vif.type != NL80211_IFTYPE_STATION)) + return -EINVAL; + + old_req = link->u.mgd.req_smps; + link->u.mgd.req_smps = smps_mode; + + if (old_req == smps_mode && + smps_mode != IEEE80211_SMPS_AUTOMATIC) + return 0; + + /* + * If not associated, or current association is not an HT + * association, there's no need to do anything, just store + * the new value until we associate. + */ + if (!sdata->u.mgd.associated || + link->conf->chandef.width == NL80211_CHAN_WIDTH_20_NOHT) + return 0; + + ap = link->u.mgd.bssid; + + rcu_read_lock(); + list_for_each_entry_rcu(sta, &sdata->local->sta_list, list) { + if (!sta->sta.tdls || sta->sdata != sdata || !sta->uploaded || + !test_sta_flag(sta, WLAN_STA_AUTHORIZED)) + continue; + + tdls_peer_found = true; + break; + } + rcu_read_unlock(); + + if (smps_mode == IEEE80211_SMPS_AUTOMATIC) { + if (tdls_peer_found || !sdata->u.mgd.powersave) + smps_mode = IEEE80211_SMPS_OFF; + else + smps_mode = IEEE80211_SMPS_DYNAMIC; + } + + /* send SM PS frame to AP */ + err = ieee80211_send_smps_action(sdata, smps_mode, + ap, ap); + if (err) + link->u.mgd.req_smps = old_req; + else if (smps_mode != IEEE80211_SMPS_OFF && tdls_peer_found) + ieee80211_teardown_tdls_peers(sdata); + + return err; +} + +static int ieee80211_set_power_mgmt(struct wiphy *wiphy, struct net_device *dev, + bool enabled, int timeout) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = wdev_priv(dev->ieee80211_ptr); + unsigned int link_id; + + if (sdata->vif.type != NL80211_IFTYPE_STATION) + return -EOPNOTSUPP; + + if (!ieee80211_hw_check(&local->hw, SUPPORTS_PS)) + return -EOPNOTSUPP; + + if (enabled == sdata->u.mgd.powersave && + timeout == local->dynamic_ps_forced_timeout) + return 0; + + sdata->u.mgd.powersave = enabled; + local->dynamic_ps_forced_timeout = timeout; + + /* no change, but if automatic follow powersave */ + sdata_lock(sdata); + for (link_id = 0; link_id < ARRAY_SIZE(sdata->link); link_id++) { + struct ieee80211_link_data *link; + + link = sdata_dereference(sdata->link[link_id], sdata); + + if (!link) + continue; + __ieee80211_request_smps_mgd(sdata, link, + link->u.mgd.req_smps); + } + sdata_unlock(sdata); + + if (ieee80211_hw_check(&local->hw, SUPPORTS_DYNAMIC_PS)) + ieee80211_hw_config(local, IEEE80211_CONF_CHANGE_PS); + + ieee80211_recalc_ps(local); + ieee80211_recalc_ps_vif(sdata); + ieee80211_check_fast_rx_iface(sdata); + + return 0; +} + +static int ieee80211_set_cqm_rssi_config(struct wiphy *wiphy, + struct net_device *dev, + s32 rssi_thold, u32 rssi_hyst) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_vif *vif = &sdata->vif; + struct ieee80211_bss_conf *bss_conf = &vif->bss_conf; + + if (rssi_thold == bss_conf->cqm_rssi_thold && + rssi_hyst == bss_conf->cqm_rssi_hyst) + return 0; + + if (sdata->vif.driver_flags & IEEE80211_VIF_BEACON_FILTER && + !(sdata->vif.driver_flags & IEEE80211_VIF_SUPPORTS_CQM_RSSI)) + return -EOPNOTSUPP; + + bss_conf->cqm_rssi_thold = rssi_thold; + bss_conf->cqm_rssi_hyst = rssi_hyst; + bss_conf->cqm_rssi_low = 0; + bss_conf->cqm_rssi_high = 0; + sdata->deflink.u.mgd.last_cqm_event_signal = 0; + + /* tell the driver upon association, unless already associated */ + if (sdata->u.mgd.associated && + sdata->vif.driver_flags & IEEE80211_VIF_SUPPORTS_CQM_RSSI) + ieee80211_link_info_change_notify(sdata, &sdata->deflink, + BSS_CHANGED_CQM); + + return 0; +} + +static int ieee80211_set_cqm_rssi_range_config(struct wiphy *wiphy, + struct net_device *dev, + s32 rssi_low, s32 rssi_high) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_vif *vif = &sdata->vif; + struct ieee80211_bss_conf *bss_conf = &vif->bss_conf; + + if (sdata->vif.driver_flags & IEEE80211_VIF_BEACON_FILTER) + return -EOPNOTSUPP; + + bss_conf->cqm_rssi_low = rssi_low; + bss_conf->cqm_rssi_high = rssi_high; + bss_conf->cqm_rssi_thold = 0; + bss_conf->cqm_rssi_hyst = 0; + sdata->deflink.u.mgd.last_cqm_event_signal = 0; + + /* tell the driver upon association, unless already associated */ + if (sdata->u.mgd.associated && + sdata->vif.driver_flags & IEEE80211_VIF_SUPPORTS_CQM_RSSI) + ieee80211_link_info_change_notify(sdata, &sdata->deflink, + BSS_CHANGED_CQM); + + return 0; +} + +static int ieee80211_set_bitrate_mask(struct wiphy *wiphy, + struct net_device *dev, + unsigned int link_id, + const u8 *addr, + const struct cfg80211_bitrate_mask *mask) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = wdev_priv(dev->ieee80211_ptr); + int i, ret; + + if (!ieee80211_sdata_running(sdata)) + return -ENETDOWN; + + /* + * If active validate the setting and reject it if it doesn't leave + * at least one basic rate usable, since we really have to be able + * to send something, and if we're an AP we have to be able to do + * so at a basic rate so that all clients can receive it. + */ + if (rcu_access_pointer(sdata->vif.bss_conf.chanctx_conf) && + sdata->vif.bss_conf.chandef.chan) { + u32 basic_rates = sdata->vif.bss_conf.basic_rates; + enum nl80211_band band = sdata->vif.bss_conf.chandef.chan->band; + + if (!(mask->control[band].legacy & basic_rates)) + return -EINVAL; + } + + if (ieee80211_hw_check(&local->hw, HAS_RATE_CONTROL)) { + ret = drv_set_bitrate_mask(local, sdata, mask); + if (ret) + return ret; + } + + for (i = 0; i < NUM_NL80211_BANDS; i++) { + struct ieee80211_supported_band *sband = wiphy->bands[i]; + int j; + + sdata->rc_rateidx_mask[i] = mask->control[i].legacy; + memcpy(sdata->rc_rateidx_mcs_mask[i], mask->control[i].ht_mcs, + sizeof(mask->control[i].ht_mcs)); + memcpy(sdata->rc_rateidx_vht_mcs_mask[i], + mask->control[i].vht_mcs, + sizeof(mask->control[i].vht_mcs)); + + sdata->rc_has_mcs_mask[i] = false; + sdata->rc_has_vht_mcs_mask[i] = false; + if (!sband) + continue; + + for (j = 0; j < IEEE80211_HT_MCS_MASK_LEN; j++) { + if (sdata->rc_rateidx_mcs_mask[i][j] != 0xff) { + sdata->rc_has_mcs_mask[i] = true; + break; + } + } + + for (j = 0; j < NL80211_VHT_NSS_MAX; j++) { + if (sdata->rc_rateidx_vht_mcs_mask[i][j] != 0xffff) { + sdata->rc_has_vht_mcs_mask[i] = true; + break; + } + } + } + + return 0; +} + +static int ieee80211_start_radar_detection(struct wiphy *wiphy, + struct net_device *dev, + struct cfg80211_chan_def *chandef, + u32 cac_time_ms) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + int err; + + mutex_lock(&local->mtx); + if (!list_empty(&local->roc_list) || local->scanning) { + err = -EBUSY; + goto out_unlock; + } + + /* whatever, but channel contexts should not complain about that one */ + sdata->deflink.smps_mode = IEEE80211_SMPS_OFF; + sdata->deflink.needed_rx_chains = local->rx_chains; + + err = ieee80211_link_use_channel(&sdata->deflink, chandef, + IEEE80211_CHANCTX_SHARED); + if (err) + goto out_unlock; + + ieee80211_queue_delayed_work(&sdata->local->hw, + &sdata->deflink.dfs_cac_timer_work, + msecs_to_jiffies(cac_time_ms)); + + out_unlock: + mutex_unlock(&local->mtx); + return err; +} + +static void ieee80211_end_cac(struct wiphy *wiphy, + struct net_device *dev) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + + mutex_lock(&local->mtx); + list_for_each_entry(sdata, &local->interfaces, list) { + /* it might be waiting for the local->mtx, but then + * by the time it gets it, sdata->wdev.cac_started + * will no longer be true + */ + cancel_delayed_work(&sdata->deflink.dfs_cac_timer_work); + + if (sdata->wdev.cac_started) { + ieee80211_link_release_channel(&sdata->deflink); + sdata->wdev.cac_started = false; + } + } + mutex_unlock(&local->mtx); +} + +static struct cfg80211_beacon_data * +cfg80211_beacon_dup(struct cfg80211_beacon_data *beacon) +{ + struct cfg80211_beacon_data *new_beacon; + u8 *pos; + int len; + + len = beacon->head_len + beacon->tail_len + beacon->beacon_ies_len + + beacon->proberesp_ies_len + beacon->assocresp_ies_len + + beacon->probe_resp_len + beacon->lci_len + beacon->civicloc_len + + ieee80211_get_mbssid_beacon_len(beacon->mbssid_ies); + + new_beacon = kzalloc(sizeof(*new_beacon) + len, GFP_KERNEL); + if (!new_beacon) + return NULL; + + if (beacon->mbssid_ies && beacon->mbssid_ies->cnt) { + new_beacon->mbssid_ies = + kzalloc(struct_size(new_beacon->mbssid_ies, + elem, beacon->mbssid_ies->cnt), + GFP_KERNEL); + if (!new_beacon->mbssid_ies) { + kfree(new_beacon); + return NULL; + } + } + + pos = (u8 *)(new_beacon + 1); + if (beacon->head_len) { + new_beacon->head_len = beacon->head_len; + new_beacon->head = pos; + memcpy(pos, beacon->head, beacon->head_len); + pos += beacon->head_len; + } + if (beacon->tail_len) { + new_beacon->tail_len = beacon->tail_len; + new_beacon->tail = pos; + memcpy(pos, beacon->tail, beacon->tail_len); + pos += beacon->tail_len; + } + if (beacon->beacon_ies_len) { + new_beacon->beacon_ies_len = beacon->beacon_ies_len; + new_beacon->beacon_ies = pos; + memcpy(pos, beacon->beacon_ies, beacon->beacon_ies_len); + pos += beacon->beacon_ies_len; + } + if (beacon->proberesp_ies_len) { + new_beacon->proberesp_ies_len = beacon->proberesp_ies_len; + new_beacon->proberesp_ies = pos; + memcpy(pos, beacon->proberesp_ies, beacon->proberesp_ies_len); + pos += beacon->proberesp_ies_len; + } + if (beacon->assocresp_ies_len) { + new_beacon->assocresp_ies_len = beacon->assocresp_ies_len; + new_beacon->assocresp_ies = pos; + memcpy(pos, beacon->assocresp_ies, beacon->assocresp_ies_len); + pos += beacon->assocresp_ies_len; + } + if (beacon->probe_resp_len) { + new_beacon->probe_resp_len = beacon->probe_resp_len; + new_beacon->probe_resp = pos; + memcpy(pos, beacon->probe_resp, beacon->probe_resp_len); + pos += beacon->probe_resp_len; + } + if (beacon->mbssid_ies && beacon->mbssid_ies->cnt) + pos += ieee80211_copy_mbssid_beacon(pos, + new_beacon->mbssid_ies, + beacon->mbssid_ies); + + /* might copy -1, meaning no changes requested */ + new_beacon->ftm_responder = beacon->ftm_responder; + if (beacon->lci) { + new_beacon->lci_len = beacon->lci_len; + new_beacon->lci = pos; + memcpy(pos, beacon->lci, beacon->lci_len); + pos += beacon->lci_len; + } + if (beacon->civicloc) { + new_beacon->civicloc_len = beacon->civicloc_len; + new_beacon->civicloc = pos; + memcpy(pos, beacon->civicloc, beacon->civicloc_len); + pos += beacon->civicloc_len; + } + + return new_beacon; +} + +void ieee80211_csa_finish(struct ieee80211_vif *vif) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_local *local = sdata->local; + + rcu_read_lock(); + + if (vif->mbssid_tx_vif == vif) { + /* Trigger ieee80211_csa_finish() on the non-transmitting + * interfaces when channel switch is received on + * transmitting interface + */ + struct ieee80211_sub_if_data *iter; + + list_for_each_entry_rcu(iter, &local->interfaces, list) { + if (!ieee80211_sdata_running(iter)) + continue; + + if (iter == sdata || iter->vif.mbssid_tx_vif != vif) + continue; + + ieee80211_queue_work(&iter->local->hw, + &iter->deflink.csa_finalize_work); + } + } + ieee80211_queue_work(&local->hw, &sdata->deflink.csa_finalize_work); + + rcu_read_unlock(); +} +EXPORT_SYMBOL(ieee80211_csa_finish); + +void ieee80211_channel_switch_disconnect(struct ieee80211_vif *vif, bool block_tx) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_local *local = sdata->local; + + sdata->deflink.csa_block_tx = block_tx; + sdata_info(sdata, "channel switch failed, disconnecting\n"); + ieee80211_queue_work(&local->hw, &ifmgd->csa_connection_drop_work); +} +EXPORT_SYMBOL(ieee80211_channel_switch_disconnect); + +static int ieee80211_set_after_csa_beacon(struct ieee80211_sub_if_data *sdata, + u32 *changed) +{ + int err; + + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP: + if (!sdata->deflink.u.ap.next_beacon) + return -EINVAL; + + err = ieee80211_assign_beacon(sdata, &sdata->deflink, + sdata->deflink.u.ap.next_beacon, + NULL, NULL); + ieee80211_free_next_beacon(&sdata->deflink); + + if (err < 0) + return err; + *changed |= err; + break; + case NL80211_IFTYPE_ADHOC: + err = ieee80211_ibss_finish_csa(sdata); + if (err < 0) + return err; + *changed |= err; + break; +#ifdef CONFIG_MAC80211_MESH + case NL80211_IFTYPE_MESH_POINT: + err = ieee80211_mesh_finish_csa(sdata); + if (err < 0) + return err; + *changed |= err; + break; +#endif + default: + WARN_ON(1); + return -EINVAL; + } + + return 0; +} + +static int __ieee80211_csa_finalize(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + u32 changed = 0; + int err; + + sdata_assert_lock(sdata); + lockdep_assert_held(&local->mtx); + lockdep_assert_held(&local->chanctx_mtx); + + /* + * using reservation isn't immediate as it may be deferred until later + * with multi-vif. once reservation is complete it will re-schedule the + * work with no reserved_chanctx so verify chandef to check if it + * completed successfully + */ + + if (sdata->deflink.reserved_chanctx) { + /* + * with multi-vif csa driver may call ieee80211_csa_finish() + * many times while waiting for other interfaces to use their + * reservations + */ + if (sdata->deflink.reserved_ready) + return 0; + + return ieee80211_link_use_reserved_context(&sdata->deflink); + } + + if (!cfg80211_chandef_identical(&sdata->vif.bss_conf.chandef, + &sdata->deflink.csa_chandef)) + return -EINVAL; + + sdata->vif.bss_conf.csa_active = false; + + err = ieee80211_set_after_csa_beacon(sdata, &changed); + if (err) + return err; + + ieee80211_link_info_change_notify(sdata, &sdata->deflink, changed); + + if (sdata->deflink.csa_block_tx) { + ieee80211_wake_vif_queues(local, sdata, + IEEE80211_QUEUE_STOP_REASON_CSA); + sdata->deflink.csa_block_tx = false; + } + + err = drv_post_channel_switch(sdata); + if (err) + return err; + + cfg80211_ch_switch_notify(sdata->dev, &sdata->deflink.csa_chandef, 0); + + return 0; +} + +static void ieee80211_csa_finalize(struct ieee80211_sub_if_data *sdata) +{ + if (__ieee80211_csa_finalize(sdata)) { + sdata_info(sdata, "failed to finalize CSA, disconnecting\n"); + cfg80211_stop_iface(sdata->local->hw.wiphy, &sdata->wdev, + GFP_KERNEL); + } +} + +void ieee80211_csa_finalize_work(struct work_struct *work) +{ + struct ieee80211_sub_if_data *sdata = + container_of(work, struct ieee80211_sub_if_data, + deflink.csa_finalize_work); + struct ieee80211_local *local = sdata->local; + + sdata_lock(sdata); + mutex_lock(&local->mtx); + mutex_lock(&local->chanctx_mtx); + + /* AP might have been stopped while waiting for the lock. */ + if (!sdata->vif.bss_conf.csa_active) + goto unlock; + + if (!ieee80211_sdata_running(sdata)) + goto unlock; + + ieee80211_csa_finalize(sdata); + +unlock: + mutex_unlock(&local->chanctx_mtx); + mutex_unlock(&local->mtx); + sdata_unlock(sdata); +} + +static int ieee80211_set_csa_beacon(struct ieee80211_sub_if_data *sdata, + struct cfg80211_csa_settings *params, + u32 *changed) +{ + struct ieee80211_csa_settings csa = {}; + int err; + + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP: + sdata->deflink.u.ap.next_beacon = + cfg80211_beacon_dup(¶ms->beacon_after); + if (!sdata->deflink.u.ap.next_beacon) + return -ENOMEM; + + /* + * With a count of 0, we don't have to wait for any + * TBTT before switching, so complete the CSA + * immediately. In theory, with a count == 1 we + * should delay the switch until just before the next + * TBTT, but that would complicate things so we switch + * immediately too. If we would delay the switch + * until the next TBTT, we would have to set the probe + * response here. + * + * TODO: A channel switch with count <= 1 without + * sending a CSA action frame is kind of useless, + * because the clients won't know we're changing + * channels. The action frame must be implemented + * either here or in the userspace. + */ + if (params->count <= 1) + break; + + if ((params->n_counter_offsets_beacon > + IEEE80211_MAX_CNTDWN_COUNTERS_NUM) || + (params->n_counter_offsets_presp > + IEEE80211_MAX_CNTDWN_COUNTERS_NUM)) { + ieee80211_free_next_beacon(&sdata->deflink); + return -EINVAL; + } + + csa.counter_offsets_beacon = params->counter_offsets_beacon; + csa.counter_offsets_presp = params->counter_offsets_presp; + csa.n_counter_offsets_beacon = params->n_counter_offsets_beacon; + csa.n_counter_offsets_presp = params->n_counter_offsets_presp; + csa.count = params->count; + + err = ieee80211_assign_beacon(sdata, &sdata->deflink, + ¶ms->beacon_csa, &csa, + NULL); + if (err < 0) { + ieee80211_free_next_beacon(&sdata->deflink); + return err; + } + *changed |= err; + + break; + case NL80211_IFTYPE_ADHOC: + if (!sdata->vif.cfg.ibss_joined) + return -EINVAL; + + if (params->chandef.width != sdata->u.ibss.chandef.width) + return -EINVAL; + + switch (params->chandef.width) { + case NL80211_CHAN_WIDTH_40: + if (cfg80211_get_chandef_type(¶ms->chandef) != + cfg80211_get_chandef_type(&sdata->u.ibss.chandef)) + return -EINVAL; + break; + case NL80211_CHAN_WIDTH_5: + case NL80211_CHAN_WIDTH_10: + case NL80211_CHAN_WIDTH_20_NOHT: + case NL80211_CHAN_WIDTH_20: + break; + default: + return -EINVAL; + } + + /* changes into another band are not supported */ + if (sdata->u.ibss.chandef.chan->band != + params->chandef.chan->band) + return -EINVAL; + + /* see comments in the NL80211_IFTYPE_AP block */ + if (params->count > 1) { + err = ieee80211_ibss_csa_beacon(sdata, params); + if (err < 0) + return err; + *changed |= err; + } + + ieee80211_send_action_csa(sdata, params); + + break; +#ifdef CONFIG_MAC80211_MESH + case NL80211_IFTYPE_MESH_POINT: { + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + + /* changes into another band are not supported */ + if (sdata->vif.bss_conf.chandef.chan->band != + params->chandef.chan->band) + return -EINVAL; + + if (ifmsh->csa_role == IEEE80211_MESH_CSA_ROLE_NONE) { + ifmsh->csa_role = IEEE80211_MESH_CSA_ROLE_INIT; + if (!ifmsh->pre_value) + ifmsh->pre_value = 1; + else + ifmsh->pre_value++; + } + + /* see comments in the NL80211_IFTYPE_AP block */ + if (params->count > 1) { + err = ieee80211_mesh_csa_beacon(sdata, params); + if (err < 0) { + ifmsh->csa_role = IEEE80211_MESH_CSA_ROLE_NONE; + return err; + } + *changed |= err; + } + + if (ifmsh->csa_role == IEEE80211_MESH_CSA_ROLE_INIT) + ieee80211_send_action_csa(sdata, params); + + break; + } +#endif + default: + return -EOPNOTSUPP; + } + + return 0; +} + +static void ieee80211_color_change_abort(struct ieee80211_sub_if_data *sdata) +{ + sdata->vif.bss_conf.color_change_active = false; + + ieee80211_free_next_beacon(&sdata->deflink); + + cfg80211_color_change_aborted_notify(sdata->dev); +} + +static int +__ieee80211_channel_switch(struct wiphy *wiphy, struct net_device *dev, + struct cfg80211_csa_settings *params) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + struct ieee80211_channel_switch ch_switch; + struct ieee80211_chanctx_conf *conf; + struct ieee80211_chanctx *chanctx; + u32 changed = 0; + int err; + + sdata_assert_lock(sdata); + lockdep_assert_held(&local->mtx); + + if (!list_empty(&local->roc_list) || local->scanning) + return -EBUSY; + + if (sdata->wdev.cac_started) + return -EBUSY; + + if (cfg80211_chandef_identical(¶ms->chandef, + &sdata->vif.bss_conf.chandef)) + return -EINVAL; + + /* don't allow another channel switch if one is already active. */ + if (sdata->vif.bss_conf.csa_active) + return -EBUSY; + + mutex_lock(&local->chanctx_mtx); + conf = rcu_dereference_protected(sdata->vif.bss_conf.chanctx_conf, + lockdep_is_held(&local->chanctx_mtx)); + if (!conf) { + err = -EBUSY; + goto out; + } + + if (params->chandef.chan->freq_offset) { + /* this may work, but is untested */ + err = -EOPNOTSUPP; + goto out; + } + + chanctx = container_of(conf, struct ieee80211_chanctx, conf); + + ch_switch.timestamp = 0; + ch_switch.device_timestamp = 0; + ch_switch.block_tx = params->block_tx; + ch_switch.chandef = params->chandef; + ch_switch.count = params->count; + + err = drv_pre_channel_switch(sdata, &ch_switch); + if (err) + goto out; + + err = ieee80211_link_reserve_chanctx(&sdata->deflink, ¶ms->chandef, + chanctx->mode, + params->radar_required); + if (err) + goto out; + + /* if reservation is invalid then this will fail */ + err = ieee80211_check_combinations(sdata, NULL, chanctx->mode, 0); + if (err) { + ieee80211_link_unreserve_chanctx(&sdata->deflink); + goto out; + } + + /* if there is a color change in progress, abort it */ + if (sdata->vif.bss_conf.color_change_active) + ieee80211_color_change_abort(sdata); + + err = ieee80211_set_csa_beacon(sdata, params, &changed); + if (err) { + ieee80211_link_unreserve_chanctx(&sdata->deflink); + goto out; + } + + sdata->deflink.csa_chandef = params->chandef; + sdata->deflink.csa_block_tx = params->block_tx; + sdata->vif.bss_conf.csa_active = true; + + if (sdata->deflink.csa_block_tx) + ieee80211_stop_vif_queues(local, sdata, + IEEE80211_QUEUE_STOP_REASON_CSA); + + cfg80211_ch_switch_started_notify(sdata->dev, + &sdata->deflink.csa_chandef, 0, + params->count, params->block_tx); + + if (changed) { + ieee80211_link_info_change_notify(sdata, &sdata->deflink, + changed); + drv_channel_switch_beacon(sdata, ¶ms->chandef); + } else { + /* if the beacon didn't change, we can finalize immediately */ + ieee80211_csa_finalize(sdata); + } + +out: + mutex_unlock(&local->chanctx_mtx); + return err; +} + +int ieee80211_channel_switch(struct wiphy *wiphy, struct net_device *dev, + struct cfg80211_csa_settings *params) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + int err; + + mutex_lock(&local->mtx); + err = __ieee80211_channel_switch(wiphy, dev, params); + mutex_unlock(&local->mtx); + + return err; +} + +u64 ieee80211_mgmt_tx_cookie(struct ieee80211_local *local) +{ + lockdep_assert_held(&local->mtx); + + local->roc_cookie_counter++; + + /* wow, you wrapped 64 bits ... more likely a bug */ + if (WARN_ON(local->roc_cookie_counter == 0)) + local->roc_cookie_counter++; + + return local->roc_cookie_counter; +} + +int ieee80211_attach_ack_skb(struct ieee80211_local *local, struct sk_buff *skb, + u64 *cookie, gfp_t gfp) +{ + unsigned long spin_flags; + struct sk_buff *ack_skb; + int id; + + ack_skb = skb_copy(skb, gfp); + if (!ack_skb) + return -ENOMEM; + + spin_lock_irqsave(&local->ack_status_lock, spin_flags); + id = idr_alloc(&local->ack_status_frames, ack_skb, + 1, 0x2000, GFP_ATOMIC); + spin_unlock_irqrestore(&local->ack_status_lock, spin_flags); + + if (id < 0) { + kfree_skb(ack_skb); + return -ENOMEM; + } + + IEEE80211_SKB_CB(skb)->ack_frame_id = id; + + *cookie = ieee80211_mgmt_tx_cookie(local); + IEEE80211_SKB_CB(ack_skb)->ack.cookie = *cookie; + + return 0; +} + +static void +ieee80211_update_mgmt_frame_registrations(struct wiphy *wiphy, + struct wireless_dev *wdev, + struct mgmt_frame_regs *upd) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + u32 preq_mask = BIT(IEEE80211_STYPE_PROBE_REQ >> 4); + u32 action_mask = BIT(IEEE80211_STYPE_ACTION >> 4); + bool global_change, intf_change; + + global_change = + (local->probe_req_reg != !!(upd->global_stypes & preq_mask)) || + (local->rx_mcast_action_reg != + !!(upd->global_mcast_stypes & action_mask)); + local->probe_req_reg = upd->global_stypes & preq_mask; + local->rx_mcast_action_reg = upd->global_mcast_stypes & action_mask; + + intf_change = (sdata->vif.probe_req_reg != + !!(upd->interface_stypes & preq_mask)) || + (sdata->vif.rx_mcast_action_reg != + !!(upd->interface_mcast_stypes & action_mask)); + sdata->vif.probe_req_reg = upd->interface_stypes & preq_mask; + sdata->vif.rx_mcast_action_reg = + upd->interface_mcast_stypes & action_mask; + + if (!local->open_count) + return; + + if (intf_change && ieee80211_sdata_running(sdata)) + drv_config_iface_filter(local, sdata, + sdata->vif.probe_req_reg ? + FIF_PROBE_REQ : 0, + FIF_PROBE_REQ); + + if (global_change) + ieee80211_configure_filter(local); +} + +static int ieee80211_set_antenna(struct wiphy *wiphy, u32 tx_ant, u32 rx_ant) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + + if (local->started) + return -EOPNOTSUPP; + + return drv_set_antenna(local, tx_ant, rx_ant); +} + +static int ieee80211_get_antenna(struct wiphy *wiphy, u32 *tx_ant, u32 *rx_ant) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + + return drv_get_antenna(local, tx_ant, rx_ant); +} + +static int ieee80211_set_rekey_data(struct wiphy *wiphy, + struct net_device *dev, + struct cfg80211_gtk_rekey_data *data) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + if (!local->ops->set_rekey_data) + return -EOPNOTSUPP; + + drv_set_rekey_data(local, sdata, data); + + return 0; +} + +static int ieee80211_probe_client(struct wiphy *wiphy, struct net_device *dev, + const u8 *peer, u64 *cookie) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + struct ieee80211_qos_hdr *nullfunc; + struct sk_buff *skb; + int size = sizeof(*nullfunc); + __le16 fc; + bool qos; + struct ieee80211_tx_info *info; + struct sta_info *sta; + struct ieee80211_chanctx_conf *chanctx_conf; + enum nl80211_band band; + int ret; + + /* the lock is needed to assign the cookie later */ + mutex_lock(&local->mtx); + + rcu_read_lock(); + sta = sta_info_get_bss(sdata, peer); + if (!sta) { + ret = -ENOLINK; + goto unlock; + } + + qos = sta->sta.wme; + + chanctx_conf = rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + if (WARN_ON(!chanctx_conf)) { + ret = -EINVAL; + goto unlock; + } + band = chanctx_conf->def.chan->band; + + if (qos) { + fc = cpu_to_le16(IEEE80211_FTYPE_DATA | + IEEE80211_STYPE_QOS_NULLFUNC | + IEEE80211_FCTL_FROMDS); + } else { + size -= 2; + fc = cpu_to_le16(IEEE80211_FTYPE_DATA | + IEEE80211_STYPE_NULLFUNC | + IEEE80211_FCTL_FROMDS); + } + + skb = dev_alloc_skb(local->hw.extra_tx_headroom + size); + if (!skb) { + ret = -ENOMEM; + goto unlock; + } + + skb->dev = dev; + + skb_reserve(skb, local->hw.extra_tx_headroom); + + nullfunc = skb_put(skb, size); + nullfunc->frame_control = fc; + nullfunc->duration_id = 0; + memcpy(nullfunc->addr1, sta->sta.addr, ETH_ALEN); + memcpy(nullfunc->addr2, sdata->vif.addr, ETH_ALEN); + memcpy(nullfunc->addr3, sdata->vif.addr, ETH_ALEN); + nullfunc->seq_ctrl = 0; + + info = IEEE80211_SKB_CB(skb); + + info->flags |= IEEE80211_TX_CTL_REQ_TX_STATUS | + IEEE80211_TX_INTFL_NL80211_FRAME_TX; + info->band = band; + + skb_set_queue_mapping(skb, IEEE80211_AC_VO); + skb->priority = 7; + if (qos) + nullfunc->qos_ctrl = cpu_to_le16(7); + + ret = ieee80211_attach_ack_skb(local, skb, cookie, GFP_ATOMIC); + if (ret) { + kfree_skb(skb); + goto unlock; + } + + local_bh_disable(); + ieee80211_xmit(sdata, sta, skb); + local_bh_enable(); + + ret = 0; +unlock: + rcu_read_unlock(); + mutex_unlock(&local->mtx); + + return ret; +} + +static int ieee80211_cfg_get_channel(struct wiphy *wiphy, + struct wireless_dev *wdev, + unsigned int link_id, + struct cfg80211_chan_def *chandef) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + struct ieee80211_local *local = wiphy_priv(wiphy); + struct ieee80211_chanctx_conf *chanctx_conf; + struct ieee80211_link_data *link; + int ret = -ENODATA; + + rcu_read_lock(); + link = rcu_dereference(sdata->link[link_id]); + if (!link) { + ret = -ENOLINK; + goto out; + } + + chanctx_conf = rcu_dereference(link->conf->chanctx_conf); + if (chanctx_conf) { + *chandef = link->conf->chandef; + ret = 0; + } else if (local->open_count > 0 && + local->open_count == local->monitors && + sdata->vif.type == NL80211_IFTYPE_MONITOR) { + if (local->use_chanctx) + *chandef = local->monitor_chandef; + else + *chandef = local->_oper_chandef; + ret = 0; + } +out: + rcu_read_unlock(); + + return ret; +} + +#ifdef CONFIG_PM +static void ieee80211_set_wakeup(struct wiphy *wiphy, bool enabled) +{ + drv_set_wakeup(wiphy_priv(wiphy), enabled); +} +#endif + +static int ieee80211_set_qos_map(struct wiphy *wiphy, + struct net_device *dev, + struct cfg80211_qos_map *qos_map) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct mac80211_qos_map *new_qos_map, *old_qos_map; + + if (qos_map) { + new_qos_map = kzalloc(sizeof(*new_qos_map), GFP_KERNEL); + if (!new_qos_map) + return -ENOMEM; + memcpy(&new_qos_map->qos_map, qos_map, sizeof(*qos_map)); + } else { + /* A NULL qos_map was passed to disable QoS mapping */ + new_qos_map = NULL; + } + + old_qos_map = sdata_dereference(sdata->qos_map, sdata); + rcu_assign_pointer(sdata->qos_map, new_qos_map); + if (old_qos_map) + kfree_rcu(old_qos_map, rcu_head); + + return 0; +} + +static int ieee80211_set_ap_chanwidth(struct wiphy *wiphy, + struct net_device *dev, + unsigned int link_id, + struct cfg80211_chan_def *chandef) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_link_data *link; + int ret; + u32 changed = 0; + + link = sdata_dereference(sdata->link[link_id], sdata); + + ret = ieee80211_link_change_bandwidth(link, chandef, &changed); + if (ret == 0) + ieee80211_link_info_change_notify(sdata, link, changed); + + return ret; +} + +static int ieee80211_add_tx_ts(struct wiphy *wiphy, struct net_device *dev, + u8 tsid, const u8 *peer, u8 up, + u16 admitted_time) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + int ac = ieee802_1d_to_ac[up]; + + if (sdata->vif.type != NL80211_IFTYPE_STATION) + return -EOPNOTSUPP; + + if (!(sdata->wmm_acm & BIT(up))) + return -EINVAL; + + if (ifmgd->tx_tspec[ac].admitted_time) + return -EBUSY; + + if (admitted_time) { + ifmgd->tx_tspec[ac].admitted_time = 32 * admitted_time; + ifmgd->tx_tspec[ac].tsid = tsid; + ifmgd->tx_tspec[ac].up = up; + } + + return 0; +} + +static int ieee80211_del_tx_ts(struct wiphy *wiphy, struct net_device *dev, + u8 tsid, const u8 *peer) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_local *local = wiphy_priv(wiphy); + int ac; + + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { + struct ieee80211_sta_tx_tspec *tx_tspec = &ifmgd->tx_tspec[ac]; + + /* skip unused entries */ + if (!tx_tspec->admitted_time) + continue; + + if (tx_tspec->tsid != tsid) + continue; + + /* due to this new packets will be reassigned to non-ACM ACs */ + tx_tspec->up = -1; + + /* Make sure that all packets have been sent to avoid to + * restore the QoS params on packets that are still on the + * queues. + */ + synchronize_net(); + ieee80211_flush_queues(local, sdata, false); + + /* restore the normal QoS parameters + * (unconditionally to avoid races) + */ + tx_tspec->action = TX_TSPEC_ACTION_STOP_DOWNGRADE; + tx_tspec->downgraded = false; + ieee80211_sta_handle_tspec_ac_params(sdata); + + /* finally clear all the data */ + memset(tx_tspec, 0, sizeof(*tx_tspec)); + + return 0; + } + + return -ENOENT; +} + +void ieee80211_nan_func_terminated(struct ieee80211_vif *vif, + u8 inst_id, + enum nl80211_nan_func_term_reason reason, + gfp_t gfp) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct cfg80211_nan_func *func; + u64 cookie; + + if (WARN_ON(vif->type != NL80211_IFTYPE_NAN)) + return; + + spin_lock_bh(&sdata->u.nan.func_lock); + + func = idr_find(&sdata->u.nan.function_inst_ids, inst_id); + if (WARN_ON(!func)) { + spin_unlock_bh(&sdata->u.nan.func_lock); + return; + } + + cookie = func->cookie; + idr_remove(&sdata->u.nan.function_inst_ids, inst_id); + + spin_unlock_bh(&sdata->u.nan.func_lock); + + cfg80211_free_nan_func(func); + + cfg80211_nan_func_terminated(ieee80211_vif_to_wdev(vif), inst_id, + reason, cookie, gfp); +} +EXPORT_SYMBOL(ieee80211_nan_func_terminated); + +void ieee80211_nan_func_match(struct ieee80211_vif *vif, + struct cfg80211_nan_match_params *match, + gfp_t gfp) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct cfg80211_nan_func *func; + + if (WARN_ON(vif->type != NL80211_IFTYPE_NAN)) + return; + + spin_lock_bh(&sdata->u.nan.func_lock); + + func = idr_find(&sdata->u.nan.function_inst_ids, match->inst_id); + if (WARN_ON(!func)) { + spin_unlock_bh(&sdata->u.nan.func_lock); + return; + } + match->cookie = func->cookie; + + spin_unlock_bh(&sdata->u.nan.func_lock); + + cfg80211_nan_match(ieee80211_vif_to_wdev(vif), match, gfp); +} +EXPORT_SYMBOL(ieee80211_nan_func_match); + +static int ieee80211_set_multicast_to_unicast(struct wiphy *wiphy, + struct net_device *dev, + const bool enabled) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + sdata->u.ap.multicast_to_unicast = enabled; + + return 0; +} + +void ieee80211_fill_txq_stats(struct cfg80211_txq_stats *txqstats, + struct txq_info *txqi) +{ + if (!(txqstats->filled & BIT(NL80211_TXQ_STATS_BACKLOG_BYTES))) { + txqstats->filled |= BIT(NL80211_TXQ_STATS_BACKLOG_BYTES); + txqstats->backlog_bytes = txqi->tin.backlog_bytes; + } + + if (!(txqstats->filled & BIT(NL80211_TXQ_STATS_BACKLOG_PACKETS))) { + txqstats->filled |= BIT(NL80211_TXQ_STATS_BACKLOG_PACKETS); + txqstats->backlog_packets = txqi->tin.backlog_packets; + } + + if (!(txqstats->filled & BIT(NL80211_TXQ_STATS_FLOWS))) { + txqstats->filled |= BIT(NL80211_TXQ_STATS_FLOWS); + txqstats->flows = txqi->tin.flows; + } + + if (!(txqstats->filled & BIT(NL80211_TXQ_STATS_DROPS))) { + txqstats->filled |= BIT(NL80211_TXQ_STATS_DROPS); + txqstats->drops = txqi->cstats.drop_count; + } + + if (!(txqstats->filled & BIT(NL80211_TXQ_STATS_ECN_MARKS))) { + txqstats->filled |= BIT(NL80211_TXQ_STATS_ECN_MARKS); + txqstats->ecn_marks = txqi->cstats.ecn_mark; + } + + if (!(txqstats->filled & BIT(NL80211_TXQ_STATS_OVERLIMIT))) { + txqstats->filled |= BIT(NL80211_TXQ_STATS_OVERLIMIT); + txqstats->overlimit = txqi->tin.overlimit; + } + + if (!(txqstats->filled & BIT(NL80211_TXQ_STATS_COLLISIONS))) { + txqstats->filled |= BIT(NL80211_TXQ_STATS_COLLISIONS); + txqstats->collisions = txqi->tin.collisions; + } + + if (!(txqstats->filled & BIT(NL80211_TXQ_STATS_TX_BYTES))) { + txqstats->filled |= BIT(NL80211_TXQ_STATS_TX_BYTES); + txqstats->tx_bytes = txqi->tin.tx_bytes; + } + + if (!(txqstats->filled & BIT(NL80211_TXQ_STATS_TX_PACKETS))) { + txqstats->filled |= BIT(NL80211_TXQ_STATS_TX_PACKETS); + txqstats->tx_packets = txqi->tin.tx_packets; + } +} + +static int ieee80211_get_txq_stats(struct wiphy *wiphy, + struct wireless_dev *wdev, + struct cfg80211_txq_stats *txqstats) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + struct ieee80211_sub_if_data *sdata; + int ret = 0; + + if (!local->ops->wake_tx_queue) + return 1; + + spin_lock_bh(&local->fq.lock); + rcu_read_lock(); + + if (wdev) { + sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + if (!sdata->vif.txq) { + ret = 1; + goto out; + } + ieee80211_fill_txq_stats(txqstats, to_txq_info(sdata->vif.txq)); + } else { + /* phy stats */ + txqstats->filled |= BIT(NL80211_TXQ_STATS_BACKLOG_PACKETS) | + BIT(NL80211_TXQ_STATS_BACKLOG_BYTES) | + BIT(NL80211_TXQ_STATS_OVERLIMIT) | + BIT(NL80211_TXQ_STATS_OVERMEMORY) | + BIT(NL80211_TXQ_STATS_COLLISIONS) | + BIT(NL80211_TXQ_STATS_MAX_FLOWS); + txqstats->backlog_packets = local->fq.backlog; + txqstats->backlog_bytes = local->fq.memory_usage; + txqstats->overlimit = local->fq.overlimit; + txqstats->overmemory = local->fq.overmemory; + txqstats->collisions = local->fq.collisions; + txqstats->max_flows = local->fq.flows_cnt; + } + +out: + rcu_read_unlock(); + spin_unlock_bh(&local->fq.lock); + + return ret; +} + +static int +ieee80211_get_ftm_responder_stats(struct wiphy *wiphy, + struct net_device *dev, + struct cfg80211_ftm_responder_stats *ftm_stats) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + return drv_get_ftm_responder_stats(local, sdata, ftm_stats); +} + +static int +ieee80211_start_pmsr(struct wiphy *wiphy, struct wireless_dev *dev, + struct cfg80211_pmsr_request *request) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(dev); + + return drv_start_pmsr(local, sdata, request); +} + +static void +ieee80211_abort_pmsr(struct wiphy *wiphy, struct wireless_dev *dev, + struct cfg80211_pmsr_request *request) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(dev); + + return drv_abort_pmsr(local, sdata, request); +} + +static int ieee80211_set_tid_config(struct wiphy *wiphy, + struct net_device *dev, + struct cfg80211_tid_config *tid_conf) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct sta_info *sta; + int ret; + + if (!sdata->local->ops->set_tid_config) + return -EOPNOTSUPP; + + if (!tid_conf->peer) + return drv_set_tid_config(sdata->local, sdata, NULL, tid_conf); + + mutex_lock(&sdata->local->sta_mtx); + sta = sta_info_get_bss(sdata, tid_conf->peer); + if (!sta) { + mutex_unlock(&sdata->local->sta_mtx); + return -ENOENT; + } + + ret = drv_set_tid_config(sdata->local, sdata, &sta->sta, tid_conf); + mutex_unlock(&sdata->local->sta_mtx); + + return ret; +} + +static int ieee80211_reset_tid_config(struct wiphy *wiphy, + struct net_device *dev, + const u8 *peer, u8 tids) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct sta_info *sta; + int ret; + + if (!sdata->local->ops->reset_tid_config) + return -EOPNOTSUPP; + + if (!peer) + return drv_reset_tid_config(sdata->local, sdata, NULL, tids); + + mutex_lock(&sdata->local->sta_mtx); + sta = sta_info_get_bss(sdata, peer); + if (!sta) { + mutex_unlock(&sdata->local->sta_mtx); + return -ENOENT; + } + + ret = drv_reset_tid_config(sdata->local, sdata, &sta->sta, tids); + mutex_unlock(&sdata->local->sta_mtx); + + return ret; +} + +static int ieee80211_set_sar_specs(struct wiphy *wiphy, + struct cfg80211_sar_specs *sar) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + + if (!local->ops->set_sar_specs) + return -EOPNOTSUPP; + + return local->ops->set_sar_specs(&local->hw, sar); +} + +static int +ieee80211_set_after_color_change_beacon(struct ieee80211_sub_if_data *sdata, + u32 *changed) +{ + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP: { + int ret; + + if (!sdata->deflink.u.ap.next_beacon) + return -EINVAL; + + ret = ieee80211_assign_beacon(sdata, &sdata->deflink, + sdata->deflink.u.ap.next_beacon, + NULL, NULL); + ieee80211_free_next_beacon(&sdata->deflink); + + if (ret < 0) + return ret; + + *changed |= ret; + break; + } + default: + WARN_ON_ONCE(1); + return -EINVAL; + } + + return 0; +} + +static int +ieee80211_set_color_change_beacon(struct ieee80211_sub_if_data *sdata, + struct cfg80211_color_change_settings *params, + u32 *changed) +{ + struct ieee80211_color_change_settings color_change = {}; + int err; + + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP: + sdata->deflink.u.ap.next_beacon = + cfg80211_beacon_dup(¶ms->beacon_next); + if (!sdata->deflink.u.ap.next_beacon) + return -ENOMEM; + + if (params->count <= 1) + break; + + color_change.counter_offset_beacon = + params->counter_offset_beacon; + color_change.counter_offset_presp = + params->counter_offset_presp; + color_change.count = params->count; + + err = ieee80211_assign_beacon(sdata, &sdata->deflink, + ¶ms->beacon_color_change, + NULL, &color_change); + if (err < 0) { + ieee80211_free_next_beacon(&sdata->deflink); + return err; + } + *changed |= err; + break; + default: + return -EOPNOTSUPP; + } + + return 0; +} + +static void +ieee80211_color_change_bss_config_notify(struct ieee80211_sub_if_data *sdata, + u8 color, int enable, u32 changed) +{ + sdata->vif.bss_conf.he_bss_color.color = color; + sdata->vif.bss_conf.he_bss_color.enabled = enable; + changed |= BSS_CHANGED_HE_BSS_COLOR; + + ieee80211_link_info_change_notify(sdata, &sdata->deflink, changed); + + if (!sdata->vif.bss_conf.nontransmitted && sdata->vif.mbssid_tx_vif) { + struct ieee80211_sub_if_data *child; + + mutex_lock(&sdata->local->iflist_mtx); + list_for_each_entry(child, &sdata->local->interfaces, list) { + if (child != sdata && child->vif.mbssid_tx_vif == &sdata->vif) { + child->vif.bss_conf.he_bss_color.color = color; + child->vif.bss_conf.he_bss_color.enabled = enable; + ieee80211_link_info_change_notify(child, + &child->deflink, + BSS_CHANGED_HE_BSS_COLOR); + } + } + mutex_unlock(&sdata->local->iflist_mtx); + } +} + +static int ieee80211_color_change_finalize(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + u32 changed = 0; + int err; + + sdata_assert_lock(sdata); + lockdep_assert_held(&local->mtx); + + sdata->vif.bss_conf.color_change_active = false; + + err = ieee80211_set_after_color_change_beacon(sdata, &changed); + if (err) { + cfg80211_color_change_aborted_notify(sdata->dev); + return err; + } + + ieee80211_color_change_bss_config_notify(sdata, + sdata->vif.bss_conf.color_change_color, + 1, changed); + cfg80211_color_change_notify(sdata->dev); + + return 0; +} + +void ieee80211_color_change_finalize_work(struct work_struct *work) +{ + struct ieee80211_sub_if_data *sdata = + container_of(work, struct ieee80211_sub_if_data, + deflink.color_change_finalize_work); + struct ieee80211_local *local = sdata->local; + + sdata_lock(sdata); + mutex_lock(&local->mtx); + + /* AP might have been stopped while waiting for the lock. */ + if (!sdata->vif.bss_conf.color_change_active) + goto unlock; + + if (!ieee80211_sdata_running(sdata)) + goto unlock; + + ieee80211_color_change_finalize(sdata); + +unlock: + mutex_unlock(&local->mtx); + sdata_unlock(sdata); +} + +void ieee80211_color_collision_detection_work(struct work_struct *work) +{ + struct delayed_work *delayed_work = to_delayed_work(work); + struct ieee80211_link_data *link = + container_of(delayed_work, struct ieee80211_link_data, + color_collision_detect_work); + struct ieee80211_sub_if_data *sdata = link->sdata; + + sdata_lock(sdata); + cfg80211_obss_color_collision_notify(sdata->dev, link->color_bitmap, + GFP_KERNEL); + sdata_unlock(sdata); +} + +void ieee80211_color_change_finish(struct ieee80211_vif *vif) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + + ieee80211_queue_work(&sdata->local->hw, + &sdata->deflink.color_change_finalize_work); +} +EXPORT_SYMBOL_GPL(ieee80211_color_change_finish); + +void +ieeee80211_obss_color_collision_notify(struct ieee80211_vif *vif, + u64 color_bitmap, gfp_t gfp) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_link_data *link = &sdata->deflink; + + if (sdata->vif.bss_conf.color_change_active || sdata->vif.bss_conf.csa_active) + return; + + if (delayed_work_pending(&link->color_collision_detect_work)) + return; + + link->color_bitmap = color_bitmap; + /* queue the color collision detection event every 500 ms in order to + * avoid sending too much netlink messages to userspace. + */ + ieee80211_queue_delayed_work(&sdata->local->hw, + &link->color_collision_detect_work, + msecs_to_jiffies(500)); +} +EXPORT_SYMBOL_GPL(ieeee80211_obss_color_collision_notify); + +static int +ieee80211_color_change(struct wiphy *wiphy, struct net_device *dev, + struct cfg80211_color_change_settings *params) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + u32 changed = 0; + int err; + + sdata_assert_lock(sdata); + + if (sdata->vif.bss_conf.nontransmitted) + return -EINVAL; + + mutex_lock(&local->mtx); + + /* don't allow another color change if one is already active or if csa + * is active + */ + if (sdata->vif.bss_conf.color_change_active || sdata->vif.bss_conf.csa_active) { + err = -EBUSY; + goto out; + } + + err = ieee80211_set_color_change_beacon(sdata, params, &changed); + if (err) + goto out; + + sdata->vif.bss_conf.color_change_active = true; + sdata->vif.bss_conf.color_change_color = params->color; + + cfg80211_color_change_started_notify(sdata->dev, params->count); + + if (changed) + ieee80211_color_change_bss_config_notify(sdata, 0, 0, changed); + else + /* if the beacon didn't change, we can finalize immediately */ + ieee80211_color_change_finalize(sdata); + +out: + mutex_unlock(&local->mtx); + + return err; +} + +static int +ieee80211_set_radar_background(struct wiphy *wiphy, + struct cfg80211_chan_def *chandef) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + + if (!local->ops->set_radar_background) + return -EOPNOTSUPP; + + return local->ops->set_radar_background(&local->hw, chandef); +} + +static int ieee80211_add_intf_link(struct wiphy *wiphy, + struct wireless_dev *wdev, + unsigned int link_id) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + int res; + + if (wdev->use_4addr) + return -EOPNOTSUPP; + + mutex_lock(&sdata->local->mtx); + res = ieee80211_vif_set_links(sdata, wdev->valid_links); + mutex_unlock(&sdata->local->mtx); + + return res; +} + +static void ieee80211_del_intf_link(struct wiphy *wiphy, + struct wireless_dev *wdev, + unsigned int link_id) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + + mutex_lock(&sdata->local->mtx); + ieee80211_vif_set_links(sdata, wdev->valid_links); + mutex_unlock(&sdata->local->mtx); +} + +static int sta_add_link_station(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct link_station_parameters *params) +{ + struct sta_info *sta; + int ret; + + sta = sta_info_get_bss(sdata, params->mld_mac); + if (!sta) + return -ENOENT; + + if (!sta->sta.valid_links) + return -EINVAL; + + if (sta->sta.valid_links & BIT(params->link_id)) + return -EALREADY; + + ret = ieee80211_sta_allocate_link(sta, params->link_id); + if (ret) + return ret; + + ret = sta_link_apply_parameters(local, sta, true, params); + if (ret) { + ieee80211_sta_free_link(sta, params->link_id); + return ret; + } + + /* ieee80211_sta_activate_link frees the link upon failure */ + return ieee80211_sta_activate_link(sta, params->link_id); +} + +static int +ieee80211_add_link_station(struct wiphy *wiphy, struct net_device *dev, + struct link_station_parameters *params) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = wiphy_priv(wiphy); + int ret; + + mutex_lock(&sdata->local->sta_mtx); + ret = sta_add_link_station(local, sdata, params); + mutex_unlock(&sdata->local->sta_mtx); + + return ret; +} + +static int sta_mod_link_station(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct link_station_parameters *params) +{ + struct sta_info *sta; + + sta = sta_info_get_bss(sdata, params->mld_mac); + if (!sta) + return -ENOENT; + + if (!(sta->sta.valid_links & BIT(params->link_id))) + return -EINVAL; + + return sta_link_apply_parameters(local, sta, false, params); +} + +static int +ieee80211_mod_link_station(struct wiphy *wiphy, struct net_device *dev, + struct link_station_parameters *params) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = wiphy_priv(wiphy); + int ret; + + mutex_lock(&sdata->local->sta_mtx); + ret = sta_mod_link_station(local, sdata, params); + mutex_unlock(&sdata->local->sta_mtx); + + return ret; +} + +static int sta_del_link_station(struct ieee80211_sub_if_data *sdata, + struct link_station_del_parameters *params) +{ + struct sta_info *sta; + + sta = sta_info_get_bss(sdata, params->mld_mac); + if (!sta) + return -ENOENT; + + if (!(sta->sta.valid_links & BIT(params->link_id))) + return -EINVAL; + + /* must not create a STA without links */ + if (sta->sta.valid_links == BIT(params->link_id)) + return -EINVAL; + + ieee80211_sta_remove_link(sta, params->link_id); + + return 0; +} + +static int +ieee80211_del_link_station(struct wiphy *wiphy, struct net_device *dev, + struct link_station_del_parameters *params) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + int ret; + + mutex_lock(&sdata->local->sta_mtx); + ret = sta_del_link_station(sdata, params); + mutex_unlock(&sdata->local->sta_mtx); + + return ret; +} + +const struct cfg80211_ops mac80211_config_ops = { + .add_virtual_intf = ieee80211_add_iface, + .del_virtual_intf = ieee80211_del_iface, + .change_virtual_intf = ieee80211_change_iface, + .start_p2p_device = ieee80211_start_p2p_device, + .stop_p2p_device = ieee80211_stop_p2p_device, + .add_key = ieee80211_add_key, + .del_key = ieee80211_del_key, + .get_key = ieee80211_get_key, + .set_default_key = ieee80211_config_default_key, + .set_default_mgmt_key = ieee80211_config_default_mgmt_key, + .set_default_beacon_key = ieee80211_config_default_beacon_key, + .start_ap = ieee80211_start_ap, + .change_beacon = ieee80211_change_beacon, + .stop_ap = ieee80211_stop_ap, + .add_station = ieee80211_add_station, + .del_station = ieee80211_del_station, + .change_station = ieee80211_change_station, + .get_station = ieee80211_get_station, + .dump_station = ieee80211_dump_station, + .dump_survey = ieee80211_dump_survey, +#ifdef CONFIG_MAC80211_MESH + .add_mpath = ieee80211_add_mpath, + .del_mpath = ieee80211_del_mpath, + .change_mpath = ieee80211_change_mpath, + .get_mpath = ieee80211_get_mpath, + .dump_mpath = ieee80211_dump_mpath, + .get_mpp = ieee80211_get_mpp, + .dump_mpp = ieee80211_dump_mpp, + .update_mesh_config = ieee80211_update_mesh_config, + .get_mesh_config = ieee80211_get_mesh_config, + .join_mesh = ieee80211_join_mesh, + .leave_mesh = ieee80211_leave_mesh, +#endif + .join_ocb = ieee80211_join_ocb, + .leave_ocb = ieee80211_leave_ocb, + .change_bss = ieee80211_change_bss, + .set_txq_params = ieee80211_set_txq_params, + .set_monitor_channel = ieee80211_set_monitor_channel, + .suspend = ieee80211_suspend, + .resume = ieee80211_resume, + .scan = ieee80211_scan, + .abort_scan = ieee80211_abort_scan, + .sched_scan_start = ieee80211_sched_scan_start, + .sched_scan_stop = ieee80211_sched_scan_stop, + .auth = ieee80211_auth, + .assoc = ieee80211_assoc, + .deauth = ieee80211_deauth, + .disassoc = ieee80211_disassoc, + .join_ibss = ieee80211_join_ibss, + .leave_ibss = ieee80211_leave_ibss, + .set_mcast_rate = ieee80211_set_mcast_rate, + .set_wiphy_params = ieee80211_set_wiphy_params, + .set_tx_power = ieee80211_set_tx_power, + .get_tx_power = ieee80211_get_tx_power, + .rfkill_poll = ieee80211_rfkill_poll, + CFG80211_TESTMODE_CMD(ieee80211_testmode_cmd) + CFG80211_TESTMODE_DUMP(ieee80211_testmode_dump) + .set_power_mgmt = ieee80211_set_power_mgmt, + .set_bitrate_mask = ieee80211_set_bitrate_mask, + .remain_on_channel = ieee80211_remain_on_channel, + .cancel_remain_on_channel = ieee80211_cancel_remain_on_channel, + .mgmt_tx = ieee80211_mgmt_tx, + .mgmt_tx_cancel_wait = ieee80211_mgmt_tx_cancel_wait, + .set_cqm_rssi_config = ieee80211_set_cqm_rssi_config, + .set_cqm_rssi_range_config = ieee80211_set_cqm_rssi_range_config, + .update_mgmt_frame_registrations = + ieee80211_update_mgmt_frame_registrations, + .set_antenna = ieee80211_set_antenna, + .get_antenna = ieee80211_get_antenna, + .set_rekey_data = ieee80211_set_rekey_data, + .tdls_oper = ieee80211_tdls_oper, + .tdls_mgmt = ieee80211_tdls_mgmt, + .tdls_channel_switch = ieee80211_tdls_channel_switch, + .tdls_cancel_channel_switch = ieee80211_tdls_cancel_channel_switch, + .probe_client = ieee80211_probe_client, + .set_noack_map = ieee80211_set_noack_map, +#ifdef CONFIG_PM + .set_wakeup = ieee80211_set_wakeup, +#endif + .get_channel = ieee80211_cfg_get_channel, + .start_radar_detection = ieee80211_start_radar_detection, + .end_cac = ieee80211_end_cac, + .channel_switch = ieee80211_channel_switch, + .set_qos_map = ieee80211_set_qos_map, + .set_ap_chanwidth = ieee80211_set_ap_chanwidth, + .add_tx_ts = ieee80211_add_tx_ts, + .del_tx_ts = ieee80211_del_tx_ts, + .start_nan = ieee80211_start_nan, + .stop_nan = ieee80211_stop_nan, + .nan_change_conf = ieee80211_nan_change_conf, + .add_nan_func = ieee80211_add_nan_func, + .del_nan_func = ieee80211_del_nan_func, + .set_multicast_to_unicast = ieee80211_set_multicast_to_unicast, + .tx_control_port = ieee80211_tx_control_port, + .get_txq_stats = ieee80211_get_txq_stats, + .get_ftm_responder_stats = ieee80211_get_ftm_responder_stats, + .start_pmsr = ieee80211_start_pmsr, + .abort_pmsr = ieee80211_abort_pmsr, + .probe_mesh_link = ieee80211_probe_mesh_link, + .set_tid_config = ieee80211_set_tid_config, + .reset_tid_config = ieee80211_reset_tid_config, + .set_sar_specs = ieee80211_set_sar_specs, + .color_change = ieee80211_color_change, + .set_radar_background = ieee80211_set_radar_background, + .add_intf_link = ieee80211_add_intf_link, + .del_intf_link = ieee80211_del_intf_link, + .add_link_station = ieee80211_add_link_station, + .mod_link_station = ieee80211_mod_link_station, + .del_link_station = ieee80211_del_link_station, +}; diff --git a/net/mac80211/chan.c b/net/mac80211/chan.c new file mode 100644 index 000000000..f07e34bed --- /dev/null +++ b/net/mac80211/chan.c @@ -0,0 +1,2064 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * mac80211 - channel management + * Copyright 2020 - 2022 Intel Corporation + */ + +#include +#include +#include +#include +#include "ieee80211_i.h" +#include "driver-ops.h" +#include "rate.h" + +static int ieee80211_chanctx_num_assigned(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx) +{ + struct ieee80211_link_data *link; + int num = 0; + + lockdep_assert_held(&local->chanctx_mtx); + + list_for_each_entry(link, &ctx->assigned_links, assigned_chanctx_list) + num++; + + return num; +} + +static int ieee80211_chanctx_num_reserved(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx) +{ + struct ieee80211_link_data *link; + int num = 0; + + lockdep_assert_held(&local->chanctx_mtx); + + list_for_each_entry(link, &ctx->reserved_links, reserved_chanctx_list) + num++; + + return num; +} + +int ieee80211_chanctx_refcount(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx) +{ + return ieee80211_chanctx_num_assigned(local, ctx) + + ieee80211_chanctx_num_reserved(local, ctx); +} + +static int ieee80211_num_chanctx(struct ieee80211_local *local) +{ + struct ieee80211_chanctx *ctx; + int num = 0; + + lockdep_assert_held(&local->chanctx_mtx); + + list_for_each_entry(ctx, &local->chanctx_list, list) + num++; + + return num; +} + +static bool ieee80211_can_create_new_chanctx(struct ieee80211_local *local) +{ + lockdep_assert_held(&local->chanctx_mtx); + return ieee80211_num_chanctx(local) < ieee80211_max_num_channels(local); +} + +static struct ieee80211_chanctx * +ieee80211_link_get_chanctx(struct ieee80211_link_data *link) +{ + struct ieee80211_local *local __maybe_unused = link->sdata->local; + struct ieee80211_chanctx_conf *conf; + + conf = rcu_dereference_protected(link->conf->chanctx_conf, + lockdep_is_held(&local->chanctx_mtx)); + if (!conf) + return NULL; + + return container_of(conf, struct ieee80211_chanctx, conf); +} + +static const struct cfg80211_chan_def * +ieee80211_chanctx_reserved_chandef(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx, + const struct cfg80211_chan_def *compat) +{ + struct ieee80211_link_data *link; + + lockdep_assert_held(&local->chanctx_mtx); + + list_for_each_entry(link, &ctx->reserved_links, + reserved_chanctx_list) { + if (!compat) + compat = &link->reserved_chandef; + + compat = cfg80211_chandef_compatible(&link->reserved_chandef, + compat); + if (!compat) + break; + } + + return compat; +} + +static const struct cfg80211_chan_def * +ieee80211_chanctx_non_reserved_chandef(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx, + const struct cfg80211_chan_def *compat) +{ + struct ieee80211_link_data *link; + + lockdep_assert_held(&local->chanctx_mtx); + + list_for_each_entry(link, &ctx->assigned_links, + assigned_chanctx_list) { + struct ieee80211_bss_conf *link_conf = link->conf; + + if (link->reserved_chanctx) + continue; + + if (!compat) + compat = &link_conf->chandef; + + compat = cfg80211_chandef_compatible( + &link_conf->chandef, compat); + if (!compat) + break; + } + + return compat; +} + +static const struct cfg80211_chan_def * +ieee80211_chanctx_combined_chandef(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx, + const struct cfg80211_chan_def *compat) +{ + lockdep_assert_held(&local->chanctx_mtx); + + compat = ieee80211_chanctx_reserved_chandef(local, ctx, compat); + if (!compat) + return NULL; + + compat = ieee80211_chanctx_non_reserved_chandef(local, ctx, compat); + if (!compat) + return NULL; + + return compat; +} + +static bool +ieee80211_chanctx_can_reserve_chandef(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx, + const struct cfg80211_chan_def *def) +{ + lockdep_assert_held(&local->chanctx_mtx); + + if (ieee80211_chanctx_combined_chandef(local, ctx, def)) + return true; + + if (!list_empty(&ctx->reserved_links) && + ieee80211_chanctx_reserved_chandef(local, ctx, def)) + return true; + + return false; +} + +static struct ieee80211_chanctx * +ieee80211_find_reservation_chanctx(struct ieee80211_local *local, + const struct cfg80211_chan_def *chandef, + enum ieee80211_chanctx_mode mode) +{ + struct ieee80211_chanctx *ctx; + + lockdep_assert_held(&local->chanctx_mtx); + + if (mode == IEEE80211_CHANCTX_EXCLUSIVE) + return NULL; + + list_for_each_entry(ctx, &local->chanctx_list, list) { + if (ctx->replace_state == IEEE80211_CHANCTX_WILL_BE_REPLACED) + continue; + + if (ctx->mode == IEEE80211_CHANCTX_EXCLUSIVE) + continue; + + if (!ieee80211_chanctx_can_reserve_chandef(local, ctx, + chandef)) + continue; + + return ctx; + } + + return NULL; +} + +static enum nl80211_chan_width ieee80211_get_sta_bw(struct sta_info *sta, + unsigned int link_id) +{ + enum ieee80211_sta_rx_bandwidth width; + struct link_sta_info *link_sta; + + link_sta = rcu_dereference(sta->link[link_id]); + + /* no effect if this STA has no presence on this link */ + if (!link_sta) + return NL80211_CHAN_WIDTH_20_NOHT; + + width = ieee80211_sta_cap_rx_bw(link_sta); + + switch (width) { + case IEEE80211_STA_RX_BW_20: + if (link_sta->pub->ht_cap.ht_supported) + return NL80211_CHAN_WIDTH_20; + else + return NL80211_CHAN_WIDTH_20_NOHT; + case IEEE80211_STA_RX_BW_40: + return NL80211_CHAN_WIDTH_40; + case IEEE80211_STA_RX_BW_80: + return NL80211_CHAN_WIDTH_80; + case IEEE80211_STA_RX_BW_160: + /* + * This applied for both 160 and 80+80. since we use + * the returned value to consider degradation of + * ctx->conf.min_def, we have to make sure to take + * the bigger one (NL80211_CHAN_WIDTH_160). + * Otherwise we might try degrading even when not + * needed, as the max required sta_bw returned (80+80) + * might be smaller than the configured bw (160). + */ + return NL80211_CHAN_WIDTH_160; + case IEEE80211_STA_RX_BW_320: + return NL80211_CHAN_WIDTH_320; + default: + WARN_ON(1); + return NL80211_CHAN_WIDTH_20; + } +} + +static enum nl80211_chan_width +ieee80211_get_max_required_bw(struct ieee80211_sub_if_data *sdata, + unsigned int link_id) +{ + enum nl80211_chan_width max_bw = NL80211_CHAN_WIDTH_20_NOHT; + struct sta_info *sta; + + list_for_each_entry_rcu(sta, &sdata->local->sta_list, list) { + if (sdata != sta->sdata && + !(sta->sdata->bss && sta->sdata->bss == sdata->bss)) + continue; + + max_bw = max(max_bw, ieee80211_get_sta_bw(sta, link_id)); + } + + return max_bw; +} + +static enum nl80211_chan_width +ieee80211_get_chanctx_vif_max_required_bw(struct ieee80211_sub_if_data *sdata, + struct ieee80211_chanctx *ctx, + struct ieee80211_link_data *rsvd_for) +{ + enum nl80211_chan_width max_bw = NL80211_CHAN_WIDTH_20_NOHT; + struct ieee80211_vif *vif = &sdata->vif; + int link_id; + + rcu_read_lock(); + for (link_id = 0; link_id < ARRAY_SIZE(sdata->link); link_id++) { + enum nl80211_chan_width width = NL80211_CHAN_WIDTH_20_NOHT; + struct ieee80211_link_data *link = + rcu_dereference(sdata->link[link_id]); + + if (!link) + continue; + + if (link != rsvd_for && + rcu_access_pointer(link->conf->chanctx_conf) != &ctx->conf) + continue; + + switch (vif->type) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + width = ieee80211_get_max_required_bw(sdata, link_id); + break; + case NL80211_IFTYPE_STATION: + /* + * The ap's sta->bandwidth is not set yet at this + * point, so take the width from the chandef, but + * account also for TDLS peers + */ + width = max(link->conf->chandef.width, + ieee80211_get_max_required_bw(sdata, link_id)); + break; + case NL80211_IFTYPE_P2P_DEVICE: + case NL80211_IFTYPE_NAN: + continue; + case NL80211_IFTYPE_ADHOC: + case NL80211_IFTYPE_MESH_POINT: + case NL80211_IFTYPE_OCB: + width = link->conf->chandef.width; + break; + case NL80211_IFTYPE_WDS: + case NL80211_IFTYPE_UNSPECIFIED: + case NUM_NL80211_IFTYPES: + case NL80211_IFTYPE_MONITOR: + case NL80211_IFTYPE_P2P_CLIENT: + case NL80211_IFTYPE_P2P_GO: + WARN_ON_ONCE(1); + } + + max_bw = max(max_bw, width); + } + rcu_read_unlock(); + + return max_bw; +} + +static enum nl80211_chan_width +ieee80211_get_chanctx_max_required_bw(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx, + struct ieee80211_link_data *rsvd_for) +{ + struct ieee80211_sub_if_data *sdata; + enum nl80211_chan_width max_bw = NL80211_CHAN_WIDTH_20_NOHT; + + rcu_read_lock(); + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + enum nl80211_chan_width width; + + if (!ieee80211_sdata_running(sdata)) + continue; + + width = ieee80211_get_chanctx_vif_max_required_bw(sdata, ctx, + rsvd_for); + + max_bw = max(max_bw, width); + } + + /* use the configured bandwidth in case of monitor interface */ + sdata = rcu_dereference(local->monitor_sdata); + if (sdata && + rcu_access_pointer(sdata->vif.bss_conf.chanctx_conf) == &ctx->conf) + max_bw = max(max_bw, ctx->conf.def.width); + + rcu_read_unlock(); + + return max_bw; +} + +/* + * recalc the min required chan width of the channel context, which is + * the max of min required widths of all the interfaces bound to this + * channel context. + */ +static u32 +_ieee80211_recalc_chanctx_min_def(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx, + struct ieee80211_link_data *rsvd_for) +{ + enum nl80211_chan_width max_bw; + struct cfg80211_chan_def min_def; + + lockdep_assert_held(&local->chanctx_mtx); + + /* don't optimize non-20MHz based and radar_enabled confs */ + if (ctx->conf.def.width == NL80211_CHAN_WIDTH_5 || + ctx->conf.def.width == NL80211_CHAN_WIDTH_10 || + ctx->conf.def.width == NL80211_CHAN_WIDTH_1 || + ctx->conf.def.width == NL80211_CHAN_WIDTH_2 || + ctx->conf.def.width == NL80211_CHAN_WIDTH_4 || + ctx->conf.def.width == NL80211_CHAN_WIDTH_8 || + ctx->conf.def.width == NL80211_CHAN_WIDTH_16 || + ctx->conf.radar_enabled) { + ctx->conf.min_def = ctx->conf.def; + return 0; + } + + max_bw = ieee80211_get_chanctx_max_required_bw(local, ctx, rsvd_for); + + /* downgrade chandef up to max_bw */ + min_def = ctx->conf.def; + while (min_def.width > max_bw) + ieee80211_chandef_downgrade(&min_def); + + if (cfg80211_chandef_identical(&ctx->conf.min_def, &min_def)) + return 0; + + ctx->conf.min_def = min_def; + if (!ctx->driver_present) + return 0; + + return IEEE80211_CHANCTX_CHANGE_MIN_WIDTH; +} + +/* calling this function is assuming that station vif is updated to + * lates changes by calling ieee80211_link_update_chandef + */ +static void ieee80211_chan_bw_change(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx, + bool narrowed) +{ + struct sta_info *sta; + struct ieee80211_supported_band *sband = + local->hw.wiphy->bands[ctx->conf.def.chan->band]; + + rcu_read_lock(); + list_for_each_entry_rcu(sta, &local->sta_list, + list) { + struct ieee80211_sub_if_data *sdata = sta->sdata; + enum ieee80211_sta_rx_bandwidth new_sta_bw; + unsigned int link_id; + + if (!ieee80211_sdata_running(sta->sdata)) + continue; + + for (link_id = 0; link_id < ARRAY_SIZE(sta->sdata->link); link_id++) { + struct ieee80211_bss_conf *link_conf = + rcu_dereference(sdata->vif.link_conf[link_id]); + struct link_sta_info *link_sta; + + if (!link_conf) + continue; + + if (rcu_access_pointer(link_conf->chanctx_conf) != &ctx->conf) + continue; + + link_sta = rcu_dereference(sta->link[link_id]); + if (!link_sta) + continue; + + new_sta_bw = ieee80211_sta_cur_vht_bw(link_sta); + + /* nothing change */ + if (new_sta_bw == link_sta->pub->bandwidth) + continue; + + /* vif changed to narrow BW and narrow BW for station wasn't + * requested or vise versa */ + if ((new_sta_bw < link_sta->pub->bandwidth) == !narrowed) + continue; + + link_sta->pub->bandwidth = new_sta_bw; + rate_control_rate_update(local, sband, sta, link_id, + IEEE80211_RC_BW_CHANGED); + } + } + rcu_read_unlock(); +} + +/* + * recalc the min required chan width of the channel context, which is + * the max of min required widths of all the interfaces bound to this + * channel context. + */ +void ieee80211_recalc_chanctx_min_def(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx, + struct ieee80211_link_data *rsvd_for) +{ + u32 changed = _ieee80211_recalc_chanctx_min_def(local, ctx, rsvd_for); + + if (!changed) + return; + + /* check is BW narrowed */ + ieee80211_chan_bw_change(local, ctx, true); + + drv_change_chanctx(local, ctx, changed); + + /* check is BW wider */ + ieee80211_chan_bw_change(local, ctx, false); +} + +static void _ieee80211_change_chanctx(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx, + struct ieee80211_chanctx *old_ctx, + const struct cfg80211_chan_def *chandef, + struct ieee80211_link_data *rsvd_for) +{ + u32 changed; + + /* expected to handle only 20/40/80/160/320 channel widths */ + switch (chandef->width) { + case NL80211_CHAN_WIDTH_20_NOHT: + case NL80211_CHAN_WIDTH_20: + case NL80211_CHAN_WIDTH_40: + case NL80211_CHAN_WIDTH_80: + case NL80211_CHAN_WIDTH_80P80: + case NL80211_CHAN_WIDTH_160: + case NL80211_CHAN_WIDTH_320: + break; + default: + WARN_ON(1); + } + + /* Check maybe BW narrowed - we do this _before_ calling recalc_chanctx_min_def + * due to maybe not returning from it, e.g in case new context was added + * first time with all parameters up to date. + */ + ieee80211_chan_bw_change(local, old_ctx, true); + + if (cfg80211_chandef_identical(&ctx->conf.def, chandef)) { + ieee80211_recalc_chanctx_min_def(local, ctx, rsvd_for); + return; + } + + WARN_ON(!cfg80211_chandef_compatible(&ctx->conf.def, chandef)); + + ctx->conf.def = *chandef; + + /* check if min chanctx also changed */ + changed = IEEE80211_CHANCTX_CHANGE_WIDTH | + _ieee80211_recalc_chanctx_min_def(local, ctx, rsvd_for); + drv_change_chanctx(local, ctx, changed); + + if (!local->use_chanctx) { + local->_oper_chandef = *chandef; + ieee80211_hw_config(local, 0); + } + + /* check is BW wider */ + ieee80211_chan_bw_change(local, old_ctx, false); +} + +static void ieee80211_change_chanctx(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx, + struct ieee80211_chanctx *old_ctx, + const struct cfg80211_chan_def *chandef) +{ + _ieee80211_change_chanctx(local, ctx, old_ctx, chandef, NULL); +} + +static struct ieee80211_chanctx * +ieee80211_find_chanctx(struct ieee80211_local *local, + const struct cfg80211_chan_def *chandef, + enum ieee80211_chanctx_mode mode) +{ + struct ieee80211_chanctx *ctx; + + lockdep_assert_held(&local->chanctx_mtx); + + if (mode == IEEE80211_CHANCTX_EXCLUSIVE) + return NULL; + + list_for_each_entry(ctx, &local->chanctx_list, list) { + const struct cfg80211_chan_def *compat; + + if (ctx->replace_state != IEEE80211_CHANCTX_REPLACE_NONE) + continue; + + if (ctx->mode == IEEE80211_CHANCTX_EXCLUSIVE) + continue; + + compat = cfg80211_chandef_compatible(&ctx->conf.def, chandef); + if (!compat) + continue; + + compat = ieee80211_chanctx_reserved_chandef(local, ctx, + compat); + if (!compat) + continue; + + ieee80211_change_chanctx(local, ctx, ctx, compat); + + return ctx; + } + + return NULL; +} + +bool ieee80211_is_radar_required(struct ieee80211_local *local) +{ + struct ieee80211_sub_if_data *sdata; + + lockdep_assert_held(&local->mtx); + + rcu_read_lock(); + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + unsigned int link_id; + + for (link_id = 0; link_id < ARRAY_SIZE(sdata->link); link_id++) { + struct ieee80211_link_data *link; + + link = rcu_dereference(sdata->link[link_id]); + + if (link && link->radar_required) { + rcu_read_unlock(); + return true; + } + } + } + rcu_read_unlock(); + + return false; +} + +static bool +ieee80211_chanctx_radar_required(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx) +{ + struct ieee80211_chanctx_conf *conf = &ctx->conf; + struct ieee80211_sub_if_data *sdata; + bool required = false; + + lockdep_assert_held(&local->chanctx_mtx); + lockdep_assert_held(&local->mtx); + + rcu_read_lock(); + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + unsigned int link_id; + + if (!ieee80211_sdata_running(sdata)) + continue; + for (link_id = 0; link_id < ARRAY_SIZE(sdata->link); link_id++) { + struct ieee80211_link_data *link; + + link = rcu_dereference(sdata->link[link_id]); + if (!link) + continue; + + if (rcu_access_pointer(link->conf->chanctx_conf) != conf) + continue; + if (!link->radar_required) + continue; + required = true; + break; + } + + if (required) + break; + } + rcu_read_unlock(); + + return required; +} + +static struct ieee80211_chanctx * +ieee80211_alloc_chanctx(struct ieee80211_local *local, + const struct cfg80211_chan_def *chandef, + enum ieee80211_chanctx_mode mode) +{ + struct ieee80211_chanctx *ctx; + + lockdep_assert_held(&local->chanctx_mtx); + + ctx = kzalloc(sizeof(*ctx) + local->hw.chanctx_data_size, GFP_KERNEL); + if (!ctx) + return NULL; + + INIT_LIST_HEAD(&ctx->assigned_links); + INIT_LIST_HEAD(&ctx->reserved_links); + ctx->conf.def = *chandef; + ctx->conf.rx_chains_static = 1; + ctx->conf.rx_chains_dynamic = 1; + ctx->mode = mode; + ctx->conf.radar_enabled = false; + _ieee80211_recalc_chanctx_min_def(local, ctx, NULL); + + return ctx; +} + +static int ieee80211_add_chanctx(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx) +{ + u32 changed; + int err; + + lockdep_assert_held(&local->mtx); + lockdep_assert_held(&local->chanctx_mtx); + + if (!local->use_chanctx) + local->hw.conf.radar_enabled = ctx->conf.radar_enabled; + + /* turn idle off *before* setting channel -- some drivers need that */ + changed = ieee80211_idle_off(local); + if (changed) + ieee80211_hw_config(local, changed); + + if (!local->use_chanctx) { + local->_oper_chandef = ctx->conf.def; + ieee80211_hw_config(local, IEEE80211_CONF_CHANGE_CHANNEL); + } else { + err = drv_add_chanctx(local, ctx); + if (err) { + ieee80211_recalc_idle(local); + return err; + } + } + + return 0; +} + +static struct ieee80211_chanctx * +ieee80211_new_chanctx(struct ieee80211_local *local, + const struct cfg80211_chan_def *chandef, + enum ieee80211_chanctx_mode mode) +{ + struct ieee80211_chanctx *ctx; + int err; + + lockdep_assert_held(&local->mtx); + lockdep_assert_held(&local->chanctx_mtx); + + ctx = ieee80211_alloc_chanctx(local, chandef, mode); + if (!ctx) + return ERR_PTR(-ENOMEM); + + err = ieee80211_add_chanctx(local, ctx); + if (err) { + kfree(ctx); + return ERR_PTR(err); + } + + list_add_rcu(&ctx->list, &local->chanctx_list); + return ctx; +} + +static void ieee80211_del_chanctx(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx) +{ + lockdep_assert_held(&local->chanctx_mtx); + + if (!local->use_chanctx) { + struct cfg80211_chan_def *chandef = &local->_oper_chandef; + /* S1G doesn't have 20MHz, so get the correct width for the + * current channel. + */ + if (chandef->chan->band == NL80211_BAND_S1GHZ) + chandef->width = + ieee80211_s1g_channel_width(chandef->chan); + else + chandef->width = NL80211_CHAN_WIDTH_20_NOHT; + chandef->center_freq1 = chandef->chan->center_freq; + chandef->freq1_offset = chandef->chan->freq_offset; + chandef->center_freq2 = 0; + + /* NOTE: Disabling radar is only valid here for + * single channel context. To be sure, check it ... + */ + WARN_ON(local->hw.conf.radar_enabled && + !list_empty(&local->chanctx_list)); + + local->hw.conf.radar_enabled = false; + + ieee80211_hw_config(local, IEEE80211_CONF_CHANGE_CHANNEL); + } else { + drv_remove_chanctx(local, ctx); + } + + ieee80211_recalc_idle(local); +} + +static void ieee80211_free_chanctx(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx) +{ + lockdep_assert_held(&local->chanctx_mtx); + + WARN_ON_ONCE(ieee80211_chanctx_refcount(local, ctx) != 0); + + list_del_rcu(&ctx->list); + ieee80211_del_chanctx(local, ctx); + kfree_rcu(ctx, rcu_head); +} + +void ieee80211_recalc_chanctx_chantype(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx) +{ + struct ieee80211_chanctx_conf *conf = &ctx->conf; + struct ieee80211_sub_if_data *sdata; + const struct cfg80211_chan_def *compat = NULL; + struct sta_info *sta; + + lockdep_assert_held(&local->chanctx_mtx); + + rcu_read_lock(); + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + int link_id; + + if (!ieee80211_sdata_running(sdata)) + continue; + + if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + continue; + + for (link_id = 0; link_id < ARRAY_SIZE(sdata->link); link_id++) { + struct ieee80211_bss_conf *link_conf = + rcu_dereference(sdata->vif.link_conf[link_id]); + + if (!link_conf) + continue; + + if (rcu_access_pointer(link_conf->chanctx_conf) != conf) + continue; + + if (!compat) + compat = &link_conf->chandef; + + compat = cfg80211_chandef_compatible(&link_conf->chandef, + compat); + if (WARN_ON_ONCE(!compat)) + break; + } + } + + /* TDLS peers can sometimes affect the chandef width */ + list_for_each_entry_rcu(sta, &local->sta_list, list) { + if (!sta->uploaded || + !test_sta_flag(sta, WLAN_STA_TDLS_WIDER_BW) || + !test_sta_flag(sta, WLAN_STA_AUTHORIZED) || + !sta->tdls_chandef.chan) + continue; + + compat = cfg80211_chandef_compatible(&sta->tdls_chandef, + compat); + if (WARN_ON_ONCE(!compat)) + break; + } + rcu_read_unlock(); + + if (!compat) + return; + + ieee80211_change_chanctx(local, ctx, ctx, compat); +} + +static void ieee80211_recalc_radar_chanctx(struct ieee80211_local *local, + struct ieee80211_chanctx *chanctx) +{ + bool radar_enabled; + + lockdep_assert_held(&local->chanctx_mtx); + /* for ieee80211_is_radar_required */ + lockdep_assert_held(&local->mtx); + + radar_enabled = ieee80211_chanctx_radar_required(local, chanctx); + + if (radar_enabled == chanctx->conf.radar_enabled) + return; + + chanctx->conf.radar_enabled = radar_enabled; + + if (!local->use_chanctx) { + local->hw.conf.radar_enabled = chanctx->conf.radar_enabled; + ieee80211_hw_config(local, IEEE80211_CONF_CHANGE_CHANNEL); + } + + drv_change_chanctx(local, chanctx, IEEE80211_CHANCTX_CHANGE_RADAR); +} + +static int ieee80211_assign_link_chanctx(struct ieee80211_link_data *link, + struct ieee80211_chanctx *new_ctx) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_chanctx_conf *conf; + struct ieee80211_chanctx *curr_ctx = NULL; + int ret = 0; + + if (WARN_ON(sdata->vif.type == NL80211_IFTYPE_NAN)) + return -ENOTSUPP; + + conf = rcu_dereference_protected(link->conf->chanctx_conf, + lockdep_is_held(&local->chanctx_mtx)); + + if (conf) { + curr_ctx = container_of(conf, struct ieee80211_chanctx, conf); + + drv_unassign_vif_chanctx(local, sdata, link->conf, curr_ctx); + conf = NULL; + list_del(&link->assigned_chanctx_list); + } + + if (new_ctx) { + /* recalc considering the link we'll use it for now */ + ieee80211_recalc_chanctx_min_def(local, new_ctx, link); + + ret = drv_assign_vif_chanctx(local, sdata, link->conf, new_ctx); + if (ret) + goto out; + + conf = &new_ctx->conf; + list_add(&link->assigned_chanctx_list, + &new_ctx->assigned_links); + } + +out: + rcu_assign_pointer(link->conf->chanctx_conf, conf); + + sdata->vif.cfg.idle = !conf; + + if (curr_ctx && ieee80211_chanctx_num_assigned(local, curr_ctx) > 0) { + ieee80211_recalc_chanctx_chantype(local, curr_ctx); + ieee80211_recalc_smps_chanctx(local, curr_ctx); + ieee80211_recalc_radar_chanctx(local, curr_ctx); + ieee80211_recalc_chanctx_min_def(local, curr_ctx, NULL); + } + + if (new_ctx && ieee80211_chanctx_num_assigned(local, new_ctx) > 0) { + ieee80211_recalc_txpower(sdata, false); + ieee80211_recalc_chanctx_min_def(local, new_ctx, NULL); + } + + if (sdata->vif.type != NL80211_IFTYPE_P2P_DEVICE && + sdata->vif.type != NL80211_IFTYPE_MONITOR) + ieee80211_vif_cfg_change_notify(sdata, BSS_CHANGED_IDLE); + + ieee80211_check_fast_xmit_iface(sdata); + + return ret; +} + +void ieee80211_recalc_smps_chanctx(struct ieee80211_local *local, + struct ieee80211_chanctx *chanctx) +{ + struct ieee80211_sub_if_data *sdata; + u8 rx_chains_static, rx_chains_dynamic; + + lockdep_assert_held(&local->chanctx_mtx); + + rx_chains_static = 1; + rx_chains_dynamic = 1; + + rcu_read_lock(); + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + u8 needed_static, needed_dynamic; + unsigned int link_id; + + if (!ieee80211_sdata_running(sdata)) + continue; + + switch (sdata->vif.type) { + case NL80211_IFTYPE_STATION: + if (!sdata->u.mgd.associated) + continue; + break; + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_ADHOC: + case NL80211_IFTYPE_MESH_POINT: + case NL80211_IFTYPE_OCB: + break; + default: + continue; + } + + for (link_id = 0; link_id < ARRAY_SIZE(sdata->link); link_id++) { + struct ieee80211_link_data *link; + + link = rcu_dereference(sdata->link[link_id]); + + if (!link) + continue; + + if (rcu_access_pointer(link->conf->chanctx_conf) != &chanctx->conf) + continue; + + switch (link->smps_mode) { + default: + WARN_ONCE(1, "Invalid SMPS mode %d\n", + link->smps_mode); + fallthrough; + case IEEE80211_SMPS_OFF: + needed_static = link->needed_rx_chains; + needed_dynamic = link->needed_rx_chains; + break; + case IEEE80211_SMPS_DYNAMIC: + needed_static = 1; + needed_dynamic = link->needed_rx_chains; + break; + case IEEE80211_SMPS_STATIC: + needed_static = 1; + needed_dynamic = 1; + break; + } + + rx_chains_static = max(rx_chains_static, needed_static); + rx_chains_dynamic = max(rx_chains_dynamic, needed_dynamic); + } + } + + /* Disable SMPS for the monitor interface */ + sdata = rcu_dereference(local->monitor_sdata); + if (sdata && + rcu_access_pointer(sdata->vif.bss_conf.chanctx_conf) == &chanctx->conf) + rx_chains_dynamic = rx_chains_static = local->rx_chains; + + rcu_read_unlock(); + + if (!local->use_chanctx) { + if (rx_chains_static > 1) + local->smps_mode = IEEE80211_SMPS_OFF; + else if (rx_chains_dynamic > 1) + local->smps_mode = IEEE80211_SMPS_DYNAMIC; + else + local->smps_mode = IEEE80211_SMPS_STATIC; + ieee80211_hw_config(local, 0); + } + + if (rx_chains_static == chanctx->conf.rx_chains_static && + rx_chains_dynamic == chanctx->conf.rx_chains_dynamic) + return; + + chanctx->conf.rx_chains_static = rx_chains_static; + chanctx->conf.rx_chains_dynamic = rx_chains_dynamic; + drv_change_chanctx(local, chanctx, IEEE80211_CHANCTX_CHANGE_RX_CHAINS); +} + +static void +__ieee80211_link_copy_chanctx_to_vlans(struct ieee80211_link_data *link, + bool clear) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + unsigned int link_id = link->link_id; + struct ieee80211_bss_conf *link_conf = link->conf; + struct ieee80211_local *local __maybe_unused = sdata->local; + struct ieee80211_sub_if_data *vlan; + struct ieee80211_chanctx_conf *conf; + + if (WARN_ON(sdata->vif.type != NL80211_IFTYPE_AP)) + return; + + lockdep_assert_held(&local->mtx); + + /* Check that conf exists, even when clearing this function + * must be called with the AP's channel context still there + * as it would otherwise cause VLANs to have an invalid + * channel context pointer for a while, possibly pointing + * to a channel context that has already been freed. + */ + conf = rcu_dereference_protected(link_conf->chanctx_conf, + lockdep_is_held(&local->chanctx_mtx)); + WARN_ON(!conf); + + if (clear) + conf = NULL; + + rcu_read_lock(); + list_for_each_entry(vlan, &sdata->u.ap.vlans, u.vlan.list) { + struct ieee80211_bss_conf *vlan_conf; + + vlan_conf = rcu_dereference(vlan->vif.link_conf[link_id]); + if (WARN_ON(!vlan_conf)) + continue; + + rcu_assign_pointer(vlan_conf->chanctx_conf, conf); + } + rcu_read_unlock(); +} + +void ieee80211_link_copy_chanctx_to_vlans(struct ieee80211_link_data *link, + bool clear) +{ + struct ieee80211_local *local = link->sdata->local; + + mutex_lock(&local->chanctx_mtx); + + __ieee80211_link_copy_chanctx_to_vlans(link, clear); + + mutex_unlock(&local->chanctx_mtx); +} + +int ieee80211_link_unreserve_chanctx(struct ieee80211_link_data *link) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_chanctx *ctx = link->reserved_chanctx; + + lockdep_assert_held(&sdata->local->chanctx_mtx); + + if (WARN_ON(!ctx)) + return -EINVAL; + + list_del(&link->reserved_chanctx_list); + link->reserved_chanctx = NULL; + + if (ieee80211_chanctx_refcount(sdata->local, ctx) == 0) { + if (ctx->replace_state == IEEE80211_CHANCTX_REPLACES_OTHER) { + if (WARN_ON(!ctx->replace_ctx)) + return -EINVAL; + + WARN_ON(ctx->replace_ctx->replace_state != + IEEE80211_CHANCTX_WILL_BE_REPLACED); + WARN_ON(ctx->replace_ctx->replace_ctx != ctx); + + ctx->replace_ctx->replace_ctx = NULL; + ctx->replace_ctx->replace_state = + IEEE80211_CHANCTX_REPLACE_NONE; + + list_del_rcu(&ctx->list); + kfree_rcu(ctx, rcu_head); + } else { + ieee80211_free_chanctx(sdata->local, ctx); + } + } + + return 0; +} + +int ieee80211_link_reserve_chanctx(struct ieee80211_link_data *link, + const struct cfg80211_chan_def *chandef, + enum ieee80211_chanctx_mode mode, + bool radar_required) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_chanctx *new_ctx, *curr_ctx, *ctx; + + lockdep_assert_held(&local->chanctx_mtx); + + curr_ctx = ieee80211_link_get_chanctx(link); + if (curr_ctx && local->use_chanctx && !local->ops->switch_vif_chanctx) + return -ENOTSUPP; + + new_ctx = ieee80211_find_reservation_chanctx(local, chandef, mode); + if (!new_ctx) { + if (ieee80211_can_create_new_chanctx(local)) { + new_ctx = ieee80211_new_chanctx(local, chandef, mode); + if (IS_ERR(new_ctx)) + return PTR_ERR(new_ctx); + } else { + if (!curr_ctx || + (curr_ctx->replace_state == + IEEE80211_CHANCTX_WILL_BE_REPLACED) || + !list_empty(&curr_ctx->reserved_links)) { + /* + * Another link already requested this context + * for a reservation. Find another one hoping + * all links assigned to it will also switch + * soon enough. + * + * TODO: This needs a little more work as some + * cases (more than 2 chanctx capable devices) + * may fail which could otherwise succeed + * provided some channel context juggling was + * performed. + * + * Consider ctx1..3, link1..6, each ctx has 2 + * links. link1 and link2 from ctx1 request new + * different chandefs starting 2 in-place + * reserations with ctx4 and ctx5 replacing + * ctx1 and ctx2 respectively. Next link5 and + * link6 from ctx3 reserve ctx4. If link3 and + * link4 remain on ctx2 as they are then this + * fails unless `replace_ctx` from ctx5 is + * replaced with ctx3. + */ + list_for_each_entry(ctx, &local->chanctx_list, + list) { + if (ctx->replace_state != + IEEE80211_CHANCTX_REPLACE_NONE) + continue; + + if (!list_empty(&ctx->reserved_links)) + continue; + + curr_ctx = ctx; + break; + } + } + + /* + * If that's true then all available contexts already + * have reservations and cannot be used. + */ + if (!curr_ctx || + (curr_ctx->replace_state == + IEEE80211_CHANCTX_WILL_BE_REPLACED) || + !list_empty(&curr_ctx->reserved_links)) + return -EBUSY; + + new_ctx = ieee80211_alloc_chanctx(local, chandef, mode); + if (!new_ctx) + return -ENOMEM; + + new_ctx->replace_ctx = curr_ctx; + new_ctx->replace_state = + IEEE80211_CHANCTX_REPLACES_OTHER; + + curr_ctx->replace_ctx = new_ctx; + curr_ctx->replace_state = + IEEE80211_CHANCTX_WILL_BE_REPLACED; + + list_add_rcu(&new_ctx->list, &local->chanctx_list); + } + } + + list_add(&link->reserved_chanctx_list, &new_ctx->reserved_links); + link->reserved_chanctx = new_ctx; + link->reserved_chandef = *chandef; + link->reserved_radar_required = radar_required; + link->reserved_ready = false; + + return 0; +} + +static void +ieee80211_link_chanctx_reservation_complete(struct ieee80211_link_data *link) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + + switch (sdata->vif.type) { + case NL80211_IFTYPE_ADHOC: + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_MESH_POINT: + case NL80211_IFTYPE_OCB: + ieee80211_queue_work(&sdata->local->hw, + &link->csa_finalize_work); + break; + case NL80211_IFTYPE_STATION: + ieee80211_queue_work(&sdata->local->hw, + &link->u.mgd.chswitch_work); + break; + case NL80211_IFTYPE_UNSPECIFIED: + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_WDS: + case NL80211_IFTYPE_MONITOR: + case NL80211_IFTYPE_P2P_CLIENT: + case NL80211_IFTYPE_P2P_GO: + case NL80211_IFTYPE_P2P_DEVICE: + case NL80211_IFTYPE_NAN: + case NUM_NL80211_IFTYPES: + WARN_ON(1); + break; + } +} + +static void +ieee80211_link_update_chandef(struct ieee80211_link_data *link, + const struct cfg80211_chan_def *chandef) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + unsigned int link_id = link->link_id; + struct ieee80211_sub_if_data *vlan; + + link->conf->chandef = *chandef; + + if (sdata->vif.type != NL80211_IFTYPE_AP) + return; + + rcu_read_lock(); + list_for_each_entry(vlan, &sdata->u.ap.vlans, u.vlan.list) { + struct ieee80211_bss_conf *vlan_conf; + + vlan_conf = rcu_dereference(vlan->vif.link_conf[link_id]); + if (WARN_ON(!vlan_conf)) + continue; + + vlan_conf->chandef = *chandef; + } + rcu_read_unlock(); +} + +static int +ieee80211_link_use_reserved_reassign(struct ieee80211_link_data *link) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_bss_conf *link_conf = link->conf; + struct ieee80211_local *local = sdata->local; + struct ieee80211_vif_chanctx_switch vif_chsw[1] = {}; + struct ieee80211_chanctx *old_ctx, *new_ctx; + const struct cfg80211_chan_def *chandef; + u32 changed = 0; + int err; + + lockdep_assert_held(&local->mtx); + lockdep_assert_held(&local->chanctx_mtx); + + new_ctx = link->reserved_chanctx; + old_ctx = ieee80211_link_get_chanctx(link); + + if (WARN_ON(!link->reserved_ready)) + return -EBUSY; + + if (WARN_ON(!new_ctx)) + return -EINVAL; + + if (WARN_ON(!old_ctx)) + return -EINVAL; + + if (WARN_ON(new_ctx->replace_state == + IEEE80211_CHANCTX_REPLACES_OTHER)) + return -EINVAL; + + chandef = ieee80211_chanctx_non_reserved_chandef(local, new_ctx, + &link->reserved_chandef); + if (WARN_ON(!chandef)) + return -EINVAL; + + if (link_conf->chandef.width != link->reserved_chandef.width) + changed = BSS_CHANGED_BANDWIDTH; + + ieee80211_link_update_chandef(link, &link->reserved_chandef); + + _ieee80211_change_chanctx(local, new_ctx, old_ctx, chandef, link); + + vif_chsw[0].vif = &sdata->vif; + vif_chsw[0].old_ctx = &old_ctx->conf; + vif_chsw[0].new_ctx = &new_ctx->conf; + vif_chsw[0].link_conf = link->conf; + + list_del(&link->reserved_chanctx_list); + link->reserved_chanctx = NULL; + + err = drv_switch_vif_chanctx(local, vif_chsw, 1, + CHANCTX_SWMODE_REASSIGN_VIF); + if (err) { + if (ieee80211_chanctx_refcount(local, new_ctx) == 0) + ieee80211_free_chanctx(local, new_ctx); + + goto out; + } + + list_move(&link->assigned_chanctx_list, &new_ctx->assigned_links); + rcu_assign_pointer(link_conf->chanctx_conf, &new_ctx->conf); + + if (sdata->vif.type == NL80211_IFTYPE_AP) + __ieee80211_link_copy_chanctx_to_vlans(link, false); + + ieee80211_check_fast_xmit_iface(sdata); + + if (ieee80211_chanctx_refcount(local, old_ctx) == 0) + ieee80211_free_chanctx(local, old_ctx); + + ieee80211_recalc_chanctx_min_def(local, new_ctx, NULL); + ieee80211_recalc_smps_chanctx(local, new_ctx); + ieee80211_recalc_radar_chanctx(local, new_ctx); + + if (changed) + ieee80211_link_info_change_notify(sdata, link, changed); + +out: + ieee80211_link_chanctx_reservation_complete(link); + return err; +} + +static int +ieee80211_link_use_reserved_assign(struct ieee80211_link_data *link) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_chanctx *old_ctx, *new_ctx; + const struct cfg80211_chan_def *chandef; + int err; + + old_ctx = ieee80211_link_get_chanctx(link); + new_ctx = link->reserved_chanctx; + + if (WARN_ON(!link->reserved_ready)) + return -EINVAL; + + if (WARN_ON(old_ctx)) + return -EINVAL; + + if (WARN_ON(!new_ctx)) + return -EINVAL; + + if (WARN_ON(new_ctx->replace_state == + IEEE80211_CHANCTX_REPLACES_OTHER)) + return -EINVAL; + + chandef = ieee80211_chanctx_non_reserved_chandef(local, new_ctx, + &link->reserved_chandef); + if (WARN_ON(!chandef)) + return -EINVAL; + + ieee80211_change_chanctx(local, new_ctx, new_ctx, chandef); + + list_del(&link->reserved_chanctx_list); + link->reserved_chanctx = NULL; + + err = ieee80211_assign_link_chanctx(link, new_ctx); + if (err) { + if (ieee80211_chanctx_refcount(local, new_ctx) == 0) + ieee80211_free_chanctx(local, new_ctx); + + goto out; + } + +out: + ieee80211_link_chanctx_reservation_complete(link); + return err; +} + +static bool +ieee80211_link_has_in_place_reservation(struct ieee80211_link_data *link) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_chanctx *old_ctx, *new_ctx; + + lockdep_assert_held(&sdata->local->chanctx_mtx); + + new_ctx = link->reserved_chanctx; + old_ctx = ieee80211_link_get_chanctx(link); + + if (!old_ctx) + return false; + + if (WARN_ON(!new_ctx)) + return false; + + if (old_ctx->replace_state != IEEE80211_CHANCTX_WILL_BE_REPLACED) + return false; + + if (new_ctx->replace_state != IEEE80211_CHANCTX_REPLACES_OTHER) + return false; + + return true; +} + +static int ieee80211_chsw_switch_hwconf(struct ieee80211_local *local, + struct ieee80211_chanctx *new_ctx) +{ + const struct cfg80211_chan_def *chandef; + + lockdep_assert_held(&local->mtx); + lockdep_assert_held(&local->chanctx_mtx); + + chandef = ieee80211_chanctx_reserved_chandef(local, new_ctx, NULL); + if (WARN_ON(!chandef)) + return -EINVAL; + + local->hw.conf.radar_enabled = new_ctx->conf.radar_enabled; + local->_oper_chandef = *chandef; + ieee80211_hw_config(local, 0); + + return 0; +} + +static int ieee80211_chsw_switch_vifs(struct ieee80211_local *local, + int n_vifs) +{ + struct ieee80211_vif_chanctx_switch *vif_chsw; + struct ieee80211_link_data *link; + struct ieee80211_chanctx *ctx, *old_ctx; + int i, err; + + lockdep_assert_held(&local->mtx); + lockdep_assert_held(&local->chanctx_mtx); + + vif_chsw = kcalloc(n_vifs, sizeof(vif_chsw[0]), GFP_KERNEL); + if (!vif_chsw) + return -ENOMEM; + + i = 0; + list_for_each_entry(ctx, &local->chanctx_list, list) { + if (ctx->replace_state != IEEE80211_CHANCTX_REPLACES_OTHER) + continue; + + if (WARN_ON(!ctx->replace_ctx)) { + err = -EINVAL; + goto out; + } + + list_for_each_entry(link, &ctx->reserved_links, + reserved_chanctx_list) { + if (!ieee80211_link_has_in_place_reservation(link)) + continue; + + old_ctx = ieee80211_link_get_chanctx(link); + vif_chsw[i].vif = &link->sdata->vif; + vif_chsw[i].old_ctx = &old_ctx->conf; + vif_chsw[i].new_ctx = &ctx->conf; + vif_chsw[i].link_conf = link->conf; + + i++; + } + } + + err = drv_switch_vif_chanctx(local, vif_chsw, n_vifs, + CHANCTX_SWMODE_SWAP_CONTEXTS); + +out: + kfree(vif_chsw); + return err; +} + +static int ieee80211_chsw_switch_ctxs(struct ieee80211_local *local) +{ + struct ieee80211_chanctx *ctx; + int err; + + lockdep_assert_held(&local->mtx); + lockdep_assert_held(&local->chanctx_mtx); + + list_for_each_entry(ctx, &local->chanctx_list, list) { + if (ctx->replace_state != IEEE80211_CHANCTX_REPLACES_OTHER) + continue; + + if (!list_empty(&ctx->replace_ctx->assigned_links)) + continue; + + ieee80211_del_chanctx(local, ctx->replace_ctx); + err = ieee80211_add_chanctx(local, ctx); + if (err) + goto err; + } + + return 0; + +err: + WARN_ON(ieee80211_add_chanctx(local, ctx)); + list_for_each_entry_continue_reverse(ctx, &local->chanctx_list, list) { + if (ctx->replace_state != IEEE80211_CHANCTX_REPLACES_OTHER) + continue; + + if (!list_empty(&ctx->replace_ctx->assigned_links)) + continue; + + ieee80211_del_chanctx(local, ctx); + WARN_ON(ieee80211_add_chanctx(local, ctx->replace_ctx)); + } + + return err; +} + +static int ieee80211_vif_use_reserved_switch(struct ieee80211_local *local) +{ + struct ieee80211_chanctx *ctx, *ctx_tmp, *old_ctx; + struct ieee80211_chanctx *new_ctx = NULL; + int err, n_assigned, n_reserved, n_ready; + int n_ctx = 0, n_vifs_switch = 0, n_vifs_assign = 0, n_vifs_ctxless = 0; + + lockdep_assert_held(&local->mtx); + lockdep_assert_held(&local->chanctx_mtx); + + /* + * If there are 2 independent pairs of channel contexts performing + * cross-switch of their vifs this code will still wait until both are + * ready even though it could be possible to switch one before the + * other is ready. + * + * For practical reasons and code simplicity just do a single huge + * switch. + */ + + /* + * Verify if the reservation is still feasible. + * - if it's not then disconnect + * - if it is but not all vifs necessary are ready then defer + */ + + list_for_each_entry(ctx, &local->chanctx_list, list) { + struct ieee80211_link_data *link; + + if (ctx->replace_state != IEEE80211_CHANCTX_REPLACES_OTHER) + continue; + + if (WARN_ON(!ctx->replace_ctx)) { + err = -EINVAL; + goto err; + } + + if (!local->use_chanctx) + new_ctx = ctx; + + n_ctx++; + + n_assigned = 0; + n_reserved = 0; + n_ready = 0; + + list_for_each_entry(link, &ctx->replace_ctx->assigned_links, + assigned_chanctx_list) { + n_assigned++; + if (link->reserved_chanctx) { + n_reserved++; + if (link->reserved_ready) + n_ready++; + } + } + + if (n_assigned != n_reserved) { + if (n_ready == n_reserved) { + wiphy_info(local->hw.wiphy, + "channel context reservation cannot be finalized because some interfaces aren't switching\n"); + err = -EBUSY; + goto err; + } + + return -EAGAIN; + } + + ctx->conf.radar_enabled = false; + list_for_each_entry(link, &ctx->reserved_links, + reserved_chanctx_list) { + if (ieee80211_link_has_in_place_reservation(link) && + !link->reserved_ready) + return -EAGAIN; + + old_ctx = ieee80211_link_get_chanctx(link); + if (old_ctx) { + if (old_ctx->replace_state == + IEEE80211_CHANCTX_WILL_BE_REPLACED) + n_vifs_switch++; + else + n_vifs_assign++; + } else { + n_vifs_ctxless++; + } + + if (link->reserved_radar_required) + ctx->conf.radar_enabled = true; + } + } + + if (WARN_ON(n_ctx == 0) || + WARN_ON(n_vifs_switch == 0 && + n_vifs_assign == 0 && + n_vifs_ctxless == 0) || + WARN_ON(n_ctx > 1 && !local->use_chanctx) || + WARN_ON(!new_ctx && !local->use_chanctx)) { + err = -EINVAL; + goto err; + } + + /* + * All necessary vifs are ready. Perform the switch now depending on + * reservations and driver capabilities. + */ + + if (local->use_chanctx) { + if (n_vifs_switch > 0) { + err = ieee80211_chsw_switch_vifs(local, n_vifs_switch); + if (err) + goto err; + } + + if (n_vifs_assign > 0 || n_vifs_ctxless > 0) { + err = ieee80211_chsw_switch_ctxs(local); + if (err) + goto err; + } + } else { + err = ieee80211_chsw_switch_hwconf(local, new_ctx); + if (err) + goto err; + } + + /* + * Update all structures, values and pointers to point to new channel + * context(s). + */ + list_for_each_entry(ctx, &local->chanctx_list, list) { + struct ieee80211_link_data *link, *link_tmp; + + if (ctx->replace_state != IEEE80211_CHANCTX_REPLACES_OTHER) + continue; + + if (WARN_ON(!ctx->replace_ctx)) { + err = -EINVAL; + goto err; + } + + list_for_each_entry(link, &ctx->reserved_links, + reserved_chanctx_list) { + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_bss_conf *link_conf = link->conf; + u32 changed = 0; + + if (!ieee80211_link_has_in_place_reservation(link)) + continue; + + rcu_assign_pointer(link_conf->chanctx_conf, + &ctx->conf); + + if (sdata->vif.type == NL80211_IFTYPE_AP) + __ieee80211_link_copy_chanctx_to_vlans(link, + false); + + ieee80211_check_fast_xmit_iface(sdata); + + link->radar_required = link->reserved_radar_required; + + if (link_conf->chandef.width != link->reserved_chandef.width) + changed = BSS_CHANGED_BANDWIDTH; + + ieee80211_link_update_chandef(link, &link->reserved_chandef); + if (changed) + ieee80211_link_info_change_notify(sdata, + link, + changed); + + ieee80211_recalc_txpower(sdata, false); + } + + ieee80211_recalc_chanctx_chantype(local, ctx); + ieee80211_recalc_smps_chanctx(local, ctx); + ieee80211_recalc_radar_chanctx(local, ctx); + ieee80211_recalc_chanctx_min_def(local, ctx, NULL); + + list_for_each_entry_safe(link, link_tmp, &ctx->reserved_links, + reserved_chanctx_list) { + if (ieee80211_link_get_chanctx(link) != ctx) + continue; + + list_del(&link->reserved_chanctx_list); + list_move(&link->assigned_chanctx_list, + &ctx->assigned_links); + link->reserved_chanctx = NULL; + + ieee80211_link_chanctx_reservation_complete(link); + } + + /* + * This context might have been a dependency for an already + * ready re-assign reservation interface that was deferred. Do + * not propagate error to the caller though. The in-place + * reservation for originally requested interface has already + * succeeded at this point. + */ + list_for_each_entry_safe(link, link_tmp, &ctx->reserved_links, + reserved_chanctx_list) { + if (WARN_ON(ieee80211_link_has_in_place_reservation(link))) + continue; + + if (WARN_ON(link->reserved_chanctx != ctx)) + continue; + + if (!link->reserved_ready) + continue; + + if (ieee80211_link_get_chanctx(link)) + err = ieee80211_link_use_reserved_reassign(link); + else + err = ieee80211_link_use_reserved_assign(link); + + if (err) { + link_info(link, + "failed to finalize (re-)assign reservation (err=%d)\n", + err); + ieee80211_link_unreserve_chanctx(link); + cfg80211_stop_iface(local->hw.wiphy, + &link->sdata->wdev, + GFP_KERNEL); + } + } + } + + /* + * Finally free old contexts + */ + + list_for_each_entry_safe(ctx, ctx_tmp, &local->chanctx_list, list) { + if (ctx->replace_state != IEEE80211_CHANCTX_WILL_BE_REPLACED) + continue; + + ctx->replace_ctx->replace_ctx = NULL; + ctx->replace_ctx->replace_state = + IEEE80211_CHANCTX_REPLACE_NONE; + + list_del_rcu(&ctx->list); + kfree_rcu(ctx, rcu_head); + } + + return 0; + +err: + list_for_each_entry(ctx, &local->chanctx_list, list) { + struct ieee80211_link_data *link, *link_tmp; + + if (ctx->replace_state != IEEE80211_CHANCTX_REPLACES_OTHER) + continue; + + list_for_each_entry_safe(link, link_tmp, &ctx->reserved_links, + reserved_chanctx_list) { + ieee80211_link_unreserve_chanctx(link); + ieee80211_link_chanctx_reservation_complete(link); + } + } + + return err; +} + +static void __ieee80211_link_release_channel(struct ieee80211_link_data *link) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_bss_conf *link_conf = link->conf; + struct ieee80211_local *local = sdata->local; + struct ieee80211_chanctx_conf *conf; + struct ieee80211_chanctx *ctx; + bool use_reserved_switch = false; + + lockdep_assert_held(&local->chanctx_mtx); + + conf = rcu_dereference_protected(link_conf->chanctx_conf, + lockdep_is_held(&local->chanctx_mtx)); + if (!conf) + return; + + ctx = container_of(conf, struct ieee80211_chanctx, conf); + + if (link->reserved_chanctx) { + if (link->reserved_chanctx->replace_state == IEEE80211_CHANCTX_REPLACES_OTHER && + ieee80211_chanctx_num_reserved(local, link->reserved_chanctx) > 1) + use_reserved_switch = true; + + ieee80211_link_unreserve_chanctx(link); + } + + ieee80211_assign_link_chanctx(link, NULL); + if (ieee80211_chanctx_refcount(local, ctx) == 0) + ieee80211_free_chanctx(local, ctx); + + link->radar_required = false; + + /* Unreserving may ready an in-place reservation. */ + if (use_reserved_switch) + ieee80211_vif_use_reserved_switch(local); +} + +int ieee80211_link_use_channel(struct ieee80211_link_data *link, + const struct cfg80211_chan_def *chandef, + enum ieee80211_chanctx_mode mode) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_chanctx *ctx; + u8 radar_detect_width = 0; + int ret; + + lockdep_assert_held(&local->mtx); + + if (sdata->vif.active_links && + !(sdata->vif.active_links & BIT(link->link_id))) { + ieee80211_link_update_chandef(link, chandef); + return 0; + } + + mutex_lock(&local->chanctx_mtx); + + ret = cfg80211_chandef_dfs_required(local->hw.wiphy, + chandef, + sdata->wdev.iftype); + if (ret < 0) + goto out; + if (ret > 0) + radar_detect_width = BIT(chandef->width); + + link->radar_required = ret; + + ret = ieee80211_check_combinations(sdata, chandef, mode, + radar_detect_width); + if (ret < 0) + goto out; + + __ieee80211_link_release_channel(link); + + ctx = ieee80211_find_chanctx(local, chandef, mode); + if (!ctx) + ctx = ieee80211_new_chanctx(local, chandef, mode); + if (IS_ERR(ctx)) { + ret = PTR_ERR(ctx); + goto out; + } + + ieee80211_link_update_chandef(link, chandef); + + ret = ieee80211_assign_link_chanctx(link, ctx); + if (ret) { + /* if assign fails refcount stays the same */ + if (ieee80211_chanctx_refcount(local, ctx) == 0) + ieee80211_free_chanctx(local, ctx); + goto out; + } + + ieee80211_recalc_smps_chanctx(local, ctx); + ieee80211_recalc_radar_chanctx(local, ctx); + out: + if (ret) + link->radar_required = false; + + mutex_unlock(&local->chanctx_mtx); + return ret; +} + +int ieee80211_link_use_reserved_context(struct ieee80211_link_data *link) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_chanctx *new_ctx; + struct ieee80211_chanctx *old_ctx; + int err; + + lockdep_assert_held(&local->mtx); + lockdep_assert_held(&local->chanctx_mtx); + + new_ctx = link->reserved_chanctx; + old_ctx = ieee80211_link_get_chanctx(link); + + if (WARN_ON(!new_ctx)) + return -EINVAL; + + if (WARN_ON(new_ctx->replace_state == + IEEE80211_CHANCTX_WILL_BE_REPLACED)) + return -EINVAL; + + if (WARN_ON(link->reserved_ready)) + return -EINVAL; + + link->reserved_ready = true; + + if (new_ctx->replace_state == IEEE80211_CHANCTX_REPLACE_NONE) { + if (old_ctx) + return ieee80211_link_use_reserved_reassign(link); + + return ieee80211_link_use_reserved_assign(link); + } + + /* + * In-place reservation may need to be finalized now either if: + * a) sdata is taking part in the swapping itself and is the last one + * b) sdata has switched with a re-assign reservation to an existing + * context readying in-place switching of old_ctx + * + * In case of (b) do not propagate the error up because the requested + * sdata already switched successfully. Just spill an extra warning. + * The ieee80211_vif_use_reserved_switch() already stops all necessary + * interfaces upon failure. + */ + if ((old_ctx && + old_ctx->replace_state == IEEE80211_CHANCTX_WILL_BE_REPLACED) || + new_ctx->replace_state == IEEE80211_CHANCTX_REPLACES_OTHER) { + err = ieee80211_vif_use_reserved_switch(local); + if (err && err != -EAGAIN) { + if (new_ctx->replace_state == + IEEE80211_CHANCTX_REPLACES_OTHER) + return err; + + wiphy_info(local->hw.wiphy, + "depending in-place reservation failed (err=%d)\n", + err); + } + } + + return 0; +} + +int ieee80211_link_change_bandwidth(struct ieee80211_link_data *link, + const struct cfg80211_chan_def *chandef, + u32 *changed) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_bss_conf *link_conf = link->conf; + struct ieee80211_local *local = sdata->local; + struct ieee80211_chanctx_conf *conf; + struct ieee80211_chanctx *ctx; + const struct cfg80211_chan_def *compat; + int ret; + + if (!cfg80211_chandef_usable(sdata->local->hw.wiphy, chandef, + IEEE80211_CHAN_DISABLED)) + return -EINVAL; + + mutex_lock(&local->chanctx_mtx); + if (cfg80211_chandef_identical(chandef, &link_conf->chandef)) { + ret = 0; + goto out; + } + + if (chandef->width == NL80211_CHAN_WIDTH_20_NOHT || + link_conf->chandef.width == NL80211_CHAN_WIDTH_20_NOHT) { + ret = -EINVAL; + goto out; + } + + conf = rcu_dereference_protected(link_conf->chanctx_conf, + lockdep_is_held(&local->chanctx_mtx)); + if (!conf) { + ret = -EINVAL; + goto out; + } + + ctx = container_of(conf, struct ieee80211_chanctx, conf); + + compat = cfg80211_chandef_compatible(&conf->def, chandef); + if (!compat) { + ret = -EINVAL; + goto out; + } + + switch (ctx->replace_state) { + case IEEE80211_CHANCTX_REPLACE_NONE: + if (!ieee80211_chanctx_reserved_chandef(local, ctx, compat)) { + ret = -EBUSY; + goto out; + } + break; + case IEEE80211_CHANCTX_WILL_BE_REPLACED: + /* TODO: Perhaps the bandwidth change could be treated as a + * reservation itself? */ + ret = -EBUSY; + goto out; + case IEEE80211_CHANCTX_REPLACES_OTHER: + /* channel context that is going to replace another channel + * context doesn't really exist and shouldn't be assigned + * anywhere yet */ + WARN_ON(1); + break; + } + + ieee80211_link_update_chandef(link, chandef); + + ieee80211_recalc_chanctx_chantype(local, ctx); + + *changed |= BSS_CHANGED_BANDWIDTH; + ret = 0; + out: + mutex_unlock(&local->chanctx_mtx); + return ret; +} + +void ieee80211_link_release_channel(struct ieee80211_link_data *link) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + + mutex_lock(&sdata->local->chanctx_mtx); + if (rcu_access_pointer(link->conf->chanctx_conf)) { + lockdep_assert_held(&sdata->local->mtx); + __ieee80211_link_release_channel(link); + } + mutex_unlock(&sdata->local->chanctx_mtx); +} + +void ieee80211_link_vlan_copy_chanctx(struct ieee80211_link_data *link) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + unsigned int link_id = link->link_id; + struct ieee80211_bss_conf *link_conf = link->conf; + struct ieee80211_bss_conf *ap_conf; + struct ieee80211_local *local = sdata->local; + struct ieee80211_sub_if_data *ap; + struct ieee80211_chanctx_conf *conf; + + if (WARN_ON(sdata->vif.type != NL80211_IFTYPE_AP_VLAN || !sdata->bss)) + return; + + ap = container_of(sdata->bss, struct ieee80211_sub_if_data, u.ap); + + mutex_lock(&local->chanctx_mtx); + + rcu_read_lock(); + ap_conf = rcu_dereference(ap->vif.link_conf[link_id]); + conf = rcu_dereference_protected(ap_conf->chanctx_conf, + lockdep_is_held(&local->chanctx_mtx)); + rcu_assign_pointer(link_conf->chanctx_conf, conf); + rcu_read_unlock(); + mutex_unlock(&local->chanctx_mtx); +} + +void ieee80211_iter_chan_contexts_atomic( + struct ieee80211_hw *hw, + void (*iter)(struct ieee80211_hw *hw, + struct ieee80211_chanctx_conf *chanctx_conf, + void *data), + void *iter_data) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_chanctx *ctx; + + rcu_read_lock(); + list_for_each_entry_rcu(ctx, &local->chanctx_list, list) + if (ctx->driver_present) + iter(hw, &ctx->conf, iter_data); + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(ieee80211_iter_chan_contexts_atomic); diff --git a/net/mac80211/debug.h b/net/mac80211/debug.h new file mode 100644 index 000000000..b4c20f5e7 --- /dev/null +++ b/net/mac80211/debug.h @@ -0,0 +1,234 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Portions + * Copyright (C) 2022 Intel Corporation + */ +#ifndef __MAC80211_DEBUG_H +#define __MAC80211_DEBUG_H +#include + +#ifdef CONFIG_MAC80211_OCB_DEBUG +#define MAC80211_OCB_DEBUG 1 +#else +#define MAC80211_OCB_DEBUG 0 +#endif + +#ifdef CONFIG_MAC80211_IBSS_DEBUG +#define MAC80211_IBSS_DEBUG 1 +#else +#define MAC80211_IBSS_DEBUG 0 +#endif + +#ifdef CONFIG_MAC80211_PS_DEBUG +#define MAC80211_PS_DEBUG 1 +#else +#define MAC80211_PS_DEBUG 0 +#endif + +#ifdef CONFIG_MAC80211_HT_DEBUG +#define MAC80211_HT_DEBUG 1 +#else +#define MAC80211_HT_DEBUG 0 +#endif + +#ifdef CONFIG_MAC80211_MPL_DEBUG +#define MAC80211_MPL_DEBUG 1 +#else +#define MAC80211_MPL_DEBUG 0 +#endif + +#ifdef CONFIG_MAC80211_MPATH_DEBUG +#define MAC80211_MPATH_DEBUG 1 +#else +#define MAC80211_MPATH_DEBUG 0 +#endif + +#ifdef CONFIG_MAC80211_MHWMP_DEBUG +#define MAC80211_MHWMP_DEBUG 1 +#else +#define MAC80211_MHWMP_DEBUG 0 +#endif + +#ifdef CONFIG_MAC80211_MESH_SYNC_DEBUG +#define MAC80211_MESH_SYNC_DEBUG 1 +#else +#define MAC80211_MESH_SYNC_DEBUG 0 +#endif + +#ifdef CONFIG_MAC80211_MESH_CSA_DEBUG +#define MAC80211_MESH_CSA_DEBUG 1 +#else +#define MAC80211_MESH_CSA_DEBUG 0 +#endif + +#ifdef CONFIG_MAC80211_MESH_PS_DEBUG +#define MAC80211_MESH_PS_DEBUG 1 +#else +#define MAC80211_MESH_PS_DEBUG 0 +#endif + +#ifdef CONFIG_MAC80211_TDLS_DEBUG +#define MAC80211_TDLS_DEBUG 1 +#else +#define MAC80211_TDLS_DEBUG 0 +#endif + +#ifdef CONFIG_MAC80211_STA_DEBUG +#define MAC80211_STA_DEBUG 1 +#else +#define MAC80211_STA_DEBUG 0 +#endif + +#ifdef CONFIG_MAC80211_MLME_DEBUG +#define MAC80211_MLME_DEBUG 1 +#else +#define MAC80211_MLME_DEBUG 0 +#endif + +#ifdef CONFIG_MAC80211_MESSAGE_TRACING +void __sdata_info(const char *fmt, ...) __printf(1, 2); +void __sdata_dbg(bool print, const char *fmt, ...) __printf(2, 3); +void __sdata_err(const char *fmt, ...) __printf(1, 2); +void __wiphy_dbg(struct wiphy *wiphy, bool print, const char *fmt, ...) + __printf(3, 4); + +#define _sdata_info(sdata, fmt, ...) \ + __sdata_info("%s: " fmt, (sdata)->name, ##__VA_ARGS__) +#define _sdata_dbg(print, sdata, fmt, ...) \ + __sdata_dbg(print, "%s: " fmt, (sdata)->name, ##__VA_ARGS__) +#define _sdata_err(sdata, fmt, ...) \ + __sdata_err("%s: " fmt, (sdata)->name, ##__VA_ARGS__) +#define _wiphy_dbg(print, wiphy, fmt, ...) \ + __wiphy_dbg(wiphy, print, fmt, ##__VA_ARGS__) +#else +#define _sdata_info(sdata, fmt, ...) \ +do { \ + pr_info("%s: " fmt, \ + (sdata)->name, ##__VA_ARGS__); \ +} while (0) + +#define _sdata_dbg(print, sdata, fmt, ...) \ +do { \ + if (print) \ + pr_debug("%s: " fmt, \ + (sdata)->name, ##__VA_ARGS__); \ +} while (0) + +#define _sdata_err(sdata, fmt, ...) \ +do { \ + pr_err("%s: " fmt, \ + (sdata)->name, ##__VA_ARGS__); \ +} while (0) + +#define _wiphy_dbg(print, wiphy, fmt, ...) \ +do { \ + if (print) \ + wiphy_dbg((wiphy), fmt, ##__VA_ARGS__); \ +} while (0) +#endif + +#define sdata_info(sdata, fmt, ...) \ + _sdata_info(sdata, fmt, ##__VA_ARGS__) +#define sdata_err(sdata, fmt, ...) \ + _sdata_err(sdata, fmt, ##__VA_ARGS__) +#define sdata_dbg(sdata, fmt, ...) \ + _sdata_dbg(1, sdata, fmt, ##__VA_ARGS__) + +#define link_info(link, fmt, ...) \ + do { \ + if ((link)->sdata->vif.valid_links) \ + _sdata_info((link)->sdata, "[link %d] " fmt, \ + (link)->link_id, \ + ##__VA_ARGS__); \ + else \ + _sdata_info((link)->sdata, fmt, ##__VA_ARGS__); \ + } while (0) +#define link_err(link, fmt, ...) \ + do { \ + if ((link)->sdata->vif.valid_links) \ + _sdata_err((link)->sdata, "[link %d] " fmt, \ + (link)->link_id, \ + ##__VA_ARGS__); \ + else \ + _sdata_err((link)->sdata, fmt, ##__VA_ARGS__); \ + } while (0) +#define link_dbg(link, fmt, ...) \ + do { \ + if ((link)->sdata->vif.valid_links) \ + _sdata_dbg(1, (link)->sdata, "[link %d] " fmt, \ + (link)->link_id, \ + ##__VA_ARGS__); \ + else \ + _sdata_dbg(1, (link)->sdata, fmt, \ + ##__VA_ARGS__); \ + } while (0) + +#define ht_dbg(sdata, fmt, ...) \ + _sdata_dbg(MAC80211_HT_DEBUG, \ + sdata, fmt, ##__VA_ARGS__) + +#define ht_dbg_ratelimited(sdata, fmt, ...) \ + _sdata_dbg(MAC80211_HT_DEBUG && net_ratelimit(), \ + sdata, fmt, ##__VA_ARGS__) + +#define ocb_dbg(sdata, fmt, ...) \ + _sdata_dbg(MAC80211_OCB_DEBUG, \ + sdata, fmt, ##__VA_ARGS__) + +#define ibss_dbg(sdata, fmt, ...) \ + _sdata_dbg(MAC80211_IBSS_DEBUG, \ + sdata, fmt, ##__VA_ARGS__) + +#define ps_dbg(sdata, fmt, ...) \ + _sdata_dbg(MAC80211_PS_DEBUG, \ + sdata, fmt, ##__VA_ARGS__) + +#define ps_dbg_hw(hw, fmt, ...) \ + _wiphy_dbg(MAC80211_PS_DEBUG, \ + (hw)->wiphy, fmt, ##__VA_ARGS__) + +#define ps_dbg_ratelimited(sdata, fmt, ...) \ + _sdata_dbg(MAC80211_PS_DEBUG && net_ratelimit(), \ + sdata, fmt, ##__VA_ARGS__) + +#define mpl_dbg(sdata, fmt, ...) \ + _sdata_dbg(MAC80211_MPL_DEBUG, \ + sdata, fmt, ##__VA_ARGS__) + +#define mpath_dbg(sdata, fmt, ...) \ + _sdata_dbg(MAC80211_MPATH_DEBUG, \ + sdata, fmt, ##__VA_ARGS__) + +#define mhwmp_dbg(sdata, fmt, ...) \ + _sdata_dbg(MAC80211_MHWMP_DEBUG, \ + sdata, fmt, ##__VA_ARGS__) + +#define msync_dbg(sdata, fmt, ...) \ + _sdata_dbg(MAC80211_MESH_SYNC_DEBUG, \ + sdata, fmt, ##__VA_ARGS__) + +#define mcsa_dbg(sdata, fmt, ...) \ + _sdata_dbg(MAC80211_MESH_CSA_DEBUG, \ + sdata, fmt, ##__VA_ARGS__) + +#define mps_dbg(sdata, fmt, ...) \ + _sdata_dbg(MAC80211_MESH_PS_DEBUG, \ + sdata, fmt, ##__VA_ARGS__) + +#define tdls_dbg(sdata, fmt, ...) \ + _sdata_dbg(MAC80211_TDLS_DEBUG, \ + sdata, fmt, ##__VA_ARGS__) + +#define sta_dbg(sdata, fmt, ...) \ + _sdata_dbg(MAC80211_STA_DEBUG, \ + sdata, fmt, ##__VA_ARGS__) + +#define mlme_dbg(sdata, fmt, ...) \ + _sdata_dbg(MAC80211_MLME_DEBUG, \ + sdata, fmt, ##__VA_ARGS__) + +#define mlme_dbg_ratelimited(sdata, fmt, ...) \ + _sdata_dbg(MAC80211_MLME_DEBUG && net_ratelimit(), \ + sdata, fmt, ##__VA_ARGS__) + +#endif /* __MAC80211_DEBUG_H */ diff --git a/net/mac80211/debugfs.c b/net/mac80211/debugfs.c new file mode 100644 index 000000000..78c7d60e8 --- /dev/null +++ b/net/mac80211/debugfs.c @@ -0,0 +1,711 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * mac80211 debugfs for wireless PHYs + * + * Copyright 2007 Johannes Berg + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright (C) 2018 - 2019, 2021-2022 Intel Corporation + */ + +#include +#include +#include +#include "ieee80211_i.h" +#include "driver-ops.h" +#include "rate.h" +#include "debugfs.h" + +#define DEBUGFS_FORMAT_BUFFER_SIZE 100 + +int mac80211_format_buffer(char __user *userbuf, size_t count, + loff_t *ppos, char *fmt, ...) +{ + va_list args; + char buf[DEBUGFS_FORMAT_BUFFER_SIZE]; + int res; + + va_start(args, fmt); + res = vscnprintf(buf, sizeof(buf), fmt, args); + va_end(args); + + return simple_read_from_buffer(userbuf, count, ppos, buf, res); +} + +#define DEBUGFS_READONLY_FILE_FN(name, fmt, value...) \ +static ssize_t name## _read(struct file *file, char __user *userbuf, \ + size_t count, loff_t *ppos) \ +{ \ + struct ieee80211_local *local = file->private_data; \ + \ + return mac80211_format_buffer(userbuf, count, ppos, \ + fmt "\n", ##value); \ +} + +#define DEBUGFS_READONLY_FILE_OPS(name) \ +static const struct file_operations name## _ops = { \ + .read = name## _read, \ + .open = simple_open, \ + .llseek = generic_file_llseek, \ +}; + +#define DEBUGFS_READONLY_FILE(name, fmt, value...) \ + DEBUGFS_READONLY_FILE_FN(name, fmt, value) \ + DEBUGFS_READONLY_FILE_OPS(name) + +#define DEBUGFS_ADD(name) \ + debugfs_create_file(#name, 0400, phyd, local, &name## _ops) + +#define DEBUGFS_ADD_MODE(name, mode) \ + debugfs_create_file(#name, mode, phyd, local, &name## _ops); + + +DEBUGFS_READONLY_FILE(hw_conf, "%x", + local->hw.conf.flags); +DEBUGFS_READONLY_FILE(user_power, "%d", + local->user_power_level); +DEBUGFS_READONLY_FILE(power, "%d", + local->hw.conf.power_level); +DEBUGFS_READONLY_FILE(total_ps_buffered, "%d", + local->total_ps_buffered); +DEBUGFS_READONLY_FILE(wep_iv, "%#08x", + local->wep_iv & 0xffffff); +DEBUGFS_READONLY_FILE(rate_ctrl_alg, "%s", + local->rate_ctrl ? local->rate_ctrl->ops->name : "hw/driver"); + +static ssize_t aqm_read(struct file *file, + char __user *user_buf, + size_t count, + loff_t *ppos) +{ + struct ieee80211_local *local = file->private_data; + struct fq *fq = &local->fq; + char buf[200]; + int len = 0; + + spin_lock_bh(&local->fq.lock); + rcu_read_lock(); + + len = scnprintf(buf, sizeof(buf), + "access name value\n" + "R fq_flows_cnt %u\n" + "R fq_backlog %u\n" + "R fq_overlimit %u\n" + "R fq_overmemory %u\n" + "R fq_collisions %u\n" + "R fq_memory_usage %u\n" + "RW fq_memory_limit %u\n" + "RW fq_limit %u\n" + "RW fq_quantum %u\n", + fq->flows_cnt, + fq->backlog, + fq->overmemory, + fq->overlimit, + fq->collisions, + fq->memory_usage, + fq->memory_limit, + fq->limit, + fq->quantum); + + rcu_read_unlock(); + spin_unlock_bh(&local->fq.lock); + + return simple_read_from_buffer(user_buf, count, ppos, + buf, len); +} + +static ssize_t aqm_write(struct file *file, + const char __user *user_buf, + size_t count, + loff_t *ppos) +{ + struct ieee80211_local *local = file->private_data; + char buf[100]; + + if (count >= sizeof(buf)) + return -EINVAL; + + if (copy_from_user(buf, user_buf, count)) + return -EFAULT; + + if (count && buf[count - 1] == '\n') + buf[count - 1] = '\0'; + else + buf[count] = '\0'; + + if (sscanf(buf, "fq_limit %u", &local->fq.limit) == 1) + return count; + else if (sscanf(buf, "fq_memory_limit %u", &local->fq.memory_limit) == 1) + return count; + else if (sscanf(buf, "fq_quantum %u", &local->fq.quantum) == 1) + return count; + + return -EINVAL; +} + +static const struct file_operations aqm_ops = { + .write = aqm_write, + .read = aqm_read, + .open = simple_open, + .llseek = default_llseek, +}; + +static ssize_t airtime_flags_read(struct file *file, + char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct ieee80211_local *local = file->private_data; + char buf[128] = {}, *pos, *end; + + pos = buf; + end = pos + sizeof(buf) - 1; + + if (local->airtime_flags & AIRTIME_USE_TX) + pos += scnprintf(pos, end - pos, "AIRTIME_TX\t(%lx)\n", + AIRTIME_USE_TX); + if (local->airtime_flags & AIRTIME_USE_RX) + pos += scnprintf(pos, end - pos, "AIRTIME_RX\t(%lx)\n", + AIRTIME_USE_RX); + + return simple_read_from_buffer(user_buf, count, ppos, buf, + strlen(buf)); +} + +static ssize_t airtime_flags_write(struct file *file, + const char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct ieee80211_local *local = file->private_data; + char buf[16]; + + if (count >= sizeof(buf)) + return -EINVAL; + + if (copy_from_user(buf, user_buf, count)) + return -EFAULT; + + if (count && buf[count - 1] == '\n') + buf[count - 1] = '\0'; + else + buf[count] = '\0'; + + if (kstrtou16(buf, 0, &local->airtime_flags)) + return -EINVAL; + + return count; +} + +static const struct file_operations airtime_flags_ops = { + .write = airtime_flags_write, + .read = airtime_flags_read, + .open = simple_open, + .llseek = default_llseek, +}; + +static ssize_t aql_pending_read(struct file *file, + char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct ieee80211_local *local = file->private_data; + char buf[400]; + int len = 0; + + len = scnprintf(buf, sizeof(buf), + "AC AQL pending\n" + "VO %u us\n" + "VI %u us\n" + "BE %u us\n" + "BK %u us\n" + "total %u us\n", + atomic_read(&local->aql_ac_pending_airtime[IEEE80211_AC_VO]), + atomic_read(&local->aql_ac_pending_airtime[IEEE80211_AC_VI]), + atomic_read(&local->aql_ac_pending_airtime[IEEE80211_AC_BE]), + atomic_read(&local->aql_ac_pending_airtime[IEEE80211_AC_BK]), + atomic_read(&local->aql_total_pending_airtime)); + return simple_read_from_buffer(user_buf, count, ppos, + buf, len); +} + +static const struct file_operations aql_pending_ops = { + .read = aql_pending_read, + .open = simple_open, + .llseek = default_llseek, +}; + +static ssize_t aql_txq_limit_read(struct file *file, + char __user *user_buf, + size_t count, + loff_t *ppos) +{ + struct ieee80211_local *local = file->private_data; + char buf[400]; + int len = 0; + + len = scnprintf(buf, sizeof(buf), + "AC AQL limit low AQL limit high\n" + "VO %u %u\n" + "VI %u %u\n" + "BE %u %u\n" + "BK %u %u\n", + local->aql_txq_limit_low[IEEE80211_AC_VO], + local->aql_txq_limit_high[IEEE80211_AC_VO], + local->aql_txq_limit_low[IEEE80211_AC_VI], + local->aql_txq_limit_high[IEEE80211_AC_VI], + local->aql_txq_limit_low[IEEE80211_AC_BE], + local->aql_txq_limit_high[IEEE80211_AC_BE], + local->aql_txq_limit_low[IEEE80211_AC_BK], + local->aql_txq_limit_high[IEEE80211_AC_BK]); + return simple_read_from_buffer(user_buf, count, ppos, + buf, len); +} + +static ssize_t aql_txq_limit_write(struct file *file, + const char __user *user_buf, + size_t count, + loff_t *ppos) +{ + struct ieee80211_local *local = file->private_data; + char buf[100]; + u32 ac, q_limit_low, q_limit_high, q_limit_low_old, q_limit_high_old; + struct sta_info *sta; + + if (count >= sizeof(buf)) + return -EINVAL; + + if (copy_from_user(buf, user_buf, count)) + return -EFAULT; + + if (count && buf[count - 1] == '\n') + buf[count - 1] = '\0'; + else + buf[count] = '\0'; + + if (sscanf(buf, "%u %u %u", &ac, &q_limit_low, &q_limit_high) != 3) + return -EINVAL; + + if (ac >= IEEE80211_NUM_ACS) + return -EINVAL; + + q_limit_low_old = local->aql_txq_limit_low[ac]; + q_limit_high_old = local->aql_txq_limit_high[ac]; + + local->aql_txq_limit_low[ac] = q_limit_low; + local->aql_txq_limit_high[ac] = q_limit_high; + + mutex_lock(&local->sta_mtx); + list_for_each_entry(sta, &local->sta_list, list) { + /* If a sta has customized queue limits, keep it */ + if (sta->airtime[ac].aql_limit_low == q_limit_low_old && + sta->airtime[ac].aql_limit_high == q_limit_high_old) { + sta->airtime[ac].aql_limit_low = q_limit_low; + sta->airtime[ac].aql_limit_high = q_limit_high; + } + } + mutex_unlock(&local->sta_mtx); + return count; +} + +static const struct file_operations aql_txq_limit_ops = { + .write = aql_txq_limit_write, + .read = aql_txq_limit_read, + .open = simple_open, + .llseek = default_llseek, +}; + +static ssize_t aql_enable_read(struct file *file, char __user *user_buf, + size_t count, loff_t *ppos) +{ + char buf[3]; + int len; + + len = scnprintf(buf, sizeof(buf), "%d\n", + !static_key_false(&aql_disable.key)); + + return simple_read_from_buffer(user_buf, count, ppos, buf, len); +} + +static ssize_t aql_enable_write(struct file *file, const char __user *user_buf, + size_t count, loff_t *ppos) +{ + bool aql_disabled = static_key_false(&aql_disable.key); + char buf[3]; + size_t len; + + if (count > sizeof(buf)) + return -EINVAL; + + if (copy_from_user(buf, user_buf, count)) + return -EFAULT; + + buf[sizeof(buf) - 1] = '\0'; + len = strlen(buf); + if (len > 0 && buf[len - 1] == '\n') + buf[len - 1] = 0; + + if (buf[0] == '0' && buf[1] == '\0') { + if (!aql_disabled) + static_branch_inc(&aql_disable); + } else if (buf[0] == '1' && buf[1] == '\0') { + if (aql_disabled) + static_branch_dec(&aql_disable); + } else { + return -EINVAL; + } + + return count; +} + +static const struct file_operations aql_enable_ops = { + .write = aql_enable_write, + .read = aql_enable_read, + .open = simple_open, + .llseek = default_llseek, +}; + +static ssize_t force_tx_status_read(struct file *file, + char __user *user_buf, + size_t count, + loff_t *ppos) +{ + struct ieee80211_local *local = file->private_data; + char buf[3]; + int len = 0; + + len = scnprintf(buf, sizeof(buf), "%d\n", (int)local->force_tx_status); + + return simple_read_from_buffer(user_buf, count, ppos, + buf, len); +} + +static ssize_t force_tx_status_write(struct file *file, + const char __user *user_buf, + size_t count, + loff_t *ppos) +{ + struct ieee80211_local *local = file->private_data; + char buf[3]; + + if (count >= sizeof(buf)) + return -EINVAL; + + if (copy_from_user(buf, user_buf, count)) + return -EFAULT; + + if (count && buf[count - 1] == '\n') + buf[count - 1] = '\0'; + else + buf[count] = '\0'; + + if (buf[0] == '0' && buf[1] == '\0') + local->force_tx_status = 0; + else if (buf[0] == '1' && buf[1] == '\0') + local->force_tx_status = 1; + else + return -EINVAL; + + return count; +} + +static const struct file_operations force_tx_status_ops = { + .write = force_tx_status_write, + .read = force_tx_status_read, + .open = simple_open, + .llseek = default_llseek, +}; + +#ifdef CONFIG_PM +static ssize_t reset_write(struct file *file, const char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct ieee80211_local *local = file->private_data; + int ret; + + rtnl_lock(); + wiphy_lock(local->hw.wiphy); + __ieee80211_suspend(&local->hw, NULL); + ret = __ieee80211_resume(&local->hw); + wiphy_unlock(local->hw.wiphy); + + if (ret) + cfg80211_shutdown_all_interfaces(local->hw.wiphy); + + rtnl_unlock(); + + return count; +} + +static const struct file_operations reset_ops = { + .write = reset_write, + .open = simple_open, + .llseek = noop_llseek, +}; +#endif + +static const char *hw_flag_names[] = { +#define FLAG(F) [IEEE80211_HW_##F] = #F + FLAG(HAS_RATE_CONTROL), + FLAG(RX_INCLUDES_FCS), + FLAG(HOST_BROADCAST_PS_BUFFERING), + FLAG(SIGNAL_UNSPEC), + FLAG(SIGNAL_DBM), + FLAG(NEED_DTIM_BEFORE_ASSOC), + FLAG(SPECTRUM_MGMT), + FLAG(AMPDU_AGGREGATION), + FLAG(SUPPORTS_PS), + FLAG(PS_NULLFUNC_STACK), + FLAG(SUPPORTS_DYNAMIC_PS), + FLAG(MFP_CAPABLE), + FLAG(WANT_MONITOR_VIF), + FLAG(NO_AUTO_VIF), + FLAG(SW_CRYPTO_CONTROL), + FLAG(SUPPORT_FAST_XMIT), + FLAG(REPORTS_TX_ACK_STATUS), + FLAG(CONNECTION_MONITOR), + FLAG(QUEUE_CONTROL), + FLAG(SUPPORTS_PER_STA_GTK), + FLAG(AP_LINK_PS), + FLAG(TX_AMPDU_SETUP_IN_HW), + FLAG(SUPPORTS_RC_TABLE), + FLAG(P2P_DEV_ADDR_FOR_INTF), + FLAG(TIMING_BEACON_ONLY), + FLAG(SUPPORTS_HT_CCK_RATES), + FLAG(CHANCTX_STA_CSA), + FLAG(SUPPORTS_CLONED_SKBS), + FLAG(SINGLE_SCAN_ON_ALL_BANDS), + FLAG(TDLS_WIDER_BW), + FLAG(SUPPORTS_AMSDU_IN_AMPDU), + FLAG(BEACON_TX_STATUS), + FLAG(NEEDS_UNIQUE_STA_ADDR), + FLAG(SUPPORTS_REORDERING_BUFFER), + FLAG(USES_RSS), + FLAG(TX_AMSDU), + FLAG(TX_FRAG_LIST), + FLAG(REPORTS_LOW_ACK), + FLAG(SUPPORTS_TX_FRAG), + FLAG(SUPPORTS_TDLS_BUFFER_STA), + FLAG(DEAUTH_NEED_MGD_TX_PREP), + FLAG(DOESNT_SUPPORT_QOS_NDP), + FLAG(BUFF_MMPDU_TXQ), + FLAG(SUPPORTS_VHT_EXT_NSS_BW), + FLAG(STA_MMPDU_TXQ), + FLAG(TX_STATUS_NO_AMPDU_LEN), + FLAG(SUPPORTS_MULTI_BSSID), + FLAG(SUPPORTS_ONLY_HE_MULTI_BSSID), + FLAG(AMPDU_KEYBORDER_SUPPORT), + FLAG(SUPPORTS_TX_ENCAP_OFFLOAD), + FLAG(SUPPORTS_RX_DECAP_OFFLOAD), + FLAG(SUPPORTS_CONC_MON_RX_DECAP), + FLAG(DETECTS_COLOR_COLLISION), + FLAG(MLO_MCAST_MULTI_LINK_TX), +#undef FLAG +}; + +static ssize_t hwflags_read(struct file *file, char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct ieee80211_local *local = file->private_data; + size_t bufsz = 30 * NUM_IEEE80211_HW_FLAGS; + char *buf = kzalloc(bufsz, GFP_KERNEL); + char *pos = buf, *end = buf + bufsz - 1; + ssize_t rv; + int i; + + if (!buf) + return -ENOMEM; + + /* fail compilation if somebody adds or removes + * a flag without updating the name array above + */ + BUILD_BUG_ON(ARRAY_SIZE(hw_flag_names) != NUM_IEEE80211_HW_FLAGS); + + for (i = 0; i < NUM_IEEE80211_HW_FLAGS; i++) { + if (test_bit(i, local->hw.flags)) + pos += scnprintf(pos, end - pos, "%s\n", + hw_flag_names[i]); + } + + rv = simple_read_from_buffer(user_buf, count, ppos, buf, strlen(buf)); + kfree(buf); + return rv; +} + +static ssize_t misc_read(struct file *file, char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct ieee80211_local *local = file->private_data; + /* Max len of each line is 16 characters, plus 9 for 'pending:\n' */ + size_t bufsz = IEEE80211_MAX_QUEUES * 16 + 9; + char *buf; + char *pos, *end; + ssize_t rv; + int i; + int ln; + + buf = kzalloc(bufsz, GFP_KERNEL); + if (!buf) + return -ENOMEM; + + pos = buf; + end = buf + bufsz - 1; + + pos += scnprintf(pos, end - pos, "pending:\n"); + + for (i = 0; i < IEEE80211_MAX_QUEUES; i++) { + ln = skb_queue_len(&local->pending[i]); + pos += scnprintf(pos, end - pos, "[%i] %d\n", + i, ln); + } + + rv = simple_read_from_buffer(user_buf, count, ppos, buf, strlen(buf)); + kfree(buf); + return rv; +} + +static ssize_t queues_read(struct file *file, char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct ieee80211_local *local = file->private_data; + unsigned long flags; + char buf[IEEE80211_MAX_QUEUES * 20]; + int q, res = 0; + + spin_lock_irqsave(&local->queue_stop_reason_lock, flags); + for (q = 0; q < local->hw.queues; q++) + res += sprintf(buf + res, "%02d: %#.8lx/%d\n", q, + local->queue_stop_reasons[q], + skb_queue_len(&local->pending[q])); + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); + + return simple_read_from_buffer(user_buf, count, ppos, buf, res); +} + +DEBUGFS_READONLY_FILE_OPS(hwflags); +DEBUGFS_READONLY_FILE_OPS(queues); +DEBUGFS_READONLY_FILE_OPS(misc); + +/* statistics stuff */ + +static ssize_t format_devstat_counter(struct ieee80211_local *local, + char __user *userbuf, + size_t count, loff_t *ppos, + int (*printvalue)(struct ieee80211_low_level_stats *stats, char *buf, + int buflen)) +{ + struct ieee80211_low_level_stats stats; + char buf[20]; + int res; + + rtnl_lock(); + res = drv_get_stats(local, &stats); + rtnl_unlock(); + if (res) + return res; + res = printvalue(&stats, buf, sizeof(buf)); + return simple_read_from_buffer(userbuf, count, ppos, buf, res); +} + +#define DEBUGFS_DEVSTATS_FILE(name) \ +static int print_devstats_##name(struct ieee80211_low_level_stats *stats,\ + char *buf, int buflen) \ +{ \ + return scnprintf(buf, buflen, "%u\n", stats->name); \ +} \ +static ssize_t stats_ ##name## _read(struct file *file, \ + char __user *userbuf, \ + size_t count, loff_t *ppos) \ +{ \ + return format_devstat_counter(file->private_data, \ + userbuf, \ + count, \ + ppos, \ + print_devstats_##name); \ +} \ + \ +static const struct file_operations stats_ ##name## _ops = { \ + .read = stats_ ##name## _read, \ + .open = simple_open, \ + .llseek = generic_file_llseek, \ +}; + +#ifdef CONFIG_MAC80211_DEBUG_COUNTERS +#define DEBUGFS_STATS_ADD(name) \ + debugfs_create_u32(#name, 0400, statsd, &local->name); +#endif +#define DEBUGFS_DEVSTATS_ADD(name) \ + debugfs_create_file(#name, 0400, statsd, local, &stats_ ##name## _ops); + +DEBUGFS_DEVSTATS_FILE(dot11ACKFailureCount); +DEBUGFS_DEVSTATS_FILE(dot11RTSFailureCount); +DEBUGFS_DEVSTATS_FILE(dot11FCSErrorCount); +DEBUGFS_DEVSTATS_FILE(dot11RTSSuccessCount); + +void debugfs_hw_add(struct ieee80211_local *local) +{ + struct dentry *phyd = local->hw.wiphy->debugfsdir; + struct dentry *statsd; + + if (!phyd) + return; + + local->debugfs.keys = debugfs_create_dir("keys", phyd); + + DEBUGFS_ADD(total_ps_buffered); + DEBUGFS_ADD(wep_iv); + DEBUGFS_ADD(rate_ctrl_alg); + DEBUGFS_ADD(queues); + DEBUGFS_ADD(misc); +#ifdef CONFIG_PM + DEBUGFS_ADD_MODE(reset, 0200); +#endif + DEBUGFS_ADD(hwflags); + DEBUGFS_ADD(user_power); + DEBUGFS_ADD(power); + DEBUGFS_ADD(hw_conf); + DEBUGFS_ADD_MODE(force_tx_status, 0600); + DEBUGFS_ADD_MODE(aql_enable, 0600); + DEBUGFS_ADD(aql_pending); + + if (local->ops->wake_tx_queue) + DEBUGFS_ADD_MODE(aqm, 0600); + + DEBUGFS_ADD_MODE(airtime_flags, 0600); + + DEBUGFS_ADD(aql_txq_limit); + debugfs_create_u32("aql_threshold", 0600, + phyd, &local->aql_threshold); + + statsd = debugfs_create_dir("statistics", phyd); + + /* if the dir failed, don't put all the other things into the root! */ + if (!statsd) + return; + +#ifdef CONFIG_MAC80211_DEBUG_COUNTERS + DEBUGFS_STATS_ADD(dot11TransmittedFragmentCount); + DEBUGFS_STATS_ADD(dot11MulticastTransmittedFrameCount); + DEBUGFS_STATS_ADD(dot11FailedCount); + DEBUGFS_STATS_ADD(dot11RetryCount); + DEBUGFS_STATS_ADD(dot11MultipleRetryCount); + DEBUGFS_STATS_ADD(dot11FrameDuplicateCount); + DEBUGFS_STATS_ADD(dot11ReceivedFragmentCount); + DEBUGFS_STATS_ADD(dot11MulticastReceivedFrameCount); + DEBUGFS_STATS_ADD(dot11TransmittedFrameCount); + DEBUGFS_STATS_ADD(tx_handlers_drop); + DEBUGFS_STATS_ADD(tx_handlers_queued); + DEBUGFS_STATS_ADD(tx_handlers_drop_wep); + DEBUGFS_STATS_ADD(tx_handlers_drop_not_assoc); + DEBUGFS_STATS_ADD(tx_handlers_drop_unauth_port); + DEBUGFS_STATS_ADD(rx_handlers_drop); + DEBUGFS_STATS_ADD(rx_handlers_queued); + DEBUGFS_STATS_ADD(rx_handlers_drop_nullfunc); + DEBUGFS_STATS_ADD(rx_handlers_drop_defrag); + DEBUGFS_STATS_ADD(tx_expand_skb_head); + DEBUGFS_STATS_ADD(tx_expand_skb_head_cloned); + DEBUGFS_STATS_ADD(rx_expand_skb_head_defrag); + DEBUGFS_STATS_ADD(rx_handlers_fragments); + DEBUGFS_STATS_ADD(tx_status_drop); +#endif + DEBUGFS_DEVSTATS_ADD(dot11ACKFailureCount); + DEBUGFS_DEVSTATS_ADD(dot11RTSFailureCount); + DEBUGFS_DEVSTATS_ADD(dot11FCSErrorCount); + DEBUGFS_DEVSTATS_ADD(dot11RTSSuccessCount); +} diff --git a/net/mac80211/debugfs.h b/net/mac80211/debugfs.h new file mode 100644 index 000000000..d2c424787 --- /dev/null +++ b/net/mac80211/debugfs.h @@ -0,0 +1,17 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __MAC80211_DEBUGFS_H +#define __MAC80211_DEBUGFS_H + +#include "ieee80211_i.h" + +#ifdef CONFIG_MAC80211_DEBUGFS +void debugfs_hw_add(struct ieee80211_local *local); +int __printf(4, 5) mac80211_format_buffer(char __user *userbuf, size_t count, + loff_t *ppos, char *fmt, ...); +#else +static inline void debugfs_hw_add(struct ieee80211_local *local) +{ +} +#endif + +#endif /* __MAC80211_DEBUGFS_H */ diff --git a/net/mac80211/debugfs_key.c b/net/mac80211/debugfs_key.c new file mode 100644 index 000000000..16a04330e --- /dev/null +++ b/net/mac80211/debugfs_key.c @@ -0,0 +1,472 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2003-2005 Devicescape Software, Inc. + * Copyright (c) 2006 Jiri Benc + * Copyright 2007 Johannes Berg + * Copyright (C) 2015 Intel Deutschland GmbH + * Copyright (C) 2021-2022 Intel Corporation + */ + +#include +#include +#include "ieee80211_i.h" +#include "key.h" +#include "debugfs.h" +#include "debugfs_key.h" + +#define KEY_READ(name, prop, format_string) \ +static ssize_t key_##name##_read(struct file *file, \ + char __user *userbuf, \ + size_t count, loff_t *ppos) \ +{ \ + struct ieee80211_key *key = file->private_data; \ + return mac80211_format_buffer(userbuf, count, ppos, \ + format_string, key->prop); \ +} +#define KEY_READ_X(name) KEY_READ(name, name, "0x%x\n") + +#define KEY_OPS(name) \ +static const struct file_operations key_ ##name## _ops = { \ + .read = key_##name##_read, \ + .open = simple_open, \ + .llseek = generic_file_llseek, \ +} + +#define KEY_OPS_W(name) \ +static const struct file_operations key_ ##name## _ops = { \ + .read = key_##name##_read, \ + .write = key_##name##_write, \ + .open = simple_open, \ + .llseek = generic_file_llseek, \ +} + +#define KEY_FILE(name, format) \ + KEY_READ_##format(name) \ + KEY_OPS(name) + +#define KEY_CONF_READ(name, format_string) \ + KEY_READ(conf_##name, conf.name, format_string) +#define KEY_CONF_READ_D(name) KEY_CONF_READ(name, "%d\n") + +#define KEY_CONF_OPS(name) \ +static const struct file_operations key_ ##name## _ops = { \ + .read = key_conf_##name##_read, \ + .open = simple_open, \ + .llseek = generic_file_llseek, \ +} + +#define KEY_CONF_FILE(name, format) \ + KEY_CONF_READ_##format(name) \ + KEY_CONF_OPS(name) + +KEY_CONF_FILE(keylen, D); +KEY_CONF_FILE(keyidx, D); +KEY_CONF_FILE(hw_key_idx, D); +KEY_FILE(flags, X); +KEY_READ(ifindex, sdata->name, "%s\n"); +KEY_OPS(ifindex); + +static ssize_t key_algorithm_read(struct file *file, + char __user *userbuf, + size_t count, loff_t *ppos) +{ + char buf[15]; + struct ieee80211_key *key = file->private_data; + u32 c = key->conf.cipher; + + sprintf(buf, "%.2x-%.2x-%.2x:%d\n", + c >> 24, (c >> 16) & 0xff, (c >> 8) & 0xff, c & 0xff); + return simple_read_from_buffer(userbuf, count, ppos, buf, strlen(buf)); +} +KEY_OPS(algorithm); + +static ssize_t key_tx_spec_write(struct file *file, const char __user *userbuf, + size_t count, loff_t *ppos) +{ + struct ieee80211_key *key = file->private_data; + u64 pn; + int ret; + + switch (key->conf.cipher) { + case WLAN_CIPHER_SUITE_WEP40: + case WLAN_CIPHER_SUITE_WEP104: + return -EINVAL; + case WLAN_CIPHER_SUITE_TKIP: + /* not supported yet */ + return -EOPNOTSUPP; + case WLAN_CIPHER_SUITE_CCMP: + case WLAN_CIPHER_SUITE_CCMP_256: + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + ret = kstrtou64_from_user(userbuf, count, 16, &pn); + if (ret) + return ret; + /* PN is a 48-bit counter */ + if (pn >= (1ULL << 48)) + return -ERANGE; + atomic64_set(&key->conf.tx_pn, pn); + return count; + default: + return 0; + } +} + +static ssize_t key_tx_spec_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ + u64 pn; + char buf[20]; + int len; + struct ieee80211_key *key = file->private_data; + + switch (key->conf.cipher) { + case WLAN_CIPHER_SUITE_WEP40: + case WLAN_CIPHER_SUITE_WEP104: + len = scnprintf(buf, sizeof(buf), "\n"); + break; + case WLAN_CIPHER_SUITE_TKIP: + pn = atomic64_read(&key->conf.tx_pn); + len = scnprintf(buf, sizeof(buf), "%08x %04x\n", + TKIP_PN_TO_IV32(pn), + TKIP_PN_TO_IV16(pn)); + break; + case WLAN_CIPHER_SUITE_CCMP: + case WLAN_CIPHER_SUITE_CCMP_256: + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + pn = atomic64_read(&key->conf.tx_pn); + len = scnprintf(buf, sizeof(buf), "%02x%02x%02x%02x%02x%02x\n", + (u8)(pn >> 40), (u8)(pn >> 32), (u8)(pn >> 24), + (u8)(pn >> 16), (u8)(pn >> 8), (u8)pn); + break; + default: + return 0; + } + return simple_read_from_buffer(userbuf, count, ppos, buf, len); +} +KEY_OPS_W(tx_spec); + +static ssize_t key_rx_spec_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ + struct ieee80211_key *key = file->private_data; + char buf[14*IEEE80211_NUM_TIDS+1], *p = buf; + int i, len; + const u8 *rpn; + + switch (key->conf.cipher) { + case WLAN_CIPHER_SUITE_WEP40: + case WLAN_CIPHER_SUITE_WEP104: + len = scnprintf(buf, sizeof(buf), "\n"); + break; + case WLAN_CIPHER_SUITE_TKIP: + for (i = 0; i < IEEE80211_NUM_TIDS; i++) + p += scnprintf(p, sizeof(buf)+buf-p, + "%08x %04x\n", + key->u.tkip.rx[i].iv32, + key->u.tkip.rx[i].iv16); + len = p - buf; + break; + case WLAN_CIPHER_SUITE_CCMP: + case WLAN_CIPHER_SUITE_CCMP_256: + for (i = 0; i < IEEE80211_NUM_TIDS + 1; i++) { + rpn = key->u.ccmp.rx_pn[i]; + p += scnprintf(p, sizeof(buf)+buf-p, + "%02x%02x%02x%02x%02x%02x\n", + rpn[0], rpn[1], rpn[2], + rpn[3], rpn[4], rpn[5]); + } + len = p - buf; + break; + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + rpn = key->u.aes_cmac.rx_pn; + p += scnprintf(p, sizeof(buf)+buf-p, + "%02x%02x%02x%02x%02x%02x\n", + rpn[0], rpn[1], rpn[2], + rpn[3], rpn[4], rpn[5]); + len = p - buf; + break; + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + rpn = key->u.aes_gmac.rx_pn; + p += scnprintf(p, sizeof(buf)+buf-p, + "%02x%02x%02x%02x%02x%02x\n", + rpn[0], rpn[1], rpn[2], + rpn[3], rpn[4], rpn[5]); + len = p - buf; + break; + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + for (i = 0; i < IEEE80211_NUM_TIDS + 1; i++) { + rpn = key->u.gcmp.rx_pn[i]; + p += scnprintf(p, sizeof(buf)+buf-p, + "%02x%02x%02x%02x%02x%02x\n", + rpn[0], rpn[1], rpn[2], + rpn[3], rpn[4], rpn[5]); + } + len = p - buf; + break; + default: + return 0; + } + return simple_read_from_buffer(userbuf, count, ppos, buf, len); +} +KEY_OPS(rx_spec); + +static ssize_t key_replays_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ + struct ieee80211_key *key = file->private_data; + char buf[20]; + int len; + + switch (key->conf.cipher) { + case WLAN_CIPHER_SUITE_CCMP: + case WLAN_CIPHER_SUITE_CCMP_256: + len = scnprintf(buf, sizeof(buf), "%u\n", key->u.ccmp.replays); + break; + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + len = scnprintf(buf, sizeof(buf), "%u\n", + key->u.aes_cmac.replays); + break; + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + len = scnprintf(buf, sizeof(buf), "%u\n", + key->u.aes_gmac.replays); + break; + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + len = scnprintf(buf, sizeof(buf), "%u\n", key->u.gcmp.replays); + break; + default: + return 0; + } + return simple_read_from_buffer(userbuf, count, ppos, buf, len); +} +KEY_OPS(replays); + +static ssize_t key_icverrors_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ + struct ieee80211_key *key = file->private_data; + char buf[20]; + int len; + + switch (key->conf.cipher) { + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + len = scnprintf(buf, sizeof(buf), "%u\n", + key->u.aes_cmac.icverrors); + break; + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + len = scnprintf(buf, sizeof(buf), "%u\n", + key->u.aes_gmac.icverrors); + break; + default: + return 0; + } + return simple_read_from_buffer(userbuf, count, ppos, buf, len); +} +KEY_OPS(icverrors); + +static ssize_t key_mic_failures_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ + struct ieee80211_key *key = file->private_data; + char buf[20]; + int len; + + if (key->conf.cipher != WLAN_CIPHER_SUITE_TKIP) + return -EINVAL; + + len = scnprintf(buf, sizeof(buf), "%u\n", key->u.tkip.mic_failures); + + return simple_read_from_buffer(userbuf, count, ppos, buf, len); +} +KEY_OPS(mic_failures); + +static ssize_t key_key_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ + struct ieee80211_key *key = file->private_data; + int i, bufsize = 2 * key->conf.keylen + 2; + char *buf = kmalloc(bufsize, GFP_KERNEL); + char *p = buf; + ssize_t res; + + if (!buf) + return -ENOMEM; + + for (i = 0; i < key->conf.keylen; i++) + p += scnprintf(p, bufsize + buf - p, "%02x", key->conf.key[i]); + p += scnprintf(p, bufsize+buf-p, "\n"); + res = simple_read_from_buffer(userbuf, count, ppos, buf, p - buf); + kfree(buf); + return res; +} +KEY_OPS(key); + +#define DEBUGFS_ADD(name) \ + debugfs_create_file(#name, 0400, key->debugfs.dir, \ + key, &key_##name##_ops) +#define DEBUGFS_ADD_W(name) \ + debugfs_create_file(#name, 0600, key->debugfs.dir, \ + key, &key_##name##_ops); + +void ieee80211_debugfs_key_add(struct ieee80211_key *key) +{ + static int keycount; + char buf[100]; + struct sta_info *sta; + + if (!key->local->debugfs.keys) + return; + + sprintf(buf, "%d", keycount); + key->debugfs.cnt = keycount; + keycount++; + key->debugfs.dir = debugfs_create_dir(buf, + key->local->debugfs.keys); + + sta = key->sta; + if (sta) { + sprintf(buf, "../../netdev:%s/stations/%pM", + sta->sdata->name, sta->sta.addr); + key->debugfs.stalink = + debugfs_create_symlink("station", key->debugfs.dir, buf); + } + + DEBUGFS_ADD(keylen); + DEBUGFS_ADD(flags); + DEBUGFS_ADD(keyidx); + DEBUGFS_ADD(hw_key_idx); + DEBUGFS_ADD(algorithm); + DEBUGFS_ADD_W(tx_spec); + DEBUGFS_ADD(rx_spec); + DEBUGFS_ADD(replays); + DEBUGFS_ADD(icverrors); + DEBUGFS_ADD(mic_failures); + DEBUGFS_ADD(key); + DEBUGFS_ADD(ifindex); +}; + +void ieee80211_debugfs_key_remove(struct ieee80211_key *key) +{ + if (!key) + return; + + debugfs_remove_recursive(key->debugfs.dir); + key->debugfs.dir = NULL; +} + +void ieee80211_debugfs_key_update_default(struct ieee80211_sub_if_data *sdata) +{ + char buf[50]; + struct ieee80211_key *key; + + if (!sdata->vif.debugfs_dir) + return; + + lockdep_assert_held(&sdata->local->key_mtx); + + debugfs_remove(sdata->debugfs.default_unicast_key); + sdata->debugfs.default_unicast_key = NULL; + + if (sdata->default_unicast_key) { + key = key_mtx_dereference(sdata->local, + sdata->default_unicast_key); + sprintf(buf, "../keys/%d", key->debugfs.cnt); + sdata->debugfs.default_unicast_key = + debugfs_create_symlink("default_unicast_key", + sdata->vif.debugfs_dir, buf); + } + + debugfs_remove(sdata->debugfs.default_multicast_key); + sdata->debugfs.default_multicast_key = NULL; + + if (sdata->deflink.default_multicast_key) { + key = key_mtx_dereference(sdata->local, + sdata->deflink.default_multicast_key); + sprintf(buf, "../keys/%d", key->debugfs.cnt); + sdata->debugfs.default_multicast_key = + debugfs_create_symlink("default_multicast_key", + sdata->vif.debugfs_dir, buf); + } +} + +void ieee80211_debugfs_key_add_mgmt_default(struct ieee80211_sub_if_data *sdata) +{ + char buf[50]; + struct ieee80211_key *key; + + if (!sdata->vif.debugfs_dir) + return; + + key = key_mtx_dereference(sdata->local, + sdata->deflink.default_mgmt_key); + if (key) { + sprintf(buf, "../keys/%d", key->debugfs.cnt); + sdata->debugfs.default_mgmt_key = + debugfs_create_symlink("default_mgmt_key", + sdata->vif.debugfs_dir, buf); + } else + ieee80211_debugfs_key_remove_mgmt_default(sdata); +} + +void ieee80211_debugfs_key_remove_mgmt_default(struct ieee80211_sub_if_data *sdata) +{ + if (!sdata) + return; + + debugfs_remove(sdata->debugfs.default_mgmt_key); + sdata->debugfs.default_mgmt_key = NULL; +} + +void +ieee80211_debugfs_key_add_beacon_default(struct ieee80211_sub_if_data *sdata) +{ + char buf[50]; + struct ieee80211_key *key; + + if (!sdata->vif.debugfs_dir) + return; + + key = key_mtx_dereference(sdata->local, + sdata->deflink.default_beacon_key); + if (key) { + sprintf(buf, "../keys/%d", key->debugfs.cnt); + sdata->debugfs.default_beacon_key = + debugfs_create_symlink("default_beacon_key", + sdata->vif.debugfs_dir, buf); + } else { + ieee80211_debugfs_key_remove_beacon_default(sdata); + } +} + +void +ieee80211_debugfs_key_remove_beacon_default(struct ieee80211_sub_if_data *sdata) +{ + if (!sdata) + return; + + debugfs_remove(sdata->debugfs.default_beacon_key); + sdata->debugfs.default_beacon_key = NULL; +} + +void ieee80211_debugfs_key_sta_del(struct ieee80211_key *key, + struct sta_info *sta) +{ + debugfs_remove(key->debugfs.stalink); + key->debugfs.stalink = NULL; +} diff --git a/net/mac80211/debugfs_key.h b/net/mac80211/debugfs_key.h new file mode 100644 index 000000000..af7cf495f --- /dev/null +++ b/net/mac80211/debugfs_key.h @@ -0,0 +1,44 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __MAC80211_DEBUGFS_KEY_H +#define __MAC80211_DEBUGFS_KEY_H + +#ifdef CONFIG_MAC80211_DEBUGFS +void ieee80211_debugfs_key_add(struct ieee80211_key *key); +void ieee80211_debugfs_key_remove(struct ieee80211_key *key); +void ieee80211_debugfs_key_update_default(struct ieee80211_sub_if_data *sdata); +void ieee80211_debugfs_key_add_mgmt_default( + struct ieee80211_sub_if_data *sdata); +void ieee80211_debugfs_key_remove_mgmt_default( + struct ieee80211_sub_if_data *sdata); +void ieee80211_debugfs_key_add_beacon_default( + struct ieee80211_sub_if_data *sdata); +void ieee80211_debugfs_key_remove_beacon_default( + struct ieee80211_sub_if_data *sdata); +void ieee80211_debugfs_key_sta_del(struct ieee80211_key *key, + struct sta_info *sta); +#else +static inline void ieee80211_debugfs_key_add(struct ieee80211_key *key) +{} +static inline void ieee80211_debugfs_key_remove(struct ieee80211_key *key) +{} +static inline void ieee80211_debugfs_key_update_default( + struct ieee80211_sub_if_data *sdata) +{} +static inline void ieee80211_debugfs_key_add_mgmt_default( + struct ieee80211_sub_if_data *sdata) +{} +static inline void ieee80211_debugfs_key_remove_mgmt_default( + struct ieee80211_sub_if_data *sdata) +{} +static inline void ieee80211_debugfs_key_add_beacon_default( + struct ieee80211_sub_if_data *sdata) +{} +static inline void ieee80211_debugfs_key_remove_beacon_default( + struct ieee80211_sub_if_data *sdata) +{} +static inline void ieee80211_debugfs_key_sta_del(struct ieee80211_key *key, + struct sta_info *sta) +{} +#endif + +#endif /* __MAC80211_DEBUGFS_KEY_H */ diff --git a/net/mac80211/debugfs_netdev.c b/net/mac80211/debugfs_netdev.c new file mode 100644 index 000000000..08a1d7564 --- /dev/null +++ b/net/mac80211/debugfs_netdev.c @@ -0,0 +1,862 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2006 Jiri Benc + * Copyright 2007 Johannes Berg + * Copyright (C) 2020-2022 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ieee80211_i.h" +#include "rate.h" +#include "debugfs.h" +#include "debugfs_netdev.h" +#include "driver-ops.h" + +static ssize_t ieee80211_if_read( + struct ieee80211_sub_if_data *sdata, + char __user *userbuf, + size_t count, loff_t *ppos, + ssize_t (*format)(const struct ieee80211_sub_if_data *, char *, int)) +{ + char buf[200]; + ssize_t ret = -EINVAL; + + read_lock(&dev_base_lock); + ret = (*format)(sdata, buf, sizeof(buf)); + read_unlock(&dev_base_lock); + + if (ret >= 0) + ret = simple_read_from_buffer(userbuf, count, ppos, buf, ret); + + return ret; +} + +static ssize_t ieee80211_if_write( + struct ieee80211_sub_if_data *sdata, + const char __user *userbuf, + size_t count, loff_t *ppos, + ssize_t (*write)(struct ieee80211_sub_if_data *, const char *, int)) +{ + char buf[64]; + ssize_t ret; + + if (count >= sizeof(buf)) + return -E2BIG; + + if (copy_from_user(buf, userbuf, count)) + return -EFAULT; + buf[count] = '\0'; + + rtnl_lock(); + ret = (*write)(sdata, buf, count); + rtnl_unlock(); + + return ret; +} + +#define IEEE80211_IF_FMT(name, field, format_string) \ +static ssize_t ieee80211_if_fmt_##name( \ + const struct ieee80211_sub_if_data *sdata, char *buf, \ + int buflen) \ +{ \ + return scnprintf(buf, buflen, format_string, sdata->field); \ +} +#define IEEE80211_IF_FMT_DEC(name, field) \ + IEEE80211_IF_FMT(name, field, "%d\n") +#define IEEE80211_IF_FMT_HEX(name, field) \ + IEEE80211_IF_FMT(name, field, "%#x\n") +#define IEEE80211_IF_FMT_LHEX(name, field) \ + IEEE80211_IF_FMT(name, field, "%#lx\n") + +#define IEEE80211_IF_FMT_HEXARRAY(name, field) \ +static ssize_t ieee80211_if_fmt_##name( \ + const struct ieee80211_sub_if_data *sdata, \ + char *buf, int buflen) \ +{ \ + char *p = buf; \ + int i; \ + for (i = 0; i < sizeof(sdata->field); i++) { \ + p += scnprintf(p, buflen + buf - p, "%.2x ", \ + sdata->field[i]); \ + } \ + p += scnprintf(p, buflen + buf - p, "\n"); \ + return p - buf; \ +} + +#define IEEE80211_IF_FMT_ATOMIC(name, field) \ +static ssize_t ieee80211_if_fmt_##name( \ + const struct ieee80211_sub_if_data *sdata, \ + char *buf, int buflen) \ +{ \ + return scnprintf(buf, buflen, "%d\n", atomic_read(&sdata->field));\ +} + +#define IEEE80211_IF_FMT_MAC(name, field) \ +static ssize_t ieee80211_if_fmt_##name( \ + const struct ieee80211_sub_if_data *sdata, char *buf, \ + int buflen) \ +{ \ + return scnprintf(buf, buflen, "%pM\n", sdata->field); \ +} + +#define IEEE80211_IF_FMT_JIFFIES_TO_MS(name, field) \ +static ssize_t ieee80211_if_fmt_##name( \ + const struct ieee80211_sub_if_data *sdata, \ + char *buf, int buflen) \ +{ \ + return scnprintf(buf, buflen, "%d\n", \ + jiffies_to_msecs(sdata->field)); \ +} + +#define _IEEE80211_IF_FILE_OPS(name, _read, _write) \ +static const struct file_operations name##_ops = { \ + .read = (_read), \ + .write = (_write), \ + .open = simple_open, \ + .llseek = generic_file_llseek, \ +} + +#define _IEEE80211_IF_FILE_R_FN(name) \ +static ssize_t ieee80211_if_read_##name(struct file *file, \ + char __user *userbuf, \ + size_t count, loff_t *ppos) \ +{ \ + return ieee80211_if_read(file->private_data, \ + userbuf, count, ppos, \ + ieee80211_if_fmt_##name); \ +} + +#define _IEEE80211_IF_FILE_W_FN(name) \ +static ssize_t ieee80211_if_write_##name(struct file *file, \ + const char __user *userbuf, \ + size_t count, loff_t *ppos) \ +{ \ + return ieee80211_if_write(file->private_data, userbuf, count, \ + ppos, ieee80211_if_parse_##name); \ +} + +#define IEEE80211_IF_FILE_R(name) \ + _IEEE80211_IF_FILE_R_FN(name) \ + _IEEE80211_IF_FILE_OPS(name, ieee80211_if_read_##name, NULL) + +#define IEEE80211_IF_FILE_W(name) \ + _IEEE80211_IF_FILE_W_FN(name) \ + _IEEE80211_IF_FILE_OPS(name, NULL, ieee80211_if_write_##name) + +#define IEEE80211_IF_FILE_RW(name) \ + _IEEE80211_IF_FILE_R_FN(name) \ + _IEEE80211_IF_FILE_W_FN(name) \ + _IEEE80211_IF_FILE_OPS(name, ieee80211_if_read_##name, \ + ieee80211_if_write_##name) + +#define IEEE80211_IF_FILE(name, field, format) \ + IEEE80211_IF_FMT_##format(name, field) \ + IEEE80211_IF_FILE_R(name) + +/* common attributes */ +IEEE80211_IF_FILE(rc_rateidx_mask_2ghz, rc_rateidx_mask[NL80211_BAND_2GHZ], + HEX); +IEEE80211_IF_FILE(rc_rateidx_mask_5ghz, rc_rateidx_mask[NL80211_BAND_5GHZ], + HEX); +IEEE80211_IF_FILE(rc_rateidx_mcs_mask_2ghz, + rc_rateidx_mcs_mask[NL80211_BAND_2GHZ], HEXARRAY); +IEEE80211_IF_FILE(rc_rateidx_mcs_mask_5ghz, + rc_rateidx_mcs_mask[NL80211_BAND_5GHZ], HEXARRAY); + +static ssize_t ieee80211_if_fmt_rc_rateidx_vht_mcs_mask_2ghz( + const struct ieee80211_sub_if_data *sdata, + char *buf, int buflen) +{ + int i, len = 0; + const u16 *mask = sdata->rc_rateidx_vht_mcs_mask[NL80211_BAND_2GHZ]; + + for (i = 0; i < NL80211_VHT_NSS_MAX; i++) + len += scnprintf(buf + len, buflen - len, "%04x ", mask[i]); + len += scnprintf(buf + len, buflen - len, "\n"); + + return len; +} + +IEEE80211_IF_FILE_R(rc_rateidx_vht_mcs_mask_2ghz); + +static ssize_t ieee80211_if_fmt_rc_rateidx_vht_mcs_mask_5ghz( + const struct ieee80211_sub_if_data *sdata, + char *buf, int buflen) +{ + int i, len = 0; + const u16 *mask = sdata->rc_rateidx_vht_mcs_mask[NL80211_BAND_5GHZ]; + + for (i = 0; i < NL80211_VHT_NSS_MAX; i++) + len += scnprintf(buf + len, buflen - len, "%04x ", mask[i]); + len += scnprintf(buf + len, buflen - len, "\n"); + + return len; +} + +IEEE80211_IF_FILE_R(rc_rateidx_vht_mcs_mask_5ghz); + +IEEE80211_IF_FILE(flags, flags, HEX); +IEEE80211_IF_FILE(state, state, LHEX); +IEEE80211_IF_FILE(txpower, vif.bss_conf.txpower, DEC); +IEEE80211_IF_FILE(ap_power_level, deflink.ap_power_level, DEC); +IEEE80211_IF_FILE(user_power_level, deflink.user_power_level, DEC); + +static ssize_t +ieee80211_if_fmt_hw_queues(const struct ieee80211_sub_if_data *sdata, + char *buf, int buflen) +{ + int len; + + len = scnprintf(buf, buflen, "AC queues: VO:%d VI:%d BE:%d BK:%d\n", + sdata->vif.hw_queue[IEEE80211_AC_VO], + sdata->vif.hw_queue[IEEE80211_AC_VI], + sdata->vif.hw_queue[IEEE80211_AC_BE], + sdata->vif.hw_queue[IEEE80211_AC_BK]); + + if (sdata->vif.type == NL80211_IFTYPE_AP) + len += scnprintf(buf + len, buflen - len, "cab queue: %d\n", + sdata->vif.cab_queue); + + return len; +} +IEEE80211_IF_FILE_R(hw_queues); + +/* STA attributes */ +IEEE80211_IF_FILE(bssid, deflink.u.mgd.bssid, MAC); +IEEE80211_IF_FILE(aid, vif.cfg.aid, DEC); +IEEE80211_IF_FILE(beacon_timeout, u.mgd.beacon_timeout, JIFFIES_TO_MS); + +static int ieee80211_set_smps(struct ieee80211_sub_if_data *sdata, + enum ieee80211_smps_mode smps_mode) +{ + struct ieee80211_local *local = sdata->local; + int err; + + if (!(local->hw.wiphy->features & NL80211_FEATURE_STATIC_SMPS) && + smps_mode == IEEE80211_SMPS_STATIC) + return -EINVAL; + + /* auto should be dynamic if in PS mode */ + if (!(local->hw.wiphy->features & NL80211_FEATURE_DYNAMIC_SMPS) && + (smps_mode == IEEE80211_SMPS_DYNAMIC || + smps_mode == IEEE80211_SMPS_AUTOMATIC)) + return -EINVAL; + + if (sdata->vif.type != NL80211_IFTYPE_STATION) + return -EOPNOTSUPP; + + sdata_lock(sdata); + err = __ieee80211_request_smps_mgd(sdata, &sdata->deflink, smps_mode); + sdata_unlock(sdata); + + return err; +} + +static const char *smps_modes[IEEE80211_SMPS_NUM_MODES] = { + [IEEE80211_SMPS_AUTOMATIC] = "auto", + [IEEE80211_SMPS_OFF] = "off", + [IEEE80211_SMPS_STATIC] = "static", + [IEEE80211_SMPS_DYNAMIC] = "dynamic", +}; + +static ssize_t ieee80211_if_fmt_smps(const struct ieee80211_sub_if_data *sdata, + char *buf, int buflen) +{ + if (sdata->vif.type == NL80211_IFTYPE_STATION) + return snprintf(buf, buflen, "request: %s\nused: %s\n", + smps_modes[sdata->deflink.u.mgd.req_smps], + smps_modes[sdata->deflink.smps_mode]); + return -EINVAL; +} + +static ssize_t ieee80211_if_parse_smps(struct ieee80211_sub_if_data *sdata, + const char *buf, int buflen) +{ + enum ieee80211_smps_mode mode; + + for (mode = 0; mode < IEEE80211_SMPS_NUM_MODES; mode++) { + if (strncmp(buf, smps_modes[mode], buflen) == 0) { + int err = ieee80211_set_smps(sdata, mode); + if (!err) + return buflen; + return err; + } + } + + return -EINVAL; +} +IEEE80211_IF_FILE_RW(smps); + +static ssize_t ieee80211_if_parse_tkip_mic_test( + struct ieee80211_sub_if_data *sdata, const char *buf, int buflen) +{ + struct ieee80211_local *local = sdata->local; + u8 addr[ETH_ALEN]; + struct sk_buff *skb; + struct ieee80211_hdr *hdr; + __le16 fc; + + if (!mac_pton(buf, addr)) + return -EINVAL; + + if (!ieee80211_sdata_running(sdata)) + return -ENOTCONN; + + skb = dev_alloc_skb(local->hw.extra_tx_headroom + 24 + 100); + if (!skb) + return -ENOMEM; + skb_reserve(skb, local->hw.extra_tx_headroom); + + hdr = skb_put_zero(skb, 24); + fc = cpu_to_le16(IEEE80211_FTYPE_DATA | IEEE80211_STYPE_DATA); + + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP: + fc |= cpu_to_le16(IEEE80211_FCTL_FROMDS); + /* DA BSSID SA */ + memcpy(hdr->addr1, addr, ETH_ALEN); + memcpy(hdr->addr2, sdata->vif.addr, ETH_ALEN); + memcpy(hdr->addr3, sdata->vif.addr, ETH_ALEN); + break; + case NL80211_IFTYPE_STATION: + fc |= cpu_to_le16(IEEE80211_FCTL_TODS); + /* BSSID SA DA */ + sdata_lock(sdata); + if (!sdata->u.mgd.associated) { + sdata_unlock(sdata); + dev_kfree_skb(skb); + return -ENOTCONN; + } + memcpy(hdr->addr1, sdata->deflink.u.mgd.bssid, ETH_ALEN); + memcpy(hdr->addr2, sdata->vif.addr, ETH_ALEN); + memcpy(hdr->addr3, addr, ETH_ALEN); + sdata_unlock(sdata); + break; + default: + dev_kfree_skb(skb); + return -EOPNOTSUPP; + } + hdr->frame_control = fc; + + /* + * Add some length to the test frame to make it look bit more valid. + * The exact contents does not matter since the recipient is required + * to drop this because of the Michael MIC failure. + */ + skb_put_zero(skb, 50); + + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_INTFL_TKIP_MIC_FAILURE; + + ieee80211_tx_skb(sdata, skb); + + return buflen; +} +IEEE80211_IF_FILE_W(tkip_mic_test); + +static ssize_t ieee80211_if_parse_beacon_loss( + struct ieee80211_sub_if_data *sdata, const char *buf, int buflen) +{ + if (!ieee80211_sdata_running(sdata) || !sdata->vif.cfg.assoc) + return -ENOTCONN; + + ieee80211_beacon_loss(&sdata->vif); + + return buflen; +} +IEEE80211_IF_FILE_W(beacon_loss); + +static ssize_t ieee80211_if_fmt_uapsd_queues( + const struct ieee80211_sub_if_data *sdata, char *buf, int buflen) +{ + const struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + + return snprintf(buf, buflen, "0x%x\n", ifmgd->uapsd_queues); +} + +static ssize_t ieee80211_if_parse_uapsd_queues( + struct ieee80211_sub_if_data *sdata, const char *buf, int buflen) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + u8 val; + int ret; + + ret = kstrtou8(buf, 0, &val); + if (ret) + return ret; + + if (val & ~IEEE80211_WMM_IE_STA_QOSINFO_AC_MASK) + return -ERANGE; + + ifmgd->uapsd_queues = val; + + return buflen; +} +IEEE80211_IF_FILE_RW(uapsd_queues); + +static ssize_t ieee80211_if_fmt_uapsd_max_sp_len( + const struct ieee80211_sub_if_data *sdata, char *buf, int buflen) +{ + const struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + + return snprintf(buf, buflen, "0x%x\n", ifmgd->uapsd_max_sp_len); +} + +static ssize_t ieee80211_if_parse_uapsd_max_sp_len( + struct ieee80211_sub_if_data *sdata, const char *buf, int buflen) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + unsigned long val; + int ret; + + ret = kstrtoul(buf, 0, &val); + if (ret) + return -EINVAL; + + if (val & ~IEEE80211_WMM_IE_STA_QOSINFO_SP_MASK) + return -ERANGE; + + ifmgd->uapsd_max_sp_len = val; + + return buflen; +} +IEEE80211_IF_FILE_RW(uapsd_max_sp_len); + +static ssize_t ieee80211_if_fmt_tdls_wider_bw( + const struct ieee80211_sub_if_data *sdata, char *buf, int buflen) +{ + const struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + bool tdls_wider_bw; + + tdls_wider_bw = ieee80211_hw_check(&sdata->local->hw, TDLS_WIDER_BW) && + !ifmgd->tdls_wider_bw_prohibited; + + return snprintf(buf, buflen, "%d\n", tdls_wider_bw); +} + +static ssize_t ieee80211_if_parse_tdls_wider_bw( + struct ieee80211_sub_if_data *sdata, const char *buf, int buflen) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + u8 val; + int ret; + + ret = kstrtou8(buf, 0, &val); + if (ret) + return ret; + + ifmgd->tdls_wider_bw_prohibited = !val; + return buflen; +} +IEEE80211_IF_FILE_RW(tdls_wider_bw); + +/* AP attributes */ +IEEE80211_IF_FILE(num_mcast_sta, u.ap.num_mcast_sta, ATOMIC); +IEEE80211_IF_FILE(num_sta_ps, u.ap.ps.num_sta_ps, ATOMIC); +IEEE80211_IF_FILE(dtim_count, u.ap.ps.dtim_count, DEC); +IEEE80211_IF_FILE(num_mcast_sta_vlan, u.vlan.num_mcast_sta, ATOMIC); + +static ssize_t ieee80211_if_fmt_num_buffered_multicast( + const struct ieee80211_sub_if_data *sdata, char *buf, int buflen) +{ + return scnprintf(buf, buflen, "%u\n", + skb_queue_len(&sdata->u.ap.ps.bc_buf)); +} +IEEE80211_IF_FILE_R(num_buffered_multicast); + +static ssize_t ieee80211_if_fmt_aqm( + const struct ieee80211_sub_if_data *sdata, char *buf, int buflen) +{ + struct ieee80211_local *local = sdata->local; + struct txq_info *txqi; + int len; + + if (!sdata->vif.txq) + return 0; + + txqi = to_txq_info(sdata->vif.txq); + + spin_lock_bh(&local->fq.lock); + rcu_read_lock(); + + len = scnprintf(buf, + buflen, + "ac backlog-bytes backlog-packets new-flows drops marks overlimit collisions tx-bytes tx-packets\n" + "%u %u %u %u %u %u %u %u %u %u\n", + txqi->txq.ac, + txqi->tin.backlog_bytes, + txqi->tin.backlog_packets, + txqi->tin.flows, + txqi->cstats.drop_count, + txqi->cstats.ecn_mark, + txqi->tin.overlimit, + txqi->tin.collisions, + txqi->tin.tx_bytes, + txqi->tin.tx_packets); + + rcu_read_unlock(); + spin_unlock_bh(&local->fq.lock); + + return len; +} +IEEE80211_IF_FILE_R(aqm); + +IEEE80211_IF_FILE(multicast_to_unicast, u.ap.multicast_to_unicast, HEX); + +/* IBSS attributes */ +static ssize_t ieee80211_if_fmt_tsf( + const struct ieee80211_sub_if_data *sdata, char *buf, int buflen) +{ + struct ieee80211_local *local = sdata->local; + u64 tsf; + + tsf = drv_get_tsf(local, (struct ieee80211_sub_if_data *)sdata); + + return scnprintf(buf, buflen, "0x%016llx\n", (unsigned long long) tsf); +} + +static ssize_t ieee80211_if_parse_tsf( + struct ieee80211_sub_if_data *sdata, const char *buf, int buflen) +{ + struct ieee80211_local *local = sdata->local; + unsigned long long tsf; + int ret; + int tsf_is_delta = 0; + + if (strncmp(buf, "reset", 5) == 0) { + if (local->ops->reset_tsf) { + drv_reset_tsf(local, sdata); + wiphy_info(local->hw.wiphy, "debugfs reset TSF\n"); + } + } else { + if (buflen > 10 && buf[1] == '=') { + if (buf[0] == '+') + tsf_is_delta = 1; + else if (buf[0] == '-') + tsf_is_delta = -1; + else + return -EINVAL; + buf += 2; + } + ret = kstrtoull(buf, 10, &tsf); + if (ret < 0) + return ret; + if (tsf_is_delta && local->ops->offset_tsf) { + drv_offset_tsf(local, sdata, tsf_is_delta * tsf); + wiphy_info(local->hw.wiphy, + "debugfs offset TSF by %018lld\n", + tsf_is_delta * tsf); + } else if (local->ops->set_tsf) { + if (tsf_is_delta) + tsf = drv_get_tsf(local, sdata) + + tsf_is_delta * tsf; + drv_set_tsf(local, sdata, tsf); + wiphy_info(local->hw.wiphy, + "debugfs set TSF to %#018llx\n", tsf); + } + } + + ieee80211_recalc_dtim(local, sdata); + return buflen; +} +IEEE80211_IF_FILE_RW(tsf); + +static ssize_t ieee80211_if_fmt_valid_links(const struct ieee80211_sub_if_data *sdata, + char *buf, int buflen) +{ + return snprintf(buf, buflen, "0x%x\n", sdata->vif.valid_links); +} +IEEE80211_IF_FILE_R(valid_links); + +static ssize_t ieee80211_if_fmt_active_links(const struct ieee80211_sub_if_data *sdata, + char *buf, int buflen) +{ + return snprintf(buf, buflen, "0x%x\n", sdata->vif.active_links); +} + +static ssize_t ieee80211_if_parse_active_links(struct ieee80211_sub_if_data *sdata, + const char *buf, int buflen) +{ + u16 active_links; + + if (kstrtou16(buf, 0, &active_links)) + return -EINVAL; + + return ieee80211_set_active_links(&sdata->vif, active_links) ?: buflen; +} +IEEE80211_IF_FILE_RW(active_links); + +#ifdef CONFIG_MAC80211_MESH +IEEE80211_IF_FILE(estab_plinks, u.mesh.estab_plinks, ATOMIC); + +/* Mesh stats attributes */ +IEEE80211_IF_FILE(fwded_mcast, u.mesh.mshstats.fwded_mcast, DEC); +IEEE80211_IF_FILE(fwded_unicast, u.mesh.mshstats.fwded_unicast, DEC); +IEEE80211_IF_FILE(fwded_frames, u.mesh.mshstats.fwded_frames, DEC); +IEEE80211_IF_FILE(dropped_frames_ttl, u.mesh.mshstats.dropped_frames_ttl, DEC); +IEEE80211_IF_FILE(dropped_frames_congestion, + u.mesh.mshstats.dropped_frames_congestion, DEC); +IEEE80211_IF_FILE(dropped_frames_no_route, + u.mesh.mshstats.dropped_frames_no_route, DEC); + +/* Mesh parameters */ +IEEE80211_IF_FILE(dot11MeshMaxRetries, + u.mesh.mshcfg.dot11MeshMaxRetries, DEC); +IEEE80211_IF_FILE(dot11MeshRetryTimeout, + u.mesh.mshcfg.dot11MeshRetryTimeout, DEC); +IEEE80211_IF_FILE(dot11MeshConfirmTimeout, + u.mesh.mshcfg.dot11MeshConfirmTimeout, DEC); +IEEE80211_IF_FILE(dot11MeshHoldingTimeout, + u.mesh.mshcfg.dot11MeshHoldingTimeout, DEC); +IEEE80211_IF_FILE(dot11MeshTTL, u.mesh.mshcfg.dot11MeshTTL, DEC); +IEEE80211_IF_FILE(element_ttl, u.mesh.mshcfg.element_ttl, DEC); +IEEE80211_IF_FILE(auto_open_plinks, u.mesh.mshcfg.auto_open_plinks, DEC); +IEEE80211_IF_FILE(dot11MeshMaxPeerLinks, + u.mesh.mshcfg.dot11MeshMaxPeerLinks, DEC); +IEEE80211_IF_FILE(dot11MeshHWMPactivePathTimeout, + u.mesh.mshcfg.dot11MeshHWMPactivePathTimeout, DEC); +IEEE80211_IF_FILE(dot11MeshHWMPpreqMinInterval, + u.mesh.mshcfg.dot11MeshHWMPpreqMinInterval, DEC); +IEEE80211_IF_FILE(dot11MeshHWMPperrMinInterval, + u.mesh.mshcfg.dot11MeshHWMPperrMinInterval, DEC); +IEEE80211_IF_FILE(dot11MeshHWMPnetDiameterTraversalTime, + u.mesh.mshcfg.dot11MeshHWMPnetDiameterTraversalTime, DEC); +IEEE80211_IF_FILE(dot11MeshHWMPmaxPREQretries, + u.mesh.mshcfg.dot11MeshHWMPmaxPREQretries, DEC); +IEEE80211_IF_FILE(path_refresh_time, + u.mesh.mshcfg.path_refresh_time, DEC); +IEEE80211_IF_FILE(min_discovery_timeout, + u.mesh.mshcfg.min_discovery_timeout, DEC); +IEEE80211_IF_FILE(dot11MeshHWMPRootMode, + u.mesh.mshcfg.dot11MeshHWMPRootMode, DEC); +IEEE80211_IF_FILE(dot11MeshGateAnnouncementProtocol, + u.mesh.mshcfg.dot11MeshGateAnnouncementProtocol, DEC); +IEEE80211_IF_FILE(dot11MeshHWMPRannInterval, + u.mesh.mshcfg.dot11MeshHWMPRannInterval, DEC); +IEEE80211_IF_FILE(dot11MeshForwarding, u.mesh.mshcfg.dot11MeshForwarding, DEC); +IEEE80211_IF_FILE(rssi_threshold, u.mesh.mshcfg.rssi_threshold, DEC); +IEEE80211_IF_FILE(ht_opmode, u.mesh.mshcfg.ht_opmode, DEC); +IEEE80211_IF_FILE(dot11MeshHWMPactivePathToRootTimeout, + u.mesh.mshcfg.dot11MeshHWMPactivePathToRootTimeout, DEC); +IEEE80211_IF_FILE(dot11MeshHWMProotInterval, + u.mesh.mshcfg.dot11MeshHWMProotInterval, DEC); +IEEE80211_IF_FILE(dot11MeshHWMPconfirmationInterval, + u.mesh.mshcfg.dot11MeshHWMPconfirmationInterval, DEC); +IEEE80211_IF_FILE(power_mode, u.mesh.mshcfg.power_mode, DEC); +IEEE80211_IF_FILE(dot11MeshAwakeWindowDuration, + u.mesh.mshcfg.dot11MeshAwakeWindowDuration, DEC); +IEEE80211_IF_FILE(dot11MeshConnectedToMeshGate, + u.mesh.mshcfg.dot11MeshConnectedToMeshGate, DEC); +IEEE80211_IF_FILE(dot11MeshNolearn, u.mesh.mshcfg.dot11MeshNolearn, DEC); +IEEE80211_IF_FILE(dot11MeshConnectedToAuthServer, + u.mesh.mshcfg.dot11MeshConnectedToAuthServer, DEC); +#endif + +#define DEBUGFS_ADD_MODE(name, mode) \ + debugfs_create_file(#name, mode, sdata->vif.debugfs_dir, \ + sdata, &name##_ops) + +#define DEBUGFS_ADD(name) DEBUGFS_ADD_MODE(name, 0400) + +static void add_common_files(struct ieee80211_sub_if_data *sdata) +{ + DEBUGFS_ADD(rc_rateidx_mask_2ghz); + DEBUGFS_ADD(rc_rateidx_mask_5ghz); + DEBUGFS_ADD(rc_rateidx_mcs_mask_2ghz); + DEBUGFS_ADD(rc_rateidx_mcs_mask_5ghz); + DEBUGFS_ADD(rc_rateidx_vht_mcs_mask_2ghz); + DEBUGFS_ADD(rc_rateidx_vht_mcs_mask_5ghz); + DEBUGFS_ADD(hw_queues); + + if (sdata->local->ops->wake_tx_queue && + sdata->vif.type != NL80211_IFTYPE_P2P_DEVICE && + sdata->vif.type != NL80211_IFTYPE_NAN) + DEBUGFS_ADD(aqm); +} + +static void add_sta_files(struct ieee80211_sub_if_data *sdata) +{ + DEBUGFS_ADD(bssid); + DEBUGFS_ADD(aid); + DEBUGFS_ADD(beacon_timeout); + DEBUGFS_ADD_MODE(smps, 0600); + DEBUGFS_ADD_MODE(tkip_mic_test, 0200); + DEBUGFS_ADD_MODE(beacon_loss, 0200); + DEBUGFS_ADD_MODE(uapsd_queues, 0600); + DEBUGFS_ADD_MODE(uapsd_max_sp_len, 0600); + DEBUGFS_ADD_MODE(tdls_wider_bw, 0600); + DEBUGFS_ADD_MODE(valid_links, 0400); + DEBUGFS_ADD_MODE(active_links, 0600); +} + +static void add_ap_files(struct ieee80211_sub_if_data *sdata) +{ + DEBUGFS_ADD(num_mcast_sta); + DEBUGFS_ADD_MODE(smps, 0600); + DEBUGFS_ADD(num_sta_ps); + DEBUGFS_ADD(dtim_count); + DEBUGFS_ADD(num_buffered_multicast); + DEBUGFS_ADD_MODE(tkip_mic_test, 0200); + DEBUGFS_ADD_MODE(multicast_to_unicast, 0600); +} + +static void add_vlan_files(struct ieee80211_sub_if_data *sdata) +{ + /* add num_mcast_sta_vlan using name num_mcast_sta */ + debugfs_create_file("num_mcast_sta", 0400, sdata->vif.debugfs_dir, + sdata, &num_mcast_sta_vlan_ops); +} + +static void add_ibss_files(struct ieee80211_sub_if_data *sdata) +{ + DEBUGFS_ADD_MODE(tsf, 0600); +} + +#ifdef CONFIG_MAC80211_MESH + +static void add_mesh_files(struct ieee80211_sub_if_data *sdata) +{ + DEBUGFS_ADD_MODE(tsf, 0600); + DEBUGFS_ADD_MODE(estab_plinks, 0400); +} + +static void add_mesh_stats(struct ieee80211_sub_if_data *sdata) +{ + struct dentry *dir = debugfs_create_dir("mesh_stats", + sdata->vif.debugfs_dir); +#define MESHSTATS_ADD(name)\ + debugfs_create_file(#name, 0400, dir, sdata, &name##_ops) + + MESHSTATS_ADD(fwded_mcast); + MESHSTATS_ADD(fwded_unicast); + MESHSTATS_ADD(fwded_frames); + MESHSTATS_ADD(dropped_frames_ttl); + MESHSTATS_ADD(dropped_frames_no_route); + MESHSTATS_ADD(dropped_frames_congestion); +#undef MESHSTATS_ADD +} + +static void add_mesh_config(struct ieee80211_sub_if_data *sdata) +{ + struct dentry *dir = debugfs_create_dir("mesh_config", + sdata->vif.debugfs_dir); + +#define MESHPARAMS_ADD(name) \ + debugfs_create_file(#name, 0600, dir, sdata, &name##_ops) + + MESHPARAMS_ADD(dot11MeshMaxRetries); + MESHPARAMS_ADD(dot11MeshRetryTimeout); + MESHPARAMS_ADD(dot11MeshConfirmTimeout); + MESHPARAMS_ADD(dot11MeshHoldingTimeout); + MESHPARAMS_ADD(dot11MeshTTL); + MESHPARAMS_ADD(element_ttl); + MESHPARAMS_ADD(auto_open_plinks); + MESHPARAMS_ADD(dot11MeshMaxPeerLinks); + MESHPARAMS_ADD(dot11MeshHWMPactivePathTimeout); + MESHPARAMS_ADD(dot11MeshHWMPpreqMinInterval); + MESHPARAMS_ADD(dot11MeshHWMPperrMinInterval); + MESHPARAMS_ADD(dot11MeshHWMPnetDiameterTraversalTime); + MESHPARAMS_ADD(dot11MeshHWMPmaxPREQretries); + MESHPARAMS_ADD(path_refresh_time); + MESHPARAMS_ADD(min_discovery_timeout); + MESHPARAMS_ADD(dot11MeshHWMPRootMode); + MESHPARAMS_ADD(dot11MeshHWMPRannInterval); + MESHPARAMS_ADD(dot11MeshForwarding); + MESHPARAMS_ADD(dot11MeshGateAnnouncementProtocol); + MESHPARAMS_ADD(rssi_threshold); + MESHPARAMS_ADD(ht_opmode); + MESHPARAMS_ADD(dot11MeshHWMPactivePathToRootTimeout); + MESHPARAMS_ADD(dot11MeshHWMProotInterval); + MESHPARAMS_ADD(dot11MeshHWMPconfirmationInterval); + MESHPARAMS_ADD(power_mode); + MESHPARAMS_ADD(dot11MeshAwakeWindowDuration); + MESHPARAMS_ADD(dot11MeshConnectedToMeshGate); + MESHPARAMS_ADD(dot11MeshNolearn); + MESHPARAMS_ADD(dot11MeshConnectedToAuthServer); +#undef MESHPARAMS_ADD +} +#endif + +static void add_files(struct ieee80211_sub_if_data *sdata) +{ + if (!sdata->vif.debugfs_dir) + return; + + DEBUGFS_ADD(flags); + DEBUGFS_ADD(state); + DEBUGFS_ADD(txpower); + DEBUGFS_ADD(user_power_level); + DEBUGFS_ADD(ap_power_level); + + if (sdata->vif.type != NL80211_IFTYPE_MONITOR) + add_common_files(sdata); + + switch (sdata->vif.type) { + case NL80211_IFTYPE_MESH_POINT: +#ifdef CONFIG_MAC80211_MESH + add_mesh_files(sdata); + add_mesh_stats(sdata); + add_mesh_config(sdata); +#endif + break; + case NL80211_IFTYPE_STATION: + add_sta_files(sdata); + break; + case NL80211_IFTYPE_ADHOC: + add_ibss_files(sdata); + break; + case NL80211_IFTYPE_AP: + add_ap_files(sdata); + break; + case NL80211_IFTYPE_AP_VLAN: + add_vlan_files(sdata); + break; + default: + break; + } +} + +void ieee80211_debugfs_add_netdev(struct ieee80211_sub_if_data *sdata) +{ + char buf[10+IFNAMSIZ]; + + sprintf(buf, "netdev:%s", sdata->name); + sdata->vif.debugfs_dir = debugfs_create_dir(buf, + sdata->local->hw.wiphy->debugfsdir); + sdata->debugfs.subdir_stations = debugfs_create_dir("stations", + sdata->vif.debugfs_dir); + add_files(sdata); +} + +void ieee80211_debugfs_remove_netdev(struct ieee80211_sub_if_data *sdata) +{ + if (!sdata->vif.debugfs_dir) + return; + + debugfs_remove_recursive(sdata->vif.debugfs_dir); + sdata->vif.debugfs_dir = NULL; + sdata->debugfs.subdir_stations = NULL; +} + +void ieee80211_debugfs_rename_netdev(struct ieee80211_sub_if_data *sdata) +{ + struct dentry *dir; + char buf[10 + IFNAMSIZ]; + + dir = sdata->vif.debugfs_dir; + + if (IS_ERR_OR_NULL(dir)) + return; + + sprintf(buf, "netdev:%s", sdata->name); + debugfs_rename(dir->d_parent, dir, dir->d_parent, buf); +} diff --git a/net/mac80211/debugfs_netdev.h b/net/mac80211/debugfs_netdev.h new file mode 100644 index 000000000..a7e9d8d51 --- /dev/null +++ b/net/mac80211/debugfs_netdev.h @@ -0,0 +1,25 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* routines exported for debugfs handling */ + +#ifndef __IEEE80211_DEBUGFS_NETDEV_H +#define __IEEE80211_DEBUGFS_NETDEV_H + +#include "ieee80211_i.h" + +#ifdef CONFIG_MAC80211_DEBUGFS +void ieee80211_debugfs_add_netdev(struct ieee80211_sub_if_data *sdata); +void ieee80211_debugfs_remove_netdev(struct ieee80211_sub_if_data *sdata); +void ieee80211_debugfs_rename_netdev(struct ieee80211_sub_if_data *sdata); +#else +static inline void ieee80211_debugfs_add_netdev( + struct ieee80211_sub_if_data *sdata) +{} +static inline void ieee80211_debugfs_remove_netdev( + struct ieee80211_sub_if_data *sdata) +{} +static inline void ieee80211_debugfs_rename_netdev( + struct ieee80211_sub_if_data *sdata) +{} +#endif + +#endif /* __IEEE80211_DEBUGFS_NETDEV_H */ diff --git a/net/mac80211/debugfs_sta.c b/net/mac80211/debugfs_sta.c new file mode 100644 index 000000000..b057253db --- /dev/null +++ b/net/mac80211/debugfs_sta.c @@ -0,0 +1,1079 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2003-2005 Devicescape Software, Inc. + * Copyright (c) 2006 Jiri Benc + * Copyright 2007 Johannes Berg + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright(c) 2016 Intel Deutschland GmbH + * Copyright (C) 2018 - 2021 Intel Corporation + */ + +#include +#include +#include "ieee80211_i.h" +#include "debugfs.h" +#include "debugfs_sta.h" +#include "sta_info.h" +#include "driver-ops.h" + +/* sta attributtes */ + +#define STA_READ(name, field, format_string) \ +static ssize_t sta_ ##name## _read(struct file *file, \ + char __user *userbuf, \ + size_t count, loff_t *ppos) \ +{ \ + struct sta_info *sta = file->private_data; \ + return mac80211_format_buffer(userbuf, count, ppos, \ + format_string, sta->field); \ +} +#define STA_READ_D(name, field) STA_READ(name, field, "%d\n") + +#define STA_OPS(name) \ +static const struct file_operations sta_ ##name## _ops = { \ + .read = sta_##name##_read, \ + .open = simple_open, \ + .llseek = generic_file_llseek, \ +} + +#define STA_OPS_RW(name) \ +static const struct file_operations sta_ ##name## _ops = { \ + .read = sta_##name##_read, \ + .write = sta_##name##_write, \ + .open = simple_open, \ + .llseek = generic_file_llseek, \ +} + +#define STA_FILE(name, field, format) \ + STA_READ_##format(name, field) \ + STA_OPS(name) + +STA_FILE(aid, sta.aid, D); + +static const char * const sta_flag_names[] = { +#define FLAG(F) [WLAN_STA_##F] = #F + FLAG(AUTH), + FLAG(ASSOC), + FLAG(PS_STA), + FLAG(AUTHORIZED), + FLAG(SHORT_PREAMBLE), + FLAG(WDS), + FLAG(CLEAR_PS_FILT), + FLAG(MFP), + FLAG(BLOCK_BA), + FLAG(PS_DRIVER), + FLAG(PSPOLL), + FLAG(TDLS_PEER), + FLAG(TDLS_PEER_AUTH), + FLAG(TDLS_INITIATOR), + FLAG(TDLS_CHAN_SWITCH), + FLAG(TDLS_OFF_CHANNEL), + FLAG(TDLS_WIDER_BW), + FLAG(UAPSD), + FLAG(SP), + FLAG(4ADDR_EVENT), + FLAG(INSERTED), + FLAG(RATE_CONTROL), + FLAG(TOFFSET_KNOWN), + FLAG(MPSP_OWNER), + FLAG(MPSP_RECIPIENT), + FLAG(PS_DELIVER), + FLAG(USES_ENCRYPTION), + FLAG(DECAP_OFFLOAD), +#undef FLAG +}; + +static ssize_t sta_flags_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ + char buf[16 * NUM_WLAN_STA_FLAGS], *pos = buf; + char *end = buf + sizeof(buf) - 1; + struct sta_info *sta = file->private_data; + unsigned int flg; + + BUILD_BUG_ON(ARRAY_SIZE(sta_flag_names) != NUM_WLAN_STA_FLAGS); + + for (flg = 0; flg < NUM_WLAN_STA_FLAGS; flg++) { + if (test_sta_flag(sta, flg)) + pos += scnprintf(pos, end - pos, "%s\n", + sta_flag_names[flg]); + } + + return simple_read_from_buffer(userbuf, count, ppos, buf, strlen(buf)); +} +STA_OPS(flags); + +static ssize_t sta_num_ps_buf_frames_read(struct file *file, + char __user *userbuf, + size_t count, loff_t *ppos) +{ + struct sta_info *sta = file->private_data; + char buf[17*IEEE80211_NUM_ACS], *p = buf; + int ac; + + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) + p += scnprintf(p, sizeof(buf)+buf-p, "AC%d: %d\n", ac, + skb_queue_len(&sta->ps_tx_buf[ac]) + + skb_queue_len(&sta->tx_filtered[ac])); + return simple_read_from_buffer(userbuf, count, ppos, buf, p - buf); +} +STA_OPS(num_ps_buf_frames); + +static ssize_t sta_last_seq_ctrl_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ + char buf[15*IEEE80211_NUM_TIDS], *p = buf; + int i; + struct sta_info *sta = file->private_data; + for (i = 0; i < IEEE80211_NUM_TIDS; i++) + p += scnprintf(p, sizeof(buf)+buf-p, "%x ", + le16_to_cpu(sta->last_seq_ctrl[i])); + p += scnprintf(p, sizeof(buf)+buf-p, "\n"); + return simple_read_from_buffer(userbuf, count, ppos, buf, p - buf); +} +STA_OPS(last_seq_ctrl); + +#define AQM_TXQ_ENTRY_LEN 130 + +static ssize_t sta_aqm_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ + struct sta_info *sta = file->private_data; + struct ieee80211_local *local = sta->local; + size_t bufsz = AQM_TXQ_ENTRY_LEN * (IEEE80211_NUM_TIDS + 2); + char *buf = kzalloc(bufsz, GFP_KERNEL), *p = buf; + struct txq_info *txqi; + ssize_t rv; + int i; + + if (!buf) + return -ENOMEM; + + spin_lock_bh(&local->fq.lock); + rcu_read_lock(); + + p += scnprintf(p, + bufsz + buf - p, + "target %uus interval %uus ecn %s\n", + codel_time_to_us(sta->cparams.target), + codel_time_to_us(sta->cparams.interval), + sta->cparams.ecn ? "yes" : "no"); + p += scnprintf(p, + bufsz + buf - p, + "tid ac backlog-bytes backlog-packets new-flows drops marks overlimit collisions tx-bytes tx-packets flags\n"); + + for (i = 0; i < ARRAY_SIZE(sta->sta.txq); i++) { + if (!sta->sta.txq[i]) + continue; + txqi = to_txq_info(sta->sta.txq[i]); + p += scnprintf(p, bufsz + buf - p, + "%d %d %u %u %u %u %u %u %u %u %u 0x%lx(%s%s%s%s)\n", + txqi->txq.tid, + txqi->txq.ac, + txqi->tin.backlog_bytes, + txqi->tin.backlog_packets, + txqi->tin.flows, + txqi->cstats.drop_count, + txqi->cstats.ecn_mark, + txqi->tin.overlimit, + txqi->tin.collisions, + txqi->tin.tx_bytes, + txqi->tin.tx_packets, + txqi->flags, + test_bit(IEEE80211_TXQ_STOP, &txqi->flags) ? "STOP" : "RUN", + test_bit(IEEE80211_TXQ_AMPDU, &txqi->flags) ? " AMPDU" : "", + test_bit(IEEE80211_TXQ_NO_AMSDU, &txqi->flags) ? " NO-AMSDU" : "", + test_bit(IEEE80211_TXQ_DIRTY, &txqi->flags) ? " DIRTY" : ""); + } + + rcu_read_unlock(); + spin_unlock_bh(&local->fq.lock); + + rv = simple_read_from_buffer(userbuf, count, ppos, buf, p - buf); + kfree(buf); + return rv; +} +STA_OPS(aqm); + +static ssize_t sta_airtime_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ + struct sta_info *sta = file->private_data; + struct ieee80211_local *local = sta->sdata->local; + size_t bufsz = 400; + char *buf = kzalloc(bufsz, GFP_KERNEL), *p = buf; + u64 rx_airtime = 0, tx_airtime = 0; + s32 deficit[IEEE80211_NUM_ACS]; + ssize_t rv; + int ac; + + if (!buf) + return -ENOMEM; + + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { + spin_lock_bh(&local->active_txq_lock[ac]); + rx_airtime += sta->airtime[ac].rx_airtime; + tx_airtime += sta->airtime[ac].tx_airtime; + deficit[ac] = sta->airtime[ac].deficit; + spin_unlock_bh(&local->active_txq_lock[ac]); + } + + p += scnprintf(p, bufsz + buf - p, + "RX: %llu us\nTX: %llu us\nWeight: %u\n" + "Deficit: VO: %d us VI: %d us BE: %d us BK: %d us\n", + rx_airtime, tx_airtime, sta->airtime_weight, + deficit[0], deficit[1], deficit[2], deficit[3]); + + rv = simple_read_from_buffer(userbuf, count, ppos, buf, p - buf); + kfree(buf); + return rv; +} + +static ssize_t sta_airtime_write(struct file *file, const char __user *userbuf, + size_t count, loff_t *ppos) +{ + struct sta_info *sta = file->private_data; + struct ieee80211_local *local = sta->sdata->local; + int ac; + + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { + spin_lock_bh(&local->active_txq_lock[ac]); + sta->airtime[ac].rx_airtime = 0; + sta->airtime[ac].tx_airtime = 0; + sta->airtime[ac].deficit = sta->airtime_weight; + spin_unlock_bh(&local->active_txq_lock[ac]); + } + + return count; +} +STA_OPS_RW(airtime); + +static ssize_t sta_aql_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ + struct sta_info *sta = file->private_data; + struct ieee80211_local *local = sta->sdata->local; + size_t bufsz = 400; + char *buf = kzalloc(bufsz, GFP_KERNEL), *p = buf; + u32 q_depth[IEEE80211_NUM_ACS]; + u32 q_limit_l[IEEE80211_NUM_ACS], q_limit_h[IEEE80211_NUM_ACS]; + ssize_t rv; + int ac; + + if (!buf) + return -ENOMEM; + + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { + spin_lock_bh(&local->active_txq_lock[ac]); + q_limit_l[ac] = sta->airtime[ac].aql_limit_low; + q_limit_h[ac] = sta->airtime[ac].aql_limit_high; + spin_unlock_bh(&local->active_txq_lock[ac]); + q_depth[ac] = atomic_read(&sta->airtime[ac].aql_tx_pending); + } + + p += scnprintf(p, bufsz + buf - p, + "Q depth: VO: %u us VI: %u us BE: %u us BK: %u us\n" + "Q limit[low/high]: VO: %u/%u VI: %u/%u BE: %u/%u BK: %u/%u\n", + q_depth[0], q_depth[1], q_depth[2], q_depth[3], + q_limit_l[0], q_limit_h[0], q_limit_l[1], q_limit_h[1], + q_limit_l[2], q_limit_h[2], q_limit_l[3], q_limit_h[3]); + + rv = simple_read_from_buffer(userbuf, count, ppos, buf, p - buf); + kfree(buf); + return rv; +} + +static ssize_t sta_aql_write(struct file *file, const char __user *userbuf, + size_t count, loff_t *ppos) +{ + struct sta_info *sta = file->private_data; + u32 ac, q_limit_l, q_limit_h; + char _buf[100] = {}, *buf = _buf; + + if (count > sizeof(_buf)) + return -EINVAL; + + if (copy_from_user(buf, userbuf, count)) + return -EFAULT; + + buf[sizeof(_buf) - 1] = '\0'; + if (sscanf(buf, "limit %u %u %u", &ac, &q_limit_l, &q_limit_h) + != 3) + return -EINVAL; + + if (ac >= IEEE80211_NUM_ACS) + return -EINVAL; + + sta->airtime[ac].aql_limit_low = q_limit_l; + sta->airtime[ac].aql_limit_high = q_limit_h; + + return count; +} +STA_OPS_RW(aql); + + +static ssize_t sta_agg_status_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ + char *buf, *p; + ssize_t bufsz = 71 + IEEE80211_NUM_TIDS * 40; + int i; + struct sta_info *sta = file->private_data; + struct tid_ampdu_rx *tid_rx; + struct tid_ampdu_tx *tid_tx; + ssize_t ret; + + buf = kzalloc(bufsz, GFP_KERNEL); + if (!buf) + return -ENOMEM; + p = buf; + + rcu_read_lock(); + + p += scnprintf(p, bufsz + buf - p, "next dialog_token: %#02x\n", + sta->ampdu_mlme.dialog_token_allocator + 1); + p += scnprintf(p, bufsz + buf - p, + "TID\t\tRX\tDTKN\tSSN\t\tTX\tDTKN\tpending\n"); + + for (i = 0; i < IEEE80211_NUM_TIDS; i++) { + bool tid_rx_valid; + + tid_rx = rcu_dereference(sta->ampdu_mlme.tid_rx[i]); + tid_tx = rcu_dereference(sta->ampdu_mlme.tid_tx[i]); + tid_rx_valid = test_bit(i, sta->ampdu_mlme.agg_session_valid); + + p += scnprintf(p, bufsz + buf - p, "%02d", i); + p += scnprintf(p, bufsz + buf - p, "\t\t%x", + tid_rx_valid); + p += scnprintf(p, bufsz + buf - p, "\t%#.2x", + tid_rx_valid ? + sta->ampdu_mlme.tid_rx_token[i] : 0); + p += scnprintf(p, bufsz + buf - p, "\t%#.3x", + tid_rx ? tid_rx->ssn : 0); + + p += scnprintf(p, bufsz + buf - p, "\t\t%x", !!tid_tx); + p += scnprintf(p, bufsz + buf - p, "\t%#.2x", + tid_tx ? tid_tx->dialog_token : 0); + p += scnprintf(p, bufsz + buf - p, "\t%03d", + tid_tx ? skb_queue_len(&tid_tx->pending) : 0); + p += scnprintf(p, bufsz + buf - p, "\n"); + } + rcu_read_unlock(); + + ret = simple_read_from_buffer(userbuf, count, ppos, buf, p - buf); + kfree(buf); + return ret; +} + +static ssize_t sta_agg_status_write(struct file *file, const char __user *userbuf, + size_t count, loff_t *ppos) +{ + char _buf[25] = {}, *buf = _buf; + struct sta_info *sta = file->private_data; + bool start, tx; + unsigned long tid; + char *pos; + int ret, timeout = 5000; + + if (count > sizeof(_buf)) + return -EINVAL; + + if (copy_from_user(buf, userbuf, count)) + return -EFAULT; + + buf[sizeof(_buf) - 1] = '\0'; + pos = buf; + buf = strsep(&pos, " "); + if (!buf) + return -EINVAL; + + if (!strcmp(buf, "tx")) + tx = true; + else if (!strcmp(buf, "rx")) + tx = false; + else + return -EINVAL; + + buf = strsep(&pos, " "); + if (!buf) + return -EINVAL; + if (!strcmp(buf, "start")) { + start = true; + if (!tx) + return -EINVAL; + } else if (!strcmp(buf, "stop")) { + start = false; + } else { + return -EINVAL; + } + + buf = strsep(&pos, " "); + if (!buf) + return -EINVAL; + if (sscanf(buf, "timeout=%d", &timeout) == 1) { + buf = strsep(&pos, " "); + if (!buf || !tx || !start) + return -EINVAL; + } + + ret = kstrtoul(buf, 0, &tid); + if (ret || tid >= IEEE80211_NUM_TIDS) + return -EINVAL; + + if (tx) { + if (start) + ret = ieee80211_start_tx_ba_session(&sta->sta, tid, + timeout); + else + ret = ieee80211_stop_tx_ba_session(&sta->sta, tid); + } else { + __ieee80211_stop_rx_ba_session(sta, tid, WLAN_BACK_RECIPIENT, + 3, true); + ret = 0; + } + + return ret ?: count; +} +STA_OPS_RW(agg_status); + +static ssize_t sta_ht_capa_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ +#define PRINT_HT_CAP(_cond, _str) \ + do { \ + if (_cond) \ + p += scnprintf(p, bufsz + buf - p, "\t" _str "\n"); \ + } while (0) + char *buf, *p; + int i; + ssize_t bufsz = 512; + struct sta_info *sta = file->private_data; + struct ieee80211_sta_ht_cap *htc = &sta->sta.deflink.ht_cap; + ssize_t ret; + + buf = kzalloc(bufsz, GFP_KERNEL); + if (!buf) + return -ENOMEM; + p = buf; + + p += scnprintf(p, bufsz + buf - p, "ht %ssupported\n", + htc->ht_supported ? "" : "not "); + if (htc->ht_supported) { + p += scnprintf(p, bufsz + buf - p, "cap: %#.4x\n", htc->cap); + + PRINT_HT_CAP((htc->cap & BIT(0)), "RX LDPC"); + PRINT_HT_CAP((htc->cap & BIT(1)), "HT20/HT40"); + PRINT_HT_CAP(!(htc->cap & BIT(1)), "HT20"); + + PRINT_HT_CAP(((htc->cap >> 2) & 0x3) == 0, "Static SM Power Save"); + PRINT_HT_CAP(((htc->cap >> 2) & 0x3) == 1, "Dynamic SM Power Save"); + PRINT_HT_CAP(((htc->cap >> 2) & 0x3) == 3, "SM Power Save disabled"); + + PRINT_HT_CAP((htc->cap & BIT(4)), "RX Greenfield"); + PRINT_HT_CAP((htc->cap & BIT(5)), "RX HT20 SGI"); + PRINT_HT_CAP((htc->cap & BIT(6)), "RX HT40 SGI"); + PRINT_HT_CAP((htc->cap & BIT(7)), "TX STBC"); + + PRINT_HT_CAP(((htc->cap >> 8) & 0x3) == 0, "No RX STBC"); + PRINT_HT_CAP(((htc->cap >> 8) & 0x3) == 1, "RX STBC 1-stream"); + PRINT_HT_CAP(((htc->cap >> 8) & 0x3) == 2, "RX STBC 2-streams"); + PRINT_HT_CAP(((htc->cap >> 8) & 0x3) == 3, "RX STBC 3-streams"); + + PRINT_HT_CAP((htc->cap & BIT(10)), "HT Delayed Block Ack"); + + PRINT_HT_CAP(!(htc->cap & BIT(11)), "Max AMSDU length: " + "3839 bytes"); + PRINT_HT_CAP((htc->cap & BIT(11)), "Max AMSDU length: " + "7935 bytes"); + + /* + * For beacons and probe response this would mean the BSS + * does or does not allow the usage of DSSS/CCK HT40. + * Otherwise it means the STA does or does not use + * DSSS/CCK HT40. + */ + PRINT_HT_CAP((htc->cap & BIT(12)), "DSSS/CCK HT40"); + PRINT_HT_CAP(!(htc->cap & BIT(12)), "No DSSS/CCK HT40"); + + /* BIT(13) is reserved */ + + PRINT_HT_CAP((htc->cap & BIT(14)), "40 MHz Intolerant"); + + PRINT_HT_CAP((htc->cap & BIT(15)), "L-SIG TXOP protection"); + + p += scnprintf(p, bufsz + buf - p, "ampdu factor/density: %d/%d\n", + htc->ampdu_factor, htc->ampdu_density); + p += scnprintf(p, bufsz + buf - p, "MCS mask:"); + + for (i = 0; i < IEEE80211_HT_MCS_MASK_LEN; i++) + p += scnprintf(p, bufsz + buf - p, " %.2x", + htc->mcs.rx_mask[i]); + p += scnprintf(p, bufsz + buf - p, "\n"); + + /* If not set this is meaningless */ + if (le16_to_cpu(htc->mcs.rx_highest)) { + p += scnprintf(p, bufsz + buf - p, + "MCS rx highest: %d Mbps\n", + le16_to_cpu(htc->mcs.rx_highest)); + } + + p += scnprintf(p, bufsz + buf - p, "MCS tx params: %x\n", + htc->mcs.tx_params); + } + + ret = simple_read_from_buffer(userbuf, count, ppos, buf, p - buf); + kfree(buf); + return ret; +} +STA_OPS(ht_capa); + +static ssize_t sta_vht_capa_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ + char *buf, *p; + struct sta_info *sta = file->private_data; + struct ieee80211_sta_vht_cap *vhtc = &sta->sta.deflink.vht_cap; + ssize_t ret; + ssize_t bufsz = 512; + + buf = kzalloc(bufsz, GFP_KERNEL); + if (!buf) + return -ENOMEM; + p = buf; + + p += scnprintf(p, bufsz + buf - p, "VHT %ssupported\n", + vhtc->vht_supported ? "" : "not "); + if (vhtc->vht_supported) { + p += scnprintf(p, bufsz + buf - p, "cap: %#.8x\n", + vhtc->cap); +#define PFLAG(a, b) \ + do { \ + if (vhtc->cap & IEEE80211_VHT_CAP_ ## a) \ + p += scnprintf(p, bufsz + buf - p, \ + "\t\t%s\n", b); \ + } while (0) + + switch (vhtc->cap & 0x3) { + case IEEE80211_VHT_CAP_MAX_MPDU_LENGTH_3895: + p += scnprintf(p, bufsz + buf - p, + "\t\tMAX-MPDU-3895\n"); + break; + case IEEE80211_VHT_CAP_MAX_MPDU_LENGTH_7991: + p += scnprintf(p, bufsz + buf - p, + "\t\tMAX-MPDU-7991\n"); + break; + case IEEE80211_VHT_CAP_MAX_MPDU_LENGTH_11454: + p += scnprintf(p, bufsz + buf - p, + "\t\tMAX-MPDU-11454\n"); + break; + default: + p += scnprintf(p, bufsz + buf - p, + "\t\tMAX-MPDU-UNKNOWN\n"); + } + switch (vhtc->cap & IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_MASK) { + case 0: + p += scnprintf(p, bufsz + buf - p, + "\t\t80Mhz\n"); + break; + case IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160MHZ: + p += scnprintf(p, bufsz + buf - p, + "\t\t160Mhz\n"); + break; + case IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160_80PLUS80MHZ: + p += scnprintf(p, bufsz + buf - p, + "\t\t80+80Mhz\n"); + break; + default: + p += scnprintf(p, bufsz + buf - p, + "\t\tUNKNOWN-MHZ: 0x%x\n", + (vhtc->cap >> 2) & 0x3); + } + PFLAG(RXLDPC, "RXLDPC"); + PFLAG(SHORT_GI_80, "SHORT-GI-80"); + PFLAG(SHORT_GI_160, "SHORT-GI-160"); + PFLAG(TXSTBC, "TXSTBC"); + p += scnprintf(p, bufsz + buf - p, + "\t\tRXSTBC_%d\n", (vhtc->cap >> 8) & 0x7); + PFLAG(SU_BEAMFORMER_CAPABLE, "SU-BEAMFORMER-CAPABLE"); + PFLAG(SU_BEAMFORMEE_CAPABLE, "SU-BEAMFORMEE-CAPABLE"); + p += scnprintf(p, bufsz + buf - p, + "\t\tBEAMFORMEE-STS: 0x%x\n", + (vhtc->cap & IEEE80211_VHT_CAP_BEAMFORMEE_STS_MASK) >> + IEEE80211_VHT_CAP_BEAMFORMEE_STS_SHIFT); + p += scnprintf(p, bufsz + buf - p, + "\t\tSOUNDING-DIMENSIONS: 0x%x\n", + (vhtc->cap & IEEE80211_VHT_CAP_SOUNDING_DIMENSIONS_MASK) + >> IEEE80211_VHT_CAP_SOUNDING_DIMENSIONS_SHIFT); + PFLAG(MU_BEAMFORMER_CAPABLE, "MU-BEAMFORMER-CAPABLE"); + PFLAG(MU_BEAMFORMEE_CAPABLE, "MU-BEAMFORMEE-CAPABLE"); + PFLAG(VHT_TXOP_PS, "TXOP-PS"); + PFLAG(HTC_VHT, "HTC-VHT"); + p += scnprintf(p, bufsz + buf - p, + "\t\tMPDU-LENGTH-EXPONENT: 0x%x\n", + (vhtc->cap & IEEE80211_VHT_CAP_MAX_A_MPDU_LENGTH_EXPONENT_MASK) >> + IEEE80211_VHT_CAP_MAX_A_MPDU_LENGTH_EXPONENT_SHIFT); + PFLAG(VHT_LINK_ADAPTATION_VHT_UNSOL_MFB, + "LINK-ADAPTATION-VHT-UNSOL-MFB"); + p += scnprintf(p, bufsz + buf - p, + "\t\tLINK-ADAPTATION-VHT-MRQ-MFB: 0x%x\n", + (vhtc->cap & IEEE80211_VHT_CAP_VHT_LINK_ADAPTATION_VHT_MRQ_MFB) >> 26); + PFLAG(RX_ANTENNA_PATTERN, "RX-ANTENNA-PATTERN"); + PFLAG(TX_ANTENNA_PATTERN, "TX-ANTENNA-PATTERN"); + + p += scnprintf(p, bufsz + buf - p, "RX MCS: %.4x\n", + le16_to_cpu(vhtc->vht_mcs.rx_mcs_map)); + if (vhtc->vht_mcs.rx_highest) + p += scnprintf(p, bufsz + buf - p, + "MCS RX highest: %d Mbps\n", + le16_to_cpu(vhtc->vht_mcs.rx_highest)); + p += scnprintf(p, bufsz + buf - p, "TX MCS: %.4x\n", + le16_to_cpu(vhtc->vht_mcs.tx_mcs_map)); + if (vhtc->vht_mcs.tx_highest) + p += scnprintf(p, bufsz + buf - p, + "MCS TX highest: %d Mbps\n", + le16_to_cpu(vhtc->vht_mcs.tx_highest)); +#undef PFLAG + } + + ret = simple_read_from_buffer(userbuf, count, ppos, buf, p - buf); + kfree(buf); + return ret; +} +STA_OPS(vht_capa); + +static ssize_t sta_he_capa_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ + char *buf, *p; + size_t buf_sz = PAGE_SIZE; + struct sta_info *sta = file->private_data; + struct ieee80211_sta_he_cap *hec = &sta->sta.deflink.he_cap; + struct ieee80211_he_mcs_nss_supp *nss = &hec->he_mcs_nss_supp; + u8 ppe_size; + u8 *cap; + int i; + ssize_t ret; + + buf = kmalloc(buf_sz, GFP_KERNEL); + if (!buf) + return -ENOMEM; + p = buf; + + p += scnprintf(p, buf_sz + buf - p, "HE %ssupported\n", + hec->has_he ? "" : "not "); + if (!hec->has_he) + goto out; + + cap = hec->he_cap_elem.mac_cap_info; + p += scnprintf(p, buf_sz + buf - p, + "MAC-CAP: %#.2x %#.2x %#.2x %#.2x %#.2x %#.2x\n", + cap[0], cap[1], cap[2], cap[3], cap[4], cap[5]); + +#define PRINT(fmt, ...) \ + p += scnprintf(p, buf_sz + buf - p, "\t\t" fmt "\n", \ + ##__VA_ARGS__) + +#define PFLAG(t, n, a, b) \ + do { \ + if (cap[n] & IEEE80211_HE_##t##_CAP##n##_##a) \ + PRINT("%s", b); \ + } while (0) + +#define PFLAG_RANGE(t, i, n, s, m, off, fmt) \ + do { \ + u8 msk = IEEE80211_HE_##t##_CAP##i##_##n##_MASK; \ + u8 idx = ((cap[i] & msk) >> (ffs(msk) - 1)) + off; \ + PRINT(fmt, (s << idx) + (m * idx)); \ + } while (0) + +#define PFLAG_RANGE_DEFAULT(t, i, n, s, m, off, fmt, a, b) \ + do { \ + if (cap[i] == IEEE80211_HE_##t ##_CAP##i##_##n##_##a) { \ + PRINT("%s", b); \ + break; \ + } \ + PFLAG_RANGE(t, i, n, s, m, off, fmt); \ + } while (0) + + PFLAG(MAC, 0, HTC_HE, "HTC-HE"); + PFLAG(MAC, 0, TWT_REQ, "TWT-REQ"); + PFLAG(MAC, 0, TWT_RES, "TWT-RES"); + PFLAG_RANGE_DEFAULT(MAC, 0, DYNAMIC_FRAG, 0, 1, 0, + "DYNAMIC-FRAG-LEVEL-%d", NOT_SUPP, "NOT-SUPP"); + PFLAG_RANGE_DEFAULT(MAC, 0, MAX_NUM_FRAG_MSDU, 1, 0, 0, + "MAX-NUM-FRAG-MSDU-%d", UNLIMITED, "UNLIMITED"); + + PFLAG_RANGE_DEFAULT(MAC, 1, MIN_FRAG_SIZE, 128, 0, -1, + "MIN-FRAG-SIZE-%d", UNLIMITED, "UNLIMITED"); + PFLAG_RANGE_DEFAULT(MAC, 1, TF_MAC_PAD_DUR, 0, 8, 0, + "TF-MAC-PAD-DUR-%dUS", MASK, "UNKNOWN"); + PFLAG_RANGE(MAC, 1, MULTI_TID_AGG_RX_QOS, 0, 1, 1, + "MULTI-TID-AGG-RX-QOS-%d"); + + if (cap[0] & IEEE80211_HE_MAC_CAP0_HTC_HE) { + switch (((cap[2] << 1) | (cap[1] >> 7)) & 0x3) { + case 0: + PRINT("LINK-ADAPTATION-NO-FEEDBACK"); + break; + case 1: + PRINT("LINK-ADAPTATION-RESERVED"); + break; + case 2: + PRINT("LINK-ADAPTATION-UNSOLICITED-FEEDBACK"); + break; + case 3: + PRINT("LINK-ADAPTATION-BOTH"); + break; + } + } + + PFLAG(MAC, 2, ALL_ACK, "ALL-ACK"); + PFLAG(MAC, 2, TRS, "TRS"); + PFLAG(MAC, 2, BSR, "BSR"); + PFLAG(MAC, 2, BCAST_TWT, "BCAST-TWT"); + PFLAG(MAC, 2, 32BIT_BA_BITMAP, "32BIT-BA-BITMAP"); + PFLAG(MAC, 2, MU_CASCADING, "MU-CASCADING"); + PFLAG(MAC, 2, ACK_EN, "ACK-EN"); + + PFLAG(MAC, 3, OMI_CONTROL, "OMI-CONTROL"); + PFLAG(MAC, 3, OFDMA_RA, "OFDMA-RA"); + + switch (cap[3] & IEEE80211_HE_MAC_CAP3_MAX_AMPDU_LEN_EXP_MASK) { + case IEEE80211_HE_MAC_CAP3_MAX_AMPDU_LEN_EXP_EXT_0: + PRINT("MAX-AMPDU-LEN-EXP-USE-EXT-0"); + break; + case IEEE80211_HE_MAC_CAP3_MAX_AMPDU_LEN_EXP_EXT_1: + PRINT("MAX-AMPDU-LEN-EXP-VHT-EXT-1"); + break; + case IEEE80211_HE_MAC_CAP3_MAX_AMPDU_LEN_EXP_EXT_2: + PRINT("MAX-AMPDU-LEN-EXP-VHT-EXT-2"); + break; + case IEEE80211_HE_MAC_CAP3_MAX_AMPDU_LEN_EXP_EXT_3: + PRINT("MAX-AMPDU-LEN-EXP-VHT-EXT-3"); + break; + } + + PFLAG(MAC, 3, AMSDU_FRAG, "AMSDU-FRAG"); + PFLAG(MAC, 3, FLEX_TWT_SCHED, "FLEX-TWT-SCHED"); + PFLAG(MAC, 3, RX_CTRL_FRAME_TO_MULTIBSS, "RX-CTRL-FRAME-TO-MULTIBSS"); + + PFLAG(MAC, 4, BSRP_BQRP_A_MPDU_AGG, "BSRP-BQRP-A-MPDU-AGG"); + PFLAG(MAC, 4, QTP, "QTP"); + PFLAG(MAC, 4, BQR, "BQR"); + PFLAG(MAC, 4, PSR_RESP, "PSR-RESP"); + PFLAG(MAC, 4, NDP_FB_REP, "NDP-FB-REP"); + PFLAG(MAC, 4, OPS, "OPS"); + PFLAG(MAC, 4, AMSDU_IN_AMPDU, "AMSDU-IN-AMPDU"); + + PRINT("MULTI-TID-AGG-TX-QOS-%d", ((cap[5] << 1) | (cap[4] >> 7)) & 0x7); + + PFLAG(MAC, 5, SUBCHAN_SELECTIVE_TRANSMISSION, + "SUBCHAN-SELECTIVE-TRANSMISSION"); + PFLAG(MAC, 5, UL_2x996_TONE_RU, "UL-2x996-TONE-RU"); + PFLAG(MAC, 5, OM_CTRL_UL_MU_DATA_DIS_RX, "OM-CTRL-UL-MU-DATA-DIS-RX"); + PFLAG(MAC, 5, HE_DYNAMIC_SM_PS, "HE-DYNAMIC-SM-PS"); + PFLAG(MAC, 5, PUNCTURED_SOUNDING, "PUNCTURED-SOUNDING"); + PFLAG(MAC, 5, HT_VHT_TRIG_FRAME_RX, "HT-VHT-TRIG-FRAME-RX"); + + cap = hec->he_cap_elem.phy_cap_info; + p += scnprintf(p, buf_sz + buf - p, + "PHY CAP: %#.2x %#.2x %#.2x %#.2x %#.2x %#.2x %#.2x %#.2x %#.2x %#.2x %#.2x\n", + cap[0], cap[1], cap[2], cap[3], cap[4], cap[5], cap[6], + cap[7], cap[8], cap[9], cap[10]); + + PFLAG(PHY, 0, CHANNEL_WIDTH_SET_40MHZ_IN_2G, + "CHANNEL-WIDTH-SET-40MHZ-IN-2G"); + PFLAG(PHY, 0, CHANNEL_WIDTH_SET_40MHZ_80MHZ_IN_5G, + "CHANNEL-WIDTH-SET-40MHZ-80MHZ-IN-5G"); + PFLAG(PHY, 0, CHANNEL_WIDTH_SET_160MHZ_IN_5G, + "CHANNEL-WIDTH-SET-160MHZ-IN-5G"); + PFLAG(PHY, 0, CHANNEL_WIDTH_SET_80PLUS80_MHZ_IN_5G, + "CHANNEL-WIDTH-SET-80PLUS80-MHZ-IN-5G"); + PFLAG(PHY, 0, CHANNEL_WIDTH_SET_RU_MAPPING_IN_2G, + "CHANNEL-WIDTH-SET-RU-MAPPING-IN-2G"); + PFLAG(PHY, 0, CHANNEL_WIDTH_SET_RU_MAPPING_IN_5G, + "CHANNEL-WIDTH-SET-RU-MAPPING-IN-5G"); + + switch (cap[1] & IEEE80211_HE_PHY_CAP1_PREAMBLE_PUNC_RX_MASK) { + case IEEE80211_HE_PHY_CAP1_PREAMBLE_PUNC_RX_80MHZ_ONLY_SECOND_20MHZ: + PRINT("PREAMBLE-PUNC-RX-80MHZ-ONLY-SECOND-20MHZ"); + break; + case IEEE80211_HE_PHY_CAP1_PREAMBLE_PUNC_RX_80MHZ_ONLY_SECOND_40MHZ: + PRINT("PREAMBLE-PUNC-RX-80MHZ-ONLY-SECOND-40MHZ"); + break; + case IEEE80211_HE_PHY_CAP1_PREAMBLE_PUNC_RX_160MHZ_ONLY_SECOND_20MHZ: + PRINT("PREAMBLE-PUNC-RX-160MHZ-ONLY-SECOND-20MHZ"); + break; + case IEEE80211_HE_PHY_CAP1_PREAMBLE_PUNC_RX_160MHZ_ONLY_SECOND_40MHZ: + PRINT("PREAMBLE-PUNC-RX-160MHZ-ONLY-SECOND-40MHZ"); + break; + } + + PFLAG(PHY, 1, DEVICE_CLASS_A, + "IEEE80211-HE-PHY-CAP1-DEVICE-CLASS-A"); + PFLAG(PHY, 1, LDPC_CODING_IN_PAYLOAD, + "LDPC-CODING-IN-PAYLOAD"); + PFLAG(PHY, 1, HE_LTF_AND_GI_FOR_HE_PPDUS_0_8US, + "HY-CAP1-HE-LTF-AND-GI-FOR-HE-PPDUS-0-8US"); + PRINT("MIDAMBLE-RX-MAX-NSTS-%d", ((cap[2] << 1) | (cap[1] >> 7)) & 0x3); + + PFLAG(PHY, 2, NDP_4x_LTF_AND_3_2US, "NDP-4X-LTF-AND-3-2US"); + PFLAG(PHY, 2, STBC_TX_UNDER_80MHZ, "STBC-TX-UNDER-80MHZ"); + PFLAG(PHY, 2, STBC_RX_UNDER_80MHZ, "STBC-RX-UNDER-80MHZ"); + PFLAG(PHY, 2, DOPPLER_TX, "DOPPLER-TX"); + PFLAG(PHY, 2, DOPPLER_RX, "DOPPLER-RX"); + PFLAG(PHY, 2, UL_MU_FULL_MU_MIMO, "UL-MU-FULL-MU-MIMO"); + PFLAG(PHY, 2, UL_MU_PARTIAL_MU_MIMO, "UL-MU-PARTIAL-MU-MIMO"); + + switch (cap[3] & IEEE80211_HE_PHY_CAP3_DCM_MAX_CONST_TX_MASK) { + case IEEE80211_HE_PHY_CAP3_DCM_MAX_CONST_TX_NO_DCM: + PRINT("DCM-MAX-CONST-TX-NO-DCM"); + break; + case IEEE80211_HE_PHY_CAP3_DCM_MAX_CONST_TX_BPSK: + PRINT("DCM-MAX-CONST-TX-BPSK"); + break; + case IEEE80211_HE_PHY_CAP3_DCM_MAX_CONST_TX_QPSK: + PRINT("DCM-MAX-CONST-TX-QPSK"); + break; + case IEEE80211_HE_PHY_CAP3_DCM_MAX_CONST_TX_16_QAM: + PRINT("DCM-MAX-CONST-TX-16-QAM"); + break; + } + + PFLAG(PHY, 3, DCM_MAX_TX_NSS_1, "DCM-MAX-TX-NSS-1"); + PFLAG(PHY, 3, DCM_MAX_TX_NSS_2, "DCM-MAX-TX-NSS-2"); + + switch (cap[3] & IEEE80211_HE_PHY_CAP3_DCM_MAX_CONST_RX_MASK) { + case IEEE80211_HE_PHY_CAP3_DCM_MAX_CONST_RX_NO_DCM: + PRINT("DCM-MAX-CONST-RX-NO-DCM"); + break; + case IEEE80211_HE_PHY_CAP3_DCM_MAX_CONST_RX_BPSK: + PRINT("DCM-MAX-CONST-RX-BPSK"); + break; + case IEEE80211_HE_PHY_CAP3_DCM_MAX_CONST_RX_QPSK: + PRINT("DCM-MAX-CONST-RX-QPSK"); + break; + case IEEE80211_HE_PHY_CAP3_DCM_MAX_CONST_RX_16_QAM: + PRINT("DCM-MAX-CONST-RX-16-QAM"); + break; + } + + PFLAG(PHY, 3, DCM_MAX_RX_NSS_1, "DCM-MAX-RX-NSS-1"); + PFLAG(PHY, 3, DCM_MAX_RX_NSS_2, "DCM-MAX-RX-NSS-2"); + PFLAG(PHY, 3, RX_PARTIAL_BW_SU_IN_20MHZ_MU, + "RX-PARTIAL-BW-SU-IN-20MHZ-MU"); + PFLAG(PHY, 3, SU_BEAMFORMER, "SU-BEAMFORMER"); + + PFLAG(PHY, 4, SU_BEAMFORMEE, "SU-BEAMFORMEE"); + PFLAG(PHY, 4, MU_BEAMFORMER, "MU-BEAMFORMER"); + + PFLAG_RANGE(PHY, 4, BEAMFORMEE_MAX_STS_UNDER_80MHZ, 0, 1, 4, + "BEAMFORMEE-MAX-STS-UNDER-%d"); + PFLAG_RANGE(PHY, 4, BEAMFORMEE_MAX_STS_ABOVE_80MHZ, 0, 1, 4, + "BEAMFORMEE-MAX-STS-ABOVE-%d"); + + PFLAG_RANGE(PHY, 5, BEAMFORMEE_NUM_SND_DIM_UNDER_80MHZ, 0, 1, 1, + "NUM-SND-DIM-UNDER-80MHZ-%d"); + PFLAG_RANGE(PHY, 5, BEAMFORMEE_NUM_SND_DIM_ABOVE_80MHZ, 0, 1, 1, + "NUM-SND-DIM-ABOVE-80MHZ-%d"); + PFLAG(PHY, 5, NG16_SU_FEEDBACK, "NG16-SU-FEEDBACK"); + PFLAG(PHY, 5, NG16_MU_FEEDBACK, "NG16-MU-FEEDBACK"); + + PFLAG(PHY, 6, CODEBOOK_SIZE_42_SU, "CODEBOOK-SIZE-42-SU"); + PFLAG(PHY, 6, CODEBOOK_SIZE_75_MU, "CODEBOOK-SIZE-75-MU"); + PFLAG(PHY, 6, TRIG_SU_BEAMFORMING_FB, "TRIG-SU-BEAMFORMING-FB"); + PFLAG(PHY, 6, TRIG_MU_BEAMFORMING_PARTIAL_BW_FB, + "MU-BEAMFORMING-PARTIAL-BW-FB"); + PFLAG(PHY, 6, TRIG_CQI_FB, "TRIG-CQI-FB"); + PFLAG(PHY, 6, PARTIAL_BW_EXT_RANGE, "PARTIAL-BW-EXT-RANGE"); + PFLAG(PHY, 6, PARTIAL_BANDWIDTH_DL_MUMIMO, + "PARTIAL-BANDWIDTH-DL-MUMIMO"); + PFLAG(PHY, 6, PPE_THRESHOLD_PRESENT, "PPE-THRESHOLD-PRESENT"); + + PFLAG(PHY, 7, PSR_BASED_SR, "PSR-BASED-SR"); + PFLAG(PHY, 7, POWER_BOOST_FACTOR_SUPP, "POWER-BOOST-FACTOR-SUPP"); + PFLAG(PHY, 7, HE_SU_MU_PPDU_4XLTF_AND_08_US_GI, + "HE-SU-MU-PPDU-4XLTF-AND-08-US-GI"); + PFLAG_RANGE(PHY, 7, MAX_NC, 0, 1, 1, "MAX-NC-%d"); + PFLAG(PHY, 7, STBC_TX_ABOVE_80MHZ, "STBC-TX-ABOVE-80MHZ"); + PFLAG(PHY, 7, STBC_RX_ABOVE_80MHZ, "STBC-RX-ABOVE-80MHZ"); + + PFLAG(PHY, 8, HE_ER_SU_PPDU_4XLTF_AND_08_US_GI, + "HE-ER-SU-PPDU-4XLTF-AND-08-US-GI"); + PFLAG(PHY, 8, 20MHZ_IN_40MHZ_HE_PPDU_IN_2G, + "20MHZ-IN-40MHZ-HE-PPDU-IN-2G"); + PFLAG(PHY, 8, 20MHZ_IN_160MHZ_HE_PPDU, "20MHZ-IN-160MHZ-HE-PPDU"); + PFLAG(PHY, 8, 80MHZ_IN_160MHZ_HE_PPDU, "80MHZ-IN-160MHZ-HE-PPDU"); + PFLAG(PHY, 8, HE_ER_SU_1XLTF_AND_08_US_GI, + "HE-ER-SU-1XLTF-AND-08-US-GI"); + PFLAG(PHY, 8, MIDAMBLE_RX_TX_2X_AND_1XLTF, + "MIDAMBLE-RX-TX-2X-AND-1XLTF"); + + switch (cap[8] & IEEE80211_HE_PHY_CAP8_DCM_MAX_RU_MASK) { + case IEEE80211_HE_PHY_CAP8_DCM_MAX_RU_242: + PRINT("DCM-MAX-RU-242"); + break; + case IEEE80211_HE_PHY_CAP8_DCM_MAX_RU_484: + PRINT("DCM-MAX-RU-484"); + break; + case IEEE80211_HE_PHY_CAP8_DCM_MAX_RU_996: + PRINT("DCM-MAX-RU-996"); + break; + case IEEE80211_HE_PHY_CAP8_DCM_MAX_RU_2x996: + PRINT("DCM-MAX-RU-2x996"); + break; + } + + PFLAG(PHY, 9, LONGER_THAN_16_SIGB_OFDM_SYM, + "LONGER-THAN-16-SIGB-OFDM-SYM"); + PFLAG(PHY, 9, NON_TRIGGERED_CQI_FEEDBACK, + "NON-TRIGGERED-CQI-FEEDBACK"); + PFLAG(PHY, 9, TX_1024_QAM_LESS_THAN_242_TONE_RU, + "TX-1024-QAM-LESS-THAN-242-TONE-RU"); + PFLAG(PHY, 9, RX_1024_QAM_LESS_THAN_242_TONE_RU, + "RX-1024-QAM-LESS-THAN-242-TONE-RU"); + PFLAG(PHY, 9, RX_FULL_BW_SU_USING_MU_WITH_COMP_SIGB, + "RX-FULL-BW-SU-USING-MU-WITH-COMP-SIGB"); + PFLAG(PHY, 9, RX_FULL_BW_SU_USING_MU_WITH_NON_COMP_SIGB, + "RX-FULL-BW-SU-USING-MU-WITH-NON-COMP-SIGB"); + + switch (u8_get_bits(cap[9], + IEEE80211_HE_PHY_CAP9_NOMINAL_PKT_PADDING_MASK)) { + case IEEE80211_HE_PHY_CAP9_NOMINAL_PKT_PADDING_0US: + PRINT("NOMINAL-PACKET-PADDING-0US"); + break; + case IEEE80211_HE_PHY_CAP9_NOMINAL_PKT_PADDING_8US: + PRINT("NOMINAL-PACKET-PADDING-8US"); + break; + case IEEE80211_HE_PHY_CAP9_NOMINAL_PKT_PADDING_16US: + PRINT("NOMINAL-PACKET-PADDING-16US"); + break; + } + +#undef PFLAG_RANGE_DEFAULT +#undef PFLAG_RANGE +#undef PFLAG + +#define PRINT_NSS_SUPP(f, n) \ + do { \ + int _i; \ + u16 v = le16_to_cpu(nss->f); \ + p += scnprintf(p, buf_sz + buf - p, n ": %#.4x\n", v); \ + for (_i = 0; _i < 8; _i += 2) { \ + switch ((v >> _i) & 0x3) { \ + case 0: \ + PRINT(n "-%d-SUPPORT-0-7", _i / 2); \ + break; \ + case 1: \ + PRINT(n "-%d-SUPPORT-0-9", _i / 2); \ + break; \ + case 2: \ + PRINT(n "-%d-SUPPORT-0-11", _i / 2); \ + break; \ + case 3: \ + PRINT(n "-%d-NOT-SUPPORTED", _i / 2); \ + break; \ + } \ + } \ + } while (0) + + PRINT_NSS_SUPP(rx_mcs_80, "RX-MCS-80"); + PRINT_NSS_SUPP(tx_mcs_80, "TX-MCS-80"); + + if (cap[0] & IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_160MHZ_IN_5G) { + PRINT_NSS_SUPP(rx_mcs_160, "RX-MCS-160"); + PRINT_NSS_SUPP(tx_mcs_160, "TX-MCS-160"); + } + + if (cap[0] & + IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_80PLUS80_MHZ_IN_5G) { + PRINT_NSS_SUPP(rx_mcs_80p80, "RX-MCS-80P80"); + PRINT_NSS_SUPP(tx_mcs_80p80, "TX-MCS-80P80"); + } + +#undef PRINT_NSS_SUPP +#undef PRINT + + if (!(cap[6] & IEEE80211_HE_PHY_CAP6_PPE_THRESHOLD_PRESENT)) + goto out; + + p += scnprintf(p, buf_sz + buf - p, "PPE-THRESHOLDS: %#.2x", + hec->ppe_thres[0]); + + ppe_size = ieee80211_he_ppe_size(hec->ppe_thres[0], cap); + for (i = 1; i < ppe_size; i++) { + p += scnprintf(p, buf_sz + buf - p, " %#.2x", + hec->ppe_thres[i]); + } + p += scnprintf(p, buf_sz + buf - p, "\n"); + +out: + ret = simple_read_from_buffer(userbuf, count, ppos, buf, p - buf); + kfree(buf); + return ret; +} +STA_OPS(he_capa); + +#define DEBUGFS_ADD(name) \ + debugfs_create_file(#name, 0400, \ + sta->debugfs_dir, sta, &sta_ ##name## _ops) + +#define DEBUGFS_ADD_COUNTER(name, field) \ + debugfs_create_ulong(#name, 0400, sta->debugfs_dir, &sta->field); + +void ieee80211_sta_debugfs_add(struct sta_info *sta) +{ + struct ieee80211_local *local = sta->local; + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct dentry *stations_dir = sta->sdata->debugfs.subdir_stations; + u8 mac[3*ETH_ALEN]; + + if (!stations_dir) + return; + + snprintf(mac, sizeof(mac), "%pM", sta->sta.addr); + + /* + * This might fail due to a race condition: + * When mac80211 unlinks a station, the debugfs entries + * remain, but it is already possible to link a new + * station with the same address which triggers adding + * it to debugfs; therefore, if the old station isn't + * destroyed quickly enough the old station's debugfs + * dir might still be around. + */ + sta->debugfs_dir = debugfs_create_dir(mac, stations_dir); + + DEBUGFS_ADD(flags); + DEBUGFS_ADD(aid); + DEBUGFS_ADD(num_ps_buf_frames); + DEBUGFS_ADD(last_seq_ctrl); + DEBUGFS_ADD(agg_status); + DEBUGFS_ADD(ht_capa); + DEBUGFS_ADD(vht_capa); + DEBUGFS_ADD(he_capa); + + DEBUGFS_ADD_COUNTER(rx_duplicates, deflink.rx_stats.num_duplicates); + DEBUGFS_ADD_COUNTER(rx_fragments, deflink.rx_stats.fragments); + DEBUGFS_ADD_COUNTER(tx_filtered, deflink.status_stats.filtered); + + if (local->ops->wake_tx_queue) { + DEBUGFS_ADD(aqm); + DEBUGFS_ADD(airtime); + } + + if (wiphy_ext_feature_isset(local->hw.wiphy, + NL80211_EXT_FEATURE_AQL)) + DEBUGFS_ADD(aql); + + debugfs_create_xul("driver_buffered_tids", 0400, sta->debugfs_dir, + &sta->driver_buffered_tids); + + drv_sta_add_debugfs(local, sdata, &sta->sta, sta->debugfs_dir); +} + +void ieee80211_sta_debugfs_remove(struct sta_info *sta) +{ + debugfs_remove_recursive(sta->debugfs_dir); + sta->debugfs_dir = NULL; +} diff --git a/net/mac80211/debugfs_sta.h b/net/mac80211/debugfs_sta.h new file mode 100644 index 000000000..d2e7c27ad --- /dev/null +++ b/net/mac80211/debugfs_sta.h @@ -0,0 +1,15 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __MAC80211_DEBUGFS_STA_H +#define __MAC80211_DEBUGFS_STA_H + +#include "sta_info.h" + +#ifdef CONFIG_MAC80211_DEBUGFS +void ieee80211_sta_debugfs_add(struct sta_info *sta); +void ieee80211_sta_debugfs_remove(struct sta_info *sta); +#else +static inline void ieee80211_sta_debugfs_add(struct sta_info *sta) {} +static inline void ieee80211_sta_debugfs_remove(struct sta_info *sta) {} +#endif + +#endif /* __MAC80211_DEBUGFS_STA_H */ diff --git a/net/mac80211/driver-ops.c b/net/mac80211/driver-ops.c new file mode 100644 index 000000000..c08d3c9a4 --- /dev/null +++ b/net/mac80211/driver-ops.c @@ -0,0 +1,523 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2015 Intel Deutschland GmbH + * Copyright (C) 2022 Intel Corporation + */ +#include +#include "ieee80211_i.h" +#include "trace.h" +#include "driver-ops.h" + +int drv_start(struct ieee80211_local *local) +{ + int ret; + + might_sleep(); + + if (WARN_ON(local->started)) + return -EALREADY; + + trace_drv_start(local); + local->started = true; + /* allow rx frames */ + smp_mb(); + ret = local->ops->start(&local->hw); + trace_drv_return_int(local, ret); + + if (ret) + local->started = false; + + return ret; +} + +void drv_stop(struct ieee80211_local *local) +{ + might_sleep(); + + if (WARN_ON(!local->started)) + return; + + trace_drv_stop(local); + local->ops->stop(&local->hw); + trace_drv_return_void(local); + + /* sync away all work on the tasklet before clearing started */ + tasklet_disable(&local->tasklet); + tasklet_enable(&local->tasklet); + + barrier(); + + local->started = false; +} + +int drv_add_interface(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + int ret; + + might_sleep(); + + if (WARN_ON(sdata->vif.type == NL80211_IFTYPE_AP_VLAN || + (sdata->vif.type == NL80211_IFTYPE_MONITOR && + !ieee80211_hw_check(&local->hw, WANT_MONITOR_VIF) && + !(sdata->u.mntr.flags & MONITOR_FLAG_ACTIVE)))) + return -EINVAL; + + trace_drv_add_interface(local, sdata); + ret = local->ops->add_interface(&local->hw, &sdata->vif); + trace_drv_return_int(local, ret); + + if (ret == 0) + sdata->flags |= IEEE80211_SDATA_IN_DRIVER; + + return ret; +} + +int drv_change_interface(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + enum nl80211_iftype type, bool p2p) +{ + int ret; + + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return -EIO; + + trace_drv_change_interface(local, sdata, type, p2p); + ret = local->ops->change_interface(&local->hw, &sdata->vif, type, p2p); + trace_drv_return_int(local, ret); + return ret; +} + +void drv_remove_interface(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_remove_interface(local, sdata); + local->ops->remove_interface(&local->hw, &sdata->vif); + sdata->flags &= ~IEEE80211_SDATA_IN_DRIVER; + trace_drv_return_void(local); +} + +__must_check +int drv_sta_state(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + enum ieee80211_sta_state old_state, + enum ieee80211_sta_state new_state) +{ + int ret = 0; + + might_sleep(); + + sdata = get_bss_sdata(sdata); + if (!check_sdata_in_driver(sdata)) + return -EIO; + + trace_drv_sta_state(local, sdata, &sta->sta, old_state, new_state); + if (local->ops->sta_state) { + ret = local->ops->sta_state(&local->hw, &sdata->vif, &sta->sta, + old_state, new_state); + } else if (old_state == IEEE80211_STA_AUTH && + new_state == IEEE80211_STA_ASSOC) { + ret = drv_sta_add(local, sdata, &sta->sta); + if (ret == 0) { + sta->uploaded = true; + if (rcu_access_pointer(sta->sta.rates)) + drv_sta_rate_tbl_update(local, sdata, &sta->sta); + } + } else if (old_state == IEEE80211_STA_ASSOC && + new_state == IEEE80211_STA_AUTH) { + drv_sta_remove(local, sdata, &sta->sta); + } + trace_drv_return_int(local, ret); + return ret; +} + +__must_check +int drv_sta_set_txpwr(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct sta_info *sta) +{ + int ret = -EOPNOTSUPP; + + might_sleep(); + + sdata = get_bss_sdata(sdata); + if (!check_sdata_in_driver(sdata)) + return -EIO; + + trace_drv_sta_set_txpwr(local, sdata, &sta->sta); + if (local->ops->sta_set_txpwr) + ret = local->ops->sta_set_txpwr(&local->hw, &sdata->vif, + &sta->sta); + trace_drv_return_int(local, ret); + return ret; +} + +void drv_sta_rc_update(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, u32 changed) +{ + sdata = get_bss_sdata(sdata); + if (!check_sdata_in_driver(sdata)) + return; + + WARN_ON(changed & IEEE80211_RC_SUPP_RATES_CHANGED && + (sdata->vif.type != NL80211_IFTYPE_ADHOC && + sdata->vif.type != NL80211_IFTYPE_MESH_POINT)); + + trace_drv_sta_rc_update(local, sdata, sta, changed); + if (local->ops->sta_rc_update) + local->ops->sta_rc_update(&local->hw, &sdata->vif, + sta, changed); + + trace_drv_return_void(local); +} + +int drv_conf_tx(struct ieee80211_local *local, + struct ieee80211_link_data *link, u16 ac, + const struct ieee80211_tx_queue_params *params) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + int ret = -EOPNOTSUPP; + + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return -EIO; + + if (sdata->vif.active_links && + !(sdata->vif.active_links & BIT(link->link_id))) + return 0; + + if (params->cw_min == 0 || params->cw_min > params->cw_max) { + /* + * If we can't configure hardware anyway, don't warn. We may + * never have initialized the CW parameters. + */ + WARN_ONCE(local->ops->conf_tx, + "%s: invalid CW_min/CW_max: %d/%d\n", + sdata->name, params->cw_min, params->cw_max); + return -EINVAL; + } + + trace_drv_conf_tx(local, sdata, link->link_id, ac, params); + if (local->ops->conf_tx) + ret = local->ops->conf_tx(&local->hw, &sdata->vif, + link->link_id, ac, params); + trace_drv_return_int(local, ret); + return ret; +} + +u64 drv_get_tsf(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + u64 ret = -1ULL; + + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return ret; + + trace_drv_get_tsf(local, sdata); + if (local->ops->get_tsf) + ret = local->ops->get_tsf(&local->hw, &sdata->vif); + trace_drv_return_u64(local, ret); + return ret; +} + +void drv_set_tsf(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + u64 tsf) +{ + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_set_tsf(local, sdata, tsf); + if (local->ops->set_tsf) + local->ops->set_tsf(&local->hw, &sdata->vif, tsf); + trace_drv_return_void(local); +} + +void drv_offset_tsf(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + s64 offset) +{ + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_offset_tsf(local, sdata, offset); + if (local->ops->offset_tsf) + local->ops->offset_tsf(&local->hw, &sdata->vif, offset); + trace_drv_return_void(local); +} + +void drv_reset_tsf(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_reset_tsf(local, sdata); + if (local->ops->reset_tsf) + local->ops->reset_tsf(&local->hw, &sdata->vif); + trace_drv_return_void(local); +} + +int drv_assign_vif_chanctx(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss_conf *link_conf, + struct ieee80211_chanctx *ctx) +{ + int ret = 0; + + drv_verify_link_exists(sdata, link_conf); + if (!check_sdata_in_driver(sdata)) + return -EIO; + + if (sdata->vif.active_links && + !(sdata->vif.active_links & BIT(link_conf->link_id))) + return 0; + + trace_drv_assign_vif_chanctx(local, sdata, link_conf, ctx); + if (local->ops->assign_vif_chanctx) { + WARN_ON_ONCE(!ctx->driver_present); + ret = local->ops->assign_vif_chanctx(&local->hw, + &sdata->vif, + link_conf, + &ctx->conf); + } + trace_drv_return_int(local, ret); + + return ret; +} + +void drv_unassign_vif_chanctx(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss_conf *link_conf, + struct ieee80211_chanctx *ctx) +{ + might_sleep(); + + drv_verify_link_exists(sdata, link_conf); + if (!check_sdata_in_driver(sdata)) + return; + + if (sdata->vif.active_links && + !(sdata->vif.active_links & BIT(link_conf->link_id))) + return; + + trace_drv_unassign_vif_chanctx(local, sdata, link_conf, ctx); + if (local->ops->unassign_vif_chanctx) { + WARN_ON_ONCE(!ctx->driver_present); + local->ops->unassign_vif_chanctx(&local->hw, + &sdata->vif, + link_conf, + &ctx->conf); + } + trace_drv_return_void(local); +} + +int drv_switch_vif_chanctx(struct ieee80211_local *local, + struct ieee80211_vif_chanctx_switch *vifs, + int n_vifs, enum ieee80211_chanctx_switch_mode mode) +{ + int ret = 0; + int i; + + might_sleep(); + + if (!local->ops->switch_vif_chanctx) + return -EOPNOTSUPP; + + for (i = 0; i < n_vifs; i++) { + struct ieee80211_chanctx *new_ctx = + container_of(vifs[i].new_ctx, + struct ieee80211_chanctx, + conf); + struct ieee80211_chanctx *old_ctx = + container_of(vifs[i].old_ctx, + struct ieee80211_chanctx, + conf); + + WARN_ON_ONCE(!old_ctx->driver_present); + WARN_ON_ONCE((mode == CHANCTX_SWMODE_SWAP_CONTEXTS && + new_ctx->driver_present) || + (mode == CHANCTX_SWMODE_REASSIGN_VIF && + !new_ctx->driver_present)); + } + + trace_drv_switch_vif_chanctx(local, vifs, n_vifs, mode); + ret = local->ops->switch_vif_chanctx(&local->hw, + vifs, n_vifs, mode); + trace_drv_return_int(local, ret); + + if (!ret && mode == CHANCTX_SWMODE_SWAP_CONTEXTS) { + for (i = 0; i < n_vifs; i++) { + struct ieee80211_chanctx *new_ctx = + container_of(vifs[i].new_ctx, + struct ieee80211_chanctx, + conf); + struct ieee80211_chanctx *old_ctx = + container_of(vifs[i].old_ctx, + struct ieee80211_chanctx, + conf); + + new_ctx->driver_present = true; + old_ctx->driver_present = false; + } + } + + return ret; +} + +int drv_ampdu_action(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_ampdu_params *params) +{ + int ret = -EOPNOTSUPP; + + might_sleep(); + + if (!sdata) + return -EIO; + + sdata = get_bss_sdata(sdata); + if (!check_sdata_in_driver(sdata)) + return -EIO; + + trace_drv_ampdu_action(local, sdata, params); + + if (local->ops->ampdu_action) + ret = local->ops->ampdu_action(&local->hw, &sdata->vif, params); + + trace_drv_return_int(local, ret); + + return ret; +} + +void drv_link_info_changed(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss_conf *info, + int link_id, u64 changed) +{ + might_sleep(); + + if (WARN_ON_ONCE(changed & (BSS_CHANGED_BEACON | + BSS_CHANGED_BEACON_ENABLED) && + sdata->vif.type != NL80211_IFTYPE_AP && + sdata->vif.type != NL80211_IFTYPE_ADHOC && + sdata->vif.type != NL80211_IFTYPE_MESH_POINT && + sdata->vif.type != NL80211_IFTYPE_OCB)) + return; + + if (WARN_ON_ONCE(sdata->vif.type == NL80211_IFTYPE_P2P_DEVICE || + sdata->vif.type == NL80211_IFTYPE_NAN || + (sdata->vif.type == NL80211_IFTYPE_MONITOR && + !sdata->vif.bss_conf.mu_mimo_owner && + !(changed & BSS_CHANGED_TXPOWER)))) + return; + + if (!check_sdata_in_driver(sdata)) + return; + + if (sdata->vif.active_links && + !(sdata->vif.active_links & BIT(link_id))) + return; + + trace_drv_link_info_changed(local, sdata, info, changed); + if (local->ops->link_info_changed) + local->ops->link_info_changed(&local->hw, &sdata->vif, + info, changed); + else if (local->ops->bss_info_changed) + local->ops->bss_info_changed(&local->hw, &sdata->vif, + info, changed); + trace_drv_return_void(local); +} + +int drv_set_key(struct ieee80211_local *local, + enum set_key_cmd cmd, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, + struct ieee80211_key_conf *key) +{ + int ret; + + might_sleep(); + + sdata = get_bss_sdata(sdata); + if (!check_sdata_in_driver(sdata)) + return -EIO; + + if (WARN_ON(key->link_id >= 0 && sdata->vif.active_links && + !(sdata->vif.active_links & BIT(key->link_id)))) + return -ENOLINK; + + trace_drv_set_key(local, cmd, sdata, sta, key); + ret = local->ops->set_key(&local->hw, cmd, &sdata->vif, sta, key); + trace_drv_return_int(local, ret); + return ret; +} + +int drv_change_vif_links(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + u16 old_links, u16 new_links, + struct ieee80211_bss_conf *old[IEEE80211_MLD_MAX_NUM_LINKS]) +{ + int ret = -EOPNOTSUPP; + + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return -EIO; + + if (old_links == new_links) + return 0; + + trace_drv_change_vif_links(local, sdata, old_links, new_links); + if (local->ops->change_vif_links) + ret = local->ops->change_vif_links(&local->hw, &sdata->vif, + old_links, new_links, old); + trace_drv_return_int(local, ret); + + return ret; +} + +int drv_change_sta_links(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, + u16 old_links, u16 new_links) +{ + int ret = -EOPNOTSUPP; + + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return -EIO; + + old_links &= sdata->vif.active_links; + new_links &= sdata->vif.active_links; + + if (old_links == new_links) + return 0; + + trace_drv_change_sta_links(local, sdata, sta, old_links, new_links); + if (local->ops->change_sta_links) + ret = local->ops->change_sta_links(&local->hw, &sdata->vif, sta, + old_links, new_links); + trace_drv_return_int(local, ret); + + return ret; +} diff --git a/net/mac80211/driver-ops.h b/net/mac80211/driver-ops.h new file mode 100644 index 000000000..e685c1275 --- /dev/null +++ b/net/mac80211/driver-ops.h @@ -0,0 +1,1482 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* +* Portions of this file +* Copyright(c) 2016 Intel Deutschland GmbH +* Copyright (C) 2018 - 2019, 2021 Intel Corporation +*/ + +#ifndef __MAC80211_DRIVER_OPS +#define __MAC80211_DRIVER_OPS + +#include +#include "ieee80211_i.h" +#include "trace.h" + +#define check_sdata_in_driver(sdata) ({ \ + !WARN_ONCE(!(sdata->flags & IEEE80211_SDATA_IN_DRIVER), \ + "%s: Failed check-sdata-in-driver check, flags: 0x%x\n", \ + sdata->dev ? sdata->dev->name : sdata->name, sdata->flags); \ +}) + +static inline struct ieee80211_sub_if_data * +get_bss_sdata(struct ieee80211_sub_if_data *sdata) +{ + if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + sdata = container_of(sdata->bss, struct ieee80211_sub_if_data, + u.ap); + + return sdata; +} + +static inline void drv_tx(struct ieee80211_local *local, + struct ieee80211_tx_control *control, + struct sk_buff *skb) +{ + local->ops->tx(&local->hw, control, skb); +} + +static inline void drv_sync_rx_queues(struct ieee80211_local *local, + struct sta_info *sta) +{ + if (local->ops->sync_rx_queues) { + trace_drv_sync_rx_queues(local, sta->sdata, &sta->sta); + local->ops->sync_rx_queues(&local->hw); + trace_drv_return_void(local); + } +} + +static inline void drv_get_et_strings(struct ieee80211_sub_if_data *sdata, + u32 sset, u8 *data) +{ + struct ieee80211_local *local = sdata->local; + if (local->ops->get_et_strings) { + trace_drv_get_et_strings(local, sset); + local->ops->get_et_strings(&local->hw, &sdata->vif, sset, data); + trace_drv_return_void(local); + } +} + +static inline void drv_get_et_stats(struct ieee80211_sub_if_data *sdata, + struct ethtool_stats *stats, + u64 *data) +{ + struct ieee80211_local *local = sdata->local; + if (local->ops->get_et_stats) { + trace_drv_get_et_stats(local); + local->ops->get_et_stats(&local->hw, &sdata->vif, stats, data); + trace_drv_return_void(local); + } +} + +static inline int drv_get_et_sset_count(struct ieee80211_sub_if_data *sdata, + int sset) +{ + struct ieee80211_local *local = sdata->local; + int rv = 0; + if (local->ops->get_et_sset_count) { + trace_drv_get_et_sset_count(local, sset); + rv = local->ops->get_et_sset_count(&local->hw, &sdata->vif, + sset); + trace_drv_return_int(local, rv); + } + return rv; +} + +int drv_start(struct ieee80211_local *local); +void drv_stop(struct ieee80211_local *local); + +#ifdef CONFIG_PM +static inline int drv_suspend(struct ieee80211_local *local, + struct cfg80211_wowlan *wowlan) +{ + int ret; + + might_sleep(); + + trace_drv_suspend(local); + ret = local->ops->suspend(&local->hw, wowlan); + trace_drv_return_int(local, ret); + return ret; +} + +static inline int drv_resume(struct ieee80211_local *local) +{ + int ret; + + might_sleep(); + + trace_drv_resume(local); + ret = local->ops->resume(&local->hw); + trace_drv_return_int(local, ret); + return ret; +} + +static inline void drv_set_wakeup(struct ieee80211_local *local, + bool enabled) +{ + might_sleep(); + + if (!local->ops->set_wakeup) + return; + + trace_drv_set_wakeup(local, enabled); + local->ops->set_wakeup(&local->hw, enabled); + trace_drv_return_void(local); +} +#endif + +int drv_add_interface(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata); + +int drv_change_interface(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + enum nl80211_iftype type, bool p2p); + +void drv_remove_interface(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata); + +static inline int drv_config(struct ieee80211_local *local, u32 changed) +{ + int ret; + + might_sleep(); + + trace_drv_config(local, changed); + ret = local->ops->config(&local->hw, changed); + trace_drv_return_int(local, ret); + return ret; +} + +static inline void drv_vif_cfg_changed(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + u64 changed) +{ + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_vif_cfg_changed(local, sdata, changed); + if (local->ops->vif_cfg_changed) + local->ops->vif_cfg_changed(&local->hw, &sdata->vif, changed); + else if (local->ops->bss_info_changed) + local->ops->bss_info_changed(&local->hw, &sdata->vif, + &sdata->vif.bss_conf, changed); + trace_drv_return_void(local); +} + +void drv_link_info_changed(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss_conf *info, + int link_id, u64 changed); + +static inline u64 drv_prepare_multicast(struct ieee80211_local *local, + struct netdev_hw_addr_list *mc_list) +{ + u64 ret = 0; + + trace_drv_prepare_multicast(local, mc_list->count); + + if (local->ops->prepare_multicast) + ret = local->ops->prepare_multicast(&local->hw, mc_list); + + trace_drv_return_u64(local, ret); + + return ret; +} + +static inline void drv_configure_filter(struct ieee80211_local *local, + unsigned int changed_flags, + unsigned int *total_flags, + u64 multicast) +{ + might_sleep(); + + trace_drv_configure_filter(local, changed_flags, total_flags, + multicast); + local->ops->configure_filter(&local->hw, changed_flags, total_flags, + multicast); + trace_drv_return_void(local); +} + +static inline void drv_config_iface_filter(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + unsigned int filter_flags, + unsigned int changed_flags) +{ + might_sleep(); + + trace_drv_config_iface_filter(local, sdata, filter_flags, + changed_flags); + if (local->ops->config_iface_filter) + local->ops->config_iface_filter(&local->hw, &sdata->vif, + filter_flags, + changed_flags); + trace_drv_return_void(local); +} + +static inline int drv_set_tim(struct ieee80211_local *local, + struct ieee80211_sta *sta, bool set) +{ + int ret = 0; + trace_drv_set_tim(local, sta, set); + if (local->ops->set_tim) + ret = local->ops->set_tim(&local->hw, sta, set); + trace_drv_return_int(local, ret); + return ret; +} + +int drv_set_key(struct ieee80211_local *local, + enum set_key_cmd cmd, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, + struct ieee80211_key_conf *key); + +static inline void drv_update_tkip_key(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_key_conf *conf, + struct sta_info *sta, u32 iv32, + u16 *phase1key) +{ + struct ieee80211_sta *ista = NULL; + + if (sta) + ista = &sta->sta; + + sdata = get_bss_sdata(sdata); + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_update_tkip_key(local, sdata, conf, ista, iv32); + if (local->ops->update_tkip_key) + local->ops->update_tkip_key(&local->hw, &sdata->vif, conf, + ista, iv32, phase1key); + trace_drv_return_void(local); +} + +static inline int drv_hw_scan(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_scan_request *req) +{ + int ret; + + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return -EIO; + + trace_drv_hw_scan(local, sdata); + ret = local->ops->hw_scan(&local->hw, &sdata->vif, req); + trace_drv_return_int(local, ret); + return ret; +} + +static inline void drv_cancel_hw_scan(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_cancel_hw_scan(local, sdata); + local->ops->cancel_hw_scan(&local->hw, &sdata->vif); + trace_drv_return_void(local); +} + +static inline int +drv_sched_scan_start(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct cfg80211_sched_scan_request *req, + struct ieee80211_scan_ies *ies) +{ + int ret; + + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return -EIO; + + trace_drv_sched_scan_start(local, sdata); + ret = local->ops->sched_scan_start(&local->hw, &sdata->vif, + req, ies); + trace_drv_return_int(local, ret); + return ret; +} + +static inline int drv_sched_scan_stop(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + int ret; + + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return -EIO; + + trace_drv_sched_scan_stop(local, sdata); + ret = local->ops->sched_scan_stop(&local->hw, &sdata->vif); + trace_drv_return_int(local, ret); + + return ret; +} + +static inline void drv_sw_scan_start(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + const u8 *mac_addr) +{ + might_sleep(); + + trace_drv_sw_scan_start(local, sdata, mac_addr); + if (local->ops->sw_scan_start) + local->ops->sw_scan_start(&local->hw, &sdata->vif, mac_addr); + trace_drv_return_void(local); +} + +static inline void drv_sw_scan_complete(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + might_sleep(); + + trace_drv_sw_scan_complete(local, sdata); + if (local->ops->sw_scan_complete) + local->ops->sw_scan_complete(&local->hw, &sdata->vif); + trace_drv_return_void(local); +} + +static inline int drv_get_stats(struct ieee80211_local *local, + struct ieee80211_low_level_stats *stats) +{ + int ret = -EOPNOTSUPP; + + might_sleep(); + + if (local->ops->get_stats) + ret = local->ops->get_stats(&local->hw, stats); + trace_drv_get_stats(local, stats, ret); + + return ret; +} + +static inline void drv_get_key_seq(struct ieee80211_local *local, + struct ieee80211_key *key, + struct ieee80211_key_seq *seq) +{ + if (local->ops->get_key_seq) + local->ops->get_key_seq(&local->hw, &key->conf, seq); + trace_drv_get_key_seq(local, &key->conf); +} + +static inline int drv_set_frag_threshold(struct ieee80211_local *local, + u32 value) +{ + int ret = 0; + + might_sleep(); + + trace_drv_set_frag_threshold(local, value); + if (local->ops->set_frag_threshold) + ret = local->ops->set_frag_threshold(&local->hw, value); + trace_drv_return_int(local, ret); + return ret; +} + +static inline int drv_set_rts_threshold(struct ieee80211_local *local, + u32 value) +{ + int ret = 0; + + might_sleep(); + + trace_drv_set_rts_threshold(local, value); + if (local->ops->set_rts_threshold) + ret = local->ops->set_rts_threshold(&local->hw, value); + trace_drv_return_int(local, ret); + return ret; +} + +static inline int drv_set_coverage_class(struct ieee80211_local *local, + s16 value) +{ + int ret = 0; + might_sleep(); + + trace_drv_set_coverage_class(local, value); + if (local->ops->set_coverage_class) + local->ops->set_coverage_class(&local->hw, value); + else + ret = -EOPNOTSUPP; + + trace_drv_return_int(local, ret); + return ret; +} + +static inline void drv_sta_notify(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + enum sta_notify_cmd cmd, + struct ieee80211_sta *sta) +{ + sdata = get_bss_sdata(sdata); + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_sta_notify(local, sdata, cmd, sta); + if (local->ops->sta_notify) + local->ops->sta_notify(&local->hw, &sdata->vif, cmd, sta); + trace_drv_return_void(local); +} + +static inline int drv_sta_add(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta) +{ + int ret = 0; + + might_sleep(); + + sdata = get_bss_sdata(sdata); + if (!check_sdata_in_driver(sdata)) + return -EIO; + + trace_drv_sta_add(local, sdata, sta); + if (local->ops->sta_add) + ret = local->ops->sta_add(&local->hw, &sdata->vif, sta); + + trace_drv_return_int(local, ret); + + return ret; +} + +static inline void drv_sta_remove(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta) +{ + might_sleep(); + + sdata = get_bss_sdata(sdata); + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_sta_remove(local, sdata, sta); + if (local->ops->sta_remove) + local->ops->sta_remove(&local->hw, &sdata->vif, sta); + + trace_drv_return_void(local); +} + +#ifdef CONFIG_MAC80211_DEBUGFS +static inline void drv_sta_add_debugfs(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, + struct dentry *dir) +{ + might_sleep(); + + sdata = get_bss_sdata(sdata); + if (!check_sdata_in_driver(sdata)) + return; + + if (local->ops->sta_add_debugfs) + local->ops->sta_add_debugfs(&local->hw, &sdata->vif, + sta, dir); +} +#endif + +static inline void drv_sta_pre_rcu_remove(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct sta_info *sta) +{ + might_sleep(); + + sdata = get_bss_sdata(sdata); + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_sta_pre_rcu_remove(local, sdata, &sta->sta); + if (local->ops->sta_pre_rcu_remove) + local->ops->sta_pre_rcu_remove(&local->hw, &sdata->vif, + &sta->sta); + trace_drv_return_void(local); +} + +__must_check +int drv_sta_state(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + enum ieee80211_sta_state old_state, + enum ieee80211_sta_state new_state); + +__must_check +int drv_sta_set_txpwr(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct sta_info *sta); + +void drv_sta_rc_update(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, u32 changed); + +static inline void drv_sta_rate_tbl_update(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta) +{ + sdata = get_bss_sdata(sdata); + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_sta_rate_tbl_update(local, sdata, sta); + if (local->ops->sta_rate_tbl_update) + local->ops->sta_rate_tbl_update(&local->hw, &sdata->vif, sta); + + trace_drv_return_void(local); +} + +static inline void drv_sta_statistics(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, + struct station_info *sinfo) +{ + sdata = get_bss_sdata(sdata); + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_sta_statistics(local, sdata, sta); + if (local->ops->sta_statistics) + local->ops->sta_statistics(&local->hw, &sdata->vif, sta, sinfo); + trace_drv_return_void(local); +} + +int drv_conf_tx(struct ieee80211_local *local, + struct ieee80211_link_data *link, u16 ac, + const struct ieee80211_tx_queue_params *params); + +u64 drv_get_tsf(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata); +void drv_set_tsf(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + u64 tsf); +void drv_offset_tsf(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + s64 offset); +void drv_reset_tsf(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata); + +static inline int drv_tx_last_beacon(struct ieee80211_local *local) +{ + int ret = 0; /* default unsupported op for less congestion */ + + might_sleep(); + + trace_drv_tx_last_beacon(local); + if (local->ops->tx_last_beacon) + ret = local->ops->tx_last_beacon(&local->hw); + trace_drv_return_int(local, ret); + return ret; +} + +int drv_ampdu_action(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_ampdu_params *params); + +static inline int drv_get_survey(struct ieee80211_local *local, int idx, + struct survey_info *survey) +{ + int ret = -EOPNOTSUPP; + + trace_drv_get_survey(local, idx, survey); + + if (local->ops->get_survey) + ret = local->ops->get_survey(&local->hw, idx, survey); + + trace_drv_return_int(local, ret); + + return ret; +} + +static inline void drv_rfkill_poll(struct ieee80211_local *local) +{ + might_sleep(); + + if (local->ops->rfkill_poll) + local->ops->rfkill_poll(&local->hw); +} + +static inline void drv_flush(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + u32 queues, bool drop) +{ + struct ieee80211_vif *vif = sdata ? &sdata->vif : NULL; + + might_sleep(); + + if (sdata && !check_sdata_in_driver(sdata)) + return; + + trace_drv_flush(local, queues, drop); + if (local->ops->flush) + local->ops->flush(&local->hw, vif, queues, drop); + trace_drv_return_void(local); +} + +static inline void drv_channel_switch(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_channel_switch *ch_switch) +{ + might_sleep(); + + trace_drv_channel_switch(local, sdata, ch_switch); + local->ops->channel_switch(&local->hw, &sdata->vif, ch_switch); + trace_drv_return_void(local); +} + + +static inline int drv_set_antenna(struct ieee80211_local *local, + u32 tx_ant, u32 rx_ant) +{ + int ret = -EOPNOTSUPP; + might_sleep(); + if (local->ops->set_antenna) + ret = local->ops->set_antenna(&local->hw, tx_ant, rx_ant); + trace_drv_set_antenna(local, tx_ant, rx_ant, ret); + return ret; +} + +static inline int drv_get_antenna(struct ieee80211_local *local, + u32 *tx_ant, u32 *rx_ant) +{ + int ret = -EOPNOTSUPP; + might_sleep(); + if (local->ops->get_antenna) + ret = local->ops->get_antenna(&local->hw, tx_ant, rx_ant); + trace_drv_get_antenna(local, *tx_ant, *rx_ant, ret); + return ret; +} + +static inline int drv_remain_on_channel(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_channel *chan, + unsigned int duration, + enum ieee80211_roc_type type) +{ + int ret; + + might_sleep(); + + trace_drv_remain_on_channel(local, sdata, chan, duration, type); + ret = local->ops->remain_on_channel(&local->hw, &sdata->vif, + chan, duration, type); + trace_drv_return_int(local, ret); + + return ret; +} + +static inline int +drv_cancel_remain_on_channel(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + int ret; + + might_sleep(); + + trace_drv_cancel_remain_on_channel(local, sdata); + ret = local->ops->cancel_remain_on_channel(&local->hw, &sdata->vif); + trace_drv_return_int(local, ret); + + return ret; +} + +static inline int drv_set_ringparam(struct ieee80211_local *local, + u32 tx, u32 rx) +{ + int ret = -ENOTSUPP; + + might_sleep(); + + trace_drv_set_ringparam(local, tx, rx); + if (local->ops->set_ringparam) + ret = local->ops->set_ringparam(&local->hw, tx, rx); + trace_drv_return_int(local, ret); + + return ret; +} + +static inline void drv_get_ringparam(struct ieee80211_local *local, + u32 *tx, u32 *tx_max, u32 *rx, u32 *rx_max) +{ + might_sleep(); + + trace_drv_get_ringparam(local, tx, tx_max, rx, rx_max); + if (local->ops->get_ringparam) + local->ops->get_ringparam(&local->hw, tx, tx_max, rx, rx_max); + trace_drv_return_void(local); +} + +static inline bool drv_tx_frames_pending(struct ieee80211_local *local) +{ + bool ret = false; + + might_sleep(); + + trace_drv_tx_frames_pending(local); + if (local->ops->tx_frames_pending) + ret = local->ops->tx_frames_pending(&local->hw); + trace_drv_return_bool(local, ret); + + return ret; +} + +static inline int drv_set_bitrate_mask(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + const struct cfg80211_bitrate_mask *mask) +{ + int ret = -EOPNOTSUPP; + + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return -EIO; + + trace_drv_set_bitrate_mask(local, sdata, mask); + if (local->ops->set_bitrate_mask) + ret = local->ops->set_bitrate_mask(&local->hw, + &sdata->vif, mask); + trace_drv_return_int(local, ret); + + return ret; +} + +static inline void drv_set_rekey_data(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct cfg80211_gtk_rekey_data *data) +{ + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_set_rekey_data(local, sdata, data); + if (local->ops->set_rekey_data) + local->ops->set_rekey_data(&local->hw, &sdata->vif, data); + trace_drv_return_void(local); +} + +static inline void drv_event_callback(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + const struct ieee80211_event *event) +{ + trace_drv_event_callback(local, sdata, event); + if (local->ops->event_callback) + local->ops->event_callback(&local->hw, &sdata->vif, event); + trace_drv_return_void(local); +} + +static inline void +drv_release_buffered_frames(struct ieee80211_local *local, + struct sta_info *sta, u16 tids, int num_frames, + enum ieee80211_frame_release_type reason, + bool more_data) +{ + trace_drv_release_buffered_frames(local, &sta->sta, tids, num_frames, + reason, more_data); + if (local->ops->release_buffered_frames) + local->ops->release_buffered_frames(&local->hw, &sta->sta, tids, + num_frames, reason, + more_data); + trace_drv_return_void(local); +} + +static inline void +drv_allow_buffered_frames(struct ieee80211_local *local, + struct sta_info *sta, u16 tids, int num_frames, + enum ieee80211_frame_release_type reason, + bool more_data) +{ + trace_drv_allow_buffered_frames(local, &sta->sta, tids, num_frames, + reason, more_data); + if (local->ops->allow_buffered_frames) + local->ops->allow_buffered_frames(&local->hw, &sta->sta, + tids, num_frames, reason, + more_data); + trace_drv_return_void(local); +} + +static inline void drv_mgd_prepare_tx(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_prep_tx_info *info) +{ + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return; + WARN_ON_ONCE(sdata->vif.type != NL80211_IFTYPE_STATION); + + trace_drv_mgd_prepare_tx(local, sdata, info->duration, + info->subtype, info->success); + if (local->ops->mgd_prepare_tx) + local->ops->mgd_prepare_tx(&local->hw, &sdata->vif, info); + trace_drv_return_void(local); +} + +static inline void drv_mgd_complete_tx(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_prep_tx_info *info) +{ + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return; + WARN_ON_ONCE(sdata->vif.type != NL80211_IFTYPE_STATION); + + trace_drv_mgd_complete_tx(local, sdata, info->duration, + info->subtype, info->success); + if (local->ops->mgd_complete_tx) + local->ops->mgd_complete_tx(&local->hw, &sdata->vif, info); + trace_drv_return_void(local); +} + +static inline void +drv_mgd_protect_tdls_discover(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return; + WARN_ON_ONCE(sdata->vif.type != NL80211_IFTYPE_STATION); + + trace_drv_mgd_protect_tdls_discover(local, sdata); + if (local->ops->mgd_protect_tdls_discover) + local->ops->mgd_protect_tdls_discover(&local->hw, &sdata->vif); + trace_drv_return_void(local); +} + +static inline int drv_add_chanctx(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx) +{ + int ret = -EOPNOTSUPP; + + might_sleep(); + + trace_drv_add_chanctx(local, ctx); + if (local->ops->add_chanctx) + ret = local->ops->add_chanctx(&local->hw, &ctx->conf); + trace_drv_return_int(local, ret); + if (!ret) + ctx->driver_present = true; + + return ret; +} + +static inline void drv_remove_chanctx(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx) +{ + might_sleep(); + + if (WARN_ON(!ctx->driver_present)) + return; + + trace_drv_remove_chanctx(local, ctx); + if (local->ops->remove_chanctx) + local->ops->remove_chanctx(&local->hw, &ctx->conf); + trace_drv_return_void(local); + ctx->driver_present = false; +} + +static inline void drv_change_chanctx(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx, + u32 changed) +{ + might_sleep(); + + trace_drv_change_chanctx(local, ctx, changed); + if (local->ops->change_chanctx) { + WARN_ON_ONCE(!ctx->driver_present); + local->ops->change_chanctx(&local->hw, &ctx->conf, changed); + } + trace_drv_return_void(local); +} + +static inline void drv_verify_link_exists(struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss_conf *link_conf) +{ + /* deflink always exists, so need to check only for other links */ + if (sdata->deflink.conf != link_conf) + sdata_assert_lock(sdata); +} + +int drv_assign_vif_chanctx(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss_conf *link_conf, + struct ieee80211_chanctx *ctx); +void drv_unassign_vif_chanctx(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss_conf *link_conf, + struct ieee80211_chanctx *ctx); +int drv_switch_vif_chanctx(struct ieee80211_local *local, + struct ieee80211_vif_chanctx_switch *vifs, + int n_vifs, enum ieee80211_chanctx_switch_mode mode); + +static inline int drv_start_ap(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss_conf *link_conf) +{ + int ret = 0; + + /* make sure link_conf is protected */ + drv_verify_link_exists(sdata, link_conf); + + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return -EIO; + + trace_drv_start_ap(local, sdata, link_conf); + if (local->ops->start_ap) + ret = local->ops->start_ap(&local->hw, &sdata->vif, link_conf); + trace_drv_return_int(local, ret); + return ret; +} + +static inline void drv_stop_ap(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss_conf *link_conf) +{ + /* make sure link_conf is protected */ + drv_verify_link_exists(sdata, link_conf); + + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_stop_ap(local, sdata, link_conf); + if (local->ops->stop_ap) + local->ops->stop_ap(&local->hw, &sdata->vif, link_conf); + trace_drv_return_void(local); +} + +static inline void +drv_reconfig_complete(struct ieee80211_local *local, + enum ieee80211_reconfig_type reconfig_type) +{ + might_sleep(); + + trace_drv_reconfig_complete(local, reconfig_type); + if (local->ops->reconfig_complete) + local->ops->reconfig_complete(&local->hw, reconfig_type); + trace_drv_return_void(local); +} + +static inline void +drv_set_default_unicast_key(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + int key_idx) +{ + if (!check_sdata_in_driver(sdata)) + return; + + WARN_ON_ONCE(key_idx < -1 || key_idx > 3); + + trace_drv_set_default_unicast_key(local, sdata, key_idx); + if (local->ops->set_default_unicast_key) + local->ops->set_default_unicast_key(&local->hw, &sdata->vif, + key_idx); + trace_drv_return_void(local); +} + +#if IS_ENABLED(CONFIG_IPV6) +static inline void drv_ipv6_addr_change(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct inet6_dev *idev) +{ + trace_drv_ipv6_addr_change(local, sdata); + if (local->ops->ipv6_addr_change) + local->ops->ipv6_addr_change(&local->hw, &sdata->vif, idev); + trace_drv_return_void(local); +} +#endif + +static inline void +drv_channel_switch_beacon(struct ieee80211_sub_if_data *sdata, + struct cfg80211_chan_def *chandef) +{ + struct ieee80211_local *local = sdata->local; + + if (local->ops->channel_switch_beacon) { + trace_drv_channel_switch_beacon(local, sdata, chandef); + local->ops->channel_switch_beacon(&local->hw, &sdata->vif, + chandef); + } +} + +static inline int +drv_pre_channel_switch(struct ieee80211_sub_if_data *sdata, + struct ieee80211_channel_switch *ch_switch) +{ + struct ieee80211_local *local = sdata->local; + int ret = 0; + + if (!check_sdata_in_driver(sdata)) + return -EIO; + + trace_drv_pre_channel_switch(local, sdata, ch_switch); + if (local->ops->pre_channel_switch) + ret = local->ops->pre_channel_switch(&local->hw, &sdata->vif, + ch_switch); + trace_drv_return_int(local, ret); + return ret; +} + +static inline int +drv_post_channel_switch(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + int ret = 0; + + if (!check_sdata_in_driver(sdata)) + return -EIO; + + trace_drv_post_channel_switch(local, sdata); + if (local->ops->post_channel_switch) + ret = local->ops->post_channel_switch(&local->hw, &sdata->vif); + trace_drv_return_int(local, ret); + return ret; +} + +static inline void +drv_abort_channel_switch(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_abort_channel_switch(local, sdata); + + if (local->ops->abort_channel_switch) + local->ops->abort_channel_switch(&local->hw, &sdata->vif); +} + +static inline void +drv_channel_switch_rx_beacon(struct ieee80211_sub_if_data *sdata, + struct ieee80211_channel_switch *ch_switch) +{ + struct ieee80211_local *local = sdata->local; + + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_channel_switch_rx_beacon(local, sdata, ch_switch); + if (local->ops->channel_switch_rx_beacon) + local->ops->channel_switch_rx_beacon(&local->hw, &sdata->vif, + ch_switch); +} + +static inline int drv_join_ibss(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + int ret = 0; + + might_sleep(); + if (!check_sdata_in_driver(sdata)) + return -EIO; + + trace_drv_join_ibss(local, sdata, &sdata->vif.bss_conf); + if (local->ops->join_ibss) + ret = local->ops->join_ibss(&local->hw, &sdata->vif); + trace_drv_return_int(local, ret); + return ret; +} + +static inline void drv_leave_ibss(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + might_sleep(); + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_leave_ibss(local, sdata); + if (local->ops->leave_ibss) + local->ops->leave_ibss(&local->hw, &sdata->vif); + trace_drv_return_void(local); +} + +static inline u32 drv_get_expected_throughput(struct ieee80211_local *local, + struct sta_info *sta) +{ + u32 ret = 0; + + trace_drv_get_expected_throughput(&sta->sta); + if (local->ops->get_expected_throughput && sta->uploaded) + ret = local->ops->get_expected_throughput(&local->hw, &sta->sta); + trace_drv_return_u32(local, ret); + + return ret; +} + +static inline int drv_get_txpower(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, int *dbm) +{ + int ret; + + if (!local->ops->get_txpower) + return -EOPNOTSUPP; + + ret = local->ops->get_txpower(&local->hw, &sdata->vif, dbm); + trace_drv_get_txpower(local, sdata, *dbm, ret); + + return ret; +} + +static inline int +drv_tdls_channel_switch(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, u8 oper_class, + struct cfg80211_chan_def *chandef, + struct sk_buff *tmpl_skb, u32 ch_sw_tm_ie) +{ + int ret; + + might_sleep(); + if (!check_sdata_in_driver(sdata)) + return -EIO; + + if (!local->ops->tdls_channel_switch) + return -EOPNOTSUPP; + + trace_drv_tdls_channel_switch(local, sdata, sta, oper_class, chandef); + ret = local->ops->tdls_channel_switch(&local->hw, &sdata->vif, sta, + oper_class, chandef, tmpl_skb, + ch_sw_tm_ie); + trace_drv_return_int(local, ret); + return ret; +} + +static inline void +drv_tdls_cancel_channel_switch(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta) +{ + might_sleep(); + if (!check_sdata_in_driver(sdata)) + return; + + if (!local->ops->tdls_cancel_channel_switch) + return; + + trace_drv_tdls_cancel_channel_switch(local, sdata, sta); + local->ops->tdls_cancel_channel_switch(&local->hw, &sdata->vif, sta); + trace_drv_return_void(local); +} + +static inline void +drv_tdls_recv_channel_switch(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_tdls_ch_sw_params *params) +{ + trace_drv_tdls_recv_channel_switch(local, sdata, params); + if (local->ops->tdls_recv_channel_switch) + local->ops->tdls_recv_channel_switch(&local->hw, &sdata->vif, + params); + trace_drv_return_void(local); +} + +static inline void drv_wake_tx_queue(struct ieee80211_local *local, + struct txq_info *txq) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(txq->txq.vif); + + /* In reconfig don't transmit now, but mark for waking later */ + if (local->in_reconfig) { + set_bit(IEEE80211_TXQ_DIRTY, &txq->flags); + return; + } + + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_wake_tx_queue(local, sdata, txq); + local->ops->wake_tx_queue(&local->hw, &txq->txq); +} + +static inline void schedule_and_wake_txq(struct ieee80211_local *local, + struct txq_info *txqi) +{ + ieee80211_schedule_txq(&local->hw, &txqi->txq); + drv_wake_tx_queue(local, txqi); +} + +static inline int drv_can_aggregate_in_amsdu(struct ieee80211_local *local, + struct sk_buff *head, + struct sk_buff *skb) +{ + if (!local->ops->can_aggregate_in_amsdu) + return true; + + return local->ops->can_aggregate_in_amsdu(&local->hw, head, skb); +} + +static inline int +drv_get_ftm_responder_stats(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct cfg80211_ftm_responder_stats *ftm_stats) +{ + u32 ret = -EOPNOTSUPP; + + if (local->ops->get_ftm_responder_stats) + ret = local->ops->get_ftm_responder_stats(&local->hw, + &sdata->vif, + ftm_stats); + trace_drv_get_ftm_responder_stats(local, sdata, ftm_stats); + + return ret; +} + +static inline int drv_start_pmsr(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct cfg80211_pmsr_request *request) +{ + int ret = -EOPNOTSUPP; + + might_sleep(); + if (!check_sdata_in_driver(sdata)) + return -EIO; + + trace_drv_start_pmsr(local, sdata); + + if (local->ops->start_pmsr) + ret = local->ops->start_pmsr(&local->hw, &sdata->vif, request); + trace_drv_return_int(local, ret); + + return ret; +} + +static inline void drv_abort_pmsr(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct cfg80211_pmsr_request *request) +{ + trace_drv_abort_pmsr(local, sdata); + + might_sleep(); + if (!check_sdata_in_driver(sdata)) + return; + + if (local->ops->abort_pmsr) + local->ops->abort_pmsr(&local->hw, &sdata->vif, request); + trace_drv_return_void(local); +} + +static inline int drv_start_nan(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct cfg80211_nan_conf *conf) +{ + int ret; + + might_sleep(); + check_sdata_in_driver(sdata); + + trace_drv_start_nan(local, sdata, conf); + ret = local->ops->start_nan(&local->hw, &sdata->vif, conf); + trace_drv_return_int(local, ret); + return ret; +} + +static inline void drv_stop_nan(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + might_sleep(); + check_sdata_in_driver(sdata); + + trace_drv_stop_nan(local, sdata); + local->ops->stop_nan(&local->hw, &sdata->vif); + trace_drv_return_void(local); +} + +static inline int drv_nan_change_conf(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct cfg80211_nan_conf *conf, + u32 changes) +{ + int ret; + + might_sleep(); + check_sdata_in_driver(sdata); + + if (!local->ops->nan_change_conf) + return -EOPNOTSUPP; + + trace_drv_nan_change_conf(local, sdata, conf, changes); + ret = local->ops->nan_change_conf(&local->hw, &sdata->vif, conf, + changes); + trace_drv_return_int(local, ret); + + return ret; +} + +static inline int drv_add_nan_func(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + const struct cfg80211_nan_func *nan_func) +{ + int ret; + + might_sleep(); + check_sdata_in_driver(sdata); + + if (!local->ops->add_nan_func) + return -EOPNOTSUPP; + + trace_drv_add_nan_func(local, sdata, nan_func); + ret = local->ops->add_nan_func(&local->hw, &sdata->vif, nan_func); + trace_drv_return_int(local, ret); + + return ret; +} + +static inline void drv_del_nan_func(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + u8 instance_id) +{ + might_sleep(); + check_sdata_in_driver(sdata); + + trace_drv_del_nan_func(local, sdata, instance_id); + if (local->ops->del_nan_func) + local->ops->del_nan_func(&local->hw, &sdata->vif, instance_id); + trace_drv_return_void(local); +} + +static inline int drv_set_tid_config(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, + struct cfg80211_tid_config *tid_conf) +{ + int ret; + + might_sleep(); + ret = local->ops->set_tid_config(&local->hw, &sdata->vif, sta, + tid_conf); + trace_drv_return_int(local, ret); + + return ret; +} + +static inline int drv_reset_tid_config(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, u8 tids) +{ + int ret; + + might_sleep(); + ret = local->ops->reset_tid_config(&local->hw, &sdata->vif, sta, tids); + trace_drv_return_int(local, ret); + + return ret; +} + +static inline void drv_update_vif_offload(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + might_sleep(); + check_sdata_in_driver(sdata); + + if (!local->ops->update_vif_offload) + return; + + trace_drv_update_vif_offload(local, sdata); + local->ops->update_vif_offload(&local->hw, &sdata->vif); + trace_drv_return_void(local); +} + +static inline void drv_sta_set_4addr(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, bool enabled) +{ + sdata = get_bss_sdata(sdata); + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_sta_set_4addr(local, sdata, sta, enabled); + if (local->ops->sta_set_4addr) + local->ops->sta_set_4addr(&local->hw, &sdata->vif, sta, enabled); + trace_drv_return_void(local); +} + +static inline void drv_sta_set_decap_offload(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, + bool enabled) +{ + sdata = get_bss_sdata(sdata); + if (!check_sdata_in_driver(sdata)) + return; + + trace_drv_sta_set_decap_offload(local, sdata, sta, enabled); + if (local->ops->sta_set_decap_offload) + local->ops->sta_set_decap_offload(&local->hw, &sdata->vif, sta, + enabled); + trace_drv_return_void(local); +} + +static inline void drv_add_twt_setup(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, + struct ieee80211_twt_setup *twt) +{ + struct ieee80211_twt_params *twt_agrt; + + might_sleep(); + + if (!check_sdata_in_driver(sdata)) + return; + + twt_agrt = (void *)twt->params; + + trace_drv_add_twt_setup(local, sta, twt, twt_agrt); + local->ops->add_twt_setup(&local->hw, sta, twt); + trace_drv_return_void(local); +} + +static inline void drv_twt_teardown_request(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, + u8 flowid) +{ + might_sleep(); + if (!check_sdata_in_driver(sdata)) + return; + + if (!local->ops->twt_teardown_request) + return; + + trace_drv_twt_teardown_request(local, sta, flowid); + local->ops->twt_teardown_request(&local->hw, sta, flowid); + trace_drv_return_void(local); +} + +static inline int drv_net_fill_forward_path(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, + struct net_device_path_ctx *ctx, + struct net_device_path *path) +{ + int ret = -EOPNOTSUPP; + + sdata = get_bss_sdata(sdata); + if (!check_sdata_in_driver(sdata)) + return -EIO; + + trace_drv_net_fill_forward_path(local, sdata, sta); + if (local->ops->net_fill_forward_path) + ret = local->ops->net_fill_forward_path(&local->hw, + &sdata->vif, sta, + ctx, path); + trace_drv_return_int(local, ret); + + return ret; +} + +int drv_change_vif_links(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + u16 old_links, u16 new_links, + struct ieee80211_bss_conf *old[IEEE80211_MLD_MAX_NUM_LINKS]); +int drv_change_sta_links(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, + u16 old_links, u16 new_links); + +#endif /* __MAC80211_DRIVER_OPS */ diff --git a/net/mac80211/eht.c b/net/mac80211/eht.c new file mode 100644 index 000000000..18bc6b78b --- /dev/null +++ b/net/mac80211/eht.c @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * EHT handling + * + * Copyright(c) 2021-2022 Intel Corporation + */ + +#include "ieee80211_i.h" + +void +ieee80211_eht_cap_ie_to_sta_eht_cap(struct ieee80211_sub_if_data *sdata, + struct ieee80211_supported_band *sband, + const u8 *he_cap_ie, u8 he_cap_len, + const struct ieee80211_eht_cap_elem *eht_cap_ie_elem, + u8 eht_cap_len, + struct link_sta_info *link_sta) +{ + struct ieee80211_sta_eht_cap *eht_cap = &link_sta->pub->eht_cap; + struct ieee80211_he_cap_elem *he_cap_ie_elem = (void *)he_cap_ie; + u8 eht_ppe_size = 0; + u8 mcs_nss_size; + u8 eht_total_size = sizeof(eht_cap->eht_cap_elem); + u8 *pos = (u8 *)eht_cap_ie_elem; + + memset(eht_cap, 0, sizeof(*eht_cap)); + + if (!eht_cap_ie_elem || + !ieee80211_get_eht_iftype_cap(sband, + ieee80211_vif_type_p2p(&sdata->vif))) + return; + + mcs_nss_size = ieee80211_eht_mcs_nss_size(he_cap_ie_elem, + &eht_cap_ie_elem->fixed, + sdata->vif.type == + NL80211_IFTYPE_STATION); + + eht_total_size += mcs_nss_size; + + /* Calculate the PPE thresholds length only if the header is present */ + if (eht_cap_ie_elem->fixed.phy_cap_info[5] & + IEEE80211_EHT_PHY_CAP5_PPE_THRESHOLD_PRESENT) { + u16 eht_ppe_hdr; + + if (eht_cap_len < eht_total_size + sizeof(u16)) + return; + + eht_ppe_hdr = get_unaligned_le16(eht_cap_ie_elem->optional + mcs_nss_size); + eht_ppe_size = + ieee80211_eht_ppe_size(eht_ppe_hdr, + eht_cap_ie_elem->fixed.phy_cap_info); + eht_total_size += eht_ppe_size; + + /* we calculate as if NSS > 8 are valid, but don't handle that */ + if (eht_ppe_size > sizeof(eht_cap->eht_ppe_thres)) + return; + } + + if (eht_cap_len < eht_total_size) + return; + + /* Copy the static portion of the EHT capabilities */ + memcpy(&eht_cap->eht_cap_elem, pos, sizeof(eht_cap->eht_cap_elem)); + pos += sizeof(eht_cap->eht_cap_elem); + + /* Copy MCS/NSS which depends on the peer capabilities */ + memset(&eht_cap->eht_mcs_nss_supp, 0, + sizeof(eht_cap->eht_mcs_nss_supp)); + memcpy(&eht_cap->eht_mcs_nss_supp, pos, mcs_nss_size); + + if (eht_ppe_size) + memcpy(eht_cap->eht_ppe_thres, + &eht_cap_ie_elem->optional[mcs_nss_size], + eht_ppe_size); + + eht_cap->has_eht = true; + + link_sta->cur_max_bandwidth = ieee80211_sta_cap_rx_bw(link_sta); + link_sta->pub->bandwidth = ieee80211_sta_cur_vht_bw(link_sta); +} diff --git a/net/mac80211/ethtool.c b/net/mac80211/ethtool.c new file mode 100644 index 000000000..a3830d925 --- /dev/null +++ b/net/mac80211/ethtool.c @@ -0,0 +1,246 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * mac80211 ethtool hooks for cfg80211 + * + * Copied from cfg.c - originally + * Copyright 2006-2010 Johannes Berg + * Copyright 2014 Intel Corporation (Author: Johannes Berg) + * Copyright (C) 2018, 2022 Intel Corporation + */ +#include +#include +#include "ieee80211_i.h" +#include "sta_info.h" +#include "driver-ops.h" + +static int ieee80211_set_ringparam(struct net_device *dev, + struct ethtool_ringparam *rp, + struct kernel_ethtool_ringparam *kernel_rp, + struct netlink_ext_ack *extack) +{ + struct ieee80211_local *local = wiphy_priv(dev->ieee80211_ptr->wiphy); + + if (rp->rx_mini_pending != 0 || rp->rx_jumbo_pending != 0) + return -EINVAL; + + return drv_set_ringparam(local, rp->tx_pending, rp->rx_pending); +} + +static void ieee80211_get_ringparam(struct net_device *dev, + struct ethtool_ringparam *rp, + struct kernel_ethtool_ringparam *kernel_rp, + struct netlink_ext_ack *extack) +{ + struct ieee80211_local *local = wiphy_priv(dev->ieee80211_ptr->wiphy); + + memset(rp, 0, sizeof(*rp)); + + drv_get_ringparam(local, &rp->tx_pending, &rp->tx_max_pending, + &rp->rx_pending, &rp->rx_max_pending); +} + +static const char ieee80211_gstrings_sta_stats[][ETH_GSTRING_LEN] = { + "rx_packets", "rx_bytes", + "rx_duplicates", "rx_fragments", "rx_dropped", + "tx_packets", "tx_bytes", + "tx_filtered", "tx_retry_failed", "tx_retries", + "sta_state", "txrate", "rxrate", "signal", + "channel", "noise", "ch_time", "ch_time_busy", + "ch_time_ext_busy", "ch_time_rx", "ch_time_tx" +}; +#define STA_STATS_LEN ARRAY_SIZE(ieee80211_gstrings_sta_stats) + +static int ieee80211_get_sset_count(struct net_device *dev, int sset) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + int rv = 0; + + if (sset == ETH_SS_STATS) + rv += STA_STATS_LEN; + + rv += drv_get_et_sset_count(sdata, sset); + + if (rv == 0) + return -EOPNOTSUPP; + return rv; +} + +static void ieee80211_get_stats(struct net_device *dev, + struct ethtool_stats *stats, + u64 *data) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_chanctx_conf *chanctx_conf; + struct ieee80211_channel *channel; + struct sta_info *sta; + struct ieee80211_local *local = sdata->local; + struct station_info sinfo; + struct survey_info survey; + int i, q; +#define STA_STATS_SURVEY_LEN 7 + + memset(data, 0, sizeof(u64) * STA_STATS_LEN); + +#define ADD_STA_STATS(sta) \ + do { \ + data[i++] += sinfo.rx_packets; \ + data[i++] += sinfo.rx_bytes; \ + data[i++] += (sta)->rx_stats.num_duplicates; \ + data[i++] += (sta)->rx_stats.fragments; \ + data[i++] += sinfo.rx_dropped_misc; \ + \ + data[i++] += sinfo.tx_packets; \ + data[i++] += sinfo.tx_bytes; \ + data[i++] += (sta)->status_stats.filtered; \ + data[i++] += sinfo.tx_failed; \ + data[i++] += sinfo.tx_retries; \ + } while (0) + + /* For Managed stations, find the single station based on BSSID + * and use that. For interface types, iterate through all available + * stations and add stats for any station that is assigned to this + * network device. + */ + + mutex_lock(&local->sta_mtx); + + if (sdata->vif.type == NL80211_IFTYPE_STATION) { + sta = sta_info_get_bss(sdata, sdata->deflink.u.mgd.bssid); + + if (!(sta && !WARN_ON(sta->sdata->dev != dev))) + goto do_survey; + + memset(&sinfo, 0, sizeof(sinfo)); + sta_set_sinfo(sta, &sinfo, false); + + i = 0; + ADD_STA_STATS(&sta->deflink); + + data[i++] = sta->sta_state; + + + if (sinfo.filled & BIT_ULL(NL80211_STA_INFO_TX_BITRATE)) + data[i] = 100000ULL * + cfg80211_calculate_bitrate(&sinfo.txrate); + i++; + if (sinfo.filled & BIT_ULL(NL80211_STA_INFO_RX_BITRATE)) + data[i] = 100000ULL * + cfg80211_calculate_bitrate(&sinfo.rxrate); + i++; + + if (sinfo.filled & BIT_ULL(NL80211_STA_INFO_SIGNAL_AVG)) + data[i] = (u8)sinfo.signal_avg; + i++; + } else { + list_for_each_entry(sta, &local->sta_list, list) { + /* Make sure this station belongs to the proper dev */ + if (sta->sdata->dev != dev) + continue; + + memset(&sinfo, 0, sizeof(sinfo)); + sta_set_sinfo(sta, &sinfo, false); + i = 0; + ADD_STA_STATS(&sta->deflink); + } + } + +do_survey: + i = STA_STATS_LEN - STA_STATS_SURVEY_LEN; + /* Get survey stats for current channel */ + survey.filled = 0; + + rcu_read_lock(); + chanctx_conf = rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + if (chanctx_conf) + channel = chanctx_conf->def.chan; + else + channel = NULL; + rcu_read_unlock(); + + if (channel) { + q = 0; + do { + survey.filled = 0; + if (drv_get_survey(local, q, &survey) != 0) { + survey.filled = 0; + break; + } + q++; + } while (channel != survey.channel); + } + + if (survey.filled) + data[i++] = survey.channel->center_freq; + else + data[i++] = 0; + if (survey.filled & SURVEY_INFO_NOISE_DBM) + data[i++] = (u8)survey.noise; + else + data[i++] = -1LL; + if (survey.filled & SURVEY_INFO_TIME) + data[i++] = survey.time; + else + data[i++] = -1LL; + if (survey.filled & SURVEY_INFO_TIME_BUSY) + data[i++] = survey.time_busy; + else + data[i++] = -1LL; + if (survey.filled & SURVEY_INFO_TIME_EXT_BUSY) + data[i++] = survey.time_ext_busy; + else + data[i++] = -1LL; + if (survey.filled & SURVEY_INFO_TIME_RX) + data[i++] = survey.time_rx; + else + data[i++] = -1LL; + if (survey.filled & SURVEY_INFO_TIME_TX) + data[i++] = survey.time_tx; + else + data[i++] = -1LL; + + mutex_unlock(&local->sta_mtx); + + if (WARN_ON(i != STA_STATS_LEN)) + return; + + drv_get_et_stats(sdata, stats, &(data[STA_STATS_LEN])); +} + +static void ieee80211_get_strings(struct net_device *dev, u32 sset, u8 *data) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + int sz_sta_stats = 0; + + if (sset == ETH_SS_STATS) { + sz_sta_stats = sizeof(ieee80211_gstrings_sta_stats); + memcpy(data, ieee80211_gstrings_sta_stats, sz_sta_stats); + } + drv_get_et_strings(sdata, sset, &(data[sz_sta_stats])); +} + +static int ieee80211_get_regs_len(struct net_device *dev) +{ + return 0; +} + +static void ieee80211_get_regs(struct net_device *dev, + struct ethtool_regs *regs, + void *data) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + + regs->version = wdev->wiphy->hw_version; + regs->len = 0; +} + +const struct ethtool_ops ieee80211_ethtool_ops = { + .get_drvinfo = cfg80211_get_drvinfo, + .get_regs_len = ieee80211_get_regs_len, + .get_regs = ieee80211_get_regs, + .get_link = ethtool_op_get_link, + .get_ringparam = ieee80211_get_ringparam, + .set_ringparam = ieee80211_set_ringparam, + .get_strings = ieee80211_get_strings, + .get_ethtool_stats = ieee80211_get_stats, + .get_sset_count = ieee80211_get_sset_count, +}; diff --git a/net/mac80211/fils_aead.c b/net/mac80211/fils_aead.c new file mode 100644 index 000000000..e1d4cfd99 --- /dev/null +++ b/net/mac80211/fils_aead.c @@ -0,0 +1,333 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * FILS AEAD for (Re)Association Request/Response frames + * Copyright 2016, Qualcomm Atheros, Inc. + */ + +#include +#include +#include +#include + +#include "ieee80211_i.h" +#include "aes_cmac.h" +#include "fils_aead.h" + +static void gf_mulx(u8 *pad) +{ + u64 a = get_unaligned_be64(pad); + u64 b = get_unaligned_be64(pad + 8); + + put_unaligned_be64((a << 1) | (b >> 63), pad); + put_unaligned_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0), pad + 8); +} + +static int aes_s2v(struct crypto_shash *tfm, + size_t num_elem, const u8 *addr[], size_t len[], u8 *v) +{ + u8 d[AES_BLOCK_SIZE], tmp[AES_BLOCK_SIZE] = {}; + SHASH_DESC_ON_STACK(desc, tfm); + size_t i; + + desc->tfm = tfm; + + /* D = AES-CMAC(K, ) */ + crypto_shash_digest(desc, tmp, AES_BLOCK_SIZE, d); + + for (i = 0; i < num_elem - 1; i++) { + /* D = dbl(D) xor AES_CMAC(K, Si) */ + gf_mulx(d); /* dbl */ + crypto_shash_digest(desc, addr[i], len[i], tmp); + crypto_xor(d, tmp, AES_BLOCK_SIZE); + } + + crypto_shash_init(desc); + + if (len[i] >= AES_BLOCK_SIZE) { + /* len(Sn) >= 128 */ + /* T = Sn xorend D */ + crypto_shash_update(desc, addr[i], len[i] - AES_BLOCK_SIZE); + crypto_xor(d, addr[i] + len[i] - AES_BLOCK_SIZE, + AES_BLOCK_SIZE); + } else { + /* len(Sn) < 128 */ + /* T = dbl(D) xor pad(Sn) */ + gf_mulx(d); /* dbl */ + crypto_xor(d, addr[i], len[i]); + d[len[i]] ^= 0x80; + } + /* V = AES-CMAC(K, T) */ + crypto_shash_finup(desc, d, AES_BLOCK_SIZE, v); + + return 0; +} + +/* Note: addr[] and len[] needs to have one extra slot at the end. */ +static int aes_siv_encrypt(const u8 *key, size_t key_len, + const u8 *plain, size_t plain_len, + size_t num_elem, const u8 *addr[], + size_t len[], u8 *out) +{ + u8 v[AES_BLOCK_SIZE]; + struct crypto_shash *tfm; + struct crypto_skcipher *tfm2; + struct skcipher_request *req; + int res; + struct scatterlist src[1], dst[1]; + u8 *tmp; + + key_len /= 2; /* S2V key || CTR key */ + + addr[num_elem] = plain; + len[num_elem] = plain_len; + num_elem++; + + /* S2V */ + + tfm = crypto_alloc_shash("cmac(aes)", 0, 0); + if (IS_ERR(tfm)) + return PTR_ERR(tfm); + /* K1 for S2V */ + res = crypto_shash_setkey(tfm, key, key_len); + if (!res) + res = aes_s2v(tfm, num_elem, addr, len, v); + crypto_free_shash(tfm); + if (res) + return res; + + /* Use a temporary buffer of the plaintext to handle need for + * overwriting this during AES-CTR. + */ + tmp = kmemdup(plain, plain_len, GFP_KERNEL); + if (!tmp) + return -ENOMEM; + + /* IV for CTR before encrypted data */ + memcpy(out, v, AES_BLOCK_SIZE); + + /* Synthetic IV to be used as the initial counter in CTR: + * Q = V bitand (1^64 || 0^1 || 1^31 || 0^1 || 1^31) + */ + v[8] &= 0x7f; + v[12] &= 0x7f; + + /* CTR */ + + tfm2 = crypto_alloc_skcipher("ctr(aes)", 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(tfm2)) { + kfree(tmp); + return PTR_ERR(tfm2); + } + /* K2 for CTR */ + res = crypto_skcipher_setkey(tfm2, key + key_len, key_len); + if (res) + goto fail; + + req = skcipher_request_alloc(tfm2, GFP_KERNEL); + if (!req) { + res = -ENOMEM; + goto fail; + } + + sg_init_one(src, tmp, plain_len); + sg_init_one(dst, out + AES_BLOCK_SIZE, plain_len); + skcipher_request_set_crypt(req, src, dst, plain_len, v); + res = crypto_skcipher_encrypt(req); + skcipher_request_free(req); +fail: + kfree(tmp); + crypto_free_skcipher(tfm2); + return res; +} + +/* Note: addr[] and len[] needs to have one extra slot at the end. */ +static int aes_siv_decrypt(const u8 *key, size_t key_len, + const u8 *iv_crypt, size_t iv_c_len, + size_t num_elem, const u8 *addr[], size_t len[], + u8 *out) +{ + struct crypto_shash *tfm; + struct crypto_skcipher *tfm2; + struct skcipher_request *req; + struct scatterlist src[1], dst[1]; + size_t crypt_len; + int res; + u8 frame_iv[AES_BLOCK_SIZE], iv[AES_BLOCK_SIZE]; + u8 check[AES_BLOCK_SIZE]; + + crypt_len = iv_c_len - AES_BLOCK_SIZE; + key_len /= 2; /* S2V key || CTR key */ + addr[num_elem] = out; + len[num_elem] = crypt_len; + num_elem++; + + memcpy(iv, iv_crypt, AES_BLOCK_SIZE); + memcpy(frame_iv, iv_crypt, AES_BLOCK_SIZE); + + /* Synthetic IV to be used as the initial counter in CTR: + * Q = V bitand (1^64 || 0^1 || 1^31 || 0^1 || 1^31) + */ + iv[8] &= 0x7f; + iv[12] &= 0x7f; + + /* CTR */ + + tfm2 = crypto_alloc_skcipher("ctr(aes)", 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(tfm2)) + return PTR_ERR(tfm2); + /* K2 for CTR */ + res = crypto_skcipher_setkey(tfm2, key + key_len, key_len); + if (res) { + crypto_free_skcipher(tfm2); + return res; + } + + req = skcipher_request_alloc(tfm2, GFP_KERNEL); + if (!req) { + crypto_free_skcipher(tfm2); + return -ENOMEM; + } + + sg_init_one(src, iv_crypt + AES_BLOCK_SIZE, crypt_len); + sg_init_one(dst, out, crypt_len); + skcipher_request_set_crypt(req, src, dst, crypt_len, iv); + res = crypto_skcipher_decrypt(req); + skcipher_request_free(req); + crypto_free_skcipher(tfm2); + if (res) + return res; + + /* S2V */ + + tfm = crypto_alloc_shash("cmac(aes)", 0, 0); + if (IS_ERR(tfm)) + return PTR_ERR(tfm); + /* K1 for S2V */ + res = crypto_shash_setkey(tfm, key, key_len); + if (!res) + res = aes_s2v(tfm, num_elem, addr, len, check); + crypto_free_shash(tfm); + if (res) + return res; + if (memcmp(check, frame_iv, AES_BLOCK_SIZE) != 0) + return -EINVAL; + return 0; +} + +int fils_encrypt_assoc_req(struct sk_buff *skb, + struct ieee80211_mgd_assoc_data *assoc_data) +{ + struct ieee80211_mgmt *mgmt = (void *)skb->data; + u8 *capab, *ies, *encr; + const u8 *addr[5 + 1]; + const struct element *session; + size_t len[5 + 1]; + size_t crypt_len; + + if (ieee80211_is_reassoc_req(mgmt->frame_control)) { + capab = (u8 *)&mgmt->u.reassoc_req.capab_info; + ies = mgmt->u.reassoc_req.variable; + } else { + capab = (u8 *)&mgmt->u.assoc_req.capab_info; + ies = mgmt->u.assoc_req.variable; + } + + session = cfg80211_find_ext_elem(WLAN_EID_EXT_FILS_SESSION, + ies, skb->data + skb->len - ies); + if (!session || session->datalen != 1 + 8) + return -EINVAL; + /* encrypt after FILS Session element */ + encr = (u8 *)session->data + 1 + 8; + + /* AES-SIV AAD vectors */ + + /* The STA's MAC address */ + addr[0] = mgmt->sa; + len[0] = ETH_ALEN; + /* The AP's BSSID */ + addr[1] = mgmt->da; + len[1] = ETH_ALEN; + /* The STA's nonce */ + addr[2] = assoc_data->fils_nonces; + len[2] = FILS_NONCE_LEN; + /* The AP's nonce */ + addr[3] = &assoc_data->fils_nonces[FILS_NONCE_LEN]; + len[3] = FILS_NONCE_LEN; + /* The (Re)Association Request frame from the Capability Information + * field to the FILS Session element (both inclusive). + */ + addr[4] = capab; + len[4] = encr - capab; + + crypt_len = skb->data + skb->len - encr; + skb_put(skb, AES_BLOCK_SIZE); + return aes_siv_encrypt(assoc_data->fils_kek, assoc_data->fils_kek_len, + encr, crypt_len, 5, addr, len, encr); +} + +int fils_decrypt_assoc_resp(struct ieee80211_sub_if_data *sdata, + u8 *frame, size_t *frame_len, + struct ieee80211_mgd_assoc_data *assoc_data) +{ + struct ieee80211_mgmt *mgmt = (void *)frame; + u8 *capab, *ies, *encr; + const u8 *addr[5 + 1]; + const struct element *session; + size_t len[5 + 1]; + int res; + size_t crypt_len; + + if (*frame_len < 24 + 6) + return -EINVAL; + + capab = (u8 *)&mgmt->u.assoc_resp.capab_info; + ies = mgmt->u.assoc_resp.variable; + session = cfg80211_find_ext_elem(WLAN_EID_EXT_FILS_SESSION, + ies, frame + *frame_len - ies); + if (!session || session->datalen != 1 + 8) { + mlme_dbg(sdata, + "No (valid) FILS Session element in (Re)Association Response frame from %pM", + mgmt->sa); + return -EINVAL; + } + /* decrypt after FILS Session element */ + encr = (u8 *)session->data + 1 + 8; + + /* AES-SIV AAD vectors */ + + /* The AP's BSSID */ + addr[0] = mgmt->sa; + len[0] = ETH_ALEN; + /* The STA's MAC address */ + addr[1] = mgmt->da; + len[1] = ETH_ALEN; + /* The AP's nonce */ + addr[2] = &assoc_data->fils_nonces[FILS_NONCE_LEN]; + len[2] = FILS_NONCE_LEN; + /* The STA's nonce */ + addr[3] = assoc_data->fils_nonces; + len[3] = FILS_NONCE_LEN; + /* The (Re)Association Response frame from the Capability Information + * field to the FILS Session element (both inclusive). + */ + addr[4] = capab; + len[4] = encr - capab; + + crypt_len = frame + *frame_len - encr; + if (crypt_len < AES_BLOCK_SIZE) { + mlme_dbg(sdata, + "Not enough room for AES-SIV data after FILS Session element in (Re)Association Response frame from %pM", + mgmt->sa); + return -EINVAL; + } + res = aes_siv_decrypt(assoc_data->fils_kek, assoc_data->fils_kek_len, + encr, crypt_len, 5, addr, len, encr); + if (res != 0) { + mlme_dbg(sdata, + "AES-SIV decryption of (Re)Association Response frame from %pM failed", + mgmt->sa); + return res; + } + *frame_len -= AES_BLOCK_SIZE; + return 0; +} diff --git a/net/mac80211/fils_aead.h b/net/mac80211/fils_aead.h new file mode 100644 index 000000000..c868153f8 --- /dev/null +++ b/net/mac80211/fils_aead.h @@ -0,0 +1,16 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * FILS AEAD for (Re)Association Request/Response frames + * Copyright 2016, Qualcomm Atheros, Inc. + */ + +#ifndef FILS_AEAD_H +#define FILS_AEAD_H + +int fils_encrypt_assoc_req(struct sk_buff *skb, + struct ieee80211_mgd_assoc_data *assoc_data); +int fils_decrypt_assoc_resp(struct ieee80211_sub_if_data *sdata, + u8 *frame, size_t *frame_len, + struct ieee80211_mgd_assoc_data *assoc_data); + +#endif /* FILS_AEAD_H */ diff --git a/net/mac80211/he.c b/net/mac80211/he.c new file mode 100644 index 000000000..0322abae0 --- /dev/null +++ b/net/mac80211/he.c @@ -0,0 +1,245 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * HE handling + * + * Copyright(c) 2017 Intel Deutschland GmbH + * Copyright(c) 2019 - 2023 Intel Corporation + */ + +#include "ieee80211_i.h" + +static void +ieee80211_update_from_he_6ghz_capa(const struct ieee80211_he_6ghz_capa *he_6ghz_capa, + struct link_sta_info *link_sta) +{ + struct sta_info *sta = link_sta->sta; + enum ieee80211_smps_mode smps_mode; + + if (sta->sdata->vif.type == NL80211_IFTYPE_AP || + sta->sdata->vif.type == NL80211_IFTYPE_AP_VLAN) { + switch (le16_get_bits(he_6ghz_capa->capa, + IEEE80211_HE_6GHZ_CAP_SM_PS)) { + case WLAN_HT_CAP_SM_PS_INVALID: + case WLAN_HT_CAP_SM_PS_STATIC: + smps_mode = IEEE80211_SMPS_STATIC; + break; + case WLAN_HT_CAP_SM_PS_DYNAMIC: + smps_mode = IEEE80211_SMPS_DYNAMIC; + break; + case WLAN_HT_CAP_SM_PS_DISABLED: + smps_mode = IEEE80211_SMPS_OFF; + break; + } + + link_sta->pub->smps_mode = smps_mode; + } else { + link_sta->pub->smps_mode = IEEE80211_SMPS_OFF; + } + + switch (le16_get_bits(he_6ghz_capa->capa, + IEEE80211_HE_6GHZ_CAP_MAX_MPDU_LEN)) { + case IEEE80211_VHT_CAP_MAX_MPDU_LENGTH_11454: + link_sta->pub->agg.max_amsdu_len = IEEE80211_MAX_MPDU_LEN_VHT_11454; + break; + case IEEE80211_VHT_CAP_MAX_MPDU_LENGTH_7991: + link_sta->pub->agg.max_amsdu_len = IEEE80211_MAX_MPDU_LEN_VHT_7991; + break; + case IEEE80211_VHT_CAP_MAX_MPDU_LENGTH_3895: + default: + link_sta->pub->agg.max_amsdu_len = IEEE80211_MAX_MPDU_LEN_VHT_3895; + break; + } + + ieee80211_sta_recalc_aggregates(&sta->sta); + + link_sta->pub->he_6ghz_capa = *he_6ghz_capa; +} + +static void ieee80211_he_mcs_disable(__le16 *he_mcs) +{ + u32 i; + + for (i = 0; i < 8; i++) + *he_mcs |= cpu_to_le16(IEEE80211_HE_MCS_NOT_SUPPORTED << i * 2); +} + +static void ieee80211_he_mcs_intersection(__le16 *he_own_rx, __le16 *he_peer_rx, + __le16 *he_own_tx, __le16 *he_peer_tx) +{ + u32 i; + u16 own_rx, own_tx, peer_rx, peer_tx; + + for (i = 0; i < 8; i++) { + own_rx = le16_to_cpu(*he_own_rx); + own_rx = (own_rx >> i * 2) & IEEE80211_HE_MCS_NOT_SUPPORTED; + + own_tx = le16_to_cpu(*he_own_tx); + own_tx = (own_tx >> i * 2) & IEEE80211_HE_MCS_NOT_SUPPORTED; + + peer_rx = le16_to_cpu(*he_peer_rx); + peer_rx = (peer_rx >> i * 2) & IEEE80211_HE_MCS_NOT_SUPPORTED; + + peer_tx = le16_to_cpu(*he_peer_tx); + peer_tx = (peer_tx >> i * 2) & IEEE80211_HE_MCS_NOT_SUPPORTED; + + if (peer_tx != IEEE80211_HE_MCS_NOT_SUPPORTED) { + if (own_rx == IEEE80211_HE_MCS_NOT_SUPPORTED) + peer_tx = IEEE80211_HE_MCS_NOT_SUPPORTED; + else if (own_rx < peer_tx) + peer_tx = own_rx; + } + + if (peer_rx != IEEE80211_HE_MCS_NOT_SUPPORTED) { + if (own_tx == IEEE80211_HE_MCS_NOT_SUPPORTED) + peer_rx = IEEE80211_HE_MCS_NOT_SUPPORTED; + else if (own_tx < peer_rx) + peer_rx = own_tx; + } + + *he_peer_rx &= + ~cpu_to_le16(IEEE80211_HE_MCS_NOT_SUPPORTED << i * 2); + *he_peer_rx |= cpu_to_le16(peer_rx << i * 2); + + *he_peer_tx &= + ~cpu_to_le16(IEEE80211_HE_MCS_NOT_SUPPORTED << i * 2); + *he_peer_tx |= cpu_to_le16(peer_tx << i * 2); + } +} + +void +ieee80211_he_cap_ie_to_sta_he_cap(struct ieee80211_sub_if_data *sdata, + struct ieee80211_supported_band *sband, + const u8 *he_cap_ie, u8 he_cap_len, + const struct ieee80211_he_6ghz_capa *he_6ghz_capa, + struct link_sta_info *link_sta) +{ + struct ieee80211_sta_he_cap *he_cap = &link_sta->pub->he_cap; + const struct ieee80211_sta_he_cap *own_he_cap_ptr; + struct ieee80211_sta_he_cap own_he_cap; + struct ieee80211_he_cap_elem *he_cap_ie_elem = (void *)he_cap_ie; + u8 he_ppe_size; + u8 mcs_nss_size; + u8 he_total_size; + bool own_160, peer_160, own_80p80, peer_80p80; + + memset(he_cap, 0, sizeof(*he_cap)); + + if (!he_cap_ie) + return; + + own_he_cap_ptr = + ieee80211_get_he_iftype_cap(sband, + ieee80211_vif_type_p2p(&sdata->vif)); + if (!own_he_cap_ptr) + return; + + own_he_cap = *own_he_cap_ptr; + + /* Make sure size is OK */ + mcs_nss_size = ieee80211_he_mcs_nss_size(he_cap_ie_elem); + he_ppe_size = + ieee80211_he_ppe_size(he_cap_ie[sizeof(he_cap->he_cap_elem) + + mcs_nss_size], + he_cap_ie_elem->phy_cap_info); + he_total_size = sizeof(he_cap->he_cap_elem) + mcs_nss_size + + he_ppe_size; + if (he_cap_len < he_total_size) + return; + + memcpy(&he_cap->he_cap_elem, he_cap_ie, sizeof(he_cap->he_cap_elem)); + + /* HE Tx/Rx HE MCS NSS Support Field */ + memcpy(&he_cap->he_mcs_nss_supp, + &he_cap_ie[sizeof(he_cap->he_cap_elem)], mcs_nss_size); + + /* Check if there are (optional) PPE Thresholds */ + if (he_cap->he_cap_elem.phy_cap_info[6] & + IEEE80211_HE_PHY_CAP6_PPE_THRESHOLD_PRESENT) + memcpy(he_cap->ppe_thres, + &he_cap_ie[sizeof(he_cap->he_cap_elem) + mcs_nss_size], + he_ppe_size); + + he_cap->has_he = true; + + link_sta->cur_max_bandwidth = ieee80211_sta_cap_rx_bw(link_sta); + link_sta->pub->bandwidth = ieee80211_sta_cur_vht_bw(link_sta); + + if (sband->band == NL80211_BAND_6GHZ && he_6ghz_capa) + ieee80211_update_from_he_6ghz_capa(he_6ghz_capa, link_sta); + + ieee80211_he_mcs_intersection(&own_he_cap.he_mcs_nss_supp.rx_mcs_80, + &he_cap->he_mcs_nss_supp.rx_mcs_80, + &own_he_cap.he_mcs_nss_supp.tx_mcs_80, + &he_cap->he_mcs_nss_supp.tx_mcs_80); + + own_160 = own_he_cap.he_cap_elem.phy_cap_info[0] & + IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_160MHZ_IN_5G; + peer_160 = he_cap->he_cap_elem.phy_cap_info[0] & + IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_160MHZ_IN_5G; + + if (peer_160 && own_160) { + ieee80211_he_mcs_intersection(&own_he_cap.he_mcs_nss_supp.rx_mcs_160, + &he_cap->he_mcs_nss_supp.rx_mcs_160, + &own_he_cap.he_mcs_nss_supp.tx_mcs_160, + &he_cap->he_mcs_nss_supp.tx_mcs_160); + } else if (peer_160 && !own_160) { + ieee80211_he_mcs_disable(&he_cap->he_mcs_nss_supp.rx_mcs_160); + ieee80211_he_mcs_disable(&he_cap->he_mcs_nss_supp.tx_mcs_160); + he_cap->he_cap_elem.phy_cap_info[0] &= + ~IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_160MHZ_IN_5G; + } + + own_80p80 = own_he_cap.he_cap_elem.phy_cap_info[0] & + IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_80PLUS80_MHZ_IN_5G; + peer_80p80 = he_cap->he_cap_elem.phy_cap_info[0] & + IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_80PLUS80_MHZ_IN_5G; + + if (peer_80p80 && own_80p80) { + ieee80211_he_mcs_intersection(&own_he_cap.he_mcs_nss_supp.rx_mcs_80p80, + &he_cap->he_mcs_nss_supp.rx_mcs_80p80, + &own_he_cap.he_mcs_nss_supp.tx_mcs_80p80, + &he_cap->he_mcs_nss_supp.tx_mcs_80p80); + } else if (peer_80p80 && !own_80p80) { + ieee80211_he_mcs_disable(&he_cap->he_mcs_nss_supp.rx_mcs_80p80); + ieee80211_he_mcs_disable(&he_cap->he_mcs_nss_supp.tx_mcs_80p80); + he_cap->he_cap_elem.phy_cap_info[0] &= + ~IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_80PLUS80_MHZ_IN_5G; + } +} + +void +ieee80211_he_op_ie_to_bss_conf(struct ieee80211_vif *vif, + const struct ieee80211_he_operation *he_op_ie) +{ + memset(&vif->bss_conf.he_oper, 0, sizeof(vif->bss_conf.he_oper)); + if (!he_op_ie) + return; + + vif->bss_conf.he_oper.params = __le32_to_cpu(he_op_ie->he_oper_params); + vif->bss_conf.he_oper.nss_set = __le16_to_cpu(he_op_ie->he_mcs_nss_set); +} + +void +ieee80211_he_spr_ie_to_bss_conf(struct ieee80211_vif *vif, + const struct ieee80211_he_spr *he_spr_ie_elem) +{ + struct ieee80211_he_obss_pd *he_obss_pd = + &vif->bss_conf.he_obss_pd; + const u8 *data; + + memset(he_obss_pd, 0, sizeof(*he_obss_pd)); + + if (!he_spr_ie_elem) + return; + data = he_spr_ie_elem->optional; + + if (he_spr_ie_elem->he_sr_control & + IEEE80211_HE_SPR_NON_SRG_OFFSET_PRESENT) + data++; + if (he_spr_ie_elem->he_sr_control & + IEEE80211_HE_SPR_SRG_INFORMATION_PRESENT) { + he_obss_pd->max_offset = *data++; + he_obss_pd->min_offset = *data++; + he_obss_pd->enable = true; + } +} diff --git a/net/mac80211/ht.c b/net/mac80211/ht.c new file mode 100644 index 000000000..9bfe128ad --- /dev/null +++ b/net/mac80211/ht.c @@ -0,0 +1,617 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * HT handling + * + * Copyright 2003, Jouni Malinen + * Copyright 2002-2005, Instant802 Networks, Inc. + * Copyright 2005-2006, Devicescape Software, Inc. + * Copyright 2006-2007 Jiri Benc + * Copyright 2007, Michael Wu + * Copyright 2007-2010, Intel Corporation + * Copyright 2017 Intel Deutschland GmbH + * Copyright(c) 2020-2022 Intel Corporation + */ + +#include +#include +#include +#include "ieee80211_i.h" +#include "rate.h" + +static void __check_htcap_disable(struct ieee80211_ht_cap *ht_capa, + struct ieee80211_ht_cap *ht_capa_mask, + struct ieee80211_sta_ht_cap *ht_cap, + u16 flag) +{ + __le16 le_flag = cpu_to_le16(flag); + if (ht_capa_mask->cap_info & le_flag) { + if (!(ht_capa->cap_info & le_flag)) + ht_cap->cap &= ~flag; + } +} + +static void __check_htcap_enable(struct ieee80211_ht_cap *ht_capa, + struct ieee80211_ht_cap *ht_capa_mask, + struct ieee80211_sta_ht_cap *ht_cap, + u16 flag) +{ + __le16 le_flag = cpu_to_le16(flag); + + if ((ht_capa_mask->cap_info & le_flag) && + (ht_capa->cap_info & le_flag)) + ht_cap->cap |= flag; +} + +void ieee80211_apply_htcap_overrides(struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta_ht_cap *ht_cap) +{ + struct ieee80211_ht_cap *ht_capa, *ht_capa_mask; + u8 *scaps, *smask; + int i; + + if (!ht_cap->ht_supported) + return; + + switch (sdata->vif.type) { + case NL80211_IFTYPE_STATION: + ht_capa = &sdata->u.mgd.ht_capa; + ht_capa_mask = &sdata->u.mgd.ht_capa_mask; + break; + case NL80211_IFTYPE_ADHOC: + ht_capa = &sdata->u.ibss.ht_capa; + ht_capa_mask = &sdata->u.ibss.ht_capa_mask; + break; + default: + WARN_ON_ONCE(1); + return; + } + + scaps = (u8 *)(&ht_capa->mcs.rx_mask); + smask = (u8 *)(&ht_capa_mask->mcs.rx_mask); + + /* NOTE: If you add more over-rides here, update register_hw + * ht_capa_mod_mask logic in main.c as well. + * And, if this method can ever change ht_cap.ht_supported, fix + * the check in ieee80211_add_ht_ie. + */ + + /* check for HT over-rides, MCS rates first. */ + for (i = 0; i < IEEE80211_HT_MCS_MASK_LEN; i++) { + u8 m = smask[i]; + ht_cap->mcs.rx_mask[i] &= ~m; /* turn off all masked bits */ + /* Add back rates that are supported */ + ht_cap->mcs.rx_mask[i] |= (m & scaps[i]); + } + + /* Force removal of HT-40 capabilities? */ + __check_htcap_disable(ht_capa, ht_capa_mask, ht_cap, + IEEE80211_HT_CAP_SUP_WIDTH_20_40); + __check_htcap_disable(ht_capa, ht_capa_mask, ht_cap, + IEEE80211_HT_CAP_SGI_40); + + /* Allow user to disable SGI-20 (SGI-40 is handled above) */ + __check_htcap_disable(ht_capa, ht_capa_mask, ht_cap, + IEEE80211_HT_CAP_SGI_20); + + /* Allow user to disable the max-AMSDU bit. */ + __check_htcap_disable(ht_capa, ht_capa_mask, ht_cap, + IEEE80211_HT_CAP_MAX_AMSDU); + + /* Allow user to disable LDPC */ + __check_htcap_disable(ht_capa, ht_capa_mask, ht_cap, + IEEE80211_HT_CAP_LDPC_CODING); + + /* Allow user to enable 40 MHz intolerant bit. */ + __check_htcap_enable(ht_capa, ht_capa_mask, ht_cap, + IEEE80211_HT_CAP_40MHZ_INTOLERANT); + + /* Allow user to enable TX STBC bit */ + __check_htcap_enable(ht_capa, ht_capa_mask, ht_cap, + IEEE80211_HT_CAP_TX_STBC); + + /* Allow user to configure RX STBC bits */ + if (ht_capa_mask->cap_info & cpu_to_le16(IEEE80211_HT_CAP_RX_STBC)) + ht_cap->cap |= le16_to_cpu(ht_capa->cap_info) & + IEEE80211_HT_CAP_RX_STBC; + + /* Allow user to decrease AMPDU factor */ + if (ht_capa_mask->ampdu_params_info & + IEEE80211_HT_AMPDU_PARM_FACTOR) { + u8 n = ht_capa->ampdu_params_info & + IEEE80211_HT_AMPDU_PARM_FACTOR; + if (n < ht_cap->ampdu_factor) + ht_cap->ampdu_factor = n; + } + + /* Allow the user to increase AMPDU density. */ + if (ht_capa_mask->ampdu_params_info & + IEEE80211_HT_AMPDU_PARM_DENSITY) { + u8 n = (ht_capa->ampdu_params_info & + IEEE80211_HT_AMPDU_PARM_DENSITY) + >> IEEE80211_HT_AMPDU_PARM_DENSITY_SHIFT; + if (n > ht_cap->ampdu_density) + ht_cap->ampdu_density = n; + } +} + + +bool ieee80211_ht_cap_ie_to_sta_ht_cap(struct ieee80211_sub_if_data *sdata, + struct ieee80211_supported_band *sband, + const struct ieee80211_ht_cap *ht_cap_ie, + struct link_sta_info *link_sta) +{ + struct ieee80211_bss_conf *link_conf; + struct sta_info *sta = link_sta->sta; + struct ieee80211_sta_ht_cap ht_cap, own_cap; + u8 ampdu_info, tx_mcs_set_cap; + int i, max_tx_streams; + bool changed; + enum ieee80211_sta_rx_bandwidth bw; + enum nl80211_chan_width width; + + memset(&ht_cap, 0, sizeof(ht_cap)); + + if (!ht_cap_ie || !sband->ht_cap.ht_supported) + goto apply; + + ht_cap.ht_supported = true; + + own_cap = sband->ht_cap; + + /* + * If user has specified capability over-rides, take care + * of that if the station we're setting up is the AP or TDLS peer that + * we advertised a restricted capability set to. Override + * our own capabilities and then use those below. + */ + if (sdata->vif.type == NL80211_IFTYPE_STATION || + sdata->vif.type == NL80211_IFTYPE_ADHOC) + ieee80211_apply_htcap_overrides(sdata, &own_cap); + + /* + * The bits listed in this expression should be + * the same for the peer and us, if the station + * advertises more then we can't use those thus + * we mask them out. + */ + ht_cap.cap = le16_to_cpu(ht_cap_ie->cap_info) & + (own_cap.cap | ~(IEEE80211_HT_CAP_LDPC_CODING | + IEEE80211_HT_CAP_SUP_WIDTH_20_40 | + IEEE80211_HT_CAP_GRN_FLD | + IEEE80211_HT_CAP_SGI_20 | + IEEE80211_HT_CAP_SGI_40 | + IEEE80211_HT_CAP_DSSSCCK40)); + + /* + * The STBC bits are asymmetric -- if we don't have + * TX then mask out the peer's RX and vice versa. + */ + if (!(own_cap.cap & IEEE80211_HT_CAP_TX_STBC)) + ht_cap.cap &= ~IEEE80211_HT_CAP_RX_STBC; + if (!(own_cap.cap & IEEE80211_HT_CAP_RX_STBC)) + ht_cap.cap &= ~IEEE80211_HT_CAP_TX_STBC; + + ampdu_info = ht_cap_ie->ampdu_params_info; + ht_cap.ampdu_factor = + ampdu_info & IEEE80211_HT_AMPDU_PARM_FACTOR; + ht_cap.ampdu_density = + (ampdu_info & IEEE80211_HT_AMPDU_PARM_DENSITY) >> 2; + + /* own MCS TX capabilities */ + tx_mcs_set_cap = own_cap.mcs.tx_params; + + /* Copy peer MCS TX capabilities, the driver might need them. */ + ht_cap.mcs.tx_params = ht_cap_ie->mcs.tx_params; + + /* can we TX with MCS rates? */ + if (!(tx_mcs_set_cap & IEEE80211_HT_MCS_TX_DEFINED)) + goto apply; + + /* Counting from 0, therefore +1 */ + if (tx_mcs_set_cap & IEEE80211_HT_MCS_TX_RX_DIFF) + max_tx_streams = + ((tx_mcs_set_cap & IEEE80211_HT_MCS_TX_MAX_STREAMS_MASK) + >> IEEE80211_HT_MCS_TX_MAX_STREAMS_SHIFT) + 1; + else + max_tx_streams = IEEE80211_HT_MCS_TX_MAX_STREAMS; + + /* + * 802.11n-2009 20.3.5 / 20.6 says: + * - indices 0 to 7 and 32 are single spatial stream + * - 8 to 31 are multiple spatial streams using equal modulation + * [8..15 for two streams, 16..23 for three and 24..31 for four] + * - remainder are multiple spatial streams using unequal modulation + */ + for (i = 0; i < max_tx_streams; i++) + ht_cap.mcs.rx_mask[i] = + own_cap.mcs.rx_mask[i] & ht_cap_ie->mcs.rx_mask[i]; + + if (tx_mcs_set_cap & IEEE80211_HT_MCS_TX_UNEQUAL_MODULATION) + for (i = IEEE80211_HT_MCS_UNEQUAL_MODULATION_START_BYTE; + i < IEEE80211_HT_MCS_MASK_LEN; i++) + ht_cap.mcs.rx_mask[i] = + own_cap.mcs.rx_mask[i] & + ht_cap_ie->mcs.rx_mask[i]; + + /* handle MCS rate 32 too */ + if (own_cap.mcs.rx_mask[32/8] & ht_cap_ie->mcs.rx_mask[32/8] & 1) + ht_cap.mcs.rx_mask[32/8] |= 1; + + /* set Rx highest rate */ + ht_cap.mcs.rx_highest = ht_cap_ie->mcs.rx_highest; + + if (ht_cap.cap & IEEE80211_HT_CAP_MAX_AMSDU) + link_sta->pub->agg.max_amsdu_len = IEEE80211_MAX_MPDU_LEN_HT_7935; + else + link_sta->pub->agg.max_amsdu_len = IEEE80211_MAX_MPDU_LEN_HT_3839; + + ieee80211_sta_recalc_aggregates(&sta->sta); + + apply: + changed = memcmp(&link_sta->pub->ht_cap, &ht_cap, sizeof(ht_cap)); + + memcpy(&link_sta->pub->ht_cap, &ht_cap, sizeof(ht_cap)); + + rcu_read_lock(); + link_conf = rcu_dereference(sdata->vif.link_conf[link_sta->link_id]); + if (WARN_ON(!link_conf)) + width = NL80211_CHAN_WIDTH_20_NOHT; + else + width = link_conf->chandef.width; + + switch (width) { + default: + WARN_ON_ONCE(1); + fallthrough; + case NL80211_CHAN_WIDTH_20_NOHT: + case NL80211_CHAN_WIDTH_20: + bw = IEEE80211_STA_RX_BW_20; + break; + case NL80211_CHAN_WIDTH_40: + case NL80211_CHAN_WIDTH_80: + case NL80211_CHAN_WIDTH_80P80: + case NL80211_CHAN_WIDTH_160: + case NL80211_CHAN_WIDTH_320: + bw = ht_cap.cap & IEEE80211_HT_CAP_SUP_WIDTH_20_40 ? + IEEE80211_STA_RX_BW_40 : IEEE80211_STA_RX_BW_20; + break; + } + rcu_read_unlock(); + + link_sta->pub->bandwidth = bw; + + link_sta->cur_max_bandwidth = + ht_cap.cap & IEEE80211_HT_CAP_SUP_WIDTH_20_40 ? + IEEE80211_STA_RX_BW_40 : IEEE80211_STA_RX_BW_20; + + if (sta->sdata->vif.type == NL80211_IFTYPE_AP || + sta->sdata->vif.type == NL80211_IFTYPE_AP_VLAN) { + enum ieee80211_smps_mode smps_mode; + + switch ((ht_cap.cap & IEEE80211_HT_CAP_SM_PS) + >> IEEE80211_HT_CAP_SM_PS_SHIFT) { + case WLAN_HT_CAP_SM_PS_INVALID: + case WLAN_HT_CAP_SM_PS_STATIC: + smps_mode = IEEE80211_SMPS_STATIC; + break; + case WLAN_HT_CAP_SM_PS_DYNAMIC: + smps_mode = IEEE80211_SMPS_DYNAMIC; + break; + case WLAN_HT_CAP_SM_PS_DISABLED: + smps_mode = IEEE80211_SMPS_OFF; + break; + } + + if (smps_mode != link_sta->pub->smps_mode) + changed = true; + link_sta->pub->smps_mode = smps_mode; + } else { + link_sta->pub->smps_mode = IEEE80211_SMPS_OFF; + } + + return changed; +} + +void ieee80211_sta_tear_down_BA_sessions(struct sta_info *sta, + enum ieee80211_agg_stop_reason reason) +{ + int i; + + mutex_lock(&sta->ampdu_mlme.mtx); + for (i = 0; i < IEEE80211_NUM_TIDS; i++) + ___ieee80211_stop_rx_ba_session(sta, i, WLAN_BACK_RECIPIENT, + WLAN_REASON_QSTA_LEAVE_QBSS, + reason != AGG_STOP_DESTROY_STA && + reason != AGG_STOP_PEER_REQUEST); + + for (i = 0; i < IEEE80211_NUM_TIDS; i++) + ___ieee80211_stop_tx_ba_session(sta, i, reason); + mutex_unlock(&sta->ampdu_mlme.mtx); + + /* + * In case the tear down is part of a reconfigure due to HW restart + * request, it is possible that the low level driver requested to stop + * the BA session, so handle it to properly clean tid_tx data. + */ + if(reason == AGG_STOP_DESTROY_STA) { + cancel_work_sync(&sta->ampdu_mlme.work); + + mutex_lock(&sta->ampdu_mlme.mtx); + for (i = 0; i < IEEE80211_NUM_TIDS; i++) { + struct tid_ampdu_tx *tid_tx = + rcu_dereference_protected_tid_tx(sta, i); + + if (!tid_tx) + continue; + + if (test_and_clear_bit(HT_AGG_STATE_STOP_CB, &tid_tx->state)) + ieee80211_stop_tx_ba_cb(sta, i, tid_tx); + } + mutex_unlock(&sta->ampdu_mlme.mtx); + } +} + +void ieee80211_ba_session_work(struct work_struct *work) +{ + struct sta_info *sta = + container_of(work, struct sta_info, ampdu_mlme.work); + struct tid_ampdu_tx *tid_tx; + bool blocked; + int tid; + + /* When this flag is set, new sessions should be blocked. */ + blocked = test_sta_flag(sta, WLAN_STA_BLOCK_BA); + + mutex_lock(&sta->ampdu_mlme.mtx); + for (tid = 0; tid < IEEE80211_NUM_TIDS; tid++) { + if (test_and_clear_bit(tid, sta->ampdu_mlme.tid_rx_timer_expired)) + ___ieee80211_stop_rx_ba_session( + sta, tid, WLAN_BACK_RECIPIENT, + WLAN_REASON_QSTA_TIMEOUT, true); + + if (test_and_clear_bit(tid, + sta->ampdu_mlme.tid_rx_stop_requested)) + ___ieee80211_stop_rx_ba_session( + sta, tid, WLAN_BACK_RECIPIENT, + WLAN_REASON_UNSPECIFIED, true); + + if (!blocked && + test_and_clear_bit(tid, + sta->ampdu_mlme.tid_rx_manage_offl)) + ___ieee80211_start_rx_ba_session(sta, 0, 0, 0, 1, tid, + IEEE80211_MAX_AMPDU_BUF_HT, + false, true, NULL); + + if (test_and_clear_bit(tid + IEEE80211_NUM_TIDS, + sta->ampdu_mlme.tid_rx_manage_offl)) + ___ieee80211_stop_rx_ba_session( + sta, tid, WLAN_BACK_RECIPIENT, + 0, false); + + spin_lock_bh(&sta->lock); + + tid_tx = sta->ampdu_mlme.tid_start_tx[tid]; + if (!blocked && tid_tx) { + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sdata->local; + + if (local->ops->wake_tx_queue) { + struct txq_info *txqi = + to_txq_info(sta->sta.txq[tid]); + struct fq *fq = &local->fq; + + spin_lock_bh(&fq->lock); + + /* Allow only frags to be dequeued */ + set_bit(IEEE80211_TXQ_STOP, &txqi->flags); + + if (!skb_queue_empty(&txqi->frags)) { + /* Fragmented Tx is ongoing, wait for it + * to finish. Reschedule worker to retry + * later. + */ + + spin_unlock_bh(&fq->lock); + spin_unlock_bh(&sta->lock); + + /* Give the task working on the txq a + * chance to send out the queued frags + */ + synchronize_net(); + + mutex_unlock(&sta->ampdu_mlme.mtx); + + ieee80211_queue_work(&sdata->local->hw, + work); + return; + } + + spin_unlock_bh(&fq->lock); + } + + /* + * Assign it over to the normal tid_tx array + * where it "goes live". + */ + + sta->ampdu_mlme.tid_start_tx[tid] = NULL; + /* could there be a race? */ + if (sta->ampdu_mlme.tid_tx[tid]) + kfree(tid_tx); + else + ieee80211_assign_tid_tx(sta, tid, tid_tx); + spin_unlock_bh(&sta->lock); + + ieee80211_tx_ba_session_handle_start(sta, tid); + continue; + } + spin_unlock_bh(&sta->lock); + + tid_tx = rcu_dereference_protected_tid_tx(sta, tid); + if (!tid_tx) + continue; + + if (!blocked && + test_and_clear_bit(HT_AGG_STATE_START_CB, &tid_tx->state)) + ieee80211_start_tx_ba_cb(sta, tid, tid_tx); + if (test_and_clear_bit(HT_AGG_STATE_WANT_STOP, &tid_tx->state)) + ___ieee80211_stop_tx_ba_session(sta, tid, + AGG_STOP_LOCAL_REQUEST); + if (test_and_clear_bit(HT_AGG_STATE_STOP_CB, &tid_tx->state)) + ieee80211_stop_tx_ba_cb(sta, tid, tid_tx); + } + mutex_unlock(&sta->ampdu_mlme.mtx); +} + +void ieee80211_send_delba(struct ieee80211_sub_if_data *sdata, + const u8 *da, u16 tid, + u16 initiator, u16 reason_code) +{ + struct ieee80211_local *local = sdata->local; + struct sk_buff *skb; + struct ieee80211_mgmt *mgmt; + u16 params; + + skb = dev_alloc_skb(sizeof(*mgmt) + local->hw.extra_tx_headroom); + if (!skb) + return; + + skb_reserve(skb, local->hw.extra_tx_headroom); + mgmt = skb_put_zero(skb, 24); + memcpy(mgmt->da, da, ETH_ALEN); + memcpy(mgmt->sa, sdata->vif.addr, ETH_ALEN); + if (sdata->vif.type == NL80211_IFTYPE_AP || + sdata->vif.type == NL80211_IFTYPE_AP_VLAN || + sdata->vif.type == NL80211_IFTYPE_MESH_POINT) + memcpy(mgmt->bssid, sdata->vif.addr, ETH_ALEN); + else if (sdata->vif.type == NL80211_IFTYPE_STATION) + memcpy(mgmt->bssid, sdata->deflink.u.mgd.bssid, ETH_ALEN); + else if (sdata->vif.type == NL80211_IFTYPE_ADHOC) + memcpy(mgmt->bssid, sdata->u.ibss.bssid, ETH_ALEN); + + mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_ACTION); + + skb_put(skb, 1 + sizeof(mgmt->u.action.u.delba)); + + mgmt->u.action.category = WLAN_CATEGORY_BACK; + mgmt->u.action.u.delba.action_code = WLAN_ACTION_DELBA; + params = (u16)(initiator << 11); /* bit 11 initiator */ + params |= (u16)(tid << 12); /* bit 15:12 TID number */ + + mgmt->u.action.u.delba.params = cpu_to_le16(params); + mgmt->u.action.u.delba.reason_code = cpu_to_le16(reason_code); + + ieee80211_tx_skb(sdata, skb); +} + +void ieee80211_process_delba(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + struct ieee80211_mgmt *mgmt, size_t len) +{ + u16 tid, params; + u16 initiator; + + params = le16_to_cpu(mgmt->u.action.u.delba.params); + tid = (params & IEEE80211_DELBA_PARAM_TID_MASK) >> 12; + initiator = (params & IEEE80211_DELBA_PARAM_INITIATOR_MASK) >> 11; + + ht_dbg_ratelimited(sdata, "delba from %pM (%s) tid %d reason code %d\n", + mgmt->sa, initiator ? "initiator" : "recipient", + tid, + le16_to_cpu(mgmt->u.action.u.delba.reason_code)); + + if (initiator == WLAN_BACK_INITIATOR) + __ieee80211_stop_rx_ba_session(sta, tid, WLAN_BACK_INITIATOR, 0, + true); + else + __ieee80211_stop_tx_ba_session(sta, tid, AGG_STOP_PEER_REQUEST); +} + +enum nl80211_smps_mode +ieee80211_smps_mode_to_smps_mode(enum ieee80211_smps_mode smps) +{ + switch (smps) { + case IEEE80211_SMPS_OFF: + return NL80211_SMPS_OFF; + case IEEE80211_SMPS_STATIC: + return NL80211_SMPS_STATIC; + case IEEE80211_SMPS_DYNAMIC: + return NL80211_SMPS_DYNAMIC; + default: + return NL80211_SMPS_OFF; + } +} + +int ieee80211_send_smps_action(struct ieee80211_sub_if_data *sdata, + enum ieee80211_smps_mode smps, const u8 *da, + const u8 *bssid) +{ + struct ieee80211_local *local = sdata->local; + struct sk_buff *skb; + struct ieee80211_mgmt *action_frame; + + /* 27 = header + category + action + smps mode */ + skb = dev_alloc_skb(27 + local->hw.extra_tx_headroom); + if (!skb) + return -ENOMEM; + + skb_reserve(skb, local->hw.extra_tx_headroom); + action_frame = skb_put(skb, 27); + memcpy(action_frame->da, da, ETH_ALEN); + memcpy(action_frame->sa, sdata->dev->dev_addr, ETH_ALEN); + memcpy(action_frame->bssid, bssid, ETH_ALEN); + action_frame->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_ACTION); + action_frame->u.action.category = WLAN_CATEGORY_HT; + action_frame->u.action.u.ht_smps.action = WLAN_HT_ACTION_SMPS; + switch (smps) { + case IEEE80211_SMPS_AUTOMATIC: + case IEEE80211_SMPS_NUM_MODES: + WARN_ON(1); + fallthrough; + case IEEE80211_SMPS_OFF: + action_frame->u.action.u.ht_smps.smps_control = + WLAN_HT_SMPS_CONTROL_DISABLED; + break; + case IEEE80211_SMPS_STATIC: + action_frame->u.action.u.ht_smps.smps_control = + WLAN_HT_SMPS_CONTROL_STATIC; + break; + case IEEE80211_SMPS_DYNAMIC: + action_frame->u.action.u.ht_smps.smps_control = + WLAN_HT_SMPS_CONTROL_DYNAMIC; + break; + } + + /* we'll do more on status of this frame */ + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_CTL_REQ_TX_STATUS; + ieee80211_tx_skb(sdata, skb); + + return 0; +} + +void ieee80211_request_smps(struct ieee80211_vif *vif, unsigned int link_id, + enum ieee80211_smps_mode smps_mode) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_link_data *link; + + if (WARN_ON_ONCE(vif->type != NL80211_IFTYPE_STATION)) + return; + + rcu_read_lock(); + link = rcu_dereference(sdata->link[link_id]); + if (WARN_ON(!link)) + goto out; + + if (link->u.mgd.driver_smps_mode == smps_mode) + goto out; + + link->u.mgd.driver_smps_mode = smps_mode; + ieee80211_queue_work(&sdata->local->hw, &link->u.mgd.request_smps_work); +out: + rcu_read_unlock(); +} +/* this might change ... don't want non-open drivers using it */ +EXPORT_SYMBOL_GPL(ieee80211_request_smps); diff --git a/net/mac80211/ibss.c b/net/mac80211/ibss.c new file mode 100644 index 000000000..79d2c5505 --- /dev/null +++ b/net/mac80211/ibss.c @@ -0,0 +1,1888 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * IBSS mode implementation + * Copyright 2003-2008, Jouni Malinen + * Copyright 2004, Instant802 Networks, Inc. + * Copyright 2005, Devicescape Software, Inc. + * Copyright 2006-2007 Jiri Benc + * Copyright 2007, Michael Wu + * Copyright 2009, Johannes Berg + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright(c) 2016 Intel Deutschland GmbH + * Copyright(c) 2018-2022 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ieee80211_i.h" +#include "driver-ops.h" +#include "rate.h" + +#define IEEE80211_SCAN_INTERVAL (2 * HZ) +#define IEEE80211_IBSS_JOIN_TIMEOUT (7 * HZ) + +#define IEEE80211_IBSS_MERGE_INTERVAL (30 * HZ) +#define IEEE80211_IBSS_INACTIVITY_LIMIT (60 * HZ) +#define IEEE80211_IBSS_RSN_INACTIVITY_LIMIT (10 * HZ) + +#define IEEE80211_IBSS_MAX_STA_ENTRIES 128 + +static struct beacon_data * +ieee80211_ibss_build_presp(struct ieee80211_sub_if_data *sdata, + const int beacon_int, const u32 basic_rates, + const u16 capability, u64 tsf, + struct cfg80211_chan_def *chandef, + bool *have_higher_than_11mbit, + struct cfg80211_csa_settings *csa_settings) +{ + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + struct ieee80211_local *local = sdata->local; + int rates_n = 0, i, ri; + struct ieee80211_mgmt *mgmt; + u8 *pos; + struct ieee80211_supported_band *sband; + u32 rate_flags, rates = 0, rates_added = 0; + struct beacon_data *presp; + int frame_len; + int shift; + + /* Build IBSS probe response */ + frame_len = sizeof(struct ieee80211_hdr_3addr) + + 12 /* struct ieee80211_mgmt.u.beacon */ + + 2 + IEEE80211_MAX_SSID_LEN /* max SSID */ + + 2 + 8 /* max Supported Rates */ + + 3 /* max DS params */ + + 4 /* IBSS params */ + + 5 /* Channel Switch Announcement */ + + 2 + (IEEE80211_MAX_SUPP_RATES - 8) + + 2 + sizeof(struct ieee80211_ht_cap) + + 2 + sizeof(struct ieee80211_ht_operation) + + 2 + sizeof(struct ieee80211_vht_cap) + + 2 + sizeof(struct ieee80211_vht_operation) + + ifibss->ie_len; + presp = kzalloc(sizeof(*presp) + frame_len, GFP_KERNEL); + if (!presp) + return NULL; + + presp->head = (void *)(presp + 1); + + mgmt = (void *) presp->head; + mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_PROBE_RESP); + eth_broadcast_addr(mgmt->da); + memcpy(mgmt->sa, sdata->vif.addr, ETH_ALEN); + memcpy(mgmt->bssid, ifibss->bssid, ETH_ALEN); + mgmt->u.beacon.beacon_int = cpu_to_le16(beacon_int); + mgmt->u.beacon.timestamp = cpu_to_le64(tsf); + mgmt->u.beacon.capab_info = cpu_to_le16(capability); + + pos = (u8 *)mgmt + offsetof(struct ieee80211_mgmt, u.beacon.variable); + + *pos++ = WLAN_EID_SSID; + *pos++ = ifibss->ssid_len; + memcpy(pos, ifibss->ssid, ifibss->ssid_len); + pos += ifibss->ssid_len; + + sband = local->hw.wiphy->bands[chandef->chan->band]; + rate_flags = ieee80211_chandef_rate_flags(chandef); + shift = ieee80211_chandef_get_shift(chandef); + rates_n = 0; + if (have_higher_than_11mbit) + *have_higher_than_11mbit = false; + + for (i = 0; i < sband->n_bitrates; i++) { + if ((rate_flags & sband->bitrates[i].flags) != rate_flags) + continue; + if (sband->bitrates[i].bitrate > 110 && + have_higher_than_11mbit) + *have_higher_than_11mbit = true; + + rates |= BIT(i); + rates_n++; + } + + *pos++ = WLAN_EID_SUPP_RATES; + *pos++ = min_t(int, 8, rates_n); + for (ri = 0; ri < sband->n_bitrates; ri++) { + int rate = DIV_ROUND_UP(sband->bitrates[ri].bitrate, + 5 * (1 << shift)); + u8 basic = 0; + if (!(rates & BIT(ri))) + continue; + + if (basic_rates & BIT(ri)) + basic = 0x80; + *pos++ = basic | (u8) rate; + if (++rates_added == 8) { + ri++; /* continue at next rate for EXT_SUPP_RATES */ + break; + } + } + + if (sband->band == NL80211_BAND_2GHZ) { + *pos++ = WLAN_EID_DS_PARAMS; + *pos++ = 1; + *pos++ = ieee80211_frequency_to_channel( + chandef->chan->center_freq); + } + + *pos++ = WLAN_EID_IBSS_PARAMS; + *pos++ = 2; + /* FIX: set ATIM window based on scan results */ + *pos++ = 0; + *pos++ = 0; + + if (csa_settings) { + *pos++ = WLAN_EID_CHANNEL_SWITCH; + *pos++ = 3; + *pos++ = csa_settings->block_tx ? 1 : 0; + *pos++ = ieee80211_frequency_to_channel( + csa_settings->chandef.chan->center_freq); + presp->cntdwn_counter_offsets[0] = (pos - presp->head); + *pos++ = csa_settings->count; + presp->cntdwn_current_counter = csa_settings->count; + } + + /* put the remaining rates in WLAN_EID_EXT_SUPP_RATES */ + if (rates_n > 8) { + *pos++ = WLAN_EID_EXT_SUPP_RATES; + *pos++ = rates_n - 8; + for (; ri < sband->n_bitrates; ri++) { + int rate = DIV_ROUND_UP(sband->bitrates[ri].bitrate, + 5 * (1 << shift)); + u8 basic = 0; + if (!(rates & BIT(ri))) + continue; + + if (basic_rates & BIT(ri)) + basic = 0x80; + *pos++ = basic | (u8) rate; + } + } + + if (ifibss->ie_len) { + memcpy(pos, ifibss->ie, ifibss->ie_len); + pos += ifibss->ie_len; + } + + /* add HT capability and information IEs */ + if (chandef->width != NL80211_CHAN_WIDTH_20_NOHT && + chandef->width != NL80211_CHAN_WIDTH_5 && + chandef->width != NL80211_CHAN_WIDTH_10 && + sband->ht_cap.ht_supported) { + struct ieee80211_sta_ht_cap ht_cap; + + memcpy(&ht_cap, &sband->ht_cap, sizeof(ht_cap)); + ieee80211_apply_htcap_overrides(sdata, &ht_cap); + + pos = ieee80211_ie_build_ht_cap(pos, &ht_cap, ht_cap.cap); + /* + * Note: According to 802.11n-2009 9.13.3.1, HT Protection + * field and RIFS Mode are reserved in IBSS mode, therefore + * keep them at 0 + */ + pos = ieee80211_ie_build_ht_oper(pos, &sband->ht_cap, + chandef, 0, false); + + /* add VHT capability and information IEs */ + if (chandef->width != NL80211_CHAN_WIDTH_20 && + chandef->width != NL80211_CHAN_WIDTH_40 && + sband->vht_cap.vht_supported) { + pos = ieee80211_ie_build_vht_cap(pos, &sband->vht_cap, + sband->vht_cap.cap); + pos = ieee80211_ie_build_vht_oper(pos, &sband->vht_cap, + chandef); + } + } + + if (local->hw.queues >= IEEE80211_NUM_ACS) + pos = ieee80211_add_wmm_info_ie(pos, 0); /* U-APSD not in use */ + + presp->head_len = pos - presp->head; + if (WARN_ON(presp->head_len > frame_len)) + goto error; + + return presp; +error: + kfree(presp); + return NULL; +} + +static void __ieee80211_sta_join_ibss(struct ieee80211_sub_if_data *sdata, + const u8 *bssid, const int beacon_int, + struct cfg80211_chan_def *req_chandef, + const u32 basic_rates, + const u16 capability, u64 tsf, + bool creator) +{ + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + struct ieee80211_local *local = sdata->local; + struct ieee80211_mgmt *mgmt; + struct cfg80211_bss *bss; + u32 bss_change; + struct cfg80211_chan_def chandef; + struct ieee80211_channel *chan; + struct beacon_data *presp; + struct cfg80211_inform_bss bss_meta = {}; + bool have_higher_than_11mbit; + bool radar_required; + int err; + + sdata_assert_lock(sdata); + + /* Reset own TSF to allow time synchronization work. */ + drv_reset_tsf(local, sdata); + + if (!ether_addr_equal(ifibss->bssid, bssid)) + sta_info_flush(sdata); + + /* if merging, indicate to driver that we leave the old IBSS */ + if (sdata->vif.cfg.ibss_joined) { + sdata->vif.cfg.ibss_joined = false; + sdata->vif.cfg.ibss_creator = false; + sdata->vif.bss_conf.enable_beacon = false; + netif_carrier_off(sdata->dev); + ieee80211_bss_info_change_notify(sdata, + BSS_CHANGED_IBSS | + BSS_CHANGED_BEACON_ENABLED); + drv_leave_ibss(local, sdata); + } + + presp = sdata_dereference(ifibss->presp, sdata); + RCU_INIT_POINTER(ifibss->presp, NULL); + if (presp) + kfree_rcu(presp, rcu_head); + + /* make a copy of the chandef, it could be modified below. */ + chandef = *req_chandef; + chan = chandef.chan; + if (!cfg80211_reg_can_beacon(local->hw.wiphy, &chandef, + NL80211_IFTYPE_ADHOC)) { + if (chandef.width == NL80211_CHAN_WIDTH_5 || + chandef.width == NL80211_CHAN_WIDTH_10 || + chandef.width == NL80211_CHAN_WIDTH_20_NOHT || + chandef.width == NL80211_CHAN_WIDTH_20) { + sdata_info(sdata, + "Failed to join IBSS, beacons forbidden\n"); + return; + } + chandef.width = NL80211_CHAN_WIDTH_20; + chandef.center_freq1 = chan->center_freq; + /* check again for downgraded chandef */ + if (!cfg80211_reg_can_beacon(local->hw.wiphy, &chandef, + NL80211_IFTYPE_ADHOC)) { + sdata_info(sdata, + "Failed to join IBSS, beacons forbidden\n"); + return; + } + } + + err = cfg80211_chandef_dfs_required(sdata->local->hw.wiphy, + &chandef, NL80211_IFTYPE_ADHOC); + if (err < 0) { + sdata_info(sdata, + "Failed to join IBSS, invalid chandef\n"); + return; + } + if (err > 0 && !ifibss->userspace_handles_dfs) { + sdata_info(sdata, + "Failed to join IBSS, DFS channel without control program\n"); + return; + } + + radar_required = err; + + mutex_lock(&local->mtx); + if (ieee80211_link_use_channel(&sdata->deflink, &chandef, + ifibss->fixed_channel ? + IEEE80211_CHANCTX_SHARED : + IEEE80211_CHANCTX_EXCLUSIVE)) { + sdata_info(sdata, "Failed to join IBSS, no channel context\n"); + mutex_unlock(&local->mtx); + return; + } + sdata->deflink.radar_required = radar_required; + mutex_unlock(&local->mtx); + + memcpy(ifibss->bssid, bssid, ETH_ALEN); + + presp = ieee80211_ibss_build_presp(sdata, beacon_int, basic_rates, + capability, tsf, &chandef, + &have_higher_than_11mbit, NULL); + if (!presp) + return; + + rcu_assign_pointer(ifibss->presp, presp); + mgmt = (void *)presp->head; + + sdata->vif.bss_conf.enable_beacon = true; + sdata->vif.bss_conf.beacon_int = beacon_int; + sdata->vif.bss_conf.basic_rates = basic_rates; + sdata->vif.cfg.ssid_len = ifibss->ssid_len; + memcpy(sdata->vif.cfg.ssid, ifibss->ssid, ifibss->ssid_len); + bss_change = BSS_CHANGED_BEACON_INT; + bss_change |= ieee80211_reset_erp_info(sdata); + bss_change |= BSS_CHANGED_BSSID; + bss_change |= BSS_CHANGED_BEACON; + bss_change |= BSS_CHANGED_BEACON_ENABLED; + bss_change |= BSS_CHANGED_BASIC_RATES; + bss_change |= BSS_CHANGED_HT; + bss_change |= BSS_CHANGED_IBSS; + bss_change |= BSS_CHANGED_SSID; + + /* + * In 5 GHz/802.11a, we can always use short slot time. + * (IEEE 802.11-2012 18.3.8.7) + * + * In 2.4GHz, we must always use long slots in IBSS for compatibility + * reasons. + * (IEEE 802.11-2012 19.4.5) + * + * HT follows these specifications (IEEE 802.11-2012 20.3.18) + */ + sdata->vif.bss_conf.use_short_slot = chan->band == NL80211_BAND_5GHZ; + bss_change |= BSS_CHANGED_ERP_SLOT; + + /* cf. IEEE 802.11 9.2.12 */ + sdata->deflink.operating_11g_mode = + chan->band == NL80211_BAND_2GHZ && have_higher_than_11mbit; + + ieee80211_set_wmm_default(&sdata->deflink, true, false); + + sdata->vif.cfg.ibss_joined = true; + sdata->vif.cfg.ibss_creator = creator; + + err = drv_join_ibss(local, sdata); + if (err) { + sdata->vif.cfg.ibss_joined = false; + sdata->vif.cfg.ibss_creator = false; + sdata->vif.bss_conf.enable_beacon = false; + sdata->vif.cfg.ssid_len = 0; + RCU_INIT_POINTER(ifibss->presp, NULL); + kfree_rcu(presp, rcu_head); + mutex_lock(&local->mtx); + ieee80211_link_release_channel(&sdata->deflink); + mutex_unlock(&local->mtx); + sdata_info(sdata, "Failed to join IBSS, driver failure: %d\n", + err); + return; + } + + ieee80211_bss_info_change_notify(sdata, bss_change); + + ifibss->state = IEEE80211_IBSS_MLME_JOINED; + mod_timer(&ifibss->timer, + round_jiffies(jiffies + IEEE80211_IBSS_MERGE_INTERVAL)); + + bss_meta.chan = chan; + bss_meta.scan_width = cfg80211_chandef_to_scan_width(&chandef); + bss = cfg80211_inform_bss_frame_data(local->hw.wiphy, &bss_meta, mgmt, + presp->head_len, GFP_KERNEL); + + cfg80211_put_bss(local->hw.wiphy, bss); + netif_carrier_on(sdata->dev); + cfg80211_ibss_joined(sdata->dev, ifibss->bssid, chan, GFP_KERNEL); +} + +static void ieee80211_sta_join_ibss(struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss *bss) +{ + struct cfg80211_bss *cbss = + container_of((void *)bss, struct cfg80211_bss, priv); + struct ieee80211_supported_band *sband; + struct cfg80211_chan_def chandef; + u32 basic_rates; + int i, j; + u16 beacon_int = cbss->beacon_interval; + const struct cfg80211_bss_ies *ies; + enum nl80211_channel_type chan_type; + u64 tsf; + u32 rate_flags; + int shift; + + sdata_assert_lock(sdata); + + if (beacon_int < 10) + beacon_int = 10; + + switch (sdata->u.ibss.chandef.width) { + case NL80211_CHAN_WIDTH_20_NOHT: + case NL80211_CHAN_WIDTH_20: + case NL80211_CHAN_WIDTH_40: + chan_type = cfg80211_get_chandef_type(&sdata->u.ibss.chandef); + cfg80211_chandef_create(&chandef, cbss->channel, chan_type); + break; + case NL80211_CHAN_WIDTH_5: + case NL80211_CHAN_WIDTH_10: + cfg80211_chandef_create(&chandef, cbss->channel, + NL80211_CHAN_NO_HT); + chandef.width = sdata->u.ibss.chandef.width; + break; + case NL80211_CHAN_WIDTH_80: + case NL80211_CHAN_WIDTH_80P80: + case NL80211_CHAN_WIDTH_160: + chandef = sdata->u.ibss.chandef; + chandef.chan = cbss->channel; + break; + default: + /* fall back to 20 MHz for unsupported modes */ + cfg80211_chandef_create(&chandef, cbss->channel, + NL80211_CHAN_NO_HT); + break; + } + + sband = sdata->local->hw.wiphy->bands[cbss->channel->band]; + rate_flags = ieee80211_chandef_rate_flags(&sdata->u.ibss.chandef); + shift = ieee80211_vif_get_shift(&sdata->vif); + + basic_rates = 0; + + for (i = 0; i < bss->supp_rates_len; i++) { + int rate = bss->supp_rates[i] & 0x7f; + bool is_basic = !!(bss->supp_rates[i] & 0x80); + + for (j = 0; j < sband->n_bitrates; j++) { + int brate; + if ((rate_flags & sband->bitrates[j].flags) + != rate_flags) + continue; + + brate = DIV_ROUND_UP(sband->bitrates[j].bitrate, + 5 * (1 << shift)); + if (brate == rate) { + if (is_basic) + basic_rates |= BIT(j); + break; + } + } + } + + rcu_read_lock(); + ies = rcu_dereference(cbss->ies); + tsf = ies->tsf; + rcu_read_unlock(); + + __ieee80211_sta_join_ibss(sdata, cbss->bssid, + beacon_int, + &chandef, + basic_rates, + cbss->capability, + tsf, false); +} + +int ieee80211_ibss_csa_beacon(struct ieee80211_sub_if_data *sdata, + struct cfg80211_csa_settings *csa_settings) +{ + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + struct beacon_data *presp, *old_presp; + struct cfg80211_bss *cbss; + const struct cfg80211_bss_ies *ies; + u16 capability = WLAN_CAPABILITY_IBSS; + u64 tsf; + + sdata_assert_lock(sdata); + + if (ifibss->privacy) + capability |= WLAN_CAPABILITY_PRIVACY; + + cbss = cfg80211_get_bss(sdata->local->hw.wiphy, ifibss->chandef.chan, + ifibss->bssid, ifibss->ssid, + ifibss->ssid_len, IEEE80211_BSS_TYPE_IBSS, + IEEE80211_PRIVACY(ifibss->privacy)); + + if (WARN_ON(!cbss)) + return -EINVAL; + + rcu_read_lock(); + ies = rcu_dereference(cbss->ies); + tsf = ies->tsf; + rcu_read_unlock(); + cfg80211_put_bss(sdata->local->hw.wiphy, cbss); + + old_presp = sdata_dereference(ifibss->presp, sdata); + + presp = ieee80211_ibss_build_presp(sdata, + sdata->vif.bss_conf.beacon_int, + sdata->vif.bss_conf.basic_rates, + capability, tsf, &ifibss->chandef, + NULL, csa_settings); + if (!presp) + return -ENOMEM; + + rcu_assign_pointer(ifibss->presp, presp); + if (old_presp) + kfree_rcu(old_presp, rcu_head); + + return BSS_CHANGED_BEACON; +} + +int ieee80211_ibss_finish_csa(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + struct cfg80211_bss *cbss; + + sdata_assert_lock(sdata); + + /* When not connected/joined, sending CSA doesn't make sense. */ + if (ifibss->state != IEEE80211_IBSS_MLME_JOINED) + return -ENOLINK; + + /* update cfg80211 bss information with the new channel */ + if (!is_zero_ether_addr(ifibss->bssid)) { + cbss = cfg80211_get_bss(sdata->local->hw.wiphy, + ifibss->chandef.chan, + ifibss->bssid, ifibss->ssid, + ifibss->ssid_len, + IEEE80211_BSS_TYPE_IBSS, + IEEE80211_PRIVACY(ifibss->privacy)); + /* XXX: should not really modify cfg80211 data */ + if (cbss) { + cbss->channel = sdata->deflink.csa_chandef.chan; + cfg80211_put_bss(sdata->local->hw.wiphy, cbss); + } + } + + ifibss->chandef = sdata->deflink.csa_chandef; + + /* generate the beacon */ + return ieee80211_ibss_csa_beacon(sdata, NULL); +} + +void ieee80211_ibss_stop(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + + cancel_work_sync(&ifibss->csa_connection_drop_work); +} + +static struct sta_info *ieee80211_ibss_finish_sta(struct sta_info *sta) + __acquires(RCU) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + u8 addr[ETH_ALEN]; + + memcpy(addr, sta->sta.addr, ETH_ALEN); + + ibss_dbg(sdata, "Adding new IBSS station %pM\n", addr); + + sta_info_pre_move_state(sta, IEEE80211_STA_AUTH); + sta_info_pre_move_state(sta, IEEE80211_STA_ASSOC); + /* authorize the station only if the network is not RSN protected. If + * not wait for the userspace to authorize it */ + if (!sta->sdata->u.ibss.control_port) + sta_info_pre_move_state(sta, IEEE80211_STA_AUTHORIZED); + + rate_control_rate_init(sta); + + /* If it fails, maybe we raced another insertion? */ + if (sta_info_insert_rcu(sta)) + return sta_info_get(sdata, addr); + return sta; +} + +static struct sta_info * +ieee80211_ibss_add_sta(struct ieee80211_sub_if_data *sdata, const u8 *bssid, + const u8 *addr, u32 supp_rates) + __acquires(RCU) +{ + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + struct ieee80211_chanctx_conf *chanctx_conf; + struct ieee80211_supported_band *sband; + enum nl80211_bss_scan_width scan_width; + int band; + + /* + * XXX: Consider removing the least recently used entry and + * allow new one to be added. + */ + if (local->num_sta >= IEEE80211_IBSS_MAX_STA_ENTRIES) { + net_info_ratelimited("%s: No room for a new IBSS STA entry %pM\n", + sdata->name, addr); + rcu_read_lock(); + return NULL; + } + + if (ifibss->state == IEEE80211_IBSS_MLME_SEARCH) { + rcu_read_lock(); + return NULL; + } + + if (!ether_addr_equal(bssid, sdata->u.ibss.bssid)) { + rcu_read_lock(); + return NULL; + } + + rcu_read_lock(); + chanctx_conf = rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + if (WARN_ON_ONCE(!chanctx_conf)) + return NULL; + band = chanctx_conf->def.chan->band; + scan_width = cfg80211_chandef_to_scan_width(&chanctx_conf->def); + rcu_read_unlock(); + + sta = sta_info_alloc(sdata, addr, GFP_KERNEL); + if (!sta) { + rcu_read_lock(); + return NULL; + } + + /* make sure mandatory rates are always added */ + sband = local->hw.wiphy->bands[band]; + sta->sta.deflink.supp_rates[band] = supp_rates | + ieee80211_mandatory_rates(sband, scan_width); + + return ieee80211_ibss_finish_sta(sta); +} + +static int ieee80211_sta_active_ibss(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + int active = 0; + struct sta_info *sta; + + sdata_assert_lock(sdata); + + rcu_read_lock(); + + list_for_each_entry_rcu(sta, &local->sta_list, list) { + unsigned long last_active = ieee80211_sta_last_active(sta); + + if (sta->sdata == sdata && + time_is_after_jiffies(last_active + + IEEE80211_IBSS_MERGE_INTERVAL)) { + active++; + break; + } + } + + rcu_read_unlock(); + + return active; +} + +static void ieee80211_ibss_disconnect(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + struct ieee80211_local *local = sdata->local; + struct cfg80211_bss *cbss; + struct beacon_data *presp; + struct sta_info *sta; + + if (!is_zero_ether_addr(ifibss->bssid)) { + cbss = cfg80211_get_bss(local->hw.wiphy, ifibss->chandef.chan, + ifibss->bssid, ifibss->ssid, + ifibss->ssid_len, + IEEE80211_BSS_TYPE_IBSS, + IEEE80211_PRIVACY(ifibss->privacy)); + + if (cbss) { + cfg80211_unlink_bss(local->hw.wiphy, cbss); + cfg80211_put_bss(sdata->local->hw.wiphy, cbss); + } + } + + ifibss->state = IEEE80211_IBSS_MLME_SEARCH; + + sta_info_flush(sdata); + + spin_lock_bh(&ifibss->incomplete_lock); + while (!list_empty(&ifibss->incomplete_stations)) { + sta = list_first_entry(&ifibss->incomplete_stations, + struct sta_info, list); + list_del(&sta->list); + spin_unlock_bh(&ifibss->incomplete_lock); + + sta_info_free(local, sta); + spin_lock_bh(&ifibss->incomplete_lock); + } + spin_unlock_bh(&ifibss->incomplete_lock); + + netif_carrier_off(sdata->dev); + + sdata->vif.cfg.ibss_joined = false; + sdata->vif.cfg.ibss_creator = false; + sdata->vif.bss_conf.enable_beacon = false; + sdata->vif.cfg.ssid_len = 0; + + /* remove beacon */ + presp = sdata_dereference(ifibss->presp, sdata); + RCU_INIT_POINTER(sdata->u.ibss.presp, NULL); + if (presp) + kfree_rcu(presp, rcu_head); + + clear_bit(SDATA_STATE_OFFCHANNEL_BEACON_STOPPED, &sdata->state); + ieee80211_bss_info_change_notify(sdata, BSS_CHANGED_BEACON_ENABLED | + BSS_CHANGED_IBSS); + drv_leave_ibss(local, sdata); + mutex_lock(&local->mtx); + ieee80211_link_release_channel(&sdata->deflink); + mutex_unlock(&local->mtx); +} + +static void ieee80211_csa_connection_drop_work(struct work_struct *work) +{ + struct ieee80211_sub_if_data *sdata = + container_of(work, struct ieee80211_sub_if_data, + u.ibss.csa_connection_drop_work); + + sdata_lock(sdata); + + ieee80211_ibss_disconnect(sdata); + synchronize_rcu(); + skb_queue_purge(&sdata->skb_queue); + + /* trigger a scan to find another IBSS network to join */ + ieee80211_queue_work(&sdata->local->hw, &sdata->work); + + sdata_unlock(sdata); +} + +static void ieee80211_ibss_csa_mark_radar(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + int err; + + /* if the current channel is a DFS channel, mark the channel as + * unavailable. + */ + err = cfg80211_chandef_dfs_required(sdata->local->hw.wiphy, + &ifibss->chandef, + NL80211_IFTYPE_ADHOC); + if (err > 0) + cfg80211_radar_event(sdata->local->hw.wiphy, &ifibss->chandef, + GFP_ATOMIC); +} + +static bool +ieee80211_ibss_process_chanswitch(struct ieee80211_sub_if_data *sdata, + struct ieee802_11_elems *elems, + bool beacon) +{ + struct cfg80211_csa_settings params; + struct ieee80211_csa_ie csa_ie; + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + enum nl80211_channel_type ch_type; + int err; + ieee80211_conn_flags_t conn_flags; + u32 vht_cap_info = 0; + + sdata_assert_lock(sdata); + + conn_flags = IEEE80211_CONN_DISABLE_VHT; + + switch (ifibss->chandef.width) { + case NL80211_CHAN_WIDTH_5: + case NL80211_CHAN_WIDTH_10: + case NL80211_CHAN_WIDTH_20_NOHT: + conn_flags |= IEEE80211_CONN_DISABLE_HT; + fallthrough; + case NL80211_CHAN_WIDTH_20: + conn_flags |= IEEE80211_CONN_DISABLE_40MHZ; + break; + default: + break; + } + + if (elems->vht_cap_elem) + vht_cap_info = le32_to_cpu(elems->vht_cap_elem->vht_cap_info); + + memset(¶ms, 0, sizeof(params)); + err = ieee80211_parse_ch_switch_ie(sdata, elems, + ifibss->chandef.chan->band, + vht_cap_info, + conn_flags, ifibss->bssid, &csa_ie); + /* can't switch to destination channel, fail */ + if (err < 0) + goto disconnect; + + /* did not contain a CSA */ + if (err) + return false; + + /* channel switch is not supported, disconnect */ + if (!(sdata->local->hw.wiphy->flags & WIPHY_FLAG_HAS_CHANNEL_SWITCH)) + goto disconnect; + + params.count = csa_ie.count; + params.chandef = csa_ie.chandef; + + switch (ifibss->chandef.width) { + case NL80211_CHAN_WIDTH_20_NOHT: + case NL80211_CHAN_WIDTH_20: + case NL80211_CHAN_WIDTH_40: + /* keep our current HT mode (HT20/HT40+/HT40-), even if + * another mode has been announced. The mode is not adopted + * within the beacon while doing CSA and we should therefore + * keep the mode which we announce. + */ + ch_type = cfg80211_get_chandef_type(&ifibss->chandef); + cfg80211_chandef_create(¶ms.chandef, params.chandef.chan, + ch_type); + break; + case NL80211_CHAN_WIDTH_5: + case NL80211_CHAN_WIDTH_10: + if (params.chandef.width != ifibss->chandef.width) { + sdata_info(sdata, + "IBSS %pM received channel switch from incompatible channel width (%d MHz, width:%d, CF1/2: %d/%d MHz), disconnecting\n", + ifibss->bssid, + params.chandef.chan->center_freq, + params.chandef.width, + params.chandef.center_freq1, + params.chandef.center_freq2); + goto disconnect; + } + break; + default: + /* should not happen, conn_flags should prevent VHT modes. */ + WARN_ON(1); + goto disconnect; + } + + if (!cfg80211_reg_can_beacon(sdata->local->hw.wiphy, ¶ms.chandef, + NL80211_IFTYPE_ADHOC)) { + sdata_info(sdata, + "IBSS %pM switches to unsupported channel (%d MHz, width:%d, CF1/2: %d/%d MHz), disconnecting\n", + ifibss->bssid, + params.chandef.chan->center_freq, + params.chandef.width, + params.chandef.center_freq1, + params.chandef.center_freq2); + goto disconnect; + } + + err = cfg80211_chandef_dfs_required(sdata->local->hw.wiphy, + ¶ms.chandef, + NL80211_IFTYPE_ADHOC); + if (err < 0) + goto disconnect; + if (err > 0 && !ifibss->userspace_handles_dfs) { + /* IBSS-DFS only allowed with a control program */ + goto disconnect; + } + + params.radar_required = err; + + if (cfg80211_chandef_identical(¶ms.chandef, + &sdata->vif.bss_conf.chandef)) { + ibss_dbg(sdata, + "received csa with an identical chandef, ignoring\n"); + return true; + } + + /* all checks done, now perform the channel switch. */ + ibss_dbg(sdata, + "received channel switch announcement to go to channel %d MHz\n", + params.chandef.chan->center_freq); + + params.block_tx = !!csa_ie.mode; + + if (ieee80211_channel_switch(sdata->local->hw.wiphy, sdata->dev, + ¶ms)) + goto disconnect; + + ieee80211_ibss_csa_mark_radar(sdata); + + return true; +disconnect: + ibss_dbg(sdata, "Can't handle channel switch, disconnect\n"); + ieee80211_queue_work(&sdata->local->hw, + &ifibss->csa_connection_drop_work); + + ieee80211_ibss_csa_mark_radar(sdata); + + return true; +} + +static void +ieee80211_rx_mgmt_spectrum_mgmt(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, size_t len, + struct ieee80211_rx_status *rx_status, + struct ieee802_11_elems *elems) +{ + int required_len; + + if (len < IEEE80211_MIN_ACTION_SIZE + 1) + return; + + /* CSA is the only action we handle for now */ + if (mgmt->u.action.u.measurement.action_code != + WLAN_ACTION_SPCT_CHL_SWITCH) + return; + + required_len = IEEE80211_MIN_ACTION_SIZE + + sizeof(mgmt->u.action.u.chan_switch); + if (len < required_len) + return; + + if (!sdata->vif.bss_conf.csa_active) + ieee80211_ibss_process_chanswitch(sdata, elems, false); +} + +static void ieee80211_rx_mgmt_deauth_ibss(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, + size_t len) +{ + u16 reason = le16_to_cpu(mgmt->u.deauth.reason_code); + + if (len < IEEE80211_DEAUTH_FRAME_LEN) + return; + + ibss_dbg(sdata, "RX DeAuth SA=%pM DA=%pM\n", mgmt->sa, mgmt->da); + ibss_dbg(sdata, "\tBSSID=%pM (reason: %d)\n", mgmt->bssid, reason); + sta_info_destroy_addr(sdata, mgmt->sa); +} + +static void ieee80211_rx_mgmt_auth_ibss(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, + size_t len) +{ + u16 auth_alg, auth_transaction; + + sdata_assert_lock(sdata); + + if (len < 24 + 6) + return; + + auth_alg = le16_to_cpu(mgmt->u.auth.auth_alg); + auth_transaction = le16_to_cpu(mgmt->u.auth.auth_transaction); + + ibss_dbg(sdata, "RX Auth SA=%pM DA=%pM\n", mgmt->sa, mgmt->da); + ibss_dbg(sdata, "\tBSSID=%pM (auth_transaction=%d)\n", + mgmt->bssid, auth_transaction); + + if (auth_alg != WLAN_AUTH_OPEN || auth_transaction != 1) + return; + + /* + * IEEE 802.11 standard does not require authentication in IBSS + * networks and most implementations do not seem to use it. + * However, try to reply to authentication attempts if someone + * has actually implemented this. + */ + ieee80211_send_auth(sdata, 2, WLAN_AUTH_OPEN, 0, NULL, 0, + mgmt->sa, sdata->u.ibss.bssid, NULL, 0, 0, 0); +} + +static void ieee80211_update_sta_info(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, size_t len, + struct ieee80211_rx_status *rx_status, + struct ieee802_11_elems *elems, + struct ieee80211_channel *channel) +{ + struct sta_info *sta; + enum nl80211_band band = rx_status->band; + enum nl80211_bss_scan_width scan_width; + struct ieee80211_local *local = sdata->local; + struct ieee80211_supported_band *sband; + bool rates_updated = false; + u32 supp_rates = 0; + + if (sdata->vif.type != NL80211_IFTYPE_ADHOC) + return; + + if (!ether_addr_equal(mgmt->bssid, sdata->u.ibss.bssid)) + return; + + sband = local->hw.wiphy->bands[band]; + if (WARN_ON(!sband)) + return; + + rcu_read_lock(); + sta = sta_info_get(sdata, mgmt->sa); + + if (elems->supp_rates) { + supp_rates = ieee80211_sta_get_rates(sdata, elems, + band, NULL); + if (sta) { + u32 prev_rates; + + prev_rates = sta->sta.deflink.supp_rates[band]; + /* make sure mandatory rates are always added */ + scan_width = NL80211_BSS_CHAN_WIDTH_20; + if (rx_status->bw == RATE_INFO_BW_5) + scan_width = NL80211_BSS_CHAN_WIDTH_5; + else if (rx_status->bw == RATE_INFO_BW_10) + scan_width = NL80211_BSS_CHAN_WIDTH_10; + + sta->sta.deflink.supp_rates[band] = supp_rates | + ieee80211_mandatory_rates(sband, scan_width); + if (sta->sta.deflink.supp_rates[band] != prev_rates) { + ibss_dbg(sdata, + "updated supp_rates set for %pM based on beacon/probe_resp (0x%x -> 0x%x)\n", + sta->sta.addr, prev_rates, + sta->sta.deflink.supp_rates[band]); + rates_updated = true; + } + } else { + rcu_read_unlock(); + sta = ieee80211_ibss_add_sta(sdata, mgmt->bssid, + mgmt->sa, supp_rates); + } + } + + if (sta && !sta->sta.wme && + (elems->wmm_info || elems->s1g_capab) && + local->hw.queues >= IEEE80211_NUM_ACS) { + sta->sta.wme = true; + ieee80211_check_fast_xmit(sta); + } + + if (sta && elems->ht_operation && elems->ht_cap_elem && + sdata->u.ibss.chandef.width != NL80211_CHAN_WIDTH_20_NOHT && + sdata->u.ibss.chandef.width != NL80211_CHAN_WIDTH_5 && + sdata->u.ibss.chandef.width != NL80211_CHAN_WIDTH_10) { + /* we both use HT */ + struct ieee80211_ht_cap htcap_ie; + struct cfg80211_chan_def chandef; + enum ieee80211_sta_rx_bandwidth bw = sta->sta.deflink.bandwidth; + + cfg80211_chandef_create(&chandef, channel, NL80211_CHAN_NO_HT); + ieee80211_chandef_ht_oper(elems->ht_operation, &chandef); + + memcpy(&htcap_ie, elems->ht_cap_elem, sizeof(htcap_ie)); + rates_updated |= ieee80211_ht_cap_ie_to_sta_ht_cap(sdata, sband, + &htcap_ie, + &sta->deflink); + + if (elems->vht_operation && elems->vht_cap_elem && + sdata->u.ibss.chandef.width != NL80211_CHAN_WIDTH_20 && + sdata->u.ibss.chandef.width != NL80211_CHAN_WIDTH_40) { + /* we both use VHT */ + struct ieee80211_vht_cap cap_ie; + struct ieee80211_sta_vht_cap cap = sta->sta.deflink.vht_cap; + u32 vht_cap_info = + le32_to_cpu(elems->vht_cap_elem->vht_cap_info); + + ieee80211_chandef_vht_oper(&local->hw, vht_cap_info, + elems->vht_operation, + elems->ht_operation, + &chandef); + memcpy(&cap_ie, elems->vht_cap_elem, sizeof(cap_ie)); + ieee80211_vht_cap_ie_to_sta_vht_cap(sdata, sband, + &cap_ie, NULL, + &sta->deflink); + if (memcmp(&cap, &sta->sta.deflink.vht_cap, sizeof(cap))) + rates_updated |= true; + } + + if (bw != sta->sta.deflink.bandwidth) + rates_updated |= true; + + if (!cfg80211_chandef_compatible(&sdata->u.ibss.chandef, + &chandef)) + WARN_ON_ONCE(1); + } + + if (sta && rates_updated) { + u32 changed = IEEE80211_RC_SUPP_RATES_CHANGED; + u8 rx_nss = sta->sta.deflink.rx_nss; + + /* Force rx_nss recalculation */ + sta->sta.deflink.rx_nss = 0; + rate_control_rate_init(sta); + if (sta->sta.deflink.rx_nss != rx_nss) + changed |= IEEE80211_RC_NSS_CHANGED; + + drv_sta_rc_update(local, sdata, &sta->sta, changed); + } + + rcu_read_unlock(); +} + +static void ieee80211_rx_bss_info(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, size_t len, + struct ieee80211_rx_status *rx_status, + struct ieee802_11_elems *elems) +{ + struct ieee80211_local *local = sdata->local; + struct cfg80211_bss *cbss; + struct ieee80211_bss *bss; + struct ieee80211_channel *channel; + u64 beacon_timestamp, rx_timestamp; + u32 supp_rates = 0; + enum nl80211_band band = rx_status->band; + + channel = ieee80211_get_channel(local->hw.wiphy, rx_status->freq); + if (!channel) + return; + + ieee80211_update_sta_info(sdata, mgmt, len, rx_status, elems, channel); + + bss = ieee80211_bss_info_update(local, rx_status, mgmt, len, channel); + if (!bss) + return; + + cbss = container_of((void *)bss, struct cfg80211_bss, priv); + + /* same for beacon and probe response */ + beacon_timestamp = le64_to_cpu(mgmt->u.beacon.timestamp); + + /* check if we need to merge IBSS */ + + /* not an IBSS */ + if (!(cbss->capability & WLAN_CAPABILITY_IBSS)) + goto put_bss; + + /* different channel */ + if (sdata->u.ibss.fixed_channel && + sdata->u.ibss.chandef.chan != cbss->channel) + goto put_bss; + + /* different SSID */ + if (elems->ssid_len != sdata->u.ibss.ssid_len || + memcmp(elems->ssid, sdata->u.ibss.ssid, + sdata->u.ibss.ssid_len)) + goto put_bss; + + /* process channel switch */ + if (sdata->vif.bss_conf.csa_active || + ieee80211_ibss_process_chanswitch(sdata, elems, true)) + goto put_bss; + + /* same BSSID */ + if (ether_addr_equal(cbss->bssid, sdata->u.ibss.bssid)) + goto put_bss; + + /* we use a fixed BSSID */ + if (sdata->u.ibss.fixed_bssid) + goto put_bss; + + if (ieee80211_have_rx_timestamp(rx_status)) { + /* time when timestamp field was received */ + rx_timestamp = + ieee80211_calculate_rx_timestamp(local, rx_status, + len + FCS_LEN, 24); + } else { + /* + * second best option: get current TSF + * (will return -1 if not supported) + */ + rx_timestamp = drv_get_tsf(local, sdata); + } + + ibss_dbg(sdata, "RX beacon SA=%pM BSSID=%pM TSF=0x%llx\n", + mgmt->sa, mgmt->bssid, + (unsigned long long)rx_timestamp); + ibss_dbg(sdata, "\tBCN=0x%llx diff=%lld @%lu\n", + (unsigned long long)beacon_timestamp, + (unsigned long long)(rx_timestamp - beacon_timestamp), + jiffies); + + if (beacon_timestamp > rx_timestamp) { + ibss_dbg(sdata, + "beacon TSF higher than local TSF - IBSS merge with BSSID %pM\n", + mgmt->bssid); + ieee80211_sta_join_ibss(sdata, bss); + supp_rates = ieee80211_sta_get_rates(sdata, elems, band, NULL); + ieee80211_ibss_add_sta(sdata, mgmt->bssid, mgmt->sa, + supp_rates); + rcu_read_unlock(); + } + + put_bss: + ieee80211_rx_bss_put(local, bss); +} + +void ieee80211_ibss_rx_no_sta(struct ieee80211_sub_if_data *sdata, + const u8 *bssid, const u8 *addr, + u32 supp_rates) +{ + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + struct ieee80211_chanctx_conf *chanctx_conf; + struct ieee80211_supported_band *sband; + enum nl80211_bss_scan_width scan_width; + int band; + + /* + * XXX: Consider removing the least recently used entry and + * allow new one to be added. + */ + if (local->num_sta >= IEEE80211_IBSS_MAX_STA_ENTRIES) { + net_info_ratelimited("%s: No room for a new IBSS STA entry %pM\n", + sdata->name, addr); + return; + } + + if (ifibss->state == IEEE80211_IBSS_MLME_SEARCH) + return; + + if (!ether_addr_equal(bssid, sdata->u.ibss.bssid)) + return; + + rcu_read_lock(); + chanctx_conf = rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + if (WARN_ON_ONCE(!chanctx_conf)) { + rcu_read_unlock(); + return; + } + band = chanctx_conf->def.chan->band; + scan_width = cfg80211_chandef_to_scan_width(&chanctx_conf->def); + rcu_read_unlock(); + + sta = sta_info_alloc(sdata, addr, GFP_ATOMIC); + if (!sta) + return; + + /* make sure mandatory rates are always added */ + sband = local->hw.wiphy->bands[band]; + sta->sta.deflink.supp_rates[band] = supp_rates | + ieee80211_mandatory_rates(sband, scan_width); + + spin_lock(&ifibss->incomplete_lock); + list_add(&sta->list, &ifibss->incomplete_stations); + spin_unlock(&ifibss->incomplete_lock); + ieee80211_queue_work(&local->hw, &sdata->work); +} + +static void ieee80211_ibss_sta_expire(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + struct ieee80211_local *local = sdata->local; + struct sta_info *sta, *tmp; + unsigned long exp_time = IEEE80211_IBSS_INACTIVITY_LIMIT; + unsigned long exp_rsn = IEEE80211_IBSS_RSN_INACTIVITY_LIMIT; + + mutex_lock(&local->sta_mtx); + + list_for_each_entry_safe(sta, tmp, &local->sta_list, list) { + unsigned long last_active = ieee80211_sta_last_active(sta); + + if (sdata != sta->sdata) + continue; + + if (time_is_before_jiffies(last_active + exp_time) || + (time_is_before_jiffies(last_active + exp_rsn) && + sta->sta_state != IEEE80211_STA_AUTHORIZED)) { + u8 frame_buf[IEEE80211_DEAUTH_FRAME_LEN]; + + sta_dbg(sta->sdata, "expiring inactive %sSTA %pM\n", + sta->sta_state != IEEE80211_STA_AUTHORIZED ? + "not authorized " : "", sta->sta.addr); + + ieee80211_send_deauth_disassoc(sdata, sta->sta.addr, + ifibss->bssid, + IEEE80211_STYPE_DEAUTH, + WLAN_REASON_DEAUTH_LEAVING, + true, frame_buf); + WARN_ON(__sta_info_destroy(sta)); + } + } + + mutex_unlock(&local->sta_mtx); +} + +/* + * This function is called with state == IEEE80211_IBSS_MLME_JOINED + */ + +static void ieee80211_sta_merge_ibss(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + enum nl80211_bss_scan_width scan_width; + + sdata_assert_lock(sdata); + + mod_timer(&ifibss->timer, + round_jiffies(jiffies + IEEE80211_IBSS_MERGE_INTERVAL)); + + ieee80211_ibss_sta_expire(sdata); + + if (time_before(jiffies, ifibss->last_scan_completed + + IEEE80211_IBSS_MERGE_INTERVAL)) + return; + + if (ieee80211_sta_active_ibss(sdata)) + return; + + if (ifibss->fixed_channel) + return; + + sdata_info(sdata, + "No active IBSS STAs - trying to scan for other IBSS networks with same SSID (merge)\n"); + + scan_width = cfg80211_chandef_to_scan_width(&ifibss->chandef); + ieee80211_request_ibss_scan(sdata, ifibss->ssid, ifibss->ssid_len, + NULL, 0, scan_width); +} + +static void ieee80211_sta_create_ibss(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + u8 bssid[ETH_ALEN]; + u16 capability; + int i; + + sdata_assert_lock(sdata); + + if (ifibss->fixed_bssid) { + memcpy(bssid, ifibss->bssid, ETH_ALEN); + } else { + /* Generate random, not broadcast, locally administered BSSID. Mix in + * own MAC address to make sure that devices that do not have proper + * random number generator get different BSSID. */ + get_random_bytes(bssid, ETH_ALEN); + for (i = 0; i < ETH_ALEN; i++) + bssid[i] ^= sdata->vif.addr[i]; + bssid[0] &= ~0x01; + bssid[0] |= 0x02; + } + + sdata_info(sdata, "Creating new IBSS network, BSSID %pM\n", bssid); + + capability = WLAN_CAPABILITY_IBSS; + + if (ifibss->privacy) + capability |= WLAN_CAPABILITY_PRIVACY; + + __ieee80211_sta_join_ibss(sdata, bssid, sdata->vif.bss_conf.beacon_int, + &ifibss->chandef, ifibss->basic_rates, + capability, 0, true); +} + +static unsigned int ibss_setup_channels(struct wiphy *wiphy, + struct ieee80211_channel **channels, + unsigned int channels_max, + u32 center_freq, u32 width) +{ + struct ieee80211_channel *chan = NULL; + unsigned int n_chan = 0; + u32 start_freq, end_freq, freq; + + if (width <= 20) { + start_freq = center_freq; + end_freq = center_freq; + } else { + start_freq = center_freq - width / 2 + 10; + end_freq = center_freq + width / 2 - 10; + } + + for (freq = start_freq; freq <= end_freq; freq += 20) { + chan = ieee80211_get_channel(wiphy, freq); + if (!chan) + continue; + if (n_chan >= channels_max) + return n_chan; + + channels[n_chan] = chan; + n_chan++; + } + + return n_chan; +} + +static unsigned int +ieee80211_ibss_setup_scan_channels(struct wiphy *wiphy, + const struct cfg80211_chan_def *chandef, + struct ieee80211_channel **channels, + unsigned int channels_max) +{ + unsigned int n_chan = 0; + u32 width, cf1, cf2 = 0; + + switch (chandef->width) { + case NL80211_CHAN_WIDTH_40: + width = 40; + break; + case NL80211_CHAN_WIDTH_80P80: + cf2 = chandef->center_freq2; + fallthrough; + case NL80211_CHAN_WIDTH_80: + width = 80; + break; + case NL80211_CHAN_WIDTH_160: + width = 160; + break; + default: + width = 20; + break; + } + + cf1 = chandef->center_freq1; + + n_chan = ibss_setup_channels(wiphy, channels, channels_max, cf1, width); + + if (cf2) + n_chan += ibss_setup_channels(wiphy, &channels[n_chan], + channels_max - n_chan, cf2, + width); + + return n_chan; +} + +/* + * This function is called with state == IEEE80211_IBSS_MLME_SEARCH + */ + +static void ieee80211_sta_find_ibss(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + struct ieee80211_local *local = sdata->local; + struct cfg80211_bss *cbss; + struct ieee80211_channel *chan = NULL; + const u8 *bssid = NULL; + enum nl80211_bss_scan_width scan_width; + int active_ibss; + + sdata_assert_lock(sdata); + + active_ibss = ieee80211_sta_active_ibss(sdata); + ibss_dbg(sdata, "sta_find_ibss (active_ibss=%d)\n", active_ibss); + + if (active_ibss) + return; + + if (ifibss->fixed_bssid) + bssid = ifibss->bssid; + if (ifibss->fixed_channel) + chan = ifibss->chandef.chan; + if (!is_zero_ether_addr(ifibss->bssid)) + bssid = ifibss->bssid; + cbss = cfg80211_get_bss(local->hw.wiphy, chan, bssid, + ifibss->ssid, ifibss->ssid_len, + IEEE80211_BSS_TYPE_IBSS, + IEEE80211_PRIVACY(ifibss->privacy)); + + if (cbss) { + struct ieee80211_bss *bss; + + bss = (void *)cbss->priv; + ibss_dbg(sdata, + "sta_find_ibss: selected %pM current %pM\n", + cbss->bssid, ifibss->bssid); + sdata_info(sdata, + "Selected IBSS BSSID %pM based on configured SSID\n", + cbss->bssid); + + ieee80211_sta_join_ibss(sdata, bss); + ieee80211_rx_bss_put(local, bss); + return; + } + + /* if a fixed bssid and a fixed freq have been provided create the IBSS + * directly and do not waste time scanning + */ + if (ifibss->fixed_bssid && ifibss->fixed_channel) { + sdata_info(sdata, "Created IBSS using preconfigured BSSID %pM\n", + bssid); + ieee80211_sta_create_ibss(sdata); + return; + } + + + ibss_dbg(sdata, "sta_find_ibss: did not try to join ibss\n"); + + /* Selected IBSS not found in current scan results - try to scan */ + if (time_after(jiffies, ifibss->last_scan_completed + + IEEE80211_SCAN_INTERVAL)) { + struct ieee80211_channel *channels[8]; + unsigned int num; + + sdata_info(sdata, "Trigger new scan to find an IBSS to join\n"); + + scan_width = cfg80211_chandef_to_scan_width(&ifibss->chandef); + + if (ifibss->fixed_channel) { + num = ieee80211_ibss_setup_scan_channels(local->hw.wiphy, + &ifibss->chandef, + channels, + ARRAY_SIZE(channels)); + ieee80211_request_ibss_scan(sdata, ifibss->ssid, + ifibss->ssid_len, channels, + num, scan_width); + } else { + ieee80211_request_ibss_scan(sdata, ifibss->ssid, + ifibss->ssid_len, NULL, + 0, scan_width); + } + } else { + int interval = IEEE80211_SCAN_INTERVAL; + + if (time_after(jiffies, ifibss->ibss_join_req + + IEEE80211_IBSS_JOIN_TIMEOUT)) + ieee80211_sta_create_ibss(sdata); + + mod_timer(&ifibss->timer, + round_jiffies(jiffies + interval)); + } +} + +static void ieee80211_rx_mgmt_probe_req(struct ieee80211_sub_if_data *sdata, + struct sk_buff *req) +{ + struct ieee80211_mgmt *mgmt = (void *)req->data; + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + struct ieee80211_local *local = sdata->local; + int tx_last_beacon, len = req->len; + struct sk_buff *skb; + struct beacon_data *presp; + u8 *pos, *end; + + sdata_assert_lock(sdata); + + presp = sdata_dereference(ifibss->presp, sdata); + + if (ifibss->state != IEEE80211_IBSS_MLME_JOINED || + len < 24 + 2 || !presp) + return; + + tx_last_beacon = drv_tx_last_beacon(local); + + ibss_dbg(sdata, "RX ProbeReq SA=%pM DA=%pM\n", mgmt->sa, mgmt->da); + ibss_dbg(sdata, "\tBSSID=%pM (tx_last_beacon=%d)\n", + mgmt->bssid, tx_last_beacon); + + if (!tx_last_beacon && is_multicast_ether_addr(mgmt->da)) + return; + + if (!ether_addr_equal(mgmt->bssid, ifibss->bssid) && + !is_broadcast_ether_addr(mgmt->bssid)) + return; + + end = ((u8 *) mgmt) + len; + pos = mgmt->u.probe_req.variable; + if (pos[0] != WLAN_EID_SSID || + pos + 2 + pos[1] > end) { + ibss_dbg(sdata, "Invalid SSID IE in ProbeReq from %pM\n", + mgmt->sa); + return; + } + if (pos[1] != 0 && + (pos[1] != ifibss->ssid_len || + memcmp(pos + 2, ifibss->ssid, ifibss->ssid_len))) { + /* Ignore ProbeReq for foreign SSID */ + return; + } + + /* Reply with ProbeResp */ + skb = dev_alloc_skb(local->tx_headroom + presp->head_len); + if (!skb) + return; + + skb_reserve(skb, local->tx_headroom); + skb_put_data(skb, presp->head, presp->head_len); + + memcpy(((struct ieee80211_mgmt *) skb->data)->da, mgmt->sa, ETH_ALEN); + ibss_dbg(sdata, "Sending ProbeResp to %pM\n", mgmt->sa); + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_INTFL_DONT_ENCRYPT; + + /* avoid excessive retries for probe request to wildcard SSIDs */ + if (pos[1] == 0) + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_CTL_NO_ACK; + + ieee80211_tx_skb(sdata, skb); +} + +static +void ieee80211_rx_mgmt_probe_beacon(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, size_t len, + struct ieee80211_rx_status *rx_status) +{ + size_t baselen; + struct ieee802_11_elems *elems; + + BUILD_BUG_ON(offsetof(typeof(mgmt->u.probe_resp), variable) != + offsetof(typeof(mgmt->u.beacon), variable)); + + /* + * either beacon or probe_resp but the variable field is at the + * same offset + */ + baselen = (u8 *) mgmt->u.probe_resp.variable - (u8 *) mgmt; + if (baselen > len) + return; + + elems = ieee802_11_parse_elems(mgmt->u.probe_resp.variable, + len - baselen, false, NULL); + + if (elems) { + ieee80211_rx_bss_info(sdata, mgmt, len, rx_status, elems); + kfree(elems); + } +} + +void ieee80211_ibss_rx_queued_mgmt(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_rx_status *rx_status; + struct ieee80211_mgmt *mgmt; + u16 fc; + struct ieee802_11_elems *elems; + int ies_len; + + rx_status = IEEE80211_SKB_RXCB(skb); + mgmt = (struct ieee80211_mgmt *) skb->data; + fc = le16_to_cpu(mgmt->frame_control); + + sdata_lock(sdata); + + if (!sdata->u.ibss.ssid_len) + goto mgmt_out; /* not ready to merge yet */ + + switch (fc & IEEE80211_FCTL_STYPE) { + case IEEE80211_STYPE_PROBE_REQ: + ieee80211_rx_mgmt_probe_req(sdata, skb); + break; + case IEEE80211_STYPE_PROBE_RESP: + case IEEE80211_STYPE_BEACON: + ieee80211_rx_mgmt_probe_beacon(sdata, mgmt, skb->len, + rx_status); + break; + case IEEE80211_STYPE_AUTH: + ieee80211_rx_mgmt_auth_ibss(sdata, mgmt, skb->len); + break; + case IEEE80211_STYPE_DEAUTH: + ieee80211_rx_mgmt_deauth_ibss(sdata, mgmt, skb->len); + break; + case IEEE80211_STYPE_ACTION: + switch (mgmt->u.action.category) { + case WLAN_CATEGORY_SPECTRUM_MGMT: + ies_len = skb->len - + offsetof(struct ieee80211_mgmt, + u.action.u.chan_switch.variable); + + if (ies_len < 0) + break; + + elems = ieee802_11_parse_elems( + mgmt->u.action.u.chan_switch.variable, + ies_len, true, NULL); + + if (elems && !elems->parse_error) + ieee80211_rx_mgmt_spectrum_mgmt(sdata, mgmt, + skb->len, + rx_status, + elems); + kfree(elems); + break; + } + } + + mgmt_out: + sdata_unlock(sdata); +} + +void ieee80211_ibss_work(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + struct sta_info *sta; + + sdata_lock(sdata); + + /* + * Work could be scheduled after scan or similar + * when we aren't even joined (or trying) with a + * network. + */ + if (!ifibss->ssid_len) + goto out; + + spin_lock_bh(&ifibss->incomplete_lock); + while (!list_empty(&ifibss->incomplete_stations)) { + sta = list_first_entry(&ifibss->incomplete_stations, + struct sta_info, list); + list_del(&sta->list); + spin_unlock_bh(&ifibss->incomplete_lock); + + ieee80211_ibss_finish_sta(sta); + rcu_read_unlock(); + spin_lock_bh(&ifibss->incomplete_lock); + } + spin_unlock_bh(&ifibss->incomplete_lock); + + switch (ifibss->state) { + case IEEE80211_IBSS_MLME_SEARCH: + ieee80211_sta_find_ibss(sdata); + break; + case IEEE80211_IBSS_MLME_JOINED: + ieee80211_sta_merge_ibss(sdata); + break; + default: + WARN_ON(1); + break; + } + + out: + sdata_unlock(sdata); +} + +static void ieee80211_ibss_timer(struct timer_list *t) +{ + struct ieee80211_sub_if_data *sdata = + from_timer(sdata, t, u.ibss.timer); + + ieee80211_queue_work(&sdata->local->hw, &sdata->work); +} + +void ieee80211_ibss_setup_sdata(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + + timer_setup(&ifibss->timer, ieee80211_ibss_timer, 0); + INIT_LIST_HEAD(&ifibss->incomplete_stations); + spin_lock_init(&ifibss->incomplete_lock); + INIT_WORK(&ifibss->csa_connection_drop_work, + ieee80211_csa_connection_drop_work); +} + +/* scan finished notification */ +void ieee80211_ibss_notify_scan_completed(struct ieee80211_local *local) +{ + struct ieee80211_sub_if_data *sdata; + + mutex_lock(&local->iflist_mtx); + list_for_each_entry(sdata, &local->interfaces, list) { + if (!ieee80211_sdata_running(sdata)) + continue; + if (sdata->vif.type != NL80211_IFTYPE_ADHOC) + continue; + sdata->u.ibss.last_scan_completed = jiffies; + } + mutex_unlock(&local->iflist_mtx); +} + +int ieee80211_ibss_join(struct ieee80211_sub_if_data *sdata, + struct cfg80211_ibss_params *params) +{ + u32 changed = 0; + u32 rate_flags; + struct ieee80211_supported_band *sband; + enum ieee80211_chanctx_mode chanmode; + struct ieee80211_local *local = sdata->local; + int radar_detect_width = 0; + int i; + int ret; + + if (params->chandef.chan->freq_offset) { + /* this may work, but is untested */ + return -EOPNOTSUPP; + } + + ret = cfg80211_chandef_dfs_required(local->hw.wiphy, + ¶ms->chandef, + sdata->wdev.iftype); + if (ret < 0) + return ret; + + if (ret > 0) { + if (!params->userspace_handles_dfs) + return -EINVAL; + radar_detect_width = BIT(params->chandef.width); + } + + chanmode = (params->channel_fixed && !ret) ? + IEEE80211_CHANCTX_SHARED : IEEE80211_CHANCTX_EXCLUSIVE; + + mutex_lock(&local->chanctx_mtx); + ret = ieee80211_check_combinations(sdata, ¶ms->chandef, chanmode, + radar_detect_width); + mutex_unlock(&local->chanctx_mtx); + if (ret < 0) + return ret; + + if (params->bssid) { + memcpy(sdata->u.ibss.bssid, params->bssid, ETH_ALEN); + sdata->u.ibss.fixed_bssid = true; + } else + sdata->u.ibss.fixed_bssid = false; + + sdata->u.ibss.privacy = params->privacy; + sdata->u.ibss.control_port = params->control_port; + sdata->u.ibss.userspace_handles_dfs = params->userspace_handles_dfs; + sdata->u.ibss.basic_rates = params->basic_rates; + sdata->u.ibss.last_scan_completed = jiffies; + + /* fix basic_rates if channel does not support these rates */ + rate_flags = ieee80211_chandef_rate_flags(¶ms->chandef); + sband = local->hw.wiphy->bands[params->chandef.chan->band]; + for (i = 0; i < sband->n_bitrates; i++) { + if ((rate_flags & sband->bitrates[i].flags) != rate_flags) + sdata->u.ibss.basic_rates &= ~BIT(i); + } + memcpy(sdata->vif.bss_conf.mcast_rate, params->mcast_rate, + sizeof(params->mcast_rate)); + + sdata->vif.bss_conf.beacon_int = params->beacon_interval; + + sdata->u.ibss.chandef = params->chandef; + sdata->u.ibss.fixed_channel = params->channel_fixed; + + if (params->ie) { + sdata->u.ibss.ie = kmemdup(params->ie, params->ie_len, + GFP_KERNEL); + if (sdata->u.ibss.ie) + sdata->u.ibss.ie_len = params->ie_len; + } + + sdata->u.ibss.state = IEEE80211_IBSS_MLME_SEARCH; + sdata->u.ibss.ibss_join_req = jiffies; + + memcpy(sdata->u.ibss.ssid, params->ssid, params->ssid_len); + sdata->u.ibss.ssid_len = params->ssid_len; + + memcpy(&sdata->u.ibss.ht_capa, ¶ms->ht_capa, + sizeof(sdata->u.ibss.ht_capa)); + memcpy(&sdata->u.ibss.ht_capa_mask, ¶ms->ht_capa_mask, + sizeof(sdata->u.ibss.ht_capa_mask)); + + /* + * 802.11n-2009 9.13.3.1: In an IBSS, the HT Protection field is + * reserved, but an HT STA shall protect HT transmissions as though + * the HT Protection field were set to non-HT mixed mode. + * + * In an IBSS, the RIFS Mode field of the HT Operation element is + * also reserved, but an HT STA shall operate as though this field + * were set to 1. + */ + + sdata->vif.bss_conf.ht_operation_mode |= + IEEE80211_HT_OP_MODE_PROTECTION_NONHT_MIXED + | IEEE80211_HT_PARAM_RIFS_MODE; + + changed |= BSS_CHANGED_HT | BSS_CHANGED_MCAST_RATE; + ieee80211_link_info_change_notify(sdata, &sdata->deflink, changed); + + sdata->deflink.smps_mode = IEEE80211_SMPS_OFF; + sdata->deflink.needed_rx_chains = local->rx_chains; + sdata->control_port_over_nl80211 = params->control_port_over_nl80211; + + ieee80211_queue_work(&local->hw, &sdata->work); + + return 0; +} + +int ieee80211_ibss_leave(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + + ieee80211_ibss_disconnect(sdata); + ifibss->ssid_len = 0; + eth_zero_addr(ifibss->bssid); + + /* remove beacon */ + kfree(sdata->u.ibss.ie); + sdata->u.ibss.ie = NULL; + sdata->u.ibss.ie_len = 0; + + /* on the next join, re-program HT parameters */ + memset(&ifibss->ht_capa, 0, sizeof(ifibss->ht_capa)); + memset(&ifibss->ht_capa_mask, 0, sizeof(ifibss->ht_capa_mask)); + + synchronize_rcu(); + + skb_queue_purge(&sdata->skb_queue); + + del_timer_sync(&sdata->u.ibss.timer); + + return 0; +} diff --git a/net/mac80211/ieee80211_i.h b/net/mac80211/ieee80211_i.h new file mode 100644 index 000000000..d5dd2d9e8 --- /dev/null +++ b/net/mac80211/ieee80211_i.h @@ -0,0 +1,2559 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright 2002-2005, Instant802 Networks, Inc. + * Copyright 2005, Devicescape Software, Inc. + * Copyright 2006-2007 Jiri Benc + * Copyright 2007-2010 Johannes Berg + * Copyright 2013-2015 Intel Mobile Communications GmbH + * Copyright (C) 2018-2022 Intel Corporation + */ + +#ifndef IEEE80211_I_H +#define IEEE80211_I_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "key.h" +#include "sta_info.h" +#include "debug.h" + +extern const struct cfg80211_ops mac80211_config_ops; + +struct ieee80211_local; + +/* Maximum number of broadcast/multicast frames to buffer when some of the + * associated stations are using power saving. */ +#define AP_MAX_BC_BUFFER 128 + +/* Maximum number of frames buffered to all STAs, including multicast frames. + * Note: increasing this limit increases the potential memory requirement. Each + * frame can be up to about 2 kB long. */ +#define TOTAL_MAX_TX_BUFFER 512 + +/* Required encryption head and tailroom */ +#define IEEE80211_ENCRYPT_HEADROOM 8 +#define IEEE80211_ENCRYPT_TAILROOM 18 + +/* power level hasn't been configured (or set to automatic) */ +#define IEEE80211_UNSET_POWER_LEVEL INT_MIN + +/* + * Some APs experience problems when working with U-APSD. Decreasing the + * probability of that happening by using legacy mode for all ACs but VO isn't + * enough. + * + * Cisco 4410N originally forced us to enable VO by default only because it + * treated non-VO ACs as legacy. + * + * However some APs (notably Netgear R7000) silently reclassify packets to + * different ACs. Since u-APSD ACs require trigger frames for frame retrieval + * clients would never see some frames (e.g. ARP responses) or would fetch them + * accidentally after a long time. + * + * It makes little sense to enable u-APSD queues by default because it needs + * userspace applications to be aware of it to actually take advantage of the + * possible additional powersavings. Implicitly depending on driver autotrigger + * frame support doesn't make much sense. + */ +#define IEEE80211_DEFAULT_UAPSD_QUEUES 0 + +#define IEEE80211_DEFAULT_MAX_SP_LEN \ + IEEE80211_WMM_IE_STA_QOSINFO_SP_ALL + +extern const u8 ieee80211_ac_to_qos_mask[IEEE80211_NUM_ACS]; + +#define IEEE80211_DEAUTH_FRAME_LEN (24 /* hdr */ + 2 /* reason */) + +#define IEEE80211_MAX_NAN_INSTANCE_ID 255 + + +/* + * Keep a station's queues on the active list for deficit accounting purposes + * if it was active or queued during the last 100ms + */ +#define AIRTIME_ACTIVE_DURATION (HZ / 10) + +struct ieee80211_bss { + u32 device_ts_beacon, device_ts_presp; + + bool wmm_used; + bool uapsd_supported; + +#define IEEE80211_MAX_SUPP_RATES 32 + u8 supp_rates[IEEE80211_MAX_SUPP_RATES]; + size_t supp_rates_len; + struct ieee80211_rate *beacon_rate; + + u32 vht_cap_info; + + /* + * During association, we save an ERP value from a probe response so + * that we can feed ERP info to the driver when handling the + * association completes. these fields probably won't be up-to-date + * otherwise, you probably don't want to use them. + */ + bool has_erp_value; + u8 erp_value; + + /* Keep track of the corruption of the last beacon/probe response. */ + u8 corrupt_data; + + /* Keep track of what bits of information we have valid info for. */ + u8 valid_data; +}; + +/** + * enum ieee80211_corrupt_data_flags - BSS data corruption flags + * @IEEE80211_BSS_CORRUPT_BEACON: last beacon frame received was corrupted + * @IEEE80211_BSS_CORRUPT_PROBE_RESP: last probe response received was corrupted + * + * These are bss flags that are attached to a bss in the + * @corrupt_data field of &struct ieee80211_bss. + */ +enum ieee80211_bss_corrupt_data_flags { + IEEE80211_BSS_CORRUPT_BEACON = BIT(0), + IEEE80211_BSS_CORRUPT_PROBE_RESP = BIT(1) +}; + +/** + * enum ieee80211_valid_data_flags - BSS valid data flags + * @IEEE80211_BSS_VALID_WMM: WMM/UAPSD data was gathered from non-corrupt IE + * @IEEE80211_BSS_VALID_RATES: Supported rates were gathered from non-corrupt IE + * @IEEE80211_BSS_VALID_ERP: ERP flag was gathered from non-corrupt IE + * + * These are bss flags that are attached to a bss in the + * @valid_data field of &struct ieee80211_bss. They show which parts + * of the data structure were received as a result of an un-corrupted + * beacon/probe response. + */ +enum ieee80211_bss_valid_data_flags { + IEEE80211_BSS_VALID_WMM = BIT(1), + IEEE80211_BSS_VALID_RATES = BIT(2), + IEEE80211_BSS_VALID_ERP = BIT(3) +}; + +typedef unsigned __bitwise ieee80211_tx_result; +#define TX_CONTINUE ((__force ieee80211_tx_result) 0u) +#define TX_DROP ((__force ieee80211_tx_result) 1u) +#define TX_QUEUED ((__force ieee80211_tx_result) 2u) + +#define IEEE80211_TX_UNICAST BIT(1) +#define IEEE80211_TX_PS_BUFFERED BIT(2) + +struct ieee80211_tx_data { + struct sk_buff *skb; + struct sk_buff_head skbs; + struct ieee80211_local *local; + struct ieee80211_sub_if_data *sdata; + struct sta_info *sta; + struct ieee80211_key *key; + struct ieee80211_tx_rate rate; + + unsigned int flags; +}; + + +typedef unsigned __bitwise ieee80211_rx_result; +#define RX_CONTINUE ((__force ieee80211_rx_result) 0u) +#define RX_DROP_UNUSABLE ((__force ieee80211_rx_result) 1u) +#define RX_DROP_MONITOR ((__force ieee80211_rx_result) 2u) +#define RX_QUEUED ((__force ieee80211_rx_result) 3u) + +/** + * enum ieee80211_packet_rx_flags - packet RX flags + * @IEEE80211_RX_AMSDU: a-MSDU packet + * @IEEE80211_RX_MALFORMED_ACTION_FRM: action frame is malformed + * @IEEE80211_RX_DEFERRED_RELEASE: frame was subjected to receive reordering + * + * These are per-frame flags that are attached to a frame in the + * @rx_flags field of &struct ieee80211_rx_status. + */ +enum ieee80211_packet_rx_flags { + IEEE80211_RX_AMSDU = BIT(3), + IEEE80211_RX_MALFORMED_ACTION_FRM = BIT(4), + IEEE80211_RX_DEFERRED_RELEASE = BIT(5), +}; + +/** + * enum ieee80211_rx_flags - RX data flags + * + * @IEEE80211_RX_CMNTR: received on cooked monitor already + * @IEEE80211_RX_BEACON_REPORTED: This frame was already reported + * to cfg80211_report_obss_beacon(). + * + * These flags are used across handling multiple interfaces + * for a single frame. + */ +enum ieee80211_rx_flags { + IEEE80211_RX_CMNTR = BIT(0), + IEEE80211_RX_BEACON_REPORTED = BIT(1), +}; + +struct ieee80211_rx_data { + struct list_head *list; + struct sk_buff *skb; + struct ieee80211_local *local; + struct ieee80211_sub_if_data *sdata; + struct ieee80211_link_data *link; + struct sta_info *sta; + struct link_sta_info *link_sta; + struct ieee80211_key *key; + + unsigned int flags; + + /* + * Index into sequence numbers array, 0..16 + * since the last (16) is used for non-QoS, + * will be 16 on non-QoS frames. + */ + int seqno_idx; + + /* + * Index into the security IV/PN arrays, 0..16 + * since the last (16) is used for CCMP-encrypted + * management frames, will be set to 16 on mgmt + * frames and 0 on non-QoS frames. + */ + int security_idx; + + int link_id; + + union { + struct { + u32 iv32; + u16 iv16; + } tkip; + struct { + u8 pn[IEEE80211_CCMP_PN_LEN]; + } ccm_gcm; + }; +}; + +struct ieee80211_csa_settings { + const u16 *counter_offsets_beacon; + const u16 *counter_offsets_presp; + + int n_counter_offsets_beacon; + int n_counter_offsets_presp; + + u8 count; +}; + +struct ieee80211_color_change_settings { + u16 counter_offset_beacon; + u16 counter_offset_presp; + u8 count; +}; + +struct beacon_data { + u8 *head, *tail; + int head_len, tail_len; + struct ieee80211_meshconf_ie *meshconf; + u16 cntdwn_counter_offsets[IEEE80211_MAX_CNTDWN_COUNTERS_NUM]; + u8 cntdwn_current_counter; + struct cfg80211_mbssid_elems *mbssid_ies; + struct rcu_head rcu_head; +}; + +struct probe_resp { + struct rcu_head rcu_head; + int len; + u16 cntdwn_counter_offsets[IEEE80211_MAX_CNTDWN_COUNTERS_NUM]; + u8 data[]; +}; + +struct fils_discovery_data { + struct rcu_head rcu_head; + int len; + u8 data[]; +}; + +struct unsol_bcast_probe_resp_data { + struct rcu_head rcu_head; + int len; + u8 data[]; +}; + +struct ps_data { + /* yes, this looks ugly, but guarantees that we can later use + * bitmap_empty :) + * NB: don't touch this bitmap, use sta_info_{set,clear}_tim_bit */ + u8 tim[sizeof(unsigned long) * BITS_TO_LONGS(IEEE80211_MAX_AID + 1)] + __aligned(__alignof__(unsigned long)); + struct sk_buff_head bc_buf; + atomic_t num_sta_ps; /* number of stations in PS mode */ + int dtim_count; + bool dtim_bc_mc; +}; + +struct ieee80211_if_ap { + struct list_head vlans; /* write-protected with RTNL and local->mtx */ + + struct ps_data ps; + atomic_t num_mcast_sta; /* number of stations receiving multicast */ + + bool multicast_to_unicast; + bool active; +}; + +struct ieee80211_if_vlan { + struct list_head list; /* write-protected with RTNL and local->mtx */ + + /* used for all tx if the VLAN is configured to 4-addr mode */ + struct sta_info __rcu *sta; + atomic_t num_mcast_sta; /* number of stations receiving multicast */ +}; + +struct mesh_stats { + __u32 fwded_mcast; /* Mesh forwarded multicast frames */ + __u32 fwded_unicast; /* Mesh forwarded unicast frames */ + __u32 fwded_frames; /* Mesh total forwarded frames */ + __u32 dropped_frames_ttl; /* Not transmitted since mesh_ttl == 0*/ + __u32 dropped_frames_no_route; /* Not transmitted, no route found */ + __u32 dropped_frames_congestion;/* Not forwarded due to congestion */ +}; + +#define PREQ_Q_F_START 0x1 +#define PREQ_Q_F_REFRESH 0x2 +struct mesh_preq_queue { + struct list_head list; + u8 dst[ETH_ALEN]; + u8 flags; +}; + +struct ieee80211_roc_work { + struct list_head list; + + struct ieee80211_sub_if_data *sdata; + + struct ieee80211_channel *chan; + + bool started, abort, hw_begun, notified; + bool on_channel; + + unsigned long start_time; + + u32 duration, req_duration; + struct sk_buff *frame; + u64 cookie, mgmt_tx_cookie; + enum ieee80211_roc_type type; +}; + +/* flags used in struct ieee80211_if_managed.flags */ +enum ieee80211_sta_flags { + IEEE80211_STA_CONNECTION_POLL = BIT(1), + IEEE80211_STA_CONTROL_PORT = BIT(2), + IEEE80211_STA_MFP_ENABLED = BIT(6), + IEEE80211_STA_UAPSD_ENABLED = BIT(7), + IEEE80211_STA_NULLFUNC_ACKED = BIT(8), + IEEE80211_STA_ENABLE_RRM = BIT(15), +}; + +typedef u32 __bitwise ieee80211_conn_flags_t; + +enum ieee80211_conn_flags { + IEEE80211_CONN_DISABLE_HT = (__force ieee80211_conn_flags_t)BIT(0), + IEEE80211_CONN_DISABLE_40MHZ = (__force ieee80211_conn_flags_t)BIT(1), + IEEE80211_CONN_DISABLE_VHT = (__force ieee80211_conn_flags_t)BIT(2), + IEEE80211_CONN_DISABLE_80P80MHZ = (__force ieee80211_conn_flags_t)BIT(3), + IEEE80211_CONN_DISABLE_160MHZ = (__force ieee80211_conn_flags_t)BIT(4), + IEEE80211_CONN_DISABLE_HE = (__force ieee80211_conn_flags_t)BIT(5), + IEEE80211_CONN_DISABLE_EHT = (__force ieee80211_conn_flags_t)BIT(6), + IEEE80211_CONN_DISABLE_320MHZ = (__force ieee80211_conn_flags_t)BIT(7), +}; + +struct ieee80211_mgd_auth_data { + struct cfg80211_bss *bss; + unsigned long timeout; + int tries; + u16 algorithm, expected_transaction; + + u8 key[WLAN_KEY_LEN_WEP104]; + u8 key_len, key_idx; + bool done, waiting; + bool peer_confirmed; + bool timeout_started; + int link_id; + + u8 ap_addr[ETH_ALEN] __aligned(2); + + u16 sae_trans, sae_status; + size_t data_len; + u8 data[]; +}; + +struct ieee80211_mgd_assoc_data { + struct { + struct cfg80211_bss *bss; + + u8 addr[ETH_ALEN] __aligned(2); + + u8 ap_ht_param; + + struct ieee80211_vht_cap ap_vht_cap; + + size_t elems_len; + u8 *elems; /* pointing to inside ie[] below */ + + ieee80211_conn_flags_t conn_flags; + } link[IEEE80211_MLD_MAX_NUM_LINKS]; + + u8 ap_addr[ETH_ALEN] __aligned(2); + + /* this is for a workaround, so we use it only for non-MLO */ + const u8 *supp_rates; + u8 supp_rates_len; + + unsigned long timeout; + int tries; + + u8 prev_ap_addr[ETH_ALEN]; + u8 ssid[IEEE80211_MAX_SSID_LEN]; + u8 ssid_len; + bool wmm, uapsd; + bool need_beacon; + bool synced; + bool timeout_started; + bool s1g; + + unsigned int assoc_link_id; + + u8 fils_nonces[2 * FILS_NONCE_LEN]; + u8 fils_kek[FILS_MAX_KEK_LEN]; + size_t fils_kek_len; + + size_t ie_len; + u8 *ie_pos; /* used to fill ie[] with link[].elems */ + u8 ie[]; +}; + +struct ieee80211_sta_tx_tspec { + /* timestamp of the first packet in the time slice */ + unsigned long time_slice_start; + + u32 admitted_time; /* in usecs, unlike over the air */ + u8 tsid; + s8 up; /* signed to be able to invalidate with -1 during teardown */ + + /* consumed TX time in microseconds in the time slice */ + u32 consumed_tx_time; + enum { + TX_TSPEC_ACTION_NONE = 0, + TX_TSPEC_ACTION_DOWNGRADE, + TX_TSPEC_ACTION_STOP_DOWNGRADE, + } action; + bool downgraded; +}; + +DECLARE_EWMA(beacon_signal, 4, 4) + +struct ieee80211_if_managed { + struct timer_list timer; + struct timer_list conn_mon_timer; + struct timer_list bcn_mon_timer; + struct work_struct monitor_work; + struct work_struct beacon_connection_loss_work; + struct work_struct csa_connection_drop_work; + + unsigned long beacon_timeout; + unsigned long probe_timeout; + int probe_send_count; + bool nullfunc_failed; + u8 connection_loss:1, + driver_disconnect:1, + reconnect:1, + associated:1; + + struct ieee80211_mgd_auth_data *auth_data; + struct ieee80211_mgd_assoc_data *assoc_data; + + bool powersave; /* powersave requested for this iface */ + bool broken_ap; /* AP is broken -- turn off powersave */ + + unsigned int flags; + + bool status_acked; + bool status_received; + __le16 status_fc; + + enum { + IEEE80211_MFP_DISABLED, + IEEE80211_MFP_OPTIONAL, + IEEE80211_MFP_REQUIRED + } mfp; /* management frame protection */ + + /* + * Bitmask of enabled u-apsd queues, + * IEEE80211_WMM_IE_STA_QOSINFO_AC_BE & co. Needs a new association + * to take effect. + */ + unsigned int uapsd_queues; + + /* + * Maximum number of buffered frames AP can deliver during a + * service period, IEEE80211_WMM_IE_STA_QOSINFO_SP_ALL or similar. + * Needs a new association to take effect. + */ + unsigned int uapsd_max_sp_len; + + u8 use_4addr; + + /* + * State variables for keeping track of RSSI of the AP currently + * connected to and informing driver when RSSI has gone + * below/above a certain threshold. + */ + int rssi_min_thold, rssi_max_thold; + + struct ieee80211_ht_cap ht_capa; /* configured ht-cap over-rides */ + struct ieee80211_ht_cap ht_capa_mask; /* Valid parts of ht_capa */ + struct ieee80211_vht_cap vht_capa; /* configured VHT overrides */ + struct ieee80211_vht_cap vht_capa_mask; /* Valid parts of vht_capa */ + struct ieee80211_s1g_cap s1g_capa; /* configured S1G overrides */ + struct ieee80211_s1g_cap s1g_capa_mask; /* valid s1g_capa bits */ + + /* TDLS support */ + u8 tdls_peer[ETH_ALEN] __aligned(2); + struct delayed_work tdls_peer_del_work; + struct sk_buff *orig_teardown_skb; /* The original teardown skb */ + struct sk_buff *teardown_skb; /* A copy to send through the AP */ + spinlock_t teardown_lock; /* To lock changing teardown_skb */ + bool tdls_wider_bw_prohibited; + + /* WMM-AC TSPEC support */ + struct ieee80211_sta_tx_tspec tx_tspec[IEEE80211_NUM_ACS]; + /* Use a separate work struct so that we can do something here + * while the sdata->work is flushing the queues, for example. + * otherwise, in scenarios where we hardly get any traffic out + * on the BE queue, but there's a lot of VO traffic, we might + * get stuck in a downgraded situation and flush takes forever. + */ + struct delayed_work tx_tspec_wk; + + /* Information elements from the last transmitted (Re)Association + * Request frame. + */ + u8 *assoc_req_ies; + size_t assoc_req_ies_len; +}; + +struct ieee80211_if_ibss { + struct timer_list timer; + struct work_struct csa_connection_drop_work; + + unsigned long last_scan_completed; + + u32 basic_rates; + + bool fixed_bssid; + bool fixed_channel; + bool privacy; + + bool control_port; + bool userspace_handles_dfs; + + u8 bssid[ETH_ALEN] __aligned(2); + u8 ssid[IEEE80211_MAX_SSID_LEN]; + u8 ssid_len, ie_len; + u8 *ie; + struct cfg80211_chan_def chandef; + + unsigned long ibss_join_req; + /* probe response/beacon for IBSS */ + struct beacon_data __rcu *presp; + + struct ieee80211_ht_cap ht_capa; /* configured ht-cap over-rides */ + struct ieee80211_ht_cap ht_capa_mask; /* Valid parts of ht_capa */ + + spinlock_t incomplete_lock; + struct list_head incomplete_stations; + + enum { + IEEE80211_IBSS_MLME_SEARCH, + IEEE80211_IBSS_MLME_JOINED, + } state; +}; + +/** + * struct ieee80211_if_ocb - OCB mode state + * + * @housekeeping_timer: timer for periodic invocation of a housekeeping task + * @wrkq_flags: OCB deferred task action + * @incomplete_lock: delayed STA insertion lock + * @incomplete_stations: list of STAs waiting for delayed insertion + * @joined: indication if the interface is connected to an OCB network + */ +struct ieee80211_if_ocb { + struct timer_list housekeeping_timer; + unsigned long wrkq_flags; + + spinlock_t incomplete_lock; + struct list_head incomplete_stations; + + bool joined; +}; + +/** + * struct ieee80211_mesh_sync_ops - Extensible synchronization framework interface + * + * these declarations define the interface, which enables + * vendor-specific mesh synchronization + * + */ +struct ieee802_11_elems; +struct ieee80211_mesh_sync_ops { + void (*rx_bcn_presp)(struct ieee80211_sub_if_data *sdata, u16 stype, + struct ieee80211_mgmt *mgmt, unsigned int len, + const struct ieee80211_meshconf_ie *mesh_cfg, + struct ieee80211_rx_status *rx_status); + + /* should be called with beacon_data under RCU read lock */ + void (*adjust_tsf)(struct ieee80211_sub_if_data *sdata, + struct beacon_data *beacon); + /* add other framework functions here */ +}; + +struct mesh_csa_settings { + struct rcu_head rcu_head; + struct cfg80211_csa_settings settings; +}; + +/** + * struct mesh_table + * + * @known_gates: list of known mesh gates and their mpaths by the station. The + * gate's mpath may or may not be resolved and active. + * @gates_lock: protects updates to known_gates + * @rhead: the rhashtable containing struct mesh_paths, keyed by dest addr + * @walk_head: linked list containing all mesh_path objects + * @walk_lock: lock protecting walk_head + * @entries: number of entries in the table + */ +struct mesh_table { + struct hlist_head known_gates; + spinlock_t gates_lock; + struct rhashtable rhead; + struct hlist_head walk_head; + spinlock_t walk_lock; + atomic_t entries; /* Up to MAX_MESH_NEIGHBOURS */ +}; + +struct ieee80211_if_mesh { + struct timer_list housekeeping_timer; + struct timer_list mesh_path_timer; + struct timer_list mesh_path_root_timer; + + unsigned long wrkq_flags; + unsigned long mbss_changed; + + bool userspace_handles_dfs; + + u8 mesh_id[IEEE80211_MAX_MESH_ID_LEN]; + size_t mesh_id_len; + /* Active Path Selection Protocol Identifier */ + u8 mesh_pp_id; + /* Active Path Selection Metric Identifier */ + u8 mesh_pm_id; + /* Congestion Control Mode Identifier */ + u8 mesh_cc_id; + /* Synchronization Protocol Identifier */ + u8 mesh_sp_id; + /* Authentication Protocol Identifier */ + u8 mesh_auth_id; + /* Local mesh Sequence Number */ + u32 sn; + /* Last used PREQ ID */ + u32 preq_id; + atomic_t mpaths; + /* Timestamp of last SN update */ + unsigned long last_sn_update; + /* Time when it's ok to send next PERR */ + unsigned long next_perr; + /* Timestamp of last PREQ sent */ + unsigned long last_preq; + struct mesh_rmc *rmc; + spinlock_t mesh_preq_queue_lock; + struct mesh_preq_queue preq_queue; + int preq_queue_len; + struct mesh_stats mshstats; + struct mesh_config mshcfg; + atomic_t estab_plinks; + u32 mesh_seqnum; + bool accepting_plinks; + int num_gates; + struct beacon_data __rcu *beacon; + const u8 *ie; + u8 ie_len; + enum { + IEEE80211_MESH_SEC_NONE = 0x0, + IEEE80211_MESH_SEC_AUTHED = 0x1, + IEEE80211_MESH_SEC_SECURED = 0x2, + } security; + bool user_mpm; + /* Extensible Synchronization Framework */ + const struct ieee80211_mesh_sync_ops *sync_ops; + s64 sync_offset_clockdrift_max; + spinlock_t sync_offset_lock; + /* mesh power save */ + enum nl80211_mesh_power_mode nonpeer_pm; + int ps_peers_light_sleep; + int ps_peers_deep_sleep; + struct ps_data ps; + /* Channel Switching Support */ + struct mesh_csa_settings __rcu *csa; + enum { + IEEE80211_MESH_CSA_ROLE_NONE, + IEEE80211_MESH_CSA_ROLE_INIT, + IEEE80211_MESH_CSA_ROLE_REPEATER, + } csa_role; + u8 chsw_ttl; + u16 pre_value; + + /* offset from skb->data while building IE */ + int meshconf_offset; + + struct mesh_table mesh_paths; + struct mesh_table mpp_paths; /* Store paths for MPP&MAP */ + int mesh_paths_generation; + int mpp_paths_generation; +}; + +#ifdef CONFIG_MAC80211_MESH +#define IEEE80211_IFSTA_MESH_CTR_INC(msh, name) \ + do { (msh)->mshstats.name++; } while (0) +#else +#define IEEE80211_IFSTA_MESH_CTR_INC(msh, name) \ + do { } while (0) +#endif + +/** + * enum ieee80211_sub_if_data_flags - virtual interface flags + * + * @IEEE80211_SDATA_ALLMULTI: interface wants all multicast packets + * @IEEE80211_SDATA_DONT_BRIDGE_PACKETS: bridge packets between + * associated stations and deliver multicast frames both + * back to wireless media and to the local net stack. + * @IEEE80211_SDATA_DISCONNECT_RESUME: Disconnect after resume. + * @IEEE80211_SDATA_IN_DRIVER: indicates interface was added to driver + * @IEEE80211_SDATA_DISCONNECT_HW_RESTART: Disconnect after hardware restart + * recovery + */ +enum ieee80211_sub_if_data_flags { + IEEE80211_SDATA_ALLMULTI = BIT(0), + IEEE80211_SDATA_DONT_BRIDGE_PACKETS = BIT(3), + IEEE80211_SDATA_DISCONNECT_RESUME = BIT(4), + IEEE80211_SDATA_IN_DRIVER = BIT(5), + IEEE80211_SDATA_DISCONNECT_HW_RESTART = BIT(6), +}; + +/** + * enum ieee80211_sdata_state_bits - virtual interface state bits + * @SDATA_STATE_RUNNING: virtual interface is up & running; this + * mirrors netif_running() but is separate for interface type + * change handling while the interface is up + * @SDATA_STATE_OFFCHANNEL: This interface is currently in offchannel + * mode, so queues are stopped + * @SDATA_STATE_OFFCHANNEL_BEACON_STOPPED: Beaconing was stopped due + * to offchannel, reset when offchannel returns + */ +enum ieee80211_sdata_state_bits { + SDATA_STATE_RUNNING, + SDATA_STATE_OFFCHANNEL, + SDATA_STATE_OFFCHANNEL_BEACON_STOPPED, +}; + +/** + * enum ieee80211_chanctx_mode - channel context configuration mode + * + * @IEEE80211_CHANCTX_SHARED: channel context may be used by + * multiple interfaces + * @IEEE80211_CHANCTX_EXCLUSIVE: channel context can be used + * only by a single interface. This can be used for example for + * non-fixed channel IBSS. + */ +enum ieee80211_chanctx_mode { + IEEE80211_CHANCTX_SHARED, + IEEE80211_CHANCTX_EXCLUSIVE +}; + +/** + * enum ieee80211_chanctx_replace_state - channel context replacement state + * + * This is used for channel context in-place reservations that require channel + * context switch/swap. + * + * @IEEE80211_CHANCTX_REPLACE_NONE: no replacement is taking place + * @IEEE80211_CHANCTX_WILL_BE_REPLACED: this channel context will be replaced + * by a (not yet registered) channel context pointed by %replace_ctx. + * @IEEE80211_CHANCTX_REPLACES_OTHER: this (not yet registered) channel context + * replaces an existing channel context pointed to by %replace_ctx. + */ +enum ieee80211_chanctx_replace_state { + IEEE80211_CHANCTX_REPLACE_NONE, + IEEE80211_CHANCTX_WILL_BE_REPLACED, + IEEE80211_CHANCTX_REPLACES_OTHER, +}; + +struct ieee80211_chanctx { + struct list_head list; + struct rcu_head rcu_head; + + struct list_head assigned_links; + struct list_head reserved_links; + + enum ieee80211_chanctx_replace_state replace_state; + struct ieee80211_chanctx *replace_ctx; + + enum ieee80211_chanctx_mode mode; + bool driver_present; + + struct ieee80211_chanctx_conf conf; +}; + +struct mac80211_qos_map { + struct cfg80211_qos_map qos_map; + struct rcu_head rcu_head; +}; + +enum txq_info_flags { + IEEE80211_TXQ_STOP, + IEEE80211_TXQ_AMPDU, + IEEE80211_TXQ_NO_AMSDU, + IEEE80211_TXQ_DIRTY, +}; + +/** + * struct txq_info - per tid queue + * + * @tin: contains packets split into multiple flows + * @def_flow: used as a fallback flow when a packet destined to @tin hashes to + * a fq_flow which is already owned by a different tin + * @def_cvars: codel vars for @def_flow + * @frags: used to keep fragments created after dequeue + * @schedule_order: used with ieee80211_local->active_txqs + * @schedule_round: counter to prevent infinite loops on TXQ scheduling + */ +struct txq_info { + struct fq_tin tin; + struct codel_vars def_cvars; + struct codel_stats cstats; + + u16 schedule_round; + struct list_head schedule_order; + + struct sk_buff_head frags; + + unsigned long flags; + + /* keep last! */ + struct ieee80211_txq txq; +}; + +struct ieee80211_if_mntr { + u32 flags; + u8 mu_follow_addr[ETH_ALEN] __aligned(2); + + struct list_head list; +}; + +/** + * struct ieee80211_if_nan - NAN state + * + * @conf: current NAN configuration + * @func_ids: a bitmap of available instance_id's + */ +struct ieee80211_if_nan { + struct cfg80211_nan_conf conf; + + /* protects function_inst_ids */ + spinlock_t func_lock; + struct idr function_inst_ids; +}; + +struct ieee80211_link_data_managed { + u8 bssid[ETH_ALEN] __aligned(2); + + u8 dtim_period; + enum ieee80211_smps_mode req_smps, /* requested smps mode */ + driver_smps_mode; /* smps mode request */ + + ieee80211_conn_flags_t conn_flags; + + s16 p2p_noa_index; + + bool tdls_chan_switch_prohibited; + + bool have_beacon; + bool tracking_signal_avg; + bool disable_wmm_tracking; + bool operating_11g_mode; + + bool csa_waiting_bcn; + bool csa_ignored_same_chan; + struct timer_list chswitch_timer; + struct work_struct chswitch_work; + + struct work_struct request_smps_work; + bool beacon_crc_valid; + u32 beacon_crc; + struct ewma_beacon_signal ave_beacon_signal; + int last_ave_beacon_signal; + + /* + * Number of Beacon frames used in ave_beacon_signal. This can be used + * to avoid generating less reliable cqm events that would be based + * only on couple of received frames. + */ + unsigned int count_beacon_signal; + + /* Number of times beacon loss was invoked. */ + unsigned int beacon_loss_count; + + /* + * Last Beacon frame signal strength average (ave_beacon_signal / 16) + * that triggered a cqm event. 0 indicates that no event has been + * generated for the current association. + */ + int last_cqm_event_signal; + + int wmm_last_param_set; + int mu_edca_last_param_set; + + struct cfg80211_bss *bss; +}; + +struct ieee80211_link_data_ap { + struct beacon_data __rcu *beacon; + struct probe_resp __rcu *probe_resp; + struct fils_discovery_data __rcu *fils_discovery; + struct unsol_bcast_probe_resp_data __rcu *unsol_bcast_probe_resp; + + /* to be used after channel switch. */ + struct cfg80211_beacon_data *next_beacon; +}; + +struct ieee80211_link_data { + struct ieee80211_sub_if_data *sdata; + unsigned int link_id; + + struct list_head assigned_chanctx_list; /* protected by chanctx_mtx */ + struct list_head reserved_chanctx_list; /* protected by chanctx_mtx */ + + /* multicast keys only */ + struct ieee80211_key __rcu *gtk[NUM_DEFAULT_KEYS + + NUM_DEFAULT_MGMT_KEYS + + NUM_DEFAULT_BEACON_KEYS]; + struct ieee80211_key __rcu *default_multicast_key; + struct ieee80211_key __rcu *default_mgmt_key; + struct ieee80211_key __rcu *default_beacon_key; + + struct work_struct csa_finalize_work; + bool csa_block_tx; /* write-protected by sdata_lock and local->mtx */ + + bool operating_11g_mode; + + struct cfg80211_chan_def csa_chandef; + + struct work_struct color_change_finalize_work; + struct delayed_work color_collision_detect_work; + u64 color_bitmap; + + /* context reservation -- protected with chanctx_mtx */ + struct ieee80211_chanctx *reserved_chanctx; + struct cfg80211_chan_def reserved_chandef; + bool reserved_radar_required; + bool reserved_ready; + + u8 needed_rx_chains; + enum ieee80211_smps_mode smps_mode; + + int user_power_level; /* in dBm */ + int ap_power_level; /* in dBm */ + + bool radar_required; + struct delayed_work dfs_cac_timer_work; + + union { + struct ieee80211_link_data_managed mgd; + struct ieee80211_link_data_ap ap; + } u; + + struct ieee80211_tx_queue_params tx_conf[IEEE80211_NUM_ACS]; + + struct ieee80211_bss_conf *conf; +}; + +struct ieee80211_sub_if_data { + struct list_head list; + + struct wireless_dev wdev; + + /* keys */ + struct list_head key_list; + + /* count for keys needing tailroom space allocation */ + int crypto_tx_tailroom_needed_cnt; + int crypto_tx_tailroom_pending_dec; + struct delayed_work dec_tailroom_needed_wk; + + struct net_device *dev; + struct ieee80211_local *local; + + unsigned int flags; + + unsigned long state; + + char name[IFNAMSIZ]; + + struct ieee80211_fragment_cache frags; + + /* TID bitmap for NoAck policy */ + u16 noack_map; + + /* bit field of ACM bits (BIT(802.1D tag)) */ + u8 wmm_acm; + + struct ieee80211_key __rcu *keys[NUM_DEFAULT_KEYS]; + struct ieee80211_key __rcu *default_unicast_key; + + u16 sequence_number; + u16 mld_mcast_seq; + __be16 control_port_protocol; + bool control_port_no_encrypt; + bool control_port_no_preauth; + bool control_port_over_nl80211; + + atomic_t num_tx_queued; + struct mac80211_qos_map __rcu *qos_map; + + /* used to reconfigure hardware SM PS */ + struct work_struct recalc_smps; + + struct work_struct work; + struct sk_buff_head skb_queue; + struct sk_buff_head status_queue; + + /* + * AP this belongs to: self in AP mode and + * corresponding AP in VLAN mode, NULL for + * all others (might be needed later in IBSS) + */ + struct ieee80211_if_ap *bss; + + /* bitmap of allowed (non-MCS) rate indexes for rate control */ + u32 rc_rateidx_mask[NUM_NL80211_BANDS]; + + bool rc_has_mcs_mask[NUM_NL80211_BANDS]; + u8 rc_rateidx_mcs_mask[NUM_NL80211_BANDS][IEEE80211_HT_MCS_MASK_LEN]; + + bool rc_has_vht_mcs_mask[NUM_NL80211_BANDS]; + u16 rc_rateidx_vht_mcs_mask[NUM_NL80211_BANDS][NL80211_VHT_NSS_MAX]; + + /* Beacon frame (non-MCS) rate (as a bitmap) */ + u32 beacon_rateidx_mask[NUM_NL80211_BANDS]; + bool beacon_rate_set; + + union { + struct ieee80211_if_ap ap; + struct ieee80211_if_vlan vlan; + struct ieee80211_if_managed mgd; + struct ieee80211_if_ibss ibss; + struct ieee80211_if_mesh mesh; + struct ieee80211_if_ocb ocb; + struct ieee80211_if_mntr mntr; + struct ieee80211_if_nan nan; + } u; + + struct ieee80211_link_data deflink; + struct ieee80211_link_data __rcu *link[IEEE80211_MLD_MAX_NUM_LINKS]; + + /* for ieee80211_set_active_links_async() */ + struct work_struct activate_links_work; + u16 desired_active_links; + +#ifdef CONFIG_MAC80211_DEBUGFS + struct { + struct dentry *subdir_stations; + struct dentry *default_unicast_key; + struct dentry *default_multicast_key; + struct dentry *default_mgmt_key; + struct dentry *default_beacon_key; + } debugfs; +#endif + + /* must be last, dynamically sized area in this! */ + struct ieee80211_vif vif; +}; + +static inline +struct ieee80211_sub_if_data *vif_to_sdata(struct ieee80211_vif *p) +{ + return container_of(p, struct ieee80211_sub_if_data, vif); +} + +static inline void sdata_lock(struct ieee80211_sub_if_data *sdata) + __acquires(&sdata->wdev.mtx) +{ + mutex_lock(&sdata->wdev.mtx); + __acquire(&sdata->wdev.mtx); +} + +static inline void sdata_unlock(struct ieee80211_sub_if_data *sdata) + __releases(&sdata->wdev.mtx) +{ + mutex_unlock(&sdata->wdev.mtx); + __release(&sdata->wdev.mtx); +} + +#define sdata_dereference(p, sdata) \ + rcu_dereference_protected(p, lockdep_is_held(&sdata->wdev.mtx)) + +static inline void +sdata_assert_lock(struct ieee80211_sub_if_data *sdata) +{ + lockdep_assert_held(&sdata->wdev.mtx); +} + +static inline int +ieee80211_chanwidth_get_shift(enum nl80211_chan_width width) +{ + switch (width) { + case NL80211_CHAN_WIDTH_5: + return 2; + case NL80211_CHAN_WIDTH_10: + return 1; + default: + return 0; + } +} + +static inline int +ieee80211_chandef_get_shift(struct cfg80211_chan_def *chandef) +{ + return ieee80211_chanwidth_get_shift(chandef->width); +} + +static inline int +ieee80211_vif_get_shift(struct ieee80211_vif *vif) +{ + struct ieee80211_chanctx_conf *chanctx_conf; + int shift = 0; + + rcu_read_lock(); + chanctx_conf = rcu_dereference(vif->bss_conf.chanctx_conf); + if (chanctx_conf) + shift = ieee80211_chandef_get_shift(&chanctx_conf->def); + rcu_read_unlock(); + + return shift; +} + +static inline int +ieee80211_get_mbssid_beacon_len(struct cfg80211_mbssid_elems *elems) +{ + int i, len = 0; + + if (!elems) + return 0; + + for (i = 0; i < elems->cnt; i++) + len += elems->elem[i].len; + + return len; +} + +enum { + IEEE80211_RX_MSG = 1, + IEEE80211_TX_STATUS_MSG = 2, +}; + +enum queue_stop_reason { + IEEE80211_QUEUE_STOP_REASON_DRIVER, + IEEE80211_QUEUE_STOP_REASON_PS, + IEEE80211_QUEUE_STOP_REASON_CSA, + IEEE80211_QUEUE_STOP_REASON_AGGREGATION, + IEEE80211_QUEUE_STOP_REASON_SUSPEND, + IEEE80211_QUEUE_STOP_REASON_SKB_ADD, + IEEE80211_QUEUE_STOP_REASON_OFFCHANNEL, + IEEE80211_QUEUE_STOP_REASON_FLUSH, + IEEE80211_QUEUE_STOP_REASON_TDLS_TEARDOWN, + IEEE80211_QUEUE_STOP_REASON_RESERVE_TID, + IEEE80211_QUEUE_STOP_REASON_IFTYPE_CHANGE, + + IEEE80211_QUEUE_STOP_REASONS, +}; + +#ifdef CONFIG_MAC80211_LEDS +struct tpt_led_trigger { + char name[32]; + const struct ieee80211_tpt_blink *blink_table; + unsigned int blink_table_len; + struct timer_list timer; + struct ieee80211_local *local; + unsigned long prev_traffic; + unsigned long tx_bytes, rx_bytes; + unsigned int active, want; + bool running; +}; +#endif + +/** + * mac80211 scan flags - currently active scan mode + * + * @SCAN_SW_SCANNING: We're currently in the process of scanning but may as + * well be on the operating channel + * @SCAN_HW_SCANNING: The hardware is scanning for us, we have no way to + * determine if we are on the operating channel or not + * @SCAN_ONCHANNEL_SCANNING: Do a software scan on only the current operating + * channel. This should not interrupt normal traffic. + * @SCAN_COMPLETED: Set for our scan work function when the driver reported + * that the scan completed. + * @SCAN_ABORTED: Set for our scan work function when the driver reported + * a scan complete for an aborted scan. + * @SCAN_HW_CANCELLED: Set for our scan work function when the scan is being + * cancelled. + * @SCAN_BEACON_WAIT: Set whenever we're passive scanning because of radar/no-IR + * and could send a probe request after receiving a beacon. + * @SCAN_BEACON_DONE: Beacon received, we can now send a probe request + */ +enum { + SCAN_SW_SCANNING, + SCAN_HW_SCANNING, + SCAN_ONCHANNEL_SCANNING, + SCAN_COMPLETED, + SCAN_ABORTED, + SCAN_HW_CANCELLED, + SCAN_BEACON_WAIT, + SCAN_BEACON_DONE, +}; + +/** + * enum mac80211_scan_state - scan state machine states + * + * @SCAN_DECISION: Main entry point to the scan state machine, this state + * determines if we should keep on scanning or switch back to the + * operating channel + * @SCAN_SET_CHANNEL: Set the next channel to be scanned + * @SCAN_SEND_PROBE: Send probe requests and wait for probe responses + * @SCAN_SUSPEND: Suspend the scan and go back to operating channel to + * send out data + * @SCAN_RESUME: Resume the scan and scan the next channel + * @SCAN_ABORT: Abort the scan and go back to operating channel + */ +enum mac80211_scan_state { + SCAN_DECISION, + SCAN_SET_CHANNEL, + SCAN_SEND_PROBE, + SCAN_SUSPEND, + SCAN_RESUME, + SCAN_ABORT, +}; + +DECLARE_STATIC_KEY_FALSE(aql_disable); + +struct ieee80211_local { + /* embed the driver visible part. + * don't cast (use the static inlines below), but we keep + * it first anyway so they become a no-op */ + struct ieee80211_hw hw; + + struct fq fq; + struct codel_vars *cvars; + struct codel_params cparams; + + /* protects active_txqs and txqi->schedule_order */ + spinlock_t active_txq_lock[IEEE80211_NUM_ACS]; + struct list_head active_txqs[IEEE80211_NUM_ACS]; + u16 schedule_round[IEEE80211_NUM_ACS]; + + u16 airtime_flags; + u32 aql_txq_limit_low[IEEE80211_NUM_ACS]; + u32 aql_txq_limit_high[IEEE80211_NUM_ACS]; + u32 aql_threshold; + atomic_t aql_total_pending_airtime; + atomic_t aql_ac_pending_airtime[IEEE80211_NUM_ACS]; + + const struct ieee80211_ops *ops; + + /* + * private workqueue to mac80211. mac80211 makes this accessible + * via ieee80211_queue_work() + */ + struct workqueue_struct *workqueue; + + unsigned long queue_stop_reasons[IEEE80211_MAX_QUEUES]; + int q_stop_reasons[IEEE80211_MAX_QUEUES][IEEE80211_QUEUE_STOP_REASONS]; + /* also used to protect ampdu_ac_queue and amdpu_ac_stop_refcnt */ + spinlock_t queue_stop_reason_lock; + + int open_count; + int monitors, cooked_mntrs; + /* number of interfaces with corresponding FIF_ flags */ + int fif_fcsfail, fif_plcpfail, fif_control, fif_other_bss, fif_pspoll, + fif_probe_req; + bool probe_req_reg; + bool rx_mcast_action_reg; + unsigned int filter_flags; /* FIF_* */ + + bool wiphy_ciphers_allocated; + + bool use_chanctx; + + /* protects the aggregated multicast list and filter calls */ + spinlock_t filter_lock; + + /* used for uploading changed mc list */ + struct work_struct reconfig_filter; + + /* aggregated multicast list */ + struct netdev_hw_addr_list mc_list; + + bool tim_in_locked_section; /* see ieee80211_beacon_get() */ + + /* + * suspended is true if we finished all the suspend _and_ we have + * not yet come up from resume. This is to be used by mac80211 + * to ensure driver sanity during suspend and mac80211's own + * sanity. It can eventually be used for WoW as well. + */ + bool suspended; + + /* suspending is true during the whole suspend process */ + bool suspending; + + /* + * Resuming is true while suspended, but when we're reprogramming the + * hardware -- at that time it's allowed to use ieee80211_queue_work() + * again even though some other parts of the stack are still suspended + * and we still drop received frames to avoid waking the stack. + */ + bool resuming; + + /* + * quiescing is true during the suspend process _only_ to + * ease timer cancelling etc. + */ + bool quiescing; + + /* device is started */ + bool started; + + /* device is during a HW reconfig */ + bool in_reconfig; + + /* wowlan is enabled -- don't reconfig on resume */ + bool wowlan; + + struct wiphy_work radar_detected_work; + + /* number of RX chains the hardware has */ + u8 rx_chains; + + /* bitmap of which sbands were copied */ + u8 sband_allocated; + + int tx_headroom; /* required headroom for hardware/radiotap */ + + /* Tasklet and skb queue to process calls from IRQ mode. All frames + * added to skb_queue will be processed, but frames in + * skb_queue_unreliable may be dropped if the total length of these + * queues increases over the limit. */ +#define IEEE80211_IRQSAFE_QUEUE_LIMIT 128 + struct tasklet_struct tasklet; + struct sk_buff_head skb_queue; + struct sk_buff_head skb_queue_unreliable; + + spinlock_t rx_path_lock; + + /* Station data */ + /* + * The mutex only protects the list, hash table and + * counter, reads are done with RCU. + */ + struct mutex sta_mtx; + spinlock_t tim_lock; + unsigned long num_sta; + struct list_head sta_list; + struct rhltable sta_hash; + struct rhltable link_sta_hash; + struct timer_list sta_cleanup; + int sta_generation; + + struct sk_buff_head pending[IEEE80211_MAX_QUEUES]; + struct tasklet_struct tx_pending_tasklet; + struct tasklet_struct wake_txqs_tasklet; + + atomic_t agg_queue_stop[IEEE80211_MAX_QUEUES]; + + /* number of interfaces with allmulti RX */ + atomic_t iff_allmultis; + + struct rate_control_ref *rate_ctrl; + + struct arc4_ctx wep_tx_ctx; + struct arc4_ctx wep_rx_ctx; + u32 wep_iv; + + /* see iface.c */ + struct list_head interfaces; + struct list_head mon_list; /* only that are IFF_UP && !cooked */ + struct mutex iflist_mtx; + + /* + * Key mutex, protects sdata's key_list and sta_info's + * key pointers and ptk_idx (write access, they're RCU.) + */ + struct mutex key_mtx; + + /* mutex for scan and work locking */ + struct mutex mtx; + + /* Scanning and BSS list */ + unsigned long scanning; + struct cfg80211_ssid scan_ssid; + struct cfg80211_scan_request *int_scan_req; + struct cfg80211_scan_request __rcu *scan_req; + struct ieee80211_scan_request *hw_scan_req; + struct cfg80211_chan_def scan_chandef; + enum nl80211_band hw_scan_band; + int scan_channel_idx; + int scan_ies_len; + int hw_scan_ies_bufsize; + struct cfg80211_scan_info scan_info; + + struct wiphy_work sched_scan_stopped_work; + struct ieee80211_sub_if_data __rcu *sched_scan_sdata; + struct cfg80211_sched_scan_request __rcu *sched_scan_req; + u8 scan_addr[ETH_ALEN]; + + unsigned long leave_oper_channel_time; + enum mac80211_scan_state next_scan_state; + struct wiphy_delayed_work scan_work; + struct ieee80211_sub_if_data __rcu *scan_sdata; + /* For backward compatibility only -- do not use */ + struct cfg80211_chan_def _oper_chandef; + + /* Temporary remain-on-channel for off-channel operations */ + struct ieee80211_channel *tmp_channel; + + /* channel contexts */ + struct list_head chanctx_list; + struct mutex chanctx_mtx; + +#ifdef CONFIG_MAC80211_LEDS + struct led_trigger tx_led, rx_led, assoc_led, radio_led; + struct led_trigger tpt_led; + atomic_t tx_led_active, rx_led_active, assoc_led_active; + atomic_t radio_led_active, tpt_led_active; + struct tpt_led_trigger *tpt_led_trigger; +#endif + +#ifdef CONFIG_MAC80211_DEBUG_COUNTERS + /* SNMP counters */ + /* dot11CountersTable */ + u32 dot11TransmittedFragmentCount; + u32 dot11MulticastTransmittedFrameCount; + u32 dot11FailedCount; + u32 dot11RetryCount; + u32 dot11MultipleRetryCount; + u32 dot11FrameDuplicateCount; + u32 dot11ReceivedFragmentCount; + u32 dot11MulticastReceivedFrameCount; + u32 dot11TransmittedFrameCount; + + /* TX/RX handler statistics */ + unsigned int tx_handlers_drop; + unsigned int tx_handlers_queued; + unsigned int tx_handlers_drop_wep; + unsigned int tx_handlers_drop_not_assoc; + unsigned int tx_handlers_drop_unauth_port; + unsigned int rx_handlers_drop; + unsigned int rx_handlers_queued; + unsigned int rx_handlers_drop_nullfunc; + unsigned int rx_handlers_drop_defrag; + unsigned int tx_expand_skb_head; + unsigned int tx_expand_skb_head_cloned; + unsigned int rx_expand_skb_head_defrag; + unsigned int rx_handlers_fragments; + unsigned int tx_status_drop; +#define I802_DEBUG_INC(c) (c)++ +#else /* CONFIG_MAC80211_DEBUG_COUNTERS */ +#define I802_DEBUG_INC(c) do { } while (0) +#endif /* CONFIG_MAC80211_DEBUG_COUNTERS */ + + + int total_ps_buffered; /* total number of all buffered unicast and + * multicast packets for power saving stations + */ + + bool pspolling; + /* + * PS can only be enabled when we have exactly one managed + * interface (and monitors) in PS, this then points there. + */ + struct ieee80211_sub_if_data *ps_sdata; + struct work_struct dynamic_ps_enable_work; + struct work_struct dynamic_ps_disable_work; + struct timer_list dynamic_ps_timer; + struct notifier_block ifa_notifier; + struct notifier_block ifa6_notifier; + + /* + * The dynamic ps timeout configured from user space via WEXT - + * this will override whatever chosen by mac80211 internally. + */ + int dynamic_ps_forced_timeout; + + int user_power_level; /* in dBm, for all interfaces */ + + enum ieee80211_smps_mode smps_mode; + + struct work_struct restart_work; + +#ifdef CONFIG_MAC80211_DEBUGFS + struct local_debugfsdentries { + struct dentry *rcdir; + struct dentry *keys; + } debugfs; + bool force_tx_status; +#endif + + /* + * Remain-on-channel support + */ + struct wiphy_delayed_work roc_work; + struct list_head roc_list; + struct wiphy_work hw_roc_start, hw_roc_done; + unsigned long hw_roc_start_time; + u64 roc_cookie_counter; + + struct idr ack_status_frames; + spinlock_t ack_status_lock; + + struct ieee80211_sub_if_data __rcu *p2p_sdata; + + /* virtual monitor interface */ + struct ieee80211_sub_if_data __rcu *monitor_sdata; + struct cfg80211_chan_def monitor_chandef; + + /* extended capabilities provided by mac80211 */ + u8 ext_capa[8]; +}; + +static inline struct ieee80211_sub_if_data * +IEEE80211_DEV_TO_SUB_IF(const struct net_device *dev) +{ + return netdev_priv(dev); +} + +static inline struct ieee80211_sub_if_data * +IEEE80211_WDEV_TO_SUB_IF(struct wireless_dev *wdev) +{ + return container_of(wdev, struct ieee80211_sub_if_data, wdev); +} + +static inline struct ieee80211_supported_band * +ieee80211_get_sband(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_chanctx_conf *chanctx_conf; + enum nl80211_band band; + + WARN_ON(sdata->vif.valid_links); + + rcu_read_lock(); + chanctx_conf = rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + + if (!chanctx_conf) { + rcu_read_unlock(); + return NULL; + } + + band = chanctx_conf->def.chan->band; + rcu_read_unlock(); + + return local->hw.wiphy->bands[band]; +} + +static inline struct ieee80211_supported_band * +ieee80211_get_link_sband(struct ieee80211_link_data *link) +{ + struct ieee80211_local *local = link->sdata->local; + struct ieee80211_chanctx_conf *chanctx_conf; + enum nl80211_band band; + + rcu_read_lock(); + chanctx_conf = rcu_dereference(link->conf->chanctx_conf); + if (!chanctx_conf) { + rcu_read_unlock(); + return NULL; + } + + band = chanctx_conf->def.chan->band; + rcu_read_unlock(); + + return local->hw.wiphy->bands[band]; +} + +/* this struct holds the value parsing from channel switch IE */ +struct ieee80211_csa_ie { + struct cfg80211_chan_def chandef; + u8 mode; + u8 count; + u8 ttl; + u16 pre_value; + u16 reason_code; + u32 max_switch_time; +}; + +/* Parsed Information Elements */ +struct ieee802_11_elems { + const u8 *ie_start; + size_t total_len; + u32 crc; + + /* pointers to IEs */ + const struct ieee80211_tdls_lnkie *lnk_id; + const struct ieee80211_ch_switch_timing *ch_sw_timing; + const u8 *ext_capab; + const u8 *ssid; + const u8 *supp_rates; + const u8 *ds_params; + const struct ieee80211_tim_ie *tim; + const u8 *rsn; + const u8 *rsnx; + const u8 *erp_info; + const u8 *ext_supp_rates; + const u8 *wmm_info; + const u8 *wmm_param; + const struct ieee80211_ht_cap *ht_cap_elem; + const struct ieee80211_ht_operation *ht_operation; + const struct ieee80211_vht_cap *vht_cap_elem; + const struct ieee80211_vht_operation *vht_operation; + const struct ieee80211_meshconf_ie *mesh_config; + const u8 *he_cap; + const struct ieee80211_he_operation *he_operation; + const struct ieee80211_he_spr *he_spr; + const struct ieee80211_mu_edca_param_set *mu_edca_param_set; + const struct ieee80211_he_6ghz_capa *he_6ghz_capa; + const struct ieee80211_tx_pwr_env *tx_pwr_env[IEEE80211_TPE_MAX_IE_COUNT]; + const u8 *uora_element; + const u8 *mesh_id; + const u8 *peering; + const __le16 *awake_window; + const u8 *preq; + const u8 *prep; + const u8 *perr; + const struct ieee80211_rann_ie *rann; + const struct ieee80211_channel_sw_ie *ch_switch_ie; + const struct ieee80211_ext_chansw_ie *ext_chansw_ie; + const struct ieee80211_wide_bw_chansw_ie *wide_bw_chansw_ie; + const u8 *max_channel_switch_time; + const u8 *country_elem; + const u8 *pwr_constr_elem; + const u8 *cisco_dtpc_elem; + const struct ieee80211_timeout_interval_ie *timeout_int; + const u8 *opmode_notif; + const struct ieee80211_sec_chan_offs_ie *sec_chan_offs; + struct ieee80211_mesh_chansw_params_ie *mesh_chansw_params_ie; + const struct ieee80211_bss_max_idle_period_ie *max_idle_period_ie; + const struct ieee80211_multiple_bssid_configuration *mbssid_config_ie; + const struct ieee80211_bssid_index *bssid_index; + u8 max_bssid_indicator; + u8 dtim_count; + u8 dtim_period; + const struct ieee80211_addba_ext_ie *addba_ext_ie; + const struct ieee80211_s1g_cap *s1g_capab; + const struct ieee80211_s1g_oper_ie *s1g_oper; + const struct ieee80211_s1g_bcn_compat_ie *s1g_bcn_compat; + const struct ieee80211_aid_response_ie *aid_resp; + const struct ieee80211_eht_cap_elem *eht_cap; + const struct ieee80211_eht_operation *eht_operation; + const struct ieee80211_multi_link_elem *multi_link; + + /* length of them, respectively */ + u8 ext_capab_len; + u8 ssid_len; + u8 supp_rates_len; + u8 tim_len; + u8 rsn_len; + u8 rsnx_len; + u8 ext_supp_rates_len; + u8 wmm_info_len; + u8 wmm_param_len; + u8 he_cap_len; + u8 mesh_id_len; + u8 peering_len; + u8 preq_len; + u8 prep_len; + u8 perr_len; + u8 country_elem_len; + u8 bssid_index_len; + u8 tx_pwr_env_len[IEEE80211_TPE_MAX_IE_COUNT]; + u8 tx_pwr_env_num; + u8 eht_cap_len; + + /* whether a parse error occurred while retrieving these elements */ + bool parse_error; + + /* + * scratch buffer that can be used for various element parsing related + * tasks, e.g., element de-fragmentation etc. + */ + size_t scratch_len; + u8 *scratch_pos; + u8 scratch[]; +}; + +static inline struct ieee80211_local *hw_to_local( + struct ieee80211_hw *hw) +{ + return container_of(hw, struct ieee80211_local, hw); +} + +static inline struct txq_info *to_txq_info(struct ieee80211_txq *txq) +{ + return container_of(txq, struct txq_info, txq); +} + +static inline bool txq_has_queue(struct ieee80211_txq *txq) +{ + struct txq_info *txqi = to_txq_info(txq); + + return !(skb_queue_empty(&txqi->frags) && !txqi->tin.backlog_packets); +} + +static inline bool +ieee80211_have_rx_timestamp(struct ieee80211_rx_status *status) +{ + WARN_ON_ONCE(status->flag & RX_FLAG_MACTIME_START && + status->flag & RX_FLAG_MACTIME_END); + return !!(status->flag & (RX_FLAG_MACTIME_START | RX_FLAG_MACTIME_END | + RX_FLAG_MACTIME_PLCP_START)); +} + +void ieee80211_vif_inc_num_mcast(struct ieee80211_sub_if_data *sdata); +void ieee80211_vif_dec_num_mcast(struct ieee80211_sub_if_data *sdata); + +/* This function returns the number of multicast stations connected to this + * interface. It returns -1 if that number is not tracked, that is for netdevs + * not in AP or AP_VLAN mode or when using 4addr. + */ +static inline int +ieee80211_vif_get_num_mcast_if(struct ieee80211_sub_if_data *sdata) +{ + if (sdata->vif.type == NL80211_IFTYPE_AP) + return atomic_read(&sdata->u.ap.num_mcast_sta); + if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN && !sdata->u.vlan.sta) + return atomic_read(&sdata->u.vlan.num_mcast_sta); + return -1; +} + +u64 ieee80211_calculate_rx_timestamp(struct ieee80211_local *local, + struct ieee80211_rx_status *status, + unsigned int mpdu_len, + unsigned int mpdu_offset); +int ieee80211_hw_config(struct ieee80211_local *local, u32 changed); +void ieee80211_tx_set_protected(struct ieee80211_tx_data *tx); +void ieee80211_bss_info_change_notify(struct ieee80211_sub_if_data *sdata, + u64 changed); +void ieee80211_vif_cfg_change_notify(struct ieee80211_sub_if_data *sdata, + u64 changed); +void ieee80211_link_info_change_notify(struct ieee80211_sub_if_data *sdata, + struct ieee80211_link_data *link, + u64 changed); +void ieee80211_configure_filter(struct ieee80211_local *local); +u32 ieee80211_reset_erp_info(struct ieee80211_sub_if_data *sdata); + +u64 ieee80211_mgmt_tx_cookie(struct ieee80211_local *local); +int ieee80211_attach_ack_skb(struct ieee80211_local *local, struct sk_buff *skb, + u64 *cookie, gfp_t gfp); + +void ieee80211_check_fast_rx(struct sta_info *sta); +void __ieee80211_check_fast_rx_iface(struct ieee80211_sub_if_data *sdata); +void ieee80211_check_fast_rx_iface(struct ieee80211_sub_if_data *sdata); +void ieee80211_clear_fast_rx(struct sta_info *sta); + +bool ieee80211_is_our_addr(struct ieee80211_sub_if_data *sdata, + const u8 *addr, int *out_link_id); + +/* STA code */ +void ieee80211_sta_setup_sdata(struct ieee80211_sub_if_data *sdata); +int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata, + struct cfg80211_auth_request *req); +int ieee80211_mgd_assoc(struct ieee80211_sub_if_data *sdata, + struct cfg80211_assoc_request *req); +int ieee80211_mgd_deauth(struct ieee80211_sub_if_data *sdata, + struct cfg80211_deauth_request *req); +int ieee80211_mgd_disassoc(struct ieee80211_sub_if_data *sdata, + struct cfg80211_disassoc_request *req); +void ieee80211_send_pspoll(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata); +void ieee80211_recalc_ps(struct ieee80211_local *local); +void ieee80211_recalc_ps_vif(struct ieee80211_sub_if_data *sdata); +int ieee80211_set_arp_filter(struct ieee80211_sub_if_data *sdata); +void ieee80211_sta_work(struct ieee80211_sub_if_data *sdata); +void ieee80211_sta_rx_queued_mgmt(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +void ieee80211_sta_rx_queued_ext(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +void ieee80211_sta_reset_beacon_monitor(struct ieee80211_sub_if_data *sdata); +void ieee80211_sta_reset_conn_monitor(struct ieee80211_sub_if_data *sdata); +void ieee80211_mgd_stop(struct ieee80211_sub_if_data *sdata); +void ieee80211_mgd_conn_tx_status(struct ieee80211_sub_if_data *sdata, + __le16 fc, bool acked); +void ieee80211_mgd_quiesce(struct ieee80211_sub_if_data *sdata); +void ieee80211_sta_restart(struct ieee80211_sub_if_data *sdata); +void ieee80211_sta_handle_tspec_ac_params(struct ieee80211_sub_if_data *sdata); +void ieee80211_sta_connection_lost(struct ieee80211_sub_if_data *sdata, + u8 reason, bool tx); +void ieee80211_mgd_setup_link(struct ieee80211_link_data *link); +void ieee80211_mgd_stop_link(struct ieee80211_link_data *link); +void ieee80211_mgd_set_link_qos_params(struct ieee80211_link_data *link); + +/* IBSS code */ +void ieee80211_ibss_notify_scan_completed(struct ieee80211_local *local); +void ieee80211_ibss_setup_sdata(struct ieee80211_sub_if_data *sdata); +void ieee80211_ibss_rx_no_sta(struct ieee80211_sub_if_data *sdata, + const u8 *bssid, const u8 *addr, u32 supp_rates); +int ieee80211_ibss_join(struct ieee80211_sub_if_data *sdata, + struct cfg80211_ibss_params *params); +int ieee80211_ibss_leave(struct ieee80211_sub_if_data *sdata); +void ieee80211_ibss_work(struct ieee80211_sub_if_data *sdata); +void ieee80211_ibss_rx_queued_mgmt(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +int ieee80211_ibss_csa_beacon(struct ieee80211_sub_if_data *sdata, + struct cfg80211_csa_settings *csa_settings); +int ieee80211_ibss_finish_csa(struct ieee80211_sub_if_data *sdata); +void ieee80211_ibss_stop(struct ieee80211_sub_if_data *sdata); + +/* OCB code */ +void ieee80211_ocb_work(struct ieee80211_sub_if_data *sdata); +void ieee80211_ocb_rx_no_sta(struct ieee80211_sub_if_data *sdata, + const u8 *bssid, const u8 *addr, u32 supp_rates); +void ieee80211_ocb_setup_sdata(struct ieee80211_sub_if_data *sdata); +int ieee80211_ocb_join(struct ieee80211_sub_if_data *sdata, + struct ocb_setup *setup); +int ieee80211_ocb_leave(struct ieee80211_sub_if_data *sdata); + +/* mesh code */ +void ieee80211_mesh_work(struct ieee80211_sub_if_data *sdata); +void ieee80211_mesh_rx_queued_mgmt(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +int ieee80211_mesh_csa_beacon(struct ieee80211_sub_if_data *sdata, + struct cfg80211_csa_settings *csa_settings); +int ieee80211_mesh_finish_csa(struct ieee80211_sub_if_data *sdata); + +/* scan/BSS handling */ +void ieee80211_scan_work(struct wiphy *wiphy, struct wiphy_work *work); +int ieee80211_request_ibss_scan(struct ieee80211_sub_if_data *sdata, + const u8 *ssid, u8 ssid_len, + struct ieee80211_channel **channels, + unsigned int n_channels, + enum nl80211_bss_scan_width scan_width); +int ieee80211_request_scan(struct ieee80211_sub_if_data *sdata, + struct cfg80211_scan_request *req); +void ieee80211_scan_cancel(struct ieee80211_local *local); +void ieee80211_run_deferred_scan(struct ieee80211_local *local); +void ieee80211_scan_rx(struct ieee80211_local *local, struct sk_buff *skb); + +void ieee80211_mlme_notify_scan_completed(struct ieee80211_local *local); +struct ieee80211_bss * +ieee80211_bss_info_update(struct ieee80211_local *local, + struct ieee80211_rx_status *rx_status, + struct ieee80211_mgmt *mgmt, + size_t len, + struct ieee80211_channel *channel); +void ieee80211_rx_bss_put(struct ieee80211_local *local, + struct ieee80211_bss *bss); + +/* scheduled scan handling */ +int +__ieee80211_request_sched_scan_start(struct ieee80211_sub_if_data *sdata, + struct cfg80211_sched_scan_request *req); +int ieee80211_request_sched_scan_start(struct ieee80211_sub_if_data *sdata, + struct cfg80211_sched_scan_request *req); +int ieee80211_request_sched_scan_stop(struct ieee80211_local *local); +void ieee80211_sched_scan_end(struct ieee80211_local *local); +void ieee80211_sched_scan_stopped_work(struct wiphy *wiphy, + struct wiphy_work *work); + +/* off-channel/mgmt-tx */ +void ieee80211_offchannel_stop_vifs(struct ieee80211_local *local); +void ieee80211_offchannel_return(struct ieee80211_local *local); +void ieee80211_roc_setup(struct ieee80211_local *local); +void ieee80211_start_next_roc(struct ieee80211_local *local); +void ieee80211_roc_purge(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata); +int ieee80211_remain_on_channel(struct wiphy *wiphy, struct wireless_dev *wdev, + struct ieee80211_channel *chan, + unsigned int duration, u64 *cookie); +int ieee80211_cancel_remain_on_channel(struct wiphy *wiphy, + struct wireless_dev *wdev, u64 cookie); +int ieee80211_mgmt_tx(struct wiphy *wiphy, struct wireless_dev *wdev, + struct cfg80211_mgmt_tx_params *params, u64 *cookie); +int ieee80211_mgmt_tx_cancel_wait(struct wiphy *wiphy, + struct wireless_dev *wdev, u64 cookie); + +/* channel switch handling */ +void ieee80211_csa_finalize_work(struct work_struct *work); +int ieee80211_channel_switch(struct wiphy *wiphy, struct net_device *dev, + struct cfg80211_csa_settings *params); + +/* color change handling */ +void ieee80211_color_change_finalize_work(struct work_struct *work); +void ieee80211_color_collision_detection_work(struct work_struct *work); + +/* interface handling */ +#define MAC80211_SUPPORTED_FEATURES_TX (NETIF_F_IP_CSUM | NETIF_F_IPV6_CSUM | \ + NETIF_F_HW_CSUM | NETIF_F_SG | \ + NETIF_F_HIGHDMA | NETIF_F_GSO_SOFTWARE) +#define MAC80211_SUPPORTED_FEATURES_RX (NETIF_F_RXCSUM) +#define MAC80211_SUPPORTED_FEATURES (MAC80211_SUPPORTED_FEATURES_TX | \ + MAC80211_SUPPORTED_FEATURES_RX) + +int ieee80211_iface_init(void); +void ieee80211_iface_exit(void); +int ieee80211_if_add(struct ieee80211_local *local, const char *name, + unsigned char name_assign_type, + struct wireless_dev **new_wdev, enum nl80211_iftype type, + struct vif_params *params); +int ieee80211_if_change_type(struct ieee80211_sub_if_data *sdata, + enum nl80211_iftype type); +void ieee80211_if_remove(struct ieee80211_sub_if_data *sdata); +void ieee80211_remove_interfaces(struct ieee80211_local *local); +u32 ieee80211_idle_off(struct ieee80211_local *local); +void ieee80211_recalc_idle(struct ieee80211_local *local); +void ieee80211_adjust_monitor_flags(struct ieee80211_sub_if_data *sdata, + const int offset); +int ieee80211_do_open(struct wireless_dev *wdev, bool coming_up); +void ieee80211_sdata_stop(struct ieee80211_sub_if_data *sdata); +int ieee80211_add_virtual_monitor(struct ieee80211_local *local); +void ieee80211_del_virtual_monitor(struct ieee80211_local *local); + +bool __ieee80211_recalc_txpower(struct ieee80211_sub_if_data *sdata); +void ieee80211_recalc_txpower(struct ieee80211_sub_if_data *sdata, + bool update_bss); +void ieee80211_recalc_offload(struct ieee80211_local *local); + +static inline bool ieee80211_sdata_running(struct ieee80211_sub_if_data *sdata) +{ + return test_bit(SDATA_STATE_RUNNING, &sdata->state); +} + +/* link handling */ +void ieee80211_link_setup(struct ieee80211_link_data *link); +void ieee80211_link_init(struct ieee80211_sub_if_data *sdata, + int link_id, + struct ieee80211_link_data *link, + struct ieee80211_bss_conf *link_conf); +void ieee80211_link_stop(struct ieee80211_link_data *link); +int ieee80211_vif_set_links(struct ieee80211_sub_if_data *sdata, + u16 new_links); +void ieee80211_vif_clear_links(struct ieee80211_sub_if_data *sdata); + +/* tx handling */ +void ieee80211_clear_tx_pending(struct ieee80211_local *local); +void ieee80211_tx_pending(struct tasklet_struct *t); +netdev_tx_t ieee80211_monitor_start_xmit(struct sk_buff *skb, + struct net_device *dev); +netdev_tx_t ieee80211_subif_start_xmit(struct sk_buff *skb, + struct net_device *dev); +netdev_tx_t ieee80211_subif_start_xmit_8023(struct sk_buff *skb, + struct net_device *dev); +void __ieee80211_subif_start_xmit(struct sk_buff *skb, + struct net_device *dev, + u32 info_flags, + u32 ctrl_flags, + u64 *cookie); +void ieee80211_purge_tx_queue(struct ieee80211_hw *hw, + struct sk_buff_head *skbs); +struct sk_buff * +ieee80211_build_data_template(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, u32 info_flags); +void ieee80211_tx_monitor(struct ieee80211_local *local, struct sk_buff *skb, + int retry_count, int shift, bool send_to_cooked, + struct ieee80211_tx_status *status); + +void ieee80211_check_fast_xmit(struct sta_info *sta); +void ieee80211_check_fast_xmit_all(struct ieee80211_local *local); +void ieee80211_check_fast_xmit_iface(struct ieee80211_sub_if_data *sdata); +void ieee80211_clear_fast_xmit(struct sta_info *sta); +int ieee80211_tx_control_port(struct wiphy *wiphy, struct net_device *dev, + const u8 *buf, size_t len, + const u8 *dest, __be16 proto, bool unencrypted, + int link_id, u64 *cookie); +int ieee80211_probe_mesh_link(struct wiphy *wiphy, struct net_device *dev, + const u8 *buf, size_t len); + +/* HT */ +void ieee80211_apply_htcap_overrides(struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta_ht_cap *ht_cap); +bool ieee80211_ht_cap_ie_to_sta_ht_cap(struct ieee80211_sub_if_data *sdata, + struct ieee80211_supported_band *sband, + const struct ieee80211_ht_cap *ht_cap_ie, + struct link_sta_info *link_sta); +void ieee80211_send_delba(struct ieee80211_sub_if_data *sdata, + const u8 *da, u16 tid, + u16 initiator, u16 reason_code); +int ieee80211_send_smps_action(struct ieee80211_sub_if_data *sdata, + enum ieee80211_smps_mode smps, const u8 *da, + const u8 *bssid); +bool ieee80211_smps_is_restrictive(enum ieee80211_smps_mode smps_mode_old, + enum ieee80211_smps_mode smps_mode_new); + +void ___ieee80211_stop_rx_ba_session(struct sta_info *sta, u16 tid, + u16 initiator, u16 reason, bool stop); +void __ieee80211_stop_rx_ba_session(struct sta_info *sta, u16 tid, + u16 initiator, u16 reason, bool stop); +void ___ieee80211_start_rx_ba_session(struct sta_info *sta, + u8 dialog_token, u16 timeout, + u16 start_seq_num, u16 ba_policy, u16 tid, + u16 buf_size, bool tx, bool auto_seq, + const struct ieee80211_addba_ext_ie *addbaext); +void ieee80211_sta_tear_down_BA_sessions(struct sta_info *sta, + enum ieee80211_agg_stop_reason reason); +void ieee80211_process_delba(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + struct ieee80211_mgmt *mgmt, size_t len); +void ieee80211_process_addba_resp(struct ieee80211_local *local, + struct sta_info *sta, + struct ieee80211_mgmt *mgmt, + size_t len); +void ieee80211_process_addba_request(struct ieee80211_local *local, + struct sta_info *sta, + struct ieee80211_mgmt *mgmt, + size_t len); + +int __ieee80211_stop_tx_ba_session(struct sta_info *sta, u16 tid, + enum ieee80211_agg_stop_reason reason); +int ___ieee80211_stop_tx_ba_session(struct sta_info *sta, u16 tid, + enum ieee80211_agg_stop_reason reason); +void ieee80211_start_tx_ba_cb(struct sta_info *sta, int tid, + struct tid_ampdu_tx *tid_tx); +void ieee80211_stop_tx_ba_cb(struct sta_info *sta, int tid, + struct tid_ampdu_tx *tid_tx); +void ieee80211_ba_session_work(struct work_struct *work); +void ieee80211_tx_ba_session_handle_start(struct sta_info *sta, int tid); +void ieee80211_release_reorder_timeout(struct sta_info *sta, int tid); + +u8 ieee80211_mcs_to_chains(const struct ieee80211_mcs_info *mcs); +enum nl80211_smps_mode +ieee80211_smps_mode_to_smps_mode(enum ieee80211_smps_mode smps); + +/* VHT */ +void +ieee80211_vht_cap_ie_to_sta_vht_cap(struct ieee80211_sub_if_data *sdata, + struct ieee80211_supported_band *sband, + const struct ieee80211_vht_cap *vht_cap_ie, + const struct ieee80211_vht_cap *vht_cap_ie2, + struct link_sta_info *link_sta); +enum ieee80211_sta_rx_bandwidth +ieee80211_sta_cap_rx_bw(struct link_sta_info *link_sta); +enum ieee80211_sta_rx_bandwidth +ieee80211_sta_cur_vht_bw(struct link_sta_info *link_sta); +void ieee80211_sta_set_rx_nss(struct link_sta_info *link_sta); +enum ieee80211_sta_rx_bandwidth +ieee80211_chan_width_to_rx_bw(enum nl80211_chan_width width); +enum nl80211_chan_width +ieee80211_sta_cap_chan_bw(struct link_sta_info *link_sta); +void ieee80211_process_mu_groups(struct ieee80211_sub_if_data *sdata, + struct ieee80211_link_data *link, + struct ieee80211_mgmt *mgmt); +u32 __ieee80211_vht_handle_opmode(struct ieee80211_sub_if_data *sdata, + struct link_sta_info *sta, + u8 opmode, enum nl80211_band band); +void ieee80211_vht_handle_opmode(struct ieee80211_sub_if_data *sdata, + struct link_sta_info *sta, + u8 opmode, enum nl80211_band band); +void ieee80211_apply_vhtcap_overrides(struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta_vht_cap *vht_cap); +void ieee80211_get_vht_mask_from_cap(__le16 vht_cap, + u16 vht_mask[NL80211_VHT_NSS_MAX]); +enum nl80211_chan_width +ieee80211_sta_rx_bw_to_chan_width(struct link_sta_info *sta); + +/* HE */ +void +ieee80211_he_cap_ie_to_sta_he_cap(struct ieee80211_sub_if_data *sdata, + struct ieee80211_supported_band *sband, + const u8 *he_cap_ie, u8 he_cap_len, + const struct ieee80211_he_6ghz_capa *he_6ghz_capa, + struct link_sta_info *link_sta); +void +ieee80211_he_spr_ie_to_bss_conf(struct ieee80211_vif *vif, + const struct ieee80211_he_spr *he_spr_ie_elem); + +void +ieee80211_he_op_ie_to_bss_conf(struct ieee80211_vif *vif, + const struct ieee80211_he_operation *he_op_ie_elem); + +/* S1G */ +void ieee80211_s1g_sta_rate_init(struct sta_info *sta); +bool ieee80211_s1g_is_twt_setup(struct sk_buff *skb); +void ieee80211_s1g_rx_twt_action(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +void ieee80211_s1g_status_twt_action(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); + +/* Spectrum management */ +void ieee80211_process_measurement_req(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, + size_t len); +/** + * ieee80211_parse_ch_switch_ie - parses channel switch IEs + * @sdata: the sdata of the interface which has received the frame + * @elems: parsed 802.11 elements received with the frame + * @current_band: indicates the current band + * @vht_cap_info: VHT capabilities of the transmitter + * @conn_flags: contains information about own capabilities and restrictions + * to decide which channel switch announcements can be accepted, using + * flags from &enum ieee80211_conn_flags. + * @bssid: the currently connected bssid (for reporting) + * @csa_ie: parsed 802.11 csa elements on count, mode, chandef and mesh ttl. + All of them will be filled with if success only. + * Return: 0 on success, <0 on error and >0 if there is nothing to parse. + */ +int ieee80211_parse_ch_switch_ie(struct ieee80211_sub_if_data *sdata, + struct ieee802_11_elems *elems, + enum nl80211_band current_band, + u32 vht_cap_info, + ieee80211_conn_flags_t conn_flags, u8 *bssid, + struct ieee80211_csa_ie *csa_ie); + +/* Suspend/resume and hw reconfiguration */ +int ieee80211_reconfig(struct ieee80211_local *local); +void ieee80211_stop_device(struct ieee80211_local *local); + +int __ieee80211_suspend(struct ieee80211_hw *hw, + struct cfg80211_wowlan *wowlan); + +static inline int __ieee80211_resume(struct ieee80211_hw *hw) +{ + struct ieee80211_local *local = hw_to_local(hw); + + WARN(test_bit(SCAN_HW_SCANNING, &local->scanning) && + !test_bit(SCAN_COMPLETED, &local->scanning), + "%s: resume with hardware scan still in progress\n", + wiphy_name(hw->wiphy)); + + return ieee80211_reconfig(hw_to_local(hw)); +} + +/* utility functions/constants */ +extern const void *const mac80211_wiphy_privid; /* for wiphy privid */ +int ieee80211_frame_duration(enum nl80211_band band, size_t len, + int rate, int erp, int short_preamble, + int shift); +void ieee80211_regulatory_limit_wmm_params(struct ieee80211_sub_if_data *sdata, + struct ieee80211_tx_queue_params *qparam, + int ac); +void ieee80211_set_wmm_default(struct ieee80211_link_data *link, + bool bss_notify, bool enable_qos); +void ieee80211_xmit(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, struct sk_buff *skb); + +void __ieee80211_tx_skb_tid_band(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, int tid, int link_id, + enum nl80211_band band); + +/* sta_out needs to be checked for ERR_PTR() before using */ +int ieee80211_lookup_ra_sta(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, + struct sta_info **sta_out); + +static inline void +ieee80211_tx_skb_tid_band(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, int tid, + enum nl80211_band band) +{ + rcu_read_lock(); + __ieee80211_tx_skb_tid_band(sdata, skb, tid, -1, band); + rcu_read_unlock(); +} + +void ieee80211_tx_skb_tid(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, int tid, int link_id); + +static inline void ieee80211_tx_skb(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + /* Send all internal mgmt frames on VO. Accordingly set TID to 7. */ + ieee80211_tx_skb_tid(sdata, skb, 7, -1); +} + +/** + * struct ieee80211_elems_parse_params - element parsing parameters + * @start: pointer to the elements + * @len: length of the elements + * @action: %true if the elements came from an action frame + * @filter: bitmap of element IDs to filter out while calculating + * the element CRC + * @crc: CRC starting value + * @bss: the BSS to parse this as, for multi-BSSID cases this can + * represent a non-transmitting BSS in which case the data + * for that non-transmitting BSS is returned + * @link_id: the link ID to parse elements for, if a STA profile + * is present in the multi-link element, or -1 to ignore + * @from_ap: frame is received from an AP (currently used only + * for EHT capabilities parsing) + */ +struct ieee80211_elems_parse_params { + const u8 *start; + size_t len; + bool action; + u64 filter; + u32 crc; + struct cfg80211_bss *bss; + int link_id; + bool from_ap; +}; + +struct ieee802_11_elems * +ieee802_11_parse_elems_full(struct ieee80211_elems_parse_params *params); + +static inline struct ieee802_11_elems * +ieee802_11_parse_elems_crc(const u8 *start, size_t len, bool action, + u64 filter, u32 crc, + struct cfg80211_bss *bss) +{ + struct ieee80211_elems_parse_params params = { + .start = start, + .len = len, + .action = action, + .filter = filter, + .crc = crc, + .bss = bss, + .link_id = -1, + }; + + return ieee802_11_parse_elems_full(¶ms); +} + +static inline struct ieee802_11_elems * +ieee802_11_parse_elems(const u8 *start, size_t len, bool action, + struct cfg80211_bss *bss) +{ + return ieee802_11_parse_elems_crc(start, len, action, 0, 0, bss); +} + +void ieee80211_fragment_element(struct sk_buff *skb, u8 *len_pos); + +extern const int ieee802_1d_to_ac[8]; + +static inline int ieee80211_ac_from_tid(int tid) +{ + return ieee802_1d_to_ac[tid & 7]; +} + +void ieee80211_dynamic_ps_enable_work(struct work_struct *work); +void ieee80211_dynamic_ps_disable_work(struct work_struct *work); +void ieee80211_dynamic_ps_timer(struct timer_list *t); +void ieee80211_send_nullfunc(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + bool powersave); +void ieee80211_send_4addr_nullfunc(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata); +void ieee80211_sta_tx_notify(struct ieee80211_sub_if_data *sdata, + struct ieee80211_hdr *hdr, bool ack, u16 tx_time); + +void ieee80211_wake_queues_by_reason(struct ieee80211_hw *hw, + unsigned long queues, + enum queue_stop_reason reason, + bool refcounted); +void ieee80211_stop_vif_queues(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + enum queue_stop_reason reason); +void ieee80211_wake_vif_queues(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + enum queue_stop_reason reason); +void ieee80211_stop_queues_by_reason(struct ieee80211_hw *hw, + unsigned long queues, + enum queue_stop_reason reason, + bool refcounted); +void ieee80211_wake_queue_by_reason(struct ieee80211_hw *hw, int queue, + enum queue_stop_reason reason, + bool refcounted); +void ieee80211_stop_queue_by_reason(struct ieee80211_hw *hw, int queue, + enum queue_stop_reason reason, + bool refcounted); +void ieee80211_propagate_queue_wake(struct ieee80211_local *local, int queue); +void ieee80211_add_pending_skb(struct ieee80211_local *local, + struct sk_buff *skb); +void ieee80211_add_pending_skbs(struct ieee80211_local *local, + struct sk_buff_head *skbs); +void ieee80211_flush_queues(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, bool drop); +void __ieee80211_flush_queues(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + unsigned int queues, bool drop); + +static inline bool ieee80211_can_run_worker(struct ieee80211_local *local) +{ + /* + * It's unsafe to try to do any work during reconfigure flow. + * When the flow ends the work will be requeued. + */ + if (local->in_reconfig) + return false; + + /* + * If quiescing is set, we are racing with __ieee80211_suspend. + * __ieee80211_suspend flushes the workers after setting quiescing, + * and we check quiescing / suspended before enqueing new workers. + * We should abort the worker to avoid the races below. + */ + if (local->quiescing) + return false; + + /* + * We might already be suspended if the following scenario occurs: + * __ieee80211_suspend Control path + * + * if (local->quiescing) + * return; + * local->quiescing = true; + * flush_workqueue(); + * queue_work(...); + * local->suspended = true; + * local->quiescing = false; + * worker starts running... + */ + if (local->suspended) + return false; + + return true; +} + +int ieee80211_txq_setup_flows(struct ieee80211_local *local); +void ieee80211_txq_set_params(struct ieee80211_local *local); +void ieee80211_txq_teardown_flows(struct ieee80211_local *local); +void ieee80211_txq_init(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + struct txq_info *txq, int tid); +void ieee80211_txq_purge(struct ieee80211_local *local, + struct txq_info *txqi); +void ieee80211_txq_remove_vlan(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata); +void ieee80211_fill_txq_stats(struct cfg80211_txq_stats *txqstats, + struct txq_info *txqi); +void ieee80211_wake_txqs(struct tasklet_struct *t); +void ieee80211_send_auth(struct ieee80211_sub_if_data *sdata, + u16 transaction, u16 auth_alg, u16 status, + const u8 *extra, size_t extra_len, const u8 *bssid, + const u8 *da, const u8 *key, u8 key_len, u8 key_idx, + u32 tx_flags); +void ieee80211_send_deauth_disassoc(struct ieee80211_sub_if_data *sdata, + const u8 *da, const u8 *bssid, + u16 stype, u16 reason, + bool send_frame, u8 *frame_buf); + +enum { + IEEE80211_PROBE_FLAG_DIRECTED = BIT(0), + IEEE80211_PROBE_FLAG_MIN_CONTENT = BIT(1), + IEEE80211_PROBE_FLAG_RANDOM_SN = BIT(2), +}; + +int ieee80211_build_preq_ies(struct ieee80211_sub_if_data *sdata, u8 *buffer, + size_t buffer_len, + struct ieee80211_scan_ies *ie_desc, + const u8 *ie, size_t ie_len, + u8 bands_used, u32 *rate_masks, + struct cfg80211_chan_def *chandef, + u32 flags); +struct sk_buff *ieee80211_build_probe_req(struct ieee80211_sub_if_data *sdata, + const u8 *src, const u8 *dst, + u32 ratemask, + struct ieee80211_channel *chan, + const u8 *ssid, size_t ssid_len, + const u8 *ie, size_t ie_len, + u32 flags); +u32 ieee80211_sta_get_rates(struct ieee80211_sub_if_data *sdata, + struct ieee802_11_elems *elems, + enum nl80211_band band, u32 *basic_rates); +int __ieee80211_request_smps_mgd(struct ieee80211_sub_if_data *sdata, + struct ieee80211_link_data *link, + enum ieee80211_smps_mode smps_mode); +void ieee80211_recalc_smps(struct ieee80211_sub_if_data *sdata, + struct ieee80211_link_data *link); +void ieee80211_recalc_min_chandef(struct ieee80211_sub_if_data *sdata, + int link_id); + +size_t ieee80211_ie_split_vendor(const u8 *ies, size_t ielen, size_t offset); +u8 *ieee80211_ie_build_ht_cap(u8 *pos, struct ieee80211_sta_ht_cap *ht_cap, + u16 cap); +u8 *ieee80211_ie_build_ht_oper(u8 *pos, struct ieee80211_sta_ht_cap *ht_cap, + const struct cfg80211_chan_def *chandef, + u16 prot_mode, bool rifs_mode); +void ieee80211_ie_build_wide_bw_cs(u8 *pos, + const struct cfg80211_chan_def *chandef); +u8 *ieee80211_ie_build_vht_cap(u8 *pos, struct ieee80211_sta_vht_cap *vht_cap, + u32 cap); +u8 *ieee80211_ie_build_vht_oper(u8 *pos, struct ieee80211_sta_vht_cap *vht_cap, + const struct cfg80211_chan_def *chandef); +u8 ieee80211_ie_len_he_cap(struct ieee80211_sub_if_data *sdata, u8 iftype); +u8 *ieee80211_ie_build_he_cap(ieee80211_conn_flags_t disable_flags, u8 *pos, + const struct ieee80211_sta_he_cap *he_cap, + u8 *end); +void ieee80211_ie_build_he_6ghz_cap(struct ieee80211_sub_if_data *sdata, + enum ieee80211_smps_mode smps_mode, + struct sk_buff *skb); +u8 *ieee80211_ie_build_he_oper(u8 *pos, struct cfg80211_chan_def *chandef); +int ieee80211_parse_bitrates(enum nl80211_chan_width width, + const struct ieee80211_supported_band *sband, + const u8 *srates, int srates_len, u32 *rates); +int ieee80211_add_srates_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, bool need_basic, + enum nl80211_band band); +int ieee80211_add_ext_srates_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, bool need_basic, + enum nl80211_band band); +u8 *ieee80211_add_wmm_info_ie(u8 *buf, u8 qosinfo); +void ieee80211_add_s1g_capab_ie(struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta_s1g_cap *caps, + struct sk_buff *skb); +void ieee80211_add_aid_request_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); + +/* channel management */ +bool ieee80211_chandef_ht_oper(const struct ieee80211_ht_operation *ht_oper, + struct cfg80211_chan_def *chandef); +bool ieee80211_chandef_vht_oper(struct ieee80211_hw *hw, u32 vht_cap_info, + const struct ieee80211_vht_operation *oper, + const struct ieee80211_ht_operation *htop, + struct cfg80211_chan_def *chandef); +void ieee80211_chandef_eht_oper(const struct ieee80211_eht_operation *eht_oper, + bool support_160, bool support_320, + struct cfg80211_chan_def *chandef); +bool ieee80211_chandef_he_6ghz_oper(struct ieee80211_sub_if_data *sdata, + const struct ieee80211_he_operation *he_oper, + const struct ieee80211_eht_operation *eht_oper, + struct cfg80211_chan_def *chandef); +bool ieee80211_chandef_s1g_oper(const struct ieee80211_s1g_oper_ie *oper, + struct cfg80211_chan_def *chandef); +ieee80211_conn_flags_t ieee80211_chandef_downgrade(struct cfg80211_chan_def *c); + +int __must_check +ieee80211_link_use_channel(struct ieee80211_link_data *link, + const struct cfg80211_chan_def *chandef, + enum ieee80211_chanctx_mode mode); +int __must_check +ieee80211_link_reserve_chanctx(struct ieee80211_link_data *link, + const struct cfg80211_chan_def *chandef, + enum ieee80211_chanctx_mode mode, + bool radar_required); +int __must_check +ieee80211_link_use_reserved_context(struct ieee80211_link_data *link); +int ieee80211_link_unreserve_chanctx(struct ieee80211_link_data *link); + +int __must_check +ieee80211_link_change_bandwidth(struct ieee80211_link_data *link, + const struct cfg80211_chan_def *chandef, + u32 *changed); +void ieee80211_link_release_channel(struct ieee80211_link_data *link); +void ieee80211_link_vlan_copy_chanctx(struct ieee80211_link_data *link); +void ieee80211_link_copy_chanctx_to_vlans(struct ieee80211_link_data *link, + bool clear); +int ieee80211_chanctx_refcount(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx); + +void ieee80211_recalc_smps_chanctx(struct ieee80211_local *local, + struct ieee80211_chanctx *chanctx); +void ieee80211_recalc_chanctx_min_def(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx, + struct ieee80211_link_data *rsvd_for); +bool ieee80211_is_radar_required(struct ieee80211_local *local); + +void ieee80211_dfs_cac_timer(unsigned long data); +void ieee80211_dfs_cac_timer_work(struct work_struct *work); +void ieee80211_dfs_cac_cancel(struct ieee80211_local *local); +void ieee80211_dfs_radar_detected_work(struct wiphy *wiphy, + struct wiphy_work *work); +int ieee80211_send_action_csa(struct ieee80211_sub_if_data *sdata, + struct cfg80211_csa_settings *csa_settings); + +void ieee80211_recalc_dtim(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata); +int ieee80211_check_combinations(struct ieee80211_sub_if_data *sdata, + const struct cfg80211_chan_def *chandef, + enum ieee80211_chanctx_mode chanmode, + u8 radar_detect); +int ieee80211_max_num_channels(struct ieee80211_local *local); +void ieee80211_recalc_chanctx_chantype(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx); + +/* TDLS */ +int ieee80211_tdls_mgmt(struct wiphy *wiphy, struct net_device *dev, + const u8 *peer, u8 action_code, u8 dialog_token, + u16 status_code, u32 peer_capability, + bool initiator, const u8 *extra_ies, + size_t extra_ies_len); +int ieee80211_tdls_oper(struct wiphy *wiphy, struct net_device *dev, + const u8 *peer, enum nl80211_tdls_operation oper); +void ieee80211_tdls_peer_del_work(struct work_struct *wk); +int ieee80211_tdls_channel_switch(struct wiphy *wiphy, struct net_device *dev, + const u8 *addr, u8 oper_class, + struct cfg80211_chan_def *chandef); +void ieee80211_tdls_cancel_channel_switch(struct wiphy *wiphy, + struct net_device *dev, + const u8 *addr); +void ieee80211_teardown_tdls_peers(struct ieee80211_sub_if_data *sdata); +void ieee80211_tdls_handle_disconnect(struct ieee80211_sub_if_data *sdata, + const u8 *peer, u16 reason); +void +ieee80211_process_tdls_channel_switch(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); + + +const char *ieee80211_get_reason_code_string(u16 reason_code); +u16 ieee80211_encode_usf(int val); +u8 *ieee80211_get_bssid(struct ieee80211_hdr *hdr, size_t len, + enum nl80211_iftype type); + +extern const struct ethtool_ops ieee80211_ethtool_ops; + +u32 ieee80211_calc_expected_tx_airtime(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, + struct ieee80211_sta *pubsta, + int len, bool ampdu); +#ifdef CONFIG_MAC80211_NOINLINE +#define debug_noinline noinline +#else +#define debug_noinline +#endif + +void ieee80211_init_frag_cache(struct ieee80211_fragment_cache *cache); +void ieee80211_destroy_frag_cache(struct ieee80211_fragment_cache *cache); + +u8 ieee80211_ie_len_eht_cap(struct ieee80211_sub_if_data *sdata, u8 iftype); +u8 *ieee80211_ie_build_eht_cap(u8 *pos, + const struct ieee80211_sta_he_cap *he_cap, + const struct ieee80211_sta_eht_cap *eht_cap, + u8 *end, + bool for_ap); + +void +ieee80211_eht_cap_ie_to_sta_eht_cap(struct ieee80211_sub_if_data *sdata, + struct ieee80211_supported_band *sband, + const u8 *he_cap_ie, u8 he_cap_len, + const struct ieee80211_eht_cap_elem *eht_cap_ie_elem, + u8 eht_cap_len, + struct link_sta_info *link_sta); +#endif /* IEEE80211_I_H */ diff --git a/net/mac80211/iface.c b/net/mac80211/iface.c new file mode 100644 index 000000000..e00e1bf0f --- /dev/null +++ b/net/mac80211/iface.c @@ -0,0 +1,2404 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Interface handling + * + * Copyright 2002-2005, Instant802 Networks, Inc. + * Copyright 2005-2006, Devicescape Software, Inc. + * Copyright (c) 2006 Jiri Benc + * Copyright 2008, Johannes Berg + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright (c) 2016 Intel Deutschland GmbH + * Copyright (C) 2018-2022 Intel Corporation + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include "ieee80211_i.h" +#include "sta_info.h" +#include "debugfs_netdev.h" +#include "mesh.h" +#include "led.h" +#include "driver-ops.h" +#include "wme.h" +#include "rate.h" + +/** + * DOC: Interface list locking + * + * The interface list in each struct ieee80211_local is protected + * three-fold: + * + * (1) modifications may only be done under the RTNL + * (2) modifications and readers are protected against each other by + * the iflist_mtx. + * (3) modifications are done in an RCU manner so atomic readers + * can traverse the list in RCU-safe blocks. + * + * As a consequence, reads (traversals) of the list can be protected + * by either the RTNL, the iflist_mtx or RCU. + */ + +static void ieee80211_iface_work(struct work_struct *work); + +bool __ieee80211_recalc_txpower(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_chanctx_conf *chanctx_conf; + int power; + + rcu_read_lock(); + chanctx_conf = rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + if (!chanctx_conf) { + rcu_read_unlock(); + return false; + } + + power = ieee80211_chandef_max_power(&chanctx_conf->def); + rcu_read_unlock(); + + if (sdata->deflink.user_power_level != IEEE80211_UNSET_POWER_LEVEL) + power = min(power, sdata->deflink.user_power_level); + + if (sdata->deflink.ap_power_level != IEEE80211_UNSET_POWER_LEVEL) + power = min(power, sdata->deflink.ap_power_level); + + if (power != sdata->vif.bss_conf.txpower) { + sdata->vif.bss_conf.txpower = power; + ieee80211_hw_config(sdata->local, 0); + return true; + } + + return false; +} + +void ieee80211_recalc_txpower(struct ieee80211_sub_if_data *sdata, + bool update_bss) +{ + if (__ieee80211_recalc_txpower(sdata) || + (update_bss && ieee80211_sdata_running(sdata))) + ieee80211_link_info_change_notify(sdata, &sdata->deflink, + BSS_CHANGED_TXPOWER); +} + +static u32 __ieee80211_idle_off(struct ieee80211_local *local) +{ + if (!(local->hw.conf.flags & IEEE80211_CONF_IDLE)) + return 0; + + local->hw.conf.flags &= ~IEEE80211_CONF_IDLE; + return IEEE80211_CONF_CHANGE_IDLE; +} + +static u32 __ieee80211_idle_on(struct ieee80211_local *local) +{ + if (local->hw.conf.flags & IEEE80211_CONF_IDLE) + return 0; + + ieee80211_flush_queues(local, NULL, false); + + local->hw.conf.flags |= IEEE80211_CONF_IDLE; + return IEEE80211_CONF_CHANGE_IDLE; +} + +static u32 __ieee80211_recalc_idle(struct ieee80211_local *local, + bool force_active) +{ + bool working, scanning, active; + unsigned int led_trig_start = 0, led_trig_stop = 0; + + lockdep_assert_held(&local->mtx); + + active = force_active || + !list_empty(&local->chanctx_list) || + local->monitors; + + working = !local->ops->remain_on_channel && + !list_empty(&local->roc_list); + + scanning = test_bit(SCAN_SW_SCANNING, &local->scanning) || + test_bit(SCAN_ONCHANNEL_SCANNING, &local->scanning); + + if (working || scanning) + led_trig_start |= IEEE80211_TPT_LEDTRIG_FL_WORK; + else + led_trig_stop |= IEEE80211_TPT_LEDTRIG_FL_WORK; + + if (active) + led_trig_start |= IEEE80211_TPT_LEDTRIG_FL_CONNECTED; + else + led_trig_stop |= IEEE80211_TPT_LEDTRIG_FL_CONNECTED; + + ieee80211_mod_tpt_led_trig(local, led_trig_start, led_trig_stop); + + if (working || scanning || active) + return __ieee80211_idle_off(local); + return __ieee80211_idle_on(local); +} + +u32 ieee80211_idle_off(struct ieee80211_local *local) +{ + return __ieee80211_recalc_idle(local, true); +} + +void ieee80211_recalc_idle(struct ieee80211_local *local) +{ + u32 change = __ieee80211_recalc_idle(local, false); + if (change) + ieee80211_hw_config(local, change); +} + +static int ieee80211_verify_mac(struct ieee80211_sub_if_data *sdata, u8 *addr, + bool check_dup) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_sub_if_data *iter; + u64 new, mask, tmp; + u8 *m; + int ret = 0; + + if (is_zero_ether_addr(local->hw.wiphy->addr_mask)) + return 0; + + m = addr; + new = ((u64)m[0] << 5*8) | ((u64)m[1] << 4*8) | + ((u64)m[2] << 3*8) | ((u64)m[3] << 2*8) | + ((u64)m[4] << 1*8) | ((u64)m[5] << 0*8); + + m = local->hw.wiphy->addr_mask; + mask = ((u64)m[0] << 5*8) | ((u64)m[1] << 4*8) | + ((u64)m[2] << 3*8) | ((u64)m[3] << 2*8) | + ((u64)m[4] << 1*8) | ((u64)m[5] << 0*8); + + if (!check_dup) + return ret; + + mutex_lock(&local->iflist_mtx); + list_for_each_entry(iter, &local->interfaces, list) { + if (iter == sdata) + continue; + + if (iter->vif.type == NL80211_IFTYPE_MONITOR && + !(iter->u.mntr.flags & MONITOR_FLAG_ACTIVE)) + continue; + + m = iter->vif.addr; + tmp = ((u64)m[0] << 5*8) | ((u64)m[1] << 4*8) | + ((u64)m[2] << 3*8) | ((u64)m[3] << 2*8) | + ((u64)m[4] << 1*8) | ((u64)m[5] << 0*8); + + if ((new & ~mask) != (tmp & ~mask)) { + ret = -EINVAL; + break; + } + } + mutex_unlock(&local->iflist_mtx); + + return ret; +} + +static int ieee80211_can_powered_addr_change(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_roc_work *roc; + struct ieee80211_local *local = sdata->local; + struct ieee80211_sub_if_data *scan_sdata; + int ret = 0; + + /* To be the most flexible here we want to only limit changing the + * address if the specific interface is doing offchannel work or + * scanning. + */ + if (netif_carrier_ok(sdata->dev)) + return -EBUSY; + + mutex_lock(&local->mtx); + + /* First check no ROC work is happening on this iface */ + list_for_each_entry(roc, &local->roc_list, list) { + if (roc->sdata != sdata) + continue; + + if (roc->started) { + ret = -EBUSY; + goto unlock; + } + } + + /* And if this iface is scanning */ + if (local->scanning) { + scan_sdata = rcu_dereference_protected(local->scan_sdata, + lockdep_is_held(&local->mtx)); + if (sdata == scan_sdata) + ret = -EBUSY; + } + + switch (sdata->vif.type) { + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_P2P_CLIENT: + /* More interface types could be added here but changing the + * address while powered makes the most sense in client modes. + */ + break; + default: + ret = -EOPNOTSUPP; + } + +unlock: + mutex_unlock(&local->mtx); + return ret; +} + +static int ieee80211_change_mac(struct net_device *dev, void *addr) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + struct sockaddr *sa = addr; + bool check_dup = true; + bool live = false; + int ret; + + if (ieee80211_sdata_running(sdata)) { + ret = ieee80211_can_powered_addr_change(sdata); + if (ret) + return ret; + + live = true; + } + + if (sdata->vif.type == NL80211_IFTYPE_MONITOR && + !(sdata->u.mntr.flags & MONITOR_FLAG_ACTIVE)) + check_dup = false; + + ret = ieee80211_verify_mac(sdata, sa->sa_data, check_dup); + if (ret) + return ret; + + if (live) + drv_remove_interface(local, sdata); + ret = eth_mac_addr(dev, sa); + + if (ret == 0) { + memcpy(sdata->vif.addr, sa->sa_data, ETH_ALEN); + ether_addr_copy(sdata->vif.bss_conf.addr, sdata->vif.addr); + } + + /* Regardless of eth_mac_addr() return we still want to add the + * interface back. This should not fail... + */ + if (live) + WARN_ON(drv_add_interface(local, sdata)); + + return ret; +} + +static inline int identical_mac_addr_allowed(int type1, int type2) +{ + return type1 == NL80211_IFTYPE_MONITOR || + type2 == NL80211_IFTYPE_MONITOR || + type1 == NL80211_IFTYPE_P2P_DEVICE || + type2 == NL80211_IFTYPE_P2P_DEVICE || + (type1 == NL80211_IFTYPE_AP && type2 == NL80211_IFTYPE_AP_VLAN) || + (type1 == NL80211_IFTYPE_AP_VLAN && + (type2 == NL80211_IFTYPE_AP || + type2 == NL80211_IFTYPE_AP_VLAN)); +} + +static int ieee80211_check_concurrent_iface(struct ieee80211_sub_if_data *sdata, + enum nl80211_iftype iftype) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_sub_if_data *nsdata; + int ret; + + ASSERT_RTNL(); + + /* we hold the RTNL here so can safely walk the list */ + list_for_each_entry(nsdata, &local->interfaces, list) { + if (nsdata != sdata && ieee80211_sdata_running(nsdata)) { + /* + * Only OCB and monitor mode may coexist + */ + if ((sdata->vif.type == NL80211_IFTYPE_OCB && + nsdata->vif.type != NL80211_IFTYPE_MONITOR) || + (sdata->vif.type != NL80211_IFTYPE_MONITOR && + nsdata->vif.type == NL80211_IFTYPE_OCB)) + return -EBUSY; + + /* + * Allow only a single IBSS interface to be up at any + * time. This is restricted because beacon distribution + * cannot work properly if both are in the same IBSS. + * + * To remove this restriction we'd have to disallow them + * from setting the same SSID on different IBSS interfaces + * belonging to the same hardware. Then, however, we're + * faced with having to adopt two different TSF timers... + */ + if (iftype == NL80211_IFTYPE_ADHOC && + nsdata->vif.type == NL80211_IFTYPE_ADHOC) + return -EBUSY; + /* + * will not add another interface while any channel + * switch is active. + */ + if (nsdata->vif.bss_conf.csa_active) + return -EBUSY; + + /* + * The remaining checks are only performed for interfaces + * with the same MAC address. + */ + if (!ether_addr_equal(sdata->vif.addr, + nsdata->vif.addr)) + continue; + + /* + * check whether it may have the same address + */ + if (!identical_mac_addr_allowed(iftype, + nsdata->vif.type)) + return -ENOTUNIQ; + + /* No support for VLAN with MLO yet */ + if (iftype == NL80211_IFTYPE_AP_VLAN && + sdata->wdev.use_4addr && + nsdata->vif.type == NL80211_IFTYPE_AP && + nsdata->vif.valid_links) + return -EOPNOTSUPP; + + /* + * can only add VLANs to enabled APs + */ + if (iftype == NL80211_IFTYPE_AP_VLAN && + nsdata->vif.type == NL80211_IFTYPE_AP) + sdata->bss = &nsdata->u.ap; + } + } + + mutex_lock(&local->chanctx_mtx); + ret = ieee80211_check_combinations(sdata, NULL, 0, 0); + mutex_unlock(&local->chanctx_mtx); + return ret; +} + +static int ieee80211_check_queues(struct ieee80211_sub_if_data *sdata, + enum nl80211_iftype iftype) +{ + int n_queues = sdata->local->hw.queues; + int i; + + if (iftype == NL80211_IFTYPE_NAN) + return 0; + + if (iftype != NL80211_IFTYPE_P2P_DEVICE) { + for (i = 0; i < IEEE80211_NUM_ACS; i++) { + if (WARN_ON_ONCE(sdata->vif.hw_queue[i] == + IEEE80211_INVAL_HW_QUEUE)) + return -EINVAL; + if (WARN_ON_ONCE(sdata->vif.hw_queue[i] >= + n_queues)) + return -EINVAL; + } + } + + if ((iftype != NL80211_IFTYPE_AP && + iftype != NL80211_IFTYPE_P2P_GO && + iftype != NL80211_IFTYPE_MESH_POINT) || + !ieee80211_hw_check(&sdata->local->hw, QUEUE_CONTROL)) { + sdata->vif.cab_queue = IEEE80211_INVAL_HW_QUEUE; + return 0; + } + + if (WARN_ON_ONCE(sdata->vif.cab_queue == IEEE80211_INVAL_HW_QUEUE)) + return -EINVAL; + + if (WARN_ON_ONCE(sdata->vif.cab_queue >= n_queues)) + return -EINVAL; + + return 0; +} + +static int ieee80211_open(struct net_device *dev) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + int err; + + /* fail early if user set an invalid address */ + if (!is_valid_ether_addr(dev->dev_addr)) + return -EADDRNOTAVAIL; + + err = ieee80211_check_concurrent_iface(sdata, sdata->vif.type); + if (err) + return err; + + wiphy_lock(sdata->local->hw.wiphy); + err = ieee80211_do_open(&sdata->wdev, true); + wiphy_unlock(sdata->local->hw.wiphy); + + return err; +} + +static void ieee80211_do_stop(struct ieee80211_sub_if_data *sdata, bool going_down) +{ + struct ieee80211_local *local = sdata->local; + unsigned long flags; + struct sk_buff *skb, *tmp; + u32 hw_reconf_flags = 0; + int i, flushed; + struct ps_data *ps; + struct cfg80211_chan_def chandef; + bool cancel_scan; + struct cfg80211_nan_func *func; + + clear_bit(SDATA_STATE_RUNNING, &sdata->state); + synchronize_rcu(); /* flush _ieee80211_wake_txqs() */ + + cancel_scan = rcu_access_pointer(local->scan_sdata) == sdata; + if (cancel_scan) + ieee80211_scan_cancel(local); + + /* + * Stop TX on this interface first. + */ + if (!local->ops->wake_tx_queue && sdata->dev) + netif_tx_stop_all_queues(sdata->dev); + + ieee80211_roc_purge(local, sdata); + + switch (sdata->vif.type) { + case NL80211_IFTYPE_STATION: + ieee80211_mgd_stop(sdata); + break; + case NL80211_IFTYPE_ADHOC: + ieee80211_ibss_stop(sdata); + break; + case NL80211_IFTYPE_MONITOR: + if (sdata->u.mntr.flags & MONITOR_FLAG_COOK_FRAMES) + break; + list_del_rcu(&sdata->u.mntr.list); + break; + default: + break; + } + + /* + * Remove all stations associated with this interface. + * + * This must be done before calling ops->remove_interface() + * because otherwise we can later invoke ops->sta_notify() + * whenever the STAs are removed, and that invalidates driver + * assumptions about always getting a vif pointer that is valid + * (because if we remove a STA after ops->remove_interface() + * the driver will have removed the vif info already!) + * + * For AP_VLANs stations may exist since there's nothing else that + * would have removed them, but in other modes there shouldn't + * be any stations. + */ + flushed = sta_info_flush(sdata); + WARN_ON_ONCE(sdata->vif.type != NL80211_IFTYPE_AP_VLAN && flushed > 0); + + /* don't count this interface for allmulti while it is down */ + if (sdata->flags & IEEE80211_SDATA_ALLMULTI) + atomic_dec(&local->iff_allmultis); + + if (sdata->vif.type == NL80211_IFTYPE_AP) { + local->fif_pspoll--; + local->fif_probe_req--; + } else if (sdata->vif.type == NL80211_IFTYPE_ADHOC) { + local->fif_probe_req--; + } + + if (sdata->dev) { + netif_addr_lock_bh(sdata->dev); + spin_lock_bh(&local->filter_lock); + __hw_addr_unsync(&local->mc_list, &sdata->dev->mc, + sdata->dev->addr_len); + spin_unlock_bh(&local->filter_lock); + netif_addr_unlock_bh(sdata->dev); + } + + del_timer_sync(&local->dynamic_ps_timer); + cancel_work_sync(&local->dynamic_ps_enable_work); + + cancel_work_sync(&sdata->recalc_smps); + + sdata_lock(sdata); + WARN(sdata->vif.valid_links, + "destroying interface with valid links 0x%04x\n", + sdata->vif.valid_links); + + mutex_lock(&local->mtx); + sdata->vif.bss_conf.csa_active = false; + if (sdata->vif.type == NL80211_IFTYPE_STATION) + sdata->deflink.u.mgd.csa_waiting_bcn = false; + if (sdata->deflink.csa_block_tx) { + ieee80211_wake_vif_queues(local, sdata, + IEEE80211_QUEUE_STOP_REASON_CSA); + sdata->deflink.csa_block_tx = false; + } + mutex_unlock(&local->mtx); + sdata_unlock(sdata); + + cancel_work_sync(&sdata->deflink.csa_finalize_work); + cancel_work_sync(&sdata->deflink.color_change_finalize_work); + + cancel_delayed_work_sync(&sdata->deflink.dfs_cac_timer_work); + + if (sdata->wdev.cac_started) { + chandef = sdata->vif.bss_conf.chandef; + WARN_ON(local->suspended); + mutex_lock(&local->mtx); + ieee80211_link_release_channel(&sdata->deflink); + mutex_unlock(&local->mtx); + cfg80211_cac_event(sdata->dev, &chandef, + NL80211_RADAR_CAC_ABORTED, + GFP_KERNEL); + } + + if (sdata->vif.type == NL80211_IFTYPE_AP) { + WARN_ON(!list_empty(&sdata->u.ap.vlans)); + } else if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) { + /* remove all packets in parent bc_buf pointing to this dev */ + ps = &sdata->bss->ps; + + spin_lock_irqsave(&ps->bc_buf.lock, flags); + skb_queue_walk_safe(&ps->bc_buf, skb, tmp) { + if (skb->dev == sdata->dev) { + __skb_unlink(skb, &ps->bc_buf); + local->total_ps_buffered--; + ieee80211_free_txskb(&local->hw, skb); + } + } + spin_unlock_irqrestore(&ps->bc_buf.lock, flags); + } + + if (going_down) + local->open_count--; + + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP_VLAN: + mutex_lock(&local->mtx); + list_del(&sdata->u.vlan.list); + mutex_unlock(&local->mtx); + RCU_INIT_POINTER(sdata->vif.bss_conf.chanctx_conf, NULL); + /* see comment in the default case below */ + ieee80211_free_keys(sdata, true); + /* no need to tell driver */ + break; + case NL80211_IFTYPE_MONITOR: + if (sdata->u.mntr.flags & MONITOR_FLAG_COOK_FRAMES) { + local->cooked_mntrs--; + break; + } + + local->monitors--; + if (local->monitors == 0) { + local->hw.conf.flags &= ~IEEE80211_CONF_MONITOR; + hw_reconf_flags |= IEEE80211_CONF_CHANGE_MONITOR; + } + + ieee80211_adjust_monitor_flags(sdata, -1); + break; + case NL80211_IFTYPE_NAN: + /* clean all the functions */ + spin_lock_bh(&sdata->u.nan.func_lock); + + idr_for_each_entry(&sdata->u.nan.function_inst_ids, func, i) { + idr_remove(&sdata->u.nan.function_inst_ids, i); + cfg80211_free_nan_func(func); + } + idr_destroy(&sdata->u.nan.function_inst_ids); + + spin_unlock_bh(&sdata->u.nan.func_lock); + break; + case NL80211_IFTYPE_P2P_DEVICE: + /* relies on synchronize_rcu() below */ + RCU_INIT_POINTER(local->p2p_sdata, NULL); + fallthrough; + default: + cancel_work_sync(&sdata->work); + /* + * When we get here, the interface is marked down. + * Free the remaining keys, if there are any + * (which can happen in AP mode if userspace sets + * keys before the interface is operating) + * + * Force the key freeing to always synchronize_net() + * to wait for the RX path in case it is using this + * interface enqueuing frames at this very time on + * another CPU. + */ + ieee80211_free_keys(sdata, true); + skb_queue_purge(&sdata->skb_queue); + skb_queue_purge(&sdata->status_queue); + } + + spin_lock_irqsave(&local->queue_stop_reason_lock, flags); + for (i = 0; i < IEEE80211_MAX_QUEUES; i++) { + skb_queue_walk_safe(&local->pending[i], skb, tmp) { + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + if (info->control.vif == &sdata->vif) { + __skb_unlink(skb, &local->pending[i]); + ieee80211_free_txskb(&local->hw, skb); + } + } + } + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); + + if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + ieee80211_txq_remove_vlan(local, sdata); + + sdata->bss = NULL; + + if (local->open_count == 0) + ieee80211_clear_tx_pending(local); + + sdata->vif.bss_conf.beacon_int = 0; + + /* + * If the interface goes down while suspended, presumably because + * the device was unplugged and that happens before our resume, + * then the driver is already unconfigured and the remainder of + * this function isn't needed. + * XXX: what about WoWLAN? If the device has software state, e.g. + * memory allocated, it might expect teardown commands from + * mac80211 here? + */ + if (local->suspended) { + WARN_ON(local->wowlan); + WARN_ON(rcu_access_pointer(local->monitor_sdata)); + return; + } + + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP_VLAN: + break; + case NL80211_IFTYPE_MONITOR: + if (local->monitors == 0) + ieee80211_del_virtual_monitor(local); + + mutex_lock(&local->mtx); + ieee80211_recalc_idle(local); + mutex_unlock(&local->mtx); + + if (!(sdata->u.mntr.flags & MONITOR_FLAG_ACTIVE)) + break; + + fallthrough; + default: + if (going_down) + drv_remove_interface(local, sdata); + } + + ieee80211_recalc_ps(local); + + if (cancel_scan) + wiphy_delayed_work_flush(local->hw.wiphy, &local->scan_work); + + if (local->open_count == 0) { + ieee80211_stop_device(local); + + /* no reconfiguring after stop! */ + return; + } + + /* do after stop to avoid reconfiguring when we stop anyway */ + ieee80211_configure_filter(local); + ieee80211_hw_config(local, hw_reconf_flags); + + if (local->monitors == local->open_count) + ieee80211_add_virtual_monitor(local); +} + +static void ieee80211_stop_mbssid(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_sub_if_data *tx_sdata, *non_tx_sdata, *tmp_sdata; + struct ieee80211_vif *tx_vif = sdata->vif.mbssid_tx_vif; + + if (!tx_vif) + return; + + tx_sdata = vif_to_sdata(tx_vif); + sdata->vif.mbssid_tx_vif = NULL; + + list_for_each_entry_safe(non_tx_sdata, tmp_sdata, + &tx_sdata->local->interfaces, list) { + if (non_tx_sdata != sdata && non_tx_sdata != tx_sdata && + non_tx_sdata->vif.mbssid_tx_vif == tx_vif && + ieee80211_sdata_running(non_tx_sdata)) { + non_tx_sdata->vif.mbssid_tx_vif = NULL; + dev_close(non_tx_sdata->wdev.netdev); + } + } + + if (sdata != tx_sdata && ieee80211_sdata_running(tx_sdata)) { + tx_sdata->vif.mbssid_tx_vif = NULL; + dev_close(tx_sdata->wdev.netdev); + } +} + +static int ieee80211_stop(struct net_device *dev) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + + /* close dependent VLAN and MBSSID interfaces before locking wiphy */ + if (sdata->vif.type == NL80211_IFTYPE_AP) { + struct ieee80211_sub_if_data *vlan, *tmpsdata; + + list_for_each_entry_safe(vlan, tmpsdata, &sdata->u.ap.vlans, + u.vlan.list) + dev_close(vlan->dev); + + ieee80211_stop_mbssid(sdata); + } + + cancel_work_sync(&sdata->activate_links_work); + + wiphy_lock(sdata->local->hw.wiphy); + ieee80211_do_stop(sdata, true); + wiphy_unlock(sdata->local->hw.wiphy); + + return 0; +} + +static void ieee80211_set_multicast_list(struct net_device *dev) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + int allmulti, sdata_allmulti; + + allmulti = !!(dev->flags & IFF_ALLMULTI); + sdata_allmulti = !!(sdata->flags & IEEE80211_SDATA_ALLMULTI); + + if (allmulti != sdata_allmulti) { + if (dev->flags & IFF_ALLMULTI) + atomic_inc(&local->iff_allmultis); + else + atomic_dec(&local->iff_allmultis); + sdata->flags ^= IEEE80211_SDATA_ALLMULTI; + } + + spin_lock_bh(&local->filter_lock); + __hw_addr_sync(&local->mc_list, &dev->mc, dev->addr_len); + spin_unlock_bh(&local->filter_lock); + ieee80211_queue_work(&local->hw, &local->reconfig_filter); +} + +/* + * Called when the netdev is removed or, by the code below, before + * the interface type changes. + */ +static void ieee80211_teardown_sdata(struct ieee80211_sub_if_data *sdata) +{ + /* free extra data */ + ieee80211_free_keys(sdata, false); + + ieee80211_debugfs_remove_netdev(sdata); + + ieee80211_destroy_frag_cache(&sdata->frags); + + if (ieee80211_vif_is_mesh(&sdata->vif)) + ieee80211_mesh_teardown_sdata(sdata); + + ieee80211_vif_clear_links(sdata); + ieee80211_link_stop(&sdata->deflink); +} + +static void ieee80211_uninit(struct net_device *dev) +{ + ieee80211_teardown_sdata(IEEE80211_DEV_TO_SUB_IF(dev)); +} + +static u16 ieee80211_netdev_select_queue(struct net_device *dev, + struct sk_buff *skb, + struct net_device *sb_dev) +{ + return ieee80211_select_queue(IEEE80211_DEV_TO_SUB_IF(dev), skb); +} + +static void +ieee80211_get_stats64(struct net_device *dev, struct rtnl_link_stats64 *stats) +{ + dev_fetch_sw_netstats(stats, dev->tstats); +} + +static const struct net_device_ops ieee80211_dataif_ops = { + .ndo_open = ieee80211_open, + .ndo_stop = ieee80211_stop, + .ndo_uninit = ieee80211_uninit, + .ndo_start_xmit = ieee80211_subif_start_xmit, + .ndo_set_rx_mode = ieee80211_set_multicast_list, + .ndo_set_mac_address = ieee80211_change_mac, + .ndo_select_queue = ieee80211_netdev_select_queue, + .ndo_get_stats64 = ieee80211_get_stats64, +}; + +static u16 ieee80211_monitor_select_queue(struct net_device *dev, + struct sk_buff *skb, + struct net_device *sb_dev) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_hdr *hdr; + int len_rthdr; + + if (local->hw.queues < IEEE80211_NUM_ACS) + return 0; + + /* reset flags and info before parsing radiotap header */ + memset(info, 0, sizeof(*info)); + + if (!ieee80211_parse_tx_radiotap(skb, dev)) + return 0; /* doesn't matter, frame will be dropped */ + + len_rthdr = ieee80211_get_radiotap_len(skb->data); + hdr = (struct ieee80211_hdr *)(skb->data + len_rthdr); + if (skb->len < len_rthdr + 2 || + skb->len < len_rthdr + ieee80211_hdrlen(hdr->frame_control)) + return 0; /* doesn't matter, frame will be dropped */ + + return ieee80211_select_queue_80211(sdata, skb, hdr); +} + +static const struct net_device_ops ieee80211_monitorif_ops = { + .ndo_open = ieee80211_open, + .ndo_stop = ieee80211_stop, + .ndo_uninit = ieee80211_uninit, + .ndo_start_xmit = ieee80211_monitor_start_xmit, + .ndo_set_rx_mode = ieee80211_set_multicast_list, + .ndo_set_mac_address = ieee80211_change_mac, + .ndo_select_queue = ieee80211_monitor_select_queue, + .ndo_get_stats64 = ieee80211_get_stats64, +}; + +static int ieee80211_netdev_fill_forward_path(struct net_device_path_ctx *ctx, + struct net_device_path *path) +{ + struct ieee80211_sub_if_data *sdata; + struct ieee80211_local *local; + struct sta_info *sta; + int ret = -ENOENT; + + sdata = IEEE80211_DEV_TO_SUB_IF(ctx->dev); + local = sdata->local; + + if (!local->ops->net_fill_forward_path) + return -EOPNOTSUPP; + + rcu_read_lock(); + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP_VLAN: + sta = rcu_dereference(sdata->u.vlan.sta); + if (sta) + break; + if (sdata->wdev.use_4addr) + goto out; + if (is_multicast_ether_addr(ctx->daddr)) + goto out; + sta = sta_info_get_bss(sdata, ctx->daddr); + break; + case NL80211_IFTYPE_AP: + if (is_multicast_ether_addr(ctx->daddr)) + goto out; + sta = sta_info_get(sdata, ctx->daddr); + break; + case NL80211_IFTYPE_STATION: + if (sdata->wdev.wiphy->flags & WIPHY_FLAG_SUPPORTS_TDLS) { + sta = sta_info_get(sdata, ctx->daddr); + if (sta && test_sta_flag(sta, WLAN_STA_TDLS_PEER)) { + if (!test_sta_flag(sta, WLAN_STA_TDLS_PEER_AUTH)) + goto out; + + break; + } + } + + sta = sta_info_get(sdata, sdata->deflink.u.mgd.bssid); + break; + default: + goto out; + } + + if (!sta) + goto out; + + ret = drv_net_fill_forward_path(local, sdata, &sta->sta, ctx, path); +out: + rcu_read_unlock(); + + return ret; +} + +static const struct net_device_ops ieee80211_dataif_8023_ops = { + .ndo_open = ieee80211_open, + .ndo_stop = ieee80211_stop, + .ndo_uninit = ieee80211_uninit, + .ndo_start_xmit = ieee80211_subif_start_xmit_8023, + .ndo_set_rx_mode = ieee80211_set_multicast_list, + .ndo_set_mac_address = ieee80211_change_mac, + .ndo_select_queue = ieee80211_netdev_select_queue, + .ndo_get_stats64 = ieee80211_get_stats64, + .ndo_fill_forward_path = ieee80211_netdev_fill_forward_path, +}; + +static bool ieee80211_iftype_supports_hdr_offload(enum nl80211_iftype iftype) +{ + switch (iftype) { + /* P2P GO and client are mapped to AP/STATION types */ + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_STATION: + return true; + default: + return false; + } +} + +static bool ieee80211_set_sdata_offload_flags(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + u32 flags; + + flags = sdata->vif.offload_flags; + + if (ieee80211_hw_check(&local->hw, SUPPORTS_TX_ENCAP_OFFLOAD) && + ieee80211_iftype_supports_hdr_offload(sdata->vif.type)) { + flags |= IEEE80211_OFFLOAD_ENCAP_ENABLED; + + if (!ieee80211_hw_check(&local->hw, SUPPORTS_TX_FRAG) && + local->hw.wiphy->frag_threshold != (u32)-1) + flags &= ~IEEE80211_OFFLOAD_ENCAP_ENABLED; + + if (local->monitors) + flags &= ~IEEE80211_OFFLOAD_ENCAP_ENABLED; + } else { + flags &= ~IEEE80211_OFFLOAD_ENCAP_ENABLED; + } + + if (ieee80211_hw_check(&local->hw, SUPPORTS_RX_DECAP_OFFLOAD) && + ieee80211_iftype_supports_hdr_offload(sdata->vif.type)) { + flags |= IEEE80211_OFFLOAD_DECAP_ENABLED; + + if (local->monitors && + !ieee80211_hw_check(&local->hw, SUPPORTS_CONC_MON_RX_DECAP)) + flags &= ~IEEE80211_OFFLOAD_DECAP_ENABLED; + } else { + flags &= ~IEEE80211_OFFLOAD_DECAP_ENABLED; + } + + if (sdata->vif.offload_flags == flags) + return false; + + sdata->vif.offload_flags = flags; + ieee80211_check_fast_rx_iface(sdata); + return true; +} + +static void ieee80211_set_vif_encap_ops(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_sub_if_data *bss = sdata; + bool enabled; + + if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) { + if (!sdata->bss) + return; + + bss = container_of(sdata->bss, struct ieee80211_sub_if_data, u.ap); + } + + if (!ieee80211_hw_check(&local->hw, SUPPORTS_TX_ENCAP_OFFLOAD) || + !ieee80211_iftype_supports_hdr_offload(bss->vif.type)) + return; + + enabled = bss->vif.offload_flags & IEEE80211_OFFLOAD_ENCAP_ENABLED; + if (sdata->wdev.use_4addr && + !(bss->vif.offload_flags & IEEE80211_OFFLOAD_ENCAP_4ADDR)) + enabled = false; + + sdata->dev->netdev_ops = enabled ? &ieee80211_dataif_8023_ops : + &ieee80211_dataif_ops; +} + +static void ieee80211_recalc_sdata_offload(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_sub_if_data *vsdata; + + if (ieee80211_set_sdata_offload_flags(sdata)) { + drv_update_vif_offload(local, sdata); + ieee80211_set_vif_encap_ops(sdata); + } + + list_for_each_entry(vsdata, &local->interfaces, list) { + if (vsdata->vif.type != NL80211_IFTYPE_AP_VLAN || + vsdata->bss != &sdata->u.ap) + continue; + + ieee80211_set_vif_encap_ops(vsdata); + } +} + +void ieee80211_recalc_offload(struct ieee80211_local *local) +{ + struct ieee80211_sub_if_data *sdata; + + if (!ieee80211_hw_check(&local->hw, SUPPORTS_TX_ENCAP_OFFLOAD)) + return; + + mutex_lock(&local->iflist_mtx); + + list_for_each_entry(sdata, &local->interfaces, list) { + if (!ieee80211_sdata_running(sdata)) + continue; + + ieee80211_recalc_sdata_offload(sdata); + } + + mutex_unlock(&local->iflist_mtx); +} + +void ieee80211_adjust_monitor_flags(struct ieee80211_sub_if_data *sdata, + const int offset) +{ + struct ieee80211_local *local = sdata->local; + u32 flags = sdata->u.mntr.flags; + +#define ADJUST(_f, _s) do { \ + if (flags & MONITOR_FLAG_##_f) \ + local->fif_##_s += offset; \ + } while (0) + + ADJUST(FCSFAIL, fcsfail); + ADJUST(PLCPFAIL, plcpfail); + ADJUST(CONTROL, control); + ADJUST(CONTROL, pspoll); + ADJUST(OTHER_BSS, other_bss); + +#undef ADJUST +} + +static void ieee80211_set_default_queues(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + int i; + + for (i = 0; i < IEEE80211_NUM_ACS; i++) { + if (ieee80211_hw_check(&local->hw, QUEUE_CONTROL)) + sdata->vif.hw_queue[i] = IEEE80211_INVAL_HW_QUEUE; + else if (local->hw.queues >= IEEE80211_NUM_ACS) + sdata->vif.hw_queue[i] = i; + else + sdata->vif.hw_queue[i] = 0; + } + sdata->vif.cab_queue = IEEE80211_INVAL_HW_QUEUE; +} + +static void ieee80211_sdata_init(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + sdata->local = local; + + /* + * Initialize the default link, so we can use link_id 0 for non-MLD, + * and that continues to work for non-MLD-aware drivers that use just + * vif.bss_conf instead of vif.link_conf. + * + * Note that we never change this, so if link ID 0 isn't used in an + * MLD connection, we get a separate allocation for it. + */ + ieee80211_link_init(sdata, -1, &sdata->deflink, &sdata->vif.bss_conf); +} + +int ieee80211_add_virtual_monitor(struct ieee80211_local *local) +{ + struct ieee80211_sub_if_data *sdata; + int ret; + + if (!ieee80211_hw_check(&local->hw, WANT_MONITOR_VIF)) + return 0; + + ASSERT_RTNL(); + lockdep_assert_wiphy(local->hw.wiphy); + + if (local->monitor_sdata) + return 0; + + sdata = kzalloc(sizeof(*sdata) + local->hw.vif_data_size, GFP_KERNEL); + if (!sdata) + return -ENOMEM; + + /* set up data */ + sdata->vif.type = NL80211_IFTYPE_MONITOR; + snprintf(sdata->name, IFNAMSIZ, "%s-monitor", + wiphy_name(local->hw.wiphy)); + sdata->wdev.iftype = NL80211_IFTYPE_MONITOR; + + ieee80211_sdata_init(local, sdata); + + ieee80211_set_default_queues(sdata); + + ret = drv_add_interface(local, sdata); + if (WARN_ON(ret)) { + /* ok .. stupid driver, it asked for this! */ + kfree(sdata); + return ret; + } + + set_bit(SDATA_STATE_RUNNING, &sdata->state); + + ret = ieee80211_check_queues(sdata, NL80211_IFTYPE_MONITOR); + if (ret) { + kfree(sdata); + return ret; + } + + mutex_lock(&local->iflist_mtx); + rcu_assign_pointer(local->monitor_sdata, sdata); + mutex_unlock(&local->iflist_mtx); + + mutex_lock(&local->mtx); + ret = ieee80211_link_use_channel(&sdata->deflink, &local->monitor_chandef, + IEEE80211_CHANCTX_EXCLUSIVE); + mutex_unlock(&local->mtx); + if (ret) { + mutex_lock(&local->iflist_mtx); + RCU_INIT_POINTER(local->monitor_sdata, NULL); + mutex_unlock(&local->iflist_mtx); + synchronize_net(); + drv_remove_interface(local, sdata); + kfree(sdata); + return ret; + } + + skb_queue_head_init(&sdata->skb_queue); + skb_queue_head_init(&sdata->status_queue); + INIT_WORK(&sdata->work, ieee80211_iface_work); + + return 0; +} + +void ieee80211_del_virtual_monitor(struct ieee80211_local *local) +{ + struct ieee80211_sub_if_data *sdata; + + if (!ieee80211_hw_check(&local->hw, WANT_MONITOR_VIF)) + return; + + ASSERT_RTNL(); + lockdep_assert_wiphy(local->hw.wiphy); + + mutex_lock(&local->iflist_mtx); + + sdata = rcu_dereference_protected(local->monitor_sdata, + lockdep_is_held(&local->iflist_mtx)); + if (!sdata) { + mutex_unlock(&local->iflist_mtx); + return; + } + + RCU_INIT_POINTER(local->monitor_sdata, NULL); + mutex_unlock(&local->iflist_mtx); + + synchronize_net(); + + mutex_lock(&local->mtx); + ieee80211_link_release_channel(&sdata->deflink); + mutex_unlock(&local->mtx); + + drv_remove_interface(local, sdata); + + kfree(sdata); +} + +/* + * NOTE: Be very careful when changing this function, it must NOT return + * an error on interface type changes that have been pre-checked, so most + * checks should be in ieee80211_check_concurrent_iface. + */ +int ieee80211_do_open(struct wireless_dev *wdev, bool coming_up) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + struct net_device *dev = wdev->netdev; + struct ieee80211_local *local = sdata->local; + u32 changed = 0; + int res; + u32 hw_reconf_flags = 0; + + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP_VLAN: { + struct ieee80211_sub_if_data *master; + + if (!sdata->bss) + return -ENOLINK; + + mutex_lock(&local->mtx); + list_add(&sdata->u.vlan.list, &sdata->bss->vlans); + mutex_unlock(&local->mtx); + + master = container_of(sdata->bss, + struct ieee80211_sub_if_data, u.ap); + sdata->control_port_protocol = + master->control_port_protocol; + sdata->control_port_no_encrypt = + master->control_port_no_encrypt; + sdata->control_port_over_nl80211 = + master->control_port_over_nl80211; + sdata->control_port_no_preauth = + master->control_port_no_preauth; + sdata->vif.cab_queue = master->vif.cab_queue; + memcpy(sdata->vif.hw_queue, master->vif.hw_queue, + sizeof(sdata->vif.hw_queue)); + sdata->vif.bss_conf.chandef = master->vif.bss_conf.chandef; + + mutex_lock(&local->key_mtx); + sdata->crypto_tx_tailroom_needed_cnt += + master->crypto_tx_tailroom_needed_cnt; + mutex_unlock(&local->key_mtx); + + break; + } + case NL80211_IFTYPE_AP: + sdata->bss = &sdata->u.ap; + break; + case NL80211_IFTYPE_MESH_POINT: + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_MONITOR: + case NL80211_IFTYPE_ADHOC: + case NL80211_IFTYPE_P2P_DEVICE: + case NL80211_IFTYPE_OCB: + case NL80211_IFTYPE_NAN: + /* no special treatment */ + break; + case NL80211_IFTYPE_UNSPECIFIED: + case NUM_NL80211_IFTYPES: + case NL80211_IFTYPE_P2P_CLIENT: + case NL80211_IFTYPE_P2P_GO: + case NL80211_IFTYPE_WDS: + /* cannot happen */ + WARN_ON(1); + break; + } + + if (local->open_count == 0) { + res = drv_start(local); + if (res) + goto err_del_bss; + /* we're brought up, everything changes */ + hw_reconf_flags = ~0; + ieee80211_led_radio(local, true); + ieee80211_mod_tpt_led_trig(local, + IEEE80211_TPT_LEDTRIG_FL_RADIO, 0); + } + + /* + * Copy the hopefully now-present MAC address to + * this interface, if it has the special null one. + */ + if (dev && is_zero_ether_addr(dev->dev_addr)) { + eth_hw_addr_set(dev, local->hw.wiphy->perm_addr); + memcpy(dev->perm_addr, dev->dev_addr, ETH_ALEN); + + if (!is_valid_ether_addr(dev->dev_addr)) { + res = -EADDRNOTAVAIL; + goto err_stop; + } + } + + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP_VLAN: + /* no need to tell driver, but set carrier and chanctx */ + if (sdata->bss->active) { + ieee80211_link_vlan_copy_chanctx(&sdata->deflink); + netif_carrier_on(dev); + ieee80211_set_vif_encap_ops(sdata); + } else { + netif_carrier_off(dev); + } + break; + case NL80211_IFTYPE_MONITOR: + if (sdata->u.mntr.flags & MONITOR_FLAG_COOK_FRAMES) { + local->cooked_mntrs++; + break; + } + + if (sdata->u.mntr.flags & MONITOR_FLAG_ACTIVE) { + res = drv_add_interface(local, sdata); + if (res) + goto err_stop; + } else if (local->monitors == 0 && local->open_count == 0) { + res = ieee80211_add_virtual_monitor(local); + if (res) + goto err_stop; + } + + /* must be before the call to ieee80211_configure_filter */ + local->monitors++; + if (local->monitors == 1) { + local->hw.conf.flags |= IEEE80211_CONF_MONITOR; + hw_reconf_flags |= IEEE80211_CONF_CHANGE_MONITOR; + } + + ieee80211_adjust_monitor_flags(sdata, 1); + ieee80211_configure_filter(local); + ieee80211_recalc_offload(local); + mutex_lock(&local->mtx); + ieee80211_recalc_idle(local); + mutex_unlock(&local->mtx); + + netif_carrier_on(dev); + break; + default: + if (coming_up) { + ieee80211_del_virtual_monitor(local); + ieee80211_set_sdata_offload_flags(sdata); + + res = drv_add_interface(local, sdata); + if (res) + goto err_stop; + + ieee80211_set_vif_encap_ops(sdata); + res = ieee80211_check_queues(sdata, + ieee80211_vif_type_p2p(&sdata->vif)); + if (res) + goto err_del_interface; + } + + if (sdata->vif.type == NL80211_IFTYPE_AP) { + local->fif_pspoll++; + local->fif_probe_req++; + + ieee80211_configure_filter(local); + } else if (sdata->vif.type == NL80211_IFTYPE_ADHOC) { + local->fif_probe_req++; + } + + if (sdata->vif.probe_req_reg) + drv_config_iface_filter(local, sdata, + FIF_PROBE_REQ, + FIF_PROBE_REQ); + + if (sdata->vif.type != NL80211_IFTYPE_P2P_DEVICE && + sdata->vif.type != NL80211_IFTYPE_NAN) + changed |= ieee80211_reset_erp_info(sdata); + ieee80211_link_info_change_notify(sdata, &sdata->deflink, + changed); + + switch (sdata->vif.type) { + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_ADHOC: + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_MESH_POINT: + case NL80211_IFTYPE_OCB: + netif_carrier_off(dev); + break; + case NL80211_IFTYPE_P2P_DEVICE: + case NL80211_IFTYPE_NAN: + break; + default: + /* not reached */ + WARN_ON(1); + } + + /* + * Set default queue parameters so drivers don't + * need to initialise the hardware if the hardware + * doesn't start up with sane defaults. + * Enable QoS for anything but station interfaces. + */ + ieee80211_set_wmm_default(&sdata->deflink, true, + sdata->vif.type != NL80211_IFTYPE_STATION); + } + + switch (sdata->vif.type) { + case NL80211_IFTYPE_P2P_DEVICE: + rcu_assign_pointer(local->p2p_sdata, sdata); + break; + case NL80211_IFTYPE_MONITOR: + if (sdata->u.mntr.flags & MONITOR_FLAG_COOK_FRAMES) + break; + list_add_tail_rcu(&sdata->u.mntr.list, &local->mon_list); + break; + default: + break; + } + + /* + * set_multicast_list will be invoked by the networking core + * which will check whether any increments here were done in + * error and sync them down to the hardware as filter flags. + */ + if (sdata->flags & IEEE80211_SDATA_ALLMULTI) + atomic_inc(&local->iff_allmultis); + + if (coming_up) + local->open_count++; + + if (hw_reconf_flags) + ieee80211_hw_config(local, hw_reconf_flags); + + ieee80211_recalc_ps(local); + + if (sdata->vif.type == NL80211_IFTYPE_MONITOR || + sdata->vif.type == NL80211_IFTYPE_AP_VLAN || + local->ops->wake_tx_queue) { + /* XXX: for AP_VLAN, actually track AP queues */ + if (dev) + netif_tx_start_all_queues(dev); + } else if (dev) { + unsigned long flags; + int n_acs = IEEE80211_NUM_ACS; + int ac; + + if (local->hw.queues < IEEE80211_NUM_ACS) + n_acs = 1; + + spin_lock_irqsave(&local->queue_stop_reason_lock, flags); + if (sdata->vif.cab_queue == IEEE80211_INVAL_HW_QUEUE || + (local->queue_stop_reasons[sdata->vif.cab_queue] == 0 && + skb_queue_empty(&local->pending[sdata->vif.cab_queue]))) { + for (ac = 0; ac < n_acs; ac++) { + int ac_queue = sdata->vif.hw_queue[ac]; + + if (local->queue_stop_reasons[ac_queue] == 0 && + skb_queue_empty(&local->pending[ac_queue])) + netif_start_subqueue(dev, ac); + } + } + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); + } + + set_bit(SDATA_STATE_RUNNING, &sdata->state); + + return 0; + err_del_interface: + drv_remove_interface(local, sdata); + err_stop: + if (!local->open_count) + drv_stop(local); + err_del_bss: + sdata->bss = NULL; + if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) { + mutex_lock(&local->mtx); + list_del(&sdata->u.vlan.list); + mutex_unlock(&local->mtx); + } + /* might already be clear but that doesn't matter */ + clear_bit(SDATA_STATE_RUNNING, &sdata->state); + return res; +} + +static void ieee80211_if_free(struct net_device *dev) +{ + free_percpu(dev->tstats); +} + +static void ieee80211_if_setup(struct net_device *dev) +{ + ether_setup(dev); + dev->priv_flags &= ~IFF_TX_SKB_SHARING; + dev->netdev_ops = &ieee80211_dataif_ops; + dev->needs_free_netdev = true; + dev->priv_destructor = ieee80211_if_free; +} + +static void ieee80211_if_setup_no_queue(struct net_device *dev) +{ + ieee80211_if_setup(dev); + dev->priv_flags |= IFF_NO_QUEUE; +} + +static void ieee80211_iface_process_skb(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_mgmt *mgmt = (void *)skb->data; + + if (ieee80211_is_action(mgmt->frame_control) && + mgmt->u.action.category == WLAN_CATEGORY_BACK) { + struct sta_info *sta; + int len = skb->len; + + mutex_lock(&local->sta_mtx); + sta = sta_info_get_bss(sdata, mgmt->sa); + if (sta) { + switch (mgmt->u.action.u.addba_req.action_code) { + case WLAN_ACTION_ADDBA_REQ: + ieee80211_process_addba_request(local, sta, + mgmt, len); + break; + case WLAN_ACTION_ADDBA_RESP: + ieee80211_process_addba_resp(local, sta, + mgmt, len); + break; + case WLAN_ACTION_DELBA: + ieee80211_process_delba(sdata, sta, + mgmt, len); + break; + default: + WARN_ON(1); + break; + } + } + mutex_unlock(&local->sta_mtx); + } else if (ieee80211_is_action(mgmt->frame_control) && + mgmt->u.action.category == WLAN_CATEGORY_VHT) { + switch (mgmt->u.action.u.vht_group_notif.action_code) { + case WLAN_VHT_ACTION_OPMODE_NOTIF: { + struct ieee80211_rx_status *status; + enum nl80211_band band; + struct sta_info *sta; + u8 opmode; + + status = IEEE80211_SKB_RXCB(skb); + band = status->band; + opmode = mgmt->u.action.u.vht_opmode_notif.operating_mode; + + mutex_lock(&local->sta_mtx); + sta = sta_info_get_bss(sdata, mgmt->sa); + + if (sta) + ieee80211_vht_handle_opmode(sdata, + &sta->deflink, + opmode, band); + + mutex_unlock(&local->sta_mtx); + break; + } + case WLAN_VHT_ACTION_GROUPID_MGMT: + ieee80211_process_mu_groups(sdata, &sdata->deflink, + mgmt); + break; + default: + WARN_ON(1); + break; + } + } else if (ieee80211_is_action(mgmt->frame_control) && + mgmt->u.action.category == WLAN_CATEGORY_S1G) { + switch (mgmt->u.action.u.s1g.action_code) { + case WLAN_S1G_TWT_TEARDOWN: + case WLAN_S1G_TWT_SETUP: + ieee80211_s1g_rx_twt_action(sdata, skb); + break; + default: + break; + } + } else if (ieee80211_is_ext(mgmt->frame_control)) { + if (sdata->vif.type == NL80211_IFTYPE_STATION) + ieee80211_sta_rx_queued_ext(sdata, skb); + else + WARN_ON(1); + } else if (ieee80211_is_data_qos(mgmt->frame_control)) { + struct ieee80211_hdr *hdr = (void *)mgmt; + struct sta_info *sta; + + /* + * So the frame isn't mgmt, but frame_control + * is at the right place anyway, of course, so + * the if statement is correct. + * + * Warn if we have other data frame types here, + * they must not get here. + */ + WARN_ON(hdr->frame_control & + cpu_to_le16(IEEE80211_STYPE_NULLFUNC)); + WARN_ON(!(hdr->seq_ctrl & + cpu_to_le16(IEEE80211_SCTL_FRAG))); + /* + * This was a fragment of a frame, received while + * a block-ack session was active. That cannot be + * right, so terminate the session. + */ + mutex_lock(&local->sta_mtx); + sta = sta_info_get_bss(sdata, mgmt->sa); + if (sta) { + u16 tid = ieee80211_get_tid(hdr); + + __ieee80211_stop_rx_ba_session( + sta, tid, WLAN_BACK_RECIPIENT, + WLAN_REASON_QSTA_REQUIRE_SETUP, + true); + } + mutex_unlock(&local->sta_mtx); + } else switch (sdata->vif.type) { + case NL80211_IFTYPE_STATION: + ieee80211_sta_rx_queued_mgmt(sdata, skb); + break; + case NL80211_IFTYPE_ADHOC: + ieee80211_ibss_rx_queued_mgmt(sdata, skb); + break; + case NL80211_IFTYPE_MESH_POINT: + if (!ieee80211_vif_is_mesh(&sdata->vif)) + break; + ieee80211_mesh_rx_queued_mgmt(sdata, skb); + break; + default: + WARN(1, "frame for unexpected interface type"); + break; + } +} + +static void ieee80211_iface_process_status(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_mgmt *mgmt = (void *)skb->data; + + if (ieee80211_is_action(mgmt->frame_control) && + mgmt->u.action.category == WLAN_CATEGORY_S1G) { + switch (mgmt->u.action.u.s1g.action_code) { + case WLAN_S1G_TWT_TEARDOWN: + case WLAN_S1G_TWT_SETUP: + ieee80211_s1g_status_twt_action(sdata, skb); + break; + default: + break; + } + } +} + +static void ieee80211_iface_work(struct work_struct *work) +{ + struct ieee80211_sub_if_data *sdata = + container_of(work, struct ieee80211_sub_if_data, work); + struct ieee80211_local *local = sdata->local; + struct sk_buff *skb; + + if (!ieee80211_sdata_running(sdata)) + return; + + if (test_bit(SCAN_SW_SCANNING, &local->scanning)) + return; + + if (!ieee80211_can_run_worker(local)) + return; + + /* first process frames */ + while ((skb = skb_dequeue(&sdata->skb_queue))) { + kcov_remote_start_common(skb_get_kcov_handle(skb)); + + if (skb->protocol == cpu_to_be16(ETH_P_TDLS)) + ieee80211_process_tdls_channel_switch(sdata, skb); + else + ieee80211_iface_process_skb(local, sdata, skb); + + kfree_skb(skb); + kcov_remote_stop(); + } + + /* process status queue */ + while ((skb = skb_dequeue(&sdata->status_queue))) { + kcov_remote_start_common(skb_get_kcov_handle(skb)); + + ieee80211_iface_process_status(sdata, skb); + kfree_skb(skb); + + kcov_remote_stop(); + } + + /* then other type-dependent work */ + switch (sdata->vif.type) { + case NL80211_IFTYPE_STATION: + ieee80211_sta_work(sdata); + break; + case NL80211_IFTYPE_ADHOC: + ieee80211_ibss_work(sdata); + break; + case NL80211_IFTYPE_MESH_POINT: + if (!ieee80211_vif_is_mesh(&sdata->vif)) + break; + ieee80211_mesh_work(sdata); + break; + case NL80211_IFTYPE_OCB: + ieee80211_ocb_work(sdata); + break; + default: + break; + } +} + +static void ieee80211_recalc_smps_work(struct work_struct *work) +{ + struct ieee80211_sub_if_data *sdata = + container_of(work, struct ieee80211_sub_if_data, recalc_smps); + + ieee80211_recalc_smps(sdata, &sdata->deflink); +} + +static void ieee80211_activate_links_work(struct work_struct *work) +{ + struct ieee80211_sub_if_data *sdata = + container_of(work, struct ieee80211_sub_if_data, + activate_links_work); + + ieee80211_set_active_links(&sdata->vif, sdata->desired_active_links); +} + +/* + * Helper function to initialise an interface to a specific type. + */ +static void ieee80211_setup_sdata(struct ieee80211_sub_if_data *sdata, + enum nl80211_iftype type) +{ + static const u8 bssid_wildcard[ETH_ALEN] = {0xff, 0xff, 0xff, + 0xff, 0xff, 0xff}; + + /* clear type-dependent unions */ + memset(&sdata->u, 0, sizeof(sdata->u)); + memset(&sdata->deflink.u, 0, sizeof(sdata->deflink.u)); + + /* and set some type-dependent values */ + sdata->vif.type = type; + sdata->vif.p2p = false; + sdata->wdev.iftype = type; + + sdata->control_port_protocol = cpu_to_be16(ETH_P_PAE); + sdata->control_port_no_encrypt = false; + sdata->control_port_over_nl80211 = false; + sdata->control_port_no_preauth = false; + sdata->vif.cfg.idle = true; + sdata->vif.bss_conf.txpower = INT_MIN; /* unset */ + + sdata->noack_map = 0; + + /* only monitor/p2p-device differ */ + if (sdata->dev) { + sdata->dev->netdev_ops = &ieee80211_dataif_ops; + sdata->dev->type = ARPHRD_ETHER; + } + + skb_queue_head_init(&sdata->skb_queue); + skb_queue_head_init(&sdata->status_queue); + INIT_WORK(&sdata->work, ieee80211_iface_work); + INIT_WORK(&sdata->recalc_smps, ieee80211_recalc_smps_work); + INIT_WORK(&sdata->activate_links_work, ieee80211_activate_links_work); + + switch (type) { + case NL80211_IFTYPE_P2P_GO: + type = NL80211_IFTYPE_AP; + sdata->vif.type = type; + sdata->vif.p2p = true; + fallthrough; + case NL80211_IFTYPE_AP: + skb_queue_head_init(&sdata->u.ap.ps.bc_buf); + INIT_LIST_HEAD(&sdata->u.ap.vlans); + sdata->vif.bss_conf.bssid = sdata->vif.addr; + break; + case NL80211_IFTYPE_P2P_CLIENT: + type = NL80211_IFTYPE_STATION; + sdata->vif.type = type; + sdata->vif.p2p = true; + fallthrough; + case NL80211_IFTYPE_STATION: + sdata->vif.bss_conf.bssid = sdata->deflink.u.mgd.bssid; + ieee80211_sta_setup_sdata(sdata); + break; + case NL80211_IFTYPE_OCB: + sdata->vif.bss_conf.bssid = bssid_wildcard; + ieee80211_ocb_setup_sdata(sdata); + break; + case NL80211_IFTYPE_ADHOC: + sdata->vif.bss_conf.bssid = sdata->u.ibss.bssid; + ieee80211_ibss_setup_sdata(sdata); + break; + case NL80211_IFTYPE_MESH_POINT: + if (ieee80211_vif_is_mesh(&sdata->vif)) + ieee80211_mesh_init_sdata(sdata); + break; + case NL80211_IFTYPE_MONITOR: + sdata->dev->type = ARPHRD_IEEE80211_RADIOTAP; + sdata->dev->netdev_ops = &ieee80211_monitorif_ops; + sdata->u.mntr.flags = MONITOR_FLAG_CONTROL | + MONITOR_FLAG_OTHER_BSS; + break; + case NL80211_IFTYPE_NAN: + idr_init(&sdata->u.nan.function_inst_ids); + spin_lock_init(&sdata->u.nan.func_lock); + sdata->vif.bss_conf.bssid = sdata->vif.addr; + break; + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_P2P_DEVICE: + sdata->vif.bss_conf.bssid = sdata->vif.addr; + break; + case NL80211_IFTYPE_UNSPECIFIED: + case NL80211_IFTYPE_WDS: + case NUM_NL80211_IFTYPES: + WARN_ON(1); + break; + } + + /* need to do this after the switch so vif.type is correct */ + ieee80211_link_setup(&sdata->deflink); + + ieee80211_debugfs_add_netdev(sdata); +} + +static int ieee80211_runtime_change_iftype(struct ieee80211_sub_if_data *sdata, + enum nl80211_iftype type) +{ + struct ieee80211_local *local = sdata->local; + int ret, err; + enum nl80211_iftype internal_type = type; + bool p2p = false; + + ASSERT_RTNL(); + + if (!local->ops->change_interface) + return -EBUSY; + + /* for now, don't support changing while links exist */ + if (sdata->vif.valid_links) + return -EBUSY; + + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP: + if (!list_empty(&sdata->u.ap.vlans)) + return -EBUSY; + break; + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_ADHOC: + case NL80211_IFTYPE_OCB: + /* + * Could maybe also all others here? + * Just not sure how that interacts + * with the RX/config path e.g. for + * mesh. + */ + break; + default: + return -EBUSY; + } + + switch (type) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_ADHOC: + case NL80211_IFTYPE_OCB: + /* + * Could probably support everything + * but here. + */ + break; + case NL80211_IFTYPE_P2P_CLIENT: + p2p = true; + internal_type = NL80211_IFTYPE_STATION; + break; + case NL80211_IFTYPE_P2P_GO: + p2p = true; + internal_type = NL80211_IFTYPE_AP; + break; + default: + return -EBUSY; + } + + ret = ieee80211_check_concurrent_iface(sdata, internal_type); + if (ret) + return ret; + + ieee80211_stop_vif_queues(local, sdata, + IEEE80211_QUEUE_STOP_REASON_IFTYPE_CHANGE); + synchronize_net(); + + ieee80211_do_stop(sdata, false); + + ieee80211_teardown_sdata(sdata); + + ieee80211_set_sdata_offload_flags(sdata); + ret = drv_change_interface(local, sdata, internal_type, p2p); + if (ret) + type = ieee80211_vif_type_p2p(&sdata->vif); + + /* + * Ignore return value here, there's not much we can do since + * the driver changed the interface type internally already. + * The warnings will hopefully make driver authors fix it :-) + */ + ieee80211_check_queues(sdata, type); + + ieee80211_setup_sdata(sdata, type); + ieee80211_set_vif_encap_ops(sdata); + + err = ieee80211_do_open(&sdata->wdev, false); + WARN(err, "type change: do_open returned %d", err); + + ieee80211_wake_vif_queues(local, sdata, + IEEE80211_QUEUE_STOP_REASON_IFTYPE_CHANGE); + return ret; +} + +int ieee80211_if_change_type(struct ieee80211_sub_if_data *sdata, + enum nl80211_iftype type) +{ + int ret; + + ASSERT_RTNL(); + + if (type == ieee80211_vif_type_p2p(&sdata->vif)) + return 0; + + if (ieee80211_sdata_running(sdata)) { + ret = ieee80211_runtime_change_iftype(sdata, type); + if (ret) + return ret; + } else { + /* Purge and reset type-dependent state. */ + ieee80211_teardown_sdata(sdata); + ieee80211_setup_sdata(sdata, type); + } + + /* reset some values that shouldn't be kept across type changes */ + if (type == NL80211_IFTYPE_STATION) + sdata->u.mgd.use_4addr = false; + + return 0; +} + +static void ieee80211_assign_perm_addr(struct ieee80211_local *local, + u8 *perm_addr, enum nl80211_iftype type) +{ + struct ieee80211_sub_if_data *sdata; + u64 mask, start, addr, val, inc; + u8 *m; + u8 tmp_addr[ETH_ALEN]; + int i; + + /* default ... something at least */ + memcpy(perm_addr, local->hw.wiphy->perm_addr, ETH_ALEN); + + if (is_zero_ether_addr(local->hw.wiphy->addr_mask) && + local->hw.wiphy->n_addresses <= 1) + return; + + mutex_lock(&local->iflist_mtx); + + switch (type) { + case NL80211_IFTYPE_MONITOR: + /* doesn't matter */ + break; + case NL80211_IFTYPE_AP_VLAN: + /* match up with an AP interface */ + list_for_each_entry(sdata, &local->interfaces, list) { + if (sdata->vif.type != NL80211_IFTYPE_AP) + continue; + memcpy(perm_addr, sdata->vif.addr, ETH_ALEN); + break; + } + /* keep default if no AP interface present */ + break; + case NL80211_IFTYPE_P2P_CLIENT: + case NL80211_IFTYPE_P2P_GO: + if (ieee80211_hw_check(&local->hw, P2P_DEV_ADDR_FOR_INTF)) { + list_for_each_entry(sdata, &local->interfaces, list) { + if (sdata->vif.type != NL80211_IFTYPE_P2P_DEVICE) + continue; + if (!ieee80211_sdata_running(sdata)) + continue; + memcpy(perm_addr, sdata->vif.addr, ETH_ALEN); + goto out_unlock; + } + } + fallthrough; + default: + /* assign a new address if possible -- try n_addresses first */ + for (i = 0; i < local->hw.wiphy->n_addresses; i++) { + bool used = false; + + list_for_each_entry(sdata, &local->interfaces, list) { + if (ether_addr_equal(local->hw.wiphy->addresses[i].addr, + sdata->vif.addr)) { + used = true; + break; + } + } + + if (!used) { + memcpy(perm_addr, + local->hw.wiphy->addresses[i].addr, + ETH_ALEN); + break; + } + } + + /* try mask if available */ + if (is_zero_ether_addr(local->hw.wiphy->addr_mask)) + break; + + m = local->hw.wiphy->addr_mask; + mask = ((u64)m[0] << 5*8) | ((u64)m[1] << 4*8) | + ((u64)m[2] << 3*8) | ((u64)m[3] << 2*8) | + ((u64)m[4] << 1*8) | ((u64)m[5] << 0*8); + + if (__ffs64(mask) + hweight64(mask) != fls64(mask)) { + /* not a contiguous mask ... not handled now! */ + pr_info("not contiguous\n"); + break; + } + + /* + * Pick address of existing interface in case user changed + * MAC address manually, default to perm_addr. + */ + m = local->hw.wiphy->perm_addr; + list_for_each_entry(sdata, &local->interfaces, list) { + if (sdata->vif.type == NL80211_IFTYPE_MONITOR) + continue; + m = sdata->vif.addr; + break; + } + start = ((u64)m[0] << 5*8) | ((u64)m[1] << 4*8) | + ((u64)m[2] << 3*8) | ((u64)m[3] << 2*8) | + ((u64)m[4] << 1*8) | ((u64)m[5] << 0*8); + + inc = 1ULL<<__ffs64(mask); + val = (start & mask); + addr = (start & ~mask) | (val & mask); + do { + bool used = false; + + tmp_addr[5] = addr >> 0*8; + tmp_addr[4] = addr >> 1*8; + tmp_addr[3] = addr >> 2*8; + tmp_addr[2] = addr >> 3*8; + tmp_addr[1] = addr >> 4*8; + tmp_addr[0] = addr >> 5*8; + + val += inc; + + list_for_each_entry(sdata, &local->interfaces, list) { + if (ether_addr_equal(tmp_addr, sdata->vif.addr)) { + used = true; + break; + } + } + + if (!used) { + memcpy(perm_addr, tmp_addr, ETH_ALEN); + break; + } + addr = (start & ~mask) | (val & mask); + } while (addr != start); + + break; + } + + out_unlock: + mutex_unlock(&local->iflist_mtx); +} + +int ieee80211_if_add(struct ieee80211_local *local, const char *name, + unsigned char name_assign_type, + struct wireless_dev **new_wdev, enum nl80211_iftype type, + struct vif_params *params) +{ + struct net_device *ndev = NULL; + struct ieee80211_sub_if_data *sdata = NULL; + struct txq_info *txqi; + void (*if_setup)(struct net_device *dev); + int ret, i; + int txqs = 1; + + ASSERT_RTNL(); + + if (type == NL80211_IFTYPE_P2P_DEVICE || type == NL80211_IFTYPE_NAN) { + struct wireless_dev *wdev; + + sdata = kzalloc(sizeof(*sdata) + local->hw.vif_data_size, + GFP_KERNEL); + if (!sdata) + return -ENOMEM; + wdev = &sdata->wdev; + + sdata->dev = NULL; + strscpy(sdata->name, name, IFNAMSIZ); + ieee80211_assign_perm_addr(local, wdev->address, type); + memcpy(sdata->vif.addr, wdev->address, ETH_ALEN); + ether_addr_copy(sdata->vif.bss_conf.addr, sdata->vif.addr); + } else { + int size = ALIGN(sizeof(*sdata) + local->hw.vif_data_size, + sizeof(void *)); + int txq_size = 0; + + if (local->ops->wake_tx_queue && + type != NL80211_IFTYPE_AP_VLAN && + (type != NL80211_IFTYPE_MONITOR || + (params->flags & MONITOR_FLAG_ACTIVE))) + txq_size += sizeof(struct txq_info) + + local->hw.txq_data_size; + + if (local->ops->wake_tx_queue) { + if_setup = ieee80211_if_setup_no_queue; + } else { + if_setup = ieee80211_if_setup; + if (local->hw.queues >= IEEE80211_NUM_ACS) + txqs = IEEE80211_NUM_ACS; + } + + ndev = alloc_netdev_mqs(size + txq_size, + name, name_assign_type, + if_setup, txqs, 1); + if (!ndev) + return -ENOMEM; + + if (!local->ops->wake_tx_queue && local->hw.wiphy->tx_queue_len) + ndev->tx_queue_len = local->hw.wiphy->tx_queue_len; + + dev_net_set(ndev, wiphy_net(local->hw.wiphy)); + + ndev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); + if (!ndev->tstats) { + free_netdev(ndev); + return -ENOMEM; + } + + ndev->needed_headroom = local->tx_headroom + + 4*6 /* four MAC addresses */ + + 2 + 2 + 2 + 2 /* ctl, dur, seq, qos */ + + 6 /* mesh */ + + 8 /* rfc1042/bridge tunnel */ + - ETH_HLEN /* ethernet hard_header_len */ + + IEEE80211_ENCRYPT_HEADROOM; + ndev->needed_tailroom = IEEE80211_ENCRYPT_TAILROOM; + + ret = dev_alloc_name(ndev, ndev->name); + if (ret < 0) { + ieee80211_if_free(ndev); + free_netdev(ndev); + return ret; + } + + ieee80211_assign_perm_addr(local, ndev->perm_addr, type); + if (is_valid_ether_addr(params->macaddr)) + eth_hw_addr_set(ndev, params->macaddr); + else + eth_hw_addr_set(ndev, ndev->perm_addr); + SET_NETDEV_DEV(ndev, wiphy_dev(local->hw.wiphy)); + + /* don't use IEEE80211_DEV_TO_SUB_IF -- it checks too much */ + sdata = netdev_priv(ndev); + ndev->ieee80211_ptr = &sdata->wdev; + memcpy(sdata->vif.addr, ndev->dev_addr, ETH_ALEN); + ether_addr_copy(sdata->vif.bss_conf.addr, sdata->vif.addr); + memcpy(sdata->name, ndev->name, IFNAMSIZ); + + if (txq_size) { + txqi = netdev_priv(ndev) + size; + ieee80211_txq_init(sdata, NULL, txqi, 0); + } + + sdata->dev = ndev; + } + + /* initialise type-independent data */ + sdata->wdev.wiphy = local->hw.wiphy; + + ieee80211_sdata_init(local, sdata); + + ieee80211_init_frag_cache(&sdata->frags); + + INIT_LIST_HEAD(&sdata->key_list); + + INIT_DELAYED_WORK(&sdata->dec_tailroom_needed_wk, + ieee80211_delayed_tailroom_dec); + + for (i = 0; i < NUM_NL80211_BANDS; i++) { + struct ieee80211_supported_band *sband; + sband = local->hw.wiphy->bands[i]; + sdata->rc_rateidx_mask[i] = + sband ? (1 << sband->n_bitrates) - 1 : 0; + if (sband) { + __le16 cap; + u16 *vht_rate_mask; + + memcpy(sdata->rc_rateidx_mcs_mask[i], + sband->ht_cap.mcs.rx_mask, + sizeof(sdata->rc_rateidx_mcs_mask[i])); + + cap = sband->vht_cap.vht_mcs.rx_mcs_map; + vht_rate_mask = sdata->rc_rateidx_vht_mcs_mask[i]; + ieee80211_get_vht_mask_from_cap(cap, vht_rate_mask); + } else { + memset(sdata->rc_rateidx_mcs_mask[i], 0, + sizeof(sdata->rc_rateidx_mcs_mask[i])); + memset(sdata->rc_rateidx_vht_mcs_mask[i], 0, + sizeof(sdata->rc_rateidx_vht_mcs_mask[i])); + } + } + + ieee80211_set_default_queues(sdata); + + sdata->deflink.ap_power_level = IEEE80211_UNSET_POWER_LEVEL; + sdata->deflink.user_power_level = local->user_power_level; + + /* setup type-dependent data */ + ieee80211_setup_sdata(sdata, type); + + if (ndev) { + ndev->ieee80211_ptr->use_4addr = params->use_4addr; + if (type == NL80211_IFTYPE_STATION) + sdata->u.mgd.use_4addr = params->use_4addr; + + ndev->features |= local->hw.netdev_features; + ndev->priv_flags |= IFF_LIVE_ADDR_CHANGE; + ndev->hw_features |= ndev->features & + MAC80211_SUPPORTED_FEATURES_TX; + + netdev_set_default_ethtool_ops(ndev, &ieee80211_ethtool_ops); + + /* MTU range is normally 256 - 2304, where the upper limit is + * the maximum MSDU size. Monitor interfaces send and receive + * MPDU and A-MSDU frames which may be much larger so we do + * not impose an upper limit in that case. + */ + ndev->min_mtu = 256; + if (type == NL80211_IFTYPE_MONITOR) + ndev->max_mtu = 0; + else + ndev->max_mtu = local->hw.max_mtu; + + ret = cfg80211_register_netdevice(ndev); + if (ret) { + free_netdev(ndev); + return ret; + } + } + + mutex_lock(&local->iflist_mtx); + list_add_tail_rcu(&sdata->list, &local->interfaces); + mutex_unlock(&local->iflist_mtx); + + if (new_wdev) + *new_wdev = &sdata->wdev; + + return 0; +} + +void ieee80211_if_remove(struct ieee80211_sub_if_data *sdata) +{ + ASSERT_RTNL(); + + mutex_lock(&sdata->local->iflist_mtx); + list_del_rcu(&sdata->list); + mutex_unlock(&sdata->local->iflist_mtx); + + if (sdata->vif.txq) + ieee80211_txq_purge(sdata->local, to_txq_info(sdata->vif.txq)); + + synchronize_rcu(); + + cfg80211_unregister_wdev(&sdata->wdev); + + if (!sdata->dev) { + ieee80211_teardown_sdata(sdata); + kfree(sdata); + } +} + +void ieee80211_sdata_stop(struct ieee80211_sub_if_data *sdata) +{ + if (WARN_ON_ONCE(!test_bit(SDATA_STATE_RUNNING, &sdata->state))) + return; + ieee80211_do_stop(sdata, true); +} + +void ieee80211_remove_interfaces(struct ieee80211_local *local) +{ + struct ieee80211_sub_if_data *sdata, *tmp; + LIST_HEAD(unreg_list); + LIST_HEAD(wdev_list); + + ASSERT_RTNL(); + + /* Before destroying the interfaces, make sure they're all stopped so + * that the hardware is stopped. Otherwise, the driver might still be + * iterating the interfaces during the shutdown, e.g. from a worker + * or from RX processing or similar, and if it does so (using atomic + * iteration) while we're manipulating the list, the iteration will + * crash. + * + * After this, the hardware should be stopped and the driver should + * have stopped all of its activities, so that we can do RCU-unaware + * manipulations of the interface list below. + */ + cfg80211_shutdown_all_interfaces(local->hw.wiphy); + + WARN(local->open_count, "%s: open count remains %d\n", + wiphy_name(local->hw.wiphy), local->open_count); + + ieee80211_txq_teardown_flows(local); + + mutex_lock(&local->iflist_mtx); + list_for_each_entry_safe(sdata, tmp, &local->interfaces, list) { + list_del(&sdata->list); + + if (sdata->dev) + unregister_netdevice_queue(sdata->dev, &unreg_list); + else + list_add(&sdata->list, &wdev_list); + } + mutex_unlock(&local->iflist_mtx); + + unregister_netdevice_many(&unreg_list); + + wiphy_lock(local->hw.wiphy); + list_for_each_entry_safe(sdata, tmp, &wdev_list, list) { + list_del(&sdata->list); + cfg80211_unregister_wdev(&sdata->wdev); + kfree(sdata); + } + wiphy_unlock(local->hw.wiphy); +} + +static int netdev_notify(struct notifier_block *nb, + unsigned long state, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct ieee80211_sub_if_data *sdata; + + if (state != NETDEV_CHANGENAME) + return NOTIFY_DONE; + + if (!dev->ieee80211_ptr || !dev->ieee80211_ptr->wiphy) + return NOTIFY_DONE; + + if (dev->ieee80211_ptr->wiphy->privid != mac80211_wiphy_privid) + return NOTIFY_DONE; + + sdata = IEEE80211_DEV_TO_SUB_IF(dev); + memcpy(sdata->name, dev->name, IFNAMSIZ); + ieee80211_debugfs_rename_netdev(sdata); + + return NOTIFY_OK; +} + +static struct notifier_block mac80211_netdev_notifier = { + .notifier_call = netdev_notify, +}; + +int ieee80211_iface_init(void) +{ + return register_netdevice_notifier(&mac80211_netdev_notifier); +} + +void ieee80211_iface_exit(void) +{ + unregister_netdevice_notifier(&mac80211_netdev_notifier); +} + +void ieee80211_vif_inc_num_mcast(struct ieee80211_sub_if_data *sdata) +{ + if (sdata->vif.type == NL80211_IFTYPE_AP) + atomic_inc(&sdata->u.ap.num_mcast_sta); + else if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + atomic_inc(&sdata->u.vlan.num_mcast_sta); +} + +void ieee80211_vif_dec_num_mcast(struct ieee80211_sub_if_data *sdata) +{ + if (sdata->vif.type == NL80211_IFTYPE_AP) + atomic_dec(&sdata->u.ap.num_mcast_sta); + else if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + atomic_dec(&sdata->u.vlan.num_mcast_sta); +} diff --git a/net/mac80211/key.c b/net/mac80211/key.c new file mode 100644 index 000000000..23bb24243 --- /dev/null +++ b/net/mac80211/key.c @@ -0,0 +1,1481 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2002-2005, Instant802 Networks, Inc. + * Copyright 2005-2006, Devicescape Software, Inc. + * Copyright 2006-2007 Jiri Benc + * Copyright 2007-2008 Johannes Berg + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright 2015-2017 Intel Deutschland GmbH + * Copyright 2018-2020, 2022 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ieee80211_i.h" +#include "driver-ops.h" +#include "debugfs_key.h" +#include "aes_ccm.h" +#include "aes_cmac.h" +#include "aes_gmac.h" +#include "aes_gcm.h" + + +/** + * DOC: Key handling basics + * + * Key handling in mac80211 is done based on per-interface (sub_if_data) + * keys and per-station keys. Since each station belongs to an interface, + * each station key also belongs to that interface. + * + * Hardware acceleration is done on a best-effort basis for algorithms + * that are implemented in software, for each key the hardware is asked + * to enable that key for offloading but if it cannot do that the key is + * simply kept for software encryption (unless it is for an algorithm + * that isn't implemented in software). + * There is currently no way of knowing whether a key is handled in SW + * or HW except by looking into debugfs. + * + * All key management is internally protected by a mutex. Within all + * other parts of mac80211, key references are, just as STA structure + * references, protected by RCU. Note, however, that some things are + * unprotected, namely the key->sta dereferences within the hardware + * acceleration functions. This means that sta_info_destroy() must + * remove the key which waits for an RCU grace period. + */ + +static const u8 bcast_addr[ETH_ALEN] = { 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF }; + +static void assert_key_lock(struct ieee80211_local *local) +{ + lockdep_assert_held(&local->key_mtx); +} + +static void +update_vlan_tailroom_need_count(struct ieee80211_sub_if_data *sdata, int delta) +{ + struct ieee80211_sub_if_data *vlan; + + if (sdata->vif.type != NL80211_IFTYPE_AP) + return; + + /* crypto_tx_tailroom_needed_cnt is protected by this */ + assert_key_lock(sdata->local); + + rcu_read_lock(); + + list_for_each_entry_rcu(vlan, &sdata->u.ap.vlans, u.vlan.list) + vlan->crypto_tx_tailroom_needed_cnt += delta; + + rcu_read_unlock(); +} + +static void increment_tailroom_need_count(struct ieee80211_sub_if_data *sdata) +{ + /* + * When this count is zero, SKB resizing for allocating tailroom + * for IV or MMIC is skipped. But, this check has created two race + * cases in xmit path while transiting from zero count to one: + * + * 1. SKB resize was skipped because no key was added but just before + * the xmit key is added and SW encryption kicks off. + * + * 2. SKB resize was skipped because all the keys were hw planted but + * just before xmit one of the key is deleted and SW encryption kicks + * off. + * + * In both the above case SW encryption will find not enough space for + * tailroom and exits with WARN_ON. (See WARN_ONs at wpa.c) + * + * Solution has been explained at + * http://mid.gmane.org/1308590980.4322.19.camel@jlt3.sipsolutions.net + */ + + assert_key_lock(sdata->local); + + update_vlan_tailroom_need_count(sdata, 1); + + if (!sdata->crypto_tx_tailroom_needed_cnt++) { + /* + * Flush all XMIT packets currently using HW encryption or no + * encryption at all if the count transition is from 0 -> 1. + */ + synchronize_net(); + } +} + +static void decrease_tailroom_need_count(struct ieee80211_sub_if_data *sdata, + int delta) +{ + assert_key_lock(sdata->local); + + WARN_ON_ONCE(sdata->crypto_tx_tailroom_needed_cnt < delta); + + update_vlan_tailroom_need_count(sdata, -delta); + sdata->crypto_tx_tailroom_needed_cnt -= delta; +} + +static int ieee80211_key_enable_hw_accel(struct ieee80211_key *key) +{ + struct ieee80211_sub_if_data *sdata = key->sdata; + struct sta_info *sta; + int ret = -EOPNOTSUPP; + + might_sleep(); + + if (key->flags & KEY_FLAG_TAINTED) { + /* If we get here, it's during resume and the key is + * tainted so shouldn't be used/programmed any more. + * However, its flags may still indicate that it was + * programmed into the device (since we're in resume) + * so clear that flag now to avoid trying to remove + * it again later. + */ + if (key->flags & KEY_FLAG_UPLOADED_TO_HARDWARE && + !(key->conf.flags & (IEEE80211_KEY_FLAG_GENERATE_MMIC | + IEEE80211_KEY_FLAG_PUT_MIC_SPACE | + IEEE80211_KEY_FLAG_RESERVE_TAILROOM))) + increment_tailroom_need_count(sdata); + + key->flags &= ~KEY_FLAG_UPLOADED_TO_HARDWARE; + return -EINVAL; + } + + if (!key->local->ops->set_key) + goto out_unsupported; + + assert_key_lock(key->local); + + sta = key->sta; + + /* + * If this is a per-STA GTK, check if it + * is supported; if not, return. + */ + if (sta && !(key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE) && + !ieee80211_hw_check(&key->local->hw, SUPPORTS_PER_STA_GTK)) + goto out_unsupported; + + if (sta && !sta->uploaded) + goto out_unsupported; + + if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) { + /* + * The driver doesn't know anything about VLAN interfaces. + * Hence, don't send GTKs for VLAN interfaces to the driver. + */ + if (!(key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE)) { + ret = 1; + goto out_unsupported; + } + } + + if (key->conf.link_id >= 0 && sdata->vif.active_links && + !(sdata->vif.active_links & BIT(key->conf.link_id))) + return 0; + + ret = drv_set_key(key->local, SET_KEY, sdata, + sta ? &sta->sta : NULL, &key->conf); + + if (!ret) { + key->flags |= KEY_FLAG_UPLOADED_TO_HARDWARE; + + if (!(key->conf.flags & (IEEE80211_KEY_FLAG_GENERATE_MMIC | + IEEE80211_KEY_FLAG_PUT_MIC_SPACE | + IEEE80211_KEY_FLAG_RESERVE_TAILROOM))) + decrease_tailroom_need_count(sdata, 1); + + WARN_ON((key->conf.flags & IEEE80211_KEY_FLAG_PUT_IV_SPACE) && + (key->conf.flags & IEEE80211_KEY_FLAG_GENERATE_IV)); + + WARN_ON((key->conf.flags & IEEE80211_KEY_FLAG_PUT_MIC_SPACE) && + (key->conf.flags & IEEE80211_KEY_FLAG_GENERATE_MMIC)); + + return 0; + } + + if (ret != -ENOSPC && ret != -EOPNOTSUPP && ret != 1) + sdata_err(sdata, + "failed to set key (%d, %pM) to hardware (%d)\n", + key->conf.keyidx, + sta ? sta->sta.addr : bcast_addr, ret); + + out_unsupported: + switch (key->conf.cipher) { + case WLAN_CIPHER_SUITE_WEP40: + case WLAN_CIPHER_SUITE_WEP104: + case WLAN_CIPHER_SUITE_TKIP: + case WLAN_CIPHER_SUITE_CCMP: + case WLAN_CIPHER_SUITE_CCMP_256: + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + /* all of these we can do in software - if driver can */ + if (ret == 1) + return 0; + if (ieee80211_hw_check(&key->local->hw, SW_CRYPTO_CONTROL)) + return -EINVAL; + return 0; + default: + return -EINVAL; + } +} + +static void ieee80211_key_disable_hw_accel(struct ieee80211_key *key) +{ + struct ieee80211_sub_if_data *sdata; + struct sta_info *sta; + int ret; + + might_sleep(); + + if (!key || !key->local->ops->set_key) + return; + + assert_key_lock(key->local); + + if (!(key->flags & KEY_FLAG_UPLOADED_TO_HARDWARE)) + return; + + sta = key->sta; + sdata = key->sdata; + + if (key->conf.link_id >= 0 && sdata->vif.active_links && + !(sdata->vif.active_links & BIT(key->conf.link_id))) + return; + + if (!(key->conf.flags & (IEEE80211_KEY_FLAG_GENERATE_MMIC | + IEEE80211_KEY_FLAG_PUT_MIC_SPACE | + IEEE80211_KEY_FLAG_RESERVE_TAILROOM))) + increment_tailroom_need_count(sdata); + + key->flags &= ~KEY_FLAG_UPLOADED_TO_HARDWARE; + ret = drv_set_key(key->local, DISABLE_KEY, sdata, + sta ? &sta->sta : NULL, &key->conf); + + if (ret) + sdata_err(sdata, + "failed to remove key (%d, %pM) from hardware (%d)\n", + key->conf.keyidx, + sta ? sta->sta.addr : bcast_addr, ret); +} + +static int _ieee80211_set_tx_key(struct ieee80211_key *key, bool force) +{ + struct sta_info *sta = key->sta; + struct ieee80211_local *local = key->local; + + assert_key_lock(local); + + set_sta_flag(sta, WLAN_STA_USES_ENCRYPTION); + + sta->ptk_idx = key->conf.keyidx; + + if (force || !ieee80211_hw_check(&local->hw, AMPDU_KEYBORDER_SUPPORT)) + clear_sta_flag(sta, WLAN_STA_BLOCK_BA); + ieee80211_check_fast_xmit(sta); + + return 0; +} + +int ieee80211_set_tx_key(struct ieee80211_key *key) +{ + return _ieee80211_set_tx_key(key, false); +} + +static void ieee80211_pairwise_rekey(struct ieee80211_key *old, + struct ieee80211_key *new) +{ + struct ieee80211_local *local = new->local; + struct sta_info *sta = new->sta; + int i; + + assert_key_lock(local); + + if (new->conf.flags & IEEE80211_KEY_FLAG_NO_AUTO_TX) { + /* Extended Key ID key install, initial one or rekey */ + + if (sta->ptk_idx != INVALID_PTK_KEYIDX && + !ieee80211_hw_check(&local->hw, AMPDU_KEYBORDER_SUPPORT)) { + /* Aggregation Sessions with Extended Key ID must not + * mix MPDUs with different keyIDs within one A-MPDU. + * Tear down running Tx aggregation sessions and block + * new Rx/Tx aggregation requests during rekey to + * ensure there are no A-MPDUs when the driver is not + * supporting A-MPDU key borders. (Blocking Tx only + * would be sufficient but WLAN_STA_BLOCK_BA gets the + * job done for the few ms we need it.) + */ + set_sta_flag(sta, WLAN_STA_BLOCK_BA); + mutex_lock(&sta->ampdu_mlme.mtx); + for (i = 0; i < IEEE80211_NUM_TIDS; i++) + ___ieee80211_stop_tx_ba_session(sta, i, + AGG_STOP_LOCAL_REQUEST); + mutex_unlock(&sta->ampdu_mlme.mtx); + } + } else if (old) { + /* Rekey without Extended Key ID. + * Aggregation sessions are OK when running on SW crypto. + * A broken remote STA may cause issues not observed with HW + * crypto, though. + */ + if (!(old->flags & KEY_FLAG_UPLOADED_TO_HARDWARE)) + return; + + /* Stop Tx till we are on the new key */ + old->flags |= KEY_FLAG_TAINTED; + ieee80211_clear_fast_xmit(sta); + if (ieee80211_hw_check(&local->hw, AMPDU_AGGREGATION)) { + set_sta_flag(sta, WLAN_STA_BLOCK_BA); + ieee80211_sta_tear_down_BA_sessions(sta, + AGG_STOP_LOCAL_REQUEST); + } + if (!wiphy_ext_feature_isset(local->hw.wiphy, + NL80211_EXT_FEATURE_CAN_REPLACE_PTK0)) { + pr_warn_ratelimited("Rekeying PTK for STA %pM but driver can't safely do that.", + sta->sta.addr); + /* Flushing the driver queues *may* help prevent + * the clear text leaks and freezes. + */ + ieee80211_flush_queues(local, old->sdata, false); + } + } +} + +static void __ieee80211_set_default_key(struct ieee80211_link_data *link, + int idx, bool uni, bool multi) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_key *key = NULL; + + assert_key_lock(sdata->local); + + if (idx >= 0 && idx < NUM_DEFAULT_KEYS) { + key = key_mtx_dereference(sdata->local, sdata->keys[idx]); + if (!key) + key = key_mtx_dereference(sdata->local, link->gtk[idx]); + } + + if (uni) { + rcu_assign_pointer(sdata->default_unicast_key, key); + ieee80211_check_fast_xmit_iface(sdata); + if (sdata->vif.type != NL80211_IFTYPE_AP_VLAN) + drv_set_default_unicast_key(sdata->local, sdata, idx); + } + + if (multi) + rcu_assign_pointer(link->default_multicast_key, key); + + ieee80211_debugfs_key_update_default(sdata); +} + +void ieee80211_set_default_key(struct ieee80211_link_data *link, int idx, + bool uni, bool multi) +{ + mutex_lock(&link->sdata->local->key_mtx); + __ieee80211_set_default_key(link, idx, uni, multi); + mutex_unlock(&link->sdata->local->key_mtx); +} + +static void +__ieee80211_set_default_mgmt_key(struct ieee80211_link_data *link, int idx) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_key *key = NULL; + + assert_key_lock(sdata->local); + + if (idx >= NUM_DEFAULT_KEYS && + idx < NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS) + key = key_mtx_dereference(sdata->local, link->gtk[idx]); + + rcu_assign_pointer(link->default_mgmt_key, key); + + ieee80211_debugfs_key_update_default(sdata); +} + +void ieee80211_set_default_mgmt_key(struct ieee80211_link_data *link, + int idx) +{ + mutex_lock(&link->sdata->local->key_mtx); + __ieee80211_set_default_mgmt_key(link, idx); + mutex_unlock(&link->sdata->local->key_mtx); +} + +static void +__ieee80211_set_default_beacon_key(struct ieee80211_link_data *link, int idx) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_key *key = NULL; + + assert_key_lock(sdata->local); + + if (idx >= NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS && + idx < NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS + + NUM_DEFAULT_BEACON_KEYS) + key = key_mtx_dereference(sdata->local, link->gtk[idx]); + + rcu_assign_pointer(link->default_beacon_key, key); + + ieee80211_debugfs_key_update_default(sdata); +} + +void ieee80211_set_default_beacon_key(struct ieee80211_link_data *link, + int idx) +{ + mutex_lock(&link->sdata->local->key_mtx); + __ieee80211_set_default_beacon_key(link, idx); + mutex_unlock(&link->sdata->local->key_mtx); +} + +static int ieee80211_key_replace(struct ieee80211_sub_if_data *sdata, + struct ieee80211_link_data *link, + struct sta_info *sta, + bool pairwise, + struct ieee80211_key *old, + struct ieee80211_key *new) +{ + struct link_sta_info *link_sta = sta ? &sta->deflink : NULL; + int link_id; + int idx; + int ret = 0; + bool defunikey, defmultikey, defmgmtkey, defbeaconkey; + bool is_wep; + + /* caller must provide at least one old/new */ + if (WARN_ON(!new && !old)) + return 0; + + if (new) { + idx = new->conf.keyidx; + is_wep = new->conf.cipher == WLAN_CIPHER_SUITE_WEP40 || + new->conf.cipher == WLAN_CIPHER_SUITE_WEP104; + link_id = new->conf.link_id; + } else { + idx = old->conf.keyidx; + is_wep = old->conf.cipher == WLAN_CIPHER_SUITE_WEP40 || + old->conf.cipher == WLAN_CIPHER_SUITE_WEP104; + link_id = old->conf.link_id; + } + + if (WARN(old && old->conf.link_id != link_id, + "old link ID %d doesn't match new link ID %d\n", + old->conf.link_id, link_id)) + return -EINVAL; + + if (link_id >= 0) { + if (!link) { + link = sdata_dereference(sdata->link[link_id], sdata); + if (!link) + return -ENOLINK; + } + + if (sta) { + link_sta = rcu_dereference_protected(sta->link[link_id], + lockdep_is_held(&sta->local->sta_mtx)); + if (!link_sta) + return -ENOLINK; + } + } else { + link = &sdata->deflink; + } + + if ((is_wep || pairwise) && idx >= NUM_DEFAULT_KEYS) + return -EINVAL; + + WARN_ON(new && old && new->conf.keyidx != old->conf.keyidx); + + if (new && sta && pairwise) { + /* Unicast rekey needs special handling. With Extended Key ID + * old is still NULL for the first rekey. + */ + ieee80211_pairwise_rekey(old, new); + } + + if (old) { + if (old->flags & KEY_FLAG_UPLOADED_TO_HARDWARE) { + ieee80211_key_disable_hw_accel(old); + + if (new) + ret = ieee80211_key_enable_hw_accel(new); + } + } else { + if (!new->local->wowlan) + ret = ieee80211_key_enable_hw_accel(new); + } + + if (ret) + return ret; + + if (new) + list_add_tail_rcu(&new->list, &sdata->key_list); + + if (sta) { + if (pairwise) { + rcu_assign_pointer(sta->ptk[idx], new); + if (new && + !(new->conf.flags & IEEE80211_KEY_FLAG_NO_AUTO_TX)) + _ieee80211_set_tx_key(new, true); + } else { + rcu_assign_pointer(link_sta->gtk[idx], new); + } + /* Only needed for transition from no key -> key. + * Still triggers unnecessary when using Extended Key ID + * and installing the second key ID the first time. + */ + if (new && !old) + ieee80211_check_fast_rx(sta); + } else { + defunikey = old && + old == key_mtx_dereference(sdata->local, + sdata->default_unicast_key); + defmultikey = old && + old == key_mtx_dereference(sdata->local, + link->default_multicast_key); + defmgmtkey = old && + old == key_mtx_dereference(sdata->local, + link->default_mgmt_key); + defbeaconkey = old && + old == key_mtx_dereference(sdata->local, + link->default_beacon_key); + + if (defunikey && !new) + __ieee80211_set_default_key(link, -1, true, false); + if (defmultikey && !new) + __ieee80211_set_default_key(link, -1, false, true); + if (defmgmtkey && !new) + __ieee80211_set_default_mgmt_key(link, -1); + if (defbeaconkey && !new) + __ieee80211_set_default_beacon_key(link, -1); + + if (is_wep || pairwise) + rcu_assign_pointer(sdata->keys[idx], new); + else + rcu_assign_pointer(link->gtk[idx], new); + + if (defunikey && new) + __ieee80211_set_default_key(link, new->conf.keyidx, + true, false); + if (defmultikey && new) + __ieee80211_set_default_key(link, new->conf.keyidx, + false, true); + if (defmgmtkey && new) + __ieee80211_set_default_mgmt_key(link, + new->conf.keyidx); + if (defbeaconkey && new) + __ieee80211_set_default_beacon_key(link, + new->conf.keyidx); + } + + if (old) + list_del_rcu(&old->list); + + return 0; +} + +struct ieee80211_key * +ieee80211_key_alloc(u32 cipher, int idx, size_t key_len, + const u8 *key_data, + size_t seq_len, const u8 *seq) +{ + struct ieee80211_key *key; + int i, j, err; + + if (WARN_ON(idx < 0 || + idx >= NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS + + NUM_DEFAULT_BEACON_KEYS)) + return ERR_PTR(-EINVAL); + + key = kzalloc(sizeof(struct ieee80211_key) + key_len, GFP_KERNEL); + if (!key) + return ERR_PTR(-ENOMEM); + + /* + * Default to software encryption; we'll later upload the + * key to the hardware if possible. + */ + key->conf.flags = 0; + key->flags = 0; + + key->conf.link_id = -1; + key->conf.cipher = cipher; + key->conf.keyidx = idx; + key->conf.keylen = key_len; + switch (cipher) { + case WLAN_CIPHER_SUITE_WEP40: + case WLAN_CIPHER_SUITE_WEP104: + key->conf.iv_len = IEEE80211_WEP_IV_LEN; + key->conf.icv_len = IEEE80211_WEP_ICV_LEN; + break; + case WLAN_CIPHER_SUITE_TKIP: + key->conf.iv_len = IEEE80211_TKIP_IV_LEN; + key->conf.icv_len = IEEE80211_TKIP_ICV_LEN; + if (seq) { + for (i = 0; i < IEEE80211_NUM_TIDS; i++) { + key->u.tkip.rx[i].iv32 = + get_unaligned_le32(&seq[2]); + key->u.tkip.rx[i].iv16 = + get_unaligned_le16(seq); + } + } + spin_lock_init(&key->u.tkip.txlock); + break; + case WLAN_CIPHER_SUITE_CCMP: + key->conf.iv_len = IEEE80211_CCMP_HDR_LEN; + key->conf.icv_len = IEEE80211_CCMP_MIC_LEN; + if (seq) { + for (i = 0; i < IEEE80211_NUM_TIDS + 1; i++) + for (j = 0; j < IEEE80211_CCMP_PN_LEN; j++) + key->u.ccmp.rx_pn[i][j] = + seq[IEEE80211_CCMP_PN_LEN - j - 1]; + } + /* + * Initialize AES key state here as an optimization so that + * it does not need to be initialized for every packet. + */ + key->u.ccmp.tfm = ieee80211_aes_key_setup_encrypt( + key_data, key_len, IEEE80211_CCMP_MIC_LEN); + if (IS_ERR(key->u.ccmp.tfm)) { + err = PTR_ERR(key->u.ccmp.tfm); + kfree(key); + return ERR_PTR(err); + } + break; + case WLAN_CIPHER_SUITE_CCMP_256: + key->conf.iv_len = IEEE80211_CCMP_256_HDR_LEN; + key->conf.icv_len = IEEE80211_CCMP_256_MIC_LEN; + for (i = 0; seq && i < IEEE80211_NUM_TIDS + 1; i++) + for (j = 0; j < IEEE80211_CCMP_256_PN_LEN; j++) + key->u.ccmp.rx_pn[i][j] = + seq[IEEE80211_CCMP_256_PN_LEN - j - 1]; + /* Initialize AES key state here as an optimization so that + * it does not need to be initialized for every packet. + */ + key->u.ccmp.tfm = ieee80211_aes_key_setup_encrypt( + key_data, key_len, IEEE80211_CCMP_256_MIC_LEN); + if (IS_ERR(key->u.ccmp.tfm)) { + err = PTR_ERR(key->u.ccmp.tfm); + kfree(key); + return ERR_PTR(err); + } + break; + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + key->conf.iv_len = 0; + if (cipher == WLAN_CIPHER_SUITE_AES_CMAC) + key->conf.icv_len = sizeof(struct ieee80211_mmie); + else + key->conf.icv_len = sizeof(struct ieee80211_mmie_16); + if (seq) + for (j = 0; j < IEEE80211_CMAC_PN_LEN; j++) + key->u.aes_cmac.rx_pn[j] = + seq[IEEE80211_CMAC_PN_LEN - j - 1]; + /* + * Initialize AES key state here as an optimization so that + * it does not need to be initialized for every packet. + */ + key->u.aes_cmac.tfm = + ieee80211_aes_cmac_key_setup(key_data, key_len); + if (IS_ERR(key->u.aes_cmac.tfm)) { + err = PTR_ERR(key->u.aes_cmac.tfm); + kfree(key); + return ERR_PTR(err); + } + break; + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + key->conf.iv_len = 0; + key->conf.icv_len = sizeof(struct ieee80211_mmie_16); + if (seq) + for (j = 0; j < IEEE80211_GMAC_PN_LEN; j++) + key->u.aes_gmac.rx_pn[j] = + seq[IEEE80211_GMAC_PN_LEN - j - 1]; + /* Initialize AES key state here as an optimization so that + * it does not need to be initialized for every packet. + */ + key->u.aes_gmac.tfm = + ieee80211_aes_gmac_key_setup(key_data, key_len); + if (IS_ERR(key->u.aes_gmac.tfm)) { + err = PTR_ERR(key->u.aes_gmac.tfm); + kfree(key); + return ERR_PTR(err); + } + break; + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + key->conf.iv_len = IEEE80211_GCMP_HDR_LEN; + key->conf.icv_len = IEEE80211_GCMP_MIC_LEN; + for (i = 0; seq && i < IEEE80211_NUM_TIDS + 1; i++) + for (j = 0; j < IEEE80211_GCMP_PN_LEN; j++) + key->u.gcmp.rx_pn[i][j] = + seq[IEEE80211_GCMP_PN_LEN - j - 1]; + /* Initialize AES key state here as an optimization so that + * it does not need to be initialized for every packet. + */ + key->u.gcmp.tfm = ieee80211_aes_gcm_key_setup_encrypt(key_data, + key_len); + if (IS_ERR(key->u.gcmp.tfm)) { + err = PTR_ERR(key->u.gcmp.tfm); + kfree(key); + return ERR_PTR(err); + } + break; + } + memcpy(key->conf.key, key_data, key_len); + INIT_LIST_HEAD(&key->list); + + return key; +} + +static void ieee80211_key_free_common(struct ieee80211_key *key) +{ + switch (key->conf.cipher) { + case WLAN_CIPHER_SUITE_CCMP: + case WLAN_CIPHER_SUITE_CCMP_256: + ieee80211_aes_key_free(key->u.ccmp.tfm); + break; + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + ieee80211_aes_cmac_key_free(key->u.aes_cmac.tfm); + break; + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + ieee80211_aes_gmac_key_free(key->u.aes_gmac.tfm); + break; + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + ieee80211_aes_gcm_key_free(key->u.gcmp.tfm); + break; + } + kfree_sensitive(key); +} + +static void __ieee80211_key_destroy(struct ieee80211_key *key, + bool delay_tailroom) +{ + if (key->local) { + struct ieee80211_sub_if_data *sdata = key->sdata; + + ieee80211_debugfs_key_remove(key); + + if (delay_tailroom) { + /* see ieee80211_delayed_tailroom_dec */ + sdata->crypto_tx_tailroom_pending_dec++; + schedule_delayed_work(&sdata->dec_tailroom_needed_wk, + HZ/2); + } else { + decrease_tailroom_need_count(sdata, 1); + } + } + + ieee80211_key_free_common(key); +} + +static void ieee80211_key_destroy(struct ieee80211_key *key, + bool delay_tailroom) +{ + if (!key) + return; + + /* + * Synchronize so the TX path and rcu key iterators + * can no longer be using this key before we free/remove it. + */ + synchronize_net(); + + __ieee80211_key_destroy(key, delay_tailroom); +} + +void ieee80211_key_free_unused(struct ieee80211_key *key) +{ + WARN_ON(key->sdata || key->local); + ieee80211_key_free_common(key); +} + +static bool ieee80211_key_identical(struct ieee80211_sub_if_data *sdata, + struct ieee80211_key *old, + struct ieee80211_key *new) +{ + u8 tkip_old[WLAN_KEY_LEN_TKIP], tkip_new[WLAN_KEY_LEN_TKIP]; + u8 *tk_old, *tk_new; + + if (!old || new->conf.keylen != old->conf.keylen) + return false; + + tk_old = old->conf.key; + tk_new = new->conf.key; + + /* + * In station mode, don't compare the TX MIC key, as it's never used + * and offloaded rekeying may not care to send it to the host. This + * is the case in iwlwifi, for example. + */ + if (sdata->vif.type == NL80211_IFTYPE_STATION && + new->conf.cipher == WLAN_CIPHER_SUITE_TKIP && + new->conf.keylen == WLAN_KEY_LEN_TKIP && + !(new->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE)) { + memcpy(tkip_old, tk_old, WLAN_KEY_LEN_TKIP); + memcpy(tkip_new, tk_new, WLAN_KEY_LEN_TKIP); + memset(tkip_old + NL80211_TKIP_DATA_OFFSET_TX_MIC_KEY, 0, 8); + memset(tkip_new + NL80211_TKIP_DATA_OFFSET_TX_MIC_KEY, 0, 8); + tk_old = tkip_old; + tk_new = tkip_new; + } + + return !crypto_memneq(tk_old, tk_new, new->conf.keylen); +} + +int ieee80211_key_link(struct ieee80211_key *key, + struct ieee80211_link_data *link, + struct sta_info *sta) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + static atomic_t key_color = ATOMIC_INIT(0); + struct ieee80211_key *old_key = NULL; + int idx = key->conf.keyidx; + bool pairwise = key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE; + /* + * We want to delay tailroom updates only for station - in that + * case it helps roaming speed, but in other cases it hurts and + * can cause warnings to appear. + */ + bool delay_tailroom = sdata->vif.type == NL80211_IFTYPE_STATION; + int ret = -EOPNOTSUPP; + + mutex_lock(&sdata->local->key_mtx); + + if (sta && pairwise) { + struct ieee80211_key *alt_key; + + old_key = key_mtx_dereference(sdata->local, sta->ptk[idx]); + alt_key = key_mtx_dereference(sdata->local, sta->ptk[idx ^ 1]); + + /* The rekey code assumes that the old and new key are using + * the same cipher. Enforce the assumption for pairwise keys. + */ + if ((alt_key && alt_key->conf.cipher != key->conf.cipher) || + (old_key && old_key->conf.cipher != key->conf.cipher)) + goto out; + } else if (sta) { + struct link_sta_info *link_sta = &sta->deflink; + int link_id = key->conf.link_id; + + if (link_id >= 0) { + link_sta = rcu_dereference_protected(sta->link[link_id], + lockdep_is_held(&sta->local->sta_mtx)); + if (!link_sta) { + ret = -ENOLINK; + goto out; + } + } + + old_key = key_mtx_dereference(sdata->local, link_sta->gtk[idx]); + } else { + if (idx < NUM_DEFAULT_KEYS) + old_key = key_mtx_dereference(sdata->local, + sdata->keys[idx]); + if (!old_key) + old_key = key_mtx_dereference(sdata->local, + link->gtk[idx]); + } + + /* Non-pairwise keys must also not switch the cipher on rekey */ + if (!pairwise) { + if (old_key && old_key->conf.cipher != key->conf.cipher) + goto out; + } + + /* + * Silently accept key re-installation without really installing the + * new version of the key to avoid nonce reuse or replay issues. + */ + if (ieee80211_key_identical(sdata, old_key, key)) { + ieee80211_key_free_unused(key); + ret = -EALREADY; + goto out; + } + + key->local = sdata->local; + key->sdata = sdata; + key->sta = sta; + + /* + * Assign a unique ID to every key so we can easily prevent mixed + * key and fragment cache attacks. + */ + key->color = atomic_inc_return(&key_color); + + increment_tailroom_need_count(sdata); + + ret = ieee80211_key_replace(sdata, link, sta, pairwise, old_key, key); + + if (!ret) { + ieee80211_debugfs_key_add(key); + ieee80211_key_destroy(old_key, delay_tailroom); + } else { + ieee80211_key_free(key, delay_tailroom); + } + + out: + mutex_unlock(&sdata->local->key_mtx); + + return ret; +} + +void ieee80211_key_free(struct ieee80211_key *key, bool delay_tailroom) +{ + if (!key) + return; + + /* + * Replace key with nothingness if it was ever used. + */ + if (key->sdata) + ieee80211_key_replace(key->sdata, NULL, key->sta, + key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE, + key, NULL); + ieee80211_key_destroy(key, delay_tailroom); +} + +void ieee80211_reenable_keys(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_key *key; + struct ieee80211_sub_if_data *vlan; + + lockdep_assert_wiphy(sdata->local->hw.wiphy); + + mutex_lock(&sdata->local->key_mtx); + + sdata->crypto_tx_tailroom_needed_cnt = 0; + sdata->crypto_tx_tailroom_pending_dec = 0; + + if (sdata->vif.type == NL80211_IFTYPE_AP) { + list_for_each_entry(vlan, &sdata->u.ap.vlans, u.vlan.list) { + vlan->crypto_tx_tailroom_needed_cnt = 0; + vlan->crypto_tx_tailroom_pending_dec = 0; + } + } + + if (ieee80211_sdata_running(sdata)) { + list_for_each_entry(key, &sdata->key_list, list) { + increment_tailroom_need_count(sdata); + ieee80211_key_enable_hw_accel(key); + } + } + + mutex_unlock(&sdata->local->key_mtx); +} + +void ieee80211_iter_keys(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, + void (*iter)(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, + struct ieee80211_sta *sta, + struct ieee80211_key_conf *key, + void *data), + void *iter_data) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_key *key, *tmp; + struct ieee80211_sub_if_data *sdata; + + lockdep_assert_wiphy(hw->wiphy); + + mutex_lock(&local->key_mtx); + if (vif) { + sdata = vif_to_sdata(vif); + list_for_each_entry_safe(key, tmp, &sdata->key_list, list) + iter(hw, &sdata->vif, + key->sta ? &key->sta->sta : NULL, + &key->conf, iter_data); + } else { + list_for_each_entry(sdata, &local->interfaces, list) + list_for_each_entry_safe(key, tmp, + &sdata->key_list, list) + iter(hw, &sdata->vif, + key->sta ? &key->sta->sta : NULL, + &key->conf, iter_data); + } + mutex_unlock(&local->key_mtx); +} +EXPORT_SYMBOL(ieee80211_iter_keys); + +static void +_ieee80211_iter_keys_rcu(struct ieee80211_hw *hw, + struct ieee80211_sub_if_data *sdata, + void (*iter)(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, + struct ieee80211_sta *sta, + struct ieee80211_key_conf *key, + void *data), + void *iter_data) +{ + struct ieee80211_key *key; + + list_for_each_entry_rcu(key, &sdata->key_list, list) { + /* skip keys of station in removal process */ + if (key->sta && key->sta->removed) + continue; + if (!(key->flags & KEY_FLAG_UPLOADED_TO_HARDWARE)) + continue; + + iter(hw, &sdata->vif, + key->sta ? &key->sta->sta : NULL, + &key->conf, iter_data); + } +} + +void ieee80211_iter_keys_rcu(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, + void (*iter)(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, + struct ieee80211_sta *sta, + struct ieee80211_key_conf *key, + void *data), + void *iter_data) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_sub_if_data *sdata; + + if (vif) { + sdata = vif_to_sdata(vif); + _ieee80211_iter_keys_rcu(hw, sdata, iter, iter_data); + } else { + list_for_each_entry_rcu(sdata, &local->interfaces, list) + _ieee80211_iter_keys_rcu(hw, sdata, iter, iter_data); + } +} +EXPORT_SYMBOL(ieee80211_iter_keys_rcu); + +static void ieee80211_free_keys_iface(struct ieee80211_sub_if_data *sdata, + struct list_head *keys) +{ + struct ieee80211_key *key, *tmp; + + decrease_tailroom_need_count(sdata, + sdata->crypto_tx_tailroom_pending_dec); + sdata->crypto_tx_tailroom_pending_dec = 0; + + ieee80211_debugfs_key_remove_mgmt_default(sdata); + ieee80211_debugfs_key_remove_beacon_default(sdata); + + list_for_each_entry_safe(key, tmp, &sdata->key_list, list) { + ieee80211_key_replace(key->sdata, NULL, key->sta, + key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE, + key, NULL); + list_add_tail(&key->list, keys); + } + + ieee80211_debugfs_key_update_default(sdata); +} + +void ieee80211_remove_link_keys(struct ieee80211_link_data *link, + struct list_head *keys) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_key *key, *tmp; + + mutex_lock(&local->key_mtx); + list_for_each_entry_safe(key, tmp, &sdata->key_list, list) { + if (key->conf.link_id != link->link_id) + continue; + ieee80211_key_replace(key->sdata, link, key->sta, + key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE, + key, NULL); + list_add_tail(&key->list, keys); + } + mutex_unlock(&local->key_mtx); +} + +void ieee80211_free_key_list(struct ieee80211_local *local, + struct list_head *keys) +{ + struct ieee80211_key *key, *tmp; + + mutex_lock(&local->key_mtx); + list_for_each_entry_safe(key, tmp, keys, list) + __ieee80211_key_destroy(key, false); + mutex_unlock(&local->key_mtx); +} + +void ieee80211_free_keys(struct ieee80211_sub_if_data *sdata, + bool force_synchronize) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_sub_if_data *vlan; + struct ieee80211_sub_if_data *master; + struct ieee80211_key *key, *tmp; + LIST_HEAD(keys); + + cancel_delayed_work_sync(&sdata->dec_tailroom_needed_wk); + + mutex_lock(&local->key_mtx); + + ieee80211_free_keys_iface(sdata, &keys); + + if (sdata->vif.type == NL80211_IFTYPE_AP) { + list_for_each_entry(vlan, &sdata->u.ap.vlans, u.vlan.list) + ieee80211_free_keys_iface(vlan, &keys); + } + + if (!list_empty(&keys) || force_synchronize) + synchronize_net(); + list_for_each_entry_safe(key, tmp, &keys, list) + __ieee80211_key_destroy(key, false); + + if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) { + if (sdata->bss) { + master = container_of(sdata->bss, + struct ieee80211_sub_if_data, + u.ap); + + WARN_ON_ONCE(sdata->crypto_tx_tailroom_needed_cnt != + master->crypto_tx_tailroom_needed_cnt); + } + } else { + WARN_ON_ONCE(sdata->crypto_tx_tailroom_needed_cnt || + sdata->crypto_tx_tailroom_pending_dec); + } + + if (sdata->vif.type == NL80211_IFTYPE_AP) { + list_for_each_entry(vlan, &sdata->u.ap.vlans, u.vlan.list) + WARN_ON_ONCE(vlan->crypto_tx_tailroom_needed_cnt || + vlan->crypto_tx_tailroom_pending_dec); + } + + mutex_unlock(&local->key_mtx); +} + +void ieee80211_free_sta_keys(struct ieee80211_local *local, + struct sta_info *sta) +{ + struct ieee80211_key *key; + int i; + + mutex_lock(&local->key_mtx); + for (i = 0; i < ARRAY_SIZE(sta->deflink.gtk); i++) { + key = key_mtx_dereference(local, sta->deflink.gtk[i]); + if (!key) + continue; + ieee80211_key_replace(key->sdata, NULL, key->sta, + key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE, + key, NULL); + __ieee80211_key_destroy(key, key->sdata->vif.type == + NL80211_IFTYPE_STATION); + } + + for (i = 0; i < NUM_DEFAULT_KEYS; i++) { + key = key_mtx_dereference(local, sta->ptk[i]); + if (!key) + continue; + ieee80211_key_replace(key->sdata, NULL, key->sta, + key->conf.flags & IEEE80211_KEY_FLAG_PAIRWISE, + key, NULL); + __ieee80211_key_destroy(key, key->sdata->vif.type == + NL80211_IFTYPE_STATION); + } + + mutex_unlock(&local->key_mtx); +} + +void ieee80211_delayed_tailroom_dec(struct work_struct *wk) +{ + struct ieee80211_sub_if_data *sdata; + + sdata = container_of(wk, struct ieee80211_sub_if_data, + dec_tailroom_needed_wk.work); + + /* + * The reason for the delayed tailroom needed decrementing is to + * make roaming faster: during roaming, all keys are first deleted + * and then new keys are installed. The first new key causes the + * crypto_tx_tailroom_needed_cnt to go from 0 to 1, which invokes + * the cost of synchronize_net() (which can be slow). Avoid this + * by deferring the crypto_tx_tailroom_needed_cnt decrementing on + * key removal for a while, so if we roam the value is larger than + * zero and no 0->1 transition happens. + * + * The cost is that if the AP switching was from an AP with keys + * to one without, we still allocate tailroom while it would no + * longer be needed. However, in the typical (fast) roaming case + * within an ESS this usually won't happen. + */ + + mutex_lock(&sdata->local->key_mtx); + decrease_tailroom_need_count(sdata, + sdata->crypto_tx_tailroom_pending_dec); + sdata->crypto_tx_tailroom_pending_dec = 0; + mutex_unlock(&sdata->local->key_mtx); +} + +void ieee80211_gtk_rekey_notify(struct ieee80211_vif *vif, const u8 *bssid, + const u8 *replay_ctr, gfp_t gfp) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + + trace_api_gtk_rekey_notify(sdata, bssid, replay_ctr); + + cfg80211_gtk_rekey_notify(sdata->dev, bssid, replay_ctr, gfp); +} +EXPORT_SYMBOL_GPL(ieee80211_gtk_rekey_notify); + +void ieee80211_get_key_rx_seq(struct ieee80211_key_conf *keyconf, + int tid, struct ieee80211_key_seq *seq) +{ + struct ieee80211_key *key; + const u8 *pn; + + key = container_of(keyconf, struct ieee80211_key, conf); + + switch (key->conf.cipher) { + case WLAN_CIPHER_SUITE_TKIP: + if (WARN_ON(tid < 0 || tid >= IEEE80211_NUM_TIDS)) + return; + seq->tkip.iv32 = key->u.tkip.rx[tid].iv32; + seq->tkip.iv16 = key->u.tkip.rx[tid].iv16; + break; + case WLAN_CIPHER_SUITE_CCMP: + case WLAN_CIPHER_SUITE_CCMP_256: + if (WARN_ON(tid < -1 || tid >= IEEE80211_NUM_TIDS)) + return; + if (tid < 0) + pn = key->u.ccmp.rx_pn[IEEE80211_NUM_TIDS]; + else + pn = key->u.ccmp.rx_pn[tid]; + memcpy(seq->ccmp.pn, pn, IEEE80211_CCMP_PN_LEN); + break; + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + if (WARN_ON(tid != 0)) + return; + pn = key->u.aes_cmac.rx_pn; + memcpy(seq->aes_cmac.pn, pn, IEEE80211_CMAC_PN_LEN); + break; + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + if (WARN_ON(tid != 0)) + return; + pn = key->u.aes_gmac.rx_pn; + memcpy(seq->aes_gmac.pn, pn, IEEE80211_GMAC_PN_LEN); + break; + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + if (WARN_ON(tid < -1 || tid >= IEEE80211_NUM_TIDS)) + return; + if (tid < 0) + pn = key->u.gcmp.rx_pn[IEEE80211_NUM_TIDS]; + else + pn = key->u.gcmp.rx_pn[tid]; + memcpy(seq->gcmp.pn, pn, IEEE80211_GCMP_PN_LEN); + break; + } +} +EXPORT_SYMBOL(ieee80211_get_key_rx_seq); + +void ieee80211_set_key_rx_seq(struct ieee80211_key_conf *keyconf, + int tid, struct ieee80211_key_seq *seq) +{ + struct ieee80211_key *key; + u8 *pn; + + key = container_of(keyconf, struct ieee80211_key, conf); + + switch (key->conf.cipher) { + case WLAN_CIPHER_SUITE_TKIP: + if (WARN_ON(tid < 0 || tid >= IEEE80211_NUM_TIDS)) + return; + key->u.tkip.rx[tid].iv32 = seq->tkip.iv32; + key->u.tkip.rx[tid].iv16 = seq->tkip.iv16; + break; + case WLAN_CIPHER_SUITE_CCMP: + case WLAN_CIPHER_SUITE_CCMP_256: + if (WARN_ON(tid < -1 || tid >= IEEE80211_NUM_TIDS)) + return; + if (tid < 0) + pn = key->u.ccmp.rx_pn[IEEE80211_NUM_TIDS]; + else + pn = key->u.ccmp.rx_pn[tid]; + memcpy(pn, seq->ccmp.pn, IEEE80211_CCMP_PN_LEN); + break; + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + if (WARN_ON(tid != 0)) + return; + pn = key->u.aes_cmac.rx_pn; + memcpy(pn, seq->aes_cmac.pn, IEEE80211_CMAC_PN_LEN); + break; + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + if (WARN_ON(tid != 0)) + return; + pn = key->u.aes_gmac.rx_pn; + memcpy(pn, seq->aes_gmac.pn, IEEE80211_GMAC_PN_LEN); + break; + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + if (WARN_ON(tid < -1 || tid >= IEEE80211_NUM_TIDS)) + return; + if (tid < 0) + pn = key->u.gcmp.rx_pn[IEEE80211_NUM_TIDS]; + else + pn = key->u.gcmp.rx_pn[tid]; + memcpy(pn, seq->gcmp.pn, IEEE80211_GCMP_PN_LEN); + break; + default: + WARN_ON(1); + break; + } +} +EXPORT_SYMBOL_GPL(ieee80211_set_key_rx_seq); + +void ieee80211_remove_key(struct ieee80211_key_conf *keyconf) +{ + struct ieee80211_key *key; + + key = container_of(keyconf, struct ieee80211_key, conf); + + assert_key_lock(key->local); + + /* + * if key was uploaded, we assume the driver will/has remove(d) + * it, so adjust bookkeeping accordingly + */ + if (key->flags & KEY_FLAG_UPLOADED_TO_HARDWARE) { + key->flags &= ~KEY_FLAG_UPLOADED_TO_HARDWARE; + + if (!(key->conf.flags & (IEEE80211_KEY_FLAG_GENERATE_MMIC | + IEEE80211_KEY_FLAG_PUT_MIC_SPACE | + IEEE80211_KEY_FLAG_RESERVE_TAILROOM))) + increment_tailroom_need_count(key->sdata); + } + + ieee80211_key_free(key, false); +} +EXPORT_SYMBOL_GPL(ieee80211_remove_key); + +struct ieee80211_key_conf * +ieee80211_gtk_rekey_add(struct ieee80211_vif *vif, + struct ieee80211_key_conf *keyconf) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_local *local = sdata->local; + struct ieee80211_key *key; + int err; + + if (WARN_ON(!local->wowlan)) + return ERR_PTR(-EINVAL); + + if (WARN_ON(vif->type != NL80211_IFTYPE_STATION)) + return ERR_PTR(-EINVAL); + + key = ieee80211_key_alloc(keyconf->cipher, keyconf->keyidx, + keyconf->keylen, keyconf->key, + 0, NULL); + if (IS_ERR(key)) + return ERR_CAST(key); + + if (sdata->u.mgd.mfp != IEEE80211_MFP_DISABLED) + key->conf.flags |= IEEE80211_KEY_FLAG_RX_MGMT; + + /* FIXME: this function needs to get a link ID */ + err = ieee80211_key_link(key, &sdata->deflink, NULL); + if (err) + return ERR_PTR(err); + + return &key->conf; +} +EXPORT_SYMBOL_GPL(ieee80211_gtk_rekey_add); + +void ieee80211_key_mic_failure(struct ieee80211_key_conf *keyconf) +{ + struct ieee80211_key *key; + + key = container_of(keyconf, struct ieee80211_key, conf); + + switch (key->conf.cipher) { + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + key->u.aes_cmac.icverrors++; + break; + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + key->u.aes_gmac.icverrors++; + break; + default: + /* ignore the others for now, we don't keep counters now */ + break; + } +} +EXPORT_SYMBOL_GPL(ieee80211_key_mic_failure); + +void ieee80211_key_replay(struct ieee80211_key_conf *keyconf) +{ + struct ieee80211_key *key; + + key = container_of(keyconf, struct ieee80211_key, conf); + + switch (key->conf.cipher) { + case WLAN_CIPHER_SUITE_CCMP: + case WLAN_CIPHER_SUITE_CCMP_256: + key->u.ccmp.replays++; + break; + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + key->u.aes_cmac.replays++; + break; + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + key->u.aes_gmac.replays++; + break; + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + key->u.gcmp.replays++; + break; + } +} +EXPORT_SYMBOL_GPL(ieee80211_key_replay); + +int ieee80211_key_switch_links(struct ieee80211_sub_if_data *sdata, + unsigned long del_links_mask, + unsigned long add_links_mask) +{ + struct ieee80211_key *key; + int ret; + + list_for_each_entry(key, &sdata->key_list, list) { + if (key->conf.link_id < 0 || + !(del_links_mask & BIT(key->conf.link_id))) + continue; + + /* shouldn't happen for per-link keys */ + WARN_ON(key->sta); + + ieee80211_key_disable_hw_accel(key); + } + + list_for_each_entry(key, &sdata->key_list, list) { + if (key->conf.link_id < 0 || + !(add_links_mask & BIT(key->conf.link_id))) + continue; + + /* shouldn't happen for per-link keys */ + WARN_ON(key->sta); + + ret = ieee80211_key_enable_hw_accel(key); + if (ret) + return ret; + } + + return 0; +} diff --git a/net/mac80211/key.h b/net/mac80211/key.h new file mode 100644 index 000000000..f3df97df4 --- /dev/null +++ b/net/mac80211/key.h @@ -0,0 +1,179 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright 2002-2004, Instant802 Networks, Inc. + * Copyright 2005, Devicescape Software, Inc. + * Copyright (C) 2019, 2022 Intel Corporation + */ + +#ifndef IEEE80211_KEY_H +#define IEEE80211_KEY_H + +#include +#include +#include +#include +#include +#include + +#define NUM_DEFAULT_KEYS 4 +#define NUM_DEFAULT_MGMT_KEYS 2 +#define NUM_DEFAULT_BEACON_KEYS 2 +#define INVALID_PTK_KEYIDX 2 /* Keyidx always pointing to a NULL key for PTK */ + +struct ieee80211_local; +struct ieee80211_sub_if_data; +struct ieee80211_link_data; +struct sta_info; + +/** + * enum ieee80211_internal_key_flags - internal key flags + * + * @KEY_FLAG_UPLOADED_TO_HARDWARE: Indicates that this key is present + * in the hardware for TX crypto hardware acceleration. + * @KEY_FLAG_TAINTED: Key is tainted and packets should be dropped. + */ +enum ieee80211_internal_key_flags { + KEY_FLAG_UPLOADED_TO_HARDWARE = BIT(0), + KEY_FLAG_TAINTED = BIT(1), +}; + +enum ieee80211_internal_tkip_state { + TKIP_STATE_NOT_INIT, + TKIP_STATE_PHASE1_DONE, + TKIP_STATE_PHASE1_HW_UPLOADED, +}; + +struct tkip_ctx { + u16 p1k[5]; /* p1k cache */ + u32 p1k_iv32; /* iv32 for which p1k computed */ + enum ieee80211_internal_tkip_state state; +}; + +struct tkip_ctx_rx { + struct tkip_ctx ctx; + u32 iv32; /* current iv32 */ + u16 iv16; /* current iv16 */ +}; + +struct ieee80211_key { + struct ieee80211_local *local; + struct ieee80211_sub_if_data *sdata; + struct sta_info *sta; + + /* for sdata list */ + struct list_head list; + + /* protected by key mutex */ + unsigned int flags; + + union { + struct { + /* protects tx context */ + spinlock_t txlock; + + /* last used TSC */ + struct tkip_ctx tx; + + /* last received RSC */ + struct tkip_ctx_rx rx[IEEE80211_NUM_TIDS]; + + /* number of mic failures */ + u32 mic_failures; + } tkip; + struct { + /* + * Last received packet number. The first + * IEEE80211_NUM_TIDS counters are used with Data + * frames and the last counter is used with Robust + * Management frames. + */ + u8 rx_pn[IEEE80211_NUM_TIDS + 1][IEEE80211_CCMP_PN_LEN]; + struct crypto_aead *tfm; + u32 replays; /* dot11RSNAStatsCCMPReplays */ + } ccmp; + struct { + u8 rx_pn[IEEE80211_CMAC_PN_LEN]; + struct crypto_shash *tfm; + u32 replays; /* dot11RSNAStatsCMACReplays */ + u32 icverrors; /* dot11RSNAStatsCMACICVErrors */ + } aes_cmac; + struct { + u8 rx_pn[IEEE80211_GMAC_PN_LEN]; + struct crypto_aead *tfm; + u32 replays; /* dot11RSNAStatsCMACReplays */ + u32 icverrors; /* dot11RSNAStatsCMACICVErrors */ + } aes_gmac; + struct { + /* Last received packet number. The first + * IEEE80211_NUM_TIDS counters are used with Data + * frames and the last counter is used with Robust + * Management frames. + */ + u8 rx_pn[IEEE80211_NUM_TIDS + 1][IEEE80211_GCMP_PN_LEN]; + struct crypto_aead *tfm; + u32 replays; /* dot11RSNAStatsGCMPReplays */ + } gcmp; + struct { + /* generic cipher scheme */ + u8 rx_pn[IEEE80211_NUM_TIDS + 1][IEEE80211_MAX_PN_LEN]; + } gen; + } u; + +#ifdef CONFIG_MAC80211_DEBUGFS + struct { + struct dentry *stalink; + struct dentry *dir; + int cnt; + } debugfs; +#endif + + unsigned int color; + + /* + * key config, must be last because it contains key + * material as variable length member + */ + struct ieee80211_key_conf conf; +}; + +struct ieee80211_key * +ieee80211_key_alloc(u32 cipher, int idx, size_t key_len, + const u8 *key_data, + size_t seq_len, const u8 *seq); +/* + * Insert a key into data structures (sdata, sta if necessary) + * to make it used, free old key. On failure, also free the new key. + */ +int ieee80211_key_link(struct ieee80211_key *key, + struct ieee80211_link_data *link, + struct sta_info *sta); +int ieee80211_set_tx_key(struct ieee80211_key *key); +void ieee80211_key_free(struct ieee80211_key *key, bool delay_tailroom); +void ieee80211_key_free_unused(struct ieee80211_key *key); +void ieee80211_set_default_key(struct ieee80211_link_data *link, int idx, + bool uni, bool multi); +void ieee80211_set_default_mgmt_key(struct ieee80211_link_data *link, + int idx); +void ieee80211_set_default_beacon_key(struct ieee80211_link_data *link, + int idx); +void ieee80211_remove_link_keys(struct ieee80211_link_data *link, + struct list_head *keys); +void ieee80211_free_key_list(struct ieee80211_local *local, + struct list_head *keys); +void ieee80211_free_keys(struct ieee80211_sub_if_data *sdata, + bool force_synchronize); +void ieee80211_free_sta_keys(struct ieee80211_local *local, + struct sta_info *sta); +void ieee80211_reenable_keys(struct ieee80211_sub_if_data *sdata); +int ieee80211_key_switch_links(struct ieee80211_sub_if_data *sdata, + unsigned long del_links_mask, + unsigned long add_links_mask); + +#define key_mtx_dereference(local, ref) \ + rcu_dereference_protected(ref, lockdep_is_held(&((local)->key_mtx))) +#define rcu_dereference_check_key_mtx(local, ref) \ + rcu_dereference_check(ref, lockdep_is_held(&((local)->key_mtx))) + +void ieee80211_delayed_tailroom_dec(struct work_struct *wk); + +#endif /* IEEE80211_KEY_H */ diff --git a/net/mac80211/led.c b/net/mac80211/led.c new file mode 100644 index 000000000..6de8d0ad5 --- /dev/null +++ b/net/mac80211/led.c @@ -0,0 +1,376 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2006, Johannes Berg + */ + +/* just for IFNAMSIZ */ +#include +#include +#include +#include "led.h" + +void ieee80211_led_assoc(struct ieee80211_local *local, bool associated) +{ + if (!atomic_read(&local->assoc_led_active)) + return; + if (associated) + led_trigger_event(&local->assoc_led, LED_FULL); + else + led_trigger_event(&local->assoc_led, LED_OFF); +} + +void ieee80211_led_radio(struct ieee80211_local *local, bool enabled) +{ + if (!atomic_read(&local->radio_led_active)) + return; + if (enabled) + led_trigger_event(&local->radio_led, LED_FULL); + else + led_trigger_event(&local->radio_led, LED_OFF); +} + +void ieee80211_alloc_led_names(struct ieee80211_local *local) +{ + local->rx_led.name = kasprintf(GFP_KERNEL, "%srx", + wiphy_name(local->hw.wiphy)); + local->tx_led.name = kasprintf(GFP_KERNEL, "%stx", + wiphy_name(local->hw.wiphy)); + local->assoc_led.name = kasprintf(GFP_KERNEL, "%sassoc", + wiphy_name(local->hw.wiphy)); + local->radio_led.name = kasprintf(GFP_KERNEL, "%sradio", + wiphy_name(local->hw.wiphy)); +} + +void ieee80211_free_led_names(struct ieee80211_local *local) +{ + kfree(local->rx_led.name); + kfree(local->tx_led.name); + kfree(local->assoc_led.name); + kfree(local->radio_led.name); +} + +static int ieee80211_tx_led_activate(struct led_classdev *led_cdev) +{ + struct ieee80211_local *local = container_of(led_cdev->trigger, + struct ieee80211_local, + tx_led); + + atomic_inc(&local->tx_led_active); + + return 0; +} + +static void ieee80211_tx_led_deactivate(struct led_classdev *led_cdev) +{ + struct ieee80211_local *local = container_of(led_cdev->trigger, + struct ieee80211_local, + tx_led); + + atomic_dec(&local->tx_led_active); +} + +static int ieee80211_rx_led_activate(struct led_classdev *led_cdev) +{ + struct ieee80211_local *local = container_of(led_cdev->trigger, + struct ieee80211_local, + rx_led); + + atomic_inc(&local->rx_led_active); + + return 0; +} + +static void ieee80211_rx_led_deactivate(struct led_classdev *led_cdev) +{ + struct ieee80211_local *local = container_of(led_cdev->trigger, + struct ieee80211_local, + rx_led); + + atomic_dec(&local->rx_led_active); +} + +static int ieee80211_assoc_led_activate(struct led_classdev *led_cdev) +{ + struct ieee80211_local *local = container_of(led_cdev->trigger, + struct ieee80211_local, + assoc_led); + + atomic_inc(&local->assoc_led_active); + + return 0; +} + +static void ieee80211_assoc_led_deactivate(struct led_classdev *led_cdev) +{ + struct ieee80211_local *local = container_of(led_cdev->trigger, + struct ieee80211_local, + assoc_led); + + atomic_dec(&local->assoc_led_active); +} + +static int ieee80211_radio_led_activate(struct led_classdev *led_cdev) +{ + struct ieee80211_local *local = container_of(led_cdev->trigger, + struct ieee80211_local, + radio_led); + + atomic_inc(&local->radio_led_active); + + return 0; +} + +static void ieee80211_radio_led_deactivate(struct led_classdev *led_cdev) +{ + struct ieee80211_local *local = container_of(led_cdev->trigger, + struct ieee80211_local, + radio_led); + + atomic_dec(&local->radio_led_active); +} + +static int ieee80211_tpt_led_activate(struct led_classdev *led_cdev) +{ + struct ieee80211_local *local = container_of(led_cdev->trigger, + struct ieee80211_local, + tpt_led); + + atomic_inc(&local->tpt_led_active); + + return 0; +} + +static void ieee80211_tpt_led_deactivate(struct led_classdev *led_cdev) +{ + struct ieee80211_local *local = container_of(led_cdev->trigger, + struct ieee80211_local, + tpt_led); + + atomic_dec(&local->tpt_led_active); +} + +void ieee80211_led_init(struct ieee80211_local *local) +{ + atomic_set(&local->rx_led_active, 0); + local->rx_led.activate = ieee80211_rx_led_activate; + local->rx_led.deactivate = ieee80211_rx_led_deactivate; + if (local->rx_led.name && led_trigger_register(&local->rx_led)) { + kfree(local->rx_led.name); + local->rx_led.name = NULL; + } + + atomic_set(&local->tx_led_active, 0); + local->tx_led.activate = ieee80211_tx_led_activate; + local->tx_led.deactivate = ieee80211_tx_led_deactivate; + if (local->tx_led.name && led_trigger_register(&local->tx_led)) { + kfree(local->tx_led.name); + local->tx_led.name = NULL; + } + + atomic_set(&local->assoc_led_active, 0); + local->assoc_led.activate = ieee80211_assoc_led_activate; + local->assoc_led.deactivate = ieee80211_assoc_led_deactivate; + if (local->assoc_led.name && led_trigger_register(&local->assoc_led)) { + kfree(local->assoc_led.name); + local->assoc_led.name = NULL; + } + + atomic_set(&local->radio_led_active, 0); + local->radio_led.activate = ieee80211_radio_led_activate; + local->radio_led.deactivate = ieee80211_radio_led_deactivate; + if (local->radio_led.name && led_trigger_register(&local->radio_led)) { + kfree(local->radio_led.name); + local->radio_led.name = NULL; + } + + atomic_set(&local->tpt_led_active, 0); + if (local->tpt_led_trigger) { + local->tpt_led.activate = ieee80211_tpt_led_activate; + local->tpt_led.deactivate = ieee80211_tpt_led_deactivate; + if (led_trigger_register(&local->tpt_led)) { + kfree(local->tpt_led_trigger); + local->tpt_led_trigger = NULL; + } + } +} + +void ieee80211_led_exit(struct ieee80211_local *local) +{ + if (local->radio_led.name) + led_trigger_unregister(&local->radio_led); + if (local->assoc_led.name) + led_trigger_unregister(&local->assoc_led); + if (local->tx_led.name) + led_trigger_unregister(&local->tx_led); + if (local->rx_led.name) + led_trigger_unregister(&local->rx_led); + + if (local->tpt_led_trigger) { + led_trigger_unregister(&local->tpt_led); + kfree(local->tpt_led_trigger); + } +} + +const char *__ieee80211_get_radio_led_name(struct ieee80211_hw *hw) +{ + struct ieee80211_local *local = hw_to_local(hw); + + return local->radio_led.name; +} +EXPORT_SYMBOL(__ieee80211_get_radio_led_name); + +const char *__ieee80211_get_assoc_led_name(struct ieee80211_hw *hw) +{ + struct ieee80211_local *local = hw_to_local(hw); + + return local->assoc_led.name; +} +EXPORT_SYMBOL(__ieee80211_get_assoc_led_name); + +const char *__ieee80211_get_tx_led_name(struct ieee80211_hw *hw) +{ + struct ieee80211_local *local = hw_to_local(hw); + + return local->tx_led.name; +} +EXPORT_SYMBOL(__ieee80211_get_tx_led_name); + +const char *__ieee80211_get_rx_led_name(struct ieee80211_hw *hw) +{ + struct ieee80211_local *local = hw_to_local(hw); + + return local->rx_led.name; +} +EXPORT_SYMBOL(__ieee80211_get_rx_led_name); + +static unsigned long tpt_trig_traffic(struct ieee80211_local *local, + struct tpt_led_trigger *tpt_trig) +{ + unsigned long traffic, delta; + + traffic = tpt_trig->tx_bytes + tpt_trig->rx_bytes; + + delta = traffic - tpt_trig->prev_traffic; + tpt_trig->prev_traffic = traffic; + return DIV_ROUND_UP(delta, 1024 / 8); +} + +static void tpt_trig_timer(struct timer_list *t) +{ + struct tpt_led_trigger *tpt_trig = from_timer(tpt_trig, t, timer); + struct ieee80211_local *local = tpt_trig->local; + unsigned long on, off, tpt; + int i; + + if (!tpt_trig->running) + return; + + mod_timer(&tpt_trig->timer, round_jiffies(jiffies + HZ)); + + tpt = tpt_trig_traffic(local, tpt_trig); + + /* default to just solid on */ + on = 1; + off = 0; + + for (i = tpt_trig->blink_table_len - 1; i >= 0; i--) { + if (tpt_trig->blink_table[i].throughput < 0 || + tpt > tpt_trig->blink_table[i].throughput) { + off = tpt_trig->blink_table[i].blink_time / 2; + on = tpt_trig->blink_table[i].blink_time - off; + break; + } + } + + led_trigger_blink(&local->tpt_led, &on, &off); +} + +const char * +__ieee80211_create_tpt_led_trigger(struct ieee80211_hw *hw, + unsigned int flags, + const struct ieee80211_tpt_blink *blink_table, + unsigned int blink_table_len) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct tpt_led_trigger *tpt_trig; + + if (WARN_ON(local->tpt_led_trigger)) + return NULL; + + tpt_trig = kzalloc(sizeof(struct tpt_led_trigger), GFP_KERNEL); + if (!tpt_trig) + return NULL; + + snprintf(tpt_trig->name, sizeof(tpt_trig->name), + "%stpt", wiphy_name(local->hw.wiphy)); + + local->tpt_led.name = tpt_trig->name; + + tpt_trig->blink_table = blink_table; + tpt_trig->blink_table_len = blink_table_len; + tpt_trig->want = flags; + tpt_trig->local = local; + + timer_setup(&tpt_trig->timer, tpt_trig_timer, 0); + + local->tpt_led_trigger = tpt_trig; + + return tpt_trig->name; +} +EXPORT_SYMBOL(__ieee80211_create_tpt_led_trigger); + +static void ieee80211_start_tpt_led_trig(struct ieee80211_local *local) +{ + struct tpt_led_trigger *tpt_trig = local->tpt_led_trigger; + + if (tpt_trig->running) + return; + + /* reset traffic */ + tpt_trig_traffic(local, tpt_trig); + tpt_trig->running = true; + + tpt_trig_timer(&tpt_trig->timer); + mod_timer(&tpt_trig->timer, round_jiffies(jiffies + HZ)); +} + +static void ieee80211_stop_tpt_led_trig(struct ieee80211_local *local) +{ + struct tpt_led_trigger *tpt_trig = local->tpt_led_trigger; + + if (!tpt_trig->running) + return; + + tpt_trig->running = false; + del_timer_sync(&tpt_trig->timer); + + led_trigger_event(&local->tpt_led, LED_OFF); +} + +void ieee80211_mod_tpt_led_trig(struct ieee80211_local *local, + unsigned int types_on, unsigned int types_off) +{ + struct tpt_led_trigger *tpt_trig = local->tpt_led_trigger; + bool allowed; + + WARN_ON(types_on & types_off); + + if (!tpt_trig) + return; + + tpt_trig->active &= ~types_off; + tpt_trig->active |= types_on; + + /* + * Regardless of wanted state, we shouldn't blink when + * the radio is disabled -- this can happen due to some + * code ordering issues with __ieee80211_recalc_idle() + * being called before the radio is started. + */ + allowed = tpt_trig->active & IEEE80211_TPT_LEDTRIG_FL_RADIO; + + if (!allowed || !(tpt_trig->active & tpt_trig->want)) + ieee80211_stop_tpt_led_trig(local); + else + ieee80211_start_tpt_led_trig(local); +} diff --git a/net/mac80211/led.h b/net/mac80211/led.h new file mode 100644 index 000000000..b71a1428d --- /dev/null +++ b/net/mac80211/led.h @@ -0,0 +1,90 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright 2006, Johannes Berg + */ + +#include +#include +#include +#include "ieee80211_i.h" + +#define MAC80211_BLINK_DELAY 50 /* ms */ + +static inline void ieee80211_led_rx(struct ieee80211_local *local) +{ +#ifdef CONFIG_MAC80211_LEDS + unsigned long led_delay = MAC80211_BLINK_DELAY; + + if (!atomic_read(&local->rx_led_active)) + return; + led_trigger_blink_oneshot(&local->rx_led, &led_delay, &led_delay, 0); +#endif +} + +static inline void ieee80211_led_tx(struct ieee80211_local *local) +{ +#ifdef CONFIG_MAC80211_LEDS + unsigned long led_delay = MAC80211_BLINK_DELAY; + + if (!atomic_read(&local->tx_led_active)) + return; + led_trigger_blink_oneshot(&local->tx_led, &led_delay, &led_delay, 0); +#endif +} + +#ifdef CONFIG_MAC80211_LEDS +void ieee80211_led_assoc(struct ieee80211_local *local, + bool associated); +void ieee80211_led_radio(struct ieee80211_local *local, + bool enabled); +void ieee80211_alloc_led_names(struct ieee80211_local *local); +void ieee80211_free_led_names(struct ieee80211_local *local); +void ieee80211_led_init(struct ieee80211_local *local); +void ieee80211_led_exit(struct ieee80211_local *local); +void ieee80211_mod_tpt_led_trig(struct ieee80211_local *local, + unsigned int types_on, unsigned int types_off); +#else +static inline void ieee80211_led_assoc(struct ieee80211_local *local, + bool associated) +{ +} +static inline void ieee80211_led_radio(struct ieee80211_local *local, + bool enabled) +{ +} +static inline void ieee80211_alloc_led_names(struct ieee80211_local *local) +{ +} +static inline void ieee80211_free_led_names(struct ieee80211_local *local) +{ +} +static inline void ieee80211_led_init(struct ieee80211_local *local) +{ +} +static inline void ieee80211_led_exit(struct ieee80211_local *local) +{ +} +static inline void ieee80211_mod_tpt_led_trig(struct ieee80211_local *local, + unsigned int types_on, + unsigned int types_off) +{ +} +#endif + +static inline void +ieee80211_tpt_led_trig_tx(struct ieee80211_local *local, int bytes) +{ +#ifdef CONFIG_MAC80211_LEDS + if (atomic_read(&local->tpt_led_active)) + local->tpt_led_trigger->tx_bytes += bytes; +#endif +} + +static inline void +ieee80211_tpt_led_trig_rx(struct ieee80211_local *local, int bytes) +{ +#ifdef CONFIG_MAC80211_LEDS + if (atomic_read(&local->tpt_led_active)) + local->tpt_led_trigger->rx_bytes += bytes; +#endif +} diff --git a/net/mac80211/link.c b/net/mac80211/link.c new file mode 100644 index 000000000..a85b44c1b --- /dev/null +++ b/net/mac80211/link.c @@ -0,0 +1,476 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * MLO link handling + * + * Copyright (C) 2022-2023 Intel Corporation + */ +#include +#include +#include +#include "ieee80211_i.h" +#include "driver-ops.h" +#include "key.h" + +void ieee80211_link_setup(struct ieee80211_link_data *link) +{ + if (link->sdata->vif.type == NL80211_IFTYPE_STATION) + ieee80211_mgd_setup_link(link); +} + +void ieee80211_link_init(struct ieee80211_sub_if_data *sdata, + int link_id, + struct ieee80211_link_data *link, + struct ieee80211_bss_conf *link_conf) +{ + bool deflink = link_id < 0; + + if (link_id < 0) + link_id = 0; + + rcu_assign_pointer(sdata->vif.link_conf[link_id], link_conf); + rcu_assign_pointer(sdata->link[link_id], link); + + link->sdata = sdata; + link->link_id = link_id; + link->conf = link_conf; + link_conf->link_id = link_id; + + INIT_WORK(&link->csa_finalize_work, + ieee80211_csa_finalize_work); + INIT_WORK(&link->color_change_finalize_work, + ieee80211_color_change_finalize_work); + INIT_DELAYED_WORK(&link->color_collision_detect_work, + ieee80211_color_collision_detection_work); + INIT_LIST_HEAD(&link->assigned_chanctx_list); + INIT_LIST_HEAD(&link->reserved_chanctx_list); + INIT_DELAYED_WORK(&link->dfs_cac_timer_work, + ieee80211_dfs_cac_timer_work); + + if (!deflink) { + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP: + ether_addr_copy(link_conf->addr, + sdata->wdev.links[link_id].addr); + link_conf->bssid = link_conf->addr; + WARN_ON(!(sdata->wdev.valid_links & BIT(link_id))); + break; + case NL80211_IFTYPE_STATION: + /* station sets the bssid in ieee80211_mgd_setup_link */ + break; + default: + WARN_ON(1); + } + } +} + +void ieee80211_link_stop(struct ieee80211_link_data *link) +{ + if (link->sdata->vif.type == NL80211_IFTYPE_STATION) + ieee80211_mgd_stop_link(link); + + cancel_delayed_work_sync(&link->color_collision_detect_work); + ieee80211_link_release_channel(link); +} + +struct link_container { + struct ieee80211_link_data data; + struct ieee80211_bss_conf conf; +}; + +static void ieee80211_tear_down_links(struct ieee80211_sub_if_data *sdata, + struct link_container **links, u16 mask) +{ + struct ieee80211_link_data *link; + LIST_HEAD(keys); + unsigned int link_id; + + for (link_id = 0; link_id < IEEE80211_MLD_MAX_NUM_LINKS; link_id++) { + if (!(mask & BIT(link_id))) + continue; + link = &links[link_id]->data; + if (link_id == 0 && !link) + link = &sdata->deflink; + if (WARN_ON(!link)) + continue; + ieee80211_remove_link_keys(link, &keys); + ieee80211_link_stop(link); + } + + synchronize_rcu(); + + ieee80211_free_key_list(sdata->local, &keys); +} + +static void ieee80211_free_links(struct ieee80211_sub_if_data *sdata, + struct link_container **links) +{ + unsigned int link_id; + + for (link_id = 0; link_id < IEEE80211_MLD_MAX_NUM_LINKS; link_id++) + kfree(links[link_id]); +} + +static int ieee80211_check_dup_link_addrs(struct ieee80211_sub_if_data *sdata) +{ + unsigned int i, j; + + for (i = 0; i < IEEE80211_MLD_MAX_NUM_LINKS; i++) { + struct ieee80211_link_data *link1; + + link1 = sdata_dereference(sdata->link[i], sdata); + if (!link1) + continue; + for (j = i + 1; j < IEEE80211_MLD_MAX_NUM_LINKS; j++) { + struct ieee80211_link_data *link2; + + link2 = sdata_dereference(sdata->link[j], sdata); + if (!link2) + continue; + + if (ether_addr_equal(link1->conf->addr, + link2->conf->addr)) + return -EALREADY; + } + } + + return 0; +} + +static void ieee80211_set_vif_links_bitmaps(struct ieee80211_sub_if_data *sdata, + u16 links) +{ + sdata->vif.valid_links = links; + + if (!links) { + sdata->vif.active_links = 0; + return; + } + + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP: + /* in an AP all links are always active */ + sdata->vif.active_links = links; + break; + case NL80211_IFTYPE_STATION: + if (sdata->vif.active_links) + break; + WARN_ON(hweight16(links) > 1); + sdata->vif.active_links = links; + break; + default: + WARN_ON(1); + } +} + +static int ieee80211_vif_update_links(struct ieee80211_sub_if_data *sdata, + struct link_container **to_free, + u16 new_links) +{ + u16 old_links = sdata->vif.valid_links; + u16 old_active = sdata->vif.active_links; + unsigned long add = new_links & ~old_links; + unsigned long rem = old_links & ~new_links; + unsigned int link_id; + int ret; + struct link_container *links[IEEE80211_MLD_MAX_NUM_LINKS] = {}, *link; + struct ieee80211_bss_conf *old[IEEE80211_MLD_MAX_NUM_LINKS]; + struct ieee80211_link_data *old_data[IEEE80211_MLD_MAX_NUM_LINKS]; + bool use_deflink = old_links == 0; /* set for error case */ + + sdata_assert_lock(sdata); + + memset(to_free, 0, sizeof(links)); + + if (old_links == new_links) + return 0; + + /* if there were no old links, need to clear the pointers to deflink */ + if (!old_links) + rem |= BIT(0); + + /* allocate new link structures first */ + for_each_set_bit(link_id, &add, IEEE80211_MLD_MAX_NUM_LINKS) { + link = kzalloc(sizeof(*link), GFP_KERNEL); + if (!link) { + ret = -ENOMEM; + goto free; + } + links[link_id] = link; + } + + /* keep track of the old pointers for the driver */ + BUILD_BUG_ON(sizeof(old) != sizeof(sdata->vif.link_conf)); + memcpy(old, sdata->vif.link_conf, sizeof(old)); + /* and for us in error cases */ + BUILD_BUG_ON(sizeof(old_data) != sizeof(sdata->link)); + memcpy(old_data, sdata->link, sizeof(old_data)); + + /* grab old links to free later */ + for_each_set_bit(link_id, &rem, IEEE80211_MLD_MAX_NUM_LINKS) { + if (rcu_access_pointer(sdata->link[link_id]) != &sdata->deflink) { + /* + * we must have allocated the data through this path so + * we know we can free both at the same time + */ + to_free[link_id] = container_of(rcu_access_pointer(sdata->link[link_id]), + typeof(*links[link_id]), + data); + } + + RCU_INIT_POINTER(sdata->link[link_id], NULL); + RCU_INIT_POINTER(sdata->vif.link_conf[link_id], NULL); + } + + /* link them into data structures */ + for_each_set_bit(link_id, &add, IEEE80211_MLD_MAX_NUM_LINKS) { + WARN_ON(!use_deflink && + rcu_access_pointer(sdata->link[link_id]) == &sdata->deflink); + + link = links[link_id]; + ieee80211_link_init(sdata, link_id, &link->data, &link->conf); + ieee80211_link_setup(&link->data); + } + + if (new_links == 0) + ieee80211_link_init(sdata, -1, &sdata->deflink, + &sdata->vif.bss_conf); + + ret = ieee80211_check_dup_link_addrs(sdata); + if (!ret) { + /* for keys we will not be able to undo this */ + ieee80211_tear_down_links(sdata, to_free, rem); + + ieee80211_set_vif_links_bitmaps(sdata, new_links); + + /* tell the driver */ + ret = drv_change_vif_links(sdata->local, sdata, + old_links & old_active, + new_links & sdata->vif.active_links, + old); + } + + if (ret) { + /* restore config */ + memcpy(sdata->link, old_data, sizeof(old_data)); + memcpy(sdata->vif.link_conf, old, sizeof(old)); + ieee80211_set_vif_links_bitmaps(sdata, old_links); + /* and free (only) the newly allocated links */ + memset(to_free, 0, sizeof(links)); + goto free; + } + + /* use deflink/bss_conf again if and only if there are no more links */ + use_deflink = new_links == 0; + + goto deinit; +free: + /* if we failed during allocation, only free all */ + for (link_id = 0; link_id < IEEE80211_MLD_MAX_NUM_LINKS; link_id++) { + kfree(links[link_id]); + links[link_id] = NULL; + } +deinit: + if (use_deflink) + ieee80211_link_init(sdata, -1, &sdata->deflink, + &sdata->vif.bss_conf); + return ret; +} + +int ieee80211_vif_set_links(struct ieee80211_sub_if_data *sdata, + u16 new_links) +{ + struct link_container *links[IEEE80211_MLD_MAX_NUM_LINKS]; + int ret; + + ret = ieee80211_vif_update_links(sdata, links, new_links); + ieee80211_free_links(sdata, links); + + return ret; +} + +void ieee80211_vif_clear_links(struct ieee80211_sub_if_data *sdata) +{ + struct link_container *links[IEEE80211_MLD_MAX_NUM_LINKS]; + + /* + * The locking here is different because when we free links + * in the station case we need to be able to cancel_work_sync() + * something that also takes the lock. + */ + + sdata_lock(sdata); + ieee80211_vif_update_links(sdata, links, 0); + sdata_unlock(sdata); + + ieee80211_free_links(sdata, links); +} + +static int _ieee80211_set_active_links(struct ieee80211_sub_if_data *sdata, + u16 active_links) +{ + struct ieee80211_bss_conf *link_confs[IEEE80211_MLD_MAX_NUM_LINKS]; + struct ieee80211_local *local = sdata->local; + u16 old_active = sdata->vif.active_links; + unsigned long rem = old_active & ~active_links; + unsigned long add = active_links & ~old_active; + struct sta_info *sta; + unsigned int link_id; + int ret, i; + + if (!ieee80211_sdata_running(sdata)) + return -ENETDOWN; + + if (sdata->vif.type != NL80211_IFTYPE_STATION) + return -EINVAL; + + /* cannot activate links that don't exist */ + if (active_links & ~sdata->vif.valid_links) + return -EINVAL; + + /* nothing to do */ + if (old_active == active_links) + return 0; + + for (i = 0; i < IEEE80211_MLD_MAX_NUM_LINKS; i++) + link_confs[i] = sdata_dereference(sdata->vif.link_conf[i], + sdata); + + if (add) { + sdata->vif.active_links |= active_links; + ret = drv_change_vif_links(local, sdata, + old_active, + sdata->vif.active_links, + link_confs); + if (ret) { + sdata->vif.active_links = old_active; + return ret; + } + } + + for_each_set_bit(link_id, &rem, IEEE80211_MLD_MAX_NUM_LINKS) { + struct ieee80211_link_data *link; + + link = sdata_dereference(sdata->link[link_id], sdata); + + /* FIXME: kill TDLS connections on the link */ + + ieee80211_link_release_channel(link); + } + + list_for_each_entry(sta, &local->sta_list, list) { + if (sdata != sta->sdata) + continue; + ret = drv_change_sta_links(local, sdata, &sta->sta, + old_active, + old_active | active_links); + WARN_ON_ONCE(ret); + } + + ret = ieee80211_key_switch_links(sdata, rem, add); + WARN_ON_ONCE(ret); + + list_for_each_entry(sta, &local->sta_list, list) { + if (sdata != sta->sdata) + continue; + ret = drv_change_sta_links(local, sdata, &sta->sta, + old_active | active_links, + active_links); + WARN_ON_ONCE(ret); + } + + for_each_set_bit(link_id, &add, IEEE80211_MLD_MAX_NUM_LINKS) { + struct ieee80211_link_data *link; + + link = sdata_dereference(sdata->link[link_id], sdata); + + ret = ieee80211_link_use_channel(link, &link->conf->chandef, + IEEE80211_CHANCTX_SHARED); + WARN_ON_ONCE(ret); + + ieee80211_mgd_set_link_qos_params(link); + ieee80211_link_info_change_notify(sdata, link, + BSS_CHANGED_ERP_CTS_PROT | + BSS_CHANGED_ERP_PREAMBLE | + BSS_CHANGED_ERP_SLOT | + BSS_CHANGED_HT | + BSS_CHANGED_BASIC_RATES | + BSS_CHANGED_BSSID | + BSS_CHANGED_CQM | + BSS_CHANGED_QOS | + BSS_CHANGED_TXPOWER | + BSS_CHANGED_BANDWIDTH | + BSS_CHANGED_TWT | + BSS_CHANGED_HE_OBSS_PD | + BSS_CHANGED_HE_BSS_COLOR); + } + + old_active = sdata->vif.active_links; + sdata->vif.active_links = active_links; + + if (rem) { + ret = drv_change_vif_links(local, sdata, old_active, + active_links, link_confs); + WARN_ON_ONCE(ret); + } + + return 0; +} + +int ieee80211_set_active_links(struct ieee80211_vif *vif, u16 active_links) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_local *local = sdata->local; + u16 old_active; + int ret; + + sdata_lock(sdata); + mutex_lock(&local->sta_mtx); + mutex_lock(&local->mtx); + mutex_lock(&local->key_mtx); + old_active = sdata->vif.active_links; + if (old_active & active_links) { + /* + * if there's at least one link that stays active across + * the change then switch to it (to those) first, and + * then enable the additional links + */ + ret = _ieee80211_set_active_links(sdata, + old_active & active_links); + if (!ret) + ret = _ieee80211_set_active_links(sdata, active_links); + } else { + /* otherwise switch directly */ + ret = _ieee80211_set_active_links(sdata, active_links); + } + mutex_unlock(&local->key_mtx); + mutex_unlock(&local->mtx); + mutex_unlock(&local->sta_mtx); + sdata_unlock(sdata); + + return ret; +} +EXPORT_SYMBOL_GPL(ieee80211_set_active_links); + +void ieee80211_set_active_links_async(struct ieee80211_vif *vif, + u16 active_links) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + + if (!ieee80211_sdata_running(sdata)) + return; + + if (sdata->vif.type != NL80211_IFTYPE_STATION) + return; + + /* cannot activate links that don't exist */ + if (active_links & ~sdata->vif.valid_links) + return; + + /* nothing to do */ + if (sdata->vif.active_links == active_links) + return; + + sdata->desired_active_links = active_links; + schedule_work(&sdata->activate_links_work); +} +EXPORT_SYMBOL_GPL(ieee80211_set_active_links_async); diff --git a/net/mac80211/main.c b/net/mac80211/main.c new file mode 100644 index 000000000..6faba47b7 --- /dev/null +++ b/net/mac80211/main.c @@ -0,0 +1,1575 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2002-2005, Instant802 Networks, Inc. + * Copyright 2005-2006, Devicescape Software, Inc. + * Copyright 2006-2007 Jiri Benc + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright (C) 2017 Intel Deutschland GmbH + * Copyright (C) 2018-2022 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ieee80211_i.h" +#include "driver-ops.h" +#include "rate.h" +#include "mesh.h" +#include "wep.h" +#include "led.h" +#include "debugfs.h" + +void ieee80211_configure_filter(struct ieee80211_local *local) +{ + u64 mc; + unsigned int changed_flags; + unsigned int new_flags = 0; + + if (atomic_read(&local->iff_allmultis)) + new_flags |= FIF_ALLMULTI; + + if (local->monitors || test_bit(SCAN_SW_SCANNING, &local->scanning) || + test_bit(SCAN_ONCHANNEL_SCANNING, &local->scanning)) + new_flags |= FIF_BCN_PRBRESP_PROMISC; + + if (local->fif_probe_req || local->probe_req_reg) + new_flags |= FIF_PROBE_REQ; + + if (local->fif_fcsfail) + new_flags |= FIF_FCSFAIL; + + if (local->fif_plcpfail) + new_flags |= FIF_PLCPFAIL; + + if (local->fif_control) + new_flags |= FIF_CONTROL; + + if (local->fif_other_bss) + new_flags |= FIF_OTHER_BSS; + + if (local->fif_pspoll) + new_flags |= FIF_PSPOLL; + + if (local->rx_mcast_action_reg) + new_flags |= FIF_MCAST_ACTION; + + spin_lock_bh(&local->filter_lock); + changed_flags = local->filter_flags ^ new_flags; + + mc = drv_prepare_multicast(local, &local->mc_list); + spin_unlock_bh(&local->filter_lock); + + /* be a bit nasty */ + new_flags |= (1<<31); + + drv_configure_filter(local, changed_flags, &new_flags, mc); + + WARN_ON(new_flags & (1<<31)); + + local->filter_flags = new_flags & ~(1<<31); +} + +static void ieee80211_reconfig_filter(struct work_struct *work) +{ + struct ieee80211_local *local = + container_of(work, struct ieee80211_local, reconfig_filter); + + ieee80211_configure_filter(local); +} + +static u32 ieee80211_hw_conf_chan(struct ieee80211_local *local) +{ + struct ieee80211_sub_if_data *sdata; + struct cfg80211_chan_def chandef = {}; + u32 changed = 0; + int power; + u32 offchannel_flag; + + offchannel_flag = local->hw.conf.flags & IEEE80211_CONF_OFFCHANNEL; + + if (local->scan_chandef.chan) { + chandef = local->scan_chandef; + } else if (local->tmp_channel) { + chandef.chan = local->tmp_channel; + chandef.width = NL80211_CHAN_WIDTH_20_NOHT; + chandef.center_freq1 = chandef.chan->center_freq; + chandef.freq1_offset = chandef.chan->freq_offset; + } else + chandef = local->_oper_chandef; + + WARN(!cfg80211_chandef_valid(&chandef), + "control:%d.%03d MHz width:%d center: %d.%03d/%d MHz", + chandef.chan->center_freq, chandef.chan->freq_offset, + chandef.width, chandef.center_freq1, chandef.freq1_offset, + chandef.center_freq2); + + if (!cfg80211_chandef_identical(&chandef, &local->_oper_chandef)) + local->hw.conf.flags |= IEEE80211_CONF_OFFCHANNEL; + else + local->hw.conf.flags &= ~IEEE80211_CONF_OFFCHANNEL; + + offchannel_flag ^= local->hw.conf.flags & IEEE80211_CONF_OFFCHANNEL; + + if (offchannel_flag || + !cfg80211_chandef_identical(&local->hw.conf.chandef, + &local->_oper_chandef)) { + local->hw.conf.chandef = chandef; + changed |= IEEE80211_CONF_CHANGE_CHANNEL; + } + + if (!conf_is_ht(&local->hw.conf)) { + /* + * mac80211.h documents that this is only valid + * when the channel is set to an HT type, and + * that otherwise STATIC is used. + */ + local->hw.conf.smps_mode = IEEE80211_SMPS_STATIC; + } else if (local->hw.conf.smps_mode != local->smps_mode) { + local->hw.conf.smps_mode = local->smps_mode; + changed |= IEEE80211_CONF_CHANGE_SMPS; + } + + power = ieee80211_chandef_max_power(&chandef); + + rcu_read_lock(); + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + if (!rcu_access_pointer(sdata->vif.bss_conf.chanctx_conf)) + continue; + if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + continue; + if (sdata->vif.bss_conf.txpower == INT_MIN) + continue; + power = min(power, sdata->vif.bss_conf.txpower); + } + rcu_read_unlock(); + + if (local->hw.conf.power_level != power) { + changed |= IEEE80211_CONF_CHANGE_POWER; + local->hw.conf.power_level = power; + } + + return changed; +} + +int ieee80211_hw_config(struct ieee80211_local *local, u32 changed) +{ + int ret = 0; + + might_sleep(); + + if (!local->use_chanctx) + changed |= ieee80211_hw_conf_chan(local); + else + changed &= ~(IEEE80211_CONF_CHANGE_CHANNEL | + IEEE80211_CONF_CHANGE_POWER | + IEEE80211_CONF_CHANGE_SMPS); + + if (changed && local->open_count) { + ret = drv_config(local, changed); + /* + * Goal: + * HW reconfiguration should never fail, the driver has told + * us what it can support so it should live up to that promise. + * + * Current status: + * rfkill is not integrated with mac80211 and a + * configuration command can thus fail if hardware rfkill + * is enabled + * + * FIXME: integrate rfkill with mac80211 and then add this + * WARN_ON() back + * + */ + /* WARN_ON(ret); */ + } + + return ret; +} + +#define BSS_CHANGED_VIF_CFG_FLAGS (BSS_CHANGED_ASSOC |\ + BSS_CHANGED_IDLE |\ + BSS_CHANGED_PS |\ + BSS_CHANGED_IBSS |\ + BSS_CHANGED_ARP_FILTER |\ + BSS_CHANGED_SSID) + +void ieee80211_bss_info_change_notify(struct ieee80211_sub_if_data *sdata, + u64 changed) +{ + struct ieee80211_local *local = sdata->local; + + might_sleep(); + + if (!changed || sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + return; + + if (WARN_ON_ONCE(changed & (BSS_CHANGED_BEACON | + BSS_CHANGED_BEACON_ENABLED) && + sdata->vif.type != NL80211_IFTYPE_AP && + sdata->vif.type != NL80211_IFTYPE_ADHOC && + sdata->vif.type != NL80211_IFTYPE_MESH_POINT && + sdata->vif.type != NL80211_IFTYPE_OCB)) + return; + + if (WARN_ON_ONCE(sdata->vif.type == NL80211_IFTYPE_P2P_DEVICE || + sdata->vif.type == NL80211_IFTYPE_NAN || + (sdata->vif.type == NL80211_IFTYPE_MONITOR && + !sdata->vif.bss_conf.mu_mimo_owner && + !(changed & BSS_CHANGED_TXPOWER)))) + return; + + if (!check_sdata_in_driver(sdata)) + return; + + if (changed & BSS_CHANGED_VIF_CFG_FLAGS) { + u64 ch = changed & BSS_CHANGED_VIF_CFG_FLAGS; + + trace_drv_vif_cfg_changed(local, sdata, changed); + if (local->ops->vif_cfg_changed) + local->ops->vif_cfg_changed(&local->hw, &sdata->vif, ch); + } + + if (changed & ~BSS_CHANGED_VIF_CFG_FLAGS) { + u64 ch = changed & ~BSS_CHANGED_VIF_CFG_FLAGS; + + /* FIXME: should be for each link */ + trace_drv_link_info_changed(local, sdata, &sdata->vif.bss_conf, + changed); + if (local->ops->link_info_changed) + local->ops->link_info_changed(&local->hw, &sdata->vif, + &sdata->vif.bss_conf, ch); + } + + if (local->ops->bss_info_changed) + local->ops->bss_info_changed(&local->hw, &sdata->vif, + &sdata->vif.bss_conf, changed); + trace_drv_return_void(local); +} + +void ieee80211_vif_cfg_change_notify(struct ieee80211_sub_if_data *sdata, + u64 changed) +{ + struct ieee80211_local *local = sdata->local; + + WARN_ON_ONCE(changed & ~BSS_CHANGED_VIF_CFG_FLAGS); + + if (!changed || sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + return; + + drv_vif_cfg_changed(local, sdata, changed); +} + +void ieee80211_link_info_change_notify(struct ieee80211_sub_if_data *sdata, + struct ieee80211_link_data *link, + u64 changed) +{ + struct ieee80211_local *local = sdata->local; + + WARN_ON_ONCE(changed & BSS_CHANGED_VIF_CFG_FLAGS); + + if (!changed || sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + return; + + if (!check_sdata_in_driver(sdata)) + return; + + drv_link_info_changed(local, sdata, link->conf, link->link_id, changed); +} + +u32 ieee80211_reset_erp_info(struct ieee80211_sub_if_data *sdata) +{ + sdata->vif.bss_conf.use_cts_prot = false; + sdata->vif.bss_conf.use_short_preamble = false; + sdata->vif.bss_conf.use_short_slot = false; + return BSS_CHANGED_ERP_CTS_PROT | + BSS_CHANGED_ERP_PREAMBLE | + BSS_CHANGED_ERP_SLOT; +} + +static void ieee80211_tasklet_handler(struct tasklet_struct *t) +{ + struct ieee80211_local *local = from_tasklet(local, t, tasklet); + struct sk_buff *skb; + + while ((skb = skb_dequeue(&local->skb_queue)) || + (skb = skb_dequeue(&local->skb_queue_unreliable))) { + switch (skb->pkt_type) { + case IEEE80211_RX_MSG: + /* Clear skb->pkt_type in order to not confuse kernel + * netstack. */ + skb->pkt_type = 0; + ieee80211_rx(&local->hw, skb); + break; + case IEEE80211_TX_STATUS_MSG: + skb->pkt_type = 0; + ieee80211_tx_status(&local->hw, skb); + break; + default: + WARN(1, "mac80211: Packet is of unknown type %d\n", + skb->pkt_type); + dev_kfree_skb(skb); + break; + } + } +} + +static void ieee80211_restart_work(struct work_struct *work) +{ + struct ieee80211_local *local = + container_of(work, struct ieee80211_local, restart_work); + struct ieee80211_sub_if_data *sdata; + int ret; + + flush_workqueue(local->workqueue); + + rtnl_lock(); + /* we might do interface manipulations, so need both */ + wiphy_lock(local->hw.wiphy); + + WARN(test_bit(SCAN_HW_SCANNING, &local->scanning), + "%s called with hardware scan in progress\n", __func__); + + list_for_each_entry(sdata, &local->interfaces, list) { + /* + * XXX: there may be more work for other vif types and even + * for station mode: a good thing would be to run most of + * the iface type's dependent _stop (ieee80211_mg_stop, + * ieee80211_ibss_stop) etc... + * For now, fix only the specific bug that was seen: race + * between csa_connection_drop_work and us. + */ + if (sdata->vif.type == NL80211_IFTYPE_STATION) { + /* + * This worker is scheduled from the iface worker that + * runs on mac80211's workqueue, so we can't be + * scheduling this worker after the cancel right here. + * The exception is ieee80211_chswitch_done. + * Then we can have a race... + */ + cancel_work_sync(&sdata->u.mgd.csa_connection_drop_work); + if (sdata->vif.bss_conf.csa_active) { + sdata_lock(sdata); + ieee80211_sta_connection_lost(sdata, + WLAN_REASON_UNSPECIFIED, + false); + sdata_unlock(sdata); + } + } + flush_delayed_work(&sdata->dec_tailroom_needed_wk); + } + ieee80211_scan_cancel(local); + + /* make sure any new ROC will consider local->in_reconfig */ + wiphy_delayed_work_flush(local->hw.wiphy, &local->roc_work); + wiphy_work_flush(local->hw.wiphy, &local->hw_roc_done); + + /* wait for all packet processing to be done */ + synchronize_net(); + + ret = ieee80211_reconfig(local); + wiphy_unlock(local->hw.wiphy); + + if (ret) + cfg80211_shutdown_all_interfaces(local->hw.wiphy); + + rtnl_unlock(); +} + +void ieee80211_restart_hw(struct ieee80211_hw *hw) +{ + struct ieee80211_local *local = hw_to_local(hw); + + trace_api_restart_hw(local); + + wiphy_info(hw->wiphy, + "Hardware restart was requested\n"); + + /* use this reason, ieee80211_reconfig will unblock it */ + ieee80211_stop_queues_by_reason(hw, IEEE80211_MAX_QUEUE_MAP, + IEEE80211_QUEUE_STOP_REASON_SUSPEND, + false); + + /* + * Stop all Rx during the reconfig. We don't want state changes + * or driver callbacks while this is in progress. + */ + local->in_reconfig = true; + barrier(); + + queue_work(system_freezable_wq, &local->restart_work); +} +EXPORT_SYMBOL(ieee80211_restart_hw); + +#ifdef CONFIG_INET +static int ieee80211_ifa_changed(struct notifier_block *nb, + unsigned long data, void *arg) +{ + struct in_ifaddr *ifa = arg; + struct ieee80211_local *local = + container_of(nb, struct ieee80211_local, + ifa_notifier); + struct net_device *ndev = ifa->ifa_dev->dev; + struct wireless_dev *wdev = ndev->ieee80211_ptr; + struct in_device *idev; + struct ieee80211_sub_if_data *sdata; + struct ieee80211_vif_cfg *vif_cfg; + struct ieee80211_if_managed *ifmgd; + int c = 0; + + /* Make sure it's our interface that got changed */ + if (!wdev) + return NOTIFY_DONE; + + if (wdev->wiphy != local->hw.wiphy) + return NOTIFY_DONE; + + sdata = IEEE80211_DEV_TO_SUB_IF(ndev); + vif_cfg = &sdata->vif.cfg; + + /* ARP filtering is only supported in managed mode */ + if (sdata->vif.type != NL80211_IFTYPE_STATION) + return NOTIFY_DONE; + + idev = __in_dev_get_rtnl(sdata->dev); + if (!idev) + return NOTIFY_DONE; + + ifmgd = &sdata->u.mgd; + sdata_lock(sdata); + + /* Copy the addresses to the vif config list */ + ifa = rtnl_dereference(idev->ifa_list); + while (ifa) { + if (c < IEEE80211_BSS_ARP_ADDR_LIST_LEN) + vif_cfg->arp_addr_list[c] = ifa->ifa_address; + ifa = rtnl_dereference(ifa->ifa_next); + c++; + } + + vif_cfg->arp_addr_cnt = c; + + /* Configure driver only if associated (which also implies it is up) */ + if (ifmgd->associated) + ieee80211_vif_cfg_change_notify(sdata, BSS_CHANGED_ARP_FILTER); + + sdata_unlock(sdata); + + return NOTIFY_OK; +} +#endif + +#if IS_ENABLED(CONFIG_IPV6) +static int ieee80211_ifa6_changed(struct notifier_block *nb, + unsigned long data, void *arg) +{ + struct inet6_ifaddr *ifa = (struct inet6_ifaddr *)arg; + struct inet6_dev *idev = ifa->idev; + struct net_device *ndev = ifa->idev->dev; + struct ieee80211_local *local = + container_of(nb, struct ieee80211_local, ifa6_notifier); + struct wireless_dev *wdev = ndev->ieee80211_ptr; + struct ieee80211_sub_if_data *sdata; + + /* Make sure it's our interface that got changed */ + if (!wdev || wdev->wiphy != local->hw.wiphy) + return NOTIFY_DONE; + + sdata = IEEE80211_DEV_TO_SUB_IF(ndev); + + /* + * For now only support station mode. This is mostly because + * doing AP would have to handle AP_VLAN in some way ... + */ + if (sdata->vif.type != NL80211_IFTYPE_STATION) + return NOTIFY_DONE; + + drv_ipv6_addr_change(local, sdata, idev); + + return NOTIFY_OK; +} +#endif + +/* There isn't a lot of sense in it, but you can transmit anything you like */ +static const struct ieee80211_txrx_stypes +ieee80211_default_mgmt_stypes[NUM_NL80211_IFTYPES] = { + [NL80211_IFTYPE_ADHOC] = { + .tx = 0xffff, + .rx = BIT(IEEE80211_STYPE_ACTION >> 4) | + BIT(IEEE80211_STYPE_AUTH >> 4) | + BIT(IEEE80211_STYPE_DEAUTH >> 4) | + BIT(IEEE80211_STYPE_PROBE_REQ >> 4), + }, + [NL80211_IFTYPE_STATION] = { + .tx = 0xffff, + /* + * To support Pre Association Security Negotiation (PASN) while + * already associated to one AP, allow user space to register to + * Rx authentication frames, so that the user space logic would + * be able to receive/handle authentication frames from a + * different AP as part of PASN. + * It is expected that user space would intelligently register + * for Rx authentication frames, i.e., only when PASN is used + * and configure a match filter only for PASN authentication + * algorithm, as otherwise the MLME functionality of mac80211 + * would be broken. + */ + .rx = BIT(IEEE80211_STYPE_ACTION >> 4) | + BIT(IEEE80211_STYPE_AUTH >> 4) | + BIT(IEEE80211_STYPE_PROBE_REQ >> 4), + }, + [NL80211_IFTYPE_AP] = { + .tx = 0xffff, + .rx = BIT(IEEE80211_STYPE_ASSOC_REQ >> 4) | + BIT(IEEE80211_STYPE_REASSOC_REQ >> 4) | + BIT(IEEE80211_STYPE_PROBE_REQ >> 4) | + BIT(IEEE80211_STYPE_DISASSOC >> 4) | + BIT(IEEE80211_STYPE_AUTH >> 4) | + BIT(IEEE80211_STYPE_DEAUTH >> 4) | + BIT(IEEE80211_STYPE_ACTION >> 4), + }, + [NL80211_IFTYPE_AP_VLAN] = { + /* copy AP */ + .tx = 0xffff, + .rx = BIT(IEEE80211_STYPE_ASSOC_REQ >> 4) | + BIT(IEEE80211_STYPE_REASSOC_REQ >> 4) | + BIT(IEEE80211_STYPE_PROBE_REQ >> 4) | + BIT(IEEE80211_STYPE_DISASSOC >> 4) | + BIT(IEEE80211_STYPE_AUTH >> 4) | + BIT(IEEE80211_STYPE_DEAUTH >> 4) | + BIT(IEEE80211_STYPE_ACTION >> 4), + }, + [NL80211_IFTYPE_P2P_CLIENT] = { + .tx = 0xffff, + .rx = BIT(IEEE80211_STYPE_ACTION >> 4) | + BIT(IEEE80211_STYPE_PROBE_REQ >> 4), + }, + [NL80211_IFTYPE_P2P_GO] = { + .tx = 0xffff, + .rx = BIT(IEEE80211_STYPE_ASSOC_REQ >> 4) | + BIT(IEEE80211_STYPE_REASSOC_REQ >> 4) | + BIT(IEEE80211_STYPE_PROBE_REQ >> 4) | + BIT(IEEE80211_STYPE_DISASSOC >> 4) | + BIT(IEEE80211_STYPE_AUTH >> 4) | + BIT(IEEE80211_STYPE_DEAUTH >> 4) | + BIT(IEEE80211_STYPE_ACTION >> 4), + }, + [NL80211_IFTYPE_MESH_POINT] = { + .tx = 0xffff, + .rx = BIT(IEEE80211_STYPE_ACTION >> 4) | + BIT(IEEE80211_STYPE_AUTH >> 4) | + BIT(IEEE80211_STYPE_DEAUTH >> 4), + }, + [NL80211_IFTYPE_P2P_DEVICE] = { + .tx = 0xffff, + .rx = BIT(IEEE80211_STYPE_ACTION >> 4) | + BIT(IEEE80211_STYPE_PROBE_REQ >> 4), + }, +}; + +static const struct ieee80211_ht_cap mac80211_ht_capa_mod_mask = { + .ampdu_params_info = IEEE80211_HT_AMPDU_PARM_FACTOR | + IEEE80211_HT_AMPDU_PARM_DENSITY, + + .cap_info = cpu_to_le16(IEEE80211_HT_CAP_SUP_WIDTH_20_40 | + IEEE80211_HT_CAP_MAX_AMSDU | + IEEE80211_HT_CAP_SGI_20 | + IEEE80211_HT_CAP_SGI_40 | + IEEE80211_HT_CAP_TX_STBC | + IEEE80211_HT_CAP_RX_STBC | + IEEE80211_HT_CAP_LDPC_CODING | + IEEE80211_HT_CAP_40MHZ_INTOLERANT), + .mcs = { + .rx_mask = { 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, }, + }, +}; + +static const struct ieee80211_vht_cap mac80211_vht_capa_mod_mask = { + .vht_cap_info = + cpu_to_le32(IEEE80211_VHT_CAP_RXLDPC | + IEEE80211_VHT_CAP_SHORT_GI_80 | + IEEE80211_VHT_CAP_SHORT_GI_160 | + IEEE80211_VHT_CAP_RXSTBC_MASK | + IEEE80211_VHT_CAP_TXSTBC | + IEEE80211_VHT_CAP_SU_BEAMFORMER_CAPABLE | + IEEE80211_VHT_CAP_SU_BEAMFORMEE_CAPABLE | + IEEE80211_VHT_CAP_TX_ANTENNA_PATTERN | + IEEE80211_VHT_CAP_RX_ANTENNA_PATTERN | + IEEE80211_VHT_CAP_MAX_A_MPDU_LENGTH_EXPONENT_MASK), + .supp_mcs = { + .rx_mcs_map = cpu_to_le16(~0), + .tx_mcs_map = cpu_to_le16(~0), + }, +}; + +struct ieee80211_hw *ieee80211_alloc_hw_nm(size_t priv_data_len, + const struct ieee80211_ops *ops, + const char *requested_name) +{ + struct ieee80211_local *local; + int priv_size, i; + struct wiphy *wiphy; + bool use_chanctx; + + if (WARN_ON(!ops->tx || !ops->start || !ops->stop || !ops->config || + !ops->add_interface || !ops->remove_interface || + !ops->configure_filter)) + return NULL; + + if (WARN_ON(ops->sta_state && (ops->sta_add || ops->sta_remove))) + return NULL; + + if (WARN_ON(!!ops->link_info_changed != !!ops->vif_cfg_changed || + (ops->link_info_changed && ops->bss_info_changed))) + return NULL; + + /* check all or no channel context operations exist */ + i = !!ops->add_chanctx + !!ops->remove_chanctx + + !!ops->change_chanctx + !!ops->assign_vif_chanctx + + !!ops->unassign_vif_chanctx; + if (WARN_ON(i != 0 && i != 5)) + return NULL; + use_chanctx = i == 5; + + /* Ensure 32-byte alignment of our private data and hw private data. + * We use the wiphy priv data for both our ieee80211_local and for + * the driver's private data + * + * In memory it'll be like this: + * + * +-------------------------+ + * | struct wiphy | + * +-------------------------+ + * | struct ieee80211_local | + * +-------------------------+ + * | driver's private data | + * +-------------------------+ + * + */ + priv_size = ALIGN(sizeof(*local), NETDEV_ALIGN) + priv_data_len; + + wiphy = wiphy_new_nm(&mac80211_config_ops, priv_size, requested_name); + + if (!wiphy) + return NULL; + + wiphy->mgmt_stypes = ieee80211_default_mgmt_stypes; + + wiphy->privid = mac80211_wiphy_privid; + + wiphy->flags |= WIPHY_FLAG_NETNS_OK | + WIPHY_FLAG_4ADDR_AP | + WIPHY_FLAG_4ADDR_STATION | + WIPHY_FLAG_REPORTS_OBSS | + WIPHY_FLAG_OFFCHAN_TX; + + if (!use_chanctx || ops->remain_on_channel) + wiphy->flags |= WIPHY_FLAG_HAS_REMAIN_ON_CHANNEL; + + wiphy->features |= NL80211_FEATURE_SK_TX_STATUS | + NL80211_FEATURE_SAE | + NL80211_FEATURE_HT_IBSS | + NL80211_FEATURE_VIF_TXPOWER | + NL80211_FEATURE_MAC_ON_CREATE | + NL80211_FEATURE_USERSPACE_MPM | + NL80211_FEATURE_FULL_AP_CLIENT_STATE; + wiphy_ext_feature_set(wiphy, NL80211_EXT_FEATURE_FILS_STA); + wiphy_ext_feature_set(wiphy, + NL80211_EXT_FEATURE_CONTROL_PORT_OVER_NL80211); + wiphy_ext_feature_set(wiphy, + NL80211_EXT_FEATURE_CONTROL_PORT_NO_PREAUTH); + wiphy_ext_feature_set(wiphy, + NL80211_EXT_FEATURE_CONTROL_PORT_OVER_NL80211_TX_STATUS); + wiphy_ext_feature_set(wiphy, + NL80211_EXT_FEATURE_SCAN_FREQ_KHZ); + wiphy_ext_feature_set(wiphy, + NL80211_EXT_FEATURE_POWERED_ADDR_CHANGE); + + if (!ops->hw_scan) { + wiphy->features |= NL80211_FEATURE_LOW_PRIORITY_SCAN | + NL80211_FEATURE_AP_SCAN; + /* + * if the driver behaves correctly using the probe request + * (template) from mac80211, then both of these should be + * supported even with hw scan - but let drivers opt in. + */ + wiphy_ext_feature_set(wiphy, + NL80211_EXT_FEATURE_SCAN_RANDOM_SN); + wiphy_ext_feature_set(wiphy, + NL80211_EXT_FEATURE_SCAN_MIN_PREQ_CONTENT); + } + + if (!ops->set_key) + wiphy->flags |= WIPHY_FLAG_IBSS_RSN; + + if (ops->wake_tx_queue) + wiphy_ext_feature_set(wiphy, NL80211_EXT_FEATURE_TXQS); + + wiphy_ext_feature_set(wiphy, NL80211_EXT_FEATURE_RRM); + + wiphy->bss_priv_size = sizeof(struct ieee80211_bss); + + local = wiphy_priv(wiphy); + + if (sta_info_init(local)) + goto err_free; + + local->hw.wiphy = wiphy; + + local->hw.priv = (char *)local + ALIGN(sizeof(*local), NETDEV_ALIGN); + + local->ops = ops; + local->use_chanctx = use_chanctx; + + /* + * We need a bit of data queued to build aggregates properly, so + * instruct the TCP stack to allow more than a single ms of data + * to be queued in the stack. The value is a bit-shift of 1 + * second, so 7 is ~8ms of queued data. Only affects local TCP + * sockets. + * This is the default, anyhow - drivers may need to override it + * for local reasons (longer buffers, longer completion time, or + * similar). + */ + local->hw.tx_sk_pacing_shift = 7; + + /* set up some defaults */ + local->hw.queues = 1; + local->hw.max_rates = 1; + local->hw.max_report_rates = 0; + local->hw.max_rx_aggregation_subframes = IEEE80211_MAX_AMPDU_BUF_HT; + local->hw.max_tx_aggregation_subframes = IEEE80211_MAX_AMPDU_BUF_HT; + local->hw.offchannel_tx_hw_queue = IEEE80211_INVAL_HW_QUEUE; + local->hw.conf.long_frame_max_tx_count = wiphy->retry_long; + local->hw.conf.short_frame_max_tx_count = wiphy->retry_short; + local->hw.radiotap_mcs_details = IEEE80211_RADIOTAP_MCS_HAVE_MCS | + IEEE80211_RADIOTAP_MCS_HAVE_GI | + IEEE80211_RADIOTAP_MCS_HAVE_BW; + local->hw.radiotap_vht_details = IEEE80211_RADIOTAP_VHT_KNOWN_GI | + IEEE80211_RADIOTAP_VHT_KNOWN_BANDWIDTH; + local->hw.uapsd_queues = IEEE80211_DEFAULT_UAPSD_QUEUES; + local->hw.uapsd_max_sp_len = IEEE80211_DEFAULT_MAX_SP_LEN; + local->hw.max_mtu = IEEE80211_MAX_DATA_LEN; + local->user_power_level = IEEE80211_UNSET_POWER_LEVEL; + wiphy->ht_capa_mod_mask = &mac80211_ht_capa_mod_mask; + wiphy->vht_capa_mod_mask = &mac80211_vht_capa_mod_mask; + + local->ext_capa[7] = WLAN_EXT_CAPA8_OPMODE_NOTIF; + + wiphy->extended_capabilities = local->ext_capa; + wiphy->extended_capabilities_mask = local->ext_capa; + wiphy->extended_capabilities_len = + ARRAY_SIZE(local->ext_capa); + + INIT_LIST_HEAD(&local->interfaces); + INIT_LIST_HEAD(&local->mon_list); + + __hw_addr_init(&local->mc_list); + + mutex_init(&local->iflist_mtx); + mutex_init(&local->mtx); + + mutex_init(&local->key_mtx); + spin_lock_init(&local->filter_lock); + spin_lock_init(&local->rx_path_lock); + spin_lock_init(&local->queue_stop_reason_lock); + + for (i = 0; i < IEEE80211_NUM_ACS; i++) { + INIT_LIST_HEAD(&local->active_txqs[i]); + spin_lock_init(&local->active_txq_lock[i]); + local->aql_txq_limit_low[i] = IEEE80211_DEFAULT_AQL_TXQ_LIMIT_L; + local->aql_txq_limit_high[i] = + IEEE80211_DEFAULT_AQL_TXQ_LIMIT_H; + atomic_set(&local->aql_ac_pending_airtime[i], 0); + } + + local->airtime_flags = AIRTIME_USE_TX | AIRTIME_USE_RX; + local->aql_threshold = IEEE80211_AQL_THRESHOLD; + atomic_set(&local->aql_total_pending_airtime, 0); + + INIT_LIST_HEAD(&local->chanctx_list); + mutex_init(&local->chanctx_mtx); + + wiphy_delayed_work_init(&local->scan_work, ieee80211_scan_work); + + INIT_WORK(&local->restart_work, ieee80211_restart_work); + + wiphy_work_init(&local->radar_detected_work, + ieee80211_dfs_radar_detected_work); + + INIT_WORK(&local->reconfig_filter, ieee80211_reconfig_filter); + local->smps_mode = IEEE80211_SMPS_OFF; + + INIT_WORK(&local->dynamic_ps_enable_work, + ieee80211_dynamic_ps_enable_work); + INIT_WORK(&local->dynamic_ps_disable_work, + ieee80211_dynamic_ps_disable_work); + timer_setup(&local->dynamic_ps_timer, ieee80211_dynamic_ps_timer, 0); + + wiphy_work_init(&local->sched_scan_stopped_work, + ieee80211_sched_scan_stopped_work); + + spin_lock_init(&local->ack_status_lock); + idr_init(&local->ack_status_frames); + + for (i = 0; i < IEEE80211_MAX_QUEUES; i++) { + skb_queue_head_init(&local->pending[i]); + atomic_set(&local->agg_queue_stop[i], 0); + } + tasklet_setup(&local->tx_pending_tasklet, ieee80211_tx_pending); + + if (ops->wake_tx_queue) + tasklet_setup(&local->wake_txqs_tasklet, ieee80211_wake_txqs); + + tasklet_setup(&local->tasklet, ieee80211_tasklet_handler); + + skb_queue_head_init(&local->skb_queue); + skb_queue_head_init(&local->skb_queue_unreliable); + + ieee80211_alloc_led_names(local); + + ieee80211_roc_setup(local); + + local->hw.radiotap_timestamp.units_pos = -1; + local->hw.radiotap_timestamp.accuracy = -1; + + return &local->hw; + err_free: + wiphy_free(wiphy); + return NULL; +} +EXPORT_SYMBOL(ieee80211_alloc_hw_nm); + +static int ieee80211_init_cipher_suites(struct ieee80211_local *local) +{ + bool have_wep = !fips_enabled; /* FIPS does not permit the use of RC4 */ + bool have_mfp = ieee80211_hw_check(&local->hw, MFP_CAPABLE); + int r = 0, w = 0; + u32 *suites; + static const u32 cipher_suites[] = { + /* keep WEP first, it may be removed below */ + WLAN_CIPHER_SUITE_WEP40, + WLAN_CIPHER_SUITE_WEP104, + WLAN_CIPHER_SUITE_TKIP, + WLAN_CIPHER_SUITE_CCMP, + WLAN_CIPHER_SUITE_CCMP_256, + WLAN_CIPHER_SUITE_GCMP, + WLAN_CIPHER_SUITE_GCMP_256, + + /* keep last -- depends on hw flags! */ + WLAN_CIPHER_SUITE_AES_CMAC, + WLAN_CIPHER_SUITE_BIP_CMAC_256, + WLAN_CIPHER_SUITE_BIP_GMAC_128, + WLAN_CIPHER_SUITE_BIP_GMAC_256, + }; + + if (ieee80211_hw_check(&local->hw, SW_CRYPTO_CONTROL) || + local->hw.wiphy->cipher_suites) { + /* If the driver advertises, or doesn't support SW crypto, + * we only need to remove WEP if necessary. + */ + if (have_wep) + return 0; + + /* well if it has _no_ ciphers ... fine */ + if (!local->hw.wiphy->n_cipher_suites) + return 0; + + /* Driver provides cipher suites, but we need to exclude WEP */ + suites = kmemdup(local->hw.wiphy->cipher_suites, + sizeof(u32) * local->hw.wiphy->n_cipher_suites, + GFP_KERNEL); + if (!suites) + return -ENOMEM; + + for (r = 0; r < local->hw.wiphy->n_cipher_suites; r++) { + u32 suite = local->hw.wiphy->cipher_suites[r]; + + if (suite == WLAN_CIPHER_SUITE_WEP40 || + suite == WLAN_CIPHER_SUITE_WEP104) + continue; + suites[w++] = suite; + } + } else { + /* assign the (software supported and perhaps offloaded) + * cipher suites + */ + local->hw.wiphy->cipher_suites = cipher_suites; + local->hw.wiphy->n_cipher_suites = ARRAY_SIZE(cipher_suites); + + if (!have_mfp) + local->hw.wiphy->n_cipher_suites -= 4; + + if (!have_wep) { + local->hw.wiphy->cipher_suites += 2; + local->hw.wiphy->n_cipher_suites -= 2; + } + + /* not dynamically allocated, so just return */ + return 0; + } + + local->hw.wiphy->cipher_suites = suites; + local->hw.wiphy->n_cipher_suites = w; + local->wiphy_ciphers_allocated = true; + + return 0; +} + +int ieee80211_register_hw(struct ieee80211_hw *hw) +{ + struct ieee80211_local *local = hw_to_local(hw); + int result, i; + enum nl80211_band band; + int channels, max_bitrates; + bool supp_ht, supp_vht, supp_he, supp_eht; + struct cfg80211_chan_def dflt_chandef = {}; + + if (ieee80211_hw_check(hw, QUEUE_CONTROL) && + (local->hw.offchannel_tx_hw_queue == IEEE80211_INVAL_HW_QUEUE || + local->hw.offchannel_tx_hw_queue >= local->hw.queues)) + return -EINVAL; + + if ((hw->wiphy->features & NL80211_FEATURE_TDLS_CHANNEL_SWITCH) && + (!local->ops->tdls_channel_switch || + !local->ops->tdls_cancel_channel_switch || + !local->ops->tdls_recv_channel_switch)) + return -EOPNOTSUPP; + + if (WARN_ON(ieee80211_hw_check(hw, SUPPORTS_TX_FRAG) && + !local->ops->set_frag_threshold)) + return -EINVAL; + + if (WARN_ON(local->hw.wiphy->interface_modes & + BIT(NL80211_IFTYPE_NAN) && + (!local->ops->start_nan || !local->ops->stop_nan))) + return -EINVAL; + + if (hw->wiphy->flags & WIPHY_FLAG_SUPPORTS_MLO) { + /* + * For drivers capable of doing MLO, assume modern driver + * or firmware facilities, so software doesn't have to do + * as much, e.g. monitoring beacons would be hard if we + * might not even know which link is active at which time. + */ + if (WARN_ON(!local->use_chanctx)) + return -EINVAL; + + if (WARN_ON(!local->ops->link_info_changed)) + return -EINVAL; + + if (WARN_ON(!ieee80211_hw_check(hw, HAS_RATE_CONTROL))) + return -EINVAL; + + if (WARN_ON(!ieee80211_hw_check(hw, AMPDU_AGGREGATION))) + return -EINVAL; + + if (WARN_ON(ieee80211_hw_check(hw, HOST_BROADCAST_PS_BUFFERING))) + return -EINVAL; + + if (WARN_ON(ieee80211_hw_check(hw, SUPPORTS_PS) && + (!ieee80211_hw_check(hw, SUPPORTS_DYNAMIC_PS) || + ieee80211_hw_check(hw, PS_NULLFUNC_STACK)))) + return -EINVAL; + + if (WARN_ON(!ieee80211_hw_check(hw, MFP_CAPABLE))) + return -EINVAL; + + if (WARN_ON(!ieee80211_hw_check(hw, CONNECTION_MONITOR))) + return -EINVAL; + + if (WARN_ON(ieee80211_hw_check(hw, NEED_DTIM_BEFORE_ASSOC))) + return -EINVAL; + + if (WARN_ON(ieee80211_hw_check(hw, TIMING_BEACON_ONLY))) + return -EINVAL; + + if (WARN_ON(!ieee80211_hw_check(hw, AP_LINK_PS))) + return -EINVAL; + + if (WARN_ON(ieee80211_hw_check(hw, DEAUTH_NEED_MGD_TX_PREP))) + return -EINVAL; + } + +#ifdef CONFIG_PM + if (hw->wiphy->wowlan && (!local->ops->suspend || !local->ops->resume)) + return -EINVAL; +#endif + + if (!local->use_chanctx) { + for (i = 0; i < local->hw.wiphy->n_iface_combinations; i++) { + const struct ieee80211_iface_combination *comb; + + comb = &local->hw.wiphy->iface_combinations[i]; + + if (comb->num_different_channels > 1) + return -EINVAL; + } + } else { + /* DFS is not supported with multi-channel combinations yet */ + for (i = 0; i < local->hw.wiphy->n_iface_combinations; i++) { + const struct ieee80211_iface_combination *comb; + + comb = &local->hw.wiphy->iface_combinations[i]; + + if (comb->radar_detect_widths && + comb->num_different_channels > 1) + return -EINVAL; + } + } + + /* Only HW csum features are currently compatible with mac80211 */ + if (WARN_ON(hw->netdev_features & ~MAC80211_SUPPORTED_FEATURES)) + return -EINVAL; + + if (hw->max_report_rates == 0) + hw->max_report_rates = hw->max_rates; + + local->rx_chains = 1; + + /* + * generic code guarantees at least one band, + * set this very early because much code assumes + * that hw.conf.channel is assigned + */ + channels = 0; + max_bitrates = 0; + supp_ht = false; + supp_vht = false; + supp_he = false; + supp_eht = false; + for (band = 0; band < NUM_NL80211_BANDS; band++) { + struct ieee80211_supported_band *sband; + + sband = local->hw.wiphy->bands[band]; + if (!sband) + continue; + + if (!dflt_chandef.chan) { + /* + * Assign the first enabled channel to dflt_chandef + * from the list of channels + */ + for (i = 0; i < sband->n_channels; i++) + if (!(sband->channels[i].flags & + IEEE80211_CHAN_DISABLED)) + break; + /* if none found then use the first anyway */ + if (i == sband->n_channels) + i = 0; + cfg80211_chandef_create(&dflt_chandef, + &sband->channels[i], + NL80211_CHAN_NO_HT); + /* init channel we're on */ + if (!local->use_chanctx && !local->_oper_chandef.chan) { + local->hw.conf.chandef = dflt_chandef; + local->_oper_chandef = dflt_chandef; + } + local->monitor_chandef = dflt_chandef; + } + + channels += sband->n_channels; + + if (max_bitrates < sband->n_bitrates) + max_bitrates = sband->n_bitrates; + supp_ht = supp_ht || sband->ht_cap.ht_supported; + supp_vht = supp_vht || sband->vht_cap.vht_supported; + + for (i = 0; i < sband->n_iftype_data; i++) { + const struct ieee80211_sband_iftype_data *iftd; + + iftd = &sband->iftype_data[i]; + + supp_he = supp_he || iftd->he_cap.has_he; + supp_eht = supp_eht || iftd->eht_cap.has_eht; + } + + /* HT, VHT, HE require QoS, thus >= 4 queues */ + if (WARN_ON(local->hw.queues < IEEE80211_NUM_ACS && + (supp_ht || supp_vht || supp_he))) + return -EINVAL; + + /* EHT requires HE support */ + if (WARN_ON(supp_eht && !supp_he)) + return -EINVAL; + + if (!sband->ht_cap.ht_supported) + continue; + + /* TODO: consider VHT for RX chains, hopefully it's the same */ + local->rx_chains = + max(ieee80211_mcs_to_chains(&sband->ht_cap.mcs), + local->rx_chains); + + /* no need to mask, SM_PS_DISABLED has all bits set */ + sband->ht_cap.cap |= WLAN_HT_CAP_SM_PS_DISABLED << + IEEE80211_HT_CAP_SM_PS_SHIFT; + } + + /* if low-level driver supports AP, we also support VLAN. + * drivers advertising SW_CRYPTO_CONTROL should enable AP_VLAN + * based on their support to transmit SW encrypted packets. + */ + if (local->hw.wiphy->interface_modes & BIT(NL80211_IFTYPE_AP) && + !ieee80211_hw_check(&local->hw, SW_CRYPTO_CONTROL)) { + hw->wiphy->interface_modes |= BIT(NL80211_IFTYPE_AP_VLAN); + hw->wiphy->software_iftypes |= BIT(NL80211_IFTYPE_AP_VLAN); + } + + /* mac80211 always supports monitor */ + hw->wiphy->interface_modes |= BIT(NL80211_IFTYPE_MONITOR); + hw->wiphy->software_iftypes |= BIT(NL80211_IFTYPE_MONITOR); + + /* mac80211 doesn't support more than one IBSS interface right now */ + for (i = 0; i < hw->wiphy->n_iface_combinations; i++) { + const struct ieee80211_iface_combination *c; + int j; + + c = &hw->wiphy->iface_combinations[i]; + + for (j = 0; j < c->n_limits; j++) + if ((c->limits[j].types & BIT(NL80211_IFTYPE_ADHOC)) && + c->limits[j].max > 1) + return -EINVAL; + } + + local->int_scan_req = kzalloc(sizeof(*local->int_scan_req) + + sizeof(void *) * channels, GFP_KERNEL); + if (!local->int_scan_req) + return -ENOMEM; + + for (band = 0; band < NUM_NL80211_BANDS; band++) { + if (!local->hw.wiphy->bands[band]) + continue; + local->int_scan_req->rates[band] = (u32) -1; + } + +#ifndef CONFIG_MAC80211_MESH + /* mesh depends on Kconfig, but drivers should set it if they want */ + local->hw.wiphy->interface_modes &= ~BIT(NL80211_IFTYPE_MESH_POINT); +#endif + + /* if the underlying driver supports mesh, mac80211 will (at least) + * provide routing of mesh authentication frames to userspace */ + if (local->hw.wiphy->interface_modes & BIT(NL80211_IFTYPE_MESH_POINT)) + local->hw.wiphy->flags |= WIPHY_FLAG_MESH_AUTH; + + /* mac80211 supports control port protocol changing */ + local->hw.wiphy->flags |= WIPHY_FLAG_CONTROL_PORT_PROTOCOL; + + if (ieee80211_hw_check(&local->hw, SIGNAL_DBM)) { + local->hw.wiphy->signal_type = CFG80211_SIGNAL_TYPE_MBM; + } else if (ieee80211_hw_check(&local->hw, SIGNAL_UNSPEC)) { + local->hw.wiphy->signal_type = CFG80211_SIGNAL_TYPE_UNSPEC; + if (hw->max_signal <= 0) { + result = -EINVAL; + goto fail_workqueue; + } + } + + /* Mac80211 and therefore all drivers using SW crypto only + * are able to handle PTK rekeys and Extended Key ID. + */ + if (!local->ops->set_key) { + wiphy_ext_feature_set(local->hw.wiphy, + NL80211_EXT_FEATURE_CAN_REPLACE_PTK0); + wiphy_ext_feature_set(local->hw.wiphy, + NL80211_EXT_FEATURE_EXT_KEY_ID); + } + + if (local->hw.wiphy->interface_modes & BIT(NL80211_IFTYPE_ADHOC)) + wiphy_ext_feature_set(local->hw.wiphy, + NL80211_EXT_FEATURE_DEL_IBSS_STA); + + /* + * Calculate scan IE length -- we need this to alloc + * memory and to subtract from the driver limit. It + * includes the DS Params, (extended) supported rates, and HT + * information -- SSID is the driver's responsibility. + */ + local->scan_ies_len = 4 + max_bitrates /* (ext) supp rates */ + + 3 /* DS Params */; + if (supp_ht) + local->scan_ies_len += 2 + sizeof(struct ieee80211_ht_cap); + + if (supp_vht) + local->scan_ies_len += + 2 + sizeof(struct ieee80211_vht_cap); + + /* + * HE cap element is variable in size - set len to allow max size */ + if (supp_he) { + local->scan_ies_len += + 3 + sizeof(struct ieee80211_he_cap_elem) + + sizeof(struct ieee80211_he_mcs_nss_supp) + + IEEE80211_HE_PPE_THRES_MAX_LEN; + + if (supp_eht) + local->scan_ies_len += + 3 + sizeof(struct ieee80211_eht_cap_elem) + + sizeof(struct ieee80211_eht_mcs_nss_supp) + + IEEE80211_EHT_PPE_THRES_MAX_LEN; + } + + if (!local->ops->hw_scan) { + /* For hw_scan, driver needs to set these up. */ + local->hw.wiphy->max_scan_ssids = 4; + local->hw.wiphy->max_scan_ie_len = IEEE80211_MAX_DATA_LEN; + } + + /* + * If the driver supports any scan IEs, then assume the + * limit includes the IEs mac80211 will add, otherwise + * leave it at zero and let the driver sort it out; we + * still pass our IEs to the driver but userspace will + * not be allowed to in that case. + */ + if (local->hw.wiphy->max_scan_ie_len) + local->hw.wiphy->max_scan_ie_len -= local->scan_ies_len; + + result = ieee80211_init_cipher_suites(local); + if (result < 0) + goto fail_workqueue; + + if (!local->ops->remain_on_channel) + local->hw.wiphy->max_remain_on_channel_duration = 5000; + + /* mac80211 based drivers don't support internal TDLS setup */ + if (local->hw.wiphy->flags & WIPHY_FLAG_SUPPORTS_TDLS) + local->hw.wiphy->flags |= WIPHY_FLAG_TDLS_EXTERNAL_SETUP; + + /* mac80211 supports eCSA, if the driver supports STA CSA at all */ + if (ieee80211_hw_check(&local->hw, CHANCTX_STA_CSA)) + local->ext_capa[0] |= WLAN_EXT_CAPA1_EXT_CHANNEL_SWITCHING; + + /* mac80211 supports multi BSSID, if the driver supports it */ + if (ieee80211_hw_check(&local->hw, SUPPORTS_MULTI_BSSID)) { + local->hw.wiphy->support_mbssid = true; + if (ieee80211_hw_check(&local->hw, + SUPPORTS_ONLY_HE_MULTI_BSSID)) + local->hw.wiphy->support_only_he_mbssid = true; + else + local->ext_capa[2] |= + WLAN_EXT_CAPA3_MULTI_BSSID_SUPPORT; + } + + local->hw.wiphy->max_num_csa_counters = IEEE80211_MAX_CNTDWN_COUNTERS_NUM; + + /* + * We use the number of queues for feature tests (QoS, HT) internally + * so restrict them appropriately. + */ + if (hw->queues > IEEE80211_MAX_QUEUES) + hw->queues = IEEE80211_MAX_QUEUES; + + local->workqueue = + alloc_ordered_workqueue("%s", 0, wiphy_name(local->hw.wiphy)); + if (!local->workqueue) { + result = -ENOMEM; + goto fail_workqueue; + } + + /* + * The hardware needs headroom for sending the frame, + * and we need some headroom for passing the frame to monitor + * interfaces, but never both at the same time. + */ + local->tx_headroom = max_t(unsigned int , local->hw.extra_tx_headroom, + IEEE80211_TX_STATUS_HEADROOM); + + /* + * if the driver doesn't specify a max listen interval we + * use 5 which should be a safe default + */ + if (local->hw.max_listen_interval == 0) + local->hw.max_listen_interval = 5; + + local->hw.conf.listen_interval = local->hw.max_listen_interval; + + local->dynamic_ps_forced_timeout = -1; + + if (!local->hw.max_nan_de_entries) + local->hw.max_nan_de_entries = IEEE80211_MAX_NAN_INSTANCE_ID; + + if (!local->hw.weight_multiplier) + local->hw.weight_multiplier = 1; + + ieee80211_wep_init(local); + + local->hw.conf.flags = IEEE80211_CONF_IDLE; + + ieee80211_led_init(local); + + result = ieee80211_txq_setup_flows(local); + if (result) + goto fail_flows; + + rtnl_lock(); + result = ieee80211_init_rate_ctrl_alg(local, + hw->rate_control_algorithm); + rtnl_unlock(); + if (result < 0) { + wiphy_debug(local->hw.wiphy, + "Failed to initialize rate control algorithm\n"); + goto fail_rate; + } + + if (local->rate_ctrl) { + clear_bit(IEEE80211_HW_SUPPORTS_VHT_EXT_NSS_BW, hw->flags); + if (local->rate_ctrl->ops->capa & RATE_CTRL_CAPA_VHT_EXT_NSS_BW) + ieee80211_hw_set(hw, SUPPORTS_VHT_EXT_NSS_BW); + } + + /* + * If the VHT capabilities don't have IEEE80211_VHT_EXT_NSS_BW_CAPABLE, + * or have it when we don't, copy the sband structure and set/clear it. + * This is necessary because rate scaling algorithms could be switched + * and have different support values. + * Print a message so that in the common case the reallocation can be + * avoided. + */ + BUILD_BUG_ON(NUM_NL80211_BANDS > 8 * sizeof(local->sband_allocated)); + for (band = 0; band < NUM_NL80211_BANDS; band++) { + struct ieee80211_supported_band *sband; + bool local_cap, ie_cap; + + local_cap = ieee80211_hw_check(hw, SUPPORTS_VHT_EXT_NSS_BW); + + sband = local->hw.wiphy->bands[band]; + if (!sband || !sband->vht_cap.vht_supported) + continue; + + ie_cap = !!(sband->vht_cap.vht_mcs.tx_highest & + cpu_to_le16(IEEE80211_VHT_EXT_NSS_BW_CAPABLE)); + + if (local_cap == ie_cap) + continue; + + sband = kmemdup(sband, sizeof(*sband), GFP_KERNEL); + if (!sband) { + result = -ENOMEM; + goto fail_rate; + } + + wiphy_dbg(hw->wiphy, "copying sband (band %d) due to VHT EXT NSS BW flag\n", + band); + + sband->vht_cap.vht_mcs.tx_highest ^= + cpu_to_le16(IEEE80211_VHT_EXT_NSS_BW_CAPABLE); + + local->hw.wiphy->bands[band] = sband; + local->sband_allocated |= BIT(band); + } + + result = wiphy_register(local->hw.wiphy); + if (result < 0) + goto fail_wiphy_register; + + debugfs_hw_add(local); + rate_control_add_debugfs(local); + + rtnl_lock(); + wiphy_lock(hw->wiphy); + + /* add one default STA interface if supported */ + if (local->hw.wiphy->interface_modes & BIT(NL80211_IFTYPE_STATION) && + !ieee80211_hw_check(hw, NO_AUTO_VIF)) { + struct vif_params params = {0}; + + result = ieee80211_if_add(local, "wlan%d", NET_NAME_ENUM, NULL, + NL80211_IFTYPE_STATION, ¶ms); + if (result) + wiphy_warn(local->hw.wiphy, + "Failed to add default virtual iface\n"); + } + + wiphy_unlock(hw->wiphy); + rtnl_unlock(); + +#ifdef CONFIG_INET + local->ifa_notifier.notifier_call = ieee80211_ifa_changed; + result = register_inetaddr_notifier(&local->ifa_notifier); + if (result) + goto fail_ifa; +#endif + +#if IS_ENABLED(CONFIG_IPV6) + local->ifa6_notifier.notifier_call = ieee80211_ifa6_changed; + result = register_inet6addr_notifier(&local->ifa6_notifier); + if (result) + goto fail_ifa6; +#endif + + return 0; + +#if IS_ENABLED(CONFIG_IPV6) + fail_ifa6: +#ifdef CONFIG_INET + unregister_inetaddr_notifier(&local->ifa_notifier); +#endif +#endif +#if defined(CONFIG_INET) || defined(CONFIG_IPV6) + fail_ifa: +#endif + wiphy_unregister(local->hw.wiphy); + fail_wiphy_register: + rtnl_lock(); + rate_control_deinitialize(local); + ieee80211_remove_interfaces(local); + rtnl_unlock(); + fail_rate: + fail_flows: + ieee80211_led_exit(local); + destroy_workqueue(local->workqueue); + fail_workqueue: + if (local->wiphy_ciphers_allocated) { + kfree(local->hw.wiphy->cipher_suites); + local->wiphy_ciphers_allocated = false; + } + kfree(local->int_scan_req); + return result; +} +EXPORT_SYMBOL(ieee80211_register_hw); + +void ieee80211_unregister_hw(struct ieee80211_hw *hw) +{ + struct ieee80211_local *local = hw_to_local(hw); + + tasklet_kill(&local->tx_pending_tasklet); + tasklet_kill(&local->tasklet); + +#ifdef CONFIG_INET + unregister_inetaddr_notifier(&local->ifa_notifier); +#endif +#if IS_ENABLED(CONFIG_IPV6) + unregister_inet6addr_notifier(&local->ifa6_notifier); +#endif + + rtnl_lock(); + + /* + * At this point, interface list manipulations are fine + * because the driver cannot be handing us frames any + * more and the tasklet is killed. + */ + ieee80211_remove_interfaces(local); + + wiphy_lock(local->hw.wiphy); + wiphy_delayed_work_cancel(local->hw.wiphy, &local->roc_work); + wiphy_work_cancel(local->hw.wiphy, &local->sched_scan_stopped_work); + wiphy_work_cancel(local->hw.wiphy, &local->radar_detected_work); + wiphy_unlock(local->hw.wiphy); + rtnl_unlock(); + + cancel_work_sync(&local->restart_work); + cancel_work_sync(&local->reconfig_filter); + + ieee80211_clear_tx_pending(local); + rate_control_deinitialize(local); + + if (skb_queue_len(&local->skb_queue) || + skb_queue_len(&local->skb_queue_unreliable)) + wiphy_warn(local->hw.wiphy, "skb_queue not empty\n"); + skb_queue_purge(&local->skb_queue); + skb_queue_purge(&local->skb_queue_unreliable); + + wiphy_unregister(local->hw.wiphy); + destroy_workqueue(local->workqueue); + ieee80211_led_exit(local); + kfree(local->int_scan_req); +} +EXPORT_SYMBOL(ieee80211_unregister_hw); + +static int ieee80211_free_ack_frame(int id, void *p, void *data) +{ + WARN_ONCE(1, "Have pending ack frames!\n"); + kfree_skb(p); + return 0; +} + +void ieee80211_free_hw(struct ieee80211_hw *hw) +{ + struct ieee80211_local *local = hw_to_local(hw); + enum nl80211_band band; + + mutex_destroy(&local->iflist_mtx); + mutex_destroy(&local->mtx); + + if (local->wiphy_ciphers_allocated) { + kfree(local->hw.wiphy->cipher_suites); + local->wiphy_ciphers_allocated = false; + } + + idr_for_each(&local->ack_status_frames, + ieee80211_free_ack_frame, NULL); + idr_destroy(&local->ack_status_frames); + + sta_info_stop(local); + + ieee80211_free_led_names(local); + + for (band = 0; band < NUM_NL80211_BANDS; band++) { + if (!(local->sband_allocated & BIT(band))) + continue; + kfree(local->hw.wiphy->bands[band]); + } + + wiphy_free(local->hw.wiphy); +} +EXPORT_SYMBOL(ieee80211_free_hw); + +static int __init ieee80211_init(void) +{ + struct sk_buff *skb; + int ret; + + BUILD_BUG_ON(sizeof(struct ieee80211_tx_info) > sizeof(skb->cb)); + BUILD_BUG_ON(offsetof(struct ieee80211_tx_info, driver_data) + + IEEE80211_TX_INFO_DRIVER_DATA_SIZE > sizeof(skb->cb)); + + ret = rc80211_minstrel_init(); + if (ret) + return ret; + + ret = ieee80211_iface_init(); + if (ret) + goto err_netdev; + + return 0; + err_netdev: + rc80211_minstrel_exit(); + + return ret; +} + +static void __exit ieee80211_exit(void) +{ + rc80211_minstrel_exit(); + + ieee80211s_stop(); + + ieee80211_iface_exit(); + + rcu_barrier(); +} + + +subsys_initcall(ieee80211_init); +module_exit(ieee80211_exit); + +MODULE_DESCRIPTION("IEEE 802.11 subsystem"); +MODULE_LICENSE("GPL"); diff --git a/net/mac80211/mesh.c b/net/mac80211/mesh.c new file mode 100644 index 000000000..5a99b8f6e --- /dev/null +++ b/net/mac80211/mesh.c @@ -0,0 +1,1650 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008, 2009 open80211s Ltd. + * Copyright (C) 2018 - 2022 Intel Corporation + * Authors: Luis Carlos Cobo + * Javier Cardona + */ + +#include +#include +#include "ieee80211_i.h" +#include "mesh.h" +#include "driver-ops.h" + +static int mesh_allocated; +static struct kmem_cache *rm_cache; + +bool mesh_action_is_path_sel(struct ieee80211_mgmt *mgmt) +{ + return (mgmt->u.action.u.mesh_action.action_code == + WLAN_MESH_ACTION_HWMP_PATH_SELECTION); +} + +void ieee80211s_init(void) +{ + mesh_allocated = 1; + rm_cache = kmem_cache_create("mesh_rmc", sizeof(struct rmc_entry), + 0, 0, NULL); +} + +void ieee80211s_stop(void) +{ + if (!mesh_allocated) + return; + kmem_cache_destroy(rm_cache); +} + +static void ieee80211_mesh_housekeeping_timer(struct timer_list *t) +{ + struct ieee80211_sub_if_data *sdata = + from_timer(sdata, t, u.mesh.housekeeping_timer); + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + + set_bit(MESH_WORK_HOUSEKEEPING, &ifmsh->wrkq_flags); + + ieee80211_queue_work(&local->hw, &sdata->work); +} + +/** + * mesh_matches_local - check if the config of a mesh point matches ours + * + * @sdata: local mesh subif + * @ie: information elements of a management frame from the mesh peer + * + * This function checks if the mesh configuration of a mesh point matches the + * local mesh configuration, i.e. if both nodes belong to the same mesh network. + */ +bool mesh_matches_local(struct ieee80211_sub_if_data *sdata, + struct ieee802_11_elems *ie) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + u32 basic_rates = 0; + struct cfg80211_chan_def sta_chan_def; + struct ieee80211_supported_band *sband; + u32 vht_cap_info = 0; + + /* + * As support for each feature is added, check for matching + * - On mesh config capabilities + * - Power Save Support En + * - Sync support enabled + * - Sync support active + * - Sync support required from peer + * - MDA enabled + * - Power management control on fc + */ + if (!(ifmsh->mesh_id_len == ie->mesh_id_len && + memcmp(ifmsh->mesh_id, ie->mesh_id, ie->mesh_id_len) == 0 && + (ifmsh->mesh_pp_id == ie->mesh_config->meshconf_psel) && + (ifmsh->mesh_pm_id == ie->mesh_config->meshconf_pmetric) && + (ifmsh->mesh_cc_id == ie->mesh_config->meshconf_congest) && + (ifmsh->mesh_sp_id == ie->mesh_config->meshconf_synch) && + (ifmsh->mesh_auth_id == ie->mesh_config->meshconf_auth))) + return false; + + sband = ieee80211_get_sband(sdata); + if (!sband) + return false; + + ieee80211_sta_get_rates(sdata, ie, sband->band, + &basic_rates); + + if (sdata->vif.bss_conf.basic_rates != basic_rates) + return false; + + cfg80211_chandef_create(&sta_chan_def, sdata->vif.bss_conf.chandef.chan, + NL80211_CHAN_NO_HT); + ieee80211_chandef_ht_oper(ie->ht_operation, &sta_chan_def); + + if (ie->vht_cap_elem) + vht_cap_info = le32_to_cpu(ie->vht_cap_elem->vht_cap_info); + + ieee80211_chandef_vht_oper(&sdata->local->hw, vht_cap_info, + ie->vht_operation, ie->ht_operation, + &sta_chan_def); + ieee80211_chandef_he_6ghz_oper(sdata, ie->he_operation, NULL, + &sta_chan_def); + + if (!cfg80211_chandef_compatible(&sdata->vif.bss_conf.chandef, + &sta_chan_def)) + return false; + + return true; +} + +/** + * mesh_peer_accepts_plinks - check if an mp is willing to establish peer links + * + * @ie: information elements of a management frame from the mesh peer + */ +bool mesh_peer_accepts_plinks(struct ieee802_11_elems *ie) +{ + return (ie->mesh_config->meshconf_cap & + IEEE80211_MESHCONF_CAPAB_ACCEPT_PLINKS) != 0; +} + +/** + * mesh_accept_plinks_update - update accepting_plink in local mesh beacons + * + * @sdata: mesh interface in which mesh beacons are going to be updated + * + * Returns: beacon changed flag if the beacon content changed. + */ +u32 mesh_accept_plinks_update(struct ieee80211_sub_if_data *sdata) +{ + bool free_plinks; + u32 changed = 0; + + /* In case mesh_plink_free_count > 0 and mesh_plinktbl_capacity == 0, + * the mesh interface might be able to establish plinks with peers that + * are already on the table but are not on PLINK_ESTAB state. However, + * in general the mesh interface is not accepting peer link requests + * from new peers, and that must be reflected in the beacon + */ + free_plinks = mesh_plink_availables(sdata); + + if (free_plinks != sdata->u.mesh.accepting_plinks) { + sdata->u.mesh.accepting_plinks = free_plinks; + changed = BSS_CHANGED_BEACON; + } + + return changed; +} + +/* + * mesh_sta_cleanup - clean up any mesh sta state + * + * @sta: mesh sta to clean up. + */ +void mesh_sta_cleanup(struct sta_info *sta) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + u32 changed = mesh_plink_deactivate(sta); + + if (changed) + ieee80211_mbss_info_change_notify(sdata, changed); +} + +int mesh_rmc_init(struct ieee80211_sub_if_data *sdata) +{ + int i; + + sdata->u.mesh.rmc = kmalloc(sizeof(struct mesh_rmc), GFP_KERNEL); + if (!sdata->u.mesh.rmc) + return -ENOMEM; + sdata->u.mesh.rmc->idx_mask = RMC_BUCKETS - 1; + for (i = 0; i < RMC_BUCKETS; i++) + INIT_HLIST_HEAD(&sdata->u.mesh.rmc->bucket[i]); + return 0; +} + +void mesh_rmc_free(struct ieee80211_sub_if_data *sdata) +{ + struct mesh_rmc *rmc = sdata->u.mesh.rmc; + struct rmc_entry *p; + struct hlist_node *n; + int i; + + if (!sdata->u.mesh.rmc) + return; + + for (i = 0; i < RMC_BUCKETS; i++) { + hlist_for_each_entry_safe(p, n, &rmc->bucket[i], list) { + hlist_del(&p->list); + kmem_cache_free(rm_cache, p); + } + } + + kfree(rmc); + sdata->u.mesh.rmc = NULL; +} + +/** + * mesh_rmc_check - Check frame in recent multicast cache and add if absent. + * + * @sdata: interface + * @sa: source address + * @mesh_hdr: mesh_header + * + * Returns: 0 if the frame is not in the cache, nonzero otherwise. + * + * Checks using the source address and the mesh sequence number if we have + * received this frame lately. If the frame is not in the cache, it is added to + * it. + */ +int mesh_rmc_check(struct ieee80211_sub_if_data *sdata, + const u8 *sa, struct ieee80211s_hdr *mesh_hdr) +{ + struct mesh_rmc *rmc = sdata->u.mesh.rmc; + u32 seqnum = 0; + int entries = 0; + u8 idx; + struct rmc_entry *p; + struct hlist_node *n; + + if (!rmc) + return -1; + + /* Don't care about endianness since only match matters */ + memcpy(&seqnum, &mesh_hdr->seqnum, sizeof(mesh_hdr->seqnum)); + idx = le32_to_cpu(mesh_hdr->seqnum) & rmc->idx_mask; + hlist_for_each_entry_safe(p, n, &rmc->bucket[idx], list) { + ++entries; + if (time_after(jiffies, p->exp_time) || + entries == RMC_QUEUE_MAX_LEN) { + hlist_del(&p->list); + kmem_cache_free(rm_cache, p); + --entries; + } else if ((seqnum == p->seqnum) && ether_addr_equal(sa, p->sa)) + return -1; + } + + p = kmem_cache_alloc(rm_cache, GFP_ATOMIC); + if (!p) + return 0; + + p->seqnum = seqnum; + p->exp_time = jiffies + RMC_TIMEOUT; + memcpy(p->sa, sa, ETH_ALEN); + hlist_add_head(&p->list, &rmc->bucket[idx]); + return 0; +} + +int mesh_add_meshconf_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + u8 *pos, neighbors; + u8 meshconf_len = sizeof(struct ieee80211_meshconf_ie); + bool is_connected_to_gate = ifmsh->num_gates > 0 || + ifmsh->mshcfg.dot11MeshGateAnnouncementProtocol || + ifmsh->mshcfg.dot11MeshConnectedToMeshGate; + bool is_connected_to_as = ifmsh->mshcfg.dot11MeshConnectedToAuthServer; + + if (skb_tailroom(skb) < 2 + meshconf_len) + return -ENOMEM; + + pos = skb_put(skb, 2 + meshconf_len); + *pos++ = WLAN_EID_MESH_CONFIG; + *pos++ = meshconf_len; + + /* save a pointer for quick updates in pre-tbtt */ + ifmsh->meshconf_offset = pos - skb->data; + + /* Active path selection protocol ID */ + *pos++ = ifmsh->mesh_pp_id; + /* Active path selection metric ID */ + *pos++ = ifmsh->mesh_pm_id; + /* Congestion control mode identifier */ + *pos++ = ifmsh->mesh_cc_id; + /* Synchronization protocol identifier */ + *pos++ = ifmsh->mesh_sp_id; + /* Authentication Protocol identifier */ + *pos++ = ifmsh->mesh_auth_id; + /* Mesh Formation Info - number of neighbors */ + neighbors = atomic_read(&ifmsh->estab_plinks); + neighbors = min_t(int, neighbors, IEEE80211_MAX_MESH_PEERINGS); + *pos++ = (is_connected_to_as << 7) | + (neighbors << 1) | + is_connected_to_gate; + /* Mesh capability */ + *pos = 0x00; + *pos |= ifmsh->mshcfg.dot11MeshForwarding ? + IEEE80211_MESHCONF_CAPAB_FORWARDING : 0x00; + *pos |= ifmsh->accepting_plinks ? + IEEE80211_MESHCONF_CAPAB_ACCEPT_PLINKS : 0x00; + /* Mesh PS mode. See IEEE802.11-2012 8.4.2.100.8 */ + *pos |= ifmsh->ps_peers_deep_sleep ? + IEEE80211_MESHCONF_CAPAB_POWER_SAVE_LEVEL : 0x00; + return 0; +} + +int mesh_add_meshid_ie(struct ieee80211_sub_if_data *sdata, struct sk_buff *skb) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + u8 *pos; + + if (skb_tailroom(skb) < 2 + ifmsh->mesh_id_len) + return -ENOMEM; + + pos = skb_put(skb, 2 + ifmsh->mesh_id_len); + *pos++ = WLAN_EID_MESH_ID; + *pos++ = ifmsh->mesh_id_len; + if (ifmsh->mesh_id_len) + memcpy(pos, ifmsh->mesh_id, ifmsh->mesh_id_len); + + return 0; +} + +static int mesh_add_awake_window_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + u8 *pos; + + /* see IEEE802.11-2012 13.14.6 */ + if (ifmsh->ps_peers_light_sleep == 0 && + ifmsh->ps_peers_deep_sleep == 0 && + ifmsh->nonpeer_pm == NL80211_MESH_POWER_ACTIVE) + return 0; + + if (skb_tailroom(skb) < 4) + return -ENOMEM; + + pos = skb_put(skb, 2 + 2); + *pos++ = WLAN_EID_MESH_AWAKE_WINDOW; + *pos++ = 2; + put_unaligned_le16(ifmsh->mshcfg.dot11MeshAwakeWindowDuration, pos); + + return 0; +} + +int mesh_add_vendor_ies(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + u8 offset, len; + const u8 *data; + + if (!ifmsh->ie || !ifmsh->ie_len) + return 0; + + /* fast-forward to vendor IEs */ + offset = ieee80211_ie_split_vendor(ifmsh->ie, ifmsh->ie_len, 0); + + if (offset < ifmsh->ie_len) { + len = ifmsh->ie_len - offset; + data = ifmsh->ie + offset; + if (skb_tailroom(skb) < len) + return -ENOMEM; + skb_put_data(skb, data, len); + } + + return 0; +} + +int mesh_add_rsn_ie(struct ieee80211_sub_if_data *sdata, struct sk_buff *skb) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + u8 len = 0; + const u8 *data; + + if (!ifmsh->ie || !ifmsh->ie_len) + return 0; + + /* find RSN IE */ + data = cfg80211_find_ie(WLAN_EID_RSN, ifmsh->ie, ifmsh->ie_len); + if (!data) + return 0; + + len = data[1] + 2; + + if (skb_tailroom(skb) < len) + return -ENOMEM; + skb_put_data(skb, data, len); + + return 0; +} + +static int mesh_add_ds_params_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_chanctx_conf *chanctx_conf; + struct ieee80211_channel *chan; + u8 *pos; + + if (skb_tailroom(skb) < 3) + return -ENOMEM; + + rcu_read_lock(); + chanctx_conf = rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + if (WARN_ON(!chanctx_conf)) { + rcu_read_unlock(); + return -EINVAL; + } + chan = chanctx_conf->def.chan; + rcu_read_unlock(); + + pos = skb_put(skb, 2 + 1); + *pos++ = WLAN_EID_DS_PARAMS; + *pos++ = 1; + *pos++ = ieee80211_frequency_to_channel(chan->center_freq); + + return 0; +} + +int mesh_add_ht_cap_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_supported_band *sband; + u8 *pos; + + sband = ieee80211_get_sband(sdata); + if (!sband) + return -EINVAL; + + /* HT not allowed in 6 GHz */ + if (sband->band == NL80211_BAND_6GHZ) + return 0; + + if (!sband->ht_cap.ht_supported || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_20_NOHT || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_5 || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_10) + return 0; + + if (skb_tailroom(skb) < 2 + sizeof(struct ieee80211_ht_cap)) + return -ENOMEM; + + pos = skb_put(skb, 2 + sizeof(struct ieee80211_ht_cap)); + ieee80211_ie_build_ht_cap(pos, &sband->ht_cap, sband->ht_cap.cap); + + return 0; +} + +int mesh_add_ht_oper_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_chanctx_conf *chanctx_conf; + struct ieee80211_channel *channel; + struct ieee80211_supported_band *sband; + struct ieee80211_sta_ht_cap *ht_cap; + u8 *pos; + + rcu_read_lock(); + chanctx_conf = rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + if (WARN_ON(!chanctx_conf)) { + rcu_read_unlock(); + return -EINVAL; + } + channel = chanctx_conf->def.chan; + rcu_read_unlock(); + + sband = local->hw.wiphy->bands[channel->band]; + ht_cap = &sband->ht_cap; + + /* HT not allowed in 6 GHz */ + if (sband->band == NL80211_BAND_6GHZ) + return 0; + + if (!ht_cap->ht_supported || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_20_NOHT || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_5 || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_10) + return 0; + + if (skb_tailroom(skb) < 2 + sizeof(struct ieee80211_ht_operation)) + return -ENOMEM; + + pos = skb_put(skb, 2 + sizeof(struct ieee80211_ht_operation)); + ieee80211_ie_build_ht_oper(pos, ht_cap, &sdata->vif.bss_conf.chandef, + sdata->vif.bss_conf.ht_operation_mode, + false); + + return 0; +} + +int mesh_add_vht_cap_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_supported_band *sband; + u8 *pos; + + sband = ieee80211_get_sband(sdata); + if (!sband) + return -EINVAL; + + /* VHT not allowed in 6 GHz */ + if (sband->band == NL80211_BAND_6GHZ) + return 0; + + if (!sband->vht_cap.vht_supported || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_20_NOHT || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_5 || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_10) + return 0; + + if (skb_tailroom(skb) < 2 + sizeof(struct ieee80211_vht_cap)) + return -ENOMEM; + + pos = skb_put(skb, 2 + sizeof(struct ieee80211_vht_cap)); + ieee80211_ie_build_vht_cap(pos, &sband->vht_cap, sband->vht_cap.cap); + + return 0; +} + +int mesh_add_vht_oper_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_chanctx_conf *chanctx_conf; + struct ieee80211_channel *channel; + struct ieee80211_supported_band *sband; + struct ieee80211_sta_vht_cap *vht_cap; + u8 *pos; + + rcu_read_lock(); + chanctx_conf = rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + if (WARN_ON(!chanctx_conf)) { + rcu_read_unlock(); + return -EINVAL; + } + channel = chanctx_conf->def.chan; + rcu_read_unlock(); + + sband = local->hw.wiphy->bands[channel->band]; + vht_cap = &sband->vht_cap; + + /* VHT not allowed in 6 GHz */ + if (sband->band == NL80211_BAND_6GHZ) + return 0; + + if (!vht_cap->vht_supported || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_20_NOHT || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_5 || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_10) + return 0; + + if (skb_tailroom(skb) < 2 + sizeof(struct ieee80211_vht_operation)) + return -ENOMEM; + + pos = skb_put(skb, 2 + sizeof(struct ieee80211_vht_operation)); + ieee80211_ie_build_vht_oper(pos, vht_cap, + &sdata->vif.bss_conf.chandef); + + return 0; +} + +int mesh_add_he_cap_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, u8 ie_len) +{ + const struct ieee80211_sta_he_cap *he_cap; + struct ieee80211_supported_band *sband; + u8 *pos; + + sband = ieee80211_get_sband(sdata); + if (!sband) + return -EINVAL; + + he_cap = ieee80211_get_he_iftype_cap(sband, NL80211_IFTYPE_MESH_POINT); + + if (!he_cap || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_20_NOHT || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_5 || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_10) + return 0; + + if (skb_tailroom(skb) < ie_len) + return -ENOMEM; + + pos = skb_put(skb, ie_len); + ieee80211_ie_build_he_cap(0, pos, he_cap, pos + ie_len); + + return 0; +} + +int mesh_add_he_oper_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + const struct ieee80211_sta_he_cap *he_cap; + struct ieee80211_supported_band *sband; + u32 len; + u8 *pos; + + sband = ieee80211_get_sband(sdata); + if (!sband) + return -EINVAL; + + he_cap = ieee80211_get_he_iftype_cap(sband, NL80211_IFTYPE_MESH_POINT); + if (!he_cap || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_20_NOHT || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_5 || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_10) + return 0; + + len = 2 + 1 + sizeof(struct ieee80211_he_operation); + if (sdata->vif.bss_conf.chandef.chan->band == NL80211_BAND_6GHZ) + len += sizeof(struct ieee80211_he_6ghz_oper); + + if (skb_tailroom(skb) < len) + return -ENOMEM; + + pos = skb_put(skb, len); + ieee80211_ie_build_he_oper(pos, &sdata->vif.bss_conf.chandef); + + return 0; +} + +int mesh_add_he_6ghz_cap_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_supported_band *sband; + const struct ieee80211_sband_iftype_data *iftd; + + sband = ieee80211_get_sband(sdata); + if (!sband) + return -EINVAL; + + iftd = ieee80211_get_sband_iftype_data(sband, + NL80211_IFTYPE_MESH_POINT); + /* The device doesn't support HE in mesh mode or at all */ + if (!iftd) + return 0; + + ieee80211_ie_build_he_6ghz_cap(sdata, sdata->deflink.smps_mode, skb); + return 0; +} + +static void ieee80211_mesh_path_timer(struct timer_list *t) +{ + struct ieee80211_sub_if_data *sdata = + from_timer(sdata, t, u.mesh.mesh_path_timer); + + ieee80211_queue_work(&sdata->local->hw, &sdata->work); +} + +static void ieee80211_mesh_path_root_timer(struct timer_list *t) +{ + struct ieee80211_sub_if_data *sdata = + from_timer(sdata, t, u.mesh.mesh_path_root_timer); + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + + set_bit(MESH_WORK_ROOT, &ifmsh->wrkq_flags); + + ieee80211_queue_work(&sdata->local->hw, &sdata->work); +} + +void ieee80211_mesh_root_setup(struct ieee80211_if_mesh *ifmsh) +{ + if (ifmsh->mshcfg.dot11MeshHWMPRootMode > IEEE80211_ROOTMODE_ROOT) + set_bit(MESH_WORK_ROOT, &ifmsh->wrkq_flags); + else { + clear_bit(MESH_WORK_ROOT, &ifmsh->wrkq_flags); + /* stop running timer */ + del_timer_sync(&ifmsh->mesh_path_root_timer); + } +} + +static void +ieee80211_mesh_update_bss_params(struct ieee80211_sub_if_data *sdata, + u8 *ie, u8 ie_len) +{ + struct ieee80211_supported_band *sband; + const struct element *cap; + const struct ieee80211_he_operation *he_oper = NULL; + + sband = ieee80211_get_sband(sdata); + if (!sband) + return; + + if (!ieee80211_get_he_iftype_cap(sband, NL80211_IFTYPE_MESH_POINT) || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_20_NOHT || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_5 || + sdata->vif.bss_conf.chandef.width == NL80211_CHAN_WIDTH_10) + return; + + sdata->vif.bss_conf.he_support = true; + + cap = cfg80211_find_ext_elem(WLAN_EID_EXT_HE_OPERATION, ie, ie_len); + if (cap && cap->datalen >= 1 + sizeof(*he_oper) && + cap->datalen >= 1 + ieee80211_he_oper_size(cap->data + 1)) + he_oper = (void *)(cap->data + 1); + + if (he_oper) + sdata->vif.bss_conf.he_oper.params = + __le32_to_cpu(he_oper->he_oper_params); +} + +/** + * ieee80211_fill_mesh_addresses - fill addresses of a locally originated mesh frame + * @hdr: 802.11 frame header + * @fc: frame control field + * @meshda: destination address in the mesh + * @meshsa: source address in the mesh. Same as TA, as frame is + * locally originated. + * + * Return the length of the 802.11 (does not include a mesh control header) + */ +int ieee80211_fill_mesh_addresses(struct ieee80211_hdr *hdr, __le16 *fc, + const u8 *meshda, const u8 *meshsa) +{ + if (is_multicast_ether_addr(meshda)) { + *fc |= cpu_to_le16(IEEE80211_FCTL_FROMDS); + /* DA TA SA */ + memcpy(hdr->addr1, meshda, ETH_ALEN); + memcpy(hdr->addr2, meshsa, ETH_ALEN); + memcpy(hdr->addr3, meshsa, ETH_ALEN); + return 24; + } else { + *fc |= cpu_to_le16(IEEE80211_FCTL_FROMDS | IEEE80211_FCTL_TODS); + /* RA TA DA SA */ + eth_zero_addr(hdr->addr1); /* RA is resolved later */ + memcpy(hdr->addr2, meshsa, ETH_ALEN); + memcpy(hdr->addr3, meshda, ETH_ALEN); + memcpy(hdr->addr4, meshsa, ETH_ALEN); + return 30; + } +} + +/** + * ieee80211_new_mesh_header - create a new mesh header + * @sdata: mesh interface to be used + * @meshhdr: uninitialized mesh header + * @addr4or5: 1st address in the ae header, which may correspond to address 4 + * (if addr6 is NULL) or address 5 (if addr6 is present). It may + * be NULL. + * @addr6: 2nd address in the ae header, which corresponds to addr6 of the + * mesh frame + * + * Return the header length. + */ +unsigned int ieee80211_new_mesh_header(struct ieee80211_sub_if_data *sdata, + struct ieee80211s_hdr *meshhdr, + const char *addr4or5, const char *addr6) +{ + if (WARN_ON(!addr4or5 && addr6)) + return 0; + + memset(meshhdr, 0, sizeof(*meshhdr)); + + meshhdr->ttl = sdata->u.mesh.mshcfg.dot11MeshTTL; + + /* FIXME: racy -- TX on multiple queues can be concurrent */ + put_unaligned(cpu_to_le32(sdata->u.mesh.mesh_seqnum), &meshhdr->seqnum); + sdata->u.mesh.mesh_seqnum++; + + if (addr4or5 && !addr6) { + meshhdr->flags |= MESH_FLAGS_AE_A4; + memcpy(meshhdr->eaddr1, addr4or5, ETH_ALEN); + return 2 * ETH_ALEN; + } else if (addr4or5 && addr6) { + meshhdr->flags |= MESH_FLAGS_AE_A5_A6; + memcpy(meshhdr->eaddr1, addr4or5, ETH_ALEN); + memcpy(meshhdr->eaddr2, addr6, ETH_ALEN); + return 3 * ETH_ALEN; + } + + return ETH_ALEN; +} + +static void ieee80211_mesh_housekeeping(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + u32 changed; + + if (ifmsh->mshcfg.plink_timeout > 0) + ieee80211_sta_expire(sdata, ifmsh->mshcfg.plink_timeout * HZ); + mesh_path_expire(sdata); + + changed = mesh_accept_plinks_update(sdata); + ieee80211_mbss_info_change_notify(sdata, changed); + + mod_timer(&ifmsh->housekeeping_timer, + round_jiffies(jiffies + + IEEE80211_MESH_HOUSEKEEPING_INTERVAL)); +} + +static void ieee80211_mesh_rootpath(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + u32 interval; + + mesh_path_tx_root_frame(sdata); + + if (ifmsh->mshcfg.dot11MeshHWMPRootMode == IEEE80211_PROACTIVE_RANN) + interval = ifmsh->mshcfg.dot11MeshHWMPRannInterval; + else + interval = ifmsh->mshcfg.dot11MeshHWMProotInterval; + + mod_timer(&ifmsh->mesh_path_root_timer, + round_jiffies(TU_TO_EXP_TIME(interval))); +} + +static int +ieee80211_mesh_build_beacon(struct ieee80211_if_mesh *ifmsh) +{ + struct beacon_data *bcn; + int head_len, tail_len; + struct sk_buff *skb; + struct ieee80211_mgmt *mgmt; + struct ieee80211_chanctx_conf *chanctx_conf; + struct mesh_csa_settings *csa; + enum nl80211_band band; + u8 ie_len_he_cap; + u8 *pos; + struct ieee80211_sub_if_data *sdata; + int hdr_len = offsetofend(struct ieee80211_mgmt, u.beacon); + + sdata = container_of(ifmsh, struct ieee80211_sub_if_data, u.mesh); + rcu_read_lock(); + chanctx_conf = rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + band = chanctx_conf->def.chan->band; + rcu_read_unlock(); + + ie_len_he_cap = ieee80211_ie_len_he_cap(sdata, + NL80211_IFTYPE_MESH_POINT); + head_len = hdr_len + + 2 + /* NULL SSID */ + /* Channel Switch Announcement */ + 2 + sizeof(struct ieee80211_channel_sw_ie) + + /* Mesh Channel Switch Parameters */ + 2 + sizeof(struct ieee80211_mesh_chansw_params_ie) + + /* Channel Switch Wrapper + Wide Bandwidth CSA IE */ + 2 + 2 + sizeof(struct ieee80211_wide_bw_chansw_ie) + + 2 + sizeof(struct ieee80211_sec_chan_offs_ie) + + 2 + 8 + /* supported rates */ + 2 + 3; /* DS params */ + tail_len = 2 + (IEEE80211_MAX_SUPP_RATES - 8) + + 2 + sizeof(struct ieee80211_ht_cap) + + 2 + sizeof(struct ieee80211_ht_operation) + + 2 + ifmsh->mesh_id_len + + 2 + sizeof(struct ieee80211_meshconf_ie) + + 2 + sizeof(__le16) + /* awake window */ + 2 + sizeof(struct ieee80211_vht_cap) + + 2 + sizeof(struct ieee80211_vht_operation) + + ie_len_he_cap + + 2 + 1 + sizeof(struct ieee80211_he_operation) + + sizeof(struct ieee80211_he_6ghz_oper) + + 2 + 1 + sizeof(struct ieee80211_he_6ghz_capa) + + ifmsh->ie_len; + + bcn = kzalloc(sizeof(*bcn) + head_len + tail_len, GFP_KERNEL); + /* need an skb for IE builders to operate on */ + skb = __dev_alloc_skb(max(head_len, tail_len), GFP_KERNEL); + + if (!bcn || !skb) + goto out_free; + + /* + * pointers go into the block we allocated, + * memory is | beacon_data | head | tail | + */ + bcn->head = ((u8 *) bcn) + sizeof(*bcn); + + /* fill in the head */ + mgmt = skb_put_zero(skb, hdr_len); + mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_BEACON); + eth_broadcast_addr(mgmt->da); + memcpy(mgmt->sa, sdata->vif.addr, ETH_ALEN); + memcpy(mgmt->bssid, sdata->vif.addr, ETH_ALEN); + ieee80211_mps_set_frame_flags(sdata, NULL, (void *) mgmt); + mgmt->u.beacon.beacon_int = + cpu_to_le16(sdata->vif.bss_conf.beacon_int); + mgmt->u.beacon.capab_info |= cpu_to_le16( + sdata->u.mesh.security ? WLAN_CAPABILITY_PRIVACY : 0); + + pos = skb_put(skb, 2); + *pos++ = WLAN_EID_SSID; + *pos++ = 0x0; + + rcu_read_lock(); + csa = rcu_dereference(ifmsh->csa); + if (csa) { + enum nl80211_channel_type ct; + struct cfg80211_chan_def *chandef; + int ie_len = 2 + sizeof(struct ieee80211_channel_sw_ie) + + 2 + sizeof(struct ieee80211_mesh_chansw_params_ie); + + pos = skb_put_zero(skb, ie_len); + *pos++ = WLAN_EID_CHANNEL_SWITCH; + *pos++ = 3; + *pos++ = 0x0; + *pos++ = ieee80211_frequency_to_channel( + csa->settings.chandef.chan->center_freq); + bcn->cntdwn_current_counter = csa->settings.count; + bcn->cntdwn_counter_offsets[0] = hdr_len + 6; + *pos++ = csa->settings.count; + *pos++ = WLAN_EID_CHAN_SWITCH_PARAM; + *pos++ = 6; + if (ifmsh->csa_role == IEEE80211_MESH_CSA_ROLE_INIT) { + *pos++ = ifmsh->mshcfg.dot11MeshTTL; + *pos |= WLAN_EID_CHAN_SWITCH_PARAM_INITIATOR; + } else { + *pos++ = ifmsh->chsw_ttl; + } + *pos++ |= csa->settings.block_tx ? + WLAN_EID_CHAN_SWITCH_PARAM_TX_RESTRICT : 0x00; + put_unaligned_le16(WLAN_REASON_MESH_CHAN, pos); + pos += 2; + put_unaligned_le16(ifmsh->pre_value, pos); + pos += 2; + + switch (csa->settings.chandef.width) { + case NL80211_CHAN_WIDTH_40: + ie_len = 2 + sizeof(struct ieee80211_sec_chan_offs_ie); + pos = skb_put_zero(skb, ie_len); + + *pos++ = WLAN_EID_SECONDARY_CHANNEL_OFFSET; /* EID */ + *pos++ = 1; /* len */ + ct = cfg80211_get_chandef_type(&csa->settings.chandef); + if (ct == NL80211_CHAN_HT40PLUS) + *pos++ = IEEE80211_HT_PARAM_CHA_SEC_ABOVE; + else + *pos++ = IEEE80211_HT_PARAM_CHA_SEC_BELOW; + break; + case NL80211_CHAN_WIDTH_80: + case NL80211_CHAN_WIDTH_80P80: + case NL80211_CHAN_WIDTH_160: + /* Channel Switch Wrapper + Wide Bandwidth CSA IE */ + ie_len = 2 + 2 + + sizeof(struct ieee80211_wide_bw_chansw_ie); + pos = skb_put_zero(skb, ie_len); + + *pos++ = WLAN_EID_CHANNEL_SWITCH_WRAPPER; /* EID */ + *pos++ = 5; /* len */ + /* put sub IE */ + chandef = &csa->settings.chandef; + ieee80211_ie_build_wide_bw_cs(pos, chandef); + break; + default: + break; + } + } + rcu_read_unlock(); + + if (ieee80211_add_srates_ie(sdata, skb, true, band) || + mesh_add_ds_params_ie(sdata, skb)) + goto out_free; + + bcn->head_len = skb->len; + memcpy(bcn->head, skb->data, bcn->head_len); + + /* now the tail */ + skb_trim(skb, 0); + bcn->tail = bcn->head + bcn->head_len; + + if (ieee80211_add_ext_srates_ie(sdata, skb, true, band) || + mesh_add_rsn_ie(sdata, skb) || + mesh_add_ht_cap_ie(sdata, skb) || + mesh_add_ht_oper_ie(sdata, skb) || + mesh_add_meshid_ie(sdata, skb) || + mesh_add_meshconf_ie(sdata, skb) || + mesh_add_awake_window_ie(sdata, skb) || + mesh_add_vht_cap_ie(sdata, skb) || + mesh_add_vht_oper_ie(sdata, skb) || + mesh_add_he_cap_ie(sdata, skb, ie_len_he_cap) || + mesh_add_he_oper_ie(sdata, skb) || + mesh_add_he_6ghz_cap_ie(sdata, skb) || + mesh_add_vendor_ies(sdata, skb)) + goto out_free; + + bcn->tail_len = skb->len; + memcpy(bcn->tail, skb->data, bcn->tail_len); + ieee80211_mesh_update_bss_params(sdata, bcn->tail, bcn->tail_len); + bcn->meshconf = (struct ieee80211_meshconf_ie *) + (bcn->tail + ifmsh->meshconf_offset); + + dev_kfree_skb(skb); + rcu_assign_pointer(ifmsh->beacon, bcn); + return 0; +out_free: + kfree(bcn); + dev_kfree_skb(skb); + return -ENOMEM; +} + +static int +ieee80211_mesh_rebuild_beacon(struct ieee80211_sub_if_data *sdata) +{ + struct beacon_data *old_bcn; + int ret; + + old_bcn = sdata_dereference(sdata->u.mesh.beacon, sdata); + ret = ieee80211_mesh_build_beacon(&sdata->u.mesh); + if (ret) + /* just reuse old beacon */ + return ret; + + if (old_bcn) + kfree_rcu(old_bcn, rcu_head); + return 0; +} + +void ieee80211_mbss_info_change_notify(struct ieee80211_sub_if_data *sdata, + u32 changed) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + unsigned long bits = changed; + u32 bit; + + if (!bits) + return; + + /* if we race with running work, worst case this work becomes a noop */ + for_each_set_bit(bit, &bits, sizeof(changed) * BITS_PER_BYTE) + set_bit(bit, &ifmsh->mbss_changed); + set_bit(MESH_WORK_MBSS_CHANGED, &ifmsh->wrkq_flags); + ieee80211_queue_work(&sdata->local->hw, &sdata->work); +} + +int ieee80211_start_mesh(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct ieee80211_local *local = sdata->local; + u32 changed = BSS_CHANGED_BEACON | + BSS_CHANGED_BEACON_ENABLED | + BSS_CHANGED_HT | + BSS_CHANGED_BASIC_RATES | + BSS_CHANGED_BEACON_INT | + BSS_CHANGED_MCAST_RATE; + + local->fif_other_bss++; + /* mesh ifaces must set allmulti to forward mcast traffic */ + atomic_inc(&local->iff_allmultis); + ieee80211_configure_filter(local); + + ifmsh->mesh_cc_id = 0; /* Disabled */ + /* register sync ops from extensible synchronization framework */ + ifmsh->sync_ops = ieee80211_mesh_sync_ops_get(ifmsh->mesh_sp_id); + ifmsh->sync_offset_clockdrift_max = 0; + set_bit(MESH_WORK_HOUSEKEEPING, &ifmsh->wrkq_flags); + ieee80211_mesh_root_setup(ifmsh); + ieee80211_queue_work(&local->hw, &sdata->work); + sdata->vif.bss_conf.ht_operation_mode = + ifmsh->mshcfg.ht_opmode; + sdata->vif.bss_conf.enable_beacon = true; + + changed |= ieee80211_mps_local_status_update(sdata); + + if (ieee80211_mesh_build_beacon(ifmsh)) { + ieee80211_stop_mesh(sdata); + return -ENOMEM; + } + + ieee80211_recalc_dtim(local, sdata); + ieee80211_link_info_change_notify(sdata, &sdata->deflink, changed); + + netif_carrier_on(sdata->dev); + return 0; +} + +void ieee80211_stop_mesh(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct beacon_data *bcn; + + netif_carrier_off(sdata->dev); + + /* flush STAs and mpaths on this iface */ + sta_info_flush(sdata); + ieee80211_free_keys(sdata, true); + mesh_path_flush_by_iface(sdata); + + /* stop the beacon */ + ifmsh->mesh_id_len = 0; + sdata->vif.bss_conf.enable_beacon = false; + sdata->beacon_rate_set = false; + clear_bit(SDATA_STATE_OFFCHANNEL_BEACON_STOPPED, &sdata->state); + ieee80211_link_info_change_notify(sdata, &sdata->deflink, + BSS_CHANGED_BEACON_ENABLED); + + /* remove beacon */ + bcn = sdata_dereference(ifmsh->beacon, sdata); + RCU_INIT_POINTER(ifmsh->beacon, NULL); + kfree_rcu(bcn, rcu_head); + + /* free all potentially still buffered group-addressed frames */ + local->total_ps_buffered -= skb_queue_len(&ifmsh->ps.bc_buf); + skb_queue_purge(&ifmsh->ps.bc_buf); + + del_timer_sync(&sdata->u.mesh.housekeeping_timer); + del_timer_sync(&sdata->u.mesh.mesh_path_root_timer); + del_timer_sync(&sdata->u.mesh.mesh_path_timer); + + /* clear any mesh work (for next join) we may have accrued */ + ifmsh->wrkq_flags = 0; + ifmsh->mbss_changed = 0; + + local->fif_other_bss--; + atomic_dec(&local->iff_allmultis); + ieee80211_configure_filter(local); +} + +static void ieee80211_mesh_csa_mark_radar(struct ieee80211_sub_if_data *sdata) +{ + int err; + + /* if the current channel is a DFS channel, mark the channel as + * unavailable. + */ + err = cfg80211_chandef_dfs_required(sdata->local->hw.wiphy, + &sdata->vif.bss_conf.chandef, + NL80211_IFTYPE_MESH_POINT); + if (err > 0) + cfg80211_radar_event(sdata->local->hw.wiphy, + &sdata->vif.bss_conf.chandef, GFP_ATOMIC); +} + +static bool +ieee80211_mesh_process_chnswitch(struct ieee80211_sub_if_data *sdata, + struct ieee802_11_elems *elems, bool beacon) +{ + struct cfg80211_csa_settings params; + struct ieee80211_csa_ie csa_ie; + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct ieee80211_supported_band *sband; + int err; + ieee80211_conn_flags_t conn_flags = 0; + u32 vht_cap_info = 0; + + sdata_assert_lock(sdata); + + sband = ieee80211_get_sband(sdata); + if (!sband) + return false; + + switch (sdata->vif.bss_conf.chandef.width) { + case NL80211_CHAN_WIDTH_20_NOHT: + conn_flags |= IEEE80211_CONN_DISABLE_HT; + fallthrough; + case NL80211_CHAN_WIDTH_20: + conn_flags |= IEEE80211_CONN_DISABLE_40MHZ; + fallthrough; + case NL80211_CHAN_WIDTH_40: + conn_flags |= IEEE80211_CONN_DISABLE_VHT; + break; + default: + break; + } + + if (elems->vht_cap_elem) + vht_cap_info = + le32_to_cpu(elems->vht_cap_elem->vht_cap_info); + + memset(¶ms, 0, sizeof(params)); + err = ieee80211_parse_ch_switch_ie(sdata, elems, sband->band, + vht_cap_info, + conn_flags, sdata->vif.addr, + &csa_ie); + if (err < 0) + return false; + if (err) + return false; + + /* Mark the channel unavailable if the reason for the switch is + * regulatory. + */ + if (csa_ie.reason_code == WLAN_REASON_MESH_CHAN_REGULATORY) + ieee80211_mesh_csa_mark_radar(sdata); + + params.chandef = csa_ie.chandef; + params.count = csa_ie.count; + + if (!cfg80211_chandef_usable(sdata->local->hw.wiphy, ¶ms.chandef, + IEEE80211_CHAN_DISABLED) || + !cfg80211_reg_can_beacon(sdata->local->hw.wiphy, ¶ms.chandef, + NL80211_IFTYPE_MESH_POINT)) { + sdata_info(sdata, + "mesh STA %pM switches to unsupported channel (%d MHz, width:%d, CF1/2: %d/%d MHz), aborting\n", + sdata->vif.addr, + params.chandef.chan->center_freq, + params.chandef.width, + params.chandef.center_freq1, + params.chandef.center_freq2); + return false; + } + + err = cfg80211_chandef_dfs_required(sdata->local->hw.wiphy, + ¶ms.chandef, + NL80211_IFTYPE_MESH_POINT); + if (err < 0) + return false; + if (err > 0 && !ifmsh->userspace_handles_dfs) { + sdata_info(sdata, + "mesh STA %pM switches to channel requiring DFS (%d MHz, width:%d, CF1/2: %d/%d MHz), aborting\n", + sdata->vif.addr, + params.chandef.chan->center_freq, + params.chandef.width, + params.chandef.center_freq1, + params.chandef.center_freq2); + return false; + } + + params.radar_required = err; + + if (cfg80211_chandef_identical(¶ms.chandef, + &sdata->vif.bss_conf.chandef)) { + mcsa_dbg(sdata, + "received csa with an identical chandef, ignoring\n"); + return true; + } + + mcsa_dbg(sdata, + "received channel switch announcement to go to channel %d MHz\n", + params.chandef.chan->center_freq); + + params.block_tx = csa_ie.mode & WLAN_EID_CHAN_SWITCH_PARAM_TX_RESTRICT; + if (beacon) { + ifmsh->chsw_ttl = csa_ie.ttl - 1; + if (ifmsh->pre_value >= csa_ie.pre_value) + return false; + ifmsh->pre_value = csa_ie.pre_value; + } + + if (ifmsh->chsw_ttl >= ifmsh->mshcfg.dot11MeshTTL) + return false; + + ifmsh->csa_role = IEEE80211_MESH_CSA_ROLE_REPEATER; + + if (ieee80211_channel_switch(sdata->local->hw.wiphy, sdata->dev, + ¶ms) < 0) + return false; + + return true; +} + +static void +ieee80211_mesh_rx_probe_req(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, size_t len) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct sk_buff *presp; + struct beacon_data *bcn; + struct ieee80211_mgmt *hdr; + struct ieee802_11_elems *elems; + size_t baselen; + u8 *pos; + + pos = mgmt->u.probe_req.variable; + baselen = (u8 *) pos - (u8 *) mgmt; + if (baselen > len) + return; + + elems = ieee802_11_parse_elems(pos, len - baselen, false, NULL); + if (!elems) + return; + + if (!elems->mesh_id) + goto free; + + /* 802.11-2012 10.1.4.3.2 */ + if ((!ether_addr_equal(mgmt->da, sdata->vif.addr) && + !is_broadcast_ether_addr(mgmt->da)) || + elems->ssid_len != 0) + goto free; + + if (elems->mesh_id_len != 0 && + (elems->mesh_id_len != ifmsh->mesh_id_len || + memcmp(elems->mesh_id, ifmsh->mesh_id, ifmsh->mesh_id_len))) + goto free; + + rcu_read_lock(); + bcn = rcu_dereference(ifmsh->beacon); + + if (!bcn) + goto out; + + presp = dev_alloc_skb(local->tx_headroom + + bcn->head_len + bcn->tail_len); + if (!presp) + goto out; + + skb_reserve(presp, local->tx_headroom); + skb_put_data(presp, bcn->head, bcn->head_len); + skb_put_data(presp, bcn->tail, bcn->tail_len); + hdr = (struct ieee80211_mgmt *) presp->data; + hdr->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_PROBE_RESP); + memcpy(hdr->da, mgmt->sa, ETH_ALEN); + IEEE80211_SKB_CB(presp)->flags |= IEEE80211_TX_INTFL_DONT_ENCRYPT; + ieee80211_tx_skb(sdata, presp); +out: + rcu_read_unlock(); +free: + kfree(elems); +} + +static void ieee80211_mesh_rx_bcn_presp(struct ieee80211_sub_if_data *sdata, + u16 stype, + struct ieee80211_mgmt *mgmt, + size_t len, + struct ieee80211_rx_status *rx_status) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct ieee802_11_elems *elems; + struct ieee80211_channel *channel; + size_t baselen; + int freq; + enum nl80211_band band = rx_status->band; + + /* ignore ProbeResp to foreign address */ + if (stype == IEEE80211_STYPE_PROBE_RESP && + !ether_addr_equal(mgmt->da, sdata->vif.addr)) + return; + + baselen = (u8 *) mgmt->u.probe_resp.variable - (u8 *) mgmt; + if (baselen > len) + return; + + elems = ieee802_11_parse_elems(mgmt->u.probe_resp.variable, + len - baselen, + false, NULL); + if (!elems) + return; + + /* ignore non-mesh or secure / unsecure mismatch */ + if ((!elems->mesh_id || !elems->mesh_config) || + (elems->rsn && sdata->u.mesh.security == IEEE80211_MESH_SEC_NONE) || + (!elems->rsn && sdata->u.mesh.security != IEEE80211_MESH_SEC_NONE)) + goto free; + + if (elems->ds_params) + freq = ieee80211_channel_to_frequency(elems->ds_params[0], band); + else + freq = rx_status->freq; + + channel = ieee80211_get_channel(local->hw.wiphy, freq); + + if (!channel || channel->flags & IEEE80211_CHAN_DISABLED) + goto free; + + if (mesh_matches_local(sdata, elems)) { + mpl_dbg(sdata, "rssi_threshold=%d,rx_status->signal=%d\n", + sdata->u.mesh.mshcfg.rssi_threshold, rx_status->signal); + if (!sdata->u.mesh.user_mpm || + sdata->u.mesh.mshcfg.rssi_threshold == 0 || + sdata->u.mesh.mshcfg.rssi_threshold < rx_status->signal) + mesh_neighbour_update(sdata, mgmt->sa, elems, + rx_status); + + if (ifmsh->csa_role != IEEE80211_MESH_CSA_ROLE_INIT && + !sdata->vif.bss_conf.csa_active) + ieee80211_mesh_process_chnswitch(sdata, elems, true); + } + + if (ifmsh->sync_ops) + ifmsh->sync_ops->rx_bcn_presp(sdata, stype, mgmt, len, + elems->mesh_config, rx_status); +free: + kfree(elems); +} + +int ieee80211_mesh_finish_csa(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct mesh_csa_settings *tmp_csa_settings; + int ret = 0; + int changed = 0; + + /* Reset the TTL value and Initiator flag */ + ifmsh->csa_role = IEEE80211_MESH_CSA_ROLE_NONE; + ifmsh->chsw_ttl = 0; + + /* Remove the CSA and MCSP elements from the beacon */ + tmp_csa_settings = sdata_dereference(ifmsh->csa, sdata); + RCU_INIT_POINTER(ifmsh->csa, NULL); + if (tmp_csa_settings) + kfree_rcu(tmp_csa_settings, rcu_head); + ret = ieee80211_mesh_rebuild_beacon(sdata); + if (ret) + return -EINVAL; + + changed |= BSS_CHANGED_BEACON; + + mcsa_dbg(sdata, "complete switching to center freq %d MHz", + sdata->vif.bss_conf.chandef.chan->center_freq); + return changed; +} + +int ieee80211_mesh_csa_beacon(struct ieee80211_sub_if_data *sdata, + struct cfg80211_csa_settings *csa_settings) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct mesh_csa_settings *tmp_csa_settings; + int ret = 0; + + lockdep_assert_held(&sdata->wdev.mtx); + + tmp_csa_settings = kmalloc(sizeof(*tmp_csa_settings), + GFP_ATOMIC); + if (!tmp_csa_settings) + return -ENOMEM; + + memcpy(&tmp_csa_settings->settings, csa_settings, + sizeof(struct cfg80211_csa_settings)); + + rcu_assign_pointer(ifmsh->csa, tmp_csa_settings); + + ret = ieee80211_mesh_rebuild_beacon(sdata); + if (ret) { + tmp_csa_settings = rcu_dereference(ifmsh->csa); + RCU_INIT_POINTER(ifmsh->csa, NULL); + kfree_rcu(tmp_csa_settings, rcu_head); + return ret; + } + + return BSS_CHANGED_BEACON; +} + +static int mesh_fwd_csa_frame(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, size_t len, + struct ieee802_11_elems *elems) +{ + struct ieee80211_mgmt *mgmt_fwd; + struct sk_buff *skb; + struct ieee80211_local *local = sdata->local; + + skb = dev_alloc_skb(local->tx_headroom + len); + if (!skb) + return -ENOMEM; + skb_reserve(skb, local->tx_headroom); + mgmt_fwd = skb_put(skb, len); + + elems->mesh_chansw_params_ie->mesh_ttl--; + elems->mesh_chansw_params_ie->mesh_flags &= + ~WLAN_EID_CHAN_SWITCH_PARAM_INITIATOR; + + memcpy(mgmt_fwd, mgmt, len); + eth_broadcast_addr(mgmt_fwd->da); + memcpy(mgmt_fwd->sa, sdata->vif.addr, ETH_ALEN); + memcpy(mgmt_fwd->bssid, sdata->vif.addr, ETH_ALEN); + + ieee80211_tx_skb(sdata, skb); + return 0; +} + +static void mesh_rx_csa_frame(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, size_t len) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct ieee802_11_elems *elems; + u16 pre_value; + bool fwd_csa = true; + size_t baselen; + u8 *pos; + + if (mgmt->u.action.u.measurement.action_code != + WLAN_ACTION_SPCT_CHL_SWITCH) + return; + + pos = mgmt->u.action.u.chan_switch.variable; + baselen = offsetof(struct ieee80211_mgmt, + u.action.u.chan_switch.variable); + elems = ieee802_11_parse_elems(pos, len - baselen, true, NULL); + if (!elems) + return; + + if (!mesh_matches_local(sdata, elems)) + goto free; + + ifmsh->chsw_ttl = elems->mesh_chansw_params_ie->mesh_ttl; + if (!--ifmsh->chsw_ttl) + fwd_csa = false; + + pre_value = le16_to_cpu(elems->mesh_chansw_params_ie->mesh_pre_value); + if (ifmsh->pre_value >= pre_value) + goto free; + + ifmsh->pre_value = pre_value; + + if (!sdata->vif.bss_conf.csa_active && + !ieee80211_mesh_process_chnswitch(sdata, elems, false)) { + mcsa_dbg(sdata, "Failed to process CSA action frame"); + goto free; + } + + /* forward or re-broadcast the CSA frame */ + if (fwd_csa) { + if (mesh_fwd_csa_frame(sdata, mgmt, len, elems) < 0) + mcsa_dbg(sdata, "Failed to forward the CSA frame"); + } +free: + kfree(elems); +} + +static void ieee80211_mesh_rx_mgmt_action(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, + size_t len, + struct ieee80211_rx_status *rx_status) +{ + switch (mgmt->u.action.category) { + case WLAN_CATEGORY_SELF_PROTECTED: + switch (mgmt->u.action.u.self_prot.action_code) { + case WLAN_SP_MESH_PEERING_OPEN: + case WLAN_SP_MESH_PEERING_CLOSE: + case WLAN_SP_MESH_PEERING_CONFIRM: + mesh_rx_plink_frame(sdata, mgmt, len, rx_status); + break; + } + break; + case WLAN_CATEGORY_MESH_ACTION: + if (mesh_action_is_path_sel(mgmt)) + mesh_rx_path_sel_frame(sdata, mgmt, len); + break; + case WLAN_CATEGORY_SPECTRUM_MGMT: + mesh_rx_csa_frame(sdata, mgmt, len); + break; + } +} + +void ieee80211_mesh_rx_queued_mgmt(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_rx_status *rx_status; + struct ieee80211_mgmt *mgmt; + u16 stype; + + sdata_lock(sdata); + + /* mesh already went down */ + if (!sdata->u.mesh.mesh_id_len) + goto out; + + rx_status = IEEE80211_SKB_RXCB(skb); + mgmt = (struct ieee80211_mgmt *) skb->data; + stype = le16_to_cpu(mgmt->frame_control) & IEEE80211_FCTL_STYPE; + + switch (stype) { + case IEEE80211_STYPE_PROBE_RESP: + case IEEE80211_STYPE_BEACON: + ieee80211_mesh_rx_bcn_presp(sdata, stype, mgmt, skb->len, + rx_status); + break; + case IEEE80211_STYPE_PROBE_REQ: + ieee80211_mesh_rx_probe_req(sdata, mgmt, skb->len); + break; + case IEEE80211_STYPE_ACTION: + ieee80211_mesh_rx_mgmt_action(sdata, mgmt, skb->len, rx_status); + break; + } +out: + sdata_unlock(sdata); +} + +static void mesh_bss_info_changed(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + u32 bit, changed = 0; + + for_each_set_bit(bit, &ifmsh->mbss_changed, + sizeof(changed) * BITS_PER_BYTE) { + clear_bit(bit, &ifmsh->mbss_changed); + changed |= BIT(bit); + } + + if (sdata->vif.bss_conf.enable_beacon && + (changed & (BSS_CHANGED_BEACON | + BSS_CHANGED_HT | + BSS_CHANGED_BASIC_RATES | + BSS_CHANGED_BEACON_INT))) + if (ieee80211_mesh_rebuild_beacon(sdata)) + return; + + ieee80211_link_info_change_notify(sdata, &sdata->deflink, changed); +} + +void ieee80211_mesh_work(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + + sdata_lock(sdata); + + /* mesh already went down */ + if (!sdata->u.mesh.mesh_id_len) + goto out; + + if (ifmsh->preq_queue_len && + time_after(jiffies, + ifmsh->last_preq + msecs_to_jiffies(ifmsh->mshcfg.dot11MeshHWMPpreqMinInterval))) + mesh_path_start_discovery(sdata); + + if (test_and_clear_bit(MESH_WORK_HOUSEKEEPING, &ifmsh->wrkq_flags)) + ieee80211_mesh_housekeeping(sdata); + + if (test_and_clear_bit(MESH_WORK_ROOT, &ifmsh->wrkq_flags)) + ieee80211_mesh_rootpath(sdata); + + if (test_and_clear_bit(MESH_WORK_DRIFT_ADJUST, &ifmsh->wrkq_flags)) + mesh_sync_adjust_tsf(sdata); + + if (test_and_clear_bit(MESH_WORK_MBSS_CHANGED, &ifmsh->wrkq_flags)) + mesh_bss_info_changed(sdata); +out: + sdata_unlock(sdata); +} + + +void ieee80211_mesh_init_sdata(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + static u8 zero_addr[ETH_ALEN] = {}; + + timer_setup(&ifmsh->housekeeping_timer, + ieee80211_mesh_housekeeping_timer, 0); + + ifmsh->accepting_plinks = true; + atomic_set(&ifmsh->mpaths, 0); + mesh_rmc_init(sdata); + ifmsh->last_preq = jiffies; + ifmsh->next_perr = jiffies; + ifmsh->csa_role = IEEE80211_MESH_CSA_ROLE_NONE; + /* Allocate all mesh structures when creating the first mesh interface. */ + if (!mesh_allocated) + ieee80211s_init(); + + mesh_pathtbl_init(sdata); + + timer_setup(&ifmsh->mesh_path_timer, ieee80211_mesh_path_timer, 0); + timer_setup(&ifmsh->mesh_path_root_timer, + ieee80211_mesh_path_root_timer, 0); + INIT_LIST_HEAD(&ifmsh->preq_queue.list); + skb_queue_head_init(&ifmsh->ps.bc_buf); + spin_lock_init(&ifmsh->mesh_preq_queue_lock); + spin_lock_init(&ifmsh->sync_offset_lock); + RCU_INIT_POINTER(ifmsh->beacon, NULL); + + sdata->vif.bss_conf.bssid = zero_addr; +} + +void ieee80211_mesh_teardown_sdata(struct ieee80211_sub_if_data *sdata) +{ + mesh_rmc_free(sdata); + mesh_pathtbl_unregister(sdata); +} diff --git a/net/mac80211/mesh.h b/net/mac80211/mesh.h new file mode 100644 index 000000000..b2b717a78 --- /dev/null +++ b/net/mac80211/mesh.h @@ -0,0 +1,350 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2008, 2009 open80211s Ltd. + * Authors: Luis Carlos Cobo + * Javier Cardona + */ + +#ifndef IEEE80211S_H +#define IEEE80211S_H + +#include +#include +#include "ieee80211_i.h" + + +/* Data structures */ + +/** + * enum mesh_path_flags - mac80211 mesh path flags + * + * @MESH_PATH_ACTIVE: the mesh path can be used for forwarding + * @MESH_PATH_RESOLVING: the discovery process is running for this mesh path + * @MESH_PATH_SN_VALID: the mesh path contains a valid destination sequence + * number + * @MESH_PATH_FIXED: the mesh path has been manually set and should not be + * modified + * @MESH_PATH_RESOLVED: the mesh path can has been resolved + * @MESH_PATH_REQ_QUEUED: there is an unsent path request for this destination + * already queued up, waiting for the discovery process to start. + * @MESH_PATH_DELETED: the mesh path has been deleted and should no longer + * be used + * + * MESH_PATH_RESOLVED is used by the mesh path timer to + * decide when to stop or cancel the mesh path discovery. + */ +enum mesh_path_flags { + MESH_PATH_ACTIVE = BIT(0), + MESH_PATH_RESOLVING = BIT(1), + MESH_PATH_SN_VALID = BIT(2), + MESH_PATH_FIXED = BIT(3), + MESH_PATH_RESOLVED = BIT(4), + MESH_PATH_REQ_QUEUED = BIT(5), + MESH_PATH_DELETED = BIT(6), +}; + +/** + * enum mesh_deferred_task_flags - mac80211 mesh deferred tasks + * + * + * + * @MESH_WORK_HOUSEKEEPING: run the periodic mesh housekeeping tasks + * @MESH_WORK_ROOT: the mesh root station needs to send a frame + * @MESH_WORK_DRIFT_ADJUST: time to compensate for clock drift relative to other + * mesh nodes + * @MESH_WORK_MBSS_CHANGED: rebuild beacon and notify driver of BSS changes + */ +enum mesh_deferred_task_flags { + MESH_WORK_HOUSEKEEPING, + MESH_WORK_ROOT, + MESH_WORK_DRIFT_ADJUST, + MESH_WORK_MBSS_CHANGED, +}; + +/** + * struct mesh_path - mac80211 mesh path structure + * + * @dst: mesh path destination mac address + * @mpp: mesh proxy mac address + * @rhash: rhashtable list pointer + * @walk_list: linked list containing all mesh_path objects. + * @gate_list: list pointer for known gates list + * @sdata: mesh subif + * @next_hop: mesh neighbor to which frames for this destination will be + * forwarded + * @timer: mesh path discovery timer + * @frame_queue: pending queue for frames sent to this destination while the + * path is unresolved + * @rcu: rcu head for freeing mesh path + * @sn: target sequence number + * @metric: current metric to this destination + * @hop_count: hops to destination + * @exp_time: in jiffies, when the path will expire or when it expired + * @discovery_timeout: timeout (lapse in jiffies) used for the last discovery + * retry + * @discovery_retries: number of discovery retries + * @flags: mesh path flags, as specified on &enum mesh_path_flags + * @state_lock: mesh path state lock used to protect changes to the + * mpath itself. No need to take this lock when adding or removing + * an mpath to a hash bucket on a path table. + * @rann_snd_addr: the RANN sender address + * @rann_metric: the aggregated path metric towards the root node + * @last_preq_to_root: Timestamp of last PREQ sent to root + * @is_root: the destination station of this path is a root node + * @is_gate: the destination station of this path is a mesh gate + * @path_change_count: the number of path changes to destination + * + * + * The dst address is unique in the mesh path table. Since the mesh_path is + * protected by RCU, deleting the next_hop STA must remove / substitute the + * mesh_path structure and wait until that is no longer reachable before + * destroying the STA completely. + */ +struct mesh_path { + u8 dst[ETH_ALEN]; + u8 mpp[ETH_ALEN]; /* used for MPP or MAP */ + struct rhash_head rhash; + struct hlist_node walk_list; + struct hlist_node gate_list; + struct ieee80211_sub_if_data *sdata; + struct sta_info __rcu *next_hop; + struct timer_list timer; + struct sk_buff_head frame_queue; + struct rcu_head rcu; + u32 sn; + u32 metric; + u8 hop_count; + unsigned long exp_time; + u32 discovery_timeout; + u8 discovery_retries; + enum mesh_path_flags flags; + spinlock_t state_lock; + u8 rann_snd_addr[ETH_ALEN]; + u32 rann_metric; + unsigned long last_preq_to_root; + bool is_root; + bool is_gate; + u32 path_change_count; +}; + +/* Recent multicast cache */ +/* RMC_BUCKETS must be a power of 2, maximum 256 */ +#define RMC_BUCKETS 256 +#define RMC_QUEUE_MAX_LEN 4 +#define RMC_TIMEOUT (3 * HZ) + +/** + * struct rmc_entry - entry in the Recent Multicast Cache + * + * @seqnum: mesh sequence number of the frame + * @exp_time: expiration time of the entry, in jiffies + * @sa: source address of the frame + * @list: hashtable list pointer + * + * The Recent Multicast Cache keeps track of the latest multicast frames that + * have been received by a mesh interface and discards received multicast frames + * that are found in the cache. + */ +struct rmc_entry { + struct hlist_node list; + unsigned long exp_time; + u32 seqnum; + u8 sa[ETH_ALEN]; +}; + +struct mesh_rmc { + struct hlist_head bucket[RMC_BUCKETS]; + u32 idx_mask; +}; + +#define IEEE80211_MESH_HOUSEKEEPING_INTERVAL (60 * HZ) + +#define MESH_PATH_EXPIRE (600 * HZ) + +/* Default maximum number of plinks per interface */ +#define MESH_MAX_PLINKS 256 + +/* Maximum number of paths per interface */ +#define MESH_MAX_MPATHS 1024 + +/* Number of frames buffered per destination for unresolved destinations */ +#define MESH_FRAME_QUEUE_LEN 10 + +/* Public interfaces */ +/* Various */ +int ieee80211_fill_mesh_addresses(struct ieee80211_hdr *hdr, __le16 *fc, + const u8 *da, const u8 *sa); +unsigned int ieee80211_new_mesh_header(struct ieee80211_sub_if_data *sdata, + struct ieee80211s_hdr *meshhdr, + const char *addr4or5, const char *addr6); +int mesh_rmc_check(struct ieee80211_sub_if_data *sdata, + const u8 *addr, struct ieee80211s_hdr *mesh_hdr); +bool mesh_matches_local(struct ieee80211_sub_if_data *sdata, + struct ieee802_11_elems *ie); +void mesh_ids_set_default(struct ieee80211_if_mesh *mesh); +int mesh_add_meshconf_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +int mesh_add_meshid_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +int mesh_add_rsn_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +int mesh_add_vendor_ies(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +int mesh_add_ht_cap_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +int mesh_add_ht_oper_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +int mesh_add_vht_cap_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +int mesh_add_vht_oper_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +int mesh_add_he_cap_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, u8 ie_len); +int mesh_add_he_oper_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +int mesh_add_he_6ghz_cap_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +void mesh_rmc_free(struct ieee80211_sub_if_data *sdata); +int mesh_rmc_init(struct ieee80211_sub_if_data *sdata); +void ieee80211s_init(void); +void ieee80211s_update_metric(struct ieee80211_local *local, + struct sta_info *sta, + struct ieee80211_tx_status *st); +void ieee80211_mesh_init_sdata(struct ieee80211_sub_if_data *sdata); +void ieee80211_mesh_teardown_sdata(struct ieee80211_sub_if_data *sdata); +int ieee80211_start_mesh(struct ieee80211_sub_if_data *sdata); +void ieee80211_stop_mesh(struct ieee80211_sub_if_data *sdata); +void ieee80211_mesh_root_setup(struct ieee80211_if_mesh *ifmsh); +const struct ieee80211_mesh_sync_ops *ieee80211_mesh_sync_ops_get(u8 method); +/* wrapper for ieee80211_bss_info_change_notify() */ +void ieee80211_mbss_info_change_notify(struct ieee80211_sub_if_data *sdata, + u32 changed); + +/* mesh power save */ +u32 ieee80211_mps_local_status_update(struct ieee80211_sub_if_data *sdata); +u32 ieee80211_mps_set_sta_local_pm(struct sta_info *sta, + enum nl80211_mesh_power_mode pm); +void ieee80211_mps_set_frame_flags(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + struct ieee80211_hdr *hdr); +void ieee80211_mps_sta_status_update(struct sta_info *sta); +void ieee80211_mps_rx_h_sta_process(struct sta_info *sta, + struct ieee80211_hdr *hdr); +void ieee80211_mpsp_trigger_process(u8 *qc, struct sta_info *sta, + bool tx, bool acked); +void ieee80211_mps_frame_release(struct sta_info *sta, + struct ieee802_11_elems *elems); + +/* Mesh paths */ +int mesh_nexthop_lookup(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +int mesh_nexthop_resolve(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +void mesh_path_start_discovery(struct ieee80211_sub_if_data *sdata); +struct mesh_path *mesh_path_lookup(struct ieee80211_sub_if_data *sdata, + const u8 *dst); +struct mesh_path *mpp_path_lookup(struct ieee80211_sub_if_data *sdata, + const u8 *dst); +int mpp_path_add(struct ieee80211_sub_if_data *sdata, + const u8 *dst, const u8 *mpp); +struct mesh_path * +mesh_path_lookup_by_idx(struct ieee80211_sub_if_data *sdata, int idx); +struct mesh_path * +mpp_path_lookup_by_idx(struct ieee80211_sub_if_data *sdata, int idx); +void mesh_path_fix_nexthop(struct mesh_path *mpath, struct sta_info *next_hop); +void mesh_path_expire(struct ieee80211_sub_if_data *sdata); +void mesh_rx_path_sel_frame(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, size_t len); +struct mesh_path * +mesh_path_add(struct ieee80211_sub_if_data *sdata, const u8 *dst); + +int mesh_path_add_gate(struct mesh_path *mpath); +int mesh_path_send_to_gates(struct mesh_path *mpath); +int mesh_gate_num(struct ieee80211_sub_if_data *sdata); +u32 airtime_link_metric_get(struct ieee80211_local *local, + struct sta_info *sta); + +/* Mesh plinks */ +void mesh_neighbour_update(struct ieee80211_sub_if_data *sdata, + u8 *hw_addr, struct ieee802_11_elems *ie, + struct ieee80211_rx_status *rx_status); +bool mesh_peer_accepts_plinks(struct ieee802_11_elems *ie); +u32 mesh_accept_plinks_update(struct ieee80211_sub_if_data *sdata); +void mesh_plink_timer(struct timer_list *t); +void mesh_plink_broken(struct sta_info *sta); +u32 mesh_plink_deactivate(struct sta_info *sta); +u32 mesh_plink_open(struct sta_info *sta); +u32 mesh_plink_block(struct sta_info *sta); +void mesh_rx_plink_frame(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, size_t len, + struct ieee80211_rx_status *rx_status); +void mesh_sta_cleanup(struct sta_info *sta); + +/* Private interfaces */ +/* Mesh paths */ +int mesh_path_error_tx(struct ieee80211_sub_if_data *sdata, + u8 ttl, const u8 *target, u32 target_sn, + u16 target_rcode, const u8 *ra); +void mesh_path_assign_nexthop(struct mesh_path *mpath, struct sta_info *sta); +void mesh_path_flush_pending(struct mesh_path *mpath); +void mesh_path_tx_pending(struct mesh_path *mpath); +void mesh_pathtbl_init(struct ieee80211_sub_if_data *sdata); +void mesh_pathtbl_unregister(struct ieee80211_sub_if_data *sdata); +int mesh_path_del(struct ieee80211_sub_if_data *sdata, const u8 *addr); +void mesh_path_timer(struct timer_list *t); +void mesh_path_flush_by_nexthop(struct sta_info *sta); +void mesh_path_discard_frame(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +void mesh_path_tx_root_frame(struct ieee80211_sub_if_data *sdata); + +bool mesh_action_is_path_sel(struct ieee80211_mgmt *mgmt); + +#ifdef CONFIG_MAC80211_MESH +static inline +u32 mesh_plink_inc_estab_count(struct ieee80211_sub_if_data *sdata) +{ + atomic_inc(&sdata->u.mesh.estab_plinks); + return mesh_accept_plinks_update(sdata) | BSS_CHANGED_BEACON; +} + +static inline +u32 mesh_plink_dec_estab_count(struct ieee80211_sub_if_data *sdata) +{ + atomic_dec(&sdata->u.mesh.estab_plinks); + return mesh_accept_plinks_update(sdata) | BSS_CHANGED_BEACON; +} + +static inline int mesh_plink_free_count(struct ieee80211_sub_if_data *sdata) +{ + return sdata->u.mesh.mshcfg.dot11MeshMaxPeerLinks - + atomic_read(&sdata->u.mesh.estab_plinks); +} + +static inline bool mesh_plink_availables(struct ieee80211_sub_if_data *sdata) +{ + return (min_t(long, mesh_plink_free_count(sdata), + MESH_MAX_PLINKS - sdata->local->num_sta)) > 0; +} + +static inline void mesh_path_activate(struct mesh_path *mpath) +{ + mpath->flags |= MESH_PATH_ACTIVE | MESH_PATH_RESOLVED; +} + +static inline bool mesh_path_sel_is_hwmp(struct ieee80211_sub_if_data *sdata) +{ + return sdata->u.mesh.mesh_pp_id == IEEE80211_PATH_PROTOCOL_HWMP; +} + +void mesh_path_flush_by_iface(struct ieee80211_sub_if_data *sdata); +void mesh_sync_adjust_tsf(struct ieee80211_sub_if_data *sdata); +void ieee80211s_stop(void); +#else +static inline bool mesh_path_sel_is_hwmp(struct ieee80211_sub_if_data *sdata) +{ return false; } +static inline void mesh_path_flush_by_iface(struct ieee80211_sub_if_data *sdata) +{} +static inline void ieee80211s_stop(void) {} +#endif + +#endif /* IEEE80211S_H */ diff --git a/net/mac80211/mesh_hwmp.c b/net/mac80211/mesh_hwmp.c new file mode 100644 index 000000000..9b1ce7c39 --- /dev/null +++ b/net/mac80211/mesh_hwmp.c @@ -0,0 +1,1332 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008, 2009 open80211s Ltd. + * Copyright (C) 2019, 2021-2022 Intel Corporation + * Author: Luis Carlos Cobo + */ + +#include +#include +#include +#include "wme.h" +#include "mesh.h" + +#define TEST_FRAME_LEN 8192 +#define MAX_METRIC 0xffffffff +#define ARITH_SHIFT 8 +#define LINK_FAIL_THRESH 95 + +#define MAX_PREQ_QUEUE_LEN 64 + +static void mesh_queue_preq(struct mesh_path *, u8); + +static inline u32 u32_field_get(const u8 *preq_elem, int offset, bool ae) +{ + if (ae) + offset += 6; + return get_unaligned_le32(preq_elem + offset); +} + +static inline u16 u16_field_get(const u8 *preq_elem, int offset, bool ae) +{ + if (ae) + offset += 6; + return get_unaligned_le16(preq_elem + offset); +} + +/* HWMP IE processing macros */ +#define AE_F (1<<6) +#define AE_F_SET(x) (*x & AE_F) +#define PREQ_IE_FLAGS(x) (*(x)) +#define PREQ_IE_HOPCOUNT(x) (*(x + 1)) +#define PREQ_IE_TTL(x) (*(x + 2)) +#define PREQ_IE_PREQ_ID(x) u32_field_get(x, 3, 0) +#define PREQ_IE_ORIG_ADDR(x) (x + 7) +#define PREQ_IE_ORIG_SN(x) u32_field_get(x, 13, 0) +#define PREQ_IE_LIFETIME(x) u32_field_get(x, 17, AE_F_SET(x)) +#define PREQ_IE_METRIC(x) u32_field_get(x, 21, AE_F_SET(x)) +#define PREQ_IE_TARGET_F(x) (*(AE_F_SET(x) ? x + 32 : x + 26)) +#define PREQ_IE_TARGET_ADDR(x) (AE_F_SET(x) ? x + 33 : x + 27) +#define PREQ_IE_TARGET_SN(x) u32_field_get(x, 33, AE_F_SET(x)) + + +#define PREP_IE_FLAGS(x) PREQ_IE_FLAGS(x) +#define PREP_IE_HOPCOUNT(x) PREQ_IE_HOPCOUNT(x) +#define PREP_IE_TTL(x) PREQ_IE_TTL(x) +#define PREP_IE_ORIG_ADDR(x) (AE_F_SET(x) ? x + 27 : x + 21) +#define PREP_IE_ORIG_SN(x) u32_field_get(x, 27, AE_F_SET(x)) +#define PREP_IE_LIFETIME(x) u32_field_get(x, 13, AE_F_SET(x)) +#define PREP_IE_METRIC(x) u32_field_get(x, 17, AE_F_SET(x)) +#define PREP_IE_TARGET_ADDR(x) (x + 3) +#define PREP_IE_TARGET_SN(x) u32_field_get(x, 9, 0) + +#define PERR_IE_TTL(x) (*(x)) +#define PERR_IE_TARGET_FLAGS(x) (*(x + 2)) +#define PERR_IE_TARGET_ADDR(x) (x + 3) +#define PERR_IE_TARGET_SN(x) u32_field_get(x, 9, 0) +#define PERR_IE_TARGET_RCODE(x) u16_field_get(x, 13, 0) + +#define MSEC_TO_TU(x) (x*1000/1024) +#define SN_GT(x, y) ((s32)(y - x) < 0) +#define SN_LT(x, y) ((s32)(x - y) < 0) +#define MAX_SANE_SN_DELTA 32 + +static inline u32 SN_DELTA(u32 x, u32 y) +{ + return x >= y ? x - y : y - x; +} + +#define net_traversal_jiffies(s) \ + msecs_to_jiffies(s->u.mesh.mshcfg.dot11MeshHWMPnetDiameterTraversalTime) +#define default_lifetime(s) \ + MSEC_TO_TU(s->u.mesh.mshcfg.dot11MeshHWMPactivePathTimeout) +#define min_preq_int_jiff(s) \ + (msecs_to_jiffies(s->u.mesh.mshcfg.dot11MeshHWMPpreqMinInterval)) +#define max_preq_retries(s) (s->u.mesh.mshcfg.dot11MeshHWMPmaxPREQretries) +#define disc_timeout_jiff(s) \ + msecs_to_jiffies(sdata->u.mesh.mshcfg.min_discovery_timeout) +#define root_path_confirmation_jiffies(s) \ + msecs_to_jiffies(sdata->u.mesh.mshcfg.dot11MeshHWMPconfirmationInterval) + +enum mpath_frame_type { + MPATH_PREQ = 0, + MPATH_PREP, + MPATH_PERR, + MPATH_RANN +}; + +static const u8 broadcast_addr[ETH_ALEN] = {0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; + +static int mesh_path_sel_frame_tx(enum mpath_frame_type action, u8 flags, + const u8 *orig_addr, u32 orig_sn, + u8 target_flags, const u8 *target, + u32 target_sn, const u8 *da, + u8 hop_count, u8 ttl, + u32 lifetime, u32 metric, u32 preq_id, + struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct sk_buff *skb; + struct ieee80211_mgmt *mgmt; + u8 *pos, ie_len; + int hdr_len = offsetofend(struct ieee80211_mgmt, + u.action.u.mesh_action); + + skb = dev_alloc_skb(local->tx_headroom + + hdr_len + + 2 + 37); /* max HWMP IE */ + if (!skb) + return -1; + skb_reserve(skb, local->tx_headroom); + mgmt = skb_put_zero(skb, hdr_len); + mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_ACTION); + + memcpy(mgmt->da, da, ETH_ALEN); + memcpy(mgmt->sa, sdata->vif.addr, ETH_ALEN); + /* BSSID == SA */ + memcpy(mgmt->bssid, sdata->vif.addr, ETH_ALEN); + mgmt->u.action.category = WLAN_CATEGORY_MESH_ACTION; + mgmt->u.action.u.mesh_action.action_code = + WLAN_MESH_ACTION_HWMP_PATH_SELECTION; + + switch (action) { + case MPATH_PREQ: + mhwmp_dbg(sdata, "sending PREQ to %pM\n", target); + ie_len = 37; + pos = skb_put(skb, 2 + ie_len); + *pos++ = WLAN_EID_PREQ; + break; + case MPATH_PREP: + mhwmp_dbg(sdata, "sending PREP to %pM\n", orig_addr); + ie_len = 31; + pos = skb_put(skb, 2 + ie_len); + *pos++ = WLAN_EID_PREP; + break; + case MPATH_RANN: + mhwmp_dbg(sdata, "sending RANN from %pM\n", orig_addr); + ie_len = sizeof(struct ieee80211_rann_ie); + pos = skb_put(skb, 2 + ie_len); + *pos++ = WLAN_EID_RANN; + break; + default: + kfree_skb(skb); + return -ENOTSUPP; + } + *pos++ = ie_len; + *pos++ = flags; + *pos++ = hop_count; + *pos++ = ttl; + if (action == MPATH_PREP) { + memcpy(pos, target, ETH_ALEN); + pos += ETH_ALEN; + put_unaligned_le32(target_sn, pos); + pos += 4; + } else { + if (action == MPATH_PREQ) { + put_unaligned_le32(preq_id, pos); + pos += 4; + } + memcpy(pos, orig_addr, ETH_ALEN); + pos += ETH_ALEN; + put_unaligned_le32(orig_sn, pos); + pos += 4; + } + put_unaligned_le32(lifetime, pos); /* interval for RANN */ + pos += 4; + put_unaligned_le32(metric, pos); + pos += 4; + if (action == MPATH_PREQ) { + *pos++ = 1; /* destination count */ + *pos++ = target_flags; + memcpy(pos, target, ETH_ALEN); + pos += ETH_ALEN; + put_unaligned_le32(target_sn, pos); + pos += 4; + } else if (action == MPATH_PREP) { + memcpy(pos, orig_addr, ETH_ALEN); + pos += ETH_ALEN; + put_unaligned_le32(orig_sn, pos); + pos += 4; + } + + ieee80211_tx_skb(sdata, skb); + return 0; +} + + +/* Headroom is not adjusted. Caller should ensure that skb has sufficient + * headroom in case the frame is encrypted. */ +static void prepare_frame_for_deferred_tx(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + + skb_reset_mac_header(skb); + skb_reset_network_header(skb); + skb_reset_transport_header(skb); + + /* Send all internal mgmt frames on VO. Accordingly set TID to 7. */ + skb_set_queue_mapping(skb, IEEE80211_AC_VO); + skb->priority = 7; + + info->control.vif = &sdata->vif; + info->control.flags |= IEEE80211_TX_INTCFL_NEED_TXPROCESSING; + ieee80211_set_qos_hdr(sdata, skb); + ieee80211_mps_set_frame_flags(sdata, NULL, hdr); +} + +/** + * mesh_path_error_tx - Sends a PERR mesh management frame + * + * @ttl: allowed remaining hops + * @target: broken destination + * @target_sn: SN of the broken destination + * @target_rcode: reason code for this PERR + * @ra: node this frame is addressed to + * @sdata: local mesh subif + * + * Note: This function may be called with driver locks taken that the driver + * also acquires in the TX path. To avoid a deadlock we don't transmit the + * frame directly but add it to the pending queue instead. + */ +int mesh_path_error_tx(struct ieee80211_sub_if_data *sdata, + u8 ttl, const u8 *target, u32 target_sn, + u16 target_rcode, const u8 *ra) +{ + struct ieee80211_local *local = sdata->local; + struct sk_buff *skb; + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct ieee80211_mgmt *mgmt; + u8 *pos, ie_len; + int hdr_len = offsetofend(struct ieee80211_mgmt, + u.action.u.mesh_action); + + if (time_before(jiffies, ifmsh->next_perr)) + return -EAGAIN; + + skb = dev_alloc_skb(local->tx_headroom + + IEEE80211_ENCRYPT_HEADROOM + + IEEE80211_ENCRYPT_TAILROOM + + hdr_len + + 2 + 15 /* PERR IE */); + if (!skb) + return -1; + skb_reserve(skb, local->tx_headroom + IEEE80211_ENCRYPT_HEADROOM); + mgmt = skb_put_zero(skb, hdr_len); + mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_ACTION); + + memcpy(mgmt->da, ra, ETH_ALEN); + memcpy(mgmt->sa, sdata->vif.addr, ETH_ALEN); + /* BSSID == SA */ + memcpy(mgmt->bssid, sdata->vif.addr, ETH_ALEN); + mgmt->u.action.category = WLAN_CATEGORY_MESH_ACTION; + mgmt->u.action.u.mesh_action.action_code = + WLAN_MESH_ACTION_HWMP_PATH_SELECTION; + ie_len = 15; + pos = skb_put(skb, 2 + ie_len); + *pos++ = WLAN_EID_PERR; + *pos++ = ie_len; + /* ttl */ + *pos++ = ttl; + /* number of destinations */ + *pos++ = 1; + /* Flags field has AE bit only as defined in + * sec 8.4.2.117 IEEE802.11-2012 + */ + *pos = 0; + pos++; + memcpy(pos, target, ETH_ALEN); + pos += ETH_ALEN; + put_unaligned_le32(target_sn, pos); + pos += 4; + put_unaligned_le16(target_rcode, pos); + + /* see note in function header */ + prepare_frame_for_deferred_tx(sdata, skb); + ifmsh->next_perr = TU_TO_EXP_TIME( + ifmsh->mshcfg.dot11MeshHWMPperrMinInterval); + ieee80211_add_pending_skb(local, skb); + return 0; +} + +void ieee80211s_update_metric(struct ieee80211_local *local, + struct sta_info *sta, + struct ieee80211_tx_status *st) +{ + struct ieee80211_tx_info *txinfo = st->info; + int failed; + struct rate_info rinfo; + + failed = !(txinfo->flags & IEEE80211_TX_STAT_ACK); + + /* moving average, scaled to 100. + * feed failure as 100 and success as 0 + */ + ewma_mesh_fail_avg_add(&sta->mesh->fail_avg, failed * 100); + if (ewma_mesh_fail_avg_read(&sta->mesh->fail_avg) > + LINK_FAIL_THRESH) + mesh_plink_broken(sta); + + /* use rate info set by the driver directly if present */ + if (st->n_rates) + rinfo = sta->deflink.tx_stats.last_rate_info; + else + sta_set_rate_info_tx(sta, &sta->deflink.tx_stats.last_rate, &rinfo); + + ewma_mesh_tx_rate_avg_add(&sta->mesh->tx_rate_avg, + cfg80211_calculate_bitrate(&rinfo)); +} + +u32 airtime_link_metric_get(struct ieee80211_local *local, + struct sta_info *sta) +{ + /* This should be adjusted for each device */ + int device_constant = 1 << ARITH_SHIFT; + int test_frame_len = TEST_FRAME_LEN << ARITH_SHIFT; + int s_unit = 1 << ARITH_SHIFT; + int rate, err; + u32 tx_time, estimated_retx; + u64 result; + unsigned long fail_avg = + ewma_mesh_fail_avg_read(&sta->mesh->fail_avg); + + if (sta->mesh->plink_state != NL80211_PLINK_ESTAB) + return MAX_METRIC; + + /* Try to get rate based on HW/SW RC algorithm. + * Rate is returned in units of Kbps, correct this + * to comply with airtime calculation units + * Round up in case we get rate < 100Kbps + */ + rate = DIV_ROUND_UP(sta_get_expected_throughput(sta), 100); + + if (rate) { + err = 0; + } else { + if (fail_avg > LINK_FAIL_THRESH) + return MAX_METRIC; + + rate = ewma_mesh_tx_rate_avg_read(&sta->mesh->tx_rate_avg); + if (WARN_ON(!rate)) + return MAX_METRIC; + + err = (fail_avg << ARITH_SHIFT) / 100; + } + + /* bitrate is in units of 100 Kbps, while we need rate in units of + * 1Mbps. This will be corrected on tx_time computation. + */ + tx_time = (device_constant + 10 * test_frame_len / rate); + estimated_retx = ((1 << (2 * ARITH_SHIFT)) / (s_unit - err)); + result = ((u64)tx_time * estimated_retx) >> (2 * ARITH_SHIFT); + return (u32)result; +} + +/** + * hwmp_route_info_get - Update routing info to originator and transmitter + * + * @sdata: local mesh subif + * @mgmt: mesh management frame + * @hwmp_ie: hwmp information element (PREP or PREQ) + * @action: type of hwmp ie + * + * This function updates the path routing information to the originator and the + * transmitter of a HWMP PREQ or PREP frame. + * + * Returns: metric to frame originator or 0 if the frame should not be further + * processed + * + * Notes: this function is the only place (besides user-provided info) where + * path routing information is updated. + */ +static u32 hwmp_route_info_get(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, + const u8 *hwmp_ie, enum mpath_frame_type action) +{ + struct ieee80211_local *local = sdata->local; + struct mesh_path *mpath; + struct sta_info *sta; + bool fresh_info; + const u8 *orig_addr, *ta; + u32 orig_sn, orig_metric; + unsigned long orig_lifetime, exp_time; + u32 last_hop_metric, new_metric; + bool process = true; + u8 hopcount; + + rcu_read_lock(); + sta = sta_info_get(sdata, mgmt->sa); + if (!sta) { + rcu_read_unlock(); + return 0; + } + + last_hop_metric = airtime_link_metric_get(local, sta); + /* Update and check originator routing info */ + fresh_info = true; + + switch (action) { + case MPATH_PREQ: + orig_addr = PREQ_IE_ORIG_ADDR(hwmp_ie); + orig_sn = PREQ_IE_ORIG_SN(hwmp_ie); + orig_lifetime = PREQ_IE_LIFETIME(hwmp_ie); + orig_metric = PREQ_IE_METRIC(hwmp_ie); + hopcount = PREQ_IE_HOPCOUNT(hwmp_ie) + 1; + break; + case MPATH_PREP: + /* Originator here refers to the MP that was the target in the + * Path Request. We divert from the nomenclature in the draft + * so that we can easily use a single function to gather path + * information from both PREQ and PREP frames. + */ + orig_addr = PREP_IE_TARGET_ADDR(hwmp_ie); + orig_sn = PREP_IE_TARGET_SN(hwmp_ie); + orig_lifetime = PREP_IE_LIFETIME(hwmp_ie); + orig_metric = PREP_IE_METRIC(hwmp_ie); + hopcount = PREP_IE_HOPCOUNT(hwmp_ie) + 1; + break; + default: + rcu_read_unlock(); + return 0; + } + new_metric = orig_metric + last_hop_metric; + if (new_metric < orig_metric) + new_metric = MAX_METRIC; + exp_time = TU_TO_EXP_TIME(orig_lifetime); + + if (ether_addr_equal(orig_addr, sdata->vif.addr)) { + /* This MP is the originator, we are not interested in this + * frame, except for updating transmitter's path info. + */ + process = false; + fresh_info = false; + } else { + mpath = mesh_path_lookup(sdata, orig_addr); + if (mpath) { + spin_lock_bh(&mpath->state_lock); + if (mpath->flags & MESH_PATH_FIXED) + fresh_info = false; + else if ((mpath->flags & MESH_PATH_ACTIVE) && + (mpath->flags & MESH_PATH_SN_VALID)) { + if (SN_GT(mpath->sn, orig_sn) || + (mpath->sn == orig_sn && + (rcu_access_pointer(mpath->next_hop) != + sta ? + mult_frac(new_metric, 10, 9) : + new_metric) >= mpath->metric)) { + process = false; + fresh_info = false; + } + } else if (!(mpath->flags & MESH_PATH_ACTIVE)) { + bool have_sn, newer_sn, bounced; + + have_sn = mpath->flags & MESH_PATH_SN_VALID; + newer_sn = have_sn && SN_GT(orig_sn, mpath->sn); + bounced = have_sn && + (SN_DELTA(orig_sn, mpath->sn) > + MAX_SANE_SN_DELTA); + + if (!have_sn || newer_sn) { + /* if SN is newer than what we had + * then we can take it */; + } else if (bounced) { + /* if SN is way different than what + * we had then assume the other side + * rebooted or restarted */; + } else { + process = false; + fresh_info = false; + } + } + } else { + mpath = mesh_path_add(sdata, orig_addr); + if (IS_ERR(mpath)) { + rcu_read_unlock(); + return 0; + } + spin_lock_bh(&mpath->state_lock); + } + + if (fresh_info) { + if (rcu_access_pointer(mpath->next_hop) != sta) + mpath->path_change_count++; + mesh_path_assign_nexthop(mpath, sta); + mpath->flags |= MESH_PATH_SN_VALID; + mpath->metric = new_metric; + mpath->sn = orig_sn; + mpath->exp_time = time_after(mpath->exp_time, exp_time) + ? mpath->exp_time : exp_time; + mpath->hop_count = hopcount; + mesh_path_activate(mpath); + spin_unlock_bh(&mpath->state_lock); + ewma_mesh_fail_avg_init(&sta->mesh->fail_avg); + /* init it at a low value - 0 start is tricky */ + ewma_mesh_fail_avg_add(&sta->mesh->fail_avg, 1); + mesh_path_tx_pending(mpath); + /* draft says preq_id should be saved to, but there does + * not seem to be any use for it, skipping by now + */ + } else + spin_unlock_bh(&mpath->state_lock); + } + + /* Update and check transmitter routing info */ + ta = mgmt->sa; + if (ether_addr_equal(orig_addr, ta)) + fresh_info = false; + else { + fresh_info = true; + + mpath = mesh_path_lookup(sdata, ta); + if (mpath) { + spin_lock_bh(&mpath->state_lock); + if ((mpath->flags & MESH_PATH_FIXED) || + ((mpath->flags & MESH_PATH_ACTIVE) && + ((rcu_access_pointer(mpath->next_hop) != sta ? + mult_frac(last_hop_metric, 10, 9) : + last_hop_metric) > mpath->metric))) + fresh_info = false; + } else { + mpath = mesh_path_add(sdata, ta); + if (IS_ERR(mpath)) { + rcu_read_unlock(); + return 0; + } + spin_lock_bh(&mpath->state_lock); + } + + if (fresh_info) { + if (rcu_access_pointer(mpath->next_hop) != sta) + mpath->path_change_count++; + mesh_path_assign_nexthop(mpath, sta); + mpath->metric = last_hop_metric; + mpath->exp_time = time_after(mpath->exp_time, exp_time) + ? mpath->exp_time : exp_time; + mpath->hop_count = 1; + mesh_path_activate(mpath); + spin_unlock_bh(&mpath->state_lock); + ewma_mesh_fail_avg_init(&sta->mesh->fail_avg); + /* init it at a low value - 0 start is tricky */ + ewma_mesh_fail_avg_add(&sta->mesh->fail_avg, 1); + mesh_path_tx_pending(mpath); + } else + spin_unlock_bh(&mpath->state_lock); + } + + rcu_read_unlock(); + + return process ? new_metric : 0; +} + +static void hwmp_preq_frame_process(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, + const u8 *preq_elem, u32 orig_metric) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct mesh_path *mpath = NULL; + const u8 *target_addr, *orig_addr; + const u8 *da; + u8 target_flags, ttl, flags; + u32 orig_sn, target_sn, lifetime, target_metric = 0; + bool reply = false; + bool forward = true; + bool root_is_gate; + + /* Update target SN, if present */ + target_addr = PREQ_IE_TARGET_ADDR(preq_elem); + orig_addr = PREQ_IE_ORIG_ADDR(preq_elem); + target_sn = PREQ_IE_TARGET_SN(preq_elem); + orig_sn = PREQ_IE_ORIG_SN(preq_elem); + target_flags = PREQ_IE_TARGET_F(preq_elem); + /* Proactive PREQ gate announcements */ + flags = PREQ_IE_FLAGS(preq_elem); + root_is_gate = !!(flags & RANN_FLAG_IS_GATE); + + mhwmp_dbg(sdata, "received PREQ from %pM\n", orig_addr); + + if (ether_addr_equal(target_addr, sdata->vif.addr)) { + mhwmp_dbg(sdata, "PREQ is for us\n"); + forward = false; + reply = true; + target_metric = 0; + + if (SN_GT(target_sn, ifmsh->sn)) + ifmsh->sn = target_sn; + + if (time_after(jiffies, ifmsh->last_sn_update + + net_traversal_jiffies(sdata)) || + time_before(jiffies, ifmsh->last_sn_update)) { + ++ifmsh->sn; + ifmsh->last_sn_update = jiffies; + } + target_sn = ifmsh->sn; + } else if (is_broadcast_ether_addr(target_addr) && + (target_flags & IEEE80211_PREQ_TO_FLAG)) { + rcu_read_lock(); + mpath = mesh_path_lookup(sdata, orig_addr); + if (mpath) { + if (flags & IEEE80211_PREQ_PROACTIVE_PREP_FLAG) { + reply = true; + target_addr = sdata->vif.addr; + target_sn = ++ifmsh->sn; + target_metric = 0; + ifmsh->last_sn_update = jiffies; + } + if (root_is_gate) + mesh_path_add_gate(mpath); + } + rcu_read_unlock(); + } else { + rcu_read_lock(); + mpath = mesh_path_lookup(sdata, target_addr); + if (mpath) { + if ((!(mpath->flags & MESH_PATH_SN_VALID)) || + SN_LT(mpath->sn, target_sn)) { + mpath->sn = target_sn; + mpath->flags |= MESH_PATH_SN_VALID; + } else if ((!(target_flags & IEEE80211_PREQ_TO_FLAG)) && + (mpath->flags & MESH_PATH_ACTIVE)) { + reply = true; + target_metric = mpath->metric; + target_sn = mpath->sn; + /* Case E2 of sec 13.10.9.3 IEEE 802.11-2012*/ + target_flags |= IEEE80211_PREQ_TO_FLAG; + } + } + rcu_read_unlock(); + } + + if (reply) { + lifetime = PREQ_IE_LIFETIME(preq_elem); + ttl = ifmsh->mshcfg.element_ttl; + if (ttl != 0) { + mhwmp_dbg(sdata, "replying to the PREQ\n"); + mesh_path_sel_frame_tx(MPATH_PREP, 0, orig_addr, + orig_sn, 0, target_addr, + target_sn, mgmt->sa, 0, ttl, + lifetime, target_metric, 0, + sdata); + } else { + ifmsh->mshstats.dropped_frames_ttl++; + } + } + + if (forward && ifmsh->mshcfg.dot11MeshForwarding) { + u32 preq_id; + u8 hopcount; + + ttl = PREQ_IE_TTL(preq_elem); + lifetime = PREQ_IE_LIFETIME(preq_elem); + if (ttl <= 1) { + ifmsh->mshstats.dropped_frames_ttl++; + return; + } + mhwmp_dbg(sdata, "forwarding the PREQ from %pM\n", orig_addr); + --ttl; + preq_id = PREQ_IE_PREQ_ID(preq_elem); + hopcount = PREQ_IE_HOPCOUNT(preq_elem) + 1; + da = (mpath && mpath->is_root) ? + mpath->rann_snd_addr : broadcast_addr; + + if (flags & IEEE80211_PREQ_PROACTIVE_PREP_FLAG) { + target_addr = PREQ_IE_TARGET_ADDR(preq_elem); + target_sn = PREQ_IE_TARGET_SN(preq_elem); + } + + mesh_path_sel_frame_tx(MPATH_PREQ, flags, orig_addr, + orig_sn, target_flags, target_addr, + target_sn, da, hopcount, ttl, lifetime, + orig_metric, preq_id, sdata); + if (!is_multicast_ether_addr(da)) + ifmsh->mshstats.fwded_unicast++; + else + ifmsh->mshstats.fwded_mcast++; + ifmsh->mshstats.fwded_frames++; + } +} + + +static inline struct sta_info * +next_hop_deref_protected(struct mesh_path *mpath) +{ + return rcu_dereference_protected(mpath->next_hop, + lockdep_is_held(&mpath->state_lock)); +} + + +static void hwmp_prep_frame_process(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, + const u8 *prep_elem, u32 metric) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct mesh_path *mpath; + const u8 *target_addr, *orig_addr; + u8 ttl, hopcount, flags; + u8 next_hop[ETH_ALEN]; + u32 target_sn, orig_sn, lifetime; + + mhwmp_dbg(sdata, "received PREP from %pM\n", + PREP_IE_TARGET_ADDR(prep_elem)); + + orig_addr = PREP_IE_ORIG_ADDR(prep_elem); + if (ether_addr_equal(orig_addr, sdata->vif.addr)) + /* destination, no forwarding required */ + return; + + if (!ifmsh->mshcfg.dot11MeshForwarding) + return; + + ttl = PREP_IE_TTL(prep_elem); + if (ttl <= 1) { + sdata->u.mesh.mshstats.dropped_frames_ttl++; + return; + } + + rcu_read_lock(); + mpath = mesh_path_lookup(sdata, orig_addr); + if (mpath) + spin_lock_bh(&mpath->state_lock); + else + goto fail; + if (!(mpath->flags & MESH_PATH_ACTIVE)) { + spin_unlock_bh(&mpath->state_lock); + goto fail; + } + memcpy(next_hop, next_hop_deref_protected(mpath)->sta.addr, ETH_ALEN); + spin_unlock_bh(&mpath->state_lock); + --ttl; + flags = PREP_IE_FLAGS(prep_elem); + lifetime = PREP_IE_LIFETIME(prep_elem); + hopcount = PREP_IE_HOPCOUNT(prep_elem) + 1; + target_addr = PREP_IE_TARGET_ADDR(prep_elem); + target_sn = PREP_IE_TARGET_SN(prep_elem); + orig_sn = PREP_IE_ORIG_SN(prep_elem); + + mesh_path_sel_frame_tx(MPATH_PREP, flags, orig_addr, orig_sn, 0, + target_addr, target_sn, next_hop, hopcount, + ttl, lifetime, metric, 0, sdata); + rcu_read_unlock(); + + sdata->u.mesh.mshstats.fwded_unicast++; + sdata->u.mesh.mshstats.fwded_frames++; + return; + +fail: + rcu_read_unlock(); + sdata->u.mesh.mshstats.dropped_frames_no_route++; +} + +static void hwmp_perr_frame_process(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, + const u8 *perr_elem) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct mesh_path *mpath; + u8 ttl; + const u8 *ta, *target_addr; + u32 target_sn; + u16 target_rcode; + + ta = mgmt->sa; + ttl = PERR_IE_TTL(perr_elem); + if (ttl <= 1) { + ifmsh->mshstats.dropped_frames_ttl++; + return; + } + ttl--; + target_addr = PERR_IE_TARGET_ADDR(perr_elem); + target_sn = PERR_IE_TARGET_SN(perr_elem); + target_rcode = PERR_IE_TARGET_RCODE(perr_elem); + + rcu_read_lock(); + mpath = mesh_path_lookup(sdata, target_addr); + if (mpath) { + struct sta_info *sta; + + spin_lock_bh(&mpath->state_lock); + sta = next_hop_deref_protected(mpath); + if (mpath->flags & MESH_PATH_ACTIVE && + ether_addr_equal(ta, sta->sta.addr) && + !(mpath->flags & MESH_PATH_FIXED) && + (!(mpath->flags & MESH_PATH_SN_VALID) || + SN_GT(target_sn, mpath->sn) || target_sn == 0)) { + mpath->flags &= ~MESH_PATH_ACTIVE; + if (target_sn != 0) + mpath->sn = target_sn; + else + mpath->sn += 1; + spin_unlock_bh(&mpath->state_lock); + if (!ifmsh->mshcfg.dot11MeshForwarding) + goto endperr; + mesh_path_error_tx(sdata, ttl, target_addr, + target_sn, target_rcode, + broadcast_addr); + } else + spin_unlock_bh(&mpath->state_lock); + } +endperr: + rcu_read_unlock(); +} + +static void hwmp_rann_frame_process(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, + const struct ieee80211_rann_ie *rann) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + struct mesh_path *mpath; + u8 ttl, flags, hopcount; + const u8 *orig_addr; + u32 orig_sn, new_metric, orig_metric, last_hop_metric, interval; + bool root_is_gate; + + ttl = rann->rann_ttl; + flags = rann->rann_flags; + root_is_gate = !!(flags & RANN_FLAG_IS_GATE); + orig_addr = rann->rann_addr; + orig_sn = le32_to_cpu(rann->rann_seq); + interval = le32_to_cpu(rann->rann_interval); + hopcount = rann->rann_hopcount; + hopcount++; + orig_metric = le32_to_cpu(rann->rann_metric); + + /* Ignore our own RANNs */ + if (ether_addr_equal(orig_addr, sdata->vif.addr)) + return; + + mhwmp_dbg(sdata, + "received RANN from %pM via neighbour %pM (is_gate=%d)\n", + orig_addr, mgmt->sa, root_is_gate); + + rcu_read_lock(); + sta = sta_info_get(sdata, mgmt->sa); + if (!sta) { + rcu_read_unlock(); + return; + } + + last_hop_metric = airtime_link_metric_get(local, sta); + new_metric = orig_metric + last_hop_metric; + if (new_metric < orig_metric) + new_metric = MAX_METRIC; + + mpath = mesh_path_lookup(sdata, orig_addr); + if (!mpath) { + mpath = mesh_path_add(sdata, orig_addr); + if (IS_ERR(mpath)) { + rcu_read_unlock(); + sdata->u.mesh.mshstats.dropped_frames_no_route++; + return; + } + } + + if (!(SN_LT(mpath->sn, orig_sn)) && + !(mpath->sn == orig_sn && new_metric < mpath->rann_metric)) { + rcu_read_unlock(); + return; + } + + if ((!(mpath->flags & (MESH_PATH_ACTIVE | MESH_PATH_RESOLVING)) || + (time_after(jiffies, mpath->last_preq_to_root + + root_path_confirmation_jiffies(sdata)) || + time_before(jiffies, mpath->last_preq_to_root))) && + !(mpath->flags & MESH_PATH_FIXED) && (ttl != 0)) { + mhwmp_dbg(sdata, + "time to refresh root mpath %pM\n", + orig_addr); + mesh_queue_preq(mpath, PREQ_Q_F_START | PREQ_Q_F_REFRESH); + mpath->last_preq_to_root = jiffies; + } + + mpath->sn = orig_sn; + mpath->rann_metric = new_metric; + mpath->is_root = true; + /* Recording RANNs sender address to send individually + * addressed PREQs destined for root mesh STA */ + memcpy(mpath->rann_snd_addr, mgmt->sa, ETH_ALEN); + + if (root_is_gate) + mesh_path_add_gate(mpath); + + if (ttl <= 1) { + ifmsh->mshstats.dropped_frames_ttl++; + rcu_read_unlock(); + return; + } + ttl--; + + if (ifmsh->mshcfg.dot11MeshForwarding) { + mesh_path_sel_frame_tx(MPATH_RANN, flags, orig_addr, + orig_sn, 0, NULL, 0, broadcast_addr, + hopcount, ttl, interval, + new_metric, 0, sdata); + } + + rcu_read_unlock(); +} + + +void mesh_rx_path_sel_frame(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, size_t len) +{ + struct ieee802_11_elems *elems; + size_t baselen; + u32 path_metric; + struct sta_info *sta; + + /* need action_code */ + if (len < IEEE80211_MIN_ACTION_SIZE + 1) + return; + + rcu_read_lock(); + sta = sta_info_get(sdata, mgmt->sa); + if (!sta || sta->mesh->plink_state != NL80211_PLINK_ESTAB) { + rcu_read_unlock(); + return; + } + rcu_read_unlock(); + + baselen = (u8 *) mgmt->u.action.u.mesh_action.variable - (u8 *) mgmt; + elems = ieee802_11_parse_elems(mgmt->u.action.u.mesh_action.variable, + len - baselen, false, NULL); + if (!elems) + return; + + if (elems->preq) { + if (elems->preq_len != 37) + /* Right now we support just 1 destination and no AE */ + goto free; + path_metric = hwmp_route_info_get(sdata, mgmt, elems->preq, + MPATH_PREQ); + if (path_metric) + hwmp_preq_frame_process(sdata, mgmt, elems->preq, + path_metric); + } + if (elems->prep) { + if (elems->prep_len != 31) + /* Right now we support no AE */ + goto free; + path_metric = hwmp_route_info_get(sdata, mgmt, elems->prep, + MPATH_PREP); + if (path_metric) + hwmp_prep_frame_process(sdata, mgmt, elems->prep, + path_metric); + } + if (elems->perr) { + if (elems->perr_len != 15) + /* Right now we support only one destination per PERR */ + goto free; + hwmp_perr_frame_process(sdata, mgmt, elems->perr); + } + if (elems->rann) + hwmp_rann_frame_process(sdata, mgmt, elems->rann); +free: + kfree(elems); +} + +/** + * mesh_queue_preq - queue a PREQ to a given destination + * + * @mpath: mesh path to discover + * @flags: special attributes of the PREQ to be sent + * + * Locking: the function must be called from within a rcu read lock block. + * + */ +static void mesh_queue_preq(struct mesh_path *mpath, u8 flags) +{ + struct ieee80211_sub_if_data *sdata = mpath->sdata; + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct mesh_preq_queue *preq_node; + + preq_node = kmalloc(sizeof(struct mesh_preq_queue), GFP_ATOMIC); + if (!preq_node) { + mhwmp_dbg(sdata, "could not allocate PREQ node\n"); + return; + } + + spin_lock_bh(&ifmsh->mesh_preq_queue_lock); + if (ifmsh->preq_queue_len == MAX_PREQ_QUEUE_LEN) { + spin_unlock_bh(&ifmsh->mesh_preq_queue_lock); + kfree(preq_node); + if (printk_ratelimit()) + mhwmp_dbg(sdata, "PREQ node queue full\n"); + return; + } + + spin_lock(&mpath->state_lock); + if (mpath->flags & MESH_PATH_REQ_QUEUED) { + spin_unlock(&mpath->state_lock); + spin_unlock_bh(&ifmsh->mesh_preq_queue_lock); + kfree(preq_node); + return; + } + + memcpy(preq_node->dst, mpath->dst, ETH_ALEN); + preq_node->flags = flags; + + mpath->flags |= MESH_PATH_REQ_QUEUED; + spin_unlock(&mpath->state_lock); + + list_add_tail(&preq_node->list, &ifmsh->preq_queue.list); + ++ifmsh->preq_queue_len; + spin_unlock_bh(&ifmsh->mesh_preq_queue_lock); + + if (time_after(jiffies, ifmsh->last_preq + min_preq_int_jiff(sdata))) + ieee80211_queue_work(&sdata->local->hw, &sdata->work); + + else if (time_before(jiffies, ifmsh->last_preq)) { + /* avoid long wait if did not send preqs for a long time + * and jiffies wrapped around + */ + ifmsh->last_preq = jiffies - min_preq_int_jiff(sdata) - 1; + ieee80211_queue_work(&sdata->local->hw, &sdata->work); + } else + mod_timer(&ifmsh->mesh_path_timer, ifmsh->last_preq + + min_preq_int_jiff(sdata)); +} + +/** + * mesh_path_start_discovery - launch a path discovery from the PREQ queue + * + * @sdata: local mesh subif + */ +void mesh_path_start_discovery(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct mesh_preq_queue *preq_node; + struct mesh_path *mpath; + u8 ttl, target_flags = 0; + const u8 *da; + u32 lifetime; + + spin_lock_bh(&ifmsh->mesh_preq_queue_lock); + if (!ifmsh->preq_queue_len || + time_before(jiffies, ifmsh->last_preq + + min_preq_int_jiff(sdata))) { + spin_unlock_bh(&ifmsh->mesh_preq_queue_lock); + return; + } + + preq_node = list_first_entry(&ifmsh->preq_queue.list, + struct mesh_preq_queue, list); + list_del(&preq_node->list); + --ifmsh->preq_queue_len; + spin_unlock_bh(&ifmsh->mesh_preq_queue_lock); + + rcu_read_lock(); + mpath = mesh_path_lookup(sdata, preq_node->dst); + if (!mpath) + goto enddiscovery; + + spin_lock_bh(&mpath->state_lock); + if (mpath->flags & (MESH_PATH_DELETED | MESH_PATH_FIXED)) { + spin_unlock_bh(&mpath->state_lock); + goto enddiscovery; + } + mpath->flags &= ~MESH_PATH_REQ_QUEUED; + if (preq_node->flags & PREQ_Q_F_START) { + if (mpath->flags & MESH_PATH_RESOLVING) { + spin_unlock_bh(&mpath->state_lock); + goto enddiscovery; + } else { + mpath->flags &= ~MESH_PATH_RESOLVED; + mpath->flags |= MESH_PATH_RESOLVING; + mpath->discovery_retries = 0; + mpath->discovery_timeout = disc_timeout_jiff(sdata); + } + } else if (!(mpath->flags & MESH_PATH_RESOLVING) || + mpath->flags & MESH_PATH_RESOLVED) { + mpath->flags &= ~MESH_PATH_RESOLVING; + spin_unlock_bh(&mpath->state_lock); + goto enddiscovery; + } + + ifmsh->last_preq = jiffies; + + if (time_after(jiffies, ifmsh->last_sn_update + + net_traversal_jiffies(sdata)) || + time_before(jiffies, ifmsh->last_sn_update)) { + ++ifmsh->sn; + sdata->u.mesh.last_sn_update = jiffies; + } + lifetime = default_lifetime(sdata); + ttl = sdata->u.mesh.mshcfg.element_ttl; + if (ttl == 0) { + sdata->u.mesh.mshstats.dropped_frames_ttl++; + spin_unlock_bh(&mpath->state_lock); + goto enddiscovery; + } + + if (preq_node->flags & PREQ_Q_F_REFRESH) + target_flags |= IEEE80211_PREQ_TO_FLAG; + else + target_flags &= ~IEEE80211_PREQ_TO_FLAG; + + spin_unlock_bh(&mpath->state_lock); + da = (mpath->is_root) ? mpath->rann_snd_addr : broadcast_addr; + mesh_path_sel_frame_tx(MPATH_PREQ, 0, sdata->vif.addr, ifmsh->sn, + target_flags, mpath->dst, mpath->sn, da, 0, + ttl, lifetime, 0, ifmsh->preq_id++, sdata); + + spin_lock_bh(&mpath->state_lock); + if (!(mpath->flags & MESH_PATH_DELETED)) + mod_timer(&mpath->timer, jiffies + mpath->discovery_timeout); + spin_unlock_bh(&mpath->state_lock); + +enddiscovery: + rcu_read_unlock(); + kfree(preq_node); +} + +/** + * mesh_nexthop_resolve - lookup next hop; conditionally start path discovery + * + * @skb: 802.11 frame to be sent + * @sdata: network subif the frame will be sent through + * + * Lookup next hop for given skb and start path discovery if no + * forwarding information is found. + * + * Returns: 0 if the next hop was found and -ENOENT if the frame was queued. + * skb is freed here if no mpath could be allocated. + */ +int mesh_nexthop_resolve(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct mesh_path *mpath; + struct sk_buff *skb_to_free = NULL; + u8 *target_addr = hdr->addr3; + + /* Nulls are only sent to peers for PS and should be pre-addressed */ + if (ieee80211_is_qos_nullfunc(hdr->frame_control)) + return 0; + + /* Allow injected packets to bypass mesh routing */ + if (info->control.flags & IEEE80211_TX_CTRL_SKIP_MPATH_LOOKUP) + return 0; + + if (!mesh_nexthop_lookup(sdata, skb)) + return 0; + + /* no nexthop found, start resolving */ + mpath = mesh_path_lookup(sdata, target_addr); + if (!mpath) { + mpath = mesh_path_add(sdata, target_addr); + if (IS_ERR(mpath)) { + mesh_path_discard_frame(sdata, skb); + return PTR_ERR(mpath); + } + } + + if (!(mpath->flags & MESH_PATH_RESOLVING) && + mesh_path_sel_is_hwmp(sdata)) + mesh_queue_preq(mpath, PREQ_Q_F_START); + + if (skb_queue_len(&mpath->frame_queue) >= MESH_FRAME_QUEUE_LEN) + skb_to_free = skb_dequeue(&mpath->frame_queue); + + info->control.flags |= IEEE80211_TX_INTCFL_NEED_TXPROCESSING; + ieee80211_set_qos_hdr(sdata, skb); + skb_queue_tail(&mpath->frame_queue, skb); + if (skb_to_free) + mesh_path_discard_frame(sdata, skb_to_free); + + return -ENOENT; +} + +/** + * mesh_nexthop_lookup_nolearn - try to set next hop without path discovery + * @skb: 802.11 frame to be sent + * @sdata: network subif the frame will be sent through + * + * Check if the meshDA (addr3) of a unicast frame is a direct neighbor. + * And if so, set the RA (addr1) to it to transmit to this node directly, + * avoiding PREQ/PREP path discovery. + * + * Returns: 0 if the next hop was found and -ENOENT otherwise. + */ +static int mesh_nexthop_lookup_nolearn(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + struct sta_info *sta; + + if (is_multicast_ether_addr(hdr->addr1)) + return -ENOENT; + + rcu_read_lock(); + sta = sta_info_get(sdata, hdr->addr3); + + if (!sta || sta->mesh->plink_state != NL80211_PLINK_ESTAB) { + rcu_read_unlock(); + return -ENOENT; + } + rcu_read_unlock(); + + memcpy(hdr->addr1, hdr->addr3, ETH_ALEN); + memcpy(hdr->addr2, sdata->vif.addr, ETH_ALEN); + return 0; +} + +/** + * mesh_nexthop_lookup - put the appropriate next hop on a mesh frame. Calling + * this function is considered "using" the associated mpath, so preempt a path + * refresh if this mpath expires soon. + * + * @skb: 802.11 frame to be sent + * @sdata: network subif the frame will be sent through + * + * Returns: 0 if the next hop was found. Nonzero otherwise. + */ +int mesh_nexthop_lookup(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct mesh_path *mpath; + struct sta_info *next_hop; + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + u8 *target_addr = hdr->addr3; + + if (ifmsh->mshcfg.dot11MeshNolearn && + !mesh_nexthop_lookup_nolearn(sdata, skb)) + return 0; + + mpath = mesh_path_lookup(sdata, target_addr); + if (!mpath || !(mpath->flags & MESH_PATH_ACTIVE)) + return -ENOENT; + + if (time_after(jiffies, + mpath->exp_time - + msecs_to_jiffies(sdata->u.mesh.mshcfg.path_refresh_time)) && + ether_addr_equal(sdata->vif.addr, hdr->addr4) && + !(mpath->flags & MESH_PATH_RESOLVING) && + !(mpath->flags & MESH_PATH_FIXED)) + mesh_queue_preq(mpath, PREQ_Q_F_START | PREQ_Q_F_REFRESH); + + next_hop = rcu_dereference(mpath->next_hop); + if (next_hop) { + memcpy(hdr->addr1, next_hop->sta.addr, ETH_ALEN); + memcpy(hdr->addr2, sdata->vif.addr, ETH_ALEN); + ieee80211_mps_set_frame_flags(sdata, next_hop, hdr); + return 0; + } + + return -ENOENT; +} + +void mesh_path_timer(struct timer_list *t) +{ + struct mesh_path *mpath = from_timer(mpath, t, timer); + struct ieee80211_sub_if_data *sdata = mpath->sdata; + int ret; + + if (sdata->local->quiescing) + return; + + spin_lock_bh(&mpath->state_lock); + if (mpath->flags & MESH_PATH_RESOLVED || + (!(mpath->flags & MESH_PATH_RESOLVING))) { + mpath->flags &= ~(MESH_PATH_RESOLVING | MESH_PATH_RESOLVED); + spin_unlock_bh(&mpath->state_lock); + } else if (mpath->discovery_retries < max_preq_retries(sdata)) { + ++mpath->discovery_retries; + mpath->discovery_timeout *= 2; + mpath->flags &= ~MESH_PATH_REQ_QUEUED; + spin_unlock_bh(&mpath->state_lock); + mesh_queue_preq(mpath, 0); + } else { + mpath->flags &= ~(MESH_PATH_RESOLVING | + MESH_PATH_RESOLVED | + MESH_PATH_REQ_QUEUED); + mpath->exp_time = jiffies; + spin_unlock_bh(&mpath->state_lock); + if (!mpath->is_gate && mesh_gate_num(sdata) > 0) { + ret = mesh_path_send_to_gates(mpath); + if (ret) + mhwmp_dbg(sdata, "no gate was reachable\n"); + } else + mesh_path_flush_pending(mpath); + } +} + +void mesh_path_tx_root_frame(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + u32 interval = ifmsh->mshcfg.dot11MeshHWMPRannInterval; + u8 flags, target_flags = 0; + + flags = (ifmsh->mshcfg.dot11MeshGateAnnouncementProtocol) + ? RANN_FLAG_IS_GATE : 0; + + switch (ifmsh->mshcfg.dot11MeshHWMPRootMode) { + case IEEE80211_PROACTIVE_RANN: + mesh_path_sel_frame_tx(MPATH_RANN, flags, sdata->vif.addr, + ++ifmsh->sn, 0, NULL, 0, broadcast_addr, + 0, ifmsh->mshcfg.element_ttl, + interval, 0, 0, sdata); + break; + case IEEE80211_PROACTIVE_PREQ_WITH_PREP: + flags |= IEEE80211_PREQ_PROACTIVE_PREP_FLAG; + fallthrough; + case IEEE80211_PROACTIVE_PREQ_NO_PREP: + interval = ifmsh->mshcfg.dot11MeshHWMPactivePathToRootTimeout; + target_flags |= IEEE80211_PREQ_TO_FLAG | + IEEE80211_PREQ_USN_FLAG; + mesh_path_sel_frame_tx(MPATH_PREQ, flags, sdata->vif.addr, + ++ifmsh->sn, target_flags, + (u8 *) broadcast_addr, 0, broadcast_addr, + 0, ifmsh->mshcfg.element_ttl, interval, + 0, ifmsh->preq_id++, sdata); + break; + default: + mhwmp_dbg(sdata, "Proactive mechanism not supported\n"); + return; + } +} diff --git a/net/mac80211/mesh_pathtbl.c b/net/mac80211/mesh_pathtbl.c new file mode 100644 index 000000000..69d5e1ec6 --- /dev/null +++ b/net/mac80211/mesh_pathtbl.c @@ -0,0 +1,790 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008, 2009 open80211s Ltd. + * Author: Luis Carlos Cobo + */ + +#include +#include +#include +#include +#include +#include +#include +#include "wme.h" +#include "ieee80211_i.h" +#include "mesh.h" + +static void mesh_path_free_rcu(struct mesh_table *tbl, struct mesh_path *mpath); + +static u32 mesh_table_hash(const void *addr, u32 len, u32 seed) +{ + /* Use last four bytes of hw addr as hash index */ + return jhash_1word(__get_unaligned_cpu32((u8 *)addr + 2), seed); +} + +static const struct rhashtable_params mesh_rht_params = { + .nelem_hint = 2, + .automatic_shrinking = true, + .key_len = ETH_ALEN, + .key_offset = offsetof(struct mesh_path, dst), + .head_offset = offsetof(struct mesh_path, rhash), + .hashfn = mesh_table_hash, +}; + +static inline bool mpath_expired(struct mesh_path *mpath) +{ + return (mpath->flags & MESH_PATH_ACTIVE) && + time_after(jiffies, mpath->exp_time) && + !(mpath->flags & MESH_PATH_FIXED); +} + +static void mesh_path_rht_free(void *ptr, void *tblptr) +{ + struct mesh_path *mpath = ptr; + struct mesh_table *tbl = tblptr; + + mesh_path_free_rcu(tbl, mpath); +} + +static void mesh_table_init(struct mesh_table *tbl) +{ + INIT_HLIST_HEAD(&tbl->known_gates); + INIT_HLIST_HEAD(&tbl->walk_head); + atomic_set(&tbl->entries, 0); + spin_lock_init(&tbl->gates_lock); + spin_lock_init(&tbl->walk_lock); + + /* rhashtable_init() may fail only in case of wrong + * mesh_rht_params + */ + WARN_ON(rhashtable_init(&tbl->rhead, &mesh_rht_params)); +} + +static void mesh_table_free(struct mesh_table *tbl) +{ + rhashtable_free_and_destroy(&tbl->rhead, + mesh_path_rht_free, tbl); +} + +/** + * mesh_path_assign_nexthop - update mesh path next hop + * + * @mpath: mesh path to update + * @sta: next hop to assign + * + * Locking: mpath->state_lock must be held when calling this function + */ +void mesh_path_assign_nexthop(struct mesh_path *mpath, struct sta_info *sta) +{ + struct sk_buff *skb; + struct ieee80211_hdr *hdr; + unsigned long flags; + + rcu_assign_pointer(mpath->next_hop, sta); + + spin_lock_irqsave(&mpath->frame_queue.lock, flags); + skb_queue_walk(&mpath->frame_queue, skb) { + hdr = (struct ieee80211_hdr *) skb->data; + memcpy(hdr->addr1, sta->sta.addr, ETH_ALEN); + memcpy(hdr->addr2, mpath->sdata->vif.addr, ETH_ALEN); + ieee80211_mps_set_frame_flags(sta->sdata, sta, hdr); + } + + spin_unlock_irqrestore(&mpath->frame_queue.lock, flags); +} + +static void prepare_for_gate(struct sk_buff *skb, char *dst_addr, + struct mesh_path *gate_mpath) +{ + struct ieee80211_hdr *hdr; + struct ieee80211s_hdr *mshdr; + int mesh_hdrlen, hdrlen; + char *next_hop; + + hdr = (struct ieee80211_hdr *) skb->data; + hdrlen = ieee80211_hdrlen(hdr->frame_control); + mshdr = (struct ieee80211s_hdr *) (skb->data + hdrlen); + + if (!(mshdr->flags & MESH_FLAGS_AE)) { + /* size of the fixed part of the mesh header */ + mesh_hdrlen = 6; + + /* make room for the two extended addresses */ + skb_push(skb, 2 * ETH_ALEN); + memmove(skb->data, hdr, hdrlen + mesh_hdrlen); + + hdr = (struct ieee80211_hdr *) skb->data; + + /* we preserve the previous mesh header and only add + * the new addresses */ + mshdr = (struct ieee80211s_hdr *) (skb->data + hdrlen); + mshdr->flags = MESH_FLAGS_AE_A5_A6; + memcpy(mshdr->eaddr1, hdr->addr3, ETH_ALEN); + memcpy(mshdr->eaddr2, hdr->addr4, ETH_ALEN); + } + + /* update next hop */ + hdr = (struct ieee80211_hdr *) skb->data; + rcu_read_lock(); + next_hop = rcu_dereference(gate_mpath->next_hop)->sta.addr; + memcpy(hdr->addr1, next_hop, ETH_ALEN); + rcu_read_unlock(); + memcpy(hdr->addr2, gate_mpath->sdata->vif.addr, ETH_ALEN); + memcpy(hdr->addr3, dst_addr, ETH_ALEN); +} + +/** + * mesh_path_move_to_queue - Move or copy frames from one mpath queue to another + * + * This function is used to transfer or copy frames from an unresolved mpath to + * a gate mpath. The function also adds the Address Extension field and + * updates the next hop. + * + * If a frame already has an Address Extension field, only the next hop and + * destination addresses are updated. + * + * The gate mpath must be an active mpath with a valid mpath->next_hop. + * + * @gate_mpath: An active mpath the frames will be sent to (i.e. the gate) + * @from_mpath: The failed mpath + * @copy: When true, copy all the frames to the new mpath queue. When false, + * move them. + */ +static void mesh_path_move_to_queue(struct mesh_path *gate_mpath, + struct mesh_path *from_mpath, + bool copy) +{ + struct sk_buff *skb, *fskb, *tmp; + struct sk_buff_head failq; + unsigned long flags; + + if (WARN_ON(gate_mpath == from_mpath)) + return; + if (WARN_ON(!gate_mpath->next_hop)) + return; + + __skb_queue_head_init(&failq); + + spin_lock_irqsave(&from_mpath->frame_queue.lock, flags); + skb_queue_splice_init(&from_mpath->frame_queue, &failq); + spin_unlock_irqrestore(&from_mpath->frame_queue.lock, flags); + + skb_queue_walk_safe(&failq, fskb, tmp) { + if (skb_queue_len(&gate_mpath->frame_queue) >= + MESH_FRAME_QUEUE_LEN) { + mpath_dbg(gate_mpath->sdata, "mpath queue full!\n"); + break; + } + + skb = skb_copy(fskb, GFP_ATOMIC); + if (WARN_ON(!skb)) + break; + + prepare_for_gate(skb, gate_mpath->dst, gate_mpath); + skb_queue_tail(&gate_mpath->frame_queue, skb); + + if (copy) + continue; + + __skb_unlink(fskb, &failq); + kfree_skb(fskb); + } + + mpath_dbg(gate_mpath->sdata, "Mpath queue for gate %pM has %d frames\n", + gate_mpath->dst, skb_queue_len(&gate_mpath->frame_queue)); + + if (!copy) + return; + + spin_lock_irqsave(&from_mpath->frame_queue.lock, flags); + skb_queue_splice(&failq, &from_mpath->frame_queue); + spin_unlock_irqrestore(&from_mpath->frame_queue.lock, flags); +} + + +static struct mesh_path *mpath_lookup(struct mesh_table *tbl, const u8 *dst, + struct ieee80211_sub_if_data *sdata) +{ + struct mesh_path *mpath; + + mpath = rhashtable_lookup(&tbl->rhead, dst, mesh_rht_params); + + if (mpath && mpath_expired(mpath)) { + spin_lock_bh(&mpath->state_lock); + mpath->flags &= ~MESH_PATH_ACTIVE; + spin_unlock_bh(&mpath->state_lock); + } + return mpath; +} + +/** + * mesh_path_lookup - look up a path in the mesh path table + * @sdata: local subif + * @dst: hardware address (ETH_ALEN length) of destination + * + * Returns: pointer to the mesh path structure, or NULL if not found + * + * Locking: must be called within a read rcu section. + */ +struct mesh_path * +mesh_path_lookup(struct ieee80211_sub_if_data *sdata, const u8 *dst) +{ + return mpath_lookup(&sdata->u.mesh.mesh_paths, dst, sdata); +} + +struct mesh_path * +mpp_path_lookup(struct ieee80211_sub_if_data *sdata, const u8 *dst) +{ + return mpath_lookup(&sdata->u.mesh.mpp_paths, dst, sdata); +} + +static struct mesh_path * +__mesh_path_lookup_by_idx(struct mesh_table *tbl, int idx) +{ + int i = 0; + struct mesh_path *mpath; + + hlist_for_each_entry_rcu(mpath, &tbl->walk_head, walk_list) { + if (i++ == idx) + break; + } + + if (!mpath) + return NULL; + + if (mpath_expired(mpath)) { + spin_lock_bh(&mpath->state_lock); + mpath->flags &= ~MESH_PATH_ACTIVE; + spin_unlock_bh(&mpath->state_lock); + } + return mpath; +} + +/** + * mesh_path_lookup_by_idx - look up a path in the mesh path table by its index + * @idx: index + * @sdata: local subif, or NULL for all entries + * + * Returns: pointer to the mesh path structure, or NULL if not found. + * + * Locking: must be called within a read rcu section. + */ +struct mesh_path * +mesh_path_lookup_by_idx(struct ieee80211_sub_if_data *sdata, int idx) +{ + return __mesh_path_lookup_by_idx(&sdata->u.mesh.mesh_paths, idx); +} + +/** + * mpp_path_lookup_by_idx - look up a path in the proxy path table by its index + * @idx: index + * @sdata: local subif, or NULL for all entries + * + * Returns: pointer to the proxy path structure, or NULL if not found. + * + * Locking: must be called within a read rcu section. + */ +struct mesh_path * +mpp_path_lookup_by_idx(struct ieee80211_sub_if_data *sdata, int idx) +{ + return __mesh_path_lookup_by_idx(&sdata->u.mesh.mpp_paths, idx); +} + +/** + * mesh_path_add_gate - add the given mpath to a mesh gate to our path table + * @mpath: gate path to add to table + */ +int mesh_path_add_gate(struct mesh_path *mpath) +{ + struct mesh_table *tbl; + int err; + + rcu_read_lock(); + tbl = &mpath->sdata->u.mesh.mesh_paths; + + spin_lock_bh(&mpath->state_lock); + if (mpath->is_gate) { + err = -EEXIST; + spin_unlock_bh(&mpath->state_lock); + goto err_rcu; + } + mpath->is_gate = true; + mpath->sdata->u.mesh.num_gates++; + + spin_lock(&tbl->gates_lock); + hlist_add_head_rcu(&mpath->gate_list, &tbl->known_gates); + spin_unlock(&tbl->gates_lock); + + spin_unlock_bh(&mpath->state_lock); + + mpath_dbg(mpath->sdata, + "Mesh path: Recorded new gate: %pM. %d known gates\n", + mpath->dst, mpath->sdata->u.mesh.num_gates); + err = 0; +err_rcu: + rcu_read_unlock(); + return err; +} + +/** + * mesh_gate_del - remove a mesh gate from the list of known gates + * @tbl: table which holds our list of known gates + * @mpath: gate mpath + */ +static void mesh_gate_del(struct mesh_table *tbl, struct mesh_path *mpath) +{ + lockdep_assert_held(&mpath->state_lock); + if (!mpath->is_gate) + return; + + mpath->is_gate = false; + spin_lock_bh(&tbl->gates_lock); + hlist_del_rcu(&mpath->gate_list); + mpath->sdata->u.mesh.num_gates--; + spin_unlock_bh(&tbl->gates_lock); + + mpath_dbg(mpath->sdata, + "Mesh path: Deleted gate: %pM. %d known gates\n", + mpath->dst, mpath->sdata->u.mesh.num_gates); +} + +/** + * mesh_gate_num - number of gates known to this interface + * @sdata: subif data + */ +int mesh_gate_num(struct ieee80211_sub_if_data *sdata) +{ + return sdata->u.mesh.num_gates; +} + +static +struct mesh_path *mesh_path_new(struct ieee80211_sub_if_data *sdata, + const u8 *dst, gfp_t gfp_flags) +{ + struct mesh_path *new_mpath; + + new_mpath = kzalloc(sizeof(struct mesh_path), gfp_flags); + if (!new_mpath) + return NULL; + + memcpy(new_mpath->dst, dst, ETH_ALEN); + eth_broadcast_addr(new_mpath->rann_snd_addr); + new_mpath->is_root = false; + new_mpath->sdata = sdata; + new_mpath->flags = 0; + skb_queue_head_init(&new_mpath->frame_queue); + new_mpath->exp_time = jiffies; + spin_lock_init(&new_mpath->state_lock); + timer_setup(&new_mpath->timer, mesh_path_timer, 0); + + return new_mpath; +} + +/** + * mesh_path_add - allocate and add a new path to the mesh path table + * @dst: destination address of the path (ETH_ALEN length) + * @sdata: local subif + * + * Returns: 0 on success + * + * State: the initial state of the new path is set to 0 + */ +struct mesh_path *mesh_path_add(struct ieee80211_sub_if_data *sdata, + const u8 *dst) +{ + struct mesh_table *tbl; + struct mesh_path *mpath, *new_mpath; + + if (ether_addr_equal(dst, sdata->vif.addr)) + /* never add ourselves as neighbours */ + return ERR_PTR(-ENOTSUPP); + + if (is_multicast_ether_addr(dst)) + return ERR_PTR(-ENOTSUPP); + + if (atomic_add_unless(&sdata->u.mesh.mpaths, 1, MESH_MAX_MPATHS) == 0) + return ERR_PTR(-ENOSPC); + + new_mpath = mesh_path_new(sdata, dst, GFP_ATOMIC); + if (!new_mpath) + return ERR_PTR(-ENOMEM); + + tbl = &sdata->u.mesh.mesh_paths; + spin_lock_bh(&tbl->walk_lock); + mpath = rhashtable_lookup_get_insert_fast(&tbl->rhead, + &new_mpath->rhash, + mesh_rht_params); + if (!mpath) + hlist_add_head(&new_mpath->walk_list, &tbl->walk_head); + spin_unlock_bh(&tbl->walk_lock); + + if (mpath) { + kfree(new_mpath); + + if (IS_ERR(mpath)) + return mpath; + + new_mpath = mpath; + } + + sdata->u.mesh.mesh_paths_generation++; + return new_mpath; +} + +int mpp_path_add(struct ieee80211_sub_if_data *sdata, + const u8 *dst, const u8 *mpp) +{ + struct mesh_table *tbl; + struct mesh_path *new_mpath; + int ret; + + if (ether_addr_equal(dst, sdata->vif.addr)) + /* never add ourselves as neighbours */ + return -ENOTSUPP; + + if (is_multicast_ether_addr(dst)) + return -ENOTSUPP; + + new_mpath = mesh_path_new(sdata, dst, GFP_ATOMIC); + + if (!new_mpath) + return -ENOMEM; + + memcpy(new_mpath->mpp, mpp, ETH_ALEN); + tbl = &sdata->u.mesh.mpp_paths; + + spin_lock_bh(&tbl->walk_lock); + ret = rhashtable_lookup_insert_fast(&tbl->rhead, + &new_mpath->rhash, + mesh_rht_params); + if (!ret) + hlist_add_head_rcu(&new_mpath->walk_list, &tbl->walk_head); + spin_unlock_bh(&tbl->walk_lock); + + if (ret) + kfree(new_mpath); + + sdata->u.mesh.mpp_paths_generation++; + return ret; +} + + +/** + * mesh_plink_broken - deactivates paths and sends perr when a link breaks + * + * @sta: broken peer link + * + * This function must be called from the rate control algorithm if enough + * delivery errors suggest that a peer link is no longer usable. + */ +void mesh_plink_broken(struct sta_info *sta) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct mesh_table *tbl = &sdata->u.mesh.mesh_paths; + static const u8 bcast[ETH_ALEN] = {0xff, 0xff, 0xff, 0xff, 0xff, 0xff}; + struct mesh_path *mpath; + + rcu_read_lock(); + hlist_for_each_entry_rcu(mpath, &tbl->walk_head, walk_list) { + if (rcu_access_pointer(mpath->next_hop) == sta && + mpath->flags & MESH_PATH_ACTIVE && + !(mpath->flags & MESH_PATH_FIXED)) { + spin_lock_bh(&mpath->state_lock); + mpath->flags &= ~MESH_PATH_ACTIVE; + ++mpath->sn; + spin_unlock_bh(&mpath->state_lock); + mesh_path_error_tx(sdata, + sdata->u.mesh.mshcfg.element_ttl, + mpath->dst, mpath->sn, + WLAN_REASON_MESH_PATH_DEST_UNREACHABLE, bcast); + } + } + rcu_read_unlock(); +} + +static void mesh_path_free_rcu(struct mesh_table *tbl, + struct mesh_path *mpath) +{ + struct ieee80211_sub_if_data *sdata = mpath->sdata; + + spin_lock_bh(&mpath->state_lock); + mpath->flags |= MESH_PATH_RESOLVING | MESH_PATH_DELETED; + mesh_gate_del(tbl, mpath); + spin_unlock_bh(&mpath->state_lock); + del_timer_sync(&mpath->timer); + atomic_dec(&sdata->u.mesh.mpaths); + atomic_dec(&tbl->entries); + mesh_path_flush_pending(mpath); + kfree_rcu(mpath, rcu); +} + +static void __mesh_path_del(struct mesh_table *tbl, struct mesh_path *mpath) +{ + hlist_del_rcu(&mpath->walk_list); + rhashtable_remove_fast(&tbl->rhead, &mpath->rhash, mesh_rht_params); + mesh_path_free_rcu(tbl, mpath); +} + +/** + * mesh_path_flush_by_nexthop - Deletes mesh paths if their next hop matches + * + * @sta: mesh peer to match + * + * RCU notes: this function is called when a mesh plink transitions from + * PLINK_ESTAB to any other state, since PLINK_ESTAB state is the only one that + * allows path creation. This will happen before the sta can be freed (because + * sta_info_destroy() calls this) so any reader in a rcu read block will be + * protected against the plink disappearing. + */ +void mesh_path_flush_by_nexthop(struct sta_info *sta) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct mesh_table *tbl = &sdata->u.mesh.mesh_paths; + struct mesh_path *mpath; + struct hlist_node *n; + + spin_lock_bh(&tbl->walk_lock); + hlist_for_each_entry_safe(mpath, n, &tbl->walk_head, walk_list) { + if (rcu_access_pointer(mpath->next_hop) == sta) + __mesh_path_del(tbl, mpath); + } + spin_unlock_bh(&tbl->walk_lock); +} + +static void mpp_flush_by_proxy(struct ieee80211_sub_if_data *sdata, + const u8 *proxy) +{ + struct mesh_table *tbl = &sdata->u.mesh.mpp_paths; + struct mesh_path *mpath; + struct hlist_node *n; + + spin_lock_bh(&tbl->walk_lock); + hlist_for_each_entry_safe(mpath, n, &tbl->walk_head, walk_list) { + if (ether_addr_equal(mpath->mpp, proxy)) + __mesh_path_del(tbl, mpath); + } + spin_unlock_bh(&tbl->walk_lock); +} + +static void table_flush_by_iface(struct mesh_table *tbl) +{ + struct mesh_path *mpath; + struct hlist_node *n; + + spin_lock_bh(&tbl->walk_lock); + hlist_for_each_entry_safe(mpath, n, &tbl->walk_head, walk_list) { + __mesh_path_del(tbl, mpath); + } + spin_unlock_bh(&tbl->walk_lock); +} + +/** + * mesh_path_flush_by_iface - Deletes all mesh paths associated with a given iface + * + * This function deletes both mesh paths as well as mesh portal paths. + * + * @sdata: interface data to match + * + */ +void mesh_path_flush_by_iface(struct ieee80211_sub_if_data *sdata) +{ + table_flush_by_iface(&sdata->u.mesh.mesh_paths); + table_flush_by_iface(&sdata->u.mesh.mpp_paths); +} + +/** + * table_path_del - delete a path from the mesh or mpp table + * + * @tbl: mesh or mpp path table + * @sdata: local subif + * @addr: dst address (ETH_ALEN length) + * + * Returns: 0 if successful + */ +static int table_path_del(struct mesh_table *tbl, + struct ieee80211_sub_if_data *sdata, + const u8 *addr) +{ + struct mesh_path *mpath; + + spin_lock_bh(&tbl->walk_lock); + mpath = rhashtable_lookup_fast(&tbl->rhead, addr, mesh_rht_params); + if (!mpath) { + spin_unlock_bh(&tbl->walk_lock); + return -ENXIO; + } + + __mesh_path_del(tbl, mpath); + spin_unlock_bh(&tbl->walk_lock); + return 0; +} + + +/** + * mesh_path_del - delete a mesh path from the table + * + * @addr: dst address (ETH_ALEN length) + * @sdata: local subif + * + * Returns: 0 if successful + */ +int mesh_path_del(struct ieee80211_sub_if_data *sdata, const u8 *addr) +{ + int err; + + /* flush relevant mpp entries first */ + mpp_flush_by_proxy(sdata, addr); + + err = table_path_del(&sdata->u.mesh.mesh_paths, sdata, addr); + sdata->u.mesh.mesh_paths_generation++; + return err; +} + +/** + * mesh_path_tx_pending - sends pending frames in a mesh path queue + * + * @mpath: mesh path to activate + * + * Locking: the state_lock of the mpath structure must NOT be held when calling + * this function. + */ +void mesh_path_tx_pending(struct mesh_path *mpath) +{ + if (mpath->flags & MESH_PATH_ACTIVE) + ieee80211_add_pending_skbs(mpath->sdata->local, + &mpath->frame_queue); +} + +/** + * mesh_path_send_to_gates - sends pending frames to all known mesh gates + * + * @mpath: mesh path whose queue will be emptied + * + * If there is only one gate, the frames are transferred from the failed mpath + * queue to that gate's queue. If there are more than one gates, the frames + * are copied from each gate to the next. After frames are copied, the + * mpath queues are emptied onto the transmission queue. + */ +int mesh_path_send_to_gates(struct mesh_path *mpath) +{ + struct ieee80211_sub_if_data *sdata = mpath->sdata; + struct mesh_table *tbl; + struct mesh_path *from_mpath = mpath; + struct mesh_path *gate; + bool copy = false; + + tbl = &sdata->u.mesh.mesh_paths; + + rcu_read_lock(); + hlist_for_each_entry_rcu(gate, &tbl->known_gates, gate_list) { + if (gate->flags & MESH_PATH_ACTIVE) { + mpath_dbg(sdata, "Forwarding to %pM\n", gate->dst); + mesh_path_move_to_queue(gate, from_mpath, copy); + from_mpath = gate; + copy = true; + } else { + mpath_dbg(sdata, + "Not forwarding to %pM (flags %#x)\n", + gate->dst, gate->flags); + } + } + + hlist_for_each_entry_rcu(gate, &tbl->known_gates, gate_list) { + mpath_dbg(sdata, "Sending to %pM\n", gate->dst); + mesh_path_tx_pending(gate); + } + rcu_read_unlock(); + + return (from_mpath == mpath) ? -EHOSTUNREACH : 0; +} + +/** + * mesh_path_discard_frame - discard a frame whose path could not be resolved + * + * @skb: frame to discard + * @sdata: network subif the frame was to be sent through + * + * Locking: the function must me called within a rcu_read_lock region + */ +void mesh_path_discard_frame(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + ieee80211_free_txskb(&sdata->local->hw, skb); + sdata->u.mesh.mshstats.dropped_frames_no_route++; +} + +/** + * mesh_path_flush_pending - free the pending queue of a mesh path + * + * @mpath: mesh path whose queue has to be freed + * + * Locking: the function must me called within a rcu_read_lock region + */ +void mesh_path_flush_pending(struct mesh_path *mpath) +{ + struct sk_buff *skb; + + while ((skb = skb_dequeue(&mpath->frame_queue)) != NULL) + mesh_path_discard_frame(mpath->sdata, skb); +} + +/** + * mesh_path_fix_nexthop - force a specific next hop for a mesh path + * + * @mpath: the mesh path to modify + * @next_hop: the next hop to force + * + * Locking: this function must be called holding mpath->state_lock + */ +void mesh_path_fix_nexthop(struct mesh_path *mpath, struct sta_info *next_hop) +{ + spin_lock_bh(&mpath->state_lock); + mesh_path_assign_nexthop(mpath, next_hop); + mpath->sn = 0xffff; + mpath->metric = 0; + mpath->hop_count = 0; + mpath->exp_time = 0; + mpath->flags = MESH_PATH_FIXED | MESH_PATH_SN_VALID; + mesh_path_activate(mpath); + spin_unlock_bh(&mpath->state_lock); + ewma_mesh_fail_avg_init(&next_hop->mesh->fail_avg); + /* init it at a low value - 0 start is tricky */ + ewma_mesh_fail_avg_add(&next_hop->mesh->fail_avg, 1); + mesh_path_tx_pending(mpath); +} + +void mesh_pathtbl_init(struct ieee80211_sub_if_data *sdata) +{ + mesh_table_init(&sdata->u.mesh.mesh_paths); + mesh_table_init(&sdata->u.mesh.mpp_paths); +} + +static +void mesh_path_tbl_expire(struct ieee80211_sub_if_data *sdata, + struct mesh_table *tbl) +{ + struct mesh_path *mpath; + struct hlist_node *n; + + spin_lock_bh(&tbl->walk_lock); + hlist_for_each_entry_safe(mpath, n, &tbl->walk_head, walk_list) { + if ((!(mpath->flags & MESH_PATH_RESOLVING)) && + (!(mpath->flags & MESH_PATH_FIXED)) && + time_after(jiffies, mpath->exp_time + MESH_PATH_EXPIRE)) + __mesh_path_del(tbl, mpath); + } + spin_unlock_bh(&tbl->walk_lock); +} + +void mesh_path_expire(struct ieee80211_sub_if_data *sdata) +{ + mesh_path_tbl_expire(sdata, &sdata->u.mesh.mesh_paths); + mesh_path_tbl_expire(sdata, &sdata->u.mesh.mpp_paths); +} + +void mesh_pathtbl_unregister(struct ieee80211_sub_if_data *sdata) +{ + mesh_table_free(&sdata->u.mesh.mesh_paths); + mesh_table_free(&sdata->u.mesh.mpp_paths); +} diff --git a/net/mac80211/mesh_plink.c b/net/mac80211/mesh_plink.c new file mode 100644 index 000000000..711c3377f --- /dev/null +++ b/net/mac80211/mesh_plink.c @@ -0,0 +1,1237 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008, 2009 open80211s Ltd. + * Copyright (C) 2019, 2021-2022 Intel Corporation + * Author: Luis Carlos Cobo + */ +#include +#include +#include +#include + +#include "ieee80211_i.h" +#include "rate.h" +#include "mesh.h" + +#define PLINK_CNF_AID(mgmt) ((mgmt)->u.action.u.self_prot.variable + 2) +#define PLINK_GET_LLID(p) (p + 2) +#define PLINK_GET_PLID(p) (p + 4) + +#define mod_plink_timer(s, t) (mod_timer(&s->mesh->plink_timer, \ + jiffies + msecs_to_jiffies(t))) + +enum plink_event { + PLINK_UNDEFINED, + OPN_ACPT, + OPN_RJCT, + OPN_IGNR, + CNF_ACPT, + CNF_RJCT, + CNF_IGNR, + CLS_ACPT, + CLS_IGNR +}; + +static const char * const mplstates[] = { + [NL80211_PLINK_LISTEN] = "LISTEN", + [NL80211_PLINK_OPN_SNT] = "OPN-SNT", + [NL80211_PLINK_OPN_RCVD] = "OPN-RCVD", + [NL80211_PLINK_CNF_RCVD] = "CNF_RCVD", + [NL80211_PLINK_ESTAB] = "ESTAB", + [NL80211_PLINK_HOLDING] = "HOLDING", + [NL80211_PLINK_BLOCKED] = "BLOCKED" +}; + +static const char * const mplevents[] = { + [PLINK_UNDEFINED] = "NONE", + [OPN_ACPT] = "OPN_ACPT", + [OPN_RJCT] = "OPN_RJCT", + [OPN_IGNR] = "OPN_IGNR", + [CNF_ACPT] = "CNF_ACPT", + [CNF_RJCT] = "CNF_RJCT", + [CNF_IGNR] = "CNF_IGNR", + [CLS_ACPT] = "CLS_ACPT", + [CLS_IGNR] = "CLS_IGNR" +}; + +/* We only need a valid sta if user configured a minimum rssi_threshold. */ +static bool rssi_threshold_check(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta) +{ + s32 rssi_threshold = sdata->u.mesh.mshcfg.rssi_threshold; + return rssi_threshold == 0 || + (sta && + (s8)-ewma_signal_read(&sta->deflink.rx_stats_avg.signal) > + rssi_threshold); +} + +/** + * mesh_plink_fsm_restart - restart a mesh peer link finite state machine + * + * @sta: mesh peer link to restart + * + * Locking: this function must be called holding sta->mesh->plink_lock + */ +static inline void mesh_plink_fsm_restart(struct sta_info *sta) +{ + lockdep_assert_held(&sta->mesh->plink_lock); + sta->mesh->plink_state = NL80211_PLINK_LISTEN; + sta->mesh->llid = sta->mesh->plid = sta->mesh->reason = 0; + sta->mesh->plink_retries = 0; +} + +/* + * mesh_set_short_slot_time - enable / disable ERP short slot time. + * + * The standard indirectly mandates mesh STAs to turn off short slot time by + * disallowing advertising this (802.11-2012 8.4.1.4), but that doesn't mean we + * can't be sneaky about it. Enable short slot time if all mesh STAs in the + * MBSS support ERP rates. + * + * Returns BSS_CHANGED_ERP_SLOT or 0 for no change. + */ +static u32 mesh_set_short_slot_time(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_supported_band *sband; + struct sta_info *sta; + u32 erp_rates = 0, changed = 0; + int i; + bool short_slot = false; + + sband = ieee80211_get_sband(sdata); + if (!sband) + return changed; + + if (sband->band == NL80211_BAND_5GHZ) { + /* (IEEE 802.11-2012 19.4.5) */ + short_slot = true; + goto out; + } else if (sband->band != NL80211_BAND_2GHZ) { + goto out; + } + + for (i = 0; i < sband->n_bitrates; i++) + if (sband->bitrates[i].flags & IEEE80211_RATE_ERP_G) + erp_rates |= BIT(i); + + if (!erp_rates) + goto out; + + rcu_read_lock(); + list_for_each_entry_rcu(sta, &local->sta_list, list) { + if (sdata != sta->sdata || + sta->mesh->plink_state != NL80211_PLINK_ESTAB) + continue; + + short_slot = false; + if (erp_rates & sta->sta.deflink.supp_rates[sband->band]) + short_slot = true; + else + break; + } + rcu_read_unlock(); + +out: + if (sdata->vif.bss_conf.use_short_slot != short_slot) { + sdata->vif.bss_conf.use_short_slot = short_slot; + changed = BSS_CHANGED_ERP_SLOT; + mpl_dbg(sdata, "mesh_plink %pM: ERP short slot time %d\n", + sdata->vif.addr, short_slot); + } + return changed; +} + +/** + * mesh_set_ht_prot_mode - set correct HT protection mode + * @sdata: the (mesh) interface to handle + * + * Section 9.23.3.5 of IEEE 80211-2012 describes the protection rules for HT + * mesh STA in a MBSS. Three HT protection modes are supported for now, non-HT + * mixed mode, 20MHz-protection and no-protection mode. non-HT mixed mode is + * selected if any non-HT peers are present in our MBSS. 20MHz-protection mode + * is selected if all peers in our 20/40MHz MBSS support HT and at least one + * HT20 peer is present. Otherwise no-protection mode is selected. + */ +static u32 mesh_set_ht_prot_mode(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + u16 ht_opmode; + bool non_ht_sta = false, ht20_sta = false; + + switch (sdata->vif.bss_conf.chandef.width) { + case NL80211_CHAN_WIDTH_20_NOHT: + case NL80211_CHAN_WIDTH_5: + case NL80211_CHAN_WIDTH_10: + return 0; + default: + break; + } + + rcu_read_lock(); + list_for_each_entry_rcu(sta, &local->sta_list, list) { + if (sdata != sta->sdata || + sta->mesh->plink_state != NL80211_PLINK_ESTAB) + continue; + + if (sta->sta.deflink.bandwidth > IEEE80211_STA_RX_BW_20) + continue; + + if (!sta->sta.deflink.ht_cap.ht_supported) { + mpl_dbg(sdata, "nonHT sta (%pM) is present\n", + sta->sta.addr); + non_ht_sta = true; + break; + } + + mpl_dbg(sdata, "HT20 sta (%pM) is present\n", sta->sta.addr); + ht20_sta = true; + } + rcu_read_unlock(); + + if (non_ht_sta) + ht_opmode = IEEE80211_HT_OP_MODE_PROTECTION_NONHT_MIXED; + else if (ht20_sta && + sdata->vif.bss_conf.chandef.width > NL80211_CHAN_WIDTH_20) + ht_opmode = IEEE80211_HT_OP_MODE_PROTECTION_20MHZ; + else + ht_opmode = IEEE80211_HT_OP_MODE_PROTECTION_NONE; + + if (sdata->vif.bss_conf.ht_operation_mode == ht_opmode) + return 0; + + sdata->vif.bss_conf.ht_operation_mode = ht_opmode; + sdata->u.mesh.mshcfg.ht_opmode = ht_opmode; + mpl_dbg(sdata, "selected new HT protection mode %d\n", ht_opmode); + return BSS_CHANGED_HT; +} + +static int mesh_plink_frame_tx(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + enum ieee80211_self_protected_actioncode action, + u8 *da, u16 llid, u16 plid, u16 reason) +{ + struct ieee80211_local *local = sdata->local; + struct sk_buff *skb; + struct ieee80211_tx_info *info; + struct ieee80211_mgmt *mgmt; + bool include_plid = false; + u16 peering_proto = 0; + u8 *pos, ie_len = 4; + u8 ie_len_he_cap; + int hdr_len = offsetofend(struct ieee80211_mgmt, u.action.u.self_prot); + int err = -ENOMEM; + + ie_len_he_cap = ieee80211_ie_len_he_cap(sdata, + NL80211_IFTYPE_MESH_POINT); + skb = dev_alloc_skb(local->tx_headroom + + hdr_len + + 2 + /* capability info */ + 2 + /* AID */ + 2 + 8 + /* supported rates */ + 2 + (IEEE80211_MAX_SUPP_RATES - 8) + + 2 + sdata->u.mesh.mesh_id_len + + 2 + sizeof(struct ieee80211_meshconf_ie) + + 2 + sizeof(struct ieee80211_ht_cap) + + 2 + sizeof(struct ieee80211_ht_operation) + + 2 + sizeof(struct ieee80211_vht_cap) + + 2 + sizeof(struct ieee80211_vht_operation) + + ie_len_he_cap + + 2 + 1 + sizeof(struct ieee80211_he_operation) + + sizeof(struct ieee80211_he_6ghz_oper) + + 2 + 1 + sizeof(struct ieee80211_he_6ghz_capa) + + 2 + 8 + /* peering IE */ + sdata->u.mesh.ie_len); + if (!skb) + return err; + info = IEEE80211_SKB_CB(skb); + skb_reserve(skb, local->tx_headroom); + mgmt = skb_put_zero(skb, hdr_len); + mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_ACTION); + memcpy(mgmt->da, da, ETH_ALEN); + memcpy(mgmt->sa, sdata->vif.addr, ETH_ALEN); + memcpy(mgmt->bssid, sdata->vif.addr, ETH_ALEN); + mgmt->u.action.category = WLAN_CATEGORY_SELF_PROTECTED; + mgmt->u.action.u.self_prot.action_code = action; + + if (action != WLAN_SP_MESH_PEERING_CLOSE) { + struct ieee80211_supported_band *sband; + enum nl80211_band band; + + sband = ieee80211_get_sband(sdata); + if (!sband) { + err = -EINVAL; + goto free; + } + band = sband->band; + + /* capability info */ + pos = skb_put_zero(skb, 2); + if (action == WLAN_SP_MESH_PEERING_CONFIRM) { + /* AID */ + pos = skb_put(skb, 2); + put_unaligned_le16(sta->sta.aid, pos); + } + if (ieee80211_add_srates_ie(sdata, skb, true, band) || + ieee80211_add_ext_srates_ie(sdata, skb, true, band) || + mesh_add_rsn_ie(sdata, skb) || + mesh_add_meshid_ie(sdata, skb) || + mesh_add_meshconf_ie(sdata, skb)) + goto free; + } else { /* WLAN_SP_MESH_PEERING_CLOSE */ + info->flags |= IEEE80211_TX_CTL_NO_ACK; + if (mesh_add_meshid_ie(sdata, skb)) + goto free; + } + + /* Add Mesh Peering Management element */ + switch (action) { + case WLAN_SP_MESH_PEERING_OPEN: + break; + case WLAN_SP_MESH_PEERING_CONFIRM: + ie_len += 2; + include_plid = true; + break; + case WLAN_SP_MESH_PEERING_CLOSE: + if (plid) { + ie_len += 2; + include_plid = true; + } + ie_len += 2; /* reason code */ + break; + default: + err = -EINVAL; + goto free; + } + + if (WARN_ON(skb_tailroom(skb) < 2 + ie_len)) + goto free; + + pos = skb_put(skb, 2 + ie_len); + *pos++ = WLAN_EID_PEER_MGMT; + *pos++ = ie_len; + memcpy(pos, &peering_proto, 2); + pos += 2; + put_unaligned_le16(llid, pos); + pos += 2; + if (include_plid) { + put_unaligned_le16(plid, pos); + pos += 2; + } + if (action == WLAN_SP_MESH_PEERING_CLOSE) { + put_unaligned_le16(reason, pos); + pos += 2; + } + + if (action != WLAN_SP_MESH_PEERING_CLOSE) { + if (mesh_add_ht_cap_ie(sdata, skb) || + mesh_add_ht_oper_ie(sdata, skb) || + mesh_add_vht_cap_ie(sdata, skb) || + mesh_add_vht_oper_ie(sdata, skb) || + mesh_add_he_cap_ie(sdata, skb, ie_len_he_cap) || + mesh_add_he_oper_ie(sdata, skb) || + mesh_add_he_6ghz_cap_ie(sdata, skb)) + goto free; + } + + if (mesh_add_vendor_ies(sdata, skb)) + goto free; + + ieee80211_tx_skb(sdata, skb); + return 0; +free: + kfree_skb(skb); + return err; +} + +/** + * __mesh_plink_deactivate - deactivate mesh peer link + * + * @sta: mesh peer link to deactivate + * + * Mesh paths with this peer as next hop should be flushed + * by the caller outside of plink_lock. + * + * Returns beacon changed flag if the beacon content changed. + * + * Locking: the caller must hold sta->mesh->plink_lock + */ +static u32 __mesh_plink_deactivate(struct sta_info *sta) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + u32 changed = 0; + + lockdep_assert_held(&sta->mesh->plink_lock); + + if (sta->mesh->plink_state == NL80211_PLINK_ESTAB) + changed = mesh_plink_dec_estab_count(sdata); + sta->mesh->plink_state = NL80211_PLINK_BLOCKED; + + ieee80211_mps_sta_status_update(sta); + changed |= ieee80211_mps_set_sta_local_pm(sta, + NL80211_MESH_POWER_UNKNOWN); + + return changed; +} + +/** + * mesh_plink_deactivate - deactivate mesh peer link + * + * @sta: mesh peer link to deactivate + * + * All mesh paths with this peer as next hop will be flushed + */ +u32 mesh_plink_deactivate(struct sta_info *sta) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + u32 changed; + + spin_lock_bh(&sta->mesh->plink_lock); + changed = __mesh_plink_deactivate(sta); + + if (!sdata->u.mesh.user_mpm) { + sta->mesh->reason = WLAN_REASON_MESH_PEER_CANCELED; + mesh_plink_frame_tx(sdata, sta, WLAN_SP_MESH_PEERING_CLOSE, + sta->sta.addr, sta->mesh->llid, + sta->mesh->plid, sta->mesh->reason); + } + spin_unlock_bh(&sta->mesh->plink_lock); + if (!sdata->u.mesh.user_mpm) + del_timer_sync(&sta->mesh->plink_timer); + mesh_path_flush_by_nexthop(sta); + + /* make sure no readers can access nexthop sta from here on */ + synchronize_net(); + + return changed; +} + +static void mesh_sta_info_init(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + struct ieee802_11_elems *elems) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_supported_band *sband; + u32 rates, basic_rates = 0, changed = 0; + enum ieee80211_sta_rx_bandwidth bw = sta->sta.deflink.bandwidth; + + sband = ieee80211_get_sband(sdata); + if (!sband) + return; + + rates = ieee80211_sta_get_rates(sdata, elems, sband->band, + &basic_rates); + + spin_lock_bh(&sta->mesh->plink_lock); + sta->deflink.rx_stats.last_rx = jiffies; + + /* rates and capabilities don't change during peering */ + if (sta->mesh->plink_state == NL80211_PLINK_ESTAB && + sta->mesh->processed_beacon) + goto out; + sta->mesh->processed_beacon = true; + + if (sta->sta.deflink.supp_rates[sband->band] != rates) + changed |= IEEE80211_RC_SUPP_RATES_CHANGED; + sta->sta.deflink.supp_rates[sband->band] = rates; + + if (ieee80211_ht_cap_ie_to_sta_ht_cap(sdata, sband, + elems->ht_cap_elem, + &sta->deflink)) + changed |= IEEE80211_RC_BW_CHANGED; + + ieee80211_vht_cap_ie_to_sta_vht_cap(sdata, sband, + elems->vht_cap_elem, NULL, + &sta->deflink); + + ieee80211_he_cap_ie_to_sta_he_cap(sdata, sband, elems->he_cap, + elems->he_cap_len, + elems->he_6ghz_capa, + &sta->deflink); + + if (bw != sta->sta.deflink.bandwidth) + changed |= IEEE80211_RC_BW_CHANGED; + + /* HT peer is operating 20MHz-only */ + if (elems->ht_operation && + !(elems->ht_operation->ht_param & + IEEE80211_HT_PARAM_CHAN_WIDTH_ANY)) { + if (sta->sta.deflink.bandwidth != IEEE80211_STA_RX_BW_20) + changed |= IEEE80211_RC_BW_CHANGED; + sta->sta.deflink.bandwidth = IEEE80211_STA_RX_BW_20; + } + + if (!test_sta_flag(sta, WLAN_STA_RATE_CONTROL)) + rate_control_rate_init(sta); + else + rate_control_rate_update(local, sband, sta, 0, changed); +out: + spin_unlock_bh(&sta->mesh->plink_lock); +} + +static int mesh_allocate_aid(struct ieee80211_sub_if_data *sdata) +{ + struct sta_info *sta; + unsigned long *aid_map; + int aid; + + aid_map = bitmap_zalloc(IEEE80211_MAX_AID + 1, GFP_KERNEL); + if (!aid_map) + return -ENOMEM; + + /* reserve aid 0 for mcast indication */ + __set_bit(0, aid_map); + + rcu_read_lock(); + list_for_each_entry_rcu(sta, &sdata->local->sta_list, list) + __set_bit(sta->sta.aid, aid_map); + rcu_read_unlock(); + + aid = find_first_zero_bit(aid_map, IEEE80211_MAX_AID + 1); + bitmap_free(aid_map); + + if (aid > IEEE80211_MAX_AID) + return -ENOBUFS; + + return aid; +} + +static struct sta_info * +__mesh_sta_info_alloc(struct ieee80211_sub_if_data *sdata, u8 *hw_addr) +{ + struct sta_info *sta; + int aid; + + if (sdata->local->num_sta >= MESH_MAX_PLINKS) + return NULL; + + aid = mesh_allocate_aid(sdata); + if (aid < 0) + return NULL; + + sta = sta_info_alloc(sdata, hw_addr, GFP_KERNEL); + if (!sta) + return NULL; + + sta->mesh->plink_state = NL80211_PLINK_LISTEN; + sta->sta.wme = true; + sta->sta.aid = aid; + + sta_info_pre_move_state(sta, IEEE80211_STA_AUTH); + sta_info_pre_move_state(sta, IEEE80211_STA_ASSOC); + sta_info_pre_move_state(sta, IEEE80211_STA_AUTHORIZED); + + return sta; +} + +static struct sta_info * +mesh_sta_info_alloc(struct ieee80211_sub_if_data *sdata, u8 *addr, + struct ieee802_11_elems *elems, + struct ieee80211_rx_status *rx_status) +{ + struct sta_info *sta = NULL; + + /* Userspace handles station allocation */ + if (sdata->u.mesh.user_mpm || + sdata->u.mesh.security & IEEE80211_MESH_SEC_AUTHED) { + if (mesh_peer_accepts_plinks(elems) && + mesh_plink_availables(sdata)) { + int sig = 0; + + if (ieee80211_hw_check(&sdata->local->hw, SIGNAL_DBM)) + sig = rx_status->signal; + + cfg80211_notify_new_peer_candidate(sdata->dev, addr, + elems->ie_start, + elems->total_len, + sig, GFP_KERNEL); + } + } else + sta = __mesh_sta_info_alloc(sdata, addr); + + return sta; +} + +/* + * mesh_sta_info_get - return mesh sta info entry for @addr. + * + * @sdata: local meshif + * @addr: peer's address + * @elems: IEs from beacon or mesh peering frame. + * @rx_status: rx status for the frame for signal reporting + * + * Return existing or newly allocated sta_info under RCU read lock. + * (re)initialize with given IEs. + */ +static struct sta_info * +mesh_sta_info_get(struct ieee80211_sub_if_data *sdata, + u8 *addr, struct ieee802_11_elems *elems, + struct ieee80211_rx_status *rx_status) __acquires(RCU) +{ + struct sta_info *sta = NULL; + + rcu_read_lock(); + sta = sta_info_get(sdata, addr); + if (sta) { + mesh_sta_info_init(sdata, sta, elems); + } else { + rcu_read_unlock(); + /* can't run atomic */ + sta = mesh_sta_info_alloc(sdata, addr, elems, rx_status); + if (!sta) { + rcu_read_lock(); + return NULL; + } + + mesh_sta_info_init(sdata, sta, elems); + + if (sta_info_insert_rcu(sta)) + return NULL; + } + + return sta; +} + +/* + * mesh_neighbour_update - update or initialize new mesh neighbor. + * + * @sdata: local meshif + * @addr: peer's address + * @elems: IEs from beacon or mesh peering frame + * @rx_status: rx status for the frame for signal reporting + * + * Initiates peering if appropriate. + */ +void mesh_neighbour_update(struct ieee80211_sub_if_data *sdata, + u8 *hw_addr, + struct ieee802_11_elems *elems, + struct ieee80211_rx_status *rx_status) +{ + struct sta_info *sta; + u32 changed = 0; + + sta = mesh_sta_info_get(sdata, hw_addr, elems, rx_status); + if (!sta) + goto out; + + sta->mesh->connected_to_gate = elems->mesh_config->meshconf_form & + IEEE80211_MESHCONF_FORM_CONNECTED_TO_GATE; + + if (mesh_peer_accepts_plinks(elems) && + sta->mesh->plink_state == NL80211_PLINK_LISTEN && + sdata->u.mesh.accepting_plinks && + sdata->u.mesh.mshcfg.auto_open_plinks && + rssi_threshold_check(sdata, sta)) + changed = mesh_plink_open(sta); + + ieee80211_mps_frame_release(sta, elems); +out: + rcu_read_unlock(); + ieee80211_mbss_info_change_notify(sdata, changed); +} + +void mesh_plink_timer(struct timer_list *t) +{ + struct mesh_sta *mesh = from_timer(mesh, t, plink_timer); + struct sta_info *sta; + u16 reason = 0; + struct ieee80211_sub_if_data *sdata; + struct mesh_config *mshcfg; + enum ieee80211_self_protected_actioncode action = 0; + + /* + * This STA is valid because sta_info_destroy() will + * del_timer_sync() this timer after having made sure + * it cannot be readded (by deleting the plink.) + */ + sta = mesh->plink_sta; + + if (sta->sdata->local->quiescing) + return; + + spin_lock_bh(&sta->mesh->plink_lock); + + /* If a timer fires just before a state transition on another CPU, + * we may have already extended the timeout and changed state by the + * time we've acquired the lock and arrived here. In that case, + * skip this timer and wait for the new one. + */ + if (time_before(jiffies, sta->mesh->plink_timer.expires)) { + mpl_dbg(sta->sdata, + "Ignoring timer for %pM in state %s (timer adjusted)", + sta->sta.addr, mplstates[sta->mesh->plink_state]); + spin_unlock_bh(&sta->mesh->plink_lock); + return; + } + + /* del_timer() and handler may race when entering these states */ + if (sta->mesh->plink_state == NL80211_PLINK_LISTEN || + sta->mesh->plink_state == NL80211_PLINK_ESTAB) { + mpl_dbg(sta->sdata, + "Ignoring timer for %pM in state %s (timer deleted)", + sta->sta.addr, mplstates[sta->mesh->plink_state]); + spin_unlock_bh(&sta->mesh->plink_lock); + return; + } + + mpl_dbg(sta->sdata, + "Mesh plink timer for %pM fired on state %s\n", + sta->sta.addr, mplstates[sta->mesh->plink_state]); + sdata = sta->sdata; + mshcfg = &sdata->u.mesh.mshcfg; + + switch (sta->mesh->plink_state) { + case NL80211_PLINK_OPN_RCVD: + case NL80211_PLINK_OPN_SNT: + /* retry timer */ + if (sta->mesh->plink_retries < mshcfg->dot11MeshMaxRetries) { + u32 rand; + mpl_dbg(sta->sdata, + "Mesh plink for %pM (retry, timeout): %d %d\n", + sta->sta.addr, sta->mesh->plink_retries, + sta->mesh->plink_timeout); + get_random_bytes(&rand, sizeof(u32)); + sta->mesh->plink_timeout = sta->mesh->plink_timeout + + rand % sta->mesh->plink_timeout; + ++sta->mesh->plink_retries; + mod_plink_timer(sta, sta->mesh->plink_timeout); + action = WLAN_SP_MESH_PEERING_OPEN; + break; + } + reason = WLAN_REASON_MESH_MAX_RETRIES; + fallthrough; + case NL80211_PLINK_CNF_RCVD: + /* confirm timer */ + if (!reason) + reason = WLAN_REASON_MESH_CONFIRM_TIMEOUT; + sta->mesh->plink_state = NL80211_PLINK_HOLDING; + mod_plink_timer(sta, mshcfg->dot11MeshHoldingTimeout); + action = WLAN_SP_MESH_PEERING_CLOSE; + break; + case NL80211_PLINK_HOLDING: + /* holding timer */ + del_timer(&sta->mesh->plink_timer); + mesh_plink_fsm_restart(sta); + break; + default: + break; + } + spin_unlock_bh(&sta->mesh->plink_lock); + if (action) + mesh_plink_frame_tx(sdata, sta, action, sta->sta.addr, + sta->mesh->llid, sta->mesh->plid, reason); +} + +static inline void mesh_plink_timer_set(struct sta_info *sta, u32 timeout) +{ + sta->mesh->plink_timeout = timeout; + mod_timer(&sta->mesh->plink_timer, jiffies + msecs_to_jiffies(timeout)); +} + +static bool llid_in_use(struct ieee80211_sub_if_data *sdata, + u16 llid) +{ + struct ieee80211_local *local = sdata->local; + bool in_use = false; + struct sta_info *sta; + + rcu_read_lock(); + list_for_each_entry_rcu(sta, &local->sta_list, list) { + if (sdata != sta->sdata) + continue; + + if (!memcmp(&sta->mesh->llid, &llid, sizeof(llid))) { + in_use = true; + break; + } + } + rcu_read_unlock(); + + return in_use; +} + +static u16 mesh_get_new_llid(struct ieee80211_sub_if_data *sdata) +{ + u16 llid; + + do { + get_random_bytes(&llid, sizeof(llid)); + } while (llid_in_use(sdata, llid)); + + return llid; +} + +u32 mesh_plink_open(struct sta_info *sta) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + u32 changed; + + if (!test_sta_flag(sta, WLAN_STA_AUTH)) + return 0; + + spin_lock_bh(&sta->mesh->plink_lock); + sta->mesh->llid = mesh_get_new_llid(sdata); + if (sta->mesh->plink_state != NL80211_PLINK_LISTEN && + sta->mesh->plink_state != NL80211_PLINK_BLOCKED) { + spin_unlock_bh(&sta->mesh->plink_lock); + return 0; + } + sta->mesh->plink_state = NL80211_PLINK_OPN_SNT; + mesh_plink_timer_set(sta, sdata->u.mesh.mshcfg.dot11MeshRetryTimeout); + spin_unlock_bh(&sta->mesh->plink_lock); + mpl_dbg(sdata, + "Mesh plink: starting establishment with %pM\n", + sta->sta.addr); + + /* set the non-peer mode to active during peering */ + changed = ieee80211_mps_local_status_update(sdata); + + mesh_plink_frame_tx(sdata, sta, WLAN_SP_MESH_PEERING_OPEN, + sta->sta.addr, sta->mesh->llid, 0, 0); + return changed; +} + +u32 mesh_plink_block(struct sta_info *sta) +{ + u32 changed; + + spin_lock_bh(&sta->mesh->plink_lock); + changed = __mesh_plink_deactivate(sta); + sta->mesh->plink_state = NL80211_PLINK_BLOCKED; + spin_unlock_bh(&sta->mesh->plink_lock); + mesh_path_flush_by_nexthop(sta); + + return changed; +} + +static void mesh_plink_close(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + enum plink_event event) +{ + struct mesh_config *mshcfg = &sdata->u.mesh.mshcfg; + u16 reason = (event == CLS_ACPT) ? + WLAN_REASON_MESH_CLOSE : WLAN_REASON_MESH_CONFIG; + + sta->mesh->reason = reason; + sta->mesh->plink_state = NL80211_PLINK_HOLDING; + mod_plink_timer(sta, mshcfg->dot11MeshHoldingTimeout); +} + +static u32 mesh_plink_establish(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta) +{ + struct mesh_config *mshcfg = &sdata->u.mesh.mshcfg; + u32 changed = 0; + + del_timer(&sta->mesh->plink_timer); + sta->mesh->plink_state = NL80211_PLINK_ESTAB; + changed |= mesh_plink_inc_estab_count(sdata); + changed |= mesh_set_ht_prot_mode(sdata); + changed |= mesh_set_short_slot_time(sdata); + mpl_dbg(sdata, "Mesh plink with %pM ESTABLISHED\n", sta->sta.addr); + ieee80211_mps_sta_status_update(sta); + changed |= ieee80211_mps_set_sta_local_pm(sta, mshcfg->power_mode); + return changed; +} + +/** + * mesh_plink_fsm - step @sta MPM based on @event + * + * @sdata: interface + * @sta: mesh neighbor + * @event: peering event + * + * Return: changed MBSS flags + */ +static u32 mesh_plink_fsm(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, enum plink_event event) +{ + struct mesh_config *mshcfg = &sdata->u.mesh.mshcfg; + enum ieee80211_self_protected_actioncode action = 0; + u32 changed = 0; + bool flush = false; + + mpl_dbg(sdata, "peer %pM in state %s got event %s\n", sta->sta.addr, + mplstates[sta->mesh->plink_state], mplevents[event]); + + spin_lock_bh(&sta->mesh->plink_lock); + switch (sta->mesh->plink_state) { + case NL80211_PLINK_LISTEN: + switch (event) { + case CLS_ACPT: + mesh_plink_fsm_restart(sta); + break; + case OPN_ACPT: + sta->mesh->plink_state = NL80211_PLINK_OPN_RCVD; + sta->mesh->llid = mesh_get_new_llid(sdata); + mesh_plink_timer_set(sta, + mshcfg->dot11MeshRetryTimeout); + + /* set the non-peer mode to active during peering */ + changed |= ieee80211_mps_local_status_update(sdata); + action = WLAN_SP_MESH_PEERING_OPEN; + break; + default: + break; + } + break; + case NL80211_PLINK_OPN_SNT: + switch (event) { + case OPN_RJCT: + case CNF_RJCT: + case CLS_ACPT: + mesh_plink_close(sdata, sta, event); + action = WLAN_SP_MESH_PEERING_CLOSE; + break; + case OPN_ACPT: + /* retry timer is left untouched */ + sta->mesh->plink_state = NL80211_PLINK_OPN_RCVD; + action = WLAN_SP_MESH_PEERING_CONFIRM; + break; + case CNF_ACPT: + sta->mesh->plink_state = NL80211_PLINK_CNF_RCVD; + mod_plink_timer(sta, mshcfg->dot11MeshConfirmTimeout); + break; + default: + break; + } + break; + case NL80211_PLINK_OPN_RCVD: + switch (event) { + case OPN_RJCT: + case CNF_RJCT: + case CLS_ACPT: + mesh_plink_close(sdata, sta, event); + action = WLAN_SP_MESH_PEERING_CLOSE; + break; + case OPN_ACPT: + action = WLAN_SP_MESH_PEERING_CONFIRM; + break; + case CNF_ACPT: + changed |= mesh_plink_establish(sdata, sta); + break; + default: + break; + } + break; + case NL80211_PLINK_CNF_RCVD: + switch (event) { + case OPN_RJCT: + case CNF_RJCT: + case CLS_ACPT: + mesh_plink_close(sdata, sta, event); + action = WLAN_SP_MESH_PEERING_CLOSE; + break; + case OPN_ACPT: + changed |= mesh_plink_establish(sdata, sta); + action = WLAN_SP_MESH_PEERING_CONFIRM; + break; + default: + break; + } + break; + case NL80211_PLINK_ESTAB: + switch (event) { + case CLS_ACPT: + changed |= __mesh_plink_deactivate(sta); + changed |= mesh_set_ht_prot_mode(sdata); + changed |= mesh_set_short_slot_time(sdata); + mesh_plink_close(sdata, sta, event); + action = WLAN_SP_MESH_PEERING_CLOSE; + flush = true; + break; + case OPN_ACPT: + action = WLAN_SP_MESH_PEERING_CONFIRM; + break; + default: + break; + } + break; + case NL80211_PLINK_HOLDING: + switch (event) { + case CLS_ACPT: + del_timer(&sta->mesh->plink_timer); + mesh_plink_fsm_restart(sta); + break; + case OPN_ACPT: + case CNF_ACPT: + case OPN_RJCT: + case CNF_RJCT: + action = WLAN_SP_MESH_PEERING_CLOSE; + break; + default: + break; + } + break; + default: + /* should not get here, PLINK_BLOCKED is dealt with at the + * beginning of the function + */ + break; + } + spin_unlock_bh(&sta->mesh->plink_lock); + if (flush) + mesh_path_flush_by_nexthop(sta); + if (action) { + mesh_plink_frame_tx(sdata, sta, action, sta->sta.addr, + sta->mesh->llid, sta->mesh->plid, + sta->mesh->reason); + + /* also send confirm in open case */ + if (action == WLAN_SP_MESH_PEERING_OPEN) { + mesh_plink_frame_tx(sdata, sta, + WLAN_SP_MESH_PEERING_CONFIRM, + sta->sta.addr, sta->mesh->llid, + sta->mesh->plid, 0); + } + } + + return changed; +} + +/* + * mesh_plink_get_event - get correct MPM event + * + * @sdata: interface + * @sta: peer, leave NULL if processing a frame from a new suitable peer + * @elems: peering management IEs + * @ftype: frame type + * @llid: peer's peer link ID + * @plid: peer's local link ID + * + * Return: new peering event for @sta, but PLINK_UNDEFINED should be treated as + * an error. + */ +static enum plink_event +mesh_plink_get_event(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + struct ieee802_11_elems *elems, + enum ieee80211_self_protected_actioncode ftype, + u16 llid, u16 plid) +{ + enum plink_event event = PLINK_UNDEFINED; + u8 ie_len = elems->peering_len; + bool matches_local; + + matches_local = (ftype == WLAN_SP_MESH_PEERING_CLOSE || + mesh_matches_local(sdata, elems)); + + /* deny open request from non-matching peer */ + if (!matches_local && !sta) { + event = OPN_RJCT; + goto out; + } + + if (!sta) { + if (ftype != WLAN_SP_MESH_PEERING_OPEN) { + mpl_dbg(sdata, "Mesh plink: cls or cnf from unknown peer\n"); + goto out; + } + /* ftype == WLAN_SP_MESH_PEERING_OPEN */ + if (!mesh_plink_free_count(sdata)) { + mpl_dbg(sdata, "Mesh plink error: no more free plinks\n"); + goto out; + } + + /* new matching peer */ + event = OPN_ACPT; + goto out; + } else { + if (!test_sta_flag(sta, WLAN_STA_AUTH)) { + mpl_dbg(sdata, "Mesh plink: Action frame from non-authed peer\n"); + goto out; + } + if (sta->mesh->plink_state == NL80211_PLINK_BLOCKED) + goto out; + } + + switch (ftype) { + case WLAN_SP_MESH_PEERING_OPEN: + if (!matches_local) + event = OPN_RJCT; + else if (!mesh_plink_free_count(sdata) || + (sta->mesh->plid && sta->mesh->plid != plid)) + event = OPN_IGNR; + else + event = OPN_ACPT; + break; + case WLAN_SP_MESH_PEERING_CONFIRM: + if (!matches_local) + event = CNF_RJCT; + else if (!mesh_plink_free_count(sdata) || + sta->mesh->llid != llid || + (sta->mesh->plid && sta->mesh->plid != plid)) + event = CNF_IGNR; + else + event = CNF_ACPT; + break; + case WLAN_SP_MESH_PEERING_CLOSE: + if (sta->mesh->plink_state == NL80211_PLINK_ESTAB) + /* Do not check for llid or plid. This does not + * follow the standard but since multiple plinks + * per sta are not supported, it is necessary in + * order to avoid a livelock when MP A sees an + * establish peer link to MP B but MP B does not + * see it. This can be caused by a timeout in + * B's peer link establishment or B beign + * restarted. + */ + event = CLS_ACPT; + else if (sta->mesh->plid != plid) + event = CLS_IGNR; + else if (ie_len == 8 && sta->mesh->llid != llid) + event = CLS_IGNR; + else + event = CLS_ACPT; + break; + default: + mpl_dbg(sdata, "Mesh plink: unknown frame subtype\n"); + break; + } + +out: + return event; +} + +static void +mesh_process_plink_frame(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, + struct ieee802_11_elems *elems, + struct ieee80211_rx_status *rx_status) +{ + + struct sta_info *sta; + enum plink_event event; + enum ieee80211_self_protected_actioncode ftype; + u32 changed = 0; + u8 ie_len = elems->peering_len; + u16 plid, llid = 0; + + if (!elems->peering) { + mpl_dbg(sdata, + "Mesh plink: missing necessary peer link ie\n"); + return; + } + + if (elems->rsn_len && + sdata->u.mesh.security == IEEE80211_MESH_SEC_NONE) { + mpl_dbg(sdata, + "Mesh plink: can't establish link with secure peer\n"); + return; + } + + ftype = mgmt->u.action.u.self_prot.action_code; + if ((ftype == WLAN_SP_MESH_PEERING_OPEN && ie_len != 4) || + (ftype == WLAN_SP_MESH_PEERING_CONFIRM && ie_len != 6) || + (ftype == WLAN_SP_MESH_PEERING_CLOSE && ie_len != 6 + && ie_len != 8)) { + mpl_dbg(sdata, + "Mesh plink: incorrect plink ie length %d %d\n", + ftype, ie_len); + return; + } + + if (ftype != WLAN_SP_MESH_PEERING_CLOSE && + (!elems->mesh_id || !elems->mesh_config)) { + mpl_dbg(sdata, "Mesh plink: missing necessary ie\n"); + return; + } + /* Note the lines below are correct, the llid in the frame is the plid + * from the point of view of this host. + */ + plid = get_unaligned_le16(PLINK_GET_LLID(elems->peering)); + if (ftype == WLAN_SP_MESH_PEERING_CONFIRM || + (ftype == WLAN_SP_MESH_PEERING_CLOSE && ie_len == 8)) + llid = get_unaligned_le16(PLINK_GET_PLID(elems->peering)); + + /* WARNING: Only for sta pointer, is dropped & re-acquired */ + rcu_read_lock(); + + sta = sta_info_get(sdata, mgmt->sa); + + if (ftype == WLAN_SP_MESH_PEERING_OPEN && + !rssi_threshold_check(sdata, sta)) { + mpl_dbg(sdata, "Mesh plink: %pM does not meet rssi threshold\n", + mgmt->sa); + goto unlock_rcu; + } + + /* Now we will figure out the appropriate event... */ + event = mesh_plink_get_event(sdata, sta, elems, ftype, llid, plid); + + if (event == OPN_ACPT) { + rcu_read_unlock(); + /* allocate sta entry if necessary and update info */ + sta = mesh_sta_info_get(sdata, mgmt->sa, elems, rx_status); + if (!sta) { + mpl_dbg(sdata, "Mesh plink: failed to init peer!\n"); + goto unlock_rcu; + } + sta->mesh->plid = plid; + } else if (!sta && event == OPN_RJCT) { + mesh_plink_frame_tx(sdata, NULL, WLAN_SP_MESH_PEERING_CLOSE, + mgmt->sa, 0, plid, + WLAN_REASON_MESH_CONFIG); + goto unlock_rcu; + } else if (!sta || event == PLINK_UNDEFINED) { + /* something went wrong */ + goto unlock_rcu; + } + + if (event == CNF_ACPT) { + /* 802.11-2012 13.3.7.2 - update plid on CNF if not set */ + if (!sta->mesh->plid) + sta->mesh->plid = plid; + + sta->mesh->aid = get_unaligned_le16(PLINK_CNF_AID(mgmt)); + } + + changed |= mesh_plink_fsm(sdata, sta, event); + +unlock_rcu: + rcu_read_unlock(); + + if (changed) + ieee80211_mbss_info_change_notify(sdata, changed); +} + +void mesh_rx_plink_frame(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, size_t len, + struct ieee80211_rx_status *rx_status) +{ + struct ieee802_11_elems *elems; + size_t baselen; + u8 *baseaddr; + + /* need action_code, aux */ + if (len < IEEE80211_MIN_ACTION_SIZE + 3) + return; + + if (sdata->u.mesh.user_mpm) + /* userspace must register for these */ + return; + + if (is_multicast_ether_addr(mgmt->da)) { + mpl_dbg(sdata, + "Mesh plink: ignore frame from multicast address\n"); + return; + } + + baseaddr = mgmt->u.action.u.self_prot.variable; + baselen = (u8 *) mgmt->u.action.u.self_prot.variable - (u8 *) mgmt; + if (mgmt->u.action.u.self_prot.action_code == + WLAN_SP_MESH_PEERING_CONFIRM) { + baseaddr += 4; + baselen += 4; + + if (baselen > len) + return; + } + elems = ieee802_11_parse_elems(baseaddr, len - baselen, true, NULL); + if (elems) { + mesh_process_plink_frame(sdata, mgmt, elems, rx_status); + kfree(elems); + } +} diff --git a/net/mac80211/mesh_ps.c b/net/mac80211/mesh_ps.c new file mode 100644 index 000000000..3fbd0b9ff --- /dev/null +++ b/net/mac80211/mesh_ps.c @@ -0,0 +1,607 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2012-2013, Marco Porsch + * Copyright 2012-2013, cozybit Inc. + * Copyright (C) 2021 Intel Corporation + */ + +#include "mesh.h" +#include "wme.h" + + +/* mesh PS management */ + +/** + * mps_qos_null_get - create pre-addressed QoS Null frame for mesh powersave + * @sta: the station to get the frame for + */ +static struct sk_buff *mps_qos_null_get(struct sta_info *sta) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_hdr *nullfunc; /* use 4addr header */ + struct sk_buff *skb; + int size = sizeof(*nullfunc); + __le16 fc; + + skb = dev_alloc_skb(local->hw.extra_tx_headroom + size + 2); + if (!skb) + return NULL; + skb_reserve(skb, local->hw.extra_tx_headroom); + + nullfunc = skb_put(skb, size); + fc = cpu_to_le16(IEEE80211_FTYPE_DATA | IEEE80211_STYPE_QOS_NULLFUNC); + ieee80211_fill_mesh_addresses(nullfunc, &fc, sta->sta.addr, + sdata->vif.addr); + nullfunc->frame_control = fc; + nullfunc->duration_id = 0; + nullfunc->seq_ctrl = 0; + /* no address resolution for this frame -> set addr 1 immediately */ + memcpy(nullfunc->addr1, sta->sta.addr, ETH_ALEN); + skb_put_zero(skb, 2); /* append QoS control field */ + ieee80211_mps_set_frame_flags(sdata, sta, nullfunc); + + return skb; +} + +/** + * mps_qos_null_tx - send a QoS Null to indicate link-specific power mode + * @sta: the station to send to + */ +static void mps_qos_null_tx(struct sta_info *sta) +{ + struct sk_buff *skb; + + skb = mps_qos_null_get(sta); + if (!skb) + return; + + mps_dbg(sta->sdata, "announcing peer-specific power mode to %pM\n", + sta->sta.addr); + + /* don't unintentionally start a MPSP */ + if (!test_sta_flag(sta, WLAN_STA_PS_STA)) { + u8 *qc = ieee80211_get_qos_ctl((void *) skb->data); + + qc[0] |= IEEE80211_QOS_CTL_EOSP; + } + + ieee80211_tx_skb(sta->sdata, skb); +} + +/** + * ieee80211_mps_local_status_update - track status of local link-specific PMs + * + * @sdata: local mesh subif + * + * sets the non-peer power mode and triggers the driver PS (re-)configuration + * Return BSS_CHANGED_BEACON if a beacon update is necessary. + */ +u32 ieee80211_mps_local_status_update(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct sta_info *sta; + bool peering = false; + int light_sleep_cnt = 0; + int deep_sleep_cnt = 0; + u32 changed = 0; + enum nl80211_mesh_power_mode nonpeer_pm; + + rcu_read_lock(); + list_for_each_entry_rcu(sta, &sdata->local->sta_list, list) { + if (sdata != sta->sdata) + continue; + + switch (sta->mesh->plink_state) { + case NL80211_PLINK_OPN_SNT: + case NL80211_PLINK_OPN_RCVD: + case NL80211_PLINK_CNF_RCVD: + peering = true; + break; + case NL80211_PLINK_ESTAB: + if (sta->mesh->local_pm == NL80211_MESH_POWER_LIGHT_SLEEP) + light_sleep_cnt++; + else if (sta->mesh->local_pm == NL80211_MESH_POWER_DEEP_SLEEP) + deep_sleep_cnt++; + break; + default: + break; + } + } + rcu_read_unlock(); + + /* + * Set non-peer mode to active during peering/scanning/authentication + * (see IEEE802.11-2012 13.14.8.3). The non-peer mesh power mode is + * deep sleep if the local STA is in light or deep sleep towards at + * least one mesh peer (see 13.14.3.1). Otherwise, set it to the + * user-configured default value. + */ + if (peering) { + mps_dbg(sdata, "setting non-peer PM to active for peering\n"); + nonpeer_pm = NL80211_MESH_POWER_ACTIVE; + } else if (light_sleep_cnt || deep_sleep_cnt) { + mps_dbg(sdata, "setting non-peer PM to deep sleep\n"); + nonpeer_pm = NL80211_MESH_POWER_DEEP_SLEEP; + } else { + mps_dbg(sdata, "setting non-peer PM to user value\n"); + nonpeer_pm = ifmsh->mshcfg.power_mode; + } + + /* need update if sleep counts move between 0 and non-zero */ + if (ifmsh->nonpeer_pm != nonpeer_pm || + !ifmsh->ps_peers_light_sleep != !light_sleep_cnt || + !ifmsh->ps_peers_deep_sleep != !deep_sleep_cnt) + changed = BSS_CHANGED_BEACON; + + ifmsh->nonpeer_pm = nonpeer_pm; + ifmsh->ps_peers_light_sleep = light_sleep_cnt; + ifmsh->ps_peers_deep_sleep = deep_sleep_cnt; + + return changed; +} + +/** + * ieee80211_mps_set_sta_local_pm - set local PM towards a mesh STA + * + * @sta: mesh STA + * @pm: the power mode to set + * Return BSS_CHANGED_BEACON if a beacon update is in order. + */ +u32 ieee80211_mps_set_sta_local_pm(struct sta_info *sta, + enum nl80211_mesh_power_mode pm) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + + if (sta->mesh->local_pm == pm) + return 0; + + mps_dbg(sdata, "local STA operates in mode %d with %pM\n", + pm, sta->sta.addr); + + sta->mesh->local_pm = pm; + + /* + * announce peer-specific power mode transition + * (see IEEE802.11-2012 13.14.3.2 and 13.14.3.3) + */ + if (sta->mesh->plink_state == NL80211_PLINK_ESTAB) + mps_qos_null_tx(sta); + + return ieee80211_mps_local_status_update(sdata); +} + +/** + * ieee80211_mps_set_frame_flags - set mesh PS flags in FC (and QoS Control) + * + * @sdata: local mesh subif + * @sta: mesh STA + * @hdr: 802.11 frame header + * + * see IEEE802.11-2012 8.2.4.1.7 and 8.2.4.5.11 + * + * NOTE: sta must be given when an individually-addressed QoS frame header + * is handled, for group-addressed and management frames it is not used + */ +void ieee80211_mps_set_frame_flags(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + struct ieee80211_hdr *hdr) +{ + enum nl80211_mesh_power_mode pm; + u8 *qc; + + if (WARN_ON(is_unicast_ether_addr(hdr->addr1) && + ieee80211_is_data_qos(hdr->frame_control) && + !sta)) + return; + + if (is_unicast_ether_addr(hdr->addr1) && + ieee80211_is_data_qos(hdr->frame_control) && + sta->mesh->plink_state == NL80211_PLINK_ESTAB) + pm = sta->mesh->local_pm; + else + pm = sdata->u.mesh.nonpeer_pm; + + if (pm == NL80211_MESH_POWER_ACTIVE) + hdr->frame_control &= cpu_to_le16(~IEEE80211_FCTL_PM); + else + hdr->frame_control |= cpu_to_le16(IEEE80211_FCTL_PM); + + if (!ieee80211_is_data_qos(hdr->frame_control)) + return; + + qc = ieee80211_get_qos_ctl(hdr); + + if ((is_unicast_ether_addr(hdr->addr1) && + pm == NL80211_MESH_POWER_DEEP_SLEEP) || + (is_multicast_ether_addr(hdr->addr1) && + sdata->u.mesh.ps_peers_deep_sleep > 0)) + qc[1] |= (IEEE80211_QOS_CTL_MESH_PS_LEVEL >> 8); + else + qc[1] &= ~(IEEE80211_QOS_CTL_MESH_PS_LEVEL >> 8); +} + +/** + * ieee80211_mps_sta_status_update - update buffering status of neighbor STA + * + * @sta: mesh STA + * + * called after change of peering status or non-peer/peer-specific power mode + */ +void ieee80211_mps_sta_status_update(struct sta_info *sta) +{ + enum nl80211_mesh_power_mode pm; + bool do_buffer; + + /* For non-assoc STA, prevent buffering or frame transmission */ + if (sta->sta_state < IEEE80211_STA_ASSOC) + return; + + /* + * use peer-specific power mode if peering is established and the + * peer's power mode is known + */ + if (sta->mesh->plink_state == NL80211_PLINK_ESTAB && + sta->mesh->peer_pm != NL80211_MESH_POWER_UNKNOWN) + pm = sta->mesh->peer_pm; + else + pm = sta->mesh->nonpeer_pm; + + do_buffer = (pm != NL80211_MESH_POWER_ACTIVE); + + /* clear the MPSP flags for non-peers or active STA */ + if (sta->mesh->plink_state != NL80211_PLINK_ESTAB) { + clear_sta_flag(sta, WLAN_STA_MPSP_OWNER); + clear_sta_flag(sta, WLAN_STA_MPSP_RECIPIENT); + } else if (!do_buffer) { + clear_sta_flag(sta, WLAN_STA_MPSP_OWNER); + } + + /* Don't let the same PS state be set twice */ + if (test_sta_flag(sta, WLAN_STA_PS_STA) == do_buffer) + return; + + if (do_buffer) { + set_sta_flag(sta, WLAN_STA_PS_STA); + atomic_inc(&sta->sdata->u.mesh.ps.num_sta_ps); + mps_dbg(sta->sdata, "start PS buffering frames towards %pM\n", + sta->sta.addr); + } else { + ieee80211_sta_ps_deliver_wakeup(sta); + } +} + +static void mps_set_sta_peer_pm(struct sta_info *sta, + struct ieee80211_hdr *hdr) +{ + enum nl80211_mesh_power_mode pm; + u8 *qc = ieee80211_get_qos_ctl(hdr); + + /* + * Test Power Management field of frame control (PW) and + * mesh power save level subfield of QoS control field (PSL) + * + * | PM | PSL| Mesh PM | + * +----+----+---------+ + * | 0 |Rsrv| Active | + * | 1 | 0 | Light | + * | 1 | 1 | Deep | + */ + if (ieee80211_has_pm(hdr->frame_control)) { + if (qc[1] & (IEEE80211_QOS_CTL_MESH_PS_LEVEL >> 8)) + pm = NL80211_MESH_POWER_DEEP_SLEEP; + else + pm = NL80211_MESH_POWER_LIGHT_SLEEP; + } else { + pm = NL80211_MESH_POWER_ACTIVE; + } + + if (sta->mesh->peer_pm == pm) + return; + + mps_dbg(sta->sdata, "STA %pM enters mode %d\n", + sta->sta.addr, pm); + + sta->mesh->peer_pm = pm; + + ieee80211_mps_sta_status_update(sta); +} + +static void mps_set_sta_nonpeer_pm(struct sta_info *sta, + struct ieee80211_hdr *hdr) +{ + enum nl80211_mesh_power_mode pm; + + if (ieee80211_has_pm(hdr->frame_control)) + pm = NL80211_MESH_POWER_DEEP_SLEEP; + else + pm = NL80211_MESH_POWER_ACTIVE; + + if (sta->mesh->nonpeer_pm == pm) + return; + + mps_dbg(sta->sdata, "STA %pM sets non-peer mode to %d\n", + sta->sta.addr, pm); + + sta->mesh->nonpeer_pm = pm; + + ieee80211_mps_sta_status_update(sta); +} + +/** + * ieee80211_mps_rx_h_sta_process - frame receive handler for mesh powersave + * + * @sta: STA info that transmitted the frame + * @hdr: IEEE 802.11 (QoS) Header + */ +void ieee80211_mps_rx_h_sta_process(struct sta_info *sta, + struct ieee80211_hdr *hdr) +{ + if (is_unicast_ether_addr(hdr->addr1) && + ieee80211_is_data_qos(hdr->frame_control)) { + /* + * individually addressed QoS Data/Null frames contain + * peer link-specific PS mode towards the local STA + */ + mps_set_sta_peer_pm(sta, hdr); + + /* check for mesh Peer Service Period trigger frames */ + ieee80211_mpsp_trigger_process(ieee80211_get_qos_ctl(hdr), + sta, false, false); + } else { + /* + * can only determine non-peer PS mode + * (see IEEE802.11-2012 8.2.4.1.7) + */ + mps_set_sta_nonpeer_pm(sta, hdr); + } +} + + +/* mesh PS frame release */ + +static void mpsp_trigger_send(struct sta_info *sta, bool rspi, bool eosp) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct sk_buff *skb; + struct ieee80211_hdr *nullfunc; + struct ieee80211_tx_info *info; + u8 *qc; + + skb = mps_qos_null_get(sta); + if (!skb) + return; + + nullfunc = (struct ieee80211_hdr *) skb->data; + if (!eosp) + nullfunc->frame_control |= + cpu_to_le16(IEEE80211_FCTL_MOREDATA); + /* + * | RSPI | EOSP | MPSP triggering | + * +------+------+--------------------+ + * | 0 | 0 | local STA is owner | + * | 0 | 1 | no MPSP (MPSP end) | + * | 1 | 0 | both STA are owner | + * | 1 | 1 | peer STA is owner | see IEEE802.11-2012 13.14.9.2 + */ + qc = ieee80211_get_qos_ctl(nullfunc); + if (rspi) + qc[1] |= (IEEE80211_QOS_CTL_RSPI >> 8); + if (eosp) + qc[0] |= IEEE80211_QOS_CTL_EOSP; + + info = IEEE80211_SKB_CB(skb); + + info->flags |= IEEE80211_TX_CTL_NO_PS_BUFFER | + IEEE80211_TX_CTL_REQ_TX_STATUS; + + mps_dbg(sdata, "sending MPSP trigger%s%s to %pM\n", + rspi ? " RSPI" : "", eosp ? " EOSP" : "", sta->sta.addr); + + ieee80211_tx_skb(sdata, skb); +} + +/** + * mpsp_qos_null_append - append QoS Null frame to MPSP skb queue if needed + * @sta: the station to handle + * @frames: the frame list to append to + * + * To properly end a mesh MPSP the last transmitted frame has to set the EOSP + * flag in the QoS Control field. In case the current tailing frame is not a + * QoS Data frame, append a QoS Null to carry the flag. + */ +static void mpsp_qos_null_append(struct sta_info *sta, + struct sk_buff_head *frames) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct sk_buff *new_skb, *skb = skb_peek_tail(frames); + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + struct ieee80211_tx_info *info; + + if (ieee80211_is_data_qos(hdr->frame_control)) + return; + + new_skb = mps_qos_null_get(sta); + if (!new_skb) + return; + + mps_dbg(sdata, "appending QoS Null in MPSP towards %pM\n", + sta->sta.addr); + /* + * This frame has to be transmitted last. Assign lowest priority to + * make sure it cannot pass other frames when releasing multiple ACs. + */ + new_skb->priority = 1; + skb_set_queue_mapping(new_skb, IEEE80211_AC_BK); + ieee80211_set_qos_hdr(sdata, new_skb); + + info = IEEE80211_SKB_CB(new_skb); + info->control.vif = &sdata->vif; + info->control.flags |= IEEE80211_TX_INTCFL_NEED_TXPROCESSING; + + __skb_queue_tail(frames, new_skb); +} + +/** + * mps_frame_deliver - transmit frames during mesh powersave + * + * @sta: STA info to transmit to + * @n_frames: number of frames to transmit. -1 for all + */ +static void mps_frame_deliver(struct sta_info *sta, int n_frames) +{ + struct ieee80211_local *local = sta->sdata->local; + int ac; + struct sk_buff_head frames; + struct sk_buff *skb; + bool more_data = false; + + skb_queue_head_init(&frames); + + /* collect frame(s) from buffers */ + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { + while (n_frames != 0) { + skb = skb_dequeue(&sta->tx_filtered[ac]); + if (!skb) { + skb = skb_dequeue( + &sta->ps_tx_buf[ac]); + if (skb) + local->total_ps_buffered--; + } + if (!skb) + break; + n_frames--; + __skb_queue_tail(&frames, skb); + } + + if (!skb_queue_empty(&sta->tx_filtered[ac]) || + !skb_queue_empty(&sta->ps_tx_buf[ac])) + more_data = true; + } + + /* nothing to send? -> EOSP */ + if (skb_queue_empty(&frames)) { + mpsp_trigger_send(sta, false, true); + return; + } + + /* in a MPSP make sure the last skb is a QoS Data frame */ + if (test_sta_flag(sta, WLAN_STA_MPSP_OWNER)) + mpsp_qos_null_append(sta, &frames); + + mps_dbg(sta->sdata, "sending %d frames to PS STA %pM\n", + skb_queue_len(&frames), sta->sta.addr); + + /* prepare collected frames for transmission */ + skb_queue_walk(&frames, skb) { + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_hdr *hdr = (void *) skb->data; + + /* + * Tell TX path to send this frame even though the + * STA may still remain is PS mode after this frame + * exchange. + */ + info->flags |= IEEE80211_TX_CTL_NO_PS_BUFFER; + + if (more_data || !skb_queue_is_last(&frames, skb)) + hdr->frame_control |= + cpu_to_le16(IEEE80211_FCTL_MOREDATA); + else + hdr->frame_control &= + cpu_to_le16(~IEEE80211_FCTL_MOREDATA); + + if (skb_queue_is_last(&frames, skb) && + ieee80211_is_data_qos(hdr->frame_control)) { + u8 *qoshdr = ieee80211_get_qos_ctl(hdr); + + /* MPSP trigger frame ends service period */ + *qoshdr |= IEEE80211_QOS_CTL_EOSP; + info->flags |= IEEE80211_TX_CTL_REQ_TX_STATUS; + } + } + + ieee80211_add_pending_skbs(local, &frames); + sta_info_recalc_tim(sta); +} + +/** + * ieee80211_mpsp_trigger_process - track status of mesh Peer Service Periods + * + * @qc: QoS Control field + * @sta: peer to start a MPSP with + * @tx: frame was transmitted by the local STA + * @acked: frame has been transmitted successfully + * + * NOTE: active mode STA may only serve as MPSP owner + */ +void ieee80211_mpsp_trigger_process(u8 *qc, struct sta_info *sta, + bool tx, bool acked) +{ + u8 rspi = qc[1] & (IEEE80211_QOS_CTL_RSPI >> 8); + u8 eosp = qc[0] & IEEE80211_QOS_CTL_EOSP; + + if (tx) { + if (rspi && acked) + set_sta_flag(sta, WLAN_STA_MPSP_RECIPIENT); + + if (eosp) + clear_sta_flag(sta, WLAN_STA_MPSP_OWNER); + else if (acked && + test_sta_flag(sta, WLAN_STA_PS_STA) && + !test_and_set_sta_flag(sta, WLAN_STA_MPSP_OWNER)) + mps_frame_deliver(sta, -1); + } else { + if (eosp) + clear_sta_flag(sta, WLAN_STA_MPSP_RECIPIENT); + else if (sta->mesh->local_pm != NL80211_MESH_POWER_ACTIVE) + set_sta_flag(sta, WLAN_STA_MPSP_RECIPIENT); + + if (rspi && !test_and_set_sta_flag(sta, WLAN_STA_MPSP_OWNER)) + mps_frame_deliver(sta, -1); + } +} + +/** + * ieee80211_mps_frame_release - release frames buffered due to mesh power save + * + * @sta: mesh STA + * @elems: IEs of beacon or probe response + * + * For peers if we have individually-addressed frames buffered or the peer + * indicates buffered frames, send a corresponding MPSP trigger frame. Since + * we do not evaluate the awake window duration, QoS Nulls are used as MPSP + * trigger frames. If the neighbour STA is not a peer, only send single frames. + */ +void ieee80211_mps_frame_release(struct sta_info *sta, + struct ieee802_11_elems *elems) +{ + int ac, buffer_local = 0; + bool has_buffered = false; + + if (sta->mesh->plink_state == NL80211_PLINK_ESTAB) + has_buffered = ieee80211_check_tim(elems->tim, elems->tim_len, + sta->mesh->aid); + + if (has_buffered) + mps_dbg(sta->sdata, "%pM indicates buffered frames\n", + sta->sta.addr); + + /* only transmit to PS STA with announced, non-zero awake window */ + if (test_sta_flag(sta, WLAN_STA_PS_STA) && + (!elems->awake_window || !get_unaligned_le16(elems->awake_window))) + return; + + if (!test_sta_flag(sta, WLAN_STA_MPSP_OWNER)) + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) + buffer_local += skb_queue_len(&sta->ps_tx_buf[ac]) + + skb_queue_len(&sta->tx_filtered[ac]); + + if (!has_buffered && !buffer_local) + return; + + if (sta->mesh->plink_state == NL80211_PLINK_ESTAB) + mpsp_trigger_send(sta, has_buffered, !buffer_local); + else + mps_frame_deliver(sta, 1); +} diff --git a/net/mac80211/mesh_sync.c b/net/mac80211/mesh_sync.c new file mode 100644 index 000000000..9e342cc25 --- /dev/null +++ b/net/mac80211/mesh_sync.c @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2011-2012, Pavel Zubarev + * Copyright 2011-2012, Marco Porsch + * Copyright 2011-2012, cozybit Inc. + * Copyright (C) 2021 Intel Corporation + */ + +#include "ieee80211_i.h" +#include "mesh.h" +#include "driver-ops.h" + +/* This is not in the standard. It represents a tolerable tsf drift below + * which we do no TSF adjustment. + */ +#define TOFFSET_MINIMUM_ADJUSTMENT 10 + +/* This is not in the standard. It is a margin added to the + * Toffset setpoint to mitigate TSF overcorrection + * introduced by TSF adjustment latency. + */ +#define TOFFSET_SET_MARGIN 20 + +/* This is not in the standard. It represents the maximum Toffset jump above + * which we'll invalidate the Toffset setpoint and choose a new setpoint. This + * could be, for instance, in case a neighbor is restarted and its TSF counter + * reset. + */ +#define TOFFSET_MAXIMUM_ADJUSTMENT 800 /* 0.8 ms */ + +struct sync_method { + u8 method; + struct ieee80211_mesh_sync_ops ops; +}; + +/** + * mesh_peer_tbtt_adjusting - check if an mp is currently adjusting its TBTT + * + * @cfg: mesh config element from the mesh peer (or %NULL) + */ +static bool mesh_peer_tbtt_adjusting(const struct ieee80211_meshconf_ie *cfg) +{ + return cfg && + (cfg->meshconf_cap & IEEE80211_MESHCONF_CAPAB_TBTT_ADJUSTING); +} + +void mesh_sync_adjust_tsf(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + /* sdata->vif.bss_conf.beacon_int in 1024us units, 0.04% */ + u64 beacon_int_fraction = sdata->vif.bss_conf.beacon_int * 1024 / 2500; + u64 tsf; + u64 tsfdelta; + + spin_lock_bh(&ifmsh->sync_offset_lock); + if (ifmsh->sync_offset_clockdrift_max < beacon_int_fraction) { + msync_dbg(sdata, "TSF : max clockdrift=%lld; adjusting\n", + (long long) ifmsh->sync_offset_clockdrift_max); + tsfdelta = -ifmsh->sync_offset_clockdrift_max; + ifmsh->sync_offset_clockdrift_max = 0; + } else { + msync_dbg(sdata, "TSF : max clockdrift=%lld; adjusting by %llu\n", + (long long) ifmsh->sync_offset_clockdrift_max, + (unsigned long long) beacon_int_fraction); + tsfdelta = -beacon_int_fraction; + ifmsh->sync_offset_clockdrift_max -= beacon_int_fraction; + } + spin_unlock_bh(&ifmsh->sync_offset_lock); + + if (local->ops->offset_tsf) { + drv_offset_tsf(local, sdata, tsfdelta); + } else { + tsf = drv_get_tsf(local, sdata); + if (tsf != -1ULL) + drv_set_tsf(local, sdata, tsf + tsfdelta); + } +} + +static void +mesh_sync_offset_rx_bcn_presp(struct ieee80211_sub_if_data *sdata, u16 stype, + struct ieee80211_mgmt *mgmt, unsigned int len, + const struct ieee80211_meshconf_ie *mesh_cfg, + struct ieee80211_rx_status *rx_status) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + u64 t_t, t_r; + + WARN_ON(ifmsh->mesh_sp_id != IEEE80211_SYNC_METHOD_NEIGHBOR_OFFSET); + + /* standard mentions only beacons */ + if (stype != IEEE80211_STYPE_BEACON) + return; + + /* + * Get time when timestamp field was received. If we don't + * have rx timestamps, then use current tsf as an approximation. + * drv_get_tsf() must be called before entering the rcu-read + * section. + */ + if (ieee80211_have_rx_timestamp(rx_status)) + t_r = ieee80211_calculate_rx_timestamp(local, rx_status, + len + FCS_LEN, 24); + else + t_r = drv_get_tsf(local, sdata); + + rcu_read_lock(); + sta = sta_info_get(sdata, mgmt->sa); + if (!sta) + goto no_sync; + + /* check offset sync conditions (13.13.2.2.1) + * + * TODO also sync to + * dot11MeshNbrOffsetMaxNeighbor non-peer non-MBSS neighbors + */ + + if (mesh_peer_tbtt_adjusting(mesh_cfg)) { + msync_dbg(sdata, "STA %pM : is adjusting TBTT\n", + sta->sta.addr); + goto no_sync; + } + + /* Timing offset calculation (see 13.13.2.2.2) */ + t_t = le64_to_cpu(mgmt->u.beacon.timestamp); + sta->mesh->t_offset = t_t - t_r; + + if (test_sta_flag(sta, WLAN_STA_TOFFSET_KNOWN)) { + s64 t_clockdrift = sta->mesh->t_offset_setpoint - sta->mesh->t_offset; + msync_dbg(sdata, + "STA %pM : t_offset=%lld, t_offset_setpoint=%lld, t_clockdrift=%lld\n", + sta->sta.addr, (long long) sta->mesh->t_offset, + (long long) sta->mesh->t_offset_setpoint, + (long long) t_clockdrift); + + if (t_clockdrift > TOFFSET_MAXIMUM_ADJUSTMENT || + t_clockdrift < -TOFFSET_MAXIMUM_ADJUSTMENT) { + msync_dbg(sdata, + "STA %pM : t_clockdrift=%lld too large, setpoint reset\n", + sta->sta.addr, + (long long) t_clockdrift); + clear_sta_flag(sta, WLAN_STA_TOFFSET_KNOWN); + goto no_sync; + } + + spin_lock_bh(&ifmsh->sync_offset_lock); + if (t_clockdrift > ifmsh->sync_offset_clockdrift_max) + ifmsh->sync_offset_clockdrift_max = t_clockdrift; + spin_unlock_bh(&ifmsh->sync_offset_lock); + } else { + sta->mesh->t_offset_setpoint = sta->mesh->t_offset - TOFFSET_SET_MARGIN; + set_sta_flag(sta, WLAN_STA_TOFFSET_KNOWN); + msync_dbg(sdata, + "STA %pM : offset was invalid, t_offset=%lld\n", + sta->sta.addr, + (long long) sta->mesh->t_offset); + } + +no_sync: + rcu_read_unlock(); +} + +static void mesh_sync_offset_adjust_tsf(struct ieee80211_sub_if_data *sdata, + struct beacon_data *beacon) +{ + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + + WARN_ON(ifmsh->mesh_sp_id != IEEE80211_SYNC_METHOD_NEIGHBOR_OFFSET); + WARN_ON(!rcu_read_lock_held()); + + spin_lock_bh(&ifmsh->sync_offset_lock); + + if (ifmsh->sync_offset_clockdrift_max > TOFFSET_MINIMUM_ADJUSTMENT) { + /* Since ajusting the tsf here would + * require a possibly blocking call + * to the driver tsf setter, we punt + * the tsf adjustment to the mesh tasklet + */ + msync_dbg(sdata, + "TSF : kicking off TSF adjustment with clockdrift_max=%lld\n", + ifmsh->sync_offset_clockdrift_max); + set_bit(MESH_WORK_DRIFT_ADJUST, &ifmsh->wrkq_flags); + } else { + msync_dbg(sdata, + "TSF : max clockdrift=%lld; too small to adjust\n", + (long long)ifmsh->sync_offset_clockdrift_max); + ifmsh->sync_offset_clockdrift_max = 0; + } + spin_unlock_bh(&ifmsh->sync_offset_lock); +} + +static const struct sync_method sync_methods[] = { + { + .method = IEEE80211_SYNC_METHOD_NEIGHBOR_OFFSET, + .ops = { + .rx_bcn_presp = &mesh_sync_offset_rx_bcn_presp, + .adjust_tsf = &mesh_sync_offset_adjust_tsf, + } + }, +}; + +const struct ieee80211_mesh_sync_ops *ieee80211_mesh_sync_ops_get(u8 method) +{ + int i; + + for (i = 0 ; i < ARRAY_SIZE(sync_methods); ++i) { + if (sync_methods[i].method == method) + return &sync_methods[i].ops; + } + return NULL; +} diff --git a/net/mac80211/michael.c b/net/mac80211/michael.c new file mode 100644 index 000000000..a57502d9f --- /dev/null +++ b/net/mac80211/michael.c @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Michael MIC implementation - optimized for TKIP MIC operations + * Copyright 2002-2003, Instant802 Networks, Inc. + */ +#include +#include +#include +#include + +#include "michael.h" + +static void michael_block(struct michael_mic_ctx *mctx, u32 val) +{ + mctx->l ^= val; + mctx->r ^= rol32(mctx->l, 17); + mctx->l += mctx->r; + mctx->r ^= ((mctx->l & 0xff00ff00) >> 8) | + ((mctx->l & 0x00ff00ff) << 8); + mctx->l += mctx->r; + mctx->r ^= rol32(mctx->l, 3); + mctx->l += mctx->r; + mctx->r ^= ror32(mctx->l, 2); + mctx->l += mctx->r; +} + +static void michael_mic_hdr(struct michael_mic_ctx *mctx, const u8 *key, + struct ieee80211_hdr *hdr) +{ + u8 *da, *sa, tid; + + da = ieee80211_get_DA(hdr); + sa = ieee80211_get_SA(hdr); + if (ieee80211_is_data_qos(hdr->frame_control)) + tid = ieee80211_get_tid(hdr); + else + tid = 0; + + mctx->l = get_unaligned_le32(key); + mctx->r = get_unaligned_le32(key + 4); + + /* + * A pseudo header (DA, SA, Priority, 0, 0, 0) is used in Michael MIC + * calculation, but it is _not_ transmitted + */ + michael_block(mctx, get_unaligned_le32(da)); + michael_block(mctx, get_unaligned_le16(&da[4]) | + (get_unaligned_le16(sa) << 16)); + michael_block(mctx, get_unaligned_le32(&sa[2])); + michael_block(mctx, tid); +} + +void michael_mic(const u8 *key, struct ieee80211_hdr *hdr, + const u8 *data, size_t data_len, u8 *mic) +{ + u32 val; + size_t block, blocks, left; + struct michael_mic_ctx mctx; + + michael_mic_hdr(&mctx, key, hdr); + + /* Real data */ + blocks = data_len / 4; + left = data_len % 4; + + for (block = 0; block < blocks; block++) + michael_block(&mctx, get_unaligned_le32(&data[block * 4])); + + /* Partial block of 0..3 bytes and padding: 0x5a + 4..7 zeros to make + * total length a multiple of 4. */ + val = 0x5a; + while (left > 0) { + val <<= 8; + left--; + val |= data[blocks * 4 + left]; + } + + michael_block(&mctx, val); + michael_block(&mctx, 0); + + put_unaligned_le32(mctx.l, mic); + put_unaligned_le32(mctx.r, mic + 4); +} diff --git a/net/mac80211/michael.h b/net/mac80211/michael.h new file mode 100644 index 000000000..a7fdb8e84 --- /dev/null +++ b/net/mac80211/michael.h @@ -0,0 +1,22 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Michael MIC implementation - optimized for TKIP MIC operations + * Copyright 2002-2003, Instant802 Networks, Inc. + */ + +#ifndef MICHAEL_H +#define MICHAEL_H + +#include +#include + +#define MICHAEL_MIC_LEN 8 + +struct michael_mic_ctx { + u32 l, r; +}; + +void michael_mic(const u8 *key, struct ieee80211_hdr *hdr, + const u8 *data, size_t data_len, u8 *mic); + +#endif /* MICHAEL_H */ diff --git a/net/mac80211/mlme.c b/net/mac80211/mlme.c new file mode 100644 index 000000000..c07645c99 --- /dev/null +++ b/net/mac80211/mlme.c @@ -0,0 +1,7444 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * BSS client mode implementation + * Copyright 2003-2008, Jouni Malinen + * Copyright 2004, Instant802 Networks, Inc. + * Copyright 2005, Devicescape Software, Inc. + * Copyright 2006-2007 Jiri Benc + * Copyright 2007, Michael Wu + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright (C) 2015 - 2017 Intel Deutschland GmbH + * Copyright (C) 2018 - 2022 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ieee80211_i.h" +#include "driver-ops.h" +#include "rate.h" +#include "led.h" +#include "fils_aead.h" + +#define IEEE80211_AUTH_TIMEOUT (HZ / 5) +#define IEEE80211_AUTH_TIMEOUT_LONG (HZ / 2) +#define IEEE80211_AUTH_TIMEOUT_SHORT (HZ / 10) +#define IEEE80211_AUTH_TIMEOUT_SAE (HZ * 2) +#define IEEE80211_AUTH_MAX_TRIES 3 +#define IEEE80211_AUTH_WAIT_ASSOC (HZ * 5) +#define IEEE80211_AUTH_WAIT_SAE_RETRY (HZ * 2) +#define IEEE80211_ASSOC_TIMEOUT (HZ / 5) +#define IEEE80211_ASSOC_TIMEOUT_LONG (HZ / 2) +#define IEEE80211_ASSOC_TIMEOUT_SHORT (HZ / 10) +#define IEEE80211_ASSOC_MAX_TRIES 3 + +static int max_nullfunc_tries = 2; +module_param(max_nullfunc_tries, int, 0644); +MODULE_PARM_DESC(max_nullfunc_tries, + "Maximum nullfunc tx tries before disconnecting (reason 4)."); + +static int max_probe_tries = 5; +module_param(max_probe_tries, int, 0644); +MODULE_PARM_DESC(max_probe_tries, + "Maximum probe tries before disconnecting (reason 4)."); + +/* + * Beacon loss timeout is calculated as N frames times the + * advertised beacon interval. This may need to be somewhat + * higher than what hardware might detect to account for + * delays in the host processing frames. But since we also + * probe on beacon miss before declaring the connection lost + * default to what we want. + */ +static int beacon_loss_count = 7; +module_param(beacon_loss_count, int, 0644); +MODULE_PARM_DESC(beacon_loss_count, + "Number of beacon intervals before we decide beacon was lost."); + +/* + * Time the connection can be idle before we probe + * it to see if we can still talk to the AP. + */ +#define IEEE80211_CONNECTION_IDLE_TIME (30 * HZ) +/* + * Time we wait for a probe response after sending + * a probe request because of beacon loss or for + * checking the connection still works. + */ +static int probe_wait_ms = 500; +module_param(probe_wait_ms, int, 0644); +MODULE_PARM_DESC(probe_wait_ms, + "Maximum time(ms) to wait for probe response" + " before disconnecting (reason 4)."); + +/* + * How many Beacon frames need to have been used in average signal strength + * before starting to indicate signal change events. + */ +#define IEEE80211_SIGNAL_AVE_MIN_COUNT 4 + +/* + * We can have multiple work items (and connection probing) + * scheduling this timer, but we need to take care to only + * reschedule it when it should fire _earlier_ than it was + * asked for before, or if it's not pending right now. This + * function ensures that. Note that it then is required to + * run this function for all timeouts after the first one + * has happened -- the work that runs from this timer will + * do that. + */ +static void run_again(struct ieee80211_sub_if_data *sdata, + unsigned long timeout) +{ + sdata_assert_lock(sdata); + + if (!timer_pending(&sdata->u.mgd.timer) || + time_before(timeout, sdata->u.mgd.timer.expires)) + mod_timer(&sdata->u.mgd.timer, timeout); +} + +void ieee80211_sta_reset_beacon_monitor(struct ieee80211_sub_if_data *sdata) +{ + if (sdata->vif.driver_flags & IEEE80211_VIF_BEACON_FILTER) + return; + + if (ieee80211_hw_check(&sdata->local->hw, CONNECTION_MONITOR)) + return; + + mod_timer(&sdata->u.mgd.bcn_mon_timer, + round_jiffies_up(jiffies + sdata->u.mgd.beacon_timeout)); +} + +void ieee80211_sta_reset_conn_monitor(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + + if (unlikely(!ifmgd->associated)) + return; + + if (ifmgd->probe_send_count) + ifmgd->probe_send_count = 0; + + if (ieee80211_hw_check(&sdata->local->hw, CONNECTION_MONITOR)) + return; + + mod_timer(&ifmgd->conn_mon_timer, + round_jiffies_up(jiffies + IEEE80211_CONNECTION_IDLE_TIME)); +} + +static int ecw2cw(int ecw) +{ + return (1 << ecw) - 1; +} + +static ieee80211_conn_flags_t +ieee80211_determine_chantype(struct ieee80211_sub_if_data *sdata, + struct ieee80211_link_data *link, + ieee80211_conn_flags_t conn_flags, + struct ieee80211_supported_band *sband, + struct ieee80211_channel *channel, + u32 vht_cap_info, + const struct ieee80211_ht_operation *ht_oper, + const struct ieee80211_vht_operation *vht_oper, + const struct ieee80211_he_operation *he_oper, + const struct ieee80211_eht_operation *eht_oper, + const struct ieee80211_s1g_oper_ie *s1g_oper, + struct cfg80211_chan_def *chandef, bool tracking) +{ + struct cfg80211_chan_def vht_chandef; + struct ieee80211_sta_ht_cap sta_ht_cap; + ieee80211_conn_flags_t ret; + u32 ht_cfreq; + + memset(chandef, 0, sizeof(struct cfg80211_chan_def)); + chandef->chan = channel; + chandef->width = NL80211_CHAN_WIDTH_20_NOHT; + chandef->center_freq1 = channel->center_freq; + chandef->freq1_offset = channel->freq_offset; + + if (channel->band == NL80211_BAND_6GHZ) { + if (!ieee80211_chandef_he_6ghz_oper(sdata, he_oper, eht_oper, + chandef)) { + mlme_dbg(sdata, + "bad 6 GHz operation, disabling HT/VHT/HE/EHT\n"); + ret = IEEE80211_CONN_DISABLE_HT | + IEEE80211_CONN_DISABLE_VHT | + IEEE80211_CONN_DISABLE_HE | + IEEE80211_CONN_DISABLE_EHT; + } else { + ret = 0; + } + vht_chandef = *chandef; + goto out; + } else if (sband->band == NL80211_BAND_S1GHZ) { + if (!ieee80211_chandef_s1g_oper(s1g_oper, chandef)) { + sdata_info(sdata, + "Missing S1G Operation Element? Trying operating == primary\n"); + chandef->width = ieee80211_s1g_channel_width(channel); + } + + ret = IEEE80211_CONN_DISABLE_HT | IEEE80211_CONN_DISABLE_40MHZ | + IEEE80211_CONN_DISABLE_VHT | + IEEE80211_CONN_DISABLE_80P80MHZ | + IEEE80211_CONN_DISABLE_160MHZ; + goto out; + } + + memcpy(&sta_ht_cap, &sband->ht_cap, sizeof(sta_ht_cap)); + ieee80211_apply_htcap_overrides(sdata, &sta_ht_cap); + + if (!ht_oper || !sta_ht_cap.ht_supported) { + mlme_dbg(sdata, "HT operation missing / HT not supported\n"); + ret = IEEE80211_CONN_DISABLE_HT | + IEEE80211_CONN_DISABLE_VHT | + IEEE80211_CONN_DISABLE_HE | + IEEE80211_CONN_DISABLE_EHT; + goto out; + } + + chandef->width = NL80211_CHAN_WIDTH_20; + + ht_cfreq = ieee80211_channel_to_frequency(ht_oper->primary_chan, + channel->band); + /* check that channel matches the right operating channel */ + if (!tracking && channel->center_freq != ht_cfreq) { + /* + * It's possible that some APs are confused here; + * Netgear WNDR3700 sometimes reports 4 higher than + * the actual channel in association responses, but + * since we look at probe response/beacon data here + * it should be OK. + */ + sdata_info(sdata, + "Wrong control channel: center-freq: %d ht-cfreq: %d ht->primary_chan: %d band: %d - Disabling HT\n", + channel->center_freq, ht_cfreq, + ht_oper->primary_chan, channel->band); + ret = IEEE80211_CONN_DISABLE_HT | + IEEE80211_CONN_DISABLE_VHT | + IEEE80211_CONN_DISABLE_HE | + IEEE80211_CONN_DISABLE_EHT; + goto out; + } + + /* check 40 MHz support, if we have it */ + if (sta_ht_cap.cap & IEEE80211_HT_CAP_SUP_WIDTH_20_40) { + ieee80211_chandef_ht_oper(ht_oper, chandef); + } else { + mlme_dbg(sdata, "40 MHz not supported\n"); + /* 40 MHz (and 80 MHz) must be supported for VHT */ + ret = IEEE80211_CONN_DISABLE_VHT; + /* also mark 40 MHz disabled */ + ret |= IEEE80211_CONN_DISABLE_40MHZ; + goto out; + } + + if (!vht_oper || !sband->vht_cap.vht_supported) { + mlme_dbg(sdata, "VHT operation missing / VHT not supported\n"); + ret = IEEE80211_CONN_DISABLE_VHT; + goto out; + } + + vht_chandef = *chandef; + if (!(conn_flags & IEEE80211_CONN_DISABLE_HE) && + he_oper && + (le32_to_cpu(he_oper->he_oper_params) & + IEEE80211_HE_OPERATION_VHT_OPER_INFO)) { + struct ieee80211_vht_operation he_oper_vht_cap; + + /* + * Set only first 3 bytes (other 2 aren't used in + * ieee80211_chandef_vht_oper() anyway) + */ + memcpy(&he_oper_vht_cap, he_oper->optional, 3); + he_oper_vht_cap.basic_mcs_set = cpu_to_le16(0); + + if (!ieee80211_chandef_vht_oper(&sdata->local->hw, vht_cap_info, + &he_oper_vht_cap, ht_oper, + &vht_chandef)) { + if (!(conn_flags & IEEE80211_CONN_DISABLE_HE)) + sdata_info(sdata, + "HE AP VHT information is invalid, disabling HE\n"); + ret = IEEE80211_CONN_DISABLE_HE | IEEE80211_CONN_DISABLE_EHT; + goto out; + } + } else if (!ieee80211_chandef_vht_oper(&sdata->local->hw, + vht_cap_info, + vht_oper, ht_oper, + &vht_chandef)) { + if (!(conn_flags & IEEE80211_CONN_DISABLE_VHT)) + sdata_info(sdata, + "AP VHT information is invalid, disabling VHT\n"); + ret = IEEE80211_CONN_DISABLE_VHT; + goto out; + } + + if (!cfg80211_chandef_valid(&vht_chandef)) { + if (!(conn_flags & IEEE80211_CONN_DISABLE_VHT)) + sdata_info(sdata, + "AP VHT information is invalid, disabling VHT\n"); + ret = IEEE80211_CONN_DISABLE_VHT; + goto out; + } + + if (cfg80211_chandef_identical(chandef, &vht_chandef)) { + ret = 0; + goto out; + } + + if (!cfg80211_chandef_compatible(chandef, &vht_chandef)) { + if (!(conn_flags & IEEE80211_CONN_DISABLE_VHT)) + sdata_info(sdata, + "AP VHT information doesn't match HT, disabling VHT\n"); + ret = IEEE80211_CONN_DISABLE_VHT; + goto out; + } + + *chandef = vht_chandef; + + /* + * handle the case that the EHT operation indicates that it holds EHT + * operation information (in case that the channel width differs from + * the channel width reported in HT/VHT/HE). + */ + if (eht_oper && (eht_oper->params & IEEE80211_EHT_OPER_INFO_PRESENT)) { + struct cfg80211_chan_def eht_chandef = *chandef; + + ieee80211_chandef_eht_oper(eht_oper, + eht_chandef.width == + NL80211_CHAN_WIDTH_160, + false, &eht_chandef); + + if (!cfg80211_chandef_valid(&eht_chandef)) { + if (!(conn_flags & IEEE80211_CONN_DISABLE_EHT)) + sdata_info(sdata, + "AP EHT information is invalid, disabling EHT\n"); + ret = IEEE80211_CONN_DISABLE_EHT; + goto out; + } + + if (!cfg80211_chandef_compatible(chandef, &eht_chandef)) { + if (!(conn_flags & IEEE80211_CONN_DISABLE_EHT)) + sdata_info(sdata, + "AP EHT information is incompatible, disabling EHT\n"); + ret = IEEE80211_CONN_DISABLE_EHT; + goto out; + } + + *chandef = eht_chandef; + } + + ret = 0; + +out: + /* + * When tracking the current AP, don't do any further checks if the + * new chandef is identical to the one we're currently using for the + * connection. This keeps us from playing ping-pong with regulatory, + * without it the following can happen (for example): + * - connect to an AP with 80 MHz, world regdom allows 80 MHz + * - AP advertises regdom US + * - CRDA loads regdom US with 80 MHz prohibited (old database) + * - the code below detects an unsupported channel, downgrades, and + * we disconnect from the AP in the caller + * - disconnect causes CRDA to reload world regdomain and the game + * starts anew. + * (see https://bugzilla.kernel.org/show_bug.cgi?id=70881) + * + * It seems possible that there are still scenarios with CSA or real + * bandwidth changes where a this could happen, but those cases are + * less common and wouldn't completely prevent using the AP. + */ + if (tracking && + cfg80211_chandef_identical(chandef, &link->conf->chandef)) + return ret; + + /* don't print the message below for VHT mismatch if VHT is disabled */ + if (ret & IEEE80211_CONN_DISABLE_VHT) + vht_chandef = *chandef; + + /* + * Ignore the DISABLED flag when we're already connected and only + * tracking the APs beacon for bandwidth changes - otherwise we + * might get disconnected here if we connect to an AP, update our + * regulatory information based on the AP's country IE and the + * information we have is wrong/outdated and disables the channel + * that we're actually using for the connection to the AP. + */ + while (!cfg80211_chandef_usable(sdata->local->hw.wiphy, chandef, + tracking ? 0 : + IEEE80211_CHAN_DISABLED)) { + if (WARN_ON(chandef->width == NL80211_CHAN_WIDTH_20_NOHT)) { + ret = IEEE80211_CONN_DISABLE_HT | + IEEE80211_CONN_DISABLE_VHT | + IEEE80211_CONN_DISABLE_HE | + IEEE80211_CONN_DISABLE_EHT; + break; + } + + ret |= ieee80211_chandef_downgrade(chandef); + } + + if (!he_oper || !cfg80211_chandef_usable(sdata->wdev.wiphy, chandef, + IEEE80211_CHAN_NO_HE)) + ret |= IEEE80211_CONN_DISABLE_HE | IEEE80211_CONN_DISABLE_EHT; + + if (!eht_oper || !cfg80211_chandef_usable(sdata->wdev.wiphy, chandef, + IEEE80211_CHAN_NO_EHT)) + ret |= IEEE80211_CONN_DISABLE_EHT; + + if (chandef->width != vht_chandef.width && !tracking) + sdata_info(sdata, + "capabilities/regulatory prevented using AP HT/VHT configuration, downgraded\n"); + + WARN_ON_ONCE(!cfg80211_chandef_valid(chandef)); + return ret; +} + +static int ieee80211_config_bw(struct ieee80211_link_data *link, + const struct ieee80211_ht_cap *ht_cap, + const struct ieee80211_vht_cap *vht_cap, + const struct ieee80211_ht_operation *ht_oper, + const struct ieee80211_vht_operation *vht_oper, + const struct ieee80211_he_operation *he_oper, + const struct ieee80211_eht_operation *eht_oper, + const struct ieee80211_s1g_oper_ie *s1g_oper, + const u8 *bssid, u32 *changed) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_channel *chan = link->conf->chandef.chan; + struct ieee80211_supported_band *sband = + local->hw.wiphy->bands[chan->band]; + struct cfg80211_chan_def chandef; + u16 ht_opmode; + ieee80211_conn_flags_t flags; + u32 vht_cap_info = 0; + int ret; + + /* if HT was/is disabled, don't track any bandwidth changes */ + if (link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_HT || !ht_oper) + return 0; + + /* don't check VHT if we associated as non-VHT station */ + if (link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_VHT) + vht_oper = NULL; + + /* don't check HE if we associated as non-HE station */ + if (link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_HE || + !ieee80211_get_he_iftype_cap(sband, + ieee80211_vif_type_p2p(&sdata->vif))) { + he_oper = NULL; + eht_oper = NULL; + } + + /* don't check EHT if we associated as non-EHT station */ + if (link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_EHT || + !ieee80211_get_eht_iftype_cap(sband, + ieee80211_vif_type_p2p(&sdata->vif))) + eht_oper = NULL; + + /* + * if bss configuration changed store the new one - + * this may be applicable even if channel is identical + */ + ht_opmode = le16_to_cpu(ht_oper->operation_mode); + if (link->conf->ht_operation_mode != ht_opmode) { + *changed |= BSS_CHANGED_HT; + link->conf->ht_operation_mode = ht_opmode; + } + + if (vht_cap) + vht_cap_info = le32_to_cpu(vht_cap->vht_cap_info); + + /* calculate new channel (type) based on HT/VHT/HE operation IEs */ + flags = ieee80211_determine_chantype(sdata, link, + link->u.mgd.conn_flags, + sband, chan, vht_cap_info, + ht_oper, vht_oper, + he_oper, eht_oper, + s1g_oper, &chandef, true); + + /* + * Downgrade the new channel if we associated with restricted + * capabilities. For example, if we associated as a 20 MHz STA + * to a 40 MHz AP (due to regulatory, capabilities or config + * reasons) then switching to a 40 MHz channel now won't do us + * any good -- we couldn't use it with the AP. + */ + if (link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_80P80MHZ && + chandef.width == NL80211_CHAN_WIDTH_80P80) + flags |= ieee80211_chandef_downgrade(&chandef); + if (link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_160MHZ && + chandef.width == NL80211_CHAN_WIDTH_160) + flags |= ieee80211_chandef_downgrade(&chandef); + if (link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_40MHZ && + chandef.width > NL80211_CHAN_WIDTH_20) + flags |= ieee80211_chandef_downgrade(&chandef); + + if (cfg80211_chandef_identical(&chandef, &link->conf->chandef)) + return 0; + + link_info(link, + "AP %pM changed bandwidth, new config is %d.%03d MHz, width %d (%d.%03d/%d MHz)\n", + link->u.mgd.bssid, chandef.chan->center_freq, + chandef.chan->freq_offset, chandef.width, + chandef.center_freq1, chandef.freq1_offset, + chandef.center_freq2); + + if (flags != (link->u.mgd.conn_flags & + (IEEE80211_CONN_DISABLE_HT | + IEEE80211_CONN_DISABLE_VHT | + IEEE80211_CONN_DISABLE_HE | + IEEE80211_CONN_DISABLE_EHT | + IEEE80211_CONN_DISABLE_40MHZ | + IEEE80211_CONN_DISABLE_80P80MHZ | + IEEE80211_CONN_DISABLE_160MHZ | + IEEE80211_CONN_DISABLE_320MHZ)) || + !cfg80211_chandef_valid(&chandef)) { + sdata_info(sdata, + "AP %pM changed caps/bw in a way we can't support (0x%x/0x%x) - disconnect\n", + link->u.mgd.bssid, flags, ifmgd->flags); + return -EINVAL; + } + + ret = ieee80211_link_change_bandwidth(link, &chandef, changed); + + if (ret) { + sdata_info(sdata, + "AP %pM changed bandwidth to incompatible one - disconnect\n", + link->u.mgd.bssid); + return ret; + } + + return 0; +} + +/* frame sending functions */ + +static void ieee80211_add_ht_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, u8 ap_ht_param, + struct ieee80211_supported_band *sband, + struct ieee80211_channel *channel, + enum ieee80211_smps_mode smps, + ieee80211_conn_flags_t conn_flags) +{ + u8 *pos; + u32 flags = channel->flags; + u16 cap; + struct ieee80211_sta_ht_cap ht_cap; + + BUILD_BUG_ON(sizeof(ht_cap) != sizeof(sband->ht_cap)); + + memcpy(&ht_cap, &sband->ht_cap, sizeof(ht_cap)); + ieee80211_apply_htcap_overrides(sdata, &ht_cap); + + /* determine capability flags */ + cap = ht_cap.cap; + + switch (ap_ht_param & IEEE80211_HT_PARAM_CHA_SEC_OFFSET) { + case IEEE80211_HT_PARAM_CHA_SEC_ABOVE: + if (flags & IEEE80211_CHAN_NO_HT40PLUS) { + cap &= ~IEEE80211_HT_CAP_SUP_WIDTH_20_40; + cap &= ~IEEE80211_HT_CAP_SGI_40; + } + break; + case IEEE80211_HT_PARAM_CHA_SEC_BELOW: + if (flags & IEEE80211_CHAN_NO_HT40MINUS) { + cap &= ~IEEE80211_HT_CAP_SUP_WIDTH_20_40; + cap &= ~IEEE80211_HT_CAP_SGI_40; + } + break; + } + + /* + * If 40 MHz was disabled associate as though we weren't + * capable of 40 MHz -- some broken APs will never fall + * back to trying to transmit in 20 MHz. + */ + if (conn_flags & IEEE80211_CONN_DISABLE_40MHZ) { + cap &= ~IEEE80211_HT_CAP_SUP_WIDTH_20_40; + cap &= ~IEEE80211_HT_CAP_SGI_40; + } + + /* set SM PS mode properly */ + cap &= ~IEEE80211_HT_CAP_SM_PS; + switch (smps) { + case IEEE80211_SMPS_AUTOMATIC: + case IEEE80211_SMPS_NUM_MODES: + WARN_ON(1); + fallthrough; + case IEEE80211_SMPS_OFF: + cap |= WLAN_HT_CAP_SM_PS_DISABLED << + IEEE80211_HT_CAP_SM_PS_SHIFT; + break; + case IEEE80211_SMPS_STATIC: + cap |= WLAN_HT_CAP_SM_PS_STATIC << + IEEE80211_HT_CAP_SM_PS_SHIFT; + break; + case IEEE80211_SMPS_DYNAMIC: + cap |= WLAN_HT_CAP_SM_PS_DYNAMIC << + IEEE80211_HT_CAP_SM_PS_SHIFT; + break; + } + + /* reserve and fill IE */ + pos = skb_put(skb, sizeof(struct ieee80211_ht_cap) + 2); + ieee80211_ie_build_ht_cap(pos, &ht_cap, cap); +} + +/* This function determines vht capability flags for the association + * and builds the IE. + * Note - the function returns true to own the MU-MIMO capability + */ +static bool ieee80211_add_vht_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, + struct ieee80211_supported_band *sband, + struct ieee80211_vht_cap *ap_vht_cap, + ieee80211_conn_flags_t conn_flags) +{ + struct ieee80211_local *local = sdata->local; + u8 *pos; + u32 cap; + struct ieee80211_sta_vht_cap vht_cap; + u32 mask, ap_bf_sts, our_bf_sts; + bool mu_mimo_owner = false; + + BUILD_BUG_ON(sizeof(vht_cap) != sizeof(sband->vht_cap)); + + memcpy(&vht_cap, &sband->vht_cap, sizeof(vht_cap)); + ieee80211_apply_vhtcap_overrides(sdata, &vht_cap); + + /* determine capability flags */ + cap = vht_cap.cap; + + if (conn_flags & IEEE80211_CONN_DISABLE_80P80MHZ) { + u32 bw = cap & IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_MASK; + + cap &= ~IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_MASK; + if (bw == IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160MHZ || + bw == IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160_80PLUS80MHZ) + cap |= IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160MHZ; + } + + if (conn_flags & IEEE80211_CONN_DISABLE_160MHZ) { + cap &= ~IEEE80211_VHT_CAP_SHORT_GI_160; + cap &= ~IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_MASK; + } + + /* + * Some APs apparently get confused if our capabilities are better + * than theirs, so restrict what we advertise in the assoc request. + */ + if (!(ap_vht_cap->vht_cap_info & + cpu_to_le32(IEEE80211_VHT_CAP_SU_BEAMFORMER_CAPABLE))) + cap &= ~(IEEE80211_VHT_CAP_SU_BEAMFORMEE_CAPABLE | + IEEE80211_VHT_CAP_MU_BEAMFORMEE_CAPABLE); + else if (!(ap_vht_cap->vht_cap_info & + cpu_to_le32(IEEE80211_VHT_CAP_MU_BEAMFORMER_CAPABLE))) + cap &= ~IEEE80211_VHT_CAP_MU_BEAMFORMEE_CAPABLE; + + /* + * If some other vif is using the MU-MIMO capability we cannot associate + * using MU-MIMO - this will lead to contradictions in the group-id + * mechanism. + * Ownership is defined since association request, in order to avoid + * simultaneous associations with MU-MIMO. + */ + if (cap & IEEE80211_VHT_CAP_MU_BEAMFORMEE_CAPABLE) { + bool disable_mu_mimo = false; + struct ieee80211_sub_if_data *other; + + list_for_each_entry_rcu(other, &local->interfaces, list) { + if (other->vif.bss_conf.mu_mimo_owner) { + disable_mu_mimo = true; + break; + } + } + if (disable_mu_mimo) + cap &= ~IEEE80211_VHT_CAP_MU_BEAMFORMEE_CAPABLE; + else + mu_mimo_owner = true; + } + + mask = IEEE80211_VHT_CAP_BEAMFORMEE_STS_MASK; + + ap_bf_sts = le32_to_cpu(ap_vht_cap->vht_cap_info) & mask; + our_bf_sts = cap & mask; + + if (ap_bf_sts < our_bf_sts) { + cap &= ~mask; + cap |= ap_bf_sts; + } + + /* reserve and fill IE */ + pos = skb_put(skb, sizeof(struct ieee80211_vht_cap) + 2); + ieee80211_ie_build_vht_cap(pos, &vht_cap, cap); + + return mu_mimo_owner; +} + +/* This function determines HE capability flags for the association + * and builds the IE. + */ +static void ieee80211_add_he_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, + struct ieee80211_supported_band *sband, + enum ieee80211_smps_mode smps_mode, + ieee80211_conn_flags_t conn_flags) +{ + u8 *pos, *pre_he_pos; + const struct ieee80211_sta_he_cap *he_cap; + u8 he_cap_size; + + he_cap = ieee80211_get_he_iftype_cap(sband, + ieee80211_vif_type_p2p(&sdata->vif)); + if (WARN_ON(!he_cap)) + return; + + /* get a max size estimate */ + he_cap_size = + 2 + 1 + sizeof(he_cap->he_cap_elem) + + ieee80211_he_mcs_nss_size(&he_cap->he_cap_elem) + + ieee80211_he_ppe_size(he_cap->ppe_thres[0], + he_cap->he_cap_elem.phy_cap_info); + pos = skb_put(skb, he_cap_size); + pre_he_pos = pos; + pos = ieee80211_ie_build_he_cap(conn_flags, + pos, he_cap, pos + he_cap_size); + /* trim excess if any */ + skb_trim(skb, skb->len - (pre_he_pos + he_cap_size - pos)); + + ieee80211_ie_build_he_6ghz_cap(sdata, smps_mode, skb); +} + +static void ieee80211_add_eht_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, + struct ieee80211_supported_band *sband) +{ + u8 *pos; + const struct ieee80211_sta_he_cap *he_cap; + const struct ieee80211_sta_eht_cap *eht_cap; + u8 eht_cap_size; + + he_cap = ieee80211_get_he_iftype_cap(sband, + ieee80211_vif_type_p2p(&sdata->vif)); + eht_cap = ieee80211_get_eht_iftype_cap(sband, + ieee80211_vif_type_p2p(&sdata->vif)); + + /* + * EHT capabilities element is only added if the HE capabilities element + * was added so assume that 'he_cap' is valid and don't check it. + */ + if (WARN_ON(!he_cap || !eht_cap)) + return; + + eht_cap_size = + 2 + 1 + sizeof(eht_cap->eht_cap_elem) + + ieee80211_eht_mcs_nss_size(&he_cap->he_cap_elem, + &eht_cap->eht_cap_elem, + false) + + ieee80211_eht_ppe_size(eht_cap->eht_ppe_thres[0], + eht_cap->eht_cap_elem.phy_cap_info); + pos = skb_put(skb, eht_cap_size); + ieee80211_ie_build_eht_cap(pos, he_cap, eht_cap, pos + eht_cap_size, + false); +} + +static void ieee80211_assoc_add_rates(struct sk_buff *skb, + enum nl80211_chan_width width, + struct ieee80211_supported_band *sband, + struct ieee80211_mgd_assoc_data *assoc_data) +{ + unsigned int shift = ieee80211_chanwidth_get_shift(width); + unsigned int rates_len, supp_rates_len; + u32 rates = 0; + int i, count; + u8 *pos; + + if (assoc_data->supp_rates_len) { + /* + * Get all rates supported by the device and the AP as + * some APs don't like getting a superset of their rates + * in the association request (e.g. D-Link DAP 1353 in + * b-only mode)... + */ + rates_len = ieee80211_parse_bitrates(width, sband, + assoc_data->supp_rates, + assoc_data->supp_rates_len, + &rates); + } else { + /* + * In case AP not provide any supported rates information + * before association, we send information element(s) with + * all rates that we support. + */ + rates_len = sband->n_bitrates; + for (i = 0; i < sband->n_bitrates; i++) + rates |= BIT(i); + } + + supp_rates_len = rates_len; + if (supp_rates_len > 8) + supp_rates_len = 8; + + pos = skb_put(skb, supp_rates_len + 2); + *pos++ = WLAN_EID_SUPP_RATES; + *pos++ = supp_rates_len; + + count = 0; + for (i = 0; i < sband->n_bitrates; i++) { + if (BIT(i) & rates) { + int rate = DIV_ROUND_UP(sband->bitrates[i].bitrate, + 5 * (1 << shift)); + *pos++ = (u8)rate; + if (++count == 8) + break; + } + } + + if (rates_len > count) { + pos = skb_put(skb, rates_len - count + 2); + *pos++ = WLAN_EID_EXT_SUPP_RATES; + *pos++ = rates_len - count; + + for (i++; i < sband->n_bitrates; i++) { + if (BIT(i) & rates) { + int rate; + + rate = DIV_ROUND_UP(sband->bitrates[i].bitrate, + 5 * (1 << shift)); + *pos++ = (u8)rate; + } + } + } +} + +static size_t ieee80211_add_before_ht_elems(struct sk_buff *skb, + const u8 *elems, + size_t elems_len, + size_t offset) +{ + size_t noffset; + + static const u8 before_ht[] = { + WLAN_EID_SSID, + WLAN_EID_SUPP_RATES, + WLAN_EID_EXT_SUPP_RATES, + WLAN_EID_PWR_CAPABILITY, + WLAN_EID_SUPPORTED_CHANNELS, + WLAN_EID_RSN, + WLAN_EID_QOS_CAPA, + WLAN_EID_RRM_ENABLED_CAPABILITIES, + WLAN_EID_MOBILITY_DOMAIN, + WLAN_EID_FAST_BSS_TRANSITION, /* reassoc only */ + WLAN_EID_RIC_DATA, /* reassoc only */ + WLAN_EID_SUPPORTED_REGULATORY_CLASSES, + }; + static const u8 after_ric[] = { + WLAN_EID_SUPPORTED_REGULATORY_CLASSES, + WLAN_EID_HT_CAPABILITY, + WLAN_EID_BSS_COEX_2040, + /* luckily this is almost always there */ + WLAN_EID_EXT_CAPABILITY, + WLAN_EID_QOS_TRAFFIC_CAPA, + WLAN_EID_TIM_BCAST_REQ, + WLAN_EID_INTERWORKING, + /* 60 GHz (Multi-band, DMG, MMS) can't happen */ + WLAN_EID_VHT_CAPABILITY, + WLAN_EID_OPMODE_NOTIF, + }; + + if (!elems_len) + return offset; + + noffset = ieee80211_ie_split_ric(elems, elems_len, + before_ht, + ARRAY_SIZE(before_ht), + after_ric, + ARRAY_SIZE(after_ric), + offset); + skb_put_data(skb, elems + offset, noffset - offset); + + return noffset; +} + +static size_t ieee80211_add_before_vht_elems(struct sk_buff *skb, + const u8 *elems, + size_t elems_len, + size_t offset) +{ + static const u8 before_vht[] = { + /* + * no need to list the ones split off before HT + * or generated here + */ + WLAN_EID_BSS_COEX_2040, + WLAN_EID_EXT_CAPABILITY, + WLAN_EID_QOS_TRAFFIC_CAPA, + WLAN_EID_TIM_BCAST_REQ, + WLAN_EID_INTERWORKING, + /* 60 GHz (Multi-band, DMG, MMS) can't happen */ + }; + size_t noffset; + + if (!elems_len) + return offset; + + /* RIC already taken care of in ieee80211_add_before_ht_elems() */ + noffset = ieee80211_ie_split(elems, elems_len, + before_vht, ARRAY_SIZE(before_vht), + offset); + skb_put_data(skb, elems + offset, noffset - offset); + + return noffset; +} + +static size_t ieee80211_add_before_he_elems(struct sk_buff *skb, + const u8 *elems, + size_t elems_len, + size_t offset) +{ + static const u8 before_he[] = { + /* + * no need to list the ones split off before VHT + * or generated here + */ + WLAN_EID_OPMODE_NOTIF, + WLAN_EID_EXTENSION, WLAN_EID_EXT_FUTURE_CHAN_GUIDANCE, + /* 11ai elements */ + WLAN_EID_EXTENSION, WLAN_EID_EXT_FILS_SESSION, + WLAN_EID_EXTENSION, WLAN_EID_EXT_FILS_PUBLIC_KEY, + WLAN_EID_EXTENSION, WLAN_EID_EXT_FILS_KEY_CONFIRM, + WLAN_EID_EXTENSION, WLAN_EID_EXT_FILS_HLP_CONTAINER, + WLAN_EID_EXTENSION, WLAN_EID_EXT_FILS_IP_ADDR_ASSIGN, + /* TODO: add 11ah/11aj/11ak elements */ + }; + size_t noffset; + + if (!elems_len) + return offset; + + /* RIC already taken care of in ieee80211_add_before_ht_elems() */ + noffset = ieee80211_ie_split(elems, elems_len, + before_he, ARRAY_SIZE(before_he), + offset); + skb_put_data(skb, elems + offset, noffset - offset); + + return noffset; +} + +#define PRESENT_ELEMS_MAX 8 +#define PRESENT_ELEM_EXT_OFFS 0x100 + +static void ieee80211_assoc_add_ml_elem(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, u16 capab, + const struct element *ext_capa, + const u16 *present_elems); + +static size_t ieee80211_assoc_link_elems(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, u16 *capab, + const struct element *ext_capa, + const u8 *extra_elems, + size_t extra_elems_len, + unsigned int link_id, + struct ieee80211_link_data *link, + u16 *present_elems) +{ + enum nl80211_iftype iftype = ieee80211_vif_type_p2p(&sdata->vif); + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_mgd_assoc_data *assoc_data = ifmgd->assoc_data; + struct cfg80211_bss *cbss = assoc_data->link[link_id].bss; + struct ieee80211_channel *chan = cbss->channel; + const struct ieee80211_sband_iftype_data *iftd; + struct ieee80211_local *local = sdata->local; + struct ieee80211_supported_band *sband; + enum nl80211_chan_width width = NL80211_CHAN_WIDTH_20; + struct ieee80211_chanctx_conf *chanctx_conf; + enum ieee80211_smps_mode smps_mode; + u16 orig_capab = *capab; + size_t offset = 0; + int present_elems_len = 0; + u8 *pos; + int i; + +#define ADD_PRESENT_ELEM(id) do { \ + /* need a last for termination - we use 0 == SSID */ \ + if (!WARN_ON(present_elems_len >= PRESENT_ELEMS_MAX - 1)) \ + present_elems[present_elems_len++] = (id); \ +} while (0) +#define ADD_PRESENT_EXT_ELEM(id) ADD_PRESENT_ELEM(PRESENT_ELEM_EXT_OFFS | (id)) + + if (link) + smps_mode = link->smps_mode; + else if (sdata->u.mgd.powersave) + smps_mode = IEEE80211_SMPS_DYNAMIC; + else + smps_mode = IEEE80211_SMPS_OFF; + + if (link) { + /* + * 5/10 MHz scenarios are only viable without MLO, in which + * case this pointer should be used ... All of this is a bit + * unclear though, not sure this even works at all. + */ + rcu_read_lock(); + chanctx_conf = rcu_dereference(link->conf->chanctx_conf); + if (chanctx_conf) + width = chanctx_conf->def.width; + rcu_read_unlock(); + } + + sband = local->hw.wiphy->bands[chan->band]; + iftd = ieee80211_get_sband_iftype_data(sband, iftype); + + if (sband->band == NL80211_BAND_2GHZ) { + *capab |= WLAN_CAPABILITY_SHORT_SLOT_TIME; + *capab |= WLAN_CAPABILITY_SHORT_PREAMBLE; + } + + if ((cbss->capability & WLAN_CAPABILITY_SPECTRUM_MGMT) && + ieee80211_hw_check(&local->hw, SPECTRUM_MGMT)) + *capab |= WLAN_CAPABILITY_SPECTRUM_MGMT; + + if (sband->band != NL80211_BAND_S1GHZ) + ieee80211_assoc_add_rates(skb, width, sband, assoc_data); + + if (*capab & WLAN_CAPABILITY_SPECTRUM_MGMT || + *capab & WLAN_CAPABILITY_RADIO_MEASURE) { + struct cfg80211_chan_def chandef = { + .width = width, + .chan = chan, + }; + + pos = skb_put(skb, 4); + *pos++ = WLAN_EID_PWR_CAPABILITY; + *pos++ = 2; + *pos++ = 0; /* min tx power */ + /* max tx power */ + *pos++ = ieee80211_chandef_max_power(&chandef); + ADD_PRESENT_ELEM(WLAN_EID_PWR_CAPABILITY); + } + + /* + * Per spec, we shouldn't include the list of channels if we advertise + * support for extended channel switching, but we've always done that; + * (for now?) apply this restriction only on the (new) 6 GHz band. + */ + if (*capab & WLAN_CAPABILITY_SPECTRUM_MGMT && + (sband->band != NL80211_BAND_6GHZ || + !ext_capa || ext_capa->datalen < 1 || + !(ext_capa->data[0] & WLAN_EXT_CAPA1_EXT_CHANNEL_SWITCHING))) { + /* TODO: get this in reg domain format */ + pos = skb_put(skb, 2 * sband->n_channels + 2); + *pos++ = WLAN_EID_SUPPORTED_CHANNELS; + *pos++ = 2 * sband->n_channels; + for (i = 0; i < sband->n_channels; i++) { + int cf = sband->channels[i].center_freq; + + *pos++ = ieee80211_frequency_to_channel(cf); + *pos++ = 1; /* one channel in the subband*/ + } + ADD_PRESENT_ELEM(WLAN_EID_SUPPORTED_CHANNELS); + } + + /* if present, add any custom IEs that go before HT */ + offset = ieee80211_add_before_ht_elems(skb, extra_elems, + extra_elems_len, + offset); + + if (sband->band != NL80211_BAND_6GHZ && + !(assoc_data->link[link_id].conn_flags & IEEE80211_CONN_DISABLE_HT)) { + ieee80211_add_ht_ie(sdata, skb, + assoc_data->link[link_id].ap_ht_param, + sband, chan, smps_mode, + assoc_data->link[link_id].conn_flags); + ADD_PRESENT_ELEM(WLAN_EID_HT_CAPABILITY); + } + + /* if present, add any custom IEs that go before VHT */ + offset = ieee80211_add_before_vht_elems(skb, extra_elems, + extra_elems_len, + offset); + + if (sband->band != NL80211_BAND_6GHZ && + !(assoc_data->link[link_id].conn_flags & IEEE80211_CONN_DISABLE_VHT)) { + bool mu_mimo_owner = + ieee80211_add_vht_ie(sdata, skb, sband, + &assoc_data->link[link_id].ap_vht_cap, + assoc_data->link[link_id].conn_flags); + + if (link) + link->conf->mu_mimo_owner = mu_mimo_owner; + ADD_PRESENT_ELEM(WLAN_EID_VHT_CAPABILITY); + } + + /* + * If AP doesn't support HT, mark HE and EHT as disabled. + * If on the 5GHz band, make sure it supports VHT. + */ + if (assoc_data->link[link_id].conn_flags & IEEE80211_CONN_DISABLE_HT || + (sband->band == NL80211_BAND_5GHZ && + assoc_data->link[link_id].conn_flags & IEEE80211_CONN_DISABLE_VHT)) + assoc_data->link[link_id].conn_flags |= + IEEE80211_CONN_DISABLE_HE | + IEEE80211_CONN_DISABLE_EHT; + + /* if present, add any custom IEs that go before HE */ + offset = ieee80211_add_before_he_elems(skb, extra_elems, + extra_elems_len, + offset); + + if (!(assoc_data->link[link_id].conn_flags & IEEE80211_CONN_DISABLE_HE)) { + ieee80211_add_he_ie(sdata, skb, sband, smps_mode, + assoc_data->link[link_id].conn_flags); + ADD_PRESENT_EXT_ELEM(WLAN_EID_EXT_HE_CAPABILITY); + } + + /* + * careful - need to know about all the present elems before + * calling ieee80211_assoc_add_ml_elem(), so add this one if + * we're going to put it after the ML element + */ + if (!(assoc_data->link[link_id].conn_flags & IEEE80211_CONN_DISABLE_EHT)) + ADD_PRESENT_EXT_ELEM(WLAN_EID_EXT_EHT_CAPABILITY); + + if (link_id == assoc_data->assoc_link_id) + ieee80211_assoc_add_ml_elem(sdata, skb, orig_capab, ext_capa, + present_elems); + + /* crash if somebody gets it wrong */ + present_elems = NULL; + + if (!(assoc_data->link[link_id].conn_flags & IEEE80211_CONN_DISABLE_EHT)) + ieee80211_add_eht_ie(sdata, skb, sband); + + if (sband->band == NL80211_BAND_S1GHZ) { + ieee80211_add_aid_request_ie(sdata, skb); + ieee80211_add_s1g_capab_ie(sdata, &sband->s1g_cap, skb); + } + + if (iftd && iftd->vendor_elems.data && iftd->vendor_elems.len) + skb_put_data(skb, iftd->vendor_elems.data, iftd->vendor_elems.len); + + if (link) + link->u.mgd.conn_flags = assoc_data->link[link_id].conn_flags; + + return offset; +} + +static void ieee80211_add_non_inheritance_elem(struct sk_buff *skb, + const u16 *outer, + const u16 *inner) +{ + unsigned int skb_len = skb->len; + bool at_extension = false; + bool added = false; + int i, j; + u8 *len, *list_len = NULL; + + skb_put_u8(skb, WLAN_EID_EXTENSION); + len = skb_put(skb, 1); + skb_put_u8(skb, WLAN_EID_EXT_NON_INHERITANCE); + + for (i = 0; i < PRESENT_ELEMS_MAX && outer[i]; i++) { + u16 elem = outer[i]; + bool have_inner = false; + + /* should at least be sorted in the sense of normal -> ext */ + WARN_ON(at_extension && elem < PRESENT_ELEM_EXT_OFFS); + + /* switch to extension list */ + if (!at_extension && elem >= PRESENT_ELEM_EXT_OFFS) { + at_extension = true; + if (!list_len) + skb_put_u8(skb, 0); + list_len = NULL; + } + + for (j = 0; j < PRESENT_ELEMS_MAX && inner[j]; j++) { + if (elem == inner[j]) { + have_inner = true; + break; + } + } + + if (have_inner) + continue; + + if (!list_len) { + list_len = skb_put(skb, 1); + *list_len = 0; + } + *list_len += 1; + skb_put_u8(skb, (u8)elem); + added = true; + } + + /* if we added a list but no extension list, make a zero-len one */ + if (added && (!at_extension || !list_len)) + skb_put_u8(skb, 0); + + /* if nothing added remove extension element completely */ + if (!added) + skb_trim(skb, skb_len); + else + *len = skb->len - skb_len - 2; +} + +static void ieee80211_assoc_add_ml_elem(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, u16 capab, + const struct element *ext_capa, + const u16 *outer_present_elems) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_mgd_assoc_data *assoc_data = ifmgd->assoc_data; + struct ieee80211_multi_link_elem *ml_elem; + struct ieee80211_mle_basic_common_info *common; + const struct wiphy_iftype_ext_capab *ift_ext_capa; + __le16 eml_capa = 0, mld_capa_ops = 0; + unsigned int link_id; + u8 *ml_elem_len; + void *capab_pos; + + if (!sdata->vif.valid_links) + return; + + ift_ext_capa = cfg80211_get_iftype_ext_capa(local->hw.wiphy, + ieee80211_vif_type_p2p(&sdata->vif)); + if (ift_ext_capa) { + eml_capa = cpu_to_le16(ift_ext_capa->eml_capabilities); + mld_capa_ops = cpu_to_le16(ift_ext_capa->mld_capa_and_ops); + } + + skb_put_u8(skb, WLAN_EID_EXTENSION); + ml_elem_len = skb_put(skb, 1); + skb_put_u8(skb, WLAN_EID_EXT_EHT_MULTI_LINK); + ml_elem = skb_put(skb, sizeof(*ml_elem)); + ml_elem->control = + cpu_to_le16(IEEE80211_ML_CONTROL_TYPE_BASIC | + IEEE80211_MLC_BASIC_PRES_MLD_CAPA_OP); + common = skb_put(skb, sizeof(*common)); + common->len = sizeof(*common) + + 2; /* MLD capa/ops */ + memcpy(common->mld_mac_addr, sdata->vif.addr, ETH_ALEN); + + /* add EML_CAPA only if needed, see Draft P802.11be_D2.1, 35.3.17 */ + if (eml_capa & + cpu_to_le16((IEEE80211_EML_CAP_EMLSR_SUPP | + IEEE80211_EML_CAP_EMLMR_SUPPORT))) { + common->len += 2; /* EML capabilities */ + ml_elem->control |= + cpu_to_le16(IEEE80211_MLC_BASIC_PRES_EML_CAPA); + skb_put_data(skb, &eml_capa, sizeof(eml_capa)); + } + /* need indication from userspace to support this */ + mld_capa_ops &= ~cpu_to_le16(IEEE80211_MLD_CAP_OP_TID_TO_LINK_MAP_NEG_SUPP); + skb_put_data(skb, &mld_capa_ops, sizeof(mld_capa_ops)); + + for (link_id = 0; link_id < IEEE80211_MLD_MAX_NUM_LINKS; link_id++) { + u16 link_present_elems[PRESENT_ELEMS_MAX] = {}; + const u8 *extra_elems; + size_t extra_elems_len; + size_t extra_used; + u8 *subelem_len = NULL; + __le16 ctrl; + + if (!assoc_data->link[link_id].bss || + link_id == assoc_data->assoc_link_id) + continue; + + extra_elems = assoc_data->link[link_id].elems; + extra_elems_len = assoc_data->link[link_id].elems_len; + + skb_put_u8(skb, IEEE80211_MLE_SUBELEM_PER_STA_PROFILE); + subelem_len = skb_put(skb, 1); + + ctrl = cpu_to_le16(link_id | + IEEE80211_MLE_STA_CONTROL_COMPLETE_PROFILE | + IEEE80211_MLE_STA_CONTROL_STA_MAC_ADDR_PRESENT); + skb_put_data(skb, &ctrl, sizeof(ctrl)); + skb_put_u8(skb, 1 + ETH_ALEN); /* STA Info Length */ + skb_put_data(skb, assoc_data->link[link_id].addr, + ETH_ALEN); + /* + * Now add the contents of the (re)association request, + * but the "listen interval" and "current AP address" + * (if applicable) are skipped. So we only have + * the capability field (remember the position and fill + * later), followed by the elements added below by + * calling ieee80211_assoc_link_elems(). + */ + capab_pos = skb_put(skb, 2); + + extra_used = ieee80211_assoc_link_elems(sdata, skb, &capab, + ext_capa, + extra_elems, + extra_elems_len, + link_id, NULL, + link_present_elems); + if (extra_elems) + skb_put_data(skb, extra_elems + extra_used, + extra_elems_len - extra_used); + + put_unaligned_le16(capab, capab_pos); + + ieee80211_add_non_inheritance_elem(skb, outer_present_elems, + link_present_elems); + + ieee80211_fragment_element(skb, subelem_len); + } + + ieee80211_fragment_element(skb, ml_elem_len); +} + +static int ieee80211_send_assoc(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_mgd_assoc_data *assoc_data = ifmgd->assoc_data; + struct ieee80211_link_data *link; + struct sk_buff *skb; + struct ieee80211_mgmt *mgmt; + u8 *pos, qos_info, *ie_start; + size_t offset, noffset; + u16 capab = WLAN_CAPABILITY_ESS, link_capab; + __le16 listen_int; + struct element *ext_capa = NULL; + enum nl80211_iftype iftype = ieee80211_vif_type_p2p(&sdata->vif); + struct ieee80211_prep_tx_info info = {}; + unsigned int link_id, n_links = 0; + u16 present_elems[PRESENT_ELEMS_MAX] = {}; + void *capab_pos; + size_t size; + int ret; + + /* we know it's writable, cast away the const */ + if (assoc_data->ie_len) + ext_capa = (void *)cfg80211_find_elem(WLAN_EID_EXT_CAPABILITY, + assoc_data->ie, + assoc_data->ie_len); + + sdata_assert_lock(sdata); + + size = local->hw.extra_tx_headroom + + sizeof(*mgmt) + /* bit too much but doesn't matter */ + 2 + assoc_data->ssid_len + /* SSID */ + assoc_data->ie_len + /* extra IEs */ + (assoc_data->fils_kek_len ? 16 /* AES-SIV */ : 0) + + 9; /* WMM */ + + for (link_id = 0; link_id < IEEE80211_MLD_MAX_NUM_LINKS; link_id++) { + struct cfg80211_bss *cbss = assoc_data->link[link_id].bss; + const struct ieee80211_sband_iftype_data *iftd; + struct ieee80211_supported_band *sband; + + if (!cbss) + continue; + + sband = local->hw.wiphy->bands[cbss->channel->band]; + + n_links++; + /* add STA profile elements length */ + size += assoc_data->link[link_id].elems_len; + /* and supported rates length */ + size += 4 + sband->n_bitrates; + /* supported channels */ + size += 2 + 2 * sband->n_channels; + + iftd = ieee80211_get_sband_iftype_data(sband, iftype); + if (iftd) + size += iftd->vendor_elems.len; + + /* power capability */ + size += 4; + + /* HT, VHT, HE, EHT */ + size += 2 + sizeof(struct ieee80211_ht_cap); + size += 2 + sizeof(struct ieee80211_vht_cap); + size += 2 + 1 + sizeof(struct ieee80211_he_cap_elem) + + sizeof(struct ieee80211_he_mcs_nss_supp) + + IEEE80211_HE_PPE_THRES_MAX_LEN; + + if (sband->band == NL80211_BAND_6GHZ) + size += 2 + 1 + sizeof(struct ieee80211_he_6ghz_capa); + + size += 2 + 1 + sizeof(struct ieee80211_eht_cap_elem) + + sizeof(struct ieee80211_eht_mcs_nss_supp) + + IEEE80211_EHT_PPE_THRES_MAX_LEN; + + /* non-inheritance element */ + size += 2 + 2 + PRESENT_ELEMS_MAX; + + /* should be the same across all BSSes */ + if (cbss->capability & WLAN_CAPABILITY_PRIVACY) + capab |= WLAN_CAPABILITY_PRIVACY; + } + + if (sdata->vif.valid_links) { + /* consider the multi-link element with STA profile */ + size += sizeof(struct ieee80211_multi_link_elem); + /* max common info field in basic multi-link element */ + size += sizeof(struct ieee80211_mle_basic_common_info) + + 2 + /* capa & op */ + 2; /* EML capa */ + + /* + * The capability elements were already considered above; + * note this over-estimates a bit because there's no + * STA profile for the assoc link. + */ + size += (n_links - 1) * + (1 + 1 + /* subelement ID/length */ + 2 + /* STA control */ + 1 + ETH_ALEN + 2 /* STA Info field */); + } + + link = sdata_dereference(sdata->link[assoc_data->assoc_link_id], sdata); + if (WARN_ON(!link)) + return -EINVAL; + + if (WARN_ON(!assoc_data->link[assoc_data->assoc_link_id].bss)) + return -EINVAL; + + skb = alloc_skb(size, GFP_KERNEL); + if (!skb) + return -ENOMEM; + + skb_reserve(skb, local->hw.extra_tx_headroom); + + if (ifmgd->flags & IEEE80211_STA_ENABLE_RRM) + capab |= WLAN_CAPABILITY_RADIO_MEASURE; + + /* Set MBSSID support for HE AP if needed */ + if (ieee80211_hw_check(&local->hw, SUPPORTS_ONLY_HE_MULTI_BSSID) && + !(link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_HE) && + ext_capa && ext_capa->datalen >= 3) + ext_capa->data[2] |= WLAN_EXT_CAPA3_MULTI_BSSID_SUPPORT; + + mgmt = skb_put_zero(skb, 24); + memcpy(mgmt->da, sdata->vif.cfg.ap_addr, ETH_ALEN); + memcpy(mgmt->sa, sdata->vif.addr, ETH_ALEN); + memcpy(mgmt->bssid, sdata->vif.cfg.ap_addr, ETH_ALEN); + + listen_int = cpu_to_le16(assoc_data->s1g ? + ieee80211_encode_usf(local->hw.conf.listen_interval) : + local->hw.conf.listen_interval); + if (!is_zero_ether_addr(assoc_data->prev_ap_addr)) { + skb_put(skb, 10); + mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_REASSOC_REQ); + capab_pos = &mgmt->u.reassoc_req.capab_info; + mgmt->u.reassoc_req.listen_interval = listen_int; + memcpy(mgmt->u.reassoc_req.current_ap, + assoc_data->prev_ap_addr, ETH_ALEN); + info.subtype = IEEE80211_STYPE_REASSOC_REQ; + } else { + skb_put(skb, 4); + mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_ASSOC_REQ); + capab_pos = &mgmt->u.assoc_req.capab_info; + mgmt->u.assoc_req.listen_interval = listen_int; + info.subtype = IEEE80211_STYPE_ASSOC_REQ; + } + + /* SSID */ + pos = skb_put(skb, 2 + assoc_data->ssid_len); + ie_start = pos; + *pos++ = WLAN_EID_SSID; + *pos++ = assoc_data->ssid_len; + memcpy(pos, assoc_data->ssid, assoc_data->ssid_len); + + /* add the elements for the assoc (main) link */ + link_capab = capab; + offset = ieee80211_assoc_link_elems(sdata, skb, &link_capab, + ext_capa, + assoc_data->ie, + assoc_data->ie_len, + assoc_data->assoc_link_id, link, + present_elems); + put_unaligned_le16(link_capab, capab_pos); + + /* if present, add any custom non-vendor IEs */ + if (assoc_data->ie_len) { + noffset = ieee80211_ie_split_vendor(assoc_data->ie, + assoc_data->ie_len, + offset); + skb_put_data(skb, assoc_data->ie + offset, noffset - offset); + offset = noffset; + } + + if (assoc_data->wmm) { + if (assoc_data->uapsd) { + qos_info = ifmgd->uapsd_queues; + qos_info |= (ifmgd->uapsd_max_sp_len << + IEEE80211_WMM_IE_STA_QOSINFO_SP_SHIFT); + } else { + qos_info = 0; + } + + pos = ieee80211_add_wmm_info_ie(skb_put(skb, 9), qos_info); + } + + /* add any remaining custom (i.e. vendor specific here) IEs */ + if (assoc_data->ie_len) { + noffset = assoc_data->ie_len; + skb_put_data(skb, assoc_data->ie + offset, noffset - offset); + } + + if (assoc_data->fils_kek_len) { + ret = fils_encrypt_assoc_req(skb, assoc_data); + if (ret < 0) { + dev_kfree_skb(skb); + return ret; + } + } + + pos = skb_tail_pointer(skb); + kfree(ifmgd->assoc_req_ies); + ifmgd->assoc_req_ies = kmemdup(ie_start, pos - ie_start, GFP_ATOMIC); + if (!ifmgd->assoc_req_ies) { + dev_kfree_skb(skb); + return -ENOMEM; + } + + ifmgd->assoc_req_ies_len = pos - ie_start; + + drv_mgd_prepare_tx(local, sdata, &info); + + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_INTFL_DONT_ENCRYPT; + if (ieee80211_hw_check(&local->hw, REPORTS_TX_ACK_STATUS)) + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_CTL_REQ_TX_STATUS | + IEEE80211_TX_INTFL_MLME_CONN_TX; + ieee80211_tx_skb(sdata, skb); + + return 0; +} + +void ieee80211_send_pspoll(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_pspoll *pspoll; + struct sk_buff *skb; + + skb = ieee80211_pspoll_get(&local->hw, &sdata->vif); + if (!skb) + return; + + pspoll = (struct ieee80211_pspoll *) skb->data; + pspoll->frame_control |= cpu_to_le16(IEEE80211_FCTL_PM); + + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_INTFL_DONT_ENCRYPT; + ieee80211_tx_skb(sdata, skb); +} + +void ieee80211_send_nullfunc(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + bool powersave) +{ + struct sk_buff *skb; + struct ieee80211_hdr_3addr *nullfunc; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + + skb = ieee80211_nullfunc_get(&local->hw, &sdata->vif, -1, + !ieee80211_hw_check(&local->hw, + DOESNT_SUPPORT_QOS_NDP)); + if (!skb) + return; + + nullfunc = (struct ieee80211_hdr_3addr *) skb->data; + if (powersave) + nullfunc->frame_control |= cpu_to_le16(IEEE80211_FCTL_PM); + + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_INTFL_DONT_ENCRYPT | + IEEE80211_TX_INTFL_OFFCHAN_TX_OK; + + if (ieee80211_hw_check(&local->hw, REPORTS_TX_ACK_STATUS)) + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_CTL_REQ_TX_STATUS; + + if (ifmgd->flags & IEEE80211_STA_CONNECTION_POLL) + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_CTL_USE_MINRATE; + + ieee80211_tx_skb(sdata, skb); +} + +void ieee80211_send_4addr_nullfunc(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + struct sk_buff *skb; + struct ieee80211_hdr *nullfunc; + __le16 fc; + + if (WARN_ON(sdata->vif.type != NL80211_IFTYPE_STATION)) + return; + + skb = dev_alloc_skb(local->hw.extra_tx_headroom + 30); + if (!skb) + return; + + skb_reserve(skb, local->hw.extra_tx_headroom); + + nullfunc = skb_put_zero(skb, 30); + fc = cpu_to_le16(IEEE80211_FTYPE_DATA | IEEE80211_STYPE_NULLFUNC | + IEEE80211_FCTL_FROMDS | IEEE80211_FCTL_TODS); + nullfunc->frame_control = fc; + memcpy(nullfunc->addr1, sdata->deflink.u.mgd.bssid, ETH_ALEN); + memcpy(nullfunc->addr2, sdata->vif.addr, ETH_ALEN); + memcpy(nullfunc->addr3, sdata->deflink.u.mgd.bssid, ETH_ALEN); + memcpy(nullfunc->addr4, sdata->vif.addr, ETH_ALEN); + + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_INTFL_DONT_ENCRYPT; + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_CTL_USE_MINRATE; + ieee80211_tx_skb(sdata, skb); +} + +/* spectrum management related things */ +static void ieee80211_chswitch_work(struct work_struct *work) +{ + struct ieee80211_link_data *link = + container_of(work, struct ieee80211_link_data, u.mgd.chswitch_work); + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + int ret; + + if (!ieee80211_sdata_running(sdata)) + return; + + sdata_lock(sdata); + mutex_lock(&local->mtx); + mutex_lock(&local->chanctx_mtx); + + if (!ifmgd->associated) + goto out; + + if (!link->conf->csa_active) + goto out; + + /* + * using reservation isn't immediate as it may be deferred until later + * with multi-vif. once reservation is complete it will re-schedule the + * work with no reserved_chanctx so verify chandef to check if it + * completed successfully + */ + + if (link->reserved_chanctx) { + /* + * with multi-vif csa driver may call ieee80211_csa_finish() + * many times while waiting for other interfaces to use their + * reservations + */ + if (link->reserved_ready) + goto out; + + ret = ieee80211_link_use_reserved_context(link); + if (ret) { + sdata_info(sdata, + "failed to use reserved channel context, disconnecting (err=%d)\n", + ret); + ieee80211_queue_work(&sdata->local->hw, + &ifmgd->csa_connection_drop_work); + goto out; + } + + goto out; + } + + if (!cfg80211_chandef_identical(&link->conf->chandef, + &link->csa_chandef)) { + sdata_info(sdata, + "failed to finalize channel switch, disconnecting\n"); + ieee80211_queue_work(&sdata->local->hw, + &ifmgd->csa_connection_drop_work); + goto out; + } + + link->u.mgd.csa_waiting_bcn = true; + + ieee80211_sta_reset_beacon_monitor(sdata); + ieee80211_sta_reset_conn_monitor(sdata); + +out: + mutex_unlock(&local->chanctx_mtx); + mutex_unlock(&local->mtx); + sdata_unlock(sdata); +} + +static void ieee80211_chswitch_post_beacon(struct ieee80211_link_data *link) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + int ret; + + sdata_assert_lock(sdata); + + WARN_ON(!link->conf->csa_active); + + if (link->csa_block_tx) { + ieee80211_wake_vif_queues(local, sdata, + IEEE80211_QUEUE_STOP_REASON_CSA); + link->csa_block_tx = false; + } + + link->conf->csa_active = false; + link->u.mgd.csa_waiting_bcn = false; + /* + * If the CSA IE is still present on the beacon after the switch, + * we need to consider it as a new CSA (possibly to self). + */ + link->u.mgd.beacon_crc_valid = false; + + ret = drv_post_channel_switch(sdata); + if (ret) { + sdata_info(sdata, + "driver post channel switch failed, disconnecting\n"); + ieee80211_queue_work(&local->hw, + &ifmgd->csa_connection_drop_work); + return; + } + + cfg80211_ch_switch_notify(sdata->dev, &link->reserved_chandef, 0); +} + +void ieee80211_chswitch_done(struct ieee80211_vif *vif, bool success) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + + if (WARN_ON(sdata->vif.valid_links)) + success = false; + + trace_api_chswitch_done(sdata, success); + if (!success) { + sdata_info(sdata, + "driver channel switch failed, disconnecting\n"); + ieee80211_queue_work(&sdata->local->hw, + &ifmgd->csa_connection_drop_work); + } else { + ieee80211_queue_work(&sdata->local->hw, + &sdata->deflink.u.mgd.chswitch_work); + } +} +EXPORT_SYMBOL(ieee80211_chswitch_done); + +static void ieee80211_chswitch_timer(struct timer_list *t) +{ + struct ieee80211_link_data *link = + from_timer(link, t, u.mgd.chswitch_timer); + + ieee80211_queue_work(&link->sdata->local->hw, + &link->u.mgd.chswitch_work); +} + +static void +ieee80211_sta_abort_chanswitch(struct ieee80211_link_data *link) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_local *local = sdata->local; + + if (!local->ops->abort_channel_switch) + return; + + mutex_lock(&local->mtx); + + mutex_lock(&local->chanctx_mtx); + ieee80211_link_unreserve_chanctx(link); + mutex_unlock(&local->chanctx_mtx); + + if (link->csa_block_tx) + ieee80211_wake_vif_queues(local, sdata, + IEEE80211_QUEUE_STOP_REASON_CSA); + + link->csa_block_tx = false; + link->conf->csa_active = false; + + mutex_unlock(&local->mtx); + + drv_abort_channel_switch(sdata); +} + +static void +ieee80211_sta_process_chanswitch(struct ieee80211_link_data *link, + u64 timestamp, u32 device_timestamp, + struct ieee802_11_elems *elems, + bool beacon) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct cfg80211_bss *cbss = link->u.mgd.bss; + struct ieee80211_chanctx_conf *conf; + struct ieee80211_chanctx *chanctx; + enum nl80211_band current_band; + struct ieee80211_csa_ie csa_ie; + struct ieee80211_channel_switch ch_switch; + struct ieee80211_bss *bss; + int res; + + sdata_assert_lock(sdata); + + if (!cbss) + return; + + if (local->scanning) + return; + + current_band = cbss->channel->band; + bss = (void *)cbss->priv; + res = ieee80211_parse_ch_switch_ie(sdata, elems, current_band, + bss->vht_cap_info, + link->u.mgd.conn_flags, + link->u.mgd.bssid, &csa_ie); + + if (!res) { + ch_switch.timestamp = timestamp; + ch_switch.device_timestamp = device_timestamp; + ch_switch.block_tx = csa_ie.mode; + ch_switch.chandef = csa_ie.chandef; + ch_switch.count = csa_ie.count; + ch_switch.delay = csa_ie.max_switch_time; + } + + if (res < 0) + goto lock_and_drop_connection; + + if (beacon && link->conf->csa_active && + !link->u.mgd.csa_waiting_bcn) { + if (res) + ieee80211_sta_abort_chanswitch(link); + else + drv_channel_switch_rx_beacon(sdata, &ch_switch); + return; + } else if (link->conf->csa_active || res) { + /* disregard subsequent announcements if already processing */ + return; + } + + if (link->conf->chandef.chan->band != + csa_ie.chandef.chan->band) { + sdata_info(sdata, + "AP %pM switches to different band (%d MHz, width:%d, CF1/2: %d/%d MHz), disconnecting\n", + link->u.mgd.bssid, + csa_ie.chandef.chan->center_freq, + csa_ie.chandef.width, csa_ie.chandef.center_freq1, + csa_ie.chandef.center_freq2); + goto lock_and_drop_connection; + } + + if (!cfg80211_chandef_usable(local->hw.wiphy, &csa_ie.chandef, + IEEE80211_CHAN_DISABLED)) { + sdata_info(sdata, + "AP %pM switches to unsupported channel " + "(%d.%03d MHz, width:%d, CF1/2: %d.%03d/%d MHz), " + "disconnecting\n", + link->u.mgd.bssid, + csa_ie.chandef.chan->center_freq, + csa_ie.chandef.chan->freq_offset, + csa_ie.chandef.width, csa_ie.chandef.center_freq1, + csa_ie.chandef.freq1_offset, + csa_ie.chandef.center_freq2); + goto lock_and_drop_connection; + } + + if (cfg80211_chandef_identical(&csa_ie.chandef, + &link->conf->chandef) && + (!csa_ie.mode || !beacon)) { + if (link->u.mgd.csa_ignored_same_chan) + return; + sdata_info(sdata, + "AP %pM tries to chanswitch to same channel, ignore\n", + link->u.mgd.bssid); + link->u.mgd.csa_ignored_same_chan = true; + return; + } + + /* + * Drop all TDLS peers - either we disconnect or move to a different + * channel from this point on. There's no telling what our peer will do. + * The TDLS WIDER_BW scenario is also problematic, as peers might now + * have an incompatible wider chandef. + */ + ieee80211_teardown_tdls_peers(sdata); + + mutex_lock(&local->mtx); + mutex_lock(&local->chanctx_mtx); + conf = rcu_dereference_protected(link->conf->chanctx_conf, + lockdep_is_held(&local->chanctx_mtx)); + if (!conf) { + sdata_info(sdata, + "no channel context assigned to vif?, disconnecting\n"); + goto drop_connection; + } + + chanctx = container_of(conf, struct ieee80211_chanctx, conf); + + if (local->use_chanctx && + !ieee80211_hw_check(&local->hw, CHANCTX_STA_CSA)) { + sdata_info(sdata, + "driver doesn't support chan-switch with channel contexts\n"); + goto drop_connection; + } + + if (drv_pre_channel_switch(sdata, &ch_switch)) { + sdata_info(sdata, + "preparing for channel switch failed, disconnecting\n"); + goto drop_connection; + } + + res = ieee80211_link_reserve_chanctx(link, &csa_ie.chandef, + chanctx->mode, false); + if (res) { + sdata_info(sdata, + "failed to reserve channel context for channel switch, disconnecting (err=%d)\n", + res); + goto drop_connection; + } + mutex_unlock(&local->chanctx_mtx); + + link->conf->csa_active = true; + link->csa_chandef = csa_ie.chandef; + link->csa_block_tx = csa_ie.mode; + link->u.mgd.csa_ignored_same_chan = false; + link->u.mgd.beacon_crc_valid = false; + + if (link->csa_block_tx) + ieee80211_stop_vif_queues(local, sdata, + IEEE80211_QUEUE_STOP_REASON_CSA); + mutex_unlock(&local->mtx); + + cfg80211_ch_switch_started_notify(sdata->dev, &csa_ie.chandef, 0, + csa_ie.count, csa_ie.mode); + + if (local->ops->channel_switch) { + /* use driver's channel switch callback */ + drv_channel_switch(local, sdata, &ch_switch); + return; + } + + /* channel switch handled in software */ + if (csa_ie.count <= 1) + ieee80211_queue_work(&local->hw, &link->u.mgd.chswitch_work); + else + mod_timer(&link->u.mgd.chswitch_timer, + TU_TO_EXP_TIME((csa_ie.count - 1) * + cbss->beacon_interval)); + return; + lock_and_drop_connection: + mutex_lock(&local->mtx); + mutex_lock(&local->chanctx_mtx); + drop_connection: + /* + * This is just so that the disconnect flow will know that + * we were trying to switch channel and failed. In case the + * mode is 1 (we are not allowed to Tx), we will know not to + * send a deauthentication frame. Those two fields will be + * reset when the disconnection worker runs. + */ + link->conf->csa_active = true; + link->csa_block_tx = csa_ie.mode; + + ieee80211_queue_work(&local->hw, &ifmgd->csa_connection_drop_work); + mutex_unlock(&local->chanctx_mtx); + mutex_unlock(&local->mtx); +} + +static bool +ieee80211_find_80211h_pwr_constr(struct ieee80211_sub_if_data *sdata, + struct ieee80211_channel *channel, + const u8 *country_ie, u8 country_ie_len, + const u8 *pwr_constr_elem, + int *chan_pwr, int *pwr_reduction) +{ + struct ieee80211_country_ie_triplet *triplet; + int chan = ieee80211_frequency_to_channel(channel->center_freq); + int i, chan_increment; + bool have_chan_pwr = false; + + /* Invalid IE */ + if (country_ie_len % 2 || country_ie_len < IEEE80211_COUNTRY_IE_MIN_LEN) + return false; + + triplet = (void *)(country_ie + 3); + country_ie_len -= 3; + + switch (channel->band) { + default: + WARN_ON_ONCE(1); + fallthrough; + case NL80211_BAND_2GHZ: + case NL80211_BAND_60GHZ: + case NL80211_BAND_LC: + chan_increment = 1; + break; + case NL80211_BAND_5GHZ: + chan_increment = 4; + break; + case NL80211_BAND_6GHZ: + /* + * In the 6 GHz band, the "maximum transmit power level" + * field in the triplets is reserved, and thus will be + * zero and we shouldn't use it to control TX power. + * The actual TX power will be given in the transmit + * power envelope element instead. + */ + return false; + } + + /* find channel */ + while (country_ie_len >= 3) { + u8 first_channel = triplet->chans.first_channel; + + if (first_channel >= IEEE80211_COUNTRY_EXTENSION_ID) + goto next; + + for (i = 0; i < triplet->chans.num_channels; i++) { + if (first_channel + i * chan_increment == chan) { + have_chan_pwr = true; + *chan_pwr = triplet->chans.max_power; + break; + } + } + if (have_chan_pwr) + break; + + next: + triplet++; + country_ie_len -= 3; + } + + if (have_chan_pwr && pwr_constr_elem) + *pwr_reduction = *pwr_constr_elem; + else + *pwr_reduction = 0; + + return have_chan_pwr; +} + +static void ieee80211_find_cisco_dtpc(struct ieee80211_sub_if_data *sdata, + struct ieee80211_channel *channel, + const u8 *cisco_dtpc_ie, + int *pwr_level) +{ + /* From practical testing, the first data byte of the DTPC element + * seems to contain the requested dBm level, and the CLI on Cisco + * APs clearly state the range is -127 to 127 dBm, which indicates + * a signed byte, although it seemingly never actually goes negative. + * The other byte seems to always be zero. + */ + *pwr_level = (__s8)cisco_dtpc_ie[4]; +} + +static u32 ieee80211_handle_pwr_constr(struct ieee80211_link_data *link, + struct ieee80211_channel *channel, + struct ieee80211_mgmt *mgmt, + const u8 *country_ie, u8 country_ie_len, + const u8 *pwr_constr_ie, + const u8 *cisco_dtpc_ie) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + bool has_80211h_pwr = false, has_cisco_pwr = false; + int chan_pwr = 0, pwr_reduction_80211h = 0; + int pwr_level_cisco, pwr_level_80211h; + int new_ap_level; + __le16 capab = mgmt->u.probe_resp.capab_info; + + if (ieee80211_is_s1g_beacon(mgmt->frame_control)) + return 0; /* TODO */ + + if (country_ie && + (capab & cpu_to_le16(WLAN_CAPABILITY_SPECTRUM_MGMT) || + capab & cpu_to_le16(WLAN_CAPABILITY_RADIO_MEASURE))) { + has_80211h_pwr = ieee80211_find_80211h_pwr_constr( + sdata, channel, country_ie, country_ie_len, + pwr_constr_ie, &chan_pwr, &pwr_reduction_80211h); + pwr_level_80211h = + max_t(int, 0, chan_pwr - pwr_reduction_80211h); + } + + if (cisco_dtpc_ie) { + ieee80211_find_cisco_dtpc( + sdata, channel, cisco_dtpc_ie, &pwr_level_cisco); + has_cisco_pwr = true; + } + + if (!has_80211h_pwr && !has_cisco_pwr) + return 0; + + /* If we have both 802.11h and Cisco DTPC, apply both limits + * by picking the smallest of the two power levels advertised. + */ + if (has_80211h_pwr && + (!has_cisco_pwr || pwr_level_80211h <= pwr_level_cisco)) { + new_ap_level = pwr_level_80211h; + + if (link->ap_power_level == new_ap_level) + return 0; + + sdata_dbg(sdata, + "Limiting TX power to %d (%d - %d) dBm as advertised by %pM\n", + pwr_level_80211h, chan_pwr, pwr_reduction_80211h, + link->u.mgd.bssid); + } else { /* has_cisco_pwr is always true here. */ + new_ap_level = pwr_level_cisco; + + if (link->ap_power_level == new_ap_level) + return 0; + + sdata_dbg(sdata, + "Limiting TX power to %d dBm as advertised by %pM\n", + pwr_level_cisco, link->u.mgd.bssid); + } + + link->ap_power_level = new_ap_level; + if (__ieee80211_recalc_txpower(sdata)) + return BSS_CHANGED_TXPOWER; + return 0; +} + +/* powersave */ +static void ieee80211_enable_ps(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_conf *conf = &local->hw.conf; + + /* + * If we are scanning right now then the parameters will + * take effect when scan finishes. + */ + if (local->scanning) + return; + + if (conf->dynamic_ps_timeout > 0 && + !ieee80211_hw_check(&local->hw, SUPPORTS_DYNAMIC_PS)) { + mod_timer(&local->dynamic_ps_timer, jiffies + + msecs_to_jiffies(conf->dynamic_ps_timeout)); + } else { + if (ieee80211_hw_check(&local->hw, PS_NULLFUNC_STACK)) + ieee80211_send_nullfunc(local, sdata, true); + + if (ieee80211_hw_check(&local->hw, PS_NULLFUNC_STACK) && + ieee80211_hw_check(&local->hw, REPORTS_TX_ACK_STATUS)) + return; + + conf->flags |= IEEE80211_CONF_PS; + ieee80211_hw_config(local, IEEE80211_CONF_CHANGE_PS); + } +} + +static void ieee80211_change_ps(struct ieee80211_local *local) +{ + struct ieee80211_conf *conf = &local->hw.conf; + + if (local->ps_sdata) { + ieee80211_enable_ps(local, local->ps_sdata); + } else if (conf->flags & IEEE80211_CONF_PS) { + conf->flags &= ~IEEE80211_CONF_PS; + ieee80211_hw_config(local, IEEE80211_CONF_CHANGE_PS); + del_timer_sync(&local->dynamic_ps_timer); + cancel_work_sync(&local->dynamic_ps_enable_work); + } +} + +static bool ieee80211_powersave_allowed(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_managed *mgd = &sdata->u.mgd; + struct sta_info *sta = NULL; + bool authorized = false; + + if (!mgd->powersave) + return false; + + if (mgd->broken_ap) + return false; + + if (!mgd->associated) + return false; + + if (mgd->flags & IEEE80211_STA_CONNECTION_POLL) + return false; + + if (!(local->hw.wiphy->flags & WIPHY_FLAG_SUPPORTS_MLO) && + !sdata->deflink.u.mgd.have_beacon) + return false; + + rcu_read_lock(); + sta = sta_info_get(sdata, sdata->vif.cfg.ap_addr); + if (sta) + authorized = test_sta_flag(sta, WLAN_STA_AUTHORIZED); + rcu_read_unlock(); + + return authorized; +} + +/* need to hold RTNL or interface lock */ +void ieee80211_recalc_ps(struct ieee80211_local *local) +{ + struct ieee80211_sub_if_data *sdata, *found = NULL; + int count = 0; + int timeout; + + if (!ieee80211_hw_check(&local->hw, SUPPORTS_PS) || + ieee80211_hw_check(&local->hw, SUPPORTS_DYNAMIC_PS)) { + local->ps_sdata = NULL; + return; + } + + list_for_each_entry(sdata, &local->interfaces, list) { + if (!ieee80211_sdata_running(sdata)) + continue; + if (sdata->vif.type == NL80211_IFTYPE_AP) { + /* If an AP vif is found, then disable PS + * by setting the count to zero thereby setting + * ps_sdata to NULL. + */ + count = 0; + break; + } + if (sdata->vif.type != NL80211_IFTYPE_STATION) + continue; + found = sdata; + count++; + } + + if (count == 1 && ieee80211_powersave_allowed(found)) { + u8 dtimper = found->deflink.u.mgd.dtim_period; + + timeout = local->dynamic_ps_forced_timeout; + if (timeout < 0) + timeout = 100; + local->hw.conf.dynamic_ps_timeout = timeout; + + /* If the TIM IE is invalid, pretend the value is 1 */ + if (!dtimper) + dtimper = 1; + + local->hw.conf.ps_dtim_period = dtimper; + local->ps_sdata = found; + } else { + local->ps_sdata = NULL; + } + + ieee80211_change_ps(local); +} + +void ieee80211_recalc_ps_vif(struct ieee80211_sub_if_data *sdata) +{ + bool ps_allowed = ieee80211_powersave_allowed(sdata); + + if (sdata->vif.cfg.ps != ps_allowed) { + sdata->vif.cfg.ps = ps_allowed; + ieee80211_vif_cfg_change_notify(sdata, BSS_CHANGED_PS); + } +} + +void ieee80211_dynamic_ps_disable_work(struct work_struct *work) +{ + struct ieee80211_local *local = + container_of(work, struct ieee80211_local, + dynamic_ps_disable_work); + + if (local->hw.conf.flags & IEEE80211_CONF_PS) { + local->hw.conf.flags &= ~IEEE80211_CONF_PS; + ieee80211_hw_config(local, IEEE80211_CONF_CHANGE_PS); + } + + ieee80211_wake_queues_by_reason(&local->hw, + IEEE80211_MAX_QUEUE_MAP, + IEEE80211_QUEUE_STOP_REASON_PS, + false); +} + +void ieee80211_dynamic_ps_enable_work(struct work_struct *work) +{ + struct ieee80211_local *local = + container_of(work, struct ieee80211_local, + dynamic_ps_enable_work); + struct ieee80211_sub_if_data *sdata = local->ps_sdata; + struct ieee80211_if_managed *ifmgd; + unsigned long flags; + int q; + + /* can only happen when PS was just disabled anyway */ + if (!sdata) + return; + + ifmgd = &sdata->u.mgd; + + if (local->hw.conf.flags & IEEE80211_CONF_PS) + return; + + if (local->hw.conf.dynamic_ps_timeout > 0) { + /* don't enter PS if TX frames are pending */ + if (drv_tx_frames_pending(local)) { + mod_timer(&local->dynamic_ps_timer, jiffies + + msecs_to_jiffies( + local->hw.conf.dynamic_ps_timeout)); + return; + } + + /* + * transmission can be stopped by others which leads to + * dynamic_ps_timer expiry. Postpone the ps timer if it + * is not the actual idle state. + */ + spin_lock_irqsave(&local->queue_stop_reason_lock, flags); + for (q = 0; q < local->hw.queues; q++) { + if (local->queue_stop_reasons[q]) { + spin_unlock_irqrestore(&local->queue_stop_reason_lock, + flags); + mod_timer(&local->dynamic_ps_timer, jiffies + + msecs_to_jiffies( + local->hw.conf.dynamic_ps_timeout)); + return; + } + } + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); + } + + if (ieee80211_hw_check(&local->hw, PS_NULLFUNC_STACK) && + !(ifmgd->flags & IEEE80211_STA_NULLFUNC_ACKED)) { + if (drv_tx_frames_pending(local)) { + mod_timer(&local->dynamic_ps_timer, jiffies + + msecs_to_jiffies( + local->hw.conf.dynamic_ps_timeout)); + } else { + ieee80211_send_nullfunc(local, sdata, true); + /* Flush to get the tx status of nullfunc frame */ + ieee80211_flush_queues(local, sdata, false); + } + } + + if (!(ieee80211_hw_check(&local->hw, REPORTS_TX_ACK_STATUS) && + ieee80211_hw_check(&local->hw, PS_NULLFUNC_STACK)) || + (ifmgd->flags & IEEE80211_STA_NULLFUNC_ACKED)) { + ifmgd->flags &= ~IEEE80211_STA_NULLFUNC_ACKED; + local->hw.conf.flags |= IEEE80211_CONF_PS; + ieee80211_hw_config(local, IEEE80211_CONF_CHANGE_PS); + } +} + +void ieee80211_dynamic_ps_timer(struct timer_list *t) +{ + struct ieee80211_local *local = from_timer(local, t, dynamic_ps_timer); + + ieee80211_queue_work(&local->hw, &local->dynamic_ps_enable_work); +} + +void ieee80211_dfs_cac_timer_work(struct work_struct *work) +{ + struct delayed_work *delayed_work = to_delayed_work(work); + struct ieee80211_link_data *link = + container_of(delayed_work, struct ieee80211_link_data, + dfs_cac_timer_work); + struct cfg80211_chan_def chandef = link->conf->chandef; + struct ieee80211_sub_if_data *sdata = link->sdata; + + mutex_lock(&sdata->local->mtx); + if (sdata->wdev.cac_started) { + ieee80211_link_release_channel(link); + cfg80211_cac_event(sdata->dev, &chandef, + NL80211_RADAR_CAC_FINISHED, + GFP_KERNEL); + } + mutex_unlock(&sdata->local->mtx); +} + +static bool +__ieee80211_sta_handle_tspec_ac_params(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + bool ret = false; + int ac; + + if (local->hw.queues < IEEE80211_NUM_ACS) + return false; + + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { + struct ieee80211_sta_tx_tspec *tx_tspec = &ifmgd->tx_tspec[ac]; + int non_acm_ac; + unsigned long now = jiffies; + + if (tx_tspec->action == TX_TSPEC_ACTION_NONE && + tx_tspec->admitted_time && + time_after(now, tx_tspec->time_slice_start + HZ)) { + tx_tspec->consumed_tx_time = 0; + tx_tspec->time_slice_start = now; + + if (tx_tspec->downgraded) + tx_tspec->action = + TX_TSPEC_ACTION_STOP_DOWNGRADE; + } + + switch (tx_tspec->action) { + case TX_TSPEC_ACTION_STOP_DOWNGRADE: + /* take the original parameters */ + if (drv_conf_tx(local, &sdata->deflink, ac, + &sdata->deflink.tx_conf[ac])) + link_err(&sdata->deflink, + "failed to set TX queue parameters for queue %d\n", + ac); + tx_tspec->action = TX_TSPEC_ACTION_NONE; + tx_tspec->downgraded = false; + ret = true; + break; + case TX_TSPEC_ACTION_DOWNGRADE: + if (time_after(now, tx_tspec->time_slice_start + HZ)) { + tx_tspec->action = TX_TSPEC_ACTION_NONE; + ret = true; + break; + } + /* downgrade next lower non-ACM AC */ + for (non_acm_ac = ac + 1; + non_acm_ac < IEEE80211_NUM_ACS; + non_acm_ac++) + if (!(sdata->wmm_acm & BIT(7 - 2 * non_acm_ac))) + break; + /* Usually the loop will result in using BK even if it + * requires admission control, but such a configuration + * makes no sense and we have to transmit somehow - the + * AC selection does the same thing. + * If we started out trying to downgrade from BK, then + * the extra condition here might be needed. + */ + if (non_acm_ac >= IEEE80211_NUM_ACS) + non_acm_ac = IEEE80211_AC_BK; + if (drv_conf_tx(local, &sdata->deflink, ac, + &sdata->deflink.tx_conf[non_acm_ac])) + link_err(&sdata->deflink, + "failed to set TX queue parameters for queue %d\n", + ac); + tx_tspec->action = TX_TSPEC_ACTION_NONE; + ret = true; + schedule_delayed_work(&ifmgd->tx_tspec_wk, + tx_tspec->time_slice_start + HZ - now + 1); + break; + case TX_TSPEC_ACTION_NONE: + /* nothing now */ + break; + } + } + + return ret; +} + +void ieee80211_sta_handle_tspec_ac_params(struct ieee80211_sub_if_data *sdata) +{ + if (__ieee80211_sta_handle_tspec_ac_params(sdata)) + ieee80211_link_info_change_notify(sdata, &sdata->deflink, + BSS_CHANGED_QOS); +} + +static void ieee80211_sta_handle_tspec_ac_params_wk(struct work_struct *work) +{ + struct ieee80211_sub_if_data *sdata; + + sdata = container_of(work, struct ieee80211_sub_if_data, + u.mgd.tx_tspec_wk.work); + ieee80211_sta_handle_tspec_ac_params(sdata); +} + +void ieee80211_mgd_set_link_qos_params(struct ieee80211_link_data *link) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_tx_queue_params *params = link->tx_conf; + u8 ac; + + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { + mlme_dbg(sdata, + "WMM AC=%d acm=%d aifs=%d cWmin=%d cWmax=%d txop=%d uapsd=%d, downgraded=%d\n", + ac, params[ac].acm, + params[ac].aifs, params[ac].cw_min, params[ac].cw_max, + params[ac].txop, params[ac].uapsd, + ifmgd->tx_tspec[ac].downgraded); + if (!ifmgd->tx_tspec[ac].downgraded && + drv_conf_tx(local, link, ac, ¶ms[ac])) + link_err(link, + "failed to set TX queue parameters for AC %d\n", + ac); + } +} + +/* MLME */ +static bool +ieee80211_sta_wmm_params(struct ieee80211_local *local, + struct ieee80211_link_data *link, + const u8 *wmm_param, size_t wmm_param_len, + const struct ieee80211_mu_edca_param_set *mu_edca) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_tx_queue_params params[IEEE80211_NUM_ACS]; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + size_t left; + int count, mu_edca_count, ac; + const u8 *pos; + u8 uapsd_queues = 0; + + if (!local->ops->conf_tx) + return false; + + if (local->hw.queues < IEEE80211_NUM_ACS) + return false; + + if (!wmm_param) + return false; + + if (wmm_param_len < 8 || wmm_param[5] /* version */ != 1) + return false; + + if (ifmgd->flags & IEEE80211_STA_UAPSD_ENABLED) + uapsd_queues = ifmgd->uapsd_queues; + + count = wmm_param[6] & 0x0f; + /* -1 is the initial value of ifmgd->mu_edca_last_param_set. + * if mu_edca was preset before and now it disappeared tell + * the driver about it. + */ + mu_edca_count = mu_edca ? mu_edca->mu_qos_info & 0x0f : -1; + if (count == link->u.mgd.wmm_last_param_set && + mu_edca_count == link->u.mgd.mu_edca_last_param_set) + return false; + link->u.mgd.wmm_last_param_set = count; + link->u.mgd.mu_edca_last_param_set = mu_edca_count; + + pos = wmm_param + 8; + left = wmm_param_len - 8; + + memset(¶ms, 0, sizeof(params)); + + sdata->wmm_acm = 0; + for (; left >= 4; left -= 4, pos += 4) { + int aci = (pos[0] >> 5) & 0x03; + int acm = (pos[0] >> 4) & 0x01; + bool uapsd = false; + + switch (aci) { + case 1: /* AC_BK */ + ac = IEEE80211_AC_BK; + if (acm) + sdata->wmm_acm |= BIT(1) | BIT(2); /* BK/- */ + if (uapsd_queues & IEEE80211_WMM_IE_STA_QOSINFO_AC_BK) + uapsd = true; + params[ac].mu_edca = !!mu_edca; + if (mu_edca) + params[ac].mu_edca_param_rec = mu_edca->ac_bk; + break; + case 2: /* AC_VI */ + ac = IEEE80211_AC_VI; + if (acm) + sdata->wmm_acm |= BIT(4) | BIT(5); /* CL/VI */ + if (uapsd_queues & IEEE80211_WMM_IE_STA_QOSINFO_AC_VI) + uapsd = true; + params[ac].mu_edca = !!mu_edca; + if (mu_edca) + params[ac].mu_edca_param_rec = mu_edca->ac_vi; + break; + case 3: /* AC_VO */ + ac = IEEE80211_AC_VO; + if (acm) + sdata->wmm_acm |= BIT(6) | BIT(7); /* VO/NC */ + if (uapsd_queues & IEEE80211_WMM_IE_STA_QOSINFO_AC_VO) + uapsd = true; + params[ac].mu_edca = !!mu_edca; + if (mu_edca) + params[ac].mu_edca_param_rec = mu_edca->ac_vo; + break; + case 0: /* AC_BE */ + default: + ac = IEEE80211_AC_BE; + if (acm) + sdata->wmm_acm |= BIT(0) | BIT(3); /* BE/EE */ + if (uapsd_queues & IEEE80211_WMM_IE_STA_QOSINFO_AC_BE) + uapsd = true; + params[ac].mu_edca = !!mu_edca; + if (mu_edca) + params[ac].mu_edca_param_rec = mu_edca->ac_be; + break; + } + + params[ac].aifs = pos[0] & 0x0f; + + if (params[ac].aifs < 2) { + sdata_info(sdata, + "AP has invalid WMM params (AIFSN=%d for ACI %d), will use 2\n", + params[ac].aifs, aci); + params[ac].aifs = 2; + } + params[ac].cw_max = ecw2cw((pos[1] & 0xf0) >> 4); + params[ac].cw_min = ecw2cw(pos[1] & 0x0f); + params[ac].txop = get_unaligned_le16(pos + 2); + params[ac].acm = acm; + params[ac].uapsd = uapsd; + + if (params[ac].cw_min == 0 || + params[ac].cw_min > params[ac].cw_max) { + sdata_info(sdata, + "AP has invalid WMM params (CWmin/max=%d/%d for ACI %d), using defaults\n", + params[ac].cw_min, params[ac].cw_max, aci); + return false; + } + ieee80211_regulatory_limit_wmm_params(sdata, ¶ms[ac], ac); + } + + /* WMM specification requires all 4 ACIs. */ + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { + if (params[ac].cw_min == 0) { + sdata_info(sdata, + "AP has invalid WMM params (missing AC %d), using defaults\n", + ac); + return false; + } + } + + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) + link->tx_conf[ac] = params[ac]; + + ieee80211_mgd_set_link_qos_params(link); + + /* enable WMM or activate new settings */ + link->conf->qos = true; + return true; +} + +static void __ieee80211_stop_poll(struct ieee80211_sub_if_data *sdata) +{ + lockdep_assert_held(&sdata->local->mtx); + + sdata->u.mgd.flags &= ~IEEE80211_STA_CONNECTION_POLL; + ieee80211_run_deferred_scan(sdata->local); +} + +static void ieee80211_stop_poll(struct ieee80211_sub_if_data *sdata) +{ + mutex_lock(&sdata->local->mtx); + __ieee80211_stop_poll(sdata); + mutex_unlock(&sdata->local->mtx); +} + +static u32 ieee80211_handle_bss_capability(struct ieee80211_link_data *link, + u16 capab, bool erp_valid, u8 erp) +{ + struct ieee80211_bss_conf *bss_conf = link->conf; + struct ieee80211_supported_band *sband; + u32 changed = 0; + bool use_protection; + bool use_short_preamble; + bool use_short_slot; + + sband = ieee80211_get_link_sband(link); + if (!sband) + return changed; + + if (erp_valid) { + use_protection = (erp & WLAN_ERP_USE_PROTECTION) != 0; + use_short_preamble = (erp & WLAN_ERP_BARKER_PREAMBLE) == 0; + } else { + use_protection = false; + use_short_preamble = !!(capab & WLAN_CAPABILITY_SHORT_PREAMBLE); + } + + use_short_slot = !!(capab & WLAN_CAPABILITY_SHORT_SLOT_TIME); + if (sband->band == NL80211_BAND_5GHZ || + sband->band == NL80211_BAND_6GHZ) + use_short_slot = true; + + if (use_protection != bss_conf->use_cts_prot) { + bss_conf->use_cts_prot = use_protection; + changed |= BSS_CHANGED_ERP_CTS_PROT; + } + + if (use_short_preamble != bss_conf->use_short_preamble) { + bss_conf->use_short_preamble = use_short_preamble; + changed |= BSS_CHANGED_ERP_PREAMBLE; + } + + if (use_short_slot != bss_conf->use_short_slot) { + bss_conf->use_short_slot = use_short_slot; + changed |= BSS_CHANGED_ERP_SLOT; + } + + return changed; +} + +static u32 ieee80211_link_set_associated(struct ieee80211_link_data *link, + struct cfg80211_bss *cbss) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_bss_conf *bss_conf = link->conf; + struct ieee80211_bss *bss = (void *)cbss->priv; + u32 changed = BSS_CHANGED_QOS; + + /* not really used in MLO */ + sdata->u.mgd.beacon_timeout = + usecs_to_jiffies(ieee80211_tu_to_usec(beacon_loss_count * + bss_conf->beacon_int)); + + changed |= ieee80211_handle_bss_capability(link, + bss_conf->assoc_capability, + bss->has_erp_value, + bss->erp_value); + + ieee80211_check_rate_mask(link); + + link->u.mgd.bss = cbss; + memcpy(link->u.mgd.bssid, cbss->bssid, ETH_ALEN); + + if (sdata->vif.p2p || + sdata->vif.driver_flags & IEEE80211_VIF_GET_NOA_UPDATE) { + const struct cfg80211_bss_ies *ies; + + rcu_read_lock(); + ies = rcu_dereference(cbss->ies); + if (ies) { + int ret; + + ret = cfg80211_get_p2p_attr( + ies->data, ies->len, + IEEE80211_P2P_ATTR_ABSENCE_NOTICE, + (u8 *) &bss_conf->p2p_noa_attr, + sizeof(bss_conf->p2p_noa_attr)); + if (ret >= 2) { + link->u.mgd.p2p_noa_index = + bss_conf->p2p_noa_attr.index; + changed |= BSS_CHANGED_P2P_PS; + } + } + rcu_read_unlock(); + } + + if (link->u.mgd.have_beacon) { + /* + * If the AP is buggy we may get here with no DTIM period + * known, so assume it's 1 which is the only safe assumption + * in that case, although if the TIM IE is broken powersave + * probably just won't work at all. + */ + bss_conf->dtim_period = link->u.mgd.dtim_period ?: 1; + bss_conf->beacon_rate = bss->beacon_rate; + changed |= BSS_CHANGED_BEACON_INFO; + } else { + bss_conf->beacon_rate = NULL; + bss_conf->dtim_period = 0; + } + + /* Tell the driver to monitor connection quality (if supported) */ + if (sdata->vif.driver_flags & IEEE80211_VIF_SUPPORTS_CQM_RSSI && + bss_conf->cqm_rssi_thold) + changed |= BSS_CHANGED_CQM; + + return changed; +} + +static void ieee80211_set_associated(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgd_assoc_data *assoc_data, + u64 changed[IEEE80211_MLD_MAX_NUM_LINKS]) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_vif_cfg *vif_cfg = &sdata->vif.cfg; + u64 vif_changed = BSS_CHANGED_ASSOC; + unsigned int link_id; + + sdata->u.mgd.associated = true; + + for (link_id = 0; link_id < IEEE80211_MLD_MAX_NUM_LINKS; link_id++) { + struct cfg80211_bss *cbss = assoc_data->link[link_id].bss; + struct ieee80211_link_data *link; + + if (!cbss) + continue; + + link = sdata_dereference(sdata->link[link_id], sdata); + if (WARN_ON(!link)) + return; + + changed[link_id] |= ieee80211_link_set_associated(link, cbss); + } + + /* just to be sure */ + ieee80211_stop_poll(sdata); + + ieee80211_led_assoc(local, 1); + + vif_cfg->assoc = 1; + + /* Enable ARP filtering */ + if (vif_cfg->arp_addr_cnt) + vif_changed |= BSS_CHANGED_ARP_FILTER; + + if (sdata->vif.valid_links) { + for (link_id = 0; + link_id < IEEE80211_MLD_MAX_NUM_LINKS; + link_id++) { + struct ieee80211_link_data *link; + struct cfg80211_bss *cbss = assoc_data->link[link_id].bss; + + if (!cbss) + continue; + + link = sdata_dereference(sdata->link[link_id], sdata); + if (WARN_ON(!link)) + return; + + ieee80211_link_info_change_notify(sdata, link, + changed[link_id]); + + ieee80211_recalc_smps(sdata, link); + } + + ieee80211_vif_cfg_change_notify(sdata, vif_changed); + } else { + ieee80211_bss_info_change_notify(sdata, + vif_changed | changed[0]); + } + + mutex_lock(&local->iflist_mtx); + ieee80211_recalc_ps(local); + mutex_unlock(&local->iflist_mtx); + + /* leave this here to not change ordering in non-MLO cases */ + if (!sdata->vif.valid_links) + ieee80211_recalc_smps(sdata, &sdata->deflink); + ieee80211_recalc_ps_vif(sdata); + + netif_carrier_on(sdata->dev); +} + +static void ieee80211_set_disassoc(struct ieee80211_sub_if_data *sdata, + u16 stype, u16 reason, bool tx, + u8 *frame_buf) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_local *local = sdata->local; + unsigned int link_id; + u32 changed = 0; + struct ieee80211_prep_tx_info info = { + .subtype = stype, + }; + + sdata_assert_lock(sdata); + + if (WARN_ON_ONCE(tx && !frame_buf)) + return; + + if (WARN_ON(!ifmgd->associated)) + return; + + ieee80211_stop_poll(sdata); + + ifmgd->associated = false; + + /* other links will be destroyed */ + sdata->deflink.u.mgd.bss = NULL; + + netif_carrier_off(sdata->dev); + + /* + * if we want to get out of ps before disassoc (why?) we have + * to do it before sending disassoc, as otherwise the null-packet + * won't be valid. + */ + if (local->hw.conf.flags & IEEE80211_CONF_PS) { + local->hw.conf.flags &= ~IEEE80211_CONF_PS; + ieee80211_hw_config(local, IEEE80211_CONF_CHANGE_PS); + } + local->ps_sdata = NULL; + + /* disable per-vif ps */ + ieee80211_recalc_ps_vif(sdata); + + /* make sure ongoing transmission finishes */ + synchronize_net(); + + /* + * drop any frame before deauth/disassoc, this can be data or + * management frame. Since we are disconnecting, we should not + * insist sending these frames which can take time and delay + * the disconnection and possible the roaming. + */ + if (tx) + ieee80211_flush_queues(local, sdata, true); + + /* deauthenticate/disassociate now */ + if (tx || frame_buf) { + /* + * In multi channel scenarios guarantee that the virtual + * interface is granted immediate airtime to transmit the + * deauthentication frame by calling mgd_prepare_tx, if the + * driver requested so. + */ + if (ieee80211_hw_check(&local->hw, DEAUTH_NEED_MGD_TX_PREP) && + !sdata->deflink.u.mgd.have_beacon) { + drv_mgd_prepare_tx(sdata->local, sdata, &info); + } + + ieee80211_send_deauth_disassoc(sdata, sdata->vif.cfg.ap_addr, + sdata->vif.cfg.ap_addr, stype, + reason, tx, frame_buf); + } + + /* flush out frame - make sure the deauth was actually sent */ + if (tx) + ieee80211_flush_queues(local, sdata, false); + + drv_mgd_complete_tx(sdata->local, sdata, &info); + + /* clear AP addr only after building the needed mgmt frames */ + eth_zero_addr(sdata->deflink.u.mgd.bssid); + eth_zero_addr(sdata->vif.cfg.ap_addr); + + sdata->vif.cfg.ssid_len = 0; + + /* remove AP and TDLS peers */ + sta_info_flush(sdata); + + /* finally reset all BSS / config parameters */ + if (!sdata->vif.valid_links) + changed |= ieee80211_reset_erp_info(sdata); + + ieee80211_led_assoc(local, 0); + changed |= BSS_CHANGED_ASSOC; + sdata->vif.cfg.assoc = false; + + sdata->deflink.u.mgd.p2p_noa_index = -1; + memset(&sdata->vif.bss_conf.p2p_noa_attr, 0, + sizeof(sdata->vif.bss_conf.p2p_noa_attr)); + + /* on the next assoc, re-program HT/VHT parameters */ + memset(&ifmgd->ht_capa, 0, sizeof(ifmgd->ht_capa)); + memset(&ifmgd->ht_capa_mask, 0, sizeof(ifmgd->ht_capa_mask)); + memset(&ifmgd->vht_capa, 0, sizeof(ifmgd->vht_capa)); + memset(&ifmgd->vht_capa_mask, 0, sizeof(ifmgd->vht_capa_mask)); + + /* + * reset MU-MIMO ownership and group data in default link, + * if used, other links are destroyed + */ + memset(sdata->vif.bss_conf.mu_group.membership, 0, + sizeof(sdata->vif.bss_conf.mu_group.membership)); + memset(sdata->vif.bss_conf.mu_group.position, 0, + sizeof(sdata->vif.bss_conf.mu_group.position)); + if (!sdata->vif.valid_links) + changed |= BSS_CHANGED_MU_GROUPS; + sdata->vif.bss_conf.mu_mimo_owner = false; + + sdata->deflink.ap_power_level = IEEE80211_UNSET_POWER_LEVEL; + + del_timer_sync(&local->dynamic_ps_timer); + cancel_work_sync(&local->dynamic_ps_enable_work); + + /* Disable ARP filtering */ + if (sdata->vif.cfg.arp_addr_cnt) + changed |= BSS_CHANGED_ARP_FILTER; + + sdata->vif.bss_conf.qos = false; + if (!sdata->vif.valid_links) { + changed |= BSS_CHANGED_QOS; + /* The BSSID (not really interesting) and HT changed */ + changed |= BSS_CHANGED_BSSID | BSS_CHANGED_HT; + ieee80211_bss_info_change_notify(sdata, changed); + } else { + ieee80211_vif_cfg_change_notify(sdata, changed); + } + + /* disassociated - set to defaults now */ + ieee80211_set_wmm_default(&sdata->deflink, false, false); + + del_timer_sync(&sdata->u.mgd.conn_mon_timer); + del_timer_sync(&sdata->u.mgd.bcn_mon_timer); + del_timer_sync(&sdata->u.mgd.timer); + del_timer_sync(&sdata->deflink.u.mgd.chswitch_timer); + + sdata->vif.bss_conf.dtim_period = 0; + sdata->vif.bss_conf.beacon_rate = NULL; + + sdata->deflink.u.mgd.have_beacon = false; + sdata->deflink.u.mgd.tracking_signal_avg = false; + sdata->deflink.u.mgd.disable_wmm_tracking = false; + + ifmgd->flags = 0; + sdata->deflink.u.mgd.conn_flags = 0; + mutex_lock(&local->mtx); + + for (link_id = 0; link_id < ARRAY_SIZE(sdata->link); link_id++) { + struct ieee80211_link_data *link; + + link = sdata_dereference(sdata->link[link_id], sdata); + if (!link) + continue; + ieee80211_link_release_channel(link); + } + + sdata->vif.bss_conf.csa_active = false; + sdata->deflink.u.mgd.csa_waiting_bcn = false; + sdata->deflink.u.mgd.csa_ignored_same_chan = false; + if (sdata->deflink.csa_block_tx) { + ieee80211_wake_vif_queues(local, sdata, + IEEE80211_QUEUE_STOP_REASON_CSA); + sdata->deflink.csa_block_tx = false; + } + mutex_unlock(&local->mtx); + + /* existing TX TSPEC sessions no longer exist */ + memset(ifmgd->tx_tspec, 0, sizeof(ifmgd->tx_tspec)); + cancel_delayed_work_sync(&ifmgd->tx_tspec_wk); + + sdata->vif.bss_conf.pwr_reduction = 0; + sdata->vif.bss_conf.tx_pwr_env_num = 0; + memset(sdata->vif.bss_conf.tx_pwr_env, 0, + sizeof(sdata->vif.bss_conf.tx_pwr_env)); + + ieee80211_vif_set_links(sdata, 0); +} + +static void ieee80211_reset_ap_probe(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_local *local = sdata->local; + + mutex_lock(&local->mtx); + if (!(ifmgd->flags & IEEE80211_STA_CONNECTION_POLL)) + goto out; + + __ieee80211_stop_poll(sdata); + + mutex_lock(&local->iflist_mtx); + ieee80211_recalc_ps(local); + mutex_unlock(&local->iflist_mtx); + + if (ieee80211_hw_check(&sdata->local->hw, CONNECTION_MONITOR)) + goto out; + + /* + * We've received a probe response, but are not sure whether + * we have or will be receiving any beacons or data, so let's + * schedule the timers again, just in case. + */ + ieee80211_sta_reset_beacon_monitor(sdata); + + mod_timer(&ifmgd->conn_mon_timer, + round_jiffies_up(jiffies + + IEEE80211_CONNECTION_IDLE_TIME)); +out: + mutex_unlock(&local->mtx); +} + +static void ieee80211_sta_tx_wmm_ac_notify(struct ieee80211_sub_if_data *sdata, + struct ieee80211_hdr *hdr, + u16 tx_time) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + u16 tid; + int ac; + struct ieee80211_sta_tx_tspec *tx_tspec; + unsigned long now = jiffies; + + if (!ieee80211_is_data_qos(hdr->frame_control)) + return; + + tid = ieee80211_get_tid(hdr); + ac = ieee80211_ac_from_tid(tid); + tx_tspec = &ifmgd->tx_tspec[ac]; + + if (likely(!tx_tspec->admitted_time)) + return; + + if (time_after(now, tx_tspec->time_slice_start + HZ)) { + tx_tspec->consumed_tx_time = 0; + tx_tspec->time_slice_start = now; + + if (tx_tspec->downgraded) { + tx_tspec->action = TX_TSPEC_ACTION_STOP_DOWNGRADE; + schedule_delayed_work(&ifmgd->tx_tspec_wk, 0); + } + } + + if (tx_tspec->downgraded) + return; + + tx_tspec->consumed_tx_time += tx_time; + + if (tx_tspec->consumed_tx_time >= tx_tspec->admitted_time) { + tx_tspec->downgraded = true; + tx_tspec->action = TX_TSPEC_ACTION_DOWNGRADE; + schedule_delayed_work(&ifmgd->tx_tspec_wk, 0); + } +} + +void ieee80211_sta_tx_notify(struct ieee80211_sub_if_data *sdata, + struct ieee80211_hdr *hdr, bool ack, u16 tx_time) +{ + ieee80211_sta_tx_wmm_ac_notify(sdata, hdr, tx_time); + + if (!ieee80211_is_any_nullfunc(hdr->frame_control) || + !sdata->u.mgd.probe_send_count) + return; + + if (ack) + sdata->u.mgd.probe_send_count = 0; + else + sdata->u.mgd.nullfunc_failed = true; + ieee80211_queue_work(&sdata->local->hw, &sdata->work); +} + +static void ieee80211_mlme_send_probe_req(struct ieee80211_sub_if_data *sdata, + const u8 *src, const u8 *dst, + const u8 *ssid, size_t ssid_len, + struct ieee80211_channel *channel) +{ + struct sk_buff *skb; + + skb = ieee80211_build_probe_req(sdata, src, dst, (u32)-1, channel, + ssid, ssid_len, NULL, 0, + IEEE80211_PROBE_FLAG_DIRECTED); + if (skb) + ieee80211_tx_skb(sdata, skb); +} + +static void ieee80211_mgd_probe_ap_send(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + u8 *dst = sdata->vif.cfg.ap_addr; + u8 unicast_limit = max(1, max_probe_tries - 3); + struct sta_info *sta; + + if (WARN_ON(sdata->vif.valid_links)) + return; + + /* + * Try sending broadcast probe requests for the last three + * probe requests after the first ones failed since some + * buggy APs only support broadcast probe requests. + */ + if (ifmgd->probe_send_count >= unicast_limit) + dst = NULL; + + /* + * When the hardware reports an accurate Tx ACK status, it's + * better to send a nullfunc frame instead of a probe request, + * as it will kick us off the AP quickly if we aren't associated + * anymore. The timeout will be reset if the frame is ACKed by + * the AP. + */ + ifmgd->probe_send_count++; + + if (dst) { + mutex_lock(&sdata->local->sta_mtx); + sta = sta_info_get(sdata, dst); + if (!WARN_ON(!sta)) + ieee80211_check_fast_rx(sta); + mutex_unlock(&sdata->local->sta_mtx); + } + + if (ieee80211_hw_check(&sdata->local->hw, REPORTS_TX_ACK_STATUS)) { + ifmgd->nullfunc_failed = false; + ieee80211_send_nullfunc(sdata->local, sdata, false); + } else { + ieee80211_mlme_send_probe_req(sdata, sdata->vif.addr, dst, + sdata->vif.cfg.ssid, + sdata->vif.cfg.ssid_len, + sdata->deflink.u.mgd.bss->channel); + } + + ifmgd->probe_timeout = jiffies + msecs_to_jiffies(probe_wait_ms); + run_again(sdata, ifmgd->probe_timeout); +} + +static void ieee80211_mgd_probe_ap(struct ieee80211_sub_if_data *sdata, + bool beacon) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + bool already = false; + + if (WARN_ON(sdata->vif.valid_links)) + return; + + if (!ieee80211_sdata_running(sdata)) + return; + + sdata_lock(sdata); + + if (!ifmgd->associated) + goto out; + + mutex_lock(&sdata->local->mtx); + + if (sdata->local->tmp_channel || sdata->local->scanning) { + mutex_unlock(&sdata->local->mtx); + goto out; + } + + if (sdata->local->suspending) { + /* reschedule after resume */ + mutex_unlock(&sdata->local->mtx); + ieee80211_reset_ap_probe(sdata); + goto out; + } + + if (beacon) { + mlme_dbg_ratelimited(sdata, + "detected beacon loss from AP (missed %d beacons) - probing\n", + beacon_loss_count); + + ieee80211_cqm_beacon_loss_notify(&sdata->vif, GFP_KERNEL); + } + + /* + * The driver/our work has already reported this event or the + * connection monitoring has kicked in and we have already sent + * a probe request. Or maybe the AP died and the driver keeps + * reporting until we disassociate... + * + * In either case we have to ignore the current call to this + * function (except for setting the correct probe reason bit) + * because otherwise we would reset the timer every time and + * never check whether we received a probe response! + */ + if (ifmgd->flags & IEEE80211_STA_CONNECTION_POLL) + already = true; + + ifmgd->flags |= IEEE80211_STA_CONNECTION_POLL; + + mutex_unlock(&sdata->local->mtx); + + if (already) + goto out; + + mutex_lock(&sdata->local->iflist_mtx); + ieee80211_recalc_ps(sdata->local); + mutex_unlock(&sdata->local->iflist_mtx); + + ifmgd->probe_send_count = 0; + ieee80211_mgd_probe_ap_send(sdata); + out: + sdata_unlock(sdata); +} + +struct sk_buff *ieee80211_ap_probereq_get(struct ieee80211_hw *hw, + struct ieee80211_vif *vif) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct cfg80211_bss *cbss; + struct sk_buff *skb; + const struct element *ssid; + int ssid_len; + + if (WARN_ON(sdata->vif.type != NL80211_IFTYPE_STATION || + sdata->vif.valid_links)) + return NULL; + + sdata_assert_lock(sdata); + + if (ifmgd->associated) + cbss = sdata->deflink.u.mgd.bss; + else if (ifmgd->auth_data) + cbss = ifmgd->auth_data->bss; + else if (ifmgd->assoc_data && ifmgd->assoc_data->link[0].bss) + cbss = ifmgd->assoc_data->link[0].bss; + else + return NULL; + + rcu_read_lock(); + ssid = ieee80211_bss_get_elem(cbss, WLAN_EID_SSID); + if (WARN_ONCE(!ssid || ssid->datalen > IEEE80211_MAX_SSID_LEN, + "invalid SSID element (len=%d)", + ssid ? ssid->datalen : -1)) + ssid_len = 0; + else + ssid_len = ssid->datalen; + + skb = ieee80211_build_probe_req(sdata, sdata->vif.addr, cbss->bssid, + (u32) -1, cbss->channel, + ssid->data, ssid_len, + NULL, 0, IEEE80211_PROBE_FLAG_DIRECTED); + rcu_read_unlock(); + + return skb; +} +EXPORT_SYMBOL(ieee80211_ap_probereq_get); + +static void ieee80211_report_disconnect(struct ieee80211_sub_if_data *sdata, + const u8 *buf, size_t len, bool tx, + u16 reason, bool reconnect) +{ + struct ieee80211_event event = { + .type = MLME_EVENT, + .u.mlme.data = tx ? DEAUTH_TX_EVENT : DEAUTH_RX_EVENT, + .u.mlme.reason = reason, + }; + + if (tx) + cfg80211_tx_mlme_mgmt(sdata->dev, buf, len, reconnect); + else + cfg80211_rx_mlme_mgmt(sdata->dev, buf, len); + + drv_event_callback(sdata->local, sdata, &event); +} + +static void __ieee80211_disconnect(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + u8 frame_buf[IEEE80211_DEAUTH_FRAME_LEN]; + bool tx; + + sdata_lock(sdata); + if (!ifmgd->associated) { + sdata_unlock(sdata); + return; + } + + /* in MLO assume we have a link where we can TX the frame */ + tx = sdata->vif.valid_links || !sdata->deflink.csa_block_tx; + + if (!ifmgd->driver_disconnect) { + unsigned int link_id; + + /* + * AP is probably out of range (or not reachable for another + * reason) so remove the bss structs for that AP. In the case + * of multi-link, it's not clear that all of them really are + * out of range, but if they weren't the driver likely would + * have switched to just have a single link active? + */ + for (link_id = 0; + link_id < ARRAY_SIZE(sdata->link); + link_id++) { + struct ieee80211_link_data *link; + + link = sdata_dereference(sdata->link[link_id], sdata); + if (!link) + continue; + cfg80211_unlink_bss(local->hw.wiphy, link->u.mgd.bss); + link->u.mgd.bss = NULL; + } + } + + ieee80211_set_disassoc(sdata, IEEE80211_STYPE_DEAUTH, + ifmgd->driver_disconnect ? + WLAN_REASON_DEAUTH_LEAVING : + WLAN_REASON_DISASSOC_DUE_TO_INACTIVITY, + tx, frame_buf); + mutex_lock(&local->mtx); + /* the other links will be destroyed */ + sdata->vif.bss_conf.csa_active = false; + sdata->deflink.u.mgd.csa_waiting_bcn = false; + if (sdata->deflink.csa_block_tx) { + ieee80211_wake_vif_queues(local, sdata, + IEEE80211_QUEUE_STOP_REASON_CSA); + sdata->deflink.csa_block_tx = false; + } + mutex_unlock(&local->mtx); + + ieee80211_report_disconnect(sdata, frame_buf, sizeof(frame_buf), tx, + WLAN_REASON_DISASSOC_DUE_TO_INACTIVITY, + ifmgd->reconnect); + ifmgd->reconnect = false; + + sdata_unlock(sdata); +} + +static void ieee80211_beacon_connection_loss_work(struct work_struct *work) +{ + struct ieee80211_sub_if_data *sdata = + container_of(work, struct ieee80211_sub_if_data, + u.mgd.beacon_connection_loss_work); + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + + if (ifmgd->connection_loss) { + sdata_info(sdata, "Connection to AP %pM lost\n", + sdata->vif.cfg.ap_addr); + __ieee80211_disconnect(sdata); + ifmgd->connection_loss = false; + } else if (ifmgd->driver_disconnect) { + sdata_info(sdata, + "Driver requested disconnection from AP %pM\n", + sdata->vif.cfg.ap_addr); + __ieee80211_disconnect(sdata); + ifmgd->driver_disconnect = false; + } else { + if (ifmgd->associated) + sdata->deflink.u.mgd.beacon_loss_count++; + ieee80211_mgd_probe_ap(sdata, true); + } +} + +static void ieee80211_csa_connection_drop_work(struct work_struct *work) +{ + struct ieee80211_sub_if_data *sdata = + container_of(work, struct ieee80211_sub_if_data, + u.mgd.csa_connection_drop_work); + + __ieee80211_disconnect(sdata); +} + +void ieee80211_beacon_loss(struct ieee80211_vif *vif) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_hw *hw = &sdata->local->hw; + + trace_api_beacon_loss(sdata); + + sdata->u.mgd.connection_loss = false; + ieee80211_queue_work(hw, &sdata->u.mgd.beacon_connection_loss_work); +} +EXPORT_SYMBOL(ieee80211_beacon_loss); + +void ieee80211_connection_loss(struct ieee80211_vif *vif) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_hw *hw = &sdata->local->hw; + + trace_api_connection_loss(sdata); + + sdata->u.mgd.connection_loss = true; + ieee80211_queue_work(hw, &sdata->u.mgd.beacon_connection_loss_work); +} +EXPORT_SYMBOL(ieee80211_connection_loss); + +void ieee80211_disconnect(struct ieee80211_vif *vif, bool reconnect) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_hw *hw = &sdata->local->hw; + + trace_api_disconnect(sdata, reconnect); + + if (WARN_ON(sdata->vif.type != NL80211_IFTYPE_STATION)) + return; + + sdata->u.mgd.driver_disconnect = true; + sdata->u.mgd.reconnect = reconnect; + ieee80211_queue_work(hw, &sdata->u.mgd.beacon_connection_loss_work); +} +EXPORT_SYMBOL(ieee80211_disconnect); + +static void ieee80211_destroy_auth_data(struct ieee80211_sub_if_data *sdata, + bool assoc) +{ + struct ieee80211_mgd_auth_data *auth_data = sdata->u.mgd.auth_data; + + sdata_assert_lock(sdata); + + if (!assoc) { + /* + * we are not authenticated yet, the only timer that could be + * running is the timeout for the authentication response which + * which is not relevant anymore. + */ + del_timer_sync(&sdata->u.mgd.timer); + sta_info_destroy_addr(sdata, auth_data->ap_addr); + + /* other links are destroyed */ + sdata->deflink.u.mgd.conn_flags = 0; + eth_zero_addr(sdata->deflink.u.mgd.bssid); + ieee80211_link_info_change_notify(sdata, &sdata->deflink, + BSS_CHANGED_BSSID); + sdata->u.mgd.flags = 0; + + mutex_lock(&sdata->local->mtx); + ieee80211_link_release_channel(&sdata->deflink); + ieee80211_vif_set_links(sdata, 0); + mutex_unlock(&sdata->local->mtx); + } + + cfg80211_put_bss(sdata->local->hw.wiphy, auth_data->bss); + kfree(auth_data); + sdata->u.mgd.auth_data = NULL; +} + +enum assoc_status { + ASSOC_SUCCESS, + ASSOC_REJECTED, + ASSOC_TIMEOUT, + ASSOC_ABANDON, +}; + +static void ieee80211_destroy_assoc_data(struct ieee80211_sub_if_data *sdata, + enum assoc_status status) +{ + struct ieee80211_mgd_assoc_data *assoc_data = sdata->u.mgd.assoc_data; + + sdata_assert_lock(sdata); + + if (status != ASSOC_SUCCESS) { + /* + * we are not associated yet, the only timer that could be + * running is the timeout for the association response which + * which is not relevant anymore. + */ + del_timer_sync(&sdata->u.mgd.timer); + sta_info_destroy_addr(sdata, assoc_data->ap_addr); + + sdata->deflink.u.mgd.conn_flags = 0; + eth_zero_addr(sdata->deflink.u.mgd.bssid); + ieee80211_link_info_change_notify(sdata, &sdata->deflink, + BSS_CHANGED_BSSID); + sdata->u.mgd.flags = 0; + sdata->vif.bss_conf.mu_mimo_owner = false; + + if (status != ASSOC_REJECTED) { + struct cfg80211_assoc_failure data = { + .timeout = status == ASSOC_TIMEOUT, + }; + int i; + + BUILD_BUG_ON(ARRAY_SIZE(data.bss) != + ARRAY_SIZE(assoc_data->link)); + + for (i = 0; i < ARRAY_SIZE(data.bss); i++) + data.bss[i] = assoc_data->link[i].bss; + + if (sdata->vif.valid_links) + data.ap_mld_addr = assoc_data->ap_addr; + + cfg80211_assoc_failure(sdata->dev, &data); + } + + mutex_lock(&sdata->local->mtx); + ieee80211_link_release_channel(&sdata->deflink); + ieee80211_vif_set_links(sdata, 0); + mutex_unlock(&sdata->local->mtx); + } + + kfree(assoc_data); + sdata->u.mgd.assoc_data = NULL; +} + +static void ieee80211_auth_challenge(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, size_t len) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_mgd_auth_data *auth_data = sdata->u.mgd.auth_data; + const struct element *challenge; + u8 *pos; + u32 tx_flags = 0; + struct ieee80211_prep_tx_info info = { + .subtype = IEEE80211_STYPE_AUTH, + }; + + pos = mgmt->u.auth.variable; + challenge = cfg80211_find_elem(WLAN_EID_CHALLENGE, pos, + len - (pos - (u8 *)mgmt)); + if (!challenge) + return; + auth_data->expected_transaction = 4; + drv_mgd_prepare_tx(sdata->local, sdata, &info); + if (ieee80211_hw_check(&local->hw, REPORTS_TX_ACK_STATUS)) + tx_flags = IEEE80211_TX_CTL_REQ_TX_STATUS | + IEEE80211_TX_INTFL_MLME_CONN_TX; + ieee80211_send_auth(sdata, 3, auth_data->algorithm, 0, + (void *)challenge, + challenge->datalen + sizeof(*challenge), + auth_data->ap_addr, auth_data->ap_addr, + auth_data->key, auth_data->key_len, + auth_data->key_idx, tx_flags); +} + +static bool ieee80211_mark_sta_auth(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + const u8 *ap_addr = ifmgd->auth_data->ap_addr; + struct sta_info *sta; + bool result = true; + + sdata_info(sdata, "authenticated\n"); + ifmgd->auth_data->done = true; + ifmgd->auth_data->timeout = jiffies + IEEE80211_AUTH_WAIT_ASSOC; + ifmgd->auth_data->timeout_started = true; + run_again(sdata, ifmgd->auth_data->timeout); + + /* move station state to auth */ + mutex_lock(&sdata->local->sta_mtx); + sta = sta_info_get(sdata, ap_addr); + if (!sta) { + WARN_ONCE(1, "%s: STA %pM not found", sdata->name, ap_addr); + result = false; + goto out; + } + if (sta_info_move_state(sta, IEEE80211_STA_AUTH)) { + sdata_info(sdata, "failed moving %pM to auth\n", ap_addr); + result = false; + goto out; + } + +out: + mutex_unlock(&sdata->local->sta_mtx); + return result; +} + +static void ieee80211_rx_mgmt_auth(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, size_t len) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + u16 auth_alg, auth_transaction, status_code; + struct ieee80211_event event = { + .type = MLME_EVENT, + .u.mlme.data = AUTH_EVENT, + }; + struct ieee80211_prep_tx_info info = { + .subtype = IEEE80211_STYPE_AUTH, + }; + + sdata_assert_lock(sdata); + + if (len < 24 + 6) + return; + + if (!ifmgd->auth_data || ifmgd->auth_data->done) + return; + + if (!ether_addr_equal(ifmgd->auth_data->ap_addr, mgmt->bssid)) + return; + + auth_alg = le16_to_cpu(mgmt->u.auth.auth_alg); + auth_transaction = le16_to_cpu(mgmt->u.auth.auth_transaction); + status_code = le16_to_cpu(mgmt->u.auth.status_code); + + if (auth_alg != ifmgd->auth_data->algorithm || + (auth_alg != WLAN_AUTH_SAE && + auth_transaction != ifmgd->auth_data->expected_transaction) || + (auth_alg == WLAN_AUTH_SAE && + (auth_transaction < ifmgd->auth_data->expected_transaction || + auth_transaction > 2))) { + sdata_info(sdata, "%pM unexpected authentication state: alg %d (expected %d) transact %d (expected %d)\n", + mgmt->sa, auth_alg, ifmgd->auth_data->algorithm, + auth_transaction, + ifmgd->auth_data->expected_transaction); + goto notify_driver; + } + + if (status_code != WLAN_STATUS_SUCCESS) { + cfg80211_rx_mlme_mgmt(sdata->dev, (u8 *)mgmt, len); + + if (auth_alg == WLAN_AUTH_SAE && + (status_code == WLAN_STATUS_ANTI_CLOG_REQUIRED || + (auth_transaction == 1 && + (status_code == WLAN_STATUS_SAE_HASH_TO_ELEMENT || + status_code == WLAN_STATUS_SAE_PK)))) { + /* waiting for userspace now */ + ifmgd->auth_data->waiting = true; + ifmgd->auth_data->timeout = + jiffies + IEEE80211_AUTH_WAIT_SAE_RETRY; + ifmgd->auth_data->timeout_started = true; + run_again(sdata, ifmgd->auth_data->timeout); + goto notify_driver; + } + + sdata_info(sdata, "%pM denied authentication (status %d)\n", + mgmt->sa, status_code); + ieee80211_destroy_auth_data(sdata, false); + event.u.mlme.status = MLME_DENIED; + event.u.mlme.reason = status_code; + drv_event_callback(sdata->local, sdata, &event); + goto notify_driver; + } + + switch (ifmgd->auth_data->algorithm) { + case WLAN_AUTH_OPEN: + case WLAN_AUTH_LEAP: + case WLAN_AUTH_FT: + case WLAN_AUTH_SAE: + case WLAN_AUTH_FILS_SK: + case WLAN_AUTH_FILS_SK_PFS: + case WLAN_AUTH_FILS_PK: + break; + case WLAN_AUTH_SHARED_KEY: + if (ifmgd->auth_data->expected_transaction != 4) { + ieee80211_auth_challenge(sdata, mgmt, len); + /* need another frame */ + return; + } + break; + default: + WARN_ONCE(1, "invalid auth alg %d", + ifmgd->auth_data->algorithm); + goto notify_driver; + } + + event.u.mlme.status = MLME_SUCCESS; + info.success = 1; + drv_event_callback(sdata->local, sdata, &event); + if (ifmgd->auth_data->algorithm != WLAN_AUTH_SAE || + (auth_transaction == 2 && + ifmgd->auth_data->expected_transaction == 2)) { + if (!ieee80211_mark_sta_auth(sdata)) + return; /* ignore frame -- wait for timeout */ + } else if (ifmgd->auth_data->algorithm == WLAN_AUTH_SAE && + auth_transaction == 2) { + sdata_info(sdata, "SAE peer confirmed\n"); + ifmgd->auth_data->peer_confirmed = true; + } + + cfg80211_rx_mlme_mgmt(sdata->dev, (u8 *)mgmt, len); +notify_driver: + drv_mgd_complete_tx(sdata->local, sdata, &info); +} + +#define case_WLAN(type) \ + case WLAN_REASON_##type: return #type + +const char *ieee80211_get_reason_code_string(u16 reason_code) +{ + switch (reason_code) { + case_WLAN(UNSPECIFIED); + case_WLAN(PREV_AUTH_NOT_VALID); + case_WLAN(DEAUTH_LEAVING); + case_WLAN(DISASSOC_DUE_TO_INACTIVITY); + case_WLAN(DISASSOC_AP_BUSY); + case_WLAN(CLASS2_FRAME_FROM_NONAUTH_STA); + case_WLAN(CLASS3_FRAME_FROM_NONASSOC_STA); + case_WLAN(DISASSOC_STA_HAS_LEFT); + case_WLAN(STA_REQ_ASSOC_WITHOUT_AUTH); + case_WLAN(DISASSOC_BAD_POWER); + case_WLAN(DISASSOC_BAD_SUPP_CHAN); + case_WLAN(INVALID_IE); + case_WLAN(MIC_FAILURE); + case_WLAN(4WAY_HANDSHAKE_TIMEOUT); + case_WLAN(GROUP_KEY_HANDSHAKE_TIMEOUT); + case_WLAN(IE_DIFFERENT); + case_WLAN(INVALID_GROUP_CIPHER); + case_WLAN(INVALID_PAIRWISE_CIPHER); + case_WLAN(INVALID_AKMP); + case_WLAN(UNSUPP_RSN_VERSION); + case_WLAN(INVALID_RSN_IE_CAP); + case_WLAN(IEEE8021X_FAILED); + case_WLAN(CIPHER_SUITE_REJECTED); + case_WLAN(DISASSOC_UNSPECIFIED_QOS); + case_WLAN(DISASSOC_QAP_NO_BANDWIDTH); + case_WLAN(DISASSOC_LOW_ACK); + case_WLAN(DISASSOC_QAP_EXCEED_TXOP); + case_WLAN(QSTA_LEAVE_QBSS); + case_WLAN(QSTA_NOT_USE); + case_WLAN(QSTA_REQUIRE_SETUP); + case_WLAN(QSTA_TIMEOUT); + case_WLAN(QSTA_CIPHER_NOT_SUPP); + case_WLAN(MESH_PEER_CANCELED); + case_WLAN(MESH_MAX_PEERS); + case_WLAN(MESH_CONFIG); + case_WLAN(MESH_CLOSE); + case_WLAN(MESH_MAX_RETRIES); + case_WLAN(MESH_CONFIRM_TIMEOUT); + case_WLAN(MESH_INVALID_GTK); + case_WLAN(MESH_INCONSISTENT_PARAM); + case_WLAN(MESH_INVALID_SECURITY); + case_WLAN(MESH_PATH_ERROR); + case_WLAN(MESH_PATH_NOFORWARD); + case_WLAN(MESH_PATH_DEST_UNREACHABLE); + case_WLAN(MAC_EXISTS_IN_MBSS); + case_WLAN(MESH_CHAN_REGULATORY); + case_WLAN(MESH_CHAN); + default: return ""; + } +} + +static void ieee80211_rx_mgmt_deauth(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, size_t len) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + u16 reason_code = le16_to_cpu(mgmt->u.deauth.reason_code); + + sdata_assert_lock(sdata); + + if (len < 24 + 2) + return; + + if (!ether_addr_equal(mgmt->bssid, mgmt->sa)) { + ieee80211_tdls_handle_disconnect(sdata, mgmt->sa, reason_code); + return; + } + + if (ifmgd->associated && + ether_addr_equal(mgmt->bssid, sdata->vif.cfg.ap_addr)) { + sdata_info(sdata, "deauthenticated from %pM (Reason: %u=%s)\n", + sdata->vif.cfg.ap_addr, reason_code, + ieee80211_get_reason_code_string(reason_code)); + + ieee80211_set_disassoc(sdata, 0, 0, false, NULL); + + ieee80211_report_disconnect(sdata, (u8 *)mgmt, len, false, + reason_code, false); + return; + } + + if (ifmgd->assoc_data && + ether_addr_equal(mgmt->bssid, ifmgd->assoc_data->ap_addr)) { + sdata_info(sdata, + "deauthenticated from %pM while associating (Reason: %u=%s)\n", + ifmgd->assoc_data->ap_addr, reason_code, + ieee80211_get_reason_code_string(reason_code)); + + ieee80211_destroy_assoc_data(sdata, ASSOC_ABANDON); + + cfg80211_rx_mlme_mgmt(sdata->dev, (u8 *)mgmt, len); + return; + } +} + + +static void ieee80211_rx_mgmt_disassoc(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, size_t len) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + u16 reason_code; + + sdata_assert_lock(sdata); + + if (len < 24 + 2) + return; + + if (!ifmgd->associated || + !ether_addr_equal(mgmt->bssid, sdata->vif.cfg.ap_addr)) + return; + + reason_code = le16_to_cpu(mgmt->u.disassoc.reason_code); + + if (!ether_addr_equal(mgmt->bssid, mgmt->sa)) { + ieee80211_tdls_handle_disconnect(sdata, mgmt->sa, reason_code); + return; + } + + sdata_info(sdata, "disassociated from %pM (Reason: %u=%s)\n", + sdata->vif.cfg.ap_addr, reason_code, + ieee80211_get_reason_code_string(reason_code)); + + ieee80211_set_disassoc(sdata, 0, 0, false, NULL); + + ieee80211_report_disconnect(sdata, (u8 *)mgmt, len, false, reason_code, + false); +} + +static void ieee80211_get_rates(struct ieee80211_supported_band *sband, + u8 *supp_rates, unsigned int supp_rates_len, + u32 *rates, u32 *basic_rates, + bool *have_higher_than_11mbit, + int *min_rate, int *min_rate_index, + int shift) +{ + int i, j; + + for (i = 0; i < supp_rates_len; i++) { + int rate = supp_rates[i] & 0x7f; + bool is_basic = !!(supp_rates[i] & 0x80); + + if ((rate * 5 * (1 << shift)) > 110) + *have_higher_than_11mbit = true; + + /* + * Skip HT, VHT, HE and SAE H2E only BSS membership selectors + * since they're not rates. + * + * Note: Even though the membership selector and the basic + * rate flag share the same bit, they are not exactly + * the same. + */ + if (supp_rates[i] == (0x80 | BSS_MEMBERSHIP_SELECTOR_HT_PHY) || + supp_rates[i] == (0x80 | BSS_MEMBERSHIP_SELECTOR_VHT_PHY) || + supp_rates[i] == (0x80 | BSS_MEMBERSHIP_SELECTOR_HE_PHY) || + supp_rates[i] == (0x80 | BSS_MEMBERSHIP_SELECTOR_SAE_H2E)) + continue; + + for (j = 0; j < sband->n_bitrates; j++) { + struct ieee80211_rate *br; + int brate; + + br = &sband->bitrates[j]; + + brate = DIV_ROUND_UP(br->bitrate, (1 << shift) * 5); + if (brate == rate) { + *rates |= BIT(j); + if (is_basic) + *basic_rates |= BIT(j); + if ((rate * 5) < *min_rate) { + *min_rate = rate * 5; + *min_rate_index = j; + } + break; + } + } + } +} + +static bool ieee80211_twt_req_supported(const struct link_sta_info *link_sta, + const struct ieee802_11_elems *elems) +{ + if (elems->ext_capab_len < 10) + return false; + + if (!(elems->ext_capab[9] & WLAN_EXT_CAPA10_TWT_RESPONDER_SUPPORT)) + return false; + + return link_sta->pub->he_cap.he_cap_elem.mac_cap_info[0] & + IEEE80211_HE_MAC_CAP0_TWT_RES; +} + +static int ieee80211_recalc_twt_req(struct ieee80211_link_data *link, + struct link_sta_info *link_sta, + struct ieee802_11_elems *elems) +{ + bool twt = ieee80211_twt_req_supported(link_sta, elems); + + if (link->conf->twt_requester != twt) { + link->conf->twt_requester = twt; + return BSS_CHANGED_TWT; + } + return 0; +} + +static bool ieee80211_twt_bcast_support(struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss_conf *bss_conf, + struct ieee80211_supported_band *sband, + struct link_sta_info *link_sta) +{ + const struct ieee80211_sta_he_cap *own_he_cap = + ieee80211_get_he_iftype_cap(sband, + ieee80211_vif_type_p2p(&sdata->vif)); + + return bss_conf->he_support && + (link_sta->pub->he_cap.he_cap_elem.mac_cap_info[2] & + IEEE80211_HE_MAC_CAP2_BCAST_TWT) && + own_he_cap && + (own_he_cap->he_cap_elem.mac_cap_info[2] & + IEEE80211_HE_MAC_CAP2_BCAST_TWT); +} + +static bool ieee80211_assoc_config_link(struct ieee80211_link_data *link, + struct link_sta_info *link_sta, + struct cfg80211_bss *cbss, + struct ieee80211_mgmt *mgmt, + const u8 *elem_start, + unsigned int elem_len, + u64 *changed) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_mgd_assoc_data *assoc_data = sdata->u.mgd.assoc_data; + struct ieee80211_bss_conf *bss_conf = link->conf; + struct ieee80211_local *local = sdata->local; + struct ieee80211_elems_parse_params parse_params = { + .start = elem_start, + .len = elem_len, + .bss = cbss, + .link_id = link == &sdata->deflink ? -1 : link->link_id, + .from_ap = true, + }; + bool is_6ghz = cbss->channel->band == NL80211_BAND_6GHZ; + bool is_s1g = cbss->channel->band == NL80211_BAND_S1GHZ; + const struct cfg80211_bss_ies *bss_ies = NULL; + struct ieee80211_supported_band *sband; + struct ieee802_11_elems *elems; + u16 capab_info; + bool ret; + + elems = ieee802_11_parse_elems_full(&parse_params); + if (!elems) + return false; + + /* FIXME: use from STA profile element after parsing that */ + capab_info = le16_to_cpu(mgmt->u.assoc_resp.capab_info); + + if (!is_s1g && !elems->supp_rates) { + sdata_info(sdata, "no SuppRates element in AssocResp\n"); + ret = false; + goto out; + } + + link->u.mgd.tdls_chan_switch_prohibited = + elems->ext_capab && elems->ext_capab_len >= 5 && + (elems->ext_capab[4] & WLAN_EXT_CAPA5_TDLS_CH_SW_PROHIBITED); + + /* + * Some APs are erroneously not including some information in their + * (re)association response frames. Try to recover by using the data + * from the beacon or probe response. This seems to afflict mobile + * 2G/3G/4G wifi routers, reported models include the "Onda PN51T", + * "Vodafone PocketWiFi 2", "ZTE MF60" and a similar T-Mobile device. + */ + if (!is_6ghz && + ((assoc_data->wmm && !elems->wmm_param) || + (!(link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_HT) && + (!elems->ht_cap_elem || !elems->ht_operation)) || + (!(link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_VHT) && + (!elems->vht_cap_elem || !elems->vht_operation)))) { + const struct cfg80211_bss_ies *ies; + struct ieee802_11_elems *bss_elems; + + rcu_read_lock(); + ies = rcu_dereference(cbss->ies); + if (ies) + bss_ies = kmemdup(ies, sizeof(*ies) + ies->len, + GFP_ATOMIC); + rcu_read_unlock(); + if (!bss_ies) { + ret = false; + goto out; + } + + parse_params.start = bss_ies->data; + parse_params.len = bss_ies->len; + bss_elems = ieee802_11_parse_elems_full(&parse_params); + if (!bss_elems) { + ret = false; + goto out; + } + + if (assoc_data->wmm && + !elems->wmm_param && bss_elems->wmm_param) { + elems->wmm_param = bss_elems->wmm_param; + sdata_info(sdata, + "AP bug: WMM param missing from AssocResp\n"); + } + + /* + * Also check if we requested HT/VHT, otherwise the AP doesn't + * have to include the IEs in the (re)association response. + */ + if (!elems->ht_cap_elem && bss_elems->ht_cap_elem && + !(link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_HT)) { + elems->ht_cap_elem = bss_elems->ht_cap_elem; + sdata_info(sdata, + "AP bug: HT capability missing from AssocResp\n"); + } + if (!elems->ht_operation && bss_elems->ht_operation && + !(link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_HT)) { + elems->ht_operation = bss_elems->ht_operation; + sdata_info(sdata, + "AP bug: HT operation missing from AssocResp\n"); + } + if (!elems->vht_cap_elem && bss_elems->vht_cap_elem && + !(link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_VHT)) { + elems->vht_cap_elem = bss_elems->vht_cap_elem; + sdata_info(sdata, + "AP bug: VHT capa missing from AssocResp\n"); + } + if (!elems->vht_operation && bss_elems->vht_operation && + !(link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_VHT)) { + elems->vht_operation = bss_elems->vht_operation; + sdata_info(sdata, + "AP bug: VHT operation missing from AssocResp\n"); + } + + kfree(bss_elems); + } + + /* + * We previously checked these in the beacon/probe response, so + * they should be present here. This is just a safety net. + */ + if (!is_6ghz && !(link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_HT) && + (!elems->wmm_param || !elems->ht_cap_elem || !elems->ht_operation)) { + sdata_info(sdata, + "HT AP is missing WMM params or HT capability/operation\n"); + ret = false; + goto out; + } + + if (!is_6ghz && !(link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_VHT) && + (!elems->vht_cap_elem || !elems->vht_operation)) { + sdata_info(sdata, + "VHT AP is missing VHT capability/operation\n"); + ret = false; + goto out; + } + + if (is_6ghz && !(link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_HE) && + !elems->he_6ghz_capa) { + sdata_info(sdata, + "HE 6 GHz AP is missing HE 6 GHz band capability\n"); + ret = false; + goto out; + } + + if (WARN_ON(!link->conf->chandef.chan)) { + ret = false; + goto out; + } + sband = local->hw.wiphy->bands[link->conf->chandef.chan->band]; + + if (!(link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_HE) && + (!elems->he_cap || !elems->he_operation)) { + sdata_info(sdata, + "HE AP is missing HE capability/operation\n"); + ret = false; + goto out; + } + + /* Set up internal HT/VHT capabilities */ + if (elems->ht_cap_elem && !(link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_HT)) + ieee80211_ht_cap_ie_to_sta_ht_cap(sdata, sband, + elems->ht_cap_elem, + link_sta); + + if (elems->vht_cap_elem && + !(link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_VHT)) { + const struct ieee80211_vht_cap *bss_vht_cap = NULL; + const struct cfg80211_bss_ies *ies; + + /* + * Cisco AP module 9115 with FW 17.3 has a bug and sends a + * too large maximum MPDU length in the association response + * (indicating 12k) that it cannot actually process ... + * Work around that. + */ + rcu_read_lock(); + ies = rcu_dereference(cbss->ies); + if (ies) { + const struct element *elem; + + elem = cfg80211_find_elem(WLAN_EID_VHT_CAPABILITY, + ies->data, ies->len); + if (elem && elem->datalen >= sizeof(*bss_vht_cap)) + bss_vht_cap = (const void *)elem->data; + } + + ieee80211_vht_cap_ie_to_sta_vht_cap(sdata, sband, + elems->vht_cap_elem, + bss_vht_cap, link_sta); + rcu_read_unlock(); + } + + if (elems->he_operation && !(link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_HE) && + elems->he_cap) { + ieee80211_he_cap_ie_to_sta_he_cap(sdata, sband, + elems->he_cap, + elems->he_cap_len, + elems->he_6ghz_capa, + link_sta); + + bss_conf->he_support = link_sta->pub->he_cap.has_he; + if (elems->rsnx && elems->rsnx_len && + (elems->rsnx[0] & WLAN_RSNX_CAPA_PROTECTED_TWT) && + wiphy_ext_feature_isset(local->hw.wiphy, + NL80211_EXT_FEATURE_PROTECTED_TWT)) + bss_conf->twt_protected = true; + else + bss_conf->twt_protected = false; + + *changed |= ieee80211_recalc_twt_req(link, link_sta, elems); + + if (elems->eht_operation && elems->eht_cap && + !(link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_EHT)) { + ieee80211_eht_cap_ie_to_sta_eht_cap(sdata, sband, + elems->he_cap, + elems->he_cap_len, + elems->eht_cap, + elems->eht_cap_len, + link_sta); + + bss_conf->eht_support = link_sta->pub->eht_cap.has_eht; + } else { + bss_conf->eht_support = false; + } + } else { + bss_conf->he_support = false; + bss_conf->twt_requester = false; + bss_conf->twt_protected = false; + bss_conf->eht_support = false; + } + + bss_conf->twt_broadcast = + ieee80211_twt_bcast_support(sdata, bss_conf, sband, link_sta); + + if (bss_conf->he_support) { + bss_conf->he_bss_color.color = + le32_get_bits(elems->he_operation->he_oper_params, + IEEE80211_HE_OPERATION_BSS_COLOR_MASK); + bss_conf->he_bss_color.partial = + le32_get_bits(elems->he_operation->he_oper_params, + IEEE80211_HE_OPERATION_PARTIAL_BSS_COLOR); + bss_conf->he_bss_color.enabled = + !le32_get_bits(elems->he_operation->he_oper_params, + IEEE80211_HE_OPERATION_BSS_COLOR_DISABLED); + + if (bss_conf->he_bss_color.enabled) + *changed |= BSS_CHANGED_HE_BSS_COLOR; + + bss_conf->htc_trig_based_pkt_ext = + le32_get_bits(elems->he_operation->he_oper_params, + IEEE80211_HE_OPERATION_DFLT_PE_DURATION_MASK); + bss_conf->frame_time_rts_th = + le32_get_bits(elems->he_operation->he_oper_params, + IEEE80211_HE_OPERATION_RTS_THRESHOLD_MASK); + + bss_conf->uora_exists = !!elems->uora_element; + if (elems->uora_element) + bss_conf->uora_ocw_range = elems->uora_element[0]; + + ieee80211_he_op_ie_to_bss_conf(&sdata->vif, elems->he_operation); + ieee80211_he_spr_ie_to_bss_conf(&sdata->vif, elems->he_spr); + /* TODO: OPEN: what happens if BSS color disable is set? */ + } + + if (cbss->transmitted_bss) { + bss_conf->nontransmitted = true; + ether_addr_copy(bss_conf->transmitter_bssid, + cbss->transmitted_bss->bssid); + bss_conf->bssid_indicator = cbss->max_bssid_indicator; + bss_conf->bssid_index = cbss->bssid_index; + } + + /* + * Some APs, e.g. Netgear WNDR3700, report invalid HT operation data + * in their association response, so ignore that data for our own + * configuration. If it changed since the last beacon, we'll get the + * next beacon and update then. + */ + + /* + * If an operating mode notification IE is present, override the + * NSS calculation (that would be done in rate_control_rate_init()) + * and use the # of streams from that element. + */ + if (elems->opmode_notif && + !(*elems->opmode_notif & IEEE80211_OPMODE_NOTIF_RX_NSS_TYPE_BF)) { + u8 nss; + + nss = *elems->opmode_notif & IEEE80211_OPMODE_NOTIF_RX_NSS_MASK; + nss >>= IEEE80211_OPMODE_NOTIF_RX_NSS_SHIFT; + nss += 1; + link_sta->pub->rx_nss = nss; + } + + /* + * Always handle WMM once after association regardless + * of the first value the AP uses. Setting -1 here has + * that effect because the AP values is an unsigned + * 4-bit value. + */ + link->u.mgd.wmm_last_param_set = -1; + link->u.mgd.mu_edca_last_param_set = -1; + + if (link->u.mgd.disable_wmm_tracking) { + ieee80211_set_wmm_default(link, false, false); + } else if (!ieee80211_sta_wmm_params(local, link, elems->wmm_param, + elems->wmm_param_len, + elems->mu_edca_param_set)) { + /* still enable QoS since we might have HT/VHT */ + ieee80211_set_wmm_default(link, false, true); + /* disable WMM tracking in this case to disable + * tracking WMM parameter changes in the beacon if + * the parameters weren't actually valid. Doing so + * avoids changing parameters very strangely when + * the AP is going back and forth between valid and + * invalid parameters. + */ + link->u.mgd.disable_wmm_tracking = true; + } + + if (elems->max_idle_period_ie) { + bss_conf->max_idle_period = + le16_to_cpu(elems->max_idle_period_ie->max_idle_period); + bss_conf->protected_keep_alive = + !!(elems->max_idle_period_ie->idle_options & + WLAN_IDLE_OPTIONS_PROTECTED_KEEP_ALIVE); + *changed |= BSS_CHANGED_KEEP_ALIVE; + } else { + bss_conf->max_idle_period = 0; + bss_conf->protected_keep_alive = false; + } + + /* set assoc capability (AID was already set earlier), + * ieee80211_set_associated() will tell the driver */ + bss_conf->assoc_capability = capab_info; + + ret = true; +out: + kfree(elems); + kfree(bss_ies); + return ret; +} + +static int ieee80211_mgd_setup_link_sta(struct ieee80211_link_data *link, + struct sta_info *sta, + struct link_sta_info *link_sta, + struct cfg80211_bss *cbss) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_bss *bss = (void *)cbss->priv; + u32 rates = 0, basic_rates = 0; + bool have_higher_than_11mbit = false; + int min_rate = INT_MAX, min_rate_index = -1; + /* this is clearly wrong for MLO but we'll just remove it later */ + int shift = ieee80211_vif_get_shift(&sdata->vif); + struct ieee80211_supported_band *sband; + + memcpy(link_sta->addr, cbss->bssid, ETH_ALEN); + memcpy(link_sta->pub->addr, cbss->bssid, ETH_ALEN); + + /* TODO: S1G Basic Rate Set is expressed elsewhere */ + if (cbss->channel->band == NL80211_BAND_S1GHZ) { + ieee80211_s1g_sta_rate_init(sta); + return 0; + } + + sband = local->hw.wiphy->bands[cbss->channel->band]; + + ieee80211_get_rates(sband, bss->supp_rates, bss->supp_rates_len, + &rates, &basic_rates, &have_higher_than_11mbit, + &min_rate, &min_rate_index, shift); + + /* + * This used to be a workaround for basic rates missing + * in the association response frame. Now that we no + * longer use the basic rates from there, it probably + * doesn't happen any more, but keep the workaround so + * in case some *other* APs are buggy in different ways + * we can connect -- with a warning. + * Allow this workaround only in case the AP provided at least + * one rate. + */ + if (min_rate_index < 0) { + link_info(link, "No legacy rates in association response\n"); + return -EINVAL; + } else if (!basic_rates) { + link_info(link, "No basic rates, using min rate instead\n"); + basic_rates = BIT(min_rate_index); + } + + if (rates) + link_sta->pub->supp_rates[cbss->channel->band] = rates; + else + link_info(link, "No rates found, keeping mandatory only\n"); + + link->conf->basic_rates = basic_rates; + + /* cf. IEEE 802.11 9.2.12 */ + link->operating_11g_mode = sband->band == NL80211_BAND_2GHZ && + have_higher_than_11mbit; + + return 0; +} + +static u8 ieee80211_max_rx_chains(struct ieee80211_link_data *link, + struct cfg80211_bss *cbss) +{ + struct ieee80211_he_mcs_nss_supp *he_mcs_nss_supp; + const struct element *ht_cap_elem, *vht_cap_elem; + const struct cfg80211_bss_ies *ies; + const struct ieee80211_ht_cap *ht_cap; + const struct ieee80211_vht_cap *vht_cap; + const struct ieee80211_he_cap_elem *he_cap; + const struct element *he_cap_elem; + u16 mcs_80_map, mcs_160_map; + int i, mcs_nss_size; + bool support_160; + u8 chains = 1; + + if (link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_HT) + return chains; + + ht_cap_elem = ieee80211_bss_get_elem(cbss, WLAN_EID_HT_CAPABILITY); + if (ht_cap_elem && ht_cap_elem->datalen >= sizeof(*ht_cap)) { + ht_cap = (void *)ht_cap_elem->data; + chains = ieee80211_mcs_to_chains(&ht_cap->mcs); + /* + * TODO: use "Tx Maximum Number Spatial Streams Supported" and + * "Tx Unequal Modulation Supported" fields. + */ + } + + if (link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_VHT) + return chains; + + vht_cap_elem = ieee80211_bss_get_elem(cbss, WLAN_EID_VHT_CAPABILITY); + if (vht_cap_elem && vht_cap_elem->datalen >= sizeof(*vht_cap)) { + u8 nss; + u16 tx_mcs_map; + + vht_cap = (void *)vht_cap_elem->data; + tx_mcs_map = le16_to_cpu(vht_cap->supp_mcs.tx_mcs_map); + for (nss = 8; nss > 0; nss--) { + if (((tx_mcs_map >> (2 * (nss - 1))) & 3) != + IEEE80211_VHT_MCS_NOT_SUPPORTED) + break; + } + /* TODO: use "Tx Highest Supported Long GI Data Rate" field? */ + chains = max(chains, nss); + } + + if (link->u.mgd.conn_flags & IEEE80211_CONN_DISABLE_HE) + return chains; + + ies = rcu_dereference(cbss->ies); + he_cap_elem = cfg80211_find_ext_elem(WLAN_EID_EXT_HE_CAPABILITY, + ies->data, ies->len); + + if (!he_cap_elem || he_cap_elem->datalen < sizeof(*he_cap)) + return chains; + + /* skip one byte ext_tag_id */ + he_cap = (void *)(he_cap_elem->data + 1); + mcs_nss_size = ieee80211_he_mcs_nss_size(he_cap); + + /* invalid HE IE */ + if (he_cap_elem->datalen < 1 + mcs_nss_size + sizeof(*he_cap)) + return chains; + + /* mcs_nss is right after he_cap info */ + he_mcs_nss_supp = (void *)(he_cap + 1); + + mcs_80_map = le16_to_cpu(he_mcs_nss_supp->tx_mcs_80); + + for (i = 7; i >= 0; i--) { + u8 mcs_80 = mcs_80_map >> (2 * i) & 3; + + if (mcs_80 != IEEE80211_VHT_MCS_NOT_SUPPORTED) { + chains = max_t(u8, chains, i + 1); + break; + } + } + + support_160 = he_cap->phy_cap_info[0] & + IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_160MHZ_IN_5G; + + if (!support_160) + return chains; + + mcs_160_map = le16_to_cpu(he_mcs_nss_supp->tx_mcs_160); + for (i = 7; i >= 0; i--) { + u8 mcs_160 = mcs_160_map >> (2 * i) & 3; + + if (mcs_160 != IEEE80211_VHT_MCS_NOT_SUPPORTED) { + chains = max_t(u8, chains, i + 1); + break; + } + } + + return chains; +} + +static bool +ieee80211_verify_peer_he_mcs_support(struct ieee80211_sub_if_data *sdata, + const struct cfg80211_bss_ies *ies, + const struct ieee80211_he_operation *he_op) +{ + const struct element *he_cap_elem; + const struct ieee80211_he_cap_elem *he_cap; + struct ieee80211_he_mcs_nss_supp *he_mcs_nss_supp; + u16 mcs_80_map_tx, mcs_80_map_rx; + u16 ap_min_req_set; + int mcs_nss_size; + int nss; + + he_cap_elem = cfg80211_find_ext_elem(WLAN_EID_EXT_HE_CAPABILITY, + ies->data, ies->len); + + if (!he_cap_elem) + return false; + + /* invalid HE IE */ + if (he_cap_elem->datalen < 1 + sizeof(*he_cap)) { + sdata_info(sdata, + "Invalid HE elem, Disable HE\n"); + return false; + } + + /* skip one byte ext_tag_id */ + he_cap = (void *)(he_cap_elem->data + 1); + mcs_nss_size = ieee80211_he_mcs_nss_size(he_cap); + + /* invalid HE IE */ + if (he_cap_elem->datalen < 1 + sizeof(*he_cap) + mcs_nss_size) { + sdata_info(sdata, + "Invalid HE elem with nss size, Disable HE\n"); + return false; + } + + /* mcs_nss is right after he_cap info */ + he_mcs_nss_supp = (void *)(he_cap + 1); + + mcs_80_map_tx = le16_to_cpu(he_mcs_nss_supp->tx_mcs_80); + mcs_80_map_rx = le16_to_cpu(he_mcs_nss_supp->rx_mcs_80); + + /* P802.11-REVme/D0.3 + * 27.1.1 Introduction to the HE PHY + * ... + * An HE STA shall support the following features: + * ... + * Single spatial stream HE-MCSs 0 to 7 (transmit and receive) in all + * supported channel widths for HE SU PPDUs + */ + if ((mcs_80_map_tx & 0x3) == IEEE80211_HE_MCS_NOT_SUPPORTED || + (mcs_80_map_rx & 0x3) == IEEE80211_HE_MCS_NOT_SUPPORTED) { + sdata_info(sdata, + "Missing mandatory rates for 1 Nss, rx 0x%x, tx 0x%x, disable HE\n", + mcs_80_map_tx, mcs_80_map_rx); + return false; + } + + if (!he_op) + return true; + + ap_min_req_set = le16_to_cpu(he_op->he_mcs_nss_set); + + /* + * Apparently iPhone 13 (at least iOS version 15.3.1) sets this to all + * zeroes, which is nonsense, and completely inconsistent with itself + * (it doesn't have 8 streams). Accept the settings in this case anyway. + */ + if (!ap_min_req_set) + return true; + + /* make sure the AP is consistent with itself + * + * P802.11-REVme/D0.3 + * 26.17.1 Basic HE BSS operation + * + * A STA that is operating in an HE BSS shall be able to receive and + * transmit at each of the tuple values indicated by the + * Basic HE-MCS And NSS Set field of the HE Operation parameter of the + * MLME-START.request primitive and shall be able to receive at each of + * the tuple values indicated by the Supported HE-MCS and + * NSS Set field in the HE Capabilities parameter of the MLMESTART.request + * primitive + */ + for (nss = 8; nss > 0; nss--) { + u8 ap_op_val = (ap_min_req_set >> (2 * (nss - 1))) & 3; + u8 ap_rx_val; + u8 ap_tx_val; + + if (ap_op_val == IEEE80211_HE_MCS_NOT_SUPPORTED) + continue; + + ap_rx_val = (mcs_80_map_rx >> (2 * (nss - 1))) & 3; + ap_tx_val = (mcs_80_map_tx >> (2 * (nss - 1))) & 3; + + if (ap_rx_val == IEEE80211_HE_MCS_NOT_SUPPORTED || + ap_tx_val == IEEE80211_HE_MCS_NOT_SUPPORTED || + ap_rx_val < ap_op_val || ap_tx_val < ap_op_val) { + sdata_info(sdata, + "Invalid rates for %d Nss, rx %d, tx %d oper %d, disable HE\n", + nss, ap_rx_val, ap_rx_val, ap_op_val); + return false; + } + } + + return true; +} + +static bool +ieee80211_verify_sta_he_mcs_support(struct ieee80211_sub_if_data *sdata, + struct ieee80211_supported_band *sband, + const struct ieee80211_he_operation *he_op) +{ + const struct ieee80211_sta_he_cap *sta_he_cap = + ieee80211_get_he_iftype_cap(sband, + ieee80211_vif_type_p2p(&sdata->vif)); + u16 ap_min_req_set; + int i; + + if (!sta_he_cap || !he_op) + return false; + + ap_min_req_set = le16_to_cpu(he_op->he_mcs_nss_set); + + /* + * Apparently iPhone 13 (at least iOS version 15.3.1) sets this to all + * zeroes, which is nonsense, and completely inconsistent with itself + * (it doesn't have 8 streams). Accept the settings in this case anyway. + */ + if (!ap_min_req_set) + return true; + + /* Need to go over for 80MHz, 160MHz and for 80+80 */ + for (i = 0; i < 3; i++) { + const struct ieee80211_he_mcs_nss_supp *sta_mcs_nss_supp = + &sta_he_cap->he_mcs_nss_supp; + u16 sta_mcs_map_rx = + le16_to_cpu(((__le16 *)sta_mcs_nss_supp)[2 * i]); + u16 sta_mcs_map_tx = + le16_to_cpu(((__le16 *)sta_mcs_nss_supp)[2 * i + 1]); + u8 nss; + bool verified = true; + + /* + * For each band there is a maximum of 8 spatial streams + * possible. Each of the sta_mcs_map_* is a 16-bit struct built + * of 2 bits per NSS (1-8), with the values defined in enum + * ieee80211_he_mcs_support. Need to make sure STA TX and RX + * capabilities aren't less than the AP's minimum requirements + * for this HE BSS per SS. + * It is enough to find one such band that meets the reqs. + */ + for (nss = 8; nss > 0; nss--) { + u8 sta_rx_val = (sta_mcs_map_rx >> (2 * (nss - 1))) & 3; + u8 sta_tx_val = (sta_mcs_map_tx >> (2 * (nss - 1))) & 3; + u8 ap_val = (ap_min_req_set >> (2 * (nss - 1))) & 3; + + if (ap_val == IEEE80211_HE_MCS_NOT_SUPPORTED) + continue; + + /* + * Make sure the HE AP doesn't require MCSs that aren't + * supported by the client as required by spec + * + * P802.11-REVme/D0.3 + * 26.17.1 Basic HE BSS operation + * + * An HE STA shall not attempt to join * (MLME-JOIN.request primitive) + * a BSS, unless it supports (i.e., is able to both transmit and + * receive using) all of the tuples in the basic + * HE-MCS and NSS set. + */ + if (sta_rx_val == IEEE80211_HE_MCS_NOT_SUPPORTED || + sta_tx_val == IEEE80211_HE_MCS_NOT_SUPPORTED || + (ap_val > sta_rx_val) || (ap_val > sta_tx_val)) { + verified = false; + break; + } + } + + if (verified) + return true; + } + + /* If here, STA doesn't meet AP's HE min requirements */ + return false; +} + +static int ieee80211_prep_channel(struct ieee80211_sub_if_data *sdata, + struct ieee80211_link_data *link, + struct cfg80211_bss *cbss, + ieee80211_conn_flags_t *conn_flags) +{ + struct ieee80211_local *local = sdata->local; + const struct ieee80211_ht_cap *ht_cap = NULL; + const struct ieee80211_ht_operation *ht_oper = NULL; + const struct ieee80211_vht_operation *vht_oper = NULL; + const struct ieee80211_he_operation *he_oper = NULL; + const struct ieee80211_eht_operation *eht_oper = NULL; + const struct ieee80211_s1g_oper_ie *s1g_oper = NULL; + struct ieee80211_supported_band *sband; + struct cfg80211_chan_def chandef; + bool is_6ghz = cbss->channel->band == NL80211_BAND_6GHZ; + bool is_5ghz = cbss->channel->band == NL80211_BAND_5GHZ; + struct ieee80211_bss *bss = (void *)cbss->priv; + struct ieee80211_elems_parse_params parse_params = { + .bss = cbss, + .link_id = -1, + .from_ap = true, + }; + struct ieee802_11_elems *elems; + const struct cfg80211_bss_ies *ies; + int ret; + u32 i; + bool have_80mhz; + + rcu_read_lock(); + + ies = rcu_dereference(cbss->ies); + parse_params.start = ies->data; + parse_params.len = ies->len; + elems = ieee802_11_parse_elems_full(&parse_params); + if (!elems) { + rcu_read_unlock(); + return -ENOMEM; + } + + sband = local->hw.wiphy->bands[cbss->channel->band]; + + *conn_flags &= ~(IEEE80211_CONN_DISABLE_40MHZ | + IEEE80211_CONN_DISABLE_80P80MHZ | + IEEE80211_CONN_DISABLE_160MHZ); + + /* disable HT/VHT/HE if we don't support them */ + if (!sband->ht_cap.ht_supported && !is_6ghz) { + mlme_dbg(sdata, "HT not supported, disabling HT/VHT/HE/EHT\n"); + *conn_flags |= IEEE80211_CONN_DISABLE_HT; + *conn_flags |= IEEE80211_CONN_DISABLE_VHT; + *conn_flags |= IEEE80211_CONN_DISABLE_HE; + *conn_flags |= IEEE80211_CONN_DISABLE_EHT; + } + + if (!sband->vht_cap.vht_supported && is_5ghz) { + mlme_dbg(sdata, "VHT not supported, disabling VHT/HE/EHT\n"); + *conn_flags |= IEEE80211_CONN_DISABLE_VHT; + *conn_flags |= IEEE80211_CONN_DISABLE_HE; + *conn_flags |= IEEE80211_CONN_DISABLE_EHT; + } + + if (!ieee80211_get_he_iftype_cap(sband, + ieee80211_vif_type_p2p(&sdata->vif))) { + mlme_dbg(sdata, "HE not supported, disabling HE and EHT\n"); + *conn_flags |= IEEE80211_CONN_DISABLE_HE; + *conn_flags |= IEEE80211_CONN_DISABLE_EHT; + } + + if (!ieee80211_get_eht_iftype_cap(sband, + ieee80211_vif_type_p2p(&sdata->vif))) { + mlme_dbg(sdata, "EHT not supported, disabling EHT\n"); + *conn_flags |= IEEE80211_CONN_DISABLE_EHT; + } + + if (!(*conn_flags & IEEE80211_CONN_DISABLE_HT) && !is_6ghz) { + ht_oper = elems->ht_operation; + ht_cap = elems->ht_cap_elem; + + if (!ht_cap) { + *conn_flags |= IEEE80211_CONN_DISABLE_HT; + ht_oper = NULL; + } + } + + if (!(*conn_flags & IEEE80211_CONN_DISABLE_VHT) && !is_6ghz) { + vht_oper = elems->vht_operation; + if (vht_oper && !ht_oper) { + vht_oper = NULL; + sdata_info(sdata, + "AP advertised VHT without HT, disabling HT/VHT/HE\n"); + *conn_flags |= IEEE80211_CONN_DISABLE_HT; + *conn_flags |= IEEE80211_CONN_DISABLE_VHT; + *conn_flags |= IEEE80211_CONN_DISABLE_HE; + *conn_flags |= IEEE80211_CONN_DISABLE_EHT; + } + + if (!elems->vht_cap_elem) { + *conn_flags |= IEEE80211_CONN_DISABLE_VHT; + vht_oper = NULL; + } + } + + if (!(*conn_flags & IEEE80211_CONN_DISABLE_HE)) { + he_oper = elems->he_operation; + + if (link && is_6ghz) { + struct ieee80211_bss_conf *bss_conf; + u8 j = 0; + + bss_conf = link->conf; + + if (elems->pwr_constr_elem) + bss_conf->pwr_reduction = *elems->pwr_constr_elem; + + BUILD_BUG_ON(ARRAY_SIZE(bss_conf->tx_pwr_env) != + ARRAY_SIZE(elems->tx_pwr_env)); + + for (i = 0; i < elems->tx_pwr_env_num; i++) { + if (elems->tx_pwr_env_len[i] > + sizeof(bss_conf->tx_pwr_env[j])) + continue; + + bss_conf->tx_pwr_env_num++; + memcpy(&bss_conf->tx_pwr_env[j], elems->tx_pwr_env[i], + elems->tx_pwr_env_len[i]); + j++; + } + } + + if (!ieee80211_verify_peer_he_mcs_support(sdata, ies, he_oper) || + !ieee80211_verify_sta_he_mcs_support(sdata, sband, he_oper)) + *conn_flags |= IEEE80211_CONN_DISABLE_HE | + IEEE80211_CONN_DISABLE_EHT; + } + + /* + * EHT requires HE to be supported as well. Specifically for 6 GHz + * channels, the operation channel information can only be deduced from + * both the 6 GHz operation information (from the HE operation IE) and + * EHT operation. + */ + if (!(*conn_flags & + (IEEE80211_CONN_DISABLE_HE | + IEEE80211_CONN_DISABLE_EHT)) && + he_oper) { + const struct cfg80211_bss_ies *cbss_ies; + const u8 *eht_oper_ie; + + cbss_ies = rcu_dereference(cbss->ies); + eht_oper_ie = cfg80211_find_ext_ie(WLAN_EID_EXT_EHT_OPERATION, + cbss_ies->data, cbss_ies->len); + if (eht_oper_ie && eht_oper_ie[1] >= + 1 + sizeof(struct ieee80211_eht_operation)) + eht_oper = (void *)(eht_oper_ie + 3); + else + eht_oper = NULL; + } + + /* Allow VHT if at least one channel on the sband supports 80 MHz */ + have_80mhz = false; + for (i = 0; i < sband->n_channels; i++) { + if (sband->channels[i].flags & (IEEE80211_CHAN_DISABLED | + IEEE80211_CHAN_NO_80MHZ)) + continue; + + have_80mhz = true; + break; + } + + if (!have_80mhz) { + sdata_info(sdata, "80 MHz not supported, disabling VHT\n"); + *conn_flags |= IEEE80211_CONN_DISABLE_VHT; + } + + if (sband->band == NL80211_BAND_S1GHZ) { + s1g_oper = elems->s1g_oper; + if (!s1g_oper) + sdata_info(sdata, + "AP missing S1G operation element?\n"); + } + + *conn_flags |= + ieee80211_determine_chantype(sdata, link, *conn_flags, + sband, + cbss->channel, + bss->vht_cap_info, + ht_oper, vht_oper, + he_oper, eht_oper, + s1g_oper, + &chandef, false); + + if (link) + link->needed_rx_chains = + min(ieee80211_max_rx_chains(link, cbss), + local->rx_chains); + + rcu_read_unlock(); + /* the element data was RCU protected so no longer valid anyway */ + kfree(elems); + elems = NULL; + + if (*conn_flags & IEEE80211_CONN_DISABLE_HE && is_6ghz) { + sdata_info(sdata, "Rejecting non-HE 6/7 GHz connection"); + return -EINVAL; + } + + if (!link) + return 0; + + /* will change later if needed */ + link->smps_mode = IEEE80211_SMPS_OFF; + + mutex_lock(&local->mtx); + /* + * If this fails (possibly due to channel context sharing + * on incompatible channels, e.g. 80+80 and 160 sharing the + * same control channel) try to use a smaller bandwidth. + */ + ret = ieee80211_link_use_channel(link, &chandef, + IEEE80211_CHANCTX_SHARED); + + /* don't downgrade for 5 and 10 MHz channels, though. */ + if (chandef.width == NL80211_CHAN_WIDTH_5 || + chandef.width == NL80211_CHAN_WIDTH_10) + goto out; + + while (ret && chandef.width != NL80211_CHAN_WIDTH_20_NOHT) { + *conn_flags |= + ieee80211_chandef_downgrade(&chandef); + ret = ieee80211_link_use_channel(link, &chandef, + IEEE80211_CHANCTX_SHARED); + } + out: + mutex_unlock(&local->mtx); + return ret; +} + +static bool ieee80211_get_dtim(const struct cfg80211_bss_ies *ies, + u8 *dtim_count, u8 *dtim_period) +{ + const u8 *tim_ie = cfg80211_find_ie(WLAN_EID_TIM, ies->data, ies->len); + const u8 *idx_ie = cfg80211_find_ie(WLAN_EID_MULTI_BSSID_IDX, ies->data, + ies->len); + const struct ieee80211_tim_ie *tim = NULL; + const struct ieee80211_bssid_index *idx; + bool valid = tim_ie && tim_ie[1] >= 2; + + if (valid) + tim = (void *)(tim_ie + 2); + + if (dtim_count) + *dtim_count = valid ? tim->dtim_count : 0; + + if (dtim_period) + *dtim_period = valid ? tim->dtim_period : 0; + + /* Check if value is overridden by non-transmitted profile */ + if (!idx_ie || idx_ie[1] < 3) + return valid; + + idx = (void *)(idx_ie + 2); + + if (dtim_count) + *dtim_count = idx->dtim_count; + + if (dtim_period) + *dtim_period = idx->dtim_period; + + return true; +} + +static bool ieee80211_assoc_success(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, + struct ieee802_11_elems *elems, + const u8 *elem_start, unsigned int elem_len) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_mgd_assoc_data *assoc_data = ifmgd->assoc_data; + struct ieee80211_local *local = sdata->local; + unsigned int link_id; + struct sta_info *sta; + u64 changed[IEEE80211_MLD_MAX_NUM_LINKS] = {}; + int err; + + mutex_lock(&sdata->local->sta_mtx); + /* + * station info was already allocated and inserted before + * the association and should be available to us + */ + sta = sta_info_get(sdata, assoc_data->ap_addr); + if (WARN_ON(!sta)) + goto out_err; + + if (sdata->vif.valid_links) { + u16 valid_links = 0; + + for (link_id = 0; link_id < IEEE80211_MLD_MAX_NUM_LINKS; link_id++) { + if (!assoc_data->link[link_id].bss) + continue; + valid_links |= BIT(link_id); + + if (link_id != assoc_data->assoc_link_id) { + err = ieee80211_sta_allocate_link(sta, link_id); + if (err) + goto out_err; + } + } + + ieee80211_vif_set_links(sdata, valid_links); + } + + for (link_id = 0; link_id < IEEE80211_MLD_MAX_NUM_LINKS; link_id++) { + struct ieee80211_link_data *link; + struct link_sta_info *link_sta; + + if (!assoc_data->link[link_id].bss) + continue; + + link = sdata_dereference(sdata->link[link_id], sdata); + if (WARN_ON(!link)) + goto out_err; + + if (sdata->vif.valid_links) + link_info(link, + "local address %pM, AP link address %pM\n", + link->conf->addr, + assoc_data->link[link_id].bss->bssid); + + link_sta = rcu_dereference_protected(sta->link[link_id], + lockdep_is_held(&local->sta_mtx)); + if (WARN_ON(!link_sta)) + goto out_err; + + if (link_id != assoc_data->assoc_link_id) { + struct cfg80211_bss *cbss = assoc_data->link[link_id].bss; + const struct cfg80211_bss_ies *ies; + + rcu_read_lock(); + ies = rcu_dereference(cbss->ies); + ieee80211_get_dtim(ies, + &link->conf->sync_dtim_count, + &link->u.mgd.dtim_period); + link->conf->dtim_period = link->u.mgd.dtim_period ?: 1; + link->conf->beacon_int = cbss->beacon_interval; + rcu_read_unlock(); + + err = ieee80211_prep_channel(sdata, link, cbss, + &link->u.mgd.conn_flags); + if (err) { + link_info(link, "prep_channel failed\n"); + goto out_err; + } + } + + err = ieee80211_mgd_setup_link_sta(link, sta, link_sta, + assoc_data->link[link_id].bss); + if (err) + goto out_err; + + if (!ieee80211_assoc_config_link(link, link_sta, + assoc_data->link[link_id].bss, + mgmt, elem_start, elem_len, + &changed[link_id])) + goto out_err; + + if (link_id != assoc_data->assoc_link_id) { + err = ieee80211_sta_activate_link(sta, link_id); + if (err) + goto out_err; + } + } + + rate_control_rate_init(sta); + + if (ifmgd->flags & IEEE80211_STA_MFP_ENABLED) { + set_sta_flag(sta, WLAN_STA_MFP); + sta->sta.mfp = true; + } else { + sta->sta.mfp = false; + } + + ieee80211_sta_set_max_amsdu_subframes(sta, elems->ext_capab, + elems->ext_capab_len); + + sta->sta.wme = (elems->wmm_param || elems->s1g_capab) && + local->hw.queues >= IEEE80211_NUM_ACS; + + err = sta_info_move_state(sta, IEEE80211_STA_ASSOC); + if (!err && !(ifmgd->flags & IEEE80211_STA_CONTROL_PORT)) + err = sta_info_move_state(sta, IEEE80211_STA_AUTHORIZED); + if (err) { + sdata_info(sdata, + "failed to move station %pM to desired state\n", + sta->sta.addr); + WARN_ON(__sta_info_destroy(sta)); + goto out_err; + } + + if (sdata->wdev.use_4addr) + drv_sta_set_4addr(local, sdata, &sta->sta, true); + + mutex_unlock(&sdata->local->sta_mtx); + + ieee80211_set_associated(sdata, assoc_data, changed); + + /* + * If we're using 4-addr mode, let the AP know that we're + * doing so, so that it can create the STA VLAN on its side + */ + if (ifmgd->use_4addr) + ieee80211_send_4addr_nullfunc(local, sdata); + + /* + * Start timer to probe the connection to the AP now. + * Also start the timer that will detect beacon loss. + */ + ieee80211_sta_reset_beacon_monitor(sdata); + ieee80211_sta_reset_conn_monitor(sdata); + + return true; +out_err: + eth_zero_addr(sdata->vif.cfg.ap_addr); + mutex_unlock(&sdata->local->sta_mtx); + return false; +} + +static void ieee80211_rx_mgmt_assoc_resp(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, + size_t len) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_mgd_assoc_data *assoc_data = ifmgd->assoc_data; + u16 capab_info, status_code, aid; + struct ieee80211_elems_parse_params parse_params = { + .bss = NULL, + .link_id = -1, + .from_ap = true, + }; + struct ieee802_11_elems *elems; + int ac; + const u8 *elem_start; + unsigned int elem_len; + bool reassoc; + struct ieee80211_event event = { + .type = MLME_EVENT, + .u.mlme.data = ASSOC_EVENT, + }; + struct ieee80211_prep_tx_info info = {}; + struct cfg80211_rx_assoc_resp resp = { + .uapsd_queues = -1, + }; + u8 ap_mld_addr[ETH_ALEN] __aligned(2); + unsigned int link_id; + + sdata_assert_lock(sdata); + + if (!assoc_data) + return; + + if (!ether_addr_equal(assoc_data->ap_addr, mgmt->bssid) || + !ether_addr_equal(assoc_data->ap_addr, mgmt->sa)) + return; + + /* + * AssocResp and ReassocResp have identical structure, so process both + * of them in this function. + */ + + if (len < 24 + 6) + return; + + reassoc = ieee80211_is_reassoc_resp(mgmt->frame_control); + capab_info = le16_to_cpu(mgmt->u.assoc_resp.capab_info); + status_code = le16_to_cpu(mgmt->u.assoc_resp.status_code); + if (assoc_data->s1g) + elem_start = mgmt->u.s1g_assoc_resp.variable; + else + elem_start = mgmt->u.assoc_resp.variable; + + /* + * Note: this may not be perfect, AP might misbehave - if + * anyone needs to rely on perfect complete notification + * with the exact right subtype, then we need to track what + * we actually transmitted. + */ + info.subtype = reassoc ? IEEE80211_STYPE_REASSOC_REQ : + IEEE80211_STYPE_ASSOC_REQ; + + if (assoc_data->fils_kek_len && + fils_decrypt_assoc_resp(sdata, (u8 *)mgmt, &len, assoc_data) < 0) + return; + + elem_len = len - (elem_start - (u8 *)mgmt); + parse_params.start = elem_start; + parse_params.len = elem_len; + elems = ieee802_11_parse_elems_full(&parse_params); + if (!elems) + goto notify_driver; + + if (elems->aid_resp) + aid = le16_to_cpu(elems->aid_resp->aid); + else if (assoc_data->s1g) + aid = 0; /* TODO */ + else + aid = le16_to_cpu(mgmt->u.assoc_resp.aid); + + /* + * The 5 MSB of the AID field are reserved + * (802.11-2016 9.4.1.8 AID field) + */ + aid &= 0x7ff; + + sdata_info(sdata, + "RX %sssocResp from %pM (capab=0x%x status=%d aid=%d)\n", + reassoc ? "Rea" : "A", assoc_data->ap_addr, + capab_info, status_code, (u16)(aid & ~(BIT(15) | BIT(14)))); + + ifmgd->broken_ap = false; + + if (status_code == WLAN_STATUS_ASSOC_REJECTED_TEMPORARILY && + elems->timeout_int && + elems->timeout_int->type == WLAN_TIMEOUT_ASSOC_COMEBACK) { + u32 tu, ms; + + cfg80211_assoc_comeback(sdata->dev, assoc_data->ap_addr, + le32_to_cpu(elems->timeout_int->value)); + + tu = le32_to_cpu(elems->timeout_int->value); + ms = tu * 1024 / 1000; + sdata_info(sdata, + "%pM rejected association temporarily; comeback duration %u TU (%u ms)\n", + assoc_data->ap_addr, tu, ms); + assoc_data->timeout = jiffies + msecs_to_jiffies(ms); + assoc_data->timeout_started = true; + if (ms > IEEE80211_ASSOC_TIMEOUT) + run_again(sdata, assoc_data->timeout); + goto notify_driver; + } + + if (status_code != WLAN_STATUS_SUCCESS) { + sdata_info(sdata, "%pM denied association (code=%d)\n", + assoc_data->ap_addr, status_code); + event.u.mlme.status = MLME_DENIED; + event.u.mlme.reason = status_code; + drv_event_callback(sdata->local, sdata, &event); + } else { + if (aid == 0 || aid > IEEE80211_MAX_AID) { + sdata_info(sdata, + "invalid AID value %d (out of range), turn off PS\n", + aid); + aid = 0; + ifmgd->broken_ap = true; + } + + if (sdata->vif.valid_links) { + if (!elems->multi_link) { + sdata_info(sdata, + "MLO association with %pM but no multi-link element in response!\n", + assoc_data->ap_addr); + goto abandon_assoc; + } + + if (le16_get_bits(elems->multi_link->control, + IEEE80211_ML_CONTROL_TYPE) != + IEEE80211_ML_CONTROL_TYPE_BASIC) { + sdata_info(sdata, + "bad multi-link element (control=0x%x)\n", + le16_to_cpu(elems->multi_link->control)); + goto abandon_assoc; + } else { + struct ieee80211_mle_basic_common_info *common; + + common = (void *)elems->multi_link->variable; + + if (memcmp(assoc_data->ap_addr, + common->mld_mac_addr, ETH_ALEN)) { + sdata_info(sdata, + "AP MLD MAC address mismatch: got %pM expected %pM\n", + common->mld_mac_addr, + assoc_data->ap_addr); + goto abandon_assoc; + } + } + } + + sdata->vif.cfg.aid = aid; + + if (!ieee80211_assoc_success(sdata, mgmt, elems, + elem_start, elem_len)) { + /* oops -- internal error -- send timeout for now */ + ieee80211_destroy_assoc_data(sdata, ASSOC_TIMEOUT); + goto notify_driver; + } + event.u.mlme.status = MLME_SUCCESS; + drv_event_callback(sdata->local, sdata, &event); + sdata_info(sdata, "associated\n"); + + info.success = 1; + } + + for (link_id = 0; link_id < IEEE80211_MLD_MAX_NUM_LINKS; link_id++) { + struct ieee80211_link_data *link; + + link = sdata_dereference(sdata->link[link_id], sdata); + if (!link) + continue; + if (!assoc_data->link[link_id].bss) + continue; + resp.links[link_id].bss = assoc_data->link[link_id].bss; + resp.links[link_id].addr = link->conf->addr; + + /* get uapsd queues configuration - same for all links */ + resp.uapsd_queues = 0; + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) + if (link->tx_conf[ac].uapsd) + resp.uapsd_queues |= ieee80211_ac_to_qos_mask[ac]; + } + + if (sdata->vif.valid_links) { + ether_addr_copy(ap_mld_addr, sdata->vif.cfg.ap_addr); + resp.ap_mld_addr = ap_mld_addr; + } + + ieee80211_destroy_assoc_data(sdata, + status_code == WLAN_STATUS_SUCCESS ? + ASSOC_SUCCESS : + ASSOC_REJECTED); + + resp.buf = (u8 *)mgmt; + resp.len = len; + resp.req_ies = ifmgd->assoc_req_ies; + resp.req_ies_len = ifmgd->assoc_req_ies_len; + cfg80211_rx_assoc_resp(sdata->dev, &resp); +notify_driver: + drv_mgd_complete_tx(sdata->local, sdata, &info); + kfree(elems); + return; +abandon_assoc: + ieee80211_destroy_assoc_data(sdata, ASSOC_ABANDON); + goto notify_driver; +} + +static void ieee80211_rx_bss_info(struct ieee80211_link_data *link, + struct ieee80211_mgmt *mgmt, size_t len, + struct ieee80211_rx_status *rx_status) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_bss *bss; + struct ieee80211_channel *channel; + + sdata_assert_lock(sdata); + + channel = ieee80211_get_channel_khz(local->hw.wiphy, + ieee80211_rx_status_to_khz(rx_status)); + if (!channel) + return; + + bss = ieee80211_bss_info_update(local, rx_status, mgmt, len, channel); + if (bss) { + link->conf->beacon_rate = bss->beacon_rate; + ieee80211_rx_bss_put(local, bss); + } +} + + +static void ieee80211_rx_mgmt_probe_resp(struct ieee80211_link_data *link, + struct sk_buff *skb) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_mgmt *mgmt = (void *)skb->data; + struct ieee80211_if_managed *ifmgd; + struct ieee80211_rx_status *rx_status = (void *) skb->cb; + struct ieee80211_channel *channel; + size_t baselen, len = skb->len; + + ifmgd = &sdata->u.mgd; + + sdata_assert_lock(sdata); + + /* + * According to Draft P802.11ax D6.0 clause 26.17.2.3.2: + * "If a 6 GHz AP receives a Probe Request frame and responds with + * a Probe Response frame [..], the Address 1 field of the Probe + * Response frame shall be set to the broadcast address [..]" + * So, on 6GHz band we should also accept broadcast responses. + */ + channel = ieee80211_get_channel(sdata->local->hw.wiphy, + rx_status->freq); + if (!channel) + return; + + if (!ether_addr_equal(mgmt->da, sdata->vif.addr) && + (channel->band != NL80211_BAND_6GHZ || + !is_broadcast_ether_addr(mgmt->da))) + return; /* ignore ProbeResp to foreign address */ + + baselen = (u8 *) mgmt->u.probe_resp.variable - (u8 *) mgmt; + if (baselen > len) + return; + + ieee80211_rx_bss_info(link, mgmt, len, rx_status); + + if (ifmgd->associated && + ether_addr_equal(mgmt->bssid, link->u.mgd.bssid)) + ieee80211_reset_ap_probe(sdata); +} + +/* + * This is the canonical list of information elements we care about, + * the filter code also gives us all changes to the Microsoft OUI + * (00:50:F2) vendor IE which is used for WMM which we need to track, + * as well as the DTPC IE (part of the Cisco OUI) used for signaling + * changes to requested client power. + * + * We implement beacon filtering in software since that means we can + * avoid processing the frame here and in cfg80211, and userspace + * will not be able to tell whether the hardware supports it or not. + * + * XXX: This list needs to be dynamic -- userspace needs to be able to + * add items it requires. It also needs to be able to tell us to + * look out for other vendor IEs. + */ +static const u64 care_about_ies = + (1ULL << WLAN_EID_COUNTRY) | + (1ULL << WLAN_EID_ERP_INFO) | + (1ULL << WLAN_EID_CHANNEL_SWITCH) | + (1ULL << WLAN_EID_PWR_CONSTRAINT) | + (1ULL << WLAN_EID_HT_CAPABILITY) | + (1ULL << WLAN_EID_HT_OPERATION) | + (1ULL << WLAN_EID_EXT_CHANSWITCH_ANN); + +static void ieee80211_handle_beacon_sig(struct ieee80211_link_data *link, + struct ieee80211_if_managed *ifmgd, + struct ieee80211_bss_conf *bss_conf, + struct ieee80211_local *local, + struct ieee80211_rx_status *rx_status) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + + /* Track average RSSI from the Beacon frames of the current AP */ + + if (!link->u.mgd.tracking_signal_avg) { + link->u.mgd.tracking_signal_avg = true; + ewma_beacon_signal_init(&link->u.mgd.ave_beacon_signal); + link->u.mgd.last_cqm_event_signal = 0; + link->u.mgd.count_beacon_signal = 1; + link->u.mgd.last_ave_beacon_signal = 0; + } else { + link->u.mgd.count_beacon_signal++; + } + + ewma_beacon_signal_add(&link->u.mgd.ave_beacon_signal, + -rx_status->signal); + + if (ifmgd->rssi_min_thold != ifmgd->rssi_max_thold && + link->u.mgd.count_beacon_signal >= IEEE80211_SIGNAL_AVE_MIN_COUNT) { + int sig = -ewma_beacon_signal_read(&link->u.mgd.ave_beacon_signal); + int last_sig = link->u.mgd.last_ave_beacon_signal; + struct ieee80211_event event = { + .type = RSSI_EVENT, + }; + + /* + * if signal crosses either of the boundaries, invoke callback + * with appropriate parameters + */ + if (sig > ifmgd->rssi_max_thold && + (last_sig <= ifmgd->rssi_min_thold || last_sig == 0)) { + link->u.mgd.last_ave_beacon_signal = sig; + event.u.rssi.data = RSSI_EVENT_HIGH; + drv_event_callback(local, sdata, &event); + } else if (sig < ifmgd->rssi_min_thold && + (last_sig >= ifmgd->rssi_max_thold || + last_sig == 0)) { + link->u.mgd.last_ave_beacon_signal = sig; + event.u.rssi.data = RSSI_EVENT_LOW; + drv_event_callback(local, sdata, &event); + } + } + + if (bss_conf->cqm_rssi_thold && + link->u.mgd.count_beacon_signal >= IEEE80211_SIGNAL_AVE_MIN_COUNT && + !(sdata->vif.driver_flags & IEEE80211_VIF_SUPPORTS_CQM_RSSI)) { + int sig = -ewma_beacon_signal_read(&link->u.mgd.ave_beacon_signal); + int last_event = link->u.mgd.last_cqm_event_signal; + int thold = bss_conf->cqm_rssi_thold; + int hyst = bss_conf->cqm_rssi_hyst; + + if (sig < thold && + (last_event == 0 || sig < last_event - hyst)) { + link->u.mgd.last_cqm_event_signal = sig; + ieee80211_cqm_rssi_notify( + &sdata->vif, + NL80211_CQM_RSSI_THRESHOLD_EVENT_LOW, + sig, GFP_KERNEL); + } else if (sig > thold && + (last_event == 0 || sig > last_event + hyst)) { + link->u.mgd.last_cqm_event_signal = sig; + ieee80211_cqm_rssi_notify( + &sdata->vif, + NL80211_CQM_RSSI_THRESHOLD_EVENT_HIGH, + sig, GFP_KERNEL); + } + } + + if (bss_conf->cqm_rssi_low && + link->u.mgd.count_beacon_signal >= IEEE80211_SIGNAL_AVE_MIN_COUNT) { + int sig = -ewma_beacon_signal_read(&link->u.mgd.ave_beacon_signal); + int last_event = link->u.mgd.last_cqm_event_signal; + int low = bss_conf->cqm_rssi_low; + int high = bss_conf->cqm_rssi_high; + + if (sig < low && + (last_event == 0 || last_event >= low)) { + link->u.mgd.last_cqm_event_signal = sig; + ieee80211_cqm_rssi_notify( + &sdata->vif, + NL80211_CQM_RSSI_THRESHOLD_EVENT_LOW, + sig, GFP_KERNEL); + } else if (sig > high && + (last_event == 0 || last_event <= high)) { + link->u.mgd.last_cqm_event_signal = sig; + ieee80211_cqm_rssi_notify( + &sdata->vif, + NL80211_CQM_RSSI_THRESHOLD_EVENT_HIGH, + sig, GFP_KERNEL); + } + } +} + +static bool ieee80211_rx_our_beacon(const u8 *tx_bssid, + struct cfg80211_bss *bss) +{ + if (ether_addr_equal(tx_bssid, bss->bssid)) + return true; + if (!bss->transmitted_bss) + return false; + return ether_addr_equal(tx_bssid, bss->transmitted_bss->bssid); +} + +static void ieee80211_rx_mgmt_beacon(struct ieee80211_link_data *link, + struct ieee80211_hdr *hdr, size_t len, + struct ieee80211_rx_status *rx_status) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_bss_conf *bss_conf = &sdata->vif.bss_conf; + struct ieee80211_vif_cfg *vif_cfg = &sdata->vif.cfg; + struct ieee80211_mgmt *mgmt = (void *) hdr; + size_t baselen; + struct ieee802_11_elems *elems; + struct ieee80211_local *local = sdata->local; + struct ieee80211_chanctx_conf *chanctx_conf; + struct ieee80211_channel *chan; + struct link_sta_info *link_sta; + struct sta_info *sta; + u32 changed = 0; + bool erp_valid; + u8 erp_value = 0; + u32 ncrc = 0; + u8 *bssid, *variable = mgmt->u.beacon.variable; + u8 deauth_buf[IEEE80211_DEAUTH_FRAME_LEN]; + struct ieee80211_elems_parse_params parse_params = { + .link_id = -1, + .from_ap = true, + }; + + sdata_assert_lock(sdata); + + /* Process beacon from the current BSS */ + bssid = ieee80211_get_bssid(hdr, len, sdata->vif.type); + if (ieee80211_is_s1g_beacon(mgmt->frame_control)) { + struct ieee80211_ext *ext = (void *) mgmt; + + if (ieee80211_is_s1g_short_beacon(ext->frame_control)) + variable = ext->u.s1g_short_beacon.variable; + else + variable = ext->u.s1g_beacon.variable; + } + + baselen = (u8 *) variable - (u8 *) mgmt; + if (baselen > len) + return; + + parse_params.start = variable; + parse_params.len = len - baselen; + + rcu_read_lock(); + chanctx_conf = rcu_dereference(link->conf->chanctx_conf); + if (!chanctx_conf) { + rcu_read_unlock(); + return; + } + + if (ieee80211_rx_status_to_khz(rx_status) != + ieee80211_channel_to_khz(chanctx_conf->def.chan)) { + rcu_read_unlock(); + return; + } + chan = chanctx_conf->def.chan; + rcu_read_unlock(); + + if (ifmgd->assoc_data && ifmgd->assoc_data->need_beacon && + !WARN_ON(sdata->vif.valid_links) && + ieee80211_rx_our_beacon(bssid, ifmgd->assoc_data->link[0].bss)) { + parse_params.bss = ifmgd->assoc_data->link[0].bss; + elems = ieee802_11_parse_elems_full(&parse_params); + if (!elems) + return; + + ieee80211_rx_bss_info(link, mgmt, len, rx_status); + + if (elems->dtim_period) + link->u.mgd.dtim_period = elems->dtim_period; + link->u.mgd.have_beacon = true; + ifmgd->assoc_data->need_beacon = false; + if (ieee80211_hw_check(&local->hw, TIMING_BEACON_ONLY)) { + link->conf->sync_tsf = + le64_to_cpu(mgmt->u.beacon.timestamp); + link->conf->sync_device_ts = + rx_status->device_timestamp; + link->conf->sync_dtim_count = elems->dtim_count; + } + + if (elems->mbssid_config_ie) + bss_conf->profile_periodicity = + elems->mbssid_config_ie->profile_periodicity; + else + bss_conf->profile_periodicity = 0; + + if (elems->ext_capab_len >= 11 && + (elems->ext_capab[10] & WLAN_EXT_CAPA11_EMA_SUPPORT)) + bss_conf->ema_ap = true; + else + bss_conf->ema_ap = false; + + /* continue assoc process */ + ifmgd->assoc_data->timeout = jiffies; + ifmgd->assoc_data->timeout_started = true; + run_again(sdata, ifmgd->assoc_data->timeout); + kfree(elems); + return; + } + + if (!ifmgd->associated || + !ieee80211_rx_our_beacon(bssid, link->u.mgd.bss)) + return; + bssid = link->u.mgd.bssid; + + if (!(rx_status->flag & RX_FLAG_NO_SIGNAL_VAL)) + ieee80211_handle_beacon_sig(link, ifmgd, bss_conf, + local, rx_status); + + if (ifmgd->flags & IEEE80211_STA_CONNECTION_POLL) { + mlme_dbg_ratelimited(sdata, + "cancelling AP probe due to a received beacon\n"); + ieee80211_reset_ap_probe(sdata); + } + + /* + * Push the beacon loss detection into the future since + * we are processing a beacon from the AP just now. + */ + ieee80211_sta_reset_beacon_monitor(sdata); + + /* TODO: CRC urrently not calculated on S1G Beacon Compatibility + * element (which carries the beacon interval). Don't forget to add a + * bit to care_about_ies[] above if mac80211 is interested in a + * changing S1G element. + */ + if (!ieee80211_is_s1g_beacon(hdr->frame_control)) + ncrc = crc32_be(0, (void *)&mgmt->u.beacon.beacon_int, 4); + parse_params.bss = link->u.mgd.bss; + parse_params.filter = care_about_ies; + parse_params.crc = ncrc; + elems = ieee802_11_parse_elems_full(&parse_params); + if (!elems) + return; + ncrc = elems->crc; + + if (ieee80211_hw_check(&local->hw, PS_NULLFUNC_STACK) && + ieee80211_check_tim(elems->tim, elems->tim_len, vif_cfg->aid)) { + if (local->hw.conf.dynamic_ps_timeout > 0) { + if (local->hw.conf.flags & IEEE80211_CONF_PS) { + local->hw.conf.flags &= ~IEEE80211_CONF_PS; + ieee80211_hw_config(local, + IEEE80211_CONF_CHANGE_PS); + } + ieee80211_send_nullfunc(local, sdata, false); + } else if (!local->pspolling && sdata->u.mgd.powersave) { + local->pspolling = true; + + /* + * Here is assumed that the driver will be + * able to send ps-poll frame and receive a + * response even though power save mode is + * enabled, but some drivers might require + * to disable power save here. This needs + * to be investigated. + */ + ieee80211_send_pspoll(local, sdata); + } + } + + if (sdata->vif.p2p || + sdata->vif.driver_flags & IEEE80211_VIF_GET_NOA_UPDATE) { + struct ieee80211_p2p_noa_attr noa = {}; + int ret; + + ret = cfg80211_get_p2p_attr(variable, + len - baselen, + IEEE80211_P2P_ATTR_ABSENCE_NOTICE, + (u8 *) &noa, sizeof(noa)); + if (ret >= 2) { + if (link->u.mgd.p2p_noa_index != noa.index) { + /* valid noa_attr and index changed */ + link->u.mgd.p2p_noa_index = noa.index; + memcpy(&bss_conf->p2p_noa_attr, &noa, sizeof(noa)); + changed |= BSS_CHANGED_P2P_PS; + /* + * make sure we update all information, the CRC + * mechanism doesn't look at P2P attributes. + */ + link->u.mgd.beacon_crc_valid = false; + } + } else if (link->u.mgd.p2p_noa_index != -1) { + /* noa_attr not found and we had valid noa_attr before */ + link->u.mgd.p2p_noa_index = -1; + memset(&bss_conf->p2p_noa_attr, 0, sizeof(bss_conf->p2p_noa_attr)); + changed |= BSS_CHANGED_P2P_PS; + link->u.mgd.beacon_crc_valid = false; + } + } + + if (link->u.mgd.csa_waiting_bcn) + ieee80211_chswitch_post_beacon(link); + + /* + * Update beacon timing and dtim count on every beacon appearance. This + * will allow the driver to use the most updated values. Do it before + * comparing this one with last received beacon. + * IMPORTANT: These parameters would possibly be out of sync by the time + * the driver will use them. The synchronized view is currently + * guaranteed only in certain callbacks. + */ + if (ieee80211_hw_check(&local->hw, TIMING_BEACON_ONLY) && + !ieee80211_is_s1g_beacon(hdr->frame_control)) { + link->conf->sync_tsf = + le64_to_cpu(mgmt->u.beacon.timestamp); + link->conf->sync_device_ts = + rx_status->device_timestamp; + link->conf->sync_dtim_count = elems->dtim_count; + } + + if ((ncrc == link->u.mgd.beacon_crc && link->u.mgd.beacon_crc_valid) || + ieee80211_is_s1g_short_beacon(mgmt->frame_control)) + goto free; + link->u.mgd.beacon_crc = ncrc; + link->u.mgd.beacon_crc_valid = true; + + ieee80211_rx_bss_info(link, mgmt, len, rx_status); + + ieee80211_sta_process_chanswitch(link, rx_status->mactime, + rx_status->device_timestamp, + elems, true); + + if (!link->u.mgd.disable_wmm_tracking && + ieee80211_sta_wmm_params(local, link, elems->wmm_param, + elems->wmm_param_len, + elems->mu_edca_param_set)) + changed |= BSS_CHANGED_QOS; + + /* + * If we haven't had a beacon before, tell the driver about the + * DTIM period (and beacon timing if desired) now. + */ + if (!link->u.mgd.have_beacon) { + /* a few bogus AP send dtim_period = 0 or no TIM IE */ + bss_conf->dtim_period = elems->dtim_period ?: 1; + + changed |= BSS_CHANGED_BEACON_INFO; + link->u.mgd.have_beacon = true; + + mutex_lock(&local->iflist_mtx); + ieee80211_recalc_ps(local); + mutex_unlock(&local->iflist_mtx); + + ieee80211_recalc_ps_vif(sdata); + } + + if (elems->erp_info) { + erp_valid = true; + erp_value = elems->erp_info[0]; + } else { + erp_valid = false; + } + + if (!ieee80211_is_s1g_beacon(hdr->frame_control)) + changed |= ieee80211_handle_bss_capability(link, + le16_to_cpu(mgmt->u.beacon.capab_info), + erp_valid, erp_value); + + mutex_lock(&local->sta_mtx); + sta = sta_info_get(sdata, sdata->vif.cfg.ap_addr); + if (WARN_ON(!sta)) { + mutex_unlock(&local->sta_mtx); + goto free; + } + link_sta = rcu_dereference_protected(sta->link[link->link_id], + lockdep_is_held(&local->sta_mtx)); + if (WARN_ON(!link_sta)) { + mutex_unlock(&local->sta_mtx); + goto free; + } + + changed |= ieee80211_recalc_twt_req(link, link_sta, elems); + + if (ieee80211_config_bw(link, elems->ht_cap_elem, + elems->vht_cap_elem, elems->ht_operation, + elems->vht_operation, elems->he_operation, + elems->eht_operation, + elems->s1g_oper, bssid, &changed)) { + mutex_unlock(&local->sta_mtx); + sdata_info(sdata, + "failed to follow AP %pM bandwidth change, disconnect\n", + bssid); + ieee80211_set_disassoc(sdata, IEEE80211_STYPE_DEAUTH, + WLAN_REASON_DEAUTH_LEAVING, + true, deauth_buf); + ieee80211_report_disconnect(sdata, deauth_buf, + sizeof(deauth_buf), true, + WLAN_REASON_DEAUTH_LEAVING, + false); + goto free; + } + + if (sta && elems->opmode_notif) + ieee80211_vht_handle_opmode(sdata, link_sta, + *elems->opmode_notif, + rx_status->band); + mutex_unlock(&local->sta_mtx); + + changed |= ieee80211_handle_pwr_constr(link, chan, mgmt, + elems->country_elem, + elems->country_elem_len, + elems->pwr_constr_elem, + elems->cisco_dtpc_elem); + + ieee80211_link_info_change_notify(sdata, link, changed); +free: + kfree(elems); +} + +void ieee80211_sta_rx_queued_ext(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_link_data *link = &sdata->deflink; + struct ieee80211_rx_status *rx_status; + struct ieee80211_hdr *hdr; + u16 fc; + + rx_status = (struct ieee80211_rx_status *) skb->cb; + hdr = (struct ieee80211_hdr *) skb->data; + fc = le16_to_cpu(hdr->frame_control); + + sdata_lock(sdata); + switch (fc & IEEE80211_FCTL_STYPE) { + case IEEE80211_STYPE_S1G_BEACON: + ieee80211_rx_mgmt_beacon(link, hdr, skb->len, rx_status); + break; + } + sdata_unlock(sdata); +} + +void ieee80211_sta_rx_queued_mgmt(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_link_data *link = &sdata->deflink; + struct ieee80211_rx_status *rx_status; + struct ieee80211_mgmt *mgmt; + u16 fc; + int ies_len; + + rx_status = (struct ieee80211_rx_status *) skb->cb; + mgmt = (struct ieee80211_mgmt *) skb->data; + fc = le16_to_cpu(mgmt->frame_control); + + sdata_lock(sdata); + + if (rx_status->link_valid) { + link = sdata_dereference(sdata->link[rx_status->link_id], + sdata); + if (!link) + goto out; + } + + switch (fc & IEEE80211_FCTL_STYPE) { + case IEEE80211_STYPE_BEACON: + ieee80211_rx_mgmt_beacon(link, (void *)mgmt, + skb->len, rx_status); + break; + case IEEE80211_STYPE_PROBE_RESP: + ieee80211_rx_mgmt_probe_resp(link, skb); + break; + case IEEE80211_STYPE_AUTH: + ieee80211_rx_mgmt_auth(sdata, mgmt, skb->len); + break; + case IEEE80211_STYPE_DEAUTH: + ieee80211_rx_mgmt_deauth(sdata, mgmt, skb->len); + break; + case IEEE80211_STYPE_DISASSOC: + ieee80211_rx_mgmt_disassoc(sdata, mgmt, skb->len); + break; + case IEEE80211_STYPE_ASSOC_RESP: + case IEEE80211_STYPE_REASSOC_RESP: + ieee80211_rx_mgmt_assoc_resp(sdata, mgmt, skb->len); + break; + case IEEE80211_STYPE_ACTION: + if (mgmt->u.action.category == WLAN_CATEGORY_SPECTRUM_MGMT) { + struct ieee802_11_elems *elems; + + ies_len = skb->len - + offsetof(struct ieee80211_mgmt, + u.action.u.chan_switch.variable); + + if (ies_len < 0) + break; + + /* CSA IE cannot be overridden, no need for BSSID */ + elems = ieee802_11_parse_elems( + mgmt->u.action.u.chan_switch.variable, + ies_len, true, NULL); + + if (elems && !elems->parse_error) + ieee80211_sta_process_chanswitch(link, + rx_status->mactime, + rx_status->device_timestamp, + elems, false); + kfree(elems); + } else if (mgmt->u.action.category == WLAN_CATEGORY_PUBLIC) { + struct ieee802_11_elems *elems; + + ies_len = skb->len - + offsetof(struct ieee80211_mgmt, + u.action.u.ext_chan_switch.variable); + + if (ies_len < 0) + break; + + /* + * extended CSA IE can't be overridden, no need for + * BSSID + */ + elems = ieee802_11_parse_elems( + mgmt->u.action.u.ext_chan_switch.variable, + ies_len, true, NULL); + + if (elems && !elems->parse_error) { + /* for the handling code pretend it was an IE */ + elems->ext_chansw_ie = + &mgmt->u.action.u.ext_chan_switch.data; + + ieee80211_sta_process_chanswitch(link, + rx_status->mactime, + rx_status->device_timestamp, + elems, false); + } + + kfree(elems); + } + break; + } +out: + sdata_unlock(sdata); +} + +static void ieee80211_sta_timer(struct timer_list *t) +{ + struct ieee80211_sub_if_data *sdata = + from_timer(sdata, t, u.mgd.timer); + + ieee80211_queue_work(&sdata->local->hw, &sdata->work); +} + +void ieee80211_sta_connection_lost(struct ieee80211_sub_if_data *sdata, + u8 reason, bool tx) +{ + u8 frame_buf[IEEE80211_DEAUTH_FRAME_LEN]; + + ieee80211_set_disassoc(sdata, IEEE80211_STYPE_DEAUTH, reason, + tx, frame_buf); + + ieee80211_report_disconnect(sdata, frame_buf, sizeof(frame_buf), true, + reason, false); +} + +static int ieee80211_auth(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_mgd_auth_data *auth_data = ifmgd->auth_data; + u32 tx_flags = 0; + u16 trans = 1; + u16 status = 0; + struct ieee80211_prep_tx_info info = { + .subtype = IEEE80211_STYPE_AUTH, + }; + + sdata_assert_lock(sdata); + + if (WARN_ON_ONCE(!auth_data)) + return -EINVAL; + + auth_data->tries++; + + if (auth_data->tries > IEEE80211_AUTH_MAX_TRIES) { + sdata_info(sdata, "authentication with %pM timed out\n", + auth_data->ap_addr); + + /* + * Most likely AP is not in the range so remove the + * bss struct for that AP. + */ + cfg80211_unlink_bss(local->hw.wiphy, auth_data->bss); + + return -ETIMEDOUT; + } + + if (auth_data->algorithm == WLAN_AUTH_SAE) + info.duration = jiffies_to_msecs(IEEE80211_AUTH_TIMEOUT_SAE); + + drv_mgd_prepare_tx(local, sdata, &info); + + sdata_info(sdata, "send auth to %pM (try %d/%d)\n", + auth_data->ap_addr, auth_data->tries, + IEEE80211_AUTH_MAX_TRIES); + + auth_data->expected_transaction = 2; + + if (auth_data->algorithm == WLAN_AUTH_SAE) { + trans = auth_data->sae_trans; + status = auth_data->sae_status; + auth_data->expected_transaction = trans; + } + + if (ieee80211_hw_check(&local->hw, REPORTS_TX_ACK_STATUS)) + tx_flags = IEEE80211_TX_CTL_REQ_TX_STATUS | + IEEE80211_TX_INTFL_MLME_CONN_TX; + + ieee80211_send_auth(sdata, trans, auth_data->algorithm, status, + auth_data->data, auth_data->data_len, + auth_data->ap_addr, auth_data->ap_addr, + NULL, 0, 0, tx_flags); + + if (tx_flags == 0) { + if (auth_data->algorithm == WLAN_AUTH_SAE) + auth_data->timeout = jiffies + + IEEE80211_AUTH_TIMEOUT_SAE; + else + auth_data->timeout = jiffies + IEEE80211_AUTH_TIMEOUT; + } else { + auth_data->timeout = + round_jiffies_up(jiffies + IEEE80211_AUTH_TIMEOUT_LONG); + } + + auth_data->timeout_started = true; + run_again(sdata, auth_data->timeout); + + return 0; +} + +static int ieee80211_do_assoc(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_mgd_assoc_data *assoc_data = sdata->u.mgd.assoc_data; + struct ieee80211_local *local = sdata->local; + int ret; + + sdata_assert_lock(sdata); + + assoc_data->tries++; + if (assoc_data->tries > IEEE80211_ASSOC_MAX_TRIES) { + sdata_info(sdata, "association with %pM timed out\n", + assoc_data->ap_addr); + + /* + * Most likely AP is not in the range so remove the + * bss struct for that AP. + */ + cfg80211_unlink_bss(local->hw.wiphy, + assoc_data->link[assoc_data->assoc_link_id].bss); + + return -ETIMEDOUT; + } + + sdata_info(sdata, "associate with %pM (try %d/%d)\n", + assoc_data->ap_addr, assoc_data->tries, + IEEE80211_ASSOC_MAX_TRIES); + ret = ieee80211_send_assoc(sdata); + if (ret) + return ret; + + if (!ieee80211_hw_check(&local->hw, REPORTS_TX_ACK_STATUS)) { + assoc_data->timeout = jiffies + IEEE80211_ASSOC_TIMEOUT; + assoc_data->timeout_started = true; + run_again(sdata, assoc_data->timeout); + } else { + assoc_data->timeout = + round_jiffies_up(jiffies + + IEEE80211_ASSOC_TIMEOUT_LONG); + assoc_data->timeout_started = true; + run_again(sdata, assoc_data->timeout); + } + + return 0; +} + +void ieee80211_mgd_conn_tx_status(struct ieee80211_sub_if_data *sdata, + __le16 fc, bool acked) +{ + struct ieee80211_local *local = sdata->local; + + sdata->u.mgd.status_fc = fc; + sdata->u.mgd.status_acked = acked; + sdata->u.mgd.status_received = true; + + ieee80211_queue_work(&local->hw, &sdata->work); +} + +void ieee80211_sta_work(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + + sdata_lock(sdata); + + if (ifmgd->status_received) { + __le16 fc = ifmgd->status_fc; + bool status_acked = ifmgd->status_acked; + + ifmgd->status_received = false; + if (ifmgd->auth_data && ieee80211_is_auth(fc)) { + if (status_acked) { + if (ifmgd->auth_data->algorithm == + WLAN_AUTH_SAE) + ifmgd->auth_data->timeout = + jiffies + + IEEE80211_AUTH_TIMEOUT_SAE; + else + ifmgd->auth_data->timeout = + jiffies + + IEEE80211_AUTH_TIMEOUT_SHORT; + run_again(sdata, ifmgd->auth_data->timeout); + } else { + ifmgd->auth_data->timeout = jiffies - 1; + } + ifmgd->auth_data->timeout_started = true; + } else if (ifmgd->assoc_data && + (ieee80211_is_assoc_req(fc) || + ieee80211_is_reassoc_req(fc))) { + if (status_acked) { + ifmgd->assoc_data->timeout = + jiffies + IEEE80211_ASSOC_TIMEOUT_SHORT; + run_again(sdata, ifmgd->assoc_data->timeout); + } else { + ifmgd->assoc_data->timeout = jiffies - 1; + } + ifmgd->assoc_data->timeout_started = true; + } + } + + if (ifmgd->auth_data && ifmgd->auth_data->timeout_started && + time_after(jiffies, ifmgd->auth_data->timeout)) { + if (ifmgd->auth_data->done || ifmgd->auth_data->waiting) { + /* + * ok ... we waited for assoc or continuation but + * userspace didn't do it, so kill the auth data + */ + ieee80211_destroy_auth_data(sdata, false); + } else if (ieee80211_auth(sdata)) { + u8 ap_addr[ETH_ALEN]; + struct ieee80211_event event = { + .type = MLME_EVENT, + .u.mlme.data = AUTH_EVENT, + .u.mlme.status = MLME_TIMEOUT, + }; + + memcpy(ap_addr, ifmgd->auth_data->ap_addr, ETH_ALEN); + + ieee80211_destroy_auth_data(sdata, false); + + cfg80211_auth_timeout(sdata->dev, ap_addr); + drv_event_callback(sdata->local, sdata, &event); + } + } else if (ifmgd->auth_data && ifmgd->auth_data->timeout_started) + run_again(sdata, ifmgd->auth_data->timeout); + + if (ifmgd->assoc_data && ifmgd->assoc_data->timeout_started && + time_after(jiffies, ifmgd->assoc_data->timeout)) { + if ((ifmgd->assoc_data->need_beacon && + !sdata->deflink.u.mgd.have_beacon) || + ieee80211_do_assoc(sdata)) { + struct ieee80211_event event = { + .type = MLME_EVENT, + .u.mlme.data = ASSOC_EVENT, + .u.mlme.status = MLME_TIMEOUT, + }; + + ieee80211_destroy_assoc_data(sdata, ASSOC_TIMEOUT); + drv_event_callback(sdata->local, sdata, &event); + } + } else if (ifmgd->assoc_data && ifmgd->assoc_data->timeout_started) + run_again(sdata, ifmgd->assoc_data->timeout); + + if (ifmgd->flags & IEEE80211_STA_CONNECTION_POLL && + ifmgd->associated) { + u8 *bssid = sdata->deflink.u.mgd.bssid; + int max_tries; + + if (ieee80211_hw_check(&local->hw, REPORTS_TX_ACK_STATUS)) + max_tries = max_nullfunc_tries; + else + max_tries = max_probe_tries; + + /* ACK received for nullfunc probing frame */ + if (!ifmgd->probe_send_count) + ieee80211_reset_ap_probe(sdata); + else if (ifmgd->nullfunc_failed) { + if (ifmgd->probe_send_count < max_tries) { + mlme_dbg(sdata, + "No ack for nullfunc frame to AP %pM, try %d/%i\n", + bssid, ifmgd->probe_send_count, + max_tries); + ieee80211_mgd_probe_ap_send(sdata); + } else { + mlme_dbg(sdata, + "No ack for nullfunc frame to AP %pM, disconnecting.\n", + bssid); + ieee80211_sta_connection_lost(sdata, + WLAN_REASON_DISASSOC_DUE_TO_INACTIVITY, + false); + } + } else if (time_is_after_jiffies(ifmgd->probe_timeout)) + run_again(sdata, ifmgd->probe_timeout); + else if (ieee80211_hw_check(&local->hw, REPORTS_TX_ACK_STATUS)) { + mlme_dbg(sdata, + "Failed to send nullfunc to AP %pM after %dms, disconnecting\n", + bssid, probe_wait_ms); + ieee80211_sta_connection_lost(sdata, + WLAN_REASON_DISASSOC_DUE_TO_INACTIVITY, false); + } else if (ifmgd->probe_send_count < max_tries) { + mlme_dbg(sdata, + "No probe response from AP %pM after %dms, try %d/%i\n", + bssid, probe_wait_ms, + ifmgd->probe_send_count, max_tries); + ieee80211_mgd_probe_ap_send(sdata); + } else { + /* + * We actually lost the connection ... or did we? + * Let's make sure! + */ + mlme_dbg(sdata, + "No probe response from AP %pM after %dms, disconnecting.\n", + bssid, probe_wait_ms); + + ieee80211_sta_connection_lost(sdata, + WLAN_REASON_DISASSOC_DUE_TO_INACTIVITY, false); + } + } + + sdata_unlock(sdata); +} + +static void ieee80211_sta_bcn_mon_timer(struct timer_list *t) +{ + struct ieee80211_sub_if_data *sdata = + from_timer(sdata, t, u.mgd.bcn_mon_timer); + + if (WARN_ON(sdata->vif.valid_links)) + return; + + if (sdata->vif.bss_conf.csa_active && + !sdata->deflink.u.mgd.csa_waiting_bcn) + return; + + if (sdata->vif.driver_flags & IEEE80211_VIF_BEACON_FILTER) + return; + + sdata->u.mgd.connection_loss = false; + ieee80211_queue_work(&sdata->local->hw, + &sdata->u.mgd.beacon_connection_loss_work); +} + +static void ieee80211_sta_conn_mon_timer(struct timer_list *t) +{ + struct ieee80211_sub_if_data *sdata = + from_timer(sdata, t, u.mgd.conn_mon_timer); + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + unsigned long timeout; + + if (WARN_ON(sdata->vif.valid_links)) + return; + + if (sdata->vif.bss_conf.csa_active && + !sdata->deflink.u.mgd.csa_waiting_bcn) + return; + + sta = sta_info_get(sdata, sdata->vif.cfg.ap_addr); + if (!sta) + return; + + timeout = sta->deflink.status_stats.last_ack; + if (time_before(sta->deflink.status_stats.last_ack, sta->deflink.rx_stats.last_rx)) + timeout = sta->deflink.rx_stats.last_rx; + timeout += IEEE80211_CONNECTION_IDLE_TIME; + + /* If timeout is after now, then update timer to fire at + * the later date, but do not actually probe at this time. + */ + if (time_is_after_jiffies(timeout)) { + mod_timer(&ifmgd->conn_mon_timer, round_jiffies_up(timeout)); + return; + } + + ieee80211_queue_work(&local->hw, &ifmgd->monitor_work); +} + +static void ieee80211_sta_monitor_work(struct work_struct *work) +{ + struct ieee80211_sub_if_data *sdata = + container_of(work, struct ieee80211_sub_if_data, + u.mgd.monitor_work); + + ieee80211_mgd_probe_ap(sdata, false); +} + +static void ieee80211_restart_sta_timer(struct ieee80211_sub_if_data *sdata) +{ + if (sdata->vif.type == NL80211_IFTYPE_STATION) { + __ieee80211_stop_poll(sdata); + + /* let's probe the connection once */ + if (!ieee80211_hw_check(&sdata->local->hw, CONNECTION_MONITOR)) + ieee80211_queue_work(&sdata->local->hw, + &sdata->u.mgd.monitor_work); + } +} + +#ifdef CONFIG_PM +void ieee80211_mgd_quiesce(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + u8 frame_buf[IEEE80211_DEAUTH_FRAME_LEN]; + + sdata_lock(sdata); + + if (ifmgd->auth_data || ifmgd->assoc_data) { + const u8 *ap_addr = ifmgd->auth_data ? + ifmgd->auth_data->ap_addr : + ifmgd->assoc_data->ap_addr; + + /* + * If we are trying to authenticate / associate while suspending, + * cfg80211 won't know and won't actually abort those attempts, + * thus we need to do that ourselves. + */ + ieee80211_send_deauth_disassoc(sdata, ap_addr, ap_addr, + IEEE80211_STYPE_DEAUTH, + WLAN_REASON_DEAUTH_LEAVING, + false, frame_buf); + if (ifmgd->assoc_data) + ieee80211_destroy_assoc_data(sdata, ASSOC_ABANDON); + if (ifmgd->auth_data) + ieee80211_destroy_auth_data(sdata, false); + cfg80211_tx_mlme_mgmt(sdata->dev, frame_buf, + IEEE80211_DEAUTH_FRAME_LEN, + false); + } + + /* This is a bit of a hack - we should find a better and more generic + * solution to this. Normally when suspending, cfg80211 will in fact + * deauthenticate. However, it doesn't (and cannot) stop an ongoing + * auth (not so important) or assoc (this is the problem) process. + * + * As a consequence, it can happen that we are in the process of both + * associating and suspending, and receive an association response + * after cfg80211 has checked if it needs to disconnect, but before + * we actually set the flag to drop incoming frames. This will then + * cause the workqueue flush to process the association response in + * the suspend, resulting in a successful association just before it + * tries to remove the interface from the driver, which now though + * has a channel context assigned ... this results in issues. + * + * To work around this (for now) simply deauth here again if we're + * now connected. + */ + if (ifmgd->associated && !sdata->local->wowlan) { + u8 bssid[ETH_ALEN]; + struct cfg80211_deauth_request req = { + .reason_code = WLAN_REASON_DEAUTH_LEAVING, + .bssid = bssid, + }; + + memcpy(bssid, sdata->vif.cfg.ap_addr, ETH_ALEN); + ieee80211_mgd_deauth(sdata, &req); + } + + sdata_unlock(sdata); +} +#endif + +void ieee80211_sta_restart(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + + sdata_lock(sdata); + if (!ifmgd->associated) { + sdata_unlock(sdata); + return; + } + + if (sdata->flags & IEEE80211_SDATA_DISCONNECT_RESUME) { + sdata->flags &= ~IEEE80211_SDATA_DISCONNECT_RESUME; + mlme_dbg(sdata, "driver requested disconnect after resume\n"); + ieee80211_sta_connection_lost(sdata, + WLAN_REASON_UNSPECIFIED, + true); + sdata_unlock(sdata); + return; + } + + if (sdata->flags & IEEE80211_SDATA_DISCONNECT_HW_RESTART) { + sdata->flags &= ~IEEE80211_SDATA_DISCONNECT_HW_RESTART; + mlme_dbg(sdata, "driver requested disconnect after hardware restart\n"); + ieee80211_sta_connection_lost(sdata, + WLAN_REASON_UNSPECIFIED, + true); + sdata_unlock(sdata); + return; + } + + sdata_unlock(sdata); +} + +static void ieee80211_request_smps_mgd_work(struct work_struct *work) +{ + struct ieee80211_link_data *link = + container_of(work, struct ieee80211_link_data, + u.mgd.request_smps_work); + + sdata_lock(link->sdata); + __ieee80211_request_smps_mgd(link->sdata, link, + link->u.mgd.driver_smps_mode); + sdata_unlock(link->sdata); +} + +/* interface setup */ +void ieee80211_sta_setup_sdata(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + + INIT_WORK(&ifmgd->monitor_work, ieee80211_sta_monitor_work); + INIT_WORK(&ifmgd->beacon_connection_loss_work, + ieee80211_beacon_connection_loss_work); + INIT_WORK(&ifmgd->csa_connection_drop_work, + ieee80211_csa_connection_drop_work); + INIT_DELAYED_WORK(&ifmgd->tdls_peer_del_work, + ieee80211_tdls_peer_del_work); + timer_setup(&ifmgd->timer, ieee80211_sta_timer, 0); + timer_setup(&ifmgd->bcn_mon_timer, ieee80211_sta_bcn_mon_timer, 0); + timer_setup(&ifmgd->conn_mon_timer, ieee80211_sta_conn_mon_timer, 0); + INIT_DELAYED_WORK(&ifmgd->tx_tspec_wk, + ieee80211_sta_handle_tspec_ac_params_wk); + + ifmgd->flags = 0; + ifmgd->powersave = sdata->wdev.ps; + ifmgd->uapsd_queues = sdata->local->hw.uapsd_queues; + ifmgd->uapsd_max_sp_len = sdata->local->hw.uapsd_max_sp_len; + /* Setup TDLS data */ + spin_lock_init(&ifmgd->teardown_lock); + ifmgd->teardown_skb = NULL; + ifmgd->orig_teardown_skb = NULL; +} + +void ieee80211_mgd_setup_link(struct ieee80211_link_data *link) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_local *local = sdata->local; + unsigned int link_id = link->link_id; + + link->u.mgd.p2p_noa_index = -1; + link->u.mgd.conn_flags = 0; + link->conf->bssid = link->u.mgd.bssid; + + INIT_WORK(&link->u.mgd.request_smps_work, + ieee80211_request_smps_mgd_work); + if (local->hw.wiphy->features & NL80211_FEATURE_DYNAMIC_SMPS) + link->u.mgd.req_smps = IEEE80211_SMPS_AUTOMATIC; + else + link->u.mgd.req_smps = IEEE80211_SMPS_OFF; + + INIT_WORK(&link->u.mgd.chswitch_work, ieee80211_chswitch_work); + timer_setup(&link->u.mgd.chswitch_timer, ieee80211_chswitch_timer, 0); + + if (sdata->u.mgd.assoc_data) + ether_addr_copy(link->conf->addr, + sdata->u.mgd.assoc_data->link[link_id].addr); + else if (!is_valid_ether_addr(link->conf->addr)) + eth_random_addr(link->conf->addr); +} + +/* scan finished notification */ +void ieee80211_mlme_notify_scan_completed(struct ieee80211_local *local) +{ + struct ieee80211_sub_if_data *sdata; + + /* Restart STA timers */ + rcu_read_lock(); + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + if (ieee80211_sdata_running(sdata)) + ieee80211_restart_sta_timer(sdata); + } + rcu_read_unlock(); +} + +static int ieee80211_prep_connection(struct ieee80211_sub_if_data *sdata, + struct cfg80211_bss *cbss, s8 link_id, + const u8 *ap_mld_addr, bool assoc, + bool override) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_bss *bss = (void *)cbss->priv; + struct sta_info *new_sta = NULL; + struct ieee80211_link_data *link; + bool have_sta = false; + bool mlo; + int err; + + if (link_id >= 0) { + mlo = true; + if (WARN_ON(!ap_mld_addr)) + return -EINVAL; + err = ieee80211_vif_set_links(sdata, BIT(link_id)); + } else { + if (WARN_ON(ap_mld_addr)) + return -EINVAL; + ap_mld_addr = cbss->bssid; + err = ieee80211_vif_set_links(sdata, 0); + link_id = 0; + mlo = false; + } + + if (err) + return err; + + link = sdata_dereference(sdata->link[link_id], sdata); + if (WARN_ON(!link)) { + err = -ENOLINK; + goto out_err; + } + + if (WARN_ON(!ifmgd->auth_data && !ifmgd->assoc_data)) { + err = -EINVAL; + goto out_err; + } + + /* If a reconfig is happening, bail out */ + if (local->in_reconfig) { + err = -EBUSY; + goto out_err; + } + + if (assoc) { + rcu_read_lock(); + have_sta = sta_info_get(sdata, ap_mld_addr); + rcu_read_unlock(); + } + + if (!have_sta) { + if (mlo) + new_sta = sta_info_alloc_with_link(sdata, ap_mld_addr, + link_id, cbss->bssid, + GFP_KERNEL); + else + new_sta = sta_info_alloc(sdata, ap_mld_addr, GFP_KERNEL); + + if (!new_sta) { + err = -ENOMEM; + goto out_err; + } + + new_sta->sta.mlo = mlo; + } + + /* + * Set up the information for the new channel before setting the + * new channel. We can't - completely race-free - change the basic + * rates bitmap and the channel (sband) that it refers to, but if + * we set it up before we at least avoid calling into the driver's + * bss_info_changed() method with invalid information (since we do + * call that from changing the channel - only for IDLE and perhaps + * some others, but ...). + * + * So to avoid that, just set up all the new information before the + * channel, but tell the driver to apply it only afterwards, since + * it might need the new channel for that. + */ + if (new_sta) { + const struct cfg80211_bss_ies *ies; + struct link_sta_info *link_sta; + + rcu_read_lock(); + link_sta = rcu_dereference(new_sta->link[link_id]); + if (WARN_ON(!link_sta)) { + rcu_read_unlock(); + sta_info_free(local, new_sta); + err = -EINVAL; + goto out_err; + } + + err = ieee80211_mgd_setup_link_sta(link, new_sta, + link_sta, cbss); + if (err) { + rcu_read_unlock(); + sta_info_free(local, new_sta); + goto out_err; + } + + memcpy(link->u.mgd.bssid, cbss->bssid, ETH_ALEN); + + /* set timing information */ + link->conf->beacon_int = cbss->beacon_interval; + ies = rcu_dereference(cbss->beacon_ies); + if (ies) { + link->conf->sync_tsf = ies->tsf; + link->conf->sync_device_ts = + bss->device_ts_beacon; + + ieee80211_get_dtim(ies, + &link->conf->sync_dtim_count, + NULL); + } else if (!ieee80211_hw_check(&sdata->local->hw, + TIMING_BEACON_ONLY)) { + ies = rcu_dereference(cbss->proberesp_ies); + /* must be non-NULL since beacon IEs were NULL */ + link->conf->sync_tsf = ies->tsf; + link->conf->sync_device_ts = + bss->device_ts_presp; + link->conf->sync_dtim_count = 0; + } else { + link->conf->sync_tsf = 0; + link->conf->sync_device_ts = 0; + link->conf->sync_dtim_count = 0; + } + rcu_read_unlock(); + } + + if (new_sta || override) { + err = ieee80211_prep_channel(sdata, link, cbss, + &link->u.mgd.conn_flags); + if (err) { + if (new_sta) + sta_info_free(local, new_sta); + goto out_err; + } + } + + if (new_sta) { + /* + * tell driver about BSSID, basic rates and timing + * this was set up above, before setting the channel + */ + ieee80211_link_info_change_notify(sdata, link, + BSS_CHANGED_BSSID | + BSS_CHANGED_BASIC_RATES | + BSS_CHANGED_BEACON_INT); + + if (assoc) + sta_info_pre_move_state(new_sta, IEEE80211_STA_AUTH); + + err = sta_info_insert(new_sta); + new_sta = NULL; + if (err) { + sdata_info(sdata, + "failed to insert STA entry for the AP (error %d)\n", + err); + goto out_err; + } + } else + WARN_ON_ONCE(!ether_addr_equal(link->u.mgd.bssid, cbss->bssid)); + + /* Cancel scan to ensure that nothing interferes with connection */ + if (local->scanning) + ieee80211_scan_cancel(local); + + return 0; + +out_err: + ieee80211_link_release_channel(&sdata->deflink); + ieee80211_vif_set_links(sdata, 0); + return err; +} + +/* config hooks */ +int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata, + struct cfg80211_auth_request *req) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_mgd_auth_data *auth_data; + u16 auth_alg; + int err; + bool cont_auth; + + /* prepare auth data structure */ + + switch (req->auth_type) { + case NL80211_AUTHTYPE_OPEN_SYSTEM: + auth_alg = WLAN_AUTH_OPEN; + break; + case NL80211_AUTHTYPE_SHARED_KEY: + if (fips_enabled) + return -EOPNOTSUPP; + auth_alg = WLAN_AUTH_SHARED_KEY; + break; + case NL80211_AUTHTYPE_FT: + auth_alg = WLAN_AUTH_FT; + break; + case NL80211_AUTHTYPE_NETWORK_EAP: + auth_alg = WLAN_AUTH_LEAP; + break; + case NL80211_AUTHTYPE_SAE: + auth_alg = WLAN_AUTH_SAE; + break; + case NL80211_AUTHTYPE_FILS_SK: + auth_alg = WLAN_AUTH_FILS_SK; + break; + case NL80211_AUTHTYPE_FILS_SK_PFS: + auth_alg = WLAN_AUTH_FILS_SK_PFS; + break; + case NL80211_AUTHTYPE_FILS_PK: + auth_alg = WLAN_AUTH_FILS_PK; + break; + default: + return -EOPNOTSUPP; + } + + if (ifmgd->assoc_data) + return -EBUSY; + + auth_data = kzalloc(sizeof(*auth_data) + req->auth_data_len + + req->ie_len, GFP_KERNEL); + if (!auth_data) + return -ENOMEM; + + memcpy(auth_data->ap_addr, + req->ap_mld_addr ?: req->bss->bssid, + ETH_ALEN); + auth_data->bss = req->bss; + auth_data->link_id = req->link_id; + + if (req->auth_data_len >= 4) { + if (req->auth_type == NL80211_AUTHTYPE_SAE) { + __le16 *pos = (__le16 *) req->auth_data; + + auth_data->sae_trans = le16_to_cpu(pos[0]); + auth_data->sae_status = le16_to_cpu(pos[1]); + } + memcpy(auth_data->data, req->auth_data + 4, + req->auth_data_len - 4); + auth_data->data_len += req->auth_data_len - 4; + } + + /* Check if continuing authentication or trying to authenticate with the + * same BSS that we were in the process of authenticating with and avoid + * removal and re-addition of the STA entry in + * ieee80211_prep_connection(). + */ + cont_auth = ifmgd->auth_data && req->bss == ifmgd->auth_data->bss && + ifmgd->auth_data->link_id == req->link_id; + + if (req->ie && req->ie_len) { + memcpy(&auth_data->data[auth_data->data_len], + req->ie, req->ie_len); + auth_data->data_len += req->ie_len; + } + + if (req->key && req->key_len) { + auth_data->key_len = req->key_len; + auth_data->key_idx = req->key_idx; + memcpy(auth_data->key, req->key, req->key_len); + } + + auth_data->algorithm = auth_alg; + + /* try to authenticate/probe */ + + if (ifmgd->auth_data) { + if (cont_auth && req->auth_type == NL80211_AUTHTYPE_SAE) { + auth_data->peer_confirmed = + ifmgd->auth_data->peer_confirmed; + } + ieee80211_destroy_auth_data(sdata, cont_auth); + } + + /* prep auth_data so we don't go into idle on disassoc */ + ifmgd->auth_data = auth_data; + + /* If this is continuation of an ongoing SAE authentication exchange + * (i.e., request to send SAE Confirm) and the peer has already + * confirmed, mark authentication completed since we are about to send + * out SAE Confirm. + */ + if (cont_auth && req->auth_type == NL80211_AUTHTYPE_SAE && + auth_data->peer_confirmed && auth_data->sae_trans == 2) + ieee80211_mark_sta_auth(sdata); + + if (ifmgd->associated) { + u8 frame_buf[IEEE80211_DEAUTH_FRAME_LEN]; + + sdata_info(sdata, + "disconnect from AP %pM for new auth to %pM\n", + sdata->vif.cfg.ap_addr, auth_data->ap_addr); + ieee80211_set_disassoc(sdata, IEEE80211_STYPE_DEAUTH, + WLAN_REASON_UNSPECIFIED, + false, frame_buf); + + ieee80211_report_disconnect(sdata, frame_buf, + sizeof(frame_buf), true, + WLAN_REASON_UNSPECIFIED, + false); + } + + sdata_info(sdata, "authenticate with %pM\n", auth_data->ap_addr); + + /* needed for transmitting the auth frame(s) properly */ + memcpy(sdata->vif.cfg.ap_addr, auth_data->ap_addr, ETH_ALEN); + + err = ieee80211_prep_connection(sdata, req->bss, req->link_id, + req->ap_mld_addr, cont_auth, false); + if (err) + goto err_clear; + + err = ieee80211_auth(sdata); + if (err) { + sta_info_destroy_addr(sdata, auth_data->ap_addr); + goto err_clear; + } + + /* hold our own reference */ + cfg80211_ref_bss(local->hw.wiphy, auth_data->bss); + return 0; + + err_clear: + if (!sdata->vif.valid_links) { + eth_zero_addr(sdata->deflink.u.mgd.bssid); + ieee80211_link_info_change_notify(sdata, &sdata->deflink, + BSS_CHANGED_BSSID); + mutex_lock(&sdata->local->mtx); + ieee80211_link_release_channel(&sdata->deflink); + mutex_unlock(&sdata->local->mtx); + } + ifmgd->auth_data = NULL; + kfree(auth_data); + return err; +} + +static ieee80211_conn_flags_t +ieee80211_setup_assoc_link(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgd_assoc_data *assoc_data, + struct cfg80211_assoc_request *req, + ieee80211_conn_flags_t conn_flags, + unsigned int link_id) +{ + struct ieee80211_local *local = sdata->local; + const struct cfg80211_bss_ies *beacon_ies; + struct ieee80211_supported_band *sband; + const struct element *ht_elem, *vht_elem; + struct ieee80211_link_data *link; + struct cfg80211_bss *cbss; + struct ieee80211_bss *bss; + bool is_5ghz, is_6ghz; + + cbss = assoc_data->link[link_id].bss; + if (WARN_ON(!cbss)) + return 0; + + bss = (void *)cbss->priv; + + sband = local->hw.wiphy->bands[cbss->channel->band]; + if (WARN_ON(!sband)) + return 0; + + link = sdata_dereference(sdata->link[link_id], sdata); + if (WARN_ON(!link)) + return 0; + + is_5ghz = cbss->channel->band == NL80211_BAND_5GHZ; + is_6ghz = cbss->channel->band == NL80211_BAND_6GHZ; + + /* for MLO connections assume advertising all rates is OK */ + if (!req->ap_mld_addr) { + assoc_data->supp_rates = bss->supp_rates; + assoc_data->supp_rates_len = bss->supp_rates_len; + } + + /* copy and link elems for the STA profile */ + if (req->links[link_id].elems_len) { + memcpy(assoc_data->ie_pos, req->links[link_id].elems, + req->links[link_id].elems_len); + assoc_data->link[link_id].elems = assoc_data->ie_pos; + assoc_data->link[link_id].elems_len = req->links[link_id].elems_len; + assoc_data->ie_pos += req->links[link_id].elems_len; + } + + rcu_read_lock(); + ht_elem = ieee80211_bss_get_elem(cbss, WLAN_EID_HT_OPERATION); + if (ht_elem && ht_elem->datalen >= sizeof(struct ieee80211_ht_operation)) + assoc_data->link[link_id].ap_ht_param = + ((struct ieee80211_ht_operation *)(ht_elem->data))->ht_param; + else if (!is_6ghz) + conn_flags |= IEEE80211_CONN_DISABLE_HT; + vht_elem = ieee80211_bss_get_elem(cbss, WLAN_EID_VHT_CAPABILITY); + if (vht_elem && vht_elem->datalen >= sizeof(struct ieee80211_vht_cap)) { + memcpy(&assoc_data->link[link_id].ap_vht_cap, vht_elem->data, + sizeof(struct ieee80211_vht_cap)); + } else if (is_5ghz) { + link_info(link, + "VHT capa missing/short, disabling VHT/HE/EHT\n"); + conn_flags |= IEEE80211_CONN_DISABLE_VHT | + IEEE80211_CONN_DISABLE_HE | + IEEE80211_CONN_DISABLE_EHT; + } + rcu_read_unlock(); + + link->u.mgd.beacon_crc_valid = false; + link->u.mgd.dtim_period = 0; + link->u.mgd.have_beacon = false; + + /* override HT/VHT configuration only if the AP and we support it */ + if (!(conn_flags & IEEE80211_CONN_DISABLE_HT)) { + struct ieee80211_sta_ht_cap sta_ht_cap; + + memcpy(&sta_ht_cap, &sband->ht_cap, sizeof(sta_ht_cap)); + ieee80211_apply_htcap_overrides(sdata, &sta_ht_cap); + } + + rcu_read_lock(); + beacon_ies = rcu_dereference(cbss->beacon_ies); + if (beacon_ies) { + const struct element *elem; + u8 dtim_count = 0; + + ieee80211_get_dtim(beacon_ies, &dtim_count, + &link->u.mgd.dtim_period); + + sdata->deflink.u.mgd.have_beacon = true; + + if (ieee80211_hw_check(&local->hw, TIMING_BEACON_ONLY)) { + link->conf->sync_tsf = beacon_ies->tsf; + link->conf->sync_device_ts = bss->device_ts_beacon; + link->conf->sync_dtim_count = dtim_count; + } + + elem = cfg80211_find_ext_elem(WLAN_EID_EXT_MULTIPLE_BSSID_CONFIGURATION, + beacon_ies->data, beacon_ies->len); + if (elem && elem->datalen >= 3) + link->conf->profile_periodicity = elem->data[2]; + else + link->conf->profile_periodicity = 0; + + elem = cfg80211_find_elem(WLAN_EID_EXT_CAPABILITY, + beacon_ies->data, beacon_ies->len); + if (elem && elem->datalen >= 11 && + (elem->data[10] & WLAN_EXT_CAPA11_EMA_SUPPORT)) + link->conf->ema_ap = true; + else + link->conf->ema_ap = false; + } + rcu_read_unlock(); + + if (bss->corrupt_data) { + char *corrupt_type = "data"; + + if (bss->corrupt_data & IEEE80211_BSS_CORRUPT_BEACON) { + if (bss->corrupt_data & IEEE80211_BSS_CORRUPT_PROBE_RESP) + corrupt_type = "beacon and probe response"; + else + corrupt_type = "beacon"; + } else if (bss->corrupt_data & IEEE80211_BSS_CORRUPT_PROBE_RESP) { + corrupt_type = "probe response"; + } + sdata_info(sdata, "associating to AP %pM with corrupt %s\n", + cbss->bssid, corrupt_type); + } + + if (link->u.mgd.req_smps == IEEE80211_SMPS_AUTOMATIC) { + if (sdata->u.mgd.powersave) + link->smps_mode = IEEE80211_SMPS_DYNAMIC; + else + link->smps_mode = IEEE80211_SMPS_OFF; + } else { + link->smps_mode = link->u.mgd.req_smps; + } + + return conn_flags; +} + +int ieee80211_mgd_assoc(struct ieee80211_sub_if_data *sdata, + struct cfg80211_assoc_request *req) +{ + unsigned int assoc_link_id = req->link_id < 0 ? 0 : req->link_id; + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_mgd_assoc_data *assoc_data; + const struct element *ssid_elem; + struct ieee80211_vif_cfg *vif_cfg = &sdata->vif.cfg; + ieee80211_conn_flags_t conn_flags = 0; + struct ieee80211_link_data *link; + struct cfg80211_bss *cbss; + struct ieee80211_bss *bss; + bool override; + int i, err; + size_t size = sizeof(*assoc_data) + req->ie_len; + + for (i = 0; i < IEEE80211_MLD_MAX_NUM_LINKS; i++) + size += req->links[i].elems_len; + + /* FIXME: no support for 4-addr MLO yet */ + if (sdata->u.mgd.use_4addr && req->link_id >= 0) + return -EOPNOTSUPP; + + assoc_data = kzalloc(size, GFP_KERNEL); + if (!assoc_data) + return -ENOMEM; + + cbss = req->link_id < 0 ? req->bss : req->links[req->link_id].bss; + + rcu_read_lock(); + ssid_elem = ieee80211_bss_get_elem(cbss, WLAN_EID_SSID); + if (!ssid_elem || ssid_elem->datalen > sizeof(assoc_data->ssid)) { + rcu_read_unlock(); + kfree(assoc_data); + return -EINVAL; + } + memcpy(assoc_data->ssid, ssid_elem->data, ssid_elem->datalen); + assoc_data->ssid_len = ssid_elem->datalen; + memcpy(vif_cfg->ssid, assoc_data->ssid, assoc_data->ssid_len); + vif_cfg->ssid_len = assoc_data->ssid_len; + rcu_read_unlock(); + + if (req->ap_mld_addr) { + for (i = 0; i < IEEE80211_MLD_MAX_NUM_LINKS; i++) { + if (!req->links[i].bss) + continue; + link = sdata_dereference(sdata->link[i], sdata); + if (link) + ether_addr_copy(assoc_data->link[i].addr, + link->conf->addr); + else + eth_random_addr(assoc_data->link[i].addr); + } + } else { + memcpy(assoc_data->link[0].addr, sdata->vif.addr, ETH_ALEN); + } + + assoc_data->s1g = cbss->channel->band == NL80211_BAND_S1GHZ; + + memcpy(assoc_data->ap_addr, + req->ap_mld_addr ?: req->bss->bssid, + ETH_ALEN); + + if (ifmgd->associated) { + u8 frame_buf[IEEE80211_DEAUTH_FRAME_LEN]; + + sdata_info(sdata, + "disconnect from AP %pM for new assoc to %pM\n", + sdata->vif.cfg.ap_addr, assoc_data->ap_addr); + ieee80211_set_disassoc(sdata, IEEE80211_STYPE_DEAUTH, + WLAN_REASON_UNSPECIFIED, + false, frame_buf); + + ieee80211_report_disconnect(sdata, frame_buf, + sizeof(frame_buf), true, + WLAN_REASON_UNSPECIFIED, + false); + } + + if (ifmgd->auth_data && !ifmgd->auth_data->done) { + err = -EBUSY; + goto err_free; + } + + if (ifmgd->assoc_data) { + err = -EBUSY; + goto err_free; + } + + if (ifmgd->auth_data) { + bool match; + + /* keep sta info, bssid if matching */ + match = ether_addr_equal(ifmgd->auth_data->ap_addr, + assoc_data->ap_addr) && + ifmgd->auth_data->link_id == req->link_id; + ieee80211_destroy_auth_data(sdata, match); + } + + /* prepare assoc data */ + + bss = (void *)cbss->priv; + assoc_data->wmm = bss->wmm_used && + (local->hw.queues >= IEEE80211_NUM_ACS); + + /* + * IEEE802.11n does not allow TKIP/WEP as pairwise ciphers in HT mode. + * We still associate in non-HT mode (11a/b/g) if any one of these + * ciphers is configured as pairwise. + * We can set this to true for non-11n hardware, that'll be checked + * separately along with the peer capabilities. + */ + for (i = 0; i < req->crypto.n_ciphers_pairwise; i++) { + if (req->crypto.ciphers_pairwise[i] == WLAN_CIPHER_SUITE_WEP40 || + req->crypto.ciphers_pairwise[i] == WLAN_CIPHER_SUITE_TKIP || + req->crypto.ciphers_pairwise[i] == WLAN_CIPHER_SUITE_WEP104) { + conn_flags |= IEEE80211_CONN_DISABLE_HT; + conn_flags |= IEEE80211_CONN_DISABLE_VHT; + conn_flags |= IEEE80211_CONN_DISABLE_HE; + conn_flags |= IEEE80211_CONN_DISABLE_EHT; + netdev_info(sdata->dev, + "disabling HT/VHT/HE due to WEP/TKIP use\n"); + } + } + + /* also disable HT/VHT/HE/EHT if the AP doesn't use WMM */ + if (!bss->wmm_used) { + conn_flags |= IEEE80211_CONN_DISABLE_HT; + conn_flags |= IEEE80211_CONN_DISABLE_VHT; + conn_flags |= IEEE80211_CONN_DISABLE_HE; + conn_flags |= IEEE80211_CONN_DISABLE_EHT; + netdev_info(sdata->dev, + "disabling HT/VHT/HE as WMM/QoS is not supported by the AP\n"); + } + + if (req->flags & ASSOC_REQ_DISABLE_HT) { + mlme_dbg(sdata, "HT disabled by flag, disabling HT/VHT/HE\n"); + conn_flags |= IEEE80211_CONN_DISABLE_HT; + conn_flags |= IEEE80211_CONN_DISABLE_VHT; + conn_flags |= IEEE80211_CONN_DISABLE_HE; + conn_flags |= IEEE80211_CONN_DISABLE_EHT; + } + + if (req->flags & ASSOC_REQ_DISABLE_VHT) { + mlme_dbg(sdata, "VHT disabled by flag, disabling VHT\n"); + conn_flags |= IEEE80211_CONN_DISABLE_VHT; + } + + if (req->flags & ASSOC_REQ_DISABLE_HE) { + mlme_dbg(sdata, "HE disabled by flag, disabling HE/EHT\n"); + conn_flags |= IEEE80211_CONN_DISABLE_HE; + conn_flags |= IEEE80211_CONN_DISABLE_EHT; + } + + if (req->flags & ASSOC_REQ_DISABLE_EHT) + conn_flags |= IEEE80211_CONN_DISABLE_EHT; + + memcpy(&ifmgd->ht_capa, &req->ht_capa, sizeof(ifmgd->ht_capa)); + memcpy(&ifmgd->ht_capa_mask, &req->ht_capa_mask, + sizeof(ifmgd->ht_capa_mask)); + + memcpy(&ifmgd->vht_capa, &req->vht_capa, sizeof(ifmgd->vht_capa)); + memcpy(&ifmgd->vht_capa_mask, &req->vht_capa_mask, + sizeof(ifmgd->vht_capa_mask)); + + memcpy(&ifmgd->s1g_capa, &req->s1g_capa, sizeof(ifmgd->s1g_capa)); + memcpy(&ifmgd->s1g_capa_mask, &req->s1g_capa_mask, + sizeof(ifmgd->s1g_capa_mask)); + + if (req->ie && req->ie_len) { + memcpy(assoc_data->ie, req->ie, req->ie_len); + assoc_data->ie_len = req->ie_len; + assoc_data->ie_pos = assoc_data->ie + assoc_data->ie_len; + } else { + assoc_data->ie_pos = assoc_data->ie; + } + + if (req->fils_kek) { + /* should already be checked in cfg80211 - so warn */ + if (WARN_ON(req->fils_kek_len > FILS_MAX_KEK_LEN)) { + err = -EINVAL; + goto err_free; + } + memcpy(assoc_data->fils_kek, req->fils_kek, + req->fils_kek_len); + assoc_data->fils_kek_len = req->fils_kek_len; + } + + if (req->fils_nonces) + memcpy(assoc_data->fils_nonces, req->fils_nonces, + 2 * FILS_NONCE_LEN); + + /* default timeout */ + assoc_data->timeout = jiffies; + assoc_data->timeout_started = true; + + assoc_data->assoc_link_id = assoc_link_id; + + if (req->ap_mld_addr) { + for (i = 0; i < ARRAY_SIZE(assoc_data->link); i++) { + assoc_data->link[i].conn_flags = conn_flags; + assoc_data->link[i].bss = req->links[i].bss; + } + + /* if there was no authentication, set up the link */ + err = ieee80211_vif_set_links(sdata, BIT(assoc_link_id)); + if (err) + goto err_clear; + } else { + assoc_data->link[0].conn_flags = conn_flags; + assoc_data->link[0].bss = cbss; + } + + link = sdata_dereference(sdata->link[assoc_link_id], sdata); + if (WARN_ON(!link)) { + err = -EINVAL; + goto err_clear; + } + + /* keep old conn_flags from ieee80211_prep_channel() from auth */ + conn_flags |= link->u.mgd.conn_flags; + conn_flags |= ieee80211_setup_assoc_link(sdata, assoc_data, req, + conn_flags, assoc_link_id); + override = link->u.mgd.conn_flags != conn_flags; + link->u.mgd.conn_flags |= conn_flags; + + if (WARN((sdata->vif.driver_flags & IEEE80211_VIF_SUPPORTS_UAPSD) && + ieee80211_hw_check(&local->hw, PS_NULLFUNC_STACK), + "U-APSD not supported with HW_PS_NULLFUNC_STACK\n")) + sdata->vif.driver_flags &= ~IEEE80211_VIF_SUPPORTS_UAPSD; + + if (bss->wmm_used && bss->uapsd_supported && + (sdata->vif.driver_flags & IEEE80211_VIF_SUPPORTS_UAPSD)) { + assoc_data->uapsd = true; + ifmgd->flags |= IEEE80211_STA_UAPSD_ENABLED; + } else { + assoc_data->uapsd = false; + ifmgd->flags &= ~IEEE80211_STA_UAPSD_ENABLED; + } + + if (req->prev_bssid) + memcpy(assoc_data->prev_ap_addr, req->prev_bssid, ETH_ALEN); + + if (req->use_mfp) { + ifmgd->mfp = IEEE80211_MFP_REQUIRED; + ifmgd->flags |= IEEE80211_STA_MFP_ENABLED; + } else { + ifmgd->mfp = IEEE80211_MFP_DISABLED; + ifmgd->flags &= ~IEEE80211_STA_MFP_ENABLED; + } + + if (req->flags & ASSOC_REQ_USE_RRM) + ifmgd->flags |= IEEE80211_STA_ENABLE_RRM; + else + ifmgd->flags &= ~IEEE80211_STA_ENABLE_RRM; + + if (req->crypto.control_port) + ifmgd->flags |= IEEE80211_STA_CONTROL_PORT; + else + ifmgd->flags &= ~IEEE80211_STA_CONTROL_PORT; + + sdata->control_port_protocol = req->crypto.control_port_ethertype; + sdata->control_port_no_encrypt = req->crypto.control_port_no_encrypt; + sdata->control_port_over_nl80211 = + req->crypto.control_port_over_nl80211; + sdata->control_port_no_preauth = req->crypto.control_port_no_preauth; + + /* kick off associate process */ + ifmgd->assoc_data = assoc_data; + + for (i = 0; i < ARRAY_SIZE(assoc_data->link); i++) { + if (!assoc_data->link[i].bss) + continue; + if (i == assoc_data->assoc_link_id) + continue; + /* only calculate the flags, hence link == NULL */ + err = ieee80211_prep_channel(sdata, NULL, assoc_data->link[i].bss, + &assoc_data->link[i].conn_flags); + if (err) + goto err_clear; + } + + /* needed for transmitting the assoc frames properly */ + memcpy(sdata->vif.cfg.ap_addr, assoc_data->ap_addr, ETH_ALEN); + + err = ieee80211_prep_connection(sdata, cbss, req->link_id, + req->ap_mld_addr, true, override); + if (err) + goto err_clear; + + assoc_data->link[assoc_data->assoc_link_id].conn_flags = + link->u.mgd.conn_flags; + + if (ieee80211_hw_check(&sdata->local->hw, NEED_DTIM_BEFORE_ASSOC)) { + const struct cfg80211_bss_ies *beacon_ies; + + rcu_read_lock(); + beacon_ies = rcu_dereference(req->bss->beacon_ies); + + if (beacon_ies) { + /* + * Wait up to one beacon interval ... + * should this be more if we miss one? + */ + sdata_info(sdata, "waiting for beacon from %pM\n", + link->u.mgd.bssid); + assoc_data->timeout = TU_TO_EXP_TIME(req->bss->beacon_interval); + assoc_data->timeout_started = true; + assoc_data->need_beacon = true; + } + rcu_read_unlock(); + } + + run_again(sdata, assoc_data->timeout); + + return 0; + err_clear: + eth_zero_addr(sdata->deflink.u.mgd.bssid); + ieee80211_link_info_change_notify(sdata, &sdata->deflink, + BSS_CHANGED_BSSID); + ifmgd->assoc_data = NULL; + err_free: + kfree(assoc_data); + return err; +} + +int ieee80211_mgd_deauth(struct ieee80211_sub_if_data *sdata, + struct cfg80211_deauth_request *req) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + u8 frame_buf[IEEE80211_DEAUTH_FRAME_LEN]; + bool tx = !req->local_state_change; + struct ieee80211_prep_tx_info info = { + .subtype = IEEE80211_STYPE_DEAUTH, + }; + + if (ifmgd->auth_data && + ether_addr_equal(ifmgd->auth_data->ap_addr, req->bssid)) { + sdata_info(sdata, + "aborting authentication with %pM by local choice (Reason: %u=%s)\n", + req->bssid, req->reason_code, + ieee80211_get_reason_code_string(req->reason_code)); + + drv_mgd_prepare_tx(sdata->local, sdata, &info); + ieee80211_send_deauth_disassoc(sdata, req->bssid, req->bssid, + IEEE80211_STYPE_DEAUTH, + req->reason_code, tx, + frame_buf); + ieee80211_destroy_auth_data(sdata, false); + ieee80211_report_disconnect(sdata, frame_buf, + sizeof(frame_buf), true, + req->reason_code, false); + drv_mgd_complete_tx(sdata->local, sdata, &info); + return 0; + } + + if (ifmgd->assoc_data && + ether_addr_equal(ifmgd->assoc_data->ap_addr, req->bssid)) { + sdata_info(sdata, + "aborting association with %pM by local choice (Reason: %u=%s)\n", + req->bssid, req->reason_code, + ieee80211_get_reason_code_string(req->reason_code)); + + drv_mgd_prepare_tx(sdata->local, sdata, &info); + ieee80211_send_deauth_disassoc(sdata, req->bssid, req->bssid, + IEEE80211_STYPE_DEAUTH, + req->reason_code, tx, + frame_buf); + ieee80211_destroy_assoc_data(sdata, ASSOC_ABANDON); + ieee80211_report_disconnect(sdata, frame_buf, + sizeof(frame_buf), true, + req->reason_code, false); + return 0; + } + + if (ifmgd->associated && + ether_addr_equal(sdata->vif.cfg.ap_addr, req->bssid)) { + sdata_info(sdata, + "deauthenticating from %pM by local choice (Reason: %u=%s)\n", + req->bssid, req->reason_code, + ieee80211_get_reason_code_string(req->reason_code)); + + ieee80211_set_disassoc(sdata, IEEE80211_STYPE_DEAUTH, + req->reason_code, tx, frame_buf); + ieee80211_report_disconnect(sdata, frame_buf, + sizeof(frame_buf), true, + req->reason_code, false); + drv_mgd_complete_tx(sdata->local, sdata, &info); + return 0; + } + + return -ENOTCONN; +} + +int ieee80211_mgd_disassoc(struct ieee80211_sub_if_data *sdata, + struct cfg80211_disassoc_request *req) +{ + u8 frame_buf[IEEE80211_DEAUTH_FRAME_LEN]; + + if (!sdata->u.mgd.associated || + memcmp(sdata->vif.cfg.ap_addr, req->ap_addr, ETH_ALEN)) + return -ENOTCONN; + + sdata_info(sdata, + "disassociating from %pM by local choice (Reason: %u=%s)\n", + req->ap_addr, req->reason_code, + ieee80211_get_reason_code_string(req->reason_code)); + + ieee80211_set_disassoc(sdata, IEEE80211_STYPE_DISASSOC, + req->reason_code, !req->local_state_change, + frame_buf); + + ieee80211_report_disconnect(sdata, frame_buf, sizeof(frame_buf), true, + req->reason_code, false); + + return 0; +} + +void ieee80211_mgd_stop_link(struct ieee80211_link_data *link) +{ + cancel_work_sync(&link->u.mgd.request_smps_work); + cancel_work_sync(&link->u.mgd.chswitch_work); +} + +void ieee80211_mgd_stop(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + + /* + * Make sure some work items will not run after this, + * they will not do anything but might not have been + * cancelled when disconnecting. + */ + cancel_work_sync(&ifmgd->monitor_work); + cancel_work_sync(&ifmgd->beacon_connection_loss_work); + cancel_work_sync(&ifmgd->csa_connection_drop_work); + cancel_delayed_work_sync(&ifmgd->tdls_peer_del_work); + + sdata_lock(sdata); + if (ifmgd->assoc_data) + ieee80211_destroy_assoc_data(sdata, ASSOC_TIMEOUT); + if (ifmgd->auth_data) + ieee80211_destroy_auth_data(sdata, false); + spin_lock_bh(&ifmgd->teardown_lock); + if (ifmgd->teardown_skb) { + kfree_skb(ifmgd->teardown_skb); + ifmgd->teardown_skb = NULL; + ifmgd->orig_teardown_skb = NULL; + } + kfree(ifmgd->assoc_req_ies); + ifmgd->assoc_req_ies = NULL; + ifmgd->assoc_req_ies_len = 0; + spin_unlock_bh(&ifmgd->teardown_lock); + del_timer_sync(&ifmgd->timer); + sdata_unlock(sdata); +} + +void ieee80211_cqm_rssi_notify(struct ieee80211_vif *vif, + enum nl80211_cqm_rssi_threshold_event rssi_event, + s32 rssi_level, + gfp_t gfp) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + + trace_api_cqm_rssi_notify(sdata, rssi_event, rssi_level); + + cfg80211_cqm_rssi_notify(sdata->dev, rssi_event, rssi_level, gfp); +} +EXPORT_SYMBOL(ieee80211_cqm_rssi_notify); + +void ieee80211_cqm_beacon_loss_notify(struct ieee80211_vif *vif, gfp_t gfp) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + + trace_api_cqm_beacon_loss_notify(sdata->local, sdata); + + cfg80211_cqm_beacon_loss_notify(sdata->dev, gfp); +} +EXPORT_SYMBOL(ieee80211_cqm_beacon_loss_notify); + +static void _ieee80211_enable_rssi_reports(struct ieee80211_sub_if_data *sdata, + int rssi_min_thold, + int rssi_max_thold) +{ + trace_api_enable_rssi_reports(sdata, rssi_min_thold, rssi_max_thold); + + if (WARN_ON(sdata->vif.type != NL80211_IFTYPE_STATION)) + return; + + /* + * Scale up threshold values before storing it, as the RSSI averaging + * algorithm uses a scaled up value as well. Change this scaling + * factor if the RSSI averaging algorithm changes. + */ + sdata->u.mgd.rssi_min_thold = rssi_min_thold*16; + sdata->u.mgd.rssi_max_thold = rssi_max_thold*16; +} + +void ieee80211_enable_rssi_reports(struct ieee80211_vif *vif, + int rssi_min_thold, + int rssi_max_thold) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + + WARN_ON(rssi_min_thold == rssi_max_thold || + rssi_min_thold > rssi_max_thold); + + _ieee80211_enable_rssi_reports(sdata, rssi_min_thold, + rssi_max_thold); +} +EXPORT_SYMBOL(ieee80211_enable_rssi_reports); + +void ieee80211_disable_rssi_reports(struct ieee80211_vif *vif) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + + _ieee80211_enable_rssi_reports(sdata, 0, 0); +} +EXPORT_SYMBOL(ieee80211_disable_rssi_reports); diff --git a/net/mac80211/ocb.c b/net/mac80211/ocb.c new file mode 100644 index 000000000..a57dcbe99 --- /dev/null +++ b/net/mac80211/ocb.c @@ -0,0 +1,246 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * OCB mode implementation + * + * Copyright: (c) 2014 Czech Technical University in Prague + * (c) 2014 Volkswagen Group Research + * Copyright (C) 2022 Intel Corporation + * Author: Rostislav Lisovy + * Funded by: Volkswagen Group Research + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ieee80211_i.h" +#include "driver-ops.h" +#include "rate.h" + +#define IEEE80211_OCB_HOUSEKEEPING_INTERVAL (60 * HZ) +#define IEEE80211_OCB_PEER_INACTIVITY_LIMIT (240 * HZ) +#define IEEE80211_OCB_MAX_STA_ENTRIES 128 + +/** + * enum ocb_deferred_task_flags - mac80211 OCB deferred tasks + * @OCB_WORK_HOUSEKEEPING: run the periodic OCB housekeeping tasks + * + * These flags are used in @wrkq_flags field of &struct ieee80211_if_ocb + */ +enum ocb_deferred_task_flags { + OCB_WORK_HOUSEKEEPING, +}; + +void ieee80211_ocb_rx_no_sta(struct ieee80211_sub_if_data *sdata, + const u8 *bssid, const u8 *addr, + u32 supp_rates) +{ + struct ieee80211_if_ocb *ifocb = &sdata->u.ocb; + struct ieee80211_local *local = sdata->local; + struct ieee80211_chanctx_conf *chanctx_conf; + struct ieee80211_supported_band *sband; + enum nl80211_bss_scan_width scan_width; + struct sta_info *sta; + int band; + + /* XXX: Consider removing the least recently used entry and + * allow new one to be added. + */ + if (local->num_sta >= IEEE80211_OCB_MAX_STA_ENTRIES) { + net_info_ratelimited("%s: No room for a new OCB STA entry %pM\n", + sdata->name, addr); + return; + } + + ocb_dbg(sdata, "Adding new OCB station %pM\n", addr); + + rcu_read_lock(); + chanctx_conf = rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + if (WARN_ON_ONCE(!chanctx_conf)) { + rcu_read_unlock(); + return; + } + band = chanctx_conf->def.chan->band; + scan_width = cfg80211_chandef_to_scan_width(&chanctx_conf->def); + rcu_read_unlock(); + + sta = sta_info_alloc(sdata, addr, GFP_ATOMIC); + if (!sta) + return; + + /* Add only mandatory rates for now */ + sband = local->hw.wiphy->bands[band]; + sta->sta.deflink.supp_rates[band] = + ieee80211_mandatory_rates(sband, scan_width); + + spin_lock(&ifocb->incomplete_lock); + list_add(&sta->list, &ifocb->incomplete_stations); + spin_unlock(&ifocb->incomplete_lock); + ieee80211_queue_work(&local->hw, &sdata->work); +} + +static struct sta_info *ieee80211_ocb_finish_sta(struct sta_info *sta) + __acquires(RCU) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + u8 addr[ETH_ALEN]; + + memcpy(addr, sta->sta.addr, ETH_ALEN); + + ocb_dbg(sdata, "Adding new IBSS station %pM (dev=%s)\n", + addr, sdata->name); + + sta_info_move_state(sta, IEEE80211_STA_AUTH); + sta_info_move_state(sta, IEEE80211_STA_ASSOC); + sta_info_move_state(sta, IEEE80211_STA_AUTHORIZED); + + rate_control_rate_init(sta); + + /* If it fails, maybe we raced another insertion? */ + if (sta_info_insert_rcu(sta)) + return sta_info_get(sdata, addr); + return sta; +} + +static void ieee80211_ocb_housekeeping(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_ocb *ifocb = &sdata->u.ocb; + + ocb_dbg(sdata, "Running ocb housekeeping\n"); + + ieee80211_sta_expire(sdata, IEEE80211_OCB_PEER_INACTIVITY_LIMIT); + + mod_timer(&ifocb->housekeeping_timer, + round_jiffies(jiffies + IEEE80211_OCB_HOUSEKEEPING_INTERVAL)); +} + +void ieee80211_ocb_work(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_ocb *ifocb = &sdata->u.ocb; + struct sta_info *sta; + + if (ifocb->joined != true) + return; + + sdata_lock(sdata); + + spin_lock_bh(&ifocb->incomplete_lock); + while (!list_empty(&ifocb->incomplete_stations)) { + sta = list_first_entry(&ifocb->incomplete_stations, + struct sta_info, list); + list_del(&sta->list); + spin_unlock_bh(&ifocb->incomplete_lock); + + ieee80211_ocb_finish_sta(sta); + rcu_read_unlock(); + spin_lock_bh(&ifocb->incomplete_lock); + } + spin_unlock_bh(&ifocb->incomplete_lock); + + if (test_and_clear_bit(OCB_WORK_HOUSEKEEPING, &ifocb->wrkq_flags)) + ieee80211_ocb_housekeeping(sdata); + + sdata_unlock(sdata); +} + +static void ieee80211_ocb_housekeeping_timer(struct timer_list *t) +{ + struct ieee80211_sub_if_data *sdata = + from_timer(sdata, t, u.ocb.housekeeping_timer); + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_ocb *ifocb = &sdata->u.ocb; + + set_bit(OCB_WORK_HOUSEKEEPING, &ifocb->wrkq_flags); + + ieee80211_queue_work(&local->hw, &sdata->work); +} + +void ieee80211_ocb_setup_sdata(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_ocb *ifocb = &sdata->u.ocb; + + timer_setup(&ifocb->housekeeping_timer, + ieee80211_ocb_housekeeping_timer, 0); + INIT_LIST_HEAD(&ifocb->incomplete_stations); + spin_lock_init(&ifocb->incomplete_lock); +} + +int ieee80211_ocb_join(struct ieee80211_sub_if_data *sdata, + struct ocb_setup *setup) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_ocb *ifocb = &sdata->u.ocb; + u32 changed = BSS_CHANGED_OCB | BSS_CHANGED_BSSID; + int err; + + if (ifocb->joined == true) + return -EINVAL; + + sdata->deflink.operating_11g_mode = true; + sdata->deflink.smps_mode = IEEE80211_SMPS_OFF; + sdata->deflink.needed_rx_chains = sdata->local->rx_chains; + + mutex_lock(&sdata->local->mtx); + err = ieee80211_link_use_channel(&sdata->deflink, &setup->chandef, + IEEE80211_CHANCTX_SHARED); + mutex_unlock(&sdata->local->mtx); + if (err) + return err; + + ieee80211_bss_info_change_notify(sdata, changed); + + ifocb->joined = true; + + set_bit(OCB_WORK_HOUSEKEEPING, &ifocb->wrkq_flags); + ieee80211_queue_work(&local->hw, &sdata->work); + + netif_carrier_on(sdata->dev); + return 0; +} + +int ieee80211_ocb_leave(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_if_ocb *ifocb = &sdata->u.ocb; + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + + ifocb->joined = false; + sta_info_flush(sdata); + + spin_lock_bh(&ifocb->incomplete_lock); + while (!list_empty(&ifocb->incomplete_stations)) { + sta = list_first_entry(&ifocb->incomplete_stations, + struct sta_info, list); + list_del(&sta->list); + spin_unlock_bh(&ifocb->incomplete_lock); + + sta_info_free(local, sta); + spin_lock_bh(&ifocb->incomplete_lock); + } + spin_unlock_bh(&ifocb->incomplete_lock); + + netif_carrier_off(sdata->dev); + clear_bit(SDATA_STATE_OFFCHANNEL, &sdata->state); + ieee80211_bss_info_change_notify(sdata, BSS_CHANGED_OCB); + + mutex_lock(&sdata->local->mtx); + ieee80211_link_release_channel(&sdata->deflink); + mutex_unlock(&sdata->local->mtx); + + skb_queue_purge(&sdata->skb_queue); + + del_timer_sync(&sdata->u.ocb.housekeeping_timer); + /* If the timer fired while we waited for it, it will have + * requeued the work. Now the work will be running again + * but will not rearm the timer again because it checks + * whether we are connected to the network or not -- at this + * point we shouldn't be anymore. + */ + + return 0; +} diff --git a/net/mac80211/offchannel.c b/net/mac80211/offchannel.c new file mode 100644 index 000000000..50dc379ca --- /dev/null +++ b/net/mac80211/offchannel.c @@ -0,0 +1,1030 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Off-channel operation helpers + * + * Copyright 2003, Jouni Malinen + * Copyright 2004, Instant802 Networks, Inc. + * Copyright 2005, Devicescape Software, Inc. + * Copyright 2006-2007 Jiri Benc + * Copyright 2007, Michael Wu + * Copyright 2009 Johannes Berg + * Copyright (C) 2019, 2022 Intel Corporation + */ +#include +#include +#include "ieee80211_i.h" +#include "driver-ops.h" + +/* + * Tell our hardware to disable PS. + * Optionally inform AP that we will go to sleep so that it will buffer + * the frames while we are doing off-channel work. This is optional + * because we *may* be doing work on-operating channel, and want our + * hardware unconditionally awake, but still let the AP send us normal frames. + */ +static void ieee80211_offchannel_ps_enable(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + bool offchannel_ps_enabled = false; + + /* FIXME: what to do when local->pspolling is true? */ + + del_timer_sync(&local->dynamic_ps_timer); + del_timer_sync(&ifmgd->bcn_mon_timer); + del_timer_sync(&ifmgd->conn_mon_timer); + + cancel_work_sync(&local->dynamic_ps_enable_work); + + if (local->hw.conf.flags & IEEE80211_CONF_PS) { + offchannel_ps_enabled = true; + local->hw.conf.flags &= ~IEEE80211_CONF_PS; + ieee80211_hw_config(local, IEEE80211_CONF_CHANGE_PS); + } + + if (!offchannel_ps_enabled || + !ieee80211_hw_check(&local->hw, PS_NULLFUNC_STACK)) + /* + * If power save was enabled, no need to send a nullfunc + * frame because AP knows that we are sleeping. But if the + * hardware is creating the nullfunc frame for power save + * status (ie. IEEE80211_HW_PS_NULLFUNC_STACK is not + * enabled) and power save was enabled, the firmware just + * sent a null frame with power save disabled. So we need + * to send a new nullfunc frame to inform the AP that we + * are again sleeping. + */ + ieee80211_send_nullfunc(local, sdata, true); +} + +/* inform AP that we are awake again */ +static void ieee80211_offchannel_ps_disable(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + + if (!local->ps_sdata) + ieee80211_send_nullfunc(local, sdata, false); + else if (local->hw.conf.dynamic_ps_timeout > 0) { + /* + * the dynamic_ps_timer had been running before leaving the + * operating channel, restart the timer now and send a nullfunc + * frame to inform the AP that we are awake so that AP sends + * the buffered packets (if any). + */ + ieee80211_send_nullfunc(local, sdata, false); + mod_timer(&local->dynamic_ps_timer, jiffies + + msecs_to_jiffies(local->hw.conf.dynamic_ps_timeout)); + } + + ieee80211_sta_reset_beacon_monitor(sdata); + ieee80211_sta_reset_conn_monitor(sdata); +} + +void ieee80211_offchannel_stop_vifs(struct ieee80211_local *local) +{ + struct ieee80211_sub_if_data *sdata; + + if (WARN_ON(local->use_chanctx)) + return; + + /* + * notify the AP about us leaving the channel and stop all + * STA interfaces. + */ + + /* + * Stop queues and transmit all frames queued by the driver + * before sending nullfunc to enable powersave at the AP. + */ + ieee80211_stop_queues_by_reason(&local->hw, IEEE80211_MAX_QUEUE_MAP, + IEEE80211_QUEUE_STOP_REASON_OFFCHANNEL, + false); + ieee80211_flush_queues(local, NULL, false); + + mutex_lock(&local->iflist_mtx); + list_for_each_entry(sdata, &local->interfaces, list) { + if (!ieee80211_sdata_running(sdata)) + continue; + + if (sdata->vif.type == NL80211_IFTYPE_P2P_DEVICE || + sdata->vif.type == NL80211_IFTYPE_NAN) + continue; + + if (sdata->vif.type != NL80211_IFTYPE_MONITOR) + set_bit(SDATA_STATE_OFFCHANNEL, &sdata->state); + + /* Check to see if we should disable beaconing. */ + if (sdata->vif.bss_conf.enable_beacon) { + set_bit(SDATA_STATE_OFFCHANNEL_BEACON_STOPPED, + &sdata->state); + sdata->vif.bss_conf.enable_beacon = false; + ieee80211_link_info_change_notify( + sdata, &sdata->deflink, + BSS_CHANGED_BEACON_ENABLED); + } + + if (sdata->vif.type == NL80211_IFTYPE_STATION && + sdata->u.mgd.associated) + ieee80211_offchannel_ps_enable(sdata); + } + mutex_unlock(&local->iflist_mtx); +} + +void ieee80211_offchannel_return(struct ieee80211_local *local) +{ + struct ieee80211_sub_if_data *sdata; + + if (WARN_ON(local->use_chanctx)) + return; + + mutex_lock(&local->iflist_mtx); + list_for_each_entry(sdata, &local->interfaces, list) { + if (sdata->vif.type == NL80211_IFTYPE_P2P_DEVICE) + continue; + + if (sdata->vif.type != NL80211_IFTYPE_MONITOR) + clear_bit(SDATA_STATE_OFFCHANNEL, &sdata->state); + + if (!ieee80211_sdata_running(sdata)) + continue; + + /* Tell AP we're back */ + if (sdata->vif.type == NL80211_IFTYPE_STATION && + sdata->u.mgd.associated) + ieee80211_offchannel_ps_disable(sdata); + + if (test_and_clear_bit(SDATA_STATE_OFFCHANNEL_BEACON_STOPPED, + &sdata->state)) { + sdata->vif.bss_conf.enable_beacon = true; + ieee80211_link_info_change_notify( + sdata, &sdata->deflink, + BSS_CHANGED_BEACON_ENABLED); + } + } + mutex_unlock(&local->iflist_mtx); + + ieee80211_wake_queues_by_reason(&local->hw, IEEE80211_MAX_QUEUE_MAP, + IEEE80211_QUEUE_STOP_REASON_OFFCHANNEL, + false); +} + +static void ieee80211_roc_notify_destroy(struct ieee80211_roc_work *roc) +{ + /* was never transmitted */ + if (roc->frame) { + cfg80211_mgmt_tx_status(&roc->sdata->wdev, roc->mgmt_tx_cookie, + roc->frame->data, roc->frame->len, + false, GFP_KERNEL); + ieee80211_free_txskb(&roc->sdata->local->hw, roc->frame); + } + + if (!roc->mgmt_tx_cookie) + cfg80211_remain_on_channel_expired(&roc->sdata->wdev, + roc->cookie, roc->chan, + GFP_KERNEL); + else + cfg80211_tx_mgmt_expired(&roc->sdata->wdev, + roc->mgmt_tx_cookie, + roc->chan, GFP_KERNEL); + + list_del(&roc->list); + kfree(roc); +} + +static unsigned long ieee80211_end_finished_rocs(struct ieee80211_local *local, + unsigned long now) +{ + struct ieee80211_roc_work *roc, *tmp; + long remaining_dur_min = LONG_MAX; + + lockdep_assert_held(&local->mtx); + + list_for_each_entry_safe(roc, tmp, &local->roc_list, list) { + long remaining; + + if (!roc->started) + break; + + remaining = roc->start_time + + msecs_to_jiffies(roc->duration) - + now; + + /* In case of HW ROC, it is possible that the HW finished the + * ROC session before the actual requested time. In such a case + * end the ROC session (disregarding the remaining time). + */ + if (roc->abort || roc->hw_begun || remaining <= 0) + ieee80211_roc_notify_destroy(roc); + else + remaining_dur_min = min(remaining_dur_min, remaining); + } + + return remaining_dur_min; +} + +static bool ieee80211_recalc_sw_work(struct ieee80211_local *local, + unsigned long now) +{ + long dur = ieee80211_end_finished_rocs(local, now); + + if (dur == LONG_MAX) + return false; + + wiphy_delayed_work_queue(local->hw.wiphy, &local->roc_work, dur); + return true; +} + +static void ieee80211_handle_roc_started(struct ieee80211_roc_work *roc, + unsigned long start_time) +{ + if (WARN_ON(roc->notified)) + return; + + roc->start_time = start_time; + roc->started = true; + + if (roc->mgmt_tx_cookie) { + if (!WARN_ON(!roc->frame)) { + ieee80211_tx_skb_tid_band(roc->sdata, roc->frame, 7, + roc->chan->band); + roc->frame = NULL; + } + } else { + cfg80211_ready_on_channel(&roc->sdata->wdev, roc->cookie, + roc->chan, roc->req_duration, + GFP_KERNEL); + } + + roc->notified = true; +} + +static void ieee80211_hw_roc_start(struct wiphy *wiphy, struct wiphy_work *work) +{ + struct ieee80211_local *local = + container_of(work, struct ieee80211_local, hw_roc_start); + struct ieee80211_roc_work *roc; + + mutex_lock(&local->mtx); + + list_for_each_entry(roc, &local->roc_list, list) { + if (!roc->started) + break; + + roc->hw_begun = true; + ieee80211_handle_roc_started(roc, local->hw_roc_start_time); + } + + mutex_unlock(&local->mtx); +} + +void ieee80211_ready_on_channel(struct ieee80211_hw *hw) +{ + struct ieee80211_local *local = hw_to_local(hw); + + local->hw_roc_start_time = jiffies; + + trace_api_ready_on_channel(local); + + wiphy_work_queue(hw->wiphy, &local->hw_roc_start); +} +EXPORT_SYMBOL_GPL(ieee80211_ready_on_channel); + +static void _ieee80211_start_next_roc(struct ieee80211_local *local) +{ + struct ieee80211_roc_work *roc, *tmp; + enum ieee80211_roc_type type; + u32 min_dur, max_dur; + + lockdep_assert_held(&local->mtx); + + if (WARN_ON(list_empty(&local->roc_list))) + return; + + roc = list_first_entry(&local->roc_list, struct ieee80211_roc_work, + list); + + if (WARN_ON(roc->started)) + return; + + min_dur = roc->duration; + max_dur = roc->duration; + type = roc->type; + + list_for_each_entry(tmp, &local->roc_list, list) { + if (tmp == roc) + continue; + if (tmp->sdata != roc->sdata || tmp->chan != roc->chan) + break; + max_dur = max(tmp->duration, max_dur); + min_dur = min(tmp->duration, min_dur); + type = max(tmp->type, type); + } + + if (local->ops->remain_on_channel) { + int ret = drv_remain_on_channel(local, roc->sdata, roc->chan, + max_dur, type); + + if (ret) { + wiphy_warn(local->hw.wiphy, + "failed to start next HW ROC (%d)\n", ret); + /* + * queue the work struct again to avoid recursion + * when multiple failures occur + */ + list_for_each_entry(tmp, &local->roc_list, list) { + if (tmp->sdata != roc->sdata || + tmp->chan != roc->chan) + break; + tmp->started = true; + tmp->abort = true; + } + wiphy_work_queue(local->hw.wiphy, &local->hw_roc_done); + return; + } + + /* we'll notify about the start once the HW calls back */ + list_for_each_entry(tmp, &local->roc_list, list) { + if (tmp->sdata != roc->sdata || tmp->chan != roc->chan) + break; + tmp->started = true; + } + } else { + /* If actually operating on the desired channel (with at least + * 20 MHz channel width) don't stop all the operations but still + * treat it as though the ROC operation started properly, so + * other ROC operations won't interfere with this one. + */ + roc->on_channel = roc->chan == local->_oper_chandef.chan && + local->_oper_chandef.width != NL80211_CHAN_WIDTH_5 && + local->_oper_chandef.width != NL80211_CHAN_WIDTH_10; + + /* start this ROC */ + ieee80211_recalc_idle(local); + + if (!roc->on_channel) { + ieee80211_offchannel_stop_vifs(local); + + local->tmp_channel = roc->chan; + ieee80211_hw_config(local, 0); + } + + wiphy_delayed_work_queue(local->hw.wiphy, &local->roc_work, + msecs_to_jiffies(min_dur)); + + /* tell userspace or send frame(s) */ + list_for_each_entry(tmp, &local->roc_list, list) { + if (tmp->sdata != roc->sdata || tmp->chan != roc->chan) + break; + + tmp->on_channel = roc->on_channel; + ieee80211_handle_roc_started(tmp, jiffies); + } + } +} + +void ieee80211_start_next_roc(struct ieee80211_local *local) +{ + struct ieee80211_roc_work *roc; + + lockdep_assert_held(&local->mtx); + + if (list_empty(&local->roc_list)) { + ieee80211_run_deferred_scan(local); + return; + } + + /* defer roc if driver is not started (i.e. during reconfig) */ + if (local->in_reconfig) + return; + + roc = list_first_entry(&local->roc_list, struct ieee80211_roc_work, + list); + + if (WARN_ON_ONCE(roc->started)) + return; + + if (local->ops->remain_on_channel) { + _ieee80211_start_next_roc(local); + } else { + /* delay it a bit */ + wiphy_delayed_work_queue(local->hw.wiphy, &local->roc_work, + round_jiffies_relative(HZ / 2)); + } +} + +static void __ieee80211_roc_work(struct ieee80211_local *local) +{ + struct ieee80211_roc_work *roc; + bool on_channel; + + lockdep_assert_held(&local->mtx); + + if (WARN_ON(local->ops->remain_on_channel)) + return; + + roc = list_first_entry_or_null(&local->roc_list, + struct ieee80211_roc_work, list); + if (!roc) + return; + + if (!roc->started) { + WARN_ON(local->use_chanctx); + _ieee80211_start_next_roc(local); + } else { + on_channel = roc->on_channel; + if (ieee80211_recalc_sw_work(local, jiffies)) + return; + + /* careful - roc pointer became invalid during recalc */ + + if (!on_channel) { + ieee80211_flush_queues(local, NULL, false); + + local->tmp_channel = NULL; + ieee80211_hw_config(local, 0); + + ieee80211_offchannel_return(local); + } + + ieee80211_recalc_idle(local); + ieee80211_start_next_roc(local); + } +} + +static void ieee80211_roc_work(struct wiphy *wiphy, struct wiphy_work *work) +{ + struct ieee80211_local *local = + container_of(work, struct ieee80211_local, roc_work.work); + + mutex_lock(&local->mtx); + __ieee80211_roc_work(local); + mutex_unlock(&local->mtx); +} + +static void ieee80211_hw_roc_done(struct wiphy *wiphy, struct wiphy_work *work) +{ + struct ieee80211_local *local = + container_of(work, struct ieee80211_local, hw_roc_done); + + mutex_lock(&local->mtx); + + ieee80211_end_finished_rocs(local, jiffies); + + /* if there's another roc, start it now */ + ieee80211_start_next_roc(local); + + mutex_unlock(&local->mtx); +} + +void ieee80211_remain_on_channel_expired(struct ieee80211_hw *hw) +{ + struct ieee80211_local *local = hw_to_local(hw); + + trace_api_remain_on_channel_expired(local); + + wiphy_work_queue(hw->wiphy, &local->hw_roc_done); +} +EXPORT_SYMBOL_GPL(ieee80211_remain_on_channel_expired); + +static bool +ieee80211_coalesce_hw_started_roc(struct ieee80211_local *local, + struct ieee80211_roc_work *new_roc, + struct ieee80211_roc_work *cur_roc) +{ + unsigned long now = jiffies; + unsigned long remaining; + + if (WARN_ON(!cur_roc->started)) + return false; + + /* if it was scheduled in the hardware, but not started yet, + * we can only combine if the older one had a longer duration + */ + if (!cur_roc->hw_begun && new_roc->duration > cur_roc->duration) + return false; + + remaining = cur_roc->start_time + + msecs_to_jiffies(cur_roc->duration) - + now; + + /* if it doesn't fit entirely, schedule a new one */ + if (new_roc->duration > jiffies_to_msecs(remaining)) + return false; + + /* add just after the current one so we combine their finish later */ + list_add(&new_roc->list, &cur_roc->list); + + /* if the existing one has already begun then let this one also + * begin, otherwise they'll both be marked properly by the work + * struct that runs once the driver notifies us of the beginning + */ + if (cur_roc->hw_begun) { + new_roc->hw_begun = true; + ieee80211_handle_roc_started(new_roc, now); + } + + return true; +} + +static int ieee80211_start_roc_work(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_channel *channel, + unsigned int duration, u64 *cookie, + struct sk_buff *txskb, + enum ieee80211_roc_type type) +{ + struct ieee80211_roc_work *roc, *tmp; + bool queued = false, combine_started = true; + int ret; + + lockdep_assert_held(&local->mtx); + + if (channel->freq_offset) + /* this may work, but is untested */ + return -EOPNOTSUPP; + + if (local->use_chanctx && !local->ops->remain_on_channel) + return -EOPNOTSUPP; + + roc = kzalloc(sizeof(*roc), GFP_KERNEL); + if (!roc) + return -ENOMEM; + + /* + * If the duration is zero, then the driver + * wouldn't actually do anything. Set it to + * 10 for now. + * + * TODO: cancel the off-channel operation + * when we get the SKB's TX status and + * the wait time was zero before. + */ + if (!duration) + duration = 10; + + roc->chan = channel; + roc->duration = duration; + roc->req_duration = duration; + roc->frame = txskb; + roc->type = type; + roc->sdata = sdata; + + /* + * cookie is either the roc cookie (for normal roc) + * or the SKB (for mgmt TX) + */ + if (!txskb) { + roc->cookie = ieee80211_mgmt_tx_cookie(local); + *cookie = roc->cookie; + } else { + roc->mgmt_tx_cookie = *cookie; + } + + /* if there's no need to queue, handle it immediately */ + if (list_empty(&local->roc_list) && + !local->scanning && !ieee80211_is_radar_required(local)) { + /* if not HW assist, just queue & schedule work */ + if (!local->ops->remain_on_channel) { + list_add_tail(&roc->list, &local->roc_list); + wiphy_delayed_work_queue(local->hw.wiphy, + &local->roc_work, 0); + } else { + /* otherwise actually kick it off here + * (for error handling) + */ + ret = drv_remain_on_channel(local, sdata, channel, + duration, type); + if (ret) { + kfree(roc); + return ret; + } + roc->started = true; + list_add_tail(&roc->list, &local->roc_list); + } + + return 0; + } + + /* otherwise handle queueing */ + + list_for_each_entry(tmp, &local->roc_list, list) { + if (tmp->chan != channel || tmp->sdata != sdata) + continue; + + /* + * Extend this ROC if possible: If it hasn't started, add + * just after the new one to combine. + */ + if (!tmp->started) { + list_add(&roc->list, &tmp->list); + queued = true; + break; + } + + if (!combine_started) + continue; + + if (!local->ops->remain_on_channel) { + /* If there's no hardware remain-on-channel, and + * doing so won't push us over the maximum r-o-c + * we allow, then we can just add the new one to + * the list and mark it as having started now. + * If it would push over the limit, don't try to + * combine with other started ones (that haven't + * been running as long) but potentially sort it + * with others that had the same fate. + */ + unsigned long now = jiffies; + u32 elapsed = jiffies_to_msecs(now - tmp->start_time); + struct wiphy *wiphy = local->hw.wiphy; + u32 max_roc = wiphy->max_remain_on_channel_duration; + + if (elapsed + roc->duration > max_roc) { + combine_started = false; + continue; + } + + list_add(&roc->list, &tmp->list); + queued = true; + roc->on_channel = tmp->on_channel; + ieee80211_handle_roc_started(roc, now); + ieee80211_recalc_sw_work(local, now); + break; + } + + queued = ieee80211_coalesce_hw_started_roc(local, roc, tmp); + if (queued) + break; + /* if it wasn't queued, perhaps it can be combined with + * another that also couldn't get combined previously, + * but no need to check for already started ones, since + * that can't work. + */ + combine_started = false; + } + + if (!queued) + list_add_tail(&roc->list, &local->roc_list); + + return 0; +} + +int ieee80211_remain_on_channel(struct wiphy *wiphy, struct wireless_dev *wdev, + struct ieee80211_channel *chan, + unsigned int duration, u64 *cookie) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + struct ieee80211_local *local = sdata->local; + int ret; + + mutex_lock(&local->mtx); + ret = ieee80211_start_roc_work(local, sdata, chan, + duration, cookie, NULL, + IEEE80211_ROC_TYPE_NORMAL); + mutex_unlock(&local->mtx); + + return ret; +} + +static int ieee80211_cancel_roc(struct ieee80211_local *local, + u64 cookie, bool mgmt_tx) +{ + struct ieee80211_roc_work *roc, *tmp, *found = NULL; + int ret; + + if (!cookie) + return -ENOENT; + + wiphy_work_flush(local->hw.wiphy, &local->hw_roc_start); + + mutex_lock(&local->mtx); + list_for_each_entry_safe(roc, tmp, &local->roc_list, list) { + if (!mgmt_tx && roc->cookie != cookie) + continue; + else if (mgmt_tx && roc->mgmt_tx_cookie != cookie) + continue; + + found = roc; + break; + } + + if (!found) { + mutex_unlock(&local->mtx); + return -ENOENT; + } + + if (!found->started) { + ieee80211_roc_notify_destroy(found); + goto out_unlock; + } + + if (local->ops->remain_on_channel) { + ret = drv_cancel_remain_on_channel(local, roc->sdata); + if (WARN_ON_ONCE(ret)) { + mutex_unlock(&local->mtx); + return ret; + } + + /* TODO: + * if multiple items were combined here then we really shouldn't + * cancel them all - we should wait for as much time as needed + * for the longest remaining one, and only then cancel ... + */ + list_for_each_entry_safe(roc, tmp, &local->roc_list, list) { + if (!roc->started) + break; + if (roc == found) + found = NULL; + ieee80211_roc_notify_destroy(roc); + } + + /* that really must not happen - it was started */ + WARN_ON(found); + + ieee80211_start_next_roc(local); + } else { + /* go through work struct to return to the operating channel */ + found->abort = true; + wiphy_delayed_work_queue(local->hw.wiphy, &local->roc_work, 0); + } + + out_unlock: + mutex_unlock(&local->mtx); + + return 0; +} + +int ieee80211_cancel_remain_on_channel(struct wiphy *wiphy, + struct wireless_dev *wdev, u64 cookie) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + struct ieee80211_local *local = sdata->local; + + return ieee80211_cancel_roc(local, cookie, false); +} + +int ieee80211_mgmt_tx(struct wiphy *wiphy, struct wireless_dev *wdev, + struct cfg80211_mgmt_tx_params *params, u64 *cookie) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + struct ieee80211_local *local = sdata->local; + struct sk_buff *skb; + struct sta_info *sta = NULL; + const struct ieee80211_mgmt *mgmt = (void *)params->buf; + bool need_offchan = false; + bool mlo_sta = false; + int link_id = -1; + u32 flags; + int ret; + u8 *data; + + if (params->dont_wait_for_ack) + flags = IEEE80211_TX_CTL_NO_ACK; + else + flags = IEEE80211_TX_INTFL_NL80211_FRAME_TX | + IEEE80211_TX_CTL_REQ_TX_STATUS; + + if (params->no_cck) + flags |= IEEE80211_TX_CTL_NO_CCK_RATE; + + switch (sdata->vif.type) { + case NL80211_IFTYPE_ADHOC: + if (!sdata->vif.cfg.ibss_joined) + need_offchan = true; +#ifdef CONFIG_MAC80211_MESH + fallthrough; + case NL80211_IFTYPE_MESH_POINT: + if (ieee80211_vif_is_mesh(&sdata->vif) && + !sdata->u.mesh.mesh_id_len) + need_offchan = true; +#endif + fallthrough; + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_P2P_GO: + if (sdata->vif.type != NL80211_IFTYPE_ADHOC && + !ieee80211_vif_is_mesh(&sdata->vif) && + !sdata->bss->active) + need_offchan = true; + + rcu_read_lock(); + sta = sta_info_get_bss(sdata, mgmt->da); + mlo_sta = sta && sta->sta.mlo; + + if (!ieee80211_is_action(mgmt->frame_control) || + mgmt->u.action.category == WLAN_CATEGORY_PUBLIC || + mgmt->u.action.category == WLAN_CATEGORY_SELF_PROTECTED || + mgmt->u.action.category == WLAN_CATEGORY_SPECTRUM_MGMT) { + rcu_read_unlock(); + break; + } + + if (!sta) { + rcu_read_unlock(); + return -ENOLINK; + } + if (params->link_id >= 0 && + !(sta->sta.valid_links & BIT(params->link_id))) { + rcu_read_unlock(); + return -ENOLINK; + } + link_id = params->link_id; + rcu_read_unlock(); + break; + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_P2P_CLIENT: + sdata_lock(sdata); + if (!sdata->u.mgd.associated || + (params->offchan && params->wait && + local->ops->remain_on_channel && + memcmp(sdata->vif.cfg.ap_addr, mgmt->bssid, ETH_ALEN))) + need_offchan = true; + sdata_unlock(sdata); + break; + case NL80211_IFTYPE_P2P_DEVICE: + need_offchan = true; + break; + case NL80211_IFTYPE_NAN: + default: + return -EOPNOTSUPP; + } + + /* configurations requiring offchan cannot work if no channel has been + * specified + */ + if (need_offchan && !params->chan) + return -EINVAL; + + mutex_lock(&local->mtx); + + /* Check if the operating channel is the requested channel */ + if (!params->chan && mlo_sta) { + need_offchan = false; + } else if (!need_offchan) { + struct ieee80211_chanctx_conf *chanctx_conf = NULL; + int i; + + rcu_read_lock(); + /* Check all the links first */ + for (i = 0; i < ARRAY_SIZE(sdata->vif.link_conf); i++) { + struct ieee80211_bss_conf *conf; + + conf = rcu_dereference(sdata->vif.link_conf[i]); + if (!conf) + continue; + + chanctx_conf = rcu_dereference(conf->chanctx_conf); + if (!chanctx_conf) + continue; + + if (mlo_sta && params->chan == chanctx_conf->def.chan && + ether_addr_equal(sdata->vif.addr, mgmt->sa)) { + link_id = i; + break; + } + + if (ether_addr_equal(conf->addr, mgmt->sa)) + break; + + chanctx_conf = NULL; + } + + if (chanctx_conf) { + need_offchan = params->chan && + (params->chan != + chanctx_conf->def.chan); + } else { + need_offchan = true; + } + rcu_read_unlock(); + } + + if (need_offchan && !params->offchan) { + ret = -EBUSY; + goto out_unlock; + } + + skb = dev_alloc_skb(local->hw.extra_tx_headroom + params->len); + if (!skb) { + ret = -ENOMEM; + goto out_unlock; + } + skb_reserve(skb, local->hw.extra_tx_headroom); + + data = skb_put_data(skb, params->buf, params->len); + + /* Update CSA counters */ + if (sdata->vif.bss_conf.csa_active && + (sdata->vif.type == NL80211_IFTYPE_AP || + sdata->vif.type == NL80211_IFTYPE_MESH_POINT || + sdata->vif.type == NL80211_IFTYPE_ADHOC) && + params->n_csa_offsets) { + int i; + struct beacon_data *beacon = NULL; + + rcu_read_lock(); + + if (sdata->vif.type == NL80211_IFTYPE_AP) + beacon = rcu_dereference(sdata->deflink.u.ap.beacon); + else if (sdata->vif.type == NL80211_IFTYPE_ADHOC) + beacon = rcu_dereference(sdata->u.ibss.presp); + else if (ieee80211_vif_is_mesh(&sdata->vif)) + beacon = rcu_dereference(sdata->u.mesh.beacon); + + if (beacon) + for (i = 0; i < params->n_csa_offsets; i++) + data[params->csa_offsets[i]] = + beacon->cntdwn_current_counter; + + rcu_read_unlock(); + } + + IEEE80211_SKB_CB(skb)->flags = flags; + + skb->dev = sdata->dev; + + if (!params->dont_wait_for_ack) { + /* make a copy to preserve the frame contents + * in case of encryption. + */ + ret = ieee80211_attach_ack_skb(local, skb, cookie, GFP_KERNEL); + if (ret) { + kfree_skb(skb); + goto out_unlock; + } + } else { + /* Assign a dummy non-zero cookie, it's not sent to + * userspace in this case but we rely on its value + * internally in the need_offchan case to distinguish + * mgmt-tx from remain-on-channel. + */ + *cookie = 0xffffffff; + } + + if (!need_offchan) { + ieee80211_tx_skb_tid(sdata, skb, 7, link_id); + ret = 0; + goto out_unlock; + } + + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_CTL_TX_OFFCHAN | + IEEE80211_TX_INTFL_OFFCHAN_TX_OK; + if (ieee80211_hw_check(&local->hw, QUEUE_CONTROL)) + IEEE80211_SKB_CB(skb)->hw_queue = + local->hw.offchannel_tx_hw_queue; + + /* This will handle all kinds of coalescing and immediate TX */ + ret = ieee80211_start_roc_work(local, sdata, params->chan, + params->wait, cookie, skb, + IEEE80211_ROC_TYPE_MGMT_TX); + if (ret) + ieee80211_free_txskb(&local->hw, skb); + out_unlock: + mutex_unlock(&local->mtx); + return ret; +} + +int ieee80211_mgmt_tx_cancel_wait(struct wiphy *wiphy, + struct wireless_dev *wdev, u64 cookie) +{ + struct ieee80211_local *local = wiphy_priv(wiphy); + + return ieee80211_cancel_roc(local, cookie, true); +} + +void ieee80211_roc_setup(struct ieee80211_local *local) +{ + wiphy_work_init(&local->hw_roc_start, ieee80211_hw_roc_start); + wiphy_work_init(&local->hw_roc_done, ieee80211_hw_roc_done); + wiphy_delayed_work_init(&local->roc_work, ieee80211_roc_work); + INIT_LIST_HEAD(&local->roc_list); +} + +void ieee80211_roc_purge(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_roc_work *roc, *tmp; + bool work_to_do = false; + + mutex_lock(&local->mtx); + list_for_each_entry_safe(roc, tmp, &local->roc_list, list) { + if (sdata && roc->sdata != sdata) + continue; + + if (roc->started) { + if (local->ops->remain_on_channel) { + /* can race, so ignore return value */ + drv_cancel_remain_on_channel(local, sdata); + ieee80211_roc_notify_destroy(roc); + } else { + roc->abort = true; + work_to_do = true; + } + } else { + ieee80211_roc_notify_destroy(roc); + } + } + if (work_to_do) + __ieee80211_roc_work(local); + mutex_unlock(&local->mtx); +} diff --git a/net/mac80211/pm.c b/net/mac80211/pm.c new file mode 100644 index 000000000..0ccb5701c --- /dev/null +++ b/net/mac80211/pm.c @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Portions + * Copyright (C) 2020-2021 Intel Corporation + */ +#include +#include + +#include "ieee80211_i.h" +#include "mesh.h" +#include "driver-ops.h" +#include "led.h" + +static void ieee80211_sched_scan_cancel(struct ieee80211_local *local) +{ + if (ieee80211_request_sched_scan_stop(local)) + return; + cfg80211_sched_scan_stopped_locked(local->hw.wiphy, 0); +} + +int __ieee80211_suspend(struct ieee80211_hw *hw, struct cfg80211_wowlan *wowlan) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_sub_if_data *sdata; + struct sta_info *sta; + + if (!local->open_count) + goto suspend; + + local->suspending = true; + mb(); /* make suspending visible before any cancellation */ + + ieee80211_scan_cancel(local); + + ieee80211_dfs_cac_cancel(local); + + ieee80211_roc_purge(local, NULL); + + ieee80211_del_virtual_monitor(local); + + if (ieee80211_hw_check(hw, AMPDU_AGGREGATION) && + !(wowlan && wowlan->any)) { + mutex_lock(&local->sta_mtx); + list_for_each_entry(sta, &local->sta_list, list) { + set_sta_flag(sta, WLAN_STA_BLOCK_BA); + ieee80211_sta_tear_down_BA_sessions( + sta, AGG_STOP_LOCAL_REQUEST); + } + mutex_unlock(&local->sta_mtx); + } + + /* keep sched_scan only in case of 'any' trigger */ + if (!(wowlan && wowlan->any)) + ieee80211_sched_scan_cancel(local); + + ieee80211_stop_queues_by_reason(hw, + IEEE80211_MAX_QUEUE_MAP, + IEEE80211_QUEUE_STOP_REASON_SUSPEND, + false); + + /* flush out all packets */ + synchronize_net(); + + ieee80211_flush_queues(local, NULL, true); + + local->quiescing = true; + /* make quiescing visible to timers everywhere */ + mb(); + + flush_workqueue(local->workqueue); + + /* Don't try to run timers while suspended. */ + del_timer_sync(&local->sta_cleanup); + + /* + * Note that this particular timer doesn't need to be + * restarted at resume. + */ + cancel_work_sync(&local->dynamic_ps_enable_work); + del_timer_sync(&local->dynamic_ps_timer); + + local->wowlan = wowlan; + if (local->wowlan) { + int err; + + /* Drivers don't expect to suspend while some operations like + * authenticating or associating are in progress. It doesn't + * make sense anyway to accept that, since the authentication + * or association would never finish since the driver can't do + * that on its own. + * Thus, clean up in-progress auth/assoc first. + */ + list_for_each_entry(sdata, &local->interfaces, list) { + if (!ieee80211_sdata_running(sdata)) + continue; + if (sdata->vif.type != NL80211_IFTYPE_STATION) + continue; + ieee80211_mgd_quiesce(sdata); + /* If suspended during TX in progress, and wowlan + * is enabled (connection will be active) there + * can be a race where the driver is put out + * of power-save due to TX and during suspend + * dynamic_ps_timer is cancelled and TX packet + * is flushed, leaving the driver in ACTIVE even + * after resuming until dynamic_ps_timer puts + * driver back in DOZE. + */ + if (sdata->u.mgd.associated && + sdata->u.mgd.powersave && + !(local->hw.conf.flags & IEEE80211_CONF_PS)) { + local->hw.conf.flags |= IEEE80211_CONF_PS; + ieee80211_hw_config(local, + IEEE80211_CONF_CHANGE_PS); + } + } + + err = drv_suspend(local, wowlan); + if (err < 0) { + local->quiescing = false; + local->wowlan = false; + if (ieee80211_hw_check(hw, AMPDU_AGGREGATION)) { + mutex_lock(&local->sta_mtx); + list_for_each_entry(sta, + &local->sta_list, list) { + clear_sta_flag(sta, WLAN_STA_BLOCK_BA); + } + mutex_unlock(&local->sta_mtx); + } + ieee80211_wake_queues_by_reason(hw, + IEEE80211_MAX_QUEUE_MAP, + IEEE80211_QUEUE_STOP_REASON_SUSPEND, + false); + return err; + } else if (err > 0) { + WARN_ON(err != 1); + /* cfg80211 will call back into mac80211 to disconnect + * all interfaces, allow that to proceed properly + */ + ieee80211_wake_queues_by_reason(hw, + IEEE80211_MAX_QUEUE_MAP, + IEEE80211_QUEUE_STOP_REASON_SUSPEND, + false); + return err; + } else { + goto suspend; + } + } + + /* remove all interfaces that were created in the driver */ + list_for_each_entry(sdata, &local->interfaces, list) { + if (!ieee80211_sdata_running(sdata)) + continue; + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_MONITOR: + continue; + case NL80211_IFTYPE_STATION: + ieee80211_mgd_quiesce(sdata); + break; + default: + break; + } + + flush_delayed_work(&sdata->dec_tailroom_needed_wk); + drv_remove_interface(local, sdata); + } + + /* + * We disconnected on all interfaces before suspend, all channel + * contexts should be released. + */ + WARN_ON(!list_empty(&local->chanctx_list)); + + /* stop hardware - this must stop RX */ + ieee80211_stop_device(local); + + suspend: + local->suspended = true; + /* need suspended to be visible before quiescing is false */ + barrier(); + local->quiescing = false; + local->suspending = false; + + return 0; +} + +/* + * __ieee80211_resume() is a static inline which just calls + * ieee80211_reconfig(), which is also needed for hardware + * hang/firmware failure/etc. recovery. + */ + +void ieee80211_report_wowlan_wakeup(struct ieee80211_vif *vif, + struct cfg80211_wowlan_wakeup *wakeup, + gfp_t gfp) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + + cfg80211_report_wowlan_wakeup(&sdata->wdev, wakeup, gfp); +} +EXPORT_SYMBOL(ieee80211_report_wowlan_wakeup); diff --git a/net/mac80211/rate.c b/net/mac80211/rate.c new file mode 100644 index 000000000..d5ea5f5bc --- /dev/null +++ b/net/mac80211/rate.c @@ -0,0 +1,1018 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2002-2005, Instant802 Networks, Inc. + * Copyright 2005-2006, Devicescape Software, Inc. + * Copyright (c) 2006 Jiri Benc + * Copyright 2017 Intel Deutschland GmbH + * Copyright (C) 2022 Intel Corporation + */ + +#include +#include +#include +#include +#include "rate.h" +#include "ieee80211_i.h" +#include "debugfs.h" + +struct rate_control_alg { + struct list_head list; + const struct rate_control_ops *ops; +}; + +static LIST_HEAD(rate_ctrl_algs); +static DEFINE_MUTEX(rate_ctrl_mutex); + +static char *ieee80211_default_rc_algo = CONFIG_MAC80211_RC_DEFAULT; +module_param(ieee80211_default_rc_algo, charp, 0644); +MODULE_PARM_DESC(ieee80211_default_rc_algo, + "Default rate control algorithm for mac80211 to use"); + +void rate_control_rate_init(struct sta_info *sta) +{ + struct ieee80211_local *local = sta->sdata->local; + struct rate_control_ref *ref = sta->rate_ctrl; + struct ieee80211_sta *ista = &sta->sta; + void *priv_sta = sta->rate_ctrl_priv; + struct ieee80211_supported_band *sband; + struct ieee80211_chanctx_conf *chanctx_conf; + + ieee80211_sta_set_rx_nss(&sta->deflink); + + if (!ref) + return; + + rcu_read_lock(); + + chanctx_conf = rcu_dereference(sta->sdata->vif.bss_conf.chanctx_conf); + if (WARN_ON(!chanctx_conf)) { + rcu_read_unlock(); + return; + } + + sband = local->hw.wiphy->bands[chanctx_conf->def.chan->band]; + + /* TODO: check for minstrel_s1g ? */ + if (sband->band == NL80211_BAND_S1GHZ) { + ieee80211_s1g_sta_rate_init(sta); + rcu_read_unlock(); + return; + } + + spin_lock_bh(&sta->rate_ctrl_lock); + ref->ops->rate_init(ref->priv, sband, &chanctx_conf->def, ista, + priv_sta); + spin_unlock_bh(&sta->rate_ctrl_lock); + rcu_read_unlock(); + set_sta_flag(sta, WLAN_STA_RATE_CONTROL); +} + +void rate_control_tx_status(struct ieee80211_local *local, + struct ieee80211_tx_status *st) +{ + struct rate_control_ref *ref = local->rate_ctrl; + struct sta_info *sta = container_of(st->sta, struct sta_info, sta); + void *priv_sta = sta->rate_ctrl_priv; + struct ieee80211_supported_band *sband; + + if (!ref || !test_sta_flag(sta, WLAN_STA_RATE_CONTROL)) + return; + + sband = local->hw.wiphy->bands[st->info->band]; + + spin_lock_bh(&sta->rate_ctrl_lock); + if (ref->ops->tx_status_ext) + ref->ops->tx_status_ext(ref->priv, sband, priv_sta, st); + else if (st->skb) + ref->ops->tx_status(ref->priv, sband, st->sta, priv_sta, st->skb); + else + WARN_ON_ONCE(1); + + spin_unlock_bh(&sta->rate_ctrl_lock); +} + +void rate_control_rate_update(struct ieee80211_local *local, + struct ieee80211_supported_band *sband, + struct sta_info *sta, unsigned int link_id, + u32 changed) +{ + struct rate_control_ref *ref = local->rate_ctrl; + struct ieee80211_sta *ista = &sta->sta; + void *priv_sta = sta->rate_ctrl_priv; + struct ieee80211_chanctx_conf *chanctx_conf; + + WARN_ON(link_id != 0); + + if (ref && ref->ops->rate_update) { + rcu_read_lock(); + + chanctx_conf = rcu_dereference(sta->sdata->vif.bss_conf.chanctx_conf); + if (WARN_ON(!chanctx_conf)) { + rcu_read_unlock(); + return; + } + + spin_lock_bh(&sta->rate_ctrl_lock); + ref->ops->rate_update(ref->priv, sband, &chanctx_conf->def, + ista, priv_sta, changed); + spin_unlock_bh(&sta->rate_ctrl_lock); + rcu_read_unlock(); + } + + drv_sta_rc_update(local, sta->sdata, &sta->sta, changed); +} + +int ieee80211_rate_control_register(const struct rate_control_ops *ops) +{ + struct rate_control_alg *alg; + + if (!ops->name) + return -EINVAL; + + mutex_lock(&rate_ctrl_mutex); + list_for_each_entry(alg, &rate_ctrl_algs, list) { + if (!strcmp(alg->ops->name, ops->name)) { + /* don't register an algorithm twice */ + WARN_ON(1); + mutex_unlock(&rate_ctrl_mutex); + return -EALREADY; + } + } + + alg = kzalloc(sizeof(*alg), GFP_KERNEL); + if (alg == NULL) { + mutex_unlock(&rate_ctrl_mutex); + return -ENOMEM; + } + alg->ops = ops; + + list_add_tail(&alg->list, &rate_ctrl_algs); + mutex_unlock(&rate_ctrl_mutex); + + return 0; +} +EXPORT_SYMBOL(ieee80211_rate_control_register); + +void ieee80211_rate_control_unregister(const struct rate_control_ops *ops) +{ + struct rate_control_alg *alg; + + mutex_lock(&rate_ctrl_mutex); + list_for_each_entry(alg, &rate_ctrl_algs, list) { + if (alg->ops == ops) { + list_del(&alg->list); + kfree(alg); + break; + } + } + mutex_unlock(&rate_ctrl_mutex); +} +EXPORT_SYMBOL(ieee80211_rate_control_unregister); + +static const struct rate_control_ops * +ieee80211_try_rate_control_ops_get(const char *name) +{ + struct rate_control_alg *alg; + const struct rate_control_ops *ops = NULL; + + if (!name) + return NULL; + + mutex_lock(&rate_ctrl_mutex); + list_for_each_entry(alg, &rate_ctrl_algs, list) { + if (!strcmp(alg->ops->name, name)) { + ops = alg->ops; + break; + } + } + mutex_unlock(&rate_ctrl_mutex); + return ops; +} + +/* Get the rate control algorithm. */ +static const struct rate_control_ops * +ieee80211_rate_control_ops_get(const char *name) +{ + const struct rate_control_ops *ops; + const char *alg_name; + + kernel_param_lock(THIS_MODULE); + if (!name) + alg_name = ieee80211_default_rc_algo; + else + alg_name = name; + + ops = ieee80211_try_rate_control_ops_get(alg_name); + if (!ops && name) + /* try default if specific alg requested but not found */ + ops = ieee80211_try_rate_control_ops_get(ieee80211_default_rc_algo); + + /* Note: check for > 0 is intentional to avoid clang warning */ + if (!ops && (strlen(CONFIG_MAC80211_RC_DEFAULT) > 0)) + /* try built-in one if specific alg requested but not found */ + ops = ieee80211_try_rate_control_ops_get(CONFIG_MAC80211_RC_DEFAULT); + + kernel_param_unlock(THIS_MODULE); + + return ops; +} + +#ifdef CONFIG_MAC80211_DEBUGFS +static ssize_t rcname_read(struct file *file, char __user *userbuf, + size_t count, loff_t *ppos) +{ + struct rate_control_ref *ref = file->private_data; + int len = strlen(ref->ops->name); + + return simple_read_from_buffer(userbuf, count, ppos, + ref->ops->name, len); +} + +const struct file_operations rcname_ops = { + .read = rcname_read, + .open = simple_open, + .llseek = default_llseek, +}; +#endif + +static struct rate_control_ref * +rate_control_alloc(const char *name, struct ieee80211_local *local) +{ + struct rate_control_ref *ref; + + ref = kmalloc(sizeof(struct rate_control_ref), GFP_KERNEL); + if (!ref) + return NULL; + ref->ops = ieee80211_rate_control_ops_get(name); + if (!ref->ops) + goto free; + + ref->priv = ref->ops->alloc(&local->hw); + if (!ref->priv) + goto free; + return ref; + +free: + kfree(ref); + return NULL; +} + +static void rate_control_free(struct ieee80211_local *local, + struct rate_control_ref *ctrl_ref) +{ + ctrl_ref->ops->free(ctrl_ref->priv); + +#ifdef CONFIG_MAC80211_DEBUGFS + debugfs_remove_recursive(local->debugfs.rcdir); + local->debugfs.rcdir = NULL; +#endif + + kfree(ctrl_ref); +} + +void ieee80211_check_rate_mask(struct ieee80211_link_data *link) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_supported_band *sband; + u32 user_mask, basic_rates = link->conf->basic_rates; + enum nl80211_band band; + + if (WARN_ON(!link->conf->chandef.chan)) + return; + + band = link->conf->chandef.chan->band; + if (band == NL80211_BAND_S1GHZ) { + /* TODO */ + return; + } + + if (WARN_ON_ONCE(!basic_rates)) + return; + + user_mask = sdata->rc_rateidx_mask[band]; + sband = local->hw.wiphy->bands[band]; + + if (user_mask & basic_rates) + return; + + sdata_dbg(sdata, + "no overlap between basic rates (0x%x) and user mask (0x%x on band %d) - clearing the latter", + basic_rates, user_mask, band); + sdata->rc_rateidx_mask[band] = (1 << sband->n_bitrates) - 1; +} + +static bool rc_no_data_or_no_ack_use_min(struct ieee80211_tx_rate_control *txrc) +{ + struct sk_buff *skb = txrc->skb; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + + return (info->flags & (IEEE80211_TX_CTL_NO_ACK | + IEEE80211_TX_CTL_USE_MINRATE)) || + !ieee80211_is_tx_data(skb); +} + +static void rc_send_low_basicrate(struct ieee80211_tx_rate *rate, + u32 basic_rates, + struct ieee80211_supported_band *sband) +{ + u8 i; + + if (sband->band == NL80211_BAND_S1GHZ) { + /* TODO */ + rate->flags |= IEEE80211_TX_RC_S1G_MCS; + rate->idx = 0; + return; + } + + if (basic_rates == 0) + return; /* assume basic rates unknown and accept rate */ + if (rate->idx < 0) + return; + if (basic_rates & (1 << rate->idx)) + return; /* selected rate is a basic rate */ + + for (i = rate->idx + 1; i <= sband->n_bitrates; i++) { + if (basic_rates & (1 << i)) { + rate->idx = i; + return; + } + } + + /* could not find a basic rate; use original selection */ +} + +static void __rate_control_send_low(struct ieee80211_hw *hw, + struct ieee80211_supported_band *sband, + struct ieee80211_sta *sta, + struct ieee80211_tx_info *info, + u32 rate_mask) +{ + int i; + u32 rate_flags = + ieee80211_chandef_rate_flags(&hw->conf.chandef); + + if (sband->band == NL80211_BAND_S1GHZ) { + info->control.rates[0].flags |= IEEE80211_TX_RC_S1G_MCS; + info->control.rates[0].idx = 0; + return; + } + + if ((sband->band == NL80211_BAND_2GHZ) && + (info->flags & IEEE80211_TX_CTL_NO_CCK_RATE)) + rate_flags |= IEEE80211_RATE_ERP_G; + + info->control.rates[0].idx = 0; + for (i = 0; i < sband->n_bitrates; i++) { + if (!(rate_mask & BIT(i))) + continue; + + if ((rate_flags & sband->bitrates[i].flags) != rate_flags) + continue; + + if (!rate_supported(sta, sband->band, i)) + continue; + + info->control.rates[0].idx = i; + break; + } + WARN_ONCE(i == sband->n_bitrates, + "no supported rates for sta %pM (0x%x, band %d) in rate_mask 0x%x with flags 0x%x\n", + sta ? sta->addr : NULL, + sta ? sta->deflink.supp_rates[sband->band] : -1, + sband->band, + rate_mask, rate_flags); + + info->control.rates[0].count = + (info->flags & IEEE80211_TX_CTL_NO_ACK) ? + 1 : hw->max_rate_tries; + + info->control.skip_table = 1; +} + + +static bool rate_control_send_low(struct ieee80211_sta *pubsta, + struct ieee80211_tx_rate_control *txrc) +{ + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(txrc->skb); + struct ieee80211_supported_band *sband = txrc->sband; + struct sta_info *sta; + int mcast_rate; + bool use_basicrate = false; + + if (!pubsta || rc_no_data_or_no_ack_use_min(txrc)) { + __rate_control_send_low(txrc->hw, sband, pubsta, info, + txrc->rate_idx_mask); + + if (!pubsta && txrc->bss) { + mcast_rate = txrc->bss_conf->mcast_rate[sband->band]; + if (mcast_rate > 0) { + info->control.rates[0].idx = mcast_rate - 1; + return true; + } + use_basicrate = true; + } else if (pubsta) { + sta = container_of(pubsta, struct sta_info, sta); + if (ieee80211_vif_is_mesh(&sta->sdata->vif)) + use_basicrate = true; + } + + if (use_basicrate) + rc_send_low_basicrate(&info->control.rates[0], + txrc->bss_conf->basic_rates, + sband); + + return true; + } + return false; +} + +static bool rate_idx_match_legacy_mask(s8 *rate_idx, int n_bitrates, u32 mask) +{ + int j; + + /* See whether the selected rate or anything below it is allowed. */ + for (j = *rate_idx; j >= 0; j--) { + if (mask & (1 << j)) { + /* Okay, found a suitable rate. Use it. */ + *rate_idx = j; + return true; + } + } + + /* Try to find a higher rate that would be allowed */ + for (j = *rate_idx + 1; j < n_bitrates; j++) { + if (mask & (1 << j)) { + /* Okay, found a suitable rate. Use it. */ + *rate_idx = j; + return true; + } + } + return false; +} + +static bool rate_idx_match_mcs_mask(s8 *rate_idx, u8 *mcs_mask) +{ + int i, j; + int ridx, rbit; + + ridx = *rate_idx / 8; + rbit = *rate_idx % 8; + + /* sanity check */ + if (ridx < 0 || ridx >= IEEE80211_HT_MCS_MASK_LEN) + return false; + + /* See whether the selected rate or anything below it is allowed. */ + for (i = ridx; i >= 0; i--) { + for (j = rbit; j >= 0; j--) + if (mcs_mask[i] & BIT(j)) { + *rate_idx = i * 8 + j; + return true; + } + rbit = 7; + } + + /* Try to find a higher rate that would be allowed */ + ridx = (*rate_idx + 1) / 8; + rbit = (*rate_idx + 1) % 8; + + for (i = ridx; i < IEEE80211_HT_MCS_MASK_LEN; i++) { + for (j = rbit; j < 8; j++) + if (mcs_mask[i] & BIT(j)) { + *rate_idx = i * 8 + j; + return true; + } + rbit = 0; + } + return false; +} + +static bool rate_idx_match_vht_mcs_mask(s8 *rate_idx, u16 *vht_mask) +{ + int i, j; + int ridx, rbit; + + ridx = *rate_idx >> 4; + rbit = *rate_idx & 0xf; + + if (ridx < 0 || ridx >= NL80211_VHT_NSS_MAX) + return false; + + /* See whether the selected rate or anything below it is allowed. */ + for (i = ridx; i >= 0; i--) { + for (j = rbit; j >= 0; j--) { + if (vht_mask[i] & BIT(j)) { + *rate_idx = (i << 4) | j; + return true; + } + } + rbit = 15; + } + + /* Try to find a higher rate that would be allowed */ + ridx = (*rate_idx + 1) >> 4; + rbit = (*rate_idx + 1) & 0xf; + + for (i = ridx; i < NL80211_VHT_NSS_MAX; i++) { + for (j = rbit; j < 16; j++) { + if (vht_mask[i] & BIT(j)) { + *rate_idx = (i << 4) | j; + return true; + } + } + rbit = 0; + } + return false; +} + +static void rate_idx_match_mask(s8 *rate_idx, u16 *rate_flags, + struct ieee80211_supported_band *sband, + enum nl80211_chan_width chan_width, + u32 mask, + u8 mcs_mask[IEEE80211_HT_MCS_MASK_LEN], + u16 vht_mask[NL80211_VHT_NSS_MAX]) +{ + if (*rate_flags & IEEE80211_TX_RC_VHT_MCS) { + /* handle VHT rates */ + if (rate_idx_match_vht_mcs_mask(rate_idx, vht_mask)) + return; + + *rate_idx = 0; + /* keep protection flags */ + *rate_flags &= (IEEE80211_TX_RC_USE_RTS_CTS | + IEEE80211_TX_RC_USE_CTS_PROTECT | + IEEE80211_TX_RC_USE_SHORT_PREAMBLE); + + *rate_flags |= IEEE80211_TX_RC_MCS; + if (chan_width == NL80211_CHAN_WIDTH_40) + *rate_flags |= IEEE80211_TX_RC_40_MHZ_WIDTH; + + if (rate_idx_match_mcs_mask(rate_idx, mcs_mask)) + return; + + /* also try the legacy rates. */ + *rate_flags &= ~(IEEE80211_TX_RC_MCS | + IEEE80211_TX_RC_40_MHZ_WIDTH); + if (rate_idx_match_legacy_mask(rate_idx, sband->n_bitrates, + mask)) + return; + } else if (*rate_flags & IEEE80211_TX_RC_MCS) { + /* handle HT rates */ + if (rate_idx_match_mcs_mask(rate_idx, mcs_mask)) + return; + + /* also try the legacy rates. */ + *rate_idx = 0; + /* keep protection flags */ + *rate_flags &= (IEEE80211_TX_RC_USE_RTS_CTS | + IEEE80211_TX_RC_USE_CTS_PROTECT | + IEEE80211_TX_RC_USE_SHORT_PREAMBLE); + if (rate_idx_match_legacy_mask(rate_idx, sband->n_bitrates, + mask)) + return; + } else { + /* handle legacy rates */ + if (rate_idx_match_legacy_mask(rate_idx, sband->n_bitrates, + mask)) + return; + + /* if HT BSS, and we handle a data frame, also try HT rates */ + switch (chan_width) { + case NL80211_CHAN_WIDTH_20_NOHT: + case NL80211_CHAN_WIDTH_5: + case NL80211_CHAN_WIDTH_10: + return; + default: + break; + } + + *rate_idx = 0; + /* keep protection flags */ + *rate_flags &= (IEEE80211_TX_RC_USE_RTS_CTS | + IEEE80211_TX_RC_USE_CTS_PROTECT | + IEEE80211_TX_RC_USE_SHORT_PREAMBLE); + + *rate_flags |= IEEE80211_TX_RC_MCS; + + if (chan_width == NL80211_CHAN_WIDTH_40) + *rate_flags |= IEEE80211_TX_RC_40_MHZ_WIDTH; + + if (rate_idx_match_mcs_mask(rate_idx, mcs_mask)) + return; + } + + /* + * Uh.. No suitable rate exists. This should not really happen with + * sane TX rate mask configurations. However, should someone manage to + * configure supported rates and TX rate mask in incompatible way, + * allow the frame to be transmitted with whatever the rate control + * selected. + */ +} + +static void rate_fixup_ratelist(struct ieee80211_vif *vif, + struct ieee80211_supported_band *sband, + struct ieee80211_tx_info *info, + struct ieee80211_tx_rate *rates, + int max_rates) +{ + struct ieee80211_rate *rate; + bool inval = false; + int i; + + /* + * Set up the RTS/CTS rate as the fastest basic rate + * that is not faster than the data rate unless there + * is no basic rate slower than the data rate, in which + * case we pick the slowest basic rate + * + * XXX: Should this check all retry rates? + */ + if (!(rates[0].flags & + (IEEE80211_TX_RC_MCS | IEEE80211_TX_RC_VHT_MCS))) { + u32 basic_rates = vif->bss_conf.basic_rates; + s8 baserate = basic_rates ? ffs(basic_rates) - 1 : 0; + + rate = &sband->bitrates[rates[0].idx]; + + for (i = 0; i < sband->n_bitrates; i++) { + /* must be a basic rate */ + if (!(basic_rates & BIT(i))) + continue; + /* must not be faster than the data rate */ + if (sband->bitrates[i].bitrate > rate->bitrate) + continue; + /* maximum */ + if (sband->bitrates[baserate].bitrate < + sband->bitrates[i].bitrate) + baserate = i; + } + + info->control.rts_cts_rate_idx = baserate; + } + + for (i = 0; i < max_rates; i++) { + /* + * make sure there's no valid rate following + * an invalid one, just in case drivers don't + * take the API seriously to stop at -1. + */ + if (inval) { + rates[i].idx = -1; + continue; + } + if (rates[i].idx < 0) { + inval = true; + continue; + } + + /* + * For now assume MCS is already set up correctly, this + * needs to be fixed. + */ + if (rates[i].flags & IEEE80211_TX_RC_MCS) { + WARN_ON(rates[i].idx > 76); + + if (!(rates[i].flags & IEEE80211_TX_RC_USE_RTS_CTS) && + info->control.use_cts_prot) + rates[i].flags |= + IEEE80211_TX_RC_USE_CTS_PROTECT; + continue; + } + + if (rates[i].flags & IEEE80211_TX_RC_VHT_MCS) { + WARN_ON(ieee80211_rate_get_vht_mcs(&rates[i]) > 9); + continue; + } + + /* set up RTS protection if desired */ + if (info->control.use_rts) { + rates[i].flags |= IEEE80211_TX_RC_USE_RTS_CTS; + info->control.use_cts_prot = false; + } + + /* RC is busted */ + if (WARN_ON_ONCE(rates[i].idx >= sband->n_bitrates)) { + rates[i].idx = -1; + continue; + } + + rate = &sband->bitrates[rates[i].idx]; + + /* set up short preamble */ + if (info->control.short_preamble && + rate->flags & IEEE80211_RATE_SHORT_PREAMBLE) + rates[i].flags |= IEEE80211_TX_RC_USE_SHORT_PREAMBLE; + + /* set up G protection */ + if (!(rates[i].flags & IEEE80211_TX_RC_USE_RTS_CTS) && + info->control.use_cts_prot && + rate->flags & IEEE80211_RATE_ERP_G) + rates[i].flags |= IEEE80211_TX_RC_USE_CTS_PROTECT; + } +} + + +static void rate_control_fill_sta_table(struct ieee80211_sta *sta, + struct ieee80211_tx_info *info, + struct ieee80211_tx_rate *rates, + int max_rates) +{ + struct ieee80211_sta_rates *ratetbl = NULL; + int i; + + if (sta && !info->control.skip_table) + ratetbl = rcu_dereference(sta->rates); + + /* Fill remaining rate slots with data from the sta rate table. */ + max_rates = min_t(int, max_rates, IEEE80211_TX_RATE_TABLE_SIZE); + for (i = 0; i < max_rates; i++) { + if (i < ARRAY_SIZE(info->control.rates) && + info->control.rates[i].idx >= 0 && + info->control.rates[i].count) { + if (rates != info->control.rates) + rates[i] = info->control.rates[i]; + } else if (ratetbl) { + rates[i].idx = ratetbl->rate[i].idx; + rates[i].flags = ratetbl->rate[i].flags; + if (info->control.use_rts) + rates[i].count = ratetbl->rate[i].count_rts; + else if (info->control.use_cts_prot) + rates[i].count = ratetbl->rate[i].count_cts; + else + rates[i].count = ratetbl->rate[i].count; + } else { + rates[i].idx = -1; + rates[i].count = 0; + } + + if (rates[i].idx < 0 || !rates[i].count) + break; + } +} + +static bool rate_control_cap_mask(struct ieee80211_sub_if_data *sdata, + struct ieee80211_supported_band *sband, + struct ieee80211_sta *sta, u32 *mask, + u8 mcs_mask[IEEE80211_HT_MCS_MASK_LEN], + u16 vht_mask[NL80211_VHT_NSS_MAX]) +{ + u32 i, flags; + + *mask = sdata->rc_rateidx_mask[sband->band]; + flags = ieee80211_chandef_rate_flags(&sdata->vif.bss_conf.chandef); + for (i = 0; i < sband->n_bitrates; i++) { + if ((flags & sband->bitrates[i].flags) != flags) + *mask &= ~BIT(i); + } + + if (*mask == (1 << sband->n_bitrates) - 1 && + !sdata->rc_has_mcs_mask[sband->band] && + !sdata->rc_has_vht_mcs_mask[sband->band]) + return false; + + if (sdata->rc_has_mcs_mask[sband->band]) + memcpy(mcs_mask, sdata->rc_rateidx_mcs_mask[sband->band], + IEEE80211_HT_MCS_MASK_LEN); + else + memset(mcs_mask, 0xff, IEEE80211_HT_MCS_MASK_LEN); + + if (sdata->rc_has_vht_mcs_mask[sband->band]) + memcpy(vht_mask, sdata->rc_rateidx_vht_mcs_mask[sband->band], + sizeof(u16) * NL80211_VHT_NSS_MAX); + else + memset(vht_mask, 0xff, sizeof(u16) * NL80211_VHT_NSS_MAX); + + if (sta) { + __le16 sta_vht_cap; + u16 sta_vht_mask[NL80211_VHT_NSS_MAX]; + + /* Filter out rates that the STA does not support */ + *mask &= sta->deflink.supp_rates[sband->band]; + for (i = 0; i < IEEE80211_HT_MCS_MASK_LEN; i++) + mcs_mask[i] &= sta->deflink.ht_cap.mcs.rx_mask[i]; + + sta_vht_cap = sta->deflink.vht_cap.vht_mcs.rx_mcs_map; + ieee80211_get_vht_mask_from_cap(sta_vht_cap, sta_vht_mask); + for (i = 0; i < NL80211_VHT_NSS_MAX; i++) + vht_mask[i] &= sta_vht_mask[i]; + } + + return true; +} + +static void +rate_control_apply_mask_ratetbl(struct sta_info *sta, + struct ieee80211_supported_band *sband, + struct ieee80211_sta_rates *rates) +{ + int i; + u32 mask; + u8 mcs_mask[IEEE80211_HT_MCS_MASK_LEN]; + u16 vht_mask[NL80211_VHT_NSS_MAX]; + enum nl80211_chan_width chan_width; + + if (!rate_control_cap_mask(sta->sdata, sband, &sta->sta, &mask, + mcs_mask, vht_mask)) + return; + + chan_width = sta->sdata->vif.bss_conf.chandef.width; + for (i = 0; i < IEEE80211_TX_RATE_TABLE_SIZE; i++) { + if (rates->rate[i].idx < 0) + break; + + rate_idx_match_mask(&rates->rate[i].idx, &rates->rate[i].flags, + sband, chan_width, mask, mcs_mask, + vht_mask); + } +} + +static void rate_control_apply_mask(struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, + struct ieee80211_supported_band *sband, + struct ieee80211_tx_rate *rates, + int max_rates) +{ + enum nl80211_chan_width chan_width; + u8 mcs_mask[IEEE80211_HT_MCS_MASK_LEN]; + u32 mask; + u16 rate_flags, vht_mask[NL80211_VHT_NSS_MAX]; + int i; + + /* + * Try to enforce the rateidx mask the user wanted. skip this if the + * default mask (allow all rates) is used to save some processing for + * the common case. + */ + if (!rate_control_cap_mask(sdata, sband, sta, &mask, mcs_mask, + vht_mask)) + return; + + /* + * Make sure the rate index selected for each TX rate is + * included in the configured mask and change the rate indexes + * if needed. + */ + chan_width = sdata->vif.bss_conf.chandef.width; + for (i = 0; i < max_rates; i++) { + /* Skip invalid rates */ + if (rates[i].idx < 0) + break; + + rate_flags = rates[i].flags; + rate_idx_match_mask(&rates[i].idx, &rate_flags, sband, + chan_width, mask, mcs_mask, vht_mask); + rates[i].flags = rate_flags; + } +} + +void ieee80211_get_tx_rates(struct ieee80211_vif *vif, + struct ieee80211_sta *sta, + struct sk_buff *skb, + struct ieee80211_tx_rate *dest, + int max_rates) +{ + struct ieee80211_sub_if_data *sdata; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_supported_band *sband; + + rate_control_fill_sta_table(sta, info, dest, max_rates); + + if (!vif) + return; + + sdata = vif_to_sdata(vif); + sband = sdata->local->hw.wiphy->bands[info->band]; + + if (ieee80211_is_tx_data(skb)) + rate_control_apply_mask(sdata, sta, sband, dest, max_rates); + + if (dest[0].idx < 0) + __rate_control_send_low(&sdata->local->hw, sband, sta, info, + sdata->rc_rateidx_mask[info->band]); + + if (sta) + rate_fixup_ratelist(vif, sband, info, dest, max_rates); +} +EXPORT_SYMBOL(ieee80211_get_tx_rates); + +void rate_control_get_rate(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + struct ieee80211_tx_rate_control *txrc) +{ + struct rate_control_ref *ref = sdata->local->rate_ctrl; + void *priv_sta = NULL; + struct ieee80211_sta *ista = NULL; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(txrc->skb); + int i; + + for (i = 0; i < IEEE80211_TX_MAX_RATES; i++) { + info->control.rates[i].idx = -1; + info->control.rates[i].flags = 0; + info->control.rates[i].count = 0; + } + + if (rate_control_send_low(sta ? &sta->sta : NULL, txrc)) + return; + + if (ieee80211_hw_check(&sdata->local->hw, HAS_RATE_CONTROL)) + return; + + if (sta && test_sta_flag(sta, WLAN_STA_RATE_CONTROL)) { + ista = &sta->sta; + priv_sta = sta->rate_ctrl_priv; + } + + if (ista) { + spin_lock_bh(&sta->rate_ctrl_lock); + ref->ops->get_rate(ref->priv, ista, priv_sta, txrc); + spin_unlock_bh(&sta->rate_ctrl_lock); + } else { + rate_control_send_low(NULL, txrc); + } + + if (ieee80211_hw_check(&sdata->local->hw, SUPPORTS_RC_TABLE)) + return; + + ieee80211_get_tx_rates(&sdata->vif, ista, txrc->skb, + info->control.rates, + ARRAY_SIZE(info->control.rates)); +} + +int rate_control_set_rates(struct ieee80211_hw *hw, + struct ieee80211_sta *pubsta, + struct ieee80211_sta_rates *rates) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + struct ieee80211_sta_rates *old; + struct ieee80211_supported_band *sband; + + sband = ieee80211_get_sband(sta->sdata); + if (!sband) + return -EINVAL; + rate_control_apply_mask_ratetbl(sta, sband, rates); + /* + * mac80211 guarantees that this function will not be called + * concurrently, so the following RCU access is safe, even without + * extra locking. This can not be checked easily, so we just set + * the condition to true. + */ + old = rcu_dereference_protected(pubsta->rates, true); + rcu_assign_pointer(pubsta->rates, rates); + if (old) + kfree_rcu(old, rcu_head); + + if (sta->uploaded) + drv_sta_rate_tbl_update(hw_to_local(hw), sta->sdata, pubsta); + + ieee80211_sta_set_expected_throughput(pubsta, sta_get_expected_throughput(sta)); + + return 0; +} +EXPORT_SYMBOL(rate_control_set_rates); + +int ieee80211_init_rate_ctrl_alg(struct ieee80211_local *local, + const char *name) +{ + struct rate_control_ref *ref; + + ASSERT_RTNL(); + + if (local->open_count) + return -EBUSY; + + if (ieee80211_hw_check(&local->hw, HAS_RATE_CONTROL)) { + if (WARN_ON(!local->ops->set_rts_threshold)) + return -EINVAL; + return 0; + } + + ref = rate_control_alloc(name, local); + if (!ref) { + wiphy_warn(local->hw.wiphy, + "Failed to select rate control algorithm\n"); + return -ENOENT; + } + + WARN_ON(local->rate_ctrl); + local->rate_ctrl = ref; + + wiphy_debug(local->hw.wiphy, "Selected rate control algorithm '%s'\n", + ref->ops->name); + + return 0; +} + +void rate_control_deinitialize(struct ieee80211_local *local) +{ + struct rate_control_ref *ref; + + ref = local->rate_ctrl; + + if (!ref) + return; + + local->rate_ctrl = NULL; + rate_control_free(local, ref); +} diff --git a/net/mac80211/rate.h b/net/mac80211/rate.h new file mode 100644 index 000000000..d6190f10f --- /dev/null +++ b/net/mac80211/rate.h @@ -0,0 +1,112 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright 2002-2005, Instant802 Networks, Inc. + * Copyright 2005, Devicescape Software, Inc. + * Copyright (c) 2006 Jiri Benc + * Copyright (C) 2022 Intel Corporation + */ + +#ifndef IEEE80211_RATE_H +#define IEEE80211_RATE_H + +#include +#include +#include +#include +#include "ieee80211_i.h" +#include "sta_info.h" +#include "driver-ops.h" + +struct rate_control_ref { + const struct rate_control_ops *ops; + void *priv; +}; + +void rate_control_get_rate(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + struct ieee80211_tx_rate_control *txrc); + +void rate_control_tx_status(struct ieee80211_local *local, + struct ieee80211_tx_status *st); + +void rate_control_rate_init(struct sta_info *sta); +void rate_control_rate_update(struct ieee80211_local *local, + struct ieee80211_supported_band *sband, + struct sta_info *sta, + unsigned int link_id, + u32 changed); + +static inline void *rate_control_alloc_sta(struct rate_control_ref *ref, + struct sta_info *sta, gfp_t gfp) +{ + spin_lock_init(&sta->rate_ctrl_lock); + return ref->ops->alloc_sta(ref->priv, &sta->sta, gfp); +} + +static inline void rate_control_free_sta(struct sta_info *sta) +{ + struct rate_control_ref *ref = sta->rate_ctrl; + struct ieee80211_sta *ista = &sta->sta; + void *priv_sta = sta->rate_ctrl_priv; + + ref->ops->free_sta(ref->priv, ista, priv_sta); +} + +static inline void rate_control_add_sta_debugfs(struct sta_info *sta) +{ +#ifdef CONFIG_MAC80211_DEBUGFS + struct rate_control_ref *ref = sta->rate_ctrl; + if (ref && sta->debugfs_dir && ref->ops->add_sta_debugfs) + ref->ops->add_sta_debugfs(ref->priv, sta->rate_ctrl_priv, + sta->debugfs_dir); +#endif +} + +extern const struct file_operations rcname_ops; + +static inline void rate_control_add_debugfs(struct ieee80211_local *local) +{ +#ifdef CONFIG_MAC80211_DEBUGFS + struct dentry *debugfsdir; + + if (!local->rate_ctrl) + return; + + if (!local->rate_ctrl->ops->add_debugfs) + return; + + debugfsdir = debugfs_create_dir("rc", local->hw.wiphy->debugfsdir); + local->debugfs.rcdir = debugfsdir; + debugfs_create_file("name", 0400, debugfsdir, + local->rate_ctrl, &rcname_ops); + + local->rate_ctrl->ops->add_debugfs(&local->hw, local->rate_ctrl->priv, + debugfsdir); +#endif +} + +void ieee80211_check_rate_mask(struct ieee80211_link_data *link); + +/* Get a reference to the rate control algorithm. If `name' is NULL, get the + * first available algorithm. */ +int ieee80211_init_rate_ctrl_alg(struct ieee80211_local *local, + const char *name); +void rate_control_deinitialize(struct ieee80211_local *local); + + +/* Rate control algorithms */ +#ifdef CONFIG_MAC80211_RC_MINSTREL +int rc80211_minstrel_init(void); +void rc80211_minstrel_exit(void); +#else +static inline int rc80211_minstrel_init(void) +{ + return 0; +} +static inline void rc80211_minstrel_exit(void) +{ +} +#endif + + +#endif /* IEEE80211_RATE_H */ diff --git a/net/mac80211/rc80211_minstrel_ht.c b/net/mac80211/rc80211_minstrel_ht.c new file mode 100644 index 000000000..3d91b98db --- /dev/null +++ b/net/mac80211/rc80211_minstrel_ht.c @@ -0,0 +1,2061 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2010-2013 Felix Fietkau + * Copyright (C) 2019-2022 Intel Corporation + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "rate.h" +#include "sta_info.h" +#include "rc80211_minstrel_ht.h" + +#define AVG_AMPDU_SIZE 16 +#define AVG_PKT_SIZE 1200 + +/* Number of bits for an average sized packet */ +#define MCS_NBITS ((AVG_PKT_SIZE * AVG_AMPDU_SIZE) << 3) + +/* Number of symbols for a packet with (bps) bits per symbol */ +#define MCS_NSYMS(bps) DIV_ROUND_UP(MCS_NBITS, (bps)) + +/* Transmission time (nanoseconds) for a packet containing (syms) symbols */ +#define MCS_SYMBOL_TIME(sgi, syms) \ + (sgi ? \ + ((syms) * 18000 + 4000) / 5 : /* syms * 3.6 us */ \ + ((syms) * 1000) << 2 /* syms * 4 us */ \ + ) + +/* Transmit duration for the raw data part of an average sized packet */ +#define MCS_DURATION(streams, sgi, bps) \ + (MCS_SYMBOL_TIME(sgi, MCS_NSYMS((streams) * (bps))) / AVG_AMPDU_SIZE) + +#define BW_20 0 +#define BW_40 1 +#define BW_80 2 + +/* + * Define group sort order: HT40 -> SGI -> #streams + */ +#define GROUP_IDX(_streams, _sgi, _ht40) \ + MINSTREL_HT_GROUP_0 + \ + MINSTREL_MAX_STREAMS * 2 * _ht40 + \ + MINSTREL_MAX_STREAMS * _sgi + \ + _streams - 1 + +#define _MAX(a, b) (((a)>(b))?(a):(b)) + +#define GROUP_SHIFT(duration) \ + _MAX(0, 16 - __builtin_clz(duration)) + +/* MCS rate information for an MCS group */ +#define __MCS_GROUP(_streams, _sgi, _ht40, _s) \ + [GROUP_IDX(_streams, _sgi, _ht40)] = { \ + .streams = _streams, \ + .shift = _s, \ + .bw = _ht40, \ + .flags = \ + IEEE80211_TX_RC_MCS | \ + (_sgi ? IEEE80211_TX_RC_SHORT_GI : 0) | \ + (_ht40 ? IEEE80211_TX_RC_40_MHZ_WIDTH : 0), \ + .duration = { \ + MCS_DURATION(_streams, _sgi, _ht40 ? 54 : 26) >> _s, \ + MCS_DURATION(_streams, _sgi, _ht40 ? 108 : 52) >> _s, \ + MCS_DURATION(_streams, _sgi, _ht40 ? 162 : 78) >> _s, \ + MCS_DURATION(_streams, _sgi, _ht40 ? 216 : 104) >> _s, \ + MCS_DURATION(_streams, _sgi, _ht40 ? 324 : 156) >> _s, \ + MCS_DURATION(_streams, _sgi, _ht40 ? 432 : 208) >> _s, \ + MCS_DURATION(_streams, _sgi, _ht40 ? 486 : 234) >> _s, \ + MCS_DURATION(_streams, _sgi, _ht40 ? 540 : 260) >> _s \ + } \ +} + +#define MCS_GROUP_SHIFT(_streams, _sgi, _ht40) \ + GROUP_SHIFT(MCS_DURATION(_streams, _sgi, _ht40 ? 54 : 26)) + +#define MCS_GROUP(_streams, _sgi, _ht40) \ + __MCS_GROUP(_streams, _sgi, _ht40, \ + MCS_GROUP_SHIFT(_streams, _sgi, _ht40)) + +#define VHT_GROUP_IDX(_streams, _sgi, _bw) \ + (MINSTREL_VHT_GROUP_0 + \ + MINSTREL_MAX_STREAMS * 2 * (_bw) + \ + MINSTREL_MAX_STREAMS * (_sgi) + \ + (_streams) - 1) + +#define BW2VBPS(_bw, r3, r2, r1) \ + (_bw == BW_80 ? r3 : _bw == BW_40 ? r2 : r1) + +#define __VHT_GROUP(_streams, _sgi, _bw, _s) \ + [VHT_GROUP_IDX(_streams, _sgi, _bw)] = { \ + .streams = _streams, \ + .shift = _s, \ + .bw = _bw, \ + .flags = \ + IEEE80211_TX_RC_VHT_MCS | \ + (_sgi ? IEEE80211_TX_RC_SHORT_GI : 0) | \ + (_bw == BW_80 ? IEEE80211_TX_RC_80_MHZ_WIDTH : \ + _bw == BW_40 ? IEEE80211_TX_RC_40_MHZ_WIDTH : 0), \ + .duration = { \ + MCS_DURATION(_streams, _sgi, \ + BW2VBPS(_bw, 117, 54, 26)) >> _s, \ + MCS_DURATION(_streams, _sgi, \ + BW2VBPS(_bw, 234, 108, 52)) >> _s, \ + MCS_DURATION(_streams, _sgi, \ + BW2VBPS(_bw, 351, 162, 78)) >> _s, \ + MCS_DURATION(_streams, _sgi, \ + BW2VBPS(_bw, 468, 216, 104)) >> _s, \ + MCS_DURATION(_streams, _sgi, \ + BW2VBPS(_bw, 702, 324, 156)) >> _s, \ + MCS_DURATION(_streams, _sgi, \ + BW2VBPS(_bw, 936, 432, 208)) >> _s, \ + MCS_DURATION(_streams, _sgi, \ + BW2VBPS(_bw, 1053, 486, 234)) >> _s, \ + MCS_DURATION(_streams, _sgi, \ + BW2VBPS(_bw, 1170, 540, 260)) >> _s, \ + MCS_DURATION(_streams, _sgi, \ + BW2VBPS(_bw, 1404, 648, 312)) >> _s, \ + MCS_DURATION(_streams, _sgi, \ + BW2VBPS(_bw, 1560, 720, 346)) >> _s \ + } \ +} + +#define VHT_GROUP_SHIFT(_streams, _sgi, _bw) \ + GROUP_SHIFT(MCS_DURATION(_streams, _sgi, \ + BW2VBPS(_bw, 117, 54, 26))) + +#define VHT_GROUP(_streams, _sgi, _bw) \ + __VHT_GROUP(_streams, _sgi, _bw, \ + VHT_GROUP_SHIFT(_streams, _sgi, _bw)) + +#define CCK_DURATION(_bitrate, _short) \ + (1000 * (10 /* SIFS */ + \ + (_short ? 72 + 24 : 144 + 48) + \ + (8 * (AVG_PKT_SIZE + 4) * 10) / (_bitrate))) + +#define CCK_DURATION_LIST(_short, _s) \ + CCK_DURATION(10, _short) >> _s, \ + CCK_DURATION(20, _short) >> _s, \ + CCK_DURATION(55, _short) >> _s, \ + CCK_DURATION(110, _short) >> _s + +#define __CCK_GROUP(_s) \ + [MINSTREL_CCK_GROUP] = { \ + .streams = 1, \ + .flags = 0, \ + .shift = _s, \ + .duration = { \ + CCK_DURATION_LIST(false, _s), \ + CCK_DURATION_LIST(true, _s) \ + } \ + } + +#define CCK_GROUP_SHIFT \ + GROUP_SHIFT(CCK_DURATION(10, false)) + +#define CCK_GROUP __CCK_GROUP(CCK_GROUP_SHIFT) + +#define OFDM_DURATION(_bitrate) \ + (1000 * (16 /* SIFS + signal ext */ + \ + 16 /* T_PREAMBLE */ + \ + 4 /* T_SIGNAL */ + \ + 4 * (((16 + 80 * (AVG_PKT_SIZE + 4) + 6) / \ + ((_bitrate) * 4))))) + +#define OFDM_DURATION_LIST(_s) \ + OFDM_DURATION(60) >> _s, \ + OFDM_DURATION(90) >> _s, \ + OFDM_DURATION(120) >> _s, \ + OFDM_DURATION(180) >> _s, \ + OFDM_DURATION(240) >> _s, \ + OFDM_DURATION(360) >> _s, \ + OFDM_DURATION(480) >> _s, \ + OFDM_DURATION(540) >> _s + +#define __OFDM_GROUP(_s) \ + [MINSTREL_OFDM_GROUP] = { \ + .streams = 1, \ + .flags = 0, \ + .shift = _s, \ + .duration = { \ + OFDM_DURATION_LIST(_s), \ + } \ + } + +#define OFDM_GROUP_SHIFT \ + GROUP_SHIFT(OFDM_DURATION(60)) + +#define OFDM_GROUP __OFDM_GROUP(OFDM_GROUP_SHIFT) + + +static bool minstrel_vht_only = true; +module_param(minstrel_vht_only, bool, 0644); +MODULE_PARM_DESC(minstrel_vht_only, + "Use only VHT rates when VHT is supported by sta."); + +/* + * To enable sufficiently targeted rate sampling, MCS rates are divided into + * groups, based on the number of streams and flags (HT40, SGI) that they + * use. + * + * Sortorder has to be fixed for GROUP_IDX macro to be applicable: + * BW -> SGI -> #streams + */ +const struct mcs_group minstrel_mcs_groups[] = { + MCS_GROUP(1, 0, BW_20), + MCS_GROUP(2, 0, BW_20), + MCS_GROUP(3, 0, BW_20), + MCS_GROUP(4, 0, BW_20), + + MCS_GROUP(1, 1, BW_20), + MCS_GROUP(2, 1, BW_20), + MCS_GROUP(3, 1, BW_20), + MCS_GROUP(4, 1, BW_20), + + MCS_GROUP(1, 0, BW_40), + MCS_GROUP(2, 0, BW_40), + MCS_GROUP(3, 0, BW_40), + MCS_GROUP(4, 0, BW_40), + + MCS_GROUP(1, 1, BW_40), + MCS_GROUP(2, 1, BW_40), + MCS_GROUP(3, 1, BW_40), + MCS_GROUP(4, 1, BW_40), + + CCK_GROUP, + OFDM_GROUP, + + VHT_GROUP(1, 0, BW_20), + VHT_GROUP(2, 0, BW_20), + VHT_GROUP(3, 0, BW_20), + VHT_GROUP(4, 0, BW_20), + + VHT_GROUP(1, 1, BW_20), + VHT_GROUP(2, 1, BW_20), + VHT_GROUP(3, 1, BW_20), + VHT_GROUP(4, 1, BW_20), + + VHT_GROUP(1, 0, BW_40), + VHT_GROUP(2, 0, BW_40), + VHT_GROUP(3, 0, BW_40), + VHT_GROUP(4, 0, BW_40), + + VHT_GROUP(1, 1, BW_40), + VHT_GROUP(2, 1, BW_40), + VHT_GROUP(3, 1, BW_40), + VHT_GROUP(4, 1, BW_40), + + VHT_GROUP(1, 0, BW_80), + VHT_GROUP(2, 0, BW_80), + VHT_GROUP(3, 0, BW_80), + VHT_GROUP(4, 0, BW_80), + + VHT_GROUP(1, 1, BW_80), + VHT_GROUP(2, 1, BW_80), + VHT_GROUP(3, 1, BW_80), + VHT_GROUP(4, 1, BW_80), +}; + +const s16 minstrel_cck_bitrates[4] = { 10, 20, 55, 110 }; +const s16 minstrel_ofdm_bitrates[8] = { 60, 90, 120, 180, 240, 360, 480, 540 }; +static u8 sample_table[SAMPLE_COLUMNS][MCS_GROUP_RATES] __read_mostly; +static const u8 minstrel_sample_seq[] = { + MINSTREL_SAMPLE_TYPE_INC, + MINSTREL_SAMPLE_TYPE_JUMP, + MINSTREL_SAMPLE_TYPE_INC, + MINSTREL_SAMPLE_TYPE_JUMP, + MINSTREL_SAMPLE_TYPE_INC, + MINSTREL_SAMPLE_TYPE_SLOW, +}; + +static void +minstrel_ht_update_rates(struct minstrel_priv *mp, struct minstrel_ht_sta *mi); + +/* + * Some VHT MCSes are invalid (when Ndbps / Nes is not an integer) + * e.g for MCS9@20MHzx1Nss: Ndbps=8x52*(5/6) Nes=1 + * + * Returns the valid mcs map for struct minstrel_mcs_group_data.supported + */ +static u16 +minstrel_get_valid_vht_rates(int bw, int nss, __le16 mcs_map) +{ + u16 mask = 0; + + if (bw == BW_20) { + if (nss != 3 && nss != 6) + mask = BIT(9); + } else if (bw == BW_80) { + if (nss == 3 || nss == 7) + mask = BIT(6); + else if (nss == 6) + mask = BIT(9); + } else { + WARN_ON(bw != BW_40); + } + + switch ((le16_to_cpu(mcs_map) >> (2 * (nss - 1))) & 3) { + case IEEE80211_VHT_MCS_SUPPORT_0_7: + mask |= 0x300; + break; + case IEEE80211_VHT_MCS_SUPPORT_0_8: + mask |= 0x200; + break; + case IEEE80211_VHT_MCS_SUPPORT_0_9: + break; + default: + mask = 0x3ff; + } + + return 0x3ff & ~mask; +} + +static bool +minstrel_ht_is_legacy_group(int group) +{ + return group == MINSTREL_CCK_GROUP || + group == MINSTREL_OFDM_GROUP; +} + +/* + * Look up an MCS group index based on mac80211 rate information + */ +static int +minstrel_ht_get_group_idx(struct ieee80211_tx_rate *rate) +{ + return GROUP_IDX((rate->idx / 8) + 1, + !!(rate->flags & IEEE80211_TX_RC_SHORT_GI), + !!(rate->flags & IEEE80211_TX_RC_40_MHZ_WIDTH)); +} + +/* + * Look up an MCS group index based on new cfg80211 rate_info. + */ +static int +minstrel_ht_ri_get_group_idx(struct rate_info *rate) +{ + return GROUP_IDX((rate->mcs / 8) + 1, + !!(rate->flags & RATE_INFO_FLAGS_SHORT_GI), + !!(rate->bw & RATE_INFO_BW_40)); +} + +static int +minstrel_vht_get_group_idx(struct ieee80211_tx_rate *rate) +{ + return VHT_GROUP_IDX(ieee80211_rate_get_vht_nss(rate), + !!(rate->flags & IEEE80211_TX_RC_SHORT_GI), + !!(rate->flags & IEEE80211_TX_RC_40_MHZ_WIDTH) + + 2*!!(rate->flags & IEEE80211_TX_RC_80_MHZ_WIDTH)); +} + +/* + * Look up an MCS group index based on new cfg80211 rate_info. + */ +static int +minstrel_vht_ri_get_group_idx(struct rate_info *rate) +{ + return VHT_GROUP_IDX(rate->nss, + !!(rate->flags & RATE_INFO_FLAGS_SHORT_GI), + !!(rate->bw & RATE_INFO_BW_40) + + 2*!!(rate->bw & RATE_INFO_BW_80)); +} + +static struct minstrel_rate_stats * +minstrel_ht_get_stats(struct minstrel_priv *mp, struct minstrel_ht_sta *mi, + struct ieee80211_tx_rate *rate) +{ + int group, idx; + + if (rate->flags & IEEE80211_TX_RC_MCS) { + group = minstrel_ht_get_group_idx(rate); + idx = rate->idx % 8; + goto out; + } + + if (rate->flags & IEEE80211_TX_RC_VHT_MCS) { + group = minstrel_vht_get_group_idx(rate); + idx = ieee80211_rate_get_vht_mcs(rate); + goto out; + } + + group = MINSTREL_CCK_GROUP; + for (idx = 0; idx < ARRAY_SIZE(mp->cck_rates); idx++) { + if (!(mi->supported[group] & BIT(idx))) + continue; + + if (rate->idx != mp->cck_rates[idx]) + continue; + + /* short preamble */ + if ((mi->supported[group] & BIT(idx + 4)) && + (rate->flags & IEEE80211_TX_RC_USE_SHORT_PREAMBLE)) + idx += 4; + goto out; + } + + group = MINSTREL_OFDM_GROUP; + for (idx = 0; idx < ARRAY_SIZE(mp->ofdm_rates[0]); idx++) + if (rate->idx == mp->ofdm_rates[mi->band][idx]) + goto out; + + idx = 0; +out: + return &mi->groups[group].rates[idx]; +} + +/* + * Get the minstrel rate statistics for specified STA and rate info. + */ +static struct minstrel_rate_stats * +minstrel_ht_ri_get_stats(struct minstrel_priv *mp, struct minstrel_ht_sta *mi, + struct ieee80211_rate_status *rate_status) +{ + int group, idx; + struct rate_info *rate = &rate_status->rate_idx; + + if (rate->flags & RATE_INFO_FLAGS_MCS) { + group = minstrel_ht_ri_get_group_idx(rate); + idx = rate->mcs % 8; + goto out; + } + + if (rate->flags & RATE_INFO_FLAGS_VHT_MCS) { + group = minstrel_vht_ri_get_group_idx(rate); + idx = rate->mcs; + goto out; + } + + group = MINSTREL_CCK_GROUP; + for (idx = 0; idx < ARRAY_SIZE(mp->cck_rates); idx++) { + if (rate->legacy != minstrel_cck_bitrates[ mp->cck_rates[idx] ]) + continue; + + /* short preamble */ + if ((mi->supported[group] & BIT(idx + 4)) && + mi->use_short_preamble) + idx += 4; + goto out; + } + + group = MINSTREL_OFDM_GROUP; + for (idx = 0; idx < ARRAY_SIZE(mp->ofdm_rates[0]); idx++) + if (rate->legacy == minstrel_ofdm_bitrates[ mp->ofdm_rates[mi->band][idx] ]) + goto out; + + idx = 0; +out: + return &mi->groups[group].rates[idx]; +} + +static inline struct minstrel_rate_stats * +minstrel_get_ratestats(struct minstrel_ht_sta *mi, int index) +{ + return &mi->groups[MI_RATE_GROUP(index)].rates[MI_RATE_IDX(index)]; +} + +static inline int minstrel_get_duration(int index) +{ + const struct mcs_group *group = &minstrel_mcs_groups[MI_RATE_GROUP(index)]; + unsigned int duration = group->duration[MI_RATE_IDX(index)]; + + return duration << group->shift; +} + +static unsigned int +minstrel_ht_avg_ampdu_len(struct minstrel_ht_sta *mi) +{ + int duration; + + if (mi->avg_ampdu_len) + return MINSTREL_TRUNC(mi->avg_ampdu_len); + + if (minstrel_ht_is_legacy_group(MI_RATE_GROUP(mi->max_tp_rate[0]))) + return 1; + + duration = minstrel_get_duration(mi->max_tp_rate[0]); + + if (duration > 400 * 1000) + return 2; + + if (duration > 250 * 1000) + return 4; + + if (duration > 150 * 1000) + return 8; + + return 16; +} + +/* + * Return current throughput based on the average A-MPDU length, taking into + * account the expected number of retransmissions and their expected length + */ +int +minstrel_ht_get_tp_avg(struct minstrel_ht_sta *mi, int group, int rate, + int prob_avg) +{ + unsigned int nsecs = 0, overhead = mi->overhead; + unsigned int ampdu_len = 1; + + /* do not account throughput if success prob is below 10% */ + if (prob_avg < MINSTREL_FRAC(10, 100)) + return 0; + + if (minstrel_ht_is_legacy_group(group)) + overhead = mi->overhead_legacy; + else + ampdu_len = minstrel_ht_avg_ampdu_len(mi); + + nsecs = 1000 * overhead / ampdu_len; + nsecs += minstrel_mcs_groups[group].duration[rate] << + minstrel_mcs_groups[group].shift; + + /* + * For the throughput calculation, limit the probability value to 90% to + * account for collision related packet error rate fluctuation + * (prob is scaled - see MINSTREL_FRAC above) + */ + if (prob_avg > MINSTREL_FRAC(90, 100)) + prob_avg = MINSTREL_FRAC(90, 100); + + return MINSTREL_TRUNC(100 * ((prob_avg * 1000000) / nsecs)); +} + +/* + * Find & sort topmost throughput rates + * + * If multiple rates provide equal throughput the sorting is based on their + * current success probability. Higher success probability is preferred among + * MCS groups, CCK rates do not provide aggregation and are therefore at last. + */ +static void +minstrel_ht_sort_best_tp_rates(struct minstrel_ht_sta *mi, u16 index, + u16 *tp_list) +{ + int cur_group, cur_idx, cur_tp_avg, cur_prob; + int tmp_group, tmp_idx, tmp_tp_avg, tmp_prob; + int j = MAX_THR_RATES; + + cur_group = MI_RATE_GROUP(index); + cur_idx = MI_RATE_IDX(index); + cur_prob = mi->groups[cur_group].rates[cur_idx].prob_avg; + cur_tp_avg = minstrel_ht_get_tp_avg(mi, cur_group, cur_idx, cur_prob); + + do { + tmp_group = MI_RATE_GROUP(tp_list[j - 1]); + tmp_idx = MI_RATE_IDX(tp_list[j - 1]); + tmp_prob = mi->groups[tmp_group].rates[tmp_idx].prob_avg; + tmp_tp_avg = minstrel_ht_get_tp_avg(mi, tmp_group, tmp_idx, + tmp_prob); + if (cur_tp_avg < tmp_tp_avg || + (cur_tp_avg == tmp_tp_avg && cur_prob <= tmp_prob)) + break; + j--; + } while (j > 0); + + if (j < MAX_THR_RATES - 1) { + memmove(&tp_list[j + 1], &tp_list[j], (sizeof(*tp_list) * + (MAX_THR_RATES - (j + 1)))); + } + if (j < MAX_THR_RATES) + tp_list[j] = index; +} + +/* + * Find and set the topmost probability rate per sta and per group + */ +static void +minstrel_ht_set_best_prob_rate(struct minstrel_ht_sta *mi, u16 *dest, u16 index) +{ + struct minstrel_mcs_group_data *mg; + struct minstrel_rate_stats *mrs; + int tmp_group, tmp_idx, tmp_tp_avg, tmp_prob; + int max_tp_group, max_tp_idx, max_tp_prob; + int cur_tp_avg, cur_group, cur_idx; + int max_gpr_group, max_gpr_idx; + int max_gpr_tp_avg, max_gpr_prob; + + cur_group = MI_RATE_GROUP(index); + cur_idx = MI_RATE_IDX(index); + mg = &mi->groups[cur_group]; + mrs = &mg->rates[cur_idx]; + + tmp_group = MI_RATE_GROUP(*dest); + tmp_idx = MI_RATE_IDX(*dest); + tmp_prob = mi->groups[tmp_group].rates[tmp_idx].prob_avg; + tmp_tp_avg = minstrel_ht_get_tp_avg(mi, tmp_group, tmp_idx, tmp_prob); + + /* if max_tp_rate[0] is from MCS_GROUP max_prob_rate get selected from + * MCS_GROUP as well as CCK_GROUP rates do not allow aggregation */ + max_tp_group = MI_RATE_GROUP(mi->max_tp_rate[0]); + max_tp_idx = MI_RATE_IDX(mi->max_tp_rate[0]); + max_tp_prob = mi->groups[max_tp_group].rates[max_tp_idx].prob_avg; + + if (minstrel_ht_is_legacy_group(MI_RATE_GROUP(index)) && + !minstrel_ht_is_legacy_group(max_tp_group)) + return; + + /* skip rates faster than max tp rate with lower prob */ + if (minstrel_get_duration(mi->max_tp_rate[0]) > minstrel_get_duration(index) && + mrs->prob_avg < max_tp_prob) + return; + + max_gpr_group = MI_RATE_GROUP(mg->max_group_prob_rate); + max_gpr_idx = MI_RATE_IDX(mg->max_group_prob_rate); + max_gpr_prob = mi->groups[max_gpr_group].rates[max_gpr_idx].prob_avg; + + if (mrs->prob_avg > MINSTREL_FRAC(75, 100)) { + cur_tp_avg = minstrel_ht_get_tp_avg(mi, cur_group, cur_idx, + mrs->prob_avg); + if (cur_tp_avg > tmp_tp_avg) + *dest = index; + + max_gpr_tp_avg = minstrel_ht_get_tp_avg(mi, max_gpr_group, + max_gpr_idx, + max_gpr_prob); + if (cur_tp_avg > max_gpr_tp_avg) + mg->max_group_prob_rate = index; + } else { + if (mrs->prob_avg > tmp_prob) + *dest = index; + if (mrs->prob_avg > max_gpr_prob) + mg->max_group_prob_rate = index; + } +} + + +/* + * Assign new rate set per sta and use CCK rates only if the fastest + * rate (max_tp_rate[0]) is from CCK group. This prohibits such sorted + * rate sets where MCS and CCK rates are mixed, because CCK rates can + * not use aggregation. + */ +static void +minstrel_ht_assign_best_tp_rates(struct minstrel_ht_sta *mi, + u16 tmp_mcs_tp_rate[MAX_THR_RATES], + u16 tmp_legacy_tp_rate[MAX_THR_RATES]) +{ + unsigned int tmp_group, tmp_idx, tmp_cck_tp, tmp_mcs_tp, tmp_prob; + int i; + + tmp_group = MI_RATE_GROUP(tmp_legacy_tp_rate[0]); + tmp_idx = MI_RATE_IDX(tmp_legacy_tp_rate[0]); + tmp_prob = mi->groups[tmp_group].rates[tmp_idx].prob_avg; + tmp_cck_tp = minstrel_ht_get_tp_avg(mi, tmp_group, tmp_idx, tmp_prob); + + tmp_group = MI_RATE_GROUP(tmp_mcs_tp_rate[0]); + tmp_idx = MI_RATE_IDX(tmp_mcs_tp_rate[0]); + tmp_prob = mi->groups[tmp_group].rates[tmp_idx].prob_avg; + tmp_mcs_tp = minstrel_ht_get_tp_avg(mi, tmp_group, tmp_idx, tmp_prob); + + if (tmp_cck_tp > tmp_mcs_tp) { + for(i = 0; i < MAX_THR_RATES; i++) { + minstrel_ht_sort_best_tp_rates(mi, tmp_legacy_tp_rate[i], + tmp_mcs_tp_rate); + } + } + +} + +/* + * Try to increase robustness of max_prob rate by decrease number of + * streams if possible. + */ +static inline void +minstrel_ht_prob_rate_reduce_streams(struct minstrel_ht_sta *mi) +{ + struct minstrel_mcs_group_data *mg; + int tmp_max_streams, group, tmp_idx, tmp_prob; + int tmp_tp = 0; + + if (!mi->sta->deflink.ht_cap.ht_supported) + return; + + group = MI_RATE_GROUP(mi->max_tp_rate[0]); + tmp_max_streams = minstrel_mcs_groups[group].streams; + for (group = 0; group < ARRAY_SIZE(minstrel_mcs_groups); group++) { + mg = &mi->groups[group]; + if (!mi->supported[group] || group == MINSTREL_CCK_GROUP) + continue; + + tmp_idx = MI_RATE_IDX(mg->max_group_prob_rate); + tmp_prob = mi->groups[group].rates[tmp_idx].prob_avg; + + if (tmp_tp < minstrel_ht_get_tp_avg(mi, group, tmp_idx, tmp_prob) && + (minstrel_mcs_groups[group].streams < tmp_max_streams)) { + mi->max_prob_rate = mg->max_group_prob_rate; + tmp_tp = minstrel_ht_get_tp_avg(mi, group, + tmp_idx, + tmp_prob); + } + } +} + +static u16 +__minstrel_ht_get_sample_rate(struct minstrel_ht_sta *mi, + enum minstrel_sample_type type) +{ + u16 *rates = mi->sample[type].sample_rates; + u16 cur; + int i; + + for (i = 0; i < MINSTREL_SAMPLE_RATES; i++) { + if (!rates[i]) + continue; + + cur = rates[i]; + rates[i] = 0; + return cur; + } + + return 0; +} + +static inline int +minstrel_ewma(int old, int new, int weight) +{ + int diff, incr; + + diff = new - old; + incr = (EWMA_DIV - weight) * diff / EWMA_DIV; + + return old + incr; +} + +static inline int minstrel_filter_avg_add(u16 *prev_1, u16 *prev_2, s32 in) +{ + s32 out_1 = *prev_1; + s32 out_2 = *prev_2; + s32 val; + + if (!in) + in += 1; + + if (!out_1) { + val = out_1 = in; + goto out; + } + + val = MINSTREL_AVG_COEFF1 * in; + val += MINSTREL_AVG_COEFF2 * out_1; + val += MINSTREL_AVG_COEFF3 * out_2; + val >>= MINSTREL_SCALE; + + if (val > 1 << MINSTREL_SCALE) + val = 1 << MINSTREL_SCALE; + if (val < 0) + val = 1; + +out: + *prev_2 = out_1; + *prev_1 = val; + + return val; +} + +/* +* Recalculate statistics and counters of a given rate +*/ +static void +minstrel_ht_calc_rate_stats(struct minstrel_priv *mp, + struct minstrel_rate_stats *mrs) +{ + unsigned int cur_prob; + + if (unlikely(mrs->attempts > 0)) { + cur_prob = MINSTREL_FRAC(mrs->success, mrs->attempts); + minstrel_filter_avg_add(&mrs->prob_avg, + &mrs->prob_avg_1, cur_prob); + mrs->att_hist += mrs->attempts; + mrs->succ_hist += mrs->success; + } + + mrs->last_success = mrs->success; + mrs->last_attempts = mrs->attempts; + mrs->success = 0; + mrs->attempts = 0; +} + +static bool +minstrel_ht_find_sample_rate(struct minstrel_ht_sta *mi, int type, int idx) +{ + int i; + + for (i = 0; i < MINSTREL_SAMPLE_RATES; i++) { + u16 cur = mi->sample[type].sample_rates[i]; + + if (cur == idx) + return true; + + if (!cur) + break; + } + + return false; +} + +static int +minstrel_ht_move_sample_rates(struct minstrel_ht_sta *mi, int type, + u32 fast_rate_dur, u32 slow_rate_dur) +{ + u16 *rates = mi->sample[type].sample_rates; + int i, j; + + for (i = 0, j = 0; i < MINSTREL_SAMPLE_RATES; i++) { + u32 duration; + bool valid = false; + u16 cur; + + cur = rates[i]; + if (!cur) + continue; + + duration = minstrel_get_duration(cur); + switch (type) { + case MINSTREL_SAMPLE_TYPE_SLOW: + valid = duration > fast_rate_dur && + duration < slow_rate_dur; + break; + case MINSTREL_SAMPLE_TYPE_INC: + case MINSTREL_SAMPLE_TYPE_JUMP: + valid = duration < fast_rate_dur; + break; + default: + valid = false; + break; + } + + if (!valid) { + rates[i] = 0; + continue; + } + + if (i == j) + continue; + + rates[j++] = cur; + rates[i] = 0; + } + + return j; +} + +static int +minstrel_ht_group_min_rate_offset(struct minstrel_ht_sta *mi, int group, + u32 max_duration) +{ + u16 supported = mi->supported[group]; + int i; + + for (i = 0; i < MCS_GROUP_RATES && supported; i++, supported >>= 1) { + if (!(supported & BIT(0))) + continue; + + if (minstrel_get_duration(MI_RATE(group, i)) >= max_duration) + continue; + + return i; + } + + return -1; +} + +/* + * Incremental update rates: + * Flip through groups and pick the first group rate that is faster than the + * highest currently selected rate + */ +static u16 +minstrel_ht_next_inc_rate(struct minstrel_ht_sta *mi, u32 fast_rate_dur) +{ + u8 type = MINSTREL_SAMPLE_TYPE_INC; + int i, index = 0; + u8 group; + + group = mi->sample[type].sample_group; + for (i = 0; i < ARRAY_SIZE(minstrel_mcs_groups); i++) { + group = (group + 1) % ARRAY_SIZE(minstrel_mcs_groups); + + index = minstrel_ht_group_min_rate_offset(mi, group, + fast_rate_dur); + if (index < 0) + continue; + + index = MI_RATE(group, index & 0xf); + if (!minstrel_ht_find_sample_rate(mi, type, index)) + goto out; + } + index = 0; + +out: + mi->sample[type].sample_group = group; + + return index; +} + +static int +minstrel_ht_next_group_sample_rate(struct minstrel_ht_sta *mi, int group, + u16 supported, int offset) +{ + struct minstrel_mcs_group_data *mg = &mi->groups[group]; + u16 idx; + int i; + + for (i = 0; i < MCS_GROUP_RATES; i++) { + idx = sample_table[mg->column][mg->index]; + if (++mg->index >= MCS_GROUP_RATES) { + mg->index = 0; + if (++mg->column >= ARRAY_SIZE(sample_table)) + mg->column = 0; + } + + if (idx < offset) + continue; + + if (!(supported & BIT(idx))) + continue; + + return MI_RATE(group, idx); + } + + return -1; +} + +/* + * Jump rates: + * Sample random rates, use those that are faster than the highest + * currently selected rate. Rates between the fastest and the slowest + * get sorted into the slow sample bucket, but only if it has room + */ +static u16 +minstrel_ht_next_jump_rate(struct minstrel_ht_sta *mi, u32 fast_rate_dur, + u32 slow_rate_dur, int *slow_rate_ofs) +{ + struct minstrel_rate_stats *mrs; + u32 max_duration = slow_rate_dur; + int i, index, offset; + u16 *slow_rates; + u16 supported; + u32 duration; + u8 group; + + if (*slow_rate_ofs >= MINSTREL_SAMPLE_RATES) + max_duration = fast_rate_dur; + + slow_rates = mi->sample[MINSTREL_SAMPLE_TYPE_SLOW].sample_rates; + group = mi->sample[MINSTREL_SAMPLE_TYPE_JUMP].sample_group; + for (i = 0; i < ARRAY_SIZE(minstrel_mcs_groups); i++) { + u8 type; + + group = (group + 1) % ARRAY_SIZE(minstrel_mcs_groups); + + supported = mi->supported[group]; + if (!supported) + continue; + + offset = minstrel_ht_group_min_rate_offset(mi, group, + max_duration); + if (offset < 0) + continue; + + index = minstrel_ht_next_group_sample_rate(mi, group, supported, + offset); + if (index < 0) + continue; + + duration = minstrel_get_duration(index); + if (duration < fast_rate_dur) + type = MINSTREL_SAMPLE_TYPE_JUMP; + else + type = MINSTREL_SAMPLE_TYPE_SLOW; + + if (minstrel_ht_find_sample_rate(mi, type, index)) + continue; + + if (type == MINSTREL_SAMPLE_TYPE_JUMP) + goto found; + + if (*slow_rate_ofs >= MINSTREL_SAMPLE_RATES) + continue; + + if (duration >= slow_rate_dur) + continue; + + /* skip slow rates with high success probability */ + mrs = minstrel_get_ratestats(mi, index); + if (mrs->prob_avg > MINSTREL_FRAC(95, 100)) + continue; + + slow_rates[(*slow_rate_ofs)++] = index; + if (*slow_rate_ofs >= MINSTREL_SAMPLE_RATES) + max_duration = fast_rate_dur; + } + index = 0; + +found: + mi->sample[MINSTREL_SAMPLE_TYPE_JUMP].sample_group = group; + + return index; +} + +static void +minstrel_ht_refill_sample_rates(struct minstrel_ht_sta *mi) +{ + u32 prob_dur = minstrel_get_duration(mi->max_prob_rate); + u32 tp_dur = minstrel_get_duration(mi->max_tp_rate[0]); + u32 tp2_dur = minstrel_get_duration(mi->max_tp_rate[1]); + u32 fast_rate_dur = min(min(tp_dur, tp2_dur), prob_dur); + u32 slow_rate_dur = max(max(tp_dur, tp2_dur), prob_dur); + u16 *rates; + int i, j; + + rates = mi->sample[MINSTREL_SAMPLE_TYPE_INC].sample_rates; + i = minstrel_ht_move_sample_rates(mi, MINSTREL_SAMPLE_TYPE_INC, + fast_rate_dur, slow_rate_dur); + while (i < MINSTREL_SAMPLE_RATES) { + rates[i] = minstrel_ht_next_inc_rate(mi, tp_dur); + if (!rates[i]) + break; + + i++; + } + + rates = mi->sample[MINSTREL_SAMPLE_TYPE_JUMP].sample_rates; + i = minstrel_ht_move_sample_rates(mi, MINSTREL_SAMPLE_TYPE_JUMP, + fast_rate_dur, slow_rate_dur); + j = minstrel_ht_move_sample_rates(mi, MINSTREL_SAMPLE_TYPE_SLOW, + fast_rate_dur, slow_rate_dur); + while (i < MINSTREL_SAMPLE_RATES) { + rates[i] = minstrel_ht_next_jump_rate(mi, fast_rate_dur, + slow_rate_dur, &j); + if (!rates[i]) + break; + + i++; + } + + for (i = 0; i < ARRAY_SIZE(mi->sample); i++) + memcpy(mi->sample[i].cur_sample_rates, mi->sample[i].sample_rates, + sizeof(mi->sample[i].cur_sample_rates)); +} + + +/* + * Update rate statistics and select new primary rates + * + * Rules for rate selection: + * - max_prob_rate must use only one stream, as a tradeoff between delivery + * probability and throughput during strong fluctuations + * - as long as the max prob rate has a probability of more than 75%, pick + * higher throughput rates, even if the probablity is a bit lower + */ +static void +minstrel_ht_update_stats(struct minstrel_priv *mp, struct minstrel_ht_sta *mi) +{ + struct minstrel_mcs_group_data *mg; + struct minstrel_rate_stats *mrs; + int group, i, j, cur_prob; + u16 tmp_mcs_tp_rate[MAX_THR_RATES], tmp_group_tp_rate[MAX_THR_RATES]; + u16 tmp_legacy_tp_rate[MAX_THR_RATES], tmp_max_prob_rate; + u16 index; + bool ht_supported = mi->sta->deflink.ht_cap.ht_supported; + + if (mi->ampdu_packets > 0) { + if (!ieee80211_hw_check(mp->hw, TX_STATUS_NO_AMPDU_LEN)) + mi->avg_ampdu_len = minstrel_ewma(mi->avg_ampdu_len, + MINSTREL_FRAC(mi->ampdu_len, mi->ampdu_packets), + EWMA_LEVEL); + else + mi->avg_ampdu_len = 0; + mi->ampdu_len = 0; + mi->ampdu_packets = 0; + } + + if (mi->supported[MINSTREL_CCK_GROUP]) + group = MINSTREL_CCK_GROUP; + else if (mi->supported[MINSTREL_OFDM_GROUP]) + group = MINSTREL_OFDM_GROUP; + else + group = 0; + + index = MI_RATE(group, 0); + for (j = 0; j < ARRAY_SIZE(tmp_legacy_tp_rate); j++) + tmp_legacy_tp_rate[j] = index; + + if (mi->supported[MINSTREL_VHT_GROUP_0]) + group = MINSTREL_VHT_GROUP_0; + else if (ht_supported) + group = MINSTREL_HT_GROUP_0; + else if (mi->supported[MINSTREL_CCK_GROUP]) + group = MINSTREL_CCK_GROUP; + else + group = MINSTREL_OFDM_GROUP; + + index = MI_RATE(group, 0); + tmp_max_prob_rate = index; + for (j = 0; j < ARRAY_SIZE(tmp_mcs_tp_rate); j++) + tmp_mcs_tp_rate[j] = index; + + /* Find best rate sets within all MCS groups*/ + for (group = 0; group < ARRAY_SIZE(minstrel_mcs_groups); group++) { + u16 *tp_rate = tmp_mcs_tp_rate; + u16 last_prob = 0; + + mg = &mi->groups[group]; + if (!mi->supported[group]) + continue; + + /* (re)Initialize group rate indexes */ + for(j = 0; j < MAX_THR_RATES; j++) + tmp_group_tp_rate[j] = MI_RATE(group, 0); + + if (group == MINSTREL_CCK_GROUP && ht_supported) + tp_rate = tmp_legacy_tp_rate; + + for (i = MCS_GROUP_RATES - 1; i >= 0; i--) { + if (!(mi->supported[group] & BIT(i))) + continue; + + index = MI_RATE(group, i); + + mrs = &mg->rates[i]; + mrs->retry_updated = false; + minstrel_ht_calc_rate_stats(mp, mrs); + + if (mrs->att_hist) + last_prob = max(last_prob, mrs->prob_avg); + else + mrs->prob_avg = max(last_prob, mrs->prob_avg); + cur_prob = mrs->prob_avg; + + if (minstrel_ht_get_tp_avg(mi, group, i, cur_prob) == 0) + continue; + + /* Find max throughput rate set */ + minstrel_ht_sort_best_tp_rates(mi, index, tp_rate); + + /* Find max throughput rate set within a group */ + minstrel_ht_sort_best_tp_rates(mi, index, + tmp_group_tp_rate); + } + + memcpy(mg->max_group_tp_rate, tmp_group_tp_rate, + sizeof(mg->max_group_tp_rate)); + } + + /* Assign new rate set per sta */ + minstrel_ht_assign_best_tp_rates(mi, tmp_mcs_tp_rate, + tmp_legacy_tp_rate); + memcpy(mi->max_tp_rate, tmp_mcs_tp_rate, sizeof(mi->max_tp_rate)); + + for (group = 0; group < ARRAY_SIZE(minstrel_mcs_groups); group++) { + if (!mi->supported[group]) + continue; + + mg = &mi->groups[group]; + mg->max_group_prob_rate = MI_RATE(group, 0); + + for (i = 0; i < MCS_GROUP_RATES; i++) { + if (!(mi->supported[group] & BIT(i))) + continue; + + index = MI_RATE(group, i); + + /* Find max probability rate per group and global */ + minstrel_ht_set_best_prob_rate(mi, &tmp_max_prob_rate, + index); + } + } + + mi->max_prob_rate = tmp_max_prob_rate; + + /* Try to increase robustness of max_prob_rate*/ + minstrel_ht_prob_rate_reduce_streams(mi); + minstrel_ht_refill_sample_rates(mi); + +#ifdef CONFIG_MAC80211_DEBUGFS + /* use fixed index if set */ + if (mp->fixed_rate_idx != -1) { + for (i = 0; i < 4; i++) + mi->max_tp_rate[i] = mp->fixed_rate_idx; + mi->max_prob_rate = mp->fixed_rate_idx; + } +#endif + + /* Reset update timer */ + mi->last_stats_update = jiffies; + mi->sample_time = jiffies; +} + +static bool +minstrel_ht_txstat_valid(struct minstrel_priv *mp, struct minstrel_ht_sta *mi, + struct ieee80211_tx_rate *rate) +{ + int i; + + if (rate->idx < 0) + return false; + + if (!rate->count) + return false; + + if (rate->flags & IEEE80211_TX_RC_MCS || + rate->flags & IEEE80211_TX_RC_VHT_MCS) + return true; + + for (i = 0; i < ARRAY_SIZE(mp->cck_rates); i++) + if (rate->idx == mp->cck_rates[i]) + return true; + + for (i = 0; i < ARRAY_SIZE(mp->ofdm_rates[0]); i++) + if (rate->idx == mp->ofdm_rates[mi->band][i]) + return true; + + return false; +} + +/* + * Check whether rate_status contains valid information. + */ +static bool +minstrel_ht_ri_txstat_valid(struct minstrel_priv *mp, + struct minstrel_ht_sta *mi, + struct ieee80211_rate_status *rate_status) +{ + int i; + + if (!rate_status) + return false; + if (!rate_status->try_count) + return false; + + if (rate_status->rate_idx.flags & RATE_INFO_FLAGS_MCS || + rate_status->rate_idx.flags & RATE_INFO_FLAGS_VHT_MCS) + return true; + + for (i = 0; i < ARRAY_SIZE(mp->cck_rates); i++) { + if (rate_status->rate_idx.legacy == + minstrel_cck_bitrates[ mp->cck_rates[i] ]) + return true; + } + + for (i = 0; i < ARRAY_SIZE(mp->ofdm_rates); i++) { + if (rate_status->rate_idx.legacy == + minstrel_ofdm_bitrates[ mp->ofdm_rates[mi->band][i] ]) + return true; + } + + return false; +} + +static void +minstrel_downgrade_rate(struct minstrel_ht_sta *mi, u16 *idx, bool primary) +{ + int group, orig_group; + + orig_group = group = MI_RATE_GROUP(*idx); + while (group > 0) { + group--; + + if (!mi->supported[group]) + continue; + + if (minstrel_mcs_groups[group].streams > + minstrel_mcs_groups[orig_group].streams) + continue; + + if (primary) + *idx = mi->groups[group].max_group_tp_rate[0]; + else + *idx = mi->groups[group].max_group_tp_rate[1]; + break; + } +} + +static void +minstrel_ht_tx_status(void *priv, struct ieee80211_supported_band *sband, + void *priv_sta, struct ieee80211_tx_status *st) +{ + struct ieee80211_tx_info *info = st->info; + struct minstrel_ht_sta *mi = priv_sta; + struct ieee80211_tx_rate *ar = info->status.rates; + struct minstrel_rate_stats *rate, *rate2; + struct minstrel_priv *mp = priv; + u32 update_interval = mp->update_interval; + bool last, update = false; + int i; + + /* Ignore packet that was sent with noAck flag */ + if (info->flags & IEEE80211_TX_CTL_NO_ACK) + return; + + /* This packet was aggregated but doesn't carry status info */ + if ((info->flags & IEEE80211_TX_CTL_AMPDU) && + !(info->flags & IEEE80211_TX_STAT_AMPDU)) + return; + + if (!(info->flags & IEEE80211_TX_STAT_AMPDU)) { + info->status.ampdu_ack_len = + (info->flags & IEEE80211_TX_STAT_ACK ? 1 : 0); + info->status.ampdu_len = 1; + } + + /* wraparound */ + if (mi->total_packets >= ~0 - info->status.ampdu_len) { + mi->total_packets = 0; + mi->sample_packets = 0; + } + + mi->total_packets += info->status.ampdu_len; + if (info->flags & IEEE80211_TX_CTL_RATE_CTRL_PROBE) + mi->sample_packets += info->status.ampdu_len; + + mi->ampdu_packets++; + mi->ampdu_len += info->status.ampdu_len; + + if (st->rates && st->n_rates) { + last = !minstrel_ht_ri_txstat_valid(mp, mi, &(st->rates[0])); + for (i = 0; !last; i++) { + last = (i == st->n_rates - 1) || + !minstrel_ht_ri_txstat_valid(mp, mi, + &(st->rates[i + 1])); + + rate = minstrel_ht_ri_get_stats(mp, mi, + &(st->rates[i])); + + if (last) + rate->success += info->status.ampdu_ack_len; + + rate->attempts += st->rates[i].try_count * + info->status.ampdu_len; + } + } else { + last = !minstrel_ht_txstat_valid(mp, mi, &ar[0]); + for (i = 0; !last; i++) { + last = (i == IEEE80211_TX_MAX_RATES - 1) || + !minstrel_ht_txstat_valid(mp, mi, &ar[i + 1]); + + rate = minstrel_ht_get_stats(mp, mi, &ar[i]); + if (last) + rate->success += info->status.ampdu_ack_len; + + rate->attempts += ar[i].count * info->status.ampdu_len; + } + } + + if (mp->hw->max_rates > 1) { + /* + * check for sudden death of spatial multiplexing, + * downgrade to a lower number of streams if necessary. + */ + rate = minstrel_get_ratestats(mi, mi->max_tp_rate[0]); + if (rate->attempts > 30 && + rate->success < rate->attempts / 4) { + minstrel_downgrade_rate(mi, &mi->max_tp_rate[0], true); + update = true; + } + + rate2 = minstrel_get_ratestats(mi, mi->max_tp_rate[1]); + if (rate2->attempts > 30 && + rate2->success < rate2->attempts / 4) { + minstrel_downgrade_rate(mi, &mi->max_tp_rate[1], false); + update = true; + } + } + + if (time_after(jiffies, mi->last_stats_update + update_interval)) { + update = true; + minstrel_ht_update_stats(mp, mi); + } + + if (update) + minstrel_ht_update_rates(mp, mi); +} + +static void +minstrel_calc_retransmit(struct minstrel_priv *mp, struct minstrel_ht_sta *mi, + int index) +{ + struct minstrel_rate_stats *mrs; + unsigned int tx_time, tx_time_rtscts, tx_time_data; + unsigned int cw = mp->cw_min; + unsigned int ctime = 0; + unsigned int t_slot = 9; /* FIXME */ + unsigned int ampdu_len = minstrel_ht_avg_ampdu_len(mi); + unsigned int overhead = 0, overhead_rtscts = 0; + + mrs = minstrel_get_ratestats(mi, index); + if (mrs->prob_avg < MINSTREL_FRAC(1, 10)) { + mrs->retry_count = 1; + mrs->retry_count_rtscts = 1; + return; + } + + mrs->retry_count = 2; + mrs->retry_count_rtscts = 2; + mrs->retry_updated = true; + + tx_time_data = minstrel_get_duration(index) * ampdu_len / 1000; + + /* Contention time for first 2 tries */ + ctime = (t_slot * cw) >> 1; + cw = min((cw << 1) | 1, mp->cw_max); + ctime += (t_slot * cw) >> 1; + cw = min((cw << 1) | 1, mp->cw_max); + + if (minstrel_ht_is_legacy_group(MI_RATE_GROUP(index))) { + overhead = mi->overhead_legacy; + overhead_rtscts = mi->overhead_legacy_rtscts; + } else { + overhead = mi->overhead; + overhead_rtscts = mi->overhead_rtscts; + } + + /* Total TX time for data and Contention after first 2 tries */ + tx_time = ctime + 2 * (overhead + tx_time_data); + tx_time_rtscts = ctime + 2 * (overhead_rtscts + tx_time_data); + + /* See how many more tries we can fit inside segment size */ + do { + /* Contention time for this try */ + ctime = (t_slot * cw) >> 1; + cw = min((cw << 1) | 1, mp->cw_max); + + /* Total TX time after this try */ + tx_time += ctime + overhead + tx_time_data; + tx_time_rtscts += ctime + overhead_rtscts + tx_time_data; + + if (tx_time_rtscts < mp->segment_size) + mrs->retry_count_rtscts++; + } while ((tx_time < mp->segment_size) && + (++mrs->retry_count < mp->max_retry)); +} + + +static void +minstrel_ht_set_rate(struct minstrel_priv *mp, struct minstrel_ht_sta *mi, + struct ieee80211_sta_rates *ratetbl, int offset, int index) +{ + int group_idx = MI_RATE_GROUP(index); + const struct mcs_group *group = &minstrel_mcs_groups[group_idx]; + struct minstrel_rate_stats *mrs; + u8 idx; + u16 flags = group->flags; + + mrs = minstrel_get_ratestats(mi, index); + if (!mrs->retry_updated) + minstrel_calc_retransmit(mp, mi, index); + + if (mrs->prob_avg < MINSTREL_FRAC(20, 100) || !mrs->retry_count) { + ratetbl->rate[offset].count = 2; + ratetbl->rate[offset].count_rts = 2; + ratetbl->rate[offset].count_cts = 2; + } else { + ratetbl->rate[offset].count = mrs->retry_count; + ratetbl->rate[offset].count_cts = mrs->retry_count; + ratetbl->rate[offset].count_rts = mrs->retry_count_rtscts; + } + + index = MI_RATE_IDX(index); + if (group_idx == MINSTREL_CCK_GROUP) + idx = mp->cck_rates[index % ARRAY_SIZE(mp->cck_rates)]; + else if (group_idx == MINSTREL_OFDM_GROUP) + idx = mp->ofdm_rates[mi->band][index % + ARRAY_SIZE(mp->ofdm_rates[0])]; + else if (flags & IEEE80211_TX_RC_VHT_MCS) + idx = ((group->streams - 1) << 4) | + (index & 0xF); + else + idx = index + (group->streams - 1) * 8; + + /* enable RTS/CTS if needed: + * - if station is in dynamic SMPS (and streams > 1) + * - for fallback rates, to increase chances of getting through + */ + if (offset > 0 || + (mi->sta->deflink.smps_mode == IEEE80211_SMPS_DYNAMIC && + group->streams > 1)) { + ratetbl->rate[offset].count = ratetbl->rate[offset].count_rts; + flags |= IEEE80211_TX_RC_USE_RTS_CTS; + } + + ratetbl->rate[offset].idx = idx; + ratetbl->rate[offset].flags = flags; +} + +static inline int +minstrel_ht_get_prob_avg(struct minstrel_ht_sta *mi, int rate) +{ + int group = MI_RATE_GROUP(rate); + rate = MI_RATE_IDX(rate); + return mi->groups[group].rates[rate].prob_avg; +} + +static int +minstrel_ht_get_max_amsdu_len(struct minstrel_ht_sta *mi) +{ + int group = MI_RATE_GROUP(mi->max_prob_rate); + const struct mcs_group *g = &minstrel_mcs_groups[group]; + int rate = MI_RATE_IDX(mi->max_prob_rate); + unsigned int duration; + + /* Disable A-MSDU if max_prob_rate is bad */ + if (mi->groups[group].rates[rate].prob_avg < MINSTREL_FRAC(50, 100)) + return 1; + + duration = g->duration[rate]; + duration <<= g->shift; + + /* If the rate is slower than single-stream MCS1, make A-MSDU limit small */ + if (duration > MCS_DURATION(1, 0, 52)) + return 500; + + /* + * If the rate is slower than single-stream MCS4, limit A-MSDU to usual + * data packet size + */ + if (duration > MCS_DURATION(1, 0, 104)) + return 1600; + + /* + * If the rate is slower than single-stream MCS7, or if the max throughput + * rate success probability is less than 75%, limit A-MSDU to twice the usual + * data packet size + */ + if (duration > MCS_DURATION(1, 0, 260) || + (minstrel_ht_get_prob_avg(mi, mi->max_tp_rate[0]) < + MINSTREL_FRAC(75, 100))) + return 3200; + + /* + * HT A-MPDU limits maximum MPDU size under BA agreement to 4095 bytes. + * Since aggregation sessions are started/stopped without txq flush, use + * the limit here to avoid the complexity of having to de-aggregate + * packets in the queue. + */ + if (!mi->sta->deflink.vht_cap.vht_supported) + return IEEE80211_MAX_MPDU_LEN_HT_BA; + + /* unlimited */ + return 0; +} + +static void +minstrel_ht_update_rates(struct minstrel_priv *mp, struct minstrel_ht_sta *mi) +{ + struct ieee80211_sta_rates *rates; + int i = 0; + int max_rates = min_t(int, mp->hw->max_rates, IEEE80211_TX_RATE_TABLE_SIZE); + + rates = kzalloc(sizeof(*rates), GFP_ATOMIC); + if (!rates) + return; + + /* Start with max_tp_rate[0] */ + minstrel_ht_set_rate(mp, mi, rates, i++, mi->max_tp_rate[0]); + + /* Fill up remaining, keep one entry for max_probe_rate */ + for (; i < (max_rates - 1); i++) + minstrel_ht_set_rate(mp, mi, rates, i, mi->max_tp_rate[i]); + + if (i < max_rates) + minstrel_ht_set_rate(mp, mi, rates, i++, mi->max_prob_rate); + + if (i < IEEE80211_TX_RATE_TABLE_SIZE) + rates->rate[i].idx = -1; + + mi->sta->deflink.agg.max_rc_amsdu_len = minstrel_ht_get_max_amsdu_len(mi); + ieee80211_sta_recalc_aggregates(mi->sta); + rate_control_set_rates(mp->hw, mi->sta, rates); +} + +static u16 +minstrel_ht_get_sample_rate(struct minstrel_priv *mp, struct minstrel_ht_sta *mi) +{ + u8 seq; + + if (mp->hw->max_rates > 1) { + seq = mi->sample_seq; + mi->sample_seq = (seq + 1) % ARRAY_SIZE(minstrel_sample_seq); + seq = minstrel_sample_seq[seq]; + } else { + seq = MINSTREL_SAMPLE_TYPE_INC; + } + + return __minstrel_ht_get_sample_rate(mi, seq); +} + +static void +minstrel_ht_get_rate(void *priv, struct ieee80211_sta *sta, void *priv_sta, + struct ieee80211_tx_rate_control *txrc) +{ + const struct mcs_group *sample_group; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(txrc->skb); + struct ieee80211_tx_rate *rate = &info->status.rates[0]; + struct minstrel_ht_sta *mi = priv_sta; + struct minstrel_priv *mp = priv; + u16 sample_idx; + + info->flags |= mi->tx_flags; + +#ifdef CONFIG_MAC80211_DEBUGFS + if (mp->fixed_rate_idx != -1) + return; +#endif + + /* Don't use EAPOL frames for sampling on non-mrr hw */ + if (mp->hw->max_rates == 1 && + (info->control.flags & IEEE80211_TX_CTRL_PORT_CTRL_PROTO)) + return; + + if (time_is_after_jiffies(mi->sample_time)) + return; + + mi->sample_time = jiffies + MINSTREL_SAMPLE_INTERVAL; + sample_idx = minstrel_ht_get_sample_rate(mp, mi); + if (!sample_idx) + return; + + sample_group = &minstrel_mcs_groups[MI_RATE_GROUP(sample_idx)]; + sample_idx = MI_RATE_IDX(sample_idx); + + if (sample_group == &minstrel_mcs_groups[MINSTREL_CCK_GROUP] && + (sample_idx >= 4) != txrc->short_preamble) + return; + + info->flags |= IEEE80211_TX_CTL_RATE_CTRL_PROBE; + rate->count = 1; + + if (sample_group == &minstrel_mcs_groups[MINSTREL_CCK_GROUP]) { + int idx = sample_idx % ARRAY_SIZE(mp->cck_rates); + rate->idx = mp->cck_rates[idx]; + } else if (sample_group == &minstrel_mcs_groups[MINSTREL_OFDM_GROUP]) { + int idx = sample_idx % ARRAY_SIZE(mp->ofdm_rates[0]); + rate->idx = mp->ofdm_rates[mi->band][idx]; + } else if (sample_group->flags & IEEE80211_TX_RC_VHT_MCS) { + ieee80211_rate_set_vht(rate, MI_RATE_IDX(sample_idx), + sample_group->streams); + } else { + rate->idx = sample_idx + (sample_group->streams - 1) * 8; + } + + rate->flags = sample_group->flags; +} + +static void +minstrel_ht_update_cck(struct minstrel_priv *mp, struct minstrel_ht_sta *mi, + struct ieee80211_supported_band *sband, + struct ieee80211_sta *sta) +{ + int i; + + if (sband->band != NL80211_BAND_2GHZ) + return; + + if (sta->deflink.ht_cap.ht_supported && + !ieee80211_hw_check(mp->hw, SUPPORTS_HT_CCK_RATES)) + return; + + for (i = 0; i < 4; i++) { + if (mp->cck_rates[i] == 0xff || + !rate_supported(sta, sband->band, mp->cck_rates[i])) + continue; + + mi->supported[MINSTREL_CCK_GROUP] |= BIT(i); + if (sband->bitrates[i].flags & IEEE80211_RATE_SHORT_PREAMBLE) + mi->supported[MINSTREL_CCK_GROUP] |= BIT(i + 4); + } +} + +static void +minstrel_ht_update_ofdm(struct minstrel_priv *mp, struct minstrel_ht_sta *mi, + struct ieee80211_supported_band *sband, + struct ieee80211_sta *sta) +{ + const u8 *rates; + int i; + + if (sta->deflink.ht_cap.ht_supported) + return; + + rates = mp->ofdm_rates[sband->band]; + for (i = 0; i < ARRAY_SIZE(mp->ofdm_rates[0]); i++) { + if (rates[i] == 0xff || + !rate_supported(sta, sband->band, rates[i])) + continue; + + mi->supported[MINSTREL_OFDM_GROUP] |= BIT(i); + } +} + +static void +minstrel_ht_update_caps(void *priv, struct ieee80211_supported_band *sband, + struct cfg80211_chan_def *chandef, + struct ieee80211_sta *sta, void *priv_sta) +{ + struct minstrel_priv *mp = priv; + struct minstrel_ht_sta *mi = priv_sta; + struct ieee80211_mcs_info *mcs = &sta->deflink.ht_cap.mcs; + u16 ht_cap = sta->deflink.ht_cap.cap; + struct ieee80211_sta_vht_cap *vht_cap = &sta->deflink.vht_cap; + const struct ieee80211_rate *ctl_rate; + struct sta_info *sta_info; + bool ldpc, erp; + int use_vht; + int n_supported = 0; + int ack_dur; + int stbc; + int i; + + BUILD_BUG_ON(ARRAY_SIZE(minstrel_mcs_groups) != MINSTREL_GROUPS_NB); + + if (vht_cap->vht_supported) + use_vht = vht_cap->vht_mcs.tx_mcs_map != cpu_to_le16(~0); + else + use_vht = 0; + + memset(mi, 0, sizeof(*mi)); + + mi->sta = sta; + mi->band = sband->band; + mi->last_stats_update = jiffies; + + ack_dur = ieee80211_frame_duration(sband->band, 10, 60, 1, 1, 0); + mi->overhead = ieee80211_frame_duration(sband->band, 0, 60, 1, 1, 0); + mi->overhead += ack_dur; + mi->overhead_rtscts = mi->overhead + 2 * ack_dur; + + ctl_rate = &sband->bitrates[rate_lowest_index(sband, sta)]; + erp = ctl_rate->flags & IEEE80211_RATE_ERP_G; + ack_dur = ieee80211_frame_duration(sband->band, 10, + ctl_rate->bitrate, erp, 1, + ieee80211_chandef_get_shift(chandef)); + mi->overhead_legacy = ack_dur; + mi->overhead_legacy_rtscts = mi->overhead_legacy + 2 * ack_dur; + + mi->avg_ampdu_len = MINSTREL_FRAC(1, 1); + + if (!use_vht) { + stbc = (ht_cap & IEEE80211_HT_CAP_RX_STBC) >> + IEEE80211_HT_CAP_RX_STBC_SHIFT; + + ldpc = ht_cap & IEEE80211_HT_CAP_LDPC_CODING; + } else { + stbc = (vht_cap->cap & IEEE80211_VHT_CAP_RXSTBC_MASK) >> + IEEE80211_VHT_CAP_RXSTBC_SHIFT; + + ldpc = vht_cap->cap & IEEE80211_VHT_CAP_RXLDPC; + } + + mi->tx_flags |= stbc << IEEE80211_TX_CTL_STBC_SHIFT; + if (ldpc) + mi->tx_flags |= IEEE80211_TX_CTL_LDPC; + + for (i = 0; i < ARRAY_SIZE(mi->groups); i++) { + u32 gflags = minstrel_mcs_groups[i].flags; + int bw, nss; + + mi->supported[i] = 0; + if (minstrel_ht_is_legacy_group(i)) + continue; + + if (gflags & IEEE80211_TX_RC_SHORT_GI) { + if (gflags & IEEE80211_TX_RC_40_MHZ_WIDTH) { + if (!(ht_cap & IEEE80211_HT_CAP_SGI_40)) + continue; + } else { + if (!(ht_cap & IEEE80211_HT_CAP_SGI_20)) + continue; + } + } + + if (gflags & IEEE80211_TX_RC_40_MHZ_WIDTH && + sta->deflink.bandwidth < IEEE80211_STA_RX_BW_40) + continue; + + nss = minstrel_mcs_groups[i].streams; + + /* Mark MCS > 7 as unsupported if STA is in static SMPS mode */ + if (sta->deflink.smps_mode == IEEE80211_SMPS_STATIC && nss > 1) + continue; + + /* HT rate */ + if (gflags & IEEE80211_TX_RC_MCS) { + if (use_vht && minstrel_vht_only) + continue; + + mi->supported[i] = mcs->rx_mask[nss - 1]; + if (mi->supported[i]) + n_supported++; + continue; + } + + /* VHT rate */ + if (!vht_cap->vht_supported || + WARN_ON(!(gflags & IEEE80211_TX_RC_VHT_MCS)) || + WARN_ON(gflags & IEEE80211_TX_RC_160_MHZ_WIDTH)) + continue; + + if (gflags & IEEE80211_TX_RC_80_MHZ_WIDTH) { + if (sta->deflink.bandwidth < IEEE80211_STA_RX_BW_80 || + ((gflags & IEEE80211_TX_RC_SHORT_GI) && + !(vht_cap->cap & IEEE80211_VHT_CAP_SHORT_GI_80))) { + continue; + } + } + + if (gflags & IEEE80211_TX_RC_40_MHZ_WIDTH) + bw = BW_40; + else if (gflags & IEEE80211_TX_RC_80_MHZ_WIDTH) + bw = BW_80; + else + bw = BW_20; + + mi->supported[i] = minstrel_get_valid_vht_rates(bw, nss, + vht_cap->vht_mcs.tx_mcs_map); + + if (mi->supported[i]) + n_supported++; + } + + sta_info = container_of(sta, struct sta_info, sta); + mi->use_short_preamble = test_sta_flag(sta_info, WLAN_STA_SHORT_PREAMBLE) && + sta_info->sdata->vif.bss_conf.use_short_preamble; + + minstrel_ht_update_cck(mp, mi, sband, sta); + minstrel_ht_update_ofdm(mp, mi, sband, sta); + + /* create an initial rate table with the lowest supported rates */ + minstrel_ht_update_stats(mp, mi); + minstrel_ht_update_rates(mp, mi); +} + +static void +minstrel_ht_rate_init(void *priv, struct ieee80211_supported_band *sband, + struct cfg80211_chan_def *chandef, + struct ieee80211_sta *sta, void *priv_sta) +{ + minstrel_ht_update_caps(priv, sband, chandef, sta, priv_sta); +} + +static void +minstrel_ht_rate_update(void *priv, struct ieee80211_supported_band *sband, + struct cfg80211_chan_def *chandef, + struct ieee80211_sta *sta, void *priv_sta, + u32 changed) +{ + minstrel_ht_update_caps(priv, sband, chandef, sta, priv_sta); +} + +static void * +minstrel_ht_alloc_sta(void *priv, struct ieee80211_sta *sta, gfp_t gfp) +{ + struct ieee80211_supported_band *sband; + struct minstrel_ht_sta *mi; + struct minstrel_priv *mp = priv; + struct ieee80211_hw *hw = mp->hw; + int max_rates = 0; + int i; + + for (i = 0; i < NUM_NL80211_BANDS; i++) { + sband = hw->wiphy->bands[i]; + if (sband && sband->n_bitrates > max_rates) + max_rates = sband->n_bitrates; + } + + return kzalloc(sizeof(*mi), gfp); +} + +static void +minstrel_ht_free_sta(void *priv, struct ieee80211_sta *sta, void *priv_sta) +{ + kfree(priv_sta); +} + +static void +minstrel_ht_fill_rate_array(u8 *dest, struct ieee80211_supported_band *sband, + const s16 *bitrates, int n_rates, u32 rate_flags) +{ + int i, j; + + for (i = 0; i < sband->n_bitrates; i++) { + struct ieee80211_rate *rate = &sband->bitrates[i]; + + if ((rate_flags & sband->bitrates[i].flags) != rate_flags) + continue; + + for (j = 0; j < n_rates; j++) { + if (rate->bitrate != bitrates[j]) + continue; + + dest[j] = i; + break; + } + } +} + +static void +minstrel_ht_init_cck_rates(struct minstrel_priv *mp) +{ + static const s16 bitrates[4] = { 10, 20, 55, 110 }; + struct ieee80211_supported_band *sband; + u32 rate_flags = ieee80211_chandef_rate_flags(&mp->hw->conf.chandef); + + memset(mp->cck_rates, 0xff, sizeof(mp->cck_rates)); + sband = mp->hw->wiphy->bands[NL80211_BAND_2GHZ]; + if (!sband) + return; + + BUILD_BUG_ON(ARRAY_SIZE(mp->cck_rates) != ARRAY_SIZE(bitrates)); + minstrel_ht_fill_rate_array(mp->cck_rates, sband, + minstrel_cck_bitrates, + ARRAY_SIZE(minstrel_cck_bitrates), + rate_flags); +} + +static void +minstrel_ht_init_ofdm_rates(struct minstrel_priv *mp, enum nl80211_band band) +{ + static const s16 bitrates[8] = { 60, 90, 120, 180, 240, 360, 480, 540 }; + struct ieee80211_supported_band *sband; + u32 rate_flags = ieee80211_chandef_rate_flags(&mp->hw->conf.chandef); + + memset(mp->ofdm_rates[band], 0xff, sizeof(mp->ofdm_rates[band])); + sband = mp->hw->wiphy->bands[band]; + if (!sband) + return; + + BUILD_BUG_ON(ARRAY_SIZE(mp->ofdm_rates[band]) != ARRAY_SIZE(bitrates)); + minstrel_ht_fill_rate_array(mp->ofdm_rates[band], sband, + minstrel_ofdm_bitrates, + ARRAY_SIZE(minstrel_ofdm_bitrates), + rate_flags); +} + +static void * +minstrel_ht_alloc(struct ieee80211_hw *hw) +{ + struct minstrel_priv *mp; + int i; + + mp = kzalloc(sizeof(struct minstrel_priv), GFP_ATOMIC); + if (!mp) + return NULL; + + /* contention window settings + * Just an approximation. Using the per-queue values would complicate + * the calculations and is probably unnecessary */ + mp->cw_min = 15; + mp->cw_max = 1023; + + /* maximum time that the hw is allowed to stay in one MRR segment */ + mp->segment_size = 6000; + + if (hw->max_rate_tries > 0) + mp->max_retry = hw->max_rate_tries; + else + /* safe default, does not necessarily have to match hw properties */ + mp->max_retry = 7; + + if (hw->max_rates >= 4) + mp->has_mrr = true; + + mp->hw = hw; + mp->update_interval = HZ / 20; + + minstrel_ht_init_cck_rates(mp); + for (i = 0; i < ARRAY_SIZE(mp->hw->wiphy->bands); i++) + minstrel_ht_init_ofdm_rates(mp, i); + + return mp; +} + +#ifdef CONFIG_MAC80211_DEBUGFS +static void minstrel_ht_add_debugfs(struct ieee80211_hw *hw, void *priv, + struct dentry *debugfsdir) +{ + struct minstrel_priv *mp = priv; + + mp->fixed_rate_idx = (u32) -1; + debugfs_create_u32("fixed_rate_idx", S_IRUGO | S_IWUGO, debugfsdir, + &mp->fixed_rate_idx); +} +#endif + +static void +minstrel_ht_free(void *priv) +{ + kfree(priv); +} + +static u32 minstrel_ht_get_expected_throughput(void *priv_sta) +{ + struct minstrel_ht_sta *mi = priv_sta; + int i, j, prob, tp_avg; + + i = MI_RATE_GROUP(mi->max_tp_rate[0]); + j = MI_RATE_IDX(mi->max_tp_rate[0]); + prob = mi->groups[i].rates[j].prob_avg; + + /* convert tp_avg from pkt per second in kbps */ + tp_avg = minstrel_ht_get_tp_avg(mi, i, j, prob) * 10; + tp_avg = tp_avg * AVG_PKT_SIZE * 8 / 1024; + + return tp_avg; +} + +static const struct rate_control_ops mac80211_minstrel_ht = { + .name = "minstrel_ht", + .capa = RATE_CTRL_CAPA_AMPDU_TRIGGER, + .tx_status_ext = minstrel_ht_tx_status, + .get_rate = minstrel_ht_get_rate, + .rate_init = minstrel_ht_rate_init, + .rate_update = minstrel_ht_rate_update, + .alloc_sta = minstrel_ht_alloc_sta, + .free_sta = minstrel_ht_free_sta, + .alloc = minstrel_ht_alloc, + .free = minstrel_ht_free, +#ifdef CONFIG_MAC80211_DEBUGFS + .add_debugfs = minstrel_ht_add_debugfs, + .add_sta_debugfs = minstrel_ht_add_sta_debugfs, +#endif + .get_expected_throughput = minstrel_ht_get_expected_throughput, +}; + + +static void __init init_sample_table(void) +{ + int col, i, new_idx; + u8 rnd[MCS_GROUP_RATES]; + + memset(sample_table, 0xff, sizeof(sample_table)); + for (col = 0; col < SAMPLE_COLUMNS; col++) { + get_random_bytes(rnd, sizeof(rnd)); + for (i = 0; i < MCS_GROUP_RATES; i++) { + new_idx = (i + rnd[i]) % MCS_GROUP_RATES; + while (sample_table[col][new_idx] != 0xff) + new_idx = (new_idx + 1) % MCS_GROUP_RATES; + + sample_table[col][new_idx] = i; + } + } +} + +int __init +rc80211_minstrel_init(void) +{ + init_sample_table(); + return ieee80211_rate_control_register(&mac80211_minstrel_ht); +} + +void +rc80211_minstrel_exit(void) +{ + ieee80211_rate_control_unregister(&mac80211_minstrel_ht); +} diff --git a/net/mac80211/rc80211_minstrel_ht.h b/net/mac80211/rc80211_minstrel_ht.h new file mode 100644 index 000000000..1766ff0c7 --- /dev/null +++ b/net/mac80211/rc80211_minstrel_ht.h @@ -0,0 +1,203 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (C) 2010 Felix Fietkau + */ + +#ifndef __RC_MINSTREL_HT_H +#define __RC_MINSTREL_HT_H + +#include + +/* number of highest throughput rates to consider*/ +#define MAX_THR_RATES 4 +#define SAMPLE_COLUMNS 10 /* number of columns in sample table */ + +/* scaled fraction values */ +#define MINSTREL_SCALE 12 +#define MINSTREL_FRAC(val, div) (((val) << MINSTREL_SCALE) / div) +#define MINSTREL_TRUNC(val) ((val) >> MINSTREL_SCALE) + +#define EWMA_LEVEL 96 /* ewma weighting factor [/EWMA_DIV] */ +#define EWMA_DIV 128 + +/* + * Coefficients for moving average with noise filter (period=16), + * scaled by 10 bits + * + * a1 = exp(-pi * sqrt(2) / period) + * coeff2 = 2 * a1 * cos(sqrt(2) * 2 * pi / period) + * coeff3 = -sqr(a1) + * coeff1 = 1 - coeff2 - coeff3 + */ +#define MINSTREL_AVG_COEFF1 (MINSTREL_FRAC(1, 1) - \ + MINSTREL_AVG_COEFF2 - \ + MINSTREL_AVG_COEFF3) +#define MINSTREL_AVG_COEFF2 0x00001499 +#define MINSTREL_AVG_COEFF3 -0x0000092e + +/* + * The number of streams can be changed to 2 to reduce code + * size and memory footprint. + */ +#define MINSTREL_MAX_STREAMS 4 +#define MINSTREL_HT_STREAM_GROUPS 4 /* BW(=2) * SGI(=2) */ +#define MINSTREL_VHT_STREAM_GROUPS 6 /* BW(=3) * SGI(=2) */ + +#define MINSTREL_HT_GROUPS_NB (MINSTREL_MAX_STREAMS * \ + MINSTREL_HT_STREAM_GROUPS) +#define MINSTREL_VHT_GROUPS_NB (MINSTREL_MAX_STREAMS * \ + MINSTREL_VHT_STREAM_GROUPS) +#define MINSTREL_LEGACY_GROUPS_NB 2 +#define MINSTREL_GROUPS_NB (MINSTREL_HT_GROUPS_NB + \ + MINSTREL_VHT_GROUPS_NB + \ + MINSTREL_LEGACY_GROUPS_NB) + +#define MINSTREL_HT_GROUP_0 0 +#define MINSTREL_CCK_GROUP (MINSTREL_HT_GROUP_0 + MINSTREL_HT_GROUPS_NB) +#define MINSTREL_OFDM_GROUP (MINSTREL_CCK_GROUP + 1) +#define MINSTREL_VHT_GROUP_0 (MINSTREL_OFDM_GROUP + 1) + +#define MCS_GROUP_RATES 10 + +#define MI_RATE_IDX_MASK GENMASK(3, 0) +#define MI_RATE_GROUP_MASK GENMASK(15, 4) + +#define MI_RATE(_group, _idx) \ + (FIELD_PREP(MI_RATE_GROUP_MASK, _group) | \ + FIELD_PREP(MI_RATE_IDX_MASK, _idx)) + +#define MI_RATE_IDX(_rate) FIELD_GET(MI_RATE_IDX_MASK, _rate) +#define MI_RATE_GROUP(_rate) FIELD_GET(MI_RATE_GROUP_MASK, _rate) + +#define MINSTREL_SAMPLE_RATES 5 /* rates per sample type */ +#define MINSTREL_SAMPLE_INTERVAL (HZ / 50) + +struct minstrel_priv { + struct ieee80211_hw *hw; + bool has_mrr; + unsigned int cw_min; + unsigned int cw_max; + unsigned int max_retry; + unsigned int segment_size; + unsigned int update_interval; + + u8 cck_rates[4]; + u8 ofdm_rates[NUM_NL80211_BANDS][8]; + +#ifdef CONFIG_MAC80211_DEBUGFS + /* + * enable fixed rate processing per RC + * - write static index to debugfs:ieee80211/phyX/rc/fixed_rate_idx + * - write -1 to enable RC processing again + * - setting will be applied on next update + */ + u32 fixed_rate_idx; +#endif +}; + + +struct mcs_group { + u16 flags; + u8 streams; + u8 shift; + u8 bw; + u16 duration[MCS_GROUP_RATES]; +}; + +extern const s16 minstrel_cck_bitrates[4]; +extern const s16 minstrel_ofdm_bitrates[8]; +extern const struct mcs_group minstrel_mcs_groups[]; + +struct minstrel_rate_stats { + /* current / last sampling period attempts/success counters */ + u16 attempts, last_attempts; + u16 success, last_success; + + /* total attempts/success counters */ + u32 att_hist, succ_hist; + + /* prob_avg - moving average of prob */ + u16 prob_avg; + u16 prob_avg_1; + + /* maximum retry counts */ + u8 retry_count; + u8 retry_count_rtscts; + + bool retry_updated; +}; + +enum minstrel_sample_type { + MINSTREL_SAMPLE_TYPE_INC, + MINSTREL_SAMPLE_TYPE_JUMP, + MINSTREL_SAMPLE_TYPE_SLOW, + __MINSTREL_SAMPLE_TYPE_MAX +}; + +struct minstrel_mcs_group_data { + u8 index; + u8 column; + + /* sorted rate set within a MCS group*/ + u16 max_group_tp_rate[MAX_THR_RATES]; + u16 max_group_prob_rate; + + /* MCS rate statistics */ + struct minstrel_rate_stats rates[MCS_GROUP_RATES]; +}; + +struct minstrel_sample_category { + u8 sample_group; + u16 sample_rates[MINSTREL_SAMPLE_RATES]; + u16 cur_sample_rates[MINSTREL_SAMPLE_RATES]; +}; + +struct minstrel_ht_sta { + struct ieee80211_sta *sta; + + /* ampdu length (average, per sampling interval) */ + unsigned int ampdu_len; + unsigned int ampdu_packets; + + /* ampdu length (EWMA) */ + unsigned int avg_ampdu_len; + + /* overall sorted rate set */ + u16 max_tp_rate[MAX_THR_RATES]; + u16 max_prob_rate; + + /* time of last status update */ + unsigned long last_stats_update; + + /* overhead time in usec for each frame */ + unsigned int overhead; + unsigned int overhead_rtscts; + unsigned int overhead_legacy; + unsigned int overhead_legacy_rtscts; + + unsigned int total_packets; + unsigned int sample_packets; + + /* tx flags to add for frames for this sta */ + u32 tx_flags; + bool use_short_preamble; + u8 band; + + u8 sample_seq; + u16 sample_rate; + + unsigned long sample_time; + struct minstrel_sample_category sample[__MINSTREL_SAMPLE_TYPE_MAX]; + + /* Bitfield of supported MCS rates of all groups */ + u16 supported[MINSTREL_GROUPS_NB]; + + /* MCS rate group info and statistics */ + struct minstrel_mcs_group_data groups[MINSTREL_GROUPS_NB]; +}; + +void minstrel_ht_add_sta_debugfs(void *priv, void *priv_sta, struct dentry *dir); +int minstrel_ht_get_tp_avg(struct minstrel_ht_sta *mi, int group, int rate, + int prob_avg); + +#endif diff --git a/net/mac80211/rc80211_minstrel_ht_debugfs.c b/net/mac80211/rc80211_minstrel_ht_debugfs.c new file mode 100644 index 000000000..25b8a67a6 --- /dev/null +++ b/net/mac80211/rc80211_minstrel_ht_debugfs.c @@ -0,0 +1,336 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2010 Felix Fietkau + */ +#include +#include +#include +#include +#include +#include +#include +#include "rc80211_minstrel_ht.h" + +struct minstrel_debugfs_info { + size_t len; + char buf[]; +}; + +static ssize_t +minstrel_stats_read(struct file *file, char __user *buf, size_t len, loff_t *ppos) +{ + struct minstrel_debugfs_info *ms; + + ms = file->private_data; + return simple_read_from_buffer(buf, len, ppos, ms->buf, ms->len); +} + +static int +minstrel_stats_release(struct inode *inode, struct file *file) +{ + kfree(file->private_data); + return 0; +} + +static bool +minstrel_ht_is_sample_rate(struct minstrel_ht_sta *mi, int idx) +{ + int type, i; + + for (type = 0; type < ARRAY_SIZE(mi->sample); type++) + for (i = 0; i < MINSTREL_SAMPLE_RATES; i++) + if (mi->sample[type].cur_sample_rates[i] == idx) + return true; + return false; +} + +static char * +minstrel_ht_stats_dump(struct minstrel_ht_sta *mi, int i, char *p) +{ + const struct mcs_group *mg; + unsigned int j, tp_max, tp_avg, eprob, tx_time; + char htmode = '2'; + char gimode = 'L'; + u32 gflags; + + if (!mi->supported[i]) + return p; + + mg = &minstrel_mcs_groups[i]; + gflags = mg->flags; + + if (gflags & IEEE80211_TX_RC_40_MHZ_WIDTH) + htmode = '4'; + else if (gflags & IEEE80211_TX_RC_80_MHZ_WIDTH) + htmode = '8'; + if (gflags & IEEE80211_TX_RC_SHORT_GI) + gimode = 'S'; + + for (j = 0; j < MCS_GROUP_RATES; j++) { + struct minstrel_rate_stats *mrs = &mi->groups[i].rates[j]; + int idx = MI_RATE(i, j); + unsigned int duration; + + if (!(mi->supported[i] & BIT(j))) + continue; + + if (gflags & IEEE80211_TX_RC_MCS) { + p += sprintf(p, "HT%c0 ", htmode); + p += sprintf(p, "%cGI ", gimode); + p += sprintf(p, "%d ", mg->streams); + } else if (gflags & IEEE80211_TX_RC_VHT_MCS) { + p += sprintf(p, "VHT%c0 ", htmode); + p += sprintf(p, "%cGI ", gimode); + p += sprintf(p, "%d ", mg->streams); + } else if (i == MINSTREL_OFDM_GROUP) { + p += sprintf(p, "OFDM "); + p += sprintf(p, "1 "); + } else { + p += sprintf(p, "CCK "); + p += sprintf(p, "%cP ", j < 4 ? 'L' : 'S'); + p += sprintf(p, "1 "); + } + + *(p++) = (idx == mi->max_tp_rate[0]) ? 'A' : ' '; + *(p++) = (idx == mi->max_tp_rate[1]) ? 'B' : ' '; + *(p++) = (idx == mi->max_tp_rate[2]) ? 'C' : ' '; + *(p++) = (idx == mi->max_tp_rate[3]) ? 'D' : ' '; + *(p++) = (idx == mi->max_prob_rate) ? 'P' : ' '; + *(p++) = minstrel_ht_is_sample_rate(mi, idx) ? 'S' : ' '; + + if (gflags & IEEE80211_TX_RC_MCS) { + p += sprintf(p, " MCS%-2u", (mg->streams - 1) * 8 + j); + } else if (gflags & IEEE80211_TX_RC_VHT_MCS) { + p += sprintf(p, " MCS%-1u/%1u", j, mg->streams); + } else { + int r; + + if (i == MINSTREL_OFDM_GROUP) + r = minstrel_ofdm_bitrates[j % 8]; + else + r = minstrel_cck_bitrates[j % 4]; + + p += sprintf(p, " %2u.%1uM", r / 10, r % 10); + } + + p += sprintf(p, " %3u ", idx); + + /* tx_time[rate(i)] in usec */ + duration = mg->duration[j]; + duration <<= mg->shift; + tx_time = DIV_ROUND_CLOSEST(duration, 1000); + p += sprintf(p, "%6u ", tx_time); + + tp_max = minstrel_ht_get_tp_avg(mi, i, j, MINSTREL_FRAC(100, 100)); + tp_avg = minstrel_ht_get_tp_avg(mi, i, j, mrs->prob_avg); + eprob = MINSTREL_TRUNC(mrs->prob_avg * 1000); + + p += sprintf(p, "%4u.%1u %4u.%1u %3u.%1u" + " %3u %3u %-3u " + "%9llu %-9llu\n", + tp_max / 10, tp_max % 10, + tp_avg / 10, tp_avg % 10, + eprob / 10, eprob % 10, + mrs->retry_count, + mrs->last_success, + mrs->last_attempts, + (unsigned long long)mrs->succ_hist, + (unsigned long long)mrs->att_hist); + } + + return p; +} + +static int +minstrel_ht_stats_open(struct inode *inode, struct file *file) +{ + struct minstrel_ht_sta *mi = inode->i_private; + struct minstrel_debugfs_info *ms; + unsigned int i; + char *p; + + ms = kmalloc(32768, GFP_KERNEL); + if (!ms) + return -ENOMEM; + + file->private_data = ms; + p = ms->buf; + + p += sprintf(p, "\n"); + p += sprintf(p, + " best ____________rate__________ ____statistics___ _____last____ ______sum-of________\n"); + p += sprintf(p, + "mode guard # rate [name idx airtime max_tp] [avg(tp) avg(prob)] [retry|suc|att] [#success | #attempts]\n"); + + p = minstrel_ht_stats_dump(mi, MINSTREL_CCK_GROUP, p); + for (i = 0; i < MINSTREL_CCK_GROUP; i++) + p = minstrel_ht_stats_dump(mi, i, p); + for (i++; i < ARRAY_SIZE(mi->groups); i++) + p = minstrel_ht_stats_dump(mi, i, p); + + p += sprintf(p, "\nTotal packet count:: ideal %d " + "lookaround %d\n", + max(0, (int) mi->total_packets - (int) mi->sample_packets), + mi->sample_packets); + if (mi->avg_ampdu_len) + p += sprintf(p, "Average # of aggregated frames per A-MPDU: %d.%d\n", + MINSTREL_TRUNC(mi->avg_ampdu_len), + MINSTREL_TRUNC(mi->avg_ampdu_len * 10) % 10); + ms->len = p - ms->buf; + WARN_ON(ms->len + sizeof(*ms) > 32768); + + return nonseekable_open(inode, file); +} + +static const struct file_operations minstrel_ht_stat_fops = { + .owner = THIS_MODULE, + .open = minstrel_ht_stats_open, + .read = minstrel_stats_read, + .release = minstrel_stats_release, + .llseek = no_llseek, +}; + +static char * +minstrel_ht_stats_csv_dump(struct minstrel_ht_sta *mi, int i, char *p) +{ + const struct mcs_group *mg; + unsigned int j, tp_max, tp_avg, eprob, tx_time; + char htmode = '2'; + char gimode = 'L'; + u32 gflags; + + if (!mi->supported[i]) + return p; + + mg = &minstrel_mcs_groups[i]; + gflags = mg->flags; + + if (gflags & IEEE80211_TX_RC_40_MHZ_WIDTH) + htmode = '4'; + else if (gflags & IEEE80211_TX_RC_80_MHZ_WIDTH) + htmode = '8'; + if (gflags & IEEE80211_TX_RC_SHORT_GI) + gimode = 'S'; + + for (j = 0; j < MCS_GROUP_RATES; j++) { + struct minstrel_rate_stats *mrs = &mi->groups[i].rates[j]; + int idx = MI_RATE(i, j); + unsigned int duration; + + if (!(mi->supported[i] & BIT(j))) + continue; + + if (gflags & IEEE80211_TX_RC_MCS) { + p += sprintf(p, "HT%c0,", htmode); + p += sprintf(p, "%cGI,", gimode); + p += sprintf(p, "%d,", mg->streams); + } else if (gflags & IEEE80211_TX_RC_VHT_MCS) { + p += sprintf(p, "VHT%c0,", htmode); + p += sprintf(p, "%cGI,", gimode); + p += sprintf(p, "%d,", mg->streams); + } else if (i == MINSTREL_OFDM_GROUP) { + p += sprintf(p, "OFDM,,1,"); + } else { + p += sprintf(p, "CCK,"); + p += sprintf(p, "%cP,", j < 4 ? 'L' : 'S'); + p += sprintf(p, "1,"); + } + + p += sprintf(p, "%s" ,((idx == mi->max_tp_rate[0]) ? "A" : "")); + p += sprintf(p, "%s" ,((idx == mi->max_tp_rate[1]) ? "B" : "")); + p += sprintf(p, "%s" ,((idx == mi->max_tp_rate[2]) ? "C" : "")); + p += sprintf(p, "%s" ,((idx == mi->max_tp_rate[3]) ? "D" : "")); + p += sprintf(p, "%s" ,((idx == mi->max_prob_rate) ? "P" : "")); + p += sprintf(p, "%s", (minstrel_ht_is_sample_rate(mi, idx) ? "S" : "")); + + if (gflags & IEEE80211_TX_RC_MCS) { + p += sprintf(p, ",MCS%-2u,", (mg->streams - 1) * 8 + j); + } else if (gflags & IEEE80211_TX_RC_VHT_MCS) { + p += sprintf(p, ",MCS%-1u/%1u,", j, mg->streams); + } else { + int r; + + if (i == MINSTREL_OFDM_GROUP) + r = minstrel_ofdm_bitrates[j % 8]; + else + r = minstrel_cck_bitrates[j % 4]; + + p += sprintf(p, ",%2u.%1uM,", r / 10, r % 10); + } + + p += sprintf(p, "%u,", idx); + + duration = mg->duration[j]; + duration <<= mg->shift; + tx_time = DIV_ROUND_CLOSEST(duration, 1000); + p += sprintf(p, "%u,", tx_time); + + tp_max = minstrel_ht_get_tp_avg(mi, i, j, MINSTREL_FRAC(100, 100)); + tp_avg = minstrel_ht_get_tp_avg(mi, i, j, mrs->prob_avg); + eprob = MINSTREL_TRUNC(mrs->prob_avg * 1000); + + p += sprintf(p, "%u.%u,%u.%u,%u.%u,%u,%u," + "%u,%llu,%llu,", + tp_max / 10, tp_max % 10, + tp_avg / 10, tp_avg % 10, + eprob / 10, eprob % 10, + mrs->retry_count, + mrs->last_success, + mrs->last_attempts, + (unsigned long long)mrs->succ_hist, + (unsigned long long)mrs->att_hist); + p += sprintf(p, "%d,%d,%d.%d\n", + max(0, (int) mi->total_packets - + (int) mi->sample_packets), + mi->sample_packets, + MINSTREL_TRUNC(mi->avg_ampdu_len), + MINSTREL_TRUNC(mi->avg_ampdu_len * 10) % 10); + } + + return p; +} + +static int +minstrel_ht_stats_csv_open(struct inode *inode, struct file *file) +{ + struct minstrel_ht_sta *mi = inode->i_private; + struct minstrel_debugfs_info *ms; + unsigned int i; + char *p; + + ms = kmalloc(32768, GFP_KERNEL); + if (!ms) + return -ENOMEM; + + file->private_data = ms; + + p = ms->buf; + + p = minstrel_ht_stats_csv_dump(mi, MINSTREL_CCK_GROUP, p); + for (i = 0; i < MINSTREL_CCK_GROUP; i++) + p = minstrel_ht_stats_csv_dump(mi, i, p); + for (i++; i < ARRAY_SIZE(mi->groups); i++) + p = minstrel_ht_stats_csv_dump(mi, i, p); + + ms->len = p - ms->buf; + WARN_ON(ms->len + sizeof(*ms) > 32768); + + return nonseekable_open(inode, file); +} + +static const struct file_operations minstrel_ht_stat_csv_fops = { + .owner = THIS_MODULE, + .open = minstrel_ht_stats_csv_open, + .read = minstrel_stats_read, + .release = minstrel_stats_release, + .llseek = no_llseek, +}; + +void +minstrel_ht_add_sta_debugfs(void *priv, void *priv_sta, struct dentry *dir) +{ + debugfs_create_file("rc_stats", 0444, dir, priv_sta, + &minstrel_ht_stat_fops); + debugfs_create_file("rc_stats_csv", 0444, dir, priv_sta, + &minstrel_ht_stat_csv_fops); +} diff --git a/net/mac80211/rx.c b/net/mac80211/rx.c new file mode 100644 index 000000000..c4c80037d --- /dev/null +++ b/net/mac80211/rx.c @@ -0,0 +1,5309 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2002-2005, Instant802 Networks, Inc. + * Copyright 2005-2006, Devicescape Software, Inc. + * Copyright 2006-2007 Jiri Benc + * Copyright 2007-2010 Johannes Berg + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright(c) 2015 - 2017 Intel Deutschland GmbH + * Copyright (C) 2018-2022 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ieee80211_i.h" +#include "driver-ops.h" +#include "led.h" +#include "mesh.h" +#include "wep.h" +#include "wpa.h" +#include "tkip.h" +#include "wme.h" +#include "rate.h" + +/* + * monitor mode reception + * + * This function cleans up the SKB, i.e. it removes all the stuff + * only useful for monitoring. + */ +static struct sk_buff *ieee80211_clean_skb(struct sk_buff *skb, + unsigned int present_fcs_len, + unsigned int rtap_space) +{ + struct ieee80211_hdr *hdr; + unsigned int hdrlen; + __le16 fc; + + if (present_fcs_len) + __pskb_trim(skb, skb->len - present_fcs_len); + pskb_pull(skb, rtap_space); + + hdr = (void *)skb->data; + fc = hdr->frame_control; + + /* + * Remove the HT-Control field (if present) on management + * frames after we've sent the frame to monitoring. We + * (currently) don't need it, and don't properly parse + * frames with it present, due to the assumption of a + * fixed management header length. + */ + if (likely(!ieee80211_is_mgmt(fc) || !ieee80211_has_order(fc))) + return skb; + + hdrlen = ieee80211_hdrlen(fc); + hdr->frame_control &= ~cpu_to_le16(IEEE80211_FCTL_ORDER); + + if (!pskb_may_pull(skb, hdrlen)) { + dev_kfree_skb(skb); + return NULL; + } + + memmove(skb->data + IEEE80211_HT_CTL_LEN, skb->data, + hdrlen - IEEE80211_HT_CTL_LEN); + pskb_pull(skb, IEEE80211_HT_CTL_LEN); + + return skb; +} + +static inline bool should_drop_frame(struct sk_buff *skb, int present_fcs_len, + unsigned int rtap_space) +{ + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + struct ieee80211_hdr *hdr; + + hdr = (void *)(skb->data + rtap_space); + + if (status->flag & (RX_FLAG_FAILED_FCS_CRC | + RX_FLAG_FAILED_PLCP_CRC | + RX_FLAG_ONLY_MONITOR | + RX_FLAG_NO_PSDU)) + return true; + + if (unlikely(skb->len < 16 + present_fcs_len + rtap_space)) + return true; + + if (ieee80211_is_ctl(hdr->frame_control) && + !ieee80211_is_pspoll(hdr->frame_control) && + !ieee80211_is_back_req(hdr->frame_control)) + return true; + + return false; +} + +static int +ieee80211_rx_radiotap_hdrlen(struct ieee80211_local *local, + struct ieee80211_rx_status *status, + struct sk_buff *skb) +{ + int len; + + /* always present fields */ + len = sizeof(struct ieee80211_radiotap_header) + 8; + + /* allocate extra bitmaps */ + if (status->chains) + len += 4 * hweight8(status->chains); + /* vendor presence bitmap */ + if (status->flag & RX_FLAG_RADIOTAP_VENDOR_DATA) + len += 4; + + if (ieee80211_have_rx_timestamp(status)) { + len = ALIGN(len, 8); + len += 8; + } + if (ieee80211_hw_check(&local->hw, SIGNAL_DBM)) + len += 1; + + /* antenna field, if we don't have per-chain info */ + if (!status->chains) + len += 1; + + /* padding for RX_FLAGS if necessary */ + len = ALIGN(len, 2); + + if (status->encoding == RX_ENC_HT) /* HT info */ + len += 3; + + if (status->flag & RX_FLAG_AMPDU_DETAILS) { + len = ALIGN(len, 4); + len += 8; + } + + if (status->encoding == RX_ENC_VHT) { + len = ALIGN(len, 2); + len += 12; + } + + if (local->hw.radiotap_timestamp.units_pos >= 0) { + len = ALIGN(len, 8); + len += 12; + } + + if (status->encoding == RX_ENC_HE && + status->flag & RX_FLAG_RADIOTAP_HE) { + len = ALIGN(len, 2); + len += 12; + BUILD_BUG_ON(sizeof(struct ieee80211_radiotap_he) != 12); + } + + if (status->encoding == RX_ENC_HE && + status->flag & RX_FLAG_RADIOTAP_HE_MU) { + len = ALIGN(len, 2); + len += 12; + BUILD_BUG_ON(sizeof(struct ieee80211_radiotap_he_mu) != 12); + } + + if (status->flag & RX_FLAG_NO_PSDU) + len += 1; + + if (status->flag & RX_FLAG_RADIOTAP_LSIG) { + len = ALIGN(len, 2); + len += 4; + BUILD_BUG_ON(sizeof(struct ieee80211_radiotap_lsig) != 4); + } + + if (status->chains) { + /* antenna and antenna signal fields */ + len += 2 * hweight8(status->chains); + } + + if (status->flag & RX_FLAG_RADIOTAP_VENDOR_DATA) { + struct ieee80211_vendor_radiotap *rtap; + int vendor_data_offset = 0; + + /* + * The position to look at depends on the existence (or non- + * existence) of other elements, so take that into account... + */ + if (status->flag & RX_FLAG_RADIOTAP_HE) + vendor_data_offset += + sizeof(struct ieee80211_radiotap_he); + if (status->flag & RX_FLAG_RADIOTAP_HE_MU) + vendor_data_offset += + sizeof(struct ieee80211_radiotap_he_mu); + if (status->flag & RX_FLAG_RADIOTAP_LSIG) + vendor_data_offset += + sizeof(struct ieee80211_radiotap_lsig); + + rtap = (void *)&skb->data[vendor_data_offset]; + + /* alignment for fixed 6-byte vendor data header */ + len = ALIGN(len, 2); + /* vendor data header */ + len += 6; + if (WARN_ON(rtap->align == 0)) + rtap->align = 1; + len = ALIGN(len, rtap->align); + len += rtap->len + rtap->pad; + } + + return len; +} + +static void __ieee80211_queue_skb_to_iface(struct ieee80211_sub_if_data *sdata, + int link_id, + struct sta_info *sta, + struct sk_buff *skb) +{ + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + + if (link_id >= 0) { + status->link_valid = 1; + status->link_id = link_id; + } else { + status->link_valid = 0; + } + + skb_queue_tail(&sdata->skb_queue, skb); + ieee80211_queue_work(&sdata->local->hw, &sdata->work); + if (sta) + sta->deflink.rx_stats.packets++; +} + +static void ieee80211_queue_skb_to_iface(struct ieee80211_sub_if_data *sdata, + int link_id, + struct sta_info *sta, + struct sk_buff *skb) +{ + skb->protocol = 0; + __ieee80211_queue_skb_to_iface(sdata, link_id, sta, skb); +} + +static void ieee80211_handle_mu_mimo_mon(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, + int rtap_space) +{ + struct { + struct ieee80211_hdr_3addr hdr; + u8 category; + u8 action_code; + } __packed __aligned(2) action; + + if (!sdata) + return; + + BUILD_BUG_ON(sizeof(action) != IEEE80211_MIN_ACTION_SIZE + 1); + + if (skb->len < rtap_space + sizeof(action) + + VHT_MUMIMO_GROUPS_DATA_LEN) + return; + + if (!is_valid_ether_addr(sdata->u.mntr.mu_follow_addr)) + return; + + skb_copy_bits(skb, rtap_space, &action, sizeof(action)); + + if (!ieee80211_is_action(action.hdr.frame_control)) + return; + + if (action.category != WLAN_CATEGORY_VHT) + return; + + if (action.action_code != WLAN_VHT_ACTION_GROUPID_MGMT) + return; + + if (!ether_addr_equal(action.hdr.addr1, sdata->u.mntr.mu_follow_addr)) + return; + + skb = skb_copy(skb, GFP_ATOMIC); + if (!skb) + return; + + ieee80211_queue_skb_to_iface(sdata, -1, NULL, skb); +} + +/* + * ieee80211_add_rx_radiotap_header - add radiotap header + * + * add a radiotap header containing all the fields which the hardware provided. + */ +static void +ieee80211_add_rx_radiotap_header(struct ieee80211_local *local, + struct sk_buff *skb, + struct ieee80211_rate *rate, + int rtap_len, bool has_fcs) +{ + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + struct ieee80211_radiotap_header *rthdr; + unsigned char *pos; + __le32 *it_present; + u32 it_present_val; + u16 rx_flags = 0; + u16 channel_flags = 0; + int mpdulen, chain; + unsigned long chains = status->chains; + struct ieee80211_vendor_radiotap rtap = {}; + struct ieee80211_radiotap_he he = {}; + struct ieee80211_radiotap_he_mu he_mu = {}; + struct ieee80211_radiotap_lsig lsig = {}; + + if (status->flag & RX_FLAG_RADIOTAP_HE) { + he = *(struct ieee80211_radiotap_he *)skb->data; + skb_pull(skb, sizeof(he)); + WARN_ON_ONCE(status->encoding != RX_ENC_HE); + } + + if (status->flag & RX_FLAG_RADIOTAP_HE_MU) { + he_mu = *(struct ieee80211_radiotap_he_mu *)skb->data; + skb_pull(skb, sizeof(he_mu)); + } + + if (status->flag & RX_FLAG_RADIOTAP_LSIG) { + lsig = *(struct ieee80211_radiotap_lsig *)skb->data; + skb_pull(skb, sizeof(lsig)); + } + + if (status->flag & RX_FLAG_RADIOTAP_VENDOR_DATA) { + rtap = *(struct ieee80211_vendor_radiotap *)skb->data; + /* rtap.len and rtap.pad are undone immediately */ + skb_pull(skb, sizeof(rtap) + rtap.len + rtap.pad); + } + + mpdulen = skb->len; + if (!(has_fcs && ieee80211_hw_check(&local->hw, RX_INCLUDES_FCS))) + mpdulen += FCS_LEN; + + rthdr = skb_push(skb, rtap_len); + memset(rthdr, 0, rtap_len - rtap.len - rtap.pad); + it_present = &rthdr->it_present; + + /* radiotap header, set always present flags */ + rthdr->it_len = cpu_to_le16(rtap_len); + it_present_val = BIT(IEEE80211_RADIOTAP_FLAGS) | + BIT(IEEE80211_RADIOTAP_CHANNEL) | + BIT(IEEE80211_RADIOTAP_RX_FLAGS); + + if (!status->chains) + it_present_val |= BIT(IEEE80211_RADIOTAP_ANTENNA); + + for_each_set_bit(chain, &chains, IEEE80211_MAX_CHAINS) { + it_present_val |= + BIT(IEEE80211_RADIOTAP_EXT) | + BIT(IEEE80211_RADIOTAP_RADIOTAP_NAMESPACE); + put_unaligned_le32(it_present_val, it_present); + it_present++; + it_present_val = BIT(IEEE80211_RADIOTAP_ANTENNA) | + BIT(IEEE80211_RADIOTAP_DBM_ANTSIGNAL); + } + + if (status->flag & RX_FLAG_RADIOTAP_VENDOR_DATA) { + it_present_val |= BIT(IEEE80211_RADIOTAP_VENDOR_NAMESPACE) | + BIT(IEEE80211_RADIOTAP_EXT); + put_unaligned_le32(it_present_val, it_present); + it_present++; + it_present_val = rtap.present; + } + + put_unaligned_le32(it_present_val, it_present); + + /* This references through an offset into it_optional[] rather + * than via it_present otherwise later uses of pos will cause + * the compiler to think we have walked past the end of the + * struct member. + */ + pos = (void *)&rthdr->it_optional[it_present + 1 - rthdr->it_optional]; + + /* the order of the following fields is important */ + + /* IEEE80211_RADIOTAP_TSFT */ + if (ieee80211_have_rx_timestamp(status)) { + /* padding */ + while ((pos - (u8 *)rthdr) & 7) + *pos++ = 0; + put_unaligned_le64( + ieee80211_calculate_rx_timestamp(local, status, + mpdulen, 0), + pos); + rthdr->it_present |= cpu_to_le32(BIT(IEEE80211_RADIOTAP_TSFT)); + pos += 8; + } + + /* IEEE80211_RADIOTAP_FLAGS */ + if (has_fcs && ieee80211_hw_check(&local->hw, RX_INCLUDES_FCS)) + *pos |= IEEE80211_RADIOTAP_F_FCS; + if (status->flag & (RX_FLAG_FAILED_FCS_CRC | RX_FLAG_FAILED_PLCP_CRC)) + *pos |= IEEE80211_RADIOTAP_F_BADFCS; + if (status->enc_flags & RX_ENC_FLAG_SHORTPRE) + *pos |= IEEE80211_RADIOTAP_F_SHORTPRE; + pos++; + + /* IEEE80211_RADIOTAP_RATE */ + if (!rate || status->encoding != RX_ENC_LEGACY) { + /* + * Without rate information don't add it. If we have, + * MCS information is a separate field in radiotap, + * added below. The byte here is needed as padding + * for the channel though, so initialise it to 0. + */ + *pos = 0; + } else { + int shift = 0; + rthdr->it_present |= cpu_to_le32(BIT(IEEE80211_RADIOTAP_RATE)); + if (status->bw == RATE_INFO_BW_10) + shift = 1; + else if (status->bw == RATE_INFO_BW_5) + shift = 2; + *pos = DIV_ROUND_UP(rate->bitrate, 5 * (1 << shift)); + } + pos++; + + /* IEEE80211_RADIOTAP_CHANNEL */ + /* TODO: frequency offset in KHz */ + put_unaligned_le16(status->freq, pos); + pos += 2; + if (status->bw == RATE_INFO_BW_10) + channel_flags |= IEEE80211_CHAN_HALF; + else if (status->bw == RATE_INFO_BW_5) + channel_flags |= IEEE80211_CHAN_QUARTER; + + if (status->band == NL80211_BAND_5GHZ || + status->band == NL80211_BAND_6GHZ) + channel_flags |= IEEE80211_CHAN_OFDM | IEEE80211_CHAN_5GHZ; + else if (status->encoding != RX_ENC_LEGACY) + channel_flags |= IEEE80211_CHAN_DYN | IEEE80211_CHAN_2GHZ; + else if (rate && rate->flags & IEEE80211_RATE_ERP_G) + channel_flags |= IEEE80211_CHAN_OFDM | IEEE80211_CHAN_2GHZ; + else if (rate) + channel_flags |= IEEE80211_CHAN_CCK | IEEE80211_CHAN_2GHZ; + else + channel_flags |= IEEE80211_CHAN_2GHZ; + put_unaligned_le16(channel_flags, pos); + pos += 2; + + /* IEEE80211_RADIOTAP_DBM_ANTSIGNAL */ + if (ieee80211_hw_check(&local->hw, SIGNAL_DBM) && + !(status->flag & RX_FLAG_NO_SIGNAL_VAL)) { + *pos = status->signal; + rthdr->it_present |= + cpu_to_le32(BIT(IEEE80211_RADIOTAP_DBM_ANTSIGNAL)); + pos++; + } + + /* IEEE80211_RADIOTAP_LOCK_QUALITY is missing */ + + if (!status->chains) { + /* IEEE80211_RADIOTAP_ANTENNA */ + *pos = status->antenna; + pos++; + } + + /* IEEE80211_RADIOTAP_DB_ANTNOISE is not used */ + + /* IEEE80211_RADIOTAP_RX_FLAGS */ + /* ensure 2 byte alignment for the 2 byte field as required */ + if ((pos - (u8 *)rthdr) & 1) + *pos++ = 0; + if (status->flag & RX_FLAG_FAILED_PLCP_CRC) + rx_flags |= IEEE80211_RADIOTAP_F_RX_BADPLCP; + put_unaligned_le16(rx_flags, pos); + pos += 2; + + if (status->encoding == RX_ENC_HT) { + unsigned int stbc; + + rthdr->it_present |= cpu_to_le32(BIT(IEEE80211_RADIOTAP_MCS)); + *pos = local->hw.radiotap_mcs_details; + if (status->enc_flags & RX_ENC_FLAG_HT_GF) + *pos |= IEEE80211_RADIOTAP_MCS_HAVE_FMT; + if (status->enc_flags & RX_ENC_FLAG_LDPC) + *pos |= IEEE80211_RADIOTAP_MCS_HAVE_FEC; + pos++; + *pos = 0; + if (status->enc_flags & RX_ENC_FLAG_SHORT_GI) + *pos |= IEEE80211_RADIOTAP_MCS_SGI; + if (status->bw == RATE_INFO_BW_40) + *pos |= IEEE80211_RADIOTAP_MCS_BW_40; + if (status->enc_flags & RX_ENC_FLAG_HT_GF) + *pos |= IEEE80211_RADIOTAP_MCS_FMT_GF; + if (status->enc_flags & RX_ENC_FLAG_LDPC) + *pos |= IEEE80211_RADIOTAP_MCS_FEC_LDPC; + stbc = (status->enc_flags & RX_ENC_FLAG_STBC_MASK) >> RX_ENC_FLAG_STBC_SHIFT; + *pos |= stbc << IEEE80211_RADIOTAP_MCS_STBC_SHIFT; + pos++; + *pos++ = status->rate_idx; + } + + if (status->flag & RX_FLAG_AMPDU_DETAILS) { + u16 flags = 0; + + /* ensure 4 byte alignment */ + while ((pos - (u8 *)rthdr) & 3) + pos++; + rthdr->it_present |= + cpu_to_le32(BIT(IEEE80211_RADIOTAP_AMPDU_STATUS)); + put_unaligned_le32(status->ampdu_reference, pos); + pos += 4; + if (status->flag & RX_FLAG_AMPDU_LAST_KNOWN) + flags |= IEEE80211_RADIOTAP_AMPDU_LAST_KNOWN; + if (status->flag & RX_FLAG_AMPDU_IS_LAST) + flags |= IEEE80211_RADIOTAP_AMPDU_IS_LAST; + if (status->flag & RX_FLAG_AMPDU_DELIM_CRC_ERROR) + flags |= IEEE80211_RADIOTAP_AMPDU_DELIM_CRC_ERR; + if (status->flag & RX_FLAG_AMPDU_DELIM_CRC_KNOWN) + flags |= IEEE80211_RADIOTAP_AMPDU_DELIM_CRC_KNOWN; + if (status->flag & RX_FLAG_AMPDU_EOF_BIT_KNOWN) + flags |= IEEE80211_RADIOTAP_AMPDU_EOF_KNOWN; + if (status->flag & RX_FLAG_AMPDU_EOF_BIT) + flags |= IEEE80211_RADIOTAP_AMPDU_EOF; + put_unaligned_le16(flags, pos); + pos += 2; + if (status->flag & RX_FLAG_AMPDU_DELIM_CRC_KNOWN) + *pos++ = status->ampdu_delimiter_crc; + else + *pos++ = 0; + *pos++ = 0; + } + + if (status->encoding == RX_ENC_VHT) { + u16 known = local->hw.radiotap_vht_details; + + rthdr->it_present |= cpu_to_le32(BIT(IEEE80211_RADIOTAP_VHT)); + put_unaligned_le16(known, pos); + pos += 2; + /* flags */ + if (status->enc_flags & RX_ENC_FLAG_SHORT_GI) + *pos |= IEEE80211_RADIOTAP_VHT_FLAG_SGI; + /* in VHT, STBC is binary */ + if (status->enc_flags & RX_ENC_FLAG_STBC_MASK) + *pos |= IEEE80211_RADIOTAP_VHT_FLAG_STBC; + if (status->enc_flags & RX_ENC_FLAG_BF) + *pos |= IEEE80211_RADIOTAP_VHT_FLAG_BEAMFORMED; + pos++; + /* bandwidth */ + switch (status->bw) { + case RATE_INFO_BW_80: + *pos++ = 4; + break; + case RATE_INFO_BW_160: + *pos++ = 11; + break; + case RATE_INFO_BW_40: + *pos++ = 1; + break; + default: + *pos++ = 0; + } + /* MCS/NSS */ + *pos = (status->rate_idx << 4) | status->nss; + pos += 4; + /* coding field */ + if (status->enc_flags & RX_ENC_FLAG_LDPC) + *pos |= IEEE80211_RADIOTAP_CODING_LDPC_USER0; + pos++; + /* group ID */ + pos++; + /* partial_aid */ + pos += 2; + } + + if (local->hw.radiotap_timestamp.units_pos >= 0) { + u16 accuracy = 0; + u8 flags = IEEE80211_RADIOTAP_TIMESTAMP_FLAG_32BIT; + + rthdr->it_present |= + cpu_to_le32(BIT(IEEE80211_RADIOTAP_TIMESTAMP)); + + /* ensure 8 byte alignment */ + while ((pos - (u8 *)rthdr) & 7) + pos++; + + put_unaligned_le64(status->device_timestamp, pos); + pos += sizeof(u64); + + if (local->hw.radiotap_timestamp.accuracy >= 0) { + accuracy = local->hw.radiotap_timestamp.accuracy; + flags |= IEEE80211_RADIOTAP_TIMESTAMP_FLAG_ACCURACY; + } + put_unaligned_le16(accuracy, pos); + pos += sizeof(u16); + + *pos++ = local->hw.radiotap_timestamp.units_pos; + *pos++ = flags; + } + + if (status->encoding == RX_ENC_HE && + status->flag & RX_FLAG_RADIOTAP_HE) { +#define HE_PREP(f, val) le16_encode_bits(val, IEEE80211_RADIOTAP_HE_##f) + + if (status->enc_flags & RX_ENC_FLAG_STBC_MASK) { + he.data6 |= HE_PREP(DATA6_NSTS, + FIELD_GET(RX_ENC_FLAG_STBC_MASK, + status->enc_flags)); + he.data3 |= HE_PREP(DATA3_STBC, 1); + } else { + he.data6 |= HE_PREP(DATA6_NSTS, status->nss); + } + +#define CHECK_GI(s) \ + BUILD_BUG_ON(IEEE80211_RADIOTAP_HE_DATA5_GI_##s != \ + (int)NL80211_RATE_INFO_HE_GI_##s) + + CHECK_GI(0_8); + CHECK_GI(1_6); + CHECK_GI(3_2); + + he.data3 |= HE_PREP(DATA3_DATA_MCS, status->rate_idx); + he.data3 |= HE_PREP(DATA3_DATA_DCM, status->he_dcm); + he.data3 |= HE_PREP(DATA3_CODING, + !!(status->enc_flags & RX_ENC_FLAG_LDPC)); + + he.data5 |= HE_PREP(DATA5_GI, status->he_gi); + + switch (status->bw) { + case RATE_INFO_BW_20: + he.data5 |= HE_PREP(DATA5_DATA_BW_RU_ALLOC, + IEEE80211_RADIOTAP_HE_DATA5_DATA_BW_RU_ALLOC_20MHZ); + break; + case RATE_INFO_BW_40: + he.data5 |= HE_PREP(DATA5_DATA_BW_RU_ALLOC, + IEEE80211_RADIOTAP_HE_DATA5_DATA_BW_RU_ALLOC_40MHZ); + break; + case RATE_INFO_BW_80: + he.data5 |= HE_PREP(DATA5_DATA_BW_RU_ALLOC, + IEEE80211_RADIOTAP_HE_DATA5_DATA_BW_RU_ALLOC_80MHZ); + break; + case RATE_INFO_BW_160: + he.data5 |= HE_PREP(DATA5_DATA_BW_RU_ALLOC, + IEEE80211_RADIOTAP_HE_DATA5_DATA_BW_RU_ALLOC_160MHZ); + break; + case RATE_INFO_BW_HE_RU: +#define CHECK_RU_ALLOC(s) \ + BUILD_BUG_ON(IEEE80211_RADIOTAP_HE_DATA5_DATA_BW_RU_ALLOC_##s##T != \ + NL80211_RATE_INFO_HE_RU_ALLOC_##s + 4) + + CHECK_RU_ALLOC(26); + CHECK_RU_ALLOC(52); + CHECK_RU_ALLOC(106); + CHECK_RU_ALLOC(242); + CHECK_RU_ALLOC(484); + CHECK_RU_ALLOC(996); + CHECK_RU_ALLOC(2x996); + + he.data5 |= HE_PREP(DATA5_DATA_BW_RU_ALLOC, + status->he_ru + 4); + break; + default: + WARN_ONCE(1, "Invalid SU BW %d\n", status->bw); + } + + /* ensure 2 byte alignment */ + while ((pos - (u8 *)rthdr) & 1) + pos++; + rthdr->it_present |= cpu_to_le32(BIT(IEEE80211_RADIOTAP_HE)); + memcpy(pos, &he, sizeof(he)); + pos += sizeof(he); + } + + if (status->encoding == RX_ENC_HE && + status->flag & RX_FLAG_RADIOTAP_HE_MU) { + /* ensure 2 byte alignment */ + while ((pos - (u8 *)rthdr) & 1) + pos++; + rthdr->it_present |= cpu_to_le32(BIT(IEEE80211_RADIOTAP_HE_MU)); + memcpy(pos, &he_mu, sizeof(he_mu)); + pos += sizeof(he_mu); + } + + if (status->flag & RX_FLAG_NO_PSDU) { + rthdr->it_present |= + cpu_to_le32(BIT(IEEE80211_RADIOTAP_ZERO_LEN_PSDU)); + *pos++ = status->zero_length_psdu_type; + } + + if (status->flag & RX_FLAG_RADIOTAP_LSIG) { + /* ensure 2 byte alignment */ + while ((pos - (u8 *)rthdr) & 1) + pos++; + rthdr->it_present |= cpu_to_le32(BIT(IEEE80211_RADIOTAP_LSIG)); + memcpy(pos, &lsig, sizeof(lsig)); + pos += sizeof(lsig); + } + + for_each_set_bit(chain, &chains, IEEE80211_MAX_CHAINS) { + *pos++ = status->chain_signal[chain]; + *pos++ = chain; + } + + if (status->flag & RX_FLAG_RADIOTAP_VENDOR_DATA) { + /* ensure 2 byte alignment for the vendor field as required */ + if ((pos - (u8 *)rthdr) & 1) + *pos++ = 0; + *pos++ = rtap.oui[0]; + *pos++ = rtap.oui[1]; + *pos++ = rtap.oui[2]; + *pos++ = rtap.subns; + put_unaligned_le16(rtap.len, pos); + pos += 2; + /* align the actual payload as requested */ + while ((pos - (u8 *)rthdr) & (rtap.align - 1)) + *pos++ = 0; + /* data (and possible padding) already follows */ + } +} + +static struct sk_buff * +ieee80211_make_monitor_skb(struct ieee80211_local *local, + struct sk_buff **origskb, + struct ieee80211_rate *rate, + int rtap_space, bool use_origskb) +{ + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(*origskb); + int rt_hdrlen, needed_headroom; + struct sk_buff *skb; + + /* room for the radiotap header based on driver features */ + rt_hdrlen = ieee80211_rx_radiotap_hdrlen(local, status, *origskb); + needed_headroom = rt_hdrlen - rtap_space; + + if (use_origskb) { + /* only need to expand headroom if necessary */ + skb = *origskb; + *origskb = NULL; + + /* + * This shouldn't trigger often because most devices have an + * RX header they pull before we get here, and that should + * be big enough for our radiotap information. We should + * probably export the length to drivers so that we can have + * them allocate enough headroom to start with. + */ + if (skb_headroom(skb) < needed_headroom && + pskb_expand_head(skb, needed_headroom, 0, GFP_ATOMIC)) { + dev_kfree_skb(skb); + return NULL; + } + } else { + /* + * Need to make a copy and possibly remove radiotap header + * and FCS from the original. + */ + skb = skb_copy_expand(*origskb, needed_headroom + NET_SKB_PAD, + 0, GFP_ATOMIC); + + if (!skb) + return NULL; + } + + /* prepend radiotap information */ + ieee80211_add_rx_radiotap_header(local, skb, rate, rt_hdrlen, true); + + skb_reset_mac_header(skb); + skb->ip_summed = CHECKSUM_UNNECESSARY; + skb->pkt_type = PACKET_OTHERHOST; + skb->protocol = htons(ETH_P_802_2); + + return skb; +} + +/* + * This function copies a received frame to all monitor interfaces and + * returns a cleaned-up SKB that no longer includes the FCS nor the + * radiotap header the driver might have added. + */ +static struct sk_buff * +ieee80211_rx_monitor(struct ieee80211_local *local, struct sk_buff *origskb, + struct ieee80211_rate *rate) +{ + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(origskb); + struct ieee80211_sub_if_data *sdata; + struct sk_buff *monskb = NULL; + int present_fcs_len = 0; + unsigned int rtap_space = 0; + struct ieee80211_sub_if_data *monitor_sdata = + rcu_dereference(local->monitor_sdata); + bool only_monitor = false; + unsigned int min_head_len; + + if (status->flag & RX_FLAG_RADIOTAP_HE) + rtap_space += sizeof(struct ieee80211_radiotap_he); + + if (status->flag & RX_FLAG_RADIOTAP_HE_MU) + rtap_space += sizeof(struct ieee80211_radiotap_he_mu); + + if (status->flag & RX_FLAG_RADIOTAP_LSIG) + rtap_space += sizeof(struct ieee80211_radiotap_lsig); + + if (unlikely(status->flag & RX_FLAG_RADIOTAP_VENDOR_DATA)) { + struct ieee80211_vendor_radiotap *rtap = + (void *)(origskb->data + rtap_space); + + rtap_space += sizeof(*rtap) + rtap->len + rtap->pad; + } + + min_head_len = rtap_space; + + /* + * First, we may need to make a copy of the skb because + * (1) we need to modify it for radiotap (if not present), and + * (2) the other RX handlers will modify the skb we got. + * + * We don't need to, of course, if we aren't going to return + * the SKB because it has a bad FCS/PLCP checksum. + */ + + if (!(status->flag & RX_FLAG_NO_PSDU)) { + if (ieee80211_hw_check(&local->hw, RX_INCLUDES_FCS)) { + if (unlikely(origskb->len <= FCS_LEN + rtap_space)) { + /* driver bug */ + WARN_ON(1); + dev_kfree_skb(origskb); + return NULL; + } + present_fcs_len = FCS_LEN; + } + + /* also consider the hdr->frame_control */ + min_head_len += 2; + } + + /* ensure that the expected data elements are in skb head */ + if (!pskb_may_pull(origskb, min_head_len)) { + dev_kfree_skb(origskb); + return NULL; + } + + only_monitor = should_drop_frame(origskb, present_fcs_len, rtap_space); + + if (!local->monitors || (status->flag & RX_FLAG_SKIP_MONITOR)) { + if (only_monitor) { + dev_kfree_skb(origskb); + return NULL; + } + + return ieee80211_clean_skb(origskb, present_fcs_len, + rtap_space); + } + + ieee80211_handle_mu_mimo_mon(monitor_sdata, origskb, rtap_space); + + list_for_each_entry_rcu(sdata, &local->mon_list, u.mntr.list) { + bool last_monitor = list_is_last(&sdata->u.mntr.list, + &local->mon_list); + + if (!monskb) + monskb = ieee80211_make_monitor_skb(local, &origskb, + rate, rtap_space, + only_monitor && + last_monitor); + + if (monskb) { + struct sk_buff *skb; + + if (last_monitor) { + skb = monskb; + monskb = NULL; + } else { + skb = skb_clone(monskb, GFP_ATOMIC); + } + + if (skb) { + skb->dev = sdata->dev; + dev_sw_netstats_rx_add(skb->dev, skb->len); + netif_receive_skb(skb); + } + } + + if (last_monitor) + break; + } + + /* this happens if last_monitor was erroneously false */ + dev_kfree_skb(monskb); + + /* ditto */ + if (!origskb) + return NULL; + + return ieee80211_clean_skb(origskb, present_fcs_len, rtap_space); +} + +static void ieee80211_parse_qos(struct ieee80211_rx_data *rx) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)rx->skb->data; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(rx->skb); + int tid, seqno_idx, security_idx; + + /* does the frame have a qos control field? */ + if (ieee80211_is_data_qos(hdr->frame_control)) { + u8 *qc = ieee80211_get_qos_ctl(hdr); + /* frame has qos control */ + tid = *qc & IEEE80211_QOS_CTL_TID_MASK; + if (*qc & IEEE80211_QOS_CTL_A_MSDU_PRESENT) + status->rx_flags |= IEEE80211_RX_AMSDU; + + seqno_idx = tid; + security_idx = tid; + } else { + /* + * IEEE 802.11-2007, 7.1.3.4.1 ("Sequence Number field"): + * + * Sequence numbers for management frames, QoS data + * frames with a broadcast/multicast address in the + * Address 1 field, and all non-QoS data frames sent + * by QoS STAs are assigned using an additional single + * modulo-4096 counter, [...] + * + * We also use that counter for non-QoS STAs. + */ + seqno_idx = IEEE80211_NUM_TIDS; + security_idx = 0; + if (ieee80211_is_mgmt(hdr->frame_control)) + security_idx = IEEE80211_NUM_TIDS; + tid = 0; + } + + rx->seqno_idx = seqno_idx; + rx->security_idx = security_idx; + /* Set skb->priority to 1d tag if highest order bit of TID is not set. + * For now, set skb->priority to 0 for other cases. */ + rx->skb->priority = (tid > 7) ? 0 : tid; +} + +/** + * DOC: Packet alignment + * + * Drivers always need to pass packets that are aligned to two-byte boundaries + * to the stack. + * + * Additionally, should, if possible, align the payload data in a way that + * guarantees that the contained IP header is aligned to a four-byte + * boundary. In the case of regular frames, this simply means aligning the + * payload to a four-byte boundary (because either the IP header is directly + * contained, or IV/RFC1042 headers that have a length divisible by four are + * in front of it). If the payload data is not properly aligned and the + * architecture doesn't support efficient unaligned operations, mac80211 + * will align the data. + * + * With A-MSDU frames, however, the payload data address must yield two modulo + * four because there are 14-byte 802.3 headers within the A-MSDU frames that + * push the IP header further back to a multiple of four again. Thankfully, the + * specs were sane enough this time around to require padding each A-MSDU + * subframe to a length that is a multiple of four. + * + * Padding like Atheros hardware adds which is between the 802.11 header and + * the payload is not supported, the driver is required to move the 802.11 + * header to be directly in front of the payload in that case. + */ +static void ieee80211_verify_alignment(struct ieee80211_rx_data *rx) +{ +#ifdef CONFIG_MAC80211_VERBOSE_DEBUG + WARN_ON_ONCE((unsigned long)rx->skb->data & 1); +#endif +} + + +/* rx handlers */ + +static int ieee80211_is_unicast_robust_mgmt_frame(struct sk_buff *skb) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + + if (is_multicast_ether_addr(hdr->addr1)) + return 0; + + return ieee80211_is_robust_mgmt_frame(skb); +} + + +static int ieee80211_is_multicast_robust_mgmt_frame(struct sk_buff *skb) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + + if (!is_multicast_ether_addr(hdr->addr1)) + return 0; + + return ieee80211_is_robust_mgmt_frame(skb); +} + + +/* Get the BIP key index from MMIE; return -1 if this is not a BIP frame */ +static int ieee80211_get_mmie_keyidx(struct sk_buff *skb) +{ + struct ieee80211_mgmt *hdr = (struct ieee80211_mgmt *) skb->data; + struct ieee80211_mmie *mmie; + struct ieee80211_mmie_16 *mmie16; + + if (skb->len < 24 + sizeof(*mmie) || !is_multicast_ether_addr(hdr->da)) + return -1; + + if (!ieee80211_is_robust_mgmt_frame(skb) && + !ieee80211_is_beacon(hdr->frame_control)) + return -1; /* not a robust management frame */ + + mmie = (struct ieee80211_mmie *) + (skb->data + skb->len - sizeof(*mmie)); + if (mmie->element_id == WLAN_EID_MMIE && + mmie->length == sizeof(*mmie) - 2) + return le16_to_cpu(mmie->key_id); + + mmie16 = (struct ieee80211_mmie_16 *) + (skb->data + skb->len - sizeof(*mmie16)); + if (skb->len >= 24 + sizeof(*mmie16) && + mmie16->element_id == WLAN_EID_MMIE && + mmie16->length == sizeof(*mmie16) - 2) + return le16_to_cpu(mmie16->key_id); + + return -1; +} + +static int ieee80211_get_keyid(struct sk_buff *skb) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + __le16 fc = hdr->frame_control; + int hdrlen = ieee80211_hdrlen(fc); + u8 keyid; + + /* WEP, TKIP, CCMP and GCMP */ + if (unlikely(skb->len < hdrlen + IEEE80211_WEP_IV_LEN)) + return -EINVAL; + + skb_copy_bits(skb, hdrlen + 3, &keyid, 1); + + keyid >>= 6; + + return keyid; +} + +static ieee80211_rx_result ieee80211_rx_mesh_check(struct ieee80211_rx_data *rx) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)rx->skb->data; + char *dev_addr = rx->sdata->vif.addr; + + if (ieee80211_is_data(hdr->frame_control)) { + if (is_multicast_ether_addr(hdr->addr1)) { + if (ieee80211_has_tods(hdr->frame_control) || + !ieee80211_has_fromds(hdr->frame_control)) + return RX_DROP_MONITOR; + if (ether_addr_equal(hdr->addr3, dev_addr)) + return RX_DROP_MONITOR; + } else { + if (!ieee80211_has_a4(hdr->frame_control)) + return RX_DROP_MONITOR; + if (ether_addr_equal(hdr->addr4, dev_addr)) + return RX_DROP_MONITOR; + } + } + + /* If there is not an established peer link and this is not a peer link + * establisment frame, beacon or probe, drop the frame. + */ + + if (!rx->sta || sta_plink_state(rx->sta) != NL80211_PLINK_ESTAB) { + struct ieee80211_mgmt *mgmt; + + if (!ieee80211_is_mgmt(hdr->frame_control)) + return RX_DROP_MONITOR; + + if (ieee80211_is_action(hdr->frame_control)) { + u8 category; + + /* make sure category field is present */ + if (rx->skb->len < IEEE80211_MIN_ACTION_SIZE) + return RX_DROP_MONITOR; + + mgmt = (struct ieee80211_mgmt *)hdr; + category = mgmt->u.action.category; + if (category != WLAN_CATEGORY_MESH_ACTION && + category != WLAN_CATEGORY_SELF_PROTECTED) + return RX_DROP_MONITOR; + return RX_CONTINUE; + } + + if (ieee80211_is_probe_req(hdr->frame_control) || + ieee80211_is_probe_resp(hdr->frame_control) || + ieee80211_is_beacon(hdr->frame_control) || + ieee80211_is_auth(hdr->frame_control)) + return RX_CONTINUE; + + return RX_DROP_MONITOR; + } + + return RX_CONTINUE; +} + +static inline bool ieee80211_rx_reorder_ready(struct tid_ampdu_rx *tid_agg_rx, + int index) +{ + struct sk_buff_head *frames = &tid_agg_rx->reorder_buf[index]; + struct sk_buff *tail = skb_peek_tail(frames); + struct ieee80211_rx_status *status; + + if (tid_agg_rx->reorder_buf_filtered && + tid_agg_rx->reorder_buf_filtered & BIT_ULL(index)) + return true; + + if (!tail) + return false; + + status = IEEE80211_SKB_RXCB(tail); + if (status->flag & RX_FLAG_AMSDU_MORE) + return false; + + return true; +} + +static void ieee80211_release_reorder_frame(struct ieee80211_sub_if_data *sdata, + struct tid_ampdu_rx *tid_agg_rx, + int index, + struct sk_buff_head *frames) +{ + struct sk_buff_head *skb_list = &tid_agg_rx->reorder_buf[index]; + struct sk_buff *skb; + struct ieee80211_rx_status *status; + + lockdep_assert_held(&tid_agg_rx->reorder_lock); + + if (skb_queue_empty(skb_list)) + goto no_frame; + + if (!ieee80211_rx_reorder_ready(tid_agg_rx, index)) { + __skb_queue_purge(skb_list); + goto no_frame; + } + + /* release frames from the reorder ring buffer */ + tid_agg_rx->stored_mpdu_num--; + while ((skb = __skb_dequeue(skb_list))) { + status = IEEE80211_SKB_RXCB(skb); + status->rx_flags |= IEEE80211_RX_DEFERRED_RELEASE; + __skb_queue_tail(frames, skb); + } + +no_frame: + if (tid_agg_rx->reorder_buf_filtered) + tid_agg_rx->reorder_buf_filtered &= ~BIT_ULL(index); + tid_agg_rx->head_seq_num = ieee80211_sn_inc(tid_agg_rx->head_seq_num); +} + +static void ieee80211_release_reorder_frames(struct ieee80211_sub_if_data *sdata, + struct tid_ampdu_rx *tid_agg_rx, + u16 head_seq_num, + struct sk_buff_head *frames) +{ + int index; + + lockdep_assert_held(&tid_agg_rx->reorder_lock); + + while (ieee80211_sn_less(tid_agg_rx->head_seq_num, head_seq_num)) { + index = tid_agg_rx->head_seq_num % tid_agg_rx->buf_size; + ieee80211_release_reorder_frame(sdata, tid_agg_rx, index, + frames); + } +} + +/* + * Timeout (in jiffies) for skb's that are waiting in the RX reorder buffer. If + * the skb was added to the buffer longer than this time ago, the earlier + * frames that have not yet been received are assumed to be lost and the skb + * can be released for processing. This may also release other skb's from the + * reorder buffer if there are no additional gaps between the frames. + * + * Callers must hold tid_agg_rx->reorder_lock. + */ +#define HT_RX_REORDER_BUF_TIMEOUT (HZ / 10) + +static void ieee80211_sta_reorder_release(struct ieee80211_sub_if_data *sdata, + struct tid_ampdu_rx *tid_agg_rx, + struct sk_buff_head *frames) +{ + int index, i, j; + + lockdep_assert_held(&tid_agg_rx->reorder_lock); + + /* release the buffer until next missing frame */ + index = tid_agg_rx->head_seq_num % tid_agg_rx->buf_size; + if (!ieee80211_rx_reorder_ready(tid_agg_rx, index) && + tid_agg_rx->stored_mpdu_num) { + /* + * No buffers ready to be released, but check whether any + * frames in the reorder buffer have timed out. + */ + int skipped = 1; + for (j = (index + 1) % tid_agg_rx->buf_size; j != index; + j = (j + 1) % tid_agg_rx->buf_size) { + if (!ieee80211_rx_reorder_ready(tid_agg_rx, j)) { + skipped++; + continue; + } + if (skipped && + !time_after(jiffies, tid_agg_rx->reorder_time[j] + + HT_RX_REORDER_BUF_TIMEOUT)) + goto set_release_timer; + + /* don't leave incomplete A-MSDUs around */ + for (i = (index + 1) % tid_agg_rx->buf_size; i != j; + i = (i + 1) % tid_agg_rx->buf_size) + __skb_queue_purge(&tid_agg_rx->reorder_buf[i]); + + ht_dbg_ratelimited(sdata, + "release an RX reorder frame due to timeout on earlier frames\n"); + ieee80211_release_reorder_frame(sdata, tid_agg_rx, j, + frames); + + /* + * Increment the head seq# also for the skipped slots. + */ + tid_agg_rx->head_seq_num = + (tid_agg_rx->head_seq_num + + skipped) & IEEE80211_SN_MASK; + skipped = 0; + } + } else while (ieee80211_rx_reorder_ready(tid_agg_rx, index)) { + ieee80211_release_reorder_frame(sdata, tid_agg_rx, index, + frames); + index = tid_agg_rx->head_seq_num % tid_agg_rx->buf_size; + } + + if (tid_agg_rx->stored_mpdu_num) { + j = index = tid_agg_rx->head_seq_num % tid_agg_rx->buf_size; + + for (; j != (index - 1) % tid_agg_rx->buf_size; + j = (j + 1) % tid_agg_rx->buf_size) { + if (ieee80211_rx_reorder_ready(tid_agg_rx, j)) + break; + } + + set_release_timer: + + if (!tid_agg_rx->removed) + mod_timer(&tid_agg_rx->reorder_timer, + tid_agg_rx->reorder_time[j] + 1 + + HT_RX_REORDER_BUF_TIMEOUT); + } else { + del_timer(&tid_agg_rx->reorder_timer); + } +} + +/* + * As this function belongs to the RX path it must be under + * rcu_read_lock protection. It returns false if the frame + * can be processed immediately, true if it was consumed. + */ +static bool ieee80211_sta_manage_reorder_buf(struct ieee80211_sub_if_data *sdata, + struct tid_ampdu_rx *tid_agg_rx, + struct sk_buff *skb, + struct sk_buff_head *frames) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + u16 sc = le16_to_cpu(hdr->seq_ctrl); + u16 mpdu_seq_num = (sc & IEEE80211_SCTL_SEQ) >> 4; + u16 head_seq_num, buf_size; + int index; + bool ret = true; + + spin_lock(&tid_agg_rx->reorder_lock); + + /* + * Offloaded BA sessions have no known starting sequence number so pick + * one from first Rxed frame for this tid after BA was started. + */ + if (unlikely(tid_agg_rx->auto_seq)) { + tid_agg_rx->auto_seq = false; + tid_agg_rx->ssn = mpdu_seq_num; + tid_agg_rx->head_seq_num = mpdu_seq_num; + } + + buf_size = tid_agg_rx->buf_size; + head_seq_num = tid_agg_rx->head_seq_num; + + /* + * If the current MPDU's SN is smaller than the SSN, it shouldn't + * be reordered. + */ + if (unlikely(!tid_agg_rx->started)) { + if (ieee80211_sn_less(mpdu_seq_num, head_seq_num)) { + ret = false; + goto out; + } + tid_agg_rx->started = true; + } + + /* frame with out of date sequence number */ + if (ieee80211_sn_less(mpdu_seq_num, head_seq_num)) { + dev_kfree_skb(skb); + goto out; + } + + /* + * If frame the sequence number exceeds our buffering window + * size release some previous frames to make room for this one. + */ + if (!ieee80211_sn_less(mpdu_seq_num, head_seq_num + buf_size)) { + head_seq_num = ieee80211_sn_inc( + ieee80211_sn_sub(mpdu_seq_num, buf_size)); + /* release stored frames up to new head to stack */ + ieee80211_release_reorder_frames(sdata, tid_agg_rx, + head_seq_num, frames); + } + + /* Now the new frame is always in the range of the reordering buffer */ + + index = mpdu_seq_num % tid_agg_rx->buf_size; + + /* check if we already stored this frame */ + if (ieee80211_rx_reorder_ready(tid_agg_rx, index)) { + dev_kfree_skb(skb); + goto out; + } + + /* + * If the current MPDU is in the right order and nothing else + * is stored we can process it directly, no need to buffer it. + * If it is first but there's something stored, we may be able + * to release frames after this one. + */ + if (mpdu_seq_num == tid_agg_rx->head_seq_num && + tid_agg_rx->stored_mpdu_num == 0) { + if (!(status->flag & RX_FLAG_AMSDU_MORE)) + tid_agg_rx->head_seq_num = + ieee80211_sn_inc(tid_agg_rx->head_seq_num); + ret = false; + goto out; + } + + /* put the frame in the reordering buffer */ + __skb_queue_tail(&tid_agg_rx->reorder_buf[index], skb); + if (!(status->flag & RX_FLAG_AMSDU_MORE)) { + tid_agg_rx->reorder_time[index] = jiffies; + tid_agg_rx->stored_mpdu_num++; + ieee80211_sta_reorder_release(sdata, tid_agg_rx, frames); + } + + out: + spin_unlock(&tid_agg_rx->reorder_lock); + return ret; +} + +/* + * Reorder MPDUs from A-MPDUs, keeping them on a buffer. Returns + * true if the MPDU was buffered, false if it should be processed. + */ +static void ieee80211_rx_reorder_ampdu(struct ieee80211_rx_data *rx, + struct sk_buff_head *frames) +{ + struct sk_buff *skb = rx->skb; + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + struct sta_info *sta = rx->sta; + struct tid_ampdu_rx *tid_agg_rx; + u16 sc; + u8 tid, ack_policy; + + if (!ieee80211_is_data_qos(hdr->frame_control) || + is_multicast_ether_addr(hdr->addr1)) + goto dont_reorder; + + /* + * filter the QoS data rx stream according to + * STA/TID and check if this STA/TID is on aggregation + */ + + if (!sta) + goto dont_reorder; + + ack_policy = *ieee80211_get_qos_ctl(hdr) & + IEEE80211_QOS_CTL_ACK_POLICY_MASK; + tid = ieee80211_get_tid(hdr); + + tid_agg_rx = rcu_dereference(sta->ampdu_mlme.tid_rx[tid]); + if (!tid_agg_rx) { + if (ack_policy == IEEE80211_QOS_CTL_ACK_POLICY_BLOCKACK && + !test_bit(tid, rx->sta->ampdu_mlme.agg_session_valid) && + !test_and_set_bit(tid, rx->sta->ampdu_mlme.unexpected_agg)) + ieee80211_send_delba(rx->sdata, rx->sta->sta.addr, tid, + WLAN_BACK_RECIPIENT, + WLAN_REASON_QSTA_REQUIRE_SETUP); + goto dont_reorder; + } + + /* qos null data frames are excluded */ + if (unlikely(hdr->frame_control & cpu_to_le16(IEEE80211_STYPE_NULLFUNC))) + goto dont_reorder; + + /* not part of a BA session */ + if (ack_policy == IEEE80211_QOS_CTL_ACK_POLICY_NOACK) + goto dont_reorder; + + /* new, potentially un-ordered, ampdu frame - process it */ + + /* reset session timer */ + if (tid_agg_rx->timeout) + tid_agg_rx->last_rx = jiffies; + + /* if this mpdu is fragmented - terminate rx aggregation session */ + sc = le16_to_cpu(hdr->seq_ctrl); + if (sc & IEEE80211_SCTL_FRAG) { + ieee80211_queue_skb_to_iface(rx->sdata, rx->link_id, NULL, skb); + return; + } + + /* + * No locking needed -- we will only ever process one + * RX packet at a time, and thus own tid_agg_rx. All + * other code manipulating it needs to (and does) make + * sure that we cannot get to it any more before doing + * anything with it. + */ + if (ieee80211_sta_manage_reorder_buf(rx->sdata, tid_agg_rx, skb, + frames)) + return; + + dont_reorder: + __skb_queue_tail(frames, skb); +} + +static ieee80211_rx_result debug_noinline +ieee80211_rx_h_check_dup(struct ieee80211_rx_data *rx) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)rx->skb->data; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(rx->skb); + + if (status->flag & RX_FLAG_DUP_VALIDATED) + return RX_CONTINUE; + + /* + * Drop duplicate 802.11 retransmissions + * (IEEE 802.11-2012: 9.3.2.10 "Duplicate detection and recovery") + */ + + if (rx->skb->len < 24) + return RX_CONTINUE; + + if (ieee80211_is_ctl(hdr->frame_control) || + ieee80211_is_any_nullfunc(hdr->frame_control) || + is_multicast_ether_addr(hdr->addr1)) + return RX_CONTINUE; + + if (!rx->sta) + return RX_CONTINUE; + + if (unlikely(ieee80211_has_retry(hdr->frame_control) && + rx->sta->last_seq_ctrl[rx->seqno_idx] == hdr->seq_ctrl)) { + I802_DEBUG_INC(rx->local->dot11FrameDuplicateCount); + rx->link_sta->rx_stats.num_duplicates++; + return RX_DROP_UNUSABLE; + } else if (!(status->flag & RX_FLAG_AMSDU_MORE)) { + rx->sta->last_seq_ctrl[rx->seqno_idx] = hdr->seq_ctrl; + } + + return RX_CONTINUE; +} + +static ieee80211_rx_result debug_noinline +ieee80211_rx_h_check(struct ieee80211_rx_data *rx) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)rx->skb->data; + + /* Drop disallowed frame classes based on STA auth/assoc state; + * IEEE 802.11, Chap 5.5. + * + * mac80211 filters only based on association state, i.e. it drops + * Class 3 frames from not associated stations. hostapd sends + * deauth/disassoc frames when needed. In addition, hostapd is + * responsible for filtering on both auth and assoc states. + */ + + if (ieee80211_vif_is_mesh(&rx->sdata->vif)) + return ieee80211_rx_mesh_check(rx); + + if (unlikely((ieee80211_is_data(hdr->frame_control) || + ieee80211_is_pspoll(hdr->frame_control)) && + rx->sdata->vif.type != NL80211_IFTYPE_ADHOC && + rx->sdata->vif.type != NL80211_IFTYPE_OCB && + (!rx->sta || !test_sta_flag(rx->sta, WLAN_STA_ASSOC)))) { + /* + * accept port control frames from the AP even when it's not + * yet marked ASSOC to prevent a race where we don't set the + * assoc bit quickly enough before it sends the first frame + */ + if (rx->sta && rx->sdata->vif.type == NL80211_IFTYPE_STATION && + ieee80211_is_data_present(hdr->frame_control)) { + unsigned int hdrlen; + __be16 ethertype; + + hdrlen = ieee80211_hdrlen(hdr->frame_control); + + if (rx->skb->len < hdrlen + 8) + return RX_DROP_MONITOR; + + skb_copy_bits(rx->skb, hdrlen + 6, ðertype, 2); + if (ethertype == rx->sdata->control_port_protocol) + return RX_CONTINUE; + } + + if (rx->sdata->vif.type == NL80211_IFTYPE_AP && + cfg80211_rx_spurious_frame(rx->sdata->dev, + hdr->addr2, + GFP_ATOMIC)) + return RX_DROP_UNUSABLE; + + return RX_DROP_MONITOR; + } + + return RX_CONTINUE; +} + + +static ieee80211_rx_result debug_noinline +ieee80211_rx_h_check_more_data(struct ieee80211_rx_data *rx) +{ + struct ieee80211_local *local; + struct ieee80211_hdr *hdr; + struct sk_buff *skb; + + local = rx->local; + skb = rx->skb; + hdr = (struct ieee80211_hdr *) skb->data; + + if (!local->pspolling) + return RX_CONTINUE; + + if (!ieee80211_has_fromds(hdr->frame_control)) + /* this is not from AP */ + return RX_CONTINUE; + + if (!ieee80211_is_data(hdr->frame_control)) + return RX_CONTINUE; + + if (!ieee80211_has_moredata(hdr->frame_control)) { + /* AP has no more frames buffered for us */ + local->pspolling = false; + return RX_CONTINUE; + } + + /* more data bit is set, let's request a new frame from the AP */ + ieee80211_send_pspoll(local, rx->sdata); + + return RX_CONTINUE; +} + +static void sta_ps_start(struct sta_info *sta) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sdata->local; + struct ps_data *ps; + int tid; + + if (sta->sdata->vif.type == NL80211_IFTYPE_AP || + sta->sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + ps = &sdata->bss->ps; + else + return; + + atomic_inc(&ps->num_sta_ps); + set_sta_flag(sta, WLAN_STA_PS_STA); + if (!ieee80211_hw_check(&local->hw, AP_LINK_PS)) + drv_sta_notify(local, sdata, STA_NOTIFY_SLEEP, &sta->sta); + ps_dbg(sdata, "STA %pM aid %d enters power save mode\n", + sta->sta.addr, sta->sta.aid); + + ieee80211_clear_fast_xmit(sta); + + if (!sta->sta.txq[0]) + return; + + for (tid = 0; tid < IEEE80211_NUM_TIDS; tid++) { + struct ieee80211_txq *txq = sta->sta.txq[tid]; + struct txq_info *txqi = to_txq_info(txq); + + spin_lock(&local->active_txq_lock[txq->ac]); + if (!list_empty(&txqi->schedule_order)) + list_del_init(&txqi->schedule_order); + spin_unlock(&local->active_txq_lock[txq->ac]); + + if (txq_has_queue(txq)) + set_bit(tid, &sta->txq_buffered_tids); + else + clear_bit(tid, &sta->txq_buffered_tids); + } +} + +static void sta_ps_end(struct sta_info *sta) +{ + ps_dbg(sta->sdata, "STA %pM aid %d exits power save mode\n", + sta->sta.addr, sta->sta.aid); + + if (test_sta_flag(sta, WLAN_STA_PS_DRIVER)) { + /* + * Clear the flag only if the other one is still set + * so that the TX path won't start TX'ing new frames + * directly ... In the case that the driver flag isn't + * set ieee80211_sta_ps_deliver_wakeup() will clear it. + */ + clear_sta_flag(sta, WLAN_STA_PS_STA); + ps_dbg(sta->sdata, "STA %pM aid %d driver-ps-blocked\n", + sta->sta.addr, sta->sta.aid); + return; + } + + set_sta_flag(sta, WLAN_STA_PS_DELIVER); + clear_sta_flag(sta, WLAN_STA_PS_STA); + ieee80211_sta_ps_deliver_wakeup(sta); +} + +int ieee80211_sta_ps_transition(struct ieee80211_sta *pubsta, bool start) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + bool in_ps; + + WARN_ON(!ieee80211_hw_check(&sta->local->hw, AP_LINK_PS)); + + /* Don't let the same PS state be set twice */ + in_ps = test_sta_flag(sta, WLAN_STA_PS_STA); + if ((start && in_ps) || (!start && !in_ps)) + return -EINVAL; + + if (start) + sta_ps_start(sta); + else + sta_ps_end(sta); + + return 0; +} +EXPORT_SYMBOL(ieee80211_sta_ps_transition); + +void ieee80211_sta_pspoll(struct ieee80211_sta *pubsta) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + + if (test_sta_flag(sta, WLAN_STA_SP)) + return; + + if (!test_sta_flag(sta, WLAN_STA_PS_DRIVER)) + ieee80211_sta_ps_deliver_poll_response(sta); + else + set_sta_flag(sta, WLAN_STA_PSPOLL); +} +EXPORT_SYMBOL(ieee80211_sta_pspoll); + +void ieee80211_sta_uapsd_trigger(struct ieee80211_sta *pubsta, u8 tid) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + int ac = ieee80211_ac_from_tid(tid); + + /* + * If this AC is not trigger-enabled do nothing unless the + * driver is calling us after it already checked. + * + * NB: This could/should check a separate bitmap of trigger- + * enabled queues, but for now we only implement uAPSD w/o + * TSPEC changes to the ACs, so they're always the same. + */ + if (!(sta->sta.uapsd_queues & ieee80211_ac_to_qos_mask[ac]) && + tid != IEEE80211_NUM_TIDS) + return; + + /* if we are in a service period, do nothing */ + if (test_sta_flag(sta, WLAN_STA_SP)) + return; + + if (!test_sta_flag(sta, WLAN_STA_PS_DRIVER)) + ieee80211_sta_ps_deliver_uapsd(sta); + else + set_sta_flag(sta, WLAN_STA_UAPSD); +} +EXPORT_SYMBOL(ieee80211_sta_uapsd_trigger); + +static ieee80211_rx_result debug_noinline +ieee80211_rx_h_uapsd_and_pspoll(struct ieee80211_rx_data *rx) +{ + struct ieee80211_sub_if_data *sdata = rx->sdata; + struct ieee80211_hdr *hdr = (void *)rx->skb->data; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(rx->skb); + + if (!rx->sta) + return RX_CONTINUE; + + if (sdata->vif.type != NL80211_IFTYPE_AP && + sdata->vif.type != NL80211_IFTYPE_AP_VLAN) + return RX_CONTINUE; + + /* + * The device handles station powersave, so don't do anything about + * uAPSD and PS-Poll frames (the latter shouldn't even come up from + * it to mac80211 since they're handled.) + */ + if (ieee80211_hw_check(&sdata->local->hw, AP_LINK_PS)) + return RX_CONTINUE; + + /* + * Don't do anything if the station isn't already asleep. In + * the uAPSD case, the station will probably be marked asleep, + * in the PS-Poll case the station must be confused ... + */ + if (!test_sta_flag(rx->sta, WLAN_STA_PS_STA)) + return RX_CONTINUE; + + if (unlikely(ieee80211_is_pspoll(hdr->frame_control))) { + ieee80211_sta_pspoll(&rx->sta->sta); + + /* Free PS Poll skb here instead of returning RX_DROP that would + * count as an dropped frame. */ + dev_kfree_skb(rx->skb); + + return RX_QUEUED; + } else if (!ieee80211_has_morefrags(hdr->frame_control) && + !(status->rx_flags & IEEE80211_RX_DEFERRED_RELEASE) && + ieee80211_has_pm(hdr->frame_control) && + (ieee80211_is_data_qos(hdr->frame_control) || + ieee80211_is_qos_nullfunc(hdr->frame_control))) { + u8 tid = ieee80211_get_tid(hdr); + + ieee80211_sta_uapsd_trigger(&rx->sta->sta, tid); + } + + return RX_CONTINUE; +} + +static ieee80211_rx_result debug_noinline +ieee80211_rx_h_sta_process(struct ieee80211_rx_data *rx) +{ + struct sta_info *sta = rx->sta; + struct link_sta_info *link_sta = rx->link_sta; + struct sk_buff *skb = rx->skb; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + int i; + + if (!sta || !link_sta) + return RX_CONTINUE; + + /* + * Update last_rx only for IBSS packets which are for the current + * BSSID and for station already AUTHORIZED to avoid keeping the + * current IBSS network alive in cases where other STAs start + * using different BSSID. This will also give the station another + * chance to restart the authentication/authorization in case + * something went wrong the first time. + */ + if (rx->sdata->vif.type == NL80211_IFTYPE_ADHOC) { + u8 *bssid = ieee80211_get_bssid(hdr, rx->skb->len, + NL80211_IFTYPE_ADHOC); + if (ether_addr_equal(bssid, rx->sdata->u.ibss.bssid) && + test_sta_flag(sta, WLAN_STA_AUTHORIZED)) { + link_sta->rx_stats.last_rx = jiffies; + if (ieee80211_is_data(hdr->frame_control) && + !is_multicast_ether_addr(hdr->addr1)) + link_sta->rx_stats.last_rate = + sta_stats_encode_rate(status); + } + } else if (rx->sdata->vif.type == NL80211_IFTYPE_OCB) { + link_sta->rx_stats.last_rx = jiffies; + } else if (!ieee80211_is_s1g_beacon(hdr->frame_control) && + !is_multicast_ether_addr(hdr->addr1)) { + /* + * Mesh beacons will update last_rx when if they are found to + * match the current local configuration when processed. + */ + link_sta->rx_stats.last_rx = jiffies; + if (ieee80211_is_data(hdr->frame_control)) + link_sta->rx_stats.last_rate = sta_stats_encode_rate(status); + } + + link_sta->rx_stats.fragments++; + + u64_stats_update_begin(&link_sta->rx_stats.syncp); + link_sta->rx_stats.bytes += rx->skb->len; + u64_stats_update_end(&link_sta->rx_stats.syncp); + + if (!(status->flag & RX_FLAG_NO_SIGNAL_VAL)) { + link_sta->rx_stats.last_signal = status->signal; + ewma_signal_add(&link_sta->rx_stats_avg.signal, + -status->signal); + } + + if (status->chains) { + link_sta->rx_stats.chains = status->chains; + for (i = 0; i < ARRAY_SIZE(status->chain_signal); i++) { + int signal = status->chain_signal[i]; + + if (!(status->chains & BIT(i))) + continue; + + link_sta->rx_stats.chain_signal_last[i] = signal; + ewma_signal_add(&link_sta->rx_stats_avg.chain_signal[i], + -signal); + } + } + + if (ieee80211_is_s1g_beacon(hdr->frame_control)) + return RX_CONTINUE; + + /* + * Change STA power saving mode only at the end of a frame + * exchange sequence, and only for a data or management + * frame as specified in IEEE 802.11-2016 11.2.3.2 + */ + if (!ieee80211_hw_check(&sta->local->hw, AP_LINK_PS) && + !ieee80211_has_morefrags(hdr->frame_control) && + !is_multicast_ether_addr(hdr->addr1) && + (ieee80211_is_mgmt(hdr->frame_control) || + ieee80211_is_data(hdr->frame_control)) && + !(status->rx_flags & IEEE80211_RX_DEFERRED_RELEASE) && + (rx->sdata->vif.type == NL80211_IFTYPE_AP || + rx->sdata->vif.type == NL80211_IFTYPE_AP_VLAN)) { + if (test_sta_flag(sta, WLAN_STA_PS_STA)) { + if (!ieee80211_has_pm(hdr->frame_control)) + sta_ps_end(sta); + } else { + if (ieee80211_has_pm(hdr->frame_control)) + sta_ps_start(sta); + } + } + + /* mesh power save support */ + if (ieee80211_vif_is_mesh(&rx->sdata->vif)) + ieee80211_mps_rx_h_sta_process(sta, hdr); + + /* + * Drop (qos-)data::nullfunc frames silently, since they + * are used only to control station power saving mode. + */ + if (ieee80211_is_any_nullfunc(hdr->frame_control)) { + I802_DEBUG_INC(rx->local->rx_handlers_drop_nullfunc); + + /* + * If we receive a 4-addr nullfunc frame from a STA + * that was not moved to a 4-addr STA vlan yet send + * the event to userspace and for older hostapd drop + * the frame to the monitor interface. + */ + if (ieee80211_has_a4(hdr->frame_control) && + (rx->sdata->vif.type == NL80211_IFTYPE_AP || + (rx->sdata->vif.type == NL80211_IFTYPE_AP_VLAN && + !rx->sdata->u.vlan.sta))) { + if (!test_and_set_sta_flag(sta, WLAN_STA_4ADDR_EVENT)) + cfg80211_rx_unexpected_4addr_frame( + rx->sdata->dev, sta->sta.addr, + GFP_ATOMIC); + return RX_DROP_MONITOR; + } + /* + * Update counter and free packet here to avoid + * counting this as a dropped packed. + */ + link_sta->rx_stats.packets++; + dev_kfree_skb(rx->skb); + return RX_QUEUED; + } + + return RX_CONTINUE; +} /* ieee80211_rx_h_sta_process */ + +static struct ieee80211_key * +ieee80211_rx_get_bigtk(struct ieee80211_rx_data *rx, int idx) +{ + struct ieee80211_key *key = NULL; + int idx2; + + /* Make sure key gets set if either BIGTK key index is set so that + * ieee80211_drop_unencrypted_mgmt() can properly drop both unprotected + * Beacon frames and Beacon frames that claim to use another BIGTK key + * index (i.e., a key that we do not have). + */ + + if (idx < 0) { + idx = NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS; + idx2 = idx + 1; + } else { + if (idx == NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS) + idx2 = idx + 1; + else + idx2 = idx - 1; + } + + if (rx->link_sta) + key = rcu_dereference(rx->link_sta->gtk[idx]); + if (!key) + key = rcu_dereference(rx->link->gtk[idx]); + if (!key && rx->link_sta) + key = rcu_dereference(rx->link_sta->gtk[idx2]); + if (!key) + key = rcu_dereference(rx->link->gtk[idx2]); + + return key; +} + +static ieee80211_rx_result debug_noinline +ieee80211_rx_h_decrypt(struct ieee80211_rx_data *rx) +{ + struct sk_buff *skb = rx->skb; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + int keyidx; + ieee80211_rx_result result = RX_DROP_UNUSABLE; + struct ieee80211_key *sta_ptk = NULL; + struct ieee80211_key *ptk_idx = NULL; + int mmie_keyidx = -1; + __le16 fc; + + if (ieee80211_is_ext(hdr->frame_control)) + return RX_CONTINUE; + + /* + * Key selection 101 + * + * There are five types of keys: + * - GTK (group keys) + * - IGTK (group keys for management frames) + * - BIGTK (group keys for Beacon frames) + * - PTK (pairwise keys) + * - STK (station-to-station pairwise keys) + * + * When selecting a key, we have to distinguish between multicast + * (including broadcast) and unicast frames, the latter can only + * use PTKs and STKs while the former always use GTKs, IGTKs, and + * BIGTKs. Unless, of course, actual WEP keys ("pre-RSNA") are used, + * then unicast frames can also use key indices like GTKs. Hence, if we + * don't have a PTK/STK we check the key index for a WEP key. + * + * Note that in a regular BSS, multicast frames are sent by the + * AP only, associated stations unicast the frame to the AP first + * which then multicasts it on their behalf. + * + * There is also a slight problem in IBSS mode: GTKs are negotiated + * with each station, that is something we don't currently handle. + * The spec seems to expect that one negotiates the same key with + * every station but there's no such requirement; VLANs could be + * possible. + */ + + /* start without a key */ + rx->key = NULL; + fc = hdr->frame_control; + + if (rx->sta) { + int keyid = rx->sta->ptk_idx; + sta_ptk = rcu_dereference(rx->sta->ptk[keyid]); + + if (ieee80211_has_protected(fc) && + !(status->flag & RX_FLAG_IV_STRIPPED)) { + keyid = ieee80211_get_keyid(rx->skb); + + if (unlikely(keyid < 0)) + return RX_DROP_UNUSABLE; + + ptk_idx = rcu_dereference(rx->sta->ptk[keyid]); + } + } + + if (!ieee80211_has_protected(fc)) + mmie_keyidx = ieee80211_get_mmie_keyidx(rx->skb); + + if (!is_multicast_ether_addr(hdr->addr1) && sta_ptk) { + rx->key = ptk_idx ? ptk_idx : sta_ptk; + if ((status->flag & RX_FLAG_DECRYPTED) && + (status->flag & RX_FLAG_IV_STRIPPED)) + return RX_CONTINUE; + /* Skip decryption if the frame is not protected. */ + if (!ieee80211_has_protected(fc)) + return RX_CONTINUE; + } else if (mmie_keyidx >= 0 && ieee80211_is_beacon(fc)) { + /* Broadcast/multicast robust management frame / BIP */ + if ((status->flag & RX_FLAG_DECRYPTED) && + (status->flag & RX_FLAG_IV_STRIPPED)) + return RX_CONTINUE; + + if (mmie_keyidx < NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS || + mmie_keyidx >= NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS + + NUM_DEFAULT_BEACON_KEYS) { + if (rx->sdata->dev) + cfg80211_rx_unprot_mlme_mgmt(rx->sdata->dev, + skb->data, + skb->len); + return RX_DROP_MONITOR; /* unexpected BIP keyidx */ + } + + rx->key = ieee80211_rx_get_bigtk(rx, mmie_keyidx); + if (!rx->key) + return RX_CONTINUE; /* Beacon protection not in use */ + } else if (mmie_keyidx >= 0) { + /* Broadcast/multicast robust management frame / BIP */ + if ((status->flag & RX_FLAG_DECRYPTED) && + (status->flag & RX_FLAG_IV_STRIPPED)) + return RX_CONTINUE; + + if (mmie_keyidx < NUM_DEFAULT_KEYS || + mmie_keyidx >= NUM_DEFAULT_KEYS + NUM_DEFAULT_MGMT_KEYS) + return RX_DROP_MONITOR; /* unexpected BIP keyidx */ + if (rx->link_sta) { + if (ieee80211_is_group_privacy_action(skb) && + test_sta_flag(rx->sta, WLAN_STA_MFP)) + return RX_DROP_MONITOR; + + rx->key = rcu_dereference(rx->link_sta->gtk[mmie_keyidx]); + } + if (!rx->key) + rx->key = rcu_dereference(rx->link->gtk[mmie_keyidx]); + } else if (!ieee80211_has_protected(fc)) { + /* + * The frame was not protected, so skip decryption. However, we + * need to set rx->key if there is a key that could have been + * used so that the frame may be dropped if encryption would + * have been expected. + */ + struct ieee80211_key *key = NULL; + int i; + + if (ieee80211_is_beacon(fc)) { + key = ieee80211_rx_get_bigtk(rx, -1); + } else if (ieee80211_is_mgmt(fc) && + is_multicast_ether_addr(hdr->addr1)) { + key = rcu_dereference(rx->link->default_mgmt_key); + } else { + if (rx->link_sta) { + for (i = 0; i < NUM_DEFAULT_KEYS; i++) { + key = rcu_dereference(rx->link_sta->gtk[i]); + if (key) + break; + } + } + if (!key) { + for (i = 0; i < NUM_DEFAULT_KEYS; i++) { + key = rcu_dereference(rx->link->gtk[i]); + if (key) + break; + } + } + } + if (key) + rx->key = key; + return RX_CONTINUE; + } else { + /* + * The device doesn't give us the IV so we won't be + * able to look up the key. That's ok though, we + * don't need to decrypt the frame, we just won't + * be able to keep statistics accurate. + * Except for key threshold notifications, should + * we somehow allow the driver to tell us which key + * the hardware used if this flag is set? + */ + if ((status->flag & RX_FLAG_DECRYPTED) && + (status->flag & RX_FLAG_IV_STRIPPED)) + return RX_CONTINUE; + + keyidx = ieee80211_get_keyid(rx->skb); + + if (unlikely(keyidx < 0)) + return RX_DROP_UNUSABLE; + + /* check per-station GTK first, if multicast packet */ + if (is_multicast_ether_addr(hdr->addr1) && rx->link_sta) + rx->key = rcu_dereference(rx->link_sta->gtk[keyidx]); + + /* if not found, try default key */ + if (!rx->key) { + if (is_multicast_ether_addr(hdr->addr1)) + rx->key = rcu_dereference(rx->link->gtk[keyidx]); + if (!rx->key) + rx->key = rcu_dereference(rx->sdata->keys[keyidx]); + + /* + * RSNA-protected unicast frames should always be + * sent with pairwise or station-to-station keys, + * but for WEP we allow using a key index as well. + */ + if (rx->key && + rx->key->conf.cipher != WLAN_CIPHER_SUITE_WEP40 && + rx->key->conf.cipher != WLAN_CIPHER_SUITE_WEP104 && + !is_multicast_ether_addr(hdr->addr1)) + rx->key = NULL; + } + } + + if (rx->key) { + if (unlikely(rx->key->flags & KEY_FLAG_TAINTED)) + return RX_DROP_MONITOR; + + /* TODO: add threshold stuff again */ + } else { + return RX_DROP_MONITOR; + } + + switch (rx->key->conf.cipher) { + case WLAN_CIPHER_SUITE_WEP40: + case WLAN_CIPHER_SUITE_WEP104: + result = ieee80211_crypto_wep_decrypt(rx); + break; + case WLAN_CIPHER_SUITE_TKIP: + result = ieee80211_crypto_tkip_decrypt(rx); + break; + case WLAN_CIPHER_SUITE_CCMP: + result = ieee80211_crypto_ccmp_decrypt( + rx, IEEE80211_CCMP_MIC_LEN); + break; + case WLAN_CIPHER_SUITE_CCMP_256: + result = ieee80211_crypto_ccmp_decrypt( + rx, IEEE80211_CCMP_256_MIC_LEN); + break; + case WLAN_CIPHER_SUITE_AES_CMAC: + result = ieee80211_crypto_aes_cmac_decrypt(rx); + break; + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + result = ieee80211_crypto_aes_cmac_256_decrypt(rx); + break; + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + result = ieee80211_crypto_aes_gmac_decrypt(rx); + break; + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + result = ieee80211_crypto_gcmp_decrypt(rx); + break; + default: + result = RX_DROP_UNUSABLE; + } + + /* the hdr variable is invalid after the decrypt handlers */ + + /* either the frame has been decrypted or will be dropped */ + status->flag |= RX_FLAG_DECRYPTED; + + if (unlikely(ieee80211_is_beacon(fc) && result == RX_DROP_UNUSABLE && + rx->sdata->dev)) + cfg80211_rx_unprot_mlme_mgmt(rx->sdata->dev, + skb->data, skb->len); + + return result; +} + +void ieee80211_init_frag_cache(struct ieee80211_fragment_cache *cache) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(cache->entries); i++) + skb_queue_head_init(&cache->entries[i].skb_list); +} + +void ieee80211_destroy_frag_cache(struct ieee80211_fragment_cache *cache) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(cache->entries); i++) + __skb_queue_purge(&cache->entries[i].skb_list); +} + +static inline struct ieee80211_fragment_entry * +ieee80211_reassemble_add(struct ieee80211_fragment_cache *cache, + unsigned int frag, unsigned int seq, int rx_queue, + struct sk_buff **skb) +{ + struct ieee80211_fragment_entry *entry; + + entry = &cache->entries[cache->next++]; + if (cache->next >= IEEE80211_FRAGMENT_MAX) + cache->next = 0; + + __skb_queue_purge(&entry->skb_list); + + __skb_queue_tail(&entry->skb_list, *skb); /* no need for locking */ + *skb = NULL; + entry->first_frag_time = jiffies; + entry->seq = seq; + entry->rx_queue = rx_queue; + entry->last_frag = frag; + entry->check_sequential_pn = false; + entry->extra_len = 0; + + return entry; +} + +static inline struct ieee80211_fragment_entry * +ieee80211_reassemble_find(struct ieee80211_fragment_cache *cache, + unsigned int frag, unsigned int seq, + int rx_queue, struct ieee80211_hdr *hdr) +{ + struct ieee80211_fragment_entry *entry; + int i, idx; + + idx = cache->next; + for (i = 0; i < IEEE80211_FRAGMENT_MAX; i++) { + struct ieee80211_hdr *f_hdr; + struct sk_buff *f_skb; + + idx--; + if (idx < 0) + idx = IEEE80211_FRAGMENT_MAX - 1; + + entry = &cache->entries[idx]; + if (skb_queue_empty(&entry->skb_list) || entry->seq != seq || + entry->rx_queue != rx_queue || + entry->last_frag + 1 != frag) + continue; + + f_skb = __skb_peek(&entry->skb_list); + f_hdr = (struct ieee80211_hdr *) f_skb->data; + + /* + * Check ftype and addresses are equal, else check next fragment + */ + if (((hdr->frame_control ^ f_hdr->frame_control) & + cpu_to_le16(IEEE80211_FCTL_FTYPE)) || + !ether_addr_equal(hdr->addr1, f_hdr->addr1) || + !ether_addr_equal(hdr->addr2, f_hdr->addr2)) + continue; + + if (time_after(jiffies, entry->first_frag_time + 2 * HZ)) { + __skb_queue_purge(&entry->skb_list); + continue; + } + return entry; + } + + return NULL; +} + +static bool requires_sequential_pn(struct ieee80211_rx_data *rx, __le16 fc) +{ + return rx->key && + (rx->key->conf.cipher == WLAN_CIPHER_SUITE_CCMP || + rx->key->conf.cipher == WLAN_CIPHER_SUITE_CCMP_256 || + rx->key->conf.cipher == WLAN_CIPHER_SUITE_GCMP || + rx->key->conf.cipher == WLAN_CIPHER_SUITE_GCMP_256) && + ieee80211_has_protected(fc); +} + +static ieee80211_rx_result debug_noinline +ieee80211_rx_h_defragment(struct ieee80211_rx_data *rx) +{ + struct ieee80211_fragment_cache *cache = &rx->sdata->frags; + struct ieee80211_hdr *hdr; + u16 sc; + __le16 fc; + unsigned int frag, seq; + struct ieee80211_fragment_entry *entry; + struct sk_buff *skb; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(rx->skb); + + hdr = (struct ieee80211_hdr *)rx->skb->data; + fc = hdr->frame_control; + + if (ieee80211_is_ctl(fc) || ieee80211_is_ext(fc)) + return RX_CONTINUE; + + sc = le16_to_cpu(hdr->seq_ctrl); + frag = sc & IEEE80211_SCTL_FRAG; + + if (rx->sta) + cache = &rx->sta->frags; + + if (likely(!ieee80211_has_morefrags(fc) && frag == 0)) + goto out; + + if (is_multicast_ether_addr(hdr->addr1)) + return RX_DROP_MONITOR; + + I802_DEBUG_INC(rx->local->rx_handlers_fragments); + + if (skb_linearize(rx->skb)) + return RX_DROP_UNUSABLE; + + /* + * skb_linearize() might change the skb->data and + * previously cached variables (in this case, hdr) need to + * be refreshed with the new data. + */ + hdr = (struct ieee80211_hdr *)rx->skb->data; + seq = (sc & IEEE80211_SCTL_SEQ) >> 4; + + if (frag == 0) { + /* This is the first fragment of a new frame. */ + entry = ieee80211_reassemble_add(cache, frag, seq, + rx->seqno_idx, &(rx->skb)); + if (requires_sequential_pn(rx, fc)) { + int queue = rx->security_idx; + + /* Store CCMP/GCMP PN so that we can verify that the + * next fragment has a sequential PN value. + */ + entry->check_sequential_pn = true; + entry->is_protected = true; + entry->key_color = rx->key->color; + memcpy(entry->last_pn, + rx->key->u.ccmp.rx_pn[queue], + IEEE80211_CCMP_PN_LEN); + BUILD_BUG_ON(offsetof(struct ieee80211_key, + u.ccmp.rx_pn) != + offsetof(struct ieee80211_key, + u.gcmp.rx_pn)); + BUILD_BUG_ON(sizeof(rx->key->u.ccmp.rx_pn[queue]) != + sizeof(rx->key->u.gcmp.rx_pn[queue])); + BUILD_BUG_ON(IEEE80211_CCMP_PN_LEN != + IEEE80211_GCMP_PN_LEN); + } else if (rx->key && + (ieee80211_has_protected(fc) || + (status->flag & RX_FLAG_DECRYPTED))) { + entry->is_protected = true; + entry->key_color = rx->key->color; + } + return RX_QUEUED; + } + + /* This is a fragment for a frame that should already be pending in + * fragment cache. Add this fragment to the end of the pending entry. + */ + entry = ieee80211_reassemble_find(cache, frag, seq, + rx->seqno_idx, hdr); + if (!entry) { + I802_DEBUG_INC(rx->local->rx_handlers_drop_defrag); + return RX_DROP_MONITOR; + } + + /* "The receiver shall discard MSDUs and MMPDUs whose constituent + * MPDU PN values are not incrementing in steps of 1." + * see IEEE P802.11-REVmc/D5.0, 12.5.3.4.4, item d (for CCMP) + * and IEEE P802.11-REVmc/D5.0, 12.5.5.4.4, item d (for GCMP) + */ + if (entry->check_sequential_pn) { + int i; + u8 pn[IEEE80211_CCMP_PN_LEN], *rpn; + + if (!requires_sequential_pn(rx, fc)) + return RX_DROP_UNUSABLE; + + /* Prevent mixed key and fragment cache attacks */ + if (entry->key_color != rx->key->color) + return RX_DROP_UNUSABLE; + + memcpy(pn, entry->last_pn, IEEE80211_CCMP_PN_LEN); + for (i = IEEE80211_CCMP_PN_LEN - 1; i >= 0; i--) { + pn[i]++; + if (pn[i]) + break; + } + + rpn = rx->ccm_gcm.pn; + if (memcmp(pn, rpn, IEEE80211_CCMP_PN_LEN)) + return RX_DROP_UNUSABLE; + memcpy(entry->last_pn, pn, IEEE80211_CCMP_PN_LEN); + } else if (entry->is_protected && + (!rx->key || + (!ieee80211_has_protected(fc) && + !(status->flag & RX_FLAG_DECRYPTED)) || + rx->key->color != entry->key_color)) { + /* Drop this as a mixed key or fragment cache attack, even + * if for TKIP Michael MIC should protect us, and WEP is a + * lost cause anyway. + */ + return RX_DROP_UNUSABLE; + } else if (entry->is_protected && rx->key && + entry->key_color != rx->key->color && + (status->flag & RX_FLAG_DECRYPTED)) { + return RX_DROP_UNUSABLE; + } + + skb_pull(rx->skb, ieee80211_hdrlen(fc)); + __skb_queue_tail(&entry->skb_list, rx->skb); + entry->last_frag = frag; + entry->extra_len += rx->skb->len; + if (ieee80211_has_morefrags(fc)) { + rx->skb = NULL; + return RX_QUEUED; + } + + rx->skb = __skb_dequeue(&entry->skb_list); + if (skb_tailroom(rx->skb) < entry->extra_len) { + I802_DEBUG_INC(rx->local->rx_expand_skb_head_defrag); + if (unlikely(pskb_expand_head(rx->skb, 0, entry->extra_len, + GFP_ATOMIC))) { + I802_DEBUG_INC(rx->local->rx_handlers_drop_defrag); + __skb_queue_purge(&entry->skb_list); + return RX_DROP_UNUSABLE; + } + } + while ((skb = __skb_dequeue(&entry->skb_list))) { + skb_put_data(rx->skb, skb->data, skb->len); + dev_kfree_skb(skb); + } + + out: + ieee80211_led_rx(rx->local); + if (rx->sta) + rx->link_sta->rx_stats.packets++; + return RX_CONTINUE; +} + +static int ieee80211_802_1x_port_control(struct ieee80211_rx_data *rx) +{ + if (unlikely(!rx->sta || !test_sta_flag(rx->sta, WLAN_STA_AUTHORIZED))) + return -EACCES; + + return 0; +} + +static int ieee80211_drop_unencrypted(struct ieee80211_rx_data *rx, __le16 fc) +{ + struct ieee80211_hdr *hdr = (void *)rx->skb->data; + struct sk_buff *skb = rx->skb; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + + /* + * Pass through unencrypted frames if the hardware has + * decrypted them already. + */ + if (status->flag & RX_FLAG_DECRYPTED) + return 0; + + /* check mesh EAPOL frames first */ + if (unlikely(rx->sta && ieee80211_vif_is_mesh(&rx->sdata->vif) && + ieee80211_is_data(fc))) { + struct ieee80211s_hdr *mesh_hdr; + u16 hdr_len = ieee80211_hdrlen(fc); + u16 ethertype_offset; + __be16 ethertype; + + if (!ether_addr_equal(hdr->addr1, rx->sdata->vif.addr)) + goto drop_check; + + /* make sure fixed part of mesh header is there, also checks skb len */ + if (!pskb_may_pull(rx->skb, hdr_len + 6)) + goto drop_check; + + mesh_hdr = (struct ieee80211s_hdr *)(skb->data + hdr_len); + ethertype_offset = hdr_len + ieee80211_get_mesh_hdrlen(mesh_hdr) + + sizeof(rfc1042_header); + + if (skb_copy_bits(rx->skb, ethertype_offset, ðertype, 2) == 0 && + ethertype == rx->sdata->control_port_protocol) + return 0; + } + +drop_check: + /* Drop unencrypted frames if key is set. */ + if (unlikely(!ieee80211_has_protected(fc) && + !ieee80211_is_any_nullfunc(fc) && + ieee80211_is_data(fc) && rx->key)) + return -EACCES; + + return 0; +} + +static int ieee80211_drop_unencrypted_mgmt(struct ieee80211_rx_data *rx) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)rx->skb->data; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(rx->skb); + __le16 fc = hdr->frame_control; + + /* + * Pass through unencrypted frames if the hardware has + * decrypted them already. + */ + if (status->flag & RX_FLAG_DECRYPTED) + return 0; + + if (rx->sta && test_sta_flag(rx->sta, WLAN_STA_MFP)) { + if (unlikely(!ieee80211_has_protected(fc) && + ieee80211_is_unicast_robust_mgmt_frame(rx->skb) && + rx->key)) { + if (ieee80211_is_deauth(fc) || + ieee80211_is_disassoc(fc)) + cfg80211_rx_unprot_mlme_mgmt(rx->sdata->dev, + rx->skb->data, + rx->skb->len); + return -EACCES; + } + /* BIP does not use Protected field, so need to check MMIE */ + if (unlikely(ieee80211_is_multicast_robust_mgmt_frame(rx->skb) && + ieee80211_get_mmie_keyidx(rx->skb) < 0)) { + if (ieee80211_is_deauth(fc) || + ieee80211_is_disassoc(fc)) + cfg80211_rx_unprot_mlme_mgmt(rx->sdata->dev, + rx->skb->data, + rx->skb->len); + return -EACCES; + } + if (unlikely(ieee80211_is_beacon(fc) && rx->key && + ieee80211_get_mmie_keyidx(rx->skb) < 0)) { + cfg80211_rx_unprot_mlme_mgmt(rx->sdata->dev, + rx->skb->data, + rx->skb->len); + return -EACCES; + } + /* + * When using MFP, Action frames are not allowed prior to + * having configured keys. + */ + if (unlikely(ieee80211_is_action(fc) && !rx->key && + ieee80211_is_robust_mgmt_frame(rx->skb))) + return -EACCES; + } + + return 0; +} + +static int +__ieee80211_data_to_8023(struct ieee80211_rx_data *rx, bool *port_control) +{ + struct ieee80211_sub_if_data *sdata = rx->sdata; + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)rx->skb->data; + bool check_port_control = false; + struct ethhdr *ehdr; + int ret; + + *port_control = false; + if (ieee80211_has_a4(hdr->frame_control) && + sdata->vif.type == NL80211_IFTYPE_AP_VLAN && !sdata->u.vlan.sta) + return -1; + + if (sdata->vif.type == NL80211_IFTYPE_STATION && + !!sdata->u.mgd.use_4addr != !!ieee80211_has_a4(hdr->frame_control)) { + + if (!sdata->u.mgd.use_4addr) + return -1; + else if (!ether_addr_equal(hdr->addr1, sdata->vif.addr)) + check_port_control = true; + } + + if (is_multicast_ether_addr(hdr->addr1) && + sdata->vif.type == NL80211_IFTYPE_AP_VLAN && sdata->u.vlan.sta) + return -1; + + ret = ieee80211_data_to_8023(rx->skb, sdata->vif.addr, sdata->vif.type); + if (ret < 0) + return ret; + + ehdr = (struct ethhdr *) rx->skb->data; + if (ehdr->h_proto == rx->sdata->control_port_protocol) + *port_control = true; + else if (check_port_control) + return -1; + + return 0; +} + +bool ieee80211_is_our_addr(struct ieee80211_sub_if_data *sdata, + const u8 *addr, int *out_link_id) +{ + unsigned int link_id; + + /* non-MLO, or MLD address replaced by hardware */ + if (ether_addr_equal(sdata->vif.addr, addr)) + return true; + + if (!sdata->vif.valid_links) + return false; + + for (link_id = 0; link_id < ARRAY_SIZE(sdata->vif.link_conf); link_id++) { + struct ieee80211_bss_conf *conf; + + conf = rcu_dereference(sdata->vif.link_conf[link_id]); + + if (!conf) + continue; + if (ether_addr_equal(conf->addr, addr)) { + if (out_link_id) + *out_link_id = link_id; + return true; + } + } + + return false; +} + +/* + * requires that rx->skb is a frame with ethernet header + */ +static bool ieee80211_frame_allowed(struct ieee80211_rx_data *rx, __le16 fc) +{ + static const u8 pae_group_addr[ETH_ALEN] __aligned(2) + = { 0x01, 0x80, 0xC2, 0x00, 0x00, 0x03 }; + struct ethhdr *ehdr = (struct ethhdr *) rx->skb->data; + + /* + * Allow EAPOL frames to us/the PAE group address regardless of + * whether the frame was encrypted or not, and always disallow + * all other destination addresses for them. + */ + if (unlikely(ehdr->h_proto == rx->sdata->control_port_protocol)) + return ieee80211_is_our_addr(rx->sdata, ehdr->h_dest, NULL) || + ether_addr_equal(ehdr->h_dest, pae_group_addr); + + if (ieee80211_802_1x_port_control(rx) || + ieee80211_drop_unencrypted(rx, fc)) + return false; + + return true; +} + +static void ieee80211_deliver_skb_to_local_stack(struct sk_buff *skb, + struct ieee80211_rx_data *rx) +{ + struct ieee80211_sub_if_data *sdata = rx->sdata; + struct net_device *dev = sdata->dev; + + if (unlikely((skb->protocol == sdata->control_port_protocol || + (skb->protocol == cpu_to_be16(ETH_P_PREAUTH) && + !sdata->control_port_no_preauth)) && + sdata->control_port_over_nl80211)) { + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + bool noencrypt = !(status->flag & RX_FLAG_DECRYPTED); + + cfg80211_rx_control_port(dev, skb, noencrypt); + dev_kfree_skb(skb); + } else { + struct ethhdr *ehdr = (void *)skb_mac_header(skb); + + memset(skb->cb, 0, sizeof(skb->cb)); + + /* + * 802.1X over 802.11 requires that the authenticator address + * be used for EAPOL frames. However, 802.1X allows the use of + * the PAE group address instead. If the interface is part of + * a bridge and we pass the frame with the PAE group address, + * then the bridge will forward it to the network (even if the + * client was not associated yet), which isn't supposed to + * happen. + * To avoid that, rewrite the destination address to our own + * address, so that the authenticator (e.g. hostapd) will see + * the frame, but bridge won't forward it anywhere else. Note + * that due to earlier filtering, the only other address can + * be the PAE group address, unless the hardware allowed them + * through in 802.3 offloaded mode. + */ + if (unlikely(skb->protocol == sdata->control_port_protocol && + !ether_addr_equal(ehdr->h_dest, sdata->vif.addr))) + ether_addr_copy(ehdr->h_dest, sdata->vif.addr); + + /* deliver to local stack */ + if (rx->list) + list_add_tail(&skb->list, rx->list); + else + netif_receive_skb(skb); + } +} + +/* + * requires that rx->skb is a frame with ethernet header + */ +static void +ieee80211_deliver_skb(struct ieee80211_rx_data *rx) +{ + struct ieee80211_sub_if_data *sdata = rx->sdata; + struct net_device *dev = sdata->dev; + struct sk_buff *skb, *xmit_skb; + struct ethhdr *ehdr = (struct ethhdr *) rx->skb->data; + struct sta_info *dsta; + + skb = rx->skb; + xmit_skb = NULL; + + dev_sw_netstats_rx_add(dev, skb->len); + + if (rx->sta) { + /* The seqno index has the same property as needed + * for the rx_msdu field, i.e. it is IEEE80211_NUM_TIDS + * for non-QoS-data frames. Here we know it's a data + * frame, so count MSDUs. + */ + u64_stats_update_begin(&rx->link_sta->rx_stats.syncp); + rx->link_sta->rx_stats.msdu[rx->seqno_idx]++; + u64_stats_update_end(&rx->link_sta->rx_stats.syncp); + } + + if ((sdata->vif.type == NL80211_IFTYPE_AP || + sdata->vif.type == NL80211_IFTYPE_AP_VLAN) && + !(sdata->flags & IEEE80211_SDATA_DONT_BRIDGE_PACKETS) && + ehdr->h_proto != rx->sdata->control_port_protocol && + (sdata->vif.type != NL80211_IFTYPE_AP_VLAN || !sdata->u.vlan.sta)) { + if (is_multicast_ether_addr(ehdr->h_dest) && + ieee80211_vif_get_num_mcast_if(sdata) != 0) { + /* + * send multicast frames both to higher layers in + * local net stack and back to the wireless medium + */ + xmit_skb = skb_copy(skb, GFP_ATOMIC); + if (!xmit_skb) + net_info_ratelimited("%s: failed to clone multicast frame\n", + dev->name); + } else if (!is_multicast_ether_addr(ehdr->h_dest) && + !ether_addr_equal(ehdr->h_dest, ehdr->h_source)) { + dsta = sta_info_get(sdata, ehdr->h_dest); + if (dsta) { + /* + * The destination station is associated to + * this AP (in this VLAN), so send the frame + * directly to it and do not pass it to local + * net stack. + */ + xmit_skb = skb; + skb = NULL; + } + } + } + +#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS + if (skb) { + /* 'align' will only take the values 0 or 2 here since all + * frames are required to be aligned to 2-byte boundaries + * when being passed to mac80211; the code here works just + * as well if that isn't true, but mac80211 assumes it can + * access fields as 2-byte aligned (e.g. for ether_addr_equal) + */ + int align; + + align = (unsigned long)(skb->data + sizeof(struct ethhdr)) & 3; + if (align) { + if (WARN_ON(skb_headroom(skb) < 3)) { + dev_kfree_skb(skb); + skb = NULL; + } else { + u8 *data = skb->data; + size_t len = skb_headlen(skb); + skb->data -= align; + memmove(skb->data, data, len); + skb_set_tail_pointer(skb, len); + } + } + } +#endif + + if (skb) { + skb->protocol = eth_type_trans(skb, dev); + ieee80211_deliver_skb_to_local_stack(skb, rx); + } + + if (xmit_skb) { + /* + * Send to wireless media and increase priority by 256 to + * keep the received priority instead of reclassifying + * the frame (see cfg80211_classify8021d). + */ + xmit_skb->priority += 256; + xmit_skb->protocol = htons(ETH_P_802_3); + skb_reset_network_header(xmit_skb); + skb_reset_mac_header(xmit_skb); + dev_queue_xmit(xmit_skb); + } +} + +static ieee80211_rx_result debug_noinline +__ieee80211_rx_h_amsdu(struct ieee80211_rx_data *rx, u8 data_offset) +{ + struct net_device *dev = rx->sdata->dev; + struct sk_buff *skb = rx->skb; + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + __le16 fc = hdr->frame_control; + struct sk_buff_head frame_list; + struct ethhdr ethhdr; + const u8 *check_da = ethhdr.h_dest, *check_sa = ethhdr.h_source; + + if (unlikely(ieee80211_has_a4(hdr->frame_control))) { + check_da = NULL; + check_sa = NULL; + } else switch (rx->sdata->vif.type) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + check_da = NULL; + break; + case NL80211_IFTYPE_STATION: + if (!rx->sta || + !test_sta_flag(rx->sta, WLAN_STA_TDLS_PEER)) + check_sa = NULL; + break; + case NL80211_IFTYPE_MESH_POINT: + check_sa = NULL; + break; + default: + break; + } + + skb->dev = dev; + __skb_queue_head_init(&frame_list); + + if (ieee80211_data_to_8023_exthdr(skb, ðhdr, + rx->sdata->vif.addr, + rx->sdata->vif.type, + data_offset, true)) + return RX_DROP_UNUSABLE; + + ieee80211_amsdu_to_8023s(skb, &frame_list, dev->dev_addr, + rx->sdata->vif.type, + rx->local->hw.extra_tx_headroom, + check_da, check_sa); + + while (!skb_queue_empty(&frame_list)) { + rx->skb = __skb_dequeue(&frame_list); + + if (!ieee80211_frame_allowed(rx, fc)) { + dev_kfree_skb(rx->skb); + continue; + } + + ieee80211_deliver_skb(rx); + } + + return RX_QUEUED; +} + +static ieee80211_rx_result debug_noinline +ieee80211_rx_h_amsdu(struct ieee80211_rx_data *rx) +{ + struct sk_buff *skb = rx->skb; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + __le16 fc = hdr->frame_control; + + if (!(status->rx_flags & IEEE80211_RX_AMSDU)) + return RX_CONTINUE; + + if (unlikely(!ieee80211_is_data(fc))) + return RX_CONTINUE; + + if (unlikely(!ieee80211_is_data_present(fc))) + return RX_DROP_MONITOR; + + if (unlikely(ieee80211_has_a4(hdr->frame_control))) { + switch (rx->sdata->vif.type) { + case NL80211_IFTYPE_AP_VLAN: + if (!rx->sdata->u.vlan.sta) + return RX_DROP_UNUSABLE; + break; + case NL80211_IFTYPE_STATION: + if (!rx->sdata->u.mgd.use_4addr) + return RX_DROP_UNUSABLE; + break; + default: + return RX_DROP_UNUSABLE; + } + } + + if (is_multicast_ether_addr(hdr->addr1)) + return RX_DROP_UNUSABLE; + + if (rx->key) { + /* + * We should not receive A-MSDUs on pre-HT connections, + * and HT connections cannot use old ciphers. Thus drop + * them, as in those cases we couldn't even have SPP + * A-MSDUs or such. + */ + switch (rx->key->conf.cipher) { + case WLAN_CIPHER_SUITE_WEP40: + case WLAN_CIPHER_SUITE_WEP104: + case WLAN_CIPHER_SUITE_TKIP: + return RX_DROP_UNUSABLE; + default: + break; + } + } + + return __ieee80211_rx_h_amsdu(rx, 0); +} + +#ifdef CONFIG_MAC80211_MESH +static ieee80211_rx_result +ieee80211_rx_h_mesh_fwding(struct ieee80211_rx_data *rx) +{ + struct ieee80211_hdr *fwd_hdr, *hdr; + struct ieee80211_tx_info *info; + struct ieee80211s_hdr *mesh_hdr; + struct sk_buff *skb = rx->skb, *fwd_skb; + struct ieee80211_local *local = rx->local; + struct ieee80211_sub_if_data *sdata = rx->sdata; + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + u16 ac, q, hdrlen; + int tailroom = 0; + + hdr = (struct ieee80211_hdr *) skb->data; + hdrlen = ieee80211_hdrlen(hdr->frame_control); + + /* make sure fixed part of mesh header is there, also checks skb len */ + if (!pskb_may_pull(rx->skb, hdrlen + 6)) + return RX_DROP_MONITOR; + + mesh_hdr = (struct ieee80211s_hdr *) (skb->data + hdrlen); + + /* make sure full mesh header is there, also checks skb len */ + if (!pskb_may_pull(rx->skb, + hdrlen + ieee80211_get_mesh_hdrlen(mesh_hdr))) + return RX_DROP_MONITOR; + + /* reload pointers */ + hdr = (struct ieee80211_hdr *) skb->data; + mesh_hdr = (struct ieee80211s_hdr *) (skb->data + hdrlen); + + if (ieee80211_drop_unencrypted(rx, hdr->frame_control)) + return RX_DROP_MONITOR; + + /* frame is in RMC, don't forward */ + if (ieee80211_is_data(hdr->frame_control) && + is_multicast_ether_addr(hdr->addr1) && + mesh_rmc_check(rx->sdata, hdr->addr3, mesh_hdr)) + return RX_DROP_MONITOR; + + if (!ieee80211_is_data(hdr->frame_control)) + return RX_CONTINUE; + + if (!mesh_hdr->ttl) + return RX_DROP_MONITOR; + + if (mesh_hdr->flags & MESH_FLAGS_AE) { + struct mesh_path *mppath; + char *proxied_addr; + char *mpp_addr; + + if (is_multicast_ether_addr(hdr->addr1)) { + mpp_addr = hdr->addr3; + proxied_addr = mesh_hdr->eaddr1; + } else if ((mesh_hdr->flags & MESH_FLAGS_AE) == + MESH_FLAGS_AE_A5_A6) { + /* has_a4 already checked in ieee80211_rx_mesh_check */ + mpp_addr = hdr->addr4; + proxied_addr = mesh_hdr->eaddr2; + } else { + return RX_DROP_MONITOR; + } + + rcu_read_lock(); + mppath = mpp_path_lookup(sdata, proxied_addr); + if (!mppath) { + mpp_path_add(sdata, proxied_addr, mpp_addr); + } else { + spin_lock_bh(&mppath->state_lock); + if (!ether_addr_equal(mppath->mpp, mpp_addr)) + memcpy(mppath->mpp, mpp_addr, ETH_ALEN); + mppath->exp_time = jiffies; + spin_unlock_bh(&mppath->state_lock); + } + rcu_read_unlock(); + } + + /* Frame has reached destination. Don't forward */ + if (!is_multicast_ether_addr(hdr->addr1) && + ether_addr_equal(sdata->vif.addr, hdr->addr3)) + return RX_CONTINUE; + + ac = ieee802_1d_to_ac[skb->priority]; + q = sdata->vif.hw_queue[ac]; + if (ieee80211_queue_stopped(&local->hw, q)) { + IEEE80211_IFSTA_MESH_CTR_INC(ifmsh, dropped_frames_congestion); + return RX_DROP_MONITOR; + } + skb_set_queue_mapping(skb, ac); + + if (!--mesh_hdr->ttl) { + if (!is_multicast_ether_addr(hdr->addr1)) + IEEE80211_IFSTA_MESH_CTR_INC(ifmsh, + dropped_frames_ttl); + goto out; + } + + if (!ifmsh->mshcfg.dot11MeshForwarding) + goto out; + + if (sdata->crypto_tx_tailroom_needed_cnt) + tailroom = IEEE80211_ENCRYPT_TAILROOM; + + fwd_skb = skb_copy_expand(skb, local->tx_headroom + + IEEE80211_ENCRYPT_HEADROOM, + tailroom, GFP_ATOMIC); + if (!fwd_skb) + goto out; + + fwd_skb->dev = sdata->dev; + fwd_hdr = (struct ieee80211_hdr *) fwd_skb->data; + fwd_hdr->frame_control &= ~cpu_to_le16(IEEE80211_FCTL_RETRY); + info = IEEE80211_SKB_CB(fwd_skb); + memset(info, 0, sizeof(*info)); + info->control.flags |= IEEE80211_TX_INTCFL_NEED_TXPROCESSING; + info->control.vif = &rx->sdata->vif; + info->control.jiffies = jiffies; + if (is_multicast_ether_addr(fwd_hdr->addr1)) { + IEEE80211_IFSTA_MESH_CTR_INC(ifmsh, fwded_mcast); + memcpy(fwd_hdr->addr2, sdata->vif.addr, ETH_ALEN); + /* update power mode indication when forwarding */ + ieee80211_mps_set_frame_flags(sdata, NULL, fwd_hdr); + } else if (!mesh_nexthop_lookup(sdata, fwd_skb)) { + /* mesh power mode flags updated in mesh_nexthop_lookup */ + IEEE80211_IFSTA_MESH_CTR_INC(ifmsh, fwded_unicast); + } else { + /* unable to resolve next hop */ + mesh_path_error_tx(sdata, ifmsh->mshcfg.element_ttl, + fwd_hdr->addr3, 0, + WLAN_REASON_MESH_PATH_NOFORWARD, + fwd_hdr->addr2); + IEEE80211_IFSTA_MESH_CTR_INC(ifmsh, dropped_frames_no_route); + kfree_skb(fwd_skb); + return RX_DROP_MONITOR; + } + + IEEE80211_IFSTA_MESH_CTR_INC(ifmsh, fwded_frames); + ieee80211_add_pending_skb(local, fwd_skb); + out: + if (is_multicast_ether_addr(hdr->addr1)) + return RX_CONTINUE; + return RX_DROP_MONITOR; +} +#endif + +static ieee80211_rx_result debug_noinline +ieee80211_rx_h_data(struct ieee80211_rx_data *rx) +{ + struct ieee80211_sub_if_data *sdata = rx->sdata; + struct ieee80211_local *local = rx->local; + struct net_device *dev = sdata->dev; + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)rx->skb->data; + __le16 fc = hdr->frame_control; + bool port_control; + int err; + + if (unlikely(!ieee80211_is_data(hdr->frame_control))) + return RX_CONTINUE; + + if (unlikely(!ieee80211_is_data_present(hdr->frame_control))) + return RX_DROP_MONITOR; + + /* + * Send unexpected-4addr-frame event to hostapd. For older versions, + * also drop the frame to cooked monitor interfaces. + */ + if (ieee80211_has_a4(hdr->frame_control) && + sdata->vif.type == NL80211_IFTYPE_AP) { + if (rx->sta && + !test_and_set_sta_flag(rx->sta, WLAN_STA_4ADDR_EVENT)) + cfg80211_rx_unexpected_4addr_frame( + rx->sdata->dev, rx->sta->sta.addr, GFP_ATOMIC); + return RX_DROP_MONITOR; + } + + err = __ieee80211_data_to_8023(rx, &port_control); + if (unlikely(err)) + return RX_DROP_UNUSABLE; + + if (!ieee80211_frame_allowed(rx, fc)) + return RX_DROP_MONITOR; + + /* directly handle TDLS channel switch requests/responses */ + if (unlikely(((struct ethhdr *)rx->skb->data)->h_proto == + cpu_to_be16(ETH_P_TDLS))) { + struct ieee80211_tdls_data *tf = (void *)rx->skb->data; + + if (pskb_may_pull(rx->skb, + offsetof(struct ieee80211_tdls_data, u)) && + tf->payload_type == WLAN_TDLS_SNAP_RFTYPE && + tf->category == WLAN_CATEGORY_TDLS && + (tf->action_code == WLAN_TDLS_CHANNEL_SWITCH_REQUEST || + tf->action_code == WLAN_TDLS_CHANNEL_SWITCH_RESPONSE)) { + rx->skb->protocol = cpu_to_be16(ETH_P_TDLS); + __ieee80211_queue_skb_to_iface(sdata, rx->link_id, + rx->sta, rx->skb); + return RX_QUEUED; + } + } + + if (rx->sdata->vif.type == NL80211_IFTYPE_AP_VLAN && + unlikely(port_control) && sdata->bss) { + sdata = container_of(sdata->bss, struct ieee80211_sub_if_data, + u.ap); + dev = sdata->dev; + rx->sdata = sdata; + } + + rx->skb->dev = dev; + + if (!ieee80211_hw_check(&local->hw, SUPPORTS_DYNAMIC_PS) && + local->ps_sdata && local->hw.conf.dynamic_ps_timeout > 0 && + !is_multicast_ether_addr( + ((struct ethhdr *)rx->skb->data)->h_dest) && + (!local->scanning && + !test_bit(SDATA_STATE_OFFCHANNEL, &sdata->state))) + mod_timer(&local->dynamic_ps_timer, jiffies + + msecs_to_jiffies(local->hw.conf.dynamic_ps_timeout)); + + ieee80211_deliver_skb(rx); + + return RX_QUEUED; +} + +static ieee80211_rx_result debug_noinline +ieee80211_rx_h_ctrl(struct ieee80211_rx_data *rx, struct sk_buff_head *frames) +{ + struct sk_buff *skb = rx->skb; + struct ieee80211_bar *bar = (struct ieee80211_bar *)skb->data; + struct tid_ampdu_rx *tid_agg_rx; + u16 start_seq_num; + u16 tid; + + if (likely(!ieee80211_is_ctl(bar->frame_control))) + return RX_CONTINUE; + + if (ieee80211_is_back_req(bar->frame_control)) { + struct { + __le16 control, start_seq_num; + } __packed bar_data; + struct ieee80211_event event = { + .type = BAR_RX_EVENT, + }; + + if (!rx->sta) + return RX_DROP_MONITOR; + + if (skb_copy_bits(skb, offsetof(struct ieee80211_bar, control), + &bar_data, sizeof(bar_data))) + return RX_DROP_MONITOR; + + tid = le16_to_cpu(bar_data.control) >> 12; + + if (!test_bit(tid, rx->sta->ampdu_mlme.agg_session_valid) && + !test_and_set_bit(tid, rx->sta->ampdu_mlme.unexpected_agg)) + ieee80211_send_delba(rx->sdata, rx->sta->sta.addr, tid, + WLAN_BACK_RECIPIENT, + WLAN_REASON_QSTA_REQUIRE_SETUP); + + tid_agg_rx = rcu_dereference(rx->sta->ampdu_mlme.tid_rx[tid]); + if (!tid_agg_rx) + return RX_DROP_MONITOR; + + start_seq_num = le16_to_cpu(bar_data.start_seq_num) >> 4; + event.u.ba.tid = tid; + event.u.ba.ssn = start_seq_num; + event.u.ba.sta = &rx->sta->sta; + + /* reset session timer */ + if (tid_agg_rx->timeout) + mod_timer(&tid_agg_rx->session_timer, + TU_TO_EXP_TIME(tid_agg_rx->timeout)); + + spin_lock(&tid_agg_rx->reorder_lock); + /* release stored frames up to start of BAR */ + ieee80211_release_reorder_frames(rx->sdata, tid_agg_rx, + start_seq_num, frames); + spin_unlock(&tid_agg_rx->reorder_lock); + + drv_event_callback(rx->local, rx->sdata, &event); + + kfree_skb(skb); + return RX_QUEUED; + } + + /* + * After this point, we only want management frames, + * so we can drop all remaining control frames to + * cooked monitor interfaces. + */ + return RX_DROP_MONITOR; +} + +static void ieee80211_process_sa_query_req(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, + size_t len) +{ + struct ieee80211_local *local = sdata->local; + struct sk_buff *skb; + struct ieee80211_mgmt *resp; + + if (!ether_addr_equal(mgmt->da, sdata->vif.addr)) { + /* Not to own unicast address */ + return; + } + + if (!ether_addr_equal(mgmt->sa, sdata->deflink.u.mgd.bssid) || + !ether_addr_equal(mgmt->bssid, sdata->deflink.u.mgd.bssid)) { + /* Not from the current AP or not associated yet. */ + return; + } + + if (len < 24 + 1 + sizeof(resp->u.action.u.sa_query)) { + /* Too short SA Query request frame */ + return; + } + + skb = dev_alloc_skb(sizeof(*resp) + local->hw.extra_tx_headroom); + if (skb == NULL) + return; + + skb_reserve(skb, local->hw.extra_tx_headroom); + resp = skb_put_zero(skb, 24); + memcpy(resp->da, mgmt->sa, ETH_ALEN); + memcpy(resp->sa, sdata->vif.addr, ETH_ALEN); + memcpy(resp->bssid, sdata->deflink.u.mgd.bssid, ETH_ALEN); + resp->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_ACTION); + skb_put(skb, 1 + sizeof(resp->u.action.u.sa_query)); + resp->u.action.category = WLAN_CATEGORY_SA_QUERY; + resp->u.action.u.sa_query.action = WLAN_ACTION_SA_QUERY_RESPONSE; + memcpy(resp->u.action.u.sa_query.trans_id, + mgmt->u.action.u.sa_query.trans_id, + WLAN_SA_QUERY_TR_ID_LEN); + + ieee80211_tx_skb(sdata, skb); +} + +static void +ieee80211_rx_check_bss_color_collision(struct ieee80211_rx_data *rx) +{ + struct ieee80211_mgmt *mgmt = (void *)rx->skb->data; + const struct element *ie; + size_t baselen; + + if (!wiphy_ext_feature_isset(rx->local->hw.wiphy, + NL80211_EXT_FEATURE_BSS_COLOR)) + return; + + if (ieee80211_hw_check(&rx->local->hw, DETECTS_COLOR_COLLISION)) + return; + + if (rx->sdata->vif.bss_conf.csa_active) + return; + + baselen = mgmt->u.beacon.variable - rx->skb->data; + if (baselen > rx->skb->len) + return; + + ie = cfg80211_find_ext_elem(WLAN_EID_EXT_HE_OPERATION, + mgmt->u.beacon.variable, + rx->skb->len - baselen); + if (ie && ie->datalen >= sizeof(struct ieee80211_he_operation) && + ie->datalen >= ieee80211_he_oper_size(ie->data + 1)) { + struct ieee80211_bss_conf *bss_conf = &rx->sdata->vif.bss_conf; + const struct ieee80211_he_operation *he_oper; + u8 color; + + he_oper = (void *)(ie->data + 1); + if (le32_get_bits(he_oper->he_oper_params, + IEEE80211_HE_OPERATION_BSS_COLOR_DISABLED)) + return; + + color = le32_get_bits(he_oper->he_oper_params, + IEEE80211_HE_OPERATION_BSS_COLOR_MASK); + if (color == bss_conf->he_bss_color.color) + ieeee80211_obss_color_collision_notify(&rx->sdata->vif, + BIT_ULL(color), + GFP_ATOMIC); + } +} + +static ieee80211_rx_result debug_noinline +ieee80211_rx_h_mgmt_check(struct ieee80211_rx_data *rx) +{ + struct ieee80211_mgmt *mgmt = (struct ieee80211_mgmt *) rx->skb->data; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(rx->skb); + + if (ieee80211_is_s1g_beacon(mgmt->frame_control)) + return RX_CONTINUE; + + /* + * From here on, look only at management frames. + * Data and control frames are already handled, + * and unknown (reserved) frames are useless. + */ + if (rx->skb->len < 24) + return RX_DROP_MONITOR; + + if (!ieee80211_is_mgmt(mgmt->frame_control)) + return RX_DROP_MONITOR; + + if (rx->sdata->vif.type == NL80211_IFTYPE_AP && + ieee80211_is_beacon(mgmt->frame_control) && + !(rx->flags & IEEE80211_RX_BEACON_REPORTED)) { + int sig = 0; + + /* sw bss color collision detection */ + ieee80211_rx_check_bss_color_collision(rx); + + if (ieee80211_hw_check(&rx->local->hw, SIGNAL_DBM) && + !(status->flag & RX_FLAG_NO_SIGNAL_VAL)) + sig = status->signal; + + cfg80211_report_obss_beacon_khz(rx->local->hw.wiphy, + rx->skb->data, rx->skb->len, + ieee80211_rx_status_to_khz(status), + sig); + rx->flags |= IEEE80211_RX_BEACON_REPORTED; + } + + if (ieee80211_drop_unencrypted_mgmt(rx)) + return RX_DROP_UNUSABLE; + + return RX_CONTINUE; +} + +static bool +ieee80211_process_rx_twt_action(struct ieee80211_rx_data *rx) +{ + struct ieee80211_mgmt *mgmt = (struct ieee80211_mgmt *)rx->skb->data; + struct ieee80211_sub_if_data *sdata = rx->sdata; + + /* TWT actions are only supported in AP for the moment */ + if (sdata->vif.type != NL80211_IFTYPE_AP) + return false; + + if (!rx->local->ops->add_twt_setup) + return false; + + if (!sdata->vif.bss_conf.twt_responder) + return false; + + if (!rx->sta) + return false; + + switch (mgmt->u.action.u.s1g.action_code) { + case WLAN_S1G_TWT_SETUP: { + struct ieee80211_twt_setup *twt; + + if (rx->skb->len < IEEE80211_MIN_ACTION_SIZE + + 1 + /* action code */ + sizeof(struct ieee80211_twt_setup) + + 2 /* TWT req_type agrt */) + break; + + twt = (void *)mgmt->u.action.u.s1g.variable; + if (twt->element_id != WLAN_EID_S1G_TWT) + break; + + if (rx->skb->len < IEEE80211_MIN_ACTION_SIZE + + 4 + /* action code + token + tlv */ + twt->length) + break; + + return true; /* queue the frame */ + } + case WLAN_S1G_TWT_TEARDOWN: + if (rx->skb->len < IEEE80211_MIN_ACTION_SIZE + 2) + break; + + return true; /* queue the frame */ + default: + break; + } + + return false; +} + +static ieee80211_rx_result debug_noinline +ieee80211_rx_h_action(struct ieee80211_rx_data *rx) +{ + struct ieee80211_local *local = rx->local; + struct ieee80211_sub_if_data *sdata = rx->sdata; + struct ieee80211_mgmt *mgmt = (struct ieee80211_mgmt *) rx->skb->data; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(rx->skb); + int len = rx->skb->len; + + if (!ieee80211_is_action(mgmt->frame_control)) + return RX_CONTINUE; + + /* drop too small frames */ + if (len < IEEE80211_MIN_ACTION_SIZE) + return RX_DROP_UNUSABLE; + + if (!rx->sta && mgmt->u.action.category != WLAN_CATEGORY_PUBLIC && + mgmt->u.action.category != WLAN_CATEGORY_SELF_PROTECTED && + mgmt->u.action.category != WLAN_CATEGORY_SPECTRUM_MGMT) + return RX_DROP_UNUSABLE; + + switch (mgmt->u.action.category) { + case WLAN_CATEGORY_HT: + /* reject HT action frames from stations not supporting HT */ + if (!rx->link_sta->pub->ht_cap.ht_supported) + goto invalid; + + if (sdata->vif.type != NL80211_IFTYPE_STATION && + sdata->vif.type != NL80211_IFTYPE_MESH_POINT && + sdata->vif.type != NL80211_IFTYPE_AP_VLAN && + sdata->vif.type != NL80211_IFTYPE_AP && + sdata->vif.type != NL80211_IFTYPE_ADHOC) + break; + + /* verify action & smps_control/chanwidth are present */ + if (len < IEEE80211_MIN_ACTION_SIZE + 2) + goto invalid; + + switch (mgmt->u.action.u.ht_smps.action) { + case WLAN_HT_ACTION_SMPS: { + struct ieee80211_supported_band *sband; + enum ieee80211_smps_mode smps_mode; + struct sta_opmode_info sta_opmode = {}; + + if (sdata->vif.type != NL80211_IFTYPE_AP && + sdata->vif.type != NL80211_IFTYPE_AP_VLAN) + goto handled; + + /* convert to HT capability */ + switch (mgmt->u.action.u.ht_smps.smps_control) { + case WLAN_HT_SMPS_CONTROL_DISABLED: + smps_mode = IEEE80211_SMPS_OFF; + break; + case WLAN_HT_SMPS_CONTROL_STATIC: + smps_mode = IEEE80211_SMPS_STATIC; + break; + case WLAN_HT_SMPS_CONTROL_DYNAMIC: + smps_mode = IEEE80211_SMPS_DYNAMIC; + break; + default: + goto invalid; + } + + /* if no change do nothing */ + if (rx->link_sta->pub->smps_mode == smps_mode) + goto handled; + rx->link_sta->pub->smps_mode = smps_mode; + sta_opmode.smps_mode = + ieee80211_smps_mode_to_smps_mode(smps_mode); + sta_opmode.changed = STA_OPMODE_SMPS_MODE_CHANGED; + + sband = rx->local->hw.wiphy->bands[status->band]; + + rate_control_rate_update(local, sband, rx->sta, 0, + IEEE80211_RC_SMPS_CHANGED); + cfg80211_sta_opmode_change_notify(sdata->dev, + rx->sta->addr, + &sta_opmode, + GFP_ATOMIC); + goto handled; + } + case WLAN_HT_ACTION_NOTIFY_CHANWIDTH: { + struct ieee80211_supported_band *sband; + u8 chanwidth = mgmt->u.action.u.ht_notify_cw.chanwidth; + enum ieee80211_sta_rx_bandwidth max_bw, new_bw; + struct sta_opmode_info sta_opmode = {}; + + /* If it doesn't support 40 MHz it can't change ... */ + if (!(rx->link_sta->pub->ht_cap.cap & + IEEE80211_HT_CAP_SUP_WIDTH_20_40)) + goto handled; + + if (chanwidth == IEEE80211_HT_CHANWIDTH_20MHZ) + max_bw = IEEE80211_STA_RX_BW_20; + else + max_bw = ieee80211_sta_cap_rx_bw(rx->link_sta); + + /* set cur_max_bandwidth and recalc sta bw */ + rx->link_sta->cur_max_bandwidth = max_bw; + new_bw = ieee80211_sta_cur_vht_bw(rx->link_sta); + + if (rx->link_sta->pub->bandwidth == new_bw) + goto handled; + + rx->link_sta->pub->bandwidth = new_bw; + sband = rx->local->hw.wiphy->bands[status->band]; + sta_opmode.bw = + ieee80211_sta_rx_bw_to_chan_width(rx->link_sta); + sta_opmode.changed = STA_OPMODE_MAX_BW_CHANGED; + + rate_control_rate_update(local, sband, rx->sta, 0, + IEEE80211_RC_BW_CHANGED); + cfg80211_sta_opmode_change_notify(sdata->dev, + rx->sta->addr, + &sta_opmode, + GFP_ATOMIC); + goto handled; + } + default: + goto invalid; + } + + break; + case WLAN_CATEGORY_PUBLIC: + if (len < IEEE80211_MIN_ACTION_SIZE + 1) + goto invalid; + if (sdata->vif.type != NL80211_IFTYPE_STATION) + break; + if (!rx->sta) + break; + if (!ether_addr_equal(mgmt->bssid, sdata->deflink.u.mgd.bssid)) + break; + if (mgmt->u.action.u.ext_chan_switch.action_code != + WLAN_PUB_ACTION_EXT_CHANSW_ANN) + break; + if (len < offsetof(struct ieee80211_mgmt, + u.action.u.ext_chan_switch.variable)) + goto invalid; + goto queue; + case WLAN_CATEGORY_VHT: + if (sdata->vif.type != NL80211_IFTYPE_STATION && + sdata->vif.type != NL80211_IFTYPE_MESH_POINT && + sdata->vif.type != NL80211_IFTYPE_AP_VLAN && + sdata->vif.type != NL80211_IFTYPE_AP && + sdata->vif.type != NL80211_IFTYPE_ADHOC) + break; + + /* verify action code is present */ + if (len < IEEE80211_MIN_ACTION_SIZE + 1) + goto invalid; + + switch (mgmt->u.action.u.vht_opmode_notif.action_code) { + case WLAN_VHT_ACTION_OPMODE_NOTIF: { + /* verify opmode is present */ + if (len < IEEE80211_MIN_ACTION_SIZE + 2) + goto invalid; + goto queue; + } + case WLAN_VHT_ACTION_GROUPID_MGMT: { + if (len < IEEE80211_MIN_ACTION_SIZE + 25) + goto invalid; + goto queue; + } + default: + break; + } + break; + case WLAN_CATEGORY_BACK: + if (sdata->vif.type != NL80211_IFTYPE_STATION && + sdata->vif.type != NL80211_IFTYPE_MESH_POINT && + sdata->vif.type != NL80211_IFTYPE_AP_VLAN && + sdata->vif.type != NL80211_IFTYPE_AP && + sdata->vif.type != NL80211_IFTYPE_ADHOC) + break; + + /* verify action_code is present */ + if (len < IEEE80211_MIN_ACTION_SIZE + 1) + break; + + switch (mgmt->u.action.u.addba_req.action_code) { + case WLAN_ACTION_ADDBA_REQ: + if (len < (IEEE80211_MIN_ACTION_SIZE + + sizeof(mgmt->u.action.u.addba_req))) + goto invalid; + break; + case WLAN_ACTION_ADDBA_RESP: + if (len < (IEEE80211_MIN_ACTION_SIZE + + sizeof(mgmt->u.action.u.addba_resp))) + goto invalid; + break; + case WLAN_ACTION_DELBA: + if (len < (IEEE80211_MIN_ACTION_SIZE + + sizeof(mgmt->u.action.u.delba))) + goto invalid; + break; + default: + goto invalid; + } + + goto queue; + case WLAN_CATEGORY_SPECTRUM_MGMT: + /* verify action_code is present */ + if (len < IEEE80211_MIN_ACTION_SIZE + 1) + break; + + switch (mgmt->u.action.u.measurement.action_code) { + case WLAN_ACTION_SPCT_MSR_REQ: + if (status->band != NL80211_BAND_5GHZ) + break; + + if (len < (IEEE80211_MIN_ACTION_SIZE + + sizeof(mgmt->u.action.u.measurement))) + break; + + if (sdata->vif.type != NL80211_IFTYPE_STATION) + break; + + ieee80211_process_measurement_req(sdata, mgmt, len); + goto handled; + case WLAN_ACTION_SPCT_CHL_SWITCH: { + u8 *bssid; + if (len < (IEEE80211_MIN_ACTION_SIZE + + sizeof(mgmt->u.action.u.chan_switch))) + break; + + if (sdata->vif.type != NL80211_IFTYPE_STATION && + sdata->vif.type != NL80211_IFTYPE_ADHOC && + sdata->vif.type != NL80211_IFTYPE_MESH_POINT) + break; + + if (sdata->vif.type == NL80211_IFTYPE_STATION) + bssid = sdata->deflink.u.mgd.bssid; + else if (sdata->vif.type == NL80211_IFTYPE_ADHOC) + bssid = sdata->u.ibss.bssid; + else if (sdata->vif.type == NL80211_IFTYPE_MESH_POINT) + bssid = mgmt->sa; + else + break; + + if (!ether_addr_equal(mgmt->bssid, bssid)) + break; + + goto queue; + } + } + break; + case WLAN_CATEGORY_SELF_PROTECTED: + if (len < (IEEE80211_MIN_ACTION_SIZE + + sizeof(mgmt->u.action.u.self_prot.action_code))) + break; + + switch (mgmt->u.action.u.self_prot.action_code) { + case WLAN_SP_MESH_PEERING_OPEN: + case WLAN_SP_MESH_PEERING_CLOSE: + case WLAN_SP_MESH_PEERING_CONFIRM: + if (!ieee80211_vif_is_mesh(&sdata->vif)) + goto invalid; + if (sdata->u.mesh.user_mpm) + /* userspace handles this frame */ + break; + goto queue; + case WLAN_SP_MGK_INFORM: + case WLAN_SP_MGK_ACK: + if (!ieee80211_vif_is_mesh(&sdata->vif)) + goto invalid; + break; + } + break; + case WLAN_CATEGORY_MESH_ACTION: + if (len < (IEEE80211_MIN_ACTION_SIZE + + sizeof(mgmt->u.action.u.mesh_action.action_code))) + break; + + if (!ieee80211_vif_is_mesh(&sdata->vif)) + break; + if (mesh_action_is_path_sel(mgmt) && + !mesh_path_sel_is_hwmp(sdata)) + break; + goto queue; + case WLAN_CATEGORY_S1G: + if (len < offsetofend(typeof(*mgmt), + u.action.u.s1g.action_code)) + break; + + switch (mgmt->u.action.u.s1g.action_code) { + case WLAN_S1G_TWT_SETUP: + case WLAN_S1G_TWT_TEARDOWN: + if (ieee80211_process_rx_twt_action(rx)) + goto queue; + break; + default: + break; + } + break; + } + + return RX_CONTINUE; + + invalid: + status->rx_flags |= IEEE80211_RX_MALFORMED_ACTION_FRM; + /* will return in the next handlers */ + return RX_CONTINUE; + + handled: + if (rx->sta) + rx->link_sta->rx_stats.packets++; + dev_kfree_skb(rx->skb); + return RX_QUEUED; + + queue: + ieee80211_queue_skb_to_iface(sdata, rx->link_id, rx->sta, rx->skb); + return RX_QUEUED; +} + +static ieee80211_rx_result debug_noinline +ieee80211_rx_h_userspace_mgmt(struct ieee80211_rx_data *rx) +{ + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(rx->skb); + struct cfg80211_rx_info info = { + .freq = ieee80211_rx_status_to_khz(status), + .buf = rx->skb->data, + .len = rx->skb->len, + .link_id = rx->link_id, + .have_link_id = rx->link_id >= 0, + }; + + /* skip known-bad action frames and return them in the next handler */ + if (status->rx_flags & IEEE80211_RX_MALFORMED_ACTION_FRM) + return RX_CONTINUE; + + /* + * Getting here means the kernel doesn't know how to handle + * it, but maybe userspace does ... include returned frames + * so userspace can register for those to know whether ones + * it transmitted were processed or returned. + */ + + if (ieee80211_hw_check(&rx->local->hw, SIGNAL_DBM) && + !(status->flag & RX_FLAG_NO_SIGNAL_VAL)) + info.sig_dbm = status->signal; + + if (ieee80211_is_timing_measurement(rx->skb) || + ieee80211_is_ftm(rx->skb)) { + info.rx_tstamp = ktime_to_ns(skb_hwtstamps(rx->skb)->hwtstamp); + info.ack_tstamp = ktime_to_ns(status->ack_tx_hwtstamp); + } + + if (cfg80211_rx_mgmt_ext(&rx->sdata->wdev, &info)) { + if (rx->sta) + rx->link_sta->rx_stats.packets++; + dev_kfree_skb(rx->skb); + return RX_QUEUED; + } + + return RX_CONTINUE; +} + +static ieee80211_rx_result debug_noinline +ieee80211_rx_h_action_post_userspace(struct ieee80211_rx_data *rx) +{ + struct ieee80211_sub_if_data *sdata = rx->sdata; + struct ieee80211_mgmt *mgmt = (struct ieee80211_mgmt *) rx->skb->data; + int len = rx->skb->len; + + if (!ieee80211_is_action(mgmt->frame_control)) + return RX_CONTINUE; + + switch (mgmt->u.action.category) { + case WLAN_CATEGORY_SA_QUERY: + if (len < (IEEE80211_MIN_ACTION_SIZE + + sizeof(mgmt->u.action.u.sa_query))) + break; + + switch (mgmt->u.action.u.sa_query.action) { + case WLAN_ACTION_SA_QUERY_REQUEST: + if (sdata->vif.type != NL80211_IFTYPE_STATION) + break; + ieee80211_process_sa_query_req(sdata, mgmt, len); + goto handled; + } + break; + } + + return RX_CONTINUE; + + handled: + if (rx->sta) + rx->link_sta->rx_stats.packets++; + dev_kfree_skb(rx->skb); + return RX_QUEUED; +} + +static ieee80211_rx_result debug_noinline +ieee80211_rx_h_action_return(struct ieee80211_rx_data *rx) +{ + struct ieee80211_local *local = rx->local; + struct ieee80211_mgmt *mgmt = (struct ieee80211_mgmt *) rx->skb->data; + struct sk_buff *nskb; + struct ieee80211_sub_if_data *sdata = rx->sdata; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(rx->skb); + + if (!ieee80211_is_action(mgmt->frame_control)) + return RX_CONTINUE; + + /* + * For AP mode, hostapd is responsible for handling any action + * frames that we didn't handle, including returning unknown + * ones. For all other modes we will return them to the sender, + * setting the 0x80 bit in the action category, as required by + * 802.11-2012 9.24.4. + * Newer versions of hostapd shall also use the management frame + * registration mechanisms, but older ones still use cooked + * monitor interfaces so push all frames there. + */ + if (!(status->rx_flags & IEEE80211_RX_MALFORMED_ACTION_FRM) && + (sdata->vif.type == NL80211_IFTYPE_AP || + sdata->vif.type == NL80211_IFTYPE_AP_VLAN)) + return RX_DROP_MONITOR; + + if (is_multicast_ether_addr(mgmt->da)) + return RX_DROP_MONITOR; + + /* do not return rejected action frames */ + if (mgmt->u.action.category & 0x80) + return RX_DROP_UNUSABLE; + + nskb = skb_copy_expand(rx->skb, local->hw.extra_tx_headroom, 0, + GFP_ATOMIC); + if (nskb) { + struct ieee80211_mgmt *nmgmt = (void *)nskb->data; + + nmgmt->u.action.category |= 0x80; + memcpy(nmgmt->da, nmgmt->sa, ETH_ALEN); + memcpy(nmgmt->sa, rx->sdata->vif.addr, ETH_ALEN); + + memset(nskb->cb, 0, sizeof(nskb->cb)); + + if (rx->sdata->vif.type == NL80211_IFTYPE_P2P_DEVICE) { + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(nskb); + + info->flags = IEEE80211_TX_CTL_TX_OFFCHAN | + IEEE80211_TX_INTFL_OFFCHAN_TX_OK | + IEEE80211_TX_CTL_NO_CCK_RATE; + if (ieee80211_hw_check(&local->hw, QUEUE_CONTROL)) + info->hw_queue = + local->hw.offchannel_tx_hw_queue; + } + + __ieee80211_tx_skb_tid_band(rx->sdata, nskb, 7, -1, + status->band); + } + dev_kfree_skb(rx->skb); + return RX_QUEUED; +} + +static ieee80211_rx_result debug_noinline +ieee80211_rx_h_ext(struct ieee80211_rx_data *rx) +{ + struct ieee80211_sub_if_data *sdata = rx->sdata; + struct ieee80211_hdr *hdr = (void *)rx->skb->data; + + if (!ieee80211_is_ext(hdr->frame_control)) + return RX_CONTINUE; + + if (sdata->vif.type != NL80211_IFTYPE_STATION) + return RX_DROP_MONITOR; + + /* for now only beacons are ext, so queue them */ + ieee80211_queue_skb_to_iface(sdata, rx->link_id, rx->sta, rx->skb); + + return RX_QUEUED; +} + +static ieee80211_rx_result debug_noinline +ieee80211_rx_h_mgmt(struct ieee80211_rx_data *rx) +{ + struct ieee80211_sub_if_data *sdata = rx->sdata; + struct ieee80211_mgmt *mgmt = (void *)rx->skb->data; + __le16 stype; + + stype = mgmt->frame_control & cpu_to_le16(IEEE80211_FCTL_STYPE); + + if (!ieee80211_vif_is_mesh(&sdata->vif) && + sdata->vif.type != NL80211_IFTYPE_ADHOC && + sdata->vif.type != NL80211_IFTYPE_OCB && + sdata->vif.type != NL80211_IFTYPE_STATION) + return RX_DROP_MONITOR; + + switch (stype) { + case cpu_to_le16(IEEE80211_STYPE_AUTH): + case cpu_to_le16(IEEE80211_STYPE_BEACON): + case cpu_to_le16(IEEE80211_STYPE_PROBE_RESP): + /* process for all: mesh, mlme, ibss */ + break; + case cpu_to_le16(IEEE80211_STYPE_DEAUTH): + if (is_multicast_ether_addr(mgmt->da) && + !is_broadcast_ether_addr(mgmt->da)) + return RX_DROP_MONITOR; + + /* process only for station/IBSS */ + if (sdata->vif.type != NL80211_IFTYPE_STATION && + sdata->vif.type != NL80211_IFTYPE_ADHOC) + return RX_DROP_MONITOR; + break; + case cpu_to_le16(IEEE80211_STYPE_ASSOC_RESP): + case cpu_to_le16(IEEE80211_STYPE_REASSOC_RESP): + case cpu_to_le16(IEEE80211_STYPE_DISASSOC): + if (is_multicast_ether_addr(mgmt->da) && + !is_broadcast_ether_addr(mgmt->da)) + return RX_DROP_MONITOR; + + /* process only for station */ + if (sdata->vif.type != NL80211_IFTYPE_STATION) + return RX_DROP_MONITOR; + break; + case cpu_to_le16(IEEE80211_STYPE_PROBE_REQ): + /* process only for ibss and mesh */ + if (sdata->vif.type != NL80211_IFTYPE_ADHOC && + sdata->vif.type != NL80211_IFTYPE_MESH_POINT) + return RX_DROP_MONITOR; + break; + default: + return RX_DROP_MONITOR; + } + + ieee80211_queue_skb_to_iface(sdata, rx->link_id, rx->sta, rx->skb); + + return RX_QUEUED; +} + +static void ieee80211_rx_cooked_monitor(struct ieee80211_rx_data *rx, + struct ieee80211_rate *rate) +{ + struct ieee80211_sub_if_data *sdata; + struct ieee80211_local *local = rx->local; + struct sk_buff *skb = rx->skb, *skb2; + struct net_device *prev_dev = NULL; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + int needed_headroom; + + /* + * If cooked monitor has been processed already, then + * don't do it again. If not, set the flag. + */ + if (rx->flags & IEEE80211_RX_CMNTR) + goto out_free_skb; + rx->flags |= IEEE80211_RX_CMNTR; + + /* If there are no cooked monitor interfaces, just free the SKB */ + if (!local->cooked_mntrs) + goto out_free_skb; + + /* vendor data is long removed here */ + status->flag &= ~RX_FLAG_RADIOTAP_VENDOR_DATA; + /* room for the radiotap header based on driver features */ + needed_headroom = ieee80211_rx_radiotap_hdrlen(local, status, skb); + + if (skb_headroom(skb) < needed_headroom && + pskb_expand_head(skb, needed_headroom, 0, GFP_ATOMIC)) + goto out_free_skb; + + /* prepend radiotap information */ + ieee80211_add_rx_radiotap_header(local, skb, rate, needed_headroom, + false); + + skb_reset_mac_header(skb); + skb->ip_summed = CHECKSUM_UNNECESSARY; + skb->pkt_type = PACKET_OTHERHOST; + skb->protocol = htons(ETH_P_802_2); + + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + if (!ieee80211_sdata_running(sdata)) + continue; + + if (sdata->vif.type != NL80211_IFTYPE_MONITOR || + !(sdata->u.mntr.flags & MONITOR_FLAG_COOK_FRAMES)) + continue; + + if (prev_dev) { + skb2 = skb_clone(skb, GFP_ATOMIC); + if (skb2) { + skb2->dev = prev_dev; + netif_receive_skb(skb2); + } + } + + prev_dev = sdata->dev; + dev_sw_netstats_rx_add(sdata->dev, skb->len); + } + + if (prev_dev) { + skb->dev = prev_dev; + netif_receive_skb(skb); + return; + } + + out_free_skb: + dev_kfree_skb(skb); +} + +static void ieee80211_rx_handlers_result(struct ieee80211_rx_data *rx, + ieee80211_rx_result res) +{ + switch (res) { + case RX_DROP_MONITOR: + I802_DEBUG_INC(rx->sdata->local->rx_handlers_drop); + if (rx->sta) + rx->link_sta->rx_stats.dropped++; + fallthrough; + case RX_CONTINUE: { + struct ieee80211_rate *rate = NULL; + struct ieee80211_supported_band *sband; + struct ieee80211_rx_status *status; + + status = IEEE80211_SKB_RXCB((rx->skb)); + + sband = rx->local->hw.wiphy->bands[status->band]; + if (status->encoding == RX_ENC_LEGACY) + rate = &sband->bitrates[status->rate_idx]; + + ieee80211_rx_cooked_monitor(rx, rate); + break; + } + case RX_DROP_UNUSABLE: + I802_DEBUG_INC(rx->sdata->local->rx_handlers_drop); + if (rx->sta) + rx->link_sta->rx_stats.dropped++; + dev_kfree_skb(rx->skb); + break; + case RX_QUEUED: + I802_DEBUG_INC(rx->sdata->local->rx_handlers_queued); + break; + } +} + +static void ieee80211_rx_handlers(struct ieee80211_rx_data *rx, + struct sk_buff_head *frames) +{ + ieee80211_rx_result res = RX_DROP_MONITOR; + struct sk_buff *skb; + +#define CALL_RXH(rxh) \ + do { \ + res = rxh(rx); \ + if (res != RX_CONTINUE) \ + goto rxh_next; \ + } while (0) + + /* Lock here to avoid hitting all of the data used in the RX + * path (e.g. key data, station data, ...) concurrently when + * a frame is released from the reorder buffer due to timeout + * from the timer, potentially concurrently with RX from the + * driver. + */ + spin_lock_bh(&rx->local->rx_path_lock); + + while ((skb = __skb_dequeue(frames))) { + /* + * all the other fields are valid across frames + * that belong to an aMPDU since they are on the + * same TID from the same station + */ + rx->skb = skb; + + if (WARN_ON_ONCE(!rx->link)) + goto rxh_next; + + CALL_RXH(ieee80211_rx_h_check_more_data); + CALL_RXH(ieee80211_rx_h_uapsd_and_pspoll); + CALL_RXH(ieee80211_rx_h_sta_process); + CALL_RXH(ieee80211_rx_h_decrypt); + CALL_RXH(ieee80211_rx_h_defragment); + CALL_RXH(ieee80211_rx_h_michael_mic_verify); + /* must be after MMIC verify so header is counted in MPDU mic */ +#ifdef CONFIG_MAC80211_MESH + if (ieee80211_vif_is_mesh(&rx->sdata->vif)) + CALL_RXH(ieee80211_rx_h_mesh_fwding); +#endif + CALL_RXH(ieee80211_rx_h_amsdu); + CALL_RXH(ieee80211_rx_h_data); + + /* special treatment -- needs the queue */ + res = ieee80211_rx_h_ctrl(rx, frames); + if (res != RX_CONTINUE) + goto rxh_next; + + CALL_RXH(ieee80211_rx_h_mgmt_check); + CALL_RXH(ieee80211_rx_h_action); + CALL_RXH(ieee80211_rx_h_userspace_mgmt); + CALL_RXH(ieee80211_rx_h_action_post_userspace); + CALL_RXH(ieee80211_rx_h_action_return); + CALL_RXH(ieee80211_rx_h_ext); + CALL_RXH(ieee80211_rx_h_mgmt); + + rxh_next: + ieee80211_rx_handlers_result(rx, res); + +#undef CALL_RXH + } + + spin_unlock_bh(&rx->local->rx_path_lock); +} + +static void ieee80211_invoke_rx_handlers(struct ieee80211_rx_data *rx) +{ + struct sk_buff_head reorder_release; + ieee80211_rx_result res = RX_DROP_MONITOR; + + __skb_queue_head_init(&reorder_release); + +#define CALL_RXH(rxh) \ + do { \ + res = rxh(rx); \ + if (res != RX_CONTINUE) \ + goto rxh_next; \ + } while (0) + + CALL_RXH(ieee80211_rx_h_check_dup); + CALL_RXH(ieee80211_rx_h_check); + + ieee80211_rx_reorder_ampdu(rx, &reorder_release); + + ieee80211_rx_handlers(rx, &reorder_release); + return; + + rxh_next: + ieee80211_rx_handlers_result(rx, res); + +#undef CALL_RXH +} + +static bool +ieee80211_rx_is_valid_sta_link_id(struct ieee80211_sta *sta, u8 link_id) +{ + return !!(sta->valid_links & BIT(link_id)); +} + +static bool ieee80211_rx_data_set_link(struct ieee80211_rx_data *rx, + u8 link_id) +{ + rx->link_id = link_id; + rx->link = rcu_dereference(rx->sdata->link[link_id]); + + if (!rx->sta) + return rx->link; + + if (!ieee80211_rx_is_valid_sta_link_id(&rx->sta->sta, link_id)) + return false; + + rx->link_sta = rcu_dereference(rx->sta->link[link_id]); + + return rx->link && rx->link_sta; +} + +static bool ieee80211_rx_data_set_sta(struct ieee80211_rx_data *rx, + struct sta_info *sta, int link_id) +{ + rx->link_id = link_id; + rx->sta = sta; + + if (sta) { + rx->local = sta->sdata->local; + if (!rx->sdata) + rx->sdata = sta->sdata; + rx->link_sta = &sta->deflink; + } + + if (link_id < 0) + rx->link = &rx->sdata->deflink; + else if (!ieee80211_rx_data_set_link(rx, link_id)) + return false; + + return true; +} + +/* + * This function makes calls into the RX path, therefore + * it has to be invoked under RCU read lock. + */ +void ieee80211_release_reorder_timeout(struct sta_info *sta, int tid) +{ + struct sk_buff_head frames; + struct ieee80211_rx_data rx = { + /* This is OK -- must be QoS data frame */ + .security_idx = tid, + .seqno_idx = tid, + }; + struct tid_ampdu_rx *tid_agg_rx; + int link_id = -1; + + /* FIXME: statistics won't be right with this */ + if (sta->sta.valid_links) + link_id = ffs(sta->sta.valid_links) - 1; + + if (!ieee80211_rx_data_set_sta(&rx, sta, link_id)) + return; + + tid_agg_rx = rcu_dereference(sta->ampdu_mlme.tid_rx[tid]); + if (!tid_agg_rx) + return; + + __skb_queue_head_init(&frames); + + spin_lock(&tid_agg_rx->reorder_lock); + ieee80211_sta_reorder_release(sta->sdata, tid_agg_rx, &frames); + spin_unlock(&tid_agg_rx->reorder_lock); + + if (!skb_queue_empty(&frames)) { + struct ieee80211_event event = { + .type = BA_FRAME_TIMEOUT, + .u.ba.tid = tid, + .u.ba.sta = &sta->sta, + }; + drv_event_callback(rx.local, rx.sdata, &event); + } + + ieee80211_rx_handlers(&rx, &frames); +} + +void ieee80211_mark_rx_ba_filtered_frames(struct ieee80211_sta *pubsta, u8 tid, + u16 ssn, u64 filtered, + u16 received_mpdus) +{ + struct ieee80211_local *local; + struct sta_info *sta; + struct tid_ampdu_rx *tid_agg_rx; + struct sk_buff_head frames; + struct ieee80211_rx_data rx = { + /* This is OK -- must be QoS data frame */ + .security_idx = tid, + .seqno_idx = tid, + }; + int i, diff; + + if (WARN_ON(!pubsta || tid >= IEEE80211_NUM_TIDS)) + return; + + __skb_queue_head_init(&frames); + + sta = container_of(pubsta, struct sta_info, sta); + + local = sta->sdata->local; + WARN_ONCE(local->hw.max_rx_aggregation_subframes > 64, + "RX BA marker can't support max_rx_aggregation_subframes %u > 64\n", + local->hw.max_rx_aggregation_subframes); + + if (!ieee80211_rx_data_set_sta(&rx, sta, -1)) + return; + + rcu_read_lock(); + tid_agg_rx = rcu_dereference(sta->ampdu_mlme.tid_rx[tid]); + if (!tid_agg_rx) + goto out; + + spin_lock_bh(&tid_agg_rx->reorder_lock); + + if (received_mpdus >= IEEE80211_SN_MODULO >> 1) { + int release; + + /* release all frames in the reorder buffer */ + release = (tid_agg_rx->head_seq_num + tid_agg_rx->buf_size) % + IEEE80211_SN_MODULO; + ieee80211_release_reorder_frames(sta->sdata, tid_agg_rx, + release, &frames); + /* update ssn to match received ssn */ + tid_agg_rx->head_seq_num = ssn; + } else { + ieee80211_release_reorder_frames(sta->sdata, tid_agg_rx, ssn, + &frames); + } + + /* handle the case that received ssn is behind the mac ssn. + * it can be tid_agg_rx->buf_size behind and still be valid */ + diff = (tid_agg_rx->head_seq_num - ssn) & IEEE80211_SN_MASK; + if (diff >= tid_agg_rx->buf_size) { + tid_agg_rx->reorder_buf_filtered = 0; + goto release; + } + filtered = filtered >> diff; + ssn += diff; + + /* update bitmap */ + for (i = 0; i < tid_agg_rx->buf_size; i++) { + int index = (ssn + i) % tid_agg_rx->buf_size; + + tid_agg_rx->reorder_buf_filtered &= ~BIT_ULL(index); + if (filtered & BIT_ULL(i)) + tid_agg_rx->reorder_buf_filtered |= BIT_ULL(index); + } + + /* now process also frames that the filter marking released */ + ieee80211_sta_reorder_release(sta->sdata, tid_agg_rx, &frames); + +release: + spin_unlock_bh(&tid_agg_rx->reorder_lock); + + ieee80211_rx_handlers(&rx, &frames); + + out: + rcu_read_unlock(); +} +EXPORT_SYMBOL(ieee80211_mark_rx_ba_filtered_frames); + +/* main receive path */ + +static inline int ieee80211_bssid_match(const u8 *raddr, const u8 *addr) +{ + return ether_addr_equal(raddr, addr) || + is_broadcast_ether_addr(raddr); +} + +static bool ieee80211_accept_frame(struct ieee80211_rx_data *rx) +{ + struct ieee80211_sub_if_data *sdata = rx->sdata; + struct sk_buff *skb = rx->skb; + struct ieee80211_hdr *hdr = (void *)skb->data; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + u8 *bssid = ieee80211_get_bssid(hdr, skb->len, sdata->vif.type); + bool multicast = is_multicast_ether_addr(hdr->addr1) || + ieee80211_is_s1g_beacon(hdr->frame_control); + + switch (sdata->vif.type) { + case NL80211_IFTYPE_STATION: + if (!bssid && !sdata->u.mgd.use_4addr) + return false; + if (ieee80211_is_robust_mgmt_frame(skb) && !rx->sta) + return false; + if (multicast) + return true; + return ieee80211_is_our_addr(sdata, hdr->addr1, &rx->link_id); + case NL80211_IFTYPE_ADHOC: + if (!bssid) + return false; + if (ether_addr_equal(sdata->vif.addr, hdr->addr2) || + ether_addr_equal(sdata->u.ibss.bssid, hdr->addr2) || + !is_valid_ether_addr(hdr->addr2)) + return false; + if (ieee80211_is_beacon(hdr->frame_control)) + return true; + if (!ieee80211_bssid_match(bssid, sdata->u.ibss.bssid)) + return false; + if (!multicast && + !ether_addr_equal(sdata->vif.addr, hdr->addr1)) + return false; + if (!rx->sta) { + int rate_idx; + if (status->encoding != RX_ENC_LEGACY) + rate_idx = 0; /* TODO: HT/VHT rates */ + else + rate_idx = status->rate_idx; + ieee80211_ibss_rx_no_sta(sdata, bssid, hdr->addr2, + BIT(rate_idx)); + } + return true; + case NL80211_IFTYPE_OCB: + if (!bssid) + return false; + if (!ieee80211_is_data_present(hdr->frame_control)) + return false; + if (!is_broadcast_ether_addr(bssid)) + return false; + if (!multicast && + !ether_addr_equal(sdata->dev->dev_addr, hdr->addr1)) + return false; + if (!rx->sta) { + int rate_idx; + if (status->encoding != RX_ENC_LEGACY) + rate_idx = 0; /* TODO: HT rates */ + else + rate_idx = status->rate_idx; + ieee80211_ocb_rx_no_sta(sdata, bssid, hdr->addr2, + BIT(rate_idx)); + } + return true; + case NL80211_IFTYPE_MESH_POINT: + if (ether_addr_equal(sdata->vif.addr, hdr->addr2)) + return false; + if (multicast) + return true; + return ether_addr_equal(sdata->vif.addr, hdr->addr1); + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_AP: + if (!bssid) + return ieee80211_is_our_addr(sdata, hdr->addr1, + &rx->link_id); + + if (!is_broadcast_ether_addr(bssid) && + !ieee80211_is_our_addr(sdata, bssid, NULL)) { + /* + * Accept public action frames even when the + * BSSID doesn't match, this is used for P2P + * and location updates. Note that mac80211 + * itself never looks at these frames. + */ + if (!multicast && + !ieee80211_is_our_addr(sdata, hdr->addr1, + &rx->link_id)) + return false; + if (ieee80211_is_public_action(hdr, skb->len)) + return true; + return ieee80211_is_beacon(hdr->frame_control); + } + + if (!ieee80211_has_tods(hdr->frame_control)) { + /* ignore data frames to TDLS-peers */ + if (ieee80211_is_data(hdr->frame_control)) + return false; + /* ignore action frames to TDLS-peers */ + if (ieee80211_is_action(hdr->frame_control) && + !is_broadcast_ether_addr(bssid) && + !ether_addr_equal(bssid, hdr->addr1)) + return false; + } + + /* + * 802.11-2016 Table 9-26 says that for data frames, A1 must be + * the BSSID - we've checked that already but may have accepted + * the wildcard (ff:ff:ff:ff:ff:ff). + * + * It also says: + * The BSSID of the Data frame is determined as follows: + * a) If the STA is contained within an AP or is associated + * with an AP, the BSSID is the address currently in use + * by the STA contained in the AP. + * + * So we should not accept data frames with an address that's + * multicast. + * + * Accepting it also opens a security problem because stations + * could encrypt it with the GTK and inject traffic that way. + */ + if (ieee80211_is_data(hdr->frame_control) && multicast) + return false; + + return true; + case NL80211_IFTYPE_P2P_DEVICE: + return ieee80211_is_public_action(hdr, skb->len) || + ieee80211_is_probe_req(hdr->frame_control) || + ieee80211_is_probe_resp(hdr->frame_control) || + ieee80211_is_beacon(hdr->frame_control); + case NL80211_IFTYPE_NAN: + /* Currently no frames on NAN interface are allowed */ + return false; + default: + break; + } + + WARN_ON_ONCE(1); + return false; +} + +void ieee80211_check_fast_rx(struct sta_info *sta) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_key *key; + struct ieee80211_fast_rx fastrx = { + .dev = sdata->dev, + .vif_type = sdata->vif.type, + .control_port_protocol = sdata->control_port_protocol, + }, *old, *new = NULL; + u32 offload_flags; + bool set_offload = false; + bool assign = false; + bool offload; + + /* use sparse to check that we don't return without updating */ + __acquire(check_fast_rx); + + BUILD_BUG_ON(sizeof(fastrx.rfc1042_hdr) != sizeof(rfc1042_header)); + BUILD_BUG_ON(sizeof(fastrx.rfc1042_hdr) != ETH_ALEN); + ether_addr_copy(fastrx.rfc1042_hdr, rfc1042_header); + ether_addr_copy(fastrx.vif_addr, sdata->vif.addr); + + fastrx.uses_rss = ieee80211_hw_check(&local->hw, USES_RSS); + + /* fast-rx doesn't do reordering */ + if (ieee80211_hw_check(&local->hw, AMPDU_AGGREGATION) && + !ieee80211_hw_check(&local->hw, SUPPORTS_REORDERING_BUFFER)) + goto clear; + + switch (sdata->vif.type) { + case NL80211_IFTYPE_STATION: + if (sta->sta.tdls) { + fastrx.da_offs = offsetof(struct ieee80211_hdr, addr1); + fastrx.sa_offs = offsetof(struct ieee80211_hdr, addr2); + fastrx.expected_ds_bits = 0; + } else { + fastrx.da_offs = offsetof(struct ieee80211_hdr, addr1); + fastrx.sa_offs = offsetof(struct ieee80211_hdr, addr3); + fastrx.expected_ds_bits = + cpu_to_le16(IEEE80211_FCTL_FROMDS); + } + + if (sdata->u.mgd.use_4addr && !sta->sta.tdls) { + fastrx.expected_ds_bits |= + cpu_to_le16(IEEE80211_FCTL_TODS); + fastrx.da_offs = offsetof(struct ieee80211_hdr, addr3); + fastrx.sa_offs = offsetof(struct ieee80211_hdr, addr4); + } + + if (!sdata->u.mgd.powersave) + break; + + /* software powersave is a huge mess, avoid all of it */ + if (ieee80211_hw_check(&local->hw, PS_NULLFUNC_STACK)) + goto clear; + if (ieee80211_hw_check(&local->hw, SUPPORTS_PS) && + !ieee80211_hw_check(&local->hw, SUPPORTS_DYNAMIC_PS)) + goto clear; + break; + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_AP: + /* parallel-rx requires this, at least with calls to + * ieee80211_sta_ps_transition() + */ + if (!ieee80211_hw_check(&local->hw, AP_LINK_PS)) + goto clear; + fastrx.da_offs = offsetof(struct ieee80211_hdr, addr3); + fastrx.sa_offs = offsetof(struct ieee80211_hdr, addr2); + fastrx.expected_ds_bits = cpu_to_le16(IEEE80211_FCTL_TODS); + + fastrx.internal_forward = + !(sdata->flags & IEEE80211_SDATA_DONT_BRIDGE_PACKETS) && + (sdata->vif.type != NL80211_IFTYPE_AP_VLAN || + !sdata->u.vlan.sta); + + if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN && + sdata->u.vlan.sta) { + fastrx.expected_ds_bits |= + cpu_to_le16(IEEE80211_FCTL_FROMDS); + fastrx.sa_offs = offsetof(struct ieee80211_hdr, addr4); + fastrx.internal_forward = 0; + } + + break; + default: + goto clear; + } + + if (!test_sta_flag(sta, WLAN_STA_AUTHORIZED)) + goto clear; + + rcu_read_lock(); + key = rcu_dereference(sta->ptk[sta->ptk_idx]); + if (!key) + key = rcu_dereference(sdata->default_unicast_key); + if (key) { + switch (key->conf.cipher) { + case WLAN_CIPHER_SUITE_TKIP: + /* we don't want to deal with MMIC in fast-rx */ + goto clear_rcu; + case WLAN_CIPHER_SUITE_CCMP: + case WLAN_CIPHER_SUITE_CCMP_256: + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + break; + default: + /* We also don't want to deal with + * WEP or cipher scheme. + */ + goto clear_rcu; + } + + fastrx.key = true; + fastrx.icv_len = key->conf.icv_len; + } + + assign = true; + clear_rcu: + rcu_read_unlock(); + clear: + __release(check_fast_rx); + + if (assign) + new = kmemdup(&fastrx, sizeof(fastrx), GFP_KERNEL); + + offload_flags = get_bss_sdata(sdata)->vif.offload_flags; + offload = offload_flags & IEEE80211_OFFLOAD_DECAP_ENABLED; + + if (assign && offload) + set_offload = !test_and_set_sta_flag(sta, WLAN_STA_DECAP_OFFLOAD); + else + set_offload = test_and_clear_sta_flag(sta, WLAN_STA_DECAP_OFFLOAD); + + if (set_offload) + drv_sta_set_decap_offload(local, sdata, &sta->sta, assign); + + spin_lock_bh(&sta->lock); + old = rcu_dereference_protected(sta->fast_rx, true); + rcu_assign_pointer(sta->fast_rx, new); + spin_unlock_bh(&sta->lock); + + if (old) + kfree_rcu(old, rcu_head); +} + +void ieee80211_clear_fast_rx(struct sta_info *sta) +{ + struct ieee80211_fast_rx *old; + + spin_lock_bh(&sta->lock); + old = rcu_dereference_protected(sta->fast_rx, true); + RCU_INIT_POINTER(sta->fast_rx, NULL); + spin_unlock_bh(&sta->lock); + + if (old) + kfree_rcu(old, rcu_head); +} + +void __ieee80211_check_fast_rx_iface(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + + lockdep_assert_held(&local->sta_mtx); + + list_for_each_entry(sta, &local->sta_list, list) { + if (sdata != sta->sdata && + (!sta->sdata->bss || sta->sdata->bss != sdata->bss)) + continue; + ieee80211_check_fast_rx(sta); + } +} + +void ieee80211_check_fast_rx_iface(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + + mutex_lock(&local->sta_mtx); + __ieee80211_check_fast_rx_iface(sdata); + mutex_unlock(&local->sta_mtx); +} + +static void ieee80211_rx_8023(struct ieee80211_rx_data *rx, + struct ieee80211_fast_rx *fast_rx, + int orig_len) +{ + struct ieee80211_sta_rx_stats *stats; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(rx->skb); + struct sta_info *sta = rx->sta; + struct link_sta_info *link_sta; + struct sk_buff *skb = rx->skb; + void *sa = skb->data + ETH_ALEN; + void *da = skb->data; + + if (rx->link_id >= 0) { + link_sta = rcu_dereference(sta->link[rx->link_id]); + if (WARN_ON_ONCE(!link_sta)) { + dev_kfree_skb(rx->skb); + return; + } + } else { + link_sta = &sta->deflink; + } + + stats = &link_sta->rx_stats; + if (fast_rx->uses_rss) + stats = this_cpu_ptr(link_sta->pcpu_rx_stats); + + /* statistics part of ieee80211_rx_h_sta_process() */ + if (!(status->flag & RX_FLAG_NO_SIGNAL_VAL)) { + stats->last_signal = status->signal; + if (!fast_rx->uses_rss) + ewma_signal_add(&link_sta->rx_stats_avg.signal, + -status->signal); + } + + if (status->chains) { + int i; + + stats->chains = status->chains; + for (i = 0; i < ARRAY_SIZE(status->chain_signal); i++) { + int signal = status->chain_signal[i]; + + if (!(status->chains & BIT(i))) + continue; + + stats->chain_signal_last[i] = signal; + if (!fast_rx->uses_rss) + ewma_signal_add(&link_sta->rx_stats_avg.chain_signal[i], + -signal); + } + } + /* end of statistics */ + + stats->last_rx = jiffies; + stats->last_rate = sta_stats_encode_rate(status); + + stats->fragments++; + stats->packets++; + + skb->dev = fast_rx->dev; + + dev_sw_netstats_rx_add(fast_rx->dev, skb->len); + + /* The seqno index has the same property as needed + * for the rx_msdu field, i.e. it is IEEE80211_NUM_TIDS + * for non-QoS-data frames. Here we know it's a data + * frame, so count MSDUs. + */ + u64_stats_update_begin(&stats->syncp); + stats->msdu[rx->seqno_idx]++; + stats->bytes += orig_len; + u64_stats_update_end(&stats->syncp); + + if (fast_rx->internal_forward) { + struct sk_buff *xmit_skb = NULL; + if (is_multicast_ether_addr(da)) { + xmit_skb = skb_copy(skb, GFP_ATOMIC); + } else if (!ether_addr_equal(da, sa) && + sta_info_get(rx->sdata, da)) { + xmit_skb = skb; + skb = NULL; + } + + if (xmit_skb) { + /* + * Send to wireless media and increase priority by 256 + * to keep the received priority instead of + * reclassifying the frame (see cfg80211_classify8021d). + */ + xmit_skb->priority += 256; + xmit_skb->protocol = htons(ETH_P_802_3); + skb_reset_network_header(xmit_skb); + skb_reset_mac_header(xmit_skb); + dev_queue_xmit(xmit_skb); + } + + if (!skb) + return; + } + + /* deliver to local stack */ + skb->protocol = eth_type_trans(skb, fast_rx->dev); + ieee80211_deliver_skb_to_local_stack(skb, rx); +} + +static bool ieee80211_invoke_fast_rx(struct ieee80211_rx_data *rx, + struct ieee80211_fast_rx *fast_rx) +{ + struct sk_buff *skb = rx->skb; + struct ieee80211_hdr *hdr = (void *)skb->data; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + int orig_len = skb->len; + int hdrlen = ieee80211_hdrlen(hdr->frame_control); + int snap_offs = hdrlen; + struct { + u8 snap[sizeof(rfc1042_header)]; + __be16 proto; + } *payload __aligned(2); + struct { + u8 da[ETH_ALEN]; + u8 sa[ETH_ALEN]; + } addrs __aligned(2); + struct ieee80211_sta_rx_stats *stats; + + /* for parallel-rx, we need to have DUP_VALIDATED, otherwise we write + * to a common data structure; drivers can implement that per queue + * but we don't have that information in mac80211 + */ + if (!(status->flag & RX_FLAG_DUP_VALIDATED)) + return false; + +#define FAST_RX_CRYPT_FLAGS (RX_FLAG_PN_VALIDATED | RX_FLAG_DECRYPTED) + + /* If using encryption, we also need to have: + * - PN_VALIDATED: similar, but the implementation is tricky + * - DECRYPTED: necessary for PN_VALIDATED + */ + if (fast_rx->key && + (status->flag & FAST_RX_CRYPT_FLAGS) != FAST_RX_CRYPT_FLAGS) + return false; + + if (unlikely(!ieee80211_is_data_present(hdr->frame_control))) + return false; + + if (unlikely(ieee80211_is_frag(hdr))) + return false; + + /* Since our interface address cannot be multicast, this + * implicitly also rejects multicast frames without the + * explicit check. + * + * We shouldn't get any *data* frames not addressed to us + * (AP mode will accept multicast *management* frames), but + * punting here will make it go through the full checks in + * ieee80211_accept_frame(). + */ + if (!ether_addr_equal(fast_rx->vif_addr, hdr->addr1)) + return false; + + if ((hdr->frame_control & cpu_to_le16(IEEE80211_FCTL_FROMDS | + IEEE80211_FCTL_TODS)) != + fast_rx->expected_ds_bits) + return false; + + /* assign the key to drop unencrypted frames (later) + * and strip the IV/MIC if necessary + */ + if (fast_rx->key && !(status->flag & RX_FLAG_IV_STRIPPED)) { + /* GCMP header length is the same */ + snap_offs += IEEE80211_CCMP_HDR_LEN; + } + + if (!(status->rx_flags & IEEE80211_RX_AMSDU)) { + if (!pskb_may_pull(skb, snap_offs + sizeof(*payload))) + return false; + + payload = (void *)(skb->data + snap_offs); + + if (!ether_addr_equal(payload->snap, fast_rx->rfc1042_hdr)) + return false; + + /* Don't handle these here since they require special code. + * Accept AARP and IPX even though they should come with a + * bridge-tunnel header - but if we get them this way then + * there's little point in discarding them. + */ + if (unlikely(payload->proto == cpu_to_be16(ETH_P_TDLS) || + payload->proto == fast_rx->control_port_protocol)) + return false; + } + + /* after this point, don't punt to the slowpath! */ + + if (rx->key && !(status->flag & RX_FLAG_MIC_STRIPPED) && + pskb_trim(skb, skb->len - fast_rx->icv_len)) + goto drop; + + if (rx->key && !ieee80211_has_protected(hdr->frame_control)) + goto drop; + + if (status->rx_flags & IEEE80211_RX_AMSDU) { + if (__ieee80211_rx_h_amsdu(rx, snap_offs - hdrlen) != + RX_QUEUED) + goto drop; + + return true; + } + + /* do the header conversion - first grab the addresses */ + ether_addr_copy(addrs.da, skb->data + fast_rx->da_offs); + ether_addr_copy(addrs.sa, skb->data + fast_rx->sa_offs); + skb_postpull_rcsum(skb, skb->data + snap_offs, + sizeof(rfc1042_header) + 2); + /* remove the SNAP but leave the ethertype */ + skb_pull(skb, snap_offs + sizeof(rfc1042_header)); + /* push the addresses in front */ + memcpy(skb_push(skb, sizeof(addrs)), &addrs, sizeof(addrs)); + + ieee80211_rx_8023(rx, fast_rx, orig_len); + + return true; + drop: + dev_kfree_skb(skb); + + if (fast_rx->uses_rss) + stats = this_cpu_ptr(rx->link_sta->pcpu_rx_stats); + else + stats = &rx->link_sta->rx_stats; + + stats->dropped++; + return true; +} + +/* + * This function returns whether or not the SKB + * was destined for RX processing or not, which, + * if consume is true, is equivalent to whether + * or not the skb was consumed. + */ +static bool ieee80211_prepare_and_rx_handle(struct ieee80211_rx_data *rx, + struct sk_buff *skb, bool consume) +{ + struct ieee80211_local *local = rx->local; + struct ieee80211_sub_if_data *sdata = rx->sdata; + struct ieee80211_hdr *hdr = (void *)skb->data; + struct link_sta_info *link_sta = rx->link_sta; + struct ieee80211_link_data *link = rx->link; + + rx->skb = skb; + + /* See if we can do fast-rx; if we have to copy we already lost, + * so punt in that case. We should never have to deliver a data + * frame to multiple interfaces anyway. + * + * We skip the ieee80211_accept_frame() call and do the necessary + * checking inside ieee80211_invoke_fast_rx(). + */ + if (consume && rx->sta) { + struct ieee80211_fast_rx *fast_rx; + + fast_rx = rcu_dereference(rx->sta->fast_rx); + if (fast_rx && ieee80211_invoke_fast_rx(rx, fast_rx)) + return true; + } + + if (!ieee80211_accept_frame(rx)) + return false; + + if (!consume) { + struct skb_shared_hwtstamps *shwt; + + rx->skb = skb_copy(skb, GFP_ATOMIC); + if (!rx->skb) { + if (net_ratelimit()) + wiphy_debug(local->hw.wiphy, + "failed to copy skb for %s\n", + sdata->name); + return true; + } + + /* skb_copy() does not copy the hw timestamps, so copy it + * explicitly + */ + shwt = skb_hwtstamps(rx->skb); + shwt->hwtstamp = skb_hwtstamps(skb)->hwtstamp; + + /* Update the hdr pointer to the new skb for translation below */ + hdr = (struct ieee80211_hdr *)rx->skb->data; + } + + if (unlikely(rx->sta && rx->sta->sta.mlo) && + is_unicast_ether_addr(hdr->addr1) && + !ieee80211_is_probe_resp(hdr->frame_control) && + !ieee80211_is_beacon(hdr->frame_control)) { + /* translate to MLD addresses */ + if (ether_addr_equal(link->conf->addr, hdr->addr1)) + ether_addr_copy(hdr->addr1, rx->sdata->vif.addr); + if (ether_addr_equal(link_sta->addr, hdr->addr2)) + ether_addr_copy(hdr->addr2, rx->sta->addr); + /* translate A3 only if it's the BSSID */ + if (!ieee80211_has_tods(hdr->frame_control) && + !ieee80211_has_fromds(hdr->frame_control)) { + if (ether_addr_equal(link_sta->addr, hdr->addr3)) + ether_addr_copy(hdr->addr3, rx->sta->addr); + else if (ether_addr_equal(link->conf->addr, hdr->addr3)) + ether_addr_copy(hdr->addr3, rx->sdata->vif.addr); + } + /* not needed for A4 since it can only carry the SA */ + } + + ieee80211_invoke_rx_handlers(rx); + return true; +} + +static void __ieee80211_rx_handle_8023(struct ieee80211_hw *hw, + struct ieee80211_sta *pubsta, + struct sk_buff *skb, + struct list_head *list) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + struct ieee80211_fast_rx *fast_rx; + struct ieee80211_rx_data rx; + struct sta_info *sta; + int link_id = -1; + + memset(&rx, 0, sizeof(rx)); + rx.skb = skb; + rx.local = local; + rx.list = list; + rx.link_id = -1; + + I802_DEBUG_INC(local->dot11ReceivedFragmentCount); + + /* drop frame if too short for header */ + if (skb->len < sizeof(struct ethhdr)) + goto drop; + + if (!pubsta) + goto drop; + + if (status->link_valid) + link_id = status->link_id; + + /* + * TODO: Should the frame be dropped if the right link_id is not + * available? Or may be it is fine in the current form to proceed with + * the frame processing because with frame being in 802.3 format, + * link_id is used only for stats purpose and updating the stats on + * the deflink is fine? + */ + sta = container_of(pubsta, struct sta_info, sta); + if (!ieee80211_rx_data_set_sta(&rx, sta, link_id)) + goto drop; + + fast_rx = rcu_dereference(rx.sta->fast_rx); + if (!fast_rx) + goto drop; + + ieee80211_rx_8023(&rx, fast_rx, skb->len); + return; + +drop: + dev_kfree_skb(skb); +} + +static bool ieee80211_rx_for_interface(struct ieee80211_rx_data *rx, + struct sk_buff *skb, bool consume) +{ + struct link_sta_info *link_sta; + struct ieee80211_hdr *hdr = (void *)skb->data; + struct sta_info *sta; + int link_id = -1; + + /* + * Look up link station first, in case there's a + * chance that they might have a link address that + * is identical to the MLD address, that way we'll + * have the link information if needed. + */ + link_sta = link_sta_info_get_bss(rx->sdata, hdr->addr2); + if (link_sta) { + sta = link_sta->sta; + link_id = link_sta->link_id; + } else { + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + + sta = sta_info_get_bss(rx->sdata, hdr->addr2); + if (status->link_valid) + link_id = status->link_id; + } + + if (!ieee80211_rx_data_set_sta(rx, sta, link_id)) + return false; + + return ieee80211_prepare_and_rx_handle(rx, skb, consume); +} + +/* + * This is the actual Rx frames handler. as it belongs to Rx path it must + * be called with rcu_read_lock protection. + */ +static void __ieee80211_rx_handle_packet(struct ieee80211_hw *hw, + struct ieee80211_sta *pubsta, + struct sk_buff *skb, + struct list_head *list) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + struct ieee80211_sub_if_data *sdata; + struct ieee80211_hdr *hdr; + __le16 fc; + struct ieee80211_rx_data rx; + struct ieee80211_sub_if_data *prev; + struct rhlist_head *tmp; + int err = 0; + + fc = ((struct ieee80211_hdr *)skb->data)->frame_control; + memset(&rx, 0, sizeof(rx)); + rx.skb = skb; + rx.local = local; + rx.list = list; + rx.link_id = -1; + + if (ieee80211_is_data(fc) || ieee80211_is_mgmt(fc)) + I802_DEBUG_INC(local->dot11ReceivedFragmentCount); + + if (ieee80211_is_mgmt(fc)) { + /* drop frame if too short for header */ + if (skb->len < ieee80211_hdrlen(fc)) + err = -ENOBUFS; + else + err = skb_linearize(skb); + } else { + err = !pskb_may_pull(skb, ieee80211_hdrlen(fc)); + } + + if (err) { + dev_kfree_skb(skb); + return; + } + + hdr = (struct ieee80211_hdr *)skb->data; + ieee80211_parse_qos(&rx); + ieee80211_verify_alignment(&rx); + + if (unlikely(ieee80211_is_probe_resp(hdr->frame_control) || + ieee80211_is_beacon(hdr->frame_control) || + ieee80211_is_s1g_beacon(hdr->frame_control))) + ieee80211_scan_rx(local, skb); + + if (ieee80211_is_data(fc)) { + struct sta_info *sta, *prev_sta; + int link_id = -1; + + if (status->link_valid) + link_id = status->link_id; + + if (pubsta) { + sta = container_of(pubsta, struct sta_info, sta); + if (!ieee80211_rx_data_set_sta(&rx, sta, link_id)) + goto out; + + /* + * In MLO connection, fetch the link_id using addr2 + * when the driver does not pass link_id in status. + * When the address translation is already performed by + * driver/hw, the valid link_id must be passed in + * status. + */ + + if (!status->link_valid && pubsta->mlo) { + struct ieee80211_hdr *hdr = (void *)skb->data; + struct link_sta_info *link_sta; + + link_sta = link_sta_info_get_bss(rx.sdata, + hdr->addr2); + if (!link_sta) + goto out; + + ieee80211_rx_data_set_link(&rx, link_sta->link_id); + } + + if (ieee80211_prepare_and_rx_handle(&rx, skb, true)) + return; + goto out; + } + + prev_sta = NULL; + + for_each_sta_info(local, hdr->addr2, sta, tmp) { + if (!prev_sta) { + prev_sta = sta; + continue; + } + + rx.sdata = prev_sta->sdata; + if (!ieee80211_rx_data_set_sta(&rx, prev_sta, link_id)) + goto out; + + if (!status->link_valid && prev_sta->sta.mlo) + continue; + + ieee80211_prepare_and_rx_handle(&rx, skb, false); + + prev_sta = sta; + } + + if (prev_sta) { + rx.sdata = prev_sta->sdata; + if (!ieee80211_rx_data_set_sta(&rx, prev_sta, link_id)) + goto out; + + if (!status->link_valid && prev_sta->sta.mlo) + goto out; + + if (ieee80211_prepare_and_rx_handle(&rx, skb, true)) + return; + goto out; + } + } + + prev = NULL; + + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + if (!ieee80211_sdata_running(sdata)) + continue; + + if (sdata->vif.type == NL80211_IFTYPE_MONITOR || + sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + continue; + + /* + * frame is destined for this interface, but if it's + * not also for the previous one we handle that after + * the loop to avoid copying the SKB once too much + */ + + if (!prev) { + prev = sdata; + continue; + } + + rx.sdata = prev; + ieee80211_rx_for_interface(&rx, skb, false); + + prev = sdata; + } + + if (prev) { + rx.sdata = prev; + + if (ieee80211_rx_for_interface(&rx, skb, true)) + return; + } + + out: + dev_kfree_skb(skb); +} + +/* + * This is the receive path handler. It is called by a low level driver when an + * 802.11 MPDU is received from the hardware. + */ +void ieee80211_rx_list(struct ieee80211_hw *hw, struct ieee80211_sta *pubsta, + struct sk_buff *skb, struct list_head *list) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_rate *rate = NULL; + struct ieee80211_supported_band *sband; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + + WARN_ON_ONCE(softirq_count() == 0); + + if (WARN_ON(status->band >= NUM_NL80211_BANDS)) + goto drop; + + sband = local->hw.wiphy->bands[status->band]; + if (WARN_ON(!sband)) + goto drop; + + /* + * If we're suspending, it is possible although not too likely + * that we'd be receiving frames after having already partially + * quiesced the stack. We can't process such frames then since + * that might, for example, cause stations to be added or other + * driver callbacks be invoked. + */ + if (unlikely(local->quiescing || local->suspended)) + goto drop; + + /* We might be during a HW reconfig, prevent Rx for the same reason */ + if (unlikely(local->in_reconfig)) + goto drop; + + /* + * The same happens when we're not even started, + * but that's worth a warning. + */ + if (WARN_ON(!local->started)) + goto drop; + + if (likely(!(status->flag & RX_FLAG_FAILED_PLCP_CRC))) { + /* + * Validate the rate, unless a PLCP error means that + * we probably can't have a valid rate here anyway. + */ + + switch (status->encoding) { + case RX_ENC_HT: + /* + * rate_idx is MCS index, which can be [0-76] + * as documented on: + * + * https://wireless.wiki.kernel.org/en/developers/Documentation/ieee80211/802.11n + * + * Anything else would be some sort of driver or + * hardware error. The driver should catch hardware + * errors. + */ + if (WARN(status->rate_idx > 76, + "Rate marked as an HT rate but passed " + "status->rate_idx is not " + "an MCS index [0-76]: %d (0x%02x)\n", + status->rate_idx, + status->rate_idx)) + goto drop; + break; + case RX_ENC_VHT: + if (WARN_ONCE(status->rate_idx > 11 || + !status->nss || + status->nss > 8, + "Rate marked as a VHT rate but data is invalid: MCS: %d, NSS: %d\n", + status->rate_idx, status->nss)) + goto drop; + break; + case RX_ENC_HE: + if (WARN_ONCE(status->rate_idx > 11 || + !status->nss || + status->nss > 8, + "Rate marked as an HE rate but data is invalid: MCS: %d, NSS: %d\n", + status->rate_idx, status->nss)) + goto drop; + break; + default: + WARN_ON_ONCE(1); + fallthrough; + case RX_ENC_LEGACY: + if (WARN_ON(status->rate_idx >= sband->n_bitrates)) + goto drop; + rate = &sband->bitrates[status->rate_idx]; + } + } + + if (WARN_ON_ONCE(status->link_id >= IEEE80211_LINK_UNSPECIFIED)) + goto drop; + + status->rx_flags = 0; + + kcov_remote_start_common(skb_get_kcov_handle(skb)); + + /* + * Frames with failed FCS/PLCP checksum are not returned, + * all other frames are returned without radiotap header + * if it was previously present. + * Also, frames with less than 16 bytes are dropped. + */ + if (!(status->flag & RX_FLAG_8023)) + skb = ieee80211_rx_monitor(local, skb, rate); + if (skb) { + if ((status->flag & RX_FLAG_8023) || + ieee80211_is_data_present(hdr->frame_control)) + ieee80211_tpt_led_trig_rx(local, skb->len); + + if (status->flag & RX_FLAG_8023) + __ieee80211_rx_handle_8023(hw, pubsta, skb, list); + else + __ieee80211_rx_handle_packet(hw, pubsta, skb, list); + } + + kcov_remote_stop(); + return; + drop: + kfree_skb(skb); +} +EXPORT_SYMBOL(ieee80211_rx_list); + +void ieee80211_rx_napi(struct ieee80211_hw *hw, struct ieee80211_sta *pubsta, + struct sk_buff *skb, struct napi_struct *napi) +{ + struct sk_buff *tmp; + LIST_HEAD(list); + + + /* + * key references and virtual interfaces are protected using RCU + * and this requires that we are in a read-side RCU section during + * receive processing + */ + rcu_read_lock(); + ieee80211_rx_list(hw, pubsta, skb, &list); + rcu_read_unlock(); + + if (!napi) { + netif_receive_skb_list(&list); + return; + } + + list_for_each_entry_safe(skb, tmp, &list, list) { + skb_list_del_init(skb); + napi_gro_receive(napi, skb); + } +} +EXPORT_SYMBOL(ieee80211_rx_napi); + +/* This is a version of the rx handler that can be called from hard irq + * context. Post the skb on the queue and schedule the tasklet */ +void ieee80211_rx_irqsafe(struct ieee80211_hw *hw, struct sk_buff *skb) +{ + struct ieee80211_local *local = hw_to_local(hw); + + BUILD_BUG_ON(sizeof(struct ieee80211_rx_status) > sizeof(skb->cb)); + + skb->pkt_type = IEEE80211_RX_MSG; + skb_queue_tail(&local->skb_queue, skb); + tasklet_schedule(&local->tasklet); +} +EXPORT_SYMBOL(ieee80211_rx_irqsafe); diff --git a/net/mac80211/s1g.c b/net/mac80211/s1g.c new file mode 100644 index 000000000..c1f964e99 --- /dev/null +++ b/net/mac80211/s1g.c @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * S1G handling + * Copyright(c) 2020 Adapt-IP + */ +#include +#include +#include "ieee80211_i.h" +#include "driver-ops.h" + +void ieee80211_s1g_sta_rate_init(struct sta_info *sta) +{ + /* avoid indicating legacy bitrates for S1G STAs */ + sta->deflink.tx_stats.last_rate.flags |= IEEE80211_TX_RC_S1G_MCS; + sta->deflink.rx_stats.last_rate = + STA_STATS_FIELD(TYPE, STA_STATS_RATE_TYPE_S1G); +} + +bool ieee80211_s1g_is_twt_setup(struct sk_buff *skb) +{ + struct ieee80211_mgmt *mgmt = (struct ieee80211_mgmt *)skb->data; + + if (likely(!ieee80211_is_action(mgmt->frame_control))) + return false; + + if (likely(mgmt->u.action.category != WLAN_CATEGORY_S1G)) + return false; + + return mgmt->u.action.u.s1g.action_code == WLAN_S1G_TWT_SETUP; +} + +static void +ieee80211_s1g_send_twt_setup(struct ieee80211_sub_if_data *sdata, const u8 *da, + const u8 *bssid, struct ieee80211_twt_setup *twt) +{ + int len = IEEE80211_MIN_ACTION_SIZE + 4 + twt->length; + struct ieee80211_local *local = sdata->local; + struct ieee80211_mgmt *mgmt; + struct sk_buff *skb; + + skb = dev_alloc_skb(local->hw.extra_tx_headroom + len); + if (!skb) + return; + + skb_reserve(skb, local->hw.extra_tx_headroom); + mgmt = skb_put_zero(skb, len); + mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_ACTION); + memcpy(mgmt->da, da, ETH_ALEN); + memcpy(mgmt->sa, sdata->vif.addr, ETH_ALEN); + memcpy(mgmt->bssid, bssid, ETH_ALEN); + + mgmt->u.action.category = WLAN_CATEGORY_S1G; + mgmt->u.action.u.s1g.action_code = WLAN_S1G_TWT_SETUP; + memcpy(mgmt->u.action.u.s1g.variable, twt, 3 + twt->length); + + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_INTFL_DONT_ENCRYPT | + IEEE80211_TX_INTFL_MLME_CONN_TX | + IEEE80211_TX_CTL_REQ_TX_STATUS; + ieee80211_tx_skb(sdata, skb); +} + +static void +ieee80211_s1g_send_twt_teardown(struct ieee80211_sub_if_data *sdata, + const u8 *da, const u8 *bssid, u8 flowid) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_mgmt *mgmt; + struct sk_buff *skb; + u8 *id; + + skb = dev_alloc_skb(local->hw.extra_tx_headroom + + IEEE80211_MIN_ACTION_SIZE + 2); + if (!skb) + return; + + skb_reserve(skb, local->hw.extra_tx_headroom); + mgmt = skb_put_zero(skb, IEEE80211_MIN_ACTION_SIZE + 2); + mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_ACTION); + memcpy(mgmt->da, da, ETH_ALEN); + memcpy(mgmt->sa, sdata->vif.addr, ETH_ALEN); + memcpy(mgmt->bssid, bssid, ETH_ALEN); + + mgmt->u.action.category = WLAN_CATEGORY_S1G; + mgmt->u.action.u.s1g.action_code = WLAN_S1G_TWT_TEARDOWN; + id = (u8 *)mgmt->u.action.u.s1g.variable; + *id = flowid; + + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_INTFL_DONT_ENCRYPT | + IEEE80211_TX_CTL_REQ_TX_STATUS; + ieee80211_tx_skb(sdata, skb); +} + +static void +ieee80211_s1g_rx_twt_setup(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, struct sk_buff *skb) +{ + struct ieee80211_mgmt *mgmt = (void *)skb->data; + struct ieee80211_twt_setup *twt = (void *)mgmt->u.action.u.s1g.variable; + struct ieee80211_twt_params *twt_agrt = (void *)twt->params; + + twt_agrt->req_type &= cpu_to_le16(~IEEE80211_TWT_REQTYPE_REQUEST); + + /* broadcast TWT not supported yet */ + if (twt->control & IEEE80211_TWT_CONTROL_NEG_TYPE_BROADCAST) { + twt_agrt->req_type &= + ~cpu_to_le16(IEEE80211_TWT_REQTYPE_SETUP_CMD); + twt_agrt->req_type |= + le16_encode_bits(TWT_SETUP_CMD_REJECT, + IEEE80211_TWT_REQTYPE_SETUP_CMD); + goto out; + } + + /* TWT Information not supported yet */ + twt->control |= IEEE80211_TWT_CONTROL_RX_DISABLED; + + drv_add_twt_setup(sdata->local, sdata, &sta->sta, twt); +out: + ieee80211_s1g_send_twt_setup(sdata, mgmt->sa, sdata->vif.addr, twt); +} + +static void +ieee80211_s1g_rx_twt_teardown(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, struct sk_buff *skb) +{ + struct ieee80211_mgmt *mgmt = (struct ieee80211_mgmt *)skb->data; + + drv_twt_teardown_request(sdata->local, sdata, &sta->sta, + mgmt->u.action.u.s1g.variable[0]); +} + +static void +ieee80211_s1g_tx_twt_setup_fail(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, struct sk_buff *skb) +{ + struct ieee80211_mgmt *mgmt = (struct ieee80211_mgmt *)skb->data; + struct ieee80211_twt_setup *twt = (void *)mgmt->u.action.u.s1g.variable; + struct ieee80211_twt_params *twt_agrt = (void *)twt->params; + u8 flowid = le16_get_bits(twt_agrt->req_type, + IEEE80211_TWT_REQTYPE_FLOWID); + + drv_twt_teardown_request(sdata->local, sdata, &sta->sta, flowid); + + ieee80211_s1g_send_twt_teardown(sdata, mgmt->sa, sdata->vif.addr, + flowid); +} + +void ieee80211_s1g_rx_twt_action(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_mgmt *mgmt = (struct ieee80211_mgmt *)skb->data; + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + + mutex_lock(&local->sta_mtx); + + sta = sta_info_get_bss(sdata, mgmt->sa); + if (!sta) + goto out; + + switch (mgmt->u.action.u.s1g.action_code) { + case WLAN_S1G_TWT_SETUP: + ieee80211_s1g_rx_twt_setup(sdata, sta, skb); + break; + case WLAN_S1G_TWT_TEARDOWN: + ieee80211_s1g_rx_twt_teardown(sdata, sta, skb); + break; + default: + break; + } + +out: + mutex_unlock(&local->sta_mtx); +} + +void ieee80211_s1g_status_twt_action(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_mgmt *mgmt = (struct ieee80211_mgmt *)skb->data; + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + + mutex_lock(&local->sta_mtx); + + sta = sta_info_get_bss(sdata, mgmt->da); + if (!sta) + goto out; + + switch (mgmt->u.action.u.s1g.action_code) { + case WLAN_S1G_TWT_SETUP: + /* process failed twt setup frames */ + ieee80211_s1g_tx_twt_setup_fail(sdata, sta, skb); + break; + default: + break; + } + +out: + mutex_unlock(&local->sta_mtx); +} diff --git a/net/mac80211/scan.c b/net/mac80211/scan.c new file mode 100644 index 000000000..c37e2576f --- /dev/null +++ b/net/mac80211/scan.c @@ -0,0 +1,1471 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Scanning implementation + * + * Copyright 2003, Jouni Malinen + * Copyright 2004, Instant802 Networks, Inc. + * Copyright 2005, Devicescape Software, Inc. + * Copyright 2006-2007 Jiri Benc + * Copyright 2007, Michael Wu + * Copyright 2013-2015 Intel Mobile Communications GmbH + * Copyright 2016-2017 Intel Deutschland GmbH + * Copyright (C) 2018-2021 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ieee80211_i.h" +#include "driver-ops.h" +#include "mesh.h" + +#define IEEE80211_PROBE_DELAY (HZ / 33) +#define IEEE80211_CHANNEL_TIME (HZ / 33) +#define IEEE80211_PASSIVE_CHANNEL_TIME (HZ / 9) + +void ieee80211_rx_bss_put(struct ieee80211_local *local, + struct ieee80211_bss *bss) +{ + if (!bss) + return; + cfg80211_put_bss(local->hw.wiphy, + container_of((void *)bss, struct cfg80211_bss, priv)); +} + +static bool is_uapsd_supported(struct ieee802_11_elems *elems) +{ + u8 qos_info; + + if (elems->wmm_info && elems->wmm_info_len == 7 + && elems->wmm_info[5] == 1) + qos_info = elems->wmm_info[6]; + else if (elems->wmm_param && elems->wmm_param_len == 24 + && elems->wmm_param[5] == 1) + qos_info = elems->wmm_param[6]; + else + /* no valid wmm information or parameter element found */ + return false; + + return qos_info & IEEE80211_WMM_IE_AP_QOSINFO_UAPSD; +} + +static void +ieee80211_update_bss_from_elems(struct ieee80211_local *local, + struct ieee80211_bss *bss, + struct ieee802_11_elems *elems, + struct ieee80211_rx_status *rx_status, + bool beacon) +{ + int clen, srlen; + + if (beacon) + bss->device_ts_beacon = rx_status->device_timestamp; + else + bss->device_ts_presp = rx_status->device_timestamp; + + if (elems->parse_error) { + if (beacon) + bss->corrupt_data |= IEEE80211_BSS_CORRUPT_BEACON; + else + bss->corrupt_data |= IEEE80211_BSS_CORRUPT_PROBE_RESP; + } else { + if (beacon) + bss->corrupt_data &= ~IEEE80211_BSS_CORRUPT_BEACON; + else + bss->corrupt_data &= ~IEEE80211_BSS_CORRUPT_PROBE_RESP; + } + + /* save the ERP value so that it is available at association time */ + if (elems->erp_info && (!elems->parse_error || + !(bss->valid_data & IEEE80211_BSS_VALID_ERP))) { + bss->erp_value = elems->erp_info[0]; + bss->has_erp_value = true; + if (!elems->parse_error) + bss->valid_data |= IEEE80211_BSS_VALID_ERP; + } + + /* replace old supported rates if we get new values */ + if (!elems->parse_error || + !(bss->valid_data & IEEE80211_BSS_VALID_RATES)) { + srlen = 0; + if (elems->supp_rates) { + clen = IEEE80211_MAX_SUPP_RATES; + if (clen > elems->supp_rates_len) + clen = elems->supp_rates_len; + memcpy(bss->supp_rates, elems->supp_rates, clen); + srlen += clen; + } + if (elems->ext_supp_rates) { + clen = IEEE80211_MAX_SUPP_RATES - srlen; + if (clen > elems->ext_supp_rates_len) + clen = elems->ext_supp_rates_len; + memcpy(bss->supp_rates + srlen, elems->ext_supp_rates, + clen); + srlen += clen; + } + if (srlen) { + bss->supp_rates_len = srlen; + if (!elems->parse_error) + bss->valid_data |= IEEE80211_BSS_VALID_RATES; + } + } + + if (!elems->parse_error || + !(bss->valid_data & IEEE80211_BSS_VALID_WMM)) { + bss->wmm_used = elems->wmm_param || elems->wmm_info; + bss->uapsd_supported = is_uapsd_supported(elems); + if (!elems->parse_error) + bss->valid_data |= IEEE80211_BSS_VALID_WMM; + } + + if (beacon) { + struct ieee80211_supported_band *sband = + local->hw.wiphy->bands[rx_status->band]; + if (!(rx_status->encoding == RX_ENC_HT) && + !(rx_status->encoding == RX_ENC_VHT)) + bss->beacon_rate = + &sband->bitrates[rx_status->rate_idx]; + } + + if (elems->vht_cap_elem) + bss->vht_cap_info = + le32_to_cpu(elems->vht_cap_elem->vht_cap_info); + else + bss->vht_cap_info = 0; +} + +struct ieee80211_bss * +ieee80211_bss_info_update(struct ieee80211_local *local, + struct ieee80211_rx_status *rx_status, + struct ieee80211_mgmt *mgmt, size_t len, + struct ieee80211_channel *channel) +{ + bool beacon = ieee80211_is_beacon(mgmt->frame_control) || + ieee80211_is_s1g_beacon(mgmt->frame_control); + struct cfg80211_bss *cbss, *non_tx_cbss; + struct ieee80211_bss *bss, *non_tx_bss; + struct cfg80211_inform_bss bss_meta = { + .boottime_ns = rx_status->boottime_ns, + }; + bool signal_valid; + struct ieee80211_sub_if_data *scan_sdata; + struct ieee802_11_elems *elems; + size_t baselen; + u8 *elements; + + if (rx_status->flag & RX_FLAG_NO_SIGNAL_VAL) + bss_meta.signal = 0; /* invalid signal indication */ + else if (ieee80211_hw_check(&local->hw, SIGNAL_DBM)) + bss_meta.signal = rx_status->signal * 100; + else if (ieee80211_hw_check(&local->hw, SIGNAL_UNSPEC)) + bss_meta.signal = (rx_status->signal * 100) / local->hw.max_signal; + + bss_meta.scan_width = NL80211_BSS_CHAN_WIDTH_20; + if (rx_status->bw == RATE_INFO_BW_5) + bss_meta.scan_width = NL80211_BSS_CHAN_WIDTH_5; + else if (rx_status->bw == RATE_INFO_BW_10) + bss_meta.scan_width = NL80211_BSS_CHAN_WIDTH_10; + + bss_meta.chan = channel; + + rcu_read_lock(); + scan_sdata = rcu_dereference(local->scan_sdata); + if (scan_sdata && scan_sdata->vif.type == NL80211_IFTYPE_STATION && + scan_sdata->vif.cfg.assoc && + ieee80211_have_rx_timestamp(rx_status)) { + bss_meta.parent_tsf = + ieee80211_calculate_rx_timestamp(local, rx_status, + len + FCS_LEN, 24); + ether_addr_copy(bss_meta.parent_bssid, + scan_sdata->vif.bss_conf.bssid); + } + rcu_read_unlock(); + + cbss = cfg80211_inform_bss_frame_data(local->hw.wiphy, &bss_meta, + mgmt, len, GFP_ATOMIC); + if (!cbss) + return NULL; + + if (ieee80211_is_probe_resp(mgmt->frame_control)) { + elements = mgmt->u.probe_resp.variable; + baselen = offsetof(struct ieee80211_mgmt, + u.probe_resp.variable); + } else if (ieee80211_is_s1g_beacon(mgmt->frame_control)) { + struct ieee80211_ext *ext = (void *) mgmt; + + baselen = offsetof(struct ieee80211_ext, u.s1g_beacon.variable); + elements = ext->u.s1g_beacon.variable; + } else { + baselen = offsetof(struct ieee80211_mgmt, u.beacon.variable); + elements = mgmt->u.beacon.variable; + } + + if (baselen > len) + return NULL; + + elems = ieee802_11_parse_elems(elements, len - baselen, false, cbss); + if (!elems) + return NULL; + + /* In case the signal is invalid update the status */ + signal_valid = channel == cbss->channel; + if (!signal_valid) + rx_status->flag |= RX_FLAG_NO_SIGNAL_VAL; + + bss = (void *)cbss->priv; + ieee80211_update_bss_from_elems(local, bss, elems, rx_status, beacon); + kfree(elems); + + list_for_each_entry(non_tx_cbss, &cbss->nontrans_list, nontrans_list) { + non_tx_bss = (void *)non_tx_cbss->priv; + + elems = ieee802_11_parse_elems(elements, len - baselen, false, + non_tx_cbss); + if (!elems) + continue; + + ieee80211_update_bss_from_elems(local, non_tx_bss, elems, + rx_status, beacon); + kfree(elems); + } + + return bss; +} + +static bool ieee80211_scan_accept_presp(struct ieee80211_sub_if_data *sdata, + u32 scan_flags, const u8 *da) +{ + if (!sdata) + return false; + /* accept broadcast for OCE */ + if (scan_flags & NL80211_SCAN_FLAG_ACCEPT_BCAST_PROBE_RESP && + is_broadcast_ether_addr(da)) + return true; + if (scan_flags & NL80211_SCAN_FLAG_RANDOM_ADDR) + return true; + return ether_addr_equal(da, sdata->vif.addr); +} + +void ieee80211_scan_rx(struct ieee80211_local *local, struct sk_buff *skb) +{ + struct ieee80211_rx_status *rx_status = IEEE80211_SKB_RXCB(skb); + struct ieee80211_sub_if_data *sdata1, *sdata2; + struct ieee80211_mgmt *mgmt = (void *)skb->data; + struct ieee80211_bss *bss; + struct ieee80211_channel *channel; + size_t min_hdr_len = offsetof(struct ieee80211_mgmt, + u.probe_resp.variable); + + if (!ieee80211_is_probe_resp(mgmt->frame_control) && + !ieee80211_is_beacon(mgmt->frame_control) && + !ieee80211_is_s1g_beacon(mgmt->frame_control)) + return; + + if (ieee80211_is_s1g_beacon(mgmt->frame_control)) { + if (ieee80211_is_s1g_short_beacon(mgmt->frame_control)) + min_hdr_len = offsetof(struct ieee80211_ext, + u.s1g_short_beacon.variable); + else + min_hdr_len = offsetof(struct ieee80211_ext, + u.s1g_beacon); + } + + if (skb->len < min_hdr_len) + return; + + sdata1 = rcu_dereference(local->scan_sdata); + sdata2 = rcu_dereference(local->sched_scan_sdata); + + if (likely(!sdata1 && !sdata2)) + return; + + if (test_and_clear_bit(SCAN_BEACON_WAIT, &local->scanning)) { + /* + * we were passive scanning because of radar/no-IR, but + * the beacon/proberesp rx gives us an opportunity to upgrade + * to active scan + */ + set_bit(SCAN_BEACON_DONE, &local->scanning); + wiphy_delayed_work_queue(local->hw.wiphy, &local->scan_work, 0); + } + + if (ieee80211_is_probe_resp(mgmt->frame_control)) { + struct cfg80211_scan_request *scan_req; + struct cfg80211_sched_scan_request *sched_scan_req; + u32 scan_req_flags = 0, sched_scan_req_flags = 0; + + scan_req = rcu_dereference(local->scan_req); + sched_scan_req = rcu_dereference(local->sched_scan_req); + + if (scan_req) + scan_req_flags = scan_req->flags; + + if (sched_scan_req) + sched_scan_req_flags = sched_scan_req->flags; + + /* ignore ProbeResp to foreign address or non-bcast (OCE) + * unless scanning with randomised address + */ + if (!ieee80211_scan_accept_presp(sdata1, scan_req_flags, + mgmt->da) && + !ieee80211_scan_accept_presp(sdata2, sched_scan_req_flags, + mgmt->da)) + return; + } + + channel = ieee80211_get_channel_khz(local->hw.wiphy, + ieee80211_rx_status_to_khz(rx_status)); + + if (!channel || channel->flags & IEEE80211_CHAN_DISABLED) + return; + + bss = ieee80211_bss_info_update(local, rx_status, + mgmt, skb->len, + channel); + if (bss) + ieee80211_rx_bss_put(local, bss); +} + +static void +ieee80211_prepare_scan_chandef(struct cfg80211_chan_def *chandef, + enum nl80211_bss_scan_width scan_width) +{ + memset(chandef, 0, sizeof(*chandef)); + switch (scan_width) { + case NL80211_BSS_CHAN_WIDTH_5: + chandef->width = NL80211_CHAN_WIDTH_5; + break; + case NL80211_BSS_CHAN_WIDTH_10: + chandef->width = NL80211_CHAN_WIDTH_10; + break; + default: + chandef->width = NL80211_CHAN_WIDTH_20_NOHT; + break; + } +} + +/* return false if no more work */ +static bool ieee80211_prep_hw_scan(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct cfg80211_scan_request *req; + struct cfg80211_chan_def chandef; + u8 bands_used = 0; + int i, ielen, n_chans; + u32 flags = 0; + + req = rcu_dereference_protected(local->scan_req, + lockdep_is_held(&local->mtx)); + + if (test_bit(SCAN_HW_CANCELLED, &local->scanning)) + return false; + + if (ieee80211_hw_check(&local->hw, SINGLE_SCAN_ON_ALL_BANDS)) { + for (i = 0; i < req->n_channels; i++) { + local->hw_scan_req->req.channels[i] = req->channels[i]; + bands_used |= BIT(req->channels[i]->band); + } + + n_chans = req->n_channels; + } else { + do { + if (local->hw_scan_band == NUM_NL80211_BANDS) + return false; + + n_chans = 0; + + for (i = 0; i < req->n_channels; i++) { + if (req->channels[i]->band != + local->hw_scan_band) + continue; + local->hw_scan_req->req.channels[n_chans] = + req->channels[i]; + n_chans++; + bands_used |= BIT(req->channels[i]->band); + } + + local->hw_scan_band++; + } while (!n_chans); + } + + local->hw_scan_req->req.n_channels = n_chans; + ieee80211_prepare_scan_chandef(&chandef, req->scan_width); + + if (req->flags & NL80211_SCAN_FLAG_MIN_PREQ_CONTENT) + flags |= IEEE80211_PROBE_FLAG_MIN_CONTENT; + + ielen = ieee80211_build_preq_ies(sdata, + (u8 *)local->hw_scan_req->req.ie, + local->hw_scan_ies_bufsize, + &local->hw_scan_req->ies, + req->ie, req->ie_len, + bands_used, req->rates, &chandef, + flags); + local->hw_scan_req->req.ie_len = ielen; + local->hw_scan_req->req.no_cck = req->no_cck; + ether_addr_copy(local->hw_scan_req->req.mac_addr, req->mac_addr); + ether_addr_copy(local->hw_scan_req->req.mac_addr_mask, + req->mac_addr_mask); + ether_addr_copy(local->hw_scan_req->req.bssid, req->bssid); + + return true; +} + +static void __ieee80211_scan_completed(struct ieee80211_hw *hw, bool aborted) +{ + struct ieee80211_local *local = hw_to_local(hw); + bool hw_scan = test_bit(SCAN_HW_SCANNING, &local->scanning); + bool was_scanning = local->scanning; + struct cfg80211_scan_request *scan_req; + struct ieee80211_sub_if_data *scan_sdata; + struct ieee80211_sub_if_data *sdata; + + lockdep_assert_held(&local->mtx); + + /* + * It's ok to abort a not-yet-running scan (that + * we have one at all will be verified by checking + * local->scan_req next), but not to complete it + * successfully. + */ + if (WARN_ON(!local->scanning && !aborted)) + aborted = true; + + if (WARN_ON(!local->scan_req)) + return; + + scan_sdata = rcu_dereference_protected(local->scan_sdata, + lockdep_is_held(&local->mtx)); + + if (hw_scan && !aborted && + !ieee80211_hw_check(&local->hw, SINGLE_SCAN_ON_ALL_BANDS) && + ieee80211_prep_hw_scan(scan_sdata)) { + int rc; + + rc = drv_hw_scan(local, + rcu_dereference_protected(local->scan_sdata, + lockdep_is_held(&local->mtx)), + local->hw_scan_req); + + if (rc == 0) + return; + + /* HW scan failed and is going to be reported as aborted, + * so clear old scan info. + */ + memset(&local->scan_info, 0, sizeof(local->scan_info)); + aborted = true; + } + + kfree(local->hw_scan_req); + local->hw_scan_req = NULL; + + scan_req = rcu_dereference_protected(local->scan_req, + lockdep_is_held(&local->mtx)); + + RCU_INIT_POINTER(local->scan_req, NULL); + RCU_INIT_POINTER(local->scan_sdata, NULL); + + local->scanning = 0; + local->scan_chandef.chan = NULL; + + synchronize_rcu(); + + if (scan_req != local->int_scan_req) { + local->scan_info.aborted = aborted; + cfg80211_scan_done(scan_req, &local->scan_info); + } + + /* Set power back to normal operating levels. */ + ieee80211_hw_config(local, 0); + + if (!hw_scan && was_scanning) { + ieee80211_configure_filter(local); + drv_sw_scan_complete(local, scan_sdata); + ieee80211_offchannel_return(local); + } + + ieee80211_recalc_idle(local); + + ieee80211_mlme_notify_scan_completed(local); + ieee80211_ibss_notify_scan_completed(local); + + /* Requeue all the work that might have been ignored while + * the scan was in progress; if there was none this will + * just be a no-op for the particular interface. + */ + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + if (ieee80211_sdata_running(sdata)) + ieee80211_queue_work(&sdata->local->hw, &sdata->work); + } + + if (was_scanning) + ieee80211_start_next_roc(local); +} + +void ieee80211_scan_completed(struct ieee80211_hw *hw, + struct cfg80211_scan_info *info) +{ + struct ieee80211_local *local = hw_to_local(hw); + + trace_api_scan_completed(local, info->aborted); + + set_bit(SCAN_COMPLETED, &local->scanning); + if (info->aborted) + set_bit(SCAN_ABORTED, &local->scanning); + + memcpy(&local->scan_info, info, sizeof(*info)); + + wiphy_delayed_work_queue(local->hw.wiphy, &local->scan_work, 0); +} +EXPORT_SYMBOL(ieee80211_scan_completed); + +static int ieee80211_start_sw_scan(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + /* Software scan is not supported in multi-channel cases */ + if (local->use_chanctx) + return -EOPNOTSUPP; + + /* + * Hardware/driver doesn't support hw_scan, so use software + * scanning instead. First send a nullfunc frame with power save + * bit on so that AP will buffer the frames for us while we are not + * listening, then send probe requests to each channel and wait for + * the responses. After all channels are scanned, tune back to the + * original channel and send a nullfunc frame with power save bit + * off to trigger the AP to send us all the buffered frames. + * + * Note that while local->sw_scanning is true everything else but + * nullfunc frames and probe requests will be dropped in + * ieee80211_tx_h_check_assoc(). + */ + drv_sw_scan_start(local, sdata, local->scan_addr); + + local->leave_oper_channel_time = jiffies; + local->next_scan_state = SCAN_DECISION; + local->scan_channel_idx = 0; + + ieee80211_offchannel_stop_vifs(local); + + /* ensure nullfunc is transmitted before leaving operating channel */ + ieee80211_flush_queues(local, NULL, false); + + ieee80211_configure_filter(local); + + /* We need to set power level at maximum rate for scanning. */ + ieee80211_hw_config(local, 0); + + wiphy_delayed_work_queue(local->hw.wiphy, &local->scan_work, 0); + + return 0; +} + +static bool __ieee80211_can_leave_ch(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_sub_if_data *sdata_iter; + + if (!ieee80211_is_radar_required(local)) + return true; + + if (!regulatory_pre_cac_allowed(local->hw.wiphy)) + return false; + + mutex_lock(&local->iflist_mtx); + list_for_each_entry(sdata_iter, &local->interfaces, list) { + if (sdata_iter->wdev.cac_started) { + mutex_unlock(&local->iflist_mtx); + return false; + } + } + mutex_unlock(&local->iflist_mtx); + + return true; +} + +static bool ieee80211_can_scan(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + if (!__ieee80211_can_leave_ch(sdata)) + return false; + + if (!list_empty(&local->roc_list)) + return false; + + if (sdata->vif.type == NL80211_IFTYPE_STATION && + sdata->u.mgd.flags & IEEE80211_STA_CONNECTION_POLL) + return false; + + return true; +} + +void ieee80211_run_deferred_scan(struct ieee80211_local *local) +{ + lockdep_assert_held(&local->mtx); + + if (!local->scan_req || local->scanning) + return; + + if (!ieee80211_can_scan(local, + rcu_dereference_protected( + local->scan_sdata, + lockdep_is_held(&local->mtx)))) + return; + + wiphy_delayed_work_queue(local->hw.wiphy, &local->scan_work, + round_jiffies_relative(0)); +} + +static void ieee80211_send_scan_probe_req(struct ieee80211_sub_if_data *sdata, + const u8 *src, const u8 *dst, + const u8 *ssid, size_t ssid_len, + const u8 *ie, size_t ie_len, + u32 ratemask, u32 flags, u32 tx_flags, + struct ieee80211_channel *channel) +{ + struct sk_buff *skb; + + skb = ieee80211_build_probe_req(sdata, src, dst, ratemask, channel, + ssid, ssid_len, + ie, ie_len, flags); + + if (skb) { + if (flags & IEEE80211_PROBE_FLAG_RANDOM_SN) { + struct ieee80211_hdr *hdr = (void *)skb->data; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + u16 sn = get_random_u16(); + + info->control.flags |= IEEE80211_TX_CTRL_NO_SEQNO; + hdr->seq_ctrl = + cpu_to_le16(IEEE80211_SN_TO_SEQ(sn)); + } + IEEE80211_SKB_CB(skb)->flags |= tx_flags; + ieee80211_tx_skb_tid_band(sdata, skb, 7, channel->band); + } +} + +static void ieee80211_scan_state_send_probe(struct ieee80211_local *local, + unsigned long *next_delay) +{ + int i; + struct ieee80211_sub_if_data *sdata; + struct cfg80211_scan_request *scan_req; + enum nl80211_band band = local->hw.conf.chandef.chan->band; + u32 flags = 0, tx_flags; + + scan_req = rcu_dereference_protected(local->scan_req, + lockdep_is_held(&local->mtx)); + + tx_flags = IEEE80211_TX_INTFL_OFFCHAN_TX_OK; + if (scan_req->no_cck) + tx_flags |= IEEE80211_TX_CTL_NO_CCK_RATE; + if (scan_req->flags & NL80211_SCAN_FLAG_MIN_PREQ_CONTENT) + flags |= IEEE80211_PROBE_FLAG_MIN_CONTENT; + if (scan_req->flags & NL80211_SCAN_FLAG_RANDOM_SN) + flags |= IEEE80211_PROBE_FLAG_RANDOM_SN; + + sdata = rcu_dereference_protected(local->scan_sdata, + lockdep_is_held(&local->mtx)); + + for (i = 0; i < scan_req->n_ssids; i++) + ieee80211_send_scan_probe_req( + sdata, local->scan_addr, scan_req->bssid, + scan_req->ssids[i].ssid, scan_req->ssids[i].ssid_len, + scan_req->ie, scan_req->ie_len, + scan_req->rates[band], flags, + tx_flags, local->hw.conf.chandef.chan); + + /* + * After sending probe requests, wait for probe responses + * on the channel. + */ + *next_delay = IEEE80211_CHANNEL_TIME; + local->next_scan_state = SCAN_DECISION; +} + +static int __ieee80211_start_scan(struct ieee80211_sub_if_data *sdata, + struct cfg80211_scan_request *req) +{ + struct ieee80211_local *local = sdata->local; + bool hw_scan = local->ops->hw_scan; + int rc; + + lockdep_assert_held(&local->mtx); + + if (local->scan_req) + return -EBUSY; + + if (!__ieee80211_can_leave_ch(sdata)) + return -EBUSY; + + if (!ieee80211_can_scan(local, sdata)) { + /* wait for the work to finish/time out */ + rcu_assign_pointer(local->scan_req, req); + rcu_assign_pointer(local->scan_sdata, sdata); + return 0; + } + + again: + if (hw_scan) { + u8 *ies; + + local->hw_scan_ies_bufsize = local->scan_ies_len + req->ie_len; + + if (ieee80211_hw_check(&local->hw, SINGLE_SCAN_ON_ALL_BANDS)) { + int i, n_bands = 0; + u8 bands_counted = 0; + + for (i = 0; i < req->n_channels; i++) { + if (bands_counted & BIT(req->channels[i]->band)) + continue; + bands_counted |= BIT(req->channels[i]->band); + n_bands++; + } + + local->hw_scan_ies_bufsize *= n_bands; + } + + local->hw_scan_req = kmalloc( + sizeof(*local->hw_scan_req) + + req->n_channels * sizeof(req->channels[0]) + + local->hw_scan_ies_bufsize, GFP_KERNEL); + if (!local->hw_scan_req) + return -ENOMEM; + + local->hw_scan_req->req.ssids = req->ssids; + local->hw_scan_req->req.n_ssids = req->n_ssids; + ies = (u8 *)local->hw_scan_req + + sizeof(*local->hw_scan_req) + + req->n_channels * sizeof(req->channels[0]); + local->hw_scan_req->req.ie = ies; + local->hw_scan_req->req.flags = req->flags; + eth_broadcast_addr(local->hw_scan_req->req.bssid); + local->hw_scan_req->req.duration = req->duration; + local->hw_scan_req->req.duration_mandatory = + req->duration_mandatory; + + local->hw_scan_band = 0; + local->hw_scan_req->req.n_6ghz_params = req->n_6ghz_params; + local->hw_scan_req->req.scan_6ghz_params = + req->scan_6ghz_params; + local->hw_scan_req->req.scan_6ghz = req->scan_6ghz; + + /* + * After allocating local->hw_scan_req, we must + * go through until ieee80211_prep_hw_scan(), so + * anything that might be changed here and leave + * this function early must not go after this + * allocation. + */ + } + + rcu_assign_pointer(local->scan_req, req); + rcu_assign_pointer(local->scan_sdata, sdata); + + if (req->flags & NL80211_SCAN_FLAG_RANDOM_ADDR) + get_random_mask_addr(local->scan_addr, + req->mac_addr, + req->mac_addr_mask); + else + memcpy(local->scan_addr, sdata->vif.addr, ETH_ALEN); + + if (hw_scan) { + __set_bit(SCAN_HW_SCANNING, &local->scanning); + } else if ((req->n_channels == 1) && + (req->channels[0] == local->_oper_chandef.chan)) { + /* + * If we are scanning only on the operating channel + * then we do not need to stop normal activities + */ + unsigned long next_delay; + + __set_bit(SCAN_ONCHANNEL_SCANNING, &local->scanning); + + ieee80211_recalc_idle(local); + + /* Notify driver scan is starting, keep order of operations + * same as normal software scan, in case that matters. */ + drv_sw_scan_start(local, sdata, local->scan_addr); + + ieee80211_configure_filter(local); /* accept probe-responses */ + + /* We need to ensure power level is at max for scanning. */ + ieee80211_hw_config(local, 0); + + if ((req->channels[0]->flags & (IEEE80211_CHAN_NO_IR | + IEEE80211_CHAN_RADAR)) || + !req->n_ssids) { + next_delay = IEEE80211_PASSIVE_CHANNEL_TIME; + if (req->n_ssids) + set_bit(SCAN_BEACON_WAIT, &local->scanning); + } else { + ieee80211_scan_state_send_probe(local, &next_delay); + next_delay = IEEE80211_CHANNEL_TIME; + } + + /* Now, just wait a bit and we are all done! */ + wiphy_delayed_work_queue(local->hw.wiphy, &local->scan_work, + next_delay); + return 0; + } else { + /* Do normal software scan */ + __set_bit(SCAN_SW_SCANNING, &local->scanning); + } + + ieee80211_recalc_idle(local); + + if (hw_scan) { + WARN_ON(!ieee80211_prep_hw_scan(sdata)); + rc = drv_hw_scan(local, sdata, local->hw_scan_req); + } else { + rc = ieee80211_start_sw_scan(local, sdata); + } + + if (rc) { + kfree(local->hw_scan_req); + local->hw_scan_req = NULL; + local->scanning = 0; + + ieee80211_recalc_idle(local); + + local->scan_req = NULL; + RCU_INIT_POINTER(local->scan_sdata, NULL); + } + + if (hw_scan && rc == 1) { + /* + * we can't fall back to software for P2P-GO + * as it must update NoA etc. + */ + if (ieee80211_vif_type_p2p(&sdata->vif) == + NL80211_IFTYPE_P2P_GO) + return -EOPNOTSUPP; + hw_scan = false; + goto again; + } + + return rc; +} + +static unsigned long +ieee80211_scan_get_channel_time(struct ieee80211_channel *chan) +{ + /* + * TODO: channel switching also consumes quite some time, + * add that delay as well to get a better estimation + */ + if (chan->flags & (IEEE80211_CHAN_NO_IR | IEEE80211_CHAN_RADAR)) + return IEEE80211_PASSIVE_CHANNEL_TIME; + return IEEE80211_PROBE_DELAY + IEEE80211_CHANNEL_TIME; +} + +static void ieee80211_scan_state_decision(struct ieee80211_local *local, + unsigned long *next_delay) +{ + bool associated = false; + bool tx_empty = true; + bool bad_latency; + struct ieee80211_sub_if_data *sdata; + struct ieee80211_channel *next_chan; + enum mac80211_scan_state next_scan_state; + struct cfg80211_scan_request *scan_req; + + /* + * check if at least one STA interface is associated, + * check if at least one STA interface has pending tx frames + * and grab the lowest used beacon interval + */ + mutex_lock(&local->iflist_mtx); + list_for_each_entry(sdata, &local->interfaces, list) { + if (!ieee80211_sdata_running(sdata)) + continue; + + if (sdata->vif.type == NL80211_IFTYPE_STATION) { + if (sdata->u.mgd.associated) { + associated = true; + + if (!qdisc_all_tx_empty(sdata->dev)) { + tx_empty = false; + break; + } + } + } + } + mutex_unlock(&local->iflist_mtx); + + scan_req = rcu_dereference_protected(local->scan_req, + lockdep_is_held(&local->mtx)); + + next_chan = scan_req->channels[local->scan_channel_idx]; + + /* + * we're currently scanning a different channel, let's + * see if we can scan another channel without interfering + * with the current traffic situation. + * + * Keep good latency, do not stay off-channel more than 125 ms. + */ + + bad_latency = time_after(jiffies + + ieee80211_scan_get_channel_time(next_chan), + local->leave_oper_channel_time + HZ / 8); + + if (associated && !tx_empty) { + if (scan_req->flags & NL80211_SCAN_FLAG_LOW_PRIORITY) + next_scan_state = SCAN_ABORT; + else + next_scan_state = SCAN_SUSPEND; + } else if (associated && bad_latency) { + next_scan_state = SCAN_SUSPEND; + } else { + next_scan_state = SCAN_SET_CHANNEL; + } + + local->next_scan_state = next_scan_state; + + *next_delay = 0; +} + +static void ieee80211_scan_state_set_channel(struct ieee80211_local *local, + unsigned long *next_delay) +{ + int skip; + struct ieee80211_channel *chan; + enum nl80211_bss_scan_width oper_scan_width; + struct cfg80211_scan_request *scan_req; + + scan_req = rcu_dereference_protected(local->scan_req, + lockdep_is_held(&local->mtx)); + + skip = 0; + chan = scan_req->channels[local->scan_channel_idx]; + + local->scan_chandef.chan = chan; + local->scan_chandef.center_freq1 = chan->center_freq; + local->scan_chandef.freq1_offset = chan->freq_offset; + local->scan_chandef.center_freq2 = 0; + + /* For scanning on the S1G band, ignore scan_width (which is constant + * across all channels) for now since channel width is specific to each + * channel. Detect the required channel width here and likely revisit + * later. Maybe scan_width could be used to build the channel scan list? + */ + if (chan->band == NL80211_BAND_S1GHZ) { + local->scan_chandef.width = ieee80211_s1g_channel_width(chan); + goto set_channel; + } + + switch (scan_req->scan_width) { + case NL80211_BSS_CHAN_WIDTH_5: + local->scan_chandef.width = NL80211_CHAN_WIDTH_5; + break; + case NL80211_BSS_CHAN_WIDTH_10: + local->scan_chandef.width = NL80211_CHAN_WIDTH_10; + break; + default: + case NL80211_BSS_CHAN_WIDTH_20: + /* If scanning on oper channel, use whatever channel-type + * is currently in use. + */ + oper_scan_width = cfg80211_chandef_to_scan_width( + &local->_oper_chandef); + if (chan == local->_oper_chandef.chan && + oper_scan_width == scan_req->scan_width) + local->scan_chandef = local->_oper_chandef; + else + local->scan_chandef.width = NL80211_CHAN_WIDTH_20_NOHT; + break; + case NL80211_BSS_CHAN_WIDTH_1: + case NL80211_BSS_CHAN_WIDTH_2: + /* shouldn't get here, S1G handled above */ + WARN_ON(1); + break; + } + +set_channel: + if (ieee80211_hw_config(local, IEEE80211_CONF_CHANGE_CHANNEL)) + skip = 1; + + /* advance state machine to next channel/band */ + local->scan_channel_idx++; + + if (skip) { + /* if we skip this channel return to the decision state */ + local->next_scan_state = SCAN_DECISION; + return; + } + + /* + * Probe delay is used to update the NAV, cf. 11.1.3.2.2 + * (which unfortunately doesn't say _why_ step a) is done, + * but it waits for the probe delay or until a frame is + * received - and the received frame would update the NAV). + * For now, we do not support waiting until a frame is + * received. + * + * In any case, it is not necessary for a passive scan. + */ + if ((chan->flags & (IEEE80211_CHAN_NO_IR | IEEE80211_CHAN_RADAR)) || + !scan_req->n_ssids) { + *next_delay = IEEE80211_PASSIVE_CHANNEL_TIME; + local->next_scan_state = SCAN_DECISION; + if (scan_req->n_ssids) + set_bit(SCAN_BEACON_WAIT, &local->scanning); + return; + } + + /* active scan, send probes */ + *next_delay = IEEE80211_PROBE_DELAY; + local->next_scan_state = SCAN_SEND_PROBE; +} + +static void ieee80211_scan_state_suspend(struct ieee80211_local *local, + unsigned long *next_delay) +{ + /* switch back to the operating channel */ + local->scan_chandef.chan = NULL; + ieee80211_hw_config(local, IEEE80211_CONF_CHANGE_CHANNEL); + + /* disable PS */ + ieee80211_offchannel_return(local); + + *next_delay = HZ / 5; + /* afterwards, resume scan & go to next channel */ + local->next_scan_state = SCAN_RESUME; +} + +static void ieee80211_scan_state_resume(struct ieee80211_local *local, + unsigned long *next_delay) +{ + ieee80211_offchannel_stop_vifs(local); + + if (local->ops->flush) { + ieee80211_flush_queues(local, NULL, false); + *next_delay = 0; + } else + *next_delay = HZ / 10; + + /* remember when we left the operating channel */ + local->leave_oper_channel_time = jiffies; + + /* advance to the next channel to be scanned */ + local->next_scan_state = SCAN_SET_CHANNEL; +} + +void ieee80211_scan_work(struct wiphy *wiphy, struct wiphy_work *work) +{ + struct ieee80211_local *local = + container_of(work, struct ieee80211_local, scan_work.work); + struct ieee80211_sub_if_data *sdata; + struct cfg80211_scan_request *scan_req; + unsigned long next_delay = 0; + bool aborted; + + mutex_lock(&local->mtx); + + if (!ieee80211_can_run_worker(local)) { + aborted = true; + goto out_complete; + } + + sdata = rcu_dereference_protected(local->scan_sdata, + lockdep_is_held(&local->mtx)); + scan_req = rcu_dereference_protected(local->scan_req, + lockdep_is_held(&local->mtx)); + + /* When scanning on-channel, the first-callback means completed. */ + if (test_bit(SCAN_ONCHANNEL_SCANNING, &local->scanning)) { + aborted = test_and_clear_bit(SCAN_ABORTED, &local->scanning); + goto out_complete; + } + + if (test_and_clear_bit(SCAN_COMPLETED, &local->scanning)) { + aborted = test_and_clear_bit(SCAN_ABORTED, &local->scanning); + goto out_complete; + } + + if (!sdata || !scan_req) + goto out; + + if (!local->scanning) { + int rc; + + RCU_INIT_POINTER(local->scan_req, NULL); + RCU_INIT_POINTER(local->scan_sdata, NULL); + + rc = __ieee80211_start_scan(sdata, scan_req); + if (rc) { + /* need to complete scan in cfg80211 */ + rcu_assign_pointer(local->scan_req, scan_req); + aborted = true; + goto out_complete; + } else + goto out; + } + + clear_bit(SCAN_BEACON_WAIT, &local->scanning); + + /* + * as long as no delay is required advance immediately + * without scheduling a new work + */ + do { + if (!ieee80211_sdata_running(sdata)) { + aborted = true; + goto out_complete; + } + + if (test_and_clear_bit(SCAN_BEACON_DONE, &local->scanning) && + local->next_scan_state == SCAN_DECISION) + local->next_scan_state = SCAN_SEND_PROBE; + + switch (local->next_scan_state) { + case SCAN_DECISION: + /* if no more bands/channels left, complete scan */ + if (local->scan_channel_idx >= scan_req->n_channels) { + aborted = false; + goto out_complete; + } + ieee80211_scan_state_decision(local, &next_delay); + break; + case SCAN_SET_CHANNEL: + ieee80211_scan_state_set_channel(local, &next_delay); + break; + case SCAN_SEND_PROBE: + ieee80211_scan_state_send_probe(local, &next_delay); + break; + case SCAN_SUSPEND: + ieee80211_scan_state_suspend(local, &next_delay); + break; + case SCAN_RESUME: + ieee80211_scan_state_resume(local, &next_delay); + break; + case SCAN_ABORT: + aborted = true; + goto out_complete; + } + } while (next_delay == 0); + + wiphy_delayed_work_queue(local->hw.wiphy, &local->scan_work, + next_delay); + goto out; + +out_complete: + __ieee80211_scan_completed(&local->hw, aborted); +out: + mutex_unlock(&local->mtx); +} + +int ieee80211_request_scan(struct ieee80211_sub_if_data *sdata, + struct cfg80211_scan_request *req) +{ + int res; + + mutex_lock(&sdata->local->mtx); + res = __ieee80211_start_scan(sdata, req); + mutex_unlock(&sdata->local->mtx); + + return res; +} + +int ieee80211_request_ibss_scan(struct ieee80211_sub_if_data *sdata, + const u8 *ssid, u8 ssid_len, + struct ieee80211_channel **channels, + unsigned int n_channels, + enum nl80211_bss_scan_width scan_width) +{ + struct ieee80211_local *local = sdata->local; + int ret = -EBUSY, i, n_ch = 0; + enum nl80211_band band; + + mutex_lock(&local->mtx); + + /* busy scanning */ + if (local->scan_req) + goto unlock; + + /* fill internal scan request */ + if (!channels) { + int max_n; + + for (band = 0; band < NUM_NL80211_BANDS; band++) { + if (!local->hw.wiphy->bands[band] || + band == NL80211_BAND_6GHZ) + continue; + + max_n = local->hw.wiphy->bands[band]->n_channels; + for (i = 0; i < max_n; i++) { + struct ieee80211_channel *tmp_ch = + &local->hw.wiphy->bands[band]->channels[i]; + + if (tmp_ch->flags & (IEEE80211_CHAN_NO_IR | + IEEE80211_CHAN_DISABLED)) + continue; + + local->int_scan_req->channels[n_ch] = tmp_ch; + n_ch++; + } + } + + if (WARN_ON_ONCE(n_ch == 0)) + goto unlock; + + local->int_scan_req->n_channels = n_ch; + } else { + for (i = 0; i < n_channels; i++) { + if (channels[i]->flags & (IEEE80211_CHAN_NO_IR | + IEEE80211_CHAN_DISABLED)) + continue; + + local->int_scan_req->channels[n_ch] = channels[i]; + n_ch++; + } + + if (WARN_ON_ONCE(n_ch == 0)) + goto unlock; + + local->int_scan_req->n_channels = n_ch; + } + + local->int_scan_req->ssids = &local->scan_ssid; + local->int_scan_req->n_ssids = 1; + local->int_scan_req->scan_width = scan_width; + memcpy(local->int_scan_req->ssids[0].ssid, ssid, IEEE80211_MAX_SSID_LEN); + local->int_scan_req->ssids[0].ssid_len = ssid_len; + + ret = __ieee80211_start_scan(sdata, sdata->local->int_scan_req); + unlock: + mutex_unlock(&local->mtx); + return ret; +} + +/* + * Only call this function when a scan can't be queued -- under RTNL. + */ +void ieee80211_scan_cancel(struct ieee80211_local *local) +{ + /* + * We are canceling software scan, or deferred scan that was not + * yet really started (see __ieee80211_start_scan ). + * + * Regarding hardware scan: + * - we can not call __ieee80211_scan_completed() as when + * SCAN_HW_SCANNING bit is set this function change + * local->hw_scan_req to operate on 5G band, what race with + * driver which can use local->hw_scan_req + * + * - we can not cancel scan_work since driver can schedule it + * by ieee80211_scan_completed(..., true) to finish scan + * + * Hence we only call the cancel_hw_scan() callback, but the low-level + * driver is still responsible for calling ieee80211_scan_completed() + * after the scan was completed/aborted. + */ + + mutex_lock(&local->mtx); + if (!local->scan_req) + goto out; + + /* + * We have a scan running and the driver already reported completion, + * but the worker hasn't run yet or is stuck on the mutex - mark it as + * cancelled. + */ + if (test_bit(SCAN_HW_SCANNING, &local->scanning) && + test_bit(SCAN_COMPLETED, &local->scanning)) { + set_bit(SCAN_HW_CANCELLED, &local->scanning); + goto out; + } + + if (test_bit(SCAN_HW_SCANNING, &local->scanning)) { + /* + * Make sure that __ieee80211_scan_completed doesn't trigger a + * scan on another band. + */ + set_bit(SCAN_HW_CANCELLED, &local->scanning); + if (local->ops->cancel_hw_scan) + drv_cancel_hw_scan(local, + rcu_dereference_protected(local->scan_sdata, + lockdep_is_held(&local->mtx))); + goto out; + } + + wiphy_delayed_work_cancel(local->hw.wiphy, &local->scan_work); + /* and clean up */ + memset(&local->scan_info, 0, sizeof(local->scan_info)); + __ieee80211_scan_completed(&local->hw, true); +out: + mutex_unlock(&local->mtx); +} + +int __ieee80211_request_sched_scan_start(struct ieee80211_sub_if_data *sdata, + struct cfg80211_sched_scan_request *req) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_scan_ies sched_scan_ies = {}; + struct cfg80211_chan_def chandef; + int ret, i, iebufsz, num_bands = 0; + u32 rate_masks[NUM_NL80211_BANDS] = {}; + u8 bands_used = 0; + u8 *ie; + u32 flags = 0; + + iebufsz = local->scan_ies_len + req->ie_len; + + lockdep_assert_held(&local->mtx); + + if (!local->ops->sched_scan_start) + return -ENOTSUPP; + + for (i = 0; i < NUM_NL80211_BANDS; i++) { + if (local->hw.wiphy->bands[i]) { + bands_used |= BIT(i); + rate_masks[i] = (u32) -1; + num_bands++; + } + } + + if (req->flags & NL80211_SCAN_FLAG_MIN_PREQ_CONTENT) + flags |= IEEE80211_PROBE_FLAG_MIN_CONTENT; + + ie = kcalloc(iebufsz, num_bands, GFP_KERNEL); + if (!ie) { + ret = -ENOMEM; + goto out; + } + + ieee80211_prepare_scan_chandef(&chandef, req->scan_width); + + ieee80211_build_preq_ies(sdata, ie, num_bands * iebufsz, + &sched_scan_ies, req->ie, + req->ie_len, bands_used, rate_masks, &chandef, + flags); + + ret = drv_sched_scan_start(local, sdata, req, &sched_scan_ies); + if (ret == 0) { + rcu_assign_pointer(local->sched_scan_sdata, sdata); + rcu_assign_pointer(local->sched_scan_req, req); + } + + kfree(ie); + +out: + if (ret) { + /* Clean in case of failure after HW restart or upon resume. */ + RCU_INIT_POINTER(local->sched_scan_sdata, NULL); + RCU_INIT_POINTER(local->sched_scan_req, NULL); + } + + return ret; +} + +int ieee80211_request_sched_scan_start(struct ieee80211_sub_if_data *sdata, + struct cfg80211_sched_scan_request *req) +{ + struct ieee80211_local *local = sdata->local; + int ret; + + mutex_lock(&local->mtx); + + if (rcu_access_pointer(local->sched_scan_sdata)) { + mutex_unlock(&local->mtx); + return -EBUSY; + } + + ret = __ieee80211_request_sched_scan_start(sdata, req); + + mutex_unlock(&local->mtx); + return ret; +} + +int ieee80211_request_sched_scan_stop(struct ieee80211_local *local) +{ + struct ieee80211_sub_if_data *sched_scan_sdata; + int ret = -ENOENT; + + mutex_lock(&local->mtx); + + if (!local->ops->sched_scan_stop) { + ret = -ENOTSUPP; + goto out; + } + + /* We don't want to restart sched scan anymore. */ + RCU_INIT_POINTER(local->sched_scan_req, NULL); + + sched_scan_sdata = rcu_dereference_protected(local->sched_scan_sdata, + lockdep_is_held(&local->mtx)); + if (sched_scan_sdata) { + ret = drv_sched_scan_stop(local, sched_scan_sdata); + if (!ret) + RCU_INIT_POINTER(local->sched_scan_sdata, NULL); + } +out: + mutex_unlock(&local->mtx); + + return ret; +} + +void ieee80211_sched_scan_results(struct ieee80211_hw *hw) +{ + struct ieee80211_local *local = hw_to_local(hw); + + trace_api_sched_scan_results(local); + + cfg80211_sched_scan_results(hw->wiphy, 0); +} +EXPORT_SYMBOL(ieee80211_sched_scan_results); + +void ieee80211_sched_scan_end(struct ieee80211_local *local) +{ + mutex_lock(&local->mtx); + + if (!rcu_access_pointer(local->sched_scan_sdata)) { + mutex_unlock(&local->mtx); + return; + } + + RCU_INIT_POINTER(local->sched_scan_sdata, NULL); + + /* If sched scan was aborted by the driver. */ + RCU_INIT_POINTER(local->sched_scan_req, NULL); + + mutex_unlock(&local->mtx); + + cfg80211_sched_scan_stopped_locked(local->hw.wiphy, 0); +} + +void ieee80211_sched_scan_stopped_work(struct wiphy *wiphy, + struct wiphy_work *work) +{ + struct ieee80211_local *local = + container_of(work, struct ieee80211_local, + sched_scan_stopped_work); + + ieee80211_sched_scan_end(local); +} + +void ieee80211_sched_scan_stopped(struct ieee80211_hw *hw) +{ + struct ieee80211_local *local = hw_to_local(hw); + + trace_api_sched_scan_stopped(local); + + /* + * this shouldn't really happen, so for simplicity + * simply ignore it, and let mac80211 reconfigure + * the sched scan later on. + */ + if (local->in_reconfig) + return; + + wiphy_work_queue(hw->wiphy, &local->sched_scan_stopped_work); +} +EXPORT_SYMBOL(ieee80211_sched_scan_stopped); diff --git a/net/mac80211/spectmgmt.c b/net/mac80211/spectmgmt.c new file mode 100644 index 000000000..871cdac2d --- /dev/null +++ b/net/mac80211/spectmgmt.c @@ -0,0 +1,249 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * spectrum management + * + * Copyright 2003, Jouni Malinen + * Copyright 2002-2005, Instant802 Networks, Inc. + * Copyright 2005-2006, Devicescape Software, Inc. + * Copyright 2006-2007 Jiri Benc + * Copyright 2007, Michael Wu + * Copyright 2007-2008, Intel Corporation + * Copyright 2008, Johannes Berg + * Copyright (C) 2018, 2020, 2022 Intel Corporation + */ + +#include +#include +#include +#include "ieee80211_i.h" +#include "sta_info.h" +#include "wme.h" + +int ieee80211_parse_ch_switch_ie(struct ieee80211_sub_if_data *sdata, + struct ieee802_11_elems *elems, + enum nl80211_band current_band, + u32 vht_cap_info, + ieee80211_conn_flags_t conn_flags, u8 *bssid, + struct ieee80211_csa_ie *csa_ie) +{ + enum nl80211_band new_band = current_band; + int new_freq; + u8 new_chan_no; + struct ieee80211_channel *new_chan; + struct cfg80211_chan_def new_vht_chandef = {}; + const struct ieee80211_sec_chan_offs_ie *sec_chan_offs; + const struct ieee80211_wide_bw_chansw_ie *wide_bw_chansw_ie; + int secondary_channel_offset = -1; + + memset(csa_ie, 0, sizeof(*csa_ie)); + + sec_chan_offs = elems->sec_chan_offs; + wide_bw_chansw_ie = elems->wide_bw_chansw_ie; + + if (conn_flags & (IEEE80211_CONN_DISABLE_HT | + IEEE80211_CONN_DISABLE_40MHZ)) { + sec_chan_offs = NULL; + wide_bw_chansw_ie = NULL; + } + + if (conn_flags & IEEE80211_CONN_DISABLE_VHT) + wide_bw_chansw_ie = NULL; + + if (elems->ext_chansw_ie) { + if (!ieee80211_operating_class_to_band( + elems->ext_chansw_ie->new_operating_class, + &new_band)) { + sdata_info(sdata, + "cannot understand ECSA IE operating class, %d, ignoring\n", + elems->ext_chansw_ie->new_operating_class); + } + new_chan_no = elems->ext_chansw_ie->new_ch_num; + csa_ie->count = elems->ext_chansw_ie->count; + csa_ie->mode = elems->ext_chansw_ie->mode; + } else if (elems->ch_switch_ie) { + new_chan_no = elems->ch_switch_ie->new_ch_num; + csa_ie->count = elems->ch_switch_ie->count; + csa_ie->mode = elems->ch_switch_ie->mode; + } else { + /* nothing here we understand */ + return 1; + } + + /* Mesh Channel Switch Parameters Element */ + if (elems->mesh_chansw_params_ie) { + csa_ie->ttl = elems->mesh_chansw_params_ie->mesh_ttl; + csa_ie->mode = elems->mesh_chansw_params_ie->mesh_flags; + csa_ie->pre_value = le16_to_cpu( + elems->mesh_chansw_params_ie->mesh_pre_value); + + if (elems->mesh_chansw_params_ie->mesh_flags & + WLAN_EID_CHAN_SWITCH_PARAM_REASON) + csa_ie->reason_code = le16_to_cpu( + elems->mesh_chansw_params_ie->mesh_reason); + } + + new_freq = ieee80211_channel_to_frequency(new_chan_no, new_band); + new_chan = ieee80211_get_channel(sdata->local->hw.wiphy, new_freq); + if (!new_chan || new_chan->flags & IEEE80211_CHAN_DISABLED) { + sdata_info(sdata, + "BSS %pM switches to unsupported channel (%d MHz), disconnecting\n", + bssid, new_freq); + return -EINVAL; + } + + if (sec_chan_offs) { + secondary_channel_offset = sec_chan_offs->sec_chan_offs; + } else if (!(conn_flags & IEEE80211_CONN_DISABLE_HT)) { + /* If the secondary channel offset IE is not present, + * we can't know what's the post-CSA offset, so the + * best we can do is use 20MHz. + */ + secondary_channel_offset = IEEE80211_HT_PARAM_CHA_SEC_NONE; + } + + switch (secondary_channel_offset) { + default: + /* secondary_channel_offset was present but is invalid */ + case IEEE80211_HT_PARAM_CHA_SEC_NONE: + cfg80211_chandef_create(&csa_ie->chandef, new_chan, + NL80211_CHAN_HT20); + break; + case IEEE80211_HT_PARAM_CHA_SEC_ABOVE: + cfg80211_chandef_create(&csa_ie->chandef, new_chan, + NL80211_CHAN_HT40PLUS); + break; + case IEEE80211_HT_PARAM_CHA_SEC_BELOW: + cfg80211_chandef_create(&csa_ie->chandef, new_chan, + NL80211_CHAN_HT40MINUS); + break; + case -1: + cfg80211_chandef_create(&csa_ie->chandef, new_chan, + NL80211_CHAN_NO_HT); + /* keep width for 5/10 MHz channels */ + switch (sdata->vif.bss_conf.chandef.width) { + case NL80211_CHAN_WIDTH_5: + case NL80211_CHAN_WIDTH_10: + csa_ie->chandef.width = + sdata->vif.bss_conf.chandef.width; + break; + default: + break; + } + break; + } + + if (wide_bw_chansw_ie) { + u8 new_seg1 = wide_bw_chansw_ie->new_center_freq_seg1; + struct ieee80211_vht_operation vht_oper = { + .chan_width = + wide_bw_chansw_ie->new_channel_width, + .center_freq_seg0_idx = + wide_bw_chansw_ie->new_center_freq_seg0, + .center_freq_seg1_idx = new_seg1, + /* .basic_mcs_set doesn't matter */ + }; + struct ieee80211_ht_operation ht_oper = { + .operation_mode = + cpu_to_le16(new_seg1 << + IEEE80211_HT_OP_MODE_CCFS2_SHIFT), + }; + + /* default, for the case of IEEE80211_VHT_CHANWIDTH_USE_HT, + * to the previously parsed chandef + */ + new_vht_chandef = csa_ie->chandef; + + /* ignore if parsing fails */ + if (!ieee80211_chandef_vht_oper(&sdata->local->hw, + vht_cap_info, + &vht_oper, &ht_oper, + &new_vht_chandef)) + new_vht_chandef.chan = NULL; + + if (conn_flags & IEEE80211_CONN_DISABLE_80P80MHZ && + new_vht_chandef.width == NL80211_CHAN_WIDTH_80P80) + ieee80211_chandef_downgrade(&new_vht_chandef); + if (conn_flags & IEEE80211_CONN_DISABLE_160MHZ && + new_vht_chandef.width == NL80211_CHAN_WIDTH_160) + ieee80211_chandef_downgrade(&new_vht_chandef); + } + + /* if VHT data is there validate & use it */ + if (new_vht_chandef.chan) { + if (!cfg80211_chandef_compatible(&new_vht_chandef, + &csa_ie->chandef)) { + sdata_info(sdata, + "BSS %pM: CSA has inconsistent channel data, disconnecting\n", + bssid); + return -EINVAL; + } + csa_ie->chandef = new_vht_chandef; + } + + if (elems->max_channel_switch_time) + csa_ie->max_switch_time = + (elems->max_channel_switch_time[0] << 0) | + (elems->max_channel_switch_time[1] << 8) | + (elems->max_channel_switch_time[2] << 16); + + return 0; +} + +static void ieee80211_send_refuse_measurement_request(struct ieee80211_sub_if_data *sdata, + struct ieee80211_msrment_ie *request_ie, + const u8 *da, const u8 *bssid, + u8 dialog_token) +{ + struct ieee80211_local *local = sdata->local; + struct sk_buff *skb; + struct ieee80211_mgmt *msr_report; + + skb = dev_alloc_skb(sizeof(*msr_report) + local->hw.extra_tx_headroom + + sizeof(struct ieee80211_msrment_ie)); + if (!skb) + return; + + skb_reserve(skb, local->hw.extra_tx_headroom); + msr_report = skb_put_zero(skb, 24); + memcpy(msr_report->da, da, ETH_ALEN); + memcpy(msr_report->sa, sdata->vif.addr, ETH_ALEN); + memcpy(msr_report->bssid, bssid, ETH_ALEN); + msr_report->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_ACTION); + + skb_put(skb, 1 + sizeof(msr_report->u.action.u.measurement)); + msr_report->u.action.category = WLAN_CATEGORY_SPECTRUM_MGMT; + msr_report->u.action.u.measurement.action_code = + WLAN_ACTION_SPCT_MSR_RPRT; + msr_report->u.action.u.measurement.dialog_token = dialog_token; + + msr_report->u.action.u.measurement.element_id = WLAN_EID_MEASURE_REPORT; + msr_report->u.action.u.measurement.length = + sizeof(struct ieee80211_msrment_ie); + + memset(&msr_report->u.action.u.measurement.msr_elem, 0, + sizeof(struct ieee80211_msrment_ie)); + msr_report->u.action.u.measurement.msr_elem.token = request_ie->token; + msr_report->u.action.u.measurement.msr_elem.mode |= + IEEE80211_SPCT_MSR_RPRT_MODE_REFUSED; + msr_report->u.action.u.measurement.msr_elem.type = request_ie->type; + + ieee80211_tx_skb(sdata, skb); +} + +void ieee80211_process_measurement_req(struct ieee80211_sub_if_data *sdata, + struct ieee80211_mgmt *mgmt, + size_t len) +{ + /* + * Ignoring measurement request is spec violation. + * Mandatory measurements must be reported optional + * measurements might be refused or reported incapable + * For now just refuse + * TODO: Answer basic measurement as unmeasured + */ + ieee80211_send_refuse_measurement_request(sdata, + &mgmt->u.action.u.measurement.msr_elem, + mgmt->sa, mgmt->bssid, + mgmt->u.action.u.measurement.dialog_token); +} diff --git a/net/mac80211/sta_info.c b/net/mac80211/sta_info.c new file mode 100644 index 000000000..f3d6c3e4c --- /dev/null +++ b/net/mac80211/sta_info.c @@ -0,0 +1,2935 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2002-2005, Instant802 Networks, Inc. + * Copyright 2006-2007 Jiri Benc + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright (C) 2015 - 2017 Intel Deutschland GmbH + * Copyright (C) 2018-2021 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "ieee80211_i.h" +#include "driver-ops.h" +#include "rate.h" +#include "sta_info.h" +#include "debugfs_sta.h" +#include "mesh.h" +#include "wme.h" + +/** + * DOC: STA information lifetime rules + * + * STA info structures (&struct sta_info) are managed in a hash table + * for faster lookup and a list for iteration. They are managed using + * RCU, i.e. access to the list and hash table is protected by RCU. + * + * Upon allocating a STA info structure with sta_info_alloc(), the caller + * owns that structure. It must then insert it into the hash table using + * either sta_info_insert() or sta_info_insert_rcu(); only in the latter + * case (which acquires an rcu read section but must not be called from + * within one) will the pointer still be valid after the call. Note that + * the caller may not do much with the STA info before inserting it, in + * particular, it may not start any mesh peer link management or add + * encryption keys. + * + * When the insertion fails (sta_info_insert()) returns non-zero), the + * structure will have been freed by sta_info_insert()! + * + * Station entries are added by mac80211 when you establish a link with a + * peer. This means different things for the different type of interfaces + * we support. For a regular station this mean we add the AP sta when we + * receive an association response from the AP. For IBSS this occurs when + * get to know about a peer on the same IBSS. For WDS we add the sta for + * the peer immediately upon device open. When using AP mode we add stations + * for each respective station upon request from userspace through nl80211. + * + * In order to remove a STA info structure, various sta_info_destroy_*() + * calls are available. + * + * There is no concept of ownership on a STA entry, each structure is + * owned by the global hash table/list until it is removed. All users of + * the structure need to be RCU protected so that the structure won't be + * freed before they are done using it. + */ + +struct sta_link_alloc { + struct link_sta_info info; + struct ieee80211_link_sta sta; + struct rcu_head rcu_head; +}; + +static const struct rhashtable_params sta_rht_params = { + .nelem_hint = 3, /* start small */ + .automatic_shrinking = true, + .head_offset = offsetof(struct sta_info, hash_node), + .key_offset = offsetof(struct sta_info, addr), + .key_len = ETH_ALEN, + .max_size = CONFIG_MAC80211_STA_HASH_MAX_SIZE, +}; + +static const struct rhashtable_params link_sta_rht_params = { + .nelem_hint = 3, /* start small */ + .automatic_shrinking = true, + .head_offset = offsetof(struct link_sta_info, link_hash_node), + .key_offset = offsetof(struct link_sta_info, addr), + .key_len = ETH_ALEN, + .max_size = CONFIG_MAC80211_STA_HASH_MAX_SIZE, +}; + +/* Caller must hold local->sta_mtx */ +static int sta_info_hash_del(struct ieee80211_local *local, + struct sta_info *sta) +{ + return rhltable_remove(&local->sta_hash, &sta->hash_node, + sta_rht_params); +} + +static int link_sta_info_hash_add(struct ieee80211_local *local, + struct link_sta_info *link_sta) +{ + lockdep_assert_held(&local->sta_mtx); + return rhltable_insert(&local->link_sta_hash, + &link_sta->link_hash_node, + link_sta_rht_params); +} + +static int link_sta_info_hash_del(struct ieee80211_local *local, + struct link_sta_info *link_sta) +{ + lockdep_assert_held(&local->sta_mtx); + return rhltable_remove(&local->link_sta_hash, + &link_sta->link_hash_node, + link_sta_rht_params); +} + +static void __cleanup_single_sta(struct sta_info *sta) +{ + int ac, i; + struct tid_ampdu_tx *tid_tx; + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sdata->local; + struct ps_data *ps; + + if (test_sta_flag(sta, WLAN_STA_PS_STA) || + test_sta_flag(sta, WLAN_STA_PS_DRIVER) || + test_sta_flag(sta, WLAN_STA_PS_DELIVER)) { + if (sta->sdata->vif.type == NL80211_IFTYPE_AP || + sta->sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + ps = &sdata->bss->ps; + else if (ieee80211_vif_is_mesh(&sdata->vif)) + ps = &sdata->u.mesh.ps; + else + return; + + clear_sta_flag(sta, WLAN_STA_PS_STA); + clear_sta_flag(sta, WLAN_STA_PS_DRIVER); + clear_sta_flag(sta, WLAN_STA_PS_DELIVER); + + atomic_dec(&ps->num_sta_ps); + } + + if (sta->sta.txq[0]) { + for (i = 0; i < ARRAY_SIZE(sta->sta.txq); i++) { + struct txq_info *txqi; + + if (!sta->sta.txq[i]) + continue; + + txqi = to_txq_info(sta->sta.txq[i]); + + ieee80211_txq_purge(local, txqi); + } + } + + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { + local->total_ps_buffered -= skb_queue_len(&sta->ps_tx_buf[ac]); + ieee80211_purge_tx_queue(&local->hw, &sta->ps_tx_buf[ac]); + ieee80211_purge_tx_queue(&local->hw, &sta->tx_filtered[ac]); + } + + if (ieee80211_vif_is_mesh(&sdata->vif)) + mesh_sta_cleanup(sta); + + cancel_work_sync(&sta->drv_deliver_wk); + + /* + * Destroy aggregation state here. It would be nice to wait for the + * driver to finish aggregation stop and then clean up, but for now + * drivers have to handle aggregation stop being requested, followed + * directly by station destruction. + */ + for (i = 0; i < IEEE80211_NUM_TIDS; i++) { + kfree(sta->ampdu_mlme.tid_start_tx[i]); + tid_tx = rcu_dereference_raw(sta->ampdu_mlme.tid_tx[i]); + if (!tid_tx) + continue; + ieee80211_purge_tx_queue(&local->hw, &tid_tx->pending); + kfree(tid_tx); + } +} + +static void cleanup_single_sta(struct sta_info *sta) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sdata->local; + + __cleanup_single_sta(sta); + sta_info_free(local, sta); +} + +struct rhlist_head *sta_info_hash_lookup(struct ieee80211_local *local, + const u8 *addr) +{ + return rhltable_lookup(&local->sta_hash, addr, sta_rht_params); +} + +/* protected by RCU */ +struct sta_info *sta_info_get(struct ieee80211_sub_if_data *sdata, + const u8 *addr) +{ + struct ieee80211_local *local = sdata->local; + struct rhlist_head *tmp; + struct sta_info *sta; + + rcu_read_lock(); + for_each_sta_info(local, addr, sta, tmp) { + if (sta->sdata == sdata) { + rcu_read_unlock(); + /* this is safe as the caller must already hold + * another rcu read section or the mutex + */ + return sta; + } + } + rcu_read_unlock(); + return NULL; +} + +/* + * Get sta info either from the specified interface + * or from one of its vlans + */ +struct sta_info *sta_info_get_bss(struct ieee80211_sub_if_data *sdata, + const u8 *addr) +{ + struct ieee80211_local *local = sdata->local; + struct rhlist_head *tmp; + struct sta_info *sta; + + rcu_read_lock(); + for_each_sta_info(local, addr, sta, tmp) { + if (sta->sdata == sdata || + (sta->sdata->bss && sta->sdata->bss == sdata->bss)) { + rcu_read_unlock(); + /* this is safe as the caller must already hold + * another rcu read section or the mutex + */ + return sta; + } + } + rcu_read_unlock(); + return NULL; +} + +struct rhlist_head *link_sta_info_hash_lookup(struct ieee80211_local *local, + const u8 *addr) +{ + return rhltable_lookup(&local->link_sta_hash, addr, + link_sta_rht_params); +} + +struct link_sta_info * +link_sta_info_get_bss(struct ieee80211_sub_if_data *sdata, const u8 *addr) +{ + struct ieee80211_local *local = sdata->local; + struct rhlist_head *tmp; + struct link_sta_info *link_sta; + + rcu_read_lock(); + for_each_link_sta_info(local, addr, link_sta, tmp) { + struct sta_info *sta = link_sta->sta; + + if (sta->sdata == sdata || + (sta->sdata->bss && sta->sdata->bss == sdata->bss)) { + rcu_read_unlock(); + /* this is safe as the caller must already hold + * another rcu read section or the mutex + */ + return link_sta; + } + } + rcu_read_unlock(); + return NULL; +} + +struct ieee80211_sta * +ieee80211_find_sta_by_link_addrs(struct ieee80211_hw *hw, + const u8 *addr, + const u8 *localaddr, + unsigned int *link_id) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct link_sta_info *link_sta; + struct rhlist_head *tmp; + + for_each_link_sta_info(local, addr, link_sta, tmp) { + struct sta_info *sta = link_sta->sta; + struct ieee80211_link_data *link; + u8 _link_id = link_sta->link_id; + + if (!localaddr) { + if (link_id) + *link_id = _link_id; + return &sta->sta; + } + + link = rcu_dereference(sta->sdata->link[_link_id]); + if (!link) + continue; + + if (memcmp(link->conf->addr, localaddr, ETH_ALEN)) + continue; + + if (link_id) + *link_id = _link_id; + return &sta->sta; + } + + return NULL; +} +EXPORT_SYMBOL_GPL(ieee80211_find_sta_by_link_addrs); + +struct sta_info *sta_info_get_by_addrs(struct ieee80211_local *local, + const u8 *sta_addr, const u8 *vif_addr) +{ + struct rhlist_head *tmp; + struct sta_info *sta; + + for_each_sta_info(local, sta_addr, sta, tmp) { + if (ether_addr_equal(vif_addr, sta->sdata->vif.addr)) + return sta; + } + + return NULL; +} + +struct sta_info *sta_info_get_by_idx(struct ieee80211_sub_if_data *sdata, + int idx) +{ + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + int i = 0; + + list_for_each_entry_rcu(sta, &local->sta_list, list, + lockdep_is_held(&local->sta_mtx)) { + if (sdata != sta->sdata) + continue; + if (i < idx) { + ++i; + continue; + } + return sta; + } + + return NULL; +} + +static void sta_info_free_link(struct link_sta_info *link_sta) +{ + free_percpu(link_sta->pcpu_rx_stats); +} + +static void sta_remove_link(struct sta_info *sta, unsigned int link_id, + bool unhash) +{ + struct sta_link_alloc *alloc = NULL; + struct link_sta_info *link_sta; + + link_sta = rcu_dereference_protected(sta->link[link_id], + lockdep_is_held(&sta->local->sta_mtx)); + + if (WARN_ON(!link_sta)) + return; + + if (unhash) + link_sta_info_hash_del(sta->local, link_sta); + + if (link_sta != &sta->deflink) + alloc = container_of(link_sta, typeof(*alloc), info); + + sta->sta.valid_links &= ~BIT(link_id); + RCU_INIT_POINTER(sta->link[link_id], NULL); + RCU_INIT_POINTER(sta->sta.link[link_id], NULL); + if (alloc) { + sta_info_free_link(&alloc->info); + kfree_rcu(alloc, rcu_head); + } + + ieee80211_sta_recalc_aggregates(&sta->sta); +} + +/** + * sta_info_free - free STA + * + * @local: pointer to the global information + * @sta: STA info to free + * + * This function must undo everything done by sta_info_alloc() + * that may happen before sta_info_insert(). It may only be + * called when sta_info_insert() has not been attempted (and + * if that fails, the station is freed anyway.) + */ +void sta_info_free(struct ieee80211_local *local, struct sta_info *sta) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(sta->link); i++) { + struct link_sta_info *link_sta; + + link_sta = rcu_access_pointer(sta->link[i]); + if (!link_sta) + continue; + + sta_remove_link(sta, i, false); + } + + /* + * If we had used sta_info_pre_move_state() then we might not + * have gone through the state transitions down again, so do + * it here now (and warn if it's inserted). + * + * This will clear state such as fast TX/RX that may have been + * allocated during state transitions. + */ + while (sta->sta_state > IEEE80211_STA_NONE) { + int ret; + + WARN_ON_ONCE(test_sta_flag(sta, WLAN_STA_INSERTED)); + + ret = sta_info_move_state(sta, sta->sta_state - 1); + if (WARN_ONCE(ret, "sta_info_move_state() returned %d\n", ret)) + break; + } + + if (sta->rate_ctrl) + rate_control_free_sta(sta); + + sta_dbg(sta->sdata, "Destroyed STA %pM\n", sta->sta.addr); + + if (sta->sta.txq[0]) + kfree(to_txq_info(sta->sta.txq[0])); + kfree(rcu_dereference_raw(sta->sta.rates)); +#ifdef CONFIG_MAC80211_MESH + kfree(sta->mesh); +#endif + + sta_info_free_link(&sta->deflink); + kfree(sta); +} + +/* Caller must hold local->sta_mtx */ +static int sta_info_hash_add(struct ieee80211_local *local, + struct sta_info *sta) +{ + return rhltable_insert(&local->sta_hash, &sta->hash_node, + sta_rht_params); +} + +static void sta_deliver_ps_frames(struct work_struct *wk) +{ + struct sta_info *sta; + + sta = container_of(wk, struct sta_info, drv_deliver_wk); + + if (sta->dead) + return; + + local_bh_disable(); + if (!test_sta_flag(sta, WLAN_STA_PS_STA)) + ieee80211_sta_ps_deliver_wakeup(sta); + else if (test_and_clear_sta_flag(sta, WLAN_STA_PSPOLL)) + ieee80211_sta_ps_deliver_poll_response(sta); + else if (test_and_clear_sta_flag(sta, WLAN_STA_UAPSD)) + ieee80211_sta_ps_deliver_uapsd(sta); + local_bh_enable(); +} + +static int sta_prepare_rate_control(struct ieee80211_local *local, + struct sta_info *sta, gfp_t gfp) +{ + if (ieee80211_hw_check(&local->hw, HAS_RATE_CONTROL)) + return 0; + + sta->rate_ctrl = local->rate_ctrl; + sta->rate_ctrl_priv = rate_control_alloc_sta(sta->rate_ctrl, + sta, gfp); + if (!sta->rate_ctrl_priv) + return -ENOMEM; + + return 0; +} + +static int sta_info_alloc_link(struct ieee80211_local *local, + struct link_sta_info *link_info, + gfp_t gfp) +{ + struct ieee80211_hw *hw = &local->hw; + int i; + + if (ieee80211_hw_check(hw, USES_RSS)) { + link_info->pcpu_rx_stats = + alloc_percpu_gfp(struct ieee80211_sta_rx_stats, gfp); + if (!link_info->pcpu_rx_stats) + return -ENOMEM; + } + + link_info->rx_stats.last_rx = jiffies; + u64_stats_init(&link_info->rx_stats.syncp); + + ewma_signal_init(&link_info->rx_stats_avg.signal); + ewma_avg_signal_init(&link_info->status_stats.avg_ack_signal); + for (i = 0; i < ARRAY_SIZE(link_info->rx_stats_avg.chain_signal); i++) + ewma_signal_init(&link_info->rx_stats_avg.chain_signal[i]); + + return 0; +} + +static void sta_info_add_link(struct sta_info *sta, + unsigned int link_id, + struct link_sta_info *link_info, + struct ieee80211_link_sta *link_sta) +{ + link_info->sta = sta; + link_info->link_id = link_id; + link_info->pub = link_sta; + link_sta->link_id = link_id; + rcu_assign_pointer(sta->link[link_id], link_info); + rcu_assign_pointer(sta->sta.link[link_id], link_sta); + + link_sta->smps_mode = IEEE80211_SMPS_OFF; + link_sta->agg.max_rc_amsdu_len = IEEE80211_MAX_MPDU_LEN_HT_BA; +} + +static struct sta_info * +__sta_info_alloc(struct ieee80211_sub_if_data *sdata, + const u8 *addr, int link_id, const u8 *link_addr, + gfp_t gfp) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_hw *hw = &local->hw; + struct sta_info *sta; + int i; + + sta = kzalloc(sizeof(*sta) + hw->sta_data_size, gfp); + if (!sta) + return NULL; + + sta->local = local; + sta->sdata = sdata; + + if (sta_info_alloc_link(local, &sta->deflink, gfp)) + goto free; + + if (link_id >= 0) { + sta_info_add_link(sta, link_id, &sta->deflink, + &sta->sta.deflink); + sta->sta.valid_links = BIT(link_id); + } else { + sta_info_add_link(sta, 0, &sta->deflink, &sta->sta.deflink); + } + + sta->sta.cur = &sta->sta.deflink.agg; + + spin_lock_init(&sta->lock); + spin_lock_init(&sta->ps_lock); + INIT_WORK(&sta->drv_deliver_wk, sta_deliver_ps_frames); + INIT_WORK(&sta->ampdu_mlme.work, ieee80211_ba_session_work); + mutex_init(&sta->ampdu_mlme.mtx); +#ifdef CONFIG_MAC80211_MESH + if (ieee80211_vif_is_mesh(&sdata->vif)) { + sta->mesh = kzalloc(sizeof(*sta->mesh), gfp); + if (!sta->mesh) + goto free; + sta->mesh->plink_sta = sta; + spin_lock_init(&sta->mesh->plink_lock); + if (!sdata->u.mesh.user_mpm) + timer_setup(&sta->mesh->plink_timer, mesh_plink_timer, + 0); + sta->mesh->nonpeer_pm = NL80211_MESH_POWER_ACTIVE; + } +#endif + + memcpy(sta->addr, addr, ETH_ALEN); + memcpy(sta->sta.addr, addr, ETH_ALEN); + memcpy(sta->deflink.addr, link_addr, ETH_ALEN); + memcpy(sta->sta.deflink.addr, link_addr, ETH_ALEN); + sta->sta.max_rx_aggregation_subframes = + local->hw.max_rx_aggregation_subframes; + + /* TODO link specific alloc and assignments for MLO Link STA */ + + /* Extended Key ID needs to install keys for keyid 0 and 1 Rx-only. + * The Tx path starts to use a key as soon as the key slot ptk_idx + * references to is not NULL. To not use the initial Rx-only key + * prematurely for Tx initialize ptk_idx to an impossible PTK keyid + * which always will refer to a NULL key. + */ + BUILD_BUG_ON(ARRAY_SIZE(sta->ptk) <= INVALID_PTK_KEYIDX); + sta->ptk_idx = INVALID_PTK_KEYIDX; + + + ieee80211_init_frag_cache(&sta->frags); + + sta->sta_state = IEEE80211_STA_NONE; + + /* Mark TID as unreserved */ + sta->reserved_tid = IEEE80211_TID_UNRESERVED; + + sta->last_connected = ktime_get_seconds(); + + if (local->ops->wake_tx_queue) { + void *txq_data; + int size = sizeof(struct txq_info) + + ALIGN(hw->txq_data_size, sizeof(void *)); + + txq_data = kcalloc(ARRAY_SIZE(sta->sta.txq), size, gfp); + if (!txq_data) + goto free; + + for (i = 0; i < ARRAY_SIZE(sta->sta.txq); i++) { + struct txq_info *txq = txq_data + i * size; + + /* might not do anything for the bufferable MMPDU TXQ */ + ieee80211_txq_init(sdata, sta, txq, i); + } + } + + if (sta_prepare_rate_control(local, sta, gfp)) + goto free_txq; + + sta->airtime_weight = IEEE80211_DEFAULT_AIRTIME_WEIGHT; + + for (i = 0; i < IEEE80211_NUM_ACS; i++) { + skb_queue_head_init(&sta->ps_tx_buf[i]); + skb_queue_head_init(&sta->tx_filtered[i]); + sta->airtime[i].deficit = sta->airtime_weight; + atomic_set(&sta->airtime[i].aql_tx_pending, 0); + sta->airtime[i].aql_limit_low = local->aql_txq_limit_low[i]; + sta->airtime[i].aql_limit_high = local->aql_txq_limit_high[i]; + } + + for (i = 0; i < IEEE80211_NUM_TIDS; i++) + sta->last_seq_ctrl[i] = cpu_to_le16(USHRT_MAX); + + for (i = 0; i < NUM_NL80211_BANDS; i++) { + u32 mandatory = 0; + int r; + + if (!hw->wiphy->bands[i]) + continue; + + switch (i) { + case NL80211_BAND_2GHZ: + case NL80211_BAND_LC: + /* + * We use both here, even if we cannot really know for + * sure the station will support both, but the only use + * for this is when we don't know anything yet and send + * management frames, and then we'll pick the lowest + * possible rate anyway. + * If we don't include _G here, we cannot find a rate + * in P2P, and thus trigger the WARN_ONCE() in rate.c + */ + mandatory = IEEE80211_RATE_MANDATORY_B | + IEEE80211_RATE_MANDATORY_G; + break; + case NL80211_BAND_5GHZ: + mandatory = IEEE80211_RATE_MANDATORY_A; + break; + case NL80211_BAND_60GHZ: + WARN_ON(1); + mandatory = 0; + break; + } + + for (r = 0; r < hw->wiphy->bands[i]->n_bitrates; r++) { + struct ieee80211_rate *rate; + + rate = &hw->wiphy->bands[i]->bitrates[r]; + + if (!(rate->flags & mandatory)) + continue; + sta->sta.deflink.supp_rates[i] |= BIT(r); + } + } + + sta->cparams.ce_threshold = CODEL_DISABLED_THRESHOLD; + sta->cparams.target = MS2TIME(20); + sta->cparams.interval = MS2TIME(100); + sta->cparams.ecn = true; + sta->cparams.ce_threshold_selector = 0; + sta->cparams.ce_threshold_mask = 0; + + sta_dbg(sdata, "Allocated STA %pM\n", sta->sta.addr); + + return sta; + +free_txq: + if (sta->sta.txq[0]) + kfree(to_txq_info(sta->sta.txq[0])); +free: + sta_info_free_link(&sta->deflink); +#ifdef CONFIG_MAC80211_MESH + kfree(sta->mesh); +#endif + kfree(sta); + return NULL; +} + +struct sta_info *sta_info_alloc(struct ieee80211_sub_if_data *sdata, + const u8 *addr, gfp_t gfp) +{ + return __sta_info_alloc(sdata, addr, -1, addr, gfp); +} + +struct sta_info *sta_info_alloc_with_link(struct ieee80211_sub_if_data *sdata, + const u8 *mld_addr, + unsigned int link_id, + const u8 *link_addr, + gfp_t gfp) +{ + return __sta_info_alloc(sdata, mld_addr, link_id, link_addr, gfp); +} + +static int sta_info_insert_check(struct sta_info *sta) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + + /* + * Can't be a WARN_ON because it can be triggered through a race: + * something inserts a STA (on one CPU) without holding the RTNL + * and another CPU turns off the net device. + */ + if (unlikely(!ieee80211_sdata_running(sdata))) + return -ENETDOWN; + + if (WARN_ON(ether_addr_equal(sta->sta.addr, sdata->vif.addr) || + !is_valid_ether_addr(sta->sta.addr))) + return -EINVAL; + + /* The RCU read lock is required by rhashtable due to + * asynchronous resize/rehash. We also require the mutex + * for correctness. + */ + rcu_read_lock(); + lockdep_assert_held(&sdata->local->sta_mtx); + if (ieee80211_hw_check(&sdata->local->hw, NEEDS_UNIQUE_STA_ADDR) && + ieee80211_find_sta_by_ifaddr(&sdata->local->hw, sta->addr, NULL)) { + rcu_read_unlock(); + return -ENOTUNIQ; + } + rcu_read_unlock(); + + return 0; +} + +static int sta_info_insert_drv_state(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct sta_info *sta) +{ + enum ieee80211_sta_state state; + int err = 0; + + for (state = IEEE80211_STA_NOTEXIST; state < sta->sta_state; state++) { + err = drv_sta_state(local, sdata, sta, state, state + 1); + if (err) + break; + } + + if (!err) { + /* + * Drivers using legacy sta_add/sta_remove callbacks only + * get uploaded set to true after sta_add is called. + */ + if (!local->ops->sta_add) + sta->uploaded = true; + return 0; + } + + if (sdata->vif.type == NL80211_IFTYPE_ADHOC) { + sdata_info(sdata, + "failed to move IBSS STA %pM to state %d (%d) - keeping it anyway\n", + sta->sta.addr, state + 1, err); + err = 0; + } + + /* unwind on error */ + for (; state > IEEE80211_STA_NOTEXIST; state--) + WARN_ON(drv_sta_state(local, sdata, sta, state, state - 1)); + + return err; +} + +static void +ieee80211_recalc_p2p_go_ps_allowed(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + bool allow_p2p_go_ps = sdata->vif.p2p; + struct sta_info *sta; + + rcu_read_lock(); + list_for_each_entry_rcu(sta, &local->sta_list, list) { + if (sdata != sta->sdata || + !test_sta_flag(sta, WLAN_STA_ASSOC)) + continue; + if (!sta->sta.support_p2p_ps) { + allow_p2p_go_ps = false; + break; + } + } + rcu_read_unlock(); + + if (allow_p2p_go_ps != sdata->vif.bss_conf.allow_p2p_go_ps) { + sdata->vif.bss_conf.allow_p2p_go_ps = allow_p2p_go_ps; + ieee80211_link_info_change_notify(sdata, &sdata->deflink, + BSS_CHANGED_P2P_PS); + } +} + +/* + * should be called with sta_mtx locked + * this function replaces the mutex lock + * with a RCU lock + */ +static int sta_info_insert_finish(struct sta_info *sta) __acquires(RCU) +{ + struct ieee80211_local *local = sta->local; + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct station_info *sinfo = NULL; + int err = 0; + + lockdep_assert_held(&local->sta_mtx); + + /* check if STA exists already */ + if (sta_info_get_bss(sdata, sta->sta.addr)) { + err = -EEXIST; + goto out_cleanup; + } + + sinfo = kzalloc(sizeof(struct station_info), GFP_KERNEL); + if (!sinfo) { + err = -ENOMEM; + goto out_cleanup; + } + + local->num_sta++; + local->sta_generation++; + smp_mb(); + + /* simplify things and don't accept BA sessions yet */ + set_sta_flag(sta, WLAN_STA_BLOCK_BA); + + /* make the station visible */ + err = sta_info_hash_add(local, sta); + if (err) + goto out_drop_sta; + + if (sta->sta.valid_links) { + err = link_sta_info_hash_add(local, &sta->deflink); + if (err) { + sta_info_hash_del(local, sta); + goto out_drop_sta; + } + } + + list_add_tail_rcu(&sta->list, &local->sta_list); + + /* update channel context before notifying the driver about state + * change, this enables driver using the updated channel context right away. + */ + if (sta->sta_state >= IEEE80211_STA_ASSOC) { + ieee80211_recalc_min_chandef(sta->sdata, -1); + if (!sta->sta.support_p2p_ps) + ieee80211_recalc_p2p_go_ps_allowed(sta->sdata); + } + + /* notify driver */ + err = sta_info_insert_drv_state(local, sdata, sta); + if (err) + goto out_remove; + + set_sta_flag(sta, WLAN_STA_INSERTED); + + /* accept BA sessions now */ + clear_sta_flag(sta, WLAN_STA_BLOCK_BA); + + ieee80211_sta_debugfs_add(sta); + rate_control_add_sta_debugfs(sta); + + sinfo->generation = local->sta_generation; + cfg80211_new_sta(sdata->dev, sta->sta.addr, sinfo, GFP_KERNEL); + kfree(sinfo); + + sta_dbg(sdata, "Inserted STA %pM\n", sta->sta.addr); + + /* move reference to rcu-protected */ + rcu_read_lock(); + mutex_unlock(&local->sta_mtx); + + if (ieee80211_vif_is_mesh(&sdata->vif)) + mesh_accept_plinks_update(sdata); + + return 0; + out_remove: + if (sta->sta.valid_links) + link_sta_info_hash_del(local, &sta->deflink); + sta_info_hash_del(local, sta); + list_del_rcu(&sta->list); + out_drop_sta: + local->num_sta--; + synchronize_net(); + out_cleanup: + cleanup_single_sta(sta); + mutex_unlock(&local->sta_mtx); + kfree(sinfo); + rcu_read_lock(); + return err; +} + +int sta_info_insert_rcu(struct sta_info *sta) __acquires(RCU) +{ + struct ieee80211_local *local = sta->local; + int err; + + might_sleep(); + + mutex_lock(&local->sta_mtx); + + err = sta_info_insert_check(sta); + if (err) { + sta_info_free(local, sta); + mutex_unlock(&local->sta_mtx); + rcu_read_lock(); + return err; + } + + return sta_info_insert_finish(sta); +} + +int sta_info_insert(struct sta_info *sta) +{ + int err = sta_info_insert_rcu(sta); + + rcu_read_unlock(); + + return err; +} + +static inline void __bss_tim_set(u8 *tim, u16 id) +{ + /* + * This format has been mandated by the IEEE specifications, + * so this line may not be changed to use the __set_bit() format. + */ + tim[id / 8] |= (1 << (id % 8)); +} + +static inline void __bss_tim_clear(u8 *tim, u16 id) +{ + /* + * This format has been mandated by the IEEE specifications, + * so this line may not be changed to use the __clear_bit() format. + */ + tim[id / 8] &= ~(1 << (id % 8)); +} + +static inline bool __bss_tim_get(u8 *tim, u16 id) +{ + /* + * This format has been mandated by the IEEE specifications, + * so this line may not be changed to use the test_bit() format. + */ + return tim[id / 8] & (1 << (id % 8)); +} + +static unsigned long ieee80211_tids_for_ac(int ac) +{ + /* If we ever support TIDs > 7, this obviously needs to be adjusted */ + switch (ac) { + case IEEE80211_AC_VO: + return BIT(6) | BIT(7); + case IEEE80211_AC_VI: + return BIT(4) | BIT(5); + case IEEE80211_AC_BE: + return BIT(0) | BIT(3); + case IEEE80211_AC_BK: + return BIT(1) | BIT(2); + default: + WARN_ON(1); + return 0; + } +} + +static void __sta_info_recalc_tim(struct sta_info *sta, bool ignore_pending) +{ + struct ieee80211_local *local = sta->local; + struct ps_data *ps; + bool indicate_tim = false; + u8 ignore_for_tim = sta->sta.uapsd_queues; + int ac; + u16 id = sta->sta.aid; + + if (sta->sdata->vif.type == NL80211_IFTYPE_AP || + sta->sdata->vif.type == NL80211_IFTYPE_AP_VLAN) { + if (WARN_ON_ONCE(!sta->sdata->bss)) + return; + + ps = &sta->sdata->bss->ps; +#ifdef CONFIG_MAC80211_MESH + } else if (ieee80211_vif_is_mesh(&sta->sdata->vif)) { + ps = &sta->sdata->u.mesh.ps; +#endif + } else { + return; + } + + /* No need to do anything if the driver does all */ + if (ieee80211_hw_check(&local->hw, AP_LINK_PS) && !local->ops->set_tim) + return; + + if (sta->dead) + goto done; + + /* + * If all ACs are delivery-enabled then we should build + * the TIM bit for all ACs anyway; if only some are then + * we ignore those and build the TIM bit using only the + * non-enabled ones. + */ + if (ignore_for_tim == BIT(IEEE80211_NUM_ACS) - 1) + ignore_for_tim = 0; + + if (ignore_pending) + ignore_for_tim = BIT(IEEE80211_NUM_ACS) - 1; + + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { + unsigned long tids; + + if (ignore_for_tim & ieee80211_ac_to_qos_mask[ac]) + continue; + + indicate_tim |= !skb_queue_empty(&sta->tx_filtered[ac]) || + !skb_queue_empty(&sta->ps_tx_buf[ac]); + if (indicate_tim) + break; + + tids = ieee80211_tids_for_ac(ac); + + indicate_tim |= + sta->driver_buffered_tids & tids; + indicate_tim |= + sta->txq_buffered_tids & tids; + } + + done: + spin_lock_bh(&local->tim_lock); + + if (indicate_tim == __bss_tim_get(ps->tim, id)) + goto out_unlock; + + if (indicate_tim) + __bss_tim_set(ps->tim, id); + else + __bss_tim_clear(ps->tim, id); + + if (local->ops->set_tim && !WARN_ON(sta->dead)) { + local->tim_in_locked_section = true; + drv_set_tim(local, &sta->sta, indicate_tim); + local->tim_in_locked_section = false; + } + +out_unlock: + spin_unlock_bh(&local->tim_lock); +} + +void sta_info_recalc_tim(struct sta_info *sta) +{ + __sta_info_recalc_tim(sta, false); +} + +static bool sta_info_buffer_expired(struct sta_info *sta, struct sk_buff *skb) +{ + struct ieee80211_tx_info *info; + int timeout; + + if (!skb) + return false; + + info = IEEE80211_SKB_CB(skb); + + /* Timeout: (2 * listen_interval * beacon_int * 1024 / 1000000) sec */ + timeout = (sta->listen_interval * + sta->sdata->vif.bss_conf.beacon_int * + 32 / 15625) * HZ; + if (timeout < STA_TX_BUFFER_EXPIRE) + timeout = STA_TX_BUFFER_EXPIRE; + return time_after(jiffies, info->control.jiffies + timeout); +} + + +static bool sta_info_cleanup_expire_buffered_ac(struct ieee80211_local *local, + struct sta_info *sta, int ac) +{ + unsigned long flags; + struct sk_buff *skb; + + /* + * First check for frames that should expire on the filtered + * queue. Frames here were rejected by the driver and are on + * a separate queue to avoid reordering with normal PS-buffered + * frames. They also aren't accounted for right now in the + * total_ps_buffered counter. + */ + for (;;) { + spin_lock_irqsave(&sta->tx_filtered[ac].lock, flags); + skb = skb_peek(&sta->tx_filtered[ac]); + if (sta_info_buffer_expired(sta, skb)) + skb = __skb_dequeue(&sta->tx_filtered[ac]); + else + skb = NULL; + spin_unlock_irqrestore(&sta->tx_filtered[ac].lock, flags); + + /* + * Frames are queued in order, so if this one + * hasn't expired yet we can stop testing. If + * we actually reached the end of the queue we + * also need to stop, of course. + */ + if (!skb) + break; + ieee80211_free_txskb(&local->hw, skb); + } + + /* + * Now also check the normal PS-buffered queue, this will + * only find something if the filtered queue was emptied + * since the filtered frames are all before the normal PS + * buffered frames. + */ + for (;;) { + spin_lock_irqsave(&sta->ps_tx_buf[ac].lock, flags); + skb = skb_peek(&sta->ps_tx_buf[ac]); + if (sta_info_buffer_expired(sta, skb)) + skb = __skb_dequeue(&sta->ps_tx_buf[ac]); + else + skb = NULL; + spin_unlock_irqrestore(&sta->ps_tx_buf[ac].lock, flags); + + /* + * frames are queued in order, so if this one + * hasn't expired yet (or we reached the end of + * the queue) we can stop testing + */ + if (!skb) + break; + + local->total_ps_buffered--; + ps_dbg(sta->sdata, "Buffered frame expired (STA %pM)\n", + sta->sta.addr); + ieee80211_free_txskb(&local->hw, skb); + } + + /* + * Finally, recalculate the TIM bit for this station -- it might + * now be clear because the station was too slow to retrieve its + * frames. + */ + sta_info_recalc_tim(sta); + + /* + * Return whether there are any frames still buffered, this is + * used to check whether the cleanup timer still needs to run, + * if there are no frames we don't need to rearm the timer. + */ + return !(skb_queue_empty(&sta->ps_tx_buf[ac]) && + skb_queue_empty(&sta->tx_filtered[ac])); +} + +static bool sta_info_cleanup_expire_buffered(struct ieee80211_local *local, + struct sta_info *sta) +{ + bool have_buffered = false; + int ac; + + /* This is only necessary for stations on BSS/MBSS interfaces */ + if (!sta->sdata->bss && + !ieee80211_vif_is_mesh(&sta->sdata->vif)) + return false; + + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) + have_buffered |= + sta_info_cleanup_expire_buffered_ac(local, sta, ac); + + return have_buffered; +} + +static int __must_check __sta_info_destroy_part1(struct sta_info *sta) +{ + struct ieee80211_local *local; + struct ieee80211_sub_if_data *sdata; + int ret, i; + + might_sleep(); + + if (!sta) + return -ENOENT; + + local = sta->local; + sdata = sta->sdata; + + lockdep_assert_held(&local->sta_mtx); + + /* + * Before removing the station from the driver and + * rate control, it might still start new aggregation + * sessions -- block that to make sure the tear-down + * will be sufficient. + */ + set_sta_flag(sta, WLAN_STA_BLOCK_BA); + ieee80211_sta_tear_down_BA_sessions(sta, AGG_STOP_DESTROY_STA); + + /* + * Before removing the station from the driver there might be pending + * rx frames on RSS queues sent prior to the disassociation - wait for + * all such frames to be processed. + */ + drv_sync_rx_queues(local, sta); + + for (i = 0; i < ARRAY_SIZE(sta->link); i++) { + struct link_sta_info *link_sta; + + if (!(sta->sta.valid_links & BIT(i))) + continue; + + link_sta = rcu_dereference_protected(sta->link[i], + lockdep_is_held(&local->sta_mtx)); + + link_sta_info_hash_del(local, link_sta); + } + + ret = sta_info_hash_del(local, sta); + if (WARN_ON(ret)) + return ret; + + /* + * for TDLS peers, make sure to return to the base channel before + * removal. + */ + if (test_sta_flag(sta, WLAN_STA_TDLS_OFF_CHANNEL)) { + drv_tdls_cancel_channel_switch(local, sdata, &sta->sta); + clear_sta_flag(sta, WLAN_STA_TDLS_OFF_CHANNEL); + } + + list_del_rcu(&sta->list); + sta->removed = true; + + if (sta->uploaded) + drv_sta_pre_rcu_remove(local, sta->sdata, sta); + + if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN && + rcu_access_pointer(sdata->u.vlan.sta) == sta) + RCU_INIT_POINTER(sdata->u.vlan.sta, NULL); + + return 0; +} + +static void __sta_info_destroy_part2(struct sta_info *sta) +{ + struct ieee80211_local *local = sta->local; + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct station_info *sinfo; + int ret; + + /* + * NOTE: This assumes at least synchronize_net() was done + * after _part1 and before _part2! + */ + + might_sleep(); + lockdep_assert_held(&local->sta_mtx); + + if (sta->sta_state == IEEE80211_STA_AUTHORIZED) { + ret = sta_info_move_state(sta, IEEE80211_STA_ASSOC); + WARN_ON_ONCE(ret); + } + + /* now keys can no longer be reached */ + ieee80211_free_sta_keys(local, sta); + + /* disable TIM bit - last chance to tell driver */ + __sta_info_recalc_tim(sta, true); + + sta->dead = true; + + local->num_sta--; + local->sta_generation++; + + while (sta->sta_state > IEEE80211_STA_NONE) { + ret = sta_info_move_state(sta, sta->sta_state - 1); + if (ret) { + WARN_ON_ONCE(1); + break; + } + } + + if (sta->uploaded) { + ret = drv_sta_state(local, sdata, sta, IEEE80211_STA_NONE, + IEEE80211_STA_NOTEXIST); + WARN_ON_ONCE(ret != 0); + } + + sta_dbg(sdata, "Removed STA %pM\n", sta->sta.addr); + + sinfo = kzalloc(sizeof(*sinfo), GFP_KERNEL); + if (sinfo) + sta_set_sinfo(sta, sinfo, true); + cfg80211_del_sta_sinfo(sdata->dev, sta->sta.addr, sinfo, GFP_KERNEL); + kfree(sinfo); + + ieee80211_sta_debugfs_remove(sta); + + ieee80211_destroy_frag_cache(&sta->frags); + + cleanup_single_sta(sta); +} + +int __must_check __sta_info_destroy(struct sta_info *sta) +{ + int err = __sta_info_destroy_part1(sta); + + if (err) + return err; + + synchronize_net(); + + __sta_info_destroy_part2(sta); + + return 0; +} + +int sta_info_destroy_addr(struct ieee80211_sub_if_data *sdata, const u8 *addr) +{ + struct sta_info *sta; + int ret; + + mutex_lock(&sdata->local->sta_mtx); + sta = sta_info_get(sdata, addr); + ret = __sta_info_destroy(sta); + mutex_unlock(&sdata->local->sta_mtx); + + return ret; +} + +int sta_info_destroy_addr_bss(struct ieee80211_sub_if_data *sdata, + const u8 *addr) +{ + struct sta_info *sta; + int ret; + + mutex_lock(&sdata->local->sta_mtx); + sta = sta_info_get_bss(sdata, addr); + ret = __sta_info_destroy(sta); + mutex_unlock(&sdata->local->sta_mtx); + + return ret; +} + +static void sta_info_cleanup(struct timer_list *t) +{ + struct ieee80211_local *local = from_timer(local, t, sta_cleanup); + struct sta_info *sta; + bool timer_needed = false; + + rcu_read_lock(); + list_for_each_entry_rcu(sta, &local->sta_list, list) + if (sta_info_cleanup_expire_buffered(local, sta)) + timer_needed = true; + rcu_read_unlock(); + + if (local->quiescing) + return; + + if (!timer_needed) + return; + + mod_timer(&local->sta_cleanup, + round_jiffies(jiffies + STA_INFO_CLEANUP_INTERVAL)); +} + +int sta_info_init(struct ieee80211_local *local) +{ + int err; + + err = rhltable_init(&local->sta_hash, &sta_rht_params); + if (err) + return err; + + err = rhltable_init(&local->link_sta_hash, &link_sta_rht_params); + if (err) { + rhltable_destroy(&local->sta_hash); + return err; + } + + spin_lock_init(&local->tim_lock); + mutex_init(&local->sta_mtx); + INIT_LIST_HEAD(&local->sta_list); + + timer_setup(&local->sta_cleanup, sta_info_cleanup, 0); + return 0; +} + +void sta_info_stop(struct ieee80211_local *local) +{ + del_timer_sync(&local->sta_cleanup); + rhltable_destroy(&local->sta_hash); + rhltable_destroy(&local->link_sta_hash); +} + + +int __sta_info_flush(struct ieee80211_sub_if_data *sdata, bool vlans) +{ + struct ieee80211_local *local = sdata->local; + struct sta_info *sta, *tmp; + LIST_HEAD(free_list); + int ret = 0; + + might_sleep(); + + WARN_ON(vlans && sdata->vif.type != NL80211_IFTYPE_AP); + WARN_ON(vlans && !sdata->bss); + + mutex_lock(&local->sta_mtx); + list_for_each_entry_safe(sta, tmp, &local->sta_list, list) { + if (sdata == sta->sdata || + (vlans && sdata->bss == sta->sdata->bss)) { + if (!WARN_ON(__sta_info_destroy_part1(sta))) + list_add(&sta->free_list, &free_list); + ret++; + } + } + + if (!list_empty(&free_list)) { + synchronize_net(); + list_for_each_entry_safe(sta, tmp, &free_list, free_list) + __sta_info_destroy_part2(sta); + } + mutex_unlock(&local->sta_mtx); + + return ret; +} + +void ieee80211_sta_expire(struct ieee80211_sub_if_data *sdata, + unsigned long exp_time) +{ + struct ieee80211_local *local = sdata->local; + struct sta_info *sta, *tmp; + + mutex_lock(&local->sta_mtx); + + list_for_each_entry_safe(sta, tmp, &local->sta_list, list) { + unsigned long last_active = ieee80211_sta_last_active(sta); + + if (sdata != sta->sdata) + continue; + + if (time_is_before_jiffies(last_active + exp_time)) { + sta_dbg(sta->sdata, "expiring inactive STA %pM\n", + sta->sta.addr); + + if (ieee80211_vif_is_mesh(&sdata->vif) && + test_sta_flag(sta, WLAN_STA_PS_STA)) + atomic_dec(&sdata->u.mesh.ps.num_sta_ps); + + WARN_ON(__sta_info_destroy(sta)); + } + } + + mutex_unlock(&local->sta_mtx); +} + +struct ieee80211_sta *ieee80211_find_sta_by_ifaddr(struct ieee80211_hw *hw, + const u8 *addr, + const u8 *localaddr) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct rhlist_head *tmp; + struct sta_info *sta; + + /* + * Just return a random station if localaddr is NULL + * ... first in list. + */ + for_each_sta_info(local, addr, sta, tmp) { + if (localaddr && + !ether_addr_equal(sta->sdata->vif.addr, localaddr)) + continue; + if (!sta->uploaded) + return NULL; + return &sta->sta; + } + + return NULL; +} +EXPORT_SYMBOL_GPL(ieee80211_find_sta_by_ifaddr); + +struct ieee80211_sta *ieee80211_find_sta(struct ieee80211_vif *vif, + const u8 *addr) +{ + struct sta_info *sta; + + if (!vif) + return NULL; + + sta = sta_info_get_bss(vif_to_sdata(vif), addr); + if (!sta) + return NULL; + + if (!sta->uploaded) + return NULL; + + return &sta->sta; +} +EXPORT_SYMBOL(ieee80211_find_sta); + +/* powersave support code */ +void ieee80211_sta_ps_deliver_wakeup(struct sta_info *sta) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sdata->local; + struct sk_buff_head pending; + int filtered = 0, buffered = 0, ac, i; + unsigned long flags; + struct ps_data *ps; + + if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + sdata = container_of(sdata->bss, struct ieee80211_sub_if_data, + u.ap); + + if (sdata->vif.type == NL80211_IFTYPE_AP) + ps = &sdata->bss->ps; + else if (ieee80211_vif_is_mesh(&sdata->vif)) + ps = &sdata->u.mesh.ps; + else + return; + + clear_sta_flag(sta, WLAN_STA_SP); + + BUILD_BUG_ON(BITS_TO_LONGS(IEEE80211_NUM_TIDS) > 1); + sta->driver_buffered_tids = 0; + sta->txq_buffered_tids = 0; + + if (!ieee80211_hw_check(&local->hw, AP_LINK_PS)) + drv_sta_notify(local, sdata, STA_NOTIFY_AWAKE, &sta->sta); + + for (i = 0; i < ARRAY_SIZE(sta->sta.txq); i++) { + if (!sta->sta.txq[i] || !txq_has_queue(sta->sta.txq[i])) + continue; + + schedule_and_wake_txq(local, to_txq_info(sta->sta.txq[i])); + } + + skb_queue_head_init(&pending); + + /* sync with ieee80211_tx_h_unicast_ps_buf */ + spin_lock(&sta->ps_lock); + /* Send all buffered frames to the station */ + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { + int count = skb_queue_len(&pending), tmp; + + spin_lock_irqsave(&sta->tx_filtered[ac].lock, flags); + skb_queue_splice_tail_init(&sta->tx_filtered[ac], &pending); + spin_unlock_irqrestore(&sta->tx_filtered[ac].lock, flags); + tmp = skb_queue_len(&pending); + filtered += tmp - count; + count = tmp; + + spin_lock_irqsave(&sta->ps_tx_buf[ac].lock, flags); + skb_queue_splice_tail_init(&sta->ps_tx_buf[ac], &pending); + spin_unlock_irqrestore(&sta->ps_tx_buf[ac].lock, flags); + tmp = skb_queue_len(&pending); + buffered += tmp - count; + } + + ieee80211_add_pending_skbs(local, &pending); + + /* now we're no longer in the deliver code */ + clear_sta_flag(sta, WLAN_STA_PS_DELIVER); + + /* The station might have polled and then woken up before we responded, + * so clear these flags now to avoid them sticking around. + */ + clear_sta_flag(sta, WLAN_STA_PSPOLL); + clear_sta_flag(sta, WLAN_STA_UAPSD); + spin_unlock(&sta->ps_lock); + + atomic_dec(&ps->num_sta_ps); + + local->total_ps_buffered -= buffered; + + sta_info_recalc_tim(sta); + + ps_dbg(sdata, + "STA %pM aid %d sending %d filtered/%d PS frames since STA woke up\n", + sta->sta.addr, sta->sta.aid, filtered, buffered); + + ieee80211_check_fast_xmit(sta); +} + +static void ieee80211_send_null_response(struct sta_info *sta, int tid, + enum ieee80211_frame_release_type reason, + bool call_driver, bool more_data) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_qos_hdr *nullfunc; + struct sk_buff *skb; + int size = sizeof(*nullfunc); + __le16 fc; + bool qos = sta->sta.wme; + struct ieee80211_tx_info *info; + struct ieee80211_chanctx_conf *chanctx_conf; + + if (qos) { + fc = cpu_to_le16(IEEE80211_FTYPE_DATA | + IEEE80211_STYPE_QOS_NULLFUNC | + IEEE80211_FCTL_FROMDS); + } else { + size -= 2; + fc = cpu_to_le16(IEEE80211_FTYPE_DATA | + IEEE80211_STYPE_NULLFUNC | + IEEE80211_FCTL_FROMDS); + } + + skb = dev_alloc_skb(local->hw.extra_tx_headroom + size); + if (!skb) + return; + + skb_reserve(skb, local->hw.extra_tx_headroom); + + nullfunc = skb_put(skb, size); + nullfunc->frame_control = fc; + nullfunc->duration_id = 0; + memcpy(nullfunc->addr1, sta->sta.addr, ETH_ALEN); + memcpy(nullfunc->addr2, sdata->vif.addr, ETH_ALEN); + memcpy(nullfunc->addr3, sdata->vif.addr, ETH_ALEN); + nullfunc->seq_ctrl = 0; + + skb->priority = tid; + skb_set_queue_mapping(skb, ieee802_1d_to_ac[tid]); + if (qos) { + nullfunc->qos_ctrl = cpu_to_le16(tid); + + if (reason == IEEE80211_FRAME_RELEASE_UAPSD) { + nullfunc->qos_ctrl |= + cpu_to_le16(IEEE80211_QOS_CTL_EOSP); + if (more_data) + nullfunc->frame_control |= + cpu_to_le16(IEEE80211_FCTL_MOREDATA); + } + } + + info = IEEE80211_SKB_CB(skb); + + /* + * Tell TX path to send this frame even though the + * STA may still remain is PS mode after this frame + * exchange. Also set EOSP to indicate this packet + * ends the poll/service period. + */ + info->flags |= IEEE80211_TX_CTL_NO_PS_BUFFER | + IEEE80211_TX_STATUS_EOSP | + IEEE80211_TX_CTL_REQ_TX_STATUS; + + info->control.flags |= IEEE80211_TX_CTRL_PS_RESPONSE; + + if (call_driver) + drv_allow_buffered_frames(local, sta, BIT(tid), 1, + reason, false); + + skb->dev = sdata->dev; + + rcu_read_lock(); + chanctx_conf = rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + if (WARN_ON(!chanctx_conf)) { + rcu_read_unlock(); + kfree_skb(skb); + return; + } + + info->band = chanctx_conf->def.chan->band; + ieee80211_xmit(sdata, sta, skb); + rcu_read_unlock(); +} + +static int find_highest_prio_tid(unsigned long tids) +{ + /* lower 3 TIDs aren't ordered perfectly */ + if (tids & 0xF8) + return fls(tids) - 1; + /* TID 0 is BE just like TID 3 */ + if (tids & BIT(0)) + return 0; + return fls(tids) - 1; +} + +/* Indicates if the MORE_DATA bit should be set in the last + * frame obtained by ieee80211_sta_ps_get_frames. + * Note that driver_release_tids is relevant only if + * reason = IEEE80211_FRAME_RELEASE_PSPOLL + */ +static bool +ieee80211_sta_ps_more_data(struct sta_info *sta, u8 ignored_acs, + enum ieee80211_frame_release_type reason, + unsigned long driver_release_tids) +{ + int ac; + + /* If the driver has data on more than one TID then + * certainly there's more data if we release just a + * single frame now (from a single TID). This will + * only happen for PS-Poll. + */ + if (reason == IEEE80211_FRAME_RELEASE_PSPOLL && + hweight16(driver_release_tids) > 1) + return true; + + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { + if (ignored_acs & ieee80211_ac_to_qos_mask[ac]) + continue; + + if (!skb_queue_empty(&sta->tx_filtered[ac]) || + !skb_queue_empty(&sta->ps_tx_buf[ac])) + return true; + } + + return false; +} + +static void +ieee80211_sta_ps_get_frames(struct sta_info *sta, int n_frames, u8 ignored_acs, + enum ieee80211_frame_release_type reason, + struct sk_buff_head *frames, + unsigned long *driver_release_tids) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sdata->local; + int ac; + + /* Get response frame(s) and more data bit for the last one. */ + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { + unsigned long tids; + + if (ignored_acs & ieee80211_ac_to_qos_mask[ac]) + continue; + + tids = ieee80211_tids_for_ac(ac); + + /* if we already have frames from software, then we can't also + * release from hardware queues + */ + if (skb_queue_empty(frames)) { + *driver_release_tids |= + sta->driver_buffered_tids & tids; + *driver_release_tids |= sta->txq_buffered_tids & tids; + } + + if (!*driver_release_tids) { + struct sk_buff *skb; + + while (n_frames > 0) { + skb = skb_dequeue(&sta->tx_filtered[ac]); + if (!skb) { + skb = skb_dequeue( + &sta->ps_tx_buf[ac]); + if (skb) + local->total_ps_buffered--; + } + if (!skb) + break; + n_frames--; + __skb_queue_tail(frames, skb); + } + } + + /* If we have more frames buffered on this AC, then abort the + * loop since we can't send more data from other ACs before + * the buffered frames from this. + */ + if (!skb_queue_empty(&sta->tx_filtered[ac]) || + !skb_queue_empty(&sta->ps_tx_buf[ac])) + break; + } +} + +static void +ieee80211_sta_ps_deliver_response(struct sta_info *sta, + int n_frames, u8 ignored_acs, + enum ieee80211_frame_release_type reason) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sdata->local; + unsigned long driver_release_tids = 0; + struct sk_buff_head frames; + bool more_data; + + /* Service or PS-Poll period starts */ + set_sta_flag(sta, WLAN_STA_SP); + + __skb_queue_head_init(&frames); + + ieee80211_sta_ps_get_frames(sta, n_frames, ignored_acs, reason, + &frames, &driver_release_tids); + + more_data = ieee80211_sta_ps_more_data(sta, ignored_acs, reason, driver_release_tids); + + if (driver_release_tids && reason == IEEE80211_FRAME_RELEASE_PSPOLL) + driver_release_tids = + BIT(find_highest_prio_tid(driver_release_tids)); + + if (skb_queue_empty(&frames) && !driver_release_tids) { + int tid, ac; + + /* + * For PS-Poll, this can only happen due to a race condition + * when we set the TIM bit and the station notices it, but + * before it can poll for the frame we expire it. + * + * For uAPSD, this is said in the standard (11.2.1.5 h): + * At each unscheduled SP for a non-AP STA, the AP shall + * attempt to transmit at least one MSDU or MMPDU, but no + * more than the value specified in the Max SP Length field + * in the QoS Capability element from delivery-enabled ACs, + * that are destined for the non-AP STA. + * + * Since we have no other MSDU/MMPDU, transmit a QoS null frame. + */ + + /* This will evaluate to 1, 3, 5 or 7. */ + for (ac = IEEE80211_AC_VO; ac < IEEE80211_NUM_ACS; ac++) + if (!(ignored_acs & ieee80211_ac_to_qos_mask[ac])) + break; + tid = 7 - 2 * ac; + + ieee80211_send_null_response(sta, tid, reason, true, false); + } else if (!driver_release_tids) { + struct sk_buff_head pending; + struct sk_buff *skb; + int num = 0; + u16 tids = 0; + bool need_null = false; + + skb_queue_head_init(&pending); + + while ((skb = __skb_dequeue(&frames))) { + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_hdr *hdr = (void *) skb->data; + u8 *qoshdr = NULL; + + num++; + + /* + * Tell TX path to send this frame even though the + * STA may still remain is PS mode after this frame + * exchange. + */ + info->flags |= IEEE80211_TX_CTL_NO_PS_BUFFER; + info->control.flags |= IEEE80211_TX_CTRL_PS_RESPONSE; + + /* + * Use MoreData flag to indicate whether there are + * more buffered frames for this STA + */ + if (more_data || !skb_queue_empty(&frames)) + hdr->frame_control |= + cpu_to_le16(IEEE80211_FCTL_MOREDATA); + else + hdr->frame_control &= + cpu_to_le16(~IEEE80211_FCTL_MOREDATA); + + if (ieee80211_is_data_qos(hdr->frame_control) || + ieee80211_is_qos_nullfunc(hdr->frame_control)) + qoshdr = ieee80211_get_qos_ctl(hdr); + + tids |= BIT(skb->priority); + + __skb_queue_tail(&pending, skb); + + /* end service period after last frame or add one */ + if (!skb_queue_empty(&frames)) + continue; + + if (reason != IEEE80211_FRAME_RELEASE_UAPSD) { + /* for PS-Poll, there's only one frame */ + info->flags |= IEEE80211_TX_STATUS_EOSP | + IEEE80211_TX_CTL_REQ_TX_STATUS; + break; + } + + /* For uAPSD, things are a bit more complicated. If the + * last frame has a QoS header (i.e. is a QoS-data or + * QoS-nulldata frame) then just set the EOSP bit there + * and be done. + * If the frame doesn't have a QoS header (which means + * it should be a bufferable MMPDU) then we can't set + * the EOSP bit in the QoS header; add a QoS-nulldata + * frame to the list to send it after the MMPDU. + * + * Note that this code is only in the mac80211-release + * code path, we assume that the driver will not buffer + * anything but QoS-data frames, or if it does, will + * create the QoS-nulldata frame by itself if needed. + * + * Cf. 802.11-2012 10.2.1.10 (c). + */ + if (qoshdr) { + *qoshdr |= IEEE80211_QOS_CTL_EOSP; + + info->flags |= IEEE80211_TX_STATUS_EOSP | + IEEE80211_TX_CTL_REQ_TX_STATUS; + } else { + /* The standard isn't completely clear on this + * as it says the more-data bit should be set + * if there are more BUs. The QoS-Null frame + * we're about to send isn't buffered yet, we + * only create it below, but let's pretend it + * was buffered just in case some clients only + * expect more-data=0 when eosp=1. + */ + hdr->frame_control |= + cpu_to_le16(IEEE80211_FCTL_MOREDATA); + need_null = true; + num++; + } + break; + } + + drv_allow_buffered_frames(local, sta, tids, num, + reason, more_data); + + ieee80211_add_pending_skbs(local, &pending); + + if (need_null) + ieee80211_send_null_response( + sta, find_highest_prio_tid(tids), + reason, false, false); + + sta_info_recalc_tim(sta); + } else { + int tid; + + /* + * We need to release a frame that is buffered somewhere in the + * driver ... it'll have to handle that. + * Note that the driver also has to check the number of frames + * on the TIDs we're releasing from - if there are more than + * n_frames it has to set the more-data bit (if we didn't ask + * it to set it anyway due to other buffered frames); if there + * are fewer than n_frames it has to make sure to adjust that + * to allow the service period to end properly. + */ + drv_release_buffered_frames(local, sta, driver_release_tids, + n_frames, reason, more_data); + + /* + * Note that we don't recalculate the TIM bit here as it would + * most likely have no effect at all unless the driver told us + * that the TID(s) became empty before returning here from the + * release function. + * Either way, however, when the driver tells us that the TID(s) + * became empty or we find that a txq became empty, we'll do the + * TIM recalculation. + */ + + if (!sta->sta.txq[0]) + return; + + for (tid = 0; tid < ARRAY_SIZE(sta->sta.txq); tid++) { + if (!sta->sta.txq[tid] || + !(driver_release_tids & BIT(tid)) || + txq_has_queue(sta->sta.txq[tid])) + continue; + + sta_info_recalc_tim(sta); + break; + } + } +} + +void ieee80211_sta_ps_deliver_poll_response(struct sta_info *sta) +{ + u8 ignore_for_response = sta->sta.uapsd_queues; + + /* + * If all ACs are delivery-enabled then we should reply + * from any of them, if only some are enabled we reply + * only from the non-enabled ones. + */ + if (ignore_for_response == BIT(IEEE80211_NUM_ACS) - 1) + ignore_for_response = 0; + + ieee80211_sta_ps_deliver_response(sta, 1, ignore_for_response, + IEEE80211_FRAME_RELEASE_PSPOLL); +} + +void ieee80211_sta_ps_deliver_uapsd(struct sta_info *sta) +{ + int n_frames = sta->sta.max_sp; + u8 delivery_enabled = sta->sta.uapsd_queues; + + /* + * If we ever grow support for TSPEC this might happen if + * the TSPEC update from hostapd comes in between a trigger + * frame setting WLAN_STA_UAPSD in the RX path and this + * actually getting called. + */ + if (!delivery_enabled) + return; + + switch (sta->sta.max_sp) { + case 1: + n_frames = 2; + break; + case 2: + n_frames = 4; + break; + case 3: + n_frames = 6; + break; + case 0: + /* XXX: what is a good value? */ + n_frames = 128; + break; + } + + ieee80211_sta_ps_deliver_response(sta, n_frames, ~delivery_enabled, + IEEE80211_FRAME_RELEASE_UAPSD); +} + +void ieee80211_sta_block_awake(struct ieee80211_hw *hw, + struct ieee80211_sta *pubsta, bool block) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + + trace_api_sta_block_awake(sta->local, pubsta, block); + + if (block) { + set_sta_flag(sta, WLAN_STA_PS_DRIVER); + ieee80211_clear_fast_xmit(sta); + return; + } + + if (!test_sta_flag(sta, WLAN_STA_PS_DRIVER)) + return; + + if (!test_sta_flag(sta, WLAN_STA_PS_STA)) { + set_sta_flag(sta, WLAN_STA_PS_DELIVER); + clear_sta_flag(sta, WLAN_STA_PS_DRIVER); + ieee80211_queue_work(hw, &sta->drv_deliver_wk); + } else if (test_sta_flag(sta, WLAN_STA_PSPOLL) || + test_sta_flag(sta, WLAN_STA_UAPSD)) { + /* must be asleep in this case */ + clear_sta_flag(sta, WLAN_STA_PS_DRIVER); + ieee80211_queue_work(hw, &sta->drv_deliver_wk); + } else { + clear_sta_flag(sta, WLAN_STA_PS_DRIVER); + ieee80211_check_fast_xmit(sta); + } +} +EXPORT_SYMBOL(ieee80211_sta_block_awake); + +void ieee80211_sta_eosp(struct ieee80211_sta *pubsta) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + struct ieee80211_local *local = sta->local; + + trace_api_eosp(local, pubsta); + + clear_sta_flag(sta, WLAN_STA_SP); +} +EXPORT_SYMBOL(ieee80211_sta_eosp); + +void ieee80211_send_eosp_nullfunc(struct ieee80211_sta *pubsta, int tid) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + enum ieee80211_frame_release_type reason; + bool more_data; + + trace_api_send_eosp_nullfunc(sta->local, pubsta, tid); + + reason = IEEE80211_FRAME_RELEASE_UAPSD; + more_data = ieee80211_sta_ps_more_data(sta, ~sta->sta.uapsd_queues, + reason, 0); + + ieee80211_send_null_response(sta, tid, reason, false, more_data); +} +EXPORT_SYMBOL(ieee80211_send_eosp_nullfunc); + +void ieee80211_sta_set_buffered(struct ieee80211_sta *pubsta, + u8 tid, bool buffered) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + + if (WARN_ON(tid >= IEEE80211_NUM_TIDS)) + return; + + trace_api_sta_set_buffered(sta->local, pubsta, tid, buffered); + + if (buffered) + set_bit(tid, &sta->driver_buffered_tids); + else + clear_bit(tid, &sta->driver_buffered_tids); + + sta_info_recalc_tim(sta); +} +EXPORT_SYMBOL(ieee80211_sta_set_buffered); + +void ieee80211_sta_register_airtime(struct ieee80211_sta *pubsta, u8 tid, + u32 tx_airtime, u32 rx_airtime) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + struct ieee80211_local *local = sta->sdata->local; + u8 ac = ieee80211_ac_from_tid(tid); + u32 airtime = 0; + u32 diff; + + if (sta->local->airtime_flags & AIRTIME_USE_TX) + airtime += tx_airtime; + if (sta->local->airtime_flags & AIRTIME_USE_RX) + airtime += rx_airtime; + + spin_lock_bh(&local->active_txq_lock[ac]); + sta->airtime[ac].tx_airtime += tx_airtime; + sta->airtime[ac].rx_airtime += rx_airtime; + + diff = (u32)jiffies - sta->airtime[ac].last_active; + if (diff <= AIRTIME_ACTIVE_DURATION) + sta->airtime[ac].deficit -= airtime; + + spin_unlock_bh(&local->active_txq_lock[ac]); +} +EXPORT_SYMBOL(ieee80211_sta_register_airtime); + +void ieee80211_sta_recalc_aggregates(struct ieee80211_sta *pubsta) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + struct ieee80211_link_sta *link_sta; + int link_id, i; + bool first = true; + + if (!pubsta->valid_links || !pubsta->mlo) { + pubsta->cur = &pubsta->deflink.agg; + return; + } + + rcu_read_lock(); + for_each_sta_active_link(&sta->sdata->vif, pubsta, link_sta, link_id) { + if (first) { + sta->cur = pubsta->deflink.agg; + first = false; + continue; + } + + sta->cur.max_amsdu_len = + min(sta->cur.max_amsdu_len, + link_sta->agg.max_amsdu_len); + sta->cur.max_rc_amsdu_len = + min(sta->cur.max_rc_amsdu_len, + link_sta->agg.max_rc_amsdu_len); + + for (i = 0; i < ARRAY_SIZE(sta->cur.max_tid_amsdu_len); i++) + sta->cur.max_tid_amsdu_len[i] = + min(sta->cur.max_tid_amsdu_len[i], + link_sta->agg.max_tid_amsdu_len[i]); + } + rcu_read_unlock(); + + pubsta->cur = &sta->cur; +} +EXPORT_SYMBOL(ieee80211_sta_recalc_aggregates); + +void ieee80211_sta_update_pending_airtime(struct ieee80211_local *local, + struct sta_info *sta, u8 ac, + u16 tx_airtime, bool tx_completed) +{ + int tx_pending; + + if (!wiphy_ext_feature_isset(local->hw.wiphy, NL80211_EXT_FEATURE_AQL)) + return; + + if (!tx_completed) { + if (sta) + atomic_add(tx_airtime, + &sta->airtime[ac].aql_tx_pending); + + atomic_add(tx_airtime, &local->aql_total_pending_airtime); + atomic_add(tx_airtime, &local->aql_ac_pending_airtime[ac]); + return; + } + + if (sta) { + tx_pending = atomic_sub_return(tx_airtime, + &sta->airtime[ac].aql_tx_pending); + if (tx_pending < 0) + atomic_cmpxchg(&sta->airtime[ac].aql_tx_pending, + tx_pending, 0); + } + + atomic_sub(tx_airtime, &local->aql_total_pending_airtime); + tx_pending = atomic_sub_return(tx_airtime, + &local->aql_ac_pending_airtime[ac]); + if (WARN_ONCE(tx_pending < 0, + "Device %s AC %d pending airtime underflow: %u, %u", + wiphy_name(local->hw.wiphy), ac, tx_pending, + tx_airtime)) { + atomic_cmpxchg(&local->aql_ac_pending_airtime[ac], + tx_pending, 0); + atomic_sub(tx_pending, &local->aql_total_pending_airtime); + } +} + +int sta_info_move_state(struct sta_info *sta, + enum ieee80211_sta_state new_state) +{ + might_sleep(); + + if (sta->sta_state == new_state) + return 0; + + /* check allowed transitions first */ + + switch (new_state) { + case IEEE80211_STA_NONE: + if (sta->sta_state != IEEE80211_STA_AUTH) + return -EINVAL; + break; + case IEEE80211_STA_AUTH: + if (sta->sta_state != IEEE80211_STA_NONE && + sta->sta_state != IEEE80211_STA_ASSOC) + return -EINVAL; + break; + case IEEE80211_STA_ASSOC: + if (sta->sta_state != IEEE80211_STA_AUTH && + sta->sta_state != IEEE80211_STA_AUTHORIZED) + return -EINVAL; + break; + case IEEE80211_STA_AUTHORIZED: + if (sta->sta_state != IEEE80211_STA_ASSOC) + return -EINVAL; + break; + default: + WARN(1, "invalid state %d", new_state); + return -EINVAL; + } + + sta_dbg(sta->sdata, "moving STA %pM to state %d\n", + sta->sta.addr, new_state); + + /* + * notify the driver before the actual changes so it can + * fail the transition + */ + if (test_sta_flag(sta, WLAN_STA_INSERTED)) { + int err = drv_sta_state(sta->local, sta->sdata, sta, + sta->sta_state, new_state); + if (err) + return err; + } + + /* reflect the change in all state variables */ + + switch (new_state) { + case IEEE80211_STA_NONE: + if (sta->sta_state == IEEE80211_STA_AUTH) + clear_bit(WLAN_STA_AUTH, &sta->_flags); + break; + case IEEE80211_STA_AUTH: + if (sta->sta_state == IEEE80211_STA_NONE) { + set_bit(WLAN_STA_AUTH, &sta->_flags); + } else if (sta->sta_state == IEEE80211_STA_ASSOC) { + clear_bit(WLAN_STA_ASSOC, &sta->_flags); + ieee80211_recalc_min_chandef(sta->sdata, -1); + if (!sta->sta.support_p2p_ps) + ieee80211_recalc_p2p_go_ps_allowed(sta->sdata); + } + break; + case IEEE80211_STA_ASSOC: + if (sta->sta_state == IEEE80211_STA_AUTH) { + set_bit(WLAN_STA_ASSOC, &sta->_flags); + sta->assoc_at = ktime_get_boottime_ns(); + ieee80211_recalc_min_chandef(sta->sdata, -1); + if (!sta->sta.support_p2p_ps) + ieee80211_recalc_p2p_go_ps_allowed(sta->sdata); + } else if (sta->sta_state == IEEE80211_STA_AUTHORIZED) { + ieee80211_vif_dec_num_mcast(sta->sdata); + clear_bit(WLAN_STA_AUTHORIZED, &sta->_flags); + ieee80211_clear_fast_xmit(sta); + ieee80211_clear_fast_rx(sta); + } + break; + case IEEE80211_STA_AUTHORIZED: + if (sta->sta_state == IEEE80211_STA_ASSOC) { + ieee80211_vif_inc_num_mcast(sta->sdata); + set_bit(WLAN_STA_AUTHORIZED, &sta->_flags); + ieee80211_check_fast_xmit(sta); + ieee80211_check_fast_rx(sta); + } + if (sta->sdata->vif.type == NL80211_IFTYPE_AP_VLAN || + sta->sdata->vif.type == NL80211_IFTYPE_AP) + cfg80211_send_layer2_update(sta->sdata->dev, + sta->sta.addr); + break; + default: + break; + } + + sta->sta_state = new_state; + + return 0; +} + +static struct ieee80211_sta_rx_stats * +sta_get_last_rx_stats(struct sta_info *sta) +{ + struct ieee80211_sta_rx_stats *stats = &sta->deflink.rx_stats; + int cpu; + + if (!sta->deflink.pcpu_rx_stats) + return stats; + + for_each_possible_cpu(cpu) { + struct ieee80211_sta_rx_stats *cpustats; + + cpustats = per_cpu_ptr(sta->deflink.pcpu_rx_stats, cpu); + + if (time_after(cpustats->last_rx, stats->last_rx)) + stats = cpustats; + } + + return stats; +} + +static void sta_stats_decode_rate(struct ieee80211_local *local, u32 rate, + struct rate_info *rinfo) +{ + rinfo->bw = STA_STATS_GET(BW, rate); + + switch (STA_STATS_GET(TYPE, rate)) { + case STA_STATS_RATE_TYPE_VHT: + rinfo->flags = RATE_INFO_FLAGS_VHT_MCS; + rinfo->mcs = STA_STATS_GET(VHT_MCS, rate); + rinfo->nss = STA_STATS_GET(VHT_NSS, rate); + if (STA_STATS_GET(SGI, rate)) + rinfo->flags |= RATE_INFO_FLAGS_SHORT_GI; + break; + case STA_STATS_RATE_TYPE_HT: + rinfo->flags = RATE_INFO_FLAGS_MCS; + rinfo->mcs = STA_STATS_GET(HT_MCS, rate); + if (STA_STATS_GET(SGI, rate)) + rinfo->flags |= RATE_INFO_FLAGS_SHORT_GI; + break; + case STA_STATS_RATE_TYPE_LEGACY: { + struct ieee80211_supported_band *sband; + u16 brate; + unsigned int shift; + int band = STA_STATS_GET(LEGACY_BAND, rate); + int rate_idx = STA_STATS_GET(LEGACY_IDX, rate); + + sband = local->hw.wiphy->bands[band]; + + if (WARN_ON_ONCE(!sband->bitrates)) + break; + + brate = sband->bitrates[rate_idx].bitrate; + if (rinfo->bw == RATE_INFO_BW_5) + shift = 2; + else if (rinfo->bw == RATE_INFO_BW_10) + shift = 1; + else + shift = 0; + rinfo->legacy = DIV_ROUND_UP(brate, 1 << shift); + break; + } + case STA_STATS_RATE_TYPE_HE: + rinfo->flags = RATE_INFO_FLAGS_HE_MCS; + rinfo->mcs = STA_STATS_GET(HE_MCS, rate); + rinfo->nss = STA_STATS_GET(HE_NSS, rate); + rinfo->he_gi = STA_STATS_GET(HE_GI, rate); + rinfo->he_ru_alloc = STA_STATS_GET(HE_RU, rate); + rinfo->he_dcm = STA_STATS_GET(HE_DCM, rate); + break; + } +} + +static int sta_set_rate_info_rx(struct sta_info *sta, struct rate_info *rinfo) +{ + u32 rate = READ_ONCE(sta_get_last_rx_stats(sta)->last_rate); + + if (rate == STA_STATS_RATE_INVALID) + return -EINVAL; + + sta_stats_decode_rate(sta->local, rate, rinfo); + return 0; +} + +static inline u64 sta_get_tidstats_msdu(struct ieee80211_sta_rx_stats *rxstats, + int tid) +{ + unsigned int start; + u64 value; + + do { + start = u64_stats_fetch_begin_irq(&rxstats->syncp); + value = rxstats->msdu[tid]; + } while (u64_stats_fetch_retry_irq(&rxstats->syncp, start)); + + return value; +} + +static void sta_set_tidstats(struct sta_info *sta, + struct cfg80211_tid_stats *tidstats, + int tid) +{ + struct ieee80211_local *local = sta->local; + int cpu; + + if (!(tidstats->filled & BIT(NL80211_TID_STATS_RX_MSDU))) { + tidstats->rx_msdu += sta_get_tidstats_msdu(&sta->deflink.rx_stats, + tid); + + if (sta->deflink.pcpu_rx_stats) { + for_each_possible_cpu(cpu) { + struct ieee80211_sta_rx_stats *cpurxs; + + cpurxs = per_cpu_ptr(sta->deflink.pcpu_rx_stats, + cpu); + tidstats->rx_msdu += + sta_get_tidstats_msdu(cpurxs, tid); + } + } + + tidstats->filled |= BIT(NL80211_TID_STATS_RX_MSDU); + } + + if (!(tidstats->filled & BIT(NL80211_TID_STATS_TX_MSDU))) { + tidstats->filled |= BIT(NL80211_TID_STATS_TX_MSDU); + tidstats->tx_msdu = sta->deflink.tx_stats.msdu[tid]; + } + + if (!(tidstats->filled & BIT(NL80211_TID_STATS_TX_MSDU_RETRIES)) && + ieee80211_hw_check(&local->hw, REPORTS_TX_ACK_STATUS)) { + tidstats->filled |= BIT(NL80211_TID_STATS_TX_MSDU_RETRIES); + tidstats->tx_msdu_retries = sta->deflink.status_stats.msdu_retries[tid]; + } + + if (!(tidstats->filled & BIT(NL80211_TID_STATS_TX_MSDU_FAILED)) && + ieee80211_hw_check(&local->hw, REPORTS_TX_ACK_STATUS)) { + tidstats->filled |= BIT(NL80211_TID_STATS_TX_MSDU_FAILED); + tidstats->tx_msdu_failed = sta->deflink.status_stats.msdu_failed[tid]; + } + + if (local->ops->wake_tx_queue && tid < IEEE80211_NUM_TIDS) { + spin_lock_bh(&local->fq.lock); + rcu_read_lock(); + + tidstats->filled |= BIT(NL80211_TID_STATS_TXQ_STATS); + ieee80211_fill_txq_stats(&tidstats->txq_stats, + to_txq_info(sta->sta.txq[tid])); + + rcu_read_unlock(); + spin_unlock_bh(&local->fq.lock); + } +} + +static inline u64 sta_get_stats_bytes(struct ieee80211_sta_rx_stats *rxstats) +{ + unsigned int start; + u64 value; + + do { + start = u64_stats_fetch_begin_irq(&rxstats->syncp); + value = rxstats->bytes; + } while (u64_stats_fetch_retry_irq(&rxstats->syncp, start)); + + return value; +} + +void sta_set_sinfo(struct sta_info *sta, struct station_info *sinfo, + bool tidstats) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sdata->local; + u32 thr = 0; + int i, ac, cpu; + struct ieee80211_sta_rx_stats *last_rxstats; + + last_rxstats = sta_get_last_rx_stats(sta); + + sinfo->generation = sdata->local->sta_generation; + + /* do before driver, so beacon filtering drivers have a + * chance to e.g. just add the number of filtered beacons + * (or just modify the value entirely, of course) + */ + if (sdata->vif.type == NL80211_IFTYPE_STATION) + sinfo->rx_beacon = sdata->deflink.u.mgd.count_beacon_signal; + + drv_sta_statistics(local, sdata, &sta->sta, sinfo); + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_INACTIVE_TIME) | + BIT_ULL(NL80211_STA_INFO_STA_FLAGS) | + BIT_ULL(NL80211_STA_INFO_BSS_PARAM) | + BIT_ULL(NL80211_STA_INFO_CONNECTED_TIME) | + BIT_ULL(NL80211_STA_INFO_ASSOC_AT_BOOTTIME) | + BIT_ULL(NL80211_STA_INFO_RX_DROP_MISC); + + if (sdata->vif.type == NL80211_IFTYPE_STATION) { + sinfo->beacon_loss_count = + sdata->deflink.u.mgd.beacon_loss_count; + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_BEACON_LOSS); + } + + sinfo->connected_time = ktime_get_seconds() - sta->last_connected; + sinfo->assoc_at = sta->assoc_at; + sinfo->inactive_time = + jiffies_to_msecs(jiffies - ieee80211_sta_last_active(sta)); + + if (!(sinfo->filled & (BIT_ULL(NL80211_STA_INFO_TX_BYTES64) | + BIT_ULL(NL80211_STA_INFO_TX_BYTES)))) { + sinfo->tx_bytes = 0; + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) + sinfo->tx_bytes += sta->deflink.tx_stats.bytes[ac]; + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_TX_BYTES64); + } + + if (!(sinfo->filled & BIT_ULL(NL80211_STA_INFO_TX_PACKETS))) { + sinfo->tx_packets = 0; + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) + sinfo->tx_packets += sta->deflink.tx_stats.packets[ac]; + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_TX_PACKETS); + } + + if (!(sinfo->filled & (BIT_ULL(NL80211_STA_INFO_RX_BYTES64) | + BIT_ULL(NL80211_STA_INFO_RX_BYTES)))) { + sinfo->rx_bytes += sta_get_stats_bytes(&sta->deflink.rx_stats); + + if (sta->deflink.pcpu_rx_stats) { + for_each_possible_cpu(cpu) { + struct ieee80211_sta_rx_stats *cpurxs; + + cpurxs = per_cpu_ptr(sta->deflink.pcpu_rx_stats, + cpu); + sinfo->rx_bytes += sta_get_stats_bytes(cpurxs); + } + } + + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_RX_BYTES64); + } + + if (!(sinfo->filled & BIT_ULL(NL80211_STA_INFO_RX_PACKETS))) { + sinfo->rx_packets = sta->deflink.rx_stats.packets; + if (sta->deflink.pcpu_rx_stats) { + for_each_possible_cpu(cpu) { + struct ieee80211_sta_rx_stats *cpurxs; + + cpurxs = per_cpu_ptr(sta->deflink.pcpu_rx_stats, + cpu); + sinfo->rx_packets += cpurxs->packets; + } + } + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_RX_PACKETS); + } + + if (!(sinfo->filled & BIT_ULL(NL80211_STA_INFO_TX_RETRIES))) { + sinfo->tx_retries = sta->deflink.status_stats.retry_count; + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_TX_RETRIES); + } + + if (!(sinfo->filled & BIT_ULL(NL80211_STA_INFO_TX_FAILED))) { + sinfo->tx_failed = sta->deflink.status_stats.retry_failed; + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_TX_FAILED); + } + + if (!(sinfo->filled & BIT_ULL(NL80211_STA_INFO_RX_DURATION))) { + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) + sinfo->rx_duration += sta->airtime[ac].rx_airtime; + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_RX_DURATION); + } + + if (!(sinfo->filled & BIT_ULL(NL80211_STA_INFO_TX_DURATION))) { + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) + sinfo->tx_duration += sta->airtime[ac].tx_airtime; + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_TX_DURATION); + } + + if (!(sinfo->filled & BIT_ULL(NL80211_STA_INFO_AIRTIME_WEIGHT))) { + sinfo->airtime_weight = sta->airtime_weight; + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_AIRTIME_WEIGHT); + } + + sinfo->rx_dropped_misc = sta->deflink.rx_stats.dropped; + if (sta->deflink.pcpu_rx_stats) { + for_each_possible_cpu(cpu) { + struct ieee80211_sta_rx_stats *cpurxs; + + cpurxs = per_cpu_ptr(sta->deflink.pcpu_rx_stats, cpu); + sinfo->rx_dropped_misc += cpurxs->dropped; + } + } + + if (sdata->vif.type == NL80211_IFTYPE_STATION && + !(sdata->vif.driver_flags & IEEE80211_VIF_BEACON_FILTER)) { + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_BEACON_RX) | + BIT_ULL(NL80211_STA_INFO_BEACON_SIGNAL_AVG); + sinfo->rx_beacon_signal_avg = ieee80211_ave_rssi(&sdata->vif); + } + + if (ieee80211_hw_check(&sta->local->hw, SIGNAL_DBM) || + ieee80211_hw_check(&sta->local->hw, SIGNAL_UNSPEC)) { + if (!(sinfo->filled & BIT_ULL(NL80211_STA_INFO_SIGNAL))) { + sinfo->signal = (s8)last_rxstats->last_signal; + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_SIGNAL); + } + + if (!sta->deflink.pcpu_rx_stats && + !(sinfo->filled & BIT_ULL(NL80211_STA_INFO_SIGNAL_AVG))) { + sinfo->signal_avg = + -ewma_signal_read(&sta->deflink.rx_stats_avg.signal); + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_SIGNAL_AVG); + } + } + + /* for the average - if pcpu_rx_stats isn't set - rxstats must point to + * the sta->rx_stats struct, so the check here is fine with and without + * pcpu statistics + */ + if (last_rxstats->chains && + !(sinfo->filled & (BIT_ULL(NL80211_STA_INFO_CHAIN_SIGNAL) | + BIT_ULL(NL80211_STA_INFO_CHAIN_SIGNAL_AVG)))) { + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_CHAIN_SIGNAL); + if (!sta->deflink.pcpu_rx_stats) + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_CHAIN_SIGNAL_AVG); + + sinfo->chains = last_rxstats->chains; + + for (i = 0; i < ARRAY_SIZE(sinfo->chain_signal); i++) { + sinfo->chain_signal[i] = + last_rxstats->chain_signal_last[i]; + sinfo->chain_signal_avg[i] = + -ewma_signal_read(&sta->deflink.rx_stats_avg.chain_signal[i]); + } + } + + if (!(sinfo->filled & BIT_ULL(NL80211_STA_INFO_TX_BITRATE)) && + !sta->sta.valid_links) { + sta_set_rate_info_tx(sta, &sta->deflink.tx_stats.last_rate, + &sinfo->txrate); + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_TX_BITRATE); + } + + if (!(sinfo->filled & BIT_ULL(NL80211_STA_INFO_RX_BITRATE)) && + !sta->sta.valid_links) { + if (sta_set_rate_info_rx(sta, &sinfo->rxrate) == 0) + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_RX_BITRATE); + } + + if (tidstats && !cfg80211_sinfo_alloc_tid_stats(sinfo, GFP_KERNEL)) { + for (i = 0; i < IEEE80211_NUM_TIDS + 1; i++) + sta_set_tidstats(sta, &sinfo->pertid[i], i); + } + + if (ieee80211_vif_is_mesh(&sdata->vif)) { +#ifdef CONFIG_MAC80211_MESH + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_LLID) | + BIT_ULL(NL80211_STA_INFO_PLID) | + BIT_ULL(NL80211_STA_INFO_PLINK_STATE) | + BIT_ULL(NL80211_STA_INFO_LOCAL_PM) | + BIT_ULL(NL80211_STA_INFO_PEER_PM) | + BIT_ULL(NL80211_STA_INFO_NONPEER_PM) | + BIT_ULL(NL80211_STA_INFO_CONNECTED_TO_GATE) | + BIT_ULL(NL80211_STA_INFO_CONNECTED_TO_AS); + + sinfo->llid = sta->mesh->llid; + sinfo->plid = sta->mesh->plid; + sinfo->plink_state = sta->mesh->plink_state; + if (test_sta_flag(sta, WLAN_STA_TOFFSET_KNOWN)) { + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_T_OFFSET); + sinfo->t_offset = sta->mesh->t_offset; + } + sinfo->local_pm = sta->mesh->local_pm; + sinfo->peer_pm = sta->mesh->peer_pm; + sinfo->nonpeer_pm = sta->mesh->nonpeer_pm; + sinfo->connected_to_gate = sta->mesh->connected_to_gate; + sinfo->connected_to_as = sta->mesh->connected_to_as; +#endif + } + + sinfo->bss_param.flags = 0; + if (sdata->vif.bss_conf.use_cts_prot) + sinfo->bss_param.flags |= BSS_PARAM_FLAGS_CTS_PROT; + if (sdata->vif.bss_conf.use_short_preamble) + sinfo->bss_param.flags |= BSS_PARAM_FLAGS_SHORT_PREAMBLE; + if (sdata->vif.bss_conf.use_short_slot) + sinfo->bss_param.flags |= BSS_PARAM_FLAGS_SHORT_SLOT_TIME; + sinfo->bss_param.dtim_period = sdata->vif.bss_conf.dtim_period; + sinfo->bss_param.beacon_interval = sdata->vif.bss_conf.beacon_int; + + sinfo->sta_flags.set = 0; + sinfo->sta_flags.mask = BIT(NL80211_STA_FLAG_AUTHORIZED) | + BIT(NL80211_STA_FLAG_SHORT_PREAMBLE) | + BIT(NL80211_STA_FLAG_WME) | + BIT(NL80211_STA_FLAG_MFP) | + BIT(NL80211_STA_FLAG_AUTHENTICATED) | + BIT(NL80211_STA_FLAG_ASSOCIATED) | + BIT(NL80211_STA_FLAG_TDLS_PEER); + if (test_sta_flag(sta, WLAN_STA_AUTHORIZED)) + sinfo->sta_flags.set |= BIT(NL80211_STA_FLAG_AUTHORIZED); + if (test_sta_flag(sta, WLAN_STA_SHORT_PREAMBLE)) + sinfo->sta_flags.set |= BIT(NL80211_STA_FLAG_SHORT_PREAMBLE); + if (sta->sta.wme) + sinfo->sta_flags.set |= BIT(NL80211_STA_FLAG_WME); + if (test_sta_flag(sta, WLAN_STA_MFP)) + sinfo->sta_flags.set |= BIT(NL80211_STA_FLAG_MFP); + if (test_sta_flag(sta, WLAN_STA_AUTH)) + sinfo->sta_flags.set |= BIT(NL80211_STA_FLAG_AUTHENTICATED); + if (test_sta_flag(sta, WLAN_STA_ASSOC)) + sinfo->sta_flags.set |= BIT(NL80211_STA_FLAG_ASSOCIATED); + if (test_sta_flag(sta, WLAN_STA_TDLS_PEER)) + sinfo->sta_flags.set |= BIT(NL80211_STA_FLAG_TDLS_PEER); + + thr = sta_get_expected_throughput(sta); + + if (thr != 0) { + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_EXPECTED_THROUGHPUT); + sinfo->expected_throughput = thr; + } + + if (!(sinfo->filled & BIT_ULL(NL80211_STA_INFO_ACK_SIGNAL)) && + sta->deflink.status_stats.ack_signal_filled) { + sinfo->ack_signal = sta->deflink.status_stats.last_ack_signal; + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_ACK_SIGNAL); + } + + if (!(sinfo->filled & BIT_ULL(NL80211_STA_INFO_ACK_SIGNAL_AVG)) && + sta->deflink.status_stats.ack_signal_filled) { + sinfo->avg_ack_signal = + -(s8)ewma_avg_signal_read( + &sta->deflink.status_stats.avg_ack_signal); + sinfo->filled |= + BIT_ULL(NL80211_STA_INFO_ACK_SIGNAL_AVG); + } + + if (ieee80211_vif_is_mesh(&sdata->vif)) { + sinfo->filled |= BIT_ULL(NL80211_STA_INFO_AIRTIME_LINK_METRIC); + sinfo->airtime_link_metric = + airtime_link_metric_get(local, sta); + } +} + +u32 sta_get_expected_throughput(struct sta_info *sta) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sdata->local; + struct rate_control_ref *ref = NULL; + u32 thr = 0; + + if (test_sta_flag(sta, WLAN_STA_RATE_CONTROL)) + ref = local->rate_ctrl; + + /* check if the driver has a SW RC implementation */ + if (ref && ref->ops->get_expected_throughput) + thr = ref->ops->get_expected_throughput(sta->rate_ctrl_priv); + else + thr = drv_get_expected_throughput(local, sta); + + return thr; +} + +unsigned long ieee80211_sta_last_active(struct sta_info *sta) +{ + struct ieee80211_sta_rx_stats *stats = sta_get_last_rx_stats(sta); + + if (!sta->deflink.status_stats.last_ack || + time_after(stats->last_rx, sta->deflink.status_stats.last_ack)) + return stats->last_rx; + return sta->deflink.status_stats.last_ack; +} + +static void sta_update_codel_params(struct sta_info *sta, u32 thr) +{ + if (!sta->sdata->local->ops->wake_tx_queue) + return; + + if (thr && thr < STA_SLOW_THRESHOLD * sta->local->num_sta) { + sta->cparams.target = MS2TIME(50); + sta->cparams.interval = MS2TIME(300); + sta->cparams.ecn = false; + } else { + sta->cparams.target = MS2TIME(20); + sta->cparams.interval = MS2TIME(100); + sta->cparams.ecn = true; + } +} + +void ieee80211_sta_set_expected_throughput(struct ieee80211_sta *pubsta, + u32 thr) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + + sta_update_codel_params(sta, thr); +} + +int ieee80211_sta_allocate_link(struct sta_info *sta, unsigned int link_id) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct sta_link_alloc *alloc; + int ret; + + lockdep_assert_held(&sdata->local->sta_mtx); + + /* must represent an MLD from the start */ + if (WARN_ON(!sta->sta.valid_links)) + return -EINVAL; + + if (WARN_ON(sta->sta.valid_links & BIT(link_id) || + sta->link[link_id])) + return -EBUSY; + + alloc = kzalloc(sizeof(*alloc), GFP_KERNEL); + if (!alloc) + return -ENOMEM; + + ret = sta_info_alloc_link(sdata->local, &alloc->info, GFP_KERNEL); + if (ret) { + kfree(alloc); + return ret; + } + + sta_info_add_link(sta, link_id, &alloc->info, &alloc->sta); + + return 0; +} + +void ieee80211_sta_free_link(struct sta_info *sta, unsigned int link_id) +{ + lockdep_assert_held(&sta->sdata->local->sta_mtx); + + sta_remove_link(sta, link_id, false); +} + +int ieee80211_sta_activate_link(struct sta_info *sta, unsigned int link_id) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct link_sta_info *link_sta; + u16 old_links = sta->sta.valid_links; + u16 new_links = old_links | BIT(link_id); + int ret; + + link_sta = rcu_dereference_protected(sta->link[link_id], + lockdep_is_held(&sdata->local->sta_mtx)); + + if (WARN_ON(old_links == new_links || !link_sta)) + return -EINVAL; + + rcu_read_lock(); + if (link_sta_info_hash_lookup(sdata->local, link_sta->addr)) { + rcu_read_unlock(); + return -EALREADY; + } + /* we only modify under the mutex so this is fine */ + rcu_read_unlock(); + + sta->sta.valid_links = new_links; + + if (!test_sta_flag(sta, WLAN_STA_INSERTED)) + goto hash; + + ieee80211_recalc_min_chandef(sdata, link_id); + + /* Ensure the values are updated for the driver, + * redone by sta_remove_link on failure. + */ + ieee80211_sta_recalc_aggregates(&sta->sta); + + ret = drv_change_sta_links(sdata->local, sdata, &sta->sta, + old_links, new_links); + if (ret) { + sta->sta.valid_links = old_links; + sta_remove_link(sta, link_id, false); + return ret; + } + +hash: + ret = link_sta_info_hash_add(sdata->local, link_sta); + WARN_ON(ret); + return 0; +} + +void ieee80211_sta_remove_link(struct sta_info *sta, unsigned int link_id) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + u16 old_links = sta->sta.valid_links; + + lockdep_assert_held(&sdata->local->sta_mtx); + + sta->sta.valid_links &= ~BIT(link_id); + + if (test_sta_flag(sta, WLAN_STA_INSERTED)) + drv_change_sta_links(sdata->local, sdata, &sta->sta, + old_links, sta->sta.valid_links); + + sta_remove_link(sta, link_id, true); +} + +void ieee80211_sta_set_max_amsdu_subframes(struct sta_info *sta, + const u8 *ext_capab, + unsigned int ext_capab_len) +{ + u8 val; + + sta->sta.max_amsdu_subframes = 0; + + if (ext_capab_len < 8) + return; + + /* The sender might not have sent the last bit, consider it to be 0 */ + val = u8_get_bits(ext_capab[7], WLAN_EXT_CAPA8_MAX_MSDU_IN_AMSDU_LSB); + + /* we did get all the bits, take the MSB as well */ + if (ext_capab_len >= 9) + val |= u8_get_bits(ext_capab[8], + WLAN_EXT_CAPA9_MAX_MSDU_IN_AMSDU_MSB) << 1; + + if (val) + sta->sta.max_amsdu_subframes = 4 << (4 - val); +} + +#ifdef CONFIG_LOCKDEP +bool lockdep_sta_mutex_held(struct ieee80211_sta *pubsta) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + + return lockdep_is_held(&sta->local->sta_mtx); +} +EXPORT_SYMBOL(lockdep_sta_mutex_held); +#endif diff --git a/net/mac80211/sta_info.h b/net/mac80211/sta_info.h new file mode 100644 index 000000000..2517ea714 --- /dev/null +++ b/net/mac80211/sta_info.h @@ -0,0 +1,993 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright 2002-2005, Devicescape Software, Inc. + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright(c) 2015-2017 Intel Deutschland GmbH + * Copyright(c) 2020-2022 Intel Corporation + */ + +#ifndef STA_INFO_H +#define STA_INFO_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "key.h" + +/** + * enum ieee80211_sta_info_flags - Stations flags + * + * These flags are used with &struct sta_info's @flags member, but + * only indirectly with set_sta_flag() and friends. + * + * @WLAN_STA_AUTH: Station is authenticated. + * @WLAN_STA_ASSOC: Station is associated. + * @WLAN_STA_PS_STA: Station is in power-save mode + * @WLAN_STA_AUTHORIZED: Station is authorized to send/receive traffic. + * This bit is always checked so needs to be enabled for all stations + * when virtual port control is not in use. + * @WLAN_STA_SHORT_PREAMBLE: Station is capable of receiving short-preamble + * frames. + * @WLAN_STA_WDS: Station is one of our WDS peers. + * @WLAN_STA_CLEAR_PS_FILT: Clear PS filter in hardware (using the + * IEEE80211_TX_CTL_CLEAR_PS_FILT control flag) when the next + * frame to this station is transmitted. + * @WLAN_STA_MFP: Management frame protection is used with this STA. + * @WLAN_STA_BLOCK_BA: Used to deny ADDBA requests (both TX and RX) + * during suspend/resume and station removal. + * @WLAN_STA_PS_DRIVER: driver requires keeping this station in + * power-save mode logically to flush frames that might still + * be in the queues + * @WLAN_STA_PSPOLL: Station sent PS-poll while driver was keeping + * station in power-save mode, reply when the driver unblocks. + * @WLAN_STA_TDLS_PEER: Station is a TDLS peer. + * @WLAN_STA_TDLS_PEER_AUTH: This TDLS peer is authorized to send direct + * packets. This means the link is enabled. + * @WLAN_STA_TDLS_INITIATOR: We are the initiator of the TDLS link with this + * station. + * @WLAN_STA_TDLS_CHAN_SWITCH: This TDLS peer supports TDLS channel-switching + * @WLAN_STA_TDLS_OFF_CHANNEL: The local STA is currently off-channel with this + * TDLS peer + * @WLAN_STA_TDLS_WIDER_BW: This TDLS peer supports working on a wider bw on + * the BSS base channel. + * @WLAN_STA_UAPSD: Station requested unscheduled SP while driver was + * keeping station in power-save mode, reply when the driver + * unblocks the station. + * @WLAN_STA_SP: Station is in a service period, so don't try to + * reply to other uAPSD trigger frames or PS-Poll. + * @WLAN_STA_4ADDR_EVENT: 4-addr event was already sent for this frame. + * @WLAN_STA_INSERTED: This station is inserted into the hash table. + * @WLAN_STA_RATE_CONTROL: rate control was initialized for this station. + * @WLAN_STA_TOFFSET_KNOWN: toffset calculated for this station is valid. + * @WLAN_STA_MPSP_OWNER: local STA is owner of a mesh Peer Service Period. + * @WLAN_STA_MPSP_RECIPIENT: local STA is recipient of a MPSP. + * @WLAN_STA_PS_DELIVER: station woke up, but we're still blocking TX + * until pending frames are delivered + * @WLAN_STA_USES_ENCRYPTION: This station was configured for encryption, + * so drop all packets without a key later. + * @WLAN_STA_DECAP_OFFLOAD: This station uses rx decap offload + * + * @NUM_WLAN_STA_FLAGS: number of defined flags + */ +enum ieee80211_sta_info_flags { + WLAN_STA_AUTH, + WLAN_STA_ASSOC, + WLAN_STA_PS_STA, + WLAN_STA_AUTHORIZED, + WLAN_STA_SHORT_PREAMBLE, + WLAN_STA_WDS, + WLAN_STA_CLEAR_PS_FILT, + WLAN_STA_MFP, + WLAN_STA_BLOCK_BA, + WLAN_STA_PS_DRIVER, + WLAN_STA_PSPOLL, + WLAN_STA_TDLS_PEER, + WLAN_STA_TDLS_PEER_AUTH, + WLAN_STA_TDLS_INITIATOR, + WLAN_STA_TDLS_CHAN_SWITCH, + WLAN_STA_TDLS_OFF_CHANNEL, + WLAN_STA_TDLS_WIDER_BW, + WLAN_STA_UAPSD, + WLAN_STA_SP, + WLAN_STA_4ADDR_EVENT, + WLAN_STA_INSERTED, + WLAN_STA_RATE_CONTROL, + WLAN_STA_TOFFSET_KNOWN, + WLAN_STA_MPSP_OWNER, + WLAN_STA_MPSP_RECIPIENT, + WLAN_STA_PS_DELIVER, + WLAN_STA_USES_ENCRYPTION, + WLAN_STA_DECAP_OFFLOAD, + + NUM_WLAN_STA_FLAGS, +}; + +#define ADDBA_RESP_INTERVAL HZ +#define HT_AGG_MAX_RETRIES 15 +#define HT_AGG_BURST_RETRIES 3 +#define HT_AGG_RETRIES_PERIOD (15 * HZ) + +#define HT_AGG_STATE_DRV_READY 0 +#define HT_AGG_STATE_RESPONSE_RECEIVED 1 +#define HT_AGG_STATE_OPERATIONAL 2 +#define HT_AGG_STATE_STOPPING 3 +#define HT_AGG_STATE_WANT_START 4 +#define HT_AGG_STATE_WANT_STOP 5 +#define HT_AGG_STATE_START_CB 6 +#define HT_AGG_STATE_STOP_CB 7 +#define HT_AGG_STATE_SENT_ADDBA 8 + +DECLARE_EWMA(avg_signal, 10, 8) +enum ieee80211_agg_stop_reason { + AGG_STOP_DECLINED, + AGG_STOP_LOCAL_REQUEST, + AGG_STOP_PEER_REQUEST, + AGG_STOP_DESTROY_STA, +}; + +/* Debugfs flags to enable/disable use of RX/TX airtime in scheduler */ +#define AIRTIME_USE_TX BIT(0) +#define AIRTIME_USE_RX BIT(1) + +struct airtime_info { + u64 rx_airtime; + u64 tx_airtime; + u32 last_active; + s32 deficit; + atomic_t aql_tx_pending; /* Estimated airtime for frames pending */ + u32 aql_limit_low; + u32 aql_limit_high; +}; + +void ieee80211_sta_update_pending_airtime(struct ieee80211_local *local, + struct sta_info *sta, u8 ac, + u16 tx_airtime, bool tx_completed); + +struct sta_info; + +/** + * struct tid_ampdu_tx - TID aggregation information (Tx). + * + * @rcu_head: rcu head for freeing structure + * @session_timer: check if we keep Tx-ing on the TID (by timeout value) + * @addba_resp_timer: timer for peer's response to addba request + * @pending: pending frames queue -- use sta's spinlock to protect + * @sta: station we are attached to + * @dialog_token: dialog token for aggregation session + * @timeout: session timeout value to be filled in ADDBA requests + * @tid: TID number + * @state: session state (see above) + * @last_tx: jiffies of last tx activity + * @stop_initiator: initiator of a session stop + * @tx_stop: TX DelBA frame when stopping + * @buf_size: reorder buffer size at receiver + * @failed_bar_ssn: ssn of the last failed BAR tx attempt + * @bar_pending: BAR needs to be re-sent + * @amsdu: support A-MSDU withing A-MDPU + * @ssn: starting sequence number of the session + * + * This structure's lifetime is managed by RCU, assignments to + * the array holding it must hold the aggregation mutex. + * + * The TX path can access it under RCU lock-free if, and + * only if, the state has the flag %HT_AGG_STATE_OPERATIONAL + * set. Otherwise, the TX path must also acquire the spinlock + * and re-check the state, see comments in the tx code + * touching it. + */ +struct tid_ampdu_tx { + struct rcu_head rcu_head; + struct timer_list session_timer; + struct timer_list addba_resp_timer; + struct sk_buff_head pending; + struct sta_info *sta; + unsigned long state; + unsigned long last_tx; + u16 timeout; + u8 dialog_token; + u8 stop_initiator; + bool tx_stop; + u16 buf_size; + u16 ssn; + + u16 failed_bar_ssn; + bool bar_pending; + bool amsdu; + u8 tid; +}; + +/** + * struct tid_ampdu_rx - TID aggregation information (Rx). + * + * @reorder_buf: buffer to reorder incoming aggregated MPDUs. An MPDU may be an + * A-MSDU with individually reported subframes. + * @reorder_buf_filtered: bitmap indicating where there are filtered frames in + * the reorder buffer that should be ignored when releasing frames + * @reorder_time: jiffies when skb was added + * @session_timer: check if peer keeps Tx-ing on the TID (by timeout value) + * @reorder_timer: releases expired frames from the reorder buffer. + * @sta: station we are attached to + * @last_rx: jiffies of last rx activity + * @head_seq_num: head sequence number in reordering buffer. + * @stored_mpdu_num: number of MPDUs in reordering buffer + * @ssn: Starting Sequence Number expected to be aggregated. + * @buf_size: buffer size for incoming A-MPDUs + * @timeout: reset timer value (in TUs). + * @tid: TID number + * @rcu_head: RCU head used for freeing this struct + * @reorder_lock: serializes access to reorder buffer, see below. + * @auto_seq: used for offloaded BA sessions to automatically pick head_seq_and + * and ssn. + * @removed: this session is removed (but might have been found due to RCU) + * @started: this session has started (head ssn or higher was received) + * + * This structure's lifetime is managed by RCU, assignments to + * the array holding it must hold the aggregation mutex. + * + * The @reorder_lock is used to protect the members of this + * struct, except for @timeout, @buf_size and @dialog_token, + * which are constant across the lifetime of the struct (the + * dialog token being used only for debugging). + */ +struct tid_ampdu_rx { + struct rcu_head rcu_head; + spinlock_t reorder_lock; + u64 reorder_buf_filtered; + struct sk_buff_head *reorder_buf; + unsigned long *reorder_time; + struct sta_info *sta; + struct timer_list session_timer; + struct timer_list reorder_timer; + unsigned long last_rx; + u16 head_seq_num; + u16 stored_mpdu_num; + u16 ssn; + u16 buf_size; + u16 timeout; + u8 tid; + u8 auto_seq:1, + removed:1, + started:1; +}; + +/** + * struct sta_ampdu_mlme - STA aggregation information. + * + * @mtx: mutex to protect all TX data (except non-NULL assignments + * to tid_tx[idx], which are protected by the sta spinlock) + * tid_start_tx is also protected by sta->lock. + * @tid_rx: aggregation info for Rx per TID -- RCU protected + * @tid_rx_token: dialog tokens for valid aggregation sessions + * @tid_rx_timer_expired: bitmap indicating on which TIDs the + * RX timer expired until the work for it runs + * @tid_rx_stop_requested: bitmap indicating which BA sessions per TID the + * driver requested to close until the work for it runs + * @tid_rx_manage_offl: bitmap indicating which BA sessions were requested + * to be treated as started/stopped due to offloading + * @agg_session_valid: bitmap indicating which TID has a rx BA session open on + * @unexpected_agg: bitmap indicating which TID already sent a delBA due to + * unexpected aggregation related frames outside a session + * @work: work struct for starting/stopping aggregation + * @tid_tx: aggregation info for Tx per TID + * @tid_start_tx: sessions where start was requested + * @last_addba_req_time: timestamp of the last addBA request. + * @addba_req_num: number of times addBA request has been sent. + * @dialog_token_allocator: dialog token enumerator for each new session; + */ +struct sta_ampdu_mlme { + struct mutex mtx; + /* rx */ + struct tid_ampdu_rx __rcu *tid_rx[IEEE80211_NUM_TIDS]; + u8 tid_rx_token[IEEE80211_NUM_TIDS]; + unsigned long tid_rx_timer_expired[BITS_TO_LONGS(IEEE80211_NUM_TIDS)]; + unsigned long tid_rx_stop_requested[BITS_TO_LONGS(IEEE80211_NUM_TIDS)]; + unsigned long tid_rx_manage_offl[BITS_TO_LONGS(2 * IEEE80211_NUM_TIDS)]; + unsigned long agg_session_valid[BITS_TO_LONGS(IEEE80211_NUM_TIDS)]; + unsigned long unexpected_agg[BITS_TO_LONGS(IEEE80211_NUM_TIDS)]; + /* tx */ + struct work_struct work; + struct tid_ampdu_tx __rcu *tid_tx[IEEE80211_NUM_TIDS]; + struct tid_ampdu_tx *tid_start_tx[IEEE80211_NUM_TIDS]; + unsigned long last_addba_req_time[IEEE80211_NUM_TIDS]; + u8 addba_req_num[IEEE80211_NUM_TIDS]; + u8 dialog_token_allocator; +}; + + +/* Value to indicate no TID reservation */ +#define IEEE80211_TID_UNRESERVED 0xff + +#define IEEE80211_FAST_XMIT_MAX_IV 18 + +/** + * struct ieee80211_fast_tx - TX fastpath information + * @key: key to use for hw crypto + * @hdr: the 802.11 header to put with the frame + * @hdr_len: actual 802.11 header length + * @sa_offs: offset of the SA + * @da_offs: offset of the DA + * @pn_offs: offset where to put PN for crypto (or 0 if not needed) + * @band: band this will be transmitted on, for tx_info + * @rcu_head: RCU head to free this struct + * + * This struct is small enough so that the common case (maximum crypto + * header length of 8 like for CCMP/GCMP) fits into a single 64-byte + * cache line. + */ +struct ieee80211_fast_tx { + struct ieee80211_key *key; + u8 hdr_len; + u8 sa_offs, da_offs, pn_offs; + u8 band; + u8 hdr[30 + 2 + IEEE80211_FAST_XMIT_MAX_IV + + sizeof(rfc1042_header)] __aligned(2); + + struct rcu_head rcu_head; +}; + +/** + * struct ieee80211_fast_rx - RX fastpath information + * @dev: netdevice for reporting the SKB + * @vif_type: (P2P-less) interface type of the original sdata (sdata->vif.type) + * @vif_addr: interface address + * @rfc1042_hdr: copy of the RFC 1042 SNAP header (to have in cache) + * @control_port_protocol: control port protocol copied from sdata + * @expected_ds_bits: from/to DS bits expected + * @icv_len: length of the MIC if present + * @key: bool indicating encryption is expected (key is set) + * @internal_forward: forward froms internally on AP/VLAN type interfaces + * @uses_rss: copy of USES_RSS hw flag + * @da_offs: offset of the DA in the header (for header conversion) + * @sa_offs: offset of the SA in the header (for header conversion) + * @rcu_head: RCU head for freeing this structure + */ +struct ieee80211_fast_rx { + struct net_device *dev; + enum nl80211_iftype vif_type; + u8 vif_addr[ETH_ALEN] __aligned(2); + u8 rfc1042_hdr[6] __aligned(2); + __be16 control_port_protocol; + __le16 expected_ds_bits; + u8 icv_len; + u8 key:1, + internal_forward:1, + uses_rss:1; + u8 da_offs, sa_offs; + + struct rcu_head rcu_head; +}; + +/* we use only values in the range 0-100, so pick a large precision */ +DECLARE_EWMA(mesh_fail_avg, 20, 8) +DECLARE_EWMA(mesh_tx_rate_avg, 8, 16) + +/** + * struct mesh_sta - mesh STA information + * @plink_lock: serialize access to plink fields + * @llid: Local link ID + * @plid: Peer link ID + * @aid: local aid supplied by peer + * @reason: Cancel reason on PLINK_HOLDING state + * @plink_retries: Retries in establishment + * @plink_state: peer link state + * @plink_timeout: timeout of peer link + * @plink_timer: peer link watch timer + * @plink_sta: peer link watch timer's sta_info + * @t_offset: timing offset relative to this host + * @t_offset_setpoint: reference timing offset of this sta to be used when + * calculating clockdrift + * @local_pm: local link-specific power save mode + * @peer_pm: peer-specific power save mode towards local STA + * @nonpeer_pm: STA power save mode towards non-peer neighbors + * @processed_beacon: set to true after peer rates and capabilities are + * processed + * @connected_to_gate: true if mesh STA has a path to a mesh gate + * @connected_to_as: true if mesh STA has a path to a authentication server + * @fail_avg: moving percentage of failed MSDUs + * @tx_rate_avg: moving average of tx bitrate + */ +struct mesh_sta { + struct timer_list plink_timer; + struct sta_info *plink_sta; + + s64 t_offset; + s64 t_offset_setpoint; + + spinlock_t plink_lock; + u16 llid; + u16 plid; + u16 aid; + u16 reason; + u8 plink_retries; + + bool processed_beacon; + bool connected_to_gate; + bool connected_to_as; + + enum nl80211_plink_state plink_state; + u32 plink_timeout; + + /* mesh power save */ + enum nl80211_mesh_power_mode local_pm; + enum nl80211_mesh_power_mode peer_pm; + enum nl80211_mesh_power_mode nonpeer_pm; + + /* moving percentage of failed MSDUs */ + struct ewma_mesh_fail_avg fail_avg; + /* moving average of tx bitrate */ + struct ewma_mesh_tx_rate_avg tx_rate_avg; +}; + +DECLARE_EWMA(signal, 10, 8) + +struct ieee80211_sta_rx_stats { + unsigned long packets; + unsigned long last_rx; + unsigned long num_duplicates; + unsigned long fragments; + unsigned long dropped; + int last_signal; + u8 chains; + s8 chain_signal_last[IEEE80211_MAX_CHAINS]; + u32 last_rate; + struct u64_stats_sync syncp; + u64 bytes; + u64 msdu[IEEE80211_NUM_TIDS + 1]; +}; + +/* + * IEEE 802.11-2016 (10.6 "Defragmentation") recommends support for "concurrent + * reception of at least one MSDU per access category per associated STA" + * on APs, or "at least one MSDU per access category" on other interface types. + * + * This limit can be increased by changing this define, at the cost of slower + * frame reassembly and increased memory use while fragments are pending. + */ +#define IEEE80211_FRAGMENT_MAX 4 + +struct ieee80211_fragment_entry { + struct sk_buff_head skb_list; + unsigned long first_frag_time; + u16 seq; + u16 extra_len; + u16 last_frag; + u8 rx_queue; + u8 check_sequential_pn:1, /* needed for CCMP/GCMP */ + is_protected:1; + u8 last_pn[6]; /* PN of the last fragment if CCMP was used */ + unsigned int key_color; +}; + +struct ieee80211_fragment_cache { + struct ieee80211_fragment_entry entries[IEEE80211_FRAGMENT_MAX]; + unsigned int next; +}; + +/* + * The bandwidth threshold below which the per-station CoDel parameters will be + * scaled to be more lenient (to prevent starvation of slow stations). This + * value will be scaled by the number of active stations when it is being + * applied. + */ +#define STA_SLOW_THRESHOLD 6000 /* 6 Mbps */ + +/** + * struct link_sta_info - Link STA information + * All link specific sta info are stored here for reference. This can be + * a single entry for non-MLD STA or multiple entries for MLD STA + * @addr: Link MAC address - Can be same as MLD STA mac address and is always + * same for non-MLD STA. This is used as key for searching link STA + * @link_id: Link ID uniquely identifying the link STA. This is 0 for non-MLD + * and set to the corresponding vif LinkId for MLD STA + * @link_hash_node: hash node for rhashtable + * @sta: Points to the STA info + * @gtk: group keys negotiated with this station, if any + * @tx_stats: TX statistics + * @tx_stats.packets: # of packets transmitted + * @tx_stats.bytes: # of bytes in all packets transmitted + * @tx_stats.last_rate: last TX rate + * @tx_stats.msdu: # of transmitted MSDUs per TID + * @rx_stats: RX statistics + * @rx_stats_avg: averaged RX statistics + * @rx_stats_avg.signal: averaged signal + * @rx_stats_avg.chain_signal: averaged per-chain signal + * @pcpu_rx_stats: per-CPU RX statistics, assigned only if the driver needs + * this (by advertising the USES_RSS hw flag) + * @status_stats: TX status statistics + * @status_stats.filtered: # of filtered frames + * @status_stats.retry_failed: # of frames that failed after retry + * @status_stats.retry_count: # of retries attempted + * @status_stats.lost_packets: # of lost packets + * @status_stats.last_pkt_time: timestamp of last ACKed packet + * @status_stats.msdu_retries: # of MSDU retries + * @status_stats.msdu_failed: # of failed MSDUs + * @status_stats.last_ack: last ack timestamp (jiffies) + * @status_stats.last_ack_signal: last ACK signal + * @status_stats.ack_signal_filled: last ACK signal validity + * @status_stats.avg_ack_signal: average ACK signal + * @cur_max_bandwidth: maximum bandwidth to use for TX to the station, + * taken from HT/VHT capabilities or VHT operating mode notification + * @pub: public (driver visible) link STA data + * TODO Move other link params from sta_info as required for MLD operation + */ +struct link_sta_info { + u8 addr[ETH_ALEN]; + u8 link_id; + + struct rhlist_head link_hash_node; + + struct sta_info *sta; + struct ieee80211_key __rcu *gtk[NUM_DEFAULT_KEYS + + NUM_DEFAULT_MGMT_KEYS + + NUM_DEFAULT_BEACON_KEYS]; + struct ieee80211_sta_rx_stats __percpu *pcpu_rx_stats; + + /* Updated from RX path only, no locking requirements */ + struct ieee80211_sta_rx_stats rx_stats; + struct { + struct ewma_signal signal; + struct ewma_signal chain_signal[IEEE80211_MAX_CHAINS]; + } rx_stats_avg; + + /* Updated from TX status path only, no locking requirements */ + struct { + unsigned long filtered; + unsigned long retry_failed, retry_count; + unsigned int lost_packets; + unsigned long last_pkt_time; + u64 msdu_retries[IEEE80211_NUM_TIDS + 1]; + u64 msdu_failed[IEEE80211_NUM_TIDS + 1]; + unsigned long last_ack; + s8 last_ack_signal; + bool ack_signal_filled; + struct ewma_avg_signal avg_ack_signal; + } status_stats; + + /* Updated from TX path only, no locking requirements */ + struct { + u64 packets[IEEE80211_NUM_ACS]; + u64 bytes[IEEE80211_NUM_ACS]; + struct ieee80211_tx_rate last_rate; + struct rate_info last_rate_info; + u64 msdu[IEEE80211_NUM_TIDS + 1]; + } tx_stats; + + enum ieee80211_sta_rx_bandwidth cur_max_bandwidth; + + struct ieee80211_link_sta *pub; +}; + +/** + * struct sta_info - STA information + * + * This structure collects information about a station that + * mac80211 is communicating with. + * + * @list: global linked list entry + * @free_list: list entry for keeping track of stations to free + * @hash_node: hash node for rhashtable + * @addr: station's MAC address - duplicated from public part to + * let the hash table work with just a single cacheline + * @local: pointer to the global information + * @sdata: virtual interface this station belongs to + * @ptk: peer keys negotiated with this station, if any + * @ptk_idx: last installed peer key index + * @rate_ctrl: rate control algorithm reference + * @rate_ctrl_lock: spinlock used to protect rate control data + * (data inside the algorithm, so serializes calls there) + * @rate_ctrl_priv: rate control private per-STA pointer + * @lock: used for locking all fields that require locking, see comments + * in the header file. + * @drv_deliver_wk: used for delivering frames after driver PS unblocking + * @listen_interval: listen interval of this station, when we're acting as AP + * @_flags: STA flags, see &enum ieee80211_sta_info_flags, do not use directly + * @ps_lock: used for powersave (when mac80211 is the AP) related locking + * @ps_tx_buf: buffers (per AC) of frames to transmit to this station + * when it leaves power saving state or polls + * @tx_filtered: buffers (per AC) of frames we already tried to + * transmit but were filtered by hardware due to STA having + * entered power saving state, these are also delivered to + * the station when it leaves powersave or polls for frames + * @driver_buffered_tids: bitmap of TIDs the driver has data buffered on + * @txq_buffered_tids: bitmap of TIDs that mac80211 has txq data buffered on + * @assoc_at: clock boottime (in ns) of last association + * @last_connected: time (in seconds) when a station got connected + * @last_seq_ctrl: last received seq/frag number from this STA (per TID + * plus one for non-QoS frames) + * @tid_seq: per-TID sequence numbers for sending to this STA + * @airtime: per-AC struct airtime_info describing airtime statistics for this + * station + * @airtime_weight: station weight for airtime fairness calculation purposes + * @ampdu_mlme: A-MPDU state machine state + * @mesh: mesh STA information + * @debugfs_dir: debug filesystem directory dentry + * @dead: set to true when sta is unlinked + * @removed: set to true when sta is being removed from sta_list + * @uploaded: set to true when sta is uploaded to the driver + * @sta: station information we share with the driver + * @sta_state: duplicates information about station state (for debug) + * @rcu_head: RCU head used for freeing this station struct + * @cur_max_bandwidth: maximum bandwidth to use for TX to the station, + * taken from HT/VHT capabilities or VHT operating mode notification + * @cparams: CoDel parameters for this station. + * @reserved_tid: reserved TID (if any, otherwise IEEE80211_TID_UNRESERVED) + * @fast_tx: TX fastpath information + * @fast_rx: RX fastpath information + * @tdls_chandef: a TDLS peer can have a wider chandef that is compatible to + * the BSS one. + * @frags: fragment cache + * @cur: storage for aggregation data + * &struct ieee80211_sta points either here or to deflink.agg. + * @deflink: This is the default link STA information, for non MLO STA all link + * specific STA information is accessed through @deflink or through + * link[0] which points to address of @deflink. For MLO Link STA + * the first added link STA will point to deflink. + * @link: reference to Link Sta entries. For Non MLO STA, except 1st link, + * i.e link[0] all links would be assigned to NULL by default and + * would access link information via @deflink or link[0]. For MLO + * STA, first link STA being added will point its link pointer to + * @deflink address and remaining would be allocated and the address + * would be assigned to link[link_id] where link_id is the id assigned + * by the AP. + */ +struct sta_info { + /* General information, mostly static */ + struct list_head list, free_list; + struct rcu_head rcu_head; + struct rhlist_head hash_node; + u8 addr[ETH_ALEN]; + struct ieee80211_local *local; + struct ieee80211_sub_if_data *sdata; + struct ieee80211_key __rcu *ptk[NUM_DEFAULT_KEYS]; + u8 ptk_idx; + struct rate_control_ref *rate_ctrl; + void *rate_ctrl_priv; + spinlock_t rate_ctrl_lock; + spinlock_t lock; + + struct ieee80211_fast_tx __rcu *fast_tx; + struct ieee80211_fast_rx __rcu *fast_rx; + +#ifdef CONFIG_MAC80211_MESH + struct mesh_sta *mesh; +#endif + + struct work_struct drv_deliver_wk; + + u16 listen_interval; + + bool dead; + bool removed; + + bool uploaded; + + enum ieee80211_sta_state sta_state; + + /* use the accessors defined below */ + unsigned long _flags; + + /* STA powersave lock and frame queues */ + spinlock_t ps_lock; + struct sk_buff_head ps_tx_buf[IEEE80211_NUM_ACS]; + struct sk_buff_head tx_filtered[IEEE80211_NUM_ACS]; + unsigned long driver_buffered_tids; + unsigned long txq_buffered_tids; + + u64 assoc_at; + long last_connected; + + /* Plus 1 for non-QoS frames */ + __le16 last_seq_ctrl[IEEE80211_NUM_TIDS + 1]; + + u16 tid_seq[IEEE80211_QOS_CTL_TID_MASK + 1]; + + struct airtime_info airtime[IEEE80211_NUM_ACS]; + u16 airtime_weight; + + /* + * Aggregation information, locked with lock. + */ + struct sta_ampdu_mlme ampdu_mlme; + +#ifdef CONFIG_MAC80211_DEBUGFS + struct dentry *debugfs_dir; +#endif + + struct codel_params cparams; + + u8 reserved_tid; + + struct cfg80211_chan_def tdls_chandef; + + struct ieee80211_fragment_cache frags; + + struct ieee80211_sta_aggregates cur; + struct link_sta_info deflink; + struct link_sta_info __rcu *link[IEEE80211_MLD_MAX_NUM_LINKS]; + + /* keep last! */ + struct ieee80211_sta sta; +}; + +static inline enum nl80211_plink_state sta_plink_state(struct sta_info *sta) +{ +#ifdef CONFIG_MAC80211_MESH + return sta->mesh->plink_state; +#endif + return NL80211_PLINK_LISTEN; +} + +static inline void set_sta_flag(struct sta_info *sta, + enum ieee80211_sta_info_flags flag) +{ + WARN_ON(flag == WLAN_STA_AUTH || + flag == WLAN_STA_ASSOC || + flag == WLAN_STA_AUTHORIZED); + set_bit(flag, &sta->_flags); +} + +static inline void clear_sta_flag(struct sta_info *sta, + enum ieee80211_sta_info_flags flag) +{ + WARN_ON(flag == WLAN_STA_AUTH || + flag == WLAN_STA_ASSOC || + flag == WLAN_STA_AUTHORIZED); + clear_bit(flag, &sta->_flags); +} + +static inline int test_sta_flag(struct sta_info *sta, + enum ieee80211_sta_info_flags flag) +{ + return test_bit(flag, &sta->_flags); +} + +static inline int test_and_clear_sta_flag(struct sta_info *sta, + enum ieee80211_sta_info_flags flag) +{ + WARN_ON(flag == WLAN_STA_AUTH || + flag == WLAN_STA_ASSOC || + flag == WLAN_STA_AUTHORIZED); + return test_and_clear_bit(flag, &sta->_flags); +} + +static inline int test_and_set_sta_flag(struct sta_info *sta, + enum ieee80211_sta_info_flags flag) +{ + WARN_ON(flag == WLAN_STA_AUTH || + flag == WLAN_STA_ASSOC || + flag == WLAN_STA_AUTHORIZED); + return test_and_set_bit(flag, &sta->_flags); +} + +int sta_info_move_state(struct sta_info *sta, + enum ieee80211_sta_state new_state); + +static inline void sta_info_pre_move_state(struct sta_info *sta, + enum ieee80211_sta_state new_state) +{ + int ret; + + WARN_ON_ONCE(test_sta_flag(sta, WLAN_STA_INSERTED)); + + ret = sta_info_move_state(sta, new_state); + WARN_ON_ONCE(ret); +} + + +void ieee80211_assign_tid_tx(struct sta_info *sta, int tid, + struct tid_ampdu_tx *tid_tx); + +static inline struct tid_ampdu_tx * +rcu_dereference_protected_tid_tx(struct sta_info *sta, int tid) +{ + return rcu_dereference_protected(sta->ampdu_mlme.tid_tx[tid], + lockdep_is_held(&sta->lock) || + lockdep_is_held(&sta->ampdu_mlme.mtx)); +} + +/* Maximum number of frames to buffer per power saving station per AC */ +#define STA_MAX_TX_BUFFER 64 + +/* Minimum buffered frame expiry time. If STA uses listen interval that is + * smaller than this value, the minimum value here is used instead. */ +#define STA_TX_BUFFER_EXPIRE (10 * HZ) + +/* How often station data is cleaned up (e.g., expiration of buffered frames) + */ +#define STA_INFO_CLEANUP_INTERVAL (10 * HZ) + +struct rhlist_head *sta_info_hash_lookup(struct ieee80211_local *local, + const u8 *addr); + +/* + * Get a STA info, must be under RCU read lock. + */ +struct sta_info *sta_info_get(struct ieee80211_sub_if_data *sdata, + const u8 *addr); + +struct sta_info *sta_info_get_bss(struct ieee80211_sub_if_data *sdata, + const u8 *addr); + +/* user must hold sta_mtx or be in RCU critical section */ +struct sta_info *sta_info_get_by_addrs(struct ieee80211_local *local, + const u8 *sta_addr, const u8 *vif_addr); + +#define for_each_sta_info(local, _addr, _sta, _tmp) \ + rhl_for_each_entry_rcu(_sta, _tmp, \ + sta_info_hash_lookup(local, _addr), hash_node) + +struct rhlist_head *link_sta_info_hash_lookup(struct ieee80211_local *local, + const u8 *addr); + +#define for_each_link_sta_info(local, _addr, _sta, _tmp) \ + rhl_for_each_entry_rcu(_sta, _tmp, \ + link_sta_info_hash_lookup(local, _addr), \ + link_hash_node) + +struct link_sta_info * +link_sta_info_get_bss(struct ieee80211_sub_if_data *sdata, const u8 *addr); + +/* + * Get STA info by index, BROKEN! + */ +struct sta_info *sta_info_get_by_idx(struct ieee80211_sub_if_data *sdata, + int idx); +/* + * Create a new STA info, caller owns returned structure + * until sta_info_insert(). + */ +struct sta_info *sta_info_alloc(struct ieee80211_sub_if_data *sdata, + const u8 *addr, gfp_t gfp); +struct sta_info *sta_info_alloc_with_link(struct ieee80211_sub_if_data *sdata, + const u8 *mld_addr, + unsigned int link_id, + const u8 *link_addr, + gfp_t gfp); + +void sta_info_free(struct ieee80211_local *local, struct sta_info *sta); + +/* + * Insert STA info into hash table/list, returns zero or a + * -EEXIST if (if the same MAC address is already present). + * + * Calling the non-rcu version makes the caller relinquish, + * the _rcu version calls read_lock_rcu() and must be called + * without it held. + */ +int sta_info_insert(struct sta_info *sta); +int sta_info_insert_rcu(struct sta_info *sta) __acquires(RCU); + +int __must_check __sta_info_destroy(struct sta_info *sta); +int sta_info_destroy_addr(struct ieee80211_sub_if_data *sdata, + const u8 *addr); +int sta_info_destroy_addr_bss(struct ieee80211_sub_if_data *sdata, + const u8 *addr); + +void sta_info_recalc_tim(struct sta_info *sta); + +int sta_info_init(struct ieee80211_local *local); +void sta_info_stop(struct ieee80211_local *local); + +/** + * __sta_info_flush - flush matching STA entries from the STA table + * + * Returns the number of removed STA entries. + * + * @sdata: sdata to remove all stations from + * @vlans: if the given interface is an AP interface, also flush VLANs + */ +int __sta_info_flush(struct ieee80211_sub_if_data *sdata, bool vlans); + +/** + * sta_info_flush - flush matching STA entries from the STA table + * + * Returns the number of removed STA entries. + * + * @sdata: sdata to remove all stations from + */ +static inline int sta_info_flush(struct ieee80211_sub_if_data *sdata) +{ + return __sta_info_flush(sdata, false); +} + +void sta_set_rate_info_tx(struct sta_info *sta, + const struct ieee80211_tx_rate *rate, + struct rate_info *rinfo); +void sta_set_sinfo(struct sta_info *sta, struct station_info *sinfo, + bool tidstats); + +u32 sta_get_expected_throughput(struct sta_info *sta); + +void ieee80211_sta_expire(struct ieee80211_sub_if_data *sdata, + unsigned long exp_time); + +int ieee80211_sta_allocate_link(struct sta_info *sta, unsigned int link_id); +void ieee80211_sta_free_link(struct sta_info *sta, unsigned int link_id); +int ieee80211_sta_activate_link(struct sta_info *sta, unsigned int link_id); +void ieee80211_sta_remove_link(struct sta_info *sta, unsigned int link_id); + +void ieee80211_sta_ps_deliver_wakeup(struct sta_info *sta); +void ieee80211_sta_ps_deliver_poll_response(struct sta_info *sta); +void ieee80211_sta_ps_deliver_uapsd(struct sta_info *sta); + +unsigned long ieee80211_sta_last_active(struct sta_info *sta); + +void ieee80211_sta_set_max_amsdu_subframes(struct sta_info *sta, + const u8 *ext_capab, + unsigned int ext_capab_len); + +enum sta_stats_type { + STA_STATS_RATE_TYPE_INVALID = 0, + STA_STATS_RATE_TYPE_LEGACY, + STA_STATS_RATE_TYPE_HT, + STA_STATS_RATE_TYPE_VHT, + STA_STATS_RATE_TYPE_HE, + STA_STATS_RATE_TYPE_S1G, +}; + +#define STA_STATS_FIELD_HT_MCS GENMASK( 7, 0) +#define STA_STATS_FIELD_LEGACY_IDX GENMASK( 3, 0) +#define STA_STATS_FIELD_LEGACY_BAND GENMASK( 7, 4) +#define STA_STATS_FIELD_VHT_MCS GENMASK( 3, 0) +#define STA_STATS_FIELD_VHT_NSS GENMASK( 7, 4) +#define STA_STATS_FIELD_HE_MCS GENMASK( 3, 0) +#define STA_STATS_FIELD_HE_NSS GENMASK( 7, 4) +#define STA_STATS_FIELD_BW GENMASK(11, 8) +#define STA_STATS_FIELD_SGI GENMASK(12, 12) +#define STA_STATS_FIELD_TYPE GENMASK(15, 13) +#define STA_STATS_FIELD_HE_RU GENMASK(18, 16) +#define STA_STATS_FIELD_HE_GI GENMASK(20, 19) +#define STA_STATS_FIELD_HE_DCM GENMASK(21, 21) + +#define STA_STATS_FIELD(_n, _v) FIELD_PREP(STA_STATS_FIELD_ ## _n, _v) +#define STA_STATS_GET(_n, _v) FIELD_GET(STA_STATS_FIELD_ ## _n, _v) + +#define STA_STATS_RATE_INVALID 0 + +static inline u32 sta_stats_encode_rate(struct ieee80211_rx_status *s) +{ + u32 r; + + r = STA_STATS_FIELD(BW, s->bw); + + if (s->enc_flags & RX_ENC_FLAG_SHORT_GI) + r |= STA_STATS_FIELD(SGI, 1); + + switch (s->encoding) { + case RX_ENC_VHT: + r |= STA_STATS_FIELD(TYPE, STA_STATS_RATE_TYPE_VHT); + r |= STA_STATS_FIELD(VHT_NSS, s->nss); + r |= STA_STATS_FIELD(VHT_MCS, s->rate_idx); + break; + case RX_ENC_HT: + r |= STA_STATS_FIELD(TYPE, STA_STATS_RATE_TYPE_HT); + r |= STA_STATS_FIELD(HT_MCS, s->rate_idx); + break; + case RX_ENC_LEGACY: + r |= STA_STATS_FIELD(TYPE, STA_STATS_RATE_TYPE_LEGACY); + r |= STA_STATS_FIELD(LEGACY_BAND, s->band); + r |= STA_STATS_FIELD(LEGACY_IDX, s->rate_idx); + break; + case RX_ENC_HE: + r |= STA_STATS_FIELD(TYPE, STA_STATS_RATE_TYPE_HE); + r |= STA_STATS_FIELD(HE_NSS, s->nss); + r |= STA_STATS_FIELD(HE_MCS, s->rate_idx); + r |= STA_STATS_FIELD(HE_GI, s->he_gi); + r |= STA_STATS_FIELD(HE_RU, s->he_ru); + r |= STA_STATS_FIELD(HE_DCM, s->he_dcm); + break; + default: + WARN_ON(1); + return STA_STATS_RATE_INVALID; + } + + return r; +} + +#endif /* STA_INFO_H */ diff --git a/net/mac80211/status.c b/net/mac80211/status.c new file mode 100644 index 000000000..3f9ddd7f0 --- /dev/null +++ b/net/mac80211/status.c @@ -0,0 +1,1296 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2002-2005, Instant802 Networks, Inc. + * Copyright 2005-2006, Devicescape Software, Inc. + * Copyright 2006-2007 Jiri Benc + * Copyright 2008-2010 Johannes Berg + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright 2021-2022 Intel Corporation + */ + +#include +#include +#include +#include +#include "ieee80211_i.h" +#include "rate.h" +#include "mesh.h" +#include "led.h" +#include "wme.h" + + +void ieee80211_tx_status_irqsafe(struct ieee80211_hw *hw, + struct sk_buff *skb) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + int tmp; + + skb->pkt_type = IEEE80211_TX_STATUS_MSG; + skb_queue_tail(info->flags & IEEE80211_TX_CTL_REQ_TX_STATUS ? + &local->skb_queue : &local->skb_queue_unreliable, skb); + tmp = skb_queue_len(&local->skb_queue) + + skb_queue_len(&local->skb_queue_unreliable); + while (tmp > IEEE80211_IRQSAFE_QUEUE_LIMIT && + (skb = skb_dequeue(&local->skb_queue_unreliable))) { + ieee80211_free_txskb(hw, skb); + tmp--; + I802_DEBUG_INC(local->tx_status_drop); + } + tasklet_schedule(&local->tasklet); +} +EXPORT_SYMBOL(ieee80211_tx_status_irqsafe); + +static void ieee80211_handle_filtered_frame(struct ieee80211_local *local, + struct sta_info *sta, + struct sk_buff *skb) +{ + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_hdr *hdr = (void *)skb->data; + int ac; + + if (info->flags & (IEEE80211_TX_CTL_NO_PS_BUFFER | + IEEE80211_TX_CTL_AMPDU | + IEEE80211_TX_CTL_HW_80211_ENCAP)) { + ieee80211_free_txskb(&local->hw, skb); + return; + } + + /* + * This skb 'survived' a round-trip through the driver, and + * hopefully the driver didn't mangle it too badly. However, + * we can definitely not rely on the control information + * being correct. Clear it so we don't get junk there, and + * indicate that it needs new processing, but must not be + * modified/encrypted again. + */ + memset(&info->control, 0, sizeof(info->control)); + + info->control.jiffies = jiffies; + info->control.vif = &sta->sdata->vif; + info->control.flags |= IEEE80211_TX_INTCFL_NEED_TXPROCESSING; + info->flags |= IEEE80211_TX_INTFL_RETRANSMISSION; + info->flags &= ~IEEE80211_TX_TEMPORARY_FLAGS; + + sta->deflink.status_stats.filtered++; + + /* + * Clear more-data bit on filtered frames, it might be set + * but later frames might time out so it might have to be + * clear again ... It's all rather unlikely (this frame + * should time out first, right?) but let's not confuse + * peers unnecessarily. + */ + if (hdr->frame_control & cpu_to_le16(IEEE80211_FCTL_MOREDATA)) + hdr->frame_control &= ~cpu_to_le16(IEEE80211_FCTL_MOREDATA); + + if (ieee80211_is_data_qos(hdr->frame_control)) { + u8 *p = ieee80211_get_qos_ctl(hdr); + int tid = *p & IEEE80211_QOS_CTL_TID_MASK; + + /* + * Clear EOSP if set, this could happen e.g. + * if an absence period (us being a P2P GO) + * shortens the SP. + */ + if (*p & IEEE80211_QOS_CTL_EOSP) + *p &= ~IEEE80211_QOS_CTL_EOSP; + ac = ieee80211_ac_from_tid(tid); + } else { + ac = IEEE80211_AC_BE; + } + + /* + * Clear the TX filter mask for this STA when sending the next + * packet. If the STA went to power save mode, this will happen + * when it wakes up for the next time. + */ + set_sta_flag(sta, WLAN_STA_CLEAR_PS_FILT); + ieee80211_clear_fast_xmit(sta); + + /* + * This code races in the following way: + * + * (1) STA sends frame indicating it will go to sleep and does so + * (2) hardware/firmware adds STA to filter list, passes frame up + * (3) hardware/firmware processes TX fifo and suppresses a frame + * (4) we get TX status before having processed the frame and + * knowing that the STA has gone to sleep. + * + * This is actually quite unlikely even when both those events are + * processed from interrupts coming in quickly after one another or + * even at the same time because we queue both TX status events and + * RX frames to be processed by a tasklet and process them in the + * same order that they were received or TX status last. Hence, there + * is no race as long as the frame RX is processed before the next TX + * status, which drivers can ensure, see below. + * + * Note that this can only happen if the hardware or firmware can + * actually add STAs to the filter list, if this is done by the + * driver in response to set_tim() (which will only reduce the race + * this whole filtering tries to solve, not completely solve it) + * this situation cannot happen. + * + * To completely solve this race drivers need to make sure that they + * (a) don't mix the irq-safe/not irq-safe TX status/RX processing + * functions and + * (b) always process RX events before TX status events if ordering + * can be unknown, for example with different interrupt status + * bits. + * (c) if PS mode transitions are manual (i.e. the flag + * %IEEE80211_HW_AP_LINK_PS is set), always process PS state + * changes before calling TX status events if ordering can be + * unknown. + */ + if (test_sta_flag(sta, WLAN_STA_PS_STA) && + skb_queue_len(&sta->tx_filtered[ac]) < STA_MAX_TX_BUFFER) { + skb_queue_tail(&sta->tx_filtered[ac], skb); + sta_info_recalc_tim(sta); + + if (!timer_pending(&local->sta_cleanup)) + mod_timer(&local->sta_cleanup, + round_jiffies(jiffies + + STA_INFO_CLEANUP_INTERVAL)); + return; + } + + if (!test_sta_flag(sta, WLAN_STA_PS_STA) && + !(info->flags & IEEE80211_TX_INTFL_RETRIED)) { + /* Software retry the packet once */ + info->flags |= IEEE80211_TX_INTFL_RETRIED; + ieee80211_add_pending_skb(local, skb); + return; + } + + ps_dbg_ratelimited(sta->sdata, + "dropped TX filtered frame, queue_len=%d PS=%d @%lu\n", + skb_queue_len(&sta->tx_filtered[ac]), + !!test_sta_flag(sta, WLAN_STA_PS_STA), jiffies); + ieee80211_free_txskb(&local->hw, skb); +} + +static void ieee80211_check_pending_bar(struct sta_info *sta, u8 *addr, u8 tid) +{ + struct tid_ampdu_tx *tid_tx; + + tid_tx = rcu_dereference(sta->ampdu_mlme.tid_tx[tid]); + if (!tid_tx || !tid_tx->bar_pending) + return; + + tid_tx->bar_pending = false; + ieee80211_send_bar(&sta->sdata->vif, addr, tid, tid_tx->failed_bar_ssn); +} + +static void ieee80211_frame_acked(struct sta_info *sta, struct sk_buff *skb) +{ + struct ieee80211_mgmt *mgmt = (void *) skb->data; + struct ieee80211_local *local = sta->local; + struct ieee80211_sub_if_data *sdata = sta->sdata; + + if (ieee80211_is_data_qos(mgmt->frame_control)) { + struct ieee80211_hdr *hdr = (void *) skb->data; + u8 *qc = ieee80211_get_qos_ctl(hdr); + u16 tid = qc[0] & 0xf; + + ieee80211_check_pending_bar(sta, hdr->addr1, tid); + } + + if (ieee80211_is_action(mgmt->frame_control) && + !ieee80211_has_protected(mgmt->frame_control) && + mgmt->u.action.category == WLAN_CATEGORY_HT && + mgmt->u.action.u.ht_smps.action == WLAN_HT_ACTION_SMPS && + ieee80211_sdata_running(sdata)) { + enum ieee80211_smps_mode smps_mode; + + switch (mgmt->u.action.u.ht_smps.smps_control) { + case WLAN_HT_SMPS_CONTROL_DYNAMIC: + smps_mode = IEEE80211_SMPS_DYNAMIC; + break; + case WLAN_HT_SMPS_CONTROL_STATIC: + smps_mode = IEEE80211_SMPS_STATIC; + break; + case WLAN_HT_SMPS_CONTROL_DISABLED: + default: /* shouldn't happen since we don't send that */ + smps_mode = IEEE80211_SMPS_OFF; + break; + } + + if (sdata->vif.type == NL80211_IFTYPE_STATION) { + /* + * This update looks racy, but isn't -- if we come + * here we've definitely got a station that we're + * talking to, and on a managed interface that can + * only be the AP. And the only other place updating + * this variable in managed mode is before association. + */ + sdata->deflink.smps_mode = smps_mode; + ieee80211_queue_work(&local->hw, &sdata->recalc_smps); + } + } +} + +static void ieee80211_set_bar_pending(struct sta_info *sta, u8 tid, u16 ssn) +{ + struct tid_ampdu_tx *tid_tx; + + tid_tx = rcu_dereference(sta->ampdu_mlme.tid_tx[tid]); + if (!tid_tx) + return; + + tid_tx->failed_bar_ssn = ssn; + tid_tx->bar_pending = true; +} + +static int ieee80211_tx_radiotap_len(struct ieee80211_tx_info *info, + struct ieee80211_tx_status *status) +{ + struct ieee80211_rate_status *status_rate = NULL; + int len = sizeof(struct ieee80211_radiotap_header); + + if (status && status->n_rates) + status_rate = &status->rates[status->n_rates - 1]; + + /* IEEE80211_RADIOTAP_RATE rate */ + if (status_rate && !(status_rate->rate_idx.flags & + (RATE_INFO_FLAGS_MCS | + RATE_INFO_FLAGS_DMG | + RATE_INFO_FLAGS_EDMG | + RATE_INFO_FLAGS_VHT_MCS | + RATE_INFO_FLAGS_HE_MCS))) + len += 2; + else if (info->status.rates[0].idx >= 0 && + !(info->status.rates[0].flags & + (IEEE80211_TX_RC_MCS | IEEE80211_TX_RC_VHT_MCS))) + len += 2; + + /* IEEE80211_RADIOTAP_TX_FLAGS */ + len += 2; + + /* IEEE80211_RADIOTAP_DATA_RETRIES */ + len += 1; + + /* IEEE80211_RADIOTAP_MCS + * IEEE80211_RADIOTAP_VHT */ + if (status_rate) { + if (status_rate->rate_idx.flags & RATE_INFO_FLAGS_MCS) + len += 3; + else if (status_rate->rate_idx.flags & RATE_INFO_FLAGS_VHT_MCS) + len = ALIGN(len, 2) + 12; + else if (status_rate->rate_idx.flags & RATE_INFO_FLAGS_HE_MCS) + len = ALIGN(len, 2) + 12; + } else if (info->status.rates[0].idx >= 0) { + if (info->status.rates[0].flags & IEEE80211_TX_RC_MCS) + len += 3; + else if (info->status.rates[0].flags & IEEE80211_TX_RC_VHT_MCS) + len = ALIGN(len, 2) + 12; + } + + return len; +} + +static void +ieee80211_add_tx_radiotap_header(struct ieee80211_local *local, + struct sk_buff *skb, int retry_count, + int rtap_len, int shift, + struct ieee80211_tx_status *status) +{ + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + struct ieee80211_radiotap_header *rthdr; + struct ieee80211_rate_status *status_rate = NULL; + unsigned char *pos; + u16 legacy_rate = 0; + u16 txflags; + + if (status && status->n_rates) + status_rate = &status->rates[status->n_rates - 1]; + + rthdr = skb_push(skb, rtap_len); + + memset(rthdr, 0, rtap_len); + rthdr->it_len = cpu_to_le16(rtap_len); + rthdr->it_present = + cpu_to_le32(BIT(IEEE80211_RADIOTAP_TX_FLAGS) | + BIT(IEEE80211_RADIOTAP_DATA_RETRIES)); + pos = (unsigned char *)(rthdr + 1); + + /* + * XXX: Once radiotap gets the bitmap reset thing the vendor + * extensions proposal contains, we can actually report + * the whole set of tries we did. + */ + + /* IEEE80211_RADIOTAP_RATE */ + + if (status_rate) { + if (!(status_rate->rate_idx.flags & + (RATE_INFO_FLAGS_MCS | + RATE_INFO_FLAGS_DMG | + RATE_INFO_FLAGS_EDMG | + RATE_INFO_FLAGS_VHT_MCS | + RATE_INFO_FLAGS_HE_MCS))) + legacy_rate = status_rate->rate_idx.legacy; + } else if (info->status.rates[0].idx >= 0 && + !(info->status.rates[0].flags & (IEEE80211_TX_RC_MCS | + IEEE80211_TX_RC_VHT_MCS))) { + struct ieee80211_supported_band *sband; + + sband = local->hw.wiphy->bands[info->band]; + legacy_rate = + sband->bitrates[info->status.rates[0].idx].bitrate; + } + + if (legacy_rate) { + rthdr->it_present |= cpu_to_le32(BIT(IEEE80211_RADIOTAP_RATE)); + *pos = DIV_ROUND_UP(legacy_rate, 5 * (1 << shift)); + /* padding for tx flags */ + pos += 2; + } + + /* IEEE80211_RADIOTAP_TX_FLAGS */ + txflags = 0; + if (!(info->flags & IEEE80211_TX_STAT_ACK) && + !is_multicast_ether_addr(hdr->addr1)) + txflags |= IEEE80211_RADIOTAP_F_TX_FAIL; + + if (info->status.rates[0].flags & IEEE80211_TX_RC_USE_CTS_PROTECT) + txflags |= IEEE80211_RADIOTAP_F_TX_CTS; + if (info->status.rates[0].flags & IEEE80211_TX_RC_USE_RTS_CTS) + txflags |= IEEE80211_RADIOTAP_F_TX_RTS; + + put_unaligned_le16(txflags, pos); + pos += 2; + + /* IEEE80211_RADIOTAP_DATA_RETRIES */ + /* for now report the total retry_count */ + *pos = retry_count; + pos++; + + if (status_rate && (status_rate->rate_idx.flags & RATE_INFO_FLAGS_MCS)) + { + rthdr->it_present |= cpu_to_le32(BIT(IEEE80211_RADIOTAP_MCS)); + pos[0] = IEEE80211_RADIOTAP_MCS_HAVE_MCS | + IEEE80211_RADIOTAP_MCS_HAVE_GI | + IEEE80211_RADIOTAP_MCS_HAVE_BW; + if (status_rate->rate_idx.flags & RATE_INFO_FLAGS_SHORT_GI) + pos[1] |= IEEE80211_RADIOTAP_MCS_SGI; + if (status_rate->rate_idx.bw == RATE_INFO_BW_40) + pos[1] |= IEEE80211_RADIOTAP_MCS_BW_40; + pos[2] = status_rate->rate_idx.mcs; + pos += 3; + } else if (status_rate && (status_rate->rate_idx.flags & + RATE_INFO_FLAGS_VHT_MCS)) + { + u16 known = local->hw.radiotap_vht_details & + (IEEE80211_RADIOTAP_VHT_KNOWN_GI | + IEEE80211_RADIOTAP_VHT_KNOWN_BANDWIDTH); + + rthdr->it_present |= cpu_to_le32(BIT(IEEE80211_RADIOTAP_VHT)); + + /* required alignment from rthdr */ + pos = (u8 *)rthdr + ALIGN(pos - (u8 *)rthdr, 2); + + /* u16 known - IEEE80211_RADIOTAP_VHT_KNOWN_* */ + put_unaligned_le16(known, pos); + pos += 2; + + /* u8 flags - IEEE80211_RADIOTAP_VHT_FLAG_* */ + if (status_rate->rate_idx.flags & RATE_INFO_FLAGS_SHORT_GI) + *pos |= IEEE80211_RADIOTAP_VHT_FLAG_SGI; + pos++; + + /* u8 bandwidth */ + switch (status_rate->rate_idx.bw) { + case RATE_INFO_BW_160: + *pos = 11; + break; + case RATE_INFO_BW_80: + *pos = 4; + break; + case RATE_INFO_BW_40: + *pos = 1; + break; + default: + *pos = 0; + break; + } + pos++; + + /* u8 mcs_nss[4] */ + *pos = (status_rate->rate_idx.mcs << 4) | + status_rate->rate_idx.nss; + pos += 4; + + /* u8 coding */ + pos++; + /* u8 group_id */ + pos++; + /* u16 partial_aid */ + pos += 2; + } else if (status_rate && (status_rate->rate_idx.flags & + RATE_INFO_FLAGS_HE_MCS)) + { + struct ieee80211_radiotap_he *he; + + rthdr->it_present |= cpu_to_le32(BIT(IEEE80211_RADIOTAP_HE)); + + /* required alignment from rthdr */ + pos = (u8 *)rthdr + ALIGN(pos - (u8 *)rthdr, 2); + he = (struct ieee80211_radiotap_he *)pos; + + he->data1 = cpu_to_le16(IEEE80211_RADIOTAP_HE_DATA1_FORMAT_SU | + IEEE80211_RADIOTAP_HE_DATA1_DATA_MCS_KNOWN | + IEEE80211_RADIOTAP_HE_DATA1_DATA_DCM_KNOWN | + IEEE80211_RADIOTAP_HE_DATA1_BW_RU_ALLOC_KNOWN); + + he->data2 = cpu_to_le16(IEEE80211_RADIOTAP_HE_DATA2_GI_KNOWN); + +#define HE_PREP(f, val) le16_encode_bits(val, IEEE80211_RADIOTAP_HE_##f) + + he->data6 |= HE_PREP(DATA6_NSTS, status_rate->rate_idx.nss); + +#define CHECK_GI(s) \ + BUILD_BUG_ON(IEEE80211_RADIOTAP_HE_DATA5_GI_##s != \ + (int)NL80211_RATE_INFO_HE_GI_##s) + + CHECK_GI(0_8); + CHECK_GI(1_6); + CHECK_GI(3_2); + + he->data3 |= HE_PREP(DATA3_DATA_MCS, status_rate->rate_idx.mcs); + he->data3 |= HE_PREP(DATA3_DATA_DCM, status_rate->rate_idx.he_dcm); + + he->data5 |= HE_PREP(DATA5_GI, status_rate->rate_idx.he_gi); + + switch (status_rate->rate_idx.bw) { + case RATE_INFO_BW_20: + he->data5 |= HE_PREP(DATA5_DATA_BW_RU_ALLOC, + IEEE80211_RADIOTAP_HE_DATA5_DATA_BW_RU_ALLOC_20MHZ); + break; + case RATE_INFO_BW_40: + he->data5 |= HE_PREP(DATA5_DATA_BW_RU_ALLOC, + IEEE80211_RADIOTAP_HE_DATA5_DATA_BW_RU_ALLOC_40MHZ); + break; + case RATE_INFO_BW_80: + he->data5 |= HE_PREP(DATA5_DATA_BW_RU_ALLOC, + IEEE80211_RADIOTAP_HE_DATA5_DATA_BW_RU_ALLOC_80MHZ); + break; + case RATE_INFO_BW_160: + he->data5 |= HE_PREP(DATA5_DATA_BW_RU_ALLOC, + IEEE80211_RADIOTAP_HE_DATA5_DATA_BW_RU_ALLOC_160MHZ); + break; + case RATE_INFO_BW_HE_RU: +#define CHECK_RU_ALLOC(s) \ + BUILD_BUG_ON(IEEE80211_RADIOTAP_HE_DATA5_DATA_BW_RU_ALLOC_##s##T != \ + NL80211_RATE_INFO_HE_RU_ALLOC_##s + 4) + + CHECK_RU_ALLOC(26); + CHECK_RU_ALLOC(52); + CHECK_RU_ALLOC(106); + CHECK_RU_ALLOC(242); + CHECK_RU_ALLOC(484); + CHECK_RU_ALLOC(996); + CHECK_RU_ALLOC(2x996); + + he->data5 |= HE_PREP(DATA5_DATA_BW_RU_ALLOC, + status_rate->rate_idx.he_ru_alloc + 4); + break; + default: + WARN_ONCE(1, "Invalid SU BW %d\n", status_rate->rate_idx.bw); + } + + pos += sizeof(struct ieee80211_radiotap_he); + } + + if (status_rate || info->status.rates[0].idx < 0) + return; + + /* IEEE80211_RADIOTAP_MCS + * IEEE80211_RADIOTAP_VHT */ + if (info->status.rates[0].flags & IEEE80211_TX_RC_MCS) { + rthdr->it_present |= cpu_to_le32(BIT(IEEE80211_RADIOTAP_MCS)); + pos[0] = IEEE80211_RADIOTAP_MCS_HAVE_MCS | + IEEE80211_RADIOTAP_MCS_HAVE_GI | + IEEE80211_RADIOTAP_MCS_HAVE_BW; + if (info->status.rates[0].flags & IEEE80211_TX_RC_SHORT_GI) + pos[1] |= IEEE80211_RADIOTAP_MCS_SGI; + if (info->status.rates[0].flags & IEEE80211_TX_RC_40_MHZ_WIDTH) + pos[1] |= IEEE80211_RADIOTAP_MCS_BW_40; + if (info->status.rates[0].flags & IEEE80211_TX_RC_GREEN_FIELD) + pos[1] |= IEEE80211_RADIOTAP_MCS_FMT_GF; + pos[2] = info->status.rates[0].idx; + pos += 3; + } else if (info->status.rates[0].flags & IEEE80211_TX_RC_VHT_MCS) { + u16 known = local->hw.radiotap_vht_details & + (IEEE80211_RADIOTAP_VHT_KNOWN_GI | + IEEE80211_RADIOTAP_VHT_KNOWN_BANDWIDTH); + + rthdr->it_present |= cpu_to_le32(BIT(IEEE80211_RADIOTAP_VHT)); + + /* required alignment from rthdr */ + pos = (u8 *)rthdr + ALIGN(pos - (u8 *)rthdr, 2); + + /* u16 known - IEEE80211_RADIOTAP_VHT_KNOWN_* */ + put_unaligned_le16(known, pos); + pos += 2; + + /* u8 flags - IEEE80211_RADIOTAP_VHT_FLAG_* */ + if (info->status.rates[0].flags & IEEE80211_TX_RC_SHORT_GI) + *pos |= IEEE80211_RADIOTAP_VHT_FLAG_SGI; + pos++; + + /* u8 bandwidth */ + if (info->status.rates[0].flags & IEEE80211_TX_RC_40_MHZ_WIDTH) + *pos = 1; + else if (info->status.rates[0].flags & IEEE80211_TX_RC_80_MHZ_WIDTH) + *pos = 4; + else if (info->status.rates[0].flags & IEEE80211_TX_RC_160_MHZ_WIDTH) + *pos = 11; + else /* IEEE80211_TX_RC_{20_MHZ_WIDTH,FIXME:DUP_DATA} */ + *pos = 0; + pos++; + + /* u8 mcs_nss[4] */ + *pos = (ieee80211_rate_get_vht_mcs(&info->status.rates[0]) << 4) | + ieee80211_rate_get_vht_nss(&info->status.rates[0]); + pos += 4; + + /* u8 coding */ + pos++; + /* u8 group_id */ + pos++; + /* u16 partial_aid */ + pos += 2; + } +} + +/* + * Handles the tx for TDLS teardown frames. + * If the frame wasn't ACKed by the peer - it will be re-sent through the AP + */ +static void ieee80211_tdls_td_tx_handle(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, u32 flags) +{ + struct sk_buff *teardown_skb; + struct sk_buff *orig_teardown_skb; + bool is_teardown = false; + + /* Get the teardown data we need and free the lock */ + spin_lock(&sdata->u.mgd.teardown_lock); + teardown_skb = sdata->u.mgd.teardown_skb; + orig_teardown_skb = sdata->u.mgd.orig_teardown_skb; + if ((skb == orig_teardown_skb) && teardown_skb) { + sdata->u.mgd.teardown_skb = NULL; + sdata->u.mgd.orig_teardown_skb = NULL; + is_teardown = true; + } + spin_unlock(&sdata->u.mgd.teardown_lock); + + if (is_teardown) { + /* This mechanism relies on being able to get ACKs */ + WARN_ON(!ieee80211_hw_check(&local->hw, REPORTS_TX_ACK_STATUS)); + + /* Check if peer has ACKed */ + if (flags & IEEE80211_TX_STAT_ACK) { + dev_kfree_skb_any(teardown_skb); + } else { + tdls_dbg(sdata, + "TDLS Resending teardown through AP\n"); + + ieee80211_subif_start_xmit(teardown_skb, skb->dev); + } + } +} + +static struct ieee80211_sub_if_data * +ieee80211_sdata_from_skb(struct ieee80211_local *local, struct sk_buff *skb) +{ + struct ieee80211_sub_if_data *sdata; + + if (skb->dev) { + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + if (!sdata->dev) + continue; + + if (skb->dev == sdata->dev) + return sdata; + } + + return NULL; + } + + return rcu_dereference(local->p2p_sdata); +} + +static void ieee80211_report_ack_skb(struct ieee80211_local *local, + struct sk_buff *orig_skb, + bool acked, bool dropped, + ktime_t ack_hwtstamp) +{ + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(orig_skb); + struct sk_buff *skb; + unsigned long flags; + + spin_lock_irqsave(&local->ack_status_lock, flags); + skb = idr_remove(&local->ack_status_frames, info->ack_frame_id); + spin_unlock_irqrestore(&local->ack_status_lock, flags); + + if (!skb) + return; + + if (info->flags & IEEE80211_TX_INTFL_NL80211_FRAME_TX) { + u64 cookie = IEEE80211_SKB_CB(skb)->ack.cookie; + struct ieee80211_sub_if_data *sdata; + struct ieee80211_hdr *hdr = (void *)skb->data; + bool is_valid_ack_signal = + !!(info->status.flags & IEEE80211_TX_STATUS_ACK_SIGNAL_VALID); + struct cfg80211_tx_status status = { + .cookie = cookie, + .buf = skb->data, + .len = skb->len, + .ack = acked, + }; + + if (ieee80211_is_timing_measurement(orig_skb) || + ieee80211_is_ftm(orig_skb)) { + status.tx_tstamp = + ktime_to_ns(skb_hwtstamps(orig_skb)->hwtstamp); + status.ack_tstamp = ktime_to_ns(ack_hwtstamp); + } + + rcu_read_lock(); + sdata = ieee80211_sdata_from_skb(local, skb); + if (sdata) { + if (skb->protocol == sdata->control_port_protocol || + skb->protocol == cpu_to_be16(ETH_P_PREAUTH)) + cfg80211_control_port_tx_status(&sdata->wdev, + cookie, + skb->data, + skb->len, + acked, + GFP_ATOMIC); + else if (ieee80211_is_any_nullfunc(hdr->frame_control)) + cfg80211_probe_status(sdata->dev, hdr->addr1, + cookie, acked, + info->status.ack_signal, + is_valid_ack_signal, + GFP_ATOMIC); + else if (ieee80211_is_mgmt(hdr->frame_control)) + cfg80211_mgmt_tx_status_ext(&sdata->wdev, + &status, + GFP_ATOMIC); + else + pr_warn("Unknown status report in ack skb\n"); + + } + rcu_read_unlock(); + + dev_kfree_skb_any(skb); + } else if (dropped) { + dev_kfree_skb_any(skb); + } else { + /* consumes skb */ + skb_complete_wifi_ack(skb, acked); + } +} + +static void ieee80211_report_used_skb(struct ieee80211_local *local, + struct sk_buff *skb, bool dropped, + ktime_t ack_hwtstamp) +{ + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + u16 tx_time_est = ieee80211_info_get_tx_time_est(info); + struct ieee80211_hdr *hdr = (void *)skb->data; + bool acked = info->flags & IEEE80211_TX_STAT_ACK; + + if (dropped) + acked = false; + + if (tx_time_est) { + struct sta_info *sta; + + rcu_read_lock(); + + sta = sta_info_get_by_addrs(local, hdr->addr1, hdr->addr2); + ieee80211_sta_update_pending_airtime(local, sta, + skb_get_queue_mapping(skb), + tx_time_est, + true); + rcu_read_unlock(); + } + + if (info->flags & IEEE80211_TX_INTFL_MLME_CONN_TX) { + struct ieee80211_sub_if_data *sdata; + + rcu_read_lock(); + + sdata = ieee80211_sdata_from_skb(local, skb); + + if (!sdata) { + skb->dev = NULL; + } else if (!dropped) { + unsigned int hdr_size = + ieee80211_hdrlen(hdr->frame_control); + + /* Check to see if packet is a TDLS teardown packet */ + if (ieee80211_is_data(hdr->frame_control) && + (ieee80211_get_tdls_action(skb, hdr_size) == + WLAN_TDLS_TEARDOWN)) { + ieee80211_tdls_td_tx_handle(local, sdata, skb, + info->flags); + } else if (ieee80211_s1g_is_twt_setup(skb)) { + if (!acked) { + struct sk_buff *qskb; + + qskb = skb_clone(skb, GFP_ATOMIC); + if (qskb) { + skb_queue_tail(&sdata->status_queue, + qskb); + ieee80211_queue_work(&local->hw, + &sdata->work); + } + } + } else { + ieee80211_mgd_conn_tx_status(sdata, + hdr->frame_control, + acked); + } + } + + rcu_read_unlock(); + } else if (info->ack_frame_id) { + ieee80211_report_ack_skb(local, skb, acked, dropped, + ack_hwtstamp); + } + + if (!dropped && skb->destructor) { + skb->wifi_acked_valid = 1; + skb->wifi_acked = acked; + } + + ieee80211_led_tx(local); + + if (skb_has_frag_list(skb)) { + kfree_skb_list(skb_shinfo(skb)->frag_list); + skb_shinfo(skb)->frag_list = NULL; + } +} + +/* + * Use a static threshold for now, best value to be determined + * by testing ... + * Should it depend on: + * - on # of retransmissions + * - current throughput (higher value for higher tpt)? + */ +#define STA_LOST_PKT_THRESHOLD 50 +#define STA_LOST_PKT_TIME HZ /* 1 sec since last ACK */ +#define STA_LOST_TDLS_PKT_TIME (10*HZ) /* 10secs since last ACK */ + +static void ieee80211_lost_packet(struct sta_info *sta, + struct ieee80211_tx_info *info) +{ + unsigned long pkt_time = STA_LOST_PKT_TIME; + unsigned int pkt_thr = STA_LOST_PKT_THRESHOLD; + + /* If driver relies on its own algorithm for station kickout, skip + * mac80211 packet loss mechanism. + */ + if (ieee80211_hw_check(&sta->local->hw, REPORTS_LOW_ACK)) + return; + + /* This packet was aggregated but doesn't carry status info */ + if ((info->flags & IEEE80211_TX_CTL_AMPDU) && + !(info->flags & IEEE80211_TX_STAT_AMPDU)) + return; + + sta->deflink.status_stats.lost_packets++; + if (sta->sta.tdls) { + pkt_time = STA_LOST_TDLS_PKT_TIME; + pkt_thr = STA_LOST_PKT_THRESHOLD; + } + + /* + * If we're in TDLS mode, make sure that all STA_LOST_PKT_THRESHOLD + * of the last packets were lost, and that no ACK was received in the + * last STA_LOST_TDLS_PKT_TIME ms, before triggering the CQM packet-loss + * mechanism. + * For non-TDLS, use STA_LOST_PKT_THRESHOLD and STA_LOST_PKT_TIME + */ + if (sta->deflink.status_stats.lost_packets < pkt_thr || + !time_after(jiffies, sta->deflink.status_stats.last_pkt_time + pkt_time)) + return; + + cfg80211_cqm_pktloss_notify(sta->sdata->dev, sta->sta.addr, + sta->deflink.status_stats.lost_packets, + GFP_ATOMIC); + sta->deflink.status_stats.lost_packets = 0; +} + +static int ieee80211_tx_get_rates(struct ieee80211_hw *hw, + struct ieee80211_tx_info *info, + int *retry_count) +{ + int count = -1; + int i; + + for (i = 0; i < IEEE80211_TX_MAX_RATES; i++) { + if ((info->flags & IEEE80211_TX_CTL_AMPDU) && + !(info->flags & IEEE80211_TX_STAT_AMPDU)) { + /* just the first aggr frame carry status info */ + info->status.rates[i].idx = -1; + info->status.rates[i].count = 0; + break; + } else if (info->status.rates[i].idx < 0) { + break; + } else if (i >= hw->max_report_rates) { + /* the HW cannot have attempted that rate */ + info->status.rates[i].idx = -1; + info->status.rates[i].count = 0; + break; + } + + count += info->status.rates[i].count; + } + + if (count < 0) + count = 0; + + *retry_count = count; + return i - 1; +} + +void ieee80211_tx_monitor(struct ieee80211_local *local, struct sk_buff *skb, + int retry_count, int shift, bool send_to_cooked, + struct ieee80211_tx_status *status) +{ + struct sk_buff *skb2; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_sub_if_data *sdata; + struct net_device *prev_dev = NULL; + int rtap_len; + + /* send frame to monitor interfaces now */ + rtap_len = ieee80211_tx_radiotap_len(info, status); + if (WARN_ON_ONCE(skb_headroom(skb) < rtap_len)) { + pr_err("ieee80211_tx_status: headroom too small\n"); + dev_kfree_skb(skb); + return; + } + ieee80211_add_tx_radiotap_header(local, skb, retry_count, + rtap_len, shift, status); + + /* XXX: is this sufficient for BPF? */ + skb_reset_mac_header(skb); + skb->ip_summed = CHECKSUM_UNNECESSARY; + skb->pkt_type = PACKET_OTHERHOST; + skb->protocol = htons(ETH_P_802_2); + memset(skb->cb, 0, sizeof(skb->cb)); + + rcu_read_lock(); + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + if (sdata->vif.type == NL80211_IFTYPE_MONITOR) { + if (!ieee80211_sdata_running(sdata)) + continue; + + if ((sdata->u.mntr.flags & MONITOR_FLAG_COOK_FRAMES) && + !send_to_cooked) + continue; + + if (prev_dev) { + skb2 = skb_clone(skb, GFP_ATOMIC); + if (skb2) { + skb2->dev = prev_dev; + netif_rx(skb2); + } + } + + prev_dev = sdata->dev; + } + } + if (prev_dev) { + skb->dev = prev_dev; + netif_rx(skb); + skb = NULL; + } + rcu_read_unlock(); + dev_kfree_skb(skb); +} + +static void __ieee80211_tx_status(struct ieee80211_hw *hw, + struct ieee80211_tx_status *status, + int rates_idx, int retry_count) +{ + struct sk_buff *skb = status->skb; + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_tx_info *info = status->info; + struct sta_info *sta; + __le16 fc; + bool send_to_cooked; + bool acked; + bool noack_success; + struct ieee80211_bar *bar; + int shift = 0; + int tid = IEEE80211_NUM_TIDS; + + fc = hdr->frame_control; + + if (status->sta) { + sta = container_of(status->sta, struct sta_info, sta); + shift = ieee80211_vif_get_shift(&sta->sdata->vif); + + if (info->flags & IEEE80211_TX_STATUS_EOSP) + clear_sta_flag(sta, WLAN_STA_SP); + + acked = !!(info->flags & IEEE80211_TX_STAT_ACK); + noack_success = !!(info->flags & + IEEE80211_TX_STAT_NOACK_TRANSMITTED); + + /* mesh Peer Service Period support */ + if (ieee80211_vif_is_mesh(&sta->sdata->vif) && + ieee80211_is_data_qos(fc)) + ieee80211_mpsp_trigger_process( + ieee80211_get_qos_ctl(hdr), sta, true, acked); + + if (ieee80211_hw_check(&local->hw, HAS_RATE_CONTROL) && + (ieee80211_is_data(hdr->frame_control)) && + (rates_idx != -1)) + sta->deflink.tx_stats.last_rate = + info->status.rates[rates_idx]; + + if ((info->flags & IEEE80211_TX_STAT_AMPDU_NO_BACK) && + (ieee80211_is_data_qos(fc))) { + u16 ssn; + u8 *qc; + + qc = ieee80211_get_qos_ctl(hdr); + tid = qc[0] & 0xf; + ssn = ((le16_to_cpu(hdr->seq_ctrl) + 0x10) + & IEEE80211_SCTL_SEQ); + ieee80211_send_bar(&sta->sdata->vif, hdr->addr1, + tid, ssn); + } else if (ieee80211_is_data_qos(fc)) { + u8 *qc = ieee80211_get_qos_ctl(hdr); + + tid = qc[0] & 0xf; + } + + if (!acked && ieee80211_is_back_req(fc)) { + u16 control; + + /* + * BAR failed, store the last SSN and retry sending + * the BAR when the next unicast transmission on the + * same TID succeeds. + */ + bar = (struct ieee80211_bar *) skb->data; + control = le16_to_cpu(bar->control); + if (!(control & IEEE80211_BAR_CTRL_MULTI_TID)) { + u16 ssn = le16_to_cpu(bar->start_seq_num); + + tid = (control & + IEEE80211_BAR_CTRL_TID_INFO_MASK) >> + IEEE80211_BAR_CTRL_TID_INFO_SHIFT; + + ieee80211_set_bar_pending(sta, tid, ssn); + } + } + + if (info->flags & IEEE80211_TX_STAT_TX_FILTERED) { + ieee80211_handle_filtered_frame(local, sta, skb); + return; + } else if (ieee80211_is_data_present(fc)) { + if (!acked && !noack_success) + sta->deflink.status_stats.msdu_failed[tid]++; + + sta->deflink.status_stats.msdu_retries[tid] += + retry_count; + } + + if (!(info->flags & IEEE80211_TX_CTL_INJECTED) && acked) + ieee80211_frame_acked(sta, skb); + + } + + /* SNMP counters + * Fragments are passed to low-level drivers as separate skbs, so these + * are actually fragments, not frames. Update frame counters only for + * the first fragment of the frame. */ + if ((info->flags & IEEE80211_TX_STAT_ACK) || + (info->flags & IEEE80211_TX_STAT_NOACK_TRANSMITTED)) { + if (ieee80211_is_first_frag(hdr->seq_ctrl)) { + I802_DEBUG_INC(local->dot11TransmittedFrameCount); + if (is_multicast_ether_addr(ieee80211_get_DA(hdr))) + I802_DEBUG_INC(local->dot11MulticastTransmittedFrameCount); + if (retry_count > 0) + I802_DEBUG_INC(local->dot11RetryCount); + if (retry_count > 1) + I802_DEBUG_INC(local->dot11MultipleRetryCount); + } + + /* This counter shall be incremented for an acknowledged MPDU + * with an individual address in the address 1 field or an MPDU + * with a multicast address in the address 1 field of type Data + * or Management. */ + if (!is_multicast_ether_addr(hdr->addr1) || + ieee80211_is_data(fc) || + ieee80211_is_mgmt(fc)) + I802_DEBUG_INC(local->dot11TransmittedFragmentCount); + } else { + if (ieee80211_is_first_frag(hdr->seq_ctrl)) + I802_DEBUG_INC(local->dot11FailedCount); + } + + if (ieee80211_is_any_nullfunc(fc) && + ieee80211_has_pm(fc) && + ieee80211_hw_check(&local->hw, REPORTS_TX_ACK_STATUS) && + !(info->flags & IEEE80211_TX_CTL_INJECTED) && + local->ps_sdata && !(local->scanning)) { + if (info->flags & IEEE80211_TX_STAT_ACK) + local->ps_sdata->u.mgd.flags |= + IEEE80211_STA_NULLFUNC_ACKED; + mod_timer(&local->dynamic_ps_timer, + jiffies + msecs_to_jiffies(10)); + } + + ieee80211_report_used_skb(local, skb, false, status->ack_hwtstamp); + + /* this was a transmitted frame, but now we want to reuse it */ + skb_orphan(skb); + + /* Need to make a copy before skb->cb gets cleared */ + send_to_cooked = !!(info->flags & IEEE80211_TX_CTL_INJECTED) || + !(ieee80211_is_data(fc)); + + /* + * This is a bit racy but we can avoid a lot of work + * with this test... + */ + if (!local->monitors && (!send_to_cooked || !local->cooked_mntrs)) { + if (status->free_list) + list_add_tail(&skb->list, status->free_list); + else + dev_kfree_skb(skb); + return; + } + + /* send to monitor interfaces */ + ieee80211_tx_monitor(local, skb, retry_count, shift, + send_to_cooked, status); +} + +void ieee80211_tx_status(struct ieee80211_hw *hw, struct sk_buff *skb) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_tx_status status = { + .skb = skb, + .info = IEEE80211_SKB_CB(skb), + }; + struct sta_info *sta; + + rcu_read_lock(); + + sta = sta_info_get_by_addrs(local, hdr->addr1, hdr->addr2); + if (sta) + status.sta = &sta->sta; + + ieee80211_tx_status_ext(hw, &status); + rcu_read_unlock(); +} +EXPORT_SYMBOL(ieee80211_tx_status); + +void ieee80211_tx_status_ext(struct ieee80211_hw *hw, + struct ieee80211_tx_status *status) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_tx_info *info = status->info; + struct ieee80211_sta *pubsta = status->sta; + struct sk_buff *skb = status->skb; + struct sta_info *sta = NULL; + int rates_idx, retry_count; + bool acked, noack_success, ack_signal_valid; + u16 tx_time_est; + + if (pubsta) { + sta = container_of(pubsta, struct sta_info, sta); + + if (status->n_rates) + sta->deflink.tx_stats.last_rate_info = + status->rates[status->n_rates - 1].rate_idx; + } + + if (skb && (tx_time_est = + ieee80211_info_get_tx_time_est(IEEE80211_SKB_CB(skb))) > 0) { + /* Do this here to avoid the expensive lookup of the sta + * in ieee80211_report_used_skb(). + */ + ieee80211_sta_update_pending_airtime(local, sta, + skb_get_queue_mapping(skb), + tx_time_est, + true); + ieee80211_info_set_tx_time_est(IEEE80211_SKB_CB(skb), 0); + } + + if (!status->info) + goto free; + + rates_idx = ieee80211_tx_get_rates(hw, info, &retry_count); + + acked = !!(info->flags & IEEE80211_TX_STAT_ACK); + noack_success = !!(info->flags & IEEE80211_TX_STAT_NOACK_TRANSMITTED); + ack_signal_valid = + !!(info->status.flags & IEEE80211_TX_STATUS_ACK_SIGNAL_VALID); + + if (pubsta) { + struct ieee80211_sub_if_data *sdata = sta->sdata; + + if (!acked && !noack_success) + sta->deflink.status_stats.retry_failed++; + sta->deflink.status_stats.retry_count += retry_count; + + if (ieee80211_hw_check(&local->hw, REPORTS_TX_ACK_STATUS)) { + if (sdata->vif.type == NL80211_IFTYPE_STATION && + skb && !(info->flags & IEEE80211_TX_CTL_HW_80211_ENCAP)) + ieee80211_sta_tx_notify(sdata, (void *) skb->data, + acked, info->status.tx_time); + + if (acked) { + sta->deflink.status_stats.last_ack = jiffies; + + if (sta->deflink.status_stats.lost_packets) + sta->deflink.status_stats.lost_packets = 0; + + /* Track when last packet was ACKed */ + sta->deflink.status_stats.last_pkt_time = jiffies; + + /* Reset connection monitor */ + if (sdata->vif.type == NL80211_IFTYPE_STATION && + unlikely(sdata->u.mgd.probe_send_count > 0)) + sdata->u.mgd.probe_send_count = 0; + + if (ack_signal_valid) { + sta->deflink.status_stats.last_ack_signal = + (s8)info->status.ack_signal; + sta->deflink.status_stats.ack_signal_filled = true; + ewma_avg_signal_add(&sta->deflink.status_stats.avg_ack_signal, + -info->status.ack_signal); + } + } else if (test_sta_flag(sta, WLAN_STA_PS_STA)) { + /* + * The STA is in power save mode, so assume + * that this TX packet failed because of that. + */ + if (skb) + ieee80211_handle_filtered_frame(local, sta, skb); + return; + } else if (noack_success) { + /* nothing to do here, do not account as lost */ + } else { + ieee80211_lost_packet(sta, info); + } + } + + rate_control_tx_status(local, status); + if (ieee80211_vif_is_mesh(&sta->sdata->vif)) + ieee80211s_update_metric(local, sta, status); + } + + if (skb && !(info->flags & IEEE80211_TX_CTL_HW_80211_ENCAP)) + return __ieee80211_tx_status(hw, status, rates_idx, + retry_count); + + if (acked || noack_success) { + I802_DEBUG_INC(local->dot11TransmittedFrameCount); + if (!pubsta) + I802_DEBUG_INC(local->dot11MulticastTransmittedFrameCount); + if (retry_count > 0) + I802_DEBUG_INC(local->dot11RetryCount); + if (retry_count > 1) + I802_DEBUG_INC(local->dot11MultipleRetryCount); + } else { + I802_DEBUG_INC(local->dot11FailedCount); + } + +free: + if (!skb) + return; + + ieee80211_report_used_skb(local, skb, false, status->ack_hwtstamp); + if (status->free_list) + list_add_tail(&skb->list, status->free_list); + else + dev_kfree_skb(skb); +} +EXPORT_SYMBOL(ieee80211_tx_status_ext); + +void ieee80211_tx_rate_update(struct ieee80211_hw *hw, + struct ieee80211_sta *pubsta, + struct ieee80211_tx_info *info) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + struct ieee80211_tx_status status = { + .info = info, + .sta = pubsta, + }; + + rate_control_tx_status(local, &status); + + if (ieee80211_hw_check(&local->hw, HAS_RATE_CONTROL)) + sta->deflink.tx_stats.last_rate = info->status.rates[0]; +} +EXPORT_SYMBOL(ieee80211_tx_rate_update); + +void ieee80211_tx_status_8023(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, + struct sk_buff *skb) +{ + struct ieee80211_sub_if_data *sdata; + struct ieee80211_tx_status status = { + .skb = skb, + .info = IEEE80211_SKB_CB(skb), + }; + struct sta_info *sta; + + sdata = vif_to_sdata(vif); + + rcu_read_lock(); + + if (!ieee80211_lookup_ra_sta(sdata, skb, &sta) && !IS_ERR(sta)) + status.sta = &sta->sta; + + ieee80211_tx_status_ext(hw, &status); + + rcu_read_unlock(); +} +EXPORT_SYMBOL(ieee80211_tx_status_8023); + +void ieee80211_report_low_ack(struct ieee80211_sta *pubsta, u32 num_packets) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + cfg80211_cqm_pktloss_notify(sta->sdata->dev, sta->sta.addr, + num_packets, GFP_ATOMIC); +} +EXPORT_SYMBOL(ieee80211_report_low_ack); + +void ieee80211_free_txskb(struct ieee80211_hw *hw, struct sk_buff *skb) +{ + struct ieee80211_local *local = hw_to_local(hw); + ktime_t kt = ktime_set(0, 0); + + ieee80211_report_used_skb(local, skb, true, kt); + dev_kfree_skb_any(skb); +} +EXPORT_SYMBOL(ieee80211_free_txskb); + +void ieee80211_purge_tx_queue(struct ieee80211_hw *hw, + struct sk_buff_head *skbs) +{ + struct sk_buff *skb; + + while ((skb = __skb_dequeue(skbs))) + ieee80211_free_txskb(hw, skb); +} diff --git a/net/mac80211/tdls.c b/net/mac80211/tdls.c new file mode 100644 index 000000000..f4b4d25ee --- /dev/null +++ b/net/mac80211/tdls.c @@ -0,0 +1,2010 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * mac80211 TDLS handling code + * + * Copyright 2006-2010 Johannes Berg + * Copyright 2014, Intel Corporation + * Copyright 2014 Intel Mobile Communications GmbH + * Copyright 2015 - 2016 Intel Deutschland GmbH + * Copyright (C) 2019, 2021-2022 Intel Corporation + */ + +#include +#include +#include +#include +#include "ieee80211_i.h" +#include "driver-ops.h" +#include "rate.h" +#include "wme.h" + +/* give usermode some time for retries in setting up the TDLS session */ +#define TDLS_PEER_SETUP_TIMEOUT (15 * HZ) + +void ieee80211_tdls_peer_del_work(struct work_struct *wk) +{ + struct ieee80211_sub_if_data *sdata; + struct ieee80211_local *local; + + sdata = container_of(wk, struct ieee80211_sub_if_data, + u.mgd.tdls_peer_del_work.work); + local = sdata->local; + + mutex_lock(&local->mtx); + if (!is_zero_ether_addr(sdata->u.mgd.tdls_peer)) { + tdls_dbg(sdata, "TDLS del peer %pM\n", sdata->u.mgd.tdls_peer); + sta_info_destroy_addr(sdata, sdata->u.mgd.tdls_peer); + eth_zero_addr(sdata->u.mgd.tdls_peer); + } + mutex_unlock(&local->mtx); +} + +static void ieee80211_tdls_add_ext_capab(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + bool chan_switch = local->hw.wiphy->features & + NL80211_FEATURE_TDLS_CHANNEL_SWITCH; + bool wider_band = ieee80211_hw_check(&local->hw, TDLS_WIDER_BW) && + !ifmgd->tdls_wider_bw_prohibited; + bool buffer_sta = ieee80211_hw_check(&local->hw, + SUPPORTS_TDLS_BUFFER_STA); + struct ieee80211_supported_band *sband = ieee80211_get_sband(sdata); + bool vht = sband && sband->vht_cap.vht_supported; + u8 *pos = skb_put(skb, 10); + + *pos++ = WLAN_EID_EXT_CAPABILITY; + *pos++ = 8; /* len */ + *pos++ = 0x0; + *pos++ = 0x0; + *pos++ = 0x0; + *pos++ = (chan_switch ? WLAN_EXT_CAPA4_TDLS_CHAN_SWITCH : 0) | + (buffer_sta ? WLAN_EXT_CAPA4_TDLS_BUFFER_STA : 0); + *pos++ = WLAN_EXT_CAPA5_TDLS_ENABLED; + *pos++ = 0; + *pos++ = 0; + *pos++ = (vht && wider_band) ? WLAN_EXT_CAPA8_TDLS_WIDE_BW_ENABLED : 0; +} + +static u8 +ieee80211_tdls_add_subband(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, u16 start, u16 end, + u16 spacing) +{ + u8 subband_cnt = 0, ch_cnt = 0; + struct ieee80211_channel *ch; + struct cfg80211_chan_def chandef; + int i, subband_start; + struct wiphy *wiphy = sdata->local->hw.wiphy; + + for (i = start; i <= end; i += spacing) { + if (!ch_cnt) + subband_start = i; + + ch = ieee80211_get_channel(sdata->local->hw.wiphy, i); + if (ch) { + /* we will be active on the channel */ + cfg80211_chandef_create(&chandef, ch, + NL80211_CHAN_NO_HT); + if (cfg80211_reg_can_beacon_relax(wiphy, &chandef, + sdata->wdev.iftype)) { + ch_cnt++; + /* + * check if the next channel is also part of + * this allowed range + */ + continue; + } + } + + /* + * we've reached the end of a range, with allowed channels + * found + */ + if (ch_cnt) { + u8 *pos = skb_put(skb, 2); + *pos++ = ieee80211_frequency_to_channel(subband_start); + *pos++ = ch_cnt; + + subband_cnt++; + ch_cnt = 0; + } + } + + /* all channels in the requested range are allowed - add them here */ + if (ch_cnt) { + u8 *pos = skb_put(skb, 2); + *pos++ = ieee80211_frequency_to_channel(subband_start); + *pos++ = ch_cnt; + + subband_cnt++; + } + + return subband_cnt; +} + +static void +ieee80211_tdls_add_supp_channels(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + /* + * Add possible channels for TDLS. These are channels that are allowed + * to be active. + */ + u8 subband_cnt; + u8 *pos = skb_put(skb, 2); + + *pos++ = WLAN_EID_SUPPORTED_CHANNELS; + + /* + * 5GHz and 2GHz channels numbers can overlap. Ignore this for now, as + * this doesn't happen in real world scenarios. + */ + + /* 2GHz, with 5MHz spacing */ + subband_cnt = ieee80211_tdls_add_subband(sdata, skb, 2412, 2472, 5); + + /* 5GHz, with 20MHz spacing */ + subband_cnt += ieee80211_tdls_add_subband(sdata, skb, 5000, 5825, 20); + + /* length */ + *pos = 2 * subband_cnt; +} + +static void ieee80211_tdls_add_oper_classes(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + u8 *pos; + u8 op_class; + + if (!ieee80211_chandef_to_operating_class(&sdata->vif.bss_conf.chandef, + &op_class)) + return; + + pos = skb_put(skb, 4); + *pos++ = WLAN_EID_SUPPORTED_REGULATORY_CLASSES; + *pos++ = 2; /* len */ + + *pos++ = op_class; + *pos++ = op_class; /* give current operating class as alternate too */ +} + +static void ieee80211_tdls_add_bss_coex_ie(struct sk_buff *skb) +{ + u8 *pos = skb_put(skb, 3); + + *pos++ = WLAN_EID_BSS_COEX_2040; + *pos++ = 1; /* len */ + + *pos++ = WLAN_BSS_COEX_INFORMATION_REQUEST; +} + +static u16 ieee80211_get_tdls_sta_capab(struct ieee80211_sub_if_data *sdata, + u16 status_code) +{ + struct ieee80211_supported_band *sband; + + /* The capability will be 0 when sending a failure code */ + if (status_code != 0) + return 0; + + sband = ieee80211_get_sband(sdata); + if (sband && sband->band == NL80211_BAND_2GHZ) { + return WLAN_CAPABILITY_SHORT_SLOT_TIME | + WLAN_CAPABILITY_SHORT_PREAMBLE; + } + + return 0; +} + +static void ieee80211_tdls_add_link_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, const u8 *peer, + bool initiator) +{ + struct ieee80211_tdls_lnkie *lnkid; + const u8 *init_addr, *rsp_addr; + + if (initiator) { + init_addr = sdata->vif.addr; + rsp_addr = peer; + } else { + init_addr = peer; + rsp_addr = sdata->vif.addr; + } + + lnkid = skb_put(skb, sizeof(struct ieee80211_tdls_lnkie)); + + lnkid->ie_type = WLAN_EID_LINK_ID; + lnkid->ie_len = sizeof(struct ieee80211_tdls_lnkie) - 2; + + memcpy(lnkid->bssid, sdata->deflink.u.mgd.bssid, ETH_ALEN); + memcpy(lnkid->init_sta, init_addr, ETH_ALEN); + memcpy(lnkid->resp_sta, rsp_addr, ETH_ALEN); +} + +static void +ieee80211_tdls_add_aid(struct ieee80211_sub_if_data *sdata, struct sk_buff *skb) +{ + u8 *pos = skb_put(skb, 4); + + *pos++ = WLAN_EID_AID; + *pos++ = 2; /* len */ + put_unaligned_le16(sdata->vif.cfg.aid, pos); +} + +/* translate numbering in the WMM parameter IE to the mac80211 notation */ +static enum ieee80211_ac_numbers ieee80211_ac_from_wmm(int ac) +{ + switch (ac) { + default: + WARN_ON_ONCE(1); + fallthrough; + case 0: + return IEEE80211_AC_BE; + case 1: + return IEEE80211_AC_BK; + case 2: + return IEEE80211_AC_VI; + case 3: + return IEEE80211_AC_VO; + } +} + +static u8 ieee80211_wmm_aci_aifsn(int aifsn, bool acm, int aci) +{ + u8 ret; + + ret = aifsn & 0x0f; + if (acm) + ret |= 0x10; + ret |= (aci << 5) & 0x60; + return ret; +} + +static u8 ieee80211_wmm_ecw(u16 cw_min, u16 cw_max) +{ + return ((ilog2(cw_min + 1) << 0x0) & 0x0f) | + ((ilog2(cw_max + 1) << 0x4) & 0xf0); +} + +static void ieee80211_tdls_add_wmm_param_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_wmm_param_ie *wmm; + struct ieee80211_tx_queue_params *txq; + int i; + + wmm = skb_put_zero(skb, sizeof(*wmm)); + + wmm->element_id = WLAN_EID_VENDOR_SPECIFIC; + wmm->len = sizeof(*wmm) - 2; + + wmm->oui[0] = 0x00; /* Microsoft OUI 00:50:F2 */ + wmm->oui[1] = 0x50; + wmm->oui[2] = 0xf2; + wmm->oui_type = 2; /* WME */ + wmm->oui_subtype = 1; /* WME param */ + wmm->version = 1; /* WME ver */ + wmm->qos_info = 0; /* U-APSD not in use */ + + /* + * Use the EDCA parameters defined for the BSS, or default if the AP + * doesn't support it, as mandated by 802.11-2012 section 10.22.4 + */ + for (i = 0; i < IEEE80211_NUM_ACS; i++) { + txq = &sdata->deflink.tx_conf[ieee80211_ac_from_wmm(i)]; + wmm->ac[i].aci_aifsn = ieee80211_wmm_aci_aifsn(txq->aifs, + txq->acm, i); + wmm->ac[i].cw = ieee80211_wmm_ecw(txq->cw_min, txq->cw_max); + wmm->ac[i].txop_limit = cpu_to_le16(txq->txop); + } +} + +static void +ieee80211_tdls_chandef_vht_upgrade(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta) +{ + /* IEEE802.11ac-2013 Table E-4 */ + u16 centers_80mhz[] = { 5210, 5290, 5530, 5610, 5690, 5775 }; + struct cfg80211_chan_def uc = sta->tdls_chandef; + enum nl80211_chan_width max_width = + ieee80211_sta_cap_chan_bw(&sta->deflink); + int i; + + /* only support upgrading non-narrow channels up to 80Mhz */ + if (max_width == NL80211_CHAN_WIDTH_5 || + max_width == NL80211_CHAN_WIDTH_10) + return; + + if (max_width > NL80211_CHAN_WIDTH_80) + max_width = NL80211_CHAN_WIDTH_80; + + if (uc.width >= max_width) + return; + /* + * Channel usage constrains in the IEEE802.11ac-2013 specification only + * allow expanding a 20MHz channel to 80MHz in a single way. In + * addition, there are no 40MHz allowed channels that are not part of + * the allowed 80MHz range in the 5GHz spectrum (the relevant one here). + */ + for (i = 0; i < ARRAY_SIZE(centers_80mhz); i++) + if (abs(uc.chan->center_freq - centers_80mhz[i]) <= 30) { + uc.center_freq1 = centers_80mhz[i]; + uc.center_freq2 = 0; + uc.width = NL80211_CHAN_WIDTH_80; + break; + } + + if (!uc.center_freq1) + return; + + /* proceed to downgrade the chandef until usable or the same as AP BW */ + while (uc.width > max_width || + (uc.width > sta->tdls_chandef.width && + !cfg80211_reg_can_beacon_relax(sdata->local->hw.wiphy, &uc, + sdata->wdev.iftype))) + ieee80211_chandef_downgrade(&uc); + + if (!cfg80211_chandef_identical(&uc, &sta->tdls_chandef)) { + tdls_dbg(sdata, "TDLS ch width upgraded %d -> %d\n", + sta->tdls_chandef.width, uc.width); + + /* + * the station is not yet authorized when BW upgrade is done, + * locking is not required + */ + sta->tdls_chandef = uc; + } +} + +static void +ieee80211_tdls_add_setup_start_ies(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, const u8 *peer, + u8 action_code, bool initiator, + const u8 *extra_ies, size_t extra_ies_len) +{ + struct ieee80211_supported_band *sband; + struct ieee80211_local *local = sdata->local; + struct ieee80211_sta_ht_cap ht_cap; + struct ieee80211_sta_vht_cap vht_cap; + struct sta_info *sta = NULL; + size_t offset = 0, noffset; + u8 *pos; + + sband = ieee80211_get_sband(sdata); + if (!sband) + return; + + ieee80211_add_srates_ie(sdata, skb, false, sband->band); + ieee80211_add_ext_srates_ie(sdata, skb, false, sband->band); + ieee80211_tdls_add_supp_channels(sdata, skb); + + /* add any custom IEs that go before Extended Capabilities */ + if (extra_ies_len) { + static const u8 before_ext_cap[] = { + WLAN_EID_SUPP_RATES, + WLAN_EID_COUNTRY, + WLAN_EID_EXT_SUPP_RATES, + WLAN_EID_SUPPORTED_CHANNELS, + WLAN_EID_RSN, + }; + noffset = ieee80211_ie_split(extra_ies, extra_ies_len, + before_ext_cap, + ARRAY_SIZE(before_ext_cap), + offset); + skb_put_data(skb, extra_ies + offset, noffset - offset); + offset = noffset; + } + + ieee80211_tdls_add_ext_capab(sdata, skb); + + /* add the QoS element if we support it */ + if (local->hw.queues >= IEEE80211_NUM_ACS && + action_code != WLAN_PUB_ACTION_TDLS_DISCOVER_RES) + ieee80211_add_wmm_info_ie(skb_put(skb, 9), 0); /* no U-APSD */ + + /* add any custom IEs that go before HT capabilities */ + if (extra_ies_len) { + static const u8 before_ht_cap[] = { + WLAN_EID_SUPP_RATES, + WLAN_EID_COUNTRY, + WLAN_EID_EXT_SUPP_RATES, + WLAN_EID_SUPPORTED_CHANNELS, + WLAN_EID_RSN, + WLAN_EID_EXT_CAPABILITY, + WLAN_EID_QOS_CAPA, + WLAN_EID_FAST_BSS_TRANSITION, + WLAN_EID_TIMEOUT_INTERVAL, + WLAN_EID_SUPPORTED_REGULATORY_CLASSES, + }; + noffset = ieee80211_ie_split(extra_ies, extra_ies_len, + before_ht_cap, + ARRAY_SIZE(before_ht_cap), + offset); + skb_put_data(skb, extra_ies + offset, noffset - offset); + offset = noffset; + } + + mutex_lock(&local->sta_mtx); + + /* we should have the peer STA if we're already responding */ + if (action_code == WLAN_TDLS_SETUP_RESPONSE) { + sta = sta_info_get(sdata, peer); + if (WARN_ON_ONCE(!sta)) { + mutex_unlock(&local->sta_mtx); + return; + } + + sta->tdls_chandef = sdata->vif.bss_conf.chandef; + } + + ieee80211_tdls_add_oper_classes(sdata, skb); + + /* + * with TDLS we can switch channels, and HT-caps are not necessarily + * the same on all bands. The specification limits the setup to a + * single HT-cap, so use the current band for now. + */ + memcpy(&ht_cap, &sband->ht_cap, sizeof(ht_cap)); + + if ((action_code == WLAN_TDLS_SETUP_REQUEST || + action_code == WLAN_PUB_ACTION_TDLS_DISCOVER_RES) && + ht_cap.ht_supported) { + ieee80211_apply_htcap_overrides(sdata, &ht_cap); + + /* disable SMPS in TDLS initiator */ + ht_cap.cap |= WLAN_HT_CAP_SM_PS_DISABLED + << IEEE80211_HT_CAP_SM_PS_SHIFT; + + pos = skb_put(skb, sizeof(struct ieee80211_ht_cap) + 2); + ieee80211_ie_build_ht_cap(pos, &ht_cap, ht_cap.cap); + } else if (action_code == WLAN_TDLS_SETUP_RESPONSE && + ht_cap.ht_supported && sta->sta.deflink.ht_cap.ht_supported) { + /* the peer caps are already intersected with our own */ + memcpy(&ht_cap, &sta->sta.deflink.ht_cap, sizeof(ht_cap)); + + pos = skb_put(skb, sizeof(struct ieee80211_ht_cap) + 2); + ieee80211_ie_build_ht_cap(pos, &ht_cap, ht_cap.cap); + } + + if (ht_cap.ht_supported && + (ht_cap.cap & IEEE80211_HT_CAP_SUP_WIDTH_20_40)) + ieee80211_tdls_add_bss_coex_ie(skb); + + ieee80211_tdls_add_link_ie(sdata, skb, peer, initiator); + + /* add any custom IEs that go before VHT capabilities */ + if (extra_ies_len) { + static const u8 before_vht_cap[] = { + WLAN_EID_SUPP_RATES, + WLAN_EID_COUNTRY, + WLAN_EID_EXT_SUPP_RATES, + WLAN_EID_SUPPORTED_CHANNELS, + WLAN_EID_RSN, + WLAN_EID_EXT_CAPABILITY, + WLAN_EID_QOS_CAPA, + WLAN_EID_FAST_BSS_TRANSITION, + WLAN_EID_TIMEOUT_INTERVAL, + WLAN_EID_SUPPORTED_REGULATORY_CLASSES, + WLAN_EID_MULTI_BAND, + }; + noffset = ieee80211_ie_split(extra_ies, extra_ies_len, + before_vht_cap, + ARRAY_SIZE(before_vht_cap), + offset); + skb_put_data(skb, extra_ies + offset, noffset - offset); + offset = noffset; + } + + /* build the VHT-cap similarly to the HT-cap */ + memcpy(&vht_cap, &sband->vht_cap, sizeof(vht_cap)); + if ((action_code == WLAN_TDLS_SETUP_REQUEST || + action_code == WLAN_PUB_ACTION_TDLS_DISCOVER_RES) && + vht_cap.vht_supported) { + ieee80211_apply_vhtcap_overrides(sdata, &vht_cap); + + /* the AID is present only when VHT is implemented */ + if (action_code == WLAN_TDLS_SETUP_REQUEST) + ieee80211_tdls_add_aid(sdata, skb); + + pos = skb_put(skb, sizeof(struct ieee80211_vht_cap) + 2); + ieee80211_ie_build_vht_cap(pos, &vht_cap, vht_cap.cap); + } else if (action_code == WLAN_TDLS_SETUP_RESPONSE && + vht_cap.vht_supported && sta->sta.deflink.vht_cap.vht_supported) { + /* the peer caps are already intersected with our own */ + memcpy(&vht_cap, &sta->sta.deflink.vht_cap, sizeof(vht_cap)); + + /* the AID is present only when VHT is implemented */ + ieee80211_tdls_add_aid(sdata, skb); + + pos = skb_put(skb, sizeof(struct ieee80211_vht_cap) + 2); + ieee80211_ie_build_vht_cap(pos, &vht_cap, vht_cap.cap); + + /* + * if both peers support WIDER_BW, we can expand the chandef to + * a wider compatible one, up to 80MHz + */ + if (test_sta_flag(sta, WLAN_STA_TDLS_WIDER_BW)) + ieee80211_tdls_chandef_vht_upgrade(sdata, sta); + } + + mutex_unlock(&local->sta_mtx); + + /* add any remaining IEs */ + if (extra_ies_len) { + noffset = extra_ies_len; + skb_put_data(skb, extra_ies + offset, noffset - offset); + } + +} + +static void +ieee80211_tdls_add_setup_cfm_ies(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, const u8 *peer, + bool initiator, const u8 *extra_ies, + size_t extra_ies_len) +{ + struct ieee80211_local *local = sdata->local; + size_t offset = 0, noffset; + struct sta_info *sta, *ap_sta; + struct ieee80211_supported_band *sband; + u8 *pos; + + sband = ieee80211_get_sband(sdata); + if (!sband) + return; + + mutex_lock(&local->sta_mtx); + + sta = sta_info_get(sdata, peer); + ap_sta = sta_info_get(sdata, sdata->deflink.u.mgd.bssid); + if (WARN_ON_ONCE(!sta || !ap_sta)) { + mutex_unlock(&local->sta_mtx); + return; + } + + sta->tdls_chandef = sdata->vif.bss_conf.chandef; + + /* add any custom IEs that go before the QoS IE */ + if (extra_ies_len) { + static const u8 before_qos[] = { + WLAN_EID_RSN, + }; + noffset = ieee80211_ie_split(extra_ies, extra_ies_len, + before_qos, + ARRAY_SIZE(before_qos), + offset); + skb_put_data(skb, extra_ies + offset, noffset - offset); + offset = noffset; + } + + /* add the QoS param IE if both the peer and we support it */ + if (local->hw.queues >= IEEE80211_NUM_ACS && sta->sta.wme) + ieee80211_tdls_add_wmm_param_ie(sdata, skb); + + /* add any custom IEs that go before HT operation */ + if (extra_ies_len) { + static const u8 before_ht_op[] = { + WLAN_EID_RSN, + WLAN_EID_QOS_CAPA, + WLAN_EID_FAST_BSS_TRANSITION, + WLAN_EID_TIMEOUT_INTERVAL, + }; + noffset = ieee80211_ie_split(extra_ies, extra_ies_len, + before_ht_op, + ARRAY_SIZE(before_ht_op), + offset); + skb_put_data(skb, extra_ies + offset, noffset - offset); + offset = noffset; + } + + /* + * if HT support is only added in TDLS, we need an HT-operation IE. + * add the IE as required by IEEE802.11-2012 9.23.3.2. + */ + if (!ap_sta->sta.deflink.ht_cap.ht_supported && sta->sta.deflink.ht_cap.ht_supported) { + u16 prot = IEEE80211_HT_OP_MODE_PROTECTION_NONHT_MIXED | + IEEE80211_HT_OP_MODE_NON_GF_STA_PRSNT | + IEEE80211_HT_OP_MODE_NON_HT_STA_PRSNT; + + pos = skb_put(skb, 2 + sizeof(struct ieee80211_ht_operation)); + ieee80211_ie_build_ht_oper(pos, &sta->sta.deflink.ht_cap, + &sdata->vif.bss_conf.chandef, prot, + true); + } + + ieee80211_tdls_add_link_ie(sdata, skb, peer, initiator); + + /* only include VHT-operation if not on the 2.4GHz band */ + if (sband->band != NL80211_BAND_2GHZ && + sta->sta.deflink.vht_cap.vht_supported) { + /* + * if both peers support WIDER_BW, we can expand the chandef to + * a wider compatible one, up to 80MHz + */ + if (test_sta_flag(sta, WLAN_STA_TDLS_WIDER_BW)) + ieee80211_tdls_chandef_vht_upgrade(sdata, sta); + + pos = skb_put(skb, 2 + sizeof(struct ieee80211_vht_operation)); + ieee80211_ie_build_vht_oper(pos, &sta->sta.deflink.vht_cap, + &sta->tdls_chandef); + } + + mutex_unlock(&local->sta_mtx); + + /* add any remaining IEs */ + if (extra_ies_len) { + noffset = extra_ies_len; + skb_put_data(skb, extra_ies + offset, noffset - offset); + } +} + +static void +ieee80211_tdls_add_chan_switch_req_ies(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, const u8 *peer, + bool initiator, const u8 *extra_ies, + size_t extra_ies_len, u8 oper_class, + struct cfg80211_chan_def *chandef) +{ + struct ieee80211_tdls_data *tf; + size_t offset = 0, noffset; + + if (WARN_ON_ONCE(!chandef)) + return; + + tf = (void *)skb->data; + tf->u.chan_switch_req.target_channel = + ieee80211_frequency_to_channel(chandef->chan->center_freq); + tf->u.chan_switch_req.oper_class = oper_class; + + if (extra_ies_len) { + static const u8 before_lnkie[] = { + WLAN_EID_SECONDARY_CHANNEL_OFFSET, + }; + noffset = ieee80211_ie_split(extra_ies, extra_ies_len, + before_lnkie, + ARRAY_SIZE(before_lnkie), + offset); + skb_put_data(skb, extra_ies + offset, noffset - offset); + offset = noffset; + } + + ieee80211_tdls_add_link_ie(sdata, skb, peer, initiator); + + /* add any remaining IEs */ + if (extra_ies_len) { + noffset = extra_ies_len; + skb_put_data(skb, extra_ies + offset, noffset - offset); + } +} + +static void +ieee80211_tdls_add_chan_switch_resp_ies(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, const u8 *peer, + u16 status_code, bool initiator, + const u8 *extra_ies, + size_t extra_ies_len) +{ + if (status_code == 0) + ieee80211_tdls_add_link_ie(sdata, skb, peer, initiator); + + if (extra_ies_len) + skb_put_data(skb, extra_ies, extra_ies_len); +} + +static void ieee80211_tdls_add_ies(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, const u8 *peer, + u8 action_code, u16 status_code, + bool initiator, const u8 *extra_ies, + size_t extra_ies_len, u8 oper_class, + struct cfg80211_chan_def *chandef) +{ + switch (action_code) { + case WLAN_TDLS_SETUP_REQUEST: + case WLAN_TDLS_SETUP_RESPONSE: + case WLAN_PUB_ACTION_TDLS_DISCOVER_RES: + if (status_code == 0) + ieee80211_tdls_add_setup_start_ies(sdata, skb, peer, + action_code, + initiator, + extra_ies, + extra_ies_len); + break; + case WLAN_TDLS_SETUP_CONFIRM: + if (status_code == 0) + ieee80211_tdls_add_setup_cfm_ies(sdata, skb, peer, + initiator, extra_ies, + extra_ies_len); + break; + case WLAN_TDLS_TEARDOWN: + case WLAN_TDLS_DISCOVERY_REQUEST: + if (extra_ies_len) + skb_put_data(skb, extra_ies, extra_ies_len); + if (status_code == 0 || action_code == WLAN_TDLS_TEARDOWN) + ieee80211_tdls_add_link_ie(sdata, skb, peer, initiator); + break; + case WLAN_TDLS_CHANNEL_SWITCH_REQUEST: + ieee80211_tdls_add_chan_switch_req_ies(sdata, skb, peer, + initiator, extra_ies, + extra_ies_len, + oper_class, chandef); + break; + case WLAN_TDLS_CHANNEL_SWITCH_RESPONSE: + ieee80211_tdls_add_chan_switch_resp_ies(sdata, skb, peer, + status_code, + initiator, extra_ies, + extra_ies_len); + break; + } + +} + +static int +ieee80211_prep_tdls_encap_data(struct wiphy *wiphy, struct net_device *dev, + const u8 *peer, u8 action_code, u8 dialog_token, + u16 status_code, struct sk_buff *skb) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_tdls_data *tf; + + tf = skb_put(skb, offsetof(struct ieee80211_tdls_data, u)); + + memcpy(tf->da, peer, ETH_ALEN); + memcpy(tf->sa, sdata->vif.addr, ETH_ALEN); + tf->ether_type = cpu_to_be16(ETH_P_TDLS); + tf->payload_type = WLAN_TDLS_SNAP_RFTYPE; + + /* network header is after the ethernet header */ + skb_set_network_header(skb, ETH_HLEN); + + switch (action_code) { + case WLAN_TDLS_SETUP_REQUEST: + tf->category = WLAN_CATEGORY_TDLS; + tf->action_code = WLAN_TDLS_SETUP_REQUEST; + + skb_put(skb, sizeof(tf->u.setup_req)); + tf->u.setup_req.dialog_token = dialog_token; + tf->u.setup_req.capability = + cpu_to_le16(ieee80211_get_tdls_sta_capab(sdata, + status_code)); + break; + case WLAN_TDLS_SETUP_RESPONSE: + tf->category = WLAN_CATEGORY_TDLS; + tf->action_code = WLAN_TDLS_SETUP_RESPONSE; + + skb_put(skb, sizeof(tf->u.setup_resp)); + tf->u.setup_resp.status_code = cpu_to_le16(status_code); + tf->u.setup_resp.dialog_token = dialog_token; + tf->u.setup_resp.capability = + cpu_to_le16(ieee80211_get_tdls_sta_capab(sdata, + status_code)); + break; + case WLAN_TDLS_SETUP_CONFIRM: + tf->category = WLAN_CATEGORY_TDLS; + tf->action_code = WLAN_TDLS_SETUP_CONFIRM; + + skb_put(skb, sizeof(tf->u.setup_cfm)); + tf->u.setup_cfm.status_code = cpu_to_le16(status_code); + tf->u.setup_cfm.dialog_token = dialog_token; + break; + case WLAN_TDLS_TEARDOWN: + tf->category = WLAN_CATEGORY_TDLS; + tf->action_code = WLAN_TDLS_TEARDOWN; + + skb_put(skb, sizeof(tf->u.teardown)); + tf->u.teardown.reason_code = cpu_to_le16(status_code); + break; + case WLAN_TDLS_DISCOVERY_REQUEST: + tf->category = WLAN_CATEGORY_TDLS; + tf->action_code = WLAN_TDLS_DISCOVERY_REQUEST; + + skb_put(skb, sizeof(tf->u.discover_req)); + tf->u.discover_req.dialog_token = dialog_token; + break; + case WLAN_TDLS_CHANNEL_SWITCH_REQUEST: + tf->category = WLAN_CATEGORY_TDLS; + tf->action_code = WLAN_TDLS_CHANNEL_SWITCH_REQUEST; + + skb_put(skb, sizeof(tf->u.chan_switch_req)); + break; + case WLAN_TDLS_CHANNEL_SWITCH_RESPONSE: + tf->category = WLAN_CATEGORY_TDLS; + tf->action_code = WLAN_TDLS_CHANNEL_SWITCH_RESPONSE; + + skb_put(skb, sizeof(tf->u.chan_switch_resp)); + tf->u.chan_switch_resp.status_code = cpu_to_le16(status_code); + break; + default: + return -EINVAL; + } + + return 0; +} + +static int +ieee80211_prep_tdls_direct(struct wiphy *wiphy, struct net_device *dev, + const u8 *peer, u8 action_code, u8 dialog_token, + u16 status_code, struct sk_buff *skb) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_mgmt *mgmt; + + mgmt = skb_put_zero(skb, 24); + memcpy(mgmt->da, peer, ETH_ALEN); + memcpy(mgmt->sa, sdata->vif.addr, ETH_ALEN); + memcpy(mgmt->bssid, sdata->deflink.u.mgd.bssid, ETH_ALEN); + + mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_ACTION); + + switch (action_code) { + case WLAN_PUB_ACTION_TDLS_DISCOVER_RES: + skb_put(skb, 1 + sizeof(mgmt->u.action.u.tdls_discover_resp)); + mgmt->u.action.category = WLAN_CATEGORY_PUBLIC; + mgmt->u.action.u.tdls_discover_resp.action_code = + WLAN_PUB_ACTION_TDLS_DISCOVER_RES; + mgmt->u.action.u.tdls_discover_resp.dialog_token = + dialog_token; + mgmt->u.action.u.tdls_discover_resp.capability = + cpu_to_le16(ieee80211_get_tdls_sta_capab(sdata, + status_code)); + break; + default: + return -EINVAL; + } + + return 0; +} + +static struct sk_buff * +ieee80211_tdls_build_mgmt_packet_data(struct ieee80211_sub_if_data *sdata, + const u8 *peer, u8 action_code, + u8 dialog_token, u16 status_code, + bool initiator, const u8 *extra_ies, + size_t extra_ies_len, u8 oper_class, + struct cfg80211_chan_def *chandef) +{ + struct ieee80211_local *local = sdata->local; + struct sk_buff *skb; + int ret; + + skb = netdev_alloc_skb(sdata->dev, + local->hw.extra_tx_headroom + + max(sizeof(struct ieee80211_mgmt), + sizeof(struct ieee80211_tdls_data)) + + 50 + /* supported rates */ + 10 + /* ext capab */ + 26 + /* max(WMM-info, WMM-param) */ + 2 + max(sizeof(struct ieee80211_ht_cap), + sizeof(struct ieee80211_ht_operation)) + + 2 + max(sizeof(struct ieee80211_vht_cap), + sizeof(struct ieee80211_vht_operation)) + + 50 + /* supported channels */ + 3 + /* 40/20 BSS coex */ + 4 + /* AID */ + 4 + /* oper classes */ + extra_ies_len + + sizeof(struct ieee80211_tdls_lnkie)); + if (!skb) + return NULL; + + skb_reserve(skb, local->hw.extra_tx_headroom); + + switch (action_code) { + case WLAN_TDLS_SETUP_REQUEST: + case WLAN_TDLS_SETUP_RESPONSE: + case WLAN_TDLS_SETUP_CONFIRM: + case WLAN_TDLS_TEARDOWN: + case WLAN_TDLS_DISCOVERY_REQUEST: + case WLAN_TDLS_CHANNEL_SWITCH_REQUEST: + case WLAN_TDLS_CHANNEL_SWITCH_RESPONSE: + ret = ieee80211_prep_tdls_encap_data(local->hw.wiphy, + sdata->dev, peer, + action_code, dialog_token, + status_code, skb); + break; + case WLAN_PUB_ACTION_TDLS_DISCOVER_RES: + ret = ieee80211_prep_tdls_direct(local->hw.wiphy, sdata->dev, + peer, action_code, + dialog_token, status_code, + skb); + break; + default: + ret = -ENOTSUPP; + break; + } + + if (ret < 0) + goto fail; + + ieee80211_tdls_add_ies(sdata, skb, peer, action_code, status_code, + initiator, extra_ies, extra_ies_len, oper_class, + chandef); + return skb; + +fail: + dev_kfree_skb(skb); + return NULL; +} + +static int +ieee80211_tdls_prep_mgmt_packet(struct wiphy *wiphy, struct net_device *dev, + const u8 *peer, u8 action_code, u8 dialog_token, + u16 status_code, u32 peer_capability, + bool initiator, const u8 *extra_ies, + size_t extra_ies_len, u8 oper_class, + struct cfg80211_chan_def *chandef) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct sk_buff *skb = NULL; + struct sta_info *sta; + u32 flags = 0; + int ret = 0; + + rcu_read_lock(); + sta = sta_info_get(sdata, peer); + + /* infer the initiator if we can, to support old userspace */ + switch (action_code) { + case WLAN_TDLS_SETUP_REQUEST: + if (sta) { + set_sta_flag(sta, WLAN_STA_TDLS_INITIATOR); + sta->sta.tdls_initiator = false; + } + fallthrough; + case WLAN_TDLS_SETUP_CONFIRM: + case WLAN_TDLS_DISCOVERY_REQUEST: + initiator = true; + break; + case WLAN_TDLS_SETUP_RESPONSE: + /* + * In some testing scenarios, we send a request and response. + * Make the last packet sent take effect for the initiator + * value. + */ + if (sta) { + clear_sta_flag(sta, WLAN_STA_TDLS_INITIATOR); + sta->sta.tdls_initiator = true; + } + fallthrough; + case WLAN_PUB_ACTION_TDLS_DISCOVER_RES: + initiator = false; + break; + case WLAN_TDLS_TEARDOWN: + case WLAN_TDLS_CHANNEL_SWITCH_REQUEST: + case WLAN_TDLS_CHANNEL_SWITCH_RESPONSE: + /* any value is ok */ + break; + default: + ret = -ENOTSUPP; + break; + } + + if (sta && test_sta_flag(sta, WLAN_STA_TDLS_INITIATOR)) + initiator = true; + + rcu_read_unlock(); + if (ret < 0) + goto fail; + + skb = ieee80211_tdls_build_mgmt_packet_data(sdata, peer, action_code, + dialog_token, status_code, + initiator, extra_ies, + extra_ies_len, oper_class, + chandef); + if (!skb) { + ret = -EINVAL; + goto fail; + } + + if (action_code == WLAN_PUB_ACTION_TDLS_DISCOVER_RES) { + ieee80211_tx_skb(sdata, skb); + return 0; + } + + /* + * According to 802.11z: Setup req/resp are sent in AC_BK, otherwise + * we should default to AC_VI. + */ + switch (action_code) { + case WLAN_TDLS_SETUP_REQUEST: + case WLAN_TDLS_SETUP_RESPONSE: + skb->priority = 256 + 2; + break; + default: + skb->priority = 256 + 5; + break; + } + skb_set_queue_mapping(skb, ieee80211_select_queue(sdata, skb)); + + /* + * Set the WLAN_TDLS_TEARDOWN flag to indicate a teardown in progress. + * Later, if no ACK is returned from peer, we will re-send the teardown + * packet through the AP. + */ + if ((action_code == WLAN_TDLS_TEARDOWN) && + ieee80211_hw_check(&sdata->local->hw, REPORTS_TX_ACK_STATUS)) { + bool try_resend; /* Should we keep skb for possible resend */ + + /* If not sending directly to peer - no point in keeping skb */ + rcu_read_lock(); + sta = sta_info_get(sdata, peer); + try_resend = sta && test_sta_flag(sta, WLAN_STA_TDLS_PEER_AUTH); + rcu_read_unlock(); + + spin_lock_bh(&sdata->u.mgd.teardown_lock); + if (try_resend && !sdata->u.mgd.teardown_skb) { + /* Mark it as requiring TX status callback */ + flags |= IEEE80211_TX_CTL_REQ_TX_STATUS | + IEEE80211_TX_INTFL_MLME_CONN_TX; + + /* + * skb is copied since mac80211 will later set + * properties that might not be the same as the AP, + * such as encryption, QoS, addresses, etc. + * + * No problem if skb_copy() fails, so no need to check. + */ + sdata->u.mgd.teardown_skb = skb_copy(skb, GFP_ATOMIC); + sdata->u.mgd.orig_teardown_skb = skb; + } + spin_unlock_bh(&sdata->u.mgd.teardown_lock); + } + + /* disable bottom halves when entering the Tx path */ + local_bh_disable(); + __ieee80211_subif_start_xmit(skb, dev, flags, + IEEE80211_TX_CTRL_MLO_LINK_UNSPEC, NULL); + local_bh_enable(); + + return ret; + +fail: + dev_kfree_skb(skb); + return ret; +} + +static int +ieee80211_tdls_mgmt_setup(struct wiphy *wiphy, struct net_device *dev, + const u8 *peer, u8 action_code, u8 dialog_token, + u16 status_code, u32 peer_capability, bool initiator, + const u8 *extra_ies, size_t extra_ies_len) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + enum ieee80211_smps_mode smps_mode = + sdata->deflink.u.mgd.driver_smps_mode; + int ret; + + /* don't support setup with forced SMPS mode that's not off */ + if (smps_mode != IEEE80211_SMPS_AUTOMATIC && + smps_mode != IEEE80211_SMPS_OFF) { + tdls_dbg(sdata, "Aborting TDLS setup due to SMPS mode %d\n", + smps_mode); + return -ENOTSUPP; + } + + mutex_lock(&local->mtx); + + /* we don't support concurrent TDLS peer setups */ + if (!is_zero_ether_addr(sdata->u.mgd.tdls_peer) && + !ether_addr_equal(sdata->u.mgd.tdls_peer, peer)) { + ret = -EBUSY; + goto out_unlock; + } + + /* + * make sure we have a STA representing the peer so we drop or buffer + * non-TDLS-setup frames to the peer. We can't send other packets + * during setup through the AP path. + * Allow error packets to be sent - sometimes we don't even add a STA + * before failing the setup. + */ + if (status_code == 0) { + rcu_read_lock(); + if (!sta_info_get(sdata, peer)) { + rcu_read_unlock(); + ret = -ENOLINK; + goto out_unlock; + } + rcu_read_unlock(); + } + + ieee80211_flush_queues(local, sdata, false); + memcpy(sdata->u.mgd.tdls_peer, peer, ETH_ALEN); + mutex_unlock(&local->mtx); + + /* we cannot take the mutex while preparing the setup packet */ + ret = ieee80211_tdls_prep_mgmt_packet(wiphy, dev, peer, action_code, + dialog_token, status_code, + peer_capability, initiator, + extra_ies, extra_ies_len, 0, + NULL); + if (ret < 0) { + mutex_lock(&local->mtx); + eth_zero_addr(sdata->u.mgd.tdls_peer); + mutex_unlock(&local->mtx); + return ret; + } + + ieee80211_queue_delayed_work(&sdata->local->hw, + &sdata->u.mgd.tdls_peer_del_work, + TDLS_PEER_SETUP_TIMEOUT); + return 0; + +out_unlock: + mutex_unlock(&local->mtx); + return ret; +} + +static int +ieee80211_tdls_mgmt_teardown(struct wiphy *wiphy, struct net_device *dev, + const u8 *peer, u8 action_code, u8 dialog_token, + u16 status_code, u32 peer_capability, + bool initiator, const u8 *extra_ies, + size_t extra_ies_len) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + int ret; + + /* + * No packets can be transmitted to the peer via the AP during setup - + * the STA is set as a TDLS peer, but is not authorized. + * During teardown, we prevent direct transmissions by stopping the + * queues and flushing all direct packets. + */ + ieee80211_stop_vif_queues(local, sdata, + IEEE80211_QUEUE_STOP_REASON_TDLS_TEARDOWN); + ieee80211_flush_queues(local, sdata, false); + + ret = ieee80211_tdls_prep_mgmt_packet(wiphy, dev, peer, action_code, + dialog_token, status_code, + peer_capability, initiator, + extra_ies, extra_ies_len, 0, + NULL); + if (ret < 0) + sdata_err(sdata, "Failed sending TDLS teardown packet %d\n", + ret); + + /* + * Remove the STA AUTH flag to force further traffic through the AP. If + * the STA was unreachable, it was already removed. + */ + rcu_read_lock(); + sta = sta_info_get(sdata, peer); + if (sta) + clear_sta_flag(sta, WLAN_STA_TDLS_PEER_AUTH); + rcu_read_unlock(); + + ieee80211_wake_vif_queues(local, sdata, + IEEE80211_QUEUE_STOP_REASON_TDLS_TEARDOWN); + + return 0; +} + +int ieee80211_tdls_mgmt(struct wiphy *wiphy, struct net_device *dev, + const u8 *peer, u8 action_code, u8 dialog_token, + u16 status_code, u32 peer_capability, + bool initiator, const u8 *extra_ies, + size_t extra_ies_len) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + int ret; + + if (!(wiphy->flags & WIPHY_FLAG_SUPPORTS_TDLS)) + return -ENOTSUPP; + + /* make sure we are in managed mode, and associated */ + if (sdata->vif.type != NL80211_IFTYPE_STATION || + !sdata->u.mgd.associated) + return -EINVAL; + + switch (action_code) { + case WLAN_TDLS_SETUP_REQUEST: + case WLAN_TDLS_SETUP_RESPONSE: + ret = ieee80211_tdls_mgmt_setup(wiphy, dev, peer, action_code, + dialog_token, status_code, + peer_capability, initiator, + extra_ies, extra_ies_len); + break; + case WLAN_TDLS_TEARDOWN: + ret = ieee80211_tdls_mgmt_teardown(wiphy, dev, peer, + action_code, dialog_token, + status_code, + peer_capability, initiator, + extra_ies, extra_ies_len); + break; + case WLAN_TDLS_DISCOVERY_REQUEST: + /* + * Protect the discovery so we can hear the TDLS discovery + * response frame. It is transmitted directly and not buffered + * by the AP. + */ + drv_mgd_protect_tdls_discover(sdata->local, sdata); + fallthrough; + case WLAN_TDLS_SETUP_CONFIRM: + case WLAN_PUB_ACTION_TDLS_DISCOVER_RES: + /* no special handling */ + ret = ieee80211_tdls_prep_mgmt_packet(wiphy, dev, peer, + action_code, + dialog_token, + status_code, + peer_capability, + initiator, extra_ies, + extra_ies_len, 0, NULL); + break; + default: + ret = -EOPNOTSUPP; + break; + } + + tdls_dbg(sdata, "TDLS mgmt action %d peer %pM status %d\n", + action_code, peer, ret); + return ret; +} + +static void iee80211_tdls_recalc_chanctx(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_chanctx_conf *conf; + struct ieee80211_chanctx *ctx; + enum nl80211_chan_width width; + struct ieee80211_supported_band *sband; + + mutex_lock(&local->chanctx_mtx); + conf = rcu_dereference_protected(sdata->vif.bss_conf.chanctx_conf, + lockdep_is_held(&local->chanctx_mtx)); + if (conf) { + width = conf->def.width; + sband = local->hw.wiphy->bands[conf->def.chan->band]; + ctx = container_of(conf, struct ieee80211_chanctx, conf); + ieee80211_recalc_chanctx_chantype(local, ctx); + + /* if width changed and a peer is given, update its BW */ + if (width != conf->def.width && sta && + test_sta_flag(sta, WLAN_STA_TDLS_WIDER_BW)) { + enum ieee80211_sta_rx_bandwidth bw; + + bw = ieee80211_chan_width_to_rx_bw(conf->def.width); + bw = min(bw, ieee80211_sta_cap_rx_bw(&sta->deflink)); + if (bw != sta->sta.deflink.bandwidth) { + sta->sta.deflink.bandwidth = bw; + rate_control_rate_update(local, sband, sta, 0, + IEEE80211_RC_BW_CHANGED); + /* + * if a TDLS peer BW was updated, we need to + * recalc the chandef width again, to get the + * correct chanctx min_def + */ + ieee80211_recalc_chanctx_chantype(local, ctx); + } + } + + } + mutex_unlock(&local->chanctx_mtx); +} + +static int iee80211_tdls_have_ht_peers(struct ieee80211_sub_if_data *sdata) +{ + struct sta_info *sta; + bool result = false; + + rcu_read_lock(); + list_for_each_entry_rcu(sta, &sdata->local->sta_list, list) { + if (!sta->sta.tdls || sta->sdata != sdata || !sta->uploaded || + !test_sta_flag(sta, WLAN_STA_AUTHORIZED) || + !test_sta_flag(sta, WLAN_STA_TDLS_PEER_AUTH) || + !sta->sta.deflink.ht_cap.ht_supported) + continue; + result = true; + break; + } + rcu_read_unlock(); + + return result; +} + +static void +iee80211_tdls_recalc_ht_protection(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta) +{ + bool tdls_ht; + u16 protection = IEEE80211_HT_OP_MODE_PROTECTION_NONHT_MIXED | + IEEE80211_HT_OP_MODE_NON_GF_STA_PRSNT | + IEEE80211_HT_OP_MODE_NON_HT_STA_PRSNT; + u16 opmode; + + /* Nothing to do if the BSS connection uses HT */ + if (!(sdata->deflink.u.mgd.conn_flags & IEEE80211_CONN_DISABLE_HT)) + return; + + tdls_ht = (sta && sta->sta.deflink.ht_cap.ht_supported) || + iee80211_tdls_have_ht_peers(sdata); + + opmode = sdata->vif.bss_conf.ht_operation_mode; + + if (tdls_ht) + opmode |= protection; + else + opmode &= ~protection; + + if (opmode == sdata->vif.bss_conf.ht_operation_mode) + return; + + sdata->vif.bss_conf.ht_operation_mode = opmode; + ieee80211_link_info_change_notify(sdata, &sdata->deflink, + BSS_CHANGED_HT); +} + +int ieee80211_tdls_oper(struct wiphy *wiphy, struct net_device *dev, + const u8 *peer, enum nl80211_tdls_operation oper) +{ + struct sta_info *sta; + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + int ret; + + if (!(wiphy->flags & WIPHY_FLAG_SUPPORTS_TDLS)) + return -ENOTSUPP; + + if (sdata->vif.type != NL80211_IFTYPE_STATION) + return -EINVAL; + + switch (oper) { + case NL80211_TDLS_ENABLE_LINK: + case NL80211_TDLS_DISABLE_LINK: + break; + case NL80211_TDLS_TEARDOWN: + case NL80211_TDLS_SETUP: + case NL80211_TDLS_DISCOVERY_REQ: + /* We don't support in-driver setup/teardown/discovery */ + return -ENOTSUPP; + } + + /* protect possible bss_conf changes and avoid concurrency in + * ieee80211_bss_info_change_notify() + */ + sdata_lock(sdata); + mutex_lock(&local->mtx); + tdls_dbg(sdata, "TDLS oper %d peer %pM\n", oper, peer); + + switch (oper) { + case NL80211_TDLS_ENABLE_LINK: + if (sdata->vif.bss_conf.csa_active) { + tdls_dbg(sdata, "TDLS: disallow link during CSA\n"); + ret = -EBUSY; + break; + } + + mutex_lock(&local->sta_mtx); + sta = sta_info_get(sdata, peer); + if (!sta) { + mutex_unlock(&local->sta_mtx); + ret = -ENOLINK; + break; + } + + iee80211_tdls_recalc_chanctx(sdata, sta); + iee80211_tdls_recalc_ht_protection(sdata, sta); + + set_sta_flag(sta, WLAN_STA_TDLS_PEER_AUTH); + mutex_unlock(&local->sta_mtx); + + WARN_ON_ONCE(is_zero_ether_addr(sdata->u.mgd.tdls_peer) || + !ether_addr_equal(sdata->u.mgd.tdls_peer, peer)); + ret = 0; + break; + case NL80211_TDLS_DISABLE_LINK: + /* + * The teardown message in ieee80211_tdls_mgmt_teardown() was + * created while the queues were stopped, so it might still be + * pending. Before flushing the queues we need to be sure the + * message is handled by the tasklet handling pending messages, + * otherwise we might start destroying the station before + * sending the teardown packet. + * Note that this only forces the tasklet to flush pendings - + * not to stop the tasklet from rescheduling itself. + */ + tasklet_kill(&local->tx_pending_tasklet); + /* flush a potentially queued teardown packet */ + ieee80211_flush_queues(local, sdata, false); + + ret = sta_info_destroy_addr(sdata, peer); + + mutex_lock(&local->sta_mtx); + iee80211_tdls_recalc_ht_protection(sdata, NULL); + mutex_unlock(&local->sta_mtx); + + iee80211_tdls_recalc_chanctx(sdata, NULL); + break; + default: + ret = -ENOTSUPP; + break; + } + + if (ret == 0 && ether_addr_equal(sdata->u.mgd.tdls_peer, peer)) { + cancel_delayed_work(&sdata->u.mgd.tdls_peer_del_work); + eth_zero_addr(sdata->u.mgd.tdls_peer); + } + + if (ret == 0) + ieee80211_queue_work(&sdata->local->hw, + &sdata->deflink.u.mgd.request_smps_work); + + mutex_unlock(&local->mtx); + sdata_unlock(sdata); + return ret; +} + +void ieee80211_tdls_oper_request(struct ieee80211_vif *vif, const u8 *peer, + enum nl80211_tdls_operation oper, + u16 reason_code, gfp_t gfp) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + + if (vif->type != NL80211_IFTYPE_STATION || !vif->cfg.assoc) { + sdata_err(sdata, "Discarding TDLS oper %d - not STA or disconnected\n", + oper); + return; + } + + cfg80211_tdls_oper_request(sdata->dev, peer, oper, reason_code, gfp); +} +EXPORT_SYMBOL(ieee80211_tdls_oper_request); + +static void +iee80211_tdls_add_ch_switch_timing(u8 *buf, u16 switch_time, u16 switch_timeout) +{ + struct ieee80211_ch_switch_timing *ch_sw; + + *buf++ = WLAN_EID_CHAN_SWITCH_TIMING; + *buf++ = sizeof(struct ieee80211_ch_switch_timing); + + ch_sw = (void *)buf; + ch_sw->switch_time = cpu_to_le16(switch_time); + ch_sw->switch_timeout = cpu_to_le16(switch_timeout); +} + +/* find switch timing IE in SKB ready for Tx */ +static const u8 *ieee80211_tdls_find_sw_timing_ie(struct sk_buff *skb) +{ + struct ieee80211_tdls_data *tf; + const u8 *ie_start; + + /* + * Get the offset for the new location of the switch timing IE. + * The SKB network header will now point to the "payload_type" + * element of the TDLS data frame struct. + */ + tf = container_of(skb->data + skb_network_offset(skb), + struct ieee80211_tdls_data, payload_type); + ie_start = tf->u.chan_switch_req.variable; + return cfg80211_find_ie(WLAN_EID_CHAN_SWITCH_TIMING, ie_start, + skb->len - (ie_start - skb->data)); +} + +static struct sk_buff * +ieee80211_tdls_ch_sw_tmpl_get(struct sta_info *sta, u8 oper_class, + struct cfg80211_chan_def *chandef, + u32 *ch_sw_tm_ie_offset) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + u8 extra_ies[2 + sizeof(struct ieee80211_sec_chan_offs_ie) + + 2 + sizeof(struct ieee80211_ch_switch_timing)]; + int extra_ies_len = 2 + sizeof(struct ieee80211_ch_switch_timing); + u8 *pos = extra_ies; + struct sk_buff *skb; + + /* + * if chandef points to a wide channel add a Secondary-Channel + * Offset information element + */ + if (chandef->width == NL80211_CHAN_WIDTH_40) { + struct ieee80211_sec_chan_offs_ie *sec_chan_ie; + bool ht40plus; + + *pos++ = WLAN_EID_SECONDARY_CHANNEL_OFFSET; + *pos++ = sizeof(*sec_chan_ie); + sec_chan_ie = (void *)pos; + + ht40plus = cfg80211_get_chandef_type(chandef) == + NL80211_CHAN_HT40PLUS; + sec_chan_ie->sec_chan_offs = ht40plus ? + IEEE80211_HT_PARAM_CHA_SEC_ABOVE : + IEEE80211_HT_PARAM_CHA_SEC_BELOW; + pos += sizeof(*sec_chan_ie); + + extra_ies_len += 2 + sizeof(struct ieee80211_sec_chan_offs_ie); + } + + /* just set the values to 0, this is a template */ + iee80211_tdls_add_ch_switch_timing(pos, 0, 0); + + skb = ieee80211_tdls_build_mgmt_packet_data(sdata, sta->sta.addr, + WLAN_TDLS_CHANNEL_SWITCH_REQUEST, + 0, 0, !sta->sta.tdls_initiator, + extra_ies, extra_ies_len, + oper_class, chandef); + if (!skb) + return NULL; + + skb = ieee80211_build_data_template(sdata, skb, 0); + if (IS_ERR(skb)) { + tdls_dbg(sdata, "Failed building TDLS channel switch frame\n"); + return NULL; + } + + if (ch_sw_tm_ie_offset) { + const u8 *tm_ie = ieee80211_tdls_find_sw_timing_ie(skb); + + if (!tm_ie) { + tdls_dbg(sdata, "No switch timing IE in TDLS switch\n"); + dev_kfree_skb_any(skb); + return NULL; + } + + *ch_sw_tm_ie_offset = tm_ie - skb->data; + } + + tdls_dbg(sdata, + "TDLS channel switch request template for %pM ch %d width %d\n", + sta->sta.addr, chandef->chan->center_freq, chandef->width); + return skb; +} + +int +ieee80211_tdls_channel_switch(struct wiphy *wiphy, struct net_device *dev, + const u8 *addr, u8 oper_class, + struct cfg80211_chan_def *chandef) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + struct sk_buff *skb = NULL; + u32 ch_sw_tm_ie; + int ret; + + if (chandef->chan->freq_offset) + /* this may work, but is untested */ + return -EOPNOTSUPP; + + mutex_lock(&local->sta_mtx); + sta = sta_info_get(sdata, addr); + if (!sta) { + tdls_dbg(sdata, + "Invalid TDLS peer %pM for channel switch request\n", + addr); + ret = -ENOENT; + goto out; + } + + if (!test_sta_flag(sta, WLAN_STA_TDLS_CHAN_SWITCH)) { + tdls_dbg(sdata, "TDLS channel switch unsupported by %pM\n", + addr); + ret = -ENOTSUPP; + goto out; + } + + skb = ieee80211_tdls_ch_sw_tmpl_get(sta, oper_class, chandef, + &ch_sw_tm_ie); + if (!skb) { + ret = -ENOENT; + goto out; + } + + ret = drv_tdls_channel_switch(local, sdata, &sta->sta, oper_class, + chandef, skb, ch_sw_tm_ie); + if (!ret) + set_sta_flag(sta, WLAN_STA_TDLS_OFF_CHANNEL); + +out: + mutex_unlock(&local->sta_mtx); + dev_kfree_skb_any(skb); + return ret; +} + +void +ieee80211_tdls_cancel_channel_switch(struct wiphy *wiphy, + struct net_device *dev, + const u8 *addr) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + + mutex_lock(&local->sta_mtx); + sta = sta_info_get(sdata, addr); + if (!sta) { + tdls_dbg(sdata, + "Invalid TDLS peer %pM for channel switch cancel\n", + addr); + goto out; + } + + if (!test_sta_flag(sta, WLAN_STA_TDLS_OFF_CHANNEL)) { + tdls_dbg(sdata, "TDLS channel switch not initiated by %pM\n", + addr); + goto out; + } + + drv_tdls_cancel_channel_switch(local, sdata, &sta->sta); + clear_sta_flag(sta, WLAN_STA_TDLS_OFF_CHANNEL); + +out: + mutex_unlock(&local->sta_mtx); +} + +static struct sk_buff * +ieee80211_tdls_ch_sw_resp_tmpl_get(struct sta_info *sta, + u32 *ch_sw_tm_ie_offset) +{ + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct sk_buff *skb; + u8 extra_ies[2 + sizeof(struct ieee80211_ch_switch_timing)]; + + /* initial timing are always zero in the template */ + iee80211_tdls_add_ch_switch_timing(extra_ies, 0, 0); + + skb = ieee80211_tdls_build_mgmt_packet_data(sdata, sta->sta.addr, + WLAN_TDLS_CHANNEL_SWITCH_RESPONSE, + 0, 0, !sta->sta.tdls_initiator, + extra_ies, sizeof(extra_ies), 0, NULL); + if (!skb) + return NULL; + + skb = ieee80211_build_data_template(sdata, skb, 0); + if (IS_ERR(skb)) { + tdls_dbg(sdata, + "Failed building TDLS channel switch resp frame\n"); + return NULL; + } + + if (ch_sw_tm_ie_offset) { + const u8 *tm_ie = ieee80211_tdls_find_sw_timing_ie(skb); + + if (!tm_ie) { + tdls_dbg(sdata, + "No switch timing IE in TDLS switch resp\n"); + dev_kfree_skb_any(skb); + return NULL; + } + + *ch_sw_tm_ie_offset = tm_ie - skb->data; + } + + tdls_dbg(sdata, "TDLS get channel switch response template for %pM\n", + sta->sta.addr); + return skb; +} + +static int +ieee80211_process_tdls_channel_switch_resp(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_local *local = sdata->local; + struct ieee802_11_elems *elems = NULL; + struct sta_info *sta; + struct ieee80211_tdls_data *tf = (void *)skb->data; + bool local_initiator; + struct ieee80211_rx_status *rx_status = IEEE80211_SKB_RXCB(skb); + int baselen = offsetof(typeof(*tf), u.chan_switch_resp.variable); + struct ieee80211_tdls_ch_sw_params params = {}; + int ret; + + params.action_code = WLAN_TDLS_CHANNEL_SWITCH_RESPONSE; + params.timestamp = rx_status->device_timestamp; + + if (skb->len < baselen) { + tdls_dbg(sdata, "TDLS channel switch resp too short: %d\n", + skb->len); + return -EINVAL; + } + + mutex_lock(&local->sta_mtx); + sta = sta_info_get(sdata, tf->sa); + if (!sta || !test_sta_flag(sta, WLAN_STA_TDLS_PEER_AUTH)) { + tdls_dbg(sdata, "TDLS chan switch from non-peer sta %pM\n", + tf->sa); + ret = -EINVAL; + goto out; + } + + params.sta = &sta->sta; + params.status = le16_to_cpu(tf->u.chan_switch_resp.status_code); + if (params.status != 0) { + ret = 0; + goto call_drv; + } + + elems = ieee802_11_parse_elems(tf->u.chan_switch_resp.variable, + skb->len - baselen, false, NULL); + if (!elems) { + ret = -ENOMEM; + goto out; + } + + if (elems->parse_error) { + tdls_dbg(sdata, "Invalid IEs in TDLS channel switch resp\n"); + ret = -EINVAL; + goto out; + } + + if (!elems->ch_sw_timing || !elems->lnk_id) { + tdls_dbg(sdata, "TDLS channel switch resp - missing IEs\n"); + ret = -EINVAL; + goto out; + } + + /* validate the initiator is set correctly */ + local_initiator = + !memcmp(elems->lnk_id->init_sta, sdata->vif.addr, ETH_ALEN); + if (local_initiator == sta->sta.tdls_initiator) { + tdls_dbg(sdata, "TDLS chan switch invalid lnk-id initiator\n"); + ret = -EINVAL; + goto out; + } + + params.switch_time = le16_to_cpu(elems->ch_sw_timing->switch_time); + params.switch_timeout = le16_to_cpu(elems->ch_sw_timing->switch_timeout); + + params.tmpl_skb = + ieee80211_tdls_ch_sw_resp_tmpl_get(sta, ¶ms.ch_sw_tm_ie); + if (!params.tmpl_skb) { + ret = -ENOENT; + goto out; + } + + ret = 0; +call_drv: + drv_tdls_recv_channel_switch(sdata->local, sdata, ¶ms); + + tdls_dbg(sdata, + "TDLS channel switch response received from %pM status %d\n", + tf->sa, params.status); + +out: + mutex_unlock(&local->sta_mtx); + dev_kfree_skb_any(params.tmpl_skb); + kfree(elems); + return ret; +} + +static int +ieee80211_process_tdls_channel_switch_req(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_local *local = sdata->local; + struct ieee802_11_elems *elems; + struct cfg80211_chan_def chandef; + struct ieee80211_channel *chan; + enum nl80211_channel_type chan_type; + int freq; + u8 target_channel, oper_class; + bool local_initiator; + struct sta_info *sta; + enum nl80211_band band; + struct ieee80211_tdls_data *tf = (void *)skb->data; + struct ieee80211_rx_status *rx_status = IEEE80211_SKB_RXCB(skb); + int baselen = offsetof(typeof(*tf), u.chan_switch_req.variable); + struct ieee80211_tdls_ch_sw_params params = {}; + int ret = 0; + + params.action_code = WLAN_TDLS_CHANNEL_SWITCH_REQUEST; + params.timestamp = rx_status->device_timestamp; + + if (skb->len < baselen) { + tdls_dbg(sdata, "TDLS channel switch req too short: %d\n", + skb->len); + return -EINVAL; + } + + target_channel = tf->u.chan_switch_req.target_channel; + oper_class = tf->u.chan_switch_req.oper_class; + + /* + * We can't easily infer the channel band. The operating class is + * ambiguous - there are multiple tables (US/Europe/JP/Global). The + * solution here is to treat channels with number >14 as 5GHz ones, + * and specifically check for the (oper_class, channel) combinations + * where this doesn't hold. These are thankfully unique according to + * IEEE802.11-2012. + * We consider only the 2GHz and 5GHz bands and 20MHz+ channels as + * valid here. + */ + if ((oper_class == 112 || oper_class == 2 || oper_class == 3 || + oper_class == 4 || oper_class == 5 || oper_class == 6) && + target_channel < 14) + band = NL80211_BAND_5GHZ; + else + band = target_channel < 14 ? NL80211_BAND_2GHZ : + NL80211_BAND_5GHZ; + + freq = ieee80211_channel_to_frequency(target_channel, band); + if (freq == 0) { + tdls_dbg(sdata, "Invalid channel in TDLS chan switch: %d\n", + target_channel); + return -EINVAL; + } + + chan = ieee80211_get_channel(sdata->local->hw.wiphy, freq); + if (!chan) { + tdls_dbg(sdata, + "Unsupported channel for TDLS chan switch: %d\n", + target_channel); + return -EINVAL; + } + + elems = ieee802_11_parse_elems(tf->u.chan_switch_req.variable, + skb->len - baselen, false, NULL); + if (!elems) + return -ENOMEM; + + if (elems->parse_error) { + tdls_dbg(sdata, "Invalid IEs in TDLS channel switch req\n"); + ret = -EINVAL; + goto free; + } + + if (!elems->ch_sw_timing || !elems->lnk_id) { + tdls_dbg(sdata, "TDLS channel switch req - missing IEs\n"); + ret = -EINVAL; + goto free; + } + + if (!elems->sec_chan_offs) { + chan_type = NL80211_CHAN_HT20; + } else { + switch (elems->sec_chan_offs->sec_chan_offs) { + case IEEE80211_HT_PARAM_CHA_SEC_ABOVE: + chan_type = NL80211_CHAN_HT40PLUS; + break; + case IEEE80211_HT_PARAM_CHA_SEC_BELOW: + chan_type = NL80211_CHAN_HT40MINUS; + break; + default: + chan_type = NL80211_CHAN_HT20; + break; + } + } + + cfg80211_chandef_create(&chandef, chan, chan_type); + + /* we will be active on the TDLS link */ + if (!cfg80211_reg_can_beacon_relax(sdata->local->hw.wiphy, &chandef, + sdata->wdev.iftype)) { + tdls_dbg(sdata, "TDLS chan switch to forbidden channel\n"); + ret = -EINVAL; + goto free; + } + + mutex_lock(&local->sta_mtx); + sta = sta_info_get(sdata, tf->sa); + if (!sta || !test_sta_flag(sta, WLAN_STA_TDLS_PEER_AUTH)) { + tdls_dbg(sdata, "TDLS chan switch from non-peer sta %pM\n", + tf->sa); + ret = -EINVAL; + goto out; + } + + params.sta = &sta->sta; + + /* validate the initiator is set correctly */ + local_initiator = + !memcmp(elems->lnk_id->init_sta, sdata->vif.addr, ETH_ALEN); + if (local_initiator == sta->sta.tdls_initiator) { + tdls_dbg(sdata, "TDLS chan switch invalid lnk-id initiator\n"); + ret = -EINVAL; + goto out; + } + + /* peer should have known better */ + if (!sta->sta.deflink.ht_cap.ht_supported && elems->sec_chan_offs && + elems->sec_chan_offs->sec_chan_offs) { + tdls_dbg(sdata, "TDLS chan switch - wide chan unsupported\n"); + ret = -ENOTSUPP; + goto out; + } + + params.chandef = &chandef; + params.switch_time = le16_to_cpu(elems->ch_sw_timing->switch_time); + params.switch_timeout = le16_to_cpu(elems->ch_sw_timing->switch_timeout); + + params.tmpl_skb = + ieee80211_tdls_ch_sw_resp_tmpl_get(sta, + ¶ms.ch_sw_tm_ie); + if (!params.tmpl_skb) { + ret = -ENOENT; + goto out; + } + + drv_tdls_recv_channel_switch(sdata->local, sdata, ¶ms); + + tdls_dbg(sdata, + "TDLS ch switch request received from %pM ch %d width %d\n", + tf->sa, params.chandef->chan->center_freq, + params.chandef->width); +out: + mutex_unlock(&local->sta_mtx); + dev_kfree_skb_any(params.tmpl_skb); +free: + kfree(elems); + return ret; +} + +void +ieee80211_process_tdls_channel_switch(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_tdls_data *tf = (void *)skb->data; + struct wiphy *wiphy = sdata->local->hw.wiphy; + + lockdep_assert_wiphy(wiphy); + + /* make sure the driver supports it */ + if (!(wiphy->features & NL80211_FEATURE_TDLS_CHANNEL_SWITCH)) + return; + + /* we want to access the entire packet */ + if (skb_linearize(skb)) + return; + /* + * The packet/size was already validated by mac80211 Rx path, only look + * at the action type. + */ + switch (tf->action_code) { + case WLAN_TDLS_CHANNEL_SWITCH_REQUEST: + ieee80211_process_tdls_channel_switch_req(sdata, skb); + break; + case WLAN_TDLS_CHANNEL_SWITCH_RESPONSE: + ieee80211_process_tdls_channel_switch_resp(sdata, skb); + break; + default: + WARN_ON_ONCE(1); + return; + } +} + +void ieee80211_teardown_tdls_peers(struct ieee80211_sub_if_data *sdata) +{ + struct sta_info *sta; + u16 reason = WLAN_REASON_TDLS_TEARDOWN_UNSPECIFIED; + + rcu_read_lock(); + list_for_each_entry_rcu(sta, &sdata->local->sta_list, list) { + if (!sta->sta.tdls || sta->sdata != sdata || !sta->uploaded || + !test_sta_flag(sta, WLAN_STA_AUTHORIZED)) + continue; + + ieee80211_tdls_oper_request(&sdata->vif, sta->sta.addr, + NL80211_TDLS_TEARDOWN, reason, + GFP_ATOMIC); + } + rcu_read_unlock(); +} + +void ieee80211_tdls_handle_disconnect(struct ieee80211_sub_if_data *sdata, + const u8 *peer, u16 reason) +{ + struct ieee80211_sta *sta; + + rcu_read_lock(); + sta = ieee80211_find_sta(&sdata->vif, peer); + if (!sta || !sta->tdls) { + rcu_read_unlock(); + return; + } + rcu_read_unlock(); + + tdls_dbg(sdata, "disconnected from TDLS peer %pM (Reason: %u=%s)\n", + peer, reason, + ieee80211_get_reason_code_string(reason)); + + ieee80211_tdls_oper_request(&sdata->vif, peer, + NL80211_TDLS_TEARDOWN, + WLAN_REASON_TDLS_TEARDOWN_UNREACHABLE, + GFP_ATOMIC); +} diff --git a/net/mac80211/tkip.c b/net/mac80211/tkip.c new file mode 100644 index 000000000..e7f57bb18 --- /dev/null +++ b/net/mac80211/tkip.c @@ -0,0 +1,323 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2002-2004, Instant802 Networks, Inc. + * Copyright 2005, Devicescape Software, Inc. + * Copyright (C) 2016 Intel Deutschland GmbH + */ +#include +#include +#include +#include +#include +#include + +#include +#include "driver-ops.h" +#include "key.h" +#include "tkip.h" +#include "wep.h" + +#define PHASE1_LOOP_COUNT 8 + +/* + * 2-byte by 2-byte subset of the full AES S-box table; second part of this + * table is identical to first part but byte-swapped + */ +static const u16 tkip_sbox[256] = +{ + 0xC6A5, 0xF884, 0xEE99, 0xF68D, 0xFF0D, 0xD6BD, 0xDEB1, 0x9154, + 0x6050, 0x0203, 0xCEA9, 0x567D, 0xE719, 0xB562, 0x4DE6, 0xEC9A, + 0x8F45, 0x1F9D, 0x8940, 0xFA87, 0xEF15, 0xB2EB, 0x8EC9, 0xFB0B, + 0x41EC, 0xB367, 0x5FFD, 0x45EA, 0x23BF, 0x53F7, 0xE496, 0x9B5B, + 0x75C2, 0xE11C, 0x3DAE, 0x4C6A, 0x6C5A, 0x7E41, 0xF502, 0x834F, + 0x685C, 0x51F4, 0xD134, 0xF908, 0xE293, 0xAB73, 0x6253, 0x2A3F, + 0x080C, 0x9552, 0x4665, 0x9D5E, 0x3028, 0x37A1, 0x0A0F, 0x2FB5, + 0x0E09, 0x2436, 0x1B9B, 0xDF3D, 0xCD26, 0x4E69, 0x7FCD, 0xEA9F, + 0x121B, 0x1D9E, 0x5874, 0x342E, 0x362D, 0xDCB2, 0xB4EE, 0x5BFB, + 0xA4F6, 0x764D, 0xB761, 0x7DCE, 0x527B, 0xDD3E, 0x5E71, 0x1397, + 0xA6F5, 0xB968, 0x0000, 0xC12C, 0x4060, 0xE31F, 0x79C8, 0xB6ED, + 0xD4BE, 0x8D46, 0x67D9, 0x724B, 0x94DE, 0x98D4, 0xB0E8, 0x854A, + 0xBB6B, 0xC52A, 0x4FE5, 0xED16, 0x86C5, 0x9AD7, 0x6655, 0x1194, + 0x8ACF, 0xE910, 0x0406, 0xFE81, 0xA0F0, 0x7844, 0x25BA, 0x4BE3, + 0xA2F3, 0x5DFE, 0x80C0, 0x058A, 0x3FAD, 0x21BC, 0x7048, 0xF104, + 0x63DF, 0x77C1, 0xAF75, 0x4263, 0x2030, 0xE51A, 0xFD0E, 0xBF6D, + 0x814C, 0x1814, 0x2635, 0xC32F, 0xBEE1, 0x35A2, 0x88CC, 0x2E39, + 0x9357, 0x55F2, 0xFC82, 0x7A47, 0xC8AC, 0xBAE7, 0x322B, 0xE695, + 0xC0A0, 0x1998, 0x9ED1, 0xA37F, 0x4466, 0x547E, 0x3BAB, 0x0B83, + 0x8CCA, 0xC729, 0x6BD3, 0x283C, 0xA779, 0xBCE2, 0x161D, 0xAD76, + 0xDB3B, 0x6456, 0x744E, 0x141E, 0x92DB, 0x0C0A, 0x486C, 0xB8E4, + 0x9F5D, 0xBD6E, 0x43EF, 0xC4A6, 0x39A8, 0x31A4, 0xD337, 0xF28B, + 0xD532, 0x8B43, 0x6E59, 0xDAB7, 0x018C, 0xB164, 0x9CD2, 0x49E0, + 0xD8B4, 0xACFA, 0xF307, 0xCF25, 0xCAAF, 0xF48E, 0x47E9, 0x1018, + 0x6FD5, 0xF088, 0x4A6F, 0x5C72, 0x3824, 0x57F1, 0x73C7, 0x9751, + 0xCB23, 0xA17C, 0xE89C, 0x3E21, 0x96DD, 0x61DC, 0x0D86, 0x0F85, + 0xE090, 0x7C42, 0x71C4, 0xCCAA, 0x90D8, 0x0605, 0xF701, 0x1C12, + 0xC2A3, 0x6A5F, 0xAEF9, 0x69D0, 0x1791, 0x9958, 0x3A27, 0x27B9, + 0xD938, 0xEB13, 0x2BB3, 0x2233, 0xD2BB, 0xA970, 0x0789, 0x33A7, + 0x2DB6, 0x3C22, 0x1592, 0xC920, 0x8749, 0xAAFF, 0x5078, 0xA57A, + 0x038F, 0x59F8, 0x0980, 0x1A17, 0x65DA, 0xD731, 0x84C6, 0xD0B8, + 0x82C3, 0x29B0, 0x5A77, 0x1E11, 0x7BCB, 0xA8FC, 0x6DD6, 0x2C3A, +}; + +static u16 tkipS(u16 val) +{ + return tkip_sbox[val & 0xff] ^ swab16(tkip_sbox[val >> 8]); +} + +static u8 *write_tkip_iv(u8 *pos, u16 iv16) +{ + *pos++ = iv16 >> 8; + *pos++ = ((iv16 >> 8) | 0x20) & 0x7f; + *pos++ = iv16 & 0xFF; + return pos; +} + +/* + * P1K := Phase1(TA, TK, TSC) + * TA = transmitter address (48 bits) + * TK = dot11DefaultKeyValue or dot11KeyMappingValue (128 bits) + * TSC = TKIP sequence counter (48 bits, only 32 msb bits used) + * P1K: 80 bits + */ +static void tkip_mixing_phase1(const u8 *tk, struct tkip_ctx *ctx, + const u8 *ta, u32 tsc_IV32) +{ + int i, j; + u16 *p1k = ctx->p1k; + + p1k[0] = tsc_IV32 & 0xFFFF; + p1k[1] = tsc_IV32 >> 16; + p1k[2] = get_unaligned_le16(ta + 0); + p1k[3] = get_unaligned_le16(ta + 2); + p1k[4] = get_unaligned_le16(ta + 4); + + for (i = 0; i < PHASE1_LOOP_COUNT; i++) { + j = 2 * (i & 1); + p1k[0] += tkipS(p1k[4] ^ get_unaligned_le16(tk + 0 + j)); + p1k[1] += tkipS(p1k[0] ^ get_unaligned_le16(tk + 4 + j)); + p1k[2] += tkipS(p1k[1] ^ get_unaligned_le16(tk + 8 + j)); + p1k[3] += tkipS(p1k[2] ^ get_unaligned_le16(tk + 12 + j)); + p1k[4] += tkipS(p1k[3] ^ get_unaligned_le16(tk + 0 + j)) + i; + } + ctx->state = TKIP_STATE_PHASE1_DONE; + ctx->p1k_iv32 = tsc_IV32; +} + +static void tkip_mixing_phase2(const u8 *tk, struct tkip_ctx *ctx, + u16 tsc_IV16, u8 *rc4key) +{ + u16 ppk[6]; + const u16 *p1k = ctx->p1k; + int i; + + ppk[0] = p1k[0]; + ppk[1] = p1k[1]; + ppk[2] = p1k[2]; + ppk[3] = p1k[3]; + ppk[4] = p1k[4]; + ppk[5] = p1k[4] + tsc_IV16; + + ppk[0] += tkipS(ppk[5] ^ get_unaligned_le16(tk + 0)); + ppk[1] += tkipS(ppk[0] ^ get_unaligned_le16(tk + 2)); + ppk[2] += tkipS(ppk[1] ^ get_unaligned_le16(tk + 4)); + ppk[3] += tkipS(ppk[2] ^ get_unaligned_le16(tk + 6)); + ppk[4] += tkipS(ppk[3] ^ get_unaligned_le16(tk + 8)); + ppk[5] += tkipS(ppk[4] ^ get_unaligned_le16(tk + 10)); + ppk[0] += ror16(ppk[5] ^ get_unaligned_le16(tk + 12), 1); + ppk[1] += ror16(ppk[0] ^ get_unaligned_le16(tk + 14), 1); + ppk[2] += ror16(ppk[1], 1); + ppk[3] += ror16(ppk[2], 1); + ppk[4] += ror16(ppk[3], 1); + ppk[5] += ror16(ppk[4], 1); + + rc4key = write_tkip_iv(rc4key, tsc_IV16); + *rc4key++ = ((ppk[5] ^ get_unaligned_le16(tk)) >> 1) & 0xFF; + + for (i = 0; i < 6; i++) + put_unaligned_le16(ppk[i], rc4key + 2 * i); +} + +/* Add TKIP IV and Ext. IV at @pos. @iv0, @iv1, and @iv2 are the first octets + * of the IV. Returns pointer to the octet following IVs (i.e., beginning of + * the packet payload). */ +u8 *ieee80211_tkip_add_iv(u8 *pos, struct ieee80211_key_conf *keyconf, u64 pn) +{ + pos = write_tkip_iv(pos, TKIP_PN_TO_IV16(pn)); + *pos++ = (keyconf->keyidx << 6) | (1 << 5) /* Ext IV */; + put_unaligned_le32(TKIP_PN_TO_IV32(pn), pos); + return pos + 4; +} +EXPORT_SYMBOL_GPL(ieee80211_tkip_add_iv); + +static void ieee80211_compute_tkip_p1k(struct ieee80211_key *key, u32 iv32) +{ + struct ieee80211_sub_if_data *sdata = key->sdata; + struct tkip_ctx *ctx = &key->u.tkip.tx; + const u8 *tk = &key->conf.key[NL80211_TKIP_DATA_OFFSET_ENCR_KEY]; + + lockdep_assert_held(&key->u.tkip.txlock); + + /* + * Update the P1K when the IV32 is different from the value it + * had when we last computed it (or when not initialised yet). + * This might flip-flop back and forth if packets are processed + * out-of-order due to the different ACs, but then we have to + * just compute the P1K more often. + */ + if (ctx->p1k_iv32 != iv32 || ctx->state == TKIP_STATE_NOT_INIT) + tkip_mixing_phase1(tk, ctx, sdata->vif.addr, iv32); +} + +void ieee80211_get_tkip_p1k_iv(struct ieee80211_key_conf *keyconf, + u32 iv32, u16 *p1k) +{ + struct ieee80211_key *key = (struct ieee80211_key *) + container_of(keyconf, struct ieee80211_key, conf); + struct tkip_ctx *ctx = &key->u.tkip.tx; + + spin_lock_bh(&key->u.tkip.txlock); + ieee80211_compute_tkip_p1k(key, iv32); + memcpy(p1k, ctx->p1k, sizeof(ctx->p1k)); + spin_unlock_bh(&key->u.tkip.txlock); +} +EXPORT_SYMBOL(ieee80211_get_tkip_p1k_iv); + +void ieee80211_get_tkip_rx_p1k(struct ieee80211_key_conf *keyconf, + const u8 *ta, u32 iv32, u16 *p1k) +{ + const u8 *tk = &keyconf->key[NL80211_TKIP_DATA_OFFSET_ENCR_KEY]; + struct tkip_ctx ctx; + + tkip_mixing_phase1(tk, &ctx, ta, iv32); + memcpy(p1k, ctx.p1k, sizeof(ctx.p1k)); +} +EXPORT_SYMBOL(ieee80211_get_tkip_rx_p1k); + +void ieee80211_get_tkip_p2k(struct ieee80211_key_conf *keyconf, + struct sk_buff *skb, u8 *p2k) +{ + struct ieee80211_key *key = (struct ieee80211_key *) + container_of(keyconf, struct ieee80211_key, conf); + const u8 *tk = &key->conf.key[NL80211_TKIP_DATA_OFFSET_ENCR_KEY]; + struct tkip_ctx *ctx = &key->u.tkip.tx; + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + const u8 *data = (u8 *)hdr + ieee80211_hdrlen(hdr->frame_control); + u32 iv32 = get_unaligned_le32(&data[4]); + u16 iv16 = data[2] | (data[0] << 8); + + spin_lock(&key->u.tkip.txlock); + ieee80211_compute_tkip_p1k(key, iv32); + tkip_mixing_phase2(tk, ctx, iv16, p2k); + spin_unlock(&key->u.tkip.txlock); +} +EXPORT_SYMBOL(ieee80211_get_tkip_p2k); + +/* + * Encrypt packet payload with TKIP using @key. @pos is a pointer to the + * beginning of the buffer containing payload. This payload must include + * the IV/Ext.IV and space for (taildroom) four octets for ICV. + * @payload_len is the length of payload (_not_ including IV/ICV length). + * @ta is the transmitter addresses. + */ +int ieee80211_tkip_encrypt_data(struct arc4_ctx *ctx, + struct ieee80211_key *key, + struct sk_buff *skb, + u8 *payload, size_t payload_len) +{ + u8 rc4key[16]; + + ieee80211_get_tkip_p2k(&key->conf, skb, rc4key); + + return ieee80211_wep_encrypt_data(ctx, rc4key, 16, + payload, payload_len); +} + +/* Decrypt packet payload with TKIP using @key. @pos is a pointer to the + * beginning of the buffer containing IEEE 802.11 header payload, i.e., + * including IV, Ext. IV, real data, Michael MIC, ICV. @payload_len is the + * length of payload, including IV, Ext. IV, MIC, ICV. */ +int ieee80211_tkip_decrypt_data(struct arc4_ctx *ctx, + struct ieee80211_key *key, + u8 *payload, size_t payload_len, u8 *ta, + u8 *ra, int only_iv, int queue, + u32 *out_iv32, u16 *out_iv16) +{ + u32 iv32; + u32 iv16; + u8 rc4key[16], keyid, *pos = payload; + int res; + const u8 *tk = &key->conf.key[NL80211_TKIP_DATA_OFFSET_ENCR_KEY]; + struct tkip_ctx_rx *rx_ctx = &key->u.tkip.rx[queue]; + + if (payload_len < 12) + return -1; + + iv16 = (pos[0] << 8) | pos[2]; + keyid = pos[3]; + iv32 = get_unaligned_le32(pos + 4); + pos += 8; + + if (!(keyid & (1 << 5))) + return TKIP_DECRYPT_NO_EXT_IV; + + if ((keyid >> 6) != key->conf.keyidx) + return TKIP_DECRYPT_INVALID_KEYIDX; + + /* Reject replays if the received TSC is smaller than or equal to the + * last received value in a valid message, but with an exception for + * the case where a new key has been set and no valid frame using that + * key has yet received and the local RSC was initialized to 0. This + * exception allows the very first frame sent by the transmitter to be + * accepted even if that transmitter were to use TSC 0 (IEEE 802.11 + * described TSC to be initialized to 1 whenever a new key is taken into + * use). + */ + if (iv32 < rx_ctx->iv32 || + (iv32 == rx_ctx->iv32 && + (iv16 < rx_ctx->iv16 || + (iv16 == rx_ctx->iv16 && + (rx_ctx->iv32 || rx_ctx->iv16 || + rx_ctx->ctx.state != TKIP_STATE_NOT_INIT))))) + return TKIP_DECRYPT_REPLAY; + + if (only_iv) { + res = TKIP_DECRYPT_OK; + rx_ctx->ctx.state = TKIP_STATE_PHASE1_HW_UPLOADED; + goto done; + } + + if (rx_ctx->ctx.state == TKIP_STATE_NOT_INIT || + rx_ctx->iv32 != iv32) { + /* IV16 wrapped around - perform TKIP phase 1 */ + tkip_mixing_phase1(tk, &rx_ctx->ctx, ta, iv32); + } + if (key->local->ops->update_tkip_key && + key->flags & KEY_FLAG_UPLOADED_TO_HARDWARE && + rx_ctx->ctx.state != TKIP_STATE_PHASE1_HW_UPLOADED) { + struct ieee80211_sub_if_data *sdata = key->sdata; + + if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + sdata = container_of(key->sdata->bss, + struct ieee80211_sub_if_data, u.ap); + drv_update_tkip_key(key->local, sdata, &key->conf, key->sta, + iv32, rx_ctx->ctx.p1k); + rx_ctx->ctx.state = TKIP_STATE_PHASE1_HW_UPLOADED; + } + + tkip_mixing_phase2(tk, &rx_ctx->ctx, iv16, rc4key); + + res = ieee80211_wep_decrypt_data(ctx, rc4key, 16, pos, payload_len - 12); + done: + if (res == TKIP_DECRYPT_OK) { + /* + * Record previously received IV, will be copied into the + * key information after MIC verification. It is possible + * that we don't catch replays of fragments but that's ok + * because the Michael MIC verication will then fail. + */ + *out_iv32 = iv32; + *out_iv16 = iv16; + } + + return res; +} diff --git a/net/mac80211/tkip.h b/net/mac80211/tkip.h new file mode 100644 index 000000000..9d2f8bd36 --- /dev/null +++ b/net/mac80211/tkip.h @@ -0,0 +1,30 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright 2002-2004, Instant802 Networks, Inc. + */ + +#ifndef TKIP_H +#define TKIP_H + +#include +#include +#include "key.h" + +int ieee80211_tkip_encrypt_data(struct arc4_ctx *ctx, + struct ieee80211_key *key, + struct sk_buff *skb, + u8 *payload, size_t payload_len); + +enum { + TKIP_DECRYPT_OK = 0, + TKIP_DECRYPT_NO_EXT_IV = -1, + TKIP_DECRYPT_INVALID_KEYIDX = -2, + TKIP_DECRYPT_REPLAY = -3, +}; +int ieee80211_tkip_decrypt_data(struct arc4_ctx *ctx, + struct ieee80211_key *key, + u8 *payload, size_t payload_len, u8 *ta, + u8 *ra, int only_iv, int queue, + u32 *out_iv32, u16 *out_iv16); + +#endif /* TKIP_H */ diff --git a/net/mac80211/trace.c b/net/mac80211/trace.c new file mode 100644 index 000000000..837857261 --- /dev/null +++ b/net/mac80211/trace.c @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: GPL-2.0 +/* bug in tracepoint.h, it should include this */ +#include + +/* sparse isn't too happy with all macros... */ +#ifndef __CHECKER__ +#include +#include "driver-ops.h" +#include "debug.h" +#define CREATE_TRACE_POINTS +#include "trace.h" +#include "trace_msg.h" + +#ifdef CONFIG_MAC80211_MESSAGE_TRACING +void __sdata_info(const char *fmt, ...) +{ + struct va_format vaf = { + .fmt = fmt, + }; + va_list args; + + va_start(args, fmt); + vaf.va = &args; + + pr_info("%pV", &vaf); + trace_mac80211_info(&vaf); + va_end(args); +} + +void __sdata_dbg(bool print, const char *fmt, ...) +{ + struct va_format vaf = { + .fmt = fmt, + }; + va_list args; + + va_start(args, fmt); + vaf.va = &args; + + if (print) + pr_debug("%pV", &vaf); + trace_mac80211_dbg(&vaf); + va_end(args); +} + +void __sdata_err(const char *fmt, ...) +{ + struct va_format vaf = { + .fmt = fmt, + }; + va_list args; + + va_start(args, fmt); + vaf.va = &args; + + pr_err("%pV", &vaf); + trace_mac80211_err(&vaf); + va_end(args); +} + +void __wiphy_dbg(struct wiphy *wiphy, bool print, const char *fmt, ...) +{ + struct va_format vaf = { + .fmt = fmt, + }; + va_list args; + + va_start(args, fmt); + vaf.va = &args; + + if (print) + wiphy_dbg(wiphy, "%pV", &vaf); + trace_mac80211_dbg(&vaf); + va_end(args); +} +#endif +#endif diff --git a/net/mac80211/trace.h b/net/mac80211/trace.h new file mode 100644 index 000000000..c85367a47 --- /dev/null +++ b/net/mac80211/trace.h @@ -0,0 +1,3035 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Portions of this file + * Copyright(c) 2016-2017 Intel Deutschland GmbH + * Copyright (C) 2018 - 2022 Intel Corporation + */ + +#if !defined(__MAC80211_DRIVER_TRACE) || defined(TRACE_HEADER_MULTI_READ) +#define __MAC80211_DRIVER_TRACE + +#include +#include +#include "ieee80211_i.h" + +#undef TRACE_SYSTEM +#define TRACE_SYSTEM mac80211 + +#define MAXNAME 32 +#define LOCAL_ENTRY __array(char, wiphy_name, 32) +#define LOCAL_ASSIGN strlcpy(__entry->wiphy_name, wiphy_name(local->hw.wiphy), MAXNAME) +#define LOCAL_PR_FMT "%s" +#define LOCAL_PR_ARG __entry->wiphy_name + +#define STA_ENTRY __array(char, sta_addr, ETH_ALEN) +#define STA_ASSIGN (sta ? memcpy(__entry->sta_addr, sta->addr, ETH_ALEN) : \ + eth_zero_addr(__entry->sta_addr)) +#define STA_NAMED_ASSIGN(s) memcpy(__entry->sta_addr, (s)->addr, ETH_ALEN) +#define STA_PR_FMT " sta:%pM" +#define STA_PR_ARG __entry->sta_addr + +#define VIF_ENTRY __field(enum nl80211_iftype, vif_type) __field(void *, sdata) \ + __field(bool, p2p) \ + __string(vif_name, sdata->name) +#define VIF_ASSIGN __entry->vif_type = sdata->vif.type; __entry->sdata = sdata; \ + __entry->p2p = sdata->vif.p2p; \ + __assign_str(vif_name, sdata->name) +#define VIF_PR_FMT " vif:%s(%d%s)" +#define VIF_PR_ARG __get_str(vif_name), __entry->vif_type, __entry->p2p ? "/p2p" : "" + +#define CHANDEF_ENTRY __field(u32, control_freq) \ + __field(u32, freq_offset) \ + __field(u32, chan_width) \ + __field(u32, center_freq1) \ + __field(u32, freq1_offset) \ + __field(u32, center_freq2) +#define CHANDEF_ASSIGN(c) \ + __entry->control_freq = (c) ? ((c)->chan ? (c)->chan->center_freq : 0) : 0; \ + __entry->freq_offset = (c) ? ((c)->chan ? (c)->chan->freq_offset : 0) : 0; \ + __entry->chan_width = (c) ? (c)->width : 0; \ + __entry->center_freq1 = (c) ? (c)->center_freq1 : 0; \ + __entry->freq1_offset = (c) ? (c)->freq1_offset : 0; \ + __entry->center_freq2 = (c) ? (c)->center_freq2 : 0; +#define CHANDEF_PR_FMT " control:%d.%03d MHz width:%d center: %d.%03d/%d MHz" +#define CHANDEF_PR_ARG __entry->control_freq, __entry->freq_offset, __entry->chan_width, \ + __entry->center_freq1, __entry->freq1_offset, __entry->center_freq2 + +#define MIN_CHANDEF_ENTRY \ + __field(u32, min_control_freq) \ + __field(u32, min_freq_offset) \ + __field(u32, min_chan_width) \ + __field(u32, min_center_freq1) \ + __field(u32, min_freq1_offset) \ + __field(u32, min_center_freq2) + +#define MIN_CHANDEF_ASSIGN(c) \ + __entry->min_control_freq = (c)->chan ? (c)->chan->center_freq : 0; \ + __entry->min_freq_offset = (c)->chan ? (c)->chan->freq_offset : 0; \ + __entry->min_chan_width = (c)->width; \ + __entry->min_center_freq1 = (c)->center_freq1; \ + __entry->min_freq1_offset = (c)->freq1_offset; \ + __entry->min_center_freq2 = (c)->center_freq2; +#define MIN_CHANDEF_PR_FMT " min_control:%d.%03d MHz min_width:%d min_center: %d.%03d/%d MHz" +#define MIN_CHANDEF_PR_ARG __entry->min_control_freq, __entry->min_freq_offset, \ + __entry->min_chan_width, \ + __entry->min_center_freq1, __entry->min_freq1_offset, \ + __entry->min_center_freq2 + +#define CHANCTX_ENTRY CHANDEF_ENTRY \ + MIN_CHANDEF_ENTRY \ + __field(u8, rx_chains_static) \ + __field(u8, rx_chains_dynamic) +#define CHANCTX_ASSIGN CHANDEF_ASSIGN(&ctx->conf.def) \ + MIN_CHANDEF_ASSIGN(&ctx->conf.min_def) \ + __entry->rx_chains_static = ctx->conf.rx_chains_static; \ + __entry->rx_chains_dynamic = ctx->conf.rx_chains_dynamic +#define CHANCTX_PR_FMT CHANDEF_PR_FMT MIN_CHANDEF_PR_FMT " chains:%d/%d" +#define CHANCTX_PR_ARG CHANDEF_PR_ARG, MIN_CHANDEF_PR_ARG, \ + __entry->rx_chains_static, __entry->rx_chains_dynamic + +#define KEY_ENTRY __field(u32, cipher) \ + __field(u8, hw_key_idx) \ + __field(u8, flags) \ + __field(s8, keyidx) +#define KEY_ASSIGN(k) __entry->cipher = (k)->cipher; \ + __entry->flags = (k)->flags; \ + __entry->keyidx = (k)->keyidx; \ + __entry->hw_key_idx = (k)->hw_key_idx; +#define KEY_PR_FMT " cipher:0x%x, flags=%#x, keyidx=%d, hw_key_idx=%d" +#define KEY_PR_ARG __entry->cipher, __entry->flags, __entry->keyidx, __entry->hw_key_idx + +#define AMPDU_ACTION_ENTRY __field(enum ieee80211_ampdu_mlme_action, \ + ieee80211_ampdu_mlme_action) \ + STA_ENTRY \ + __field(u16, tid) \ + __field(u16, ssn) \ + __field(u16, buf_size) \ + __field(bool, amsdu) \ + __field(u16, timeout) \ + __field(u16, action) +#define AMPDU_ACTION_ASSIGN STA_NAMED_ASSIGN(params->sta); \ + __entry->tid = params->tid; \ + __entry->ssn = params->ssn; \ + __entry->buf_size = params->buf_size; \ + __entry->amsdu = params->amsdu; \ + __entry->timeout = params->timeout; \ + __entry->action = params->action; +#define AMPDU_ACTION_PR_FMT STA_PR_FMT " tid %d, ssn %d, buf_size %u, amsdu %d, timeout %d action %d" +#define AMPDU_ACTION_PR_ARG STA_PR_ARG, __entry->tid, __entry->ssn, \ + __entry->buf_size, __entry->amsdu, __entry->timeout, \ + __entry->action + +/* + * Tracing for driver callbacks. + */ + +DECLARE_EVENT_CLASS(local_only_evt, + TP_PROTO(struct ieee80211_local *local), + TP_ARGS(local), + TP_STRUCT__entry( + LOCAL_ENTRY + ), + TP_fast_assign( + LOCAL_ASSIGN; + ), + TP_printk(LOCAL_PR_FMT, LOCAL_PR_ARG) +); + +DECLARE_EVENT_CLASS(local_sdata_addr_evt, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __array(char, addr, ETH_ALEN) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + memcpy(__entry->addr, sdata->vif.addr, ETH_ALEN); + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " addr:%pM", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->addr + ) +); + +DECLARE_EVENT_CLASS(local_u32_evt, + TP_PROTO(struct ieee80211_local *local, u32 value), + TP_ARGS(local, value), + + TP_STRUCT__entry( + LOCAL_ENTRY + __field(u32, value) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + __entry->value = value; + ), + + TP_printk( + LOCAL_PR_FMT " value:%d", + LOCAL_PR_ARG, __entry->value + ) +); + +DECLARE_EVENT_CLASS(local_sdata_evt, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT, + LOCAL_PR_ARG, VIF_PR_ARG + ) +); + +DEFINE_EVENT(local_only_evt, drv_return_void, + TP_PROTO(struct ieee80211_local *local), + TP_ARGS(local) +); + +TRACE_EVENT(drv_return_int, + TP_PROTO(struct ieee80211_local *local, int ret), + TP_ARGS(local, ret), + TP_STRUCT__entry( + LOCAL_ENTRY + __field(int, ret) + ), + TP_fast_assign( + LOCAL_ASSIGN; + __entry->ret = ret; + ), + TP_printk(LOCAL_PR_FMT " - %d", LOCAL_PR_ARG, __entry->ret) +); + +TRACE_EVENT(drv_return_bool, + TP_PROTO(struct ieee80211_local *local, bool ret), + TP_ARGS(local, ret), + TP_STRUCT__entry( + LOCAL_ENTRY + __field(bool, ret) + ), + TP_fast_assign( + LOCAL_ASSIGN; + __entry->ret = ret; + ), + TP_printk(LOCAL_PR_FMT " - %s", LOCAL_PR_ARG, (__entry->ret) ? + "true" : "false") +); + +TRACE_EVENT(drv_return_u32, + TP_PROTO(struct ieee80211_local *local, u32 ret), + TP_ARGS(local, ret), + TP_STRUCT__entry( + LOCAL_ENTRY + __field(u32, ret) + ), + TP_fast_assign( + LOCAL_ASSIGN; + __entry->ret = ret; + ), + TP_printk(LOCAL_PR_FMT " - %u", LOCAL_PR_ARG, __entry->ret) +); + +TRACE_EVENT(drv_return_u64, + TP_PROTO(struct ieee80211_local *local, u64 ret), + TP_ARGS(local, ret), + TP_STRUCT__entry( + LOCAL_ENTRY + __field(u64, ret) + ), + TP_fast_assign( + LOCAL_ASSIGN; + __entry->ret = ret; + ), + TP_printk(LOCAL_PR_FMT " - %llu", LOCAL_PR_ARG, __entry->ret) +); + +DEFINE_EVENT(local_only_evt, drv_start, + TP_PROTO(struct ieee80211_local *local), + TP_ARGS(local) +); + +DEFINE_EVENT(local_u32_evt, drv_get_et_strings, + TP_PROTO(struct ieee80211_local *local, u32 sset), + TP_ARGS(local, sset) +); + +DEFINE_EVENT(local_u32_evt, drv_get_et_sset_count, + TP_PROTO(struct ieee80211_local *local, u32 sset), + TP_ARGS(local, sset) +); + +DEFINE_EVENT(local_only_evt, drv_get_et_stats, + TP_PROTO(struct ieee80211_local *local), + TP_ARGS(local) +); + +DEFINE_EVENT(local_only_evt, drv_suspend, + TP_PROTO(struct ieee80211_local *local), + TP_ARGS(local) +); + +DEFINE_EVENT(local_only_evt, drv_resume, + TP_PROTO(struct ieee80211_local *local), + TP_ARGS(local) +); + +TRACE_EVENT(drv_set_wakeup, + TP_PROTO(struct ieee80211_local *local, bool enabled), + TP_ARGS(local, enabled), + TP_STRUCT__entry( + LOCAL_ENTRY + __field(bool, enabled) + ), + TP_fast_assign( + LOCAL_ASSIGN; + __entry->enabled = enabled; + ), + TP_printk(LOCAL_PR_FMT " enabled:%d", LOCAL_PR_ARG, __entry->enabled) +); + +DEFINE_EVENT(local_only_evt, drv_stop, + TP_PROTO(struct ieee80211_local *local), + TP_ARGS(local) +); + +DEFINE_EVENT(local_sdata_addr_evt, drv_add_interface, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +TRACE_EVENT(drv_change_interface, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + enum nl80211_iftype type, bool p2p), + + TP_ARGS(local, sdata, type, p2p), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(u32, new_type) + __field(bool, new_p2p) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->new_type = type; + __entry->new_p2p = p2p; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " new type:%d%s", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->new_type, + __entry->new_p2p ? "/p2p" : "" + ) +); + +DEFINE_EVENT(local_sdata_addr_evt, drv_remove_interface, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +TRACE_EVENT(drv_config, + TP_PROTO(struct ieee80211_local *local, + u32 changed), + + TP_ARGS(local, changed), + + TP_STRUCT__entry( + LOCAL_ENTRY + __field(u32, changed) + __field(u32, flags) + __field(int, power_level) + __field(int, dynamic_ps_timeout) + __field(u16, listen_interval) + __field(u8, long_frame_max_tx_count) + __field(u8, short_frame_max_tx_count) + CHANDEF_ENTRY + __field(int, smps) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + __entry->changed = changed; + __entry->flags = local->hw.conf.flags; + __entry->power_level = local->hw.conf.power_level; + __entry->dynamic_ps_timeout = local->hw.conf.dynamic_ps_timeout; + __entry->listen_interval = local->hw.conf.listen_interval; + __entry->long_frame_max_tx_count = + local->hw.conf.long_frame_max_tx_count; + __entry->short_frame_max_tx_count = + local->hw.conf.short_frame_max_tx_count; + CHANDEF_ASSIGN(&local->hw.conf.chandef) + __entry->smps = local->hw.conf.smps_mode; + ), + + TP_printk( + LOCAL_PR_FMT " ch:%#x" CHANDEF_PR_FMT, + LOCAL_PR_ARG, __entry->changed, CHANDEF_PR_ARG + ) +); + +TRACE_EVENT(drv_vif_cfg_changed, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + u64 changed), + + TP_ARGS(local, sdata, changed), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(u64, changed) + __field(bool, assoc) + __field(bool, ibss_joined) + __field(bool, ibss_creator) + __field(u16, aid) + __dynamic_array(u32, arp_addr_list, + sdata->vif.cfg.arp_addr_cnt > IEEE80211_BSS_ARP_ADDR_LIST_LEN ? + IEEE80211_BSS_ARP_ADDR_LIST_LEN : + sdata->vif.cfg.arp_addr_cnt) + __field(int, arp_addr_cnt) + __dynamic_array(u8, ssid, sdata->vif.cfg.ssid_len) + __field(int, s1g) + __field(bool, idle) + __field(bool, ps) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->changed = changed; + __entry->aid = sdata->vif.cfg.aid; + __entry->assoc = sdata->vif.cfg.assoc; + __entry->ibss_joined = sdata->vif.cfg.ibss_joined; + __entry->ibss_creator = sdata->vif.cfg.ibss_creator; + __entry->ps = sdata->vif.cfg.ps; + + __entry->arp_addr_cnt = sdata->vif.cfg.arp_addr_cnt; + memcpy(__get_dynamic_array(arp_addr_list), + sdata->vif.cfg.arp_addr_list, + sizeof(u32) * (sdata->vif.cfg.arp_addr_cnt > IEEE80211_BSS_ARP_ADDR_LIST_LEN ? + IEEE80211_BSS_ARP_ADDR_LIST_LEN : + sdata->vif.cfg.arp_addr_cnt)); + memcpy(__get_dynamic_array(ssid), + sdata->vif.cfg.ssid, + sdata->vif.cfg.ssid_len); + __entry->s1g = sdata->vif.cfg.s1g; + __entry->idle = sdata->vif.cfg.idle; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " changed:%#llx", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->changed + ) +); + +TRACE_EVENT(drv_link_info_changed, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss_conf *link_conf, + u64 changed), + + TP_ARGS(local, sdata, link_conf, changed), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(u64, changed) + __field(int, link_id) + __field(bool, cts) + __field(bool, shortpre) + __field(bool, shortslot) + __field(bool, enable_beacon) + __field(u8, dtimper) + __field(u16, bcnint) + __field(u16, assoc_cap) + __field(u64, sync_tsf) + __field(u32, sync_device_ts) + __field(u8, sync_dtim_count) + __field(u32, basic_rates) + __array(int, mcast_rate, NUM_NL80211_BANDS) + __field(u16, ht_operation_mode) + __field(s32, cqm_rssi_thold) + __field(s32, cqm_rssi_hyst) + __field(u32, channel_width) + __field(u32, channel_cfreq1) + __field(u32, channel_cfreq1_offset) + __field(bool, qos) + __field(bool, hidden_ssid) + __field(int, txpower) + __field(u8, p2p_oppps_ctwindow) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->changed = changed; + __entry->link_id = link_conf->link_id; + __entry->shortpre = link_conf->use_short_preamble; + __entry->cts = link_conf->use_cts_prot; + __entry->shortslot = link_conf->use_short_slot; + __entry->enable_beacon = link_conf->enable_beacon; + __entry->dtimper = link_conf->dtim_period; + __entry->bcnint = link_conf->beacon_int; + __entry->assoc_cap = link_conf->assoc_capability; + __entry->sync_tsf = link_conf->sync_tsf; + __entry->sync_device_ts = link_conf->sync_device_ts; + __entry->sync_dtim_count = link_conf->sync_dtim_count; + __entry->basic_rates = link_conf->basic_rates; + memcpy(__entry->mcast_rate, link_conf->mcast_rate, + sizeof(__entry->mcast_rate)); + __entry->ht_operation_mode = link_conf->ht_operation_mode; + __entry->cqm_rssi_thold = link_conf->cqm_rssi_thold; + __entry->cqm_rssi_hyst = link_conf->cqm_rssi_hyst; + __entry->channel_width = link_conf->chandef.width; + __entry->channel_cfreq1 = link_conf->chandef.center_freq1; + __entry->channel_cfreq1_offset = link_conf->chandef.freq1_offset; + __entry->qos = link_conf->qos; + __entry->hidden_ssid = link_conf->hidden_ssid; + __entry->txpower = link_conf->txpower; + __entry->p2p_oppps_ctwindow = link_conf->p2p_noa_attr.oppps_ctwindow; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " link_id:%d, changed:%#llx", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->link_id, + __entry->changed + ) +); + +TRACE_EVENT(drv_prepare_multicast, + TP_PROTO(struct ieee80211_local *local, int mc_count), + + TP_ARGS(local, mc_count), + + TP_STRUCT__entry( + LOCAL_ENTRY + __field(int, mc_count) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + __entry->mc_count = mc_count; + ), + + TP_printk( + LOCAL_PR_FMT " prepare mc (%d)", + LOCAL_PR_ARG, __entry->mc_count + ) +); + +TRACE_EVENT(drv_configure_filter, + TP_PROTO(struct ieee80211_local *local, + unsigned int changed_flags, + unsigned int *total_flags, + u64 multicast), + + TP_ARGS(local, changed_flags, total_flags, multicast), + + TP_STRUCT__entry( + LOCAL_ENTRY + __field(unsigned int, changed) + __field(unsigned int, total) + __field(u64, multicast) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + __entry->changed = changed_flags; + __entry->total = *total_flags; + __entry->multicast = multicast; + ), + + TP_printk( + LOCAL_PR_FMT " changed:%#x total:%#x", + LOCAL_PR_ARG, __entry->changed, __entry->total + ) +); + +TRACE_EVENT(drv_config_iface_filter, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + unsigned int filter_flags, + unsigned int changed_flags), + + TP_ARGS(local, sdata, filter_flags, changed_flags), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(unsigned int, filter_flags) + __field(unsigned int, changed_flags) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->filter_flags = filter_flags; + __entry->changed_flags = changed_flags; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT + " filter_flags: %#x changed_flags: %#x", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->filter_flags, + __entry->changed_flags + ) +); + +TRACE_EVENT(drv_set_tim, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sta *sta, bool set), + + TP_ARGS(local, sta, set), + + TP_STRUCT__entry( + LOCAL_ENTRY + STA_ENTRY + __field(bool, set) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + STA_ASSIGN; + __entry->set = set; + ), + + TP_printk( + LOCAL_PR_FMT STA_PR_FMT " set:%d", + LOCAL_PR_ARG, STA_PR_ARG, __entry->set + ) +); + +TRACE_EVENT(drv_set_key, + TP_PROTO(struct ieee80211_local *local, + enum set_key_cmd cmd, struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, + struct ieee80211_key_conf *key), + + TP_ARGS(local, cmd, sdata, sta, key), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + STA_ENTRY + KEY_ENTRY + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + STA_ASSIGN; + KEY_ASSIGN(key); + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT STA_PR_FMT KEY_PR_FMT, + LOCAL_PR_ARG, VIF_PR_ARG, STA_PR_ARG, KEY_PR_ARG + ) +); + +TRACE_EVENT(drv_update_tkip_key, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_key_conf *conf, + struct ieee80211_sta *sta, u32 iv32), + + TP_ARGS(local, sdata, conf, sta, iv32), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + STA_ENTRY + __field(u32, iv32) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + STA_ASSIGN; + __entry->iv32 = iv32; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT STA_PR_FMT " iv32:%#x", + LOCAL_PR_ARG, VIF_PR_ARG, STA_PR_ARG, __entry->iv32 + ) +); + +DEFINE_EVENT(local_sdata_evt, drv_hw_scan, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +DEFINE_EVENT(local_sdata_evt, drv_cancel_hw_scan, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +DEFINE_EVENT(local_sdata_evt, drv_sched_scan_start, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +DEFINE_EVENT(local_sdata_evt, drv_sched_scan_stop, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +TRACE_EVENT(drv_sw_scan_start, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + const u8 *mac_addr), + + TP_ARGS(local, sdata, mac_addr), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __array(char, mac_addr, ETH_ALEN) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + memcpy(__entry->mac_addr, mac_addr, ETH_ALEN); + ), + + TP_printk(LOCAL_PR_FMT ", " VIF_PR_FMT ", addr:%pM", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->mac_addr) +); + +DEFINE_EVENT(local_sdata_evt, drv_sw_scan_complete, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +TRACE_EVENT(drv_get_stats, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_low_level_stats *stats, + int ret), + + TP_ARGS(local, stats, ret), + + TP_STRUCT__entry( + LOCAL_ENTRY + __field(int, ret) + __field(unsigned int, ackfail) + __field(unsigned int, rtsfail) + __field(unsigned int, fcserr) + __field(unsigned int, rtssucc) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + __entry->ret = ret; + __entry->ackfail = stats->dot11ACKFailureCount; + __entry->rtsfail = stats->dot11RTSFailureCount; + __entry->fcserr = stats->dot11FCSErrorCount; + __entry->rtssucc = stats->dot11RTSSuccessCount; + ), + + TP_printk( + LOCAL_PR_FMT " ret:%d", + LOCAL_PR_ARG, __entry->ret + ) +); + +TRACE_EVENT(drv_get_key_seq, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_key_conf *key), + + TP_ARGS(local, key), + + TP_STRUCT__entry( + LOCAL_ENTRY + KEY_ENTRY + ), + + TP_fast_assign( + LOCAL_ASSIGN; + KEY_ASSIGN(key); + ), + + TP_printk( + LOCAL_PR_FMT KEY_PR_FMT, + LOCAL_PR_ARG, KEY_PR_ARG + ) +); + +DEFINE_EVENT(local_u32_evt, drv_set_frag_threshold, + TP_PROTO(struct ieee80211_local *local, u32 value), + TP_ARGS(local, value) +); + +DEFINE_EVENT(local_u32_evt, drv_set_rts_threshold, + TP_PROTO(struct ieee80211_local *local, u32 value), + TP_ARGS(local, value) +); + +TRACE_EVENT(drv_set_coverage_class, + TP_PROTO(struct ieee80211_local *local, s16 value), + + TP_ARGS(local, value), + + TP_STRUCT__entry( + LOCAL_ENTRY + __field(s16, value) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + __entry->value = value; + ), + + TP_printk( + LOCAL_PR_FMT " value:%d", + LOCAL_PR_ARG, __entry->value + ) +); + +TRACE_EVENT(drv_sta_notify, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + enum sta_notify_cmd cmd, + struct ieee80211_sta *sta), + + TP_ARGS(local, sdata, cmd, sta), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + STA_ENTRY + __field(u32, cmd) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + STA_ASSIGN; + __entry->cmd = cmd; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT STA_PR_FMT " cmd:%d", + LOCAL_PR_ARG, VIF_PR_ARG, STA_PR_ARG, __entry->cmd + ) +); + +TRACE_EVENT(drv_sta_state, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, + enum ieee80211_sta_state old_state, + enum ieee80211_sta_state new_state), + + TP_ARGS(local, sdata, sta, old_state, new_state), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + STA_ENTRY + __field(u32, old_state) + __field(u32, new_state) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + STA_ASSIGN; + __entry->old_state = old_state; + __entry->new_state = new_state; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT STA_PR_FMT " state: %d->%d", + LOCAL_PR_ARG, VIF_PR_ARG, STA_PR_ARG, + __entry->old_state, __entry->new_state + ) +); + +TRACE_EVENT(drv_sta_set_txpwr, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta), + + TP_ARGS(local, sdata, sta), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + STA_ENTRY + __field(s16, txpwr) + __field(u8, type) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + STA_ASSIGN; + __entry->txpwr = sta->deflink.txpwr.power; + __entry->type = sta->deflink.txpwr.type; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT STA_PR_FMT " txpwr: %d type %d", + LOCAL_PR_ARG, VIF_PR_ARG, STA_PR_ARG, + __entry->txpwr, __entry->type + ) +); + +TRACE_EVENT(drv_sta_rc_update, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, + u32 changed), + + TP_ARGS(local, sdata, sta, changed), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + STA_ENTRY + __field(u32, changed) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + STA_ASSIGN; + __entry->changed = changed; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT STA_PR_FMT " changed: 0x%x", + LOCAL_PR_ARG, VIF_PR_ARG, STA_PR_ARG, __entry->changed + ) +); + +DECLARE_EVENT_CLASS(sta_event, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta), + + TP_ARGS(local, sdata, sta), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + STA_ENTRY + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + STA_ASSIGN; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT STA_PR_FMT, + LOCAL_PR_ARG, VIF_PR_ARG, STA_PR_ARG + ) +); + +DEFINE_EVENT(sta_event, drv_sta_statistics, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta), + TP_ARGS(local, sdata, sta) +); + +DEFINE_EVENT(sta_event, drv_sta_add, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta), + TP_ARGS(local, sdata, sta) +); + +DEFINE_EVENT(sta_event, drv_sta_remove, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta), + TP_ARGS(local, sdata, sta) +); + +DEFINE_EVENT(sta_event, drv_sta_pre_rcu_remove, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta), + TP_ARGS(local, sdata, sta) +); + +DEFINE_EVENT(sta_event, drv_sync_rx_queues, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta), + TP_ARGS(local, sdata, sta) +); + +DEFINE_EVENT(sta_event, drv_sta_rate_tbl_update, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta), + TP_ARGS(local, sdata, sta) +); + +TRACE_EVENT(drv_conf_tx, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + unsigned int link_id, + u16 ac, const struct ieee80211_tx_queue_params *params), + + TP_ARGS(local, sdata, link_id, ac, params), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(unsigned int, link_id) + __field(u16, ac) + __field(u16, txop) + __field(u16, cw_min) + __field(u16, cw_max) + __field(u8, aifs) + __field(bool, uapsd) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->link_id = link_id; + __entry->ac = ac; + __entry->txop = params->txop; + __entry->cw_max = params->cw_max; + __entry->cw_min = params->cw_min; + __entry->aifs = params->aifs; + __entry->uapsd = params->uapsd; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " link_id: %d, AC:%d", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->link_id, __entry->ac + ) +); + +DEFINE_EVENT(local_sdata_evt, drv_get_tsf, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +TRACE_EVENT(drv_set_tsf, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + u64 tsf), + + TP_ARGS(local, sdata, tsf), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(u64, tsf) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->tsf = tsf; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " tsf:%llu", + LOCAL_PR_ARG, VIF_PR_ARG, (unsigned long long)__entry->tsf + ) +); + +TRACE_EVENT(drv_offset_tsf, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + s64 offset), + + TP_ARGS(local, sdata, offset), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(s64, tsf_offset) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->tsf_offset = offset; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " tsf offset:%lld", + LOCAL_PR_ARG, VIF_PR_ARG, + (unsigned long long)__entry->tsf_offset + ) +); + +DEFINE_EVENT(local_sdata_evt, drv_reset_tsf, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +DEFINE_EVENT(local_only_evt, drv_tx_last_beacon, + TP_PROTO(struct ieee80211_local *local), + TP_ARGS(local) +); + +TRACE_EVENT(drv_ampdu_action, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_ampdu_params *params), + + TP_ARGS(local, sdata, params), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + AMPDU_ACTION_ENTRY + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + AMPDU_ACTION_ASSIGN; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT AMPDU_ACTION_PR_FMT, + LOCAL_PR_ARG, VIF_PR_ARG, AMPDU_ACTION_PR_ARG + ) +); + +TRACE_EVENT(drv_get_survey, + TP_PROTO(struct ieee80211_local *local, int _idx, + struct survey_info *survey), + + TP_ARGS(local, _idx, survey), + + TP_STRUCT__entry( + LOCAL_ENTRY + __field(int, idx) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + __entry->idx = _idx; + ), + + TP_printk( + LOCAL_PR_FMT " idx:%d", + LOCAL_PR_ARG, __entry->idx + ) +); + +TRACE_EVENT(drv_flush, + TP_PROTO(struct ieee80211_local *local, + u32 queues, bool drop), + + TP_ARGS(local, queues, drop), + + TP_STRUCT__entry( + LOCAL_ENTRY + __field(bool, drop) + __field(u32, queues) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + __entry->drop = drop; + __entry->queues = queues; + ), + + TP_printk( + LOCAL_PR_FMT " queues:0x%x drop:%d", + LOCAL_PR_ARG, __entry->queues, __entry->drop + ) +); + +TRACE_EVENT(drv_channel_switch, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_channel_switch *ch_switch), + + TP_ARGS(local, sdata, ch_switch), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + CHANDEF_ENTRY + __field(u64, timestamp) + __field(u32, device_timestamp) + __field(bool, block_tx) + __field(u8, count) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + CHANDEF_ASSIGN(&ch_switch->chandef) + __entry->timestamp = ch_switch->timestamp; + __entry->device_timestamp = ch_switch->device_timestamp; + __entry->block_tx = ch_switch->block_tx; + __entry->count = ch_switch->count; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " new " CHANDEF_PR_FMT " count:%d", + LOCAL_PR_ARG, VIF_PR_ARG, CHANDEF_PR_ARG, __entry->count + ) +); + +TRACE_EVENT(drv_set_antenna, + TP_PROTO(struct ieee80211_local *local, u32 tx_ant, u32 rx_ant, int ret), + + TP_ARGS(local, tx_ant, rx_ant, ret), + + TP_STRUCT__entry( + LOCAL_ENTRY + __field(u32, tx_ant) + __field(u32, rx_ant) + __field(int, ret) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + __entry->tx_ant = tx_ant; + __entry->rx_ant = rx_ant; + __entry->ret = ret; + ), + + TP_printk( + LOCAL_PR_FMT " tx_ant:%d rx_ant:%d ret:%d", + LOCAL_PR_ARG, __entry->tx_ant, __entry->rx_ant, __entry->ret + ) +); + +TRACE_EVENT(drv_get_antenna, + TP_PROTO(struct ieee80211_local *local, u32 tx_ant, u32 rx_ant, int ret), + + TP_ARGS(local, tx_ant, rx_ant, ret), + + TP_STRUCT__entry( + LOCAL_ENTRY + __field(u32, tx_ant) + __field(u32, rx_ant) + __field(int, ret) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + __entry->tx_ant = tx_ant; + __entry->rx_ant = rx_ant; + __entry->ret = ret; + ), + + TP_printk( + LOCAL_PR_FMT " tx_ant:%d rx_ant:%d ret:%d", + LOCAL_PR_ARG, __entry->tx_ant, __entry->rx_ant, __entry->ret + ) +); + +TRACE_EVENT(drv_remain_on_channel, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_channel *chan, + unsigned int duration, + enum ieee80211_roc_type type), + + TP_ARGS(local, sdata, chan, duration, type), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(int, center_freq) + __field(int, freq_offset) + __field(unsigned int, duration) + __field(u32, type) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->center_freq = chan->center_freq; + __entry->freq_offset = chan->freq_offset; + __entry->duration = duration; + __entry->type = type; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " freq:%d.%03dMHz duration:%dms type=%d", + LOCAL_PR_ARG, VIF_PR_ARG, + __entry->center_freq, __entry->freq_offset, + __entry->duration, __entry->type + ) +); + +DEFINE_EVENT(local_sdata_evt, drv_cancel_remain_on_channel, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +TRACE_EVENT(drv_set_ringparam, + TP_PROTO(struct ieee80211_local *local, u32 tx, u32 rx), + + TP_ARGS(local, tx, rx), + + TP_STRUCT__entry( + LOCAL_ENTRY + __field(u32, tx) + __field(u32, rx) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + __entry->tx = tx; + __entry->rx = rx; + ), + + TP_printk( + LOCAL_PR_FMT " tx:%d rx %d", + LOCAL_PR_ARG, __entry->tx, __entry->rx + ) +); + +TRACE_EVENT(drv_get_ringparam, + TP_PROTO(struct ieee80211_local *local, u32 *tx, u32 *tx_max, + u32 *rx, u32 *rx_max), + + TP_ARGS(local, tx, tx_max, rx, rx_max), + + TP_STRUCT__entry( + LOCAL_ENTRY + __field(u32, tx) + __field(u32, tx_max) + __field(u32, rx) + __field(u32, rx_max) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + __entry->tx = *tx; + __entry->tx_max = *tx_max; + __entry->rx = *rx; + __entry->rx_max = *rx_max; + ), + + TP_printk( + LOCAL_PR_FMT " tx:%d tx_max %d rx %d rx_max %d", + LOCAL_PR_ARG, + __entry->tx, __entry->tx_max, __entry->rx, __entry->rx_max + ) +); + +DEFINE_EVENT(local_only_evt, drv_tx_frames_pending, + TP_PROTO(struct ieee80211_local *local), + TP_ARGS(local) +); + +DEFINE_EVENT(local_only_evt, drv_offchannel_tx_cancel_wait, + TP_PROTO(struct ieee80211_local *local), + TP_ARGS(local) +); + +TRACE_EVENT(drv_set_bitrate_mask, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + const struct cfg80211_bitrate_mask *mask), + + TP_ARGS(local, sdata, mask), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(u32, legacy_2g) + __field(u32, legacy_5g) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->legacy_2g = mask->control[NL80211_BAND_2GHZ].legacy; + __entry->legacy_5g = mask->control[NL80211_BAND_5GHZ].legacy; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " 2G Mask:0x%x 5G Mask:0x%x", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->legacy_2g, __entry->legacy_5g + ) +); + +TRACE_EVENT(drv_set_rekey_data, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct cfg80211_gtk_rekey_data *data), + + TP_ARGS(local, sdata, data), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __array(u8, kek, NL80211_KEK_LEN) + __array(u8, kck, NL80211_KCK_LEN) + __array(u8, replay_ctr, NL80211_REPLAY_CTR_LEN) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + memcpy(__entry->kek, data->kek, NL80211_KEK_LEN); + memcpy(__entry->kck, data->kck, NL80211_KCK_LEN); + memcpy(__entry->replay_ctr, data->replay_ctr, + NL80211_REPLAY_CTR_LEN); + ), + + TP_printk(LOCAL_PR_FMT VIF_PR_FMT, + LOCAL_PR_ARG, VIF_PR_ARG) +); + +TRACE_EVENT(drv_event_callback, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + const struct ieee80211_event *_event), + + TP_ARGS(local, sdata, _event), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(u32, type) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->type = _event->type; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " event:%d", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->type + ) +); + +DECLARE_EVENT_CLASS(release_evt, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sta *sta, + u16 tids, int num_frames, + enum ieee80211_frame_release_type reason, + bool more_data), + + TP_ARGS(local, sta, tids, num_frames, reason, more_data), + + TP_STRUCT__entry( + LOCAL_ENTRY + STA_ENTRY + __field(u16, tids) + __field(int, num_frames) + __field(int, reason) + __field(bool, more_data) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + STA_ASSIGN; + __entry->tids = tids; + __entry->num_frames = num_frames; + __entry->reason = reason; + __entry->more_data = more_data; + ), + + TP_printk( + LOCAL_PR_FMT STA_PR_FMT + " TIDs:0x%.4x frames:%d reason:%d more:%d", + LOCAL_PR_ARG, STA_PR_ARG, __entry->tids, __entry->num_frames, + __entry->reason, __entry->more_data + ) +); + +DEFINE_EVENT(release_evt, drv_release_buffered_frames, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sta *sta, + u16 tids, int num_frames, + enum ieee80211_frame_release_type reason, + bool more_data), + + TP_ARGS(local, sta, tids, num_frames, reason, more_data) +); + +DEFINE_EVENT(release_evt, drv_allow_buffered_frames, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sta *sta, + u16 tids, int num_frames, + enum ieee80211_frame_release_type reason, + bool more_data), + + TP_ARGS(local, sta, tids, num_frames, reason, more_data) +); + +DECLARE_EVENT_CLASS(mgd_prepare_complete_tx_evt, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + u16 duration, u16 subtype, bool success), + + TP_ARGS(local, sdata, duration, subtype, success), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(u32, duration) + __field(u16, subtype) + __field(u8, success) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->duration = duration; + __entry->subtype = subtype; + __entry->success = success; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " duration: %u, subtype:0x%x, success:%d", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->duration, + __entry->subtype, __entry->success + ) +); + +DEFINE_EVENT(mgd_prepare_complete_tx_evt, drv_mgd_prepare_tx, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + u16 duration, u16 subtype, bool success), + + TP_ARGS(local, sdata, duration, subtype, success) +); + +DEFINE_EVENT(mgd_prepare_complete_tx_evt, drv_mgd_complete_tx, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + u16 duration, u16 subtype, bool success), + + TP_ARGS(local, sdata, duration, subtype, success) +); + +DEFINE_EVENT(local_sdata_evt, drv_mgd_protect_tdls_discover, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + + TP_ARGS(local, sdata) +); + +DECLARE_EVENT_CLASS(local_chanctx, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx), + + TP_ARGS(local, ctx), + + TP_STRUCT__entry( + LOCAL_ENTRY + CHANCTX_ENTRY + ), + + TP_fast_assign( + LOCAL_ASSIGN; + CHANCTX_ASSIGN; + ), + + TP_printk( + LOCAL_PR_FMT CHANCTX_PR_FMT, + LOCAL_PR_ARG, CHANCTX_PR_ARG + ) +); + +DEFINE_EVENT(local_chanctx, drv_add_chanctx, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx), + TP_ARGS(local, ctx) +); + +DEFINE_EVENT(local_chanctx, drv_remove_chanctx, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx), + TP_ARGS(local, ctx) +); + +TRACE_EVENT(drv_change_chanctx, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx, + u32 changed), + + TP_ARGS(local, ctx, changed), + + TP_STRUCT__entry( + LOCAL_ENTRY + CHANCTX_ENTRY + __field(u32, changed) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + CHANCTX_ASSIGN; + __entry->changed = changed; + ), + + TP_printk( + LOCAL_PR_FMT CHANCTX_PR_FMT " changed:%#x", + LOCAL_PR_ARG, CHANCTX_PR_ARG, __entry->changed + ) +); + +#if !defined(__TRACE_VIF_ENTRY) +#define __TRACE_VIF_ENTRY +struct trace_vif_entry { + enum nl80211_iftype vif_type; + bool p2p; + char vif_name[IFNAMSIZ]; +} __packed; + +struct trace_chandef_entry { + u32 control_freq; + u32 freq_offset; + u32 chan_width; + u32 center_freq1; + u32 freq1_offset; + u32 center_freq2; +} __packed; + +struct trace_switch_entry { + struct trace_vif_entry vif; + unsigned int link_id; + struct trace_chandef_entry old_chandef; + struct trace_chandef_entry new_chandef; +} __packed; + +#define SWITCH_ENTRY_ASSIGN(to, from) local_vifs[i].to = vifs[i].from +#endif + +TRACE_EVENT(drv_switch_vif_chanctx, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_vif_chanctx_switch *vifs, + int n_vifs, enum ieee80211_chanctx_switch_mode mode), + TP_ARGS(local, vifs, n_vifs, mode), + + TP_STRUCT__entry( + LOCAL_ENTRY + __field(int, n_vifs) + __field(u32, mode) + __dynamic_array(u8, vifs, + sizeof(struct trace_switch_entry) * n_vifs) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + __entry->n_vifs = n_vifs; + __entry->mode = mode; + { + struct trace_switch_entry *local_vifs = + __get_dynamic_array(vifs); + int i; + + for (i = 0; i < n_vifs; i++) { + struct ieee80211_sub_if_data *sdata; + + sdata = container_of(vifs[i].vif, + struct ieee80211_sub_if_data, + vif); + + SWITCH_ENTRY_ASSIGN(vif.vif_type, vif->type); + SWITCH_ENTRY_ASSIGN(vif.p2p, vif->p2p); + SWITCH_ENTRY_ASSIGN(link_id, link_conf->link_id); + strncpy(local_vifs[i].vif.vif_name, + sdata->name, + sizeof(local_vifs[i].vif.vif_name)); + SWITCH_ENTRY_ASSIGN(old_chandef.control_freq, + old_ctx->def.chan->center_freq); + SWITCH_ENTRY_ASSIGN(old_chandef.freq_offset, + old_ctx->def.chan->freq_offset); + SWITCH_ENTRY_ASSIGN(old_chandef.chan_width, + old_ctx->def.width); + SWITCH_ENTRY_ASSIGN(old_chandef.center_freq1, + old_ctx->def.center_freq1); + SWITCH_ENTRY_ASSIGN(old_chandef.freq1_offset, + old_ctx->def.freq1_offset); + SWITCH_ENTRY_ASSIGN(old_chandef.center_freq2, + old_ctx->def.center_freq2); + SWITCH_ENTRY_ASSIGN(new_chandef.control_freq, + new_ctx->def.chan->center_freq); + SWITCH_ENTRY_ASSIGN(new_chandef.freq_offset, + new_ctx->def.chan->freq_offset); + SWITCH_ENTRY_ASSIGN(new_chandef.chan_width, + new_ctx->def.width); + SWITCH_ENTRY_ASSIGN(new_chandef.center_freq1, + new_ctx->def.center_freq1); + SWITCH_ENTRY_ASSIGN(new_chandef.freq1_offset, + new_ctx->def.freq1_offset); + SWITCH_ENTRY_ASSIGN(new_chandef.center_freq2, + new_ctx->def.center_freq2); + } + } + ), + + TP_printk( + LOCAL_PR_FMT " n_vifs:%d mode:%d", + LOCAL_PR_ARG, __entry->n_vifs, __entry->mode + ) +); + +DECLARE_EVENT_CLASS(local_sdata_chanctx, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss_conf *link_conf, + struct ieee80211_chanctx *ctx), + + TP_ARGS(local, sdata, link_conf, ctx), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + CHANCTX_ENTRY + __field(unsigned int, link_id) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + CHANCTX_ASSIGN; + __entry->link_id = link_conf->link_id; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " link_id:%d" CHANCTX_PR_FMT, + LOCAL_PR_ARG, VIF_PR_ARG, __entry->link_id, CHANCTX_PR_ARG + ) +); + +DEFINE_EVENT(local_sdata_chanctx, drv_assign_vif_chanctx, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss_conf *link_conf, + struct ieee80211_chanctx *ctx), + TP_ARGS(local, sdata, link_conf, ctx) +); + +DEFINE_EVENT(local_sdata_chanctx, drv_unassign_vif_chanctx, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss_conf *link_conf, + struct ieee80211_chanctx *ctx), + TP_ARGS(local, sdata, link_conf, ctx) +); + +TRACE_EVENT(drv_start_ap, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss_conf *link_conf), + + TP_ARGS(local, sdata, link_conf), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(u32, link_id) + __field(u8, dtimper) + __field(u16, bcnint) + __dynamic_array(u8, ssid, sdata->vif.cfg.ssid_len) + __field(bool, hidden_ssid) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->link_id = link_conf->link_id; + __entry->dtimper = link_conf->dtim_period; + __entry->bcnint = link_conf->beacon_int; + __entry->hidden_ssid = link_conf->hidden_ssid; + memcpy(__get_dynamic_array(ssid), + sdata->vif.cfg.ssid, + sdata->vif.cfg.ssid_len); + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " link id %u", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->link_id + ) +); + +TRACE_EVENT(drv_stop_ap, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss_conf *link_conf), + + TP_ARGS(local, sdata, link_conf), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(u32, link_id) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->link_id = link_conf->link_id; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " link id %u", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->link_id + ) +); + +TRACE_EVENT(drv_reconfig_complete, + TP_PROTO(struct ieee80211_local *local, + enum ieee80211_reconfig_type reconfig_type), + TP_ARGS(local, reconfig_type), + + TP_STRUCT__entry( + LOCAL_ENTRY + __field(u8, reconfig_type) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + __entry->reconfig_type = reconfig_type; + ), + + TP_printk( + LOCAL_PR_FMT " reconfig_type:%d", + LOCAL_PR_ARG, __entry->reconfig_type + ) + +); + +#if IS_ENABLED(CONFIG_IPV6) +DEFINE_EVENT(local_sdata_evt, drv_ipv6_addr_change, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); +#endif + +TRACE_EVENT(drv_join_ibss, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_bss_conf *info), + + TP_ARGS(local, sdata, info), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(u8, dtimper) + __field(u16, bcnint) + __dynamic_array(u8, ssid, sdata->vif.cfg.ssid_len) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->dtimper = info->dtim_period; + __entry->bcnint = info->beacon_int; + memcpy(__get_dynamic_array(ssid), + sdata->vif.cfg.ssid, + sdata->vif.cfg.ssid_len); + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT, + LOCAL_PR_ARG, VIF_PR_ARG + ) +); + +DEFINE_EVENT(local_sdata_evt, drv_leave_ibss, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +TRACE_EVENT(drv_get_expected_throughput, + TP_PROTO(struct ieee80211_sta *sta), + + TP_ARGS(sta), + + TP_STRUCT__entry( + STA_ENTRY + ), + + TP_fast_assign( + STA_ASSIGN; + ), + + TP_printk( + STA_PR_FMT, STA_PR_ARG + ) +); + +TRACE_EVENT(drv_start_nan, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct cfg80211_nan_conf *conf), + + TP_ARGS(local, sdata, conf), + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(u8, master_pref) + __field(u8, bands) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->master_pref = conf->master_pref; + __entry->bands = conf->bands; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT + ", master preference: %u, bands: 0x%0x", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->master_pref, + __entry->bands + ) +); + +TRACE_EVENT(drv_stop_nan, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + + TP_ARGS(local, sdata), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT, + LOCAL_PR_ARG, VIF_PR_ARG + ) +); + +TRACE_EVENT(drv_nan_change_conf, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct cfg80211_nan_conf *conf, + u32 changes), + + TP_ARGS(local, sdata, conf, changes), + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(u8, master_pref) + __field(u8, bands) + __field(u32, changes) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->master_pref = conf->master_pref; + __entry->bands = conf->bands; + __entry->changes = changes; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT + ", master preference: %u, bands: 0x%0x, changes: 0x%x", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->master_pref, + __entry->bands, __entry->changes + ) +); + +TRACE_EVENT(drv_add_nan_func, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + const struct cfg80211_nan_func *func), + + TP_ARGS(local, sdata, func), + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(u8, type) + __field(u8, inst_id) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->type = func->type; + __entry->inst_id = func->instance_id; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT + ", type: %u, inst_id: %u", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->type, __entry->inst_id + ) +); + +TRACE_EVENT(drv_del_nan_func, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + u8 instance_id), + + TP_ARGS(local, sdata, instance_id), + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(u8, instance_id) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->instance_id = instance_id; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT + ", instance_id: %u", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->instance_id + ) +); + +DEFINE_EVENT(local_sdata_evt, drv_start_pmsr, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +DEFINE_EVENT(local_sdata_evt, drv_abort_pmsr, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +TRACE_EVENT(drv_set_default_unicast_key, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + int key_idx), + + TP_ARGS(local, sdata, key_idx), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(int, key_idx) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->key_idx = key_idx; + ), + + TP_printk(LOCAL_PR_FMT VIF_PR_FMT " key_idx:%d", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->key_idx) +); + +TRACE_EVENT(drv_channel_switch_beacon, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct cfg80211_chan_def *chandef), + + TP_ARGS(local, sdata, chandef), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + CHANDEF_ENTRY + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + CHANDEF_ASSIGN(chandef); + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " channel switch to " CHANDEF_PR_FMT, + LOCAL_PR_ARG, VIF_PR_ARG, CHANDEF_PR_ARG + ) +); + +TRACE_EVENT(drv_pre_channel_switch, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_channel_switch *ch_switch), + + TP_ARGS(local, sdata, ch_switch), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + CHANDEF_ENTRY + __field(u64, timestamp) + __field(u32, device_timestamp) + __field(bool, block_tx) + __field(u8, count) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + CHANDEF_ASSIGN(&ch_switch->chandef) + __entry->timestamp = ch_switch->timestamp; + __entry->device_timestamp = ch_switch->device_timestamp; + __entry->block_tx = ch_switch->block_tx; + __entry->count = ch_switch->count; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " prepare channel switch to " + CHANDEF_PR_FMT " count:%d block_tx:%d timestamp:%llu", + LOCAL_PR_ARG, VIF_PR_ARG, CHANDEF_PR_ARG, __entry->count, + __entry->block_tx, __entry->timestamp + ) +); + +DEFINE_EVENT(local_sdata_evt, drv_post_channel_switch, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +DEFINE_EVENT(local_sdata_evt, drv_abort_channel_switch, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +TRACE_EVENT(drv_channel_switch_rx_beacon, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_channel_switch *ch_switch), + + TP_ARGS(local, sdata, ch_switch), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + CHANDEF_ENTRY + __field(u64, timestamp) + __field(u32, device_timestamp) + __field(bool, block_tx) + __field(u8, count) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + CHANDEF_ASSIGN(&ch_switch->chandef) + __entry->timestamp = ch_switch->timestamp; + __entry->device_timestamp = ch_switch->device_timestamp; + __entry->block_tx = ch_switch->block_tx; + __entry->count = ch_switch->count; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT + " received a channel switch beacon to " + CHANDEF_PR_FMT " count:%d block_tx:%d timestamp:%llu", + LOCAL_PR_ARG, VIF_PR_ARG, CHANDEF_PR_ARG, __entry->count, + __entry->block_tx, __entry->timestamp + ) +); + +TRACE_EVENT(drv_get_txpower, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + int dbm, int ret), + + TP_ARGS(local, sdata, dbm, ret), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(int, dbm) + __field(int, ret) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->dbm = dbm; + __entry->ret = ret; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " dbm:%d ret:%d", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->dbm, __entry->ret + ) +); + +TRACE_EVENT(drv_tdls_channel_switch, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, u8 oper_class, + struct cfg80211_chan_def *chandef), + + TP_ARGS(local, sdata, sta, oper_class, chandef), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + STA_ENTRY + __field(u8, oper_class) + CHANDEF_ENTRY + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + STA_ASSIGN; + __entry->oper_class = oper_class; + CHANDEF_ASSIGN(chandef) + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " tdls channel switch to" + CHANDEF_PR_FMT " oper_class:%d " STA_PR_FMT, + LOCAL_PR_ARG, VIF_PR_ARG, CHANDEF_PR_ARG, __entry->oper_class, + STA_PR_ARG + ) +); + +TRACE_EVENT(drv_tdls_cancel_channel_switch, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta), + + TP_ARGS(local, sdata, sta), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + STA_ENTRY + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + STA_ASSIGN; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT + " tdls cancel channel switch with " STA_PR_FMT, + LOCAL_PR_ARG, VIF_PR_ARG, STA_PR_ARG + ) +); + +TRACE_EVENT(drv_tdls_recv_channel_switch, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_tdls_ch_sw_params *params), + + TP_ARGS(local, sdata, params), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(u8, action_code) + STA_ENTRY + CHANDEF_ENTRY + __field(u32, status) + __field(bool, peer_initiator) + __field(u32, timestamp) + __field(u16, switch_time) + __field(u16, switch_timeout) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + STA_NAMED_ASSIGN(params->sta); + CHANDEF_ASSIGN(params->chandef) + __entry->peer_initiator = params->sta->tdls_initiator; + __entry->action_code = params->action_code; + __entry->status = params->status; + __entry->timestamp = params->timestamp; + __entry->switch_time = params->switch_time; + __entry->switch_timeout = params->switch_timeout; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " received tdls channel switch packet" + " action:%d status:%d time:%d switch time:%d switch" + " timeout:%d initiator: %d chan:" CHANDEF_PR_FMT STA_PR_FMT, + LOCAL_PR_ARG, VIF_PR_ARG, __entry->action_code, __entry->status, + __entry->timestamp, __entry->switch_time, + __entry->switch_timeout, __entry->peer_initiator, + CHANDEF_PR_ARG, STA_PR_ARG + ) +); + +TRACE_EVENT(drv_wake_tx_queue, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct txq_info *txq), + + TP_ARGS(local, sdata, txq), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + STA_ENTRY + __field(u8, ac) + __field(u8, tid) + ), + + TP_fast_assign( + struct ieee80211_sta *sta = txq->txq.sta; + + LOCAL_ASSIGN; + VIF_ASSIGN; + STA_ASSIGN; + __entry->ac = txq->txq.ac; + __entry->tid = txq->txq.tid; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT STA_PR_FMT " ac:%d tid:%d", + LOCAL_PR_ARG, VIF_PR_ARG, STA_PR_ARG, __entry->ac, __entry->tid + ) +); + +TRACE_EVENT(drv_get_ftm_responder_stats, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct cfg80211_ftm_responder_stats *ftm_stats), + + TP_ARGS(local, sdata, ftm_stats), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT, + LOCAL_PR_ARG, VIF_PR_ARG + ) +); + +DEFINE_EVENT(local_sdata_addr_evt, drv_update_vif_offload, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +DECLARE_EVENT_CLASS(sta_flag_evt, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, bool enabled), + + TP_ARGS(local, sdata, sta, enabled), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + STA_ENTRY + __field(bool, enabled) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + STA_ASSIGN; + __entry->enabled = enabled; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT STA_PR_FMT " enabled:%d", + LOCAL_PR_ARG, VIF_PR_ARG, STA_PR_ARG, __entry->enabled + ) +); + +DEFINE_EVENT(sta_flag_evt, drv_sta_set_4addr, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, bool enabled), + + TP_ARGS(local, sdata, sta, enabled) +); + +DEFINE_EVENT(sta_flag_evt, drv_sta_set_decap_offload, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, bool enabled), + + TP_ARGS(local, sdata, sta, enabled) +); + +TRACE_EVENT(drv_add_twt_setup, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sta *sta, + struct ieee80211_twt_setup *twt, + struct ieee80211_twt_params *twt_agrt), + + TP_ARGS(local, sta, twt, twt_agrt), + + TP_STRUCT__entry( + LOCAL_ENTRY + STA_ENTRY + __field(u8, dialog_token) + __field(u8, control) + __field(__le16, req_type) + __field(__le64, twt) + __field(u8, duration) + __field(__le16, mantissa) + __field(u8, channel) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + STA_ASSIGN; + __entry->dialog_token = twt->dialog_token; + __entry->control = twt->control; + __entry->req_type = twt_agrt->req_type; + __entry->twt = twt_agrt->twt; + __entry->duration = twt_agrt->min_twt_dur; + __entry->mantissa = twt_agrt->mantissa; + __entry->channel = twt_agrt->channel; + ), + + TP_printk( + LOCAL_PR_FMT STA_PR_FMT + " token:%d control:0x%02x req_type:0x%04x" + " twt:%llu duration:%d mantissa:%d channel:%d", + LOCAL_PR_ARG, STA_PR_ARG, __entry->dialog_token, + __entry->control, le16_to_cpu(__entry->req_type), + le64_to_cpu(__entry->twt), __entry->duration, + le16_to_cpu(__entry->mantissa), __entry->channel + ) +); + +TRACE_EVENT(drv_twt_teardown_request, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sta *sta, u8 flowid), + + TP_ARGS(local, sta, flowid), + + TP_STRUCT__entry( + LOCAL_ENTRY + STA_ENTRY + __field(u8, flowid) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + STA_ASSIGN; + __entry->flowid = flowid; + ), + + TP_printk( + LOCAL_PR_FMT STA_PR_FMT " flowid:%d", + LOCAL_PR_ARG, STA_PR_ARG, __entry->flowid + ) +); + +DEFINE_EVENT(sta_event, drv_net_fill_forward_path, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta), + TP_ARGS(local, sdata, sta) +); + +TRACE_EVENT(drv_change_vif_links, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + u16 old_links, u16 new_links), + + TP_ARGS(local, sdata, old_links, new_links), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + __field(u16, old_links) + __field(u16, new_links) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + __entry->old_links = old_links; + __entry->new_links = new_links; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT " old_links:0x%04x, new_links:0x%04x\n", + LOCAL_PR_ARG, VIF_PR_ARG, __entry->old_links, __entry->new_links + ) +); + +TRACE_EVENT(drv_change_sta_links, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta *sta, + u16 old_links, u16 new_links), + + TP_ARGS(local, sdata, sta, old_links, new_links), + + TP_STRUCT__entry( + LOCAL_ENTRY + VIF_ENTRY + STA_ENTRY + __field(u16, old_links) + __field(u16, new_links) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + VIF_ASSIGN; + STA_ASSIGN; + __entry->old_links = old_links; + __entry->new_links = new_links; + ), + + TP_printk( + LOCAL_PR_FMT VIF_PR_FMT STA_PR_FMT " old_links:0x%04x, new_links:0x%04x\n", + LOCAL_PR_ARG, VIF_PR_ARG, STA_PR_ARG, + __entry->old_links, __entry->new_links + ) +); + +/* + * Tracing for API calls that drivers call. + */ + +TRACE_EVENT(api_start_tx_ba_session, + TP_PROTO(struct ieee80211_sta *sta, u16 tid), + + TP_ARGS(sta, tid), + + TP_STRUCT__entry( + STA_ENTRY + __field(u16, tid) + ), + + TP_fast_assign( + STA_ASSIGN; + __entry->tid = tid; + ), + + TP_printk( + STA_PR_FMT " tid:%d", + STA_PR_ARG, __entry->tid + ) +); + +TRACE_EVENT(api_start_tx_ba_cb, + TP_PROTO(struct ieee80211_sub_if_data *sdata, const u8 *ra, u16 tid), + + TP_ARGS(sdata, ra, tid), + + TP_STRUCT__entry( + VIF_ENTRY + __array(u8, ra, ETH_ALEN) + __field(u16, tid) + ), + + TP_fast_assign( + VIF_ASSIGN; + memcpy(__entry->ra, ra, ETH_ALEN); + __entry->tid = tid; + ), + + TP_printk( + VIF_PR_FMT " ra:%pM tid:%d", + VIF_PR_ARG, __entry->ra, __entry->tid + ) +); + +TRACE_EVENT(api_stop_tx_ba_session, + TP_PROTO(struct ieee80211_sta *sta, u16 tid), + + TP_ARGS(sta, tid), + + TP_STRUCT__entry( + STA_ENTRY + __field(u16, tid) + ), + + TP_fast_assign( + STA_ASSIGN; + __entry->tid = tid; + ), + + TP_printk( + STA_PR_FMT " tid:%d", + STA_PR_ARG, __entry->tid + ) +); + +TRACE_EVENT(api_stop_tx_ba_cb, + TP_PROTO(struct ieee80211_sub_if_data *sdata, const u8 *ra, u16 tid), + + TP_ARGS(sdata, ra, tid), + + TP_STRUCT__entry( + VIF_ENTRY + __array(u8, ra, ETH_ALEN) + __field(u16, tid) + ), + + TP_fast_assign( + VIF_ASSIGN; + memcpy(__entry->ra, ra, ETH_ALEN); + __entry->tid = tid; + ), + + TP_printk( + VIF_PR_FMT " ra:%pM tid:%d", + VIF_PR_ARG, __entry->ra, __entry->tid + ) +); + +DEFINE_EVENT(local_only_evt, api_restart_hw, + TP_PROTO(struct ieee80211_local *local), + TP_ARGS(local) +); + +TRACE_EVENT(api_beacon_loss, + TP_PROTO(struct ieee80211_sub_if_data *sdata), + + TP_ARGS(sdata), + + TP_STRUCT__entry( + VIF_ENTRY + ), + + TP_fast_assign( + VIF_ASSIGN; + ), + + TP_printk( + VIF_PR_FMT, + VIF_PR_ARG + ) +); + +TRACE_EVENT(api_connection_loss, + TP_PROTO(struct ieee80211_sub_if_data *sdata), + + TP_ARGS(sdata), + + TP_STRUCT__entry( + VIF_ENTRY + ), + + TP_fast_assign( + VIF_ASSIGN; + ), + + TP_printk( + VIF_PR_FMT, + VIF_PR_ARG + ) +); + +TRACE_EVENT(api_disconnect, + TP_PROTO(struct ieee80211_sub_if_data *sdata, bool reconnect), + + TP_ARGS(sdata, reconnect), + + TP_STRUCT__entry( + VIF_ENTRY + __field(int, reconnect) + ), + + TP_fast_assign( + VIF_ASSIGN; + __entry->reconnect = reconnect; + ), + + TP_printk( + VIF_PR_FMT " reconnect:%d", + VIF_PR_ARG, __entry->reconnect + ) +); + +TRACE_EVENT(api_cqm_rssi_notify, + TP_PROTO(struct ieee80211_sub_if_data *sdata, + enum nl80211_cqm_rssi_threshold_event rssi_event, + s32 rssi_level), + + TP_ARGS(sdata, rssi_event, rssi_level), + + TP_STRUCT__entry( + VIF_ENTRY + __field(u32, rssi_event) + __field(s32, rssi_level) + ), + + TP_fast_assign( + VIF_ASSIGN; + __entry->rssi_event = rssi_event; + __entry->rssi_level = rssi_level; + ), + + TP_printk( + VIF_PR_FMT " event:%d rssi:%d", + VIF_PR_ARG, __entry->rssi_event, __entry->rssi_level + ) +); + +DEFINE_EVENT(local_sdata_evt, api_cqm_beacon_loss_notify, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata), + TP_ARGS(local, sdata) +); + +TRACE_EVENT(api_scan_completed, + TP_PROTO(struct ieee80211_local *local, bool aborted), + + TP_ARGS(local, aborted), + + TP_STRUCT__entry( + LOCAL_ENTRY + __field(bool, aborted) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + __entry->aborted = aborted; + ), + + TP_printk( + LOCAL_PR_FMT " aborted:%d", + LOCAL_PR_ARG, __entry->aborted + ) +); + +TRACE_EVENT(api_sched_scan_results, + TP_PROTO(struct ieee80211_local *local), + + TP_ARGS(local), + + TP_STRUCT__entry( + LOCAL_ENTRY + ), + + TP_fast_assign( + LOCAL_ASSIGN; + ), + + TP_printk( + LOCAL_PR_FMT, LOCAL_PR_ARG + ) +); + +TRACE_EVENT(api_sched_scan_stopped, + TP_PROTO(struct ieee80211_local *local), + + TP_ARGS(local), + + TP_STRUCT__entry( + LOCAL_ENTRY + ), + + TP_fast_assign( + LOCAL_ASSIGN; + ), + + TP_printk( + LOCAL_PR_FMT, LOCAL_PR_ARG + ) +); + +TRACE_EVENT(api_sta_block_awake, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sta *sta, bool block), + + TP_ARGS(local, sta, block), + + TP_STRUCT__entry( + LOCAL_ENTRY + STA_ENTRY + __field(bool, block) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + STA_ASSIGN; + __entry->block = block; + ), + + TP_printk( + LOCAL_PR_FMT STA_PR_FMT " block:%d", + LOCAL_PR_ARG, STA_PR_ARG, __entry->block + ) +); + +TRACE_EVENT(api_chswitch_done, + TP_PROTO(struct ieee80211_sub_if_data *sdata, bool success), + + TP_ARGS(sdata, success), + + TP_STRUCT__entry( + VIF_ENTRY + __field(bool, success) + ), + + TP_fast_assign( + VIF_ASSIGN; + __entry->success = success; + ), + + TP_printk( + VIF_PR_FMT " success=%d", + VIF_PR_ARG, __entry->success + ) +); + +DEFINE_EVENT(local_only_evt, api_ready_on_channel, + TP_PROTO(struct ieee80211_local *local), + TP_ARGS(local) +); + +DEFINE_EVENT(local_only_evt, api_remain_on_channel_expired, + TP_PROTO(struct ieee80211_local *local), + TP_ARGS(local) +); + +TRACE_EVENT(api_gtk_rekey_notify, + TP_PROTO(struct ieee80211_sub_if_data *sdata, + const u8 *bssid, const u8 *replay_ctr), + + TP_ARGS(sdata, bssid, replay_ctr), + + TP_STRUCT__entry( + VIF_ENTRY + __array(u8, bssid, ETH_ALEN) + __array(u8, replay_ctr, NL80211_REPLAY_CTR_LEN) + ), + + TP_fast_assign( + VIF_ASSIGN; + memcpy(__entry->bssid, bssid, ETH_ALEN); + memcpy(__entry->replay_ctr, replay_ctr, NL80211_REPLAY_CTR_LEN); + ), + + TP_printk(VIF_PR_FMT, VIF_PR_ARG) +); + +TRACE_EVENT(api_enable_rssi_reports, + TP_PROTO(struct ieee80211_sub_if_data *sdata, + int rssi_min_thold, int rssi_max_thold), + + TP_ARGS(sdata, rssi_min_thold, rssi_max_thold), + + TP_STRUCT__entry( + VIF_ENTRY + __field(int, rssi_min_thold) + __field(int, rssi_max_thold) + ), + + TP_fast_assign( + VIF_ASSIGN; + __entry->rssi_min_thold = rssi_min_thold; + __entry->rssi_max_thold = rssi_max_thold; + ), + + TP_printk( + VIF_PR_FMT " rssi_min_thold =%d, rssi_max_thold = %d", + VIF_PR_ARG, __entry->rssi_min_thold, __entry->rssi_max_thold + ) +); + +TRACE_EVENT(api_eosp, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sta *sta), + + TP_ARGS(local, sta), + + TP_STRUCT__entry( + LOCAL_ENTRY + STA_ENTRY + ), + + TP_fast_assign( + LOCAL_ASSIGN; + STA_ASSIGN; + ), + + TP_printk( + LOCAL_PR_FMT STA_PR_FMT, + LOCAL_PR_ARG, STA_PR_ARG + ) +); + +TRACE_EVENT(api_send_eosp_nullfunc, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sta *sta, + u8 tid), + + TP_ARGS(local, sta, tid), + + TP_STRUCT__entry( + LOCAL_ENTRY + STA_ENTRY + __field(u8, tid) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + STA_ASSIGN; + __entry->tid = tid; + ), + + TP_printk( + LOCAL_PR_FMT STA_PR_FMT " tid:%d", + LOCAL_PR_ARG, STA_PR_ARG, __entry->tid + ) +); + +TRACE_EVENT(api_sta_set_buffered, + TP_PROTO(struct ieee80211_local *local, + struct ieee80211_sta *sta, + u8 tid, bool buffered), + + TP_ARGS(local, sta, tid, buffered), + + TP_STRUCT__entry( + LOCAL_ENTRY + STA_ENTRY + __field(u8, tid) + __field(bool, buffered) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + STA_ASSIGN; + __entry->tid = tid; + __entry->buffered = buffered; + ), + + TP_printk( + LOCAL_PR_FMT STA_PR_FMT " tid:%d buffered:%d", + LOCAL_PR_ARG, STA_PR_ARG, __entry->tid, __entry->buffered + ) +); + +TRACE_EVENT(api_radar_detected, + TP_PROTO(struct ieee80211_local *local), + + TP_ARGS(local), + + TP_STRUCT__entry( + LOCAL_ENTRY + ), + + TP_fast_assign( + LOCAL_ASSIGN; + ), + + TP_printk( + LOCAL_PR_FMT " radar detected", + LOCAL_PR_ARG + ) +); + +/* + * Tracing for internal functions + * (which may also be called in response to driver calls) + */ + +TRACE_EVENT(wake_queue, + TP_PROTO(struct ieee80211_local *local, u16 queue, + enum queue_stop_reason reason), + + TP_ARGS(local, queue, reason), + + TP_STRUCT__entry( + LOCAL_ENTRY + __field(u16, queue) + __field(u32, reason) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + __entry->queue = queue; + __entry->reason = reason; + ), + + TP_printk( + LOCAL_PR_FMT " queue:%d, reason:%d", + LOCAL_PR_ARG, __entry->queue, __entry->reason + ) +); + +TRACE_EVENT(stop_queue, + TP_PROTO(struct ieee80211_local *local, u16 queue, + enum queue_stop_reason reason), + + TP_ARGS(local, queue, reason), + + TP_STRUCT__entry( + LOCAL_ENTRY + __field(u16, queue) + __field(u32, reason) + ), + + TP_fast_assign( + LOCAL_ASSIGN; + __entry->queue = queue; + __entry->reason = reason; + ), + + TP_printk( + LOCAL_PR_FMT " queue:%d, reason:%d", + LOCAL_PR_ARG, __entry->queue, __entry->reason + ) +); + +#endif /* !__MAC80211_DRIVER_TRACE || TRACE_HEADER_MULTI_READ */ + +#undef TRACE_INCLUDE_PATH +#define TRACE_INCLUDE_PATH . +#undef TRACE_INCLUDE_FILE +#define TRACE_INCLUDE_FILE trace +#include diff --git a/net/mac80211/trace_msg.h b/net/mac80211/trace_msg.h new file mode 100644 index 000000000..c9dbe9aab --- /dev/null +++ b/net/mac80211/trace_msg.h @@ -0,0 +1,57 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Portions of this file + * Copyright (C) 2019 Intel Corporation + */ + +#ifdef CONFIG_MAC80211_MESSAGE_TRACING + +#if !defined(__MAC80211_MSG_DRIVER_TRACE) || defined(TRACE_HEADER_MULTI_READ) +#define __MAC80211_MSG_DRIVER_TRACE + +#include +#include +#include "ieee80211_i.h" + +#undef TRACE_SYSTEM +#define TRACE_SYSTEM mac80211_msg + +#define MAX_MSG_LEN 120 + +DECLARE_EVENT_CLASS(mac80211_msg_event, + TP_PROTO(struct va_format *vaf), + + TP_ARGS(vaf), + + TP_STRUCT__entry( + __vstring(msg, vaf->fmt, vaf->va) + ), + + TP_fast_assign( + __assign_vstr(msg, vaf->fmt, vaf->va); + ), + + TP_printk("%s", __get_str(msg)) +); + +DEFINE_EVENT(mac80211_msg_event, mac80211_info, + TP_PROTO(struct va_format *vaf), + TP_ARGS(vaf) +); +DEFINE_EVENT(mac80211_msg_event, mac80211_dbg, + TP_PROTO(struct va_format *vaf), + TP_ARGS(vaf) +); +DEFINE_EVENT(mac80211_msg_event, mac80211_err, + TP_PROTO(struct va_format *vaf), + TP_ARGS(vaf) +); +#endif /* !__MAC80211_MSG_DRIVER_TRACE || TRACE_HEADER_MULTI_READ */ + +#undef TRACE_INCLUDE_PATH +#define TRACE_INCLUDE_PATH . +#undef TRACE_INCLUDE_FILE +#define TRACE_INCLUDE_FILE trace_msg +#include + +#endif diff --git a/net/mac80211/tx.c b/net/mac80211/tx.c new file mode 100644 index 000000000..2db103a56 --- /dev/null +++ b/net/mac80211/tx.c @@ -0,0 +1,6021 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2002-2005, Instant802 Networks, Inc. + * Copyright 2005-2006, Devicescape Software, Inc. + * Copyright 2006-2007 Jiri Benc + * Copyright 2007 Johannes Berg + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright (C) 2018-2022 Intel Corporation + * + * Transmit and frame generation functions. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ieee80211_i.h" +#include "driver-ops.h" +#include "led.h" +#include "mesh.h" +#include "wep.h" +#include "wpa.h" +#include "wme.h" +#include "rate.h" + +/* misc utils */ + +static __le16 ieee80211_duration(struct ieee80211_tx_data *tx, + struct sk_buff *skb, int group_addr, + int next_frag_len) +{ + int rate, mrate, erp, dur, i, shift = 0; + struct ieee80211_rate *txrate; + struct ieee80211_local *local = tx->local; + struct ieee80211_supported_band *sband; + struct ieee80211_hdr *hdr; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_chanctx_conf *chanctx_conf; + u32 rate_flags = 0; + + /* assume HW handles this */ + if (tx->rate.flags & (IEEE80211_TX_RC_MCS | IEEE80211_TX_RC_VHT_MCS)) + return 0; + + rcu_read_lock(); + chanctx_conf = rcu_dereference(tx->sdata->vif.bss_conf.chanctx_conf); + if (chanctx_conf) { + shift = ieee80211_chandef_get_shift(&chanctx_conf->def); + rate_flags = ieee80211_chandef_rate_flags(&chanctx_conf->def); + } + rcu_read_unlock(); + + /* uh huh? */ + if (WARN_ON_ONCE(tx->rate.idx < 0)) + return 0; + + sband = local->hw.wiphy->bands[info->band]; + txrate = &sband->bitrates[tx->rate.idx]; + + erp = txrate->flags & IEEE80211_RATE_ERP_G; + + /* device is expected to do this */ + if (sband->band == NL80211_BAND_S1GHZ) + return 0; + + /* + * data and mgmt (except PS Poll): + * - during CFP: 32768 + * - during contention period: + * if addr1 is group address: 0 + * if more fragments = 0 and addr1 is individual address: time to + * transmit one ACK plus SIFS + * if more fragments = 1 and addr1 is individual address: time to + * transmit next fragment plus 2 x ACK plus 3 x SIFS + * + * IEEE 802.11, 9.6: + * - control response frame (CTS or ACK) shall be transmitted using the + * same rate as the immediately previous frame in the frame exchange + * sequence, if this rate belongs to the PHY mandatory rates, or else + * at the highest possible rate belonging to the PHY rates in the + * BSSBasicRateSet + */ + hdr = (struct ieee80211_hdr *)skb->data; + if (ieee80211_is_ctl(hdr->frame_control)) { + /* TODO: These control frames are not currently sent by + * mac80211, but should they be implemented, this function + * needs to be updated to support duration field calculation. + * + * RTS: time needed to transmit pending data/mgmt frame plus + * one CTS frame plus one ACK frame plus 3 x SIFS + * CTS: duration of immediately previous RTS minus time + * required to transmit CTS and its SIFS + * ACK: 0 if immediately previous directed data/mgmt had + * more=0, with more=1 duration in ACK frame is duration + * from previous frame minus time needed to transmit ACK + * and its SIFS + * PS Poll: BIT(15) | BIT(14) | aid + */ + return 0; + } + + /* data/mgmt */ + if (0 /* FIX: data/mgmt during CFP */) + return cpu_to_le16(32768); + + if (group_addr) /* Group address as the destination - no ACK */ + return 0; + + /* Individual destination address: + * IEEE 802.11, Ch. 9.6 (after IEEE 802.11g changes) + * CTS and ACK frames shall be transmitted using the highest rate in + * basic rate set that is less than or equal to the rate of the + * immediately previous frame and that is using the same modulation + * (CCK or OFDM). If no basic rate set matches with these requirements, + * the highest mandatory rate of the PHY that is less than or equal to + * the rate of the previous frame is used. + * Mandatory rates for IEEE 802.11g PHY: 1, 2, 5.5, 11, 6, 12, 24 Mbps + */ + rate = -1; + /* use lowest available if everything fails */ + mrate = sband->bitrates[0].bitrate; + for (i = 0; i < sband->n_bitrates; i++) { + struct ieee80211_rate *r = &sband->bitrates[i]; + + if (r->bitrate > txrate->bitrate) + break; + + if ((rate_flags & r->flags) != rate_flags) + continue; + + if (tx->sdata->vif.bss_conf.basic_rates & BIT(i)) + rate = DIV_ROUND_UP(r->bitrate, 1 << shift); + + switch (sband->band) { + case NL80211_BAND_2GHZ: + case NL80211_BAND_LC: { + u32 flag; + if (tx->sdata->deflink.operating_11g_mode) + flag = IEEE80211_RATE_MANDATORY_G; + else + flag = IEEE80211_RATE_MANDATORY_B; + if (r->flags & flag) + mrate = r->bitrate; + break; + } + case NL80211_BAND_5GHZ: + case NL80211_BAND_6GHZ: + if (r->flags & IEEE80211_RATE_MANDATORY_A) + mrate = r->bitrate; + break; + case NL80211_BAND_S1GHZ: + case NL80211_BAND_60GHZ: + /* TODO, for now fall through */ + case NUM_NL80211_BANDS: + WARN_ON(1); + break; + } + } + if (rate == -1) { + /* No matching basic rate found; use highest suitable mandatory + * PHY rate */ + rate = DIV_ROUND_UP(mrate, 1 << shift); + } + + /* Don't calculate ACKs for QoS Frames with NoAck Policy set */ + if (ieee80211_is_data_qos(hdr->frame_control) && + *(ieee80211_get_qos_ctl(hdr)) & IEEE80211_QOS_CTL_ACK_POLICY_NOACK) + dur = 0; + else + /* Time needed to transmit ACK + * (10 bytes + 4-byte FCS = 112 bits) plus SIFS; rounded up + * to closest integer */ + dur = ieee80211_frame_duration(sband->band, 10, rate, erp, + tx->sdata->vif.bss_conf.use_short_preamble, + shift); + + if (next_frag_len) { + /* Frame is fragmented: duration increases with time needed to + * transmit next fragment plus ACK and 2 x SIFS. */ + dur *= 2; /* ACK + SIFS */ + /* next fragment */ + dur += ieee80211_frame_duration(sband->band, next_frag_len, + txrate->bitrate, erp, + tx->sdata->vif.bss_conf.use_short_preamble, + shift); + } + + return cpu_to_le16(dur); +} + +/* tx handlers */ +static ieee80211_tx_result debug_noinline +ieee80211_tx_h_dynamic_ps(struct ieee80211_tx_data *tx) +{ + struct ieee80211_local *local = tx->local; + struct ieee80211_if_managed *ifmgd; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(tx->skb); + + /* driver doesn't support power save */ + if (!ieee80211_hw_check(&local->hw, SUPPORTS_PS)) + return TX_CONTINUE; + + /* hardware does dynamic power save */ + if (ieee80211_hw_check(&local->hw, SUPPORTS_DYNAMIC_PS)) + return TX_CONTINUE; + + /* dynamic power save disabled */ + if (local->hw.conf.dynamic_ps_timeout <= 0) + return TX_CONTINUE; + + /* we are scanning, don't enable power save */ + if (local->scanning) + return TX_CONTINUE; + + if (!local->ps_sdata) + return TX_CONTINUE; + + /* No point if we're going to suspend */ + if (local->quiescing) + return TX_CONTINUE; + + /* dynamic ps is supported only in managed mode */ + if (tx->sdata->vif.type != NL80211_IFTYPE_STATION) + return TX_CONTINUE; + + if (unlikely(info->flags & IEEE80211_TX_INTFL_OFFCHAN_TX_OK)) + return TX_CONTINUE; + + ifmgd = &tx->sdata->u.mgd; + + /* + * Don't wakeup from power save if u-apsd is enabled, voip ac has + * u-apsd enabled and the frame is in voip class. This effectively + * means that even if all access categories have u-apsd enabled, in + * practise u-apsd is only used with the voip ac. This is a + * workaround for the case when received voip class packets do not + * have correct qos tag for some reason, due the network or the + * peer application. + * + * Note: ifmgd->uapsd_queues access is racy here. If the value is + * changed via debugfs, user needs to reassociate manually to have + * everything in sync. + */ + if ((ifmgd->flags & IEEE80211_STA_UAPSD_ENABLED) && + (ifmgd->uapsd_queues & IEEE80211_WMM_IE_STA_QOSINFO_AC_VO) && + skb_get_queue_mapping(tx->skb) == IEEE80211_AC_VO) + return TX_CONTINUE; + + if (local->hw.conf.flags & IEEE80211_CONF_PS) { + ieee80211_stop_queues_by_reason(&local->hw, + IEEE80211_MAX_QUEUE_MAP, + IEEE80211_QUEUE_STOP_REASON_PS, + false); + ifmgd->flags &= ~IEEE80211_STA_NULLFUNC_ACKED; + ieee80211_queue_work(&local->hw, + &local->dynamic_ps_disable_work); + } + + /* Don't restart the timer if we're not disassociated */ + if (!ifmgd->associated) + return TX_CONTINUE; + + mod_timer(&local->dynamic_ps_timer, jiffies + + msecs_to_jiffies(local->hw.conf.dynamic_ps_timeout)); + + return TX_CONTINUE; +} + +static ieee80211_tx_result debug_noinline +ieee80211_tx_h_check_assoc(struct ieee80211_tx_data *tx) +{ + + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)tx->skb->data; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(tx->skb); + bool assoc = false; + + if (unlikely(info->flags & IEEE80211_TX_CTL_INJECTED)) + return TX_CONTINUE; + + if (unlikely(test_bit(SCAN_SW_SCANNING, &tx->local->scanning)) && + test_bit(SDATA_STATE_OFFCHANNEL, &tx->sdata->state) && + !ieee80211_is_probe_req(hdr->frame_control) && + !ieee80211_is_any_nullfunc(hdr->frame_control)) + /* + * When software scanning only nullfunc frames (to notify + * the sleep state to the AP) and probe requests (for the + * active scan) are allowed, all other frames should not be + * sent and we should not get here, but if we do + * nonetheless, drop them to avoid sending them + * off-channel. See the link below and + * ieee80211_start_scan() for more. + * + * http://article.gmane.org/gmane.linux.kernel.wireless.general/30089 + */ + return TX_DROP; + + if (tx->sdata->vif.type == NL80211_IFTYPE_OCB) + return TX_CONTINUE; + + if (tx->flags & IEEE80211_TX_PS_BUFFERED) + return TX_CONTINUE; + + if (tx->sta) + assoc = test_sta_flag(tx->sta, WLAN_STA_ASSOC); + + if (likely(tx->flags & IEEE80211_TX_UNICAST)) { + if (unlikely(!assoc && + ieee80211_is_data(hdr->frame_control))) { +#ifdef CONFIG_MAC80211_VERBOSE_DEBUG + sdata_info(tx->sdata, + "dropped data frame to not associated station %pM\n", + hdr->addr1); +#endif + I802_DEBUG_INC(tx->local->tx_handlers_drop_not_assoc); + return TX_DROP; + } + } else if (unlikely(ieee80211_is_data(hdr->frame_control) && + ieee80211_vif_get_num_mcast_if(tx->sdata) == 0)) { + /* + * No associated STAs - no need to send multicast + * frames. + */ + return TX_DROP; + } + + return TX_CONTINUE; +} + +/* This function is called whenever the AP is about to exceed the maximum limit + * of buffered frames for power saving STAs. This situation should not really + * happen often during normal operation, so dropping the oldest buffered packet + * from each queue should be OK to make some room for new frames. */ +static void purge_old_ps_buffers(struct ieee80211_local *local) +{ + int total = 0, purged = 0; + struct sk_buff *skb; + struct ieee80211_sub_if_data *sdata; + struct sta_info *sta; + + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + struct ps_data *ps; + + if (sdata->vif.type == NL80211_IFTYPE_AP) + ps = &sdata->u.ap.ps; + else if (ieee80211_vif_is_mesh(&sdata->vif)) + ps = &sdata->u.mesh.ps; + else + continue; + + skb = skb_dequeue(&ps->bc_buf); + if (skb) { + purged++; + ieee80211_free_txskb(&local->hw, skb); + } + total += skb_queue_len(&ps->bc_buf); + } + + /* + * Drop one frame from each station from the lowest-priority + * AC that has frames at all. + */ + list_for_each_entry_rcu(sta, &local->sta_list, list) { + int ac; + + for (ac = IEEE80211_AC_BK; ac >= IEEE80211_AC_VO; ac--) { + skb = skb_dequeue(&sta->ps_tx_buf[ac]); + total += skb_queue_len(&sta->ps_tx_buf[ac]); + if (skb) { + purged++; + ieee80211_free_txskb(&local->hw, skb); + break; + } + } + } + + local->total_ps_buffered = total; + ps_dbg_hw(&local->hw, "PS buffers full - purged %d frames\n", purged); +} + +static ieee80211_tx_result +ieee80211_tx_h_multicast_ps_buf(struct ieee80211_tx_data *tx) +{ + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(tx->skb); + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)tx->skb->data; + struct ps_data *ps; + + /* + * broadcast/multicast frame + * + * If any of the associated/peer stations is in power save mode, + * the frame is buffered to be sent after DTIM beacon frame. + * This is done either by the hardware or us. + */ + + /* powersaving STAs currently only in AP/VLAN/mesh mode */ + if (tx->sdata->vif.type == NL80211_IFTYPE_AP || + tx->sdata->vif.type == NL80211_IFTYPE_AP_VLAN) { + if (!tx->sdata->bss) + return TX_CONTINUE; + + ps = &tx->sdata->bss->ps; + } else if (ieee80211_vif_is_mesh(&tx->sdata->vif)) { + ps = &tx->sdata->u.mesh.ps; + } else { + return TX_CONTINUE; + } + + + /* no buffering for ordered frames */ + if (ieee80211_has_order(hdr->frame_control)) + return TX_CONTINUE; + + if (ieee80211_is_probe_req(hdr->frame_control)) + return TX_CONTINUE; + + if (ieee80211_hw_check(&tx->local->hw, QUEUE_CONTROL)) + info->hw_queue = tx->sdata->vif.cab_queue; + + /* no stations in PS mode and no buffered packets */ + if (!atomic_read(&ps->num_sta_ps) && skb_queue_empty(&ps->bc_buf)) + return TX_CONTINUE; + + info->flags |= IEEE80211_TX_CTL_SEND_AFTER_DTIM; + + /* device releases frame after DTIM beacon */ + if (!ieee80211_hw_check(&tx->local->hw, HOST_BROADCAST_PS_BUFFERING)) + return TX_CONTINUE; + + /* buffered in mac80211 */ + if (tx->local->total_ps_buffered >= TOTAL_MAX_TX_BUFFER) + purge_old_ps_buffers(tx->local); + + if (skb_queue_len(&ps->bc_buf) >= AP_MAX_BC_BUFFER) { + ps_dbg(tx->sdata, + "BC TX buffer full - dropping the oldest frame\n"); + ieee80211_free_txskb(&tx->local->hw, skb_dequeue(&ps->bc_buf)); + } else + tx->local->total_ps_buffered++; + + skb_queue_tail(&ps->bc_buf, tx->skb); + + return TX_QUEUED; +} + +static int ieee80211_use_mfp(__le16 fc, struct sta_info *sta, + struct sk_buff *skb) +{ + if (!ieee80211_is_mgmt(fc)) + return 0; + + if (sta == NULL || !test_sta_flag(sta, WLAN_STA_MFP)) + return 0; + + if (!ieee80211_is_robust_mgmt_frame(skb)) + return 0; + + return 1; +} + +static ieee80211_tx_result +ieee80211_tx_h_unicast_ps_buf(struct ieee80211_tx_data *tx) +{ + struct sta_info *sta = tx->sta; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(tx->skb); + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)tx->skb->data; + struct ieee80211_local *local = tx->local; + + if (unlikely(!sta)) + return TX_CONTINUE; + + if (unlikely((test_sta_flag(sta, WLAN_STA_PS_STA) || + test_sta_flag(sta, WLAN_STA_PS_DRIVER) || + test_sta_flag(sta, WLAN_STA_PS_DELIVER)) && + !(info->flags & IEEE80211_TX_CTL_NO_PS_BUFFER))) { + int ac = skb_get_queue_mapping(tx->skb); + + if (ieee80211_is_mgmt(hdr->frame_control) && + !ieee80211_is_bufferable_mmpdu(hdr->frame_control)) { + info->flags |= IEEE80211_TX_CTL_NO_PS_BUFFER; + return TX_CONTINUE; + } + + ps_dbg(sta->sdata, "STA %pM aid %d: PS buffer for AC %d\n", + sta->sta.addr, sta->sta.aid, ac); + if (tx->local->total_ps_buffered >= TOTAL_MAX_TX_BUFFER) + purge_old_ps_buffers(tx->local); + + /* sync with ieee80211_sta_ps_deliver_wakeup */ + spin_lock(&sta->ps_lock); + /* + * STA woke up the meantime and all the frames on ps_tx_buf have + * been queued to pending queue. No reordering can happen, go + * ahead and Tx the packet. + */ + if (!test_sta_flag(sta, WLAN_STA_PS_STA) && + !test_sta_flag(sta, WLAN_STA_PS_DRIVER) && + !test_sta_flag(sta, WLAN_STA_PS_DELIVER)) { + spin_unlock(&sta->ps_lock); + return TX_CONTINUE; + } + + if (skb_queue_len(&sta->ps_tx_buf[ac]) >= STA_MAX_TX_BUFFER) { + struct sk_buff *old = skb_dequeue(&sta->ps_tx_buf[ac]); + ps_dbg(tx->sdata, + "STA %pM TX buffer for AC %d full - dropping oldest frame\n", + sta->sta.addr, ac); + ieee80211_free_txskb(&local->hw, old); + } else + tx->local->total_ps_buffered++; + + info->control.jiffies = jiffies; + info->control.vif = &tx->sdata->vif; + info->control.flags |= IEEE80211_TX_INTCFL_NEED_TXPROCESSING; + info->flags &= ~IEEE80211_TX_TEMPORARY_FLAGS; + skb_queue_tail(&sta->ps_tx_buf[ac], tx->skb); + spin_unlock(&sta->ps_lock); + + if (!timer_pending(&local->sta_cleanup)) + mod_timer(&local->sta_cleanup, + round_jiffies(jiffies + + STA_INFO_CLEANUP_INTERVAL)); + + /* + * We queued up some frames, so the TIM bit might + * need to be set, recalculate it. + */ + sta_info_recalc_tim(sta); + + return TX_QUEUED; + } else if (unlikely(test_sta_flag(sta, WLAN_STA_PS_STA))) { + ps_dbg(tx->sdata, + "STA %pM in PS mode, but polling/in SP -> send frame\n", + sta->sta.addr); + } + + return TX_CONTINUE; +} + +static ieee80211_tx_result debug_noinline +ieee80211_tx_h_ps_buf(struct ieee80211_tx_data *tx) +{ + if (unlikely(tx->flags & IEEE80211_TX_PS_BUFFERED)) + return TX_CONTINUE; + + if (tx->flags & IEEE80211_TX_UNICAST) + return ieee80211_tx_h_unicast_ps_buf(tx); + else + return ieee80211_tx_h_multicast_ps_buf(tx); +} + +static ieee80211_tx_result debug_noinline +ieee80211_tx_h_check_control_port_protocol(struct ieee80211_tx_data *tx) +{ + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(tx->skb); + + if (unlikely(tx->sdata->control_port_protocol == tx->skb->protocol)) { + if (tx->sdata->control_port_no_encrypt) + info->flags |= IEEE80211_TX_INTFL_DONT_ENCRYPT; + info->control.flags |= IEEE80211_TX_CTRL_PORT_CTRL_PROTO; + info->flags |= IEEE80211_TX_CTL_USE_MINRATE; + } + + return TX_CONTINUE; +} + +static struct ieee80211_key * +ieee80211_select_link_key(struct ieee80211_tx_data *tx) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)tx->skb->data; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(tx->skb); + enum { + USE_NONE, + USE_MGMT_KEY, + USE_MCAST_KEY, + } which_key = USE_NONE; + struct ieee80211_link_data *link; + unsigned int link_id; + + if (ieee80211_is_group_privacy_action(tx->skb)) + which_key = USE_MCAST_KEY; + else if (ieee80211_is_mgmt(hdr->frame_control) && + is_multicast_ether_addr(hdr->addr1) && + ieee80211_is_robust_mgmt_frame(tx->skb)) + which_key = USE_MGMT_KEY; + else if (is_multicast_ether_addr(hdr->addr1)) + which_key = USE_MCAST_KEY; + else + return NULL; + + link_id = u32_get_bits(info->control.flags, IEEE80211_TX_CTRL_MLO_LINK); + if (link_id == IEEE80211_LINK_UNSPECIFIED) { + link = &tx->sdata->deflink; + } else { + link = rcu_dereference(tx->sdata->link[link_id]); + if (!link) + return NULL; + } + + switch (which_key) { + case USE_NONE: + break; + case USE_MGMT_KEY: + return rcu_dereference(link->default_mgmt_key); + case USE_MCAST_KEY: + return rcu_dereference(link->default_multicast_key); + } + + return NULL; +} + +static ieee80211_tx_result debug_noinline +ieee80211_tx_h_select_key(struct ieee80211_tx_data *tx) +{ + struct ieee80211_key *key; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(tx->skb); + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)tx->skb->data; + + if (unlikely(info->flags & IEEE80211_TX_INTFL_DONT_ENCRYPT)) { + tx->key = NULL; + return TX_CONTINUE; + } + + if (tx->sta && + (key = rcu_dereference(tx->sta->ptk[tx->sta->ptk_idx]))) + tx->key = key; + else if ((key = ieee80211_select_link_key(tx))) + tx->key = key; + else if (!is_multicast_ether_addr(hdr->addr1) && + (key = rcu_dereference(tx->sdata->default_unicast_key))) + tx->key = key; + else + tx->key = NULL; + + if (tx->key) { + bool skip_hw = false; + + /* TODO: add threshold stuff again */ + + switch (tx->key->conf.cipher) { + case WLAN_CIPHER_SUITE_WEP40: + case WLAN_CIPHER_SUITE_WEP104: + case WLAN_CIPHER_SUITE_TKIP: + if (!ieee80211_is_data_present(hdr->frame_control)) + tx->key = NULL; + break; + case WLAN_CIPHER_SUITE_CCMP: + case WLAN_CIPHER_SUITE_CCMP_256: + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + if (!ieee80211_is_data_present(hdr->frame_control) && + !ieee80211_use_mfp(hdr->frame_control, tx->sta, + tx->skb) && + !ieee80211_is_group_privacy_action(tx->skb)) + tx->key = NULL; + else + skip_hw = (tx->key->conf.flags & + IEEE80211_KEY_FLAG_SW_MGMT_TX) && + ieee80211_is_mgmt(hdr->frame_control); + break; + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + if (!ieee80211_is_mgmt(hdr->frame_control)) + tx->key = NULL; + break; + } + + if (unlikely(tx->key && tx->key->flags & KEY_FLAG_TAINTED && + !ieee80211_is_deauth(hdr->frame_control)) && + tx->skb->protocol != tx->sdata->control_port_protocol) + return TX_DROP; + + if (!skip_hw && tx->key && + tx->key->flags & KEY_FLAG_UPLOADED_TO_HARDWARE) + info->control.hw_key = &tx->key->conf; + } else if (ieee80211_is_data_present(hdr->frame_control) && tx->sta && + test_sta_flag(tx->sta, WLAN_STA_USES_ENCRYPTION)) { + return TX_DROP; + } + + return TX_CONTINUE; +} + +static ieee80211_tx_result debug_noinline +ieee80211_tx_h_rate_ctrl(struct ieee80211_tx_data *tx) +{ + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(tx->skb); + struct ieee80211_hdr *hdr = (void *)tx->skb->data; + struct ieee80211_supported_band *sband; + u32 len; + struct ieee80211_tx_rate_control txrc; + struct ieee80211_sta_rates *ratetbl = NULL; + bool encap = info->flags & IEEE80211_TX_CTL_HW_80211_ENCAP; + bool assoc = false; + + memset(&txrc, 0, sizeof(txrc)); + + sband = tx->local->hw.wiphy->bands[info->band]; + + len = min_t(u32, tx->skb->len + FCS_LEN, + tx->local->hw.wiphy->frag_threshold); + + /* set up the tx rate control struct we give the RC algo */ + txrc.hw = &tx->local->hw; + txrc.sband = sband; + txrc.bss_conf = &tx->sdata->vif.bss_conf; + txrc.skb = tx->skb; + txrc.reported_rate.idx = -1; + txrc.rate_idx_mask = tx->sdata->rc_rateidx_mask[info->band]; + + if (tx->sdata->rc_has_mcs_mask[info->band]) + txrc.rate_idx_mcs_mask = + tx->sdata->rc_rateidx_mcs_mask[info->band]; + + txrc.bss = (tx->sdata->vif.type == NL80211_IFTYPE_AP || + tx->sdata->vif.type == NL80211_IFTYPE_MESH_POINT || + tx->sdata->vif.type == NL80211_IFTYPE_ADHOC || + tx->sdata->vif.type == NL80211_IFTYPE_OCB); + + /* set up RTS protection if desired */ + if (len > tx->local->hw.wiphy->rts_threshold) { + txrc.rts = true; + } + + info->control.use_rts = txrc.rts; + info->control.use_cts_prot = tx->sdata->vif.bss_conf.use_cts_prot; + + /* + * Use short preamble if the BSS can handle it, but not for + * management frames unless we know the receiver can handle + * that -- the management frame might be to a station that + * just wants a probe response. + */ + if (tx->sdata->vif.bss_conf.use_short_preamble && + (ieee80211_is_tx_data(tx->skb) || + (tx->sta && test_sta_flag(tx->sta, WLAN_STA_SHORT_PREAMBLE)))) + txrc.short_preamble = true; + + info->control.short_preamble = txrc.short_preamble; + + /* don't ask rate control when rate already injected via radiotap */ + if (info->control.flags & IEEE80211_TX_CTRL_RATE_INJECT) + return TX_CONTINUE; + + if (tx->sta) + assoc = test_sta_flag(tx->sta, WLAN_STA_ASSOC); + + /* + * Lets not bother rate control if we're associated and cannot + * talk to the sta. This should not happen. + */ + if (WARN(test_bit(SCAN_SW_SCANNING, &tx->local->scanning) && assoc && + !rate_usable_index_exists(sband, &tx->sta->sta), + "%s: Dropped data frame as no usable bitrate found while " + "scanning and associated. Target station: " + "%pM on %d GHz band\n", + tx->sdata->name, + encap ? ((struct ethhdr *)hdr)->h_dest : hdr->addr1, + info->band ? 5 : 2)) + return TX_DROP; + + /* + * If we're associated with the sta at this point we know we can at + * least send the frame at the lowest bit rate. + */ + rate_control_get_rate(tx->sdata, tx->sta, &txrc); + + if (tx->sta && !info->control.skip_table) + ratetbl = rcu_dereference(tx->sta->sta.rates); + + if (unlikely(info->control.rates[0].idx < 0)) { + if (ratetbl) { + struct ieee80211_tx_rate rate = { + .idx = ratetbl->rate[0].idx, + .flags = ratetbl->rate[0].flags, + .count = ratetbl->rate[0].count + }; + + if (ratetbl->rate[0].idx < 0) + return TX_DROP; + + tx->rate = rate; + } else { + return TX_DROP; + } + } else { + tx->rate = info->control.rates[0]; + } + + if (txrc.reported_rate.idx < 0) { + txrc.reported_rate = tx->rate; + if (tx->sta && ieee80211_is_tx_data(tx->skb)) + tx->sta->deflink.tx_stats.last_rate = txrc.reported_rate; + } else if (tx->sta) + tx->sta->deflink.tx_stats.last_rate = txrc.reported_rate; + + if (ratetbl) + return TX_CONTINUE; + + if (unlikely(!info->control.rates[0].count)) + info->control.rates[0].count = 1; + + if (WARN_ON_ONCE((info->control.rates[0].count > 1) && + (info->flags & IEEE80211_TX_CTL_NO_ACK))) + info->control.rates[0].count = 1; + + return TX_CONTINUE; +} + +static __le16 ieee80211_tx_next_seq(struct sta_info *sta, int tid) +{ + u16 *seq = &sta->tid_seq[tid]; + __le16 ret = cpu_to_le16(*seq); + + /* Increase the sequence number. */ + *seq = (*seq + 0x10) & IEEE80211_SCTL_SEQ; + + return ret; +} + +static ieee80211_tx_result debug_noinline +ieee80211_tx_h_sequence(struct ieee80211_tx_data *tx) +{ + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(tx->skb); + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)tx->skb->data; + int tid; + + /* + * Packet injection may want to control the sequence + * number, if we have no matching interface then we + * neither assign one ourselves nor ask the driver to. + */ + if (unlikely(info->control.vif->type == NL80211_IFTYPE_MONITOR)) + return TX_CONTINUE; + + if (unlikely(ieee80211_is_ctl(hdr->frame_control))) + return TX_CONTINUE; + + if (ieee80211_hdrlen(hdr->frame_control) < 24) + return TX_CONTINUE; + + if (ieee80211_is_qos_nullfunc(hdr->frame_control)) + return TX_CONTINUE; + + if (info->control.flags & IEEE80211_TX_CTRL_NO_SEQNO) + return TX_CONTINUE; + + /* SNS11 from 802.11be 10.3.2.14 */ + if (unlikely(is_multicast_ether_addr(hdr->addr1) && + info->control.vif->valid_links && + info->control.vif->type == NL80211_IFTYPE_AP)) { + if (info->control.flags & IEEE80211_TX_CTRL_MCAST_MLO_FIRST_TX) + tx->sdata->mld_mcast_seq += 0x10; + hdr->seq_ctrl = cpu_to_le16(tx->sdata->mld_mcast_seq); + return TX_CONTINUE; + } + + /* + * Anything but QoS data that has a sequence number field + * (is long enough) gets a sequence number from the global + * counter. QoS data frames with a multicast destination + * also use the global counter (802.11-2012 9.3.2.10). + */ + if (!ieee80211_is_data_qos(hdr->frame_control) || + is_multicast_ether_addr(hdr->addr1)) { + /* driver should assign sequence number */ + info->flags |= IEEE80211_TX_CTL_ASSIGN_SEQ; + /* for pure STA mode without beacons, we can do it */ + hdr->seq_ctrl = cpu_to_le16(tx->sdata->sequence_number); + tx->sdata->sequence_number += 0x10; + if (tx->sta) + tx->sta->deflink.tx_stats.msdu[IEEE80211_NUM_TIDS]++; + return TX_CONTINUE; + } + + /* + * This should be true for injected/management frames only, for + * management frames we have set the IEEE80211_TX_CTL_ASSIGN_SEQ + * above since they are not QoS-data frames. + */ + if (!tx->sta) + return TX_CONTINUE; + + /* include per-STA, per-TID sequence counter */ + tid = ieee80211_get_tid(hdr); + tx->sta->deflink.tx_stats.msdu[tid]++; + + hdr->seq_ctrl = ieee80211_tx_next_seq(tx->sta, tid); + + return TX_CONTINUE; +} + +static int ieee80211_fragment(struct ieee80211_tx_data *tx, + struct sk_buff *skb, int hdrlen, + int frag_threshold) +{ + struct ieee80211_local *local = tx->local; + struct ieee80211_tx_info *info; + struct sk_buff *tmp; + int per_fragm = frag_threshold - hdrlen - FCS_LEN; + int pos = hdrlen + per_fragm; + int rem = skb->len - hdrlen - per_fragm; + + if (WARN_ON(rem < 0)) + return -EINVAL; + + /* first fragment was already added to queue by caller */ + + while (rem) { + int fraglen = per_fragm; + + if (fraglen > rem) + fraglen = rem; + rem -= fraglen; + tmp = dev_alloc_skb(local->tx_headroom + + frag_threshold + + IEEE80211_ENCRYPT_HEADROOM + + IEEE80211_ENCRYPT_TAILROOM); + if (!tmp) + return -ENOMEM; + + __skb_queue_tail(&tx->skbs, tmp); + + skb_reserve(tmp, + local->tx_headroom + IEEE80211_ENCRYPT_HEADROOM); + + /* copy control information */ + memcpy(tmp->cb, skb->cb, sizeof(tmp->cb)); + + info = IEEE80211_SKB_CB(tmp); + info->flags &= ~(IEEE80211_TX_CTL_CLEAR_PS_FILT | + IEEE80211_TX_CTL_FIRST_FRAGMENT); + + if (rem) + info->flags |= IEEE80211_TX_CTL_MORE_FRAMES; + + skb_copy_queue_mapping(tmp, skb); + tmp->priority = skb->priority; + tmp->dev = skb->dev; + + /* copy header and data */ + skb_put_data(tmp, skb->data, hdrlen); + skb_put_data(tmp, skb->data + pos, fraglen); + + pos += fraglen; + } + + /* adjust first fragment's length */ + skb_trim(skb, hdrlen + per_fragm); + return 0; +} + +static ieee80211_tx_result debug_noinline +ieee80211_tx_h_fragment(struct ieee80211_tx_data *tx) +{ + struct sk_buff *skb = tx->skb; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_hdr *hdr = (void *)skb->data; + int frag_threshold = tx->local->hw.wiphy->frag_threshold; + int hdrlen; + int fragnum; + + /* no matter what happens, tx->skb moves to tx->skbs */ + __skb_queue_tail(&tx->skbs, skb); + tx->skb = NULL; + + if (info->flags & IEEE80211_TX_CTL_DONTFRAG) + return TX_CONTINUE; + + if (ieee80211_hw_check(&tx->local->hw, SUPPORTS_TX_FRAG)) + return TX_CONTINUE; + + /* + * Warn when submitting a fragmented A-MPDU frame and drop it. + * This scenario is handled in ieee80211_tx_prepare but extra + * caution taken here as fragmented ampdu may cause Tx stop. + */ + if (WARN_ON(info->flags & IEEE80211_TX_CTL_AMPDU)) + return TX_DROP; + + hdrlen = ieee80211_hdrlen(hdr->frame_control); + + /* internal error, why isn't DONTFRAG set? */ + if (WARN_ON(skb->len + FCS_LEN <= frag_threshold)) + return TX_DROP; + + /* + * Now fragment the frame. This will allocate all the fragments and + * chain them (using skb as the first fragment) to skb->next. + * During transmission, we will remove the successfully transmitted + * fragments from this list. When the low-level driver rejects one + * of the fragments then we will simply pretend to accept the skb + * but store it away as pending. + */ + if (ieee80211_fragment(tx, skb, hdrlen, frag_threshold)) + return TX_DROP; + + /* update duration/seq/flags of fragments */ + fragnum = 0; + + skb_queue_walk(&tx->skbs, skb) { + const __le16 morefrags = cpu_to_le16(IEEE80211_FCTL_MOREFRAGS); + + hdr = (void *)skb->data; + info = IEEE80211_SKB_CB(skb); + + if (!skb_queue_is_last(&tx->skbs, skb)) { + hdr->frame_control |= morefrags; + /* + * No multi-rate retries for fragmented frames, that + * would completely throw off the NAV at other STAs. + */ + info->control.rates[1].idx = -1; + info->control.rates[2].idx = -1; + info->control.rates[3].idx = -1; + BUILD_BUG_ON(IEEE80211_TX_MAX_RATES != 4); + info->flags &= ~IEEE80211_TX_CTL_RATE_CTRL_PROBE; + } else { + hdr->frame_control &= ~morefrags; + } + hdr->seq_ctrl |= cpu_to_le16(fragnum & IEEE80211_SCTL_FRAG); + fragnum++; + } + + return TX_CONTINUE; +} + +static ieee80211_tx_result debug_noinline +ieee80211_tx_h_stats(struct ieee80211_tx_data *tx) +{ + struct sk_buff *skb; + int ac = -1; + + if (!tx->sta) + return TX_CONTINUE; + + skb_queue_walk(&tx->skbs, skb) { + ac = skb_get_queue_mapping(skb); + tx->sta->deflink.tx_stats.bytes[ac] += skb->len; + } + if (ac >= 0) + tx->sta->deflink.tx_stats.packets[ac]++; + + return TX_CONTINUE; +} + +static ieee80211_tx_result debug_noinline +ieee80211_tx_h_encrypt(struct ieee80211_tx_data *tx) +{ + if (!tx->key) + return TX_CONTINUE; + + switch (tx->key->conf.cipher) { + case WLAN_CIPHER_SUITE_WEP40: + case WLAN_CIPHER_SUITE_WEP104: + return ieee80211_crypto_wep_encrypt(tx); + case WLAN_CIPHER_SUITE_TKIP: + return ieee80211_crypto_tkip_encrypt(tx); + case WLAN_CIPHER_SUITE_CCMP: + return ieee80211_crypto_ccmp_encrypt( + tx, IEEE80211_CCMP_MIC_LEN); + case WLAN_CIPHER_SUITE_CCMP_256: + return ieee80211_crypto_ccmp_encrypt( + tx, IEEE80211_CCMP_256_MIC_LEN); + case WLAN_CIPHER_SUITE_AES_CMAC: + return ieee80211_crypto_aes_cmac_encrypt(tx); + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + return ieee80211_crypto_aes_cmac_256_encrypt(tx); + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + return ieee80211_crypto_aes_gmac_encrypt(tx); + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + return ieee80211_crypto_gcmp_encrypt(tx); + } + + return TX_DROP; +} + +static ieee80211_tx_result debug_noinline +ieee80211_tx_h_calculate_duration(struct ieee80211_tx_data *tx) +{ + struct sk_buff *skb; + struct ieee80211_hdr *hdr; + int next_len; + bool group_addr; + + skb_queue_walk(&tx->skbs, skb) { + hdr = (void *) skb->data; + if (unlikely(ieee80211_is_pspoll(hdr->frame_control))) + break; /* must not overwrite AID */ + if (!skb_queue_is_last(&tx->skbs, skb)) { + struct sk_buff *next = skb_queue_next(&tx->skbs, skb); + next_len = next->len; + } else + next_len = 0; + group_addr = is_multicast_ether_addr(hdr->addr1); + + hdr->duration_id = + ieee80211_duration(tx, skb, group_addr, next_len); + } + + return TX_CONTINUE; +} + +/* actual transmit path */ + +static bool ieee80211_tx_prep_agg(struct ieee80211_tx_data *tx, + struct sk_buff *skb, + struct ieee80211_tx_info *info, + struct tid_ampdu_tx *tid_tx, + int tid) +{ + bool queued = false; + bool reset_agg_timer = false; + struct sk_buff *purge_skb = NULL; + + if (test_bit(HT_AGG_STATE_OPERATIONAL, &tid_tx->state)) { + info->flags |= IEEE80211_TX_CTL_AMPDU; + reset_agg_timer = true; + } else if (test_bit(HT_AGG_STATE_WANT_START, &tid_tx->state)) { + /* + * nothing -- this aggregation session is being started + * but that might still fail with the driver + */ + } else if (!tx->sta->sta.txq[tid]) { + spin_lock(&tx->sta->lock); + /* + * Need to re-check now, because we may get here + * + * 1) in the window during which the setup is actually + * already done, but not marked yet because not all + * packets are spliced over to the driver pending + * queue yet -- if this happened we acquire the lock + * either before or after the splice happens, but + * need to recheck which of these cases happened. + * + * 2) during session teardown, if the OPERATIONAL bit + * was cleared due to the teardown but the pointer + * hasn't been assigned NULL yet (or we loaded it + * before it was assigned) -- in this case it may + * now be NULL which means we should just let the + * packet pass through because splicing the frames + * back is already done. + */ + tid_tx = rcu_dereference_protected_tid_tx(tx->sta, tid); + + if (!tid_tx) { + /* do nothing, let packet pass through */ + } else if (test_bit(HT_AGG_STATE_OPERATIONAL, &tid_tx->state)) { + info->flags |= IEEE80211_TX_CTL_AMPDU; + reset_agg_timer = true; + } else { + queued = true; + if (info->flags & IEEE80211_TX_CTL_NO_PS_BUFFER) { + clear_sta_flag(tx->sta, WLAN_STA_SP); + ps_dbg(tx->sta->sdata, + "STA %pM aid %d: SP frame queued, close the SP w/o telling the peer\n", + tx->sta->sta.addr, tx->sta->sta.aid); + } + info->control.vif = &tx->sdata->vif; + info->control.flags |= IEEE80211_TX_INTCFL_NEED_TXPROCESSING; + info->flags &= ~IEEE80211_TX_TEMPORARY_FLAGS; + __skb_queue_tail(&tid_tx->pending, skb); + if (skb_queue_len(&tid_tx->pending) > STA_MAX_TX_BUFFER) + purge_skb = __skb_dequeue(&tid_tx->pending); + } + spin_unlock(&tx->sta->lock); + + if (purge_skb) + ieee80211_free_txskb(&tx->local->hw, purge_skb); + } + + /* reset session timer */ + if (reset_agg_timer) + tid_tx->last_tx = jiffies; + + return queued; +} + +static void +ieee80211_aggr_check(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + struct sk_buff *skb) +{ + struct rate_control_ref *ref = sdata->local->rate_ctrl; + u16 tid; + + if (!ref || !(ref->ops->capa & RATE_CTRL_CAPA_AMPDU_TRIGGER)) + return; + + if (!sta || !sta->sta.deflink.ht_cap.ht_supported || + !sta->sta.wme || skb_get_queue_mapping(skb) == IEEE80211_AC_VO || + skb->protocol == sdata->control_port_protocol) + return; + + tid = skb->priority & IEEE80211_QOS_CTL_TID_MASK; + if (likely(sta->ampdu_mlme.tid_tx[tid])) + return; + + ieee80211_start_tx_ba_session(&sta->sta, tid, 0); +} + +/* + * initialises @tx + * pass %NULL for the station if unknown, a valid pointer if known + * or an ERR_PTR() if the station is known not to exist + */ +static ieee80211_tx_result +ieee80211_tx_prepare(struct ieee80211_sub_if_data *sdata, + struct ieee80211_tx_data *tx, + struct sta_info *sta, struct sk_buff *skb) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_hdr *hdr; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + bool aggr_check = false; + int tid; + + memset(tx, 0, sizeof(*tx)); + tx->skb = skb; + tx->local = local; + tx->sdata = sdata; + __skb_queue_head_init(&tx->skbs); + + /* + * If this flag is set to true anywhere, and we get here, + * we are doing the needed processing, so remove the flag + * now. + */ + info->control.flags &= ~IEEE80211_TX_INTCFL_NEED_TXPROCESSING; + + hdr = (struct ieee80211_hdr *) skb->data; + + if (likely(sta)) { + if (!IS_ERR(sta)) + tx->sta = sta; + } else { + if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) { + tx->sta = rcu_dereference(sdata->u.vlan.sta); + if (!tx->sta && sdata->wdev.use_4addr) + return TX_DROP; + } else if (tx->sdata->control_port_protocol == tx->skb->protocol) { + tx->sta = sta_info_get_bss(sdata, hdr->addr1); + } + if (!tx->sta && !is_multicast_ether_addr(hdr->addr1)) { + tx->sta = sta_info_get(sdata, hdr->addr1); + aggr_check = true; + } + } + + if (tx->sta && ieee80211_is_data_qos(hdr->frame_control) && + !ieee80211_is_qos_nullfunc(hdr->frame_control) && + ieee80211_hw_check(&local->hw, AMPDU_AGGREGATION) && + !ieee80211_hw_check(&local->hw, TX_AMPDU_SETUP_IN_HW)) { + struct tid_ampdu_tx *tid_tx; + + tid = ieee80211_get_tid(hdr); + tid_tx = rcu_dereference(tx->sta->ampdu_mlme.tid_tx[tid]); + if (!tid_tx && aggr_check) { + ieee80211_aggr_check(sdata, tx->sta, skb); + tid_tx = rcu_dereference(tx->sta->ampdu_mlme.tid_tx[tid]); + } + + if (tid_tx) { + bool queued; + + queued = ieee80211_tx_prep_agg(tx, skb, info, + tid_tx, tid); + + if (unlikely(queued)) + return TX_QUEUED; + } + } + + if (is_multicast_ether_addr(hdr->addr1)) { + tx->flags &= ~IEEE80211_TX_UNICAST; + info->flags |= IEEE80211_TX_CTL_NO_ACK; + } else + tx->flags |= IEEE80211_TX_UNICAST; + + if (!(info->flags & IEEE80211_TX_CTL_DONTFRAG)) { + if (!(tx->flags & IEEE80211_TX_UNICAST) || + skb->len + FCS_LEN <= local->hw.wiphy->frag_threshold || + (info->flags & IEEE80211_TX_CTL_AMPDU && + !local->ops->wake_tx_queue)) + info->flags |= IEEE80211_TX_CTL_DONTFRAG; + } + + if (!tx->sta) + info->flags |= IEEE80211_TX_CTL_CLEAR_PS_FILT; + else if (test_and_clear_sta_flag(tx->sta, WLAN_STA_CLEAR_PS_FILT)) { + info->flags |= IEEE80211_TX_CTL_CLEAR_PS_FILT; + ieee80211_check_fast_xmit(tx->sta); + } + + info->flags |= IEEE80211_TX_CTL_FIRST_FRAGMENT; + + return TX_CONTINUE; +} + +static struct txq_info *ieee80211_get_txq(struct ieee80211_local *local, + struct ieee80211_vif *vif, + struct sta_info *sta, + struct sk_buff *skb) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_txq *txq = NULL; + + if ((info->flags & IEEE80211_TX_CTL_SEND_AFTER_DTIM) || + (info->control.flags & IEEE80211_TX_CTRL_PS_RESPONSE)) + return NULL; + + if (!(info->flags & IEEE80211_TX_CTL_HW_80211_ENCAP) && + unlikely(!ieee80211_is_data_present(hdr->frame_control))) { + if ((!ieee80211_is_mgmt(hdr->frame_control) || + ieee80211_is_bufferable_mmpdu(hdr->frame_control) || + vif->type == NL80211_IFTYPE_STATION) && + sta && sta->uploaded) { + /* + * This will be NULL if the driver didn't set the + * opt-in hardware flag. + */ + txq = sta->sta.txq[IEEE80211_NUM_TIDS]; + } + } else if (sta) { + u8 tid = skb->priority & IEEE80211_QOS_CTL_TID_MASK; + + if (!sta->uploaded) + return NULL; + + txq = sta->sta.txq[tid]; + } else if (vif) { + txq = vif->txq; + } + + if (!txq) + return NULL; + + return to_txq_info(txq); +} + +static void ieee80211_set_skb_enqueue_time(struct sk_buff *skb) +{ + IEEE80211_SKB_CB(skb)->control.enqueue_time = codel_get_time(); +} + +static u32 codel_skb_len_func(const struct sk_buff *skb) +{ + return skb->len; +} + +static codel_time_t codel_skb_time_func(const struct sk_buff *skb) +{ + const struct ieee80211_tx_info *info; + + info = (const struct ieee80211_tx_info *)skb->cb; + return info->control.enqueue_time; +} + +static struct sk_buff *codel_dequeue_func(struct codel_vars *cvars, + void *ctx) +{ + struct ieee80211_local *local; + struct txq_info *txqi; + struct fq *fq; + struct fq_flow *flow; + + txqi = ctx; + local = vif_to_sdata(txqi->txq.vif)->local; + fq = &local->fq; + + if (cvars == &txqi->def_cvars) + flow = &txqi->tin.default_flow; + else + flow = &fq->flows[cvars - local->cvars]; + + return fq_flow_dequeue(fq, flow); +} + +static void codel_drop_func(struct sk_buff *skb, + void *ctx) +{ + struct ieee80211_local *local; + struct ieee80211_hw *hw; + struct txq_info *txqi; + + txqi = ctx; + local = vif_to_sdata(txqi->txq.vif)->local; + hw = &local->hw; + + ieee80211_free_txskb(hw, skb); +} + +static struct sk_buff *fq_tin_dequeue_func(struct fq *fq, + struct fq_tin *tin, + struct fq_flow *flow) +{ + struct ieee80211_local *local; + struct txq_info *txqi; + struct codel_vars *cvars; + struct codel_params *cparams; + struct codel_stats *cstats; + + local = container_of(fq, struct ieee80211_local, fq); + txqi = container_of(tin, struct txq_info, tin); + cstats = &txqi->cstats; + + if (txqi->txq.sta) { + struct sta_info *sta = container_of(txqi->txq.sta, + struct sta_info, sta); + cparams = &sta->cparams; + } else { + cparams = &local->cparams; + } + + if (flow == &tin->default_flow) + cvars = &txqi->def_cvars; + else + cvars = &local->cvars[flow - fq->flows]; + + return codel_dequeue(txqi, + &flow->backlog, + cparams, + cvars, + cstats, + codel_skb_len_func, + codel_skb_time_func, + codel_drop_func, + codel_dequeue_func); +} + +static void fq_skb_free_func(struct fq *fq, + struct fq_tin *tin, + struct fq_flow *flow, + struct sk_buff *skb) +{ + struct ieee80211_local *local; + + local = container_of(fq, struct ieee80211_local, fq); + ieee80211_free_txskb(&local->hw, skb); +} + +static void ieee80211_txq_enqueue(struct ieee80211_local *local, + struct txq_info *txqi, + struct sk_buff *skb) +{ + struct fq *fq = &local->fq; + struct fq_tin *tin = &txqi->tin; + u32 flow_idx = fq_flow_idx(fq, skb); + + ieee80211_set_skb_enqueue_time(skb); + + spin_lock_bh(&fq->lock); + /* + * For management frames, don't really apply codel etc., + * we don't want to apply any shaping or anything we just + * want to simplify the driver API by having them on the + * txqi. + */ + if (unlikely(txqi->txq.tid == IEEE80211_NUM_TIDS)) { + IEEE80211_SKB_CB(skb)->control.flags |= + IEEE80211_TX_INTCFL_NEED_TXPROCESSING; + __skb_queue_tail(&txqi->frags, skb); + } else { + fq_tin_enqueue(fq, tin, flow_idx, skb, + fq_skb_free_func); + } + spin_unlock_bh(&fq->lock); +} + +static bool fq_vlan_filter_func(struct fq *fq, struct fq_tin *tin, + struct fq_flow *flow, struct sk_buff *skb, + void *data) +{ + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + + return info->control.vif == data; +} + +void ieee80211_txq_remove_vlan(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + struct fq *fq = &local->fq; + struct txq_info *txqi; + struct fq_tin *tin; + struct ieee80211_sub_if_data *ap; + + if (WARN_ON(sdata->vif.type != NL80211_IFTYPE_AP_VLAN)) + return; + + ap = container_of(sdata->bss, struct ieee80211_sub_if_data, u.ap); + + if (!ap->vif.txq) + return; + + txqi = to_txq_info(ap->vif.txq); + tin = &txqi->tin; + + spin_lock_bh(&fq->lock); + fq_tin_filter(fq, tin, fq_vlan_filter_func, &sdata->vif, + fq_skb_free_func); + spin_unlock_bh(&fq->lock); +} + +void ieee80211_txq_init(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + struct txq_info *txqi, int tid) +{ + fq_tin_init(&txqi->tin); + codel_vars_init(&txqi->def_cvars); + codel_stats_init(&txqi->cstats); + __skb_queue_head_init(&txqi->frags); + INIT_LIST_HEAD(&txqi->schedule_order); + + txqi->txq.vif = &sdata->vif; + + if (!sta) { + sdata->vif.txq = &txqi->txq; + txqi->txq.tid = 0; + txqi->txq.ac = IEEE80211_AC_BE; + + return; + } + + if (tid == IEEE80211_NUM_TIDS) { + if (sdata->vif.type == NL80211_IFTYPE_STATION) { + /* Drivers need to opt in to the management MPDU TXQ */ + if (!ieee80211_hw_check(&sdata->local->hw, + STA_MMPDU_TXQ)) + return; + } else if (!ieee80211_hw_check(&sdata->local->hw, + BUFF_MMPDU_TXQ)) { + /* Drivers need to opt in to the bufferable MMPDU TXQ */ + return; + } + txqi->txq.ac = IEEE80211_AC_VO; + } else { + txqi->txq.ac = ieee80211_ac_from_tid(tid); + } + + txqi->txq.sta = &sta->sta; + txqi->txq.tid = tid; + sta->sta.txq[tid] = &txqi->txq; +} + +void ieee80211_txq_purge(struct ieee80211_local *local, + struct txq_info *txqi) +{ + struct fq *fq = &local->fq; + struct fq_tin *tin = &txqi->tin; + + spin_lock_bh(&fq->lock); + fq_tin_reset(fq, tin, fq_skb_free_func); + ieee80211_purge_tx_queue(&local->hw, &txqi->frags); + spin_unlock_bh(&fq->lock); + + spin_lock_bh(&local->active_txq_lock[txqi->txq.ac]); + list_del_init(&txqi->schedule_order); + spin_unlock_bh(&local->active_txq_lock[txqi->txq.ac]); +} + +void ieee80211_txq_set_params(struct ieee80211_local *local) +{ + if (local->hw.wiphy->txq_limit) + local->fq.limit = local->hw.wiphy->txq_limit; + else + local->hw.wiphy->txq_limit = local->fq.limit; + + if (local->hw.wiphy->txq_memory_limit) + local->fq.memory_limit = local->hw.wiphy->txq_memory_limit; + else + local->hw.wiphy->txq_memory_limit = local->fq.memory_limit; + + if (local->hw.wiphy->txq_quantum) + local->fq.quantum = local->hw.wiphy->txq_quantum; + else + local->hw.wiphy->txq_quantum = local->fq.quantum; +} + +int ieee80211_txq_setup_flows(struct ieee80211_local *local) +{ + struct fq *fq = &local->fq; + int ret; + int i; + bool supp_vht = false; + enum nl80211_band band; + + if (!local->ops->wake_tx_queue) + return 0; + + ret = fq_init(fq, 4096); + if (ret) + return ret; + + /* + * If the hardware doesn't support VHT, it is safe to limit the maximum + * queue size. 4 Mbytes is 64 max-size aggregates in 802.11n. + */ + for (band = 0; band < NUM_NL80211_BANDS; band++) { + struct ieee80211_supported_band *sband; + + sband = local->hw.wiphy->bands[band]; + if (!sband) + continue; + + supp_vht = supp_vht || sband->vht_cap.vht_supported; + } + + if (!supp_vht) + fq->memory_limit = 4 << 20; /* 4 Mbytes */ + + codel_params_init(&local->cparams); + local->cparams.interval = MS2TIME(100); + local->cparams.target = MS2TIME(20); + local->cparams.ecn = true; + + local->cvars = kcalloc(fq->flows_cnt, sizeof(local->cvars[0]), + GFP_KERNEL); + if (!local->cvars) { + spin_lock_bh(&fq->lock); + fq_reset(fq, fq_skb_free_func); + spin_unlock_bh(&fq->lock); + return -ENOMEM; + } + + for (i = 0; i < fq->flows_cnt; i++) + codel_vars_init(&local->cvars[i]); + + ieee80211_txq_set_params(local); + + return 0; +} + +void ieee80211_txq_teardown_flows(struct ieee80211_local *local) +{ + struct fq *fq = &local->fq; + + if (!local->ops->wake_tx_queue) + return; + + kfree(local->cvars); + local->cvars = NULL; + + spin_lock_bh(&fq->lock); + fq_reset(fq, fq_skb_free_func); + spin_unlock_bh(&fq->lock); +} + +static bool ieee80211_queue_skb(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + struct sk_buff *skb) +{ + struct ieee80211_vif *vif; + struct txq_info *txqi; + + if (!local->ops->wake_tx_queue || + sdata->vif.type == NL80211_IFTYPE_MONITOR) + return false; + + if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + sdata = container_of(sdata->bss, + struct ieee80211_sub_if_data, u.ap); + + vif = &sdata->vif; + txqi = ieee80211_get_txq(local, vif, sta, skb); + + if (!txqi) + return false; + + ieee80211_txq_enqueue(local, txqi, skb); + + schedule_and_wake_txq(local, txqi); + + return true; +} + +static bool ieee80211_tx_frags(struct ieee80211_local *local, + struct ieee80211_vif *vif, + struct sta_info *sta, + struct sk_buff_head *skbs, + bool txpending) +{ + struct ieee80211_tx_control control = {}; + struct sk_buff *skb, *tmp; + unsigned long flags; + + skb_queue_walk_safe(skbs, skb, tmp) { + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + int q = info->hw_queue; + +#ifdef CONFIG_MAC80211_VERBOSE_DEBUG + if (WARN_ON_ONCE(q >= local->hw.queues)) { + __skb_unlink(skb, skbs); + ieee80211_free_txskb(&local->hw, skb); + continue; + } +#endif + + spin_lock_irqsave(&local->queue_stop_reason_lock, flags); + if (local->queue_stop_reasons[q] || + (!txpending && !skb_queue_empty(&local->pending[q]))) { + if (unlikely(info->flags & + IEEE80211_TX_INTFL_OFFCHAN_TX_OK)) { + if (local->queue_stop_reasons[q] & + ~BIT(IEEE80211_QUEUE_STOP_REASON_OFFCHANNEL)) { + /* + * Drop off-channel frames if queues + * are stopped for any reason other + * than off-channel operation. Never + * queue them. + */ + spin_unlock_irqrestore( + &local->queue_stop_reason_lock, + flags); + ieee80211_purge_tx_queue(&local->hw, + skbs); + return true; + } + } else { + + /* + * Since queue is stopped, queue up frames for + * later transmission from the tx-pending + * tasklet when the queue is woken again. + */ + if (txpending) + skb_queue_splice_init(skbs, + &local->pending[q]); + else + skb_queue_splice_tail_init(skbs, + &local->pending[q]); + + spin_unlock_irqrestore(&local->queue_stop_reason_lock, + flags); + return false; + } + } + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); + + info->control.vif = vif; + control.sta = sta ? &sta->sta : NULL; + + __skb_unlink(skb, skbs); + drv_tx(local, &control, skb); + } + + return true; +} + +/* + * Returns false if the frame couldn't be transmitted but was queued instead. + */ +static bool __ieee80211_tx(struct ieee80211_local *local, + struct sk_buff_head *skbs, struct sta_info *sta, + bool txpending) +{ + struct ieee80211_tx_info *info; + struct ieee80211_sub_if_data *sdata; + struct ieee80211_vif *vif; + struct sk_buff *skb; + bool result; + + if (WARN_ON(skb_queue_empty(skbs))) + return true; + + skb = skb_peek(skbs); + info = IEEE80211_SKB_CB(skb); + sdata = vif_to_sdata(info->control.vif); + if (sta && !sta->uploaded) + sta = NULL; + + switch (sdata->vif.type) { + case NL80211_IFTYPE_MONITOR: + if (sdata->u.mntr.flags & MONITOR_FLAG_ACTIVE) { + vif = &sdata->vif; + break; + } + sdata = rcu_dereference(local->monitor_sdata); + if (sdata) { + vif = &sdata->vif; + info->hw_queue = + vif->hw_queue[skb_get_queue_mapping(skb)]; + } else if (ieee80211_hw_check(&local->hw, QUEUE_CONTROL)) { + ieee80211_purge_tx_queue(&local->hw, skbs); + return true; + } else + vif = NULL; + break; + case NL80211_IFTYPE_AP_VLAN: + sdata = container_of(sdata->bss, + struct ieee80211_sub_if_data, u.ap); + fallthrough; + default: + vif = &sdata->vif; + break; + } + + result = ieee80211_tx_frags(local, vif, sta, skbs, txpending); + + WARN_ON_ONCE(!skb_queue_empty(skbs)); + + return result; +} + +/* + * Invoke TX handlers, return 0 on success and non-zero if the + * frame was dropped or queued. + * + * The handlers are split into an early and late part. The latter is everything + * that can be sensitive to reordering, and will be deferred to after packets + * are dequeued from the intermediate queues (when they are enabled). + */ +static int invoke_tx_handlers_early(struct ieee80211_tx_data *tx) +{ + ieee80211_tx_result res = TX_DROP; + +#define CALL_TXH(txh) \ + do { \ + res = txh(tx); \ + if (res != TX_CONTINUE) \ + goto txh_done; \ + } while (0) + + CALL_TXH(ieee80211_tx_h_dynamic_ps); + CALL_TXH(ieee80211_tx_h_check_assoc); + CALL_TXH(ieee80211_tx_h_ps_buf); + CALL_TXH(ieee80211_tx_h_check_control_port_protocol); + CALL_TXH(ieee80211_tx_h_select_key); + + txh_done: + if (unlikely(res == TX_DROP)) { + I802_DEBUG_INC(tx->local->tx_handlers_drop); + if (tx->skb) + ieee80211_free_txskb(&tx->local->hw, tx->skb); + else + ieee80211_purge_tx_queue(&tx->local->hw, &tx->skbs); + return -1; + } else if (unlikely(res == TX_QUEUED)) { + I802_DEBUG_INC(tx->local->tx_handlers_queued); + return -1; + } + + return 0; +} + +/* + * Late handlers can be called while the sta lock is held. Handlers that can + * cause packets to be generated will cause deadlock! + */ +static int invoke_tx_handlers_late(struct ieee80211_tx_data *tx) +{ + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(tx->skb); + ieee80211_tx_result res = TX_CONTINUE; + + if (!ieee80211_hw_check(&tx->local->hw, HAS_RATE_CONTROL)) + CALL_TXH(ieee80211_tx_h_rate_ctrl); + + if (unlikely(info->flags & IEEE80211_TX_INTFL_RETRANSMISSION)) { + __skb_queue_tail(&tx->skbs, tx->skb); + tx->skb = NULL; + goto txh_done; + } + + CALL_TXH(ieee80211_tx_h_michael_mic_add); + CALL_TXH(ieee80211_tx_h_sequence); + CALL_TXH(ieee80211_tx_h_fragment); + /* handlers after fragment must be aware of tx info fragmentation! */ + CALL_TXH(ieee80211_tx_h_stats); + CALL_TXH(ieee80211_tx_h_encrypt); + if (!ieee80211_hw_check(&tx->local->hw, HAS_RATE_CONTROL)) + CALL_TXH(ieee80211_tx_h_calculate_duration); +#undef CALL_TXH + + txh_done: + if (unlikely(res == TX_DROP)) { + I802_DEBUG_INC(tx->local->tx_handlers_drop); + if (tx->skb) + ieee80211_free_txskb(&tx->local->hw, tx->skb); + else + ieee80211_purge_tx_queue(&tx->local->hw, &tx->skbs); + return -1; + } else if (unlikely(res == TX_QUEUED)) { + I802_DEBUG_INC(tx->local->tx_handlers_queued); + return -1; + } + + return 0; +} + +static int invoke_tx_handlers(struct ieee80211_tx_data *tx) +{ + int r = invoke_tx_handlers_early(tx); + + if (r) + return r; + return invoke_tx_handlers_late(tx); +} + +bool ieee80211_tx_prepare_skb(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, struct sk_buff *skb, + int band, struct ieee80211_sta **sta) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_tx_data tx; + struct sk_buff *skb2; + + if (ieee80211_tx_prepare(sdata, &tx, NULL, skb) == TX_DROP) + return false; + + info->band = band; + info->control.vif = vif; + info->hw_queue = vif->hw_queue[skb_get_queue_mapping(skb)]; + + if (invoke_tx_handlers(&tx)) + return false; + + if (sta) { + if (tx.sta) + *sta = &tx.sta->sta; + else + *sta = NULL; + } + + /* this function isn't suitable for fragmented data frames */ + skb2 = __skb_dequeue(&tx.skbs); + if (WARN_ON(skb2 != skb || !skb_queue_empty(&tx.skbs))) { + ieee80211_free_txskb(hw, skb2); + ieee80211_purge_tx_queue(hw, &tx.skbs); + return false; + } + + return true; +} +EXPORT_SYMBOL(ieee80211_tx_prepare_skb); + +/* + * Returns false if the frame couldn't be transmitted but was queued instead. + */ +static bool ieee80211_tx(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, struct sk_buff *skb, + bool txpending) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_tx_data tx; + ieee80211_tx_result res_prepare; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + bool result = true; + + if (unlikely(skb->len < 10)) { + dev_kfree_skb(skb); + return true; + } + + /* initialises tx */ + res_prepare = ieee80211_tx_prepare(sdata, &tx, sta, skb); + + if (unlikely(res_prepare == TX_DROP)) { + ieee80211_free_txskb(&local->hw, skb); + return true; + } else if (unlikely(res_prepare == TX_QUEUED)) { + return true; + } + + /* set up hw_queue value early */ + if (!(info->flags & IEEE80211_TX_CTL_TX_OFFCHAN) || + !ieee80211_hw_check(&local->hw, QUEUE_CONTROL)) + info->hw_queue = + sdata->vif.hw_queue[skb_get_queue_mapping(skb)]; + + if (invoke_tx_handlers_early(&tx)) + return true; + + if (ieee80211_queue_skb(local, sdata, tx.sta, tx.skb)) + return true; + + if (!invoke_tx_handlers_late(&tx)) + result = __ieee80211_tx(local, &tx.skbs, tx.sta, txpending); + + return result; +} + +/* device xmit handlers */ + +enum ieee80211_encrypt { + ENCRYPT_NO, + ENCRYPT_MGMT, + ENCRYPT_DATA, +}; + +static int ieee80211_skb_resize(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, + int head_need, + enum ieee80211_encrypt encrypt) +{ + struct ieee80211_local *local = sdata->local; + bool enc_tailroom; + int tail_need = 0; + + enc_tailroom = encrypt == ENCRYPT_MGMT || + (encrypt == ENCRYPT_DATA && + sdata->crypto_tx_tailroom_needed_cnt); + + if (enc_tailroom) { + tail_need = IEEE80211_ENCRYPT_TAILROOM; + tail_need -= skb_tailroom(skb); + tail_need = max_t(int, tail_need, 0); + } + + if (skb_cloned(skb) && + (!ieee80211_hw_check(&local->hw, SUPPORTS_CLONED_SKBS) || + !skb_clone_writable(skb, ETH_HLEN) || enc_tailroom)) + I802_DEBUG_INC(local->tx_expand_skb_head_cloned); + else if (head_need || tail_need) + I802_DEBUG_INC(local->tx_expand_skb_head); + else + return 0; + + if (pskb_expand_head(skb, head_need, tail_need, GFP_ATOMIC)) { + wiphy_debug(local->hw.wiphy, + "failed to reallocate TX buffer\n"); + return -ENOMEM; + } + + return 0; +} + +void ieee80211_xmit(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, struct sk_buff *skb) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + int headroom; + enum ieee80211_encrypt encrypt; + + if (info->flags & IEEE80211_TX_INTFL_DONT_ENCRYPT) + encrypt = ENCRYPT_NO; + else if (ieee80211_is_mgmt(hdr->frame_control)) + encrypt = ENCRYPT_MGMT; + else + encrypt = ENCRYPT_DATA; + + headroom = local->tx_headroom; + if (encrypt != ENCRYPT_NO) + headroom += IEEE80211_ENCRYPT_HEADROOM; + headroom -= skb_headroom(skb); + headroom = max_t(int, 0, headroom); + + if (ieee80211_skb_resize(sdata, skb, headroom, encrypt)) { + ieee80211_free_txskb(&local->hw, skb); + return; + } + + /* reload after potential resize */ + hdr = (struct ieee80211_hdr *) skb->data; + info->control.vif = &sdata->vif; + + if (ieee80211_vif_is_mesh(&sdata->vif)) { + if (ieee80211_is_data(hdr->frame_control) && + is_unicast_ether_addr(hdr->addr1)) { + if (mesh_nexthop_resolve(sdata, skb)) + return; /* skb queued: don't free */ + } else { + ieee80211_mps_set_frame_flags(sdata, NULL, hdr); + } + } + + ieee80211_set_qos_hdr(sdata, skb); + ieee80211_tx(sdata, sta, skb, false); +} + +static bool ieee80211_validate_radiotap_len(struct sk_buff *skb) +{ + struct ieee80211_radiotap_header *rthdr = + (struct ieee80211_radiotap_header *)skb->data; + + /* check for not even having the fixed radiotap header part */ + if (unlikely(skb->len < sizeof(struct ieee80211_radiotap_header))) + return false; /* too short to be possibly valid */ + + /* is it a header version we can trust to find length from? */ + if (unlikely(rthdr->it_version)) + return false; /* only version 0 is supported */ + + /* does the skb contain enough to deliver on the alleged length? */ + if (unlikely(skb->len < ieee80211_get_radiotap_len(skb->data))) + return false; /* skb too short for claimed rt header extent */ + + return true; +} + +bool ieee80211_parse_tx_radiotap(struct sk_buff *skb, + struct net_device *dev) +{ + struct ieee80211_local *local = wdev_priv(dev->ieee80211_ptr); + struct ieee80211_radiotap_iterator iterator; + struct ieee80211_radiotap_header *rthdr = + (struct ieee80211_radiotap_header *) skb->data; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + int ret = ieee80211_radiotap_iterator_init(&iterator, rthdr, skb->len, + NULL); + u16 txflags; + u16 rate = 0; + bool rate_found = false; + u8 rate_retries = 0; + u16 rate_flags = 0; + u8 mcs_known, mcs_flags, mcs_bw; + u16 vht_known; + u8 vht_mcs = 0, vht_nss = 0; + int i; + + if (!ieee80211_validate_radiotap_len(skb)) + return false; + + info->flags |= IEEE80211_TX_INTFL_DONT_ENCRYPT | + IEEE80211_TX_CTL_DONTFRAG; + + /* + * for every radiotap entry that is present + * (ieee80211_radiotap_iterator_next returns -ENOENT when no more + * entries present, or -EINVAL on error) + */ + + while (!ret) { + ret = ieee80211_radiotap_iterator_next(&iterator); + + if (ret) + continue; + + /* see if this argument is something we can use */ + switch (iterator.this_arg_index) { + /* + * You must take care when dereferencing iterator.this_arg + * for multibyte types... the pointer is not aligned. Use + * get_unaligned((type *)iterator.this_arg) to dereference + * iterator.this_arg for type "type" safely on all arches. + */ + case IEEE80211_RADIOTAP_FLAGS: + if (*iterator.this_arg & IEEE80211_RADIOTAP_F_FCS) { + /* + * this indicates that the skb we have been + * handed has the 32-bit FCS CRC at the end... + * we should react to that by snipping it off + * because it will be recomputed and added + * on transmission + */ + if (skb->len < (iterator._max_length + FCS_LEN)) + return false; + + skb_trim(skb, skb->len - FCS_LEN); + } + if (*iterator.this_arg & IEEE80211_RADIOTAP_F_WEP) + info->flags &= ~IEEE80211_TX_INTFL_DONT_ENCRYPT; + if (*iterator.this_arg & IEEE80211_RADIOTAP_F_FRAG) + info->flags &= ~IEEE80211_TX_CTL_DONTFRAG; + break; + + case IEEE80211_RADIOTAP_TX_FLAGS: + txflags = get_unaligned_le16(iterator.this_arg); + if (txflags & IEEE80211_RADIOTAP_F_TX_NOACK) + info->flags |= IEEE80211_TX_CTL_NO_ACK; + if (txflags & IEEE80211_RADIOTAP_F_TX_NOSEQNO) + info->control.flags |= IEEE80211_TX_CTRL_NO_SEQNO; + if (txflags & IEEE80211_RADIOTAP_F_TX_ORDER) + info->control.flags |= + IEEE80211_TX_CTRL_DONT_REORDER; + break; + + case IEEE80211_RADIOTAP_RATE: + rate = *iterator.this_arg; + rate_flags = 0; + rate_found = true; + break; + + case IEEE80211_RADIOTAP_DATA_RETRIES: + rate_retries = *iterator.this_arg; + break; + + case IEEE80211_RADIOTAP_MCS: + mcs_known = iterator.this_arg[0]; + mcs_flags = iterator.this_arg[1]; + if (!(mcs_known & IEEE80211_RADIOTAP_MCS_HAVE_MCS)) + break; + + rate_found = true; + rate = iterator.this_arg[2]; + rate_flags = IEEE80211_TX_RC_MCS; + + if (mcs_known & IEEE80211_RADIOTAP_MCS_HAVE_GI && + mcs_flags & IEEE80211_RADIOTAP_MCS_SGI) + rate_flags |= IEEE80211_TX_RC_SHORT_GI; + + mcs_bw = mcs_flags & IEEE80211_RADIOTAP_MCS_BW_MASK; + if (mcs_known & IEEE80211_RADIOTAP_MCS_HAVE_BW && + mcs_bw == IEEE80211_RADIOTAP_MCS_BW_40) + rate_flags |= IEEE80211_TX_RC_40_MHZ_WIDTH; + + if (mcs_known & IEEE80211_RADIOTAP_MCS_HAVE_FEC && + mcs_flags & IEEE80211_RADIOTAP_MCS_FEC_LDPC) + info->flags |= IEEE80211_TX_CTL_LDPC; + + if (mcs_known & IEEE80211_RADIOTAP_MCS_HAVE_STBC) { + u8 stbc = u8_get_bits(mcs_flags, + IEEE80211_RADIOTAP_MCS_STBC_MASK); + + info->flags |= + u32_encode_bits(stbc, + IEEE80211_TX_CTL_STBC); + } + break; + + case IEEE80211_RADIOTAP_VHT: + vht_known = get_unaligned_le16(iterator.this_arg); + rate_found = true; + + rate_flags = IEEE80211_TX_RC_VHT_MCS; + if ((vht_known & IEEE80211_RADIOTAP_VHT_KNOWN_GI) && + (iterator.this_arg[2] & + IEEE80211_RADIOTAP_VHT_FLAG_SGI)) + rate_flags |= IEEE80211_TX_RC_SHORT_GI; + if (vht_known & + IEEE80211_RADIOTAP_VHT_KNOWN_BANDWIDTH) { + if (iterator.this_arg[3] == 1) + rate_flags |= + IEEE80211_TX_RC_40_MHZ_WIDTH; + else if (iterator.this_arg[3] == 4) + rate_flags |= + IEEE80211_TX_RC_80_MHZ_WIDTH; + else if (iterator.this_arg[3] == 11) + rate_flags |= + IEEE80211_TX_RC_160_MHZ_WIDTH; + } + + vht_mcs = iterator.this_arg[4] >> 4; + if (vht_mcs > 11) + vht_mcs = 0; + vht_nss = iterator.this_arg[4] & 0xF; + if (!vht_nss || vht_nss > 8) + vht_nss = 1; + break; + + /* + * Please update the file + * Documentation/networking/mac80211-injection.rst + * when parsing new fields here. + */ + + default: + break; + } + } + + if (ret != -ENOENT) /* ie, if we didn't simply run out of fields */ + return false; + + if (rate_found) { + struct ieee80211_supported_band *sband = + local->hw.wiphy->bands[info->band]; + + info->control.flags |= IEEE80211_TX_CTRL_RATE_INJECT; + + for (i = 0; i < IEEE80211_TX_MAX_RATES; i++) { + info->control.rates[i].idx = -1; + info->control.rates[i].flags = 0; + info->control.rates[i].count = 0; + } + + if (rate_flags & IEEE80211_TX_RC_MCS) { + info->control.rates[0].idx = rate; + } else if (rate_flags & IEEE80211_TX_RC_VHT_MCS) { + ieee80211_rate_set_vht(info->control.rates, vht_mcs, + vht_nss); + } else if (sband) { + for (i = 0; i < sband->n_bitrates; i++) { + if (rate * 5 != sband->bitrates[i].bitrate) + continue; + + info->control.rates[0].idx = i; + break; + } + } + + if (info->control.rates[0].idx < 0) + info->control.flags &= ~IEEE80211_TX_CTRL_RATE_INJECT; + + info->control.rates[0].flags = rate_flags; + info->control.rates[0].count = min_t(u8, rate_retries + 1, + local->hw.max_rate_tries); + } + + return true; +} + +netdev_tx_t ieee80211_monitor_start_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct ieee80211_local *local = wdev_priv(dev->ieee80211_ptr); + struct ieee80211_chanctx_conf *chanctx_conf; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_hdr *hdr; + struct ieee80211_sub_if_data *tmp_sdata, *sdata; + struct cfg80211_chan_def *chandef; + u16 len_rthdr; + int hdrlen; + + sdata = IEEE80211_DEV_TO_SUB_IF(dev); + if (unlikely(!ieee80211_sdata_running(sdata))) + goto fail; + + memset(info, 0, sizeof(*info)); + info->flags = IEEE80211_TX_CTL_REQ_TX_STATUS | + IEEE80211_TX_CTL_INJECTED; + + /* Sanity-check the length of the radiotap header */ + if (!ieee80211_validate_radiotap_len(skb)) + goto fail; + + /* we now know there is a radiotap header with a length we can use */ + len_rthdr = ieee80211_get_radiotap_len(skb->data); + + /* + * fix up the pointers accounting for the radiotap + * header still being in there. We are being given + * a precooked IEEE80211 header so no need for + * normal processing + */ + skb_set_mac_header(skb, len_rthdr); + /* + * these are just fixed to the end of the rt area since we + * don't have any better information and at this point, nobody cares + */ + skb_set_network_header(skb, len_rthdr); + skb_set_transport_header(skb, len_rthdr); + + if (skb->len < len_rthdr + 2) + goto fail; + + hdr = (struct ieee80211_hdr *)(skb->data + len_rthdr); + hdrlen = ieee80211_hdrlen(hdr->frame_control); + + if (skb->len < len_rthdr + hdrlen) + goto fail; + + /* + * Initialize skb->protocol if the injected frame is a data frame + * carrying a rfc1042 header + */ + if (ieee80211_is_data(hdr->frame_control) && + skb->len >= len_rthdr + hdrlen + sizeof(rfc1042_header) + 2) { + u8 *payload = (u8 *)hdr + hdrlen; + + if (ether_addr_equal(payload, rfc1042_header)) + skb->protocol = cpu_to_be16((payload[6] << 8) | + payload[7]); + } + + rcu_read_lock(); + + /* + * We process outgoing injected frames that have a local address + * we handle as though they are non-injected frames. + * This code here isn't entirely correct, the local MAC address + * isn't always enough to find the interface to use; for proper + * VLAN support we have an nl80211-based mechanism. + * + * This is necessary, for example, for old hostapd versions that + * don't use nl80211-based management TX/RX. + */ + list_for_each_entry_rcu(tmp_sdata, &local->interfaces, list) { + if (!ieee80211_sdata_running(tmp_sdata)) + continue; + if (tmp_sdata->vif.type == NL80211_IFTYPE_MONITOR || + tmp_sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + continue; + if (ether_addr_equal(tmp_sdata->vif.addr, hdr->addr2)) { + sdata = tmp_sdata; + break; + } + } + + chanctx_conf = rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + if (!chanctx_conf) { + tmp_sdata = rcu_dereference(local->monitor_sdata); + if (tmp_sdata) + chanctx_conf = + rcu_dereference(tmp_sdata->vif.bss_conf.chanctx_conf); + } + + if (chanctx_conf) + chandef = &chanctx_conf->def; + else if (!local->use_chanctx) + chandef = &local->_oper_chandef; + else + goto fail_rcu; + + /* + * Frame injection is not allowed if beaconing is not allowed + * or if we need radar detection. Beaconing is usually not allowed when + * the mode or operation (Adhoc, AP, Mesh) does not support DFS. + * Passive scan is also used in world regulatory domains where + * your country is not known and as such it should be treated as + * NO TX unless the channel is explicitly allowed in which case + * your current regulatory domain would not have the passive scan + * flag. + * + * Since AP mode uses monitor interfaces to inject/TX management + * frames we can make AP mode the exception to this rule once it + * supports radar detection as its implementation can deal with + * radar detection by itself. We can do that later by adding a + * monitor flag interfaces used for AP support. + */ + if (!cfg80211_reg_can_beacon(local->hw.wiphy, chandef, + sdata->vif.type)) + goto fail_rcu; + + info->band = chandef->chan->band; + + /* Initialize skb->priority according to frame type and TID class, + * with respect to the sub interface that the frame will actually + * be transmitted on. If the DONT_REORDER flag is set, the original + * skb-priority is preserved to assure frames injected with this + * flag are not reordered relative to each other. + */ + ieee80211_select_queue_80211(sdata, skb, hdr); + skb_set_queue_mapping(skb, ieee80211_ac_from_tid(skb->priority)); + + /* + * Process the radiotap header. This will now take into account the + * selected chandef above to accurately set injection rates and + * retransmissions. + */ + if (!ieee80211_parse_tx_radiotap(skb, dev)) + goto fail_rcu; + + /* remove the injection radiotap header */ + skb_pull(skb, len_rthdr); + + ieee80211_xmit(sdata, NULL, skb); + rcu_read_unlock(); + + return NETDEV_TX_OK; + +fail_rcu: + rcu_read_unlock(); +fail: + dev_kfree_skb(skb); + return NETDEV_TX_OK; /* meaning, we dealt with the skb */ +} + +static inline bool ieee80211_is_tdls_setup(struct sk_buff *skb) +{ + u16 ethertype = (skb->data[12] << 8) | skb->data[13]; + + return ethertype == ETH_P_TDLS && + skb->len > 14 && + skb->data[14] == WLAN_TDLS_SNAP_RFTYPE; +} + +int ieee80211_lookup_ra_sta(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, + struct sta_info **sta_out) +{ + struct sta_info *sta; + + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP_VLAN: + sta = rcu_dereference(sdata->u.vlan.sta); + if (sta) { + *sta_out = sta; + return 0; + } else if (sdata->wdev.use_4addr) { + return -ENOLINK; + } + fallthrough; + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_OCB: + case NL80211_IFTYPE_ADHOC: + if (is_multicast_ether_addr(skb->data)) { + *sta_out = ERR_PTR(-ENOENT); + return 0; + } + sta = sta_info_get_bss(sdata, skb->data); + break; +#ifdef CONFIG_MAC80211_MESH + case NL80211_IFTYPE_MESH_POINT: + /* determined much later */ + *sta_out = NULL; + return 0; +#endif + case NL80211_IFTYPE_STATION: + if (sdata->wdev.wiphy->flags & WIPHY_FLAG_SUPPORTS_TDLS) { + sta = sta_info_get(sdata, skb->data); + if (sta && test_sta_flag(sta, WLAN_STA_TDLS_PEER)) { + if (test_sta_flag(sta, + WLAN_STA_TDLS_PEER_AUTH)) { + *sta_out = sta; + return 0; + } + + /* + * TDLS link during setup - throw out frames to + * peer. Allow TDLS-setup frames to unauthorized + * peers for the special case of a link teardown + * after a TDLS sta is removed due to being + * unreachable. + */ + if (!ieee80211_is_tdls_setup(skb)) + return -EINVAL; + } + + } + + sta = sta_info_get(sdata, sdata->vif.cfg.ap_addr); + if (!sta) + return -ENOLINK; + break; + default: + return -EINVAL; + } + + *sta_out = sta ?: ERR_PTR(-ENOENT); + return 0; +} + +static u16 ieee80211_store_ack_skb(struct ieee80211_local *local, + struct sk_buff *skb, + u32 *info_flags, + u64 *cookie) +{ + struct sk_buff *ack_skb; + u16 info_id = 0; + + if (skb->sk) + ack_skb = skb_clone_sk(skb); + else + ack_skb = skb_clone(skb, GFP_ATOMIC); + + if (ack_skb) { + unsigned long flags; + int id; + + spin_lock_irqsave(&local->ack_status_lock, flags); + id = idr_alloc(&local->ack_status_frames, ack_skb, + 1, 0x2000, GFP_ATOMIC); + spin_unlock_irqrestore(&local->ack_status_lock, flags); + + if (id >= 0) { + info_id = id; + *info_flags |= IEEE80211_TX_CTL_REQ_TX_STATUS; + if (cookie) { + *cookie = ieee80211_mgmt_tx_cookie(local); + IEEE80211_SKB_CB(ack_skb)->ack.cookie = *cookie; + } + } else { + kfree_skb(ack_skb); + } + } + + return info_id; +} + +/** + * ieee80211_build_hdr - build 802.11 header in the given frame + * @sdata: virtual interface to build the header for + * @skb: the skb to build the header in + * @info_flags: skb flags to set + * @sta: the station pointer + * @ctrl_flags: info control flags to set + * @cookie: cookie pointer to fill (if not %NULL) + * + * This function takes the skb with 802.3 header and reformats the header to + * the appropriate IEEE 802.11 header based on which interface the packet is + * being transmitted on. + * + * Note that this function also takes care of the TX status request and + * potential unsharing of the SKB - this needs to be interleaved with the + * header building. + * + * The function requires the read-side RCU lock held + * + * Returns: the (possibly reallocated) skb or an ERR_PTR() code + */ +static struct sk_buff *ieee80211_build_hdr(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, u32 info_flags, + struct sta_info *sta, u32 ctrl_flags, + u64 *cookie) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_tx_info *info; + int head_need; + u16 ethertype, hdrlen, meshhdrlen = 0; + __le16 fc; + struct ieee80211_hdr hdr; + struct ieee80211s_hdr mesh_hdr __maybe_unused; + struct mesh_path __maybe_unused *mppath = NULL, *mpath = NULL; + const u8 *encaps_data; + int encaps_len, skip_header_bytes; + bool wme_sta = false, authorized = false; + bool tdls_peer; + bool multicast; + u16 info_id = 0; + struct ieee80211_chanctx_conf *chanctx_conf = NULL; + enum nl80211_band band; + int ret; + u8 link_id = u32_get_bits(ctrl_flags, IEEE80211_TX_CTRL_MLO_LINK); + + if (IS_ERR(sta)) + sta = NULL; + +#ifdef CONFIG_MAC80211_DEBUGFS + if (local->force_tx_status) + info_flags |= IEEE80211_TX_CTL_REQ_TX_STATUS; +#endif + + /* convert Ethernet header to proper 802.11 header (based on + * operation mode) */ + ethertype = (skb->data[12] << 8) | skb->data[13]; + fc = cpu_to_le16(IEEE80211_FTYPE_DATA | IEEE80211_STYPE_DATA); + + if (!sdata->vif.valid_links) + chanctx_conf = + rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP_VLAN: + if (sdata->wdev.use_4addr) { + fc |= cpu_to_le16(IEEE80211_FCTL_FROMDS | IEEE80211_FCTL_TODS); + /* RA TA DA SA */ + memcpy(hdr.addr1, sta->sta.addr, ETH_ALEN); + memcpy(hdr.addr2, sdata->vif.addr, ETH_ALEN); + memcpy(hdr.addr3, skb->data, ETH_ALEN); + memcpy(hdr.addr4, skb->data + ETH_ALEN, ETH_ALEN); + hdrlen = 30; + authorized = test_sta_flag(sta, WLAN_STA_AUTHORIZED); + wme_sta = sta->sta.wme; + } + if (!sdata->vif.valid_links) { + struct ieee80211_sub_if_data *ap_sdata; + + /* override chanctx_conf from AP (we don't have one) */ + ap_sdata = container_of(sdata->bss, + struct ieee80211_sub_if_data, + u.ap); + chanctx_conf = + rcu_dereference(ap_sdata->vif.bss_conf.chanctx_conf); + } + if (sdata->wdev.use_4addr) + break; + fallthrough; + case NL80211_IFTYPE_AP: + fc |= cpu_to_le16(IEEE80211_FCTL_FROMDS); + /* DA BSSID SA */ + memcpy(hdr.addr1, skb->data, ETH_ALEN); + + if (sdata->vif.valid_links && sta && !sta->sta.mlo) { + struct ieee80211_link_data *link; + + link_id = sta->deflink.link_id; + link = rcu_dereference(sdata->link[link_id]); + if (WARN_ON(!link)) { + ret = -ENOLINK; + goto free; + } + memcpy(hdr.addr2, link->conf->addr, ETH_ALEN); + } else if (link_id == IEEE80211_LINK_UNSPECIFIED || + (sta && sta->sta.mlo)) { + memcpy(hdr.addr2, sdata->vif.addr, ETH_ALEN); + } else { + struct ieee80211_bss_conf *conf; + + conf = rcu_dereference(sdata->vif.link_conf[link_id]); + if (unlikely(!conf)) { + ret = -ENOLINK; + goto free; + } + + memcpy(hdr.addr2, conf->addr, ETH_ALEN); + } + + memcpy(hdr.addr3, skb->data + ETH_ALEN, ETH_ALEN); + hdrlen = 24; + break; +#ifdef CONFIG_MAC80211_MESH + case NL80211_IFTYPE_MESH_POINT: + if (!is_multicast_ether_addr(skb->data)) { + struct sta_info *next_hop; + bool mpp_lookup = true; + + mpath = mesh_path_lookup(sdata, skb->data); + if (mpath) { + mpp_lookup = false; + next_hop = rcu_dereference(mpath->next_hop); + if (!next_hop || + !(mpath->flags & (MESH_PATH_ACTIVE | + MESH_PATH_RESOLVING))) + mpp_lookup = true; + } + + if (mpp_lookup) { + mppath = mpp_path_lookup(sdata, skb->data); + if (mppath) + mppath->exp_time = jiffies; + } + + if (mppath && mpath) + mesh_path_del(sdata, mpath->dst); + } + + /* + * Use address extension if it is a packet from + * another interface or if we know the destination + * is being proxied by a portal (i.e. portal address + * differs from proxied address) + */ + if (ether_addr_equal(sdata->vif.addr, skb->data + ETH_ALEN) && + !(mppath && !ether_addr_equal(mppath->mpp, skb->data))) { + hdrlen = ieee80211_fill_mesh_addresses(&hdr, &fc, + skb->data, skb->data + ETH_ALEN); + meshhdrlen = ieee80211_new_mesh_header(sdata, &mesh_hdr, + NULL, NULL); + } else { + /* DS -> MBSS (802.11-2012 13.11.3.3). + * For unicast with unknown forwarding information, + * destination might be in the MBSS or if that fails + * forwarded to another mesh gate. In either case + * resolution will be handled in ieee80211_xmit(), so + * leave the original DA. This also works for mcast */ + const u8 *mesh_da = skb->data; + + if (mppath) + mesh_da = mppath->mpp; + else if (mpath) + mesh_da = mpath->dst; + + hdrlen = ieee80211_fill_mesh_addresses(&hdr, &fc, + mesh_da, sdata->vif.addr); + if (is_multicast_ether_addr(mesh_da)) + /* DA TA mSA AE:SA */ + meshhdrlen = ieee80211_new_mesh_header( + sdata, &mesh_hdr, + skb->data + ETH_ALEN, NULL); + else + /* RA TA mDA mSA AE:DA SA */ + meshhdrlen = ieee80211_new_mesh_header( + sdata, &mesh_hdr, skb->data, + skb->data + ETH_ALEN); + + } + + /* For injected frames, fill RA right away as nexthop lookup + * will be skipped. + */ + if ((ctrl_flags & IEEE80211_TX_CTRL_SKIP_MPATH_LOOKUP) && + is_zero_ether_addr(hdr.addr1)) + memcpy(hdr.addr1, skb->data, ETH_ALEN); + break; +#endif + case NL80211_IFTYPE_STATION: + /* we already did checks when looking up the RA STA */ + tdls_peer = test_sta_flag(sta, WLAN_STA_TDLS_PEER); + + if (tdls_peer) { + /* DA SA BSSID */ + memcpy(hdr.addr1, skb->data, ETH_ALEN); + memcpy(hdr.addr2, skb->data + ETH_ALEN, ETH_ALEN); + memcpy(hdr.addr3, sdata->deflink.u.mgd.bssid, ETH_ALEN); + hdrlen = 24; + } else if (sdata->u.mgd.use_4addr && + cpu_to_be16(ethertype) != sdata->control_port_protocol) { + fc |= cpu_to_le16(IEEE80211_FCTL_FROMDS | + IEEE80211_FCTL_TODS); + /* RA TA DA SA */ + memcpy(hdr.addr1, sdata->deflink.u.mgd.bssid, ETH_ALEN); + memcpy(hdr.addr2, sdata->vif.addr, ETH_ALEN); + memcpy(hdr.addr3, skb->data, ETH_ALEN); + memcpy(hdr.addr4, skb->data + ETH_ALEN, ETH_ALEN); + hdrlen = 30; + } else { + fc |= cpu_to_le16(IEEE80211_FCTL_TODS); + /* BSSID SA DA */ + memcpy(hdr.addr1, sdata->vif.cfg.ap_addr, ETH_ALEN); + memcpy(hdr.addr2, skb->data + ETH_ALEN, ETH_ALEN); + memcpy(hdr.addr3, skb->data, ETH_ALEN); + hdrlen = 24; + } + break; + case NL80211_IFTYPE_OCB: + /* DA SA BSSID */ + memcpy(hdr.addr1, skb->data, ETH_ALEN); + memcpy(hdr.addr2, skb->data + ETH_ALEN, ETH_ALEN); + eth_broadcast_addr(hdr.addr3); + hdrlen = 24; + break; + case NL80211_IFTYPE_ADHOC: + /* DA SA BSSID */ + memcpy(hdr.addr1, skb->data, ETH_ALEN); + memcpy(hdr.addr2, skb->data + ETH_ALEN, ETH_ALEN); + memcpy(hdr.addr3, sdata->u.ibss.bssid, ETH_ALEN); + hdrlen = 24; + break; + default: + ret = -EINVAL; + goto free; + } + + if (!chanctx_conf) { + if (!sdata->vif.valid_links) { + ret = -ENOTCONN; + goto free; + } + /* MLD transmissions must not rely on the band */ + band = 0; + } else { + band = chanctx_conf->def.chan->band; + } + + multicast = is_multicast_ether_addr(hdr.addr1); + + /* sta is always NULL for mesh */ + if (sta) { + authorized = test_sta_flag(sta, WLAN_STA_AUTHORIZED); + wme_sta = sta->sta.wme; + } else if (ieee80211_vif_is_mesh(&sdata->vif)) { + /* For mesh, the use of the QoS header is mandatory */ + wme_sta = true; + } + + /* receiver does QoS (which also means we do) use it */ + if (wme_sta) { + fc |= cpu_to_le16(IEEE80211_STYPE_QOS_DATA); + hdrlen += 2; + } + + /* + * Drop unicast frames to unauthorised stations unless they are + * EAPOL frames from the local station. + */ + if (unlikely(!ieee80211_vif_is_mesh(&sdata->vif) && + (sdata->vif.type != NL80211_IFTYPE_OCB) && + !multicast && !authorized && + (cpu_to_be16(ethertype) != sdata->control_port_protocol || + !ieee80211_is_our_addr(sdata, skb->data + ETH_ALEN, NULL)))) { +#ifdef CONFIG_MAC80211_VERBOSE_DEBUG + net_info_ratelimited("%s: dropped frame to %pM (unauthorized port)\n", + sdata->name, hdr.addr1); +#endif + + I802_DEBUG_INC(local->tx_handlers_drop_unauth_port); + + ret = -EPERM; + goto free; + } + + if (unlikely(!multicast && ((skb->sk && + skb_shinfo(skb)->tx_flags & SKBTX_WIFI_STATUS) || + ctrl_flags & IEEE80211_TX_CTL_REQ_TX_STATUS))) + info_id = ieee80211_store_ack_skb(local, skb, &info_flags, + cookie); + + /* + * If the skb is shared we need to obtain our own copy. + */ + skb = skb_share_check(skb, GFP_ATOMIC); + if (unlikely(!skb)) { + ret = -ENOMEM; + goto free; + } + + hdr.frame_control = fc; + hdr.duration_id = 0; + hdr.seq_ctrl = 0; + + skip_header_bytes = ETH_HLEN; + if (ethertype == ETH_P_AARP || ethertype == ETH_P_IPX) { + encaps_data = bridge_tunnel_header; + encaps_len = sizeof(bridge_tunnel_header); + skip_header_bytes -= 2; + } else if (ethertype >= ETH_P_802_3_MIN) { + encaps_data = rfc1042_header; + encaps_len = sizeof(rfc1042_header); + skip_header_bytes -= 2; + } else { + encaps_data = NULL; + encaps_len = 0; + } + + skb_pull(skb, skip_header_bytes); + head_need = hdrlen + encaps_len + meshhdrlen - skb_headroom(skb); + + /* + * So we need to modify the skb header and hence need a copy of + * that. The head_need variable above doesn't, so far, include + * the needed header space that we don't need right away. If we + * can, then we don't reallocate right now but only after the + * frame arrives at the master device (if it does...) + * + * If we cannot, however, then we will reallocate to include all + * the ever needed space. Also, if we need to reallocate it anyway, + * make it big enough for everything we may ever need. + */ + + if (head_need > 0 || skb_cloned(skb)) { + head_need += IEEE80211_ENCRYPT_HEADROOM; + head_need += local->tx_headroom; + head_need = max_t(int, 0, head_need); + if (ieee80211_skb_resize(sdata, skb, head_need, ENCRYPT_DATA)) { + ieee80211_free_txskb(&local->hw, skb); + skb = NULL; + return ERR_PTR(-ENOMEM); + } + } + + if (encaps_data) + memcpy(skb_push(skb, encaps_len), encaps_data, encaps_len); + +#ifdef CONFIG_MAC80211_MESH + if (meshhdrlen > 0) + memcpy(skb_push(skb, meshhdrlen), &mesh_hdr, meshhdrlen); +#endif + + if (ieee80211_is_data_qos(fc)) { + __le16 *qos_control; + + qos_control = skb_push(skb, 2); + memcpy(skb_push(skb, hdrlen - 2), &hdr, hdrlen - 2); + /* + * Maybe we could actually set some fields here, for now just + * initialise to zero to indicate no special operation. + */ + *qos_control = 0; + } else + memcpy(skb_push(skb, hdrlen), &hdr, hdrlen); + + skb_reset_mac_header(skb); + + info = IEEE80211_SKB_CB(skb); + memset(info, 0, sizeof(*info)); + + info->flags = info_flags; + info->ack_frame_id = info_id; + info->band = band; + + if (likely(!cookie)) { + ctrl_flags |= u32_encode_bits(link_id, + IEEE80211_TX_CTRL_MLO_LINK); + } else { + unsigned int pre_conf_link_id; + + /* + * ctrl_flags already have been set by + * ieee80211_tx_control_port(), here + * we just sanity check that + */ + + pre_conf_link_id = u32_get_bits(ctrl_flags, + IEEE80211_TX_CTRL_MLO_LINK); + + if (pre_conf_link_id != link_id && + link_id != IEEE80211_LINK_UNSPECIFIED) { +#ifdef CONFIG_MAC80211_VERBOSE_DEBUG + net_info_ratelimited("%s: dropped frame to %pM with bad link ID request (%d vs. %d)\n", + sdata->name, hdr.addr1, + pre_conf_link_id, link_id); +#endif + ret = -EINVAL; + goto free; + } + } + + info->control.flags = ctrl_flags; + + return skb; + free: + kfree_skb(skb); + return ERR_PTR(ret); +} + +/* + * fast-xmit overview + * + * The core idea of this fast-xmit is to remove per-packet checks by checking + * them out of band. ieee80211_check_fast_xmit() implements the out-of-band + * checks that are needed to get the sta->fast_tx pointer assigned, after which + * much less work can be done per packet. For example, fragmentation must be + * disabled or the fast_tx pointer will not be set. All the conditions are seen + * in the code here. + * + * Once assigned, the fast_tx data structure also caches the per-packet 802.11 + * header and other data to aid packet processing in ieee80211_xmit_fast(). + * + * The most difficult part of this is that when any of these assumptions + * change, an external trigger (i.e. a call to ieee80211_clear_fast_xmit(), + * ieee80211_check_fast_xmit() or friends) is required to reset the data, + * since the per-packet code no longer checks the conditions. This is reflected + * by the calls to these functions throughout the rest of the code, and must be + * maintained if any of the TX path checks change. + */ + +void ieee80211_check_fast_xmit(struct sta_info *sta) +{ + struct ieee80211_fast_tx build = {}, *fast_tx = NULL, *old; + struct ieee80211_local *local = sta->local; + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_hdr *hdr = (void *)build.hdr; + struct ieee80211_chanctx_conf *chanctx_conf; + __le16 fc; + + if (!ieee80211_hw_check(&local->hw, SUPPORT_FAST_XMIT)) + return; + + /* Locking here protects both the pointer itself, and against concurrent + * invocations winning data access races to, e.g., the key pointer that + * is used. + * Without it, the invocation of this function right after the key + * pointer changes wouldn't be sufficient, as another CPU could access + * the pointer, then stall, and then do the cache update after the CPU + * that invalidated the key. + * With the locking, such scenarios cannot happen as the check for the + * key and the fast-tx assignment are done atomically, so the CPU that + * modifies the key will either wait or other one will see the key + * cleared/changed already. + */ + spin_lock_bh(&sta->lock); + if (ieee80211_hw_check(&local->hw, SUPPORTS_PS) && + !ieee80211_hw_check(&local->hw, SUPPORTS_DYNAMIC_PS) && + sdata->vif.type == NL80211_IFTYPE_STATION) + goto out; + + if (!test_sta_flag(sta, WLAN_STA_AUTHORIZED)) + goto out; + + if (test_sta_flag(sta, WLAN_STA_PS_STA) || + test_sta_flag(sta, WLAN_STA_PS_DRIVER) || + test_sta_flag(sta, WLAN_STA_PS_DELIVER) || + test_sta_flag(sta, WLAN_STA_CLEAR_PS_FILT)) + goto out; + + if (sdata->noack_map) + goto out; + + /* fast-xmit doesn't handle fragmentation at all */ + if (local->hw.wiphy->frag_threshold != (u32)-1 && + !ieee80211_hw_check(&local->hw, SUPPORTS_TX_FRAG)) + goto out; + + if (!sdata->vif.valid_links) { + rcu_read_lock(); + chanctx_conf = + rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + if (!chanctx_conf) { + rcu_read_unlock(); + goto out; + } + build.band = chanctx_conf->def.chan->band; + rcu_read_unlock(); + } else { + /* MLD transmissions must not rely on the band */ + build.band = 0; + } + + fc = cpu_to_le16(IEEE80211_FTYPE_DATA | IEEE80211_STYPE_DATA); + + switch (sdata->vif.type) { + case NL80211_IFTYPE_ADHOC: + /* DA SA BSSID */ + build.da_offs = offsetof(struct ieee80211_hdr, addr1); + build.sa_offs = offsetof(struct ieee80211_hdr, addr2); + memcpy(hdr->addr3, sdata->u.ibss.bssid, ETH_ALEN); + build.hdr_len = 24; + break; + case NL80211_IFTYPE_STATION: + if (test_sta_flag(sta, WLAN_STA_TDLS_PEER)) { + /* DA SA BSSID */ + build.da_offs = offsetof(struct ieee80211_hdr, addr1); + build.sa_offs = offsetof(struct ieee80211_hdr, addr2); + memcpy(hdr->addr3, sdata->deflink.u.mgd.bssid, ETH_ALEN); + build.hdr_len = 24; + break; + } + + if (sdata->u.mgd.use_4addr) { + /* non-regular ethertype cannot use the fastpath */ + fc |= cpu_to_le16(IEEE80211_FCTL_FROMDS | + IEEE80211_FCTL_TODS); + /* RA TA DA SA */ + memcpy(hdr->addr1, sdata->deflink.u.mgd.bssid, ETH_ALEN); + memcpy(hdr->addr2, sdata->vif.addr, ETH_ALEN); + build.da_offs = offsetof(struct ieee80211_hdr, addr3); + build.sa_offs = offsetof(struct ieee80211_hdr, addr4); + build.hdr_len = 30; + break; + } + fc |= cpu_to_le16(IEEE80211_FCTL_TODS); + /* BSSID SA DA */ + memcpy(hdr->addr1, sdata->vif.cfg.ap_addr, ETH_ALEN); + build.da_offs = offsetof(struct ieee80211_hdr, addr3); + build.sa_offs = offsetof(struct ieee80211_hdr, addr2); + build.hdr_len = 24; + break; + case NL80211_IFTYPE_AP_VLAN: + if (sdata->wdev.use_4addr) { + fc |= cpu_to_le16(IEEE80211_FCTL_FROMDS | + IEEE80211_FCTL_TODS); + /* RA TA DA SA */ + memcpy(hdr->addr1, sta->sta.addr, ETH_ALEN); + memcpy(hdr->addr2, sdata->vif.addr, ETH_ALEN); + build.da_offs = offsetof(struct ieee80211_hdr, addr3); + build.sa_offs = offsetof(struct ieee80211_hdr, addr4); + build.hdr_len = 30; + break; + } + fallthrough; + case NL80211_IFTYPE_AP: + fc |= cpu_to_le16(IEEE80211_FCTL_FROMDS); + /* DA BSSID SA */ + build.da_offs = offsetof(struct ieee80211_hdr, addr1); + if (sta->sta.mlo || !sdata->vif.valid_links) { + memcpy(hdr->addr2, sdata->vif.addr, ETH_ALEN); + } else { + unsigned int link_id = sta->deflink.link_id; + struct ieee80211_link_data *link; + + rcu_read_lock(); + link = rcu_dereference(sdata->link[link_id]); + if (WARN_ON(!link)) { + rcu_read_unlock(); + goto out; + } + memcpy(hdr->addr2, link->conf->addr, ETH_ALEN); + rcu_read_unlock(); + } + build.sa_offs = offsetof(struct ieee80211_hdr, addr3); + build.hdr_len = 24; + break; + default: + /* not handled on fast-xmit */ + goto out; + } + + if (sta->sta.wme) { + build.hdr_len += 2; + fc |= cpu_to_le16(IEEE80211_STYPE_QOS_DATA); + } + + /* We store the key here so there's no point in using rcu_dereference() + * but that's fine because the code that changes the pointers will call + * this function after doing so. For a single CPU that would be enough, + * for multiple see the comment above. + */ + build.key = rcu_access_pointer(sta->ptk[sta->ptk_idx]); + if (!build.key) + build.key = rcu_access_pointer(sdata->default_unicast_key); + if (build.key) { + bool gen_iv, iv_spc, mmic; + + gen_iv = build.key->conf.flags & IEEE80211_KEY_FLAG_GENERATE_IV; + iv_spc = build.key->conf.flags & IEEE80211_KEY_FLAG_PUT_IV_SPACE; + mmic = build.key->conf.flags & + (IEEE80211_KEY_FLAG_GENERATE_MMIC | + IEEE80211_KEY_FLAG_PUT_MIC_SPACE); + + /* don't handle software crypto */ + if (!(build.key->flags & KEY_FLAG_UPLOADED_TO_HARDWARE)) + goto out; + + /* Key is being removed */ + if (build.key->flags & KEY_FLAG_TAINTED) + goto out; + + switch (build.key->conf.cipher) { + case WLAN_CIPHER_SUITE_CCMP: + case WLAN_CIPHER_SUITE_CCMP_256: + if (gen_iv) + build.pn_offs = build.hdr_len; + if (gen_iv || iv_spc) + build.hdr_len += IEEE80211_CCMP_HDR_LEN; + break; + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + if (gen_iv) + build.pn_offs = build.hdr_len; + if (gen_iv || iv_spc) + build.hdr_len += IEEE80211_GCMP_HDR_LEN; + break; + case WLAN_CIPHER_SUITE_TKIP: + /* cannot handle MMIC or IV generation in xmit-fast */ + if (mmic || gen_iv) + goto out; + if (iv_spc) + build.hdr_len += IEEE80211_TKIP_IV_LEN; + break; + case WLAN_CIPHER_SUITE_WEP40: + case WLAN_CIPHER_SUITE_WEP104: + /* cannot handle IV generation in fast-xmit */ + if (gen_iv) + goto out; + if (iv_spc) + build.hdr_len += IEEE80211_WEP_IV_LEN; + break; + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + WARN(1, + "management cipher suite 0x%x enabled for data\n", + build.key->conf.cipher); + goto out; + default: + /* we don't know how to generate IVs for this at all */ + if (WARN_ON(gen_iv)) + goto out; + } + + fc |= cpu_to_le16(IEEE80211_FCTL_PROTECTED); + } + + hdr->frame_control = fc; + + memcpy(build.hdr + build.hdr_len, + rfc1042_header, sizeof(rfc1042_header)); + build.hdr_len += sizeof(rfc1042_header); + + fast_tx = kmemdup(&build, sizeof(build), GFP_ATOMIC); + /* if the kmemdup fails, continue w/o fast_tx */ + + out: + /* we might have raced against another call to this function */ + old = rcu_dereference_protected(sta->fast_tx, + lockdep_is_held(&sta->lock)); + rcu_assign_pointer(sta->fast_tx, fast_tx); + if (old) + kfree_rcu(old, rcu_head); + spin_unlock_bh(&sta->lock); +} + +void ieee80211_check_fast_xmit_all(struct ieee80211_local *local) +{ + struct sta_info *sta; + + rcu_read_lock(); + list_for_each_entry_rcu(sta, &local->sta_list, list) + ieee80211_check_fast_xmit(sta); + rcu_read_unlock(); +} + +void ieee80211_check_fast_xmit_iface(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + + rcu_read_lock(); + + list_for_each_entry_rcu(sta, &local->sta_list, list) { + if (sdata != sta->sdata && + (!sta->sdata->bss || sta->sdata->bss != sdata->bss)) + continue; + ieee80211_check_fast_xmit(sta); + } + + rcu_read_unlock(); +} + +void ieee80211_clear_fast_xmit(struct sta_info *sta) +{ + struct ieee80211_fast_tx *fast_tx; + + spin_lock_bh(&sta->lock); + fast_tx = rcu_dereference_protected(sta->fast_tx, + lockdep_is_held(&sta->lock)); + RCU_INIT_POINTER(sta->fast_tx, NULL); + spin_unlock_bh(&sta->lock); + + if (fast_tx) + kfree_rcu(fast_tx, rcu_head); +} + +static bool ieee80211_amsdu_realloc_pad(struct ieee80211_local *local, + struct sk_buff *skb, int headroom) +{ + if (skb_headroom(skb) < headroom) { + I802_DEBUG_INC(local->tx_expand_skb_head); + + if (pskb_expand_head(skb, headroom, 0, GFP_ATOMIC)) { + wiphy_debug(local->hw.wiphy, + "failed to reallocate TX buffer\n"); + return false; + } + } + + return true; +} + +static bool ieee80211_amsdu_prepare_head(struct ieee80211_sub_if_data *sdata, + struct ieee80211_fast_tx *fast_tx, + struct sk_buff *skb) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_hdr *hdr; + struct ethhdr *amsdu_hdr; + int hdr_len = fast_tx->hdr_len - sizeof(rfc1042_header); + int subframe_len = skb->len - hdr_len; + void *data; + u8 *qc, *h_80211_src, *h_80211_dst; + const u8 *bssid; + + if (info->flags & IEEE80211_TX_CTL_RATE_CTRL_PROBE) + return false; + + if (info->control.flags & IEEE80211_TX_CTRL_AMSDU) + return true; + + if (!ieee80211_amsdu_realloc_pad(local, skb, + sizeof(*amsdu_hdr) + + local->hw.extra_tx_headroom)) + return false; + + data = skb_push(skb, sizeof(*amsdu_hdr)); + memmove(data, data + sizeof(*amsdu_hdr), hdr_len); + hdr = data; + amsdu_hdr = data + hdr_len; + /* h_80211_src/dst is addr* field within hdr */ + h_80211_src = data + fast_tx->sa_offs; + h_80211_dst = data + fast_tx->da_offs; + + amsdu_hdr->h_proto = cpu_to_be16(subframe_len); + ether_addr_copy(amsdu_hdr->h_source, h_80211_src); + ether_addr_copy(amsdu_hdr->h_dest, h_80211_dst); + + /* according to IEEE 802.11-2012 8.3.2 table 8-19, the outer SA/DA + * fields needs to be changed to BSSID for A-MSDU frames depending + * on FromDS/ToDS values. + */ + switch (sdata->vif.type) { + case NL80211_IFTYPE_STATION: + bssid = sdata->vif.cfg.ap_addr; + break; + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + bssid = sdata->vif.addr; + break; + default: + bssid = NULL; + } + + if (bssid && ieee80211_has_fromds(hdr->frame_control)) + ether_addr_copy(h_80211_src, bssid); + + if (bssid && ieee80211_has_tods(hdr->frame_control)) + ether_addr_copy(h_80211_dst, bssid); + + qc = ieee80211_get_qos_ctl(hdr); + *qc |= IEEE80211_QOS_CTL_A_MSDU_PRESENT; + + info->control.flags |= IEEE80211_TX_CTRL_AMSDU; + + return true; +} + +static bool ieee80211_amsdu_aggregate(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + struct ieee80211_fast_tx *fast_tx, + struct sk_buff *skb) +{ + struct ieee80211_local *local = sdata->local; + struct fq *fq = &local->fq; + struct fq_tin *tin; + struct fq_flow *flow; + u8 tid = skb->priority & IEEE80211_QOS_CTL_TAG1D_MASK; + struct ieee80211_txq *txq = sta->sta.txq[tid]; + struct txq_info *txqi; + struct sk_buff **frag_tail, *head; + int subframe_len = skb->len - ETH_ALEN; + u8 max_subframes = sta->sta.max_amsdu_subframes; + int max_frags = local->hw.max_tx_fragments; + int max_amsdu_len = sta->sta.cur->max_amsdu_len; + int orig_truesize; + u32 flow_idx; + __be16 len; + void *data; + bool ret = false; + unsigned int orig_len; + int n = 2, nfrags, pad = 0; + u16 hdrlen; + + if (!ieee80211_hw_check(&local->hw, TX_AMSDU)) + return false; + + if (sdata->vif.offload_flags & IEEE80211_OFFLOAD_ENCAP_ENABLED) + return false; + + if (skb_is_gso(skb)) + return false; + + if (!txq) + return false; + + txqi = to_txq_info(txq); + if (test_bit(IEEE80211_TXQ_NO_AMSDU, &txqi->flags)) + return false; + + if (sta->sta.cur->max_rc_amsdu_len) + max_amsdu_len = min_t(int, max_amsdu_len, + sta->sta.cur->max_rc_amsdu_len); + + if (sta->sta.cur->max_tid_amsdu_len[tid]) + max_amsdu_len = min_t(int, max_amsdu_len, + sta->sta.cur->max_tid_amsdu_len[tid]); + + flow_idx = fq_flow_idx(fq, skb); + + spin_lock_bh(&fq->lock); + + /* TODO: Ideally aggregation should be done on dequeue to remain + * responsive to environment changes. + */ + + tin = &txqi->tin; + flow = fq_flow_classify(fq, tin, flow_idx, skb); + head = skb_peek_tail(&flow->queue); + if (!head || skb_is_gso(head)) + goto out; + + orig_truesize = head->truesize; + orig_len = head->len; + + if (skb->len + head->len > max_amsdu_len) + goto out; + + nfrags = 1 + skb_shinfo(skb)->nr_frags; + nfrags += 1 + skb_shinfo(head)->nr_frags; + frag_tail = &skb_shinfo(head)->frag_list; + while (*frag_tail) { + nfrags += 1 + skb_shinfo(*frag_tail)->nr_frags; + frag_tail = &(*frag_tail)->next; + n++; + } + + if (max_subframes && n > max_subframes) + goto out; + + if (max_frags && nfrags > max_frags) + goto out; + + if (!drv_can_aggregate_in_amsdu(local, head, skb)) + goto out; + + if (!ieee80211_amsdu_prepare_head(sdata, fast_tx, head)) + goto out; + + /* If n == 2, the "while (*frag_tail)" loop above didn't execute + * and frag_tail should be &skb_shinfo(head)->frag_list. + * However, ieee80211_amsdu_prepare_head() can reallocate it. + * Reload frag_tail to have it pointing to the correct place. + */ + if (n == 2) + frag_tail = &skb_shinfo(head)->frag_list; + + /* + * Pad out the previous subframe to a multiple of 4 by adding the + * padding to the next one, that's being added. Note that head->len + * is the length of the full A-MSDU, but that works since each time + * we add a new subframe we pad out the previous one to a multiple + * of 4 and thus it no longer matters in the next round. + */ + hdrlen = fast_tx->hdr_len - sizeof(rfc1042_header); + if ((head->len - hdrlen) & 3) + pad = 4 - ((head->len - hdrlen) & 3); + + if (!ieee80211_amsdu_realloc_pad(local, skb, sizeof(rfc1042_header) + + 2 + pad)) + goto out_recalc; + + ret = true; + data = skb_push(skb, ETH_ALEN + 2); + memmove(data, data + ETH_ALEN + 2, 2 * ETH_ALEN); + + data += 2 * ETH_ALEN; + len = cpu_to_be16(subframe_len); + memcpy(data, &len, 2); + memcpy(data + 2, rfc1042_header, sizeof(rfc1042_header)); + + memset(skb_push(skb, pad), 0, pad); + + head->len += skb->len; + head->data_len += skb->len; + *frag_tail = skb; + +out_recalc: + fq->memory_usage += head->truesize - orig_truesize; + if (head->len != orig_len) { + flow->backlog += head->len - orig_len; + tin->backlog_bytes += head->len - orig_len; + } +out: + spin_unlock_bh(&fq->lock); + + return ret; +} + +/* + * Can be called while the sta lock is held. Anything that can cause packets to + * be generated will cause deadlock! + */ +static ieee80211_tx_result +ieee80211_xmit_fast_finish(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, u8 pn_offs, + struct ieee80211_key *key, + struct ieee80211_tx_data *tx) +{ + struct sk_buff *skb = tx->skb; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_hdr *hdr = (void *)skb->data; + u8 tid = IEEE80211_NUM_TIDS; + + if (!ieee80211_hw_check(&tx->local->hw, HAS_RATE_CONTROL) && + ieee80211_tx_h_rate_ctrl(tx) != TX_CONTINUE) + return TX_DROP; + + if (key) + info->control.hw_key = &key->conf; + + dev_sw_netstats_tx_add(skb->dev, 1, skb->len); + + if (hdr->frame_control & cpu_to_le16(IEEE80211_STYPE_QOS_DATA)) { + tid = skb->priority & IEEE80211_QOS_CTL_TAG1D_MASK; + hdr->seq_ctrl = ieee80211_tx_next_seq(sta, tid); + } else { + info->flags |= IEEE80211_TX_CTL_ASSIGN_SEQ; + hdr->seq_ctrl = cpu_to_le16(sdata->sequence_number); + sdata->sequence_number += 0x10; + } + + if (skb_shinfo(skb)->gso_size) + sta->deflink.tx_stats.msdu[tid] += + DIV_ROUND_UP(skb->len, skb_shinfo(skb)->gso_size); + else + sta->deflink.tx_stats.msdu[tid]++; + + info->hw_queue = sdata->vif.hw_queue[skb_get_queue_mapping(skb)]; + + /* statistics normally done by ieee80211_tx_h_stats (but that + * has to consider fragmentation, so is more complex) + */ + sta->deflink.tx_stats.bytes[skb_get_queue_mapping(skb)] += skb->len; + sta->deflink.tx_stats.packets[skb_get_queue_mapping(skb)]++; + + if (pn_offs) { + u64 pn; + u8 *crypto_hdr = skb->data + pn_offs; + + switch (key->conf.cipher) { + case WLAN_CIPHER_SUITE_CCMP: + case WLAN_CIPHER_SUITE_CCMP_256: + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + pn = atomic64_inc_return(&key->conf.tx_pn); + crypto_hdr[0] = pn; + crypto_hdr[1] = pn >> 8; + crypto_hdr[3] = 0x20 | (key->conf.keyidx << 6); + crypto_hdr[4] = pn >> 16; + crypto_hdr[5] = pn >> 24; + crypto_hdr[6] = pn >> 32; + crypto_hdr[7] = pn >> 40; + break; + } + } + + return TX_CONTINUE; +} + +static bool ieee80211_xmit_fast(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, + struct ieee80211_fast_tx *fast_tx, + struct sk_buff *skb) +{ + struct ieee80211_local *local = sdata->local; + u16 ethertype = (skb->data[12] << 8) | skb->data[13]; + int extra_head = fast_tx->hdr_len - (ETH_HLEN - 2); + int hw_headroom = sdata->local->hw.extra_tx_headroom; + struct ethhdr eth; + struct ieee80211_tx_info *info; + struct ieee80211_hdr *hdr = (void *)fast_tx->hdr; + struct ieee80211_tx_data tx; + ieee80211_tx_result r; + struct tid_ampdu_tx *tid_tx = NULL; + u8 tid = IEEE80211_NUM_TIDS; + + /* control port protocol needs a lot of special handling */ + if (cpu_to_be16(ethertype) == sdata->control_port_protocol) + return false; + + /* only RFC 1042 SNAP */ + if (ethertype < ETH_P_802_3_MIN) + return false; + + /* don't handle TX status request here either */ + if (skb->sk && skb_shinfo(skb)->tx_flags & SKBTX_WIFI_STATUS) + return false; + + if (hdr->frame_control & cpu_to_le16(IEEE80211_STYPE_QOS_DATA)) { + tid = skb->priority & IEEE80211_QOS_CTL_TAG1D_MASK; + tid_tx = rcu_dereference(sta->ampdu_mlme.tid_tx[tid]); + if (tid_tx) { + if (!test_bit(HT_AGG_STATE_OPERATIONAL, &tid_tx->state)) + return false; + if (tid_tx->timeout) + tid_tx->last_tx = jiffies; + } + } + + /* after this point (skb is modified) we cannot return false */ + + skb = skb_share_check(skb, GFP_ATOMIC); + if (unlikely(!skb)) + return true; + + if ((hdr->frame_control & cpu_to_le16(IEEE80211_STYPE_QOS_DATA)) && + ieee80211_amsdu_aggregate(sdata, sta, fast_tx, skb)) + return true; + + /* will not be crypto-handled beyond what we do here, so use false + * as the may-encrypt argument for the resize to not account for + * more room than we already have in 'extra_head' + */ + if (unlikely(ieee80211_skb_resize(sdata, skb, + max_t(int, extra_head + hw_headroom - + skb_headroom(skb), 0), + ENCRYPT_NO))) { + kfree_skb(skb); + return true; + } + + memcpy(ð, skb->data, ETH_HLEN - 2); + hdr = skb_push(skb, extra_head); + memcpy(skb->data, fast_tx->hdr, fast_tx->hdr_len); + memcpy(skb->data + fast_tx->da_offs, eth.h_dest, ETH_ALEN); + memcpy(skb->data + fast_tx->sa_offs, eth.h_source, ETH_ALEN); + + info = IEEE80211_SKB_CB(skb); + memset(info, 0, sizeof(*info)); + info->band = fast_tx->band; + info->control.vif = &sdata->vif; + info->flags = IEEE80211_TX_CTL_FIRST_FRAGMENT | + IEEE80211_TX_CTL_DONTFRAG | + (tid_tx ? IEEE80211_TX_CTL_AMPDU : 0); + info->control.flags = IEEE80211_TX_CTRL_FAST_XMIT | + u32_encode_bits(IEEE80211_LINK_UNSPECIFIED, + IEEE80211_TX_CTRL_MLO_LINK); + +#ifdef CONFIG_MAC80211_DEBUGFS + if (local->force_tx_status) + info->flags |= IEEE80211_TX_CTL_REQ_TX_STATUS; +#endif + + if (hdr->frame_control & cpu_to_le16(IEEE80211_STYPE_QOS_DATA)) { + tid = skb->priority & IEEE80211_QOS_CTL_TAG1D_MASK; + *ieee80211_get_qos_ctl(hdr) = tid; + } + + __skb_queue_head_init(&tx.skbs); + + tx.flags = IEEE80211_TX_UNICAST; + tx.local = local; + tx.sdata = sdata; + tx.sta = sta; + tx.key = fast_tx->key; + + if (ieee80211_queue_skb(local, sdata, sta, skb)) + return true; + + tx.skb = skb; + r = ieee80211_xmit_fast_finish(sdata, sta, fast_tx->pn_offs, + fast_tx->key, &tx); + tx.skb = NULL; + if (r == TX_DROP) { + kfree_skb(skb); + return true; + } + + if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + sdata = container_of(sdata->bss, + struct ieee80211_sub_if_data, u.ap); + + __skb_queue_tail(&tx.skbs, skb); + ieee80211_tx_frags(local, &sdata->vif, sta, &tx.skbs, false); + return true; +} + +struct sk_buff *ieee80211_tx_dequeue(struct ieee80211_hw *hw, + struct ieee80211_txq *txq) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct txq_info *txqi = container_of(txq, struct txq_info, txq); + struct ieee80211_hdr *hdr; + struct fq *fq = &local->fq; + struct fq_tin *tin = &txqi->tin; + struct ieee80211_tx_info *info; + struct ieee80211_tx_data tx; + struct sk_buff *skb; + ieee80211_tx_result r; + struct ieee80211_vif *vif = txq->vif; + int q = vif->hw_queue[txq->ac]; + unsigned long flags; + bool q_stopped; + + WARN_ON_ONCE(softirq_count() == 0); + + if (!ieee80211_txq_airtime_check(hw, txq)) + return NULL; + +begin: + spin_lock_irqsave(&local->queue_stop_reason_lock, flags); + q_stopped = local->queue_stop_reasons[q]; + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); + + if (unlikely(q_stopped)) { + /* mark for waking later */ + set_bit(IEEE80211_TXQ_DIRTY, &txqi->flags); + return NULL; + } + + spin_lock_bh(&fq->lock); + + /* Make sure fragments stay together. */ + skb = __skb_dequeue(&txqi->frags); + if (unlikely(skb)) { + if (!(IEEE80211_SKB_CB(skb)->control.flags & + IEEE80211_TX_INTCFL_NEED_TXPROCESSING)) + goto out; + IEEE80211_SKB_CB(skb)->control.flags &= + ~IEEE80211_TX_INTCFL_NEED_TXPROCESSING; + } else { + if (unlikely(test_bit(IEEE80211_TXQ_STOP, &txqi->flags))) + goto out; + + skb = fq_tin_dequeue(fq, tin, fq_tin_dequeue_func); + } + + if (!skb) + goto out; + + spin_unlock_bh(&fq->lock); + + hdr = (struct ieee80211_hdr *)skb->data; + info = IEEE80211_SKB_CB(skb); + + memset(&tx, 0, sizeof(tx)); + __skb_queue_head_init(&tx.skbs); + tx.local = local; + tx.skb = skb; + tx.sdata = vif_to_sdata(info->control.vif); + + if (txq->sta) { + tx.sta = container_of(txq->sta, struct sta_info, sta); + /* + * Drop unicast frames to unauthorised stations unless they are + * injected frames or EAPOL frames from the local station. + */ + if (unlikely(!(info->flags & IEEE80211_TX_CTL_INJECTED) && + ieee80211_is_data(hdr->frame_control) && + !ieee80211_vif_is_mesh(&tx.sdata->vif) && + tx.sdata->vif.type != NL80211_IFTYPE_OCB && + !is_multicast_ether_addr(hdr->addr1) && + !test_sta_flag(tx.sta, WLAN_STA_AUTHORIZED) && + (!(info->control.flags & + IEEE80211_TX_CTRL_PORT_CTRL_PROTO) || + !ieee80211_is_our_addr(tx.sdata, hdr->addr2, + NULL)))) { + I802_DEBUG_INC(local->tx_handlers_drop_unauth_port); + ieee80211_free_txskb(&local->hw, skb); + goto begin; + } + } + + /* + * The key can be removed while the packet was queued, so need to call + * this here to get the current key. + */ + r = ieee80211_tx_h_select_key(&tx); + if (r != TX_CONTINUE) { + ieee80211_free_txskb(&local->hw, skb); + goto begin; + } + + if (test_bit(IEEE80211_TXQ_AMPDU, &txqi->flags)) + info->flags |= (IEEE80211_TX_CTL_AMPDU | + IEEE80211_TX_CTL_DONTFRAG); + else + info->flags &= ~IEEE80211_TX_CTL_AMPDU; + + if (info->flags & IEEE80211_TX_CTL_HW_80211_ENCAP) { + if (!ieee80211_hw_check(&local->hw, HAS_RATE_CONTROL)) { + r = ieee80211_tx_h_rate_ctrl(&tx); + if (r != TX_CONTINUE) { + ieee80211_free_txskb(&local->hw, skb); + goto begin; + } + } + goto encap_out; + } + + if (info->control.flags & IEEE80211_TX_CTRL_FAST_XMIT) { + struct sta_info *sta = container_of(txq->sta, struct sta_info, + sta); + u8 pn_offs = 0; + + if (tx.key && + (tx.key->conf.flags & IEEE80211_KEY_FLAG_GENERATE_IV)) + pn_offs = ieee80211_hdrlen(hdr->frame_control); + + r = ieee80211_xmit_fast_finish(sta->sdata, sta, pn_offs, + tx.key, &tx); + if (r != TX_CONTINUE) { + ieee80211_free_txskb(&local->hw, skb); + goto begin; + } + } else { + if (invoke_tx_handlers_late(&tx)) + goto begin; + + skb = __skb_dequeue(&tx.skbs); + + if (!skb_queue_empty(&tx.skbs)) { + spin_lock_bh(&fq->lock); + skb_queue_splice_tail(&tx.skbs, &txqi->frags); + spin_unlock_bh(&fq->lock); + } + } + + if (skb_has_frag_list(skb) && + !ieee80211_hw_check(&local->hw, TX_FRAG_LIST)) { + if (skb_linearize(skb)) { + ieee80211_free_txskb(&local->hw, skb); + goto begin; + } + } + + switch (tx.sdata->vif.type) { + case NL80211_IFTYPE_MONITOR: + if (tx.sdata->u.mntr.flags & MONITOR_FLAG_ACTIVE) { + vif = &tx.sdata->vif; + break; + } + tx.sdata = rcu_dereference(local->monitor_sdata); + if (tx.sdata) { + vif = &tx.sdata->vif; + info->hw_queue = + vif->hw_queue[skb_get_queue_mapping(skb)]; + } else if (ieee80211_hw_check(&local->hw, QUEUE_CONTROL)) { + ieee80211_free_txskb(&local->hw, skb); + goto begin; + } else { + vif = NULL; + } + break; + case NL80211_IFTYPE_AP_VLAN: + tx.sdata = container_of(tx.sdata->bss, + struct ieee80211_sub_if_data, u.ap); + fallthrough; + default: + vif = &tx.sdata->vif; + break; + } + +encap_out: + IEEE80211_SKB_CB(skb)->control.vif = vif; + + if (tx.sta && + wiphy_ext_feature_isset(local->hw.wiphy, NL80211_EXT_FEATURE_AQL)) { + bool ampdu = txq->ac != IEEE80211_AC_VO; + u32 airtime; + + airtime = ieee80211_calc_expected_tx_airtime(hw, vif, txq->sta, + skb->len, ampdu); + if (airtime) { + airtime = ieee80211_info_set_tx_time_est(info, airtime); + ieee80211_sta_update_pending_airtime(local, tx.sta, + txq->ac, + airtime, + false); + } + } + + return skb; + +out: + spin_unlock_bh(&fq->lock); + + return skb; +} +EXPORT_SYMBOL(ieee80211_tx_dequeue); + +static inline s32 ieee80211_sta_deficit(struct sta_info *sta, u8 ac) +{ + struct airtime_info *air_info = &sta->airtime[ac]; + + return air_info->deficit - atomic_read(&air_info->aql_tx_pending); +} + +static void +ieee80211_txq_set_active(struct txq_info *txqi) +{ + struct sta_info *sta; + + if (!txqi->txq.sta) + return; + + sta = container_of(txqi->txq.sta, struct sta_info, sta); + sta->airtime[txqi->txq.ac].last_active = (u32)jiffies; +} + +static bool +ieee80211_txq_keep_active(struct txq_info *txqi) +{ + struct sta_info *sta; + u32 diff; + + if (!txqi->txq.sta) + return false; + + sta = container_of(txqi->txq.sta, struct sta_info, sta); + if (ieee80211_sta_deficit(sta, txqi->txq.ac) >= 0) + return false; + + diff = (u32)jiffies - sta->airtime[txqi->txq.ac].last_active; + + return diff <= AIRTIME_ACTIVE_DURATION; +} + +struct ieee80211_txq *ieee80211_next_txq(struct ieee80211_hw *hw, u8 ac) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_txq *ret = NULL; + struct txq_info *txqi = NULL, *head = NULL; + bool found_eligible_txq = false; + + spin_lock_bh(&local->active_txq_lock[ac]); + + if (!local->schedule_round[ac]) + goto out; + + begin: + txqi = list_first_entry_or_null(&local->active_txqs[ac], + struct txq_info, + schedule_order); + if (!txqi) + goto out; + + if (txqi == head) { + if (!found_eligible_txq) + goto out; + else + found_eligible_txq = false; + } + + if (!head) + head = txqi; + + if (txqi->txq.sta) { + struct sta_info *sta = container_of(txqi->txq.sta, + struct sta_info, sta); + bool aql_check = ieee80211_txq_airtime_check(hw, &txqi->txq); + s32 deficit = ieee80211_sta_deficit(sta, txqi->txq.ac); + + if (aql_check) + found_eligible_txq = true; + + if (deficit < 0) + sta->airtime[txqi->txq.ac].deficit += + sta->airtime_weight; + + if (deficit < 0 || !aql_check) { + list_move_tail(&txqi->schedule_order, + &local->active_txqs[txqi->txq.ac]); + goto begin; + } + } + + if (txqi->schedule_round == local->schedule_round[ac]) + goto out; + + list_del_init(&txqi->schedule_order); + txqi->schedule_round = local->schedule_round[ac]; + ret = &txqi->txq; + +out: + spin_unlock_bh(&local->active_txq_lock[ac]); + return ret; +} +EXPORT_SYMBOL(ieee80211_next_txq); + +void __ieee80211_schedule_txq(struct ieee80211_hw *hw, + struct ieee80211_txq *txq, + bool force) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct txq_info *txqi = to_txq_info(txq); + bool has_queue; + + spin_lock_bh(&local->active_txq_lock[txq->ac]); + + has_queue = force || txq_has_queue(txq); + if (list_empty(&txqi->schedule_order) && + (has_queue || ieee80211_txq_keep_active(txqi))) { + /* If airtime accounting is active, always enqueue STAs at the + * head of the list to ensure that they only get moved to the + * back by the airtime DRR scheduler once they have a negative + * deficit. A station that already has a negative deficit will + * get immediately moved to the back of the list on the next + * call to ieee80211_next_txq(). + */ + if (txqi->txq.sta && local->airtime_flags && has_queue && + wiphy_ext_feature_isset(local->hw.wiphy, + NL80211_EXT_FEATURE_AIRTIME_FAIRNESS)) + list_add(&txqi->schedule_order, + &local->active_txqs[txq->ac]); + else + list_add_tail(&txqi->schedule_order, + &local->active_txqs[txq->ac]); + if (has_queue) + ieee80211_txq_set_active(txqi); + } + + spin_unlock_bh(&local->active_txq_lock[txq->ac]); +} +EXPORT_SYMBOL(__ieee80211_schedule_txq); + +DEFINE_STATIC_KEY_FALSE(aql_disable); + +bool ieee80211_txq_airtime_check(struct ieee80211_hw *hw, + struct ieee80211_txq *txq) +{ + struct sta_info *sta; + struct ieee80211_local *local = hw_to_local(hw); + + if (!wiphy_ext_feature_isset(local->hw.wiphy, NL80211_EXT_FEATURE_AQL)) + return true; + + if (static_branch_unlikely(&aql_disable)) + return true; + + if (!txq->sta) + return true; + + if (unlikely(txq->tid == IEEE80211_NUM_TIDS)) + return true; + + sta = container_of(txq->sta, struct sta_info, sta); + if (atomic_read(&sta->airtime[txq->ac].aql_tx_pending) < + sta->airtime[txq->ac].aql_limit_low) + return true; + + if (atomic_read(&local->aql_total_pending_airtime) < + local->aql_threshold && + atomic_read(&sta->airtime[txq->ac].aql_tx_pending) < + sta->airtime[txq->ac].aql_limit_high) + return true; + + return false; +} +EXPORT_SYMBOL(ieee80211_txq_airtime_check); + +static bool +ieee80211_txq_schedule_airtime_check(struct ieee80211_local *local, u8 ac) +{ + unsigned int num_txq = 0; + struct txq_info *txq; + u32 aql_limit; + + if (!wiphy_ext_feature_isset(local->hw.wiphy, NL80211_EXT_FEATURE_AQL)) + return true; + + list_for_each_entry(txq, &local->active_txqs[ac], schedule_order) + num_txq++; + + aql_limit = (num_txq - 1) * local->aql_txq_limit_low[ac] / 2 + + local->aql_txq_limit_high[ac]; + + return atomic_read(&local->aql_ac_pending_airtime[ac]) < aql_limit; +} + +bool ieee80211_txq_may_transmit(struct ieee80211_hw *hw, + struct ieee80211_txq *txq) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct txq_info *iter, *tmp, *txqi = to_txq_info(txq); + struct sta_info *sta; + u8 ac = txq->ac; + + spin_lock_bh(&local->active_txq_lock[ac]); + + if (!txqi->txq.sta) + goto out; + + if (list_empty(&txqi->schedule_order)) + goto out; + + if (!ieee80211_txq_schedule_airtime_check(local, ac)) + goto out; + + list_for_each_entry_safe(iter, tmp, &local->active_txqs[ac], + schedule_order) { + if (iter == txqi) + break; + + if (!iter->txq.sta) { + list_move_tail(&iter->schedule_order, + &local->active_txqs[ac]); + continue; + } + sta = container_of(iter->txq.sta, struct sta_info, sta); + if (ieee80211_sta_deficit(sta, ac) < 0) + sta->airtime[ac].deficit += sta->airtime_weight; + list_move_tail(&iter->schedule_order, &local->active_txqs[ac]); + } + + sta = container_of(txqi->txq.sta, struct sta_info, sta); + if (sta->airtime[ac].deficit >= 0) + goto out; + + sta->airtime[ac].deficit += sta->airtime_weight; + list_move_tail(&txqi->schedule_order, &local->active_txqs[ac]); + spin_unlock_bh(&local->active_txq_lock[ac]); + + return false; +out: + if (!list_empty(&txqi->schedule_order)) + list_del_init(&txqi->schedule_order); + spin_unlock_bh(&local->active_txq_lock[ac]); + + return true; +} +EXPORT_SYMBOL(ieee80211_txq_may_transmit); + +void ieee80211_txq_schedule_start(struct ieee80211_hw *hw, u8 ac) +{ + struct ieee80211_local *local = hw_to_local(hw); + + spin_lock_bh(&local->active_txq_lock[ac]); + + if (ieee80211_txq_schedule_airtime_check(local, ac)) { + local->schedule_round[ac]++; + if (!local->schedule_round[ac]) + local->schedule_round[ac]++; + } else { + local->schedule_round[ac] = 0; + } + + spin_unlock_bh(&local->active_txq_lock[ac]); +} +EXPORT_SYMBOL(ieee80211_txq_schedule_start); + +void __ieee80211_subif_start_xmit(struct sk_buff *skb, + struct net_device *dev, + u32 info_flags, + u32 ctrl_flags, + u64 *cookie) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + struct sk_buff *next; + int len = skb->len; + + if (unlikely(!ieee80211_sdata_running(sdata) || skb->len < ETH_HLEN)) { + kfree_skb(skb); + return; + } + + rcu_read_lock(); + + if (ieee80211_lookup_ra_sta(sdata, skb, &sta)) + goto out_free; + + if (IS_ERR(sta)) + sta = NULL; + + if (local->ops->wake_tx_queue) { + u16 queue = __ieee80211_select_queue(sdata, sta, skb); + skb_set_queue_mapping(skb, queue); + skb_get_hash(skb); + } + + ieee80211_aggr_check(sdata, sta, skb); + + sk_pacing_shift_update(skb->sk, sdata->local->hw.tx_sk_pacing_shift); + + if (sta) { + struct ieee80211_fast_tx *fast_tx; + + fast_tx = rcu_dereference(sta->fast_tx); + + if (fast_tx && + ieee80211_xmit_fast(sdata, sta, fast_tx, skb)) + goto out; + } + + if (skb_is_gso(skb)) { + struct sk_buff *segs; + + segs = skb_gso_segment(skb, 0); + if (IS_ERR(segs)) { + goto out_free; + } else if (segs) { + consume_skb(skb); + skb = segs; + } + } else { + /* we cannot process non-linear frames on this path */ + if (skb_linearize(skb)) + goto out_free; + + /* the frame could be fragmented, software-encrypted, and other + * things so we cannot really handle checksum offload with it - + * fix it up in software before we handle anything else. + */ + if (skb->ip_summed == CHECKSUM_PARTIAL) { + skb_set_transport_header(skb, + skb_checksum_start_offset(skb)); + if (skb_checksum_help(skb)) + goto out_free; + } + } + + skb_list_walk_safe(skb, skb, next) { + skb_mark_not_on_list(skb); + + if (skb->protocol == sdata->control_port_protocol) + ctrl_flags |= IEEE80211_TX_CTRL_SKIP_MPATH_LOOKUP; + + skb = ieee80211_build_hdr(sdata, skb, info_flags, + sta, ctrl_flags, cookie); + if (IS_ERR(skb)) { + kfree_skb_list(next); + goto out; + } + + dev_sw_netstats_tx_add(dev, 1, skb->len); + + ieee80211_xmit(sdata, sta, skb); + } + goto out; + out_free: + kfree_skb(skb); + len = 0; + out: + if (len) + ieee80211_tpt_led_trig_tx(local, len); + rcu_read_unlock(); +} + +static int ieee80211_change_da(struct sk_buff *skb, struct sta_info *sta) +{ + struct ethhdr *eth; + int err; + + err = skb_ensure_writable(skb, ETH_HLEN); + if (unlikely(err)) + return err; + + eth = (void *)skb->data; + ether_addr_copy(eth->h_dest, sta->sta.addr); + + return 0; +} + +static bool ieee80211_multicast_to_unicast(struct sk_buff *skb, + struct net_device *dev) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + const struct ethhdr *eth = (void *)skb->data; + const struct vlan_ethhdr *ethvlan = (void *)skb->data; + __be16 ethertype; + + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP_VLAN: + if (sdata->u.vlan.sta) + return false; + if (sdata->wdev.use_4addr) + return false; + fallthrough; + case NL80211_IFTYPE_AP: + /* check runtime toggle for this bss */ + if (!sdata->bss->multicast_to_unicast) + return false; + break; + default: + return false; + } + + /* multicast to unicast conversion only for some payload */ + ethertype = eth->h_proto; + if (ethertype == htons(ETH_P_8021Q) && skb->len >= VLAN_ETH_HLEN) + ethertype = ethvlan->h_vlan_encapsulated_proto; + switch (ethertype) { + case htons(ETH_P_ARP): + case htons(ETH_P_IP): + case htons(ETH_P_IPV6): + break; + default: + return false; + } + + return true; +} + +static void +ieee80211_convert_to_unicast(struct sk_buff *skb, struct net_device *dev, + struct sk_buff_head *queue) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + const struct ethhdr *eth = (struct ethhdr *)skb->data; + struct sta_info *sta, *first = NULL; + struct sk_buff *cloned_skb; + + rcu_read_lock(); + + list_for_each_entry_rcu(sta, &local->sta_list, list) { + if (sdata != sta->sdata) + /* AP-VLAN mismatch */ + continue; + if (unlikely(ether_addr_equal(eth->h_source, sta->sta.addr))) + /* do not send back to source */ + continue; + if (!first) { + first = sta; + continue; + } + cloned_skb = skb_clone(skb, GFP_ATOMIC); + if (!cloned_skb) + goto multicast; + if (unlikely(ieee80211_change_da(cloned_skb, sta))) { + dev_kfree_skb(cloned_skb); + goto multicast; + } + __skb_queue_tail(queue, cloned_skb); + } + + if (likely(first)) { + if (unlikely(ieee80211_change_da(skb, first))) + goto multicast; + __skb_queue_tail(queue, skb); + } else { + /* no STA connected, drop */ + kfree_skb(skb); + skb = NULL; + } + + goto out; +multicast: + __skb_queue_purge(queue); + __skb_queue_tail(queue, skb); +out: + rcu_read_unlock(); +} + +static void ieee80211_mlo_multicast_tx_one(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, u32 ctrl_flags, + unsigned int link_id) +{ + struct sk_buff *out; + + out = skb_copy(skb, GFP_ATOMIC); + if (!out) + return; + + ctrl_flags |= u32_encode_bits(link_id, IEEE80211_TX_CTRL_MLO_LINK); + __ieee80211_subif_start_xmit(out, sdata->dev, 0, ctrl_flags, NULL); +} + +static void ieee80211_mlo_multicast_tx(struct net_device *dev, + struct sk_buff *skb) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + unsigned long links = sdata->vif.active_links; + unsigned int link; + u32 ctrl_flags = IEEE80211_TX_CTRL_MCAST_MLO_FIRST_TX; + + if (hweight16(links) == 1) { + ctrl_flags |= u32_encode_bits(__ffs(links), + IEEE80211_TX_CTRL_MLO_LINK); + + __ieee80211_subif_start_xmit(skb, sdata->dev, 0, ctrl_flags, + NULL); + return; + } + + for_each_set_bit(link, &links, IEEE80211_MLD_MAX_NUM_LINKS) { + ieee80211_mlo_multicast_tx_one(sdata, skb, ctrl_flags, link); + ctrl_flags = 0; + } + kfree_skb(skb); +} + +/** + * ieee80211_subif_start_xmit - netif start_xmit function for 802.3 vifs + * @skb: packet to be sent + * @dev: incoming interface + * + * On failure skb will be freed. + */ +netdev_tx_t ieee80211_subif_start_xmit(struct sk_buff *skb, + struct net_device *dev) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + const struct ethhdr *eth = (void *)skb->data; + + if (likely(!is_multicast_ether_addr(eth->h_dest))) + goto normal; + + if (unlikely(!ieee80211_sdata_running(sdata))) { + kfree_skb(skb); + return NETDEV_TX_OK; + } + + if (unlikely(ieee80211_multicast_to_unicast(skb, dev))) { + struct sk_buff_head queue; + + __skb_queue_head_init(&queue); + ieee80211_convert_to_unicast(skb, dev, &queue); + while ((skb = __skb_dequeue(&queue))) + __ieee80211_subif_start_xmit(skb, dev, 0, + IEEE80211_TX_CTRL_MLO_LINK_UNSPEC, + NULL); + } else if (sdata->vif.valid_links && + sdata->vif.type == NL80211_IFTYPE_AP && + !ieee80211_hw_check(&sdata->local->hw, MLO_MCAST_MULTI_LINK_TX)) { + ieee80211_mlo_multicast_tx(dev, skb); + } else { +normal: + __ieee80211_subif_start_xmit(skb, dev, 0, + IEEE80211_TX_CTRL_MLO_LINK_UNSPEC, + NULL); + } + + return NETDEV_TX_OK; +} + +static bool ieee80211_tx_8023(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, struct sta_info *sta, + bool txpending) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_tx_control control = {}; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_sta *pubsta = NULL; + unsigned long flags; + int q = info->hw_queue; + + if (sta) + sk_pacing_shift_update(skb->sk, local->hw.tx_sk_pacing_shift); + + ieee80211_tpt_led_trig_tx(local, skb->len); + + if (ieee80211_queue_skb(local, sdata, sta, skb)) + return true; + + spin_lock_irqsave(&local->queue_stop_reason_lock, flags); + + if (local->queue_stop_reasons[q] || + (!txpending && !skb_queue_empty(&local->pending[q]))) { + if (txpending) + skb_queue_head(&local->pending[q], skb); + else + skb_queue_tail(&local->pending[q], skb); + + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); + + return false; + } + + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); + + if (sta && sta->uploaded) + pubsta = &sta->sta; + + control.sta = pubsta; + + drv_tx(local, &control, skb); + + return true; +} + +static void ieee80211_8023_xmit(struct ieee80211_sub_if_data *sdata, + struct net_device *dev, struct sta_info *sta, + struct ieee80211_key *key, struct sk_buff *skb) +{ + struct ieee80211_tx_info *info; + struct ieee80211_local *local = sdata->local; + struct tid_ampdu_tx *tid_tx; + u8 tid; + + if (local->ops->wake_tx_queue) { + u16 queue = __ieee80211_select_queue(sdata, sta, skb); + skb_set_queue_mapping(skb, queue); + skb_get_hash(skb); + } + + if (unlikely(test_bit(SCAN_SW_SCANNING, &local->scanning)) && + test_bit(SDATA_STATE_OFFCHANNEL, &sdata->state)) + goto out_free; + + skb = skb_share_check(skb, GFP_ATOMIC); + if (unlikely(!skb)) + return; + + info = IEEE80211_SKB_CB(skb); + memset(info, 0, sizeof(*info)); + + ieee80211_aggr_check(sdata, sta, skb); + + tid = skb->priority & IEEE80211_QOS_CTL_TAG1D_MASK; + tid_tx = rcu_dereference(sta->ampdu_mlme.tid_tx[tid]); + if (tid_tx) { + if (!test_bit(HT_AGG_STATE_OPERATIONAL, &tid_tx->state)) { + /* fall back to non-offload slow path */ + __ieee80211_subif_start_xmit(skb, dev, 0, + IEEE80211_TX_CTRL_MLO_LINK_UNSPEC, + NULL); + return; + } + + info->flags |= IEEE80211_TX_CTL_AMPDU; + if (tid_tx->timeout) + tid_tx->last_tx = jiffies; + } + + if (unlikely(skb->sk && + skb_shinfo(skb)->tx_flags & SKBTX_WIFI_STATUS)) + info->ack_frame_id = ieee80211_store_ack_skb(local, skb, + &info->flags, NULL); + + info->hw_queue = sdata->vif.hw_queue[skb_get_queue_mapping(skb)]; + + dev_sw_netstats_tx_add(dev, 1, skb->len); + + sta->deflink.tx_stats.bytes[skb_get_queue_mapping(skb)] += skb->len; + sta->deflink.tx_stats.packets[skb_get_queue_mapping(skb)]++; + + if (sdata->vif.type == NL80211_IFTYPE_AP_VLAN) + sdata = container_of(sdata->bss, + struct ieee80211_sub_if_data, u.ap); + + info->flags |= IEEE80211_TX_CTL_HW_80211_ENCAP; + info->control.vif = &sdata->vif; + + if (key) + info->control.hw_key = &key->conf; + + ieee80211_tx_8023(sdata, skb, sta, false); + + return; + +out_free: + kfree_skb(skb); +} + +netdev_tx_t ieee80211_subif_start_xmit_8023(struct sk_buff *skb, + struct net_device *dev) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ethhdr *ehdr = (struct ethhdr *)skb->data; + struct ieee80211_key *key; + struct sta_info *sta; + + if (unlikely(!ieee80211_sdata_running(sdata) || skb->len < ETH_HLEN)) { + kfree_skb(skb); + return NETDEV_TX_OK; + } + + rcu_read_lock(); + + if (ieee80211_lookup_ra_sta(sdata, skb, &sta)) { + kfree_skb(skb); + goto out; + } + + if (unlikely(IS_ERR_OR_NULL(sta) || !sta->uploaded || + !test_sta_flag(sta, WLAN_STA_AUTHORIZED) || + sdata->control_port_protocol == ehdr->h_proto)) + goto skip_offload; + + key = rcu_dereference(sta->ptk[sta->ptk_idx]); + if (!key) + key = rcu_dereference(sdata->default_unicast_key); + + if (key && (!(key->flags & KEY_FLAG_UPLOADED_TO_HARDWARE) || + key->conf.cipher == WLAN_CIPHER_SUITE_TKIP)) + goto skip_offload; + + ieee80211_8023_xmit(sdata, dev, sta, key, skb); + goto out; + +skip_offload: + ieee80211_subif_start_xmit(skb, dev); +out: + rcu_read_unlock(); + + return NETDEV_TX_OK; +} + +struct sk_buff * +ieee80211_build_data_template(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, u32 info_flags) +{ + struct ieee80211_hdr *hdr; + struct ieee80211_tx_data tx = { + .local = sdata->local, + .sdata = sdata, + }; + struct sta_info *sta; + + rcu_read_lock(); + + if (ieee80211_lookup_ra_sta(sdata, skb, &sta)) { + kfree_skb(skb); + skb = ERR_PTR(-EINVAL); + goto out; + } + + skb = ieee80211_build_hdr(sdata, skb, info_flags, sta, + IEEE80211_TX_CTRL_MLO_LINK_UNSPEC, NULL); + if (IS_ERR(skb)) + goto out; + + hdr = (void *)skb->data; + tx.sta = sta_info_get(sdata, hdr->addr1); + tx.skb = skb; + + if (ieee80211_tx_h_select_key(&tx) != TX_CONTINUE) { + rcu_read_unlock(); + kfree_skb(skb); + return ERR_PTR(-EINVAL); + } + +out: + rcu_read_unlock(); + return skb; +} + +/* + * ieee80211_clear_tx_pending may not be called in a context where + * it is possible that it packets could come in again. + */ +void ieee80211_clear_tx_pending(struct ieee80211_local *local) +{ + struct sk_buff *skb; + int i; + + for (i = 0; i < local->hw.queues; i++) { + while ((skb = skb_dequeue(&local->pending[i])) != NULL) + ieee80211_free_txskb(&local->hw, skb); + } +} + +/* + * Returns false if the frame couldn't be transmitted but was queued instead, + * which in this case means re-queued -- take as an indication to stop sending + * more pending frames. + */ +static bool ieee80211_tx_pending_skb(struct ieee80211_local *local, + struct sk_buff *skb) +{ + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_sub_if_data *sdata; + struct sta_info *sta; + struct ieee80211_hdr *hdr; + bool result; + struct ieee80211_chanctx_conf *chanctx_conf; + + sdata = vif_to_sdata(info->control.vif); + + if (info->control.flags & IEEE80211_TX_INTCFL_NEED_TXPROCESSING) { + /* update band only for non-MLD */ + if (!sdata->vif.valid_links) { + chanctx_conf = + rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + if (unlikely(!chanctx_conf)) { + dev_kfree_skb(skb); + return true; + } + info->band = chanctx_conf->def.chan->band; + } + result = ieee80211_tx(sdata, NULL, skb, true); + } else if (info->flags & IEEE80211_TX_CTL_HW_80211_ENCAP) { + if (ieee80211_lookup_ra_sta(sdata, skb, &sta)) { + dev_kfree_skb(skb); + return true; + } + + if (IS_ERR(sta) || (sta && !sta->uploaded)) + sta = NULL; + + result = ieee80211_tx_8023(sdata, skb, sta, true); + } else { + struct sk_buff_head skbs; + + __skb_queue_head_init(&skbs); + __skb_queue_tail(&skbs, skb); + + hdr = (struct ieee80211_hdr *)skb->data; + sta = sta_info_get(sdata, hdr->addr1); + + result = __ieee80211_tx(local, &skbs, sta, true); + } + + return result; +} + +/* + * Transmit all pending packets. Called from tasklet. + */ +void ieee80211_tx_pending(struct tasklet_struct *t) +{ + struct ieee80211_local *local = from_tasklet(local, t, + tx_pending_tasklet); + unsigned long flags; + int i; + bool txok; + + rcu_read_lock(); + + spin_lock_irqsave(&local->queue_stop_reason_lock, flags); + for (i = 0; i < local->hw.queues; i++) { + /* + * If queue is stopped by something other than due to pending + * frames, or we have no pending frames, proceed to next queue. + */ + if (local->queue_stop_reasons[i] || + skb_queue_empty(&local->pending[i])) + continue; + + while (!skb_queue_empty(&local->pending[i])) { + struct sk_buff *skb = __skb_dequeue(&local->pending[i]); + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + + if (WARN_ON(!info->control.vif)) { + ieee80211_free_txskb(&local->hw, skb); + continue; + } + + spin_unlock_irqrestore(&local->queue_stop_reason_lock, + flags); + + txok = ieee80211_tx_pending_skb(local, skb); + spin_lock_irqsave(&local->queue_stop_reason_lock, + flags); + if (!txok) + break; + } + + if (skb_queue_empty(&local->pending[i])) + ieee80211_propagate_queue_wake(local, i); + } + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); + + rcu_read_unlock(); +} + +/* functions for drivers to get certain frames */ + +static void __ieee80211_beacon_add_tim(struct ieee80211_sub_if_data *sdata, + struct ieee80211_link_data *link, + struct ps_data *ps, struct sk_buff *skb, + bool is_template) +{ + u8 *pos, *tim; + int aid0 = 0; + int i, have_bits = 0, n1, n2; + struct ieee80211_bss_conf *link_conf = link->conf; + + /* Generate bitmap for TIM only if there are any STAs in power save + * mode. */ + if (atomic_read(&ps->num_sta_ps) > 0) + /* in the hope that this is faster than + * checking byte-for-byte */ + have_bits = !bitmap_empty((unsigned long *)ps->tim, + IEEE80211_MAX_AID+1); + if (!is_template) { + if (ps->dtim_count == 0) + ps->dtim_count = link_conf->dtim_period - 1; + else + ps->dtim_count--; + } + + tim = pos = skb_put(skb, 6); + *pos++ = WLAN_EID_TIM; + *pos++ = 4; + *pos++ = ps->dtim_count; + *pos++ = link_conf->dtim_period; + + if (ps->dtim_count == 0 && !skb_queue_empty(&ps->bc_buf)) + aid0 = 1; + + ps->dtim_bc_mc = aid0 == 1; + + if (have_bits) { + /* Find largest even number N1 so that bits numbered 1 through + * (N1 x 8) - 1 in the bitmap are 0 and number N2 so that bits + * (N2 + 1) x 8 through 2007 are 0. */ + n1 = 0; + for (i = 0; i < IEEE80211_MAX_TIM_LEN; i++) { + if (ps->tim[i]) { + n1 = i & 0xfe; + break; + } + } + n2 = n1; + for (i = IEEE80211_MAX_TIM_LEN - 1; i >= n1; i--) { + if (ps->tim[i]) { + n2 = i; + break; + } + } + + /* Bitmap control */ + *pos++ = n1 | aid0; + /* Part Virt Bitmap */ + skb_put(skb, n2 - n1); + memcpy(pos, ps->tim + n1, n2 - n1 + 1); + + tim[1] = n2 - n1 + 4; + } else { + *pos++ = aid0; /* Bitmap control */ + *pos++ = 0; /* Part Virt Bitmap */ + } +} + +static int ieee80211_beacon_add_tim(struct ieee80211_sub_if_data *sdata, + struct ieee80211_link_data *link, + struct ps_data *ps, struct sk_buff *skb, + bool is_template) +{ + struct ieee80211_local *local = sdata->local; + + /* + * Not very nice, but we want to allow the driver to call + * ieee80211_beacon_get() as a response to the set_tim() + * callback. That, however, is already invoked under the + * sta_lock to guarantee consistent and race-free update + * of the tim bitmap in mac80211 and the driver. + */ + if (local->tim_in_locked_section) { + __ieee80211_beacon_add_tim(sdata, link, ps, skb, is_template); + } else { + spin_lock_bh(&local->tim_lock); + __ieee80211_beacon_add_tim(sdata, link, ps, skb, is_template); + spin_unlock_bh(&local->tim_lock); + } + + return 0; +} + +static void ieee80211_set_beacon_cntdwn(struct ieee80211_sub_if_data *sdata, + struct beacon_data *beacon, + struct ieee80211_link_data *link) +{ + u8 *beacon_data, count, max_count = 1; + struct probe_resp *resp; + size_t beacon_data_len; + u16 *bcn_offsets; + int i; + + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP: + beacon_data = beacon->tail; + beacon_data_len = beacon->tail_len; + break; + case NL80211_IFTYPE_ADHOC: + beacon_data = beacon->head; + beacon_data_len = beacon->head_len; + break; + case NL80211_IFTYPE_MESH_POINT: + beacon_data = beacon->head; + beacon_data_len = beacon->head_len; + break; + default: + return; + } + + resp = rcu_dereference(link->u.ap.probe_resp); + + bcn_offsets = beacon->cntdwn_counter_offsets; + count = beacon->cntdwn_current_counter; + if (link->conf->csa_active) + max_count = IEEE80211_MAX_CNTDWN_COUNTERS_NUM; + + for (i = 0; i < max_count; ++i) { + if (bcn_offsets[i]) { + if (WARN_ON_ONCE(bcn_offsets[i] >= beacon_data_len)) + return; + beacon_data[bcn_offsets[i]] = count; + } + + if (sdata->vif.type == NL80211_IFTYPE_AP && resp) { + u16 *resp_offsets = resp->cntdwn_counter_offsets; + + resp->data[resp_offsets[i]] = count; + } + } +} + +static u8 __ieee80211_beacon_update_cntdwn(struct beacon_data *beacon) +{ + beacon->cntdwn_current_counter--; + + /* the counter should never reach 0 */ + WARN_ON_ONCE(!beacon->cntdwn_current_counter); + + return beacon->cntdwn_current_counter; +} + +u8 ieee80211_beacon_update_cntdwn(struct ieee80211_vif *vif) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct beacon_data *beacon = NULL; + u8 count = 0; + + rcu_read_lock(); + + if (sdata->vif.type == NL80211_IFTYPE_AP) + beacon = rcu_dereference(sdata->deflink.u.ap.beacon); + else if (sdata->vif.type == NL80211_IFTYPE_ADHOC) + beacon = rcu_dereference(sdata->u.ibss.presp); + else if (ieee80211_vif_is_mesh(&sdata->vif)) + beacon = rcu_dereference(sdata->u.mesh.beacon); + + if (!beacon) + goto unlock; + + count = __ieee80211_beacon_update_cntdwn(beacon); + +unlock: + rcu_read_unlock(); + return count; +} +EXPORT_SYMBOL(ieee80211_beacon_update_cntdwn); + +void ieee80211_beacon_set_cntdwn(struct ieee80211_vif *vif, u8 counter) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct beacon_data *beacon = NULL; + + rcu_read_lock(); + + if (sdata->vif.type == NL80211_IFTYPE_AP) + beacon = rcu_dereference(sdata->deflink.u.ap.beacon); + else if (sdata->vif.type == NL80211_IFTYPE_ADHOC) + beacon = rcu_dereference(sdata->u.ibss.presp); + else if (ieee80211_vif_is_mesh(&sdata->vif)) + beacon = rcu_dereference(sdata->u.mesh.beacon); + + if (!beacon) + goto unlock; + + if (counter < beacon->cntdwn_current_counter) + beacon->cntdwn_current_counter = counter; + +unlock: + rcu_read_unlock(); +} +EXPORT_SYMBOL(ieee80211_beacon_set_cntdwn); + +bool ieee80211_beacon_cntdwn_is_complete(struct ieee80211_vif *vif) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct beacon_data *beacon = NULL; + u8 *beacon_data; + size_t beacon_data_len; + int ret = false; + + if (!ieee80211_sdata_running(sdata)) + return false; + + rcu_read_lock(); + if (vif->type == NL80211_IFTYPE_AP) { + beacon = rcu_dereference(sdata->deflink.u.ap.beacon); + if (WARN_ON(!beacon || !beacon->tail)) + goto out; + beacon_data = beacon->tail; + beacon_data_len = beacon->tail_len; + } else if (vif->type == NL80211_IFTYPE_ADHOC) { + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + + beacon = rcu_dereference(ifibss->presp); + if (!beacon) + goto out; + + beacon_data = beacon->head; + beacon_data_len = beacon->head_len; + } else if (vif->type == NL80211_IFTYPE_MESH_POINT) { + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + + beacon = rcu_dereference(ifmsh->beacon); + if (!beacon) + goto out; + + beacon_data = beacon->head; + beacon_data_len = beacon->head_len; + } else { + WARN_ON(1); + goto out; + } + + if (!beacon->cntdwn_counter_offsets[0]) + goto out; + + if (WARN_ON_ONCE(beacon->cntdwn_counter_offsets[0] > beacon_data_len)) + goto out; + + if (beacon_data[beacon->cntdwn_counter_offsets[0]] == 1) + ret = true; + + out: + rcu_read_unlock(); + + return ret; +} +EXPORT_SYMBOL(ieee80211_beacon_cntdwn_is_complete); + +static int ieee80211_beacon_protect(struct sk_buff *skb, + struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_link_data *link) +{ + ieee80211_tx_result res; + struct ieee80211_tx_data tx; + struct sk_buff *check_skb; + + memset(&tx, 0, sizeof(tx)); + tx.key = rcu_dereference(link->default_beacon_key); + if (!tx.key) + return 0; + tx.local = local; + tx.sdata = sdata; + __skb_queue_head_init(&tx.skbs); + __skb_queue_tail(&tx.skbs, skb); + res = ieee80211_tx_h_encrypt(&tx); + check_skb = __skb_dequeue(&tx.skbs); + /* we may crash after this, but it'd be a bug in crypto */ + WARN_ON(check_skb != skb); + if (WARN_ON_ONCE(res != TX_CONTINUE)) + return -EINVAL; + + return 0; +} + +static void +ieee80211_beacon_get_finish(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, + struct ieee80211_link_data *link, + struct ieee80211_mutable_offsets *offs, + struct beacon_data *beacon, + struct sk_buff *skb, + struct ieee80211_chanctx_conf *chanctx_conf, + u16 csa_off_base) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_tx_info *info; + enum nl80211_band band; + struct ieee80211_tx_rate_control txrc; + + /* CSA offsets */ + if (offs && beacon) { + u16 i; + + for (i = 0; i < IEEE80211_MAX_CNTDWN_COUNTERS_NUM; i++) { + u16 csa_off = beacon->cntdwn_counter_offsets[i]; + + if (!csa_off) + continue; + + offs->cntdwn_counter_offs[i] = csa_off_base + csa_off; + } + } + + band = chanctx_conf->def.chan->band; + info = IEEE80211_SKB_CB(skb); + info->flags |= IEEE80211_TX_INTFL_DONT_ENCRYPT; + info->flags |= IEEE80211_TX_CTL_NO_ACK; + info->band = band; + + memset(&txrc, 0, sizeof(txrc)); + txrc.hw = hw; + txrc.sband = local->hw.wiphy->bands[band]; + txrc.bss_conf = link->conf; + txrc.skb = skb; + txrc.reported_rate.idx = -1; + if (sdata->beacon_rate_set && sdata->beacon_rateidx_mask[band]) + txrc.rate_idx_mask = sdata->beacon_rateidx_mask[band]; + else + txrc.rate_idx_mask = sdata->rc_rateidx_mask[band]; + txrc.bss = true; + rate_control_get_rate(sdata, NULL, &txrc); + + info->control.vif = vif; + info->control.flags |= u32_encode_bits(link->link_id, + IEEE80211_TX_CTRL_MLO_LINK); + info->flags |= IEEE80211_TX_CTL_CLEAR_PS_FILT | + IEEE80211_TX_CTL_ASSIGN_SEQ | + IEEE80211_TX_CTL_FIRST_FRAGMENT; +} + +static void +ieee80211_beacon_add_mbssid(struct sk_buff *skb, struct beacon_data *beacon) +{ + int i; + + if (!beacon->mbssid_ies) + return; + + for (i = 0; i < beacon->mbssid_ies->cnt; i++) + skb_put_data(skb, beacon->mbssid_ies->elem[i].data, + beacon->mbssid_ies->elem[i].len); +} + +static struct sk_buff * +ieee80211_beacon_get_ap(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, + struct ieee80211_link_data *link, + struct ieee80211_mutable_offsets *offs, + bool is_template, + struct beacon_data *beacon, + struct ieee80211_chanctx_conf *chanctx_conf) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_if_ap *ap = &sdata->u.ap; + struct sk_buff *skb = NULL; + u16 csa_off_base = 0; + int mbssid_len; + + if (beacon->cntdwn_counter_offsets[0]) { + if (!is_template) + ieee80211_beacon_update_cntdwn(vif); + + ieee80211_set_beacon_cntdwn(sdata, beacon, link); + } + + /* headroom, head length, + * tail length, maximum TIM length and multiple BSSID length + */ + mbssid_len = ieee80211_get_mbssid_beacon_len(beacon->mbssid_ies); + skb = dev_alloc_skb(local->tx_headroom + beacon->head_len + + beacon->tail_len + 256 + + local->hw.extra_beacon_tailroom + mbssid_len); + if (!skb) + return NULL; + + skb_reserve(skb, local->tx_headroom); + skb_put_data(skb, beacon->head, beacon->head_len); + + ieee80211_beacon_add_tim(sdata, link, &ap->ps, skb, is_template); + + if (offs) { + offs->tim_offset = beacon->head_len; + offs->tim_length = skb->len - beacon->head_len; + offs->cntdwn_counter_offs[0] = beacon->cntdwn_counter_offsets[0]; + + if (mbssid_len) { + ieee80211_beacon_add_mbssid(skb, beacon); + offs->mbssid_off = skb->len - mbssid_len; + } + + /* for AP the csa offsets are from tail */ + csa_off_base = skb->len; + } + + if (beacon->tail) + skb_put_data(skb, beacon->tail, beacon->tail_len); + + if (ieee80211_beacon_protect(skb, local, sdata, link) < 0) + return NULL; + + ieee80211_beacon_get_finish(hw, vif, link, offs, beacon, skb, + chanctx_conf, csa_off_base); + return skb; +} + +static struct sk_buff * +__ieee80211_beacon_get(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, + struct ieee80211_mutable_offsets *offs, + bool is_template, + unsigned int link_id) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct beacon_data *beacon = NULL; + struct sk_buff *skb = NULL; + struct ieee80211_sub_if_data *sdata = NULL; + struct ieee80211_chanctx_conf *chanctx_conf; + struct ieee80211_link_data *link; + + rcu_read_lock(); + + sdata = vif_to_sdata(vif); + link = rcu_dereference(sdata->link[link_id]); + if (!link) + goto out; + chanctx_conf = + rcu_dereference(link->conf->chanctx_conf); + + if (!ieee80211_sdata_running(sdata) || !chanctx_conf) + goto out; + + if (offs) + memset(offs, 0, sizeof(*offs)); + + if (sdata->vif.type == NL80211_IFTYPE_AP) { + beacon = rcu_dereference(link->u.ap.beacon); + if (!beacon) + goto out; + + skb = ieee80211_beacon_get_ap(hw, vif, link, offs, is_template, + beacon, chanctx_conf); + } else if (sdata->vif.type == NL80211_IFTYPE_ADHOC) { + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + struct ieee80211_hdr *hdr; + + beacon = rcu_dereference(ifibss->presp); + if (!beacon) + goto out; + + if (beacon->cntdwn_counter_offsets[0]) { + if (!is_template) + __ieee80211_beacon_update_cntdwn(beacon); + + ieee80211_set_beacon_cntdwn(sdata, beacon, link); + } + + skb = dev_alloc_skb(local->tx_headroom + beacon->head_len + + local->hw.extra_beacon_tailroom); + if (!skb) + goto out; + skb_reserve(skb, local->tx_headroom); + skb_put_data(skb, beacon->head, beacon->head_len); + + hdr = (struct ieee80211_hdr *) skb->data; + hdr->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_BEACON); + + ieee80211_beacon_get_finish(hw, vif, link, offs, beacon, skb, + chanctx_conf, 0); + } else if (ieee80211_vif_is_mesh(&sdata->vif)) { + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + + beacon = rcu_dereference(ifmsh->beacon); + if (!beacon) + goto out; + + if (beacon->cntdwn_counter_offsets[0]) { + if (!is_template) + /* TODO: For mesh csa_counter is in TU, so + * decrementing it by one isn't correct, but + * for now we leave it consistent with overall + * mac80211's behavior. + */ + __ieee80211_beacon_update_cntdwn(beacon); + + ieee80211_set_beacon_cntdwn(sdata, beacon, link); + } + + if (ifmsh->sync_ops) + ifmsh->sync_ops->adjust_tsf(sdata, beacon); + + skb = dev_alloc_skb(local->tx_headroom + + beacon->head_len + + 256 + /* TIM IE */ + beacon->tail_len + + local->hw.extra_beacon_tailroom); + if (!skb) + goto out; + skb_reserve(skb, local->tx_headroom); + skb_put_data(skb, beacon->head, beacon->head_len); + ieee80211_beacon_add_tim(sdata, link, &ifmsh->ps, skb, + is_template); + + if (offs) { + offs->tim_offset = beacon->head_len; + offs->tim_length = skb->len - beacon->head_len; + } + + skb_put_data(skb, beacon->tail, beacon->tail_len); + ieee80211_beacon_get_finish(hw, vif, link, offs, beacon, skb, + chanctx_conf, 0); + } else { + WARN_ON(1); + goto out; + } + + out: + rcu_read_unlock(); + return skb; + +} + +struct sk_buff * +ieee80211_beacon_get_template(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, + struct ieee80211_mutable_offsets *offs, + unsigned int link_id) +{ + return __ieee80211_beacon_get(hw, vif, offs, true, link_id); +} +EXPORT_SYMBOL(ieee80211_beacon_get_template); + +struct sk_buff *ieee80211_beacon_get_tim(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, + u16 *tim_offset, u16 *tim_length, + unsigned int link_id) +{ + struct ieee80211_mutable_offsets offs = {}; + struct sk_buff *bcn = __ieee80211_beacon_get(hw, vif, &offs, false, + link_id); + struct sk_buff *copy; + int shift; + + if (!bcn) + return bcn; + + if (tim_offset) + *tim_offset = offs.tim_offset; + + if (tim_length) + *tim_length = offs.tim_length; + + if (ieee80211_hw_check(hw, BEACON_TX_STATUS) || + !hw_to_local(hw)->monitors) + return bcn; + + /* send a copy to monitor interfaces */ + copy = skb_copy(bcn, GFP_ATOMIC); + if (!copy) + return bcn; + + shift = ieee80211_vif_get_shift(vif); + ieee80211_tx_monitor(hw_to_local(hw), copy, 1, shift, false, NULL); + + return bcn; +} +EXPORT_SYMBOL(ieee80211_beacon_get_tim); + +struct sk_buff *ieee80211_proberesp_get(struct ieee80211_hw *hw, + struct ieee80211_vif *vif) +{ + struct sk_buff *skb = NULL; + struct probe_resp *presp = NULL; + struct ieee80211_hdr *hdr; + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + + if (sdata->vif.type != NL80211_IFTYPE_AP) + return NULL; + + rcu_read_lock(); + presp = rcu_dereference(sdata->deflink.u.ap.probe_resp); + if (!presp) + goto out; + + skb = dev_alloc_skb(presp->len); + if (!skb) + goto out; + + skb_put_data(skb, presp->data, presp->len); + + hdr = (struct ieee80211_hdr *) skb->data; + memset(hdr->addr1, 0, sizeof(hdr->addr1)); + +out: + rcu_read_unlock(); + return skb; +} +EXPORT_SYMBOL(ieee80211_proberesp_get); + +struct sk_buff *ieee80211_get_fils_discovery_tmpl(struct ieee80211_hw *hw, + struct ieee80211_vif *vif) +{ + struct sk_buff *skb = NULL; + struct fils_discovery_data *tmpl = NULL; + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + + if (sdata->vif.type != NL80211_IFTYPE_AP) + return NULL; + + rcu_read_lock(); + tmpl = rcu_dereference(sdata->deflink.u.ap.fils_discovery); + if (!tmpl) { + rcu_read_unlock(); + return NULL; + } + + skb = dev_alloc_skb(sdata->local->hw.extra_tx_headroom + tmpl->len); + if (skb) { + skb_reserve(skb, sdata->local->hw.extra_tx_headroom); + skb_put_data(skb, tmpl->data, tmpl->len); + } + + rcu_read_unlock(); + return skb; +} +EXPORT_SYMBOL(ieee80211_get_fils_discovery_tmpl); + +struct sk_buff * +ieee80211_get_unsol_bcast_probe_resp_tmpl(struct ieee80211_hw *hw, + struct ieee80211_vif *vif) +{ + struct sk_buff *skb = NULL; + struct unsol_bcast_probe_resp_data *tmpl = NULL; + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + + if (sdata->vif.type != NL80211_IFTYPE_AP) + return NULL; + + rcu_read_lock(); + tmpl = rcu_dereference(sdata->deflink.u.ap.unsol_bcast_probe_resp); + if (!tmpl) { + rcu_read_unlock(); + return NULL; + } + + skb = dev_alloc_skb(sdata->local->hw.extra_tx_headroom + tmpl->len); + if (skb) { + skb_reserve(skb, sdata->local->hw.extra_tx_headroom); + skb_put_data(skb, tmpl->data, tmpl->len); + } + + rcu_read_unlock(); + return skb; +} +EXPORT_SYMBOL(ieee80211_get_unsol_bcast_probe_resp_tmpl); + +struct sk_buff *ieee80211_pspoll_get(struct ieee80211_hw *hw, + struct ieee80211_vif *vif) +{ + struct ieee80211_sub_if_data *sdata; + struct ieee80211_pspoll *pspoll; + struct ieee80211_local *local; + struct sk_buff *skb; + + if (WARN_ON(vif->type != NL80211_IFTYPE_STATION)) + return NULL; + + sdata = vif_to_sdata(vif); + local = sdata->local; + + skb = dev_alloc_skb(local->hw.extra_tx_headroom + sizeof(*pspoll)); + if (!skb) + return NULL; + + skb_reserve(skb, local->hw.extra_tx_headroom); + + pspoll = skb_put_zero(skb, sizeof(*pspoll)); + pspoll->frame_control = cpu_to_le16(IEEE80211_FTYPE_CTL | + IEEE80211_STYPE_PSPOLL); + pspoll->aid = cpu_to_le16(sdata->vif.cfg.aid); + + /* aid in PS-Poll has its two MSBs each set to 1 */ + pspoll->aid |= cpu_to_le16(1 << 15 | 1 << 14); + + memcpy(pspoll->bssid, sdata->deflink.u.mgd.bssid, ETH_ALEN); + memcpy(pspoll->ta, vif->addr, ETH_ALEN); + + return skb; +} +EXPORT_SYMBOL(ieee80211_pspoll_get); + +struct sk_buff *ieee80211_nullfunc_get(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, + int link_id, bool qos_ok) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + struct ieee80211_local *local = sdata->local; + struct ieee80211_link_data *link = NULL; + struct ieee80211_hdr_3addr *nullfunc; + struct sk_buff *skb; + bool qos = false; + + if (WARN_ON(vif->type != NL80211_IFTYPE_STATION)) + return NULL; + + skb = dev_alloc_skb(local->hw.extra_tx_headroom + + sizeof(*nullfunc) + 2); + if (!skb) + return NULL; + + rcu_read_lock(); + if (qos_ok) { + struct sta_info *sta; + + sta = sta_info_get(sdata, vif->cfg.ap_addr); + qos = sta && sta->sta.wme; + } + + if (link_id >= 0) { + link = rcu_dereference(sdata->link[link_id]); + if (WARN_ON_ONCE(!link)) { + rcu_read_unlock(); + kfree_skb(skb); + return NULL; + } + } + + skb_reserve(skb, local->hw.extra_tx_headroom); + + nullfunc = skb_put_zero(skb, sizeof(*nullfunc)); + nullfunc->frame_control = cpu_to_le16(IEEE80211_FTYPE_DATA | + IEEE80211_STYPE_NULLFUNC | + IEEE80211_FCTL_TODS); + if (qos) { + __le16 qoshdr = cpu_to_le16(7); + + BUILD_BUG_ON((IEEE80211_STYPE_QOS_NULLFUNC | + IEEE80211_STYPE_NULLFUNC) != + IEEE80211_STYPE_QOS_NULLFUNC); + nullfunc->frame_control |= + cpu_to_le16(IEEE80211_STYPE_QOS_NULLFUNC); + skb->priority = 7; + skb_set_queue_mapping(skb, IEEE80211_AC_VO); + skb_put_data(skb, &qoshdr, sizeof(qoshdr)); + } + + if (link) { + memcpy(nullfunc->addr1, link->conf->bssid, ETH_ALEN); + memcpy(nullfunc->addr2, link->conf->addr, ETH_ALEN); + memcpy(nullfunc->addr3, link->conf->bssid, ETH_ALEN); + } else { + memcpy(nullfunc->addr1, vif->cfg.ap_addr, ETH_ALEN); + memcpy(nullfunc->addr2, vif->addr, ETH_ALEN); + memcpy(nullfunc->addr3, vif->cfg.ap_addr, ETH_ALEN); + } + rcu_read_unlock(); + + return skb; +} +EXPORT_SYMBOL(ieee80211_nullfunc_get); + +struct sk_buff *ieee80211_probereq_get(struct ieee80211_hw *hw, + const u8 *src_addr, + const u8 *ssid, size_t ssid_len, + size_t tailroom) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_hdr_3addr *hdr; + struct sk_buff *skb; + size_t ie_ssid_len; + u8 *pos; + + ie_ssid_len = 2 + ssid_len; + + skb = dev_alloc_skb(local->hw.extra_tx_headroom + sizeof(*hdr) + + ie_ssid_len + tailroom); + if (!skb) + return NULL; + + skb_reserve(skb, local->hw.extra_tx_headroom); + + hdr = skb_put_zero(skb, sizeof(*hdr)); + hdr->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_PROBE_REQ); + eth_broadcast_addr(hdr->addr1); + memcpy(hdr->addr2, src_addr, ETH_ALEN); + eth_broadcast_addr(hdr->addr3); + + pos = skb_put(skb, ie_ssid_len); + *pos++ = WLAN_EID_SSID; + *pos++ = ssid_len; + if (ssid_len) + memcpy(pos, ssid, ssid_len); + pos += ssid_len; + + return skb; +} +EXPORT_SYMBOL(ieee80211_probereq_get); + +void ieee80211_rts_get(struct ieee80211_hw *hw, struct ieee80211_vif *vif, + const void *frame, size_t frame_len, + const struct ieee80211_tx_info *frame_txctl, + struct ieee80211_rts *rts) +{ + const struct ieee80211_hdr *hdr = frame; + + rts->frame_control = + cpu_to_le16(IEEE80211_FTYPE_CTL | IEEE80211_STYPE_RTS); + rts->duration = ieee80211_rts_duration(hw, vif, frame_len, + frame_txctl); + memcpy(rts->ra, hdr->addr1, sizeof(rts->ra)); + memcpy(rts->ta, hdr->addr2, sizeof(rts->ta)); +} +EXPORT_SYMBOL(ieee80211_rts_get); + +void ieee80211_ctstoself_get(struct ieee80211_hw *hw, struct ieee80211_vif *vif, + const void *frame, size_t frame_len, + const struct ieee80211_tx_info *frame_txctl, + struct ieee80211_cts *cts) +{ + const struct ieee80211_hdr *hdr = frame; + + cts->frame_control = + cpu_to_le16(IEEE80211_FTYPE_CTL | IEEE80211_STYPE_CTS); + cts->duration = ieee80211_ctstoself_duration(hw, vif, + frame_len, frame_txctl); + memcpy(cts->ra, hdr->addr1, sizeof(cts->ra)); +} +EXPORT_SYMBOL(ieee80211_ctstoself_get); + +struct sk_buff * +ieee80211_get_buffered_bc(struct ieee80211_hw *hw, + struct ieee80211_vif *vif) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct sk_buff *skb = NULL; + struct ieee80211_tx_data tx; + struct ieee80211_sub_if_data *sdata; + struct ps_data *ps; + struct ieee80211_tx_info *info; + struct ieee80211_chanctx_conf *chanctx_conf; + + sdata = vif_to_sdata(vif); + + rcu_read_lock(); + chanctx_conf = rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + + if (!chanctx_conf) + goto out; + + if (sdata->vif.type == NL80211_IFTYPE_AP) { + struct beacon_data *beacon = + rcu_dereference(sdata->deflink.u.ap.beacon); + + if (!beacon || !beacon->head) + goto out; + + ps = &sdata->u.ap.ps; + } else if (ieee80211_vif_is_mesh(&sdata->vif)) { + ps = &sdata->u.mesh.ps; + } else { + goto out; + } + + if (ps->dtim_count != 0 || !ps->dtim_bc_mc) + goto out; /* send buffered bc/mc only after DTIM beacon */ + + while (1) { + skb = skb_dequeue(&ps->bc_buf); + if (!skb) + goto out; + local->total_ps_buffered--; + + if (!skb_queue_empty(&ps->bc_buf) && skb->len >= 2) { + struct ieee80211_hdr *hdr = + (struct ieee80211_hdr *) skb->data; + /* more buffered multicast/broadcast frames ==> set + * MoreData flag in IEEE 802.11 header to inform PS + * STAs */ + hdr->frame_control |= + cpu_to_le16(IEEE80211_FCTL_MOREDATA); + } + + if (sdata->vif.type == NL80211_IFTYPE_AP) + sdata = IEEE80211_DEV_TO_SUB_IF(skb->dev); + if (!ieee80211_tx_prepare(sdata, &tx, NULL, skb)) + break; + ieee80211_free_txskb(hw, skb); + } + + info = IEEE80211_SKB_CB(skb); + + tx.flags |= IEEE80211_TX_PS_BUFFERED; + info->band = chanctx_conf->def.chan->band; + + if (invoke_tx_handlers(&tx)) + skb = NULL; + out: + rcu_read_unlock(); + + return skb; +} +EXPORT_SYMBOL(ieee80211_get_buffered_bc); + +int ieee80211_reserve_tid(struct ieee80211_sta *pubsta, u8 tid) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + struct ieee80211_sub_if_data *sdata = sta->sdata; + struct ieee80211_local *local = sdata->local; + int ret; + u32 queues; + + lockdep_assert_held(&local->sta_mtx); + + /* only some cases are supported right now */ + switch (sdata->vif.type) { + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + break; + default: + WARN_ON(1); + return -EINVAL; + } + + if (WARN_ON(tid >= IEEE80211_NUM_UPS)) + return -EINVAL; + + if (sta->reserved_tid == tid) { + ret = 0; + goto out; + } + + if (sta->reserved_tid != IEEE80211_TID_UNRESERVED) { + sdata_err(sdata, "TID reservation already active\n"); + ret = -EALREADY; + goto out; + } + + ieee80211_stop_vif_queues(sdata->local, sdata, + IEEE80211_QUEUE_STOP_REASON_RESERVE_TID); + + synchronize_net(); + + /* Tear down BA sessions so we stop aggregating on this TID */ + if (ieee80211_hw_check(&local->hw, AMPDU_AGGREGATION)) { + set_sta_flag(sta, WLAN_STA_BLOCK_BA); + __ieee80211_stop_tx_ba_session(sta, tid, + AGG_STOP_LOCAL_REQUEST); + } + + queues = BIT(sdata->vif.hw_queue[ieee802_1d_to_ac[tid]]); + __ieee80211_flush_queues(local, sdata, queues, false); + + sta->reserved_tid = tid; + + ieee80211_wake_vif_queues(local, sdata, + IEEE80211_QUEUE_STOP_REASON_RESERVE_TID); + + if (ieee80211_hw_check(&local->hw, AMPDU_AGGREGATION)) + clear_sta_flag(sta, WLAN_STA_BLOCK_BA); + + ret = 0; + out: + return ret; +} +EXPORT_SYMBOL(ieee80211_reserve_tid); + +void ieee80211_unreserve_tid(struct ieee80211_sta *pubsta, u8 tid) +{ + struct sta_info *sta = container_of(pubsta, struct sta_info, sta); + struct ieee80211_sub_if_data *sdata = sta->sdata; + + lockdep_assert_held(&sdata->local->sta_mtx); + + /* only some cases are supported right now */ + switch (sdata->vif.type) { + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + break; + default: + WARN_ON(1); + return; + } + + if (tid != sta->reserved_tid) { + sdata_err(sdata, "TID to unreserve (%d) isn't reserved\n", tid); + return; + } + + sta->reserved_tid = IEEE80211_TID_UNRESERVED; +} +EXPORT_SYMBOL(ieee80211_unreserve_tid); + +void __ieee80211_tx_skb_tid_band(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, int tid, int link_id, + enum nl80211_band band) +{ + const struct ieee80211_hdr *hdr = (void *)skb->data; + int ac = ieee80211_ac_from_tid(tid); + unsigned int link; + + skb_reset_mac_header(skb); + skb_set_queue_mapping(skb, ac); + skb->priority = tid; + + skb->dev = sdata->dev; + + BUILD_BUG_ON(IEEE80211_LINK_UNSPECIFIED < IEEE80211_MLD_MAX_NUM_LINKS); + BUILD_BUG_ON(!FIELD_FIT(IEEE80211_TX_CTRL_MLO_LINK, + IEEE80211_LINK_UNSPECIFIED)); + + if (!sdata->vif.valid_links) { + link = 0; + } else if (link_id >= 0) { + link = link_id; + } else if (memcmp(sdata->vif.addr, hdr->addr2, ETH_ALEN) == 0) { + /* address from the MLD */ + link = IEEE80211_LINK_UNSPECIFIED; + } else { + /* otherwise must be addressed from a link */ + rcu_read_lock(); + for (link = 0; link < ARRAY_SIZE(sdata->vif.link_conf); link++) { + struct ieee80211_bss_conf *link_conf; + + link_conf = rcu_dereference(sdata->vif.link_conf[link]); + if (!link_conf) + continue; + if (memcmp(link_conf->addr, hdr->addr2, ETH_ALEN) == 0) + break; + } + rcu_read_unlock(); + + if (WARN_ON_ONCE(link == ARRAY_SIZE(sdata->vif.link_conf))) + link = ffs(sdata->vif.active_links) - 1; + } + + IEEE80211_SKB_CB(skb)->control.flags |= + u32_encode_bits(link, IEEE80211_TX_CTRL_MLO_LINK); + + /* + * The other path calling ieee80211_xmit is from the tasklet, + * and while we can handle concurrent transmissions locking + * requirements are that we do not come into tx with bhs on. + */ + local_bh_disable(); + IEEE80211_SKB_CB(skb)->band = band; + ieee80211_xmit(sdata, NULL, skb); + local_bh_enable(); +} + +void ieee80211_tx_skb_tid(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, int tid, int link_id) +{ + struct ieee80211_chanctx_conf *chanctx_conf; + enum nl80211_band band; + + rcu_read_lock(); + if (!sdata->vif.valid_links) { + WARN_ON(link_id >= 0); + chanctx_conf = + rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + if (WARN_ON(!chanctx_conf)) { + rcu_read_unlock(); + kfree_skb(skb); + return; + } + band = chanctx_conf->def.chan->band; + } else { + WARN_ON(link_id >= 0 && + !(sdata->vif.active_links & BIT(link_id))); + /* MLD transmissions must not rely on the band */ + band = 0; + } + + __ieee80211_tx_skb_tid_band(sdata, skb, tid, link_id, band); + rcu_read_unlock(); +} + +int ieee80211_tx_control_port(struct wiphy *wiphy, struct net_device *dev, + const u8 *buf, size_t len, + const u8 *dest, __be16 proto, bool unencrypted, + int link_id, u64 *cookie) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + struct sk_buff *skb; + struct ethhdr *ehdr; + u32 ctrl_flags = 0; + u32 flags = 0; + int err; + + /* Only accept CONTROL_PORT_PROTOCOL configured in CONNECT/ASSOCIATE + * or Pre-Authentication + */ + if (proto != sdata->control_port_protocol && + proto != cpu_to_be16(ETH_P_PREAUTH)) + return -EINVAL; + + if (proto == sdata->control_port_protocol) + ctrl_flags |= IEEE80211_TX_CTRL_PORT_CTRL_PROTO | + IEEE80211_TX_CTRL_SKIP_MPATH_LOOKUP; + + if (unencrypted) + flags |= IEEE80211_TX_INTFL_DONT_ENCRYPT; + + if (cookie) + ctrl_flags |= IEEE80211_TX_CTL_REQ_TX_STATUS; + + flags |= IEEE80211_TX_INTFL_NL80211_FRAME_TX; + + skb = dev_alloc_skb(local->hw.extra_tx_headroom + + sizeof(struct ethhdr) + len); + if (!skb) + return -ENOMEM; + + skb_reserve(skb, local->hw.extra_tx_headroom + sizeof(struct ethhdr)); + + skb_put_data(skb, buf, len); + + ehdr = skb_push(skb, sizeof(struct ethhdr)); + memcpy(ehdr->h_dest, dest, ETH_ALEN); + + /* we may override the SA for MLO STA later */ + if (link_id < 0) { + ctrl_flags |= u32_encode_bits(IEEE80211_LINK_UNSPECIFIED, + IEEE80211_TX_CTRL_MLO_LINK); + memcpy(ehdr->h_source, sdata->vif.addr, ETH_ALEN); + } else { + struct ieee80211_bss_conf *link_conf; + + ctrl_flags |= u32_encode_bits(link_id, + IEEE80211_TX_CTRL_MLO_LINK); + + rcu_read_lock(); + link_conf = rcu_dereference(sdata->vif.link_conf[link_id]); + if (!link_conf) { + dev_kfree_skb(skb); + rcu_read_unlock(); + return -ENOLINK; + } + memcpy(ehdr->h_source, link_conf->addr, ETH_ALEN); + rcu_read_unlock(); + } + + ehdr->h_proto = proto; + + skb->dev = dev; + skb->protocol = proto; + skb_reset_network_header(skb); + skb_reset_mac_header(skb); + + if (local->hw.queues < IEEE80211_NUM_ACS) + goto start_xmit; + + /* update QoS header to prioritize control port frames if possible, + * priorization also happens for control port frames send over + * AF_PACKET + */ + rcu_read_lock(); + err = ieee80211_lookup_ra_sta(sdata, skb, &sta); + if (err) { + dev_kfree_skb(skb); + rcu_read_unlock(); + return err; + } + + if (!IS_ERR(sta)) { + u16 queue = __ieee80211_select_queue(sdata, sta, skb); + + skb_set_queue_mapping(skb, queue); + skb_get_hash(skb); + + /* + * for MLO STA, the SA should be the AP MLD address, but + * the link ID has been selected already + */ + if (sta && sta->sta.mlo) + memcpy(ehdr->h_source, sdata->vif.addr, ETH_ALEN); + } + rcu_read_unlock(); + +start_xmit: + /* mutex lock is only needed for incrementing the cookie counter */ + mutex_lock(&local->mtx); + + local_bh_disable(); + __ieee80211_subif_start_xmit(skb, skb->dev, flags, ctrl_flags, cookie); + local_bh_enable(); + + mutex_unlock(&local->mtx); + + return 0; +} + +int ieee80211_probe_mesh_link(struct wiphy *wiphy, struct net_device *dev, + const u8 *buf, size_t len) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_DEV_TO_SUB_IF(dev); + struct ieee80211_local *local = sdata->local; + struct sk_buff *skb; + + skb = dev_alloc_skb(local->hw.extra_tx_headroom + len + + 30 + /* header size */ + 18); /* 11s header size */ + if (!skb) + return -ENOMEM; + + skb_reserve(skb, local->hw.extra_tx_headroom); + skb_put_data(skb, buf, len); + + skb->dev = dev; + skb->protocol = htons(ETH_P_802_3); + skb_reset_network_header(skb); + skb_reset_mac_header(skb); + + local_bh_disable(); + __ieee80211_subif_start_xmit(skb, skb->dev, 0, + IEEE80211_TX_CTRL_SKIP_MPATH_LOOKUP, + NULL); + local_bh_enable(); + + return 0; +} diff --git a/net/mac80211/util.c b/net/mac80211/util.c new file mode 100644 index 000000000..1088d90e3 --- /dev/null +++ b/net/mac80211/util.c @@ -0,0 +1,4860 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2002-2005, Instant802 Networks, Inc. + * Copyright 2005-2006, Devicescape Software, Inc. + * Copyright 2006-2007 Jiri Benc + * Copyright 2007 Johannes Berg + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright (C) 2015-2017 Intel Deutschland GmbH + * Copyright (C) 2018-2022 Intel Corporation + * + * utilities for mac80211 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ieee80211_i.h" +#include "driver-ops.h" +#include "rate.h" +#include "mesh.h" +#include "wme.h" +#include "led.h" +#include "wep.h" + +/* privid for wiphys to determine whether they belong to us or not */ +const void *const mac80211_wiphy_privid = &mac80211_wiphy_privid; + +struct ieee80211_hw *wiphy_to_ieee80211_hw(struct wiphy *wiphy) +{ + struct ieee80211_local *local; + + local = wiphy_priv(wiphy); + return &local->hw; +} +EXPORT_SYMBOL(wiphy_to_ieee80211_hw); + +u8 *ieee80211_get_bssid(struct ieee80211_hdr *hdr, size_t len, + enum nl80211_iftype type) +{ + __le16 fc = hdr->frame_control; + + if (ieee80211_is_data(fc)) { + if (len < 24) /* drop incorrect hdr len (data) */ + return NULL; + + if (ieee80211_has_a4(fc)) + return NULL; + if (ieee80211_has_tods(fc)) + return hdr->addr1; + if (ieee80211_has_fromds(fc)) + return hdr->addr2; + + return hdr->addr3; + } + + if (ieee80211_is_s1g_beacon(fc)) { + struct ieee80211_ext *ext = (void *) hdr; + + return ext->u.s1g_beacon.sa; + } + + if (ieee80211_is_mgmt(fc)) { + if (len < 24) /* drop incorrect hdr len (mgmt) */ + return NULL; + return hdr->addr3; + } + + if (ieee80211_is_ctl(fc)) { + if (ieee80211_is_pspoll(fc)) + return hdr->addr1; + + if (ieee80211_is_back_req(fc)) { + switch (type) { + case NL80211_IFTYPE_STATION: + return hdr->addr2; + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + return hdr->addr1; + default: + break; /* fall through to the return */ + } + } + } + + return NULL; +} +EXPORT_SYMBOL(ieee80211_get_bssid); + +void ieee80211_tx_set_protected(struct ieee80211_tx_data *tx) +{ + struct sk_buff *skb; + struct ieee80211_hdr *hdr; + + skb_queue_walk(&tx->skbs, skb) { + hdr = (struct ieee80211_hdr *) skb->data; + hdr->frame_control |= cpu_to_le16(IEEE80211_FCTL_PROTECTED); + } +} + +int ieee80211_frame_duration(enum nl80211_band band, size_t len, + int rate, int erp, int short_preamble, + int shift) +{ + int dur; + + /* calculate duration (in microseconds, rounded up to next higher + * integer if it includes a fractional microsecond) to send frame of + * len bytes (does not include FCS) at the given rate. Duration will + * also include SIFS. + * + * rate is in 100 kbps, so divident is multiplied by 10 in the + * DIV_ROUND_UP() operations. + * + * shift may be 2 for 5 MHz channels or 1 for 10 MHz channels, and + * is assumed to be 0 otherwise. + */ + + if (band == NL80211_BAND_5GHZ || erp) { + /* + * OFDM: + * + * N_DBPS = DATARATE x 4 + * N_SYM = Ceiling((16+8xLENGTH+6) / N_DBPS) + * (16 = SIGNAL time, 6 = tail bits) + * TXTIME = T_PREAMBLE + T_SIGNAL + T_SYM x N_SYM + Signal Ext + * + * T_SYM = 4 usec + * 802.11a - 18.5.2: aSIFSTime = 16 usec + * 802.11g - 19.8.4: aSIFSTime = 10 usec + + * signal ext = 6 usec + */ + dur = 16; /* SIFS + signal ext */ + dur += 16; /* IEEE 802.11-2012 18.3.2.4: T_PREAMBLE = 16 usec */ + dur += 4; /* IEEE 802.11-2012 18.3.2.4: T_SIGNAL = 4 usec */ + + /* IEEE 802.11-2012 18.3.2.4: all values above are: + * * times 4 for 5 MHz + * * times 2 for 10 MHz + */ + dur *= 1 << shift; + + /* rates should already consider the channel bandwidth, + * don't apply divisor again. + */ + dur += 4 * DIV_ROUND_UP((16 + 8 * (len + 4) + 6) * 10, + 4 * rate); /* T_SYM x N_SYM */ + } else { + /* + * 802.11b or 802.11g with 802.11b compatibility: + * 18.3.4: TXTIME = PreambleLength + PLCPHeaderTime + + * Ceiling(((LENGTH+PBCC)x8)/DATARATE). PBCC=0. + * + * 802.11 (DS): 15.3.3, 802.11b: 18.3.4 + * aSIFSTime = 10 usec + * aPreambleLength = 144 usec or 72 usec with short preamble + * aPLCPHeaderLength = 48 usec or 24 usec with short preamble + */ + dur = 10; /* aSIFSTime = 10 usec */ + dur += short_preamble ? (72 + 24) : (144 + 48); + + dur += DIV_ROUND_UP(8 * (len + 4) * 10, rate); + } + + return dur; +} + +/* Exported duration function for driver use */ +__le16 ieee80211_generic_frame_duration(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, + enum nl80211_band band, + size_t frame_len, + struct ieee80211_rate *rate) +{ + struct ieee80211_sub_if_data *sdata; + u16 dur; + int erp, shift = 0; + bool short_preamble = false; + + erp = 0; + if (vif) { + sdata = vif_to_sdata(vif); + short_preamble = sdata->vif.bss_conf.use_short_preamble; + if (sdata->deflink.operating_11g_mode) + erp = rate->flags & IEEE80211_RATE_ERP_G; + shift = ieee80211_vif_get_shift(vif); + } + + dur = ieee80211_frame_duration(band, frame_len, rate->bitrate, erp, + short_preamble, shift); + + return cpu_to_le16(dur); +} +EXPORT_SYMBOL(ieee80211_generic_frame_duration); + +__le16 ieee80211_rts_duration(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, size_t frame_len, + const struct ieee80211_tx_info *frame_txctl) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_rate *rate; + struct ieee80211_sub_if_data *sdata; + bool short_preamble; + int erp, shift = 0, bitrate; + u16 dur; + struct ieee80211_supported_band *sband; + + sband = local->hw.wiphy->bands[frame_txctl->band]; + + short_preamble = false; + + rate = &sband->bitrates[frame_txctl->control.rts_cts_rate_idx]; + + erp = 0; + if (vif) { + sdata = vif_to_sdata(vif); + short_preamble = sdata->vif.bss_conf.use_short_preamble; + if (sdata->deflink.operating_11g_mode) + erp = rate->flags & IEEE80211_RATE_ERP_G; + shift = ieee80211_vif_get_shift(vif); + } + + bitrate = DIV_ROUND_UP(rate->bitrate, 1 << shift); + + /* CTS duration */ + dur = ieee80211_frame_duration(sband->band, 10, bitrate, + erp, short_preamble, shift); + /* Data frame duration */ + dur += ieee80211_frame_duration(sband->band, frame_len, bitrate, + erp, short_preamble, shift); + /* ACK duration */ + dur += ieee80211_frame_duration(sband->band, 10, bitrate, + erp, short_preamble, shift); + + return cpu_to_le16(dur); +} +EXPORT_SYMBOL(ieee80211_rts_duration); + +__le16 ieee80211_ctstoself_duration(struct ieee80211_hw *hw, + struct ieee80211_vif *vif, + size_t frame_len, + const struct ieee80211_tx_info *frame_txctl) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_rate *rate; + struct ieee80211_sub_if_data *sdata; + bool short_preamble; + int erp, shift = 0, bitrate; + u16 dur; + struct ieee80211_supported_band *sband; + + sband = local->hw.wiphy->bands[frame_txctl->band]; + + short_preamble = false; + + rate = &sband->bitrates[frame_txctl->control.rts_cts_rate_idx]; + erp = 0; + if (vif) { + sdata = vif_to_sdata(vif); + short_preamble = sdata->vif.bss_conf.use_short_preamble; + if (sdata->deflink.operating_11g_mode) + erp = rate->flags & IEEE80211_RATE_ERP_G; + shift = ieee80211_vif_get_shift(vif); + } + + bitrate = DIV_ROUND_UP(rate->bitrate, 1 << shift); + + /* Data frame duration */ + dur = ieee80211_frame_duration(sband->band, frame_len, bitrate, + erp, short_preamble, shift); + if (!(frame_txctl->flags & IEEE80211_TX_CTL_NO_ACK)) { + /* ACK duration */ + dur += ieee80211_frame_duration(sband->band, 10, bitrate, + erp, short_preamble, shift); + } + + return cpu_to_le16(dur); +} +EXPORT_SYMBOL(ieee80211_ctstoself_duration); + +static void __ieee80211_wake_txqs(struct ieee80211_sub_if_data *sdata, int ac) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_vif *vif = &sdata->vif; + struct fq *fq = &local->fq; + struct ps_data *ps = NULL; + struct txq_info *txqi; + struct sta_info *sta; + int i; + + local_bh_disable(); + spin_lock(&fq->lock); + + if (!test_bit(SDATA_STATE_RUNNING, &sdata->state)) + goto out; + + if (sdata->vif.type == NL80211_IFTYPE_AP) + ps = &sdata->bss->ps; + + list_for_each_entry_rcu(sta, &local->sta_list, list) { + if (sdata != sta->sdata) + continue; + + for (i = 0; i < ARRAY_SIZE(sta->sta.txq); i++) { + struct ieee80211_txq *txq = sta->sta.txq[i]; + + if (!txq) + continue; + + txqi = to_txq_info(txq); + + if (ac != txq->ac) + continue; + + if (!test_and_clear_bit(IEEE80211_TXQ_DIRTY, + &txqi->flags)) + continue; + + spin_unlock(&fq->lock); + drv_wake_tx_queue(local, txqi); + spin_lock(&fq->lock); + } + } + + if (!vif->txq) + goto out; + + txqi = to_txq_info(vif->txq); + + if (!test_and_clear_bit(IEEE80211_TXQ_DIRTY, &txqi->flags) || + (ps && atomic_read(&ps->num_sta_ps)) || ac != vif->txq->ac) + goto out; + + spin_unlock(&fq->lock); + + drv_wake_tx_queue(local, txqi); + local_bh_enable(); + return; +out: + spin_unlock(&fq->lock); + local_bh_enable(); +} + +static void +__releases(&local->queue_stop_reason_lock) +__acquires(&local->queue_stop_reason_lock) +_ieee80211_wake_txqs(struct ieee80211_local *local, unsigned long *flags) +{ + struct ieee80211_sub_if_data *sdata; + int n_acs = IEEE80211_NUM_ACS; + int i; + + rcu_read_lock(); + + if (local->hw.queues < IEEE80211_NUM_ACS) + n_acs = 1; + + for (i = 0; i < local->hw.queues; i++) { + if (local->queue_stop_reasons[i]) + continue; + + spin_unlock_irqrestore(&local->queue_stop_reason_lock, *flags); + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + int ac; + + for (ac = 0; ac < n_acs; ac++) { + int ac_queue = sdata->vif.hw_queue[ac]; + + if (ac_queue == i || + sdata->vif.cab_queue == i) + __ieee80211_wake_txqs(sdata, ac); + } + } + spin_lock_irqsave(&local->queue_stop_reason_lock, *flags); + } + + rcu_read_unlock(); +} + +void ieee80211_wake_txqs(struct tasklet_struct *t) +{ + struct ieee80211_local *local = from_tasklet(local, t, + wake_txqs_tasklet); + unsigned long flags; + + spin_lock_irqsave(&local->queue_stop_reason_lock, flags); + _ieee80211_wake_txqs(local, &flags); + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); +} + +void ieee80211_propagate_queue_wake(struct ieee80211_local *local, int queue) +{ + struct ieee80211_sub_if_data *sdata; + int n_acs = IEEE80211_NUM_ACS; + + if (local->ops->wake_tx_queue) + return; + + if (local->hw.queues < IEEE80211_NUM_ACS) + n_acs = 1; + + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + int ac; + + if (!sdata->dev) + continue; + + if (sdata->vif.cab_queue != IEEE80211_INVAL_HW_QUEUE && + local->queue_stop_reasons[sdata->vif.cab_queue] != 0) + continue; + + for (ac = 0; ac < n_acs; ac++) { + int ac_queue = sdata->vif.hw_queue[ac]; + + if (ac_queue == queue || + (sdata->vif.cab_queue == queue && + local->queue_stop_reasons[ac_queue] == 0 && + skb_queue_empty(&local->pending[ac_queue]))) + netif_wake_subqueue(sdata->dev, ac); + } + } +} + +static void __ieee80211_wake_queue(struct ieee80211_hw *hw, int queue, + enum queue_stop_reason reason, + bool refcounted, + unsigned long *flags) +{ + struct ieee80211_local *local = hw_to_local(hw); + + trace_wake_queue(local, queue, reason); + + if (WARN_ON(queue >= hw->queues)) + return; + + if (!test_bit(reason, &local->queue_stop_reasons[queue])) + return; + + if (!refcounted) { + local->q_stop_reasons[queue][reason] = 0; + } else { + local->q_stop_reasons[queue][reason]--; + if (WARN_ON(local->q_stop_reasons[queue][reason] < 0)) + local->q_stop_reasons[queue][reason] = 0; + } + + if (local->q_stop_reasons[queue][reason] == 0) + __clear_bit(reason, &local->queue_stop_reasons[queue]); + + if (local->queue_stop_reasons[queue] != 0) + /* someone still has this queue stopped */ + return; + + if (skb_queue_empty(&local->pending[queue])) { + rcu_read_lock(); + ieee80211_propagate_queue_wake(local, queue); + rcu_read_unlock(); + } else + tasklet_schedule(&local->tx_pending_tasklet); + + /* + * Calling _ieee80211_wake_txqs here can be a problem because it may + * release queue_stop_reason_lock which has been taken by + * __ieee80211_wake_queue's caller. It is certainly not very nice to + * release someone's lock, but it is fine because all the callers of + * __ieee80211_wake_queue call it right before releasing the lock. + */ + if (local->ops->wake_tx_queue) { + if (reason == IEEE80211_QUEUE_STOP_REASON_DRIVER) + tasklet_schedule(&local->wake_txqs_tasklet); + else + _ieee80211_wake_txqs(local, flags); + } +} + +void ieee80211_wake_queue_by_reason(struct ieee80211_hw *hw, int queue, + enum queue_stop_reason reason, + bool refcounted) +{ + struct ieee80211_local *local = hw_to_local(hw); + unsigned long flags; + + spin_lock_irqsave(&local->queue_stop_reason_lock, flags); + __ieee80211_wake_queue(hw, queue, reason, refcounted, &flags); + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); +} + +void ieee80211_wake_queue(struct ieee80211_hw *hw, int queue) +{ + ieee80211_wake_queue_by_reason(hw, queue, + IEEE80211_QUEUE_STOP_REASON_DRIVER, + false); +} +EXPORT_SYMBOL(ieee80211_wake_queue); + +static void __ieee80211_stop_queue(struct ieee80211_hw *hw, int queue, + enum queue_stop_reason reason, + bool refcounted) +{ + struct ieee80211_local *local = hw_to_local(hw); + struct ieee80211_sub_if_data *sdata; + int n_acs = IEEE80211_NUM_ACS; + + trace_stop_queue(local, queue, reason); + + if (WARN_ON(queue >= hw->queues)) + return; + + if (!refcounted) + local->q_stop_reasons[queue][reason] = 1; + else + local->q_stop_reasons[queue][reason]++; + + if (__test_and_set_bit(reason, &local->queue_stop_reasons[queue])) + return; + + if (local->hw.queues < IEEE80211_NUM_ACS) + n_acs = 1; + + rcu_read_lock(); + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + int ac; + + if (!sdata->dev) + continue; + + for (ac = 0; ac < n_acs; ac++) { + if (!local->ops->wake_tx_queue && + (sdata->vif.hw_queue[ac] == queue || + sdata->vif.cab_queue == queue)) + netif_stop_subqueue(sdata->dev, ac); + } + } + rcu_read_unlock(); +} + +void ieee80211_stop_queue_by_reason(struct ieee80211_hw *hw, int queue, + enum queue_stop_reason reason, + bool refcounted) +{ + struct ieee80211_local *local = hw_to_local(hw); + unsigned long flags; + + spin_lock_irqsave(&local->queue_stop_reason_lock, flags); + __ieee80211_stop_queue(hw, queue, reason, refcounted); + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); +} + +void ieee80211_stop_queue(struct ieee80211_hw *hw, int queue) +{ + ieee80211_stop_queue_by_reason(hw, queue, + IEEE80211_QUEUE_STOP_REASON_DRIVER, + false); +} +EXPORT_SYMBOL(ieee80211_stop_queue); + +void ieee80211_add_pending_skb(struct ieee80211_local *local, + struct sk_buff *skb) +{ + struct ieee80211_hw *hw = &local->hw; + unsigned long flags; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + int queue = info->hw_queue; + + if (WARN_ON(!info->control.vif)) { + ieee80211_free_txskb(&local->hw, skb); + return; + } + + spin_lock_irqsave(&local->queue_stop_reason_lock, flags); + __ieee80211_stop_queue(hw, queue, IEEE80211_QUEUE_STOP_REASON_SKB_ADD, + false); + __skb_queue_tail(&local->pending[queue], skb); + __ieee80211_wake_queue(hw, queue, IEEE80211_QUEUE_STOP_REASON_SKB_ADD, + false, &flags); + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); +} + +void ieee80211_add_pending_skbs(struct ieee80211_local *local, + struct sk_buff_head *skbs) +{ + struct ieee80211_hw *hw = &local->hw; + struct sk_buff *skb; + unsigned long flags; + int queue, i; + + spin_lock_irqsave(&local->queue_stop_reason_lock, flags); + while ((skb = skb_dequeue(skbs))) { + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + + if (WARN_ON(!info->control.vif)) { + ieee80211_free_txskb(&local->hw, skb); + continue; + } + + queue = info->hw_queue; + + __ieee80211_stop_queue(hw, queue, + IEEE80211_QUEUE_STOP_REASON_SKB_ADD, + false); + + __skb_queue_tail(&local->pending[queue], skb); + } + + for (i = 0; i < hw->queues; i++) + __ieee80211_wake_queue(hw, i, + IEEE80211_QUEUE_STOP_REASON_SKB_ADD, + false, &flags); + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); +} + +void ieee80211_stop_queues_by_reason(struct ieee80211_hw *hw, + unsigned long queues, + enum queue_stop_reason reason, + bool refcounted) +{ + struct ieee80211_local *local = hw_to_local(hw); + unsigned long flags; + int i; + + spin_lock_irqsave(&local->queue_stop_reason_lock, flags); + + for_each_set_bit(i, &queues, hw->queues) + __ieee80211_stop_queue(hw, i, reason, refcounted); + + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); +} + +void ieee80211_stop_queues(struct ieee80211_hw *hw) +{ + ieee80211_stop_queues_by_reason(hw, IEEE80211_MAX_QUEUE_MAP, + IEEE80211_QUEUE_STOP_REASON_DRIVER, + false); +} +EXPORT_SYMBOL(ieee80211_stop_queues); + +int ieee80211_queue_stopped(struct ieee80211_hw *hw, int queue) +{ + struct ieee80211_local *local = hw_to_local(hw); + unsigned long flags; + int ret; + + if (WARN_ON(queue >= hw->queues)) + return true; + + spin_lock_irqsave(&local->queue_stop_reason_lock, flags); + ret = test_bit(IEEE80211_QUEUE_STOP_REASON_DRIVER, + &local->queue_stop_reasons[queue]); + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); + return ret; +} +EXPORT_SYMBOL(ieee80211_queue_stopped); + +void ieee80211_wake_queues_by_reason(struct ieee80211_hw *hw, + unsigned long queues, + enum queue_stop_reason reason, + bool refcounted) +{ + struct ieee80211_local *local = hw_to_local(hw); + unsigned long flags; + int i; + + spin_lock_irqsave(&local->queue_stop_reason_lock, flags); + + for_each_set_bit(i, &queues, hw->queues) + __ieee80211_wake_queue(hw, i, reason, refcounted, &flags); + + spin_unlock_irqrestore(&local->queue_stop_reason_lock, flags); +} + +void ieee80211_wake_queues(struct ieee80211_hw *hw) +{ + ieee80211_wake_queues_by_reason(hw, IEEE80211_MAX_QUEUE_MAP, + IEEE80211_QUEUE_STOP_REASON_DRIVER, + false); +} +EXPORT_SYMBOL(ieee80211_wake_queues); + +static unsigned int +ieee80211_get_vif_queues(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + unsigned int queues; + + if (sdata && ieee80211_hw_check(&local->hw, QUEUE_CONTROL)) { + int ac; + + queues = 0; + + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) + queues |= BIT(sdata->vif.hw_queue[ac]); + if (sdata->vif.cab_queue != IEEE80211_INVAL_HW_QUEUE) + queues |= BIT(sdata->vif.cab_queue); + } else { + /* all queues */ + queues = BIT(local->hw.queues) - 1; + } + + return queues; +} + +void __ieee80211_flush_queues(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + unsigned int queues, bool drop) +{ + if (!local->ops->flush) + return; + + /* + * If no queue was set, or if the HW doesn't support + * IEEE80211_HW_QUEUE_CONTROL - flush all queues + */ + if (!queues || !ieee80211_hw_check(&local->hw, QUEUE_CONTROL)) + queues = ieee80211_get_vif_queues(local, sdata); + + ieee80211_stop_queues_by_reason(&local->hw, queues, + IEEE80211_QUEUE_STOP_REASON_FLUSH, + false); + + drv_flush(local, sdata, queues, drop); + + ieee80211_wake_queues_by_reason(&local->hw, queues, + IEEE80211_QUEUE_STOP_REASON_FLUSH, + false); +} + +void ieee80211_flush_queues(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, bool drop) +{ + __ieee80211_flush_queues(local, sdata, 0, drop); +} + +void ieee80211_stop_vif_queues(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + enum queue_stop_reason reason) +{ + ieee80211_stop_queues_by_reason(&local->hw, + ieee80211_get_vif_queues(local, sdata), + reason, true); +} + +void ieee80211_wake_vif_queues(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + enum queue_stop_reason reason) +{ + ieee80211_wake_queues_by_reason(&local->hw, + ieee80211_get_vif_queues(local, sdata), + reason, true); +} + +static void __iterate_interfaces(struct ieee80211_local *local, + u32 iter_flags, + void (*iterator)(void *data, u8 *mac, + struct ieee80211_vif *vif), + void *data) +{ + struct ieee80211_sub_if_data *sdata; + bool active_only = iter_flags & IEEE80211_IFACE_ITER_ACTIVE; + + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + switch (sdata->vif.type) { + case NL80211_IFTYPE_MONITOR: + if (!(sdata->u.mntr.flags & MONITOR_FLAG_ACTIVE)) + continue; + break; + case NL80211_IFTYPE_AP_VLAN: + continue; + default: + break; + } + if (!(iter_flags & IEEE80211_IFACE_ITER_RESUME_ALL) && + active_only && !(sdata->flags & IEEE80211_SDATA_IN_DRIVER)) + continue; + if ((iter_flags & IEEE80211_IFACE_SKIP_SDATA_NOT_IN_DRIVER) && + !(sdata->flags & IEEE80211_SDATA_IN_DRIVER)) + continue; + if (ieee80211_sdata_running(sdata) || !active_only) + iterator(data, sdata->vif.addr, + &sdata->vif); + } + + sdata = rcu_dereference_check(local->monitor_sdata, + lockdep_is_held(&local->iflist_mtx) || + lockdep_is_held(&local->hw.wiphy->mtx)); + if (sdata && + (iter_flags & IEEE80211_IFACE_ITER_RESUME_ALL || !active_only || + sdata->flags & IEEE80211_SDATA_IN_DRIVER)) + iterator(data, sdata->vif.addr, &sdata->vif); +} + +void ieee80211_iterate_interfaces( + struct ieee80211_hw *hw, u32 iter_flags, + void (*iterator)(void *data, u8 *mac, + struct ieee80211_vif *vif), + void *data) +{ + struct ieee80211_local *local = hw_to_local(hw); + + mutex_lock(&local->iflist_mtx); + __iterate_interfaces(local, iter_flags, iterator, data); + mutex_unlock(&local->iflist_mtx); +} +EXPORT_SYMBOL_GPL(ieee80211_iterate_interfaces); + +void ieee80211_iterate_active_interfaces_atomic( + struct ieee80211_hw *hw, u32 iter_flags, + void (*iterator)(void *data, u8 *mac, + struct ieee80211_vif *vif), + void *data) +{ + struct ieee80211_local *local = hw_to_local(hw); + + rcu_read_lock(); + __iterate_interfaces(local, iter_flags | IEEE80211_IFACE_ITER_ACTIVE, + iterator, data); + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(ieee80211_iterate_active_interfaces_atomic); + +void ieee80211_iterate_active_interfaces_mtx( + struct ieee80211_hw *hw, u32 iter_flags, + void (*iterator)(void *data, u8 *mac, + struct ieee80211_vif *vif), + void *data) +{ + struct ieee80211_local *local = hw_to_local(hw); + + lockdep_assert_wiphy(hw->wiphy); + + __iterate_interfaces(local, iter_flags | IEEE80211_IFACE_ITER_ACTIVE, + iterator, data); +} +EXPORT_SYMBOL_GPL(ieee80211_iterate_active_interfaces_mtx); + +static void __iterate_stations(struct ieee80211_local *local, + void (*iterator)(void *data, + struct ieee80211_sta *sta), + void *data) +{ + struct sta_info *sta; + + list_for_each_entry_rcu(sta, &local->sta_list, list) { + if (!sta->uploaded) + continue; + + iterator(data, &sta->sta); + } +} + +void ieee80211_iterate_stations(struct ieee80211_hw *hw, + void (*iterator)(void *data, + struct ieee80211_sta *sta), + void *data) +{ + struct ieee80211_local *local = hw_to_local(hw); + + mutex_lock(&local->sta_mtx); + __iterate_stations(local, iterator, data); + mutex_unlock(&local->sta_mtx); +} +EXPORT_SYMBOL_GPL(ieee80211_iterate_stations); + +void ieee80211_iterate_stations_atomic(struct ieee80211_hw *hw, + void (*iterator)(void *data, + struct ieee80211_sta *sta), + void *data) +{ + struct ieee80211_local *local = hw_to_local(hw); + + rcu_read_lock(); + __iterate_stations(local, iterator, data); + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(ieee80211_iterate_stations_atomic); + +struct ieee80211_vif *wdev_to_ieee80211_vif(struct wireless_dev *wdev) +{ + struct ieee80211_sub_if_data *sdata = IEEE80211_WDEV_TO_SUB_IF(wdev); + + if (!ieee80211_sdata_running(sdata) || + !(sdata->flags & IEEE80211_SDATA_IN_DRIVER)) + return NULL; + return &sdata->vif; +} +EXPORT_SYMBOL_GPL(wdev_to_ieee80211_vif); + +struct wireless_dev *ieee80211_vif_to_wdev(struct ieee80211_vif *vif) +{ + if (!vif) + return NULL; + + return &vif_to_sdata(vif)->wdev; +} +EXPORT_SYMBOL_GPL(ieee80211_vif_to_wdev); + +/* + * Nothing should have been stuffed into the workqueue during + * the suspend->resume cycle. Since we can't check each caller + * of this function if we are already quiescing / suspended, + * check here and don't WARN since this can actually happen when + * the rx path (for example) is racing against __ieee80211_suspend + * and suspending / quiescing was set after the rx path checked + * them. + */ +static bool ieee80211_can_queue_work(struct ieee80211_local *local) +{ + if (local->quiescing || (local->suspended && !local->resuming)) { + pr_warn("queueing ieee80211 work while going to suspend\n"); + return false; + } + + return true; +} + +void ieee80211_queue_work(struct ieee80211_hw *hw, struct work_struct *work) +{ + struct ieee80211_local *local = hw_to_local(hw); + + if (!ieee80211_can_queue_work(local)) + return; + + queue_work(local->workqueue, work); +} +EXPORT_SYMBOL(ieee80211_queue_work); + +void ieee80211_queue_delayed_work(struct ieee80211_hw *hw, + struct delayed_work *dwork, + unsigned long delay) +{ + struct ieee80211_local *local = hw_to_local(hw); + + if (!ieee80211_can_queue_work(local)) + return; + + queue_delayed_work(local->workqueue, dwork, delay); +} +EXPORT_SYMBOL(ieee80211_queue_delayed_work); + +static void +ieee80211_parse_extension_element(u32 *crc, + const struct element *elem, + struct ieee802_11_elems *elems, + struct ieee80211_elems_parse_params *params) +{ + const void *data = elem->data + 1; + u8 len; + + if (!elem->datalen) + return; + + len = elem->datalen - 1; + + switch (elem->data[0]) { + case WLAN_EID_EXT_HE_MU_EDCA: + if (len >= sizeof(*elems->mu_edca_param_set)) { + elems->mu_edca_param_set = data; + if (crc) + *crc = crc32_be(*crc, (void *)elem, + elem->datalen + 2); + } + break; + case WLAN_EID_EXT_HE_CAPABILITY: + if (ieee80211_he_capa_size_ok(data, len)) { + elems->he_cap = data; + elems->he_cap_len = len; + } + break; + case WLAN_EID_EXT_HE_OPERATION: + if (len >= sizeof(*elems->he_operation) && + len >= ieee80211_he_oper_size(data) - 1) { + if (crc) + *crc = crc32_be(*crc, (void *)elem, + elem->datalen + 2); + elems->he_operation = data; + } + break; + case WLAN_EID_EXT_UORA: + if (len >= 1) + elems->uora_element = data; + break; + case WLAN_EID_EXT_MAX_CHANNEL_SWITCH_TIME: + if (len == 3) + elems->max_channel_switch_time = data; + break; + case WLAN_EID_EXT_MULTIPLE_BSSID_CONFIGURATION: + if (len >= sizeof(*elems->mbssid_config_ie)) + elems->mbssid_config_ie = data; + break; + case WLAN_EID_EXT_HE_SPR: + if (len >= sizeof(*elems->he_spr) && + len >= ieee80211_he_spr_size(data)) + elems->he_spr = data; + break; + case WLAN_EID_EXT_HE_6GHZ_CAPA: + if (len >= sizeof(*elems->he_6ghz_capa)) + elems->he_6ghz_capa = data; + break; + case WLAN_EID_EXT_EHT_CAPABILITY: + if (ieee80211_eht_capa_size_ok(elems->he_cap, + data, len, + params->from_ap)) { + elems->eht_cap = data; + elems->eht_cap_len = len; + } + break; + case WLAN_EID_EXT_EHT_OPERATION: + if (ieee80211_eht_oper_size_ok(data, len)) + elems->eht_operation = data; + break; + case WLAN_EID_EXT_EHT_MULTI_LINK: + if (ieee80211_mle_size_ok(data, len)) + elems->multi_link = (void *)data; + break; + } +} + +static u32 +_ieee802_11_parse_elems_full(struct ieee80211_elems_parse_params *params, + struct ieee802_11_elems *elems, + const struct element *check_inherit) +{ + const struct element *elem; + bool calc_crc = params->filter != 0; + DECLARE_BITMAP(seen_elems, 256); + u32 crc = params->crc; + const u8 *ie; + + bitmap_zero(seen_elems, 256); + + for_each_element(elem, params->start, params->len) { + bool elem_parse_failed; + u8 id = elem->id; + u8 elen = elem->datalen; + const u8 *pos = elem->data; + + if (check_inherit && + !cfg80211_is_element_inherited(elem, + check_inherit)) + continue; + + switch (id) { + case WLAN_EID_SSID: + case WLAN_EID_SUPP_RATES: + case WLAN_EID_FH_PARAMS: + case WLAN_EID_DS_PARAMS: + case WLAN_EID_CF_PARAMS: + case WLAN_EID_TIM: + case WLAN_EID_IBSS_PARAMS: + case WLAN_EID_CHALLENGE: + case WLAN_EID_RSN: + case WLAN_EID_ERP_INFO: + case WLAN_EID_EXT_SUPP_RATES: + case WLAN_EID_HT_CAPABILITY: + case WLAN_EID_HT_OPERATION: + case WLAN_EID_VHT_CAPABILITY: + case WLAN_EID_VHT_OPERATION: + case WLAN_EID_MESH_ID: + case WLAN_EID_MESH_CONFIG: + case WLAN_EID_PEER_MGMT: + case WLAN_EID_PREQ: + case WLAN_EID_PREP: + case WLAN_EID_PERR: + case WLAN_EID_RANN: + case WLAN_EID_CHANNEL_SWITCH: + case WLAN_EID_EXT_CHANSWITCH_ANN: + case WLAN_EID_COUNTRY: + case WLAN_EID_PWR_CONSTRAINT: + case WLAN_EID_TIMEOUT_INTERVAL: + case WLAN_EID_SECONDARY_CHANNEL_OFFSET: + case WLAN_EID_WIDE_BW_CHANNEL_SWITCH: + case WLAN_EID_CHAN_SWITCH_PARAM: + case WLAN_EID_EXT_CAPABILITY: + case WLAN_EID_CHAN_SWITCH_TIMING: + case WLAN_EID_LINK_ID: + case WLAN_EID_BSS_MAX_IDLE_PERIOD: + case WLAN_EID_RSNX: + case WLAN_EID_S1G_BCN_COMPAT: + case WLAN_EID_S1G_CAPABILITIES: + case WLAN_EID_S1G_OPERATION: + case WLAN_EID_AID_RESPONSE: + case WLAN_EID_S1G_SHORT_BCN_INTERVAL: + /* + * not listing WLAN_EID_CHANNEL_SWITCH_WRAPPER -- it seems possible + * that if the content gets bigger it might be needed more than once + */ + if (test_bit(id, seen_elems)) { + elems->parse_error = true; + continue; + } + break; + } + + if (calc_crc && id < 64 && (params->filter & (1ULL << id))) + crc = crc32_be(crc, pos - 2, elen + 2); + + elem_parse_failed = false; + + switch (id) { + case WLAN_EID_LINK_ID: + if (elen + 2 < sizeof(struct ieee80211_tdls_lnkie)) { + elem_parse_failed = true; + break; + } + elems->lnk_id = (void *)(pos - 2); + break; + case WLAN_EID_CHAN_SWITCH_TIMING: + if (elen < sizeof(struct ieee80211_ch_switch_timing)) { + elem_parse_failed = true; + break; + } + elems->ch_sw_timing = (void *)pos; + break; + case WLAN_EID_EXT_CAPABILITY: + elems->ext_capab = pos; + elems->ext_capab_len = elen; + break; + case WLAN_EID_SSID: + elems->ssid = pos; + elems->ssid_len = elen; + break; + case WLAN_EID_SUPP_RATES: + elems->supp_rates = pos; + elems->supp_rates_len = elen; + break; + case WLAN_EID_DS_PARAMS: + if (elen >= 1) + elems->ds_params = pos; + else + elem_parse_failed = true; + break; + case WLAN_EID_TIM: + if (elen >= sizeof(struct ieee80211_tim_ie)) { + elems->tim = (void *)pos; + elems->tim_len = elen; + } else + elem_parse_failed = true; + break; + case WLAN_EID_VENDOR_SPECIFIC: + if (elen >= 4 && pos[0] == 0x00 && pos[1] == 0x50 && + pos[2] == 0xf2) { + /* Microsoft OUI (00:50:F2) */ + + if (calc_crc) + crc = crc32_be(crc, pos - 2, elen + 2); + + if (elen >= 5 && pos[3] == 2) { + /* OUI Type 2 - WMM IE */ + if (pos[4] == 0) { + elems->wmm_info = pos; + elems->wmm_info_len = elen; + } else if (pos[4] == 1) { + elems->wmm_param = pos; + elems->wmm_param_len = elen; + } + } + } + break; + case WLAN_EID_RSN: + elems->rsn = pos; + elems->rsn_len = elen; + break; + case WLAN_EID_ERP_INFO: + if (elen >= 1) + elems->erp_info = pos; + else + elem_parse_failed = true; + break; + case WLAN_EID_EXT_SUPP_RATES: + elems->ext_supp_rates = pos; + elems->ext_supp_rates_len = elen; + break; + case WLAN_EID_HT_CAPABILITY: + if (elen >= sizeof(struct ieee80211_ht_cap)) + elems->ht_cap_elem = (void *)pos; + else + elem_parse_failed = true; + break; + case WLAN_EID_HT_OPERATION: + if (elen >= sizeof(struct ieee80211_ht_operation)) + elems->ht_operation = (void *)pos; + else + elem_parse_failed = true; + break; + case WLAN_EID_VHT_CAPABILITY: + if (elen >= sizeof(struct ieee80211_vht_cap)) + elems->vht_cap_elem = (void *)pos; + else + elem_parse_failed = true; + break; + case WLAN_EID_VHT_OPERATION: + if (elen >= sizeof(struct ieee80211_vht_operation)) { + elems->vht_operation = (void *)pos; + if (calc_crc) + crc = crc32_be(crc, pos - 2, elen + 2); + break; + } + elem_parse_failed = true; + break; + case WLAN_EID_OPMODE_NOTIF: + if (elen > 0) { + elems->opmode_notif = pos; + if (calc_crc) + crc = crc32_be(crc, pos - 2, elen + 2); + break; + } + elem_parse_failed = true; + break; + case WLAN_EID_MESH_ID: + elems->mesh_id = pos; + elems->mesh_id_len = elen; + break; + case WLAN_EID_MESH_CONFIG: + if (elen >= sizeof(struct ieee80211_meshconf_ie)) + elems->mesh_config = (void *)pos; + else + elem_parse_failed = true; + break; + case WLAN_EID_PEER_MGMT: + elems->peering = pos; + elems->peering_len = elen; + break; + case WLAN_EID_MESH_AWAKE_WINDOW: + if (elen >= 2) + elems->awake_window = (void *)pos; + break; + case WLAN_EID_PREQ: + elems->preq = pos; + elems->preq_len = elen; + break; + case WLAN_EID_PREP: + elems->prep = pos; + elems->prep_len = elen; + break; + case WLAN_EID_PERR: + elems->perr = pos; + elems->perr_len = elen; + break; + case WLAN_EID_RANN: + if (elen >= sizeof(struct ieee80211_rann_ie)) + elems->rann = (void *)pos; + else + elem_parse_failed = true; + break; + case WLAN_EID_CHANNEL_SWITCH: + if (elen != sizeof(struct ieee80211_channel_sw_ie)) { + elem_parse_failed = true; + break; + } + elems->ch_switch_ie = (void *)pos; + break; + case WLAN_EID_EXT_CHANSWITCH_ANN: + if (elen != sizeof(struct ieee80211_ext_chansw_ie)) { + elem_parse_failed = true; + break; + } + elems->ext_chansw_ie = (void *)pos; + break; + case WLAN_EID_SECONDARY_CHANNEL_OFFSET: + if (elen != sizeof(struct ieee80211_sec_chan_offs_ie)) { + elem_parse_failed = true; + break; + } + elems->sec_chan_offs = (void *)pos; + break; + case WLAN_EID_CHAN_SWITCH_PARAM: + if (elen < + sizeof(*elems->mesh_chansw_params_ie)) { + elem_parse_failed = true; + break; + } + elems->mesh_chansw_params_ie = (void *)pos; + break; + case WLAN_EID_WIDE_BW_CHANNEL_SWITCH: + if (!params->action || + elen < sizeof(*elems->wide_bw_chansw_ie)) { + elem_parse_failed = true; + break; + } + elems->wide_bw_chansw_ie = (void *)pos; + break; + case WLAN_EID_CHANNEL_SWITCH_WRAPPER: + if (params->action) { + elem_parse_failed = true; + break; + } + /* + * This is a bit tricky, but as we only care about + * the wide bandwidth channel switch element, so + * just parse it out manually. + */ + ie = cfg80211_find_ie(WLAN_EID_WIDE_BW_CHANNEL_SWITCH, + pos, elen); + if (ie) { + if (ie[1] >= sizeof(*elems->wide_bw_chansw_ie)) + elems->wide_bw_chansw_ie = + (void *)(ie + 2); + else + elem_parse_failed = true; + } + break; + case WLAN_EID_COUNTRY: + elems->country_elem = pos; + elems->country_elem_len = elen; + break; + case WLAN_EID_PWR_CONSTRAINT: + if (elen != 1) { + elem_parse_failed = true; + break; + } + elems->pwr_constr_elem = pos; + break; + case WLAN_EID_CISCO_VENDOR_SPECIFIC: + /* Lots of different options exist, but we only care + * about the Dynamic Transmit Power Control element. + * First check for the Cisco OUI, then for the DTPC + * tag (0x00). + */ + if (elen < 4) { + elem_parse_failed = true; + break; + } + + if (pos[0] != 0x00 || pos[1] != 0x40 || + pos[2] != 0x96 || pos[3] != 0x00) + break; + + if (elen != 6) { + elem_parse_failed = true; + break; + } + + if (calc_crc) + crc = crc32_be(crc, pos - 2, elen + 2); + + elems->cisco_dtpc_elem = pos; + break; + case WLAN_EID_ADDBA_EXT: + if (elen < sizeof(struct ieee80211_addba_ext_ie)) { + elem_parse_failed = true; + break; + } + elems->addba_ext_ie = (void *)pos; + break; + case WLAN_EID_TIMEOUT_INTERVAL: + if (elen >= sizeof(struct ieee80211_timeout_interval_ie)) + elems->timeout_int = (void *)pos; + else + elem_parse_failed = true; + break; + case WLAN_EID_BSS_MAX_IDLE_PERIOD: + if (elen >= sizeof(*elems->max_idle_period_ie)) + elems->max_idle_period_ie = (void *)pos; + break; + case WLAN_EID_RSNX: + elems->rsnx = pos; + elems->rsnx_len = elen; + break; + case WLAN_EID_TX_POWER_ENVELOPE: + if (elen < 1 || + elen > sizeof(struct ieee80211_tx_pwr_env)) + break; + + if (elems->tx_pwr_env_num >= ARRAY_SIZE(elems->tx_pwr_env)) + break; + + elems->tx_pwr_env[elems->tx_pwr_env_num] = (void *)pos; + elems->tx_pwr_env_len[elems->tx_pwr_env_num] = elen; + elems->tx_pwr_env_num++; + break; + case WLAN_EID_EXTENSION: + ieee80211_parse_extension_element(calc_crc ? + &crc : NULL, + elem, elems, params); + break; + case WLAN_EID_S1G_CAPABILITIES: + if (elen >= sizeof(*elems->s1g_capab)) + elems->s1g_capab = (void *)pos; + else + elem_parse_failed = true; + break; + case WLAN_EID_S1G_OPERATION: + if (elen == sizeof(*elems->s1g_oper)) + elems->s1g_oper = (void *)pos; + else + elem_parse_failed = true; + break; + case WLAN_EID_S1G_BCN_COMPAT: + if (elen == sizeof(*elems->s1g_bcn_compat)) + elems->s1g_bcn_compat = (void *)pos; + else + elem_parse_failed = true; + break; + case WLAN_EID_AID_RESPONSE: + if (elen == sizeof(struct ieee80211_aid_response_ie)) + elems->aid_resp = (void *)pos; + else + elem_parse_failed = true; + break; + default: + break; + } + + if (elem_parse_failed) + elems->parse_error = true; + else + __set_bit(id, seen_elems); + } + + if (!for_each_element_completed(elem, params->start, params->len)) + elems->parse_error = true; + + return crc; +} + +static size_t ieee802_11_find_bssid_profile(const u8 *start, size_t len, + struct ieee802_11_elems *elems, + struct cfg80211_bss *bss, + u8 *nontransmitted_profile) +{ + const struct element *elem, *sub; + size_t profile_len = 0; + bool found = false; + + if (!bss || !bss->transmitted_bss) + return profile_len; + + for_each_element_id(elem, WLAN_EID_MULTIPLE_BSSID, start, len) { + if (elem->datalen < 2) + continue; + if (elem->data[0] < 1 || elem->data[0] > 8) + continue; + + for_each_element(sub, elem->data + 1, elem->datalen - 1) { + u8 new_bssid[ETH_ALEN]; + const u8 *index; + + if (sub->id != 0 || sub->datalen < 4) { + /* not a valid BSS profile */ + continue; + } + + if (sub->data[0] != WLAN_EID_NON_TX_BSSID_CAP || + sub->data[1] != 2) { + /* The first element of the + * Nontransmitted BSSID Profile is not + * the Nontransmitted BSSID Capability + * element. + */ + continue; + } + + memset(nontransmitted_profile, 0, len); + profile_len = cfg80211_merge_profile(start, len, + elem, + sub, + nontransmitted_profile, + len); + + /* found a Nontransmitted BSSID Profile */ + index = cfg80211_find_ie(WLAN_EID_MULTI_BSSID_IDX, + nontransmitted_profile, + profile_len); + if (!index || index[1] < 1 || index[2] == 0) { + /* Invalid MBSSID Index element */ + continue; + } + + cfg80211_gen_new_bssid(bss->transmitted_bss->bssid, + elem->data[0], + index[2], + new_bssid); + if (ether_addr_equal(new_bssid, bss->bssid)) { + found = true; + elems->bssid_index_len = index[1]; + elems->bssid_index = (void *)&index[2]; + break; + } + } + } + + return found ? profile_len : 0; +} + +struct ieee802_11_elems * +ieee802_11_parse_elems_full(struct ieee80211_elems_parse_params *params) +{ + struct ieee802_11_elems *elems; + const struct element *non_inherit = NULL; + u8 *nontransmitted_profile; + int nontransmitted_profile_len = 0; + size_t scratch_len = params->len; + + elems = kzalloc(sizeof(*elems) + scratch_len, GFP_ATOMIC); + if (!elems) + return NULL; + elems->ie_start = params->start; + elems->total_len = params->len; + elems->scratch_len = scratch_len; + elems->scratch_pos = elems->scratch; + + nontransmitted_profile = elems->scratch_pos; + nontransmitted_profile_len = + ieee802_11_find_bssid_profile(params->start, params->len, + elems, params->bss, + nontransmitted_profile); + elems->scratch_pos += nontransmitted_profile_len; + elems->scratch_len -= nontransmitted_profile_len; + non_inherit = cfg80211_find_ext_elem(WLAN_EID_EXT_NON_INHERITANCE, + nontransmitted_profile, + nontransmitted_profile_len); + + elems->crc = _ieee802_11_parse_elems_full(params, elems, non_inherit); + + /* Override with nontransmitted profile, if found */ + if (nontransmitted_profile_len) { + struct ieee80211_elems_parse_params sub = { + .start = nontransmitted_profile, + .len = nontransmitted_profile_len, + .action = params->action, + .link_id = params->link_id, + }; + + _ieee802_11_parse_elems_full(&sub, elems, NULL); + } + + if (elems->tim && !elems->parse_error) { + const struct ieee80211_tim_ie *tim_ie = elems->tim; + + elems->dtim_period = tim_ie->dtim_period; + elems->dtim_count = tim_ie->dtim_count; + } + + /* Override DTIM period and count if needed */ + if (elems->bssid_index && + elems->bssid_index_len >= + offsetofend(struct ieee80211_bssid_index, dtim_period)) + elems->dtim_period = elems->bssid_index->dtim_period; + + if (elems->bssid_index && + elems->bssid_index_len >= + offsetofend(struct ieee80211_bssid_index, dtim_count)) + elems->dtim_count = elems->bssid_index->dtim_count; + + return elems; +} + +void ieee80211_regulatory_limit_wmm_params(struct ieee80211_sub_if_data *sdata, + struct ieee80211_tx_queue_params + *qparam, int ac) +{ + struct ieee80211_chanctx_conf *chanctx_conf; + const struct ieee80211_reg_rule *rrule; + const struct ieee80211_wmm_ac *wmm_ac; + u16 center_freq = 0; + + if (sdata->vif.type != NL80211_IFTYPE_AP && + sdata->vif.type != NL80211_IFTYPE_STATION) + return; + + rcu_read_lock(); + chanctx_conf = rcu_dereference(sdata->vif.bss_conf.chanctx_conf); + if (chanctx_conf) + center_freq = chanctx_conf->def.chan->center_freq; + + if (!center_freq) { + rcu_read_unlock(); + return; + } + + rrule = freq_reg_info(sdata->wdev.wiphy, MHZ_TO_KHZ(center_freq)); + + if (IS_ERR_OR_NULL(rrule) || !rrule->has_wmm) { + rcu_read_unlock(); + return; + } + + if (sdata->vif.type == NL80211_IFTYPE_AP) + wmm_ac = &rrule->wmm_rule.ap[ac]; + else + wmm_ac = &rrule->wmm_rule.client[ac]; + qparam->cw_min = max_t(u16, qparam->cw_min, wmm_ac->cw_min); + qparam->cw_max = max_t(u16, qparam->cw_max, wmm_ac->cw_max); + qparam->aifs = max_t(u8, qparam->aifs, wmm_ac->aifsn); + qparam->txop = min_t(u16, qparam->txop, wmm_ac->cot / 32); + rcu_read_unlock(); +} + +void ieee80211_set_wmm_default(struct ieee80211_link_data *link, + bool bss_notify, bool enable_qos) +{ + struct ieee80211_sub_if_data *sdata = link->sdata; + struct ieee80211_local *local = sdata->local; + struct ieee80211_tx_queue_params qparam; + struct ieee80211_chanctx_conf *chanctx_conf; + int ac; + bool use_11b; + bool is_ocb; /* Use another EDCA parameters if dot11OCBActivated=true */ + int aCWmin, aCWmax; + + if (!local->ops->conf_tx) + return; + + if (local->hw.queues < IEEE80211_NUM_ACS) + return; + + memset(&qparam, 0, sizeof(qparam)); + + rcu_read_lock(); + chanctx_conf = rcu_dereference(link->conf->chanctx_conf); + use_11b = (chanctx_conf && + chanctx_conf->def.chan->band == NL80211_BAND_2GHZ) && + !link->operating_11g_mode; + rcu_read_unlock(); + + is_ocb = (sdata->vif.type == NL80211_IFTYPE_OCB); + + /* Set defaults according to 802.11-2007 Table 7-37 */ + aCWmax = 1023; + if (use_11b) + aCWmin = 31; + else + aCWmin = 15; + + /* Confiure old 802.11b/g medium access rules. */ + qparam.cw_max = aCWmax; + qparam.cw_min = aCWmin; + qparam.txop = 0; + qparam.aifs = 2; + + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { + /* Update if QoS is enabled. */ + if (enable_qos) { + switch (ac) { + case IEEE80211_AC_BK: + qparam.cw_max = aCWmax; + qparam.cw_min = aCWmin; + qparam.txop = 0; + if (is_ocb) + qparam.aifs = 9; + else + qparam.aifs = 7; + break; + /* never happens but let's not leave undefined */ + default: + case IEEE80211_AC_BE: + qparam.cw_max = aCWmax; + qparam.cw_min = aCWmin; + qparam.txop = 0; + if (is_ocb) + qparam.aifs = 6; + else + qparam.aifs = 3; + break; + case IEEE80211_AC_VI: + qparam.cw_max = aCWmin; + qparam.cw_min = (aCWmin + 1) / 2 - 1; + if (is_ocb) + qparam.txop = 0; + else if (use_11b) + qparam.txop = 6016/32; + else + qparam.txop = 3008/32; + + if (is_ocb) + qparam.aifs = 3; + else + qparam.aifs = 2; + break; + case IEEE80211_AC_VO: + qparam.cw_max = (aCWmin + 1) / 2 - 1; + qparam.cw_min = (aCWmin + 1) / 4 - 1; + if (is_ocb) + qparam.txop = 0; + else if (use_11b) + qparam.txop = 3264/32; + else + qparam.txop = 1504/32; + qparam.aifs = 2; + break; + } + } + ieee80211_regulatory_limit_wmm_params(sdata, &qparam, ac); + + qparam.uapsd = false; + + link->tx_conf[ac] = qparam; + drv_conf_tx(local, link, ac, &qparam); + } + + if (sdata->vif.type != NL80211_IFTYPE_MONITOR && + sdata->vif.type != NL80211_IFTYPE_P2P_DEVICE && + sdata->vif.type != NL80211_IFTYPE_NAN) { + link->conf->qos = enable_qos; + if (bss_notify) + ieee80211_link_info_change_notify(sdata, link, + BSS_CHANGED_QOS); + } +} + +void ieee80211_send_auth(struct ieee80211_sub_if_data *sdata, + u16 transaction, u16 auth_alg, u16 status, + const u8 *extra, size_t extra_len, const u8 *da, + const u8 *bssid, const u8 *key, u8 key_len, u8 key_idx, + u32 tx_flags) +{ + struct ieee80211_local *local = sdata->local; + struct sk_buff *skb; + struct ieee80211_mgmt *mgmt; + bool multi_link = sdata->vif.valid_links; + struct { + u8 id; + u8 len; + u8 ext_id; + struct ieee80211_multi_link_elem ml; + struct ieee80211_mle_basic_common_info basic; + } __packed mle = { + .id = WLAN_EID_EXTENSION, + .len = sizeof(mle) - 2, + .ext_id = WLAN_EID_EXT_EHT_MULTI_LINK, + .ml.control = cpu_to_le16(IEEE80211_ML_CONTROL_TYPE_BASIC), + .basic.len = sizeof(mle.basic), + }; + int err; + + memcpy(mle.basic.mld_mac_addr, sdata->vif.addr, ETH_ALEN); + + /* 24 + 6 = header + auth_algo + auth_transaction + status_code */ + skb = dev_alloc_skb(local->hw.extra_tx_headroom + IEEE80211_WEP_IV_LEN + + 24 + 6 + extra_len + IEEE80211_WEP_ICV_LEN + + multi_link * sizeof(mle)); + if (!skb) + return; + + skb_reserve(skb, local->hw.extra_tx_headroom + IEEE80211_WEP_IV_LEN); + + mgmt = skb_put_zero(skb, 24 + 6); + mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_AUTH); + memcpy(mgmt->da, da, ETH_ALEN); + memcpy(mgmt->sa, sdata->vif.addr, ETH_ALEN); + memcpy(mgmt->bssid, bssid, ETH_ALEN); + mgmt->u.auth.auth_alg = cpu_to_le16(auth_alg); + mgmt->u.auth.auth_transaction = cpu_to_le16(transaction); + mgmt->u.auth.status_code = cpu_to_le16(status); + if (extra) + skb_put_data(skb, extra, extra_len); + if (multi_link) + skb_put_data(skb, &mle, sizeof(mle)); + + if (auth_alg == WLAN_AUTH_SHARED_KEY && transaction == 3) { + mgmt->frame_control |= cpu_to_le16(IEEE80211_FCTL_PROTECTED); + err = ieee80211_wep_encrypt(local, skb, key, key_len, key_idx); + if (WARN_ON(err)) { + kfree_skb(skb); + return; + } + } + + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_INTFL_DONT_ENCRYPT | + tx_flags; + ieee80211_tx_skb(sdata, skb); +} + +void ieee80211_send_deauth_disassoc(struct ieee80211_sub_if_data *sdata, + const u8 *da, const u8 *bssid, + u16 stype, u16 reason, + bool send_frame, u8 *frame_buf) +{ + struct ieee80211_local *local = sdata->local; + struct sk_buff *skb; + struct ieee80211_mgmt *mgmt = (void *)frame_buf; + + /* build frame */ + mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | stype); + mgmt->duration = 0; /* initialize only */ + mgmt->seq_ctrl = 0; /* initialize only */ + memcpy(mgmt->da, da, ETH_ALEN); + memcpy(mgmt->sa, sdata->vif.addr, ETH_ALEN); + memcpy(mgmt->bssid, bssid, ETH_ALEN); + /* u.deauth.reason_code == u.disassoc.reason_code */ + mgmt->u.deauth.reason_code = cpu_to_le16(reason); + + if (send_frame) { + skb = dev_alloc_skb(local->hw.extra_tx_headroom + + IEEE80211_DEAUTH_FRAME_LEN); + if (!skb) + return; + + skb_reserve(skb, local->hw.extra_tx_headroom); + + /* copy in frame */ + skb_put_data(skb, mgmt, IEEE80211_DEAUTH_FRAME_LEN); + + if (sdata->vif.type != NL80211_IFTYPE_STATION || + !(sdata->u.mgd.flags & IEEE80211_STA_MFP_ENABLED)) + IEEE80211_SKB_CB(skb)->flags |= + IEEE80211_TX_INTFL_DONT_ENCRYPT; + + ieee80211_tx_skb(sdata, skb); + } +} + +static u8 *ieee80211_write_he_6ghz_cap(u8 *pos, __le16 cap, u8 *end) +{ + if ((end - pos) < 5) + return pos; + + *pos++ = WLAN_EID_EXTENSION; + *pos++ = 1 + sizeof(cap); + *pos++ = WLAN_EID_EXT_HE_6GHZ_CAPA; + memcpy(pos, &cap, sizeof(cap)); + + return pos + 2; +} + +static int ieee80211_build_preq_ies_band(struct ieee80211_sub_if_data *sdata, + u8 *buffer, size_t buffer_len, + const u8 *ie, size_t ie_len, + enum nl80211_band band, + u32 rate_mask, + struct cfg80211_chan_def *chandef, + size_t *offset, u32 flags) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_supported_band *sband; + const struct ieee80211_sta_he_cap *he_cap; + const struct ieee80211_sta_eht_cap *eht_cap; + u8 *pos = buffer, *end = buffer + buffer_len; + size_t noffset; + int supp_rates_len, i; + u8 rates[32]; + int num_rates; + int ext_rates_len; + int shift; + u32 rate_flags; + bool have_80mhz = false; + + *offset = 0; + + sband = local->hw.wiphy->bands[band]; + if (WARN_ON_ONCE(!sband)) + return 0; + + rate_flags = ieee80211_chandef_rate_flags(chandef); + shift = ieee80211_chandef_get_shift(chandef); + + num_rates = 0; + for (i = 0; i < sband->n_bitrates; i++) { + if ((BIT(i) & rate_mask) == 0) + continue; /* skip rate */ + if ((rate_flags & sband->bitrates[i].flags) != rate_flags) + continue; + + rates[num_rates++] = + (u8) DIV_ROUND_UP(sband->bitrates[i].bitrate, + (1 << shift) * 5); + } + + supp_rates_len = min_t(int, num_rates, 8); + + if (end - pos < 2 + supp_rates_len) + goto out_err; + *pos++ = WLAN_EID_SUPP_RATES; + *pos++ = supp_rates_len; + memcpy(pos, rates, supp_rates_len); + pos += supp_rates_len; + + /* insert "request information" if in custom IEs */ + if (ie && ie_len) { + static const u8 before_extrates[] = { + WLAN_EID_SSID, + WLAN_EID_SUPP_RATES, + WLAN_EID_REQUEST, + }; + noffset = ieee80211_ie_split(ie, ie_len, + before_extrates, + ARRAY_SIZE(before_extrates), + *offset); + if (end - pos < noffset - *offset) + goto out_err; + memcpy(pos, ie + *offset, noffset - *offset); + pos += noffset - *offset; + *offset = noffset; + } + + ext_rates_len = num_rates - supp_rates_len; + if (ext_rates_len > 0) { + if (end - pos < 2 + ext_rates_len) + goto out_err; + *pos++ = WLAN_EID_EXT_SUPP_RATES; + *pos++ = ext_rates_len; + memcpy(pos, rates + supp_rates_len, ext_rates_len); + pos += ext_rates_len; + } + + if (chandef->chan && sband->band == NL80211_BAND_2GHZ) { + if (end - pos < 3) + goto out_err; + *pos++ = WLAN_EID_DS_PARAMS; + *pos++ = 1; + *pos++ = ieee80211_frequency_to_channel( + chandef->chan->center_freq); + } + + if (flags & IEEE80211_PROBE_FLAG_MIN_CONTENT) + goto done; + + /* insert custom IEs that go before HT */ + if (ie && ie_len) { + static const u8 before_ht[] = { + /* + * no need to list the ones split off already + * (or generated here) + */ + WLAN_EID_DS_PARAMS, + WLAN_EID_SUPPORTED_REGULATORY_CLASSES, + }; + noffset = ieee80211_ie_split(ie, ie_len, + before_ht, ARRAY_SIZE(before_ht), + *offset); + if (end - pos < noffset - *offset) + goto out_err; + memcpy(pos, ie + *offset, noffset - *offset); + pos += noffset - *offset; + *offset = noffset; + } + + if (sband->ht_cap.ht_supported) { + if (end - pos < 2 + sizeof(struct ieee80211_ht_cap)) + goto out_err; + pos = ieee80211_ie_build_ht_cap(pos, &sband->ht_cap, + sband->ht_cap.cap); + } + + /* insert custom IEs that go before VHT */ + if (ie && ie_len) { + static const u8 before_vht[] = { + /* + * no need to list the ones split off already + * (or generated here) + */ + WLAN_EID_BSS_COEX_2040, + WLAN_EID_EXT_CAPABILITY, + WLAN_EID_SSID_LIST, + WLAN_EID_CHANNEL_USAGE, + WLAN_EID_INTERWORKING, + WLAN_EID_MESH_ID, + /* 60 GHz (Multi-band, DMG, MMS) can't happen */ + }; + noffset = ieee80211_ie_split(ie, ie_len, + before_vht, ARRAY_SIZE(before_vht), + *offset); + if (end - pos < noffset - *offset) + goto out_err; + memcpy(pos, ie + *offset, noffset - *offset); + pos += noffset - *offset; + *offset = noffset; + } + + /* Check if any channel in this sband supports at least 80 MHz */ + for (i = 0; i < sband->n_channels; i++) { + if (sband->channels[i].flags & (IEEE80211_CHAN_DISABLED | + IEEE80211_CHAN_NO_80MHZ)) + continue; + + have_80mhz = true; + break; + } + + if (sband->vht_cap.vht_supported && have_80mhz) { + if (end - pos < 2 + sizeof(struct ieee80211_vht_cap)) + goto out_err; + pos = ieee80211_ie_build_vht_cap(pos, &sband->vht_cap, + sband->vht_cap.cap); + } + + /* insert custom IEs that go before HE */ + if (ie && ie_len) { + static const u8 before_he[] = { + /* + * no need to list the ones split off before VHT + * or generated here + */ + WLAN_EID_EXTENSION, WLAN_EID_EXT_FILS_REQ_PARAMS, + WLAN_EID_AP_CSN, + /* TODO: add 11ah/11aj/11ak elements */ + }; + noffset = ieee80211_ie_split(ie, ie_len, + before_he, ARRAY_SIZE(before_he), + *offset); + if (end - pos < noffset - *offset) + goto out_err; + memcpy(pos, ie + *offset, noffset - *offset); + pos += noffset - *offset; + *offset = noffset; + } + + he_cap = ieee80211_get_he_iftype_cap(sband, + ieee80211_vif_type_p2p(&sdata->vif)); + if (he_cap && + cfg80211_any_usable_channels(local->hw.wiphy, BIT(sband->band), + IEEE80211_CHAN_NO_HE)) { + pos = ieee80211_ie_build_he_cap(0, pos, he_cap, end); + if (!pos) + goto out_err; + } + + eht_cap = ieee80211_get_eht_iftype_cap(sband, + ieee80211_vif_type_p2p(&sdata->vif)); + + if (eht_cap && + cfg80211_any_usable_channels(local->hw.wiphy, BIT(sband->band), + IEEE80211_CHAN_NO_HE | + IEEE80211_CHAN_NO_EHT)) { + pos = ieee80211_ie_build_eht_cap(pos, he_cap, eht_cap, end, + sdata->vif.type == NL80211_IFTYPE_AP); + if (!pos) + goto out_err; + } + + if (cfg80211_any_usable_channels(local->hw.wiphy, + BIT(NL80211_BAND_6GHZ), + IEEE80211_CHAN_NO_HE)) { + struct ieee80211_supported_band *sband6; + + sband6 = local->hw.wiphy->bands[NL80211_BAND_6GHZ]; + he_cap = ieee80211_get_he_iftype_cap(sband6, + ieee80211_vif_type_p2p(&sdata->vif)); + + if (he_cap) { + enum nl80211_iftype iftype = + ieee80211_vif_type_p2p(&sdata->vif); + __le16 cap = ieee80211_get_he_6ghz_capa(sband6, iftype); + + pos = ieee80211_write_he_6ghz_cap(pos, cap, end); + } + } + + /* + * If adding more here, adjust code in main.c + * that calculates local->scan_ies_len. + */ + + return pos - buffer; + out_err: + WARN_ONCE(1, "not enough space for preq IEs\n"); + done: + return pos - buffer; +} + +int ieee80211_build_preq_ies(struct ieee80211_sub_if_data *sdata, u8 *buffer, + size_t buffer_len, + struct ieee80211_scan_ies *ie_desc, + const u8 *ie, size_t ie_len, + u8 bands_used, u32 *rate_masks, + struct cfg80211_chan_def *chandef, + u32 flags) +{ + size_t pos = 0, old_pos = 0, custom_ie_offset = 0; + int i; + + memset(ie_desc, 0, sizeof(*ie_desc)); + + for (i = 0; i < NUM_NL80211_BANDS; i++) { + if (bands_used & BIT(i)) { + pos += ieee80211_build_preq_ies_band(sdata, + buffer + pos, + buffer_len - pos, + ie, ie_len, i, + rate_masks[i], + chandef, + &custom_ie_offset, + flags); + ie_desc->ies[i] = buffer + old_pos; + ie_desc->len[i] = pos - old_pos; + old_pos = pos; + } + } + + /* add any remaining custom IEs */ + if (ie && ie_len) { + if (WARN_ONCE(buffer_len - pos < ie_len - custom_ie_offset, + "not enough space for preq custom IEs\n")) + return pos; + memcpy(buffer + pos, ie + custom_ie_offset, + ie_len - custom_ie_offset); + ie_desc->common_ies = buffer + pos; + ie_desc->common_ie_len = ie_len - custom_ie_offset; + pos += ie_len - custom_ie_offset; + } + + return pos; +}; + +struct sk_buff *ieee80211_build_probe_req(struct ieee80211_sub_if_data *sdata, + const u8 *src, const u8 *dst, + u32 ratemask, + struct ieee80211_channel *chan, + const u8 *ssid, size_t ssid_len, + const u8 *ie, size_t ie_len, + u32 flags) +{ + struct ieee80211_local *local = sdata->local; + struct cfg80211_chan_def chandef; + struct sk_buff *skb; + struct ieee80211_mgmt *mgmt; + int ies_len; + u32 rate_masks[NUM_NL80211_BANDS] = {}; + struct ieee80211_scan_ies dummy_ie_desc; + + /* + * Do not send DS Channel parameter for directed probe requests + * in order to maximize the chance that we get a response. Some + * badly-behaved APs don't respond when this parameter is included. + */ + chandef.width = sdata->vif.bss_conf.chandef.width; + if (flags & IEEE80211_PROBE_FLAG_DIRECTED) + chandef.chan = NULL; + else + chandef.chan = chan; + + skb = ieee80211_probereq_get(&local->hw, src, ssid, ssid_len, + local->scan_ies_len + ie_len); + if (!skb) + return NULL; + + rate_masks[chan->band] = ratemask; + ies_len = ieee80211_build_preq_ies(sdata, skb_tail_pointer(skb), + skb_tailroom(skb), &dummy_ie_desc, + ie, ie_len, BIT(chan->band), + rate_masks, &chandef, flags); + skb_put(skb, ies_len); + + if (dst) { + mgmt = (struct ieee80211_mgmt *) skb->data; + memcpy(mgmt->da, dst, ETH_ALEN); + memcpy(mgmt->bssid, dst, ETH_ALEN); + } + + IEEE80211_SKB_CB(skb)->flags |= IEEE80211_TX_INTFL_DONT_ENCRYPT; + + return skb; +} + +u32 ieee80211_sta_get_rates(struct ieee80211_sub_if_data *sdata, + struct ieee802_11_elems *elems, + enum nl80211_band band, u32 *basic_rates) +{ + struct ieee80211_supported_band *sband; + size_t num_rates; + u32 supp_rates, rate_flags; + int i, j, shift; + + sband = sdata->local->hw.wiphy->bands[band]; + if (WARN_ON(!sband)) + return 1; + + rate_flags = ieee80211_chandef_rate_flags(&sdata->vif.bss_conf.chandef); + shift = ieee80211_vif_get_shift(&sdata->vif); + + num_rates = sband->n_bitrates; + supp_rates = 0; + for (i = 0; i < elems->supp_rates_len + + elems->ext_supp_rates_len; i++) { + u8 rate = 0; + int own_rate; + bool is_basic; + if (i < elems->supp_rates_len) + rate = elems->supp_rates[i]; + else if (elems->ext_supp_rates) + rate = elems->ext_supp_rates + [i - elems->supp_rates_len]; + own_rate = 5 * (rate & 0x7f); + is_basic = !!(rate & 0x80); + + if (is_basic && (rate & 0x7f) == BSS_MEMBERSHIP_SELECTOR_HT_PHY) + continue; + + for (j = 0; j < num_rates; j++) { + int brate; + if ((rate_flags & sband->bitrates[j].flags) + != rate_flags) + continue; + + brate = DIV_ROUND_UP(sband->bitrates[j].bitrate, + 1 << shift); + + if (brate == own_rate) { + supp_rates |= BIT(j); + if (basic_rates && is_basic) + *basic_rates |= BIT(j); + } + } + } + return supp_rates; +} + +void ieee80211_stop_device(struct ieee80211_local *local) +{ + ieee80211_led_radio(local, false); + ieee80211_mod_tpt_led_trig(local, 0, IEEE80211_TPT_LEDTRIG_FL_RADIO); + + cancel_work_sync(&local->reconfig_filter); + + flush_workqueue(local->workqueue); + drv_stop(local); +} + +static void ieee80211_flush_completed_scan(struct ieee80211_local *local, + bool aborted) +{ + /* It's possible that we don't handle the scan completion in + * time during suspend, so if it's still marked as completed + * here, queue the work and flush it to clean things up. + * Instead of calling the worker function directly here, we + * really queue it to avoid potential races with other flows + * scheduling the same work. + */ + if (test_bit(SCAN_COMPLETED, &local->scanning)) { + /* If coming from reconfiguration failure, abort the scan so + * we don't attempt to continue a partial HW scan - which is + * possible otherwise if (e.g.) the 2.4 GHz portion was the + * completed scan, and a 5 GHz portion is still pending. + */ + if (aborted) + set_bit(SCAN_ABORTED, &local->scanning); + wiphy_delayed_work_queue(local->hw.wiphy, &local->scan_work, 0); + wiphy_delayed_work_flush(local->hw.wiphy, &local->scan_work); + } +} + +static void ieee80211_handle_reconfig_failure(struct ieee80211_local *local) +{ + struct ieee80211_sub_if_data *sdata; + struct ieee80211_chanctx *ctx; + + /* + * We get here if during resume the device can't be restarted properly. + * We might also get here if this happens during HW reset, which is a + * slightly different situation and we need to drop all connections in + * the latter case. + * + * Ask cfg80211 to turn off all interfaces, this will result in more + * warnings but at least we'll then get into a clean stopped state. + */ + + local->resuming = false; + local->suspended = false; + local->in_reconfig = false; + + ieee80211_flush_completed_scan(local, true); + + /* scheduled scan clearly can't be running any more, but tell + * cfg80211 and clear local state + */ + ieee80211_sched_scan_end(local); + + list_for_each_entry(sdata, &local->interfaces, list) + sdata->flags &= ~IEEE80211_SDATA_IN_DRIVER; + + /* Mark channel contexts as not being in the driver any more to avoid + * removing them from the driver during the shutdown process... + */ + mutex_lock(&local->chanctx_mtx); + list_for_each_entry(ctx, &local->chanctx_list, list) + ctx->driver_present = false; + mutex_unlock(&local->chanctx_mtx); +} + +static void ieee80211_assign_chanctx(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata, + struct ieee80211_link_data *link) +{ + struct ieee80211_chanctx_conf *conf; + struct ieee80211_chanctx *ctx; + + if (!local->use_chanctx) + return; + + mutex_lock(&local->chanctx_mtx); + conf = rcu_dereference_protected(link->conf->chanctx_conf, + lockdep_is_held(&local->chanctx_mtx)); + if (conf) { + ctx = container_of(conf, struct ieee80211_chanctx, conf); + drv_assign_vif_chanctx(local, sdata, link->conf, ctx); + } + mutex_unlock(&local->chanctx_mtx); +} + +static void ieee80211_reconfig_stations(struct ieee80211_sub_if_data *sdata) +{ + struct ieee80211_local *local = sdata->local; + struct sta_info *sta; + + /* add STAs back */ + mutex_lock(&local->sta_mtx); + list_for_each_entry(sta, &local->sta_list, list) { + enum ieee80211_sta_state state; + + if (!sta->uploaded || sta->sdata != sdata) + continue; + + for (state = IEEE80211_STA_NOTEXIST; + state < sta->sta_state; state++) + WARN_ON(drv_sta_state(local, sta->sdata, sta, state, + state + 1)); + } + mutex_unlock(&local->sta_mtx); +} + +static int ieee80211_reconfig_nan(struct ieee80211_sub_if_data *sdata) +{ + struct cfg80211_nan_func *func, **funcs; + int res, id, i = 0; + + res = drv_start_nan(sdata->local, sdata, + &sdata->u.nan.conf); + if (WARN_ON(res)) + return res; + + funcs = kcalloc(sdata->local->hw.max_nan_de_entries + 1, + sizeof(*funcs), + GFP_KERNEL); + if (!funcs) + return -ENOMEM; + + /* Add all the functions: + * This is a little bit ugly. We need to call a potentially sleeping + * callback for each NAN function, so we can't hold the spinlock. + */ + spin_lock_bh(&sdata->u.nan.func_lock); + + idr_for_each_entry(&sdata->u.nan.function_inst_ids, func, id) + funcs[i++] = func; + + spin_unlock_bh(&sdata->u.nan.func_lock); + + for (i = 0; funcs[i]; i++) { + res = drv_add_nan_func(sdata->local, sdata, funcs[i]); + if (WARN_ON(res)) + ieee80211_nan_func_terminated(&sdata->vif, + funcs[i]->instance_id, + NL80211_NAN_FUNC_TERM_REASON_ERROR, + GFP_KERNEL); + } + + kfree(funcs); + + return 0; +} + +int ieee80211_reconfig(struct ieee80211_local *local) +{ + struct ieee80211_hw *hw = &local->hw; + struct ieee80211_sub_if_data *sdata; + struct ieee80211_chanctx *ctx; + struct sta_info *sta; + int res, i; + bool reconfig_due_to_wowlan = false; + struct ieee80211_sub_if_data *sched_scan_sdata; + struct cfg80211_sched_scan_request *sched_scan_req; + bool sched_scan_stopped = false; + bool suspended = local->suspended; + bool in_reconfig = false; + + /* nothing to do if HW shouldn't run */ + if (!local->open_count) + goto wake_up; + +#ifdef CONFIG_PM + if (suspended) + local->resuming = true; + + if (local->wowlan) { + /* + * In the wowlan case, both mac80211 and the device + * are functional when the resume op is called, so + * clear local->suspended so the device could operate + * normally (e.g. pass rx frames). + */ + local->suspended = false; + res = drv_resume(local); + local->wowlan = false; + if (res < 0) { + local->resuming = false; + return res; + } + if (res == 0) + goto wake_up; + WARN_ON(res > 1); + /* + * res is 1, which means the driver requested + * to go through a regular reset on wakeup. + * restore local->suspended in this case. + */ + reconfig_due_to_wowlan = true; + local->suspended = true; + } +#endif + + /* + * In case of hw_restart during suspend (without wowlan), + * cancel restart work, as we are reconfiguring the device + * anyway. + * Note that restart_work is scheduled on a frozen workqueue, + * so we can't deadlock in this case. + */ + if (suspended && local->in_reconfig && !reconfig_due_to_wowlan) + cancel_work_sync(&local->restart_work); + + local->started = false; + + /* + * Upon resume hardware can sometimes be goofy due to + * various platform / driver / bus issues, so restarting + * the device may at times not work immediately. Propagate + * the error. + */ + res = drv_start(local); + if (res) { + if (suspended) + WARN(1, "Hardware became unavailable upon resume. This could be a software issue prior to suspend or a hardware issue.\n"); + else + WARN(1, "Hardware became unavailable during restart.\n"); + ieee80211_handle_reconfig_failure(local); + return res; + } + + /* setup fragmentation threshold */ + drv_set_frag_threshold(local, hw->wiphy->frag_threshold); + + /* setup RTS threshold */ + drv_set_rts_threshold(local, hw->wiphy->rts_threshold); + + /* reset coverage class */ + drv_set_coverage_class(local, hw->wiphy->coverage_class); + + ieee80211_led_radio(local, true); + ieee80211_mod_tpt_led_trig(local, + IEEE80211_TPT_LEDTRIG_FL_RADIO, 0); + + /* add interfaces */ + sdata = wiphy_dereference(local->hw.wiphy, local->monitor_sdata); + if (sdata) { + /* in HW restart it exists already */ + WARN_ON(local->resuming); + res = drv_add_interface(local, sdata); + if (WARN_ON(res)) { + RCU_INIT_POINTER(local->monitor_sdata, NULL); + synchronize_net(); + kfree(sdata); + } + } + + list_for_each_entry(sdata, &local->interfaces, list) { + if (sdata->vif.type != NL80211_IFTYPE_AP_VLAN && + sdata->vif.type != NL80211_IFTYPE_MONITOR && + ieee80211_sdata_running(sdata)) { + res = drv_add_interface(local, sdata); + if (WARN_ON(res)) + break; + } + } + + /* If adding any of the interfaces failed above, roll back and + * report failure. + */ + if (res) { + list_for_each_entry_continue_reverse(sdata, &local->interfaces, + list) + if (sdata->vif.type != NL80211_IFTYPE_AP_VLAN && + sdata->vif.type != NL80211_IFTYPE_MONITOR && + ieee80211_sdata_running(sdata)) + drv_remove_interface(local, sdata); + ieee80211_handle_reconfig_failure(local); + return res; + } + + /* add channel contexts */ + if (local->use_chanctx) { + mutex_lock(&local->chanctx_mtx); + list_for_each_entry(ctx, &local->chanctx_list, list) + if (ctx->replace_state != + IEEE80211_CHANCTX_REPLACES_OTHER) + WARN_ON(drv_add_chanctx(local, ctx)); + mutex_unlock(&local->chanctx_mtx); + + sdata = wiphy_dereference(local->hw.wiphy, + local->monitor_sdata); + if (sdata && ieee80211_sdata_running(sdata)) + ieee80211_assign_chanctx(local, sdata, &sdata->deflink); + } + + /* reconfigure hardware */ + ieee80211_hw_config(local, ~0); + + ieee80211_configure_filter(local); + + /* Finally also reconfigure all the BSS information */ + list_for_each_entry(sdata, &local->interfaces, list) { + unsigned int link_id; + u32 changed; + + if (!ieee80211_sdata_running(sdata)) + continue; + + sdata_lock(sdata); + for (link_id = 0; + link_id < ARRAY_SIZE(sdata->vif.link_conf); + link_id++) { + struct ieee80211_link_data *link; + + link = sdata_dereference(sdata->link[link_id], sdata); + if (link) + ieee80211_assign_chanctx(local, sdata, link); + } + + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_MONITOR: + break; + case NL80211_IFTYPE_ADHOC: + if (sdata->vif.cfg.ibss_joined) + WARN_ON(drv_join_ibss(local, sdata)); + fallthrough; + default: + ieee80211_reconfig_stations(sdata); + fallthrough; + case NL80211_IFTYPE_AP: /* AP stations are handled later */ + for (i = 0; i < IEEE80211_NUM_ACS; i++) + drv_conf_tx(local, &sdata->deflink, i, + &sdata->deflink.tx_conf[i]); + break; + } + sdata_unlock(sdata); + + /* common change flags for all interface types */ + changed = BSS_CHANGED_ERP_CTS_PROT | + BSS_CHANGED_ERP_PREAMBLE | + BSS_CHANGED_ERP_SLOT | + BSS_CHANGED_HT | + BSS_CHANGED_BASIC_RATES | + BSS_CHANGED_BEACON_INT | + BSS_CHANGED_BSSID | + BSS_CHANGED_CQM | + BSS_CHANGED_QOS | + BSS_CHANGED_IDLE | + BSS_CHANGED_TXPOWER | + BSS_CHANGED_MCAST_RATE; + + if (sdata->vif.bss_conf.mu_mimo_owner) + changed |= BSS_CHANGED_MU_GROUPS; + + switch (sdata->vif.type) { + case NL80211_IFTYPE_STATION: + changed |= BSS_CHANGED_ASSOC | + BSS_CHANGED_ARP_FILTER | + BSS_CHANGED_PS; + + /* Re-send beacon info report to the driver */ + if (sdata->deflink.u.mgd.have_beacon) + changed |= BSS_CHANGED_BEACON_INFO; + + if (sdata->vif.bss_conf.max_idle_period || + sdata->vif.bss_conf.protected_keep_alive) + changed |= BSS_CHANGED_KEEP_ALIVE; + + sdata_lock(sdata); + ieee80211_bss_info_change_notify(sdata, changed); + sdata_unlock(sdata); + break; + case NL80211_IFTYPE_OCB: + changed |= BSS_CHANGED_OCB; + ieee80211_bss_info_change_notify(sdata, changed); + break; + case NL80211_IFTYPE_ADHOC: + changed |= BSS_CHANGED_IBSS; + fallthrough; + case NL80211_IFTYPE_AP: + changed |= BSS_CHANGED_SSID | BSS_CHANGED_P2P_PS; + + if (sdata->vif.bss_conf.ftm_responder == 1 && + wiphy_ext_feature_isset(sdata->local->hw.wiphy, + NL80211_EXT_FEATURE_ENABLE_FTM_RESPONDER)) + changed |= BSS_CHANGED_FTM_RESPONDER; + + if (sdata->vif.type == NL80211_IFTYPE_AP) { + changed |= BSS_CHANGED_AP_PROBE_RESP; + + if (rcu_access_pointer(sdata->deflink.u.ap.beacon)) + drv_start_ap(local, sdata, + sdata->deflink.conf); + } + fallthrough; + case NL80211_IFTYPE_MESH_POINT: + if (sdata->vif.bss_conf.enable_beacon) { + changed |= BSS_CHANGED_BEACON | + BSS_CHANGED_BEACON_ENABLED; + ieee80211_bss_info_change_notify(sdata, changed); + } + break; + case NL80211_IFTYPE_NAN: + res = ieee80211_reconfig_nan(sdata); + if (res < 0) { + ieee80211_handle_reconfig_failure(local); + return res; + } + break; + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_MONITOR: + case NL80211_IFTYPE_P2P_DEVICE: + /* nothing to do */ + break; + case NL80211_IFTYPE_UNSPECIFIED: + case NUM_NL80211_IFTYPES: + case NL80211_IFTYPE_P2P_CLIENT: + case NL80211_IFTYPE_P2P_GO: + case NL80211_IFTYPE_WDS: + WARN_ON(1); + break; + } + } + + ieee80211_recalc_ps(local); + + /* + * The sta might be in psm against the ap (e.g. because + * this was the state before a hw restart), so we + * explicitly send a null packet in order to make sure + * it'll sync against the ap (and get out of psm). + */ + if (!(local->hw.conf.flags & IEEE80211_CONF_PS)) { + list_for_each_entry(sdata, &local->interfaces, list) { + if (sdata->vif.type != NL80211_IFTYPE_STATION) + continue; + if (!sdata->u.mgd.associated) + continue; + + ieee80211_send_nullfunc(local, sdata, false); + } + } + + /* APs are now beaconing, add back stations */ + list_for_each_entry(sdata, &local->interfaces, list) { + if (!ieee80211_sdata_running(sdata)) + continue; + + sdata_lock(sdata); + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_AP: + ieee80211_reconfig_stations(sdata); + break; + default: + break; + } + sdata_unlock(sdata); + } + + /* add back keys */ + list_for_each_entry(sdata, &local->interfaces, list) + ieee80211_reenable_keys(sdata); + + /* Reconfigure sched scan if it was interrupted by FW restart */ + mutex_lock(&local->mtx); + sched_scan_sdata = rcu_dereference_protected(local->sched_scan_sdata, + lockdep_is_held(&local->mtx)); + sched_scan_req = rcu_dereference_protected(local->sched_scan_req, + lockdep_is_held(&local->mtx)); + if (sched_scan_sdata && sched_scan_req) + /* + * Sched scan stopped, but we don't want to report it. Instead, + * we're trying to reschedule. However, if more than one scan + * plan was set, we cannot reschedule since we don't know which + * scan plan was currently running (and some scan plans may have + * already finished). + */ + if (sched_scan_req->n_scan_plans > 1 || + __ieee80211_request_sched_scan_start(sched_scan_sdata, + sched_scan_req)) { + RCU_INIT_POINTER(local->sched_scan_sdata, NULL); + RCU_INIT_POINTER(local->sched_scan_req, NULL); + sched_scan_stopped = true; + } + mutex_unlock(&local->mtx); + + if (sched_scan_stopped) + cfg80211_sched_scan_stopped_locked(local->hw.wiphy, 0); + + wake_up: + + if (local->monitors == local->open_count && local->monitors > 0) + ieee80211_add_virtual_monitor(local); + + /* + * Clear the WLAN_STA_BLOCK_BA flag so new aggregation + * sessions can be established after a resume. + * + * Also tear down aggregation sessions since reconfiguring + * them in a hardware restart scenario is not easily done + * right now, and the hardware will have lost information + * about the sessions, but we and the AP still think they + * are active. This is really a workaround though. + */ + if (ieee80211_hw_check(hw, AMPDU_AGGREGATION)) { + mutex_lock(&local->sta_mtx); + + list_for_each_entry(sta, &local->sta_list, list) { + if (!local->resuming) + ieee80211_sta_tear_down_BA_sessions( + sta, AGG_STOP_LOCAL_REQUEST); + clear_sta_flag(sta, WLAN_STA_BLOCK_BA); + } + + mutex_unlock(&local->sta_mtx); + } + + /* + * If this is for hw restart things are still running. + * We may want to change that later, however. + */ + if (local->open_count && (!suspended || reconfig_due_to_wowlan)) + drv_reconfig_complete(local, IEEE80211_RECONFIG_TYPE_RESTART); + + if (local->in_reconfig) { + in_reconfig = local->in_reconfig; + local->in_reconfig = false; + barrier(); + + /* Restart deferred ROCs */ + mutex_lock(&local->mtx); + ieee80211_start_next_roc(local); + mutex_unlock(&local->mtx); + + /* Requeue all works */ + list_for_each_entry(sdata, &local->interfaces, list) + ieee80211_queue_work(&local->hw, &sdata->work); + } + + ieee80211_wake_queues_by_reason(hw, IEEE80211_MAX_QUEUE_MAP, + IEEE80211_QUEUE_STOP_REASON_SUSPEND, + false); + + if (in_reconfig) { + list_for_each_entry(sdata, &local->interfaces, list) { + if (!ieee80211_sdata_running(sdata)) + continue; + if (sdata->vif.type == NL80211_IFTYPE_STATION) + ieee80211_sta_restart(sdata); + } + } + + if (!suspended) + return 0; + +#ifdef CONFIG_PM + /* first set suspended false, then resuming */ + local->suspended = false; + mb(); + local->resuming = false; + + ieee80211_flush_completed_scan(local, false); + + if (local->open_count && !reconfig_due_to_wowlan) + drv_reconfig_complete(local, IEEE80211_RECONFIG_TYPE_SUSPEND); + + list_for_each_entry(sdata, &local->interfaces, list) { + if (!ieee80211_sdata_running(sdata)) + continue; + if (sdata->vif.type == NL80211_IFTYPE_STATION) + ieee80211_sta_restart(sdata); + } + + mod_timer(&local->sta_cleanup, jiffies + 1); +#else + WARN_ON(1); +#endif + + return 0; +} + +static void ieee80211_reconfig_disconnect(struct ieee80211_vif *vif, u8 flag) +{ + struct ieee80211_sub_if_data *sdata; + struct ieee80211_local *local; + struct ieee80211_key *key; + + if (WARN_ON(!vif)) + return; + + sdata = vif_to_sdata(vif); + local = sdata->local; + + if (WARN_ON(flag & IEEE80211_SDATA_DISCONNECT_RESUME && + !local->resuming)) + return; + + if (WARN_ON(flag & IEEE80211_SDATA_DISCONNECT_HW_RESTART && + !local->in_reconfig)) + return; + + if (WARN_ON(vif->type != NL80211_IFTYPE_STATION)) + return; + + sdata->flags |= flag; + + mutex_lock(&local->key_mtx); + list_for_each_entry(key, &sdata->key_list, list) + key->flags |= KEY_FLAG_TAINTED; + mutex_unlock(&local->key_mtx); +} + +void ieee80211_hw_restart_disconnect(struct ieee80211_vif *vif) +{ + ieee80211_reconfig_disconnect(vif, IEEE80211_SDATA_DISCONNECT_HW_RESTART); +} +EXPORT_SYMBOL_GPL(ieee80211_hw_restart_disconnect); + +void ieee80211_resume_disconnect(struct ieee80211_vif *vif) +{ + ieee80211_reconfig_disconnect(vif, IEEE80211_SDATA_DISCONNECT_RESUME); +} +EXPORT_SYMBOL_GPL(ieee80211_resume_disconnect); + +void ieee80211_recalc_smps(struct ieee80211_sub_if_data *sdata, + struct ieee80211_link_data *link) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_chanctx_conf *chanctx_conf; + struct ieee80211_chanctx *chanctx; + + mutex_lock(&local->chanctx_mtx); + + chanctx_conf = rcu_dereference_protected(link->conf->chanctx_conf, + lockdep_is_held(&local->chanctx_mtx)); + + /* + * This function can be called from a work, thus it may be possible + * that the chanctx_conf is removed (due to a disconnection, for + * example). + * So nothing should be done in such case. + */ + if (!chanctx_conf) + goto unlock; + + chanctx = container_of(chanctx_conf, struct ieee80211_chanctx, conf); + ieee80211_recalc_smps_chanctx(local, chanctx); + unlock: + mutex_unlock(&local->chanctx_mtx); +} + +void ieee80211_recalc_min_chandef(struct ieee80211_sub_if_data *sdata, + int link_id) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_chanctx_conf *chanctx_conf; + struct ieee80211_chanctx *chanctx; + int i; + + mutex_lock(&local->chanctx_mtx); + + for (i = 0; i < ARRAY_SIZE(sdata->vif.link_conf); i++) { + struct ieee80211_bss_conf *bss_conf; + + if (link_id >= 0 && link_id != i) + continue; + + rcu_read_lock(); + bss_conf = rcu_dereference(sdata->vif.link_conf[i]); + if (!bss_conf) { + rcu_read_unlock(); + continue; + } + + chanctx_conf = rcu_dereference_protected(bss_conf->chanctx_conf, + lockdep_is_held(&local->chanctx_mtx)); + /* + * Since we hold the chanctx_mtx (checked above) + * we can take the chanctx_conf pointer out of the + * RCU critical section, it cannot go away without + * the mutex. Just the way we reached it could - in + * theory - go away, but we don't really care and + * it really shouldn't happen anyway. + */ + rcu_read_unlock(); + + if (!chanctx_conf) + goto unlock; + + chanctx = container_of(chanctx_conf, struct ieee80211_chanctx, + conf); + ieee80211_recalc_chanctx_min_def(local, chanctx, NULL); + } + unlock: + mutex_unlock(&local->chanctx_mtx); +} + +size_t ieee80211_ie_split_vendor(const u8 *ies, size_t ielen, size_t offset) +{ + size_t pos = offset; + + while (pos < ielen && ies[pos] != WLAN_EID_VENDOR_SPECIFIC) + pos += 2 + ies[pos + 1]; + + return pos; +} + +u8 *ieee80211_ie_build_ht_cap(u8 *pos, struct ieee80211_sta_ht_cap *ht_cap, + u16 cap) +{ + __le16 tmp; + + *pos++ = WLAN_EID_HT_CAPABILITY; + *pos++ = sizeof(struct ieee80211_ht_cap); + memset(pos, 0, sizeof(struct ieee80211_ht_cap)); + + /* capability flags */ + tmp = cpu_to_le16(cap); + memcpy(pos, &tmp, sizeof(u16)); + pos += sizeof(u16); + + /* AMPDU parameters */ + *pos++ = ht_cap->ampdu_factor | + (ht_cap->ampdu_density << + IEEE80211_HT_AMPDU_PARM_DENSITY_SHIFT); + + /* MCS set */ + memcpy(pos, &ht_cap->mcs, sizeof(ht_cap->mcs)); + pos += sizeof(ht_cap->mcs); + + /* extended capabilities */ + pos += sizeof(__le16); + + /* BF capabilities */ + pos += sizeof(__le32); + + /* antenna selection */ + pos += sizeof(u8); + + return pos; +} + +u8 *ieee80211_ie_build_vht_cap(u8 *pos, struct ieee80211_sta_vht_cap *vht_cap, + u32 cap) +{ + __le32 tmp; + + *pos++ = WLAN_EID_VHT_CAPABILITY; + *pos++ = sizeof(struct ieee80211_vht_cap); + memset(pos, 0, sizeof(struct ieee80211_vht_cap)); + + /* capability flags */ + tmp = cpu_to_le32(cap); + memcpy(pos, &tmp, sizeof(u32)); + pos += sizeof(u32); + + /* VHT MCS set */ + memcpy(pos, &vht_cap->vht_mcs, sizeof(vht_cap->vht_mcs)); + pos += sizeof(vht_cap->vht_mcs); + + return pos; +} + +u8 ieee80211_ie_len_he_cap(struct ieee80211_sub_if_data *sdata, u8 iftype) +{ + const struct ieee80211_sta_he_cap *he_cap; + struct ieee80211_supported_band *sband; + u8 n; + + sband = ieee80211_get_sband(sdata); + if (!sband) + return 0; + + he_cap = ieee80211_get_he_iftype_cap(sband, iftype); + if (!he_cap) + return 0; + + n = ieee80211_he_mcs_nss_size(&he_cap->he_cap_elem); + return 2 + 1 + + sizeof(he_cap->he_cap_elem) + n + + ieee80211_he_ppe_size(he_cap->ppe_thres[0], + he_cap->he_cap_elem.phy_cap_info); +} + +u8 *ieee80211_ie_build_he_cap(ieee80211_conn_flags_t disable_flags, u8 *pos, + const struct ieee80211_sta_he_cap *he_cap, + u8 *end) +{ + struct ieee80211_he_cap_elem elem; + u8 n; + u8 ie_len; + u8 *orig_pos = pos; + + /* Make sure we have place for the IE */ + /* + * TODO: the 1 added is because this temporarily is under the EXTENSION + * IE. Get rid of it when it moves. + */ + if (!he_cap) + return orig_pos; + + /* modify on stack first to calculate 'n' and 'ie_len' correctly */ + elem = he_cap->he_cap_elem; + + if (disable_flags & IEEE80211_CONN_DISABLE_40MHZ) + elem.phy_cap_info[0] &= + ~(IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_40MHZ_80MHZ_IN_5G | + IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_40MHZ_IN_2G); + + if (disable_flags & IEEE80211_CONN_DISABLE_160MHZ) + elem.phy_cap_info[0] &= + ~IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_160MHZ_IN_5G; + + if (disable_flags & IEEE80211_CONN_DISABLE_80P80MHZ) + elem.phy_cap_info[0] &= + ~IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_80PLUS80_MHZ_IN_5G; + + n = ieee80211_he_mcs_nss_size(&elem); + ie_len = 2 + 1 + + sizeof(he_cap->he_cap_elem) + n + + ieee80211_he_ppe_size(he_cap->ppe_thres[0], + he_cap->he_cap_elem.phy_cap_info); + + if ((end - pos) < ie_len) + return orig_pos; + + *pos++ = WLAN_EID_EXTENSION; + pos++; /* We'll set the size later below */ + *pos++ = WLAN_EID_EXT_HE_CAPABILITY; + + /* Fixed data */ + memcpy(pos, &elem, sizeof(elem)); + pos += sizeof(elem); + + memcpy(pos, &he_cap->he_mcs_nss_supp, n); + pos += n; + + /* Check if PPE Threshold should be present */ + if ((he_cap->he_cap_elem.phy_cap_info[6] & + IEEE80211_HE_PHY_CAP6_PPE_THRESHOLD_PRESENT) == 0) + goto end; + + /* + * Calculate how many PPET16/PPET8 pairs are to come. Algorithm: + * (NSS_M1 + 1) x (num of 1 bits in RU_INDEX_BITMASK) + */ + n = hweight8(he_cap->ppe_thres[0] & + IEEE80211_PPE_THRES_RU_INDEX_BITMASK_MASK); + n *= (1 + ((he_cap->ppe_thres[0] & IEEE80211_PPE_THRES_NSS_MASK) >> + IEEE80211_PPE_THRES_NSS_POS)); + + /* + * Each pair is 6 bits, and we need to add the 7 "header" bits to the + * total size. + */ + n = (n * IEEE80211_PPE_THRES_INFO_PPET_SIZE * 2) + 7; + n = DIV_ROUND_UP(n, 8); + + /* Copy PPE Thresholds */ + memcpy(pos, &he_cap->ppe_thres, n); + pos += n; + +end: + orig_pos[1] = (pos - orig_pos) - 2; + return pos; +} + +void ieee80211_ie_build_he_6ghz_cap(struct ieee80211_sub_if_data *sdata, + enum ieee80211_smps_mode smps_mode, + struct sk_buff *skb) +{ + struct ieee80211_supported_band *sband; + const struct ieee80211_sband_iftype_data *iftd; + enum nl80211_iftype iftype = ieee80211_vif_type_p2p(&sdata->vif); + u8 *pos; + u16 cap; + + if (!cfg80211_any_usable_channels(sdata->local->hw.wiphy, + BIT(NL80211_BAND_6GHZ), + IEEE80211_CHAN_NO_HE)) + return; + + sband = sdata->local->hw.wiphy->bands[NL80211_BAND_6GHZ]; + + iftd = ieee80211_get_sband_iftype_data(sband, iftype); + if (!iftd) + return; + + /* Check for device HE 6 GHz capability before adding element */ + if (!iftd->he_6ghz_capa.capa) + return; + + cap = le16_to_cpu(iftd->he_6ghz_capa.capa); + cap &= ~IEEE80211_HE_6GHZ_CAP_SM_PS; + + switch (smps_mode) { + case IEEE80211_SMPS_AUTOMATIC: + case IEEE80211_SMPS_NUM_MODES: + WARN_ON(1); + fallthrough; + case IEEE80211_SMPS_OFF: + cap |= u16_encode_bits(WLAN_HT_CAP_SM_PS_DISABLED, + IEEE80211_HE_6GHZ_CAP_SM_PS); + break; + case IEEE80211_SMPS_STATIC: + cap |= u16_encode_bits(WLAN_HT_CAP_SM_PS_STATIC, + IEEE80211_HE_6GHZ_CAP_SM_PS); + break; + case IEEE80211_SMPS_DYNAMIC: + cap |= u16_encode_bits(WLAN_HT_CAP_SM_PS_DYNAMIC, + IEEE80211_HE_6GHZ_CAP_SM_PS); + break; + } + + pos = skb_put(skb, 2 + 1 + sizeof(cap)); + ieee80211_write_he_6ghz_cap(pos, cpu_to_le16(cap), + pos + 2 + 1 + sizeof(cap)); +} + +u8 *ieee80211_ie_build_ht_oper(u8 *pos, struct ieee80211_sta_ht_cap *ht_cap, + const struct cfg80211_chan_def *chandef, + u16 prot_mode, bool rifs_mode) +{ + struct ieee80211_ht_operation *ht_oper; + /* Build HT Information */ + *pos++ = WLAN_EID_HT_OPERATION; + *pos++ = sizeof(struct ieee80211_ht_operation); + ht_oper = (struct ieee80211_ht_operation *)pos; + ht_oper->primary_chan = ieee80211_frequency_to_channel( + chandef->chan->center_freq); + switch (chandef->width) { + case NL80211_CHAN_WIDTH_160: + case NL80211_CHAN_WIDTH_80P80: + case NL80211_CHAN_WIDTH_80: + case NL80211_CHAN_WIDTH_40: + if (chandef->center_freq1 > chandef->chan->center_freq) + ht_oper->ht_param = IEEE80211_HT_PARAM_CHA_SEC_ABOVE; + else + ht_oper->ht_param = IEEE80211_HT_PARAM_CHA_SEC_BELOW; + break; + case NL80211_CHAN_WIDTH_320: + /* HT information element should not be included on 6GHz */ + WARN_ON(1); + return pos; + default: + ht_oper->ht_param = IEEE80211_HT_PARAM_CHA_SEC_NONE; + break; + } + if (ht_cap->cap & IEEE80211_HT_CAP_SUP_WIDTH_20_40 && + chandef->width != NL80211_CHAN_WIDTH_20_NOHT && + chandef->width != NL80211_CHAN_WIDTH_20) + ht_oper->ht_param |= IEEE80211_HT_PARAM_CHAN_WIDTH_ANY; + + if (rifs_mode) + ht_oper->ht_param |= IEEE80211_HT_PARAM_RIFS_MODE; + + ht_oper->operation_mode = cpu_to_le16(prot_mode); + ht_oper->stbc_param = 0x0000; + + /* It seems that Basic MCS set and Supported MCS set + are identical for the first 10 bytes */ + memset(&ht_oper->basic_set, 0, 16); + memcpy(&ht_oper->basic_set, &ht_cap->mcs, 10); + + return pos + sizeof(struct ieee80211_ht_operation); +} + +void ieee80211_ie_build_wide_bw_cs(u8 *pos, + const struct cfg80211_chan_def *chandef) +{ + *pos++ = WLAN_EID_WIDE_BW_CHANNEL_SWITCH; /* EID */ + *pos++ = 3; /* IE length */ + /* New channel width */ + switch (chandef->width) { + case NL80211_CHAN_WIDTH_80: + *pos++ = IEEE80211_VHT_CHANWIDTH_80MHZ; + break; + case NL80211_CHAN_WIDTH_160: + *pos++ = IEEE80211_VHT_CHANWIDTH_160MHZ; + break; + case NL80211_CHAN_WIDTH_80P80: + *pos++ = IEEE80211_VHT_CHANWIDTH_80P80MHZ; + break; + case NL80211_CHAN_WIDTH_320: + /* The behavior is not defined for 320 MHz channels */ + WARN_ON(1); + fallthrough; + default: + *pos++ = IEEE80211_VHT_CHANWIDTH_USE_HT; + } + + /* new center frequency segment 0 */ + *pos++ = ieee80211_frequency_to_channel(chandef->center_freq1); + /* new center frequency segment 1 */ + if (chandef->center_freq2) + *pos++ = ieee80211_frequency_to_channel(chandef->center_freq2); + else + *pos++ = 0; +} + +u8 *ieee80211_ie_build_vht_oper(u8 *pos, struct ieee80211_sta_vht_cap *vht_cap, + const struct cfg80211_chan_def *chandef) +{ + struct ieee80211_vht_operation *vht_oper; + + *pos++ = WLAN_EID_VHT_OPERATION; + *pos++ = sizeof(struct ieee80211_vht_operation); + vht_oper = (struct ieee80211_vht_operation *)pos; + vht_oper->center_freq_seg0_idx = ieee80211_frequency_to_channel( + chandef->center_freq1); + if (chandef->center_freq2) + vht_oper->center_freq_seg1_idx = + ieee80211_frequency_to_channel(chandef->center_freq2); + else + vht_oper->center_freq_seg1_idx = 0x00; + + switch (chandef->width) { + case NL80211_CHAN_WIDTH_160: + /* + * Convert 160 MHz channel width to new style as interop + * workaround. + */ + vht_oper->chan_width = IEEE80211_VHT_CHANWIDTH_80MHZ; + vht_oper->center_freq_seg1_idx = vht_oper->center_freq_seg0_idx; + if (chandef->chan->center_freq < chandef->center_freq1) + vht_oper->center_freq_seg0_idx -= 8; + else + vht_oper->center_freq_seg0_idx += 8; + break; + case NL80211_CHAN_WIDTH_80P80: + /* + * Convert 80+80 MHz channel width to new style as interop + * workaround. + */ + vht_oper->chan_width = IEEE80211_VHT_CHANWIDTH_80MHZ; + break; + case NL80211_CHAN_WIDTH_80: + vht_oper->chan_width = IEEE80211_VHT_CHANWIDTH_80MHZ; + break; + case NL80211_CHAN_WIDTH_320: + /* VHT information element should not be included on 6GHz */ + WARN_ON(1); + return pos; + default: + vht_oper->chan_width = IEEE80211_VHT_CHANWIDTH_USE_HT; + break; + } + + /* don't require special VHT peer rates */ + vht_oper->basic_mcs_set = cpu_to_le16(0xffff); + + return pos + sizeof(struct ieee80211_vht_operation); +} + +u8 *ieee80211_ie_build_he_oper(u8 *pos, struct cfg80211_chan_def *chandef) +{ + struct ieee80211_he_operation *he_oper; + struct ieee80211_he_6ghz_oper *he_6ghz_op; + u32 he_oper_params; + u8 ie_len = 1 + sizeof(struct ieee80211_he_operation); + + if (chandef->chan->band == NL80211_BAND_6GHZ) + ie_len += sizeof(struct ieee80211_he_6ghz_oper); + + *pos++ = WLAN_EID_EXTENSION; + *pos++ = ie_len; + *pos++ = WLAN_EID_EXT_HE_OPERATION; + + he_oper_params = 0; + he_oper_params |= u32_encode_bits(1023, /* disabled */ + IEEE80211_HE_OPERATION_RTS_THRESHOLD_MASK); + he_oper_params |= u32_encode_bits(1, + IEEE80211_HE_OPERATION_ER_SU_DISABLE); + he_oper_params |= u32_encode_bits(1, + IEEE80211_HE_OPERATION_BSS_COLOR_DISABLED); + if (chandef->chan->band == NL80211_BAND_6GHZ) + he_oper_params |= u32_encode_bits(1, + IEEE80211_HE_OPERATION_6GHZ_OP_INFO); + + he_oper = (struct ieee80211_he_operation *)pos; + he_oper->he_oper_params = cpu_to_le32(he_oper_params); + + /* don't require special HE peer rates */ + he_oper->he_mcs_nss_set = cpu_to_le16(0xffff); + pos += sizeof(struct ieee80211_he_operation); + + if (chandef->chan->band != NL80211_BAND_6GHZ) + goto out; + + /* TODO add VHT operational */ + he_6ghz_op = (struct ieee80211_he_6ghz_oper *)pos; + he_6ghz_op->minrate = 6; /* 6 Mbps */ + he_6ghz_op->primary = + ieee80211_frequency_to_channel(chandef->chan->center_freq); + he_6ghz_op->ccfs0 = + ieee80211_frequency_to_channel(chandef->center_freq1); + if (chandef->center_freq2) + he_6ghz_op->ccfs1 = + ieee80211_frequency_to_channel(chandef->center_freq2); + else + he_6ghz_op->ccfs1 = 0; + + switch (chandef->width) { + case NL80211_CHAN_WIDTH_320: + /* + * TODO: mesh operation is not defined over 6GHz 320 MHz + * channels. + */ + WARN_ON(1); + break; + case NL80211_CHAN_WIDTH_160: + /* Convert 160 MHz channel width to new style as interop + * workaround. + */ + he_6ghz_op->control = + IEEE80211_HE_6GHZ_OPER_CTRL_CHANWIDTH_160MHZ; + he_6ghz_op->ccfs1 = he_6ghz_op->ccfs0; + if (chandef->chan->center_freq < chandef->center_freq1) + he_6ghz_op->ccfs0 -= 8; + else + he_6ghz_op->ccfs0 += 8; + fallthrough; + case NL80211_CHAN_WIDTH_80P80: + he_6ghz_op->control = + IEEE80211_HE_6GHZ_OPER_CTRL_CHANWIDTH_160MHZ; + break; + case NL80211_CHAN_WIDTH_80: + he_6ghz_op->control = + IEEE80211_HE_6GHZ_OPER_CTRL_CHANWIDTH_80MHZ; + break; + case NL80211_CHAN_WIDTH_40: + he_6ghz_op->control = + IEEE80211_HE_6GHZ_OPER_CTRL_CHANWIDTH_40MHZ; + break; + default: + he_6ghz_op->control = + IEEE80211_HE_6GHZ_OPER_CTRL_CHANWIDTH_20MHZ; + break; + } + + pos += sizeof(struct ieee80211_he_6ghz_oper); + +out: + return pos; +} + +bool ieee80211_chandef_ht_oper(const struct ieee80211_ht_operation *ht_oper, + struct cfg80211_chan_def *chandef) +{ + enum nl80211_channel_type channel_type; + + if (!ht_oper) + return false; + + switch (ht_oper->ht_param & IEEE80211_HT_PARAM_CHA_SEC_OFFSET) { + case IEEE80211_HT_PARAM_CHA_SEC_NONE: + channel_type = NL80211_CHAN_HT20; + break; + case IEEE80211_HT_PARAM_CHA_SEC_ABOVE: + channel_type = NL80211_CHAN_HT40PLUS; + break; + case IEEE80211_HT_PARAM_CHA_SEC_BELOW: + channel_type = NL80211_CHAN_HT40MINUS; + break; + default: + return false; + } + + cfg80211_chandef_create(chandef, chandef->chan, channel_type); + return true; +} + +bool ieee80211_chandef_vht_oper(struct ieee80211_hw *hw, u32 vht_cap_info, + const struct ieee80211_vht_operation *oper, + const struct ieee80211_ht_operation *htop, + struct cfg80211_chan_def *chandef) +{ + struct cfg80211_chan_def new = *chandef; + int cf0, cf1; + int ccfs0, ccfs1, ccfs2; + int ccf0, ccf1; + u32 vht_cap; + bool support_80_80 = false; + bool support_160 = false; + u8 ext_nss_bw_supp = u32_get_bits(vht_cap_info, + IEEE80211_VHT_CAP_EXT_NSS_BW_MASK); + u8 supp_chwidth = u32_get_bits(vht_cap_info, + IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_MASK); + + if (!oper || !htop) + return false; + + vht_cap = hw->wiphy->bands[chandef->chan->band]->vht_cap.cap; + support_160 = (vht_cap & (IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_MASK | + IEEE80211_VHT_CAP_EXT_NSS_BW_MASK)); + support_80_80 = ((vht_cap & + IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160_80PLUS80MHZ) || + (vht_cap & IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160MHZ && + vht_cap & IEEE80211_VHT_CAP_EXT_NSS_BW_MASK) || + ((vht_cap & IEEE80211_VHT_CAP_EXT_NSS_BW_MASK) >> + IEEE80211_VHT_CAP_EXT_NSS_BW_SHIFT > 1)); + ccfs0 = oper->center_freq_seg0_idx; + ccfs1 = oper->center_freq_seg1_idx; + ccfs2 = (le16_to_cpu(htop->operation_mode) & + IEEE80211_HT_OP_MODE_CCFS2_MASK) + >> IEEE80211_HT_OP_MODE_CCFS2_SHIFT; + + ccf0 = ccfs0; + + /* if not supported, parse as though we didn't understand it */ + if (!ieee80211_hw_check(hw, SUPPORTS_VHT_EXT_NSS_BW)) + ext_nss_bw_supp = 0; + + /* + * Cf. IEEE 802.11 Table 9-250 + * + * We really just consider that because it's inefficient to connect + * at a higher bandwidth than we'll actually be able to use. + */ + switch ((supp_chwidth << 4) | ext_nss_bw_supp) { + default: + case 0x00: + ccf1 = 0; + support_160 = false; + support_80_80 = false; + break; + case 0x01: + support_80_80 = false; + fallthrough; + case 0x02: + case 0x03: + ccf1 = ccfs2; + break; + case 0x10: + ccf1 = ccfs1; + break; + case 0x11: + case 0x12: + if (!ccfs1) + ccf1 = ccfs2; + else + ccf1 = ccfs1; + break; + case 0x13: + case 0x20: + case 0x23: + ccf1 = ccfs1; + break; + } + + cf0 = ieee80211_channel_to_frequency(ccf0, chandef->chan->band); + cf1 = ieee80211_channel_to_frequency(ccf1, chandef->chan->band); + + switch (oper->chan_width) { + case IEEE80211_VHT_CHANWIDTH_USE_HT: + /* just use HT information directly */ + break; + case IEEE80211_VHT_CHANWIDTH_80MHZ: + new.width = NL80211_CHAN_WIDTH_80; + new.center_freq1 = cf0; + /* If needed, adjust based on the newer interop workaround. */ + if (ccf1) { + unsigned int diff; + + diff = abs(ccf1 - ccf0); + if ((diff == 8) && support_160) { + new.width = NL80211_CHAN_WIDTH_160; + new.center_freq1 = cf1; + } else if ((diff > 8) && support_80_80) { + new.width = NL80211_CHAN_WIDTH_80P80; + new.center_freq2 = cf1; + } + } + break; + case IEEE80211_VHT_CHANWIDTH_160MHZ: + /* deprecated encoding */ + new.width = NL80211_CHAN_WIDTH_160; + new.center_freq1 = cf0; + break; + case IEEE80211_VHT_CHANWIDTH_80P80MHZ: + /* deprecated encoding */ + new.width = NL80211_CHAN_WIDTH_80P80; + new.center_freq1 = cf0; + new.center_freq2 = cf1; + break; + default: + return false; + } + + if (!cfg80211_chandef_valid(&new)) + return false; + + *chandef = new; + return true; +} + +void ieee80211_chandef_eht_oper(const struct ieee80211_eht_operation *eht_oper, + bool support_160, bool support_320, + struct cfg80211_chan_def *chandef) +{ + struct ieee80211_eht_operation_info *info = (void *)eht_oper->optional; + + chandef->center_freq1 = + ieee80211_channel_to_frequency(info->ccfs0, + chandef->chan->band); + + switch (u8_get_bits(info->control, + IEEE80211_EHT_OPER_CHAN_WIDTH)) { + case IEEE80211_EHT_OPER_CHAN_WIDTH_20MHZ: + chandef->width = NL80211_CHAN_WIDTH_20; + break; + case IEEE80211_EHT_OPER_CHAN_WIDTH_40MHZ: + chandef->width = NL80211_CHAN_WIDTH_40; + break; + case IEEE80211_EHT_OPER_CHAN_WIDTH_80MHZ: + chandef->width = NL80211_CHAN_WIDTH_80; + break; + case IEEE80211_EHT_OPER_CHAN_WIDTH_160MHZ: + if (support_160) { + chandef->width = NL80211_CHAN_WIDTH_160; + chandef->center_freq1 = + ieee80211_channel_to_frequency(info->ccfs1, + chandef->chan->band); + } else { + chandef->width = NL80211_CHAN_WIDTH_80; + } + break; + case IEEE80211_EHT_OPER_CHAN_WIDTH_320MHZ: + if (support_320) { + chandef->width = NL80211_CHAN_WIDTH_320; + chandef->center_freq1 = + ieee80211_channel_to_frequency(info->ccfs1, + chandef->chan->band); + } else if (support_160) { + chandef->width = NL80211_CHAN_WIDTH_160; + } else { + chandef->width = NL80211_CHAN_WIDTH_80; + + if (chandef->center_freq1 > chandef->chan->center_freq) + chandef->center_freq1 -= 40; + else + chandef->center_freq1 += 40; + } + break; + } +} + +bool ieee80211_chandef_he_6ghz_oper(struct ieee80211_sub_if_data *sdata, + const struct ieee80211_he_operation *he_oper, + const struct ieee80211_eht_operation *eht_oper, + struct cfg80211_chan_def *chandef) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_supported_band *sband; + enum nl80211_iftype iftype = ieee80211_vif_type_p2p(&sdata->vif); + const struct ieee80211_sta_he_cap *he_cap; + const struct ieee80211_sta_eht_cap *eht_cap; + struct cfg80211_chan_def he_chandef = *chandef; + const struct ieee80211_he_6ghz_oper *he_6ghz_oper; + struct ieee80211_bss_conf *bss_conf = &sdata->vif.bss_conf; + bool support_80_80, support_160, support_320; + u8 he_phy_cap, eht_phy_cap; + u32 freq; + + if (chandef->chan->band != NL80211_BAND_6GHZ) + return true; + + sband = local->hw.wiphy->bands[NL80211_BAND_6GHZ]; + + he_cap = ieee80211_get_he_iftype_cap(sband, iftype); + if (!he_cap) { + sdata_info(sdata, "Missing iftype sband data/HE cap"); + return false; + } + + he_phy_cap = he_cap->he_cap_elem.phy_cap_info[0]; + support_160 = + he_phy_cap & + IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_160MHZ_IN_5G; + support_80_80 = + he_phy_cap & + IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_80PLUS80_MHZ_IN_5G; + + if (!he_oper) { + sdata_info(sdata, + "HE is not advertised on (on %d MHz), expect issues\n", + chandef->chan->center_freq); + return false; + } + + eht_cap = ieee80211_get_eht_iftype_cap(sband, iftype); + if (!eht_cap) + eht_oper = NULL; + + he_6ghz_oper = ieee80211_he_6ghz_oper(he_oper); + + if (!he_6ghz_oper) { + sdata_info(sdata, + "HE 6GHz operation missing (on %d MHz), expect issues\n", + chandef->chan->center_freq); + return false; + } + + /* + * The EHT operation IE does not contain the primary channel so the + * primary channel frequency should be taken from the 6 GHz operation + * information. + */ + freq = ieee80211_channel_to_frequency(he_6ghz_oper->primary, + NL80211_BAND_6GHZ); + he_chandef.chan = ieee80211_get_channel(sdata->local->hw.wiphy, freq); + + switch (u8_get_bits(he_6ghz_oper->control, + IEEE80211_HE_6GHZ_OPER_CTRL_REG_INFO)) { + case IEEE80211_6GHZ_CTRL_REG_LPI_AP: + bss_conf->power_type = IEEE80211_REG_LPI_AP; + break; + case IEEE80211_6GHZ_CTRL_REG_SP_AP: + bss_conf->power_type = IEEE80211_REG_SP_AP; + break; + default: + bss_conf->power_type = IEEE80211_REG_UNSET_AP; + break; + } + + if (!eht_oper || + !(eht_oper->params & IEEE80211_EHT_OPER_INFO_PRESENT)) { + switch (u8_get_bits(he_6ghz_oper->control, + IEEE80211_HE_6GHZ_OPER_CTRL_CHANWIDTH)) { + case IEEE80211_HE_6GHZ_OPER_CTRL_CHANWIDTH_20MHZ: + he_chandef.width = NL80211_CHAN_WIDTH_20; + break; + case IEEE80211_HE_6GHZ_OPER_CTRL_CHANWIDTH_40MHZ: + he_chandef.width = NL80211_CHAN_WIDTH_40; + break; + case IEEE80211_HE_6GHZ_OPER_CTRL_CHANWIDTH_80MHZ: + he_chandef.width = NL80211_CHAN_WIDTH_80; + break; + case IEEE80211_HE_6GHZ_OPER_CTRL_CHANWIDTH_160MHZ: + he_chandef.width = NL80211_CHAN_WIDTH_80; + if (!he_6ghz_oper->ccfs1) + break; + if (abs(he_6ghz_oper->ccfs1 - he_6ghz_oper->ccfs0) == 8) { + if (support_160) + he_chandef.width = NL80211_CHAN_WIDTH_160; + } else { + if (support_80_80) + he_chandef.width = NL80211_CHAN_WIDTH_80P80; + } + break; + } + + if (he_chandef.width == NL80211_CHAN_WIDTH_160) { + he_chandef.center_freq1 = + ieee80211_channel_to_frequency(he_6ghz_oper->ccfs1, + NL80211_BAND_6GHZ); + } else { + he_chandef.center_freq1 = + ieee80211_channel_to_frequency(he_6ghz_oper->ccfs0, + NL80211_BAND_6GHZ); + if (support_80_80 || support_160) + he_chandef.center_freq2 = + ieee80211_channel_to_frequency(he_6ghz_oper->ccfs1, + NL80211_BAND_6GHZ); + } + } else { + eht_phy_cap = eht_cap->eht_cap_elem.phy_cap_info[0]; + support_320 = + eht_phy_cap & IEEE80211_EHT_PHY_CAP0_320MHZ_IN_6GHZ; + + ieee80211_chandef_eht_oper(eht_oper, support_160, + support_320, &he_chandef); + } + + if (!cfg80211_chandef_valid(&he_chandef)) { + sdata_info(sdata, + "HE 6GHz operation resulted in invalid chandef: %d MHz/%d/%d MHz/%d MHz\n", + he_chandef.chan ? he_chandef.chan->center_freq : 0, + he_chandef.width, + he_chandef.center_freq1, + he_chandef.center_freq2); + return false; + } + + *chandef = he_chandef; + + return true; +} + +bool ieee80211_chandef_s1g_oper(const struct ieee80211_s1g_oper_ie *oper, + struct cfg80211_chan_def *chandef) +{ + u32 oper_freq; + + if (!oper) + return false; + + switch (FIELD_GET(S1G_OPER_CH_WIDTH_OPER, oper->ch_width)) { + case IEEE80211_S1G_CHANWIDTH_1MHZ: + chandef->width = NL80211_CHAN_WIDTH_1; + break; + case IEEE80211_S1G_CHANWIDTH_2MHZ: + chandef->width = NL80211_CHAN_WIDTH_2; + break; + case IEEE80211_S1G_CHANWIDTH_4MHZ: + chandef->width = NL80211_CHAN_WIDTH_4; + break; + case IEEE80211_S1G_CHANWIDTH_8MHZ: + chandef->width = NL80211_CHAN_WIDTH_8; + break; + case IEEE80211_S1G_CHANWIDTH_16MHZ: + chandef->width = NL80211_CHAN_WIDTH_16; + break; + default: + return false; + } + + oper_freq = ieee80211_channel_to_freq_khz(oper->oper_ch, + NL80211_BAND_S1GHZ); + chandef->center_freq1 = KHZ_TO_MHZ(oper_freq); + chandef->freq1_offset = oper_freq % 1000; + + return true; +} + +int ieee80211_parse_bitrates(enum nl80211_chan_width width, + const struct ieee80211_supported_band *sband, + const u8 *srates, int srates_len, u32 *rates) +{ + u32 rate_flags = ieee80211_chanwidth_rate_flags(width); + int shift = ieee80211_chanwidth_get_shift(width); + struct ieee80211_rate *br; + int brate, rate, i, j, count = 0; + + *rates = 0; + + for (i = 0; i < srates_len; i++) { + rate = srates[i] & 0x7f; + + for (j = 0; j < sband->n_bitrates; j++) { + br = &sband->bitrates[j]; + if ((rate_flags & br->flags) != rate_flags) + continue; + + brate = DIV_ROUND_UP(br->bitrate, (1 << shift) * 5); + if (brate == rate) { + *rates |= BIT(j); + count++; + break; + } + } + } + return count; +} + +int ieee80211_add_srates_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, bool need_basic, + enum nl80211_band band) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_supported_band *sband; + int rate, shift; + u8 i, rates, *pos; + u32 basic_rates = sdata->vif.bss_conf.basic_rates; + u32 rate_flags; + + shift = ieee80211_vif_get_shift(&sdata->vif); + rate_flags = ieee80211_chandef_rate_flags(&sdata->vif.bss_conf.chandef); + sband = local->hw.wiphy->bands[band]; + rates = 0; + for (i = 0; i < sband->n_bitrates; i++) { + if ((rate_flags & sband->bitrates[i].flags) != rate_flags) + continue; + rates++; + } + if (rates > 8) + rates = 8; + + if (skb_tailroom(skb) < rates + 2) + return -ENOMEM; + + pos = skb_put(skb, rates + 2); + *pos++ = WLAN_EID_SUPP_RATES; + *pos++ = rates; + for (i = 0; i < rates; i++) { + u8 basic = 0; + if ((rate_flags & sband->bitrates[i].flags) != rate_flags) + continue; + + if (need_basic && basic_rates & BIT(i)) + basic = 0x80; + rate = DIV_ROUND_UP(sband->bitrates[i].bitrate, + 5 * (1 << shift)); + *pos++ = basic | (u8) rate; + } + + return 0; +} + +int ieee80211_add_ext_srates_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, bool need_basic, + enum nl80211_band band) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_supported_band *sband; + int rate, shift; + u8 i, exrates, *pos; + u32 basic_rates = sdata->vif.bss_conf.basic_rates; + u32 rate_flags; + + rate_flags = ieee80211_chandef_rate_flags(&sdata->vif.bss_conf.chandef); + shift = ieee80211_vif_get_shift(&sdata->vif); + + sband = local->hw.wiphy->bands[band]; + exrates = 0; + for (i = 0; i < sband->n_bitrates; i++) { + if ((rate_flags & sband->bitrates[i].flags) != rate_flags) + continue; + exrates++; + } + + if (exrates > 8) + exrates -= 8; + else + exrates = 0; + + if (skb_tailroom(skb) < exrates + 2) + return -ENOMEM; + + if (exrates) { + pos = skb_put(skb, exrates + 2); + *pos++ = WLAN_EID_EXT_SUPP_RATES; + *pos++ = exrates; + for (i = 8; i < sband->n_bitrates; i++) { + u8 basic = 0; + if ((rate_flags & sband->bitrates[i].flags) + != rate_flags) + continue; + if (need_basic && basic_rates & BIT(i)) + basic = 0x80; + rate = DIV_ROUND_UP(sband->bitrates[i].bitrate, + 5 * (1 << shift)); + *pos++ = basic | (u8) rate; + } + } + return 0; +} + +int ieee80211_ave_rssi(struct ieee80211_vif *vif) +{ + struct ieee80211_sub_if_data *sdata = vif_to_sdata(vif); + + if (WARN_ON_ONCE(sdata->vif.type != NL80211_IFTYPE_STATION)) + return 0; + + return -ewma_beacon_signal_read(&sdata->deflink.u.mgd.ave_beacon_signal); +} +EXPORT_SYMBOL_GPL(ieee80211_ave_rssi); + +u8 ieee80211_mcs_to_chains(const struct ieee80211_mcs_info *mcs) +{ + if (!mcs) + return 1; + + /* TODO: consider rx_highest */ + + if (mcs->rx_mask[3]) + return 4; + if (mcs->rx_mask[2]) + return 3; + if (mcs->rx_mask[1]) + return 2; + return 1; +} + +/** + * ieee80211_calculate_rx_timestamp - calculate timestamp in frame + * @local: mac80211 hw info struct + * @status: RX status + * @mpdu_len: total MPDU length (including FCS) + * @mpdu_offset: offset into MPDU to calculate timestamp at + * + * This function calculates the RX timestamp at the given MPDU offset, taking + * into account what the RX timestamp was. An offset of 0 will just normalize + * the timestamp to TSF at beginning of MPDU reception. + */ +u64 ieee80211_calculate_rx_timestamp(struct ieee80211_local *local, + struct ieee80211_rx_status *status, + unsigned int mpdu_len, + unsigned int mpdu_offset) +{ + u64 ts = status->mactime; + struct rate_info ri; + u16 rate; + u8 n_ltf; + + if (WARN_ON(!ieee80211_have_rx_timestamp(status))) + return 0; + + memset(&ri, 0, sizeof(ri)); + + ri.bw = status->bw; + + /* Fill cfg80211 rate info */ + switch (status->encoding) { + case RX_ENC_HE: + ri.flags |= RATE_INFO_FLAGS_HE_MCS; + ri.mcs = status->rate_idx; + ri.nss = status->nss; + ri.he_ru_alloc = status->he_ru; + if (status->enc_flags & RX_ENC_FLAG_SHORT_GI) + ri.flags |= RATE_INFO_FLAGS_SHORT_GI; + + /* + * See P802.11ax_D6.0, section 27.3.4 for + * VHT PPDU format. + */ + if (status->flag & RX_FLAG_MACTIME_PLCP_START) { + mpdu_offset += 2; + ts += 36; + + /* + * TODO: + * For HE MU PPDU, add the HE-SIG-B. + * For HE ER PPDU, add 8us for the HE-SIG-A. + * For HE TB PPDU, add 4us for the HE-STF. + * Add the HE-LTF durations - variable. + */ + } + + break; + case RX_ENC_HT: + ri.mcs = status->rate_idx; + ri.flags |= RATE_INFO_FLAGS_MCS; + if (status->enc_flags & RX_ENC_FLAG_SHORT_GI) + ri.flags |= RATE_INFO_FLAGS_SHORT_GI; + + /* + * See P802.11REVmd_D3.0, section 19.3.2 for + * HT PPDU format. + */ + if (status->flag & RX_FLAG_MACTIME_PLCP_START) { + mpdu_offset += 2; + if (status->enc_flags & RX_ENC_FLAG_HT_GF) + ts += 24; + else + ts += 32; + + /* + * Add Data HT-LTFs per streams + * TODO: add Extension HT-LTFs, 4us per LTF + */ + n_ltf = ((ri.mcs >> 3) & 3) + 1; + n_ltf = n_ltf == 3 ? 4 : n_ltf; + ts += n_ltf * 4; + } + + break; + case RX_ENC_VHT: + ri.flags |= RATE_INFO_FLAGS_VHT_MCS; + ri.mcs = status->rate_idx; + ri.nss = status->nss; + if (status->enc_flags & RX_ENC_FLAG_SHORT_GI) + ri.flags |= RATE_INFO_FLAGS_SHORT_GI; + + /* + * See P802.11REVmd_D3.0, section 21.3.2 for + * VHT PPDU format. + */ + if (status->flag & RX_FLAG_MACTIME_PLCP_START) { + mpdu_offset += 2; + ts += 36; + + /* + * Add VHT-LTFs per streams + */ + n_ltf = (ri.nss != 1) && (ri.nss % 2) ? + ri.nss + 1 : ri.nss; + ts += 4 * n_ltf; + } + + break; + default: + WARN_ON(1); + fallthrough; + case RX_ENC_LEGACY: { + struct ieee80211_supported_band *sband; + int shift = 0; + int bitrate; + + switch (status->bw) { + case RATE_INFO_BW_10: + shift = 1; + break; + case RATE_INFO_BW_5: + shift = 2; + break; + } + + sband = local->hw.wiphy->bands[status->band]; + bitrate = sband->bitrates[status->rate_idx].bitrate; + ri.legacy = DIV_ROUND_UP(bitrate, (1 << shift)); + + if (status->flag & RX_FLAG_MACTIME_PLCP_START) { + if (status->band == NL80211_BAND_5GHZ) { + ts += 20 << shift; + mpdu_offset += 2; + } else if (status->enc_flags & RX_ENC_FLAG_SHORTPRE) { + ts += 96; + } else { + ts += 192; + } + } + break; + } + } + + rate = cfg80211_calculate_bitrate(&ri); + if (WARN_ONCE(!rate, + "Invalid bitrate: flags=0x%llx, idx=%d, vht_nss=%d\n", + (unsigned long long)status->flag, status->rate_idx, + status->nss)) + return 0; + + /* rewind from end of MPDU */ + if (status->flag & RX_FLAG_MACTIME_END) + ts -= mpdu_len * 8 * 10 / rate; + + ts += mpdu_offset * 8 * 10 / rate; + + return ts; +} + +void ieee80211_dfs_cac_cancel(struct ieee80211_local *local) +{ + struct ieee80211_sub_if_data *sdata; + struct cfg80211_chan_def chandef; + + /* for interface list, to avoid linking iflist_mtx and chanctx_mtx */ + lockdep_assert_wiphy(local->hw.wiphy); + + mutex_lock(&local->mtx); + list_for_each_entry(sdata, &local->interfaces, list) { + /* it might be waiting for the local->mtx, but then + * by the time it gets it, sdata->wdev.cac_started + * will no longer be true + */ + cancel_delayed_work(&sdata->deflink.dfs_cac_timer_work); + + if (sdata->wdev.cac_started) { + chandef = sdata->vif.bss_conf.chandef; + ieee80211_link_release_channel(&sdata->deflink); + cfg80211_cac_event(sdata->dev, + &chandef, + NL80211_RADAR_CAC_ABORTED, + GFP_KERNEL); + } + } + mutex_unlock(&local->mtx); +} + +void ieee80211_dfs_radar_detected_work(struct wiphy *wiphy, + struct wiphy_work *work) +{ + struct ieee80211_local *local = + container_of(work, struct ieee80211_local, radar_detected_work); + struct cfg80211_chan_def chandef = local->hw.conf.chandef; + struct ieee80211_chanctx *ctx; + int num_chanctx = 0; + + mutex_lock(&local->chanctx_mtx); + list_for_each_entry(ctx, &local->chanctx_list, list) { + if (ctx->replace_state == IEEE80211_CHANCTX_REPLACES_OTHER) + continue; + + num_chanctx++; + chandef = ctx->conf.def; + } + mutex_unlock(&local->chanctx_mtx); + + ieee80211_dfs_cac_cancel(local); + + if (num_chanctx > 1) + /* XXX: multi-channel is not supported yet */ + WARN_ON(1); + else + cfg80211_radar_event(local->hw.wiphy, &chandef, GFP_KERNEL); +} + +void ieee80211_radar_detected(struct ieee80211_hw *hw) +{ + struct ieee80211_local *local = hw_to_local(hw); + + trace_api_radar_detected(local); + + wiphy_work_queue(hw->wiphy, &local->radar_detected_work); +} +EXPORT_SYMBOL(ieee80211_radar_detected); + +ieee80211_conn_flags_t ieee80211_chandef_downgrade(struct cfg80211_chan_def *c) +{ + ieee80211_conn_flags_t ret; + int tmp; + + switch (c->width) { + case NL80211_CHAN_WIDTH_20: + c->width = NL80211_CHAN_WIDTH_20_NOHT; + ret = IEEE80211_CONN_DISABLE_HT | IEEE80211_CONN_DISABLE_VHT; + break; + case NL80211_CHAN_WIDTH_40: + c->width = NL80211_CHAN_WIDTH_20; + c->center_freq1 = c->chan->center_freq; + ret = IEEE80211_CONN_DISABLE_40MHZ | + IEEE80211_CONN_DISABLE_VHT; + break; + case NL80211_CHAN_WIDTH_80: + tmp = (30 + c->chan->center_freq - c->center_freq1)/20; + /* n_P40 */ + tmp /= 2; + /* freq_P40 */ + c->center_freq1 = c->center_freq1 - 20 + 40 * tmp; + c->width = NL80211_CHAN_WIDTH_40; + ret = IEEE80211_CONN_DISABLE_VHT; + break; + case NL80211_CHAN_WIDTH_80P80: + c->center_freq2 = 0; + c->width = NL80211_CHAN_WIDTH_80; + ret = IEEE80211_CONN_DISABLE_80P80MHZ | + IEEE80211_CONN_DISABLE_160MHZ; + break; + case NL80211_CHAN_WIDTH_160: + /* n_P20 */ + tmp = (70 + c->chan->center_freq - c->center_freq1)/20; + /* n_P80 */ + tmp /= 4; + c->center_freq1 = c->center_freq1 - 40 + 80 * tmp; + c->width = NL80211_CHAN_WIDTH_80; + ret = IEEE80211_CONN_DISABLE_80P80MHZ | + IEEE80211_CONN_DISABLE_160MHZ; + break; + case NL80211_CHAN_WIDTH_320: + /* n_P20 */ + tmp = (150 + c->chan->center_freq - c->center_freq1) / 20; + /* n_P160 */ + tmp /= 8; + c->center_freq1 = c->center_freq1 - 80 + 160 * tmp; + c->width = NL80211_CHAN_WIDTH_160; + ret = IEEE80211_CONN_DISABLE_320MHZ; + break; + default: + case NL80211_CHAN_WIDTH_20_NOHT: + WARN_ON_ONCE(1); + c->width = NL80211_CHAN_WIDTH_20_NOHT; + ret = IEEE80211_CONN_DISABLE_HT | IEEE80211_CONN_DISABLE_VHT; + break; + case NL80211_CHAN_WIDTH_1: + case NL80211_CHAN_WIDTH_2: + case NL80211_CHAN_WIDTH_4: + case NL80211_CHAN_WIDTH_8: + case NL80211_CHAN_WIDTH_16: + case NL80211_CHAN_WIDTH_5: + case NL80211_CHAN_WIDTH_10: + WARN_ON_ONCE(1); + /* keep c->width */ + ret = IEEE80211_CONN_DISABLE_HT | IEEE80211_CONN_DISABLE_VHT; + break; + } + + WARN_ON_ONCE(!cfg80211_chandef_valid(c)); + + return ret; +} + +/* + * Returns true if smps_mode_new is strictly more restrictive than + * smps_mode_old. + */ +bool ieee80211_smps_is_restrictive(enum ieee80211_smps_mode smps_mode_old, + enum ieee80211_smps_mode smps_mode_new) +{ + if (WARN_ON_ONCE(smps_mode_old == IEEE80211_SMPS_AUTOMATIC || + smps_mode_new == IEEE80211_SMPS_AUTOMATIC)) + return false; + + switch (smps_mode_old) { + case IEEE80211_SMPS_STATIC: + return false; + case IEEE80211_SMPS_DYNAMIC: + return smps_mode_new == IEEE80211_SMPS_STATIC; + case IEEE80211_SMPS_OFF: + return smps_mode_new != IEEE80211_SMPS_OFF; + default: + WARN_ON(1); + } + + return false; +} + +int ieee80211_send_action_csa(struct ieee80211_sub_if_data *sdata, + struct cfg80211_csa_settings *csa_settings) +{ + struct sk_buff *skb; + struct ieee80211_mgmt *mgmt; + struct ieee80211_local *local = sdata->local; + int freq; + int hdr_len = offsetofend(struct ieee80211_mgmt, + u.action.u.chan_switch); + u8 *pos; + + if (sdata->vif.type != NL80211_IFTYPE_ADHOC && + sdata->vif.type != NL80211_IFTYPE_MESH_POINT) + return -EOPNOTSUPP; + + skb = dev_alloc_skb(local->tx_headroom + hdr_len + + 5 + /* channel switch announcement element */ + 3 + /* secondary channel offset element */ + 5 + /* wide bandwidth channel switch announcement */ + 8); /* mesh channel switch parameters element */ + if (!skb) + return -ENOMEM; + + skb_reserve(skb, local->tx_headroom); + mgmt = skb_put_zero(skb, hdr_len); + mgmt->frame_control = cpu_to_le16(IEEE80211_FTYPE_MGMT | + IEEE80211_STYPE_ACTION); + + eth_broadcast_addr(mgmt->da); + memcpy(mgmt->sa, sdata->vif.addr, ETH_ALEN); + if (ieee80211_vif_is_mesh(&sdata->vif)) { + memcpy(mgmt->bssid, sdata->vif.addr, ETH_ALEN); + } else { + struct ieee80211_if_ibss *ifibss = &sdata->u.ibss; + memcpy(mgmt->bssid, ifibss->bssid, ETH_ALEN); + } + mgmt->u.action.category = WLAN_CATEGORY_SPECTRUM_MGMT; + mgmt->u.action.u.chan_switch.action_code = WLAN_ACTION_SPCT_CHL_SWITCH; + pos = skb_put(skb, 5); + *pos++ = WLAN_EID_CHANNEL_SWITCH; /* EID */ + *pos++ = 3; /* IE length */ + *pos++ = csa_settings->block_tx ? 1 : 0; /* CSA mode */ + freq = csa_settings->chandef.chan->center_freq; + *pos++ = ieee80211_frequency_to_channel(freq); /* channel */ + *pos++ = csa_settings->count; /* count */ + + if (csa_settings->chandef.width == NL80211_CHAN_WIDTH_40) { + enum nl80211_channel_type ch_type; + + skb_put(skb, 3); + *pos++ = WLAN_EID_SECONDARY_CHANNEL_OFFSET; /* EID */ + *pos++ = 1; /* IE length */ + ch_type = cfg80211_get_chandef_type(&csa_settings->chandef); + if (ch_type == NL80211_CHAN_HT40PLUS) + *pos++ = IEEE80211_HT_PARAM_CHA_SEC_ABOVE; + else + *pos++ = IEEE80211_HT_PARAM_CHA_SEC_BELOW; + } + + if (ieee80211_vif_is_mesh(&sdata->vif)) { + struct ieee80211_if_mesh *ifmsh = &sdata->u.mesh; + + skb_put(skb, 8); + *pos++ = WLAN_EID_CHAN_SWITCH_PARAM; /* EID */ + *pos++ = 6; /* IE length */ + *pos++ = sdata->u.mesh.mshcfg.dot11MeshTTL; /* Mesh TTL */ + *pos = 0x00; /* Mesh Flag: Tx Restrict, Initiator, Reason */ + *pos |= WLAN_EID_CHAN_SWITCH_PARAM_INITIATOR; + *pos++ |= csa_settings->block_tx ? + WLAN_EID_CHAN_SWITCH_PARAM_TX_RESTRICT : 0x00; + put_unaligned_le16(WLAN_REASON_MESH_CHAN, pos); /* Reason Cd */ + pos += 2; + put_unaligned_le16(ifmsh->pre_value, pos);/* Precedence Value */ + pos += 2; + } + + if (csa_settings->chandef.width == NL80211_CHAN_WIDTH_80 || + csa_settings->chandef.width == NL80211_CHAN_WIDTH_80P80 || + csa_settings->chandef.width == NL80211_CHAN_WIDTH_160) { + skb_put(skb, 5); + ieee80211_ie_build_wide_bw_cs(pos, &csa_settings->chandef); + } + + ieee80211_tx_skb(sdata, skb); + return 0; +} + +static bool +ieee80211_extend_noa_desc(struct ieee80211_noa_data *data, u32 tsf, int i) +{ + s32 end = data->desc[i].start + data->desc[i].duration - (tsf + 1); + int skip; + + if (end > 0) + return false; + + /* One shot NOA */ + if (data->count[i] == 1) + return false; + + if (data->desc[i].interval == 0) + return false; + + /* End time is in the past, check for repetitions */ + skip = DIV_ROUND_UP(-end, data->desc[i].interval); + if (data->count[i] < 255) { + if (data->count[i] <= skip) { + data->count[i] = 0; + return false; + } + + data->count[i] -= skip; + } + + data->desc[i].start += skip * data->desc[i].interval; + + return true; +} + +static bool +ieee80211_extend_absent_time(struct ieee80211_noa_data *data, u32 tsf, + s32 *offset) +{ + bool ret = false; + int i; + + for (i = 0; i < IEEE80211_P2P_NOA_DESC_MAX; i++) { + s32 cur; + + if (!data->count[i]) + continue; + + if (ieee80211_extend_noa_desc(data, tsf + *offset, i)) + ret = true; + + cur = data->desc[i].start - tsf; + if (cur > *offset) + continue; + + cur = data->desc[i].start + data->desc[i].duration - tsf; + if (cur > *offset) + *offset = cur; + } + + return ret; +} + +static u32 +ieee80211_get_noa_absent_time(struct ieee80211_noa_data *data, u32 tsf) +{ + s32 offset = 0; + int tries = 0; + /* + * arbitrary limit, used to avoid infinite loops when combined NoA + * descriptors cover the full time period. + */ + int max_tries = 5; + + ieee80211_extend_absent_time(data, tsf, &offset); + do { + if (!ieee80211_extend_absent_time(data, tsf, &offset)) + break; + + tries++; + } while (tries < max_tries); + + return offset; +} + +void ieee80211_update_p2p_noa(struct ieee80211_noa_data *data, u32 tsf) +{ + u32 next_offset = BIT(31) - 1; + int i; + + data->absent = 0; + data->has_next_tsf = false; + for (i = 0; i < IEEE80211_P2P_NOA_DESC_MAX; i++) { + s32 start; + + if (!data->count[i]) + continue; + + ieee80211_extend_noa_desc(data, tsf, i); + start = data->desc[i].start - tsf; + if (start <= 0) + data->absent |= BIT(i); + + if (next_offset > start) + next_offset = start; + + data->has_next_tsf = true; + } + + if (data->absent) + next_offset = ieee80211_get_noa_absent_time(data, tsf); + + data->next_tsf = tsf + next_offset; +} +EXPORT_SYMBOL(ieee80211_update_p2p_noa); + +int ieee80211_parse_p2p_noa(const struct ieee80211_p2p_noa_attr *attr, + struct ieee80211_noa_data *data, u32 tsf) +{ + int ret = 0; + int i; + + memset(data, 0, sizeof(*data)); + + for (i = 0; i < IEEE80211_P2P_NOA_DESC_MAX; i++) { + const struct ieee80211_p2p_noa_desc *desc = &attr->desc[i]; + + if (!desc->count || !desc->duration) + continue; + + data->count[i] = desc->count; + data->desc[i].start = le32_to_cpu(desc->start_time); + data->desc[i].duration = le32_to_cpu(desc->duration); + data->desc[i].interval = le32_to_cpu(desc->interval); + + if (data->count[i] > 1 && + data->desc[i].interval < data->desc[i].duration) + continue; + + ieee80211_extend_noa_desc(data, tsf, i); + ret++; + } + + if (ret) + ieee80211_update_p2p_noa(data, tsf); + + return ret; +} +EXPORT_SYMBOL(ieee80211_parse_p2p_noa); + +void ieee80211_recalc_dtim(struct ieee80211_local *local, + struct ieee80211_sub_if_data *sdata) +{ + u64 tsf = drv_get_tsf(local, sdata); + u64 dtim_count = 0; + u16 beacon_int = sdata->vif.bss_conf.beacon_int * 1024; + u8 dtim_period = sdata->vif.bss_conf.dtim_period; + struct ps_data *ps; + u8 bcns_from_dtim; + + if (tsf == -1ULL || !beacon_int || !dtim_period) + return; + + if (sdata->vif.type == NL80211_IFTYPE_AP || + sdata->vif.type == NL80211_IFTYPE_AP_VLAN) { + if (!sdata->bss) + return; + + ps = &sdata->bss->ps; + } else if (ieee80211_vif_is_mesh(&sdata->vif)) { + ps = &sdata->u.mesh.ps; + } else { + return; + } + + /* + * actually finds last dtim_count, mac80211 will update in + * __beacon_add_tim(). + * dtim_count = dtim_period - (tsf / bcn_int) % dtim_period + */ + do_div(tsf, beacon_int); + bcns_from_dtim = do_div(tsf, dtim_period); + /* just had a DTIM */ + if (!bcns_from_dtim) + dtim_count = 0; + else + dtim_count = dtim_period - bcns_from_dtim; + + ps->dtim_count = dtim_count; +} + +static u8 ieee80211_chanctx_radar_detect(struct ieee80211_local *local, + struct ieee80211_chanctx *ctx) +{ + struct ieee80211_link_data *link; + u8 radar_detect = 0; + + lockdep_assert_held(&local->chanctx_mtx); + + if (WARN_ON(ctx->replace_state == IEEE80211_CHANCTX_WILL_BE_REPLACED)) + return 0; + + list_for_each_entry(link, &ctx->reserved_links, reserved_chanctx_list) + if (link->reserved_radar_required) + radar_detect |= BIT(link->reserved_chandef.width); + + /* + * An in-place reservation context should not have any assigned vifs + * until it replaces the other context. + */ + WARN_ON(ctx->replace_state == IEEE80211_CHANCTX_REPLACES_OTHER && + !list_empty(&ctx->assigned_links)); + + list_for_each_entry(link, &ctx->assigned_links, assigned_chanctx_list) { + if (!link->radar_required) + continue; + + radar_detect |= + BIT(link->conf->chandef.width); + } + + return radar_detect; +} + +int ieee80211_check_combinations(struct ieee80211_sub_if_data *sdata, + const struct cfg80211_chan_def *chandef, + enum ieee80211_chanctx_mode chanmode, + u8 radar_detect) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_sub_if_data *sdata_iter; + enum nl80211_iftype iftype = sdata->wdev.iftype; + struct ieee80211_chanctx *ctx; + int total = 1; + struct iface_combination_params params = { + .radar_detect = radar_detect, + }; + + lockdep_assert_held(&local->chanctx_mtx); + + if (WARN_ON(hweight32(radar_detect) > 1)) + return -EINVAL; + + if (WARN_ON(chandef && chanmode == IEEE80211_CHANCTX_SHARED && + !chandef->chan)) + return -EINVAL; + + if (WARN_ON(iftype >= NUM_NL80211_IFTYPES)) + return -EINVAL; + + if (sdata->vif.type == NL80211_IFTYPE_AP || + sdata->vif.type == NL80211_IFTYPE_MESH_POINT) { + /* + * always passing this is harmless, since it'll be the + * same value that cfg80211 finds if it finds the same + * interface ... and that's always allowed + */ + params.new_beacon_int = sdata->vif.bss_conf.beacon_int; + } + + /* Always allow software iftypes */ + if (cfg80211_iftype_allowed(local->hw.wiphy, iftype, 0, 1)) { + if (radar_detect) + return -EINVAL; + return 0; + } + + if (chandef) + params.num_different_channels = 1; + + if (iftype != NL80211_IFTYPE_UNSPECIFIED) + params.iftype_num[iftype] = 1; + + list_for_each_entry(ctx, &local->chanctx_list, list) { + if (ctx->replace_state == IEEE80211_CHANCTX_WILL_BE_REPLACED) + continue; + params.radar_detect |= + ieee80211_chanctx_radar_detect(local, ctx); + if (ctx->mode == IEEE80211_CHANCTX_EXCLUSIVE) { + params.num_different_channels++; + continue; + } + if (chandef && chanmode == IEEE80211_CHANCTX_SHARED && + cfg80211_chandef_compatible(chandef, + &ctx->conf.def)) + continue; + params.num_different_channels++; + } + + list_for_each_entry_rcu(sdata_iter, &local->interfaces, list) { + struct wireless_dev *wdev_iter; + + wdev_iter = &sdata_iter->wdev; + + if (sdata_iter == sdata || + !ieee80211_sdata_running(sdata_iter) || + cfg80211_iftype_allowed(local->hw.wiphy, + wdev_iter->iftype, 0, 1)) + continue; + + params.iftype_num[wdev_iter->iftype]++; + total++; + } + + if (total == 1 && !params.radar_detect) + return 0; + + return cfg80211_check_combinations(local->hw.wiphy, ¶ms); +} + +static void +ieee80211_iter_max_chans(const struct ieee80211_iface_combination *c, + void *data) +{ + u32 *max_num_different_channels = data; + + *max_num_different_channels = max(*max_num_different_channels, + c->num_different_channels); +} + +int ieee80211_max_num_channels(struct ieee80211_local *local) +{ + struct ieee80211_sub_if_data *sdata; + struct ieee80211_chanctx *ctx; + u32 max_num_different_channels = 1; + int err; + struct iface_combination_params params = {0}; + + lockdep_assert_held(&local->chanctx_mtx); + + list_for_each_entry(ctx, &local->chanctx_list, list) { + if (ctx->replace_state == IEEE80211_CHANCTX_WILL_BE_REPLACED) + continue; + + params.num_different_channels++; + + params.radar_detect |= + ieee80211_chanctx_radar_detect(local, ctx); + } + + list_for_each_entry_rcu(sdata, &local->interfaces, list) + params.iftype_num[sdata->wdev.iftype]++; + + err = cfg80211_iter_combinations(local->hw.wiphy, ¶ms, + ieee80211_iter_max_chans, + &max_num_different_channels); + if (err < 0) + return err; + + return max_num_different_channels; +} + +void ieee80211_add_s1g_capab_ie(struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta_s1g_cap *caps, + struct sk_buff *skb) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + struct ieee80211_s1g_cap s1g_capab; + u8 *pos; + int i; + + if (WARN_ON(sdata->vif.type != NL80211_IFTYPE_STATION)) + return; + + if (!caps->s1g) + return; + + memcpy(s1g_capab.capab_info, caps->cap, sizeof(caps->cap)); + memcpy(s1g_capab.supp_mcs_nss, caps->nss_mcs, sizeof(caps->nss_mcs)); + + /* override the capability info */ + for (i = 0; i < sizeof(ifmgd->s1g_capa.capab_info); i++) { + u8 mask = ifmgd->s1g_capa_mask.capab_info[i]; + + s1g_capab.capab_info[i] &= ~mask; + s1g_capab.capab_info[i] |= ifmgd->s1g_capa.capab_info[i] & mask; + } + + /* then MCS and NSS set */ + for (i = 0; i < sizeof(ifmgd->s1g_capa.supp_mcs_nss); i++) { + u8 mask = ifmgd->s1g_capa_mask.supp_mcs_nss[i]; + + s1g_capab.supp_mcs_nss[i] &= ~mask; + s1g_capab.supp_mcs_nss[i] |= + ifmgd->s1g_capa.supp_mcs_nss[i] & mask; + } + + pos = skb_put(skb, 2 + sizeof(s1g_capab)); + *pos++ = WLAN_EID_S1G_CAPABILITIES; + *pos++ = sizeof(s1g_capab); + + memcpy(pos, &s1g_capab, sizeof(s1g_capab)); +} + +void ieee80211_add_aid_request_ie(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + u8 *pos = skb_put(skb, 3); + + *pos++ = WLAN_EID_AID_REQUEST; + *pos++ = 1; + *pos++ = 0; +} + +u8 *ieee80211_add_wmm_info_ie(u8 *buf, u8 qosinfo) +{ + *buf++ = WLAN_EID_VENDOR_SPECIFIC; + *buf++ = 7; /* len */ + *buf++ = 0x00; /* Microsoft OUI 00:50:F2 */ + *buf++ = 0x50; + *buf++ = 0xf2; + *buf++ = 2; /* WME */ + *buf++ = 0; /* WME info */ + *buf++ = 1; /* WME ver */ + *buf++ = qosinfo; /* U-APSD no in use */ + + return buf; +} + +void ieee80211_txq_get_depth(struct ieee80211_txq *txq, + unsigned long *frame_cnt, + unsigned long *byte_cnt) +{ + struct txq_info *txqi = to_txq_info(txq); + u32 frag_cnt = 0, frag_bytes = 0; + struct sk_buff *skb; + + skb_queue_walk(&txqi->frags, skb) { + frag_cnt++; + frag_bytes += skb->len; + } + + if (frame_cnt) + *frame_cnt = txqi->tin.backlog_packets + frag_cnt; + + if (byte_cnt) + *byte_cnt = txqi->tin.backlog_bytes + frag_bytes; +} +EXPORT_SYMBOL(ieee80211_txq_get_depth); + +const u8 ieee80211_ac_to_qos_mask[IEEE80211_NUM_ACS] = { + IEEE80211_WMM_IE_STA_QOSINFO_AC_VO, + IEEE80211_WMM_IE_STA_QOSINFO_AC_VI, + IEEE80211_WMM_IE_STA_QOSINFO_AC_BE, + IEEE80211_WMM_IE_STA_QOSINFO_AC_BK +}; + +u16 ieee80211_encode_usf(int listen_interval) +{ + static const int listen_int_usf[] = { 1, 10, 1000, 10000 }; + u16 ui, usf = 0; + + /* find greatest USF */ + while (usf < IEEE80211_MAX_USF) { + if (listen_interval % listen_int_usf[usf + 1]) + break; + usf += 1; + } + ui = listen_interval / listen_int_usf[usf]; + + /* error if there is a remainder. Should've been checked by user */ + WARN_ON_ONCE(ui > IEEE80211_MAX_UI); + listen_interval = FIELD_PREP(LISTEN_INT_USF, usf) | + FIELD_PREP(LISTEN_INT_UI, ui); + + return (u16) listen_interval; +} + +u8 ieee80211_ie_len_eht_cap(struct ieee80211_sub_if_data *sdata, u8 iftype) +{ + const struct ieee80211_sta_he_cap *he_cap; + const struct ieee80211_sta_eht_cap *eht_cap; + struct ieee80211_supported_band *sband; + bool is_ap; + u8 n; + + sband = ieee80211_get_sband(sdata); + if (!sband) + return 0; + + he_cap = ieee80211_get_he_iftype_cap(sband, iftype); + eht_cap = ieee80211_get_eht_iftype_cap(sband, iftype); + if (!he_cap || !eht_cap) + return 0; + + is_ap = iftype == NL80211_IFTYPE_AP || + iftype == NL80211_IFTYPE_P2P_GO; + + n = ieee80211_eht_mcs_nss_size(&he_cap->he_cap_elem, + &eht_cap->eht_cap_elem, + is_ap); + return 2 + 1 + + sizeof(eht_cap->eht_cap_elem) + n + + ieee80211_eht_ppe_size(eht_cap->eht_ppe_thres[0], + eht_cap->eht_cap_elem.phy_cap_info); + return 0; +} + +u8 *ieee80211_ie_build_eht_cap(u8 *pos, + const struct ieee80211_sta_he_cap *he_cap, + const struct ieee80211_sta_eht_cap *eht_cap, + u8 *end, + bool for_ap) +{ + u8 mcs_nss_len, ppet_len; + u8 ie_len; + u8 *orig_pos = pos; + + /* Make sure we have place for the IE */ + if (!he_cap || !eht_cap) + return orig_pos; + + mcs_nss_len = ieee80211_eht_mcs_nss_size(&he_cap->he_cap_elem, + &eht_cap->eht_cap_elem, + for_ap); + ppet_len = ieee80211_eht_ppe_size(eht_cap->eht_ppe_thres[0], + eht_cap->eht_cap_elem.phy_cap_info); + + ie_len = 2 + 1 + sizeof(eht_cap->eht_cap_elem) + mcs_nss_len + ppet_len; + if ((end - pos) < ie_len) + return orig_pos; + + *pos++ = WLAN_EID_EXTENSION; + *pos++ = ie_len - 2; + *pos++ = WLAN_EID_EXT_EHT_CAPABILITY; + + /* Fixed data */ + memcpy(pos, &eht_cap->eht_cap_elem, sizeof(eht_cap->eht_cap_elem)); + pos += sizeof(eht_cap->eht_cap_elem); + + memcpy(pos, &eht_cap->eht_mcs_nss_supp, mcs_nss_len); + pos += mcs_nss_len; + + if (ppet_len) { + memcpy(pos, &eht_cap->eht_ppe_thres, ppet_len); + pos += ppet_len; + } + + return pos; +} + +void ieee80211_fragment_element(struct sk_buff *skb, u8 *len_pos) +{ + unsigned int elem_len; + + if (!len_pos) + return; + + elem_len = skb->data + skb->len - len_pos - 1; + + while (elem_len > 255) { + /* this one is 255 */ + *len_pos = 255; + /* remaining data gets smaller */ + elem_len -= 255; + /* make space for the fragment ID/len in SKB */ + skb_put(skb, 2); + /* shift back the remaining data to place fragment ID/len */ + memmove(len_pos + 255 + 3, len_pos + 255 + 1, elem_len); + /* place the fragment ID */ + len_pos += 255 + 1; + *len_pos = WLAN_EID_FRAGMENT; + /* and point to fragment length to update later */ + len_pos++; + } + + *len_pos = elem_len; +} diff --git a/net/mac80211/vht.c b/net/mac80211/vht.c new file mode 100644 index 000000000..f7526be8a --- /dev/null +++ b/net/mac80211/vht.c @@ -0,0 +1,778 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * VHT handling + * + * Portions of this file + * Copyright(c) 2015 - 2016 Intel Deutschland GmbH + * Copyright (C) 2018 - 2023 Intel Corporation + */ + +#include +#include +#include +#include "ieee80211_i.h" +#include "rate.h" + + +static void __check_vhtcap_disable(struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta_vht_cap *vht_cap, + u32 flag) +{ + __le32 le_flag = cpu_to_le32(flag); + + if (sdata->u.mgd.vht_capa_mask.vht_cap_info & le_flag && + !(sdata->u.mgd.vht_capa.vht_cap_info & le_flag)) + vht_cap->cap &= ~flag; +} + +void ieee80211_apply_vhtcap_overrides(struct ieee80211_sub_if_data *sdata, + struct ieee80211_sta_vht_cap *vht_cap) +{ + int i; + u16 rxmcs_mask, rxmcs_cap, rxmcs_n, txmcs_mask, txmcs_cap, txmcs_n; + + if (!vht_cap->vht_supported) + return; + + if (sdata->vif.type != NL80211_IFTYPE_STATION) + return; + + __check_vhtcap_disable(sdata, vht_cap, + IEEE80211_VHT_CAP_RXLDPC); + __check_vhtcap_disable(sdata, vht_cap, + IEEE80211_VHT_CAP_SHORT_GI_80); + __check_vhtcap_disable(sdata, vht_cap, + IEEE80211_VHT_CAP_SHORT_GI_160); + __check_vhtcap_disable(sdata, vht_cap, + IEEE80211_VHT_CAP_TXSTBC); + __check_vhtcap_disable(sdata, vht_cap, + IEEE80211_VHT_CAP_SU_BEAMFORMER_CAPABLE); + __check_vhtcap_disable(sdata, vht_cap, + IEEE80211_VHT_CAP_SU_BEAMFORMEE_CAPABLE); + __check_vhtcap_disable(sdata, vht_cap, + IEEE80211_VHT_CAP_RX_ANTENNA_PATTERN); + __check_vhtcap_disable(sdata, vht_cap, + IEEE80211_VHT_CAP_TX_ANTENNA_PATTERN); + + /* Allow user to decrease AMPDU length exponent */ + if (sdata->u.mgd.vht_capa_mask.vht_cap_info & + cpu_to_le32(IEEE80211_VHT_CAP_MAX_A_MPDU_LENGTH_EXPONENT_MASK)) { + u32 cap, n; + + n = le32_to_cpu(sdata->u.mgd.vht_capa.vht_cap_info) & + IEEE80211_VHT_CAP_MAX_A_MPDU_LENGTH_EXPONENT_MASK; + n >>= IEEE80211_VHT_CAP_MAX_A_MPDU_LENGTH_EXPONENT_SHIFT; + cap = vht_cap->cap & IEEE80211_VHT_CAP_MAX_A_MPDU_LENGTH_EXPONENT_MASK; + cap >>= IEEE80211_VHT_CAP_MAX_A_MPDU_LENGTH_EXPONENT_SHIFT; + + if (n < cap) { + vht_cap->cap &= + ~IEEE80211_VHT_CAP_MAX_A_MPDU_LENGTH_EXPONENT_MASK; + vht_cap->cap |= + n << IEEE80211_VHT_CAP_MAX_A_MPDU_LENGTH_EXPONENT_SHIFT; + } + } + + /* Allow the user to decrease MCSes */ + rxmcs_mask = + le16_to_cpu(sdata->u.mgd.vht_capa_mask.supp_mcs.rx_mcs_map); + rxmcs_n = le16_to_cpu(sdata->u.mgd.vht_capa.supp_mcs.rx_mcs_map); + rxmcs_n &= rxmcs_mask; + rxmcs_cap = le16_to_cpu(vht_cap->vht_mcs.rx_mcs_map); + + txmcs_mask = + le16_to_cpu(sdata->u.mgd.vht_capa_mask.supp_mcs.tx_mcs_map); + txmcs_n = le16_to_cpu(sdata->u.mgd.vht_capa.supp_mcs.tx_mcs_map); + txmcs_n &= txmcs_mask; + txmcs_cap = le16_to_cpu(vht_cap->vht_mcs.tx_mcs_map); + for (i = 0; i < 8; i++) { + u8 m, n, c; + + m = (rxmcs_mask >> 2*i) & IEEE80211_VHT_MCS_NOT_SUPPORTED; + n = (rxmcs_n >> 2*i) & IEEE80211_VHT_MCS_NOT_SUPPORTED; + c = (rxmcs_cap >> 2*i) & IEEE80211_VHT_MCS_NOT_SUPPORTED; + + if (m && ((c != IEEE80211_VHT_MCS_NOT_SUPPORTED && n < c) || + n == IEEE80211_VHT_MCS_NOT_SUPPORTED)) { + rxmcs_cap &= ~(3 << 2*i); + rxmcs_cap |= (rxmcs_n & (3 << 2*i)); + } + + m = (txmcs_mask >> 2*i) & IEEE80211_VHT_MCS_NOT_SUPPORTED; + n = (txmcs_n >> 2*i) & IEEE80211_VHT_MCS_NOT_SUPPORTED; + c = (txmcs_cap >> 2*i) & IEEE80211_VHT_MCS_NOT_SUPPORTED; + + if (m && ((c != IEEE80211_VHT_MCS_NOT_SUPPORTED && n < c) || + n == IEEE80211_VHT_MCS_NOT_SUPPORTED)) { + txmcs_cap &= ~(3 << 2*i); + txmcs_cap |= (txmcs_n & (3 << 2*i)); + } + } + vht_cap->vht_mcs.rx_mcs_map = cpu_to_le16(rxmcs_cap); + vht_cap->vht_mcs.tx_mcs_map = cpu_to_le16(txmcs_cap); +} + +void +ieee80211_vht_cap_ie_to_sta_vht_cap(struct ieee80211_sub_if_data *sdata, + struct ieee80211_supported_band *sband, + const struct ieee80211_vht_cap *vht_cap_ie, + const struct ieee80211_vht_cap *vht_cap_ie2, + struct link_sta_info *link_sta) +{ + struct ieee80211_sta_vht_cap *vht_cap = &link_sta->pub->vht_cap; + struct ieee80211_sta_vht_cap own_cap; + u32 cap_info, i; + bool have_80mhz; + u32 mpdu_len; + + memset(vht_cap, 0, sizeof(*vht_cap)); + + if (!link_sta->pub->ht_cap.ht_supported) + return; + + if (!vht_cap_ie || !sband->vht_cap.vht_supported) + return; + + /* Allow VHT if at least one channel on the sband supports 80 MHz */ + have_80mhz = false; + for (i = 0; i < sband->n_channels; i++) { + if (sband->channels[i].flags & (IEEE80211_CHAN_DISABLED | + IEEE80211_CHAN_NO_80MHZ)) + continue; + + have_80mhz = true; + break; + } + + if (!have_80mhz) + return; + + /* + * A VHT STA must support 40 MHz, but if we verify that here + * then we break a few things - some APs (e.g. Netgear R6300v2 + * and others based on the BCM4360 chipset) will unset this + * capability bit when operating in 20 MHz. + */ + + vht_cap->vht_supported = true; + + own_cap = sband->vht_cap; + /* + * If user has specified capability overrides, take care + * of that if the station we're setting up is the AP that + * we advertised a restricted capability set to. Override + * our own capabilities and then use those below. + */ + if (sdata->vif.type == NL80211_IFTYPE_STATION && + !test_sta_flag(link_sta->sta, WLAN_STA_TDLS_PEER)) + ieee80211_apply_vhtcap_overrides(sdata, &own_cap); + + /* take some capabilities as-is */ + cap_info = le32_to_cpu(vht_cap_ie->vht_cap_info); + vht_cap->cap = cap_info; + vht_cap->cap &= IEEE80211_VHT_CAP_RXLDPC | + IEEE80211_VHT_CAP_VHT_TXOP_PS | + IEEE80211_VHT_CAP_HTC_VHT | + IEEE80211_VHT_CAP_MAX_A_MPDU_LENGTH_EXPONENT_MASK | + IEEE80211_VHT_CAP_VHT_LINK_ADAPTATION_VHT_UNSOL_MFB | + IEEE80211_VHT_CAP_VHT_LINK_ADAPTATION_VHT_MRQ_MFB | + IEEE80211_VHT_CAP_RX_ANTENNA_PATTERN | + IEEE80211_VHT_CAP_TX_ANTENNA_PATTERN; + + vht_cap->cap |= min_t(u32, cap_info & IEEE80211_VHT_CAP_MAX_MPDU_MASK, + own_cap.cap & IEEE80211_VHT_CAP_MAX_MPDU_MASK); + + /* and some based on our own capabilities */ + switch (own_cap.cap & IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_MASK) { + case IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160MHZ: + vht_cap->cap |= cap_info & + IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160MHZ; + break; + case IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160_80PLUS80MHZ: + vht_cap->cap |= cap_info & + IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_MASK; + break; + default: + /* nothing */ + break; + } + + /* symmetric capabilities */ + vht_cap->cap |= cap_info & own_cap.cap & + (IEEE80211_VHT_CAP_SHORT_GI_80 | + IEEE80211_VHT_CAP_SHORT_GI_160); + + /* remaining ones */ + if (own_cap.cap & IEEE80211_VHT_CAP_SU_BEAMFORMEE_CAPABLE) + vht_cap->cap |= cap_info & + (IEEE80211_VHT_CAP_SU_BEAMFORMER_CAPABLE | + IEEE80211_VHT_CAP_SOUNDING_DIMENSIONS_MASK); + + if (own_cap.cap & IEEE80211_VHT_CAP_SU_BEAMFORMER_CAPABLE) + vht_cap->cap |= cap_info & + (IEEE80211_VHT_CAP_SU_BEAMFORMEE_CAPABLE | + IEEE80211_VHT_CAP_BEAMFORMEE_STS_MASK); + + if (own_cap.cap & IEEE80211_VHT_CAP_MU_BEAMFORMER_CAPABLE) + vht_cap->cap |= cap_info & + IEEE80211_VHT_CAP_MU_BEAMFORMEE_CAPABLE; + + if (own_cap.cap & IEEE80211_VHT_CAP_MU_BEAMFORMEE_CAPABLE) + vht_cap->cap |= cap_info & + IEEE80211_VHT_CAP_MU_BEAMFORMER_CAPABLE; + + if (own_cap.cap & IEEE80211_VHT_CAP_TXSTBC) + vht_cap->cap |= cap_info & IEEE80211_VHT_CAP_RXSTBC_MASK; + + if (own_cap.cap & IEEE80211_VHT_CAP_RXSTBC_MASK) + vht_cap->cap |= cap_info & IEEE80211_VHT_CAP_TXSTBC; + + /* Copy peer MCS info, the driver might need them. */ + memcpy(&vht_cap->vht_mcs, &vht_cap_ie->supp_mcs, + sizeof(struct ieee80211_vht_mcs_info)); + + /* copy EXT_NSS_BW Support value or remove the capability */ + if (ieee80211_hw_check(&sdata->local->hw, SUPPORTS_VHT_EXT_NSS_BW)) + vht_cap->cap |= (cap_info & IEEE80211_VHT_CAP_EXT_NSS_BW_MASK); + else + vht_cap->vht_mcs.tx_highest &= + ~cpu_to_le16(IEEE80211_VHT_EXT_NSS_BW_CAPABLE); + + /* but also restrict MCSes */ + for (i = 0; i < 8; i++) { + u16 own_rx, own_tx, peer_rx, peer_tx; + + own_rx = le16_to_cpu(own_cap.vht_mcs.rx_mcs_map); + own_rx = (own_rx >> i * 2) & IEEE80211_VHT_MCS_NOT_SUPPORTED; + + own_tx = le16_to_cpu(own_cap.vht_mcs.tx_mcs_map); + own_tx = (own_tx >> i * 2) & IEEE80211_VHT_MCS_NOT_SUPPORTED; + + peer_rx = le16_to_cpu(vht_cap->vht_mcs.rx_mcs_map); + peer_rx = (peer_rx >> i * 2) & IEEE80211_VHT_MCS_NOT_SUPPORTED; + + peer_tx = le16_to_cpu(vht_cap->vht_mcs.tx_mcs_map); + peer_tx = (peer_tx >> i * 2) & IEEE80211_VHT_MCS_NOT_SUPPORTED; + + if (peer_tx != IEEE80211_VHT_MCS_NOT_SUPPORTED) { + if (own_rx == IEEE80211_VHT_MCS_NOT_SUPPORTED) + peer_tx = IEEE80211_VHT_MCS_NOT_SUPPORTED; + else if (own_rx < peer_tx) + peer_tx = own_rx; + } + + if (peer_rx != IEEE80211_VHT_MCS_NOT_SUPPORTED) { + if (own_tx == IEEE80211_VHT_MCS_NOT_SUPPORTED) + peer_rx = IEEE80211_VHT_MCS_NOT_SUPPORTED; + else if (own_tx < peer_rx) + peer_rx = own_tx; + } + + vht_cap->vht_mcs.rx_mcs_map &= + ~cpu_to_le16(IEEE80211_VHT_MCS_NOT_SUPPORTED << i * 2); + vht_cap->vht_mcs.rx_mcs_map |= cpu_to_le16(peer_rx << i * 2); + + vht_cap->vht_mcs.tx_mcs_map &= + ~cpu_to_le16(IEEE80211_VHT_MCS_NOT_SUPPORTED << i * 2); + vht_cap->vht_mcs.tx_mcs_map |= cpu_to_le16(peer_tx << i * 2); + } + + /* + * This is a workaround for VHT-enabled STAs which break the spec + * and have the VHT-MCS Rx map filled in with value 3 for all eight + * spacial streams, an example is AR9462. + * + * As per spec, in section 22.1.1 Introduction to the VHT PHY + * A VHT STA shall support at least single spactial stream VHT-MCSs + * 0 to 7 (transmit and receive) in all supported channel widths. + */ + if (vht_cap->vht_mcs.rx_mcs_map == cpu_to_le16(0xFFFF)) { + vht_cap->vht_supported = false; + sdata_info(sdata, + "Ignoring VHT IE from %pM (link:%pM) due to invalid rx_mcs_map\n", + link_sta->sta->addr, link_sta->addr); + return; + } + + /* finally set up the bandwidth */ + switch (vht_cap->cap & IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_MASK) { + case IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160MHZ: + case IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160_80PLUS80MHZ: + link_sta->cur_max_bandwidth = IEEE80211_STA_RX_BW_160; + break; + default: + link_sta->cur_max_bandwidth = IEEE80211_STA_RX_BW_80; + + if (!(vht_cap->vht_mcs.tx_highest & + cpu_to_le16(IEEE80211_VHT_EXT_NSS_BW_CAPABLE))) + break; + + /* + * If this is non-zero, then it does support 160 MHz after all, + * in one form or the other. We don't distinguish here (or even + * above) between 160 and 80+80 yet. + */ + if (cap_info & IEEE80211_VHT_CAP_EXT_NSS_BW_MASK) + link_sta->cur_max_bandwidth = + IEEE80211_STA_RX_BW_160; + } + + link_sta->pub->bandwidth = ieee80211_sta_cur_vht_bw(link_sta); + + /* + * Work around the Cisco 9115 FW 17.3 bug by taking the min of + * both reported MPDU lengths. + */ + mpdu_len = vht_cap->cap & IEEE80211_VHT_CAP_MAX_MPDU_MASK; + if (vht_cap_ie2) + mpdu_len = min_t(u32, mpdu_len, + le32_get_bits(vht_cap_ie2->vht_cap_info, + IEEE80211_VHT_CAP_MAX_MPDU_MASK)); + + /* + * FIXME - should the amsdu len be per link? store per link + * and maintain a minimum? + */ + switch (mpdu_len) { + case IEEE80211_VHT_CAP_MAX_MPDU_LENGTH_11454: + link_sta->pub->agg.max_amsdu_len = IEEE80211_MAX_MPDU_LEN_VHT_11454; + break; + case IEEE80211_VHT_CAP_MAX_MPDU_LENGTH_7991: + link_sta->pub->agg.max_amsdu_len = IEEE80211_MAX_MPDU_LEN_VHT_7991; + break; + case IEEE80211_VHT_CAP_MAX_MPDU_LENGTH_3895: + default: + link_sta->pub->agg.max_amsdu_len = IEEE80211_MAX_MPDU_LEN_VHT_3895; + break; + } + + ieee80211_sta_recalc_aggregates(&link_sta->sta->sta); +} + +/* FIXME: move this to some better location - parses HE/EHT now */ +enum ieee80211_sta_rx_bandwidth +ieee80211_sta_cap_rx_bw(struct link_sta_info *link_sta) +{ + unsigned int link_id = link_sta->link_id; + struct ieee80211_sub_if_data *sdata = link_sta->sta->sdata; + struct ieee80211_sta_vht_cap *vht_cap = &link_sta->pub->vht_cap; + struct ieee80211_sta_he_cap *he_cap = &link_sta->pub->he_cap; + struct ieee80211_sta_eht_cap *eht_cap = &link_sta->pub->eht_cap; + u32 cap_width; + + if (he_cap->has_he) { + struct ieee80211_bss_conf *link_conf; + enum ieee80211_sta_rx_bandwidth ret; + u8 info; + + rcu_read_lock(); + link_conf = rcu_dereference(sdata->vif.link_conf[link_id]); + + if (eht_cap->has_eht && + link_conf->chandef.chan->band == NL80211_BAND_6GHZ) { + info = eht_cap->eht_cap_elem.phy_cap_info[0]; + + if (info & IEEE80211_EHT_PHY_CAP0_320MHZ_IN_6GHZ) { + ret = IEEE80211_STA_RX_BW_320; + goto out; + } + } + + info = he_cap->he_cap_elem.phy_cap_info[0]; + + if (link_conf->chandef.chan->band == NL80211_BAND_2GHZ) { + if (info & IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_40MHZ_IN_2G) + ret = IEEE80211_STA_RX_BW_40; + else + ret = IEEE80211_STA_RX_BW_20; + goto out; + } + + if (info & IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_160MHZ_IN_5G || + info & IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_80PLUS80_MHZ_IN_5G) + ret = IEEE80211_STA_RX_BW_160; + else if (info & IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_40MHZ_80MHZ_IN_5G) + ret = IEEE80211_STA_RX_BW_80; + else + ret = IEEE80211_STA_RX_BW_20; +out: + rcu_read_unlock(); + + return ret; + } + + if (!vht_cap->vht_supported) + return link_sta->pub->ht_cap.cap & IEEE80211_HT_CAP_SUP_WIDTH_20_40 ? + IEEE80211_STA_RX_BW_40 : + IEEE80211_STA_RX_BW_20; + + cap_width = vht_cap->cap & IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_MASK; + + if (cap_width == IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160MHZ || + cap_width == IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160_80PLUS80MHZ) + return IEEE80211_STA_RX_BW_160; + + /* + * If this is non-zero, then it does support 160 MHz after all, + * in one form or the other. We don't distinguish here (or even + * above) between 160 and 80+80 yet. + */ + if (vht_cap->cap & IEEE80211_VHT_CAP_EXT_NSS_BW_MASK) + return IEEE80211_STA_RX_BW_160; + + return IEEE80211_STA_RX_BW_80; +} + +enum nl80211_chan_width +ieee80211_sta_cap_chan_bw(struct link_sta_info *link_sta) +{ + struct ieee80211_sta_vht_cap *vht_cap = &link_sta->pub->vht_cap; + u32 cap_width; + + if (!vht_cap->vht_supported) { + if (!link_sta->pub->ht_cap.ht_supported) + return NL80211_CHAN_WIDTH_20_NOHT; + + return link_sta->pub->ht_cap.cap & IEEE80211_HT_CAP_SUP_WIDTH_20_40 ? + NL80211_CHAN_WIDTH_40 : NL80211_CHAN_WIDTH_20; + } + + cap_width = vht_cap->cap & IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_MASK; + + if (cap_width == IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160MHZ) + return NL80211_CHAN_WIDTH_160; + else if (cap_width == IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160_80PLUS80MHZ) + return NL80211_CHAN_WIDTH_80P80; + + return NL80211_CHAN_WIDTH_80; +} + +enum nl80211_chan_width +ieee80211_sta_rx_bw_to_chan_width(struct link_sta_info *link_sta) +{ + enum ieee80211_sta_rx_bandwidth cur_bw = + link_sta->pub->bandwidth; + struct ieee80211_sta_vht_cap *vht_cap = + &link_sta->pub->vht_cap; + u32 cap_width; + + switch (cur_bw) { + case IEEE80211_STA_RX_BW_20: + if (!link_sta->pub->ht_cap.ht_supported) + return NL80211_CHAN_WIDTH_20_NOHT; + else + return NL80211_CHAN_WIDTH_20; + case IEEE80211_STA_RX_BW_40: + return NL80211_CHAN_WIDTH_40; + case IEEE80211_STA_RX_BW_80: + return NL80211_CHAN_WIDTH_80; + case IEEE80211_STA_RX_BW_160: + cap_width = + vht_cap->cap & IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_MASK; + + if (cap_width == IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160MHZ) + return NL80211_CHAN_WIDTH_160; + + return NL80211_CHAN_WIDTH_80P80; + default: + return NL80211_CHAN_WIDTH_20; + } +} + +enum ieee80211_sta_rx_bandwidth +ieee80211_chan_width_to_rx_bw(enum nl80211_chan_width width) +{ + switch (width) { + case NL80211_CHAN_WIDTH_20_NOHT: + case NL80211_CHAN_WIDTH_20: + return IEEE80211_STA_RX_BW_20; + case NL80211_CHAN_WIDTH_40: + return IEEE80211_STA_RX_BW_40; + case NL80211_CHAN_WIDTH_80: + return IEEE80211_STA_RX_BW_80; + case NL80211_CHAN_WIDTH_160: + case NL80211_CHAN_WIDTH_80P80: + return IEEE80211_STA_RX_BW_160; + case NL80211_CHAN_WIDTH_320: + return IEEE80211_STA_RX_BW_320; + default: + WARN_ON_ONCE(1); + return IEEE80211_STA_RX_BW_20; + } +} + +/* FIXME: rename/move - this deals with everything not just VHT */ +enum ieee80211_sta_rx_bandwidth +ieee80211_sta_cur_vht_bw(struct link_sta_info *link_sta) +{ + struct sta_info *sta = link_sta->sta; + struct ieee80211_bss_conf *link_conf; + enum nl80211_chan_width bss_width; + enum ieee80211_sta_rx_bandwidth bw; + + rcu_read_lock(); + link_conf = rcu_dereference(sta->sdata->vif.link_conf[link_sta->link_id]); + if (WARN_ON(!link_conf)) + bss_width = NL80211_CHAN_WIDTH_20_NOHT; + else + bss_width = link_conf->chandef.width; + rcu_read_unlock(); + + bw = ieee80211_sta_cap_rx_bw(link_sta); + bw = min(bw, link_sta->cur_max_bandwidth); + + /* Don't consider AP's bandwidth for TDLS peers, section 11.23.1 of + * IEEE80211-2016 specification makes higher bandwidth operation + * possible on the TDLS link if the peers have wider bandwidth + * capability. + * + * However, in this case, and only if the TDLS peer is authorized, + * limit to the tdls_chandef so that the configuration here isn't + * wider than what's actually requested on the channel context. + */ + if (test_sta_flag(sta, WLAN_STA_TDLS_PEER) && + test_sta_flag(sta, WLAN_STA_TDLS_WIDER_BW) && + test_sta_flag(sta, WLAN_STA_AUTHORIZED) && + sta->tdls_chandef.chan) + bw = min(bw, ieee80211_chan_width_to_rx_bw(sta->tdls_chandef.width)); + else + bw = min(bw, ieee80211_chan_width_to_rx_bw(bss_width)); + + return bw; +} + +void ieee80211_sta_set_rx_nss(struct link_sta_info *link_sta) +{ + u8 ht_rx_nss = 0, vht_rx_nss = 0, he_rx_nss = 0, eht_rx_nss = 0, rx_nss; + bool support_160; + + /* if we received a notification already don't overwrite it */ + if (link_sta->pub->rx_nss) + return; + + if (link_sta->pub->eht_cap.has_eht) { + int i; + const u8 *rx_nss_mcs = (void *)&link_sta->pub->eht_cap.eht_mcs_nss_supp; + + /* get the max nss for EHT over all possible bandwidths and mcs */ + for (i = 0; i < sizeof(struct ieee80211_eht_mcs_nss_supp); i++) + eht_rx_nss = max_t(u8, eht_rx_nss, + u8_get_bits(rx_nss_mcs[i], + IEEE80211_EHT_MCS_NSS_RX)); + } + + if (link_sta->pub->he_cap.has_he) { + int i; + u8 rx_mcs_80 = 0, rx_mcs_160 = 0; + const struct ieee80211_sta_he_cap *he_cap = &link_sta->pub->he_cap; + u16 mcs_160_map = + le16_to_cpu(he_cap->he_mcs_nss_supp.rx_mcs_160); + u16 mcs_80_map = le16_to_cpu(he_cap->he_mcs_nss_supp.rx_mcs_80); + + for (i = 7; i >= 0; i--) { + u8 mcs_160 = (mcs_160_map >> (2 * i)) & 3; + + if (mcs_160 != IEEE80211_HE_MCS_NOT_SUPPORTED) { + rx_mcs_160 = i + 1; + break; + } + } + for (i = 7; i >= 0; i--) { + u8 mcs_80 = (mcs_80_map >> (2 * i)) & 3; + + if (mcs_80 != IEEE80211_HE_MCS_NOT_SUPPORTED) { + rx_mcs_80 = i + 1; + break; + } + } + + support_160 = he_cap->he_cap_elem.phy_cap_info[0] & + IEEE80211_HE_PHY_CAP0_CHANNEL_WIDTH_SET_160MHZ_IN_5G; + + if (support_160) + he_rx_nss = min(rx_mcs_80, rx_mcs_160); + else + he_rx_nss = rx_mcs_80; + } + + if (link_sta->pub->ht_cap.ht_supported) { + if (link_sta->pub->ht_cap.mcs.rx_mask[0]) + ht_rx_nss++; + if (link_sta->pub->ht_cap.mcs.rx_mask[1]) + ht_rx_nss++; + if (link_sta->pub->ht_cap.mcs.rx_mask[2]) + ht_rx_nss++; + if (link_sta->pub->ht_cap.mcs.rx_mask[3]) + ht_rx_nss++; + /* FIXME: consider rx_highest? */ + } + + if (link_sta->pub->vht_cap.vht_supported) { + int i; + u16 rx_mcs_map; + + rx_mcs_map = le16_to_cpu(link_sta->pub->vht_cap.vht_mcs.rx_mcs_map); + + for (i = 7; i >= 0; i--) { + u8 mcs = (rx_mcs_map >> (2 * i)) & 3; + + if (mcs != IEEE80211_VHT_MCS_NOT_SUPPORTED) { + vht_rx_nss = i + 1; + break; + } + } + /* FIXME: consider rx_highest? */ + } + + rx_nss = max(vht_rx_nss, ht_rx_nss); + rx_nss = max(he_rx_nss, rx_nss); + rx_nss = max(eht_rx_nss, rx_nss); + link_sta->pub->rx_nss = max_t(u8, 1, rx_nss); +} + +u32 __ieee80211_vht_handle_opmode(struct ieee80211_sub_if_data *sdata, + struct link_sta_info *link_sta, + u8 opmode, enum nl80211_band band) +{ + enum ieee80211_sta_rx_bandwidth new_bw; + struct sta_opmode_info sta_opmode = {}; + u32 changed = 0; + u8 nss; + + /* ignore - no support for BF yet */ + if (opmode & IEEE80211_OPMODE_NOTIF_RX_NSS_TYPE_BF) + return 0; + + nss = opmode & IEEE80211_OPMODE_NOTIF_RX_NSS_MASK; + nss >>= IEEE80211_OPMODE_NOTIF_RX_NSS_SHIFT; + nss += 1; + + if (link_sta->pub->rx_nss != nss) { + link_sta->pub->rx_nss = nss; + sta_opmode.rx_nss = nss; + changed |= IEEE80211_RC_NSS_CHANGED; + sta_opmode.changed |= STA_OPMODE_N_SS_CHANGED; + } + + switch (opmode & IEEE80211_OPMODE_NOTIF_CHANWIDTH_MASK) { + case IEEE80211_OPMODE_NOTIF_CHANWIDTH_20MHZ: + /* ignore IEEE80211_OPMODE_NOTIF_BW_160_80P80 must not be set */ + link_sta->cur_max_bandwidth = IEEE80211_STA_RX_BW_20; + break; + case IEEE80211_OPMODE_NOTIF_CHANWIDTH_40MHZ: + /* ignore IEEE80211_OPMODE_NOTIF_BW_160_80P80 must not be set */ + link_sta->cur_max_bandwidth = IEEE80211_STA_RX_BW_40; + break; + case IEEE80211_OPMODE_NOTIF_CHANWIDTH_80MHZ: + if (opmode & IEEE80211_OPMODE_NOTIF_BW_160_80P80) + link_sta->cur_max_bandwidth = IEEE80211_STA_RX_BW_160; + else + link_sta->cur_max_bandwidth = IEEE80211_STA_RX_BW_80; + break; + case IEEE80211_OPMODE_NOTIF_CHANWIDTH_160MHZ: + /* legacy only, no longer used by newer spec */ + link_sta->cur_max_bandwidth = IEEE80211_STA_RX_BW_160; + break; + } + + new_bw = ieee80211_sta_cur_vht_bw(link_sta); + if (new_bw != link_sta->pub->bandwidth) { + link_sta->pub->bandwidth = new_bw; + sta_opmode.bw = ieee80211_sta_rx_bw_to_chan_width(link_sta); + changed |= IEEE80211_RC_BW_CHANGED; + sta_opmode.changed |= STA_OPMODE_MAX_BW_CHANGED; + } + + if (sta_opmode.changed) + cfg80211_sta_opmode_change_notify(sdata->dev, link_sta->addr, + &sta_opmode, GFP_KERNEL); + + return changed; +} + +void ieee80211_process_mu_groups(struct ieee80211_sub_if_data *sdata, + struct ieee80211_link_data *link, + struct ieee80211_mgmt *mgmt) +{ + struct ieee80211_bss_conf *link_conf = link->conf; + + if (!link_conf->mu_mimo_owner) + return; + + if (!memcmp(mgmt->u.action.u.vht_group_notif.position, + link_conf->mu_group.position, WLAN_USER_POSITION_LEN) && + !memcmp(mgmt->u.action.u.vht_group_notif.membership, + link_conf->mu_group.membership, WLAN_MEMBERSHIP_LEN)) + return; + + memcpy(link_conf->mu_group.membership, + mgmt->u.action.u.vht_group_notif.membership, + WLAN_MEMBERSHIP_LEN); + memcpy(link_conf->mu_group.position, + mgmt->u.action.u.vht_group_notif.position, + WLAN_USER_POSITION_LEN); + + ieee80211_link_info_change_notify(sdata, link, + BSS_CHANGED_MU_GROUPS); +} + +void ieee80211_update_mu_groups(struct ieee80211_vif *vif, unsigned int link_id, + const u8 *membership, const u8 *position) +{ + struct ieee80211_bss_conf *link_conf; + + rcu_read_lock(); + link_conf = rcu_dereference(vif->link_conf[link_id]); + + if (!WARN_ON_ONCE(!link_conf || !link_conf->mu_mimo_owner)) { + memcpy(link_conf->mu_group.membership, membership, + WLAN_MEMBERSHIP_LEN); + memcpy(link_conf->mu_group.position, position, + WLAN_USER_POSITION_LEN); + } + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(ieee80211_update_mu_groups); + +void ieee80211_vht_handle_opmode(struct ieee80211_sub_if_data *sdata, + struct link_sta_info *link_sta, + u8 opmode, enum nl80211_band band) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_supported_band *sband = local->hw.wiphy->bands[band]; + + u32 changed = __ieee80211_vht_handle_opmode(sdata, link_sta, + opmode, band); + + if (changed > 0) { + ieee80211_recalc_min_chandef(sdata, link_sta->link_id); + rate_control_rate_update(local, sband, link_sta->sta, + link_sta->link_id, changed); + } +} + +void ieee80211_get_vht_mask_from_cap(__le16 vht_cap, + u16 vht_mask[NL80211_VHT_NSS_MAX]) +{ + int i; + u16 mask, cap = le16_to_cpu(vht_cap); + + for (i = 0; i < NL80211_VHT_NSS_MAX; i++) { + mask = (cap >> i * 2) & IEEE80211_VHT_MCS_NOT_SUPPORTED; + switch (mask) { + case IEEE80211_VHT_MCS_SUPPORT_0_7: + vht_mask[i] = 0x00FF; + break; + case IEEE80211_VHT_MCS_SUPPORT_0_8: + vht_mask[i] = 0x01FF; + break; + case IEEE80211_VHT_MCS_SUPPORT_0_9: + vht_mask[i] = 0x03FF; + break; + case IEEE80211_VHT_MCS_NOT_SUPPORTED: + default: + vht_mask[i] = 0; + break; + } + } +} diff --git a/net/mac80211/wep.c b/net/mac80211/wep.c new file mode 100644 index 000000000..9a6e11d7b --- /dev/null +++ b/net/mac80211/wep.c @@ -0,0 +1,306 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Software WEP encryption implementation + * Copyright 2002, Jouni Malinen + * Copyright 2003, Instant802 Networks, Inc. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "ieee80211_i.h" +#include "wep.h" + + +void ieee80211_wep_init(struct ieee80211_local *local) +{ + /* start WEP IV from a random value */ + get_random_bytes(&local->wep_iv, IEEE80211_WEP_IV_LEN); +} + +static inline bool ieee80211_wep_weak_iv(u32 iv, int keylen) +{ + /* + * Fluhrer, Mantin, and Shamir have reported weaknesses in the + * key scheduling algorithm of RC4. At least IVs (KeyByte + 3, + * 0xff, N) can be used to speedup attacks, so avoid using them. + */ + if ((iv & 0xff00) == 0xff00) { + u8 B = (iv >> 16) & 0xff; + if (B >= 3 && B < 3 + keylen) + return true; + } + return false; +} + + +static void ieee80211_wep_get_iv(struct ieee80211_local *local, + int keylen, int keyidx, u8 *iv) +{ + local->wep_iv++; + if (ieee80211_wep_weak_iv(local->wep_iv, keylen)) + local->wep_iv += 0x0100; + + if (!iv) + return; + + *iv++ = (local->wep_iv >> 16) & 0xff; + *iv++ = (local->wep_iv >> 8) & 0xff; + *iv++ = local->wep_iv & 0xff; + *iv++ = keyidx << 6; +} + + +static u8 *ieee80211_wep_add_iv(struct ieee80211_local *local, + struct sk_buff *skb, + int keylen, int keyidx) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + unsigned int hdrlen; + u8 *newhdr; + + hdr->frame_control |= cpu_to_le16(IEEE80211_FCTL_PROTECTED); + + if (WARN_ON(skb_headroom(skb) < IEEE80211_WEP_IV_LEN)) + return NULL; + + hdrlen = ieee80211_hdrlen(hdr->frame_control); + newhdr = skb_push(skb, IEEE80211_WEP_IV_LEN); + memmove(newhdr, newhdr + IEEE80211_WEP_IV_LEN, hdrlen); + + /* the HW only needs room for the IV, but not the actual IV */ + if (info->control.hw_key && + (info->control.hw_key->flags & IEEE80211_KEY_FLAG_PUT_IV_SPACE)) + return newhdr + hdrlen; + + ieee80211_wep_get_iv(local, keylen, keyidx, newhdr + hdrlen); + return newhdr + hdrlen; +} + + +static void ieee80211_wep_remove_iv(struct ieee80211_local *local, + struct sk_buff *skb, + struct ieee80211_key *key) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + unsigned int hdrlen; + + hdrlen = ieee80211_hdrlen(hdr->frame_control); + memmove(skb->data + IEEE80211_WEP_IV_LEN, skb->data, hdrlen); + skb_pull(skb, IEEE80211_WEP_IV_LEN); +} + + +/* Perform WEP encryption using given key. data buffer must have tailroom + * for 4-byte ICV. data_len must not include this ICV. Note: this function + * does _not_ add IV. data = RC4(data | CRC32(data)) */ +int ieee80211_wep_encrypt_data(struct arc4_ctx *ctx, u8 *rc4key, + size_t klen, u8 *data, size_t data_len) +{ + __le32 icv; + + icv = cpu_to_le32(~crc32_le(~0, data, data_len)); + put_unaligned(icv, (__le32 *)(data + data_len)); + + arc4_setkey(ctx, rc4key, klen); + arc4_crypt(ctx, data, data, data_len + IEEE80211_WEP_ICV_LEN); + memzero_explicit(ctx, sizeof(*ctx)); + + return 0; +} + + +/* Perform WEP encryption on given skb. 4 bytes of extra space (IV) in the + * beginning of the buffer 4 bytes of extra space (ICV) in the end of the + * buffer will be added. Both IV and ICV will be transmitted, so the + * payload length increases with 8 bytes. + * + * WEP frame payload: IV + TX key idx, RC4(data), ICV = RC4(CRC32(data)) + */ +int ieee80211_wep_encrypt(struct ieee80211_local *local, + struct sk_buff *skb, + const u8 *key, int keylen, int keyidx) +{ + u8 *iv; + size_t len; + u8 rc4key[3 + WLAN_KEY_LEN_WEP104]; + + if (WARN_ON(skb_tailroom(skb) < IEEE80211_WEP_ICV_LEN)) + return -1; + + iv = ieee80211_wep_add_iv(local, skb, keylen, keyidx); + if (!iv) + return -1; + + len = skb->len - (iv + IEEE80211_WEP_IV_LEN - skb->data); + + /* Prepend 24-bit IV to RC4 key */ + memcpy(rc4key, iv, 3); + + /* Copy rest of the WEP key (the secret part) */ + memcpy(rc4key + 3, key, keylen); + + /* Add room for ICV */ + skb_put(skb, IEEE80211_WEP_ICV_LEN); + + return ieee80211_wep_encrypt_data(&local->wep_tx_ctx, rc4key, keylen + 3, + iv + IEEE80211_WEP_IV_LEN, len); +} + + +/* Perform WEP decryption using given key. data buffer includes encrypted + * payload, including 4-byte ICV, but _not_ IV. data_len must not include ICV. + * Return 0 on success and -1 on ICV mismatch. */ +int ieee80211_wep_decrypt_data(struct arc4_ctx *ctx, u8 *rc4key, + size_t klen, u8 *data, size_t data_len) +{ + __le32 crc; + + arc4_setkey(ctx, rc4key, klen); + arc4_crypt(ctx, data, data, data_len + IEEE80211_WEP_ICV_LEN); + memzero_explicit(ctx, sizeof(*ctx)); + + crc = cpu_to_le32(~crc32_le(~0, data, data_len)); + if (memcmp(&crc, data + data_len, IEEE80211_WEP_ICV_LEN) != 0) + /* ICV mismatch */ + return -1; + + return 0; +} + + +/* Perform WEP decryption on given skb. Buffer includes whole WEP part of + * the frame: IV (4 bytes), encrypted payload (including SNAP header), + * ICV (4 bytes). skb->len includes both IV and ICV. + * + * Returns 0 if frame was decrypted successfully and ICV was correct and -1 on + * failure. If frame is OK, IV and ICV will be removed, i.e., decrypted payload + * is moved to the beginning of the skb and skb length will be reduced. + */ +static int ieee80211_wep_decrypt(struct ieee80211_local *local, + struct sk_buff *skb, + struct ieee80211_key *key) +{ + u32 klen; + u8 rc4key[3 + WLAN_KEY_LEN_WEP104]; + u8 keyidx; + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + unsigned int hdrlen; + size_t len; + int ret = 0; + + if (!ieee80211_has_protected(hdr->frame_control)) + return -1; + + hdrlen = ieee80211_hdrlen(hdr->frame_control); + if (skb->len < hdrlen + IEEE80211_WEP_IV_LEN + IEEE80211_WEP_ICV_LEN) + return -1; + + len = skb->len - hdrlen - IEEE80211_WEP_IV_LEN - IEEE80211_WEP_ICV_LEN; + + keyidx = skb->data[hdrlen + 3] >> 6; + + if (!key || keyidx != key->conf.keyidx) + return -1; + + klen = 3 + key->conf.keylen; + + /* Prepend 24-bit IV to RC4 key */ + memcpy(rc4key, skb->data + hdrlen, 3); + + /* Copy rest of the WEP key (the secret part) */ + memcpy(rc4key + 3, key->conf.key, key->conf.keylen); + + if (ieee80211_wep_decrypt_data(&local->wep_rx_ctx, rc4key, klen, + skb->data + hdrlen + + IEEE80211_WEP_IV_LEN, len)) + ret = -1; + + /* Trim ICV */ + skb_trim(skb, skb->len - IEEE80211_WEP_ICV_LEN); + + /* Remove IV */ + memmove(skb->data + IEEE80211_WEP_IV_LEN, skb->data, hdrlen); + skb_pull(skb, IEEE80211_WEP_IV_LEN); + + return ret; +} + +ieee80211_rx_result +ieee80211_crypto_wep_decrypt(struct ieee80211_rx_data *rx) +{ + struct sk_buff *skb = rx->skb; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + __le16 fc = hdr->frame_control; + + if (!ieee80211_is_data(fc) && !ieee80211_is_auth(fc)) + return RX_CONTINUE; + + if (!(status->flag & RX_FLAG_DECRYPTED)) { + if (skb_linearize(rx->skb)) + return RX_DROP_UNUSABLE; + if (ieee80211_wep_decrypt(rx->local, rx->skb, rx->key)) + return RX_DROP_UNUSABLE; + } else if (!(status->flag & RX_FLAG_IV_STRIPPED)) { + if (!pskb_may_pull(rx->skb, ieee80211_hdrlen(fc) + + IEEE80211_WEP_IV_LEN)) + return RX_DROP_UNUSABLE; + ieee80211_wep_remove_iv(rx->local, rx->skb, rx->key); + /* remove ICV */ + if (!(status->flag & RX_FLAG_ICV_STRIPPED) && + pskb_trim(rx->skb, rx->skb->len - IEEE80211_WEP_ICV_LEN)) + return RX_DROP_UNUSABLE; + } + + return RX_CONTINUE; +} + +static int wep_encrypt_skb(struct ieee80211_tx_data *tx, struct sk_buff *skb) +{ + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + struct ieee80211_key_conf *hw_key = info->control.hw_key; + + if (!hw_key) { + if (ieee80211_wep_encrypt(tx->local, skb, tx->key->conf.key, + tx->key->conf.keylen, + tx->key->conf.keyidx)) + return -1; + } else if ((hw_key->flags & IEEE80211_KEY_FLAG_GENERATE_IV) || + (hw_key->flags & IEEE80211_KEY_FLAG_PUT_IV_SPACE)) { + if (!ieee80211_wep_add_iv(tx->local, skb, + tx->key->conf.keylen, + tx->key->conf.keyidx)) + return -1; + } + + return 0; +} + +ieee80211_tx_result +ieee80211_crypto_wep_encrypt(struct ieee80211_tx_data *tx) +{ + struct sk_buff *skb; + + ieee80211_tx_set_protected(tx); + + skb_queue_walk(&tx->skbs, skb) { + if (wep_encrypt_skb(tx, skb) < 0) { + I802_DEBUG_INC(tx->local->tx_handlers_drop_wep); + return TX_DROP; + } + } + + return TX_CONTINUE; +} diff --git a/net/mac80211/wep.h b/net/mac80211/wep.h new file mode 100644 index 000000000..4ffe83554 --- /dev/null +++ b/net/mac80211/wep.h @@ -0,0 +1,30 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Software WEP encryption implementation + * Copyright 2002, Jouni Malinen + * Copyright 2003, Instant802 Networks, Inc. + */ + +#ifndef WEP_H +#define WEP_H + +#include +#include +#include "ieee80211_i.h" +#include "key.h" + +void ieee80211_wep_init(struct ieee80211_local *local); +int ieee80211_wep_encrypt_data(struct arc4_ctx *ctx, u8 *rc4key, + size_t klen, u8 *data, size_t data_len); +int ieee80211_wep_encrypt(struct ieee80211_local *local, + struct sk_buff *skb, + const u8 *key, int keylen, int keyidx); +int ieee80211_wep_decrypt_data(struct arc4_ctx *ctx, u8 *rc4key, + size_t klen, u8 *data, size_t data_len); + +ieee80211_rx_result +ieee80211_crypto_wep_decrypt(struct ieee80211_rx_data *rx); +ieee80211_tx_result +ieee80211_crypto_wep_encrypt(struct ieee80211_tx_data *tx); + +#endif /* WEP_H */ diff --git a/net/mac80211/wme.c b/net/mac80211/wme.c new file mode 100644 index 000000000..64430949f --- /dev/null +++ b/net/mac80211/wme.c @@ -0,0 +1,293 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2004, Instant802 Networks, Inc. + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright (C) 2022 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include "ieee80211_i.h" +#include "wme.h" + +/* Default mapping in classifier to work with default + * queue setup. + */ +const int ieee802_1d_to_ac[8] = { + IEEE80211_AC_BE, + IEEE80211_AC_BK, + IEEE80211_AC_BK, + IEEE80211_AC_BE, + IEEE80211_AC_VI, + IEEE80211_AC_VI, + IEEE80211_AC_VO, + IEEE80211_AC_VO +}; + +static int wme_downgrade_ac(struct sk_buff *skb) +{ + switch (skb->priority) { + case 6: + case 7: + skb->priority = 5; /* VO -> VI */ + return 0; + case 4: + case 5: + skb->priority = 3; /* VI -> BE */ + return 0; + case 0: + case 3: + skb->priority = 2; /* BE -> BK */ + return 0; + default: + return -1; + } +} + +/** + * ieee80211_fix_reserved_tid - return the TID to use if this one is reserved + * @tid: the assumed-reserved TID + * + * Returns: the alternative TID to use, or 0 on error + */ +static inline u8 ieee80211_fix_reserved_tid(u8 tid) +{ + switch (tid) { + case 0: + return 3; + case 1: + return 2; + case 2: + return 1; + case 3: + return 0; + case 4: + return 5; + case 5: + return 4; + case 6: + return 7; + case 7: + return 6; + } + + return 0; +} + +static u16 ieee80211_downgrade_queue(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, struct sk_buff *skb) +{ + struct ieee80211_if_managed *ifmgd = &sdata->u.mgd; + + /* in case we are a client verify acm is not set for this ac */ + while (sdata->wmm_acm & BIT(skb->priority)) { + int ac = ieee802_1d_to_ac[skb->priority]; + + if (ifmgd->tx_tspec[ac].admitted_time && + skb->priority == ifmgd->tx_tspec[ac].up) + return ac; + + if (wme_downgrade_ac(skb)) { + /* + * This should not really happen. The AP has marked all + * lower ACs to require admission control which is not + * a reasonable configuration. Allow the frame to be + * transmitted using AC_BK as a workaround. + */ + break; + } + } + + /* Check to see if this is a reserved TID */ + if (sta && sta->reserved_tid == skb->priority) + skb->priority = ieee80211_fix_reserved_tid(skb->priority); + + /* look up which queue to use for frames with this 1d tag */ + return ieee802_1d_to_ac[skb->priority]; +} + +/* Indicate which queue to use for this fully formed 802.11 frame */ +u16 ieee80211_select_queue_80211(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, + struct ieee80211_hdr *hdr) +{ + struct ieee80211_local *local = sdata->local; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + u8 *p; + + if ((info->control.flags & IEEE80211_TX_CTRL_DONT_REORDER) || + local->hw.queues < IEEE80211_NUM_ACS) + return 0; + + if (!ieee80211_is_data(hdr->frame_control)) { + skb->priority = 7; + return ieee802_1d_to_ac[skb->priority]; + } + if (!ieee80211_is_data_qos(hdr->frame_control)) { + skb->priority = 0; + return ieee802_1d_to_ac[skb->priority]; + } + + p = ieee80211_get_qos_ctl(hdr); + skb->priority = *p & IEEE80211_QOS_CTL_TAG1D_MASK; + + return ieee80211_downgrade_queue(sdata, NULL, skb); +} + +u16 __ieee80211_select_queue(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, struct sk_buff *skb) +{ + const struct ethhdr *eth = (void *)skb->data; + struct mac80211_qos_map *qos_map; + bool qos; + + /* all mesh/ocb stations are required to support WME */ + if ((sdata->vif.type == NL80211_IFTYPE_MESH_POINT && + !is_multicast_ether_addr(eth->h_dest)) || + (sdata->vif.type == NL80211_IFTYPE_OCB && sta)) + qos = true; + else if (sta) + qos = sta->sta.wme; + else + qos = false; + + if (!qos) { + skb->priority = 0; /* required for correct WPA/11i MIC */ + return IEEE80211_AC_BE; + } + + if (skb->protocol == sdata->control_port_protocol) { + skb->priority = 7; + goto downgrade; + } + + /* use the data classifier to determine what 802.1d tag the + * data frame has */ + qos_map = rcu_dereference(sdata->qos_map); + skb->priority = cfg80211_classify8021d(skb, qos_map ? + &qos_map->qos_map : NULL); + + downgrade: + return ieee80211_downgrade_queue(sdata, sta, skb); +} + + +/* Indicate which queue to use. */ +u16 ieee80211_select_queue(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_local *local = sdata->local; + struct sta_info *sta = NULL; + const u8 *ra = NULL; + u16 ret; + + /* when using iTXQ, we can do this later */ + if (local->ops->wake_tx_queue) + return 0; + + if (local->hw.queues < IEEE80211_NUM_ACS || skb->len < 6) { + skb->priority = 0; /* required for correct WPA/11i MIC */ + return 0; + } + + rcu_read_lock(); + switch (sdata->vif.type) { + case NL80211_IFTYPE_AP_VLAN: + sta = rcu_dereference(sdata->u.vlan.sta); + if (sta) + break; + fallthrough; + case NL80211_IFTYPE_AP: + ra = skb->data; + break; + case NL80211_IFTYPE_STATION: + /* might be a TDLS station */ + sta = sta_info_get(sdata, skb->data); + if (sta) + break; + + ra = sdata->deflink.u.mgd.bssid; + break; + case NL80211_IFTYPE_ADHOC: + ra = skb->data; + break; + default: + break; + } + + if (!sta && ra && !is_multicast_ether_addr(ra)) + sta = sta_info_get(sdata, ra); + + ret = __ieee80211_select_queue(sdata, sta, skb); + + rcu_read_unlock(); + return ret; +} + +/** + * ieee80211_set_qos_hdr - Fill in the QoS header if there is one. + * + * @sdata: local subif + * @skb: packet to be updated + */ +void ieee80211_set_qos_hdr(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb) +{ + struct ieee80211_hdr *hdr = (void *)skb->data; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + u8 tid = skb->priority & IEEE80211_QOS_CTL_TAG1D_MASK; + u8 flags; + u8 *p; + + if (!ieee80211_is_data_qos(hdr->frame_control)) + return; + + p = ieee80211_get_qos_ctl(hdr); + + /* don't overwrite the QoS field of injected frames */ + if (info->flags & IEEE80211_TX_CTL_INJECTED) { + /* do take into account Ack policy of injected frames */ + if (*p & IEEE80211_QOS_CTL_ACK_POLICY_NOACK) + info->flags |= IEEE80211_TX_CTL_NO_ACK; + return; + } + + /* set up the first byte */ + + /* + * preserve everything but the TID and ACK policy + * (which we both write here) + */ + flags = *p & ~(IEEE80211_QOS_CTL_TID_MASK | + IEEE80211_QOS_CTL_ACK_POLICY_MASK); + + if (is_multicast_ether_addr(hdr->addr1) || + sdata->noack_map & BIT(tid)) { + flags |= IEEE80211_QOS_CTL_ACK_POLICY_NOACK; + info->flags |= IEEE80211_TX_CTL_NO_ACK; + } + + *p = flags | tid; + + /* set up the second byte */ + p++; + + if (ieee80211_vif_is_mesh(&sdata->vif)) { + /* preserve RSPI and Mesh PS Level bit */ + *p &= ((IEEE80211_QOS_CTL_RSPI | + IEEE80211_QOS_CTL_MESH_PS_LEVEL) >> 8); + + /* Nulls don't have a mesh header (frame body) */ + if (!ieee80211_is_qos_nullfunc(hdr->frame_control)) + *p |= (IEEE80211_QOS_CTL_MESH_CONTROL_PRESENT >> 8); + } else { + *p = 0; + } +} diff --git a/net/mac80211/wme.h b/net/mac80211/wme.h new file mode 100644 index 000000000..2e3dec0b6 --- /dev/null +++ b/net/mac80211/wme.h @@ -0,0 +1,23 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright 2004, Instant802 Networks, Inc. + * Copyright 2005, Devicescape Software, Inc. + */ + +#ifndef _WME_H +#define _WME_H + +#include +#include "ieee80211_i.h" + +u16 ieee80211_select_queue_80211(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb, + struct ieee80211_hdr *hdr); +u16 __ieee80211_select_queue(struct ieee80211_sub_if_data *sdata, + struct sta_info *sta, struct sk_buff *skb); +u16 ieee80211_select_queue(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); +void ieee80211_set_qos_hdr(struct ieee80211_sub_if_data *sdata, + struct sk_buff *skb); + +#endif /* _WME_H */ diff --git a/net/mac80211/wpa.c b/net/mac80211/wpa.c new file mode 100644 index 000000000..20f742b55 --- /dev/null +++ b/net/mac80211/wpa.c @@ -0,0 +1,1118 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2002-2004, Instant802 Networks, Inc. + * Copyright 2008, Jouni Malinen + * Copyright (C) 2016-2017 Intel Deutschland GmbH + * Copyright (C) 2020-2022 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ieee80211_i.h" +#include "michael.h" +#include "tkip.h" +#include "aes_ccm.h" +#include "aes_cmac.h" +#include "aes_gmac.h" +#include "aes_gcm.h" +#include "wpa.h" + +ieee80211_tx_result +ieee80211_tx_h_michael_mic_add(struct ieee80211_tx_data *tx) +{ + u8 *data, *key, *mic; + size_t data_len; + unsigned int hdrlen; + struct ieee80211_hdr *hdr; + struct sk_buff *skb = tx->skb; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + int tail; + + hdr = (struct ieee80211_hdr *)skb->data; + if (!tx->key || tx->key->conf.cipher != WLAN_CIPHER_SUITE_TKIP || + skb->len < 24 || !ieee80211_is_data_present(hdr->frame_control)) + return TX_CONTINUE; + + hdrlen = ieee80211_hdrlen(hdr->frame_control); + if (skb->len < hdrlen) + return TX_DROP; + + data = skb->data + hdrlen; + data_len = skb->len - hdrlen; + + if (unlikely(info->flags & IEEE80211_TX_INTFL_TKIP_MIC_FAILURE)) { + /* Need to use software crypto for the test */ + info->control.hw_key = NULL; + } + + if (info->control.hw_key && + (info->flags & IEEE80211_TX_CTL_DONTFRAG || + ieee80211_hw_check(&tx->local->hw, SUPPORTS_TX_FRAG)) && + !(tx->key->conf.flags & (IEEE80211_KEY_FLAG_GENERATE_MMIC | + IEEE80211_KEY_FLAG_PUT_MIC_SPACE))) { + /* hwaccel - with no need for SW-generated MMIC or MIC space */ + return TX_CONTINUE; + } + + tail = MICHAEL_MIC_LEN; + if (!info->control.hw_key) + tail += IEEE80211_TKIP_ICV_LEN; + + if (WARN(skb_tailroom(skb) < tail || + skb_headroom(skb) < IEEE80211_TKIP_IV_LEN, + "mmic: not enough head/tail (%d/%d,%d/%d)\n", + skb_headroom(skb), IEEE80211_TKIP_IV_LEN, + skb_tailroom(skb), tail)) + return TX_DROP; + + mic = skb_put(skb, MICHAEL_MIC_LEN); + + if (tx->key->conf.flags & IEEE80211_KEY_FLAG_PUT_MIC_SPACE) { + /* Zeroed MIC can help with debug */ + memset(mic, 0, MICHAEL_MIC_LEN); + return TX_CONTINUE; + } + + key = &tx->key->conf.key[NL80211_TKIP_DATA_OFFSET_TX_MIC_KEY]; + michael_mic(key, hdr, data, data_len, mic); + if (unlikely(info->flags & IEEE80211_TX_INTFL_TKIP_MIC_FAILURE)) + mic[0]++; + + return TX_CONTINUE; +} + + +ieee80211_rx_result +ieee80211_rx_h_michael_mic_verify(struct ieee80211_rx_data *rx) +{ + u8 *data, *key = NULL; + size_t data_len; + unsigned int hdrlen; + u8 mic[MICHAEL_MIC_LEN]; + struct sk_buff *skb = rx->skb; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + + /* + * it makes no sense to check for MIC errors on anything other + * than data frames. + */ + if (!ieee80211_is_data_present(hdr->frame_control)) + return RX_CONTINUE; + + /* + * No way to verify the MIC if the hardware stripped it or + * the IV with the key index. In this case we have solely rely + * on the driver to set RX_FLAG_MMIC_ERROR in the event of a + * MIC failure report. + */ + if (status->flag & (RX_FLAG_MMIC_STRIPPED | RX_FLAG_IV_STRIPPED)) { + if (status->flag & RX_FLAG_MMIC_ERROR) + goto mic_fail_no_key; + + if (!(status->flag & RX_FLAG_IV_STRIPPED) && rx->key && + rx->key->conf.cipher == WLAN_CIPHER_SUITE_TKIP) + goto update_iv; + + return RX_CONTINUE; + } + + /* + * Some hardware seems to generate Michael MIC failure reports; even + * though, the frame was not encrypted with TKIP and therefore has no + * MIC. Ignore the flag them to avoid triggering countermeasures. + */ + if (!rx->key || rx->key->conf.cipher != WLAN_CIPHER_SUITE_TKIP || + !(status->flag & RX_FLAG_DECRYPTED)) + return RX_CONTINUE; + + if (rx->sdata->vif.type == NL80211_IFTYPE_AP && rx->key->conf.keyidx) { + /* + * APs with pairwise keys should never receive Michael MIC + * errors for non-zero keyidx because these are reserved for + * group keys and only the AP is sending real multicast + * frames in the BSS. + */ + return RX_DROP_UNUSABLE; + } + + if (status->flag & RX_FLAG_MMIC_ERROR) + goto mic_fail; + + hdrlen = ieee80211_hdrlen(hdr->frame_control); + if (skb->len < hdrlen + MICHAEL_MIC_LEN) + return RX_DROP_UNUSABLE; + + if (skb_linearize(rx->skb)) + return RX_DROP_UNUSABLE; + hdr = (void *)skb->data; + + data = skb->data + hdrlen; + data_len = skb->len - hdrlen - MICHAEL_MIC_LEN; + key = &rx->key->conf.key[NL80211_TKIP_DATA_OFFSET_RX_MIC_KEY]; + michael_mic(key, hdr, data, data_len, mic); + if (crypto_memneq(mic, data + data_len, MICHAEL_MIC_LEN)) + goto mic_fail; + + /* remove Michael MIC from payload */ + skb_trim(skb, skb->len - MICHAEL_MIC_LEN); + +update_iv: + /* update IV in key information to be able to detect replays */ + rx->key->u.tkip.rx[rx->security_idx].iv32 = rx->tkip.iv32; + rx->key->u.tkip.rx[rx->security_idx].iv16 = rx->tkip.iv16; + + return RX_CONTINUE; + +mic_fail: + rx->key->u.tkip.mic_failures++; + +mic_fail_no_key: + /* + * In some cases the key can be unset - e.g. a multicast packet, in + * a driver that supports HW encryption. Send up the key idx only if + * the key is set. + */ + cfg80211_michael_mic_failure(rx->sdata->dev, hdr->addr2, + is_multicast_ether_addr(hdr->addr1) ? + NL80211_KEYTYPE_GROUP : + NL80211_KEYTYPE_PAIRWISE, + rx->key ? rx->key->conf.keyidx : -1, + NULL, GFP_ATOMIC); + return RX_DROP_UNUSABLE; +} + +static int tkip_encrypt_skb(struct ieee80211_tx_data *tx, struct sk_buff *skb) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + struct ieee80211_key *key = tx->key; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + unsigned int hdrlen; + int len, tail; + u64 pn; + u8 *pos; + + if (info->control.hw_key && + !(info->control.hw_key->flags & IEEE80211_KEY_FLAG_GENERATE_IV) && + !(info->control.hw_key->flags & IEEE80211_KEY_FLAG_PUT_IV_SPACE)) { + /* hwaccel - with no need for software-generated IV */ + return 0; + } + + hdrlen = ieee80211_hdrlen(hdr->frame_control); + len = skb->len - hdrlen; + + if (info->control.hw_key) + tail = 0; + else + tail = IEEE80211_TKIP_ICV_LEN; + + if (WARN_ON(skb_tailroom(skb) < tail || + skb_headroom(skb) < IEEE80211_TKIP_IV_LEN)) + return -1; + + pos = skb_push(skb, IEEE80211_TKIP_IV_LEN); + memmove(pos, pos + IEEE80211_TKIP_IV_LEN, hdrlen); + pos += hdrlen; + + /* the HW only needs room for the IV, but not the actual IV */ + if (info->control.hw_key && + (info->control.hw_key->flags & IEEE80211_KEY_FLAG_PUT_IV_SPACE)) + return 0; + + /* Increase IV for the frame */ + pn = atomic64_inc_return(&key->conf.tx_pn); + pos = ieee80211_tkip_add_iv(pos, &key->conf, pn); + + /* hwaccel - with software IV */ + if (info->control.hw_key) + return 0; + + /* Add room for ICV */ + skb_put(skb, IEEE80211_TKIP_ICV_LEN); + + return ieee80211_tkip_encrypt_data(&tx->local->wep_tx_ctx, + key, skb, pos, len); +} + + +ieee80211_tx_result +ieee80211_crypto_tkip_encrypt(struct ieee80211_tx_data *tx) +{ + struct sk_buff *skb; + + ieee80211_tx_set_protected(tx); + + skb_queue_walk(&tx->skbs, skb) { + if (tkip_encrypt_skb(tx, skb) < 0) + return TX_DROP; + } + + return TX_CONTINUE; +} + + +ieee80211_rx_result +ieee80211_crypto_tkip_decrypt(struct ieee80211_rx_data *rx) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) rx->skb->data; + int hdrlen, res, hwaccel = 0; + struct ieee80211_key *key = rx->key; + struct sk_buff *skb = rx->skb; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + + hdrlen = ieee80211_hdrlen(hdr->frame_control); + + if (!ieee80211_is_data(hdr->frame_control)) + return RX_CONTINUE; + + if (!rx->sta || skb->len - hdrlen < 12) + return RX_DROP_UNUSABLE; + + /* it may be possible to optimize this a bit more */ + if (skb_linearize(rx->skb)) + return RX_DROP_UNUSABLE; + hdr = (void *)skb->data; + + /* + * Let TKIP code verify IV, but skip decryption. + * In the case where hardware checks the IV as well, + * we don't even get here, see ieee80211_rx_h_decrypt() + */ + if (status->flag & RX_FLAG_DECRYPTED) + hwaccel = 1; + + res = ieee80211_tkip_decrypt_data(&rx->local->wep_rx_ctx, + key, skb->data + hdrlen, + skb->len - hdrlen, rx->sta->sta.addr, + hdr->addr1, hwaccel, rx->security_idx, + &rx->tkip.iv32, + &rx->tkip.iv16); + if (res != TKIP_DECRYPT_OK) + return RX_DROP_UNUSABLE; + + /* Trim ICV */ + if (!(status->flag & RX_FLAG_ICV_STRIPPED)) + skb_trim(skb, skb->len - IEEE80211_TKIP_ICV_LEN); + + /* Remove IV */ + memmove(skb->data + IEEE80211_TKIP_IV_LEN, skb->data, hdrlen); + skb_pull(skb, IEEE80211_TKIP_IV_LEN); + + return RX_CONTINUE; +} + +/* + * Calculate AAD for CCMP/GCMP, returning qos_tid since we + * need that in CCMP also for b_0. + */ +static u8 ccmp_gcmp_aad(struct sk_buff *skb, u8 *aad) +{ + struct ieee80211_hdr *hdr = (void *)skb->data; + __le16 mask_fc; + int a4_included, mgmt; + u8 qos_tid; + u16 len_a = 22; + + /* + * Mask FC: zero subtype b4 b5 b6 (if not mgmt) + * Retry, PwrMgt, MoreData, Order (if Qos Data); set Protected + */ + mgmt = ieee80211_is_mgmt(hdr->frame_control); + mask_fc = hdr->frame_control; + mask_fc &= ~cpu_to_le16(IEEE80211_FCTL_RETRY | + IEEE80211_FCTL_PM | IEEE80211_FCTL_MOREDATA); + if (!mgmt) + mask_fc &= ~cpu_to_le16(0x0070); + mask_fc |= cpu_to_le16(IEEE80211_FCTL_PROTECTED); + + a4_included = ieee80211_has_a4(hdr->frame_control); + if (a4_included) + len_a += 6; + + if (ieee80211_is_data_qos(hdr->frame_control)) { + qos_tid = ieee80211_get_tid(hdr); + mask_fc &= ~cpu_to_le16(IEEE80211_FCTL_ORDER); + len_a += 2; + } else { + qos_tid = 0; + } + + /* AAD (extra authenticate-only data) / masked 802.11 header + * FC | A1 | A2 | A3 | SC | [A4] | [QC] */ + put_unaligned_be16(len_a, &aad[0]); + put_unaligned(mask_fc, (__le16 *)&aad[2]); + memcpy(&aad[4], &hdr->addrs, 3 * ETH_ALEN); + + /* Mask Seq#, leave Frag# */ + aad[22] = *((u8 *) &hdr->seq_ctrl) & 0x0f; + aad[23] = 0; + + if (a4_included) { + memcpy(&aad[24], hdr->addr4, ETH_ALEN); + aad[30] = qos_tid; + aad[31] = 0; + } else { + memset(&aad[24], 0, ETH_ALEN + IEEE80211_QOS_CTL_LEN); + aad[24] = qos_tid; + } + + return qos_tid; +} + +static void ccmp_special_blocks(struct sk_buff *skb, u8 *pn, u8 *b_0, u8 *aad) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + u8 qos_tid = ccmp_gcmp_aad(skb, aad); + + /* In CCM, the initial vectors (IV) used for CTR mode encryption and CBC + * mode authentication are not allowed to collide, yet both are derived + * from this vector b_0. We only set L := 1 here to indicate that the + * data size can be represented in (L+1) bytes. The CCM layer will take + * care of storing the data length in the top (L+1) bytes and setting + * and clearing the other bits as is required to derive the two IVs. + */ + b_0[0] = 0x1; + + /* Nonce: Nonce Flags | A2 | PN + * Nonce Flags: Priority (b0..b3) | Management (b4) | Reserved (b5..b7) + */ + b_0[1] = qos_tid | (ieee80211_is_mgmt(hdr->frame_control) << 4); + memcpy(&b_0[2], hdr->addr2, ETH_ALEN); + memcpy(&b_0[8], pn, IEEE80211_CCMP_PN_LEN); +} + +static inline void ccmp_pn2hdr(u8 *hdr, u8 *pn, int key_id) +{ + hdr[0] = pn[5]; + hdr[1] = pn[4]; + hdr[2] = 0; + hdr[3] = 0x20 | (key_id << 6); + hdr[4] = pn[3]; + hdr[5] = pn[2]; + hdr[6] = pn[1]; + hdr[7] = pn[0]; +} + + +static inline void ccmp_hdr2pn(u8 *pn, u8 *hdr) +{ + pn[0] = hdr[7]; + pn[1] = hdr[6]; + pn[2] = hdr[5]; + pn[3] = hdr[4]; + pn[4] = hdr[1]; + pn[5] = hdr[0]; +} + + +static int ccmp_encrypt_skb(struct ieee80211_tx_data *tx, struct sk_buff *skb, + unsigned int mic_len) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + struct ieee80211_key *key = tx->key; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + int hdrlen, len, tail; + u8 *pos; + u8 pn[6]; + u64 pn64; + u8 aad[CCM_AAD_LEN]; + u8 b_0[AES_BLOCK_SIZE]; + + if (info->control.hw_key && + !(info->control.hw_key->flags & IEEE80211_KEY_FLAG_GENERATE_IV) && + !(info->control.hw_key->flags & IEEE80211_KEY_FLAG_PUT_IV_SPACE) && + !((info->control.hw_key->flags & + IEEE80211_KEY_FLAG_GENERATE_IV_MGMT) && + ieee80211_is_mgmt(hdr->frame_control))) { + /* + * hwaccel has no need for preallocated room for CCMP + * header or MIC fields + */ + return 0; + } + + hdrlen = ieee80211_hdrlen(hdr->frame_control); + len = skb->len - hdrlen; + + if (info->control.hw_key) + tail = 0; + else + tail = mic_len; + + if (WARN_ON(skb_tailroom(skb) < tail || + skb_headroom(skb) < IEEE80211_CCMP_HDR_LEN)) + return -1; + + pos = skb_push(skb, IEEE80211_CCMP_HDR_LEN); + memmove(pos, pos + IEEE80211_CCMP_HDR_LEN, hdrlen); + + /* the HW only needs room for the IV, but not the actual IV */ + if (info->control.hw_key && + (info->control.hw_key->flags & IEEE80211_KEY_FLAG_PUT_IV_SPACE)) + return 0; + + pos += hdrlen; + + pn64 = atomic64_inc_return(&key->conf.tx_pn); + + pn[5] = pn64; + pn[4] = pn64 >> 8; + pn[3] = pn64 >> 16; + pn[2] = pn64 >> 24; + pn[1] = pn64 >> 32; + pn[0] = pn64 >> 40; + + ccmp_pn2hdr(pos, pn, key->conf.keyidx); + + /* hwaccel - with software CCMP header */ + if (info->control.hw_key) + return 0; + + pos += IEEE80211_CCMP_HDR_LEN; + ccmp_special_blocks(skb, pn, b_0, aad); + return ieee80211_aes_ccm_encrypt(key->u.ccmp.tfm, b_0, aad, pos, len, + skb_put(skb, mic_len)); +} + + +ieee80211_tx_result +ieee80211_crypto_ccmp_encrypt(struct ieee80211_tx_data *tx, + unsigned int mic_len) +{ + struct sk_buff *skb; + + ieee80211_tx_set_protected(tx); + + skb_queue_walk(&tx->skbs, skb) { + if (ccmp_encrypt_skb(tx, skb, mic_len) < 0) + return TX_DROP; + } + + return TX_CONTINUE; +} + + +ieee80211_rx_result +ieee80211_crypto_ccmp_decrypt(struct ieee80211_rx_data *rx, + unsigned int mic_len) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)rx->skb->data; + int hdrlen; + struct ieee80211_key *key = rx->key; + struct sk_buff *skb = rx->skb; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + u8 pn[IEEE80211_CCMP_PN_LEN]; + int data_len; + int queue; + + hdrlen = ieee80211_hdrlen(hdr->frame_control); + + if (!ieee80211_is_data(hdr->frame_control) && + !ieee80211_is_robust_mgmt_frame(skb)) + return RX_CONTINUE; + + if (status->flag & RX_FLAG_DECRYPTED) { + if (!pskb_may_pull(rx->skb, hdrlen + IEEE80211_CCMP_HDR_LEN)) + return RX_DROP_UNUSABLE; + if (status->flag & RX_FLAG_MIC_STRIPPED) + mic_len = 0; + } else { + if (skb_linearize(rx->skb)) + return RX_DROP_UNUSABLE; + } + + /* reload hdr - skb might have been reallocated */ + hdr = (void *)rx->skb->data; + + data_len = skb->len - hdrlen - IEEE80211_CCMP_HDR_LEN - mic_len; + if (!rx->sta || data_len < 0) + return RX_DROP_UNUSABLE; + + if (!(status->flag & RX_FLAG_PN_VALIDATED)) { + int res; + + ccmp_hdr2pn(pn, skb->data + hdrlen); + + queue = rx->security_idx; + + res = memcmp(pn, key->u.ccmp.rx_pn[queue], + IEEE80211_CCMP_PN_LEN); + if (res < 0 || + (!res && !(status->flag & RX_FLAG_ALLOW_SAME_PN))) { + key->u.ccmp.replays++; + return RX_DROP_UNUSABLE; + } + + if (!(status->flag & RX_FLAG_DECRYPTED)) { + u8 aad[2 * AES_BLOCK_SIZE]; + u8 b_0[AES_BLOCK_SIZE]; + /* hardware didn't decrypt/verify MIC */ + ccmp_special_blocks(skb, pn, b_0, aad); + + if (ieee80211_aes_ccm_decrypt( + key->u.ccmp.tfm, b_0, aad, + skb->data + hdrlen + IEEE80211_CCMP_HDR_LEN, + data_len, + skb->data + skb->len - mic_len)) + return RX_DROP_UNUSABLE; + } + + memcpy(key->u.ccmp.rx_pn[queue], pn, IEEE80211_CCMP_PN_LEN); + if (unlikely(ieee80211_is_frag(hdr))) + memcpy(rx->ccm_gcm.pn, pn, IEEE80211_CCMP_PN_LEN); + } + + /* Remove CCMP header and MIC */ + if (pskb_trim(skb, skb->len - mic_len)) + return RX_DROP_UNUSABLE; + memmove(skb->data + IEEE80211_CCMP_HDR_LEN, skb->data, hdrlen); + skb_pull(skb, IEEE80211_CCMP_HDR_LEN); + + return RX_CONTINUE; +} + +static void gcmp_special_blocks(struct sk_buff *skb, u8 *pn, u8 *j_0, u8 *aad) +{ + struct ieee80211_hdr *hdr = (void *)skb->data; + + memcpy(j_0, hdr->addr2, ETH_ALEN); + memcpy(&j_0[ETH_ALEN], pn, IEEE80211_GCMP_PN_LEN); + j_0[13] = 0; + j_0[14] = 0; + j_0[AES_BLOCK_SIZE - 1] = 0x01; + + ccmp_gcmp_aad(skb, aad); +} + +static inline void gcmp_pn2hdr(u8 *hdr, const u8 *pn, int key_id) +{ + hdr[0] = pn[5]; + hdr[1] = pn[4]; + hdr[2] = 0; + hdr[3] = 0x20 | (key_id << 6); + hdr[4] = pn[3]; + hdr[5] = pn[2]; + hdr[6] = pn[1]; + hdr[7] = pn[0]; +} + +static inline void gcmp_hdr2pn(u8 *pn, const u8 *hdr) +{ + pn[0] = hdr[7]; + pn[1] = hdr[6]; + pn[2] = hdr[5]; + pn[3] = hdr[4]; + pn[4] = hdr[1]; + pn[5] = hdr[0]; +} + +static int gcmp_encrypt_skb(struct ieee80211_tx_data *tx, struct sk_buff *skb) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + struct ieee80211_key *key = tx->key; + struct ieee80211_tx_info *info = IEEE80211_SKB_CB(skb); + int hdrlen, len, tail; + u8 *pos; + u8 pn[6]; + u64 pn64; + u8 aad[GCM_AAD_LEN]; + u8 j_0[AES_BLOCK_SIZE]; + + if (info->control.hw_key && + !(info->control.hw_key->flags & IEEE80211_KEY_FLAG_GENERATE_IV) && + !(info->control.hw_key->flags & IEEE80211_KEY_FLAG_PUT_IV_SPACE) && + !((info->control.hw_key->flags & + IEEE80211_KEY_FLAG_GENERATE_IV_MGMT) && + ieee80211_is_mgmt(hdr->frame_control))) { + /* hwaccel has no need for preallocated room for GCMP + * header or MIC fields + */ + return 0; + } + + hdrlen = ieee80211_hdrlen(hdr->frame_control); + len = skb->len - hdrlen; + + if (info->control.hw_key) + tail = 0; + else + tail = IEEE80211_GCMP_MIC_LEN; + + if (WARN_ON(skb_tailroom(skb) < tail || + skb_headroom(skb) < IEEE80211_GCMP_HDR_LEN)) + return -1; + + pos = skb_push(skb, IEEE80211_GCMP_HDR_LEN); + memmove(pos, pos + IEEE80211_GCMP_HDR_LEN, hdrlen); + skb_set_network_header(skb, skb_network_offset(skb) + + IEEE80211_GCMP_HDR_LEN); + + /* the HW only needs room for the IV, but not the actual IV */ + if (info->control.hw_key && + (info->control.hw_key->flags & IEEE80211_KEY_FLAG_PUT_IV_SPACE)) + return 0; + + pos += hdrlen; + + pn64 = atomic64_inc_return(&key->conf.tx_pn); + + pn[5] = pn64; + pn[4] = pn64 >> 8; + pn[3] = pn64 >> 16; + pn[2] = pn64 >> 24; + pn[1] = pn64 >> 32; + pn[0] = pn64 >> 40; + + gcmp_pn2hdr(pos, pn, key->conf.keyidx); + + /* hwaccel - with software GCMP header */ + if (info->control.hw_key) + return 0; + + pos += IEEE80211_GCMP_HDR_LEN; + gcmp_special_blocks(skb, pn, j_0, aad); + return ieee80211_aes_gcm_encrypt(key->u.gcmp.tfm, j_0, aad, pos, len, + skb_put(skb, IEEE80211_GCMP_MIC_LEN)); +} + +ieee80211_tx_result +ieee80211_crypto_gcmp_encrypt(struct ieee80211_tx_data *tx) +{ + struct sk_buff *skb; + + ieee80211_tx_set_protected(tx); + + skb_queue_walk(&tx->skbs, skb) { + if (gcmp_encrypt_skb(tx, skb) < 0) + return TX_DROP; + } + + return TX_CONTINUE; +} + +ieee80211_rx_result +ieee80211_crypto_gcmp_decrypt(struct ieee80211_rx_data *rx) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)rx->skb->data; + int hdrlen; + struct ieee80211_key *key = rx->key; + struct sk_buff *skb = rx->skb; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + u8 pn[IEEE80211_GCMP_PN_LEN]; + int data_len, queue, mic_len = IEEE80211_GCMP_MIC_LEN; + + hdrlen = ieee80211_hdrlen(hdr->frame_control); + + if (!ieee80211_is_data(hdr->frame_control) && + !ieee80211_is_robust_mgmt_frame(skb)) + return RX_CONTINUE; + + if (status->flag & RX_FLAG_DECRYPTED) { + if (!pskb_may_pull(rx->skb, hdrlen + IEEE80211_GCMP_HDR_LEN)) + return RX_DROP_UNUSABLE; + if (status->flag & RX_FLAG_MIC_STRIPPED) + mic_len = 0; + } else { + if (skb_linearize(rx->skb)) + return RX_DROP_UNUSABLE; + } + + /* reload hdr - skb might have been reallocated */ + hdr = (void *)rx->skb->data; + + data_len = skb->len - hdrlen - IEEE80211_GCMP_HDR_LEN - mic_len; + if (!rx->sta || data_len < 0) + return RX_DROP_UNUSABLE; + + if (!(status->flag & RX_FLAG_PN_VALIDATED)) { + int res; + + gcmp_hdr2pn(pn, skb->data + hdrlen); + + queue = rx->security_idx; + + res = memcmp(pn, key->u.gcmp.rx_pn[queue], + IEEE80211_GCMP_PN_LEN); + if (res < 0 || + (!res && !(status->flag & RX_FLAG_ALLOW_SAME_PN))) { + key->u.gcmp.replays++; + return RX_DROP_UNUSABLE; + } + + if (!(status->flag & RX_FLAG_DECRYPTED)) { + u8 aad[2 * AES_BLOCK_SIZE]; + u8 j_0[AES_BLOCK_SIZE]; + /* hardware didn't decrypt/verify MIC */ + gcmp_special_blocks(skb, pn, j_0, aad); + + if (ieee80211_aes_gcm_decrypt( + key->u.gcmp.tfm, j_0, aad, + skb->data + hdrlen + IEEE80211_GCMP_HDR_LEN, + data_len, + skb->data + skb->len - + IEEE80211_GCMP_MIC_LEN)) + return RX_DROP_UNUSABLE; + } + + memcpy(key->u.gcmp.rx_pn[queue], pn, IEEE80211_GCMP_PN_LEN); + if (unlikely(ieee80211_is_frag(hdr))) + memcpy(rx->ccm_gcm.pn, pn, IEEE80211_CCMP_PN_LEN); + } + + /* Remove GCMP header and MIC */ + if (pskb_trim(skb, skb->len - mic_len)) + return RX_DROP_UNUSABLE; + memmove(skb->data + IEEE80211_GCMP_HDR_LEN, skb->data, hdrlen); + skb_pull(skb, IEEE80211_GCMP_HDR_LEN); + + return RX_CONTINUE; +} + +static void bip_aad(struct sk_buff *skb, u8 *aad) +{ + __le16 mask_fc; + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + + /* BIP AAD: FC(masked) || A1 || A2 || A3 */ + + /* FC type/subtype */ + /* Mask FC Retry, PwrMgt, MoreData flags to zero */ + mask_fc = hdr->frame_control; + mask_fc &= ~cpu_to_le16(IEEE80211_FCTL_RETRY | IEEE80211_FCTL_PM | + IEEE80211_FCTL_MOREDATA); + put_unaligned(mask_fc, (__le16 *) &aad[0]); + /* A1 || A2 || A3 */ + memcpy(aad + 2, &hdr->addrs, 3 * ETH_ALEN); +} + + +static inline void bip_ipn_set64(u8 *d, u64 pn) +{ + *d++ = pn; + *d++ = pn >> 8; + *d++ = pn >> 16; + *d++ = pn >> 24; + *d++ = pn >> 32; + *d = pn >> 40; +} + +static inline void bip_ipn_swap(u8 *d, const u8 *s) +{ + *d++ = s[5]; + *d++ = s[4]; + *d++ = s[3]; + *d++ = s[2]; + *d++ = s[1]; + *d = s[0]; +} + + +ieee80211_tx_result +ieee80211_crypto_aes_cmac_encrypt(struct ieee80211_tx_data *tx) +{ + struct sk_buff *skb; + struct ieee80211_tx_info *info; + struct ieee80211_key *key = tx->key; + struct ieee80211_mmie *mmie; + u8 aad[20]; + u64 pn64; + + if (WARN_ON(skb_queue_len(&tx->skbs) != 1)) + return TX_DROP; + + skb = skb_peek(&tx->skbs); + + info = IEEE80211_SKB_CB(skb); + + if (info->control.hw_key && + !(key->conf.flags & IEEE80211_KEY_FLAG_GENERATE_MMIE)) + return TX_CONTINUE; + + if (WARN_ON(skb_tailroom(skb) < sizeof(*mmie))) + return TX_DROP; + + mmie = skb_put(skb, sizeof(*mmie)); + mmie->element_id = WLAN_EID_MMIE; + mmie->length = sizeof(*mmie) - 2; + mmie->key_id = cpu_to_le16(key->conf.keyidx); + + /* PN = PN + 1 */ + pn64 = atomic64_inc_return(&key->conf.tx_pn); + + bip_ipn_set64(mmie->sequence_number, pn64); + + if (info->control.hw_key) + return TX_CONTINUE; + + bip_aad(skb, aad); + + /* + * MIC = AES-128-CMAC(IGTK, AAD || Management Frame Body || MMIE, 64) + */ + ieee80211_aes_cmac(key->u.aes_cmac.tfm, aad, + skb->data + 24, skb->len - 24, mmie->mic); + + return TX_CONTINUE; +} + +ieee80211_tx_result +ieee80211_crypto_aes_cmac_256_encrypt(struct ieee80211_tx_data *tx) +{ + struct sk_buff *skb; + struct ieee80211_tx_info *info; + struct ieee80211_key *key = tx->key; + struct ieee80211_mmie_16 *mmie; + u8 aad[20]; + u64 pn64; + + if (WARN_ON(skb_queue_len(&tx->skbs) != 1)) + return TX_DROP; + + skb = skb_peek(&tx->skbs); + + info = IEEE80211_SKB_CB(skb); + + if (info->control.hw_key) + return TX_CONTINUE; + + if (WARN_ON(skb_tailroom(skb) < sizeof(*mmie))) + return TX_DROP; + + mmie = skb_put(skb, sizeof(*mmie)); + mmie->element_id = WLAN_EID_MMIE; + mmie->length = sizeof(*mmie) - 2; + mmie->key_id = cpu_to_le16(key->conf.keyidx); + + /* PN = PN + 1 */ + pn64 = atomic64_inc_return(&key->conf.tx_pn); + + bip_ipn_set64(mmie->sequence_number, pn64); + + bip_aad(skb, aad); + + /* MIC = AES-256-CMAC(IGTK, AAD || Management Frame Body || MMIE, 128) + */ + ieee80211_aes_cmac_256(key->u.aes_cmac.tfm, aad, + skb->data + 24, skb->len - 24, mmie->mic); + + return TX_CONTINUE; +} + +ieee80211_rx_result +ieee80211_crypto_aes_cmac_decrypt(struct ieee80211_rx_data *rx) +{ + struct sk_buff *skb = rx->skb; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + struct ieee80211_key *key = rx->key; + struct ieee80211_mmie *mmie; + u8 aad[20], mic[8], ipn[6]; + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + + if (!ieee80211_is_mgmt(hdr->frame_control)) + return RX_CONTINUE; + + /* management frames are already linear */ + + if (skb->len < 24 + sizeof(*mmie)) + return RX_DROP_UNUSABLE; + + mmie = (struct ieee80211_mmie *) + (skb->data + skb->len - sizeof(*mmie)); + if (mmie->element_id != WLAN_EID_MMIE || + mmie->length != sizeof(*mmie) - 2) + return RX_DROP_UNUSABLE; /* Invalid MMIE */ + + bip_ipn_swap(ipn, mmie->sequence_number); + + if (memcmp(ipn, key->u.aes_cmac.rx_pn, 6) <= 0) { + key->u.aes_cmac.replays++; + return RX_DROP_UNUSABLE; + } + + if (!(status->flag & RX_FLAG_DECRYPTED)) { + /* hardware didn't decrypt/verify MIC */ + bip_aad(skb, aad); + ieee80211_aes_cmac(key->u.aes_cmac.tfm, aad, + skb->data + 24, skb->len - 24, mic); + if (crypto_memneq(mic, mmie->mic, sizeof(mmie->mic))) { + key->u.aes_cmac.icverrors++; + return RX_DROP_UNUSABLE; + } + } + + memcpy(key->u.aes_cmac.rx_pn, ipn, 6); + + /* Remove MMIE */ + skb_trim(skb, skb->len - sizeof(*mmie)); + + return RX_CONTINUE; +} + +ieee80211_rx_result +ieee80211_crypto_aes_cmac_256_decrypt(struct ieee80211_rx_data *rx) +{ + struct sk_buff *skb = rx->skb; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + struct ieee80211_key *key = rx->key; + struct ieee80211_mmie_16 *mmie; + u8 aad[20], mic[16], ipn[6]; + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + + if (!ieee80211_is_mgmt(hdr->frame_control)) + return RX_CONTINUE; + + /* management frames are already linear */ + + if (skb->len < 24 + sizeof(*mmie)) + return RX_DROP_UNUSABLE; + + mmie = (struct ieee80211_mmie_16 *) + (skb->data + skb->len - sizeof(*mmie)); + if (mmie->element_id != WLAN_EID_MMIE || + mmie->length != sizeof(*mmie) - 2) + return RX_DROP_UNUSABLE; /* Invalid MMIE */ + + bip_ipn_swap(ipn, mmie->sequence_number); + + if (memcmp(ipn, key->u.aes_cmac.rx_pn, 6) <= 0) { + key->u.aes_cmac.replays++; + return RX_DROP_UNUSABLE; + } + + if (!(status->flag & RX_FLAG_DECRYPTED)) { + /* hardware didn't decrypt/verify MIC */ + bip_aad(skb, aad); + ieee80211_aes_cmac_256(key->u.aes_cmac.tfm, aad, + skb->data + 24, skb->len - 24, mic); + if (crypto_memneq(mic, mmie->mic, sizeof(mmie->mic))) { + key->u.aes_cmac.icverrors++; + return RX_DROP_UNUSABLE; + } + } + + memcpy(key->u.aes_cmac.rx_pn, ipn, 6); + + /* Remove MMIE */ + skb_trim(skb, skb->len - sizeof(*mmie)); + + return RX_CONTINUE; +} + +ieee80211_tx_result +ieee80211_crypto_aes_gmac_encrypt(struct ieee80211_tx_data *tx) +{ + struct sk_buff *skb; + struct ieee80211_tx_info *info; + struct ieee80211_key *key = tx->key; + struct ieee80211_mmie_16 *mmie; + struct ieee80211_hdr *hdr; + u8 aad[GMAC_AAD_LEN]; + u64 pn64; + u8 nonce[GMAC_NONCE_LEN]; + + if (WARN_ON(skb_queue_len(&tx->skbs) != 1)) + return TX_DROP; + + skb = skb_peek(&tx->skbs); + + info = IEEE80211_SKB_CB(skb); + + if (info->control.hw_key) + return TX_CONTINUE; + + if (WARN_ON(skb_tailroom(skb) < sizeof(*mmie))) + return TX_DROP; + + mmie = skb_put(skb, sizeof(*mmie)); + mmie->element_id = WLAN_EID_MMIE; + mmie->length = sizeof(*mmie) - 2; + mmie->key_id = cpu_to_le16(key->conf.keyidx); + + /* PN = PN + 1 */ + pn64 = atomic64_inc_return(&key->conf.tx_pn); + + bip_ipn_set64(mmie->sequence_number, pn64); + + bip_aad(skb, aad); + + hdr = (struct ieee80211_hdr *)skb->data; + memcpy(nonce, hdr->addr2, ETH_ALEN); + bip_ipn_swap(nonce + ETH_ALEN, mmie->sequence_number); + + /* MIC = AES-GMAC(IGTK, AAD || Management Frame Body || MMIE, 128) */ + if (ieee80211_aes_gmac(key->u.aes_gmac.tfm, aad, nonce, + skb->data + 24, skb->len - 24, mmie->mic) < 0) + return TX_DROP; + + return TX_CONTINUE; +} + +ieee80211_rx_result +ieee80211_crypto_aes_gmac_decrypt(struct ieee80211_rx_data *rx) +{ + struct sk_buff *skb = rx->skb; + struct ieee80211_rx_status *status = IEEE80211_SKB_RXCB(skb); + struct ieee80211_key *key = rx->key; + struct ieee80211_mmie_16 *mmie; + u8 aad[GMAC_AAD_LEN], *mic, ipn[6], nonce[GMAC_NONCE_LEN]; + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + + if (!ieee80211_is_mgmt(hdr->frame_control)) + return RX_CONTINUE; + + /* management frames are already linear */ + + if (skb->len < 24 + sizeof(*mmie)) + return RX_DROP_UNUSABLE; + + mmie = (struct ieee80211_mmie_16 *) + (skb->data + skb->len - sizeof(*mmie)); + if (mmie->element_id != WLAN_EID_MMIE || + mmie->length != sizeof(*mmie) - 2) + return RX_DROP_UNUSABLE; /* Invalid MMIE */ + + bip_ipn_swap(ipn, mmie->sequence_number); + + if (memcmp(ipn, key->u.aes_gmac.rx_pn, 6) <= 0) { + key->u.aes_gmac.replays++; + return RX_DROP_UNUSABLE; + } + + if (!(status->flag & RX_FLAG_DECRYPTED)) { + /* hardware didn't decrypt/verify MIC */ + bip_aad(skb, aad); + + memcpy(nonce, hdr->addr2, ETH_ALEN); + memcpy(nonce + ETH_ALEN, ipn, 6); + + mic = kmalloc(GMAC_MIC_LEN, GFP_ATOMIC); + if (!mic) + return RX_DROP_UNUSABLE; + if (ieee80211_aes_gmac(key->u.aes_gmac.tfm, aad, nonce, + skb->data + 24, skb->len - 24, + mic) < 0 || + crypto_memneq(mic, mmie->mic, sizeof(mmie->mic))) { + key->u.aes_gmac.icverrors++; + kfree(mic); + return RX_DROP_UNUSABLE; + } + kfree(mic); + } + + memcpy(key->u.aes_gmac.rx_pn, ipn, 6); + + /* Remove MMIE */ + skb_trim(skb, skb->len - sizeof(*mmie)); + + return RX_CONTINUE; +} diff --git a/net/mac80211/wpa.h b/net/mac80211/wpa.h new file mode 100644 index 000000000..a9a81abb5 --- /dev/null +++ b/net/mac80211/wpa.h @@ -0,0 +1,49 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright 2002-2004, Instant802 Networks, Inc. + * Copyright (C) 2022 Intel Corporation + */ + +#ifndef WPA_H +#define WPA_H + +#include +#include +#include "ieee80211_i.h" + +ieee80211_tx_result +ieee80211_tx_h_michael_mic_add(struct ieee80211_tx_data *tx); +ieee80211_rx_result +ieee80211_rx_h_michael_mic_verify(struct ieee80211_rx_data *rx); + +ieee80211_tx_result +ieee80211_crypto_tkip_encrypt(struct ieee80211_tx_data *tx); +ieee80211_rx_result +ieee80211_crypto_tkip_decrypt(struct ieee80211_rx_data *rx); + +ieee80211_tx_result +ieee80211_crypto_ccmp_encrypt(struct ieee80211_tx_data *tx, + unsigned int mic_len); +ieee80211_rx_result +ieee80211_crypto_ccmp_decrypt(struct ieee80211_rx_data *rx, + unsigned int mic_len); + +ieee80211_tx_result +ieee80211_crypto_aes_cmac_encrypt(struct ieee80211_tx_data *tx); +ieee80211_tx_result +ieee80211_crypto_aes_cmac_256_encrypt(struct ieee80211_tx_data *tx); +ieee80211_rx_result +ieee80211_crypto_aes_cmac_decrypt(struct ieee80211_rx_data *rx); +ieee80211_rx_result +ieee80211_crypto_aes_cmac_256_decrypt(struct ieee80211_rx_data *rx); +ieee80211_tx_result +ieee80211_crypto_aes_gmac_encrypt(struct ieee80211_tx_data *tx); +ieee80211_rx_result +ieee80211_crypto_aes_gmac_decrypt(struct ieee80211_rx_data *rx); + +ieee80211_tx_result +ieee80211_crypto_gcmp_encrypt(struct ieee80211_tx_data *tx); +ieee80211_rx_result +ieee80211_crypto_gcmp_decrypt(struct ieee80211_rx_data *rx); + +#endif /* WPA_H */ diff --git a/net/mac802154/Kconfig b/net/mac802154/Kconfig new file mode 100644 index 000000000..901167b1e --- /dev/null +++ b/net/mac802154/Kconfig @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: GPL-2.0-only +config MAC802154 + tristate "Generic IEEE 802.15.4 Soft Networking Stack (mac802154)" + depends on IEEE802154 + select CRC_CCITT + select CRYPTO + select CRYPTO_AUTHENC + select CRYPTO_CCM + select CRYPTO_CTR + select CRYPTO_AES + help + This option enables the hardware independent IEEE 802.15.4 + networking stack for SoftMAC devices (the ones implementing + only PHY level of IEEE 802.15.4 standard). + + Note: this implementation is neither certified, nor feature + complete! Compatibility with other implementations hasn't + been tested yet! + + If you plan to use HardMAC IEEE 802.15.4 devices, you can + say N here. Alternatively you can say M to compile it as + module. diff --git a/net/mac802154/Makefile b/net/mac802154/Makefile new file mode 100644 index 000000000..4059295fd --- /dev/null +++ b/net/mac802154/Makefile @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: GPL-2.0-only +obj-$(CONFIG_MAC802154) += mac802154.o +mac802154-objs := main.o rx.o tx.o mac_cmd.o mib.o \ + iface.o llsec.o util.o cfg.o trace.o + +CFLAGS_trace.o := -I$(src) diff --git a/net/mac802154/cfg.c b/net/mac802154/cfg.c new file mode 100644 index 000000000..1e4a9f74e --- /dev/null +++ b/net/mac802154/cfg.c @@ -0,0 +1,487 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * + * Authors: + * Alexander Aring + * + * Based on: net/mac80211/cfg.c + */ + +#include +#include + +#include "ieee802154_i.h" +#include "driver-ops.h" +#include "cfg.h" + +static struct net_device * +ieee802154_add_iface_deprecated(struct wpan_phy *wpan_phy, + const char *name, + unsigned char name_assign_type, int type) +{ + struct ieee802154_local *local = wpan_phy_priv(wpan_phy); + struct net_device *dev; + + rtnl_lock(); + dev = ieee802154_if_add(local, name, name_assign_type, type, + cpu_to_le64(0x0000000000000000ULL)); + rtnl_unlock(); + + return dev; +} + +static void ieee802154_del_iface_deprecated(struct wpan_phy *wpan_phy, + struct net_device *dev) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + + ieee802154_if_remove(sdata); +} + +#ifdef CONFIG_PM +static int ieee802154_suspend(struct wpan_phy *wpan_phy) +{ + struct ieee802154_local *local = wpan_phy_priv(wpan_phy); + + if (!local->open_count) + goto suspend; + + ieee802154_stop_queue(&local->hw); + synchronize_net(); + + /* stop hardware - this must stop RX */ + ieee802154_stop_device(local); + +suspend: + local->suspended = true; + return 0; +} + +static int ieee802154_resume(struct wpan_phy *wpan_phy) +{ + struct ieee802154_local *local = wpan_phy_priv(wpan_phy); + int ret; + + /* nothing to do if HW shouldn't run */ + if (!local->open_count) + goto wake_up; + + /* restart hardware */ + ret = drv_start(local); + if (ret) + return ret; + +wake_up: + ieee802154_wake_queue(&local->hw); + local->suspended = false; + return 0; +} +#else +#define ieee802154_suspend NULL +#define ieee802154_resume NULL +#endif + +static int +ieee802154_add_iface(struct wpan_phy *phy, const char *name, + unsigned char name_assign_type, + enum nl802154_iftype type, __le64 extended_addr) +{ + struct ieee802154_local *local = wpan_phy_priv(phy); + struct net_device *err; + + err = ieee802154_if_add(local, name, name_assign_type, type, + extended_addr); + return PTR_ERR_OR_ZERO(err); +} + +static int +ieee802154_del_iface(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev) +{ + ieee802154_if_remove(IEEE802154_WPAN_DEV_TO_SUB_IF(wpan_dev)); + + return 0; +} + +static int +ieee802154_set_channel(struct wpan_phy *wpan_phy, u8 page, u8 channel) +{ + struct ieee802154_local *local = wpan_phy_priv(wpan_phy); + int ret; + + ASSERT_RTNL(); + + if (wpan_phy->current_page == page && + wpan_phy->current_channel == channel) + return 0; + + ret = drv_set_channel(local, page, channel); + if (!ret) { + wpan_phy->current_page = page; + wpan_phy->current_channel = channel; + ieee802154_configure_durations(wpan_phy); + } + + return ret; +} + +static int +ieee802154_set_cca_mode(struct wpan_phy *wpan_phy, + const struct wpan_phy_cca *cca) +{ + struct ieee802154_local *local = wpan_phy_priv(wpan_phy); + int ret; + + ASSERT_RTNL(); + + if (wpan_phy_cca_cmp(&wpan_phy->cca, cca)) + return 0; + + ret = drv_set_cca_mode(local, cca); + if (!ret) + wpan_phy->cca = *cca; + + return ret; +} + +static int +ieee802154_set_cca_ed_level(struct wpan_phy *wpan_phy, s32 ed_level) +{ + struct ieee802154_local *local = wpan_phy_priv(wpan_phy); + int ret; + + ASSERT_RTNL(); + + if (wpan_phy->cca_ed_level == ed_level) + return 0; + + ret = drv_set_cca_ed_level(local, ed_level); + if (!ret) + wpan_phy->cca_ed_level = ed_level; + + return ret; +} + +static int +ieee802154_set_tx_power(struct wpan_phy *wpan_phy, s32 power) +{ + struct ieee802154_local *local = wpan_phy_priv(wpan_phy); + int ret; + + ASSERT_RTNL(); + + if (wpan_phy->transmit_power == power) + return 0; + + ret = drv_set_tx_power(local, power); + if (!ret) + wpan_phy->transmit_power = power; + + return ret; +} + +static int +ieee802154_set_pan_id(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + __le16 pan_id) +{ + int ret; + + ASSERT_RTNL(); + + if (wpan_dev->pan_id == pan_id) + return 0; + + ret = mac802154_wpan_update_llsec(wpan_dev->netdev); + if (!ret) + wpan_dev->pan_id = pan_id; + + return ret; +} + +static int +ieee802154_set_backoff_exponent(struct wpan_phy *wpan_phy, + struct wpan_dev *wpan_dev, + u8 min_be, u8 max_be) +{ + ASSERT_RTNL(); + + wpan_dev->min_be = min_be; + wpan_dev->max_be = max_be; + return 0; +} + +static int +ieee802154_set_short_addr(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + __le16 short_addr) +{ + ASSERT_RTNL(); + + wpan_dev->short_addr = short_addr; + return 0; +} + +static int +ieee802154_set_max_csma_backoffs(struct wpan_phy *wpan_phy, + struct wpan_dev *wpan_dev, + u8 max_csma_backoffs) +{ + ASSERT_RTNL(); + + wpan_dev->csma_retries = max_csma_backoffs; + return 0; +} + +static int +ieee802154_set_max_frame_retries(struct wpan_phy *wpan_phy, + struct wpan_dev *wpan_dev, + s8 max_frame_retries) +{ + ASSERT_RTNL(); + + wpan_dev->frame_retries = max_frame_retries; + return 0; +} + +static int +ieee802154_set_lbt_mode(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + bool mode) +{ + ASSERT_RTNL(); + + wpan_dev->lbt = mode; + return 0; +} + +static int +ieee802154_set_ackreq_default(struct wpan_phy *wpan_phy, + struct wpan_dev *wpan_dev, bool ackreq) +{ + ASSERT_RTNL(); + + wpan_dev->ackreq = ackreq; + return 0; +} + +#ifdef CONFIG_IEEE802154_NL802154_EXPERIMENTAL +static void +ieee802154_get_llsec_table(struct wpan_phy *wpan_phy, + struct wpan_dev *wpan_dev, + struct ieee802154_llsec_table **table) +{ + struct net_device *dev = wpan_dev->netdev; + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + + *table = &sdata->sec.table; +} + +static void +ieee802154_lock_llsec_table(struct wpan_phy *wpan_phy, + struct wpan_dev *wpan_dev) +{ + struct net_device *dev = wpan_dev->netdev; + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + + mutex_lock(&sdata->sec_mtx); +} + +static void +ieee802154_unlock_llsec_table(struct wpan_phy *wpan_phy, + struct wpan_dev *wpan_dev) +{ + struct net_device *dev = wpan_dev->netdev; + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + + mutex_unlock(&sdata->sec_mtx); +} + +static int +ieee802154_set_llsec_params(struct wpan_phy *wpan_phy, + struct wpan_dev *wpan_dev, + const struct ieee802154_llsec_params *params, + int changed) +{ + struct net_device *dev = wpan_dev->netdev; + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_set_params(&sdata->sec, params, changed); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +static int +ieee802154_get_llsec_params(struct wpan_phy *wpan_phy, + struct wpan_dev *wpan_dev, + struct ieee802154_llsec_params *params) +{ + struct net_device *dev = wpan_dev->netdev; + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_get_params(&sdata->sec, params); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +static int +ieee802154_add_llsec_key(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + const struct ieee802154_llsec_key_id *id, + const struct ieee802154_llsec_key *key) +{ + struct net_device *dev = wpan_dev->netdev; + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_key_add(&sdata->sec, id, key); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +static int +ieee802154_del_llsec_key(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + const struct ieee802154_llsec_key_id *id) +{ + struct net_device *dev = wpan_dev->netdev; + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_key_del(&sdata->sec, id); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +static int +ieee802154_add_seclevel(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + const struct ieee802154_llsec_seclevel *sl) +{ + struct net_device *dev = wpan_dev->netdev; + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_seclevel_add(&sdata->sec, sl); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +static int +ieee802154_del_seclevel(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + const struct ieee802154_llsec_seclevel *sl) +{ + struct net_device *dev = wpan_dev->netdev; + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_seclevel_del(&sdata->sec, sl); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +static int +ieee802154_add_device(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + const struct ieee802154_llsec_device *dev_desc) +{ + struct net_device *dev = wpan_dev->netdev; + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_dev_add(&sdata->sec, dev_desc); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +static int +ieee802154_del_device(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + __le64 extended_addr) +{ + struct net_device *dev = wpan_dev->netdev; + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_dev_del(&sdata->sec, extended_addr); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +static int +ieee802154_add_devkey(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + __le64 extended_addr, + const struct ieee802154_llsec_device_key *key) +{ + struct net_device *dev = wpan_dev->netdev; + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_devkey_add(&sdata->sec, extended_addr, key); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +static int +ieee802154_del_devkey(struct wpan_phy *wpan_phy, struct wpan_dev *wpan_dev, + __le64 extended_addr, + const struct ieee802154_llsec_device_key *key) +{ + struct net_device *dev = wpan_dev->netdev; + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_devkey_del(&sdata->sec, extended_addr, key); + mutex_unlock(&sdata->sec_mtx); + + return res; +} +#endif /* CONFIG_IEEE802154_NL802154_EXPERIMENTAL */ + +const struct cfg802154_ops mac802154_config_ops = { + .add_virtual_intf_deprecated = ieee802154_add_iface_deprecated, + .del_virtual_intf_deprecated = ieee802154_del_iface_deprecated, + .suspend = ieee802154_suspend, + .resume = ieee802154_resume, + .add_virtual_intf = ieee802154_add_iface, + .del_virtual_intf = ieee802154_del_iface, + .set_channel = ieee802154_set_channel, + .set_cca_mode = ieee802154_set_cca_mode, + .set_cca_ed_level = ieee802154_set_cca_ed_level, + .set_tx_power = ieee802154_set_tx_power, + .set_pan_id = ieee802154_set_pan_id, + .set_short_addr = ieee802154_set_short_addr, + .set_backoff_exponent = ieee802154_set_backoff_exponent, + .set_max_csma_backoffs = ieee802154_set_max_csma_backoffs, + .set_max_frame_retries = ieee802154_set_max_frame_retries, + .set_lbt_mode = ieee802154_set_lbt_mode, + .set_ackreq_default = ieee802154_set_ackreq_default, +#ifdef CONFIG_IEEE802154_NL802154_EXPERIMENTAL + .get_llsec_table = ieee802154_get_llsec_table, + .lock_llsec_table = ieee802154_lock_llsec_table, + .unlock_llsec_table = ieee802154_unlock_llsec_table, + /* TODO above */ + .set_llsec_params = ieee802154_set_llsec_params, + .get_llsec_params = ieee802154_get_llsec_params, + .add_llsec_key = ieee802154_add_llsec_key, + .del_llsec_key = ieee802154_del_llsec_key, + .add_seclevel = ieee802154_add_seclevel, + .del_seclevel = ieee802154_del_seclevel, + .add_device = ieee802154_add_device, + .del_device = ieee802154_del_device, + .add_devkey = ieee802154_add_devkey, + .del_devkey = ieee802154_del_devkey, +#endif /* CONFIG_IEEE802154_NL802154_EXPERIMENTAL */ +}; diff --git a/net/mac802154/cfg.h b/net/mac802154/cfg.h new file mode 100644 index 000000000..3bb089685 --- /dev/null +++ b/net/mac802154/cfg.h @@ -0,0 +1,10 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* mac802154 configuration hooks for cfg802154 + */ + +#ifndef __CFG_H +#define __CFG_H + +extern const struct cfg802154_ops mac802154_config_ops; + +#endif /* __CFG_H */ diff --git a/net/mac802154/driver-ops.h b/net/mac802154/driver-ops.h new file mode 100644 index 000000000..d23f0db98 --- /dev/null +++ b/net/mac802154/driver-ops.h @@ -0,0 +1,285 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __MAC802154_DRIVER_OPS +#define __MAC802154_DRIVER_OPS + +#include +#include + +#include + +#include "ieee802154_i.h" +#include "trace.h" + +static inline int +drv_xmit_async(struct ieee802154_local *local, struct sk_buff *skb) +{ + return local->ops->xmit_async(&local->hw, skb); +} + +static inline int +drv_xmit_sync(struct ieee802154_local *local, struct sk_buff *skb) +{ + might_sleep(); + + return local->ops->xmit_sync(&local->hw, skb); +} + +static inline int drv_start(struct ieee802154_local *local) +{ + int ret; + + might_sleep(); + + trace_802154_drv_start(local); + local->started = true; + smp_mb(); + ret = local->ops->start(&local->hw); + trace_802154_drv_return_int(local, ret); + return ret; +} + +static inline void drv_stop(struct ieee802154_local *local) +{ + might_sleep(); + + trace_802154_drv_stop(local); + local->ops->stop(&local->hw); + trace_802154_drv_return_void(local); + + /* sync away all work on the tasklet before clearing started */ + tasklet_disable(&local->tasklet); + tasklet_enable(&local->tasklet); + + barrier(); + + local->started = false; +} + +static inline int +drv_set_channel(struct ieee802154_local *local, u8 page, u8 channel) +{ + int ret; + + might_sleep(); + + trace_802154_drv_set_channel(local, page, channel); + ret = local->ops->set_channel(&local->hw, page, channel); + trace_802154_drv_return_int(local, ret); + return ret; +} + +static inline int drv_set_tx_power(struct ieee802154_local *local, s32 mbm) +{ + int ret; + + might_sleep(); + + if (!local->ops->set_txpower) { + WARN_ON(1); + return -EOPNOTSUPP; + } + + trace_802154_drv_set_tx_power(local, mbm); + ret = local->ops->set_txpower(&local->hw, mbm); + trace_802154_drv_return_int(local, ret); + return ret; +} + +static inline int drv_set_cca_mode(struct ieee802154_local *local, + const struct wpan_phy_cca *cca) +{ + int ret; + + might_sleep(); + + if (!local->ops->set_cca_mode) { + WARN_ON(1); + return -EOPNOTSUPP; + } + + trace_802154_drv_set_cca_mode(local, cca); + ret = local->ops->set_cca_mode(&local->hw, cca); + trace_802154_drv_return_int(local, ret); + return ret; +} + +static inline int drv_set_lbt_mode(struct ieee802154_local *local, bool mode) +{ + int ret; + + might_sleep(); + + if (!local->ops->set_lbt) { + WARN_ON(1); + return -EOPNOTSUPP; + } + + trace_802154_drv_set_lbt_mode(local, mode); + ret = local->ops->set_lbt(&local->hw, mode); + trace_802154_drv_return_int(local, ret); + return ret; +} + +static inline int +drv_set_cca_ed_level(struct ieee802154_local *local, s32 mbm) +{ + int ret; + + might_sleep(); + + if (!local->ops->set_cca_ed_level) { + WARN_ON(1); + return -EOPNOTSUPP; + } + + trace_802154_drv_set_cca_ed_level(local, mbm); + ret = local->ops->set_cca_ed_level(&local->hw, mbm); + trace_802154_drv_return_int(local, ret); + return ret; +} + +static inline int drv_set_pan_id(struct ieee802154_local *local, __le16 pan_id) +{ + struct ieee802154_hw_addr_filt filt; + int ret; + + might_sleep(); + + if (!local->ops->set_hw_addr_filt) { + WARN_ON(1); + return -EOPNOTSUPP; + } + + filt.pan_id = pan_id; + + trace_802154_drv_set_pan_id(local, pan_id); + ret = local->ops->set_hw_addr_filt(&local->hw, &filt, + IEEE802154_AFILT_PANID_CHANGED); + trace_802154_drv_return_int(local, ret); + return ret; +} + +static inline int +drv_set_extended_addr(struct ieee802154_local *local, __le64 extended_addr) +{ + struct ieee802154_hw_addr_filt filt; + int ret; + + might_sleep(); + + if (!local->ops->set_hw_addr_filt) { + WARN_ON(1); + return -EOPNOTSUPP; + } + + filt.ieee_addr = extended_addr; + + trace_802154_drv_set_extended_addr(local, extended_addr); + ret = local->ops->set_hw_addr_filt(&local->hw, &filt, + IEEE802154_AFILT_IEEEADDR_CHANGED); + trace_802154_drv_return_int(local, ret); + return ret; +} + +static inline int +drv_set_short_addr(struct ieee802154_local *local, __le16 short_addr) +{ + struct ieee802154_hw_addr_filt filt; + int ret; + + might_sleep(); + + if (!local->ops->set_hw_addr_filt) { + WARN_ON(1); + return -EOPNOTSUPP; + } + + filt.short_addr = short_addr; + + trace_802154_drv_set_short_addr(local, short_addr); + ret = local->ops->set_hw_addr_filt(&local->hw, &filt, + IEEE802154_AFILT_SADDR_CHANGED); + trace_802154_drv_return_int(local, ret); + return ret; +} + +static inline int +drv_set_pan_coord(struct ieee802154_local *local, bool is_coord) +{ + struct ieee802154_hw_addr_filt filt; + int ret; + + might_sleep(); + + if (!local->ops->set_hw_addr_filt) { + WARN_ON(1); + return -EOPNOTSUPP; + } + + filt.pan_coord = is_coord; + + trace_802154_drv_set_pan_coord(local, is_coord); + ret = local->ops->set_hw_addr_filt(&local->hw, &filt, + IEEE802154_AFILT_PANC_CHANGED); + trace_802154_drv_return_int(local, ret); + return ret; +} + +static inline int +drv_set_csma_params(struct ieee802154_local *local, u8 min_be, u8 max_be, + u8 max_csma_backoffs) +{ + int ret; + + might_sleep(); + + if (!local->ops->set_csma_params) { + WARN_ON(1); + return -EOPNOTSUPP; + } + + trace_802154_drv_set_csma_params(local, min_be, max_be, + max_csma_backoffs); + ret = local->ops->set_csma_params(&local->hw, min_be, max_be, + max_csma_backoffs); + trace_802154_drv_return_int(local, ret); + return ret; +} + +static inline int +drv_set_max_frame_retries(struct ieee802154_local *local, s8 max_frame_retries) +{ + int ret; + + might_sleep(); + + if (!local->ops->set_frame_retries) { + WARN_ON(1); + return -EOPNOTSUPP; + } + + trace_802154_drv_set_max_frame_retries(local, max_frame_retries); + ret = local->ops->set_frame_retries(&local->hw, max_frame_retries); + trace_802154_drv_return_int(local, ret); + return ret; +} + +static inline int +drv_set_promiscuous_mode(struct ieee802154_local *local, bool on) +{ + int ret; + + might_sleep(); + + if (!local->ops->set_promiscuous_mode) { + WARN_ON(1); + return -EOPNOTSUPP; + } + + trace_802154_drv_set_promiscuous_mode(local, on); + ret = local->ops->set_promiscuous_mode(&local->hw, on); + trace_802154_drv_return_int(local, ret); + return ret; +} + +#endif /* __MAC802154_DRIVER_OPS */ diff --git a/net/mac802154/ieee802154_i.h b/net/mac802154/ieee802154_i.h new file mode 100644 index 000000000..1381e6a5e --- /dev/null +++ b/net/mac802154/ieee802154_i.h @@ -0,0 +1,182 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (C) 2007-2012 Siemens AG + * + * Written by: + * Pavel Smolenskiy + * Maxim Gorbachyov + * Dmitry Eremin-Solenikov + * Alexander Smirnov + */ +#ifndef __IEEE802154_I_H +#define __IEEE802154_I_H + +#include +#include +#include +#include +#include +#include +#include + +#include "llsec.h" + +/* mac802154 device private data */ +struct ieee802154_local { + struct ieee802154_hw hw; + const struct ieee802154_ops *ops; + + /* ieee802154 phy */ + struct wpan_phy *phy; + + int open_count; + + /* As in mac80211 slaves list is modified: + * 1) under the RTNL + * 2) protected by slaves_mtx; + * 3) in an RCU manner + * + * So atomic readers can use any of this protection methods. + */ + struct list_head interfaces; + struct mutex iflist_mtx; + + /* This one is used for scanning and other jobs not to be interfered + * with serial driver. + */ + struct workqueue_struct *workqueue; + + struct hrtimer ifs_timer; + + bool started; + bool suspended; + + struct tasklet_struct tasklet; + struct sk_buff_head skb_queue; + + struct sk_buff *tx_skb; + struct work_struct tx_work; + /* A negative Linux error code or a null/positive MLME error status */ + int tx_result; +}; + +enum { + IEEE802154_RX_MSG = 1, +}; + +enum ieee802154_sdata_state_bits { + SDATA_STATE_RUNNING, +}; + +/* Slave interface definition. + * + * Slaves represent typical network interfaces available from userspace. + * Each ieee802154 device/transceiver may have several slaves and able + * to be associated with several networks at the same time. + */ +struct ieee802154_sub_if_data { + struct list_head list; /* the ieee802154_priv->slaves list */ + + struct wpan_dev wpan_dev; + + struct ieee802154_local *local; + struct net_device *dev; + + unsigned long state; + char name[IFNAMSIZ]; + + /* protects sec from concurrent access by netlink. access by + * encrypt/decrypt/header_create safe without additional protection. + */ + struct mutex sec_mtx; + + struct mac802154_llsec sec; +}; + +/* utility functions/constants */ +extern const void *const mac802154_wpan_phy_privid; /* for wpan_phy privid */ + +static inline struct ieee802154_local * +hw_to_local(struct ieee802154_hw *hw) +{ + return container_of(hw, struct ieee802154_local, hw); +} + +static inline struct ieee802154_sub_if_data * +IEEE802154_DEV_TO_SUB_IF(const struct net_device *dev) +{ + return netdev_priv(dev); +} + +static inline struct ieee802154_sub_if_data * +IEEE802154_WPAN_DEV_TO_SUB_IF(struct wpan_dev *wpan_dev) +{ + return container_of(wpan_dev, struct ieee802154_sub_if_data, wpan_dev); +} + +static inline bool +ieee802154_sdata_running(struct ieee802154_sub_if_data *sdata) +{ + return test_bit(SDATA_STATE_RUNNING, &sdata->state); +} + +extern struct ieee802154_mlme_ops mac802154_mlme_wpan; + +void ieee802154_rx(struct ieee802154_local *local, struct sk_buff *skb); +void ieee802154_xmit_worker(struct work_struct *work); +netdev_tx_t +ieee802154_monitor_start_xmit(struct sk_buff *skb, struct net_device *dev); +netdev_tx_t +ieee802154_subif_start_xmit(struct sk_buff *skb, struct net_device *dev); +enum hrtimer_restart ieee802154_xmit_ifs_timer(struct hrtimer *timer); + +/* MIB callbacks */ +void mac802154_dev_set_page_channel(struct net_device *dev, u8 page, u8 chan); + +int mac802154_get_params(struct net_device *dev, + struct ieee802154_llsec_params *params); +int mac802154_set_params(struct net_device *dev, + const struct ieee802154_llsec_params *params, + int changed); + +int mac802154_add_key(struct net_device *dev, + const struct ieee802154_llsec_key_id *id, + const struct ieee802154_llsec_key *key); +int mac802154_del_key(struct net_device *dev, + const struct ieee802154_llsec_key_id *id); + +int mac802154_add_dev(struct net_device *dev, + const struct ieee802154_llsec_device *llsec_dev); +int mac802154_del_dev(struct net_device *dev, __le64 dev_addr); + +int mac802154_add_devkey(struct net_device *dev, + __le64 device_addr, + const struct ieee802154_llsec_device_key *key); +int mac802154_del_devkey(struct net_device *dev, + __le64 device_addr, + const struct ieee802154_llsec_device_key *key); + +int mac802154_add_seclevel(struct net_device *dev, + const struct ieee802154_llsec_seclevel *sl); +int mac802154_del_seclevel(struct net_device *dev, + const struct ieee802154_llsec_seclevel *sl); + +void mac802154_lock_table(struct net_device *dev); +void mac802154_get_table(struct net_device *dev, + struct ieee802154_llsec_table **t); +void mac802154_unlock_table(struct net_device *dev); + +int mac802154_wpan_update_llsec(struct net_device *dev); + +/* interface handling */ +int ieee802154_iface_init(void); +void ieee802154_iface_exit(void); +void ieee802154_if_remove(struct ieee802154_sub_if_data *sdata); +struct net_device * +ieee802154_if_add(struct ieee802154_local *local, const char *name, + unsigned char name_assign_type, enum nl802154_iftype type, + __le64 extended_addr); +void ieee802154_remove_interfaces(struct ieee802154_local *local); +void ieee802154_stop_device(struct ieee802154_local *local); + +#endif /* __IEEE802154_I_H */ diff --git a/net/mac802154/iface.c b/net/mac802154/iface.c new file mode 100644 index 000000000..7e2065e72 --- /dev/null +++ b/net/mac802154/iface.c @@ -0,0 +1,745 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2007-2012 Siemens AG + * + * Written by: + * Dmitry Eremin-Solenikov + * Sergey Lapin + * Maxim Gorbachyov + * Alexander Smirnov + */ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "ieee802154_i.h" +#include "driver-ops.h" + +int mac802154_wpan_update_llsec(struct net_device *dev) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + struct ieee802154_mlme_ops *ops = ieee802154_mlme_ops(dev); + struct wpan_dev *wpan_dev = &sdata->wpan_dev; + int rc = 0; + + if (ops->llsec) { + struct ieee802154_llsec_params params; + int changed = 0; + + params.pan_id = wpan_dev->pan_id; + changed |= IEEE802154_LLSEC_PARAM_PAN_ID; + + params.hwaddr = wpan_dev->extended_addr; + changed |= IEEE802154_LLSEC_PARAM_HWADDR; + + rc = ops->llsec->set_params(dev, ¶ms, changed); + } + + return rc; +} + +static int +mac802154_wpan_ioctl(struct net_device *dev, struct ifreq *ifr, int cmd) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + struct wpan_dev *wpan_dev = &sdata->wpan_dev; + struct sockaddr_ieee802154 *sa = + (struct sockaddr_ieee802154 *)&ifr->ifr_addr; + int err = -ENOIOCTLCMD; + + if (cmd != SIOCGIFADDR && cmd != SIOCSIFADDR) + return err; + + rtnl_lock(); + + switch (cmd) { + case SIOCGIFADDR: + { + u16 pan_id, short_addr; + + pan_id = le16_to_cpu(wpan_dev->pan_id); + short_addr = le16_to_cpu(wpan_dev->short_addr); + if (pan_id == IEEE802154_PANID_BROADCAST || + short_addr == IEEE802154_ADDR_BROADCAST) { + err = -EADDRNOTAVAIL; + break; + } + + sa->family = AF_IEEE802154; + sa->addr.addr_type = IEEE802154_ADDR_SHORT; + sa->addr.pan_id = pan_id; + sa->addr.short_addr = short_addr; + + err = 0; + break; + } + case SIOCSIFADDR: + if (netif_running(dev)) { + rtnl_unlock(); + return -EBUSY; + } + + dev_warn(&dev->dev, + "Using DEBUGing ioctl SIOCSIFADDR isn't recommended!\n"); + if (sa->family != AF_IEEE802154 || + sa->addr.addr_type != IEEE802154_ADDR_SHORT || + sa->addr.pan_id == IEEE802154_PANID_BROADCAST || + sa->addr.short_addr == IEEE802154_ADDR_BROADCAST || + sa->addr.short_addr == IEEE802154_ADDR_UNDEF) { + err = -EINVAL; + break; + } + + wpan_dev->pan_id = cpu_to_le16(sa->addr.pan_id); + wpan_dev->short_addr = cpu_to_le16(sa->addr.short_addr); + + err = mac802154_wpan_update_llsec(dev); + break; + } + + rtnl_unlock(); + return err; +} + +static int mac802154_wpan_mac_addr(struct net_device *dev, void *p) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + struct sockaddr *addr = p; + __le64 extended_addr; + + if (netif_running(dev)) + return -EBUSY; + + /* lowpan need to be down for update + * SLAAC address after ifup + */ + if (sdata->wpan_dev.lowpan_dev) { + if (netif_running(sdata->wpan_dev.lowpan_dev)) + return -EBUSY; + } + + ieee802154_be64_to_le64(&extended_addr, addr->sa_data); + if (!ieee802154_is_valid_extended_unicast_addr(extended_addr)) + return -EINVAL; + + dev_addr_set(dev, addr->sa_data); + sdata->wpan_dev.extended_addr = extended_addr; + + /* update lowpan interface mac address when + * wpan mac has been changed + */ + if (sdata->wpan_dev.lowpan_dev) + dev_addr_set(sdata->wpan_dev.lowpan_dev, dev->dev_addr); + + return mac802154_wpan_update_llsec(dev); +} + +static int ieee802154_setup_hw(struct ieee802154_sub_if_data *sdata) +{ + struct ieee802154_local *local = sdata->local; + struct wpan_dev *wpan_dev = &sdata->wpan_dev; + int ret; + + if (local->hw.flags & IEEE802154_HW_PROMISCUOUS) { + ret = drv_set_promiscuous_mode(local, + wpan_dev->promiscuous_mode); + if (ret < 0) + return ret; + } + + if (local->hw.flags & IEEE802154_HW_AFILT) { + ret = drv_set_pan_id(local, wpan_dev->pan_id); + if (ret < 0) + return ret; + + ret = drv_set_extended_addr(local, wpan_dev->extended_addr); + if (ret < 0) + return ret; + + ret = drv_set_short_addr(local, wpan_dev->short_addr); + if (ret < 0) + return ret; + } + + if (local->hw.flags & IEEE802154_HW_LBT) { + ret = drv_set_lbt_mode(local, wpan_dev->lbt); + if (ret < 0) + return ret; + } + + if (local->hw.flags & IEEE802154_HW_CSMA_PARAMS) { + ret = drv_set_csma_params(local, wpan_dev->min_be, + wpan_dev->max_be, + wpan_dev->csma_retries); + if (ret < 0) + return ret; + } + + if (local->hw.flags & IEEE802154_HW_FRAME_RETRIES) { + ret = drv_set_max_frame_retries(local, wpan_dev->frame_retries); + if (ret < 0) + return ret; + } + + return 0; +} + +static int mac802154_slave_open(struct net_device *dev) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + struct ieee802154_local *local = sdata->local; + int res; + + ASSERT_RTNL(); + + set_bit(SDATA_STATE_RUNNING, &sdata->state); + + if (!local->open_count) { + res = ieee802154_setup_hw(sdata); + if (res) + goto err; + + res = drv_start(local); + if (res) + goto err; + } + + local->open_count++; + netif_start_queue(dev); + return 0; +err: + /* might already be clear but that doesn't matter */ + clear_bit(SDATA_STATE_RUNNING, &sdata->state); + + return res; +} + +static int +ieee802154_check_mac_settings(struct ieee802154_local *local, + struct wpan_dev *wpan_dev, + struct wpan_dev *nwpan_dev) +{ + ASSERT_RTNL(); + + if (local->hw.flags & IEEE802154_HW_PROMISCUOUS) { + if (wpan_dev->promiscuous_mode != nwpan_dev->promiscuous_mode) + return -EBUSY; + } + + if (local->hw.flags & IEEE802154_HW_AFILT) { + if (wpan_dev->pan_id != nwpan_dev->pan_id || + wpan_dev->short_addr != nwpan_dev->short_addr || + wpan_dev->extended_addr != nwpan_dev->extended_addr) + return -EBUSY; + } + + if (local->hw.flags & IEEE802154_HW_CSMA_PARAMS) { + if (wpan_dev->min_be != nwpan_dev->min_be || + wpan_dev->max_be != nwpan_dev->max_be || + wpan_dev->csma_retries != nwpan_dev->csma_retries) + return -EBUSY; + } + + if (local->hw.flags & IEEE802154_HW_FRAME_RETRIES) { + if (wpan_dev->frame_retries != nwpan_dev->frame_retries) + return -EBUSY; + } + + if (local->hw.flags & IEEE802154_HW_LBT) { + if (wpan_dev->lbt != nwpan_dev->lbt) + return -EBUSY; + } + + return 0; +} + +static int +ieee802154_check_concurrent_iface(struct ieee802154_sub_if_data *sdata, + enum nl802154_iftype iftype) +{ + struct ieee802154_local *local = sdata->local; + struct wpan_dev *wpan_dev = &sdata->wpan_dev; + struct ieee802154_sub_if_data *nsdata; + + /* we hold the RTNL here so can safely walk the list */ + list_for_each_entry(nsdata, &local->interfaces, list) { + if (nsdata != sdata && ieee802154_sdata_running(nsdata)) { + int ret; + + /* TODO currently we don't support multiple node types + * we need to run skb_clone at rx path. Check if there + * exist really an use case if we need to support + * multiple node types at the same time. + */ + if (wpan_dev->iftype == NL802154_IFTYPE_NODE && + nsdata->wpan_dev.iftype == NL802154_IFTYPE_NODE) + return -EBUSY; + + /* check all phy mac sublayer settings are the same. + * We have only one phy, different values makes trouble. + */ + ret = ieee802154_check_mac_settings(local, wpan_dev, + &nsdata->wpan_dev); + if (ret < 0) + return ret; + } + } + + return 0; +} + +static int mac802154_wpan_open(struct net_device *dev) +{ + int rc; + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + struct wpan_dev *wpan_dev = &sdata->wpan_dev; + + rc = ieee802154_check_concurrent_iface(sdata, wpan_dev->iftype); + if (rc < 0) + return rc; + + return mac802154_slave_open(dev); +} + +static int mac802154_slave_close(struct net_device *dev) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + struct ieee802154_local *local = sdata->local; + + ASSERT_RTNL(); + + netif_stop_queue(dev); + local->open_count--; + + clear_bit(SDATA_STATE_RUNNING, &sdata->state); + + if (!local->open_count) + ieee802154_stop_device(local); + + return 0; +} + +static int mac802154_set_header_security(struct ieee802154_sub_if_data *sdata, + struct ieee802154_hdr *hdr, + const struct ieee802154_mac_cb *cb) +{ + struct ieee802154_llsec_params params; + u8 level; + + mac802154_llsec_get_params(&sdata->sec, ¶ms); + + if (!params.enabled && cb->secen_override && cb->secen) + return -EINVAL; + if (!params.enabled || + (cb->secen_override && !cb->secen) || + !params.out_level) + return 0; + if (cb->seclevel_override && !cb->seclevel) + return -EINVAL; + + level = cb->seclevel_override ? cb->seclevel : params.out_level; + + hdr->fc.security_enabled = 1; + hdr->sec.level = level; + hdr->sec.key_id_mode = params.out_key.mode; + if (params.out_key.mode == IEEE802154_SCF_KEY_SHORT_INDEX) + hdr->sec.short_src = params.out_key.short_source; + else if (params.out_key.mode == IEEE802154_SCF_KEY_HW_INDEX) + hdr->sec.extended_src = params.out_key.extended_source; + hdr->sec.key_id = params.out_key.id; + + return 0; +} + +static int ieee802154_header_create(struct sk_buff *skb, + struct net_device *dev, + const struct ieee802154_addr *daddr, + const struct ieee802154_addr *saddr, + unsigned len) +{ + struct ieee802154_hdr hdr; + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + struct wpan_dev *wpan_dev = &sdata->wpan_dev; + struct ieee802154_mac_cb *cb = mac_cb(skb); + int hlen; + + if (!daddr) + return -EINVAL; + + memset(&hdr.fc, 0, sizeof(hdr.fc)); + hdr.fc.type = cb->type; + hdr.fc.security_enabled = cb->secen; + hdr.fc.ack_request = cb->ackreq; + hdr.seq = atomic_inc_return(&dev->ieee802154_ptr->dsn) & 0xFF; + + if (mac802154_set_header_security(sdata, &hdr, cb) < 0) + return -EINVAL; + + if (!saddr) { + if (wpan_dev->short_addr == cpu_to_le16(IEEE802154_ADDR_BROADCAST) || + wpan_dev->short_addr == cpu_to_le16(IEEE802154_ADDR_UNDEF) || + wpan_dev->pan_id == cpu_to_le16(IEEE802154_PANID_BROADCAST)) { + hdr.source.mode = IEEE802154_ADDR_LONG; + hdr.source.extended_addr = wpan_dev->extended_addr; + } else { + hdr.source.mode = IEEE802154_ADDR_SHORT; + hdr.source.short_addr = wpan_dev->short_addr; + } + + hdr.source.pan_id = wpan_dev->pan_id; + } else { + hdr.source = *(const struct ieee802154_addr *)saddr; + } + + hdr.dest = *(const struct ieee802154_addr *)daddr; + + hlen = ieee802154_hdr_push(skb, &hdr); + if (hlen < 0) + return -EINVAL; + + skb_reset_mac_header(skb); + skb->mac_len = hlen; + + if (len > ieee802154_max_payload(&hdr)) + return -EMSGSIZE; + + return hlen; +} + +static const struct wpan_dev_header_ops ieee802154_header_ops = { + .create = ieee802154_header_create, +}; + +/* This header create functionality assumes a 8 byte array for + * source and destination pointer at maximum. To adapt this for + * the 802.15.4 dataframe header we use extended address handling + * here only and intra pan connection. fc fields are mostly fallback + * handling. For provide dev_hard_header for dgram sockets. + */ +static int mac802154_header_create(struct sk_buff *skb, + struct net_device *dev, + unsigned short type, + const void *daddr, + const void *saddr, + unsigned len) +{ + struct ieee802154_hdr hdr; + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + struct wpan_dev *wpan_dev = &sdata->wpan_dev; + struct ieee802154_mac_cb cb = { }; + int hlen; + + if (!daddr) + return -EINVAL; + + memset(&hdr.fc, 0, sizeof(hdr.fc)); + hdr.fc.type = IEEE802154_FC_TYPE_DATA; + hdr.fc.ack_request = wpan_dev->ackreq; + hdr.seq = atomic_inc_return(&dev->ieee802154_ptr->dsn) & 0xFF; + + /* TODO currently a workaround to give zero cb block to set + * security parameters defaults according MIB. + */ + if (mac802154_set_header_security(sdata, &hdr, &cb) < 0) + return -EINVAL; + + hdr.dest.pan_id = wpan_dev->pan_id; + hdr.dest.mode = IEEE802154_ADDR_LONG; + ieee802154_be64_to_le64(&hdr.dest.extended_addr, daddr); + + hdr.source.pan_id = hdr.dest.pan_id; + hdr.source.mode = IEEE802154_ADDR_LONG; + + if (!saddr) + hdr.source.extended_addr = wpan_dev->extended_addr; + else + ieee802154_be64_to_le64(&hdr.source.extended_addr, saddr); + + hlen = ieee802154_hdr_push(skb, &hdr); + if (hlen < 0) + return -EINVAL; + + skb_reset_mac_header(skb); + skb->mac_len = hlen; + + if (len > ieee802154_max_payload(&hdr)) + return -EMSGSIZE; + + return hlen; +} + +static int +mac802154_header_parse(const struct sk_buff *skb, unsigned char *haddr) +{ + struct ieee802154_hdr hdr; + + if (ieee802154_hdr_peek_addrs(skb, &hdr) < 0) { + pr_debug("malformed packet\n"); + return 0; + } + + if (hdr.source.mode == IEEE802154_ADDR_LONG) { + ieee802154_le64_to_be64(haddr, &hdr.source.extended_addr); + return IEEE802154_EXTENDED_ADDR_LEN; + } + + return 0; +} + +static const struct header_ops mac802154_header_ops = { + .create = mac802154_header_create, + .parse = mac802154_header_parse, +}; + +static const struct net_device_ops mac802154_wpan_ops = { + .ndo_open = mac802154_wpan_open, + .ndo_stop = mac802154_slave_close, + .ndo_start_xmit = ieee802154_subif_start_xmit, + .ndo_do_ioctl = mac802154_wpan_ioctl, + .ndo_set_mac_address = mac802154_wpan_mac_addr, +}; + +static const struct net_device_ops mac802154_monitor_ops = { + .ndo_open = mac802154_wpan_open, + .ndo_stop = mac802154_slave_close, + .ndo_start_xmit = ieee802154_monitor_start_xmit, +}; + +static void mac802154_wpan_free(struct net_device *dev) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + + mac802154_llsec_destroy(&sdata->sec); +} + +static void ieee802154_if_setup(struct net_device *dev) +{ + dev->addr_len = IEEE802154_EXTENDED_ADDR_LEN; + memset(dev->broadcast, 0xff, IEEE802154_EXTENDED_ADDR_LEN); + + /* Let hard_header_len set to IEEE802154_MIN_HEADER_LEN. AF_PACKET + * will not send frames without any payload, but ack frames + * has no payload, so substract one that we can send a 3 bytes + * frame. The xmit callback assumes at least a hard header where two + * bytes fc and sequence field are set. + */ + dev->hard_header_len = IEEE802154_MIN_HEADER_LEN - 1; + /* The auth_tag header is for security and places in private payload + * room of mac frame which stucks between payload and FCS field. + */ + dev->needed_tailroom = IEEE802154_MAX_AUTH_TAG_LEN + + IEEE802154_FCS_LEN; + /* The mtu size is the payload without mac header in this case. + * We have a dynamic length header with a minimum header length + * which is hard_header_len. In this case we let mtu to the size + * of maximum payload which is IEEE802154_MTU - IEEE802154_FCS_LEN - + * hard_header_len. The FCS which is set by hardware or ndo_start_xmit + * and the minimum mac header which can be evaluated inside driver + * layer. The rest of mac header will be part of payload if greater + * than hard_header_len. + */ + dev->mtu = IEEE802154_MTU - IEEE802154_FCS_LEN - + dev->hard_header_len; + dev->tx_queue_len = 300; + dev->flags = IFF_NOARP | IFF_BROADCAST; +} + +static int +ieee802154_setup_sdata(struct ieee802154_sub_if_data *sdata, + enum nl802154_iftype type) +{ + struct wpan_dev *wpan_dev = &sdata->wpan_dev; + int ret; + u8 tmp; + + /* set some type-dependent values */ + sdata->wpan_dev.iftype = type; + + get_random_bytes(&tmp, sizeof(tmp)); + atomic_set(&wpan_dev->bsn, tmp); + get_random_bytes(&tmp, sizeof(tmp)); + atomic_set(&wpan_dev->dsn, tmp); + + /* defaults per 802.15.4-2011 */ + wpan_dev->min_be = 3; + wpan_dev->max_be = 5; + wpan_dev->csma_retries = 4; + wpan_dev->frame_retries = 3; + + wpan_dev->pan_id = cpu_to_le16(IEEE802154_PANID_BROADCAST); + wpan_dev->short_addr = cpu_to_le16(IEEE802154_ADDR_BROADCAST); + + switch (type) { + case NL802154_IFTYPE_NODE: + ieee802154_be64_to_le64(&wpan_dev->extended_addr, + sdata->dev->dev_addr); + + sdata->dev->header_ops = &mac802154_header_ops; + sdata->dev->needs_free_netdev = true; + sdata->dev->priv_destructor = mac802154_wpan_free; + sdata->dev->netdev_ops = &mac802154_wpan_ops; + sdata->dev->ml_priv = &mac802154_mlme_wpan; + wpan_dev->promiscuous_mode = false; + wpan_dev->header_ops = &ieee802154_header_ops; + + mutex_init(&sdata->sec_mtx); + + mac802154_llsec_init(&sdata->sec); + ret = mac802154_wpan_update_llsec(sdata->dev); + if (ret < 0) + return ret; + + break; + case NL802154_IFTYPE_MONITOR: + sdata->dev->needs_free_netdev = true; + sdata->dev->netdev_ops = &mac802154_monitor_ops; + wpan_dev->promiscuous_mode = true; + break; + default: + BUG(); + } + + return 0; +} + +struct net_device * +ieee802154_if_add(struct ieee802154_local *local, const char *name, + unsigned char name_assign_type, enum nl802154_iftype type, + __le64 extended_addr) +{ + u8 addr[IEEE802154_EXTENDED_ADDR_LEN]; + struct net_device *ndev = NULL; + struct ieee802154_sub_if_data *sdata = NULL; + int ret; + + ASSERT_RTNL(); + + ndev = alloc_netdev(sizeof(*sdata), name, + name_assign_type, ieee802154_if_setup); + if (!ndev) + return ERR_PTR(-ENOMEM); + + ndev->needed_headroom = local->hw.extra_tx_headroom + + IEEE802154_MAX_HEADER_LEN; + + ret = dev_alloc_name(ndev, ndev->name); + if (ret < 0) + goto err; + + ieee802154_le64_to_be64(ndev->perm_addr, + &local->hw.phy->perm_extended_addr); + switch (type) { + case NL802154_IFTYPE_NODE: + ndev->type = ARPHRD_IEEE802154; + if (ieee802154_is_valid_extended_unicast_addr(extended_addr)) { + ieee802154_le64_to_be64(addr, &extended_addr); + dev_addr_set(ndev, addr); + } else { + dev_addr_set(ndev, ndev->perm_addr); + } + break; + case NL802154_IFTYPE_MONITOR: + ndev->type = ARPHRD_IEEE802154_MONITOR; + break; + default: + ret = -EINVAL; + goto err; + } + + /* TODO check this */ + SET_NETDEV_DEV(ndev, &local->phy->dev); + dev_net_set(ndev, wpan_phy_net(local->hw.phy)); + sdata = netdev_priv(ndev); + ndev->ieee802154_ptr = &sdata->wpan_dev; + memcpy(sdata->name, ndev->name, IFNAMSIZ); + sdata->dev = ndev; + sdata->wpan_dev.wpan_phy = local->hw.phy; + sdata->local = local; + INIT_LIST_HEAD(&sdata->wpan_dev.list); + + /* setup type-dependent data */ + ret = ieee802154_setup_sdata(sdata, type); + if (ret) + goto err; + + ret = register_netdevice(ndev); + if (ret < 0) + goto err; + + mutex_lock(&local->iflist_mtx); + list_add_tail_rcu(&sdata->list, &local->interfaces); + mutex_unlock(&local->iflist_mtx); + + return ndev; + +err: + free_netdev(ndev); + return ERR_PTR(ret); +} + +void ieee802154_if_remove(struct ieee802154_sub_if_data *sdata) +{ + ASSERT_RTNL(); + + mutex_lock(&sdata->local->iflist_mtx); + list_del_rcu(&sdata->list); + mutex_unlock(&sdata->local->iflist_mtx); + + synchronize_rcu(); + unregister_netdevice(sdata->dev); +} + +void ieee802154_remove_interfaces(struct ieee802154_local *local) +{ + struct ieee802154_sub_if_data *sdata, *tmp; + + mutex_lock(&local->iflist_mtx); + list_for_each_entry_safe(sdata, tmp, &local->interfaces, list) { + list_del(&sdata->list); + + unregister_netdevice(sdata->dev); + } + mutex_unlock(&local->iflist_mtx); +} + +static int netdev_notify(struct notifier_block *nb, + unsigned long state, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct ieee802154_sub_if_data *sdata; + + if (state != NETDEV_CHANGENAME) + return NOTIFY_DONE; + + if (!dev->ieee802154_ptr || !dev->ieee802154_ptr->wpan_phy) + return NOTIFY_DONE; + + if (dev->ieee802154_ptr->wpan_phy->privid != mac802154_wpan_phy_privid) + return NOTIFY_DONE; + + sdata = IEEE802154_DEV_TO_SUB_IF(dev); + memcpy(sdata->name, dev->name, IFNAMSIZ); + + return NOTIFY_OK; +} + +static struct notifier_block mac802154_netdev_notifier = { + .notifier_call = netdev_notify, +}; + +int ieee802154_iface_init(void) +{ + return register_netdevice_notifier(&mac802154_netdev_notifier); +} + +void ieee802154_iface_exit(void) +{ + unregister_netdevice_notifier(&mac802154_netdev_notifier); +} diff --git a/net/mac802154/llsec.c b/net/mac802154/llsec.c new file mode 100644 index 000000000..55550ead2 --- /dev/null +++ b/net/mac802154/llsec.c @@ -0,0 +1,1050 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2014 Fraunhofer ITWM + * + * Written by: + * Phoebe Buckheister + */ + +#include +#include +#include +#include +#include + +#include +#include + +#include "ieee802154_i.h" +#include "llsec.h" + +static void llsec_key_put(struct mac802154_llsec_key *key); +static bool llsec_key_id_equal(const struct ieee802154_llsec_key_id *a, + const struct ieee802154_llsec_key_id *b); + +static void llsec_dev_free(struct mac802154_llsec_device *dev); + +void mac802154_llsec_init(struct mac802154_llsec *sec) +{ + memset(sec, 0, sizeof(*sec)); + + memset(&sec->params.default_key_source, 0xFF, IEEE802154_ADDR_LEN); + + INIT_LIST_HEAD(&sec->table.security_levels); + INIT_LIST_HEAD(&sec->table.devices); + INIT_LIST_HEAD(&sec->table.keys); + hash_init(sec->devices_short); + hash_init(sec->devices_hw); + rwlock_init(&sec->lock); +} + +void mac802154_llsec_destroy(struct mac802154_llsec *sec) +{ + struct ieee802154_llsec_seclevel *sl, *sn; + struct ieee802154_llsec_device *dev, *dn; + struct ieee802154_llsec_key_entry *key, *kn; + + list_for_each_entry_safe(sl, sn, &sec->table.security_levels, list) { + struct mac802154_llsec_seclevel *msl; + + msl = container_of(sl, struct mac802154_llsec_seclevel, level); + list_del(&sl->list); + kfree_sensitive(msl); + } + + list_for_each_entry_safe(dev, dn, &sec->table.devices, list) { + struct mac802154_llsec_device *mdev; + + mdev = container_of(dev, struct mac802154_llsec_device, dev); + list_del(&dev->list); + llsec_dev_free(mdev); + } + + list_for_each_entry_safe(key, kn, &sec->table.keys, list) { + struct mac802154_llsec_key *mkey; + + mkey = container_of(key->key, struct mac802154_llsec_key, key); + list_del(&key->list); + llsec_key_put(mkey); + kfree_sensitive(key); + } +} + +int mac802154_llsec_get_params(struct mac802154_llsec *sec, + struct ieee802154_llsec_params *params) +{ + read_lock_bh(&sec->lock); + *params = sec->params; + read_unlock_bh(&sec->lock); + + return 0; +} + +int mac802154_llsec_set_params(struct mac802154_llsec *sec, + const struct ieee802154_llsec_params *params, + int changed) +{ + write_lock_bh(&sec->lock); + + if (changed & IEEE802154_LLSEC_PARAM_ENABLED) + sec->params.enabled = params->enabled; + if (changed & IEEE802154_LLSEC_PARAM_FRAME_COUNTER) + sec->params.frame_counter = params->frame_counter; + if (changed & IEEE802154_LLSEC_PARAM_OUT_LEVEL) + sec->params.out_level = params->out_level; + if (changed & IEEE802154_LLSEC_PARAM_OUT_KEY) + sec->params.out_key = params->out_key; + if (changed & IEEE802154_LLSEC_PARAM_KEY_SOURCE) + sec->params.default_key_source = params->default_key_source; + if (changed & IEEE802154_LLSEC_PARAM_PAN_ID) + sec->params.pan_id = params->pan_id; + if (changed & IEEE802154_LLSEC_PARAM_HWADDR) + sec->params.hwaddr = params->hwaddr; + if (changed & IEEE802154_LLSEC_PARAM_COORD_HWADDR) + sec->params.coord_hwaddr = params->coord_hwaddr; + if (changed & IEEE802154_LLSEC_PARAM_COORD_SHORTADDR) + sec->params.coord_shortaddr = params->coord_shortaddr; + + write_unlock_bh(&sec->lock); + + return 0; +} + +static struct mac802154_llsec_key* +llsec_key_alloc(const struct ieee802154_llsec_key *template) +{ + const int authsizes[3] = { 4, 8, 16 }; + struct mac802154_llsec_key *key; + int i; + + key = kzalloc(sizeof(*key), GFP_KERNEL); + if (!key) + return NULL; + + kref_init(&key->ref); + key->key = *template; + + BUILD_BUG_ON(ARRAY_SIZE(authsizes) != ARRAY_SIZE(key->tfm)); + + for (i = 0; i < ARRAY_SIZE(key->tfm); i++) { + key->tfm[i] = crypto_alloc_aead("ccm(aes)", 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(key->tfm[i])) + goto err_tfm; + if (crypto_aead_setkey(key->tfm[i], template->key, + IEEE802154_LLSEC_KEY_SIZE)) + goto err_tfm; + if (crypto_aead_setauthsize(key->tfm[i], authsizes[i])) + goto err_tfm; + } + + key->tfm0 = crypto_alloc_sync_skcipher("ctr(aes)", 0, 0); + if (IS_ERR(key->tfm0)) + goto err_tfm; + + if (crypto_sync_skcipher_setkey(key->tfm0, template->key, + IEEE802154_LLSEC_KEY_SIZE)) + goto err_tfm0; + + return key; + +err_tfm0: + crypto_free_sync_skcipher(key->tfm0); +err_tfm: + for (i = 0; i < ARRAY_SIZE(key->tfm); i++) + if (!IS_ERR_OR_NULL(key->tfm[i])) + crypto_free_aead(key->tfm[i]); + + kfree_sensitive(key); + return NULL; +} + +static void llsec_key_release(struct kref *ref) +{ + struct mac802154_llsec_key *key; + int i; + + key = container_of(ref, struct mac802154_llsec_key, ref); + + for (i = 0; i < ARRAY_SIZE(key->tfm); i++) + crypto_free_aead(key->tfm[i]); + + crypto_free_sync_skcipher(key->tfm0); + kfree_sensitive(key); +} + +static struct mac802154_llsec_key* +llsec_key_get(struct mac802154_llsec_key *key) +{ + kref_get(&key->ref); + return key; +} + +static void llsec_key_put(struct mac802154_llsec_key *key) +{ + kref_put(&key->ref, llsec_key_release); +} + +static bool llsec_key_id_equal(const struct ieee802154_llsec_key_id *a, + const struct ieee802154_llsec_key_id *b) +{ + if (a->mode != b->mode) + return false; + + if (a->mode == IEEE802154_SCF_KEY_IMPLICIT) + return ieee802154_addr_equal(&a->device_addr, &b->device_addr); + + if (a->id != b->id) + return false; + + switch (a->mode) { + case IEEE802154_SCF_KEY_INDEX: + return true; + case IEEE802154_SCF_KEY_SHORT_INDEX: + return a->short_source == b->short_source; + case IEEE802154_SCF_KEY_HW_INDEX: + return a->extended_source == b->extended_source; + } + + return false; +} + +int mac802154_llsec_key_add(struct mac802154_llsec *sec, + const struct ieee802154_llsec_key_id *id, + const struct ieee802154_llsec_key *key) +{ + struct mac802154_llsec_key *mkey = NULL; + struct ieee802154_llsec_key_entry *pos, *new; + + if (!(key->frame_types & (1 << IEEE802154_FC_TYPE_MAC_CMD)) && + key->cmd_frame_ids) + return -EINVAL; + + list_for_each_entry(pos, &sec->table.keys, list) { + if (llsec_key_id_equal(&pos->id, id)) + return -EEXIST; + + if (memcmp(pos->key->key, key->key, + IEEE802154_LLSEC_KEY_SIZE)) + continue; + + mkey = container_of(pos->key, struct mac802154_llsec_key, key); + + /* Don't allow multiple instances of the same AES key to have + * different allowed frame types/command frame ids, as this is + * not possible in the 802.15.4 PIB. + */ + if (pos->key->frame_types != key->frame_types || + pos->key->cmd_frame_ids != key->cmd_frame_ids) + return -EEXIST; + + break; + } + + new = kzalloc(sizeof(*new), GFP_KERNEL); + if (!new) + return -ENOMEM; + + if (!mkey) + mkey = llsec_key_alloc(key); + else + mkey = llsec_key_get(mkey); + + if (!mkey) + goto fail; + + new->id = *id; + new->key = &mkey->key; + + list_add_rcu(&new->list, &sec->table.keys); + + return 0; + +fail: + kfree_sensitive(new); + return -ENOMEM; +} + +int mac802154_llsec_key_del(struct mac802154_llsec *sec, + const struct ieee802154_llsec_key_id *key) +{ + struct ieee802154_llsec_key_entry *pos; + + list_for_each_entry(pos, &sec->table.keys, list) { + struct mac802154_llsec_key *mkey; + + mkey = container_of(pos->key, struct mac802154_llsec_key, key); + + if (llsec_key_id_equal(&pos->id, key)) { + list_del_rcu(&pos->list); + llsec_key_put(mkey); + return 0; + } + } + + return -ENOENT; +} + +static bool llsec_dev_use_shortaddr(__le16 short_addr) +{ + return short_addr != cpu_to_le16(IEEE802154_ADDR_UNDEF) && + short_addr != cpu_to_le16(0xffff); +} + +static u32 llsec_dev_hash_short(__le16 short_addr, __le16 pan_id) +{ + return ((__force u16)short_addr) << 16 | (__force u16)pan_id; +} + +static u64 llsec_dev_hash_long(__le64 hwaddr) +{ + return (__force u64)hwaddr; +} + +static struct mac802154_llsec_device* +llsec_dev_find_short(struct mac802154_llsec *sec, __le16 short_addr, + __le16 pan_id) +{ + struct mac802154_llsec_device *dev; + u32 key = llsec_dev_hash_short(short_addr, pan_id); + + hash_for_each_possible_rcu(sec->devices_short, dev, bucket_s, key) { + if (dev->dev.short_addr == short_addr && + dev->dev.pan_id == pan_id) + return dev; + } + + return NULL; +} + +static struct mac802154_llsec_device* +llsec_dev_find_long(struct mac802154_llsec *sec, __le64 hwaddr) +{ + struct mac802154_llsec_device *dev; + u64 key = llsec_dev_hash_long(hwaddr); + + hash_for_each_possible_rcu(sec->devices_hw, dev, bucket_hw, key) { + if (dev->dev.hwaddr == hwaddr) + return dev; + } + + return NULL; +} + +static void llsec_dev_free(struct mac802154_llsec_device *dev) +{ + struct ieee802154_llsec_device_key *pos, *pn; + struct mac802154_llsec_device_key *devkey; + + list_for_each_entry_safe(pos, pn, &dev->dev.keys, list) { + devkey = container_of(pos, struct mac802154_llsec_device_key, + devkey); + + list_del(&pos->list); + kfree_sensitive(devkey); + } + + kfree_sensitive(dev); +} + +int mac802154_llsec_dev_add(struct mac802154_llsec *sec, + const struct ieee802154_llsec_device *dev) +{ + struct mac802154_llsec_device *entry; + u32 skey = llsec_dev_hash_short(dev->short_addr, dev->pan_id); + u64 hwkey = llsec_dev_hash_long(dev->hwaddr); + + BUILD_BUG_ON(sizeof(hwkey) != IEEE802154_ADDR_LEN); + + if ((llsec_dev_use_shortaddr(dev->short_addr) && + llsec_dev_find_short(sec, dev->short_addr, dev->pan_id)) || + llsec_dev_find_long(sec, dev->hwaddr)) + return -EEXIST; + + entry = kmalloc(sizeof(*entry), GFP_KERNEL); + if (!entry) + return -ENOMEM; + + entry->dev = *dev; + spin_lock_init(&entry->lock); + INIT_LIST_HEAD(&entry->dev.keys); + + if (llsec_dev_use_shortaddr(dev->short_addr)) + hash_add_rcu(sec->devices_short, &entry->bucket_s, skey); + else + INIT_HLIST_NODE(&entry->bucket_s); + + hash_add_rcu(sec->devices_hw, &entry->bucket_hw, hwkey); + list_add_tail_rcu(&entry->dev.list, &sec->table.devices); + + return 0; +} + +static void llsec_dev_free_rcu(struct rcu_head *rcu) +{ + llsec_dev_free(container_of(rcu, struct mac802154_llsec_device, rcu)); +} + +int mac802154_llsec_dev_del(struct mac802154_llsec *sec, __le64 device_addr) +{ + struct mac802154_llsec_device *pos; + + pos = llsec_dev_find_long(sec, device_addr); + if (!pos) + return -ENOENT; + + hash_del_rcu(&pos->bucket_s); + hash_del_rcu(&pos->bucket_hw); + list_del_rcu(&pos->dev.list); + call_rcu(&pos->rcu, llsec_dev_free_rcu); + + return 0; +} + +static struct mac802154_llsec_device_key* +llsec_devkey_find(struct mac802154_llsec_device *dev, + const struct ieee802154_llsec_key_id *key) +{ + struct ieee802154_llsec_device_key *devkey; + + list_for_each_entry_rcu(devkey, &dev->dev.keys, list) { + if (!llsec_key_id_equal(key, &devkey->key_id)) + continue; + + return container_of(devkey, struct mac802154_llsec_device_key, + devkey); + } + + return NULL; +} + +int mac802154_llsec_devkey_add(struct mac802154_llsec *sec, + __le64 dev_addr, + const struct ieee802154_llsec_device_key *key) +{ + struct mac802154_llsec_device *dev; + struct mac802154_llsec_device_key *devkey; + + dev = llsec_dev_find_long(sec, dev_addr); + + if (!dev) + return -ENOENT; + + if (llsec_devkey_find(dev, &key->key_id)) + return -EEXIST; + + devkey = kmalloc(sizeof(*devkey), GFP_KERNEL); + if (!devkey) + return -ENOMEM; + + devkey->devkey = *key; + list_add_tail_rcu(&devkey->devkey.list, &dev->dev.keys); + return 0; +} + +int mac802154_llsec_devkey_del(struct mac802154_llsec *sec, + __le64 dev_addr, + const struct ieee802154_llsec_device_key *key) +{ + struct mac802154_llsec_device *dev; + struct mac802154_llsec_device_key *devkey; + + dev = llsec_dev_find_long(sec, dev_addr); + + if (!dev) + return -ENOENT; + + devkey = llsec_devkey_find(dev, &key->key_id); + if (!devkey) + return -ENOENT; + + list_del_rcu(&devkey->devkey.list); + kfree_rcu(devkey, rcu); + return 0; +} + +static struct mac802154_llsec_seclevel* +llsec_find_seclevel(const struct mac802154_llsec *sec, + const struct ieee802154_llsec_seclevel *sl) +{ + struct ieee802154_llsec_seclevel *pos; + + list_for_each_entry(pos, &sec->table.security_levels, list) { + if (pos->frame_type != sl->frame_type || + (pos->frame_type == IEEE802154_FC_TYPE_MAC_CMD && + pos->cmd_frame_id != sl->cmd_frame_id) || + pos->device_override != sl->device_override || + pos->sec_levels != sl->sec_levels) + continue; + + return container_of(pos, struct mac802154_llsec_seclevel, + level); + } + + return NULL; +} + +int mac802154_llsec_seclevel_add(struct mac802154_llsec *sec, + const struct ieee802154_llsec_seclevel *sl) +{ + struct mac802154_llsec_seclevel *entry; + + if (llsec_find_seclevel(sec, sl)) + return -EEXIST; + + entry = kmalloc(sizeof(*entry), GFP_KERNEL); + if (!entry) + return -ENOMEM; + + entry->level = *sl; + + list_add_tail_rcu(&entry->level.list, &sec->table.security_levels); + + return 0; +} + +int mac802154_llsec_seclevel_del(struct mac802154_llsec *sec, + const struct ieee802154_llsec_seclevel *sl) +{ + struct mac802154_llsec_seclevel *pos; + + pos = llsec_find_seclevel(sec, sl); + if (!pos) + return -ENOENT; + + list_del_rcu(&pos->level.list); + kfree_rcu(pos, rcu); + + return 0; +} + +static int llsec_recover_addr(struct mac802154_llsec *sec, + struct ieee802154_addr *addr) +{ + __le16 caddr = sec->params.coord_shortaddr; + + addr->pan_id = sec->params.pan_id; + + if (caddr == cpu_to_le16(IEEE802154_ADDR_BROADCAST)) { + return -EINVAL; + } else if (caddr == cpu_to_le16(IEEE802154_ADDR_UNDEF)) { + addr->extended_addr = sec->params.coord_hwaddr; + addr->mode = IEEE802154_ADDR_LONG; + } else { + addr->short_addr = sec->params.coord_shortaddr; + addr->mode = IEEE802154_ADDR_SHORT; + } + + return 0; +} + +static struct mac802154_llsec_key* +llsec_lookup_key(struct mac802154_llsec *sec, + const struct ieee802154_hdr *hdr, + const struct ieee802154_addr *addr, + struct ieee802154_llsec_key_id *key_id) +{ + struct ieee802154_addr devaddr = *addr; + u8 key_id_mode = hdr->sec.key_id_mode; + struct ieee802154_llsec_key_entry *key_entry; + struct mac802154_llsec_key *key; + + if (key_id_mode == IEEE802154_SCF_KEY_IMPLICIT && + devaddr.mode == IEEE802154_ADDR_NONE) { + if (hdr->fc.type == IEEE802154_FC_TYPE_BEACON) { + devaddr.extended_addr = sec->params.coord_hwaddr; + devaddr.mode = IEEE802154_ADDR_LONG; + } else if (llsec_recover_addr(sec, &devaddr) < 0) { + return NULL; + } + } + + list_for_each_entry_rcu(key_entry, &sec->table.keys, list) { + const struct ieee802154_llsec_key_id *id = &key_entry->id; + + if (!(key_entry->key->frame_types & BIT(hdr->fc.type))) + continue; + + if (id->mode != key_id_mode) + continue; + + if (key_id_mode == IEEE802154_SCF_KEY_IMPLICIT) { + if (ieee802154_addr_equal(&devaddr, &id->device_addr)) + goto found; + } else { + if (id->id != hdr->sec.key_id) + continue; + + if ((key_id_mode == IEEE802154_SCF_KEY_INDEX) || + (key_id_mode == IEEE802154_SCF_KEY_SHORT_INDEX && + id->short_source == hdr->sec.short_src) || + (key_id_mode == IEEE802154_SCF_KEY_HW_INDEX && + id->extended_source == hdr->sec.extended_src)) + goto found; + } + } + + return NULL; + +found: + key = container_of(key_entry->key, struct mac802154_llsec_key, key); + if (key_id) + *key_id = key_entry->id; + return llsec_key_get(key); +} + +static void llsec_geniv(u8 iv[16], __le64 addr, + const struct ieee802154_sechdr *sec) +{ + __be64 addr_bytes = (__force __be64) swab64((__force u64) addr); + __be32 frame_counter = (__force __be32) swab32((__force u32) sec->frame_counter); + + iv[0] = 1; /* L' = L - 1 = 1 */ + memcpy(iv + 1, &addr_bytes, sizeof(addr_bytes)); + memcpy(iv + 9, &frame_counter, sizeof(frame_counter)); + iv[13] = sec->level; + iv[14] = 0; + iv[15] = 1; +} + +static int +llsec_do_encrypt_unauth(struct sk_buff *skb, const struct mac802154_llsec *sec, + const struct ieee802154_hdr *hdr, + struct mac802154_llsec_key *key) +{ + u8 iv[16]; + struct scatterlist src; + SYNC_SKCIPHER_REQUEST_ON_STACK(req, key->tfm0); + int err, datalen; + unsigned char *data; + + llsec_geniv(iv, sec->params.hwaddr, &hdr->sec); + /* Compute data payload offset and data length */ + data = skb_mac_header(skb) + skb->mac_len; + datalen = skb_tail_pointer(skb) - data; + sg_init_one(&src, data, datalen); + + skcipher_request_set_sync_tfm(req, key->tfm0); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, &src, &src, datalen, iv); + err = crypto_skcipher_encrypt(req); + skcipher_request_zero(req); + return err; +} + +static struct crypto_aead* +llsec_tfm_by_len(struct mac802154_llsec_key *key, int authlen) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(key->tfm); i++) + if (crypto_aead_authsize(key->tfm[i]) == authlen) + return key->tfm[i]; + + BUG(); +} + +static int +llsec_do_encrypt_auth(struct sk_buff *skb, const struct mac802154_llsec *sec, + const struct ieee802154_hdr *hdr, + struct mac802154_llsec_key *key) +{ + u8 iv[16]; + unsigned char *data; + int authlen, assoclen, datalen, rc; + struct scatterlist sg; + struct aead_request *req; + + authlen = ieee802154_sechdr_authtag_len(&hdr->sec); + llsec_geniv(iv, sec->params.hwaddr, &hdr->sec); + + req = aead_request_alloc(llsec_tfm_by_len(key, authlen), GFP_ATOMIC); + if (!req) + return -ENOMEM; + + assoclen = skb->mac_len; + + data = skb_mac_header(skb) + skb->mac_len; + datalen = skb_tail_pointer(skb) - data; + + skb_put(skb, authlen); + + sg_init_one(&sg, skb_mac_header(skb), assoclen + datalen + authlen); + + if (!(hdr->sec.level & IEEE802154_SCF_SECLEVEL_ENC)) { + assoclen += datalen; + datalen = 0; + } + + aead_request_set_callback(req, 0, NULL, NULL); + aead_request_set_crypt(req, &sg, &sg, datalen, iv); + aead_request_set_ad(req, assoclen); + + rc = crypto_aead_encrypt(req); + + kfree_sensitive(req); + + return rc; +} + +static int llsec_do_encrypt(struct sk_buff *skb, + const struct mac802154_llsec *sec, + const struct ieee802154_hdr *hdr, + struct mac802154_llsec_key *key) +{ + if (hdr->sec.level == IEEE802154_SCF_SECLEVEL_ENC) + return llsec_do_encrypt_unauth(skb, sec, hdr, key); + else + return llsec_do_encrypt_auth(skb, sec, hdr, key); +} + +int mac802154_llsec_encrypt(struct mac802154_llsec *sec, struct sk_buff *skb) +{ + struct ieee802154_hdr hdr; + int rc, authlen, hlen; + struct mac802154_llsec_key *key; + u32 frame_ctr; + + hlen = ieee802154_hdr_pull(skb, &hdr); + + if (hlen < 0 || hdr.fc.type != IEEE802154_FC_TYPE_DATA) + return -EINVAL; + + if (!hdr.fc.security_enabled || + (hdr.sec.level == IEEE802154_SCF_SECLEVEL_NONE)) { + skb_push(skb, hlen); + return 0; + } + + authlen = ieee802154_sechdr_authtag_len(&hdr.sec); + + if (skb->len + hlen + authlen + IEEE802154_MFR_SIZE > IEEE802154_MTU) + return -EMSGSIZE; + + rcu_read_lock(); + + read_lock_bh(&sec->lock); + + if (!sec->params.enabled) { + rc = -EINVAL; + goto fail_read; + } + + key = llsec_lookup_key(sec, &hdr, &hdr.dest, NULL); + if (!key) { + rc = -ENOKEY; + goto fail_read; + } + + read_unlock_bh(&sec->lock); + + write_lock_bh(&sec->lock); + + frame_ctr = be32_to_cpu(sec->params.frame_counter); + hdr.sec.frame_counter = cpu_to_le32(frame_ctr); + if (frame_ctr == 0xFFFFFFFF) { + write_unlock_bh(&sec->lock); + llsec_key_put(key); + rc = -EOVERFLOW; + goto fail; + } + + sec->params.frame_counter = cpu_to_be32(frame_ctr + 1); + + write_unlock_bh(&sec->lock); + + rcu_read_unlock(); + + skb->mac_len = ieee802154_hdr_push(skb, &hdr); + skb_reset_mac_header(skb); + + rc = llsec_do_encrypt(skb, sec, &hdr, key); + llsec_key_put(key); + + return rc; + +fail_read: + read_unlock_bh(&sec->lock); +fail: + rcu_read_unlock(); + return rc; +} + +static struct mac802154_llsec_device* +llsec_lookup_dev(struct mac802154_llsec *sec, + const struct ieee802154_addr *addr) +{ + struct ieee802154_addr devaddr = *addr; + struct mac802154_llsec_device *dev = NULL; + + if (devaddr.mode == IEEE802154_ADDR_NONE && + llsec_recover_addr(sec, &devaddr) < 0) + return NULL; + + if (devaddr.mode == IEEE802154_ADDR_SHORT) { + u32 key = llsec_dev_hash_short(devaddr.short_addr, + devaddr.pan_id); + + hash_for_each_possible_rcu(sec->devices_short, dev, + bucket_s, key) { + if (dev->dev.pan_id == devaddr.pan_id && + dev->dev.short_addr == devaddr.short_addr) + return dev; + } + } else { + u64 key = llsec_dev_hash_long(devaddr.extended_addr); + + hash_for_each_possible_rcu(sec->devices_hw, dev, + bucket_hw, key) { + if (dev->dev.hwaddr == devaddr.extended_addr) + return dev; + } + } + + return NULL; +} + +static int +llsec_lookup_seclevel(const struct mac802154_llsec *sec, + u8 frame_type, u8 cmd_frame_id, + struct ieee802154_llsec_seclevel *rlevel) +{ + struct ieee802154_llsec_seclevel *level; + + list_for_each_entry_rcu(level, &sec->table.security_levels, list) { + if (level->frame_type == frame_type && + (frame_type != IEEE802154_FC_TYPE_MAC_CMD || + level->cmd_frame_id == cmd_frame_id)) { + *rlevel = *level; + return 0; + } + } + + return -EINVAL; +} + +static int +llsec_do_decrypt_unauth(struct sk_buff *skb, const struct mac802154_llsec *sec, + const struct ieee802154_hdr *hdr, + struct mac802154_llsec_key *key, __le64 dev_addr) +{ + u8 iv[16]; + unsigned char *data; + int datalen; + struct scatterlist src; + SYNC_SKCIPHER_REQUEST_ON_STACK(req, key->tfm0); + int err; + + llsec_geniv(iv, dev_addr, &hdr->sec); + data = skb_mac_header(skb) + skb->mac_len; + datalen = skb_tail_pointer(skb) - data; + + sg_init_one(&src, data, datalen); + + skcipher_request_set_sync_tfm(req, key->tfm0); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, &src, &src, datalen, iv); + + err = crypto_skcipher_decrypt(req); + skcipher_request_zero(req); + return err; +} + +static int +llsec_do_decrypt_auth(struct sk_buff *skb, const struct mac802154_llsec *sec, + const struct ieee802154_hdr *hdr, + struct mac802154_llsec_key *key, __le64 dev_addr) +{ + u8 iv[16]; + unsigned char *data; + int authlen, datalen, assoclen, rc; + struct scatterlist sg; + struct aead_request *req; + + authlen = ieee802154_sechdr_authtag_len(&hdr->sec); + llsec_geniv(iv, dev_addr, &hdr->sec); + + req = aead_request_alloc(llsec_tfm_by_len(key, authlen), GFP_ATOMIC); + if (!req) + return -ENOMEM; + + assoclen = skb->mac_len; + + data = skb_mac_header(skb) + skb->mac_len; + datalen = skb_tail_pointer(skb) - data; + + sg_init_one(&sg, skb_mac_header(skb), assoclen + datalen); + + if (!(hdr->sec.level & IEEE802154_SCF_SECLEVEL_ENC)) { + assoclen += datalen - authlen; + datalen = authlen; + } + + aead_request_set_callback(req, 0, NULL, NULL); + aead_request_set_crypt(req, &sg, &sg, datalen, iv); + aead_request_set_ad(req, assoclen); + + rc = crypto_aead_decrypt(req); + + kfree_sensitive(req); + skb_trim(skb, skb->len - authlen); + + return rc; +} + +static int +llsec_do_decrypt(struct sk_buff *skb, const struct mac802154_llsec *sec, + const struct ieee802154_hdr *hdr, + struct mac802154_llsec_key *key, __le64 dev_addr) +{ + if (hdr->sec.level == IEEE802154_SCF_SECLEVEL_ENC) + return llsec_do_decrypt_unauth(skb, sec, hdr, key, dev_addr); + else + return llsec_do_decrypt_auth(skb, sec, hdr, key, dev_addr); +} + +static int +llsec_update_devkey_record(struct mac802154_llsec_device *dev, + const struct ieee802154_llsec_key_id *in_key) +{ + struct mac802154_llsec_device_key *devkey; + + devkey = llsec_devkey_find(dev, in_key); + + if (!devkey) { + struct mac802154_llsec_device_key *next; + + next = kzalloc(sizeof(*devkey), GFP_ATOMIC); + if (!next) + return -ENOMEM; + + next->devkey.key_id = *in_key; + + spin_lock_bh(&dev->lock); + + devkey = llsec_devkey_find(dev, in_key); + if (!devkey) + list_add_rcu(&next->devkey.list, &dev->dev.keys); + else + kfree_sensitive(next); + + spin_unlock_bh(&dev->lock); + } + + return 0; +} + +static int +llsec_update_devkey_info(struct mac802154_llsec_device *dev, + const struct ieee802154_llsec_key_id *in_key, + u32 frame_counter) +{ + struct mac802154_llsec_device_key *devkey = NULL; + + if (dev->dev.key_mode == IEEE802154_LLSEC_DEVKEY_RESTRICT) { + devkey = llsec_devkey_find(dev, in_key); + if (!devkey) + return -ENOENT; + } + + if (dev->dev.key_mode == IEEE802154_LLSEC_DEVKEY_RECORD) { + int rc = llsec_update_devkey_record(dev, in_key); + + if (rc < 0) + return rc; + } + + spin_lock_bh(&dev->lock); + + if ((!devkey && frame_counter < dev->dev.frame_counter) || + (devkey && frame_counter < devkey->devkey.frame_counter)) { + spin_unlock_bh(&dev->lock); + return -EINVAL; + } + + if (devkey) + devkey->devkey.frame_counter = frame_counter + 1; + else + dev->dev.frame_counter = frame_counter + 1; + + spin_unlock_bh(&dev->lock); + + return 0; +} + +int mac802154_llsec_decrypt(struct mac802154_llsec *sec, struct sk_buff *skb) +{ + struct ieee802154_hdr hdr; + struct mac802154_llsec_key *key; + struct ieee802154_llsec_key_id key_id; + struct mac802154_llsec_device *dev; + struct ieee802154_llsec_seclevel seclevel; + int err; + __le64 dev_addr; + u32 frame_ctr; + + if (ieee802154_hdr_peek(skb, &hdr) < 0) + return -EINVAL; + if (!hdr.fc.security_enabled) + return 0; + if (hdr.fc.version == 0) + return -EINVAL; + + read_lock_bh(&sec->lock); + if (!sec->params.enabled) { + read_unlock_bh(&sec->lock); + return -EINVAL; + } + read_unlock_bh(&sec->lock); + + rcu_read_lock(); + + key = llsec_lookup_key(sec, &hdr, &hdr.source, &key_id); + if (!key) { + err = -ENOKEY; + goto fail; + } + + dev = llsec_lookup_dev(sec, &hdr.source); + if (!dev) { + err = -EINVAL; + goto fail_dev; + } + + if (llsec_lookup_seclevel(sec, hdr.fc.type, 0, &seclevel) < 0) { + err = -EINVAL; + goto fail_dev; + } + + if (!(seclevel.sec_levels & BIT(hdr.sec.level)) && + (hdr.sec.level == 0 && seclevel.device_override && + !dev->dev.seclevel_exempt)) { + err = -EINVAL; + goto fail_dev; + } + + frame_ctr = le32_to_cpu(hdr.sec.frame_counter); + + if (frame_ctr == 0xffffffff) { + err = -EOVERFLOW; + goto fail_dev; + } + + err = llsec_update_devkey_info(dev, &key_id, frame_ctr); + if (err) + goto fail_dev; + + dev_addr = dev->dev.hwaddr; + + rcu_read_unlock(); + + err = llsec_do_decrypt(skb, sec, &hdr, key, dev_addr); + llsec_key_put(key); + return err; + +fail_dev: + llsec_key_put(key); +fail: + rcu_read_unlock(); + return err; +} diff --git a/net/mac802154/llsec.h b/net/mac802154/llsec.h new file mode 100644 index 000000000..ddeb79228 --- /dev/null +++ b/net/mac802154/llsec.h @@ -0,0 +1,99 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (C) 2014 Fraunhofer ITWM + * + * Written by: + * Phoebe Buckheister + */ + +#ifndef MAC802154_LLSEC_H +#define MAC802154_LLSEC_H + +#include +#include +#include +#include +#include +#include + +struct mac802154_llsec_key { + struct ieee802154_llsec_key key; + + /* one tfm for each authsize (4/8/16) */ + struct crypto_aead *tfm[3]; + struct crypto_sync_skcipher *tfm0; + + struct kref ref; +}; + +struct mac802154_llsec_device_key { + struct ieee802154_llsec_device_key devkey; + + struct rcu_head rcu; +}; + +struct mac802154_llsec_device { + struct ieee802154_llsec_device dev; + + struct hlist_node bucket_s; + struct hlist_node bucket_hw; + + /* protects dev.frame_counter and the elements of dev.keys */ + spinlock_t lock; + + struct rcu_head rcu; +}; + +struct mac802154_llsec_seclevel { + struct ieee802154_llsec_seclevel level; + + struct rcu_head rcu; +}; + +struct mac802154_llsec { + struct ieee802154_llsec_params params; + struct ieee802154_llsec_table table; + + DECLARE_HASHTABLE(devices_short, 6); + DECLARE_HASHTABLE(devices_hw, 6); + + /* protects params, all other fields are fine with RCU */ + rwlock_t lock; +}; + +void mac802154_llsec_init(struct mac802154_llsec *sec); +void mac802154_llsec_destroy(struct mac802154_llsec *sec); + +int mac802154_llsec_get_params(struct mac802154_llsec *sec, + struct ieee802154_llsec_params *params); +int mac802154_llsec_set_params(struct mac802154_llsec *sec, + const struct ieee802154_llsec_params *params, + int changed); + +int mac802154_llsec_key_add(struct mac802154_llsec *sec, + const struct ieee802154_llsec_key_id *id, + const struct ieee802154_llsec_key *key); +int mac802154_llsec_key_del(struct mac802154_llsec *sec, + const struct ieee802154_llsec_key_id *key); + +int mac802154_llsec_dev_add(struct mac802154_llsec *sec, + const struct ieee802154_llsec_device *dev); +int mac802154_llsec_dev_del(struct mac802154_llsec *sec, + __le64 device_addr); + +int mac802154_llsec_devkey_add(struct mac802154_llsec *sec, + __le64 dev_addr, + const struct ieee802154_llsec_device_key *key); +int mac802154_llsec_devkey_del(struct mac802154_llsec *sec, + __le64 dev_addr, + const struct ieee802154_llsec_device_key *key); + +int mac802154_llsec_seclevel_add(struct mac802154_llsec *sec, + const struct ieee802154_llsec_seclevel *sl); +int mac802154_llsec_seclevel_del(struct mac802154_llsec *sec, + const struct ieee802154_llsec_seclevel *sl); + +int mac802154_llsec_encrypt(struct mac802154_llsec *sec, struct sk_buff *skb); +int mac802154_llsec_decrypt(struct mac802154_llsec *sec, struct sk_buff *skb); + +#endif /* MAC802154_LLSEC_H */ diff --git a/net/mac802154/mac_cmd.c b/net/mac802154/mac_cmd.c new file mode 100644 index 000000000..8ea5b6402 --- /dev/null +++ b/net/mac802154/mac_cmd.c @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * MAC commands interface + * + * Copyright 2007-2012 Siemens AG + * + * Written by: + * Sergey Lapin + * Dmitry Eremin-Solenikov + * Alexander Smirnov + */ + +#include +#include +#include + +#include +#include +#include + +#include "ieee802154_i.h" +#include "driver-ops.h" + +static int mac802154_mlme_start_req(struct net_device *dev, + struct ieee802154_addr *addr, + u8 channel, u8 page, + u8 bcn_ord, u8 sf_ord, + u8 pan_coord, u8 blx, + u8 coord_realign) +{ + struct ieee802154_llsec_params params; + int changed = 0; + + ASSERT_RTNL(); + + BUG_ON(addr->mode != IEEE802154_ADDR_SHORT); + + dev->ieee802154_ptr->pan_id = addr->pan_id; + dev->ieee802154_ptr->short_addr = addr->short_addr; + mac802154_dev_set_page_channel(dev, page, channel); + + params.pan_id = addr->pan_id; + changed |= IEEE802154_LLSEC_PARAM_PAN_ID; + + params.hwaddr = ieee802154_devaddr_from_raw(dev->dev_addr); + changed |= IEEE802154_LLSEC_PARAM_HWADDR; + + params.coord_hwaddr = params.hwaddr; + changed |= IEEE802154_LLSEC_PARAM_COORD_HWADDR; + + params.coord_shortaddr = addr->short_addr; + changed |= IEEE802154_LLSEC_PARAM_COORD_SHORTADDR; + + return mac802154_set_params(dev, ¶ms, changed); +} + +static int mac802154_set_mac_params(struct net_device *dev, + const struct ieee802154_mac_params *params) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + struct ieee802154_local *local = sdata->local; + struct wpan_dev *wpan_dev = &sdata->wpan_dev; + int ret; + + ASSERT_RTNL(); + + /* PHY */ + wpan_dev->wpan_phy->transmit_power = params->transmit_power; + wpan_dev->wpan_phy->cca = params->cca; + wpan_dev->wpan_phy->cca_ed_level = params->cca_ed_level; + + /* MAC */ + wpan_dev->min_be = params->min_be; + wpan_dev->max_be = params->max_be; + wpan_dev->csma_retries = params->csma_retries; + wpan_dev->frame_retries = params->frame_retries; + wpan_dev->lbt = params->lbt; + + if (local->hw.phy->flags & WPAN_PHY_FLAG_TXPOWER) { + ret = drv_set_tx_power(local, params->transmit_power); + if (ret < 0) + return ret; + } + + if (local->hw.phy->flags & WPAN_PHY_FLAG_CCA_MODE) { + ret = drv_set_cca_mode(local, ¶ms->cca); + if (ret < 0) + return ret; + } + + if (local->hw.phy->flags & WPAN_PHY_FLAG_CCA_ED_LEVEL) { + ret = drv_set_cca_ed_level(local, params->cca_ed_level); + if (ret < 0) + return ret; + } + + return 0; +} + +static void mac802154_get_mac_params(struct net_device *dev, + struct ieee802154_mac_params *params) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + struct wpan_dev *wpan_dev = &sdata->wpan_dev; + + ASSERT_RTNL(); + + /* PHY */ + params->transmit_power = wpan_dev->wpan_phy->transmit_power; + params->cca = wpan_dev->wpan_phy->cca; + params->cca_ed_level = wpan_dev->wpan_phy->cca_ed_level; + + /* MAC */ + params->min_be = wpan_dev->min_be; + params->max_be = wpan_dev->max_be; + params->csma_retries = wpan_dev->csma_retries; + params->frame_retries = wpan_dev->frame_retries; + params->lbt = wpan_dev->lbt; +} + +static const struct ieee802154_llsec_ops mac802154_llsec_ops = { + .get_params = mac802154_get_params, + .set_params = mac802154_set_params, + .add_key = mac802154_add_key, + .del_key = mac802154_del_key, + .add_dev = mac802154_add_dev, + .del_dev = mac802154_del_dev, + .add_devkey = mac802154_add_devkey, + .del_devkey = mac802154_del_devkey, + .add_seclevel = mac802154_add_seclevel, + .del_seclevel = mac802154_del_seclevel, + .lock_table = mac802154_lock_table, + .get_table = mac802154_get_table, + .unlock_table = mac802154_unlock_table, +}; + +struct ieee802154_mlme_ops mac802154_mlme_wpan = { + .start_req = mac802154_mlme_start_req, + + .llsec = &mac802154_llsec_ops, + + .set_mac_params = mac802154_set_mac_params, + .get_mac_params = mac802154_get_mac_params, +}; diff --git a/net/mac802154/main.c b/net/mac802154/main.c new file mode 100644 index 000000000..bd7bdb121 --- /dev/null +++ b/net/mac802154/main.c @@ -0,0 +1,285 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2007-2012 Siemens AG + * + * Written by: + * Alexander Smirnov + */ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "ieee802154_i.h" +#include "cfg.h" + +static void ieee802154_tasklet_handler(struct tasklet_struct *t) +{ + struct ieee802154_local *local = from_tasklet(local, t, tasklet); + struct sk_buff *skb; + + while ((skb = skb_dequeue(&local->skb_queue))) { + switch (skb->pkt_type) { + case IEEE802154_RX_MSG: + /* Clear skb->pkt_type in order to not confuse kernel + * netstack. + */ + skb->pkt_type = 0; + ieee802154_rx(local, skb); + break; + default: + WARN(1, "mac802154: Packet is of unknown type %d\n", + skb->pkt_type); + kfree_skb(skb); + break; + } + } +} + +struct ieee802154_hw * +ieee802154_alloc_hw(size_t priv_data_len, const struct ieee802154_ops *ops) +{ + struct wpan_phy *phy; + struct ieee802154_local *local; + size_t priv_size; + + if (WARN_ON(!ops || !(ops->xmit_async || ops->xmit_sync) || !ops->ed || + !ops->start || !ops->stop || !ops->set_channel)) + return NULL; + + /* Ensure 32-byte alignment of our private data and hw private data. + * We use the wpan_phy priv data for both our ieee802154_local and for + * the driver's private data + * + * in memory it'll be like this: + * + * +-------------------------+ + * | struct wpan_phy | + * +-------------------------+ + * | struct ieee802154_local | + * +-------------------------+ + * | driver's private data | + * +-------------------------+ + * + * Due to ieee802154 layer isn't aware of driver and MAC structures, + * so lets align them here. + */ + + priv_size = ALIGN(sizeof(*local), NETDEV_ALIGN) + priv_data_len; + + phy = wpan_phy_new(&mac802154_config_ops, priv_size); + if (!phy) { + pr_err("failure to allocate master IEEE802.15.4 device\n"); + return NULL; + } + + phy->privid = mac802154_wpan_phy_privid; + + local = wpan_phy_priv(phy); + local->phy = phy; + local->hw.phy = local->phy; + local->hw.priv = (char *)local + ALIGN(sizeof(*local), NETDEV_ALIGN); + local->ops = ops; + + INIT_LIST_HEAD(&local->interfaces); + mutex_init(&local->iflist_mtx); + + tasklet_setup(&local->tasklet, ieee802154_tasklet_handler); + + skb_queue_head_init(&local->skb_queue); + + INIT_WORK(&local->tx_work, ieee802154_xmit_worker); + + /* init supported flags with 802.15.4 default ranges */ + phy->supported.max_minbe = 8; + phy->supported.min_maxbe = 3; + phy->supported.max_maxbe = 8; + phy->supported.min_frame_retries = 0; + phy->supported.max_frame_retries = 7; + phy->supported.max_csma_backoffs = 5; + phy->supported.lbt = NL802154_SUPPORTED_BOOL_FALSE; + + /* always supported */ + phy->supported.iftypes = BIT(NL802154_IFTYPE_NODE); + + return &local->hw; +} +EXPORT_SYMBOL(ieee802154_alloc_hw); + +void ieee802154_configure_durations(struct wpan_phy *phy) +{ + u32 duration = 0; + + switch (phy->current_page) { + case 0: + if (BIT(phy->current_channel) & 0x1) + /* 868 MHz BPSK 802.15.4-2003: 20 ksym/s */ + duration = 50 * NSEC_PER_USEC; + else if (BIT(phy->current_channel) & 0x7FE) + /* 915 MHz BPSK 802.15.4-2003: 40 ksym/s */ + duration = 25 * NSEC_PER_USEC; + else if (BIT(phy->current_channel) & 0x7FFF800) + /* 2400 MHz O-QPSK 802.15.4-2006: 62.5 ksym/s */ + duration = 16 * NSEC_PER_USEC; + break; + case 2: + if (BIT(phy->current_channel) & 0x1) + /* 868 MHz O-QPSK 802.15.4-2006: 25 ksym/s */ + duration = 40 * NSEC_PER_USEC; + else if (BIT(phy->current_channel) & 0x7FE) + /* 915 MHz O-QPSK 802.15.4-2006: 62.5 ksym/s */ + duration = 16 * NSEC_PER_USEC; + break; + case 3: + if (BIT(phy->current_channel) & 0x3FFF) + /* 2.4 GHz CSS 802.15.4a-2007: 1/6 Msym/s */ + duration = 6 * NSEC_PER_USEC; + break; + default: + break; + } + + if (!duration) { + pr_debug("Unknown PHY symbol duration\n"); + return; + } + + phy->symbol_duration = duration; + phy->lifs_period = (IEEE802154_LIFS_PERIOD * phy->symbol_duration) / NSEC_PER_SEC; + phy->sifs_period = (IEEE802154_SIFS_PERIOD * phy->symbol_duration) / NSEC_PER_SEC; +} +EXPORT_SYMBOL(ieee802154_configure_durations); + +void ieee802154_free_hw(struct ieee802154_hw *hw) +{ + struct ieee802154_local *local = hw_to_local(hw); + + BUG_ON(!list_empty(&local->interfaces)); + + mutex_destroy(&local->iflist_mtx); + + wpan_phy_free(local->phy); +} +EXPORT_SYMBOL(ieee802154_free_hw); + +static void ieee802154_setup_wpan_phy_pib(struct wpan_phy *wpan_phy) +{ + /* TODO warn on empty symbol_duration + * Should be done when all drivers sets this value. + */ + + wpan_phy->lifs_period = + (IEEE802154_LIFS_PERIOD * wpan_phy->symbol_duration) / 1000; + wpan_phy->sifs_period = + (IEEE802154_SIFS_PERIOD * wpan_phy->symbol_duration) / 1000; +} + +int ieee802154_register_hw(struct ieee802154_hw *hw) +{ + struct ieee802154_local *local = hw_to_local(hw); + struct net_device *dev; + int rc = -ENOSYS; + + local->workqueue = + create_singlethread_workqueue(wpan_phy_name(local->phy)); + if (!local->workqueue) { + rc = -ENOMEM; + goto out; + } + + hrtimer_init(&local->ifs_timer, CLOCK_MONOTONIC, HRTIMER_MODE_REL); + local->ifs_timer.function = ieee802154_xmit_ifs_timer; + + wpan_phy_set_dev(local->phy, local->hw.parent); + + ieee802154_setup_wpan_phy_pib(local->phy); + + ieee802154_configure_durations(local->phy); + + if (!(hw->flags & IEEE802154_HW_CSMA_PARAMS)) { + local->phy->supported.min_csma_backoffs = 4; + local->phy->supported.max_csma_backoffs = 4; + local->phy->supported.min_maxbe = 5; + local->phy->supported.max_maxbe = 5; + local->phy->supported.min_minbe = 3; + local->phy->supported.max_minbe = 3; + } + + if (!(hw->flags & IEEE802154_HW_FRAME_RETRIES)) { + local->phy->supported.min_frame_retries = 3; + local->phy->supported.max_frame_retries = 3; + } + + if (hw->flags & IEEE802154_HW_PROMISCUOUS) + local->phy->supported.iftypes |= BIT(NL802154_IFTYPE_MONITOR); + + rc = wpan_phy_register(local->phy); + if (rc < 0) + goto out_wq; + + rtnl_lock(); + + dev = ieee802154_if_add(local, "wpan%d", NET_NAME_ENUM, + NL802154_IFTYPE_NODE, + cpu_to_le64(0x0000000000000000ULL)); + if (IS_ERR(dev)) { + rtnl_unlock(); + rc = PTR_ERR(dev); + goto out_phy; + } + + rtnl_unlock(); + + return 0; + +out_phy: + wpan_phy_unregister(local->phy); +out_wq: + destroy_workqueue(local->workqueue); +out: + return rc; +} +EXPORT_SYMBOL(ieee802154_register_hw); + +void ieee802154_unregister_hw(struct ieee802154_hw *hw) +{ + struct ieee802154_local *local = hw_to_local(hw); + + tasklet_kill(&local->tasklet); + flush_workqueue(local->workqueue); + + rtnl_lock(); + + ieee802154_remove_interfaces(local); + + rtnl_unlock(); + + destroy_workqueue(local->workqueue); + wpan_phy_unregister(local->phy); +} +EXPORT_SYMBOL(ieee802154_unregister_hw); + +static int __init ieee802154_init(void) +{ + return ieee802154_iface_init(); +} + +static void __exit ieee802154_exit(void) +{ + ieee802154_iface_exit(); + + rcu_barrier(); +} + +subsys_initcall(ieee802154_init); +module_exit(ieee802154_exit); + +MODULE_DESCRIPTION("IEEE 802.15.4 subsystem"); +MODULE_LICENSE("GPL v2"); diff --git a/net/mac802154/mib.c b/net/mac802154/mib.c new file mode 100644 index 000000000..81666e1d7 --- /dev/null +++ b/net/mac802154/mib.c @@ -0,0 +1,219 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2007-2012 Siemens AG + * + * Written by: + * Dmitry Eremin-Solenikov + * Sergey Lapin + * Maxim Gorbachyov + * Alexander Smirnov + */ + +#include + +#include +#include +#include + +#include "ieee802154_i.h" +#include "driver-ops.h" + +void mac802154_dev_set_page_channel(struct net_device *dev, u8 page, u8 chan) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + struct ieee802154_local *local = sdata->local; + int res; + + ASSERT_RTNL(); + + BUG_ON(dev->type != ARPHRD_IEEE802154); + + res = drv_set_channel(local, page, chan); + if (res) { + pr_debug("set_channel failed\n"); + } else { + local->phy->current_channel = chan; + local->phy->current_page = page; + } +} + +int mac802154_get_params(struct net_device *dev, + struct ieee802154_llsec_params *params) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + BUG_ON(dev->type != ARPHRD_IEEE802154); + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_get_params(&sdata->sec, params); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +int mac802154_set_params(struct net_device *dev, + const struct ieee802154_llsec_params *params, + int changed) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + BUG_ON(dev->type != ARPHRD_IEEE802154); + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_set_params(&sdata->sec, params, changed); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +int mac802154_add_key(struct net_device *dev, + const struct ieee802154_llsec_key_id *id, + const struct ieee802154_llsec_key *key) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + BUG_ON(dev->type != ARPHRD_IEEE802154); + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_key_add(&sdata->sec, id, key); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +int mac802154_del_key(struct net_device *dev, + const struct ieee802154_llsec_key_id *id) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + BUG_ON(dev->type != ARPHRD_IEEE802154); + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_key_del(&sdata->sec, id); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +int mac802154_add_dev(struct net_device *dev, + const struct ieee802154_llsec_device *llsec_dev) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + BUG_ON(dev->type != ARPHRD_IEEE802154); + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_dev_add(&sdata->sec, llsec_dev); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +int mac802154_del_dev(struct net_device *dev, __le64 dev_addr) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + BUG_ON(dev->type != ARPHRD_IEEE802154); + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_dev_del(&sdata->sec, dev_addr); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +int mac802154_add_devkey(struct net_device *dev, + __le64 device_addr, + const struct ieee802154_llsec_device_key *key) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + BUG_ON(dev->type != ARPHRD_IEEE802154); + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_devkey_add(&sdata->sec, device_addr, key); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +int mac802154_del_devkey(struct net_device *dev, + __le64 device_addr, + const struct ieee802154_llsec_device_key *key) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + BUG_ON(dev->type != ARPHRD_IEEE802154); + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_devkey_del(&sdata->sec, device_addr, key); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +int mac802154_add_seclevel(struct net_device *dev, + const struct ieee802154_llsec_seclevel *sl) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + BUG_ON(dev->type != ARPHRD_IEEE802154); + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_seclevel_add(&sdata->sec, sl); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +int mac802154_del_seclevel(struct net_device *dev, + const struct ieee802154_llsec_seclevel *sl) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int res; + + BUG_ON(dev->type != ARPHRD_IEEE802154); + + mutex_lock(&sdata->sec_mtx); + res = mac802154_llsec_seclevel_del(&sdata->sec, sl); + mutex_unlock(&sdata->sec_mtx); + + return res; +} + +void mac802154_lock_table(struct net_device *dev) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + + BUG_ON(dev->type != ARPHRD_IEEE802154); + + mutex_lock(&sdata->sec_mtx); +} + +void mac802154_get_table(struct net_device *dev, + struct ieee802154_llsec_table **t) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + + BUG_ON(dev->type != ARPHRD_IEEE802154); + + *t = &sdata->sec.table; +} + +void mac802154_unlock_table(struct net_device *dev) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + + BUG_ON(dev->type != ARPHRD_IEEE802154); + + mutex_unlock(&sdata->sec_mtx); +} diff --git a/net/mac802154/rx.c b/net/mac802154/rx.c new file mode 100644 index 000000000..726b47a46 --- /dev/null +++ b/net/mac802154/rx.c @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2007-2012 Siemens AG + * + * Written by: + * Pavel Smolenskiy + * Maxim Gorbachyov + * Dmitry Eremin-Solenikov + * Alexander Smirnov + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "ieee802154_i.h" + +static int ieee802154_deliver_skb(struct sk_buff *skb) +{ + skb->ip_summed = CHECKSUM_UNNECESSARY; + skb->protocol = htons(ETH_P_IEEE802154); + + return netif_receive_skb(skb); +} + +static int +ieee802154_subif_frame(struct ieee802154_sub_if_data *sdata, + struct sk_buff *skb, const struct ieee802154_hdr *hdr) +{ + struct wpan_dev *wpan_dev = &sdata->wpan_dev; + __le16 span, sshort; + int rc; + + pr_debug("getting packet via slave interface %s\n", sdata->dev->name); + + span = wpan_dev->pan_id; + sshort = wpan_dev->short_addr; + + switch (mac_cb(skb)->dest.mode) { + case IEEE802154_ADDR_NONE: + if (hdr->source.mode != IEEE802154_ADDR_NONE) + /* FIXME: check if we are PAN coordinator */ + skb->pkt_type = PACKET_OTHERHOST; + else + /* ACK comes with both addresses empty */ + skb->pkt_type = PACKET_HOST; + break; + case IEEE802154_ADDR_LONG: + if (mac_cb(skb)->dest.pan_id != span && + mac_cb(skb)->dest.pan_id != cpu_to_le16(IEEE802154_PANID_BROADCAST)) + skb->pkt_type = PACKET_OTHERHOST; + else if (mac_cb(skb)->dest.extended_addr == wpan_dev->extended_addr) + skb->pkt_type = PACKET_HOST; + else + skb->pkt_type = PACKET_OTHERHOST; + break; + case IEEE802154_ADDR_SHORT: + if (mac_cb(skb)->dest.pan_id != span && + mac_cb(skb)->dest.pan_id != cpu_to_le16(IEEE802154_PANID_BROADCAST)) + skb->pkt_type = PACKET_OTHERHOST; + else if (mac_cb(skb)->dest.short_addr == sshort) + skb->pkt_type = PACKET_HOST; + else if (mac_cb(skb)->dest.short_addr == + cpu_to_le16(IEEE802154_ADDR_BROADCAST)) + skb->pkt_type = PACKET_BROADCAST; + else + skb->pkt_type = PACKET_OTHERHOST; + break; + default: + pr_debug("invalid dest mode\n"); + goto fail; + } + + skb->dev = sdata->dev; + + /* TODO this should be moved after netif_receive_skb call, otherwise + * wireshark will show a mac header with security fields and the + * payload is already decrypted. + */ + rc = mac802154_llsec_decrypt(&sdata->sec, skb); + if (rc) { + pr_debug("decryption failed: %i\n", rc); + goto fail; + } + + sdata->dev->stats.rx_packets++; + sdata->dev->stats.rx_bytes += skb->len; + + switch (mac_cb(skb)->type) { + case IEEE802154_FC_TYPE_BEACON: + case IEEE802154_FC_TYPE_ACK: + case IEEE802154_FC_TYPE_MAC_CMD: + goto fail; + + case IEEE802154_FC_TYPE_DATA: + return ieee802154_deliver_skb(skb); + default: + pr_warn_ratelimited("ieee802154: bad frame received " + "(type = %d)\n", mac_cb(skb)->type); + goto fail; + } + +fail: + kfree_skb(skb); + return NET_RX_DROP; +} + +static void +ieee802154_print_addr(const char *name, const struct ieee802154_addr *addr) +{ + if (addr->mode == IEEE802154_ADDR_NONE) + pr_debug("%s not present\n", name); + + pr_debug("%s PAN ID: %04x\n", name, le16_to_cpu(addr->pan_id)); + if (addr->mode == IEEE802154_ADDR_SHORT) { + pr_debug("%s is short: %04x\n", name, + le16_to_cpu(addr->short_addr)); + } else { + u64 hw = swab64((__force u64)addr->extended_addr); + + pr_debug("%s is hardware: %8phC\n", name, &hw); + } +} + +static int +ieee802154_parse_frame_start(struct sk_buff *skb, struct ieee802154_hdr *hdr) +{ + int hlen; + struct ieee802154_mac_cb *cb = mac_cb(skb); + + skb_reset_mac_header(skb); + + hlen = ieee802154_hdr_pull(skb, hdr); + if (hlen < 0) + return -EINVAL; + + skb->mac_len = hlen; + + pr_debug("fc: %04x dsn: %02x\n", le16_to_cpup((__le16 *)&hdr->fc), + hdr->seq); + + cb->type = hdr->fc.type; + cb->ackreq = hdr->fc.ack_request; + cb->secen = hdr->fc.security_enabled; + + ieee802154_print_addr("destination", &hdr->dest); + ieee802154_print_addr("source", &hdr->source); + + cb->source = hdr->source; + cb->dest = hdr->dest; + + if (hdr->fc.security_enabled) { + u64 key; + + pr_debug("seclevel %i\n", hdr->sec.level); + + switch (hdr->sec.key_id_mode) { + case IEEE802154_SCF_KEY_IMPLICIT: + pr_debug("implicit key\n"); + break; + + case IEEE802154_SCF_KEY_INDEX: + pr_debug("key %02x\n", hdr->sec.key_id); + break; + + case IEEE802154_SCF_KEY_SHORT_INDEX: + pr_debug("key %04x:%04x %02x\n", + le32_to_cpu(hdr->sec.short_src) >> 16, + le32_to_cpu(hdr->sec.short_src) & 0xffff, + hdr->sec.key_id); + break; + + case IEEE802154_SCF_KEY_HW_INDEX: + key = swab64((__force u64)hdr->sec.extended_src); + pr_debug("key source %8phC %02x\n", &key, + hdr->sec.key_id); + break; + } + } + + return 0; +} + +static void +__ieee802154_rx_handle_packet(struct ieee802154_local *local, + struct sk_buff *skb) +{ + int ret; + struct ieee802154_sub_if_data *sdata; + struct ieee802154_hdr hdr; + + ret = ieee802154_parse_frame_start(skb, &hdr); + if (ret) { + pr_debug("got invalid frame\n"); + kfree_skb(skb); + return; + } + + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + if (sdata->wpan_dev.iftype != NL802154_IFTYPE_NODE) + continue; + + if (!ieee802154_sdata_running(sdata)) + continue; + + ieee802154_subif_frame(sdata, skb, &hdr); + skb = NULL; + break; + } + + kfree_skb(skb); +} + +static void +ieee802154_monitors_rx(struct ieee802154_local *local, struct sk_buff *skb) +{ + struct sk_buff *skb2; + struct ieee802154_sub_if_data *sdata; + + skb_reset_mac_header(skb); + skb->ip_summed = CHECKSUM_UNNECESSARY; + skb->pkt_type = PACKET_OTHERHOST; + skb->protocol = htons(ETH_P_IEEE802154); + + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + if (sdata->wpan_dev.iftype != NL802154_IFTYPE_MONITOR) + continue; + + if (!ieee802154_sdata_running(sdata)) + continue; + + skb2 = skb_clone(skb, GFP_ATOMIC); + if (skb2) { + skb2->dev = sdata->dev; + ieee802154_deliver_skb(skb2); + + sdata->dev->stats.rx_packets++; + sdata->dev->stats.rx_bytes += skb->len; + } + } +} + +void ieee802154_rx(struct ieee802154_local *local, struct sk_buff *skb) +{ + u16 crc; + + WARN_ON_ONCE(softirq_count() == 0); + + if (local->suspended) + goto drop; + + /* TODO: When a transceiver omits the checksum here, we + * add an own calculated one. This is currently an ugly + * solution because the monitor needs a crc here. + */ + if (local->hw.flags & IEEE802154_HW_RX_OMIT_CKSUM) { + crc = crc_ccitt(0, skb->data, skb->len); + put_unaligned_le16(crc, skb_put(skb, 2)); + } + + rcu_read_lock(); + + ieee802154_monitors_rx(local, skb); + + /* Check if transceiver doesn't validate the checksum. + * If not we validate the checksum here. + */ + if (local->hw.flags & IEEE802154_HW_RX_DROP_BAD_CKSUM) { + crc = crc_ccitt(0, skb->data, skb->len); + if (crc) { + rcu_read_unlock(); + goto drop; + } + } + /* remove crc */ + skb_trim(skb, skb->len - 2); + + __ieee802154_rx_handle_packet(local, skb); + + rcu_read_unlock(); + + return; +drop: + kfree_skb(skb); +} + +void +ieee802154_rx_irqsafe(struct ieee802154_hw *hw, struct sk_buff *skb, u8 lqi) +{ + struct ieee802154_local *local = hw_to_local(hw); + struct ieee802154_mac_cb *cb = mac_cb_init(skb); + + cb->lqi = lqi; + skb->pkt_type = IEEE802154_RX_MSG; + skb_queue_tail(&local->skb_queue, skb); + tasklet_schedule(&local->tasklet); +} +EXPORT_SYMBOL(ieee802154_rx_irqsafe); diff --git a/net/mac802154/trace.c b/net/mac802154/trace.c new file mode 100644 index 000000000..c36e3d541 --- /dev/null +++ b/net/mac802154/trace.c @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: GPL-2.0 +#include + +#ifndef __CHECKER__ +#include +#include "driver-ops.h" +#define CREATE_TRACE_POINTS +#include "trace.h" + +#endif diff --git a/net/mac802154/trace.h b/net/mac802154/trace.h new file mode 100644 index 000000000..df855c33d --- /dev/null +++ b/net/mac802154/trace.h @@ -0,0 +1,273 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Based on net/mac80211/trace.h */ + +#undef TRACE_SYSTEM +#define TRACE_SYSTEM mac802154 + +#if !defined(__MAC802154_DRIVER_TRACE) || defined(TRACE_HEADER_MULTI_READ) +#define __MAC802154_DRIVER_TRACE + +#include + +#include +#include "ieee802154_i.h" + +#define MAXNAME 32 +#define LOCAL_ENTRY __array(char, wpan_phy_name, MAXNAME) +#define LOCAL_ASSIGN strlcpy(__entry->wpan_phy_name, \ + wpan_phy_name(local->hw.phy), MAXNAME) +#define LOCAL_PR_FMT "%s" +#define LOCAL_PR_ARG __entry->wpan_phy_name + +#define CCA_ENTRY __field(enum nl802154_cca_modes, cca_mode) \ + __field(enum nl802154_cca_opts, cca_opt) +#define CCA_ASSIGN \ + do { \ + (__entry->cca_mode) = cca->mode; \ + (__entry->cca_opt) = cca->opt; \ + } while (0) +#define CCA_PR_FMT "cca_mode: %d, cca_opt: %d" +#define CCA_PR_ARG __entry->cca_mode, __entry->cca_opt + +#define BOOL_TO_STR(bo) (bo) ? "true" : "false" + +/* Tracing for driver callbacks */ + +DECLARE_EVENT_CLASS(local_only_evt4, + TP_PROTO(struct ieee802154_local *local), + TP_ARGS(local), + TP_STRUCT__entry( + LOCAL_ENTRY + ), + TP_fast_assign( + LOCAL_ASSIGN; + ), + TP_printk(LOCAL_PR_FMT, LOCAL_PR_ARG) +); + +DEFINE_EVENT(local_only_evt4, 802154_drv_return_void, + TP_PROTO(struct ieee802154_local *local), + TP_ARGS(local) +); + +TRACE_EVENT(802154_drv_return_int, + TP_PROTO(struct ieee802154_local *local, int ret), + TP_ARGS(local, ret), + TP_STRUCT__entry( + LOCAL_ENTRY + __field(int, ret) + ), + TP_fast_assign( + LOCAL_ASSIGN; + __entry->ret = ret; + ), + TP_printk(LOCAL_PR_FMT ", returned: %d", LOCAL_PR_ARG, + __entry->ret) +); + +DEFINE_EVENT(local_only_evt4, 802154_drv_start, + TP_PROTO(struct ieee802154_local *local), + TP_ARGS(local) +); + +DEFINE_EVENT(local_only_evt4, 802154_drv_stop, + TP_PROTO(struct ieee802154_local *local), + TP_ARGS(local) +); + +TRACE_EVENT(802154_drv_set_channel, + TP_PROTO(struct ieee802154_local *local, u8 page, u8 channel), + TP_ARGS(local, page, channel), + TP_STRUCT__entry( + LOCAL_ENTRY + __field(u8, page) + __field(u8, channel) + ), + TP_fast_assign( + LOCAL_ASSIGN; + __entry->page = page; + __entry->channel = channel; + ), + TP_printk(LOCAL_PR_FMT ", page: %d, channel: %d", LOCAL_PR_ARG, + __entry->page, __entry->channel) +); + +TRACE_EVENT(802154_drv_set_cca_mode, + TP_PROTO(struct ieee802154_local *local, + const struct wpan_phy_cca *cca), + TP_ARGS(local, cca), + TP_STRUCT__entry( + LOCAL_ENTRY + CCA_ENTRY + ), + TP_fast_assign( + LOCAL_ASSIGN; + CCA_ASSIGN; + ), + TP_printk(LOCAL_PR_FMT ", " CCA_PR_FMT, LOCAL_PR_ARG, + CCA_PR_ARG) +); + +TRACE_EVENT(802154_drv_set_cca_ed_level, + TP_PROTO(struct ieee802154_local *local, s32 mbm), + TP_ARGS(local, mbm), + TP_STRUCT__entry( + LOCAL_ENTRY + __field(s32, mbm) + ), + TP_fast_assign( + LOCAL_ASSIGN; + __entry->mbm = mbm; + ), + TP_printk(LOCAL_PR_FMT ", ed level: %d", LOCAL_PR_ARG, + __entry->mbm) +); + +TRACE_EVENT(802154_drv_set_tx_power, + TP_PROTO(struct ieee802154_local *local, s32 power), + TP_ARGS(local, power), + TP_STRUCT__entry( + LOCAL_ENTRY + __field(s32, power) + ), + TP_fast_assign( + LOCAL_ASSIGN; + __entry->power = power; + ), + TP_printk(LOCAL_PR_FMT ", mbm: %d", LOCAL_PR_ARG, + __entry->power) +); + +TRACE_EVENT(802154_drv_set_lbt_mode, + TP_PROTO(struct ieee802154_local *local, bool mode), + TP_ARGS(local, mode), + TP_STRUCT__entry( + LOCAL_ENTRY + __field(bool, mode) + ), + TP_fast_assign( + LOCAL_ASSIGN; + __entry->mode = mode; + ), + TP_printk(LOCAL_PR_FMT ", lbt mode: %s", LOCAL_PR_ARG, + BOOL_TO_STR(__entry->mode)) +); + +TRACE_EVENT(802154_drv_set_short_addr, + TP_PROTO(struct ieee802154_local *local, __le16 short_addr), + TP_ARGS(local, short_addr), + TP_STRUCT__entry( + LOCAL_ENTRY + __field(__le16, short_addr) + ), + TP_fast_assign( + LOCAL_ASSIGN; + __entry->short_addr = short_addr; + ), + TP_printk(LOCAL_PR_FMT ", short addr: 0x%04x", LOCAL_PR_ARG, + le16_to_cpu(__entry->short_addr)) +); + +TRACE_EVENT(802154_drv_set_pan_id, + TP_PROTO(struct ieee802154_local *local, __le16 pan_id), + TP_ARGS(local, pan_id), + TP_STRUCT__entry( + LOCAL_ENTRY + __field(__le16, pan_id) + ), + TP_fast_assign( + LOCAL_ASSIGN; + __entry->pan_id = pan_id; + ), + TP_printk(LOCAL_PR_FMT ", pan id: 0x%04x", LOCAL_PR_ARG, + le16_to_cpu(__entry->pan_id)) +); + +TRACE_EVENT(802154_drv_set_extended_addr, + TP_PROTO(struct ieee802154_local *local, __le64 extended_addr), + TP_ARGS(local, extended_addr), + TP_STRUCT__entry( + LOCAL_ENTRY + __field(__le64, extended_addr) + ), + TP_fast_assign( + LOCAL_ASSIGN; + __entry->extended_addr = extended_addr; + ), + TP_printk(LOCAL_PR_FMT ", extended addr: 0x%llx", LOCAL_PR_ARG, + le64_to_cpu(__entry->extended_addr)) +); + +TRACE_EVENT(802154_drv_set_pan_coord, + TP_PROTO(struct ieee802154_local *local, bool is_coord), + TP_ARGS(local, is_coord), + TP_STRUCT__entry( + LOCAL_ENTRY + __field(bool, is_coord) + ), + TP_fast_assign( + LOCAL_ASSIGN; + __entry->is_coord = is_coord; + ), + TP_printk(LOCAL_PR_FMT ", is_coord: %s", LOCAL_PR_ARG, + BOOL_TO_STR(__entry->is_coord)) +); + +TRACE_EVENT(802154_drv_set_csma_params, + TP_PROTO(struct ieee802154_local *local, u8 min_be, u8 max_be, + u8 max_csma_backoffs), + TP_ARGS(local, min_be, max_be, max_csma_backoffs), + TP_STRUCT__entry( + LOCAL_ENTRY + __field(u8, min_be) + __field(u8, max_be) + __field(u8, max_csma_backoffs) + ), + TP_fast_assign( + LOCAL_ASSIGN, + __entry->min_be = min_be; + __entry->max_be = max_be; + __entry->max_csma_backoffs = max_csma_backoffs; + ), + TP_printk(LOCAL_PR_FMT ", min be: %d, max be: %d, max csma backoffs: %d", + LOCAL_PR_ARG, __entry->min_be, __entry->max_be, + __entry->max_csma_backoffs) +); + +TRACE_EVENT(802154_drv_set_max_frame_retries, + TP_PROTO(struct ieee802154_local *local, s8 max_frame_retries), + TP_ARGS(local, max_frame_retries), + TP_STRUCT__entry( + LOCAL_ENTRY + __field(s8, max_frame_retries) + ), + TP_fast_assign( + LOCAL_ASSIGN; + __entry->max_frame_retries = max_frame_retries; + ), + TP_printk(LOCAL_PR_FMT ", max frame retries: %d", LOCAL_PR_ARG, + __entry->max_frame_retries) +); + +TRACE_EVENT(802154_drv_set_promiscuous_mode, + TP_PROTO(struct ieee802154_local *local, bool on), + TP_ARGS(local, on), + TP_STRUCT__entry( + LOCAL_ENTRY + __field(bool, on) + ), + TP_fast_assign( + LOCAL_ASSIGN; + __entry->on = on; + ), + TP_printk(LOCAL_PR_FMT ", promiscuous mode: %s", LOCAL_PR_ARG, + BOOL_TO_STR(__entry->on)) +); + +#endif /* !__MAC802154_DRIVER_TRACE || TRACE_HEADER_MULTI_READ */ + +#undef TRACE_INCLUDE_PATH +#define TRACE_INCLUDE_PATH . +#undef TRACE_INCLUDE_FILE +#define TRACE_INCLUDE_FILE trace +#include diff --git a/net/mac802154/tx.c b/net/mac802154/tx.c new file mode 100644 index 000000000..c829e4a75 --- /dev/null +++ b/net/mac802154/tx.c @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2007-2012 Siemens AG + * + * Written by: + * Dmitry Eremin-Solenikov + * Sergey Lapin + * Maxim Gorbachyov + * Alexander Smirnov + */ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "ieee802154_i.h" +#include "driver-ops.h" + +void ieee802154_xmit_worker(struct work_struct *work) +{ + struct ieee802154_local *local = + container_of(work, struct ieee802154_local, tx_work); + struct sk_buff *skb = local->tx_skb; + struct net_device *dev = skb->dev; + int res; + + res = drv_xmit_sync(local, skb); + if (res) + goto err_tx; + + dev->stats.tx_packets++; + dev->stats.tx_bytes += skb->len; + + ieee802154_xmit_complete(&local->hw, skb, false); + + return; + +err_tx: + /* Restart the netif queue on each sub_if_data object. */ + ieee802154_wake_queue(&local->hw); + kfree_skb(skb); + netdev_dbg(dev, "transmission failed\n"); +} + +static netdev_tx_t +ieee802154_tx(struct ieee802154_local *local, struct sk_buff *skb) +{ + struct net_device *dev = skb->dev; + int ret; + + if (!(local->hw.flags & IEEE802154_HW_TX_OMIT_CKSUM)) { + struct sk_buff *nskb; + u16 crc; + + if (unlikely(skb_tailroom(skb) < IEEE802154_FCS_LEN)) { + nskb = skb_copy_expand(skb, 0, IEEE802154_FCS_LEN, + GFP_ATOMIC); + if (likely(nskb)) { + consume_skb(skb); + skb = nskb; + } else { + goto err_tx; + } + } + + crc = crc_ccitt(0, skb->data, skb->len); + put_unaligned_le16(crc, skb_put(skb, 2)); + } + + /* Stop the netif queue on each sub_if_data object. */ + ieee802154_stop_queue(&local->hw); + + /* async is priority, otherwise sync is fallback */ + if (local->ops->xmit_async) { + unsigned int len = skb->len; + + ret = drv_xmit_async(local, skb); + if (ret) { + ieee802154_wake_queue(&local->hw); + goto err_tx; + } + + dev->stats.tx_packets++; + dev->stats.tx_bytes += len; + } else { + local->tx_skb = skb; + queue_work(local->workqueue, &local->tx_work); + } + + return NETDEV_TX_OK; + +err_tx: + kfree_skb(skb); + return NETDEV_TX_OK; +} + +netdev_tx_t +ieee802154_monitor_start_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + + skb->skb_iif = dev->ifindex; + + return ieee802154_tx(sdata->local, skb); +} + +netdev_tx_t +ieee802154_subif_start_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct ieee802154_sub_if_data *sdata = IEEE802154_DEV_TO_SUB_IF(dev); + int rc; + + /* TODO we should move it to wpan_dev_hard_header and dev_hard_header + * functions. The reason is wireshark will show a mac header which is + * with security fields but the payload is not encrypted. + */ + rc = mac802154_llsec_encrypt(&sdata->sec, skb); + if (rc) { + netdev_warn(dev, "encryption failed: %i\n", rc); + kfree_skb(skb); + return NETDEV_TX_OK; + } + + skb->skb_iif = dev->ifindex; + + return ieee802154_tx(sdata->local, skb); +} diff --git a/net/mac802154/util.c b/net/mac802154/util.c new file mode 100644 index 000000000..9f024d855 --- /dev/null +++ b/net/mac802154/util.c @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * + * Authors: + * Alexander Aring + * + * Based on: net/mac80211/util.c + */ + +#include "ieee802154_i.h" +#include "driver-ops.h" + +/* privid for wpan_phys to determine whether they belong to us or not */ +const void *const mac802154_wpan_phy_privid = &mac802154_wpan_phy_privid; + +void ieee802154_wake_queue(struct ieee802154_hw *hw) +{ + struct ieee802154_local *local = hw_to_local(hw); + struct ieee802154_sub_if_data *sdata; + + rcu_read_lock(); + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + if (!sdata->dev) + continue; + + netif_wake_queue(sdata->dev); + } + rcu_read_unlock(); +} +EXPORT_SYMBOL(ieee802154_wake_queue); + +void ieee802154_stop_queue(struct ieee802154_hw *hw) +{ + struct ieee802154_local *local = hw_to_local(hw); + struct ieee802154_sub_if_data *sdata; + + rcu_read_lock(); + list_for_each_entry_rcu(sdata, &local->interfaces, list) { + if (!sdata->dev) + continue; + + netif_stop_queue(sdata->dev); + } + rcu_read_unlock(); +} +EXPORT_SYMBOL(ieee802154_stop_queue); + +enum hrtimer_restart ieee802154_xmit_ifs_timer(struct hrtimer *timer) +{ + struct ieee802154_local *local = + container_of(timer, struct ieee802154_local, ifs_timer); + + ieee802154_wake_queue(&local->hw); + + return HRTIMER_NORESTART; +} + +void ieee802154_xmit_complete(struct ieee802154_hw *hw, struct sk_buff *skb, + bool ifs_handling) +{ + struct ieee802154_local *local = hw_to_local(hw); + + local->tx_result = IEEE802154_SUCCESS; + + if (ifs_handling) { + u8 max_sifs_size; + + /* If transceiver sets CRC on his own we need to use lifs + * threshold len above 16 otherwise 18, because it's not + * part of skb->len. + */ + if (hw->flags & IEEE802154_HW_TX_OMIT_CKSUM) + max_sifs_size = IEEE802154_MAX_SIFS_FRAME_SIZE - + IEEE802154_FCS_LEN; + else + max_sifs_size = IEEE802154_MAX_SIFS_FRAME_SIZE; + + if (skb->len > max_sifs_size) + hrtimer_start(&local->ifs_timer, + hw->phy->lifs_period * NSEC_PER_USEC, + HRTIMER_MODE_REL); + else + hrtimer_start(&local->ifs_timer, + hw->phy->sifs_period * NSEC_PER_USEC, + HRTIMER_MODE_REL); + } else { + ieee802154_wake_queue(hw); + } + + dev_consume_skb_any(skb); +} +EXPORT_SYMBOL(ieee802154_xmit_complete); + +void ieee802154_xmit_error(struct ieee802154_hw *hw, struct sk_buff *skb, + int reason) +{ + struct ieee802154_local *local = hw_to_local(hw); + + local->tx_result = reason; + ieee802154_wake_queue(hw); + dev_kfree_skb_any(skb); +} +EXPORT_SYMBOL(ieee802154_xmit_error); + +void ieee802154_xmit_hw_error(struct ieee802154_hw *hw, struct sk_buff *skb) +{ + ieee802154_xmit_error(hw, skb, IEEE802154_SYSTEM_ERROR); +} +EXPORT_SYMBOL(ieee802154_xmit_hw_error); + +void ieee802154_stop_device(struct ieee802154_local *local) +{ + flush_workqueue(local->workqueue); + hrtimer_cancel(&local->ifs_timer); + drv_stop(local); +} diff --git a/net/mctp/Kconfig b/net/mctp/Kconfig new file mode 100644 index 000000000..3a5c0e70d --- /dev/null +++ b/net/mctp/Kconfig @@ -0,0 +1,23 @@ + +menuconfig MCTP + depends on NET + bool "MCTP core protocol support" + help + Management Component Transport Protocol (MCTP) is an in-system + protocol for communicating between management controllers and + their managed devices (peripherals, host processors, etc.). The + protocol is defined by DMTF specification DSP0236. + + This option enables core MCTP support. For communicating with other + devices, you'll want to enable a driver for a specific hardware + channel. + +config MCTP_TEST + bool "MCTP core tests" if !KUNIT_ALL_TESTS + depends on MCTP=y && KUNIT=y + default KUNIT_ALL_TESTS + +config MCTP_FLOWS + bool + depends on MCTP + select SKB_EXTENSIONS diff --git a/net/mctp/Makefile b/net/mctp/Makefile new file mode 100644 index 000000000..6cd55233e --- /dev/null +++ b/net/mctp/Makefile @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: GPL-2.0 +obj-$(CONFIG_MCTP) += mctp.o +mctp-objs := af_mctp.o device.o route.o neigh.o + +# tests +obj-$(CONFIG_MCTP_TEST) += test/utils.o diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c new file mode 100644 index 000000000..3150f3f0c --- /dev/null +++ b/net/mctp/af_mctp.c @@ -0,0 +1,710 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Management Component Transport Protocol (MCTP) + * + * Copyright (c) 2021 Code Construct + * Copyright (c) 2021 Google + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#define CREATE_TRACE_POINTS +#include + +/* socket implementation */ + +static void mctp_sk_expire_keys(struct timer_list *timer); + +static int mctp_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + + if (sk) { + sock->sk = NULL; + sk->sk_prot->close(sk, 0); + } + + return 0; +} + +/* Generic sockaddr checks, padding checks only so far */ +static bool mctp_sockaddr_is_ok(const struct sockaddr_mctp *addr) +{ + return !addr->__smctp_pad0 && !addr->__smctp_pad1; +} + +static bool mctp_sockaddr_ext_is_ok(const struct sockaddr_mctp_ext *addr) +{ + return !addr->__smctp_pad0[0] && + !addr->__smctp_pad0[1] && + !addr->__smctp_pad0[2]; +} + +static int mctp_bind(struct socket *sock, struct sockaddr *addr, int addrlen) +{ + struct sock *sk = sock->sk; + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + struct sockaddr_mctp *smctp; + int rc; + + if (addrlen < sizeof(*smctp)) + return -EINVAL; + + if (addr->sa_family != AF_MCTP) + return -EAFNOSUPPORT; + + if (!capable(CAP_NET_BIND_SERVICE)) + return -EACCES; + + /* it's a valid sockaddr for MCTP, cast and do protocol checks */ + smctp = (struct sockaddr_mctp *)addr; + + if (!mctp_sockaddr_is_ok(smctp)) + return -EINVAL; + + lock_sock(sk); + + /* TODO: allow rebind */ + if (sk_hashed(sk)) { + rc = -EADDRINUSE; + goto out_release; + } + msk->bind_net = smctp->smctp_network; + msk->bind_addr = smctp->smctp_addr.s_addr; + msk->bind_type = smctp->smctp_type & 0x7f; /* ignore the IC bit */ + + rc = sk->sk_prot->hash(sk); + +out_release: + release_sock(sk); + + return rc; +} + +static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) +{ + DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name); + int rc, addrlen = msg->msg_namelen; + struct sock *sk = sock->sk; + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + struct mctp_skb_cb *cb; + struct mctp_route *rt; + struct sk_buff *skb = NULL; + int hlen; + + if (addr) { + const u8 tagbits = MCTP_TAG_MASK | MCTP_TAG_OWNER | + MCTP_TAG_PREALLOC; + + if (addrlen < sizeof(struct sockaddr_mctp)) + return -EINVAL; + if (addr->smctp_family != AF_MCTP) + return -EINVAL; + if (!mctp_sockaddr_is_ok(addr)) + return -EINVAL; + if (addr->smctp_tag & ~tagbits) + return -EINVAL; + /* can't preallocate a non-owned tag */ + if (addr->smctp_tag & MCTP_TAG_PREALLOC && + !(addr->smctp_tag & MCTP_TAG_OWNER)) + return -EINVAL; + + } else { + /* TODO: connect()ed sockets */ + return -EDESTADDRREQ; + } + + if (!capable(CAP_NET_RAW)) + return -EACCES; + + if (addr->smctp_network == MCTP_NET_ANY) + addr->smctp_network = mctp_default_net(sock_net(sk)); + + /* direct addressing */ + if (msk->addr_ext && addrlen >= sizeof(struct sockaddr_mctp_ext)) { + DECLARE_SOCKADDR(struct sockaddr_mctp_ext *, + extaddr, msg->msg_name); + struct net_device *dev; + + rc = -EINVAL; + rcu_read_lock(); + dev = dev_get_by_index_rcu(sock_net(sk), extaddr->smctp_ifindex); + /* check for correct halen */ + if (dev && extaddr->smctp_halen == dev->addr_len) { + hlen = LL_RESERVED_SPACE(dev) + sizeof(struct mctp_hdr); + rc = 0; + } + rcu_read_unlock(); + if (rc) + goto err_free; + rt = NULL; + } else { + rt = mctp_route_lookup(sock_net(sk), addr->smctp_network, + addr->smctp_addr.s_addr); + if (!rt) { + rc = -EHOSTUNREACH; + goto err_free; + } + hlen = LL_RESERVED_SPACE(rt->dev->dev) + sizeof(struct mctp_hdr); + } + + skb = sock_alloc_send_skb(sk, hlen + 1 + len, + msg->msg_flags & MSG_DONTWAIT, &rc); + if (!skb) + return rc; + + skb_reserve(skb, hlen); + + /* set type as fist byte in payload */ + *(u8 *)skb_put(skb, 1) = addr->smctp_type; + + rc = memcpy_from_msg((void *)skb_put(skb, len), msg, len); + if (rc < 0) + goto err_free; + + /* set up cb */ + cb = __mctp_cb(skb); + cb->net = addr->smctp_network; + + if (!rt) { + /* fill extended address in cb */ + DECLARE_SOCKADDR(struct sockaddr_mctp_ext *, + extaddr, msg->msg_name); + + if (!mctp_sockaddr_ext_is_ok(extaddr) || + extaddr->smctp_halen > sizeof(cb->haddr)) { + rc = -EINVAL; + goto err_free; + } + + cb->ifindex = extaddr->smctp_ifindex; + /* smctp_halen is checked above */ + cb->halen = extaddr->smctp_halen; + memcpy(cb->haddr, extaddr->smctp_haddr, cb->halen); + } + + rc = mctp_local_output(sk, rt, skb, addr->smctp_addr.s_addr, + addr->smctp_tag); + + return rc ? : len; + +err_free: + kfree_skb(skb); + return rc; +} + +static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, + int flags) +{ + DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name); + struct sock *sk = sock->sk; + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + struct sk_buff *skb; + size_t msglen; + u8 type; + int rc; + + if (flags & ~(MSG_DONTWAIT | MSG_TRUNC | MSG_PEEK)) + return -EOPNOTSUPP; + + skb = skb_recv_datagram(sk, flags, &rc); + if (!skb) + return rc; + + if (!skb->len) { + rc = 0; + goto out_free; + } + + /* extract message type, remove from data */ + type = *((u8 *)skb->data); + msglen = skb->len - 1; + + if (len < msglen) + msg->msg_flags |= MSG_TRUNC; + else + len = msglen; + + rc = skb_copy_datagram_msg(skb, 1, msg, len); + if (rc < 0) + goto out_free; + + sock_recv_cmsgs(msg, sk, skb); + + if (addr) { + struct mctp_skb_cb *cb = mctp_cb(skb); + /* TODO: expand mctp_skb_cb for header fields? */ + struct mctp_hdr *hdr = mctp_hdr(skb); + + addr = msg->msg_name; + addr->smctp_family = AF_MCTP; + addr->__smctp_pad0 = 0; + addr->smctp_network = cb->net; + addr->smctp_addr.s_addr = hdr->src; + addr->smctp_type = type; + addr->smctp_tag = hdr->flags_seq_tag & + (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO); + addr->__smctp_pad1 = 0; + msg->msg_namelen = sizeof(*addr); + + if (msk->addr_ext) { + DECLARE_SOCKADDR(struct sockaddr_mctp_ext *, ae, + msg->msg_name); + msg->msg_namelen = sizeof(*ae); + ae->smctp_ifindex = cb->ifindex; + ae->smctp_halen = cb->halen; + memset(ae->__smctp_pad0, 0x0, sizeof(ae->__smctp_pad0)); + memset(ae->smctp_haddr, 0x0, sizeof(ae->smctp_haddr)); + memcpy(ae->smctp_haddr, cb->haddr, cb->halen); + } + } + + rc = len; + + if (flags & MSG_TRUNC) + rc = msglen; + +out_free: + skb_free_datagram(sk, skb); + return rc; +} + +/* We're done with the key; invalidate, stop reassembly, and remove from lists. + */ +static void __mctp_key_remove(struct mctp_sk_key *key, struct net *net, + unsigned long flags, unsigned long reason) +__releases(&key->lock) +__must_hold(&net->mctp.keys_lock) +{ + struct sk_buff *skb; + + trace_mctp_key_release(key, reason); + skb = key->reasm_head; + key->reasm_head = NULL; + key->reasm_dead = true; + key->valid = false; + mctp_dev_release_key(key->dev, key); + spin_unlock_irqrestore(&key->lock, flags); + + if (!hlist_unhashed(&key->hlist)) { + hlist_del_init(&key->hlist); + hlist_del_init(&key->sklist); + /* unref for the lists */ + mctp_key_unref(key); + } + + kfree_skb(skb); +} + +static int mctp_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk); + int val; + + if (level != SOL_MCTP) + return -EINVAL; + + if (optname == MCTP_OPT_ADDR_EXT) { + if (optlen != sizeof(int)) + return -EINVAL; + if (copy_from_sockptr(&val, optval, sizeof(int))) + return -EFAULT; + msk->addr_ext = val; + return 0; + } + + return -ENOPROTOOPT; +} + +static int mctp_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk); + int len, val; + + if (level != SOL_MCTP) + return -EINVAL; + + if (get_user(len, optlen)) + return -EFAULT; + + if (optname == MCTP_OPT_ADDR_EXT) { + if (len != sizeof(int)) + return -EINVAL; + val = !!msk->addr_ext; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + return 0; + } + + return -EINVAL; +} + +static int mctp_ioctl_alloctag(struct mctp_sock *msk, unsigned long arg) +{ + struct net *net = sock_net(&msk->sk); + struct mctp_sk_key *key = NULL; + struct mctp_ioc_tag_ctl ctl; + unsigned long flags; + u8 tag; + + if (copy_from_user(&ctl, (void __user *)arg, sizeof(ctl))) + return -EFAULT; + + if (ctl.tag) + return -EINVAL; + + if (ctl.flags) + return -EINVAL; + + key = mctp_alloc_local_tag(msk, ctl.peer_addr, MCTP_ADDR_ANY, + true, &tag); + if (IS_ERR(key)) + return PTR_ERR(key); + + ctl.tag = tag | MCTP_TAG_OWNER | MCTP_TAG_PREALLOC; + if (copy_to_user((void __user *)arg, &ctl, sizeof(ctl))) { + unsigned long fl2; + /* Unwind our key allocation: the keys list lock needs to be + * taken before the individual key locks, and we need a valid + * flags value (fl2) to pass to __mctp_key_remove, hence the + * second spin_lock_irqsave() rather than a plain spin_lock(). + */ + spin_lock_irqsave(&net->mctp.keys_lock, flags); + spin_lock_irqsave(&key->lock, fl2); + __mctp_key_remove(key, net, fl2, MCTP_TRACE_KEY_DROPPED); + mctp_key_unref(key); + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + return -EFAULT; + } + + mctp_key_unref(key); + return 0; +} + +static int mctp_ioctl_droptag(struct mctp_sock *msk, unsigned long arg) +{ + struct net *net = sock_net(&msk->sk); + struct mctp_ioc_tag_ctl ctl; + unsigned long flags, fl2; + struct mctp_sk_key *key; + struct hlist_node *tmp; + int rc; + u8 tag; + + if (copy_from_user(&ctl, (void __user *)arg, sizeof(ctl))) + return -EFAULT; + + if (ctl.flags) + return -EINVAL; + + /* Must be a local tag, TO set, preallocated */ + if ((ctl.tag & ~MCTP_TAG_MASK) != (MCTP_TAG_OWNER | MCTP_TAG_PREALLOC)) + return -EINVAL; + + tag = ctl.tag & MCTP_TAG_MASK; + rc = -EINVAL; + + spin_lock_irqsave(&net->mctp.keys_lock, flags); + hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { + /* we do an irqsave here, even though we know the irq state, + * so we have the flags to pass to __mctp_key_remove + */ + spin_lock_irqsave(&key->lock, fl2); + if (key->manual_alloc && + ctl.peer_addr == key->peer_addr && + tag == key->tag) { + __mctp_key_remove(key, net, fl2, + MCTP_TRACE_KEY_DROPPED); + rc = 0; + } else { + spin_unlock_irqrestore(&key->lock, fl2); + } + } + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + + return rc; +} + +static int mctp_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk); + + switch (cmd) { + case SIOCMCTPALLOCTAG: + return mctp_ioctl_alloctag(msk, arg); + case SIOCMCTPDROPTAG: + return mctp_ioctl_droptag(msk, arg); + } + + return -EINVAL; +} + +#ifdef CONFIG_COMPAT +static int mctp_compat_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + void __user *argp = compat_ptr(arg); + + switch (cmd) { + /* These have compatible ptr layouts */ + case SIOCMCTPALLOCTAG: + case SIOCMCTPDROPTAG: + return mctp_ioctl(sock, cmd, (unsigned long)argp); + } + + return -ENOIOCTLCMD; +} +#endif + +static const struct proto_ops mctp_dgram_ops = { + .family = PF_MCTP, + .release = mctp_release, + .bind = mctp_bind, + .connect = sock_no_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = sock_no_getname, + .poll = datagram_poll, + .ioctl = mctp_ioctl, + .gettstamp = sock_gettstamp, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = mctp_setsockopt, + .getsockopt = mctp_getsockopt, + .sendmsg = mctp_sendmsg, + .recvmsg = mctp_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +#ifdef CONFIG_COMPAT + .compat_ioctl = mctp_compat_ioctl, +#endif +}; + +static void mctp_sk_expire_keys(struct timer_list *timer) +{ + struct mctp_sock *msk = container_of(timer, struct mctp_sock, + key_expiry); + struct net *net = sock_net(&msk->sk); + unsigned long next_expiry, flags, fl2; + struct mctp_sk_key *key; + struct hlist_node *tmp; + bool next_expiry_valid = false; + + spin_lock_irqsave(&net->mctp.keys_lock, flags); + + hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { + /* don't expire. manual_alloc is immutable, no locking + * required. + */ + if (key->manual_alloc) + continue; + + spin_lock_irqsave(&key->lock, fl2); + if (!time_after_eq(key->expiry, jiffies)) { + __mctp_key_remove(key, net, fl2, + MCTP_TRACE_KEY_TIMEOUT); + continue; + } + + if (next_expiry_valid) { + if (time_before(key->expiry, next_expiry)) + next_expiry = key->expiry; + } else { + next_expiry = key->expiry; + next_expiry_valid = true; + } + spin_unlock_irqrestore(&key->lock, fl2); + } + + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + + if (next_expiry_valid) + mod_timer(timer, next_expiry); +} + +static int mctp_sk_init(struct sock *sk) +{ + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + + INIT_HLIST_HEAD(&msk->keys); + timer_setup(&msk->key_expiry, mctp_sk_expire_keys, 0); + return 0; +} + +static void mctp_sk_close(struct sock *sk, long timeout) +{ + sk_common_release(sk); +} + +static int mctp_sk_hash(struct sock *sk) +{ + struct net *net = sock_net(sk); + + mutex_lock(&net->mctp.bind_lock); + sk_add_node_rcu(sk, &net->mctp.binds); + mutex_unlock(&net->mctp.bind_lock); + + return 0; +} + +static void mctp_sk_unhash(struct sock *sk) +{ + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + struct net *net = sock_net(sk); + unsigned long flags, fl2; + struct mctp_sk_key *key; + struct hlist_node *tmp; + + /* remove from any type-based binds */ + mutex_lock(&net->mctp.bind_lock); + sk_del_node_init_rcu(sk); + mutex_unlock(&net->mctp.bind_lock); + + /* remove tag allocations */ + spin_lock_irqsave(&net->mctp.keys_lock, flags); + hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) { + spin_lock_irqsave(&key->lock, fl2); + __mctp_key_remove(key, net, fl2, MCTP_TRACE_KEY_CLOSED); + } + sock_set_flag(sk, SOCK_DEAD); + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + + /* Since there are no more tag allocations (we have removed all of the + * keys), stop any pending expiry events. the timer cannot be re-queued + * as the sk is no longer observable + */ + del_timer_sync(&msk->key_expiry); +} + +static void mctp_sk_destruct(struct sock *sk) +{ + skb_queue_purge(&sk->sk_receive_queue); +} + +static struct proto mctp_proto = { + .name = "MCTP", + .owner = THIS_MODULE, + .obj_size = sizeof(struct mctp_sock), + .init = mctp_sk_init, + .close = mctp_sk_close, + .hash = mctp_sk_hash, + .unhash = mctp_sk_unhash, +}; + +static int mctp_pf_create(struct net *net, struct socket *sock, + int protocol, int kern) +{ + const struct proto_ops *ops; + struct proto *proto; + struct sock *sk; + int rc; + + if (protocol) + return -EPROTONOSUPPORT; + + /* only datagram sockets are supported */ + if (sock->type != SOCK_DGRAM) + return -ESOCKTNOSUPPORT; + + proto = &mctp_proto; + ops = &mctp_dgram_ops; + + sock->state = SS_UNCONNECTED; + sock->ops = ops; + + sk = sk_alloc(net, PF_MCTP, GFP_KERNEL, proto, kern); + if (!sk) + return -ENOMEM; + + sock_init_data(sock, sk); + sk->sk_destruct = mctp_sk_destruct; + + rc = 0; + if (sk->sk_prot->init) + rc = sk->sk_prot->init(sk); + + if (rc) + goto err_sk_put; + + return 0; + +err_sk_put: + sock_orphan(sk); + sock_put(sk); + return rc; +} + +static struct net_proto_family mctp_pf = { + .family = PF_MCTP, + .create = mctp_pf_create, + .owner = THIS_MODULE, +}; + +static __init int mctp_init(void) +{ + int rc; + + /* ensure our uapi tag definitions match the header format */ + BUILD_BUG_ON(MCTP_TAG_OWNER != MCTP_HDR_FLAG_TO); + BUILD_BUG_ON(MCTP_TAG_MASK != MCTP_HDR_TAG_MASK); + + pr_info("mctp: management component transport protocol core\n"); + + rc = sock_register(&mctp_pf); + if (rc) + return rc; + + rc = proto_register(&mctp_proto, 0); + if (rc) + goto err_unreg_sock; + + rc = mctp_routes_init(); + if (rc) + goto err_unreg_proto; + + rc = mctp_neigh_init(); + if (rc) + goto err_unreg_routes; + + mctp_device_init(); + + return 0; + +err_unreg_routes: + mctp_routes_exit(); +err_unreg_proto: + proto_unregister(&mctp_proto); +err_unreg_sock: + sock_unregister(PF_MCTP); + + return rc; +} + +static __exit void mctp_exit(void) +{ + mctp_device_exit(); + mctp_neigh_exit(); + mctp_routes_exit(); + proto_unregister(&mctp_proto); + sock_unregister(PF_MCTP); +} + +subsys_initcall(mctp_init); +module_exit(mctp_exit); + +MODULE_DESCRIPTION("MCTP core"); +MODULE_LICENSE("GPL v2"); +MODULE_AUTHOR("Jeremy Kerr "); + +MODULE_ALIAS_NETPROTO(PF_MCTP); diff --git a/net/mctp/device.c b/net/mctp/device.c new file mode 100644 index 000000000..acb97b257 --- /dev/null +++ b/net/mctp/device.c @@ -0,0 +1,548 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Management Component Transport Protocol (MCTP) - device implementation. + * + * Copyright (c) 2021 Code Construct + * Copyright (c) 2021 Google + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +struct mctp_dump_cb { + int h; + int idx; + size_t a_idx; +}; + +/* unlocked: caller must hold rcu_read_lock. + * Returned mctp_dev has its refcount incremented, or NULL if unset. + */ +struct mctp_dev *__mctp_dev_get(const struct net_device *dev) +{ + struct mctp_dev *mdev = rcu_dereference(dev->mctp_ptr); + + /* RCU guarantees that any mdev is still live. + * Zero refcount implies a pending free, return NULL. + */ + if (mdev) + if (!refcount_inc_not_zero(&mdev->refs)) + return NULL; + return mdev; +} + +/* Returned mctp_dev does not have refcount incremented. The returned pointer + * remains live while rtnl_lock is held, as that prevents mctp_unregister() + */ +struct mctp_dev *mctp_dev_get_rtnl(const struct net_device *dev) +{ + return rtnl_dereference(dev->mctp_ptr); +} + +static int mctp_addrinfo_size(void) +{ + return NLMSG_ALIGN(sizeof(struct ifaddrmsg)) + + nla_total_size(1) // IFA_LOCAL + + nla_total_size(1) // IFA_ADDRESS + ; +} + +/* flag should be NLM_F_MULTI for dump calls */ +static int mctp_fill_addrinfo(struct sk_buff *skb, + struct mctp_dev *mdev, mctp_eid_t eid, + int msg_type, u32 portid, u32 seq, int flag) +{ + struct ifaddrmsg *hdr; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(skb, portid, seq, + msg_type, sizeof(*hdr), flag); + if (!nlh) + return -EMSGSIZE; + + hdr = nlmsg_data(nlh); + hdr->ifa_family = AF_MCTP; + hdr->ifa_prefixlen = 0; + hdr->ifa_flags = 0; + hdr->ifa_scope = 0; + hdr->ifa_index = mdev->dev->ifindex; + + if (nla_put_u8(skb, IFA_LOCAL, eid)) + goto cancel; + + if (nla_put_u8(skb, IFA_ADDRESS, eid)) + goto cancel; + + nlmsg_end(skb, nlh); + + return 0; + +cancel: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int mctp_dump_dev_addrinfo(struct mctp_dev *mdev, struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct mctp_dump_cb *mcb = (void *)cb->ctx; + u32 portid, seq; + int rc = 0; + + portid = NETLINK_CB(cb->skb).portid; + seq = cb->nlh->nlmsg_seq; + for (; mcb->a_idx < mdev->num_addrs; mcb->a_idx++) { + rc = mctp_fill_addrinfo(skb, mdev, mdev->addrs[mcb->a_idx], + RTM_NEWADDR, portid, seq, NLM_F_MULTI); + if (rc < 0) + break; + } + + return rc; +} + +static int mctp_dump_addrinfo(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct mctp_dump_cb *mcb = (void *)cb->ctx; + struct net *net = sock_net(skb->sk); + struct hlist_head *head; + struct net_device *dev; + struct ifaddrmsg *hdr; + struct mctp_dev *mdev; + int ifindex; + int idx = 0, rc; + + hdr = nlmsg_data(cb->nlh); + // filter by ifindex if requested + ifindex = hdr->ifa_index; + + rcu_read_lock(); + for (; mcb->h < NETDEV_HASHENTRIES; mcb->h++, mcb->idx = 0) { + idx = 0; + head = &net->dev_index_head[mcb->h]; + hlist_for_each_entry_rcu(dev, head, index_hlist) { + if (idx >= mcb->idx && + (ifindex == 0 || ifindex == dev->ifindex)) { + mdev = __mctp_dev_get(dev); + if (mdev) { + rc = mctp_dump_dev_addrinfo(mdev, + skb, cb); + mctp_dev_put(mdev); + // Error indicates full buffer, this + // callback will get retried. + if (rc < 0) + goto out; + } + } + idx++; + // reset for next iteration + mcb->a_idx = 0; + } + } +out: + rcu_read_unlock(); + mcb->idx = idx; + + return skb->len; +} + +static void mctp_addr_notify(struct mctp_dev *mdev, mctp_eid_t eid, int msg_type, + struct sk_buff *req_skb, struct nlmsghdr *req_nlh) +{ + u32 portid = NETLINK_CB(req_skb).portid; + struct net *net = dev_net(mdev->dev); + struct sk_buff *skb; + int rc = -ENOBUFS; + + skb = nlmsg_new(mctp_addrinfo_size(), GFP_KERNEL); + if (!skb) + goto out; + + rc = mctp_fill_addrinfo(skb, mdev, eid, msg_type, + portid, req_nlh->nlmsg_seq, 0); + if (rc < 0) { + WARN_ON_ONCE(rc == -EMSGSIZE); + goto out; + } + + rtnl_notify(skb, net, portid, RTNLGRP_MCTP_IFADDR, req_nlh, GFP_KERNEL); + return; +out: + kfree_skb(skb); + rtnl_set_sk_err(net, RTNLGRP_MCTP_IFADDR, rc); +} + +static const struct nla_policy ifa_mctp_policy[IFA_MAX + 1] = { + [IFA_ADDRESS] = { .type = NLA_U8 }, + [IFA_LOCAL] = { .type = NLA_U8 }, +}; + +static int mctp_rtm_newaddr(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *tb[IFA_MAX + 1]; + struct net_device *dev; + struct mctp_addr *addr; + struct mctp_dev *mdev; + struct ifaddrmsg *ifm; + unsigned long flags; + u8 *tmp_addrs; + int rc; + + rc = nlmsg_parse(nlh, sizeof(*ifm), tb, IFA_MAX, ifa_mctp_policy, + extack); + if (rc < 0) + return rc; + + ifm = nlmsg_data(nlh); + + if (tb[IFA_LOCAL]) + addr = nla_data(tb[IFA_LOCAL]); + else if (tb[IFA_ADDRESS]) + addr = nla_data(tb[IFA_ADDRESS]); + else + return -EINVAL; + + /* find device */ + dev = __dev_get_by_index(net, ifm->ifa_index); + if (!dev) + return -ENODEV; + + mdev = mctp_dev_get_rtnl(dev); + if (!mdev) + return -ENODEV; + + if (!mctp_address_unicast(addr->s_addr)) + return -EINVAL; + + /* Prevent duplicates. Under RTNL so don't need to lock for reading */ + if (memchr(mdev->addrs, addr->s_addr, mdev->num_addrs)) + return -EEXIST; + + tmp_addrs = kmalloc(mdev->num_addrs + 1, GFP_KERNEL); + if (!tmp_addrs) + return -ENOMEM; + memcpy(tmp_addrs, mdev->addrs, mdev->num_addrs); + tmp_addrs[mdev->num_addrs] = addr->s_addr; + + /* Lock to write */ + spin_lock_irqsave(&mdev->addrs_lock, flags); + mdev->num_addrs++; + swap(mdev->addrs, tmp_addrs); + spin_unlock_irqrestore(&mdev->addrs_lock, flags); + + kfree(tmp_addrs); + + mctp_addr_notify(mdev, addr->s_addr, RTM_NEWADDR, skb, nlh); + mctp_route_add_local(mdev, addr->s_addr); + + return 0; +} + +static int mctp_rtm_deladdr(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *tb[IFA_MAX + 1]; + struct net_device *dev; + struct mctp_addr *addr; + struct mctp_dev *mdev; + struct ifaddrmsg *ifm; + unsigned long flags; + u8 *pos; + int rc; + + rc = nlmsg_parse(nlh, sizeof(*ifm), tb, IFA_MAX, ifa_mctp_policy, + extack); + if (rc < 0) + return rc; + + ifm = nlmsg_data(nlh); + + if (tb[IFA_LOCAL]) + addr = nla_data(tb[IFA_LOCAL]); + else if (tb[IFA_ADDRESS]) + addr = nla_data(tb[IFA_ADDRESS]); + else + return -EINVAL; + + /* find device */ + dev = __dev_get_by_index(net, ifm->ifa_index); + if (!dev) + return -ENODEV; + + mdev = mctp_dev_get_rtnl(dev); + if (!mdev) + return -ENODEV; + + pos = memchr(mdev->addrs, addr->s_addr, mdev->num_addrs); + if (!pos) + return -ENOENT; + + rc = mctp_route_remove_local(mdev, addr->s_addr); + // we can ignore -ENOENT in the case a route was already removed + if (rc < 0 && rc != -ENOENT) + return rc; + + spin_lock_irqsave(&mdev->addrs_lock, flags); + memmove(pos, pos + 1, mdev->num_addrs - 1 - (pos - mdev->addrs)); + mdev->num_addrs--; + spin_unlock_irqrestore(&mdev->addrs_lock, flags); + + mctp_addr_notify(mdev, addr->s_addr, RTM_DELADDR, skb, nlh); + + return 0; +} + +void mctp_dev_hold(struct mctp_dev *mdev) +{ + refcount_inc(&mdev->refs); +} + +void mctp_dev_put(struct mctp_dev *mdev) +{ + if (mdev && refcount_dec_and_test(&mdev->refs)) { + kfree(mdev->addrs); + dev_put(mdev->dev); + kfree_rcu(mdev, rcu); + } +} + +void mctp_dev_release_key(struct mctp_dev *dev, struct mctp_sk_key *key) + __must_hold(&key->lock) +{ + if (!dev) + return; + if (dev->ops && dev->ops->release_flow) + dev->ops->release_flow(dev, key); + key->dev = NULL; + mctp_dev_put(dev); +} + +void mctp_dev_set_key(struct mctp_dev *dev, struct mctp_sk_key *key) + __must_hold(&key->lock) +{ + mctp_dev_hold(dev); + key->dev = dev; +} + +static struct mctp_dev *mctp_add_dev(struct net_device *dev) +{ + struct mctp_dev *mdev; + + ASSERT_RTNL(); + + mdev = kzalloc(sizeof(*mdev), GFP_KERNEL); + if (!mdev) + return ERR_PTR(-ENOMEM); + + spin_lock_init(&mdev->addrs_lock); + + mdev->net = mctp_default_net(dev_net(dev)); + + /* associate to net_device */ + refcount_set(&mdev->refs, 1); + rcu_assign_pointer(dev->mctp_ptr, mdev); + + dev_hold(dev); + mdev->dev = dev; + + return mdev; +} + +static int mctp_fill_link_af(struct sk_buff *skb, + const struct net_device *dev, u32 ext_filter_mask) +{ + struct mctp_dev *mdev; + + mdev = mctp_dev_get_rtnl(dev); + if (!mdev) + return -ENODATA; + if (nla_put_u32(skb, IFLA_MCTP_NET, mdev->net)) + return -EMSGSIZE; + return 0; +} + +static size_t mctp_get_link_af_size(const struct net_device *dev, + u32 ext_filter_mask) +{ + struct mctp_dev *mdev; + unsigned int ret; + + /* caller holds RCU */ + mdev = __mctp_dev_get(dev); + if (!mdev) + return 0; + ret = nla_total_size(4); /* IFLA_MCTP_NET */ + mctp_dev_put(mdev); + return ret; +} + +static const struct nla_policy ifla_af_mctp_policy[IFLA_MCTP_MAX + 1] = { + [IFLA_MCTP_NET] = { .type = NLA_U32 }, +}; + +static int mctp_set_link_af(struct net_device *dev, const struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[IFLA_MCTP_MAX + 1]; + struct mctp_dev *mdev; + int rc; + + rc = nla_parse_nested(tb, IFLA_MCTP_MAX, attr, ifla_af_mctp_policy, + NULL); + if (rc) + return rc; + + mdev = mctp_dev_get_rtnl(dev); + if (!mdev) + return 0; + + if (tb[IFLA_MCTP_NET]) + WRITE_ONCE(mdev->net, nla_get_u32(tb[IFLA_MCTP_NET])); + + return 0; +} + +/* Matches netdev types that should have MCTP handling */ +static bool mctp_known(struct net_device *dev) +{ + /* only register specific types (inc. NONE for TUN devices) */ + return dev->type == ARPHRD_MCTP || + dev->type == ARPHRD_LOOPBACK || + dev->type == ARPHRD_NONE; +} + +static void mctp_unregister(struct net_device *dev) +{ + struct mctp_dev *mdev; + + mdev = mctp_dev_get_rtnl(dev); + if (!mdev) + return; + + RCU_INIT_POINTER(mdev->dev->mctp_ptr, NULL); + + mctp_route_remove_dev(mdev); + mctp_neigh_remove_dev(mdev); + + mctp_dev_put(mdev); +} + +static int mctp_register(struct net_device *dev) +{ + struct mctp_dev *mdev; + + /* Already registered? */ + if (rtnl_dereference(dev->mctp_ptr)) + return 0; + + /* only register specific types */ + if (!mctp_known(dev)) + return 0; + + mdev = mctp_add_dev(dev); + if (IS_ERR(mdev)) + return PTR_ERR(mdev); + + return 0; +} + +static int mctp_dev_notify(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + int rc; + + switch (event) { + case NETDEV_REGISTER: + rc = mctp_register(dev); + if (rc) + return notifier_from_errno(rc); + break; + case NETDEV_UNREGISTER: + mctp_unregister(dev); + break; + } + + return NOTIFY_OK; +} + +static int mctp_register_netdevice(struct net_device *dev, + const struct mctp_netdev_ops *ops) +{ + struct mctp_dev *mdev; + + mdev = mctp_add_dev(dev); + if (IS_ERR(mdev)) + return PTR_ERR(mdev); + + mdev->ops = ops; + + return register_netdevice(dev); +} + +int mctp_register_netdev(struct net_device *dev, + const struct mctp_netdev_ops *ops) +{ + int rc; + + rtnl_lock(); + rc = mctp_register_netdevice(dev, ops); + rtnl_unlock(); + + return rc; +} +EXPORT_SYMBOL_GPL(mctp_register_netdev); + +void mctp_unregister_netdev(struct net_device *dev) +{ + unregister_netdev(dev); +} +EXPORT_SYMBOL_GPL(mctp_unregister_netdev); + +static struct rtnl_af_ops mctp_af_ops = { + .family = AF_MCTP, + .fill_link_af = mctp_fill_link_af, + .get_link_af_size = mctp_get_link_af_size, + .set_link_af = mctp_set_link_af, +}; + +static struct notifier_block mctp_dev_nb = { + .notifier_call = mctp_dev_notify, + .priority = ADDRCONF_NOTIFY_PRIORITY, +}; + +void __init mctp_device_init(void) +{ + register_netdevice_notifier(&mctp_dev_nb); + + rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_GETADDR, + NULL, mctp_dump_addrinfo, 0); + rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_NEWADDR, + mctp_rtm_newaddr, NULL, 0); + rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_DELADDR, + mctp_rtm_deladdr, NULL, 0); + rtnl_af_register(&mctp_af_ops); +} + +void __exit mctp_device_exit(void) +{ + rtnl_af_unregister(&mctp_af_ops); + rtnl_unregister(PF_MCTP, RTM_DELADDR); + rtnl_unregister(PF_MCTP, RTM_NEWADDR); + rtnl_unregister(PF_MCTP, RTM_GETADDR); + + unregister_netdevice_notifier(&mctp_dev_nb); +} diff --git a/net/mctp/neigh.c b/net/mctp/neigh.c new file mode 100644 index 000000000..ffa0f9e09 --- /dev/null +++ b/net/mctp/neigh.c @@ -0,0 +1,343 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Management Component Transport Protocol (MCTP) - routing + * implementation. + * + * This is currently based on a simple routing table, with no dst cache. The + * number of routes should stay fairly small, so the lookup cost is small. + * + * Copyright (c) 2021 Code Construct + * Copyright (c) 2021 Google + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +static int mctp_neigh_add(struct mctp_dev *mdev, mctp_eid_t eid, + enum mctp_neigh_source source, + size_t lladdr_len, const void *lladdr) +{ + struct net *net = dev_net(mdev->dev); + struct mctp_neigh *neigh; + int rc; + + mutex_lock(&net->mctp.neigh_lock); + if (mctp_neigh_lookup(mdev, eid, NULL) == 0) { + rc = -EEXIST; + goto out; + } + + if (lladdr_len > sizeof(neigh->ha)) { + rc = -EINVAL; + goto out; + } + + neigh = kzalloc(sizeof(*neigh), GFP_KERNEL); + if (!neigh) { + rc = -ENOMEM; + goto out; + } + INIT_LIST_HEAD(&neigh->list); + neigh->dev = mdev; + mctp_dev_hold(neigh->dev); + neigh->eid = eid; + neigh->source = source; + memcpy(neigh->ha, lladdr, lladdr_len); + + list_add_rcu(&neigh->list, &net->mctp.neighbours); + rc = 0; +out: + mutex_unlock(&net->mctp.neigh_lock); + return rc; +} + +static void __mctp_neigh_free(struct rcu_head *rcu) +{ + struct mctp_neigh *neigh = container_of(rcu, struct mctp_neigh, rcu); + + mctp_dev_put(neigh->dev); + kfree(neigh); +} + +/* Removes all neighbour entries referring to a device */ +void mctp_neigh_remove_dev(struct mctp_dev *mdev) +{ + struct net *net = dev_net(mdev->dev); + struct mctp_neigh *neigh, *tmp; + + mutex_lock(&net->mctp.neigh_lock); + list_for_each_entry_safe(neigh, tmp, &net->mctp.neighbours, list) { + if (neigh->dev == mdev) { + list_del_rcu(&neigh->list); + /* TODO: immediate RTM_DELNEIGH */ + call_rcu(&neigh->rcu, __mctp_neigh_free); + } + } + + mutex_unlock(&net->mctp.neigh_lock); +} + +static int mctp_neigh_remove(struct mctp_dev *mdev, mctp_eid_t eid, + enum mctp_neigh_source source) +{ + struct net *net = dev_net(mdev->dev); + struct mctp_neigh *neigh, *tmp; + bool dropped = false; + + mutex_lock(&net->mctp.neigh_lock); + list_for_each_entry_safe(neigh, tmp, &net->mctp.neighbours, list) { + if (neigh->dev == mdev && neigh->eid == eid && + neigh->source == source) { + list_del_rcu(&neigh->list); + /* TODO: immediate RTM_DELNEIGH */ + call_rcu(&neigh->rcu, __mctp_neigh_free); + dropped = true; + } + } + + mutex_unlock(&net->mctp.neigh_lock); + return dropped ? 0 : -ENOENT; +} + +static const struct nla_policy nd_mctp_policy[NDA_MAX + 1] = { + [NDA_DST] = { .type = NLA_U8 }, + [NDA_LLADDR] = { .type = NLA_BINARY, .len = MAX_ADDR_LEN }, +}; + +static int mctp_rtm_newneigh(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct net_device *dev; + struct mctp_dev *mdev; + struct ndmsg *ndm; + struct nlattr *tb[NDA_MAX + 1]; + int rc; + mctp_eid_t eid; + void *lladdr; + int lladdr_len; + + rc = nlmsg_parse(nlh, sizeof(*ndm), tb, NDA_MAX, nd_mctp_policy, + extack); + if (rc < 0) { + NL_SET_ERR_MSG(extack, "lladdr too large?"); + return rc; + } + + if (!tb[NDA_DST]) { + NL_SET_ERR_MSG(extack, "Neighbour EID must be specified"); + return -EINVAL; + } + + if (!tb[NDA_LLADDR]) { + NL_SET_ERR_MSG(extack, "Neighbour lladdr must be specified"); + return -EINVAL; + } + + eid = nla_get_u8(tb[NDA_DST]); + if (!mctp_address_unicast(eid)) { + NL_SET_ERR_MSG(extack, "Invalid neighbour EID"); + return -EINVAL; + } + + lladdr = nla_data(tb[NDA_LLADDR]); + lladdr_len = nla_len(tb[NDA_LLADDR]); + + ndm = nlmsg_data(nlh); + + dev = __dev_get_by_index(net, ndm->ndm_ifindex); + if (!dev) + return -ENODEV; + + mdev = mctp_dev_get_rtnl(dev); + if (!mdev) + return -ENODEV; + + if (lladdr_len != dev->addr_len) { + NL_SET_ERR_MSG(extack, "Wrong lladdr length"); + return -EINVAL; + } + + return mctp_neigh_add(mdev, eid, MCTP_NEIGH_STATIC, + lladdr_len, lladdr); +} + +static int mctp_rtm_delneigh(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *tb[NDA_MAX + 1]; + struct net_device *dev; + struct mctp_dev *mdev; + struct ndmsg *ndm; + int rc; + mctp_eid_t eid; + + rc = nlmsg_parse(nlh, sizeof(*ndm), tb, NDA_MAX, nd_mctp_policy, + extack); + if (rc < 0) { + NL_SET_ERR_MSG(extack, "incorrect format"); + return rc; + } + + if (!tb[NDA_DST]) { + NL_SET_ERR_MSG(extack, "Neighbour EID must be specified"); + return -EINVAL; + } + eid = nla_get_u8(tb[NDA_DST]); + + ndm = nlmsg_data(nlh); + dev = __dev_get_by_index(net, ndm->ndm_ifindex); + if (!dev) + return -ENODEV; + + mdev = mctp_dev_get_rtnl(dev); + if (!mdev) + return -ENODEV; + + return mctp_neigh_remove(mdev, eid, MCTP_NEIGH_STATIC); +} + +static int mctp_fill_neigh(struct sk_buff *skb, u32 portid, u32 seq, int event, + unsigned int flags, struct mctp_neigh *neigh) +{ + struct net_device *dev = neigh->dev->dev; + struct nlmsghdr *nlh; + struct ndmsg *hdr; + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(*hdr), flags); + if (!nlh) + return -EMSGSIZE; + + hdr = nlmsg_data(nlh); + hdr->ndm_family = AF_MCTP; + hdr->ndm_ifindex = dev->ifindex; + hdr->ndm_state = 0; // TODO other state bits? + if (neigh->source == MCTP_NEIGH_STATIC) + hdr->ndm_state |= NUD_PERMANENT; + hdr->ndm_flags = 0; + hdr->ndm_type = RTN_UNICAST; // TODO: is loopback RTN_LOCAL? + + if (nla_put_u8(skb, NDA_DST, neigh->eid)) + goto cancel; + + if (nla_put(skb, NDA_LLADDR, dev->addr_len, neigh->ha)) + goto cancel; + + nlmsg_end(skb, nlh); + + return 0; +cancel: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int mctp_rtm_getneigh(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + int rc, idx, req_ifindex; + struct mctp_neigh *neigh; + struct ndmsg *ndmsg; + struct { + int idx; + } *cbctx = (void *)cb->ctx; + + ndmsg = nlmsg_data(cb->nlh); + req_ifindex = ndmsg->ndm_ifindex; + + idx = 0; + rcu_read_lock(); + list_for_each_entry_rcu(neigh, &net->mctp.neighbours, list) { + if (idx < cbctx->idx) + goto cont; + + rc = 0; + if (req_ifindex == 0 || req_ifindex == neigh->dev->dev->ifindex) + rc = mctp_fill_neigh(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + RTM_NEWNEIGH, NLM_F_MULTI, neigh); + + if (rc) + break; +cont: + idx++; + } + rcu_read_unlock(); + + cbctx->idx = idx; + return skb->len; +} + +int mctp_neigh_lookup(struct mctp_dev *mdev, mctp_eid_t eid, void *ret_hwaddr) +{ + struct net *net = dev_net(mdev->dev); + struct mctp_neigh *neigh; + int rc = -EHOSTUNREACH; // TODO: or ENOENT? + + rcu_read_lock(); + list_for_each_entry_rcu(neigh, &net->mctp.neighbours, list) { + if (mdev == neigh->dev && eid == neigh->eid) { + if (ret_hwaddr) + memcpy(ret_hwaddr, neigh->ha, + sizeof(neigh->ha)); + rc = 0; + break; + } + } + rcu_read_unlock(); + return rc; +} + +/* namespace registration */ +static int __net_init mctp_neigh_net_init(struct net *net) +{ + struct netns_mctp *ns = &net->mctp; + + INIT_LIST_HEAD(&ns->neighbours); + mutex_init(&ns->neigh_lock); + return 0; +} + +static void __net_exit mctp_neigh_net_exit(struct net *net) +{ + struct netns_mctp *ns = &net->mctp; + struct mctp_neigh *neigh; + + list_for_each_entry(neigh, &ns->neighbours, list) + call_rcu(&neigh->rcu, __mctp_neigh_free); +} + +/* net namespace implementation */ + +static struct pernet_operations mctp_net_ops = { + .init = mctp_neigh_net_init, + .exit = mctp_neigh_net_exit, +}; + +int __init mctp_neigh_init(void) +{ + rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_NEWNEIGH, + mctp_rtm_newneigh, NULL, 0); + rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_DELNEIGH, + mctp_rtm_delneigh, NULL, 0); + rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_GETNEIGH, + NULL, mctp_rtm_getneigh, 0); + + return register_pernet_subsys(&mctp_net_ops); +} + +void __exit mctp_neigh_exit(void) +{ + unregister_pernet_subsys(&mctp_net_ops); + rtnl_unregister(PF_MCTP, RTM_GETNEIGH); + rtnl_unregister(PF_MCTP, RTM_DELNEIGH); + rtnl_unregister(PF_MCTP, RTM_NEWNEIGH); +} diff --git a/net/mctp/route.c b/net/mctp/route.c new file mode 100644 index 000000000..68be8f2b6 --- /dev/null +++ b/net/mctp/route.c @@ -0,0 +1,1432 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Management Component Transport Protocol (MCTP) - routing + * implementation. + * + * This is currently based on a simple routing table, with no dst cache. The + * number of routes should stay fairly small, so the lookup cost is small. + * + * Copyright (c) 2021 Code Construct + * Copyright (c) 2021 Google + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +#include + +static const unsigned int mctp_message_maxlen = 64 * 1024; +static const unsigned long mctp_key_lifetime = 6 * CONFIG_HZ; + +static void mctp_flow_prepare_output(struct sk_buff *skb, struct mctp_dev *dev); + +/* route output callbacks */ +static int mctp_route_discard(struct mctp_route *route, struct sk_buff *skb) +{ + kfree_skb(skb); + return 0; +} + +static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb) +{ + struct mctp_skb_cb *cb = mctp_cb(skb); + struct mctp_hdr *mh; + struct sock *sk; + u8 type; + + WARN_ON(!rcu_read_lock_held()); + + /* TODO: look up in skb->cb? */ + mh = mctp_hdr(skb); + + if (!skb_headlen(skb)) + return NULL; + + type = (*(u8 *)skb->data) & 0x7f; + + sk_for_each_rcu(sk, &net->mctp.binds) { + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + + if (msk->bind_net != MCTP_NET_ANY && msk->bind_net != cb->net) + continue; + + if (msk->bind_type != type) + continue; + + if (!mctp_address_matches(msk->bind_addr, mh->dest)) + continue; + + return msk; + } + + return NULL; +} + +static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local, + mctp_eid_t peer, u8 tag) +{ + if (!mctp_address_matches(key->local_addr, local)) + return false; + + if (key->peer_addr != peer) + return false; + + if (key->tag != tag) + return false; + + return true; +} + +/* returns a key (with key->lock held, and refcounted), or NULL if no such + * key exists. + */ +static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb, + mctp_eid_t peer, + unsigned long *irqflags) + __acquires(&key->lock) +{ + struct mctp_sk_key *key, *ret; + unsigned long flags; + struct mctp_hdr *mh; + u8 tag; + + mh = mctp_hdr(skb); + tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO); + + ret = NULL; + spin_lock_irqsave(&net->mctp.keys_lock, flags); + + hlist_for_each_entry(key, &net->mctp.keys, hlist) { + if (!mctp_key_match(key, mh->dest, peer, tag)) + continue; + + spin_lock(&key->lock); + if (key->valid) { + refcount_inc(&key->refs); + ret = key; + break; + } + spin_unlock(&key->lock); + } + + if (ret) { + spin_unlock(&net->mctp.keys_lock); + *irqflags = flags; + } else { + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + } + + return ret; +} + +static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk, + mctp_eid_t local, mctp_eid_t peer, + u8 tag, gfp_t gfp) +{ + struct mctp_sk_key *key; + + key = kzalloc(sizeof(*key), gfp); + if (!key) + return NULL; + + key->peer_addr = peer; + key->local_addr = local; + key->tag = tag; + key->sk = &msk->sk; + key->valid = true; + spin_lock_init(&key->lock); + refcount_set(&key->refs, 1); + sock_hold(key->sk); + + return key; +} + +void mctp_key_unref(struct mctp_sk_key *key) +{ + unsigned long flags; + + if (!refcount_dec_and_test(&key->refs)) + return; + + /* even though no refs exist here, the lock allows us to stay + * consistent with the locking requirement of mctp_dev_release_key + */ + spin_lock_irqsave(&key->lock, flags); + mctp_dev_release_key(key->dev, key); + spin_unlock_irqrestore(&key->lock, flags); + + sock_put(key->sk); + kfree(key); +} + +static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk) +{ + struct net *net = sock_net(&msk->sk); + struct mctp_sk_key *tmp; + unsigned long flags; + int rc = 0; + + spin_lock_irqsave(&net->mctp.keys_lock, flags); + + if (sock_flag(&msk->sk, SOCK_DEAD)) { + rc = -EINVAL; + goto out_unlock; + } + + hlist_for_each_entry(tmp, &net->mctp.keys, hlist) { + if (mctp_key_match(tmp, key->local_addr, key->peer_addr, + key->tag)) { + spin_lock(&tmp->lock); + if (tmp->valid) + rc = -EEXIST; + spin_unlock(&tmp->lock); + if (rc) + break; + } + } + + if (!rc) { + refcount_inc(&key->refs); + key->expiry = jiffies + mctp_key_lifetime; + timer_reduce(&msk->key_expiry, key->expiry); + + hlist_add_head(&key->hlist, &net->mctp.keys); + hlist_add_head(&key->sklist, &msk->keys); + } + +out_unlock: + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + + return rc; +} + +/* Helper for mctp_route_input(). + * We're done with the key; unlock and unref the key. + * For the usual case of automatic expiry we remove the key from lists. + * In the case that manual allocation is set on a key we release the lock + * and local ref, reset reassembly, but don't remove from lists. + */ +static void __mctp_key_done_in(struct mctp_sk_key *key, struct net *net, + unsigned long flags, unsigned long reason) +__releases(&key->lock) +{ + struct sk_buff *skb; + + trace_mctp_key_release(key, reason); + skb = key->reasm_head; + key->reasm_head = NULL; + + if (!key->manual_alloc) { + key->reasm_dead = true; + key->valid = false; + mctp_dev_release_key(key->dev, key); + } + spin_unlock_irqrestore(&key->lock, flags); + + if (!key->manual_alloc) { + spin_lock_irqsave(&net->mctp.keys_lock, flags); + if (!hlist_unhashed(&key->hlist)) { + hlist_del_init(&key->hlist); + hlist_del_init(&key->sklist); + mctp_key_unref(key); + } + spin_unlock_irqrestore(&net->mctp.keys_lock, flags); + } + + /* and one for the local reference */ + mctp_key_unref(key); + + kfree_skb(skb); +} + +#ifdef CONFIG_MCTP_FLOWS +static void mctp_skb_set_flow(struct sk_buff *skb, struct mctp_sk_key *key) +{ + struct mctp_flow *flow; + + flow = skb_ext_add(skb, SKB_EXT_MCTP); + if (!flow) + return; + + refcount_inc(&key->refs); + flow->key = key; +} + +static void mctp_flow_prepare_output(struct sk_buff *skb, struct mctp_dev *dev) +{ + struct mctp_sk_key *key; + struct mctp_flow *flow; + + flow = skb_ext_find(skb, SKB_EXT_MCTP); + if (!flow) + return; + + key = flow->key; + + if (WARN_ON(key->dev && key->dev != dev)) + return; + + mctp_dev_set_key(dev, key); +} +#else +static void mctp_skb_set_flow(struct sk_buff *skb, struct mctp_sk_key *key) {} +static void mctp_flow_prepare_output(struct sk_buff *skb, struct mctp_dev *dev) {} +#endif + +static int mctp_frag_queue(struct mctp_sk_key *key, struct sk_buff *skb) +{ + struct mctp_hdr *hdr = mctp_hdr(skb); + u8 exp_seq, this_seq; + + this_seq = (hdr->flags_seq_tag >> MCTP_HDR_SEQ_SHIFT) + & MCTP_HDR_SEQ_MASK; + + if (!key->reasm_head) { + key->reasm_head = skb; + key->reasm_tailp = &(skb_shinfo(skb)->frag_list); + key->last_seq = this_seq; + return 0; + } + + exp_seq = (key->last_seq + 1) & MCTP_HDR_SEQ_MASK; + + if (this_seq != exp_seq) + return -EINVAL; + + if (key->reasm_head->len + skb->len > mctp_message_maxlen) + return -EINVAL; + + skb->next = NULL; + skb->sk = NULL; + *key->reasm_tailp = skb; + key->reasm_tailp = &skb->next; + + key->last_seq = this_seq; + + key->reasm_head->data_len += skb->len; + key->reasm_head->len += skb->len; + key->reasm_head->truesize += skb->truesize; + + return 0; +} + +static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb) +{ + struct mctp_sk_key *key, *any_key = NULL; + struct net *net = dev_net(skb->dev); + struct mctp_sock *msk; + struct mctp_hdr *mh; + unsigned long f; + u8 tag, flags; + int rc; + + msk = NULL; + rc = -EINVAL; + + /* we may be receiving a locally-routed packet; drop source sk + * accounting + */ + skb_orphan(skb); + + /* ensure we have enough data for a header and a type */ + if (skb->len < sizeof(struct mctp_hdr) + 1) + goto out; + + /* grab header, advance data ptr */ + mh = mctp_hdr(skb); + skb_pull(skb, sizeof(struct mctp_hdr)); + + if (mh->ver != 1) + goto out; + + flags = mh->flags_seq_tag & (MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM); + tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO); + + rcu_read_lock(); + + /* lookup socket / reasm context, exactly matching (src,dest,tag). + * we hold a ref on the key, and key->lock held. + */ + key = mctp_lookup_key(net, skb, mh->src, &f); + + if (flags & MCTP_HDR_FLAG_SOM) { + if (key) { + msk = container_of(key->sk, struct mctp_sock, sk); + } else { + /* first response to a broadcast? do a more general + * key lookup to find the socket, but don't use this + * key for reassembly - we'll create a more specific + * one for future packets if required (ie, !EOM). + */ + any_key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY, &f); + if (any_key) { + msk = container_of(any_key->sk, + struct mctp_sock, sk); + spin_unlock_irqrestore(&any_key->lock, f); + } + } + + if (!key && !msk && (tag & MCTP_HDR_FLAG_TO)) + msk = mctp_lookup_bind(net, skb); + + if (!msk) { + rc = -ENOENT; + goto out_unlock; + } + + /* single-packet message? deliver to socket, clean up any + * pending key. + */ + if (flags & MCTP_HDR_FLAG_EOM) { + sock_queue_rcv_skb(&msk->sk, skb); + if (key) { + /* we've hit a pending reassembly; not much we + * can do but drop it + */ + __mctp_key_done_in(key, net, f, + MCTP_TRACE_KEY_REPLIED); + key = NULL; + } + rc = 0; + goto out_unlock; + } + + /* broadcast response or a bind() - create a key for further + * packets for this message + */ + if (!key) { + key = mctp_key_alloc(msk, mh->dest, mh->src, + tag, GFP_ATOMIC); + if (!key) { + rc = -ENOMEM; + goto out_unlock; + } + + /* we can queue without the key lock here, as the + * key isn't observable yet + */ + mctp_frag_queue(key, skb); + + /* if the key_add fails, we've raced with another + * SOM packet with the same src, dest and tag. There's + * no way to distinguish future packets, so all we + * can do is drop; we'll free the skb on exit from + * this function. + */ + rc = mctp_key_add(key, msk); + if (!rc) + trace_mctp_key_acquire(key); + + /* we don't need to release key->lock on exit, so + * clean up here and suppress the unlock via + * setting to NULL + */ + mctp_key_unref(key); + key = NULL; + + } else { + if (key->reasm_head || key->reasm_dead) { + /* duplicate start? drop everything */ + __mctp_key_done_in(key, net, f, + MCTP_TRACE_KEY_INVALIDATED); + rc = -EEXIST; + key = NULL; + } else { + rc = mctp_frag_queue(key, skb); + } + } + + } else if (key) { + /* this packet continues a previous message; reassemble + * using the message-specific key + */ + + /* we need to be continuing an existing reassembly... */ + if (!key->reasm_head) + rc = -EINVAL; + else + rc = mctp_frag_queue(key, skb); + + /* end of message? deliver to socket, and we're done with + * the reassembly/response key + */ + if (!rc && flags & MCTP_HDR_FLAG_EOM) { + sock_queue_rcv_skb(key->sk, key->reasm_head); + key->reasm_head = NULL; + __mctp_key_done_in(key, net, f, MCTP_TRACE_KEY_REPLIED); + key = NULL; + } + + } else { + /* not a start, no matching key */ + rc = -ENOENT; + } + +out_unlock: + rcu_read_unlock(); + if (key) { + spin_unlock_irqrestore(&key->lock, f); + mctp_key_unref(key); + } + if (any_key) + mctp_key_unref(any_key); +out: + if (rc) + kfree_skb(skb); + return rc; +} + +static unsigned int mctp_route_mtu(struct mctp_route *rt) +{ + return rt->mtu ?: READ_ONCE(rt->dev->dev->mtu); +} + +static int mctp_route_output(struct mctp_route *route, struct sk_buff *skb) +{ + struct mctp_skb_cb *cb = mctp_cb(skb); + struct mctp_hdr *hdr = mctp_hdr(skb); + char daddr_buf[MAX_ADDR_LEN]; + char *daddr = NULL; + unsigned int mtu; + int rc; + + skb->protocol = htons(ETH_P_MCTP); + + mtu = READ_ONCE(skb->dev->mtu); + if (skb->len > mtu) { + kfree_skb(skb); + return -EMSGSIZE; + } + + if (cb->ifindex) { + /* direct route; use the hwaddr we stashed in sendmsg */ + if (cb->halen != skb->dev->addr_len) { + /* sanity check, sendmsg should have already caught this */ + kfree_skb(skb); + return -EMSGSIZE; + } + daddr = cb->haddr; + } else { + /* If lookup fails let the device handle daddr==NULL */ + if (mctp_neigh_lookup(route->dev, hdr->dest, daddr_buf) == 0) + daddr = daddr_buf; + } + + rc = dev_hard_header(skb, skb->dev, ntohs(skb->protocol), + daddr, skb->dev->dev_addr, skb->len); + if (rc < 0) { + kfree_skb(skb); + return -EHOSTUNREACH; + } + + mctp_flow_prepare_output(skb, route->dev); + + rc = dev_queue_xmit(skb); + if (rc) + rc = net_xmit_errno(rc); + + return rc; +} + +/* route alloc/release */ +static void mctp_route_release(struct mctp_route *rt) +{ + if (refcount_dec_and_test(&rt->refs)) { + mctp_dev_put(rt->dev); + kfree_rcu(rt, rcu); + } +} + +/* returns a route with the refcount at 1 */ +static struct mctp_route *mctp_route_alloc(void) +{ + struct mctp_route *rt; + + rt = kzalloc(sizeof(*rt), GFP_KERNEL); + if (!rt) + return NULL; + + INIT_LIST_HEAD(&rt->list); + refcount_set(&rt->refs, 1); + rt->output = mctp_route_discard; + + return rt; +} + +unsigned int mctp_default_net(struct net *net) +{ + return READ_ONCE(net->mctp.default_net); +} + +int mctp_default_net_set(struct net *net, unsigned int index) +{ + if (index == 0) + return -EINVAL; + WRITE_ONCE(net->mctp.default_net, index); + return 0; +} + +/* tag management */ +static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key, + struct mctp_sock *msk) +{ + struct netns_mctp *mns = &net->mctp; + + lockdep_assert_held(&mns->keys_lock); + + key->expiry = jiffies + mctp_key_lifetime; + timer_reduce(&msk->key_expiry, key->expiry); + + /* we hold the net->key_lock here, allowing updates to both + * then net and sk + */ + hlist_add_head_rcu(&key->hlist, &mns->keys); + hlist_add_head_rcu(&key->sklist, &msk->keys); + refcount_inc(&key->refs); +} + +/* Allocate a locally-owned tag value for (saddr, daddr), and reserve + * it for the socket msk + */ +struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk, + mctp_eid_t daddr, mctp_eid_t saddr, + bool manual, u8 *tagp) +{ + struct net *net = sock_net(&msk->sk); + struct netns_mctp *mns = &net->mctp; + struct mctp_sk_key *key, *tmp; + unsigned long flags; + u8 tagbits; + + /* for NULL destination EIDs, we may get a response from any peer */ + if (daddr == MCTP_ADDR_NULL) + daddr = MCTP_ADDR_ANY; + + /* be optimistic, alloc now */ + key = mctp_key_alloc(msk, saddr, daddr, 0, GFP_KERNEL); + if (!key) + return ERR_PTR(-ENOMEM); + + /* 8 possible tag values */ + tagbits = 0xff; + + spin_lock_irqsave(&mns->keys_lock, flags); + + /* Walk through the existing keys, looking for potential conflicting + * tags. If we find a conflict, clear that bit from tagbits + */ + hlist_for_each_entry(tmp, &mns->keys, hlist) { + /* We can check the lookup fields (*_addr, tag) without the + * lock held, they don't change over the lifetime of the key. + */ + + /* if we don't own the tag, it can't conflict */ + if (tmp->tag & MCTP_HDR_FLAG_TO) + continue; + + if (!(mctp_address_matches(tmp->peer_addr, daddr) && + mctp_address_matches(tmp->local_addr, saddr))) + continue; + + spin_lock(&tmp->lock); + /* key must still be valid. If we find a match, clear the + * potential tag value + */ + if (tmp->valid) + tagbits &= ~(1 << tmp->tag); + spin_unlock(&tmp->lock); + + if (!tagbits) + break; + } + + if (tagbits) { + key->tag = __ffs(tagbits); + mctp_reserve_tag(net, key, msk); + trace_mctp_key_acquire(key); + + key->manual_alloc = manual; + *tagp = key->tag; + } + + spin_unlock_irqrestore(&mns->keys_lock, flags); + + if (!tagbits) { + kfree(key); + return ERR_PTR(-EBUSY); + } + + return key; +} + +static struct mctp_sk_key *mctp_lookup_prealloc_tag(struct mctp_sock *msk, + mctp_eid_t daddr, + u8 req_tag, u8 *tagp) +{ + struct net *net = sock_net(&msk->sk); + struct netns_mctp *mns = &net->mctp; + struct mctp_sk_key *key, *tmp; + unsigned long flags; + + req_tag &= ~(MCTP_TAG_PREALLOC | MCTP_TAG_OWNER); + key = NULL; + + spin_lock_irqsave(&mns->keys_lock, flags); + + hlist_for_each_entry(tmp, &mns->keys, hlist) { + if (tmp->tag != req_tag) + continue; + + if (!mctp_address_matches(tmp->peer_addr, daddr)) + continue; + + if (!tmp->manual_alloc) + continue; + + spin_lock(&tmp->lock); + if (tmp->valid) { + key = tmp; + refcount_inc(&key->refs); + spin_unlock(&tmp->lock); + break; + } + spin_unlock(&tmp->lock); + } + spin_unlock_irqrestore(&mns->keys_lock, flags); + + if (!key) + return ERR_PTR(-ENOENT); + + if (tagp) + *tagp = key->tag; + + return key; +} + +/* routing lookups */ +static bool mctp_rt_match_eid(struct mctp_route *rt, + unsigned int net, mctp_eid_t eid) +{ + return READ_ONCE(rt->dev->net) == net && + rt->min <= eid && rt->max >= eid; +} + +/* compares match, used for duplicate prevention */ +static bool mctp_rt_compare_exact(struct mctp_route *rt1, + struct mctp_route *rt2) +{ + ASSERT_RTNL(); + return rt1->dev->net == rt2->dev->net && + rt1->min == rt2->min && + rt1->max == rt2->max; +} + +struct mctp_route *mctp_route_lookup(struct net *net, unsigned int dnet, + mctp_eid_t daddr) +{ + struct mctp_route *tmp, *rt = NULL; + + rcu_read_lock(); + + list_for_each_entry_rcu(tmp, &net->mctp.routes, list) { + /* TODO: add metrics */ + if (mctp_rt_match_eid(tmp, dnet, daddr)) { + if (refcount_inc_not_zero(&tmp->refs)) { + rt = tmp; + break; + } + } + } + + rcu_read_unlock(); + + return rt; +} + +static struct mctp_route *mctp_route_lookup_null(struct net *net, + struct net_device *dev) +{ + struct mctp_route *tmp, *rt = NULL; + + rcu_read_lock(); + + list_for_each_entry_rcu(tmp, &net->mctp.routes, list) { + if (tmp->dev->dev == dev && tmp->type == RTN_LOCAL && + refcount_inc_not_zero(&tmp->refs)) { + rt = tmp; + break; + } + } + + rcu_read_unlock(); + + return rt; +} + +static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb, + unsigned int mtu, u8 tag) +{ + const unsigned int hlen = sizeof(struct mctp_hdr); + struct mctp_hdr *hdr, *hdr2; + unsigned int pos, size, headroom; + struct sk_buff *skb2; + int rc; + u8 seq; + + hdr = mctp_hdr(skb); + seq = 0; + rc = 0; + + if (mtu < hlen + 1) { + kfree_skb(skb); + return -EMSGSIZE; + } + + /* keep same headroom as the original skb */ + headroom = skb_headroom(skb); + + /* we've got the header */ + skb_pull(skb, hlen); + + for (pos = 0; pos < skb->len;) { + /* size of message payload */ + size = min(mtu - hlen, skb->len - pos); + + skb2 = alloc_skb(headroom + hlen + size, GFP_KERNEL); + if (!skb2) { + rc = -ENOMEM; + break; + } + + /* generic skb copy */ + skb2->protocol = skb->protocol; + skb2->priority = skb->priority; + skb2->dev = skb->dev; + memcpy(skb2->cb, skb->cb, sizeof(skb2->cb)); + + if (skb->sk) + skb_set_owner_w(skb2, skb->sk); + + /* establish packet */ + skb_reserve(skb2, headroom); + skb_reset_network_header(skb2); + skb_put(skb2, hlen + size); + skb2->transport_header = skb2->network_header + hlen; + + /* copy header fields, calculate SOM/EOM flags & seq */ + hdr2 = mctp_hdr(skb2); + hdr2->ver = hdr->ver; + hdr2->dest = hdr->dest; + hdr2->src = hdr->src; + hdr2->flags_seq_tag = tag & + (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO); + + if (pos == 0) + hdr2->flags_seq_tag |= MCTP_HDR_FLAG_SOM; + + if (pos + size == skb->len) + hdr2->flags_seq_tag |= MCTP_HDR_FLAG_EOM; + + hdr2->flags_seq_tag |= seq << MCTP_HDR_SEQ_SHIFT; + + /* copy message payload */ + skb_copy_bits(skb, pos, skb_transport_header(skb2), size); + + /* do route */ + rc = rt->output(rt, skb2); + if (rc) + break; + + seq = (seq + 1) & MCTP_HDR_SEQ_MASK; + pos += size; + } + + consume_skb(skb); + return rc; +} + +int mctp_local_output(struct sock *sk, struct mctp_route *rt, + struct sk_buff *skb, mctp_eid_t daddr, u8 req_tag) +{ + struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); + struct mctp_skb_cb *cb = mctp_cb(skb); + struct mctp_route tmp_rt = {0}; + struct mctp_sk_key *key; + struct mctp_hdr *hdr; + unsigned long flags; + unsigned int mtu; + mctp_eid_t saddr; + bool ext_rt; + int rc; + u8 tag; + + rc = -ENODEV; + + if (rt) { + ext_rt = false; + if (WARN_ON(!rt->dev)) + goto out_release; + + } else if (cb->ifindex) { + struct net_device *dev; + + ext_rt = true; + rt = &tmp_rt; + + rcu_read_lock(); + dev = dev_get_by_index_rcu(sock_net(sk), cb->ifindex); + if (!dev) { + rcu_read_unlock(); + return rc; + } + rt->dev = __mctp_dev_get(dev); + rcu_read_unlock(); + + if (!rt->dev) + goto out_release; + + /* establish temporary route - we set up enough to keep + * mctp_route_output happy + */ + rt->output = mctp_route_output; + rt->mtu = 0; + + } else { + return -EINVAL; + } + + spin_lock_irqsave(&rt->dev->addrs_lock, flags); + if (rt->dev->num_addrs == 0) { + rc = -EHOSTUNREACH; + } else { + /* use the outbound interface's first address as our source */ + saddr = rt->dev->addrs[0]; + rc = 0; + } + spin_unlock_irqrestore(&rt->dev->addrs_lock, flags); + + if (rc) + goto out_release; + + if (req_tag & MCTP_TAG_OWNER) { + if (req_tag & MCTP_TAG_PREALLOC) + key = mctp_lookup_prealloc_tag(msk, daddr, + req_tag, &tag); + else + key = mctp_alloc_local_tag(msk, daddr, saddr, + false, &tag); + + if (IS_ERR(key)) { + rc = PTR_ERR(key); + goto out_release; + } + mctp_skb_set_flow(skb, key); + /* done with the key in this scope */ + mctp_key_unref(key); + tag |= MCTP_HDR_FLAG_TO; + } else { + key = NULL; + tag = req_tag & MCTP_TAG_MASK; + } + + skb->protocol = htons(ETH_P_MCTP); + skb->priority = 0; + skb_reset_transport_header(skb); + skb_push(skb, sizeof(struct mctp_hdr)); + skb_reset_network_header(skb); + skb->dev = rt->dev->dev; + + /* cb->net will have been set on initial ingress */ + cb->src = saddr; + + /* set up common header fields */ + hdr = mctp_hdr(skb); + hdr->ver = 1; + hdr->dest = daddr; + hdr->src = saddr; + + mtu = mctp_route_mtu(rt); + + if (skb->len + sizeof(struct mctp_hdr) <= mtu) { + hdr->flags_seq_tag = MCTP_HDR_FLAG_SOM | + MCTP_HDR_FLAG_EOM | tag; + rc = rt->output(rt, skb); + } else { + rc = mctp_do_fragment_route(rt, skb, mtu, tag); + } + +out_release: + if (!ext_rt) + mctp_route_release(rt); + + mctp_dev_put(tmp_rt.dev); + + return rc; +} + +/* route management */ +static int mctp_route_add(struct mctp_dev *mdev, mctp_eid_t daddr_start, + unsigned int daddr_extent, unsigned int mtu, + unsigned char type) +{ + int (*rtfn)(struct mctp_route *rt, struct sk_buff *skb); + struct net *net = dev_net(mdev->dev); + struct mctp_route *rt, *ert; + + if (!mctp_address_unicast(daddr_start)) + return -EINVAL; + + if (daddr_extent > 0xff || daddr_start + daddr_extent >= 255) + return -EINVAL; + + switch (type) { + case RTN_LOCAL: + rtfn = mctp_route_input; + break; + case RTN_UNICAST: + rtfn = mctp_route_output; + break; + default: + return -EINVAL; + } + + rt = mctp_route_alloc(); + if (!rt) + return -ENOMEM; + + rt->min = daddr_start; + rt->max = daddr_start + daddr_extent; + rt->mtu = mtu; + rt->dev = mdev; + mctp_dev_hold(rt->dev); + rt->type = type; + rt->output = rtfn; + + ASSERT_RTNL(); + /* Prevent duplicate identical routes. */ + list_for_each_entry(ert, &net->mctp.routes, list) { + if (mctp_rt_compare_exact(rt, ert)) { + mctp_route_release(rt); + return -EEXIST; + } + } + + list_add_rcu(&rt->list, &net->mctp.routes); + + return 0; +} + +static int mctp_route_remove(struct mctp_dev *mdev, mctp_eid_t daddr_start, + unsigned int daddr_extent, unsigned char type) +{ + struct net *net = dev_net(mdev->dev); + struct mctp_route *rt, *tmp; + mctp_eid_t daddr_end; + bool dropped; + + if (daddr_extent > 0xff || daddr_start + daddr_extent >= 255) + return -EINVAL; + + daddr_end = daddr_start + daddr_extent; + dropped = false; + + ASSERT_RTNL(); + + list_for_each_entry_safe(rt, tmp, &net->mctp.routes, list) { + if (rt->dev == mdev && + rt->min == daddr_start && rt->max == daddr_end && + rt->type == type) { + list_del_rcu(&rt->list); + /* TODO: immediate RTM_DELROUTE */ + mctp_route_release(rt); + dropped = true; + } + } + + return dropped ? 0 : -ENOENT; +} + +int mctp_route_add_local(struct mctp_dev *mdev, mctp_eid_t addr) +{ + return mctp_route_add(mdev, addr, 0, 0, RTN_LOCAL); +} + +int mctp_route_remove_local(struct mctp_dev *mdev, mctp_eid_t addr) +{ + return mctp_route_remove(mdev, addr, 0, RTN_LOCAL); +} + +/* removes all entries for a given device */ +void mctp_route_remove_dev(struct mctp_dev *mdev) +{ + struct net *net = dev_net(mdev->dev); + struct mctp_route *rt, *tmp; + + ASSERT_RTNL(); + list_for_each_entry_safe(rt, tmp, &net->mctp.routes, list) { + if (rt->dev == mdev) { + list_del_rcu(&rt->list); + /* TODO: immediate RTM_DELROUTE */ + mctp_route_release(rt); + } + } +} + +/* Incoming packet-handling */ + +static int mctp_pkttype_receive(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, + struct net_device *orig_dev) +{ + struct net *net = dev_net(dev); + struct mctp_dev *mdev; + struct mctp_skb_cb *cb; + struct mctp_route *rt; + struct mctp_hdr *mh; + + rcu_read_lock(); + mdev = __mctp_dev_get(dev); + rcu_read_unlock(); + if (!mdev) { + /* basic non-data sanity checks */ + goto err_drop; + } + + if (!pskb_may_pull(skb, sizeof(struct mctp_hdr))) + goto err_drop; + + skb_reset_transport_header(skb); + skb_reset_network_header(skb); + + /* We have enough for a header; decode and route */ + mh = mctp_hdr(skb); + if (mh->ver < MCTP_VER_MIN || mh->ver > MCTP_VER_MAX) + goto err_drop; + + /* source must be valid unicast or null; drop reserved ranges and + * broadcast + */ + if (!(mctp_address_unicast(mh->src) || mctp_address_null(mh->src))) + goto err_drop; + + /* dest address: as above, but allow broadcast */ + if (!(mctp_address_unicast(mh->dest) || mctp_address_null(mh->dest) || + mctp_address_broadcast(mh->dest))) + goto err_drop; + + /* MCTP drivers must populate halen/haddr */ + if (dev->type == ARPHRD_MCTP) { + cb = mctp_cb(skb); + } else { + cb = __mctp_cb(skb); + cb->halen = 0; + } + cb->net = READ_ONCE(mdev->net); + cb->ifindex = dev->ifindex; + + rt = mctp_route_lookup(net, cb->net, mh->dest); + + /* NULL EID, but addressed to our physical address */ + if (!rt && mh->dest == MCTP_ADDR_NULL && skb->pkt_type == PACKET_HOST) + rt = mctp_route_lookup_null(net, dev); + + if (!rt) + goto err_drop; + + rt->output(rt, skb); + mctp_route_release(rt); + mctp_dev_put(mdev); + + return NET_RX_SUCCESS; + +err_drop: + kfree_skb(skb); + mctp_dev_put(mdev); + return NET_RX_DROP; +} + +static struct packet_type mctp_packet_type = { + .type = cpu_to_be16(ETH_P_MCTP), + .func = mctp_pkttype_receive, +}; + +/* netlink interface */ + +static const struct nla_policy rta_mctp_policy[RTA_MAX + 1] = { + [RTA_DST] = { .type = NLA_U8 }, + [RTA_METRICS] = { .type = NLA_NESTED }, + [RTA_OIF] = { .type = NLA_U32 }, +}; + +/* Common part for RTM_NEWROUTE and RTM_DELROUTE parsing. + * tb must hold RTA_MAX+1 elements. + */ +static int mctp_route_nlparse(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack, + struct nlattr **tb, struct rtmsg **rtm, + struct mctp_dev **mdev, mctp_eid_t *daddr_start) +{ + struct net *net = sock_net(skb->sk); + struct net_device *dev; + unsigned int ifindex; + int rc; + + rc = nlmsg_parse(nlh, sizeof(struct rtmsg), tb, RTA_MAX, + rta_mctp_policy, extack); + if (rc < 0) { + NL_SET_ERR_MSG(extack, "incorrect format"); + return rc; + } + + if (!tb[RTA_DST]) { + NL_SET_ERR_MSG(extack, "dst EID missing"); + return -EINVAL; + } + *daddr_start = nla_get_u8(tb[RTA_DST]); + + if (!tb[RTA_OIF]) { + NL_SET_ERR_MSG(extack, "ifindex missing"); + return -EINVAL; + } + ifindex = nla_get_u32(tb[RTA_OIF]); + + *rtm = nlmsg_data(nlh); + if ((*rtm)->rtm_family != AF_MCTP) { + NL_SET_ERR_MSG(extack, "route family must be AF_MCTP"); + return -EINVAL; + } + + dev = __dev_get_by_index(net, ifindex); + if (!dev) { + NL_SET_ERR_MSG(extack, "bad ifindex"); + return -ENODEV; + } + *mdev = mctp_dev_get_rtnl(dev); + if (!*mdev) + return -ENODEV; + + if (dev->flags & IFF_LOOPBACK) { + NL_SET_ERR_MSG(extack, "no routes to loopback"); + return -EINVAL; + } + + return 0; +} + +static const struct nla_policy rta_metrics_policy[RTAX_MAX + 1] = { + [RTAX_MTU] = { .type = NLA_U32 }, +}; + +static int mctp_newroute(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[RTA_MAX + 1]; + struct nlattr *tbx[RTAX_MAX + 1]; + mctp_eid_t daddr_start; + struct mctp_dev *mdev; + struct rtmsg *rtm; + unsigned int mtu; + int rc; + + rc = mctp_route_nlparse(skb, nlh, extack, tb, + &rtm, &mdev, &daddr_start); + if (rc < 0) + return rc; + + if (rtm->rtm_type != RTN_UNICAST) { + NL_SET_ERR_MSG(extack, "rtm_type must be RTN_UNICAST"); + return -EINVAL; + } + + mtu = 0; + if (tb[RTA_METRICS]) { + rc = nla_parse_nested(tbx, RTAX_MAX, tb[RTA_METRICS], + rta_metrics_policy, NULL); + if (rc < 0) + return rc; + if (tbx[RTAX_MTU]) + mtu = nla_get_u32(tbx[RTAX_MTU]); + } + + if (rtm->rtm_type != RTN_UNICAST) + return -EINVAL; + + rc = mctp_route_add(mdev, daddr_start, rtm->rtm_dst_len, mtu, + rtm->rtm_type); + return rc; +} + +static int mctp_delroute(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[RTA_MAX + 1]; + mctp_eid_t daddr_start; + struct mctp_dev *mdev; + struct rtmsg *rtm; + int rc; + + rc = mctp_route_nlparse(skb, nlh, extack, tb, + &rtm, &mdev, &daddr_start); + if (rc < 0) + return rc; + + /* we only have unicast routes */ + if (rtm->rtm_type != RTN_UNICAST) + return -EINVAL; + + rc = mctp_route_remove(mdev, daddr_start, rtm->rtm_dst_len, RTN_UNICAST); + return rc; +} + +static int mctp_fill_rtinfo(struct sk_buff *skb, struct mctp_route *rt, + u32 portid, u32 seq, int event, unsigned int flags) +{ + struct nlmsghdr *nlh; + struct rtmsg *hdr; + void *metrics; + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(*hdr), flags); + if (!nlh) + return -EMSGSIZE; + + hdr = nlmsg_data(nlh); + hdr->rtm_family = AF_MCTP; + + /* we use the _len fields as a number of EIDs, rather than + * a number of bits in the address + */ + hdr->rtm_dst_len = rt->max - rt->min; + hdr->rtm_src_len = 0; + hdr->rtm_tos = 0; + hdr->rtm_table = RT_TABLE_DEFAULT; + hdr->rtm_protocol = RTPROT_STATIC; /* everything is user-defined */ + hdr->rtm_scope = RT_SCOPE_LINK; /* TODO: scope in mctp_route? */ + hdr->rtm_type = rt->type; + + if (nla_put_u8(skb, RTA_DST, rt->min)) + goto cancel; + + metrics = nla_nest_start_noflag(skb, RTA_METRICS); + if (!metrics) + goto cancel; + + if (rt->mtu) { + if (nla_put_u32(skb, RTAX_MTU, rt->mtu)) + goto cancel; + } + + nla_nest_end(skb, metrics); + + if (rt->dev) { + if (nla_put_u32(skb, RTA_OIF, rt->dev->dev->ifindex)) + goto cancel; + } + + /* TODO: conditional neighbour physaddr? */ + + nlmsg_end(skb, nlh); + + return 0; + +cancel: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int mctp_dump_rtinfo(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + struct mctp_route *rt; + int s_idx, idx; + + /* TODO: allow filtering on route data, possibly under + * cb->strict_check + */ + + /* TODO: change to struct overlay */ + s_idx = cb->args[0]; + idx = 0; + + rcu_read_lock(); + list_for_each_entry_rcu(rt, &net->mctp.routes, list) { + if (idx++ < s_idx) + continue; + if (mctp_fill_rtinfo(skb, rt, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + RTM_NEWROUTE, NLM_F_MULTI) < 0) + break; + } + + rcu_read_unlock(); + cb->args[0] = idx; + + return skb->len; +} + +/* net namespace implementation */ +static int __net_init mctp_routes_net_init(struct net *net) +{ + struct netns_mctp *ns = &net->mctp; + + INIT_LIST_HEAD(&ns->routes); + INIT_HLIST_HEAD(&ns->binds); + mutex_init(&ns->bind_lock); + INIT_HLIST_HEAD(&ns->keys); + spin_lock_init(&ns->keys_lock); + WARN_ON(mctp_default_net_set(net, MCTP_INITIAL_DEFAULT_NET)); + return 0; +} + +static void __net_exit mctp_routes_net_exit(struct net *net) +{ + struct mctp_route *rt; + + rcu_read_lock(); + list_for_each_entry_rcu(rt, &net->mctp.routes, list) + mctp_route_release(rt); + rcu_read_unlock(); +} + +static struct pernet_operations mctp_net_ops = { + .init = mctp_routes_net_init, + .exit = mctp_routes_net_exit, +}; + +int __init mctp_routes_init(void) +{ + dev_add_pack(&mctp_packet_type); + + rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_GETROUTE, + NULL, mctp_dump_rtinfo, 0); + rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_NEWROUTE, + mctp_newroute, NULL, 0); + rtnl_register_module(THIS_MODULE, PF_MCTP, RTM_DELROUTE, + mctp_delroute, NULL, 0); + + return register_pernet_subsys(&mctp_net_ops); +} + +void mctp_routes_exit(void) +{ + unregister_pernet_subsys(&mctp_net_ops); + rtnl_unregister(PF_MCTP, RTM_DELROUTE); + rtnl_unregister(PF_MCTP, RTM_NEWROUTE); + rtnl_unregister(PF_MCTP, RTM_GETROUTE); + dev_remove_pack(&mctp_packet_type); +} + +#if IS_ENABLED(CONFIG_MCTP_TEST) +#include "test/route-test.c" +#endif diff --git a/net/mctp/test/route-test.c b/net/mctp/test/route-test.c new file mode 100644 index 000000000..92ea4158f --- /dev/null +++ b/net/mctp/test/route-test.c @@ -0,0 +1,684 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include + +#include "utils.h" + +struct mctp_test_route { + struct mctp_route rt; + struct sk_buff_head pkts; +}; + +static int mctp_test_route_output(struct mctp_route *rt, struct sk_buff *skb) +{ + struct mctp_test_route *test_rt = container_of(rt, struct mctp_test_route, rt); + + skb_queue_tail(&test_rt->pkts, skb); + + return 0; +} + +/* local version of mctp_route_alloc() */ +static struct mctp_test_route *mctp_route_test_alloc(void) +{ + struct mctp_test_route *rt; + + rt = kzalloc(sizeof(*rt), GFP_KERNEL); + if (!rt) + return NULL; + + INIT_LIST_HEAD(&rt->rt.list); + refcount_set(&rt->rt.refs, 1); + rt->rt.output = mctp_test_route_output; + + skb_queue_head_init(&rt->pkts); + + return rt; +} + +static struct mctp_test_route *mctp_test_create_route(struct net *net, + struct mctp_dev *dev, + mctp_eid_t eid, + unsigned int mtu) +{ + struct mctp_test_route *rt; + + rt = mctp_route_test_alloc(); + if (!rt) + return NULL; + + rt->rt.min = eid; + rt->rt.max = eid; + rt->rt.mtu = mtu; + rt->rt.type = RTN_UNSPEC; + if (dev) + mctp_dev_hold(dev); + rt->rt.dev = dev; + + list_add_rcu(&rt->rt.list, &net->mctp.routes); + + return rt; +} + +static void mctp_test_route_destroy(struct kunit *test, + struct mctp_test_route *rt) +{ + unsigned int refs; + + rtnl_lock(); + list_del_rcu(&rt->rt.list); + rtnl_unlock(); + + skb_queue_purge(&rt->pkts); + if (rt->rt.dev) + mctp_dev_put(rt->rt.dev); + + refs = refcount_read(&rt->rt.refs); + KUNIT_ASSERT_EQ_MSG(test, refs, 1, "route ref imbalance"); + + kfree_rcu(&rt->rt, rcu); +} + +static struct sk_buff *mctp_test_create_skb(const struct mctp_hdr *hdr, + unsigned int data_len) +{ + size_t hdr_len = sizeof(*hdr); + struct sk_buff *skb; + unsigned int i; + u8 *buf; + + skb = alloc_skb(hdr_len + data_len, GFP_KERNEL); + if (!skb) + return NULL; + + memcpy(skb_put(skb, hdr_len), hdr, hdr_len); + + buf = skb_put(skb, data_len); + for (i = 0; i < data_len; i++) + buf[i] = i & 0xff; + + return skb; +} + +static struct sk_buff *__mctp_test_create_skb_data(const struct mctp_hdr *hdr, + const void *data, + size_t data_len) +{ + size_t hdr_len = sizeof(*hdr); + struct sk_buff *skb; + + skb = alloc_skb(hdr_len + data_len, GFP_KERNEL); + if (!skb) + return NULL; + + memcpy(skb_put(skb, hdr_len), hdr, hdr_len); + memcpy(skb_put(skb, data_len), data, data_len); + + return skb; +} + +#define mctp_test_create_skb_data(h, d) \ + __mctp_test_create_skb_data(h, d, sizeof(*d)) + +struct mctp_frag_test { + unsigned int mtu; + unsigned int msgsize; + unsigned int n_frags; +}; + +static void mctp_test_fragment(struct kunit *test) +{ + const struct mctp_frag_test *params; + int rc, i, n, mtu, msgsize; + struct mctp_test_route *rt; + struct sk_buff *skb; + struct mctp_hdr hdr; + u8 seq; + + params = test->param_value; + mtu = params->mtu; + msgsize = params->msgsize; + + hdr.ver = 1; + hdr.src = 8; + hdr.dest = 10; + hdr.flags_seq_tag = MCTP_HDR_FLAG_TO; + + skb = mctp_test_create_skb(&hdr, msgsize); + KUNIT_ASSERT_TRUE(test, skb); + + rt = mctp_test_create_route(&init_net, NULL, 10, mtu); + KUNIT_ASSERT_TRUE(test, rt); + + rc = mctp_do_fragment_route(&rt->rt, skb, mtu, MCTP_TAG_OWNER); + KUNIT_EXPECT_FALSE(test, rc); + + n = rt->pkts.qlen; + + KUNIT_EXPECT_EQ(test, n, params->n_frags); + + for (i = 0;; i++) { + struct mctp_hdr *hdr2; + struct sk_buff *skb2; + u8 tag_mask, seq2; + bool first, last; + + first = i == 0; + last = i == (n - 1); + + skb2 = skb_dequeue(&rt->pkts); + + if (!skb2) + break; + + hdr2 = mctp_hdr(skb2); + + tag_mask = MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO; + + KUNIT_EXPECT_EQ(test, hdr2->ver, hdr.ver); + KUNIT_EXPECT_EQ(test, hdr2->src, hdr.src); + KUNIT_EXPECT_EQ(test, hdr2->dest, hdr.dest); + KUNIT_EXPECT_EQ(test, hdr2->flags_seq_tag & tag_mask, + hdr.flags_seq_tag & tag_mask); + + KUNIT_EXPECT_EQ(test, + !!(hdr2->flags_seq_tag & MCTP_HDR_FLAG_SOM), first); + KUNIT_EXPECT_EQ(test, + !!(hdr2->flags_seq_tag & MCTP_HDR_FLAG_EOM), last); + + seq2 = (hdr2->flags_seq_tag >> MCTP_HDR_SEQ_SHIFT) & + MCTP_HDR_SEQ_MASK; + + if (first) { + seq = seq2; + } else { + seq++; + KUNIT_EXPECT_EQ(test, seq2, seq & MCTP_HDR_SEQ_MASK); + } + + if (!last) + KUNIT_EXPECT_EQ(test, skb2->len, mtu); + else + KUNIT_EXPECT_LE(test, skb2->len, mtu); + + kfree_skb(skb2); + } + + mctp_test_route_destroy(test, rt); +} + +static const struct mctp_frag_test mctp_frag_tests[] = { + {.mtu = 68, .msgsize = 63, .n_frags = 1}, + {.mtu = 68, .msgsize = 64, .n_frags = 1}, + {.mtu = 68, .msgsize = 65, .n_frags = 2}, + {.mtu = 68, .msgsize = 66, .n_frags = 2}, + {.mtu = 68, .msgsize = 127, .n_frags = 2}, + {.mtu = 68, .msgsize = 128, .n_frags = 2}, + {.mtu = 68, .msgsize = 129, .n_frags = 3}, + {.mtu = 68, .msgsize = 130, .n_frags = 3}, +}; + +static void mctp_frag_test_to_desc(const struct mctp_frag_test *t, char *desc) +{ + sprintf(desc, "mtu %d len %d -> %d frags", + t->msgsize, t->mtu, t->n_frags); +} + +KUNIT_ARRAY_PARAM(mctp_frag, mctp_frag_tests, mctp_frag_test_to_desc); + +struct mctp_rx_input_test { + struct mctp_hdr hdr; + bool input; +}; + +static void mctp_test_rx_input(struct kunit *test) +{ + const struct mctp_rx_input_test *params; + struct mctp_test_route *rt; + struct mctp_test_dev *dev; + struct sk_buff *skb; + + params = test->param_value; + + dev = mctp_test_create_dev(); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev); + + rt = mctp_test_create_route(&init_net, dev->mdev, 8, 68); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt); + + skb = mctp_test_create_skb(¶ms->hdr, 1); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb); + + __mctp_cb(skb); + + mctp_pkttype_receive(skb, dev->ndev, &mctp_packet_type, NULL); + + KUNIT_EXPECT_EQ(test, !!rt->pkts.qlen, params->input); + + mctp_test_route_destroy(test, rt); + mctp_test_destroy_dev(dev); +} + +#define RX_HDR(_ver, _src, _dest, _fst) \ + { .ver = _ver, .src = _src, .dest = _dest, .flags_seq_tag = _fst } + +/* we have a route for EID 8 only */ +static const struct mctp_rx_input_test mctp_rx_input_tests[] = { + { .hdr = RX_HDR(1, 10, 8, 0), .input = true }, + { .hdr = RX_HDR(1, 10, 9, 0), .input = false }, /* no input route */ + { .hdr = RX_HDR(2, 10, 8, 0), .input = false }, /* invalid version */ +}; + +static void mctp_rx_input_test_to_desc(const struct mctp_rx_input_test *t, + char *desc) +{ + sprintf(desc, "{%x,%x,%x,%x}", t->hdr.ver, t->hdr.src, t->hdr.dest, + t->hdr.flags_seq_tag); +} + +KUNIT_ARRAY_PARAM(mctp_rx_input, mctp_rx_input_tests, + mctp_rx_input_test_to_desc); + +/* set up a local dev, route on EID 8, and a socket listening on type 0 */ +static void __mctp_route_test_init(struct kunit *test, + struct mctp_test_dev **devp, + struct mctp_test_route **rtp, + struct socket **sockp) +{ + struct sockaddr_mctp addr = {0}; + struct mctp_test_route *rt; + struct mctp_test_dev *dev; + struct socket *sock; + int rc; + + dev = mctp_test_create_dev(); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev); + + rt = mctp_test_create_route(&init_net, dev->mdev, 8, 68); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt); + + rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, &sock); + KUNIT_ASSERT_EQ(test, rc, 0); + + addr.smctp_family = AF_MCTP; + addr.smctp_network = MCTP_NET_ANY; + addr.smctp_addr.s_addr = 8; + addr.smctp_type = 0; + rc = kernel_bind(sock, (struct sockaddr *)&addr, sizeof(addr)); + KUNIT_ASSERT_EQ(test, rc, 0); + + *rtp = rt; + *devp = dev; + *sockp = sock; +} + +static void __mctp_route_test_fini(struct kunit *test, + struct mctp_test_dev *dev, + struct mctp_test_route *rt, + struct socket *sock) +{ + sock_release(sock); + mctp_test_route_destroy(test, rt); + mctp_test_destroy_dev(dev); +} + +struct mctp_route_input_sk_test { + struct mctp_hdr hdr; + u8 type; + bool deliver; +}; + +static void mctp_test_route_input_sk(struct kunit *test) +{ + const struct mctp_route_input_sk_test *params; + struct sk_buff *skb, *skb2; + struct mctp_test_route *rt; + struct mctp_test_dev *dev; + struct socket *sock; + int rc; + + params = test->param_value; + + __mctp_route_test_init(test, &dev, &rt, &sock); + + skb = mctp_test_create_skb_data(¶ms->hdr, ¶ms->type); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb); + + skb->dev = dev->ndev; + __mctp_cb(skb); + + rc = mctp_route_input(&rt->rt, skb); + + if (params->deliver) { + KUNIT_EXPECT_EQ(test, rc, 0); + + skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc); + KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2); + KUNIT_EXPECT_EQ(test, skb->len, 1); + + skb_free_datagram(sock->sk, skb2); + + } else { + KUNIT_EXPECT_NE(test, rc, 0); + skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc); + KUNIT_EXPECT_NULL(test, skb2); + } + + __mctp_route_test_fini(test, dev, rt, sock); +} + +#define FL_S (MCTP_HDR_FLAG_SOM) +#define FL_E (MCTP_HDR_FLAG_EOM) +#define FL_TO (MCTP_HDR_FLAG_TO) +#define FL_T(t) ((t) & MCTP_HDR_TAG_MASK) + +static const struct mctp_route_input_sk_test mctp_route_input_sk_tests[] = { + { .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_TO), .type = 0, .deliver = true }, + { .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_TO), .type = 1, .deliver = false }, + { .hdr = RX_HDR(1, 10, 8, FL_S | FL_E), .type = 0, .deliver = false }, + { .hdr = RX_HDR(1, 10, 8, FL_E | FL_TO), .type = 0, .deliver = false }, + { .hdr = RX_HDR(1, 10, 8, FL_TO), .type = 0, .deliver = false }, + { .hdr = RX_HDR(1, 10, 8, 0), .type = 0, .deliver = false }, +}; + +static void mctp_route_input_sk_to_desc(const struct mctp_route_input_sk_test *t, + char *desc) +{ + sprintf(desc, "{%x,%x,%x,%x} type %d", t->hdr.ver, t->hdr.src, + t->hdr.dest, t->hdr.flags_seq_tag, t->type); +} + +KUNIT_ARRAY_PARAM(mctp_route_input_sk, mctp_route_input_sk_tests, + mctp_route_input_sk_to_desc); + +struct mctp_route_input_sk_reasm_test { + const char *name; + struct mctp_hdr hdrs[4]; + int n_hdrs; + int rx_len; +}; + +static void mctp_test_route_input_sk_reasm(struct kunit *test) +{ + const struct mctp_route_input_sk_reasm_test *params; + struct sk_buff *skb, *skb2; + struct mctp_test_route *rt; + struct mctp_test_dev *dev; + struct socket *sock; + int i, rc; + u8 c; + + params = test->param_value; + + __mctp_route_test_init(test, &dev, &rt, &sock); + + for (i = 0; i < params->n_hdrs; i++) { + c = i; + skb = mctp_test_create_skb_data(¶ms->hdrs[i], &c); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb); + + skb->dev = dev->ndev; + __mctp_cb(skb); + + rc = mctp_route_input(&rt->rt, skb); + } + + skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc); + + if (params->rx_len) { + KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2); + KUNIT_EXPECT_EQ(test, skb2->len, params->rx_len); + skb_free_datagram(sock->sk, skb2); + + } else { + KUNIT_EXPECT_NULL(test, skb2); + } + + __mctp_route_test_fini(test, dev, rt, sock); +} + +#define RX_FRAG(f, s) RX_HDR(1, 10, 8, FL_TO | (f) | ((s) << MCTP_HDR_SEQ_SHIFT)) + +static const struct mctp_route_input_sk_reasm_test mctp_route_input_sk_reasm_tests[] = { + { + .name = "single packet", + .hdrs = { + RX_FRAG(FL_S | FL_E, 0), + }, + .n_hdrs = 1, + .rx_len = 1, + }, + { + .name = "single packet, offset seq", + .hdrs = { + RX_FRAG(FL_S | FL_E, 1), + }, + .n_hdrs = 1, + .rx_len = 1, + }, + { + .name = "start & end packets", + .hdrs = { + RX_FRAG(FL_S, 0), + RX_FRAG(FL_E, 1), + }, + .n_hdrs = 2, + .rx_len = 2, + }, + { + .name = "start & end packets, offset seq", + .hdrs = { + RX_FRAG(FL_S, 1), + RX_FRAG(FL_E, 2), + }, + .n_hdrs = 2, + .rx_len = 2, + }, + { + .name = "start & end packets, out of order", + .hdrs = { + RX_FRAG(FL_E, 1), + RX_FRAG(FL_S, 0), + }, + .n_hdrs = 2, + .rx_len = 0, + }, + { + .name = "start, middle & end packets", + .hdrs = { + RX_FRAG(FL_S, 0), + RX_FRAG(0, 1), + RX_FRAG(FL_E, 2), + }, + .n_hdrs = 3, + .rx_len = 3, + }, + { + .name = "missing seq", + .hdrs = { + RX_FRAG(FL_S, 0), + RX_FRAG(FL_E, 2), + }, + .n_hdrs = 2, + .rx_len = 0, + }, + { + .name = "seq wrap", + .hdrs = { + RX_FRAG(FL_S, 3), + RX_FRAG(FL_E, 0), + }, + .n_hdrs = 2, + .rx_len = 2, + }, +}; + +static void mctp_route_input_sk_reasm_to_desc( + const struct mctp_route_input_sk_reasm_test *t, + char *desc) +{ + sprintf(desc, "%s", t->name); +} + +KUNIT_ARRAY_PARAM(mctp_route_input_sk_reasm, mctp_route_input_sk_reasm_tests, + mctp_route_input_sk_reasm_to_desc); + +struct mctp_route_input_sk_keys_test { + const char *name; + mctp_eid_t key_peer_addr; + mctp_eid_t key_local_addr; + u8 key_tag; + struct mctp_hdr hdr; + bool deliver; +}; + +/* test packet rx in the presence of various key configurations */ +static void mctp_test_route_input_sk_keys(struct kunit *test) +{ + const struct mctp_route_input_sk_keys_test *params; + struct mctp_test_route *rt; + struct sk_buff *skb, *skb2; + struct mctp_test_dev *dev; + struct mctp_sk_key *key; + struct netns_mctp *mns; + struct mctp_sock *msk; + struct socket *sock; + unsigned long flags; + int rc; + u8 c; + + params = test->param_value; + + dev = mctp_test_create_dev(); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev); + + rt = mctp_test_create_route(&init_net, dev->mdev, 8, 68); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt); + + rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, &sock); + KUNIT_ASSERT_EQ(test, rc, 0); + + msk = container_of(sock->sk, struct mctp_sock, sk); + mns = &sock_net(sock->sk)->mctp; + + /* set the incoming tag according to test params */ + key = mctp_key_alloc(msk, params->key_local_addr, params->key_peer_addr, + params->key_tag, GFP_KERNEL); + + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, key); + + spin_lock_irqsave(&mns->keys_lock, flags); + mctp_reserve_tag(&init_net, key, msk); + spin_unlock_irqrestore(&mns->keys_lock, flags); + + /* create packet and route */ + c = 0; + skb = mctp_test_create_skb_data(¶ms->hdr, &c); + KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb); + + skb->dev = dev->ndev; + __mctp_cb(skb); + + rc = mctp_route_input(&rt->rt, skb); + + /* (potentially) receive message */ + skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc); + + if (params->deliver) + KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2); + else + KUNIT_EXPECT_PTR_EQ(test, skb2, NULL); + + if (skb2) + skb_free_datagram(sock->sk, skb2); + + mctp_key_unref(key); + __mctp_route_test_fini(test, dev, rt, sock); +} + +static const struct mctp_route_input_sk_keys_test mctp_route_input_sk_keys_tests[] = { + { + .name = "direct match", + .key_peer_addr = 9, + .key_local_addr = 8, + .key_tag = 1, + .hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1)), + .deliver = true, + }, + { + .name = "flipped src/dest", + .key_peer_addr = 8, + .key_local_addr = 9, + .key_tag = 1, + .hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1)), + .deliver = false, + }, + { + .name = "peer addr mismatch", + .key_peer_addr = 9, + .key_local_addr = 8, + .key_tag = 1, + .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_T(1)), + .deliver = false, + }, + { + .name = "tag value mismatch", + .key_peer_addr = 9, + .key_local_addr = 8, + .key_tag = 1, + .hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(2)), + .deliver = false, + }, + { + .name = "TO mismatch", + .key_peer_addr = 9, + .key_local_addr = 8, + .key_tag = 1, + .hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1) | FL_TO), + .deliver = false, + }, + { + .name = "broadcast response", + .key_peer_addr = MCTP_ADDR_ANY, + .key_local_addr = 8, + .key_tag = 1, + .hdr = RX_HDR(1, 11, 8, FL_S | FL_E | FL_T(1)), + .deliver = true, + }, + { + .name = "any local match", + .key_peer_addr = 12, + .key_local_addr = MCTP_ADDR_ANY, + .key_tag = 1, + .hdr = RX_HDR(1, 12, 8, FL_S | FL_E | FL_T(1)), + .deliver = true, + }, +}; + +static void mctp_route_input_sk_keys_to_desc( + const struct mctp_route_input_sk_keys_test *t, + char *desc) +{ + sprintf(desc, "%s", t->name); +} + +KUNIT_ARRAY_PARAM(mctp_route_input_sk_keys, mctp_route_input_sk_keys_tests, + mctp_route_input_sk_keys_to_desc); + +static struct kunit_case mctp_test_cases[] = { + KUNIT_CASE_PARAM(mctp_test_fragment, mctp_frag_gen_params), + KUNIT_CASE_PARAM(mctp_test_rx_input, mctp_rx_input_gen_params), + KUNIT_CASE_PARAM(mctp_test_route_input_sk, mctp_route_input_sk_gen_params), + KUNIT_CASE_PARAM(mctp_test_route_input_sk_reasm, + mctp_route_input_sk_reasm_gen_params), + KUNIT_CASE_PARAM(mctp_test_route_input_sk_keys, + mctp_route_input_sk_keys_gen_params), + {} +}; + +static struct kunit_suite mctp_test_suite = { + .name = "mctp", + .test_cases = mctp_test_cases, +}; + +kunit_test_suite(mctp_test_suite); diff --git a/net/mctp/test/utils.c b/net/mctp/test/utils.c new file mode 100644 index 000000000..e03ba66bb --- /dev/null +++ b/net/mctp/test/utils.c @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include +#include +#include + +#include +#include + +#include "utils.h" + +static netdev_tx_t mctp_test_dev_tx(struct sk_buff *skb, + struct net_device *ndev) +{ + kfree_skb(skb); + return NETDEV_TX_OK; +} + +static const struct net_device_ops mctp_test_netdev_ops = { + .ndo_start_xmit = mctp_test_dev_tx, +}; + +static void mctp_test_dev_setup(struct net_device *ndev) +{ + ndev->type = ARPHRD_MCTP; + ndev->mtu = MCTP_DEV_TEST_MTU; + ndev->hard_header_len = 0; + ndev->addr_len = 0; + ndev->tx_queue_len = DEFAULT_TX_QUEUE_LEN; + ndev->flags = IFF_NOARP; + ndev->netdev_ops = &mctp_test_netdev_ops; + ndev->needs_free_netdev = true; +} + +struct mctp_test_dev *mctp_test_create_dev(void) +{ + struct mctp_test_dev *dev; + struct net_device *ndev; + int rc; + + ndev = alloc_netdev(sizeof(*dev), "mctptest%d", NET_NAME_ENUM, + mctp_test_dev_setup); + if (!ndev) + return NULL; + + dev = netdev_priv(ndev); + dev->ndev = ndev; + + rc = register_netdev(ndev); + if (rc) { + free_netdev(ndev); + return NULL; + } + + rcu_read_lock(); + dev->mdev = __mctp_dev_get(ndev); + rcu_read_unlock(); + + return dev; +} + +void mctp_test_destroy_dev(struct mctp_test_dev *dev) +{ + mctp_dev_put(dev->mdev); + unregister_netdev(dev->ndev); +} diff --git a/net/mctp/test/utils.h b/net/mctp/test/utils.h new file mode 100644 index 000000000..df6aa1c03 --- /dev/null +++ b/net/mctp/test/utils.h @@ -0,0 +1,20 @@ +/* SPDX-License-Identifier: GPL-2.0 */ + +#ifndef __NET_MCTP_TEST_UTILS_H +#define __NET_MCTP_TEST_UTILS_H + +#include + +#define MCTP_DEV_TEST_MTU 68 + +struct mctp_test_dev { + struct net_device *ndev; + struct mctp_dev *mdev; +}; + +struct mctp_test_dev; + +struct mctp_test_dev *mctp_test_create_dev(void); +void mctp_test_destroy_dev(struct mctp_test_dev *dev); + +#endif /* __NET_MCTP_TEST_UTILS_H */ diff --git a/net/mpls/Kconfig b/net/mpls/Kconfig new file mode 100644 index 000000000..d672ab72a --- /dev/null +++ b/net/mpls/Kconfig @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# MPLS configuration +# + +menuconfig MPLS + bool "MultiProtocol Label Switching" + default n + help + MultiProtocol Label Switching routes packets through logical + circuits. Originally conceived as a way of routing packets at + hardware speeds (before hardware was capable of routing ipv4 packets), + MPLS remains a simple way of making tunnels. + + If you have not heard of MPLS you probably want to say N here. + +if MPLS + +config NET_MPLS_GSO + tristate "MPLS: GSO support" + help + This is helper module to allow segmentation of non-MPLS GSO packets + that have had MPLS stack entries pushed onto them and thus + become MPLS GSO packets. + +config MPLS_ROUTING + tristate "MPLS: routing support" + depends on NET_IP_TUNNEL || NET_IP_TUNNEL=n + depends on PROC_SYSCTL + help + Add support for forwarding of mpls packets. + +config MPLS_IPTUNNEL + tristate "MPLS: IP over MPLS tunnel support" + depends on LWTUNNEL && MPLS_ROUTING + help + mpls ip tunnel support. + +endif # MPLS diff --git a/net/mpls/Makefile b/net/mpls/Makefile new file mode 100644 index 000000000..53e33b6c7 --- /dev/null +++ b/net/mpls/Makefile @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for MPLS. +# +obj-$(CONFIG_NET_MPLS_GSO) += mpls_gso.o +obj-$(CONFIG_MPLS_ROUTING) += mpls_router.o +obj-$(CONFIG_MPLS_IPTUNNEL) += mpls_iptunnel.o + +mpls_router-y := af_mpls.o diff --git a/net/mpls/af_mpls.c b/net/mpls/af_mpls.c new file mode 100644 index 000000000..f1f43894e --- /dev/null +++ b/net/mpls/af_mpls.c @@ -0,0 +1,2799 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_IPV6) +#include +#endif +#include +#include +#include "internal.h" + +/* max memory we will use for mpls_route */ +#define MAX_MPLS_ROUTE_MEM 4096 + +/* Maximum number of labels to look ahead at when selecting a path of + * a multipath route + */ +#define MAX_MP_SELECT_LABELS 4 + +#define MPLS_NEIGH_TABLE_UNSPEC (NEIGH_LINK_TABLE + 1) + +static int label_limit = (1 << 20) - 1; +static int ttl_max = 255; + +#if IS_ENABLED(CONFIG_NET_IP_TUNNEL) +static size_t ipgre_mpls_encap_hlen(struct ip_tunnel_encap *e) +{ + return sizeof(struct mpls_shim_hdr); +} + +static const struct ip_tunnel_encap_ops mpls_iptun_ops = { + .encap_hlen = ipgre_mpls_encap_hlen, +}; + +static int ipgre_tunnel_encap_add_mpls_ops(void) +{ + return ip_tunnel_encap_add_ops(&mpls_iptun_ops, TUNNEL_ENCAP_MPLS); +} + +static void ipgre_tunnel_encap_del_mpls_ops(void) +{ + ip_tunnel_encap_del_ops(&mpls_iptun_ops, TUNNEL_ENCAP_MPLS); +} +#else +static int ipgre_tunnel_encap_add_mpls_ops(void) +{ + return 0; +} + +static void ipgre_tunnel_encap_del_mpls_ops(void) +{ +} +#endif + +static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt, + struct nlmsghdr *nlh, struct net *net, u32 portid, + unsigned int nlm_flags); + +static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index) +{ + struct mpls_route *rt = NULL; + + if (index < net->mpls.platform_labels) { + struct mpls_route __rcu **platform_label = + rcu_dereference(net->mpls.platform_label); + rt = rcu_dereference(platform_label[index]); + } + return rt; +} + +bool mpls_output_possible(const struct net_device *dev) +{ + return dev && (dev->flags & IFF_UP) && netif_carrier_ok(dev); +} +EXPORT_SYMBOL_GPL(mpls_output_possible); + +static u8 *__mpls_nh_via(struct mpls_route *rt, struct mpls_nh *nh) +{ + return (u8 *)nh + rt->rt_via_offset; +} + +static const u8 *mpls_nh_via(const struct mpls_route *rt, + const struct mpls_nh *nh) +{ + return __mpls_nh_via((struct mpls_route *)rt, (struct mpls_nh *)nh); +} + +static unsigned int mpls_nh_header_size(const struct mpls_nh *nh) +{ + /* The size of the layer 2.5 labels to be added for this route */ + return nh->nh_labels * sizeof(struct mpls_shim_hdr); +} + +unsigned int mpls_dev_mtu(const struct net_device *dev) +{ + /* The amount of data the layer 2 frame can hold */ + return dev->mtu; +} +EXPORT_SYMBOL_GPL(mpls_dev_mtu); + +bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu) +{ + if (skb->len <= mtu) + return false; + + if (skb_is_gso(skb) && skb_gso_validate_network_len(skb, mtu)) + return false; + + return true; +} +EXPORT_SYMBOL_GPL(mpls_pkt_too_big); + +void mpls_stats_inc_outucastpkts(struct net_device *dev, + const struct sk_buff *skb) +{ + struct mpls_dev *mdev; + + if (skb->protocol == htons(ETH_P_MPLS_UC)) { + mdev = mpls_dev_get(dev); + if (mdev) + MPLS_INC_STATS_LEN(mdev, skb->len, + tx_packets, + tx_bytes); + } else if (skb->protocol == htons(ETH_P_IP)) { + IP_UPD_PO_STATS(dev_net(dev), IPSTATS_MIB_OUT, skb->len); +#if IS_ENABLED(CONFIG_IPV6) + } else if (skb->protocol == htons(ETH_P_IPV6)) { + struct inet6_dev *in6dev = __in6_dev_get(dev); + + if (in6dev) + IP6_UPD_PO_STATS(dev_net(dev), in6dev, + IPSTATS_MIB_OUT, skb->len); +#endif + } +} +EXPORT_SYMBOL_GPL(mpls_stats_inc_outucastpkts); + +static u32 mpls_multipath_hash(struct mpls_route *rt, struct sk_buff *skb) +{ + struct mpls_entry_decoded dec; + unsigned int mpls_hdr_len = 0; + struct mpls_shim_hdr *hdr; + bool eli_seen = false; + int label_index; + u32 hash = 0; + + for (label_index = 0; label_index < MAX_MP_SELECT_LABELS; + label_index++) { + mpls_hdr_len += sizeof(*hdr); + if (!pskb_may_pull(skb, mpls_hdr_len)) + break; + + /* Read and decode the current label */ + hdr = mpls_hdr(skb) + label_index; + dec = mpls_entry_decode(hdr); + + /* RFC6790 - reserved labels MUST NOT be used as keys + * for the load-balancing function + */ + if (likely(dec.label >= MPLS_LABEL_FIRST_UNRESERVED)) { + hash = jhash_1word(dec.label, hash); + + /* The entropy label follows the entropy label + * indicator, so this means that the entropy + * label was just added to the hash - no need to + * go any deeper either in the label stack or in the + * payload + */ + if (eli_seen) + break; + } else if (dec.label == MPLS_LABEL_ENTROPY) { + eli_seen = true; + } + + if (!dec.bos) + continue; + + /* found bottom label; does skb have room for a header? */ + if (pskb_may_pull(skb, mpls_hdr_len + sizeof(struct iphdr))) { + const struct iphdr *v4hdr; + + v4hdr = (const struct iphdr *)(hdr + 1); + if (v4hdr->version == 4) { + hash = jhash_3words(ntohl(v4hdr->saddr), + ntohl(v4hdr->daddr), + v4hdr->protocol, hash); + } else if (v4hdr->version == 6 && + pskb_may_pull(skb, mpls_hdr_len + + sizeof(struct ipv6hdr))) { + const struct ipv6hdr *v6hdr; + + v6hdr = (const struct ipv6hdr *)(hdr + 1); + hash = __ipv6_addr_jhash(&v6hdr->saddr, hash); + hash = __ipv6_addr_jhash(&v6hdr->daddr, hash); + hash = jhash_1word(v6hdr->nexthdr, hash); + } + } + + break; + } + + return hash; +} + +static struct mpls_nh *mpls_get_nexthop(struct mpls_route *rt, u8 index) +{ + return (struct mpls_nh *)((u8 *)rt->rt_nh + index * rt->rt_nh_size); +} + +/* number of alive nexthops (rt->rt_nhn_alive) and the flags for + * a next hop (nh->nh_flags) are modified by netdev event handlers. + * Since those fields can change at any moment, use READ_ONCE to + * access both. + */ +static const struct mpls_nh *mpls_select_multipath(struct mpls_route *rt, + struct sk_buff *skb) +{ + u32 hash = 0; + int nh_index = 0; + int n = 0; + u8 alive; + + /* No need to look further into packet if there's only + * one path + */ + if (rt->rt_nhn == 1) + return rt->rt_nh; + + alive = READ_ONCE(rt->rt_nhn_alive); + if (alive == 0) + return NULL; + + hash = mpls_multipath_hash(rt, skb); + nh_index = hash % alive; + if (alive == rt->rt_nhn) + goto out; + for_nexthops(rt) { + unsigned int nh_flags = READ_ONCE(nh->nh_flags); + + if (nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN)) + continue; + if (n == nh_index) + return nh; + n++; + } endfor_nexthops(rt); + +out: + return mpls_get_nexthop(rt, nh_index); +} + +static bool mpls_egress(struct net *net, struct mpls_route *rt, + struct sk_buff *skb, struct mpls_entry_decoded dec) +{ + enum mpls_payload_type payload_type; + bool success = false; + + /* The IPv4 code below accesses through the IPv4 header + * checksum, which is 12 bytes into the packet. + * The IPv6 code below accesses through the IPv6 hop limit + * which is 8 bytes into the packet. + * + * For all supported cases there should always be at least 12 + * bytes of packet data present. The IPv4 header is 20 bytes + * without options and the IPv6 header is always 40 bytes + * long. + */ + if (!pskb_may_pull(skb, 12)) + return false; + + payload_type = rt->rt_payload_type; + if (payload_type == MPT_UNSPEC) + payload_type = ip_hdr(skb)->version; + + switch (payload_type) { + case MPT_IPV4: { + struct iphdr *hdr4 = ip_hdr(skb); + u8 new_ttl; + skb->protocol = htons(ETH_P_IP); + + /* If propagating TTL, take the decremented TTL from + * the incoming MPLS header, otherwise decrement the + * TTL, but only if not 0 to avoid underflow. + */ + if (rt->rt_ttl_propagate == MPLS_TTL_PROP_ENABLED || + (rt->rt_ttl_propagate == MPLS_TTL_PROP_DEFAULT && + net->mpls.ip_ttl_propagate)) + new_ttl = dec.ttl; + else + new_ttl = hdr4->ttl ? hdr4->ttl - 1 : 0; + + csum_replace2(&hdr4->check, + htons(hdr4->ttl << 8), + htons(new_ttl << 8)); + hdr4->ttl = new_ttl; + success = true; + break; + } + case MPT_IPV6: { + struct ipv6hdr *hdr6 = ipv6_hdr(skb); + skb->protocol = htons(ETH_P_IPV6); + + /* If propagating TTL, take the decremented TTL from + * the incoming MPLS header, otherwise decrement the + * hop limit, but only if not 0 to avoid underflow. + */ + if (rt->rt_ttl_propagate == MPLS_TTL_PROP_ENABLED || + (rt->rt_ttl_propagate == MPLS_TTL_PROP_DEFAULT && + net->mpls.ip_ttl_propagate)) + hdr6->hop_limit = dec.ttl; + else if (hdr6->hop_limit) + hdr6->hop_limit = hdr6->hop_limit - 1; + success = true; + break; + } + case MPT_UNSPEC: + /* Should have decided which protocol it is by now */ + break; + } + + return success; +} + +static int mpls_forward(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + struct net *net = dev_net(dev); + struct mpls_shim_hdr *hdr; + const struct mpls_nh *nh; + struct mpls_route *rt; + struct mpls_entry_decoded dec; + struct net_device *out_dev; + struct mpls_dev *out_mdev; + struct mpls_dev *mdev; + unsigned int hh_len; + unsigned int new_header_size; + unsigned int mtu; + int err; + + /* Careful this entire function runs inside of an rcu critical section */ + + mdev = mpls_dev_get(dev); + if (!mdev) + goto drop; + + MPLS_INC_STATS_LEN(mdev, skb->len, rx_packets, + rx_bytes); + + if (!mdev->input_enabled) { + MPLS_INC_STATS(mdev, rx_dropped); + goto drop; + } + + if (skb->pkt_type != PACKET_HOST) + goto err; + + if ((skb = skb_share_check(skb, GFP_ATOMIC)) == NULL) + goto err; + + if (!pskb_may_pull(skb, sizeof(*hdr))) + goto err; + + skb_dst_drop(skb); + + /* Read and decode the label */ + hdr = mpls_hdr(skb); + dec = mpls_entry_decode(hdr); + + rt = mpls_route_input_rcu(net, dec.label); + if (!rt) { + MPLS_INC_STATS(mdev, rx_noroute); + goto drop; + } + + nh = mpls_select_multipath(rt, skb); + if (!nh) + goto err; + + /* Pop the label */ + skb_pull(skb, sizeof(*hdr)); + skb_reset_network_header(skb); + + skb_orphan(skb); + + if (skb_warn_if_lro(skb)) + goto err; + + skb_forward_csum(skb); + + /* Verify ttl is valid */ + if (dec.ttl <= 1) + goto err; + + /* Find the output device */ + out_dev = nh->nh_dev; + if (!mpls_output_possible(out_dev)) + goto tx_err; + + /* Verify the destination can hold the packet */ + new_header_size = mpls_nh_header_size(nh); + mtu = mpls_dev_mtu(out_dev); + if (mpls_pkt_too_big(skb, mtu - new_header_size)) + goto tx_err; + + hh_len = LL_RESERVED_SPACE(out_dev); + if (!out_dev->header_ops) + hh_len = 0; + + /* Ensure there is enough space for the headers in the skb */ + if (skb_cow(skb, hh_len + new_header_size)) + goto tx_err; + + skb->dev = out_dev; + skb->protocol = htons(ETH_P_MPLS_UC); + + dec.ttl -= 1; + if (unlikely(!new_header_size && dec.bos)) { + /* Penultimate hop popping */ + if (!mpls_egress(dev_net(out_dev), rt, skb, dec)) + goto err; + } else { + bool bos; + int i; + skb_push(skb, new_header_size); + skb_reset_network_header(skb); + /* Push the new labels */ + hdr = mpls_hdr(skb); + bos = dec.bos; + for (i = nh->nh_labels - 1; i >= 0; i--) { + hdr[i] = mpls_entry_encode(nh->nh_label[i], + dec.ttl, 0, bos); + bos = false; + } + } + + mpls_stats_inc_outucastpkts(out_dev, skb); + + /* If via wasn't specified then send out using device address */ + if (nh->nh_via_table == MPLS_NEIGH_TABLE_UNSPEC) + err = neigh_xmit(NEIGH_LINK_TABLE, out_dev, + out_dev->dev_addr, skb); + else + err = neigh_xmit(nh->nh_via_table, out_dev, + mpls_nh_via(rt, nh), skb); + if (err) + net_dbg_ratelimited("%s: packet transmission failed: %d\n", + __func__, err); + return 0; + +tx_err: + out_mdev = out_dev ? mpls_dev_get(out_dev) : NULL; + if (out_mdev) + MPLS_INC_STATS(out_mdev, tx_errors); + goto drop; +err: + MPLS_INC_STATS(mdev, rx_errors); +drop: + kfree_skb(skb); + return NET_RX_DROP; +} + +static struct packet_type mpls_packet_type __read_mostly = { + .type = cpu_to_be16(ETH_P_MPLS_UC), + .func = mpls_forward, +}; + +static const struct nla_policy rtm_mpls_policy[RTA_MAX+1] = { + [RTA_DST] = { .type = NLA_U32 }, + [RTA_OIF] = { .type = NLA_U32 }, + [RTA_TTL_PROPAGATE] = { .type = NLA_U8 }, +}; + +struct mpls_route_config { + u32 rc_protocol; + u32 rc_ifindex; + u8 rc_via_table; + u8 rc_via_alen; + u8 rc_via[MAX_VIA_ALEN]; + u32 rc_label; + u8 rc_ttl_propagate; + u8 rc_output_labels; + u32 rc_output_label[MAX_NEW_LABELS]; + u32 rc_nlflags; + enum mpls_payload_type rc_payload_type; + struct nl_info rc_nlinfo; + struct rtnexthop *rc_mp; + int rc_mp_len; +}; + +/* all nexthops within a route have the same size based on max + * number of labels and max via length for a hop + */ +static struct mpls_route *mpls_rt_alloc(u8 num_nh, u8 max_alen, u8 max_labels) +{ + u8 nh_size = MPLS_NH_SIZE(max_labels, max_alen); + struct mpls_route *rt; + size_t size; + + size = sizeof(*rt) + num_nh * nh_size; + if (size > MAX_MPLS_ROUTE_MEM) + return ERR_PTR(-EINVAL); + + rt = kzalloc(size, GFP_KERNEL); + if (!rt) + return ERR_PTR(-ENOMEM); + + rt->rt_nhn = num_nh; + rt->rt_nhn_alive = num_nh; + rt->rt_nh_size = nh_size; + rt->rt_via_offset = MPLS_NH_VIA_OFF(max_labels); + + return rt; +} + +static void mpls_rt_free(struct mpls_route *rt) +{ + if (rt) + kfree_rcu(rt, rt_rcu); +} + +static void mpls_notify_route(struct net *net, unsigned index, + struct mpls_route *old, struct mpls_route *new, + const struct nl_info *info) +{ + struct nlmsghdr *nlh = info ? info->nlh : NULL; + unsigned portid = info ? info->portid : 0; + int event = new ? RTM_NEWROUTE : RTM_DELROUTE; + struct mpls_route *rt = new ? new : old; + unsigned nlm_flags = (old && new) ? NLM_F_REPLACE : 0; + /* Ignore reserved labels for now */ + if (rt && (index >= MPLS_LABEL_FIRST_UNRESERVED)) + rtmsg_lfib(event, index, rt, nlh, net, portid, nlm_flags); +} + +static void mpls_route_update(struct net *net, unsigned index, + struct mpls_route *new, + const struct nl_info *info) +{ + struct mpls_route __rcu **platform_label; + struct mpls_route *rt; + + ASSERT_RTNL(); + + platform_label = rtnl_dereference(net->mpls.platform_label); + rt = rtnl_dereference(platform_label[index]); + rcu_assign_pointer(platform_label[index], new); + + mpls_notify_route(net, index, rt, new, info); + + /* If we removed a route free it now */ + mpls_rt_free(rt); +} + +static unsigned find_free_label(struct net *net) +{ + struct mpls_route __rcu **platform_label; + size_t platform_labels; + unsigned index; + + platform_label = rtnl_dereference(net->mpls.platform_label); + platform_labels = net->mpls.platform_labels; + for (index = MPLS_LABEL_FIRST_UNRESERVED; index < platform_labels; + index++) { + if (!rtnl_dereference(platform_label[index])) + return index; + } + return LABEL_NOT_SPECIFIED; +} + +#if IS_ENABLED(CONFIG_INET) +static struct net_device *inet_fib_lookup_dev(struct net *net, + const void *addr) +{ + struct net_device *dev; + struct rtable *rt; + struct in_addr daddr; + + memcpy(&daddr, addr, sizeof(struct in_addr)); + rt = ip_route_output(net, daddr.s_addr, 0, 0, 0); + if (IS_ERR(rt)) + return ERR_CAST(rt); + + dev = rt->dst.dev; + dev_hold(dev); + + ip_rt_put(rt); + + return dev; +} +#else +static struct net_device *inet_fib_lookup_dev(struct net *net, + const void *addr) +{ + return ERR_PTR(-EAFNOSUPPORT); +} +#endif + +#if IS_ENABLED(CONFIG_IPV6) +static struct net_device *inet6_fib_lookup_dev(struct net *net, + const void *addr) +{ + struct net_device *dev; + struct dst_entry *dst; + struct flowi6 fl6; + + if (!ipv6_stub) + return ERR_PTR(-EAFNOSUPPORT); + + memset(&fl6, 0, sizeof(fl6)); + memcpy(&fl6.daddr, addr, sizeof(struct in6_addr)); + dst = ipv6_stub->ipv6_dst_lookup_flow(net, NULL, &fl6, NULL); + if (IS_ERR(dst)) + return ERR_CAST(dst); + + dev = dst->dev; + dev_hold(dev); + dst_release(dst); + + return dev; +} +#else +static struct net_device *inet6_fib_lookup_dev(struct net *net, + const void *addr) +{ + return ERR_PTR(-EAFNOSUPPORT); +} +#endif + +static struct net_device *find_outdev(struct net *net, + struct mpls_route *rt, + struct mpls_nh *nh, int oif) +{ + struct net_device *dev = NULL; + + if (!oif) { + switch (nh->nh_via_table) { + case NEIGH_ARP_TABLE: + dev = inet_fib_lookup_dev(net, mpls_nh_via(rt, nh)); + break; + case NEIGH_ND_TABLE: + dev = inet6_fib_lookup_dev(net, mpls_nh_via(rt, nh)); + break; + case NEIGH_LINK_TABLE: + break; + } + } else { + dev = dev_get_by_index(net, oif); + } + + if (!dev) + return ERR_PTR(-ENODEV); + + if (IS_ERR(dev)) + return dev; + + /* The caller is holding rtnl anyways, so release the dev reference */ + dev_put(dev); + + return dev; +} + +static int mpls_nh_assign_dev(struct net *net, struct mpls_route *rt, + struct mpls_nh *nh, int oif) +{ + struct net_device *dev = NULL; + int err = -ENODEV; + + dev = find_outdev(net, rt, nh, oif); + if (IS_ERR(dev)) { + err = PTR_ERR(dev); + dev = NULL; + goto errout; + } + + /* Ensure this is a supported device */ + err = -EINVAL; + if (!mpls_dev_get(dev)) + goto errout; + + if ((nh->nh_via_table == NEIGH_LINK_TABLE) && + (dev->addr_len != nh->nh_via_alen)) + goto errout; + + nh->nh_dev = dev; + + if (!(dev->flags & IFF_UP)) { + nh->nh_flags |= RTNH_F_DEAD; + } else { + unsigned int flags; + + flags = dev_get_flags(dev); + if (!(flags & (IFF_RUNNING | IFF_LOWER_UP))) + nh->nh_flags |= RTNH_F_LINKDOWN; + } + + return 0; + +errout: + return err; +} + +static int nla_get_via(const struct nlattr *nla, u8 *via_alen, u8 *via_table, + u8 via_addr[], struct netlink_ext_ack *extack) +{ + struct rtvia *via = nla_data(nla); + int err = -EINVAL; + int alen; + + if (nla_len(nla) < offsetof(struct rtvia, rtvia_addr)) { + NL_SET_ERR_MSG_ATTR(extack, nla, + "Invalid attribute length for RTA_VIA"); + goto errout; + } + alen = nla_len(nla) - + offsetof(struct rtvia, rtvia_addr); + if (alen > MAX_VIA_ALEN) { + NL_SET_ERR_MSG_ATTR(extack, nla, + "Invalid address length for RTA_VIA"); + goto errout; + } + + /* Validate the address family */ + switch (via->rtvia_family) { + case AF_PACKET: + *via_table = NEIGH_LINK_TABLE; + break; + case AF_INET: + *via_table = NEIGH_ARP_TABLE; + if (alen != 4) + goto errout; + break; + case AF_INET6: + *via_table = NEIGH_ND_TABLE; + if (alen != 16) + goto errout; + break; + default: + /* Unsupported address family */ + goto errout; + } + + memcpy(via_addr, via->rtvia_addr, alen); + *via_alen = alen; + err = 0; + +errout: + return err; +} + +static int mpls_nh_build_from_cfg(struct mpls_route_config *cfg, + struct mpls_route *rt) +{ + struct net *net = cfg->rc_nlinfo.nl_net; + struct mpls_nh *nh = rt->rt_nh; + int err; + int i; + + if (!nh) + return -ENOMEM; + + nh->nh_labels = cfg->rc_output_labels; + for (i = 0; i < nh->nh_labels; i++) + nh->nh_label[i] = cfg->rc_output_label[i]; + + nh->nh_via_table = cfg->rc_via_table; + memcpy(__mpls_nh_via(rt, nh), cfg->rc_via, cfg->rc_via_alen); + nh->nh_via_alen = cfg->rc_via_alen; + + err = mpls_nh_assign_dev(net, rt, nh, cfg->rc_ifindex); + if (err) + goto errout; + + if (nh->nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN)) + rt->rt_nhn_alive--; + + return 0; + +errout: + return err; +} + +static int mpls_nh_build(struct net *net, struct mpls_route *rt, + struct mpls_nh *nh, int oif, struct nlattr *via, + struct nlattr *newdst, u8 max_labels, + struct netlink_ext_ack *extack) +{ + int err = -ENOMEM; + + if (!nh) + goto errout; + + if (newdst) { + err = nla_get_labels(newdst, max_labels, &nh->nh_labels, + nh->nh_label, extack); + if (err) + goto errout; + } + + if (via) { + err = nla_get_via(via, &nh->nh_via_alen, &nh->nh_via_table, + __mpls_nh_via(rt, nh), extack); + if (err) + goto errout; + } else { + nh->nh_via_table = MPLS_NEIGH_TABLE_UNSPEC; + } + + err = mpls_nh_assign_dev(net, rt, nh, oif); + if (err) + goto errout; + + return 0; + +errout: + return err; +} + +static u8 mpls_count_nexthops(struct rtnexthop *rtnh, int len, + u8 cfg_via_alen, u8 *max_via_alen, + u8 *max_labels) +{ + int remaining = len; + u8 nhs = 0; + + *max_via_alen = 0; + *max_labels = 0; + + while (rtnh_ok(rtnh, remaining)) { + struct nlattr *nla, *attrs = rtnh_attrs(rtnh); + int attrlen; + u8 n_labels = 0; + + attrlen = rtnh_attrlen(rtnh); + nla = nla_find(attrs, attrlen, RTA_VIA); + if (nla && nla_len(nla) >= + offsetof(struct rtvia, rtvia_addr)) { + int via_alen = nla_len(nla) - + offsetof(struct rtvia, rtvia_addr); + + if (via_alen <= MAX_VIA_ALEN) + *max_via_alen = max_t(u16, *max_via_alen, + via_alen); + } + + nla = nla_find(attrs, attrlen, RTA_NEWDST); + if (nla && + nla_get_labels(nla, MAX_NEW_LABELS, &n_labels, + NULL, NULL) != 0) + return 0; + + *max_labels = max_t(u8, *max_labels, n_labels); + + /* number of nexthops is tracked by a u8. + * Check for overflow. + */ + if (nhs == 255) + return 0; + nhs++; + + rtnh = rtnh_next(rtnh, &remaining); + } + + /* leftover implies invalid nexthop configuration, discard it */ + return remaining > 0 ? 0 : nhs; +} + +static int mpls_nh_build_multi(struct mpls_route_config *cfg, + struct mpls_route *rt, u8 max_labels, + struct netlink_ext_ack *extack) +{ + struct rtnexthop *rtnh = cfg->rc_mp; + struct nlattr *nla_via, *nla_newdst; + int remaining = cfg->rc_mp_len; + int err = 0; + u8 nhs = 0; + + change_nexthops(rt) { + int attrlen; + + nla_via = NULL; + nla_newdst = NULL; + + err = -EINVAL; + if (!rtnh_ok(rtnh, remaining)) + goto errout; + + /* neither weighted multipath nor any flags + * are supported + */ + if (rtnh->rtnh_hops || rtnh->rtnh_flags) + goto errout; + + attrlen = rtnh_attrlen(rtnh); + if (attrlen > 0) { + struct nlattr *attrs = rtnh_attrs(rtnh); + + nla_via = nla_find(attrs, attrlen, RTA_VIA); + nla_newdst = nla_find(attrs, attrlen, RTA_NEWDST); + } + + err = mpls_nh_build(cfg->rc_nlinfo.nl_net, rt, nh, + rtnh->rtnh_ifindex, nla_via, nla_newdst, + max_labels, extack); + if (err) + goto errout; + + if (nh->nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN)) + rt->rt_nhn_alive--; + + rtnh = rtnh_next(rtnh, &remaining); + nhs++; + } endfor_nexthops(rt); + + rt->rt_nhn = nhs; + + return 0; + +errout: + return err; +} + +static bool mpls_label_ok(struct net *net, unsigned int *index, + struct netlink_ext_ack *extack) +{ + bool is_ok = true; + + /* Reserved labels may not be set */ + if (*index < MPLS_LABEL_FIRST_UNRESERVED) { + NL_SET_ERR_MSG(extack, + "Invalid label - must be MPLS_LABEL_FIRST_UNRESERVED or higher"); + is_ok = false; + } + + /* The full 20 bit range may not be supported. */ + if (is_ok && *index >= net->mpls.platform_labels) { + NL_SET_ERR_MSG(extack, + "Label >= configured maximum in platform_labels"); + is_ok = false; + } + + *index = array_index_nospec(*index, net->mpls.platform_labels); + return is_ok; +} + +static int mpls_route_add(struct mpls_route_config *cfg, + struct netlink_ext_ack *extack) +{ + struct mpls_route __rcu **platform_label; + struct net *net = cfg->rc_nlinfo.nl_net; + struct mpls_route *rt, *old; + int err = -EINVAL; + u8 max_via_alen; + unsigned index; + u8 max_labels; + u8 nhs; + + index = cfg->rc_label; + + /* If a label was not specified during insert pick one */ + if ((index == LABEL_NOT_SPECIFIED) && + (cfg->rc_nlflags & NLM_F_CREATE)) { + index = find_free_label(net); + } + + if (!mpls_label_ok(net, &index, extack)) + goto errout; + + /* Append makes no sense with mpls */ + err = -EOPNOTSUPP; + if (cfg->rc_nlflags & NLM_F_APPEND) { + NL_SET_ERR_MSG(extack, "MPLS does not support route append"); + goto errout; + } + + err = -EEXIST; + platform_label = rtnl_dereference(net->mpls.platform_label); + old = rtnl_dereference(platform_label[index]); + if ((cfg->rc_nlflags & NLM_F_EXCL) && old) + goto errout; + + err = -EEXIST; + if (!(cfg->rc_nlflags & NLM_F_REPLACE) && old) + goto errout; + + err = -ENOENT; + if (!(cfg->rc_nlflags & NLM_F_CREATE) && !old) + goto errout; + + err = -EINVAL; + if (cfg->rc_mp) { + nhs = mpls_count_nexthops(cfg->rc_mp, cfg->rc_mp_len, + cfg->rc_via_alen, &max_via_alen, + &max_labels); + } else { + max_via_alen = cfg->rc_via_alen; + max_labels = cfg->rc_output_labels; + nhs = 1; + } + + if (nhs == 0) { + NL_SET_ERR_MSG(extack, "Route does not contain a nexthop"); + goto errout; + } + + rt = mpls_rt_alloc(nhs, max_via_alen, max_labels); + if (IS_ERR(rt)) { + err = PTR_ERR(rt); + goto errout; + } + + rt->rt_protocol = cfg->rc_protocol; + rt->rt_payload_type = cfg->rc_payload_type; + rt->rt_ttl_propagate = cfg->rc_ttl_propagate; + + if (cfg->rc_mp) + err = mpls_nh_build_multi(cfg, rt, max_labels, extack); + else + err = mpls_nh_build_from_cfg(cfg, rt); + if (err) + goto freert; + + mpls_route_update(net, index, rt, &cfg->rc_nlinfo); + + return 0; + +freert: + mpls_rt_free(rt); +errout: + return err; +} + +static int mpls_route_del(struct mpls_route_config *cfg, + struct netlink_ext_ack *extack) +{ + struct net *net = cfg->rc_nlinfo.nl_net; + unsigned index; + int err = -EINVAL; + + index = cfg->rc_label; + + if (!mpls_label_ok(net, &index, extack)) + goto errout; + + mpls_route_update(net, index, NULL, &cfg->rc_nlinfo); + + err = 0; +errout: + return err; +} + +static void mpls_get_stats(struct mpls_dev *mdev, + struct mpls_link_stats *stats) +{ + struct mpls_pcpu_stats *p; + int i; + + memset(stats, 0, sizeof(*stats)); + + for_each_possible_cpu(i) { + struct mpls_link_stats local; + unsigned int start; + + p = per_cpu_ptr(mdev->stats, i); + do { + start = u64_stats_fetch_begin_irq(&p->syncp); + local = p->stats; + } while (u64_stats_fetch_retry_irq(&p->syncp, start)); + + stats->rx_packets += local.rx_packets; + stats->rx_bytes += local.rx_bytes; + stats->tx_packets += local.tx_packets; + stats->tx_bytes += local.tx_bytes; + stats->rx_errors += local.rx_errors; + stats->tx_errors += local.tx_errors; + stats->rx_dropped += local.rx_dropped; + stats->tx_dropped += local.tx_dropped; + stats->rx_noroute += local.rx_noroute; + } +} + +static int mpls_fill_stats_af(struct sk_buff *skb, + const struct net_device *dev) +{ + struct mpls_link_stats *stats; + struct mpls_dev *mdev; + struct nlattr *nla; + + mdev = mpls_dev_get(dev); + if (!mdev) + return -ENODATA; + + nla = nla_reserve_64bit(skb, MPLS_STATS_LINK, + sizeof(struct mpls_link_stats), + MPLS_STATS_UNSPEC); + if (!nla) + return -EMSGSIZE; + + stats = nla_data(nla); + mpls_get_stats(mdev, stats); + + return 0; +} + +static size_t mpls_get_stats_af_size(const struct net_device *dev) +{ + struct mpls_dev *mdev; + + mdev = mpls_dev_get(dev); + if (!mdev) + return 0; + + return nla_total_size_64bit(sizeof(struct mpls_link_stats)); +} + +static int mpls_netconf_fill_devconf(struct sk_buff *skb, struct mpls_dev *mdev, + u32 portid, u32 seq, int event, + unsigned int flags, int type) +{ + struct nlmsghdr *nlh; + struct netconfmsg *ncm; + bool all = false; + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(struct netconfmsg), + flags); + if (!nlh) + return -EMSGSIZE; + + if (type == NETCONFA_ALL) + all = true; + + ncm = nlmsg_data(nlh); + ncm->ncm_family = AF_MPLS; + + if (nla_put_s32(skb, NETCONFA_IFINDEX, mdev->dev->ifindex) < 0) + goto nla_put_failure; + + if ((all || type == NETCONFA_INPUT) && + nla_put_s32(skb, NETCONFA_INPUT, + mdev->input_enabled) < 0) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int mpls_netconf_msgsize_devconf(int type) +{ + int size = NLMSG_ALIGN(sizeof(struct netconfmsg)) + + nla_total_size(4); /* NETCONFA_IFINDEX */ + bool all = false; + + if (type == NETCONFA_ALL) + all = true; + + if (all || type == NETCONFA_INPUT) + size += nla_total_size(4); + + return size; +} + +static void mpls_netconf_notify_devconf(struct net *net, int event, + int type, struct mpls_dev *mdev) +{ + struct sk_buff *skb; + int err = -ENOBUFS; + + skb = nlmsg_new(mpls_netconf_msgsize_devconf(type), GFP_KERNEL); + if (!skb) + goto errout; + + err = mpls_netconf_fill_devconf(skb, mdev, 0, 0, event, 0, type); + if (err < 0) { + /* -EMSGSIZE implies BUG in mpls_netconf_msgsize_devconf() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + + rtnl_notify(skb, net, 0, RTNLGRP_MPLS_NETCONF, NULL, GFP_KERNEL); + return; +errout: + if (err < 0) + rtnl_set_sk_err(net, RTNLGRP_MPLS_NETCONF, err); +} + +static const struct nla_policy devconf_mpls_policy[NETCONFA_MAX + 1] = { + [NETCONFA_IFINDEX] = { .len = sizeof(int) }, +}; + +static int mpls_netconf_valid_get_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(struct netconfmsg))) { + NL_SET_ERR_MSG_MOD(extack, + "Invalid header for netconf get request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse_deprecated(nlh, sizeof(struct netconfmsg), + tb, NETCONFA_MAX, + devconf_mpls_policy, extack); + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(struct netconfmsg), + tb, NETCONFA_MAX, + devconf_mpls_policy, extack); + if (err) + return err; + + for (i = 0; i <= NETCONFA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case NETCONFA_IFINDEX: + break; + default: + NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in netconf get request"); + return -EINVAL; + } + } + + return 0; +} + +static int mpls_netconf_get_devconf(struct sk_buff *in_skb, + struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(in_skb->sk); + struct nlattr *tb[NETCONFA_MAX + 1]; + struct net_device *dev; + struct mpls_dev *mdev; + struct sk_buff *skb; + int ifindex; + int err; + + err = mpls_netconf_valid_get_req(in_skb, nlh, tb, extack); + if (err < 0) + goto errout; + + err = -EINVAL; + if (!tb[NETCONFA_IFINDEX]) + goto errout; + + ifindex = nla_get_s32(tb[NETCONFA_IFINDEX]); + dev = __dev_get_by_index(net, ifindex); + if (!dev) + goto errout; + + mdev = mpls_dev_get(dev); + if (!mdev) + goto errout; + + err = -ENOBUFS; + skb = nlmsg_new(mpls_netconf_msgsize_devconf(NETCONFA_ALL), GFP_KERNEL); + if (!skb) + goto errout; + + err = mpls_netconf_fill_devconf(skb, mdev, + NETLINK_CB(in_skb).portid, + nlh->nlmsg_seq, RTM_NEWNETCONF, 0, + NETCONFA_ALL); + if (err < 0) { + /* -EMSGSIZE implies BUG in mpls_netconf_msgsize_devconf() */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid); +errout: + return err; +} + +static int mpls_netconf_dump_devconf(struct sk_buff *skb, + struct netlink_callback *cb) +{ + const struct nlmsghdr *nlh = cb->nlh; + struct net *net = sock_net(skb->sk); + struct hlist_head *head; + struct net_device *dev; + struct mpls_dev *mdev; + int idx, s_idx; + int h, s_h; + + if (cb->strict_check) { + struct netlink_ext_ack *extack = cb->extack; + struct netconfmsg *ncm; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*ncm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for netconf dump request"); + return -EINVAL; + } + + if (nlmsg_attrlen(nlh, sizeof(*ncm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid data after header in netconf dump request"); + return -EINVAL; + } + } + + s_h = cb->args[0]; + s_idx = idx = cb->args[1]; + + for (h = s_h; h < NETDEV_HASHENTRIES; h++, s_idx = 0) { + idx = 0; + head = &net->dev_index_head[h]; + rcu_read_lock(); + cb->seq = net->dev_base_seq; + hlist_for_each_entry_rcu(dev, head, index_hlist) { + if (idx < s_idx) + goto cont; + mdev = mpls_dev_get(dev); + if (!mdev) + goto cont; + if (mpls_netconf_fill_devconf(skb, mdev, + NETLINK_CB(cb->skb).portid, + nlh->nlmsg_seq, + RTM_NEWNETCONF, + NLM_F_MULTI, + NETCONFA_ALL) < 0) { + rcu_read_unlock(); + goto done; + } + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); +cont: + idx++; + } + rcu_read_unlock(); + } +done: + cb->args[0] = h; + cb->args[1] = idx; + + return skb->len; +} + +#define MPLS_PERDEV_SYSCTL_OFFSET(field) \ + (&((struct mpls_dev *)0)->field) + +static int mpls_conf_proc(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + int oval = *(int *)ctl->data; + int ret = proc_dointvec(ctl, write, buffer, lenp, ppos); + + if (write) { + struct mpls_dev *mdev = ctl->extra1; + int i = (int *)ctl->data - (int *)mdev; + struct net *net = ctl->extra2; + int val = *(int *)ctl->data; + + if (i == offsetof(struct mpls_dev, input_enabled) && + val != oval) { + mpls_netconf_notify_devconf(net, RTM_NEWNETCONF, + NETCONFA_INPUT, mdev); + } + } + + return ret; +} + +static const struct ctl_table mpls_dev_table[] = { + { + .procname = "input", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = mpls_conf_proc, + .data = MPLS_PERDEV_SYSCTL_OFFSET(input_enabled), + }, + { } +}; + +static int mpls_dev_sysctl_register(struct net_device *dev, + struct mpls_dev *mdev) +{ + char path[sizeof("net/mpls/conf/") + IFNAMSIZ]; + struct net *net = dev_net(dev); + struct ctl_table *table; + int i; + + table = kmemdup(&mpls_dev_table, sizeof(mpls_dev_table), GFP_KERNEL); + if (!table) + goto out; + + /* Table data contains only offsets relative to the base of + * the mdev at this point, so make them absolute. + */ + for (i = 0; i < ARRAY_SIZE(mpls_dev_table); i++) { + table[i].data = (char *)mdev + (uintptr_t)table[i].data; + table[i].extra1 = mdev; + table[i].extra2 = net; + } + + snprintf(path, sizeof(path), "net/mpls/conf/%s", dev->name); + + mdev->sysctl = register_net_sysctl(net, path, table); + if (!mdev->sysctl) + goto free; + + mpls_netconf_notify_devconf(net, RTM_NEWNETCONF, NETCONFA_ALL, mdev); + return 0; + +free: + kfree(table); +out: + mdev->sysctl = NULL; + return -ENOBUFS; +} + +static void mpls_dev_sysctl_unregister(struct net_device *dev, + struct mpls_dev *mdev) +{ + struct net *net = dev_net(dev); + struct ctl_table *table; + + if (!mdev->sysctl) + return; + + table = mdev->sysctl->ctl_table_arg; + unregister_net_sysctl_table(mdev->sysctl); + kfree(table); + + mpls_netconf_notify_devconf(net, RTM_DELNETCONF, 0, mdev); +} + +static struct mpls_dev *mpls_add_dev(struct net_device *dev) +{ + struct mpls_dev *mdev; + int err = -ENOMEM; + int i; + + ASSERT_RTNL(); + + mdev = kzalloc(sizeof(*mdev), GFP_KERNEL); + if (!mdev) + return ERR_PTR(err); + + mdev->stats = alloc_percpu(struct mpls_pcpu_stats); + if (!mdev->stats) + goto free; + + for_each_possible_cpu(i) { + struct mpls_pcpu_stats *mpls_stats; + + mpls_stats = per_cpu_ptr(mdev->stats, i); + u64_stats_init(&mpls_stats->syncp); + } + + mdev->dev = dev; + + err = mpls_dev_sysctl_register(dev, mdev); + if (err) + goto free; + + rcu_assign_pointer(dev->mpls_ptr, mdev); + + return mdev; + +free: + free_percpu(mdev->stats); + kfree(mdev); + return ERR_PTR(err); +} + +static void mpls_dev_destroy_rcu(struct rcu_head *head) +{ + struct mpls_dev *mdev = container_of(head, struct mpls_dev, rcu); + + free_percpu(mdev->stats); + kfree(mdev); +} + +static int mpls_ifdown(struct net_device *dev, int event) +{ + struct mpls_route __rcu **platform_label; + struct net *net = dev_net(dev); + unsigned index; + + platform_label = rtnl_dereference(net->mpls.platform_label); + for (index = 0; index < net->mpls.platform_labels; index++) { + struct mpls_route *rt = rtnl_dereference(platform_label[index]); + bool nh_del = false; + u8 alive = 0; + + if (!rt) + continue; + + if (event == NETDEV_UNREGISTER) { + u8 deleted = 0; + + for_nexthops(rt) { + if (!nh->nh_dev || nh->nh_dev == dev) + deleted++; + if (nh->nh_dev == dev) + nh_del = true; + } endfor_nexthops(rt); + + /* if there are no more nexthops, delete the route */ + if (deleted == rt->rt_nhn) { + mpls_route_update(net, index, NULL, NULL); + continue; + } + + if (nh_del) { + size_t size = sizeof(*rt) + rt->rt_nhn * + rt->rt_nh_size; + struct mpls_route *orig = rt; + + rt = kmemdup(orig, size, GFP_KERNEL); + if (!rt) + return -ENOMEM; + } + } + + change_nexthops(rt) { + unsigned int nh_flags = nh->nh_flags; + + if (nh->nh_dev != dev) + goto next; + + switch (event) { + case NETDEV_DOWN: + case NETDEV_UNREGISTER: + nh_flags |= RTNH_F_DEAD; + fallthrough; + case NETDEV_CHANGE: + nh_flags |= RTNH_F_LINKDOWN; + break; + } + if (event == NETDEV_UNREGISTER) + nh->nh_dev = NULL; + + if (nh->nh_flags != nh_flags) + WRITE_ONCE(nh->nh_flags, nh_flags); +next: + if (!(nh_flags & (RTNH_F_DEAD | RTNH_F_LINKDOWN))) + alive++; + } endfor_nexthops(rt); + + WRITE_ONCE(rt->rt_nhn_alive, alive); + + if (nh_del) + mpls_route_update(net, index, rt, NULL); + } + + return 0; +} + +static void mpls_ifup(struct net_device *dev, unsigned int flags) +{ + struct mpls_route __rcu **platform_label; + struct net *net = dev_net(dev); + unsigned index; + u8 alive; + + platform_label = rtnl_dereference(net->mpls.platform_label); + for (index = 0; index < net->mpls.platform_labels; index++) { + struct mpls_route *rt = rtnl_dereference(platform_label[index]); + + if (!rt) + continue; + + alive = 0; + change_nexthops(rt) { + unsigned int nh_flags = nh->nh_flags; + + if (!(nh_flags & flags)) { + alive++; + continue; + } + if (nh->nh_dev != dev) + continue; + alive++; + nh_flags &= ~flags; + WRITE_ONCE(nh->nh_flags, nh_flags); + } endfor_nexthops(rt); + + WRITE_ONCE(rt->rt_nhn_alive, alive); + } +} + +static int mpls_dev_notify(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct mpls_dev *mdev; + unsigned int flags; + int err; + + if (event == NETDEV_REGISTER) { + mdev = mpls_add_dev(dev); + if (IS_ERR(mdev)) + return notifier_from_errno(PTR_ERR(mdev)); + + return NOTIFY_OK; + } + + mdev = mpls_dev_get(dev); + if (!mdev) + return NOTIFY_OK; + + switch (event) { + + case NETDEV_DOWN: + err = mpls_ifdown(dev, event); + if (err) + return notifier_from_errno(err); + break; + case NETDEV_UP: + flags = dev_get_flags(dev); + if (flags & (IFF_RUNNING | IFF_LOWER_UP)) + mpls_ifup(dev, RTNH_F_DEAD | RTNH_F_LINKDOWN); + else + mpls_ifup(dev, RTNH_F_DEAD); + break; + case NETDEV_CHANGE: + flags = dev_get_flags(dev); + if (flags & (IFF_RUNNING | IFF_LOWER_UP)) { + mpls_ifup(dev, RTNH_F_DEAD | RTNH_F_LINKDOWN); + } else { + err = mpls_ifdown(dev, event); + if (err) + return notifier_from_errno(err); + } + break; + case NETDEV_UNREGISTER: + err = mpls_ifdown(dev, event); + if (err) + return notifier_from_errno(err); + mdev = mpls_dev_get(dev); + if (mdev) { + mpls_dev_sysctl_unregister(dev, mdev); + RCU_INIT_POINTER(dev->mpls_ptr, NULL); + call_rcu(&mdev->rcu, mpls_dev_destroy_rcu); + } + break; + case NETDEV_CHANGENAME: + mdev = mpls_dev_get(dev); + if (mdev) { + mpls_dev_sysctl_unregister(dev, mdev); + err = mpls_dev_sysctl_register(dev, mdev); + if (err) + return notifier_from_errno(err); + } + break; + } + return NOTIFY_OK; +} + +static struct notifier_block mpls_dev_notifier = { + .notifier_call = mpls_dev_notify, +}; + +static int nla_put_via(struct sk_buff *skb, + u8 table, const void *addr, int alen) +{ + static const int table_to_family[NEIGH_NR_TABLES + 1] = { + AF_INET, AF_INET6, AF_DECnet, AF_PACKET, + }; + struct nlattr *nla; + struct rtvia *via; + int family = AF_UNSPEC; + + nla = nla_reserve(skb, RTA_VIA, alen + 2); + if (!nla) + return -EMSGSIZE; + + if (table <= NEIGH_NR_TABLES) + family = table_to_family[table]; + + via = nla_data(nla); + via->rtvia_family = family; + memcpy(via->rtvia_addr, addr, alen); + return 0; +} + +int nla_put_labels(struct sk_buff *skb, int attrtype, + u8 labels, const u32 label[]) +{ + struct nlattr *nla; + struct mpls_shim_hdr *nla_label; + bool bos; + int i; + nla = nla_reserve(skb, attrtype, labels*4); + if (!nla) + return -EMSGSIZE; + + nla_label = nla_data(nla); + bos = true; + for (i = labels - 1; i >= 0; i--) { + nla_label[i] = mpls_entry_encode(label[i], 0, 0, bos); + bos = false; + } + + return 0; +} +EXPORT_SYMBOL_GPL(nla_put_labels); + +int nla_get_labels(const struct nlattr *nla, u8 max_labels, u8 *labels, + u32 label[], struct netlink_ext_ack *extack) +{ + unsigned len = nla_len(nla); + struct mpls_shim_hdr *nla_label; + u8 nla_labels; + bool bos; + int i; + + /* len needs to be an even multiple of 4 (the label size). Number + * of labels is a u8 so check for overflow. + */ + if (len & 3 || len / 4 > 255) { + NL_SET_ERR_MSG_ATTR(extack, nla, + "Invalid length for labels attribute"); + return -EINVAL; + } + + /* Limit the number of new labels allowed */ + nla_labels = len/4; + if (nla_labels > max_labels) { + NL_SET_ERR_MSG(extack, "Too many labels"); + return -EINVAL; + } + + /* when label == NULL, caller wants number of labels */ + if (!label) + goto out; + + nla_label = nla_data(nla); + bos = true; + for (i = nla_labels - 1; i >= 0; i--, bos = false) { + struct mpls_entry_decoded dec; + dec = mpls_entry_decode(nla_label + i); + + /* Ensure the bottom of stack flag is properly set + * and ttl and tc are both clear. + */ + if (dec.ttl) { + NL_SET_ERR_MSG_ATTR(extack, nla, + "TTL in label must be 0"); + return -EINVAL; + } + + if (dec.tc) { + NL_SET_ERR_MSG_ATTR(extack, nla, + "Traffic class in label must be 0"); + return -EINVAL; + } + + if (dec.bos != bos) { + NL_SET_BAD_ATTR(extack, nla); + if (bos) { + NL_SET_ERR_MSG(extack, + "BOS bit must be set in first label"); + } else { + NL_SET_ERR_MSG(extack, + "BOS bit can only be set in first label"); + } + return -EINVAL; + } + + switch (dec.label) { + case MPLS_LABEL_IMPLNULL: + /* RFC3032: This is a label that an LSR may + * assign and distribute, but which never + * actually appears in the encapsulation. + */ + NL_SET_ERR_MSG_ATTR(extack, nla, + "Implicit NULL Label (3) can not be used in encapsulation"); + return -EINVAL; + } + + label[i] = dec.label; + } +out: + *labels = nla_labels; + return 0; +} +EXPORT_SYMBOL_GPL(nla_get_labels); + +static int rtm_to_route_config(struct sk_buff *skb, + struct nlmsghdr *nlh, + struct mpls_route_config *cfg, + struct netlink_ext_ack *extack) +{ + struct rtmsg *rtm; + struct nlattr *tb[RTA_MAX+1]; + int index; + int err; + + err = nlmsg_parse_deprecated(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_mpls_policy, extack); + if (err < 0) + goto errout; + + err = -EINVAL; + rtm = nlmsg_data(nlh); + + if (rtm->rtm_family != AF_MPLS) { + NL_SET_ERR_MSG(extack, "Invalid address family in rtmsg"); + goto errout; + } + if (rtm->rtm_dst_len != 20) { + NL_SET_ERR_MSG(extack, "rtm_dst_len must be 20 for MPLS"); + goto errout; + } + if (rtm->rtm_src_len != 0) { + NL_SET_ERR_MSG(extack, "rtm_src_len must be 0 for MPLS"); + goto errout; + } + if (rtm->rtm_tos != 0) { + NL_SET_ERR_MSG(extack, "rtm_tos must be 0 for MPLS"); + goto errout; + } + if (rtm->rtm_table != RT_TABLE_MAIN) { + NL_SET_ERR_MSG(extack, + "MPLS only supports the main route table"); + goto errout; + } + /* Any value is acceptable for rtm_protocol */ + + /* As mpls uses destination specific addresses + * (or source specific address in the case of multicast) + * all addresses have universal scope. + */ + if (rtm->rtm_scope != RT_SCOPE_UNIVERSE) { + NL_SET_ERR_MSG(extack, + "Invalid route scope - MPLS only supports UNIVERSE"); + goto errout; + } + if (rtm->rtm_type != RTN_UNICAST) { + NL_SET_ERR_MSG(extack, + "Invalid route type - MPLS only supports UNICAST"); + goto errout; + } + if (rtm->rtm_flags != 0) { + NL_SET_ERR_MSG(extack, "rtm_flags must be 0 for MPLS"); + goto errout; + } + + cfg->rc_label = LABEL_NOT_SPECIFIED; + cfg->rc_protocol = rtm->rtm_protocol; + cfg->rc_via_table = MPLS_NEIGH_TABLE_UNSPEC; + cfg->rc_ttl_propagate = MPLS_TTL_PROP_DEFAULT; + cfg->rc_nlflags = nlh->nlmsg_flags; + cfg->rc_nlinfo.portid = NETLINK_CB(skb).portid; + cfg->rc_nlinfo.nlh = nlh; + cfg->rc_nlinfo.nl_net = sock_net(skb->sk); + + for (index = 0; index <= RTA_MAX; index++) { + struct nlattr *nla = tb[index]; + if (!nla) + continue; + + switch (index) { + case RTA_OIF: + cfg->rc_ifindex = nla_get_u32(nla); + break; + case RTA_NEWDST: + if (nla_get_labels(nla, MAX_NEW_LABELS, + &cfg->rc_output_labels, + cfg->rc_output_label, extack)) + goto errout; + break; + case RTA_DST: + { + u8 label_count; + if (nla_get_labels(nla, 1, &label_count, + &cfg->rc_label, extack)) + goto errout; + + if (!mpls_label_ok(cfg->rc_nlinfo.nl_net, + &cfg->rc_label, extack)) + goto errout; + break; + } + case RTA_GATEWAY: + NL_SET_ERR_MSG(extack, "MPLS does not support RTA_GATEWAY attribute"); + goto errout; + case RTA_VIA: + { + if (nla_get_via(nla, &cfg->rc_via_alen, + &cfg->rc_via_table, cfg->rc_via, + extack)) + goto errout; + break; + } + case RTA_MULTIPATH: + { + cfg->rc_mp = nla_data(nla); + cfg->rc_mp_len = nla_len(nla); + break; + } + case RTA_TTL_PROPAGATE: + { + u8 ttl_propagate = nla_get_u8(nla); + + if (ttl_propagate > 1) { + NL_SET_ERR_MSG_ATTR(extack, nla, + "RTA_TTL_PROPAGATE can only be 0 or 1"); + goto errout; + } + cfg->rc_ttl_propagate = ttl_propagate ? + MPLS_TTL_PROP_ENABLED : + MPLS_TTL_PROP_DISABLED; + break; + } + default: + NL_SET_ERR_MSG_ATTR(extack, nla, "Unknown attribute"); + /* Unsupported attribute */ + goto errout; + } + } + + err = 0; +errout: + return err; +} + +static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct mpls_route_config *cfg; + int err; + + cfg = kzalloc(sizeof(*cfg), GFP_KERNEL); + if (!cfg) + return -ENOMEM; + + err = rtm_to_route_config(skb, nlh, cfg, extack); + if (err < 0) + goto out; + + err = mpls_route_del(cfg, extack); +out: + kfree(cfg); + + return err; +} + + +static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct mpls_route_config *cfg; + int err; + + cfg = kzalloc(sizeof(*cfg), GFP_KERNEL); + if (!cfg) + return -ENOMEM; + + err = rtm_to_route_config(skb, nlh, cfg, extack); + if (err < 0) + goto out; + + err = mpls_route_add(cfg, extack); +out: + kfree(cfg); + + return err; +} + +static int mpls_dump_route(struct sk_buff *skb, u32 portid, u32 seq, int event, + u32 label, struct mpls_route *rt, int flags) +{ + struct net_device *dev; + struct nlmsghdr *nlh; + struct rtmsg *rtm; + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(*rtm), flags); + if (nlh == NULL) + return -EMSGSIZE; + + rtm = nlmsg_data(nlh); + rtm->rtm_family = AF_MPLS; + rtm->rtm_dst_len = 20; + rtm->rtm_src_len = 0; + rtm->rtm_tos = 0; + rtm->rtm_table = RT_TABLE_MAIN; + rtm->rtm_protocol = rt->rt_protocol; + rtm->rtm_scope = RT_SCOPE_UNIVERSE; + rtm->rtm_type = RTN_UNICAST; + rtm->rtm_flags = 0; + + if (nla_put_labels(skb, RTA_DST, 1, &label)) + goto nla_put_failure; + + if (rt->rt_ttl_propagate != MPLS_TTL_PROP_DEFAULT) { + bool ttl_propagate = + rt->rt_ttl_propagate == MPLS_TTL_PROP_ENABLED; + + if (nla_put_u8(skb, RTA_TTL_PROPAGATE, + ttl_propagate)) + goto nla_put_failure; + } + if (rt->rt_nhn == 1) { + const struct mpls_nh *nh = rt->rt_nh; + + if (nh->nh_labels && + nla_put_labels(skb, RTA_NEWDST, nh->nh_labels, + nh->nh_label)) + goto nla_put_failure; + if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC && + nla_put_via(skb, nh->nh_via_table, mpls_nh_via(rt, nh), + nh->nh_via_alen)) + goto nla_put_failure; + dev = nh->nh_dev; + if (dev && nla_put_u32(skb, RTA_OIF, dev->ifindex)) + goto nla_put_failure; + if (nh->nh_flags & RTNH_F_LINKDOWN) + rtm->rtm_flags |= RTNH_F_LINKDOWN; + if (nh->nh_flags & RTNH_F_DEAD) + rtm->rtm_flags |= RTNH_F_DEAD; + } else { + struct rtnexthop *rtnh; + struct nlattr *mp; + u8 linkdown = 0; + u8 dead = 0; + + mp = nla_nest_start_noflag(skb, RTA_MULTIPATH); + if (!mp) + goto nla_put_failure; + + for_nexthops(rt) { + dev = nh->nh_dev; + if (!dev) + continue; + + rtnh = nla_reserve_nohdr(skb, sizeof(*rtnh)); + if (!rtnh) + goto nla_put_failure; + + rtnh->rtnh_ifindex = dev->ifindex; + if (nh->nh_flags & RTNH_F_LINKDOWN) { + rtnh->rtnh_flags |= RTNH_F_LINKDOWN; + linkdown++; + } + if (nh->nh_flags & RTNH_F_DEAD) { + rtnh->rtnh_flags |= RTNH_F_DEAD; + dead++; + } + + if (nh->nh_labels && nla_put_labels(skb, RTA_NEWDST, + nh->nh_labels, + nh->nh_label)) + goto nla_put_failure; + if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC && + nla_put_via(skb, nh->nh_via_table, + mpls_nh_via(rt, nh), + nh->nh_via_alen)) + goto nla_put_failure; + + /* length of rtnetlink header + attributes */ + rtnh->rtnh_len = nlmsg_get_pos(skb) - (void *)rtnh; + } endfor_nexthops(rt); + + if (linkdown == rt->rt_nhn) + rtm->rtm_flags |= RTNH_F_LINKDOWN; + if (dead == rt->rt_nhn) + rtm->rtm_flags |= RTNH_F_DEAD; + + nla_nest_end(skb, mp); + } + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +#if IS_ENABLED(CONFIG_INET) +static int mpls_valid_fib_dump_req(struct net *net, const struct nlmsghdr *nlh, + struct fib_dump_filter *filter, + struct netlink_callback *cb) +{ + return ip_valid_fib_dump_req(net, nlh, filter, cb); +} +#else +static int mpls_valid_fib_dump_req(struct net *net, const struct nlmsghdr *nlh, + struct fib_dump_filter *filter, + struct netlink_callback *cb) +{ + struct netlink_ext_ack *extack = cb->extack; + struct nlattr *tb[RTA_MAX + 1]; + struct rtmsg *rtm; + int err, i; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid header for FIB dump request"); + return -EINVAL; + } + + rtm = nlmsg_data(nlh); + if (rtm->rtm_dst_len || rtm->rtm_src_len || rtm->rtm_tos || + rtm->rtm_table || rtm->rtm_scope || rtm->rtm_type || + rtm->rtm_flags) { + NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for FIB dump request"); + return -EINVAL; + } + + if (rtm->rtm_protocol) { + filter->protocol = rtm->rtm_protocol; + filter->filter_set = 1; + cb->answer_flags = NLM_F_DUMP_FILTERED; + } + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_mpls_policy, extack); + if (err < 0) + return err; + + for (i = 0; i <= RTA_MAX; ++i) { + int ifindex; + + if (i == RTA_OIF) { + ifindex = nla_get_u32(tb[i]); + filter->dev = __dev_get_by_index(net, ifindex); + if (!filter->dev) + return -ENODEV; + filter->filter_set = 1; + } else if (tb[i]) { + NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in dump request"); + return -EINVAL; + } + } + + return 0; +} +#endif + +static bool mpls_rt_uses_dev(struct mpls_route *rt, + const struct net_device *dev) +{ + if (rt->rt_nhn == 1) { + struct mpls_nh *nh = rt->rt_nh; + + if (nh->nh_dev == dev) + return true; + } else { + for_nexthops(rt) { + if (nh->nh_dev == dev) + return true; + } endfor_nexthops(rt); + } + + return false; +} + +static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb) +{ + const struct nlmsghdr *nlh = cb->nlh; + struct net *net = sock_net(skb->sk); + struct mpls_route __rcu **platform_label; + struct fib_dump_filter filter = {}; + unsigned int flags = NLM_F_MULTI; + size_t platform_labels; + unsigned int index; + + ASSERT_RTNL(); + + if (cb->strict_check) { + int err; + + err = mpls_valid_fib_dump_req(net, nlh, &filter, cb); + if (err < 0) + return err; + + /* for MPLS, there is only 1 table with fixed type and flags. + * If either are set in the filter then return nothing. + */ + if ((filter.table_id && filter.table_id != RT_TABLE_MAIN) || + (filter.rt_type && filter.rt_type != RTN_UNICAST) || + filter.flags) + return skb->len; + } + + index = cb->args[0]; + if (index < MPLS_LABEL_FIRST_UNRESERVED) + index = MPLS_LABEL_FIRST_UNRESERVED; + + platform_label = rtnl_dereference(net->mpls.platform_label); + platform_labels = net->mpls.platform_labels; + + if (filter.filter_set) + flags |= NLM_F_DUMP_FILTERED; + + for (; index < platform_labels; index++) { + struct mpls_route *rt; + + rt = rtnl_dereference(platform_label[index]); + if (!rt) + continue; + + if ((filter.dev && !mpls_rt_uses_dev(rt, filter.dev)) || + (filter.protocol && rt->rt_protocol != filter.protocol)) + continue; + + if (mpls_dump_route(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, RTM_NEWROUTE, + index, rt, flags) < 0) + break; + } + cb->args[0] = index; + + return skb->len; +} + +static inline size_t lfib_nlmsg_size(struct mpls_route *rt) +{ + size_t payload = + NLMSG_ALIGN(sizeof(struct rtmsg)) + + nla_total_size(4) /* RTA_DST */ + + nla_total_size(1); /* RTA_TTL_PROPAGATE */ + + if (rt->rt_nhn == 1) { + struct mpls_nh *nh = rt->rt_nh; + + if (nh->nh_dev) + payload += nla_total_size(4); /* RTA_OIF */ + if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC) /* RTA_VIA */ + payload += nla_total_size(2 + nh->nh_via_alen); + if (nh->nh_labels) /* RTA_NEWDST */ + payload += nla_total_size(nh->nh_labels * 4); + } else { + /* each nexthop is packed in an attribute */ + size_t nhsize = 0; + + for_nexthops(rt) { + if (!nh->nh_dev) + continue; + nhsize += nla_total_size(sizeof(struct rtnexthop)); + /* RTA_VIA */ + if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC) + nhsize += nla_total_size(2 + nh->nh_via_alen); + if (nh->nh_labels) + nhsize += nla_total_size(nh->nh_labels * 4); + } endfor_nexthops(rt); + /* nested attribute */ + payload += nla_total_size(nhsize); + } + + return payload; +} + +static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt, + struct nlmsghdr *nlh, struct net *net, u32 portid, + unsigned int nlm_flags) +{ + struct sk_buff *skb; + u32 seq = nlh ? nlh->nlmsg_seq : 0; + int err = -ENOBUFS; + + skb = nlmsg_new(lfib_nlmsg_size(rt), GFP_KERNEL); + if (skb == NULL) + goto errout; + + err = mpls_dump_route(skb, portid, seq, event, label, rt, nlm_flags); + if (err < 0) { + /* -EMSGSIZE implies BUG in lfib_nlmsg_size */ + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + rtnl_notify(skb, net, portid, RTNLGRP_MPLS_ROUTE, nlh, GFP_KERNEL); + + return; +errout: + if (err < 0) + rtnl_set_sk_err(net, RTNLGRP_MPLS_ROUTE, err); +} + +static int mpls_valid_getroute_req(struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct rtmsg *rtm; + int i, err; + + if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*rtm))) { + NL_SET_ERR_MSG_MOD(extack, + "Invalid header for get route request"); + return -EINVAL; + } + + if (!netlink_strict_get_check(skb)) + return nlmsg_parse_deprecated(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_mpls_policy, extack); + + rtm = nlmsg_data(nlh); + if ((rtm->rtm_dst_len && rtm->rtm_dst_len != 20) || + rtm->rtm_src_len || rtm->rtm_tos || rtm->rtm_table || + rtm->rtm_protocol || rtm->rtm_scope || rtm->rtm_type) { + NL_SET_ERR_MSG_MOD(extack, "Invalid values in header for get route request"); + return -EINVAL; + } + if (rtm->rtm_flags & ~RTM_F_FIB_MATCH) { + NL_SET_ERR_MSG_MOD(extack, + "Invalid flags for get route request"); + return -EINVAL; + } + + err = nlmsg_parse_deprecated_strict(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_mpls_policy, extack); + if (err) + return err; + + if ((tb[RTA_DST] || tb[RTA_NEWDST]) && !rtm->rtm_dst_len) { + NL_SET_ERR_MSG_MOD(extack, "rtm_dst_len must be 20 for MPLS"); + return -EINVAL; + } + + for (i = 0; i <= RTA_MAX; i++) { + if (!tb[i]) + continue; + + switch (i) { + case RTA_DST: + case RTA_NEWDST: + break; + default: + NL_SET_ERR_MSG_MOD(extack, "Unsupported attribute in get route request"); + return -EINVAL; + } + } + + return 0; +} + +static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(in_skb->sk); + u32 portid = NETLINK_CB(in_skb).portid; + u32 in_label = LABEL_NOT_SPECIFIED; + struct nlattr *tb[RTA_MAX + 1]; + u32 labels[MAX_NEW_LABELS]; + struct mpls_shim_hdr *hdr; + unsigned int hdr_size = 0; + const struct mpls_nh *nh; + struct net_device *dev; + struct mpls_route *rt; + struct rtmsg *rtm, *r; + struct nlmsghdr *nlh; + struct sk_buff *skb; + u8 n_labels; + int err; + + err = mpls_valid_getroute_req(in_skb, in_nlh, tb, extack); + if (err < 0) + goto errout; + + rtm = nlmsg_data(in_nlh); + + if (tb[RTA_DST]) { + u8 label_count; + + if (nla_get_labels(tb[RTA_DST], 1, &label_count, + &in_label, extack)) { + err = -EINVAL; + goto errout; + } + + if (!mpls_label_ok(net, &in_label, extack)) { + err = -EINVAL; + goto errout; + } + } + + rt = mpls_route_input_rcu(net, in_label); + if (!rt) { + err = -ENETUNREACH; + goto errout; + } + + if (rtm->rtm_flags & RTM_F_FIB_MATCH) { + skb = nlmsg_new(lfib_nlmsg_size(rt), GFP_KERNEL); + if (!skb) { + err = -ENOBUFS; + goto errout; + } + + err = mpls_dump_route(skb, portid, in_nlh->nlmsg_seq, + RTM_NEWROUTE, in_label, rt, 0); + if (err < 0) { + /* -EMSGSIZE implies BUG in lfib_nlmsg_size */ + WARN_ON(err == -EMSGSIZE); + goto errout_free; + } + + return rtnl_unicast(skb, net, portid); + } + + if (tb[RTA_NEWDST]) { + if (nla_get_labels(tb[RTA_NEWDST], MAX_NEW_LABELS, &n_labels, + labels, extack) != 0) { + err = -EINVAL; + goto errout; + } + + hdr_size = n_labels * sizeof(struct mpls_shim_hdr); + } + + skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) { + err = -ENOBUFS; + goto errout; + } + + skb->protocol = htons(ETH_P_MPLS_UC); + + if (hdr_size) { + bool bos; + int i; + + if (skb_cow(skb, hdr_size)) { + err = -ENOBUFS; + goto errout_free; + } + + skb_reserve(skb, hdr_size); + skb_push(skb, hdr_size); + skb_reset_network_header(skb); + + /* Push new labels */ + hdr = mpls_hdr(skb); + bos = true; + for (i = n_labels - 1; i >= 0; i--) { + hdr[i] = mpls_entry_encode(labels[i], + 1, 0, bos); + bos = false; + } + } + + nh = mpls_select_multipath(rt, skb); + if (!nh) { + err = -ENETUNREACH; + goto errout_free; + } + + if (hdr_size) { + skb_pull(skb, hdr_size); + skb_reset_network_header(skb); + } + + nlh = nlmsg_put(skb, portid, in_nlh->nlmsg_seq, + RTM_NEWROUTE, sizeof(*r), 0); + if (!nlh) { + err = -EMSGSIZE; + goto errout_free; + } + + r = nlmsg_data(nlh); + r->rtm_family = AF_MPLS; + r->rtm_dst_len = 20; + r->rtm_src_len = 0; + r->rtm_table = RT_TABLE_MAIN; + r->rtm_type = RTN_UNICAST; + r->rtm_scope = RT_SCOPE_UNIVERSE; + r->rtm_protocol = rt->rt_protocol; + r->rtm_flags = 0; + + if (nla_put_labels(skb, RTA_DST, 1, &in_label)) + goto nla_put_failure; + + if (nh->nh_labels && + nla_put_labels(skb, RTA_NEWDST, nh->nh_labels, + nh->nh_label)) + goto nla_put_failure; + + if (nh->nh_via_table != MPLS_NEIGH_TABLE_UNSPEC && + nla_put_via(skb, nh->nh_via_table, mpls_nh_via(rt, nh), + nh->nh_via_alen)) + goto nla_put_failure; + dev = nh->nh_dev; + if (dev && nla_put_u32(skb, RTA_OIF, dev->ifindex)) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + + err = rtnl_unicast(skb, net, portid); +errout: + return err; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + err = -EMSGSIZE; +errout_free: + kfree_skb(skb); + return err; +} + +static int resize_platform_label_table(struct net *net, size_t limit) +{ + size_t size = sizeof(struct mpls_route *) * limit; + size_t old_limit; + size_t cp_size; + struct mpls_route __rcu **labels = NULL, **old; + struct mpls_route *rt0 = NULL, *rt2 = NULL; + unsigned index; + + if (size) { + labels = kvzalloc(size, GFP_KERNEL); + if (!labels) + goto nolabels; + } + + /* In case the predefined labels need to be populated */ + if (limit > MPLS_LABEL_IPV4NULL) { + struct net_device *lo = net->loopback_dev; + rt0 = mpls_rt_alloc(1, lo->addr_len, 0); + if (IS_ERR(rt0)) + goto nort0; + rt0->rt_nh->nh_dev = lo; + rt0->rt_protocol = RTPROT_KERNEL; + rt0->rt_payload_type = MPT_IPV4; + rt0->rt_ttl_propagate = MPLS_TTL_PROP_DEFAULT; + rt0->rt_nh->nh_via_table = NEIGH_LINK_TABLE; + rt0->rt_nh->nh_via_alen = lo->addr_len; + memcpy(__mpls_nh_via(rt0, rt0->rt_nh), lo->dev_addr, + lo->addr_len); + } + if (limit > MPLS_LABEL_IPV6NULL) { + struct net_device *lo = net->loopback_dev; + rt2 = mpls_rt_alloc(1, lo->addr_len, 0); + if (IS_ERR(rt2)) + goto nort2; + rt2->rt_nh->nh_dev = lo; + rt2->rt_protocol = RTPROT_KERNEL; + rt2->rt_payload_type = MPT_IPV6; + rt2->rt_ttl_propagate = MPLS_TTL_PROP_DEFAULT; + rt2->rt_nh->nh_via_table = NEIGH_LINK_TABLE; + rt2->rt_nh->nh_via_alen = lo->addr_len; + memcpy(__mpls_nh_via(rt2, rt2->rt_nh), lo->dev_addr, + lo->addr_len); + } + + rtnl_lock(); + /* Remember the original table */ + old = rtnl_dereference(net->mpls.platform_label); + old_limit = net->mpls.platform_labels; + + /* Free any labels beyond the new table */ + for (index = limit; index < old_limit; index++) + mpls_route_update(net, index, NULL, NULL); + + /* Copy over the old labels */ + cp_size = size; + if (old_limit < limit) + cp_size = old_limit * sizeof(struct mpls_route *); + + memcpy(labels, old, cp_size); + + /* If needed set the predefined labels */ + if ((old_limit <= MPLS_LABEL_IPV6NULL) && + (limit > MPLS_LABEL_IPV6NULL)) { + RCU_INIT_POINTER(labels[MPLS_LABEL_IPV6NULL], rt2); + rt2 = NULL; + } + + if ((old_limit <= MPLS_LABEL_IPV4NULL) && + (limit > MPLS_LABEL_IPV4NULL)) { + RCU_INIT_POINTER(labels[MPLS_LABEL_IPV4NULL], rt0); + rt0 = NULL; + } + + /* Update the global pointers */ + net->mpls.platform_labels = limit; + rcu_assign_pointer(net->mpls.platform_label, labels); + + rtnl_unlock(); + + mpls_rt_free(rt2); + mpls_rt_free(rt0); + + if (old) { + synchronize_rcu(); + kvfree(old); + } + return 0; + +nort2: + mpls_rt_free(rt0); +nort0: + kvfree(labels); +nolabels: + return -ENOMEM; +} + +static int mpls_platform_labels(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = table->data; + int platform_labels = net->mpls.platform_labels; + int ret; + struct ctl_table tmp = { + .procname = table->procname, + .data = &platform_labels, + .maxlen = sizeof(int), + .mode = table->mode, + .extra1 = SYSCTL_ZERO, + .extra2 = &label_limit, + }; + + ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos); + + if (write && ret == 0) + ret = resize_platform_label_table(net, platform_labels); + + return ret; +} + +#define MPLS_NS_SYSCTL_OFFSET(field) \ + (&((struct net *)0)->field) + +static const struct ctl_table mpls_table[] = { + { + .procname = "platform_labels", + .data = NULL, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = mpls_platform_labels, + }, + { + .procname = "ip_ttl_propagate", + .data = MPLS_NS_SYSCTL_OFFSET(mpls.ip_ttl_propagate), + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + { + .procname = "default_ttl", + .data = MPLS_NS_SYSCTL_OFFSET(mpls.default_ttl), + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + .extra2 = &ttl_max, + }, + { } +}; + +static int mpls_net_init(struct net *net) +{ + struct ctl_table *table; + int i; + + net->mpls.platform_labels = 0; + net->mpls.platform_label = NULL; + net->mpls.ip_ttl_propagate = 1; + net->mpls.default_ttl = 255; + + table = kmemdup(mpls_table, sizeof(mpls_table), GFP_KERNEL); + if (table == NULL) + return -ENOMEM; + + /* Table data contains only offsets relative to the base of + * the mdev at this point, so make them absolute. + */ + for (i = 0; i < ARRAY_SIZE(mpls_table) - 1; i++) + table[i].data = (char *)net + (uintptr_t)table[i].data; + + net->mpls.ctl = register_net_sysctl(net, "net/mpls", table); + if (net->mpls.ctl == NULL) { + kfree(table); + return -ENOMEM; + } + + return 0; +} + +static void mpls_net_exit(struct net *net) +{ + struct mpls_route __rcu **platform_label; + size_t platform_labels; + struct ctl_table *table; + unsigned int index; + + table = net->mpls.ctl->ctl_table_arg; + unregister_net_sysctl_table(net->mpls.ctl); + kfree(table); + + /* An rcu grace period has passed since there was a device in + * the network namespace (and thus the last in flight packet) + * left this network namespace. This is because + * unregister_netdevice_many and netdev_run_todo has completed + * for each network device that was in this network namespace. + * + * As such no additional rcu synchronization is necessary when + * freeing the platform_label table. + */ + rtnl_lock(); + platform_label = rtnl_dereference(net->mpls.platform_label); + platform_labels = net->mpls.platform_labels; + for (index = 0; index < platform_labels; index++) { + struct mpls_route *rt = rtnl_dereference(platform_label[index]); + RCU_INIT_POINTER(platform_label[index], NULL); + mpls_notify_route(net, index, rt, NULL, NULL); + mpls_rt_free(rt); + } + rtnl_unlock(); + + kvfree(platform_label); +} + +static struct pernet_operations mpls_net_ops = { + .init = mpls_net_init, + .exit = mpls_net_exit, +}; + +static struct rtnl_af_ops mpls_af_ops __read_mostly = { + .family = AF_MPLS, + .fill_stats_af = mpls_fill_stats_af, + .get_stats_af_size = mpls_get_stats_af_size, +}; + +static int __init mpls_init(void) +{ + int err; + + BUILD_BUG_ON(sizeof(struct mpls_shim_hdr) != 4); + + err = register_pernet_subsys(&mpls_net_ops); + if (err) + goto out; + + err = register_netdevice_notifier(&mpls_dev_notifier); + if (err) + goto out_unregister_pernet; + + dev_add_pack(&mpls_packet_type); + + rtnl_af_register(&mpls_af_ops); + + rtnl_register_module(THIS_MODULE, PF_MPLS, RTM_NEWROUTE, + mpls_rtm_newroute, NULL, 0); + rtnl_register_module(THIS_MODULE, PF_MPLS, RTM_DELROUTE, + mpls_rtm_delroute, NULL, 0); + rtnl_register_module(THIS_MODULE, PF_MPLS, RTM_GETROUTE, + mpls_getroute, mpls_dump_routes, 0); + rtnl_register_module(THIS_MODULE, PF_MPLS, RTM_GETNETCONF, + mpls_netconf_get_devconf, + mpls_netconf_dump_devconf, 0); + err = ipgre_tunnel_encap_add_mpls_ops(); + if (err) + pr_err("Can't add mpls over gre tunnel ops\n"); + + err = 0; +out: + return err; + +out_unregister_pernet: + unregister_pernet_subsys(&mpls_net_ops); + goto out; +} +module_init(mpls_init); + +static void __exit mpls_exit(void) +{ + rtnl_unregister_all(PF_MPLS); + rtnl_af_unregister(&mpls_af_ops); + dev_remove_pack(&mpls_packet_type); + unregister_netdevice_notifier(&mpls_dev_notifier); + unregister_pernet_subsys(&mpls_net_ops); + ipgre_tunnel_encap_del_mpls_ops(); +} +module_exit(mpls_exit); + +MODULE_DESCRIPTION("MultiProtocol Label Switching"); +MODULE_LICENSE("GPL v2"); +MODULE_ALIAS_NETPROTO(PF_MPLS); diff --git a/net/mpls/internal.h b/net/mpls/internal.h new file mode 100644 index 000000000..b9f492ddf --- /dev/null +++ b/net/mpls/internal.h @@ -0,0 +1,202 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef MPLS_INTERNAL_H +#define MPLS_INTERNAL_H +#include + +/* put a reasonable limit on the number of labels + * we will accept from userspace + */ +#define MAX_NEW_LABELS 30 + +struct mpls_entry_decoded { + u32 label; + u8 ttl; + u8 tc; + u8 bos; +}; + +struct mpls_pcpu_stats { + struct mpls_link_stats stats; + struct u64_stats_sync syncp; +}; + +struct mpls_dev { + int input_enabled; + struct net_device *dev; + struct mpls_pcpu_stats __percpu *stats; + + struct ctl_table_header *sysctl; + struct rcu_head rcu; +}; + +#if BITS_PER_LONG == 32 + +#define MPLS_INC_STATS_LEN(mdev, len, pkts_field, bytes_field) \ + do { \ + __typeof__(*(mdev)->stats) *ptr = \ + raw_cpu_ptr((mdev)->stats); \ + local_bh_disable(); \ + u64_stats_update_begin(&ptr->syncp); \ + ptr->stats.pkts_field++; \ + ptr->stats.bytes_field += (len); \ + u64_stats_update_end(&ptr->syncp); \ + local_bh_enable(); \ + } while (0) + +#define MPLS_INC_STATS(mdev, field) \ + do { \ + __typeof__(*(mdev)->stats) *ptr = \ + raw_cpu_ptr((mdev)->stats); \ + local_bh_disable(); \ + u64_stats_update_begin(&ptr->syncp); \ + ptr->stats.field++; \ + u64_stats_update_end(&ptr->syncp); \ + local_bh_enable(); \ + } while (0) + +#else + +#define MPLS_INC_STATS_LEN(mdev, len, pkts_field, bytes_field) \ + do { \ + this_cpu_inc((mdev)->stats->stats.pkts_field); \ + this_cpu_add((mdev)->stats->stats.bytes_field, (len)); \ + } while (0) + +#define MPLS_INC_STATS(mdev, field) \ + this_cpu_inc((mdev)->stats->stats.field) + +#endif + +struct sk_buff; + +#define LABEL_NOT_SPECIFIED (1 << 20) + +/* This maximum ha length copied from the definition of struct neighbour */ +#define VIA_ALEN_ALIGN sizeof(unsigned long) +#define MAX_VIA_ALEN (ALIGN(MAX_ADDR_LEN, VIA_ALEN_ALIGN)) + +enum mpls_payload_type { + MPT_UNSPEC, /* IPv4 or IPv6 */ + MPT_IPV4 = 4, + MPT_IPV6 = 6, + + /* Other types not implemented: + * - Pseudo-wire with or without control word (RFC4385) + * - GAL (RFC5586) + */ +}; + +struct mpls_nh { /* next hop label forwarding entry */ + struct net_device *nh_dev; + + /* nh_flags is accessed under RCU in the packet path; it is + * modified handling netdev events with rtnl lock held + */ + unsigned int nh_flags; + u8 nh_labels; + u8 nh_via_alen; + u8 nh_via_table; + u8 nh_reserved1; + + u32 nh_label[]; +}; + +/* offset of via from beginning of mpls_nh */ +#define MPLS_NH_VIA_OFF(num_labels) \ + ALIGN(sizeof(struct mpls_nh) + (num_labels) * sizeof(u32), \ + VIA_ALEN_ALIGN) + +/* all nexthops within a route have the same size based on the + * max number of labels and max via length across all nexthops + */ +#define MPLS_NH_SIZE(num_labels, max_via_alen) \ + (MPLS_NH_VIA_OFF((num_labels)) + \ + ALIGN((max_via_alen), VIA_ALEN_ALIGN)) + +enum mpls_ttl_propagation { + MPLS_TTL_PROP_DEFAULT, + MPLS_TTL_PROP_ENABLED, + MPLS_TTL_PROP_DISABLED, +}; + +/* The route, nexthops and vias are stored together in the same memory + * block: + * + * +----------------------+ + * | mpls_route | + * +----------------------+ + * | mpls_nh 0 | + * +----------------------+ + * | alignment padding | 4 bytes for odd number of labels + * +----------------------+ + * | via[rt_max_alen] 0 | + * +----------------------+ + * | alignment padding | via's aligned on sizeof(unsigned long) + * +----------------------+ + * | ... | + * +----------------------+ + * | mpls_nh n-1 | + * +----------------------+ + * | via[rt_max_alen] n-1 | + * +----------------------+ + */ +struct mpls_route { /* next hop label forwarding entry */ + struct rcu_head rt_rcu; + u8 rt_protocol; + u8 rt_payload_type; + u8 rt_max_alen; + u8 rt_ttl_propagate; + u8 rt_nhn; + /* rt_nhn_alive is accessed under RCU in the packet path; it + * is modified handling netdev events with rtnl lock held + */ + u8 rt_nhn_alive; + u8 rt_nh_size; + u8 rt_via_offset; + u8 rt_reserved1; + struct mpls_nh rt_nh[]; +}; + +#define for_nexthops(rt) { \ + int nhsel; const struct mpls_nh *nh; \ + for (nhsel = 0, nh = (rt)->rt_nh; \ + nhsel < (rt)->rt_nhn; \ + nh = (void *)nh + (rt)->rt_nh_size, nhsel++) + +#define change_nexthops(rt) { \ + int nhsel; struct mpls_nh *nh; \ + for (nhsel = 0, nh = (rt)->rt_nh; \ + nhsel < (rt)->rt_nhn; \ + nh = (void *)nh + (rt)->rt_nh_size, nhsel++) + +#define endfor_nexthops(rt) } + +static inline struct mpls_entry_decoded mpls_entry_decode(struct mpls_shim_hdr *hdr) +{ + struct mpls_entry_decoded result; + unsigned entry = be32_to_cpu(hdr->label_stack_entry); + + result.label = (entry & MPLS_LS_LABEL_MASK) >> MPLS_LS_LABEL_SHIFT; + result.ttl = (entry & MPLS_LS_TTL_MASK) >> MPLS_LS_TTL_SHIFT; + result.tc = (entry & MPLS_LS_TC_MASK) >> MPLS_LS_TC_SHIFT; + result.bos = (entry & MPLS_LS_S_MASK) >> MPLS_LS_S_SHIFT; + + return result; +} + +static inline struct mpls_dev *mpls_dev_get(const struct net_device *dev) +{ + return rcu_dereference_rtnl(dev->mpls_ptr); +} + +int nla_put_labels(struct sk_buff *skb, int attrtype, u8 labels, + const u32 label[]); +int nla_get_labels(const struct nlattr *nla, u8 max_labels, u8 *labels, + u32 label[], struct netlink_ext_ack *extack); +bool mpls_output_possible(const struct net_device *dev); +unsigned int mpls_dev_mtu(const struct net_device *dev); +bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu); +void mpls_stats_inc_outucastpkts(struct net_device *dev, + const struct sk_buff *skb); + +#endif /* MPLS_INTERNAL_H */ diff --git a/net/mpls/mpls_gso.c b/net/mpls/mpls_gso.c new file mode 100644 index 000000000..1482259de --- /dev/null +++ b/net/mpls/mpls_gso.c @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * MPLS GSO Support + * + * Authors: Simon Horman (horms@verge.net.au) + * + * Based on: GSO portions of net/ipv4/gre.c + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include + +static struct sk_buff *mpls_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + struct sk_buff *segs = ERR_PTR(-EINVAL); + u16 mac_offset = skb->mac_header; + netdev_features_t mpls_features; + u16 mac_len = skb->mac_len; + __be16 mpls_protocol; + unsigned int mpls_hlen; + + skb_reset_network_header(skb); + mpls_hlen = skb_inner_network_header(skb) - skb_network_header(skb); + if (unlikely(!mpls_hlen || mpls_hlen % MPLS_HLEN)) + goto out; + if (unlikely(!pskb_may_pull(skb, mpls_hlen))) + goto out; + + /* Setup inner SKB. */ + mpls_protocol = skb->protocol; + skb->protocol = skb->inner_protocol; + + __skb_pull(skb, mpls_hlen); + + skb->mac_len = 0; + skb_reset_mac_header(skb); + + /* Segment inner packet. */ + mpls_features = skb->dev->mpls_features & features; + segs = skb_mac_gso_segment(skb, mpls_features); + if (IS_ERR_OR_NULL(segs)) { + skb_gso_error_unwind(skb, mpls_protocol, mpls_hlen, mac_offset, + mac_len); + goto out; + } + skb = segs; + + mpls_hlen += mac_len; + do { + skb->mac_len = mac_len; + skb->protocol = mpls_protocol; + + skb_reset_inner_network_header(skb); + + __skb_push(skb, mpls_hlen); + + skb_reset_mac_header(skb); + skb_set_network_header(skb, mac_len); + } while ((skb = skb->next)); + +out: + return segs; +} + +static struct packet_offload mpls_mc_offload __read_mostly = { + .type = cpu_to_be16(ETH_P_MPLS_MC), + .priority = 15, + .callbacks = { + .gso_segment = mpls_gso_segment, + }, +}; + +static struct packet_offload mpls_uc_offload __read_mostly = { + .type = cpu_to_be16(ETH_P_MPLS_UC), + .priority = 15, + .callbacks = { + .gso_segment = mpls_gso_segment, + }, +}; + +static int __init mpls_gso_init(void) +{ + pr_info("MPLS GSO support\n"); + + dev_add_offload(&mpls_uc_offload); + dev_add_offload(&mpls_mc_offload); + + return 0; +} + +static void __exit mpls_gso_exit(void) +{ + dev_remove_offload(&mpls_uc_offload); + dev_remove_offload(&mpls_mc_offload); +} + +module_init(mpls_gso_init); +module_exit(mpls_gso_exit); + +MODULE_DESCRIPTION("MPLS GSO support"); +MODULE_AUTHOR("Simon Horman (horms@verge.net.au)"); +MODULE_LICENSE("GPL"); diff --git a/net/mpls/mpls_iptunnel.c b/net/mpls/mpls_iptunnel.c new file mode 100644 index 000000000..ef59e25dc --- /dev/null +++ b/net/mpls/mpls_iptunnel.c @@ -0,0 +1,305 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * mpls tunnels An implementation mpls tunnels using the light weight tunnel + * infrastructure + * + * Authors: Roopa Prabhu, + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "internal.h" + +static const struct nla_policy mpls_iptunnel_policy[MPLS_IPTUNNEL_MAX + 1] = { + [MPLS_IPTUNNEL_DST] = { .len = sizeof(u32) }, + [MPLS_IPTUNNEL_TTL] = { .type = NLA_U8 }, +}; + +static unsigned int mpls_encap_size(struct mpls_iptunnel_encap *en) +{ + /* The size of the layer 2.5 labels to be added for this route */ + return en->labels * sizeof(struct mpls_shim_hdr); +} + +static int mpls_xmit(struct sk_buff *skb) +{ + struct mpls_iptunnel_encap *tun_encap_info; + struct mpls_shim_hdr *hdr; + struct net_device *out_dev; + unsigned int hh_len; + unsigned int new_header_size; + unsigned int mtu; + struct dst_entry *dst = skb_dst(skb); + struct rtable *rt = NULL; + struct rt6_info *rt6 = NULL; + struct mpls_dev *out_mdev; + struct net *net; + int err = 0; + bool bos; + int i; + unsigned int ttl; + + /* Find the output device */ + out_dev = dst->dev; + net = dev_net(out_dev); + + skb_orphan(skb); + + if (!mpls_output_possible(out_dev) || + !dst->lwtstate || skb_warn_if_lro(skb)) + goto drop; + + skb_forward_csum(skb); + + tun_encap_info = mpls_lwtunnel_encap(dst->lwtstate); + + /* Obtain the ttl using the following set of rules. + * + * LWT ttl propagation setting: + * - disabled => use default TTL value from LWT + * - enabled => use TTL value from IPv4/IPv6 header + * - default => + * Global ttl propagation setting: + * - disabled => use default TTL value from global setting + * - enabled => use TTL value from IPv4/IPv6 header + */ + if (dst->ops->family == AF_INET) { + if (tun_encap_info->ttl_propagate == MPLS_TTL_PROP_DISABLED) + ttl = tun_encap_info->default_ttl; + else if (tun_encap_info->ttl_propagate == MPLS_TTL_PROP_DEFAULT && + !net->mpls.ip_ttl_propagate) + ttl = net->mpls.default_ttl; + else + ttl = ip_hdr(skb)->ttl; + rt = (struct rtable *)dst; + } else if (dst->ops->family == AF_INET6) { + if (tun_encap_info->ttl_propagate == MPLS_TTL_PROP_DISABLED) + ttl = tun_encap_info->default_ttl; + else if (tun_encap_info->ttl_propagate == MPLS_TTL_PROP_DEFAULT && + !net->mpls.ip_ttl_propagate) + ttl = net->mpls.default_ttl; + else + ttl = ipv6_hdr(skb)->hop_limit; + rt6 = (struct rt6_info *)dst; + } else { + goto drop; + } + + /* Verify the destination can hold the packet */ + new_header_size = mpls_encap_size(tun_encap_info); + mtu = mpls_dev_mtu(out_dev); + if (mpls_pkt_too_big(skb, mtu - new_header_size)) + goto drop; + + hh_len = LL_RESERVED_SPACE(out_dev); + if (!out_dev->header_ops) + hh_len = 0; + + /* Ensure there is enough space for the headers in the skb */ + if (skb_cow(skb, hh_len + new_header_size)) + goto drop; + + skb_set_inner_protocol(skb, skb->protocol); + skb_reset_inner_network_header(skb); + + skb_push(skb, new_header_size); + + skb_reset_network_header(skb); + + skb->dev = out_dev; + skb->protocol = htons(ETH_P_MPLS_UC); + + /* Push the new labels */ + hdr = mpls_hdr(skb); + bos = true; + for (i = tun_encap_info->labels - 1; i >= 0; i--) { + hdr[i] = mpls_entry_encode(tun_encap_info->label[i], + ttl, 0, bos); + bos = false; + } + + mpls_stats_inc_outucastpkts(out_dev, skb); + + if (rt) { + if (rt->rt_gw_family == AF_INET6) + err = neigh_xmit(NEIGH_ND_TABLE, out_dev, &rt->rt_gw6, + skb); + else + err = neigh_xmit(NEIGH_ARP_TABLE, out_dev, &rt->rt_gw4, + skb); + } else if (rt6) { + if (ipv6_addr_v4mapped(&rt6->rt6i_gateway)) { + /* 6PE (RFC 4798) */ + err = neigh_xmit(NEIGH_ARP_TABLE, out_dev, &rt6->rt6i_gateway.s6_addr32[3], + skb); + } else + err = neigh_xmit(NEIGH_ND_TABLE, out_dev, &rt6->rt6i_gateway, + skb); + } + if (err) + net_dbg_ratelimited("%s: packet transmission failed: %d\n", + __func__, err); + + return LWTUNNEL_XMIT_DONE; + +drop: + out_mdev = out_dev ? mpls_dev_get(out_dev) : NULL; + if (out_mdev) + MPLS_INC_STATS(out_mdev, tx_errors); + kfree_skb(skb); + return -EINVAL; +} + +static int mpls_build_state(struct net *net, struct nlattr *nla, + unsigned int family, const void *cfg, + struct lwtunnel_state **ts, + struct netlink_ext_ack *extack) +{ + struct mpls_iptunnel_encap *tun_encap_info; + struct nlattr *tb[MPLS_IPTUNNEL_MAX + 1]; + struct lwtunnel_state *newts; + u8 n_labels; + int ret; + + ret = nla_parse_nested_deprecated(tb, MPLS_IPTUNNEL_MAX, nla, + mpls_iptunnel_policy, extack); + if (ret < 0) + return ret; + + if (!tb[MPLS_IPTUNNEL_DST]) { + NL_SET_ERR_MSG(extack, "MPLS_IPTUNNEL_DST attribute is missing"); + return -EINVAL; + } + + /* determine number of labels */ + if (nla_get_labels(tb[MPLS_IPTUNNEL_DST], MAX_NEW_LABELS, + &n_labels, NULL, extack)) + return -EINVAL; + + newts = lwtunnel_state_alloc(struct_size(tun_encap_info, label, + n_labels)); + if (!newts) + return -ENOMEM; + + tun_encap_info = mpls_lwtunnel_encap(newts); + ret = nla_get_labels(tb[MPLS_IPTUNNEL_DST], n_labels, + &tun_encap_info->labels, tun_encap_info->label, + extack); + if (ret) + goto errout; + + tun_encap_info->ttl_propagate = MPLS_TTL_PROP_DEFAULT; + + if (tb[MPLS_IPTUNNEL_TTL]) { + tun_encap_info->default_ttl = nla_get_u8(tb[MPLS_IPTUNNEL_TTL]); + /* TTL 0 implies propagate from IP header */ + tun_encap_info->ttl_propagate = tun_encap_info->default_ttl ? + MPLS_TTL_PROP_DISABLED : + MPLS_TTL_PROP_ENABLED; + } + + newts->type = LWTUNNEL_ENCAP_MPLS; + newts->flags |= LWTUNNEL_STATE_XMIT_REDIRECT; + newts->headroom = mpls_encap_size(tun_encap_info); + + *ts = newts; + + return 0; + +errout: + kfree(newts); + *ts = NULL; + + return ret; +} + +static int mpls_fill_encap_info(struct sk_buff *skb, + struct lwtunnel_state *lwtstate) +{ + struct mpls_iptunnel_encap *tun_encap_info; + + tun_encap_info = mpls_lwtunnel_encap(lwtstate); + + if (nla_put_labels(skb, MPLS_IPTUNNEL_DST, tun_encap_info->labels, + tun_encap_info->label)) + goto nla_put_failure; + + if (tun_encap_info->ttl_propagate != MPLS_TTL_PROP_DEFAULT && + nla_put_u8(skb, MPLS_IPTUNNEL_TTL, tun_encap_info->default_ttl)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static int mpls_encap_nlsize(struct lwtunnel_state *lwtstate) +{ + struct mpls_iptunnel_encap *tun_encap_info; + int nlsize; + + tun_encap_info = mpls_lwtunnel_encap(lwtstate); + + nlsize = nla_total_size(tun_encap_info->labels * 4); + + if (tun_encap_info->ttl_propagate != MPLS_TTL_PROP_DEFAULT) + nlsize += nla_total_size(1); + + return nlsize; +} + +static int mpls_encap_cmp(struct lwtunnel_state *a, struct lwtunnel_state *b) +{ + struct mpls_iptunnel_encap *a_hdr = mpls_lwtunnel_encap(a); + struct mpls_iptunnel_encap *b_hdr = mpls_lwtunnel_encap(b); + int l; + + if (a_hdr->labels != b_hdr->labels || + a_hdr->ttl_propagate != b_hdr->ttl_propagate || + a_hdr->default_ttl != b_hdr->default_ttl) + return 1; + + for (l = 0; l < a_hdr->labels; l++) + if (a_hdr->label[l] != b_hdr->label[l]) + return 1; + return 0; +} + +static const struct lwtunnel_encap_ops mpls_iptun_ops = { + .build_state = mpls_build_state, + .xmit = mpls_xmit, + .fill_encap = mpls_fill_encap_info, + .get_encap_size = mpls_encap_nlsize, + .cmp_encap = mpls_encap_cmp, + .owner = THIS_MODULE, +}; + +static int __init mpls_iptunnel_init(void) +{ + return lwtunnel_encap_add_ops(&mpls_iptun_ops, LWTUNNEL_ENCAP_MPLS); +} +module_init(mpls_iptunnel_init); + +static void __exit mpls_iptunnel_exit(void) +{ + lwtunnel_encap_del_ops(&mpls_iptun_ops, LWTUNNEL_ENCAP_MPLS); +} +module_exit(mpls_iptunnel_exit); + +MODULE_ALIAS_RTNL_LWT(MPLS); +MODULE_SOFTDEP("post: mpls_gso"); +MODULE_DESCRIPTION("MultiProtocol Label Switching IP Tunnels"); +MODULE_LICENSE("GPL v2"); diff --git a/net/mptcp/Kconfig b/net/mptcp/Kconfig new file mode 100644 index 000000000..20328920f --- /dev/null +++ b/net/mptcp/Kconfig @@ -0,0 +1,39 @@ + +config MPTCP + bool "MPTCP: Multipath TCP" + depends on INET + select SKB_EXTENSIONS + select CRYPTO_LIB_SHA256 + select CRYPTO + help + Multipath TCP (MPTCP) connections send and receive data over multiple + subflows in order to utilize multiple network paths. Each subflow + uses the TCP protocol, and TCP options carry header information for + MPTCP. + +if MPTCP + +config INET_MPTCP_DIAG + depends on INET_DIAG + def_tristate INET_DIAG + +config MPTCP_IPV6 + bool "MPTCP: IPv6 support for Multipath TCP" + depends on IPV6=y + default y + +config MPTCP_KUNIT_TEST + tristate "This builds the MPTCP KUnit tests" if !KUNIT_ALL_TESTS + depends on KUNIT + default KUNIT_ALL_TESTS + help + Currently covers the MPTCP crypto and token helpers. + Only useful for kernel devs running KUnit test harness and are not + for inclusion into a production build. + + For more information on KUnit and unit tests in general please refer + to the KUnit documentation in Documentation/dev-tools/kunit/. + + If unsure, say N. + +endif diff --git a/net/mptcp/Makefile b/net/mptcp/Makefile new file mode 100644 index 000000000..6e7df47c9 --- /dev/null +++ b/net/mptcp/Makefile @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: GPL-2.0 +obj-$(CONFIG_MPTCP) += mptcp.o + +mptcp-y := protocol.o subflow.o options.o token.o crypto.o ctrl.o pm.o diag.o \ + mib.o pm_netlink.o sockopt.o pm_userspace.o + +obj-$(CONFIG_SYN_COOKIES) += syncookies.o +obj-$(CONFIG_INET_MPTCP_DIAG) += mptcp_diag.o + +mptcp_crypto_test-objs := crypto_test.o +mptcp_token_test-objs := token_test.o +obj-$(CONFIG_MPTCP_KUNIT_TEST) += mptcp_crypto_test.o mptcp_token_test.o + +obj-$(CONFIG_BPF_SYSCALL) += bpf.o diff --git a/net/mptcp/bpf.c b/net/mptcp/bpf.c new file mode 100644 index 000000000..5a0a84ad9 --- /dev/null +++ b/net/mptcp/bpf.c @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Multipath TCP + * + * Copyright (c) 2020, Tessares SA. + * Copyright (c) 2022, SUSE. + * + * Author: Nicolas Rybowski + */ + +#define pr_fmt(fmt) "MPTCP: " fmt + +#include +#include "protocol.h" + +struct mptcp_sock *bpf_mptcp_sock_from_subflow(struct sock *sk) +{ + if (sk && sk_fullsock(sk) && sk->sk_protocol == IPPROTO_TCP && sk_is_mptcp(sk)) + return mptcp_sk(mptcp_subflow_ctx(sk)->conn); + + return NULL; +} diff --git a/net/mptcp/crypto.c b/net/mptcp/crypto.c new file mode 100644 index 000000000..a89313499 --- /dev/null +++ b/net/mptcp/crypto.c @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Multipath TCP cryptographic functions + * Copyright (c) 2017 - 2019, Intel Corporation. + * + * Note: This code is based on mptcp_ctrl.c, mptcp_ipv4.c, and + * mptcp_ipv6 from multipath-tcp.org, authored by: + * + * Sébastien Barré + * Christoph Paasch + * Jaakko Korkeaniemi + * Gregory Detal + * Fabien Duchêne + * Andreas Seelinger + * Lavkesh Lahngir + * Andreas Ripke + * Vlad Dogaru + * Octavian Purdila + * John Ronan + * Catalin Nicutar + * Brandon Heller + */ + +#include +#include +#include + +#include "protocol.h" + +#define SHA256_DIGEST_WORDS (SHA256_DIGEST_SIZE / 4) + +void mptcp_crypto_key_sha(u64 key, u32 *token, u64 *idsn) +{ + __be32 mptcp_hashed_key[SHA256_DIGEST_WORDS]; + __be64 input = cpu_to_be64(key); + + sha256((__force u8 *)&input, sizeof(input), (u8 *)mptcp_hashed_key); + + if (token) + *token = be32_to_cpu(mptcp_hashed_key[0]); + if (idsn) + *idsn = be64_to_cpu(*((__be64 *)&mptcp_hashed_key[6])); +} + +void mptcp_crypto_hmac_sha(u64 key1, u64 key2, u8 *msg, int len, void *hmac) +{ + u8 input[SHA256_BLOCK_SIZE + SHA256_DIGEST_SIZE]; + u8 key1be[8]; + u8 key2be[8]; + int i; + + if (WARN_ON_ONCE(len > SHA256_DIGEST_SIZE)) + len = SHA256_DIGEST_SIZE; + + put_unaligned_be64(key1, key1be); + put_unaligned_be64(key2, key2be); + + /* Generate key xored with ipad */ + memset(input, 0x36, SHA256_BLOCK_SIZE); + for (i = 0; i < 8; i++) + input[i] ^= key1be[i]; + for (i = 0; i < 8; i++) + input[i + 8] ^= key2be[i]; + + memcpy(&input[SHA256_BLOCK_SIZE], msg, len); + + /* emit sha256(K1 || msg) on the second input block, so we can + * reuse 'input' for the last hashing + */ + sha256(input, SHA256_BLOCK_SIZE + len, &input[SHA256_BLOCK_SIZE]); + + /* Prepare second part of hmac */ + memset(input, 0x5C, SHA256_BLOCK_SIZE); + for (i = 0; i < 8; i++) + input[i] ^= key1be[i]; + for (i = 0; i < 8; i++) + input[i + 8] ^= key2be[i]; + + sha256(input, SHA256_BLOCK_SIZE + SHA256_DIGEST_SIZE, hmac); +} + +#if IS_MODULE(CONFIG_MPTCP_KUNIT_TEST) +EXPORT_SYMBOL_GPL(mptcp_crypto_hmac_sha); +#endif diff --git a/net/mptcp/crypto_test.c b/net/mptcp/crypto_test.c new file mode 100644 index 000000000..017248dea --- /dev/null +++ b/net/mptcp/crypto_test.c @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: GPL-2.0 +#include + +#include "protocol.h" + +struct test_case { + char *key; + char *msg; + char *result; +}; + +/* we can't reuse RFC 4231 test vectors, as we have constraint on the + * input and key size. + */ +static struct test_case tests[] = { + { + .key = "0b0b0b0b0b0b0b0b", + .msg = "48692054", + .result = "8385e24fb4235ac37556b6b886db106284a1da671699f46db1f235ec622dcafa", + }, + { + .key = "aaaaaaaaaaaaaaaa", + .msg = "dddddddd", + .result = "2c5e219164ff1dca1c4a92318d847bb6b9d44492984e1eb71aff9022f71046e9", + }, + { + .key = "0102030405060708", + .msg = "cdcdcdcd", + .result = "e73b9ba9969969cefb04aa0d6df18ec2fcc075b6f23b4d8c4da736a5dbbc6e7d", + }, +}; + +static void mptcp_crypto_test_basic(struct kunit *test) +{ + char hmac[32], hmac_hex[65]; + u32 nonce1, nonce2; + u64 key1, key2; + u8 msg[8]; + int i, j; + + for (i = 0; i < ARRAY_SIZE(tests); ++i) { + /* mptcp hmap will convert to be before computing the hmac */ + key1 = be64_to_cpu(*((__be64 *)&tests[i].key[0])); + key2 = be64_to_cpu(*((__be64 *)&tests[i].key[8])); + nonce1 = be32_to_cpu(*((__be32 *)&tests[i].msg[0])); + nonce2 = be32_to_cpu(*((__be32 *)&tests[i].msg[4])); + + put_unaligned_be32(nonce1, &msg[0]); + put_unaligned_be32(nonce2, &msg[4]); + + mptcp_crypto_hmac_sha(key1, key2, msg, 8, hmac); + for (j = 0; j < 32; ++j) + sprintf(&hmac_hex[j << 1], "%02x", hmac[j] & 0xff); + hmac_hex[64] = 0; + + KUNIT_EXPECT_STREQ(test, &hmac_hex[0], tests[i].result); + } +} + +static struct kunit_case mptcp_crypto_test_cases[] = { + KUNIT_CASE(mptcp_crypto_test_basic), + {} +}; + +static struct kunit_suite mptcp_crypto_suite = { + .name = "mptcp-crypto", + .test_cases = mptcp_crypto_test_cases, +}; + +kunit_test_suite(mptcp_crypto_suite); + +MODULE_LICENSE("GPL"); diff --git a/net/mptcp/ctrl.c b/net/mptcp/ctrl.c new file mode 100644 index 000000000..ae20b7d92 --- /dev/null +++ b/net/mptcp/ctrl.c @@ -0,0 +1,230 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Multipath TCP + * + * Copyright (c) 2019, Tessares SA. + */ + +#ifdef CONFIG_SYSCTL +#include +#endif + +#include +#include + +#include "protocol.h" + +#define MPTCP_SYSCTL_PATH "net/mptcp" + +static int mptcp_pernet_id; + +#ifdef CONFIG_SYSCTL +static int mptcp_pm_type_max = __MPTCP_PM_TYPE_MAX; +#endif + +struct mptcp_pernet { +#ifdef CONFIG_SYSCTL + struct ctl_table_header *ctl_table_hdr; +#endif + + unsigned int add_addr_timeout; + unsigned int stale_loss_cnt; + u8 mptcp_enabled; + u8 checksum_enabled; + u8 allow_join_initial_addr_port; + u8 pm_type; +}; + +static struct mptcp_pernet *mptcp_get_pernet(const struct net *net) +{ + return net_generic(net, mptcp_pernet_id); +} + +int mptcp_is_enabled(const struct net *net) +{ + return mptcp_get_pernet(net)->mptcp_enabled; +} + +unsigned int mptcp_get_add_addr_timeout(const struct net *net) +{ + return mptcp_get_pernet(net)->add_addr_timeout; +} + +int mptcp_is_checksum_enabled(const struct net *net) +{ + return mptcp_get_pernet(net)->checksum_enabled; +} + +int mptcp_allow_join_id0(const struct net *net) +{ + return mptcp_get_pernet(net)->allow_join_initial_addr_port; +} + +unsigned int mptcp_stale_loss_cnt(const struct net *net) +{ + return mptcp_get_pernet(net)->stale_loss_cnt; +} + +int mptcp_get_pm_type(const struct net *net) +{ + return mptcp_get_pernet(net)->pm_type; +} + +static void mptcp_pernet_set_defaults(struct mptcp_pernet *pernet) +{ + pernet->mptcp_enabled = 1; + pernet->add_addr_timeout = TCP_RTO_MAX; + pernet->checksum_enabled = 0; + pernet->allow_join_initial_addr_port = 1; + pernet->stale_loss_cnt = 4; + pernet->pm_type = MPTCP_PM_TYPE_KERNEL; +} + +#ifdef CONFIG_SYSCTL +static struct ctl_table mptcp_sysctl_table[] = { + { + .procname = "enabled", + .maxlen = sizeof(u8), + .mode = 0644, + /* users with CAP_NET_ADMIN or root (not and) can change this + * value, same as other sysctl or the 'net' tree. + */ + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE + }, + { + .procname = "add_addr_timeout", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "checksum_enabled", + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE + }, + { + .procname = "allow_join_initial_addr_port", + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE + }, + { + .procname = "stale_loss_cnt", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_douintvec_minmax, + }, + { + .procname = "pm_type", + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &mptcp_pm_type_max + }, + {} +}; + +static int mptcp_pernet_new_table(struct net *net, struct mptcp_pernet *pernet) +{ + struct ctl_table_header *hdr; + struct ctl_table *table; + + table = mptcp_sysctl_table; + if (!net_eq(net, &init_net)) { + table = kmemdup(table, sizeof(mptcp_sysctl_table), GFP_KERNEL); + if (!table) + goto err_alloc; + } + + table[0].data = &pernet->mptcp_enabled; + table[1].data = &pernet->add_addr_timeout; + table[2].data = &pernet->checksum_enabled; + table[3].data = &pernet->allow_join_initial_addr_port; + table[4].data = &pernet->stale_loss_cnt; + table[5].data = &pernet->pm_type; + + hdr = register_net_sysctl(net, MPTCP_SYSCTL_PATH, table); + if (!hdr) + goto err_reg; + + pernet->ctl_table_hdr = hdr; + + return 0; + +err_reg: + if (!net_eq(net, &init_net)) + kfree(table); +err_alloc: + return -ENOMEM; +} + +static void mptcp_pernet_del_table(struct mptcp_pernet *pernet) +{ + struct ctl_table *table = pernet->ctl_table_hdr->ctl_table_arg; + + unregister_net_sysctl_table(pernet->ctl_table_hdr); + + kfree(table); +} + +#else + +static int mptcp_pernet_new_table(struct net *net, struct mptcp_pernet *pernet) +{ + return 0; +} + +static void mptcp_pernet_del_table(struct mptcp_pernet *pernet) {} + +#endif /* CONFIG_SYSCTL */ + +static int __net_init mptcp_net_init(struct net *net) +{ + struct mptcp_pernet *pernet = mptcp_get_pernet(net); + + mptcp_pernet_set_defaults(pernet); + + return mptcp_pernet_new_table(net, pernet); +} + +/* Note: the callback will only be called per extra netns */ +static void __net_exit mptcp_net_exit(struct net *net) +{ + struct mptcp_pernet *pernet = mptcp_get_pernet(net); + + mptcp_pernet_del_table(pernet); +} + +static struct pernet_operations mptcp_pernet_ops = { + .init = mptcp_net_init, + .exit = mptcp_net_exit, + .id = &mptcp_pernet_id, + .size = sizeof(struct mptcp_pernet), +}; + +void __init mptcp_init(void) +{ + mptcp_join_cookie_init(); + mptcp_proto_init(); + + if (register_pernet_subsys(&mptcp_pernet_ops) < 0) + panic("Failed to register MPTCP pernet subsystem.\n"); +} + +#if IS_ENABLED(CONFIG_MPTCP_IPV6) +int __init mptcpv6_init(void) +{ + int err; + + err = mptcp_proto_v6_init(); + + return err; +} +#endif diff --git a/net/mptcp/diag.c b/net/mptcp/diag.c new file mode 100644 index 000000000..a53658674 --- /dev/null +++ b/net/mptcp/diag.c @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: GPL-2.0 +/* MPTCP socket monitoring support + * + * Copyright (c) 2019 Red Hat + * + * Author: Davide Caratti + */ + +#include +#include +#include +#include +#include +#include "protocol.h" + +static int subflow_get_info(const struct sock *sk, struct sk_buff *skb) +{ + struct mptcp_subflow_context *sf; + struct nlattr *start; + u32 flags = 0; + int err; + + start = nla_nest_start_noflag(skb, INET_ULP_INFO_MPTCP); + if (!start) + return -EMSGSIZE; + + rcu_read_lock(); + sf = rcu_dereference(inet_csk(sk)->icsk_ulp_data); + if (!sf) { + err = 0; + goto nla_failure; + } + + if (sf->mp_capable) + flags |= MPTCP_SUBFLOW_FLAG_MCAP_REM; + if (sf->request_mptcp) + flags |= MPTCP_SUBFLOW_FLAG_MCAP_LOC; + if (sf->mp_join) + flags |= MPTCP_SUBFLOW_FLAG_JOIN_REM; + if (sf->request_join) + flags |= MPTCP_SUBFLOW_FLAG_JOIN_LOC; + if (sf->backup) + flags |= MPTCP_SUBFLOW_FLAG_BKUP_REM; + if (sf->request_bkup) + flags |= MPTCP_SUBFLOW_FLAG_BKUP_LOC; + if (sf->fully_established) + flags |= MPTCP_SUBFLOW_FLAG_FULLY_ESTABLISHED; + if (sf->conn_finished) + flags |= MPTCP_SUBFLOW_FLAG_CONNECTED; + if (sf->map_valid) + flags |= MPTCP_SUBFLOW_FLAG_MAPVALID; + + if (nla_put_u32(skb, MPTCP_SUBFLOW_ATTR_TOKEN_REM, sf->remote_token) || + nla_put_u32(skb, MPTCP_SUBFLOW_ATTR_TOKEN_LOC, sf->token) || + nla_put_u32(skb, MPTCP_SUBFLOW_ATTR_RELWRITE_SEQ, + sf->rel_write_seq) || + nla_put_u64_64bit(skb, MPTCP_SUBFLOW_ATTR_MAP_SEQ, sf->map_seq, + MPTCP_SUBFLOW_ATTR_PAD) || + nla_put_u32(skb, MPTCP_SUBFLOW_ATTR_MAP_SFSEQ, + sf->map_subflow_seq) || + nla_put_u32(skb, MPTCP_SUBFLOW_ATTR_SSN_OFFSET, sf->ssn_offset) || + nla_put_u16(skb, MPTCP_SUBFLOW_ATTR_MAP_DATALEN, + sf->map_data_len) || + nla_put_u32(skb, MPTCP_SUBFLOW_ATTR_FLAGS, flags) || + nla_put_u8(skb, MPTCP_SUBFLOW_ATTR_ID_REM, sf->remote_id) || + nla_put_u8(skb, MPTCP_SUBFLOW_ATTR_ID_LOC, sf->local_id)) { + err = -EMSGSIZE; + goto nla_failure; + } + + rcu_read_unlock(); + nla_nest_end(skb, start); + return 0; + +nla_failure: + rcu_read_unlock(); + nla_nest_cancel(skb, start); + return err; +} + +static size_t subflow_get_info_size(const struct sock *sk) +{ + size_t size = 0; + + size += nla_total_size(0) + /* INET_ULP_INFO_MPTCP */ + nla_total_size(4) + /* MPTCP_SUBFLOW_ATTR_TOKEN_REM */ + nla_total_size(4) + /* MPTCP_SUBFLOW_ATTR_TOKEN_LOC */ + nla_total_size(4) + /* MPTCP_SUBFLOW_ATTR_RELWRITE_SEQ */ + nla_total_size_64bit(8) + /* MPTCP_SUBFLOW_ATTR_MAP_SEQ */ + nla_total_size(4) + /* MPTCP_SUBFLOW_ATTR_MAP_SFSEQ */ + nla_total_size(2) + /* MPTCP_SUBFLOW_ATTR_SSN_OFFSET */ + nla_total_size(2) + /* MPTCP_SUBFLOW_ATTR_MAP_DATALEN */ + nla_total_size(4) + /* MPTCP_SUBFLOW_ATTR_FLAGS */ + nla_total_size(1) + /* MPTCP_SUBFLOW_ATTR_ID_REM */ + nla_total_size(1) + /* MPTCP_SUBFLOW_ATTR_ID_LOC */ + 0; + return size; +} + +void mptcp_diag_subflow_init(struct tcp_ulp_ops *ops) +{ + ops->get_info = subflow_get_info; + ops->get_info_size = subflow_get_info_size; +} diff --git a/net/mptcp/mib.c b/net/mptcp/mib.c new file mode 100644 index 000000000..0dac2863c --- /dev/null +++ b/net/mptcp/mib.c @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: GPL-2.0-or-later + +#include +#include +#include +#include +#include + +#include "mib.h" + +static const struct snmp_mib mptcp_snmp_list[] = { + SNMP_MIB_ITEM("MPCapableSYNRX", MPTCP_MIB_MPCAPABLEPASSIVE), + SNMP_MIB_ITEM("MPCapableSYNTX", MPTCP_MIB_MPCAPABLEACTIVE), + SNMP_MIB_ITEM("MPCapableSYNACKRX", MPTCP_MIB_MPCAPABLEACTIVEACK), + SNMP_MIB_ITEM("MPCapableACKRX", MPTCP_MIB_MPCAPABLEPASSIVEACK), + SNMP_MIB_ITEM("MPCapableFallbackACK", MPTCP_MIB_MPCAPABLEPASSIVEFALLBACK), + SNMP_MIB_ITEM("MPCapableFallbackSYNACK", MPTCP_MIB_MPCAPABLEACTIVEFALLBACK), + SNMP_MIB_ITEM("MPFallbackTokenInit", MPTCP_MIB_TOKENFALLBACKINIT), + SNMP_MIB_ITEM("MPTCPRetrans", MPTCP_MIB_RETRANSSEGS), + SNMP_MIB_ITEM("MPJoinNoTokenFound", MPTCP_MIB_JOINNOTOKEN), + SNMP_MIB_ITEM("MPJoinSynRx", MPTCP_MIB_JOINSYNRX), + SNMP_MIB_ITEM("MPJoinSynAckRx", MPTCP_MIB_JOINSYNACKRX), + SNMP_MIB_ITEM("MPJoinSynAckHMacFailure", MPTCP_MIB_JOINSYNACKMAC), + SNMP_MIB_ITEM("MPJoinAckRx", MPTCP_MIB_JOINACKRX), + SNMP_MIB_ITEM("MPJoinAckHMacFailure", MPTCP_MIB_JOINACKMAC), + SNMP_MIB_ITEM("DSSNotMatching", MPTCP_MIB_DSSNOMATCH), + SNMP_MIB_ITEM("InfiniteMapTx", MPTCP_MIB_INFINITEMAPTX), + SNMP_MIB_ITEM("InfiniteMapRx", MPTCP_MIB_INFINITEMAPRX), + SNMP_MIB_ITEM("DSSNoMatchTCP", MPTCP_MIB_DSSTCPMISMATCH), + SNMP_MIB_ITEM("DataCsumErr", MPTCP_MIB_DATACSUMERR), + SNMP_MIB_ITEM("OFOQueueTail", MPTCP_MIB_OFOQUEUETAIL), + SNMP_MIB_ITEM("OFOQueue", MPTCP_MIB_OFOQUEUE), + SNMP_MIB_ITEM("OFOMerge", MPTCP_MIB_OFOMERGE), + SNMP_MIB_ITEM("NoDSSInWindow", MPTCP_MIB_NODSSWINDOW), + SNMP_MIB_ITEM("DuplicateData", MPTCP_MIB_DUPDATA), + SNMP_MIB_ITEM("AddAddr", MPTCP_MIB_ADDADDR), + SNMP_MIB_ITEM("EchoAdd", MPTCP_MIB_ECHOADD), + SNMP_MIB_ITEM("PortAdd", MPTCP_MIB_PORTADD), + SNMP_MIB_ITEM("AddAddrDrop", MPTCP_MIB_ADDADDRDROP), + SNMP_MIB_ITEM("MPJoinPortSynRx", MPTCP_MIB_JOINPORTSYNRX), + SNMP_MIB_ITEM("MPJoinPortSynAckRx", MPTCP_MIB_JOINPORTSYNACKRX), + SNMP_MIB_ITEM("MPJoinPortAckRx", MPTCP_MIB_JOINPORTACKRX), + SNMP_MIB_ITEM("MismatchPortSynRx", MPTCP_MIB_MISMATCHPORTSYNRX), + SNMP_MIB_ITEM("MismatchPortAckRx", MPTCP_MIB_MISMATCHPORTACKRX), + SNMP_MIB_ITEM("RmAddr", MPTCP_MIB_RMADDR), + SNMP_MIB_ITEM("RmAddrDrop", MPTCP_MIB_RMADDRDROP), + SNMP_MIB_ITEM("RmSubflow", MPTCP_MIB_RMSUBFLOW), + SNMP_MIB_ITEM("MPPrioTx", MPTCP_MIB_MPPRIOTX), + SNMP_MIB_ITEM("MPPrioRx", MPTCP_MIB_MPPRIORX), + SNMP_MIB_ITEM("MPFailTx", MPTCP_MIB_MPFAILTX), + SNMP_MIB_ITEM("MPFailRx", MPTCP_MIB_MPFAILRX), + SNMP_MIB_ITEM("MPFastcloseTx", MPTCP_MIB_MPFASTCLOSETX), + SNMP_MIB_ITEM("MPFastcloseRx", MPTCP_MIB_MPFASTCLOSERX), + SNMP_MIB_ITEM("MPRstTx", MPTCP_MIB_MPRSTTX), + SNMP_MIB_ITEM("MPRstRx", MPTCP_MIB_MPRSTRX), + SNMP_MIB_ITEM("RcvPruned", MPTCP_MIB_RCVPRUNED), + SNMP_MIB_ITEM("SubflowStale", MPTCP_MIB_SUBFLOWSTALE), + SNMP_MIB_ITEM("SubflowRecover", MPTCP_MIB_SUBFLOWRECOVER), + SNMP_MIB_ITEM("SndWndShared", MPTCP_MIB_SNDWNDSHARED), + SNMP_MIB_ITEM("RcvWndShared", MPTCP_MIB_RCVWNDSHARED), + SNMP_MIB_ITEM("RcvWndConflictUpdate", MPTCP_MIB_RCVWNDCONFLICTUPDATE), + SNMP_MIB_ITEM("RcvWndConflict", MPTCP_MIB_RCVWNDCONFLICT), + SNMP_MIB_SENTINEL +}; + +/* mptcp_mib_alloc - allocate percpu mib counters + * + * These are allocated when the first mptcp socket is created so + * we do not waste percpu memory if mptcp isn't in use. + */ +bool mptcp_mib_alloc(struct net *net) +{ + struct mptcp_mib __percpu *mib = alloc_percpu(struct mptcp_mib); + + if (!mib) + return false; + + if (cmpxchg(&net->mib.mptcp_statistics, NULL, mib)) + free_percpu(mib); + + return true; +} + +void mptcp_seq_show(struct seq_file *seq) +{ + unsigned long sum[ARRAY_SIZE(mptcp_snmp_list) - 1]; + struct net *net = seq->private; + int i; + + seq_puts(seq, "MPTcpExt:"); + for (i = 0; mptcp_snmp_list[i].name; i++) + seq_printf(seq, " %s", mptcp_snmp_list[i].name); + + seq_puts(seq, "\nMPTcpExt:"); + + memset(sum, 0, sizeof(sum)); + if (net->mib.mptcp_statistics) + snmp_get_cpu_field_batch(sum, mptcp_snmp_list, + net->mib.mptcp_statistics); + + for (i = 0; mptcp_snmp_list[i].name; i++) + seq_printf(seq, " %lu", sum[i]); + + seq_putc(seq, '\n'); +} diff --git a/net/mptcp/mib.h b/net/mptcp/mib.h new file mode 100644 index 000000000..2be359637 --- /dev/null +++ b/net/mptcp/mib.h @@ -0,0 +1,80 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ + +enum linux_mptcp_mib_field { + MPTCP_MIB_NUM = 0, + MPTCP_MIB_MPCAPABLEPASSIVE, /* Received SYN with MP_CAPABLE */ + MPTCP_MIB_MPCAPABLEACTIVE, /* Sent SYN with MP_CAPABLE */ + MPTCP_MIB_MPCAPABLEACTIVEACK, /* Received SYN/ACK with MP_CAPABLE */ + MPTCP_MIB_MPCAPABLEPASSIVEACK, /* Received third ACK with MP_CAPABLE */ + MPTCP_MIB_MPCAPABLEPASSIVEFALLBACK,/* Server-side fallback during 3-way handshake */ + MPTCP_MIB_MPCAPABLEACTIVEFALLBACK, /* Client-side fallback during 3-way handshake */ + MPTCP_MIB_TOKENFALLBACKINIT, /* Could not init/allocate token */ + MPTCP_MIB_RETRANSSEGS, /* Segments retransmitted at the MPTCP-level */ + MPTCP_MIB_JOINNOTOKEN, /* Received MP_JOIN but the token was not found */ + MPTCP_MIB_JOINSYNRX, /* Received a SYN + MP_JOIN */ + MPTCP_MIB_JOINSYNACKRX, /* Received a SYN/ACK + MP_JOIN */ + MPTCP_MIB_JOINSYNACKMAC, /* HMAC was wrong on SYN/ACK + MP_JOIN */ + MPTCP_MIB_JOINACKRX, /* Received an ACK + MP_JOIN */ + MPTCP_MIB_JOINACKMAC, /* HMAC was wrong on ACK + MP_JOIN */ + MPTCP_MIB_DSSNOMATCH, /* Received a new mapping that did not match the previous one */ + MPTCP_MIB_INFINITEMAPTX, /* Sent an infinite mapping */ + MPTCP_MIB_INFINITEMAPRX, /* Received an infinite mapping */ + MPTCP_MIB_DSSTCPMISMATCH, /* DSS-mapping did not map with TCP's sequence numbers */ + MPTCP_MIB_DATACSUMERR, /* The data checksum fail */ + MPTCP_MIB_OFOQUEUETAIL, /* Segments inserted into OoO queue tail */ + MPTCP_MIB_OFOQUEUE, /* Segments inserted into OoO queue */ + MPTCP_MIB_OFOMERGE, /* Segments merged in OoO queue */ + MPTCP_MIB_NODSSWINDOW, /* Segments not in MPTCP windows */ + MPTCP_MIB_DUPDATA, /* Segments discarded due to duplicate DSS */ + MPTCP_MIB_ADDADDR, /* Received ADD_ADDR with echo-flag=0 */ + MPTCP_MIB_ECHOADD, /* Received ADD_ADDR with echo-flag=1 */ + MPTCP_MIB_PORTADD, /* Received ADD_ADDR with a port-number */ + MPTCP_MIB_ADDADDRDROP, /* Dropped incoming ADD_ADDR */ + MPTCP_MIB_JOINPORTSYNRX, /* Received a SYN MP_JOIN with a different port-number */ + MPTCP_MIB_JOINPORTSYNACKRX, /* Received a SYNACK MP_JOIN with a different port-number */ + MPTCP_MIB_JOINPORTACKRX, /* Received an ACK MP_JOIN with a different port-number */ + MPTCP_MIB_MISMATCHPORTSYNRX, /* Received a SYN MP_JOIN with a mismatched port-number */ + MPTCP_MIB_MISMATCHPORTACKRX, /* Received an ACK MP_JOIN with a mismatched port-number */ + MPTCP_MIB_RMADDR, /* Received RM_ADDR */ + MPTCP_MIB_RMADDRDROP, /* Dropped incoming RM_ADDR */ + MPTCP_MIB_RMSUBFLOW, /* Remove a subflow */ + MPTCP_MIB_MPPRIOTX, /* Transmit a MP_PRIO */ + MPTCP_MIB_MPPRIORX, /* Received a MP_PRIO */ + MPTCP_MIB_MPFAILTX, /* Transmit a MP_FAIL */ + MPTCP_MIB_MPFAILRX, /* Received a MP_FAIL */ + MPTCP_MIB_MPFASTCLOSETX, /* Transmit a MP_FASTCLOSE */ + MPTCP_MIB_MPFASTCLOSERX, /* Received a MP_FASTCLOSE */ + MPTCP_MIB_MPRSTTX, /* Transmit a MP_RST */ + MPTCP_MIB_MPRSTRX, /* Received a MP_RST */ + MPTCP_MIB_RCVPRUNED, /* Incoming packet dropped due to memory limit */ + MPTCP_MIB_SUBFLOWSTALE, /* Subflows entered 'stale' status */ + MPTCP_MIB_SUBFLOWRECOVER, /* Subflows returned to active status after being stale */ + MPTCP_MIB_SNDWNDSHARED, /* Subflow snd wnd is overridden by msk's one */ + MPTCP_MIB_RCVWNDSHARED, /* Subflow rcv wnd is overridden by msk's one */ + MPTCP_MIB_RCVWNDCONFLICTUPDATE, /* subflow rcv wnd is overridden by msk's one due to + * conflict with another subflow while updating msk rcv wnd + */ + MPTCP_MIB_RCVWNDCONFLICT, /* Conflict with while updating msk rcv wnd */ + __MPTCP_MIB_MAX +}; + +#define LINUX_MIB_MPTCP_MAX __MPTCP_MIB_MAX +struct mptcp_mib { + unsigned long mibs[LINUX_MIB_MPTCP_MAX]; +}; + +static inline void MPTCP_INC_STATS(struct net *net, + enum linux_mptcp_mib_field field) +{ + if (likely(net->mib.mptcp_statistics)) + SNMP_INC_STATS(net->mib.mptcp_statistics, field); +} + +static inline void __MPTCP_INC_STATS(struct net *net, + enum linux_mptcp_mib_field field) +{ + if (likely(net->mib.mptcp_statistics)) + __SNMP_INC_STATS(net->mib.mptcp_statistics, field); +} + +bool mptcp_mib_alloc(struct net *net); diff --git a/net/mptcp/mptcp_diag.c b/net/mptcp/mptcp_diag.c new file mode 100644 index 000000000..8df1bdb64 --- /dev/null +++ b/net/mptcp/mptcp_diag.c @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: GPL-2.0 +/* MPTCP socket monitoring support + * + * Copyright (c) 2020 Red Hat + * + * Author: Paolo Abeni + */ + +#include +#include +#include +#include +#include +#include "protocol.h" + +static int sk_diag_dump(struct sock *sk, struct sk_buff *skb, + struct netlink_callback *cb, + const struct inet_diag_req_v2 *req, + struct nlattr *bc, bool net_admin) +{ + if (!inet_diag_bc_sk(bc, sk)) + return 0; + + return inet_sk_diag_fill(sk, inet_csk(sk), skb, cb, req, NLM_F_MULTI, + net_admin); +} + +static int mptcp_diag_dump_one(struct netlink_callback *cb, + const struct inet_diag_req_v2 *req) +{ + struct sk_buff *in_skb = cb->skb; + struct mptcp_sock *msk = NULL; + struct sk_buff *rep; + int err = -ENOENT; + struct net *net; + struct sock *sk; + + net = sock_net(in_skb->sk); + msk = mptcp_token_get_sock(net, req->id.idiag_cookie[0]); + if (!msk) + goto out_nosk; + + err = -ENOMEM; + sk = (struct sock *)msk; + rep = nlmsg_new(nla_total_size(sizeof(struct inet_diag_msg)) + + inet_diag_msg_attrs_size() + + nla_total_size(sizeof(struct mptcp_info)) + + nla_total_size(sizeof(struct inet_diag_meminfo)) + 64, + GFP_KERNEL); + if (!rep) + goto out; + + err = inet_sk_diag_fill(sk, inet_csk(sk), rep, cb, req, 0, + netlink_net_capable(in_skb, CAP_NET_ADMIN)); + if (err < 0) { + WARN_ON(err == -EMSGSIZE); + kfree_skb(rep); + goto out; + } + err = nlmsg_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid); + +out: + sock_put(sk); + +out_nosk: + return err; +} + +struct mptcp_diag_ctx { + long s_slot; + long s_num; + unsigned int l_slot; + unsigned int l_num; +}; + +static void mptcp_diag_dump_listeners(struct sk_buff *skb, struct netlink_callback *cb, + const struct inet_diag_req_v2 *r, + bool net_admin) +{ + struct inet_diag_dump_data *cb_data = cb->data; + struct mptcp_diag_ctx *diag_ctx = (void *)cb->ctx; + struct nlattr *bc = cb_data->inet_diag_nla_bc; + struct net *net = sock_net(skb->sk); + struct inet_hashinfo *hinfo; + int i; + + hinfo = net->ipv4.tcp_death_row.hashinfo; + + for (i = diag_ctx->l_slot; i <= hinfo->lhash2_mask; i++) { + struct inet_listen_hashbucket *ilb; + struct hlist_nulls_node *node; + struct sock *sk; + int num = 0; + + ilb = &hinfo->lhash2[i]; + + rcu_read_lock(); + spin_lock(&ilb->lock); + sk_nulls_for_each(sk, node, &ilb->nulls_head) { + const struct mptcp_subflow_context *ctx = mptcp_subflow_ctx(sk); + struct inet_sock *inet = inet_sk(sk); + int ret; + + if (num < diag_ctx->l_num) + goto next_listen; + + if (!ctx || strcmp(inet_csk(sk)->icsk_ulp_ops->name, "mptcp")) + goto next_listen; + + sk = ctx->conn; + if (!sk || !net_eq(sock_net(sk), net)) + goto next_listen; + + if (r->sdiag_family != AF_UNSPEC && + sk->sk_family != r->sdiag_family) + goto next_listen; + + if (r->id.idiag_sport != inet->inet_sport && + r->id.idiag_sport) + goto next_listen; + + if (!refcount_inc_not_zero(&sk->sk_refcnt)) + goto next_listen; + + ret = sk_diag_dump(sk, skb, cb, r, bc, net_admin); + + sock_put(sk); + + if (ret < 0) { + spin_unlock(&ilb->lock); + rcu_read_unlock(); + diag_ctx->l_slot = i; + diag_ctx->l_num = num; + return; + } + diag_ctx->l_num = num + 1; + num = 0; +next_listen: + ++num; + } + spin_unlock(&ilb->lock); + rcu_read_unlock(); + + cond_resched(); + diag_ctx->l_num = 0; + } + + diag_ctx->l_num = 0; + diag_ctx->l_slot = i; +} + +static void mptcp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, + const struct inet_diag_req_v2 *r) +{ + bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN); + struct mptcp_diag_ctx *diag_ctx = (void *)cb->ctx; + struct net *net = sock_net(skb->sk); + struct inet_diag_dump_data *cb_data; + struct mptcp_sock *msk; + struct nlattr *bc; + + BUILD_BUG_ON(sizeof(cb->ctx) < sizeof(*diag_ctx)); + + cb_data = cb->data; + bc = cb_data->inet_diag_nla_bc; + + while ((msk = mptcp_token_iter_next(net, &diag_ctx->s_slot, + &diag_ctx->s_num)) != NULL) { + struct inet_sock *inet = (struct inet_sock *)msk; + struct sock *sk = (struct sock *)msk; + int ret = 0; + + if (!(r->idiag_states & (1 << sk->sk_state))) + goto next; + if (r->sdiag_family != AF_UNSPEC && + sk->sk_family != r->sdiag_family) + goto next; + if (r->id.idiag_sport != inet->inet_sport && + r->id.idiag_sport) + goto next; + if (r->id.idiag_dport != inet->inet_dport && + r->id.idiag_dport) + goto next; + + ret = sk_diag_dump(sk, skb, cb, r, bc, net_admin); +next: + sock_put(sk); + if (ret < 0) { + /* will retry on the same position */ + diag_ctx->s_num--; + break; + } + cond_resched(); + } + + if ((r->idiag_states & TCPF_LISTEN) && r->id.idiag_dport == 0) + mptcp_diag_dump_listeners(skb, cb, r, net_admin); +} + +static void mptcp_diag_get_info(struct sock *sk, struct inet_diag_msg *r, + void *_info) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + struct mptcp_info *info = _info; + + r->idiag_rqueue = sk_rmem_alloc_get(sk); + r->idiag_wqueue = sk_wmem_alloc_get(sk); + + if (inet_sk_state_load(sk) == TCP_LISTEN) { + struct sock *lsk = READ_ONCE(msk->first); + + if (lsk) { + /* override with settings from tcp listener, + * so Send-Q will show accept queue. + */ + r->idiag_rqueue = READ_ONCE(lsk->sk_ack_backlog); + r->idiag_wqueue = READ_ONCE(lsk->sk_max_ack_backlog); + } + } + + if (!info) + return; + + mptcp_diag_fill_info(msk, info); +} + +static const struct inet_diag_handler mptcp_diag_handler = { + .dump = mptcp_diag_dump, + .dump_one = mptcp_diag_dump_one, + .idiag_get_info = mptcp_diag_get_info, + .idiag_type = IPPROTO_MPTCP, + .idiag_info_size = sizeof(struct mptcp_info), +}; + +static int __init mptcp_diag_init(void) +{ + return inet_diag_register(&mptcp_diag_handler); +} + +static void __exit mptcp_diag_exit(void) +{ + inet_diag_unregister(&mptcp_diag_handler); +} + +module_init(mptcp_diag_init); +module_exit(mptcp_diag_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2-262 /* AF_INET - IPPROTO_MPTCP */); diff --git a/net/mptcp/options.c b/net/mptcp/options.c new file mode 100644 index 000000000..a718ebcb5 --- /dev/null +++ b/net/mptcp/options.c @@ -0,0 +1,1633 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Multipath TCP + * + * Copyright (c) 2017 - 2019, Intel Corporation. + */ + +#define pr_fmt(fmt) "MPTCP: " fmt + +#include +#include +#include +#include +#include "protocol.h" +#include "mib.h" + +#include + +static bool mptcp_cap_flag_sha256(u8 flags) +{ + return (flags & MPTCP_CAP_FLAG_MASK) == MPTCP_CAP_HMAC_SHA256; +} + +static void mptcp_parse_option(const struct sk_buff *skb, + const unsigned char *ptr, int opsize, + struct mptcp_options_received *mp_opt) +{ + u8 subtype = *ptr >> 4; + int expected_opsize; + u8 version; + u8 flags; + u8 i; + + switch (subtype) { + case MPTCPOPT_MP_CAPABLE: + /* strict size checking */ + if (!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_SYN)) { + if (skb->len > tcp_hdr(skb)->doff << 2) + expected_opsize = TCPOLEN_MPTCP_MPC_ACK_DATA; + else + expected_opsize = TCPOLEN_MPTCP_MPC_ACK; + } else { + if (TCP_SKB_CB(skb)->tcp_flags & TCPHDR_ACK) + expected_opsize = TCPOLEN_MPTCP_MPC_SYNACK; + else + expected_opsize = TCPOLEN_MPTCP_MPC_SYN; + } + + /* Cfr RFC 8684 Section 3.3.0: + * If a checksum is present but its use had + * not been negotiated in the MP_CAPABLE handshake, the receiver MUST + * close the subflow with a RST, as it is not behaving as negotiated. + * If a checksum is not present when its use has been negotiated, the + * receiver MUST close the subflow with a RST, as it is considered + * broken + * We parse even option with mismatching csum presence, so that + * later in subflow_data_ready we can trigger the reset. + */ + if (opsize != expected_opsize && + (expected_opsize != TCPOLEN_MPTCP_MPC_ACK_DATA || + opsize != TCPOLEN_MPTCP_MPC_ACK_DATA_CSUM)) + break; + + /* try to be gentle vs future versions on the initial syn */ + version = *ptr++ & MPTCP_VERSION_MASK; + if (opsize != TCPOLEN_MPTCP_MPC_SYN) { + if (version != MPTCP_SUPPORTED_VERSION) + break; + } else if (version < MPTCP_SUPPORTED_VERSION) { + break; + } + + flags = *ptr++; + if (!mptcp_cap_flag_sha256(flags) || + (flags & MPTCP_CAP_EXTENSIBILITY)) + break; + + /* RFC 6824, Section 3.1: + * "For the Checksum Required bit (labeled "A"), if either + * host requires the use of checksums, checksums MUST be used. + * In other words, the only way for checksums not to be used + * is if both hosts in their SYNs set A=0." + */ + if (flags & MPTCP_CAP_CHECKSUM_REQD) + mp_opt->suboptions |= OPTION_MPTCP_CSUMREQD; + + mp_opt->deny_join_id0 = !!(flags & MPTCP_CAP_DENY_JOIN_ID0); + + mp_opt->suboptions |= OPTIONS_MPTCP_MPC; + if (opsize >= TCPOLEN_MPTCP_MPC_SYNACK) { + mp_opt->sndr_key = get_unaligned_be64(ptr); + ptr += 8; + } + if (opsize >= TCPOLEN_MPTCP_MPC_ACK) { + mp_opt->rcvr_key = get_unaligned_be64(ptr); + ptr += 8; + } + if (opsize >= TCPOLEN_MPTCP_MPC_ACK_DATA) { + /* Section 3.1.: + * "the data parameters in a MP_CAPABLE are semantically + * equivalent to those in a DSS option and can be used + * interchangeably." + */ + mp_opt->suboptions |= OPTION_MPTCP_DSS; + mp_opt->use_map = 1; + mp_opt->mpc_map = 1; + mp_opt->use_ack = 0; + mp_opt->data_len = get_unaligned_be16(ptr); + ptr += 2; + } + if (opsize == TCPOLEN_MPTCP_MPC_ACK_DATA_CSUM) { + mp_opt->csum = get_unaligned((__force __sum16 *)ptr); + mp_opt->suboptions |= OPTION_MPTCP_CSUMREQD; + ptr += 2; + } + pr_debug("MP_CAPABLE version=%x, flags=%x, optlen=%d sndr=%llu, rcvr=%llu len=%d csum=%u", + version, flags, opsize, mp_opt->sndr_key, + mp_opt->rcvr_key, mp_opt->data_len, mp_opt->csum); + break; + + case MPTCPOPT_MP_JOIN: + if (opsize == TCPOLEN_MPTCP_MPJ_SYN) { + mp_opt->suboptions |= OPTION_MPTCP_MPJ_SYN; + mp_opt->backup = *ptr++ & MPTCPOPT_BACKUP; + mp_opt->join_id = *ptr++; + mp_opt->token = get_unaligned_be32(ptr); + ptr += 4; + mp_opt->nonce = get_unaligned_be32(ptr); + ptr += 4; + pr_debug("MP_JOIN bkup=%u, id=%u, token=%u, nonce=%u", + mp_opt->backup, mp_opt->join_id, + mp_opt->token, mp_opt->nonce); + } else if (opsize == TCPOLEN_MPTCP_MPJ_SYNACK) { + mp_opt->suboptions |= OPTION_MPTCP_MPJ_SYNACK; + mp_opt->backup = *ptr++ & MPTCPOPT_BACKUP; + mp_opt->join_id = *ptr++; + mp_opt->thmac = get_unaligned_be64(ptr); + ptr += 8; + mp_opt->nonce = get_unaligned_be32(ptr); + ptr += 4; + pr_debug("MP_JOIN bkup=%u, id=%u, thmac=%llu, nonce=%u", + mp_opt->backup, mp_opt->join_id, + mp_opt->thmac, mp_opt->nonce); + } else if (opsize == TCPOLEN_MPTCP_MPJ_ACK) { + mp_opt->suboptions |= OPTION_MPTCP_MPJ_ACK; + ptr += 2; + memcpy(mp_opt->hmac, ptr, MPTCPOPT_HMAC_LEN); + pr_debug("MP_JOIN hmac"); + } + break; + + case MPTCPOPT_DSS: + pr_debug("DSS"); + ptr++; + + /* we must clear 'mpc_map' be able to detect MP_CAPABLE + * map vs DSS map in mptcp_incoming_options(), and reconstruct + * map info accordingly + */ + mp_opt->mpc_map = 0; + flags = (*ptr++) & MPTCP_DSS_FLAG_MASK; + mp_opt->data_fin = (flags & MPTCP_DSS_DATA_FIN) != 0; + mp_opt->dsn64 = (flags & MPTCP_DSS_DSN64) != 0; + mp_opt->use_map = (flags & MPTCP_DSS_HAS_MAP) != 0; + mp_opt->ack64 = (flags & MPTCP_DSS_ACK64) != 0; + mp_opt->use_ack = (flags & MPTCP_DSS_HAS_ACK); + + pr_debug("data_fin=%d dsn64=%d use_map=%d ack64=%d use_ack=%d", + mp_opt->data_fin, mp_opt->dsn64, + mp_opt->use_map, mp_opt->ack64, + mp_opt->use_ack); + + expected_opsize = TCPOLEN_MPTCP_DSS_BASE; + + if (mp_opt->use_ack) { + if (mp_opt->ack64) + expected_opsize += TCPOLEN_MPTCP_DSS_ACK64; + else + expected_opsize += TCPOLEN_MPTCP_DSS_ACK32; + } + + if (mp_opt->use_map) { + if (mp_opt->dsn64) + expected_opsize += TCPOLEN_MPTCP_DSS_MAP64; + else + expected_opsize += TCPOLEN_MPTCP_DSS_MAP32; + } + + /* Always parse any csum presence combination, we will enforce + * RFC 8684 Section 3.3.0 checks later in subflow_data_ready + */ + if (opsize != expected_opsize && + opsize != expected_opsize + TCPOLEN_MPTCP_DSS_CHECKSUM) + break; + + mp_opt->suboptions |= OPTION_MPTCP_DSS; + if (mp_opt->use_ack) { + if (mp_opt->ack64) { + mp_opt->data_ack = get_unaligned_be64(ptr); + ptr += 8; + } else { + mp_opt->data_ack = get_unaligned_be32(ptr); + ptr += 4; + } + + pr_debug("data_ack=%llu", mp_opt->data_ack); + } + + if (mp_opt->use_map) { + if (mp_opt->dsn64) { + mp_opt->data_seq = get_unaligned_be64(ptr); + ptr += 8; + } else { + mp_opt->data_seq = get_unaligned_be32(ptr); + ptr += 4; + } + + mp_opt->subflow_seq = get_unaligned_be32(ptr); + ptr += 4; + + mp_opt->data_len = get_unaligned_be16(ptr); + ptr += 2; + + if (opsize == expected_opsize + TCPOLEN_MPTCP_DSS_CHECKSUM) { + mp_opt->suboptions |= OPTION_MPTCP_CSUMREQD; + mp_opt->csum = get_unaligned((__force __sum16 *)ptr); + ptr += 2; + } + + pr_debug("data_seq=%llu subflow_seq=%u data_len=%u csum=%d:%u", + mp_opt->data_seq, mp_opt->subflow_seq, + mp_opt->data_len, !!(mp_opt->suboptions & OPTION_MPTCP_CSUMREQD), + mp_opt->csum); + } + + break; + + case MPTCPOPT_ADD_ADDR: + mp_opt->echo = (*ptr++) & MPTCP_ADDR_ECHO; + if (!mp_opt->echo) { + if (opsize == TCPOLEN_MPTCP_ADD_ADDR || + opsize == TCPOLEN_MPTCP_ADD_ADDR_PORT) + mp_opt->addr.family = AF_INET; +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + else if (opsize == TCPOLEN_MPTCP_ADD_ADDR6 || + opsize == TCPOLEN_MPTCP_ADD_ADDR6_PORT) + mp_opt->addr.family = AF_INET6; +#endif + else + break; + } else { + if (opsize == TCPOLEN_MPTCP_ADD_ADDR_BASE || + opsize == TCPOLEN_MPTCP_ADD_ADDR_BASE_PORT) + mp_opt->addr.family = AF_INET; +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + else if (opsize == TCPOLEN_MPTCP_ADD_ADDR6_BASE || + opsize == TCPOLEN_MPTCP_ADD_ADDR6_BASE_PORT) + mp_opt->addr.family = AF_INET6; +#endif + else + break; + } + + mp_opt->suboptions |= OPTION_MPTCP_ADD_ADDR; + mp_opt->addr.id = *ptr++; + mp_opt->addr.port = 0; + mp_opt->ahmac = 0; + if (mp_opt->addr.family == AF_INET) { + memcpy((u8 *)&mp_opt->addr.addr.s_addr, (u8 *)ptr, 4); + ptr += 4; + if (opsize == TCPOLEN_MPTCP_ADD_ADDR_PORT || + opsize == TCPOLEN_MPTCP_ADD_ADDR_BASE_PORT) { + mp_opt->addr.port = htons(get_unaligned_be16(ptr)); + ptr += 2; + } + } +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + else { + memcpy(mp_opt->addr.addr6.s6_addr, (u8 *)ptr, 16); + ptr += 16; + if (opsize == TCPOLEN_MPTCP_ADD_ADDR6_PORT || + opsize == TCPOLEN_MPTCP_ADD_ADDR6_BASE_PORT) { + mp_opt->addr.port = htons(get_unaligned_be16(ptr)); + ptr += 2; + } + } +#endif + if (!mp_opt->echo) { + mp_opt->ahmac = get_unaligned_be64(ptr); + ptr += 8; + } + pr_debug("ADD_ADDR%s: id=%d, ahmac=%llu, echo=%d, port=%d", + (mp_opt->addr.family == AF_INET6) ? "6" : "", + mp_opt->addr.id, mp_opt->ahmac, mp_opt->echo, ntohs(mp_opt->addr.port)); + break; + + case MPTCPOPT_RM_ADDR: + if (opsize < TCPOLEN_MPTCP_RM_ADDR_BASE + 1 || + opsize > TCPOLEN_MPTCP_RM_ADDR_BASE + MPTCP_RM_IDS_MAX) + break; + + ptr++; + + mp_opt->suboptions |= OPTION_MPTCP_RM_ADDR; + mp_opt->rm_list.nr = opsize - TCPOLEN_MPTCP_RM_ADDR_BASE; + for (i = 0; i < mp_opt->rm_list.nr; i++) + mp_opt->rm_list.ids[i] = *ptr++; + pr_debug("RM_ADDR: rm_list_nr=%d", mp_opt->rm_list.nr); + break; + + case MPTCPOPT_MP_PRIO: + if (opsize != TCPOLEN_MPTCP_PRIO) + break; + + mp_opt->suboptions |= OPTION_MPTCP_PRIO; + mp_opt->backup = *ptr++ & MPTCP_PRIO_BKUP; + pr_debug("MP_PRIO: prio=%d", mp_opt->backup); + break; + + case MPTCPOPT_MP_FASTCLOSE: + if (opsize != TCPOLEN_MPTCP_FASTCLOSE) + break; + + ptr += 2; + mp_opt->rcvr_key = get_unaligned_be64(ptr); + ptr += 8; + mp_opt->suboptions |= OPTION_MPTCP_FASTCLOSE; + pr_debug("MP_FASTCLOSE: recv_key=%llu", mp_opt->rcvr_key); + break; + + case MPTCPOPT_RST: + if (opsize != TCPOLEN_MPTCP_RST) + break; + + if (!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_RST)) + break; + + mp_opt->suboptions |= OPTION_MPTCP_RST; + flags = *ptr++; + mp_opt->reset_transient = flags & MPTCP_RST_TRANSIENT; + mp_opt->reset_reason = *ptr; + pr_debug("MP_RST: transient=%u reason=%u", + mp_opt->reset_transient, mp_opt->reset_reason); + break; + + case MPTCPOPT_MP_FAIL: + if (opsize != TCPOLEN_MPTCP_FAIL) + break; + + ptr += 2; + mp_opt->suboptions |= OPTION_MPTCP_FAIL; + mp_opt->fail_seq = get_unaligned_be64(ptr); + pr_debug("MP_FAIL: data_seq=%llu", mp_opt->fail_seq); + break; + + default: + break; + } +} + +void mptcp_get_options(const struct sk_buff *skb, + struct mptcp_options_received *mp_opt) +{ + const struct tcphdr *th = tcp_hdr(skb); + const unsigned char *ptr; + int length; + + /* initialize option status */ + mp_opt->suboptions = 0; + + length = (th->doff * 4) - sizeof(struct tcphdr); + ptr = (const unsigned char *)(th + 1); + + while (length > 0) { + int opcode = *ptr++; + int opsize; + + switch (opcode) { + case TCPOPT_EOL: + return; + case TCPOPT_NOP: /* Ref: RFC 793 section 3.1 */ + length--; + continue; + default: + if (length < 2) + return; + opsize = *ptr++; + if (opsize < 2) /* "silly options" */ + return; + if (opsize > length) + return; /* don't parse partial options */ + if (opcode == TCPOPT_MPTCP) + mptcp_parse_option(skb, ptr, opsize, mp_opt); + ptr += opsize - 2; + length -= opsize; + } + } +} + +bool mptcp_syn_options(struct sock *sk, const struct sk_buff *skb, + unsigned int *size, struct mptcp_out_options *opts) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + + /* we will use snd_isn to detect first pkt [re]transmission + * in mptcp_established_options_mp() + */ + subflow->snd_isn = TCP_SKB_CB(skb)->end_seq; + if (subflow->request_mptcp) { + opts->suboptions = OPTION_MPTCP_MPC_SYN; + opts->csum_reqd = mptcp_is_checksum_enabled(sock_net(sk)); + opts->allow_join_id0 = mptcp_allow_join_id0(sock_net(sk)); + *size = TCPOLEN_MPTCP_MPC_SYN; + return true; + } else if (subflow->request_join) { + pr_debug("remote_token=%u, nonce=%u", subflow->remote_token, + subflow->local_nonce); + opts->suboptions = OPTION_MPTCP_MPJ_SYN; + opts->join_id = subflow->local_id; + opts->token = subflow->remote_token; + opts->nonce = subflow->local_nonce; + opts->backup = subflow->request_bkup; + *size = TCPOLEN_MPTCP_MPJ_SYN; + return true; + } + return false; +} + +static void clear_3rdack_retransmission(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + sk_stop_timer(sk, &icsk->icsk_delack_timer); + icsk->icsk_ack.timeout = 0; + icsk->icsk_ack.ato = 0; + icsk->icsk_ack.pending &= ~(ICSK_ACK_SCHED | ICSK_ACK_TIMER); +} + +static bool mptcp_established_options_mp(struct sock *sk, struct sk_buff *skb, + bool snd_data_fin_enable, + unsigned int *size, + unsigned int remaining, + struct mptcp_out_options *opts) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + struct mptcp_sock *msk = mptcp_sk(subflow->conn); + struct mptcp_ext *mpext; + unsigned int data_len; + u8 len; + + /* When skb is not available, we better over-estimate the emitted + * options len. A full DSS option (28 bytes) is longer than + * TCPOLEN_MPTCP_MPC_ACK_DATA(22) or TCPOLEN_MPTCP_MPJ_ACK(24), so + * tell the caller to defer the estimate to + * mptcp_established_options_dss(), which will reserve enough space. + */ + if (!skb) + return false; + + /* MPC/MPJ needed only on 3rd ack packet, DATA_FIN and TCP shutdown take precedence */ + if (subflow->fully_established || snd_data_fin_enable || + subflow->snd_isn != TCP_SKB_CB(skb)->seq || + sk->sk_state != TCP_ESTABLISHED) + return false; + + if (subflow->mp_capable) { + mpext = mptcp_get_ext(skb); + data_len = mpext ? mpext->data_len : 0; + + /* we will check ops->data_len in mptcp_write_options() to + * discriminate between TCPOLEN_MPTCP_MPC_ACK_DATA and + * TCPOLEN_MPTCP_MPC_ACK + */ + opts->data_len = data_len; + opts->suboptions = OPTION_MPTCP_MPC_ACK; + opts->sndr_key = subflow->local_key; + opts->rcvr_key = subflow->remote_key; + opts->csum_reqd = READ_ONCE(msk->csum_enabled); + opts->allow_join_id0 = mptcp_allow_join_id0(sock_net(sk)); + + /* Section 3.1. + * The MP_CAPABLE option is carried on the SYN, SYN/ACK, and ACK + * packets that start the first subflow of an MPTCP connection, + * as well as the first packet that carries data + */ + if (data_len > 0) { + len = TCPOLEN_MPTCP_MPC_ACK_DATA; + if (opts->csum_reqd) { + /* we need to propagate more info to csum the pseudo hdr */ + opts->data_seq = mpext->data_seq; + opts->subflow_seq = mpext->subflow_seq; + opts->csum = mpext->csum; + len += TCPOLEN_MPTCP_DSS_CHECKSUM; + } + *size = ALIGN(len, 4); + } else { + *size = TCPOLEN_MPTCP_MPC_ACK; + } + + pr_debug("subflow=%p, local_key=%llu, remote_key=%llu map_len=%d", + subflow, subflow->local_key, subflow->remote_key, + data_len); + + return true; + } else if (subflow->mp_join) { + opts->suboptions = OPTION_MPTCP_MPJ_ACK; + memcpy(opts->hmac, subflow->hmac, MPTCPOPT_HMAC_LEN); + *size = TCPOLEN_MPTCP_MPJ_ACK; + pr_debug("subflow=%p", subflow); + + /* we can use the full delegate action helper only from BH context + * If we are in process context - sk is flushing the backlog at + * socket lock release time - just set the appropriate flag, will + * be handled by the release callback + */ + if (sock_owned_by_user(sk)) + set_bit(MPTCP_DELEGATE_ACK, &subflow->delegated_status); + else + mptcp_subflow_delegate(subflow, MPTCP_DELEGATE_ACK); + return true; + } + return false; +} + +static void mptcp_write_data_fin(struct mptcp_subflow_context *subflow, + struct sk_buff *skb, struct mptcp_ext *ext) +{ + /* The write_seq value has already been incremented, so the actual + * sequence number for the DATA_FIN is one less. + */ + u64 data_fin_tx_seq = READ_ONCE(mptcp_sk(subflow->conn)->write_seq) - 1; + + if (!ext->use_map || !skb->len) { + /* RFC6824 requires a DSS mapping with specific values + * if DATA_FIN is set but no data payload is mapped + */ + ext->data_fin = 1; + ext->use_map = 1; + ext->dsn64 = 1; + ext->data_seq = data_fin_tx_seq; + ext->subflow_seq = 0; + ext->data_len = 1; + } else if (ext->data_seq + ext->data_len == data_fin_tx_seq) { + /* If there's an existing DSS mapping and it is the + * final mapping, DATA_FIN consumes 1 additional byte of + * mapping space. + */ + ext->data_fin = 1; + ext->data_len++; + } +} + +static bool mptcp_established_options_dss(struct sock *sk, struct sk_buff *skb, + bool snd_data_fin_enable, + unsigned int *size, + unsigned int remaining, + struct mptcp_out_options *opts) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + struct mptcp_sock *msk = mptcp_sk(subflow->conn); + unsigned int dss_size = 0; + struct mptcp_ext *mpext; + unsigned int ack_size; + bool ret = false; + u64 ack_seq; + + opts->csum_reqd = READ_ONCE(msk->csum_enabled); + mpext = skb ? mptcp_get_ext(skb) : NULL; + + if (!skb || (mpext && mpext->use_map) || snd_data_fin_enable) { + unsigned int map_size = TCPOLEN_MPTCP_DSS_BASE + TCPOLEN_MPTCP_DSS_MAP64; + + if (mpext) { + if (opts->csum_reqd) + map_size += TCPOLEN_MPTCP_DSS_CHECKSUM; + + opts->ext_copy = *mpext; + } + + remaining -= map_size; + dss_size = map_size; + if (skb && snd_data_fin_enable) + mptcp_write_data_fin(subflow, skb, &opts->ext_copy); + opts->suboptions = OPTION_MPTCP_DSS; + ret = true; + } + + /* passive sockets msk will set the 'can_ack' after accept(), even + * if the first subflow may have the already the remote key handy + */ + opts->ext_copy.use_ack = 0; + if (!READ_ONCE(msk->can_ack)) { + *size = ALIGN(dss_size, 4); + return ret; + } + + ack_seq = READ_ONCE(msk->ack_seq); + if (READ_ONCE(msk->use_64bit_ack)) { + ack_size = TCPOLEN_MPTCP_DSS_ACK64; + opts->ext_copy.data_ack = ack_seq; + opts->ext_copy.ack64 = 1; + } else { + ack_size = TCPOLEN_MPTCP_DSS_ACK32; + opts->ext_copy.data_ack32 = (uint32_t)ack_seq; + opts->ext_copy.ack64 = 0; + } + opts->ext_copy.use_ack = 1; + opts->suboptions = OPTION_MPTCP_DSS; + WRITE_ONCE(msk->old_wspace, __mptcp_space((struct sock *)msk)); + + /* Add kind/length/subtype/flag overhead if mapping is not populated */ + if (dss_size == 0) + ack_size += TCPOLEN_MPTCP_DSS_BASE; + + dss_size += ack_size; + + *size = ALIGN(dss_size, 4); + return true; +} + +static u64 add_addr_generate_hmac(u64 key1, u64 key2, + struct mptcp_addr_info *addr) +{ + u16 port = ntohs(addr->port); + u8 hmac[SHA256_DIGEST_SIZE]; + u8 msg[19]; + int i = 0; + + msg[i++] = addr->id; + if (addr->family == AF_INET) { + memcpy(&msg[i], &addr->addr.s_addr, 4); + i += 4; + } +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + else if (addr->family == AF_INET6) { + memcpy(&msg[i], &addr->addr6.s6_addr, 16); + i += 16; + } +#endif + msg[i++] = port >> 8; + msg[i++] = port & 0xFF; + + mptcp_crypto_hmac_sha(key1, key2, msg, i, hmac); + + return get_unaligned_be64(&hmac[SHA256_DIGEST_SIZE - sizeof(u64)]); +} + +static bool mptcp_established_options_add_addr(struct sock *sk, struct sk_buff *skb, + unsigned int *size, + unsigned int remaining, + struct mptcp_out_options *opts) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + struct mptcp_sock *msk = mptcp_sk(subflow->conn); + bool drop_other_suboptions = false; + unsigned int opt_size = *size; + bool echo; + int len; + + /* add addr will strip the existing options, be sure to avoid breaking + * MPC/MPJ handshakes + */ + if (!mptcp_pm_should_add_signal(msk) || + (opts->suboptions & (OPTION_MPTCP_MPJ_ACK | OPTION_MPTCP_MPC_ACK)) || + !mptcp_pm_add_addr_signal(msk, skb, opt_size, remaining, &opts->addr, + &echo, &drop_other_suboptions)) + return false; + + if (drop_other_suboptions) + remaining += opt_size; + len = mptcp_add_addr_len(opts->addr.family, echo, !!opts->addr.port); + if (remaining < len) + return false; + + *size = len; + if (drop_other_suboptions) { + pr_debug("drop other suboptions"); + opts->suboptions = 0; + + /* note that e.g. DSS could have written into the memory + * aliased by ahmac, we must reset the field here + * to avoid appending the hmac even for ADD_ADDR echo + * options + */ + opts->ahmac = 0; + *size -= opt_size; + } + opts->suboptions |= OPTION_MPTCP_ADD_ADDR; + if (!echo) { + opts->ahmac = add_addr_generate_hmac(msk->local_key, + msk->remote_key, + &opts->addr); + } + pr_debug("addr_id=%d, ahmac=%llu, echo=%d, port=%d", + opts->addr.id, opts->ahmac, echo, ntohs(opts->addr.port)); + + return true; +} + +static bool mptcp_established_options_rm_addr(struct sock *sk, + unsigned int *size, + unsigned int remaining, + struct mptcp_out_options *opts) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + struct mptcp_sock *msk = mptcp_sk(subflow->conn); + struct mptcp_rm_list rm_list; + int i, len; + + if (!mptcp_pm_should_rm_signal(msk) || + !(mptcp_pm_rm_addr_signal(msk, remaining, &rm_list))) + return false; + + len = mptcp_rm_addr_len(&rm_list); + if (len < 0) + return false; + if (remaining < len) + return false; + + *size = len; + opts->suboptions |= OPTION_MPTCP_RM_ADDR; + opts->rm_list = rm_list; + + for (i = 0; i < opts->rm_list.nr; i++) + pr_debug("rm_list_ids[%d]=%d", i, opts->rm_list.ids[i]); + + return true; +} + +static bool mptcp_established_options_mp_prio(struct sock *sk, + unsigned int *size, + unsigned int remaining, + struct mptcp_out_options *opts) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + + /* can't send MP_PRIO with MPC, as they share the same option space: + * 'backup'. Also it makes no sense at all + */ + if (!subflow->send_mp_prio || (opts->suboptions & OPTIONS_MPTCP_MPC)) + return false; + + /* account for the trailing 'nop' option */ + if (remaining < TCPOLEN_MPTCP_PRIO_ALIGN) + return false; + + *size = TCPOLEN_MPTCP_PRIO_ALIGN; + opts->suboptions |= OPTION_MPTCP_PRIO; + opts->backup = subflow->request_bkup; + + pr_debug("prio=%d", opts->backup); + + return true; +} + +static noinline bool mptcp_established_options_rst(struct sock *sk, struct sk_buff *skb, + unsigned int *size, + unsigned int remaining, + struct mptcp_out_options *opts) +{ + const struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + + if (remaining < TCPOLEN_MPTCP_RST) + return false; + + *size = TCPOLEN_MPTCP_RST; + opts->suboptions |= OPTION_MPTCP_RST; + opts->reset_transient = subflow->reset_transient; + opts->reset_reason = subflow->reset_reason; + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPRSTTX); + + return true; +} + +static bool mptcp_established_options_fastclose(struct sock *sk, + unsigned int *size, + unsigned int remaining, + struct mptcp_out_options *opts) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + struct mptcp_sock *msk = mptcp_sk(subflow->conn); + + if (likely(!subflow->send_fastclose)) + return false; + + if (remaining < TCPOLEN_MPTCP_FASTCLOSE) + return false; + + *size = TCPOLEN_MPTCP_FASTCLOSE; + opts->suboptions |= OPTION_MPTCP_FASTCLOSE; + opts->rcvr_key = msk->remote_key; + + pr_debug("FASTCLOSE key=%llu", opts->rcvr_key); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPFASTCLOSETX); + return true; +} + +static bool mptcp_established_options_mp_fail(struct sock *sk, + unsigned int *size, + unsigned int remaining, + struct mptcp_out_options *opts) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + + if (likely(!subflow->send_mp_fail)) + return false; + + if (remaining < TCPOLEN_MPTCP_FAIL) + return false; + + *size = TCPOLEN_MPTCP_FAIL; + opts->suboptions |= OPTION_MPTCP_FAIL; + opts->fail_seq = subflow->map_seq; + + pr_debug("MP_FAIL fail_seq=%llu", opts->fail_seq); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPFAILTX); + + return true; +} + +bool mptcp_established_options(struct sock *sk, struct sk_buff *skb, + unsigned int *size, unsigned int remaining, + struct mptcp_out_options *opts) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + struct mptcp_sock *msk = mptcp_sk(subflow->conn); + unsigned int opt_size = 0; + bool snd_data_fin; + bool ret = false; + + opts->suboptions = 0; + + if (unlikely(__mptcp_check_fallback(msk) && !mptcp_check_infinite_map(skb))) + return false; + + if (unlikely(skb && TCP_SKB_CB(skb)->tcp_flags & TCPHDR_RST)) { + if (mptcp_established_options_fastclose(sk, &opt_size, remaining, opts) || + mptcp_established_options_mp_fail(sk, &opt_size, remaining, opts)) { + *size += opt_size; + remaining -= opt_size; + } + /* MP_RST can be used with MP_FASTCLOSE and MP_FAIL if there is room */ + if (mptcp_established_options_rst(sk, skb, &opt_size, remaining, opts)) { + *size += opt_size; + remaining -= opt_size; + } + return true; + } + + snd_data_fin = mptcp_data_fin_enabled(msk); + if (mptcp_established_options_mp(sk, skb, snd_data_fin, &opt_size, remaining, opts)) + ret = true; + else if (mptcp_established_options_dss(sk, skb, snd_data_fin, &opt_size, remaining, opts)) { + unsigned int mp_fail_size; + + ret = true; + if (mptcp_established_options_mp_fail(sk, &mp_fail_size, + remaining - opt_size, opts)) { + *size += opt_size + mp_fail_size; + remaining -= opt_size - mp_fail_size; + return true; + } + } + + /* we reserved enough space for the above options, and exceeding the + * TCP option space would be fatal + */ + if (WARN_ON_ONCE(opt_size > remaining)) + return false; + + *size += opt_size; + remaining -= opt_size; + if (mptcp_established_options_add_addr(sk, skb, &opt_size, remaining, opts)) { + *size += opt_size; + remaining -= opt_size; + ret = true; + } else if (mptcp_established_options_rm_addr(sk, &opt_size, remaining, opts)) { + *size += opt_size; + remaining -= opt_size; + ret = true; + } + + if (mptcp_established_options_mp_prio(sk, &opt_size, remaining, opts)) { + *size += opt_size; + remaining -= opt_size; + ret = true; + } + + return ret; +} + +bool mptcp_synack_options(const struct request_sock *req, unsigned int *size, + struct mptcp_out_options *opts) +{ + struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); + + if (subflow_req->mp_capable) { + opts->suboptions = OPTION_MPTCP_MPC_SYNACK; + opts->sndr_key = subflow_req->local_key; + opts->csum_reqd = subflow_req->csum_reqd; + opts->allow_join_id0 = subflow_req->allow_join_id0; + *size = TCPOLEN_MPTCP_MPC_SYNACK; + pr_debug("subflow_req=%p, local_key=%llu", + subflow_req, subflow_req->local_key); + return true; + } else if (subflow_req->mp_join) { + opts->suboptions = OPTION_MPTCP_MPJ_SYNACK; + opts->backup = subflow_req->backup; + opts->join_id = subflow_req->local_id; + opts->thmac = subflow_req->thmac; + opts->nonce = subflow_req->local_nonce; + pr_debug("req=%p, bkup=%u, id=%u, thmac=%llu, nonce=%u", + subflow_req, opts->backup, opts->join_id, + opts->thmac, opts->nonce); + *size = TCPOLEN_MPTCP_MPJ_SYNACK; + return true; + } + return false; +} + +static bool check_fully_established(struct mptcp_sock *msk, struct sock *ssk, + struct mptcp_subflow_context *subflow, + struct sk_buff *skb, + struct mptcp_options_received *mp_opt) +{ + /* here we can process OoO, in-window pkts, only in-sequence 4th ack + * will make the subflow fully established + */ + if (likely(subflow->fully_established)) { + /* on passive sockets, check for 3rd ack retransmission + * note that msk is always set by subflow_syn_recv_sock() + * for mp_join subflows + */ + if (TCP_SKB_CB(skb)->seq == subflow->ssn_offset + 1 && + TCP_SKB_CB(skb)->end_seq == TCP_SKB_CB(skb)->seq && + subflow->mp_join && (mp_opt->suboptions & OPTIONS_MPTCP_MPJ) && + !subflow->request_join) + tcp_send_ack(ssk); + goto fully_established; + } + + /* we must process OoO packets before the first subflow is fully + * established. OoO packets are instead a protocol violation + * for MP_JOIN subflows as the peer must not send any data + * before receiving the forth ack - cfr. RFC 8684 section 3.2. + */ + if (TCP_SKB_CB(skb)->seq != subflow->ssn_offset + 1) { + if (subflow->mp_join) + goto reset; + return subflow->mp_capable; + } + + if (((mp_opt->suboptions & OPTION_MPTCP_DSS) && mp_opt->use_ack) || + ((mp_opt->suboptions & OPTION_MPTCP_ADD_ADDR) && !mp_opt->echo)) { + /* subflows are fully established as soon as we get any + * additional ack, including ADD_ADDR. + */ + subflow->fully_established = 1; + WRITE_ONCE(msk->fully_established, true); + goto fully_established; + } + + /* If the first established packet does not contain MP_CAPABLE + data + * then fallback to TCP. Fallback scenarios requires a reset for + * MP_JOIN subflows. + */ + if (!(mp_opt->suboptions & OPTIONS_MPTCP_MPC)) { + if (subflow->mp_join) + goto reset; + subflow->mp_capable = 0; + pr_fallback(msk); + mptcp_do_fallback(ssk); + return false; + } + + if (mp_opt->deny_join_id0) + WRITE_ONCE(msk->pm.remote_deny_join_id0, true); + + if (unlikely(!READ_ONCE(msk->pm.server_side))) + pr_warn_once("bogus mpc option on established client sk"); + mptcp_subflow_fully_established(subflow, mp_opt); + +fully_established: + /* if the subflow is not already linked into the conn_list, we can't + * notify the PM: this subflow is still on the listener queue + * and the PM possibly acquiring the subflow lock could race with + * the listener close + */ + if (likely(subflow->pm_notified) || list_empty(&subflow->node)) + return true; + + subflow->pm_notified = 1; + if (subflow->mp_join) { + clear_3rdack_retransmission(ssk); + mptcp_pm_subflow_established(msk); + } else { + mptcp_pm_fully_established(msk, ssk, GFP_ATOMIC); + } + return true; + +reset: + mptcp_subflow_reset(ssk); + return false; +} + +u64 __mptcp_expand_seq(u64 old_seq, u64 cur_seq) +{ + u32 old_seq32, cur_seq32; + + old_seq32 = (u32)old_seq; + cur_seq32 = (u32)cur_seq; + cur_seq = (old_seq & GENMASK_ULL(63, 32)) + cur_seq32; + if (unlikely(cur_seq32 < old_seq32 && before(old_seq32, cur_seq32))) + return cur_seq + (1LL << 32); + + /* reverse wrap could happen, too */ + if (unlikely(cur_seq32 > old_seq32 && after(old_seq32, cur_seq32))) + return cur_seq - (1LL << 32); + return cur_seq; +} + +static void ack_update_msk(struct mptcp_sock *msk, + struct sock *ssk, + struct mptcp_options_received *mp_opt) +{ + u64 new_wnd_end, new_snd_una, snd_nxt = READ_ONCE(msk->snd_nxt); + struct sock *sk = (struct sock *)msk; + u64 old_snd_una; + + mptcp_data_lock(sk); + + /* avoid ack expansion on update conflict, to reduce the risk of + * wrongly expanding to a future ack sequence number, which is way + * more dangerous than missing an ack + */ + old_snd_una = msk->snd_una; + new_snd_una = mptcp_expand_seq(old_snd_una, mp_opt->data_ack, mp_opt->ack64); + + /* ACK for data not even sent yet? Ignore.*/ + if (unlikely(after64(new_snd_una, snd_nxt))) + new_snd_una = old_snd_una; + + new_wnd_end = new_snd_una + tcp_sk(ssk)->snd_wnd; + + if (after64(new_wnd_end, msk->wnd_end)) + msk->wnd_end = new_wnd_end; + + /* this assumes mptcp_incoming_options() is invoked after tcp_ack() */ + if (after64(msk->wnd_end, READ_ONCE(msk->snd_nxt))) + __mptcp_check_push(sk, ssk); + + if (after64(new_snd_una, old_snd_una)) { + msk->snd_una = new_snd_una; + __mptcp_data_acked(sk); + } + mptcp_data_unlock(sk); + + trace_ack_update_msk(mp_opt->data_ack, + old_snd_una, new_snd_una, + new_wnd_end, msk->wnd_end); +} + +bool mptcp_update_rcv_data_fin(struct mptcp_sock *msk, u64 data_fin_seq, bool use_64bit) +{ + /* Skip if DATA_FIN was already received. + * If updating simultaneously with the recvmsg loop, values + * should match. If they mismatch, the peer is misbehaving and + * we will prefer the most recent information. + */ + if (READ_ONCE(msk->rcv_data_fin)) + return false; + + WRITE_ONCE(msk->rcv_data_fin_seq, + mptcp_expand_seq(READ_ONCE(msk->ack_seq), data_fin_seq, use_64bit)); + WRITE_ONCE(msk->rcv_data_fin, 1); + + return true; +} + +static bool add_addr_hmac_valid(struct mptcp_sock *msk, + struct mptcp_options_received *mp_opt) +{ + u64 hmac = 0; + + if (mp_opt->echo) + return true; + + hmac = add_addr_generate_hmac(msk->remote_key, + msk->local_key, + &mp_opt->addr); + + pr_debug("msk=%p, ahmac=%llu, mp_opt->ahmac=%llu\n", + msk, hmac, mp_opt->ahmac); + + return hmac == mp_opt->ahmac; +} + +/* Return false if a subflow has been reset, else return true */ +bool mptcp_incoming_options(struct sock *sk, struct sk_buff *skb) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + struct mptcp_sock *msk = mptcp_sk(subflow->conn); + struct mptcp_options_received mp_opt; + struct mptcp_ext *mpext; + + if (__mptcp_check_fallback(msk)) { + /* Keep it simple and unconditionally trigger send data cleanup and + * pending queue spooling. We will need to acquire the data lock + * for more accurate checks, and once the lock is acquired, such + * helpers are cheap. + */ + mptcp_data_lock(subflow->conn); + if (sk_stream_memory_free(sk)) + __mptcp_check_push(subflow->conn, sk); + __mptcp_data_acked(subflow->conn); + mptcp_data_unlock(subflow->conn); + return true; + } + + mptcp_get_options(skb, &mp_opt); + + /* The subflow can be in close state only if check_fully_established() + * just sent a reset. If so, tell the caller to ignore the current packet. + */ + if (!check_fully_established(msk, sk, subflow, skb, &mp_opt)) + return sk->sk_state != TCP_CLOSE; + + if (unlikely(mp_opt.suboptions != OPTION_MPTCP_DSS)) { + if ((mp_opt.suboptions & OPTION_MPTCP_FASTCLOSE) && + msk->local_key == mp_opt.rcvr_key) { + WRITE_ONCE(msk->rcv_fastclose, true); + mptcp_schedule_work((struct sock *)msk); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPFASTCLOSERX); + } + + if ((mp_opt.suboptions & OPTION_MPTCP_ADD_ADDR) && + add_addr_hmac_valid(msk, &mp_opt)) { + if (!mp_opt.echo) { + mptcp_pm_add_addr_received(sk, &mp_opt.addr); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_ADDADDR); + } else { + mptcp_pm_add_addr_echoed(msk, &mp_opt.addr); + mptcp_pm_del_add_timer(msk, &mp_opt.addr, true); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_ECHOADD); + } + + if (mp_opt.addr.port) + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_PORTADD); + } + + if (mp_opt.suboptions & OPTION_MPTCP_RM_ADDR) + mptcp_pm_rm_addr_received(msk, &mp_opt.rm_list); + + if (mp_opt.suboptions & OPTION_MPTCP_PRIO) { + mptcp_pm_mp_prio_received(sk, mp_opt.backup); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPPRIORX); + } + + if (mp_opt.suboptions & OPTION_MPTCP_FAIL) { + mptcp_pm_mp_fail_received(sk, mp_opt.fail_seq); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPFAILRX); + } + + if (mp_opt.suboptions & OPTION_MPTCP_RST) { + subflow->reset_seen = 1; + subflow->reset_reason = mp_opt.reset_reason; + subflow->reset_transient = mp_opt.reset_transient; + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPRSTRX); + } + + if (!(mp_opt.suboptions & OPTION_MPTCP_DSS)) + return true; + } + + /* we can't wait for recvmsg() to update the ack_seq, otherwise + * monodirectional flows will stuck + */ + if (mp_opt.use_ack) + ack_update_msk(msk, sk, &mp_opt); + + /* Zero-data-length packets are dropped by the caller and not + * propagated to the MPTCP layer, so the skb extension does not + * need to be allocated or populated. DATA_FIN information, if + * present, needs to be updated here before the skb is freed. + */ + if (TCP_SKB_CB(skb)->seq == TCP_SKB_CB(skb)->end_seq) { + if (mp_opt.data_fin && mp_opt.data_len == 1 && + mptcp_update_rcv_data_fin(msk, mp_opt.data_seq, mp_opt.dsn64)) + mptcp_schedule_work((struct sock *)msk); + + return true; + } + + mpext = skb_ext_add(skb, SKB_EXT_MPTCP); + if (!mpext) + return true; + + memset(mpext, 0, sizeof(*mpext)); + + if (likely(mp_opt.use_map)) { + if (mp_opt.mpc_map) { + /* this is an MP_CAPABLE carrying MPTCP data + * we know this map the first chunk of data + */ + mptcp_crypto_key_sha(subflow->remote_key, NULL, + &mpext->data_seq); + mpext->data_seq++; + mpext->subflow_seq = 1; + mpext->dsn64 = 1; + mpext->mpc_map = 1; + mpext->data_fin = 0; + } else { + mpext->data_seq = mp_opt.data_seq; + mpext->subflow_seq = mp_opt.subflow_seq; + mpext->dsn64 = mp_opt.dsn64; + mpext->data_fin = mp_opt.data_fin; + } + mpext->data_len = mp_opt.data_len; + mpext->use_map = 1; + mpext->csum_reqd = !!(mp_opt.suboptions & OPTION_MPTCP_CSUMREQD); + + if (mpext->csum_reqd) + mpext->csum = mp_opt.csum; + } + + return true; +} + +static void mptcp_set_rwin(struct tcp_sock *tp, struct tcphdr *th) +{ + const struct sock *ssk = (const struct sock *)tp; + struct mptcp_subflow_context *subflow; + u64 ack_seq, rcv_wnd_old, rcv_wnd_new; + struct mptcp_sock *msk; + u32 new_win; + u64 win; + + subflow = mptcp_subflow_ctx(ssk); + msk = mptcp_sk(subflow->conn); + + ack_seq = READ_ONCE(msk->ack_seq); + rcv_wnd_new = ack_seq + tp->rcv_wnd; + + rcv_wnd_old = atomic64_read(&msk->rcv_wnd_sent); + if (after64(rcv_wnd_new, rcv_wnd_old)) { + u64 rcv_wnd; + + for (;;) { + rcv_wnd = atomic64_cmpxchg(&msk->rcv_wnd_sent, rcv_wnd_old, rcv_wnd_new); + + if (rcv_wnd == rcv_wnd_old) + break; + + rcv_wnd_old = rcv_wnd; + if (before64(rcv_wnd_new, rcv_wnd_old)) { + MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_RCVWNDCONFLICTUPDATE); + goto raise_win; + } + MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_RCVWNDCONFLICT); + } + return; + } + + if (rcv_wnd_new != rcv_wnd_old) { +raise_win: + win = rcv_wnd_old - ack_seq; + tp->rcv_wnd = min_t(u64, win, U32_MAX); + new_win = tp->rcv_wnd; + + /* Make sure we do not exceed the maximum possible + * scaled window. + */ + if (unlikely(th->syn)) + new_win = min(new_win, 65535U) << tp->rx_opt.rcv_wscale; + if (!tp->rx_opt.rcv_wscale && + READ_ONCE(sock_net(ssk)->ipv4.sysctl_tcp_workaround_signed_windows)) + new_win = min(new_win, MAX_TCP_WINDOW); + else + new_win = min(new_win, (65535U << tp->rx_opt.rcv_wscale)); + + /* RFC1323 scaling applied */ + new_win >>= tp->rx_opt.rcv_wscale; + th->window = htons(new_win); + MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_RCVWNDSHARED); + } +} + +__sum16 __mptcp_make_csum(u64 data_seq, u32 subflow_seq, u16 data_len, __wsum sum) +{ + struct csum_pseudo_header header; + __wsum csum; + + /* cfr RFC 8684 3.3.1.: + * the data sequence number used in the pseudo-header is + * always the 64-bit value, irrespective of what length is used in the + * DSS option itself. + */ + header.data_seq = cpu_to_be64(data_seq); + header.subflow_seq = htonl(subflow_seq); + header.data_len = htons(data_len); + header.csum = 0; + + csum = csum_partial(&header, sizeof(header), sum); + return csum_fold(csum); +} + +static __sum16 mptcp_make_csum(const struct mptcp_ext *mpext) +{ + return __mptcp_make_csum(mpext->data_seq, mpext->subflow_seq, mpext->data_len, + ~csum_unfold(mpext->csum)); +} + +static void put_len_csum(u16 len, __sum16 csum, void *data) +{ + __sum16 *sumptr = data + 2; + __be16 *ptr = data; + + put_unaligned_be16(len, ptr); + + put_unaligned(csum, sumptr); +} + +void mptcp_write_options(struct tcphdr *th, __be32 *ptr, struct tcp_sock *tp, + struct mptcp_out_options *opts) +{ + const struct sock *ssk = (const struct sock *)tp; + struct mptcp_subflow_context *subflow; + + /* Which options can be used together? + * + * X: mutually exclusive + * O: often used together + * C: can be used together in some cases + * P: could be used together but we prefer not to (optimisations) + * + * Opt: | MPC | MPJ | DSS | ADD | RM | PRIO | FAIL | FC | + * ------|------|------|------|------|------|------|------|------| + * MPC |------|------|------|------|------|------|------|------| + * MPJ | X |------|------|------|------|------|------|------| + * DSS | X | X |------|------|------|------|------|------| + * ADD | X | X | P |------|------|------|------|------| + * RM | C | C | C | P |------|------|------|------| + * PRIO | X | C | C | C | C |------|------|------| + * FAIL | X | X | C | X | X | X |------|------| + * FC | X | X | X | X | X | X | X |------| + * RST | X | X | X | X | X | X | O | O | + * ------|------|------|------|------|------|------|------|------| + * + * The same applies in mptcp_established_options() function. + */ + if (likely(OPTION_MPTCP_DSS & opts->suboptions)) { + struct mptcp_ext *mpext = &opts->ext_copy; + u8 len = TCPOLEN_MPTCP_DSS_BASE; + u8 flags = 0; + + if (mpext->use_ack) { + flags = MPTCP_DSS_HAS_ACK; + if (mpext->ack64) { + len += TCPOLEN_MPTCP_DSS_ACK64; + flags |= MPTCP_DSS_ACK64; + } else { + len += TCPOLEN_MPTCP_DSS_ACK32; + } + } + + if (mpext->use_map) { + len += TCPOLEN_MPTCP_DSS_MAP64; + + /* Use only 64-bit mapping flags for now, add + * support for optional 32-bit mappings later. + */ + flags |= MPTCP_DSS_HAS_MAP | MPTCP_DSS_DSN64; + if (mpext->data_fin) + flags |= MPTCP_DSS_DATA_FIN; + + if (opts->csum_reqd) + len += TCPOLEN_MPTCP_DSS_CHECKSUM; + } + + *ptr++ = mptcp_option(MPTCPOPT_DSS, len, 0, flags); + + if (mpext->use_ack) { + if (mpext->ack64) { + put_unaligned_be64(mpext->data_ack, ptr); + ptr += 2; + } else { + put_unaligned_be32(mpext->data_ack32, ptr); + ptr += 1; + } + } + + if (mpext->use_map) { + put_unaligned_be64(mpext->data_seq, ptr); + ptr += 2; + put_unaligned_be32(mpext->subflow_seq, ptr); + ptr += 1; + if (opts->csum_reqd) { + /* data_len == 0 is reserved for the infinite mapping, + * the checksum will also be set to 0. + */ + put_len_csum(mpext->data_len, + (mpext->data_len ? mptcp_make_csum(mpext) : 0), + ptr); + } else { + put_unaligned_be32(mpext->data_len << 16 | + TCPOPT_NOP << 8 | TCPOPT_NOP, ptr); + } + ptr += 1; + } + + /* We might need to add MP_FAIL options in rare cases */ + if (unlikely(OPTION_MPTCP_FAIL & opts->suboptions)) + goto mp_fail; + } else if (OPTIONS_MPTCP_MPC & opts->suboptions) { + u8 len, flag = MPTCP_CAP_HMAC_SHA256; + + if (OPTION_MPTCP_MPC_SYN & opts->suboptions) { + len = TCPOLEN_MPTCP_MPC_SYN; + } else if (OPTION_MPTCP_MPC_SYNACK & opts->suboptions) { + len = TCPOLEN_MPTCP_MPC_SYNACK; + } else if (opts->data_len) { + len = TCPOLEN_MPTCP_MPC_ACK_DATA; + if (opts->csum_reqd) + len += TCPOLEN_MPTCP_DSS_CHECKSUM; + } else { + len = TCPOLEN_MPTCP_MPC_ACK; + } + + if (opts->csum_reqd) + flag |= MPTCP_CAP_CHECKSUM_REQD; + + if (!opts->allow_join_id0) + flag |= MPTCP_CAP_DENY_JOIN_ID0; + + *ptr++ = mptcp_option(MPTCPOPT_MP_CAPABLE, len, + MPTCP_SUPPORTED_VERSION, + flag); + + if (!((OPTION_MPTCP_MPC_SYNACK | OPTION_MPTCP_MPC_ACK) & + opts->suboptions)) + goto mp_capable_done; + + put_unaligned_be64(opts->sndr_key, ptr); + ptr += 2; + if (!((OPTION_MPTCP_MPC_ACK) & opts->suboptions)) + goto mp_capable_done; + + put_unaligned_be64(opts->rcvr_key, ptr); + ptr += 2; + if (!opts->data_len) + goto mp_capable_done; + + if (opts->csum_reqd) { + put_len_csum(opts->data_len, + __mptcp_make_csum(opts->data_seq, + opts->subflow_seq, + opts->data_len, + ~csum_unfold(opts->csum)), + ptr); + } else { + put_unaligned_be32(opts->data_len << 16 | + TCPOPT_NOP << 8 | TCPOPT_NOP, ptr); + } + ptr += 1; + + /* MPC is additionally mutually exclusive with MP_PRIO */ + goto mp_capable_done; + } else if (OPTIONS_MPTCP_MPJ & opts->suboptions) { + if (OPTION_MPTCP_MPJ_SYN & opts->suboptions) { + *ptr++ = mptcp_option(MPTCPOPT_MP_JOIN, + TCPOLEN_MPTCP_MPJ_SYN, + opts->backup, opts->join_id); + put_unaligned_be32(opts->token, ptr); + ptr += 1; + put_unaligned_be32(opts->nonce, ptr); + ptr += 1; + } else if (OPTION_MPTCP_MPJ_SYNACK & opts->suboptions) { + *ptr++ = mptcp_option(MPTCPOPT_MP_JOIN, + TCPOLEN_MPTCP_MPJ_SYNACK, + opts->backup, opts->join_id); + put_unaligned_be64(opts->thmac, ptr); + ptr += 2; + put_unaligned_be32(opts->nonce, ptr); + ptr += 1; + } else { + *ptr++ = mptcp_option(MPTCPOPT_MP_JOIN, + TCPOLEN_MPTCP_MPJ_ACK, 0, 0); + memcpy(ptr, opts->hmac, MPTCPOPT_HMAC_LEN); + ptr += 5; + } + } else if (OPTION_MPTCP_ADD_ADDR & opts->suboptions) { + u8 len = TCPOLEN_MPTCP_ADD_ADDR_BASE; + u8 echo = MPTCP_ADDR_ECHO; + +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + if (opts->addr.family == AF_INET6) + len = TCPOLEN_MPTCP_ADD_ADDR6_BASE; +#endif + + if (opts->addr.port) + len += TCPOLEN_MPTCP_PORT_LEN; + + if (opts->ahmac) { + len += sizeof(opts->ahmac); + echo = 0; + } + + *ptr++ = mptcp_option(MPTCPOPT_ADD_ADDR, + len, echo, opts->addr.id); + if (opts->addr.family == AF_INET) { + memcpy((u8 *)ptr, (u8 *)&opts->addr.addr.s_addr, 4); + ptr += 1; + } +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + else if (opts->addr.family == AF_INET6) { + memcpy((u8 *)ptr, opts->addr.addr6.s6_addr, 16); + ptr += 4; + } +#endif + + if (!opts->addr.port) { + if (opts->ahmac) { + put_unaligned_be64(opts->ahmac, ptr); + ptr += 2; + } + } else { + u16 port = ntohs(opts->addr.port); + + if (opts->ahmac) { + u8 *bptr = (u8 *)ptr; + + put_unaligned_be16(port, bptr); + bptr += 2; + put_unaligned_be64(opts->ahmac, bptr); + bptr += 8; + put_unaligned_be16(TCPOPT_NOP << 8 | + TCPOPT_NOP, bptr); + + ptr += 3; + } else { + put_unaligned_be32(port << 16 | + TCPOPT_NOP << 8 | + TCPOPT_NOP, ptr); + ptr += 1; + } + } + } else if (unlikely(OPTION_MPTCP_FASTCLOSE & opts->suboptions)) { + /* FASTCLOSE is mutually exclusive with others except RST */ + *ptr++ = mptcp_option(MPTCPOPT_MP_FASTCLOSE, + TCPOLEN_MPTCP_FASTCLOSE, + 0, 0); + put_unaligned_be64(opts->rcvr_key, ptr); + ptr += 2; + + if (OPTION_MPTCP_RST & opts->suboptions) + goto mp_rst; + return; + } else if (unlikely(OPTION_MPTCP_FAIL & opts->suboptions)) { +mp_fail: + /* MP_FAIL is mutually exclusive with others except RST */ + subflow = mptcp_subflow_ctx(ssk); + subflow->send_mp_fail = 0; + + *ptr++ = mptcp_option(MPTCPOPT_MP_FAIL, + TCPOLEN_MPTCP_FAIL, + 0, 0); + put_unaligned_be64(opts->fail_seq, ptr); + ptr += 2; + + if (OPTION_MPTCP_RST & opts->suboptions) + goto mp_rst; + return; + } else if (unlikely(OPTION_MPTCP_RST & opts->suboptions)) { +mp_rst: + *ptr++ = mptcp_option(MPTCPOPT_RST, + TCPOLEN_MPTCP_RST, + opts->reset_transient, + opts->reset_reason); + return; + } + + if (OPTION_MPTCP_PRIO & opts->suboptions) { + subflow = mptcp_subflow_ctx(ssk); + subflow->send_mp_prio = 0; + + *ptr++ = mptcp_option(MPTCPOPT_MP_PRIO, + TCPOLEN_MPTCP_PRIO, + opts->backup, TCPOPT_NOP); + + MPTCP_INC_STATS(sock_net((const struct sock *)tp), + MPTCP_MIB_MPPRIOTX); + } + +mp_capable_done: + if (OPTION_MPTCP_RM_ADDR & opts->suboptions) { + u8 i = 1; + + *ptr++ = mptcp_option(MPTCPOPT_RM_ADDR, + TCPOLEN_MPTCP_RM_ADDR_BASE + opts->rm_list.nr, + 0, opts->rm_list.ids[0]); + + while (i < opts->rm_list.nr) { + u8 id1, id2, id3, id4; + + id1 = opts->rm_list.ids[i]; + id2 = i + 1 < opts->rm_list.nr ? opts->rm_list.ids[i + 1] : TCPOPT_NOP; + id3 = i + 2 < opts->rm_list.nr ? opts->rm_list.ids[i + 2] : TCPOPT_NOP; + id4 = i + 3 < opts->rm_list.nr ? opts->rm_list.ids[i + 3] : TCPOPT_NOP; + put_unaligned_be32(id1 << 24 | id2 << 16 | id3 << 8 | id4, ptr); + ptr += 1; + i += 4; + } + } + + if (tp) + mptcp_set_rwin(tp, th); +} + +__be32 mptcp_get_reset_option(const struct sk_buff *skb) +{ + const struct mptcp_ext *ext = mptcp_get_ext(skb); + u8 flags, reason; + + if (ext) { + flags = ext->reset_transient; + reason = ext->reset_reason; + + return mptcp_option(MPTCPOPT_RST, TCPOLEN_MPTCP_RST, + flags, reason); + } + + return htonl(0u); +} +EXPORT_SYMBOL_GPL(mptcp_get_reset_option); diff --git a/net/mptcp/pm.c b/net/mptcp/pm.c new file mode 100644 index 000000000..10c288a0c --- /dev/null +++ b/net/mptcp/pm.c @@ -0,0 +1,513 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Multipath TCP + * + * Copyright (c) 2019, Intel Corporation. + */ +#define pr_fmt(fmt) "MPTCP: " fmt + +#include +#include +#include +#include "protocol.h" + +#include "mib.h" + +/* path manager command handlers */ + +int mptcp_pm_announce_addr(struct mptcp_sock *msk, + const struct mptcp_addr_info *addr, + bool echo) +{ + u8 add_addr = READ_ONCE(msk->pm.addr_signal); + + pr_debug("msk=%p, local_id=%d, echo=%d", msk, addr->id, echo); + + lockdep_assert_held(&msk->pm.lock); + + if (add_addr & + (echo ? BIT(MPTCP_ADD_ADDR_ECHO) : BIT(MPTCP_ADD_ADDR_SIGNAL))) { + pr_warn("addr_signal error, add_addr=%d, echo=%d", add_addr, echo); + return -EINVAL; + } + + if (echo) { + msk->pm.remote = *addr; + add_addr |= BIT(MPTCP_ADD_ADDR_ECHO); + } else { + msk->pm.local = *addr; + add_addr |= BIT(MPTCP_ADD_ADDR_SIGNAL); + } + WRITE_ONCE(msk->pm.addr_signal, add_addr); + return 0; +} + +int mptcp_pm_remove_addr(struct mptcp_sock *msk, const struct mptcp_rm_list *rm_list) +{ + u8 rm_addr = READ_ONCE(msk->pm.addr_signal); + + pr_debug("msk=%p, rm_list_nr=%d", msk, rm_list->nr); + + if (rm_addr) { + pr_warn("addr_signal error, rm_addr=%d", rm_addr); + return -EINVAL; + } + + msk->pm.rm_list_tx = *rm_list; + rm_addr |= BIT(MPTCP_RM_ADDR_SIGNAL); + WRITE_ONCE(msk->pm.addr_signal, rm_addr); + mptcp_pm_nl_addr_send_ack(msk); + return 0; +} + +int mptcp_pm_remove_subflow(struct mptcp_sock *msk, const struct mptcp_rm_list *rm_list) +{ + pr_debug("msk=%p, rm_list_nr=%d", msk, rm_list->nr); + + spin_lock_bh(&msk->pm.lock); + mptcp_pm_nl_rm_subflow_received(msk, rm_list); + spin_unlock_bh(&msk->pm.lock); + return 0; +} + +/* path manager event handlers */ + +void mptcp_pm_new_connection(struct mptcp_sock *msk, const struct sock *ssk, int server_side) +{ + struct mptcp_pm_data *pm = &msk->pm; + + pr_debug("msk=%p, token=%u side=%d", msk, msk->token, server_side); + + WRITE_ONCE(pm->server_side, server_side); + mptcp_event(MPTCP_EVENT_CREATED, msk, ssk, GFP_ATOMIC); +} + +bool mptcp_pm_allow_new_subflow(struct mptcp_sock *msk) +{ + struct mptcp_pm_data *pm = &msk->pm; + unsigned int subflows_max; + int ret = 0; + + if (mptcp_pm_is_userspace(msk)) { + if (mptcp_userspace_pm_active(msk)) { + spin_lock_bh(&pm->lock); + pm->subflows++; + spin_unlock_bh(&pm->lock); + return true; + } + return false; + } + + subflows_max = mptcp_pm_get_subflows_max(msk); + + pr_debug("msk=%p subflows=%d max=%d allow=%d", msk, pm->subflows, + subflows_max, READ_ONCE(pm->accept_subflow)); + + /* try to avoid acquiring the lock below */ + if (!READ_ONCE(pm->accept_subflow)) + return false; + + spin_lock_bh(&pm->lock); + if (READ_ONCE(pm->accept_subflow)) { + ret = pm->subflows < subflows_max; + if (ret && ++pm->subflows == subflows_max) + WRITE_ONCE(pm->accept_subflow, false); + } + spin_unlock_bh(&pm->lock); + + return ret; +} + +/* return true if the new status bit is currently cleared, that is, this event + * can be server, eventually by an already scheduled work + */ +static bool mptcp_pm_schedule_work(struct mptcp_sock *msk, + enum mptcp_pm_status new_status) +{ + pr_debug("msk=%p status=%x new=%lx", msk, msk->pm.status, + BIT(new_status)); + if (msk->pm.status & BIT(new_status)) + return false; + + msk->pm.status |= BIT(new_status); + mptcp_schedule_work((struct sock *)msk); + return true; +} + +void mptcp_pm_fully_established(struct mptcp_sock *msk, const struct sock *ssk, gfp_t gfp) +{ + struct mptcp_pm_data *pm = &msk->pm; + bool announce = false; + + pr_debug("msk=%p", msk); + + spin_lock_bh(&pm->lock); + + /* mptcp_pm_fully_established() can be invoked by multiple + * racing paths - accept() and check_fully_established() + * be sure to serve this event only once. + */ + if (READ_ONCE(pm->work_pending) && + !(msk->pm.status & BIT(MPTCP_PM_ALREADY_ESTABLISHED))) + mptcp_pm_schedule_work(msk, MPTCP_PM_ESTABLISHED); + + if ((msk->pm.status & BIT(MPTCP_PM_ALREADY_ESTABLISHED)) == 0) + announce = true; + + msk->pm.status |= BIT(MPTCP_PM_ALREADY_ESTABLISHED); + spin_unlock_bh(&pm->lock); + + if (announce) + mptcp_event(MPTCP_EVENT_ESTABLISHED, msk, ssk, gfp); +} + +void mptcp_pm_connection_closed(struct mptcp_sock *msk) +{ + pr_debug("msk=%p", msk); +} + +void mptcp_pm_subflow_established(struct mptcp_sock *msk) +{ + struct mptcp_pm_data *pm = &msk->pm; + + pr_debug("msk=%p", msk); + + if (!READ_ONCE(pm->work_pending)) + return; + + spin_lock_bh(&pm->lock); + + if (READ_ONCE(pm->work_pending)) + mptcp_pm_schedule_work(msk, MPTCP_PM_SUBFLOW_ESTABLISHED); + + spin_unlock_bh(&pm->lock); +} + +void mptcp_pm_subflow_check_next(struct mptcp_sock *msk, const struct sock *ssk, + const struct mptcp_subflow_context *subflow) +{ + struct mptcp_pm_data *pm = &msk->pm; + bool update_subflows; + + update_subflows = subflow->request_join || subflow->mp_join; + if (mptcp_pm_is_userspace(msk)) { + if (update_subflows) { + spin_lock_bh(&pm->lock); + pm->subflows--; + spin_unlock_bh(&pm->lock); + } + return; + } + + if (!READ_ONCE(pm->work_pending) && !update_subflows) + return; + + spin_lock_bh(&pm->lock); + if (update_subflows) + __mptcp_pm_close_subflow(msk); + + /* Even if this subflow is not really established, tell the PM to try + * to pick the next ones, if possible. + */ + if (mptcp_pm_nl_check_work_pending(msk)) + mptcp_pm_schedule_work(msk, MPTCP_PM_SUBFLOW_ESTABLISHED); + + spin_unlock_bh(&pm->lock); +} + +void mptcp_pm_add_addr_received(const struct sock *ssk, + const struct mptcp_addr_info *addr) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + struct mptcp_sock *msk = mptcp_sk(subflow->conn); + struct mptcp_pm_data *pm = &msk->pm; + + pr_debug("msk=%p remote_id=%d accept=%d", msk, addr->id, + READ_ONCE(pm->accept_addr)); + + mptcp_event_addr_announced(ssk, addr); + + spin_lock_bh(&pm->lock); + + if (mptcp_pm_is_userspace(msk)) { + if (mptcp_userspace_pm_active(msk)) { + mptcp_pm_announce_addr(msk, addr, true); + mptcp_pm_add_addr_send_ack(msk); + } else { + __MPTCP_INC_STATS(sock_net((struct sock *)msk), MPTCP_MIB_ADDADDRDROP); + } + } else if (!READ_ONCE(pm->accept_addr)) { + mptcp_pm_announce_addr(msk, addr, true); + mptcp_pm_add_addr_send_ack(msk); + } else if (mptcp_pm_schedule_work(msk, MPTCP_PM_ADD_ADDR_RECEIVED)) { + pm->remote = *addr; + } else { + __MPTCP_INC_STATS(sock_net((struct sock *)msk), MPTCP_MIB_ADDADDRDROP); + } + + spin_unlock_bh(&pm->lock); +} + +void mptcp_pm_add_addr_echoed(struct mptcp_sock *msk, + const struct mptcp_addr_info *addr) +{ + struct mptcp_pm_data *pm = &msk->pm; + + pr_debug("msk=%p", msk); + + spin_lock_bh(&pm->lock); + + if (mptcp_lookup_anno_list_by_saddr(msk, addr) && READ_ONCE(pm->work_pending)) + mptcp_pm_schedule_work(msk, MPTCP_PM_SUBFLOW_ESTABLISHED); + + spin_unlock_bh(&pm->lock); +} + +void mptcp_pm_add_addr_send_ack(struct mptcp_sock *msk) +{ + if (!mptcp_pm_should_add_signal(msk)) + return; + + mptcp_pm_schedule_work(msk, MPTCP_PM_ADD_ADDR_SEND_ACK); +} + +void mptcp_pm_rm_addr_received(struct mptcp_sock *msk, + const struct mptcp_rm_list *rm_list) +{ + struct mptcp_pm_data *pm = &msk->pm; + u8 i; + + pr_debug("msk=%p remote_ids_nr=%d", msk, rm_list->nr); + + for (i = 0; i < rm_list->nr; i++) + mptcp_event_addr_removed(msk, rm_list->ids[i]); + + spin_lock_bh(&pm->lock); + if (mptcp_pm_schedule_work(msk, MPTCP_PM_RM_ADDR_RECEIVED)) + pm->rm_list_rx = *rm_list; + else + __MPTCP_INC_STATS(sock_net((struct sock *)msk), MPTCP_MIB_RMADDRDROP); + spin_unlock_bh(&pm->lock); +} + +void mptcp_pm_mp_prio_received(struct sock *ssk, u8 bkup) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + struct sock *sk = subflow->conn; + struct mptcp_sock *msk; + + pr_debug("subflow->backup=%d, bkup=%d\n", subflow->backup, bkup); + msk = mptcp_sk(sk); + if (subflow->backup != bkup) { + subflow->backup = bkup; + mptcp_data_lock(sk); + if (!sock_owned_by_user(sk)) + msk->last_snd = NULL; + else + __set_bit(MPTCP_RESET_SCHEDULER, &msk->cb_flags); + mptcp_data_unlock(sk); + } + + mptcp_event(MPTCP_EVENT_SUB_PRIORITY, msk, ssk, GFP_ATOMIC); +} + +void mptcp_pm_mp_fail_received(struct sock *sk, u64 fail_seq) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + struct mptcp_sock *msk = mptcp_sk(subflow->conn); + + pr_debug("fail_seq=%llu", fail_seq); + + if (!READ_ONCE(msk->allow_infinite_fallback)) + return; + + if (!subflow->fail_tout) { + pr_debug("send MP_FAIL response and infinite map"); + + subflow->send_mp_fail = 1; + subflow->send_infinite_map = 1; + tcp_send_ack(sk); + } else { + pr_debug("MP_FAIL response received"); + WRITE_ONCE(subflow->fail_tout, 0); + } +} + +/* path manager helpers */ + +bool mptcp_pm_add_addr_signal(struct mptcp_sock *msk, const struct sk_buff *skb, + unsigned int opt_size, unsigned int remaining, + struct mptcp_addr_info *addr, bool *echo, + bool *drop_other_suboptions) +{ + int ret = false; + u8 add_addr; + u8 family; + bool port; + + spin_lock_bh(&msk->pm.lock); + + /* double check after the lock is acquired */ + if (!mptcp_pm_should_add_signal(msk)) + goto out_unlock; + + /* always drop every other options for pure ack ADD_ADDR; this is a + * plain dup-ack from TCP perspective. The other MPTCP-relevant info, + * if any, will be carried by the 'original' TCP ack + */ + if (skb && skb_is_tcp_pure_ack(skb)) { + remaining += opt_size; + *drop_other_suboptions = true; + } + + *echo = mptcp_pm_should_add_signal_echo(msk); + port = !!(*echo ? msk->pm.remote.port : msk->pm.local.port); + + family = *echo ? msk->pm.remote.family : msk->pm.local.family; + if (remaining < mptcp_add_addr_len(family, *echo, port)) + goto out_unlock; + + if (*echo) { + *addr = msk->pm.remote; + add_addr = msk->pm.addr_signal & ~BIT(MPTCP_ADD_ADDR_ECHO); + } else { + *addr = msk->pm.local; + add_addr = msk->pm.addr_signal & ~BIT(MPTCP_ADD_ADDR_SIGNAL); + } + WRITE_ONCE(msk->pm.addr_signal, add_addr); + ret = true; + +out_unlock: + spin_unlock_bh(&msk->pm.lock); + return ret; +} + +bool mptcp_pm_rm_addr_signal(struct mptcp_sock *msk, unsigned int remaining, + struct mptcp_rm_list *rm_list) +{ + int ret = false, len; + u8 rm_addr; + + spin_lock_bh(&msk->pm.lock); + + /* double check after the lock is acquired */ + if (!mptcp_pm_should_rm_signal(msk)) + goto out_unlock; + + rm_addr = msk->pm.addr_signal & ~BIT(MPTCP_RM_ADDR_SIGNAL); + len = mptcp_rm_addr_len(&msk->pm.rm_list_tx); + if (len < 0) { + WRITE_ONCE(msk->pm.addr_signal, rm_addr); + goto out_unlock; + } + if (remaining < len) + goto out_unlock; + + *rm_list = msk->pm.rm_list_tx; + WRITE_ONCE(msk->pm.addr_signal, rm_addr); + ret = true; + +out_unlock: + spin_unlock_bh(&msk->pm.lock); + return ret; +} + +int mptcp_pm_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) +{ + return mptcp_pm_nl_get_local_id(msk, skc); +} + +void mptcp_pm_subflow_chk_stale(const struct mptcp_sock *msk, struct sock *ssk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + u32 rcv_tstamp = READ_ONCE(tcp_sk(ssk)->rcv_tstamp); + + /* keep track of rtx periods with no progress */ + if (!subflow->stale_count) { + subflow->stale_rcv_tstamp = rcv_tstamp; + subflow->stale_count++; + } else if (subflow->stale_rcv_tstamp == rcv_tstamp) { + if (subflow->stale_count < U8_MAX) + subflow->stale_count++; + mptcp_pm_nl_subflow_chk_stale(msk, ssk); + } else { + subflow->stale_count = 0; + mptcp_subflow_set_active(subflow); + } +} + +/* if sk is ipv4 or ipv6_only allows only same-family local and remote addresses, + * otherwise allow any matching local/remote pair + */ +bool mptcp_pm_addr_families_match(const struct sock *sk, + const struct mptcp_addr_info *loc, + const struct mptcp_addr_info *rem) +{ + bool mptcp_is_v4 = sk->sk_family == AF_INET; + +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + bool loc_is_v4 = loc->family == AF_INET || ipv6_addr_v4mapped(&loc->addr6); + bool rem_is_v4 = rem->family == AF_INET || ipv6_addr_v4mapped(&rem->addr6); + + if (mptcp_is_v4) + return loc_is_v4 && rem_is_v4; + + if (ipv6_only_sock(sk)) + return !loc_is_v4 && !rem_is_v4; + + return loc_is_v4 == rem_is_v4; +#else + return mptcp_is_v4 && loc->family == AF_INET && rem->family == AF_INET; +#endif +} + +void mptcp_pm_data_reset(struct mptcp_sock *msk) +{ + u8 pm_type = mptcp_get_pm_type(sock_net((struct sock *)msk)); + struct mptcp_pm_data *pm = &msk->pm; + + pm->add_addr_signaled = 0; + pm->add_addr_accepted = 0; + pm->local_addr_used = 0; + pm->subflows = 0; + pm->rm_list_tx.nr = 0; + pm->rm_list_rx.nr = 0; + WRITE_ONCE(pm->pm_type, pm_type); + + if (pm_type == MPTCP_PM_TYPE_KERNEL) { + bool subflows_allowed = !!mptcp_pm_get_subflows_max(msk); + + /* pm->work_pending must be only be set to 'true' when + * pm->pm_type is set to MPTCP_PM_TYPE_KERNEL + */ + WRITE_ONCE(pm->work_pending, + (!!mptcp_pm_get_local_addr_max(msk) && + subflows_allowed) || + !!mptcp_pm_get_add_addr_signal_max(msk)); + WRITE_ONCE(pm->accept_addr, + !!mptcp_pm_get_add_addr_accept_max(msk) && + subflows_allowed); + WRITE_ONCE(pm->accept_subflow, subflows_allowed); + } else { + WRITE_ONCE(pm->work_pending, 0); + WRITE_ONCE(pm->accept_addr, 0); + WRITE_ONCE(pm->accept_subflow, 0); + } + + WRITE_ONCE(pm->addr_signal, 0); + WRITE_ONCE(pm->remote_deny_join_id0, false); + pm->status = 0; + bitmap_fill(msk->pm.id_avail_bitmap, MPTCP_PM_MAX_ADDR_ID + 1); +} + +void mptcp_pm_data_init(struct mptcp_sock *msk) +{ + spin_lock_init(&msk->pm.lock); + INIT_LIST_HEAD(&msk->pm.anno_list); + INIT_LIST_HEAD(&msk->pm.userspace_pm_local_addr_list); + mptcp_pm_data_reset(msk); +} + +void __init mptcp_pm_init(void) +{ + mptcp_pm_nl_init(); +} diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c new file mode 100644 index 000000000..980050f6b --- /dev/null +++ b/net/mptcp/pm_netlink.c @@ -0,0 +1,2365 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Multipath TCP + * + * Copyright (c) 2020, Red Hat, Inc. + */ + +#define pr_fmt(fmt) "MPTCP: " fmt + +#include +#include +#include +#include +#include +#include +#include + +#include "protocol.h" +#include "mib.h" + +/* forward declaration */ +static struct genl_family mptcp_genl_family; + +static int pm_nl_pernet_id; + +struct mptcp_pm_add_entry { + struct list_head list; + struct mptcp_addr_info addr; + struct timer_list add_timer; + struct mptcp_sock *sock; + u8 retrans_times; +}; + +struct pm_nl_pernet { + /* protects pernet updates */ + spinlock_t lock; + struct list_head local_addr_list; + unsigned int addrs; + unsigned int stale_loss_cnt; + unsigned int add_addr_signal_max; + unsigned int add_addr_accept_max; + unsigned int local_addr_max; + unsigned int subflows_max; + unsigned int next_id; + DECLARE_BITMAP(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1); +}; + +#define MPTCP_PM_ADDR_MAX 8 +#define ADD_ADDR_RETRANS_MAX 3 + +static struct pm_nl_pernet *pm_nl_get_pernet(const struct net *net) +{ + return net_generic(net, pm_nl_pernet_id); +} + +static struct pm_nl_pernet * +pm_nl_get_pernet_from_msk(const struct mptcp_sock *msk) +{ + return pm_nl_get_pernet(sock_net((struct sock *)msk)); +} + +bool mptcp_addresses_equal(const struct mptcp_addr_info *a, + const struct mptcp_addr_info *b, bool use_port) +{ + bool addr_equals = false; + + if (a->family == b->family) { + if (a->family == AF_INET) + addr_equals = a->addr.s_addr == b->addr.s_addr; +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + else + addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6); + } else if (a->family == AF_INET) { + if (ipv6_addr_v4mapped(&b->addr6)) + addr_equals = a->addr.s_addr == b->addr6.s6_addr32[3]; + } else if (b->family == AF_INET) { + if (ipv6_addr_v4mapped(&a->addr6)) + addr_equals = a->addr6.s6_addr32[3] == b->addr.s_addr; +#endif + } + + if (!addr_equals) + return false; + if (!use_port) + return true; + + return a->port == b->port; +} + +static void local_address(const struct sock_common *skc, + struct mptcp_addr_info *addr) +{ + addr->family = skc->skc_family; + addr->port = htons(skc->skc_num); + if (addr->family == AF_INET) + addr->addr.s_addr = skc->skc_rcv_saddr; +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + else if (addr->family == AF_INET6) + addr->addr6 = skc->skc_v6_rcv_saddr; +#endif +} + +static void remote_address(const struct sock_common *skc, + struct mptcp_addr_info *addr) +{ + addr->family = skc->skc_family; + addr->port = skc->skc_dport; + if (addr->family == AF_INET) + addr->addr.s_addr = skc->skc_daddr; +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + else if (addr->family == AF_INET6) + addr->addr6 = skc->skc_v6_daddr; +#endif +} + +static bool lookup_subflow_by_saddr(const struct list_head *list, + const struct mptcp_addr_info *saddr) +{ + struct mptcp_subflow_context *subflow; + struct mptcp_addr_info cur; + struct sock_common *skc; + + list_for_each_entry(subflow, list, node) { + skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow); + + local_address(skc, &cur); + if (mptcp_addresses_equal(&cur, saddr, saddr->port)) + return true; + } + + return false; +} + +static bool lookup_subflow_by_daddr(const struct list_head *list, + const struct mptcp_addr_info *daddr) +{ + struct mptcp_subflow_context *subflow; + struct mptcp_addr_info cur; + struct sock_common *skc; + + list_for_each_entry(subflow, list, node) { + skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow); + + remote_address(skc, &cur); + if (mptcp_addresses_equal(&cur, daddr, daddr->port)) + return true; + } + + return false; +} + +static struct mptcp_pm_addr_entry * +select_local_address(const struct pm_nl_pernet *pernet, + const struct mptcp_sock *msk) +{ + const struct sock *sk = (const struct sock *)msk; + struct mptcp_pm_addr_entry *entry, *ret = NULL; + + msk_owned_by_me(msk); + + rcu_read_lock(); + list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { + if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW)) + continue; + + if (!test_bit(entry->addr.id, msk->pm.id_avail_bitmap)) + continue; + + if (entry->addr.family != sk->sk_family) { +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + if ((entry->addr.family == AF_INET && + !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) || + (sk->sk_family == AF_INET && + !ipv6_addr_v4mapped(&entry->addr.addr6))) +#endif + continue; + } + + ret = entry; + break; + } + rcu_read_unlock(); + return ret; +} + +static struct mptcp_pm_addr_entry * +select_signal_address(struct pm_nl_pernet *pernet, const struct mptcp_sock *msk) +{ + struct mptcp_pm_addr_entry *entry, *ret = NULL; + + rcu_read_lock(); + /* do not keep any additional per socket state, just signal + * the address list in order. + * Note: removal from the local address list during the msk life-cycle + * can lead to additional addresses not being announced. + */ + list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { + if (!test_bit(entry->addr.id, msk->pm.id_avail_bitmap)) + continue; + + if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) + continue; + + ret = entry; + break; + } + rcu_read_unlock(); + return ret; +} + +unsigned int mptcp_pm_get_add_addr_signal_max(const struct mptcp_sock *msk) +{ + const struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk); + + return READ_ONCE(pernet->add_addr_signal_max); +} +EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_signal_max); + +unsigned int mptcp_pm_get_add_addr_accept_max(const struct mptcp_sock *msk) +{ + struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk); + + return READ_ONCE(pernet->add_addr_accept_max); +} +EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_accept_max); + +unsigned int mptcp_pm_get_subflows_max(const struct mptcp_sock *msk) +{ + struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk); + + return READ_ONCE(pernet->subflows_max); +} +EXPORT_SYMBOL_GPL(mptcp_pm_get_subflows_max); + +unsigned int mptcp_pm_get_local_addr_max(const struct mptcp_sock *msk) +{ + struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk); + + return READ_ONCE(pernet->local_addr_max); +} +EXPORT_SYMBOL_GPL(mptcp_pm_get_local_addr_max); + +bool mptcp_pm_nl_check_work_pending(struct mptcp_sock *msk) +{ + struct pm_nl_pernet *pernet = pm_nl_get_pernet_from_msk(msk); + + if (msk->pm.subflows == mptcp_pm_get_subflows_max(msk) || + (find_next_and_bit(pernet->id_bitmap, msk->pm.id_avail_bitmap, + MPTCP_PM_MAX_ADDR_ID + 1, 0) == MPTCP_PM_MAX_ADDR_ID + 1)) { + WRITE_ONCE(msk->pm.work_pending, false); + return false; + } + return true; +} + +struct mptcp_pm_add_entry * +mptcp_lookup_anno_list_by_saddr(const struct mptcp_sock *msk, + const struct mptcp_addr_info *addr) +{ + struct mptcp_pm_add_entry *entry; + + lockdep_assert_held(&msk->pm.lock); + + list_for_each_entry(entry, &msk->pm.anno_list, list) { + if (mptcp_addresses_equal(&entry->addr, addr, true)) + return entry; + } + + return NULL; +} + +bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk) +{ + struct mptcp_pm_add_entry *entry; + struct mptcp_addr_info saddr; + bool ret = false; + + local_address((struct sock_common *)sk, &saddr); + + spin_lock_bh(&msk->pm.lock); + list_for_each_entry(entry, &msk->pm.anno_list, list) { + if (mptcp_addresses_equal(&entry->addr, &saddr, true)) { + ret = true; + goto out; + } + } + +out: + spin_unlock_bh(&msk->pm.lock); + return ret; +} + +static void mptcp_pm_add_timer(struct timer_list *timer) +{ + struct mptcp_pm_add_entry *entry = from_timer(entry, timer, add_timer); + struct mptcp_sock *msk = entry->sock; + struct sock *sk = (struct sock *)msk; + + pr_debug("msk=%p", msk); + + if (!msk) + return; + + if (inet_sk_state_load(sk) == TCP_CLOSE) + return; + + if (!entry->addr.id) + return; + + if (mptcp_pm_should_add_signal_addr(msk)) { + sk_reset_timer(sk, timer, jiffies + TCP_RTO_MAX / 8); + goto out; + } + + spin_lock_bh(&msk->pm.lock); + + if (!mptcp_pm_should_add_signal_addr(msk)) { + pr_debug("retransmit ADD_ADDR id=%d", entry->addr.id); + mptcp_pm_announce_addr(msk, &entry->addr, false); + mptcp_pm_add_addr_send_ack(msk); + entry->retrans_times++; + } + + if (entry->retrans_times < ADD_ADDR_RETRANS_MAX) + sk_reset_timer(sk, timer, + jiffies + mptcp_get_add_addr_timeout(sock_net(sk))); + + spin_unlock_bh(&msk->pm.lock); + + if (entry->retrans_times == ADD_ADDR_RETRANS_MAX) + mptcp_pm_subflow_established(msk); + +out: + __sock_put(sk); +} + +struct mptcp_pm_add_entry * +mptcp_pm_del_add_timer(struct mptcp_sock *msk, + const struct mptcp_addr_info *addr, bool check_id) +{ + struct mptcp_pm_add_entry *entry; + struct sock *sk = (struct sock *)msk; + + spin_lock_bh(&msk->pm.lock); + entry = mptcp_lookup_anno_list_by_saddr(msk, addr); + if (entry && (!check_id || entry->addr.id == addr->id)) + entry->retrans_times = ADD_ADDR_RETRANS_MAX; + spin_unlock_bh(&msk->pm.lock); + + if (entry && (!check_id || entry->addr.id == addr->id)) + sk_stop_timer_sync(sk, &entry->add_timer); + + return entry; +} + +bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk, + const struct mptcp_pm_addr_entry *entry) +{ + struct mptcp_pm_add_entry *add_entry = NULL; + struct sock *sk = (struct sock *)msk; + struct net *net = sock_net(sk); + + lockdep_assert_held(&msk->pm.lock); + + add_entry = mptcp_lookup_anno_list_by_saddr(msk, &entry->addr); + + if (add_entry) { + if (mptcp_pm_is_kernel(msk)) + return false; + + sk_reset_timer(sk, &add_entry->add_timer, + jiffies + mptcp_get_add_addr_timeout(net)); + return true; + } + + add_entry = kmalloc(sizeof(*add_entry), GFP_ATOMIC); + if (!add_entry) + return false; + + list_add(&add_entry->list, &msk->pm.anno_list); + + add_entry->addr = entry->addr; + add_entry->sock = msk; + add_entry->retrans_times = 0; + + timer_setup(&add_entry->add_timer, mptcp_pm_add_timer, 0); + sk_reset_timer(sk, &add_entry->add_timer, + jiffies + mptcp_get_add_addr_timeout(net)); + + return true; +} + +void mptcp_pm_free_anno_list(struct mptcp_sock *msk) +{ + struct mptcp_pm_add_entry *entry, *tmp; + struct sock *sk = (struct sock *)msk; + LIST_HEAD(free_list); + + pr_debug("msk=%p", msk); + + spin_lock_bh(&msk->pm.lock); + list_splice_init(&msk->pm.anno_list, &free_list); + spin_unlock_bh(&msk->pm.lock); + + list_for_each_entry_safe(entry, tmp, &free_list, list) { + sk_stop_timer_sync(sk, &entry->add_timer); + kfree(entry); + } +} + +static bool lookup_address_in_vec(const struct mptcp_addr_info *addrs, unsigned int nr, + const struct mptcp_addr_info *addr) +{ + int i; + + for (i = 0; i < nr; i++) { + if (addrs[i].id == addr->id) + return true; + } + + return false; +} + +/* Fill all the remote addresses into the array addrs[], + * and return the array size. + */ +static unsigned int fill_remote_addresses_vec(struct mptcp_sock *msk, bool fullmesh, + struct mptcp_addr_info *addrs) +{ + bool deny_id0 = READ_ONCE(msk->pm.remote_deny_join_id0); + struct sock *sk = (struct sock *)msk, *ssk; + struct mptcp_subflow_context *subflow; + struct mptcp_addr_info remote = { 0 }; + unsigned int subflows_max; + int i = 0; + + subflows_max = mptcp_pm_get_subflows_max(msk); + remote_address((struct sock_common *)sk, &remote); + + /* Non-fullmesh endpoint, fill in the single entry + * corresponding to the primary MPC subflow remote address + */ + if (!fullmesh) { + if (deny_id0) + return 0; + + msk->pm.subflows++; + addrs[i++] = remote; + } else { + mptcp_for_each_subflow(msk, subflow) { + ssk = mptcp_subflow_tcp_sock(subflow); + remote_address((struct sock_common *)ssk, &addrs[i]); + addrs[i].id = subflow->remote_id; + if (deny_id0 && !addrs[i].id) + continue; + + if (!lookup_address_in_vec(addrs, i, &addrs[i]) && + msk->pm.subflows < subflows_max) { + msk->pm.subflows++; + i++; + } + } + } + + return i; +} + +static void __mptcp_pm_send_ack(struct mptcp_sock *msk, struct mptcp_subflow_context *subflow, + bool prio, bool backup) +{ + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + bool slow; + + pr_debug("send ack for %s", + prio ? "mp_prio" : (mptcp_pm_should_add_signal(msk) ? "add_addr" : "rm_addr")); + + slow = lock_sock_fast(ssk); + if (prio) { + if (subflow->backup != backup) + msk->last_snd = NULL; + + subflow->send_mp_prio = 1; + subflow->backup = backup; + subflow->request_bkup = backup; + } + + __mptcp_subflow_send_ack(ssk); + unlock_sock_fast(ssk, slow); +} + +static void mptcp_pm_send_ack(struct mptcp_sock *msk, struct mptcp_subflow_context *subflow, + bool prio, bool backup) +{ + spin_unlock_bh(&msk->pm.lock); + __mptcp_pm_send_ack(msk, subflow, prio, backup); + spin_lock_bh(&msk->pm.lock); +} + +static struct mptcp_pm_addr_entry * +__lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id) +{ + struct mptcp_pm_addr_entry *entry; + + list_for_each_entry(entry, &pernet->local_addr_list, list) { + if (entry->addr.id == id) + return entry; + } + return NULL; +} + +static struct mptcp_pm_addr_entry * +__lookup_addr(struct pm_nl_pernet *pernet, const struct mptcp_addr_info *info, + bool lookup_by_id) +{ + struct mptcp_pm_addr_entry *entry; + + list_for_each_entry(entry, &pernet->local_addr_list, list) { + if ((!lookup_by_id && + mptcp_addresses_equal(&entry->addr, info, entry->addr.port)) || + (lookup_by_id && entry->addr.id == info->id)) + return entry; + } + return NULL; +} + +static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) +{ + struct sock *sk = (struct sock *)msk; + struct mptcp_pm_addr_entry *local; + unsigned int add_addr_signal_max; + unsigned int local_addr_max; + struct pm_nl_pernet *pernet; + unsigned int subflows_max; + + pernet = pm_nl_get_pernet(sock_net(sk)); + + add_addr_signal_max = mptcp_pm_get_add_addr_signal_max(msk); + local_addr_max = mptcp_pm_get_local_addr_max(msk); + subflows_max = mptcp_pm_get_subflows_max(msk); + + /* do lazy endpoint usage accounting for the MPC subflows */ + if (unlikely(!(msk->pm.status & BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED))) && msk->first) { + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(msk->first); + struct mptcp_pm_addr_entry *entry; + struct mptcp_addr_info mpc_addr; + bool backup = false; + + local_address((struct sock_common *)msk->first, &mpc_addr); + rcu_read_lock(); + entry = __lookup_addr(pernet, &mpc_addr, false); + if (entry) { + __clear_bit(entry->addr.id, msk->pm.id_avail_bitmap); + msk->mpc_endpoint_id = entry->addr.id; + backup = !!(entry->flags & MPTCP_PM_ADDR_FLAG_BACKUP); + } + rcu_read_unlock(); + + if (backup) + mptcp_pm_send_ack(msk, subflow, true, backup); + + msk->pm.status |= BIT(MPTCP_PM_MPC_ENDPOINT_ACCOUNTED); + } + + pr_debug("local %d:%d signal %d:%d subflows %d:%d\n", + msk->pm.local_addr_used, local_addr_max, + msk->pm.add_addr_signaled, add_addr_signal_max, + msk->pm.subflows, subflows_max); + + /* check first for announce */ + if (msk->pm.add_addr_signaled < add_addr_signal_max) { + local = select_signal_address(pernet, msk); + + /* due to racing events on both ends we can reach here while + * previous add address is still running: if we invoke now + * mptcp_pm_announce_addr(), that will fail and the + * corresponding id will be marked as used. + * Instead let the PM machinery reschedule us when the + * current address announce will be completed. + */ + if (msk->pm.addr_signal & BIT(MPTCP_ADD_ADDR_SIGNAL)) + return; + + if (local) { + if (mptcp_pm_alloc_anno_list(msk, local)) { + __clear_bit(local->addr.id, msk->pm.id_avail_bitmap); + msk->pm.add_addr_signaled++; + mptcp_pm_announce_addr(msk, &local->addr, false); + mptcp_pm_nl_addr_send_ack(msk); + } + } + } + + /* check if should create a new subflow */ + while (msk->pm.local_addr_used < local_addr_max && + msk->pm.subflows < subflows_max) { + struct mptcp_addr_info addrs[MPTCP_PM_ADDR_MAX]; + bool fullmesh; + int i, nr; + + local = select_local_address(pernet, msk); + if (!local) + break; + + fullmesh = !!(local->flags & MPTCP_PM_ADDR_FLAG_FULLMESH); + + msk->pm.local_addr_used++; + nr = fill_remote_addresses_vec(msk, fullmesh, addrs); + if (nr) + __clear_bit(local->addr.id, msk->pm.id_avail_bitmap); + spin_unlock_bh(&msk->pm.lock); + for (i = 0; i < nr; i++) + __mptcp_subflow_connect(sk, &local->addr, &addrs[i]); + spin_lock_bh(&msk->pm.lock); + } + mptcp_pm_nl_check_work_pending(msk); +} + +static void mptcp_pm_nl_fully_established(struct mptcp_sock *msk) +{ + mptcp_pm_create_subflow_or_signal_addr(msk); +} + +static void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk) +{ + mptcp_pm_create_subflow_or_signal_addr(msk); +} + +/* Fill all the local addresses into the array addrs[], + * and return the array size. + */ +static unsigned int fill_local_addresses_vec(struct mptcp_sock *msk, + struct mptcp_addr_info *addrs) +{ + struct sock *sk = (struct sock *)msk; + struct mptcp_pm_addr_entry *entry; + struct mptcp_addr_info local; + struct pm_nl_pernet *pernet; + unsigned int subflows_max; + int i = 0; + + pernet = pm_nl_get_pernet_from_msk(msk); + subflows_max = mptcp_pm_get_subflows_max(msk); + + rcu_read_lock(); + list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { + if (!(entry->flags & MPTCP_PM_ADDR_FLAG_FULLMESH)) + continue; + + if (entry->addr.family != sk->sk_family) { +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + if ((entry->addr.family == AF_INET && + !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) || + (sk->sk_family == AF_INET && + !ipv6_addr_v4mapped(&entry->addr.addr6))) +#endif + continue; + } + + if (msk->pm.subflows < subflows_max) { + msk->pm.subflows++; + addrs[i++] = entry->addr; + } + } + rcu_read_unlock(); + + /* If the array is empty, fill in the single + * 'IPADDRANY' local address + */ + if (!i) { + memset(&local, 0, sizeof(local)); + local.family = msk->pm.remote.family; + + msk->pm.subflows++; + addrs[i++] = local; + } + + return i; +} + +static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk) +{ + struct mptcp_addr_info addrs[MPTCP_PM_ADDR_MAX]; + struct sock *sk = (struct sock *)msk; + unsigned int add_addr_accept_max; + struct mptcp_addr_info remote; + unsigned int subflows_max; + int i, nr; + + add_addr_accept_max = mptcp_pm_get_add_addr_accept_max(msk); + subflows_max = mptcp_pm_get_subflows_max(msk); + + pr_debug("accepted %d:%d remote family %d", + msk->pm.add_addr_accepted, add_addr_accept_max, + msk->pm.remote.family); + + remote = msk->pm.remote; + mptcp_pm_announce_addr(msk, &remote, true); + mptcp_pm_nl_addr_send_ack(msk); + + if (lookup_subflow_by_daddr(&msk->conn_list, &remote)) + return; + + /* pick id 0 port, if none is provided the remote address */ + if (!remote.port) + remote.port = sk->sk_dport; + + /* connect to the specified remote address, using whatever + * local address the routing configuration will pick. + */ + nr = fill_local_addresses_vec(msk, addrs); + + msk->pm.add_addr_accepted++; + if (msk->pm.add_addr_accepted >= add_addr_accept_max || + msk->pm.subflows >= subflows_max) + WRITE_ONCE(msk->pm.accept_addr, false); + + spin_unlock_bh(&msk->pm.lock); + for (i = 0; i < nr; i++) + __mptcp_subflow_connect(sk, &addrs[i], &remote); + spin_lock_bh(&msk->pm.lock); +} + +void mptcp_pm_nl_addr_send_ack(struct mptcp_sock *msk) +{ + struct mptcp_subflow_context *subflow; + + msk_owned_by_me(msk); + lockdep_assert_held(&msk->pm.lock); + + if (!mptcp_pm_should_add_signal(msk) && + !mptcp_pm_should_rm_signal(msk)) + return; + + subflow = list_first_entry_or_null(&msk->conn_list, typeof(*subflow), node); + if (subflow) + mptcp_pm_send_ack(msk, subflow, false, false); +} + +int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk, + struct mptcp_addr_info *addr, + struct mptcp_addr_info *rem, + u8 bkup) +{ + struct mptcp_subflow_context *subflow; + + pr_debug("bkup=%d", bkup); + + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + struct mptcp_addr_info local, remote; + + local_address((struct sock_common *)ssk, &local); + if (!mptcp_addresses_equal(&local, addr, addr->port)) + continue; + + if (rem && rem->family != AF_UNSPEC) { + remote_address((struct sock_common *)ssk, &remote); + if (!mptcp_addresses_equal(&remote, rem, rem->port)) + continue; + } + + __mptcp_pm_send_ack(msk, subflow, true, bkup); + return 0; + } + + return -EINVAL; +} + +static bool mptcp_local_id_match(const struct mptcp_sock *msk, u8 local_id, u8 id) +{ + return local_id == id || (!local_id && msk->mpc_endpoint_id == id); +} + +static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk, + const struct mptcp_rm_list *rm_list, + enum linux_mptcp_mib_field rm_type) +{ + struct mptcp_subflow_context *subflow, *tmp; + struct sock *sk = (struct sock *)msk; + u8 i; + + pr_debug("%s rm_list_nr %d", + rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow", rm_list->nr); + + msk_owned_by_me(msk); + + if (sk->sk_state == TCP_LISTEN) + return; + + if (!rm_list->nr) + return; + + if (list_empty(&msk->conn_list)) + return; + + for (i = 0; i < rm_list->nr; i++) { + u8 rm_id = rm_list->ids[i]; + bool removed = false; + + mptcp_for_each_subflow_safe(msk, subflow, tmp) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + int how = RCV_SHUTDOWN | SEND_SHUTDOWN; + u8 id = subflow->local_id; + + if (rm_type == MPTCP_MIB_RMADDR && subflow->remote_id != rm_id) + continue; + if (rm_type == MPTCP_MIB_RMSUBFLOW && !mptcp_local_id_match(msk, id, rm_id)) + continue; + + pr_debug(" -> %s rm_list_ids[%d]=%u local_id=%u remote_id=%u mpc_id=%u", + rm_type == MPTCP_MIB_RMADDR ? "address" : "subflow", + i, rm_id, subflow->local_id, subflow->remote_id, + msk->mpc_endpoint_id); + spin_unlock_bh(&msk->pm.lock); + mptcp_subflow_shutdown(sk, ssk, how); + + /* the following takes care of updating the subflows counter */ + mptcp_close_ssk(sk, ssk, subflow); + spin_lock_bh(&msk->pm.lock); + + removed = true; + __MPTCP_INC_STATS(sock_net(sk), rm_type); + } + if (rm_type == MPTCP_MIB_RMSUBFLOW) + __set_bit(rm_id ? rm_id : msk->mpc_endpoint_id, msk->pm.id_avail_bitmap); + if (!removed) + continue; + + if (!mptcp_pm_is_kernel(msk)) + continue; + + if (rm_type == MPTCP_MIB_RMADDR) { + msk->pm.add_addr_accepted--; + WRITE_ONCE(msk->pm.accept_addr, true); + } else if (rm_type == MPTCP_MIB_RMSUBFLOW) { + msk->pm.local_addr_used--; + } + } +} + +static void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk) +{ + mptcp_pm_nl_rm_addr_or_subflow(msk, &msk->pm.rm_list_rx, MPTCP_MIB_RMADDR); +} + +void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk, + const struct mptcp_rm_list *rm_list) +{ + mptcp_pm_nl_rm_addr_or_subflow(msk, rm_list, MPTCP_MIB_RMSUBFLOW); +} + +void mptcp_pm_nl_work(struct mptcp_sock *msk) +{ + struct mptcp_pm_data *pm = &msk->pm; + + msk_owned_by_me(msk); + + if (!(pm->status & MPTCP_PM_WORK_MASK)) + return; + + spin_lock_bh(&msk->pm.lock); + + pr_debug("msk=%p status=%x", msk, pm->status); + if (pm->status & BIT(MPTCP_PM_ADD_ADDR_RECEIVED)) { + pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_RECEIVED); + mptcp_pm_nl_add_addr_received(msk); + } + if (pm->status & BIT(MPTCP_PM_ADD_ADDR_SEND_ACK)) { + pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_SEND_ACK); + mptcp_pm_nl_addr_send_ack(msk); + } + if (pm->status & BIT(MPTCP_PM_RM_ADDR_RECEIVED)) { + pm->status &= ~BIT(MPTCP_PM_RM_ADDR_RECEIVED); + mptcp_pm_nl_rm_addr_received(msk); + } + if (pm->status & BIT(MPTCP_PM_ESTABLISHED)) { + pm->status &= ~BIT(MPTCP_PM_ESTABLISHED); + mptcp_pm_nl_fully_established(msk); + } + if (pm->status & BIT(MPTCP_PM_SUBFLOW_ESTABLISHED)) { + pm->status &= ~BIT(MPTCP_PM_SUBFLOW_ESTABLISHED); + mptcp_pm_nl_subflow_established(msk); + } + + spin_unlock_bh(&msk->pm.lock); +} + +static bool address_use_port(struct mptcp_pm_addr_entry *entry) +{ + return (entry->flags & + (MPTCP_PM_ADDR_FLAG_SIGNAL | MPTCP_PM_ADDR_FLAG_SUBFLOW)) == + MPTCP_PM_ADDR_FLAG_SIGNAL; +} + +/* caller must ensure the RCU grace period is already elapsed */ +static void __mptcp_pm_release_addr_entry(struct mptcp_pm_addr_entry *entry) +{ + if (entry->lsk) + sock_release(entry->lsk); + kfree(entry); +} + +static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet, + struct mptcp_pm_addr_entry *entry) +{ + struct mptcp_pm_addr_entry *cur, *del_entry = NULL; + unsigned int addr_max; + int ret = -EINVAL; + + spin_lock_bh(&pernet->lock); + /* to keep the code simple, don't do IDR-like allocation for address ID, + * just bail when we exceed limits + */ + if (pernet->next_id == MPTCP_PM_MAX_ADDR_ID) + pernet->next_id = 1; + if (pernet->addrs >= MPTCP_PM_ADDR_MAX) + goto out; + if (test_bit(entry->addr.id, pernet->id_bitmap)) + goto out; + + /* do not insert duplicate address, differentiate on port only + * singled addresses + */ + if (!address_use_port(entry)) + entry->addr.port = 0; + list_for_each_entry(cur, &pernet->local_addr_list, list) { + if (mptcp_addresses_equal(&cur->addr, &entry->addr, + cur->addr.port || entry->addr.port)) { + /* allow replacing the exiting endpoint only if such + * endpoint is an implicit one and the user-space + * did not provide an endpoint id + */ + if (!(cur->flags & MPTCP_PM_ADDR_FLAG_IMPLICIT)) + goto out; + if (entry->addr.id) + goto out; + + pernet->addrs--; + entry->addr.id = cur->addr.id; + list_del_rcu(&cur->list); + del_entry = cur; + break; + } + } + + if (!entry->addr.id) { +find_next: + entry->addr.id = find_next_zero_bit(pernet->id_bitmap, + MPTCP_PM_MAX_ADDR_ID + 1, + pernet->next_id); + if (!entry->addr.id && pernet->next_id != 1) { + pernet->next_id = 1; + goto find_next; + } + } + + if (!entry->addr.id) + goto out; + + __set_bit(entry->addr.id, pernet->id_bitmap); + if (entry->addr.id > pernet->next_id) + pernet->next_id = entry->addr.id; + + if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL) { + addr_max = pernet->add_addr_signal_max; + WRITE_ONCE(pernet->add_addr_signal_max, addr_max + 1); + } + if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) { + addr_max = pernet->local_addr_max; + WRITE_ONCE(pernet->local_addr_max, addr_max + 1); + } + + pernet->addrs++; + if (!entry->addr.port) + list_add_tail_rcu(&entry->list, &pernet->local_addr_list); + else + list_add_rcu(&entry->list, &pernet->local_addr_list); + ret = entry->addr.id; + +out: + spin_unlock_bh(&pernet->lock); + + /* just replaced an existing entry, free it */ + if (del_entry) { + synchronize_rcu(); + __mptcp_pm_release_addr_entry(del_entry); + } + return ret; +} + +static struct lock_class_key mptcp_slock_keys[2]; +static struct lock_class_key mptcp_keys[2]; + +static int mptcp_pm_nl_create_listen_socket(struct sock *sk, + struct mptcp_pm_addr_entry *entry) +{ + bool is_ipv6 = sk->sk_family == AF_INET6; + int addrlen = sizeof(struct sockaddr_in); + struct sockaddr_storage addr; + struct socket *ssock; + struct sock *newsk; + int backlog = 1024; + int err; + + err = sock_create_kern(sock_net(sk), entry->addr.family, + SOCK_STREAM, IPPROTO_MPTCP, &entry->lsk); + if (err) + return err; + + newsk = entry->lsk->sk; + if (!newsk) + return -EINVAL; + + /* The subflow socket lock is acquired in a nested to the msk one + * in several places, even by the TCP stack, and this msk is a kernel + * socket: lockdep complains. Instead of propagating the _nested + * modifiers in several places, re-init the lock class for the msk + * socket to an mptcp specific one. + */ + sock_lock_init_class_and_name(newsk, + is_ipv6 ? "mlock-AF_INET6" : "mlock-AF_INET", + &mptcp_slock_keys[is_ipv6], + is_ipv6 ? "msk_lock-AF_INET6" : "msk_lock-AF_INET", + &mptcp_keys[is_ipv6]); + + lock_sock(newsk); + ssock = __mptcp_nmpc_socket(mptcp_sk(newsk)); + release_sock(newsk); + if (!ssock) + return -EINVAL; + + mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family); +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + if (entry->addr.family == AF_INET6) + addrlen = sizeof(struct sockaddr_in6); +#endif + err = kernel_bind(ssock, (struct sockaddr *)&addr, addrlen); + if (err) { + pr_warn("kernel_bind error, err=%d", err); + return err; + } + + inet_sk_state_store(newsk, TCP_LISTEN); + err = kernel_listen(ssock, backlog); + if (err) { + pr_warn("kernel_listen error, err=%d", err); + return err; + } + + return 0; +} + +int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) +{ + struct mptcp_pm_addr_entry *entry; + struct mptcp_addr_info skc_local; + struct mptcp_addr_info msk_local; + struct pm_nl_pernet *pernet; + int ret = -1; + + if (WARN_ON_ONCE(!msk)) + return -1; + + /* The 0 ID mapping is defined by the first subflow, copied into the msk + * addr + */ + local_address((struct sock_common *)msk, &msk_local); + local_address((struct sock_common *)skc, &skc_local); + if (mptcp_addresses_equal(&msk_local, &skc_local, false)) + return 0; + + if (mptcp_pm_is_userspace(msk)) + return mptcp_userspace_pm_get_local_id(msk, &skc_local); + + pernet = pm_nl_get_pernet_from_msk(msk); + + rcu_read_lock(); + list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { + if (mptcp_addresses_equal(&entry->addr, &skc_local, entry->addr.port)) { + ret = entry->addr.id; + break; + } + } + rcu_read_unlock(); + if (ret >= 0) + return ret; + + /* address not found, add to local list */ + entry = kmalloc(sizeof(*entry), GFP_ATOMIC); + if (!entry) + return -ENOMEM; + + entry->addr = skc_local; + entry->addr.id = 0; + entry->addr.port = 0; + entry->ifindex = 0; + entry->flags = MPTCP_PM_ADDR_FLAG_IMPLICIT; + entry->lsk = NULL; + ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); + if (ret < 0) + kfree(entry); + + return ret; +} + +#define MPTCP_PM_CMD_GRP_OFFSET 0 +#define MPTCP_PM_EV_GRP_OFFSET 1 + +static const struct genl_multicast_group mptcp_pm_mcgrps[] = { + [MPTCP_PM_CMD_GRP_OFFSET] = { .name = MPTCP_PM_CMD_GRP_NAME, }, + [MPTCP_PM_EV_GRP_OFFSET] = { .name = MPTCP_PM_EV_GRP_NAME, + .flags = GENL_UNS_ADMIN_PERM, + }, +}; + +static const struct nla_policy +mptcp_pm_addr_policy[MPTCP_PM_ADDR_ATTR_MAX + 1] = { + [MPTCP_PM_ADDR_ATTR_FAMILY] = { .type = NLA_U16, }, + [MPTCP_PM_ADDR_ATTR_ID] = { .type = NLA_U8, }, + [MPTCP_PM_ADDR_ATTR_ADDR4] = { .type = NLA_U32, }, + [MPTCP_PM_ADDR_ATTR_ADDR6] = + NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)), + [MPTCP_PM_ADDR_ATTR_PORT] = { .type = NLA_U16 }, + [MPTCP_PM_ADDR_ATTR_FLAGS] = { .type = NLA_U32 }, + [MPTCP_PM_ADDR_ATTR_IF_IDX] = { .type = NLA_S32 }, +}; + +static const struct nla_policy mptcp_pm_policy[MPTCP_PM_ATTR_MAX + 1] = { + [MPTCP_PM_ATTR_ADDR] = + NLA_POLICY_NESTED(mptcp_pm_addr_policy), + [MPTCP_PM_ATTR_RCV_ADD_ADDRS] = { .type = NLA_U32, }, + [MPTCP_PM_ATTR_SUBFLOWS] = { .type = NLA_U32, }, + [MPTCP_PM_ATTR_TOKEN] = { .type = NLA_U32, }, + [MPTCP_PM_ATTR_LOC_ID] = { .type = NLA_U8, }, + [MPTCP_PM_ATTR_ADDR_REMOTE] = + NLA_POLICY_NESTED(mptcp_pm_addr_policy), +}; + +void mptcp_pm_nl_subflow_chk_stale(const struct mptcp_sock *msk, struct sock *ssk) +{ + struct mptcp_subflow_context *iter, *subflow = mptcp_subflow_ctx(ssk); + struct sock *sk = (struct sock *)msk; + unsigned int active_max_loss_cnt; + struct net *net = sock_net(sk); + unsigned int stale_loss_cnt; + bool slow; + + stale_loss_cnt = mptcp_stale_loss_cnt(net); + if (subflow->stale || !stale_loss_cnt || subflow->stale_count <= stale_loss_cnt) + return; + + /* look for another available subflow not in loss state */ + active_max_loss_cnt = max_t(int, stale_loss_cnt - 1, 1); + mptcp_for_each_subflow(msk, iter) { + if (iter != subflow && mptcp_subflow_active(iter) && + iter->stale_count < active_max_loss_cnt) { + /* we have some alternatives, try to mark this subflow as idle ...*/ + slow = lock_sock_fast(ssk); + if (!tcp_rtx_and_write_queues_empty(ssk)) { + subflow->stale = 1; + __mptcp_retransmit_pending_data(sk); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_SUBFLOWSTALE); + } + unlock_sock_fast(ssk, slow); + + /* always try to push the pending data regardless of re-injections: + * we can possibly use backup subflows now, and subflow selection + * is cheap under the msk socket lock + */ + __mptcp_push_pending(sk, 0); + return; + } + } +} + +static int mptcp_pm_family_to_addr(int family) +{ +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + if (family == AF_INET6) + return MPTCP_PM_ADDR_ATTR_ADDR6; +#endif + return MPTCP_PM_ADDR_ATTR_ADDR4; +} + +static int mptcp_pm_parse_pm_addr_attr(struct nlattr *tb[], + const struct nlattr *attr, + struct genl_info *info, + struct mptcp_addr_info *addr, + bool require_family) +{ + int err, addr_addr; + + if (!attr) { + GENL_SET_ERR_MSG(info, "missing address info"); + return -EINVAL; + } + + /* no validation needed - was already done via nested policy */ + err = nla_parse_nested_deprecated(tb, MPTCP_PM_ADDR_ATTR_MAX, attr, + mptcp_pm_addr_policy, info->extack); + if (err) + return err; + + if (tb[MPTCP_PM_ADDR_ATTR_ID]) + addr->id = nla_get_u8(tb[MPTCP_PM_ADDR_ATTR_ID]); + + if (!tb[MPTCP_PM_ADDR_ATTR_FAMILY]) { + if (!require_family) + return err; + + NL_SET_ERR_MSG_ATTR(info->extack, attr, + "missing family"); + return -EINVAL; + } + + addr->family = nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_FAMILY]); + if (addr->family != AF_INET +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + && addr->family != AF_INET6 +#endif + ) { + NL_SET_ERR_MSG_ATTR(info->extack, attr, + "unknown address family"); + return -EINVAL; + } + addr_addr = mptcp_pm_family_to_addr(addr->family); + if (!tb[addr_addr]) { + NL_SET_ERR_MSG_ATTR(info->extack, attr, + "missing address data"); + return -EINVAL; + } + +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + if (addr->family == AF_INET6) + addr->addr6 = nla_get_in6_addr(tb[addr_addr]); + else +#endif + addr->addr.s_addr = nla_get_in_addr(tb[addr_addr]); + + if (tb[MPTCP_PM_ADDR_ATTR_PORT]) + addr->port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT])); + + return err; +} + +int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info, + struct mptcp_addr_info *addr) +{ + struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1]; + + memset(addr, 0, sizeof(*addr)); + + return mptcp_pm_parse_pm_addr_attr(tb, attr, info, addr, true); +} + +int mptcp_pm_parse_entry(struct nlattr *attr, struct genl_info *info, + bool require_family, + struct mptcp_pm_addr_entry *entry) +{ + struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1]; + int err; + + memset(entry, 0, sizeof(*entry)); + + err = mptcp_pm_parse_pm_addr_attr(tb, attr, info, &entry->addr, require_family); + if (err) + return err; + + if (tb[MPTCP_PM_ADDR_ATTR_IF_IDX]) { + u32 val = nla_get_s32(tb[MPTCP_PM_ADDR_ATTR_IF_IDX]); + + entry->ifindex = val; + } + + if (tb[MPTCP_PM_ADDR_ATTR_FLAGS]) + entry->flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]); + + if (tb[MPTCP_PM_ADDR_ATTR_PORT]) + entry->addr.port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT])); + + return 0; +} + +static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info) +{ + return pm_nl_get_pernet(genl_info_net(info)); +} + +static int mptcp_nl_add_subflow_or_signal_addr(struct net *net) +{ + struct mptcp_sock *msk; + long s_slot = 0, s_num = 0; + + while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { + struct sock *sk = (struct sock *)msk; + + if (!READ_ONCE(msk->fully_established) || + mptcp_pm_is_userspace(msk)) + goto next; + + lock_sock(sk); + spin_lock_bh(&msk->pm.lock); + mptcp_pm_create_subflow_or_signal_addr(msk); + spin_unlock_bh(&msk->pm.lock); + release_sock(sk); + +next: + sock_put(sk); + cond_resched(); + } + + return 0; +} + +static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; + struct pm_nl_pernet *pernet = genl_info_pm_nl(info); + struct mptcp_pm_addr_entry addr, *entry; + int ret; + + ret = mptcp_pm_parse_entry(attr, info, true, &addr); + if (ret < 0) + return ret; + + if (addr.addr.port && !(addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) { + GENL_SET_ERR_MSG(info, "flags must have signal when using port"); + return -EINVAL; + } + + if (addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL && + addr.flags & MPTCP_PM_ADDR_FLAG_FULLMESH) { + GENL_SET_ERR_MSG(info, "flags mustn't have both signal and fullmesh"); + return -EINVAL; + } + + if (addr.flags & MPTCP_PM_ADDR_FLAG_IMPLICIT) { + GENL_SET_ERR_MSG(info, "can't create IMPLICIT endpoint"); + return -EINVAL; + } + + entry = kzalloc(sizeof(*entry), GFP_KERNEL_ACCOUNT); + if (!entry) { + GENL_SET_ERR_MSG(info, "can't allocate addr"); + return -ENOMEM; + } + + *entry = addr; + if (entry->addr.port) { + ret = mptcp_pm_nl_create_listen_socket(skb->sk, entry); + if (ret) { + GENL_SET_ERR_MSG(info, "create listen socket error"); + goto out_free; + } + } + ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); + if (ret < 0) { + GENL_SET_ERR_MSG(info, "too many addresses or duplicate one"); + goto out_free; + } + + mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk)); + return 0; + +out_free: + __mptcp_pm_release_addr_entry(entry); + return ret; +} + +int mptcp_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk, unsigned int id, + u8 *flags, int *ifindex) +{ + struct mptcp_pm_addr_entry *entry; + struct sock *sk = (struct sock *)msk; + struct net *net = sock_net(sk); + + *flags = 0; + *ifindex = 0; + + if (id) { + if (mptcp_pm_is_userspace(msk)) + return mptcp_userspace_pm_get_flags_and_ifindex_by_id(msk, + id, + flags, + ifindex); + + rcu_read_lock(); + entry = __lookup_addr_by_id(pm_nl_get_pernet(net), id); + if (entry) { + *flags = entry->flags; + *ifindex = entry->ifindex; + } + rcu_read_unlock(); + } + + return 0; +} + +static bool remove_anno_list_by_saddr(struct mptcp_sock *msk, + const struct mptcp_addr_info *addr) +{ + struct mptcp_pm_add_entry *entry; + + entry = mptcp_pm_del_add_timer(msk, addr, false); + if (entry) { + list_del(&entry->list); + kfree(entry); + return true; + } + + return false; +} + +static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk, + const struct mptcp_addr_info *addr, + bool force) +{ + struct mptcp_rm_list list = { .nr = 0 }; + bool ret; + + list.ids[list.nr++] = addr->id; + + ret = remove_anno_list_by_saddr(msk, addr); + if (ret || force) { + spin_lock_bh(&msk->pm.lock); + mptcp_pm_remove_addr(msk, &list); + spin_unlock_bh(&msk->pm.lock); + } + return ret; +} + +static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net, + const struct mptcp_pm_addr_entry *entry) +{ + const struct mptcp_addr_info *addr = &entry->addr; + struct mptcp_rm_list list = { .nr = 0 }; + long s_slot = 0, s_num = 0; + struct mptcp_sock *msk; + + pr_debug("remove_id=%d", addr->id); + + list.ids[list.nr++] = addr->id; + + while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { + struct sock *sk = (struct sock *)msk; + bool remove_subflow; + + if (mptcp_pm_is_userspace(msk)) + goto next; + + if (list_empty(&msk->conn_list)) { + mptcp_pm_remove_anno_addr(msk, addr, false); + goto next; + } + + lock_sock(sk); + remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr); + mptcp_pm_remove_anno_addr(msk, addr, remove_subflow && + !(entry->flags & MPTCP_PM_ADDR_FLAG_IMPLICIT)); + if (remove_subflow) + mptcp_pm_remove_subflow(msk, &list); + release_sock(sk); + +next: + sock_put(sk); + cond_resched(); + } + + return 0; +} + +static int mptcp_nl_remove_id_zero_address(struct net *net, + struct mptcp_addr_info *addr) +{ + struct mptcp_rm_list list = { .nr = 0 }; + long s_slot = 0, s_num = 0; + struct mptcp_sock *msk; + + list.ids[list.nr++] = 0; + + while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { + struct sock *sk = (struct sock *)msk; + struct mptcp_addr_info msk_local; + + if (list_empty(&msk->conn_list) || mptcp_pm_is_userspace(msk)) + goto next; + + local_address((struct sock_common *)msk, &msk_local); + if (!mptcp_addresses_equal(&msk_local, addr, addr->port)) + goto next; + + lock_sock(sk); + spin_lock_bh(&msk->pm.lock); + mptcp_pm_remove_addr(msk, &list); + mptcp_pm_nl_rm_subflow_received(msk, &list); + spin_unlock_bh(&msk->pm.lock); + release_sock(sk); + +next: + sock_put(sk); + cond_resched(); + } + + return 0; +} + +static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; + struct pm_nl_pernet *pernet = genl_info_pm_nl(info); + struct mptcp_pm_addr_entry addr, *entry; + unsigned int addr_max; + int ret; + + ret = mptcp_pm_parse_entry(attr, info, false, &addr); + if (ret < 0) + return ret; + + /* the zero id address is special: the first address used by the msk + * always gets such an id, so different subflows can have different zero + * id addresses. Additionally zero id is not accounted for in id_bitmap. + * Let's use an 'mptcp_rm_list' instead of the common remove code. + */ + if (addr.addr.id == 0) + return mptcp_nl_remove_id_zero_address(sock_net(skb->sk), &addr.addr); + + spin_lock_bh(&pernet->lock); + entry = __lookup_addr_by_id(pernet, addr.addr.id); + if (!entry) { + GENL_SET_ERR_MSG(info, "address not found"); + spin_unlock_bh(&pernet->lock); + return -EINVAL; + } + if (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL) { + addr_max = pernet->add_addr_signal_max; + WRITE_ONCE(pernet->add_addr_signal_max, addr_max - 1); + } + if (entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) { + addr_max = pernet->local_addr_max; + WRITE_ONCE(pernet->local_addr_max, addr_max - 1); + } + + pernet->addrs--; + list_del_rcu(&entry->list); + __clear_bit(entry->addr.id, pernet->id_bitmap); + spin_unlock_bh(&pernet->lock); + + mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), entry); + synchronize_rcu(); + __mptcp_pm_release_addr_entry(entry); + + return ret; +} + +void mptcp_pm_remove_addrs(struct mptcp_sock *msk, struct list_head *rm_list) +{ + struct mptcp_rm_list alist = { .nr = 0 }; + struct mptcp_pm_addr_entry *entry; + + list_for_each_entry(entry, rm_list, list) { + if ((remove_anno_list_by_saddr(msk, &entry->addr) || + lookup_subflow_by_saddr(&msk->conn_list, &entry->addr)) && + alist.nr < MPTCP_RM_IDS_MAX) + alist.ids[alist.nr++] = entry->addr.id; + } + + if (alist.nr) { + spin_lock_bh(&msk->pm.lock); + mptcp_pm_remove_addr(msk, &alist); + spin_unlock_bh(&msk->pm.lock); + } +} + +void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk, + struct list_head *rm_list) +{ + struct mptcp_rm_list alist = { .nr = 0 }, slist = { .nr = 0 }; + struct mptcp_pm_addr_entry *entry; + + list_for_each_entry(entry, rm_list, list) { + if (lookup_subflow_by_saddr(&msk->conn_list, &entry->addr) && + slist.nr < MPTCP_RM_IDS_MAX) + slist.ids[slist.nr++] = entry->addr.id; + + if (remove_anno_list_by_saddr(msk, &entry->addr) && + alist.nr < MPTCP_RM_IDS_MAX) + alist.ids[alist.nr++] = entry->addr.id; + } + + if (alist.nr) { + spin_lock_bh(&msk->pm.lock); + mptcp_pm_remove_addr(msk, &alist); + spin_unlock_bh(&msk->pm.lock); + } + if (slist.nr) + mptcp_pm_remove_subflow(msk, &slist); +} + +static void mptcp_nl_remove_addrs_list(struct net *net, + struct list_head *rm_list) +{ + long s_slot = 0, s_num = 0; + struct mptcp_sock *msk; + + if (list_empty(rm_list)) + return; + + while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { + struct sock *sk = (struct sock *)msk; + + if (!mptcp_pm_is_userspace(msk)) { + lock_sock(sk); + mptcp_pm_remove_addrs_and_subflows(msk, rm_list); + release_sock(sk); + } + + sock_put(sk); + cond_resched(); + } +} + +/* caller must ensure the RCU grace period is already elapsed */ +static void __flush_addrs(struct list_head *list) +{ + while (!list_empty(list)) { + struct mptcp_pm_addr_entry *cur; + + cur = list_entry(list->next, + struct mptcp_pm_addr_entry, list); + list_del_rcu(&cur->list); + __mptcp_pm_release_addr_entry(cur); + } +} + +static void __reset_counters(struct pm_nl_pernet *pernet) +{ + WRITE_ONCE(pernet->add_addr_signal_max, 0); + WRITE_ONCE(pernet->add_addr_accept_max, 0); + WRITE_ONCE(pernet->local_addr_max, 0); + pernet->addrs = 0; +} + +static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info) +{ + struct pm_nl_pernet *pernet = genl_info_pm_nl(info); + LIST_HEAD(free_list); + + spin_lock_bh(&pernet->lock); + list_splice_init(&pernet->local_addr_list, &free_list); + __reset_counters(pernet); + pernet->next_id = 1; + bitmap_zero(pernet->id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1); + spin_unlock_bh(&pernet->lock); + mptcp_nl_remove_addrs_list(sock_net(skb->sk), &free_list); + synchronize_rcu(); + __flush_addrs(&free_list); + return 0; +} + +static int mptcp_nl_fill_addr(struct sk_buff *skb, + struct mptcp_pm_addr_entry *entry) +{ + struct mptcp_addr_info *addr = &entry->addr; + struct nlattr *attr; + + attr = nla_nest_start(skb, MPTCP_PM_ATTR_ADDR); + if (!attr) + return -EMSGSIZE; + + if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family)) + goto nla_put_failure; + if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_PORT, ntohs(addr->port))) + goto nla_put_failure; + if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id)) + goto nla_put_failure; + if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->flags)) + goto nla_put_failure; + if (entry->ifindex && + nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->ifindex)) + goto nla_put_failure; + + if (addr->family == AF_INET && + nla_put_in_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR4, + addr->addr.s_addr)) + goto nla_put_failure; +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + else if (addr->family == AF_INET6 && + nla_put_in6_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR6, &addr->addr6)) + goto nla_put_failure; +#endif + nla_nest_end(skb, attr); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, attr); + return -EMSGSIZE; +} + +static int mptcp_nl_cmd_get_addr(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; + struct pm_nl_pernet *pernet = genl_info_pm_nl(info); + struct mptcp_pm_addr_entry addr, *entry; + struct sk_buff *msg; + void *reply; + int ret; + + ret = mptcp_pm_parse_entry(attr, info, false, &addr); + if (ret < 0) + return ret; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0, + info->genlhdr->cmd); + if (!reply) { + GENL_SET_ERR_MSG(info, "not enough space in Netlink message"); + ret = -EMSGSIZE; + goto fail; + } + + spin_lock_bh(&pernet->lock); + entry = __lookup_addr_by_id(pernet, addr.addr.id); + if (!entry) { + GENL_SET_ERR_MSG(info, "address not found"); + ret = -EINVAL; + goto unlock_fail; + } + + ret = mptcp_nl_fill_addr(msg, entry); + if (ret) + goto unlock_fail; + + genlmsg_end(msg, reply); + ret = genlmsg_reply(msg, info); + spin_unlock_bh(&pernet->lock); + return ret; + +unlock_fail: + spin_unlock_bh(&pernet->lock); + +fail: + nlmsg_free(msg); + return ret; +} + +static int mptcp_nl_cmd_dump_addrs(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct net *net = sock_net(msg->sk); + struct mptcp_pm_addr_entry *entry; + struct pm_nl_pernet *pernet; + int id = cb->args[0]; + void *hdr; + int i; + + pernet = pm_nl_get_pernet(net); + + spin_lock_bh(&pernet->lock); + for (i = id; i < MPTCP_PM_MAX_ADDR_ID + 1; i++) { + if (test_bit(i, pernet->id_bitmap)) { + entry = __lookup_addr_by_id(pernet, i); + if (!entry) + break; + + if (entry->addr.id <= id) + continue; + + hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, &mptcp_genl_family, + NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR); + if (!hdr) + break; + + if (mptcp_nl_fill_addr(msg, entry) < 0) { + genlmsg_cancel(msg, hdr); + break; + } + + id = entry->addr.id; + genlmsg_end(msg, hdr); + } + } + spin_unlock_bh(&pernet->lock); + + cb->args[0] = id; + return msg->len; +} + +static int parse_limit(struct genl_info *info, int id, unsigned int *limit) +{ + struct nlattr *attr = info->attrs[id]; + + if (!attr) + return 0; + + *limit = nla_get_u32(attr); + if (*limit > MPTCP_PM_ADDR_MAX) { + GENL_SET_ERR_MSG(info, "limit greater than maximum"); + return -EINVAL; + } + return 0; +} + +static int +mptcp_nl_cmd_set_limits(struct sk_buff *skb, struct genl_info *info) +{ + struct pm_nl_pernet *pernet = genl_info_pm_nl(info); + unsigned int rcv_addrs, subflows; + int ret; + + spin_lock_bh(&pernet->lock); + rcv_addrs = pernet->add_addr_accept_max; + ret = parse_limit(info, MPTCP_PM_ATTR_RCV_ADD_ADDRS, &rcv_addrs); + if (ret) + goto unlock; + + subflows = pernet->subflows_max; + ret = parse_limit(info, MPTCP_PM_ATTR_SUBFLOWS, &subflows); + if (ret) + goto unlock; + + WRITE_ONCE(pernet->add_addr_accept_max, rcv_addrs); + WRITE_ONCE(pernet->subflows_max, subflows); + +unlock: + spin_unlock_bh(&pernet->lock); + return ret; +} + +static int +mptcp_nl_cmd_get_limits(struct sk_buff *skb, struct genl_info *info) +{ + struct pm_nl_pernet *pernet = genl_info_pm_nl(info); + struct sk_buff *msg; + void *reply; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0, + MPTCP_PM_CMD_GET_LIMITS); + if (!reply) + goto fail; + + if (nla_put_u32(msg, MPTCP_PM_ATTR_RCV_ADD_ADDRS, + READ_ONCE(pernet->add_addr_accept_max))) + goto fail; + + if (nla_put_u32(msg, MPTCP_PM_ATTR_SUBFLOWS, + READ_ONCE(pernet->subflows_max))) + goto fail; + + genlmsg_end(msg, reply); + return genlmsg_reply(msg, info); + +fail: + GENL_SET_ERR_MSG(info, "not enough space in Netlink message"); + nlmsg_free(msg); + return -EMSGSIZE; +} + +static void mptcp_pm_nl_fullmesh(struct mptcp_sock *msk, + struct mptcp_addr_info *addr) +{ + struct mptcp_rm_list list = { .nr = 0 }; + + list.ids[list.nr++] = addr->id; + + spin_lock_bh(&msk->pm.lock); + mptcp_pm_nl_rm_subflow_received(msk, &list); + mptcp_pm_create_subflow_or_signal_addr(msk); + spin_unlock_bh(&msk->pm.lock); +} + +static int mptcp_nl_set_flags(struct net *net, + struct mptcp_addr_info *addr, + u8 bkup, u8 changed) +{ + long s_slot = 0, s_num = 0; + struct mptcp_sock *msk; + int ret = -EINVAL; + + while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { + struct sock *sk = (struct sock *)msk; + + if (list_empty(&msk->conn_list) || mptcp_pm_is_userspace(msk)) + goto next; + + lock_sock(sk); + if (changed & MPTCP_PM_ADDR_FLAG_BACKUP) + ret = mptcp_pm_nl_mp_prio_send_ack(msk, addr, NULL, bkup); + if (changed & MPTCP_PM_ADDR_FLAG_FULLMESH) + mptcp_pm_nl_fullmesh(msk, addr); + release_sock(sk); + +next: + sock_put(sk); + cond_resched(); + } + + return ret; +} + +static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info) +{ + struct mptcp_pm_addr_entry addr = { .addr = { .family = AF_UNSPEC }, }, *entry; + struct mptcp_pm_addr_entry remote = { .addr = { .family = AF_UNSPEC }, }; + struct nlattr *attr_rem = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE]; + struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; + struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; + struct pm_nl_pernet *pernet = genl_info_pm_nl(info); + u8 changed, mask = MPTCP_PM_ADDR_FLAG_BACKUP | + MPTCP_PM_ADDR_FLAG_FULLMESH; + struct net *net = sock_net(skb->sk); + u8 bkup = 0, lookup_by_id = 0; + int ret; + + ret = mptcp_pm_parse_entry(attr, info, false, &addr); + if (ret < 0) + return ret; + + if (attr_rem) { + ret = mptcp_pm_parse_entry(attr_rem, info, false, &remote); + if (ret < 0) + return ret; + } + + if (addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP) + bkup = 1; + if (addr.addr.family == AF_UNSPEC) { + lookup_by_id = 1; + if (!addr.addr.id) + return -EOPNOTSUPP; + } + + if (token) + return mptcp_userspace_pm_set_flags(sock_net(skb->sk), + token, &addr, &remote, bkup); + + spin_lock_bh(&pernet->lock); + entry = __lookup_addr(pernet, &addr.addr, lookup_by_id); + if (!entry) { + spin_unlock_bh(&pernet->lock); + return -EINVAL; + } + if ((addr.flags & MPTCP_PM_ADDR_FLAG_FULLMESH) && + (entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) { + spin_unlock_bh(&pernet->lock); + return -EINVAL; + } + + changed = (addr.flags ^ entry->flags) & mask; + entry->flags = (entry->flags & ~mask) | (addr.flags & mask); + addr = *entry; + spin_unlock_bh(&pernet->lock); + + mptcp_nl_set_flags(net, &addr.addr, bkup, changed); + return 0; +} + +static void mptcp_nl_mcast_send(struct net *net, struct sk_buff *nlskb, gfp_t gfp) +{ + genlmsg_multicast_netns(&mptcp_genl_family, net, + nlskb, 0, MPTCP_PM_EV_GRP_OFFSET, gfp); +} + +bool mptcp_userspace_pm_active(const struct mptcp_sock *msk) +{ + return genl_has_listeners(&mptcp_genl_family, + sock_net((const struct sock *)msk), + MPTCP_PM_EV_GRP_OFFSET); +} + +static int mptcp_event_add_subflow(struct sk_buff *skb, const struct sock *ssk) +{ + const struct inet_sock *issk = inet_sk(ssk); + const struct mptcp_subflow_context *sf; + + if (nla_put_u16(skb, MPTCP_ATTR_FAMILY, ssk->sk_family)) + return -EMSGSIZE; + + switch (ssk->sk_family) { + case AF_INET: + if (nla_put_in_addr(skb, MPTCP_ATTR_SADDR4, issk->inet_saddr)) + return -EMSGSIZE; + if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, issk->inet_daddr)) + return -EMSGSIZE; + break; +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + case AF_INET6: { + const struct ipv6_pinfo *np = inet6_sk(ssk); + + if (nla_put_in6_addr(skb, MPTCP_ATTR_SADDR6, &np->saddr)) + return -EMSGSIZE; + if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &ssk->sk_v6_daddr)) + return -EMSGSIZE; + break; + } +#endif + default: + WARN_ON_ONCE(1); + return -EMSGSIZE; + } + + if (nla_put_be16(skb, MPTCP_ATTR_SPORT, issk->inet_sport)) + return -EMSGSIZE; + if (nla_put_be16(skb, MPTCP_ATTR_DPORT, issk->inet_dport)) + return -EMSGSIZE; + + sf = mptcp_subflow_ctx(ssk); + if (WARN_ON_ONCE(!sf)) + return -EINVAL; + + if (nla_put_u8(skb, MPTCP_ATTR_LOC_ID, sf->local_id)) + return -EMSGSIZE; + + if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, sf->remote_id)) + return -EMSGSIZE; + + return 0; +} + +static int mptcp_event_put_token_and_ssk(struct sk_buff *skb, + const struct mptcp_sock *msk, + const struct sock *ssk) +{ + const struct sock *sk = (const struct sock *)msk; + const struct mptcp_subflow_context *sf; + u8 sk_err; + + if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token)) + return -EMSGSIZE; + + if (mptcp_event_add_subflow(skb, ssk)) + return -EMSGSIZE; + + sf = mptcp_subflow_ctx(ssk); + if (WARN_ON_ONCE(!sf)) + return -EINVAL; + + if (nla_put_u8(skb, MPTCP_ATTR_BACKUP, sf->backup)) + return -EMSGSIZE; + + if (ssk->sk_bound_dev_if && + nla_put_s32(skb, MPTCP_ATTR_IF_IDX, ssk->sk_bound_dev_if)) + return -EMSGSIZE; + + sk_err = READ_ONCE(ssk->sk_err); + if (sk_err && sk->sk_state == TCP_ESTABLISHED && + nla_put_u8(skb, MPTCP_ATTR_ERROR, sk_err)) + return -EMSGSIZE; + + return 0; +} + +static int mptcp_event_sub_established(struct sk_buff *skb, + const struct mptcp_sock *msk, + const struct sock *ssk) +{ + return mptcp_event_put_token_and_ssk(skb, msk, ssk); +} + +static int mptcp_event_sub_closed(struct sk_buff *skb, + const struct mptcp_sock *msk, + const struct sock *ssk) +{ + const struct mptcp_subflow_context *sf; + + if (mptcp_event_put_token_and_ssk(skb, msk, ssk)) + return -EMSGSIZE; + + sf = mptcp_subflow_ctx(ssk); + if (!sf->reset_seen) + return 0; + + if (nla_put_u32(skb, MPTCP_ATTR_RESET_REASON, sf->reset_reason)) + return -EMSGSIZE; + + if (nla_put_u32(skb, MPTCP_ATTR_RESET_FLAGS, sf->reset_transient)) + return -EMSGSIZE; + + return 0; +} + +static int mptcp_event_created(struct sk_buff *skb, + const struct mptcp_sock *msk, + const struct sock *ssk) +{ + int err = nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token); + + if (err) + return err; + + if (nla_put_u8(skb, MPTCP_ATTR_SERVER_SIDE, READ_ONCE(msk->pm.server_side))) + return -EMSGSIZE; + + return mptcp_event_add_subflow(skb, ssk); +} + +void mptcp_event_addr_removed(const struct mptcp_sock *msk, uint8_t id) +{ + struct net *net = sock_net((const struct sock *)msk); + struct nlmsghdr *nlh; + struct sk_buff *skb; + + if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET)) + return; + + skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!skb) + return; + + nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, MPTCP_EVENT_REMOVED); + if (!nlh) + goto nla_put_failure; + + if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token)) + goto nla_put_failure; + + if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, id)) + goto nla_put_failure; + + genlmsg_end(skb, nlh); + mptcp_nl_mcast_send(net, skb, GFP_ATOMIC); + return; + +nla_put_failure: + kfree_skb(skb); +} + +void mptcp_event_addr_announced(const struct sock *ssk, + const struct mptcp_addr_info *info) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + struct mptcp_sock *msk = mptcp_sk(subflow->conn); + struct net *net = sock_net(ssk); + struct nlmsghdr *nlh; + struct sk_buff *skb; + + if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET)) + return; + + skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!skb) + return; + + nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, + MPTCP_EVENT_ANNOUNCED); + if (!nlh) + goto nla_put_failure; + + if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token)) + goto nla_put_failure; + + if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, info->id)) + goto nla_put_failure; + + if (nla_put_be16(skb, MPTCP_ATTR_DPORT, + info->port == 0 ? + inet_sk(ssk)->inet_dport : + info->port)) + goto nla_put_failure; + + switch (info->family) { + case AF_INET: + if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, info->addr.s_addr)) + goto nla_put_failure; + break; +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + case AF_INET6: + if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &info->addr6)) + goto nla_put_failure; + break; +#endif + default: + WARN_ON_ONCE(1); + goto nla_put_failure; + } + + genlmsg_end(skb, nlh); + mptcp_nl_mcast_send(net, skb, GFP_ATOMIC); + return; + +nla_put_failure: + kfree_skb(skb); +} + +void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk, + const struct sock *ssk, gfp_t gfp) +{ + struct net *net = sock_net((const struct sock *)msk); + struct nlmsghdr *nlh; + struct sk_buff *skb; + + if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET)) + return; + + skb = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!skb) + return; + + nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, type); + if (!nlh) + goto nla_put_failure; + + switch (type) { + case MPTCP_EVENT_UNSPEC: + WARN_ON_ONCE(1); + break; + case MPTCP_EVENT_CREATED: + case MPTCP_EVENT_ESTABLISHED: + if (mptcp_event_created(skb, msk, ssk) < 0) + goto nla_put_failure; + break; + case MPTCP_EVENT_CLOSED: + if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token) < 0) + goto nla_put_failure; + break; + case MPTCP_EVENT_ANNOUNCED: + case MPTCP_EVENT_REMOVED: + /* call mptcp_event_addr_announced()/removed instead */ + WARN_ON_ONCE(1); + break; + case MPTCP_EVENT_SUB_ESTABLISHED: + case MPTCP_EVENT_SUB_PRIORITY: + if (mptcp_event_sub_established(skb, msk, ssk) < 0) + goto nla_put_failure; + break; + case MPTCP_EVENT_SUB_CLOSED: + if (mptcp_event_sub_closed(skb, msk, ssk) < 0) + goto nla_put_failure; + break; + } + + genlmsg_end(skb, nlh); + mptcp_nl_mcast_send(net, skb, gfp); + return; + +nla_put_failure: + kfree_skb(skb); +} + +static const struct genl_small_ops mptcp_pm_ops[] = { + { + .cmd = MPTCP_PM_CMD_ADD_ADDR, + .doit = mptcp_nl_cmd_add_addr, + .flags = GENL_UNS_ADMIN_PERM, + }, + { + .cmd = MPTCP_PM_CMD_DEL_ADDR, + .doit = mptcp_nl_cmd_del_addr, + .flags = GENL_UNS_ADMIN_PERM, + }, + { + .cmd = MPTCP_PM_CMD_FLUSH_ADDRS, + .doit = mptcp_nl_cmd_flush_addrs, + .flags = GENL_UNS_ADMIN_PERM, + }, + { + .cmd = MPTCP_PM_CMD_GET_ADDR, + .doit = mptcp_nl_cmd_get_addr, + .dumpit = mptcp_nl_cmd_dump_addrs, + }, + { + .cmd = MPTCP_PM_CMD_SET_LIMITS, + .doit = mptcp_nl_cmd_set_limits, + .flags = GENL_UNS_ADMIN_PERM, + }, + { + .cmd = MPTCP_PM_CMD_GET_LIMITS, + .doit = mptcp_nl_cmd_get_limits, + }, + { + .cmd = MPTCP_PM_CMD_SET_FLAGS, + .doit = mptcp_nl_cmd_set_flags, + .flags = GENL_UNS_ADMIN_PERM, + }, + { + .cmd = MPTCP_PM_CMD_ANNOUNCE, + .doit = mptcp_nl_cmd_announce, + .flags = GENL_UNS_ADMIN_PERM, + }, + { + .cmd = MPTCP_PM_CMD_REMOVE, + .doit = mptcp_nl_cmd_remove, + .flags = GENL_UNS_ADMIN_PERM, + }, + { + .cmd = MPTCP_PM_CMD_SUBFLOW_CREATE, + .doit = mptcp_nl_cmd_sf_create, + .flags = GENL_UNS_ADMIN_PERM, + }, + { + .cmd = MPTCP_PM_CMD_SUBFLOW_DESTROY, + .doit = mptcp_nl_cmd_sf_destroy, + .flags = GENL_UNS_ADMIN_PERM, + }, +}; + +static struct genl_family mptcp_genl_family __ro_after_init = { + .name = MPTCP_PM_NAME, + .version = MPTCP_PM_VER, + .maxattr = MPTCP_PM_ATTR_MAX, + .policy = mptcp_pm_policy, + .netnsok = true, + .module = THIS_MODULE, + .small_ops = mptcp_pm_ops, + .n_small_ops = ARRAY_SIZE(mptcp_pm_ops), + .resv_start_op = MPTCP_PM_CMD_SUBFLOW_DESTROY + 1, + .mcgrps = mptcp_pm_mcgrps, + .n_mcgrps = ARRAY_SIZE(mptcp_pm_mcgrps), +}; + +static int __net_init pm_nl_init_net(struct net *net) +{ + struct pm_nl_pernet *pernet = pm_nl_get_pernet(net); + + INIT_LIST_HEAD_RCU(&pernet->local_addr_list); + + /* Cit. 2 subflows ought to be enough for anybody. */ + pernet->subflows_max = 2; + pernet->next_id = 1; + pernet->stale_loss_cnt = 4; + spin_lock_init(&pernet->lock); + + /* No need to initialize other pernet fields, the struct is zeroed at + * allocation time. + */ + + return 0; +} + +static void __net_exit pm_nl_exit_net(struct list_head *net_list) +{ + struct net *net; + + list_for_each_entry(net, net_list, exit_list) { + struct pm_nl_pernet *pernet = pm_nl_get_pernet(net); + + /* net is removed from namespace list, can't race with + * other modifiers, also netns core already waited for a + * RCU grace period. + */ + __flush_addrs(&pernet->local_addr_list); + } +} + +static struct pernet_operations mptcp_pm_pernet_ops = { + .init = pm_nl_init_net, + .exit_batch = pm_nl_exit_net, + .id = &pm_nl_pernet_id, + .size = sizeof(struct pm_nl_pernet), +}; + +void __init mptcp_pm_nl_init(void) +{ + if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0) + panic("Failed to register MPTCP PM pernet subsystem.\n"); + + if (genl_register_family(&mptcp_genl_family)) + panic("Failed to register MPTCP PM netlink family\n"); +} diff --git a/net/mptcp/pm_userspace.c b/net/mptcp/pm_userspace.c new file mode 100644 index 000000000..38cbdc66d --- /dev/null +++ b/net/mptcp/pm_userspace.c @@ -0,0 +1,505 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Multipath TCP + * + * Copyright (c) 2022, Intel Corporation. + */ + +#include "protocol.h" +#include "mib.h" + +void mptcp_free_local_addr_list(struct mptcp_sock *msk) +{ + struct mptcp_pm_addr_entry *entry, *tmp; + struct sock *sk = (struct sock *)msk; + LIST_HEAD(free_list); + + if (!mptcp_pm_is_userspace(msk)) + return; + + spin_lock_bh(&msk->pm.lock); + list_splice_init(&msk->pm.userspace_pm_local_addr_list, &free_list); + spin_unlock_bh(&msk->pm.lock); + + list_for_each_entry_safe(entry, tmp, &free_list, list) { + sock_kfree_s(sk, entry, sizeof(*entry)); + } +} + +int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk, + struct mptcp_pm_addr_entry *entry) +{ + DECLARE_BITMAP(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1); + struct mptcp_pm_addr_entry *match = NULL; + struct sock *sk = (struct sock *)msk; + struct mptcp_pm_addr_entry *e; + bool addr_match = false; + bool id_match = false; + int ret = -EINVAL; + + bitmap_zero(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1); + + spin_lock_bh(&msk->pm.lock); + list_for_each_entry(e, &msk->pm.userspace_pm_local_addr_list, list) { + addr_match = mptcp_addresses_equal(&e->addr, &entry->addr, true); + if (addr_match && entry->addr.id == 0) + entry->addr.id = e->addr.id; + id_match = (e->addr.id == entry->addr.id); + if (addr_match && id_match) { + match = e; + break; + } else if (addr_match || id_match) { + break; + } + __set_bit(e->addr.id, id_bitmap); + } + + if (!match && !addr_match && !id_match) { + /* Memory for the entry is allocated from the + * sock option buffer. + */ + e = sock_kmalloc(sk, sizeof(*e), GFP_ATOMIC); + if (!e) { + spin_unlock_bh(&msk->pm.lock); + return -ENOMEM; + } + + *e = *entry; + if (!e->addr.id) + e->addr.id = find_next_zero_bit(id_bitmap, + MPTCP_PM_MAX_ADDR_ID + 1, + 1); + list_add_tail_rcu(&e->list, &msk->pm.userspace_pm_local_addr_list); + msk->pm.local_addr_used++; + ret = e->addr.id; + } else if (match) { + ret = entry->addr.id; + } + + spin_unlock_bh(&msk->pm.lock); + return ret; +} + +/* If the subflow is closed from the other peer (not via a + * subflow destroy command then), we want to keep the entry + * not to assign the same ID to another address and to be + * able to send RM_ADDR after the removal of the subflow. + */ +static int mptcp_userspace_pm_delete_local_addr(struct mptcp_sock *msk, + struct mptcp_pm_addr_entry *addr) +{ + struct mptcp_pm_addr_entry *entry, *tmp; + + list_for_each_entry_safe(entry, tmp, &msk->pm.userspace_pm_local_addr_list, list) { + if (mptcp_addresses_equal(&entry->addr, &addr->addr, false)) { + /* TODO: a refcount is needed because the entry can + * be used multiple times (e.g. fullmesh mode). + */ + list_del_rcu(&entry->list); + kfree(entry); + msk->pm.local_addr_used--; + return 0; + } + } + + return -EINVAL; +} + +int mptcp_userspace_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk, + unsigned int id, + u8 *flags, int *ifindex) +{ + struct mptcp_pm_addr_entry *entry, *match = NULL; + + *flags = 0; + *ifindex = 0; + + spin_lock_bh(&msk->pm.lock); + list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) { + if (id == entry->addr.id) { + match = entry; + break; + } + } + spin_unlock_bh(&msk->pm.lock); + if (match) { + *flags = match->flags; + *ifindex = match->ifindex; + } + + return 0; +} + +int mptcp_userspace_pm_get_local_id(struct mptcp_sock *msk, + struct mptcp_addr_info *skc) +{ + struct mptcp_pm_addr_entry new_entry; + __be16 msk_sport = ((struct inet_sock *) + inet_sk((struct sock *)msk))->inet_sport; + + memset(&new_entry, 0, sizeof(struct mptcp_pm_addr_entry)); + new_entry.addr = *skc; + new_entry.addr.id = 0; + new_entry.flags = MPTCP_PM_ADDR_FLAG_IMPLICIT; + + if (new_entry.addr.port == msk_sport) + new_entry.addr.port = 0; + + return mptcp_userspace_pm_append_new_local_addr(msk, &new_entry); +} + +int mptcp_nl_cmd_announce(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; + struct nlattr *addr = info->attrs[MPTCP_PM_ATTR_ADDR]; + struct mptcp_pm_addr_entry addr_val; + struct mptcp_sock *msk; + int err = -EINVAL; + u32 token_val; + + if (!addr || !token) { + GENL_SET_ERR_MSG(info, "missing required inputs"); + return err; + } + + token_val = nla_get_u32(token); + + msk = mptcp_token_get_sock(sock_net(skb->sk), token_val); + if (!msk) { + NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token"); + return err; + } + + if (!mptcp_pm_is_userspace(msk)) { + GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); + goto announce_err; + } + + err = mptcp_pm_parse_entry(addr, info, true, &addr_val); + if (err < 0) { + GENL_SET_ERR_MSG(info, "error parsing local address"); + goto announce_err; + } + + if (addr_val.addr.id == 0 || !(addr_val.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) { + GENL_SET_ERR_MSG(info, "invalid addr id or flags"); + err = -EINVAL; + goto announce_err; + } + + err = mptcp_userspace_pm_append_new_local_addr(msk, &addr_val); + if (err < 0) { + GENL_SET_ERR_MSG(info, "did not match address and id"); + goto announce_err; + } + + lock_sock((struct sock *)msk); + spin_lock_bh(&msk->pm.lock); + + if (mptcp_pm_alloc_anno_list(msk, &addr_val)) { + msk->pm.add_addr_signaled++; + mptcp_pm_announce_addr(msk, &addr_val.addr, false); + mptcp_pm_nl_addr_send_ack(msk); + } + + spin_unlock_bh(&msk->pm.lock); + release_sock((struct sock *)msk); + + err = 0; + announce_err: + sock_put((struct sock *)msk); + return err; +} + +int mptcp_nl_cmd_remove(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; + struct nlattr *id = info->attrs[MPTCP_PM_ATTR_LOC_ID]; + struct mptcp_pm_addr_entry *match = NULL; + struct mptcp_pm_addr_entry *entry; + struct mptcp_sock *msk; + LIST_HEAD(free_list); + int err = -EINVAL; + u32 token_val; + u8 id_val; + + if (!id || !token) { + GENL_SET_ERR_MSG(info, "missing required inputs"); + return err; + } + + id_val = nla_get_u8(id); + token_val = nla_get_u32(token); + + msk = mptcp_token_get_sock(sock_net(skb->sk), token_val); + if (!msk) { + NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token"); + return err; + } + + if (!mptcp_pm_is_userspace(msk)) { + GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); + goto remove_err; + } + + lock_sock((struct sock *)msk); + + list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) { + if (entry->addr.id == id_val) { + match = entry; + break; + } + } + + if (!match) { + GENL_SET_ERR_MSG(info, "address with specified id not found"); + release_sock((struct sock *)msk); + goto remove_err; + } + + list_move(&match->list, &free_list); + + mptcp_pm_remove_addrs(msk, &free_list); + + release_sock((struct sock *)msk); + + list_for_each_entry_safe(match, entry, &free_list, list) { + sock_kfree_s((struct sock *)msk, match, sizeof(*match)); + } + + err = 0; + remove_err: + sock_put((struct sock *)msk); + return err; +} + +int mptcp_nl_cmd_sf_create(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE]; + struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; + struct nlattr *laddr = info->attrs[MPTCP_PM_ATTR_ADDR]; + struct mptcp_pm_addr_entry local = { 0 }; + struct mptcp_addr_info addr_r; + struct mptcp_addr_info addr_l; + struct mptcp_sock *msk; + int err = -EINVAL; + struct sock *sk; + u32 token_val; + + if (!laddr || !raddr || !token) { + GENL_SET_ERR_MSG(info, "missing required inputs"); + return err; + } + + token_val = nla_get_u32(token); + + msk = mptcp_token_get_sock(genl_info_net(info), token_val); + if (!msk) { + NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token"); + return err; + } + + if (!mptcp_pm_is_userspace(msk)) { + GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); + goto create_err; + } + + err = mptcp_pm_parse_addr(laddr, info, &addr_l); + if (err < 0) { + NL_SET_ERR_MSG_ATTR(info->extack, laddr, "error parsing local addr"); + goto create_err; + } + + err = mptcp_pm_parse_addr(raddr, info, &addr_r); + if (err < 0) { + NL_SET_ERR_MSG_ATTR(info->extack, raddr, "error parsing remote addr"); + goto create_err; + } + + sk = &msk->sk.icsk_inet.sk; + + if (!mptcp_pm_addr_families_match(sk, &addr_l, &addr_r)) { + GENL_SET_ERR_MSG(info, "families mismatch"); + err = -EINVAL; + goto create_err; + } + + local.addr = addr_l; + err = mptcp_userspace_pm_append_new_local_addr(msk, &local); + if (err < 0) { + GENL_SET_ERR_MSG(info, "did not match address and id"); + goto create_err; + } + + lock_sock(sk); + + err = __mptcp_subflow_connect(sk, &addr_l, &addr_r); + + release_sock(sk); + + spin_lock_bh(&msk->pm.lock); + if (err) + mptcp_userspace_pm_delete_local_addr(msk, &local); + else + msk->pm.subflows++; + spin_unlock_bh(&msk->pm.lock); + + create_err: + sock_put((struct sock *)msk); + return err; +} + +static struct sock *mptcp_nl_find_ssk(struct mptcp_sock *msk, + const struct mptcp_addr_info *local, + const struct mptcp_addr_info *remote) +{ + struct mptcp_subflow_context *subflow; + + if (local->family != remote->family) + return NULL; + + mptcp_for_each_subflow(msk, subflow) { + const struct inet_sock *issk; + struct sock *ssk; + + ssk = mptcp_subflow_tcp_sock(subflow); + + if (local->family != ssk->sk_family) + continue; + + issk = inet_sk(ssk); + + switch (ssk->sk_family) { + case AF_INET: + if (issk->inet_saddr != local->addr.s_addr || + issk->inet_daddr != remote->addr.s_addr) + continue; + break; +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + case AF_INET6: { + const struct ipv6_pinfo *pinfo = inet6_sk(ssk); + + if (!ipv6_addr_equal(&local->addr6, &pinfo->saddr) || + !ipv6_addr_equal(&remote->addr6, &ssk->sk_v6_daddr)) + continue; + break; + } +#endif + default: + continue; + } + + if (issk->inet_sport == local->port && + issk->inet_dport == remote->port) + return ssk; + } + + return NULL; +} + +int mptcp_nl_cmd_sf_destroy(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE]; + struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN]; + struct nlattr *laddr = info->attrs[MPTCP_PM_ATTR_ADDR]; + struct mptcp_addr_info addr_l; + struct mptcp_addr_info addr_r; + struct mptcp_sock *msk; + struct sock *sk, *ssk; + int err = -EINVAL; + u32 token_val; + + if (!laddr || !raddr || !token) { + GENL_SET_ERR_MSG(info, "missing required inputs"); + return err; + } + + token_val = nla_get_u32(token); + + msk = mptcp_token_get_sock(genl_info_net(info), token_val); + if (!msk) { + NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token"); + return err; + } + + if (!mptcp_pm_is_userspace(msk)) { + GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected"); + goto destroy_err; + } + + err = mptcp_pm_parse_addr(laddr, info, &addr_l); + if (err < 0) { + NL_SET_ERR_MSG_ATTR(info->extack, laddr, "error parsing local addr"); + goto destroy_err; + } + + err = mptcp_pm_parse_addr(raddr, info, &addr_r); + if (err < 0) { + NL_SET_ERR_MSG_ATTR(info->extack, raddr, "error parsing remote addr"); + goto destroy_err; + } + + if (addr_l.family != addr_r.family) { + GENL_SET_ERR_MSG(info, "address families do not match"); + err = -EINVAL; + goto destroy_err; + } + + if (!addr_l.port || !addr_r.port) { + GENL_SET_ERR_MSG(info, "missing local or remote port"); + err = -EINVAL; + goto destroy_err; + } + + sk = &msk->sk.icsk_inet.sk; + lock_sock(sk); + ssk = mptcp_nl_find_ssk(msk, &addr_l, &addr_r); + if (ssk) { + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + struct mptcp_pm_addr_entry entry = { .addr = addr_l }; + + spin_lock_bh(&msk->pm.lock); + mptcp_userspace_pm_delete_local_addr(msk, &entry); + spin_unlock_bh(&msk->pm.lock); + mptcp_subflow_shutdown(sk, ssk, RCV_SHUTDOWN | SEND_SHUTDOWN); + mptcp_close_ssk(sk, ssk, subflow); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMSUBFLOW); + err = 0; + } else { + err = -ESRCH; + } + release_sock(sk); + +destroy_err: + sock_put((struct sock *)msk); + return err; +} + +int mptcp_userspace_pm_set_flags(struct net *net, struct nlattr *token, + struct mptcp_pm_addr_entry *loc, + struct mptcp_pm_addr_entry *rem, u8 bkup) +{ + struct mptcp_sock *msk; + int ret = -EINVAL; + u32 token_val; + + token_val = nla_get_u32(token); + + msk = mptcp_token_get_sock(net, token_val); + if (!msk) + return ret; + + if (!mptcp_pm_is_userspace(msk)) + goto set_flags_err; + + if (loc->addr.family == AF_UNSPEC || + rem->addr.family == AF_UNSPEC) + goto set_flags_err; + + lock_sock((struct sock *)msk); + ret = mptcp_pm_nl_mp_prio_send_ack(msk, &loc->addr, &rem->addr, bkup); + release_sock((struct sock *)msk); + +set_flags_err: + sock_put((struct sock *)msk); + return ret; +} diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c new file mode 100644 index 000000000..76539d100 --- /dev/null +++ b/net/mptcp/protocol.c @@ -0,0 +1,4099 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Multipath TCP + * + * Copyright (c) 2017 - 2019, Intel Corporation. + */ + +#define pr_fmt(fmt) "MPTCP: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_MPTCP_IPV6) +#include +#endif +#include +#include +#include +#include "protocol.h" +#include "mib.h" + +#define CREATE_TRACE_POINTS +#include + +#if IS_ENABLED(CONFIG_MPTCP_IPV6) +struct mptcp6_sock { + struct mptcp_sock msk; + struct ipv6_pinfo np; +}; +#endif + +struct mptcp_skb_cb { + u64 map_seq; + u64 end_seq; + u32 offset; + u8 has_rxtstamp:1; +}; + +#define MPTCP_SKB_CB(__skb) ((struct mptcp_skb_cb *)&((__skb)->cb[0])) + +enum { + MPTCP_CMSG_TS = BIT(0), + MPTCP_CMSG_INQ = BIT(1), +}; + +static struct percpu_counter mptcp_sockets_allocated ____cacheline_aligned_in_smp; + +static void __mptcp_destroy_sock(struct sock *sk); +static void mptcp_check_send_data_fin(struct sock *sk); + +DEFINE_PER_CPU(struct mptcp_delegated_action, mptcp_delegated_actions); +static struct net_device mptcp_napi_dev; + +/* If msk has an initial subflow socket, and the MP_CAPABLE handshake has not + * completed yet or has failed, return the subflow socket. + * Otherwise return NULL. + */ +struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk) +{ + if (!msk->subflow || READ_ONCE(msk->can_ack)) + return NULL; + + return msk->subflow; +} + +/* Returns end sequence number of the receiver's advertised window */ +static u64 mptcp_wnd_end(const struct mptcp_sock *msk) +{ + return READ_ONCE(msk->wnd_end); +} + +static bool mptcp_is_tcpsk(struct sock *sk) +{ + struct socket *sock = sk->sk_socket; + + if (unlikely(sk->sk_prot == &tcp_prot)) { + /* we are being invoked after mptcp_accept() has + * accepted a non-mp-capable flow: sk is a tcp_sk, + * not an mptcp one. + * + * Hand the socket over to tcp so all further socket ops + * bypass mptcp. + */ + sock->ops = &inet_stream_ops; + return true; +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + } else if (unlikely(sk->sk_prot == &tcpv6_prot)) { + sock->ops = &inet6_stream_ops; + return true; +#endif + } + + return false; +} + +static int __mptcp_socket_create(struct mptcp_sock *msk) +{ + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + struct socket *ssock; + int err; + + err = mptcp_subflow_create_socket(sk, sk->sk_family, &ssock); + if (err) + return err; + + WRITE_ONCE(msk->first, ssock->sk); + WRITE_ONCE(msk->subflow, ssock); + subflow = mptcp_subflow_ctx(ssock->sk); + list_add(&subflow->node, &msk->conn_list); + sock_hold(ssock->sk); + subflow->request_mptcp = 1; + + /* This is the first subflow, always with id 0 */ + subflow->local_id_valid = 1; + mptcp_sock_graft(msk->first, sk->sk_socket); + + return 0; +} + +static void mptcp_drop(struct sock *sk, struct sk_buff *skb) +{ + sk_drops_add(sk, skb); + __kfree_skb(skb); +} + +static void mptcp_rmem_fwd_alloc_add(struct sock *sk, int size) +{ + WRITE_ONCE(mptcp_sk(sk)->rmem_fwd_alloc, + mptcp_sk(sk)->rmem_fwd_alloc + size); +} + +static void mptcp_rmem_charge(struct sock *sk, int size) +{ + mptcp_rmem_fwd_alloc_add(sk, -size); +} + +static bool mptcp_try_coalesce(struct sock *sk, struct sk_buff *to, + struct sk_buff *from) +{ + bool fragstolen; + int delta; + + if (MPTCP_SKB_CB(from)->offset || + !skb_try_coalesce(to, from, &fragstolen, &delta)) + return false; + + pr_debug("colesced seq %llx into %llx new len %d new end seq %llx", + MPTCP_SKB_CB(from)->map_seq, MPTCP_SKB_CB(to)->map_seq, + to->len, MPTCP_SKB_CB(from)->end_seq); + MPTCP_SKB_CB(to)->end_seq = MPTCP_SKB_CB(from)->end_seq; + + /* note the fwd memory can reach a negative value after accounting + * for the delta, but the later skb free will restore a non + * negative one + */ + atomic_add(delta, &sk->sk_rmem_alloc); + mptcp_rmem_charge(sk, delta); + kfree_skb_partial(from, fragstolen); + + return true; +} + +static bool mptcp_ooo_try_coalesce(struct mptcp_sock *msk, struct sk_buff *to, + struct sk_buff *from) +{ + if (MPTCP_SKB_CB(from)->map_seq != MPTCP_SKB_CB(to)->end_seq) + return false; + + return mptcp_try_coalesce((struct sock *)msk, to, from); +} + +static void __mptcp_rmem_reclaim(struct sock *sk, int amount) +{ + amount >>= PAGE_SHIFT; + mptcp_rmem_charge(sk, amount << PAGE_SHIFT); + __sk_mem_reduce_allocated(sk, amount); +} + +static void mptcp_rmem_uncharge(struct sock *sk, int size) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + int reclaimable; + + mptcp_rmem_fwd_alloc_add(sk, size); + reclaimable = msk->rmem_fwd_alloc - sk_unused_reserved_mem(sk); + + /* see sk_mem_uncharge() for the rationale behind the following schema */ + if (unlikely(reclaimable >= PAGE_SIZE)) + __mptcp_rmem_reclaim(sk, reclaimable); +} + +static void mptcp_rfree(struct sk_buff *skb) +{ + unsigned int len = skb->truesize; + struct sock *sk = skb->sk; + + atomic_sub(len, &sk->sk_rmem_alloc); + mptcp_rmem_uncharge(sk, len); +} + +static void mptcp_set_owner_r(struct sk_buff *skb, struct sock *sk) +{ + skb_orphan(skb); + skb->sk = sk; + skb->destructor = mptcp_rfree; + atomic_add(skb->truesize, &sk->sk_rmem_alloc); + mptcp_rmem_charge(sk, skb->truesize); +} + +/* "inspired" by tcp_data_queue_ofo(), main differences: + * - use mptcp seqs + * - don't cope with sacks + */ +static void mptcp_data_queue_ofo(struct mptcp_sock *msk, struct sk_buff *skb) +{ + struct sock *sk = (struct sock *)msk; + struct rb_node **p, *parent; + u64 seq, end_seq, max_seq; + struct sk_buff *skb1; + + seq = MPTCP_SKB_CB(skb)->map_seq; + end_seq = MPTCP_SKB_CB(skb)->end_seq; + max_seq = atomic64_read(&msk->rcv_wnd_sent); + + pr_debug("msk=%p seq=%llx limit=%llx empty=%d", msk, seq, max_seq, + RB_EMPTY_ROOT(&msk->out_of_order_queue)); + if (after64(end_seq, max_seq)) { + /* out of window */ + mptcp_drop(sk, skb); + pr_debug("oow by %lld, rcv_wnd_sent %llu\n", + (unsigned long long)end_seq - (unsigned long)max_seq, + (unsigned long long)atomic64_read(&msk->rcv_wnd_sent)); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_NODSSWINDOW); + return; + } + + p = &msk->out_of_order_queue.rb_node; + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_OFOQUEUE); + if (RB_EMPTY_ROOT(&msk->out_of_order_queue)) { + rb_link_node(&skb->rbnode, NULL, p); + rb_insert_color(&skb->rbnode, &msk->out_of_order_queue); + msk->ooo_last_skb = skb; + goto end; + } + + /* with 2 subflows, adding at end of ooo queue is quite likely + * Use of ooo_last_skb avoids the O(Log(N)) rbtree lookup. + */ + if (mptcp_ooo_try_coalesce(msk, msk->ooo_last_skb, skb)) { + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_OFOMERGE); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_OFOQUEUETAIL); + return; + } + + /* Can avoid an rbtree lookup if we are adding skb after ooo_last_skb */ + if (!before64(seq, MPTCP_SKB_CB(msk->ooo_last_skb)->end_seq)) { + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_OFOQUEUETAIL); + parent = &msk->ooo_last_skb->rbnode; + p = &parent->rb_right; + goto insert; + } + + /* Find place to insert this segment. Handle overlaps on the way. */ + parent = NULL; + while (*p) { + parent = *p; + skb1 = rb_to_skb(parent); + if (before64(seq, MPTCP_SKB_CB(skb1)->map_seq)) { + p = &parent->rb_left; + continue; + } + if (before64(seq, MPTCP_SKB_CB(skb1)->end_seq)) { + if (!after64(end_seq, MPTCP_SKB_CB(skb1)->end_seq)) { + /* All the bits are present. Drop. */ + mptcp_drop(sk, skb); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_DUPDATA); + return; + } + if (after64(seq, MPTCP_SKB_CB(skb1)->map_seq)) { + /* partial overlap: + * | skb | + * | skb1 | + * continue traversing + */ + } else { + /* skb's seq == skb1's seq and skb covers skb1. + * Replace skb1 with skb. + */ + rb_replace_node(&skb1->rbnode, &skb->rbnode, + &msk->out_of_order_queue); + mptcp_drop(sk, skb1); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_DUPDATA); + goto merge_right; + } + } else if (mptcp_ooo_try_coalesce(msk, skb1, skb)) { + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_OFOMERGE); + return; + } + p = &parent->rb_right; + } + +insert: + /* Insert segment into RB tree. */ + rb_link_node(&skb->rbnode, parent, p); + rb_insert_color(&skb->rbnode, &msk->out_of_order_queue); + +merge_right: + /* Remove other segments covered by skb. */ + while ((skb1 = skb_rb_next(skb)) != NULL) { + if (before64(end_seq, MPTCP_SKB_CB(skb1)->end_seq)) + break; + rb_erase(&skb1->rbnode, &msk->out_of_order_queue); + mptcp_drop(sk, skb1); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_DUPDATA); + } + /* If there is no skb after us, we are the last_skb ! */ + if (!skb1) + msk->ooo_last_skb = skb; + +end: + skb_condense(skb); + mptcp_set_owner_r(skb, sk); +} + +static bool mptcp_rmem_schedule(struct sock *sk, struct sock *ssk, int size) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + int amt, amount; + + if (size <= msk->rmem_fwd_alloc) + return true; + + size -= msk->rmem_fwd_alloc; + amt = sk_mem_pages(size); + amount = amt << PAGE_SHIFT; + if (!__sk_mem_raise_allocated(sk, size, amt, SK_MEM_RECV)) + return false; + + mptcp_rmem_fwd_alloc_add(sk, amount); + return true; +} + +static bool __mptcp_move_skb(struct mptcp_sock *msk, struct sock *ssk, + struct sk_buff *skb, unsigned int offset, + size_t copy_len) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + struct sock *sk = (struct sock *)msk; + struct sk_buff *tail; + bool has_rxtstamp; + + __skb_unlink(skb, &ssk->sk_receive_queue); + + skb_ext_reset(skb); + skb_orphan(skb); + + /* try to fetch required memory from subflow */ + if (!mptcp_rmem_schedule(sk, ssk, skb->truesize)) + goto drop; + + has_rxtstamp = TCP_SKB_CB(skb)->has_rxtstamp; + + /* the skb map_seq accounts for the skb offset: + * mptcp_subflow_get_mapped_dsn() is based on the current tp->copied_seq + * value + */ + MPTCP_SKB_CB(skb)->map_seq = mptcp_subflow_get_mapped_dsn(subflow); + MPTCP_SKB_CB(skb)->end_seq = MPTCP_SKB_CB(skb)->map_seq + copy_len; + MPTCP_SKB_CB(skb)->offset = offset; + MPTCP_SKB_CB(skb)->has_rxtstamp = has_rxtstamp; + + if (MPTCP_SKB_CB(skb)->map_seq == msk->ack_seq) { + /* in sequence */ + WRITE_ONCE(msk->ack_seq, msk->ack_seq + copy_len); + tail = skb_peek_tail(&sk->sk_receive_queue); + if (tail && mptcp_try_coalesce(sk, tail, skb)) + return true; + + mptcp_set_owner_r(skb, sk); + __skb_queue_tail(&sk->sk_receive_queue, skb); + return true; + } else if (after64(MPTCP_SKB_CB(skb)->map_seq, msk->ack_seq)) { + mptcp_data_queue_ofo(msk, skb); + return false; + } + + /* old data, keep it simple and drop the whole pkt, sender + * will retransmit as needed, if needed. + */ + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_DUPDATA); +drop: + mptcp_drop(sk, skb); + return false; +} + +static void mptcp_stop_rtx_timer(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + sk_stop_timer(sk, &icsk->icsk_retransmit_timer); + mptcp_sk(sk)->timer_ival = 0; +} + +static void mptcp_close_wake_up(struct sock *sk) +{ + if (sock_flag(sk, SOCK_DEAD)) + return; + + sk->sk_state_change(sk); + if (sk->sk_shutdown == SHUTDOWN_MASK || + sk->sk_state == TCP_CLOSE) + sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_HUP); + else + sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN); +} + +static bool mptcp_pending_data_fin_ack(struct sock *sk) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + + return ((1 << sk->sk_state) & + (TCPF_FIN_WAIT1 | TCPF_CLOSING | TCPF_LAST_ACK)) && + msk->write_seq == READ_ONCE(msk->snd_una); +} + +static void mptcp_check_data_fin_ack(struct sock *sk) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + + /* Look for an acknowledged DATA_FIN */ + if (mptcp_pending_data_fin_ack(sk)) { + WRITE_ONCE(msk->snd_data_fin_enable, 0); + + switch (sk->sk_state) { + case TCP_FIN_WAIT1: + inet_sk_state_store(sk, TCP_FIN_WAIT2); + break; + case TCP_CLOSING: + case TCP_LAST_ACK: + inet_sk_state_store(sk, TCP_CLOSE); + break; + } + + mptcp_close_wake_up(sk); + } +} + +static bool mptcp_pending_data_fin(struct sock *sk, u64 *seq) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + + if (READ_ONCE(msk->rcv_data_fin) && + ((1 << sk->sk_state) & + (TCPF_ESTABLISHED | TCPF_FIN_WAIT1 | TCPF_FIN_WAIT2))) { + u64 rcv_data_fin_seq = READ_ONCE(msk->rcv_data_fin_seq); + + if (msk->ack_seq == rcv_data_fin_seq) { + if (seq) + *seq = rcv_data_fin_seq; + + return true; + } + } + + return false; +} + +static void mptcp_set_datafin_timeout(const struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + u32 retransmits; + + retransmits = min_t(u32, icsk->icsk_retransmits, + ilog2(TCP_RTO_MAX / TCP_RTO_MIN)); + + mptcp_sk(sk)->timer_ival = TCP_RTO_MIN << retransmits; +} + +static void __mptcp_set_timeout(struct sock *sk, long tout) +{ + mptcp_sk(sk)->timer_ival = tout > 0 ? tout : TCP_RTO_MIN; +} + +static long mptcp_timeout_from_subflow(const struct mptcp_subflow_context *subflow) +{ + const struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + + return inet_csk(ssk)->icsk_pending && !subflow->stale_count ? + inet_csk(ssk)->icsk_timeout - jiffies : 0; +} + +static void mptcp_set_timeout(struct sock *sk) +{ + struct mptcp_subflow_context *subflow; + long tout = 0; + + mptcp_for_each_subflow(mptcp_sk(sk), subflow) + tout = max(tout, mptcp_timeout_from_subflow(subflow)); + __mptcp_set_timeout(sk, tout); +} + +static inline bool tcp_can_send_ack(const struct sock *ssk) +{ + return !((1 << inet_sk_state_load(ssk)) & + (TCPF_SYN_SENT | TCPF_SYN_RECV | TCPF_TIME_WAIT | TCPF_CLOSE | TCPF_LISTEN)); +} + +void __mptcp_subflow_send_ack(struct sock *ssk) +{ + if (tcp_can_send_ack(ssk)) + tcp_send_ack(ssk); +} + +static void mptcp_subflow_send_ack(struct sock *ssk) +{ + bool slow; + + slow = lock_sock_fast(ssk); + __mptcp_subflow_send_ack(ssk); + unlock_sock_fast(ssk, slow); +} + +static void mptcp_send_ack(struct mptcp_sock *msk) +{ + struct mptcp_subflow_context *subflow; + + mptcp_for_each_subflow(msk, subflow) + mptcp_subflow_send_ack(mptcp_subflow_tcp_sock(subflow)); +} + +static void mptcp_subflow_cleanup_rbuf(struct sock *ssk) +{ + bool slow; + + slow = lock_sock_fast(ssk); + if (tcp_can_send_ack(ssk)) + tcp_cleanup_rbuf(ssk, 1); + unlock_sock_fast(ssk, slow); +} + +static bool mptcp_subflow_could_cleanup(const struct sock *ssk, bool rx_empty) +{ + const struct inet_connection_sock *icsk = inet_csk(ssk); + u8 ack_pending = READ_ONCE(icsk->icsk_ack.pending); + const struct tcp_sock *tp = tcp_sk(ssk); + + return (ack_pending & ICSK_ACK_SCHED) && + ((READ_ONCE(tp->rcv_nxt) - READ_ONCE(tp->rcv_wup) > + READ_ONCE(icsk->icsk_ack.rcv_mss)) || + (rx_empty && ack_pending & + (ICSK_ACK_PUSHED2 | ICSK_ACK_PUSHED))); +} + +static void mptcp_cleanup_rbuf(struct mptcp_sock *msk) +{ + int old_space = READ_ONCE(msk->old_wspace); + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + int space = __mptcp_space(sk); + bool cleanup, rx_empty; + + cleanup = (space > 0) && (space >= (old_space << 1)); + rx_empty = !__mptcp_rmem(sk); + + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + + if (cleanup || mptcp_subflow_could_cleanup(ssk, rx_empty)) + mptcp_subflow_cleanup_rbuf(ssk); + } +} + +static bool mptcp_check_data_fin(struct sock *sk) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + u64 rcv_data_fin_seq; + bool ret = false; + + /* Need to ack a DATA_FIN received from a peer while this side + * of the connection is in ESTABLISHED, FIN_WAIT1, or FIN_WAIT2. + * msk->rcv_data_fin was set when parsing the incoming options + * at the subflow level and the msk lock was not held, so this + * is the first opportunity to act on the DATA_FIN and change + * the msk state. + * + * If we are caught up to the sequence number of the incoming + * DATA_FIN, send the DATA_ACK now and do state transition. If + * not caught up, do nothing and let the recv code send DATA_ACK + * when catching up. + */ + + if (mptcp_pending_data_fin(sk, &rcv_data_fin_seq)) { + WRITE_ONCE(msk->ack_seq, msk->ack_seq + 1); + WRITE_ONCE(msk->rcv_data_fin, 0); + + WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | RCV_SHUTDOWN); + smp_mb__before_atomic(); /* SHUTDOWN must be visible first */ + + switch (sk->sk_state) { + case TCP_ESTABLISHED: + inet_sk_state_store(sk, TCP_CLOSE_WAIT); + break; + case TCP_FIN_WAIT1: + inet_sk_state_store(sk, TCP_CLOSING); + break; + case TCP_FIN_WAIT2: + inet_sk_state_store(sk, TCP_CLOSE); + break; + default: + /* Other states not expected */ + WARN_ON_ONCE(1); + break; + } + + ret = true; + if (!__mptcp_check_fallback(msk)) + mptcp_send_ack(msk); + mptcp_close_wake_up(sk); + } + return ret; +} + +static bool __mptcp_move_skbs_from_subflow(struct mptcp_sock *msk, + struct sock *ssk, + unsigned int *bytes) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + struct sock *sk = (struct sock *)msk; + unsigned int moved = 0; + bool more_data_avail; + struct tcp_sock *tp; + bool done = false; + int sk_rbuf; + + sk_rbuf = READ_ONCE(sk->sk_rcvbuf); + + if (!(sk->sk_userlocks & SOCK_RCVBUF_LOCK)) { + int ssk_rbuf = READ_ONCE(ssk->sk_rcvbuf); + + if (unlikely(ssk_rbuf > sk_rbuf)) { + WRITE_ONCE(sk->sk_rcvbuf, ssk_rbuf); + sk_rbuf = ssk_rbuf; + } + } + + pr_debug("msk=%p ssk=%p", msk, ssk); + tp = tcp_sk(ssk); + do { + u32 map_remaining, offset; + u32 seq = tp->copied_seq; + struct sk_buff *skb; + bool fin; + + /* try to move as much data as available */ + map_remaining = subflow->map_data_len - + mptcp_subflow_get_map_offset(subflow); + + skb = skb_peek(&ssk->sk_receive_queue); + if (!skb) { + /* With racing move_skbs_to_msk() and __mptcp_move_skbs(), + * a different CPU can have already processed the pending + * data, stop here or we can enter an infinite loop + */ + if (!moved) + done = true; + break; + } + + if (__mptcp_check_fallback(msk)) { + /* Under fallback skbs have no MPTCP extension and TCP could + * collapse them between the dummy map creation and the + * current dequeue. Be sure to adjust the map size. + */ + map_remaining = skb->len; + subflow->map_data_len = skb->len; + } + + offset = seq - TCP_SKB_CB(skb)->seq; + fin = TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN; + if (fin) { + done = true; + seq++; + } + + if (offset < skb->len) { + size_t len = skb->len - offset; + + if (tp->urg_data) + done = true; + + if (__mptcp_move_skb(msk, ssk, skb, offset, len)) + moved += len; + seq += len; + + if (WARN_ON_ONCE(map_remaining < len)) + break; + } else { + WARN_ON_ONCE(!fin); + sk_eat_skb(ssk, skb); + done = true; + } + + WRITE_ONCE(tp->copied_seq, seq); + more_data_avail = mptcp_subflow_data_available(ssk); + + if (atomic_read(&sk->sk_rmem_alloc) > sk_rbuf) { + done = true; + break; + } + } while (more_data_avail); + + *bytes += moved; + return done; +} + +static bool __mptcp_ofo_queue(struct mptcp_sock *msk) +{ + struct sock *sk = (struct sock *)msk; + struct sk_buff *skb, *tail; + bool moved = false; + struct rb_node *p; + u64 end_seq; + + p = rb_first(&msk->out_of_order_queue); + pr_debug("msk=%p empty=%d", msk, RB_EMPTY_ROOT(&msk->out_of_order_queue)); + while (p) { + skb = rb_to_skb(p); + if (after64(MPTCP_SKB_CB(skb)->map_seq, msk->ack_seq)) + break; + + p = rb_next(p); + rb_erase(&skb->rbnode, &msk->out_of_order_queue); + + if (unlikely(!after64(MPTCP_SKB_CB(skb)->end_seq, + msk->ack_seq))) { + mptcp_drop(sk, skb); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_DUPDATA); + continue; + } + + end_seq = MPTCP_SKB_CB(skb)->end_seq; + tail = skb_peek_tail(&sk->sk_receive_queue); + if (!tail || !mptcp_ooo_try_coalesce(msk, tail, skb)) { + int delta = msk->ack_seq - MPTCP_SKB_CB(skb)->map_seq; + + /* skip overlapping data, if any */ + pr_debug("uncoalesced seq=%llx ack seq=%llx delta=%d", + MPTCP_SKB_CB(skb)->map_seq, msk->ack_seq, + delta); + MPTCP_SKB_CB(skb)->offset += delta; + MPTCP_SKB_CB(skb)->map_seq += delta; + __skb_queue_tail(&sk->sk_receive_queue, skb); + } + msk->ack_seq = end_seq; + moved = true; + } + return moved; +} + +static bool __mptcp_subflow_error_report(struct sock *sk, struct sock *ssk) +{ + int err = sock_error(ssk); + int ssk_state; + + if (!err) + return false; + + /* only propagate errors on fallen-back sockets or + * on MPC connect + */ + if (sk->sk_state != TCP_SYN_SENT && !__mptcp_check_fallback(mptcp_sk(sk))) + return false; + + /* We need to propagate only transition to CLOSE state. + * Orphaned socket will see such state change via + * subflow_sched_work_if_closed() and that path will properly + * destroy the msk as needed. + */ + ssk_state = inet_sk_state_load(ssk); + if (ssk_state == TCP_CLOSE && !sock_flag(sk, SOCK_DEAD)) + inet_sk_state_store(sk, ssk_state); + WRITE_ONCE(sk->sk_err, -err); + + /* This barrier is coupled with smp_rmb() in mptcp_poll() */ + smp_wmb(); + sk_error_report(sk); + return true; +} + +void __mptcp_error_report(struct sock *sk) +{ + struct mptcp_subflow_context *subflow; + struct mptcp_sock *msk = mptcp_sk(sk); + + mptcp_for_each_subflow(msk, subflow) + if (__mptcp_subflow_error_report(sk, mptcp_subflow_tcp_sock(subflow))) + break; +} + +/* In most cases we will be able to lock the mptcp socket. If its already + * owned, we need to defer to the work queue to avoid ABBA deadlock. + */ +static bool move_skbs_to_msk(struct mptcp_sock *msk, struct sock *ssk) +{ + struct sock *sk = (struct sock *)msk; + unsigned int moved = 0; + + __mptcp_move_skbs_from_subflow(msk, ssk, &moved); + __mptcp_ofo_queue(msk); + if (unlikely(ssk->sk_err)) { + if (!sock_owned_by_user(sk)) + __mptcp_error_report(sk); + else + __set_bit(MPTCP_ERROR_REPORT, &msk->cb_flags); + } + + /* If the moves have caught up with the DATA_FIN sequence number + * it's time to ack the DATA_FIN and change socket state, but + * this is not a good place to change state. Let the workqueue + * do it. + */ + if (mptcp_pending_data_fin(sk, NULL)) + mptcp_schedule_work(sk); + return moved > 0; +} + +void mptcp_data_ready(struct sock *sk, struct sock *ssk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + struct mptcp_sock *msk = mptcp_sk(sk); + int sk_rbuf, ssk_rbuf; + + /* The peer can send data while we are shutting down this + * subflow at msk destruction time, but we must avoid enqueuing + * more data to the msk receive queue + */ + if (unlikely(subflow->disposable)) + return; + + ssk_rbuf = READ_ONCE(ssk->sk_rcvbuf); + sk_rbuf = READ_ONCE(sk->sk_rcvbuf); + if (unlikely(ssk_rbuf > sk_rbuf)) + sk_rbuf = ssk_rbuf; + + /* over limit? can't append more skbs to msk, Also, no need to wake-up*/ + if (__mptcp_rmem(sk) > sk_rbuf) { + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RCVPRUNED); + return; + } + + /* Wake-up the reader only for in-sequence data */ + mptcp_data_lock(sk); + if (move_skbs_to_msk(msk, ssk)) + sk->sk_data_ready(sk); + + mptcp_data_unlock(sk); +} + +static void mptcp_subflow_joined(struct mptcp_sock *msk, struct sock *ssk) +{ + mptcp_subflow_ctx(ssk)->map_seq = READ_ONCE(msk->ack_seq); + WRITE_ONCE(msk->allow_infinite_fallback, false); + mptcp_event(MPTCP_EVENT_SUB_ESTABLISHED, msk, ssk, GFP_ATOMIC); +} + +static bool __mptcp_finish_join(struct mptcp_sock *msk, struct sock *ssk) +{ + struct sock *sk = (struct sock *)msk; + + if (sk->sk_state != TCP_ESTABLISHED) + return false; + + /* attach to msk socket only after we are sure we will deal with it + * at close time + */ + if (sk->sk_socket && !ssk->sk_socket) + mptcp_sock_graft(ssk, sk->sk_socket); + + mptcp_sockopt_sync_locked(msk, ssk); + mptcp_subflow_joined(msk, ssk); + mptcp_stop_tout_timer(sk); + return true; +} + +static void __mptcp_flush_join_list(struct sock *sk, struct list_head *join_list) +{ + struct mptcp_subflow_context *tmp, *subflow; + struct mptcp_sock *msk = mptcp_sk(sk); + + list_for_each_entry_safe(subflow, tmp, join_list, node) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + bool slow = lock_sock_fast(ssk); + + list_move_tail(&subflow->node, &msk->conn_list); + if (!__mptcp_finish_join(msk, ssk)) + mptcp_subflow_reset(ssk); + unlock_sock_fast(ssk, slow); + } +} + +static bool mptcp_rtx_timer_pending(struct sock *sk) +{ + return timer_pending(&inet_csk(sk)->icsk_retransmit_timer); +} + +static void mptcp_reset_rtx_timer(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + unsigned long tout; + + /* prevent rescheduling on close */ + if (unlikely(inet_sk_state_load(sk) == TCP_CLOSE)) + return; + + tout = mptcp_sk(sk)->timer_ival; + sk_reset_timer(sk, &icsk->icsk_retransmit_timer, jiffies + tout); +} + +bool mptcp_schedule_work(struct sock *sk) +{ + if (inet_sk_state_load(sk) != TCP_CLOSE && + schedule_work(&mptcp_sk(sk)->work)) { + /* each subflow already holds a reference to the sk, and the + * workqueue is invoked by a subflow, so sk can't go away here. + */ + sock_hold(sk); + return true; + } + return false; +} + +void mptcp_subflow_eof(struct sock *sk) +{ + if (!test_and_set_bit(MPTCP_WORK_EOF, &mptcp_sk(sk)->flags)) + mptcp_schedule_work(sk); +} + +static void mptcp_check_for_eof(struct mptcp_sock *msk) +{ + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + int receivers = 0; + + mptcp_for_each_subflow(msk, subflow) + receivers += !subflow->rx_eof; + if (receivers) + return; + + if (!(sk->sk_shutdown & RCV_SHUTDOWN)) { + /* hopefully temporary hack: propagate shutdown status + * to msk, when all subflows agree on it + */ + WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | RCV_SHUTDOWN); + + smp_mb__before_atomic(); /* SHUTDOWN must be visible first */ + sk->sk_data_ready(sk); + } + + switch (sk->sk_state) { + case TCP_ESTABLISHED: + inet_sk_state_store(sk, TCP_CLOSE_WAIT); + break; + case TCP_FIN_WAIT1: + inet_sk_state_store(sk, TCP_CLOSING); + break; + case TCP_FIN_WAIT2: + inet_sk_state_store(sk, TCP_CLOSE); + break; + default: + return; + } + mptcp_close_wake_up(sk); +} + +static struct sock *mptcp_subflow_recv_lookup(const struct mptcp_sock *msk) +{ + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + + sock_owned_by_me(sk); + + mptcp_for_each_subflow(msk, subflow) { + if (READ_ONCE(subflow->data_avail)) + return mptcp_subflow_tcp_sock(subflow); + } + + return NULL; +} + +static bool mptcp_skb_can_collapse_to(u64 write_seq, + const struct sk_buff *skb, + const struct mptcp_ext *mpext) +{ + if (!tcp_skb_can_collapse_to(skb)) + return false; + + /* can collapse only if MPTCP level sequence is in order and this + * mapping has not been xmitted yet + */ + return mpext && mpext->data_seq + mpext->data_len == write_seq && + !mpext->frozen; +} + +/* we can append data to the given data frag if: + * - there is space available in the backing page_frag + * - the data frag tail matches the current page_frag free offset + * - the data frag end sequence number matches the current write seq + */ +static bool mptcp_frag_can_collapse_to(const struct mptcp_sock *msk, + const struct page_frag *pfrag, + const struct mptcp_data_frag *df) +{ + return df && pfrag->page == df->page && + pfrag->size - pfrag->offset > 0 && + pfrag->offset == (df->offset + df->data_len) && + df->data_seq + df->data_len == msk->write_seq; +} + +static void dfrag_uncharge(struct sock *sk, int len) +{ + sk_mem_uncharge(sk, len); + sk_wmem_queued_add(sk, -len); +} + +static void dfrag_clear(struct sock *sk, struct mptcp_data_frag *dfrag) +{ + int len = dfrag->data_len + dfrag->overhead; + + list_del(&dfrag->list); + dfrag_uncharge(sk, len); + put_page(dfrag->page); +} + +static void __mptcp_clean_una(struct sock *sk) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + struct mptcp_data_frag *dtmp, *dfrag; + u64 snd_una; + + /* on fallback we just need to ignore snd_una, as this is really + * plain TCP + */ + if (__mptcp_check_fallback(msk)) + msk->snd_una = READ_ONCE(msk->snd_nxt); + + snd_una = msk->snd_una; + list_for_each_entry_safe(dfrag, dtmp, &msk->rtx_queue, list) { + if (after64(dfrag->data_seq + dfrag->data_len, snd_una)) + break; + + if (unlikely(dfrag == msk->first_pending)) { + /* in recovery mode can see ack after the current snd head */ + if (WARN_ON_ONCE(!msk->recovery)) + break; + + WRITE_ONCE(msk->first_pending, mptcp_send_next(sk)); + } + + dfrag_clear(sk, dfrag); + } + + dfrag = mptcp_rtx_head(sk); + if (dfrag && after64(snd_una, dfrag->data_seq)) { + u64 delta = snd_una - dfrag->data_seq; + + /* prevent wrap around in recovery mode */ + if (unlikely(delta > dfrag->already_sent)) { + if (WARN_ON_ONCE(!msk->recovery)) + goto out; + if (WARN_ON_ONCE(delta > dfrag->data_len)) + goto out; + dfrag->already_sent += delta - dfrag->already_sent; + } + + dfrag->data_seq += delta; + dfrag->offset += delta; + dfrag->data_len -= delta; + dfrag->already_sent -= delta; + + dfrag_uncharge(sk, delta); + } + + /* all retransmitted data acked, recovery completed */ + if (unlikely(msk->recovery) && after64(msk->snd_una, msk->recovery_snd_nxt)) + msk->recovery = false; + +out: + if (snd_una == READ_ONCE(msk->snd_nxt) && + snd_una == READ_ONCE(msk->write_seq)) { + if (mptcp_rtx_timer_pending(sk) && !mptcp_data_fin_enabled(msk)) + mptcp_stop_rtx_timer(sk); + } else { + mptcp_reset_rtx_timer(sk); + } +} + +static void __mptcp_clean_una_wakeup(struct sock *sk) +{ + lockdep_assert_held_once(&sk->sk_lock.slock); + + __mptcp_clean_una(sk); + mptcp_write_space(sk); +} + +static void mptcp_clean_una_wakeup(struct sock *sk) +{ + mptcp_data_lock(sk); + __mptcp_clean_una_wakeup(sk); + mptcp_data_unlock(sk); +} + +static void mptcp_enter_memory_pressure(struct sock *sk) +{ + struct mptcp_subflow_context *subflow; + struct mptcp_sock *msk = mptcp_sk(sk); + bool first = true; + + sk_stream_moderate_sndbuf(sk); + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + + if (first) + tcp_enter_memory_pressure(ssk); + sk_stream_moderate_sndbuf(ssk); + first = false; + } +} + +/* ensure we get enough memory for the frag hdr, beyond some minimal amount of + * data + */ +static bool mptcp_page_frag_refill(struct sock *sk, struct page_frag *pfrag) +{ + if (likely(skb_page_frag_refill(32U + sizeof(struct mptcp_data_frag), + pfrag, sk->sk_allocation))) + return true; + + mptcp_enter_memory_pressure(sk); + return false; +} + +static struct mptcp_data_frag * +mptcp_carve_data_frag(const struct mptcp_sock *msk, struct page_frag *pfrag, + int orig_offset) +{ + int offset = ALIGN(orig_offset, sizeof(long)); + struct mptcp_data_frag *dfrag; + + dfrag = (struct mptcp_data_frag *)(page_to_virt(pfrag->page) + offset); + dfrag->data_len = 0; + dfrag->data_seq = msk->write_seq; + dfrag->overhead = offset - orig_offset + sizeof(struct mptcp_data_frag); + dfrag->offset = offset + sizeof(struct mptcp_data_frag); + dfrag->already_sent = 0; + dfrag->page = pfrag->page; + + return dfrag; +} + +struct mptcp_sendmsg_info { + int mss_now; + int size_goal; + u16 limit; + u16 sent; + unsigned int flags; + bool data_lock_held; +}; + +static int mptcp_check_allowed_size(const struct mptcp_sock *msk, struct sock *ssk, + u64 data_seq, int avail_size) +{ + u64 window_end = mptcp_wnd_end(msk); + u64 mptcp_snd_wnd; + + if (__mptcp_check_fallback(msk)) + return avail_size; + + mptcp_snd_wnd = window_end - data_seq; + avail_size = min_t(unsigned int, mptcp_snd_wnd, avail_size); + + if (unlikely(tcp_sk(ssk)->snd_wnd < mptcp_snd_wnd)) { + tcp_sk(ssk)->snd_wnd = min_t(u64, U32_MAX, mptcp_snd_wnd); + MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_SNDWNDSHARED); + } + + return avail_size; +} + +static bool __mptcp_add_ext(struct sk_buff *skb, gfp_t gfp) +{ + struct skb_ext *mpext = __skb_ext_alloc(gfp); + + if (!mpext) + return false; + __skb_ext_set(skb, SKB_EXT_MPTCP, mpext); + return true; +} + +static struct sk_buff *__mptcp_do_alloc_tx_skb(struct sock *sk, gfp_t gfp) +{ + struct sk_buff *skb; + + skb = alloc_skb_fclone(MAX_TCP_HEADER, gfp); + if (likely(skb)) { + if (likely(__mptcp_add_ext(skb, gfp))) { + skb_reserve(skb, MAX_TCP_HEADER); + skb->ip_summed = CHECKSUM_PARTIAL; + INIT_LIST_HEAD(&skb->tcp_tsorted_anchor); + return skb; + } + __kfree_skb(skb); + } else { + mptcp_enter_memory_pressure(sk); + } + return NULL; +} + +static struct sk_buff *__mptcp_alloc_tx_skb(struct sock *sk, struct sock *ssk, gfp_t gfp) +{ + struct sk_buff *skb; + + skb = __mptcp_do_alloc_tx_skb(sk, gfp); + if (!skb) + return NULL; + + if (likely(sk_wmem_schedule(ssk, skb->truesize))) { + tcp_skb_entail(ssk, skb); + return skb; + } + tcp_skb_tsorted_anchor_cleanup(skb); + kfree_skb(skb); + return NULL; +} + +static struct sk_buff *mptcp_alloc_tx_skb(struct sock *sk, struct sock *ssk, bool data_lock_held) +{ + gfp_t gfp = data_lock_held ? GFP_ATOMIC : sk->sk_allocation; + + return __mptcp_alloc_tx_skb(sk, ssk, gfp); +} + +/* note: this always recompute the csum on the whole skb, even + * if we just appended a single frag. More status info needed + */ +static void mptcp_update_data_checksum(struct sk_buff *skb, int added) +{ + struct mptcp_ext *mpext = mptcp_get_ext(skb); + __wsum csum = ~csum_unfold(mpext->csum); + int offset = skb->len - added; + + mpext->csum = csum_fold(csum_block_add(csum, skb_checksum(skb, offset, added, 0), offset)); +} + +static void mptcp_update_infinite_map(struct mptcp_sock *msk, + struct sock *ssk, + struct mptcp_ext *mpext) +{ + if (!mpext) + return; + + mpext->infinite_map = 1; + mpext->data_len = 0; + + MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_INFINITEMAPTX); + mptcp_subflow_ctx(ssk)->send_infinite_map = 0; + pr_fallback(msk); + mptcp_do_fallback(ssk); +} + +#define MPTCP_MAX_GSO_SIZE (GSO_LEGACY_MAX_SIZE - (MAX_TCP_HEADER + 1)) + +static int mptcp_sendmsg_frag(struct sock *sk, struct sock *ssk, + struct mptcp_data_frag *dfrag, + struct mptcp_sendmsg_info *info) +{ + u64 data_seq = dfrag->data_seq + info->sent; + int offset = dfrag->offset + info->sent; + struct mptcp_sock *msk = mptcp_sk(sk); + bool zero_window_probe = false; + struct mptcp_ext *mpext = NULL; + bool can_coalesce = false; + bool reuse_skb = true; + struct sk_buff *skb; + size_t copy; + int i; + + pr_debug("msk=%p ssk=%p sending dfrag at seq=%llu len=%u already sent=%u", + msk, ssk, dfrag->data_seq, dfrag->data_len, info->sent); + + if (WARN_ON_ONCE(info->sent > info->limit || + info->limit > dfrag->data_len)) + return 0; + + if (unlikely(!__tcp_can_send(ssk))) + return -EAGAIN; + + /* compute send limit */ + if (unlikely(ssk->sk_gso_max_size > MPTCP_MAX_GSO_SIZE)) + ssk->sk_gso_max_size = MPTCP_MAX_GSO_SIZE; + info->mss_now = tcp_send_mss(ssk, &info->size_goal, info->flags); + copy = info->size_goal; + + skb = tcp_write_queue_tail(ssk); + if (skb && copy > skb->len) { + /* Limit the write to the size available in the + * current skb, if any, so that we create at most a new skb. + * Explicitly tells TCP internals to avoid collapsing on later + * queue management operation, to avoid breaking the ext <-> + * SSN association set here + */ + mpext = skb_ext_find(skb, SKB_EXT_MPTCP); + if (!mptcp_skb_can_collapse_to(data_seq, skb, mpext)) { + TCP_SKB_CB(skb)->eor = 1; + goto alloc_skb; + } + + i = skb_shinfo(skb)->nr_frags; + can_coalesce = skb_can_coalesce(skb, i, dfrag->page, offset); + if (!can_coalesce && i >= READ_ONCE(sysctl_max_skb_frags)) { + tcp_mark_push(tcp_sk(ssk), skb); + goto alloc_skb; + } + + copy -= skb->len; + } else { +alloc_skb: + skb = mptcp_alloc_tx_skb(sk, ssk, info->data_lock_held); + if (!skb) + return -ENOMEM; + + i = skb_shinfo(skb)->nr_frags; + reuse_skb = false; + mpext = skb_ext_find(skb, SKB_EXT_MPTCP); + } + + /* Zero window and all data acked? Probe. */ + copy = mptcp_check_allowed_size(msk, ssk, data_seq, copy); + if (copy == 0) { + u64 snd_una = READ_ONCE(msk->snd_una); + + if (snd_una != msk->snd_nxt || tcp_write_queue_tail(ssk)) { + tcp_remove_empty_skb(ssk); + return 0; + } + + zero_window_probe = true; + data_seq = snd_una - 1; + copy = 1; + } + + copy = min_t(size_t, copy, info->limit - info->sent); + if (!sk_wmem_schedule(ssk, copy)) { + tcp_remove_empty_skb(ssk); + return -ENOMEM; + } + + if (can_coalesce) { + skb_frag_size_add(&skb_shinfo(skb)->frags[i - 1], copy); + } else { + get_page(dfrag->page); + skb_fill_page_desc(skb, i, dfrag->page, offset, copy); + } + + skb->len += copy; + skb->data_len += copy; + skb->truesize += copy; + sk_wmem_queued_add(ssk, copy); + sk_mem_charge(ssk, copy); + WRITE_ONCE(tcp_sk(ssk)->write_seq, tcp_sk(ssk)->write_seq + copy); + TCP_SKB_CB(skb)->end_seq += copy; + tcp_skb_pcount_set(skb, 0); + + /* on skb reuse we just need to update the DSS len */ + if (reuse_skb) { + TCP_SKB_CB(skb)->tcp_flags &= ~TCPHDR_PSH; + mpext->data_len += copy; + goto out; + } + + memset(mpext, 0, sizeof(*mpext)); + mpext->data_seq = data_seq; + mpext->subflow_seq = mptcp_subflow_ctx(ssk)->rel_write_seq; + mpext->data_len = copy; + mpext->use_map = 1; + mpext->dsn64 = 1; + + pr_debug("data_seq=%llu subflow_seq=%u data_len=%u dsn64=%d", + mpext->data_seq, mpext->subflow_seq, mpext->data_len, + mpext->dsn64); + + if (zero_window_probe) { + mptcp_subflow_ctx(ssk)->rel_write_seq += copy; + mpext->frozen = 1; + if (READ_ONCE(msk->csum_enabled)) + mptcp_update_data_checksum(skb, copy); + tcp_push_pending_frames(ssk); + return 0; + } +out: + if (READ_ONCE(msk->csum_enabled)) + mptcp_update_data_checksum(skb, copy); + if (mptcp_subflow_ctx(ssk)->send_infinite_map) + mptcp_update_infinite_map(msk, ssk, mpext); + trace_mptcp_sendmsg_frag(mpext); + mptcp_subflow_ctx(ssk)->rel_write_seq += copy; + return copy; +} + +#define MPTCP_SEND_BURST_SIZE ((1 << 16) - \ + sizeof(struct tcphdr) - \ + MAX_TCP_OPTION_SPACE - \ + sizeof(struct ipv6hdr) - \ + sizeof(struct frag_hdr)) + +struct subflow_send_info { + struct sock *ssk; + u64 linger_time; +}; + +void mptcp_subflow_set_active(struct mptcp_subflow_context *subflow) +{ + if (!subflow->stale) + return; + + subflow->stale = 0; + MPTCP_INC_STATS(sock_net(mptcp_subflow_tcp_sock(subflow)), MPTCP_MIB_SUBFLOWRECOVER); +} + +bool mptcp_subflow_active(struct mptcp_subflow_context *subflow) +{ + if (unlikely(subflow->stale)) { + u32 rcv_tstamp = READ_ONCE(tcp_sk(mptcp_subflow_tcp_sock(subflow))->rcv_tstamp); + + if (subflow->stale_rcv_tstamp == rcv_tstamp) + return false; + + mptcp_subflow_set_active(subflow); + } + return __mptcp_subflow_active(subflow); +} + +#define SSK_MODE_ACTIVE 0 +#define SSK_MODE_BACKUP 1 +#define SSK_MODE_MAX 2 + +/* implement the mptcp packet scheduler; + * returns the subflow that will transmit the next DSS + * additionally updates the rtx timeout + */ +static struct sock *mptcp_subflow_get_send(struct mptcp_sock *msk) +{ + struct subflow_send_info send_info[SSK_MODE_MAX]; + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + u32 pace, burst, wmem; + int i, nr_active = 0; + struct sock *ssk; + u64 linger_time; + long tout = 0; + + sock_owned_by_me(sk); + + if (__mptcp_check_fallback(msk)) { + if (!msk->first) + return NULL; + return __tcp_can_send(msk->first) && + sk_stream_memory_free(msk->first) ? msk->first : NULL; + } + + /* re-use last subflow, if the burst allow that */ + if (msk->last_snd && msk->snd_burst > 0 && + sk_stream_memory_free(msk->last_snd) && + mptcp_subflow_active(mptcp_subflow_ctx(msk->last_snd))) { + mptcp_set_timeout(sk); + return msk->last_snd; + } + + /* pick the subflow with the lower wmem/wspace ratio */ + for (i = 0; i < SSK_MODE_MAX; ++i) { + send_info[i].ssk = NULL; + send_info[i].linger_time = -1; + } + + mptcp_for_each_subflow(msk, subflow) { + trace_mptcp_subflow_get_send(subflow); + ssk = mptcp_subflow_tcp_sock(subflow); + if (!mptcp_subflow_active(subflow)) + continue; + + tout = max(tout, mptcp_timeout_from_subflow(subflow)); + nr_active += !subflow->backup; + pace = subflow->avg_pacing_rate; + if (unlikely(!pace)) { + /* init pacing rate from socket */ + subflow->avg_pacing_rate = READ_ONCE(ssk->sk_pacing_rate); + pace = subflow->avg_pacing_rate; + if (!pace) + continue; + } + + linger_time = div_u64((u64)READ_ONCE(ssk->sk_wmem_queued) << 32, pace); + if (linger_time < send_info[subflow->backup].linger_time) { + send_info[subflow->backup].ssk = ssk; + send_info[subflow->backup].linger_time = linger_time; + } + } + __mptcp_set_timeout(sk, tout); + + /* pick the best backup if no other subflow is active */ + if (!nr_active) + send_info[SSK_MODE_ACTIVE].ssk = send_info[SSK_MODE_BACKUP].ssk; + + /* According to the blest algorithm, to avoid HoL blocking for the + * faster flow, we need to: + * - estimate the faster flow linger time + * - use the above to estimate the amount of byte transferred + * by the faster flow + * - check that the amount of queued data is greter than the above, + * otherwise do not use the picked, slower, subflow + * We select the subflow with the shorter estimated time to flush + * the queued mem, which basically ensure the above. We just need + * to check that subflow has a non empty cwin. + */ + ssk = send_info[SSK_MODE_ACTIVE].ssk; + if (!ssk || !sk_stream_memory_free(ssk)) + return NULL; + + burst = min_t(int, MPTCP_SEND_BURST_SIZE, mptcp_wnd_end(msk) - msk->snd_nxt); + wmem = READ_ONCE(ssk->sk_wmem_queued); + if (!burst) { + msk->last_snd = NULL; + return ssk; + } + + subflow = mptcp_subflow_ctx(ssk); + subflow->avg_pacing_rate = div_u64((u64)subflow->avg_pacing_rate * wmem + + READ_ONCE(ssk->sk_pacing_rate) * burst, + burst + wmem); + msk->last_snd = ssk; + msk->snd_burst = burst; + return ssk; +} + +static void mptcp_push_release(struct sock *ssk, struct mptcp_sendmsg_info *info) +{ + tcp_push(ssk, 0, info->mss_now, tcp_sk(ssk)->nonagle, info->size_goal); + release_sock(ssk); +} + +static void mptcp_update_post_push(struct mptcp_sock *msk, + struct mptcp_data_frag *dfrag, + u32 sent) +{ + u64 snd_nxt_new = dfrag->data_seq; + + dfrag->already_sent += sent; + + msk->snd_burst -= sent; + + snd_nxt_new += dfrag->already_sent; + + /* snd_nxt_new can be smaller than snd_nxt in case mptcp + * is recovering after a failover. In that event, this re-sends + * old segments. + * + * Thus compute snd_nxt_new candidate based on + * the dfrag->data_seq that was sent and the data + * that has been handed to the subflow for transmission + * and skip update in case it was old dfrag. + */ + if (likely(after64(snd_nxt_new, msk->snd_nxt))) + msk->snd_nxt = snd_nxt_new; +} + +void mptcp_check_and_set_pending(struct sock *sk) +{ + if (mptcp_send_head(sk)) + mptcp_sk(sk)->push_pending |= BIT(MPTCP_PUSH_PENDING); +} + +void __mptcp_push_pending(struct sock *sk, unsigned int flags) +{ + struct sock *prev_ssk = NULL, *ssk = NULL; + struct mptcp_sock *msk = mptcp_sk(sk); + struct mptcp_sendmsg_info info = { + .flags = flags, + }; + bool do_check_data_fin = false; + struct mptcp_data_frag *dfrag; + int len; + + while ((dfrag = mptcp_send_head(sk))) { + info.sent = dfrag->already_sent; + info.limit = dfrag->data_len; + len = dfrag->data_len - dfrag->already_sent; + while (len > 0) { + int ret = 0; + + prev_ssk = ssk; + ssk = mptcp_subflow_get_send(msk); + + /* First check. If the ssk has changed since + * the last round, release prev_ssk + */ + if (ssk != prev_ssk && prev_ssk) + mptcp_push_release(prev_ssk, &info); + if (!ssk) + goto out; + + /* Need to lock the new subflow only if different + * from the previous one, otherwise we are still + * helding the relevant lock + */ + if (ssk != prev_ssk) + lock_sock(ssk); + + ret = mptcp_sendmsg_frag(sk, ssk, dfrag, &info); + if (ret <= 0) { + if (ret == -EAGAIN) + continue; + mptcp_push_release(ssk, &info); + goto out; + } + + do_check_data_fin = true; + info.sent += ret; + len -= ret; + + mptcp_update_post_push(msk, dfrag, ret); + } + WRITE_ONCE(msk->first_pending, mptcp_send_next(sk)); + } + + /* at this point we held the socket lock for the last subflow we used */ + if (ssk) + mptcp_push_release(ssk, &info); + +out: + /* ensure the rtx timer is running */ + if (!mptcp_rtx_timer_pending(sk)) + mptcp_reset_rtx_timer(sk); + if (do_check_data_fin) + mptcp_check_send_data_fin(sk); +} + +static void __mptcp_subflow_push_pending(struct sock *sk, struct sock *ssk) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + struct mptcp_sendmsg_info info = { + .data_lock_held = true, + }; + struct mptcp_data_frag *dfrag; + struct sock *xmit_ssk; + int len, copied = 0; + bool first = true; + + info.flags = 0; + while ((dfrag = mptcp_send_head(sk))) { + info.sent = dfrag->already_sent; + info.limit = dfrag->data_len; + len = dfrag->data_len - dfrag->already_sent; + while (len > 0) { + int ret = 0; + + /* the caller already invoked the packet scheduler, + * check for a different subflow usage only after + * spooling the first chunk of data + */ + xmit_ssk = first ? ssk : mptcp_subflow_get_send(mptcp_sk(sk)); + if (!xmit_ssk) + goto out; + if (xmit_ssk != ssk) { + mptcp_subflow_delegate(mptcp_subflow_ctx(xmit_ssk), + MPTCP_DELEGATE_SEND); + goto out; + } + + ret = mptcp_sendmsg_frag(sk, ssk, dfrag, &info); + if (ret <= 0) + goto out; + + info.sent += ret; + copied += ret; + len -= ret; + first = false; + + mptcp_update_post_push(msk, dfrag, ret); + } + WRITE_ONCE(msk->first_pending, mptcp_send_next(sk)); + } + +out: + /* __mptcp_alloc_tx_skb could have released some wmem and we are + * not going to flush it via release_sock() + */ + if (copied) { + tcp_push(ssk, 0, info.mss_now, tcp_sk(ssk)->nonagle, + info.size_goal); + if (!mptcp_rtx_timer_pending(sk)) + mptcp_reset_rtx_timer(sk); + + if (msk->snd_data_fin_enable && + msk->snd_nxt + 1 == msk->write_seq) + mptcp_schedule_work(sk); + } +} + +static void mptcp_set_nospace(struct sock *sk) +{ + /* enable autotune */ + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + + /* will be cleared on avail space */ + set_bit(MPTCP_NOSPACE, &mptcp_sk(sk)->flags); +} + +static int mptcp_disconnect(struct sock *sk, int flags); + +static int mptcp_sendmsg_fastopen(struct sock *sk, struct sock *ssk, struct msghdr *msg, + size_t len, int *copied_syn) +{ + unsigned int saved_flags = msg->msg_flags; + struct mptcp_sock *msk = mptcp_sk(sk); + int ret; + + lock_sock(ssk); + msg->msg_flags |= MSG_DONTWAIT; + msk->fastopening = 1; + ret = tcp_sendmsg_fastopen(ssk, msg, copied_syn, len, NULL); + msk->fastopening = 0; + msg->msg_flags = saved_flags; + release_sock(ssk); + + /* do the blocking bits of inet_stream_connect outside the ssk socket lock */ + if (ret == -EINPROGRESS && !(msg->msg_flags & MSG_DONTWAIT)) { + ret = __inet_stream_connect(sk->sk_socket, msg->msg_name, + msg->msg_namelen, msg->msg_flags, 1); + + /* Keep the same behaviour of plain TCP: zero the copied bytes in + * case of any error, except timeout or signal + */ + if (ret && ret != -EINPROGRESS && ret != -ERESTARTSYS && ret != -EINTR) + *copied_syn = 0; + } else if (ret && ret != -EINPROGRESS) { + /* The disconnect() op called by tcp_sendmsg_fastopen()/ + * __inet_stream_connect() can fail, due to looking check, + * see mptcp_disconnect(). + * Attempt it again outside the problematic scope. + */ + if (!mptcp_disconnect(sk, 0)) + sk->sk_socket->state = SS_UNCONNECTED; + } + + return ret; +} + +static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + struct page_frag *pfrag; + struct socket *ssock; + size_t copied = 0; + int ret = 0; + long timeo; + + /* we don't support FASTOPEN yet */ + if (msg->msg_flags & MSG_FASTOPEN) + return -EOPNOTSUPP; + + /* silently ignore everything else */ + msg->msg_flags &= MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL; + + lock_sock(sk); + + ssock = __mptcp_nmpc_socket(msk); + if (unlikely(ssock && inet_sk(ssock->sk)->defer_connect)) { + int copied_syn = 0; + + ret = mptcp_sendmsg_fastopen(sk, ssock->sk, msg, len, &copied_syn); + copied += copied_syn; + if (ret == -EINPROGRESS && copied_syn > 0) + goto out; + else if (ret) + goto do_error; + } + + timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + + if ((1 << sk->sk_state) & ~(TCPF_ESTABLISHED | TCPF_CLOSE_WAIT)) { + ret = sk_stream_wait_connect(sk, &timeo); + if (ret) + goto do_error; + } + + ret = -EPIPE; + if (unlikely(sk->sk_err || (sk->sk_shutdown & SEND_SHUTDOWN))) + goto do_error; + + pfrag = sk_page_frag(sk); + + while (msg_data_left(msg)) { + int total_ts, frag_truesize = 0; + struct mptcp_data_frag *dfrag; + bool dfrag_collapsed; + size_t psize, offset; + + /* reuse tail pfrag, if possible, or carve a new one from the + * page allocator + */ + dfrag = mptcp_pending_tail(sk); + dfrag_collapsed = mptcp_frag_can_collapse_to(msk, pfrag, dfrag); + if (!dfrag_collapsed) { + if (!sk_stream_memory_free(sk)) + goto wait_for_memory; + + if (!mptcp_page_frag_refill(sk, pfrag)) + goto wait_for_memory; + + dfrag = mptcp_carve_data_frag(msk, pfrag, pfrag->offset); + frag_truesize = dfrag->overhead; + } + + /* we do not bound vs wspace, to allow a single packet. + * memory accounting will prevent execessive memory usage + * anyway + */ + offset = dfrag->offset + dfrag->data_len; + psize = pfrag->size - offset; + psize = min_t(size_t, psize, msg_data_left(msg)); + total_ts = psize + frag_truesize; + + if (!sk_wmem_schedule(sk, total_ts)) + goto wait_for_memory; + + if (copy_page_from_iter(dfrag->page, offset, psize, + &msg->msg_iter) != psize) { + ret = -EFAULT; + goto do_error; + } + + /* data successfully copied into the write queue */ + sk_forward_alloc_add(sk, -total_ts); + copied += psize; + dfrag->data_len += psize; + frag_truesize += psize; + pfrag->offset += frag_truesize; + WRITE_ONCE(msk->write_seq, msk->write_seq + psize); + + /* charge data on mptcp pending queue to the msk socket + * Note: we charge such data both to sk and ssk + */ + sk_wmem_queued_add(sk, frag_truesize); + if (!dfrag_collapsed) { + get_page(dfrag->page); + list_add_tail(&dfrag->list, &msk->rtx_queue); + if (!msk->first_pending) + WRITE_ONCE(msk->first_pending, dfrag); + } + pr_debug("msk=%p dfrag at seq=%llu len=%u sent=%u new=%d", msk, + dfrag->data_seq, dfrag->data_len, dfrag->already_sent, + !dfrag_collapsed); + + continue; + +wait_for_memory: + mptcp_set_nospace(sk); + __mptcp_push_pending(sk, msg->msg_flags); + ret = sk_stream_wait_memory(sk, &timeo); + if (ret) + goto do_error; + } + + if (copied) + __mptcp_push_pending(sk, msg->msg_flags); + +out: + release_sock(sk); + return copied; + +do_error: + if (copied) + goto out; + + copied = sk_stream_error(sk, msg->msg_flags, ret); + goto out; +} + +static int __mptcp_recvmsg_mskq(struct mptcp_sock *msk, + struct msghdr *msg, + size_t len, int flags, + struct scm_timestamping_internal *tss, + int *cmsg_flags) +{ + struct sk_buff *skb, *tmp; + int copied = 0; + + skb_queue_walk_safe(&msk->receive_queue, skb, tmp) { + u32 offset = MPTCP_SKB_CB(skb)->offset; + u32 data_len = skb->len - offset; + u32 count = min_t(size_t, len - copied, data_len); + int err; + + if (!(flags & MSG_TRUNC)) { + err = skb_copy_datagram_msg(skb, offset, msg, count); + if (unlikely(err < 0)) { + if (!copied) + return err; + break; + } + } + + if (MPTCP_SKB_CB(skb)->has_rxtstamp) { + tcp_update_recv_tstamps(skb, tss); + *cmsg_flags |= MPTCP_CMSG_TS; + } + + copied += count; + + if (count < data_len) { + if (!(flags & MSG_PEEK)) { + MPTCP_SKB_CB(skb)->offset += count; + MPTCP_SKB_CB(skb)->map_seq += count; + } + break; + } + + if (!(flags & MSG_PEEK)) { + /* we will bulk release the skb memory later */ + skb->destructor = NULL; + WRITE_ONCE(msk->rmem_released, msk->rmem_released + skb->truesize); + __skb_unlink(skb, &msk->receive_queue); + __kfree_skb(skb); + } + + if (copied >= len) + break; + } + + return copied; +} + +/* receive buffer autotuning. See tcp_rcv_space_adjust for more information. + * + * Only difference: Use highest rtt estimate of the subflows in use. + */ +static void mptcp_rcv_space_adjust(struct mptcp_sock *msk, int copied) +{ + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + u32 time, advmss = 1; + u64 rtt_us, mstamp; + + sock_owned_by_me(sk); + + if (copied <= 0) + return; + + msk->rcvq_space.copied += copied; + + mstamp = div_u64(tcp_clock_ns(), NSEC_PER_USEC); + time = tcp_stamp_us_delta(mstamp, msk->rcvq_space.time); + + rtt_us = msk->rcvq_space.rtt_us; + if (rtt_us && time < (rtt_us >> 3)) + return; + + rtt_us = 0; + mptcp_for_each_subflow(msk, subflow) { + const struct tcp_sock *tp; + u64 sf_rtt_us; + u32 sf_advmss; + + tp = tcp_sk(mptcp_subflow_tcp_sock(subflow)); + + sf_rtt_us = READ_ONCE(tp->rcv_rtt_est.rtt_us); + sf_advmss = READ_ONCE(tp->advmss); + + rtt_us = max(sf_rtt_us, rtt_us); + advmss = max(sf_advmss, advmss); + } + + msk->rcvq_space.rtt_us = rtt_us; + if (time < (rtt_us >> 3) || rtt_us == 0) + return; + + if (msk->rcvq_space.copied <= msk->rcvq_space.space) + goto new_measure; + + if (READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_moderate_rcvbuf) && + !(sk->sk_userlocks & SOCK_RCVBUF_LOCK)) { + int rcvmem, rcvbuf; + u64 rcvwin, grow; + + rcvwin = ((u64)msk->rcvq_space.copied << 1) + 16 * advmss; + + grow = rcvwin * (msk->rcvq_space.copied - msk->rcvq_space.space); + + do_div(grow, msk->rcvq_space.space); + rcvwin += (grow << 1); + + rcvmem = SKB_TRUESIZE(advmss + MAX_TCP_HEADER); + while (tcp_win_from_space(sk, rcvmem) < advmss) + rcvmem += 128; + + do_div(rcvwin, advmss); + rcvbuf = min_t(u64, rcvwin * rcvmem, + READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_rmem[2])); + + if (rcvbuf > sk->sk_rcvbuf) { + u32 window_clamp; + + window_clamp = tcp_win_from_space(sk, rcvbuf); + WRITE_ONCE(sk->sk_rcvbuf, rcvbuf); + + /* Make subflows follow along. If we do not do this, we + * get drops at subflow level if skbs can't be moved to + * the mptcp rx queue fast enough (announced rcv_win can + * exceed ssk->sk_rcvbuf). + */ + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk; + bool slow; + + ssk = mptcp_subflow_tcp_sock(subflow); + slow = lock_sock_fast(ssk); + WRITE_ONCE(ssk->sk_rcvbuf, rcvbuf); + tcp_sk(ssk)->window_clamp = window_clamp; + tcp_cleanup_rbuf(ssk, 1); + unlock_sock_fast(ssk, slow); + } + } + } + + msk->rcvq_space.space = msk->rcvq_space.copied; +new_measure: + msk->rcvq_space.copied = 0; + msk->rcvq_space.time = mstamp; +} + +static void __mptcp_update_rmem(struct sock *sk) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + + if (!msk->rmem_released) + return; + + atomic_sub(msk->rmem_released, &sk->sk_rmem_alloc); + mptcp_rmem_uncharge(sk, msk->rmem_released); + WRITE_ONCE(msk->rmem_released, 0); +} + +static void __mptcp_splice_receive_queue(struct sock *sk) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + + skb_queue_splice_tail_init(&sk->sk_receive_queue, &msk->receive_queue); +} + +static bool __mptcp_move_skbs(struct mptcp_sock *msk) +{ + struct sock *sk = (struct sock *)msk; + unsigned int moved = 0; + bool ret, done; + + do { + struct sock *ssk = mptcp_subflow_recv_lookup(msk); + bool slowpath; + + /* we can have data pending in the subflows only if the msk + * receive buffer was full at subflow_data_ready() time, + * that is an unlikely slow path. + */ + if (likely(!ssk)) + break; + + slowpath = lock_sock_fast(ssk); + mptcp_data_lock(sk); + __mptcp_update_rmem(sk); + done = __mptcp_move_skbs_from_subflow(msk, ssk, &moved); + mptcp_data_unlock(sk); + + if (unlikely(ssk->sk_err)) + __mptcp_error_report(sk); + unlock_sock_fast(ssk, slowpath); + } while (!done); + + /* acquire the data lock only if some input data is pending */ + ret = moved > 0; + if (!RB_EMPTY_ROOT(&msk->out_of_order_queue) || + !skb_queue_empty_lockless(&sk->sk_receive_queue)) { + mptcp_data_lock(sk); + __mptcp_update_rmem(sk); + ret |= __mptcp_ofo_queue(msk); + __mptcp_splice_receive_queue(sk); + mptcp_data_unlock(sk); + } + if (ret) + mptcp_check_data_fin((struct sock *)msk); + return !skb_queue_empty(&msk->receive_queue); +} + +static unsigned int mptcp_inq_hint(const struct sock *sk) +{ + const struct mptcp_sock *msk = mptcp_sk(sk); + const struct sk_buff *skb; + + skb = skb_peek(&msk->receive_queue); + if (skb) { + u64 hint_val = msk->ack_seq - MPTCP_SKB_CB(skb)->map_seq; + + if (hint_val >= INT_MAX) + return INT_MAX; + + return (unsigned int)hint_val; + } + + if (sk->sk_state == TCP_CLOSE || (sk->sk_shutdown & RCV_SHUTDOWN)) + return 1; + + return 0; +} + +static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int flags, int *addr_len) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + struct scm_timestamping_internal tss; + int copied = 0, cmsg_flags = 0; + int target; + long timeo; + + /* MSG_ERRQUEUE is really a no-op till we support IP_RECVERR */ + if (unlikely(flags & MSG_ERRQUEUE)) + return inet_recv_error(sk, msg, len, addr_len); + + lock_sock(sk); + if (unlikely(sk->sk_state == TCP_LISTEN)) { + copied = -ENOTCONN; + goto out_err; + } + + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + + len = min_t(size_t, len, INT_MAX); + target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); + + if (unlikely(msk->recvmsg_inq)) + cmsg_flags = MPTCP_CMSG_INQ; + + while (copied < len) { + int bytes_read; + + bytes_read = __mptcp_recvmsg_mskq(msk, msg, len - copied, flags, &tss, &cmsg_flags); + if (unlikely(bytes_read < 0)) { + if (!copied) + copied = bytes_read; + goto out_err; + } + + copied += bytes_read; + + /* be sure to advertise window change */ + mptcp_cleanup_rbuf(msk); + + if (skb_queue_empty(&msk->receive_queue) && __mptcp_move_skbs(msk)) + continue; + + /* only the master socket status is relevant here. The exit + * conditions mirror closely tcp_recvmsg() + */ + if (copied >= target) + break; + + if (copied) { + if (sk->sk_err || + sk->sk_state == TCP_CLOSE || + (sk->sk_shutdown & RCV_SHUTDOWN) || + !timeo || + signal_pending(current)) + break; + } else { + if (sk->sk_err) { + copied = sock_error(sk); + break; + } + + if (test_and_clear_bit(MPTCP_WORK_EOF, &msk->flags)) + mptcp_check_for_eof(msk); + + if (sk->sk_shutdown & RCV_SHUTDOWN) { + /* race breaker: the shutdown could be after the + * previous receive queue check + */ + if (__mptcp_move_skbs(msk)) + continue; + break; + } + + if (sk->sk_state == TCP_CLOSE) { + copied = -ENOTCONN; + break; + } + + if (!timeo) { + copied = -EAGAIN; + break; + } + + if (signal_pending(current)) { + copied = sock_intr_errno(timeo); + break; + } + } + + pr_debug("block timeout %ld", timeo); + sk_wait_data(sk, &timeo, NULL); + } + +out_err: + if (cmsg_flags && copied >= 0) { + if (cmsg_flags & MPTCP_CMSG_TS) + tcp_recv_timestamp(msg, sk, &tss); + + if (cmsg_flags & MPTCP_CMSG_INQ) { + unsigned int inq = mptcp_inq_hint(sk); + + put_cmsg(msg, SOL_TCP, TCP_CM_INQ, sizeof(inq), &inq); + } + } + + pr_debug("msk=%p rx queue empty=%d:%d copied=%d", + msk, skb_queue_empty_lockless(&sk->sk_receive_queue), + skb_queue_empty(&msk->receive_queue), copied); + if (!(flags & MSG_PEEK)) + mptcp_rcv_space_adjust(msk, copied); + + release_sock(sk); + return copied; +} + +static void mptcp_retransmit_timer(struct timer_list *t) +{ + struct inet_connection_sock *icsk = from_timer(icsk, t, + icsk_retransmit_timer); + struct sock *sk = &icsk->icsk_inet.sk; + struct mptcp_sock *msk = mptcp_sk(sk); + + bh_lock_sock(sk); + if (!sock_owned_by_user(sk)) { + /* we need a process context to retransmit */ + if (!test_and_set_bit(MPTCP_WORK_RTX, &msk->flags)) + mptcp_schedule_work(sk); + } else { + /* delegate our work to tcp_release_cb() */ + __set_bit(MPTCP_RETRANSMIT, &msk->cb_flags); + } + bh_unlock_sock(sk); + sock_put(sk); +} + +static void mptcp_tout_timer(struct timer_list *t) +{ + struct sock *sk = from_timer(sk, t, sk_timer); + + mptcp_schedule_work(sk); + sock_put(sk); +} + +/* Find an idle subflow. Return NULL if there is unacked data at tcp + * level. + * + * A backup subflow is returned only if that is the only kind available. + */ +static struct sock *mptcp_subflow_get_retrans(struct mptcp_sock *msk) +{ + struct sock *backup = NULL, *pick = NULL; + struct mptcp_subflow_context *subflow; + int min_stale_count = INT_MAX; + + sock_owned_by_me((const struct sock *)msk); + + if (__mptcp_check_fallback(msk)) + return NULL; + + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + + if (!__mptcp_subflow_active(subflow)) + continue; + + /* still data outstanding at TCP level? skip this */ + if (!tcp_rtx_and_write_queues_empty(ssk)) { + mptcp_pm_subflow_chk_stale(msk, ssk); + min_stale_count = min_t(int, min_stale_count, subflow->stale_count); + continue; + } + + if (subflow->backup) { + if (!backup) + backup = ssk; + continue; + } + + if (!pick) + pick = ssk; + } + + if (pick) + return pick; + + /* use backup only if there are no progresses anywhere */ + return min_stale_count > 1 ? backup : NULL; +} + +static void mptcp_dispose_initial_subflow(struct mptcp_sock *msk) +{ + if (msk->subflow) { + iput(SOCK_INODE(msk->subflow)); + WRITE_ONCE(msk->subflow, NULL); + } +} + +bool __mptcp_retransmit_pending_data(struct sock *sk) +{ + struct mptcp_data_frag *cur, *rtx_head; + struct mptcp_sock *msk = mptcp_sk(sk); + + if (__mptcp_check_fallback(mptcp_sk(sk))) + return false; + + if (tcp_rtx_and_write_queues_empty(sk)) + return false; + + /* the closing socket has some data untransmitted and/or unacked: + * some data in the mptcp rtx queue has not really xmitted yet. + * keep it simple and re-inject the whole mptcp level rtx queue + */ + mptcp_data_lock(sk); + __mptcp_clean_una_wakeup(sk); + rtx_head = mptcp_rtx_head(sk); + if (!rtx_head) { + mptcp_data_unlock(sk); + return false; + } + + msk->recovery_snd_nxt = msk->snd_nxt; + msk->recovery = true; + mptcp_data_unlock(sk); + + msk->first_pending = rtx_head; + msk->snd_burst = 0; + + /* be sure to clear the "sent status" on all re-injected fragments */ + list_for_each_entry(cur, &msk->rtx_queue, list) { + if (!cur->already_sent) + break; + cur->already_sent = 0; + } + + return true; +} + +/* flags for __mptcp_close_ssk() */ +#define MPTCP_CF_PUSH BIT(1) +#define MPTCP_CF_FASTCLOSE BIT(2) + +/* be sure to send a reset only if the caller asked for it, also + * clean completely the subflow status when the subflow reaches + * TCP_CLOSE state + */ +static void __mptcp_subflow_disconnect(struct sock *ssk, + struct mptcp_subflow_context *subflow, + unsigned int flags) +{ + if (((1 << ssk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN)) || + (flags & MPTCP_CF_FASTCLOSE)) { + /* The MPTCP code never wait on the subflow sockets, TCP-level + * disconnect should never fail + */ + WARN_ON_ONCE(tcp_disconnect(ssk, 0)); + mptcp_subflow_ctx_reset(subflow); + } else { + tcp_shutdown(ssk, SEND_SHUTDOWN); + } +} + +/* subflow sockets can be either outgoing (connect) or incoming + * (accept). + * + * Outgoing subflows use in-kernel sockets. + * Incoming subflows do not have their own 'struct socket' allocated, + * so we need to use tcp_close() after detaching them from the mptcp + * parent socket. + */ +static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk, + struct mptcp_subflow_context *subflow, + unsigned int flags) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + bool dispose_it, need_push = false; + + /* If the first subflow moved to a close state before accept, e.g. due + * to an incoming reset or listener shutdown, the subflow socket is + * already deleted by inet_child_forget() and the mptcp socket can't + * survive too. + */ + if (msk->in_accept_queue && msk->first == ssk && + (sock_flag(sk, SOCK_DEAD) || sock_flag(ssk, SOCK_DEAD))) { + /* ensure later check in mptcp_worker() will dispose the msk */ + mptcp_set_close_tout(sk, tcp_jiffies32 - (TCP_TIMEWAIT_LEN + 1)); + sock_set_flag(sk, SOCK_DEAD); + lock_sock_nested(ssk, SINGLE_DEPTH_NESTING); + mptcp_subflow_drop_ctx(ssk); + goto out_release; + } + + dispose_it = !msk->subflow || ssk != msk->subflow->sk; + if (dispose_it) + list_del(&subflow->node); + + lock_sock_nested(ssk, SINGLE_DEPTH_NESTING); + + if ((flags & MPTCP_CF_FASTCLOSE) && !__mptcp_check_fallback(msk)) { + /* be sure to force the tcp_close path + * to generate the egress reset + */ + ssk->sk_lingertime = 0; + sock_set_flag(ssk, SOCK_LINGER); + subflow->send_fastclose = 1; + } + + need_push = (flags & MPTCP_CF_PUSH) && __mptcp_retransmit_pending_data(sk); + if (!dispose_it) { + __mptcp_subflow_disconnect(ssk, subflow, flags); + msk->subflow->state = SS_UNCONNECTED; + release_sock(ssk); + + goto out; + } + + subflow->disposable = 1; + + /* if ssk hit tcp_done(), tcp_cleanup_ulp() cleared the related ops + * the ssk has been already destroyed, we just need to release the + * reference owned by msk; + */ + if (!inet_csk(ssk)->icsk_ulp_ops) { + WARN_ON_ONCE(!sock_flag(ssk, SOCK_DEAD)); + kfree_rcu(subflow, rcu); + } else { + /* otherwise tcp will dispose of the ssk and subflow ctx */ + __tcp_close(ssk, 0); + + /* close acquired an extra ref */ + __sock_put(ssk); + } + +out_release: + __mptcp_subflow_error_report(sk, ssk); + release_sock(ssk); + + sock_put(ssk); + + if (ssk == msk->first) + WRITE_ONCE(msk->first, NULL); + +out: + if (ssk == msk->last_snd) + msk->last_snd = NULL; + + if (need_push) + __mptcp_push_pending(sk, 0); + + /* Catch every 'all subflows closed' scenario, including peers silently + * closing them, e.g. due to timeout. + * For established sockets, allow an additional timeout before closing, + * as the protocol can still create more subflows. + */ + if (list_is_singular(&msk->conn_list) && msk->first && + inet_sk_state_load(msk->first) == TCP_CLOSE) { + if (sk->sk_state != TCP_ESTABLISHED || + msk->in_accept_queue || sock_flag(sk, SOCK_DEAD)) { + inet_sk_state_store(sk, TCP_CLOSE); + mptcp_close_wake_up(sk); + } else { + mptcp_start_tout_timer(sk); + } + } +} + +void mptcp_close_ssk(struct sock *sk, struct sock *ssk, + struct mptcp_subflow_context *subflow) +{ + if (sk->sk_state == TCP_ESTABLISHED) + mptcp_event(MPTCP_EVENT_SUB_CLOSED, mptcp_sk(sk), ssk, GFP_KERNEL); + + /* subflow aborted before reaching the fully_established status + * attempt the creation of the next subflow + */ + mptcp_pm_subflow_check_next(mptcp_sk(sk), ssk, subflow); + + __mptcp_close_ssk(sk, ssk, subflow, MPTCP_CF_PUSH); +} + +static unsigned int mptcp_sync_mss(struct sock *sk, u32 pmtu) +{ + return 0; +} + +static void __mptcp_close_subflow(struct sock *sk) +{ + struct mptcp_subflow_context *subflow, *tmp; + struct mptcp_sock *msk = mptcp_sk(sk); + + might_sleep(); + + mptcp_for_each_subflow_safe(msk, subflow, tmp) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + + if (inet_sk_state_load(ssk) != TCP_CLOSE) + continue; + + /* 'subflow_data_ready' will re-sched once rx queue is empty */ + if (!skb_queue_empty_lockless(&ssk->sk_receive_queue)) + continue; + + mptcp_close_ssk(sk, ssk, subflow); + } + +} + +static bool mptcp_close_tout_expired(const struct sock *sk) +{ + if (!inet_csk(sk)->icsk_mtup.probe_timestamp || + sk->sk_state == TCP_CLOSE) + return false; + + return time_after32(tcp_jiffies32, + inet_csk(sk)->icsk_mtup.probe_timestamp + TCP_TIMEWAIT_LEN); +} + +static void mptcp_check_fastclose(struct mptcp_sock *msk) +{ + struct mptcp_subflow_context *subflow, *tmp; + struct sock *sk = &msk->sk.icsk_inet.sk; + + if (likely(!READ_ONCE(msk->rcv_fastclose))) + return; + + mptcp_token_destroy(msk); + + mptcp_for_each_subflow_safe(msk, subflow, tmp) { + struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow); + bool slow; + + slow = lock_sock_fast(tcp_sk); + if (tcp_sk->sk_state != TCP_CLOSE) { + tcp_send_active_reset(tcp_sk, GFP_ATOMIC); + tcp_set_state(tcp_sk, TCP_CLOSE); + } + unlock_sock_fast(tcp_sk, slow); + } + + /* Mirror the tcp_reset() error propagation */ + switch (sk->sk_state) { + case TCP_SYN_SENT: + WRITE_ONCE(sk->sk_err, ECONNREFUSED); + break; + case TCP_CLOSE_WAIT: + WRITE_ONCE(sk->sk_err, EPIPE); + break; + case TCP_CLOSE: + return; + default: + WRITE_ONCE(sk->sk_err, ECONNRESET); + } + + inet_sk_state_store(sk, TCP_CLOSE); + WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK); + smp_mb__before_atomic(); /* SHUTDOWN must be visible first */ + set_bit(MPTCP_WORK_CLOSE_SUBFLOW, &msk->flags); + + /* the calling mptcp_worker will properly destroy the socket */ + if (sock_flag(sk, SOCK_DEAD)) + return; + + sk->sk_state_change(sk); + sk_error_report(sk); +} + +static void __mptcp_retrans(struct sock *sk) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + struct mptcp_sendmsg_info info = {}; + struct mptcp_data_frag *dfrag; + size_t copied = 0; + struct sock *ssk; + int ret; + + mptcp_clean_una_wakeup(sk); + + /* first check ssk: need to kick "stale" logic */ + ssk = mptcp_subflow_get_retrans(msk); + dfrag = mptcp_rtx_head(sk); + if (!dfrag) { + if (mptcp_data_fin_enabled(msk)) { + struct inet_connection_sock *icsk = inet_csk(sk); + + icsk->icsk_retransmits++; + mptcp_set_datafin_timeout(sk); + mptcp_send_ack(msk); + + goto reset_timer; + } + + if (!mptcp_send_head(sk)) + return; + + goto reset_timer; + } + + if (!ssk) + goto reset_timer; + + lock_sock(ssk); + + /* limit retransmission to the bytes already sent on some subflows */ + info.sent = 0; + info.limit = READ_ONCE(msk->csum_enabled) ? dfrag->data_len : dfrag->already_sent; + while (info.sent < info.limit) { + ret = mptcp_sendmsg_frag(sk, ssk, dfrag, &info); + if (ret <= 0) + break; + + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RETRANSSEGS); + copied += ret; + info.sent += ret; + } + if (copied) { + dfrag->already_sent = max(dfrag->already_sent, info.sent); + tcp_push(ssk, 0, info.mss_now, tcp_sk(ssk)->nonagle, + info.size_goal); + WRITE_ONCE(msk->allow_infinite_fallback, false); + } + + release_sock(ssk); + +reset_timer: + mptcp_check_and_set_pending(sk); + + if (!mptcp_rtx_timer_pending(sk)) + mptcp_reset_rtx_timer(sk); +} + +/* schedule the timeout timer for the relevant event: either close timeout + * or mp_fail timeout. The close timeout takes precedence on the mp_fail one + */ +void mptcp_reset_tout_timer(struct mptcp_sock *msk, unsigned long fail_tout) +{ + struct sock *sk = (struct sock *)msk; + unsigned long timeout, close_timeout; + + if (!fail_tout && !inet_csk(sk)->icsk_mtup.probe_timestamp) + return; + + close_timeout = inet_csk(sk)->icsk_mtup.probe_timestamp - tcp_jiffies32 + jiffies + + TCP_TIMEWAIT_LEN; + + /* the close timeout takes precedence on the fail one, and here at least one of + * them is active + */ + timeout = inet_csk(sk)->icsk_mtup.probe_timestamp ? close_timeout : fail_tout; + + sk_reset_timer(sk, &sk->sk_timer, timeout); +} + +static void mptcp_mp_fail_no_response(struct mptcp_sock *msk) +{ + struct sock *ssk = msk->first; + bool slow; + + if (!ssk) + return; + + pr_debug("MP_FAIL doesn't respond, reset the subflow"); + + slow = lock_sock_fast(ssk); + mptcp_subflow_reset(ssk); + WRITE_ONCE(mptcp_subflow_ctx(ssk)->fail_tout, 0); + unlock_sock_fast(ssk, slow); +} + +static void mptcp_do_fastclose(struct sock *sk) +{ + struct mptcp_subflow_context *subflow, *tmp; + struct mptcp_sock *msk = mptcp_sk(sk); + + mptcp_for_each_subflow_safe(msk, subflow, tmp) + __mptcp_close_ssk(sk, mptcp_subflow_tcp_sock(subflow), + subflow, MPTCP_CF_FASTCLOSE); +} + +static void mptcp_worker(struct work_struct *work) +{ + struct mptcp_sock *msk = container_of(work, struct mptcp_sock, work); + struct sock *sk = &msk->sk.icsk_inet.sk; + unsigned long fail_tout; + int state; + + lock_sock(sk); + state = sk->sk_state; + if (unlikely((1 << state) & (TCPF_CLOSE | TCPF_LISTEN))) + goto unlock; + + mptcp_check_fastclose(msk); + + mptcp_pm_nl_work(msk); + + if (test_and_clear_bit(MPTCP_WORK_EOF, &msk->flags)) + mptcp_check_for_eof(msk); + + mptcp_check_send_data_fin(sk); + mptcp_check_data_fin_ack(sk); + mptcp_check_data_fin(sk); + + if (test_and_clear_bit(MPTCP_WORK_CLOSE_SUBFLOW, &msk->flags)) + __mptcp_close_subflow(sk); + + if (mptcp_close_tout_expired(sk)) { + inet_sk_state_store(sk, TCP_CLOSE); + mptcp_do_fastclose(sk); + mptcp_close_wake_up(sk); + } + + if (sock_flag(sk, SOCK_DEAD) && sk->sk_state == TCP_CLOSE) { + __mptcp_destroy_sock(sk); + goto unlock; + } + + if (test_and_clear_bit(MPTCP_WORK_RTX, &msk->flags)) + __mptcp_retrans(sk); + + fail_tout = msk->first ? READ_ONCE(mptcp_subflow_ctx(msk->first)->fail_tout) : 0; + if (fail_tout && time_after(jiffies, fail_tout)) + mptcp_mp_fail_no_response(msk); + +unlock: + release_sock(sk); + sock_put(sk); +} + +static int __mptcp_init_sock(struct sock *sk) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + + INIT_LIST_HEAD(&msk->conn_list); + INIT_LIST_HEAD(&msk->join_list); + INIT_LIST_HEAD(&msk->rtx_queue); + INIT_WORK(&msk->work, mptcp_worker); + __skb_queue_head_init(&msk->receive_queue); + msk->out_of_order_queue = RB_ROOT; + msk->first_pending = NULL; + msk->rmem_fwd_alloc = 0; + WRITE_ONCE(msk->rmem_released, 0); + msk->timer_ival = TCP_RTO_MIN; + + WRITE_ONCE(msk->first, NULL); + inet_csk(sk)->icsk_sync_mss = mptcp_sync_mss; + WRITE_ONCE(msk->csum_enabled, mptcp_is_checksum_enabled(sock_net(sk))); + WRITE_ONCE(msk->allow_infinite_fallback, true); + msk->recovery = false; + + mptcp_pm_data_init(msk); + + /* re-use the csk retrans timer for MPTCP-level retrans */ + timer_setup(&msk->sk.icsk_retransmit_timer, mptcp_retransmit_timer, 0); + timer_setup(&sk->sk_timer, mptcp_tout_timer, 0); + + return 0; +} + +static void mptcp_ca_reset(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + tcp_assign_congestion_control(sk); + strcpy(mptcp_sk(sk)->ca_name, icsk->icsk_ca_ops->name); + + /* no need to keep a reference to the ops, the name will suffice */ + tcp_cleanup_congestion_control(sk); + icsk->icsk_ca_ops = NULL; +} + +static int mptcp_init_sock(struct sock *sk) +{ + struct net *net = sock_net(sk); + int ret; + + ret = __mptcp_init_sock(sk); + if (ret) + return ret; + + if (!mptcp_is_enabled(net)) + return -ENOPROTOOPT; + + if (unlikely(!net->mib.mptcp_statistics) && !mptcp_mib_alloc(net)) + return -ENOMEM; + + ret = __mptcp_socket_create(mptcp_sk(sk)); + if (ret) + return ret; + + /* fetch the ca name; do it outside __mptcp_init_sock(), so that clone will + * propagate the correct value + */ + mptcp_ca_reset(sk); + + sk_sockets_allocated_inc(sk); + sk->sk_rcvbuf = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_rmem[1]); + sk->sk_sndbuf = READ_ONCE(sock_net(sk)->ipv4.sysctl_tcp_wmem[1]); + + return 0; +} + +static void __mptcp_clear_xmit(struct sock *sk) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + struct mptcp_data_frag *dtmp, *dfrag; + + WRITE_ONCE(msk->first_pending, NULL); + list_for_each_entry_safe(dfrag, dtmp, &msk->rtx_queue, list) + dfrag_clear(sk, dfrag); +} + +void mptcp_cancel_work(struct sock *sk) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + + if (cancel_work_sync(&msk->work)) + __sock_put(sk); +} + +void mptcp_subflow_shutdown(struct sock *sk, struct sock *ssk, int how) +{ + lock_sock(ssk); + + switch (ssk->sk_state) { + case TCP_LISTEN: + if (!(how & RCV_SHUTDOWN)) + break; + fallthrough; + case TCP_SYN_SENT: + WARN_ON_ONCE(tcp_disconnect(ssk, O_NONBLOCK)); + break; + default: + if (__mptcp_check_fallback(mptcp_sk(sk))) { + pr_debug("Fallback"); + ssk->sk_shutdown |= how; + tcp_shutdown(ssk, how); + + /* simulate the data_fin ack reception to let the state + * machine move forward + */ + WRITE_ONCE(mptcp_sk(sk)->snd_una, mptcp_sk(sk)->snd_nxt); + mptcp_schedule_work(sk); + } else { + pr_debug("Sending DATA_FIN on subflow %p", ssk); + tcp_send_ack(ssk); + if (!mptcp_rtx_timer_pending(sk)) + mptcp_reset_rtx_timer(sk); + } + break; + } + + release_sock(ssk); +} + +static const unsigned char new_state[16] = { + /* current state: new state: action: */ + [0 /* (Invalid) */] = TCP_CLOSE, + [TCP_ESTABLISHED] = TCP_FIN_WAIT1 | TCP_ACTION_FIN, + [TCP_SYN_SENT] = TCP_CLOSE, + [TCP_SYN_RECV] = TCP_FIN_WAIT1 | TCP_ACTION_FIN, + [TCP_FIN_WAIT1] = TCP_FIN_WAIT1, + [TCP_FIN_WAIT2] = TCP_FIN_WAIT2, + [TCP_TIME_WAIT] = TCP_CLOSE, /* should not happen ! */ + [TCP_CLOSE] = TCP_CLOSE, + [TCP_CLOSE_WAIT] = TCP_LAST_ACK | TCP_ACTION_FIN, + [TCP_LAST_ACK] = TCP_LAST_ACK, + [TCP_LISTEN] = TCP_CLOSE, + [TCP_CLOSING] = TCP_CLOSING, + [TCP_NEW_SYN_RECV] = TCP_CLOSE, /* should not happen ! */ +}; + +static int mptcp_close_state(struct sock *sk) +{ + int next = (int)new_state[sk->sk_state]; + int ns = next & TCP_STATE_MASK; + + inet_sk_state_store(sk, ns); + + return next & TCP_ACTION_FIN; +} + +static void mptcp_check_send_data_fin(struct sock *sk) +{ + struct mptcp_subflow_context *subflow; + struct mptcp_sock *msk = mptcp_sk(sk); + + pr_debug("msk=%p snd_data_fin_enable=%d pending=%d snd_nxt=%llu write_seq=%llu", + msk, msk->snd_data_fin_enable, !!mptcp_send_head(sk), + msk->snd_nxt, msk->write_seq); + + /* we still need to enqueue subflows or not really shutting down, + * skip this + */ + if (!msk->snd_data_fin_enable || msk->snd_nxt + 1 != msk->write_seq || + mptcp_send_head(sk)) + return; + + WRITE_ONCE(msk->snd_nxt, msk->write_seq); + + mptcp_for_each_subflow(msk, subflow) { + struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow); + + mptcp_subflow_shutdown(sk, tcp_sk, SEND_SHUTDOWN); + } +} + +static void __mptcp_wr_shutdown(struct sock *sk) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + + pr_debug("msk=%p snd_data_fin_enable=%d shutdown=%x state=%d pending=%d", + msk, msk->snd_data_fin_enable, sk->sk_shutdown, sk->sk_state, + !!mptcp_send_head(sk)); + + /* will be ignored by fallback sockets */ + WRITE_ONCE(msk->write_seq, msk->write_seq + 1); + WRITE_ONCE(msk->snd_data_fin_enable, 1); + + mptcp_check_send_data_fin(sk); +} + +static void __mptcp_destroy_sock(struct sock *sk) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + + pr_debug("msk=%p", msk); + + might_sleep(); + + mptcp_stop_rtx_timer(sk); + sk_stop_timer(sk, &sk->sk_timer); + msk->pm.status = 0; + + sk->sk_prot->destroy(sk); + + WARN_ON_ONCE(msk->rmem_fwd_alloc); + WARN_ON_ONCE(msk->rmem_released); + sk_stream_kill_queues(sk); + xfrm_sk_free_policy(sk); + + sk_refcnt_debug_release(sk); + sock_put(sk); +} + +void __mptcp_unaccepted_force_close(struct sock *sk) +{ + sock_set_flag(sk, SOCK_DEAD); + inet_sk_state_store(sk, TCP_CLOSE); + mptcp_do_fastclose(sk); + __mptcp_destroy_sock(sk); +} + +static __poll_t mptcp_check_readable(struct mptcp_sock *msk) +{ + /* Concurrent splices from sk_receive_queue into receive_queue will + * always show at least one non-empty queue when checked in this order. + */ + if (skb_queue_empty_lockless(&((struct sock *)msk)->sk_receive_queue) && + skb_queue_empty_lockless(&msk->receive_queue)) + return 0; + + return EPOLLIN | EPOLLRDNORM; +} + +static void mptcp_check_listen_stop(struct sock *sk) +{ + struct sock *ssk; + + if (inet_sk_state_load(sk) != TCP_LISTEN) + return; + + ssk = mptcp_sk(sk)->first; + if (WARN_ON_ONCE(!ssk || inet_sk_state_load(ssk) != TCP_LISTEN)) + return; + + lock_sock_nested(ssk, SINGLE_DEPTH_NESTING); + tcp_set_state(ssk, TCP_CLOSE); + mptcp_subflow_queue_clean(sk, ssk); + inet_csk_listen_stop(ssk); + release_sock(ssk); +} + +bool __mptcp_close(struct sock *sk, long timeout) +{ + struct mptcp_subflow_context *subflow; + struct mptcp_sock *msk = mptcp_sk(sk); + bool do_cancel_work = false; + int subflows_alive = 0; + + WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK); + + if ((1 << sk->sk_state) & (TCPF_LISTEN | TCPF_CLOSE)) { + mptcp_check_listen_stop(sk); + inet_sk_state_store(sk, TCP_CLOSE); + goto cleanup; + } + + if (mptcp_check_readable(msk)) { + /* the msk has read data, do the MPTCP equivalent of TCP reset */ + inet_sk_state_store(sk, TCP_CLOSE); + mptcp_do_fastclose(sk); + } else if (mptcp_close_state(sk)) { + __mptcp_wr_shutdown(sk); + } + + sk_stream_wait_close(sk, timeout); + +cleanup: + /* orphan all the subflows */ + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + bool slow = lock_sock_fast_nested(ssk); + + subflows_alive += ssk->sk_state != TCP_CLOSE; + + /* since the close timeout takes precedence on the fail one, + * cancel the latter + */ + if (ssk == msk->first) + subflow->fail_tout = 0; + + /* detach from the parent socket, but allow data_ready to + * push incoming data into the mptcp stack, to properly ack it + */ + ssk->sk_socket = NULL; + ssk->sk_wq = NULL; + unlock_sock_fast(ssk, slow); + } + sock_orphan(sk); + + /* all the subflows are closed, only timeout can change the msk + * state, let's not keep resources busy for no reasons + */ + if (subflows_alive == 0) + inet_sk_state_store(sk, TCP_CLOSE); + + sock_hold(sk); + pr_debug("msk=%p state=%d", sk, sk->sk_state); + if (mptcp_sk(sk)->token) + mptcp_event(MPTCP_EVENT_CLOSED, msk, NULL, GFP_KERNEL); + + if (sk->sk_state == TCP_CLOSE) { + __mptcp_destroy_sock(sk); + do_cancel_work = true; + } else { + mptcp_start_tout_timer(sk); + } + + return do_cancel_work; +} + +static void mptcp_close(struct sock *sk, long timeout) +{ + bool do_cancel_work; + + lock_sock(sk); + + do_cancel_work = __mptcp_close(sk, timeout); + release_sock(sk); + if (do_cancel_work) + mptcp_cancel_work(sk); + + sock_put(sk); +} + +static void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk) +{ +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + const struct ipv6_pinfo *ssk6 = inet6_sk(ssk); + struct ipv6_pinfo *msk6 = inet6_sk(msk); + + msk->sk_v6_daddr = ssk->sk_v6_daddr; + msk->sk_v6_rcv_saddr = ssk->sk_v6_rcv_saddr; + + if (msk6 && ssk6) { + msk6->saddr = ssk6->saddr; + msk6->flow_label = ssk6->flow_label; + } +#endif + + inet_sk(msk)->inet_num = inet_sk(ssk)->inet_num; + inet_sk(msk)->inet_dport = inet_sk(ssk)->inet_dport; + inet_sk(msk)->inet_sport = inet_sk(ssk)->inet_sport; + inet_sk(msk)->inet_daddr = inet_sk(ssk)->inet_daddr; + inet_sk(msk)->inet_saddr = inet_sk(ssk)->inet_saddr; + inet_sk(msk)->inet_rcv_saddr = inet_sk(ssk)->inet_rcv_saddr; +} + +static int mptcp_disconnect(struct sock *sk, int flags) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + + /* We are on the fastopen error path. We can't call straight into the + * subflows cleanup code due to lock nesting (we are already under + * msk->firstsocket lock). + */ + if (msk->fastopening) + return -EBUSY; + + mptcp_check_listen_stop(sk); + inet_sk_state_store(sk, TCP_CLOSE); + + mptcp_stop_rtx_timer(sk); + mptcp_stop_tout_timer(sk); + + if (mptcp_sk(sk)->token) + mptcp_event(MPTCP_EVENT_CLOSED, mptcp_sk(sk), NULL, GFP_KERNEL); + + /* msk->subflow is still intact, the following will not free the first + * subflow + */ + mptcp_destroy_common(msk, MPTCP_CF_FASTCLOSE); + msk->last_snd = NULL; + WRITE_ONCE(msk->flags, 0); + msk->cb_flags = 0; + msk->push_pending = 0; + msk->recovery = false; + msk->can_ack = false; + msk->fully_established = false; + msk->rcv_data_fin = false; + msk->snd_data_fin_enable = false; + msk->rcv_fastclose = false; + msk->use_64bit_ack = false; + WRITE_ONCE(msk->csum_enabled, mptcp_is_checksum_enabled(sock_net(sk))); + mptcp_pm_data_reset(msk); + mptcp_ca_reset(sk); + + WRITE_ONCE(sk->sk_shutdown, 0); + sk_error_report(sk); + return 0; +} + +#if IS_ENABLED(CONFIG_MPTCP_IPV6) +static struct ipv6_pinfo *mptcp_inet6_sk(const struct sock *sk) +{ + unsigned int offset = sizeof(struct mptcp6_sock) - sizeof(struct ipv6_pinfo); + + return (struct ipv6_pinfo *)(((u8 *)sk) + offset); +} +#endif + +struct sock *mptcp_sk_clone_init(const struct sock *sk, + const struct mptcp_options_received *mp_opt, + struct sock *ssk, + struct request_sock *req) +{ + struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); + struct sock *nsk = sk_clone_lock(sk, GFP_ATOMIC); + struct mptcp_sock *msk; + u64 ack_seq; + + if (!nsk) + return NULL; + +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + if (nsk->sk_family == AF_INET6) + inet_sk(nsk)->pinet6 = mptcp_inet6_sk(nsk); +#endif + + __mptcp_init_sock(nsk); + + msk = mptcp_sk(nsk); + msk->local_key = subflow_req->local_key; + msk->token = subflow_req->token; + WRITE_ONCE(msk->subflow, NULL); + msk->in_accept_queue = 1; + WRITE_ONCE(msk->fully_established, false); + if (mp_opt->suboptions & OPTION_MPTCP_CSUMREQD) + WRITE_ONCE(msk->csum_enabled, true); + + msk->write_seq = subflow_req->idsn + 1; + msk->snd_nxt = msk->write_seq; + msk->snd_una = msk->write_seq; + msk->wnd_end = msk->snd_nxt + req->rsk_rcv_wnd; + msk->setsockopt_seq = mptcp_sk(sk)->setsockopt_seq; + + if (mp_opt->suboptions & OPTIONS_MPTCP_MPC) { + msk->can_ack = true; + msk->remote_key = mp_opt->sndr_key; + mptcp_crypto_key_sha(msk->remote_key, NULL, &ack_seq); + ack_seq++; + WRITE_ONCE(msk->ack_seq, ack_seq); + atomic64_set(&msk->rcv_wnd_sent, ack_seq); + } + + sock_reset_flag(nsk, SOCK_RCU_FREE); + security_inet_csk_clone(nsk, req); + + /* this can't race with mptcp_close(), as the msk is + * not yet exposted to user-space + */ + inet_sk_state_store(nsk, TCP_ESTABLISHED); + + /* The msk maintain a ref to each subflow in the connections list */ + WRITE_ONCE(msk->first, ssk); + list_add(&mptcp_subflow_ctx(ssk)->node, &msk->conn_list); + sock_hold(ssk); + + /* new mpc subflow takes ownership of the newly + * created mptcp socket + */ + mptcp_token_accept(subflow_req, msk); + + /* set msk addresses early to ensure mptcp_pm_get_local_id() + * uses the correct data + */ + mptcp_copy_inaddrs(nsk, ssk); + mptcp_propagate_sndbuf(nsk, ssk); + + mptcp_rcv_space_init(msk, ssk); + bh_unlock_sock(nsk); + + /* note: the newly allocated socket refcount is 2 now */ + return nsk; +} + +void mptcp_rcv_space_init(struct mptcp_sock *msk, const struct sock *ssk) +{ + const struct tcp_sock *tp = tcp_sk(ssk); + + msk->rcvq_space.copied = 0; + msk->rcvq_space.rtt_us = 0; + + msk->rcvq_space.time = tp->tcp_mstamp; + + /* initial rcv_space offering made to peer */ + msk->rcvq_space.space = min_t(u32, tp->rcv_wnd, + TCP_INIT_CWND * tp->advmss); + if (msk->rcvq_space.space == 0) + msk->rcvq_space.space = TCP_INIT_CWND * TCP_MSS_DEFAULT; + + WRITE_ONCE(msk->wnd_end, msk->snd_nxt + tcp_sk(ssk)->snd_wnd); +} + +static struct sock *mptcp_accept(struct sock *sk, int flags, int *err, + bool kern) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + struct socket *listener; + struct sock *newsk; + + listener = READ_ONCE(msk->subflow); + if (WARN_ON_ONCE(!listener)) { + *err = -EINVAL; + return NULL; + } + + pr_debug("msk=%p, listener=%p", msk, mptcp_subflow_ctx(listener->sk)); + newsk = inet_csk_accept(listener->sk, flags, err, kern); + if (!newsk) + return NULL; + + pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk)); + if (sk_is_mptcp(newsk)) { + struct mptcp_subflow_context *subflow; + struct sock *new_mptcp_sock; + + subflow = mptcp_subflow_ctx(newsk); + new_mptcp_sock = subflow->conn; + + /* is_mptcp should be false if subflow->conn is missing, see + * subflow_syn_recv_sock() + */ + if (WARN_ON_ONCE(!new_mptcp_sock)) { + tcp_sk(newsk)->is_mptcp = 0; + goto out; + } + + newsk = new_mptcp_sock; + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPCAPABLEPASSIVEACK); + } else { + MPTCP_INC_STATS(sock_net(sk), + MPTCP_MIB_MPCAPABLEPASSIVEFALLBACK); + } + +out: + newsk->sk_kern_sock = kern; + return newsk; +} + +void mptcp_destroy_common(struct mptcp_sock *msk, unsigned int flags) +{ + struct mptcp_subflow_context *subflow, *tmp; + struct sock *sk = (struct sock *)msk; + + __mptcp_clear_xmit(sk); + + /* join list will be eventually flushed (with rst) at sock lock release time */ + mptcp_for_each_subflow_safe(msk, subflow, tmp) + __mptcp_close_ssk(sk, mptcp_subflow_tcp_sock(subflow), subflow, flags); + + /* move to sk_receive_queue, sk_stream_kill_queues will purge it */ + mptcp_data_lock(sk); + skb_queue_splice_tail_init(&msk->receive_queue, &sk->sk_receive_queue); + __skb_queue_purge(&sk->sk_receive_queue); + skb_rbtree_purge(&msk->out_of_order_queue); + mptcp_data_unlock(sk); + + /* move all the rx fwd alloc into the sk_mem_reclaim_final in + * inet_sock_destruct() will dispose it + */ + sk_forward_alloc_add(sk, msk->rmem_fwd_alloc); + WRITE_ONCE(msk->rmem_fwd_alloc, 0); + mptcp_token_destroy(msk); + mptcp_pm_free_anno_list(msk); + mptcp_free_local_addr_list(msk); +} + +static void mptcp_destroy(struct sock *sk) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + + /* clears msk->subflow, allowing the following to close + * even the initial subflow + */ + mptcp_dispose_initial_subflow(msk); + mptcp_destroy_common(msk, 0); + sk_sockets_allocated_dec(sk); +} + +void __mptcp_data_acked(struct sock *sk) +{ + if (!sock_owned_by_user(sk)) + __mptcp_clean_una(sk); + else + __set_bit(MPTCP_CLEAN_UNA, &mptcp_sk(sk)->cb_flags); + + if (mptcp_pending_data_fin_ack(sk)) + mptcp_schedule_work(sk); +} + +void __mptcp_check_push(struct sock *sk, struct sock *ssk) +{ + if (!mptcp_send_head(sk)) + return; + + if (!sock_owned_by_user(sk)) { + struct sock *xmit_ssk = mptcp_subflow_get_send(mptcp_sk(sk)); + + if (xmit_ssk == ssk) + __mptcp_subflow_push_pending(sk, ssk); + else if (xmit_ssk) + mptcp_subflow_delegate(mptcp_subflow_ctx(xmit_ssk), MPTCP_DELEGATE_SEND); + } else { + __set_bit(MPTCP_PUSH_PENDING, &mptcp_sk(sk)->cb_flags); + } +} + +#define MPTCP_FLAGS_PROCESS_CTX_NEED (BIT(MPTCP_PUSH_PENDING) | \ + BIT(MPTCP_RETRANSMIT) | \ + BIT(MPTCP_FLUSH_JOIN_LIST)) + +/* processes deferred events and flush wmem */ +static void mptcp_release_cb(struct sock *sk) + __must_hold(&sk->sk_lock.slock) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + + for (;;) { + unsigned long flags = (msk->cb_flags & MPTCP_FLAGS_PROCESS_CTX_NEED) | + msk->push_pending; + struct list_head join_list; + + if (!flags) + break; + + INIT_LIST_HEAD(&join_list); + list_splice_init(&msk->join_list, &join_list); + + /* the following actions acquire the subflow socket lock + * + * 1) can't be invoked in atomic scope + * 2) must avoid ABBA deadlock with msk socket spinlock: the RX + * datapath acquires the msk socket spinlock while helding + * the subflow socket lock + */ + msk->push_pending = 0; + msk->cb_flags &= ~flags; + spin_unlock_bh(&sk->sk_lock.slock); + + if (flags & BIT(MPTCP_FLUSH_JOIN_LIST)) + __mptcp_flush_join_list(sk, &join_list); + if (flags & BIT(MPTCP_PUSH_PENDING)) + __mptcp_push_pending(sk, 0); + if (flags & BIT(MPTCP_RETRANSMIT)) + __mptcp_retrans(sk); + + cond_resched(); + spin_lock_bh(&sk->sk_lock.slock); + } + + if (__test_and_clear_bit(MPTCP_CLEAN_UNA, &msk->cb_flags)) + __mptcp_clean_una_wakeup(sk); + if (unlikely(msk->cb_flags)) { + /* be sure to set the current sk state before tacking actions + * depending on sk_state, that is processing MPTCP_ERROR_REPORT + */ + if (__test_and_clear_bit(MPTCP_CONNECTED, &msk->cb_flags)) + __mptcp_set_connected(sk); + if (__test_and_clear_bit(MPTCP_ERROR_REPORT, &msk->cb_flags)) + __mptcp_error_report(sk); + if (__test_and_clear_bit(MPTCP_RESET_SCHEDULER, &msk->cb_flags)) + msk->last_snd = NULL; + } + + __mptcp_update_rmem(sk); +} + +/* MP_JOIN client subflow must wait for 4th ack before sending any data: + * TCP can't schedule delack timer before the subflow is fully established. + * MPTCP uses the delack timer to do 3rd ack retransmissions + */ +static void schedule_3rdack_retransmission(struct sock *ssk) +{ + struct inet_connection_sock *icsk = inet_csk(ssk); + struct tcp_sock *tp = tcp_sk(ssk); + unsigned long timeout; + + if (mptcp_subflow_ctx(ssk)->fully_established) + return; + + /* reschedule with a timeout above RTT, as we must look only for drop */ + if (tp->srtt_us) + timeout = usecs_to_jiffies(tp->srtt_us >> (3 - 1)); + else + timeout = TCP_TIMEOUT_INIT; + timeout += jiffies; + + WARN_ON_ONCE(icsk->icsk_ack.pending & ICSK_ACK_TIMER); + icsk->icsk_ack.pending |= ICSK_ACK_SCHED | ICSK_ACK_TIMER; + icsk->icsk_ack.timeout = timeout; + sk_reset_timer(ssk, &icsk->icsk_delack_timer, timeout); +} + +void mptcp_subflow_process_delegated(struct sock *ssk, long status) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + struct sock *sk = subflow->conn; + + if (status & BIT(MPTCP_DELEGATE_SEND)) { + mptcp_data_lock(sk); + if (!sock_owned_by_user(sk)) + __mptcp_subflow_push_pending(sk, ssk); + else + __set_bit(MPTCP_PUSH_PENDING, &mptcp_sk(sk)->cb_flags); + mptcp_data_unlock(sk); + } + if (status & BIT(MPTCP_DELEGATE_ACK)) + schedule_3rdack_retransmission(ssk); +} + +static int mptcp_hash(struct sock *sk) +{ + /* should never be called, + * we hash the TCP subflows not the master socket + */ + WARN_ON_ONCE(1); + return 0; +} + +static void mptcp_unhash(struct sock *sk) +{ + /* called from sk_common_release(), but nothing to do here */ +} + +static int mptcp_get_port(struct sock *sk, unsigned short snum) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + struct socket *ssock; + + ssock = msk->subflow; + pr_debug("msk=%p, subflow=%p", msk, ssock); + if (WARN_ON_ONCE(!ssock)) + return -EINVAL; + + return inet_csk_get_port(ssock->sk, snum); +} + +void mptcp_finish_connect(struct sock *ssk) +{ + struct mptcp_subflow_context *subflow; + struct mptcp_sock *msk; + struct sock *sk; + u64 ack_seq; + + subflow = mptcp_subflow_ctx(ssk); + sk = subflow->conn; + msk = mptcp_sk(sk); + + pr_debug("msk=%p, token=%u", sk, subflow->token); + + mptcp_crypto_key_sha(subflow->remote_key, NULL, &ack_seq); + ack_seq++; + subflow->map_seq = ack_seq; + subflow->map_subflow_seq = 1; + + /* the socket is not connected yet, no msk/subflow ops can access/race + * accessing the field below + */ + WRITE_ONCE(msk->remote_key, subflow->remote_key); + WRITE_ONCE(msk->local_key, subflow->local_key); + WRITE_ONCE(msk->write_seq, subflow->idsn + 1); + WRITE_ONCE(msk->snd_nxt, msk->write_seq); + WRITE_ONCE(msk->ack_seq, ack_seq); + WRITE_ONCE(msk->can_ack, 1); + WRITE_ONCE(msk->snd_una, msk->write_seq); + atomic64_set(&msk->rcv_wnd_sent, ack_seq); + + mptcp_pm_new_connection(msk, ssk, 0); + + mptcp_rcv_space_init(msk, ssk); +} + +void mptcp_sock_graft(struct sock *sk, struct socket *parent) +{ + write_lock_bh(&sk->sk_callback_lock); + rcu_assign_pointer(sk->sk_wq, &parent->wq); + sk_set_socket(sk, parent); + sk->sk_uid = SOCK_INODE(parent)->i_uid; + write_unlock_bh(&sk->sk_callback_lock); +} + +bool mptcp_finish_join(struct sock *ssk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + struct mptcp_sock *msk = mptcp_sk(subflow->conn); + struct sock *parent = (void *)msk; + bool ret = true; + + pr_debug("msk=%p, subflow=%p", msk, subflow); + + /* mptcp socket already closing? */ + if (!mptcp_is_fully_established(parent)) { + subflow->reset_reason = MPTCP_RST_EMPTCP; + return false; + } + + /* active subflow, already present inside the conn_list */ + if (!list_empty(&subflow->node)) { + mptcp_subflow_joined(msk, ssk); + return true; + } + + if (!mptcp_pm_allow_new_subflow(msk)) + goto err_prohibited; + + /* If we can't acquire msk socket lock here, let the release callback + * handle it + */ + mptcp_data_lock(parent); + if (!sock_owned_by_user(parent)) { + ret = __mptcp_finish_join(msk, ssk); + if (ret) { + sock_hold(ssk); + list_add_tail(&subflow->node, &msk->conn_list); + } + } else { + sock_hold(ssk); + list_add_tail(&subflow->node, &msk->join_list); + __set_bit(MPTCP_FLUSH_JOIN_LIST, &msk->cb_flags); + } + mptcp_data_unlock(parent); + + if (!ret) { +err_prohibited: + subflow->reset_reason = MPTCP_RST_EPROHIBIT; + return false; + } + + return true; +} + +static void mptcp_shutdown(struct sock *sk, int how) +{ + pr_debug("sk=%p, how=%d", sk, how); + + if ((how & SEND_SHUTDOWN) && mptcp_close_state(sk)) + __mptcp_wr_shutdown(sk); +} + +static int mptcp_forward_alloc_get(const struct sock *sk) +{ + return READ_ONCE(sk->sk_forward_alloc) + + READ_ONCE(mptcp_sk(sk)->rmem_fwd_alloc); +} + +static int mptcp_ioctl_outq(const struct mptcp_sock *msk, u64 v) +{ + const struct sock *sk = (void *)msk; + u64 delta; + + if (sk->sk_state == TCP_LISTEN) + return -EINVAL; + + if ((1 << sk->sk_state) & (TCPF_SYN_SENT | TCPF_SYN_RECV)) + return 0; + + delta = msk->write_seq - v; + if (__mptcp_check_fallback(msk) && msk->first) { + struct tcp_sock *tp = tcp_sk(msk->first); + + /* the first subflow is disconnected after close - see + * __mptcp_close_ssk(). tcp_disconnect() moves the write_seq + * so ignore that status, too. + */ + if (!((1 << msk->first->sk_state) & + (TCPF_SYN_SENT | TCPF_SYN_RECV | TCPF_CLOSE))) + delta += READ_ONCE(tp->write_seq) - tp->snd_una; + } + if (delta > INT_MAX) + delta = INT_MAX; + + return (int)delta; +} + +static int mptcp_ioctl(struct sock *sk, int cmd, unsigned long arg) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + bool slow; + int answ; + + switch (cmd) { + case SIOCINQ: + if (sk->sk_state == TCP_LISTEN) + return -EINVAL; + + lock_sock(sk); + __mptcp_move_skbs(msk); + answ = mptcp_inq_hint(sk); + release_sock(sk); + break; + case SIOCOUTQ: + slow = lock_sock_fast(sk); + answ = mptcp_ioctl_outq(msk, READ_ONCE(msk->snd_una)); + unlock_sock_fast(sk, slow); + break; + case SIOCOUTQNSD: + slow = lock_sock_fast(sk); + answ = mptcp_ioctl_outq(msk, msk->snd_nxt); + unlock_sock_fast(sk, slow); + break; + default: + return -ENOIOCTLCMD; + } + + return put_user(answ, (int __user *)arg); +} + +static void mptcp_subflow_early_fallback(struct mptcp_sock *msk, + struct mptcp_subflow_context *subflow) +{ + subflow->request_mptcp = 0; + __mptcp_do_fallback(msk); +} + +static int mptcp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) +{ + struct mptcp_subflow_context *subflow; + struct mptcp_sock *msk = mptcp_sk(sk); + struct socket *ssock; + int err = -EINVAL; + + ssock = __mptcp_nmpc_socket(msk); + if (!ssock) + return -EINVAL; + + mptcp_token_destroy(msk); + inet_sk_state_store(sk, TCP_SYN_SENT); + subflow = mptcp_subflow_ctx(ssock->sk); +#ifdef CONFIG_TCP_MD5SIG + /* no MPTCP if MD5SIG is enabled on this socket or we may run out of + * TCP option space. + */ + if (rcu_access_pointer(tcp_sk(ssock->sk)->md5sig_info)) + mptcp_subflow_early_fallback(msk, subflow); +#endif + if (subflow->request_mptcp && mptcp_token_new_connect(ssock->sk)) { + MPTCP_INC_STATS(sock_net(ssock->sk), MPTCP_MIB_TOKENFALLBACKINIT); + mptcp_subflow_early_fallback(msk, subflow); + } + if (likely(!__mptcp_check_fallback(msk))) + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPCAPABLEACTIVE); + + /* if reaching here via the fastopen/sendmsg path, the caller already + * acquired the subflow socket lock, too. + */ + if (msk->fastopening) + err = __inet_stream_connect(ssock, uaddr, addr_len, O_NONBLOCK, 1); + else + err = inet_stream_connect(ssock, uaddr, addr_len, O_NONBLOCK); + inet_sk(sk)->defer_connect = inet_sk(ssock->sk)->defer_connect; + + /* on successful connect, the msk state will be moved to established by + * subflow_finish_connect() + */ + if (unlikely(err && err != -EINPROGRESS)) { + inet_sk_state_store(sk, inet_sk_state_load(ssock->sk)); + return err; + } + + mptcp_copy_inaddrs(sk, ssock->sk); + + /* silence EINPROGRESS and let the caller inet_stream_connect + * handle the connection in progress + */ + return 0; +} + +static struct proto mptcp_prot = { + .name = "MPTCP", + .owner = THIS_MODULE, + .init = mptcp_init_sock, + .connect = mptcp_connect, + .disconnect = mptcp_disconnect, + .close = mptcp_close, + .accept = mptcp_accept, + .setsockopt = mptcp_setsockopt, + .getsockopt = mptcp_getsockopt, + .shutdown = mptcp_shutdown, + .destroy = mptcp_destroy, + .sendmsg = mptcp_sendmsg, + .ioctl = mptcp_ioctl, + .recvmsg = mptcp_recvmsg, + .release_cb = mptcp_release_cb, + .hash = mptcp_hash, + .unhash = mptcp_unhash, + .get_port = mptcp_get_port, + .forward_alloc_get = mptcp_forward_alloc_get, + .sockets_allocated = &mptcp_sockets_allocated, + + .memory_allocated = &tcp_memory_allocated, + .per_cpu_fw_alloc = &tcp_memory_per_cpu_fw_alloc, + + .memory_pressure = &tcp_memory_pressure, + .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_tcp_wmem), + .sysctl_rmem_offset = offsetof(struct net, ipv4.sysctl_tcp_rmem), + .sysctl_mem = sysctl_tcp_mem, + .obj_size = sizeof(struct mptcp_sock), + .slab_flags = SLAB_TYPESAFE_BY_RCU, + .no_autobind = true, +}; + +static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +{ + struct mptcp_sock *msk = mptcp_sk(sock->sk); + struct socket *ssock; + int err; + + lock_sock(sock->sk); + ssock = __mptcp_nmpc_socket(msk); + if (!ssock) { + err = -EINVAL; + goto unlock; + } + + err = ssock->ops->bind(ssock, uaddr, addr_len); + if (!err) + mptcp_copy_inaddrs(sock->sk, ssock->sk); + +unlock: + release_sock(sock->sk); + return err; +} + +static int mptcp_listen(struct socket *sock, int backlog) +{ + struct mptcp_sock *msk = mptcp_sk(sock->sk); + struct sock *sk = sock->sk; + struct socket *ssock; + int err; + + pr_debug("msk=%p", msk); + + lock_sock(sk); + + err = -EINVAL; + if (sock->state != SS_UNCONNECTED || sock->type != SOCK_STREAM) + goto unlock; + + ssock = __mptcp_nmpc_socket(msk); + if (!ssock) { + err = -EINVAL; + goto unlock; + } + + mptcp_token_destroy(msk); + inet_sk_state_store(sk, TCP_LISTEN); + sock_set_flag(sk, SOCK_RCU_FREE); + + err = ssock->ops->listen(ssock, backlog); + inet_sk_state_store(sk, inet_sk_state_load(ssock->sk)); + if (!err) + mptcp_copy_inaddrs(sk, ssock->sk); + +unlock: + release_sock(sk); + return err; +} + +static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, + int flags, bool kern) +{ + struct mptcp_sock *msk = mptcp_sk(sock->sk); + struct socket *ssock; + int err; + + pr_debug("msk=%p", msk); + + /* Buggy applications can call accept on socket states other then LISTEN + * but no need to allocate the first subflow just to error out. + */ + ssock = READ_ONCE(msk->subflow); + if (!ssock) + return -EINVAL; + + err = ssock->ops->accept(sock, newsock, flags, kern); + if (err == 0 && !mptcp_is_tcpsk(newsock->sk)) { + struct mptcp_sock *msk = mptcp_sk(newsock->sk); + struct mptcp_subflow_context *subflow; + struct sock *newsk = newsock->sk; + + msk->in_accept_queue = 0; + + lock_sock(newsk); + + /* set ssk->sk_socket of accept()ed flows to mptcp socket. + * This is needed so NOSPACE flag can be set from tcp stack. + */ + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + + if (!ssk->sk_socket) + mptcp_sock_graft(ssk, newsock); + } + + /* Do late cleanup for the first subflow as necessary. Also + * deal with bad peers not doing a complete shutdown. + */ + if (msk->first && + unlikely(inet_sk_state_load(msk->first) == TCP_CLOSE)) { + __mptcp_close_ssk(newsk, msk->first, + mptcp_subflow_ctx(msk->first), 0); + if (unlikely(list_empty(&msk->conn_list))) + inet_sk_state_store(newsk, TCP_CLOSE); + } + + release_sock(newsk); + } + + return err; +} + +static __poll_t mptcp_check_writeable(struct mptcp_sock *msk) +{ + struct sock *sk = (struct sock *)msk; + + if (sk_stream_is_writeable(sk)) + return EPOLLOUT | EPOLLWRNORM; + + mptcp_set_nospace(sk); + smp_mb__after_atomic(); /* msk->flags is changed by write_space cb */ + if (sk_stream_is_writeable(sk)) + return EPOLLOUT | EPOLLWRNORM; + + return 0; +} + +static __poll_t mptcp_poll(struct file *file, struct socket *sock, + struct poll_table_struct *wait) +{ + struct sock *sk = sock->sk; + struct mptcp_sock *msk; + __poll_t mask = 0; + u8 shutdown; + int state; + + msk = mptcp_sk(sk); + sock_poll_wait(file, sock, wait); + + state = inet_sk_state_load(sk); + pr_debug("msk=%p state=%d flags=%lx", msk, state, msk->flags); + if (state == TCP_LISTEN) { + struct socket *ssock = READ_ONCE(msk->subflow); + + if (WARN_ON_ONCE(!ssock || !ssock->sk)) + return 0; + + return inet_csk_listen_poll(ssock->sk); + } + + shutdown = READ_ONCE(sk->sk_shutdown); + if (shutdown == SHUTDOWN_MASK || state == TCP_CLOSE) + mask |= EPOLLHUP; + if (shutdown & RCV_SHUTDOWN) + mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP; + + if (state != TCP_SYN_SENT && state != TCP_SYN_RECV) { + mask |= mptcp_check_readable(msk); + if (shutdown & SEND_SHUTDOWN) + mask |= EPOLLOUT | EPOLLWRNORM; + else + mask |= mptcp_check_writeable(msk); + } else if (state == TCP_SYN_SENT && inet_sk(sk)->defer_connect) { + /* cf tcp_poll() note about TFO */ + mask |= EPOLLOUT | EPOLLWRNORM; + } + + /* This barrier is coupled with smp_wmb() in __mptcp_error_report() */ + smp_rmb(); + if (READ_ONCE(sk->sk_err)) + mask |= EPOLLERR; + + return mask; +} + +static const struct proto_ops mptcp_stream_ops = { + .family = PF_INET, + .owner = THIS_MODULE, + .release = inet_release, + .bind = mptcp_bind, + .connect = inet_stream_connect, + .socketpair = sock_no_socketpair, + .accept = mptcp_stream_accept, + .getname = inet_getname, + .poll = mptcp_poll, + .ioctl = inet_ioctl, + .gettstamp = sock_gettstamp, + .listen = mptcp_listen, + .shutdown = inet_shutdown, + .setsockopt = sock_common_setsockopt, + .getsockopt = sock_common_getsockopt, + .sendmsg = inet_sendmsg, + .recvmsg = inet_recvmsg, + .mmap = sock_no_mmap, + .sendpage = inet_sendpage, +}; + +static struct inet_protosw mptcp_protosw = { + .type = SOCK_STREAM, + .protocol = IPPROTO_MPTCP, + .prot = &mptcp_prot, + .ops = &mptcp_stream_ops, + .flags = INET_PROTOSW_ICSK, +}; + +static int mptcp_napi_poll(struct napi_struct *napi, int budget) +{ + struct mptcp_delegated_action *delegated; + struct mptcp_subflow_context *subflow; + int work_done = 0; + + delegated = container_of(napi, struct mptcp_delegated_action, napi); + while ((subflow = mptcp_subflow_delegated_next(delegated)) != NULL) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + + bh_lock_sock_nested(ssk); + if (!sock_owned_by_user(ssk)) { + mptcp_subflow_process_delegated(ssk, xchg(&subflow->delegated_status, 0)); + } else { + /* tcp_release_cb_override already processed + * the action or will do at next release_sock(). + * In both case must dequeue the subflow here - on the same + * CPU that scheduled it. + */ + smp_wmb(); + clear_bit(MPTCP_DELEGATE_SCHEDULED, &subflow->delegated_status); + } + bh_unlock_sock(ssk); + sock_put(ssk); + + if (++work_done == budget) + return budget; + } + + /* always provide a 0 'work_done' argument, so that napi_complete_done + * will not try accessing the NULL napi->dev ptr + */ + napi_complete_done(napi, 0); + return work_done; +} + +void __init mptcp_proto_init(void) +{ + struct mptcp_delegated_action *delegated; + int cpu; + + mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo; + + if (percpu_counter_init(&mptcp_sockets_allocated, 0, GFP_KERNEL)) + panic("Failed to allocate MPTCP pcpu counter\n"); + + init_dummy_netdev(&mptcp_napi_dev); + for_each_possible_cpu(cpu) { + delegated = per_cpu_ptr(&mptcp_delegated_actions, cpu); + INIT_LIST_HEAD(&delegated->head); + netif_napi_add_tx(&mptcp_napi_dev, &delegated->napi, + mptcp_napi_poll); + napi_enable(&delegated->napi); + } + + mptcp_subflow_init(); + mptcp_pm_init(); + mptcp_token_init(); + + if (proto_register(&mptcp_prot, 1) != 0) + panic("Failed to register MPTCP proto.\n"); + + inet_register_protosw(&mptcp_protosw); + + BUILD_BUG_ON(sizeof(struct mptcp_skb_cb) > sizeof_field(struct sk_buff, cb)); +} + +#if IS_ENABLED(CONFIG_MPTCP_IPV6) +static const struct proto_ops mptcp_v6_stream_ops = { + .family = PF_INET6, + .owner = THIS_MODULE, + .release = inet6_release, + .bind = mptcp_bind, + .connect = inet_stream_connect, + .socketpair = sock_no_socketpair, + .accept = mptcp_stream_accept, + .getname = inet6_getname, + .poll = mptcp_poll, + .ioctl = inet6_ioctl, + .gettstamp = sock_gettstamp, + .listen = mptcp_listen, + .shutdown = inet_shutdown, + .setsockopt = sock_common_setsockopt, + .getsockopt = sock_common_getsockopt, + .sendmsg = inet6_sendmsg, + .recvmsg = inet6_recvmsg, + .mmap = sock_no_mmap, + .sendpage = inet_sendpage, +#ifdef CONFIG_COMPAT + .compat_ioctl = inet6_compat_ioctl, +#endif +}; + +static struct proto mptcp_v6_prot; + +static struct inet_protosw mptcp_v6_protosw = { + .type = SOCK_STREAM, + .protocol = IPPROTO_MPTCP, + .prot = &mptcp_v6_prot, + .ops = &mptcp_v6_stream_ops, + .flags = INET_PROTOSW_ICSK, +}; + +int __init mptcp_proto_v6_init(void) +{ + int err; + + mptcp_v6_prot = mptcp_prot; + strcpy(mptcp_v6_prot.name, "MPTCPv6"); + mptcp_v6_prot.slab = NULL; + mptcp_v6_prot.obj_size = sizeof(struct mptcp6_sock); + + err = proto_register(&mptcp_v6_prot, 1); + if (err) + return err; + + err = inet6_register_protosw(&mptcp_v6_protosw); + if (err) + proto_unregister(&mptcp_v6_prot); + + return err; +} +#endif diff --git a/net/mptcp/protocol.h b/net/mptcp/protocol.h new file mode 100644 index 000000000..4ec8e0a81 --- /dev/null +++ b/net/mptcp/protocol.h @@ -0,0 +1,1038 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Multipath TCP + * + * Copyright (c) 2017 - 2019, Intel Corporation. + */ + +#ifndef __MPTCP_PROTOCOL_H +#define __MPTCP_PROTOCOL_H + +#include +#include +#include +#include +#include + +#define MPTCP_SUPPORTED_VERSION 1 + +/* MPTCP option bits */ +#define OPTION_MPTCP_MPC_SYN BIT(0) +#define OPTION_MPTCP_MPC_SYNACK BIT(1) +#define OPTION_MPTCP_MPC_ACK BIT(2) +#define OPTION_MPTCP_MPJ_SYN BIT(3) +#define OPTION_MPTCP_MPJ_SYNACK BIT(4) +#define OPTION_MPTCP_MPJ_ACK BIT(5) +#define OPTION_MPTCP_ADD_ADDR BIT(6) +#define OPTION_MPTCP_RM_ADDR BIT(7) +#define OPTION_MPTCP_FASTCLOSE BIT(8) +#define OPTION_MPTCP_PRIO BIT(9) +#define OPTION_MPTCP_RST BIT(10) +#define OPTION_MPTCP_DSS BIT(11) +#define OPTION_MPTCP_FAIL BIT(12) + +#define OPTION_MPTCP_CSUMREQD BIT(13) + +#define OPTIONS_MPTCP_MPC (OPTION_MPTCP_MPC_SYN | OPTION_MPTCP_MPC_SYNACK | \ + OPTION_MPTCP_MPC_ACK) +#define OPTIONS_MPTCP_MPJ (OPTION_MPTCP_MPJ_SYN | OPTION_MPTCP_MPJ_SYNACK | \ + OPTION_MPTCP_MPJ_ACK) + +/* MPTCP option subtypes */ +#define MPTCPOPT_MP_CAPABLE 0 +#define MPTCPOPT_MP_JOIN 1 +#define MPTCPOPT_DSS 2 +#define MPTCPOPT_ADD_ADDR 3 +#define MPTCPOPT_RM_ADDR 4 +#define MPTCPOPT_MP_PRIO 5 +#define MPTCPOPT_MP_FAIL 6 +#define MPTCPOPT_MP_FASTCLOSE 7 +#define MPTCPOPT_RST 8 + +/* MPTCP suboption lengths */ +#define TCPOLEN_MPTCP_MPC_SYN 4 +#define TCPOLEN_MPTCP_MPC_SYNACK 12 +#define TCPOLEN_MPTCP_MPC_ACK 20 +#define TCPOLEN_MPTCP_MPC_ACK_DATA 22 +#define TCPOLEN_MPTCP_MPJ_SYN 12 +#define TCPOLEN_MPTCP_MPJ_SYNACK 16 +#define TCPOLEN_MPTCP_MPJ_ACK 24 +#define TCPOLEN_MPTCP_DSS_BASE 4 +#define TCPOLEN_MPTCP_DSS_ACK32 4 +#define TCPOLEN_MPTCP_DSS_ACK64 8 +#define TCPOLEN_MPTCP_DSS_MAP32 10 +#define TCPOLEN_MPTCP_DSS_MAP64 14 +#define TCPOLEN_MPTCP_DSS_CHECKSUM 2 +#define TCPOLEN_MPTCP_ADD_ADDR 16 +#define TCPOLEN_MPTCP_ADD_ADDR_PORT 18 +#define TCPOLEN_MPTCP_ADD_ADDR_BASE 8 +#define TCPOLEN_MPTCP_ADD_ADDR_BASE_PORT 10 +#define TCPOLEN_MPTCP_ADD_ADDR6 28 +#define TCPOLEN_MPTCP_ADD_ADDR6_PORT 30 +#define TCPOLEN_MPTCP_ADD_ADDR6_BASE 20 +#define TCPOLEN_MPTCP_ADD_ADDR6_BASE_PORT 22 +#define TCPOLEN_MPTCP_PORT_LEN 2 +#define TCPOLEN_MPTCP_PORT_ALIGN 2 +#define TCPOLEN_MPTCP_RM_ADDR_BASE 3 +#define TCPOLEN_MPTCP_PRIO 3 +#define TCPOLEN_MPTCP_PRIO_ALIGN 4 +#define TCPOLEN_MPTCP_FASTCLOSE 12 +#define TCPOLEN_MPTCP_RST 4 +#define TCPOLEN_MPTCP_FAIL 12 + +#define TCPOLEN_MPTCP_MPC_ACK_DATA_CSUM (TCPOLEN_MPTCP_DSS_CHECKSUM + TCPOLEN_MPTCP_MPC_ACK_DATA) + +/* MPTCP MP_JOIN flags */ +#define MPTCPOPT_BACKUP BIT(0) +#define MPTCPOPT_THMAC_LEN 8 + +/* MPTCP MP_CAPABLE flags */ +#define MPTCP_VERSION_MASK (0x0F) +#define MPTCP_CAP_CHECKSUM_REQD BIT(7) +#define MPTCP_CAP_EXTENSIBILITY BIT(6) +#define MPTCP_CAP_DENY_JOIN_ID0 BIT(5) +#define MPTCP_CAP_HMAC_SHA256 BIT(0) +#define MPTCP_CAP_FLAG_MASK (0x1F) + +/* MPTCP DSS flags */ +#define MPTCP_DSS_DATA_FIN BIT(4) +#define MPTCP_DSS_DSN64 BIT(3) +#define MPTCP_DSS_HAS_MAP BIT(2) +#define MPTCP_DSS_ACK64 BIT(1) +#define MPTCP_DSS_HAS_ACK BIT(0) +#define MPTCP_DSS_FLAG_MASK (0x1F) + +/* MPTCP ADD_ADDR flags */ +#define MPTCP_ADDR_ECHO BIT(0) + +/* MPTCP MP_PRIO flags */ +#define MPTCP_PRIO_BKUP BIT(0) + +/* MPTCP TCPRST flags */ +#define MPTCP_RST_TRANSIENT BIT(0) + +/* MPTCP socket atomic flags */ +#define MPTCP_NOSPACE 1 +#define MPTCP_WORK_RTX 2 +#define MPTCP_WORK_EOF 3 +#define MPTCP_FALLBACK_DONE 4 +#define MPTCP_WORK_CLOSE_SUBFLOW 5 + +/* MPTCP socket release cb flags */ +#define MPTCP_PUSH_PENDING 1 +#define MPTCP_CLEAN_UNA 2 +#define MPTCP_ERROR_REPORT 3 +#define MPTCP_RETRANSMIT 4 +#define MPTCP_FLUSH_JOIN_LIST 5 +#define MPTCP_CONNECTED 6 +#define MPTCP_RESET_SCHEDULER 7 + +static inline bool before64(__u64 seq1, __u64 seq2) +{ + return (__s64)(seq1 - seq2) < 0; +} + +#define after64(seq2, seq1) before64(seq1, seq2) + +struct mptcp_options_received { + u64 sndr_key; + u64 rcvr_key; + u64 data_ack; + u64 data_seq; + u32 subflow_seq; + u16 data_len; + __sum16 csum; + u16 suboptions; + u32 token; + u32 nonce; + u16 use_map:1, + dsn64:1, + data_fin:1, + use_ack:1, + ack64:1, + mpc_map:1, + reset_reason:4, + reset_transient:1, + echo:1, + backup:1, + deny_join_id0:1, + __unused:2; + u8 join_id; + u64 thmac; + u8 hmac[MPTCPOPT_HMAC_LEN]; + struct mptcp_addr_info addr; + struct mptcp_rm_list rm_list; + u64 ahmac; + u64 fail_seq; +}; + +static inline __be32 mptcp_option(u8 subopt, u8 len, u8 nib, u8 field) +{ + return htonl((TCPOPT_MPTCP << 24) | (len << 16) | (subopt << 12) | + ((nib & 0xF) << 8) | field); +} + +enum mptcp_pm_status { + MPTCP_PM_ADD_ADDR_RECEIVED, + MPTCP_PM_ADD_ADDR_SEND_ACK, + MPTCP_PM_RM_ADDR_RECEIVED, + MPTCP_PM_ESTABLISHED, + MPTCP_PM_SUBFLOW_ESTABLISHED, + MPTCP_PM_ALREADY_ESTABLISHED, /* persistent status, set after ESTABLISHED event */ + MPTCP_PM_MPC_ENDPOINT_ACCOUNTED /* persistent status, set after MPC local address is + * accounted int id_avail_bitmap + */ +}; + +enum mptcp_pm_type { + MPTCP_PM_TYPE_KERNEL = 0, + MPTCP_PM_TYPE_USERSPACE, + + __MPTCP_PM_TYPE_NR, + __MPTCP_PM_TYPE_MAX = __MPTCP_PM_TYPE_NR - 1, +}; + +/* Status bits below MPTCP_PM_ALREADY_ESTABLISHED need pm worker actions */ +#define MPTCP_PM_WORK_MASK ((1 << MPTCP_PM_ALREADY_ESTABLISHED) - 1) + +enum mptcp_addr_signal_status { + MPTCP_ADD_ADDR_SIGNAL, + MPTCP_ADD_ADDR_ECHO, + MPTCP_RM_ADDR_SIGNAL, +}; + +/* max value of mptcp_addr_info.id */ +#define MPTCP_PM_MAX_ADDR_ID U8_MAX + +struct mptcp_pm_data { + struct mptcp_addr_info local; + struct mptcp_addr_info remote; + struct list_head anno_list; + struct list_head userspace_pm_local_addr_list; + + spinlock_t lock; /*protects the whole PM data */ + + u8 addr_signal; + bool server_side; + bool work_pending; + bool accept_addr; + bool accept_subflow; + bool remote_deny_join_id0; + u8 add_addr_signaled; + u8 add_addr_accepted; + u8 local_addr_used; + u8 pm_type; + u8 subflows; + u8 status; + DECLARE_BITMAP(id_avail_bitmap, MPTCP_PM_MAX_ADDR_ID + 1); + struct mptcp_rm_list rm_list_tx; + struct mptcp_rm_list rm_list_rx; +}; + +struct mptcp_pm_addr_entry { + struct list_head list; + struct mptcp_addr_info addr; + u8 flags; + int ifindex; + struct socket *lsk; +}; + +struct mptcp_data_frag { + struct list_head list; + u64 data_seq; + u16 data_len; + u16 offset; + u16 overhead; + u16 already_sent; + struct page *page; +}; + +/* MPTCP connection sock */ +struct mptcp_sock { + /* inet_connection_sock must be the first member */ + struct inet_connection_sock sk; + u64 local_key; + u64 remote_key; + u64 write_seq; + u64 snd_nxt; + u64 ack_seq; + atomic64_t rcv_wnd_sent; + u64 rcv_data_fin_seq; + int rmem_fwd_alloc; + struct sock *last_snd; + int snd_burst; + int old_wspace; + u64 recovery_snd_nxt; /* in recovery mode accept up to this seq; + * recovery related fields are under data_lock + * protection + */ + u64 snd_una; + u64 wnd_end; + unsigned long timer_ival; + u32 token; + int rmem_released; + unsigned long flags; + unsigned long cb_flags; + unsigned long push_pending; + bool recovery; /* closing subflow write queue reinjected */ + bool can_ack; + bool fully_established; + bool rcv_data_fin; + bool snd_data_fin_enable; + bool rcv_fastclose; + bool use_64bit_ack; /* Set when we received a 64-bit DSN */ + bool csum_enabled; + bool allow_infinite_fallback; + u8 mpc_endpoint_id; + u8 recvmsg_inq:1, + cork:1, + nodelay:1, + fastopening:1, + in_accept_queue:1; + struct work_struct work; + struct sk_buff *ooo_last_skb; + struct rb_root out_of_order_queue; + struct sk_buff_head receive_queue; + struct list_head conn_list; + struct list_head rtx_queue; + struct mptcp_data_frag *first_pending; + struct list_head join_list; + struct socket *subflow; /* outgoing connect/listener/!mp_capable + * The mptcp ops can safely dereference, using suitable + * ONCE annotation, the subflow outside the socket + * lock as such sock is freed after close(). + */ + struct sock *first; + struct mptcp_pm_data pm; + struct { + u32 space; /* bytes copied in last measurement window */ + u32 copied; /* bytes copied in this measurement window */ + u64 time; /* start time of measurement window */ + u64 rtt_us; /* last maximum rtt of subflows */ + } rcvq_space; + + u32 setsockopt_seq; + char ca_name[TCP_CA_NAME_MAX]; +}; + +#define mptcp_data_lock(sk) spin_lock_bh(&(sk)->sk_lock.slock) +#define mptcp_data_unlock(sk) spin_unlock_bh(&(sk)->sk_lock.slock) + +#define mptcp_for_each_subflow(__msk, __subflow) \ + list_for_each_entry(__subflow, &((__msk)->conn_list), node) +#define mptcp_for_each_subflow_safe(__msk, __subflow, __tmp) \ + list_for_each_entry_safe(__subflow, __tmp, &((__msk)->conn_list), node) + +static inline void msk_owned_by_me(const struct mptcp_sock *msk) +{ + sock_owned_by_me((const struct sock *)msk); +} + +static inline struct mptcp_sock *mptcp_sk(const struct sock *sk) +{ + return (struct mptcp_sock *)sk; +} + +/* the msk socket don't use the backlog, also account for the bulk + * free memory + */ +static inline int __mptcp_rmem(const struct sock *sk) +{ + return atomic_read(&sk->sk_rmem_alloc) - READ_ONCE(mptcp_sk(sk)->rmem_released); +} + +static inline int __mptcp_space(const struct sock *sk) +{ + return tcp_win_from_space(sk, READ_ONCE(sk->sk_rcvbuf) - __mptcp_rmem(sk)); +} + +static inline struct mptcp_data_frag *mptcp_send_head(const struct sock *sk) +{ + const struct mptcp_sock *msk = mptcp_sk(sk); + + return READ_ONCE(msk->first_pending); +} + +static inline struct mptcp_data_frag *mptcp_send_next(struct sock *sk) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + struct mptcp_data_frag *cur; + + cur = msk->first_pending; + return list_is_last(&cur->list, &msk->rtx_queue) ? NULL : + list_next_entry(cur, list); +} + +static inline struct mptcp_data_frag *mptcp_pending_tail(const struct sock *sk) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + + if (!msk->first_pending) + return NULL; + + if (WARN_ON_ONCE(list_empty(&msk->rtx_queue))) + return NULL; + + return list_last_entry(&msk->rtx_queue, struct mptcp_data_frag, list); +} + +static inline struct mptcp_data_frag *mptcp_rtx_head(const struct sock *sk) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + + if (msk->snd_una == READ_ONCE(msk->snd_nxt)) + return NULL; + + return list_first_entry_or_null(&msk->rtx_queue, struct mptcp_data_frag, list); +} + +struct csum_pseudo_header { + __be64 data_seq; + __be32 subflow_seq; + __be16 data_len; + __sum16 csum; +}; + +struct mptcp_subflow_request_sock { + struct tcp_request_sock sk; + u16 mp_capable : 1, + mp_join : 1, + backup : 1, + csum_reqd : 1, + allow_join_id0 : 1; + u8 local_id; + u8 remote_id; + u64 local_key; + u64 idsn; + u32 token; + u32 ssn_offset; + u64 thmac; + u32 local_nonce; + u32 remote_nonce; + struct mptcp_sock *msk; + struct hlist_nulls_node token_node; +}; + +static inline struct mptcp_subflow_request_sock * +mptcp_subflow_rsk(const struct request_sock *rsk) +{ + return (struct mptcp_subflow_request_sock *)rsk; +} + +enum mptcp_data_avail { + MPTCP_SUBFLOW_NODATA, + MPTCP_SUBFLOW_DATA_AVAIL, +}; + +struct mptcp_delegated_action { + struct napi_struct napi; + struct list_head head; +}; + +DECLARE_PER_CPU(struct mptcp_delegated_action, mptcp_delegated_actions); + +#define MPTCP_DELEGATE_SCHEDULED 0 +#define MPTCP_DELEGATE_SEND 1 +#define MPTCP_DELEGATE_ACK 2 + +#define MPTCP_DELEGATE_ACTIONS_MASK (~BIT(MPTCP_DELEGATE_SCHEDULED)) +/* MPTCP subflow context */ +struct mptcp_subflow_context { + struct list_head node;/* conn_list of subflows */ + + struct_group(reset, + + unsigned long avg_pacing_rate; /* protected by msk socket lock */ + u64 local_key; + u64 remote_key; + u64 idsn; + u64 map_seq; + u32 snd_isn; + u32 token; + u32 rel_write_seq; + u32 map_subflow_seq; + u32 ssn_offset; + u32 map_data_len; + __wsum map_data_csum; + u32 map_csum_len; + u32 request_mptcp : 1, /* send MP_CAPABLE */ + request_join : 1, /* send MP_JOIN */ + request_bkup : 1, + mp_capable : 1, /* remote is MPTCP capable */ + mp_join : 1, /* remote is JOINing */ + fully_established : 1, /* path validated */ + pm_notified : 1, /* PM hook called for established status */ + conn_finished : 1, + map_valid : 1, + map_csum_reqd : 1, + map_data_fin : 1, + mpc_map : 1, + backup : 1, + send_mp_prio : 1, + send_mp_fail : 1, + send_fastclose : 1, + send_infinite_map : 1, + rx_eof : 1, + can_ack : 1, /* only after processing the remote a key */ + disposable : 1, /* ctx can be free at ulp release time */ + stale : 1, /* unable to snd/rcv data, do not use for xmit */ + local_id_valid : 1, /* local_id is correctly initialized */ + valid_csum_seen : 1; /* at least one csum validated */ + enum mptcp_data_avail data_avail; + u32 remote_nonce; + u64 thmac; + u32 local_nonce; + u32 remote_token; + u8 hmac[MPTCPOPT_HMAC_LEN]; + u8 local_id; + u8 remote_id; + u8 reset_seen:1; + u8 reset_transient:1; + u8 reset_reason:4; + u8 stale_count; + + long delegated_status; + unsigned long fail_tout; + + ); + + struct list_head delegated_node; /* link into delegated_action, protected by local BH */ + + u32 setsockopt_seq; + u32 stale_rcv_tstamp; + + struct sock *tcp_sock; /* tcp sk backpointer */ + struct sock *conn; /* parent mptcp_sock */ + const struct inet_connection_sock_af_ops *icsk_af_ops; + void (*tcp_state_change)(struct sock *sk); + void (*tcp_error_report)(struct sock *sk); + + struct rcu_head rcu; +}; + +static inline struct mptcp_subflow_context * +mptcp_subflow_ctx(const struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + /* Use RCU on icsk_ulp_data only for sock diag code */ + return (__force struct mptcp_subflow_context *)icsk->icsk_ulp_data; +} + +static inline struct sock * +mptcp_subflow_tcp_sock(const struct mptcp_subflow_context *subflow) +{ + return subflow->tcp_sock; +} + +static inline void +mptcp_subflow_ctx_reset(struct mptcp_subflow_context *subflow) +{ + memset(&subflow->reset, 0, sizeof(subflow->reset)); + subflow->request_mptcp = 1; +} + +static inline u64 +mptcp_subflow_get_map_offset(const struct mptcp_subflow_context *subflow) +{ + return tcp_sk(mptcp_subflow_tcp_sock(subflow))->copied_seq - + subflow->ssn_offset - + subflow->map_subflow_seq; +} + +static inline u64 +mptcp_subflow_get_mapped_dsn(const struct mptcp_subflow_context *subflow) +{ + return subflow->map_seq + mptcp_subflow_get_map_offset(subflow); +} + +void mptcp_subflow_process_delegated(struct sock *ssk, long actions); + +static inline void mptcp_subflow_delegate(struct mptcp_subflow_context *subflow, int action) +{ + long old, set_bits = BIT(MPTCP_DELEGATE_SCHEDULED) | BIT(action); + struct mptcp_delegated_action *delegated; + bool schedule; + + /* the caller held the subflow bh socket lock */ + lockdep_assert_in_softirq(); + + /* The implied barrier pairs with tcp_release_cb_override() + * mptcp_napi_poll(), and ensures the below list check sees list + * updates done prior to delegated status bits changes + */ + old = set_mask_bits(&subflow->delegated_status, 0, set_bits); + if (!(old & BIT(MPTCP_DELEGATE_SCHEDULED))) { + if (WARN_ON_ONCE(!list_empty(&subflow->delegated_node))) + return; + + delegated = this_cpu_ptr(&mptcp_delegated_actions); + schedule = list_empty(&delegated->head); + list_add_tail(&subflow->delegated_node, &delegated->head); + sock_hold(mptcp_subflow_tcp_sock(subflow)); + if (schedule) + napi_schedule(&delegated->napi); + } +} + +static inline struct mptcp_subflow_context * +mptcp_subflow_delegated_next(struct mptcp_delegated_action *delegated) +{ + struct mptcp_subflow_context *ret; + + if (list_empty(&delegated->head)) + return NULL; + + ret = list_first_entry(&delegated->head, struct mptcp_subflow_context, delegated_node); + list_del_init(&ret->delegated_node); + return ret; +} + +int mptcp_is_enabled(const struct net *net); +unsigned int mptcp_get_add_addr_timeout(const struct net *net); +int mptcp_is_checksum_enabled(const struct net *net); +int mptcp_allow_join_id0(const struct net *net); +unsigned int mptcp_stale_loss_cnt(const struct net *net); +int mptcp_get_pm_type(const struct net *net); +void mptcp_subflow_fully_established(struct mptcp_subflow_context *subflow, + struct mptcp_options_received *mp_opt); +bool __mptcp_retransmit_pending_data(struct sock *sk); +void mptcp_check_and_set_pending(struct sock *sk); +void __mptcp_push_pending(struct sock *sk, unsigned int flags); +bool mptcp_subflow_data_available(struct sock *sk); +void __init mptcp_subflow_init(void); +void mptcp_subflow_shutdown(struct sock *sk, struct sock *ssk, int how); +void mptcp_close_ssk(struct sock *sk, struct sock *ssk, + struct mptcp_subflow_context *subflow); +void __mptcp_subflow_send_ack(struct sock *ssk); +void mptcp_subflow_reset(struct sock *ssk); +void mptcp_subflow_queue_clean(struct sock *sk, struct sock *ssk); +void mptcp_sock_graft(struct sock *sk, struct socket *parent); +struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk); +bool __mptcp_close(struct sock *sk, long timeout); +void mptcp_cancel_work(struct sock *sk); +void __mptcp_unaccepted_force_close(struct sock *sk); + +bool mptcp_addresses_equal(const struct mptcp_addr_info *a, + const struct mptcp_addr_info *b, bool use_port); + +/* called with sk socket lock held */ +int __mptcp_subflow_connect(struct sock *sk, const struct mptcp_addr_info *loc, + const struct mptcp_addr_info *remote); +int mptcp_subflow_create_socket(struct sock *sk, unsigned short family, + struct socket **new_sock); +void mptcp_info2sockaddr(const struct mptcp_addr_info *info, + struct sockaddr_storage *addr, + unsigned short family); + +static inline bool __tcp_can_send(const struct sock *ssk) +{ + /* only send if our side has not closed yet */ + return ((1 << inet_sk_state_load(ssk)) & (TCPF_ESTABLISHED | TCPF_CLOSE_WAIT)); +} + +static inline bool __mptcp_subflow_active(struct mptcp_subflow_context *subflow) +{ + /* can't send if JOIN hasn't completed yet (i.e. is usable for mptcp) */ + if (subflow->request_join && !subflow->fully_established) + return false; + + return __tcp_can_send(mptcp_subflow_tcp_sock(subflow)); +} + +void mptcp_subflow_set_active(struct mptcp_subflow_context *subflow); + +bool mptcp_subflow_active(struct mptcp_subflow_context *subflow); + +void mptcp_subflow_drop_ctx(struct sock *ssk); + +static inline void mptcp_subflow_tcp_fallback(struct sock *sk, + struct mptcp_subflow_context *ctx) +{ + sk->sk_data_ready = sock_def_readable; + sk->sk_state_change = ctx->tcp_state_change; + sk->sk_write_space = sk_stream_write_space; + sk->sk_error_report = ctx->tcp_error_report; + + inet_csk(sk)->icsk_af_ops = ctx->icsk_af_ops; +} + +void __init mptcp_proto_init(void); +#if IS_ENABLED(CONFIG_MPTCP_IPV6) +int __init mptcp_proto_v6_init(void); +#endif + +struct sock *mptcp_sk_clone_init(const struct sock *sk, + const struct mptcp_options_received *mp_opt, + struct sock *ssk, + struct request_sock *req); +void mptcp_get_options(const struct sk_buff *skb, + struct mptcp_options_received *mp_opt); + +void mptcp_finish_connect(struct sock *sk); +void __mptcp_set_connected(struct sock *sk); +void mptcp_reset_tout_timer(struct mptcp_sock *msk, unsigned long fail_tout); + +static inline void mptcp_stop_tout_timer(struct sock *sk) +{ + if (!inet_csk(sk)->icsk_mtup.probe_timestamp) + return; + + sk_stop_timer(sk, &sk->sk_timer); + inet_csk(sk)->icsk_mtup.probe_timestamp = 0; +} + +static inline void mptcp_set_close_tout(struct sock *sk, unsigned long tout) +{ + /* avoid 0 timestamp, as that means no close timeout */ + inet_csk(sk)->icsk_mtup.probe_timestamp = tout ? : 1; +} + +static inline void mptcp_start_tout_timer(struct sock *sk) +{ + mptcp_set_close_tout(sk, tcp_jiffies32); + mptcp_reset_tout_timer(mptcp_sk(sk), 0); +} + +static inline bool mptcp_is_fully_established(struct sock *sk) +{ + return inet_sk_state_load(sk) == TCP_ESTABLISHED && + READ_ONCE(mptcp_sk(sk)->fully_established); +} +void mptcp_rcv_space_init(struct mptcp_sock *msk, const struct sock *ssk); +void mptcp_data_ready(struct sock *sk, struct sock *ssk); +bool mptcp_finish_join(struct sock *sk); +bool mptcp_schedule_work(struct sock *sk); +int mptcp_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen); +int mptcp_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *option); + +u64 __mptcp_expand_seq(u64 old_seq, u64 cur_seq); +static inline u64 mptcp_expand_seq(u64 old_seq, u64 cur_seq, bool use_64bit) +{ + if (use_64bit) + return cur_seq; + + return __mptcp_expand_seq(old_seq, cur_seq); +} +void __mptcp_check_push(struct sock *sk, struct sock *ssk); +void __mptcp_data_acked(struct sock *sk); +void __mptcp_error_report(struct sock *sk); +void mptcp_subflow_eof(struct sock *sk); +bool mptcp_update_rcv_data_fin(struct mptcp_sock *msk, u64 data_fin_seq, bool use_64bit); +static inline bool mptcp_data_fin_enabled(const struct mptcp_sock *msk) +{ + return READ_ONCE(msk->snd_data_fin_enable) && + READ_ONCE(msk->write_seq) == READ_ONCE(msk->snd_nxt); +} + +static inline bool mptcp_propagate_sndbuf(struct sock *sk, struct sock *ssk) +{ + if ((sk->sk_userlocks & SOCK_SNDBUF_LOCK) || ssk->sk_sndbuf <= READ_ONCE(sk->sk_sndbuf)) + return false; + + WRITE_ONCE(sk->sk_sndbuf, ssk->sk_sndbuf); + return true; +} + +static inline void mptcp_write_space(struct sock *sk) +{ + if (sk_stream_is_writeable(sk)) { + /* pairs with memory barrier in mptcp_poll */ + smp_mb(); + if (test_and_clear_bit(MPTCP_NOSPACE, &mptcp_sk(sk)->flags)) + sk_stream_write_space(sk); + } +} + +void mptcp_destroy_common(struct mptcp_sock *msk, unsigned int flags); + +#define MPTCP_TOKEN_MAX_RETRIES 4 + +void __init mptcp_token_init(void); +static inline void mptcp_token_init_request(struct request_sock *req) +{ + mptcp_subflow_rsk(req)->token_node.pprev = NULL; +} + +int mptcp_token_new_request(struct request_sock *req); +void mptcp_token_destroy_request(struct request_sock *req); +int mptcp_token_new_connect(struct sock *sk); +void mptcp_token_accept(struct mptcp_subflow_request_sock *r, + struct mptcp_sock *msk); +bool mptcp_token_exists(u32 token); +struct mptcp_sock *mptcp_token_get_sock(struct net *net, u32 token); +struct mptcp_sock *mptcp_token_iter_next(const struct net *net, long *s_slot, + long *s_num); +void mptcp_token_destroy(struct mptcp_sock *msk); + +void mptcp_crypto_key_sha(u64 key, u32 *token, u64 *idsn); + +void mptcp_crypto_hmac_sha(u64 key1, u64 key2, u8 *msg, int len, void *hmac); +__sum16 __mptcp_make_csum(u64 data_seq, u32 subflow_seq, u16 data_len, __wsum sum); + +void __init mptcp_pm_init(void); +void mptcp_pm_data_init(struct mptcp_sock *msk); +void mptcp_pm_data_reset(struct mptcp_sock *msk); +int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info, + struct mptcp_addr_info *addr); +int mptcp_pm_parse_entry(struct nlattr *attr, struct genl_info *info, + bool require_family, + struct mptcp_pm_addr_entry *entry); +bool mptcp_pm_addr_families_match(const struct sock *sk, + const struct mptcp_addr_info *loc, + const struct mptcp_addr_info *rem); +void mptcp_pm_subflow_chk_stale(const struct mptcp_sock *msk, struct sock *ssk); +void mptcp_pm_nl_subflow_chk_stale(const struct mptcp_sock *msk, struct sock *ssk); +void mptcp_pm_new_connection(struct mptcp_sock *msk, const struct sock *ssk, int server_side); +void mptcp_pm_fully_established(struct mptcp_sock *msk, const struct sock *ssk, gfp_t gfp); +bool mptcp_pm_allow_new_subflow(struct mptcp_sock *msk); +void mptcp_pm_connection_closed(struct mptcp_sock *msk); +void mptcp_pm_subflow_established(struct mptcp_sock *msk); +bool mptcp_pm_nl_check_work_pending(struct mptcp_sock *msk); +void mptcp_pm_subflow_check_next(struct mptcp_sock *msk, const struct sock *ssk, + const struct mptcp_subflow_context *subflow); +void mptcp_pm_add_addr_received(const struct sock *ssk, + const struct mptcp_addr_info *addr); +void mptcp_pm_add_addr_echoed(struct mptcp_sock *msk, + const struct mptcp_addr_info *addr); +void mptcp_pm_add_addr_send_ack(struct mptcp_sock *msk); +void mptcp_pm_nl_addr_send_ack(struct mptcp_sock *msk); +void mptcp_pm_rm_addr_received(struct mptcp_sock *msk, + const struct mptcp_rm_list *rm_list); +void mptcp_pm_mp_prio_received(struct sock *sk, u8 bkup); +void mptcp_pm_mp_fail_received(struct sock *sk, u64 fail_seq); +int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk, + struct mptcp_addr_info *addr, + struct mptcp_addr_info *rem, + u8 bkup); +bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk, + const struct mptcp_pm_addr_entry *entry); +void mptcp_pm_free_anno_list(struct mptcp_sock *msk); +bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk); +struct mptcp_pm_add_entry * +mptcp_pm_del_add_timer(struct mptcp_sock *msk, + const struct mptcp_addr_info *addr, bool check_id); +struct mptcp_pm_add_entry * +mptcp_lookup_anno_list_by_saddr(const struct mptcp_sock *msk, + const struct mptcp_addr_info *addr); +int mptcp_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk, + unsigned int id, + u8 *flags, int *ifindex); +int mptcp_userspace_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk, + unsigned int id, + u8 *flags, int *ifindex); +int mptcp_userspace_pm_set_flags(struct net *net, struct nlattr *token, + struct mptcp_pm_addr_entry *loc, + struct mptcp_pm_addr_entry *rem, u8 bkup); +int mptcp_pm_announce_addr(struct mptcp_sock *msk, + const struct mptcp_addr_info *addr, + bool echo); +int mptcp_pm_remove_addr(struct mptcp_sock *msk, const struct mptcp_rm_list *rm_list); +int mptcp_pm_remove_subflow(struct mptcp_sock *msk, const struct mptcp_rm_list *rm_list); +void mptcp_pm_remove_addrs(struct mptcp_sock *msk, struct list_head *rm_list); +void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk, + struct list_head *rm_list); + +int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk, + struct mptcp_pm_addr_entry *entry); +void mptcp_free_local_addr_list(struct mptcp_sock *msk); +int mptcp_nl_cmd_announce(struct sk_buff *skb, struct genl_info *info); +int mptcp_nl_cmd_remove(struct sk_buff *skb, struct genl_info *info); +int mptcp_nl_cmd_sf_create(struct sk_buff *skb, struct genl_info *info); +int mptcp_nl_cmd_sf_destroy(struct sk_buff *skb, struct genl_info *info); + +void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk, + const struct sock *ssk, gfp_t gfp); +void mptcp_event_addr_announced(const struct sock *ssk, const struct mptcp_addr_info *info); +void mptcp_event_addr_removed(const struct mptcp_sock *msk, u8 id); +bool mptcp_userspace_pm_active(const struct mptcp_sock *msk); + +static inline bool mptcp_pm_should_add_signal(struct mptcp_sock *msk) +{ + return READ_ONCE(msk->pm.addr_signal) & + (BIT(MPTCP_ADD_ADDR_SIGNAL) | BIT(MPTCP_ADD_ADDR_ECHO)); +} + +static inline bool mptcp_pm_should_add_signal_addr(struct mptcp_sock *msk) +{ + return READ_ONCE(msk->pm.addr_signal) & BIT(MPTCP_ADD_ADDR_SIGNAL); +} + +static inline bool mptcp_pm_should_add_signal_echo(struct mptcp_sock *msk) +{ + return READ_ONCE(msk->pm.addr_signal) & BIT(MPTCP_ADD_ADDR_ECHO); +} + +static inline bool mptcp_pm_should_rm_signal(struct mptcp_sock *msk) +{ + return READ_ONCE(msk->pm.addr_signal) & BIT(MPTCP_RM_ADDR_SIGNAL); +} + +static inline bool mptcp_pm_is_userspace(const struct mptcp_sock *msk) +{ + return READ_ONCE(msk->pm.pm_type) == MPTCP_PM_TYPE_USERSPACE; +} + +static inline bool mptcp_pm_is_kernel(const struct mptcp_sock *msk) +{ + return READ_ONCE(msk->pm.pm_type) == MPTCP_PM_TYPE_KERNEL; +} + +static inline unsigned int mptcp_add_addr_len(int family, bool echo, bool port) +{ + u8 len = TCPOLEN_MPTCP_ADD_ADDR_BASE; + + if (family == AF_INET6) + len = TCPOLEN_MPTCP_ADD_ADDR6_BASE; + if (!echo) + len += MPTCPOPT_THMAC_LEN; + /* account for 2 trailing 'nop' options */ + if (port) + len += TCPOLEN_MPTCP_PORT_LEN + TCPOLEN_MPTCP_PORT_ALIGN; + + return len; +} + +static inline int mptcp_rm_addr_len(const struct mptcp_rm_list *rm_list) +{ + if (rm_list->nr == 0 || rm_list->nr > MPTCP_RM_IDS_MAX) + return -EINVAL; + + return TCPOLEN_MPTCP_RM_ADDR_BASE + roundup(rm_list->nr - 1, 4) + 1; +} + +bool mptcp_pm_add_addr_signal(struct mptcp_sock *msk, const struct sk_buff *skb, + unsigned int opt_size, unsigned int remaining, + struct mptcp_addr_info *addr, bool *echo, + bool *drop_other_suboptions); +bool mptcp_pm_rm_addr_signal(struct mptcp_sock *msk, unsigned int remaining, + struct mptcp_rm_list *rm_list); +int mptcp_pm_get_local_id(struct mptcp_sock *msk, struct sock_common *skc); +int mptcp_userspace_pm_get_local_id(struct mptcp_sock *msk, struct mptcp_addr_info *skc); + +void __init mptcp_pm_nl_init(void); +void mptcp_pm_nl_work(struct mptcp_sock *msk); +void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk, + const struct mptcp_rm_list *rm_list); +int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc); +unsigned int mptcp_pm_get_add_addr_signal_max(const struct mptcp_sock *msk); +unsigned int mptcp_pm_get_add_addr_accept_max(const struct mptcp_sock *msk); +unsigned int mptcp_pm_get_subflows_max(const struct mptcp_sock *msk); +unsigned int mptcp_pm_get_local_addr_max(const struct mptcp_sock *msk); + +/* called under PM lock */ +static inline void __mptcp_pm_close_subflow(struct mptcp_sock *msk) +{ + if (--msk->pm.subflows < mptcp_pm_get_subflows_max(msk)) + WRITE_ONCE(msk->pm.accept_subflow, true); +} + +static inline void mptcp_pm_close_subflow(struct mptcp_sock *msk) +{ + spin_lock_bh(&msk->pm.lock); + __mptcp_pm_close_subflow(msk); + spin_unlock_bh(&msk->pm.lock); +} + +void mptcp_sockopt_sync(struct mptcp_sock *msk, struct sock *ssk); +void mptcp_sockopt_sync_locked(struct mptcp_sock *msk, struct sock *ssk); + +static inline struct mptcp_ext *mptcp_get_ext(const struct sk_buff *skb) +{ + return (struct mptcp_ext *)skb_ext_find(skb, SKB_EXT_MPTCP); +} + +void mptcp_diag_subflow_init(struct tcp_ulp_ops *ops); + +static inline bool __mptcp_check_fallback(const struct mptcp_sock *msk) +{ + return test_bit(MPTCP_FALLBACK_DONE, &msk->flags); +} + +static inline bool mptcp_check_fallback(const struct sock *sk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + struct mptcp_sock *msk = mptcp_sk(subflow->conn); + + return __mptcp_check_fallback(msk); +} + +static inline void __mptcp_do_fallback(struct mptcp_sock *msk) +{ + if (test_bit(MPTCP_FALLBACK_DONE, &msk->flags)) { + pr_debug("TCP fallback already done (msk=%p)", msk); + return; + } + set_bit(MPTCP_FALLBACK_DONE, &msk->flags); +} + +static inline void mptcp_do_fallback(struct sock *ssk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + struct sock *sk = subflow->conn; + struct mptcp_sock *msk; + + msk = mptcp_sk(sk); + __mptcp_do_fallback(msk); + if (READ_ONCE(msk->snd_data_fin_enable) && !(ssk->sk_shutdown & SEND_SHUTDOWN)) { + gfp_t saved_allocation = ssk->sk_allocation; + + /* we are in a atomic (BH) scope, override ssk default for data + * fin allocation + */ + ssk->sk_allocation = GFP_ATOMIC; + ssk->sk_shutdown |= SEND_SHUTDOWN; + tcp_shutdown(ssk, SEND_SHUTDOWN); + ssk->sk_allocation = saved_allocation; + } +} + +#define pr_fallback(a) pr_debug("%s:fallback to TCP (msk=%p)", __func__, a) + +static inline bool mptcp_check_infinite_map(struct sk_buff *skb) +{ + struct mptcp_ext *mpext; + + mpext = skb ? mptcp_get_ext(skb) : NULL; + if (mpext && mpext->infinite_map) + return true; + + return false; +} + +static inline bool is_active_ssk(struct mptcp_subflow_context *subflow) +{ + return (subflow->request_mptcp || subflow->request_join); +} + +static inline bool subflow_simultaneous_connect(struct sock *sk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + + return sk->sk_state == TCP_ESTABLISHED && + is_active_ssk(subflow) && + !subflow->conn_finished; +} + +#ifdef CONFIG_SYN_COOKIES +void subflow_init_req_cookie_join_save(const struct mptcp_subflow_request_sock *subflow_req, + struct sk_buff *skb); +bool mptcp_token_join_cookie_init_state(struct mptcp_subflow_request_sock *subflow_req, + struct sk_buff *skb); +void __init mptcp_join_cookie_init(void); +#else +static inline void +subflow_init_req_cookie_join_save(const struct mptcp_subflow_request_sock *subflow_req, + struct sk_buff *skb) {} +static inline bool +mptcp_token_join_cookie_init_state(struct mptcp_subflow_request_sock *subflow_req, + struct sk_buff *skb) +{ + return false; +} + +static inline void mptcp_join_cookie_init(void) {} +#endif + +#endif /* __MPTCP_PROTOCOL_H */ diff --git a/net/mptcp/sockopt.c b/net/mptcp/sockopt.c new file mode 100644 index 000000000..30374fd44 --- /dev/null +++ b/net/mptcp/sockopt.c @@ -0,0 +1,1338 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Multipath TCP + * + * Copyright (c) 2021, Red Hat. + */ + +#define pr_fmt(fmt) "MPTCP: " fmt + +#include +#include +#include +#include +#include +#include +#include "protocol.h" + +#define MIN_INFO_OPTLEN_SIZE 16 + +static struct sock *__mptcp_tcp_fallback(struct mptcp_sock *msk) +{ + sock_owned_by_me((const struct sock *)msk); + + if (likely(!__mptcp_check_fallback(msk))) + return NULL; + + return msk->first; +} + +static u32 sockopt_seq_reset(const struct sock *sk) +{ + sock_owned_by_me(sk); + + /* Highbits contain state. Allows to distinguish sockopt_seq + * of listener and established: + * s0 = new_listener() + * sockopt(s0) - seq is 1 + * s1 = accept(s0) - s1 inherits seq 1 if listener sk (s0) + * sockopt(s0) - seq increments to 2 on s0 + * sockopt(s1) // seq increments to 2 on s1 (different option) + * new ssk completes join, inherits options from s0 // seq 2 + * Needs sync from mptcp join logic, but ssk->seq == msk->seq + * + * Set High order bits to sk_state so ssk->seq == msk->seq test + * will fail. + */ + + return (u32)sk->sk_state << 24u; +} + +static void sockopt_seq_inc(struct mptcp_sock *msk) +{ + u32 seq = (msk->setsockopt_seq + 1) & 0x00ffffff; + + msk->setsockopt_seq = sockopt_seq_reset((struct sock *)msk) + seq; +} + +static int mptcp_get_int_option(struct mptcp_sock *msk, sockptr_t optval, + unsigned int optlen, int *val) +{ + if (optlen < sizeof(int)) + return -EINVAL; + + if (copy_from_sockptr(val, optval, sizeof(*val))) + return -EFAULT; + + return 0; +} + +static void mptcp_sol_socket_sync_intval(struct mptcp_sock *msk, int optname, int val) +{ + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + + lock_sock(sk); + sockopt_seq_inc(msk); + + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + bool slow = lock_sock_fast(ssk); + + switch (optname) { + case SO_DEBUG: + sock_valbool_flag(ssk, SOCK_DBG, !!val); + break; + case SO_KEEPALIVE: + if (ssk->sk_prot->keepalive) + ssk->sk_prot->keepalive(ssk, !!val); + sock_valbool_flag(ssk, SOCK_KEEPOPEN, !!val); + break; + case SO_PRIORITY: + ssk->sk_priority = val; + break; + case SO_SNDBUF: + case SO_SNDBUFFORCE: + ssk->sk_userlocks |= SOCK_SNDBUF_LOCK; + WRITE_ONCE(ssk->sk_sndbuf, sk->sk_sndbuf); + break; + case SO_RCVBUF: + case SO_RCVBUFFORCE: + ssk->sk_userlocks |= SOCK_RCVBUF_LOCK; + WRITE_ONCE(ssk->sk_rcvbuf, sk->sk_rcvbuf); + break; + case SO_MARK: + if (READ_ONCE(ssk->sk_mark) != sk->sk_mark) { + WRITE_ONCE(ssk->sk_mark, sk->sk_mark); + sk_dst_reset(ssk); + } + break; + case SO_INCOMING_CPU: + WRITE_ONCE(ssk->sk_incoming_cpu, val); + break; + } + + subflow->setsockopt_seq = msk->setsockopt_seq; + unlock_sock_fast(ssk, slow); + } + + release_sock(sk); +} + +static int mptcp_sol_socket_intval(struct mptcp_sock *msk, int optname, int val) +{ + sockptr_t optval = KERNEL_SOCKPTR(&val); + struct sock *sk = (struct sock *)msk; + int ret; + + ret = sock_setsockopt(sk->sk_socket, SOL_SOCKET, optname, + optval, sizeof(val)); + if (ret) + return ret; + + mptcp_sol_socket_sync_intval(msk, optname, val); + return 0; +} + +static void mptcp_so_incoming_cpu(struct mptcp_sock *msk, int val) +{ + struct sock *sk = (struct sock *)msk; + + WRITE_ONCE(sk->sk_incoming_cpu, val); + + mptcp_sol_socket_sync_intval(msk, SO_INCOMING_CPU, val); +} + +static int mptcp_setsockopt_sol_socket_tstamp(struct mptcp_sock *msk, int optname, int val) +{ + sockptr_t optval = KERNEL_SOCKPTR(&val); + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + int ret; + + ret = sock_setsockopt(sk->sk_socket, SOL_SOCKET, optname, + optval, sizeof(val)); + if (ret) + return ret; + + lock_sock(sk); + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + bool slow = lock_sock_fast(ssk); + + sock_set_timestamp(sk, optname, !!val); + unlock_sock_fast(ssk, slow); + } + + release_sock(sk); + return 0; +} + +static int mptcp_setsockopt_sol_socket_int(struct mptcp_sock *msk, int optname, + sockptr_t optval, + unsigned int optlen) +{ + int val, ret; + + ret = mptcp_get_int_option(msk, optval, optlen, &val); + if (ret) + return ret; + + switch (optname) { + case SO_KEEPALIVE: + mptcp_sol_socket_sync_intval(msk, optname, val); + return 0; + case SO_DEBUG: + case SO_MARK: + case SO_PRIORITY: + case SO_SNDBUF: + case SO_SNDBUFFORCE: + case SO_RCVBUF: + case SO_RCVBUFFORCE: + return mptcp_sol_socket_intval(msk, optname, val); + case SO_INCOMING_CPU: + mptcp_so_incoming_cpu(msk, val); + return 0; + case SO_TIMESTAMP_OLD: + case SO_TIMESTAMP_NEW: + case SO_TIMESTAMPNS_OLD: + case SO_TIMESTAMPNS_NEW: + return mptcp_setsockopt_sol_socket_tstamp(msk, optname, val); + } + + return -ENOPROTOOPT; +} + +static int mptcp_setsockopt_sol_socket_timestamping(struct mptcp_sock *msk, + int optname, + sockptr_t optval, + unsigned int optlen) +{ + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + struct so_timestamping timestamping; + int ret; + + if (optlen == sizeof(timestamping)) { + if (copy_from_sockptr(×tamping, optval, + sizeof(timestamping))) + return -EFAULT; + } else if (optlen == sizeof(int)) { + memset(×tamping, 0, sizeof(timestamping)); + + if (copy_from_sockptr(×tamping.flags, optval, sizeof(int))) + return -EFAULT; + } else { + return -EINVAL; + } + + ret = sock_setsockopt(sk->sk_socket, SOL_SOCKET, optname, + KERNEL_SOCKPTR(×tamping), + sizeof(timestamping)); + if (ret) + return ret; + + lock_sock(sk); + + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + bool slow = lock_sock_fast(ssk); + + sock_set_timestamping(sk, optname, timestamping); + unlock_sock_fast(ssk, slow); + } + + release_sock(sk); + + return 0; +} + +static int mptcp_setsockopt_sol_socket_linger(struct mptcp_sock *msk, sockptr_t optval, + unsigned int optlen) +{ + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + struct linger ling; + sockptr_t kopt; + int ret; + + if (optlen < sizeof(ling)) + return -EINVAL; + + if (copy_from_sockptr(&ling, optval, sizeof(ling))) + return -EFAULT; + + kopt = KERNEL_SOCKPTR(&ling); + ret = sock_setsockopt(sk->sk_socket, SOL_SOCKET, SO_LINGER, kopt, sizeof(ling)); + if (ret) + return ret; + + lock_sock(sk); + sockopt_seq_inc(msk); + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + bool slow = lock_sock_fast(ssk); + + if (!ling.l_onoff) { + sock_reset_flag(ssk, SOCK_LINGER); + } else { + ssk->sk_lingertime = sk->sk_lingertime; + sock_set_flag(ssk, SOCK_LINGER); + } + + subflow->setsockopt_seq = msk->setsockopt_seq; + unlock_sock_fast(ssk, slow); + } + + release_sock(sk); + return 0; +} + +static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = (struct sock *)msk; + struct socket *ssock; + int ret; + + switch (optname) { + case SO_REUSEPORT: + case SO_REUSEADDR: + case SO_BINDTODEVICE: + case SO_BINDTOIFINDEX: + lock_sock(sk); + ssock = __mptcp_nmpc_socket(msk); + if (!ssock) { + release_sock(sk); + return -EINVAL; + } + + ret = sock_setsockopt(ssock, SOL_SOCKET, optname, optval, optlen); + if (ret == 0) { + if (optname == SO_REUSEPORT) + sk->sk_reuseport = ssock->sk->sk_reuseport; + else if (optname == SO_REUSEADDR) + sk->sk_reuse = ssock->sk->sk_reuse; + else if (optname == SO_BINDTODEVICE) + sk->sk_bound_dev_if = ssock->sk->sk_bound_dev_if; + else if (optname == SO_BINDTOIFINDEX) + sk->sk_bound_dev_if = ssock->sk->sk_bound_dev_if; + } + release_sock(sk); + return ret; + case SO_KEEPALIVE: + case SO_PRIORITY: + case SO_SNDBUF: + case SO_SNDBUFFORCE: + case SO_RCVBUF: + case SO_RCVBUFFORCE: + case SO_MARK: + case SO_INCOMING_CPU: + case SO_DEBUG: + case SO_TIMESTAMP_OLD: + case SO_TIMESTAMP_NEW: + case SO_TIMESTAMPNS_OLD: + case SO_TIMESTAMPNS_NEW: + return mptcp_setsockopt_sol_socket_int(msk, optname, optval, + optlen); + case SO_TIMESTAMPING_OLD: + case SO_TIMESTAMPING_NEW: + return mptcp_setsockopt_sol_socket_timestamping(msk, optname, + optval, optlen); + case SO_LINGER: + return mptcp_setsockopt_sol_socket_linger(msk, optval, optlen); + case SO_RCVLOWAT: + case SO_RCVTIMEO_OLD: + case SO_RCVTIMEO_NEW: + case SO_SNDTIMEO_OLD: + case SO_SNDTIMEO_NEW: + case SO_BUSY_POLL: + case SO_PREFER_BUSY_POLL: + case SO_BUSY_POLL_BUDGET: + /* No need to copy: only relevant for msk */ + return sock_setsockopt(sk->sk_socket, SOL_SOCKET, optname, optval, optlen); + case SO_NO_CHECK: + case SO_DONTROUTE: + case SO_BROADCAST: + case SO_BSDCOMPAT: + case SO_PASSCRED: + case SO_PASSSEC: + case SO_RXQ_OVFL: + case SO_WIFI_STATUS: + case SO_NOFCS: + case SO_SELECT_ERR_QUEUE: + return 0; + } + + /* SO_OOBINLINE is not supported, let's avoid the related mess + * SO_ATTACH_FILTER, SO_ATTACH_BPF, SO_ATTACH_REUSEPORT_CBPF, + * SO_DETACH_REUSEPORT_BPF, SO_DETACH_FILTER, SO_LOCK_FILTER, + * we must be careful with subflows + * + * SO_ATTACH_REUSEPORT_EBPF is not supported, at it checks + * explicitly the sk_protocol field + * + * SO_PEEK_OFF is unsupported, as it is for plain TCP + * SO_MAX_PACING_RATE is unsupported, we must be careful with subflows + * SO_CNX_ADVICE is currently unsupported, could possibly be relevant, + * but likely needs careful design + * + * SO_ZEROCOPY is currently unsupported, TODO in sndmsg + * SO_TXTIME is currently unsupported + */ + + return -EOPNOTSUPP; +} + +static int mptcp_setsockopt_v6(struct mptcp_sock *msk, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = (struct sock *)msk; + int ret = -EOPNOTSUPP; + struct socket *ssock; + + switch (optname) { + case IPV6_V6ONLY: + case IPV6_TRANSPARENT: + case IPV6_FREEBIND: + lock_sock(sk); + ssock = __mptcp_nmpc_socket(msk); + if (!ssock) { + release_sock(sk); + return -EINVAL; + } + + ret = tcp_setsockopt(ssock->sk, SOL_IPV6, optname, optval, optlen); + if (ret != 0) { + release_sock(sk); + return ret; + } + + sockopt_seq_inc(msk); + + switch (optname) { + case IPV6_V6ONLY: + sk->sk_ipv6only = ssock->sk->sk_ipv6only; + break; + case IPV6_TRANSPARENT: + inet_sk(sk)->transparent = inet_sk(ssock->sk)->transparent; + break; + case IPV6_FREEBIND: + inet_sk(sk)->freebind = inet_sk(ssock->sk)->freebind; + break; + } + + release_sock(sk); + break; + } + + return ret; +} + +static bool mptcp_supported_sockopt(int level, int optname) +{ + if (level == SOL_IP) { + switch (optname) { + /* should work fine */ + case IP_FREEBIND: + case IP_TRANSPARENT: + + /* the following are control cmsg related */ + case IP_PKTINFO: + case IP_RECVTTL: + case IP_RECVTOS: + case IP_RECVOPTS: + case IP_RETOPTS: + case IP_PASSSEC: + case IP_RECVORIGDSTADDR: + case IP_CHECKSUM: + case IP_RECVFRAGSIZE: + + /* common stuff that need some love */ + case IP_TOS: + case IP_TTL: + case IP_BIND_ADDRESS_NO_PORT: + case IP_MTU_DISCOVER: + case IP_RECVERR: + + /* possibly less common may deserve some love */ + case IP_MINTTL: + + /* the following is apparently a no-op for plain TCP */ + case IP_RECVERR_RFC4884: + return true; + } + + /* IP_OPTIONS is not supported, needs subflow care */ + /* IP_HDRINCL, IP_NODEFRAG are not supported, RAW specific */ + /* IP_MULTICAST_TTL, IP_MULTICAST_LOOP, IP_UNICAST_IF, + * IP_ADD_MEMBERSHIP, IP_ADD_SOURCE_MEMBERSHIP, IP_DROP_MEMBERSHIP, + * IP_DROP_SOURCE_MEMBERSHIP, IP_BLOCK_SOURCE, IP_UNBLOCK_SOURCE, + * MCAST_JOIN_GROUP, MCAST_LEAVE_GROUP MCAST_JOIN_SOURCE_GROUP, + * MCAST_LEAVE_SOURCE_GROUP, MCAST_BLOCK_SOURCE, MCAST_UNBLOCK_SOURCE, + * MCAST_MSFILTER, IP_MULTICAST_ALL are not supported, better not deal + * with mcast stuff + */ + /* IP_IPSEC_POLICY, IP_XFRM_POLICY are nut supported, unrelated here */ + return false; + } + if (level == SOL_IPV6) { + switch (optname) { + case IPV6_V6ONLY: + + /* the following are control cmsg related */ + case IPV6_RECVPKTINFO: + case IPV6_2292PKTINFO: + case IPV6_RECVHOPLIMIT: + case IPV6_2292HOPLIMIT: + case IPV6_RECVRTHDR: + case IPV6_2292RTHDR: + case IPV6_RECVHOPOPTS: + case IPV6_2292HOPOPTS: + case IPV6_RECVDSTOPTS: + case IPV6_2292DSTOPTS: + case IPV6_RECVTCLASS: + case IPV6_FLOWINFO: + case IPV6_RECVPATHMTU: + case IPV6_RECVORIGDSTADDR: + case IPV6_RECVFRAGSIZE: + + /* the following ones need some love but are quite common */ + case IPV6_TCLASS: + case IPV6_TRANSPARENT: + case IPV6_FREEBIND: + case IPV6_PKTINFO: + case IPV6_2292PKTOPTIONS: + case IPV6_UNICAST_HOPS: + case IPV6_MTU_DISCOVER: + case IPV6_MTU: + case IPV6_RECVERR: + case IPV6_FLOWINFO_SEND: + case IPV6_FLOWLABEL_MGR: + case IPV6_MINHOPCOUNT: + case IPV6_DONTFRAG: + case IPV6_AUTOFLOWLABEL: + + /* the following one is a no-op for plain TCP */ + case IPV6_RECVERR_RFC4884: + return true; + } + + /* IPV6_HOPOPTS, IPV6_RTHDRDSTOPTS, IPV6_RTHDR, IPV6_DSTOPTS are + * not supported + */ + /* IPV6_MULTICAST_HOPS, IPV6_MULTICAST_LOOP, IPV6_UNICAST_IF, + * IPV6_MULTICAST_IF, IPV6_ADDRFORM, + * IPV6_ADD_MEMBERSHIP, IPV6_DROP_MEMBERSHIP, IPV6_JOIN_ANYCAST, + * IPV6_LEAVE_ANYCAST, IPV6_MULTICAST_ALL, MCAST_JOIN_GROUP, MCAST_LEAVE_GROUP, + * MCAST_JOIN_SOURCE_GROUP, MCAST_LEAVE_SOURCE_GROUP, + * MCAST_BLOCK_SOURCE, MCAST_UNBLOCK_SOURCE, MCAST_MSFILTER + * are not supported better not deal with mcast + */ + /* IPV6_ROUTER_ALERT, IPV6_ROUTER_ALERT_ISOLATE are not supported, since are evil */ + + /* IPV6_IPSEC_POLICY, IPV6_XFRM_POLICY are not supported */ + /* IPV6_ADDR_PREFERENCES is not supported, we must be careful with subflows */ + return false; + } + if (level == SOL_TCP) { + switch (optname) { + /* the following are no-op or should work just fine */ + case TCP_THIN_DUPACK: + case TCP_DEFER_ACCEPT: + + /* the following need some love */ + case TCP_MAXSEG: + case TCP_NODELAY: + case TCP_THIN_LINEAR_TIMEOUTS: + case TCP_CONGESTION: + case TCP_CORK: + case TCP_KEEPIDLE: + case TCP_KEEPINTVL: + case TCP_KEEPCNT: + case TCP_SYNCNT: + case TCP_SAVE_SYN: + case TCP_LINGER2: + case TCP_WINDOW_CLAMP: + case TCP_QUICKACK: + case TCP_USER_TIMEOUT: + case TCP_TIMESTAMP: + case TCP_NOTSENT_LOWAT: + case TCP_TX_DELAY: + case TCP_INQ: + case TCP_FASTOPEN_CONNECT: + return true; + } + + /* TCP_MD5SIG, TCP_MD5SIG_EXT are not supported, MD5 is not compatible with MPTCP */ + + /* TCP_REPAIR, TCP_REPAIR_QUEUE, TCP_QUEUE_SEQ, TCP_REPAIR_OPTIONS, + * TCP_REPAIR_WINDOW are not supported, better avoid this mess + */ + /* TCP_FASTOPEN_KEY, TCP_FASTOPEN, TCP_FASTOPEN_NO_COOKIE, + * are not supported fastopen is currently unsupported + */ + } + return false; +} + +static int mptcp_setsockopt_sol_tcp_congestion(struct mptcp_sock *msk, sockptr_t optval, + unsigned int optlen) +{ + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + char name[TCP_CA_NAME_MAX]; + bool cap_net_admin; + int ret; + + if (optlen < 1) + return -EINVAL; + + ret = strncpy_from_sockptr(name, optval, + min_t(long, TCP_CA_NAME_MAX - 1, optlen)); + if (ret < 0) + return -EFAULT; + + name[ret] = 0; + + cap_net_admin = ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN); + + ret = 0; + lock_sock(sk); + sockopt_seq_inc(msk); + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + int err; + + lock_sock(ssk); + err = tcp_set_congestion_control(ssk, name, true, cap_net_admin); + if (err < 0 && ret == 0) + ret = err; + subflow->setsockopt_seq = msk->setsockopt_seq; + release_sock(ssk); + } + + if (ret == 0) + strcpy(msk->ca_name, name); + + release_sock(sk); + return ret; +} + +static int mptcp_setsockopt_sol_tcp_cork(struct mptcp_sock *msk, sockptr_t optval, + unsigned int optlen) +{ + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + int val; + + if (optlen < sizeof(int)) + return -EINVAL; + + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + + lock_sock(sk); + sockopt_seq_inc(msk); + msk->cork = !!val; + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + + lock_sock(ssk); + __tcp_sock_set_cork(ssk, !!val); + release_sock(ssk); + } + if (!val) + mptcp_check_and_set_pending(sk); + release_sock(sk); + + return 0; +} + +static int mptcp_setsockopt_sol_tcp_nodelay(struct mptcp_sock *msk, sockptr_t optval, + unsigned int optlen) +{ + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + int val; + + if (optlen < sizeof(int)) + return -EINVAL; + + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + + lock_sock(sk); + sockopt_seq_inc(msk); + msk->nodelay = !!val; + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + + lock_sock(ssk); + __tcp_sock_set_nodelay(ssk, !!val); + release_sock(ssk); + } + if (val) + mptcp_check_and_set_pending(sk); + release_sock(sk); + + return 0; +} + +static int mptcp_setsockopt_sol_ip_set_transparent(struct mptcp_sock *msk, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = (struct sock *)msk; + struct inet_sock *issk; + struct socket *ssock; + int err; + + err = ip_setsockopt(sk, SOL_IP, optname, optval, optlen); + if (err != 0) + return err; + + lock_sock(sk); + + ssock = __mptcp_nmpc_socket(msk); + if (!ssock) { + release_sock(sk); + return -EINVAL; + } + + issk = inet_sk(ssock->sk); + + switch (optname) { + case IP_FREEBIND: + issk->freebind = inet_sk(sk)->freebind; + break; + case IP_TRANSPARENT: + issk->transparent = inet_sk(sk)->transparent; + break; + default: + release_sock(sk); + WARN_ON_ONCE(1); + return -EOPNOTSUPP; + } + + sockopt_seq_inc(msk); + release_sock(sk); + return 0; +} + +static int mptcp_setsockopt_v4_set_tos(struct mptcp_sock *msk, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct mptcp_subflow_context *subflow; + struct sock *sk = (struct sock *)msk; + int err, val; + + err = ip_setsockopt(sk, SOL_IP, optname, optval, optlen); + + if (err != 0) + return err; + + lock_sock(sk); + sockopt_seq_inc(msk); + val = inet_sk(sk)->tos; + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + bool slow; + + slow = lock_sock_fast(ssk); + __ip_sock_set_tos(ssk, val); + unlock_sock_fast(ssk, slow); + } + release_sock(sk); + + return err; +} + +static int mptcp_setsockopt_v4(struct mptcp_sock *msk, int optname, + sockptr_t optval, unsigned int optlen) +{ + switch (optname) { + case IP_FREEBIND: + case IP_TRANSPARENT: + return mptcp_setsockopt_sol_ip_set_transparent(msk, optname, optval, optlen); + case IP_TOS: + return mptcp_setsockopt_v4_set_tos(msk, optname, optval, optlen); + } + + return -EOPNOTSUPP; +} + +static int mptcp_setsockopt_sol_tcp_defer(struct mptcp_sock *msk, sockptr_t optval, + unsigned int optlen) +{ + struct socket *listener; + + listener = __mptcp_nmpc_socket(msk); + if (!listener) + return 0; /* TCP_DEFER_ACCEPT does not fail */ + + return tcp_setsockopt(listener->sk, SOL_TCP, TCP_DEFER_ACCEPT, optval, optlen); +} + +static int mptcp_setsockopt_first_sf_only(struct mptcp_sock *msk, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = (struct sock *)msk; + struct socket *sock; + int ret = -EINVAL; + + /* Limit to first subflow, before the connection establishment */ + lock_sock(sk); + sock = __mptcp_nmpc_socket(msk); + if (!sock) + goto unlock; + + ret = tcp_setsockopt(sock->sk, level, optname, optval, optlen); + +unlock: + release_sock(sk); + return ret; +} + +static int mptcp_setsockopt_sol_tcp(struct mptcp_sock *msk, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = (void *)msk; + int ret, val; + + switch (optname) { + case TCP_INQ: + ret = mptcp_get_int_option(msk, optval, optlen, &val); + if (ret) + return ret; + if (val < 0 || val > 1) + return -EINVAL; + + lock_sock(sk); + msk->recvmsg_inq = !!val; + release_sock(sk); + return 0; + case TCP_ULP: + return -EOPNOTSUPP; + case TCP_CONGESTION: + return mptcp_setsockopt_sol_tcp_congestion(msk, optval, optlen); + case TCP_CORK: + return mptcp_setsockopt_sol_tcp_cork(msk, optval, optlen); + case TCP_NODELAY: + return mptcp_setsockopt_sol_tcp_nodelay(msk, optval, optlen); + case TCP_DEFER_ACCEPT: + return mptcp_setsockopt_sol_tcp_defer(msk, optval, optlen); + case TCP_FASTOPEN_CONNECT: + return mptcp_setsockopt_first_sf_only(msk, SOL_TCP, optname, + optval, optlen); + } + + return -EOPNOTSUPP; +} + +int mptcp_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + struct sock *ssk; + + pr_debug("msk=%p", msk); + + if (level == SOL_SOCKET) + return mptcp_setsockopt_sol_socket(msk, optname, optval, optlen); + + if (!mptcp_supported_sockopt(level, optname)) + return -ENOPROTOOPT; + + /* @@ the meaning of setsockopt() when the socket is connected and + * there are multiple subflows is not yet defined. It is up to the + * MPTCP-level socket to configure the subflows until the subflow + * is in TCP fallback, when TCP socket options are passed through + * to the one remaining subflow. + */ + lock_sock(sk); + ssk = __mptcp_tcp_fallback(msk); + release_sock(sk); + if (ssk) + return tcp_setsockopt(ssk, level, optname, optval, optlen); + + if (level == SOL_IP) + return mptcp_setsockopt_v4(msk, optname, optval, optlen); + + if (level == SOL_IPV6) + return mptcp_setsockopt_v6(msk, optname, optval, optlen); + + if (level == SOL_TCP) + return mptcp_setsockopt_sol_tcp(msk, optname, optval, optlen); + + return -EOPNOTSUPP; +} + +static int mptcp_getsockopt_first_sf_only(struct mptcp_sock *msk, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = (struct sock *)msk; + struct socket *ssock; + int ret = -EINVAL; + struct sock *ssk; + + lock_sock(sk); + ssk = msk->first; + if (ssk) { + ret = tcp_getsockopt(ssk, level, optname, optval, optlen); + goto out; + } + + ssock = __mptcp_nmpc_socket(msk); + if (!ssock) + goto out; + + ret = tcp_getsockopt(ssock->sk, level, optname, optval, optlen); + +out: + release_sock(sk); + return ret; +} + +void mptcp_diag_fill_info(struct mptcp_sock *msk, struct mptcp_info *info) +{ + u32 flags = 0; + u8 val; + + memset(info, 0, sizeof(*info)); + + info->mptcpi_subflows = READ_ONCE(msk->pm.subflows); + info->mptcpi_add_addr_signal = READ_ONCE(msk->pm.add_addr_signaled); + info->mptcpi_add_addr_accepted = READ_ONCE(msk->pm.add_addr_accepted); + info->mptcpi_local_addr_used = READ_ONCE(msk->pm.local_addr_used); + info->mptcpi_subflows_max = mptcp_pm_get_subflows_max(msk); + val = mptcp_pm_get_add_addr_signal_max(msk); + info->mptcpi_add_addr_signal_max = val; + val = mptcp_pm_get_add_addr_accept_max(msk); + info->mptcpi_add_addr_accepted_max = val; + info->mptcpi_local_addr_max = mptcp_pm_get_local_addr_max(msk); + if (test_bit(MPTCP_FALLBACK_DONE, &msk->flags)) + flags |= MPTCP_INFO_FLAG_FALLBACK; + if (READ_ONCE(msk->can_ack)) + flags |= MPTCP_INFO_FLAG_REMOTE_KEY_RECEIVED; + info->mptcpi_flags = flags; + info->mptcpi_token = READ_ONCE(msk->token); + info->mptcpi_write_seq = READ_ONCE(msk->write_seq); + info->mptcpi_snd_una = READ_ONCE(msk->snd_una); + info->mptcpi_rcv_nxt = READ_ONCE(msk->ack_seq); + info->mptcpi_csum_enabled = READ_ONCE(msk->csum_enabled); +} +EXPORT_SYMBOL_GPL(mptcp_diag_fill_info); + +static int mptcp_getsockopt_info(struct mptcp_sock *msk, char __user *optval, int __user *optlen) +{ + struct mptcp_info m_info; + int len; + + if (get_user(len, optlen)) + return -EFAULT; + + len = min_t(unsigned int, len, sizeof(struct mptcp_info)); + + mptcp_diag_fill_info(msk, &m_info); + + if (put_user(len, optlen)) + return -EFAULT; + + if (copy_to_user(optval, &m_info, len)) + return -EFAULT; + + return 0; +} + +static int mptcp_put_subflow_data(struct mptcp_subflow_data *sfd, + char __user *optval, + u32 copied, + int __user *optlen) +{ + u32 copylen = min_t(u32, sfd->size_subflow_data, sizeof(*sfd)); + + if (copied) + copied += sfd->size_subflow_data; + else + copied = copylen; + + if (put_user(copied, optlen)) + return -EFAULT; + + if (copy_to_user(optval, sfd, copylen)) + return -EFAULT; + + return 0; +} + +static int mptcp_get_subflow_data(struct mptcp_subflow_data *sfd, + char __user *optval, int __user *optlen) +{ + int len, copylen; + + if (get_user(len, optlen)) + return -EFAULT; + + /* if mptcp_subflow_data size is changed, need to adjust + * this function to deal with programs using old version. + */ + BUILD_BUG_ON(sizeof(*sfd) != MIN_INFO_OPTLEN_SIZE); + + if (len < MIN_INFO_OPTLEN_SIZE) + return -EINVAL; + + memset(sfd, 0, sizeof(*sfd)); + + copylen = min_t(unsigned int, len, sizeof(*sfd)); + if (copy_from_user(sfd, optval, copylen)) + return -EFAULT; + + /* size_subflow_data is u32, but len is signed */ + if (sfd->size_subflow_data > INT_MAX || + sfd->size_user > INT_MAX) + return -EINVAL; + + if (sfd->size_subflow_data < MIN_INFO_OPTLEN_SIZE || + sfd->size_subflow_data > len) + return -EINVAL; + + if (sfd->num_subflows || sfd->size_kernel) + return -EINVAL; + + return len - sfd->size_subflow_data; +} + +static int mptcp_getsockopt_tcpinfo(struct mptcp_sock *msk, char __user *optval, + int __user *optlen) +{ + struct mptcp_subflow_context *subflow; + struct sock *sk = &msk->sk.icsk_inet.sk; + unsigned int sfcount = 0, copied = 0; + struct mptcp_subflow_data sfd; + char __user *infoptr; + int len; + + len = mptcp_get_subflow_data(&sfd, optval, optlen); + if (len < 0) + return len; + + sfd.size_kernel = sizeof(struct tcp_info); + sfd.size_user = min_t(unsigned int, sfd.size_user, + sizeof(struct tcp_info)); + + infoptr = optval + sfd.size_subflow_data; + + lock_sock(sk); + + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + + ++sfcount; + + if (len && len >= sfd.size_user) { + struct tcp_info info; + + tcp_get_info(ssk, &info); + + if (copy_to_user(infoptr, &info, sfd.size_user)) { + release_sock(sk); + return -EFAULT; + } + + infoptr += sfd.size_user; + copied += sfd.size_user; + len -= sfd.size_user; + } + } + + release_sock(sk); + + sfd.num_subflows = sfcount; + + if (mptcp_put_subflow_data(&sfd, optval, copied, optlen)) + return -EFAULT; + + return 0; +} + +static void mptcp_get_sub_addrs(const struct sock *sk, struct mptcp_subflow_addrs *a) +{ + struct inet_sock *inet = inet_sk(sk); + + memset(a, 0, sizeof(*a)); + + if (sk->sk_family == AF_INET) { + a->sin_local.sin_family = AF_INET; + a->sin_local.sin_port = inet->inet_sport; + a->sin_local.sin_addr.s_addr = inet->inet_rcv_saddr; + + if (!a->sin_local.sin_addr.s_addr) + a->sin_local.sin_addr.s_addr = inet->inet_saddr; + + a->sin_remote.sin_family = AF_INET; + a->sin_remote.sin_port = inet->inet_dport; + a->sin_remote.sin_addr.s_addr = inet->inet_daddr; +#if IS_ENABLED(CONFIG_IPV6) + } else if (sk->sk_family == AF_INET6) { + const struct ipv6_pinfo *np = inet6_sk(sk); + + if (WARN_ON_ONCE(!np)) + return; + + a->sin6_local.sin6_family = AF_INET6; + a->sin6_local.sin6_port = inet->inet_sport; + + if (ipv6_addr_any(&sk->sk_v6_rcv_saddr)) + a->sin6_local.sin6_addr = np->saddr; + else + a->sin6_local.sin6_addr = sk->sk_v6_rcv_saddr; + + a->sin6_remote.sin6_family = AF_INET6; + a->sin6_remote.sin6_port = inet->inet_dport; + a->sin6_remote.sin6_addr = sk->sk_v6_daddr; +#endif + } +} + +static int mptcp_getsockopt_subflow_addrs(struct mptcp_sock *msk, char __user *optval, + int __user *optlen) +{ + struct sock *sk = &msk->sk.icsk_inet.sk; + struct mptcp_subflow_context *subflow; + unsigned int sfcount = 0, copied = 0; + struct mptcp_subflow_data sfd; + char __user *addrptr; + int len; + + len = mptcp_get_subflow_data(&sfd, optval, optlen); + if (len < 0) + return len; + + sfd.size_kernel = sizeof(struct mptcp_subflow_addrs); + sfd.size_user = min_t(unsigned int, sfd.size_user, + sizeof(struct mptcp_subflow_addrs)); + + addrptr = optval + sfd.size_subflow_data; + + lock_sock(sk); + + mptcp_for_each_subflow(msk, subflow) { + struct sock *ssk = mptcp_subflow_tcp_sock(subflow); + + ++sfcount; + + if (len && len >= sfd.size_user) { + struct mptcp_subflow_addrs a; + + mptcp_get_sub_addrs(ssk, &a); + + if (copy_to_user(addrptr, &a, sfd.size_user)) { + release_sock(sk); + return -EFAULT; + } + + addrptr += sfd.size_user; + copied += sfd.size_user; + len -= sfd.size_user; + } + } + + release_sock(sk); + + sfd.num_subflows = sfcount; + + if (mptcp_put_subflow_data(&sfd, optval, copied, optlen)) + return -EFAULT; + + return 0; +} + +static int mptcp_put_int_option(struct mptcp_sock *msk, char __user *optval, + int __user *optlen, int val) +{ + int len; + + if (get_user(len, optlen)) + return -EFAULT; + if (len < 0) + return -EINVAL; + + if (len < sizeof(int) && len > 0 && val >= 0 && val <= 255) { + unsigned char ucval = (unsigned char)val; + + len = 1; + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &ucval, 1)) + return -EFAULT; + } else { + len = min_t(unsigned int, len, sizeof(int)); + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + } + + return 0; +} + +static int mptcp_getsockopt_sol_tcp(struct mptcp_sock *msk, int optname, + char __user *optval, int __user *optlen) +{ + switch (optname) { + case TCP_ULP: + case TCP_CONGESTION: + case TCP_INFO: + case TCP_CC_INFO: + case TCP_DEFER_ACCEPT: + case TCP_FASTOPEN_CONNECT: + return mptcp_getsockopt_first_sf_only(msk, SOL_TCP, optname, + optval, optlen); + case TCP_INQ: + return mptcp_put_int_option(msk, optval, optlen, msk->recvmsg_inq); + case TCP_CORK: + return mptcp_put_int_option(msk, optval, optlen, msk->cork); + case TCP_NODELAY: + return mptcp_put_int_option(msk, optval, optlen, msk->nodelay); + } + return -EOPNOTSUPP; +} + +static int mptcp_getsockopt_v4(struct mptcp_sock *msk, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = (void *)msk; + + switch (optname) { + case IP_TOS: + return mptcp_put_int_option(msk, optval, optlen, inet_sk(sk)->tos); + } + + return -EOPNOTSUPP; +} + +static int mptcp_getsockopt_sol_mptcp(struct mptcp_sock *msk, int optname, + char __user *optval, int __user *optlen) +{ + switch (optname) { + case MPTCP_INFO: + return mptcp_getsockopt_info(msk, optval, optlen); + case MPTCP_TCPINFO: + return mptcp_getsockopt_tcpinfo(msk, optval, optlen); + case MPTCP_SUBFLOW_ADDRS: + return mptcp_getsockopt_subflow_addrs(msk, optval, optlen); + } + + return -EOPNOTSUPP; +} + +int mptcp_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *option) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + struct sock *ssk; + + pr_debug("msk=%p", msk); + + /* @@ the meaning of setsockopt() when the socket is connected and + * there are multiple subflows is not yet defined. It is up to the + * MPTCP-level socket to configure the subflows until the subflow + * is in TCP fallback, when socket options are passed through + * to the one remaining subflow. + */ + lock_sock(sk); + ssk = __mptcp_tcp_fallback(msk); + release_sock(sk); + if (ssk) + return tcp_getsockopt(ssk, level, optname, optval, option); + + if (level == SOL_IP) + return mptcp_getsockopt_v4(msk, optname, optval, option); + if (level == SOL_TCP) + return mptcp_getsockopt_sol_tcp(msk, optname, optval, option); + if (level == SOL_MPTCP) + return mptcp_getsockopt_sol_mptcp(msk, optname, optval, option); + return -EOPNOTSUPP; +} + +static void sync_socket_options(struct mptcp_sock *msk, struct sock *ssk) +{ + static const unsigned int tx_rx_locks = SOCK_RCVBUF_LOCK | SOCK_SNDBUF_LOCK; + struct sock *sk = (struct sock *)msk; + + if (ssk->sk_prot->keepalive) { + if (sock_flag(sk, SOCK_KEEPOPEN)) + ssk->sk_prot->keepalive(ssk, 1); + else + ssk->sk_prot->keepalive(ssk, 0); + } + + ssk->sk_priority = sk->sk_priority; + ssk->sk_bound_dev_if = sk->sk_bound_dev_if; + ssk->sk_incoming_cpu = sk->sk_incoming_cpu; + __ip_sock_set_tos(ssk, inet_sk(sk)->tos); + + if (sk->sk_userlocks & tx_rx_locks) { + ssk->sk_userlocks |= sk->sk_userlocks & tx_rx_locks; + if (sk->sk_userlocks & SOCK_SNDBUF_LOCK) + WRITE_ONCE(ssk->sk_sndbuf, sk->sk_sndbuf); + if (sk->sk_userlocks & SOCK_RCVBUF_LOCK) + WRITE_ONCE(ssk->sk_rcvbuf, sk->sk_rcvbuf); + } + + if (sock_flag(sk, SOCK_LINGER)) { + ssk->sk_lingertime = sk->sk_lingertime; + sock_set_flag(ssk, SOCK_LINGER); + } else { + sock_reset_flag(ssk, SOCK_LINGER); + } + + if (sk->sk_mark != ssk->sk_mark) { + ssk->sk_mark = sk->sk_mark; + sk_dst_reset(ssk); + } + + sock_valbool_flag(ssk, SOCK_DBG, sock_flag(sk, SOCK_DBG)); + + if (inet_csk(sk)->icsk_ca_ops != inet_csk(ssk)->icsk_ca_ops) + tcp_set_congestion_control(ssk, msk->ca_name, false, true); + __tcp_sock_set_cork(ssk, !!msk->cork); + __tcp_sock_set_nodelay(ssk, !!msk->nodelay); + + inet_sk(ssk)->transparent = inet_sk(sk)->transparent; + inet_sk(ssk)->freebind = inet_sk(sk)->freebind; +} + +static void __mptcp_sockopt_sync(struct mptcp_sock *msk, struct sock *ssk) +{ + bool slow = lock_sock_fast(ssk); + + sync_socket_options(msk, ssk); + + unlock_sock_fast(ssk, slow); +} + +void mptcp_sockopt_sync(struct mptcp_sock *msk, struct sock *ssk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + + msk_owned_by_me(msk); + + if (READ_ONCE(subflow->setsockopt_seq) != msk->setsockopt_seq) { + __mptcp_sockopt_sync(msk, ssk); + + subflow->setsockopt_seq = msk->setsockopt_seq; + } +} + +void mptcp_sockopt_sync_locked(struct mptcp_sock *msk, struct sock *ssk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + + msk_owned_by_me(msk); + + if (READ_ONCE(subflow->setsockopt_seq) != msk->setsockopt_seq) { + sync_socket_options(msk, ssk); + + subflow->setsockopt_seq = msk->setsockopt_seq; + } +} diff --git a/net/mptcp/subflow.c b/net/mptcp/subflow.c new file mode 100644 index 000000000..45d20e20c --- /dev/null +++ b/net/mptcp/subflow.c @@ -0,0 +1,2000 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Multipath TCP + * + * Copyright (c) 2017 - 2019, Intel Corporation. + */ + +#define pr_fmt(fmt) "MPTCP: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_MPTCP_IPV6) +#include +#include +#endif +#include +#include +#include "protocol.h" +#include "mib.h" + +#include + +static void mptcp_subflow_ops_undo_override(struct sock *ssk); + +static void SUBFLOW_REQ_INC_STATS(struct request_sock *req, + enum linux_mptcp_mib_field field) +{ + MPTCP_INC_STATS(sock_net(req_to_sk(req)), field); +} + +static void subflow_req_destructor(struct request_sock *req) +{ + struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); + + pr_debug("subflow_req=%p", subflow_req); + + if (subflow_req->msk) + sock_put((struct sock *)subflow_req->msk); + + mptcp_token_destroy_request(req); +} + +static void subflow_generate_hmac(u64 key1, u64 key2, u32 nonce1, u32 nonce2, + void *hmac) +{ + u8 msg[8]; + + put_unaligned_be32(nonce1, &msg[0]); + put_unaligned_be32(nonce2, &msg[4]); + + mptcp_crypto_hmac_sha(key1, key2, msg, 8, hmac); +} + +static bool mptcp_can_accept_new_subflow(const struct mptcp_sock *msk) +{ + return mptcp_is_fully_established((void *)msk) && + ((mptcp_pm_is_userspace(msk) && + mptcp_userspace_pm_active(msk)) || + READ_ONCE(msk->pm.accept_subflow)); +} + +/* validate received token and create truncated hmac and nonce for SYN-ACK */ +static void subflow_req_create_thmac(struct mptcp_subflow_request_sock *subflow_req) +{ + struct mptcp_sock *msk = subflow_req->msk; + u8 hmac[SHA256_DIGEST_SIZE]; + + get_random_bytes(&subflow_req->local_nonce, sizeof(u32)); + + subflow_generate_hmac(msk->local_key, msk->remote_key, + subflow_req->local_nonce, + subflow_req->remote_nonce, hmac); + + subflow_req->thmac = get_unaligned_be64(hmac); +} + +static struct mptcp_sock *subflow_token_join_request(struct request_sock *req) +{ + struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); + struct mptcp_sock *msk; + int local_id; + + msk = mptcp_token_get_sock(sock_net(req_to_sk(req)), subflow_req->token); + if (!msk) { + SUBFLOW_REQ_INC_STATS(req, MPTCP_MIB_JOINNOTOKEN); + return NULL; + } + + local_id = mptcp_pm_get_local_id(msk, (struct sock_common *)req); + if (local_id < 0) { + sock_put((struct sock *)msk); + return NULL; + } + subflow_req->local_id = local_id; + + return msk; +} + +static void subflow_init_req(struct request_sock *req, const struct sock *sk_listener) +{ + struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); + + subflow_req->mp_capable = 0; + subflow_req->mp_join = 0; + subflow_req->csum_reqd = mptcp_is_checksum_enabled(sock_net(sk_listener)); + subflow_req->allow_join_id0 = mptcp_allow_join_id0(sock_net(sk_listener)); + subflow_req->msk = NULL; + mptcp_token_init_request(req); +} + +static bool subflow_use_different_sport(struct mptcp_sock *msk, const struct sock *sk) +{ + return inet_sk(sk)->inet_sport != inet_sk((struct sock *)msk)->inet_sport; +} + +static void subflow_add_reset_reason(struct sk_buff *skb, u8 reason) +{ + struct mptcp_ext *mpext = skb_ext_add(skb, SKB_EXT_MPTCP); + + if (mpext) { + memset(mpext, 0, sizeof(*mpext)); + mpext->reset_reason = reason; + } +} + +/* Init mptcp request socket. + * + * Returns an error code if a JOIN has failed and a TCP reset + * should be sent. + */ +static int subflow_check_req(struct request_sock *req, + const struct sock *sk_listener, + struct sk_buff *skb) +{ + struct mptcp_subflow_context *listener = mptcp_subflow_ctx(sk_listener); + struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); + struct mptcp_options_received mp_opt; + bool opt_mp_capable, opt_mp_join; + + pr_debug("subflow_req=%p, listener=%p", subflow_req, listener); + +#ifdef CONFIG_TCP_MD5SIG + /* no MPTCP if MD5SIG is enabled on this socket or we may run out of + * TCP option space. + */ + if (rcu_access_pointer(tcp_sk(sk_listener)->md5sig_info)) + return -EINVAL; +#endif + + mptcp_get_options(skb, &mp_opt); + + opt_mp_capable = !!(mp_opt.suboptions & OPTION_MPTCP_MPC_SYN); + opt_mp_join = !!(mp_opt.suboptions & OPTION_MPTCP_MPJ_SYN); + if (opt_mp_capable) { + SUBFLOW_REQ_INC_STATS(req, MPTCP_MIB_MPCAPABLEPASSIVE); + + if (opt_mp_join) + return 0; + } else if (opt_mp_join) { + SUBFLOW_REQ_INC_STATS(req, MPTCP_MIB_JOINSYNRX); + } + + if (opt_mp_capable && listener->request_mptcp) { + int err, retries = MPTCP_TOKEN_MAX_RETRIES; + + subflow_req->ssn_offset = TCP_SKB_CB(skb)->seq; +again: + do { + get_random_bytes(&subflow_req->local_key, sizeof(subflow_req->local_key)); + } while (subflow_req->local_key == 0); + + if (unlikely(req->syncookie)) { + mptcp_crypto_key_sha(subflow_req->local_key, + &subflow_req->token, + &subflow_req->idsn); + if (mptcp_token_exists(subflow_req->token)) { + if (retries-- > 0) + goto again; + SUBFLOW_REQ_INC_STATS(req, MPTCP_MIB_TOKENFALLBACKINIT); + } else { + subflow_req->mp_capable = 1; + } + return 0; + } + + err = mptcp_token_new_request(req); + if (err == 0) + subflow_req->mp_capable = 1; + else if (retries-- > 0) + goto again; + else + SUBFLOW_REQ_INC_STATS(req, MPTCP_MIB_TOKENFALLBACKINIT); + + } else if (opt_mp_join && listener->request_mptcp) { + subflow_req->ssn_offset = TCP_SKB_CB(skb)->seq; + subflow_req->mp_join = 1; + subflow_req->backup = mp_opt.backup; + subflow_req->remote_id = mp_opt.join_id; + subflow_req->token = mp_opt.token; + subflow_req->remote_nonce = mp_opt.nonce; + subflow_req->msk = subflow_token_join_request(req); + + /* Can't fall back to TCP in this case. */ + if (!subflow_req->msk) { + subflow_add_reset_reason(skb, MPTCP_RST_EMPTCP); + return -EPERM; + } + + if (subflow_use_different_sport(subflow_req->msk, sk_listener)) { + pr_debug("syn inet_sport=%d %d", + ntohs(inet_sk(sk_listener)->inet_sport), + ntohs(inet_sk((struct sock *)subflow_req->msk)->inet_sport)); + if (!mptcp_pm_sport_in_anno_list(subflow_req->msk, sk_listener)) { + SUBFLOW_REQ_INC_STATS(req, MPTCP_MIB_MISMATCHPORTSYNRX); + return -EPERM; + } + SUBFLOW_REQ_INC_STATS(req, MPTCP_MIB_JOINPORTSYNRX); + } + + subflow_req_create_thmac(subflow_req); + + if (unlikely(req->syncookie)) { + if (mptcp_can_accept_new_subflow(subflow_req->msk)) + subflow_init_req_cookie_join_save(subflow_req, skb); + else + return -EPERM; + } + + pr_debug("token=%u, remote_nonce=%u msk=%p", subflow_req->token, + subflow_req->remote_nonce, subflow_req->msk); + } + + return 0; +} + +int mptcp_subflow_init_cookie_req(struct request_sock *req, + const struct sock *sk_listener, + struct sk_buff *skb) +{ + struct mptcp_subflow_context *listener = mptcp_subflow_ctx(sk_listener); + struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); + struct mptcp_options_received mp_opt; + bool opt_mp_capable, opt_mp_join; + int err; + + subflow_init_req(req, sk_listener); + mptcp_get_options(skb, &mp_opt); + + opt_mp_capable = !!(mp_opt.suboptions & OPTION_MPTCP_MPC_ACK); + opt_mp_join = !!(mp_opt.suboptions & OPTION_MPTCP_MPJ_ACK); + if (opt_mp_capable && opt_mp_join) + return -EINVAL; + + if (opt_mp_capable && listener->request_mptcp) { + if (mp_opt.sndr_key == 0) + return -EINVAL; + + subflow_req->local_key = mp_opt.rcvr_key; + err = mptcp_token_new_request(req); + if (err) + return err; + + subflow_req->mp_capable = 1; + subflow_req->ssn_offset = TCP_SKB_CB(skb)->seq - 1; + } else if (opt_mp_join && listener->request_mptcp) { + if (!mptcp_token_join_cookie_init_state(subflow_req, skb)) + return -EINVAL; + + subflow_req->mp_join = 1; + subflow_req->ssn_offset = TCP_SKB_CB(skb)->seq - 1; + } + + return 0; +} +EXPORT_SYMBOL_GPL(mptcp_subflow_init_cookie_req); + +static struct dst_entry *subflow_v4_route_req(const struct sock *sk, + struct sk_buff *skb, + struct flowi *fl, + struct request_sock *req) +{ + struct dst_entry *dst; + int err; + + tcp_rsk(req)->is_mptcp = 1; + subflow_init_req(req, sk); + + dst = tcp_request_sock_ipv4_ops.route_req(sk, skb, fl, req); + if (!dst) + return NULL; + + err = subflow_check_req(req, sk, skb); + if (err == 0) + return dst; + + dst_release(dst); + if (!req->syncookie) + tcp_request_sock_ops.send_reset(sk, skb); + return NULL; +} + +#if IS_ENABLED(CONFIG_MPTCP_IPV6) +static struct dst_entry *subflow_v6_route_req(const struct sock *sk, + struct sk_buff *skb, + struct flowi *fl, + struct request_sock *req) +{ + struct dst_entry *dst; + int err; + + tcp_rsk(req)->is_mptcp = 1; + subflow_init_req(req, sk); + + dst = tcp_request_sock_ipv6_ops.route_req(sk, skb, fl, req); + if (!dst) + return NULL; + + err = subflow_check_req(req, sk, skb); + if (err == 0) + return dst; + + dst_release(dst); + if (!req->syncookie) + tcp6_request_sock_ops.send_reset(sk, skb); + return NULL; +} +#endif + +/* validate received truncated hmac and create hmac for third ACK */ +static bool subflow_thmac_valid(struct mptcp_subflow_context *subflow) +{ + u8 hmac[SHA256_DIGEST_SIZE]; + u64 thmac; + + subflow_generate_hmac(subflow->remote_key, subflow->local_key, + subflow->remote_nonce, subflow->local_nonce, + hmac); + + thmac = get_unaligned_be64(hmac); + pr_debug("subflow=%p, token=%u, thmac=%llu, subflow->thmac=%llu\n", + subflow, subflow->token, thmac, subflow->thmac); + + return thmac == subflow->thmac; +} + +void mptcp_subflow_reset(struct sock *ssk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + struct sock *sk = subflow->conn; + + /* mptcp_mp_fail_no_response() can reach here on an already closed + * socket + */ + if (ssk->sk_state == TCP_CLOSE) + return; + + /* must hold: tcp_done() could drop last reference on parent */ + sock_hold(sk); + + tcp_send_active_reset(ssk, GFP_ATOMIC); + tcp_done(ssk); + if (!test_and_set_bit(MPTCP_WORK_CLOSE_SUBFLOW, &mptcp_sk(sk)->flags)) + mptcp_schedule_work(sk); + + sock_put(sk); +} + +static bool subflow_use_different_dport(struct mptcp_sock *msk, const struct sock *sk) +{ + return inet_sk(sk)->inet_dport != inet_sk((struct sock *)msk)->inet_dport; +} + +void __mptcp_set_connected(struct sock *sk) +{ + if (sk->sk_state == TCP_SYN_SENT) { + inet_sk_state_store(sk, TCP_ESTABLISHED); + sk->sk_state_change(sk); + } +} + +static void mptcp_set_connected(struct sock *sk) +{ + mptcp_data_lock(sk); + if (!sock_owned_by_user(sk)) + __mptcp_set_connected(sk); + else + __set_bit(MPTCP_CONNECTED, &mptcp_sk(sk)->cb_flags); + mptcp_data_unlock(sk); +} + +static void subflow_finish_connect(struct sock *sk, const struct sk_buff *skb) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + struct mptcp_options_received mp_opt; + struct sock *parent = subflow->conn; + + subflow->icsk_af_ops->sk_rx_dst_set(sk, skb); + + /* be sure no special action on any packet other than syn-ack */ + if (subflow->conn_finished) + return; + + mptcp_propagate_sndbuf(parent, sk); + subflow->rel_write_seq = 1; + subflow->conn_finished = 1; + subflow->ssn_offset = TCP_SKB_CB(skb)->seq; + pr_debug("subflow=%p synack seq=%x", subflow, subflow->ssn_offset); + + mptcp_get_options(skb, &mp_opt); + if (subflow->request_mptcp) { + if (!(mp_opt.suboptions & OPTION_MPTCP_MPC_SYNACK)) { + MPTCP_INC_STATS(sock_net(sk), + MPTCP_MIB_MPCAPABLEACTIVEFALLBACK); + mptcp_do_fallback(sk); + pr_fallback(mptcp_sk(subflow->conn)); + goto fallback; + } + + if (mp_opt.suboptions & OPTION_MPTCP_CSUMREQD) + WRITE_ONCE(mptcp_sk(parent)->csum_enabled, true); + if (mp_opt.deny_join_id0) + WRITE_ONCE(mptcp_sk(parent)->pm.remote_deny_join_id0, true); + subflow->mp_capable = 1; + subflow->can_ack = 1; + subflow->remote_key = mp_opt.sndr_key; + pr_debug("subflow=%p, remote_key=%llu", subflow, + subflow->remote_key); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPCAPABLEACTIVEACK); + mptcp_finish_connect(sk); + mptcp_set_connected(parent); + } else if (subflow->request_join) { + u8 hmac[SHA256_DIGEST_SIZE]; + + if (!(mp_opt.suboptions & OPTION_MPTCP_MPJ_SYNACK)) { + subflow->reset_reason = MPTCP_RST_EMPTCP; + goto do_reset; + } + + subflow->backup = mp_opt.backup; + subflow->thmac = mp_opt.thmac; + subflow->remote_nonce = mp_opt.nonce; + subflow->remote_id = mp_opt.join_id; + pr_debug("subflow=%p, thmac=%llu, remote_nonce=%u backup=%d", + subflow, subflow->thmac, subflow->remote_nonce, + subflow->backup); + + if (!subflow_thmac_valid(subflow)) { + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_JOINACKMAC); + subflow->reset_reason = MPTCP_RST_EMPTCP; + goto do_reset; + } + + if (!mptcp_finish_join(sk)) + goto do_reset; + + subflow_generate_hmac(subflow->local_key, subflow->remote_key, + subflow->local_nonce, + subflow->remote_nonce, + hmac); + memcpy(subflow->hmac, hmac, MPTCPOPT_HMAC_LEN); + + subflow->mp_join = 1; + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_JOINSYNACKRX); + + if (subflow_use_different_dport(mptcp_sk(parent), sk)) { + pr_debug("synack inet_dport=%d %d", + ntohs(inet_sk(sk)->inet_dport), + ntohs(inet_sk(parent)->inet_dport)); + MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_JOINPORTSYNACKRX); + } + } else if (mptcp_check_fallback(sk)) { +fallback: + mptcp_rcv_space_init(mptcp_sk(parent), sk); + mptcp_set_connected(parent); + } + return; + +do_reset: + subflow->reset_transient = 0; + mptcp_subflow_reset(sk); +} + +static void subflow_set_local_id(struct mptcp_subflow_context *subflow, int local_id) +{ + subflow->local_id = local_id; + subflow->local_id_valid = 1; +} + +static int subflow_chk_local_id(struct sock *sk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + struct mptcp_sock *msk = mptcp_sk(subflow->conn); + int err; + + if (likely(subflow->local_id_valid)) + return 0; + + err = mptcp_pm_get_local_id(msk, (struct sock_common *)sk); + if (err < 0) + return err; + + subflow_set_local_id(subflow, err); + return 0; +} + +static int subflow_rebuild_header(struct sock *sk) +{ + int err = subflow_chk_local_id(sk); + + if (unlikely(err < 0)) + return err; + + return inet_sk_rebuild_header(sk); +} + +#if IS_ENABLED(CONFIG_MPTCP_IPV6) +static int subflow_v6_rebuild_header(struct sock *sk) +{ + int err = subflow_chk_local_id(sk); + + if (unlikely(err < 0)) + return err; + + return inet6_sk_rebuild_header(sk); +} +#endif + +static struct request_sock_ops mptcp_subflow_v4_request_sock_ops __ro_after_init; +static struct tcp_request_sock_ops subflow_request_sock_ipv4_ops __ro_after_init; + +static int subflow_v4_conn_request(struct sock *sk, struct sk_buff *skb) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + + pr_debug("subflow=%p", subflow); + + /* Never answer to SYNs sent to broadcast or multicast */ + if (skb_rtable(skb)->rt_flags & (RTCF_BROADCAST | RTCF_MULTICAST)) + goto drop; + + return tcp_conn_request(&mptcp_subflow_v4_request_sock_ops, + &subflow_request_sock_ipv4_ops, + sk, skb); +drop: + tcp_listendrop(sk); + return 0; +} + +static void subflow_v4_req_destructor(struct request_sock *req) +{ + subflow_req_destructor(req); + tcp_request_sock_ops.destructor(req); +} + +#if IS_ENABLED(CONFIG_MPTCP_IPV6) +static struct request_sock_ops mptcp_subflow_v6_request_sock_ops __ro_after_init; +static struct tcp_request_sock_ops subflow_request_sock_ipv6_ops __ro_after_init; +static struct inet_connection_sock_af_ops subflow_v6_specific __ro_after_init; +static struct inet_connection_sock_af_ops subflow_v6m_specific __ro_after_init; +static struct proto tcpv6_prot_override __ro_after_init; + +static int subflow_v6_conn_request(struct sock *sk, struct sk_buff *skb) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + + pr_debug("subflow=%p", subflow); + + if (skb->protocol == htons(ETH_P_IP)) + return subflow_v4_conn_request(sk, skb); + + if (!ipv6_unicast_destination(skb)) + goto drop; + + if (ipv6_addr_v4mapped(&ipv6_hdr(skb)->saddr)) { + __IP6_INC_STATS(sock_net(sk), NULL, IPSTATS_MIB_INHDRERRORS); + return 0; + } + + return tcp_conn_request(&mptcp_subflow_v6_request_sock_ops, + &subflow_request_sock_ipv6_ops, sk, skb); + +drop: + tcp_listendrop(sk); + return 0; /* don't send reset */ +} + +static void subflow_v6_req_destructor(struct request_sock *req) +{ + subflow_req_destructor(req); + tcp6_request_sock_ops.destructor(req); +} +#endif + +struct request_sock *mptcp_subflow_reqsk_alloc(const struct request_sock_ops *ops, + struct sock *sk_listener, + bool attach_listener) +{ + if (ops->family == AF_INET) + ops = &mptcp_subflow_v4_request_sock_ops; +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + else if (ops->family == AF_INET6) + ops = &mptcp_subflow_v6_request_sock_ops; +#endif + + return inet_reqsk_alloc(ops, sk_listener, attach_listener); +} +EXPORT_SYMBOL(mptcp_subflow_reqsk_alloc); + +/* validate hmac received in third ACK */ +static bool subflow_hmac_valid(const struct request_sock *req, + const struct mptcp_options_received *mp_opt) +{ + const struct mptcp_subflow_request_sock *subflow_req; + u8 hmac[SHA256_DIGEST_SIZE]; + struct mptcp_sock *msk; + + subflow_req = mptcp_subflow_rsk(req); + msk = subflow_req->msk; + if (!msk) + return false; + + subflow_generate_hmac(msk->remote_key, msk->local_key, + subflow_req->remote_nonce, + subflow_req->local_nonce, hmac); + + return !crypto_memneq(hmac, mp_opt->hmac, MPTCPOPT_HMAC_LEN); +} + +static void subflow_ulp_fallback(struct sock *sk, + struct mptcp_subflow_context *old_ctx) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + + mptcp_subflow_tcp_fallback(sk, old_ctx); + icsk->icsk_ulp_ops = NULL; + rcu_assign_pointer(icsk->icsk_ulp_data, NULL); + tcp_sk(sk)->is_mptcp = 0; + + mptcp_subflow_ops_undo_override(sk); +} + +void mptcp_subflow_drop_ctx(struct sock *ssk) +{ + struct mptcp_subflow_context *ctx = mptcp_subflow_ctx(ssk); + + if (!ctx) + return; + + list_del(&mptcp_subflow_ctx(ssk)->node); + if (inet_csk(ssk)->icsk_ulp_ops) { + subflow_ulp_fallback(ssk, ctx); + if (ctx->conn) + sock_put(ctx->conn); + } + + kfree_rcu(ctx, rcu); +} + +void mptcp_subflow_fully_established(struct mptcp_subflow_context *subflow, + struct mptcp_options_received *mp_opt) +{ + struct mptcp_sock *msk = mptcp_sk(subflow->conn); + + subflow->remote_key = mp_opt->sndr_key; + subflow->fully_established = 1; + subflow->can_ack = 1; + WRITE_ONCE(msk->fully_established, true); +} + +static struct sock *subflow_syn_recv_sock(const struct sock *sk, + struct sk_buff *skb, + struct request_sock *req, + struct dst_entry *dst, + struct request_sock *req_unhash, + bool *own_req) +{ + struct mptcp_subflow_context *listener = mptcp_subflow_ctx(sk); + struct mptcp_subflow_request_sock *subflow_req; + struct mptcp_options_received mp_opt; + bool fallback, fallback_is_fatal; + struct mptcp_sock *owner; + struct sock *child; + + pr_debug("listener=%p, req=%p, conn=%p", listener, req, listener->conn); + + /* After child creation we must look for MPC even when options + * are not parsed + */ + mp_opt.suboptions = 0; + + /* hopefully temporary handling for MP_JOIN+syncookie */ + subflow_req = mptcp_subflow_rsk(req); + fallback_is_fatal = tcp_rsk(req)->is_mptcp && subflow_req->mp_join; + fallback = !tcp_rsk(req)->is_mptcp; + if (fallback) + goto create_child; + + /* if the sk is MP_CAPABLE, we try to fetch the client key */ + if (subflow_req->mp_capable) { + /* we can receive and accept an in-window, out-of-order pkt, + * which may not carry the MP_CAPABLE opt even on mptcp enabled + * paths: always try to extract the peer key, and fallback + * for packets missing it. + * Even OoO DSS packets coming legitly after dropped or + * reordered MPC will cause fallback, but we don't have other + * options. + */ + mptcp_get_options(skb, &mp_opt); + if (!(mp_opt.suboptions & + (OPTION_MPTCP_MPC_SYN | OPTION_MPTCP_MPC_ACK))) + fallback = true; + + } else if (subflow_req->mp_join) { + mptcp_get_options(skb, &mp_opt); + if (!(mp_opt.suboptions & OPTION_MPTCP_MPJ_ACK) || + !subflow_hmac_valid(req, &mp_opt) || + !mptcp_can_accept_new_subflow(subflow_req->msk)) { + SUBFLOW_REQ_INC_STATS(req, MPTCP_MIB_JOINACKMAC); + fallback = true; + } + } + +create_child: + child = listener->icsk_af_ops->syn_recv_sock(sk, skb, req, dst, + req_unhash, own_req); + + if (child && *own_req) { + struct mptcp_subflow_context *ctx = mptcp_subflow_ctx(child); + + tcp_rsk(req)->drop_req = false; + + /* we need to fallback on ctx allocation failure and on pre-reqs + * checking above. In the latter scenario we additionally need + * to reset the context to non MPTCP status. + */ + if (!ctx || fallback) { + if (fallback_is_fatal) { + subflow_add_reset_reason(skb, MPTCP_RST_EMPTCP); + goto dispose_child; + } + goto fallback; + } + + /* ssk inherits options of listener sk */ + ctx->setsockopt_seq = listener->setsockopt_seq; + + if (ctx->mp_capable) { + ctx->conn = mptcp_sk_clone_init(listener->conn, &mp_opt, child, req); + if (!ctx->conn) + goto fallback; + + owner = mptcp_sk(ctx->conn); + mptcp_pm_new_connection(owner, child, 1); + + /* with OoO packets we can reach here without ingress + * mpc option + */ + if (mp_opt.suboptions & OPTIONS_MPTCP_MPC) { + mptcp_subflow_fully_established(ctx, &mp_opt); + mptcp_pm_fully_established(owner, child, GFP_ATOMIC); + ctx->pm_notified = 1; + } + } else if (ctx->mp_join) { + owner = subflow_req->msk; + if (!owner) { + subflow_add_reset_reason(skb, MPTCP_RST_EPROHIBIT); + goto dispose_child; + } + + /* move the msk reference ownership to the subflow */ + subflow_req->msk = NULL; + ctx->conn = (struct sock *)owner; + + if (subflow_use_different_sport(owner, sk)) { + pr_debug("ack inet_sport=%d %d", + ntohs(inet_sk(sk)->inet_sport), + ntohs(inet_sk((struct sock *)owner)->inet_sport)); + if (!mptcp_pm_sport_in_anno_list(owner, sk)) { + SUBFLOW_REQ_INC_STATS(req, MPTCP_MIB_MISMATCHPORTACKRX); + goto dispose_child; + } + SUBFLOW_REQ_INC_STATS(req, MPTCP_MIB_JOINPORTACKRX); + } + + if (!mptcp_finish_join(child)) + goto dispose_child; + + SUBFLOW_REQ_INC_STATS(req, MPTCP_MIB_JOINACKRX); + tcp_rsk(req)->drop_req = true; + } + } + + /* check for expected invariant - should never trigger, just help + * catching eariler subtle bugs + */ + WARN_ON_ONCE(child && *own_req && tcp_sk(child)->is_mptcp && + (!mptcp_subflow_ctx(child) || + !mptcp_subflow_ctx(child)->conn)); + return child; + +dispose_child: + mptcp_subflow_drop_ctx(child); + tcp_rsk(req)->drop_req = true; + inet_csk_prepare_for_destroy_sock(child); + tcp_done(child); + req->rsk_ops->send_reset(sk, skb); + + /* The last child reference will be released by the caller */ + return child; + +fallback: + mptcp_subflow_drop_ctx(child); + return child; +} + +static struct inet_connection_sock_af_ops subflow_specific __ro_after_init; +static struct proto tcp_prot_override __ro_after_init; + +enum mapping_status { + MAPPING_OK, + MAPPING_INVALID, + MAPPING_EMPTY, + MAPPING_DATA_FIN, + MAPPING_DUMMY, + MAPPING_BAD_CSUM +}; + +static void dbg_bad_map(struct mptcp_subflow_context *subflow, u32 ssn) +{ + pr_debug("Bad mapping: ssn=%d map_seq=%d map_data_len=%d", + ssn, subflow->map_subflow_seq, subflow->map_data_len); +} + +static bool skb_is_fully_mapped(struct sock *ssk, struct sk_buff *skb) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + unsigned int skb_consumed; + + skb_consumed = tcp_sk(ssk)->copied_seq - TCP_SKB_CB(skb)->seq; + if (WARN_ON_ONCE(skb_consumed >= skb->len)) + return true; + + return skb->len - skb_consumed <= subflow->map_data_len - + mptcp_subflow_get_map_offset(subflow); +} + +static bool validate_mapping(struct sock *ssk, struct sk_buff *skb) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + u32 ssn = tcp_sk(ssk)->copied_seq - subflow->ssn_offset; + + if (unlikely(before(ssn, subflow->map_subflow_seq))) { + /* Mapping covers data later in the subflow stream, + * currently unsupported. + */ + dbg_bad_map(subflow, ssn); + return false; + } + if (unlikely(!before(ssn, subflow->map_subflow_seq + + subflow->map_data_len))) { + /* Mapping does covers past subflow data, invalid */ + dbg_bad_map(subflow, ssn); + return false; + } + return true; +} + +static enum mapping_status validate_data_csum(struct sock *ssk, struct sk_buff *skb, + bool csum_reqd) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + u32 offset, seq, delta; + __sum16 csum; + int len; + + if (!csum_reqd) + return MAPPING_OK; + + /* mapping already validated on previous traversal */ + if (subflow->map_csum_len == subflow->map_data_len) + return MAPPING_OK; + + /* traverse the receive queue, ensuring it contains a full + * DSS mapping and accumulating the related csum. + * Preserve the accoumlate csum across multiple calls, to compute + * the csum only once + */ + delta = subflow->map_data_len - subflow->map_csum_len; + for (;;) { + seq = tcp_sk(ssk)->copied_seq + subflow->map_csum_len; + offset = seq - TCP_SKB_CB(skb)->seq; + + /* if the current skb has not been accounted yet, csum its contents + * up to the amount covered by the current DSS + */ + if (offset < skb->len) { + __wsum csum; + + len = min(skb->len - offset, delta); + csum = skb_checksum(skb, offset, len, 0); + subflow->map_data_csum = csum_block_add(subflow->map_data_csum, csum, + subflow->map_csum_len); + + delta -= len; + subflow->map_csum_len += len; + } + if (delta == 0) + break; + + if (skb_queue_is_last(&ssk->sk_receive_queue, skb)) { + /* if this subflow is closed, the partial mapping + * will be never completed; flush the pending skbs, so + * that subflow_sched_work_if_closed() can kick in + */ + if (unlikely(ssk->sk_state == TCP_CLOSE)) + while ((skb = skb_peek(&ssk->sk_receive_queue))) + sk_eat_skb(ssk, skb); + + /* not enough data to validate the csum */ + return MAPPING_EMPTY; + } + + /* the DSS mapping for next skbs will be validated later, + * when a get_mapping_status call will process such skb + */ + skb = skb->next; + } + + /* note that 'map_data_len' accounts only for the carried data, does + * not include the eventual seq increment due to the data fin, + * while the pseudo header requires the original DSS data len, + * including that + */ + csum = __mptcp_make_csum(subflow->map_seq, + subflow->map_subflow_seq, + subflow->map_data_len + subflow->map_data_fin, + subflow->map_data_csum); + if (unlikely(csum)) { + MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_DATACSUMERR); + return MAPPING_BAD_CSUM; + } + + subflow->valid_csum_seen = 1; + return MAPPING_OK; +} + +static enum mapping_status get_mapping_status(struct sock *ssk, + struct mptcp_sock *msk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + bool csum_reqd = READ_ONCE(msk->csum_enabled); + struct mptcp_ext *mpext; + struct sk_buff *skb; + u16 data_len; + u64 map_seq; + + skb = skb_peek(&ssk->sk_receive_queue); + if (!skb) + return MAPPING_EMPTY; + + if (mptcp_check_fallback(ssk)) + return MAPPING_DUMMY; + + mpext = mptcp_get_ext(skb); + if (!mpext || !mpext->use_map) { + if (!subflow->map_valid && !skb->len) { + /* the TCP stack deliver 0 len FIN pkt to the receive + * queue, that is the only 0len pkts ever expected here, + * and we can admit no mapping only for 0 len pkts + */ + if (!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN)) + WARN_ONCE(1, "0len seq %d:%d flags %x", + TCP_SKB_CB(skb)->seq, + TCP_SKB_CB(skb)->end_seq, + TCP_SKB_CB(skb)->tcp_flags); + sk_eat_skb(ssk, skb); + return MAPPING_EMPTY; + } + + if (!subflow->map_valid) + return MAPPING_INVALID; + + goto validate_seq; + } + + trace_get_mapping_status(mpext); + + data_len = mpext->data_len; + if (data_len == 0) { + pr_debug("infinite mapping received"); + MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_INFINITEMAPRX); + subflow->map_data_len = 0; + return MAPPING_INVALID; + } + + if (mpext->data_fin == 1) { + if (data_len == 1) { + bool updated = mptcp_update_rcv_data_fin(msk, mpext->data_seq, + mpext->dsn64); + pr_debug("DATA_FIN with no payload seq=%llu", mpext->data_seq); + if (subflow->map_valid) { + /* A DATA_FIN might arrive in a DSS + * option before the previous mapping + * has been fully consumed. Continue + * handling the existing mapping. + */ + skb_ext_del(skb, SKB_EXT_MPTCP); + return MAPPING_OK; + } else { + if (updated) + mptcp_schedule_work((struct sock *)msk); + + return MAPPING_DATA_FIN; + } + } else { + u64 data_fin_seq = mpext->data_seq + data_len - 1; + + /* If mpext->data_seq is a 32-bit value, data_fin_seq + * must also be limited to 32 bits. + */ + if (!mpext->dsn64) + data_fin_seq &= GENMASK_ULL(31, 0); + + mptcp_update_rcv_data_fin(msk, data_fin_seq, mpext->dsn64); + pr_debug("DATA_FIN with mapping seq=%llu dsn64=%d", + data_fin_seq, mpext->dsn64); + } + + /* Adjust for DATA_FIN using 1 byte of sequence space */ + data_len--; + } + + map_seq = mptcp_expand_seq(READ_ONCE(msk->ack_seq), mpext->data_seq, mpext->dsn64); + WRITE_ONCE(mptcp_sk(subflow->conn)->use_64bit_ack, !!mpext->dsn64); + + if (subflow->map_valid) { + /* Allow replacing only with an identical map */ + if (subflow->map_seq == map_seq && + subflow->map_subflow_seq == mpext->subflow_seq && + subflow->map_data_len == data_len && + subflow->map_csum_reqd == mpext->csum_reqd) { + skb_ext_del(skb, SKB_EXT_MPTCP); + goto validate_csum; + } + + /* If this skb data are fully covered by the current mapping, + * the new map would need caching, which is not supported + */ + if (skb_is_fully_mapped(ssk, skb)) { + MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_DSSNOMATCH); + return MAPPING_INVALID; + } + + /* will validate the next map after consuming the current one */ + goto validate_csum; + } + + subflow->map_seq = map_seq; + subflow->map_subflow_seq = mpext->subflow_seq; + subflow->map_data_len = data_len; + subflow->map_valid = 1; + subflow->map_data_fin = mpext->data_fin; + subflow->mpc_map = mpext->mpc_map; + subflow->map_csum_reqd = mpext->csum_reqd; + subflow->map_csum_len = 0; + subflow->map_data_csum = csum_unfold(mpext->csum); + + /* Cfr RFC 8684 Section 3.3.0 */ + if (unlikely(subflow->map_csum_reqd != csum_reqd)) + return MAPPING_INVALID; + + pr_debug("new map seq=%llu subflow_seq=%u data_len=%u csum=%d:%u", + subflow->map_seq, subflow->map_subflow_seq, + subflow->map_data_len, subflow->map_csum_reqd, + subflow->map_data_csum); + +validate_seq: + /* we revalidate valid mapping on new skb, because we must ensure + * the current skb is completely covered by the available mapping + */ + if (!validate_mapping(ssk, skb)) { + MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_DSSTCPMISMATCH); + return MAPPING_INVALID; + } + + skb_ext_del(skb, SKB_EXT_MPTCP); + +validate_csum: + return validate_data_csum(ssk, skb, csum_reqd); +} + +static void mptcp_subflow_discard_data(struct sock *ssk, struct sk_buff *skb, + u64 limit) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + bool fin = TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN; + u32 incr; + + incr = limit >= skb->len ? skb->len + fin : limit; + + pr_debug("discarding=%d len=%d seq=%d", incr, skb->len, + subflow->map_subflow_seq); + MPTCP_INC_STATS(sock_net(ssk), MPTCP_MIB_DUPDATA); + tcp_sk(ssk)->copied_seq += incr; + if (!before(tcp_sk(ssk)->copied_seq, TCP_SKB_CB(skb)->end_seq)) + sk_eat_skb(ssk, skb); + if (mptcp_subflow_get_map_offset(subflow) >= subflow->map_data_len) + subflow->map_valid = 0; +} + +/* sched mptcp worker to remove the subflow if no more data is pending */ +static void subflow_sched_work_if_closed(struct mptcp_sock *msk, struct sock *ssk) +{ + if (likely(ssk->sk_state != TCP_CLOSE)) + return; + + if (skb_queue_empty(&ssk->sk_receive_queue) && + !test_and_set_bit(MPTCP_WORK_CLOSE_SUBFLOW, &msk->flags)) + mptcp_schedule_work((struct sock *)msk); +} + +static bool subflow_can_fallback(struct mptcp_subflow_context *subflow) +{ + struct mptcp_sock *msk = mptcp_sk(subflow->conn); + + if (subflow->mp_join) + return false; + else if (READ_ONCE(msk->csum_enabled)) + return !subflow->valid_csum_seen; + else + return !subflow->fully_established; +} + +static void mptcp_subflow_fail(struct mptcp_sock *msk, struct sock *ssk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + unsigned long fail_tout; + + /* greceful failure can happen only on the MPC subflow */ + if (WARN_ON_ONCE(ssk != READ_ONCE(msk->first))) + return; + + /* since the close timeout take precedence on the fail one, + * no need to start the latter when the first is already set + */ + if (sock_flag((struct sock *)msk, SOCK_DEAD)) + return; + + /* we don't need extreme accuracy here, use a zero fail_tout as special + * value meaning no fail timeout at all; + */ + fail_tout = jiffies + TCP_RTO_MAX; + if (!fail_tout) + fail_tout = 1; + WRITE_ONCE(subflow->fail_tout, fail_tout); + tcp_send_ack(ssk); + + mptcp_reset_tout_timer(msk, subflow->fail_tout); +} + +static bool subflow_check_data_avail(struct sock *ssk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + enum mapping_status status; + struct mptcp_sock *msk; + struct sk_buff *skb; + + if (!skb_peek(&ssk->sk_receive_queue)) + WRITE_ONCE(subflow->data_avail, MPTCP_SUBFLOW_NODATA); + if (subflow->data_avail) + return true; + + msk = mptcp_sk(subflow->conn); + for (;;) { + u64 ack_seq; + u64 old_ack; + + status = get_mapping_status(ssk, msk); + trace_subflow_check_data_avail(status, skb_peek(&ssk->sk_receive_queue)); + if (unlikely(status == MAPPING_INVALID || status == MAPPING_DUMMY || + status == MAPPING_BAD_CSUM)) + goto fallback; + + if (status != MAPPING_OK) + goto no_data; + + skb = skb_peek(&ssk->sk_receive_queue); + if (WARN_ON_ONCE(!skb)) + goto no_data; + + /* if msk lacks the remote key, this subflow must provide an + * MP_CAPABLE-based mapping + */ + if (unlikely(!READ_ONCE(msk->can_ack))) { + if (!subflow->mpc_map) + goto fallback; + WRITE_ONCE(msk->remote_key, subflow->remote_key); + WRITE_ONCE(msk->ack_seq, subflow->map_seq); + WRITE_ONCE(msk->can_ack, true); + } + + old_ack = READ_ONCE(msk->ack_seq); + ack_seq = mptcp_subflow_get_mapped_dsn(subflow); + pr_debug("msk ack_seq=%llx subflow ack_seq=%llx", old_ack, + ack_seq); + if (unlikely(before64(ack_seq, old_ack))) { + mptcp_subflow_discard_data(ssk, skb, old_ack - ack_seq); + continue; + } + + WRITE_ONCE(subflow->data_avail, MPTCP_SUBFLOW_DATA_AVAIL); + break; + } + return true; + +no_data: + subflow_sched_work_if_closed(msk, ssk); + return false; + +fallback: + if (!__mptcp_check_fallback(msk)) { + /* RFC 8684 section 3.7. */ + if (status == MAPPING_BAD_CSUM && + (subflow->mp_join || subflow->valid_csum_seen)) { + subflow->send_mp_fail = 1; + + if (!READ_ONCE(msk->allow_infinite_fallback)) { + subflow->reset_transient = 0; + subflow->reset_reason = MPTCP_RST_EMIDDLEBOX; + goto reset; + } + mptcp_subflow_fail(msk, ssk); + WRITE_ONCE(subflow->data_avail, MPTCP_SUBFLOW_DATA_AVAIL); + return true; + } + + if (!subflow_can_fallback(subflow) && subflow->map_data_len) { + /* fatal protocol error, close the socket. + * subflow_error_report() will introduce the appropriate barriers + */ + subflow->reset_transient = 0; + subflow->reset_reason = MPTCP_RST_EMPTCP; + +reset: + WRITE_ONCE(ssk->sk_err, EBADMSG); + tcp_set_state(ssk, TCP_CLOSE); + while ((skb = skb_peek(&ssk->sk_receive_queue))) + sk_eat_skb(ssk, skb); + tcp_send_active_reset(ssk, GFP_ATOMIC); + WRITE_ONCE(subflow->data_avail, MPTCP_SUBFLOW_NODATA); + return false; + } + + mptcp_do_fallback(ssk); + } + + skb = skb_peek(&ssk->sk_receive_queue); + subflow->map_valid = 1; + subflow->map_seq = READ_ONCE(msk->ack_seq); + subflow->map_data_len = skb->len; + subflow->map_subflow_seq = tcp_sk(ssk)->copied_seq - subflow->ssn_offset; + WRITE_ONCE(subflow->data_avail, MPTCP_SUBFLOW_DATA_AVAIL); + return true; +} + +bool mptcp_subflow_data_available(struct sock *sk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + + /* check if current mapping is still valid */ + if (subflow->map_valid && + mptcp_subflow_get_map_offset(subflow) >= subflow->map_data_len) { + subflow->map_valid = 0; + WRITE_ONCE(subflow->data_avail, MPTCP_SUBFLOW_NODATA); + + pr_debug("Done with mapping: seq=%u data_len=%u", + subflow->map_subflow_seq, + subflow->map_data_len); + } + + return subflow_check_data_avail(sk); +} + +/* If ssk has an mptcp parent socket, use the mptcp rcvbuf occupancy, + * not the ssk one. + * + * In mptcp, rwin is about the mptcp-level connection data. + * + * Data that is still on the ssk rx queue can thus be ignored, + * as far as mptcp peer is concerned that data is still inflight. + * DSS ACK is updated when skb is moved to the mptcp rx queue. + */ +void mptcp_space(const struct sock *ssk, int *space, int *full_space) +{ + const struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + const struct sock *sk = subflow->conn; + + *space = __mptcp_space(sk); + *full_space = tcp_full_space(sk); +} + +static void subflow_error_report(struct sock *ssk) +{ + struct sock *sk = mptcp_subflow_ctx(ssk)->conn; + + /* bail early if this is a no-op, so that we avoid introducing a + * problematic lockdep dependency between TCP accept queue lock + * and msk socket spinlock + */ + if (!sk->sk_socket) + return; + + mptcp_data_lock(sk); + if (!sock_owned_by_user(sk)) + __mptcp_error_report(sk); + else + __set_bit(MPTCP_ERROR_REPORT, &mptcp_sk(sk)->cb_flags); + mptcp_data_unlock(sk); +} + +static void subflow_data_ready(struct sock *sk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + u16 state = 1 << inet_sk_state_load(sk); + struct sock *parent = subflow->conn; + struct mptcp_sock *msk; + + msk = mptcp_sk(parent); + if (state & TCPF_LISTEN) { + /* MPJ subflow are removed from accept queue before reaching here, + * avoid stray wakeups + */ + if (reqsk_queue_empty(&inet_csk(sk)->icsk_accept_queue)) + return; + + parent->sk_data_ready(parent); + return; + } + + WARN_ON_ONCE(!__mptcp_check_fallback(msk) && !subflow->mp_capable && + !subflow->mp_join && !(state & TCPF_CLOSE)); + + if (mptcp_subflow_data_available(sk)) + mptcp_data_ready(parent, sk); + else if (unlikely(sk->sk_err)) + subflow_error_report(sk); +} + +static void subflow_write_space(struct sock *ssk) +{ + struct sock *sk = mptcp_subflow_ctx(ssk)->conn; + + mptcp_propagate_sndbuf(sk, ssk); + mptcp_write_space(sk); +} + +static const struct inet_connection_sock_af_ops * +subflow_default_af_ops(struct sock *sk) +{ +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + if (sk->sk_family == AF_INET6) + return &subflow_v6_specific; +#endif + return &subflow_specific; +} + +#if IS_ENABLED(CONFIG_MPTCP_IPV6) +void mptcpv6_handle_mapped(struct sock *sk, bool mapped) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + struct inet_connection_sock *icsk = inet_csk(sk); + const struct inet_connection_sock_af_ops *target; + + target = mapped ? &subflow_v6m_specific : subflow_default_af_ops(sk); + + pr_debug("subflow=%p family=%d ops=%p target=%p mapped=%d", + subflow, sk->sk_family, icsk->icsk_af_ops, target, mapped); + + if (likely(icsk->icsk_af_ops == target)) + return; + + subflow->icsk_af_ops = icsk->icsk_af_ops; + icsk->icsk_af_ops = target; +} +#endif + +void mptcp_info2sockaddr(const struct mptcp_addr_info *info, + struct sockaddr_storage *addr, + unsigned short family) +{ + memset(addr, 0, sizeof(*addr)); + addr->ss_family = family; + if (addr->ss_family == AF_INET) { + struct sockaddr_in *in_addr = (struct sockaddr_in *)addr; + + if (info->family == AF_INET) + in_addr->sin_addr = info->addr; +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + else if (ipv6_addr_v4mapped(&info->addr6)) + in_addr->sin_addr.s_addr = info->addr6.s6_addr32[3]; +#endif + in_addr->sin_port = info->port; + } +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + else if (addr->ss_family == AF_INET6) { + struct sockaddr_in6 *in6_addr = (struct sockaddr_in6 *)addr; + + if (info->family == AF_INET) + ipv6_addr_set_v4mapped(info->addr.s_addr, + &in6_addr->sin6_addr); + else + in6_addr->sin6_addr = info->addr6; + in6_addr->sin6_port = info->port; + } +#endif +} + +int __mptcp_subflow_connect(struct sock *sk, const struct mptcp_addr_info *loc, + const struct mptcp_addr_info *remote) +{ + struct mptcp_sock *msk = mptcp_sk(sk); + struct mptcp_subflow_context *subflow; + struct sockaddr_storage addr; + int remote_id = remote->id; + int local_id = loc->id; + int err = -ENOTCONN; + struct socket *sf; + struct sock *ssk; + u32 remote_token; + int addrlen; + int ifindex; + u8 flags; + + if (!mptcp_is_fully_established(sk)) + goto err_out; + + err = mptcp_subflow_create_socket(sk, loc->family, &sf); + if (err) + goto err_out; + + ssk = sf->sk; + subflow = mptcp_subflow_ctx(ssk); + do { + get_random_bytes(&subflow->local_nonce, sizeof(u32)); + } while (!subflow->local_nonce); + + if (local_id) + subflow_set_local_id(subflow, local_id); + + mptcp_pm_get_flags_and_ifindex_by_id(msk, local_id, + &flags, &ifindex); + subflow->remote_key = msk->remote_key; + subflow->local_key = msk->local_key; + subflow->token = msk->token; + mptcp_info2sockaddr(loc, &addr, ssk->sk_family); + + addrlen = sizeof(struct sockaddr_in); +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + if (addr.ss_family == AF_INET6) + addrlen = sizeof(struct sockaddr_in6); +#endif + mptcp_sockopt_sync(msk, ssk); + + ssk->sk_bound_dev_if = ifindex; + err = kernel_bind(sf, (struct sockaddr *)&addr, addrlen); + if (err) + goto failed; + + mptcp_crypto_key_sha(subflow->remote_key, &remote_token, NULL); + pr_debug("msk=%p remote_token=%u local_id=%d remote_id=%d", msk, + remote_token, local_id, remote_id); + subflow->remote_token = remote_token; + subflow->remote_id = remote_id; + subflow->request_join = 1; + subflow->request_bkup = !!(flags & MPTCP_PM_ADDR_FLAG_BACKUP); + mptcp_info2sockaddr(remote, &addr, ssk->sk_family); + + sock_hold(ssk); + list_add_tail(&subflow->node, &msk->conn_list); + err = kernel_connect(sf, (struct sockaddr *)&addr, addrlen, O_NONBLOCK); + if (err && err != -EINPROGRESS) + goto failed_unlink; + + /* discard the subflow socket */ + mptcp_sock_graft(ssk, sk->sk_socket); + iput(SOCK_INODE(sf)); + WRITE_ONCE(msk->allow_infinite_fallback, false); + mptcp_stop_tout_timer(sk); + return 0; + +failed_unlink: + list_del(&subflow->node); + sock_put(mptcp_subflow_tcp_sock(subflow)); + +failed: + subflow->disposable = 1; + sock_release(sf); + +err_out: + /* we account subflows before the creation, and this failures will not + * be caught by sk_state_change() + */ + mptcp_pm_close_subflow(msk); + return err; +} + +static void mptcp_attach_cgroup(struct sock *parent, struct sock *child) +{ +#ifdef CONFIG_SOCK_CGROUP_DATA + struct sock_cgroup_data *parent_skcd = &parent->sk_cgrp_data, + *child_skcd = &child->sk_cgrp_data; + + /* only the additional subflows created by kworkers have to be modified */ + if (cgroup_id(sock_cgroup_ptr(parent_skcd)) != + cgroup_id(sock_cgroup_ptr(child_skcd))) { +#ifdef CONFIG_MEMCG + struct mem_cgroup *memcg = parent->sk_memcg; + + mem_cgroup_sk_free(child); + if (memcg && css_tryget(&memcg->css)) + child->sk_memcg = memcg; +#endif /* CONFIG_MEMCG */ + + cgroup_sk_free(child_skcd); + *child_skcd = *parent_skcd; + cgroup_sk_clone(child_skcd); + } +#endif /* CONFIG_SOCK_CGROUP_DATA */ +} + +static void mptcp_subflow_ops_override(struct sock *ssk) +{ +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + if (ssk->sk_prot == &tcpv6_prot) + ssk->sk_prot = &tcpv6_prot_override; + else +#endif + ssk->sk_prot = &tcp_prot_override; +} + +static void mptcp_subflow_ops_undo_override(struct sock *ssk) +{ +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + if (ssk->sk_prot == &tcpv6_prot_override) + ssk->sk_prot = &tcpv6_prot; + else +#endif + ssk->sk_prot = &tcp_prot; +} + +int mptcp_subflow_create_socket(struct sock *sk, unsigned short family, + struct socket **new_sock) +{ + struct mptcp_subflow_context *subflow; + struct net *net = sock_net(sk); + struct socket *sf; + int err; + + /* un-accepted server sockets can reach here - on bad configuration + * bail early to avoid greater trouble later + */ + if (unlikely(!sk->sk_socket)) + return -EINVAL; + + err = sock_create_kern(net, family, SOCK_STREAM, IPPROTO_TCP, &sf); + if (err) + return err; + + lock_sock_nested(sf->sk, SINGLE_DEPTH_NESTING); + + /* the newly created socket has to be in the same cgroup as its parent */ + mptcp_attach_cgroup(sk, sf->sk); + + /* kernel sockets do not by default acquire net ref, but TCP timer + * needs it. + */ + sf->sk->sk_net_refcnt = 1; + get_net_track(net, &sf->sk->ns_tracker, GFP_KERNEL); + sock_inuse_add(net, 1); + err = tcp_set_ulp(sf->sk, "mptcp"); + release_sock(sf->sk); + + if (err) { + sock_release(sf); + return err; + } + + /* the newly created socket really belongs to the owning MPTCP master + * socket, even if for additional subflows the allocation is performed + * by a kernel workqueue. Adjust inode references, so that the + * procfs/diag interfaces really show this one belonging to the correct + * user. + */ + SOCK_INODE(sf)->i_ino = SOCK_INODE(sk->sk_socket)->i_ino; + SOCK_INODE(sf)->i_uid = SOCK_INODE(sk->sk_socket)->i_uid; + SOCK_INODE(sf)->i_gid = SOCK_INODE(sk->sk_socket)->i_gid; + + subflow = mptcp_subflow_ctx(sf->sk); + pr_debug("subflow=%p", subflow); + + *new_sock = sf; + sock_hold(sk); + subflow->conn = sk; + mptcp_subflow_ops_override(sf->sk); + + return 0; +} + +static struct mptcp_subflow_context *subflow_create_ctx(struct sock *sk, + gfp_t priority) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct mptcp_subflow_context *ctx; + + ctx = kzalloc(sizeof(*ctx), priority); + if (!ctx) + return NULL; + + rcu_assign_pointer(icsk->icsk_ulp_data, ctx); + INIT_LIST_HEAD(&ctx->node); + INIT_LIST_HEAD(&ctx->delegated_node); + + pr_debug("subflow=%p", ctx); + + ctx->tcp_sock = sk; + + return ctx; +} + +static void __subflow_state_change(struct sock *sk) +{ + struct socket_wq *wq; + + rcu_read_lock(); + wq = rcu_dereference(sk->sk_wq); + if (skwq_has_sleeper(wq)) + wake_up_interruptible_all(&wq->wait); + rcu_read_unlock(); +} + +static bool subflow_is_done(const struct sock *sk) +{ + return sk->sk_shutdown & RCV_SHUTDOWN || sk->sk_state == TCP_CLOSE; +} + +static void subflow_state_change(struct sock *sk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + struct sock *parent = subflow->conn; + struct mptcp_sock *msk; + + __subflow_state_change(sk); + + msk = mptcp_sk(parent); + if (subflow_simultaneous_connect(sk)) { + mptcp_propagate_sndbuf(parent, sk); + mptcp_do_fallback(sk); + mptcp_rcv_space_init(msk, sk); + pr_fallback(msk); + subflow->conn_finished = 1; + mptcp_set_connected(parent); + } + + /* as recvmsg() does not acquire the subflow socket for ssk selection + * a fin packet carrying a DSS can be unnoticed if we don't trigger + * the data available machinery here. + */ + if (mptcp_subflow_data_available(sk)) + mptcp_data_ready(parent, sk); + else if (unlikely(sk->sk_err)) + subflow_error_report(sk); + + subflow_sched_work_if_closed(mptcp_sk(parent), sk); + + /* when the fallback subflow closes the rx side, trigger a 'dummy' + * ingress data fin, so that the msk state will follow along + */ + if (__mptcp_check_fallback(msk) && subflow_is_done(sk) && msk->first == sk && + mptcp_update_rcv_data_fin(msk, READ_ONCE(msk->ack_seq), true)) + mptcp_schedule_work(parent); +} + +void mptcp_subflow_queue_clean(struct sock *listener_sk, struct sock *listener_ssk) +{ + struct request_sock_queue *queue = &inet_csk(listener_ssk)->icsk_accept_queue; + struct request_sock *req, *head, *tail; + struct mptcp_subflow_context *subflow; + struct sock *sk, *ssk; + + /* Due to lock dependencies no relevant lock can be acquired under rskq_lock. + * Splice the req list, so that accept() can not reach the pending ssk after + * the listener socket is released below. + */ + spin_lock_bh(&queue->rskq_lock); + head = queue->rskq_accept_head; + tail = queue->rskq_accept_tail; + queue->rskq_accept_head = NULL; + queue->rskq_accept_tail = NULL; + spin_unlock_bh(&queue->rskq_lock); + + if (!head) + return; + + /* can't acquire the msk socket lock under the subflow one, + * or will cause ABBA deadlock + */ + release_sock(listener_ssk); + + for (req = head; req; req = req->dl_next) { + ssk = req->sk; + if (!sk_is_mptcp(ssk)) + continue; + + subflow = mptcp_subflow_ctx(ssk); + if (!subflow || !subflow->conn) + continue; + + sk = subflow->conn; + sock_hold(sk); + + lock_sock_nested(sk, SINGLE_DEPTH_NESTING); + __mptcp_unaccepted_force_close(sk); + release_sock(sk); + + /* lockdep will report a false positive ABBA deadlock + * between cancel_work_sync and the listener socket. + * The involved locks belong to different sockets WRT + * the existing AB chain. + * Using a per socket key is problematic as key + * deregistration requires process context and must be + * performed at socket disposal time, in atomic + * context. + * Just tell lockdep to consider the listener socket + * released here. + */ + mutex_release(&listener_sk->sk_lock.dep_map, _RET_IP_); + mptcp_cancel_work(sk); + mutex_acquire(&listener_sk->sk_lock.dep_map, 0, 0, _RET_IP_); + + sock_put(sk); + } + + /* we are still under the listener msk socket lock */ + lock_sock_nested(listener_ssk, SINGLE_DEPTH_NESTING); + + /* restore the listener queue, to let the TCP code clean it up */ + spin_lock_bh(&queue->rskq_lock); + WARN_ON_ONCE(queue->rskq_accept_head); + queue->rskq_accept_head = head; + queue->rskq_accept_tail = tail; + spin_unlock_bh(&queue->rskq_lock); +} + +static int subflow_ulp_init(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct mptcp_subflow_context *ctx; + struct tcp_sock *tp = tcp_sk(sk); + int err = 0; + + /* disallow attaching ULP to a socket unless it has been + * created with sock_create_kern() + */ + if (!sk->sk_kern_sock) { + err = -EOPNOTSUPP; + goto out; + } + + ctx = subflow_create_ctx(sk, GFP_KERNEL); + if (!ctx) { + err = -ENOMEM; + goto out; + } + + pr_debug("subflow=%p, family=%d", ctx, sk->sk_family); + + tp->is_mptcp = 1; + ctx->icsk_af_ops = icsk->icsk_af_ops; + icsk->icsk_af_ops = subflow_default_af_ops(sk); + ctx->tcp_state_change = sk->sk_state_change; + ctx->tcp_error_report = sk->sk_error_report; + + WARN_ON_ONCE(sk->sk_data_ready != sock_def_readable); + WARN_ON_ONCE(sk->sk_write_space != sk_stream_write_space); + + sk->sk_data_ready = subflow_data_ready; + sk->sk_write_space = subflow_write_space; + sk->sk_state_change = subflow_state_change; + sk->sk_error_report = subflow_error_report; +out: + return err; +} + +static void subflow_ulp_release(struct sock *ssk) +{ + struct mptcp_subflow_context *ctx = mptcp_subflow_ctx(ssk); + bool release = true; + struct sock *sk; + + if (!ctx) + return; + + sk = ctx->conn; + if (sk) { + /* if the msk has been orphaned, keep the ctx + * alive, will be freed by __mptcp_close_ssk(), + * when the subflow is still unaccepted + */ + release = ctx->disposable || list_empty(&ctx->node); + + /* inet_child_forget() does not call sk_state_change(), + * explicitly trigger the socket close machinery + */ + if (!release && !test_and_set_bit(MPTCP_WORK_CLOSE_SUBFLOW, + &mptcp_sk(sk)->flags)) + mptcp_schedule_work(sk); + sock_put(sk); + } + + mptcp_subflow_ops_undo_override(ssk); + if (release) + kfree_rcu(ctx, rcu); +} + +static void subflow_ulp_clone(const struct request_sock *req, + struct sock *newsk, + const gfp_t priority) +{ + struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); + struct mptcp_subflow_context *old_ctx = mptcp_subflow_ctx(newsk); + struct mptcp_subflow_context *new_ctx; + + if (!tcp_rsk(req)->is_mptcp || + (!subflow_req->mp_capable && !subflow_req->mp_join)) { + subflow_ulp_fallback(newsk, old_ctx); + return; + } + + new_ctx = subflow_create_ctx(newsk, priority); + if (!new_ctx) { + subflow_ulp_fallback(newsk, old_ctx); + return; + } + + new_ctx->conn_finished = 1; + new_ctx->icsk_af_ops = old_ctx->icsk_af_ops; + new_ctx->tcp_state_change = old_ctx->tcp_state_change; + new_ctx->tcp_error_report = old_ctx->tcp_error_report; + new_ctx->rel_write_seq = 1; + new_ctx->tcp_sock = newsk; + + if (subflow_req->mp_capable) { + /* see comments in subflow_syn_recv_sock(), MPTCP connection + * is fully established only after we receive the remote key + */ + new_ctx->mp_capable = 1; + new_ctx->local_key = subflow_req->local_key; + new_ctx->token = subflow_req->token; + new_ctx->ssn_offset = subflow_req->ssn_offset; + new_ctx->idsn = subflow_req->idsn; + + /* this is the first subflow, id is always 0 */ + new_ctx->local_id_valid = 1; + } else if (subflow_req->mp_join) { + new_ctx->ssn_offset = subflow_req->ssn_offset; + new_ctx->mp_join = 1; + new_ctx->fully_established = 1; + new_ctx->backup = subflow_req->backup; + new_ctx->remote_id = subflow_req->remote_id; + new_ctx->token = subflow_req->token; + new_ctx->thmac = subflow_req->thmac; + + /* the subflow req id is valid, fetched via subflow_check_req() + * and subflow_token_join_request() + */ + subflow_set_local_id(new_ctx, subflow_req->local_id); + } +} + +static void tcp_release_cb_override(struct sock *ssk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk); + long status; + + /* process and clear all the pending actions, but leave the subflow into + * the napi queue. To respect locking, only the same CPU that originated + * the action can touch the list. mptcp_napi_poll will take care of it. + */ + status = set_mask_bits(&subflow->delegated_status, MPTCP_DELEGATE_ACTIONS_MASK, 0); + if (status) + mptcp_subflow_process_delegated(ssk, status); + + tcp_release_cb(ssk); +} + +static int tcp_abort_override(struct sock *ssk, int err) +{ + /* closing a listener subflow requires a great deal of care. + * keep it simple and just prevent such operation + */ + if (inet_sk_state_load(ssk) == TCP_LISTEN) + return -EINVAL; + + return tcp_abort(ssk, err); +} + +static struct tcp_ulp_ops subflow_ulp_ops __read_mostly = { + .name = "mptcp", + .owner = THIS_MODULE, + .init = subflow_ulp_init, + .release = subflow_ulp_release, + .clone = subflow_ulp_clone, +}; + +static int subflow_ops_init(struct request_sock_ops *subflow_ops) +{ + subflow_ops->obj_size = sizeof(struct mptcp_subflow_request_sock); + + subflow_ops->slab = kmem_cache_create(subflow_ops->slab_name, + subflow_ops->obj_size, 0, + SLAB_ACCOUNT | + SLAB_TYPESAFE_BY_RCU, + NULL); + if (!subflow_ops->slab) + return -ENOMEM; + + return 0; +} + +void __init mptcp_subflow_init(void) +{ + mptcp_subflow_v4_request_sock_ops = tcp_request_sock_ops; + mptcp_subflow_v4_request_sock_ops.slab_name = "request_sock_subflow_v4"; + mptcp_subflow_v4_request_sock_ops.destructor = subflow_v4_req_destructor; + + if (subflow_ops_init(&mptcp_subflow_v4_request_sock_ops) != 0) + panic("MPTCP: failed to init subflow v4 request sock ops\n"); + + subflow_request_sock_ipv4_ops = tcp_request_sock_ipv4_ops; + subflow_request_sock_ipv4_ops.route_req = subflow_v4_route_req; + + subflow_specific = ipv4_specific; + subflow_specific.conn_request = subflow_v4_conn_request; + subflow_specific.syn_recv_sock = subflow_syn_recv_sock; + subflow_specific.sk_rx_dst_set = subflow_finish_connect; + subflow_specific.rebuild_header = subflow_rebuild_header; + + tcp_prot_override = tcp_prot; + tcp_prot_override.release_cb = tcp_release_cb_override; + tcp_prot_override.diag_destroy = tcp_abort_override; + +#if IS_ENABLED(CONFIG_MPTCP_IPV6) + /* In struct mptcp_subflow_request_sock, we assume the TCP request sock + * structures for v4 and v6 have the same size. It should not changed in + * the future but better to make sure to be warned if it is no longer + * the case. + */ + BUILD_BUG_ON(sizeof(struct tcp_request_sock) != sizeof(struct tcp6_request_sock)); + + mptcp_subflow_v6_request_sock_ops = tcp6_request_sock_ops; + mptcp_subflow_v6_request_sock_ops.slab_name = "request_sock_subflow_v6"; + mptcp_subflow_v6_request_sock_ops.destructor = subflow_v6_req_destructor; + + if (subflow_ops_init(&mptcp_subflow_v6_request_sock_ops) != 0) + panic("MPTCP: failed to init subflow v6 request sock ops\n"); + + subflow_request_sock_ipv6_ops = tcp_request_sock_ipv6_ops; + subflow_request_sock_ipv6_ops.route_req = subflow_v6_route_req; + + subflow_v6_specific = ipv6_specific; + subflow_v6_specific.conn_request = subflow_v6_conn_request; + subflow_v6_specific.syn_recv_sock = subflow_syn_recv_sock; + subflow_v6_specific.sk_rx_dst_set = subflow_finish_connect; + subflow_v6_specific.rebuild_header = subflow_v6_rebuild_header; + + subflow_v6m_specific = subflow_v6_specific; + subflow_v6m_specific.queue_xmit = ipv4_specific.queue_xmit; + subflow_v6m_specific.send_check = ipv4_specific.send_check; + subflow_v6m_specific.net_header_len = ipv4_specific.net_header_len; + subflow_v6m_specific.mtu_reduced = ipv4_specific.mtu_reduced; + subflow_v6m_specific.net_frag_header_len = 0; + subflow_v6m_specific.rebuild_header = subflow_rebuild_header; + + tcpv6_prot_override = tcpv6_prot; + tcpv6_prot_override.release_cb = tcp_release_cb_override; + tcpv6_prot_override.diag_destroy = tcp_abort_override; +#endif + + mptcp_diag_subflow_init(&subflow_ulp_ops); + + if (tcp_register_ulp(&subflow_ulp_ops) != 0) + panic("MPTCP: failed to register subflows to ULP\n"); +} diff --git a/net/mptcp/syncookies.c b/net/mptcp/syncookies.c new file mode 100644 index 000000000..7f2252634 --- /dev/null +++ b/net/mptcp/syncookies.c @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: GPL-2.0 +#include + +#include "protocol.h" + +/* Syncookies do not work for JOIN requests. + * + * Unlike MP_CAPABLE, where the ACK cookie contains the needed MPTCP + * options to reconstruct the initial syn state, MP_JOIN does not contain + * the token to obtain the mptcp socket nor the server-generated nonce + * that was used in the cookie SYN/ACK response. + * + * Keep a small best effort state table to store the syn/synack data, + * indexed by skb hash. + * + * A MP_JOIN SYN packet handled by syn cookies is only stored if the 32bit + * token matches a known mptcp connection that can still accept more subflows. + * + * There is no timeout handling -- state is only re-constructed + * when the TCP ACK passed the cookie validation check. + */ + +struct join_entry { + u32 token; + u32 remote_nonce; + u32 local_nonce; + u8 join_id; + u8 local_id; + u8 backup; + u8 valid; +}; + +#define COOKIE_JOIN_SLOTS 1024 + +static struct join_entry join_entries[COOKIE_JOIN_SLOTS] __cacheline_aligned_in_smp; +static spinlock_t join_entry_locks[COOKIE_JOIN_SLOTS] __cacheline_aligned_in_smp; + +static u32 mptcp_join_entry_hash(struct sk_buff *skb, struct net *net) +{ + static u32 mptcp_join_hash_secret __read_mostly; + struct tcphdr *th = tcp_hdr(skb); + u32 seq, i; + + net_get_random_once(&mptcp_join_hash_secret, + sizeof(mptcp_join_hash_secret)); + + if (th->syn) + seq = TCP_SKB_CB(skb)->seq; + else + seq = TCP_SKB_CB(skb)->seq - 1; + + i = jhash_3words(seq, net_hash_mix(net), + (__force __u32)th->source << 16 | (__force __u32)th->dest, + mptcp_join_hash_secret); + + return i % ARRAY_SIZE(join_entries); +} + +static void mptcp_join_store_state(struct join_entry *entry, + const struct mptcp_subflow_request_sock *subflow_req) +{ + entry->token = subflow_req->token; + entry->remote_nonce = subflow_req->remote_nonce; + entry->local_nonce = subflow_req->local_nonce; + entry->backup = subflow_req->backup; + entry->join_id = subflow_req->remote_id; + entry->local_id = subflow_req->local_id; + entry->valid = 1; +} + +void subflow_init_req_cookie_join_save(const struct mptcp_subflow_request_sock *subflow_req, + struct sk_buff *skb) +{ + struct net *net = read_pnet(&subflow_req->sk.req.ireq_net); + u32 i = mptcp_join_entry_hash(skb, net); + + /* No use in waiting if other cpu is already using this slot -- + * would overwrite the data that got stored. + */ + spin_lock_bh(&join_entry_locks[i]); + mptcp_join_store_state(&join_entries[i], subflow_req); + spin_unlock_bh(&join_entry_locks[i]); +} + +/* Called for a cookie-ack with MP_JOIN option present. + * Look up the saved state based on skb hash & check token matches msk + * in same netns. + * + * Caller will check msk can still accept another subflow. The hmac + * present in the cookie ACK mptcp option space will be checked later. + */ +bool mptcp_token_join_cookie_init_state(struct mptcp_subflow_request_sock *subflow_req, + struct sk_buff *skb) +{ + struct net *net = read_pnet(&subflow_req->sk.req.ireq_net); + u32 i = mptcp_join_entry_hash(skb, net); + struct mptcp_sock *msk; + struct join_entry *e; + + e = &join_entries[i]; + + spin_lock_bh(&join_entry_locks[i]); + + if (e->valid == 0) { + spin_unlock_bh(&join_entry_locks[i]); + return false; + } + + e->valid = 0; + + msk = mptcp_token_get_sock(net, e->token); + if (!msk) { + spin_unlock_bh(&join_entry_locks[i]); + return false; + } + + subflow_req->remote_nonce = e->remote_nonce; + subflow_req->local_nonce = e->local_nonce; + subflow_req->backup = e->backup; + subflow_req->remote_id = e->join_id; + subflow_req->token = e->token; + subflow_req->msk = msk; + spin_unlock_bh(&join_entry_locks[i]); + return true; +} + +void __init mptcp_join_cookie_init(void) +{ + int i; + + for (i = 0; i < COOKIE_JOIN_SLOTS; i++) + spin_lock_init(&join_entry_locks[i]); +} diff --git a/net/mptcp/token.c b/net/mptcp/token.c new file mode 100644 index 000000000..f52ee7b26 --- /dev/null +++ b/net/mptcp/token.c @@ -0,0 +1,416 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Multipath TCP token management + * Copyright (c) 2017 - 2019, Intel Corporation. + * + * Note: This code is based on mptcp_ctrl.c from multipath-tcp.org, + * authored by: + * + * Sébastien Barré + * Christoph Paasch + * Jaakko Korkeaniemi + * Gregory Detal + * Fabien Duchêne + * Andreas Seelinger + * Lavkesh Lahngir + * Andreas Ripke + * Vlad Dogaru + * Octavian Purdila + * John Ronan + * Catalin Nicutar + * Brandon Heller + */ + +#define pr_fmt(fmt) "MPTCP: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "protocol.h" + +#define TOKEN_MAX_CHAIN_LEN 4 + +struct token_bucket { + spinlock_t lock; + int chain_len; + struct hlist_nulls_head req_chain; + struct hlist_nulls_head msk_chain; +}; + +static struct token_bucket *token_hash __read_mostly; +static unsigned int token_mask __read_mostly; + +static struct token_bucket *token_bucket(u32 token) +{ + return &token_hash[token & token_mask]; +} + +/* called with bucket lock held */ +static struct mptcp_subflow_request_sock * +__token_lookup_req(struct token_bucket *t, u32 token) +{ + struct mptcp_subflow_request_sock *req; + struct hlist_nulls_node *pos; + + hlist_nulls_for_each_entry_rcu(req, pos, &t->req_chain, token_node) + if (req->token == token) + return req; + return NULL; +} + +/* called with bucket lock held */ +static struct mptcp_sock * +__token_lookup_msk(struct token_bucket *t, u32 token) +{ + struct hlist_nulls_node *pos; + struct sock *sk; + + sk_nulls_for_each_rcu(sk, pos, &t->msk_chain) + if (mptcp_sk(sk)->token == token) + return mptcp_sk(sk); + return NULL; +} + +static bool __token_bucket_busy(struct token_bucket *t, u32 token) +{ + return !token || t->chain_len >= TOKEN_MAX_CHAIN_LEN || + __token_lookup_req(t, token) || __token_lookup_msk(t, token); +} + +static void mptcp_crypto_key_gen_sha(u64 *key, u32 *token, u64 *idsn) +{ + /* we might consider a faster version that computes the key as a + * hash of some information available in the MPTCP socket. Use + * random data at the moment, as it's probably the safest option + * in case multiple sockets are opened in different namespaces at + * the same time. + */ + get_random_bytes(key, sizeof(u64)); + mptcp_crypto_key_sha(*key, token, idsn); +} + +/** + * mptcp_token_new_request - create new key/idsn/token for subflow_request + * @req: the request socket + * + * This function is called when a new mptcp connection is coming in. + * + * It creates a unique token to identify the new mptcp connection, + * a secret local key and the initial data sequence number (idsn). + * + * Returns 0 on success. + */ +int mptcp_token_new_request(struct request_sock *req) +{ + struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); + struct token_bucket *bucket; + u32 token; + + mptcp_crypto_key_sha(subflow_req->local_key, + &subflow_req->token, + &subflow_req->idsn); + pr_debug("req=%p local_key=%llu, token=%u, idsn=%llu\n", + req, subflow_req->local_key, subflow_req->token, + subflow_req->idsn); + + token = subflow_req->token; + bucket = token_bucket(token); + spin_lock_bh(&bucket->lock); + if (__token_bucket_busy(bucket, token)) { + spin_unlock_bh(&bucket->lock); + return -EBUSY; + } + + hlist_nulls_add_head_rcu(&subflow_req->token_node, &bucket->req_chain); + bucket->chain_len++; + spin_unlock_bh(&bucket->lock); + return 0; +} + +/** + * mptcp_token_new_connect - create new key/idsn/token for subflow + * @sk: the socket that will initiate a connection + * + * This function is called when a new outgoing mptcp connection is + * initiated. + * + * It creates a unique token to identify the new mptcp connection, + * a secret local key and the initial data sequence number (idsn). + * + * On success, the mptcp connection can be found again using + * the computed token at a later time, this is needed to process + * join requests. + * + * returns 0 on success. + */ +int mptcp_token_new_connect(struct sock *sk) +{ + struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); + struct mptcp_sock *msk = mptcp_sk(subflow->conn); + int retries = MPTCP_TOKEN_MAX_RETRIES; + struct token_bucket *bucket; + +again: + mptcp_crypto_key_gen_sha(&subflow->local_key, &subflow->token, + &subflow->idsn); + + bucket = token_bucket(subflow->token); + spin_lock_bh(&bucket->lock); + if (__token_bucket_busy(bucket, subflow->token)) { + spin_unlock_bh(&bucket->lock); + if (!--retries) + return -EBUSY; + goto again; + } + + pr_debug("ssk=%p, local_key=%llu, token=%u, idsn=%llu\n", + sk, subflow->local_key, subflow->token, subflow->idsn); + + WRITE_ONCE(msk->token, subflow->token); + __sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain); + bucket->chain_len++; + spin_unlock_bh(&bucket->lock); + return 0; +} + +/** + * mptcp_token_accept - replace a req sk with full sock in token hash + * @req: the request socket to be removed + * @msk: the just cloned socket linked to the new connection + * + * Called when a SYN packet creates a new logical connection, i.e. + * is not a join request. + */ +void mptcp_token_accept(struct mptcp_subflow_request_sock *req, + struct mptcp_sock *msk) +{ + struct mptcp_subflow_request_sock *pos; + struct token_bucket *bucket; + + bucket = token_bucket(req->token); + spin_lock_bh(&bucket->lock); + + /* pedantic lookup check for the moved token */ + pos = __token_lookup_req(bucket, req->token); + if (!WARN_ON_ONCE(pos != req)) + hlist_nulls_del_init_rcu(&req->token_node); + __sk_nulls_add_node_rcu((struct sock *)msk, &bucket->msk_chain); + spin_unlock_bh(&bucket->lock); +} + +bool mptcp_token_exists(u32 token) +{ + struct hlist_nulls_node *pos; + struct token_bucket *bucket; + struct mptcp_sock *msk; + struct sock *sk; + + rcu_read_lock(); + bucket = token_bucket(token); + +again: + sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) { + msk = mptcp_sk(sk); + if (READ_ONCE(msk->token) == token) + goto found; + } + if (get_nulls_value(pos) != (token & token_mask)) + goto again; + + rcu_read_unlock(); + return false; +found: + rcu_read_unlock(); + return true; +} + +/** + * mptcp_token_get_sock - retrieve mptcp connection sock using its token + * @net: restrict to this namespace + * @token: token of the mptcp connection to retrieve + * + * This function returns the mptcp connection structure with the given token. + * A reference count on the mptcp socket returned is taken. + * + * returns NULL if no connection with the given token value exists. + */ +struct mptcp_sock *mptcp_token_get_sock(struct net *net, u32 token) +{ + struct hlist_nulls_node *pos; + struct token_bucket *bucket; + struct mptcp_sock *msk; + struct sock *sk; + + rcu_read_lock(); + bucket = token_bucket(token); + +again: + sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) { + msk = mptcp_sk(sk); + if (READ_ONCE(msk->token) != token || + !net_eq(sock_net(sk), net)) + continue; + + if (!refcount_inc_not_zero(&sk->sk_refcnt)) + goto not_found; + + if (READ_ONCE(msk->token) != token || + !net_eq(sock_net(sk), net)) { + sock_put(sk); + goto again; + } + goto found; + } + if (get_nulls_value(pos) != (token & token_mask)) + goto again; + +not_found: + msk = NULL; + +found: + rcu_read_unlock(); + return msk; +} +EXPORT_SYMBOL_GPL(mptcp_token_get_sock); + +/** + * mptcp_token_iter_next - iterate over the token container from given pos + * @net: namespace to be iterated + * @s_slot: start slot number + * @s_num: start number inside the given lock + * + * This function returns the first mptcp connection structure found inside the + * token container starting from the specified position, or NULL. + * + * On successful iteration, the iterator is move to the next position and the + * the acquires a reference to the returned socket. + */ +struct mptcp_sock *mptcp_token_iter_next(const struct net *net, long *s_slot, + long *s_num) +{ + struct mptcp_sock *ret = NULL; + struct hlist_nulls_node *pos; + int slot, num = 0; + + for (slot = *s_slot; slot <= token_mask; *s_num = 0, slot++) { + struct token_bucket *bucket = &token_hash[slot]; + struct sock *sk; + + num = 0; + + if (hlist_nulls_empty(&bucket->msk_chain)) + continue; + + rcu_read_lock(); + sk_nulls_for_each_rcu(sk, pos, &bucket->msk_chain) { + ++num; + if (!net_eq(sock_net(sk), net)) + continue; + + if (num <= *s_num) + continue; + + if (!refcount_inc_not_zero(&sk->sk_refcnt)) + continue; + + if (!net_eq(sock_net(sk), net)) { + sock_put(sk); + continue; + } + + ret = mptcp_sk(sk); + rcu_read_unlock(); + goto out; + } + rcu_read_unlock(); + } + +out: + *s_slot = slot; + *s_num = num; + return ret; +} +EXPORT_SYMBOL_GPL(mptcp_token_iter_next); + +/** + * mptcp_token_destroy_request - remove mptcp connection/token + * @req: mptcp request socket dropping the token + * + * Remove the token associated to @req. + */ +void mptcp_token_destroy_request(struct request_sock *req) +{ + struct mptcp_subflow_request_sock *subflow_req = mptcp_subflow_rsk(req); + struct mptcp_subflow_request_sock *pos; + struct token_bucket *bucket; + + if (hlist_nulls_unhashed(&subflow_req->token_node)) + return; + + bucket = token_bucket(subflow_req->token); + spin_lock_bh(&bucket->lock); + pos = __token_lookup_req(bucket, subflow_req->token); + if (!WARN_ON_ONCE(pos != subflow_req)) { + hlist_nulls_del_init_rcu(&pos->token_node); + bucket->chain_len--; + } + spin_unlock_bh(&bucket->lock); +} + +/** + * mptcp_token_destroy - remove mptcp connection/token + * @msk: mptcp connection dropping the token + * + * Remove the token associated to @msk + */ +void mptcp_token_destroy(struct mptcp_sock *msk) +{ + struct token_bucket *bucket; + struct mptcp_sock *pos; + + if (sk_unhashed((struct sock *)msk)) + return; + + bucket = token_bucket(msk->token); + spin_lock_bh(&bucket->lock); + pos = __token_lookup_msk(bucket, msk->token); + if (!WARN_ON_ONCE(pos != msk)) { + __sk_nulls_del_node_init_rcu((struct sock *)pos); + bucket->chain_len--; + } + spin_unlock_bh(&bucket->lock); + WRITE_ONCE(msk->token, 0); +} + +void __init mptcp_token_init(void) +{ + int i; + + token_hash = alloc_large_system_hash("MPTCP token", + sizeof(struct token_bucket), + 0, + 20,/* one slot per 1MB of memory */ + HASH_ZERO, + NULL, + &token_mask, + 0, + 64 * 1024); + for (i = 0; i < token_mask + 1; ++i) { + INIT_HLIST_NULLS_HEAD(&token_hash[i].req_chain, i); + INIT_HLIST_NULLS_HEAD(&token_hash[i].msk_chain, i); + spin_lock_init(&token_hash[i].lock); + } +} + +#if IS_MODULE(CONFIG_MPTCP_KUNIT_TEST) +EXPORT_SYMBOL_GPL(mptcp_token_new_request); +EXPORT_SYMBOL_GPL(mptcp_token_new_connect); +EXPORT_SYMBOL_GPL(mptcp_token_accept); +EXPORT_SYMBOL_GPL(mptcp_token_destroy_request); +EXPORT_SYMBOL_GPL(mptcp_token_destroy); +#endif diff --git a/net/mptcp/token_test.c b/net/mptcp/token_test.c new file mode 100644 index 000000000..5d984bec1 --- /dev/null +++ b/net/mptcp/token_test.c @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: GPL-2.0 +#include + +#include "protocol.h" + +static struct mptcp_subflow_request_sock *build_req_sock(struct kunit *test) +{ + struct mptcp_subflow_request_sock *req; + + req = kunit_kzalloc(test, sizeof(struct mptcp_subflow_request_sock), + GFP_USER); + KUNIT_EXPECT_NOT_ERR_OR_NULL(test, req); + mptcp_token_init_request((struct request_sock *)req); + sock_net_set((struct sock *)req, &init_net); + return req; +} + +static void mptcp_token_test_req_basic(struct kunit *test) +{ + struct mptcp_subflow_request_sock *req = build_req_sock(test); + struct mptcp_sock *null_msk = NULL; + + KUNIT_ASSERT_EQ(test, 0, + mptcp_token_new_request((struct request_sock *)req)); + KUNIT_EXPECT_NE(test, 0, (int)req->token); + KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(&init_net, req->token)); + + /* cleanup */ + mptcp_token_destroy_request((struct request_sock *)req); +} + +static struct inet_connection_sock *build_icsk(struct kunit *test) +{ + struct inet_connection_sock *icsk; + + icsk = kunit_kzalloc(test, sizeof(struct inet_connection_sock), + GFP_USER); + KUNIT_EXPECT_NOT_ERR_OR_NULL(test, icsk); + return icsk; +} + +static struct mptcp_subflow_context *build_ctx(struct kunit *test) +{ + struct mptcp_subflow_context *ctx; + + ctx = kunit_kzalloc(test, sizeof(struct mptcp_subflow_context), + GFP_USER); + KUNIT_EXPECT_NOT_ERR_OR_NULL(test, ctx); + return ctx; +} + +static struct mptcp_sock *build_msk(struct kunit *test) +{ + struct mptcp_sock *msk; + + msk = kunit_kzalloc(test, sizeof(struct mptcp_sock), GFP_USER); + KUNIT_EXPECT_NOT_ERR_OR_NULL(test, msk); + refcount_set(&((struct sock *)msk)->sk_refcnt, 1); + sock_net_set((struct sock *)msk, &init_net); + return msk; +} + +static void mptcp_token_test_msk_basic(struct kunit *test) +{ + struct inet_connection_sock *icsk = build_icsk(test); + struct mptcp_subflow_context *ctx = build_ctx(test); + struct mptcp_sock *msk = build_msk(test); + struct mptcp_sock *null_msk = NULL; + struct sock *sk; + + rcu_assign_pointer(icsk->icsk_ulp_data, ctx); + ctx->conn = (struct sock *)msk; + sk = (struct sock *)msk; + + KUNIT_ASSERT_EQ(test, 0, + mptcp_token_new_connect((struct sock *)icsk)); + KUNIT_EXPECT_NE(test, 0, (int)ctx->token); + KUNIT_EXPECT_EQ(test, ctx->token, msk->token); + KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(&init_net, ctx->token)); + KUNIT_EXPECT_EQ(test, 2, (int)refcount_read(&sk->sk_refcnt)); + + mptcp_token_destroy(msk); + KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(&init_net, ctx->token)); +} + +static void mptcp_token_test_accept(struct kunit *test) +{ + struct mptcp_subflow_request_sock *req = build_req_sock(test); + struct mptcp_sock *msk = build_msk(test); + + KUNIT_ASSERT_EQ(test, 0, + mptcp_token_new_request((struct request_sock *)req)); + msk->token = req->token; + mptcp_token_accept(req, msk); + KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(&init_net, msk->token)); + + /* this is now a no-op */ + mptcp_token_destroy_request((struct request_sock *)req); + KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(&init_net, msk->token)); + + /* cleanup */ + mptcp_token_destroy(msk); +} + +static void mptcp_token_test_destroyed(struct kunit *test) +{ + struct mptcp_subflow_request_sock *req = build_req_sock(test); + struct mptcp_sock *msk = build_msk(test); + struct mptcp_sock *null_msk = NULL; + struct sock *sk; + + sk = (struct sock *)msk; + + KUNIT_ASSERT_EQ(test, 0, + mptcp_token_new_request((struct request_sock *)req)); + msk->token = req->token; + mptcp_token_accept(req, msk); + + /* simulate race on removal */ + refcount_set(&sk->sk_refcnt, 0); + KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(&init_net, msk->token)); + + /* cleanup */ + mptcp_token_destroy(msk); +} + +static struct kunit_case mptcp_token_test_cases[] = { + KUNIT_CASE(mptcp_token_test_req_basic), + KUNIT_CASE(mptcp_token_test_msk_basic), + KUNIT_CASE(mptcp_token_test_accept), + KUNIT_CASE(mptcp_token_test_destroyed), + {} +}; + +static struct kunit_suite mptcp_token_suite = { + .name = "mptcp-token", + .test_cases = mptcp_token_test_cases, +}; + +kunit_test_suite(mptcp_token_suite); + +MODULE_LICENSE("GPL"); diff --git a/net/ncsi/Kconfig b/net/ncsi/Kconfig new file mode 100644 index 000000000..ea1dd32b6 --- /dev/null +++ b/net/ncsi/Kconfig @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Configuration for NCSI support +# + +config NET_NCSI + bool "NCSI interface support" + depends on INET + help + This module provides NCSI (Network Controller Sideband Interface) + support. Enable this only if your system connects to a network + device via NCSI and the ethernet driver you're using supports + the protocol explicitly. +config NCSI_OEM_CMD_GET_MAC + bool "Get NCSI OEM MAC Address" + depends on NET_NCSI + help + This allows to get MAC address from NCSI firmware and set them back to + controller. +config NCSI_OEM_CMD_KEEP_PHY + bool "Keep PHY Link up" + depends on NET_NCSI + help + This allows to keep PHY link up and prevents any channel resets during + the host load. diff --git a/net/ncsi/Makefile b/net/ncsi/Makefile new file mode 100644 index 000000000..e205f3bf1 --- /dev/null +++ b/net/ncsi/Makefile @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for NCSI API +# +obj-$(CONFIG_NET_NCSI) += ncsi-cmd.o ncsi-rsp.o ncsi-aen.o ncsi-manage.o ncsi-netlink.o diff --git a/net/ncsi/internal.h b/net/ncsi/internal.h new file mode 100644 index 000000000..374412ed7 --- /dev/null +++ b/net/ncsi/internal.h @@ -0,0 +1,414 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * Copyright Gavin Shan, IBM Corporation 2016. + */ + +#ifndef __NCSI_INTERNAL_H__ +#define __NCSI_INTERNAL_H__ + +enum { + NCSI_CAP_BASE = 0, + NCSI_CAP_GENERIC = 0, + NCSI_CAP_BC, + NCSI_CAP_MC, + NCSI_CAP_BUFFER, + NCSI_CAP_AEN, + NCSI_CAP_VLAN, + NCSI_CAP_MAX +}; + +enum { + NCSI_CAP_GENERIC_HWA = 0x01, /* HW arbitration */ + NCSI_CAP_GENERIC_HDS = 0x02, /* HNC driver status change */ + NCSI_CAP_GENERIC_FC = 0x04, /* HNC to MC flow control */ + NCSI_CAP_GENERIC_FC1 = 0x08, /* MC to HNC flow control */ + NCSI_CAP_GENERIC_MC = 0x10, /* Global MC filtering */ + NCSI_CAP_GENERIC_HWA_UNKNOWN = 0x00, /* Unknown HW arbitration */ + NCSI_CAP_GENERIC_HWA_SUPPORT = 0x20, /* Supported HW arbitration */ + NCSI_CAP_GENERIC_HWA_NOT_SUPPORT = 0x40, /* No HW arbitration */ + NCSI_CAP_GENERIC_HWA_RESERVED = 0x60, /* Reserved HW arbitration */ + NCSI_CAP_GENERIC_HWA_MASK = 0x60, /* Mask for HW arbitration */ + NCSI_CAP_GENERIC_MASK = 0x7f, + NCSI_CAP_BC_ARP = 0x01, /* ARP packet filtering */ + NCSI_CAP_BC_DHCPC = 0x02, /* DHCP client filtering */ + NCSI_CAP_BC_DHCPS = 0x04, /* DHCP server filtering */ + NCSI_CAP_BC_NETBIOS = 0x08, /* NetBIOS packet filtering */ + NCSI_CAP_BC_MASK = 0x0f, + NCSI_CAP_MC_IPV6_NEIGHBOR = 0x01, /* IPv6 neighbor filtering */ + NCSI_CAP_MC_IPV6_ROUTER = 0x02, /* IPv6 router filering */ + NCSI_CAP_MC_DHCPV6_RELAY = 0x04, /* DHCPv6 relay / server MC */ + NCSI_CAP_MC_DHCPV6_WELL_KNOWN = 0x08, /* DHCPv6 well-known MC */ + NCSI_CAP_MC_IPV6_MLD = 0x10, /* IPv6 MLD filtering */ + NCSI_CAP_MC_IPV6_NEIGHBOR_S = 0x20, /* IPv6 neighbour filtering */ + NCSI_CAP_MC_MASK = 0x3f, + NCSI_CAP_AEN_LSC = 0x01, /* Link status change */ + NCSI_CAP_AEN_CR = 0x02, /* Configuration required */ + NCSI_CAP_AEN_HDS = 0x04, /* HNC driver status */ + NCSI_CAP_AEN_MASK = 0x07, + NCSI_CAP_VLAN_ONLY = 0x01, /* Filter VLAN packet only */ + NCSI_CAP_VLAN_NO = 0x02, /* Filter VLAN and non-VLAN */ + NCSI_CAP_VLAN_ANY = 0x04, /* Filter Any-and-non-VLAN */ + NCSI_CAP_VLAN_MASK = 0x07 +}; + +enum { + NCSI_MODE_BASE = 0, + NCSI_MODE_ENABLE = 0, + NCSI_MODE_TX_ENABLE, + NCSI_MODE_LINK, + NCSI_MODE_VLAN, + NCSI_MODE_BC, + NCSI_MODE_MC, + NCSI_MODE_AEN, + NCSI_MODE_FC, + NCSI_MODE_MAX +}; + +/* Supported media status bits for Mellanox Mac affinity command. + * Bit (0-2) for different protocol support; Bit 1 for RBT support, + * bit 1 for SMBUS support and bit 2 for PCIE support. Bit (3-5) + * for different protocol availability. Bit 4 for RBT, bit 4 for + * SMBUS and bit 5 for PCIE. + */ +enum { + MLX_MC_RBT_SUPPORT = 0x01, /* MC supports RBT */ + MLX_MC_RBT_AVL = 0x08, /* RBT medium is available */ +}; + +/* OEM Vendor Manufacture ID */ +#define NCSI_OEM_MFR_MLX_ID 0x8119 +#define NCSI_OEM_MFR_BCM_ID 0x113d +#define NCSI_OEM_MFR_INTEL_ID 0x157 +/* Intel specific OEM command */ +#define NCSI_OEM_INTEL_CMD_GMA 0x06 /* CMD ID for Get MAC */ +#define NCSI_OEM_INTEL_CMD_KEEP_PHY 0x20 /* CMD ID for Keep PHY up */ +/* Broadcom specific OEM Command */ +#define NCSI_OEM_BCM_CMD_GMA 0x01 /* CMD ID for Get MAC */ +/* Mellanox specific OEM Command */ +#define NCSI_OEM_MLX_CMD_GMA 0x00 /* CMD ID for Get MAC */ +#define NCSI_OEM_MLX_CMD_GMA_PARAM 0x1b /* Parameter for GMA */ +#define NCSI_OEM_MLX_CMD_SMAF 0x01 /* CMD ID for Set MC Affinity */ +#define NCSI_OEM_MLX_CMD_SMAF_PARAM 0x07 /* Parameter for SMAF */ +/* OEM Command payload lengths*/ +#define NCSI_OEM_INTEL_CMD_GMA_LEN 5 +#define NCSI_OEM_INTEL_CMD_KEEP_PHY_LEN 7 +#define NCSI_OEM_BCM_CMD_GMA_LEN 12 +#define NCSI_OEM_MLX_CMD_GMA_LEN 8 +#define NCSI_OEM_MLX_CMD_SMAF_LEN 60 +/* Offset in OEM request */ +#define MLX_SMAF_MAC_ADDR_OFFSET 8 /* Offset for MAC in SMAF */ +#define MLX_SMAF_MED_SUPPORT_OFFSET 14 /* Offset for medium in SMAF */ +/* Mac address offset in OEM response */ +#define BCM_MAC_ADDR_OFFSET 28 +#define MLX_MAC_ADDR_OFFSET 8 +#define INTEL_MAC_ADDR_OFFSET 1 + + +struct ncsi_channel_version { + u8 major; /* NCSI version major */ + u8 minor; /* NCSI version minor */ + u8 update; /* NCSI version update */ + char alpha1; /* NCSI version alpha1 */ + char alpha2; /* NCSI version alpha2 */ + u8 fw_name[12]; /* Firmware name string */ + u32 fw_version; /* Firmware version */ + u16 pci_ids[4]; /* PCI identification */ + u32 mf_id; /* Manufacture ID */ +}; + +struct ncsi_channel_cap { + u32 index; /* Index of channel capabilities */ + u32 cap; /* NCSI channel capability */ +}; + +struct ncsi_channel_mode { + u32 index; /* Index of channel modes */ + u32 enable; /* Enabled or disabled */ + u32 size; /* Valid entries in ncm_data[] */ + u32 data[8]; /* Data entries */ +}; + +struct ncsi_channel_mac_filter { + u8 n_uc; + u8 n_mc; + u8 n_mixed; + u64 bitmap; + unsigned char *addrs; +}; + +struct ncsi_channel_vlan_filter { + u8 n_vids; + u64 bitmap; + u16 *vids; +}; + +struct ncsi_channel_stats { + u32 hnc_cnt_hi; /* Counter cleared */ + u32 hnc_cnt_lo; /* Counter cleared */ + u32 hnc_rx_bytes; /* Rx bytes */ + u32 hnc_tx_bytes; /* Tx bytes */ + u32 hnc_rx_uc_pkts; /* Rx UC packets */ + u32 hnc_rx_mc_pkts; /* Rx MC packets */ + u32 hnc_rx_bc_pkts; /* Rx BC packets */ + u32 hnc_tx_uc_pkts; /* Tx UC packets */ + u32 hnc_tx_mc_pkts; /* Tx MC packets */ + u32 hnc_tx_bc_pkts; /* Tx BC packets */ + u32 hnc_fcs_err; /* FCS errors */ + u32 hnc_align_err; /* Alignment errors */ + u32 hnc_false_carrier; /* False carrier detection */ + u32 hnc_runt_pkts; /* Rx runt packets */ + u32 hnc_jabber_pkts; /* Rx jabber packets */ + u32 hnc_rx_pause_xon; /* Rx pause XON frames */ + u32 hnc_rx_pause_xoff; /* Rx XOFF frames */ + u32 hnc_tx_pause_xon; /* Tx XON frames */ + u32 hnc_tx_pause_xoff; /* Tx XOFF frames */ + u32 hnc_tx_s_collision; /* Single collision frames */ + u32 hnc_tx_m_collision; /* Multiple collision frames */ + u32 hnc_l_collision; /* Late collision frames */ + u32 hnc_e_collision; /* Excessive collision frames */ + u32 hnc_rx_ctl_frames; /* Rx control frames */ + u32 hnc_rx_64_frames; /* Rx 64-bytes frames */ + u32 hnc_rx_127_frames; /* Rx 65-127 bytes frames */ + u32 hnc_rx_255_frames; /* Rx 128-255 bytes frames */ + u32 hnc_rx_511_frames; /* Rx 256-511 bytes frames */ + u32 hnc_rx_1023_frames; /* Rx 512-1023 bytes frames */ + u32 hnc_rx_1522_frames; /* Rx 1024-1522 bytes frames */ + u32 hnc_rx_9022_frames; /* Rx 1523-9022 bytes frames */ + u32 hnc_tx_64_frames; /* Tx 64-bytes frames */ + u32 hnc_tx_127_frames; /* Tx 65-127 bytes frames */ + u32 hnc_tx_255_frames; /* Tx 128-255 bytes frames */ + u32 hnc_tx_511_frames; /* Tx 256-511 bytes frames */ + u32 hnc_tx_1023_frames; /* Tx 512-1023 bytes frames */ + u32 hnc_tx_1522_frames; /* Tx 1024-1522 bytes frames */ + u32 hnc_tx_9022_frames; /* Tx 1523-9022 bytes frames */ + u32 hnc_rx_valid_bytes; /* Rx valid bytes */ + u32 hnc_rx_runt_pkts; /* Rx error runt packets */ + u32 hnc_rx_jabber_pkts; /* Rx error jabber packets */ + u32 ncsi_rx_cmds; /* Rx NCSI commands */ + u32 ncsi_dropped_cmds; /* Dropped commands */ + u32 ncsi_cmd_type_errs; /* Command type errors */ + u32 ncsi_cmd_csum_errs; /* Command checksum errors */ + u32 ncsi_rx_pkts; /* Rx NCSI packets */ + u32 ncsi_tx_pkts; /* Tx NCSI packets */ + u32 ncsi_tx_aen_pkts; /* Tx AEN packets */ + u32 pt_tx_pkts; /* Tx packets */ + u32 pt_tx_dropped; /* Tx dropped packets */ + u32 pt_tx_channel_err; /* Tx channel errors */ + u32 pt_tx_us_err; /* Tx undersize errors */ + u32 pt_rx_pkts; /* Rx packets */ + u32 pt_rx_dropped; /* Rx dropped packets */ + u32 pt_rx_channel_err; /* Rx channel errors */ + u32 pt_rx_us_err; /* Rx undersize errors */ + u32 pt_rx_os_err; /* Rx oversize errors */ +}; + +struct ncsi_dev_priv; +struct ncsi_package; + +#define NCSI_PACKAGE_SHIFT 5 +#define NCSI_PACKAGE_INDEX(c) (((c) >> NCSI_PACKAGE_SHIFT) & 0x7) +#define NCSI_RESERVED_CHANNEL 0x1f +#define NCSI_CHANNEL_INDEX(c) ((c) & ((1 << NCSI_PACKAGE_SHIFT) - 1)) +#define NCSI_TO_CHANNEL(p, c) (((p) << NCSI_PACKAGE_SHIFT) | (c)) +#define NCSI_MAX_PACKAGE 8 +#define NCSI_MAX_CHANNEL 32 + +struct ncsi_channel { + unsigned char id; + int state; +#define NCSI_CHANNEL_INACTIVE 1 +#define NCSI_CHANNEL_ACTIVE 2 +#define NCSI_CHANNEL_INVISIBLE 3 + bool reconfigure_needed; + spinlock_t lock; /* Protect filters etc */ + struct ncsi_package *package; + struct ncsi_channel_version version; + struct ncsi_channel_cap caps[NCSI_CAP_MAX]; + struct ncsi_channel_mode modes[NCSI_MODE_MAX]; + /* Filtering Settings */ + struct ncsi_channel_mac_filter mac_filter; + struct ncsi_channel_vlan_filter vlan_filter; + struct ncsi_channel_stats stats; + struct { + struct timer_list timer; + bool enabled; + unsigned int state; +#define NCSI_CHANNEL_MONITOR_START 0 +#define NCSI_CHANNEL_MONITOR_RETRY 1 +#define NCSI_CHANNEL_MONITOR_WAIT 2 +#define NCSI_CHANNEL_MONITOR_WAIT_MAX 5 + } monitor; + struct list_head node; + struct list_head link; +}; + +struct ncsi_package { + unsigned char id; /* NCSI 3-bits package ID */ + unsigned char uuid[16]; /* UUID */ + struct ncsi_dev_priv *ndp; /* NCSI device */ + spinlock_t lock; /* Protect the package */ + unsigned int channel_num; /* Number of channels */ + struct list_head channels; /* List of channels */ + struct list_head node; /* Form list of packages */ + + bool multi_channel; /* Enable multiple channels */ + u32 channel_whitelist; /* Channels to configure */ + struct ncsi_channel *preferred_channel; /* Primary channel */ +}; + +struct ncsi_request { + unsigned char id; /* Request ID - 0 to 255 */ + bool used; /* Request that has been assigned */ + unsigned int flags; /* NCSI request property */ +#define NCSI_REQ_FLAG_EVENT_DRIVEN 1 +#define NCSI_REQ_FLAG_NETLINK_DRIVEN 2 + struct ncsi_dev_priv *ndp; /* Associated NCSI device */ + struct sk_buff *cmd; /* Associated NCSI command packet */ + struct sk_buff *rsp; /* Associated NCSI response packet */ + struct timer_list timer; /* Timer on waiting for response */ + bool enabled; /* Time has been enabled or not */ + u32 snd_seq; /* netlink sending sequence number */ + u32 snd_portid; /* netlink portid of sender */ + struct nlmsghdr nlhdr; /* netlink message header */ +}; + +enum { + ncsi_dev_state_major = 0xff00, + ncsi_dev_state_minor = 0x00ff, + ncsi_dev_state_probe_deselect = 0x0201, + ncsi_dev_state_probe_package, + ncsi_dev_state_probe_channel, + ncsi_dev_state_probe_mlx_gma, + ncsi_dev_state_probe_mlx_smaf, + ncsi_dev_state_probe_cis, + ncsi_dev_state_probe_keep_phy, + ncsi_dev_state_probe_gvi, + ncsi_dev_state_probe_gc, + ncsi_dev_state_probe_gls, + ncsi_dev_state_probe_dp, + ncsi_dev_state_config_sp = 0x0301, + ncsi_dev_state_config_cis, + ncsi_dev_state_config_oem_gma, + ncsi_dev_state_config_clear_vids, + ncsi_dev_state_config_svf, + ncsi_dev_state_config_ev, + ncsi_dev_state_config_sma, + ncsi_dev_state_config_ebf, + ncsi_dev_state_config_dgmf, + ncsi_dev_state_config_ecnt, + ncsi_dev_state_config_ec, + ncsi_dev_state_config_ae, + ncsi_dev_state_config_gls, + ncsi_dev_state_config_done, + ncsi_dev_state_suspend_select = 0x0401, + ncsi_dev_state_suspend_gls, + ncsi_dev_state_suspend_dcnt, + ncsi_dev_state_suspend_dc, + ncsi_dev_state_suspend_deselect, + ncsi_dev_state_suspend_done +}; + +struct vlan_vid { + struct list_head list; + __be16 proto; + u16 vid; +}; + +struct ncsi_dev_priv { + struct ncsi_dev ndev; /* Associated NCSI device */ + unsigned int flags; /* NCSI device flags */ +#define NCSI_DEV_PROBED 1 /* Finalized NCSI topology */ +#define NCSI_DEV_HWA 2 /* Enabled HW arbitration */ +#define NCSI_DEV_RESHUFFLE 4 +#define NCSI_DEV_RESET 8 /* Reset state of NC */ + unsigned int gma_flag; /* OEM GMA flag */ + spinlock_t lock; /* Protect the NCSI device */ + unsigned int package_probe_id;/* Current ID during probe */ + unsigned int package_num; /* Number of packages */ + struct list_head packages; /* List of packages */ + struct ncsi_channel *hot_channel; /* Channel was ever active */ + struct ncsi_request requests[256]; /* Request table */ + unsigned int request_id; /* Last used request ID */ +#define NCSI_REQ_START_IDX 1 + unsigned int pending_req_num; /* Number of pending requests */ + struct ncsi_package *active_package; /* Currently handled package */ + struct ncsi_channel *active_channel; /* Currently handled channel */ + struct list_head channel_queue; /* Config queue of channels */ + struct work_struct work; /* For channel management */ + struct packet_type ptype; /* NCSI packet Rx handler */ + struct list_head node; /* Form NCSI device list */ +#define NCSI_MAX_VLAN_VIDS 15 + struct list_head vlan_vids; /* List of active VLAN IDs */ + + bool multi_package; /* Enable multiple packages */ + bool mlx_multi_host; /* Enable multi host Mellanox */ + u32 package_whitelist; /* Packages to configure */ +}; + +struct ncsi_cmd_arg { + struct ncsi_dev_priv *ndp; /* Associated NCSI device */ + unsigned char type; /* Command in the NCSI packet */ + unsigned char id; /* Request ID (sequence number) */ + unsigned char package; /* Destination package ID */ + unsigned char channel; /* Destination channel ID or 0x1f */ + unsigned short payload; /* Command packet payload length */ + unsigned int req_flags; /* NCSI request properties */ + union { + unsigned char bytes[16]; /* Command packet specific data */ + unsigned short words[8]; + unsigned int dwords[4]; + }; + unsigned char *data; /* NCSI OEM data */ + struct genl_info *info; /* Netlink information */ +}; + +extern struct list_head ncsi_dev_list; +extern spinlock_t ncsi_dev_lock; + +#define TO_NCSI_DEV_PRIV(nd) \ + container_of(nd, struct ncsi_dev_priv, ndev) +#define NCSI_FOR_EACH_DEV(ndp) \ + list_for_each_entry_rcu(ndp, &ncsi_dev_list, node) +#define NCSI_FOR_EACH_PACKAGE(ndp, np) \ + list_for_each_entry_rcu(np, &ndp->packages, node) +#define NCSI_FOR_EACH_CHANNEL(np, nc) \ + list_for_each_entry_rcu(nc, &np->channels, node) + +/* Resources */ +int ncsi_reset_dev(struct ncsi_dev *nd); +void ncsi_start_channel_monitor(struct ncsi_channel *nc); +void ncsi_stop_channel_monitor(struct ncsi_channel *nc); +struct ncsi_channel *ncsi_find_channel(struct ncsi_package *np, + unsigned char id); +struct ncsi_channel *ncsi_add_channel(struct ncsi_package *np, + unsigned char id); +struct ncsi_package *ncsi_find_package(struct ncsi_dev_priv *ndp, + unsigned char id); +struct ncsi_package *ncsi_add_package(struct ncsi_dev_priv *ndp, + unsigned char id); +void ncsi_remove_package(struct ncsi_package *np); +void ncsi_find_package_and_channel(struct ncsi_dev_priv *ndp, + unsigned char id, + struct ncsi_package **np, + struct ncsi_channel **nc); +struct ncsi_request *ncsi_alloc_request(struct ncsi_dev_priv *ndp, + unsigned int req_flags); +void ncsi_free_request(struct ncsi_request *nr); +struct ncsi_dev *ncsi_find_dev(struct net_device *dev); +int ncsi_process_next_channel(struct ncsi_dev_priv *ndp); +bool ncsi_channel_has_link(struct ncsi_channel *channel); +bool ncsi_channel_is_last(struct ncsi_dev_priv *ndp, + struct ncsi_channel *channel); +int ncsi_update_tx_channel(struct ncsi_dev_priv *ndp, + struct ncsi_package *np, + struct ncsi_channel *disable, + struct ncsi_channel *enable); + +/* Packet handlers */ +u32 ncsi_calculate_checksum(unsigned char *data, int len); +int ncsi_xmit_cmd(struct ncsi_cmd_arg *nca); +int ncsi_rcv_rsp(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev); +int ncsi_aen_handler(struct ncsi_dev_priv *ndp, struct sk_buff *skb); + +#endif /* __NCSI_INTERNAL_H__ */ diff --git a/net/ncsi/ncsi-aen.c b/net/ncsi/ncsi-aen.c new file mode 100644 index 000000000..62fb10317 --- /dev/null +++ b/net/ncsi/ncsi-aen.c @@ -0,0 +1,246 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright Gavin Shan, IBM Corporation 2016. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "internal.h" +#include "ncsi-pkt.h" + +static int ncsi_validate_aen_pkt(struct ncsi_aen_pkt_hdr *h, + const unsigned short payload) +{ + u32 checksum; + __be32 *pchecksum; + + if (h->common.revision != NCSI_PKT_REVISION) + return -EINVAL; + if (ntohs(h->common.length) != payload) + return -EINVAL; + + /* Validate checksum, which might be zeroes if the + * sender doesn't support checksum according to NCSI + * specification. + */ + pchecksum = (__be32 *)((void *)(h + 1) + payload - 4); + if (ntohl(*pchecksum) == 0) + return 0; + + checksum = ncsi_calculate_checksum((unsigned char *)h, + sizeof(*h) + payload - 4); + if (*pchecksum != htonl(checksum)) + return -EINVAL; + + return 0; +} + +static int ncsi_aen_handler_lsc(struct ncsi_dev_priv *ndp, + struct ncsi_aen_pkt_hdr *h) +{ + struct ncsi_channel *nc, *tmp; + struct ncsi_channel_mode *ncm; + unsigned long old_data, data; + struct ncsi_aen_lsc_pkt *lsc; + struct ncsi_package *np; + bool had_link, has_link; + unsigned long flags; + bool chained; + int state; + + /* Find the NCSI channel */ + ncsi_find_package_and_channel(ndp, h->common.channel, NULL, &nc); + if (!nc) + return -ENODEV; + + /* Update the link status */ + lsc = (struct ncsi_aen_lsc_pkt *)h; + + spin_lock_irqsave(&nc->lock, flags); + ncm = &nc->modes[NCSI_MODE_LINK]; + old_data = ncm->data[2]; + data = ntohl(lsc->status); + ncm->data[2] = data; + ncm->data[4] = ntohl(lsc->oem_status); + + had_link = !!(old_data & 0x1); + has_link = !!(data & 0x1); + + netdev_dbg(ndp->ndev.dev, "NCSI: LSC AEN - channel %u state %s\n", + nc->id, data & 0x1 ? "up" : "down"); + + chained = !list_empty(&nc->link); + state = nc->state; + spin_unlock_irqrestore(&nc->lock, flags); + + if (state == NCSI_CHANNEL_INACTIVE) + netdev_warn(ndp->ndev.dev, + "NCSI: Inactive channel %u received AEN!\n", + nc->id); + + if ((had_link == has_link) || chained) + return 0; + + if (!ndp->multi_package && !nc->package->multi_channel) { + if (had_link) { + ndp->flags |= NCSI_DEV_RESHUFFLE; + ncsi_stop_channel_monitor(nc); + spin_lock_irqsave(&ndp->lock, flags); + list_add_tail_rcu(&nc->link, &ndp->channel_queue); + spin_unlock_irqrestore(&ndp->lock, flags); + return ncsi_process_next_channel(ndp); + } + /* Configured channel came up */ + return 0; + } + + if (had_link) { + ncm = &nc->modes[NCSI_MODE_TX_ENABLE]; + if (ncsi_channel_is_last(ndp, nc)) { + /* No channels left, reconfigure */ + return ncsi_reset_dev(&ndp->ndev); + } else if (ncm->enable) { + /* Need to failover Tx channel */ + ncsi_update_tx_channel(ndp, nc->package, nc, NULL); + } + } else if (has_link && nc->package->preferred_channel == nc) { + /* Return Tx to preferred channel */ + ncsi_update_tx_channel(ndp, nc->package, NULL, nc); + } else if (has_link) { + NCSI_FOR_EACH_PACKAGE(ndp, np) { + NCSI_FOR_EACH_CHANNEL(np, tmp) { + /* Enable Tx on this channel if the current Tx + * channel is down. + */ + ncm = &tmp->modes[NCSI_MODE_TX_ENABLE]; + if (ncm->enable && + !ncsi_channel_has_link(tmp)) { + ncsi_update_tx_channel(ndp, nc->package, + tmp, nc); + break; + } + } + } + } + + /* Leave configured channels active in a multi-channel scenario so + * AEN events are still received. + */ + return 0; +} + +static int ncsi_aen_handler_cr(struct ncsi_dev_priv *ndp, + struct ncsi_aen_pkt_hdr *h) +{ + struct ncsi_channel *nc; + unsigned long flags; + + /* Find the NCSI channel */ + ncsi_find_package_and_channel(ndp, h->common.channel, NULL, &nc); + if (!nc) + return -ENODEV; + + spin_lock_irqsave(&nc->lock, flags); + if (!list_empty(&nc->link) || + nc->state != NCSI_CHANNEL_ACTIVE) { + spin_unlock_irqrestore(&nc->lock, flags); + return 0; + } + spin_unlock_irqrestore(&nc->lock, flags); + + ncsi_stop_channel_monitor(nc); + spin_lock_irqsave(&nc->lock, flags); + nc->state = NCSI_CHANNEL_INVISIBLE; + spin_unlock_irqrestore(&nc->lock, flags); + + spin_lock_irqsave(&ndp->lock, flags); + nc->state = NCSI_CHANNEL_INACTIVE; + list_add_tail_rcu(&nc->link, &ndp->channel_queue); + spin_unlock_irqrestore(&ndp->lock, flags); + nc->modes[NCSI_MODE_TX_ENABLE].enable = 0; + + return ncsi_process_next_channel(ndp); +} + +static int ncsi_aen_handler_hncdsc(struct ncsi_dev_priv *ndp, + struct ncsi_aen_pkt_hdr *h) +{ + struct ncsi_channel *nc; + struct ncsi_channel_mode *ncm; + struct ncsi_aen_hncdsc_pkt *hncdsc; + unsigned long flags; + + /* Find the NCSI channel */ + ncsi_find_package_and_channel(ndp, h->common.channel, NULL, &nc); + if (!nc) + return -ENODEV; + + spin_lock_irqsave(&nc->lock, flags); + ncm = &nc->modes[NCSI_MODE_LINK]; + hncdsc = (struct ncsi_aen_hncdsc_pkt *)h; + ncm->data[3] = ntohl(hncdsc->status); + spin_unlock_irqrestore(&nc->lock, flags); + netdev_dbg(ndp->ndev.dev, + "NCSI: host driver %srunning on channel %u\n", + ncm->data[3] & 0x1 ? "" : "not ", nc->id); + + return 0; +} + +static struct ncsi_aen_handler { + unsigned char type; + int payload; + int (*handler)(struct ncsi_dev_priv *ndp, + struct ncsi_aen_pkt_hdr *h); +} ncsi_aen_handlers[] = { + { NCSI_PKT_AEN_LSC, 12, ncsi_aen_handler_lsc }, + { NCSI_PKT_AEN_CR, 4, ncsi_aen_handler_cr }, + { NCSI_PKT_AEN_HNCDSC, 8, ncsi_aen_handler_hncdsc } +}; + +int ncsi_aen_handler(struct ncsi_dev_priv *ndp, struct sk_buff *skb) +{ + struct ncsi_aen_pkt_hdr *h; + struct ncsi_aen_handler *nah = NULL; + int i, ret; + + /* Find the handler */ + h = (struct ncsi_aen_pkt_hdr *)skb_network_header(skb); + for (i = 0; i < ARRAY_SIZE(ncsi_aen_handlers); i++) { + if (ncsi_aen_handlers[i].type == h->type) { + nah = &ncsi_aen_handlers[i]; + break; + } + } + + if (!nah) { + netdev_warn(ndp->ndev.dev, "Invalid AEN (0x%x) received\n", + h->type); + return -ENOENT; + } + + ret = ncsi_validate_aen_pkt(h, nah->payload); + if (ret) { + netdev_warn(ndp->ndev.dev, + "NCSI: 'bad' packet ignored for AEN type 0x%x\n", + h->type); + goto out; + } + + ret = nah->handler(ndp, h); + if (ret) + netdev_err(ndp->ndev.dev, + "NCSI: Handler for AEN type 0x%x returned %d\n", + h->type, ret); +out: + consume_skb(skb); + return ret; +} diff --git a/net/ncsi/ncsi-cmd.c b/net/ncsi/ncsi-cmd.c new file mode 100644 index 000000000..dda8b76b7 --- /dev/null +++ b/net/ncsi/ncsi-cmd.c @@ -0,0 +1,406 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright Gavin Shan, IBM Corporation 2016. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "internal.h" +#include "ncsi-pkt.h" + +static const int padding_bytes = 26; + +u32 ncsi_calculate_checksum(unsigned char *data, int len) +{ + u32 checksum = 0; + int i; + + for (i = 0; i < len; i += 2) + checksum += (((u32)data[i] << 8) | data[i + 1]); + + checksum = (~checksum + 1); + return checksum; +} + +/* This function should be called after the data area has been + * populated completely. + */ +static void ncsi_cmd_build_header(struct ncsi_pkt_hdr *h, + struct ncsi_cmd_arg *nca) +{ + u32 checksum; + __be32 *pchecksum; + + h->mc_id = 0; + h->revision = NCSI_PKT_REVISION; + h->reserved = 0; + h->id = nca->id; + h->type = nca->type; + h->channel = NCSI_TO_CHANNEL(nca->package, + nca->channel); + h->length = htons(nca->payload); + h->reserved1[0] = 0; + h->reserved1[1] = 0; + + /* Fill with calculated checksum */ + checksum = ncsi_calculate_checksum((unsigned char *)h, + sizeof(*h) + nca->payload); + pchecksum = (__be32 *)((void *)h + sizeof(struct ncsi_pkt_hdr) + + ALIGN(nca->payload, 4)); + *pchecksum = htonl(checksum); +} + +static int ncsi_cmd_handler_default(struct sk_buff *skb, + struct ncsi_cmd_arg *nca) +{ + struct ncsi_cmd_pkt *cmd; + + cmd = skb_put_zero(skb, sizeof(*cmd)); + ncsi_cmd_build_header(&cmd->cmd.common, nca); + + return 0; +} + +static int ncsi_cmd_handler_sp(struct sk_buff *skb, + struct ncsi_cmd_arg *nca) +{ + struct ncsi_cmd_sp_pkt *cmd; + + cmd = skb_put_zero(skb, sizeof(*cmd)); + cmd->hw_arbitration = nca->bytes[0]; + ncsi_cmd_build_header(&cmd->cmd.common, nca); + + return 0; +} + +static int ncsi_cmd_handler_dc(struct sk_buff *skb, + struct ncsi_cmd_arg *nca) +{ + struct ncsi_cmd_dc_pkt *cmd; + + cmd = skb_put_zero(skb, sizeof(*cmd)); + cmd->ald = nca->bytes[0]; + ncsi_cmd_build_header(&cmd->cmd.common, nca); + + return 0; +} + +static int ncsi_cmd_handler_rc(struct sk_buff *skb, + struct ncsi_cmd_arg *nca) +{ + struct ncsi_cmd_rc_pkt *cmd; + + cmd = skb_put_zero(skb, sizeof(*cmd)); + ncsi_cmd_build_header(&cmd->cmd.common, nca); + + return 0; +} + +static int ncsi_cmd_handler_ae(struct sk_buff *skb, + struct ncsi_cmd_arg *nca) +{ + struct ncsi_cmd_ae_pkt *cmd; + + cmd = skb_put_zero(skb, sizeof(*cmd)); + cmd->mc_id = nca->bytes[0]; + cmd->mode = htonl(nca->dwords[1]); + ncsi_cmd_build_header(&cmd->cmd.common, nca); + + return 0; +} + +static int ncsi_cmd_handler_sl(struct sk_buff *skb, + struct ncsi_cmd_arg *nca) +{ + struct ncsi_cmd_sl_pkt *cmd; + + cmd = skb_put_zero(skb, sizeof(*cmd)); + cmd->mode = htonl(nca->dwords[0]); + cmd->oem_mode = htonl(nca->dwords[1]); + ncsi_cmd_build_header(&cmd->cmd.common, nca); + + return 0; +} + +static int ncsi_cmd_handler_svf(struct sk_buff *skb, + struct ncsi_cmd_arg *nca) +{ + struct ncsi_cmd_svf_pkt *cmd; + + cmd = skb_put_zero(skb, sizeof(*cmd)); + cmd->vlan = htons(nca->words[1]); + cmd->index = nca->bytes[6]; + cmd->enable = nca->bytes[7]; + ncsi_cmd_build_header(&cmd->cmd.common, nca); + + return 0; +} + +static int ncsi_cmd_handler_ev(struct sk_buff *skb, + struct ncsi_cmd_arg *nca) +{ + struct ncsi_cmd_ev_pkt *cmd; + + cmd = skb_put_zero(skb, sizeof(*cmd)); + cmd->mode = nca->bytes[3]; + ncsi_cmd_build_header(&cmd->cmd.common, nca); + + return 0; +} + +static int ncsi_cmd_handler_sma(struct sk_buff *skb, + struct ncsi_cmd_arg *nca) +{ + struct ncsi_cmd_sma_pkt *cmd; + int i; + + cmd = skb_put_zero(skb, sizeof(*cmd)); + for (i = 0; i < 6; i++) + cmd->mac[i] = nca->bytes[i]; + cmd->index = nca->bytes[6]; + cmd->at_e = nca->bytes[7]; + ncsi_cmd_build_header(&cmd->cmd.common, nca); + + return 0; +} + +static int ncsi_cmd_handler_ebf(struct sk_buff *skb, + struct ncsi_cmd_arg *nca) +{ + struct ncsi_cmd_ebf_pkt *cmd; + + cmd = skb_put_zero(skb, sizeof(*cmd)); + cmd->mode = htonl(nca->dwords[0]); + ncsi_cmd_build_header(&cmd->cmd.common, nca); + + return 0; +} + +static int ncsi_cmd_handler_egmf(struct sk_buff *skb, + struct ncsi_cmd_arg *nca) +{ + struct ncsi_cmd_egmf_pkt *cmd; + + cmd = skb_put_zero(skb, sizeof(*cmd)); + cmd->mode = htonl(nca->dwords[0]); + ncsi_cmd_build_header(&cmd->cmd.common, nca); + + return 0; +} + +static int ncsi_cmd_handler_snfc(struct sk_buff *skb, + struct ncsi_cmd_arg *nca) +{ + struct ncsi_cmd_snfc_pkt *cmd; + + cmd = skb_put_zero(skb, sizeof(*cmd)); + cmd->mode = nca->bytes[0]; + ncsi_cmd_build_header(&cmd->cmd.common, nca); + + return 0; +} + +static int ncsi_cmd_handler_oem(struct sk_buff *skb, + struct ncsi_cmd_arg *nca) +{ + struct ncsi_cmd_oem_pkt *cmd; + unsigned int len; + int payload; + /* NC-SI spec DSP_0222_1.2.0, section 8.2.2.2 + * requires payload to be padded with 0 to + * 32-bit boundary before the checksum field. + * Ensure the padding bytes are accounted for in + * skb allocation + */ + + payload = ALIGN(nca->payload, 4); + len = sizeof(struct ncsi_cmd_pkt_hdr) + 4; + len += max(payload, padding_bytes); + + cmd = skb_put_zero(skb, len); + memcpy(&cmd->mfr_id, nca->data, nca->payload); + ncsi_cmd_build_header(&cmd->cmd.common, nca); + + return 0; +} + +static struct ncsi_cmd_handler { + unsigned char type; + int payload; + int (*handler)(struct sk_buff *skb, + struct ncsi_cmd_arg *nca); +} ncsi_cmd_handlers[] = { + { NCSI_PKT_CMD_CIS, 0, ncsi_cmd_handler_default }, + { NCSI_PKT_CMD_SP, 4, ncsi_cmd_handler_sp }, + { NCSI_PKT_CMD_DP, 0, ncsi_cmd_handler_default }, + { NCSI_PKT_CMD_EC, 0, ncsi_cmd_handler_default }, + { NCSI_PKT_CMD_DC, 4, ncsi_cmd_handler_dc }, + { NCSI_PKT_CMD_RC, 4, ncsi_cmd_handler_rc }, + { NCSI_PKT_CMD_ECNT, 0, ncsi_cmd_handler_default }, + { NCSI_PKT_CMD_DCNT, 0, ncsi_cmd_handler_default }, + { NCSI_PKT_CMD_AE, 8, ncsi_cmd_handler_ae }, + { NCSI_PKT_CMD_SL, 8, ncsi_cmd_handler_sl }, + { NCSI_PKT_CMD_GLS, 0, ncsi_cmd_handler_default }, + { NCSI_PKT_CMD_SVF, 8, ncsi_cmd_handler_svf }, + { NCSI_PKT_CMD_EV, 4, ncsi_cmd_handler_ev }, + { NCSI_PKT_CMD_DV, 0, ncsi_cmd_handler_default }, + { NCSI_PKT_CMD_SMA, 8, ncsi_cmd_handler_sma }, + { NCSI_PKT_CMD_EBF, 4, ncsi_cmd_handler_ebf }, + { NCSI_PKT_CMD_DBF, 0, ncsi_cmd_handler_default }, + { NCSI_PKT_CMD_EGMF, 4, ncsi_cmd_handler_egmf }, + { NCSI_PKT_CMD_DGMF, 0, ncsi_cmd_handler_default }, + { NCSI_PKT_CMD_SNFC, 4, ncsi_cmd_handler_snfc }, + { NCSI_PKT_CMD_GVI, 0, ncsi_cmd_handler_default }, + { NCSI_PKT_CMD_GC, 0, ncsi_cmd_handler_default }, + { NCSI_PKT_CMD_GP, 0, ncsi_cmd_handler_default }, + { NCSI_PKT_CMD_GCPS, 0, ncsi_cmd_handler_default }, + { NCSI_PKT_CMD_GNS, 0, ncsi_cmd_handler_default }, + { NCSI_PKT_CMD_GNPTS, 0, ncsi_cmd_handler_default }, + { NCSI_PKT_CMD_GPS, 0, ncsi_cmd_handler_default }, + { NCSI_PKT_CMD_OEM, -1, ncsi_cmd_handler_oem }, + { NCSI_PKT_CMD_PLDM, 0, NULL }, + { NCSI_PKT_CMD_GPUUID, 0, ncsi_cmd_handler_default } +}; + +static struct ncsi_request *ncsi_alloc_command(struct ncsi_cmd_arg *nca) +{ + struct ncsi_dev_priv *ndp = nca->ndp; + struct ncsi_dev *nd = &ndp->ndev; + struct net_device *dev = nd->dev; + int hlen = LL_RESERVED_SPACE(dev); + int tlen = dev->needed_tailroom; + int payload; + int len = hlen + tlen; + struct sk_buff *skb; + struct ncsi_request *nr; + + nr = ncsi_alloc_request(ndp, nca->req_flags); + if (!nr) + return NULL; + + /* NCSI command packet has 16-bytes header, payload, 4 bytes checksum. + * Payload needs padding so that the checksum field following payload is + * aligned to 32-bit boundary. + * The packet needs padding if its payload is less than 26 bytes to + * meet 64 bytes minimal ethernet frame length. + */ + len += sizeof(struct ncsi_cmd_pkt_hdr) + 4; + payload = ALIGN(nca->payload, 4); + len += max(payload, padding_bytes); + + /* Allocate skb */ + skb = alloc_skb(len, GFP_ATOMIC); + if (!skb) { + ncsi_free_request(nr); + return NULL; + } + + nr->cmd = skb; + skb_reserve(skb, hlen); + skb_reset_network_header(skb); + + skb->dev = dev; + skb->protocol = htons(ETH_P_NCSI); + + return nr; +} + +int ncsi_xmit_cmd(struct ncsi_cmd_arg *nca) +{ + struct ncsi_cmd_handler *nch = NULL; + struct ncsi_request *nr; + unsigned char type; + struct ethhdr *eh; + int i, ret; + + /* Use OEM generic handler for Netlink request */ + if (nca->req_flags == NCSI_REQ_FLAG_NETLINK_DRIVEN) + type = NCSI_PKT_CMD_OEM; + else + type = nca->type; + + /* Search for the handler */ + for (i = 0; i < ARRAY_SIZE(ncsi_cmd_handlers); i++) { + if (ncsi_cmd_handlers[i].type == type) { + if (ncsi_cmd_handlers[i].handler) + nch = &ncsi_cmd_handlers[i]; + else + nch = NULL; + + break; + } + } + + if (!nch) { + netdev_err(nca->ndp->ndev.dev, + "Cannot send packet with type 0x%02x\n", nca->type); + return -ENOENT; + } + + /* Get packet payload length and allocate the request + * It is expected that if length set as negative in + * handler structure means caller is initializing it + * and setting length in nca before calling xmit function + */ + if (nch->payload >= 0) + nca->payload = nch->payload; + nr = ncsi_alloc_command(nca); + if (!nr) + return -ENOMEM; + + /* track netlink information */ + if (nca->req_flags == NCSI_REQ_FLAG_NETLINK_DRIVEN) { + nr->snd_seq = nca->info->snd_seq; + nr->snd_portid = nca->info->snd_portid; + nr->nlhdr = *nca->info->nlhdr; + } + + /* Prepare the packet */ + nca->id = nr->id; + ret = nch->handler(nr->cmd, nca); + if (ret) { + ncsi_free_request(nr); + return ret; + } + + /* Fill the ethernet header */ + eh = skb_push(nr->cmd, sizeof(*eh)); + eh->h_proto = htons(ETH_P_NCSI); + eth_broadcast_addr(eh->h_dest); + + /* If mac address received from device then use it for + * source address as unicast address else use broadcast + * address as source address + */ + if (nca->ndp->gma_flag == 1) + memcpy(eh->h_source, nca->ndp->ndev.dev->dev_addr, ETH_ALEN); + else + eth_broadcast_addr(eh->h_source); + + /* Start the timer for the request that might not have + * corresponding response. Given NCSI is an internal + * connection a 1 second delay should be sufficient. + */ + nr->enabled = true; + mod_timer(&nr->timer, jiffies + 1 * HZ); + + /* Send NCSI packet */ + skb_get(nr->cmd); + ret = dev_queue_xmit(nr->cmd); + if (ret < 0) { + ncsi_free_request(nr); + return ret; + } + + return 0; +} diff --git a/net/ncsi/ncsi-manage.c b/net/ncsi/ncsi-manage.c new file mode 100644 index 000000000..80713febf --- /dev/null +++ b/net/ncsi/ncsi-manage.c @@ -0,0 +1,1969 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright Gavin Shan, IBM Corporation 2016. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "internal.h" +#include "ncsi-pkt.h" +#include "ncsi-netlink.h" + +LIST_HEAD(ncsi_dev_list); +DEFINE_SPINLOCK(ncsi_dev_lock); + +bool ncsi_channel_has_link(struct ncsi_channel *channel) +{ + return !!(channel->modes[NCSI_MODE_LINK].data[2] & 0x1); +} + +bool ncsi_channel_is_last(struct ncsi_dev_priv *ndp, + struct ncsi_channel *channel) +{ + struct ncsi_package *np; + struct ncsi_channel *nc; + + NCSI_FOR_EACH_PACKAGE(ndp, np) + NCSI_FOR_EACH_CHANNEL(np, nc) { + if (nc == channel) + continue; + if (nc->state == NCSI_CHANNEL_ACTIVE && + ncsi_channel_has_link(nc)) + return false; + } + + return true; +} + +static void ncsi_report_link(struct ncsi_dev_priv *ndp, bool force_down) +{ + struct ncsi_dev *nd = &ndp->ndev; + struct ncsi_package *np; + struct ncsi_channel *nc; + unsigned long flags; + + nd->state = ncsi_dev_state_functional; + if (force_down) { + nd->link_up = 0; + goto report; + } + + nd->link_up = 0; + NCSI_FOR_EACH_PACKAGE(ndp, np) { + NCSI_FOR_EACH_CHANNEL(np, nc) { + spin_lock_irqsave(&nc->lock, flags); + + if (!list_empty(&nc->link) || + nc->state != NCSI_CHANNEL_ACTIVE) { + spin_unlock_irqrestore(&nc->lock, flags); + continue; + } + + if (ncsi_channel_has_link(nc)) { + spin_unlock_irqrestore(&nc->lock, flags); + nd->link_up = 1; + goto report; + } + + spin_unlock_irqrestore(&nc->lock, flags); + } + } + +report: + nd->handler(nd); +} + +static void ncsi_channel_monitor(struct timer_list *t) +{ + struct ncsi_channel *nc = from_timer(nc, t, monitor.timer); + struct ncsi_package *np = nc->package; + struct ncsi_dev_priv *ndp = np->ndp; + struct ncsi_channel_mode *ncm; + struct ncsi_cmd_arg nca; + bool enabled, chained; + unsigned int monitor_state; + unsigned long flags; + int state, ret; + + spin_lock_irqsave(&nc->lock, flags); + state = nc->state; + chained = !list_empty(&nc->link); + enabled = nc->monitor.enabled; + monitor_state = nc->monitor.state; + spin_unlock_irqrestore(&nc->lock, flags); + + if (!enabled) + return; /* expected race disabling timer */ + if (WARN_ON_ONCE(chained)) + goto bad_state; + + if (state != NCSI_CHANNEL_INACTIVE && + state != NCSI_CHANNEL_ACTIVE) { +bad_state: + netdev_warn(ndp->ndev.dev, + "Bad NCSI monitor state channel %d 0x%x %s queue\n", + nc->id, state, chained ? "on" : "off"); + spin_lock_irqsave(&nc->lock, flags); + nc->monitor.enabled = false; + spin_unlock_irqrestore(&nc->lock, flags); + return; + } + + switch (monitor_state) { + case NCSI_CHANNEL_MONITOR_START: + case NCSI_CHANNEL_MONITOR_RETRY: + nca.ndp = ndp; + nca.package = np->id; + nca.channel = nc->id; + nca.type = NCSI_PKT_CMD_GLS; + nca.req_flags = 0; + ret = ncsi_xmit_cmd(&nca); + if (ret) + netdev_err(ndp->ndev.dev, "Error %d sending GLS\n", + ret); + break; + case NCSI_CHANNEL_MONITOR_WAIT ... NCSI_CHANNEL_MONITOR_WAIT_MAX: + break; + default: + netdev_err(ndp->ndev.dev, "NCSI Channel %d timed out!\n", + nc->id); + ncsi_report_link(ndp, true); + ndp->flags |= NCSI_DEV_RESHUFFLE; + + ncm = &nc->modes[NCSI_MODE_LINK]; + spin_lock_irqsave(&nc->lock, flags); + nc->monitor.enabled = false; + nc->state = NCSI_CHANNEL_INVISIBLE; + ncm->data[2] &= ~0x1; + spin_unlock_irqrestore(&nc->lock, flags); + + spin_lock_irqsave(&ndp->lock, flags); + nc->state = NCSI_CHANNEL_ACTIVE; + list_add_tail_rcu(&nc->link, &ndp->channel_queue); + spin_unlock_irqrestore(&ndp->lock, flags); + ncsi_process_next_channel(ndp); + return; + } + + spin_lock_irqsave(&nc->lock, flags); + nc->monitor.state++; + spin_unlock_irqrestore(&nc->lock, flags); + mod_timer(&nc->monitor.timer, jiffies + HZ); +} + +void ncsi_start_channel_monitor(struct ncsi_channel *nc) +{ + unsigned long flags; + + spin_lock_irqsave(&nc->lock, flags); + WARN_ON_ONCE(nc->monitor.enabled); + nc->monitor.enabled = true; + nc->monitor.state = NCSI_CHANNEL_MONITOR_START; + spin_unlock_irqrestore(&nc->lock, flags); + + mod_timer(&nc->monitor.timer, jiffies + HZ); +} + +void ncsi_stop_channel_monitor(struct ncsi_channel *nc) +{ + unsigned long flags; + + spin_lock_irqsave(&nc->lock, flags); + if (!nc->monitor.enabled) { + spin_unlock_irqrestore(&nc->lock, flags); + return; + } + nc->monitor.enabled = false; + spin_unlock_irqrestore(&nc->lock, flags); + + del_timer_sync(&nc->monitor.timer); +} + +struct ncsi_channel *ncsi_find_channel(struct ncsi_package *np, + unsigned char id) +{ + struct ncsi_channel *nc; + + NCSI_FOR_EACH_CHANNEL(np, nc) { + if (nc->id == id) + return nc; + } + + return NULL; +} + +struct ncsi_channel *ncsi_add_channel(struct ncsi_package *np, unsigned char id) +{ + struct ncsi_channel *nc, *tmp; + int index; + unsigned long flags; + + nc = kzalloc(sizeof(*nc), GFP_ATOMIC); + if (!nc) + return NULL; + + nc->id = id; + nc->package = np; + nc->state = NCSI_CHANNEL_INACTIVE; + nc->monitor.enabled = false; + timer_setup(&nc->monitor.timer, ncsi_channel_monitor, 0); + spin_lock_init(&nc->lock); + INIT_LIST_HEAD(&nc->link); + for (index = 0; index < NCSI_CAP_MAX; index++) + nc->caps[index].index = index; + for (index = 0; index < NCSI_MODE_MAX; index++) + nc->modes[index].index = index; + + spin_lock_irqsave(&np->lock, flags); + tmp = ncsi_find_channel(np, id); + if (tmp) { + spin_unlock_irqrestore(&np->lock, flags); + kfree(nc); + return tmp; + } + + list_add_tail_rcu(&nc->node, &np->channels); + np->channel_num++; + spin_unlock_irqrestore(&np->lock, flags); + + return nc; +} + +static void ncsi_remove_channel(struct ncsi_channel *nc) +{ + struct ncsi_package *np = nc->package; + unsigned long flags; + + spin_lock_irqsave(&nc->lock, flags); + + /* Release filters */ + kfree(nc->mac_filter.addrs); + kfree(nc->vlan_filter.vids); + + nc->state = NCSI_CHANNEL_INACTIVE; + spin_unlock_irqrestore(&nc->lock, flags); + ncsi_stop_channel_monitor(nc); + + /* Remove and free channel */ + spin_lock_irqsave(&np->lock, flags); + list_del_rcu(&nc->node); + np->channel_num--; + spin_unlock_irqrestore(&np->lock, flags); + + kfree(nc); +} + +struct ncsi_package *ncsi_find_package(struct ncsi_dev_priv *ndp, + unsigned char id) +{ + struct ncsi_package *np; + + NCSI_FOR_EACH_PACKAGE(ndp, np) { + if (np->id == id) + return np; + } + + return NULL; +} + +struct ncsi_package *ncsi_add_package(struct ncsi_dev_priv *ndp, + unsigned char id) +{ + struct ncsi_package *np, *tmp; + unsigned long flags; + + np = kzalloc(sizeof(*np), GFP_ATOMIC); + if (!np) + return NULL; + + np->id = id; + np->ndp = ndp; + spin_lock_init(&np->lock); + INIT_LIST_HEAD(&np->channels); + np->channel_whitelist = UINT_MAX; + + spin_lock_irqsave(&ndp->lock, flags); + tmp = ncsi_find_package(ndp, id); + if (tmp) { + spin_unlock_irqrestore(&ndp->lock, flags); + kfree(np); + return tmp; + } + + list_add_tail_rcu(&np->node, &ndp->packages); + ndp->package_num++; + spin_unlock_irqrestore(&ndp->lock, flags); + + return np; +} + +void ncsi_remove_package(struct ncsi_package *np) +{ + struct ncsi_dev_priv *ndp = np->ndp; + struct ncsi_channel *nc, *tmp; + unsigned long flags; + + /* Release all child channels */ + list_for_each_entry_safe(nc, tmp, &np->channels, node) + ncsi_remove_channel(nc); + + /* Remove and free package */ + spin_lock_irqsave(&ndp->lock, flags); + list_del_rcu(&np->node); + ndp->package_num--; + spin_unlock_irqrestore(&ndp->lock, flags); + + kfree(np); +} + +void ncsi_find_package_and_channel(struct ncsi_dev_priv *ndp, + unsigned char id, + struct ncsi_package **np, + struct ncsi_channel **nc) +{ + struct ncsi_package *p; + struct ncsi_channel *c; + + p = ncsi_find_package(ndp, NCSI_PACKAGE_INDEX(id)); + c = p ? ncsi_find_channel(p, NCSI_CHANNEL_INDEX(id)) : NULL; + + if (np) + *np = p; + if (nc) + *nc = c; +} + +/* For two consecutive NCSI commands, the packet IDs shouldn't + * be same. Otherwise, the bogus response might be replied. So + * the available IDs are allocated in round-robin fashion. + */ +struct ncsi_request *ncsi_alloc_request(struct ncsi_dev_priv *ndp, + unsigned int req_flags) +{ + struct ncsi_request *nr = NULL; + int i, limit = ARRAY_SIZE(ndp->requests); + unsigned long flags; + + /* Check if there is one available request until the ceiling */ + spin_lock_irqsave(&ndp->lock, flags); + for (i = ndp->request_id; i < limit; i++) { + if (ndp->requests[i].used) + continue; + + nr = &ndp->requests[i]; + nr->used = true; + nr->flags = req_flags; + ndp->request_id = i + 1; + goto found; + } + + /* Fail back to check from the starting cursor */ + for (i = NCSI_REQ_START_IDX; i < ndp->request_id; i++) { + if (ndp->requests[i].used) + continue; + + nr = &ndp->requests[i]; + nr->used = true; + nr->flags = req_flags; + ndp->request_id = i + 1; + goto found; + } + +found: + spin_unlock_irqrestore(&ndp->lock, flags); + return nr; +} + +void ncsi_free_request(struct ncsi_request *nr) +{ + struct ncsi_dev_priv *ndp = nr->ndp; + struct sk_buff *cmd, *rsp; + unsigned long flags; + bool driven; + + if (nr->enabled) { + nr->enabled = false; + del_timer_sync(&nr->timer); + } + + spin_lock_irqsave(&ndp->lock, flags); + cmd = nr->cmd; + rsp = nr->rsp; + nr->cmd = NULL; + nr->rsp = NULL; + nr->used = false; + driven = !!(nr->flags & NCSI_REQ_FLAG_EVENT_DRIVEN); + spin_unlock_irqrestore(&ndp->lock, flags); + + if (driven && cmd && --ndp->pending_req_num == 0) + schedule_work(&ndp->work); + + /* Release command and response */ + consume_skb(cmd); + consume_skb(rsp); +} + +struct ncsi_dev *ncsi_find_dev(struct net_device *dev) +{ + struct ncsi_dev_priv *ndp; + + NCSI_FOR_EACH_DEV(ndp) { + if (ndp->ndev.dev == dev) + return &ndp->ndev; + } + + return NULL; +} + +static void ncsi_request_timeout(struct timer_list *t) +{ + struct ncsi_request *nr = from_timer(nr, t, timer); + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_cmd_pkt *cmd; + struct ncsi_package *np; + struct ncsi_channel *nc; + unsigned long flags; + + /* If the request already had associated response, + * let the response handler to release it. + */ + spin_lock_irqsave(&ndp->lock, flags); + nr->enabled = false; + if (nr->rsp || !nr->cmd) { + spin_unlock_irqrestore(&ndp->lock, flags); + return; + } + spin_unlock_irqrestore(&ndp->lock, flags); + + if (nr->flags == NCSI_REQ_FLAG_NETLINK_DRIVEN) { + if (nr->cmd) { + /* Find the package */ + cmd = (struct ncsi_cmd_pkt *) + skb_network_header(nr->cmd); + ncsi_find_package_and_channel(ndp, + cmd->cmd.common.channel, + &np, &nc); + ncsi_send_netlink_timeout(nr, np, nc); + } + } + + /* Release the request */ + ncsi_free_request(nr); +} + +static void ncsi_suspend_channel(struct ncsi_dev_priv *ndp) +{ + struct ncsi_dev *nd = &ndp->ndev; + struct ncsi_package *np; + struct ncsi_channel *nc, *tmp; + struct ncsi_cmd_arg nca; + unsigned long flags; + int ret; + + np = ndp->active_package; + nc = ndp->active_channel; + nca.ndp = ndp; + nca.req_flags = NCSI_REQ_FLAG_EVENT_DRIVEN; + switch (nd->state) { + case ncsi_dev_state_suspend: + nd->state = ncsi_dev_state_suspend_select; + fallthrough; + case ncsi_dev_state_suspend_select: + ndp->pending_req_num = 1; + + nca.type = NCSI_PKT_CMD_SP; + nca.package = np->id; + nca.channel = NCSI_RESERVED_CHANNEL; + if (ndp->flags & NCSI_DEV_HWA) + nca.bytes[0] = 0; + else + nca.bytes[0] = 1; + + /* To retrieve the last link states of channels in current + * package when current active channel needs fail over to + * another one. It means we will possibly select another + * channel as next active one. The link states of channels + * are most important factor of the selection. So we need + * accurate link states. Unfortunately, the link states on + * inactive channels can't be updated with LSC AEN in time. + */ + if (ndp->flags & NCSI_DEV_RESHUFFLE) + nd->state = ncsi_dev_state_suspend_gls; + else + nd->state = ncsi_dev_state_suspend_dcnt; + ret = ncsi_xmit_cmd(&nca); + if (ret) + goto error; + + break; + case ncsi_dev_state_suspend_gls: + ndp->pending_req_num = np->channel_num; + + nca.type = NCSI_PKT_CMD_GLS; + nca.package = np->id; + + nd->state = ncsi_dev_state_suspend_dcnt; + NCSI_FOR_EACH_CHANNEL(np, nc) { + nca.channel = nc->id; + ret = ncsi_xmit_cmd(&nca); + if (ret) + goto error; + } + + break; + case ncsi_dev_state_suspend_dcnt: + ndp->pending_req_num = 1; + + nca.type = NCSI_PKT_CMD_DCNT; + nca.package = np->id; + nca.channel = nc->id; + + nd->state = ncsi_dev_state_suspend_dc; + ret = ncsi_xmit_cmd(&nca); + if (ret) + goto error; + + break; + case ncsi_dev_state_suspend_dc: + ndp->pending_req_num = 1; + + nca.type = NCSI_PKT_CMD_DC; + nca.package = np->id; + nca.channel = nc->id; + nca.bytes[0] = 1; + + nd->state = ncsi_dev_state_suspend_deselect; + ret = ncsi_xmit_cmd(&nca); + if (ret) + goto error; + + NCSI_FOR_EACH_CHANNEL(np, tmp) { + /* If there is another channel active on this package + * do not deselect the package. + */ + if (tmp != nc && tmp->state == NCSI_CHANNEL_ACTIVE) { + nd->state = ncsi_dev_state_suspend_done; + break; + } + } + break; + case ncsi_dev_state_suspend_deselect: + ndp->pending_req_num = 1; + + nca.type = NCSI_PKT_CMD_DP; + nca.package = np->id; + nca.channel = NCSI_RESERVED_CHANNEL; + + nd->state = ncsi_dev_state_suspend_done; + ret = ncsi_xmit_cmd(&nca); + if (ret) + goto error; + + break; + case ncsi_dev_state_suspend_done: + spin_lock_irqsave(&nc->lock, flags); + nc->state = NCSI_CHANNEL_INACTIVE; + spin_unlock_irqrestore(&nc->lock, flags); + if (ndp->flags & NCSI_DEV_RESET) + ncsi_reset_dev(nd); + else + ncsi_process_next_channel(ndp); + break; + default: + netdev_warn(nd->dev, "Wrong NCSI state 0x%x in suspend\n", + nd->state); + } + + return; +error: + nd->state = ncsi_dev_state_functional; +} + +/* Check the VLAN filter bitmap for a set filter, and construct a + * "Set VLAN Filter - Disable" packet if found. + */ +static int clear_one_vid(struct ncsi_dev_priv *ndp, struct ncsi_channel *nc, + struct ncsi_cmd_arg *nca) +{ + struct ncsi_channel_vlan_filter *ncf; + unsigned long flags; + void *bitmap; + int index; + u16 vid; + + ncf = &nc->vlan_filter; + bitmap = &ncf->bitmap; + + spin_lock_irqsave(&nc->lock, flags); + index = find_first_bit(bitmap, ncf->n_vids); + if (index >= ncf->n_vids) { + spin_unlock_irqrestore(&nc->lock, flags); + return -1; + } + vid = ncf->vids[index]; + + clear_bit(index, bitmap); + ncf->vids[index] = 0; + spin_unlock_irqrestore(&nc->lock, flags); + + nca->type = NCSI_PKT_CMD_SVF; + nca->words[1] = vid; + /* HW filter index starts at 1 */ + nca->bytes[6] = index + 1; + nca->bytes[7] = 0x00; + return 0; +} + +/* Find an outstanding VLAN tag and construct a "Set VLAN Filter - Enable" + * packet. + */ +static int set_one_vid(struct ncsi_dev_priv *ndp, struct ncsi_channel *nc, + struct ncsi_cmd_arg *nca) +{ + struct ncsi_channel_vlan_filter *ncf; + struct vlan_vid *vlan = NULL; + unsigned long flags; + int i, index; + void *bitmap; + u16 vid; + + if (list_empty(&ndp->vlan_vids)) + return -1; + + ncf = &nc->vlan_filter; + bitmap = &ncf->bitmap; + + spin_lock_irqsave(&nc->lock, flags); + + rcu_read_lock(); + list_for_each_entry_rcu(vlan, &ndp->vlan_vids, list) { + vid = vlan->vid; + for (i = 0; i < ncf->n_vids; i++) + if (ncf->vids[i] == vid) { + vid = 0; + break; + } + if (vid) + break; + } + rcu_read_unlock(); + + if (!vid) { + /* No VLAN ID is not set */ + spin_unlock_irqrestore(&nc->lock, flags); + return -1; + } + + index = find_first_zero_bit(bitmap, ncf->n_vids); + if (index < 0 || index >= ncf->n_vids) { + netdev_err(ndp->ndev.dev, + "Channel %u already has all VLAN filters set\n", + nc->id); + spin_unlock_irqrestore(&nc->lock, flags); + return -1; + } + + ncf->vids[index] = vid; + set_bit(index, bitmap); + spin_unlock_irqrestore(&nc->lock, flags); + + nca->type = NCSI_PKT_CMD_SVF; + nca->words[1] = vid; + /* HW filter index starts at 1 */ + nca->bytes[6] = index + 1; + nca->bytes[7] = 0x01; + + return 0; +} + +#if IS_ENABLED(CONFIG_NCSI_OEM_CMD_KEEP_PHY) + +static int ncsi_oem_keep_phy_intel(struct ncsi_cmd_arg *nca) +{ + unsigned char data[NCSI_OEM_INTEL_CMD_KEEP_PHY_LEN]; + int ret = 0; + + nca->payload = NCSI_OEM_INTEL_CMD_KEEP_PHY_LEN; + + memset(data, 0, NCSI_OEM_INTEL_CMD_KEEP_PHY_LEN); + *(unsigned int *)data = ntohl((__force __be32)NCSI_OEM_MFR_INTEL_ID); + + data[4] = NCSI_OEM_INTEL_CMD_KEEP_PHY; + + /* PHY Link up attribute */ + data[6] = 0x1; + + nca->data = data; + + ret = ncsi_xmit_cmd(nca); + if (ret) + netdev_err(nca->ndp->ndev.dev, + "NCSI: Failed to transmit cmd 0x%x during configure\n", + nca->type); + return ret; +} + +#endif + +#if IS_ENABLED(CONFIG_NCSI_OEM_CMD_GET_MAC) + +/* NCSI OEM Command APIs */ +static int ncsi_oem_gma_handler_bcm(struct ncsi_cmd_arg *nca) +{ + unsigned char data[NCSI_OEM_BCM_CMD_GMA_LEN]; + int ret = 0; + + nca->payload = NCSI_OEM_BCM_CMD_GMA_LEN; + + memset(data, 0, NCSI_OEM_BCM_CMD_GMA_LEN); + *(unsigned int *)data = ntohl((__force __be32)NCSI_OEM_MFR_BCM_ID); + data[5] = NCSI_OEM_BCM_CMD_GMA; + + nca->data = data; + + ret = ncsi_xmit_cmd(nca); + if (ret) + netdev_err(nca->ndp->ndev.dev, + "NCSI: Failed to transmit cmd 0x%x during configure\n", + nca->type); + return ret; +} + +static int ncsi_oem_gma_handler_mlx(struct ncsi_cmd_arg *nca) +{ + union { + u8 data_u8[NCSI_OEM_MLX_CMD_GMA_LEN]; + u32 data_u32[NCSI_OEM_MLX_CMD_GMA_LEN / sizeof(u32)]; + } u; + int ret = 0; + + nca->payload = NCSI_OEM_MLX_CMD_GMA_LEN; + + memset(&u, 0, sizeof(u)); + u.data_u32[0] = ntohl((__force __be32)NCSI_OEM_MFR_MLX_ID); + u.data_u8[5] = NCSI_OEM_MLX_CMD_GMA; + u.data_u8[6] = NCSI_OEM_MLX_CMD_GMA_PARAM; + + nca->data = u.data_u8; + + ret = ncsi_xmit_cmd(nca); + if (ret) + netdev_err(nca->ndp->ndev.dev, + "NCSI: Failed to transmit cmd 0x%x during configure\n", + nca->type); + return ret; +} + +static int ncsi_oem_smaf_mlx(struct ncsi_cmd_arg *nca) +{ + union { + u8 data_u8[NCSI_OEM_MLX_CMD_SMAF_LEN]; + u32 data_u32[NCSI_OEM_MLX_CMD_SMAF_LEN / sizeof(u32)]; + } u; + int ret = 0; + + memset(&u, 0, sizeof(u)); + u.data_u32[0] = ntohl((__force __be32)NCSI_OEM_MFR_MLX_ID); + u.data_u8[5] = NCSI_OEM_MLX_CMD_SMAF; + u.data_u8[6] = NCSI_OEM_MLX_CMD_SMAF_PARAM; + memcpy(&u.data_u8[MLX_SMAF_MAC_ADDR_OFFSET], + nca->ndp->ndev.dev->dev_addr, ETH_ALEN); + u.data_u8[MLX_SMAF_MED_SUPPORT_OFFSET] = + (MLX_MC_RBT_AVL | MLX_MC_RBT_SUPPORT); + + nca->payload = NCSI_OEM_MLX_CMD_SMAF_LEN; + nca->data = u.data_u8; + + ret = ncsi_xmit_cmd(nca); + if (ret) + netdev_err(nca->ndp->ndev.dev, + "NCSI: Failed to transmit cmd 0x%x during probe\n", + nca->type); + return ret; +} + +static int ncsi_oem_gma_handler_intel(struct ncsi_cmd_arg *nca) +{ + unsigned char data[NCSI_OEM_INTEL_CMD_GMA_LEN]; + int ret = 0; + + nca->payload = NCSI_OEM_INTEL_CMD_GMA_LEN; + + memset(data, 0, NCSI_OEM_INTEL_CMD_GMA_LEN); + *(unsigned int *)data = ntohl((__force __be32)NCSI_OEM_MFR_INTEL_ID); + data[4] = NCSI_OEM_INTEL_CMD_GMA; + + nca->data = data; + + ret = ncsi_xmit_cmd(nca); + if (ret) + netdev_err(nca->ndp->ndev.dev, + "NCSI: Failed to transmit cmd 0x%x during configure\n", + nca->type); + + return ret; +} + +/* OEM Command handlers initialization */ +static struct ncsi_oem_gma_handler { + unsigned int mfr_id; + int (*handler)(struct ncsi_cmd_arg *nca); +} ncsi_oem_gma_handlers[] = { + { NCSI_OEM_MFR_BCM_ID, ncsi_oem_gma_handler_bcm }, + { NCSI_OEM_MFR_MLX_ID, ncsi_oem_gma_handler_mlx }, + { NCSI_OEM_MFR_INTEL_ID, ncsi_oem_gma_handler_intel } +}; + +static int ncsi_gma_handler(struct ncsi_cmd_arg *nca, unsigned int mf_id) +{ + struct ncsi_oem_gma_handler *nch = NULL; + int i; + + /* This function should only be called once, return if flag set */ + if (nca->ndp->gma_flag == 1) + return -1; + + /* Find gma handler for given manufacturer id */ + for (i = 0; i < ARRAY_SIZE(ncsi_oem_gma_handlers); i++) { + if (ncsi_oem_gma_handlers[i].mfr_id == mf_id) { + if (ncsi_oem_gma_handlers[i].handler) + nch = &ncsi_oem_gma_handlers[i]; + break; + } + } + + if (!nch) { + netdev_err(nca->ndp->ndev.dev, + "NCSI: No GMA handler available for MFR-ID (0x%x)\n", + mf_id); + return -1; + } + + /* Get Mac address from NCSI device */ + return nch->handler(nca); +} + +#endif /* CONFIG_NCSI_OEM_CMD_GET_MAC */ + +/* Determine if a given channel from the channel_queue should be used for Tx */ +static bool ncsi_channel_is_tx(struct ncsi_dev_priv *ndp, + struct ncsi_channel *nc) +{ + struct ncsi_channel_mode *ncm; + struct ncsi_channel *channel; + struct ncsi_package *np; + + /* Check if any other channel has Tx enabled; a channel may have already + * been configured and removed from the channel queue. + */ + NCSI_FOR_EACH_PACKAGE(ndp, np) { + if (!ndp->multi_package && np != nc->package) + continue; + NCSI_FOR_EACH_CHANNEL(np, channel) { + ncm = &channel->modes[NCSI_MODE_TX_ENABLE]; + if (ncm->enable) + return false; + } + } + + /* This channel is the preferred channel and has link */ + list_for_each_entry_rcu(channel, &ndp->channel_queue, link) { + np = channel->package; + if (np->preferred_channel && + ncsi_channel_has_link(np->preferred_channel)) { + return np->preferred_channel == nc; + } + } + + /* This channel has link */ + if (ncsi_channel_has_link(nc)) + return true; + + list_for_each_entry_rcu(channel, &ndp->channel_queue, link) + if (ncsi_channel_has_link(channel)) + return false; + + /* No other channel has link; default to this one */ + return true; +} + +/* Change the active Tx channel in a multi-channel setup */ +int ncsi_update_tx_channel(struct ncsi_dev_priv *ndp, + struct ncsi_package *package, + struct ncsi_channel *disable, + struct ncsi_channel *enable) +{ + struct ncsi_cmd_arg nca; + struct ncsi_channel *nc; + struct ncsi_package *np; + int ret = 0; + + if (!package->multi_channel && !ndp->multi_package) + netdev_warn(ndp->ndev.dev, + "NCSI: Trying to update Tx channel in single-channel mode\n"); + nca.ndp = ndp; + nca.req_flags = 0; + + /* Find current channel with Tx enabled */ + NCSI_FOR_EACH_PACKAGE(ndp, np) { + if (disable) + break; + if (!ndp->multi_package && np != package) + continue; + + NCSI_FOR_EACH_CHANNEL(np, nc) + if (nc->modes[NCSI_MODE_TX_ENABLE].enable) { + disable = nc; + break; + } + } + + /* Find a suitable channel for Tx */ + NCSI_FOR_EACH_PACKAGE(ndp, np) { + if (enable) + break; + if (!ndp->multi_package && np != package) + continue; + if (!(ndp->package_whitelist & (0x1 << np->id))) + continue; + + if (np->preferred_channel && + ncsi_channel_has_link(np->preferred_channel)) { + enable = np->preferred_channel; + break; + } + + NCSI_FOR_EACH_CHANNEL(np, nc) { + if (!(np->channel_whitelist & 0x1 << nc->id)) + continue; + if (nc->state != NCSI_CHANNEL_ACTIVE) + continue; + if (ncsi_channel_has_link(nc)) { + enable = nc; + break; + } + } + } + + if (disable == enable) + return -1; + + if (!enable) + return -1; + + if (disable) { + nca.channel = disable->id; + nca.package = disable->package->id; + nca.type = NCSI_PKT_CMD_DCNT; + ret = ncsi_xmit_cmd(&nca); + if (ret) + netdev_err(ndp->ndev.dev, + "Error %d sending DCNT\n", + ret); + } + + netdev_info(ndp->ndev.dev, "NCSI: channel %u enables Tx\n", enable->id); + + nca.channel = enable->id; + nca.package = enable->package->id; + nca.type = NCSI_PKT_CMD_ECNT; + ret = ncsi_xmit_cmd(&nca); + if (ret) + netdev_err(ndp->ndev.dev, + "Error %d sending ECNT\n", + ret); + + return ret; +} + +static void ncsi_configure_channel(struct ncsi_dev_priv *ndp) +{ + struct ncsi_package *np = ndp->active_package; + struct ncsi_channel *nc = ndp->active_channel; + struct ncsi_channel *hot_nc = NULL; + struct ncsi_dev *nd = &ndp->ndev; + struct net_device *dev = nd->dev; + struct ncsi_cmd_arg nca; + unsigned char index; + unsigned long flags; + int ret; + + nca.ndp = ndp; + nca.req_flags = NCSI_REQ_FLAG_EVENT_DRIVEN; + switch (nd->state) { + case ncsi_dev_state_config: + case ncsi_dev_state_config_sp: + ndp->pending_req_num = 1; + + /* Select the specific package */ + nca.type = NCSI_PKT_CMD_SP; + if (ndp->flags & NCSI_DEV_HWA) + nca.bytes[0] = 0; + else + nca.bytes[0] = 1; + nca.package = np->id; + nca.channel = NCSI_RESERVED_CHANNEL; + ret = ncsi_xmit_cmd(&nca); + if (ret) { + netdev_err(ndp->ndev.dev, + "NCSI: Failed to transmit CMD_SP\n"); + goto error; + } + + nd->state = ncsi_dev_state_config_cis; + break; + case ncsi_dev_state_config_cis: + ndp->pending_req_num = 1; + + /* Clear initial state */ + nca.type = NCSI_PKT_CMD_CIS; + nca.package = np->id; + nca.channel = nc->id; + ret = ncsi_xmit_cmd(&nca); + if (ret) { + netdev_err(ndp->ndev.dev, + "NCSI: Failed to transmit CMD_CIS\n"); + goto error; + } + + nd->state = ncsi_dev_state_config_oem_gma; + break; + case ncsi_dev_state_config_oem_gma: + nd->state = ncsi_dev_state_config_clear_vids; + ret = -1; + +#if IS_ENABLED(CONFIG_NCSI_OEM_CMD_GET_MAC) + nca.type = NCSI_PKT_CMD_OEM; + nca.package = np->id; + nca.channel = nc->id; + ndp->pending_req_num = 1; + ret = ncsi_gma_handler(&nca, nc->version.mf_id); +#endif /* CONFIG_NCSI_OEM_CMD_GET_MAC */ + + if (ret < 0) + schedule_work(&ndp->work); + + break; + case ncsi_dev_state_config_clear_vids: + case ncsi_dev_state_config_svf: + case ncsi_dev_state_config_ev: + case ncsi_dev_state_config_sma: + case ncsi_dev_state_config_ebf: + case ncsi_dev_state_config_dgmf: + case ncsi_dev_state_config_ecnt: + case ncsi_dev_state_config_ec: + case ncsi_dev_state_config_ae: + case ncsi_dev_state_config_gls: + ndp->pending_req_num = 1; + + nca.package = np->id; + nca.channel = nc->id; + + /* Clear any active filters on the channel before setting */ + if (nd->state == ncsi_dev_state_config_clear_vids) { + ret = clear_one_vid(ndp, nc, &nca); + if (ret) { + nd->state = ncsi_dev_state_config_svf; + schedule_work(&ndp->work); + break; + } + /* Repeat */ + nd->state = ncsi_dev_state_config_clear_vids; + /* Add known VLAN tags to the filter */ + } else if (nd->state == ncsi_dev_state_config_svf) { + ret = set_one_vid(ndp, nc, &nca); + if (ret) { + nd->state = ncsi_dev_state_config_ev; + schedule_work(&ndp->work); + break; + } + /* Repeat */ + nd->state = ncsi_dev_state_config_svf; + /* Enable/Disable the VLAN filter */ + } else if (nd->state == ncsi_dev_state_config_ev) { + if (list_empty(&ndp->vlan_vids)) { + nca.type = NCSI_PKT_CMD_DV; + } else { + nca.type = NCSI_PKT_CMD_EV; + nca.bytes[3] = NCSI_CAP_VLAN_NO; + } + nd->state = ncsi_dev_state_config_sma; + } else if (nd->state == ncsi_dev_state_config_sma) { + /* Use first entry in unicast filter table. Note that + * the MAC filter table starts from entry 1 instead of + * 0. + */ + nca.type = NCSI_PKT_CMD_SMA; + for (index = 0; index < 6; index++) + nca.bytes[index] = dev->dev_addr[index]; + nca.bytes[6] = 0x1; + nca.bytes[7] = 0x1; + nd->state = ncsi_dev_state_config_ebf; + } else if (nd->state == ncsi_dev_state_config_ebf) { + nca.type = NCSI_PKT_CMD_EBF; + nca.dwords[0] = nc->caps[NCSI_CAP_BC].cap; + /* if multicast global filtering is supported then + * disable it so that all multicast packet will be + * forwarded to management controller + */ + if (nc->caps[NCSI_CAP_GENERIC].cap & + NCSI_CAP_GENERIC_MC) + nd->state = ncsi_dev_state_config_dgmf; + else if (ncsi_channel_is_tx(ndp, nc)) + nd->state = ncsi_dev_state_config_ecnt; + else + nd->state = ncsi_dev_state_config_ec; + } else if (nd->state == ncsi_dev_state_config_dgmf) { + nca.type = NCSI_PKT_CMD_DGMF; + if (ncsi_channel_is_tx(ndp, nc)) + nd->state = ncsi_dev_state_config_ecnt; + else + nd->state = ncsi_dev_state_config_ec; + } else if (nd->state == ncsi_dev_state_config_ecnt) { + if (np->preferred_channel && + nc != np->preferred_channel) + netdev_info(ndp->ndev.dev, + "NCSI: Tx failed over to channel %u\n", + nc->id); + nca.type = NCSI_PKT_CMD_ECNT; + nd->state = ncsi_dev_state_config_ec; + } else if (nd->state == ncsi_dev_state_config_ec) { + /* Enable AEN if it's supported */ + nca.type = NCSI_PKT_CMD_EC; + nd->state = ncsi_dev_state_config_ae; + if (!(nc->caps[NCSI_CAP_AEN].cap & NCSI_CAP_AEN_MASK)) + nd->state = ncsi_dev_state_config_gls; + } else if (nd->state == ncsi_dev_state_config_ae) { + nca.type = NCSI_PKT_CMD_AE; + nca.bytes[0] = 0; + nca.dwords[1] = nc->caps[NCSI_CAP_AEN].cap; + nd->state = ncsi_dev_state_config_gls; + } else if (nd->state == ncsi_dev_state_config_gls) { + nca.type = NCSI_PKT_CMD_GLS; + nd->state = ncsi_dev_state_config_done; + } + + ret = ncsi_xmit_cmd(&nca); + if (ret) { + netdev_err(ndp->ndev.dev, + "NCSI: Failed to transmit CMD %x\n", + nca.type); + goto error; + } + break; + case ncsi_dev_state_config_done: + netdev_dbg(ndp->ndev.dev, "NCSI: channel %u config done\n", + nc->id); + spin_lock_irqsave(&nc->lock, flags); + nc->state = NCSI_CHANNEL_ACTIVE; + + if (ndp->flags & NCSI_DEV_RESET) { + /* A reset event happened during config, start it now */ + nc->reconfigure_needed = false; + spin_unlock_irqrestore(&nc->lock, flags); + ncsi_reset_dev(nd); + break; + } + + if (nc->reconfigure_needed) { + /* This channel's configuration has been updated + * part-way during the config state - start the + * channel configuration over + */ + nc->reconfigure_needed = false; + nc->state = NCSI_CHANNEL_INACTIVE; + spin_unlock_irqrestore(&nc->lock, flags); + + spin_lock_irqsave(&ndp->lock, flags); + list_add_tail_rcu(&nc->link, &ndp->channel_queue); + spin_unlock_irqrestore(&ndp->lock, flags); + + netdev_dbg(dev, "Dirty NCSI channel state reset\n"); + ncsi_process_next_channel(ndp); + break; + } + + if (nc->modes[NCSI_MODE_LINK].data[2] & 0x1) { + hot_nc = nc; + } else { + hot_nc = NULL; + netdev_dbg(ndp->ndev.dev, + "NCSI: channel %u link down after config\n", + nc->id); + } + spin_unlock_irqrestore(&nc->lock, flags); + + /* Update the hot channel */ + spin_lock_irqsave(&ndp->lock, flags); + ndp->hot_channel = hot_nc; + spin_unlock_irqrestore(&ndp->lock, flags); + + ncsi_start_channel_monitor(nc); + ncsi_process_next_channel(ndp); + break; + default: + netdev_alert(dev, "Wrong NCSI state 0x%x in config\n", + nd->state); + } + + return; + +error: + ncsi_report_link(ndp, true); +} + +static int ncsi_choose_active_channel(struct ncsi_dev_priv *ndp) +{ + struct ncsi_channel *nc, *found, *hot_nc; + struct ncsi_channel_mode *ncm; + unsigned long flags, cflags; + struct ncsi_package *np; + bool with_link; + + spin_lock_irqsave(&ndp->lock, flags); + hot_nc = ndp->hot_channel; + spin_unlock_irqrestore(&ndp->lock, flags); + + /* By default the search is done once an inactive channel with up + * link is found, unless a preferred channel is set. + * If multi_package or multi_channel are configured all channels in the + * whitelist are added to the channel queue. + */ + found = NULL; + with_link = false; + NCSI_FOR_EACH_PACKAGE(ndp, np) { + if (!(ndp->package_whitelist & (0x1 << np->id))) + continue; + NCSI_FOR_EACH_CHANNEL(np, nc) { + if (!(np->channel_whitelist & (0x1 << nc->id))) + continue; + + spin_lock_irqsave(&nc->lock, cflags); + + if (!list_empty(&nc->link) || + nc->state != NCSI_CHANNEL_INACTIVE) { + spin_unlock_irqrestore(&nc->lock, cflags); + continue; + } + + if (!found) + found = nc; + + if (nc == hot_nc) + found = nc; + + ncm = &nc->modes[NCSI_MODE_LINK]; + if (ncm->data[2] & 0x1) { + found = nc; + with_link = true; + } + + /* If multi_channel is enabled configure all valid + * channels whether or not they currently have link + * so they will have AENs enabled. + */ + if (with_link || np->multi_channel) { + spin_lock_irqsave(&ndp->lock, flags); + list_add_tail_rcu(&nc->link, + &ndp->channel_queue); + spin_unlock_irqrestore(&ndp->lock, flags); + + netdev_dbg(ndp->ndev.dev, + "NCSI: Channel %u added to queue (link %s)\n", + nc->id, + ncm->data[2] & 0x1 ? "up" : "down"); + } + + spin_unlock_irqrestore(&nc->lock, cflags); + + if (with_link && !np->multi_channel) + break; + } + if (with_link && !ndp->multi_package) + break; + } + + if (list_empty(&ndp->channel_queue) && found) { + netdev_info(ndp->ndev.dev, + "NCSI: No channel with link found, configuring channel %u\n", + found->id); + spin_lock_irqsave(&ndp->lock, flags); + list_add_tail_rcu(&found->link, &ndp->channel_queue); + spin_unlock_irqrestore(&ndp->lock, flags); + } else if (!found) { + netdev_warn(ndp->ndev.dev, + "NCSI: No channel found to configure!\n"); + ncsi_report_link(ndp, true); + return -ENODEV; + } + + return ncsi_process_next_channel(ndp); +} + +static bool ncsi_check_hwa(struct ncsi_dev_priv *ndp) +{ + struct ncsi_package *np; + struct ncsi_channel *nc; + unsigned int cap; + bool has_channel = false; + + /* The hardware arbitration is disabled if any one channel + * doesn't support explicitly. + */ + NCSI_FOR_EACH_PACKAGE(ndp, np) { + NCSI_FOR_EACH_CHANNEL(np, nc) { + has_channel = true; + + cap = nc->caps[NCSI_CAP_GENERIC].cap; + if (!(cap & NCSI_CAP_GENERIC_HWA) || + (cap & NCSI_CAP_GENERIC_HWA_MASK) != + NCSI_CAP_GENERIC_HWA_SUPPORT) { + ndp->flags &= ~NCSI_DEV_HWA; + return false; + } + } + } + + if (has_channel) { + ndp->flags |= NCSI_DEV_HWA; + return true; + } + + ndp->flags &= ~NCSI_DEV_HWA; + return false; +} + +static void ncsi_probe_channel(struct ncsi_dev_priv *ndp) +{ + struct ncsi_dev *nd = &ndp->ndev; + struct ncsi_package *np; + struct ncsi_channel *nc; + struct ncsi_cmd_arg nca; + unsigned char index; + int ret; + + nca.ndp = ndp; + nca.req_flags = NCSI_REQ_FLAG_EVENT_DRIVEN; + switch (nd->state) { + case ncsi_dev_state_probe: + nd->state = ncsi_dev_state_probe_deselect; + fallthrough; + case ncsi_dev_state_probe_deselect: + ndp->pending_req_num = 8; + + /* Deselect all possible packages */ + nca.type = NCSI_PKT_CMD_DP; + nca.channel = NCSI_RESERVED_CHANNEL; + for (index = 0; index < 8; index++) { + nca.package = index; + ret = ncsi_xmit_cmd(&nca); + if (ret) + goto error; + } + + nd->state = ncsi_dev_state_probe_package; + break; + case ncsi_dev_state_probe_package: + ndp->pending_req_num = 1; + + nca.type = NCSI_PKT_CMD_SP; + nca.bytes[0] = 1; + nca.package = ndp->package_probe_id; + nca.channel = NCSI_RESERVED_CHANNEL; + ret = ncsi_xmit_cmd(&nca); + if (ret) + goto error; + nd->state = ncsi_dev_state_probe_channel; + break; + case ncsi_dev_state_probe_channel: + ndp->active_package = ncsi_find_package(ndp, + ndp->package_probe_id); + if (!ndp->active_package) { + /* No response */ + nd->state = ncsi_dev_state_probe_dp; + schedule_work(&ndp->work); + break; + } + nd->state = ncsi_dev_state_probe_cis; + if (IS_ENABLED(CONFIG_NCSI_OEM_CMD_GET_MAC) && + ndp->mlx_multi_host) + nd->state = ncsi_dev_state_probe_mlx_gma; + + schedule_work(&ndp->work); + break; +#if IS_ENABLED(CONFIG_NCSI_OEM_CMD_GET_MAC) + case ncsi_dev_state_probe_mlx_gma: + ndp->pending_req_num = 1; + + nca.type = NCSI_PKT_CMD_OEM; + nca.package = ndp->active_package->id; + nca.channel = 0; + ret = ncsi_oem_gma_handler_mlx(&nca); + if (ret) + goto error; + + nd->state = ncsi_dev_state_probe_mlx_smaf; + break; + case ncsi_dev_state_probe_mlx_smaf: + ndp->pending_req_num = 1; + + nca.type = NCSI_PKT_CMD_OEM; + nca.package = ndp->active_package->id; + nca.channel = 0; + ret = ncsi_oem_smaf_mlx(&nca); + if (ret) + goto error; + + nd->state = ncsi_dev_state_probe_cis; + break; +#endif /* CONFIG_NCSI_OEM_CMD_GET_MAC */ + case ncsi_dev_state_probe_cis: + ndp->pending_req_num = NCSI_RESERVED_CHANNEL; + + /* Clear initial state */ + nca.type = NCSI_PKT_CMD_CIS; + nca.package = ndp->active_package->id; + for (index = 0; index < NCSI_RESERVED_CHANNEL; index++) { + nca.channel = index; + ret = ncsi_xmit_cmd(&nca); + if (ret) + goto error; + } + + nd->state = ncsi_dev_state_probe_gvi; + if (IS_ENABLED(CONFIG_NCSI_OEM_CMD_KEEP_PHY)) + nd->state = ncsi_dev_state_probe_keep_phy; + break; +#if IS_ENABLED(CONFIG_NCSI_OEM_CMD_KEEP_PHY) + case ncsi_dev_state_probe_keep_phy: + ndp->pending_req_num = 1; + + nca.type = NCSI_PKT_CMD_OEM; + nca.package = ndp->active_package->id; + nca.channel = 0; + ret = ncsi_oem_keep_phy_intel(&nca); + if (ret) + goto error; + + nd->state = ncsi_dev_state_probe_gvi; + break; +#endif /* CONFIG_NCSI_OEM_CMD_KEEP_PHY */ + case ncsi_dev_state_probe_gvi: + case ncsi_dev_state_probe_gc: + case ncsi_dev_state_probe_gls: + np = ndp->active_package; + ndp->pending_req_num = np->channel_num; + + /* Retrieve version, capability or link status */ + if (nd->state == ncsi_dev_state_probe_gvi) + nca.type = NCSI_PKT_CMD_GVI; + else if (nd->state == ncsi_dev_state_probe_gc) + nca.type = NCSI_PKT_CMD_GC; + else + nca.type = NCSI_PKT_CMD_GLS; + + nca.package = np->id; + NCSI_FOR_EACH_CHANNEL(np, nc) { + nca.channel = nc->id; + ret = ncsi_xmit_cmd(&nca); + if (ret) + goto error; + } + + if (nd->state == ncsi_dev_state_probe_gvi) + nd->state = ncsi_dev_state_probe_gc; + else if (nd->state == ncsi_dev_state_probe_gc) + nd->state = ncsi_dev_state_probe_gls; + else + nd->state = ncsi_dev_state_probe_dp; + break; + case ncsi_dev_state_probe_dp: + ndp->pending_req_num = 1; + + /* Deselect the current package */ + nca.type = NCSI_PKT_CMD_DP; + nca.package = ndp->package_probe_id; + nca.channel = NCSI_RESERVED_CHANNEL; + ret = ncsi_xmit_cmd(&nca); + if (ret) + goto error; + + /* Probe next package */ + ndp->package_probe_id++; + if (ndp->package_probe_id >= 8) { + /* Probe finished */ + ndp->flags |= NCSI_DEV_PROBED; + break; + } + nd->state = ncsi_dev_state_probe_package; + ndp->active_package = NULL; + break; + default: + netdev_warn(nd->dev, "Wrong NCSI state 0x%0x in enumeration\n", + nd->state); + } + + if (ndp->flags & NCSI_DEV_PROBED) { + /* Check if all packages have HWA support */ + ncsi_check_hwa(ndp); + ncsi_choose_active_channel(ndp); + } + + return; +error: + netdev_err(ndp->ndev.dev, + "NCSI: Failed to transmit cmd 0x%x during probe\n", + nca.type); + ncsi_report_link(ndp, true); +} + +static void ncsi_dev_work(struct work_struct *work) +{ + struct ncsi_dev_priv *ndp = container_of(work, + struct ncsi_dev_priv, work); + struct ncsi_dev *nd = &ndp->ndev; + + switch (nd->state & ncsi_dev_state_major) { + case ncsi_dev_state_probe: + ncsi_probe_channel(ndp); + break; + case ncsi_dev_state_suspend: + ncsi_suspend_channel(ndp); + break; + case ncsi_dev_state_config: + ncsi_configure_channel(ndp); + break; + default: + netdev_warn(nd->dev, "Wrong NCSI state 0x%x in workqueue\n", + nd->state); + } +} + +int ncsi_process_next_channel(struct ncsi_dev_priv *ndp) +{ + struct ncsi_channel *nc; + int old_state; + unsigned long flags; + + spin_lock_irqsave(&ndp->lock, flags); + nc = list_first_or_null_rcu(&ndp->channel_queue, + struct ncsi_channel, link); + if (!nc) { + spin_unlock_irqrestore(&ndp->lock, flags); + goto out; + } + + list_del_init(&nc->link); + spin_unlock_irqrestore(&ndp->lock, flags); + + spin_lock_irqsave(&nc->lock, flags); + old_state = nc->state; + nc->state = NCSI_CHANNEL_INVISIBLE; + spin_unlock_irqrestore(&nc->lock, flags); + + ndp->active_channel = nc; + ndp->active_package = nc->package; + + switch (old_state) { + case NCSI_CHANNEL_INACTIVE: + ndp->ndev.state = ncsi_dev_state_config; + netdev_dbg(ndp->ndev.dev, "NCSI: configuring channel %u\n", + nc->id); + ncsi_configure_channel(ndp); + break; + case NCSI_CHANNEL_ACTIVE: + ndp->ndev.state = ncsi_dev_state_suspend; + netdev_dbg(ndp->ndev.dev, "NCSI: suspending channel %u\n", + nc->id); + ncsi_suspend_channel(ndp); + break; + default: + netdev_err(ndp->ndev.dev, "Invalid state 0x%x on %d:%d\n", + old_state, nc->package->id, nc->id); + ncsi_report_link(ndp, false); + return -EINVAL; + } + + return 0; + +out: + ndp->active_channel = NULL; + ndp->active_package = NULL; + if (ndp->flags & NCSI_DEV_RESHUFFLE) { + ndp->flags &= ~NCSI_DEV_RESHUFFLE; + return ncsi_choose_active_channel(ndp); + } + + ncsi_report_link(ndp, false); + return -ENODEV; +} + +static int ncsi_kick_channels(struct ncsi_dev_priv *ndp) +{ + struct ncsi_dev *nd = &ndp->ndev; + struct ncsi_channel *nc; + struct ncsi_package *np; + unsigned long flags; + unsigned int n = 0; + + NCSI_FOR_EACH_PACKAGE(ndp, np) { + NCSI_FOR_EACH_CHANNEL(np, nc) { + spin_lock_irqsave(&nc->lock, flags); + + /* Channels may be busy, mark dirty instead of + * kicking if; + * a) not ACTIVE (configured) + * b) in the channel_queue (to be configured) + * c) it's ndev is in the config state + */ + if (nc->state != NCSI_CHANNEL_ACTIVE) { + if ((ndp->ndev.state & 0xff00) == + ncsi_dev_state_config || + !list_empty(&nc->link)) { + netdev_dbg(nd->dev, + "NCSI: channel %p marked dirty\n", + nc); + nc->reconfigure_needed = true; + } + spin_unlock_irqrestore(&nc->lock, flags); + continue; + } + + spin_unlock_irqrestore(&nc->lock, flags); + + ncsi_stop_channel_monitor(nc); + spin_lock_irqsave(&nc->lock, flags); + nc->state = NCSI_CHANNEL_INACTIVE; + spin_unlock_irqrestore(&nc->lock, flags); + + spin_lock_irqsave(&ndp->lock, flags); + list_add_tail_rcu(&nc->link, &ndp->channel_queue); + spin_unlock_irqrestore(&ndp->lock, flags); + + netdev_dbg(nd->dev, "NCSI: kicked channel %p\n", nc); + n++; + } + } + + return n; +} + +int ncsi_vlan_rx_add_vid(struct net_device *dev, __be16 proto, u16 vid) +{ + struct ncsi_dev_priv *ndp; + unsigned int n_vids = 0; + struct vlan_vid *vlan; + struct ncsi_dev *nd; + bool found = false; + + if (vid == 0) + return 0; + + nd = ncsi_find_dev(dev); + if (!nd) { + netdev_warn(dev, "NCSI: No net_device?\n"); + return 0; + } + + ndp = TO_NCSI_DEV_PRIV(nd); + + /* Add the VLAN id to our internal list */ + list_for_each_entry_rcu(vlan, &ndp->vlan_vids, list) { + n_vids++; + if (vlan->vid == vid) { + netdev_dbg(dev, "NCSI: vid %u already registered\n", + vid); + return 0; + } + } + if (n_vids >= NCSI_MAX_VLAN_VIDS) { + netdev_warn(dev, + "tried to add vlan id %u but NCSI max already registered (%u)\n", + vid, NCSI_MAX_VLAN_VIDS); + return -ENOSPC; + } + + vlan = kzalloc(sizeof(*vlan), GFP_KERNEL); + if (!vlan) + return -ENOMEM; + + vlan->proto = proto; + vlan->vid = vid; + list_add_rcu(&vlan->list, &ndp->vlan_vids); + + netdev_dbg(dev, "NCSI: Added new vid %u\n", vid); + + found = ncsi_kick_channels(ndp) != 0; + + return found ? ncsi_process_next_channel(ndp) : 0; +} +EXPORT_SYMBOL_GPL(ncsi_vlan_rx_add_vid); + +int ncsi_vlan_rx_kill_vid(struct net_device *dev, __be16 proto, u16 vid) +{ + struct vlan_vid *vlan, *tmp; + struct ncsi_dev_priv *ndp; + struct ncsi_dev *nd; + bool found = false; + + if (vid == 0) + return 0; + + nd = ncsi_find_dev(dev); + if (!nd) { + netdev_warn(dev, "NCSI: no net_device?\n"); + return 0; + } + + ndp = TO_NCSI_DEV_PRIV(nd); + + /* Remove the VLAN id from our internal list */ + list_for_each_entry_safe(vlan, tmp, &ndp->vlan_vids, list) + if (vlan->vid == vid) { + netdev_dbg(dev, "NCSI: vid %u found, removing\n", vid); + list_del_rcu(&vlan->list); + found = true; + kfree(vlan); + } + + if (!found) { + netdev_err(dev, "NCSI: vid %u wasn't registered!\n", vid); + return -EINVAL; + } + + found = ncsi_kick_channels(ndp) != 0; + + return found ? ncsi_process_next_channel(ndp) : 0; +} +EXPORT_SYMBOL_GPL(ncsi_vlan_rx_kill_vid); + +struct ncsi_dev *ncsi_register_dev(struct net_device *dev, + void (*handler)(struct ncsi_dev *ndev)) +{ + struct ncsi_dev_priv *ndp; + struct ncsi_dev *nd; + struct platform_device *pdev; + struct device_node *np; + unsigned long flags; + int i; + + /* Check if the device has been registered or not */ + nd = ncsi_find_dev(dev); + if (nd) + return nd; + + /* Create NCSI device */ + ndp = kzalloc(sizeof(*ndp), GFP_ATOMIC); + if (!ndp) + return NULL; + + nd = &ndp->ndev; + nd->state = ncsi_dev_state_registered; + nd->dev = dev; + nd->handler = handler; + ndp->pending_req_num = 0; + INIT_LIST_HEAD(&ndp->channel_queue); + INIT_LIST_HEAD(&ndp->vlan_vids); + INIT_WORK(&ndp->work, ncsi_dev_work); + ndp->package_whitelist = UINT_MAX; + + /* Initialize private NCSI device */ + spin_lock_init(&ndp->lock); + INIT_LIST_HEAD(&ndp->packages); + ndp->request_id = NCSI_REQ_START_IDX; + for (i = 0; i < ARRAY_SIZE(ndp->requests); i++) { + ndp->requests[i].id = i; + ndp->requests[i].ndp = ndp; + timer_setup(&ndp->requests[i].timer, ncsi_request_timeout, 0); + } + + spin_lock_irqsave(&ncsi_dev_lock, flags); + list_add_tail_rcu(&ndp->node, &ncsi_dev_list); + spin_unlock_irqrestore(&ncsi_dev_lock, flags); + + /* Register NCSI packet Rx handler */ + ndp->ptype.type = cpu_to_be16(ETH_P_NCSI); + ndp->ptype.func = ncsi_rcv_rsp; + ndp->ptype.dev = dev; + dev_add_pack(&ndp->ptype); + + pdev = to_platform_device(dev->dev.parent); + if (pdev) { + np = pdev->dev.of_node; + if (np && (of_get_property(np, "mellanox,multi-host", NULL) || + of_get_property(np, "mlx,multi-host", NULL))) + ndp->mlx_multi_host = true; + } + + return nd; +} +EXPORT_SYMBOL_GPL(ncsi_register_dev); + +int ncsi_start_dev(struct ncsi_dev *nd) +{ + struct ncsi_dev_priv *ndp = TO_NCSI_DEV_PRIV(nd); + + if (nd->state != ncsi_dev_state_registered && + nd->state != ncsi_dev_state_functional) + return -ENOTTY; + + if (!(ndp->flags & NCSI_DEV_PROBED)) { + ndp->package_probe_id = 0; + nd->state = ncsi_dev_state_probe; + schedule_work(&ndp->work); + return 0; + } + + return ncsi_reset_dev(nd); +} +EXPORT_SYMBOL_GPL(ncsi_start_dev); + +void ncsi_stop_dev(struct ncsi_dev *nd) +{ + struct ncsi_dev_priv *ndp = TO_NCSI_DEV_PRIV(nd); + struct ncsi_package *np; + struct ncsi_channel *nc; + bool chained; + int old_state; + unsigned long flags; + + /* Stop the channel monitor on any active channels. Don't reset the + * channel state so we know which were active when ncsi_start_dev() + * is next called. + */ + NCSI_FOR_EACH_PACKAGE(ndp, np) { + NCSI_FOR_EACH_CHANNEL(np, nc) { + ncsi_stop_channel_monitor(nc); + + spin_lock_irqsave(&nc->lock, flags); + chained = !list_empty(&nc->link); + old_state = nc->state; + spin_unlock_irqrestore(&nc->lock, flags); + + WARN_ON_ONCE(chained || + old_state == NCSI_CHANNEL_INVISIBLE); + } + } + + netdev_dbg(ndp->ndev.dev, "NCSI: Stopping device\n"); + ncsi_report_link(ndp, true); +} +EXPORT_SYMBOL_GPL(ncsi_stop_dev); + +int ncsi_reset_dev(struct ncsi_dev *nd) +{ + struct ncsi_dev_priv *ndp = TO_NCSI_DEV_PRIV(nd); + struct ncsi_channel *nc, *active, *tmp; + struct ncsi_package *np; + unsigned long flags; + + spin_lock_irqsave(&ndp->lock, flags); + + if (!(ndp->flags & NCSI_DEV_RESET)) { + /* Haven't been called yet, check states */ + switch (nd->state & ncsi_dev_state_major) { + case ncsi_dev_state_registered: + case ncsi_dev_state_probe: + /* Not even probed yet - do nothing */ + spin_unlock_irqrestore(&ndp->lock, flags); + return 0; + case ncsi_dev_state_suspend: + case ncsi_dev_state_config: + /* Wait for the channel to finish its suspend/config + * operation; once it finishes it will check for + * NCSI_DEV_RESET and reset the state. + */ + ndp->flags |= NCSI_DEV_RESET; + spin_unlock_irqrestore(&ndp->lock, flags); + return 0; + } + } else { + switch (nd->state) { + case ncsi_dev_state_suspend_done: + case ncsi_dev_state_config_done: + case ncsi_dev_state_functional: + /* Ok */ + break; + default: + /* Current reset operation happening */ + spin_unlock_irqrestore(&ndp->lock, flags); + return 0; + } + } + + if (!list_empty(&ndp->channel_queue)) { + /* Clear any channel queue we may have interrupted */ + list_for_each_entry_safe(nc, tmp, &ndp->channel_queue, link) + list_del_init(&nc->link); + } + spin_unlock_irqrestore(&ndp->lock, flags); + + active = NULL; + NCSI_FOR_EACH_PACKAGE(ndp, np) { + NCSI_FOR_EACH_CHANNEL(np, nc) { + spin_lock_irqsave(&nc->lock, flags); + + if (nc->state == NCSI_CHANNEL_ACTIVE) { + active = nc; + nc->state = NCSI_CHANNEL_INVISIBLE; + spin_unlock_irqrestore(&nc->lock, flags); + ncsi_stop_channel_monitor(nc); + break; + } + + spin_unlock_irqrestore(&nc->lock, flags); + } + if (active) + break; + } + + if (!active) { + /* Done */ + spin_lock_irqsave(&ndp->lock, flags); + ndp->flags &= ~NCSI_DEV_RESET; + spin_unlock_irqrestore(&ndp->lock, flags); + return ncsi_choose_active_channel(ndp); + } + + spin_lock_irqsave(&ndp->lock, flags); + ndp->flags |= NCSI_DEV_RESET; + ndp->active_channel = active; + ndp->active_package = active->package; + spin_unlock_irqrestore(&ndp->lock, flags); + + nd->state = ncsi_dev_state_suspend; + schedule_work(&ndp->work); + return 0; +} + +void ncsi_unregister_dev(struct ncsi_dev *nd) +{ + struct ncsi_dev_priv *ndp = TO_NCSI_DEV_PRIV(nd); + struct ncsi_package *np, *tmp; + unsigned long flags; + + dev_remove_pack(&ndp->ptype); + + list_for_each_entry_safe(np, tmp, &ndp->packages, node) + ncsi_remove_package(np); + + spin_lock_irqsave(&ncsi_dev_lock, flags); + list_del_rcu(&ndp->node); + spin_unlock_irqrestore(&ncsi_dev_lock, flags); + + kfree(ndp); +} +EXPORT_SYMBOL_GPL(ncsi_unregister_dev); diff --git a/net/ncsi/ncsi-netlink.c b/net/ncsi/ncsi-netlink.c new file mode 100644 index 000000000..fe681680b --- /dev/null +++ b/net/ncsi/ncsi-netlink.c @@ -0,0 +1,778 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright Samuel Mendoza-Jonas, IBM Corporation 2018. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "internal.h" +#include "ncsi-pkt.h" +#include "ncsi-netlink.h" + +static struct genl_family ncsi_genl_family; + +static const struct nla_policy ncsi_genl_policy[NCSI_ATTR_MAX + 1] = { + [NCSI_ATTR_IFINDEX] = { .type = NLA_U32 }, + [NCSI_ATTR_PACKAGE_LIST] = { .type = NLA_NESTED }, + [NCSI_ATTR_PACKAGE_ID] = { .type = NLA_U32 }, + [NCSI_ATTR_CHANNEL_ID] = { .type = NLA_U32 }, + [NCSI_ATTR_DATA] = { .type = NLA_BINARY, .len = 2048 }, + [NCSI_ATTR_MULTI_FLAG] = { .type = NLA_FLAG }, + [NCSI_ATTR_PACKAGE_MASK] = { .type = NLA_U32 }, + [NCSI_ATTR_CHANNEL_MASK] = { .type = NLA_U32 }, +}; + +static struct ncsi_dev_priv *ndp_from_ifindex(struct net *net, u32 ifindex) +{ + struct ncsi_dev_priv *ndp; + struct net_device *dev; + struct ncsi_dev *nd; + struct ncsi_dev; + + if (!net) + return NULL; + + dev = dev_get_by_index(net, ifindex); + if (!dev) { + pr_err("NCSI netlink: No device for ifindex %u\n", ifindex); + return NULL; + } + + nd = ncsi_find_dev(dev); + ndp = nd ? TO_NCSI_DEV_PRIV(nd) : NULL; + + dev_put(dev); + return ndp; +} + +static int ncsi_write_channel_info(struct sk_buff *skb, + struct ncsi_dev_priv *ndp, + struct ncsi_channel *nc) +{ + struct ncsi_channel_vlan_filter *ncf; + struct ncsi_channel_mode *m; + struct nlattr *vid_nest; + int i; + + nla_put_u32(skb, NCSI_CHANNEL_ATTR_ID, nc->id); + m = &nc->modes[NCSI_MODE_LINK]; + nla_put_u32(skb, NCSI_CHANNEL_ATTR_LINK_STATE, m->data[2]); + if (nc->state == NCSI_CHANNEL_ACTIVE) + nla_put_flag(skb, NCSI_CHANNEL_ATTR_ACTIVE); + if (nc == nc->package->preferred_channel) + nla_put_flag(skb, NCSI_CHANNEL_ATTR_FORCED); + + nla_put_u32(skb, NCSI_CHANNEL_ATTR_VERSION_MAJOR, nc->version.major); + nla_put_u32(skb, NCSI_CHANNEL_ATTR_VERSION_MINOR, nc->version.minor); + nla_put_string(skb, NCSI_CHANNEL_ATTR_VERSION_STR, nc->version.fw_name); + + vid_nest = nla_nest_start_noflag(skb, NCSI_CHANNEL_ATTR_VLAN_LIST); + if (!vid_nest) + return -ENOMEM; + ncf = &nc->vlan_filter; + i = -1; + while ((i = find_next_bit((void *)&ncf->bitmap, ncf->n_vids, + i + 1)) < ncf->n_vids) { + if (ncf->vids[i]) + nla_put_u16(skb, NCSI_CHANNEL_ATTR_VLAN_ID, + ncf->vids[i]); + } + nla_nest_end(skb, vid_nest); + + return 0; +} + +static int ncsi_write_package_info(struct sk_buff *skb, + struct ncsi_dev_priv *ndp, unsigned int id) +{ + struct nlattr *pnest, *cnest, *nest; + struct ncsi_package *np; + struct ncsi_channel *nc; + bool found; + int rc; + + if (id > ndp->package_num - 1) { + netdev_info(ndp->ndev.dev, "NCSI: No package with id %u\n", id); + return -ENODEV; + } + + found = false; + NCSI_FOR_EACH_PACKAGE(ndp, np) { + if (np->id != id) + continue; + pnest = nla_nest_start_noflag(skb, NCSI_PKG_ATTR); + if (!pnest) + return -ENOMEM; + rc = nla_put_u32(skb, NCSI_PKG_ATTR_ID, np->id); + if (rc) { + nla_nest_cancel(skb, pnest); + return rc; + } + if ((0x1 << np->id) == ndp->package_whitelist) + nla_put_flag(skb, NCSI_PKG_ATTR_FORCED); + cnest = nla_nest_start_noflag(skb, NCSI_PKG_ATTR_CHANNEL_LIST); + if (!cnest) { + nla_nest_cancel(skb, pnest); + return -ENOMEM; + } + NCSI_FOR_EACH_CHANNEL(np, nc) { + nest = nla_nest_start_noflag(skb, NCSI_CHANNEL_ATTR); + if (!nest) { + nla_nest_cancel(skb, cnest); + nla_nest_cancel(skb, pnest); + return -ENOMEM; + } + rc = ncsi_write_channel_info(skb, ndp, nc); + if (rc) { + nla_nest_cancel(skb, nest); + nla_nest_cancel(skb, cnest); + nla_nest_cancel(skb, pnest); + return rc; + } + nla_nest_end(skb, nest); + } + nla_nest_end(skb, cnest); + nla_nest_end(skb, pnest); + found = true; + } + + if (!found) + return -ENODEV; + + return 0; +} + +static int ncsi_pkg_info_nl(struct sk_buff *msg, struct genl_info *info) +{ + struct ncsi_dev_priv *ndp; + unsigned int package_id; + struct sk_buff *skb; + struct nlattr *attr; + void *hdr; + int rc; + + if (!info || !info->attrs) + return -EINVAL; + + if (!info->attrs[NCSI_ATTR_IFINDEX]) + return -EINVAL; + + if (!info->attrs[NCSI_ATTR_PACKAGE_ID]) + return -EINVAL; + + ndp = ndp_from_ifindex(genl_info_net(info), + nla_get_u32(info->attrs[NCSI_ATTR_IFINDEX])); + if (!ndp) + return -ENODEV; + + skb = genlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb) + return -ENOMEM; + + hdr = genlmsg_put(skb, info->snd_portid, info->snd_seq, + &ncsi_genl_family, 0, NCSI_CMD_PKG_INFO); + if (!hdr) { + kfree_skb(skb); + return -EMSGSIZE; + } + + package_id = nla_get_u32(info->attrs[NCSI_ATTR_PACKAGE_ID]); + + attr = nla_nest_start_noflag(skb, NCSI_ATTR_PACKAGE_LIST); + if (!attr) { + kfree_skb(skb); + return -EMSGSIZE; + } + rc = ncsi_write_package_info(skb, ndp, package_id); + + if (rc) { + nla_nest_cancel(skb, attr); + goto err; + } + + nla_nest_end(skb, attr); + + genlmsg_end(skb, hdr); + return genlmsg_reply(skb, info); + +err: + kfree_skb(skb); + return rc; +} + +static int ncsi_pkg_info_all_nl(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct nlattr *attrs[NCSI_ATTR_MAX + 1]; + struct ncsi_package *np, *package; + struct ncsi_dev_priv *ndp; + unsigned int package_id; + struct nlattr *attr; + void *hdr; + int rc; + + rc = genlmsg_parse_deprecated(cb->nlh, &ncsi_genl_family, attrs, NCSI_ATTR_MAX, + ncsi_genl_policy, NULL); + if (rc) + return rc; + + if (!attrs[NCSI_ATTR_IFINDEX]) + return -EINVAL; + + ndp = ndp_from_ifindex(get_net(sock_net(skb->sk)), + nla_get_u32(attrs[NCSI_ATTR_IFINDEX])); + + if (!ndp) + return -ENODEV; + + package_id = cb->args[0]; + package = NULL; + NCSI_FOR_EACH_PACKAGE(ndp, np) + if (np->id == package_id) + package = np; + + if (!package) + return 0; /* done */ + + hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &ncsi_genl_family, NLM_F_MULTI, NCSI_CMD_PKG_INFO); + if (!hdr) { + rc = -EMSGSIZE; + goto err; + } + + attr = nla_nest_start_noflag(skb, NCSI_ATTR_PACKAGE_LIST); + if (!attr) { + rc = -EMSGSIZE; + goto err; + } + rc = ncsi_write_package_info(skb, ndp, package->id); + if (rc) { + nla_nest_cancel(skb, attr); + goto err; + } + + nla_nest_end(skb, attr); + genlmsg_end(skb, hdr); + + cb->args[0] = package_id + 1; + + return skb->len; +err: + genlmsg_cancel(skb, hdr); + return rc; +} + +static int ncsi_set_interface_nl(struct sk_buff *msg, struct genl_info *info) +{ + struct ncsi_package *np, *package; + struct ncsi_channel *nc, *channel; + u32 package_id, channel_id; + struct ncsi_dev_priv *ndp; + unsigned long flags; + + if (!info || !info->attrs) + return -EINVAL; + + if (!info->attrs[NCSI_ATTR_IFINDEX]) + return -EINVAL; + + if (!info->attrs[NCSI_ATTR_PACKAGE_ID]) + return -EINVAL; + + ndp = ndp_from_ifindex(get_net(sock_net(msg->sk)), + nla_get_u32(info->attrs[NCSI_ATTR_IFINDEX])); + if (!ndp) + return -ENODEV; + + package_id = nla_get_u32(info->attrs[NCSI_ATTR_PACKAGE_ID]); + package = NULL; + + NCSI_FOR_EACH_PACKAGE(ndp, np) + if (np->id == package_id) + package = np; + if (!package) { + /* The user has set a package that does not exist */ + return -ERANGE; + } + + channel = NULL; + if (info->attrs[NCSI_ATTR_CHANNEL_ID]) { + channel_id = nla_get_u32(info->attrs[NCSI_ATTR_CHANNEL_ID]); + NCSI_FOR_EACH_CHANNEL(package, nc) + if (nc->id == channel_id) { + channel = nc; + break; + } + if (!channel) { + netdev_info(ndp->ndev.dev, + "NCSI: Channel %u does not exist!\n", + channel_id); + return -ERANGE; + } + } + + spin_lock_irqsave(&ndp->lock, flags); + ndp->package_whitelist = 0x1 << package->id; + ndp->multi_package = false; + spin_unlock_irqrestore(&ndp->lock, flags); + + spin_lock_irqsave(&package->lock, flags); + package->multi_channel = false; + if (channel) { + package->channel_whitelist = 0x1 << channel->id; + package->preferred_channel = channel; + } else { + /* Allow any channel */ + package->channel_whitelist = UINT_MAX; + package->preferred_channel = NULL; + } + spin_unlock_irqrestore(&package->lock, flags); + + if (channel) + netdev_info(ndp->ndev.dev, + "Set package 0x%x, channel 0x%x as preferred\n", + package_id, channel_id); + else + netdev_info(ndp->ndev.dev, "Set package 0x%x as preferred\n", + package_id); + + /* Update channel configuration */ + if (!(ndp->flags & NCSI_DEV_RESET)) + ncsi_reset_dev(&ndp->ndev); + + return 0; +} + +static int ncsi_clear_interface_nl(struct sk_buff *msg, struct genl_info *info) +{ + struct ncsi_dev_priv *ndp; + struct ncsi_package *np; + unsigned long flags; + + if (!info || !info->attrs) + return -EINVAL; + + if (!info->attrs[NCSI_ATTR_IFINDEX]) + return -EINVAL; + + ndp = ndp_from_ifindex(get_net(sock_net(msg->sk)), + nla_get_u32(info->attrs[NCSI_ATTR_IFINDEX])); + if (!ndp) + return -ENODEV; + + /* Reset any whitelists and disable multi mode */ + spin_lock_irqsave(&ndp->lock, flags); + ndp->package_whitelist = UINT_MAX; + ndp->multi_package = false; + spin_unlock_irqrestore(&ndp->lock, flags); + + NCSI_FOR_EACH_PACKAGE(ndp, np) { + spin_lock_irqsave(&np->lock, flags); + np->multi_channel = false; + np->channel_whitelist = UINT_MAX; + np->preferred_channel = NULL; + spin_unlock_irqrestore(&np->lock, flags); + } + netdev_info(ndp->ndev.dev, "NCSI: Cleared preferred package/channel\n"); + + /* Update channel configuration */ + if (!(ndp->flags & NCSI_DEV_RESET)) + ncsi_reset_dev(&ndp->ndev); + + return 0; +} + +static int ncsi_send_cmd_nl(struct sk_buff *msg, struct genl_info *info) +{ + struct ncsi_dev_priv *ndp; + struct ncsi_pkt_hdr *hdr; + struct ncsi_cmd_arg nca; + unsigned char *data; + u32 package_id; + u32 channel_id; + int len, ret; + + if (!info || !info->attrs) { + ret = -EINVAL; + goto out; + } + + if (!info->attrs[NCSI_ATTR_IFINDEX]) { + ret = -EINVAL; + goto out; + } + + if (!info->attrs[NCSI_ATTR_PACKAGE_ID]) { + ret = -EINVAL; + goto out; + } + + if (!info->attrs[NCSI_ATTR_CHANNEL_ID]) { + ret = -EINVAL; + goto out; + } + + if (!info->attrs[NCSI_ATTR_DATA]) { + ret = -EINVAL; + goto out; + } + + ndp = ndp_from_ifindex(get_net(sock_net(msg->sk)), + nla_get_u32(info->attrs[NCSI_ATTR_IFINDEX])); + if (!ndp) { + ret = -ENODEV; + goto out; + } + + package_id = nla_get_u32(info->attrs[NCSI_ATTR_PACKAGE_ID]); + channel_id = nla_get_u32(info->attrs[NCSI_ATTR_CHANNEL_ID]); + + if (package_id >= NCSI_MAX_PACKAGE || channel_id >= NCSI_MAX_CHANNEL) { + ret = -ERANGE; + goto out_netlink; + } + + len = nla_len(info->attrs[NCSI_ATTR_DATA]); + if (len < sizeof(struct ncsi_pkt_hdr)) { + netdev_info(ndp->ndev.dev, "NCSI: no command to send %u\n", + package_id); + ret = -EINVAL; + goto out_netlink; + } else { + data = (unsigned char *)nla_data(info->attrs[NCSI_ATTR_DATA]); + } + + hdr = (struct ncsi_pkt_hdr *)data; + + nca.ndp = ndp; + nca.package = (unsigned char)package_id; + nca.channel = (unsigned char)channel_id; + nca.type = hdr->type; + nca.req_flags = NCSI_REQ_FLAG_NETLINK_DRIVEN; + nca.info = info; + nca.payload = ntohs(hdr->length); + nca.data = data + sizeof(*hdr); + + ret = ncsi_xmit_cmd(&nca); +out_netlink: + if (ret != 0) { + netdev_err(ndp->ndev.dev, + "NCSI: Error %d sending command\n", + ret); + ncsi_send_netlink_err(ndp->ndev.dev, + info->snd_seq, + info->snd_portid, + info->nlhdr, + ret); + } +out: + return ret; +} + +int ncsi_send_netlink_rsp(struct ncsi_request *nr, + struct ncsi_package *np, + struct ncsi_channel *nc) +{ + struct sk_buff *skb; + struct net *net; + void *hdr; + int rc; + + net = dev_net(nr->rsp->dev); + + skb = genlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!skb) + return -ENOMEM; + + hdr = genlmsg_put(skb, nr->snd_portid, nr->snd_seq, + &ncsi_genl_family, 0, NCSI_CMD_SEND_CMD); + if (!hdr) { + kfree_skb(skb); + return -EMSGSIZE; + } + + nla_put_u32(skb, NCSI_ATTR_IFINDEX, nr->rsp->dev->ifindex); + if (np) + nla_put_u32(skb, NCSI_ATTR_PACKAGE_ID, np->id); + if (nc) + nla_put_u32(skb, NCSI_ATTR_CHANNEL_ID, nc->id); + else + nla_put_u32(skb, NCSI_ATTR_CHANNEL_ID, NCSI_RESERVED_CHANNEL); + + rc = nla_put(skb, NCSI_ATTR_DATA, nr->rsp->len, (void *)nr->rsp->data); + if (rc) + goto err; + + genlmsg_end(skb, hdr); + return genlmsg_unicast(net, skb, nr->snd_portid); + +err: + kfree_skb(skb); + return rc; +} + +int ncsi_send_netlink_timeout(struct ncsi_request *nr, + struct ncsi_package *np, + struct ncsi_channel *nc) +{ + struct sk_buff *skb; + struct net *net; + void *hdr; + + skb = genlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!skb) + return -ENOMEM; + + hdr = genlmsg_put(skb, nr->snd_portid, nr->snd_seq, + &ncsi_genl_family, 0, NCSI_CMD_SEND_CMD); + if (!hdr) { + kfree_skb(skb); + return -EMSGSIZE; + } + + net = dev_net(nr->cmd->dev); + + nla_put_u32(skb, NCSI_ATTR_IFINDEX, nr->cmd->dev->ifindex); + + if (np) + nla_put_u32(skb, NCSI_ATTR_PACKAGE_ID, np->id); + else + nla_put_u32(skb, NCSI_ATTR_PACKAGE_ID, + NCSI_PACKAGE_INDEX((((struct ncsi_pkt_hdr *) + nr->cmd->data)->channel))); + + if (nc) + nla_put_u32(skb, NCSI_ATTR_CHANNEL_ID, nc->id); + else + nla_put_u32(skb, NCSI_ATTR_CHANNEL_ID, NCSI_RESERVED_CHANNEL); + + genlmsg_end(skb, hdr); + return genlmsg_unicast(net, skb, nr->snd_portid); +} + +int ncsi_send_netlink_err(struct net_device *dev, + u32 snd_seq, + u32 snd_portid, + struct nlmsghdr *nlhdr, + int err) +{ + struct nlmsghdr *nlh; + struct nlmsgerr *nle; + struct sk_buff *skb; + struct net *net; + + skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!skb) + return -ENOMEM; + + net = dev_net(dev); + + nlh = nlmsg_put(skb, snd_portid, snd_seq, + NLMSG_ERROR, sizeof(*nle), 0); + nle = (struct nlmsgerr *)nlmsg_data(nlh); + nle->error = err; + memcpy(&nle->msg, nlhdr, sizeof(*nlh)); + + nlmsg_end(skb, nlh); + + return nlmsg_unicast(net->genl_sock, skb, snd_portid); +} + +static int ncsi_set_package_mask_nl(struct sk_buff *msg, + struct genl_info *info) +{ + struct ncsi_dev_priv *ndp; + unsigned long flags; + int rc; + + if (!info || !info->attrs) + return -EINVAL; + + if (!info->attrs[NCSI_ATTR_IFINDEX]) + return -EINVAL; + + if (!info->attrs[NCSI_ATTR_PACKAGE_MASK]) + return -EINVAL; + + ndp = ndp_from_ifindex(get_net(sock_net(msg->sk)), + nla_get_u32(info->attrs[NCSI_ATTR_IFINDEX])); + if (!ndp) + return -ENODEV; + + spin_lock_irqsave(&ndp->lock, flags); + if (nla_get_flag(info->attrs[NCSI_ATTR_MULTI_FLAG])) { + if (ndp->flags & NCSI_DEV_HWA) { + ndp->multi_package = true; + rc = 0; + } else { + netdev_err(ndp->ndev.dev, + "NCSI: Can't use multiple packages without HWA\n"); + rc = -EPERM; + } + } else { + ndp->multi_package = false; + rc = 0; + } + + if (!rc) + ndp->package_whitelist = + nla_get_u32(info->attrs[NCSI_ATTR_PACKAGE_MASK]); + spin_unlock_irqrestore(&ndp->lock, flags); + + if (!rc) { + /* Update channel configuration */ + if (!(ndp->flags & NCSI_DEV_RESET)) + ncsi_reset_dev(&ndp->ndev); + } + + return rc; +} + +static int ncsi_set_channel_mask_nl(struct sk_buff *msg, + struct genl_info *info) +{ + struct ncsi_package *np, *package; + struct ncsi_channel *nc, *channel; + u32 package_id, channel_id; + struct ncsi_dev_priv *ndp; + unsigned long flags; + + if (!info || !info->attrs) + return -EINVAL; + + if (!info->attrs[NCSI_ATTR_IFINDEX]) + return -EINVAL; + + if (!info->attrs[NCSI_ATTR_PACKAGE_ID]) + return -EINVAL; + + if (!info->attrs[NCSI_ATTR_CHANNEL_MASK]) + return -EINVAL; + + ndp = ndp_from_ifindex(get_net(sock_net(msg->sk)), + nla_get_u32(info->attrs[NCSI_ATTR_IFINDEX])); + if (!ndp) + return -ENODEV; + + package_id = nla_get_u32(info->attrs[NCSI_ATTR_PACKAGE_ID]); + package = NULL; + NCSI_FOR_EACH_PACKAGE(ndp, np) + if (np->id == package_id) { + package = np; + break; + } + if (!package) + return -ERANGE; + + spin_lock_irqsave(&package->lock, flags); + + channel = NULL; + if (info->attrs[NCSI_ATTR_CHANNEL_ID]) { + channel_id = nla_get_u32(info->attrs[NCSI_ATTR_CHANNEL_ID]); + NCSI_FOR_EACH_CHANNEL(np, nc) + if (nc->id == channel_id) { + channel = nc; + break; + } + if (!channel) { + spin_unlock_irqrestore(&package->lock, flags); + return -ERANGE; + } + netdev_dbg(ndp->ndev.dev, + "NCSI: Channel %u set as preferred channel\n", + channel->id); + } + + package->channel_whitelist = + nla_get_u32(info->attrs[NCSI_ATTR_CHANNEL_MASK]); + if (package->channel_whitelist == 0) + netdev_dbg(ndp->ndev.dev, + "NCSI: Package %u set to all channels disabled\n", + package->id); + + package->preferred_channel = channel; + + if (nla_get_flag(info->attrs[NCSI_ATTR_MULTI_FLAG])) { + package->multi_channel = true; + netdev_info(ndp->ndev.dev, + "NCSI: Multi-channel enabled on package %u\n", + package_id); + } else { + package->multi_channel = false; + } + + spin_unlock_irqrestore(&package->lock, flags); + + /* Update channel configuration */ + if (!(ndp->flags & NCSI_DEV_RESET)) + ncsi_reset_dev(&ndp->ndev); + + return 0; +} + +static const struct genl_small_ops ncsi_ops[] = { + { + .cmd = NCSI_CMD_PKG_INFO, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = ncsi_pkg_info_nl, + .dumpit = ncsi_pkg_info_all_nl, + .flags = 0, + }, + { + .cmd = NCSI_CMD_SET_INTERFACE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = ncsi_set_interface_nl, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NCSI_CMD_CLEAR_INTERFACE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = ncsi_clear_interface_nl, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NCSI_CMD_SEND_CMD, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = ncsi_send_cmd_nl, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NCSI_CMD_SET_PACKAGE_MASK, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = ncsi_set_package_mask_nl, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NCSI_CMD_SET_CHANNEL_MASK, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = ncsi_set_channel_mask_nl, + .flags = GENL_ADMIN_PERM, + }, +}; + +static struct genl_family ncsi_genl_family __ro_after_init = { + .name = "NCSI", + .version = 0, + .maxattr = NCSI_ATTR_MAX, + .policy = ncsi_genl_policy, + .module = THIS_MODULE, + .small_ops = ncsi_ops, + .n_small_ops = ARRAY_SIZE(ncsi_ops), + .resv_start_op = NCSI_CMD_SET_CHANNEL_MASK + 1, +}; + +static int __init ncsi_init_netlink(void) +{ + return genl_register_family(&ncsi_genl_family); +} +subsys_initcall(ncsi_init_netlink); diff --git a/net/ncsi/ncsi-netlink.h b/net/ncsi/ncsi-netlink.h new file mode 100644 index 000000000..39a1a9d7b --- /dev/null +++ b/net/ncsi/ncsi-netlink.h @@ -0,0 +1,25 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * Copyright Samuel Mendoza-Jonas, IBM Corporation 2018. + */ + +#ifndef __NCSI_NETLINK_H__ +#define __NCSI_NETLINK_H__ + +#include + +#include "internal.h" + +int ncsi_send_netlink_rsp(struct ncsi_request *nr, + struct ncsi_package *np, + struct ncsi_channel *nc); +int ncsi_send_netlink_timeout(struct ncsi_request *nr, + struct ncsi_package *np, + struct ncsi_channel *nc); +int ncsi_send_netlink_err(struct net_device *dev, + u32 snd_seq, + u32 snd_portid, + struct nlmsghdr *nlhdr, + int err); + +#endif /* __NCSI_NETLINK_H__ */ diff --git a/net/ncsi/ncsi-pkt.h b/net/ncsi/ncsi-pkt.h new file mode 100644 index 000000000..c9d1da34d --- /dev/null +++ b/net/ncsi/ncsi-pkt.h @@ -0,0 +1,456 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * Copyright Gavin Shan, IBM Corporation 2016. + */ + +#ifndef __NCSI_PKT_H__ +#define __NCSI_PKT_H__ + +struct ncsi_pkt_hdr { + unsigned char mc_id; /* Management controller ID */ + unsigned char revision; /* NCSI version - 0x01 */ + unsigned char reserved; /* Reserved */ + unsigned char id; /* Packet sequence number */ + unsigned char type; /* Packet type */ + unsigned char channel; /* Network controller ID */ + __be16 length; /* Payload length */ + __be32 reserved1[2]; /* Reserved */ +}; + +struct ncsi_cmd_pkt_hdr { + struct ncsi_pkt_hdr common; /* Common NCSI packet header */ +}; + +struct ncsi_rsp_pkt_hdr { + struct ncsi_pkt_hdr common; /* Common NCSI packet header */ + __be16 code; /* Response code */ + __be16 reason; /* Response reason */ +}; + +struct ncsi_aen_pkt_hdr { + struct ncsi_pkt_hdr common; /* Common NCSI packet header */ + unsigned char reserved2[3]; /* Reserved */ + unsigned char type; /* AEN packet type */ +}; + +/* NCSI common command packet */ +struct ncsi_cmd_pkt { + struct ncsi_cmd_pkt_hdr cmd; /* Command header */ + __be32 checksum; /* Checksum */ + unsigned char pad[26]; +}; + +struct ncsi_rsp_pkt { + struct ncsi_rsp_pkt_hdr rsp; /* Response header */ + __be32 checksum; /* Checksum */ + unsigned char pad[22]; +}; + +/* Select Package */ +struct ncsi_cmd_sp_pkt { + struct ncsi_cmd_pkt_hdr cmd; /* Command header */ + unsigned char reserved[3]; /* Reserved */ + unsigned char hw_arbitration; /* HW arbitration */ + __be32 checksum; /* Checksum */ + unsigned char pad[22]; +}; + +/* Disable Channel */ +struct ncsi_cmd_dc_pkt { + struct ncsi_cmd_pkt_hdr cmd; /* Command header */ + unsigned char reserved[3]; /* Reserved */ + unsigned char ald; /* Allow link down */ + __be32 checksum; /* Checksum */ + unsigned char pad[22]; +}; + +/* Reset Channel */ +struct ncsi_cmd_rc_pkt { + struct ncsi_cmd_pkt_hdr cmd; /* Command header */ + __be32 reserved; /* Reserved */ + __be32 checksum; /* Checksum */ + unsigned char pad[22]; +}; + +/* AEN Enable */ +struct ncsi_cmd_ae_pkt { + struct ncsi_cmd_pkt_hdr cmd; /* Command header */ + unsigned char reserved[3]; /* Reserved */ + unsigned char mc_id; /* MC ID */ + __be32 mode; /* AEN working mode */ + __be32 checksum; /* Checksum */ + unsigned char pad[18]; +}; + +/* Set Link */ +struct ncsi_cmd_sl_pkt { + struct ncsi_cmd_pkt_hdr cmd; /* Command header */ + __be32 mode; /* Link working mode */ + __be32 oem_mode; /* OEM link mode */ + __be32 checksum; /* Checksum */ + unsigned char pad[18]; +}; + +/* Set VLAN Filter */ +struct ncsi_cmd_svf_pkt { + struct ncsi_cmd_pkt_hdr cmd; /* Command header */ + __be16 reserved; /* Reserved */ + __be16 vlan; /* VLAN ID */ + __be16 reserved1; /* Reserved */ + unsigned char index; /* VLAN table index */ + unsigned char enable; /* Enable or disable */ + __be32 checksum; /* Checksum */ + unsigned char pad[18]; +}; + +/* Enable VLAN */ +struct ncsi_cmd_ev_pkt { + struct ncsi_cmd_pkt_hdr cmd; /* Command header */ + unsigned char reserved[3]; /* Reserved */ + unsigned char mode; /* VLAN filter mode */ + __be32 checksum; /* Checksum */ + unsigned char pad[22]; +}; + +/* Set MAC Address */ +struct ncsi_cmd_sma_pkt { + struct ncsi_cmd_pkt_hdr cmd; /* Command header */ + unsigned char mac[6]; /* MAC address */ + unsigned char index; /* MAC table index */ + unsigned char at_e; /* Addr type and operation */ + __be32 checksum; /* Checksum */ + unsigned char pad[18]; +}; + +/* Enable Broadcast Filter */ +struct ncsi_cmd_ebf_pkt { + struct ncsi_cmd_pkt_hdr cmd; /* Command header */ + __be32 mode; /* Filter mode */ + __be32 checksum; /* Checksum */ + unsigned char pad[22]; +}; + +/* Enable Global Multicast Filter */ +struct ncsi_cmd_egmf_pkt { + struct ncsi_cmd_pkt_hdr cmd; /* Command header */ + __be32 mode; /* Global MC mode */ + __be32 checksum; /* Checksum */ + unsigned char pad[22]; +}; + +/* Set NCSI Flow Control */ +struct ncsi_cmd_snfc_pkt { + struct ncsi_cmd_pkt_hdr cmd; /* Command header */ + unsigned char reserved[3]; /* Reserved */ + unsigned char mode; /* Flow control mode */ + __be32 checksum; /* Checksum */ + unsigned char pad[22]; +}; + +/* OEM Request Command as per NCSI Specification */ +struct ncsi_cmd_oem_pkt { + struct ncsi_cmd_pkt_hdr cmd; /* Command header */ + __be32 mfr_id; /* Manufacture ID */ + unsigned char data[]; /* OEM Payload Data */ +}; + +/* OEM Response Packet as per NCSI Specification */ +struct ncsi_rsp_oem_pkt { + struct ncsi_rsp_pkt_hdr rsp; /* Command header */ + __be32 mfr_id; /* Manufacture ID */ + unsigned char data[]; /* Payload data */ +}; + +/* Mellanox Response Data */ +struct ncsi_rsp_oem_mlx_pkt { + unsigned char cmd_rev; /* Command Revision */ + unsigned char cmd; /* Command ID */ + unsigned char param; /* Parameter */ + unsigned char optional; /* Optional data */ + unsigned char data[]; /* Data */ +}; + +/* Broadcom Response Data */ +struct ncsi_rsp_oem_bcm_pkt { + unsigned char ver; /* Payload Version */ + unsigned char type; /* OEM Command type */ + __be16 len; /* Payload Length */ + unsigned char data[]; /* Cmd specific Data */ +}; + +/* Intel Response Data */ +struct ncsi_rsp_oem_intel_pkt { + unsigned char cmd; /* OEM Command ID */ + unsigned char data[]; /* Cmd specific Data */ +}; + +/* Get Link Status */ +struct ncsi_rsp_gls_pkt { + struct ncsi_rsp_pkt_hdr rsp; /* Response header */ + __be32 status; /* Link status */ + __be32 other; /* Other indications */ + __be32 oem_status; /* OEM link status */ + __be32 checksum; + unsigned char pad[10]; +}; + +/* Get Version ID */ +struct ncsi_rsp_gvi_pkt { + struct ncsi_rsp_pkt_hdr rsp; /* Response header */ + unsigned char major; /* NCSI version major */ + unsigned char minor; /* NCSI version minor */ + unsigned char update; /* NCSI version update */ + unsigned char alpha1; /* NCSI version alpha1 */ + unsigned char reserved[3]; /* Reserved */ + unsigned char alpha2; /* NCSI version alpha2 */ + unsigned char fw_name[12]; /* f/w name string */ + __be32 fw_version; /* f/w version */ + __be16 pci_ids[4]; /* PCI IDs */ + __be32 mf_id; /* Manufacture ID */ + __be32 checksum; +}; + +/* Get Capabilities */ +struct ncsi_rsp_gc_pkt { + struct ncsi_rsp_pkt_hdr rsp; /* Response header */ + __be32 cap; /* Capabilities */ + __be32 bc_cap; /* Broadcast cap */ + __be32 mc_cap; /* Multicast cap */ + __be32 buf_cap; /* Buffering cap */ + __be32 aen_cap; /* AEN cap */ + unsigned char vlan_cnt; /* VLAN filter count */ + unsigned char mixed_cnt; /* Mix filter count */ + unsigned char mc_cnt; /* MC filter count */ + unsigned char uc_cnt; /* UC filter count */ + unsigned char reserved[2]; /* Reserved */ + unsigned char vlan_mode; /* VLAN mode */ + unsigned char channel_cnt; /* Channel count */ + __be32 checksum; /* Checksum */ +}; + +/* Get Parameters */ +struct ncsi_rsp_gp_pkt { + struct ncsi_rsp_pkt_hdr rsp; /* Response header */ + unsigned char mac_cnt; /* Number of MAC addr */ + unsigned char reserved[2]; /* Reserved */ + unsigned char mac_enable; /* MAC addr enable flags */ + unsigned char vlan_cnt; /* VLAN tag count */ + unsigned char reserved1; /* Reserved */ + __be16 vlan_enable; /* VLAN tag enable flags */ + __be32 link_mode; /* Link setting */ + __be32 bc_mode; /* BC filter mode */ + __be32 valid_modes; /* Valid mode parameters */ + unsigned char vlan_mode; /* VLAN mode */ + unsigned char fc_mode; /* Flow control mode */ + unsigned char reserved2[2]; /* Reserved */ + __be32 aen_mode; /* AEN mode */ + unsigned char mac[6]; /* Supported MAC addr */ + __be16 vlan; /* Supported VLAN tags */ + __be32 checksum; /* Checksum */ +}; + +/* Get Controller Packet Statistics */ +struct ncsi_rsp_gcps_pkt { + struct ncsi_rsp_pkt_hdr rsp; /* Response header */ + __be32 cnt_hi; /* Counter cleared */ + __be32 cnt_lo; /* Counter cleared */ + __be32 rx_bytes; /* Rx bytes */ + __be32 tx_bytes; /* Tx bytes */ + __be32 rx_uc_pkts; /* Rx UC packets */ + __be32 rx_mc_pkts; /* Rx MC packets */ + __be32 rx_bc_pkts; /* Rx BC packets */ + __be32 tx_uc_pkts; /* Tx UC packets */ + __be32 tx_mc_pkts; /* Tx MC packets */ + __be32 tx_bc_pkts; /* Tx BC packets */ + __be32 fcs_err; /* FCS errors */ + __be32 align_err; /* Alignment errors */ + __be32 false_carrier; /* False carrier detection */ + __be32 runt_pkts; /* Rx runt packets */ + __be32 jabber_pkts; /* Rx jabber packets */ + __be32 rx_pause_xon; /* Rx pause XON frames */ + __be32 rx_pause_xoff; /* Rx XOFF frames */ + __be32 tx_pause_xon; /* Tx XON frames */ + __be32 tx_pause_xoff; /* Tx XOFF frames */ + __be32 tx_s_collision; /* Single collision frames */ + __be32 tx_m_collision; /* Multiple collision frames */ + __be32 l_collision; /* Late collision frames */ + __be32 e_collision; /* Excessive collision frames */ + __be32 rx_ctl_frames; /* Rx control frames */ + __be32 rx_64_frames; /* Rx 64-bytes frames */ + __be32 rx_127_frames; /* Rx 65-127 bytes frames */ + __be32 rx_255_frames; /* Rx 128-255 bytes frames */ + __be32 rx_511_frames; /* Rx 256-511 bytes frames */ + __be32 rx_1023_frames; /* Rx 512-1023 bytes frames */ + __be32 rx_1522_frames; /* Rx 1024-1522 bytes frames */ + __be32 rx_9022_frames; /* Rx 1523-9022 bytes frames */ + __be32 tx_64_frames; /* Tx 64-bytes frames */ + __be32 tx_127_frames; /* Tx 65-127 bytes frames */ + __be32 tx_255_frames; /* Tx 128-255 bytes frames */ + __be32 tx_511_frames; /* Tx 256-511 bytes frames */ + __be32 tx_1023_frames; /* Tx 512-1023 bytes frames */ + __be32 tx_1522_frames; /* Tx 1024-1522 bytes frames */ + __be32 tx_9022_frames; /* Tx 1523-9022 bytes frames */ + __be32 rx_valid_bytes; /* Rx valid bytes */ + __be32 rx_runt_pkts; /* Rx error runt packets */ + __be32 rx_jabber_pkts; /* Rx error jabber packets */ + __be32 checksum; /* Checksum */ +}; + +/* Get NCSI Statistics */ +struct ncsi_rsp_gns_pkt { + struct ncsi_rsp_pkt_hdr rsp; /* Response header */ + __be32 rx_cmds; /* Rx NCSI commands */ + __be32 dropped_cmds; /* Dropped commands */ + __be32 cmd_type_errs; /* Command type errors */ + __be32 cmd_csum_errs; /* Command checksum errors */ + __be32 rx_pkts; /* Rx NCSI packets */ + __be32 tx_pkts; /* Tx NCSI packets */ + __be32 tx_aen_pkts; /* Tx AEN packets */ + __be32 checksum; /* Checksum */ +}; + +/* Get NCSI Pass-through Statistics */ +struct ncsi_rsp_gnpts_pkt { + struct ncsi_rsp_pkt_hdr rsp; /* Response header */ + __be32 tx_pkts; /* Tx packets */ + __be32 tx_dropped; /* Tx dropped packets */ + __be32 tx_channel_err; /* Tx channel errors */ + __be32 tx_us_err; /* Tx undersize errors */ + __be32 rx_pkts; /* Rx packets */ + __be32 rx_dropped; /* Rx dropped packets */ + __be32 rx_channel_err; /* Rx channel errors */ + __be32 rx_us_err; /* Rx undersize errors */ + __be32 rx_os_err; /* Rx oversize errors */ + __be32 checksum; /* Checksum */ +}; + +/* Get package status */ +struct ncsi_rsp_gps_pkt { + struct ncsi_rsp_pkt_hdr rsp; /* Response header */ + __be32 status; /* Hardware arbitration status */ + __be32 checksum; +}; + +/* Get package UUID */ +struct ncsi_rsp_gpuuid_pkt { + struct ncsi_rsp_pkt_hdr rsp; /* Response header */ + unsigned char uuid[16]; /* UUID */ + __be32 checksum; +}; + +/* AEN: Link State Change */ +struct ncsi_aen_lsc_pkt { + struct ncsi_aen_pkt_hdr aen; /* AEN header */ + __be32 status; /* Link status */ + __be32 oem_status; /* OEM link status */ + __be32 checksum; /* Checksum */ + unsigned char pad[14]; +}; + +/* AEN: Configuration Required */ +struct ncsi_aen_cr_pkt { + struct ncsi_aen_pkt_hdr aen; /* AEN header */ + __be32 checksum; /* Checksum */ + unsigned char pad[22]; +}; + +/* AEN: Host Network Controller Driver Status Change */ +struct ncsi_aen_hncdsc_pkt { + struct ncsi_aen_pkt_hdr aen; /* AEN header */ + __be32 status; /* Status */ + __be32 checksum; /* Checksum */ + unsigned char pad[18]; +}; + +/* NCSI packet revision */ +#define NCSI_PKT_REVISION 0x01 + +/* NCSI packet commands */ +#define NCSI_PKT_CMD_CIS 0x00 /* Clear Initial State */ +#define NCSI_PKT_CMD_SP 0x01 /* Select Package */ +#define NCSI_PKT_CMD_DP 0x02 /* Deselect Package */ +#define NCSI_PKT_CMD_EC 0x03 /* Enable Channel */ +#define NCSI_PKT_CMD_DC 0x04 /* Disable Channel */ +#define NCSI_PKT_CMD_RC 0x05 /* Reset Channel */ +#define NCSI_PKT_CMD_ECNT 0x06 /* Enable Channel Network Tx */ +#define NCSI_PKT_CMD_DCNT 0x07 /* Disable Channel Network Tx */ +#define NCSI_PKT_CMD_AE 0x08 /* AEN Enable */ +#define NCSI_PKT_CMD_SL 0x09 /* Set Link */ +#define NCSI_PKT_CMD_GLS 0x0a /* Get Link */ +#define NCSI_PKT_CMD_SVF 0x0b /* Set VLAN Filter */ +#define NCSI_PKT_CMD_EV 0x0c /* Enable VLAN */ +#define NCSI_PKT_CMD_DV 0x0d /* Disable VLAN */ +#define NCSI_PKT_CMD_SMA 0x0e /* Set MAC address */ +#define NCSI_PKT_CMD_EBF 0x10 /* Enable Broadcast Filter */ +#define NCSI_PKT_CMD_DBF 0x11 /* Disable Broadcast Filter */ +#define NCSI_PKT_CMD_EGMF 0x12 /* Enable Global Multicast Filter */ +#define NCSI_PKT_CMD_DGMF 0x13 /* Disable Global Multicast Filter */ +#define NCSI_PKT_CMD_SNFC 0x14 /* Set NCSI Flow Control */ +#define NCSI_PKT_CMD_GVI 0x15 /* Get Version ID */ +#define NCSI_PKT_CMD_GC 0x16 /* Get Capabilities */ +#define NCSI_PKT_CMD_GP 0x17 /* Get Parameters */ +#define NCSI_PKT_CMD_GCPS 0x18 /* Get Controller Packet Statistics */ +#define NCSI_PKT_CMD_GNS 0x19 /* Get NCSI Statistics */ +#define NCSI_PKT_CMD_GNPTS 0x1a /* Get NCSI Pass-throu Statistics */ +#define NCSI_PKT_CMD_GPS 0x1b /* Get package status */ +#define NCSI_PKT_CMD_OEM 0x50 /* OEM */ +#define NCSI_PKT_CMD_PLDM 0x51 /* PLDM request over NCSI over RBT */ +#define NCSI_PKT_CMD_GPUUID 0x52 /* Get package UUID */ +#define NCSI_PKT_CMD_QPNPR 0x56 /* Query Pending NC PLDM request */ +#define NCSI_PKT_CMD_SNPR 0x57 /* Send NC PLDM Reply */ + + +/* NCSI packet responses */ +#define NCSI_PKT_RSP_CIS (NCSI_PKT_CMD_CIS + 0x80) +#define NCSI_PKT_RSP_SP (NCSI_PKT_CMD_SP + 0x80) +#define NCSI_PKT_RSP_DP (NCSI_PKT_CMD_DP + 0x80) +#define NCSI_PKT_RSP_EC (NCSI_PKT_CMD_EC + 0x80) +#define NCSI_PKT_RSP_DC (NCSI_PKT_CMD_DC + 0x80) +#define NCSI_PKT_RSP_RC (NCSI_PKT_CMD_RC + 0x80) +#define NCSI_PKT_RSP_ECNT (NCSI_PKT_CMD_ECNT + 0x80) +#define NCSI_PKT_RSP_DCNT (NCSI_PKT_CMD_DCNT + 0x80) +#define NCSI_PKT_RSP_AE (NCSI_PKT_CMD_AE + 0x80) +#define NCSI_PKT_RSP_SL (NCSI_PKT_CMD_SL + 0x80) +#define NCSI_PKT_RSP_GLS (NCSI_PKT_CMD_GLS + 0x80) +#define NCSI_PKT_RSP_SVF (NCSI_PKT_CMD_SVF + 0x80) +#define NCSI_PKT_RSP_EV (NCSI_PKT_CMD_EV + 0x80) +#define NCSI_PKT_RSP_DV (NCSI_PKT_CMD_DV + 0x80) +#define NCSI_PKT_RSP_SMA (NCSI_PKT_CMD_SMA + 0x80) +#define NCSI_PKT_RSP_EBF (NCSI_PKT_CMD_EBF + 0x80) +#define NCSI_PKT_RSP_DBF (NCSI_PKT_CMD_DBF + 0x80) +#define NCSI_PKT_RSP_EGMF (NCSI_PKT_CMD_EGMF + 0x80) +#define NCSI_PKT_RSP_DGMF (NCSI_PKT_CMD_DGMF + 0x80) +#define NCSI_PKT_RSP_SNFC (NCSI_PKT_CMD_SNFC + 0x80) +#define NCSI_PKT_RSP_GVI (NCSI_PKT_CMD_GVI + 0x80) +#define NCSI_PKT_RSP_GC (NCSI_PKT_CMD_GC + 0x80) +#define NCSI_PKT_RSP_GP (NCSI_PKT_CMD_GP + 0x80) +#define NCSI_PKT_RSP_GCPS (NCSI_PKT_CMD_GCPS + 0x80) +#define NCSI_PKT_RSP_GNS (NCSI_PKT_CMD_GNS + 0x80) +#define NCSI_PKT_RSP_GNPTS (NCSI_PKT_CMD_GNPTS + 0x80) +#define NCSI_PKT_RSP_GPS (NCSI_PKT_CMD_GPS + 0x80) +#define NCSI_PKT_RSP_OEM (NCSI_PKT_CMD_OEM + 0x80) +#define NCSI_PKT_RSP_PLDM (NCSI_PKT_CMD_PLDM + 0x80) +#define NCSI_PKT_RSP_GPUUID (NCSI_PKT_CMD_GPUUID + 0x80) +#define NCSI_PKT_RSP_QPNPR (NCSI_PKT_CMD_QPNPR + 0x80) +#define NCSI_PKT_RSP_SNPR (NCSI_PKT_CMD_SNPR + 0x80) + +/* NCSI response code/reason */ +#define NCSI_PKT_RSP_C_COMPLETED 0x0000 /* Command Completed */ +#define NCSI_PKT_RSP_C_FAILED 0x0001 /* Command Failed */ +#define NCSI_PKT_RSP_C_UNAVAILABLE 0x0002 /* Command Unavailable */ +#define NCSI_PKT_RSP_C_UNSUPPORTED 0x0003 /* Command Unsupported */ +#define NCSI_PKT_RSP_R_NO_ERROR 0x0000 /* No Error */ +#define NCSI_PKT_RSP_R_INTERFACE 0x0001 /* Interface not ready */ +#define NCSI_PKT_RSP_R_PARAM 0x0002 /* Invalid Parameter */ +#define NCSI_PKT_RSP_R_CHANNEL 0x0003 /* Channel not Ready */ +#define NCSI_PKT_RSP_R_PACKAGE 0x0004 /* Package not Ready */ +#define NCSI_PKT_RSP_R_LENGTH 0x0005 /* Invalid payload length */ +#define NCSI_PKT_RSP_R_UNKNOWN 0x7fff /* Command type unsupported */ + +/* NCSI AEN packet type */ +#define NCSI_PKT_AEN 0xFF /* AEN Packet */ +#define NCSI_PKT_AEN_LSC 0x00 /* Link status change */ +#define NCSI_PKT_AEN_CR 0x01 /* Configuration required */ +#define NCSI_PKT_AEN_HNCDSC 0x02 /* HNC driver status change */ + +#endif /* __NCSI_PKT_H__ */ diff --git a/net/ncsi/ncsi-rsp.c b/net/ncsi/ncsi-rsp.c new file mode 100644 index 000000000..480e80e3c --- /dev/null +++ b/net/ncsi/ncsi-rsp.c @@ -0,0 +1,1232 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright Gavin Shan, IBM Corporation 2016. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "internal.h" +#include "ncsi-pkt.h" +#include "ncsi-netlink.h" + +/* Nibbles within [0xA, 0xF] add zero "0" to the returned value. + * Optional fields (encoded as 0xFF) will default to zero. + */ +static u8 decode_bcd_u8(u8 x) +{ + int lo = x & 0xF; + int hi = x >> 4; + + lo = lo < 0xA ? lo : 0; + hi = hi < 0xA ? hi : 0; + return lo + hi * 10; +} + +static int ncsi_validate_rsp_pkt(struct ncsi_request *nr, + unsigned short payload) +{ + struct ncsi_rsp_pkt_hdr *h; + u32 checksum; + __be32 *pchecksum; + + /* Check NCSI packet header. We don't need validate + * the packet type, which should have been checked + * before calling this function. + */ + h = (struct ncsi_rsp_pkt_hdr *)skb_network_header(nr->rsp); + + if (h->common.revision != NCSI_PKT_REVISION) { + netdev_dbg(nr->ndp->ndev.dev, + "NCSI: unsupported header revision\n"); + return -EINVAL; + } + if (ntohs(h->common.length) != payload) { + netdev_dbg(nr->ndp->ndev.dev, + "NCSI: payload length mismatched\n"); + return -EINVAL; + } + + /* Check on code and reason */ + if (ntohs(h->code) != NCSI_PKT_RSP_C_COMPLETED || + ntohs(h->reason) != NCSI_PKT_RSP_R_NO_ERROR) { + netdev_dbg(nr->ndp->ndev.dev, + "NCSI: non zero response/reason code %04xh, %04xh\n", + ntohs(h->code), ntohs(h->reason)); + return -EPERM; + } + + /* Validate checksum, which might be zeroes if the + * sender doesn't support checksum according to NCSI + * specification. + */ + pchecksum = (__be32 *)((void *)(h + 1) + ALIGN(payload, 4) - 4); + if (ntohl(*pchecksum) == 0) + return 0; + + checksum = ncsi_calculate_checksum((unsigned char *)h, + sizeof(*h) + payload - 4); + + if (*pchecksum != htonl(checksum)) { + netdev_dbg(nr->ndp->ndev.dev, + "NCSI: checksum mismatched; recd: %08x calc: %08x\n", + *pchecksum, htonl(checksum)); + return -EINVAL; + } + + return 0; +} + +static int ncsi_rsp_handler_cis(struct ncsi_request *nr) +{ + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_package *np; + struct ncsi_channel *nc; + unsigned char id; + + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, &np, &nc); + if (!nc) { + if (ndp->flags & NCSI_DEV_PROBED) + return -ENXIO; + + id = NCSI_CHANNEL_INDEX(rsp->rsp.common.channel); + nc = ncsi_add_channel(np, id); + } + + return nc ? 0 : -ENODEV; +} + +static int ncsi_rsp_handler_sp(struct ncsi_request *nr) +{ + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_package *np; + unsigned char id; + + /* Add the package if it's not existing. Otherwise, + * to change the state of its child channels. + */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + &np, NULL); + if (!np) { + if (ndp->flags & NCSI_DEV_PROBED) + return -ENXIO; + + id = NCSI_PACKAGE_INDEX(rsp->rsp.common.channel); + np = ncsi_add_package(ndp, id); + if (!np) + return -ENODEV; + } + + return 0; +} + +static int ncsi_rsp_handler_dp(struct ncsi_request *nr) +{ + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_package *np; + struct ncsi_channel *nc; + unsigned long flags; + + /* Find the package */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + &np, NULL); + if (!np) + return -ENODEV; + + /* Change state of all channels attached to the package */ + NCSI_FOR_EACH_CHANNEL(np, nc) { + spin_lock_irqsave(&nc->lock, flags); + nc->state = NCSI_CHANNEL_INACTIVE; + spin_unlock_irqrestore(&nc->lock, flags); + } + + return 0; +} + +static int ncsi_rsp_handler_ec(struct ncsi_request *nr) +{ + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_mode *ncm; + + /* Find the package and channel */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + ncm = &nc->modes[NCSI_MODE_ENABLE]; + if (ncm->enable) + return 0; + + ncm->enable = 1; + return 0; +} + +static int ncsi_rsp_handler_dc(struct ncsi_request *nr) +{ + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_mode *ncm; + int ret; + + ret = ncsi_validate_rsp_pkt(nr, 4); + if (ret) + return ret; + + /* Find the package and channel */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + ncm = &nc->modes[NCSI_MODE_ENABLE]; + if (!ncm->enable) + return 0; + + ncm->enable = 0; + return 0; +} + +static int ncsi_rsp_handler_rc(struct ncsi_request *nr) +{ + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + unsigned long flags; + + /* Find the package and channel */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + /* Update state for the specified channel */ + spin_lock_irqsave(&nc->lock, flags); + nc->state = NCSI_CHANNEL_INACTIVE; + spin_unlock_irqrestore(&nc->lock, flags); + + return 0; +} + +static int ncsi_rsp_handler_ecnt(struct ncsi_request *nr) +{ + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_mode *ncm; + + /* Find the package and channel */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + ncm = &nc->modes[NCSI_MODE_TX_ENABLE]; + if (ncm->enable) + return 0; + + ncm->enable = 1; + return 0; +} + +static int ncsi_rsp_handler_dcnt(struct ncsi_request *nr) +{ + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_mode *ncm; + + /* Find the package and channel */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + ncm = &nc->modes[NCSI_MODE_TX_ENABLE]; + if (!ncm->enable) + return 0; + + ncm->enable = 0; + return 0; +} + +static int ncsi_rsp_handler_ae(struct ncsi_request *nr) +{ + struct ncsi_cmd_ae_pkt *cmd; + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_mode *ncm; + + /* Find the package and channel */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + /* Check if the AEN has been enabled */ + ncm = &nc->modes[NCSI_MODE_AEN]; + if (ncm->enable) + return 0; + + /* Update to AEN configuration */ + cmd = (struct ncsi_cmd_ae_pkt *)skb_network_header(nr->cmd); + ncm->enable = 1; + ncm->data[0] = cmd->mc_id; + ncm->data[1] = ntohl(cmd->mode); + + return 0; +} + +static int ncsi_rsp_handler_sl(struct ncsi_request *nr) +{ + struct ncsi_cmd_sl_pkt *cmd; + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_mode *ncm; + + /* Find the package and channel */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + cmd = (struct ncsi_cmd_sl_pkt *)skb_network_header(nr->cmd); + ncm = &nc->modes[NCSI_MODE_LINK]; + ncm->data[0] = ntohl(cmd->mode); + ncm->data[1] = ntohl(cmd->oem_mode); + + return 0; +} + +static int ncsi_rsp_handler_gls(struct ncsi_request *nr) +{ + struct ncsi_rsp_gls_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_mode *ncm; + unsigned long flags; + + /* Find the package and channel */ + rsp = (struct ncsi_rsp_gls_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + ncm = &nc->modes[NCSI_MODE_LINK]; + ncm->data[2] = ntohl(rsp->status); + ncm->data[3] = ntohl(rsp->other); + ncm->data[4] = ntohl(rsp->oem_status); + + if (nr->flags & NCSI_REQ_FLAG_EVENT_DRIVEN) + return 0; + + /* Reset the channel monitor if it has been enabled */ + spin_lock_irqsave(&nc->lock, flags); + nc->monitor.state = NCSI_CHANNEL_MONITOR_START; + spin_unlock_irqrestore(&nc->lock, flags); + + return 0; +} + +static int ncsi_rsp_handler_svf(struct ncsi_request *nr) +{ + struct ncsi_cmd_svf_pkt *cmd; + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_vlan_filter *ncf; + unsigned long flags; + void *bitmap; + + /* Find the package and channel */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + cmd = (struct ncsi_cmd_svf_pkt *)skb_network_header(nr->cmd); + ncf = &nc->vlan_filter; + if (cmd->index == 0 || cmd->index > ncf->n_vids) + return -ERANGE; + + /* Add or remove the VLAN filter. Remember HW indexes from 1 */ + spin_lock_irqsave(&nc->lock, flags); + bitmap = &ncf->bitmap; + if (!(cmd->enable & 0x1)) { + if (test_and_clear_bit(cmd->index - 1, bitmap)) + ncf->vids[cmd->index - 1] = 0; + } else { + set_bit(cmd->index - 1, bitmap); + ncf->vids[cmd->index - 1] = ntohs(cmd->vlan); + } + spin_unlock_irqrestore(&nc->lock, flags); + + return 0; +} + +static int ncsi_rsp_handler_ev(struct ncsi_request *nr) +{ + struct ncsi_cmd_ev_pkt *cmd; + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_mode *ncm; + + /* Find the package and channel */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + /* Check if VLAN mode has been enabled */ + ncm = &nc->modes[NCSI_MODE_VLAN]; + if (ncm->enable) + return 0; + + /* Update to VLAN mode */ + cmd = (struct ncsi_cmd_ev_pkt *)skb_network_header(nr->cmd); + ncm->enable = 1; + ncm->data[0] = ntohl((__force __be32)cmd->mode); + + return 0; +} + +static int ncsi_rsp_handler_dv(struct ncsi_request *nr) +{ + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_mode *ncm; + + /* Find the package and channel */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + /* Check if VLAN mode has been enabled */ + ncm = &nc->modes[NCSI_MODE_VLAN]; + if (!ncm->enable) + return 0; + + /* Update to VLAN mode */ + ncm->enable = 0; + return 0; +} + +static int ncsi_rsp_handler_sma(struct ncsi_request *nr) +{ + struct ncsi_cmd_sma_pkt *cmd; + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_mac_filter *ncf; + unsigned long flags; + void *bitmap; + bool enabled; + int index; + + + /* Find the package and channel */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + /* According to NCSI spec 1.01, the mixed filter table + * isn't supported yet. + */ + cmd = (struct ncsi_cmd_sma_pkt *)skb_network_header(nr->cmd); + enabled = cmd->at_e & 0x1; + ncf = &nc->mac_filter; + bitmap = &ncf->bitmap; + + if (cmd->index == 0 || + cmd->index > ncf->n_uc + ncf->n_mc + ncf->n_mixed) + return -ERANGE; + + index = (cmd->index - 1) * ETH_ALEN; + spin_lock_irqsave(&nc->lock, flags); + if (enabled) { + set_bit(cmd->index - 1, bitmap); + memcpy(&ncf->addrs[index], cmd->mac, ETH_ALEN); + } else { + clear_bit(cmd->index - 1, bitmap); + eth_zero_addr(&ncf->addrs[index]); + } + spin_unlock_irqrestore(&nc->lock, flags); + + return 0; +} + +static int ncsi_rsp_handler_ebf(struct ncsi_request *nr) +{ + struct ncsi_cmd_ebf_pkt *cmd; + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_mode *ncm; + + /* Find the package and channel */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, NULL, &nc); + if (!nc) + return -ENODEV; + + /* Check if broadcast filter has been enabled */ + ncm = &nc->modes[NCSI_MODE_BC]; + if (ncm->enable) + return 0; + + /* Update to broadcast filter mode */ + cmd = (struct ncsi_cmd_ebf_pkt *)skb_network_header(nr->cmd); + ncm->enable = 1; + ncm->data[0] = ntohl(cmd->mode); + + return 0; +} + +static int ncsi_rsp_handler_dbf(struct ncsi_request *nr) +{ + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_mode *ncm; + + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + /* Check if broadcast filter isn't enabled */ + ncm = &nc->modes[NCSI_MODE_BC]; + if (!ncm->enable) + return 0; + + /* Update to broadcast filter mode */ + ncm->enable = 0; + ncm->data[0] = 0; + + return 0; +} + +static int ncsi_rsp_handler_egmf(struct ncsi_request *nr) +{ + struct ncsi_cmd_egmf_pkt *cmd; + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_mode *ncm; + + /* Find the channel */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + /* Check if multicast filter has been enabled */ + ncm = &nc->modes[NCSI_MODE_MC]; + if (ncm->enable) + return 0; + + /* Update to multicast filter mode */ + cmd = (struct ncsi_cmd_egmf_pkt *)skb_network_header(nr->cmd); + ncm->enable = 1; + ncm->data[0] = ntohl(cmd->mode); + + return 0; +} + +static int ncsi_rsp_handler_dgmf(struct ncsi_request *nr) +{ + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_mode *ncm; + + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + /* Check if multicast filter has been enabled */ + ncm = &nc->modes[NCSI_MODE_MC]; + if (!ncm->enable) + return 0; + + /* Update to multicast filter mode */ + ncm->enable = 0; + ncm->data[0] = 0; + + return 0; +} + +static int ncsi_rsp_handler_snfc(struct ncsi_request *nr) +{ + struct ncsi_cmd_snfc_pkt *cmd; + struct ncsi_rsp_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_mode *ncm; + + /* Find the channel */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + /* Check if flow control has been enabled */ + ncm = &nc->modes[NCSI_MODE_FC]; + if (ncm->enable) + return 0; + + /* Update to flow control mode */ + cmd = (struct ncsi_cmd_snfc_pkt *)skb_network_header(nr->cmd); + ncm->enable = 1; + ncm->data[0] = cmd->mode; + + return 0; +} + +/* Response handler for Get Mac Address command */ +static int ncsi_rsp_handler_oem_gma(struct ncsi_request *nr, int mfr_id) +{ + struct ncsi_dev_priv *ndp = nr->ndp; + struct net_device *ndev = ndp->ndev.dev; + struct ncsi_rsp_oem_pkt *rsp; + struct sockaddr saddr; + u32 mac_addr_off = 0; + int ret = 0; + + /* Get the response header */ + rsp = (struct ncsi_rsp_oem_pkt *)skb_network_header(nr->rsp); + + saddr.sa_family = ndev->type; + ndev->priv_flags |= IFF_LIVE_ADDR_CHANGE; + if (mfr_id == NCSI_OEM_MFR_BCM_ID) + mac_addr_off = BCM_MAC_ADDR_OFFSET; + else if (mfr_id == NCSI_OEM_MFR_MLX_ID) + mac_addr_off = MLX_MAC_ADDR_OFFSET; + else if (mfr_id == NCSI_OEM_MFR_INTEL_ID) + mac_addr_off = INTEL_MAC_ADDR_OFFSET; + + memcpy(saddr.sa_data, &rsp->data[mac_addr_off], ETH_ALEN); + if (mfr_id == NCSI_OEM_MFR_BCM_ID || mfr_id == NCSI_OEM_MFR_INTEL_ID) + eth_addr_inc((u8 *)saddr.sa_data); + if (!is_valid_ether_addr((const u8 *)saddr.sa_data)) + return -ENXIO; + + /* Set the flag for GMA command which should only be called once */ + ndp->gma_flag = 1; + + rtnl_lock(); + ret = dev_set_mac_address(ndev, &saddr, NULL); + rtnl_unlock(); + if (ret < 0) + netdev_warn(ndev, "NCSI: 'Writing mac address to device failed\n"); + + return ret; +} + +/* Response handler for Mellanox card */ +static int ncsi_rsp_handler_oem_mlx(struct ncsi_request *nr) +{ + struct ncsi_rsp_oem_mlx_pkt *mlx; + struct ncsi_rsp_oem_pkt *rsp; + + /* Get the response header */ + rsp = (struct ncsi_rsp_oem_pkt *)skb_network_header(nr->rsp); + mlx = (struct ncsi_rsp_oem_mlx_pkt *)(rsp->data); + + if (mlx->cmd == NCSI_OEM_MLX_CMD_GMA && + mlx->param == NCSI_OEM_MLX_CMD_GMA_PARAM) + return ncsi_rsp_handler_oem_gma(nr, NCSI_OEM_MFR_MLX_ID); + return 0; +} + +/* Response handler for Broadcom card */ +static int ncsi_rsp_handler_oem_bcm(struct ncsi_request *nr) +{ + struct ncsi_rsp_oem_bcm_pkt *bcm; + struct ncsi_rsp_oem_pkt *rsp; + + /* Get the response header */ + rsp = (struct ncsi_rsp_oem_pkt *)skb_network_header(nr->rsp); + bcm = (struct ncsi_rsp_oem_bcm_pkt *)(rsp->data); + + if (bcm->type == NCSI_OEM_BCM_CMD_GMA) + return ncsi_rsp_handler_oem_gma(nr, NCSI_OEM_MFR_BCM_ID); + return 0; +} + +/* Response handler for Intel card */ +static int ncsi_rsp_handler_oem_intel(struct ncsi_request *nr) +{ + struct ncsi_rsp_oem_intel_pkt *intel; + struct ncsi_rsp_oem_pkt *rsp; + + /* Get the response header */ + rsp = (struct ncsi_rsp_oem_pkt *)skb_network_header(nr->rsp); + intel = (struct ncsi_rsp_oem_intel_pkt *)(rsp->data); + + if (intel->cmd == NCSI_OEM_INTEL_CMD_GMA) + return ncsi_rsp_handler_oem_gma(nr, NCSI_OEM_MFR_INTEL_ID); + + return 0; +} + +static struct ncsi_rsp_oem_handler { + unsigned int mfr_id; + int (*handler)(struct ncsi_request *nr); +} ncsi_rsp_oem_handlers[] = { + { NCSI_OEM_MFR_MLX_ID, ncsi_rsp_handler_oem_mlx }, + { NCSI_OEM_MFR_BCM_ID, ncsi_rsp_handler_oem_bcm }, + { NCSI_OEM_MFR_INTEL_ID, ncsi_rsp_handler_oem_intel } +}; + +/* Response handler for OEM command */ +static int ncsi_rsp_handler_oem(struct ncsi_request *nr) +{ + struct ncsi_rsp_oem_handler *nrh = NULL; + struct ncsi_rsp_oem_pkt *rsp; + unsigned int mfr_id, i; + + /* Get the response header */ + rsp = (struct ncsi_rsp_oem_pkt *)skb_network_header(nr->rsp); + mfr_id = ntohl(rsp->mfr_id); + + /* Check for manufacturer id and Find the handler */ + for (i = 0; i < ARRAY_SIZE(ncsi_rsp_oem_handlers); i++) { + if (ncsi_rsp_oem_handlers[i].mfr_id == mfr_id) { + if (ncsi_rsp_oem_handlers[i].handler) + nrh = &ncsi_rsp_oem_handlers[i]; + else + nrh = NULL; + + break; + } + } + + if (!nrh) { + netdev_err(nr->ndp->ndev.dev, "Received unrecognized OEM packet with MFR-ID (0x%x)\n", + mfr_id); + return -ENOENT; + } + + /* Process the packet */ + return nrh->handler(nr); +} + +static int ncsi_rsp_handler_gvi(struct ncsi_request *nr) +{ + struct ncsi_rsp_gvi_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_version *ncv; + int i; + + /* Find the channel */ + rsp = (struct ncsi_rsp_gvi_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + /* Update channel's version info + * + * Major, minor, and update fields are supposed to be + * unsigned integers encoded as packed BCD. + * + * Alpha1 and alpha2 are ISO/IEC 8859-1 characters. + */ + ncv = &nc->version; + ncv->major = decode_bcd_u8(rsp->major); + ncv->minor = decode_bcd_u8(rsp->minor); + ncv->update = decode_bcd_u8(rsp->update); + ncv->alpha1 = rsp->alpha1; + ncv->alpha2 = rsp->alpha2; + memcpy(ncv->fw_name, rsp->fw_name, 12); + ncv->fw_version = ntohl(rsp->fw_version); + for (i = 0; i < ARRAY_SIZE(ncv->pci_ids); i++) + ncv->pci_ids[i] = ntohs(rsp->pci_ids[i]); + ncv->mf_id = ntohl(rsp->mf_id); + + return 0; +} + +static int ncsi_rsp_handler_gc(struct ncsi_request *nr) +{ + struct ncsi_rsp_gc_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + size_t size; + + /* Find the channel */ + rsp = (struct ncsi_rsp_gc_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + /* Update channel's capabilities */ + nc->caps[NCSI_CAP_GENERIC].cap = ntohl(rsp->cap) & + NCSI_CAP_GENERIC_MASK; + nc->caps[NCSI_CAP_BC].cap = ntohl(rsp->bc_cap) & + NCSI_CAP_BC_MASK; + nc->caps[NCSI_CAP_MC].cap = ntohl(rsp->mc_cap) & + NCSI_CAP_MC_MASK; + nc->caps[NCSI_CAP_BUFFER].cap = ntohl(rsp->buf_cap); + nc->caps[NCSI_CAP_AEN].cap = ntohl(rsp->aen_cap) & + NCSI_CAP_AEN_MASK; + nc->caps[NCSI_CAP_VLAN].cap = rsp->vlan_mode & + NCSI_CAP_VLAN_MASK; + + size = (rsp->uc_cnt + rsp->mc_cnt + rsp->mixed_cnt) * ETH_ALEN; + nc->mac_filter.addrs = kzalloc(size, GFP_ATOMIC); + if (!nc->mac_filter.addrs) + return -ENOMEM; + nc->mac_filter.n_uc = rsp->uc_cnt; + nc->mac_filter.n_mc = rsp->mc_cnt; + nc->mac_filter.n_mixed = rsp->mixed_cnt; + + nc->vlan_filter.vids = kcalloc(rsp->vlan_cnt, + sizeof(*nc->vlan_filter.vids), + GFP_ATOMIC); + if (!nc->vlan_filter.vids) + return -ENOMEM; + /* Set VLAN filters active so they are cleared in the first + * configuration state + */ + nc->vlan_filter.bitmap = U64_MAX; + nc->vlan_filter.n_vids = rsp->vlan_cnt; + + return 0; +} + +static int ncsi_rsp_handler_gp(struct ncsi_request *nr) +{ + struct ncsi_channel_vlan_filter *ncvf; + struct ncsi_channel_mac_filter *ncmf; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_rsp_gp_pkt *rsp; + struct ncsi_channel *nc; + unsigned short enable; + unsigned char *pdata; + unsigned long flags; + void *bitmap; + int i; + + /* Find the channel */ + rsp = (struct ncsi_rsp_gp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + /* Modes with explicit enabled indications */ + if (ntohl(rsp->valid_modes) & 0x1) { /* BC filter mode */ + nc->modes[NCSI_MODE_BC].enable = 1; + nc->modes[NCSI_MODE_BC].data[0] = ntohl(rsp->bc_mode); + } + if (ntohl(rsp->valid_modes) & 0x2) /* Channel enabled */ + nc->modes[NCSI_MODE_ENABLE].enable = 1; + if (ntohl(rsp->valid_modes) & 0x4) /* Channel Tx enabled */ + nc->modes[NCSI_MODE_TX_ENABLE].enable = 1; + if (ntohl(rsp->valid_modes) & 0x8) /* MC filter mode */ + nc->modes[NCSI_MODE_MC].enable = 1; + + /* Modes without explicit enabled indications */ + nc->modes[NCSI_MODE_LINK].enable = 1; + nc->modes[NCSI_MODE_LINK].data[0] = ntohl(rsp->link_mode); + nc->modes[NCSI_MODE_VLAN].enable = 1; + nc->modes[NCSI_MODE_VLAN].data[0] = rsp->vlan_mode; + nc->modes[NCSI_MODE_FC].enable = 1; + nc->modes[NCSI_MODE_FC].data[0] = rsp->fc_mode; + nc->modes[NCSI_MODE_AEN].enable = 1; + nc->modes[NCSI_MODE_AEN].data[0] = ntohl(rsp->aen_mode); + + /* MAC addresses filter table */ + pdata = (unsigned char *)rsp + 48; + enable = rsp->mac_enable; + ncmf = &nc->mac_filter; + spin_lock_irqsave(&nc->lock, flags); + bitmap = &ncmf->bitmap; + for (i = 0; i < rsp->mac_cnt; i++, pdata += 6) { + if (!(enable & (0x1 << i))) + clear_bit(i, bitmap); + else + set_bit(i, bitmap); + + memcpy(&ncmf->addrs[i * ETH_ALEN], pdata, ETH_ALEN); + } + spin_unlock_irqrestore(&nc->lock, flags); + + /* VLAN filter table */ + enable = ntohs(rsp->vlan_enable); + ncvf = &nc->vlan_filter; + bitmap = &ncvf->bitmap; + spin_lock_irqsave(&nc->lock, flags); + for (i = 0; i < rsp->vlan_cnt; i++, pdata += 2) { + if (!(enable & (0x1 << i))) + clear_bit(i, bitmap); + else + set_bit(i, bitmap); + + ncvf->vids[i] = ntohs(*(__be16 *)pdata); + } + spin_unlock_irqrestore(&nc->lock, flags); + + return 0; +} + +static int ncsi_rsp_handler_gcps(struct ncsi_request *nr) +{ + struct ncsi_rsp_gcps_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_stats *ncs; + + /* Find the channel */ + rsp = (struct ncsi_rsp_gcps_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + /* Update HNC's statistics */ + ncs = &nc->stats; + ncs->hnc_cnt_hi = ntohl(rsp->cnt_hi); + ncs->hnc_cnt_lo = ntohl(rsp->cnt_lo); + ncs->hnc_rx_bytes = ntohl(rsp->rx_bytes); + ncs->hnc_tx_bytes = ntohl(rsp->tx_bytes); + ncs->hnc_rx_uc_pkts = ntohl(rsp->rx_uc_pkts); + ncs->hnc_rx_mc_pkts = ntohl(rsp->rx_mc_pkts); + ncs->hnc_rx_bc_pkts = ntohl(rsp->rx_bc_pkts); + ncs->hnc_tx_uc_pkts = ntohl(rsp->tx_uc_pkts); + ncs->hnc_tx_mc_pkts = ntohl(rsp->tx_mc_pkts); + ncs->hnc_tx_bc_pkts = ntohl(rsp->tx_bc_pkts); + ncs->hnc_fcs_err = ntohl(rsp->fcs_err); + ncs->hnc_align_err = ntohl(rsp->align_err); + ncs->hnc_false_carrier = ntohl(rsp->false_carrier); + ncs->hnc_runt_pkts = ntohl(rsp->runt_pkts); + ncs->hnc_jabber_pkts = ntohl(rsp->jabber_pkts); + ncs->hnc_rx_pause_xon = ntohl(rsp->rx_pause_xon); + ncs->hnc_rx_pause_xoff = ntohl(rsp->rx_pause_xoff); + ncs->hnc_tx_pause_xon = ntohl(rsp->tx_pause_xon); + ncs->hnc_tx_pause_xoff = ntohl(rsp->tx_pause_xoff); + ncs->hnc_tx_s_collision = ntohl(rsp->tx_s_collision); + ncs->hnc_tx_m_collision = ntohl(rsp->tx_m_collision); + ncs->hnc_l_collision = ntohl(rsp->l_collision); + ncs->hnc_e_collision = ntohl(rsp->e_collision); + ncs->hnc_rx_ctl_frames = ntohl(rsp->rx_ctl_frames); + ncs->hnc_rx_64_frames = ntohl(rsp->rx_64_frames); + ncs->hnc_rx_127_frames = ntohl(rsp->rx_127_frames); + ncs->hnc_rx_255_frames = ntohl(rsp->rx_255_frames); + ncs->hnc_rx_511_frames = ntohl(rsp->rx_511_frames); + ncs->hnc_rx_1023_frames = ntohl(rsp->rx_1023_frames); + ncs->hnc_rx_1522_frames = ntohl(rsp->rx_1522_frames); + ncs->hnc_rx_9022_frames = ntohl(rsp->rx_9022_frames); + ncs->hnc_tx_64_frames = ntohl(rsp->tx_64_frames); + ncs->hnc_tx_127_frames = ntohl(rsp->tx_127_frames); + ncs->hnc_tx_255_frames = ntohl(rsp->tx_255_frames); + ncs->hnc_tx_511_frames = ntohl(rsp->tx_511_frames); + ncs->hnc_tx_1023_frames = ntohl(rsp->tx_1023_frames); + ncs->hnc_tx_1522_frames = ntohl(rsp->tx_1522_frames); + ncs->hnc_tx_9022_frames = ntohl(rsp->tx_9022_frames); + ncs->hnc_rx_valid_bytes = ntohl(rsp->rx_valid_bytes); + ncs->hnc_rx_runt_pkts = ntohl(rsp->rx_runt_pkts); + ncs->hnc_rx_jabber_pkts = ntohl(rsp->rx_jabber_pkts); + + return 0; +} + +static int ncsi_rsp_handler_gns(struct ncsi_request *nr) +{ + struct ncsi_rsp_gns_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_stats *ncs; + + /* Find the channel */ + rsp = (struct ncsi_rsp_gns_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + /* Update HNC's statistics */ + ncs = &nc->stats; + ncs->ncsi_rx_cmds = ntohl(rsp->rx_cmds); + ncs->ncsi_dropped_cmds = ntohl(rsp->dropped_cmds); + ncs->ncsi_cmd_type_errs = ntohl(rsp->cmd_type_errs); + ncs->ncsi_cmd_csum_errs = ntohl(rsp->cmd_csum_errs); + ncs->ncsi_rx_pkts = ntohl(rsp->rx_pkts); + ncs->ncsi_tx_pkts = ntohl(rsp->tx_pkts); + ncs->ncsi_tx_aen_pkts = ntohl(rsp->tx_aen_pkts); + + return 0; +} + +static int ncsi_rsp_handler_gnpts(struct ncsi_request *nr) +{ + struct ncsi_rsp_gnpts_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_channel *nc; + struct ncsi_channel_stats *ncs; + + /* Find the channel */ + rsp = (struct ncsi_rsp_gnpts_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + NULL, &nc); + if (!nc) + return -ENODEV; + + /* Update HNC's statistics */ + ncs = &nc->stats; + ncs->pt_tx_pkts = ntohl(rsp->tx_pkts); + ncs->pt_tx_dropped = ntohl(rsp->tx_dropped); + ncs->pt_tx_channel_err = ntohl(rsp->tx_channel_err); + ncs->pt_tx_us_err = ntohl(rsp->tx_us_err); + ncs->pt_rx_pkts = ntohl(rsp->rx_pkts); + ncs->pt_rx_dropped = ntohl(rsp->rx_dropped); + ncs->pt_rx_channel_err = ntohl(rsp->rx_channel_err); + ncs->pt_rx_us_err = ntohl(rsp->rx_us_err); + ncs->pt_rx_os_err = ntohl(rsp->rx_os_err); + + return 0; +} + +static int ncsi_rsp_handler_gps(struct ncsi_request *nr) +{ + struct ncsi_rsp_gps_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_package *np; + + /* Find the package */ + rsp = (struct ncsi_rsp_gps_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + &np, NULL); + if (!np) + return -ENODEV; + + return 0; +} + +static int ncsi_rsp_handler_gpuuid(struct ncsi_request *nr) +{ + struct ncsi_rsp_gpuuid_pkt *rsp; + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_package *np; + + /* Find the package */ + rsp = (struct ncsi_rsp_gpuuid_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + &np, NULL); + if (!np) + return -ENODEV; + + memcpy(np->uuid, rsp->uuid, sizeof(rsp->uuid)); + + return 0; +} + +static int ncsi_rsp_handler_pldm(struct ncsi_request *nr) +{ + return 0; +} + +static int ncsi_rsp_handler_netlink(struct ncsi_request *nr) +{ + struct ncsi_dev_priv *ndp = nr->ndp; + struct ncsi_rsp_pkt *rsp; + struct ncsi_package *np; + struct ncsi_channel *nc; + int ret; + + /* Find the package */ + rsp = (struct ncsi_rsp_pkt *)skb_network_header(nr->rsp); + ncsi_find_package_and_channel(ndp, rsp->rsp.common.channel, + &np, &nc); + if (!np) + return -ENODEV; + + ret = ncsi_send_netlink_rsp(nr, np, nc); + + return ret; +} + +static struct ncsi_rsp_handler { + unsigned char type; + int payload; + int (*handler)(struct ncsi_request *nr); +} ncsi_rsp_handlers[] = { + { NCSI_PKT_RSP_CIS, 4, ncsi_rsp_handler_cis }, + { NCSI_PKT_RSP_SP, 4, ncsi_rsp_handler_sp }, + { NCSI_PKT_RSP_DP, 4, ncsi_rsp_handler_dp }, + { NCSI_PKT_RSP_EC, 4, ncsi_rsp_handler_ec }, + { NCSI_PKT_RSP_DC, 4, ncsi_rsp_handler_dc }, + { NCSI_PKT_RSP_RC, 4, ncsi_rsp_handler_rc }, + { NCSI_PKT_RSP_ECNT, 4, ncsi_rsp_handler_ecnt }, + { NCSI_PKT_RSP_DCNT, 4, ncsi_rsp_handler_dcnt }, + { NCSI_PKT_RSP_AE, 4, ncsi_rsp_handler_ae }, + { NCSI_PKT_RSP_SL, 4, ncsi_rsp_handler_sl }, + { NCSI_PKT_RSP_GLS, 16, ncsi_rsp_handler_gls }, + { NCSI_PKT_RSP_SVF, 4, ncsi_rsp_handler_svf }, + { NCSI_PKT_RSP_EV, 4, ncsi_rsp_handler_ev }, + { NCSI_PKT_RSP_DV, 4, ncsi_rsp_handler_dv }, + { NCSI_PKT_RSP_SMA, 4, ncsi_rsp_handler_sma }, + { NCSI_PKT_RSP_EBF, 4, ncsi_rsp_handler_ebf }, + { NCSI_PKT_RSP_DBF, 4, ncsi_rsp_handler_dbf }, + { NCSI_PKT_RSP_EGMF, 4, ncsi_rsp_handler_egmf }, + { NCSI_PKT_RSP_DGMF, 4, ncsi_rsp_handler_dgmf }, + { NCSI_PKT_RSP_SNFC, 4, ncsi_rsp_handler_snfc }, + { NCSI_PKT_RSP_GVI, 40, ncsi_rsp_handler_gvi }, + { NCSI_PKT_RSP_GC, 32, ncsi_rsp_handler_gc }, + { NCSI_PKT_RSP_GP, -1, ncsi_rsp_handler_gp }, + { NCSI_PKT_RSP_GCPS, 204, ncsi_rsp_handler_gcps }, + { NCSI_PKT_RSP_GNS, 32, ncsi_rsp_handler_gns }, + { NCSI_PKT_RSP_GNPTS, 48, ncsi_rsp_handler_gnpts }, + { NCSI_PKT_RSP_GPS, 8, ncsi_rsp_handler_gps }, + { NCSI_PKT_RSP_OEM, -1, ncsi_rsp_handler_oem }, + { NCSI_PKT_RSP_PLDM, -1, ncsi_rsp_handler_pldm }, + { NCSI_PKT_RSP_GPUUID, 20, ncsi_rsp_handler_gpuuid }, + { NCSI_PKT_RSP_QPNPR, -1, ncsi_rsp_handler_pldm }, + { NCSI_PKT_RSP_SNPR, -1, ncsi_rsp_handler_pldm } +}; + +int ncsi_rcv_rsp(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + struct ncsi_rsp_handler *nrh = NULL; + struct ncsi_dev *nd; + struct ncsi_dev_priv *ndp; + struct ncsi_request *nr; + struct ncsi_pkt_hdr *hdr; + unsigned long flags; + int payload, i, ret; + + /* Find the NCSI device */ + nd = ncsi_find_dev(orig_dev); + ndp = nd ? TO_NCSI_DEV_PRIV(nd) : NULL; + if (!ndp) + return -ENODEV; + + /* Check if it is AEN packet */ + hdr = (struct ncsi_pkt_hdr *)skb_network_header(skb); + if (hdr->type == NCSI_PKT_AEN) + return ncsi_aen_handler(ndp, skb); + + /* Find the handler */ + for (i = 0; i < ARRAY_SIZE(ncsi_rsp_handlers); i++) { + if (ncsi_rsp_handlers[i].type == hdr->type) { + if (ncsi_rsp_handlers[i].handler) + nrh = &ncsi_rsp_handlers[i]; + else + nrh = NULL; + + break; + } + } + + if (!nrh) { + netdev_err(nd->dev, "Received unrecognized packet (0x%x)\n", + hdr->type); + return -ENOENT; + } + + /* Associate with the request */ + spin_lock_irqsave(&ndp->lock, flags); + nr = &ndp->requests[hdr->id]; + if (!nr->used) { + spin_unlock_irqrestore(&ndp->lock, flags); + return -ENODEV; + } + + nr->rsp = skb; + if (!nr->enabled) { + spin_unlock_irqrestore(&ndp->lock, flags); + ret = -ENOENT; + goto out; + } + + /* Validate the packet */ + spin_unlock_irqrestore(&ndp->lock, flags); + payload = nrh->payload; + if (payload < 0) + payload = ntohs(hdr->length); + ret = ncsi_validate_rsp_pkt(nr, payload); + if (ret) { + netdev_warn(ndp->ndev.dev, + "NCSI: 'bad' packet ignored for type 0x%x\n", + hdr->type); + + if (nr->flags == NCSI_REQ_FLAG_NETLINK_DRIVEN) { + if (ret == -EPERM) + goto out_netlink; + else + ncsi_send_netlink_err(ndp->ndev.dev, + nr->snd_seq, + nr->snd_portid, + &nr->nlhdr, + ret); + } + goto out; + } + + /* Process the packet */ + ret = nrh->handler(nr); + if (ret) + netdev_err(ndp->ndev.dev, + "NCSI: Handler for packet type 0x%x returned %d\n", + hdr->type, ret); + +out_netlink: + if (nr->flags == NCSI_REQ_FLAG_NETLINK_DRIVEN) { + ret = ncsi_rsp_handler_netlink(nr); + if (ret) { + netdev_err(ndp->ndev.dev, + "NCSI: Netlink handler for packet type 0x%x returned %d\n", + hdr->type, ret); + } + } + +out: + ncsi_free_request(nr); + return ret; +} diff --git a/net/netfilter/Kconfig b/net/netfilter/Kconfig new file mode 100644 index 000000000..4b8d04640 --- /dev/null +++ b/net/netfilter/Kconfig @@ -0,0 +1,1666 @@ +# SPDX-License-Identifier: GPL-2.0-only +menu "Core Netfilter Configuration" + depends on INET && NETFILTER + +config NETFILTER_INGRESS + bool "Netfilter ingress support" + default y + select NET_INGRESS + help + This allows you to classify packets from ingress using the Netfilter + infrastructure. + +config NETFILTER_EGRESS + bool "Netfilter egress support" + default y + select NET_EGRESS + help + This allows you to classify packets before transmission using the + Netfilter infrastructure. + +config NETFILTER_SKIP_EGRESS + def_bool NETFILTER_EGRESS && (NET_CLS_ACT || IFB) + +config NETFILTER_NETLINK + tristate + +config NETFILTER_FAMILY_BRIDGE + bool + +config NETFILTER_FAMILY_ARP + bool + +config NETFILTER_NETLINK_HOOK + tristate "Netfilter base hook dump support" + depends on NETFILTER_ADVANCED + depends on NF_TABLES + select NETFILTER_NETLINK + help + If this option is enabled, the kernel will include support + to list the base netfilter hooks via NFNETLINK. + This is helpful for debugging. + +config NETFILTER_NETLINK_ACCT + tristate "Netfilter NFACCT over NFNETLINK interface" + depends on NETFILTER_ADVANCED + select NETFILTER_NETLINK + help + If this option is enabled, the kernel will include support + for extended accounting via NFNETLINK. + +config NETFILTER_NETLINK_QUEUE + tristate "Netfilter NFQUEUE over NFNETLINK interface" + depends on NETFILTER_ADVANCED + select NETFILTER_NETLINK + help + If this option is enabled, the kernel will include support + for queueing packets via NFNETLINK. + +config NETFILTER_NETLINK_LOG + tristate "Netfilter LOG over NFNETLINK interface" + default m if NETFILTER_ADVANCED=n + select NETFILTER_NETLINK + help + If this option is enabled, the kernel will include support + for logging packets via NFNETLINK. + + This obsoletes the existing ipt_ULOG and ebg_ulog mechanisms, + and is also scheduled to replace the old syslog-based ipt_LOG + and ip6t_LOG modules. + +config NETFILTER_NETLINK_OSF + tristate "Netfilter OSF over NFNETLINK interface" + depends on NETFILTER_ADVANCED + select NETFILTER_NETLINK + help + If this option is enabled, the kernel will include support + for passive OS fingerprint via NFNETLINK. + +config NF_CONNTRACK + tristate "Netfilter connection tracking support" + default m if NETFILTER_ADVANCED=n + select NF_DEFRAG_IPV4 + select NF_DEFRAG_IPV6 if IPV6 != n + help + Connection tracking keeps a record of what packets have passed + through your machine, in order to figure out how they are related + into connections. + + This is required to do Masquerading or other kinds of Network + Address Translation. It can also be used to enhance packet + filtering (see `Connection state match support' below). + + To compile it as a module, choose M here. If unsure, say N. + +config NF_LOG_SYSLOG + tristate "Syslog packet logging" + default m if NETFILTER_ADVANCED=n + help + This option enable support for packet logging via syslog. + It supports IPv4, IPV6, ARP and common transport protocols such + as TCP and UDP. + This is a simpler but less flexible logging method compared to + CONFIG_NETFILTER_NETLINK_LOG. + If both are enabled the backend to use can be configured at run-time + by means of per-address-family sysctl tunables. + +if NF_CONNTRACK +config NETFILTER_CONNCOUNT + tristate + +config NF_CONNTRACK_MARK + bool 'Connection mark tracking support' + depends on NETFILTER_ADVANCED + help + This option enables support for connection marks, used by the + `CONNMARK' target and `connmark' match. Similar to the mark value + of packets, but this mark value is kept in the conntrack session + instead of the individual packets. + +config NF_CONNTRACK_SECMARK + bool 'Connection tracking security mark support' + depends on NETWORK_SECMARK + default y if NETFILTER_ADVANCED=n + help + This option enables security markings to be applied to + connections. Typically they are copied to connections from + packets using the CONNSECMARK target and copied back from + connections to packets with the same target, with the packets + being originally labeled via SECMARK. + + If unsure, say 'N'. + +config NF_CONNTRACK_ZONES + bool 'Connection tracking zones' + depends on NETFILTER_ADVANCED + help + This option enables support for connection tracking zones. + Normally, each connection needs to have a unique system wide + identity. Connection tracking zones allow to have multiple + connections using the same identity, as long as they are + contained in different zones. + + If unsure, say `N'. + +config NF_CONNTRACK_PROCFS + bool "Supply CT list in procfs (OBSOLETE)" + depends on PROC_FS + help + This option enables for the list of known conntrack entries + to be shown in procfs under net/netfilter/nf_conntrack. This + is considered obsolete in favor of using the conntrack(8) + tool which uses Netlink. + +config NF_CONNTRACK_EVENTS + bool "Connection tracking events" + depends on NETFILTER_ADVANCED + help + If this option is enabled, the connection tracking code will + provide a notifier chain that can be used by other kernel code + to get notified about changes in the connection tracking state. + + If unsure, say `N'. + +config NF_CONNTRACK_TIMEOUT + bool 'Connection tracking timeout' + depends on NETFILTER_ADVANCED + help + This option enables support for connection tracking timeout + extension. This allows you to attach timeout policies to flow + via the CT target. + + If unsure, say `N'. + +config NF_CONNTRACK_TIMESTAMP + bool 'Connection tracking timestamping' + depends on NETFILTER_ADVANCED + help + This option enables support for connection tracking timestamping. + This allows you to store the flow start-time and to obtain + the flow-stop time (once it has been destroyed) via Connection + tracking events. + + If unsure, say `N'. + +config NF_CONNTRACK_LABELS + bool "Connection tracking labels" + help + This option enables support for assigning user-defined flag bits + to connection tracking entries. It can be used with xtables connlabel + match and the nftables ct expression. + +config NF_CT_PROTO_DCCP + bool 'DCCP protocol connection tracking support' + depends on NETFILTER_ADVANCED + default y + help + With this option enabled, the layer 3 independent connection + tracking code will be able to do state tracking on DCCP connections. + + If unsure, say Y. + +config NF_CT_PROTO_GRE + bool + +config NF_CT_PROTO_SCTP + bool 'SCTP protocol connection tracking support' + depends on NETFILTER_ADVANCED + default y + select LIBCRC32C + help + With this option enabled, the layer 3 independent connection + tracking code will be able to do state tracking on SCTP connections. + + If unsure, say Y. + +config NF_CT_PROTO_UDPLITE + bool 'UDP-Lite protocol connection tracking support' + depends on NETFILTER_ADVANCED + default y + help + With this option enabled, the layer 3 independent connection + tracking code will be able to do state tracking on UDP-Lite + connections. + + If unsure, say Y. + +config NF_CONNTRACK_AMANDA + tristate "Amanda backup protocol support" + depends on NETFILTER_ADVANCED + select TEXTSEARCH + select TEXTSEARCH_KMP + help + If you are running the Amanda backup package + on this machine or machines that will be MASQUERADED through this + machine, then you may want to enable this feature. This allows the + connection tracking and natting code to allow the sub-channels that + Amanda requires for communication of the backup data, messages and + index. + + To compile it as a module, choose M here. If unsure, say N. + +config NF_CONNTRACK_FTP + tristate "FTP protocol support" + default m if NETFILTER_ADVANCED=n + help + Tracking FTP connections is problematic: special helpers are + required for tracking them, and doing masquerading and other forms + of Network Address Translation on them. + + This is FTP support on Layer 3 independent connection tracking. + + To compile it as a module, choose M here. If unsure, say N. + +config NF_CONNTRACK_H323 + tristate "H.323 protocol support" + depends on IPV6 || IPV6=n + depends on NETFILTER_ADVANCED + help + H.323 is a VoIP signalling protocol from ITU-T. As one of the most + important VoIP protocols, it is widely used by voice hardware and + software including voice gateways, IP phones, Netmeeting, OpenPhone, + Gnomemeeting, etc. + + With this module you can support H.323 on a connection tracking/NAT + firewall. + + This module supports RAS, Fast Start, H.245 Tunnelling, Call + Forwarding, RTP/RTCP and T.120 based audio, video, fax, chat, + whiteboard, file transfer, etc. For more information, please + visit http://nath323.sourceforge.net/. + + To compile it as a module, choose M here. If unsure, say N. + +config NF_CONNTRACK_IRC + tristate "IRC protocol support" + default m if NETFILTER_ADVANCED=n + help + There is a commonly-used extension to IRC called + Direct Client-to-Client Protocol (DCC). This enables users to send + files to each other, and also chat to each other without the need + of a server. DCC Sending is used anywhere you send files over IRC, + and DCC Chat is most commonly used by Eggdrop bots. If you are + using NAT, this extension will enable you to send files and initiate + chats. Note that you do NOT need this extension to get files or + have others initiate chats, or everything else in IRC. + + To compile it as a module, choose M here. If unsure, say N. + +config NF_CONNTRACK_BROADCAST + tristate + +config NF_CONNTRACK_NETBIOS_NS + tristate "NetBIOS name service protocol support" + select NF_CONNTRACK_BROADCAST + help + NetBIOS name service requests are sent as broadcast messages from an + unprivileged port and responded to with unicast messages to the + same port. This make them hard to firewall properly because connection + tracking doesn't deal with broadcasts. This helper tracks locally + originating NetBIOS name service requests and the corresponding + responses. It relies on correct IP address configuration, specifically + netmask and broadcast address. When properly configured, the output + of "ip address show" should look similar to this: + + $ ip -4 address show eth0 + 4: eth0: mtu 1500 qdisc pfifo_fast qlen 1000 + inet 172.16.2.252/24 brd 172.16.2.255 scope global eth0 + + To compile it as a module, choose M here. If unsure, say N. + +config NF_CONNTRACK_SNMP + tristate "SNMP service protocol support" + depends on NETFILTER_ADVANCED + select NF_CONNTRACK_BROADCAST + help + SNMP service requests are sent as broadcast messages from an + unprivileged port and responded to with unicast messages to the + same port. This make them hard to firewall properly because connection + tracking doesn't deal with broadcasts. This helper tracks locally + originating SNMP service requests and the corresponding + responses. It relies on correct IP address configuration, specifically + netmask and broadcast address. + + To compile it as a module, choose M here. If unsure, say N. + +config NF_CONNTRACK_PPTP + tristate "PPtP protocol support" + depends on NETFILTER_ADVANCED + select NF_CT_PROTO_GRE + help + This module adds support for PPTP (Point to Point Tunnelling + Protocol, RFC2637) connection tracking and NAT. + + If you are running PPTP sessions over a stateful firewall or NAT + box, you may want to enable this feature. + + Please note that not all PPTP modes of operation are supported yet. + Specifically these limitations exist: + - Blindly assumes that control connections are always established + in PNS->PAC direction. This is a violation of RFC2637. + - Only supports a single call within each session + + To compile it as a module, choose M here. If unsure, say N. + +config NF_CONNTRACK_SANE + tristate "SANE protocol support" + depends on NETFILTER_ADVANCED + help + SANE is a protocol for remote access to scanners as implemented + by the 'saned' daemon. Like FTP, it uses separate control and + data connections. + + With this module you can support SANE on a connection tracking + firewall. + + To compile it as a module, choose M here. If unsure, say N. + +config NF_CONNTRACK_SIP + tristate "SIP protocol support" + default m if NETFILTER_ADVANCED=n + help + SIP is an application-layer control protocol that can establish, + modify, and terminate multimedia sessions (conferences) such as + Internet telephony calls. With the nf_conntrack_sip and + the nf_nat_sip modules you can support the protocol on a connection + tracking/NATing firewall. + + To compile it as a module, choose M here. If unsure, say N. + +config NF_CONNTRACK_TFTP + tristate "TFTP protocol support" + depends on NETFILTER_ADVANCED + help + TFTP connection tracking helper, this is required depending + on how restrictive your ruleset is. + If you are using a tftp client behind -j SNAT or -j MASQUERADING + you will need this. + + To compile it as a module, choose M here. If unsure, say N. + +config NF_CT_NETLINK + tristate 'Connection tracking netlink interface' + select NETFILTER_NETLINK + default m if NETFILTER_ADVANCED=n + help + This option enables support for a netlink-based userspace interface + +config NF_CT_NETLINK_TIMEOUT + tristate 'Connection tracking timeout tuning via Netlink' + select NETFILTER_NETLINK + depends on NETFILTER_ADVANCED + depends on NF_CONNTRACK_TIMEOUT + help + This option enables support for connection tracking timeout + fine-grain tuning. This allows you to attach specific timeout + policies to flows, instead of using the global timeout policy. + + If unsure, say `N'. + +config NF_CT_NETLINK_HELPER + tristate 'Connection tracking helpers in user-space via Netlink' + select NETFILTER_NETLINK + depends on NF_CT_NETLINK + depends on NETFILTER_NETLINK_QUEUE + depends on NETFILTER_NETLINK_GLUE_CT + depends on NETFILTER_ADVANCED + help + This option enables the user-space connection tracking helpers + infrastructure. + + If unsure, say `N'. + +config NETFILTER_NETLINK_GLUE_CT + bool "NFQUEUE and NFLOG integration with Connection Tracking" + default n + depends on (NETFILTER_NETLINK_QUEUE || NETFILTER_NETLINK_LOG) && NF_CT_NETLINK + help + If this option is enabled, NFQUEUE and NFLOG can include + Connection Tracking information together with the packet is + the enqueued via NFNETLINK. + +config NF_NAT + tristate "Network Address Translation support" + depends on NF_CONNTRACK + default m if NETFILTER_ADVANCED=n + help + The NAT option allows masquerading, port forwarding and other + forms of full Network Address Port Translation. This can be + controlled by iptables, ip6tables or nft. + +config NF_NAT_AMANDA + tristate + depends on NF_CONNTRACK && NF_NAT + default NF_NAT && NF_CONNTRACK_AMANDA + +config NF_NAT_FTP + tristate + depends on NF_CONNTRACK && NF_NAT + default NF_NAT && NF_CONNTRACK_FTP + +config NF_NAT_IRC + tristate + depends on NF_CONNTRACK && NF_NAT + default NF_NAT && NF_CONNTRACK_IRC + +config NF_NAT_SIP + tristate + depends on NF_CONNTRACK && NF_NAT + default NF_NAT && NF_CONNTRACK_SIP + +config NF_NAT_TFTP + tristate + depends on NF_CONNTRACK && NF_NAT + default NF_NAT && NF_CONNTRACK_TFTP + +config NF_NAT_REDIRECT + bool + +config NF_NAT_MASQUERADE + bool + +config NETFILTER_SYNPROXY + tristate + +endif # NF_CONNTRACK + +config NF_TABLES + select NETFILTER_NETLINK + select LIBCRC32C + tristate "Netfilter nf_tables support" + help + nftables is the new packet classification framework that intends to + replace the existing {ip,ip6,arp,eb}_tables infrastructure. It + provides a pseudo-state machine with an extensible instruction-set + (also known as expressions) that the userspace 'nft' utility + (https://www.netfilter.org/projects/nftables) uses to build the + rule-set. It also comes with the generic set infrastructure that + allows you to construct mappings between matchings and actions + for performance lookups. + + To compile it as a module, choose M here. + +if NF_TABLES +config NF_TABLES_INET + depends on IPV6 + select NF_TABLES_IPV4 + select NF_TABLES_IPV6 + bool "Netfilter nf_tables mixed IPv4/IPv6 tables support" + help + This option enables support for a mixed IPv4/IPv6 "inet" table. + +config NF_TABLES_NETDEV + bool "Netfilter nf_tables netdev tables support" + help + This option enables support for the "netdev" table. + +config NFT_NUMGEN + tristate "Netfilter nf_tables number generator module" + help + This option adds the number generator expression used to perform + incremental counting and random numbers bound to a upper limit. + +config NFT_CT + depends on NF_CONNTRACK + tristate "Netfilter nf_tables conntrack module" + help + This option adds the "ct" expression that you can use to match + connection tracking information such as the flow state. + +config NFT_FLOW_OFFLOAD + depends on NF_CONNTRACK && NF_FLOW_TABLE + tristate "Netfilter nf_tables hardware flow offload module" + help + This option adds the "flow_offload" expression that you can use to + choose what flows are placed into the hardware. + +config NFT_CONNLIMIT + tristate "Netfilter nf_tables connlimit module" + depends on NF_CONNTRACK + depends on NETFILTER_ADVANCED + select NETFILTER_CONNCOUNT + help + This option adds the "connlimit" expression that you can use to + ratelimit rule matchings per connections. + +config NFT_LOG + tristate "Netfilter nf_tables log module" + help + This option adds the "log" expression that you can use to log + packets matching some criteria. + +config NFT_LIMIT + tristate "Netfilter nf_tables limit module" + help + This option adds the "limit" expression that you can use to + ratelimit rule matchings. + +config NFT_MASQ + depends on NF_CONNTRACK + depends on NF_NAT + select NF_NAT_MASQUERADE + tristate "Netfilter nf_tables masquerade support" + help + This option adds the "masquerade" expression that you can use + to perform NAT in the masquerade flavour. + +config NFT_REDIR + depends on NF_CONNTRACK + depends on NF_NAT + tristate "Netfilter nf_tables redirect support" + select NF_NAT_REDIRECT + help + This options adds the "redirect" expression that you can use + to perform NAT in the redirect flavour. + +config NFT_NAT + depends on NF_CONNTRACK + select NF_NAT + depends on NF_TABLES_IPV4 || NF_TABLES_IPV6 + tristate "Netfilter nf_tables nat module" + help + This option adds the "nat" expression that you can use to perform + typical Network Address Translation (NAT) packet transformations. + +config NFT_TUNNEL + tristate "Netfilter nf_tables tunnel module" + help + This option adds the "tunnel" expression that you can use to set + tunneling policies. + +config NFT_OBJREF + tristate "Netfilter nf_tables stateful object reference module" + help + This option adds the "objref" expression that allows you to refer to + stateful objects, such as counters and quotas. + +config NFT_QUEUE + depends on NETFILTER_NETLINK_QUEUE + tristate "Netfilter nf_tables queue module" + help + This is required if you intend to use the userspace queueing + infrastructure (also known as NFQUEUE) from nftables. + +config NFT_QUOTA + tristate "Netfilter nf_tables quota module" + help + This option adds the "quota" expression that you can use to match + enforce bytes quotas. + +config NFT_REJECT + default m if NETFILTER_ADVANCED=n + tristate "Netfilter nf_tables reject support" + depends on !NF_TABLES_INET || (IPV6!=m || m) + help + This option adds the "reject" expression that you can use to + explicitly deny and notify via TCP reset/ICMP informational errors + unallowed traffic. + +config NFT_REJECT_INET + depends on NF_TABLES_INET + default NFT_REJECT + tristate + +config NFT_COMPAT + depends on NETFILTER_XTABLES + tristate "Netfilter x_tables over nf_tables module" + help + This is required if you intend to use any of existing + x_tables match/target extensions over the nf_tables + framework. + +config NFT_HASH + tristate "Netfilter nf_tables hash module" + help + This option adds the "hash" expression that you can use to perform + a hash operation on registers. + +config NFT_FIB + tristate + +config NFT_FIB_INET + depends on NF_TABLES_INET + depends on NFT_FIB_IPV4 + depends on NFT_FIB_IPV6 + tristate "Netfilter nf_tables fib inet support" + help + This option allows using the FIB expression from the inet table. + The lookup will be delegated to the IPv4 or IPv6 FIB depending + on the protocol of the packet. + +config NFT_XFRM + tristate "Netfilter nf_tables xfrm/IPSec security association matching" + depends on XFRM + help + This option adds an expression that you can use to extract properties + of a packets security association. + +config NFT_SOCKET + tristate "Netfilter nf_tables socket match support" + depends on IPV6 || IPV6=n + select NF_SOCKET_IPV4 + select NF_SOCKET_IPV6 if NF_TABLES_IPV6 + help + This option allows matching for the presence or absence of a + corresponding socket and its attributes. + +config NFT_OSF + tristate "Netfilter nf_tables passive OS fingerprint support" + depends on NETFILTER_ADVANCED + select NETFILTER_NETLINK_OSF + help + This option allows matching packets from an specific OS. + +config NFT_TPROXY + tristate "Netfilter nf_tables tproxy support" + depends on IPV6 || IPV6=n + select NF_DEFRAG_IPV4 + select NF_DEFRAG_IPV6 if NF_TABLES_IPV6 + select NF_TPROXY_IPV4 + select NF_TPROXY_IPV6 if NF_TABLES_IPV6 + help + This makes transparent proxy support available in nftables. + +config NFT_SYNPROXY + tristate "Netfilter nf_tables SYNPROXY expression support" + depends on NF_CONNTRACK && NETFILTER_ADVANCED + select NETFILTER_SYNPROXY + select SYN_COOKIES + help + The SYNPROXY expression allows you to intercept TCP connections and + establish them using syncookies before they are passed on to the + server. This allows to avoid conntrack and server resource usage + during SYN-flood attacks. + +if NF_TABLES_NETDEV + +config NF_DUP_NETDEV + tristate "Netfilter packet duplication support" + help + This option enables the generic packet duplication infrastructure + for Netfilter. + +config NFT_DUP_NETDEV + tristate "Netfilter nf_tables netdev packet duplication support" + select NF_DUP_NETDEV + help + This option enables packet duplication for the "netdev" family. + +config NFT_FWD_NETDEV + tristate "Netfilter nf_tables netdev packet forwarding support" + select NF_DUP_NETDEV + help + This option enables packet forwarding for the "netdev" family. + +config NFT_FIB_NETDEV + depends on NFT_FIB_IPV4 + depends on NFT_FIB_IPV6 + tristate "Netfilter nf_tables netdev fib lookups support" + help + This option allows using the FIB expression from the netdev table. + The lookup will be delegated to the IPv4 or IPv6 FIB depending + on the protocol of the packet. + +config NFT_REJECT_NETDEV + depends on NFT_REJECT_IPV4 + depends on NFT_REJECT_IPV6 + tristate "Netfilter nf_tables netdev REJECT support" + help + This option enables the REJECT support from the netdev table. + The return packet generation will be delegated to the IPv4 + or IPv6 ICMP or TCP RST implementation depending on the + protocol of the packet. + +endif # NF_TABLES_NETDEV + +endif # NF_TABLES + +config NF_FLOW_TABLE_INET + tristate "Netfilter flow table mixed IPv4/IPv6 module" + depends on NF_FLOW_TABLE + help + This option adds the flow table mixed IPv4/IPv6 support. + + To compile it as a module, choose M here. + +config NF_FLOW_TABLE + tristate "Netfilter flow table module" + depends on NETFILTER_INGRESS + depends on NF_CONNTRACK + depends on NF_TABLES + help + This option adds the flow table core infrastructure. + + To compile it as a module, choose M here. + +config NF_FLOW_TABLE_PROCFS + bool "Supply flow table statistics in procfs" + depends on NF_FLOW_TABLE + depends on PROC_FS + help + This option enables for the flow table offload statistics + to be shown in procfs under net/netfilter/nf_flowtable. + +config NETFILTER_XTABLES + tristate "Netfilter Xtables support (required for ip_tables)" + default m if NETFILTER_ADVANCED=n + help + This is required if you intend to use any of ip_tables, + ip6_tables or arp_tables. + +if NETFILTER_XTABLES + +config NETFILTER_XTABLES_COMPAT + bool "Netfilter Xtables 32bit support" + depends on COMPAT + default y + help + This option provides a translation layer to run 32bit arp,ip(6),ebtables + binaries on 64bit kernels. + + If unsure, say N. + +comment "Xtables combined modules" + +config NETFILTER_XT_MARK + tristate 'nfmark target and match support' + default m if NETFILTER_ADVANCED=n + help + This option adds the "MARK" target and "mark" match. + + Netfilter mark matching allows you to match packets based on the + "nfmark" value in the packet. + The target allows you to create rules in the "mangle" table which alter + the netfilter mark (nfmark) field associated with the packet. + + Prior to routing, the nfmark can influence the routing method and can + also be used by other subsystems to change their behavior. + +config NETFILTER_XT_CONNMARK + tristate 'ctmark target and match support' + depends on NF_CONNTRACK + depends on NETFILTER_ADVANCED + select NF_CONNTRACK_MARK + help + This option adds the "CONNMARK" target and "connmark" match. + + Netfilter allows you to store a mark value per connection (a.k.a. + ctmark), similarly to the packet mark (nfmark). Using this + target and match, you can set and match on this mark. + +config NETFILTER_XT_SET + tristate 'set target and match support' + depends on IP_SET + depends on NETFILTER_ADVANCED + help + This option adds the "SET" target and "set" match. + + Using this target and match, you can add/delete and match + elements in the sets created by ipset(8). + + To compile it as a module, choose M here. If unsure, say N. + +# alphabetically ordered list of targets + +comment "Xtables targets" + +config NETFILTER_XT_TARGET_AUDIT + tristate "AUDIT target support" + depends on AUDIT + depends on NETFILTER_ADVANCED + help + This option adds a 'AUDIT' target, which can be used to create + audit records for packets dropped/accepted. + + To compileit as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_CHECKSUM + tristate "CHECKSUM target support" + depends on IP_NF_MANGLE || IP6_NF_MANGLE + depends on NETFILTER_ADVANCED + help + This option adds a `CHECKSUM' target, which can be used in the iptables mangle + table to work around buggy DHCP clients in virtualized environments. + + Some old DHCP clients drop packets because they are not aware + that the checksum would normally be offloaded to hardware and + thus should be considered valid. + This target can be used to fill in the checksum using iptables + when such packets are sent via a virtual network device. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_CLASSIFY + tristate '"CLASSIFY" target support' + depends on NETFILTER_ADVANCED + help + This option adds a `CLASSIFY' target, which enables the user to set + the priority of a packet. Some qdiscs can use this value for + classification, among these are: + + atm, cbq, dsmark, pfifo_fast, htb, prio + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_CONNMARK + tristate '"CONNMARK" target support' + depends on NF_CONNTRACK + depends on NETFILTER_ADVANCED + select NETFILTER_XT_CONNMARK + help + This is a backwards-compat option for the user's convenience + (e.g. when running oldconfig). It selects + CONFIG_NETFILTER_XT_CONNMARK (combined connmark/CONNMARK module). + +config NETFILTER_XT_TARGET_CONNSECMARK + tristate '"CONNSECMARK" target support' + depends on NF_CONNTRACK && NF_CONNTRACK_SECMARK + default m if NETFILTER_ADVANCED=n + help + The CONNSECMARK target copies security markings from packets + to connections, and restores security markings from connections + to packets (if the packets are not already marked). This would + normally be used in conjunction with the SECMARK target. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_CT + tristate '"CT" target support' + depends on NF_CONNTRACK + depends on IP_NF_RAW || IP6_NF_RAW + depends on NETFILTER_ADVANCED + help + This options adds a `CT' target, which allows to specify initial + connection tracking parameters like events to be delivered and + the helper to be used. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_DSCP + tristate '"DSCP" and "TOS" target support' + depends on IP_NF_MANGLE || IP6_NF_MANGLE + depends on NETFILTER_ADVANCED + help + This option adds a `DSCP' target, which allows you to manipulate + the IPv4/IPv6 header DSCP field (differentiated services codepoint). + + The DSCP field can have any value between 0x0 and 0x3f inclusive. + + It also adds the "TOS" target, which allows you to create rules in + the "mangle" table which alter the Type Of Service field of an IPv4 + or the Priority field of an IPv6 packet, prior to routing. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_HL + tristate '"HL" hoplimit target support' + depends on IP_NF_MANGLE || IP6_NF_MANGLE + depends on NETFILTER_ADVANCED + help + This option adds the "HL" (for IPv6) and "TTL" (for IPv4) + targets, which enable the user to change the + hoplimit/time-to-live value of the IP header. + + While it is safe to decrement the hoplimit/TTL value, the + modules also allow to increment and set the hoplimit value of + the header to arbitrary values. This is EXTREMELY DANGEROUS + since you can easily create immortal packets that loop + forever on the network. + +config NETFILTER_XT_TARGET_HMARK + tristate '"HMARK" target support' + depends on IP6_NF_IPTABLES || IP6_NF_IPTABLES=n + depends on NETFILTER_ADVANCED + help + This option adds the "HMARK" target. + + The target allows you to create rules in the "raw" and "mangle" tables + which set the skbuff mark by means of hash calculation within a given + range. The nfmark can influence the routing method and can also be used + by other subsystems to change their behaviour. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_IDLETIMER + tristate "IDLETIMER target support" + depends on NETFILTER_ADVANCED + help + + This option adds the `IDLETIMER' target. Each matching packet + resets the timer associated with label specified when the rule is + added. When the timer expires, it triggers a sysfs notification. + The remaining time for expiration can be read via sysfs. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_LED + tristate '"LED" target support' + depends on LEDS_CLASS && LEDS_TRIGGERS + depends on NETFILTER_ADVANCED + help + This option adds a `LED' target, which allows you to blink LEDs in + response to particular packets passing through your machine. + + This can be used to turn a spare LED into a network activity LED, + which only flashes in response to FTP transfers, for example. Or + you could have an LED which lights up for a minute or two every time + somebody connects to your machine via SSH. + + You will need support for the "led" class to make this work. + + To create an LED trigger for incoming SSH traffic: + iptables -A INPUT -p tcp --dport 22 -j LED --led-trigger-id ssh --led-delay 1000 + + Then attach the new trigger to an LED on your system: + echo netfilter-ssh > /sys/class/leds//trigger + + For more information on the LEDs available on your system, see + Documentation/leds/leds-class.rst + +config NETFILTER_XT_TARGET_LOG + tristate "LOG target support" + select NF_LOG_SYSLOG + select NF_LOG_IPV6 if IP6_NF_IPTABLES + default m if NETFILTER_ADVANCED=n + help + This option adds a `LOG' target, which allows you to create rules in + any iptables table which records the packet header to the syslog. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_MARK + tristate '"MARK" target support' + depends on NETFILTER_ADVANCED + select NETFILTER_XT_MARK + help + This is a backwards-compat option for the user's convenience + (e.g. when running oldconfig). It selects + CONFIG_NETFILTER_XT_MARK (combined mark/MARK module). + +config NETFILTER_XT_NAT + tristate '"SNAT and DNAT" targets support' + depends on NF_NAT + help + This option enables the SNAT and DNAT targets. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_NETMAP + tristate '"NETMAP" target support' + depends on NF_NAT + help + NETMAP is an implementation of static 1:1 NAT mapping of network + addresses. It maps the network address part, while keeping the host + address part intact. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_NFLOG + tristate '"NFLOG" target support' + default m if NETFILTER_ADVANCED=n + select NETFILTER_NETLINK_LOG + help + This option enables the NFLOG target, which allows to LOG + messages through nfnetlink_log. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_NFQUEUE + tristate '"NFQUEUE" target Support' + depends on NETFILTER_ADVANCED + select NETFILTER_NETLINK_QUEUE + help + This target replaced the old obsolete QUEUE target. + + As opposed to QUEUE, it supports 65535 different queues, + not just one. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_NOTRACK + tristate '"NOTRACK" target support (DEPRECATED)' + depends on NF_CONNTRACK + depends on IP_NF_RAW || IP6_NF_RAW + depends on NETFILTER_ADVANCED + select NETFILTER_XT_TARGET_CT + +config NETFILTER_XT_TARGET_RATEEST + tristate '"RATEEST" target support' + depends on NETFILTER_ADVANCED + help + This option adds a `RATEEST' target, which allows to measure + rates similar to TC estimators. The `rateest' match can be + used to match on the measured rates. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_REDIRECT + tristate "REDIRECT target support" + depends on NF_NAT + select NF_NAT_REDIRECT + help + REDIRECT is a special case of NAT: all incoming connections are + mapped onto the incoming interface's address, causing the packets to + come to the local machine instead of passing through. This is + useful for transparent proxies. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_MASQUERADE + tristate "MASQUERADE target support" + depends on NF_NAT + default m if NETFILTER_ADVANCED=n + select NF_NAT_MASQUERADE + help + Masquerading is a special case of NAT: all outgoing connections are + changed to seem to come from a particular interface's address, and + if the interface goes down, those connections are lost. This is + only useful for dialup accounts with dynamic IP address (ie. your IP + address will be different on next dialup). + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_TEE + tristate '"TEE" - packet cloning to alternate destination' + depends on NETFILTER_ADVANCED + depends on IPV6 || IPV6=n + depends on !NF_CONNTRACK || NF_CONNTRACK + depends on IP6_NF_IPTABLES || !IP6_NF_IPTABLES + select NF_DUP_IPV4 + select NF_DUP_IPV6 if IP6_NF_IPTABLES + help + This option adds a "TEE" target with which a packet can be cloned and + this clone be rerouted to another nexthop. + +config NETFILTER_XT_TARGET_TPROXY + tristate '"TPROXY" target transparent proxying support' + depends on NETFILTER_XTABLES + depends on NETFILTER_ADVANCED + depends on IPV6 || IPV6=n + depends on IP6_NF_IPTABLES || IP6_NF_IPTABLES=n + depends on IP_NF_MANGLE + select NF_DEFRAG_IPV4 + select NF_DEFRAG_IPV6 if IP6_NF_IPTABLES != n + select NF_TPROXY_IPV4 + select NF_TPROXY_IPV6 if IP6_NF_IPTABLES + help + This option adds a `TPROXY' target, which is somewhat similar to + REDIRECT. It can only be used in the mangle table and is useful + to redirect traffic to a transparent proxy. It does _not_ depend + on Netfilter connection tracking and NAT, unlike REDIRECT. + For it to work you will have to configure certain iptables rules + and use policy routing. For more information on how to set it up + see Documentation/networking/tproxy.rst. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_TRACE + tristate '"TRACE" target support' + depends on IP_NF_RAW || IP6_NF_RAW + depends on NETFILTER_ADVANCED + help + The TRACE target allows you to mark packets so that the kernel + will log every rule which match the packets as those traverse + the tables, chains, rules. + + If you want to compile it as a module, say M here and read + . If unsure, say `N'. + +config NETFILTER_XT_TARGET_SECMARK + tristate '"SECMARK" target support' + depends on NETWORK_SECMARK + default m if NETFILTER_ADVANCED=n + help + The SECMARK target allows security marking of network + packets, for use with security subsystems. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_TCPMSS + tristate '"TCPMSS" target support' + depends on IPV6 || IPV6=n + default m if NETFILTER_ADVANCED=n + help + This option adds a `TCPMSS' target, which allows you to alter the + MSS value of TCP SYN packets, to control the maximum size for that + connection (usually limiting it to your outgoing interface's MTU + minus 40). + + This is used to overcome criminally braindead ISPs or servers which + block ICMP Fragmentation Needed packets. The symptoms of this + problem are that everything works fine from your Linux + firewall/router, but machines behind it can never exchange large + packets: + 1) Web browsers connect, then hang with no data received. + 2) Small mail works fine, but large emails hang. + 3) ssh works fine, but scp hangs after initial handshaking. + + Workaround: activate this option and add a rule to your firewall + configuration like: + + iptables -A FORWARD -p tcp --tcp-flags SYN,RST SYN \ + -j TCPMSS --clamp-mss-to-pmtu + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_TARGET_TCPOPTSTRIP + tristate '"TCPOPTSTRIP" target support' + depends on IP_NF_MANGLE || IP6_NF_MANGLE + depends on NETFILTER_ADVANCED + help + This option adds a "TCPOPTSTRIP" target, which allows you to strip + TCP options from TCP packets. + +# alphabetically ordered list of matches + +comment "Xtables matches" + +config NETFILTER_XT_MATCH_ADDRTYPE + tristate '"addrtype" address type match support' + default m if NETFILTER_ADVANCED=n + help + This option allows you to match what routing thinks of an address, + eg. UNICAST, LOCAL, BROADCAST, ... + + If you want to compile it as a module, say M here and read + . If unsure, say `N'. + +config NETFILTER_XT_MATCH_BPF + tristate '"bpf" match support' + depends on NETFILTER_ADVANCED + help + BPF matching applies a linux socket filter to each packet and + accepts those for which the filter returns non-zero. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_CGROUP + tristate '"control group" match support' + depends on NETFILTER_ADVANCED + depends on CGROUPS + select CGROUP_NET_CLASSID + help + Socket/process control group matching allows you to match locally + generated packets based on which net_cls control group processes + belong to. + +config NETFILTER_XT_MATCH_CLUSTER + tristate '"cluster" match support' + depends on NF_CONNTRACK + depends on NETFILTER_ADVANCED + help + This option allows you to build work-load-sharing clusters of + network servers/stateful firewalls without having a dedicated + load-balancing router/server/switch. Basically, this match returns + true when the packet must be handled by this cluster node. Thus, + all nodes see all packets and this match decides which node handles + what packets. The work-load sharing algorithm is based on source + address hashing. + + If you say Y or M here, try `iptables -m cluster --help` for + more information. + +config NETFILTER_XT_MATCH_COMMENT + tristate '"comment" match support' + depends on NETFILTER_ADVANCED + help + This option adds a `comment' dummy-match, which allows you to put + comments in your iptables ruleset. + + If you want to compile it as a module, say M here and read + . If unsure, say `N'. + +config NETFILTER_XT_MATCH_CONNBYTES + tristate '"connbytes" per-connection counter match support' + depends on NF_CONNTRACK + depends on NETFILTER_ADVANCED + help + This option adds a `connbytes' match, which allows you to match the + number of bytes and/or packets for each direction within a connection. + + If you want to compile it as a module, say M here and read + . If unsure, say `N'. + +config NETFILTER_XT_MATCH_CONNLABEL + tristate '"connlabel" match support' + select NF_CONNTRACK_LABELS + depends on NF_CONNTRACK + depends on NETFILTER_ADVANCED + help + This match allows you to test and assign userspace-defined labels names + to a connection. The kernel only stores bit values - mapping + names to bits is done by userspace. + + Unlike connmark, more than 32 flag bits may be assigned to a + connection simultaneously. + +config NETFILTER_XT_MATCH_CONNLIMIT + tristate '"connlimit" match support' + depends on NF_CONNTRACK + depends on NETFILTER_ADVANCED + select NETFILTER_CONNCOUNT + help + This match allows you to match against the number of parallel + connections to a server per client IP address (or address block). + +config NETFILTER_XT_MATCH_CONNMARK + tristate '"connmark" connection mark match support' + depends on NF_CONNTRACK + depends on NETFILTER_ADVANCED + select NETFILTER_XT_CONNMARK + help + This is a backwards-compat option for the user's convenience + (e.g. when running oldconfig). It selects + CONFIG_NETFILTER_XT_CONNMARK (combined connmark/CONNMARK module). + +config NETFILTER_XT_MATCH_CONNTRACK + tristate '"conntrack" connection tracking match support' + depends on NF_CONNTRACK + default m if NETFILTER_ADVANCED=n + help + This is a general conntrack match module, a superset of the state match. + + It allows matching on additional conntrack information, which is + useful in complex configurations, such as NAT gateways with multiple + internet links or tunnels. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_CPU + tristate '"cpu" match support' + depends on NETFILTER_ADVANCED + help + CPU matching allows you to match packets based on the CPU + currently handling the packet. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_DCCP + tristate '"dccp" protocol match support' + depends on NETFILTER_ADVANCED + default IP_DCCP + help + With this option enabled, you will be able to use the iptables + `dccp' match in order to match on DCCP source/destination ports + and DCCP flags. + + If you want to compile it as a module, say M here and read + . If unsure, say `N'. + +config NETFILTER_XT_MATCH_DEVGROUP + tristate '"devgroup" match support' + depends on NETFILTER_ADVANCED + help + This options adds a `devgroup' match, which allows to match on the + device group a network device is assigned to. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_DSCP + tristate '"dscp" and "tos" match support' + depends on NETFILTER_ADVANCED + help + This option adds a `DSCP' match, which allows you to match against + the IPv4/IPv6 header DSCP field (differentiated services codepoint). + + The DSCP field can have any value between 0x0 and 0x3f inclusive. + + It will also add a "tos" match, which allows you to match packets + based on the Type Of Service fields of the IPv4 packet (which share + the same bits as DSCP). + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_ECN + tristate '"ecn" match support' + depends on NETFILTER_ADVANCED + help + This option adds an "ECN" match, which allows you to match against + the IPv4 and TCP header ECN fields. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_ESP + tristate '"esp" match support' + depends on NETFILTER_ADVANCED + help + This match extension allows you to match a range of SPIs + inside ESP header of IPSec packets. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_HASHLIMIT + tristate '"hashlimit" match support' + depends on IP6_NF_IPTABLES || IP6_NF_IPTABLES=n + depends on NETFILTER_ADVANCED + help + This option adds a `hashlimit' match. + + As opposed to `limit', this match dynamically creates a hash table + of limit buckets, based on your selection of source/destination + addresses and/or ports. + + It enables you to express policies like `10kpps for any given + destination address' or `500pps from any given source address' + with a single rule. + +config NETFILTER_XT_MATCH_HELPER + tristate '"helper" match support' + depends on NF_CONNTRACK + depends on NETFILTER_ADVANCED + help + Helper matching allows you to match packets in dynamic connections + tracked by a conntrack-helper, ie. nf_conntrack_ftp + + To compile it as a module, choose M here. If unsure, say Y. + +config NETFILTER_XT_MATCH_HL + tristate '"hl" hoplimit/TTL match support' + depends on NETFILTER_ADVANCED + help + HL matching allows you to match packets based on the hoplimit + in the IPv6 header, or the time-to-live field in the IPv4 + header of the packet. + +config NETFILTER_XT_MATCH_IPCOMP + tristate '"ipcomp" match support' + depends on NETFILTER_ADVANCED + help + This match extension allows you to match a range of CPIs(16 bits) + inside IPComp header of IPSec packets. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_IPRANGE + tristate '"iprange" address range match support' + depends on NETFILTER_ADVANCED + help + This option adds a "iprange" match, which allows you to match based on + an IP address range. (Normal iptables only matches on single addresses + with an optional mask.) + + If unsure, say M. + +config NETFILTER_XT_MATCH_IPVS + tristate '"ipvs" match support' + depends on IP_VS + depends on NETFILTER_ADVANCED + depends on NF_CONNTRACK + help + This option allows you to match against IPVS properties of a packet. + + If unsure, say N. + +config NETFILTER_XT_MATCH_L2TP + tristate '"l2tp" match support' + depends on NETFILTER_ADVANCED + default L2TP + help + This option adds an "L2TP" match, which allows you to match against + L2TP protocol header fields. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_LENGTH + tristate '"length" match support' + depends on NETFILTER_ADVANCED + help + This option allows you to match the length of a packet against a + specific value or range of values. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_LIMIT + tristate '"limit" match support' + depends on NETFILTER_ADVANCED + help + limit matching allows you to control the rate at which a rule can be + matched: mainly useful in combination with the LOG target ("LOG + target support", below) and to avoid some Denial of Service attacks. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_MAC + tristate '"mac" address match support' + depends on NETFILTER_ADVANCED + help + MAC matching allows you to match packets based on the source + Ethernet address of the packet. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_MARK + tristate '"mark" match support' + depends on NETFILTER_ADVANCED + select NETFILTER_XT_MARK + help + This is a backwards-compat option for the user's convenience + (e.g. when running oldconfig). It selects + CONFIG_NETFILTER_XT_MARK (combined mark/MARK module). + +config NETFILTER_XT_MATCH_MULTIPORT + tristate '"multiport" Multiple port match support' + depends on NETFILTER_ADVANCED + help + Multiport matching allows you to match TCP or UDP packets based on + a series of source or destination ports: normally a rule can only + match a single range of ports. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_NFACCT + tristate '"nfacct" match support' + depends on NETFILTER_ADVANCED + select NETFILTER_NETLINK_ACCT + help + This option allows you to use the extended accounting through + nfnetlink_acct. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_OSF + tristate '"osf" Passive OS fingerprint match' + depends on NETFILTER_ADVANCED + select NETFILTER_NETLINK_OSF + help + This option selects the Passive OS Fingerprinting match module + that allows to passively match the remote operating system by + analyzing incoming TCP SYN packets. + + Rules and loading software can be downloaded from + http://www.ioremap.net/projects/osf + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_OWNER + tristate '"owner" match support' + depends on NETFILTER_ADVANCED + help + Socket owner matching allows you to match locally-generated packets + based on who created the socket: the user or group. It is also + possible to check whether a socket actually exists. + +config NETFILTER_XT_MATCH_POLICY + tristate 'IPsec "policy" match support' + depends on XFRM + default m if NETFILTER_ADVANCED=n + help + Policy matching allows you to match packets based on the + IPsec policy that was used during decapsulation/will + be used during encapsulation. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_PHYSDEV + tristate '"physdev" match support' + depends on BRIDGE && BRIDGE_NETFILTER + depends on NETFILTER_ADVANCED + help + Physdev packet matching matches against the physical bridge ports + the IP packet arrived on or will leave by. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_PKTTYPE + tristate '"pkttype" packet type match support' + depends on NETFILTER_ADVANCED + help + Packet type matching allows you to match a packet by + its "class", eg. BROADCAST, MULTICAST, ... + + Typical usage: + iptables -A INPUT -m pkttype --pkt-type broadcast -j LOG + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_QUOTA + tristate '"quota" match support' + depends on NETFILTER_ADVANCED + help + This option adds a `quota' match, which allows to match on a + byte counter. + + If you want to compile it as a module, say M here and read + . If unsure, say `N'. + +config NETFILTER_XT_MATCH_RATEEST + tristate '"rateest" match support' + depends on NETFILTER_ADVANCED + select NETFILTER_XT_TARGET_RATEEST + help + This option adds a `rateest' match, which allows to match on the + rate estimated by the RATEEST target. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_REALM + tristate '"realm" match support' + depends on NETFILTER_ADVANCED + select IP_ROUTE_CLASSID + help + This option adds a `realm' match, which allows you to use the realm + key from the routing subsystem inside iptables. + + This match pretty much resembles the CONFIG_NET_CLS_ROUTE4 option + in tc world. + + If you want to compile it as a module, say M here and read + . If unsure, say `N'. + +config NETFILTER_XT_MATCH_RECENT + tristate '"recent" match support' + depends on NETFILTER_ADVANCED + help + This match is used for creating one or many lists of recently + used addresses and then matching against that/those list(s). + + Short options are available by using 'iptables -m recent -h' + Official Website: + +config NETFILTER_XT_MATCH_SCTP + tristate '"sctp" protocol match support' + depends on NETFILTER_ADVANCED + default IP_SCTP + help + With this option enabled, you will be able to use the + `sctp' match in order to match on SCTP source/destination ports + and SCTP chunk types. + + If you want to compile it as a module, say M here and read + . If unsure, say `N'. + +config NETFILTER_XT_MATCH_SOCKET + tristate '"socket" match support' + depends on NETFILTER_XTABLES + depends on NETFILTER_ADVANCED + depends on IPV6 || IPV6=n + depends on IP6_NF_IPTABLES || IP6_NF_IPTABLES=n + select NF_SOCKET_IPV4 + select NF_SOCKET_IPV6 if IP6_NF_IPTABLES + select NF_DEFRAG_IPV4 + select NF_DEFRAG_IPV6 if IP6_NF_IPTABLES != n + help + This option adds a `socket' match, which can be used to match + packets for which a TCP or UDP socket lookup finds a valid socket. + It can be used in combination with the MARK target and policy + routing to implement full featured non-locally bound sockets. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_STATE + tristate '"state" match support' + depends on NF_CONNTRACK + default m if NETFILTER_ADVANCED=n + help + Connection state matching allows you to match packets based on their + relationship to a tracked connection (ie. previous packets). This + is a powerful tool for packet classification. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_STATISTIC + tristate '"statistic" match support' + depends on NETFILTER_ADVANCED + help + This option adds a `statistic' match, which allows you to match + on packets periodically or randomly with a given percentage. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_STRING + tristate '"string" match support' + depends on NETFILTER_ADVANCED + select TEXTSEARCH + select TEXTSEARCH_KMP + select TEXTSEARCH_BM + select TEXTSEARCH_FSM + help + This option adds a `string' match, which allows you to look for + pattern matchings in packets. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_TCPMSS + tristate '"tcpmss" match support' + depends on NETFILTER_ADVANCED + help + This option adds a `tcpmss' match, which allows you to examine the + MSS value of TCP SYN packets, which control the maximum packet size + for that connection. + + To compile it as a module, choose M here. If unsure, say N. + +config NETFILTER_XT_MATCH_TIME + tristate '"time" match support' + depends on NETFILTER_ADVANCED + help + This option adds a "time" match, which allows you to match based on + the packet arrival time (at the machine which netfilter is running) + on) or departure time/date (for locally generated packets). + + If you say Y here, try `iptables -m time --help` for + more information. + + If you want to compile it as a module, say M here. + If unsure, say N. + +config NETFILTER_XT_MATCH_U32 + tristate '"u32" match support' + depends on NETFILTER_ADVANCED + help + u32 allows you to extract quantities of up to 4 bytes from a packet, + AND them with specified masks, shift them by specified amounts and + test whether the results are in any of a set of specified ranges. + The specification of what to extract is general enough to skip over + headers with lengths stored in the packet, as in IP or TCP header + lengths. + + Details and examples are in the kernel module source. + +endif # NETFILTER_XTABLES + +endmenu + +source "net/netfilter/ipset/Kconfig" + +source "net/netfilter/ipvs/Kconfig" diff --git a/net/netfilter/Makefile b/net/netfilter/Makefile new file mode 100644 index 000000000..0f060d100 --- /dev/null +++ b/net/netfilter/Makefile @@ -0,0 +1,228 @@ +# SPDX-License-Identifier: GPL-2.0 +netfilter-objs := core.o nf_log.o nf_queue.o nf_sockopt.o utils.o + +nf_conntrack-y := nf_conntrack_core.o nf_conntrack_standalone.o nf_conntrack_expect.o nf_conntrack_helper.o \ + nf_conntrack_proto.o nf_conntrack_proto_generic.o nf_conntrack_proto_tcp.o nf_conntrack_proto_udp.o \ + nf_conntrack_proto_icmp.o \ + nf_conntrack_extend.o nf_conntrack_acct.o nf_conntrack_seqadj.o + +nf_conntrack-$(subst m,y,$(CONFIG_IPV6)) += nf_conntrack_proto_icmpv6.o +nf_conntrack-$(CONFIG_NF_CONNTRACK_TIMEOUT) += nf_conntrack_timeout.o +nf_conntrack-$(CONFIG_NF_CONNTRACK_TIMESTAMP) += nf_conntrack_timestamp.o +nf_conntrack-$(CONFIG_NF_CONNTRACK_EVENTS) += nf_conntrack_ecache.o +nf_conntrack-$(CONFIG_NF_CONNTRACK_LABELS) += nf_conntrack_labels.o +nf_conntrack-$(CONFIG_NF_CT_PROTO_DCCP) += nf_conntrack_proto_dccp.o +nf_conntrack-$(CONFIG_NF_CT_PROTO_SCTP) += nf_conntrack_proto_sctp.o +nf_conntrack-$(CONFIG_NF_CT_PROTO_GRE) += nf_conntrack_proto_gre.o +ifeq ($(CONFIG_NF_CONNTRACK),m) +nf_conntrack-$(CONFIG_DEBUG_INFO_BTF_MODULES) += nf_conntrack_bpf.o +else ifeq ($(CONFIG_NF_CONNTRACK),y) +nf_conntrack-$(CONFIG_DEBUG_INFO_BTF) += nf_conntrack_bpf.o +endif + +obj-$(CONFIG_NETFILTER) = netfilter.o + +obj-$(CONFIG_NETFILTER_NETLINK) += nfnetlink.o +obj-$(CONFIG_NETFILTER_NETLINK_ACCT) += nfnetlink_acct.o +obj-$(CONFIG_NETFILTER_NETLINK_QUEUE) += nfnetlink_queue.o +obj-$(CONFIG_NETFILTER_NETLINK_LOG) += nfnetlink_log.o +obj-$(CONFIG_NETFILTER_NETLINK_OSF) += nfnetlink_osf.o +obj-$(CONFIG_NETFILTER_NETLINK_HOOK) += nfnetlink_hook.o + +# connection tracking +obj-$(CONFIG_NF_CONNTRACK) += nf_conntrack.o + +# netlink interface for nf_conntrack +obj-$(CONFIG_NF_CT_NETLINK) += nf_conntrack_netlink.o +obj-$(CONFIG_NF_CT_NETLINK_TIMEOUT) += nfnetlink_cttimeout.o +obj-$(CONFIG_NF_CT_NETLINK_HELPER) += nfnetlink_cthelper.o + +# connection tracking helpers +nf_conntrack_h323-objs := nf_conntrack_h323_main.o nf_conntrack_h323_asn1.o + +obj-$(CONFIG_NF_CONNTRACK_AMANDA) += nf_conntrack_amanda.o +obj-$(CONFIG_NF_CONNTRACK_FTP) += nf_conntrack_ftp.o +obj-$(CONFIG_NF_CONNTRACK_H323) += nf_conntrack_h323.o +obj-$(CONFIG_NF_CONNTRACK_IRC) += nf_conntrack_irc.o +obj-$(CONFIG_NF_CONNTRACK_BROADCAST) += nf_conntrack_broadcast.o +obj-$(CONFIG_NF_CONNTRACK_NETBIOS_NS) += nf_conntrack_netbios_ns.o +obj-$(CONFIG_NF_CONNTRACK_SNMP) += nf_conntrack_snmp.o +obj-$(CONFIG_NF_CONNTRACK_PPTP) += nf_conntrack_pptp.o +obj-$(CONFIG_NF_CONNTRACK_SANE) += nf_conntrack_sane.o +obj-$(CONFIG_NF_CONNTRACK_SIP) += nf_conntrack_sip.o +obj-$(CONFIG_NF_CONNTRACK_TFTP) += nf_conntrack_tftp.o + +nf_nat-y := nf_nat_core.o nf_nat_proto.o nf_nat_helper.o + +obj-$(CONFIG_NF_LOG_SYSLOG) += nf_log_syslog.o + +obj-$(CONFIG_NF_NAT) += nf_nat.o +nf_nat-$(CONFIG_NF_NAT_REDIRECT) += nf_nat_redirect.o +nf_nat-$(CONFIG_NF_NAT_MASQUERADE) += nf_nat_masquerade.o + +ifeq ($(CONFIG_NF_NAT),m) +nf_nat-$(CONFIG_DEBUG_INFO_BTF_MODULES) += nf_nat_bpf.o +else ifeq ($(CONFIG_NF_NAT),y) +nf_nat-$(CONFIG_DEBUG_INFO_BTF) += nf_nat_bpf.o +endif + +# NAT helpers +obj-$(CONFIG_NF_NAT_AMANDA) += nf_nat_amanda.o +obj-$(CONFIG_NF_NAT_FTP) += nf_nat_ftp.o +obj-$(CONFIG_NF_NAT_IRC) += nf_nat_irc.o +obj-$(CONFIG_NF_NAT_SIP) += nf_nat_sip.o +obj-$(CONFIG_NF_NAT_TFTP) += nf_nat_tftp.o + +# SYNPROXY +obj-$(CONFIG_NETFILTER_SYNPROXY) += nf_synproxy_core.o + +obj-$(CONFIG_NETFILTER_CONNCOUNT) += nf_conncount.o + +# generic packet duplication from netdev family +obj-$(CONFIG_NF_DUP_NETDEV) += nf_dup_netdev.o + +# nf_tables +nf_tables-objs := nf_tables_core.o nf_tables_api.o nft_chain_filter.o \ + nf_tables_trace.o nft_immediate.o nft_cmp.o nft_range.o \ + nft_bitwise.o nft_byteorder.o nft_payload.o nft_lookup.o \ + nft_dynset.o nft_meta.o nft_rt.o nft_exthdr.o nft_last.o \ + nft_counter.o nft_chain_route.o nf_tables_offload.o \ + nft_set_hash.o nft_set_bitmap.o nft_set_rbtree.o \ + nft_set_pipapo.o + +ifdef CONFIG_X86_64 +ifndef CONFIG_UML +nf_tables-objs += nft_set_pipapo_avx2.o +endif +endif + +obj-$(CONFIG_NF_TABLES) += nf_tables.o +obj-$(CONFIG_NFT_COMPAT) += nft_compat.o +obj-$(CONFIG_NFT_CONNLIMIT) += nft_connlimit.o +obj-$(CONFIG_NFT_NUMGEN) += nft_numgen.o +obj-$(CONFIG_NFT_CT) += nft_ct.o +obj-$(CONFIG_NFT_FLOW_OFFLOAD) += nft_flow_offload.o +obj-$(CONFIG_NFT_LIMIT) += nft_limit.o +obj-$(CONFIG_NFT_NAT) += nft_nat.o +obj-$(CONFIG_NFT_OBJREF) += nft_objref.o +obj-$(CONFIG_NFT_QUEUE) += nft_queue.o +obj-$(CONFIG_NFT_QUOTA) += nft_quota.o +obj-$(CONFIG_NFT_REJECT) += nft_reject.o +obj-$(CONFIG_NFT_REJECT_INET) += nft_reject_inet.o +obj-$(CONFIG_NFT_REJECT_NETDEV) += nft_reject_netdev.o +obj-$(CONFIG_NFT_TUNNEL) += nft_tunnel.o +obj-$(CONFIG_NFT_LOG) += nft_log.o +obj-$(CONFIG_NFT_MASQ) += nft_masq.o +obj-$(CONFIG_NFT_REDIR) += nft_redir.o +obj-$(CONFIG_NFT_HASH) += nft_hash.o +obj-$(CONFIG_NFT_FIB) += nft_fib.o +obj-$(CONFIG_NFT_FIB_INET) += nft_fib_inet.o +obj-$(CONFIG_NFT_FIB_NETDEV) += nft_fib_netdev.o +obj-$(CONFIG_NFT_SOCKET) += nft_socket.o +obj-$(CONFIG_NFT_OSF) += nft_osf.o +obj-$(CONFIG_NFT_TPROXY) += nft_tproxy.o +obj-$(CONFIG_NFT_XFRM) += nft_xfrm.o +obj-$(CONFIG_NFT_SYNPROXY) += nft_synproxy.o + +obj-$(CONFIG_NFT_NAT) += nft_chain_nat.o + +# nf_tables netdev +obj-$(CONFIG_NFT_DUP_NETDEV) += nft_dup_netdev.o +obj-$(CONFIG_NFT_FWD_NETDEV) += nft_fwd_netdev.o + +# flow table infrastructure +obj-$(CONFIG_NF_FLOW_TABLE) += nf_flow_table.o +nf_flow_table-objs := nf_flow_table_core.o nf_flow_table_ip.o \ + nf_flow_table_offload.o +nf_flow_table-$(CONFIG_NF_FLOW_TABLE_PROCFS) += nf_flow_table_procfs.o + +obj-$(CONFIG_NF_FLOW_TABLE_INET) += nf_flow_table_inet.o + +# generic X tables +obj-$(CONFIG_NETFILTER_XTABLES) += x_tables.o xt_tcpudp.o + +# combos +obj-$(CONFIG_NETFILTER_XT_MARK) += xt_mark.o +obj-$(CONFIG_NETFILTER_XT_CONNMARK) += xt_connmark.o +obj-$(CONFIG_NETFILTER_XT_SET) += xt_set.o +obj-$(CONFIG_NETFILTER_XT_NAT) += xt_nat.o + +# targets +obj-$(CONFIG_NETFILTER_XT_TARGET_AUDIT) += xt_AUDIT.o +obj-$(CONFIG_NETFILTER_XT_TARGET_CHECKSUM) += xt_CHECKSUM.o +obj-$(CONFIG_NETFILTER_XT_TARGET_CLASSIFY) += xt_CLASSIFY.o +obj-$(CONFIG_NETFILTER_XT_TARGET_CONNSECMARK) += xt_CONNSECMARK.o +obj-$(CONFIG_NETFILTER_XT_TARGET_CT) += xt_CT.o +obj-$(CONFIG_NETFILTER_XT_TARGET_DSCP) += xt_DSCP.o +obj-$(CONFIG_NETFILTER_XT_TARGET_HL) += xt_HL.o +obj-$(CONFIG_NETFILTER_XT_TARGET_HMARK) += xt_HMARK.o +obj-$(CONFIG_NETFILTER_XT_TARGET_LED) += xt_LED.o +obj-$(CONFIG_NETFILTER_XT_TARGET_LOG) += xt_LOG.o +obj-$(CONFIG_NETFILTER_XT_TARGET_NETMAP) += xt_NETMAP.o +obj-$(CONFIG_NETFILTER_XT_TARGET_NFLOG) += xt_NFLOG.o +obj-$(CONFIG_NETFILTER_XT_TARGET_NFQUEUE) += xt_NFQUEUE.o +obj-$(CONFIG_NETFILTER_XT_TARGET_RATEEST) += xt_RATEEST.o +obj-$(CONFIG_NETFILTER_XT_TARGET_REDIRECT) += xt_REDIRECT.o +obj-$(CONFIG_NETFILTER_XT_TARGET_MASQUERADE) += xt_MASQUERADE.o +obj-$(CONFIG_NETFILTER_XT_TARGET_SECMARK) += xt_SECMARK.o +obj-$(CONFIG_NETFILTER_XT_TARGET_TPROXY) += xt_TPROXY.o +obj-$(CONFIG_NETFILTER_XT_TARGET_TCPMSS) += xt_TCPMSS.o +obj-$(CONFIG_NETFILTER_XT_TARGET_TCPOPTSTRIP) += xt_TCPOPTSTRIP.o +obj-$(CONFIG_NETFILTER_XT_TARGET_TEE) += xt_TEE.o +obj-$(CONFIG_NETFILTER_XT_TARGET_TRACE) += xt_TRACE.o +obj-$(CONFIG_NETFILTER_XT_TARGET_IDLETIMER) += xt_IDLETIMER.o + +# matches +obj-$(CONFIG_NETFILTER_XT_MATCH_ADDRTYPE) += xt_addrtype.o +obj-$(CONFIG_NETFILTER_XT_MATCH_BPF) += xt_bpf.o +obj-$(CONFIG_NETFILTER_XT_MATCH_CLUSTER) += xt_cluster.o +obj-$(CONFIG_NETFILTER_XT_MATCH_COMMENT) += xt_comment.o +obj-$(CONFIG_NETFILTER_XT_MATCH_CONNBYTES) += xt_connbytes.o +obj-$(CONFIG_NETFILTER_XT_MATCH_CONNLABEL) += xt_connlabel.o +obj-$(CONFIG_NETFILTER_XT_MATCH_CONNLIMIT) += xt_connlimit.o +obj-$(CONFIG_NETFILTER_XT_MATCH_CONNTRACK) += xt_conntrack.o +obj-$(CONFIG_NETFILTER_XT_MATCH_CPU) += xt_cpu.o +obj-$(CONFIG_NETFILTER_XT_MATCH_DCCP) += xt_dccp.o +obj-$(CONFIG_NETFILTER_XT_MATCH_DEVGROUP) += xt_devgroup.o +obj-$(CONFIG_NETFILTER_XT_MATCH_DSCP) += xt_dscp.o +obj-$(CONFIG_NETFILTER_XT_MATCH_ECN) += xt_ecn.o +obj-$(CONFIG_NETFILTER_XT_MATCH_ESP) += xt_esp.o +obj-$(CONFIG_NETFILTER_XT_MATCH_HASHLIMIT) += xt_hashlimit.o +obj-$(CONFIG_NETFILTER_XT_MATCH_HELPER) += xt_helper.o +obj-$(CONFIG_NETFILTER_XT_MATCH_HL) += xt_hl.o +obj-$(CONFIG_NETFILTER_XT_MATCH_IPCOMP) += xt_ipcomp.o +obj-$(CONFIG_NETFILTER_XT_MATCH_IPRANGE) += xt_iprange.o +obj-$(CONFIG_NETFILTER_XT_MATCH_IPVS) += xt_ipvs.o +obj-$(CONFIG_NETFILTER_XT_MATCH_L2TP) += xt_l2tp.o +obj-$(CONFIG_NETFILTER_XT_MATCH_LENGTH) += xt_length.o +obj-$(CONFIG_NETFILTER_XT_MATCH_LIMIT) += xt_limit.o +obj-$(CONFIG_NETFILTER_XT_MATCH_MAC) += xt_mac.o +obj-$(CONFIG_NETFILTER_XT_MATCH_MULTIPORT) += xt_multiport.o +obj-$(CONFIG_NETFILTER_XT_MATCH_NFACCT) += xt_nfacct.o +obj-$(CONFIG_NETFILTER_XT_MATCH_OSF) += xt_osf.o +obj-$(CONFIG_NETFILTER_XT_MATCH_OWNER) += xt_owner.o +obj-$(CONFIG_NETFILTER_XT_MATCH_CGROUP) += xt_cgroup.o +obj-$(CONFIG_NETFILTER_XT_MATCH_PHYSDEV) += xt_physdev.o +obj-$(CONFIG_NETFILTER_XT_MATCH_PKTTYPE) += xt_pkttype.o +obj-$(CONFIG_NETFILTER_XT_MATCH_POLICY) += xt_policy.o +obj-$(CONFIG_NETFILTER_XT_MATCH_QUOTA) += xt_quota.o +obj-$(CONFIG_NETFILTER_XT_MATCH_RATEEST) += xt_rateest.o +obj-$(CONFIG_NETFILTER_XT_MATCH_REALM) += xt_realm.o +obj-$(CONFIG_NETFILTER_XT_MATCH_RECENT) += xt_recent.o +obj-$(CONFIG_NETFILTER_XT_MATCH_SCTP) += xt_sctp.o +obj-$(CONFIG_NETFILTER_XT_MATCH_SOCKET) += xt_socket.o +obj-$(CONFIG_NETFILTER_XT_MATCH_STATE) += xt_state.o +obj-$(CONFIG_NETFILTER_XT_MATCH_STATISTIC) += xt_statistic.o +obj-$(CONFIG_NETFILTER_XT_MATCH_STRING) += xt_string.o +obj-$(CONFIG_NETFILTER_XT_MATCH_TCPMSS) += xt_tcpmss.o +obj-$(CONFIG_NETFILTER_XT_MATCH_TIME) += xt_time.o +obj-$(CONFIG_NETFILTER_XT_MATCH_U32) += xt_u32.o + +# ipset +obj-$(CONFIG_IP_SET) += ipset/ + +# IPVS +obj-$(CONFIG_IP_VS) += ipvs/ + +# lwtunnel +obj-$(CONFIG_LWTUNNEL) += nf_hooks_lwtunnel.o diff --git a/net/netfilter/core.c b/net/netfilter/core.c new file mode 100644 index 000000000..55a7f72d5 --- /dev/null +++ b/net/netfilter/core.c @@ -0,0 +1,793 @@ +/* netfilter.c: look after the filters for various protocols. + * Heavily influenced by the old firewall.c by David Bonn and Alan Cox. + * + * Thanks to Rob `CmdrTaco' Malda for not influencing this code in any + * way. + * + * This code is GPL. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "nf_internals.h" + +const struct nf_ipv6_ops __rcu *nf_ipv6_ops __read_mostly; +EXPORT_SYMBOL_GPL(nf_ipv6_ops); + +DEFINE_PER_CPU(bool, nf_skb_duplicated); +EXPORT_SYMBOL_GPL(nf_skb_duplicated); + +#ifdef CONFIG_JUMP_LABEL +struct static_key nf_hooks_needed[NFPROTO_NUMPROTO][NF_MAX_HOOKS]; +EXPORT_SYMBOL(nf_hooks_needed); +#endif + +static DEFINE_MUTEX(nf_hook_mutex); + +/* max hooks per family/hooknum */ +#define MAX_HOOK_COUNT 1024 + +#define nf_entry_dereference(e) \ + rcu_dereference_protected(e, lockdep_is_held(&nf_hook_mutex)) + +static struct nf_hook_entries *allocate_hook_entries_size(u16 num) +{ + struct nf_hook_entries *e; + size_t alloc = sizeof(*e) + + sizeof(struct nf_hook_entry) * num + + sizeof(struct nf_hook_ops *) * num + + sizeof(struct nf_hook_entries_rcu_head); + + if (num == 0) + return NULL; + + e = kvzalloc(alloc, GFP_KERNEL_ACCOUNT); + if (e) + e->num_hook_entries = num; + return e; +} + +static void __nf_hook_entries_free(struct rcu_head *h) +{ + struct nf_hook_entries_rcu_head *head; + + head = container_of(h, struct nf_hook_entries_rcu_head, head); + kvfree(head->allocation); +} + +static void nf_hook_entries_free(struct nf_hook_entries *e) +{ + struct nf_hook_entries_rcu_head *head; + struct nf_hook_ops **ops; + unsigned int num; + + if (!e) + return; + + num = e->num_hook_entries; + ops = nf_hook_entries_get_hook_ops(e); + head = (void *)&ops[num]; + head->allocation = e; + call_rcu(&head->head, __nf_hook_entries_free); +} + +static unsigned int accept_all(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + return NF_ACCEPT; /* ACCEPT makes nf_hook_slow call next hook */ +} + +static const struct nf_hook_ops dummy_ops = { + .hook = accept_all, + .priority = INT_MIN, +}; + +static struct nf_hook_entries * +nf_hook_entries_grow(const struct nf_hook_entries *old, + const struct nf_hook_ops *reg) +{ + unsigned int i, alloc_entries, nhooks, old_entries; + struct nf_hook_ops **orig_ops = NULL; + struct nf_hook_ops **new_ops; + struct nf_hook_entries *new; + bool inserted = false; + + alloc_entries = 1; + old_entries = old ? old->num_hook_entries : 0; + + if (old) { + orig_ops = nf_hook_entries_get_hook_ops(old); + + for (i = 0; i < old_entries; i++) { + if (orig_ops[i] != &dummy_ops) + alloc_entries++; + } + } + + if (alloc_entries > MAX_HOOK_COUNT) + return ERR_PTR(-E2BIG); + + new = allocate_hook_entries_size(alloc_entries); + if (!new) + return ERR_PTR(-ENOMEM); + + new_ops = nf_hook_entries_get_hook_ops(new); + + i = 0; + nhooks = 0; + while (i < old_entries) { + if (orig_ops[i] == &dummy_ops) { + ++i; + continue; + } + + if (inserted || reg->priority > orig_ops[i]->priority) { + new_ops[nhooks] = (void *)orig_ops[i]; + new->hooks[nhooks] = old->hooks[i]; + i++; + } else { + new_ops[nhooks] = (void *)reg; + new->hooks[nhooks].hook = reg->hook; + new->hooks[nhooks].priv = reg->priv; + inserted = true; + } + nhooks++; + } + + if (!inserted) { + new_ops[nhooks] = (void *)reg; + new->hooks[nhooks].hook = reg->hook; + new->hooks[nhooks].priv = reg->priv; + } + + return new; +} + +static void hooks_validate(const struct nf_hook_entries *hooks) +{ +#ifdef CONFIG_DEBUG_MISC + struct nf_hook_ops **orig_ops; + int prio = INT_MIN; + size_t i = 0; + + orig_ops = nf_hook_entries_get_hook_ops(hooks); + + for (i = 0; i < hooks->num_hook_entries; i++) { + if (orig_ops[i] == &dummy_ops) + continue; + + WARN_ON(orig_ops[i]->priority < prio); + + if (orig_ops[i]->priority > prio) + prio = orig_ops[i]->priority; + } +#endif +} + +int nf_hook_entries_insert_raw(struct nf_hook_entries __rcu **pp, + const struct nf_hook_ops *reg) +{ + struct nf_hook_entries *new_hooks; + struct nf_hook_entries *p; + + p = rcu_dereference_raw(*pp); + new_hooks = nf_hook_entries_grow(p, reg); + if (IS_ERR(new_hooks)) + return PTR_ERR(new_hooks); + + hooks_validate(new_hooks); + + rcu_assign_pointer(*pp, new_hooks); + + BUG_ON(p == new_hooks); + nf_hook_entries_free(p); + return 0; +} +EXPORT_SYMBOL_GPL(nf_hook_entries_insert_raw); + +/* + * __nf_hook_entries_try_shrink - try to shrink hook array + * + * @old -- current hook blob at @pp + * @pp -- location of hook blob + * + * Hook unregistration must always succeed, so to-be-removed hooks + * are replaced by a dummy one that will just move to next hook. + * + * This counts the current dummy hooks, attempts to allocate new blob, + * copies the live hooks, then replaces and discards old one. + * + * return values: + * + * Returns address to free, or NULL. + */ +static void *__nf_hook_entries_try_shrink(struct nf_hook_entries *old, + struct nf_hook_entries __rcu **pp) +{ + unsigned int i, j, skip = 0, hook_entries; + struct nf_hook_entries *new = NULL; + struct nf_hook_ops **orig_ops; + struct nf_hook_ops **new_ops; + + if (WARN_ON_ONCE(!old)) + return NULL; + + orig_ops = nf_hook_entries_get_hook_ops(old); + for (i = 0; i < old->num_hook_entries; i++) { + if (orig_ops[i] == &dummy_ops) + skip++; + } + + /* if skip == hook_entries all hooks have been removed */ + hook_entries = old->num_hook_entries; + if (skip == hook_entries) + goto out_assign; + + if (skip == 0) + return NULL; + + hook_entries -= skip; + new = allocate_hook_entries_size(hook_entries); + if (!new) + return NULL; + + new_ops = nf_hook_entries_get_hook_ops(new); + for (i = 0, j = 0; i < old->num_hook_entries; i++) { + if (orig_ops[i] == &dummy_ops) + continue; + new->hooks[j] = old->hooks[i]; + new_ops[j] = (void *)orig_ops[i]; + j++; + } + hooks_validate(new); +out_assign: + rcu_assign_pointer(*pp, new); + return old; +} + +static struct nf_hook_entries __rcu ** +nf_hook_entry_head(struct net *net, int pf, unsigned int hooknum, + struct net_device *dev) +{ + switch (pf) { + case NFPROTO_NETDEV: + break; +#ifdef CONFIG_NETFILTER_FAMILY_ARP + case NFPROTO_ARP: + if (WARN_ON_ONCE(ARRAY_SIZE(net->nf.hooks_arp) <= hooknum)) + return NULL; + return net->nf.hooks_arp + hooknum; +#endif +#ifdef CONFIG_NETFILTER_FAMILY_BRIDGE + case NFPROTO_BRIDGE: + if (WARN_ON_ONCE(ARRAY_SIZE(net->nf.hooks_bridge) <= hooknum)) + return NULL; + return net->nf.hooks_bridge + hooknum; +#endif +#ifdef CONFIG_NETFILTER_INGRESS + case NFPROTO_INET: + if (WARN_ON_ONCE(hooknum != NF_INET_INGRESS)) + return NULL; + if (!dev || dev_net(dev) != net) { + WARN_ON_ONCE(1); + return NULL; + } + return &dev->nf_hooks_ingress; +#endif + case NFPROTO_IPV4: + if (WARN_ON_ONCE(ARRAY_SIZE(net->nf.hooks_ipv4) <= hooknum)) + return NULL; + return net->nf.hooks_ipv4 + hooknum; + case NFPROTO_IPV6: + if (WARN_ON_ONCE(ARRAY_SIZE(net->nf.hooks_ipv6) <= hooknum)) + return NULL; + return net->nf.hooks_ipv6 + hooknum; + default: + WARN_ON_ONCE(1); + return NULL; + } + +#ifdef CONFIG_NETFILTER_INGRESS + if (hooknum == NF_NETDEV_INGRESS) { + if (dev && dev_net(dev) == net) + return &dev->nf_hooks_ingress; + } +#endif +#ifdef CONFIG_NETFILTER_EGRESS + if (hooknum == NF_NETDEV_EGRESS) { + if (dev && dev_net(dev) == net) + return &dev->nf_hooks_egress; + } +#endif + WARN_ON_ONCE(1); + return NULL; +} + +static int nf_ingress_check(struct net *net, const struct nf_hook_ops *reg, + int hooknum) +{ +#ifndef CONFIG_NETFILTER_INGRESS + if (reg->hooknum == hooknum) + return -EOPNOTSUPP; +#endif + if (reg->hooknum != hooknum || + !reg->dev || dev_net(reg->dev) != net) + return -EINVAL; + + return 0; +} + +static inline bool __maybe_unused nf_ingress_hook(const struct nf_hook_ops *reg, + int pf) +{ + if ((pf == NFPROTO_NETDEV && reg->hooknum == NF_NETDEV_INGRESS) || + (pf == NFPROTO_INET && reg->hooknum == NF_INET_INGRESS)) + return true; + + return false; +} + +static inline bool __maybe_unused nf_egress_hook(const struct nf_hook_ops *reg, + int pf) +{ + return pf == NFPROTO_NETDEV && reg->hooknum == NF_NETDEV_EGRESS; +} + +static void nf_static_key_inc(const struct nf_hook_ops *reg, int pf) +{ +#ifdef CONFIG_JUMP_LABEL + int hooknum; + + if (pf == NFPROTO_INET && reg->hooknum == NF_INET_INGRESS) { + pf = NFPROTO_NETDEV; + hooknum = NF_NETDEV_INGRESS; + } else { + hooknum = reg->hooknum; + } + static_key_slow_inc(&nf_hooks_needed[pf][hooknum]); +#endif +} + +static void nf_static_key_dec(const struct nf_hook_ops *reg, int pf) +{ +#ifdef CONFIG_JUMP_LABEL + int hooknum; + + if (pf == NFPROTO_INET && reg->hooknum == NF_INET_INGRESS) { + pf = NFPROTO_NETDEV; + hooknum = NF_NETDEV_INGRESS; + } else { + hooknum = reg->hooknum; + } + static_key_slow_dec(&nf_hooks_needed[pf][hooknum]); +#endif +} + +static int __nf_register_net_hook(struct net *net, int pf, + const struct nf_hook_ops *reg) +{ + struct nf_hook_entries *p, *new_hooks; + struct nf_hook_entries __rcu **pp; + int err; + + switch (pf) { + case NFPROTO_NETDEV: +#ifndef CONFIG_NETFILTER_INGRESS + if (reg->hooknum == NF_NETDEV_INGRESS) + return -EOPNOTSUPP; +#endif +#ifndef CONFIG_NETFILTER_EGRESS + if (reg->hooknum == NF_NETDEV_EGRESS) + return -EOPNOTSUPP; +#endif + if ((reg->hooknum != NF_NETDEV_INGRESS && + reg->hooknum != NF_NETDEV_EGRESS) || + !reg->dev || dev_net(reg->dev) != net) + return -EINVAL; + break; + case NFPROTO_INET: + if (reg->hooknum != NF_INET_INGRESS) + break; + + err = nf_ingress_check(net, reg, NF_INET_INGRESS); + if (err < 0) + return err; + break; + } + + pp = nf_hook_entry_head(net, pf, reg->hooknum, reg->dev); + if (!pp) + return -EINVAL; + + mutex_lock(&nf_hook_mutex); + + p = nf_entry_dereference(*pp); + new_hooks = nf_hook_entries_grow(p, reg); + + if (!IS_ERR(new_hooks)) { + hooks_validate(new_hooks); + rcu_assign_pointer(*pp, new_hooks); + } + + mutex_unlock(&nf_hook_mutex); + if (IS_ERR(new_hooks)) + return PTR_ERR(new_hooks); + +#ifdef CONFIG_NETFILTER_INGRESS + if (nf_ingress_hook(reg, pf)) + net_inc_ingress_queue(); +#endif +#ifdef CONFIG_NETFILTER_EGRESS + if (nf_egress_hook(reg, pf)) + net_inc_egress_queue(); +#endif + nf_static_key_inc(reg, pf); + + BUG_ON(p == new_hooks); + nf_hook_entries_free(p); + return 0; +} + +/* + * nf_remove_net_hook - remove a hook from blob + * + * @oldp: current address of hook blob + * @unreg: hook to unregister + * + * This cannot fail, hook unregistration must always succeed. + * Therefore replace the to-be-removed hook with a dummy hook. + */ +static bool nf_remove_net_hook(struct nf_hook_entries *old, + const struct nf_hook_ops *unreg) +{ + struct nf_hook_ops **orig_ops; + unsigned int i; + + orig_ops = nf_hook_entries_get_hook_ops(old); + for (i = 0; i < old->num_hook_entries; i++) { + if (orig_ops[i] != unreg) + continue; + WRITE_ONCE(old->hooks[i].hook, accept_all); + WRITE_ONCE(orig_ops[i], (void *)&dummy_ops); + return true; + } + + return false; +} + +static void __nf_unregister_net_hook(struct net *net, int pf, + const struct nf_hook_ops *reg) +{ + struct nf_hook_entries __rcu **pp; + struct nf_hook_entries *p; + + pp = nf_hook_entry_head(net, pf, reg->hooknum, reg->dev); + if (!pp) + return; + + mutex_lock(&nf_hook_mutex); + + p = nf_entry_dereference(*pp); + if (WARN_ON_ONCE(!p)) { + mutex_unlock(&nf_hook_mutex); + return; + } + + if (nf_remove_net_hook(p, reg)) { +#ifdef CONFIG_NETFILTER_INGRESS + if (nf_ingress_hook(reg, pf)) + net_dec_ingress_queue(); +#endif +#ifdef CONFIG_NETFILTER_EGRESS + if (nf_egress_hook(reg, pf)) + net_dec_egress_queue(); +#endif + nf_static_key_dec(reg, pf); + } else { + WARN_ONCE(1, "hook not found, pf %d num %d", pf, reg->hooknum); + } + + p = __nf_hook_entries_try_shrink(p, pp); + mutex_unlock(&nf_hook_mutex); + if (!p) + return; + + nf_queue_nf_hook_drop(net); + nf_hook_entries_free(p); +} + +void nf_unregister_net_hook(struct net *net, const struct nf_hook_ops *reg) +{ + if (reg->pf == NFPROTO_INET) { + if (reg->hooknum == NF_INET_INGRESS) { + __nf_unregister_net_hook(net, NFPROTO_INET, reg); + } else { + __nf_unregister_net_hook(net, NFPROTO_IPV4, reg); + __nf_unregister_net_hook(net, NFPROTO_IPV6, reg); + } + } else { + __nf_unregister_net_hook(net, reg->pf, reg); + } +} +EXPORT_SYMBOL(nf_unregister_net_hook); + +void nf_hook_entries_delete_raw(struct nf_hook_entries __rcu **pp, + const struct nf_hook_ops *reg) +{ + struct nf_hook_entries *p; + + p = rcu_dereference_raw(*pp); + if (nf_remove_net_hook(p, reg)) { + p = __nf_hook_entries_try_shrink(p, pp); + nf_hook_entries_free(p); + } +} +EXPORT_SYMBOL_GPL(nf_hook_entries_delete_raw); + +int nf_register_net_hook(struct net *net, const struct nf_hook_ops *reg) +{ + int err; + + if (reg->pf == NFPROTO_INET) { + if (reg->hooknum == NF_INET_INGRESS) { + err = __nf_register_net_hook(net, NFPROTO_INET, reg); + if (err < 0) + return err; + } else { + err = __nf_register_net_hook(net, NFPROTO_IPV4, reg); + if (err < 0) + return err; + + err = __nf_register_net_hook(net, NFPROTO_IPV6, reg); + if (err < 0) { + __nf_unregister_net_hook(net, NFPROTO_IPV4, reg); + return err; + } + } + } else { + err = __nf_register_net_hook(net, reg->pf, reg); + if (err < 0) + return err; + } + + return 0; +} +EXPORT_SYMBOL(nf_register_net_hook); + +int nf_register_net_hooks(struct net *net, const struct nf_hook_ops *reg, + unsigned int n) +{ + unsigned int i; + int err = 0; + + for (i = 0; i < n; i++) { + err = nf_register_net_hook(net, ®[i]); + if (err) + goto err; + } + return err; + +err: + if (i > 0) + nf_unregister_net_hooks(net, reg, i); + return err; +} +EXPORT_SYMBOL(nf_register_net_hooks); + +void nf_unregister_net_hooks(struct net *net, const struct nf_hook_ops *reg, + unsigned int hookcount) +{ + unsigned int i; + + for (i = 0; i < hookcount; i++) + nf_unregister_net_hook(net, ®[i]); +} +EXPORT_SYMBOL(nf_unregister_net_hooks); + +/* Returns 1 if okfn() needs to be executed by the caller, + * -EPERM for NF_DROP, 0 otherwise. Caller must hold rcu_read_lock. */ +int nf_hook_slow(struct sk_buff *skb, struct nf_hook_state *state, + const struct nf_hook_entries *e, unsigned int s) +{ + unsigned int verdict; + int ret; + + for (; s < e->num_hook_entries; s++) { + verdict = nf_hook_entry_hookfn(&e->hooks[s], skb, state); + switch (verdict & NF_VERDICT_MASK) { + case NF_ACCEPT: + break; + case NF_DROP: + kfree_skb_reason(skb, + SKB_DROP_REASON_NETFILTER_DROP); + ret = NF_DROP_GETERR(verdict); + if (ret == 0) + ret = -EPERM; + return ret; + case NF_QUEUE: + ret = nf_queue(skb, state, s, verdict); + if (ret == 1) + continue; + return ret; + default: + /* Implicit handling for NF_STOLEN, as well as any other + * non conventional verdicts. + */ + return 0; + } + } + + return 1; +} +EXPORT_SYMBOL(nf_hook_slow); + +void nf_hook_slow_list(struct list_head *head, struct nf_hook_state *state, + const struct nf_hook_entries *e) +{ + struct sk_buff *skb, *next; + struct list_head sublist; + int ret; + + INIT_LIST_HEAD(&sublist); + + list_for_each_entry_safe(skb, next, head, list) { + skb_list_del_init(skb); + ret = nf_hook_slow(skb, state, e, 0); + if (ret == 1) + list_add_tail(&skb->list, &sublist); + } + /* Put passed packets back on main list */ + list_splice(&sublist, head); +} +EXPORT_SYMBOL(nf_hook_slow_list); + +/* This needs to be compiled in any case to avoid dependencies between the + * nfnetlink_queue code and nf_conntrack. + */ +const struct nfnl_ct_hook __rcu *nfnl_ct_hook __read_mostly; +EXPORT_SYMBOL_GPL(nfnl_ct_hook); + +const struct nf_ct_hook __rcu *nf_ct_hook __read_mostly; +EXPORT_SYMBOL_GPL(nf_ct_hook); + +#if IS_ENABLED(CONFIG_NF_CONNTRACK) +u8 nf_ctnetlink_has_listener; +EXPORT_SYMBOL_GPL(nf_ctnetlink_has_listener); + +const struct nf_nat_hook __rcu *nf_nat_hook __read_mostly; +EXPORT_SYMBOL_GPL(nf_nat_hook); + +/* This does not belong here, but locally generated errors need it if connection + * tracking in use: without this, connection may not be in hash table, and hence + * manufactured ICMP or RST packets will not be associated with it. + */ +void nf_ct_attach(struct sk_buff *new, const struct sk_buff *skb) +{ + const struct nf_ct_hook *ct_hook; + + if (skb->_nfct) { + rcu_read_lock(); + ct_hook = rcu_dereference(nf_ct_hook); + if (ct_hook) + ct_hook->attach(new, skb); + rcu_read_unlock(); + } +} +EXPORT_SYMBOL(nf_ct_attach); + +void nf_conntrack_destroy(struct nf_conntrack *nfct) +{ + const struct nf_ct_hook *ct_hook; + + rcu_read_lock(); + ct_hook = rcu_dereference(nf_ct_hook); + if (ct_hook) + ct_hook->destroy(nfct); + rcu_read_unlock(); + + WARN_ON(!ct_hook); +} +EXPORT_SYMBOL(nf_conntrack_destroy); + +bool nf_ct_get_tuple_skb(struct nf_conntrack_tuple *dst_tuple, + const struct sk_buff *skb) +{ + const struct nf_ct_hook *ct_hook; + bool ret = false; + + rcu_read_lock(); + ct_hook = rcu_dereference(nf_ct_hook); + if (ct_hook) + ret = ct_hook->get_tuple_skb(dst_tuple, skb); + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL(nf_ct_get_tuple_skb); + +/* Built-in default zone used e.g. by modules. */ +const struct nf_conntrack_zone nf_ct_zone_dflt = { + .id = NF_CT_DEFAULT_ZONE_ID, + .dir = NF_CT_DEFAULT_ZONE_DIR, +}; +EXPORT_SYMBOL_GPL(nf_ct_zone_dflt); +#endif /* CONFIG_NF_CONNTRACK */ + +static void __net_init +__netfilter_net_init(struct nf_hook_entries __rcu **e, int max) +{ + int h; + + for (h = 0; h < max; h++) + RCU_INIT_POINTER(e[h], NULL); +} + +static int __net_init netfilter_net_init(struct net *net) +{ + __netfilter_net_init(net->nf.hooks_ipv4, ARRAY_SIZE(net->nf.hooks_ipv4)); + __netfilter_net_init(net->nf.hooks_ipv6, ARRAY_SIZE(net->nf.hooks_ipv6)); +#ifdef CONFIG_NETFILTER_FAMILY_ARP + __netfilter_net_init(net->nf.hooks_arp, ARRAY_SIZE(net->nf.hooks_arp)); +#endif +#ifdef CONFIG_NETFILTER_FAMILY_BRIDGE + __netfilter_net_init(net->nf.hooks_bridge, ARRAY_SIZE(net->nf.hooks_bridge)); +#endif +#ifdef CONFIG_PROC_FS + net->nf.proc_netfilter = proc_net_mkdir(net, "netfilter", + net->proc_net); + if (!net->nf.proc_netfilter) { + if (!net_eq(net, &init_net)) + pr_err("cannot create netfilter proc entry"); + + return -ENOMEM; + } +#endif + + return 0; +} + +static void __net_exit netfilter_net_exit(struct net *net) +{ + remove_proc_entry("netfilter", net->proc_net); +} + +static struct pernet_operations netfilter_net_ops = { + .init = netfilter_net_init, + .exit = netfilter_net_exit, +}; + +int __init netfilter_init(void) +{ + int ret; + + ret = register_pernet_subsys(&netfilter_net_ops); + if (ret < 0) + goto err; + + ret = netfilter_log_init(); + if (ret < 0) + goto err_pernet; + + return 0; +err_pernet: + unregister_pernet_subsys(&netfilter_net_ops); +err: + return ret; +} diff --git a/net/netfilter/ipset/Kconfig b/net/netfilter/ipset/Kconfig new file mode 100644 index 000000000..3c273483d --- /dev/null +++ b/net/netfilter/ipset/Kconfig @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: GPL-2.0-only +menuconfig IP_SET + tristate "IP set support" + depends on INET && NETFILTER + select NETFILTER_NETLINK + help + This option adds IP set support to the kernel. + In order to define and use the sets, you need the userspace utility + ipset(8). You can use the sets in netfilter via the "set" match + and "SET" target. + + To compile it as a module, choose M here. If unsure, say N. + +if IP_SET + +config IP_SET_MAX + int "Maximum number of IP sets" + default 256 + range 2 65534 + depends on IP_SET + help + You can define here default value of the maximum number + of IP sets for the kernel. + + The value can be overridden by the 'max_sets' module + parameter of the 'ip_set' module. + +config IP_SET_BITMAP_IP + tristate "bitmap:ip set support" + depends on IP_SET + help + This option adds the bitmap:ip set type support, by which one + can store IPv4 addresses (or network addresse) from a range. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_SET_BITMAP_IPMAC + tristate "bitmap:ip,mac set support" + depends on IP_SET + help + This option adds the bitmap:ip,mac set type support, by which one + can store IPv4 address and (source) MAC address pairs from a range. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_SET_BITMAP_PORT + tristate "bitmap:port set support" + depends on IP_SET + help + This option adds the bitmap:port set type support, by which one + can store TCP/UDP port numbers from a range. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_SET_HASH_IP + tristate "hash:ip set support" + depends on IP_SET + help + This option adds the hash:ip set type support, by which one + can store arbitrary IPv4 or IPv6 addresses (or network addresses) + in a set. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_SET_HASH_IPMARK + tristate "hash:ip,mark set support" + depends on IP_SET + help + This option adds the hash:ip,mark set type support, by which one + can store IPv4/IPv6 address and mark pairs. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_SET_HASH_IPPORT + tristate "hash:ip,port set support" + depends on IP_SET + help + This option adds the hash:ip,port set type support, by which one + can store IPv4/IPv6 address and protocol/port pairs. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_SET_HASH_IPPORTIP + tristate "hash:ip,port,ip set support" + depends on IP_SET + help + This option adds the hash:ip,port,ip set type support, by which + one can store IPv4/IPv6 address, protocol/port, and IPv4/IPv6 + address triples in a set. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_SET_HASH_IPPORTNET + tristate "hash:ip,port,net set support" + depends on IP_SET + help + This option adds the hash:ip,port,net set type support, by which + one can store IPv4/IPv6 address, protocol/port, and IPv4/IPv6 + network address/prefix triples in a set. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_SET_HASH_IPMAC + tristate "hash:ip,mac set support" + depends on IP_SET + help + This option adds the hash:ip,mac set type support, by which + one can store IPv4/IPv6 address and MAC (ethernet address) pairs in a set. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_SET_HASH_MAC + tristate "hash:mac set support" + depends on IP_SET + help + This option adds the hash:mac set type support, by which + one can store MAC (ethernet address) elements in a set. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_SET_HASH_NETPORTNET + tristate "hash:net,port,net set support" + depends on IP_SET + help + This option adds the hash:net,port,net set type support, by which + one can store two IPv4/IPv6 subnets, and a protocol/port in a set. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_SET_HASH_NET + tristate "hash:net set support" + depends on IP_SET + help + This option adds the hash:net set type support, by which + one can store IPv4/IPv6 network address/prefix elements in a set. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_SET_HASH_NETNET + tristate "hash:net,net set support" + depends on IP_SET + help + This option adds the hash:net,net set type support, by which + one can store IPv4/IPv6 network address/prefix pairs in a set. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_SET_HASH_NETPORT + tristate "hash:net,port set support" + depends on IP_SET + help + This option adds the hash:net,port set type support, by which + one can store IPv4/IPv6 network address/prefix and + protocol/port pairs as elements in a set. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_SET_HASH_NETIFACE + tristate "hash:net,iface set support" + depends on IP_SET + help + This option adds the hash:net,iface set type support, by which + one can store IPv4/IPv6 network address/prefix and + interface name pairs as elements in a set. + + To compile it as a module, choose M here. If unsure, say N. + +config IP_SET_LIST_SET + tristate "list:set set support" + depends on IP_SET + help + This option adds the list:set set type support. In this + kind of set one can store the name of other sets and it forms + an ordered union of the member sets. + + To compile it as a module, choose M here. If unsure, say N. + +endif # IP_SET diff --git a/net/netfilter/ipset/Makefile b/net/netfilter/ipset/Makefile new file mode 100644 index 000000000..a445a6bf4 --- /dev/null +++ b/net/netfilter/ipset/Makefile @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the ipset modules +# + +ip_set-y := ip_set_core.o ip_set_getport.o pfxlen.o + +# ipset core +obj-$(CONFIG_IP_SET) += ip_set.o + +# bitmap types +obj-$(CONFIG_IP_SET_BITMAP_IP) += ip_set_bitmap_ip.o +obj-$(CONFIG_IP_SET_BITMAP_IPMAC) += ip_set_bitmap_ipmac.o +obj-$(CONFIG_IP_SET_BITMAP_PORT) += ip_set_bitmap_port.o + +# hash types +obj-$(CONFIG_IP_SET_HASH_IP) += ip_set_hash_ip.o +obj-$(CONFIG_IP_SET_HASH_IPMAC) += ip_set_hash_ipmac.o +obj-$(CONFIG_IP_SET_HASH_IPMARK) += ip_set_hash_ipmark.o +obj-$(CONFIG_IP_SET_HASH_IPPORT) += ip_set_hash_ipport.o +obj-$(CONFIG_IP_SET_HASH_IPPORTIP) += ip_set_hash_ipportip.o +obj-$(CONFIG_IP_SET_HASH_IPPORTNET) += ip_set_hash_ipportnet.o +obj-$(CONFIG_IP_SET_HASH_MAC) += ip_set_hash_mac.o +obj-$(CONFIG_IP_SET_HASH_NET) += ip_set_hash_net.o +obj-$(CONFIG_IP_SET_HASH_NETPORT) += ip_set_hash_netport.o +obj-$(CONFIG_IP_SET_HASH_NETIFACE) += ip_set_hash_netiface.o +obj-$(CONFIG_IP_SET_HASH_NETNET) += ip_set_hash_netnet.o +obj-$(CONFIG_IP_SET_HASH_NETPORTNET) += ip_set_hash_netportnet.o + +# list types +obj-$(CONFIG_IP_SET_LIST_SET) += ip_set_list_set.o diff --git a/net/netfilter/ipset/ip_set_bitmap_gen.h b/net/netfilter/ipset/ip_set_bitmap_gen.h new file mode 100644 index 000000000..26ab0e961 --- /dev/null +++ b/net/netfilter/ipset/ip_set_bitmap_gen.h @@ -0,0 +1,306 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* Copyright (C) 2013 Jozsef Kadlecsik */ + +#ifndef __IP_SET_BITMAP_IP_GEN_H +#define __IP_SET_BITMAP_IP_GEN_H + +#define mtype_do_test IPSET_TOKEN(MTYPE, _do_test) +#define mtype_gc_test IPSET_TOKEN(MTYPE, _gc_test) +#define mtype_is_filled IPSET_TOKEN(MTYPE, _is_filled) +#define mtype_do_add IPSET_TOKEN(MTYPE, _do_add) +#define mtype_ext_cleanup IPSET_TOKEN(MTYPE, _ext_cleanup) +#define mtype_do_del IPSET_TOKEN(MTYPE, _do_del) +#define mtype_do_list IPSET_TOKEN(MTYPE, _do_list) +#define mtype_do_head IPSET_TOKEN(MTYPE, _do_head) +#define mtype_adt_elem IPSET_TOKEN(MTYPE, _adt_elem) +#define mtype_add_timeout IPSET_TOKEN(MTYPE, _add_timeout) +#define mtype_gc_init IPSET_TOKEN(MTYPE, _gc_init) +#define mtype_kadt IPSET_TOKEN(MTYPE, _kadt) +#define mtype_uadt IPSET_TOKEN(MTYPE, _uadt) +#define mtype_destroy IPSET_TOKEN(MTYPE, _destroy) +#define mtype_memsize IPSET_TOKEN(MTYPE, _memsize) +#define mtype_flush IPSET_TOKEN(MTYPE, _flush) +#define mtype_head IPSET_TOKEN(MTYPE, _head) +#define mtype_same_set IPSET_TOKEN(MTYPE, _same_set) +#define mtype_elem IPSET_TOKEN(MTYPE, _elem) +#define mtype_test IPSET_TOKEN(MTYPE, _test) +#define mtype_add IPSET_TOKEN(MTYPE, _add) +#define mtype_del IPSET_TOKEN(MTYPE, _del) +#define mtype_list IPSET_TOKEN(MTYPE, _list) +#define mtype_gc IPSET_TOKEN(MTYPE, _gc) +#define mtype MTYPE + +#define get_ext(set, map, id) ((map)->extensions + ((set)->dsize * (id))) + +static void +mtype_gc_init(struct ip_set *set, void (*gc)(struct timer_list *t)) +{ + struct mtype *map = set->data; + + timer_setup(&map->gc, gc, 0); + mod_timer(&map->gc, jiffies + IPSET_GC_PERIOD(set->timeout) * HZ); +} + +static void +mtype_ext_cleanup(struct ip_set *set) +{ + struct mtype *map = set->data; + u32 id; + + for (id = 0; id < map->elements; id++) + if (test_bit(id, map->members)) + ip_set_ext_destroy(set, get_ext(set, map, id)); +} + +static void +mtype_destroy(struct ip_set *set) +{ + struct mtype *map = set->data; + + if (SET_WITH_TIMEOUT(set)) + del_timer_sync(&map->gc); + + if (set->dsize && set->extensions & IPSET_EXT_DESTROY) + mtype_ext_cleanup(set); + ip_set_free(map->members); + ip_set_free(map); + + set->data = NULL; +} + +static void +mtype_flush(struct ip_set *set) +{ + struct mtype *map = set->data; + + if (set->extensions & IPSET_EXT_DESTROY) + mtype_ext_cleanup(set); + bitmap_zero(map->members, map->elements); + set->elements = 0; + set->ext_size = 0; +} + +/* Calculate the actual memory size of the set data */ +static size_t +mtype_memsize(const struct mtype *map, size_t dsize) +{ + return sizeof(*map) + map->memsize + + map->elements * dsize; +} + +static int +mtype_head(struct ip_set *set, struct sk_buff *skb) +{ + const struct mtype *map = set->data; + struct nlattr *nested; + size_t memsize = mtype_memsize(map, set->dsize) + set->ext_size; + + nested = nla_nest_start(skb, IPSET_ATTR_DATA); + if (!nested) + goto nla_put_failure; + if (mtype_do_head(skb, map) || + nla_put_net32(skb, IPSET_ATTR_REFERENCES, htonl(set->ref)) || + nla_put_net32(skb, IPSET_ATTR_MEMSIZE, htonl(memsize)) || + nla_put_net32(skb, IPSET_ATTR_ELEMENTS, htonl(set->elements))) + goto nla_put_failure; + if (unlikely(ip_set_put_flags(skb, set))) + goto nla_put_failure; + nla_nest_end(skb, nested); + + return 0; +nla_put_failure: + return -EMSGSIZE; +} + +static int +mtype_test(struct ip_set *set, void *value, const struct ip_set_ext *ext, + struct ip_set_ext *mext, u32 flags) +{ + struct mtype *map = set->data; + const struct mtype_adt_elem *e = value; + void *x = get_ext(set, map, e->id); + int ret = mtype_do_test(e, map, set->dsize); + + if (ret <= 0) + return ret; + return ip_set_match_extensions(set, ext, mext, flags, x); +} + +static int +mtype_add(struct ip_set *set, void *value, const struct ip_set_ext *ext, + struct ip_set_ext *mext, u32 flags) +{ + struct mtype *map = set->data; + const struct mtype_adt_elem *e = value; + void *x = get_ext(set, map, e->id); + int ret = mtype_do_add(e, map, flags, set->dsize); + + if (ret == IPSET_ADD_FAILED) { + if (SET_WITH_TIMEOUT(set) && + ip_set_timeout_expired(ext_timeout(x, set))) { + set->elements--; + ret = 0; + } else if (!(flags & IPSET_FLAG_EXIST)) { + set_bit(e->id, map->members); + return -IPSET_ERR_EXIST; + } + /* Element is re-added, cleanup extensions */ + ip_set_ext_destroy(set, x); + } + if (ret > 0) + set->elements--; + + if (SET_WITH_TIMEOUT(set)) +#ifdef IP_SET_BITMAP_STORED_TIMEOUT + mtype_add_timeout(ext_timeout(x, set), e, ext, set, map, ret); +#else + ip_set_timeout_set(ext_timeout(x, set), ext->timeout); +#endif + + if (SET_WITH_COUNTER(set)) + ip_set_init_counter(ext_counter(x, set), ext); + if (SET_WITH_COMMENT(set)) + ip_set_init_comment(set, ext_comment(x, set), ext); + if (SET_WITH_SKBINFO(set)) + ip_set_init_skbinfo(ext_skbinfo(x, set), ext); + + /* Activate element */ + set_bit(e->id, map->members); + set->elements++; + + return 0; +} + +static int +mtype_del(struct ip_set *set, void *value, const struct ip_set_ext *ext, + struct ip_set_ext *mext, u32 flags) +{ + struct mtype *map = set->data; + const struct mtype_adt_elem *e = value; + void *x = get_ext(set, map, e->id); + + if (mtype_do_del(e, map)) + return -IPSET_ERR_EXIST; + + ip_set_ext_destroy(set, x); + set->elements--; + if (SET_WITH_TIMEOUT(set) && + ip_set_timeout_expired(ext_timeout(x, set))) + return -IPSET_ERR_EXIST; + + return 0; +} + +#ifndef IP_SET_BITMAP_STORED_TIMEOUT +static bool +mtype_is_filled(const struct mtype_elem *x) +{ + return true; +} +#endif + +static int +mtype_list(const struct ip_set *set, + struct sk_buff *skb, struct netlink_callback *cb) +{ + struct mtype *map = set->data; + struct nlattr *adt, *nested; + void *x; + u32 id, first = cb->args[IPSET_CB_ARG0]; + int ret = 0; + + adt = nla_nest_start(skb, IPSET_ATTR_ADT); + if (!adt) + return -EMSGSIZE; + /* Extensions may be replaced */ + rcu_read_lock(); + for (; cb->args[IPSET_CB_ARG0] < map->elements; + cb->args[IPSET_CB_ARG0]++) { + cond_resched_rcu(); + id = cb->args[IPSET_CB_ARG0]; + x = get_ext(set, map, id); + if (!test_bit(id, map->members) || + (SET_WITH_TIMEOUT(set) && +#ifdef IP_SET_BITMAP_STORED_TIMEOUT + mtype_is_filled(x) && +#endif + ip_set_timeout_expired(ext_timeout(x, set)))) + continue; + nested = nla_nest_start(skb, IPSET_ATTR_DATA); + if (!nested) { + if (id == first) { + nla_nest_cancel(skb, adt); + ret = -EMSGSIZE; + goto out; + } + + goto nla_put_failure; + } + if (mtype_do_list(skb, map, id, set->dsize)) + goto nla_put_failure; + if (ip_set_put_extensions(skb, set, x, mtype_is_filled(x))) + goto nla_put_failure; + nla_nest_end(skb, nested); + } + nla_nest_end(skb, adt); + + /* Set listing finished */ + cb->args[IPSET_CB_ARG0] = 0; + + goto out; + +nla_put_failure: + nla_nest_cancel(skb, nested); + if (unlikely(id == first)) { + cb->args[IPSET_CB_ARG0] = 0; + ret = -EMSGSIZE; + } + nla_nest_end(skb, adt); +out: + rcu_read_unlock(); + return ret; +} + +static void +mtype_gc(struct timer_list *t) +{ + struct mtype *map = from_timer(map, t, gc); + struct ip_set *set = map->set; + void *x; + u32 id; + + /* We run parallel with other readers (test element) + * but adding/deleting new entries is locked out + */ + spin_lock_bh(&set->lock); + for (id = 0; id < map->elements; id++) + if (mtype_gc_test(id, map, set->dsize)) { + x = get_ext(set, map, id); + if (ip_set_timeout_expired(ext_timeout(x, set))) { + clear_bit(id, map->members); + ip_set_ext_destroy(set, x); + set->elements--; + } + } + spin_unlock_bh(&set->lock); + + map->gc.expires = jiffies + IPSET_GC_PERIOD(set->timeout) * HZ; + add_timer(&map->gc); +} + +static const struct ip_set_type_variant mtype = { + .kadt = mtype_kadt, + .uadt = mtype_uadt, + .adt = { + [IPSET_ADD] = mtype_add, + [IPSET_DEL] = mtype_del, + [IPSET_TEST] = mtype_test, + }, + .destroy = mtype_destroy, + .flush = mtype_flush, + .head = mtype_head, + .list = mtype_list, + .same_set = mtype_same_set, +}; + +#endif /* __IP_SET_BITMAP_IP_GEN_H */ diff --git a/net/netfilter/ipset/ip_set_bitmap_ip.c b/net/netfilter/ipset/ip_set_bitmap_ip.c new file mode 100644 index 000000000..e4fa00abd --- /dev/null +++ b/net/netfilter/ipset/ip_set_bitmap_ip.c @@ -0,0 +1,387 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2000-2002 Joakim Axelsson + * Patrick Schaaf + * Copyright (C) 2003-2013 Jozsef Kadlecsik + */ + +/* Kernel module implementing an IP set type: the bitmap:ip type */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#define IPSET_TYPE_REV_MIN 0 +/* 1 Counter support added */ +/* 2 Comment support added */ +#define IPSET_TYPE_REV_MAX 3 /* skbinfo support added */ + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Jozsef Kadlecsik "); +IP_SET_MODULE_DESC("bitmap:ip", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX); +MODULE_ALIAS("ip_set_bitmap:ip"); + +#define MTYPE bitmap_ip +#define HOST_MASK 32 + +/* Type structure */ +struct bitmap_ip { + unsigned long *members; /* the set members */ + u32 first_ip; /* host byte order, included in range */ + u32 last_ip; /* host byte order, included in range */ + u32 elements; /* number of max elements in the set */ + u32 hosts; /* number of hosts in a subnet */ + size_t memsize; /* members size */ + u8 netmask; /* subnet netmask */ + struct timer_list gc; /* garbage collection */ + struct ip_set *set; /* attached to this ip_set */ + unsigned char extensions[] /* data extensions */ + __aligned(__alignof__(u64)); +}; + +/* ADT structure for generic function args */ +struct bitmap_ip_adt_elem { + u16 id; +}; + +static u32 +ip_to_id(const struct bitmap_ip *m, u32 ip) +{ + return ((ip & ip_set_hostmask(m->netmask)) - m->first_ip) / m->hosts; +} + +/* Common functions */ + +static int +bitmap_ip_do_test(const struct bitmap_ip_adt_elem *e, + struct bitmap_ip *map, size_t dsize) +{ + return !!test_bit(e->id, map->members); +} + +static int +bitmap_ip_gc_test(u16 id, const struct bitmap_ip *map, size_t dsize) +{ + return !!test_bit(id, map->members); +} + +static int +bitmap_ip_do_add(const struct bitmap_ip_adt_elem *e, struct bitmap_ip *map, + u32 flags, size_t dsize) +{ + return !!test_bit(e->id, map->members); +} + +static int +bitmap_ip_do_del(const struct bitmap_ip_adt_elem *e, struct bitmap_ip *map) +{ + return !test_and_clear_bit(e->id, map->members); +} + +static int +bitmap_ip_do_list(struct sk_buff *skb, const struct bitmap_ip *map, u32 id, + size_t dsize) +{ + return nla_put_ipaddr4(skb, IPSET_ATTR_IP, + htonl(map->first_ip + id * map->hosts)); +} + +static int +bitmap_ip_do_head(struct sk_buff *skb, const struct bitmap_ip *map) +{ + return nla_put_ipaddr4(skb, IPSET_ATTR_IP, htonl(map->first_ip)) || + nla_put_ipaddr4(skb, IPSET_ATTR_IP_TO, htonl(map->last_ip)) || + (map->netmask != 32 && + nla_put_u8(skb, IPSET_ATTR_NETMASK, map->netmask)); +} + +static int +bitmap_ip_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + struct bitmap_ip *map = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct bitmap_ip_adt_elem e = { .id = 0 }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + u32 ip; + + ip = ntohl(ip4addr(skb, opt->flags & IPSET_DIM_ONE_SRC)); + if (ip < map->first_ip || ip > map->last_ip) + return -IPSET_ERR_BITMAP_RANGE; + + e.id = ip_to_id(map, ip); + + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +bitmap_ip_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + struct bitmap_ip *map = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + u32 ip = 0, ip_to = 0; + struct bitmap_ip_adt_elem e = { .id = 0 }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + int ret = 0; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP])) + return -IPSET_ERR_PROTOCOL; + + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP], &ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + if (ip < map->first_ip || ip > map->last_ip) + return -IPSET_ERR_BITMAP_RANGE; + + if (adt == IPSET_TEST) { + e.id = ip_to_id(map, ip); + return adtfn(set, &e, &ext, &ext, flags); + } + + if (tb[IPSET_ATTR_IP_TO]) { + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP_TO], &ip_to); + if (ret) + return ret; + if (ip > ip_to) { + swap(ip, ip_to); + if (ip < map->first_ip) + return -IPSET_ERR_BITMAP_RANGE; + } + } else if (tb[IPSET_ATTR_CIDR]) { + u8 cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + + if (!cidr || cidr > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + ip_set_mask_from_to(ip, ip_to, cidr); + } else { + ip_to = ip; + } + + if (ip_to > map->last_ip) + return -IPSET_ERR_BITMAP_RANGE; + + for (; !before(ip_to, ip); ip += map->hosts) { + e.id = ip_to_id(map, ip); + ret = adtfn(set, &e, &ext, &ext, flags); + + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + ret = 0; + } + return ret; +} + +static bool +bitmap_ip_same_set(const struct ip_set *a, const struct ip_set *b) +{ + const struct bitmap_ip *x = a->data; + const struct bitmap_ip *y = b->data; + + return x->first_ip == y->first_ip && + x->last_ip == y->last_ip && + x->netmask == y->netmask && + a->timeout == b->timeout && + a->extensions == b->extensions; +} + +/* Plain variant */ + +struct bitmap_ip_elem { +}; + +#include "ip_set_bitmap_gen.h" + +/* Create bitmap:ip type of sets */ + +static bool +init_map_ip(struct ip_set *set, struct bitmap_ip *map, + u32 first_ip, u32 last_ip, + u32 elements, u32 hosts, u8 netmask) +{ + map->members = bitmap_zalloc(elements, GFP_KERNEL | __GFP_NOWARN); + if (!map->members) + return false; + map->first_ip = first_ip; + map->last_ip = last_ip; + map->elements = elements; + map->hosts = hosts; + map->netmask = netmask; + set->timeout = IPSET_NO_TIMEOUT; + + map->set = set; + set->data = map; + set->family = NFPROTO_IPV4; + + return true; +} + +static u32 +range_to_mask(u32 from, u32 to, u8 *bits) +{ + u32 mask = 0xFFFFFFFE; + + *bits = 32; + while (--(*bits) > 0 && mask && (to & mask) != from) + mask <<= 1; + + return mask; +} + +static int +bitmap_ip_create(struct net *net, struct ip_set *set, struct nlattr *tb[], + u32 flags) +{ + struct bitmap_ip *map; + u32 first_ip = 0, last_ip = 0, hosts; + u64 elements; + u8 netmask = 32; + int ret; + + if (unlikely(!tb[IPSET_ATTR_IP] || + !ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP], &first_ip); + if (ret) + return ret; + + if (tb[IPSET_ATTR_IP_TO]) { + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP_TO], &last_ip); + if (ret) + return ret; + if (first_ip > last_ip) + swap(first_ip, last_ip); + } else if (tb[IPSET_ATTR_CIDR]) { + u8 cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + + if (cidr >= HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + ip_set_mask_from_to(first_ip, last_ip, cidr); + } else { + return -IPSET_ERR_PROTOCOL; + } + + if (tb[IPSET_ATTR_NETMASK]) { + netmask = nla_get_u8(tb[IPSET_ATTR_NETMASK]); + + if (netmask > HOST_MASK) + return -IPSET_ERR_INVALID_NETMASK; + + first_ip &= ip_set_hostmask(netmask); + last_ip |= ~ip_set_hostmask(netmask); + } + + if (netmask == 32) { + hosts = 1; + elements = (u64)last_ip - first_ip + 1; + } else { + u8 mask_bits; + u32 mask; + + mask = range_to_mask(first_ip, last_ip, &mask_bits); + + if ((!mask && (first_ip || last_ip != 0xFFFFFFFF)) || + netmask <= mask_bits) + return -IPSET_ERR_BITMAP_RANGE; + + pr_debug("mask_bits %u, netmask %u\n", mask_bits, netmask); + hosts = 2U << (32 - netmask - 1); + elements = 2UL << (netmask - mask_bits - 1); + } + if (elements > IPSET_BITMAP_MAX_RANGE + 1) + return -IPSET_ERR_BITMAP_RANGE_SIZE; + + pr_debug("hosts %u, elements %llu\n", + hosts, (unsigned long long)elements); + + set->dsize = ip_set_elem_len(set, tb, 0, 0); + map = ip_set_alloc(sizeof(*map) + elements * set->dsize); + if (!map) + return -ENOMEM; + + map->memsize = BITS_TO_LONGS(elements) * sizeof(unsigned long); + set->variant = &bitmap_ip; + if (!init_map_ip(set, map, first_ip, last_ip, + elements, hosts, netmask)) { + ip_set_free(map); + return -ENOMEM; + } + if (tb[IPSET_ATTR_TIMEOUT]) { + set->timeout = ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]); + bitmap_ip_gc_init(set, bitmap_ip_gc); + } + return 0; +} + +static struct ip_set_type bitmap_ip_type __read_mostly = { + .name = "bitmap:ip", + .protocol = IPSET_PROTOCOL, + .features = IPSET_TYPE_IP, + .dimension = IPSET_DIM_ONE, + .family = NFPROTO_IPV4, + .revision_min = IPSET_TYPE_REV_MIN, + .revision_max = IPSET_TYPE_REV_MAX, + .create = bitmap_ip_create, + .create_policy = { + [IPSET_ATTR_IP] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP_TO] = { .type = NLA_NESTED }, + [IPSET_ATTR_CIDR] = { .type = NLA_U8 }, + [IPSET_ATTR_NETMASK] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + }, + .adt_policy = { + [IPSET_ATTR_IP] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP_TO] = { .type = NLA_NESTED }, + [IPSET_ATTR_CIDR] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_LINENO] = { .type = NLA_U32 }, + [IPSET_ATTR_BYTES] = { .type = NLA_U64 }, + [IPSET_ATTR_PACKETS] = { .type = NLA_U64 }, + [IPSET_ATTR_COMMENT] = { .type = NLA_NUL_STRING, + .len = IPSET_MAX_COMMENT_SIZE }, + [IPSET_ATTR_SKBMARK] = { .type = NLA_U64 }, + [IPSET_ATTR_SKBPRIO] = { .type = NLA_U32 }, + [IPSET_ATTR_SKBQUEUE] = { .type = NLA_U16 }, + }, + .me = THIS_MODULE, +}; + +static int __init +bitmap_ip_init(void) +{ + return ip_set_type_register(&bitmap_ip_type); +} + +static void __exit +bitmap_ip_fini(void) +{ + rcu_barrier(); + ip_set_type_unregister(&bitmap_ip_type); +} + +module_init(bitmap_ip_init); +module_exit(bitmap_ip_fini); diff --git a/net/netfilter/ipset/ip_set_bitmap_ipmac.c b/net/netfilter/ipset/ip_set_bitmap_ipmac.c new file mode 100644 index 000000000..2c625e0f4 --- /dev/null +++ b/net/netfilter/ipset/ip_set_bitmap_ipmac.c @@ -0,0 +1,423 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2000-2002 Joakim Axelsson + * Patrick Schaaf + * Martin Josefsson + */ + +/* Kernel module implementing an IP set type: the bitmap:ip,mac type */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#define IPSET_TYPE_REV_MIN 0 +/* 1 Counter support added */ +/* 2 Comment support added */ +#define IPSET_TYPE_REV_MAX 3 /* skbinfo support added */ + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Jozsef Kadlecsik "); +IP_SET_MODULE_DESC("bitmap:ip,mac", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX); +MODULE_ALIAS("ip_set_bitmap:ip,mac"); + +#define MTYPE bitmap_ipmac +#define HOST_MASK 32 +#define IP_SET_BITMAP_STORED_TIMEOUT + +enum { + MAC_UNSET, /* element is set, without MAC */ + MAC_FILLED, /* element is set with MAC */ +}; + +/* Type structure */ +struct bitmap_ipmac { + unsigned long *members; /* the set members */ + u32 first_ip; /* host byte order, included in range */ + u32 last_ip; /* host byte order, included in range */ + u32 elements; /* number of max elements in the set */ + size_t memsize; /* members size */ + struct timer_list gc; /* garbage collector */ + struct ip_set *set; /* attached to this ip_set */ + unsigned char extensions[] /* MAC + data extensions */ + __aligned(__alignof__(u64)); +}; + +/* ADT structure for generic function args */ +struct bitmap_ipmac_adt_elem { + unsigned char ether[ETH_ALEN] __aligned(2); + u16 id; + u16 add_mac; +}; + +struct bitmap_ipmac_elem { + unsigned char ether[ETH_ALEN]; + unsigned char filled; +} __aligned(__alignof__(u64)); + +static u32 +ip_to_id(const struct bitmap_ipmac *m, u32 ip) +{ + return ip - m->first_ip; +} + +#define get_elem(extensions, id, dsize) \ + (struct bitmap_ipmac_elem *)(extensions + (id) * (dsize)) + +#define get_const_elem(extensions, id, dsize) \ + (const struct bitmap_ipmac_elem *)(extensions + (id) * (dsize)) + +/* Common functions */ + +static int +bitmap_ipmac_do_test(const struct bitmap_ipmac_adt_elem *e, + const struct bitmap_ipmac *map, size_t dsize) +{ + const struct bitmap_ipmac_elem *elem; + + if (!test_bit(e->id, map->members)) + return 0; + elem = get_const_elem(map->extensions, e->id, dsize); + if (e->add_mac && elem->filled == MAC_FILLED) + return ether_addr_equal(e->ether, elem->ether); + /* Trigger kernel to fill out the ethernet address */ + return -EAGAIN; +} + +static int +bitmap_ipmac_gc_test(u16 id, const struct bitmap_ipmac *map, size_t dsize) +{ + const struct bitmap_ipmac_elem *elem; + + if (!test_bit(id, map->members)) + return 0; + elem = get_const_elem(map->extensions, id, dsize); + /* Timer not started for the incomplete elements */ + return elem->filled == MAC_FILLED; +} + +static int +bitmap_ipmac_is_filled(const struct bitmap_ipmac_elem *elem) +{ + return elem->filled == MAC_FILLED; +} + +static int +bitmap_ipmac_add_timeout(unsigned long *timeout, + const struct bitmap_ipmac_adt_elem *e, + const struct ip_set_ext *ext, struct ip_set *set, + struct bitmap_ipmac *map, int mode) +{ + u32 t = ext->timeout; + + if (mode == IPSET_ADD_START_STORED_TIMEOUT) { + if (t == set->timeout) + /* Timeout was not specified, get stored one */ + t = *timeout; + ip_set_timeout_set(timeout, t); + } else { + /* If MAC is unset yet, we store plain timeout value + * because the timer is not activated yet + * and we can reuse it later when MAC is filled out, + * possibly by the kernel + */ + if (e->add_mac) + ip_set_timeout_set(timeout, t); + else + *timeout = t; + } + return 0; +} + +static int +bitmap_ipmac_do_add(const struct bitmap_ipmac_adt_elem *e, + struct bitmap_ipmac *map, u32 flags, size_t dsize) +{ + struct bitmap_ipmac_elem *elem; + + elem = get_elem(map->extensions, e->id, dsize); + if (test_bit(e->id, map->members)) { + if (elem->filled == MAC_FILLED) { + if (e->add_mac && + (flags & IPSET_FLAG_EXIST) && + !ether_addr_equal(e->ether, elem->ether)) { + /* memcpy isn't atomic */ + clear_bit(e->id, map->members); + smp_mb__after_atomic(); + ether_addr_copy(elem->ether, e->ether); + } + return IPSET_ADD_FAILED; + } else if (!e->add_mac) + /* Already added without ethernet address */ + return IPSET_ADD_FAILED; + /* Fill the MAC address and trigger the timer activation */ + clear_bit(e->id, map->members); + smp_mb__after_atomic(); + ether_addr_copy(elem->ether, e->ether); + elem->filled = MAC_FILLED; + return IPSET_ADD_START_STORED_TIMEOUT; + } else if (e->add_mac) { + /* We can store MAC too */ + ether_addr_copy(elem->ether, e->ether); + elem->filled = MAC_FILLED; + return 0; + } + elem->filled = MAC_UNSET; + /* MAC is not stored yet, don't start timer */ + return IPSET_ADD_STORE_PLAIN_TIMEOUT; +} + +static int +bitmap_ipmac_do_del(const struct bitmap_ipmac_adt_elem *e, + struct bitmap_ipmac *map) +{ + return !test_and_clear_bit(e->id, map->members); +} + +static int +bitmap_ipmac_do_list(struct sk_buff *skb, const struct bitmap_ipmac *map, + u32 id, size_t dsize) +{ + const struct bitmap_ipmac_elem *elem = + get_const_elem(map->extensions, id, dsize); + + return nla_put_ipaddr4(skb, IPSET_ATTR_IP, + htonl(map->first_ip + id)) || + (elem->filled == MAC_FILLED && + nla_put(skb, IPSET_ATTR_ETHER, ETH_ALEN, elem->ether)); +} + +static int +bitmap_ipmac_do_head(struct sk_buff *skb, const struct bitmap_ipmac *map) +{ + return nla_put_ipaddr4(skb, IPSET_ATTR_IP, htonl(map->first_ip)) || + nla_put_ipaddr4(skb, IPSET_ATTR_IP_TO, htonl(map->last_ip)); +} + +static int +bitmap_ipmac_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + struct bitmap_ipmac *map = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct bitmap_ipmac_adt_elem e = { .id = 0, .add_mac = 1 }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + u32 ip; + + ip = ntohl(ip4addr(skb, opt->flags & IPSET_DIM_ONE_SRC)); + if (ip < map->first_ip || ip > map->last_ip) + return -IPSET_ERR_BITMAP_RANGE; + + /* Backward compatibility: we don't check the second flag */ + if (skb_mac_header(skb) < skb->head || + (skb_mac_header(skb) + ETH_HLEN) > skb->data) + return -EINVAL; + + e.id = ip_to_id(map, ip); + + if (opt->flags & IPSET_DIM_TWO_SRC) + ether_addr_copy(e.ether, eth_hdr(skb)->h_source); + else + ether_addr_copy(e.ether, eth_hdr(skb)->h_dest); + + if (is_zero_ether_addr(e.ether)) + return -EINVAL; + + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +bitmap_ipmac_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + const struct bitmap_ipmac *map = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct bitmap_ipmac_adt_elem e = { .id = 0 }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + u32 ip = 0; + int ret = 0; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP])) + return -IPSET_ERR_PROTOCOL; + + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP], &ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + if (ip < map->first_ip || ip > map->last_ip) + return -IPSET_ERR_BITMAP_RANGE; + + e.id = ip_to_id(map, ip); + if (tb[IPSET_ATTR_ETHER]) { + if (nla_len(tb[IPSET_ATTR_ETHER]) != ETH_ALEN) + return -IPSET_ERR_PROTOCOL; + memcpy(e.ether, nla_data(tb[IPSET_ATTR_ETHER]), ETH_ALEN); + e.add_mac = 1; + } + ret = adtfn(set, &e, &ext, &ext, flags); + + return ip_set_eexist(ret, flags) ? 0 : ret; +} + +static bool +bitmap_ipmac_same_set(const struct ip_set *a, const struct ip_set *b) +{ + const struct bitmap_ipmac *x = a->data; + const struct bitmap_ipmac *y = b->data; + + return x->first_ip == y->first_ip && + x->last_ip == y->last_ip && + a->timeout == b->timeout && + a->extensions == b->extensions; +} + +/* Plain variant */ + +#include "ip_set_bitmap_gen.h" + +/* Create bitmap:ip,mac type of sets */ + +static bool +init_map_ipmac(struct ip_set *set, struct bitmap_ipmac *map, + u32 first_ip, u32 last_ip, u32 elements) +{ + map->members = bitmap_zalloc(elements, GFP_KERNEL | __GFP_NOWARN); + if (!map->members) + return false; + map->first_ip = first_ip; + map->last_ip = last_ip; + map->elements = elements; + set->timeout = IPSET_NO_TIMEOUT; + + map->set = set; + set->data = map; + set->family = NFPROTO_IPV4; + + return true; +} + +static int +bitmap_ipmac_create(struct net *net, struct ip_set *set, struct nlattr *tb[], + u32 flags) +{ + u32 first_ip = 0, last_ip = 0; + u64 elements; + struct bitmap_ipmac *map; + int ret; + + if (unlikely(!tb[IPSET_ATTR_IP] || + !ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP], &first_ip); + if (ret) + return ret; + + if (tb[IPSET_ATTR_IP_TO]) { + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP_TO], &last_ip); + if (ret) + return ret; + if (first_ip > last_ip) + swap(first_ip, last_ip); + } else if (tb[IPSET_ATTR_CIDR]) { + u8 cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + + if (cidr >= HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + ip_set_mask_from_to(first_ip, last_ip, cidr); + } else { + return -IPSET_ERR_PROTOCOL; + } + + elements = (u64)last_ip - first_ip + 1; + + if (elements > IPSET_BITMAP_MAX_RANGE + 1) + return -IPSET_ERR_BITMAP_RANGE_SIZE; + + set->dsize = ip_set_elem_len(set, tb, + sizeof(struct bitmap_ipmac_elem), + __alignof__(struct bitmap_ipmac_elem)); + map = ip_set_alloc(sizeof(*map) + elements * set->dsize); + if (!map) + return -ENOMEM; + + map->memsize = BITS_TO_LONGS(elements) * sizeof(unsigned long); + set->variant = &bitmap_ipmac; + if (!init_map_ipmac(set, map, first_ip, last_ip, elements)) { + ip_set_free(map); + return -ENOMEM; + } + if (tb[IPSET_ATTR_TIMEOUT]) { + set->timeout = ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]); + bitmap_ipmac_gc_init(set, bitmap_ipmac_gc); + } + return 0; +} + +static struct ip_set_type bitmap_ipmac_type = { + .name = "bitmap:ip,mac", + .protocol = IPSET_PROTOCOL, + .features = IPSET_TYPE_IP | IPSET_TYPE_MAC, + .dimension = IPSET_DIM_TWO, + .family = NFPROTO_IPV4, + .revision_min = IPSET_TYPE_REV_MIN, + .revision_max = IPSET_TYPE_REV_MAX, + .create = bitmap_ipmac_create, + .create_policy = { + [IPSET_ATTR_IP] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP_TO] = { .type = NLA_NESTED }, + [IPSET_ATTR_CIDR] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + }, + .adt_policy = { + [IPSET_ATTR_IP] = { .type = NLA_NESTED }, + [IPSET_ATTR_ETHER] = { .type = NLA_BINARY, + .len = ETH_ALEN }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_LINENO] = { .type = NLA_U32 }, + [IPSET_ATTR_BYTES] = { .type = NLA_U64 }, + [IPSET_ATTR_PACKETS] = { .type = NLA_U64 }, + [IPSET_ATTR_COMMENT] = { .type = NLA_NUL_STRING, + .len = IPSET_MAX_COMMENT_SIZE }, + [IPSET_ATTR_SKBMARK] = { .type = NLA_U64 }, + [IPSET_ATTR_SKBPRIO] = { .type = NLA_U32 }, + [IPSET_ATTR_SKBQUEUE] = { .type = NLA_U16 }, + }, + .me = THIS_MODULE, +}; + +static int __init +bitmap_ipmac_init(void) +{ + return ip_set_type_register(&bitmap_ipmac_type); +} + +static void __exit +bitmap_ipmac_fini(void) +{ + rcu_barrier(); + ip_set_type_unregister(&bitmap_ipmac_type); +} + +module_init(bitmap_ipmac_init); +module_exit(bitmap_ipmac_fini); diff --git a/net/netfilter/ipset/ip_set_bitmap_port.c b/net/netfilter/ipset/ip_set_bitmap_port.c new file mode 100644 index 000000000..7138e080d --- /dev/null +++ b/net/netfilter/ipset/ip_set_bitmap_port.c @@ -0,0 +1,332 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2003-2013 Jozsef Kadlecsik */ + +/* Kernel module implementing an IP set type: the bitmap:port type */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#define IPSET_TYPE_REV_MIN 0 +/* 1 Counter support added */ +/* 2 Comment support added */ +#define IPSET_TYPE_REV_MAX 3 /* skbinfo support added */ + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Jozsef Kadlecsik "); +IP_SET_MODULE_DESC("bitmap:port", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX); +MODULE_ALIAS("ip_set_bitmap:port"); + +#define MTYPE bitmap_port + +/* Type structure */ +struct bitmap_port { + unsigned long *members; /* the set members */ + u16 first_port; /* host byte order, included in range */ + u16 last_port; /* host byte order, included in range */ + u32 elements; /* number of max elements in the set */ + size_t memsize; /* members size */ + struct timer_list gc; /* garbage collection */ + struct ip_set *set; /* attached to this ip_set */ + unsigned char extensions[] /* data extensions */ + __aligned(__alignof__(u64)); +}; + +/* ADT structure for generic function args */ +struct bitmap_port_adt_elem { + u16 id; +}; + +static u16 +port_to_id(const struct bitmap_port *m, u16 port) +{ + return port - m->first_port; +} + +/* Common functions */ + +static int +bitmap_port_do_test(const struct bitmap_port_adt_elem *e, + const struct bitmap_port *map, size_t dsize) +{ + return !!test_bit(e->id, map->members); +} + +static int +bitmap_port_gc_test(u16 id, const struct bitmap_port *map, size_t dsize) +{ + return !!test_bit(id, map->members); +} + +static int +bitmap_port_do_add(const struct bitmap_port_adt_elem *e, + struct bitmap_port *map, u32 flags, size_t dsize) +{ + return !!test_bit(e->id, map->members); +} + +static int +bitmap_port_do_del(const struct bitmap_port_adt_elem *e, + struct bitmap_port *map) +{ + return !test_and_clear_bit(e->id, map->members); +} + +static int +bitmap_port_do_list(struct sk_buff *skb, const struct bitmap_port *map, u32 id, + size_t dsize) +{ + return nla_put_net16(skb, IPSET_ATTR_PORT, + htons(map->first_port + id)); +} + +static int +bitmap_port_do_head(struct sk_buff *skb, const struct bitmap_port *map) +{ + return nla_put_net16(skb, IPSET_ATTR_PORT, htons(map->first_port)) || + nla_put_net16(skb, IPSET_ATTR_PORT_TO, htons(map->last_port)); +} + +static bool +ip_set_get_ip_port(const struct sk_buff *skb, u8 pf, bool src, __be16 *port) +{ + bool ret; + u8 proto; + + switch (pf) { + case NFPROTO_IPV4: + ret = ip_set_get_ip4_port(skb, src, port, &proto); + break; + case NFPROTO_IPV6: + ret = ip_set_get_ip6_port(skb, src, port, &proto); + break; + default: + return false; + } + if (!ret) + return ret; + switch (proto) { + case IPPROTO_TCP: + case IPPROTO_UDP: + return true; + default: + return false; + } +} + +static int +bitmap_port_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + struct bitmap_port *map = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct bitmap_port_adt_elem e = { .id = 0 }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + __be16 __port; + u16 port = 0; + + if (!ip_set_get_ip_port(skb, opt->family, + opt->flags & IPSET_DIM_ONE_SRC, &__port)) + return -EINVAL; + + port = ntohs(__port); + + if (port < map->first_port || port > map->last_port) + return -IPSET_ERR_BITMAP_RANGE; + + e.id = port_to_id(map, port); + + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +bitmap_port_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + struct bitmap_port *map = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct bitmap_port_adt_elem e = { .id = 0 }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + u32 port; /* wraparound */ + u16 port_to; + int ret = 0; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_PORT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_PORT_TO))) + return -IPSET_ERR_PROTOCOL; + + port = ip_set_get_h16(tb[IPSET_ATTR_PORT]); + if (port < map->first_port || port > map->last_port) + return -IPSET_ERR_BITMAP_RANGE; + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + if (adt == IPSET_TEST) { + e.id = port_to_id(map, port); + return adtfn(set, &e, &ext, &ext, flags); + } + + if (tb[IPSET_ATTR_PORT_TO]) { + port_to = ip_set_get_h16(tb[IPSET_ATTR_PORT_TO]); + if (port > port_to) { + swap(port, port_to); + if (port < map->first_port) + return -IPSET_ERR_BITMAP_RANGE; + } + } else { + port_to = port; + } + + if (port_to > map->last_port) + return -IPSET_ERR_BITMAP_RANGE; + + for (; port <= port_to; port++) { + e.id = port_to_id(map, port); + ret = adtfn(set, &e, &ext, &ext, flags); + + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + ret = 0; + } + return ret; +} + +static bool +bitmap_port_same_set(const struct ip_set *a, const struct ip_set *b) +{ + const struct bitmap_port *x = a->data; + const struct bitmap_port *y = b->data; + + return x->first_port == y->first_port && + x->last_port == y->last_port && + a->timeout == b->timeout && + a->extensions == b->extensions; +} + +/* Plain variant */ + +struct bitmap_port_elem { +}; + +#include "ip_set_bitmap_gen.h" + +/* Create bitmap:ip type of sets */ + +static bool +init_map_port(struct ip_set *set, struct bitmap_port *map, + u16 first_port, u16 last_port) +{ + map->members = bitmap_zalloc(map->elements, GFP_KERNEL | __GFP_NOWARN); + if (!map->members) + return false; + map->first_port = first_port; + map->last_port = last_port; + set->timeout = IPSET_NO_TIMEOUT; + + map->set = set; + set->data = map; + set->family = NFPROTO_UNSPEC; + + return true; +} + +static int +bitmap_port_create(struct net *net, struct ip_set *set, struct nlattr *tb[], + u32 flags) +{ + struct bitmap_port *map; + u16 first_port, last_port; + u32 elements; + + if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_PORT) || + !ip_set_attr_netorder(tb, IPSET_ATTR_PORT_TO) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + + first_port = ip_set_get_h16(tb[IPSET_ATTR_PORT]); + last_port = ip_set_get_h16(tb[IPSET_ATTR_PORT_TO]); + if (first_port > last_port) + swap(first_port, last_port); + + elements = last_port - first_port + 1; + set->dsize = ip_set_elem_len(set, tb, 0, 0); + map = ip_set_alloc(sizeof(*map) + elements * set->dsize); + if (!map) + return -ENOMEM; + + map->elements = elements; + map->memsize = BITS_TO_LONGS(elements) * sizeof(unsigned long); + set->variant = &bitmap_port; + if (!init_map_port(set, map, first_port, last_port)) { + ip_set_free(map); + return -ENOMEM; + } + if (tb[IPSET_ATTR_TIMEOUT]) { + set->timeout = ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]); + bitmap_port_gc_init(set, bitmap_port_gc); + } + return 0; +} + +static struct ip_set_type bitmap_port_type = { + .name = "bitmap:port", + .protocol = IPSET_PROTOCOL, + .features = IPSET_TYPE_PORT, + .dimension = IPSET_DIM_ONE, + .family = NFPROTO_UNSPEC, + .revision_min = IPSET_TYPE_REV_MIN, + .revision_max = IPSET_TYPE_REV_MAX, + .create = bitmap_port_create, + .create_policy = { + [IPSET_ATTR_PORT] = { .type = NLA_U16 }, + [IPSET_ATTR_PORT_TO] = { .type = NLA_U16 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + }, + .adt_policy = { + [IPSET_ATTR_PORT] = { .type = NLA_U16 }, + [IPSET_ATTR_PORT_TO] = { .type = NLA_U16 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_LINENO] = { .type = NLA_U32 }, + [IPSET_ATTR_BYTES] = { .type = NLA_U64 }, + [IPSET_ATTR_PACKETS] = { .type = NLA_U64 }, + [IPSET_ATTR_COMMENT] = { .type = NLA_NUL_STRING, + .len = IPSET_MAX_COMMENT_SIZE }, + [IPSET_ATTR_SKBMARK] = { .type = NLA_U64 }, + [IPSET_ATTR_SKBPRIO] = { .type = NLA_U32 }, + [IPSET_ATTR_SKBQUEUE] = { .type = NLA_U16 }, + }, + .me = THIS_MODULE, +}; + +static int __init +bitmap_port_init(void) +{ + return ip_set_type_register(&bitmap_port_type); +} + +static void __exit +bitmap_port_fini(void) +{ + rcu_barrier(); + ip_set_type_unregister(&bitmap_port_type); +} + +module_init(bitmap_port_init); +module_exit(bitmap_port_fini); diff --git a/net/netfilter/ipset/ip_set_core.c b/net/netfilter/ipset/ip_set_core.c new file mode 100644 index 000000000..d47dfdcb8 --- /dev/null +++ b/net/netfilter/ipset/ip_set_core.c @@ -0,0 +1,2422 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2000-2002 Joakim Axelsson + * Patrick Schaaf + * Copyright (C) 2003-2013 Jozsef Kadlecsik + */ + +/* Kernel module for IP set management */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +static LIST_HEAD(ip_set_type_list); /* all registered set types */ +static DEFINE_MUTEX(ip_set_type_mutex); /* protects ip_set_type_list */ +static DEFINE_RWLOCK(ip_set_ref_lock); /* protects the set refs */ + +struct ip_set_net { + struct ip_set * __rcu *ip_set_list; /* all individual sets */ + ip_set_id_t ip_set_max; /* max number of sets */ + bool is_deleted; /* deleted by ip_set_net_exit */ + bool is_destroyed; /* all sets are destroyed */ +}; + +static unsigned int ip_set_net_id __read_mostly; + +static struct ip_set_net *ip_set_pernet(struct net *net) +{ + return net_generic(net, ip_set_net_id); +} + +#define IP_SET_INC 64 +#define STRNCMP(a, b) (strncmp(a, b, IPSET_MAXNAMELEN) == 0) + +static unsigned int max_sets; + +module_param(max_sets, int, 0600); +MODULE_PARM_DESC(max_sets, "maximal number of sets"); +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Jozsef Kadlecsik "); +MODULE_DESCRIPTION("core IP set support"); +MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_IPSET); + +/* When the nfnl mutex or ip_set_ref_lock is held: */ +#define ip_set_dereference(p) \ + rcu_dereference_protected(p, \ + lockdep_nfnl_is_held(NFNL_SUBSYS_IPSET) || \ + lockdep_is_held(&ip_set_ref_lock)) +#define ip_set(inst, id) \ + ip_set_dereference((inst)->ip_set_list)[id] +#define ip_set_ref_netlink(inst,id) \ + rcu_dereference_raw((inst)->ip_set_list)[id] +#define ip_set_dereference_nfnl(p) \ + rcu_dereference_check(p, lockdep_nfnl_is_held(NFNL_SUBSYS_IPSET)) + +/* The set types are implemented in modules and registered set types + * can be found in ip_set_type_list. Adding/deleting types is + * serialized by ip_set_type_mutex. + */ + +static void +ip_set_type_lock(void) +{ + mutex_lock(&ip_set_type_mutex); +} + +static void +ip_set_type_unlock(void) +{ + mutex_unlock(&ip_set_type_mutex); +} + +/* Register and deregister settype */ + +static struct ip_set_type * +find_set_type(const char *name, u8 family, u8 revision) +{ + struct ip_set_type *type; + + list_for_each_entry_rcu(type, &ip_set_type_list, list, + lockdep_is_held(&ip_set_type_mutex)) + if (STRNCMP(type->name, name) && + (type->family == family || + type->family == NFPROTO_UNSPEC) && + revision >= type->revision_min && + revision <= type->revision_max) + return type; + return NULL; +} + +/* Unlock, try to load a set type module and lock again */ +static bool +load_settype(const char *name) +{ + nfnl_unlock(NFNL_SUBSYS_IPSET); + pr_debug("try to load ip_set_%s\n", name); + if (request_module("ip_set_%s", name) < 0) { + pr_warn("Can't find ip_set type %s\n", name); + nfnl_lock(NFNL_SUBSYS_IPSET); + return false; + } + nfnl_lock(NFNL_SUBSYS_IPSET); + return true; +} + +/* Find a set type and reference it */ +#define find_set_type_get(name, family, revision, found) \ + __find_set_type_get(name, family, revision, found, false) + +static int +__find_set_type_get(const char *name, u8 family, u8 revision, + struct ip_set_type **found, bool retry) +{ + struct ip_set_type *type; + int err; + + if (retry && !load_settype(name)) + return -IPSET_ERR_FIND_TYPE; + + rcu_read_lock(); + *found = find_set_type(name, family, revision); + if (*found) { + err = !try_module_get((*found)->me) ? -EFAULT : 0; + goto unlock; + } + /* Make sure the type is already loaded + * but we don't support the revision + */ + list_for_each_entry_rcu(type, &ip_set_type_list, list) + if (STRNCMP(type->name, name)) { + err = -IPSET_ERR_FIND_TYPE; + goto unlock; + } + rcu_read_unlock(); + + return retry ? -IPSET_ERR_FIND_TYPE : + __find_set_type_get(name, family, revision, found, true); + +unlock: + rcu_read_unlock(); + return err; +} + +/* Find a given set type by name and family. + * If we succeeded, the supported minimal and maximum revisions are + * filled out. + */ +#define find_set_type_minmax(name, family, min, max) \ + __find_set_type_minmax(name, family, min, max, false) + +static int +__find_set_type_minmax(const char *name, u8 family, u8 *min, u8 *max, + bool retry) +{ + struct ip_set_type *type; + bool found = false; + + if (retry && !load_settype(name)) + return -IPSET_ERR_FIND_TYPE; + + *min = 255; *max = 0; + rcu_read_lock(); + list_for_each_entry_rcu(type, &ip_set_type_list, list) + if (STRNCMP(type->name, name) && + (type->family == family || + type->family == NFPROTO_UNSPEC)) { + found = true; + if (type->revision_min < *min) + *min = type->revision_min; + if (type->revision_max > *max) + *max = type->revision_max; + } + rcu_read_unlock(); + if (found) + return 0; + + return retry ? -IPSET_ERR_FIND_TYPE : + __find_set_type_minmax(name, family, min, max, true); +} + +#define family_name(f) ((f) == NFPROTO_IPV4 ? "inet" : \ + (f) == NFPROTO_IPV6 ? "inet6" : "any") + +/* Register a set type structure. The type is identified by + * the unique triple of name, family and revision. + */ +int +ip_set_type_register(struct ip_set_type *type) +{ + int ret = 0; + + if (type->protocol != IPSET_PROTOCOL) { + pr_warn("ip_set type %s, family %s, revision %u:%u uses wrong protocol version %u (want %u)\n", + type->name, family_name(type->family), + type->revision_min, type->revision_max, + type->protocol, IPSET_PROTOCOL); + return -EINVAL; + } + + ip_set_type_lock(); + if (find_set_type(type->name, type->family, type->revision_min)) { + /* Duplicate! */ + pr_warn("ip_set type %s, family %s with revision min %u already registered!\n", + type->name, family_name(type->family), + type->revision_min); + ip_set_type_unlock(); + return -EINVAL; + } + list_add_rcu(&type->list, &ip_set_type_list); + pr_debug("type %s, family %s, revision %u:%u registered.\n", + type->name, family_name(type->family), + type->revision_min, type->revision_max); + ip_set_type_unlock(); + + return ret; +} +EXPORT_SYMBOL_GPL(ip_set_type_register); + +/* Unregister a set type. There's a small race with ip_set_create */ +void +ip_set_type_unregister(struct ip_set_type *type) +{ + ip_set_type_lock(); + if (!find_set_type(type->name, type->family, type->revision_min)) { + pr_warn("ip_set type %s, family %s with revision min %u not registered\n", + type->name, family_name(type->family), + type->revision_min); + ip_set_type_unlock(); + return; + } + list_del_rcu(&type->list); + pr_debug("type %s, family %s with revision min %u unregistered.\n", + type->name, family_name(type->family), type->revision_min); + ip_set_type_unlock(); + + synchronize_rcu(); +} +EXPORT_SYMBOL_GPL(ip_set_type_unregister); + +/* Utility functions */ +void * +ip_set_alloc(size_t size) +{ + return kvzalloc(size, GFP_KERNEL_ACCOUNT); +} +EXPORT_SYMBOL_GPL(ip_set_alloc); + +void +ip_set_free(void *members) +{ + pr_debug("%p: free with %s\n", members, + is_vmalloc_addr(members) ? "vfree" : "kfree"); + kvfree(members); +} +EXPORT_SYMBOL_GPL(ip_set_free); + +static bool +flag_nested(const struct nlattr *nla) +{ + return nla->nla_type & NLA_F_NESTED; +} + +static const struct nla_policy ipaddr_policy[IPSET_ATTR_IPADDR_MAX + 1] = { + [IPSET_ATTR_IPADDR_IPV4] = { .type = NLA_U32 }, + [IPSET_ATTR_IPADDR_IPV6] = NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)), +}; + +int +ip_set_get_ipaddr4(struct nlattr *nla, __be32 *ipaddr) +{ + struct nlattr *tb[IPSET_ATTR_IPADDR_MAX + 1]; + + if (unlikely(!flag_nested(nla))) + return -IPSET_ERR_PROTOCOL; + if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla, + ipaddr_policy, NULL)) + return -IPSET_ERR_PROTOCOL; + if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV4))) + return -IPSET_ERR_PROTOCOL; + + *ipaddr = nla_get_be32(tb[IPSET_ATTR_IPADDR_IPV4]); + return 0; +} +EXPORT_SYMBOL_GPL(ip_set_get_ipaddr4); + +int +ip_set_get_ipaddr6(struct nlattr *nla, union nf_inet_addr *ipaddr) +{ + struct nlattr *tb[IPSET_ATTR_IPADDR_MAX + 1]; + + if (unlikely(!flag_nested(nla))) + return -IPSET_ERR_PROTOCOL; + + if (nla_parse_nested(tb, IPSET_ATTR_IPADDR_MAX, nla, + ipaddr_policy, NULL)) + return -IPSET_ERR_PROTOCOL; + if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV6))) + return -IPSET_ERR_PROTOCOL; + + memcpy(ipaddr, nla_data(tb[IPSET_ATTR_IPADDR_IPV6]), + sizeof(struct in6_addr)); + return 0; +} +EXPORT_SYMBOL_GPL(ip_set_get_ipaddr6); + +static u32 +ip_set_timeout_get(const unsigned long *timeout) +{ + u32 t; + + if (*timeout == IPSET_ELEM_PERMANENT) + return 0; + + t = jiffies_to_msecs(*timeout - jiffies) / MSEC_PER_SEC; + /* Zero value in userspace means no timeout */ + return t == 0 ? 1 : t; +} + +static char * +ip_set_comment_uget(struct nlattr *tb) +{ + return nla_data(tb); +} + +/* Called from uadd only, protected by the set spinlock. + * The kadt functions don't use the comment extensions in any way. + */ +void +ip_set_init_comment(struct ip_set *set, struct ip_set_comment *comment, + const struct ip_set_ext *ext) +{ + struct ip_set_comment_rcu *c = rcu_dereference_protected(comment->c, 1); + size_t len = ext->comment ? strlen(ext->comment) : 0; + + if (unlikely(c)) { + set->ext_size -= sizeof(*c) + strlen(c->str) + 1; + kfree_rcu(c, rcu); + rcu_assign_pointer(comment->c, NULL); + } + if (!len) + return; + if (unlikely(len > IPSET_MAX_COMMENT_SIZE)) + len = IPSET_MAX_COMMENT_SIZE; + c = kmalloc(sizeof(*c) + len + 1, GFP_ATOMIC); + if (unlikely(!c)) + return; + strscpy(c->str, ext->comment, len + 1); + set->ext_size += sizeof(*c) + strlen(c->str) + 1; + rcu_assign_pointer(comment->c, c); +} +EXPORT_SYMBOL_GPL(ip_set_init_comment); + +/* Used only when dumping a set, protected by rcu_read_lock() */ +static int +ip_set_put_comment(struct sk_buff *skb, const struct ip_set_comment *comment) +{ + struct ip_set_comment_rcu *c = rcu_dereference(comment->c); + + if (!c) + return 0; + return nla_put_string(skb, IPSET_ATTR_COMMENT, c->str); +} + +/* Called from uadd/udel, flush or the garbage collectors protected + * by the set spinlock. + * Called when the set is destroyed and when there can't be any user + * of the set data anymore. + */ +static void +ip_set_comment_free(struct ip_set *set, void *ptr) +{ + struct ip_set_comment *comment = ptr; + struct ip_set_comment_rcu *c; + + c = rcu_dereference_protected(comment->c, 1); + if (unlikely(!c)) + return; + set->ext_size -= sizeof(*c) + strlen(c->str) + 1; + kfree_rcu(c, rcu); + rcu_assign_pointer(comment->c, NULL); +} + +typedef void (*destroyer)(struct ip_set *, void *); +/* ipset data extension types, in size order */ + +const struct ip_set_ext_type ip_set_extensions[] = { + [IPSET_EXT_ID_COUNTER] = { + .type = IPSET_EXT_COUNTER, + .flag = IPSET_FLAG_WITH_COUNTERS, + .len = sizeof(struct ip_set_counter), + .align = __alignof__(struct ip_set_counter), + }, + [IPSET_EXT_ID_TIMEOUT] = { + .type = IPSET_EXT_TIMEOUT, + .len = sizeof(unsigned long), + .align = __alignof__(unsigned long), + }, + [IPSET_EXT_ID_SKBINFO] = { + .type = IPSET_EXT_SKBINFO, + .flag = IPSET_FLAG_WITH_SKBINFO, + .len = sizeof(struct ip_set_skbinfo), + .align = __alignof__(struct ip_set_skbinfo), + }, + [IPSET_EXT_ID_COMMENT] = { + .type = IPSET_EXT_COMMENT | IPSET_EXT_DESTROY, + .flag = IPSET_FLAG_WITH_COMMENT, + .len = sizeof(struct ip_set_comment), + .align = __alignof__(struct ip_set_comment), + .destroy = ip_set_comment_free, + }, +}; +EXPORT_SYMBOL_GPL(ip_set_extensions); + +static bool +add_extension(enum ip_set_ext_id id, u32 flags, struct nlattr *tb[]) +{ + return ip_set_extensions[id].flag ? + (flags & ip_set_extensions[id].flag) : + !!tb[IPSET_ATTR_TIMEOUT]; +} + +size_t +ip_set_elem_len(struct ip_set *set, struct nlattr *tb[], size_t len, + size_t align) +{ + enum ip_set_ext_id id; + u32 cadt_flags = 0; + + if (tb[IPSET_ATTR_CADT_FLAGS]) + cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]); + if (cadt_flags & IPSET_FLAG_WITH_FORCEADD) + set->flags |= IPSET_CREATE_FLAG_FORCEADD; + if (!align) + align = 1; + for (id = 0; id < IPSET_EXT_ID_MAX; id++) { + if (!add_extension(id, cadt_flags, tb)) + continue; + if (align < ip_set_extensions[id].align) + align = ip_set_extensions[id].align; + len = ALIGN(len, ip_set_extensions[id].align); + set->offset[id] = len; + set->extensions |= ip_set_extensions[id].type; + len += ip_set_extensions[id].len; + } + return ALIGN(len, align); +} +EXPORT_SYMBOL_GPL(ip_set_elem_len); + +int +ip_set_get_extensions(struct ip_set *set, struct nlattr *tb[], + struct ip_set_ext *ext) +{ + u64 fullmark; + + if (unlikely(!ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_PACKETS) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_BYTES) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBMARK) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBPRIO) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBQUEUE))) + return -IPSET_ERR_PROTOCOL; + + if (tb[IPSET_ATTR_TIMEOUT]) { + if (!SET_WITH_TIMEOUT(set)) + return -IPSET_ERR_TIMEOUT; + ext->timeout = ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]); + } + if (tb[IPSET_ATTR_BYTES] || tb[IPSET_ATTR_PACKETS]) { + if (!SET_WITH_COUNTER(set)) + return -IPSET_ERR_COUNTER; + if (tb[IPSET_ATTR_BYTES]) + ext->bytes = be64_to_cpu(nla_get_be64( + tb[IPSET_ATTR_BYTES])); + if (tb[IPSET_ATTR_PACKETS]) + ext->packets = be64_to_cpu(nla_get_be64( + tb[IPSET_ATTR_PACKETS])); + } + if (tb[IPSET_ATTR_COMMENT]) { + if (!SET_WITH_COMMENT(set)) + return -IPSET_ERR_COMMENT; + ext->comment = ip_set_comment_uget(tb[IPSET_ATTR_COMMENT]); + } + if (tb[IPSET_ATTR_SKBMARK]) { + if (!SET_WITH_SKBINFO(set)) + return -IPSET_ERR_SKBINFO; + fullmark = be64_to_cpu(nla_get_be64(tb[IPSET_ATTR_SKBMARK])); + ext->skbinfo.skbmark = fullmark >> 32; + ext->skbinfo.skbmarkmask = fullmark & 0xffffffff; + } + if (tb[IPSET_ATTR_SKBPRIO]) { + if (!SET_WITH_SKBINFO(set)) + return -IPSET_ERR_SKBINFO; + ext->skbinfo.skbprio = + be32_to_cpu(nla_get_be32(tb[IPSET_ATTR_SKBPRIO])); + } + if (tb[IPSET_ATTR_SKBQUEUE]) { + if (!SET_WITH_SKBINFO(set)) + return -IPSET_ERR_SKBINFO; + ext->skbinfo.skbqueue = + be16_to_cpu(nla_get_be16(tb[IPSET_ATTR_SKBQUEUE])); + } + return 0; +} +EXPORT_SYMBOL_GPL(ip_set_get_extensions); + +static u64 +ip_set_get_bytes(const struct ip_set_counter *counter) +{ + return (u64)atomic64_read(&(counter)->bytes); +} + +static u64 +ip_set_get_packets(const struct ip_set_counter *counter) +{ + return (u64)atomic64_read(&(counter)->packets); +} + +static bool +ip_set_put_counter(struct sk_buff *skb, const struct ip_set_counter *counter) +{ + return nla_put_net64(skb, IPSET_ATTR_BYTES, + cpu_to_be64(ip_set_get_bytes(counter)), + IPSET_ATTR_PAD) || + nla_put_net64(skb, IPSET_ATTR_PACKETS, + cpu_to_be64(ip_set_get_packets(counter)), + IPSET_ATTR_PAD); +} + +static bool +ip_set_put_skbinfo(struct sk_buff *skb, const struct ip_set_skbinfo *skbinfo) +{ + /* Send nonzero parameters only */ + return ((skbinfo->skbmark || skbinfo->skbmarkmask) && + nla_put_net64(skb, IPSET_ATTR_SKBMARK, + cpu_to_be64((u64)skbinfo->skbmark << 32 | + skbinfo->skbmarkmask), + IPSET_ATTR_PAD)) || + (skbinfo->skbprio && + nla_put_net32(skb, IPSET_ATTR_SKBPRIO, + cpu_to_be32(skbinfo->skbprio))) || + (skbinfo->skbqueue && + nla_put_net16(skb, IPSET_ATTR_SKBQUEUE, + cpu_to_be16(skbinfo->skbqueue))); +} + +int +ip_set_put_extensions(struct sk_buff *skb, const struct ip_set *set, + const void *e, bool active) +{ + if (SET_WITH_TIMEOUT(set)) { + unsigned long *timeout = ext_timeout(e, set); + + if (nla_put_net32(skb, IPSET_ATTR_TIMEOUT, + htonl(active ? ip_set_timeout_get(timeout) + : *timeout))) + return -EMSGSIZE; + } + if (SET_WITH_COUNTER(set) && + ip_set_put_counter(skb, ext_counter(e, set))) + return -EMSGSIZE; + if (SET_WITH_COMMENT(set) && + ip_set_put_comment(skb, ext_comment(e, set))) + return -EMSGSIZE; + if (SET_WITH_SKBINFO(set) && + ip_set_put_skbinfo(skb, ext_skbinfo(e, set))) + return -EMSGSIZE; + return 0; +} +EXPORT_SYMBOL_GPL(ip_set_put_extensions); + +static bool +ip_set_match_counter(u64 counter, u64 match, u8 op) +{ + switch (op) { + case IPSET_COUNTER_NONE: + return true; + case IPSET_COUNTER_EQ: + return counter == match; + case IPSET_COUNTER_NE: + return counter != match; + case IPSET_COUNTER_LT: + return counter < match; + case IPSET_COUNTER_GT: + return counter > match; + } + return false; +} + +static void +ip_set_add_bytes(u64 bytes, struct ip_set_counter *counter) +{ + atomic64_add((long long)bytes, &(counter)->bytes); +} + +static void +ip_set_add_packets(u64 packets, struct ip_set_counter *counter) +{ + atomic64_add((long long)packets, &(counter)->packets); +} + +static void +ip_set_update_counter(struct ip_set_counter *counter, + const struct ip_set_ext *ext, u32 flags) +{ + if (ext->packets != ULLONG_MAX && + !(flags & IPSET_FLAG_SKIP_COUNTER_UPDATE)) { + ip_set_add_bytes(ext->bytes, counter); + ip_set_add_packets(ext->packets, counter); + } +} + +static void +ip_set_get_skbinfo(struct ip_set_skbinfo *skbinfo, + const struct ip_set_ext *ext, + struct ip_set_ext *mext, u32 flags) +{ + mext->skbinfo = *skbinfo; +} + +bool +ip_set_match_extensions(struct ip_set *set, const struct ip_set_ext *ext, + struct ip_set_ext *mext, u32 flags, void *data) +{ + if (SET_WITH_TIMEOUT(set) && + ip_set_timeout_expired(ext_timeout(data, set))) + return false; + if (SET_WITH_COUNTER(set)) { + struct ip_set_counter *counter = ext_counter(data, set); + + ip_set_update_counter(counter, ext, flags); + + if (flags & IPSET_FLAG_MATCH_COUNTERS && + !(ip_set_match_counter(ip_set_get_packets(counter), + mext->packets, mext->packets_op) && + ip_set_match_counter(ip_set_get_bytes(counter), + mext->bytes, mext->bytes_op))) + return false; + } + if (SET_WITH_SKBINFO(set)) + ip_set_get_skbinfo(ext_skbinfo(data, set), + ext, mext, flags); + return true; +} +EXPORT_SYMBOL_GPL(ip_set_match_extensions); + +/* Creating/destroying/renaming/swapping affect the existence and + * the properties of a set. All of these can be executed from userspace + * only and serialized by the nfnl mutex indirectly from nfnetlink. + * + * Sets are identified by their index in ip_set_list and the index + * is used by the external references (set/SET netfilter modules). + * + * The set behind an index may change by swapping only, from userspace. + */ + +static void +__ip_set_get(struct ip_set *set) +{ + write_lock_bh(&ip_set_ref_lock); + set->ref++; + write_unlock_bh(&ip_set_ref_lock); +} + +static void +__ip_set_put(struct ip_set *set) +{ + write_lock_bh(&ip_set_ref_lock); + BUG_ON(set->ref == 0); + set->ref--; + write_unlock_bh(&ip_set_ref_lock); +} + +/* set->ref can be swapped out by ip_set_swap, netlink events (like dump) need + * a separate reference counter + */ +static void +__ip_set_get_netlink(struct ip_set *set) +{ + write_lock_bh(&ip_set_ref_lock); + set->ref_netlink++; + write_unlock_bh(&ip_set_ref_lock); +} + +static void +__ip_set_put_netlink(struct ip_set *set) +{ + write_lock_bh(&ip_set_ref_lock); + BUG_ON(set->ref_netlink == 0); + set->ref_netlink--; + write_unlock_bh(&ip_set_ref_lock); +} + +/* Add, del and test set entries from kernel. + * + * The set behind the index must exist and must be referenced + * so it can't be destroyed (or changed) under our foot. + */ + +static struct ip_set * +ip_set_rcu_get(struct net *net, ip_set_id_t index) +{ + struct ip_set_net *inst = ip_set_pernet(net); + + /* ip_set_list and the set pointer need to be protected */ + return ip_set_dereference_nfnl(inst->ip_set_list)[index]; +} + +static inline void +ip_set_lock(struct ip_set *set) +{ + if (!set->variant->region_lock) + spin_lock_bh(&set->lock); +} + +static inline void +ip_set_unlock(struct ip_set *set) +{ + if (!set->variant->region_lock) + spin_unlock_bh(&set->lock); +} + +int +ip_set_test(ip_set_id_t index, const struct sk_buff *skb, + const struct xt_action_param *par, struct ip_set_adt_opt *opt) +{ + struct ip_set *set = ip_set_rcu_get(xt_net(par), index); + int ret = 0; + + BUG_ON(!set); + pr_debug("set %s, index %u\n", set->name, index); + + if (opt->dim < set->type->dimension || + !(opt->family == set->family || set->family == NFPROTO_UNSPEC)) + return 0; + + rcu_read_lock_bh(); + ret = set->variant->kadt(set, skb, par, IPSET_TEST, opt); + rcu_read_unlock_bh(); + + if (ret == -EAGAIN) { + /* Type requests element to be completed */ + pr_debug("element must be completed, ADD is triggered\n"); + ip_set_lock(set); + set->variant->kadt(set, skb, par, IPSET_ADD, opt); + ip_set_unlock(set); + ret = 1; + } else { + /* --return-nomatch: invert matched element */ + if ((opt->cmdflags & IPSET_FLAG_RETURN_NOMATCH) && + (set->type->features & IPSET_TYPE_NOMATCH) && + (ret > 0 || ret == -ENOTEMPTY)) + ret = -ret; + } + + /* Convert error codes to nomatch */ + return (ret < 0 ? 0 : ret); +} +EXPORT_SYMBOL_GPL(ip_set_test); + +int +ip_set_add(ip_set_id_t index, const struct sk_buff *skb, + const struct xt_action_param *par, struct ip_set_adt_opt *opt) +{ + struct ip_set *set = ip_set_rcu_get(xt_net(par), index); + int ret; + + BUG_ON(!set); + pr_debug("set %s, index %u\n", set->name, index); + + if (opt->dim < set->type->dimension || + !(opt->family == set->family || set->family == NFPROTO_UNSPEC)) + return -IPSET_ERR_TYPE_MISMATCH; + + ip_set_lock(set); + ret = set->variant->kadt(set, skb, par, IPSET_ADD, opt); + ip_set_unlock(set); + + return ret; +} +EXPORT_SYMBOL_GPL(ip_set_add); + +int +ip_set_del(ip_set_id_t index, const struct sk_buff *skb, + const struct xt_action_param *par, struct ip_set_adt_opt *opt) +{ + struct ip_set *set = ip_set_rcu_get(xt_net(par), index); + int ret = 0; + + BUG_ON(!set); + pr_debug("set %s, index %u\n", set->name, index); + + if (opt->dim < set->type->dimension || + !(opt->family == set->family || set->family == NFPROTO_UNSPEC)) + return -IPSET_ERR_TYPE_MISMATCH; + + ip_set_lock(set); + ret = set->variant->kadt(set, skb, par, IPSET_DEL, opt); + ip_set_unlock(set); + + return ret; +} +EXPORT_SYMBOL_GPL(ip_set_del); + +/* Find set by name, reference it once. The reference makes sure the + * thing pointed to, does not go away under our feet. + * + */ +ip_set_id_t +ip_set_get_byname(struct net *net, const char *name, struct ip_set **set) +{ + ip_set_id_t i, index = IPSET_INVALID_ID; + struct ip_set *s; + struct ip_set_net *inst = ip_set_pernet(net); + + rcu_read_lock(); + for (i = 0; i < inst->ip_set_max; i++) { + s = rcu_dereference(inst->ip_set_list)[i]; + if (s && STRNCMP(s->name, name)) { + __ip_set_get(s); + index = i; + *set = s; + break; + } + } + rcu_read_unlock(); + + return index; +} +EXPORT_SYMBOL_GPL(ip_set_get_byname); + +/* If the given set pointer points to a valid set, decrement + * reference count by 1. The caller shall not assume the index + * to be valid, after calling this function. + * + */ + +static void +__ip_set_put_byindex(struct ip_set_net *inst, ip_set_id_t index) +{ + struct ip_set *set; + + rcu_read_lock(); + set = rcu_dereference(inst->ip_set_list)[index]; + if (set) + __ip_set_put(set); + rcu_read_unlock(); +} + +void +ip_set_put_byindex(struct net *net, ip_set_id_t index) +{ + struct ip_set_net *inst = ip_set_pernet(net); + + __ip_set_put_byindex(inst, index); +} +EXPORT_SYMBOL_GPL(ip_set_put_byindex); + +/* Get the name of a set behind a set index. + * Set itself is protected by RCU, but its name isn't: to protect against + * renaming, grab ip_set_ref_lock as reader (see ip_set_rename()) and copy the + * name. + */ +void +ip_set_name_byindex(struct net *net, ip_set_id_t index, char *name) +{ + struct ip_set *set = ip_set_rcu_get(net, index); + + BUG_ON(!set); + + read_lock_bh(&ip_set_ref_lock); + strncpy(name, set->name, IPSET_MAXNAMELEN); + read_unlock_bh(&ip_set_ref_lock); +} +EXPORT_SYMBOL_GPL(ip_set_name_byindex); + +/* Routines to call by external subsystems, which do not + * call nfnl_lock for us. + */ + +/* Find set by index, reference it once. The reference makes sure the + * thing pointed to, does not go away under our feet. + * + * The nfnl mutex is used in the function. + */ +ip_set_id_t +ip_set_nfnl_get_byindex(struct net *net, ip_set_id_t index) +{ + struct ip_set *set; + struct ip_set_net *inst = ip_set_pernet(net); + + if (index >= inst->ip_set_max) + return IPSET_INVALID_ID; + + nfnl_lock(NFNL_SUBSYS_IPSET); + set = ip_set(inst, index); + if (set) + __ip_set_get(set); + else + index = IPSET_INVALID_ID; + nfnl_unlock(NFNL_SUBSYS_IPSET); + + return index; +} +EXPORT_SYMBOL_GPL(ip_set_nfnl_get_byindex); + +/* If the given set pointer points to a valid set, decrement + * reference count by 1. The caller shall not assume the index + * to be valid, after calling this function. + * + * The nfnl mutex is used in the function. + */ +void +ip_set_nfnl_put(struct net *net, ip_set_id_t index) +{ + struct ip_set *set; + struct ip_set_net *inst = ip_set_pernet(net); + + nfnl_lock(NFNL_SUBSYS_IPSET); + if (!inst->is_deleted) { /* already deleted from ip_set_net_exit() */ + set = ip_set(inst, index); + if (set) + __ip_set_put(set); + } + nfnl_unlock(NFNL_SUBSYS_IPSET); +} +EXPORT_SYMBOL_GPL(ip_set_nfnl_put); + +/* Communication protocol with userspace over netlink. + * + * The commands are serialized by the nfnl mutex. + */ + +static inline u8 protocol(const struct nlattr * const tb[]) +{ + return nla_get_u8(tb[IPSET_ATTR_PROTOCOL]); +} + +static inline bool +protocol_failed(const struct nlattr * const tb[]) +{ + return !tb[IPSET_ATTR_PROTOCOL] || protocol(tb) != IPSET_PROTOCOL; +} + +static inline bool +protocol_min_failed(const struct nlattr * const tb[]) +{ + return !tb[IPSET_ATTR_PROTOCOL] || protocol(tb) < IPSET_PROTOCOL_MIN; +} + +static inline u32 +flag_exist(const struct nlmsghdr *nlh) +{ + return nlh->nlmsg_flags & NLM_F_EXCL ? 0 : IPSET_FLAG_EXIST; +} + +static struct nlmsghdr * +start_msg(struct sk_buff *skb, u32 portid, u32 seq, unsigned int flags, + enum ipset_cmd cmd) +{ + return nfnl_msg_put(skb, portid, seq, + nfnl_msg_type(NFNL_SUBSYS_IPSET, cmd), flags, + NFPROTO_IPV4, NFNETLINK_V0, 0); +} + +/* Create a set */ + +static const struct nla_policy ip_set_create_policy[IPSET_ATTR_CMD_MAX + 1] = { + [IPSET_ATTR_PROTOCOL] = { .type = NLA_U8 }, + [IPSET_ATTR_SETNAME] = { .type = NLA_NUL_STRING, + .len = IPSET_MAXNAMELEN - 1 }, + [IPSET_ATTR_TYPENAME] = { .type = NLA_NUL_STRING, + .len = IPSET_MAXNAMELEN - 1}, + [IPSET_ATTR_REVISION] = { .type = NLA_U8 }, + [IPSET_ATTR_FAMILY] = { .type = NLA_U8 }, + [IPSET_ATTR_DATA] = { .type = NLA_NESTED }, +}; + +static struct ip_set * +find_set_and_id(struct ip_set_net *inst, const char *name, ip_set_id_t *id) +{ + struct ip_set *set = NULL; + ip_set_id_t i; + + *id = IPSET_INVALID_ID; + for (i = 0; i < inst->ip_set_max; i++) { + set = ip_set(inst, i); + if (set && STRNCMP(set->name, name)) { + *id = i; + break; + } + } + return (*id == IPSET_INVALID_ID ? NULL : set); +} + +static inline struct ip_set * +find_set(struct ip_set_net *inst, const char *name) +{ + ip_set_id_t id; + + return find_set_and_id(inst, name, &id); +} + +static int +find_free_id(struct ip_set_net *inst, const char *name, ip_set_id_t *index, + struct ip_set **set) +{ + struct ip_set *s; + ip_set_id_t i; + + *index = IPSET_INVALID_ID; + for (i = 0; i < inst->ip_set_max; i++) { + s = ip_set(inst, i); + if (!s) { + if (*index == IPSET_INVALID_ID) + *index = i; + } else if (STRNCMP(name, s->name)) { + /* Name clash */ + *set = s; + return -EEXIST; + } + } + if (*index == IPSET_INVALID_ID) + /* No free slot remained */ + return -IPSET_ERR_MAX_SETS; + return 0; +} + +static int ip_set_none(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const attr[]) +{ + return -EOPNOTSUPP; +} + +static int ip_set_create(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const attr[]) +{ + struct ip_set_net *inst = ip_set_pernet(info->net); + struct ip_set *set, *clash = NULL; + ip_set_id_t index = IPSET_INVALID_ID; + struct nlattr *tb[IPSET_ATTR_CREATE_MAX + 1] = {}; + const char *name, *typename; + u8 family, revision; + u32 flags = flag_exist(info->nlh); + int ret = 0; + + if (unlikely(protocol_min_failed(attr) || + !attr[IPSET_ATTR_SETNAME] || + !attr[IPSET_ATTR_TYPENAME] || + !attr[IPSET_ATTR_REVISION] || + !attr[IPSET_ATTR_FAMILY] || + (attr[IPSET_ATTR_DATA] && + !flag_nested(attr[IPSET_ATTR_DATA])))) + return -IPSET_ERR_PROTOCOL; + + name = nla_data(attr[IPSET_ATTR_SETNAME]); + typename = nla_data(attr[IPSET_ATTR_TYPENAME]); + family = nla_get_u8(attr[IPSET_ATTR_FAMILY]); + revision = nla_get_u8(attr[IPSET_ATTR_REVISION]); + pr_debug("setname: %s, typename: %s, family: %s, revision: %u\n", + name, typename, family_name(family), revision); + + /* First, and without any locks, allocate and initialize + * a normal base set structure. + */ + set = kzalloc(sizeof(*set), GFP_KERNEL); + if (!set) + return -ENOMEM; + spin_lock_init(&set->lock); + strscpy(set->name, name, IPSET_MAXNAMELEN); + set->family = family; + set->revision = revision; + + /* Next, check that we know the type, and take + * a reference on the type, to make sure it stays available + * while constructing our new set. + * + * After referencing the type, we try to create the type + * specific part of the set without holding any locks. + */ + ret = find_set_type_get(typename, family, revision, &set->type); + if (ret) + goto out; + + /* Without holding any locks, create private part. */ + if (attr[IPSET_ATTR_DATA] && + nla_parse_nested(tb, IPSET_ATTR_CREATE_MAX, attr[IPSET_ATTR_DATA], + set->type->create_policy, NULL)) { + ret = -IPSET_ERR_PROTOCOL; + goto put_out; + } + /* Set create flags depending on the type revision */ + set->flags |= set->type->create_flags[revision]; + + ret = set->type->create(info->net, set, tb, flags); + if (ret != 0) + goto put_out; + + /* BTW, ret==0 here. */ + + /* Here, we have a valid, constructed set and we are protected + * by the nfnl mutex. Find the first free index in ip_set_list + * and check clashing. + */ + ret = find_free_id(inst, set->name, &index, &clash); + if (ret == -EEXIST) { + /* If this is the same set and requested, ignore error */ + if ((flags & IPSET_FLAG_EXIST) && + STRNCMP(set->type->name, clash->type->name) && + set->type->family == clash->type->family && + set->type->revision_min == clash->type->revision_min && + set->type->revision_max == clash->type->revision_max && + set->variant->same_set(set, clash)) + ret = 0; + goto cleanup; + } else if (ret == -IPSET_ERR_MAX_SETS) { + struct ip_set **list, **tmp; + ip_set_id_t i = inst->ip_set_max + IP_SET_INC; + + if (i < inst->ip_set_max || i == IPSET_INVALID_ID) + /* Wraparound */ + goto cleanup; + + list = kvcalloc(i, sizeof(struct ip_set *), GFP_KERNEL); + if (!list) + goto cleanup; + /* nfnl mutex is held, both lists are valid */ + tmp = ip_set_dereference(inst->ip_set_list); + memcpy(list, tmp, sizeof(struct ip_set *) * inst->ip_set_max); + rcu_assign_pointer(inst->ip_set_list, list); + /* Make sure all current packets have passed through */ + synchronize_net(); + /* Use new list */ + index = inst->ip_set_max; + inst->ip_set_max = i; + kvfree(tmp); + ret = 0; + } else if (ret) { + goto cleanup; + } + + /* Finally! Add our shiny new set to the list, and be done. */ + pr_debug("create: '%s' created with index %u!\n", set->name, index); + ip_set(inst, index) = set; + + return ret; + +cleanup: + set->variant->destroy(set); +put_out: + module_put(set->type->me); +out: + kfree(set); + return ret; +} + +/* Destroy sets */ + +static const struct nla_policy +ip_set_setname_policy[IPSET_ATTR_CMD_MAX + 1] = { + [IPSET_ATTR_PROTOCOL] = { .type = NLA_U8 }, + [IPSET_ATTR_SETNAME] = { .type = NLA_NUL_STRING, + .len = IPSET_MAXNAMELEN - 1 }, +}; + +static void +ip_set_destroy_set(struct ip_set *set) +{ + pr_debug("set: %s\n", set->name); + + /* Must call it without holding any lock */ + set->variant->destroy(set); + module_put(set->type->me); + kfree(set); +} + +static int ip_set_destroy(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const attr[]) +{ + struct ip_set_net *inst = ip_set_pernet(info->net); + struct ip_set *s; + ip_set_id_t i; + int ret = 0; + + if (unlikely(protocol_min_failed(attr))) + return -IPSET_ERR_PROTOCOL; + + /* Must wait for flush to be really finished in list:set */ + rcu_barrier(); + + /* Commands are serialized and references are + * protected by the ip_set_ref_lock. + * External systems (i.e. xt_set) must call + * ip_set_put|get_nfnl_* functions, that way we + * can safely check references here. + * + * list:set timer can only decrement the reference + * counter, so if it's already zero, we can proceed + * without holding the lock. + */ + read_lock_bh(&ip_set_ref_lock); + if (!attr[IPSET_ATTR_SETNAME]) { + for (i = 0; i < inst->ip_set_max; i++) { + s = ip_set(inst, i); + if (s && (s->ref || s->ref_netlink)) { + ret = -IPSET_ERR_BUSY; + goto out; + } + } + inst->is_destroyed = true; + read_unlock_bh(&ip_set_ref_lock); + for (i = 0; i < inst->ip_set_max; i++) { + s = ip_set(inst, i); + if (s) { + ip_set(inst, i) = NULL; + ip_set_destroy_set(s); + } + } + /* Modified by ip_set_destroy() only, which is serialized */ + inst->is_destroyed = false; + } else { + u32 flags = flag_exist(info->nlh); + s = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]), + &i); + if (!s) { + if (!(flags & IPSET_FLAG_EXIST)) + ret = -ENOENT; + goto out; + } else if (s->ref || s->ref_netlink) { + ret = -IPSET_ERR_BUSY; + goto out; + } + ip_set(inst, i) = NULL; + read_unlock_bh(&ip_set_ref_lock); + + ip_set_destroy_set(s); + } + return 0; +out: + read_unlock_bh(&ip_set_ref_lock); + return ret; +} + +/* Flush sets */ + +static void +ip_set_flush_set(struct ip_set *set) +{ + pr_debug("set: %s\n", set->name); + + ip_set_lock(set); + set->variant->flush(set); + ip_set_unlock(set); +} + +static int ip_set_flush(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const attr[]) +{ + struct ip_set_net *inst = ip_set_pernet(info->net); + struct ip_set *s; + ip_set_id_t i; + + if (unlikely(protocol_min_failed(attr))) + return -IPSET_ERR_PROTOCOL; + + if (!attr[IPSET_ATTR_SETNAME]) { + for (i = 0; i < inst->ip_set_max; i++) { + s = ip_set(inst, i); + if (s) + ip_set_flush_set(s); + } + } else { + s = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME])); + if (!s) + return -ENOENT; + + ip_set_flush_set(s); + } + + return 0; +} + +/* Rename a set */ + +static const struct nla_policy +ip_set_setname2_policy[IPSET_ATTR_CMD_MAX + 1] = { + [IPSET_ATTR_PROTOCOL] = { .type = NLA_U8 }, + [IPSET_ATTR_SETNAME] = { .type = NLA_NUL_STRING, + .len = IPSET_MAXNAMELEN - 1 }, + [IPSET_ATTR_SETNAME2] = { .type = NLA_NUL_STRING, + .len = IPSET_MAXNAMELEN - 1 }, +}; + +static int ip_set_rename(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const attr[]) +{ + struct ip_set_net *inst = ip_set_pernet(info->net); + struct ip_set *set, *s; + const char *name2; + ip_set_id_t i; + int ret = 0; + + if (unlikely(protocol_min_failed(attr) || + !attr[IPSET_ATTR_SETNAME] || + !attr[IPSET_ATTR_SETNAME2])) + return -IPSET_ERR_PROTOCOL; + + set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME])); + if (!set) + return -ENOENT; + + write_lock_bh(&ip_set_ref_lock); + if (set->ref != 0 || set->ref_netlink != 0) { + ret = -IPSET_ERR_REFERENCED; + goto out; + } + + name2 = nla_data(attr[IPSET_ATTR_SETNAME2]); + for (i = 0; i < inst->ip_set_max; i++) { + s = ip_set(inst, i); + if (s && STRNCMP(s->name, name2)) { + ret = -IPSET_ERR_EXIST_SETNAME2; + goto out; + } + } + strncpy(set->name, name2, IPSET_MAXNAMELEN); + +out: + write_unlock_bh(&ip_set_ref_lock); + return ret; +} + +/* Swap two sets so that name/index points to the other. + * References and set names are also swapped. + * + * The commands are serialized by the nfnl mutex and references are + * protected by the ip_set_ref_lock. The kernel interfaces + * do not hold the mutex but the pointer settings are atomic + * so the ip_set_list always contains valid pointers to the sets. + */ + +static int ip_set_swap(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const attr[]) +{ + struct ip_set_net *inst = ip_set_pernet(info->net); + struct ip_set *from, *to; + ip_set_id_t from_id, to_id; + char from_name[IPSET_MAXNAMELEN]; + + if (unlikely(protocol_min_failed(attr) || + !attr[IPSET_ATTR_SETNAME] || + !attr[IPSET_ATTR_SETNAME2])) + return -IPSET_ERR_PROTOCOL; + + from = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]), + &from_id); + if (!from) + return -ENOENT; + + to = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME2]), + &to_id); + if (!to) + return -IPSET_ERR_EXIST_SETNAME2; + + /* Features must not change. + * Not an artifical restriction anymore, as we must prevent + * possible loops created by swapping in setlist type of sets. + */ + if (!(from->type->features == to->type->features && + from->family == to->family)) + return -IPSET_ERR_TYPE_MISMATCH; + + write_lock_bh(&ip_set_ref_lock); + + if (from->ref_netlink || to->ref_netlink) { + write_unlock_bh(&ip_set_ref_lock); + return -EBUSY; + } + + strncpy(from_name, from->name, IPSET_MAXNAMELEN); + strncpy(from->name, to->name, IPSET_MAXNAMELEN); + strncpy(to->name, from_name, IPSET_MAXNAMELEN); + + swap(from->ref, to->ref); + ip_set(inst, from_id) = to; + ip_set(inst, to_id) = from; + write_unlock_bh(&ip_set_ref_lock); + + /* Make sure all readers of the old set pointers are completed. */ + synchronize_rcu(); + + return 0; +} + +/* List/save set data */ + +#define DUMP_INIT 0 +#define DUMP_ALL 1 +#define DUMP_ONE 2 +#define DUMP_LAST 3 + +#define DUMP_TYPE(arg) (((u32)(arg)) & 0x0000FFFF) +#define DUMP_FLAGS(arg) (((u32)(arg)) >> 16) + +int +ip_set_put_flags(struct sk_buff *skb, struct ip_set *set) +{ + u32 cadt_flags = 0; + + if (SET_WITH_TIMEOUT(set)) + if (unlikely(nla_put_net32(skb, IPSET_ATTR_TIMEOUT, + htonl(set->timeout)))) + return -EMSGSIZE; + if (SET_WITH_COUNTER(set)) + cadt_flags |= IPSET_FLAG_WITH_COUNTERS; + if (SET_WITH_COMMENT(set)) + cadt_flags |= IPSET_FLAG_WITH_COMMENT; + if (SET_WITH_SKBINFO(set)) + cadt_flags |= IPSET_FLAG_WITH_SKBINFO; + if (SET_WITH_FORCEADD(set)) + cadt_flags |= IPSET_FLAG_WITH_FORCEADD; + + if (!cadt_flags) + return 0; + return nla_put_net32(skb, IPSET_ATTR_CADT_FLAGS, htonl(cadt_flags)); +} +EXPORT_SYMBOL_GPL(ip_set_put_flags); + +static int +ip_set_dump_done(struct netlink_callback *cb) +{ + if (cb->args[IPSET_CB_ARG0]) { + struct ip_set_net *inst = + (struct ip_set_net *)cb->args[IPSET_CB_NET]; + ip_set_id_t index = (ip_set_id_t)cb->args[IPSET_CB_INDEX]; + struct ip_set *set = ip_set_ref_netlink(inst, index); + + if (set->variant->uref) + set->variant->uref(set, cb, false); + pr_debug("release set %s\n", set->name); + __ip_set_put_netlink(set); + } + return 0; +} + +static inline void +dump_attrs(struct nlmsghdr *nlh) +{ + const struct nlattr *attr; + int rem; + + pr_debug("dump nlmsg\n"); + nlmsg_for_each_attr(attr, nlh, sizeof(struct nfgenmsg), rem) { + pr_debug("type: %u, len %u\n", nla_type(attr), attr->nla_len); + } +} + +static const struct nla_policy +ip_set_dump_policy[IPSET_ATTR_CMD_MAX + 1] = { + [IPSET_ATTR_PROTOCOL] = { .type = NLA_U8 }, + [IPSET_ATTR_SETNAME] = { .type = NLA_NUL_STRING, + .len = IPSET_MAXNAMELEN - 1 }, + [IPSET_ATTR_FLAGS] = { .type = NLA_U32 }, +}; + +static int +ip_set_dump_start(struct netlink_callback *cb) +{ + struct nlmsghdr *nlh = nlmsg_hdr(cb->skb); + int min_len = nlmsg_total_size(sizeof(struct nfgenmsg)); + struct nlattr *cda[IPSET_ATTR_CMD_MAX + 1]; + struct nlattr *attr = (void *)nlh + min_len; + struct sk_buff *skb = cb->skb; + struct ip_set_net *inst = ip_set_pernet(sock_net(skb->sk)); + u32 dump_type; + int ret; + + ret = nla_parse(cda, IPSET_ATTR_CMD_MAX, attr, + nlh->nlmsg_len - min_len, + ip_set_dump_policy, NULL); + if (ret) + goto error; + + cb->args[IPSET_CB_PROTO] = nla_get_u8(cda[IPSET_ATTR_PROTOCOL]); + if (cda[IPSET_ATTR_SETNAME]) { + ip_set_id_t index; + struct ip_set *set; + + set = find_set_and_id(inst, nla_data(cda[IPSET_ATTR_SETNAME]), + &index); + if (!set) { + ret = -ENOENT; + goto error; + } + dump_type = DUMP_ONE; + cb->args[IPSET_CB_INDEX] = index; + } else { + dump_type = DUMP_ALL; + } + + if (cda[IPSET_ATTR_FLAGS]) { + u32 f = ip_set_get_h32(cda[IPSET_ATTR_FLAGS]); + + dump_type |= (f << 16); + } + cb->args[IPSET_CB_NET] = (unsigned long)inst; + cb->args[IPSET_CB_DUMP] = dump_type; + + return 0; + +error: + /* We have to create and send the error message manually :-( */ + if (nlh->nlmsg_flags & NLM_F_ACK) { + netlink_ack(cb->skb, nlh, ret, NULL); + } + return ret; +} + +static int +ip_set_dump_do(struct sk_buff *skb, struct netlink_callback *cb) +{ + ip_set_id_t index = IPSET_INVALID_ID, max; + struct ip_set *set = NULL; + struct nlmsghdr *nlh = NULL; + unsigned int flags = NETLINK_CB(cb->skb).portid ? NLM_F_MULTI : 0; + struct ip_set_net *inst = ip_set_pernet(sock_net(skb->sk)); + u32 dump_type, dump_flags; + bool is_destroyed; + int ret = 0; + + if (!cb->args[IPSET_CB_DUMP]) + return -EINVAL; + + if (cb->args[IPSET_CB_INDEX] >= inst->ip_set_max) + goto out; + + dump_type = DUMP_TYPE(cb->args[IPSET_CB_DUMP]); + dump_flags = DUMP_FLAGS(cb->args[IPSET_CB_DUMP]); + max = dump_type == DUMP_ONE ? cb->args[IPSET_CB_INDEX] + 1 + : inst->ip_set_max; +dump_last: + pr_debug("dump type, flag: %u %u index: %ld\n", + dump_type, dump_flags, cb->args[IPSET_CB_INDEX]); + for (; cb->args[IPSET_CB_INDEX] < max; cb->args[IPSET_CB_INDEX]++) { + index = (ip_set_id_t)cb->args[IPSET_CB_INDEX]; + write_lock_bh(&ip_set_ref_lock); + set = ip_set(inst, index); + is_destroyed = inst->is_destroyed; + if (!set || is_destroyed) { + write_unlock_bh(&ip_set_ref_lock); + if (dump_type == DUMP_ONE) { + ret = -ENOENT; + goto out; + } + if (is_destroyed) { + /* All sets are just being destroyed */ + ret = 0; + goto out; + } + continue; + } + /* When dumping all sets, we must dump "sorted" + * so that lists (unions of sets) are dumped last. + */ + if (dump_type != DUMP_ONE && + ((dump_type == DUMP_ALL) == + !!(set->type->features & IPSET_DUMP_LAST))) { + write_unlock_bh(&ip_set_ref_lock); + continue; + } + pr_debug("List set: %s\n", set->name); + if (!cb->args[IPSET_CB_ARG0]) { + /* Start listing: make sure set won't be destroyed */ + pr_debug("reference set\n"); + set->ref_netlink++; + } + write_unlock_bh(&ip_set_ref_lock); + nlh = start_msg(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, flags, + IPSET_CMD_LIST); + if (!nlh) { + ret = -EMSGSIZE; + goto release_refcount; + } + if (nla_put_u8(skb, IPSET_ATTR_PROTOCOL, + cb->args[IPSET_CB_PROTO]) || + nla_put_string(skb, IPSET_ATTR_SETNAME, set->name)) + goto nla_put_failure; + if (dump_flags & IPSET_FLAG_LIST_SETNAME) + goto next_set; + switch (cb->args[IPSET_CB_ARG0]) { + case 0: + /* Core header data */ + if (nla_put_string(skb, IPSET_ATTR_TYPENAME, + set->type->name) || + nla_put_u8(skb, IPSET_ATTR_FAMILY, + set->family) || + nla_put_u8(skb, IPSET_ATTR_REVISION, + set->revision)) + goto nla_put_failure; + if (cb->args[IPSET_CB_PROTO] > IPSET_PROTOCOL_MIN && + nla_put_net16(skb, IPSET_ATTR_INDEX, htons(index))) + goto nla_put_failure; + ret = set->variant->head(set, skb); + if (ret < 0) + goto release_refcount; + if (dump_flags & IPSET_FLAG_LIST_HEADER) + goto next_set; + if (set->variant->uref) + set->variant->uref(set, cb, true); + fallthrough; + default: + ret = set->variant->list(set, skb, cb); + if (!cb->args[IPSET_CB_ARG0]) + /* Set is done, proceed with next one */ + goto next_set; + goto release_refcount; + } + } + /* If we dump all sets, continue with dumping last ones */ + if (dump_type == DUMP_ALL) { + dump_type = DUMP_LAST; + cb->args[IPSET_CB_DUMP] = dump_type | (dump_flags << 16); + cb->args[IPSET_CB_INDEX] = 0; + if (set && set->variant->uref) + set->variant->uref(set, cb, false); + goto dump_last; + } + goto out; + +nla_put_failure: + ret = -EFAULT; +next_set: + if (dump_type == DUMP_ONE) + cb->args[IPSET_CB_INDEX] = IPSET_INVALID_ID; + else + cb->args[IPSET_CB_INDEX]++; +release_refcount: + /* If there was an error or set is done, release set */ + if (ret || !cb->args[IPSET_CB_ARG0]) { + set = ip_set_ref_netlink(inst, index); + if (set->variant->uref) + set->variant->uref(set, cb, false); + pr_debug("release set %s\n", set->name); + __ip_set_put_netlink(set); + cb->args[IPSET_CB_ARG0] = 0; + } +out: + if (nlh) { + nlmsg_end(skb, nlh); + pr_debug("nlmsg_len: %u\n", nlh->nlmsg_len); + dump_attrs(nlh); + } + + return ret < 0 ? ret : skb->len; +} + +static int ip_set_dump(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const attr[]) +{ + if (unlikely(protocol_min_failed(attr))) + return -IPSET_ERR_PROTOCOL; + + { + struct netlink_dump_control c = { + .start = ip_set_dump_start, + .dump = ip_set_dump_do, + .done = ip_set_dump_done, + }; + return netlink_dump_start(info->sk, skb, info->nlh, &c); + } +} + +/* Add, del and test */ + +static const struct nla_policy ip_set_adt_policy[IPSET_ATTR_CMD_MAX + 1] = { + [IPSET_ATTR_PROTOCOL] = { .type = NLA_U8 }, + [IPSET_ATTR_SETNAME] = { .type = NLA_NUL_STRING, + .len = IPSET_MAXNAMELEN - 1 }, + [IPSET_ATTR_LINENO] = { .type = NLA_U32 }, + [IPSET_ATTR_DATA] = { .type = NLA_NESTED }, + [IPSET_ATTR_ADT] = { .type = NLA_NESTED }, +}; + +static int +call_ad(struct net *net, struct sock *ctnl, struct sk_buff *skb, + struct ip_set *set, struct nlattr *tb[], enum ipset_adt adt, + u32 flags, bool use_lineno) +{ + int ret; + u32 lineno = 0; + bool eexist = flags & IPSET_FLAG_EXIST, retried = false; + + do { + if (retried) { + __ip_set_get_netlink(set); + nfnl_unlock(NFNL_SUBSYS_IPSET); + cond_resched(); + nfnl_lock(NFNL_SUBSYS_IPSET); + __ip_set_put_netlink(set); + } + + ip_set_lock(set); + ret = set->variant->uadt(set, tb, adt, &lineno, flags, retried); + ip_set_unlock(set); + retried = true; + } while (ret == -ERANGE || + (ret == -EAGAIN && + set->variant->resize && + (ret = set->variant->resize(set, retried)) == 0)); + + if (!ret || (ret == -IPSET_ERR_EXIST && eexist)) + return 0; + if (lineno && use_lineno) { + /* Error in restore/batch mode: send back lineno */ + struct nlmsghdr *rep, *nlh = nlmsg_hdr(skb); + struct sk_buff *skb2; + struct nlmsgerr *errmsg; + size_t payload = min(SIZE_MAX, + sizeof(*errmsg) + nlmsg_len(nlh)); + int min_len = nlmsg_total_size(sizeof(struct nfgenmsg)); + struct nlattr *cda[IPSET_ATTR_CMD_MAX + 1]; + struct nlattr *cmdattr; + u32 *errline; + + skb2 = nlmsg_new(payload, GFP_KERNEL); + if (!skb2) + return -ENOMEM; + rep = nlmsg_put(skb2, NETLINK_CB(skb).portid, + nlh->nlmsg_seq, NLMSG_ERROR, payload, 0); + errmsg = nlmsg_data(rep); + errmsg->error = ret; + unsafe_memcpy(&errmsg->msg, nlh, nlh->nlmsg_len, + /* Bounds checked by the skb layer. */); + + cmdattr = (void *)&errmsg->msg + min_len; + + ret = nla_parse(cda, IPSET_ATTR_CMD_MAX, cmdattr, + nlh->nlmsg_len - min_len, ip_set_adt_policy, + NULL); + + if (ret) { + nlmsg_free(skb2); + return ret; + } + errline = nla_data(cda[IPSET_ATTR_LINENO]); + + *errline = lineno; + + nfnetlink_unicast(skb2, net, NETLINK_CB(skb).portid); + /* Signal netlink not to send its ACK/errmsg. */ + return -EINTR; + } + + return ret; +} + +static int ip_set_ad(struct net *net, struct sock *ctnl, + struct sk_buff *skb, + enum ipset_adt adt, + const struct nlmsghdr *nlh, + const struct nlattr * const attr[], + struct netlink_ext_ack *extack) +{ + struct ip_set_net *inst = ip_set_pernet(net); + struct ip_set *set; + struct nlattr *tb[IPSET_ATTR_ADT_MAX + 1] = {}; + const struct nlattr *nla; + u32 flags = flag_exist(nlh); + bool use_lineno; + int ret = 0; + + if (unlikely(protocol_min_failed(attr) || + !attr[IPSET_ATTR_SETNAME] || + !((attr[IPSET_ATTR_DATA] != NULL) ^ + (attr[IPSET_ATTR_ADT] != NULL)) || + (attr[IPSET_ATTR_DATA] && + !flag_nested(attr[IPSET_ATTR_DATA])) || + (attr[IPSET_ATTR_ADT] && + (!flag_nested(attr[IPSET_ATTR_ADT]) || + !attr[IPSET_ATTR_LINENO])))) + return -IPSET_ERR_PROTOCOL; + + set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME])); + if (!set) + return -ENOENT; + + use_lineno = !!attr[IPSET_ATTR_LINENO]; + if (attr[IPSET_ATTR_DATA]) { + if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, + attr[IPSET_ATTR_DATA], + set->type->adt_policy, NULL)) + return -IPSET_ERR_PROTOCOL; + ret = call_ad(net, ctnl, skb, set, tb, adt, flags, + use_lineno); + } else { + int nla_rem; + + nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) { + if (nla_type(nla) != IPSET_ATTR_DATA || + !flag_nested(nla) || + nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, nla, + set->type->adt_policy, NULL)) + return -IPSET_ERR_PROTOCOL; + ret = call_ad(net, ctnl, skb, set, tb, adt, + flags, use_lineno); + if (ret < 0) + return ret; + } + } + return ret; +} + +static int ip_set_uadd(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const attr[]) +{ + return ip_set_ad(info->net, info->sk, skb, + IPSET_ADD, info->nlh, attr, info->extack); +} + +static int ip_set_udel(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const attr[]) +{ + return ip_set_ad(info->net, info->sk, skb, + IPSET_DEL, info->nlh, attr, info->extack); +} + +static int ip_set_utest(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const attr[]) +{ + struct ip_set_net *inst = ip_set_pernet(info->net); + struct ip_set *set; + struct nlattr *tb[IPSET_ATTR_ADT_MAX + 1] = {}; + int ret = 0; + u32 lineno; + + if (unlikely(protocol_min_failed(attr) || + !attr[IPSET_ATTR_SETNAME] || + !attr[IPSET_ATTR_DATA] || + !flag_nested(attr[IPSET_ATTR_DATA]))) + return -IPSET_ERR_PROTOCOL; + + set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME])); + if (!set) + return -ENOENT; + + if (nla_parse_nested(tb, IPSET_ATTR_ADT_MAX, attr[IPSET_ATTR_DATA], + set->type->adt_policy, NULL)) + return -IPSET_ERR_PROTOCOL; + + rcu_read_lock_bh(); + ret = set->variant->uadt(set, tb, IPSET_TEST, &lineno, 0, 0); + rcu_read_unlock_bh(); + /* Userspace can't trigger element to be re-added */ + if (ret == -EAGAIN) + ret = 1; + + return ret > 0 ? 0 : -IPSET_ERR_EXIST; +} + +/* Get headed data of a set */ + +static int ip_set_header(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const attr[]) +{ + struct ip_set_net *inst = ip_set_pernet(info->net); + const struct ip_set *set; + struct sk_buff *skb2; + struct nlmsghdr *nlh2; + + if (unlikely(protocol_min_failed(attr) || + !attr[IPSET_ATTR_SETNAME])) + return -IPSET_ERR_PROTOCOL; + + set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME])); + if (!set) + return -ENOENT; + + skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb2) + return -ENOMEM; + + nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, info->nlh->nlmsg_seq, 0, + IPSET_CMD_HEADER); + if (!nlh2) + goto nlmsg_failure; + if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) || + nla_put_string(skb2, IPSET_ATTR_SETNAME, set->name) || + nla_put_string(skb2, IPSET_ATTR_TYPENAME, set->type->name) || + nla_put_u8(skb2, IPSET_ATTR_FAMILY, set->family) || + nla_put_u8(skb2, IPSET_ATTR_REVISION, set->revision)) + goto nla_put_failure; + nlmsg_end(skb2, nlh2); + + return nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid); + +nla_put_failure: + nlmsg_cancel(skb2, nlh2); +nlmsg_failure: + kfree_skb(skb2); + return -EMSGSIZE; +} + +/* Get type data */ + +static const struct nla_policy ip_set_type_policy[IPSET_ATTR_CMD_MAX + 1] = { + [IPSET_ATTR_PROTOCOL] = { .type = NLA_U8 }, + [IPSET_ATTR_TYPENAME] = { .type = NLA_NUL_STRING, + .len = IPSET_MAXNAMELEN - 1 }, + [IPSET_ATTR_FAMILY] = { .type = NLA_U8 }, +}; + +static int ip_set_type(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const attr[]) +{ + struct sk_buff *skb2; + struct nlmsghdr *nlh2; + u8 family, min, max; + const char *typename; + int ret = 0; + + if (unlikely(protocol_min_failed(attr) || + !attr[IPSET_ATTR_TYPENAME] || + !attr[IPSET_ATTR_FAMILY])) + return -IPSET_ERR_PROTOCOL; + + family = nla_get_u8(attr[IPSET_ATTR_FAMILY]); + typename = nla_data(attr[IPSET_ATTR_TYPENAME]); + ret = find_set_type_minmax(typename, family, &min, &max); + if (ret) + return ret; + + skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb2) + return -ENOMEM; + + nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, info->nlh->nlmsg_seq, 0, + IPSET_CMD_TYPE); + if (!nlh2) + goto nlmsg_failure; + if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) || + nla_put_string(skb2, IPSET_ATTR_TYPENAME, typename) || + nla_put_u8(skb2, IPSET_ATTR_FAMILY, family) || + nla_put_u8(skb2, IPSET_ATTR_REVISION, max) || + nla_put_u8(skb2, IPSET_ATTR_REVISION_MIN, min)) + goto nla_put_failure; + nlmsg_end(skb2, nlh2); + + pr_debug("Send TYPE, nlmsg_len: %u\n", nlh2->nlmsg_len); + return nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid); + +nla_put_failure: + nlmsg_cancel(skb2, nlh2); +nlmsg_failure: + kfree_skb(skb2); + return -EMSGSIZE; +} + +/* Get protocol version */ + +static const struct nla_policy +ip_set_protocol_policy[IPSET_ATTR_CMD_MAX + 1] = { + [IPSET_ATTR_PROTOCOL] = { .type = NLA_U8 }, +}; + +static int ip_set_protocol(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const attr[]) +{ + struct sk_buff *skb2; + struct nlmsghdr *nlh2; + + if (unlikely(!attr[IPSET_ATTR_PROTOCOL])) + return -IPSET_ERR_PROTOCOL; + + skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb2) + return -ENOMEM; + + nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, info->nlh->nlmsg_seq, 0, + IPSET_CMD_PROTOCOL); + if (!nlh2) + goto nlmsg_failure; + if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL)) + goto nla_put_failure; + if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL_MIN, IPSET_PROTOCOL_MIN)) + goto nla_put_failure; + nlmsg_end(skb2, nlh2); + + return nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid); + +nla_put_failure: + nlmsg_cancel(skb2, nlh2); +nlmsg_failure: + kfree_skb(skb2); + return -EMSGSIZE; +} + +/* Get set by name or index, from userspace */ + +static int ip_set_byname(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const attr[]) +{ + struct ip_set_net *inst = ip_set_pernet(info->net); + struct sk_buff *skb2; + struct nlmsghdr *nlh2; + ip_set_id_t id = IPSET_INVALID_ID; + const struct ip_set *set; + + if (unlikely(protocol_failed(attr) || + !attr[IPSET_ATTR_SETNAME])) + return -IPSET_ERR_PROTOCOL; + + set = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]), &id); + if (id == IPSET_INVALID_ID) + return -ENOENT; + + skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb2) + return -ENOMEM; + + nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, info->nlh->nlmsg_seq, 0, + IPSET_CMD_GET_BYNAME); + if (!nlh2) + goto nlmsg_failure; + if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) || + nla_put_u8(skb2, IPSET_ATTR_FAMILY, set->family) || + nla_put_net16(skb2, IPSET_ATTR_INDEX, htons(id))) + goto nla_put_failure; + nlmsg_end(skb2, nlh2); + + return nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid); + +nla_put_failure: + nlmsg_cancel(skb2, nlh2); +nlmsg_failure: + kfree_skb(skb2); + return -EMSGSIZE; +} + +static const struct nla_policy ip_set_index_policy[IPSET_ATTR_CMD_MAX + 1] = { + [IPSET_ATTR_PROTOCOL] = { .type = NLA_U8 }, + [IPSET_ATTR_INDEX] = { .type = NLA_U16 }, +}; + +static int ip_set_byindex(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const attr[]) +{ + struct ip_set_net *inst = ip_set_pernet(info->net); + struct sk_buff *skb2; + struct nlmsghdr *nlh2; + ip_set_id_t id = IPSET_INVALID_ID; + const struct ip_set *set; + + if (unlikely(protocol_failed(attr) || + !attr[IPSET_ATTR_INDEX])) + return -IPSET_ERR_PROTOCOL; + + id = ip_set_get_h16(attr[IPSET_ATTR_INDEX]); + if (id >= inst->ip_set_max) + return -ENOENT; + set = ip_set(inst, id); + if (set == NULL) + return -ENOENT; + + skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb2) + return -ENOMEM; + + nlh2 = start_msg(skb2, NETLINK_CB(skb).portid, info->nlh->nlmsg_seq, 0, + IPSET_CMD_GET_BYINDEX); + if (!nlh2) + goto nlmsg_failure; + if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) || + nla_put_string(skb2, IPSET_ATTR_SETNAME, set->name)) + goto nla_put_failure; + nlmsg_end(skb2, nlh2); + + return nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid); + +nla_put_failure: + nlmsg_cancel(skb2, nlh2); +nlmsg_failure: + kfree_skb(skb2); + return -EMSGSIZE; +} + +static const struct nfnl_callback ip_set_netlink_subsys_cb[IPSET_MSG_MAX] = { + [IPSET_CMD_NONE] = { + .call = ip_set_none, + .type = NFNL_CB_MUTEX, + .attr_count = IPSET_ATTR_CMD_MAX, + }, + [IPSET_CMD_CREATE] = { + .call = ip_set_create, + .type = NFNL_CB_MUTEX, + .attr_count = IPSET_ATTR_CMD_MAX, + .policy = ip_set_create_policy, + }, + [IPSET_CMD_DESTROY] = { + .call = ip_set_destroy, + .type = NFNL_CB_MUTEX, + .attr_count = IPSET_ATTR_CMD_MAX, + .policy = ip_set_setname_policy, + }, + [IPSET_CMD_FLUSH] = { + .call = ip_set_flush, + .type = NFNL_CB_MUTEX, + .attr_count = IPSET_ATTR_CMD_MAX, + .policy = ip_set_setname_policy, + }, + [IPSET_CMD_RENAME] = { + .call = ip_set_rename, + .type = NFNL_CB_MUTEX, + .attr_count = IPSET_ATTR_CMD_MAX, + .policy = ip_set_setname2_policy, + }, + [IPSET_CMD_SWAP] = { + .call = ip_set_swap, + .type = NFNL_CB_MUTEX, + .attr_count = IPSET_ATTR_CMD_MAX, + .policy = ip_set_setname2_policy, + }, + [IPSET_CMD_LIST] = { + .call = ip_set_dump, + .type = NFNL_CB_MUTEX, + .attr_count = IPSET_ATTR_CMD_MAX, + .policy = ip_set_dump_policy, + }, + [IPSET_CMD_SAVE] = { + .call = ip_set_dump, + .type = NFNL_CB_MUTEX, + .attr_count = IPSET_ATTR_CMD_MAX, + .policy = ip_set_setname_policy, + }, + [IPSET_CMD_ADD] = { + .call = ip_set_uadd, + .type = NFNL_CB_MUTEX, + .attr_count = IPSET_ATTR_CMD_MAX, + .policy = ip_set_adt_policy, + }, + [IPSET_CMD_DEL] = { + .call = ip_set_udel, + .type = NFNL_CB_MUTEX, + .attr_count = IPSET_ATTR_CMD_MAX, + .policy = ip_set_adt_policy, + }, + [IPSET_CMD_TEST] = { + .call = ip_set_utest, + .type = NFNL_CB_MUTEX, + .attr_count = IPSET_ATTR_CMD_MAX, + .policy = ip_set_adt_policy, + }, + [IPSET_CMD_HEADER] = { + .call = ip_set_header, + .type = NFNL_CB_MUTEX, + .attr_count = IPSET_ATTR_CMD_MAX, + .policy = ip_set_setname_policy, + }, + [IPSET_CMD_TYPE] = { + .call = ip_set_type, + .type = NFNL_CB_MUTEX, + .attr_count = IPSET_ATTR_CMD_MAX, + .policy = ip_set_type_policy, + }, + [IPSET_CMD_PROTOCOL] = { + .call = ip_set_protocol, + .type = NFNL_CB_MUTEX, + .attr_count = IPSET_ATTR_CMD_MAX, + .policy = ip_set_protocol_policy, + }, + [IPSET_CMD_GET_BYNAME] = { + .call = ip_set_byname, + .type = NFNL_CB_MUTEX, + .attr_count = IPSET_ATTR_CMD_MAX, + .policy = ip_set_setname_policy, + }, + [IPSET_CMD_GET_BYINDEX] = { + .call = ip_set_byindex, + .type = NFNL_CB_MUTEX, + .attr_count = IPSET_ATTR_CMD_MAX, + .policy = ip_set_index_policy, + }, +}; + +static struct nfnetlink_subsystem ip_set_netlink_subsys __read_mostly = { + .name = "ip_set", + .subsys_id = NFNL_SUBSYS_IPSET, + .cb_count = IPSET_MSG_MAX, + .cb = ip_set_netlink_subsys_cb, +}; + +/* Interface to iptables/ip6tables */ + +static int +ip_set_sockfn_get(struct sock *sk, int optval, void __user *user, int *len) +{ + unsigned int *op; + void *data; + int copylen = *len, ret = 0; + struct net *net = sock_net(sk); + struct ip_set_net *inst = ip_set_pernet(net); + + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + if (optval != SO_IP_SET) + return -EBADF; + if (*len < sizeof(unsigned int)) + return -EINVAL; + + data = vmalloc(*len); + if (!data) + return -ENOMEM; + if (copy_from_user(data, user, *len) != 0) { + ret = -EFAULT; + goto done; + } + op = data; + + if (*op < IP_SET_OP_VERSION) { + /* Check the version at the beginning of operations */ + struct ip_set_req_version *req_version = data; + + if (*len < sizeof(struct ip_set_req_version)) { + ret = -EINVAL; + goto done; + } + + if (req_version->version < IPSET_PROTOCOL_MIN) { + ret = -EPROTO; + goto done; + } + } + + switch (*op) { + case IP_SET_OP_VERSION: { + struct ip_set_req_version *req_version = data; + + if (*len != sizeof(struct ip_set_req_version)) { + ret = -EINVAL; + goto done; + } + + req_version->version = IPSET_PROTOCOL; + if (copy_to_user(user, req_version, + sizeof(struct ip_set_req_version))) + ret = -EFAULT; + goto done; + } + case IP_SET_OP_GET_BYNAME: { + struct ip_set_req_get_set *req_get = data; + ip_set_id_t id; + + if (*len != sizeof(struct ip_set_req_get_set)) { + ret = -EINVAL; + goto done; + } + req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0'; + nfnl_lock(NFNL_SUBSYS_IPSET); + find_set_and_id(inst, req_get->set.name, &id); + req_get->set.index = id; + nfnl_unlock(NFNL_SUBSYS_IPSET); + goto copy; + } + case IP_SET_OP_GET_FNAME: { + struct ip_set_req_get_set_family *req_get = data; + ip_set_id_t id; + + if (*len != sizeof(struct ip_set_req_get_set_family)) { + ret = -EINVAL; + goto done; + } + req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0'; + nfnl_lock(NFNL_SUBSYS_IPSET); + find_set_and_id(inst, req_get->set.name, &id); + req_get->set.index = id; + if (id != IPSET_INVALID_ID) + req_get->family = ip_set(inst, id)->family; + nfnl_unlock(NFNL_SUBSYS_IPSET); + goto copy; + } + case IP_SET_OP_GET_BYINDEX: { + struct ip_set_req_get_set *req_get = data; + struct ip_set *set; + + if (*len != sizeof(struct ip_set_req_get_set) || + req_get->set.index >= inst->ip_set_max) { + ret = -EINVAL; + goto done; + } + nfnl_lock(NFNL_SUBSYS_IPSET); + set = ip_set(inst, req_get->set.index); + ret = strscpy(req_get->set.name, set ? set->name : "", + IPSET_MAXNAMELEN); + nfnl_unlock(NFNL_SUBSYS_IPSET); + if (ret < 0) + goto done; + goto copy; + } + default: + ret = -EBADMSG; + goto done; + } /* end of switch(op) */ + +copy: + if (copy_to_user(user, data, copylen)) + ret = -EFAULT; + +done: + vfree(data); + if (ret > 0) + ret = 0; + return ret; +} + +static struct nf_sockopt_ops so_set __read_mostly = { + .pf = PF_INET, + .get_optmin = SO_IP_SET, + .get_optmax = SO_IP_SET + 1, + .get = ip_set_sockfn_get, + .owner = THIS_MODULE, +}; + +static int __net_init +ip_set_net_init(struct net *net) +{ + struct ip_set_net *inst = ip_set_pernet(net); + struct ip_set **list; + + inst->ip_set_max = max_sets ? max_sets : CONFIG_IP_SET_MAX; + if (inst->ip_set_max >= IPSET_INVALID_ID) + inst->ip_set_max = IPSET_INVALID_ID - 1; + + list = kvcalloc(inst->ip_set_max, sizeof(struct ip_set *), GFP_KERNEL); + if (!list) + return -ENOMEM; + inst->is_deleted = false; + inst->is_destroyed = false; + rcu_assign_pointer(inst->ip_set_list, list); + return 0; +} + +static void __net_exit +ip_set_net_exit(struct net *net) +{ + struct ip_set_net *inst = ip_set_pernet(net); + + struct ip_set *set = NULL; + ip_set_id_t i; + + inst->is_deleted = true; /* flag for ip_set_nfnl_put */ + + nfnl_lock(NFNL_SUBSYS_IPSET); + for (i = 0; i < inst->ip_set_max; i++) { + set = ip_set(inst, i); + if (set) { + ip_set(inst, i) = NULL; + ip_set_destroy_set(set); + } + } + nfnl_unlock(NFNL_SUBSYS_IPSET); + kvfree(rcu_dereference_protected(inst->ip_set_list, 1)); +} + +static struct pernet_operations ip_set_net_ops = { + .init = ip_set_net_init, + .exit = ip_set_net_exit, + .id = &ip_set_net_id, + .size = sizeof(struct ip_set_net), +}; + +static int __init +ip_set_init(void) +{ + int ret = register_pernet_subsys(&ip_set_net_ops); + + if (ret) { + pr_err("ip_set: cannot register pernet_subsys.\n"); + return ret; + } + + ret = nfnetlink_subsys_register(&ip_set_netlink_subsys); + if (ret != 0) { + pr_err("ip_set: cannot register with nfnetlink.\n"); + unregister_pernet_subsys(&ip_set_net_ops); + return ret; + } + + ret = nf_register_sockopt(&so_set); + if (ret != 0) { + pr_err("SO_SET registry failed: %d\n", ret); + nfnetlink_subsys_unregister(&ip_set_netlink_subsys); + unregister_pernet_subsys(&ip_set_net_ops); + return ret; + } + + return 0; +} + +static void __exit +ip_set_fini(void) +{ + nf_unregister_sockopt(&so_set); + nfnetlink_subsys_unregister(&ip_set_netlink_subsys); + + unregister_pernet_subsys(&ip_set_net_ops); + pr_debug("these are the famous last words\n"); +} + +module_init(ip_set_init); +module_exit(ip_set_fini); + +MODULE_DESCRIPTION("ip_set: protocol " __stringify(IPSET_PROTOCOL)); diff --git a/net/netfilter/ipset/ip_set_getport.c b/net/netfilter/ipset/ip_set_getport.c new file mode 100644 index 000000000..36615eb3e --- /dev/null +++ b/net/netfilter/ipset/ip_set_getport.c @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2003-2011 Jozsef Kadlecsik + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 as + * published by the Free Software Foundation. + */ + +/* Get Layer-4 data from the packets */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +/* We must handle non-linear skbs */ +static bool +get_port(const struct sk_buff *skb, int protocol, unsigned int protooff, + bool src, __be16 *port, u8 *proto) +{ + switch (protocol) { + case IPPROTO_TCP: { + struct tcphdr _tcph; + const struct tcphdr *th; + + th = skb_header_pointer(skb, protooff, sizeof(_tcph), &_tcph); + if (!th) + /* No choice either */ + return false; + + *port = src ? th->source : th->dest; + break; + } + case IPPROTO_SCTP: { + struct sctphdr _sh; + const struct sctphdr *sh; + + sh = skb_header_pointer(skb, protooff, sizeof(_sh), &_sh); + if (!sh) + /* No choice either */ + return false; + + *port = src ? sh->source : sh->dest; + break; + } + case IPPROTO_UDP: + case IPPROTO_UDPLITE: { + struct udphdr _udph; + const struct udphdr *uh; + + uh = skb_header_pointer(skb, protooff, sizeof(_udph), &_udph); + if (!uh) + /* No choice either */ + return false; + + *port = src ? uh->source : uh->dest; + break; + } + case IPPROTO_ICMP: { + struct icmphdr _ich; + const struct icmphdr *ic; + + ic = skb_header_pointer(skb, protooff, sizeof(_ich), &_ich); + if (!ic) + return false; + + *port = (__force __be16)htons((ic->type << 8) | ic->code); + break; + } + case IPPROTO_ICMPV6: { + struct icmp6hdr _ich; + const struct icmp6hdr *ic; + + ic = skb_header_pointer(skb, protooff, sizeof(_ich), &_ich); + if (!ic) + return false; + + *port = (__force __be16) + htons((ic->icmp6_type << 8) | ic->icmp6_code); + break; + } + default: + break; + } + *proto = protocol; + + return true; +} + +bool +ip_set_get_ip4_port(const struct sk_buff *skb, bool src, + __be16 *port, u8 *proto) +{ + const struct iphdr *iph = ip_hdr(skb); + unsigned int protooff = skb_network_offset(skb) + ip_hdrlen(skb); + int protocol = iph->protocol; + + /* See comments at tcp_match in ip_tables.c */ + if (protocol <= 0) + return false; + + if (ntohs(iph->frag_off) & IP_OFFSET) + switch (protocol) { + case IPPROTO_TCP: + case IPPROTO_SCTP: + case IPPROTO_UDP: + case IPPROTO_UDPLITE: + case IPPROTO_ICMP: + /* Port info not available for fragment offset > 0 */ + return false; + default: + /* Other protocols doesn't have ports, + * so we can match fragments. + */ + *proto = protocol; + return true; + } + + return get_port(skb, protocol, protooff, src, port, proto); +} +EXPORT_SYMBOL_GPL(ip_set_get_ip4_port); + +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) +bool +ip_set_get_ip6_port(const struct sk_buff *skb, bool src, + __be16 *port, u8 *proto) +{ + int protoff; + u8 nexthdr; + __be16 frag_off = 0; + + nexthdr = ipv6_hdr(skb)->nexthdr; + protoff = ipv6_skip_exthdr(skb, + skb_network_offset(skb) + + sizeof(struct ipv6hdr), &nexthdr, + &frag_off); + if (protoff < 0 || (frag_off & htons(~0x7)) != 0) + return false; + + return get_port(skb, nexthdr, protoff, src, port, proto); +} +EXPORT_SYMBOL_GPL(ip_set_get_ip6_port); +#endif diff --git a/net/netfilter/ipset/ip_set_hash_gen.h b/net/netfilter/ipset/ip_set_hash_gen.h new file mode 100644 index 000000000..7499192af --- /dev/null +++ b/net/netfilter/ipset/ip_set_hash_gen.h @@ -0,0 +1,1582 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* Copyright (C) 2013 Jozsef Kadlecsik */ + +#ifndef _IP_SET_HASH_GEN_H +#define _IP_SET_HASH_GEN_H + +#include +#include +#include +#include +#include + +#define __ipset_dereference(p) \ + rcu_dereference_protected(p, 1) +#define ipset_dereference_nfnl(p) \ + rcu_dereference_protected(p, \ + lockdep_nfnl_is_held(NFNL_SUBSYS_IPSET)) +#define ipset_dereference_set(p, set) \ + rcu_dereference_protected(p, \ + lockdep_nfnl_is_held(NFNL_SUBSYS_IPSET) || \ + lockdep_is_held(&(set)->lock)) +#define ipset_dereference_bh_nfnl(p) \ + rcu_dereference_bh_check(p, \ + lockdep_nfnl_is_held(NFNL_SUBSYS_IPSET)) + +/* Hashing which uses arrays to resolve clashing. The hash table is resized + * (doubled) when searching becomes too long. + * Internally jhash is used with the assumption that the size of the + * stored data is a multiple of sizeof(u32). + * + * Readers and resizing + * + * Resizing can be triggered by userspace command only, and those + * are serialized by the nfnl mutex. During resizing the set is + * read-locked, so the only possible concurrent operations are + * the kernel side readers. Those must be protected by proper RCU locking. + */ + +/* Number of elements to store in an initial array block */ +#define AHASH_INIT_SIZE 2 +/* Max number of elements to store in an array block */ +#define AHASH_MAX_SIZE (6 * AHASH_INIT_SIZE) +/* Max muber of elements in the array block when tuned */ +#define AHASH_MAX_TUNED 64 +#define AHASH_MAX(h) ((h)->bucketsize) + +/* A hash bucket */ +struct hbucket { + struct rcu_head rcu; /* for call_rcu */ + /* Which positions are used in the array */ + DECLARE_BITMAP(used, AHASH_MAX_TUNED); + u8 size; /* size of the array */ + u8 pos; /* position of the first free entry */ + unsigned char value[] /* the array of the values */ + __aligned(__alignof__(u64)); +}; + +/* Region size for locking == 2^HTABLE_REGION_BITS */ +#define HTABLE_REGION_BITS 10 +#define ahash_numof_locks(htable_bits) \ + ((htable_bits) < HTABLE_REGION_BITS ? 1 \ + : jhash_size((htable_bits) - HTABLE_REGION_BITS)) +#define ahash_sizeof_regions(htable_bits) \ + (ahash_numof_locks(htable_bits) * sizeof(struct ip_set_region)) +#define ahash_region(n, htable_bits) \ + ((n) % ahash_numof_locks(htable_bits)) +#define ahash_bucket_start(h, htable_bits) \ + ((htable_bits) < HTABLE_REGION_BITS ? 0 \ + : (h) * jhash_size(HTABLE_REGION_BITS)) +#define ahash_bucket_end(h, htable_bits) \ + ((htable_bits) < HTABLE_REGION_BITS ? jhash_size(htable_bits) \ + : ((h) + 1) * jhash_size(HTABLE_REGION_BITS)) + +struct htable_gc { + struct delayed_work dwork; + struct ip_set *set; /* Set the gc belongs to */ + u32 region; /* Last gc run position */ +}; + +/* The hash table: the table size stored here in order to make resizing easy */ +struct htable { + atomic_t ref; /* References for resizing */ + atomic_t uref; /* References for dumping and gc */ + u8 htable_bits; /* size of hash table == 2^htable_bits */ + u32 maxelem; /* Maxelem per region */ + struct ip_set_region *hregion; /* Region locks and ext sizes */ + struct hbucket __rcu *bucket[]; /* hashtable buckets */ +}; + +#define hbucket(h, i) ((h)->bucket[i]) +#define ext_size(n, dsize) \ + (sizeof(struct hbucket) + (n) * (dsize)) + +#ifndef IPSET_NET_COUNT +#define IPSET_NET_COUNT 1 +#endif + +/* Book-keeping of the prefixes added to the set */ +struct net_prefixes { + u32 nets[IPSET_NET_COUNT]; /* number of elements for this cidr */ + u8 cidr[IPSET_NET_COUNT]; /* the cidr value */ +}; + +/* Compute the hash table size */ +static size_t +htable_size(u8 hbits) +{ + size_t hsize; + + /* We must fit both into u32 in jhash and INT_MAX in kvmalloc_node() */ + if (hbits > 31) + return 0; + hsize = jhash_size(hbits); + if ((INT_MAX - sizeof(struct htable)) / sizeof(struct hbucket *) + < hsize) + return 0; + + return hsize * sizeof(struct hbucket *) + sizeof(struct htable); +} + +#ifdef IP_SET_HASH_WITH_NETS +#if IPSET_NET_COUNT > 1 +#define __CIDR(cidr, i) (cidr[i]) +#else +#define __CIDR(cidr, i) (cidr) +#endif + +/* cidr + 1 is stored in net_prefixes to support /0 */ +#define NCIDR_PUT(cidr) ((cidr) + 1) +#define NCIDR_GET(cidr) ((cidr) - 1) + +#ifdef IP_SET_HASH_WITH_NETS_PACKED +/* When cidr is packed with nomatch, cidr - 1 is stored in the data entry */ +#define DCIDR_PUT(cidr) ((cidr) - 1) +#define DCIDR_GET(cidr, i) (__CIDR(cidr, i) + 1) +#else +#define DCIDR_PUT(cidr) (cidr) +#define DCIDR_GET(cidr, i) __CIDR(cidr, i) +#endif + +#define INIT_CIDR(cidr, host_mask) \ + DCIDR_PUT(((cidr) ? NCIDR_GET(cidr) : host_mask)) + +#ifdef IP_SET_HASH_WITH_NET0 +/* cidr from 0 to HOST_MASK value and c = cidr + 1 */ +#define NLEN (HOST_MASK + 1) +#define CIDR_POS(c) ((c) - 1) +#else +/* cidr from 1 to HOST_MASK value and c = cidr + 1 */ +#define NLEN HOST_MASK +#define CIDR_POS(c) ((c) - 2) +#endif + +#else +#define NLEN 0 +#endif /* IP_SET_HASH_WITH_NETS */ + +#define SET_ELEM_EXPIRED(set, d) \ + (SET_WITH_TIMEOUT(set) && \ + ip_set_timeout_expired(ext_timeout(d, set))) + +#endif /* _IP_SET_HASH_GEN_H */ + +#ifndef MTYPE +#error "MTYPE is not defined!" +#endif + +#ifndef HTYPE +#error "HTYPE is not defined!" +#endif + +#ifndef HOST_MASK +#error "HOST_MASK is not defined!" +#endif + +/* Family dependent templates */ + +#undef ahash_data +#undef mtype_data_equal +#undef mtype_do_data_match +#undef mtype_data_set_flags +#undef mtype_data_reset_elem +#undef mtype_data_reset_flags +#undef mtype_data_netmask +#undef mtype_data_list +#undef mtype_data_next +#undef mtype_elem + +#undef mtype_ahash_destroy +#undef mtype_ext_cleanup +#undef mtype_add_cidr +#undef mtype_del_cidr +#undef mtype_ahash_memsize +#undef mtype_flush +#undef mtype_destroy +#undef mtype_same_set +#undef mtype_kadt +#undef mtype_uadt + +#undef mtype_add +#undef mtype_del +#undef mtype_test_cidrs +#undef mtype_test +#undef mtype_uref +#undef mtype_resize +#undef mtype_ext_size +#undef mtype_resize_ad +#undef mtype_head +#undef mtype_list +#undef mtype_gc_do +#undef mtype_gc +#undef mtype_gc_init +#undef mtype_variant +#undef mtype_data_match + +#undef htype +#undef HKEY + +#define mtype_data_equal IPSET_TOKEN(MTYPE, _data_equal) +#ifdef IP_SET_HASH_WITH_NETS +#define mtype_do_data_match IPSET_TOKEN(MTYPE, _do_data_match) +#else +#define mtype_do_data_match(d) 1 +#endif +#define mtype_data_set_flags IPSET_TOKEN(MTYPE, _data_set_flags) +#define mtype_data_reset_elem IPSET_TOKEN(MTYPE, _data_reset_elem) +#define mtype_data_reset_flags IPSET_TOKEN(MTYPE, _data_reset_flags) +#define mtype_data_netmask IPSET_TOKEN(MTYPE, _data_netmask) +#define mtype_data_list IPSET_TOKEN(MTYPE, _data_list) +#define mtype_data_next IPSET_TOKEN(MTYPE, _data_next) +#define mtype_elem IPSET_TOKEN(MTYPE, _elem) + +#define mtype_ahash_destroy IPSET_TOKEN(MTYPE, _ahash_destroy) +#define mtype_ext_cleanup IPSET_TOKEN(MTYPE, _ext_cleanup) +#define mtype_add_cidr IPSET_TOKEN(MTYPE, _add_cidr) +#define mtype_del_cidr IPSET_TOKEN(MTYPE, _del_cidr) +#define mtype_ahash_memsize IPSET_TOKEN(MTYPE, _ahash_memsize) +#define mtype_flush IPSET_TOKEN(MTYPE, _flush) +#define mtype_destroy IPSET_TOKEN(MTYPE, _destroy) +#define mtype_same_set IPSET_TOKEN(MTYPE, _same_set) +#define mtype_kadt IPSET_TOKEN(MTYPE, _kadt) +#define mtype_uadt IPSET_TOKEN(MTYPE, _uadt) + +#define mtype_add IPSET_TOKEN(MTYPE, _add) +#define mtype_del IPSET_TOKEN(MTYPE, _del) +#define mtype_test_cidrs IPSET_TOKEN(MTYPE, _test_cidrs) +#define mtype_test IPSET_TOKEN(MTYPE, _test) +#define mtype_uref IPSET_TOKEN(MTYPE, _uref) +#define mtype_resize IPSET_TOKEN(MTYPE, _resize) +#define mtype_ext_size IPSET_TOKEN(MTYPE, _ext_size) +#define mtype_resize_ad IPSET_TOKEN(MTYPE, _resize_ad) +#define mtype_head IPSET_TOKEN(MTYPE, _head) +#define mtype_list IPSET_TOKEN(MTYPE, _list) +#define mtype_gc_do IPSET_TOKEN(MTYPE, _gc_do) +#define mtype_gc IPSET_TOKEN(MTYPE, _gc) +#define mtype_gc_init IPSET_TOKEN(MTYPE, _gc_init) +#define mtype_variant IPSET_TOKEN(MTYPE, _variant) +#define mtype_data_match IPSET_TOKEN(MTYPE, _data_match) + +#ifndef HKEY_DATALEN +#define HKEY_DATALEN sizeof(struct mtype_elem) +#endif + +#define htype MTYPE + +#define HKEY(data, initval, htable_bits) \ +({ \ + const u32 *__k = (const u32 *)data; \ + u32 __l = HKEY_DATALEN / sizeof(u32); \ + \ + BUILD_BUG_ON(HKEY_DATALEN % sizeof(u32) != 0); \ + \ + jhash2(__k, __l, initval) & jhash_mask(htable_bits); \ +}) + +/* The generic hash structure */ +struct htype { + struct htable __rcu *table; /* the hash table */ + struct htable_gc gc; /* gc workqueue */ + u32 maxelem; /* max elements in the hash */ + u32 initval; /* random jhash init value */ +#ifdef IP_SET_HASH_WITH_MARKMASK + u32 markmask; /* markmask value for mark mask to store */ +#endif + u8 bucketsize; /* max elements in an array block */ +#ifdef IP_SET_HASH_WITH_NETMASK + u8 netmask; /* netmask value for subnets to store */ +#endif + struct list_head ad; /* Resize add|del backlist */ + struct mtype_elem next; /* temporary storage for uadd */ +#ifdef IP_SET_HASH_WITH_NETS + struct net_prefixes nets[NLEN]; /* book-keeping of prefixes */ +#endif +}; + +/* ADD|DEL entries saved during resize */ +struct mtype_resize_ad { + struct list_head list; + enum ipset_adt ad; /* ADD|DEL element */ + struct mtype_elem d; /* Element value */ + struct ip_set_ext ext; /* Extensions for ADD */ + struct ip_set_ext mext; /* Target extensions for ADD */ + u32 flags; /* Flags for ADD */ +}; + +#ifdef IP_SET_HASH_WITH_NETS +/* Network cidr size book keeping when the hash stores different + * sized networks. cidr == real cidr + 1 to support /0. + */ +static void +mtype_add_cidr(struct ip_set *set, struct htype *h, u8 cidr, u8 n) +{ + int i, j; + + spin_lock_bh(&set->lock); + /* Add in increasing prefix order, so larger cidr first */ + for (i = 0, j = -1; i < NLEN && h->nets[i].cidr[n]; i++) { + if (j != -1) { + continue; + } else if (h->nets[i].cidr[n] < cidr) { + j = i; + } else if (h->nets[i].cidr[n] == cidr) { + h->nets[CIDR_POS(cidr)].nets[n]++; + goto unlock; + } + } + if (j != -1) { + for (; i > j; i--) + h->nets[i].cidr[n] = h->nets[i - 1].cidr[n]; + } + h->nets[i].cidr[n] = cidr; + h->nets[CIDR_POS(cidr)].nets[n] = 1; +unlock: + spin_unlock_bh(&set->lock); +} + +static void +mtype_del_cidr(struct ip_set *set, struct htype *h, u8 cidr, u8 n) +{ + u8 i, j, net_end = NLEN - 1; + + spin_lock_bh(&set->lock); + for (i = 0; i < NLEN; i++) { + if (h->nets[i].cidr[n] != cidr) + continue; + h->nets[CIDR_POS(cidr)].nets[n]--; + if (h->nets[CIDR_POS(cidr)].nets[n] > 0) + goto unlock; + for (j = i; j < net_end && h->nets[j].cidr[n]; j++) + h->nets[j].cidr[n] = h->nets[j + 1].cidr[n]; + h->nets[j].cidr[n] = 0; + goto unlock; + } +unlock: + spin_unlock_bh(&set->lock); +} +#endif + +/* Calculate the actual memory size of the set data */ +static size_t +mtype_ahash_memsize(const struct htype *h, const struct htable *t) +{ + return sizeof(*h) + sizeof(*t) + ahash_sizeof_regions(t->htable_bits); +} + +/* Get the ith element from the array block n */ +#define ahash_data(n, i, dsize) \ + ((struct mtype_elem *)((n)->value + ((i) * (dsize)))) + +static void +mtype_ext_cleanup(struct ip_set *set, struct hbucket *n) +{ + int i; + + for (i = 0; i < n->pos; i++) + if (test_bit(i, n->used)) + ip_set_ext_destroy(set, ahash_data(n, i, set->dsize)); +} + +/* Flush a hash type of set: destroy all elements */ +static void +mtype_flush(struct ip_set *set) +{ + struct htype *h = set->data; + struct htable *t; + struct hbucket *n; + u32 r, i; + + t = ipset_dereference_nfnl(h->table); + for (r = 0; r < ahash_numof_locks(t->htable_bits); r++) { + spin_lock_bh(&t->hregion[r].lock); + for (i = ahash_bucket_start(r, t->htable_bits); + i < ahash_bucket_end(r, t->htable_bits); i++) { + n = __ipset_dereference(hbucket(t, i)); + if (!n) + continue; + if (set->extensions & IPSET_EXT_DESTROY) + mtype_ext_cleanup(set, n); + /* FIXME: use slab cache */ + rcu_assign_pointer(hbucket(t, i), NULL); + kfree_rcu(n, rcu); + } + t->hregion[r].ext_size = 0; + t->hregion[r].elements = 0; + spin_unlock_bh(&t->hregion[r].lock); + } +#ifdef IP_SET_HASH_WITH_NETS + memset(h->nets, 0, sizeof(h->nets)); +#endif +} + +/* Destroy the hashtable part of the set */ +static void +mtype_ahash_destroy(struct ip_set *set, struct htable *t, bool ext_destroy) +{ + struct hbucket *n; + u32 i; + + for (i = 0; i < jhash_size(t->htable_bits); i++) { + n = __ipset_dereference(hbucket(t, i)); + if (!n) + continue; + if (set->extensions & IPSET_EXT_DESTROY && ext_destroy) + mtype_ext_cleanup(set, n); + /* FIXME: use slab cache */ + kfree(n); + } + + ip_set_free(t->hregion); + ip_set_free(t); +} + +/* Destroy a hash type of set */ +static void +mtype_destroy(struct ip_set *set) +{ + struct htype *h = set->data; + struct list_head *l, *lt; + + if (SET_WITH_TIMEOUT(set)) + cancel_delayed_work_sync(&h->gc.dwork); + + mtype_ahash_destroy(set, ipset_dereference_nfnl(h->table), true); + list_for_each_safe(l, lt, &h->ad) { + list_del(l); + kfree(l); + } + kfree(h); + + set->data = NULL; +} + +static bool +mtype_same_set(const struct ip_set *a, const struct ip_set *b) +{ + const struct htype *x = a->data; + const struct htype *y = b->data; + + /* Resizing changes htable_bits, so we ignore it */ + return x->maxelem == y->maxelem && + a->timeout == b->timeout && +#ifdef IP_SET_HASH_WITH_NETMASK + x->netmask == y->netmask && +#endif +#ifdef IP_SET_HASH_WITH_MARKMASK + x->markmask == y->markmask && +#endif + a->extensions == b->extensions; +} + +static void +mtype_gc_do(struct ip_set *set, struct htype *h, struct htable *t, u32 r) +{ + struct hbucket *n, *tmp; + struct mtype_elem *data; + u32 i, j, d; + size_t dsize = set->dsize; +#ifdef IP_SET_HASH_WITH_NETS + u8 k; +#endif + u8 htable_bits = t->htable_bits; + + spin_lock_bh(&t->hregion[r].lock); + for (i = ahash_bucket_start(r, htable_bits); + i < ahash_bucket_end(r, htable_bits); i++) { + n = __ipset_dereference(hbucket(t, i)); + if (!n) + continue; + for (j = 0, d = 0; j < n->pos; j++) { + if (!test_bit(j, n->used)) { + d++; + continue; + } + data = ahash_data(n, j, dsize); + if (!ip_set_timeout_expired(ext_timeout(data, set))) + continue; + pr_debug("expired %u/%u\n", i, j); + clear_bit(j, n->used); + smp_mb__after_atomic(); +#ifdef IP_SET_HASH_WITH_NETS + for (k = 0; k < IPSET_NET_COUNT; k++) + mtype_del_cidr(set, h, + NCIDR_PUT(DCIDR_GET(data->cidr, k)), + k); +#endif + t->hregion[r].elements--; + ip_set_ext_destroy(set, data); + d++; + } + if (d >= AHASH_INIT_SIZE) { + if (d >= n->size) { + t->hregion[r].ext_size -= + ext_size(n->size, dsize); + rcu_assign_pointer(hbucket(t, i), NULL); + kfree_rcu(n, rcu); + continue; + } + tmp = kzalloc(sizeof(*tmp) + + (n->size - AHASH_INIT_SIZE) * dsize, + GFP_ATOMIC); + if (!tmp) + /* Still try to delete expired elements. */ + continue; + tmp->size = n->size - AHASH_INIT_SIZE; + for (j = 0, d = 0; j < n->pos; j++) { + if (!test_bit(j, n->used)) + continue; + data = ahash_data(n, j, dsize); + memcpy(tmp->value + d * dsize, + data, dsize); + set_bit(d, tmp->used); + d++; + } + tmp->pos = d; + t->hregion[r].ext_size -= + ext_size(AHASH_INIT_SIZE, dsize); + rcu_assign_pointer(hbucket(t, i), tmp); + kfree_rcu(n, rcu); + } + } + spin_unlock_bh(&t->hregion[r].lock); +} + +static void +mtype_gc(struct work_struct *work) +{ + struct htable_gc *gc; + struct ip_set *set; + struct htype *h; + struct htable *t; + u32 r, numof_locks; + unsigned int next_run; + + gc = container_of(work, struct htable_gc, dwork.work); + set = gc->set; + h = set->data; + + spin_lock_bh(&set->lock); + t = ipset_dereference_set(h->table, set); + atomic_inc(&t->uref); + numof_locks = ahash_numof_locks(t->htable_bits); + r = gc->region++; + if (r >= numof_locks) { + r = gc->region = 0; + } + next_run = (IPSET_GC_PERIOD(set->timeout) * HZ) / numof_locks; + if (next_run < HZ/10) + next_run = HZ/10; + spin_unlock_bh(&set->lock); + + mtype_gc_do(set, h, t, r); + + if (atomic_dec_and_test(&t->uref) && atomic_read(&t->ref)) { + pr_debug("Table destroy after resize by expire: %p\n", t); + mtype_ahash_destroy(set, t, false); + } + + queue_delayed_work(system_power_efficient_wq, &gc->dwork, next_run); + +} + +static void +mtype_gc_init(struct htable_gc *gc) +{ + INIT_DEFERRABLE_WORK(&gc->dwork, mtype_gc); + queue_delayed_work(system_power_efficient_wq, &gc->dwork, HZ); +} + +static int +mtype_add(struct ip_set *set, void *value, const struct ip_set_ext *ext, + struct ip_set_ext *mext, u32 flags); +static int +mtype_del(struct ip_set *set, void *value, const struct ip_set_ext *ext, + struct ip_set_ext *mext, u32 flags); + +/* Resize a hash: create a new hash table with doubling the hashsize + * and inserting the elements to it. Repeat until we succeed or + * fail due to memory pressures. + */ +static int +mtype_resize(struct ip_set *set, bool retried) +{ + struct htype *h = set->data; + struct htable *t, *orig; + u8 htable_bits; + size_t hsize, dsize = set->dsize; +#ifdef IP_SET_HASH_WITH_NETS + u8 flags; + struct mtype_elem *tmp; +#endif + struct mtype_elem *data; + struct mtype_elem *d; + struct hbucket *n, *m; + struct list_head *l, *lt; + struct mtype_resize_ad *x; + u32 i, j, r, nr, key; + int ret; + +#ifdef IP_SET_HASH_WITH_NETS + tmp = kmalloc(dsize, GFP_KERNEL); + if (!tmp) + return -ENOMEM; +#endif + orig = ipset_dereference_bh_nfnl(h->table); + htable_bits = orig->htable_bits; + +retry: + ret = 0; + htable_bits++; + if (!htable_bits) + goto hbwarn; + hsize = htable_size(htable_bits); + if (!hsize) + goto hbwarn; + t = ip_set_alloc(hsize); + if (!t) { + ret = -ENOMEM; + goto out; + } + t->hregion = ip_set_alloc(ahash_sizeof_regions(htable_bits)); + if (!t->hregion) { + ip_set_free(t); + ret = -ENOMEM; + goto out; + } + t->htable_bits = htable_bits; + t->maxelem = h->maxelem / ahash_numof_locks(htable_bits); + for (i = 0; i < ahash_numof_locks(htable_bits); i++) + spin_lock_init(&t->hregion[i].lock); + + /* There can't be another parallel resizing, + * but dumping, gc, kernel side add/del are possible + */ + orig = ipset_dereference_bh_nfnl(h->table); + atomic_set(&orig->ref, 1); + atomic_inc(&orig->uref); + pr_debug("attempt to resize set %s from %u to %u, t %p\n", + set->name, orig->htable_bits, htable_bits, orig); + for (r = 0; r < ahash_numof_locks(orig->htable_bits); r++) { + /* Expire may replace a hbucket with another one */ + rcu_read_lock_bh(); + for (i = ahash_bucket_start(r, orig->htable_bits); + i < ahash_bucket_end(r, orig->htable_bits); i++) { + n = __ipset_dereference(hbucket(orig, i)); + if (!n) + continue; + for (j = 0; j < n->pos; j++) { + if (!test_bit(j, n->used)) + continue; + data = ahash_data(n, j, dsize); + if (SET_ELEM_EXPIRED(set, data)) + continue; +#ifdef IP_SET_HASH_WITH_NETS + /* We have readers running parallel with us, + * so the live data cannot be modified. + */ + flags = 0; + memcpy(tmp, data, dsize); + data = tmp; + mtype_data_reset_flags(data, &flags); +#endif + key = HKEY(data, h->initval, htable_bits); + m = __ipset_dereference(hbucket(t, key)); + nr = ahash_region(key, htable_bits); + if (!m) { + m = kzalloc(sizeof(*m) + + AHASH_INIT_SIZE * dsize, + GFP_ATOMIC); + if (!m) { + ret = -ENOMEM; + goto cleanup; + } + m->size = AHASH_INIT_SIZE; + t->hregion[nr].ext_size += + ext_size(AHASH_INIT_SIZE, + dsize); + RCU_INIT_POINTER(hbucket(t, key), m); + } else if (m->pos >= m->size) { + struct hbucket *ht; + + if (m->size >= AHASH_MAX(h)) { + ret = -EAGAIN; + } else { + ht = kzalloc(sizeof(*ht) + + (m->size + AHASH_INIT_SIZE) + * dsize, + GFP_ATOMIC); + if (!ht) + ret = -ENOMEM; + } + if (ret < 0) + goto cleanup; + memcpy(ht, m, sizeof(struct hbucket) + + m->size * dsize); + ht->size = m->size + AHASH_INIT_SIZE; + t->hregion[nr].ext_size += + ext_size(AHASH_INIT_SIZE, + dsize); + kfree(m); + m = ht; + RCU_INIT_POINTER(hbucket(t, key), ht); + } + d = ahash_data(m, m->pos, dsize); + memcpy(d, data, dsize); + set_bit(m->pos++, m->used); + t->hregion[nr].elements++; +#ifdef IP_SET_HASH_WITH_NETS + mtype_data_reset_flags(d, &flags); +#endif + } + } + rcu_read_unlock_bh(); + } + + /* There can't be any other writer. */ + rcu_assign_pointer(h->table, t); + + /* Give time to other readers of the set */ + synchronize_rcu(); + + pr_debug("set %s resized from %u (%p) to %u (%p)\n", set->name, + orig->htable_bits, orig, t->htable_bits, t); + /* Add/delete elements processed by the SET target during resize. + * Kernel-side add cannot trigger a resize and userspace actions + * are serialized by the mutex. + */ + list_for_each_safe(l, lt, &h->ad) { + x = list_entry(l, struct mtype_resize_ad, list); + if (x->ad == IPSET_ADD) { + mtype_add(set, &x->d, &x->ext, &x->mext, x->flags); + } else { + mtype_del(set, &x->d, NULL, NULL, 0); + } + list_del(l); + kfree(l); + } + /* If there's nobody else using the table, destroy it */ + if (atomic_dec_and_test(&orig->uref)) { + pr_debug("Table destroy by resize %p\n", orig); + mtype_ahash_destroy(set, orig, false); + } + +out: +#ifdef IP_SET_HASH_WITH_NETS + kfree(tmp); +#endif + return ret; + +cleanup: + rcu_read_unlock_bh(); + atomic_set(&orig->ref, 0); + atomic_dec(&orig->uref); + mtype_ahash_destroy(set, t, false); + if (ret == -EAGAIN) + goto retry; + goto out; + +hbwarn: + /* In case we have plenty of memory :-) */ + pr_warn("Cannot increase the hashsize of set %s further\n", set->name); + ret = -IPSET_ERR_HASH_FULL; + goto out; +} + +/* Get the current number of elements and ext_size in the set */ +static void +mtype_ext_size(struct ip_set *set, u32 *elements, size_t *ext_size) +{ + struct htype *h = set->data; + const struct htable *t; + u32 i, j, r; + struct hbucket *n; + struct mtype_elem *data; + + t = rcu_dereference_bh(h->table); + for (r = 0; r < ahash_numof_locks(t->htable_bits); r++) { + for (i = ahash_bucket_start(r, t->htable_bits); + i < ahash_bucket_end(r, t->htable_bits); i++) { + n = rcu_dereference_bh(hbucket(t, i)); + if (!n) + continue; + for (j = 0; j < n->pos; j++) { + if (!test_bit(j, n->used)) + continue; + data = ahash_data(n, j, set->dsize); + if (!SET_ELEM_EXPIRED(set, data)) + (*elements)++; + } + } + *ext_size += t->hregion[r].ext_size; + } +} + +/* Add an element to a hash and update the internal counters when succeeded, + * otherwise report the proper error code. + */ +static int +mtype_add(struct ip_set *set, void *value, const struct ip_set_ext *ext, + struct ip_set_ext *mext, u32 flags) +{ + struct htype *h = set->data; + struct htable *t; + const struct mtype_elem *d = value; + struct mtype_elem *data; + struct hbucket *n, *old = ERR_PTR(-ENOENT); + int i, j = -1, ret; + bool flag_exist = flags & IPSET_FLAG_EXIST; + bool deleted = false, forceadd = false, reuse = false; + u32 r, key, multi = 0, elements, maxelem; + + rcu_read_lock_bh(); + t = rcu_dereference_bh(h->table); + key = HKEY(value, h->initval, t->htable_bits); + r = ahash_region(key, t->htable_bits); + atomic_inc(&t->uref); + elements = t->hregion[r].elements; + maxelem = t->maxelem; + if (elements >= maxelem) { + u32 e; + if (SET_WITH_TIMEOUT(set)) { + rcu_read_unlock_bh(); + mtype_gc_do(set, h, t, r); + rcu_read_lock_bh(); + } + maxelem = h->maxelem; + elements = 0; + for (e = 0; e < ahash_numof_locks(t->htable_bits); e++) + elements += t->hregion[e].elements; + if (elements >= maxelem && SET_WITH_FORCEADD(set)) + forceadd = true; + } + rcu_read_unlock_bh(); + + spin_lock_bh(&t->hregion[r].lock); + n = rcu_dereference_bh(hbucket(t, key)); + if (!n) { + if (forceadd || elements >= maxelem) + goto set_full; + old = NULL; + n = kzalloc(sizeof(*n) + AHASH_INIT_SIZE * set->dsize, + GFP_ATOMIC); + if (!n) { + ret = -ENOMEM; + goto unlock; + } + n->size = AHASH_INIT_SIZE; + t->hregion[r].ext_size += + ext_size(AHASH_INIT_SIZE, set->dsize); + goto copy_elem; + } + for (i = 0; i < n->pos; i++) { + if (!test_bit(i, n->used)) { + /* Reuse first deleted entry */ + if (j == -1) { + deleted = reuse = true; + j = i; + } + continue; + } + data = ahash_data(n, i, set->dsize); + if (mtype_data_equal(data, d, &multi)) { + if (flag_exist || SET_ELEM_EXPIRED(set, data)) { + /* Just the extensions could be overwritten */ + j = i; + goto overwrite_extensions; + } + ret = -IPSET_ERR_EXIST; + goto unlock; + } + /* Reuse first timed out entry */ + if (SET_ELEM_EXPIRED(set, data) && j == -1) { + j = i; + reuse = true; + } + } + if (reuse || forceadd) { + if (j == -1) + j = 0; + data = ahash_data(n, j, set->dsize); + if (!deleted) { +#ifdef IP_SET_HASH_WITH_NETS + for (i = 0; i < IPSET_NET_COUNT; i++) + mtype_del_cidr(set, h, + NCIDR_PUT(DCIDR_GET(data->cidr, i)), + i); +#endif + ip_set_ext_destroy(set, data); + t->hregion[r].elements--; + } + goto copy_data; + } + if (elements >= maxelem) + goto set_full; + /* Create a new slot */ + if (n->pos >= n->size) { +#ifdef IP_SET_HASH_WITH_MULTI + if (h->bucketsize >= AHASH_MAX_TUNED) + goto set_full; + else if (h->bucketsize <= multi) + h->bucketsize += AHASH_INIT_SIZE; +#endif + if (n->size >= AHASH_MAX(h)) { + /* Trigger rehashing */ + mtype_data_next(&h->next, d); + ret = -EAGAIN; + goto resize; + } + old = n; + n = kzalloc(sizeof(*n) + + (old->size + AHASH_INIT_SIZE) * set->dsize, + GFP_ATOMIC); + if (!n) { + ret = -ENOMEM; + goto unlock; + } + memcpy(n, old, sizeof(struct hbucket) + + old->size * set->dsize); + n->size = old->size + AHASH_INIT_SIZE; + t->hregion[r].ext_size += + ext_size(AHASH_INIT_SIZE, set->dsize); + } + +copy_elem: + j = n->pos++; + data = ahash_data(n, j, set->dsize); +copy_data: + t->hregion[r].elements++; +#ifdef IP_SET_HASH_WITH_NETS + for (i = 0; i < IPSET_NET_COUNT; i++) + mtype_add_cidr(set, h, NCIDR_PUT(DCIDR_GET(d->cidr, i)), i); +#endif + memcpy(data, d, sizeof(struct mtype_elem)); +overwrite_extensions: +#ifdef IP_SET_HASH_WITH_NETS + mtype_data_set_flags(data, flags); +#endif + if (SET_WITH_COUNTER(set)) + ip_set_init_counter(ext_counter(data, set), ext); + if (SET_WITH_COMMENT(set)) + ip_set_init_comment(set, ext_comment(data, set), ext); + if (SET_WITH_SKBINFO(set)) + ip_set_init_skbinfo(ext_skbinfo(data, set), ext); + /* Must come last for the case when timed out entry is reused */ + if (SET_WITH_TIMEOUT(set)) + ip_set_timeout_set(ext_timeout(data, set), ext->timeout); + smp_mb__before_atomic(); + set_bit(j, n->used); + if (old != ERR_PTR(-ENOENT)) { + rcu_assign_pointer(hbucket(t, key), n); + if (old) + kfree_rcu(old, rcu); + } + ret = 0; +resize: + spin_unlock_bh(&t->hregion[r].lock); + if (atomic_read(&t->ref) && ext->target) { + /* Resize is in process and kernel side add, save values */ + struct mtype_resize_ad *x; + + x = kzalloc(sizeof(struct mtype_resize_ad), GFP_ATOMIC); + if (!x) + /* Don't bother */ + goto out; + x->ad = IPSET_ADD; + memcpy(&x->d, value, sizeof(struct mtype_elem)); + memcpy(&x->ext, ext, sizeof(struct ip_set_ext)); + memcpy(&x->mext, mext, sizeof(struct ip_set_ext)); + x->flags = flags; + spin_lock_bh(&set->lock); + list_add_tail(&x->list, &h->ad); + spin_unlock_bh(&set->lock); + } + goto out; + +set_full: + if (net_ratelimit()) + pr_warn("Set %s is full, maxelem %u reached\n", + set->name, maxelem); + ret = -IPSET_ERR_HASH_FULL; +unlock: + spin_unlock_bh(&t->hregion[r].lock); +out: + if (atomic_dec_and_test(&t->uref) && atomic_read(&t->ref)) { + pr_debug("Table destroy after resize by add: %p\n", t); + mtype_ahash_destroy(set, t, false); + } + return ret; +} + +/* Delete an element from the hash and free up space if possible. + */ +static int +mtype_del(struct ip_set *set, void *value, const struct ip_set_ext *ext, + struct ip_set_ext *mext, u32 flags) +{ + struct htype *h = set->data; + struct htable *t; + const struct mtype_elem *d = value; + struct mtype_elem *data; + struct hbucket *n; + struct mtype_resize_ad *x = NULL; + int i, j, k, r, ret = -IPSET_ERR_EXIST; + u32 key, multi = 0; + size_t dsize = set->dsize; + + /* Userspace add and resize is excluded by the mutex. + * Kernespace add does not trigger resize. + */ + rcu_read_lock_bh(); + t = rcu_dereference_bh(h->table); + key = HKEY(value, h->initval, t->htable_bits); + r = ahash_region(key, t->htable_bits); + atomic_inc(&t->uref); + rcu_read_unlock_bh(); + + spin_lock_bh(&t->hregion[r].lock); + n = rcu_dereference_bh(hbucket(t, key)); + if (!n) + goto out; + for (i = 0, k = 0; i < n->pos; i++) { + if (!test_bit(i, n->used)) { + k++; + continue; + } + data = ahash_data(n, i, dsize); + if (!mtype_data_equal(data, d, &multi)) + continue; + if (SET_ELEM_EXPIRED(set, data)) + goto out; + + ret = 0; + clear_bit(i, n->used); + smp_mb__after_atomic(); + if (i + 1 == n->pos) + n->pos--; + t->hregion[r].elements--; +#ifdef IP_SET_HASH_WITH_NETS + for (j = 0; j < IPSET_NET_COUNT; j++) + mtype_del_cidr(set, h, + NCIDR_PUT(DCIDR_GET(d->cidr, j)), j); +#endif + ip_set_ext_destroy(set, data); + + if (atomic_read(&t->ref) && ext->target) { + /* Resize is in process and kernel side del, + * save values + */ + x = kzalloc(sizeof(struct mtype_resize_ad), + GFP_ATOMIC); + if (x) { + x->ad = IPSET_DEL; + memcpy(&x->d, value, + sizeof(struct mtype_elem)); + x->flags = flags; + } + } + for (; i < n->pos; i++) { + if (!test_bit(i, n->used)) + k++; + } + if (n->pos == 0 && k == 0) { + t->hregion[r].ext_size -= ext_size(n->size, dsize); + rcu_assign_pointer(hbucket(t, key), NULL); + kfree_rcu(n, rcu); + } else if (k >= AHASH_INIT_SIZE) { + struct hbucket *tmp = kzalloc(sizeof(*tmp) + + (n->size - AHASH_INIT_SIZE) * dsize, + GFP_ATOMIC); + if (!tmp) + goto out; + tmp->size = n->size - AHASH_INIT_SIZE; + for (j = 0, k = 0; j < n->pos; j++) { + if (!test_bit(j, n->used)) + continue; + data = ahash_data(n, j, dsize); + memcpy(tmp->value + k * dsize, data, dsize); + set_bit(k, tmp->used); + k++; + } + tmp->pos = k; + t->hregion[r].ext_size -= + ext_size(AHASH_INIT_SIZE, dsize); + rcu_assign_pointer(hbucket(t, key), tmp); + kfree_rcu(n, rcu); + } + goto out; + } + +out: + spin_unlock_bh(&t->hregion[r].lock); + if (x) { + spin_lock_bh(&set->lock); + list_add(&x->list, &h->ad); + spin_unlock_bh(&set->lock); + } + if (atomic_dec_and_test(&t->uref) && atomic_read(&t->ref)) { + pr_debug("Table destroy after resize by del: %p\n", t); + mtype_ahash_destroy(set, t, false); + } + return ret; +} + +static int +mtype_data_match(struct mtype_elem *data, const struct ip_set_ext *ext, + struct ip_set_ext *mext, struct ip_set *set, u32 flags) +{ + if (!ip_set_match_extensions(set, ext, mext, flags, data)) + return 0; + /* nomatch entries return -ENOTEMPTY */ + return mtype_do_data_match(data); +} + +#ifdef IP_SET_HASH_WITH_NETS +/* Special test function which takes into account the different network + * sizes added to the set + */ +static int +mtype_test_cidrs(struct ip_set *set, struct mtype_elem *d, + const struct ip_set_ext *ext, + struct ip_set_ext *mext, u32 flags) +{ + struct htype *h = set->data; + struct htable *t = rcu_dereference_bh(h->table); + struct hbucket *n; + struct mtype_elem *data; +#if IPSET_NET_COUNT == 2 + struct mtype_elem orig = *d; + int ret, i, j = 0, k; +#else + int ret, i, j = 0; +#endif + u32 key, multi = 0; + + pr_debug("test by nets\n"); + for (; j < NLEN && h->nets[j].cidr[0] && !multi; j++) { +#if IPSET_NET_COUNT == 2 + mtype_data_reset_elem(d, &orig); + mtype_data_netmask(d, NCIDR_GET(h->nets[j].cidr[0]), false); + for (k = 0; k < NLEN && h->nets[k].cidr[1] && !multi; + k++) { + mtype_data_netmask(d, NCIDR_GET(h->nets[k].cidr[1]), + true); +#else + mtype_data_netmask(d, NCIDR_GET(h->nets[j].cidr[0])); +#endif + key = HKEY(d, h->initval, t->htable_bits); + n = rcu_dereference_bh(hbucket(t, key)); + if (!n) + continue; + for (i = 0; i < n->pos; i++) { + if (!test_bit(i, n->used)) + continue; + data = ahash_data(n, i, set->dsize); + if (!mtype_data_equal(data, d, &multi)) + continue; + ret = mtype_data_match(data, ext, mext, set, flags); + if (ret != 0) + return ret; +#ifdef IP_SET_HASH_WITH_MULTI + /* No match, reset multiple match flag */ + multi = 0; +#endif + } +#if IPSET_NET_COUNT == 2 + } +#endif + } + return 0; +} +#endif + +/* Test whether the element is added to the set */ +static int +mtype_test(struct ip_set *set, void *value, const struct ip_set_ext *ext, + struct ip_set_ext *mext, u32 flags) +{ + struct htype *h = set->data; + struct htable *t; + struct mtype_elem *d = value; + struct hbucket *n; + struct mtype_elem *data; + int i, ret = 0; + u32 key, multi = 0; + + rcu_read_lock_bh(); + t = rcu_dereference_bh(h->table); +#ifdef IP_SET_HASH_WITH_NETS + /* If we test an IP address and not a network address, + * try all possible network sizes + */ + for (i = 0; i < IPSET_NET_COUNT; i++) + if (DCIDR_GET(d->cidr, i) != HOST_MASK) + break; + if (i == IPSET_NET_COUNT) { + ret = mtype_test_cidrs(set, d, ext, mext, flags); + goto out; + } +#endif + + key = HKEY(d, h->initval, t->htable_bits); + n = rcu_dereference_bh(hbucket(t, key)); + if (!n) { + ret = 0; + goto out; + } + for (i = 0; i < n->pos; i++) { + if (!test_bit(i, n->used)) + continue; + data = ahash_data(n, i, set->dsize); + if (!mtype_data_equal(data, d, &multi)) + continue; + ret = mtype_data_match(data, ext, mext, set, flags); + if (ret != 0) + goto out; + } +out: + rcu_read_unlock_bh(); + return ret; +} + +/* Reply a HEADER request: fill out the header part of the set */ +static int +mtype_head(struct ip_set *set, struct sk_buff *skb) +{ + struct htype *h = set->data; + const struct htable *t; + struct nlattr *nested; + size_t memsize; + u32 elements = 0; + size_t ext_size = 0; + u8 htable_bits; + + rcu_read_lock_bh(); + t = rcu_dereference_bh(h->table); + mtype_ext_size(set, &elements, &ext_size); + memsize = mtype_ahash_memsize(h, t) + ext_size + set->ext_size; + htable_bits = t->htable_bits; + rcu_read_unlock_bh(); + + nested = nla_nest_start(skb, IPSET_ATTR_DATA); + if (!nested) + goto nla_put_failure; + if (nla_put_net32(skb, IPSET_ATTR_HASHSIZE, + htonl(jhash_size(htable_bits))) || + nla_put_net32(skb, IPSET_ATTR_MAXELEM, htonl(h->maxelem))) + goto nla_put_failure; +#ifdef IP_SET_HASH_WITH_NETMASK + if (h->netmask != HOST_MASK && + nla_put_u8(skb, IPSET_ATTR_NETMASK, h->netmask)) + goto nla_put_failure; +#endif +#ifdef IP_SET_HASH_WITH_MARKMASK + if (nla_put_u32(skb, IPSET_ATTR_MARKMASK, h->markmask)) + goto nla_put_failure; +#endif + if (set->flags & IPSET_CREATE_FLAG_BUCKETSIZE) { + if (nla_put_u8(skb, IPSET_ATTR_BUCKETSIZE, h->bucketsize) || + nla_put_net32(skb, IPSET_ATTR_INITVAL, htonl(h->initval))) + goto nla_put_failure; + } + if (nla_put_net32(skb, IPSET_ATTR_REFERENCES, htonl(set->ref)) || + nla_put_net32(skb, IPSET_ATTR_MEMSIZE, htonl(memsize)) || + nla_put_net32(skb, IPSET_ATTR_ELEMENTS, htonl(elements))) + goto nla_put_failure; + if (unlikely(ip_set_put_flags(skb, set))) + goto nla_put_failure; + nla_nest_end(skb, nested); + + return 0; +nla_put_failure: + return -EMSGSIZE; +} + +/* Make possible to run dumping parallel with resizing */ +static void +mtype_uref(struct ip_set *set, struct netlink_callback *cb, bool start) +{ + struct htype *h = set->data; + struct htable *t; + + if (start) { + rcu_read_lock_bh(); + t = ipset_dereference_bh_nfnl(h->table); + atomic_inc(&t->uref); + cb->args[IPSET_CB_PRIVATE] = (unsigned long)t; + rcu_read_unlock_bh(); + } else if (cb->args[IPSET_CB_PRIVATE]) { + t = (struct htable *)cb->args[IPSET_CB_PRIVATE]; + if (atomic_dec_and_test(&t->uref) && atomic_read(&t->ref)) { + pr_debug("Table destroy after resize " + " by dump: %p\n", t); + mtype_ahash_destroy(set, t, false); + } + cb->args[IPSET_CB_PRIVATE] = 0; + } +} + +/* Reply a LIST/SAVE request: dump the elements of the specified set */ +static int +mtype_list(const struct ip_set *set, + struct sk_buff *skb, struct netlink_callback *cb) +{ + const struct htable *t; + struct nlattr *atd, *nested; + const struct hbucket *n; + const struct mtype_elem *e; + u32 first = cb->args[IPSET_CB_ARG0]; + /* We assume that one hash bucket fills into one page */ + void *incomplete; + int i, ret = 0; + + atd = nla_nest_start(skb, IPSET_ATTR_ADT); + if (!atd) + return -EMSGSIZE; + + pr_debug("list hash set %s\n", set->name); + t = (const struct htable *)cb->args[IPSET_CB_PRIVATE]; + /* Expire may replace a hbucket with another one */ + rcu_read_lock(); + for (; cb->args[IPSET_CB_ARG0] < jhash_size(t->htable_bits); + cb->args[IPSET_CB_ARG0]++) { + cond_resched_rcu(); + incomplete = skb_tail_pointer(skb); + n = rcu_dereference(hbucket(t, cb->args[IPSET_CB_ARG0])); + pr_debug("cb->arg bucket: %lu, t %p n %p\n", + cb->args[IPSET_CB_ARG0], t, n); + if (!n) + continue; + for (i = 0; i < n->pos; i++) { + if (!test_bit(i, n->used)) + continue; + e = ahash_data(n, i, set->dsize); + if (SET_ELEM_EXPIRED(set, e)) + continue; + pr_debug("list hash %lu hbucket %p i %u, data %p\n", + cb->args[IPSET_CB_ARG0], n, i, e); + nested = nla_nest_start(skb, IPSET_ATTR_DATA); + if (!nested) { + if (cb->args[IPSET_CB_ARG0] == first) { + nla_nest_cancel(skb, atd); + ret = -EMSGSIZE; + goto out; + } + goto nla_put_failure; + } + if (mtype_data_list(skb, e)) + goto nla_put_failure; + if (ip_set_put_extensions(skb, set, e, true)) + goto nla_put_failure; + nla_nest_end(skb, nested); + } + } + nla_nest_end(skb, atd); + /* Set listing finished */ + cb->args[IPSET_CB_ARG0] = 0; + + goto out; + +nla_put_failure: + nlmsg_trim(skb, incomplete); + if (unlikely(first == cb->args[IPSET_CB_ARG0])) { + pr_warn("Can't list set %s: one bucket does not fit into a message. Please report it!\n", + set->name); + cb->args[IPSET_CB_ARG0] = 0; + ret = -EMSGSIZE; + } else { + nla_nest_end(skb, atd); + } +out: + rcu_read_unlock(); + return ret; +} + +static int +IPSET_TOKEN(MTYPE, _kadt)(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt); + +static int +IPSET_TOKEN(MTYPE, _uadt)(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, + bool retried); + +static const struct ip_set_type_variant mtype_variant = { + .kadt = mtype_kadt, + .uadt = mtype_uadt, + .adt = { + [IPSET_ADD] = mtype_add, + [IPSET_DEL] = mtype_del, + [IPSET_TEST] = mtype_test, + }, + .destroy = mtype_destroy, + .flush = mtype_flush, + .head = mtype_head, + .list = mtype_list, + .uref = mtype_uref, + .resize = mtype_resize, + .same_set = mtype_same_set, + .region_lock = true, +}; + +#ifdef IP_SET_EMIT_CREATE +static int +IPSET_TOKEN(HTYPE, _create)(struct net *net, struct ip_set *set, + struct nlattr *tb[], u32 flags) +{ + u32 hashsize = IPSET_DEFAULT_HASHSIZE, maxelem = IPSET_DEFAULT_MAXELEM; +#ifdef IP_SET_HASH_WITH_MARKMASK + u32 markmask; +#endif + u8 hbits; +#ifdef IP_SET_HASH_WITH_NETMASK + u8 netmask; +#endif + size_t hsize; + struct htype *h; + struct htable *t; + u32 i; + + pr_debug("Create set %s with family %s\n", + set->name, set->family == NFPROTO_IPV4 ? "inet" : "inet6"); + +#ifdef IP_SET_PROTO_UNDEF + if (set->family != NFPROTO_UNSPEC) + return -IPSET_ERR_INVALID_FAMILY; +#else + if (!(set->family == NFPROTO_IPV4 || set->family == NFPROTO_IPV6)) + return -IPSET_ERR_INVALID_FAMILY; +#endif + + if (unlikely(!ip_set_optattr_netorder(tb, IPSET_ATTR_HASHSIZE) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_MAXELEM) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + +#ifdef IP_SET_HASH_WITH_MARKMASK + /* Separated condition in order to avoid directive in argument list */ + if (unlikely(!ip_set_optattr_netorder(tb, IPSET_ATTR_MARKMASK))) + return -IPSET_ERR_PROTOCOL; + + markmask = 0xffffffff; + if (tb[IPSET_ATTR_MARKMASK]) { + markmask = ntohl(nla_get_be32(tb[IPSET_ATTR_MARKMASK])); + if (markmask == 0) + return -IPSET_ERR_INVALID_MARKMASK; + } +#endif + +#ifdef IP_SET_HASH_WITH_NETMASK + netmask = set->family == NFPROTO_IPV4 ? 32 : 128; + if (tb[IPSET_ATTR_NETMASK]) { + netmask = nla_get_u8(tb[IPSET_ATTR_NETMASK]); + + if ((set->family == NFPROTO_IPV4 && netmask > 32) || + (set->family == NFPROTO_IPV6 && netmask > 128) || + netmask == 0) + return -IPSET_ERR_INVALID_NETMASK; + } +#endif + + if (tb[IPSET_ATTR_HASHSIZE]) { + hashsize = ip_set_get_h32(tb[IPSET_ATTR_HASHSIZE]); + if (hashsize < IPSET_MIMINAL_HASHSIZE) + hashsize = IPSET_MIMINAL_HASHSIZE; + } + + if (tb[IPSET_ATTR_MAXELEM]) + maxelem = ip_set_get_h32(tb[IPSET_ATTR_MAXELEM]); + + hsize = sizeof(*h); + h = kzalloc(hsize, GFP_KERNEL); + if (!h) + return -ENOMEM; + + /* Compute htable_bits from the user input parameter hashsize. + * Assume that hashsize == 2^htable_bits, + * otherwise round up to the first 2^n value. + */ + hbits = fls(hashsize - 1); + hsize = htable_size(hbits); + if (hsize == 0) { + kfree(h); + return -ENOMEM; + } + t = ip_set_alloc(hsize); + if (!t) { + kfree(h); + return -ENOMEM; + } + t->hregion = ip_set_alloc(ahash_sizeof_regions(hbits)); + if (!t->hregion) { + ip_set_free(t); + kfree(h); + return -ENOMEM; + } + h->gc.set = set; + for (i = 0; i < ahash_numof_locks(hbits); i++) + spin_lock_init(&t->hregion[i].lock); + h->maxelem = maxelem; +#ifdef IP_SET_HASH_WITH_NETMASK + h->netmask = netmask; +#endif +#ifdef IP_SET_HASH_WITH_MARKMASK + h->markmask = markmask; +#endif + if (tb[IPSET_ATTR_INITVAL]) + h->initval = ntohl(nla_get_be32(tb[IPSET_ATTR_INITVAL])); + else + get_random_bytes(&h->initval, sizeof(h->initval)); + h->bucketsize = AHASH_MAX_SIZE; + if (tb[IPSET_ATTR_BUCKETSIZE]) { + h->bucketsize = nla_get_u8(tb[IPSET_ATTR_BUCKETSIZE]); + if (h->bucketsize < AHASH_INIT_SIZE) + h->bucketsize = AHASH_INIT_SIZE; + else if (h->bucketsize > AHASH_MAX_SIZE) + h->bucketsize = AHASH_MAX_SIZE; + else if (h->bucketsize % 2) + h->bucketsize += 1; + } + t->htable_bits = hbits; + t->maxelem = h->maxelem / ahash_numof_locks(hbits); + RCU_INIT_POINTER(h->table, t); + + INIT_LIST_HEAD(&h->ad); + set->data = h; +#ifndef IP_SET_PROTO_UNDEF + if (set->family == NFPROTO_IPV4) { +#endif + set->variant = &IPSET_TOKEN(HTYPE, 4_variant); + set->dsize = ip_set_elem_len(set, tb, + sizeof(struct IPSET_TOKEN(HTYPE, 4_elem)), + __alignof__(struct IPSET_TOKEN(HTYPE, 4_elem))); +#ifndef IP_SET_PROTO_UNDEF + } else { + set->variant = &IPSET_TOKEN(HTYPE, 6_variant); + set->dsize = ip_set_elem_len(set, tb, + sizeof(struct IPSET_TOKEN(HTYPE, 6_elem)), + __alignof__(struct IPSET_TOKEN(HTYPE, 6_elem))); + } +#endif + set->timeout = IPSET_NO_TIMEOUT; + if (tb[IPSET_ATTR_TIMEOUT]) { + set->timeout = ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]); +#ifndef IP_SET_PROTO_UNDEF + if (set->family == NFPROTO_IPV4) +#endif + IPSET_TOKEN(HTYPE, 4_gc_init)(&h->gc); +#ifndef IP_SET_PROTO_UNDEF + else + IPSET_TOKEN(HTYPE, 6_gc_init)(&h->gc); +#endif + } + pr_debug("create %s hashsize %u (%u) maxelem %u: %p(%p)\n", + set->name, jhash_size(t->htable_bits), + t->htable_bits, h->maxelem, set->data, t); + + return 0; +} +#endif /* IP_SET_EMIT_CREATE */ + +#undef HKEY_DATALEN diff --git a/net/netfilter/ipset/ip_set_hash_ip.c b/net/netfilter/ipset/ip_set_hash_ip.c new file mode 100644 index 000000000..24adcdd7a --- /dev/null +++ b/net/netfilter/ipset/ip_set_hash_ip.c @@ -0,0 +1,329 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2003-2013 Jozsef Kadlecsik */ + +/* Kernel module implementing an IP set type: the hash:ip type */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#define IPSET_TYPE_REV_MIN 0 +/* 1 Counters support */ +/* 2 Comments support */ +/* 3 Forceadd support */ +/* 4 skbinfo support */ +#define IPSET_TYPE_REV_MAX 5 /* bucketsize, initval support */ + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Jozsef Kadlecsik "); +IP_SET_MODULE_DESC("hash:ip", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX); +MODULE_ALIAS("ip_set_hash:ip"); + +/* Type specific function prefix */ +#define HTYPE hash_ip +#define IP_SET_HASH_WITH_NETMASK + +/* IPv4 variant */ + +/* Member elements */ +struct hash_ip4_elem { + /* Zero valued IP addresses cannot be stored */ + __be32 ip; +}; + +/* Common functions */ + +static bool +hash_ip4_data_equal(const struct hash_ip4_elem *e1, + const struct hash_ip4_elem *e2, + u32 *multi) +{ + return e1->ip == e2->ip; +} + +static bool +hash_ip4_data_list(struct sk_buff *skb, const struct hash_ip4_elem *e) +{ + if (nla_put_ipaddr4(skb, IPSET_ATTR_IP, e->ip)) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_ip4_data_next(struct hash_ip4_elem *next, const struct hash_ip4_elem *e) +{ + next->ip = e->ip; +} + +#define MTYPE hash_ip4 +#define HOST_MASK 32 +#include "ip_set_hash_gen.h" + +static int +hash_ip4_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + const struct hash_ip4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ip4_elem e = { 0 }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + __be32 ip; + + ip4addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &ip); + ip &= ip_set_netmask(h->netmask); + if (ip == 0) + return -EINVAL; + + e.ip = ip; + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_ip4_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + struct hash_ip4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ip4_elem e = { 0 }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + u32 ip = 0, ip_to = 0, hosts, i = 0; + int ret = 0; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP])) + return -IPSET_ERR_PROTOCOL; + + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP], &ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + ip &= ip_set_hostmask(h->netmask); + e.ip = htonl(ip); + if (e.ip == 0) + return -IPSET_ERR_HASH_ELEM; + + if (adt == IPSET_TEST) + return adtfn(set, &e, &ext, &ext, flags); + + ip_to = ip; + if (tb[IPSET_ATTR_IP_TO]) { + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP_TO], &ip_to); + if (ret) + return ret; + if (ip > ip_to) { + if (ip_to == 0) + return -IPSET_ERR_HASH_ELEM; + swap(ip, ip_to); + } + } else if (tb[IPSET_ATTR_CIDR]) { + u8 cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + + if (!cidr || cidr > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + ip_set_mask_from_to(ip, ip_to, cidr); + } + + hosts = h->netmask == 32 ? 1 : 2 << (32 - h->netmask - 1); + + if (retried) + ip = ntohl(h->next.ip); + for (; ip <= ip_to; i++) { + e.ip = htonl(ip); + if (i > IPSET_MAX_RANGE) { + hash_ip4_data_next(&h->next, &e); + return -ERANGE; + } + ret = adtfn(set, &e, &ext, &ext, flags); + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + ip += hosts; + if (ip == 0) + return 0; + + ret = 0; + } + return ret; +} + +/* IPv6 variant */ + +/* Member elements */ +struct hash_ip6_elem { + union nf_inet_addr ip; +}; + +/* Common functions */ + +static bool +hash_ip6_data_equal(const struct hash_ip6_elem *ip1, + const struct hash_ip6_elem *ip2, + u32 *multi) +{ + return ipv6_addr_equal(&ip1->ip.in6, &ip2->ip.in6); +} + +static void +hash_ip6_netmask(union nf_inet_addr *ip, u8 prefix) +{ + ip6_netmask(ip, prefix); +} + +static bool +hash_ip6_data_list(struct sk_buff *skb, const struct hash_ip6_elem *e) +{ + if (nla_put_ipaddr6(skb, IPSET_ATTR_IP, &e->ip.in6)) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_ip6_data_next(struct hash_ip6_elem *next, const struct hash_ip6_elem *e) +{ +} + +#undef MTYPE +#undef HOST_MASK + +#define MTYPE hash_ip6 +#define HOST_MASK 128 + +#define IP_SET_EMIT_CREATE +#include "ip_set_hash_gen.h" + +static int +hash_ip6_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + const struct hash_ip6 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ip6_elem e = { { .all = { 0 } } }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + ip6addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip.in6); + hash_ip6_netmask(&e.ip, h->netmask); + if (ipv6_addr_any(&e.ip.in6)) + return -EINVAL; + + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_ip6_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + const struct hash_ip6 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ip6_elem e = { { .all = { 0 } } }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP])) + return -IPSET_ERR_PROTOCOL; + if (unlikely(tb[IPSET_ATTR_IP_TO])) + return -IPSET_ERR_HASH_RANGE_UNSUPPORTED; + if (unlikely(tb[IPSET_ATTR_CIDR])) { + u8 cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + + if (cidr != HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + } + + ret = ip_set_get_ipaddr6(tb[IPSET_ATTR_IP], &e.ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + hash_ip6_netmask(&e.ip, h->netmask); + if (ipv6_addr_any(&e.ip.in6)) + return -IPSET_ERR_HASH_ELEM; + + ret = adtfn(set, &e, &ext, &ext, flags); + + return ip_set_eexist(ret, flags) ? 0 : ret; +} + +static struct ip_set_type hash_ip_type __read_mostly = { + .name = "hash:ip", + .protocol = IPSET_PROTOCOL, + .features = IPSET_TYPE_IP, + .dimension = IPSET_DIM_ONE, + .family = NFPROTO_UNSPEC, + .revision_min = IPSET_TYPE_REV_MIN, + .revision_max = IPSET_TYPE_REV_MAX, + .create_flags[IPSET_TYPE_REV_MAX] = IPSET_CREATE_FLAG_BUCKETSIZE, + .create = hash_ip_create, + .create_policy = { + [IPSET_ATTR_HASHSIZE] = { .type = NLA_U32 }, + [IPSET_ATTR_MAXELEM] = { .type = NLA_U32 }, + [IPSET_ATTR_INITVAL] = { .type = NLA_U32 }, + [IPSET_ATTR_BUCKETSIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_RESIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_NETMASK] = { .type = NLA_U8 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + }, + .adt_policy = { + [IPSET_ATTR_IP] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP_TO] = { .type = NLA_NESTED }, + [IPSET_ATTR_CIDR] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_LINENO] = { .type = NLA_U32 }, + [IPSET_ATTR_BYTES] = { .type = NLA_U64 }, + [IPSET_ATTR_PACKETS] = { .type = NLA_U64 }, + [IPSET_ATTR_COMMENT] = { .type = NLA_NUL_STRING, + .len = IPSET_MAX_COMMENT_SIZE }, + [IPSET_ATTR_SKBMARK] = { .type = NLA_U64 }, + [IPSET_ATTR_SKBPRIO] = { .type = NLA_U32 }, + [IPSET_ATTR_SKBQUEUE] = { .type = NLA_U16 }, + }, + .me = THIS_MODULE, +}; + +static int __init +hash_ip_init(void) +{ + return ip_set_type_register(&hash_ip_type); +} + +static void __exit +hash_ip_fini(void) +{ + rcu_barrier(); + ip_set_type_unregister(&hash_ip_type); +} + +module_init(hash_ip_init); +module_exit(hash_ip_fini); diff --git a/net/netfilter/ipset/ip_set_hash_ipmac.c b/net/netfilter/ipset/ip_set_hash_ipmac.c new file mode 100644 index 000000000..467c59a83 --- /dev/null +++ b/net/netfilter/ipset/ip_set_hash_ipmac.c @@ -0,0 +1,311 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2016 Tomasz Chilinski + */ + +/* Kernel module implementing an IP set type: the hash:ip,mac type */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#define IPSET_TYPE_REV_MIN 0 +#define IPSET_TYPE_REV_MAX 1 /* bucketsize, initval support */ + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Tomasz Chilinski "); +IP_SET_MODULE_DESC("hash:ip,mac", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX); +MODULE_ALIAS("ip_set_hash:ip,mac"); + +/* Type specific function prefix */ +#define HTYPE hash_ipmac + +/* IPv4 variant */ + +/* Member elements */ +struct hash_ipmac4_elem { + /* Zero valued IP addresses cannot be stored */ + __be32 ip; + union { + unsigned char ether[ETH_ALEN]; + __be32 foo[2]; + }; +}; + +/* Common functions */ + +static bool +hash_ipmac4_data_equal(const struct hash_ipmac4_elem *e1, + const struct hash_ipmac4_elem *e2, + u32 *multi) +{ + return e1->ip == e2->ip && ether_addr_equal(e1->ether, e2->ether); +} + +static bool +hash_ipmac4_data_list(struct sk_buff *skb, const struct hash_ipmac4_elem *e) +{ + if (nla_put_ipaddr4(skb, IPSET_ATTR_IP, e->ip) || + nla_put(skb, IPSET_ATTR_ETHER, ETH_ALEN, e->ether)) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_ipmac4_data_next(struct hash_ipmac4_elem *next, + const struct hash_ipmac4_elem *e) +{ + next->ip = e->ip; +} + +#define MTYPE hash_ipmac4 +#define PF 4 +#define HOST_MASK 32 +#define HKEY_DATALEN sizeof(struct hash_ipmac4_elem) +#include "ip_set_hash_gen.h" + +static int +hash_ipmac4_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipmac4_elem e = { .ip = 0, { .foo[0] = 0, .foo[1] = 0 } }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + if (skb_mac_header(skb) < skb->head || + (skb_mac_header(skb) + ETH_HLEN) > skb->data) + return -EINVAL; + + if (opt->flags & IPSET_DIM_TWO_SRC) + ether_addr_copy(e.ether, eth_hdr(skb)->h_source); + else + ether_addr_copy(e.ether, eth_hdr(skb)->h_dest); + + if (is_zero_ether_addr(e.ether)) + return -EINVAL; + + ip4addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip); + + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_ipmac4_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipmac4_elem e = { .ip = 0, { .foo[0] = 0, .foo[1] = 0 } }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + int ret; + + if (unlikely(!tb[IPSET_ATTR_IP] || + !tb[IPSET_ATTR_ETHER] || + nla_len(tb[IPSET_ATTR_ETHER]) != ETH_ALEN || + !ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_PACKETS) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_BYTES) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBMARK) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBPRIO) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBQUEUE))) + return -IPSET_ERR_PROTOCOL; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + ret = ip_set_get_ipaddr4(tb[IPSET_ATTR_IP], &e.ip) || + ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + memcpy(e.ether, nla_data(tb[IPSET_ATTR_ETHER]), ETH_ALEN); + if (is_zero_ether_addr(e.ether)) + return -IPSET_ERR_HASH_ELEM; + + return adtfn(set, &e, &ext, &ext, flags); +} + +/* IPv6 variant */ + +/* Member elements */ +struct hash_ipmac6_elem { + /* Zero valued IP addresses cannot be stored */ + union nf_inet_addr ip; + union { + unsigned char ether[ETH_ALEN]; + __be32 foo[2]; + }; +}; + +/* Common functions */ + +static bool +hash_ipmac6_data_equal(const struct hash_ipmac6_elem *e1, + const struct hash_ipmac6_elem *e2, + u32 *multi) +{ + return ipv6_addr_equal(&e1->ip.in6, &e2->ip.in6) && + ether_addr_equal(e1->ether, e2->ether); +} + +static bool +hash_ipmac6_data_list(struct sk_buff *skb, const struct hash_ipmac6_elem *e) +{ + if (nla_put_ipaddr6(skb, IPSET_ATTR_IP, &e->ip.in6) || + nla_put(skb, IPSET_ATTR_ETHER, ETH_ALEN, e->ether)) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_ipmac6_data_next(struct hash_ipmac6_elem *next, + const struct hash_ipmac6_elem *e) +{ +} + +#undef MTYPE +#undef PF +#undef HOST_MASK +#undef HKEY_DATALEN + +#define MTYPE hash_ipmac6 +#define PF 6 +#define HOST_MASK 128 +#define HKEY_DATALEN sizeof(struct hash_ipmac6_elem) +#define IP_SET_EMIT_CREATE +#include "ip_set_hash_gen.h" + +static int +hash_ipmac6_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipmac6_elem e = { + { .all = { 0 } }, + { .foo[0] = 0, .foo[1] = 0 } + }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + if (skb_mac_header(skb) < skb->head || + (skb_mac_header(skb) + ETH_HLEN) > skb->data) + return -EINVAL; + + if (opt->flags & IPSET_DIM_TWO_SRC) + ether_addr_copy(e.ether, eth_hdr(skb)->h_source); + else + ether_addr_copy(e.ether, eth_hdr(skb)->h_dest); + + if (is_zero_ether_addr(e.ether)) + return -EINVAL; + + ip6addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip.in6); + + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_ipmac6_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipmac6_elem e = { + { .all = { 0 } }, + { .foo[0] = 0, .foo[1] = 0 } + }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + int ret; + + if (unlikely(!tb[IPSET_ATTR_IP] || + !tb[IPSET_ATTR_ETHER] || + nla_len(tb[IPSET_ATTR_ETHER]) != ETH_ALEN || + !ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_PACKETS) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_BYTES) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBMARK) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBPRIO) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBQUEUE))) + return -IPSET_ERR_PROTOCOL; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + ret = ip_set_get_ipaddr6(tb[IPSET_ATTR_IP], &e.ip) || + ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + memcpy(e.ether, nla_data(tb[IPSET_ATTR_ETHER]), ETH_ALEN); + if (is_zero_ether_addr(e.ether)) + return -IPSET_ERR_HASH_ELEM; + + return adtfn(set, &e, &ext, &ext, flags); +} + +static struct ip_set_type hash_ipmac_type __read_mostly = { + .name = "hash:ip,mac", + .protocol = IPSET_PROTOCOL, + .features = IPSET_TYPE_IP | IPSET_TYPE_MAC, + .dimension = IPSET_DIM_TWO, + .family = NFPROTO_UNSPEC, + .revision_min = IPSET_TYPE_REV_MIN, + .revision_max = IPSET_TYPE_REV_MAX, + .create_flags[IPSET_TYPE_REV_MAX] = IPSET_CREATE_FLAG_BUCKETSIZE, + .create = hash_ipmac_create, + .create_policy = { + [IPSET_ATTR_HASHSIZE] = { .type = NLA_U32 }, + [IPSET_ATTR_MAXELEM] = { .type = NLA_U32 }, + [IPSET_ATTR_INITVAL] = { .type = NLA_U32 }, + [IPSET_ATTR_BUCKETSIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_RESIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + }, + .adt_policy = { + [IPSET_ATTR_IP] = { .type = NLA_NESTED }, + [IPSET_ATTR_ETHER] = { .type = NLA_BINARY, + .len = ETH_ALEN }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_LINENO] = { .type = NLA_U32 }, + [IPSET_ATTR_BYTES] = { .type = NLA_U64 }, + [IPSET_ATTR_PACKETS] = { .type = NLA_U64 }, + [IPSET_ATTR_COMMENT] = { .type = NLA_NUL_STRING }, + [IPSET_ATTR_SKBMARK] = { .type = NLA_U64 }, + [IPSET_ATTR_SKBPRIO] = { .type = NLA_U32 }, + [IPSET_ATTR_SKBQUEUE] = { .type = NLA_U16 }, + }, + .me = THIS_MODULE, +}; + +static int __init +hash_ipmac_init(void) +{ + return ip_set_type_register(&hash_ipmac_type); +} + +static void __exit +hash_ipmac_fini(void) +{ + ip_set_type_unregister(&hash_ipmac_type); +} + +module_init(hash_ipmac_init); +module_exit(hash_ipmac_fini); diff --git a/net/netfilter/ipset/ip_set_hash_ipmark.c b/net/netfilter/ipset/ip_set_hash_ipmark.c new file mode 100644 index 000000000..a22ec1a6f --- /dev/null +++ b/net/netfilter/ipset/ip_set_hash_ipmark.c @@ -0,0 +1,331 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2003-2013 Jozsef Kadlecsik */ + +/* Kernel module implementing an IP set type: the hash:ip,mark type */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#define IPSET_TYPE_REV_MIN 0 +/* 1 Forceadd support */ +/* 2 skbinfo support */ +#define IPSET_TYPE_REV_MAX 3 /* bucketsize, initval support */ + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Vytas Dauksa "); +IP_SET_MODULE_DESC("hash:ip,mark", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX); +MODULE_ALIAS("ip_set_hash:ip,mark"); + +/* Type specific function prefix */ +#define HTYPE hash_ipmark +#define IP_SET_HASH_WITH_MARKMASK + +/* IPv4 variant */ + +/* Member elements */ +struct hash_ipmark4_elem { + __be32 ip; + __u32 mark; +}; + +/* Common functions */ + +static bool +hash_ipmark4_data_equal(const struct hash_ipmark4_elem *ip1, + const struct hash_ipmark4_elem *ip2, + u32 *multi) +{ + return ip1->ip == ip2->ip && + ip1->mark == ip2->mark; +} + +static bool +hash_ipmark4_data_list(struct sk_buff *skb, + const struct hash_ipmark4_elem *data) +{ + if (nla_put_ipaddr4(skb, IPSET_ATTR_IP, data->ip) || + nla_put_net32(skb, IPSET_ATTR_MARK, htonl(data->mark))) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_ipmark4_data_next(struct hash_ipmark4_elem *next, + const struct hash_ipmark4_elem *d) +{ + next->ip = d->ip; +} + +#define MTYPE hash_ipmark4 +#define HOST_MASK 32 +#include "ip_set_hash_gen.h" + +static int +hash_ipmark4_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + const struct hash_ipmark4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipmark4_elem e = { }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + e.mark = skb->mark; + e.mark &= h->markmask; + + ip4addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip); + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_ipmark4_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + struct hash_ipmark4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipmark4_elem e = { }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + u32 ip, ip_to = 0, i = 0; + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP] || + !ip_set_attr_netorder(tb, IPSET_ATTR_MARK))) + return -IPSET_ERR_PROTOCOL; + + ret = ip_set_get_ipaddr4(tb[IPSET_ATTR_IP], &e.ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + e.mark = ntohl(nla_get_be32(tb[IPSET_ATTR_MARK])); + e.mark &= h->markmask; + if (e.mark == 0 && e.ip == 0) + return -IPSET_ERR_HASH_ELEM; + + if (adt == IPSET_TEST || + !(tb[IPSET_ATTR_IP_TO] || tb[IPSET_ATTR_CIDR])) { + ret = adtfn(set, &e, &ext, &ext, flags); + return ip_set_eexist(ret, flags) ? 0 : ret; + } + + ip_to = ip = ntohl(e.ip); + if (tb[IPSET_ATTR_IP_TO]) { + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP_TO], &ip_to); + if (ret) + return ret; + if (ip > ip_to) { + if (e.mark == 0 && ip_to == 0) + return -IPSET_ERR_HASH_ELEM; + swap(ip, ip_to); + } + } else if (tb[IPSET_ATTR_CIDR]) { + u8 cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + + if (!cidr || cidr > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + ip_set_mask_from_to(ip, ip_to, cidr); + } + + if (retried) + ip = ntohl(h->next.ip); + for (; ip <= ip_to; ip++, i++) { + e.ip = htonl(ip); + if (i > IPSET_MAX_RANGE) { + hash_ipmark4_data_next(&h->next, &e); + return -ERANGE; + } + ret = adtfn(set, &e, &ext, &ext, flags); + + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + ret = 0; + } + return ret; +} + +/* IPv6 variant */ + +struct hash_ipmark6_elem { + union nf_inet_addr ip; + __u32 mark; +}; + +/* Common functions */ + +static bool +hash_ipmark6_data_equal(const struct hash_ipmark6_elem *ip1, + const struct hash_ipmark6_elem *ip2, + u32 *multi) +{ + return ipv6_addr_equal(&ip1->ip.in6, &ip2->ip.in6) && + ip1->mark == ip2->mark; +} + +static bool +hash_ipmark6_data_list(struct sk_buff *skb, + const struct hash_ipmark6_elem *data) +{ + if (nla_put_ipaddr6(skb, IPSET_ATTR_IP, &data->ip.in6) || + nla_put_net32(skb, IPSET_ATTR_MARK, htonl(data->mark))) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_ipmark6_data_next(struct hash_ipmark6_elem *next, + const struct hash_ipmark6_elem *d) +{ +} + +#undef MTYPE +#undef HOST_MASK + +#define MTYPE hash_ipmark6 +#define HOST_MASK 128 +#define IP_SET_EMIT_CREATE +#include "ip_set_hash_gen.h" + +static int +hash_ipmark6_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + const struct hash_ipmark6 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipmark6_elem e = { }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + e.mark = skb->mark; + e.mark &= h->markmask; + + ip6addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip.in6); + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_ipmark6_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + const struct hash_ipmark6 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipmark6_elem e = { }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP] || + !ip_set_attr_netorder(tb, IPSET_ATTR_MARK))) + return -IPSET_ERR_PROTOCOL; + if (unlikely(tb[IPSET_ATTR_IP_TO])) + return -IPSET_ERR_HASH_RANGE_UNSUPPORTED; + if (unlikely(tb[IPSET_ATTR_CIDR])) { + u8 cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + + if (cidr != HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + } + + ret = ip_set_get_ipaddr6(tb[IPSET_ATTR_IP], &e.ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + e.mark = ntohl(nla_get_be32(tb[IPSET_ATTR_MARK])); + e.mark &= h->markmask; + + if (adt == IPSET_TEST) { + ret = adtfn(set, &e, &ext, &ext, flags); + return ip_set_eexist(ret, flags) ? 0 : ret; + } + + ret = adtfn(set, &e, &ext, &ext, flags); + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + return 0; +} + +static struct ip_set_type hash_ipmark_type __read_mostly = { + .name = "hash:ip,mark", + .protocol = IPSET_PROTOCOL, + .features = IPSET_TYPE_IP | IPSET_TYPE_MARK, + .dimension = IPSET_DIM_TWO, + .family = NFPROTO_UNSPEC, + .revision_min = IPSET_TYPE_REV_MIN, + .revision_max = IPSET_TYPE_REV_MAX, + .create_flags[IPSET_TYPE_REV_MAX] = IPSET_CREATE_FLAG_BUCKETSIZE, + .create = hash_ipmark_create, + .create_policy = { + [IPSET_ATTR_MARKMASK] = { .type = NLA_U32 }, + [IPSET_ATTR_HASHSIZE] = { .type = NLA_U32 }, + [IPSET_ATTR_MAXELEM] = { .type = NLA_U32 }, + [IPSET_ATTR_INITVAL] = { .type = NLA_U32 }, + [IPSET_ATTR_BUCKETSIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_RESIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + }, + .adt_policy = { + [IPSET_ATTR_IP] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP_TO] = { .type = NLA_NESTED }, + [IPSET_ATTR_MARK] = { .type = NLA_U32 }, + [IPSET_ATTR_CIDR] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_LINENO] = { .type = NLA_U32 }, + [IPSET_ATTR_BYTES] = { .type = NLA_U64 }, + [IPSET_ATTR_PACKETS] = { .type = NLA_U64 }, + [IPSET_ATTR_COMMENT] = { .type = NLA_NUL_STRING, + .len = IPSET_MAX_COMMENT_SIZE }, + [IPSET_ATTR_SKBMARK] = { .type = NLA_U64 }, + [IPSET_ATTR_SKBPRIO] = { .type = NLA_U32 }, + [IPSET_ATTR_SKBQUEUE] = { .type = NLA_U16 }, + }, + .me = THIS_MODULE, +}; + +static int __init +hash_ipmark_init(void) +{ + return ip_set_type_register(&hash_ipmark_type); +} + +static void __exit +hash_ipmark_fini(void) +{ + rcu_barrier(); + ip_set_type_unregister(&hash_ipmark_type); +} + +module_init(hash_ipmark_init); +module_exit(hash_ipmark_fini); diff --git a/net/netfilter/ipset/ip_set_hash_ipport.c b/net/netfilter/ipset/ip_set_hash_ipport.c new file mode 100644 index 000000000..10481760a --- /dev/null +++ b/net/netfilter/ipset/ip_set_hash_ipport.c @@ -0,0 +1,395 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2003-2013 Jozsef Kadlecsik */ + +/* Kernel module implementing an IP set type: the hash:ip,port type */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#define IPSET_TYPE_REV_MIN 0 +/* 1 SCTP and UDPLITE support added */ +/* 2 Counters support added */ +/* 3 Comments support added */ +/* 4 Forceadd support added */ +/* 5 skbinfo support added */ +#define IPSET_TYPE_REV_MAX 6 /* bucketsize, initval support added */ + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Jozsef Kadlecsik "); +IP_SET_MODULE_DESC("hash:ip,port", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX); +MODULE_ALIAS("ip_set_hash:ip,port"); + +/* Type specific function prefix */ +#define HTYPE hash_ipport + +/* IPv4 variant */ + +/* Member elements */ +struct hash_ipport4_elem { + __be32 ip; + __be16 port; + u8 proto; + u8 padding; +}; + +/* Common functions */ + +static bool +hash_ipport4_data_equal(const struct hash_ipport4_elem *ip1, + const struct hash_ipport4_elem *ip2, + u32 *multi) +{ + return ip1->ip == ip2->ip && + ip1->port == ip2->port && + ip1->proto == ip2->proto; +} + +static bool +hash_ipport4_data_list(struct sk_buff *skb, + const struct hash_ipport4_elem *data) +{ + if (nla_put_ipaddr4(skb, IPSET_ATTR_IP, data->ip) || + nla_put_net16(skb, IPSET_ATTR_PORT, data->port) || + nla_put_u8(skb, IPSET_ATTR_PROTO, data->proto)) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_ipport4_data_next(struct hash_ipport4_elem *next, + const struct hash_ipport4_elem *d) +{ + next->ip = d->ip; + next->port = d->port; +} + +#define MTYPE hash_ipport4 +#define HOST_MASK 32 +#include "ip_set_hash_gen.h" + +static int +hash_ipport4_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipport4_elem e = { .ip = 0 }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + if (!ip_set_get_ip4_port(skb, opt->flags & IPSET_DIM_TWO_SRC, + &e.port, &e.proto)) + return -EINVAL; + + ip4addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip); + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_ipport4_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + struct hash_ipport4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipport4_elem e = { .ip = 0 }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + u32 ip, ip_to = 0, p = 0, port, port_to, i = 0; + bool with_ports = false; + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP] || + !ip_set_attr_netorder(tb, IPSET_ATTR_PORT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_PORT_TO))) + return -IPSET_ERR_PROTOCOL; + + ret = ip_set_get_ipaddr4(tb[IPSET_ATTR_IP], &e.ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + e.port = nla_get_be16(tb[IPSET_ATTR_PORT]); + + if (tb[IPSET_ATTR_PROTO]) { + e.proto = nla_get_u8(tb[IPSET_ATTR_PROTO]); + with_ports = ip_set_proto_with_ports(e.proto); + + if (e.proto == 0) + return -IPSET_ERR_INVALID_PROTO; + } else { + return -IPSET_ERR_MISSING_PROTO; + } + + if (!(with_ports || e.proto == IPPROTO_ICMP)) + e.port = 0; + + if (adt == IPSET_TEST || + !(tb[IPSET_ATTR_IP_TO] || tb[IPSET_ATTR_CIDR] || + tb[IPSET_ATTR_PORT_TO])) { + ret = adtfn(set, &e, &ext, &ext, flags); + return ip_set_eexist(ret, flags) ? 0 : ret; + } + + ip_to = ip = ntohl(e.ip); + if (tb[IPSET_ATTR_IP_TO]) { + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP_TO], &ip_to); + if (ret) + return ret; + if (ip > ip_to) + swap(ip, ip_to); + } else if (tb[IPSET_ATTR_CIDR]) { + u8 cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + + if (!cidr || cidr > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + ip_set_mask_from_to(ip, ip_to, cidr); + } + + port_to = port = ntohs(e.port); + if (with_ports && tb[IPSET_ATTR_PORT_TO]) { + port_to = ip_set_get_h16(tb[IPSET_ATTR_PORT_TO]); + if (port > port_to) + swap(port, port_to); + } + + if (retried) + ip = ntohl(h->next.ip); + for (; ip <= ip_to; ip++) { + p = retried && ip == ntohl(h->next.ip) ? ntohs(h->next.port) + : port; + for (; p <= port_to; p++, i++) { + e.ip = htonl(ip); + e.port = htons(p); + if (i > IPSET_MAX_RANGE) { + hash_ipport4_data_next(&h->next, &e); + return -ERANGE; + } + ret = adtfn(set, &e, &ext, &ext, flags); + + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + ret = 0; + } + } + return ret; +} + +/* IPv6 variant */ + +struct hash_ipport6_elem { + union nf_inet_addr ip; + __be16 port; + u8 proto; + u8 padding; +}; + +/* Common functions */ + +static bool +hash_ipport6_data_equal(const struct hash_ipport6_elem *ip1, + const struct hash_ipport6_elem *ip2, + u32 *multi) +{ + return ipv6_addr_equal(&ip1->ip.in6, &ip2->ip.in6) && + ip1->port == ip2->port && + ip1->proto == ip2->proto; +} + +static bool +hash_ipport6_data_list(struct sk_buff *skb, + const struct hash_ipport6_elem *data) +{ + if (nla_put_ipaddr6(skb, IPSET_ATTR_IP, &data->ip.in6) || + nla_put_net16(skb, IPSET_ATTR_PORT, data->port) || + nla_put_u8(skb, IPSET_ATTR_PROTO, data->proto)) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_ipport6_data_next(struct hash_ipport6_elem *next, + const struct hash_ipport6_elem *d) +{ + next->port = d->port; +} + +#undef MTYPE +#undef HOST_MASK + +#define MTYPE hash_ipport6 +#define HOST_MASK 128 +#define IP_SET_EMIT_CREATE +#include "ip_set_hash_gen.h" + +static int +hash_ipport6_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipport6_elem e = { .ip = { .all = { 0 } } }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + if (!ip_set_get_ip6_port(skb, opt->flags & IPSET_DIM_TWO_SRC, + &e.port, &e.proto)) + return -EINVAL; + + ip6addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip.in6); + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_ipport6_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + const struct hash_ipport6 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipport6_elem e = { .ip = { .all = { 0 } } }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + u32 port, port_to; + bool with_ports = false; + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP] || + !ip_set_attr_netorder(tb, IPSET_ATTR_PORT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_PORT_TO))) + return -IPSET_ERR_PROTOCOL; + if (unlikely(tb[IPSET_ATTR_IP_TO])) + return -IPSET_ERR_HASH_RANGE_UNSUPPORTED; + if (unlikely(tb[IPSET_ATTR_CIDR])) { + u8 cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + + if (cidr != HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + } + + ret = ip_set_get_ipaddr6(tb[IPSET_ATTR_IP], &e.ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + e.port = nla_get_be16(tb[IPSET_ATTR_PORT]); + + if (tb[IPSET_ATTR_PROTO]) { + e.proto = nla_get_u8(tb[IPSET_ATTR_PROTO]); + with_ports = ip_set_proto_with_ports(e.proto); + + if (e.proto == 0) + return -IPSET_ERR_INVALID_PROTO; + } else { + return -IPSET_ERR_MISSING_PROTO; + } + + if (!(with_ports || e.proto == IPPROTO_ICMPV6)) + e.port = 0; + + if (adt == IPSET_TEST || !with_ports || !tb[IPSET_ATTR_PORT_TO]) { + ret = adtfn(set, &e, &ext, &ext, flags); + return ip_set_eexist(ret, flags) ? 0 : ret; + } + + port = ntohs(e.port); + port_to = ip_set_get_h16(tb[IPSET_ATTR_PORT_TO]); + if (port > port_to) + swap(port, port_to); + + if (retried) + port = ntohs(h->next.port); + for (; port <= port_to; port++) { + e.port = htons(port); + ret = adtfn(set, &e, &ext, &ext, flags); + + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + ret = 0; + } + return ret; +} + +static struct ip_set_type hash_ipport_type __read_mostly = { + .name = "hash:ip,port", + .protocol = IPSET_PROTOCOL, + .features = IPSET_TYPE_IP | IPSET_TYPE_PORT, + .dimension = IPSET_DIM_TWO, + .family = NFPROTO_UNSPEC, + .revision_min = IPSET_TYPE_REV_MIN, + .revision_max = IPSET_TYPE_REV_MAX, + .create_flags[IPSET_TYPE_REV_MAX] = IPSET_CREATE_FLAG_BUCKETSIZE, + .create = hash_ipport_create, + .create_policy = { + [IPSET_ATTR_HASHSIZE] = { .type = NLA_U32 }, + [IPSET_ATTR_MAXELEM] = { .type = NLA_U32 }, + [IPSET_ATTR_INITVAL] = { .type = NLA_U32 }, + [IPSET_ATTR_BUCKETSIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_RESIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_PROTO] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + }, + .adt_policy = { + [IPSET_ATTR_IP] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP_TO] = { .type = NLA_NESTED }, + [IPSET_ATTR_PORT] = { .type = NLA_U16 }, + [IPSET_ATTR_PORT_TO] = { .type = NLA_U16 }, + [IPSET_ATTR_CIDR] = { .type = NLA_U8 }, + [IPSET_ATTR_PROTO] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_LINENO] = { .type = NLA_U32 }, + [IPSET_ATTR_BYTES] = { .type = NLA_U64 }, + [IPSET_ATTR_PACKETS] = { .type = NLA_U64 }, + [IPSET_ATTR_COMMENT] = { .type = NLA_NUL_STRING, + .len = IPSET_MAX_COMMENT_SIZE }, + [IPSET_ATTR_SKBMARK] = { .type = NLA_U64 }, + [IPSET_ATTR_SKBPRIO] = { .type = NLA_U32 }, + [IPSET_ATTR_SKBQUEUE] = { .type = NLA_U16 }, + }, + .me = THIS_MODULE, +}; + +static int __init +hash_ipport_init(void) +{ + return ip_set_type_register(&hash_ipport_type); +} + +static void __exit +hash_ipport_fini(void) +{ + rcu_barrier(); + ip_set_type_unregister(&hash_ipport_type); +} + +module_init(hash_ipport_init); +module_exit(hash_ipport_fini); diff --git a/net/netfilter/ipset/ip_set_hash_ipportip.c b/net/netfilter/ipset/ip_set_hash_ipportip.c new file mode 100644 index 000000000..39a01934b --- /dev/null +++ b/net/netfilter/ipset/ip_set_hash_ipportip.c @@ -0,0 +1,410 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2003-2013 Jozsef Kadlecsik */ + +/* Kernel module implementing an IP set type: the hash:ip,port,ip type */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#define IPSET_TYPE_REV_MIN 0 +/* 1 SCTP and UDPLITE support added */ +/* 2 Counters support added */ +/* 3 Comments support added */ +/* 4 Forceadd support added */ +/* 5 skbinfo support added */ +#define IPSET_TYPE_REV_MAX 6 /* bucketsize, initval support added */ + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Jozsef Kadlecsik "); +IP_SET_MODULE_DESC("hash:ip,port,ip", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX); +MODULE_ALIAS("ip_set_hash:ip,port,ip"); + +/* Type specific function prefix */ +#define HTYPE hash_ipportip + +/* IPv4 variant */ + +/* Member elements */ +struct hash_ipportip4_elem { + __be32 ip; + __be32 ip2; + __be16 port; + u8 proto; + u8 padding; +}; + +static bool +hash_ipportip4_data_equal(const struct hash_ipportip4_elem *ip1, + const struct hash_ipportip4_elem *ip2, + u32 *multi) +{ + return ip1->ip == ip2->ip && + ip1->ip2 == ip2->ip2 && + ip1->port == ip2->port && + ip1->proto == ip2->proto; +} + +static bool +hash_ipportip4_data_list(struct sk_buff *skb, + const struct hash_ipportip4_elem *data) +{ + if (nla_put_ipaddr4(skb, IPSET_ATTR_IP, data->ip) || + nla_put_ipaddr4(skb, IPSET_ATTR_IP2, data->ip2) || + nla_put_net16(skb, IPSET_ATTR_PORT, data->port) || + nla_put_u8(skb, IPSET_ATTR_PROTO, data->proto)) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_ipportip4_data_next(struct hash_ipportip4_elem *next, + const struct hash_ipportip4_elem *d) +{ + next->ip = d->ip; + next->port = d->port; +} + +/* Common functions */ +#define MTYPE hash_ipportip4 +#define HOST_MASK 32 +#include "ip_set_hash_gen.h" + +static int +hash_ipportip4_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipportip4_elem e = { .ip = 0 }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + if (!ip_set_get_ip4_port(skb, opt->flags & IPSET_DIM_TWO_SRC, + &e.port, &e.proto)) + return -EINVAL; + + ip4addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip); + ip4addrptr(skb, opt->flags & IPSET_DIM_THREE_SRC, &e.ip2); + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_ipportip4_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + struct hash_ipportip4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipportip4_elem e = { .ip = 0 }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + u32 ip, ip_to = 0, p = 0, port, port_to, i = 0; + bool with_ports = false; + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP] || !tb[IPSET_ATTR_IP2] || + !ip_set_attr_netorder(tb, IPSET_ATTR_PORT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_PORT_TO))) + return -IPSET_ERR_PROTOCOL; + + ret = ip_set_get_ipaddr4(tb[IPSET_ATTR_IP], &e.ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + ret = ip_set_get_ipaddr4(tb[IPSET_ATTR_IP2], &e.ip2); + if (ret) + return ret; + + e.port = nla_get_be16(tb[IPSET_ATTR_PORT]); + + if (tb[IPSET_ATTR_PROTO]) { + e.proto = nla_get_u8(tb[IPSET_ATTR_PROTO]); + with_ports = ip_set_proto_with_ports(e.proto); + + if (e.proto == 0) + return -IPSET_ERR_INVALID_PROTO; + } else { + return -IPSET_ERR_MISSING_PROTO; + } + + if (!(with_ports || e.proto == IPPROTO_ICMP)) + e.port = 0; + + if (adt == IPSET_TEST || + !(tb[IPSET_ATTR_IP_TO] || tb[IPSET_ATTR_CIDR] || + tb[IPSET_ATTR_PORT_TO])) { + ret = adtfn(set, &e, &ext, &ext, flags); + return ip_set_eexist(ret, flags) ? 0 : ret; + } + + ip_to = ip = ntohl(e.ip); + if (tb[IPSET_ATTR_IP_TO]) { + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP_TO], &ip_to); + if (ret) + return ret; + if (ip > ip_to) + swap(ip, ip_to); + } else if (tb[IPSET_ATTR_CIDR]) { + u8 cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + + if (!cidr || cidr > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + ip_set_mask_from_to(ip, ip_to, cidr); + } + + port_to = port = ntohs(e.port); + if (with_ports && tb[IPSET_ATTR_PORT_TO]) { + port_to = ip_set_get_h16(tb[IPSET_ATTR_PORT_TO]); + if (port > port_to) + swap(port, port_to); + } + + if (retried) + ip = ntohl(h->next.ip); + for (; ip <= ip_to; ip++) { + p = retried && ip == ntohl(h->next.ip) ? ntohs(h->next.port) + : port; + for (; p <= port_to; p++, i++) { + e.ip = htonl(ip); + e.port = htons(p); + if (i > IPSET_MAX_RANGE) { + hash_ipportip4_data_next(&h->next, &e); + return -ERANGE; + } + ret = adtfn(set, &e, &ext, &ext, flags); + + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + ret = 0; + } + } + return ret; +} + +/* IPv6 variant */ + +struct hash_ipportip6_elem { + union nf_inet_addr ip; + union nf_inet_addr ip2; + __be16 port; + u8 proto; + u8 padding; +}; + +/* Common functions */ + +static bool +hash_ipportip6_data_equal(const struct hash_ipportip6_elem *ip1, + const struct hash_ipportip6_elem *ip2, + u32 *multi) +{ + return ipv6_addr_equal(&ip1->ip.in6, &ip2->ip.in6) && + ipv6_addr_equal(&ip1->ip2.in6, &ip2->ip2.in6) && + ip1->port == ip2->port && + ip1->proto == ip2->proto; +} + +static bool +hash_ipportip6_data_list(struct sk_buff *skb, + const struct hash_ipportip6_elem *data) +{ + if (nla_put_ipaddr6(skb, IPSET_ATTR_IP, &data->ip.in6) || + nla_put_ipaddr6(skb, IPSET_ATTR_IP2, &data->ip2.in6) || + nla_put_net16(skb, IPSET_ATTR_PORT, data->port) || + nla_put_u8(skb, IPSET_ATTR_PROTO, data->proto)) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_ipportip6_data_next(struct hash_ipportip6_elem *next, + const struct hash_ipportip6_elem *d) +{ + next->port = d->port; +} + +#undef MTYPE +#undef HOST_MASK + +#define MTYPE hash_ipportip6 +#define HOST_MASK 128 +#define IP_SET_EMIT_CREATE +#include "ip_set_hash_gen.h" + +static int +hash_ipportip6_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipportip6_elem e = { .ip = { .all = { 0 } } }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + if (!ip_set_get_ip6_port(skb, opt->flags & IPSET_DIM_TWO_SRC, + &e.port, &e.proto)) + return -EINVAL; + + ip6addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip.in6); + ip6addrptr(skb, opt->flags & IPSET_DIM_THREE_SRC, &e.ip2.in6); + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_ipportip6_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + const struct hash_ipportip6 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipportip6_elem e = { .ip = { .all = { 0 } } }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + u32 port, port_to; + bool with_ports = false; + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP] || !tb[IPSET_ATTR_IP2] || + !ip_set_attr_netorder(tb, IPSET_ATTR_PORT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_PORT_TO))) + return -IPSET_ERR_PROTOCOL; + if (unlikely(tb[IPSET_ATTR_IP_TO])) + return -IPSET_ERR_HASH_RANGE_UNSUPPORTED; + if (unlikely(tb[IPSET_ATTR_CIDR])) { + u8 cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + + if (cidr != HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + } + + ret = ip_set_get_ipaddr6(tb[IPSET_ATTR_IP], &e.ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + ret = ip_set_get_ipaddr6(tb[IPSET_ATTR_IP2], &e.ip2); + if (ret) + return ret; + + e.port = nla_get_be16(tb[IPSET_ATTR_PORT]); + + if (tb[IPSET_ATTR_PROTO]) { + e.proto = nla_get_u8(tb[IPSET_ATTR_PROTO]); + with_ports = ip_set_proto_with_ports(e.proto); + + if (e.proto == 0) + return -IPSET_ERR_INVALID_PROTO; + } else { + return -IPSET_ERR_MISSING_PROTO; + } + + if (!(with_ports || e.proto == IPPROTO_ICMPV6)) + e.port = 0; + + if (adt == IPSET_TEST || !with_ports || !tb[IPSET_ATTR_PORT_TO]) { + ret = adtfn(set, &e, &ext, &ext, flags); + return ip_set_eexist(ret, flags) ? 0 : ret; + } + + port = ntohs(e.port); + port_to = ip_set_get_h16(tb[IPSET_ATTR_PORT_TO]); + if (port > port_to) + swap(port, port_to); + + if (retried) + port = ntohs(h->next.port); + for (; port <= port_to; port++) { + e.port = htons(port); + ret = adtfn(set, &e, &ext, &ext, flags); + + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + ret = 0; + } + return ret; +} + +static struct ip_set_type hash_ipportip_type __read_mostly = { + .name = "hash:ip,port,ip", + .protocol = IPSET_PROTOCOL, + .features = IPSET_TYPE_IP | IPSET_TYPE_PORT | IPSET_TYPE_IP2, + .dimension = IPSET_DIM_THREE, + .family = NFPROTO_UNSPEC, + .revision_min = IPSET_TYPE_REV_MIN, + .revision_max = IPSET_TYPE_REV_MAX, + .create_flags[IPSET_TYPE_REV_MAX] = IPSET_CREATE_FLAG_BUCKETSIZE, + .create = hash_ipportip_create, + .create_policy = { + [IPSET_ATTR_HASHSIZE] = { .type = NLA_U32 }, + [IPSET_ATTR_MAXELEM] = { .type = NLA_U32 }, + [IPSET_ATTR_INITVAL] = { .type = NLA_U32 }, + [IPSET_ATTR_BUCKETSIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_RESIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + }, + .adt_policy = { + [IPSET_ATTR_IP] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP_TO] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP2] = { .type = NLA_NESTED }, + [IPSET_ATTR_PORT] = { .type = NLA_U16 }, + [IPSET_ATTR_PORT_TO] = { .type = NLA_U16 }, + [IPSET_ATTR_CIDR] = { .type = NLA_U8 }, + [IPSET_ATTR_PROTO] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_LINENO] = { .type = NLA_U32 }, + [IPSET_ATTR_BYTES] = { .type = NLA_U64 }, + [IPSET_ATTR_PACKETS] = { .type = NLA_U64 }, + [IPSET_ATTR_COMMENT] = { .type = NLA_NUL_STRING, + .len = IPSET_MAX_COMMENT_SIZE }, + [IPSET_ATTR_SKBMARK] = { .type = NLA_U64 }, + [IPSET_ATTR_SKBPRIO] = { .type = NLA_U32 }, + [IPSET_ATTR_SKBQUEUE] = { .type = NLA_U16 }, + }, + .me = THIS_MODULE, +}; + +static int __init +hash_ipportip_init(void) +{ + return ip_set_type_register(&hash_ipportip_type); +} + +static void __exit +hash_ipportip_fini(void) +{ + rcu_barrier(); + ip_set_type_unregister(&hash_ipportip_type); +} + +module_init(hash_ipportip_init); +module_exit(hash_ipportip_fini); diff --git a/net/netfilter/ipset/ip_set_hash_ipportnet.c b/net/netfilter/ipset/ip_set_hash_ipportnet.c new file mode 100644 index 000000000..5c6de605a --- /dev/null +++ b/net/netfilter/ipset/ip_set_hash_ipportnet.c @@ -0,0 +1,572 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2003-2013 Jozsef Kadlecsik */ + +/* Kernel module implementing an IP set type: the hash:ip,port,net type */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#define IPSET_TYPE_REV_MIN 0 +/* 1 SCTP and UDPLITE support added */ +/* 2 Range as input support for IPv4 added */ +/* 3 nomatch flag support added */ +/* 4 Counters support added */ +/* 5 Comments support added */ +/* 6 Forceadd support added */ +/* 7 skbinfo support added */ +#define IPSET_TYPE_REV_MAX 8 /* bucketsize, initval support added */ + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Jozsef Kadlecsik "); +IP_SET_MODULE_DESC("hash:ip,port,net", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX); +MODULE_ALIAS("ip_set_hash:ip,port,net"); + +/* Type specific function prefix */ +#define HTYPE hash_ipportnet + +/* We squeeze the "nomatch" flag into cidr: we don't support cidr == 0 + * However this way we have to store internally cidr - 1, + * dancing back and forth. + */ +#define IP_SET_HASH_WITH_NETS_PACKED +#define IP_SET_HASH_WITH_PROTO +#define IP_SET_HASH_WITH_NETS + +/* IPv4 variant */ + +/* Member elements */ +struct hash_ipportnet4_elem { + __be32 ip; + __be32 ip2; + __be16 port; + u8 cidr:7; + u8 nomatch:1; + u8 proto; +}; + +/* Common functions */ + +static bool +hash_ipportnet4_data_equal(const struct hash_ipportnet4_elem *ip1, + const struct hash_ipportnet4_elem *ip2, + u32 *multi) +{ + return ip1->ip == ip2->ip && + ip1->ip2 == ip2->ip2 && + ip1->cidr == ip2->cidr && + ip1->port == ip2->port && + ip1->proto == ip2->proto; +} + +static int +hash_ipportnet4_do_data_match(const struct hash_ipportnet4_elem *elem) +{ + return elem->nomatch ? -ENOTEMPTY : 1; +} + +static void +hash_ipportnet4_data_set_flags(struct hash_ipportnet4_elem *elem, u32 flags) +{ + elem->nomatch = !!((flags >> 16) & IPSET_FLAG_NOMATCH); +} + +static void +hash_ipportnet4_data_reset_flags(struct hash_ipportnet4_elem *elem, u8 *flags) +{ + swap(*flags, elem->nomatch); +} + +static void +hash_ipportnet4_data_netmask(struct hash_ipportnet4_elem *elem, u8 cidr) +{ + elem->ip2 &= ip_set_netmask(cidr); + elem->cidr = cidr - 1; +} + +static bool +hash_ipportnet4_data_list(struct sk_buff *skb, + const struct hash_ipportnet4_elem *data) +{ + u32 flags = data->nomatch ? IPSET_FLAG_NOMATCH : 0; + + if (nla_put_ipaddr4(skb, IPSET_ATTR_IP, data->ip) || + nla_put_ipaddr4(skb, IPSET_ATTR_IP2, data->ip2) || + nla_put_net16(skb, IPSET_ATTR_PORT, data->port) || + nla_put_u8(skb, IPSET_ATTR_CIDR2, data->cidr + 1) || + nla_put_u8(skb, IPSET_ATTR_PROTO, data->proto) || + (flags && + nla_put_net32(skb, IPSET_ATTR_CADT_FLAGS, htonl(flags)))) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_ipportnet4_data_next(struct hash_ipportnet4_elem *next, + const struct hash_ipportnet4_elem *d) +{ + next->ip = d->ip; + next->port = d->port; + next->ip2 = d->ip2; +} + +#define MTYPE hash_ipportnet4 +#define HOST_MASK 32 +#include "ip_set_hash_gen.h" + +static int +hash_ipportnet4_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + const struct hash_ipportnet4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipportnet4_elem e = { + .cidr = INIT_CIDR(h->nets[0].cidr[0], HOST_MASK), + }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + if (adt == IPSET_TEST) + e.cidr = HOST_MASK - 1; + + if (!ip_set_get_ip4_port(skb, opt->flags & IPSET_DIM_TWO_SRC, + &e.port, &e.proto)) + return -EINVAL; + + ip4addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip); + ip4addrptr(skb, opt->flags & IPSET_DIM_THREE_SRC, &e.ip2); + e.ip2 &= ip_set_netmask(e.cidr + 1); + + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_ipportnet4_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + struct hash_ipportnet4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipportnet4_elem e = { .cidr = HOST_MASK - 1 }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + u32 ip = 0, ip_to = 0, p = 0, port, port_to; + u32 ip2_from = 0, ip2_to = 0, ip2, i = 0; + bool with_ports = false; + u8 cidr; + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP] || !tb[IPSET_ATTR_IP2] || + !ip_set_attr_netorder(tb, IPSET_ATTR_PORT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_PORT_TO) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP], &ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP2], &ip2_from); + if (ret) + return ret; + + if (tb[IPSET_ATTR_CIDR2]) { + cidr = nla_get_u8(tb[IPSET_ATTR_CIDR2]); + if (!cidr || cidr > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + e.cidr = cidr - 1; + } + + e.port = nla_get_be16(tb[IPSET_ATTR_PORT]); + + if (tb[IPSET_ATTR_PROTO]) { + e.proto = nla_get_u8(tb[IPSET_ATTR_PROTO]); + with_ports = ip_set_proto_with_ports(e.proto); + + if (e.proto == 0) + return -IPSET_ERR_INVALID_PROTO; + } else { + return -IPSET_ERR_MISSING_PROTO; + } + + if (!(with_ports || e.proto == IPPROTO_ICMP)) + e.port = 0; + + if (tb[IPSET_ATTR_CADT_FLAGS]) { + u32 cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]); + + if (cadt_flags & IPSET_FLAG_NOMATCH) + flags |= (IPSET_FLAG_NOMATCH << 16); + } + + with_ports = with_ports && tb[IPSET_ATTR_PORT_TO]; + if (adt == IPSET_TEST || + !(tb[IPSET_ATTR_CIDR] || tb[IPSET_ATTR_IP_TO] || with_ports || + tb[IPSET_ATTR_IP2_TO])) { + e.ip = htonl(ip); + e.ip2 = htonl(ip2_from & ip_set_hostmask(e.cidr + 1)); + ret = adtfn(set, &e, &ext, &ext, flags); + return ip_set_enomatch(ret, flags, adt, set) ? -ret : + ip_set_eexist(ret, flags) ? 0 : ret; + } + + ip_to = ip; + if (tb[IPSET_ATTR_IP_TO]) { + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP_TO], &ip_to); + if (ret) + return ret; + if (ip > ip_to) + swap(ip, ip_to); + } else if (tb[IPSET_ATTR_CIDR]) { + cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + + if (!cidr || cidr > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + ip_set_mask_from_to(ip, ip_to, cidr); + } + + port_to = port = ntohs(e.port); + if (tb[IPSET_ATTR_PORT_TO]) { + port_to = ip_set_get_h16(tb[IPSET_ATTR_PORT_TO]); + if (port > port_to) + swap(port, port_to); + } + + ip2_to = ip2_from; + if (tb[IPSET_ATTR_IP2_TO]) { + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP2_TO], &ip2_to); + if (ret) + return ret; + if (ip2_from > ip2_to) + swap(ip2_from, ip2_to); + if (ip2_from + UINT_MAX == ip2_to) + return -IPSET_ERR_HASH_RANGE; + } else { + ip_set_mask_from_to(ip2_from, ip2_to, e.cidr + 1); + } + + if (retried) { + ip = ntohl(h->next.ip); + p = ntohs(h->next.port); + ip2 = ntohl(h->next.ip2); + } else { + p = port; + ip2 = ip2_from; + } + for (; ip <= ip_to; ip++) { + e.ip = htonl(ip); + for (; p <= port_to; p++) { + e.port = htons(p); + do { + i++; + e.ip2 = htonl(ip2); + ip2 = ip_set_range_to_cidr(ip2, ip2_to, &cidr); + e.cidr = cidr - 1; + if (i > IPSET_MAX_RANGE) { + hash_ipportnet4_data_next(&h->next, + &e); + return -ERANGE; + } + ret = adtfn(set, &e, &ext, &ext, flags); + + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + ret = 0; + } while (ip2++ < ip2_to); + ip2 = ip2_from; + } + p = port; + } + return ret; +} + +/* IPv6 variant */ + +struct hash_ipportnet6_elem { + union nf_inet_addr ip; + union nf_inet_addr ip2; + __be16 port; + u8 cidr:7; + u8 nomatch:1; + u8 proto; +}; + +/* Common functions */ + +static bool +hash_ipportnet6_data_equal(const struct hash_ipportnet6_elem *ip1, + const struct hash_ipportnet6_elem *ip2, + u32 *multi) +{ + return ipv6_addr_equal(&ip1->ip.in6, &ip2->ip.in6) && + ipv6_addr_equal(&ip1->ip2.in6, &ip2->ip2.in6) && + ip1->cidr == ip2->cidr && + ip1->port == ip2->port && + ip1->proto == ip2->proto; +} + +static int +hash_ipportnet6_do_data_match(const struct hash_ipportnet6_elem *elem) +{ + return elem->nomatch ? -ENOTEMPTY : 1; +} + +static void +hash_ipportnet6_data_set_flags(struct hash_ipportnet6_elem *elem, u32 flags) +{ + elem->nomatch = !!((flags >> 16) & IPSET_FLAG_NOMATCH); +} + +static void +hash_ipportnet6_data_reset_flags(struct hash_ipportnet6_elem *elem, u8 *flags) +{ + swap(*flags, elem->nomatch); +} + +static void +hash_ipportnet6_data_netmask(struct hash_ipportnet6_elem *elem, u8 cidr) +{ + ip6_netmask(&elem->ip2, cidr); + elem->cidr = cidr - 1; +} + +static bool +hash_ipportnet6_data_list(struct sk_buff *skb, + const struct hash_ipportnet6_elem *data) +{ + u32 flags = data->nomatch ? IPSET_FLAG_NOMATCH : 0; + + if (nla_put_ipaddr6(skb, IPSET_ATTR_IP, &data->ip.in6) || + nla_put_ipaddr6(skb, IPSET_ATTR_IP2, &data->ip2.in6) || + nla_put_net16(skb, IPSET_ATTR_PORT, data->port) || + nla_put_u8(skb, IPSET_ATTR_CIDR2, data->cidr + 1) || + nla_put_u8(skb, IPSET_ATTR_PROTO, data->proto) || + (flags && + nla_put_net32(skb, IPSET_ATTR_CADT_FLAGS, htonl(flags)))) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_ipportnet6_data_next(struct hash_ipportnet6_elem *next, + const struct hash_ipportnet6_elem *d) +{ + next->port = d->port; +} + +#undef MTYPE +#undef HOST_MASK + +#define MTYPE hash_ipportnet6 +#define HOST_MASK 128 +#define IP_SET_EMIT_CREATE +#include "ip_set_hash_gen.h" + +static int +hash_ipportnet6_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + const struct hash_ipportnet6 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipportnet6_elem e = { + .cidr = INIT_CIDR(h->nets[0].cidr[0], HOST_MASK), + }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + if (adt == IPSET_TEST) + e.cidr = HOST_MASK - 1; + + if (!ip_set_get_ip6_port(skb, opt->flags & IPSET_DIM_TWO_SRC, + &e.port, &e.proto)) + return -EINVAL; + + ip6addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip.in6); + ip6addrptr(skb, opt->flags & IPSET_DIM_THREE_SRC, &e.ip2.in6); + ip6_netmask(&e.ip2, e.cidr + 1); + + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_ipportnet6_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + const struct hash_ipportnet6 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_ipportnet6_elem e = { .cidr = HOST_MASK - 1 }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + u32 port, port_to; + bool with_ports = false; + u8 cidr; + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP] || !tb[IPSET_ATTR_IP2] || + !ip_set_attr_netorder(tb, IPSET_ATTR_PORT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_PORT_TO) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + if (unlikely(tb[IPSET_ATTR_IP_TO])) + return -IPSET_ERR_HASH_RANGE_UNSUPPORTED; + if (unlikely(tb[IPSET_ATTR_CIDR])) { + cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + + if (cidr != HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + } + + ret = ip_set_get_ipaddr6(tb[IPSET_ATTR_IP], &e.ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + ret = ip_set_get_ipaddr6(tb[IPSET_ATTR_IP2], &e.ip2); + if (ret) + return ret; + + if (tb[IPSET_ATTR_CIDR2]) { + cidr = nla_get_u8(tb[IPSET_ATTR_CIDR2]); + if (!cidr || cidr > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + e.cidr = cidr - 1; + } + + ip6_netmask(&e.ip2, e.cidr + 1); + + e.port = nla_get_be16(tb[IPSET_ATTR_PORT]); + + if (tb[IPSET_ATTR_PROTO]) { + e.proto = nla_get_u8(tb[IPSET_ATTR_PROTO]); + with_ports = ip_set_proto_with_ports(e.proto); + + if (e.proto == 0) + return -IPSET_ERR_INVALID_PROTO; + } else { + return -IPSET_ERR_MISSING_PROTO; + } + + if (!(with_ports || e.proto == IPPROTO_ICMPV6)) + e.port = 0; + + if (tb[IPSET_ATTR_CADT_FLAGS]) { + u32 cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]); + + if (cadt_flags & IPSET_FLAG_NOMATCH) + flags |= (IPSET_FLAG_NOMATCH << 16); + } + + if (adt == IPSET_TEST || !with_ports || !tb[IPSET_ATTR_PORT_TO]) { + ret = adtfn(set, &e, &ext, &ext, flags); + return ip_set_enomatch(ret, flags, adt, set) ? -ret : + ip_set_eexist(ret, flags) ? 0 : ret; + } + + port = ntohs(e.port); + port_to = ip_set_get_h16(tb[IPSET_ATTR_PORT_TO]); + if (port > port_to) + swap(port, port_to); + + if (retried) + port = ntohs(h->next.port); + for (; port <= port_to; port++) { + e.port = htons(port); + ret = adtfn(set, &e, &ext, &ext, flags); + + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + ret = 0; + } + return ret; +} + +static struct ip_set_type hash_ipportnet_type __read_mostly = { + .name = "hash:ip,port,net", + .protocol = IPSET_PROTOCOL, + .features = IPSET_TYPE_IP | IPSET_TYPE_PORT | IPSET_TYPE_IP2 | + IPSET_TYPE_NOMATCH, + .dimension = IPSET_DIM_THREE, + .family = NFPROTO_UNSPEC, + .revision_min = IPSET_TYPE_REV_MIN, + .revision_max = IPSET_TYPE_REV_MAX, + .create_flags[IPSET_TYPE_REV_MAX] = IPSET_CREATE_FLAG_BUCKETSIZE, + .create = hash_ipportnet_create, + .create_policy = { + [IPSET_ATTR_HASHSIZE] = { .type = NLA_U32 }, + [IPSET_ATTR_MAXELEM] = { .type = NLA_U32 }, + [IPSET_ATTR_INITVAL] = { .type = NLA_U32 }, + [IPSET_ATTR_BUCKETSIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_RESIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + }, + .adt_policy = { + [IPSET_ATTR_IP] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP_TO] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP2] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP2_TO] = { .type = NLA_NESTED }, + [IPSET_ATTR_PORT] = { .type = NLA_U16 }, + [IPSET_ATTR_PORT_TO] = { .type = NLA_U16 }, + [IPSET_ATTR_CIDR] = { .type = NLA_U8 }, + [IPSET_ATTR_CIDR2] = { .type = NLA_U8 }, + [IPSET_ATTR_PROTO] = { .type = NLA_U8 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_LINENO] = { .type = NLA_U32 }, + [IPSET_ATTR_BYTES] = { .type = NLA_U64 }, + [IPSET_ATTR_PACKETS] = { .type = NLA_U64 }, + [IPSET_ATTR_COMMENT] = { .type = NLA_NUL_STRING, + .len = IPSET_MAX_COMMENT_SIZE }, + [IPSET_ATTR_SKBMARK] = { .type = NLA_U64 }, + [IPSET_ATTR_SKBPRIO] = { .type = NLA_U32 }, + [IPSET_ATTR_SKBQUEUE] = { .type = NLA_U16 }, + }, + .me = THIS_MODULE, +}; + +static int __init +hash_ipportnet_init(void) +{ + return ip_set_type_register(&hash_ipportnet_type); +} + +static void __exit +hash_ipportnet_fini(void) +{ + rcu_barrier(); + ip_set_type_unregister(&hash_ipportnet_type); +} + +module_init(hash_ipportnet_init); +module_exit(hash_ipportnet_fini); diff --git a/net/netfilter/ipset/ip_set_hash_mac.c b/net/netfilter/ipset/ip_set_hash_mac.c new file mode 100644 index 000000000..718814730 --- /dev/null +++ b/net/netfilter/ipset/ip_set_hash_mac.c @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2014 Jozsef Kadlecsik */ + +/* Kernel module implementing an IP set type: the hash:mac type */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#define IPSET_TYPE_REV_MIN 0 +#define IPSET_TYPE_REV_MAX 1 /* bucketsize, initval support */ + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Jozsef Kadlecsik "); +IP_SET_MODULE_DESC("hash:mac", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX); +MODULE_ALIAS("ip_set_hash:mac"); + +/* Type specific function prefix */ +#define HTYPE hash_mac + +/* Member elements */ +struct hash_mac4_elem { + /* Zero valued IP addresses cannot be stored */ + union { + unsigned char ether[ETH_ALEN]; + __be32 foo[2]; + }; +}; + +/* Common functions */ + +static bool +hash_mac4_data_equal(const struct hash_mac4_elem *e1, + const struct hash_mac4_elem *e2, + u32 *multi) +{ + return ether_addr_equal(e1->ether, e2->ether); +} + +static bool +hash_mac4_data_list(struct sk_buff *skb, const struct hash_mac4_elem *e) +{ + if (nla_put(skb, IPSET_ATTR_ETHER, ETH_ALEN, e->ether)) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_mac4_data_next(struct hash_mac4_elem *next, + const struct hash_mac4_elem *e) +{ +} + +#define MTYPE hash_mac4 +#define HOST_MASK 32 +#define IP_SET_EMIT_CREATE +#define IP_SET_PROTO_UNDEF +#include "ip_set_hash_gen.h" + +static int +hash_mac4_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_mac4_elem e = { { .foo[0] = 0, .foo[1] = 0 } }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + if (skb_mac_header(skb) < skb->head || + (skb_mac_header(skb) + ETH_HLEN) > skb->data) + return -EINVAL; + + if (opt->flags & IPSET_DIM_ONE_SRC) + ether_addr_copy(e.ether, eth_hdr(skb)->h_source); + else + ether_addr_copy(e.ether, eth_hdr(skb)->h_dest); + + if (is_zero_ether_addr(e.ether)) + return -EINVAL; + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_mac4_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_mac4_elem e = { { .foo[0] = 0, .foo[1] = 0 } }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_ETHER] || + nla_len(tb[IPSET_ATTR_ETHER]) != ETH_ALEN)) + return -IPSET_ERR_PROTOCOL; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + ether_addr_copy(e.ether, nla_data(tb[IPSET_ATTR_ETHER])); + if (is_zero_ether_addr(e.ether)) + return -IPSET_ERR_HASH_ELEM; + + return adtfn(set, &e, &ext, &ext, flags); +} + +static struct ip_set_type hash_mac_type __read_mostly = { + .name = "hash:mac", + .protocol = IPSET_PROTOCOL, + .features = IPSET_TYPE_MAC, + .dimension = IPSET_DIM_ONE, + .family = NFPROTO_UNSPEC, + .revision_min = IPSET_TYPE_REV_MIN, + .revision_max = IPSET_TYPE_REV_MAX, + .create_flags[IPSET_TYPE_REV_MAX] = IPSET_CREATE_FLAG_BUCKETSIZE, + .create = hash_mac_create, + .create_policy = { + [IPSET_ATTR_HASHSIZE] = { .type = NLA_U32 }, + [IPSET_ATTR_MAXELEM] = { .type = NLA_U32 }, + [IPSET_ATTR_INITVAL] = { .type = NLA_U32 }, + [IPSET_ATTR_BUCKETSIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_RESIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + }, + .adt_policy = { + [IPSET_ATTR_ETHER] = { .type = NLA_BINARY, + .len = ETH_ALEN }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_LINENO] = { .type = NLA_U32 }, + [IPSET_ATTR_BYTES] = { .type = NLA_U64 }, + [IPSET_ATTR_PACKETS] = { .type = NLA_U64 }, + [IPSET_ATTR_COMMENT] = { .type = NLA_NUL_STRING, + .len = IPSET_MAX_COMMENT_SIZE }, + [IPSET_ATTR_SKBMARK] = { .type = NLA_U64 }, + [IPSET_ATTR_SKBPRIO] = { .type = NLA_U32 }, + [IPSET_ATTR_SKBQUEUE] = { .type = NLA_U16 }, + }, + .me = THIS_MODULE, +}; + +static int __init +hash_mac_init(void) +{ + return ip_set_type_register(&hash_mac_type); +} + +static void __exit +hash_mac_fini(void) +{ + rcu_barrier(); + ip_set_type_unregister(&hash_mac_type); +} + +module_init(hash_mac_init); +module_exit(hash_mac_fini); diff --git a/net/netfilter/ipset/ip_set_hash_net.c b/net/netfilter/ipset/ip_set_hash_net.c new file mode 100644 index 000000000..ce0a9ce5a --- /dev/null +++ b/net/netfilter/ipset/ip_set_hash_net.c @@ -0,0 +1,407 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2003-2013 Jozsef Kadlecsik */ + +/* Kernel module implementing an IP set type: the hash:net type */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#define IPSET_TYPE_REV_MIN 0 +/* 1 Range as input support for IPv4 added */ +/* 2 nomatch flag support added */ +/* 3 Counters support added */ +/* 4 Comments support added */ +/* 5 Forceadd support added */ +/* 6 skbinfo support added */ +#define IPSET_TYPE_REV_MAX 7 /* bucketsize, initval support added */ + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Jozsef Kadlecsik "); +IP_SET_MODULE_DESC("hash:net", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX); +MODULE_ALIAS("ip_set_hash:net"); + +/* Type specific function prefix */ +#define HTYPE hash_net +#define IP_SET_HASH_WITH_NETS + +/* IPv4 variant */ + +/* Member elements */ +struct hash_net4_elem { + __be32 ip; + u16 padding0; + u8 nomatch; + u8 cidr; +}; + +/* Common functions */ + +static bool +hash_net4_data_equal(const struct hash_net4_elem *ip1, + const struct hash_net4_elem *ip2, + u32 *multi) +{ + return ip1->ip == ip2->ip && + ip1->cidr == ip2->cidr; +} + +static int +hash_net4_do_data_match(const struct hash_net4_elem *elem) +{ + return elem->nomatch ? -ENOTEMPTY : 1; +} + +static void +hash_net4_data_set_flags(struct hash_net4_elem *elem, u32 flags) +{ + elem->nomatch = (flags >> 16) & IPSET_FLAG_NOMATCH; +} + +static void +hash_net4_data_reset_flags(struct hash_net4_elem *elem, u8 *flags) +{ + swap(*flags, elem->nomatch); +} + +static void +hash_net4_data_netmask(struct hash_net4_elem *elem, u8 cidr) +{ + elem->ip &= ip_set_netmask(cidr); + elem->cidr = cidr; +} + +static bool +hash_net4_data_list(struct sk_buff *skb, const struct hash_net4_elem *data) +{ + u32 flags = data->nomatch ? IPSET_FLAG_NOMATCH : 0; + + if (nla_put_ipaddr4(skb, IPSET_ATTR_IP, data->ip) || + nla_put_u8(skb, IPSET_ATTR_CIDR, data->cidr) || + (flags && + nla_put_net32(skb, IPSET_ATTR_CADT_FLAGS, htonl(flags)))) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_net4_data_next(struct hash_net4_elem *next, + const struct hash_net4_elem *d) +{ + next->ip = d->ip; +} + +#define MTYPE hash_net4 +#define HOST_MASK 32 +#include "ip_set_hash_gen.h" + +static int +hash_net4_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + const struct hash_net4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_net4_elem e = { + .cidr = INIT_CIDR(h->nets[0].cidr[0], HOST_MASK), + }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + if (e.cidr == 0) + return -EINVAL; + if (adt == IPSET_TEST) + e.cidr = HOST_MASK; + + ip4addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip); + e.ip &= ip_set_netmask(e.cidr); + + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_net4_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + struct hash_net4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_net4_elem e = { .cidr = HOST_MASK }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + u32 ip = 0, ip_to = 0, i = 0; + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP] || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP], &ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + if (tb[IPSET_ATTR_CIDR]) { + e.cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + if (!e.cidr || e.cidr > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + } + + if (tb[IPSET_ATTR_CADT_FLAGS]) { + u32 cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]); + + if (cadt_flags & IPSET_FLAG_NOMATCH) + flags |= (IPSET_FLAG_NOMATCH << 16); + } + + if (adt == IPSET_TEST || !tb[IPSET_ATTR_IP_TO]) { + e.ip = htonl(ip & ip_set_hostmask(e.cidr)); + ret = adtfn(set, &e, &ext, &ext, flags); + return ip_set_enomatch(ret, flags, adt, set) ? -ret : + ip_set_eexist(ret, flags) ? 0 : ret; + } + + ip_to = ip; + if (tb[IPSET_ATTR_IP_TO]) { + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP_TO], &ip_to); + if (ret) + return ret; + if (ip_to < ip) + swap(ip, ip_to); + if (ip + UINT_MAX == ip_to) + return -IPSET_ERR_HASH_RANGE; + } + + if (retried) + ip = ntohl(h->next.ip); + do { + i++; + e.ip = htonl(ip); + if (i > IPSET_MAX_RANGE) { + hash_net4_data_next(&h->next, &e); + return -ERANGE; + } + ip = ip_set_range_to_cidr(ip, ip_to, &e.cidr); + ret = adtfn(set, &e, &ext, &ext, flags); + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + ret = 0; + } while (ip++ < ip_to); + return ret; +} + +/* IPv6 variant */ + +struct hash_net6_elem { + union nf_inet_addr ip; + u16 padding0; + u8 nomatch; + u8 cidr; +}; + +/* Common functions */ + +static bool +hash_net6_data_equal(const struct hash_net6_elem *ip1, + const struct hash_net6_elem *ip2, + u32 *multi) +{ + return ipv6_addr_equal(&ip1->ip.in6, &ip2->ip.in6) && + ip1->cidr == ip2->cidr; +} + +static int +hash_net6_do_data_match(const struct hash_net6_elem *elem) +{ + return elem->nomatch ? -ENOTEMPTY : 1; +} + +static void +hash_net6_data_set_flags(struct hash_net6_elem *elem, u32 flags) +{ + elem->nomatch = (flags >> 16) & IPSET_FLAG_NOMATCH; +} + +static void +hash_net6_data_reset_flags(struct hash_net6_elem *elem, u8 *flags) +{ + swap(*flags, elem->nomatch); +} + +static void +hash_net6_data_netmask(struct hash_net6_elem *elem, u8 cidr) +{ + ip6_netmask(&elem->ip, cidr); + elem->cidr = cidr; +} + +static bool +hash_net6_data_list(struct sk_buff *skb, const struct hash_net6_elem *data) +{ + u32 flags = data->nomatch ? IPSET_FLAG_NOMATCH : 0; + + if (nla_put_ipaddr6(skb, IPSET_ATTR_IP, &data->ip.in6) || + nla_put_u8(skb, IPSET_ATTR_CIDR, data->cidr) || + (flags && + nla_put_net32(skb, IPSET_ATTR_CADT_FLAGS, htonl(flags)))) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_net6_data_next(struct hash_net6_elem *next, + const struct hash_net6_elem *d) +{ +} + +#undef MTYPE +#undef HOST_MASK + +#define MTYPE hash_net6 +#define HOST_MASK 128 +#define IP_SET_EMIT_CREATE +#include "ip_set_hash_gen.h" + +static int +hash_net6_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + const struct hash_net6 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_net6_elem e = { + .cidr = INIT_CIDR(h->nets[0].cidr[0], HOST_MASK), + }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + if (e.cidr == 0) + return -EINVAL; + if (adt == IPSET_TEST) + e.cidr = HOST_MASK; + + ip6addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip.in6); + ip6_netmask(&e.ip, e.cidr); + + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_net6_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_net6_elem e = { .cidr = HOST_MASK }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP] || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + if (unlikely(tb[IPSET_ATTR_IP_TO])) + return -IPSET_ERR_HASH_RANGE_UNSUPPORTED; + + ret = ip_set_get_ipaddr6(tb[IPSET_ATTR_IP], &e.ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + if (tb[IPSET_ATTR_CIDR]) { + e.cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + if (!e.cidr || e.cidr > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + } + + ip6_netmask(&e.ip, e.cidr); + + if (tb[IPSET_ATTR_CADT_FLAGS]) { + u32 cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]); + + if (cadt_flags & IPSET_FLAG_NOMATCH) + flags |= (IPSET_FLAG_NOMATCH << 16); + } + + ret = adtfn(set, &e, &ext, &ext, flags); + + return ip_set_enomatch(ret, flags, adt, set) ? -ret : + ip_set_eexist(ret, flags) ? 0 : ret; +} + +static struct ip_set_type hash_net_type __read_mostly = { + .name = "hash:net", + .protocol = IPSET_PROTOCOL, + .features = IPSET_TYPE_IP | IPSET_TYPE_NOMATCH, + .dimension = IPSET_DIM_ONE, + .family = NFPROTO_UNSPEC, + .revision_min = IPSET_TYPE_REV_MIN, + .revision_max = IPSET_TYPE_REV_MAX, + .create_flags[IPSET_TYPE_REV_MAX] = IPSET_CREATE_FLAG_BUCKETSIZE, + .create = hash_net_create, + .create_policy = { + [IPSET_ATTR_HASHSIZE] = { .type = NLA_U32 }, + [IPSET_ATTR_MAXELEM] = { .type = NLA_U32 }, + [IPSET_ATTR_INITVAL] = { .type = NLA_U32 }, + [IPSET_ATTR_BUCKETSIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_RESIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + }, + .adt_policy = { + [IPSET_ATTR_IP] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP_TO] = { .type = NLA_NESTED }, + [IPSET_ATTR_CIDR] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_LINENO] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + [IPSET_ATTR_BYTES] = { .type = NLA_U64 }, + [IPSET_ATTR_PACKETS] = { .type = NLA_U64 }, + [IPSET_ATTR_COMMENT] = { .type = NLA_NUL_STRING, + .len = IPSET_MAX_COMMENT_SIZE }, + [IPSET_ATTR_SKBMARK] = { .type = NLA_U64 }, + [IPSET_ATTR_SKBPRIO] = { .type = NLA_U32 }, + [IPSET_ATTR_SKBQUEUE] = { .type = NLA_U16 }, + }, + .me = THIS_MODULE, +}; + +static int __init +hash_net_init(void) +{ + return ip_set_type_register(&hash_net_type); +} + +static void __exit +hash_net_fini(void) +{ + rcu_barrier(); + ip_set_type_unregister(&hash_net_type); +} + +module_init(hash_net_init); +module_exit(hash_net_fini); diff --git a/net/netfilter/ipset/ip_set_hash_netiface.c b/net/netfilter/ipset/ip_set_hash_netiface.c new file mode 100644 index 000000000..bf1a3851b --- /dev/null +++ b/net/netfilter/ipset/ip_set_hash_netiface.c @@ -0,0 +1,525 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2011-2013 Jozsef Kadlecsik */ + +/* Kernel module implementing an IP set type: the hash:net,iface type */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#define IPSET_TYPE_REV_MIN 0 +/* 1 nomatch flag support added */ +/* 2 /0 support added */ +/* 3 Counters support added */ +/* 4 Comments support added */ +/* 5 Forceadd support added */ +/* 6 skbinfo support added */ +/* 7 interface wildcard support added */ +#define IPSET_TYPE_REV_MAX 8 /* bucketsize, initval support added */ + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Jozsef Kadlecsik "); +IP_SET_MODULE_DESC("hash:net,iface", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX); +MODULE_ALIAS("ip_set_hash:net,iface"); + +/* Type specific function prefix */ +#define HTYPE hash_netiface +#define IP_SET_HASH_WITH_NETS +#define IP_SET_HASH_WITH_MULTI +#define IP_SET_HASH_WITH_NET0 + +#define STRLCPY(a, b) strlcpy(a, b, IFNAMSIZ) + +/* IPv4 variant */ + +struct hash_netiface4_elem_hashed { + __be32 ip; + u8 physdev; + u8 cidr; + u8 nomatch; + u8 elem; +}; + +/* Member elements */ +struct hash_netiface4_elem { + __be32 ip; + u8 physdev; + u8 cidr; + u8 nomatch; + u8 elem; + u8 wildcard; + char iface[IFNAMSIZ]; +}; + +/* Common functions */ + +static bool +hash_netiface4_data_equal(const struct hash_netiface4_elem *ip1, + const struct hash_netiface4_elem *ip2, + u32 *multi) +{ + return ip1->ip == ip2->ip && + ip1->cidr == ip2->cidr && + (++*multi) && + ip1->physdev == ip2->physdev && + (ip1->wildcard ? + strncmp(ip1->iface, ip2->iface, strlen(ip1->iface)) == 0 : + strcmp(ip1->iface, ip2->iface) == 0); +} + +static int +hash_netiface4_do_data_match(const struct hash_netiface4_elem *elem) +{ + return elem->nomatch ? -ENOTEMPTY : 1; +} + +static void +hash_netiface4_data_set_flags(struct hash_netiface4_elem *elem, u32 flags) +{ + elem->nomatch = (flags >> 16) & IPSET_FLAG_NOMATCH; +} + +static void +hash_netiface4_data_reset_flags(struct hash_netiface4_elem *elem, u8 *flags) +{ + swap(*flags, elem->nomatch); +} + +static void +hash_netiface4_data_netmask(struct hash_netiface4_elem *elem, u8 cidr) +{ + elem->ip &= ip_set_netmask(cidr); + elem->cidr = cidr; +} + +static bool +hash_netiface4_data_list(struct sk_buff *skb, + const struct hash_netiface4_elem *data) +{ + u32 flags = (data->physdev ? IPSET_FLAG_PHYSDEV : 0) | + (data->wildcard ? IPSET_FLAG_IFACE_WILDCARD : 0); + + if (data->nomatch) + flags |= IPSET_FLAG_NOMATCH; + if (nla_put_ipaddr4(skb, IPSET_ATTR_IP, data->ip) || + nla_put_u8(skb, IPSET_ATTR_CIDR, data->cidr) || + nla_put_string(skb, IPSET_ATTR_IFACE, data->iface) || + (flags && + nla_put_net32(skb, IPSET_ATTR_CADT_FLAGS, htonl(flags)))) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_netiface4_data_next(struct hash_netiface4_elem *next, + const struct hash_netiface4_elem *d) +{ + next->ip = d->ip; +} + +#define MTYPE hash_netiface4 +#define HOST_MASK 32 +#define HKEY_DATALEN sizeof(struct hash_netiface4_elem_hashed) +#include "ip_set_hash_gen.h" + +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) +static const char *get_physindev_name(const struct sk_buff *skb, struct net *net) +{ + struct net_device *dev = nf_bridge_get_physindev(skb, net); + + return dev ? dev->name : NULL; +} + +static const char *get_physoutdev_name(const struct sk_buff *skb) +{ + struct net_device *dev = nf_bridge_get_physoutdev(skb); + + return dev ? dev->name : NULL; +} +#endif + +static int +hash_netiface4_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + struct hash_netiface4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_netiface4_elem e = { + .cidr = INIT_CIDR(h->nets[0].cidr[0], HOST_MASK), + .elem = 1, + }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + if (adt == IPSET_TEST) + e.cidr = HOST_MASK; + + ip4addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip); + e.ip &= ip_set_netmask(e.cidr); + +#define IFACE(dir) (par->state->dir ? par->state->dir->name : "") +#define SRCDIR (opt->flags & IPSET_DIM_TWO_SRC) + + if (opt->cmdflags & IPSET_FLAG_PHYSDEV) { +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + const char *eiface = SRCDIR ? get_physindev_name(skb, xt_net(par)) : + get_physoutdev_name(skb); + + if (!eiface) + return -EINVAL; + STRLCPY(e.iface, eiface); + e.physdev = 1; +#endif + } else { + STRLCPY(e.iface, SRCDIR ? IFACE(in) : IFACE(out)); + } + + if (strlen(e.iface) == 0) + return -EINVAL; + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_netiface4_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + struct hash_netiface4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_netiface4_elem e = { .cidr = HOST_MASK, .elem = 1 }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + u32 ip = 0, ip_to = 0, i = 0; + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP] || + !tb[IPSET_ATTR_IFACE] || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP], &ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + if (tb[IPSET_ATTR_CIDR]) { + e.cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + if (e.cidr > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + } + nla_strscpy(e.iface, tb[IPSET_ATTR_IFACE], IFNAMSIZ); + + if (tb[IPSET_ATTR_CADT_FLAGS]) { + u32 cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]); + + if (cadt_flags & IPSET_FLAG_PHYSDEV) + e.physdev = 1; + if (cadt_flags & IPSET_FLAG_NOMATCH) + flags |= (IPSET_FLAG_NOMATCH << 16); + if (cadt_flags & IPSET_FLAG_IFACE_WILDCARD) + e.wildcard = 1; + } + if (adt == IPSET_TEST || !tb[IPSET_ATTR_IP_TO]) { + e.ip = htonl(ip & ip_set_hostmask(e.cidr)); + ret = adtfn(set, &e, &ext, &ext, flags); + return ip_set_enomatch(ret, flags, adt, set) ? -ret : + ip_set_eexist(ret, flags) ? 0 : ret; + } + + if (tb[IPSET_ATTR_IP_TO]) { + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP_TO], &ip_to); + if (ret) + return ret; + if (ip_to < ip) + swap(ip, ip_to); + if (ip + UINT_MAX == ip_to) + return -IPSET_ERR_HASH_RANGE; + } else { + ip_set_mask_from_to(ip, ip_to, e.cidr); + } + + if (retried) + ip = ntohl(h->next.ip); + do { + i++; + e.ip = htonl(ip); + if (i > IPSET_MAX_RANGE) { + hash_netiface4_data_next(&h->next, &e); + return -ERANGE; + } + ip = ip_set_range_to_cidr(ip, ip_to, &e.cidr); + ret = adtfn(set, &e, &ext, &ext, flags); + + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + ret = 0; + } while (ip++ < ip_to); + return ret; +} + +/* IPv6 variant */ + +struct hash_netiface6_elem_hashed { + union nf_inet_addr ip; + u8 physdev; + u8 cidr; + u8 nomatch; + u8 elem; +}; + +struct hash_netiface6_elem { + union nf_inet_addr ip; + u8 physdev; + u8 cidr; + u8 nomatch; + u8 elem; + u8 wildcard; + char iface[IFNAMSIZ]; +}; + +/* Common functions */ + +static bool +hash_netiface6_data_equal(const struct hash_netiface6_elem *ip1, + const struct hash_netiface6_elem *ip2, + u32 *multi) +{ + return ipv6_addr_equal(&ip1->ip.in6, &ip2->ip.in6) && + ip1->cidr == ip2->cidr && + (++*multi) && + ip1->physdev == ip2->physdev && + (ip1->wildcard ? + strncmp(ip1->iface, ip2->iface, strlen(ip1->iface)) == 0 : + strcmp(ip1->iface, ip2->iface) == 0); +} + +static int +hash_netiface6_do_data_match(const struct hash_netiface6_elem *elem) +{ + return elem->nomatch ? -ENOTEMPTY : 1; +} + +static void +hash_netiface6_data_set_flags(struct hash_netiface6_elem *elem, u32 flags) +{ + elem->nomatch = (flags >> 16) & IPSET_FLAG_NOMATCH; +} + +static void +hash_netiface6_data_reset_flags(struct hash_netiface6_elem *elem, u8 *flags) +{ + swap(*flags, elem->nomatch); +} + +static void +hash_netiface6_data_netmask(struct hash_netiface6_elem *elem, u8 cidr) +{ + ip6_netmask(&elem->ip, cidr); + elem->cidr = cidr; +} + +static bool +hash_netiface6_data_list(struct sk_buff *skb, + const struct hash_netiface6_elem *data) +{ + u32 flags = (data->physdev ? IPSET_FLAG_PHYSDEV : 0) | + (data->wildcard ? IPSET_FLAG_IFACE_WILDCARD : 0); + + if (data->nomatch) + flags |= IPSET_FLAG_NOMATCH; + if (nla_put_ipaddr6(skb, IPSET_ATTR_IP, &data->ip.in6) || + nla_put_u8(skb, IPSET_ATTR_CIDR, data->cidr) || + nla_put_string(skb, IPSET_ATTR_IFACE, data->iface) || + (flags && + nla_put_net32(skb, IPSET_ATTR_CADT_FLAGS, htonl(flags)))) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_netiface6_data_next(struct hash_netiface6_elem *next, + const struct hash_netiface6_elem *d) +{ +} + +#undef MTYPE +#undef HOST_MASK + +#define MTYPE hash_netiface6 +#define HOST_MASK 128 +#define HKEY_DATALEN sizeof(struct hash_netiface6_elem_hashed) +#define IP_SET_EMIT_CREATE +#include "ip_set_hash_gen.h" + +static int +hash_netiface6_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + struct hash_netiface6 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_netiface6_elem e = { + .cidr = INIT_CIDR(h->nets[0].cidr[0], HOST_MASK), + .elem = 1, + }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + if (adt == IPSET_TEST) + e.cidr = HOST_MASK; + + ip6addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip.in6); + ip6_netmask(&e.ip, e.cidr); + + if (opt->cmdflags & IPSET_FLAG_PHYSDEV) { +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + const char *eiface = SRCDIR ? get_physindev_name(skb, xt_net(par)) : + get_physoutdev_name(skb); + + if (!eiface) + return -EINVAL; + STRLCPY(e.iface, eiface); + e.physdev = 1; +#endif + } else { + STRLCPY(e.iface, SRCDIR ? IFACE(in) : IFACE(out)); + } + + if (strlen(e.iface) == 0) + return -EINVAL; + + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_netiface6_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_netiface6_elem e = { .cidr = HOST_MASK, .elem = 1 }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP] || + !tb[IPSET_ATTR_IFACE] || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + if (unlikely(tb[IPSET_ATTR_IP_TO])) + return -IPSET_ERR_HASH_RANGE_UNSUPPORTED; + + ret = ip_set_get_ipaddr6(tb[IPSET_ATTR_IP], &e.ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + if (tb[IPSET_ATTR_CIDR]) { + e.cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + if (e.cidr > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + } + + ip6_netmask(&e.ip, e.cidr); + + nla_strscpy(e.iface, tb[IPSET_ATTR_IFACE], IFNAMSIZ); + + if (tb[IPSET_ATTR_CADT_FLAGS]) { + u32 cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]); + + if (cadt_flags & IPSET_FLAG_PHYSDEV) + e.physdev = 1; + if (cadt_flags & IPSET_FLAG_NOMATCH) + flags |= (IPSET_FLAG_NOMATCH << 16); + if (cadt_flags & IPSET_FLAG_IFACE_WILDCARD) + e.wildcard = 1; + } + + ret = adtfn(set, &e, &ext, &ext, flags); + + return ip_set_enomatch(ret, flags, adt, set) ? -ret : + ip_set_eexist(ret, flags) ? 0 : ret; +} + +static struct ip_set_type hash_netiface_type __read_mostly = { + .name = "hash:net,iface", + .protocol = IPSET_PROTOCOL, + .features = IPSET_TYPE_IP | IPSET_TYPE_IFACE | + IPSET_TYPE_NOMATCH, + .dimension = IPSET_DIM_TWO, + .family = NFPROTO_UNSPEC, + .revision_min = IPSET_TYPE_REV_MIN, + .revision_max = IPSET_TYPE_REV_MAX, + .create_flags[IPSET_TYPE_REV_MAX] = IPSET_CREATE_FLAG_BUCKETSIZE, + .create = hash_netiface_create, + .create_policy = { + [IPSET_ATTR_HASHSIZE] = { .type = NLA_U32 }, + [IPSET_ATTR_MAXELEM] = { .type = NLA_U32 }, + [IPSET_ATTR_INITVAL] = { .type = NLA_U32 }, + [IPSET_ATTR_BUCKETSIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_RESIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_PROTO] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + }, + .adt_policy = { + [IPSET_ATTR_IP] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP_TO] = { .type = NLA_NESTED }, + [IPSET_ATTR_IFACE] = { .type = NLA_NUL_STRING, + .len = IFNAMSIZ - 1 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + [IPSET_ATTR_CIDR] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_LINENO] = { .type = NLA_U32 }, + [IPSET_ATTR_BYTES] = { .type = NLA_U64 }, + [IPSET_ATTR_PACKETS] = { .type = NLA_U64 }, + [IPSET_ATTR_COMMENT] = { .type = NLA_NUL_STRING, + .len = IPSET_MAX_COMMENT_SIZE }, + [IPSET_ATTR_SKBMARK] = { .type = NLA_U64 }, + [IPSET_ATTR_SKBPRIO] = { .type = NLA_U32 }, + [IPSET_ATTR_SKBQUEUE] = { .type = NLA_U16 }, + }, + .me = THIS_MODULE, +}; + +static int __init +hash_netiface_init(void) +{ + return ip_set_type_register(&hash_netiface_type); +} + +static void __exit +hash_netiface_fini(void) +{ + rcu_barrier(); + ip_set_type_unregister(&hash_netiface_type); +} + +module_init(hash_netiface_init); +module_exit(hash_netiface_fini); diff --git a/net/netfilter/ipset/ip_set_hash_netnet.c b/net/netfilter/ipset/ip_set_hash_netnet.c new file mode 100644 index 000000000..c07b70bf3 --- /dev/null +++ b/net/netfilter/ipset/ip_set_hash_netnet.c @@ -0,0 +1,514 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2003-2013 Jozsef Kadlecsik + * Copyright (C) 2013 Oliver Smith + */ + +/* Kernel module implementing an IP set type: the hash:net type */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#define IPSET_TYPE_REV_MIN 0 +/* 1 Forceadd support added */ +/* 2 skbinfo support added */ +#define IPSET_TYPE_REV_MAX 3 /* bucketsize, initval support added */ + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Oliver Smith "); +IP_SET_MODULE_DESC("hash:net,net", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX); +MODULE_ALIAS("ip_set_hash:net,net"); + +/* Type specific function prefix */ +#define HTYPE hash_netnet +#define IP_SET_HASH_WITH_NETS +#define IPSET_NET_COUNT 2 + +/* IPv4 variants */ + +/* Member elements */ +struct hash_netnet4_elem { + union { + __be32 ip[2]; + __be64 ipcmp; + }; + u8 nomatch; + u8 padding; + union { + u8 cidr[2]; + u16 ccmp; + }; +}; + +/* Common functions */ + +static bool +hash_netnet4_data_equal(const struct hash_netnet4_elem *ip1, + const struct hash_netnet4_elem *ip2, + u32 *multi) +{ + return ip1->ipcmp == ip2->ipcmp && + ip1->ccmp == ip2->ccmp; +} + +static int +hash_netnet4_do_data_match(const struct hash_netnet4_elem *elem) +{ + return elem->nomatch ? -ENOTEMPTY : 1; +} + +static void +hash_netnet4_data_set_flags(struct hash_netnet4_elem *elem, u32 flags) +{ + elem->nomatch = (flags >> 16) & IPSET_FLAG_NOMATCH; +} + +static void +hash_netnet4_data_reset_flags(struct hash_netnet4_elem *elem, u8 *flags) +{ + swap(*flags, elem->nomatch); +} + +static void +hash_netnet4_data_reset_elem(struct hash_netnet4_elem *elem, + struct hash_netnet4_elem *orig) +{ + elem->ip[1] = orig->ip[1]; +} + +static void +hash_netnet4_data_netmask(struct hash_netnet4_elem *elem, u8 cidr, bool inner) +{ + if (inner) { + elem->ip[1] &= ip_set_netmask(cidr); + elem->cidr[1] = cidr; + } else { + elem->ip[0] &= ip_set_netmask(cidr); + elem->cidr[0] = cidr; + } +} + +static bool +hash_netnet4_data_list(struct sk_buff *skb, + const struct hash_netnet4_elem *data) +{ + u32 flags = data->nomatch ? IPSET_FLAG_NOMATCH : 0; + + if (nla_put_ipaddr4(skb, IPSET_ATTR_IP, data->ip[0]) || + nla_put_ipaddr4(skb, IPSET_ATTR_IP2, data->ip[1]) || + nla_put_u8(skb, IPSET_ATTR_CIDR, data->cidr[0]) || + nla_put_u8(skb, IPSET_ATTR_CIDR2, data->cidr[1]) || + (flags && + nla_put_net32(skb, IPSET_ATTR_CADT_FLAGS, htonl(flags)))) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_netnet4_data_next(struct hash_netnet4_elem *next, + const struct hash_netnet4_elem *d) +{ + next->ipcmp = d->ipcmp; +} + +#define MTYPE hash_netnet4 +#define HOST_MASK 32 +#include "ip_set_hash_gen.h" + +static void +hash_netnet4_init(struct hash_netnet4_elem *e) +{ + e->cidr[0] = HOST_MASK; + e->cidr[1] = HOST_MASK; +} + +static int +hash_netnet4_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + const struct hash_netnet4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_netnet4_elem e = { }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + e.cidr[0] = INIT_CIDR(h->nets[0].cidr[0], HOST_MASK); + e.cidr[1] = INIT_CIDR(h->nets[0].cidr[1], HOST_MASK); + if (adt == IPSET_TEST) + e.ccmp = (HOST_MASK << (sizeof(e.cidr[0]) * 8)) | HOST_MASK; + + ip4addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip[0]); + ip4addrptr(skb, opt->flags & IPSET_DIM_TWO_SRC, &e.ip[1]); + e.ip[0] &= ip_set_netmask(e.cidr[0]); + e.ip[1] &= ip_set_netmask(e.cidr[1]); + + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_netnet4_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + struct hash_netnet4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_netnet4_elem e = { }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + u32 ip = 0, ip_to = 0; + u32 ip2 = 0, ip2_from = 0, ip2_to = 0, i = 0; + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + hash_netnet4_init(&e); + if (unlikely(!tb[IPSET_ATTR_IP] || !tb[IPSET_ATTR_IP2] || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP], &ip); + if (ret) + return ret; + + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP2], &ip2_from); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + if (tb[IPSET_ATTR_CIDR]) { + e.cidr[0] = nla_get_u8(tb[IPSET_ATTR_CIDR]); + if (!e.cidr[0] || e.cidr[0] > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + } + + if (tb[IPSET_ATTR_CIDR2]) { + e.cidr[1] = nla_get_u8(tb[IPSET_ATTR_CIDR2]); + if (!e.cidr[1] || e.cidr[1] > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + } + + if (tb[IPSET_ATTR_CADT_FLAGS]) { + u32 cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]); + + if (cadt_flags & IPSET_FLAG_NOMATCH) + flags |= (IPSET_FLAG_NOMATCH << 16); + } + + if (adt == IPSET_TEST || !(tb[IPSET_ATTR_IP_TO] || + tb[IPSET_ATTR_IP2_TO])) { + e.ip[0] = htonl(ip & ip_set_hostmask(e.cidr[0])); + e.ip[1] = htonl(ip2_from & ip_set_hostmask(e.cidr[1])); + ret = adtfn(set, &e, &ext, &ext, flags); + return ip_set_enomatch(ret, flags, adt, set) ? -ret : + ip_set_eexist(ret, flags) ? 0 : ret; + } + + ip_to = ip; + if (tb[IPSET_ATTR_IP_TO]) { + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP_TO], &ip_to); + if (ret) + return ret; + if (ip_to < ip) + swap(ip, ip_to); + if (unlikely(ip + UINT_MAX == ip_to)) + return -IPSET_ERR_HASH_RANGE; + } else { + ip_set_mask_from_to(ip, ip_to, e.cidr[0]); + } + + ip2_to = ip2_from; + if (tb[IPSET_ATTR_IP2_TO]) { + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP2_TO], &ip2_to); + if (ret) + return ret; + if (ip2_to < ip2_from) + swap(ip2_from, ip2_to); + if (unlikely(ip2_from + UINT_MAX == ip2_to)) + return -IPSET_ERR_HASH_RANGE; + } else { + ip_set_mask_from_to(ip2_from, ip2_to, e.cidr[1]); + } + + if (retried) { + ip = ntohl(h->next.ip[0]); + ip2 = ntohl(h->next.ip[1]); + } else { + ip2 = ip2_from; + } + + do { + e.ip[0] = htonl(ip); + ip = ip_set_range_to_cidr(ip, ip_to, &e.cidr[0]); + do { + i++; + e.ip[1] = htonl(ip2); + if (i > IPSET_MAX_RANGE) { + hash_netnet4_data_next(&h->next, &e); + return -ERANGE; + } + ip2 = ip_set_range_to_cidr(ip2, ip2_to, &e.cidr[1]); + ret = adtfn(set, &e, &ext, &ext, flags); + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + ret = 0; + } while (ip2++ < ip2_to); + ip2 = ip2_from; + } while (ip++ < ip_to); + return ret; +} + +/* IPv6 variants */ + +struct hash_netnet6_elem { + union nf_inet_addr ip[2]; + u8 nomatch; + u8 padding; + union { + u8 cidr[2]; + u16 ccmp; + }; +}; + +/* Common functions */ + +static bool +hash_netnet6_data_equal(const struct hash_netnet6_elem *ip1, + const struct hash_netnet6_elem *ip2, + u32 *multi) +{ + return ipv6_addr_equal(&ip1->ip[0].in6, &ip2->ip[0].in6) && + ipv6_addr_equal(&ip1->ip[1].in6, &ip2->ip[1].in6) && + ip1->ccmp == ip2->ccmp; +} + +static int +hash_netnet6_do_data_match(const struct hash_netnet6_elem *elem) +{ + return elem->nomatch ? -ENOTEMPTY : 1; +} + +static void +hash_netnet6_data_set_flags(struct hash_netnet6_elem *elem, u32 flags) +{ + elem->nomatch = (flags >> 16) & IPSET_FLAG_NOMATCH; +} + +static void +hash_netnet6_data_reset_flags(struct hash_netnet6_elem *elem, u8 *flags) +{ + swap(*flags, elem->nomatch); +} + +static void +hash_netnet6_data_reset_elem(struct hash_netnet6_elem *elem, + struct hash_netnet6_elem *orig) +{ + elem->ip[1] = orig->ip[1]; +} + +static void +hash_netnet6_data_netmask(struct hash_netnet6_elem *elem, u8 cidr, bool inner) +{ + if (inner) { + ip6_netmask(&elem->ip[1], cidr); + elem->cidr[1] = cidr; + } else { + ip6_netmask(&elem->ip[0], cidr); + elem->cidr[0] = cidr; + } +} + +static bool +hash_netnet6_data_list(struct sk_buff *skb, + const struct hash_netnet6_elem *data) +{ + u32 flags = data->nomatch ? IPSET_FLAG_NOMATCH : 0; + + if (nla_put_ipaddr6(skb, IPSET_ATTR_IP, &data->ip[0].in6) || + nla_put_ipaddr6(skb, IPSET_ATTR_IP2, &data->ip[1].in6) || + nla_put_u8(skb, IPSET_ATTR_CIDR, data->cidr[0]) || + nla_put_u8(skb, IPSET_ATTR_CIDR2, data->cidr[1]) || + (flags && + nla_put_net32(skb, IPSET_ATTR_CADT_FLAGS, htonl(flags)))) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_netnet6_data_next(struct hash_netnet6_elem *next, + const struct hash_netnet6_elem *d) +{ +} + +#undef MTYPE +#undef HOST_MASK + +#define MTYPE hash_netnet6 +#define HOST_MASK 128 +#define IP_SET_EMIT_CREATE +#include "ip_set_hash_gen.h" + +static void +hash_netnet6_init(struct hash_netnet6_elem *e) +{ + e->cidr[0] = HOST_MASK; + e->cidr[1] = HOST_MASK; +} + +static int +hash_netnet6_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + const struct hash_netnet6 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_netnet6_elem e = { }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + e.cidr[0] = INIT_CIDR(h->nets[0].cidr[0], HOST_MASK); + e.cidr[1] = INIT_CIDR(h->nets[0].cidr[1], HOST_MASK); + if (adt == IPSET_TEST) + e.ccmp = (HOST_MASK << (sizeof(u8) * 8)) | HOST_MASK; + + ip6addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip[0].in6); + ip6addrptr(skb, opt->flags & IPSET_DIM_TWO_SRC, &e.ip[1].in6); + ip6_netmask(&e.ip[0], e.cidr[0]); + ip6_netmask(&e.ip[1], e.cidr[1]); + + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_netnet6_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_netnet6_elem e = { }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + hash_netnet6_init(&e); + if (unlikely(!tb[IPSET_ATTR_IP] || !tb[IPSET_ATTR_IP2] || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + if (unlikely(tb[IPSET_ATTR_IP_TO] || tb[IPSET_ATTR_IP2_TO])) + return -IPSET_ERR_HASH_RANGE_UNSUPPORTED; + + ret = ip_set_get_ipaddr6(tb[IPSET_ATTR_IP], &e.ip[0]); + if (ret) + return ret; + + ret = ip_set_get_ipaddr6(tb[IPSET_ATTR_IP2], &e.ip[1]); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + if (tb[IPSET_ATTR_CIDR]) { + e.cidr[0] = nla_get_u8(tb[IPSET_ATTR_CIDR]); + if (!e.cidr[0] || e.cidr[0] > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + } + + if (tb[IPSET_ATTR_CIDR2]) { + e.cidr[1] = nla_get_u8(tb[IPSET_ATTR_CIDR2]); + if (!e.cidr[1] || e.cidr[1] > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + } + + ip6_netmask(&e.ip[0], e.cidr[0]); + ip6_netmask(&e.ip[1], e.cidr[1]); + + if (tb[IPSET_ATTR_CADT_FLAGS]) { + u32 cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]); + + if (cadt_flags & IPSET_FLAG_NOMATCH) + flags |= (IPSET_FLAG_NOMATCH << 16); + } + + ret = adtfn(set, &e, &ext, &ext, flags); + + return ip_set_enomatch(ret, flags, adt, set) ? -ret : + ip_set_eexist(ret, flags) ? 0 : ret; +} + +static struct ip_set_type hash_netnet_type __read_mostly = { + .name = "hash:net,net", + .protocol = IPSET_PROTOCOL, + .features = IPSET_TYPE_IP | IPSET_TYPE_IP2 | IPSET_TYPE_NOMATCH, + .dimension = IPSET_DIM_TWO, + .family = NFPROTO_UNSPEC, + .revision_min = IPSET_TYPE_REV_MIN, + .revision_max = IPSET_TYPE_REV_MAX, + .create_flags[IPSET_TYPE_REV_MAX] = IPSET_CREATE_FLAG_BUCKETSIZE, + .create = hash_netnet_create, + .create_policy = { + [IPSET_ATTR_HASHSIZE] = { .type = NLA_U32 }, + [IPSET_ATTR_MAXELEM] = { .type = NLA_U32 }, + [IPSET_ATTR_INITVAL] = { .type = NLA_U32 }, + [IPSET_ATTR_BUCKETSIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_RESIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + }, + .adt_policy = { + [IPSET_ATTR_IP] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP_TO] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP2] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP2_TO] = { .type = NLA_NESTED }, + [IPSET_ATTR_CIDR] = { .type = NLA_U8 }, + [IPSET_ATTR_CIDR2] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_LINENO] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + [IPSET_ATTR_BYTES] = { .type = NLA_U64 }, + [IPSET_ATTR_PACKETS] = { .type = NLA_U64 }, + [IPSET_ATTR_COMMENT] = { .type = NLA_NUL_STRING, + .len = IPSET_MAX_COMMENT_SIZE }, + [IPSET_ATTR_SKBMARK] = { .type = NLA_U64 }, + [IPSET_ATTR_SKBPRIO] = { .type = NLA_U32 }, + [IPSET_ATTR_SKBQUEUE] = { .type = NLA_U16 }, + }, + .me = THIS_MODULE, +}; + +static int __init +hash_netnet_init(void) +{ + return ip_set_type_register(&hash_netnet_type); +} + +static void __exit +hash_netnet_fini(void) +{ + rcu_barrier(); + ip_set_type_unregister(&hash_netnet_type); +} + +module_init(hash_netnet_init); +module_exit(hash_netnet_fini); diff --git a/net/netfilter/ipset/ip_set_hash_netport.c b/net/netfilter/ipset/ip_set_hash_netport.c new file mode 100644 index 000000000..d1a0628df --- /dev/null +++ b/net/netfilter/ipset/ip_set_hash_netport.c @@ -0,0 +1,515 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2003-2013 Jozsef Kadlecsik */ + +/* Kernel module implementing an IP set type: the hash:net,port type */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#define IPSET_TYPE_REV_MIN 0 +/* 1 SCTP and UDPLITE support added */ +/* 2 Range as input support for IPv4 added */ +/* 3 nomatch flag support added */ +/* 4 Counters support added */ +/* 5 Comments support added */ +/* 6 Forceadd support added */ +/* 7 skbinfo support added */ +#define IPSET_TYPE_REV_MAX 8 /* bucketsize, initval support added */ + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Jozsef Kadlecsik "); +IP_SET_MODULE_DESC("hash:net,port", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX); +MODULE_ALIAS("ip_set_hash:net,port"); + +/* Type specific function prefix */ +#define HTYPE hash_netport +#define IP_SET_HASH_WITH_PROTO +#define IP_SET_HASH_WITH_NETS + +/* We squeeze the "nomatch" flag into cidr: we don't support cidr == 0 + * However this way we have to store internally cidr - 1, + * dancing back and forth. + */ +#define IP_SET_HASH_WITH_NETS_PACKED + +/* IPv4 variant */ + +/* Member elements */ +struct hash_netport4_elem { + __be32 ip; + __be16 port; + u8 proto; + u8 cidr:7; + u8 nomatch:1; +}; + +/* Common functions */ + +static bool +hash_netport4_data_equal(const struct hash_netport4_elem *ip1, + const struct hash_netport4_elem *ip2, + u32 *multi) +{ + return ip1->ip == ip2->ip && + ip1->port == ip2->port && + ip1->proto == ip2->proto && + ip1->cidr == ip2->cidr; +} + +static int +hash_netport4_do_data_match(const struct hash_netport4_elem *elem) +{ + return elem->nomatch ? -ENOTEMPTY : 1; +} + +static void +hash_netport4_data_set_flags(struct hash_netport4_elem *elem, u32 flags) +{ + elem->nomatch = !!((flags >> 16) & IPSET_FLAG_NOMATCH); +} + +static void +hash_netport4_data_reset_flags(struct hash_netport4_elem *elem, u8 *flags) +{ + swap(*flags, elem->nomatch); +} + +static void +hash_netport4_data_netmask(struct hash_netport4_elem *elem, u8 cidr) +{ + elem->ip &= ip_set_netmask(cidr); + elem->cidr = cidr - 1; +} + +static bool +hash_netport4_data_list(struct sk_buff *skb, + const struct hash_netport4_elem *data) +{ + u32 flags = data->nomatch ? IPSET_FLAG_NOMATCH : 0; + + if (nla_put_ipaddr4(skb, IPSET_ATTR_IP, data->ip) || + nla_put_net16(skb, IPSET_ATTR_PORT, data->port) || + nla_put_u8(skb, IPSET_ATTR_CIDR, data->cidr + 1) || + nla_put_u8(skb, IPSET_ATTR_PROTO, data->proto) || + (flags && + nla_put_net32(skb, IPSET_ATTR_CADT_FLAGS, htonl(flags)))) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_netport4_data_next(struct hash_netport4_elem *next, + const struct hash_netport4_elem *d) +{ + next->ip = d->ip; + next->port = d->port; +} + +#define MTYPE hash_netport4 +#define HOST_MASK 32 +#include "ip_set_hash_gen.h" + +static int +hash_netport4_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + const struct hash_netport4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_netport4_elem e = { + .cidr = INIT_CIDR(h->nets[0].cidr[0], HOST_MASK), + }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + if (adt == IPSET_TEST) + e.cidr = HOST_MASK - 1; + + if (!ip_set_get_ip4_port(skb, opt->flags & IPSET_DIM_TWO_SRC, + &e.port, &e.proto)) + return -EINVAL; + + ip4addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip); + e.ip &= ip_set_netmask(e.cidr + 1); + + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_netport4_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + struct hash_netport4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_netport4_elem e = { .cidr = HOST_MASK - 1 }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + u32 port, port_to, p = 0, ip = 0, ip_to = 0, i = 0; + bool with_ports = false; + u8 cidr; + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP] || + !ip_set_attr_netorder(tb, IPSET_ATTR_PORT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_PORT_TO) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP], &ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + if (tb[IPSET_ATTR_CIDR]) { + cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + if (!cidr || cidr > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + e.cidr = cidr - 1; + } + + e.port = nla_get_be16(tb[IPSET_ATTR_PORT]); + + if (tb[IPSET_ATTR_PROTO]) { + e.proto = nla_get_u8(tb[IPSET_ATTR_PROTO]); + with_ports = ip_set_proto_with_ports(e.proto); + + if (e.proto == 0) + return -IPSET_ERR_INVALID_PROTO; + } else { + return -IPSET_ERR_MISSING_PROTO; + } + + if (!(with_ports || e.proto == IPPROTO_ICMP)) + e.port = 0; + + with_ports = with_ports && tb[IPSET_ATTR_PORT_TO]; + + if (tb[IPSET_ATTR_CADT_FLAGS]) { + u32 cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]); + + if (cadt_flags & IPSET_FLAG_NOMATCH) + flags |= (IPSET_FLAG_NOMATCH << 16); + } + + if (adt == IPSET_TEST || !(with_ports || tb[IPSET_ATTR_IP_TO])) { + e.ip = htonl(ip & ip_set_hostmask(e.cidr + 1)); + ret = adtfn(set, &e, &ext, &ext, flags); + return ip_set_enomatch(ret, flags, adt, set) ? -ret : + ip_set_eexist(ret, flags) ? 0 : ret; + } + + port = port_to = ntohs(e.port); + if (tb[IPSET_ATTR_PORT_TO]) { + port_to = ip_set_get_h16(tb[IPSET_ATTR_PORT_TO]); + if (port_to < port) + swap(port, port_to); + } + if (tb[IPSET_ATTR_IP_TO]) { + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP_TO], &ip_to); + if (ret) + return ret; + if (ip_to < ip) + swap(ip, ip_to); + if (ip + UINT_MAX == ip_to) + return -IPSET_ERR_HASH_RANGE; + } else { + ip_set_mask_from_to(ip, ip_to, e.cidr + 1); + } + + if (retried) { + ip = ntohl(h->next.ip); + p = ntohs(h->next.port); + } else { + p = port; + } + do { + e.ip = htonl(ip); + ip = ip_set_range_to_cidr(ip, ip_to, &cidr); + e.cidr = cidr - 1; + for (; p <= port_to; p++, i++) { + e.port = htons(p); + if (i > IPSET_MAX_RANGE) { + hash_netport4_data_next(&h->next, &e); + return -ERANGE; + } + ret = adtfn(set, &e, &ext, &ext, flags); + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + ret = 0; + } + p = port; + } while (ip++ < ip_to); + return ret; +} + +/* IPv6 variant */ + +struct hash_netport6_elem { + union nf_inet_addr ip; + __be16 port; + u8 proto; + u8 cidr:7; + u8 nomatch:1; +}; + +/* Common functions */ + +static bool +hash_netport6_data_equal(const struct hash_netport6_elem *ip1, + const struct hash_netport6_elem *ip2, + u32 *multi) +{ + return ipv6_addr_equal(&ip1->ip.in6, &ip2->ip.in6) && + ip1->port == ip2->port && + ip1->proto == ip2->proto && + ip1->cidr == ip2->cidr; +} + +static int +hash_netport6_do_data_match(const struct hash_netport6_elem *elem) +{ + return elem->nomatch ? -ENOTEMPTY : 1; +} + +static void +hash_netport6_data_set_flags(struct hash_netport6_elem *elem, u32 flags) +{ + elem->nomatch = !!((flags >> 16) & IPSET_FLAG_NOMATCH); +} + +static void +hash_netport6_data_reset_flags(struct hash_netport6_elem *elem, u8 *flags) +{ + swap(*flags, elem->nomatch); +} + +static void +hash_netport6_data_netmask(struct hash_netport6_elem *elem, u8 cidr) +{ + ip6_netmask(&elem->ip, cidr); + elem->cidr = cidr - 1; +} + +static bool +hash_netport6_data_list(struct sk_buff *skb, + const struct hash_netport6_elem *data) +{ + u32 flags = data->nomatch ? IPSET_FLAG_NOMATCH : 0; + + if (nla_put_ipaddr6(skb, IPSET_ATTR_IP, &data->ip.in6) || + nla_put_net16(skb, IPSET_ATTR_PORT, data->port) || + nla_put_u8(skb, IPSET_ATTR_CIDR, data->cidr + 1) || + nla_put_u8(skb, IPSET_ATTR_PROTO, data->proto) || + (flags && + nla_put_net32(skb, IPSET_ATTR_CADT_FLAGS, htonl(flags)))) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_netport6_data_next(struct hash_netport6_elem *next, + const struct hash_netport6_elem *d) +{ + next->port = d->port; +} + +#undef MTYPE +#undef HOST_MASK + +#define MTYPE hash_netport6 +#define HOST_MASK 128 +#define IP_SET_EMIT_CREATE +#include "ip_set_hash_gen.h" + +static int +hash_netport6_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + const struct hash_netport6 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_netport6_elem e = { + .cidr = INIT_CIDR(h->nets[0].cidr[0], HOST_MASK), + }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + if (adt == IPSET_TEST) + e.cidr = HOST_MASK - 1; + + if (!ip_set_get_ip6_port(skb, opt->flags & IPSET_DIM_TWO_SRC, + &e.port, &e.proto)) + return -EINVAL; + + ip6addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip.in6); + ip6_netmask(&e.ip, e.cidr + 1); + + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_netport6_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + const struct hash_netport6 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_netport6_elem e = { .cidr = HOST_MASK - 1 }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + u32 port, port_to; + bool with_ports = false; + u8 cidr; + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_IP] || + !ip_set_attr_netorder(tb, IPSET_ATTR_PORT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_PORT_TO) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + if (unlikely(tb[IPSET_ATTR_IP_TO])) + return -IPSET_ERR_HASH_RANGE_UNSUPPORTED; + + ret = ip_set_get_ipaddr6(tb[IPSET_ATTR_IP], &e.ip); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + if (tb[IPSET_ATTR_CIDR]) { + cidr = nla_get_u8(tb[IPSET_ATTR_CIDR]); + if (!cidr || cidr > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + e.cidr = cidr - 1; + } + ip6_netmask(&e.ip, e.cidr + 1); + + e.port = nla_get_be16(tb[IPSET_ATTR_PORT]); + + if (tb[IPSET_ATTR_PROTO]) { + e.proto = nla_get_u8(tb[IPSET_ATTR_PROTO]); + with_ports = ip_set_proto_with_ports(e.proto); + + if (e.proto == 0) + return -IPSET_ERR_INVALID_PROTO; + } else { + return -IPSET_ERR_MISSING_PROTO; + } + + if (!(with_ports || e.proto == IPPROTO_ICMPV6)) + e.port = 0; + + if (tb[IPSET_ATTR_CADT_FLAGS]) { + u32 cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]); + + if (cadt_flags & IPSET_FLAG_NOMATCH) + flags |= (IPSET_FLAG_NOMATCH << 16); + } + + if (adt == IPSET_TEST || !with_ports || !tb[IPSET_ATTR_PORT_TO]) { + ret = adtfn(set, &e, &ext, &ext, flags); + return ip_set_enomatch(ret, flags, adt, set) ? -ret : + ip_set_eexist(ret, flags) ? 0 : ret; + } + + port = ntohs(e.port); + port_to = ip_set_get_h16(tb[IPSET_ATTR_PORT_TO]); + if (port > port_to) + swap(port, port_to); + + if (retried) + port = ntohs(h->next.port); + for (; port <= port_to; port++) { + e.port = htons(port); + ret = adtfn(set, &e, &ext, &ext, flags); + + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + ret = 0; + } + return ret; +} + +static struct ip_set_type hash_netport_type __read_mostly = { + .name = "hash:net,port", + .protocol = IPSET_PROTOCOL, + .features = IPSET_TYPE_IP | IPSET_TYPE_PORT | IPSET_TYPE_NOMATCH, + .dimension = IPSET_DIM_TWO, + .family = NFPROTO_UNSPEC, + .revision_min = IPSET_TYPE_REV_MIN, + .revision_max = IPSET_TYPE_REV_MAX, + .create_flags[IPSET_TYPE_REV_MAX] = IPSET_CREATE_FLAG_BUCKETSIZE, + .create = hash_netport_create, + .create_policy = { + [IPSET_ATTR_HASHSIZE] = { .type = NLA_U32 }, + [IPSET_ATTR_MAXELEM] = { .type = NLA_U32 }, + [IPSET_ATTR_INITVAL] = { .type = NLA_U32 }, + [IPSET_ATTR_BUCKETSIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_RESIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_PROTO] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + }, + .adt_policy = { + [IPSET_ATTR_IP] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP_TO] = { .type = NLA_NESTED }, + [IPSET_ATTR_PORT] = { .type = NLA_U16 }, + [IPSET_ATTR_PORT_TO] = { .type = NLA_U16 }, + [IPSET_ATTR_PROTO] = { .type = NLA_U8 }, + [IPSET_ATTR_CIDR] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_LINENO] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + [IPSET_ATTR_BYTES] = { .type = NLA_U64 }, + [IPSET_ATTR_PACKETS] = { .type = NLA_U64 }, + [IPSET_ATTR_COMMENT] = { .type = NLA_NUL_STRING, + .len = IPSET_MAX_COMMENT_SIZE }, + [IPSET_ATTR_SKBMARK] = { .type = NLA_U64 }, + [IPSET_ATTR_SKBPRIO] = { .type = NLA_U32 }, + [IPSET_ATTR_SKBQUEUE] = { .type = NLA_U16 }, + }, + .me = THIS_MODULE, +}; + +static int __init +hash_netport_init(void) +{ + return ip_set_type_register(&hash_netport_type); +} + +static void __exit +hash_netport_fini(void) +{ + rcu_barrier(); + ip_set_type_unregister(&hash_netport_type); +} + +module_init(hash_netport_init); +module_exit(hash_netport_fini); diff --git a/net/netfilter/ipset/ip_set_hash_netportnet.c b/net/netfilter/ipset/ip_set_hash_netportnet.c new file mode 100644 index 000000000..bf4f91b78 --- /dev/null +++ b/net/netfilter/ipset/ip_set_hash_netportnet.c @@ -0,0 +1,628 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2003-2013 Jozsef Kadlecsik */ + +/* Kernel module implementing an IP set type: the hash:ip,port,net type */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#define IPSET_TYPE_REV_MIN 0 +/* 0 Comments support added */ +/* 1 Forceadd support added */ +/* 2 skbinfo support added */ +#define IPSET_TYPE_REV_MAX 3 /* bucketsize, initval support added */ + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Oliver Smith "); +IP_SET_MODULE_DESC("hash:net,port,net", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX); +MODULE_ALIAS("ip_set_hash:net,port,net"); + +/* Type specific function prefix */ +#define HTYPE hash_netportnet +#define IP_SET_HASH_WITH_PROTO +#define IP_SET_HASH_WITH_NETS +#define IPSET_NET_COUNT 2 +#define IP_SET_HASH_WITH_NET0 + +/* IPv4 variant */ + +/* Member elements */ +struct hash_netportnet4_elem { + union { + __be32 ip[2]; + __be64 ipcmp; + }; + __be16 port; + union { + u8 cidr[2]; + u16 ccmp; + }; + u16 padding; + u8 nomatch; + u8 proto; +}; + +/* Common functions */ + +static bool +hash_netportnet4_data_equal(const struct hash_netportnet4_elem *ip1, + const struct hash_netportnet4_elem *ip2, + u32 *multi) +{ + return ip1->ipcmp == ip2->ipcmp && + ip1->ccmp == ip2->ccmp && + ip1->port == ip2->port && + ip1->proto == ip2->proto; +} + +static int +hash_netportnet4_do_data_match(const struct hash_netportnet4_elem *elem) +{ + return elem->nomatch ? -ENOTEMPTY : 1; +} + +static void +hash_netportnet4_data_set_flags(struct hash_netportnet4_elem *elem, u32 flags) +{ + elem->nomatch = !!((flags >> 16) & IPSET_FLAG_NOMATCH); +} + +static void +hash_netportnet4_data_reset_flags(struct hash_netportnet4_elem *elem, u8 *flags) +{ + swap(*flags, elem->nomatch); +} + +static void +hash_netportnet4_data_reset_elem(struct hash_netportnet4_elem *elem, + struct hash_netportnet4_elem *orig) +{ + elem->ip[1] = orig->ip[1]; +} + +static void +hash_netportnet4_data_netmask(struct hash_netportnet4_elem *elem, + u8 cidr, bool inner) +{ + if (inner) { + elem->ip[1] &= ip_set_netmask(cidr); + elem->cidr[1] = cidr; + } else { + elem->ip[0] &= ip_set_netmask(cidr); + elem->cidr[0] = cidr; + } +} + +static bool +hash_netportnet4_data_list(struct sk_buff *skb, + const struct hash_netportnet4_elem *data) +{ + u32 flags = data->nomatch ? IPSET_FLAG_NOMATCH : 0; + + if (nla_put_ipaddr4(skb, IPSET_ATTR_IP, data->ip[0]) || + nla_put_ipaddr4(skb, IPSET_ATTR_IP2, data->ip[1]) || + nla_put_net16(skb, IPSET_ATTR_PORT, data->port) || + nla_put_u8(skb, IPSET_ATTR_CIDR, data->cidr[0]) || + nla_put_u8(skb, IPSET_ATTR_CIDR2, data->cidr[1]) || + nla_put_u8(skb, IPSET_ATTR_PROTO, data->proto) || + (flags && + nla_put_net32(skb, IPSET_ATTR_CADT_FLAGS, htonl(flags)))) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_netportnet4_data_next(struct hash_netportnet4_elem *next, + const struct hash_netportnet4_elem *d) +{ + next->ipcmp = d->ipcmp; + next->port = d->port; +} + +#define MTYPE hash_netportnet4 +#define HOST_MASK 32 +#include "ip_set_hash_gen.h" + +static void +hash_netportnet4_init(struct hash_netportnet4_elem *e) +{ + e->cidr[0] = HOST_MASK; + e->cidr[1] = HOST_MASK; +} + +static int +hash_netportnet4_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + const struct hash_netportnet4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_netportnet4_elem e = { }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + e.cidr[0] = INIT_CIDR(h->nets[0].cidr[0], HOST_MASK); + e.cidr[1] = INIT_CIDR(h->nets[0].cidr[1], HOST_MASK); + if (adt == IPSET_TEST) + e.ccmp = (HOST_MASK << (sizeof(e.cidr[0]) * 8)) | HOST_MASK; + + if (!ip_set_get_ip4_port(skb, opt->flags & IPSET_DIM_TWO_SRC, + &e.port, &e.proto)) + return -EINVAL; + + ip4addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip[0]); + ip4addrptr(skb, opt->flags & IPSET_DIM_THREE_SRC, &e.ip[1]); + e.ip[0] &= ip_set_netmask(e.cidr[0]); + e.ip[1] &= ip_set_netmask(e.cidr[1]); + + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static u32 +hash_netportnet4_range_to_cidr(u32 from, u32 to, u8 *cidr) +{ + if (from == 0 && to == UINT_MAX) { + *cidr = 0; + return to; + } + return ip_set_range_to_cidr(from, to, cidr); +} + +static int +hash_netportnet4_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + struct hash_netportnet4 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_netportnet4_elem e = { }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + u32 ip = 0, ip_to = 0, p = 0, port, port_to; + u32 ip2_from = 0, ip2_to = 0, ip2, i = 0; + bool with_ports = false; + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + hash_netportnet4_init(&e); + if (unlikely(!tb[IPSET_ATTR_IP] || !tb[IPSET_ATTR_IP2] || + !ip_set_attr_netorder(tb, IPSET_ATTR_PORT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_PORT_TO) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP], &ip); + if (ret) + return ret; + + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP2], &ip2_from); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + if (tb[IPSET_ATTR_CIDR]) { + e.cidr[0] = nla_get_u8(tb[IPSET_ATTR_CIDR]); + if (e.cidr[0] > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + } + + if (tb[IPSET_ATTR_CIDR2]) { + e.cidr[1] = nla_get_u8(tb[IPSET_ATTR_CIDR2]); + if (e.cidr[1] > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + } + + e.port = nla_get_be16(tb[IPSET_ATTR_PORT]); + + if (tb[IPSET_ATTR_PROTO]) { + e.proto = nla_get_u8(tb[IPSET_ATTR_PROTO]); + with_ports = ip_set_proto_with_ports(e.proto); + + if (e.proto == 0) + return -IPSET_ERR_INVALID_PROTO; + } else { + return -IPSET_ERR_MISSING_PROTO; + } + + if (!(with_ports || e.proto == IPPROTO_ICMP)) + e.port = 0; + + if (tb[IPSET_ATTR_CADT_FLAGS]) { + u32 cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]); + + if (cadt_flags & IPSET_FLAG_NOMATCH) + flags |= (IPSET_FLAG_NOMATCH << 16); + } + + with_ports = with_ports && tb[IPSET_ATTR_PORT_TO]; + if (adt == IPSET_TEST || + !(tb[IPSET_ATTR_IP_TO] || with_ports || tb[IPSET_ATTR_IP2_TO])) { + e.ip[0] = htonl(ip & ip_set_hostmask(e.cidr[0])); + e.ip[1] = htonl(ip2_from & ip_set_hostmask(e.cidr[1])); + ret = adtfn(set, &e, &ext, &ext, flags); + return ip_set_enomatch(ret, flags, adt, set) ? -ret : + ip_set_eexist(ret, flags) ? 0 : ret; + } + + ip_to = ip; + if (tb[IPSET_ATTR_IP_TO]) { + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP_TO], &ip_to); + if (ret) + return ret; + if (ip > ip_to) + swap(ip, ip_to); + if (unlikely(ip + UINT_MAX == ip_to)) + return -IPSET_ERR_HASH_RANGE; + } else { + ip_set_mask_from_to(ip, ip_to, e.cidr[0]); + } + + port_to = port = ntohs(e.port); + if (tb[IPSET_ATTR_PORT_TO]) { + port_to = ip_set_get_h16(tb[IPSET_ATTR_PORT_TO]); + if (port > port_to) + swap(port, port_to); + } + + ip2_to = ip2_from; + if (tb[IPSET_ATTR_IP2_TO]) { + ret = ip_set_get_hostipaddr4(tb[IPSET_ATTR_IP2_TO], &ip2_to); + if (ret) + return ret; + if (ip2_from > ip2_to) + swap(ip2_from, ip2_to); + if (unlikely(ip2_from + UINT_MAX == ip2_to)) + return -IPSET_ERR_HASH_RANGE; + } else { + ip_set_mask_from_to(ip2_from, ip2_to, e.cidr[1]); + } + + if (retried) { + ip = ntohl(h->next.ip[0]); + p = ntohs(h->next.port); + ip2 = ntohl(h->next.ip[1]); + } else { + p = port; + ip2 = ip2_from; + } + + do { + e.ip[0] = htonl(ip); + ip = hash_netportnet4_range_to_cidr(ip, ip_to, &e.cidr[0]); + for (; p <= port_to; p++) { + e.port = htons(p); + do { + i++; + e.ip[1] = htonl(ip2); + if (i > IPSET_MAX_RANGE) { + hash_netportnet4_data_next(&h->next, + &e); + return -ERANGE; + } + ip2 = hash_netportnet4_range_to_cidr(ip2, + ip2_to, &e.cidr[1]); + ret = adtfn(set, &e, &ext, &ext, flags); + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + ret = 0; + } while (ip2++ < ip2_to); + ip2 = ip2_from; + } + p = port; + } while (ip++ < ip_to); + return ret; +} + +/* IPv6 variant */ + +struct hash_netportnet6_elem { + union nf_inet_addr ip[2]; + __be16 port; + union { + u8 cidr[2]; + u16 ccmp; + }; + u16 padding; + u8 nomatch; + u8 proto; +}; + +/* Common functions */ + +static bool +hash_netportnet6_data_equal(const struct hash_netportnet6_elem *ip1, + const struct hash_netportnet6_elem *ip2, + u32 *multi) +{ + return ipv6_addr_equal(&ip1->ip[0].in6, &ip2->ip[0].in6) && + ipv6_addr_equal(&ip1->ip[1].in6, &ip2->ip[1].in6) && + ip1->ccmp == ip2->ccmp && + ip1->port == ip2->port && + ip1->proto == ip2->proto; +} + +static int +hash_netportnet6_do_data_match(const struct hash_netportnet6_elem *elem) +{ + return elem->nomatch ? -ENOTEMPTY : 1; +} + +static void +hash_netportnet6_data_set_flags(struct hash_netportnet6_elem *elem, u32 flags) +{ + elem->nomatch = !!((flags >> 16) & IPSET_FLAG_NOMATCH); +} + +static void +hash_netportnet6_data_reset_flags(struct hash_netportnet6_elem *elem, u8 *flags) +{ + swap(*flags, elem->nomatch); +} + +static void +hash_netportnet6_data_reset_elem(struct hash_netportnet6_elem *elem, + struct hash_netportnet6_elem *orig) +{ + elem->ip[1] = orig->ip[1]; +} + +static void +hash_netportnet6_data_netmask(struct hash_netportnet6_elem *elem, + u8 cidr, bool inner) +{ + if (inner) { + ip6_netmask(&elem->ip[1], cidr); + elem->cidr[1] = cidr; + } else { + ip6_netmask(&elem->ip[0], cidr); + elem->cidr[0] = cidr; + } +} + +static bool +hash_netportnet6_data_list(struct sk_buff *skb, + const struct hash_netportnet6_elem *data) +{ + u32 flags = data->nomatch ? IPSET_FLAG_NOMATCH : 0; + + if (nla_put_ipaddr6(skb, IPSET_ATTR_IP, &data->ip[0].in6) || + nla_put_ipaddr6(skb, IPSET_ATTR_IP2, &data->ip[1].in6) || + nla_put_net16(skb, IPSET_ATTR_PORT, data->port) || + nla_put_u8(skb, IPSET_ATTR_CIDR, data->cidr[0]) || + nla_put_u8(skb, IPSET_ATTR_CIDR2, data->cidr[1]) || + nla_put_u8(skb, IPSET_ATTR_PROTO, data->proto) || + (flags && + nla_put_net32(skb, IPSET_ATTR_CADT_FLAGS, htonl(flags)))) + goto nla_put_failure; + return false; + +nla_put_failure: + return true; +} + +static void +hash_netportnet6_data_next(struct hash_netportnet6_elem *next, + const struct hash_netportnet6_elem *d) +{ + next->port = d->port; +} + +#undef MTYPE +#undef HOST_MASK + +#define MTYPE hash_netportnet6 +#define HOST_MASK 128 +#define IP_SET_EMIT_CREATE +#include "ip_set_hash_gen.h" + +static void +hash_netportnet6_init(struct hash_netportnet6_elem *e) +{ + e->cidr[0] = HOST_MASK; + e->cidr[1] = HOST_MASK; +} + +static int +hash_netportnet6_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + const struct hash_netportnet6 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_netportnet6_elem e = { }; + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + + e.cidr[0] = INIT_CIDR(h->nets[0].cidr[0], HOST_MASK); + e.cidr[1] = INIT_CIDR(h->nets[0].cidr[1], HOST_MASK); + if (adt == IPSET_TEST) + e.ccmp = (HOST_MASK << (sizeof(u8) * 8)) | HOST_MASK; + + if (!ip_set_get_ip6_port(skb, opt->flags & IPSET_DIM_TWO_SRC, + &e.port, &e.proto)) + return -EINVAL; + + ip6addrptr(skb, opt->flags & IPSET_DIM_ONE_SRC, &e.ip[0].in6); + ip6addrptr(skb, opt->flags & IPSET_DIM_THREE_SRC, &e.ip[1].in6); + ip6_netmask(&e.ip[0], e.cidr[0]); + ip6_netmask(&e.ip[1], e.cidr[1]); + + return adtfn(set, &e, &ext, &opt->ext, opt->cmdflags); +} + +static int +hash_netportnet6_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + const struct hash_netportnet6 *h = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct hash_netportnet6_elem e = { }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + u32 port, port_to; + bool with_ports = false; + int ret; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + hash_netportnet6_init(&e); + if (unlikely(!tb[IPSET_ATTR_IP] || !tb[IPSET_ATTR_IP2] || + !ip_set_attr_netorder(tb, IPSET_ATTR_PORT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_PORT_TO) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + if (unlikely(tb[IPSET_ATTR_IP_TO] || tb[IPSET_ATTR_IP2_TO])) + return -IPSET_ERR_HASH_RANGE_UNSUPPORTED; + + ret = ip_set_get_ipaddr6(tb[IPSET_ATTR_IP], &e.ip[0]); + if (ret) + return ret; + + ret = ip_set_get_ipaddr6(tb[IPSET_ATTR_IP2], &e.ip[1]); + if (ret) + return ret; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + + if (tb[IPSET_ATTR_CIDR]) { + e.cidr[0] = nla_get_u8(tb[IPSET_ATTR_CIDR]); + if (e.cidr[0] > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + } + + if (tb[IPSET_ATTR_CIDR2]) { + e.cidr[1] = nla_get_u8(tb[IPSET_ATTR_CIDR2]); + if (e.cidr[1] > HOST_MASK) + return -IPSET_ERR_INVALID_CIDR; + } + + ip6_netmask(&e.ip[0], e.cidr[0]); + ip6_netmask(&e.ip[1], e.cidr[1]); + + e.port = nla_get_be16(tb[IPSET_ATTR_PORT]); + + if (tb[IPSET_ATTR_PROTO]) { + e.proto = nla_get_u8(tb[IPSET_ATTR_PROTO]); + with_ports = ip_set_proto_with_ports(e.proto); + + if (e.proto == 0) + return -IPSET_ERR_INVALID_PROTO; + } else { + return -IPSET_ERR_MISSING_PROTO; + } + + if (!(with_ports || e.proto == IPPROTO_ICMPV6)) + e.port = 0; + + if (tb[IPSET_ATTR_CADT_FLAGS]) { + u32 cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]); + + if (cadt_flags & IPSET_FLAG_NOMATCH) + flags |= (IPSET_FLAG_NOMATCH << 16); + } + + if (adt == IPSET_TEST || !with_ports || !tb[IPSET_ATTR_PORT_TO]) { + ret = adtfn(set, &e, &ext, &ext, flags); + return ip_set_enomatch(ret, flags, adt, set) ? -ret : + ip_set_eexist(ret, flags) ? 0 : ret; + } + + port = ntohs(e.port); + port_to = ip_set_get_h16(tb[IPSET_ATTR_PORT_TO]); + if (port > port_to) + swap(port, port_to); + + if (retried) + port = ntohs(h->next.port); + for (; port <= port_to; port++) { + e.port = htons(port); + ret = adtfn(set, &e, &ext, &ext, flags); + + if (ret && !ip_set_eexist(ret, flags)) + return ret; + + ret = 0; + } + return ret; +} + +static struct ip_set_type hash_netportnet_type __read_mostly = { + .name = "hash:net,port,net", + .protocol = IPSET_PROTOCOL, + .features = IPSET_TYPE_IP | IPSET_TYPE_PORT | IPSET_TYPE_IP2 | + IPSET_TYPE_NOMATCH, + .dimension = IPSET_DIM_THREE, + .family = NFPROTO_UNSPEC, + .revision_min = IPSET_TYPE_REV_MIN, + .revision_max = IPSET_TYPE_REV_MAX, + .create_flags[IPSET_TYPE_REV_MAX] = IPSET_CREATE_FLAG_BUCKETSIZE, + .create = hash_netportnet_create, + .create_policy = { + [IPSET_ATTR_HASHSIZE] = { .type = NLA_U32 }, + [IPSET_ATTR_MAXELEM] = { .type = NLA_U32 }, + [IPSET_ATTR_INITVAL] = { .type = NLA_U32 }, + [IPSET_ATTR_BUCKETSIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_RESIZE] = { .type = NLA_U8 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + }, + .adt_policy = { + [IPSET_ATTR_IP] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP_TO] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP2] = { .type = NLA_NESTED }, + [IPSET_ATTR_IP2_TO] = { .type = NLA_NESTED }, + [IPSET_ATTR_PORT] = { .type = NLA_U16 }, + [IPSET_ATTR_PORT_TO] = { .type = NLA_U16 }, + [IPSET_ATTR_CIDR] = { .type = NLA_U8 }, + [IPSET_ATTR_CIDR2] = { .type = NLA_U8 }, + [IPSET_ATTR_PROTO] = { .type = NLA_U8 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_LINENO] = { .type = NLA_U32 }, + [IPSET_ATTR_BYTES] = { .type = NLA_U64 }, + [IPSET_ATTR_PACKETS] = { .type = NLA_U64 }, + [IPSET_ATTR_COMMENT] = { .type = NLA_NUL_STRING, + .len = IPSET_MAX_COMMENT_SIZE }, + [IPSET_ATTR_SKBMARK] = { .type = NLA_U64 }, + [IPSET_ATTR_SKBPRIO] = { .type = NLA_U32 }, + [IPSET_ATTR_SKBQUEUE] = { .type = NLA_U16 }, + }, + .me = THIS_MODULE, +}; + +static int __init +hash_netportnet_init(void) +{ + return ip_set_type_register(&hash_netportnet_type); +} + +static void __exit +hash_netportnet_fini(void) +{ + rcu_barrier(); + ip_set_type_unregister(&hash_netportnet_type); +} + +module_init(hash_netportnet_init); +module_exit(hash_netportnet_fini); diff --git a/net/netfilter/ipset/ip_set_list_set.c b/net/netfilter/ipset/ip_set_list_set.c new file mode 100644 index 000000000..5a67f7966 --- /dev/null +++ b/net/netfilter/ipset/ip_set_list_set.c @@ -0,0 +1,681 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2008-2013 Jozsef Kadlecsik */ + +/* Kernel module implementing an IP set type: the list:set type */ + +#include +#include +#include +#include +#include + +#include +#include + +#define IPSET_TYPE_REV_MIN 0 +/* 1 Counters support added */ +/* 2 Comments support added */ +#define IPSET_TYPE_REV_MAX 3 /* skbinfo support added */ + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Jozsef Kadlecsik "); +IP_SET_MODULE_DESC("list:set", IPSET_TYPE_REV_MIN, IPSET_TYPE_REV_MAX); +MODULE_ALIAS("ip_set_list:set"); + +/* Member elements */ +struct set_elem { + struct rcu_head rcu; + struct list_head list; + struct ip_set *set; /* Sigh, in order to cleanup reference */ + ip_set_id_t id; +} __aligned(__alignof__(u64)); + +struct set_adt_elem { + ip_set_id_t id; + ip_set_id_t refid; + int before; +}; + +/* Type structure */ +struct list_set { + u32 size; /* size of set list array */ + struct timer_list gc; /* garbage collection */ + struct ip_set *set; /* attached to this ip_set */ + struct net *net; /* namespace */ + struct list_head members; /* the set members */ +}; + +static int +list_set_ktest(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + struct ip_set_adt_opt *opt, const struct ip_set_ext *ext) +{ + struct list_set *map = set->data; + struct ip_set_ext *mext = &opt->ext; + struct set_elem *e; + u32 flags = opt->cmdflags; + int ret; + + /* Don't lookup sub-counters at all */ + opt->cmdflags &= ~IPSET_FLAG_MATCH_COUNTERS; + if (opt->cmdflags & IPSET_FLAG_SKIP_SUBCOUNTER_UPDATE) + opt->cmdflags |= IPSET_FLAG_SKIP_COUNTER_UPDATE; + list_for_each_entry_rcu(e, &map->members, list) { + ret = ip_set_test(e->id, skb, par, opt); + if (ret <= 0) + continue; + if (ip_set_match_extensions(set, ext, mext, flags, e)) + return 1; + } + return 0; +} + +static int +list_set_kadd(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + struct ip_set_adt_opt *opt, const struct ip_set_ext *ext) +{ + struct list_set *map = set->data; + struct set_elem *e; + int ret; + + list_for_each_entry(e, &map->members, list) { + if (SET_WITH_TIMEOUT(set) && + ip_set_timeout_expired(ext_timeout(e, set))) + continue; + ret = ip_set_add(e->id, skb, par, opt); + if (ret == 0) + return ret; + } + return 0; +} + +static int +list_set_kdel(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + struct ip_set_adt_opt *opt, const struct ip_set_ext *ext) +{ + struct list_set *map = set->data; + struct set_elem *e; + int ret; + + list_for_each_entry(e, &map->members, list) { + if (SET_WITH_TIMEOUT(set) && + ip_set_timeout_expired(ext_timeout(e, set))) + continue; + ret = ip_set_del(e->id, skb, par, opt); + if (ret == 0) + return ret; + } + return 0; +} + +static int +list_set_kadt(struct ip_set *set, const struct sk_buff *skb, + const struct xt_action_param *par, + enum ipset_adt adt, struct ip_set_adt_opt *opt) +{ + struct ip_set_ext ext = IP_SET_INIT_KEXT(skb, opt, set); + int ret = -EINVAL; + + rcu_read_lock(); + switch (adt) { + case IPSET_TEST: + ret = list_set_ktest(set, skb, par, opt, &ext); + break; + case IPSET_ADD: + ret = list_set_kadd(set, skb, par, opt, &ext); + break; + case IPSET_DEL: + ret = list_set_kdel(set, skb, par, opt, &ext); + break; + default: + break; + } + rcu_read_unlock(); + + return ret; +} + +/* Userspace interfaces: we are protected by the nfnl mutex */ + +static void +__list_set_del_rcu(struct rcu_head * rcu) +{ + struct set_elem *e = container_of(rcu, struct set_elem, rcu); + struct ip_set *set = e->set; + + ip_set_ext_destroy(set, e); + kfree(e); +} + +static void +list_set_del(struct ip_set *set, struct set_elem *e) +{ + struct list_set *map = set->data; + + set->elements--; + list_del_rcu(&e->list); + ip_set_put_byindex(map->net, e->id); + call_rcu(&e->rcu, __list_set_del_rcu); +} + +static void +list_set_replace(struct ip_set *set, struct set_elem *e, struct set_elem *old) +{ + struct list_set *map = set->data; + + list_replace_rcu(&old->list, &e->list); + ip_set_put_byindex(map->net, old->id); + call_rcu(&old->rcu, __list_set_del_rcu); +} + +static void +set_cleanup_entries(struct ip_set *set) +{ + struct list_set *map = set->data; + struct set_elem *e, *n; + + list_for_each_entry_safe(e, n, &map->members, list) + if (ip_set_timeout_expired(ext_timeout(e, set))) + list_set_del(set, e); +} + +static int +list_set_utest(struct ip_set *set, void *value, const struct ip_set_ext *ext, + struct ip_set_ext *mext, u32 flags) +{ + struct list_set *map = set->data; + struct set_adt_elem *d = value; + struct set_elem *e, *next, *prev = NULL; + int ret; + + list_for_each_entry(e, &map->members, list) { + if (SET_WITH_TIMEOUT(set) && + ip_set_timeout_expired(ext_timeout(e, set))) + continue; + else if (e->id != d->id) { + prev = e; + continue; + } + + if (d->before == 0) { + ret = 1; + } else if (d->before > 0) { + next = list_next_entry(e, list); + ret = !list_is_last(&e->list, &map->members) && + next->id == d->refid; + } else { + ret = prev && prev->id == d->refid; + } + return ret; + } + return 0; +} + +static void +list_set_init_extensions(struct ip_set *set, const struct ip_set_ext *ext, + struct set_elem *e) +{ + if (SET_WITH_COUNTER(set)) + ip_set_init_counter(ext_counter(e, set), ext); + if (SET_WITH_COMMENT(set)) + ip_set_init_comment(set, ext_comment(e, set), ext); + if (SET_WITH_SKBINFO(set)) + ip_set_init_skbinfo(ext_skbinfo(e, set), ext); + /* Update timeout last */ + if (SET_WITH_TIMEOUT(set)) + ip_set_timeout_set(ext_timeout(e, set), ext->timeout); +} + +static int +list_set_uadd(struct ip_set *set, void *value, const struct ip_set_ext *ext, + struct ip_set_ext *mext, u32 flags) +{ + struct list_set *map = set->data; + struct set_adt_elem *d = value; + struct set_elem *e, *n, *prev, *next; + bool flag_exist = flags & IPSET_FLAG_EXIST; + + /* Find where to add the new entry */ + n = prev = next = NULL; + list_for_each_entry(e, &map->members, list) { + if (SET_WITH_TIMEOUT(set) && + ip_set_timeout_expired(ext_timeout(e, set))) + continue; + else if (d->id == e->id) + n = e; + else if (d->before == 0 || e->id != d->refid) + continue; + else if (d->before > 0) + next = e; + else + prev = e; + } + + /* If before/after is used on an empty set */ + if ((d->before > 0 && !next) || + (d->before < 0 && !prev)) + return -IPSET_ERR_REF_EXIST; + + /* Re-add already existing element */ + if (n) { + if (!flag_exist) + return -IPSET_ERR_EXIST; + /* Update extensions */ + ip_set_ext_destroy(set, n); + list_set_init_extensions(set, ext, n); + + /* Set is already added to the list */ + ip_set_put_byindex(map->net, d->id); + return 0; + } + /* Add new entry */ + if (d->before == 0) { + /* Append */ + n = list_empty(&map->members) ? NULL : + list_last_entry(&map->members, struct set_elem, list); + } else if (d->before > 0) { + /* Insert after next element */ + if (!list_is_last(&next->list, &map->members)) + n = list_next_entry(next, list); + } else { + /* Insert before prev element */ + if (prev->list.prev != &map->members) + n = list_prev_entry(prev, list); + } + /* Can we replace a timed out entry? */ + if (n && + !(SET_WITH_TIMEOUT(set) && + ip_set_timeout_expired(ext_timeout(n, set)))) + n = NULL; + + e = kzalloc(set->dsize, GFP_ATOMIC); + if (!e) + return -ENOMEM; + e->id = d->id; + e->set = set; + INIT_LIST_HEAD(&e->list); + list_set_init_extensions(set, ext, e); + if (n) + list_set_replace(set, e, n); + else if (next) + list_add_tail_rcu(&e->list, &next->list); + else if (prev) + list_add_rcu(&e->list, &prev->list); + else + list_add_tail_rcu(&e->list, &map->members); + set->elements++; + + return 0; +} + +static int +list_set_udel(struct ip_set *set, void *value, const struct ip_set_ext *ext, + struct ip_set_ext *mext, u32 flags) +{ + struct list_set *map = set->data; + struct set_adt_elem *d = value; + struct set_elem *e, *next, *prev = NULL; + + list_for_each_entry(e, &map->members, list) { + if (SET_WITH_TIMEOUT(set) && + ip_set_timeout_expired(ext_timeout(e, set))) + continue; + else if (e->id != d->id) { + prev = e; + continue; + } + + if (d->before > 0) { + next = list_next_entry(e, list); + if (list_is_last(&e->list, &map->members) || + next->id != d->refid) + return -IPSET_ERR_REF_EXIST; + } else if (d->before < 0) { + if (!prev || prev->id != d->refid) + return -IPSET_ERR_REF_EXIST; + } + list_set_del(set, e); + return 0; + } + return d->before != 0 ? -IPSET_ERR_REF_EXIST : -IPSET_ERR_EXIST; +} + +static int +list_set_uadt(struct ip_set *set, struct nlattr *tb[], + enum ipset_adt adt, u32 *lineno, u32 flags, bool retried) +{ + struct list_set *map = set->data; + ipset_adtfn adtfn = set->variant->adt[adt]; + struct set_adt_elem e = { .refid = IPSET_INVALID_ID }; + struct ip_set_ext ext = IP_SET_INIT_UEXT(set); + struct ip_set *s; + int ret = 0; + + if (tb[IPSET_ATTR_LINENO]) + *lineno = nla_get_u32(tb[IPSET_ATTR_LINENO]); + + if (unlikely(!tb[IPSET_ATTR_NAME] || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + + ret = ip_set_get_extensions(set, tb, &ext); + if (ret) + return ret; + e.id = ip_set_get_byname(map->net, nla_data(tb[IPSET_ATTR_NAME]), &s); + if (e.id == IPSET_INVALID_ID) + return -IPSET_ERR_NAME; + /* "Loop detection" */ + if (s->type->features & IPSET_TYPE_NAME) { + ret = -IPSET_ERR_LOOP; + goto finish; + } + + if (tb[IPSET_ATTR_CADT_FLAGS]) { + u32 f = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]); + + e.before = f & IPSET_FLAG_BEFORE; + } + + if (e.before && !tb[IPSET_ATTR_NAMEREF]) { + ret = -IPSET_ERR_BEFORE; + goto finish; + } + + if (tb[IPSET_ATTR_NAMEREF]) { + e.refid = ip_set_get_byname(map->net, + nla_data(tb[IPSET_ATTR_NAMEREF]), + &s); + if (e.refid == IPSET_INVALID_ID) { + ret = -IPSET_ERR_NAMEREF; + goto finish; + } + if (!e.before) + e.before = -1; + } + if (adt != IPSET_TEST && SET_WITH_TIMEOUT(set)) + set_cleanup_entries(set); + + ret = adtfn(set, &e, &ext, &ext, flags); + +finish: + if (e.refid != IPSET_INVALID_ID) + ip_set_put_byindex(map->net, e.refid); + if (adt != IPSET_ADD || ret) + ip_set_put_byindex(map->net, e.id); + + return ip_set_eexist(ret, flags) ? 0 : ret; +} + +static void +list_set_flush(struct ip_set *set) +{ + struct list_set *map = set->data; + struct set_elem *e, *n; + + list_for_each_entry_safe(e, n, &map->members, list) + list_set_del(set, e); + set->elements = 0; + set->ext_size = 0; +} + +static void +list_set_destroy(struct ip_set *set) +{ + struct list_set *map = set->data; + struct set_elem *e, *n; + + if (SET_WITH_TIMEOUT(set)) + del_timer_sync(&map->gc); + + list_for_each_entry_safe(e, n, &map->members, list) { + list_del(&e->list); + ip_set_put_byindex(map->net, e->id); + ip_set_ext_destroy(set, e); + kfree(e); + } + kfree(map); + + set->data = NULL; +} + +/* Calculate the actual memory size of the set data */ +static size_t +list_set_memsize(const struct list_set *map, size_t dsize) +{ + struct set_elem *e; + u32 n = 0; + + rcu_read_lock(); + list_for_each_entry_rcu(e, &map->members, list) + n++; + rcu_read_unlock(); + + return (sizeof(*map) + n * dsize); +} + +static int +list_set_head(struct ip_set *set, struct sk_buff *skb) +{ + const struct list_set *map = set->data; + struct nlattr *nested; + size_t memsize = list_set_memsize(map, set->dsize) + set->ext_size; + + nested = nla_nest_start(skb, IPSET_ATTR_DATA); + if (!nested) + goto nla_put_failure; + if (nla_put_net32(skb, IPSET_ATTR_SIZE, htonl(map->size)) || + nla_put_net32(skb, IPSET_ATTR_REFERENCES, htonl(set->ref)) || + nla_put_net32(skb, IPSET_ATTR_MEMSIZE, htonl(memsize)) || + nla_put_net32(skb, IPSET_ATTR_ELEMENTS, htonl(set->elements))) + goto nla_put_failure; + if (unlikely(ip_set_put_flags(skb, set))) + goto nla_put_failure; + nla_nest_end(skb, nested); + + return 0; +nla_put_failure: + return -EMSGSIZE; +} + +static int +list_set_list(const struct ip_set *set, + struct sk_buff *skb, struct netlink_callback *cb) +{ + const struct list_set *map = set->data; + struct nlattr *atd, *nested; + u32 i = 0, first = cb->args[IPSET_CB_ARG0]; + char name[IPSET_MAXNAMELEN]; + struct set_elem *e; + int ret = 0; + + atd = nla_nest_start(skb, IPSET_ATTR_ADT); + if (!atd) + return -EMSGSIZE; + + rcu_read_lock(); + list_for_each_entry_rcu(e, &map->members, list) { + if (i < first || + (SET_WITH_TIMEOUT(set) && + ip_set_timeout_expired(ext_timeout(e, set)))) { + i++; + continue; + } + nested = nla_nest_start(skb, IPSET_ATTR_DATA); + if (!nested) + goto nla_put_failure; + ip_set_name_byindex(map->net, e->id, name); + if (nla_put_string(skb, IPSET_ATTR_NAME, name)) + goto nla_put_failure; + if (ip_set_put_extensions(skb, set, e, true)) + goto nla_put_failure; + nla_nest_end(skb, nested); + i++; + } + + nla_nest_end(skb, atd); + /* Set listing finished */ + cb->args[IPSET_CB_ARG0] = 0; + goto out; + +nla_put_failure: + nla_nest_cancel(skb, nested); + if (unlikely(i == first)) { + nla_nest_cancel(skb, atd); + cb->args[IPSET_CB_ARG0] = 0; + ret = -EMSGSIZE; + } else { + cb->args[IPSET_CB_ARG0] = i; + nla_nest_end(skb, atd); + } +out: + rcu_read_unlock(); + return ret; +} + +static bool +list_set_same_set(const struct ip_set *a, const struct ip_set *b) +{ + const struct list_set *x = a->data; + const struct list_set *y = b->data; + + return x->size == y->size && + a->timeout == b->timeout && + a->extensions == b->extensions; +} + +static const struct ip_set_type_variant set_variant = { + .kadt = list_set_kadt, + .uadt = list_set_uadt, + .adt = { + [IPSET_ADD] = list_set_uadd, + [IPSET_DEL] = list_set_udel, + [IPSET_TEST] = list_set_utest, + }, + .destroy = list_set_destroy, + .flush = list_set_flush, + .head = list_set_head, + .list = list_set_list, + .same_set = list_set_same_set, +}; + +static void +list_set_gc(struct timer_list *t) +{ + struct list_set *map = from_timer(map, t, gc); + struct ip_set *set = map->set; + + spin_lock_bh(&set->lock); + set_cleanup_entries(set); + spin_unlock_bh(&set->lock); + + map->gc.expires = jiffies + IPSET_GC_PERIOD(set->timeout) * HZ; + add_timer(&map->gc); +} + +static void +list_set_gc_init(struct ip_set *set, void (*gc)(struct timer_list *t)) +{ + struct list_set *map = set->data; + + timer_setup(&map->gc, gc, 0); + mod_timer(&map->gc, jiffies + IPSET_GC_PERIOD(set->timeout) * HZ); +} + +/* Create list:set type of sets */ + +static bool +init_list_set(struct net *net, struct ip_set *set, u32 size) +{ + struct list_set *map; + + map = kzalloc(sizeof(*map), GFP_KERNEL); + if (!map) + return false; + + map->size = size; + map->net = net; + map->set = set; + INIT_LIST_HEAD(&map->members); + set->data = map; + + return true; +} + +static int +list_set_create(struct net *net, struct ip_set *set, struct nlattr *tb[], + u32 flags) +{ + u32 size = IP_SET_LIST_DEFAULT_SIZE; + + if (unlikely(!ip_set_optattr_netorder(tb, IPSET_ATTR_SIZE) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) || + !ip_set_optattr_netorder(tb, IPSET_ATTR_CADT_FLAGS))) + return -IPSET_ERR_PROTOCOL; + + if (tb[IPSET_ATTR_SIZE]) + size = ip_set_get_h32(tb[IPSET_ATTR_SIZE]); + if (size < IP_SET_LIST_MIN_SIZE) + size = IP_SET_LIST_MIN_SIZE; + + set->variant = &set_variant; + set->dsize = ip_set_elem_len(set, tb, sizeof(struct set_elem), + __alignof__(struct set_elem)); + if (!init_list_set(net, set, size)) + return -ENOMEM; + if (tb[IPSET_ATTR_TIMEOUT]) { + set->timeout = ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]); + list_set_gc_init(set, list_set_gc); + } + return 0; +} + +static struct ip_set_type list_set_type __read_mostly = { + .name = "list:set", + .protocol = IPSET_PROTOCOL, + .features = IPSET_TYPE_NAME | IPSET_DUMP_LAST, + .dimension = IPSET_DIM_ONE, + .family = NFPROTO_UNSPEC, + .revision_min = IPSET_TYPE_REV_MIN, + .revision_max = IPSET_TYPE_REV_MAX, + .create = list_set_create, + .create_policy = { + [IPSET_ATTR_SIZE] = { .type = NLA_U32 }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + }, + .adt_policy = { + [IPSET_ATTR_NAME] = { .type = NLA_STRING, + .len = IPSET_MAXNAMELEN }, + [IPSET_ATTR_NAMEREF] = { .type = NLA_STRING, + .len = IPSET_MAXNAMELEN }, + [IPSET_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPSET_ATTR_LINENO] = { .type = NLA_U32 }, + [IPSET_ATTR_CADT_FLAGS] = { .type = NLA_U32 }, + [IPSET_ATTR_BYTES] = { .type = NLA_U64 }, + [IPSET_ATTR_PACKETS] = { .type = NLA_U64 }, + [IPSET_ATTR_COMMENT] = { .type = NLA_NUL_STRING, + .len = IPSET_MAX_COMMENT_SIZE }, + [IPSET_ATTR_SKBMARK] = { .type = NLA_U64 }, + [IPSET_ATTR_SKBPRIO] = { .type = NLA_U32 }, + [IPSET_ATTR_SKBQUEUE] = { .type = NLA_U16 }, + }, + .me = THIS_MODULE, +}; + +static int __init +list_set_init(void) +{ + return ip_set_type_register(&list_set_type); +} + +static void __exit +list_set_fini(void) +{ + rcu_barrier(); + ip_set_type_unregister(&list_set_type); +} + +module_init(list_set_init); +module_exit(list_set_fini); diff --git a/net/netfilter/ipset/pfxlen.c b/net/netfilter/ipset/pfxlen.c new file mode 100644 index 000000000..ff570bff9 --- /dev/null +++ b/net/netfilter/ipset/pfxlen.c @@ -0,0 +1,189 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include + +/* Prefixlen maps for fast conversions, by Jan Engelhardt. */ + +#ifdef E +#undef E +#endif + +#define PREFIXES_MAP \ + E(0x00000000, 0x00000000, 0x00000000, 0x00000000), \ + E(0x80000000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xC0000000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xE0000000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xF0000000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xF8000000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFC000000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFE000000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFF000000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFF800000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFC00000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFE00000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFF00000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFF80000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFC0000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFE0000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFF0000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFF8000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFFC000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFFE000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFFF000, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFFF800, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFFFC00, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFFFE00, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFFFF00, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFFFF80, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFFFFC0, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFFFFE0, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFFFFF0, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFFFFF8, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFC, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFE, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0x80000000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xC0000000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xE0000000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xF0000000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xF8000000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFC000000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFE000000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFF000000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFF800000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFC00000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFE00000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFF00000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFF80000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFC0000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFE0000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFF0000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFF8000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFC000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFE000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFF000, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFF800, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFC00, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFE00, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFF00, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFF80, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFC0, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFE0, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFF0, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFF8, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFC, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFE, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0x80000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xC0000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xE0000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xF0000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xF8000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFC000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFE000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFF000000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFF800000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFC00000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFE00000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFF00000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFF80000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFC0000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFE0000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFF0000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFF8000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFC000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFE000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFF000, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFF800, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFC00, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFE00, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFF00, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFF80, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFC0, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFE0, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFF0, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFF8, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFC, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFE, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x80000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xC0000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xE0000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xF0000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xF8000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFC000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFE000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFF000000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFF800000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFC00000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFE00000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFF00000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFF80000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFC0000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFE0000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFF0000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFF8000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFC000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFE000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFF000), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFF800), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFC00), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFE00), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFF00), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFF80), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFC0), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFE0), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFF0), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFF8), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFC), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFE), \ + E(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF), + +#define E(a, b, c, d) \ + {.ip6 = { \ + htonl(a), htonl(b), \ + htonl(c), htonl(d), \ + } } + +/* This table works for both IPv4 and IPv6; + * just use prefixlen_netmask_map[prefixlength].ip. + */ +const union nf_inet_addr ip_set_netmask_map[] = { + PREFIXES_MAP +}; +EXPORT_SYMBOL_GPL(ip_set_netmask_map); + +#undef E +#define E(a, b, c, d) \ + {.ip6 = { (__force __be32)a, (__force __be32)b, \ + (__force __be32)c, (__force __be32)d, \ + } } + +/* This table works for both IPv4 and IPv6; + * just use prefixlen_hostmask_map[prefixlength].ip. + */ +const union nf_inet_addr ip_set_hostmask_map[] = { + PREFIXES_MAP +}; +EXPORT_SYMBOL_GPL(ip_set_hostmask_map); + +/* Find the largest network which matches the range from left, in host order. */ +u32 +ip_set_range_to_cidr(u32 from, u32 to, u8 *cidr) +{ + u32 last; + u8 i; + + for (i = 1; i < 32; i++) { + if ((from & ip_set_hostmask(i)) != from) + continue; + last = from | ~ip_set_hostmask(i); + if (!after(last, to)) { + *cidr = i; + return last; + } + } + *cidr = 32; + return from; +} +EXPORT_SYMBOL_GPL(ip_set_range_to_cidr); diff --git a/net/netfilter/ipvs/Kconfig b/net/netfilter/ipvs/Kconfig new file mode 100644 index 000000000..2a3017b9c --- /dev/null +++ b/net/netfilter/ipvs/Kconfig @@ -0,0 +1,352 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# IP Virtual Server configuration +# +menuconfig IP_VS + tristate "IP virtual server support" + depends on INET && NETFILTER + depends on (NF_CONNTRACK || NF_CONNTRACK=n) + help + IP Virtual Server support will let you build a high-performance + virtual server based on cluster of two or more real servers. This + option must be enabled for at least one of the clustered computers + that will take care of intercepting incoming connections to a + single IP address and scheduling them to real servers. + + Three request dispatching techniques are implemented, they are + virtual server via NAT, virtual server via tunneling and virtual + server via direct routing. The several scheduling algorithms can + be used to choose which server the connection is directed to, + thus load balancing can be achieved among the servers. For more + information and its administration program, please visit the + following URL: . + + If you want to compile it in kernel, say Y. To compile it as a + module, choose M here. If unsure, say N. + +if IP_VS + +config IP_VS_IPV6 + bool "IPv6 support for IPVS" + depends on IPV6 = y || IP_VS = IPV6 + select NF_DEFRAG_IPV6 + help + Add IPv6 support to IPVS. + + Say Y if unsure. + +config IP_VS_DEBUG + bool "IP virtual server debugging" + help + Say Y here if you want to get additional messages useful in + debugging the IP virtual server code. You can change the debug + level in /proc/sys/net/ipv4/vs/debug_level + +config IP_VS_TAB_BITS + int "IPVS connection table size (the Nth power of 2)" + range 8 20 if !64BIT + range 8 27 if 64BIT + default 12 + help + The IPVS connection hash table uses the chaining scheme to handle + hash collisions. Using a big IPVS connection hash table will greatly + reduce conflicts when there are hundreds of thousands of connections + in the hash table. + + Note the table size must be power of 2. The table size will be the + value of 2 to the your input number power. The number to choose is + from 8 to 27 for 64BIT(20 otherwise), the default number is 12, + which means the table size is 4096. Don't input the number too + small, otherwise you will lose performance on it. You can adapt the + table size yourself, according to your virtual server application. + It is good to set the table size not far less than the number of + connections per second multiplying average lasting time of + connection in the table. For example, your virtual server gets 200 + connections per second, the connection lasts for 200 seconds in + average in the connection table, the table size should be not far + less than 200x200, it is good to set the table size 32768 (2**15). + + Another note that each connection occupies 128 bytes effectively and + each hash entry uses 8 bytes, so you can estimate how much memory is + needed for your box. + + You can overwrite this number setting conn_tab_bits module parameter + or by appending ip_vs.conn_tab_bits=? to the kernel command line if + IP VS was compiled built-in. + +comment "IPVS transport protocol load balancing support" + +config IP_VS_PROTO_TCP + bool "TCP load balancing support" + help + This option enables support for load balancing TCP transport + protocol. Say Y if unsure. + +config IP_VS_PROTO_UDP + bool "UDP load balancing support" + help + This option enables support for load balancing UDP transport + protocol. Say Y if unsure. + +config IP_VS_PROTO_AH_ESP + def_bool IP_VS_PROTO_ESP || IP_VS_PROTO_AH + +config IP_VS_PROTO_ESP + bool "ESP load balancing support" + help + This option enables support for load balancing ESP (Encapsulation + Security Payload) transport protocol. Say Y if unsure. + +config IP_VS_PROTO_AH + bool "AH load balancing support" + help + This option enables support for load balancing AH (Authentication + Header) transport protocol. Say Y if unsure. + +config IP_VS_PROTO_SCTP + bool "SCTP load balancing support" + select LIBCRC32C + help + This option enables support for load balancing SCTP transport + protocol. Say Y if unsure. + +comment "IPVS scheduler" + +config IP_VS_RR + tristate "round-robin scheduling" + help + The robin-robin scheduling algorithm simply directs network + connections to different real servers in a round-robin manner. + + If you want to compile it in kernel, say Y. To compile it as a + module, choose M here. If unsure, say N. + +config IP_VS_WRR + tristate "weighted round-robin scheduling" + help + The weighted robin-robin scheduling algorithm directs network + connections to different real servers based on server weights + in a round-robin manner. Servers with higher weights receive + new connections first than those with less weights, and servers + with higher weights get more connections than those with less + weights and servers with equal weights get equal connections. + + If you want to compile it in kernel, say Y. To compile it as a + module, choose M here. If unsure, say N. + +config IP_VS_LC + tristate "least-connection scheduling" + help + The least-connection scheduling algorithm directs network + connections to the server with the least number of active + connections. + + If you want to compile it in kernel, say Y. To compile it as a + module, choose M here. If unsure, say N. + +config IP_VS_WLC + tristate "weighted least-connection scheduling" + help + The weighted least-connection scheduling algorithm directs network + connections to the server with the least active connections + normalized by the server weight. + + If you want to compile it in kernel, say Y. To compile it as a + module, choose M here. If unsure, say N. + +config IP_VS_FO + tristate "weighted failover scheduling" + help + The weighted failover scheduling algorithm directs network + connections to the server with the highest weight that is + currently available. + + If you want to compile it in kernel, say Y. To compile it as a + module, choose M here. If unsure, say N. + +config IP_VS_OVF + tristate "weighted overflow scheduling" + help + The weighted overflow scheduling algorithm directs network + connections to the server with the highest weight that is + currently available and overflows to the next when active + connections exceed the node's weight. + + If you want to compile it in kernel, say Y. To compile it as a + module, choose M here. If unsure, say N. + +config IP_VS_LBLC + tristate "locality-based least-connection scheduling" + help + The locality-based least-connection scheduling algorithm is for + destination IP load balancing. It is usually used in cache cluster. + This algorithm usually directs packet destined for an IP address to + its server if the server is alive and under load. If the server is + overloaded (its active connection numbers is larger than its weight) + and there is a server in its half load, then allocate the weighted + least-connection server to this IP address. + + If you want to compile it in kernel, say Y. To compile it as a + module, choose M here. If unsure, say N. + +config IP_VS_LBLCR + tristate "locality-based least-connection with replication scheduling" + help + The locality-based least-connection with replication scheduling + algorithm is also for destination IP load balancing. It is + usually used in cache cluster. It differs from the LBLC scheduling + as follows: the load balancer maintains mappings from a target + to a set of server nodes that can serve the target. Requests for + a target are assigned to the least-connection node in the target's + server set. If all the node in the server set are over loaded, + it picks up a least-connection node in the cluster and adds it + in the sever set for the target. If the server set has not been + modified for the specified time, the most loaded node is removed + from the server set, in order to avoid high degree of replication. + + If you want to compile it in kernel, say Y. To compile it as a + module, choose M here. If unsure, say N. + +config IP_VS_DH + tristate "destination hashing scheduling" + help + The destination hashing scheduling algorithm assigns network + connections to the servers through looking up a statically assigned + hash table by their destination IP addresses. + + If you want to compile it in kernel, say Y. To compile it as a + module, choose M here. If unsure, say N. + +config IP_VS_SH + tristate "source hashing scheduling" + help + The source hashing scheduling algorithm assigns network + connections to the servers through looking up a statically assigned + hash table by their source IP addresses. + + If you want to compile it in kernel, say Y. To compile it as a + module, choose M here. If unsure, say N. + +config IP_VS_MH + tristate "maglev hashing scheduling" + help + The maglev consistent hashing scheduling algorithm provides the + Google's Maglev hashing algorithm as a IPVS scheduler. It assigns + network connections to the servers through looking up a statically + assigned special hash table called the lookup table. Maglev hashing + is to assign a preference list of all the lookup table positions + to each destination. + + Through this operation, The maglev hashing gives an almost equal + share of the lookup table to each of the destinations and provides + minimal disruption by using the lookup table. When the set of + destinations changes, a connection will likely be sent to the same + destination as it was before. + + If you want to compile it in kernel, say Y. To compile it as a + module, choose M here. If unsure, say N. + +config IP_VS_SED + tristate "shortest expected delay scheduling" + help + The shortest expected delay scheduling algorithm assigns network + connections to the server with the shortest expected delay. The + expected delay that the job will experience is (Ci + 1) / Ui if + sent to the ith server, in which Ci is the number of connections + on the ith server and Ui is the fixed service rate (weight) + of the ith server. + + If you want to compile it in kernel, say Y. To compile it as a + module, choose M here. If unsure, say N. + +config IP_VS_NQ + tristate "never queue scheduling" + help + The never queue scheduling algorithm adopts a two-speed model. + When there is an idle server available, the job will be sent to + the idle server, instead of waiting for a fast one. When there + is no idle server available, the job will be sent to the server + that minimize its expected delay (The Shortest Expected Delay + scheduling algorithm). + + If you want to compile it in kernel, say Y. To compile it as a + module, choose M here. If unsure, say N. + +config IP_VS_TWOS + tristate "weighted random twos choice least-connection scheduling" + help + The weighted random twos choice least-connection scheduling + algorithm picks two random real servers and directs network + connections to the server with the least active connections + normalized by the server weight. + + If you want to compile it in kernel, say Y. To compile it as a + module, choose M here. If unsure, say N. + +comment 'IPVS SH scheduler' + +config IP_VS_SH_TAB_BITS + int "IPVS source hashing table size (the Nth power of 2)" + range 4 20 + default 8 + help + The source hashing scheduler maps source IPs to destinations + stored in a hash table. This table is tiled by each destination + until all slots in the table are filled. When using weights to + allow destinations to receive more connections, the table is + tiled an amount proportional to the weights specified. The table + needs to be large enough to effectively fit all the destinations + multiplied by their respective weights. + +comment 'IPVS MH scheduler' + +config IP_VS_MH_TAB_INDEX + int "IPVS maglev hashing table index of size (the prime numbers)" + range 8 17 + default 12 + help + The maglev hashing scheduler maps source IPs to destinations + stored in a hash table. This table is assigned by a preference + list of the positions to each destination until all slots in + the table are filled. The index determines the prime for size of + the table as 251, 509, 1021, 2039, 4093, 8191, 16381, 32749, + 65521 or 131071. When using weights to allow destinations to + receive more connections, the table is assigned an amount + proportional to the weights specified. The table needs to be large + enough to effectively fit all the destinations multiplied by their + respective weights. + +comment 'IPVS application helper' + +config IP_VS_FTP + tristate "FTP protocol helper" + depends on IP_VS_PROTO_TCP && NF_CONNTRACK && NF_NAT && \ + NF_CONNTRACK_FTP + select IP_VS_NFCT + help + FTP is a protocol that transfers IP address and/or port number in + the payload. In the virtual server via Network Address Translation, + the IP address and port number of real servers cannot be sent to + clients in ftp connections directly, so FTP protocol helper is + required for tracking the connection and mangling it back to that of + virtual service. + + If you want to compile it in kernel, say Y. To compile it as a + module, choose M here. If unsure, say N. + +config IP_VS_NFCT + bool "Netfilter connection tracking" + depends on NF_CONNTRACK + help + The Netfilter connection tracking support allows the IPVS + connection state to be exported to the Netfilter framework + for filtering purposes. + +config IP_VS_PE_SIP + tristate "SIP persistence engine" + depends on IP_VS_PROTO_UDP + depends on NF_CONNTRACK_SIP + help + Allow persistence based on the SIP Call-ID + +endif # IP_VS diff --git a/net/netfilter/ipvs/Makefile b/net/netfilter/ipvs/Makefile new file mode 100644 index 000000000..bb5d8125c --- /dev/null +++ b/net/netfilter/ipvs/Makefile @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the IPVS modules on top of IPv4. +# + +# IPVS transport protocol load balancing support +ip_vs_proto-objs-y := +ip_vs_proto-objs-$(CONFIG_IP_VS_PROTO_TCP) += ip_vs_proto_tcp.o +ip_vs_proto-objs-$(CONFIG_IP_VS_PROTO_UDP) += ip_vs_proto_udp.o +ip_vs_proto-objs-$(CONFIG_IP_VS_PROTO_AH_ESP) += ip_vs_proto_ah_esp.o +ip_vs_proto-objs-$(CONFIG_IP_VS_PROTO_SCTP) += ip_vs_proto_sctp.o + +ip_vs-extra_objs-y := +ip_vs-extra_objs-$(CONFIG_IP_VS_NFCT) += ip_vs_nfct.o + +ip_vs-objs := ip_vs_conn.o ip_vs_core.o ip_vs_ctl.o ip_vs_sched.o \ + ip_vs_xmit.o ip_vs_app.o ip_vs_sync.o \ + ip_vs_est.o ip_vs_proto.o ip_vs_pe.o \ + $(ip_vs_proto-objs-y) $(ip_vs-extra_objs-y) + + +# IPVS core +obj-$(CONFIG_IP_VS) += ip_vs.o + +# IPVS schedulers +obj-$(CONFIG_IP_VS_RR) += ip_vs_rr.o +obj-$(CONFIG_IP_VS_WRR) += ip_vs_wrr.o +obj-$(CONFIG_IP_VS_LC) += ip_vs_lc.o +obj-$(CONFIG_IP_VS_WLC) += ip_vs_wlc.o +obj-$(CONFIG_IP_VS_FO) += ip_vs_fo.o +obj-$(CONFIG_IP_VS_OVF) += ip_vs_ovf.o +obj-$(CONFIG_IP_VS_LBLC) += ip_vs_lblc.o +obj-$(CONFIG_IP_VS_LBLCR) += ip_vs_lblcr.o +obj-$(CONFIG_IP_VS_DH) += ip_vs_dh.o +obj-$(CONFIG_IP_VS_SH) += ip_vs_sh.o +obj-$(CONFIG_IP_VS_MH) += ip_vs_mh.o +obj-$(CONFIG_IP_VS_SED) += ip_vs_sed.o +obj-$(CONFIG_IP_VS_NQ) += ip_vs_nq.o +obj-$(CONFIG_IP_VS_TWOS) += ip_vs_twos.o + +# IPVS application helpers +obj-$(CONFIG_IP_VS_FTP) += ip_vs_ftp.o + +# IPVS connection template retrievers +obj-$(CONFIG_IP_VS_PE_SIP) += ip_vs_pe_sip.o diff --git a/net/netfilter/ipvs/ip_vs_app.c b/net/netfilter/ipvs/ip_vs_app.c new file mode 100644 index 000000000..fdacbc3c1 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_app.c @@ -0,0 +1,617 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * ip_vs_app.c: Application module support for IPVS + * + * Authors: Wensong Zhang + * + * Most code here is taken from ip_masq_app.c in kernel 2.2. The difference + * is that ip_vs_app module handles the reverse direction (incoming requests + * and outgoing responses). + * + * IP_MASQ_APP application masquerading module + * + * Author: Juan Jose Ciarlante, + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +EXPORT_SYMBOL(register_ip_vs_app); +EXPORT_SYMBOL(unregister_ip_vs_app); +EXPORT_SYMBOL(register_ip_vs_app_inc); + +static DEFINE_MUTEX(__ip_vs_app_mutex); + +/* + * Get an ip_vs_app object + */ +static inline int ip_vs_app_get(struct ip_vs_app *app) +{ + return try_module_get(app->module); +} + + +static inline void ip_vs_app_put(struct ip_vs_app *app) +{ + module_put(app->module); +} + +static void ip_vs_app_inc_destroy(struct ip_vs_app *inc) +{ + kfree(inc->timeout_table); + kfree(inc); +} + +static void ip_vs_app_inc_rcu_free(struct rcu_head *head) +{ + struct ip_vs_app *inc = container_of(head, struct ip_vs_app, rcu_head); + + ip_vs_app_inc_destroy(inc); +} + +/* + * Allocate/initialize app incarnation and register it in proto apps. + */ +static int +ip_vs_app_inc_new(struct netns_ipvs *ipvs, struct ip_vs_app *app, __u16 proto, + __u16 port) +{ + struct ip_vs_protocol *pp; + struct ip_vs_app *inc; + int ret; + + if (!(pp = ip_vs_proto_get(proto))) + return -EPROTONOSUPPORT; + + if (!pp->unregister_app) + return -EOPNOTSUPP; + + inc = kmemdup(app, sizeof(*inc), GFP_KERNEL); + if (!inc) + return -ENOMEM; + INIT_LIST_HEAD(&inc->p_list); + INIT_LIST_HEAD(&inc->incs_list); + inc->app = app; + inc->port = htons(port); + atomic_set(&inc->usecnt, 0); + + if (app->timeouts) { + inc->timeout_table = + ip_vs_create_timeout_table(app->timeouts, + app->timeouts_size); + if (!inc->timeout_table) { + ret = -ENOMEM; + goto out; + } + } + + ret = pp->register_app(ipvs, inc); + if (ret) + goto out; + + list_add(&inc->a_list, &app->incs_list); + IP_VS_DBG(9, "%s App %s:%u registered\n", + pp->name, inc->name, ntohs(inc->port)); + + return 0; + + out: + ip_vs_app_inc_destroy(inc); + return ret; +} + + +/* + * Release app incarnation + */ +static void +ip_vs_app_inc_release(struct netns_ipvs *ipvs, struct ip_vs_app *inc) +{ + struct ip_vs_protocol *pp; + + if (!(pp = ip_vs_proto_get(inc->protocol))) + return; + + if (pp->unregister_app) + pp->unregister_app(ipvs, inc); + + IP_VS_DBG(9, "%s App %s:%u unregistered\n", + pp->name, inc->name, ntohs(inc->port)); + + list_del(&inc->a_list); + + call_rcu(&inc->rcu_head, ip_vs_app_inc_rcu_free); +} + + +/* + * Get reference to app inc (only called from softirq) + * + */ +int ip_vs_app_inc_get(struct ip_vs_app *inc) +{ + int result; + + result = ip_vs_app_get(inc->app); + if (result) + atomic_inc(&inc->usecnt); + return result; +} + + +/* + * Put the app inc (only called from timer or net softirq) + */ +void ip_vs_app_inc_put(struct ip_vs_app *inc) +{ + atomic_dec(&inc->usecnt); + ip_vs_app_put(inc->app); +} + + +/* + * Register an application incarnation in protocol applications + */ +int +register_ip_vs_app_inc(struct netns_ipvs *ipvs, struct ip_vs_app *app, __u16 proto, + __u16 port) +{ + int result; + + mutex_lock(&__ip_vs_app_mutex); + + result = ip_vs_app_inc_new(ipvs, app, proto, port); + + mutex_unlock(&__ip_vs_app_mutex); + + return result; +} + + +/* Register application for netns */ +struct ip_vs_app *register_ip_vs_app(struct netns_ipvs *ipvs, struct ip_vs_app *app) +{ + struct ip_vs_app *a; + int err = 0; + + mutex_lock(&__ip_vs_app_mutex); + + /* increase the module use count */ + if (!ip_vs_use_count_inc()) { + err = -ENOENT; + goto out_unlock; + } + + list_for_each_entry(a, &ipvs->app_list, a_list) { + if (!strcmp(app->name, a->name)) { + err = -EEXIST; + /* decrease the module use count */ + ip_vs_use_count_dec(); + goto out_unlock; + } + } + a = kmemdup(app, sizeof(*app), GFP_KERNEL); + if (!a) { + err = -ENOMEM; + /* decrease the module use count */ + ip_vs_use_count_dec(); + goto out_unlock; + } + INIT_LIST_HEAD(&a->incs_list); + list_add(&a->a_list, &ipvs->app_list); + +out_unlock: + mutex_unlock(&__ip_vs_app_mutex); + + return err ? ERR_PTR(err) : a; +} + + +/* + * ip_vs_app unregistration routine + * We are sure there are no app incarnations attached to services + * Caller should use synchronize_rcu() or rcu_barrier() + */ +void unregister_ip_vs_app(struct netns_ipvs *ipvs, struct ip_vs_app *app) +{ + struct ip_vs_app *a, *anxt, *inc, *nxt; + + mutex_lock(&__ip_vs_app_mutex); + + list_for_each_entry_safe(a, anxt, &ipvs->app_list, a_list) { + if (app && strcmp(app->name, a->name)) + continue; + list_for_each_entry_safe(inc, nxt, &a->incs_list, a_list) { + ip_vs_app_inc_release(ipvs, inc); + } + + list_del(&a->a_list); + kfree(a); + + /* decrease the module use count */ + ip_vs_use_count_dec(); + } + + mutex_unlock(&__ip_vs_app_mutex); +} + + +/* + * Bind ip_vs_conn to its ip_vs_app (called by cp constructor) + */ +int ip_vs_bind_app(struct ip_vs_conn *cp, + struct ip_vs_protocol *pp) +{ + return pp->app_conn_bind(cp); +} + + +/* + * Unbind cp from application incarnation (called by cp destructor) + */ +void ip_vs_unbind_app(struct ip_vs_conn *cp) +{ + struct ip_vs_app *inc = cp->app; + + if (!inc) + return; + + if (inc->unbind_conn) + inc->unbind_conn(inc, cp); + if (inc->done_conn) + inc->done_conn(inc, cp); + ip_vs_app_inc_put(inc); + cp->app = NULL; +} + + +/* + * Fixes th->seq based on ip_vs_seq info. + */ +static inline void vs_fix_seq(const struct ip_vs_seq *vseq, struct tcphdr *th) +{ + __u32 seq = ntohl(th->seq); + + /* + * Adjust seq with delta-offset for all packets after + * the most recent resized pkt seq and with previous_delta offset + * for all packets before most recent resized pkt seq. + */ + if (vseq->delta || vseq->previous_delta) { + if(after(seq, vseq->init_seq)) { + th->seq = htonl(seq + vseq->delta); + IP_VS_DBG(9, "%s(): added delta (%d) to seq\n", + __func__, vseq->delta); + } else { + th->seq = htonl(seq + vseq->previous_delta); + IP_VS_DBG(9, "%s(): added previous_delta (%d) to seq\n", + __func__, vseq->previous_delta); + } + } +} + + +/* + * Fixes th->ack_seq based on ip_vs_seq info. + */ +static inline void +vs_fix_ack_seq(const struct ip_vs_seq *vseq, struct tcphdr *th) +{ + __u32 ack_seq = ntohl(th->ack_seq); + + /* + * Adjust ack_seq with delta-offset for + * the packets AFTER most recent resized pkt has caused a shift + * for packets before most recent resized pkt, use previous_delta + */ + if (vseq->delta || vseq->previous_delta) { + /* since ack_seq is the number of octet that is expected + to receive next, so compare it with init_seq+delta */ + if(after(ack_seq, vseq->init_seq+vseq->delta)) { + th->ack_seq = htonl(ack_seq - vseq->delta); + IP_VS_DBG(9, "%s(): subtracted delta " + "(%d) from ack_seq\n", __func__, vseq->delta); + + } else { + th->ack_seq = htonl(ack_seq - vseq->previous_delta); + IP_VS_DBG(9, "%s(): subtracted " + "previous_delta (%d) from ack_seq\n", + __func__, vseq->previous_delta); + } + } +} + + +/* + * Updates ip_vs_seq if pkt has been resized + * Assumes already checked proto==IPPROTO_TCP and diff!=0. + */ +static inline void vs_seq_update(struct ip_vs_conn *cp, struct ip_vs_seq *vseq, + unsigned int flag, __u32 seq, int diff) +{ + /* spinlock is to keep updating cp->flags atomic */ + spin_lock_bh(&cp->lock); + if (!(cp->flags & flag) || after(seq, vseq->init_seq)) { + vseq->previous_delta = vseq->delta; + vseq->delta += diff; + vseq->init_seq = seq; + cp->flags |= flag; + } + spin_unlock_bh(&cp->lock); +} + +static inline int app_tcp_pkt_out(struct ip_vs_conn *cp, struct sk_buff *skb, + struct ip_vs_app *app, + struct ip_vs_iphdr *ipvsh) +{ + int diff; + const unsigned int tcp_offset = ip_hdrlen(skb); + struct tcphdr *th; + __u32 seq; + + if (skb_ensure_writable(skb, tcp_offset + sizeof(*th))) + return 0; + + th = (struct tcphdr *)(skb_network_header(skb) + tcp_offset); + + /* + * Remember seq number in case this pkt gets resized + */ + seq = ntohl(th->seq); + + /* + * Fix seq stuff if flagged as so. + */ + if (cp->flags & IP_VS_CONN_F_OUT_SEQ) + vs_fix_seq(&cp->out_seq, th); + if (cp->flags & IP_VS_CONN_F_IN_SEQ) + vs_fix_ack_seq(&cp->in_seq, th); + + /* + * Call private output hook function + */ + if (app->pkt_out == NULL) + return 1; + + if (!app->pkt_out(app, cp, skb, &diff, ipvsh)) + return 0; + + /* + * Update ip_vs seq stuff if len has changed. + */ + if (diff != 0) + vs_seq_update(cp, &cp->out_seq, + IP_VS_CONN_F_OUT_SEQ, seq, diff); + + return 1; +} + +/* + * Output pkt hook. Will call bound ip_vs_app specific function + * called by ipvs packet handler, assumes previously checked cp!=NULL + * returns false if it can't handle packet (oom) + */ +int ip_vs_app_pkt_out(struct ip_vs_conn *cp, struct sk_buff *skb, + struct ip_vs_iphdr *ipvsh) +{ + struct ip_vs_app *app; + + /* + * check if application module is bound to + * this ip_vs_conn. + */ + if ((app = cp->app) == NULL) + return 1; + + /* TCP is complicated */ + if (cp->protocol == IPPROTO_TCP) + return app_tcp_pkt_out(cp, skb, app, ipvsh); + + /* + * Call private output hook function + */ + if (app->pkt_out == NULL) + return 1; + + return app->pkt_out(app, cp, skb, NULL, ipvsh); +} + + +static inline int app_tcp_pkt_in(struct ip_vs_conn *cp, struct sk_buff *skb, + struct ip_vs_app *app, + struct ip_vs_iphdr *ipvsh) +{ + int diff; + const unsigned int tcp_offset = ip_hdrlen(skb); + struct tcphdr *th; + __u32 seq; + + if (skb_ensure_writable(skb, tcp_offset + sizeof(*th))) + return 0; + + th = (struct tcphdr *)(skb_network_header(skb) + tcp_offset); + + /* + * Remember seq number in case this pkt gets resized + */ + seq = ntohl(th->seq); + + /* + * Fix seq stuff if flagged as so. + */ + if (cp->flags & IP_VS_CONN_F_IN_SEQ) + vs_fix_seq(&cp->in_seq, th); + if (cp->flags & IP_VS_CONN_F_OUT_SEQ) + vs_fix_ack_seq(&cp->out_seq, th); + + /* + * Call private input hook function + */ + if (app->pkt_in == NULL) + return 1; + + if (!app->pkt_in(app, cp, skb, &diff, ipvsh)) + return 0; + + /* + * Update ip_vs seq stuff if len has changed. + */ + if (diff != 0) + vs_seq_update(cp, &cp->in_seq, + IP_VS_CONN_F_IN_SEQ, seq, diff); + + return 1; +} + +/* + * Input pkt hook. Will call bound ip_vs_app specific function + * called by ipvs packet handler, assumes previously checked cp!=NULL. + * returns false if can't handle packet (oom). + */ +int ip_vs_app_pkt_in(struct ip_vs_conn *cp, struct sk_buff *skb, + struct ip_vs_iphdr *ipvsh) +{ + struct ip_vs_app *app; + + /* + * check if application module is bound to + * this ip_vs_conn. + */ + if ((app = cp->app) == NULL) + return 1; + + /* TCP is complicated */ + if (cp->protocol == IPPROTO_TCP) + return app_tcp_pkt_in(cp, skb, app, ipvsh); + + /* + * Call private input hook function + */ + if (app->pkt_in == NULL) + return 1; + + return app->pkt_in(app, cp, skb, NULL, ipvsh); +} + + +#ifdef CONFIG_PROC_FS +/* + * /proc/net/ip_vs_app entry function + */ + +static struct ip_vs_app *ip_vs_app_idx(struct netns_ipvs *ipvs, loff_t pos) +{ + struct ip_vs_app *app, *inc; + + list_for_each_entry(app, &ipvs->app_list, a_list) { + list_for_each_entry(inc, &app->incs_list, a_list) { + if (pos-- == 0) + return inc; + } + } + return NULL; + +} + +static void *ip_vs_app_seq_start(struct seq_file *seq, loff_t *pos) +{ + struct net *net = seq_file_net(seq); + struct netns_ipvs *ipvs = net_ipvs(net); + + mutex_lock(&__ip_vs_app_mutex); + + return *pos ? ip_vs_app_idx(ipvs, *pos - 1) : SEQ_START_TOKEN; +} + +static void *ip_vs_app_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct ip_vs_app *inc, *app; + struct list_head *e; + struct net *net = seq_file_net(seq); + struct netns_ipvs *ipvs = net_ipvs(net); + + ++*pos; + if (v == SEQ_START_TOKEN) + return ip_vs_app_idx(ipvs, 0); + + inc = v; + app = inc->app; + + if ((e = inc->a_list.next) != &app->incs_list) + return list_entry(e, struct ip_vs_app, a_list); + + /* go on to next application */ + for (e = app->a_list.next; e != &ipvs->app_list; e = e->next) { + app = list_entry(e, struct ip_vs_app, a_list); + list_for_each_entry(inc, &app->incs_list, a_list) { + return inc; + } + } + return NULL; +} + +static void ip_vs_app_seq_stop(struct seq_file *seq, void *v) +{ + mutex_unlock(&__ip_vs_app_mutex); +} + +static int ip_vs_app_seq_show(struct seq_file *seq, void *v) +{ + if (v == SEQ_START_TOKEN) + seq_puts(seq, "prot port usecnt name\n"); + else { + const struct ip_vs_app *inc = v; + + seq_printf(seq, "%-3s %-7u %-6d %-17s\n", + ip_vs_proto_name(inc->protocol), + ntohs(inc->port), + atomic_read(&inc->usecnt), + inc->name); + } + return 0; +} + +static const struct seq_operations ip_vs_app_seq_ops = { + .start = ip_vs_app_seq_start, + .next = ip_vs_app_seq_next, + .stop = ip_vs_app_seq_stop, + .show = ip_vs_app_seq_show, +}; +#endif + +int __net_init ip_vs_app_net_init(struct netns_ipvs *ipvs) +{ + INIT_LIST_HEAD(&ipvs->app_list); +#ifdef CONFIG_PROC_FS + if (!proc_create_net("ip_vs_app", 0, ipvs->net->proc_net, + &ip_vs_app_seq_ops, + sizeof(struct seq_net_private))) + return -ENOMEM; +#endif + return 0; +} + +void __net_exit ip_vs_app_net_cleanup(struct netns_ipvs *ipvs) +{ + unregister_ip_vs_app(ipvs, NULL /* all */); +#ifdef CONFIG_PROC_FS + remove_proc_entry("ip_vs_app", ipvs->net->proc_net); +#endif +} diff --git a/net/netfilter/ipvs/ip_vs_conn.c b/net/netfilter/ipvs/ip_vs_conn.c new file mode 100644 index 000000000..e1b9b5290 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_conn.c @@ -0,0 +1,1538 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPVS An implementation of the IP virtual server support for the + * LINUX operating system. IPVS is now implemented as a module + * over the Netfilter framework. IPVS can be used to build a + * high-performance and highly available server based on a + * cluster of servers. + * + * Authors: Wensong Zhang + * Peter Kese + * Julian Anastasov + * + * The IPVS code for kernel 2.2 was done by Wensong Zhang and Peter Kese, + * with changes/fixes from Julian Anastasov, Lars Marowsky-Bree, Horms + * and others. Many code here is taken from IP MASQ code of kernel 2.2. + * + * Changes: + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include /* for proc_net_* */ +#include +#include +#include +#include + +#include +#include + + +#ifndef CONFIG_IP_VS_TAB_BITS +#define CONFIG_IP_VS_TAB_BITS 12 +#endif + +/* + * Connection hash size. Default is what was selected at compile time. +*/ +static int ip_vs_conn_tab_bits = CONFIG_IP_VS_TAB_BITS; +module_param_named(conn_tab_bits, ip_vs_conn_tab_bits, int, 0444); +MODULE_PARM_DESC(conn_tab_bits, "Set connections' hash size"); + +/* size and mask values */ +int ip_vs_conn_tab_size __read_mostly; +static int ip_vs_conn_tab_mask __read_mostly; + +/* + * Connection hash table: for input and output packets lookups of IPVS + */ +static struct hlist_head *ip_vs_conn_tab __read_mostly; + +/* SLAB cache for IPVS connections */ +static struct kmem_cache *ip_vs_conn_cachep __read_mostly; + +/* counter for no client port connections */ +static atomic_t ip_vs_conn_no_cport_cnt = ATOMIC_INIT(0); + +/* random value for IPVS connection hash */ +static unsigned int ip_vs_conn_rnd __read_mostly; + +/* + * Fine locking granularity for big connection hash table + */ +#define CT_LOCKARRAY_BITS 5 +#define CT_LOCKARRAY_SIZE (1<>8)) & ip_vs_conn_tab_mask; +#endif + return (jhash_3words((__force u32)addr->ip, (__force u32)port, proto, + ip_vs_conn_rnd) ^ + ((size_t)ipvs>>8)) & ip_vs_conn_tab_mask; +} + +static unsigned int ip_vs_conn_hashkey_param(const struct ip_vs_conn_param *p, + bool inverse) +{ + const union nf_inet_addr *addr; + __be16 port; + + if (p->pe_data && p->pe->hashkey_raw) + return p->pe->hashkey_raw(p, ip_vs_conn_rnd, inverse) & + ip_vs_conn_tab_mask; + + if (likely(!inverse)) { + addr = p->caddr; + port = p->cport; + } else { + addr = p->vaddr; + port = p->vport; + } + + return ip_vs_conn_hashkey(p->ipvs, p->af, p->protocol, addr, port); +} + +static unsigned int ip_vs_conn_hashkey_conn(const struct ip_vs_conn *cp) +{ + struct ip_vs_conn_param p; + + ip_vs_conn_fill_param(cp->ipvs, cp->af, cp->protocol, + &cp->caddr, cp->cport, NULL, 0, &p); + + if (cp->pe) { + p.pe = cp->pe; + p.pe_data = cp->pe_data; + p.pe_data_len = cp->pe_data_len; + } + + return ip_vs_conn_hashkey_param(&p, false); +} + +/* + * Hashes ip_vs_conn in ip_vs_conn_tab by netns,proto,addr,port. + * returns bool success. + */ +static inline int ip_vs_conn_hash(struct ip_vs_conn *cp) +{ + unsigned int hash; + int ret; + + if (cp->flags & IP_VS_CONN_F_ONE_PACKET) + return 0; + + /* Hash by protocol, client address and port */ + hash = ip_vs_conn_hashkey_conn(cp); + + ct_write_lock_bh(hash); + spin_lock(&cp->lock); + + if (!(cp->flags & IP_VS_CONN_F_HASHED)) { + cp->flags |= IP_VS_CONN_F_HASHED; + refcount_inc(&cp->refcnt); + hlist_add_head_rcu(&cp->c_list, &ip_vs_conn_tab[hash]); + ret = 1; + } else { + pr_err("%s(): request for already hashed, called from %pS\n", + __func__, __builtin_return_address(0)); + ret = 0; + } + + spin_unlock(&cp->lock); + ct_write_unlock_bh(hash); + + return ret; +} + + +/* + * UNhashes ip_vs_conn from ip_vs_conn_tab. + * returns bool success. Caller should hold conn reference. + */ +static inline int ip_vs_conn_unhash(struct ip_vs_conn *cp) +{ + unsigned int hash; + int ret; + + /* unhash it and decrease its reference counter */ + hash = ip_vs_conn_hashkey_conn(cp); + + ct_write_lock_bh(hash); + spin_lock(&cp->lock); + + if (cp->flags & IP_VS_CONN_F_HASHED) { + hlist_del_rcu(&cp->c_list); + cp->flags &= ~IP_VS_CONN_F_HASHED; + refcount_dec(&cp->refcnt); + ret = 1; + } else + ret = 0; + + spin_unlock(&cp->lock); + ct_write_unlock_bh(hash); + + return ret; +} + +/* Try to unlink ip_vs_conn from ip_vs_conn_tab. + * returns bool success. + */ +static inline bool ip_vs_conn_unlink(struct ip_vs_conn *cp) +{ + unsigned int hash; + bool ret = false; + + if (cp->flags & IP_VS_CONN_F_ONE_PACKET) + return refcount_dec_if_one(&cp->refcnt); + + hash = ip_vs_conn_hashkey_conn(cp); + + ct_write_lock_bh(hash); + spin_lock(&cp->lock); + + if (cp->flags & IP_VS_CONN_F_HASHED) { + /* Decrease refcnt and unlink conn only if we are last user */ + if (refcount_dec_if_one(&cp->refcnt)) { + hlist_del_rcu(&cp->c_list); + cp->flags &= ~IP_VS_CONN_F_HASHED; + ret = true; + } + } + + spin_unlock(&cp->lock); + ct_write_unlock_bh(hash); + + return ret; +} + + +/* + * Gets ip_vs_conn associated with supplied parameters in the ip_vs_conn_tab. + * Called for pkts coming from OUTside-to-INside. + * p->caddr, p->cport: pkt source address (foreign host) + * p->vaddr, p->vport: pkt dest address (load balancer) + */ +static inline struct ip_vs_conn * +__ip_vs_conn_in_get(const struct ip_vs_conn_param *p) +{ + unsigned int hash; + struct ip_vs_conn *cp; + + hash = ip_vs_conn_hashkey_param(p, false); + + rcu_read_lock(); + + hlist_for_each_entry_rcu(cp, &ip_vs_conn_tab[hash], c_list) { + if (p->cport == cp->cport && p->vport == cp->vport && + cp->af == p->af && + ip_vs_addr_equal(p->af, p->caddr, &cp->caddr) && + ip_vs_addr_equal(p->af, p->vaddr, &cp->vaddr) && + ((!p->cport) ^ (!(cp->flags & IP_VS_CONN_F_NO_CPORT))) && + p->protocol == cp->protocol && + cp->ipvs == p->ipvs) { + if (!__ip_vs_conn_get(cp)) + continue; + /* HIT */ + rcu_read_unlock(); + return cp; + } + } + + rcu_read_unlock(); + + return NULL; +} + +struct ip_vs_conn *ip_vs_conn_in_get(const struct ip_vs_conn_param *p) +{ + struct ip_vs_conn *cp; + + cp = __ip_vs_conn_in_get(p); + if (!cp && atomic_read(&ip_vs_conn_no_cport_cnt)) { + struct ip_vs_conn_param cport_zero_p = *p; + cport_zero_p.cport = 0; + cp = __ip_vs_conn_in_get(&cport_zero_p); + } + + IP_VS_DBG_BUF(9, "lookup/in %s %s:%d->%s:%d %s\n", + ip_vs_proto_name(p->protocol), + IP_VS_DBG_ADDR(p->af, p->caddr), ntohs(p->cport), + IP_VS_DBG_ADDR(p->af, p->vaddr), ntohs(p->vport), + cp ? "hit" : "not hit"); + + return cp; +} + +static int +ip_vs_conn_fill_param_proto(struct netns_ipvs *ipvs, + int af, const struct sk_buff *skb, + const struct ip_vs_iphdr *iph, + struct ip_vs_conn_param *p) +{ + __be16 _ports[2], *pptr; + + pptr = frag_safe_skb_hp(skb, iph->len, sizeof(_ports), _ports); + if (pptr == NULL) + return 1; + + if (likely(!ip_vs_iph_inverse(iph))) + ip_vs_conn_fill_param(ipvs, af, iph->protocol, &iph->saddr, + pptr[0], &iph->daddr, pptr[1], p); + else + ip_vs_conn_fill_param(ipvs, af, iph->protocol, &iph->daddr, + pptr[1], &iph->saddr, pptr[0], p); + return 0; +} + +struct ip_vs_conn * +ip_vs_conn_in_get_proto(struct netns_ipvs *ipvs, int af, + const struct sk_buff *skb, + const struct ip_vs_iphdr *iph) +{ + struct ip_vs_conn_param p; + + if (ip_vs_conn_fill_param_proto(ipvs, af, skb, iph, &p)) + return NULL; + + return ip_vs_conn_in_get(&p); +} +EXPORT_SYMBOL_GPL(ip_vs_conn_in_get_proto); + +/* Get reference to connection template */ +struct ip_vs_conn *ip_vs_ct_in_get(const struct ip_vs_conn_param *p) +{ + unsigned int hash; + struct ip_vs_conn *cp; + + hash = ip_vs_conn_hashkey_param(p, false); + + rcu_read_lock(); + + hlist_for_each_entry_rcu(cp, &ip_vs_conn_tab[hash], c_list) { + if (unlikely(p->pe_data && p->pe->ct_match)) { + if (cp->ipvs != p->ipvs) + continue; + if (p->pe == cp->pe && p->pe->ct_match(p, cp)) { + if (__ip_vs_conn_get(cp)) + goto out; + } + continue; + } + + if (cp->af == p->af && + ip_vs_addr_equal(p->af, p->caddr, &cp->caddr) && + /* protocol should only be IPPROTO_IP if + * p->vaddr is a fwmark */ + ip_vs_addr_equal(p->protocol == IPPROTO_IP ? AF_UNSPEC : + p->af, p->vaddr, &cp->vaddr) && + p->vport == cp->vport && p->cport == cp->cport && + cp->flags & IP_VS_CONN_F_TEMPLATE && + p->protocol == cp->protocol && + cp->ipvs == p->ipvs) { + if (__ip_vs_conn_get(cp)) + goto out; + } + } + cp = NULL; + + out: + rcu_read_unlock(); + + IP_VS_DBG_BUF(9, "template lookup/in %s %s:%d->%s:%d %s\n", + ip_vs_proto_name(p->protocol), + IP_VS_DBG_ADDR(p->af, p->caddr), ntohs(p->cport), + IP_VS_DBG_ADDR(p->af, p->vaddr), ntohs(p->vport), + cp ? "hit" : "not hit"); + + return cp; +} + +/* Gets ip_vs_conn associated with supplied parameters in the ip_vs_conn_tab. + * Called for pkts coming from inside-to-OUTside. + * p->caddr, p->cport: pkt source address (inside host) + * p->vaddr, p->vport: pkt dest address (foreign host) */ +struct ip_vs_conn *ip_vs_conn_out_get(const struct ip_vs_conn_param *p) +{ + unsigned int hash; + struct ip_vs_conn *cp, *ret=NULL; + const union nf_inet_addr *saddr; + __be16 sport; + + /* + * Check for "full" addressed entries + */ + hash = ip_vs_conn_hashkey_param(p, true); + + rcu_read_lock(); + + hlist_for_each_entry_rcu(cp, &ip_vs_conn_tab[hash], c_list) { + if (p->vport != cp->cport) + continue; + + if (IP_VS_FWD_METHOD(cp) != IP_VS_CONN_F_MASQ) { + sport = cp->vport; + saddr = &cp->vaddr; + } else { + sport = cp->dport; + saddr = &cp->daddr; + } + + if (p->cport == sport && cp->af == p->af && + ip_vs_addr_equal(p->af, p->vaddr, &cp->caddr) && + ip_vs_addr_equal(p->af, p->caddr, saddr) && + p->protocol == cp->protocol && + cp->ipvs == p->ipvs) { + if (!__ip_vs_conn_get(cp)) + continue; + /* HIT */ + ret = cp; + break; + } + } + + rcu_read_unlock(); + + IP_VS_DBG_BUF(9, "lookup/out %s %s:%d->%s:%d %s\n", + ip_vs_proto_name(p->protocol), + IP_VS_DBG_ADDR(p->af, p->caddr), ntohs(p->cport), + IP_VS_DBG_ADDR(p->af, p->vaddr), ntohs(p->vport), + ret ? "hit" : "not hit"); + + return ret; +} + +struct ip_vs_conn * +ip_vs_conn_out_get_proto(struct netns_ipvs *ipvs, int af, + const struct sk_buff *skb, + const struct ip_vs_iphdr *iph) +{ + struct ip_vs_conn_param p; + + if (ip_vs_conn_fill_param_proto(ipvs, af, skb, iph, &p)) + return NULL; + + return ip_vs_conn_out_get(&p); +} +EXPORT_SYMBOL_GPL(ip_vs_conn_out_get_proto); + +/* + * Put back the conn and restart its timer with its timeout + */ +static void __ip_vs_conn_put_timer(struct ip_vs_conn *cp) +{ + unsigned long t = (cp->flags & IP_VS_CONN_F_ONE_PACKET) ? + 0 : cp->timeout; + mod_timer(&cp->timer, jiffies+t); + + __ip_vs_conn_put(cp); +} + +void ip_vs_conn_put(struct ip_vs_conn *cp) +{ + if ((cp->flags & IP_VS_CONN_F_ONE_PACKET) && + (refcount_read(&cp->refcnt) == 1) && + !timer_pending(&cp->timer)) + /* expire connection immediately */ + ip_vs_conn_expire(&cp->timer); + else + __ip_vs_conn_put_timer(cp); +} + +/* + * Fill a no_client_port connection with a client port number + */ +void ip_vs_conn_fill_cport(struct ip_vs_conn *cp, __be16 cport) +{ + if (ip_vs_conn_unhash(cp)) { + spin_lock_bh(&cp->lock); + if (cp->flags & IP_VS_CONN_F_NO_CPORT) { + atomic_dec(&ip_vs_conn_no_cport_cnt); + cp->flags &= ~IP_VS_CONN_F_NO_CPORT; + cp->cport = cport; + } + spin_unlock_bh(&cp->lock); + + /* hash on new dport */ + ip_vs_conn_hash(cp); + } +} + + +/* + * Bind a connection entry with the corresponding packet_xmit. + * Called by ip_vs_conn_new. + */ +static inline void ip_vs_bind_xmit(struct ip_vs_conn *cp) +{ + switch (IP_VS_FWD_METHOD(cp)) { + case IP_VS_CONN_F_MASQ: + cp->packet_xmit = ip_vs_nat_xmit; + break; + + case IP_VS_CONN_F_TUNNEL: +#ifdef CONFIG_IP_VS_IPV6 + if (cp->daf == AF_INET6) + cp->packet_xmit = ip_vs_tunnel_xmit_v6; + else +#endif + cp->packet_xmit = ip_vs_tunnel_xmit; + break; + + case IP_VS_CONN_F_DROUTE: + cp->packet_xmit = ip_vs_dr_xmit; + break; + + case IP_VS_CONN_F_LOCALNODE: + cp->packet_xmit = ip_vs_null_xmit; + break; + + case IP_VS_CONN_F_BYPASS: + cp->packet_xmit = ip_vs_bypass_xmit; + break; + } +} + +#ifdef CONFIG_IP_VS_IPV6 +static inline void ip_vs_bind_xmit_v6(struct ip_vs_conn *cp) +{ + switch (IP_VS_FWD_METHOD(cp)) { + case IP_VS_CONN_F_MASQ: + cp->packet_xmit = ip_vs_nat_xmit_v6; + break; + + case IP_VS_CONN_F_TUNNEL: + if (cp->daf == AF_INET6) + cp->packet_xmit = ip_vs_tunnel_xmit_v6; + else + cp->packet_xmit = ip_vs_tunnel_xmit; + break; + + case IP_VS_CONN_F_DROUTE: + cp->packet_xmit = ip_vs_dr_xmit_v6; + break; + + case IP_VS_CONN_F_LOCALNODE: + cp->packet_xmit = ip_vs_null_xmit; + break; + + case IP_VS_CONN_F_BYPASS: + cp->packet_xmit = ip_vs_bypass_xmit_v6; + break; + } +} +#endif + + +static inline int ip_vs_dest_totalconns(struct ip_vs_dest *dest) +{ + return atomic_read(&dest->activeconns) + + atomic_read(&dest->inactconns); +} + +/* + * Bind a connection entry with a virtual service destination + * Called just after a new connection entry is created. + */ +static inline void +ip_vs_bind_dest(struct ip_vs_conn *cp, struct ip_vs_dest *dest) +{ + unsigned int conn_flags; + __u32 flags; + + /* if dest is NULL, then return directly */ + if (!dest) + return; + + /* Increase the refcnt counter of the dest */ + ip_vs_dest_hold(dest); + + conn_flags = atomic_read(&dest->conn_flags); + if (cp->protocol != IPPROTO_UDP) + conn_flags &= ~IP_VS_CONN_F_ONE_PACKET; + flags = cp->flags; + /* Bind with the destination and its corresponding transmitter */ + if (flags & IP_VS_CONN_F_SYNC) { + /* if the connection is not template and is created + * by sync, preserve the activity flag. + */ + if (!(flags & IP_VS_CONN_F_TEMPLATE)) + conn_flags &= ~IP_VS_CONN_F_INACTIVE; + /* connections inherit forwarding method from dest */ + flags &= ~(IP_VS_CONN_F_FWD_MASK | IP_VS_CONN_F_NOOUTPUT); + } + flags |= conn_flags; + cp->flags = flags; + cp->dest = dest; + + IP_VS_DBG_BUF(7, "Bind-dest %s c:%s:%d v:%s:%d " + "d:%s:%d fwd:%c s:%u conn->flags:%X conn->refcnt:%d " + "dest->refcnt:%d\n", + ip_vs_proto_name(cp->protocol), + IP_VS_DBG_ADDR(cp->af, &cp->caddr), ntohs(cp->cport), + IP_VS_DBG_ADDR(cp->af, &cp->vaddr), ntohs(cp->vport), + IP_VS_DBG_ADDR(cp->daf, &cp->daddr), ntohs(cp->dport), + ip_vs_fwd_tag(cp), cp->state, + cp->flags, refcount_read(&cp->refcnt), + refcount_read(&dest->refcnt)); + + /* Update the connection counters */ + if (!(flags & IP_VS_CONN_F_TEMPLATE)) { + /* It is a normal connection, so modify the counters + * according to the flags, later the protocol can + * update them on state change + */ + if (!(flags & IP_VS_CONN_F_INACTIVE)) + atomic_inc(&dest->activeconns); + else + atomic_inc(&dest->inactconns); + } else { + /* It is a persistent connection/template, so increase + the persistent connection counter */ + atomic_inc(&dest->persistconns); + } + + if (dest->u_threshold != 0 && + ip_vs_dest_totalconns(dest) >= dest->u_threshold) + dest->flags |= IP_VS_DEST_F_OVERLOAD; +} + + +/* + * Check if there is a destination for the connection, if so + * bind the connection to the destination. + */ +void ip_vs_try_bind_dest(struct ip_vs_conn *cp) +{ + struct ip_vs_dest *dest; + + rcu_read_lock(); + + /* This function is only invoked by the synchronization code. We do + * not currently support heterogeneous pools with synchronization, + * so we can make the assumption that the svc_af is the same as the + * dest_af + */ + dest = ip_vs_find_dest(cp->ipvs, cp->af, cp->af, &cp->daddr, + cp->dport, &cp->vaddr, cp->vport, + cp->protocol, cp->fwmark, cp->flags); + if (dest) { + struct ip_vs_proto_data *pd; + + spin_lock_bh(&cp->lock); + if (cp->dest) { + spin_unlock_bh(&cp->lock); + rcu_read_unlock(); + return; + } + + /* Applications work depending on the forwarding method + * but better to reassign them always when binding dest */ + if (cp->app) + ip_vs_unbind_app(cp); + + ip_vs_bind_dest(cp, dest); + spin_unlock_bh(&cp->lock); + + /* Update its packet transmitter */ + cp->packet_xmit = NULL; +#ifdef CONFIG_IP_VS_IPV6 + if (cp->af == AF_INET6) + ip_vs_bind_xmit_v6(cp); + else +#endif + ip_vs_bind_xmit(cp); + + pd = ip_vs_proto_data_get(cp->ipvs, cp->protocol); + if (pd && atomic_read(&pd->appcnt)) + ip_vs_bind_app(cp, pd->pp); + } + rcu_read_unlock(); +} + + +/* + * Unbind a connection entry with its VS destination + * Called by the ip_vs_conn_expire function. + */ +static inline void ip_vs_unbind_dest(struct ip_vs_conn *cp) +{ + struct ip_vs_dest *dest = cp->dest; + + if (!dest) + return; + + IP_VS_DBG_BUF(7, "Unbind-dest %s c:%s:%d v:%s:%d " + "d:%s:%d fwd:%c s:%u conn->flags:%X conn->refcnt:%d " + "dest->refcnt:%d\n", + ip_vs_proto_name(cp->protocol), + IP_VS_DBG_ADDR(cp->af, &cp->caddr), ntohs(cp->cport), + IP_VS_DBG_ADDR(cp->af, &cp->vaddr), ntohs(cp->vport), + IP_VS_DBG_ADDR(cp->daf, &cp->daddr), ntohs(cp->dport), + ip_vs_fwd_tag(cp), cp->state, + cp->flags, refcount_read(&cp->refcnt), + refcount_read(&dest->refcnt)); + + /* Update the connection counters */ + if (!(cp->flags & IP_VS_CONN_F_TEMPLATE)) { + /* It is a normal connection, so decrease the inactconns + or activeconns counter */ + if (cp->flags & IP_VS_CONN_F_INACTIVE) { + atomic_dec(&dest->inactconns); + } else { + atomic_dec(&dest->activeconns); + } + } else { + /* It is a persistent connection/template, so decrease + the persistent connection counter */ + atomic_dec(&dest->persistconns); + } + + if (dest->l_threshold != 0) { + if (ip_vs_dest_totalconns(dest) < dest->l_threshold) + dest->flags &= ~IP_VS_DEST_F_OVERLOAD; + } else if (dest->u_threshold != 0) { + if (ip_vs_dest_totalconns(dest) * 4 < dest->u_threshold * 3) + dest->flags &= ~IP_VS_DEST_F_OVERLOAD; + } else { + if (dest->flags & IP_VS_DEST_F_OVERLOAD) + dest->flags &= ~IP_VS_DEST_F_OVERLOAD; + } + + ip_vs_dest_put(dest); +} + +static int expire_quiescent_template(struct netns_ipvs *ipvs, + struct ip_vs_dest *dest) +{ +#ifdef CONFIG_SYSCTL + return ipvs->sysctl_expire_quiescent_template && + (atomic_read(&dest->weight) == 0); +#else + return 0; +#endif +} + +/* + * Checking if the destination of a connection template is available. + * If available, return 1, otherwise invalidate this connection + * template and return 0. + */ +int ip_vs_check_template(struct ip_vs_conn *ct, struct ip_vs_dest *cdest) +{ + struct ip_vs_dest *dest = ct->dest; + struct netns_ipvs *ipvs = ct->ipvs; + + /* + * Checking the dest server status. + */ + if ((dest == NULL) || + !(dest->flags & IP_VS_DEST_F_AVAILABLE) || + expire_quiescent_template(ipvs, dest) || + (cdest && (dest != cdest))) { + IP_VS_DBG_BUF(9, "check_template: dest not available for " + "protocol %s s:%s:%d v:%s:%d " + "-> d:%s:%d\n", + ip_vs_proto_name(ct->protocol), + IP_VS_DBG_ADDR(ct->af, &ct->caddr), + ntohs(ct->cport), + IP_VS_DBG_ADDR(ct->af, &ct->vaddr), + ntohs(ct->vport), + IP_VS_DBG_ADDR(ct->daf, &ct->daddr), + ntohs(ct->dport)); + + /* + * Invalidate the connection template + */ + if (ct->vport != htons(0xffff)) { + if (ip_vs_conn_unhash(ct)) { + ct->dport = htons(0xffff); + ct->vport = htons(0xffff); + ct->cport = 0; + ip_vs_conn_hash(ct); + } + } + + /* + * Simply decrease the refcnt of the template, + * don't restart its timer. + */ + __ip_vs_conn_put(ct); + return 0; + } + return 1; +} + +static void ip_vs_conn_rcu_free(struct rcu_head *head) +{ + struct ip_vs_conn *cp = container_of(head, struct ip_vs_conn, + rcu_head); + + ip_vs_pe_put(cp->pe); + kfree(cp->pe_data); + kmem_cache_free(ip_vs_conn_cachep, cp); +} + +/* Try to delete connection while not holding reference */ +static void ip_vs_conn_del(struct ip_vs_conn *cp) +{ + if (del_timer(&cp->timer)) { + /* Drop cp->control chain too */ + if (cp->control) + cp->timeout = 0; + ip_vs_conn_expire(&cp->timer); + } +} + +/* Try to delete connection while holding reference */ +static void ip_vs_conn_del_put(struct ip_vs_conn *cp) +{ + if (del_timer(&cp->timer)) { + /* Drop cp->control chain too */ + if (cp->control) + cp->timeout = 0; + __ip_vs_conn_put(cp); + ip_vs_conn_expire(&cp->timer); + } else { + __ip_vs_conn_put(cp); + } +} + +static void ip_vs_conn_expire(struct timer_list *t) +{ + struct ip_vs_conn *cp = from_timer(cp, t, timer); + struct netns_ipvs *ipvs = cp->ipvs; + + /* + * do I control anybody? + */ + if (atomic_read(&cp->n_control)) + goto expire_later; + + /* Unlink conn if not referenced anymore */ + if (likely(ip_vs_conn_unlink(cp))) { + struct ip_vs_conn *ct = cp->control; + + /* delete the timer if it is activated by other users */ + del_timer(&cp->timer); + + /* does anybody control me? */ + if (ct) { + bool has_ref = !cp->timeout && __ip_vs_conn_get(ct); + + ip_vs_control_del(cp); + /* Drop CTL or non-assured TPL if not used anymore */ + if (has_ref && !atomic_read(&ct->n_control) && + (!(ct->flags & IP_VS_CONN_F_TEMPLATE) || + !(ct->state & IP_VS_CTPL_S_ASSURED))) { + IP_VS_DBG(4, "drop controlling connection\n"); + ip_vs_conn_del_put(ct); + } else if (has_ref) { + __ip_vs_conn_put(ct); + } + } + + if ((cp->flags & IP_VS_CONN_F_NFCT) && + !(cp->flags & IP_VS_CONN_F_ONE_PACKET)) { + /* Do not access conntracks during subsys cleanup + * because nf_conntrack_find_get can not be used after + * conntrack cleanup for the net. + */ + smp_rmb(); + if (ipvs->enable) + ip_vs_conn_drop_conntrack(cp); + } + + if (unlikely(cp->app != NULL)) + ip_vs_unbind_app(cp); + ip_vs_unbind_dest(cp); + if (cp->flags & IP_VS_CONN_F_NO_CPORT) + atomic_dec(&ip_vs_conn_no_cport_cnt); + if (cp->flags & IP_VS_CONN_F_ONE_PACKET) + ip_vs_conn_rcu_free(&cp->rcu_head); + else + call_rcu(&cp->rcu_head, ip_vs_conn_rcu_free); + atomic_dec(&ipvs->conn_count); + return; + } + + expire_later: + IP_VS_DBG(7, "delayed: conn->refcnt=%d conn->n_control=%d\n", + refcount_read(&cp->refcnt), + atomic_read(&cp->n_control)); + + refcount_inc(&cp->refcnt); + cp->timeout = 60*HZ; + + if (ipvs->sync_state & IP_VS_STATE_MASTER) + ip_vs_sync_conn(ipvs, cp, sysctl_sync_threshold(ipvs)); + + __ip_vs_conn_put_timer(cp); +} + +/* Modify timer, so that it expires as soon as possible. + * Can be called without reference only if under RCU lock. + * We can have such chain of conns linked with ->control: DATA->CTL->TPL + * - DATA (eg. FTP) and TPL (persistence) can be present depending on setup + * - cp->timeout=0 indicates all conns from chain should be dropped but + * TPL is not dropped if in assured state + */ +void ip_vs_conn_expire_now(struct ip_vs_conn *cp) +{ + /* Using mod_timer_pending will ensure the timer is not + * modified after the final del_timer in ip_vs_conn_expire. + */ + if (timer_pending(&cp->timer) && + time_after(cp->timer.expires, jiffies)) + mod_timer_pending(&cp->timer, jiffies); +} + + +/* + * Create a new connection entry and hash it into the ip_vs_conn_tab + */ +struct ip_vs_conn * +ip_vs_conn_new(const struct ip_vs_conn_param *p, int dest_af, + const union nf_inet_addr *daddr, __be16 dport, unsigned int flags, + struct ip_vs_dest *dest, __u32 fwmark) +{ + struct ip_vs_conn *cp; + struct netns_ipvs *ipvs = p->ipvs; + struct ip_vs_proto_data *pd = ip_vs_proto_data_get(p->ipvs, + p->protocol); + + cp = kmem_cache_alloc(ip_vs_conn_cachep, GFP_ATOMIC); + if (cp == NULL) { + IP_VS_ERR_RL("%s(): no memory\n", __func__); + return NULL; + } + + INIT_HLIST_NODE(&cp->c_list); + timer_setup(&cp->timer, ip_vs_conn_expire, 0); + cp->ipvs = ipvs; + cp->af = p->af; + cp->daf = dest_af; + cp->protocol = p->protocol; + ip_vs_addr_set(p->af, &cp->caddr, p->caddr); + cp->cport = p->cport; + /* proto should only be IPPROTO_IP if p->vaddr is a fwmark */ + ip_vs_addr_set(p->protocol == IPPROTO_IP ? AF_UNSPEC : p->af, + &cp->vaddr, p->vaddr); + cp->vport = p->vport; + ip_vs_addr_set(cp->daf, &cp->daddr, daddr); + cp->dport = dport; + cp->flags = flags; + cp->fwmark = fwmark; + if (flags & IP_VS_CONN_F_TEMPLATE && p->pe) { + ip_vs_pe_get(p->pe); + cp->pe = p->pe; + cp->pe_data = p->pe_data; + cp->pe_data_len = p->pe_data_len; + } else { + cp->pe = NULL; + cp->pe_data = NULL; + cp->pe_data_len = 0; + } + spin_lock_init(&cp->lock); + + /* + * Set the entry is referenced by the current thread before hashing + * it in the table, so that other thread run ip_vs_random_dropentry + * but cannot drop this entry. + */ + refcount_set(&cp->refcnt, 1); + + cp->control = NULL; + atomic_set(&cp->n_control, 0); + atomic_set(&cp->in_pkts, 0); + + cp->packet_xmit = NULL; + cp->app = NULL; + cp->app_data = NULL; + /* reset struct ip_vs_seq */ + cp->in_seq.delta = 0; + cp->out_seq.delta = 0; + + atomic_inc(&ipvs->conn_count); + if (flags & IP_VS_CONN_F_NO_CPORT) + atomic_inc(&ip_vs_conn_no_cport_cnt); + + /* Bind the connection with a destination server */ + cp->dest = NULL; + ip_vs_bind_dest(cp, dest); + + /* Set its state and timeout */ + cp->state = 0; + cp->old_state = 0; + cp->timeout = 3*HZ; + cp->sync_endtime = jiffies & ~3UL; + + /* Bind its packet transmitter */ +#ifdef CONFIG_IP_VS_IPV6 + if (p->af == AF_INET6) + ip_vs_bind_xmit_v6(cp); + else +#endif + ip_vs_bind_xmit(cp); + + if (unlikely(pd && atomic_read(&pd->appcnt))) + ip_vs_bind_app(cp, pd->pp); + + /* + * Allow conntrack to be preserved. By default, conntrack + * is created and destroyed for every packet. + * Sometimes keeping conntrack can be useful for + * IP_VS_CONN_F_ONE_PACKET too. + */ + + if (ip_vs_conntrack_enabled(ipvs)) + cp->flags |= IP_VS_CONN_F_NFCT; + + /* Hash it in the ip_vs_conn_tab finally */ + ip_vs_conn_hash(cp); + + return cp; +} + +/* + * /proc/net/ip_vs_conn entries + */ +#ifdef CONFIG_PROC_FS +struct ip_vs_iter_state { + struct seq_net_private p; + struct hlist_head *l; +}; + +static void *ip_vs_conn_array(struct seq_file *seq, loff_t pos) +{ + int idx; + struct ip_vs_conn *cp; + struct ip_vs_iter_state *iter = seq->private; + + for (idx = 0; idx < ip_vs_conn_tab_size; idx++) { + hlist_for_each_entry_rcu(cp, &ip_vs_conn_tab[idx], c_list) { + /* __ip_vs_conn_get() is not needed by + * ip_vs_conn_seq_show and ip_vs_conn_sync_seq_show + */ + if (pos-- == 0) { + iter->l = &ip_vs_conn_tab[idx]; + return cp; + } + } + cond_resched_rcu(); + } + + return NULL; +} + +static void *ip_vs_conn_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(RCU) +{ + struct ip_vs_iter_state *iter = seq->private; + + iter->l = NULL; + rcu_read_lock(); + return *pos ? ip_vs_conn_array(seq, *pos - 1) :SEQ_START_TOKEN; +} + +static void *ip_vs_conn_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct ip_vs_conn *cp = v; + struct ip_vs_iter_state *iter = seq->private; + struct hlist_node *e; + struct hlist_head *l = iter->l; + int idx; + + ++*pos; + if (v == SEQ_START_TOKEN) + return ip_vs_conn_array(seq, 0); + + /* more on same hash chain? */ + e = rcu_dereference(hlist_next_rcu(&cp->c_list)); + if (e) + return hlist_entry(e, struct ip_vs_conn, c_list); + + idx = l - ip_vs_conn_tab; + while (++idx < ip_vs_conn_tab_size) { + hlist_for_each_entry_rcu(cp, &ip_vs_conn_tab[idx], c_list) { + iter->l = &ip_vs_conn_tab[idx]; + return cp; + } + cond_resched_rcu(); + } + iter->l = NULL; + return NULL; +} + +static void ip_vs_conn_seq_stop(struct seq_file *seq, void *v) + __releases(RCU) +{ + rcu_read_unlock(); +} + +static int ip_vs_conn_seq_show(struct seq_file *seq, void *v) +{ + + if (v == SEQ_START_TOKEN) + seq_puts(seq, + "Pro FromIP FPrt ToIP TPrt DestIP DPrt State Expires PEName PEData\n"); + else { + const struct ip_vs_conn *cp = v; + struct net *net = seq_file_net(seq); + char pe_data[IP_VS_PENAME_MAXLEN + IP_VS_PEDATA_MAXLEN + 3]; + size_t len = 0; + char dbuf[IP_VS_ADDRSTRLEN]; + + if (!net_eq(cp->ipvs->net, net)) + return 0; + if (cp->pe_data) { + pe_data[0] = ' '; + len = strlen(cp->pe->name); + memcpy(pe_data + 1, cp->pe->name, len); + pe_data[len + 1] = ' '; + len += 2; + len += cp->pe->show_pe_data(cp, pe_data + len); + } + pe_data[len] = '\0'; + +#ifdef CONFIG_IP_VS_IPV6 + if (cp->daf == AF_INET6) + snprintf(dbuf, sizeof(dbuf), "%pI6", &cp->daddr.in6); + else +#endif + snprintf(dbuf, sizeof(dbuf), "%08X", + ntohl(cp->daddr.ip)); + +#ifdef CONFIG_IP_VS_IPV6 + if (cp->af == AF_INET6) + seq_printf(seq, "%-3s %pI6 %04X %pI6 %04X " + "%s %04X %-11s %7u%s\n", + ip_vs_proto_name(cp->protocol), + &cp->caddr.in6, ntohs(cp->cport), + &cp->vaddr.in6, ntohs(cp->vport), + dbuf, ntohs(cp->dport), + ip_vs_state_name(cp), + jiffies_delta_to_msecs(cp->timer.expires - + jiffies) / 1000, + pe_data); + else +#endif + seq_printf(seq, + "%-3s %08X %04X %08X %04X" + " %s %04X %-11s %7u%s\n", + ip_vs_proto_name(cp->protocol), + ntohl(cp->caddr.ip), ntohs(cp->cport), + ntohl(cp->vaddr.ip), ntohs(cp->vport), + dbuf, ntohs(cp->dport), + ip_vs_state_name(cp), + jiffies_delta_to_msecs(cp->timer.expires - + jiffies) / 1000, + pe_data); + } + return 0; +} + +static const struct seq_operations ip_vs_conn_seq_ops = { + .start = ip_vs_conn_seq_start, + .next = ip_vs_conn_seq_next, + .stop = ip_vs_conn_seq_stop, + .show = ip_vs_conn_seq_show, +}; + +static const char *ip_vs_origin_name(unsigned int flags) +{ + if (flags & IP_VS_CONN_F_SYNC) + return "SYNC"; + else + return "LOCAL"; +} + +static int ip_vs_conn_sync_seq_show(struct seq_file *seq, void *v) +{ + char dbuf[IP_VS_ADDRSTRLEN]; + + if (v == SEQ_START_TOKEN) + seq_puts(seq, + "Pro FromIP FPrt ToIP TPrt DestIP DPrt State Origin Expires\n"); + else { + const struct ip_vs_conn *cp = v; + struct net *net = seq_file_net(seq); + + if (!net_eq(cp->ipvs->net, net)) + return 0; + +#ifdef CONFIG_IP_VS_IPV6 + if (cp->daf == AF_INET6) + snprintf(dbuf, sizeof(dbuf), "%pI6", &cp->daddr.in6); + else +#endif + snprintf(dbuf, sizeof(dbuf), "%08X", + ntohl(cp->daddr.ip)); + +#ifdef CONFIG_IP_VS_IPV6 + if (cp->af == AF_INET6) + seq_printf(seq, "%-3s %pI6 %04X %pI6 %04X " + "%s %04X %-11s %-6s %7u\n", + ip_vs_proto_name(cp->protocol), + &cp->caddr.in6, ntohs(cp->cport), + &cp->vaddr.in6, ntohs(cp->vport), + dbuf, ntohs(cp->dport), + ip_vs_state_name(cp), + ip_vs_origin_name(cp->flags), + jiffies_delta_to_msecs(cp->timer.expires - + jiffies) / 1000); + else +#endif + seq_printf(seq, + "%-3s %08X %04X %08X %04X " + "%s %04X %-11s %-6s %7u\n", + ip_vs_proto_name(cp->protocol), + ntohl(cp->caddr.ip), ntohs(cp->cport), + ntohl(cp->vaddr.ip), ntohs(cp->vport), + dbuf, ntohs(cp->dport), + ip_vs_state_name(cp), + ip_vs_origin_name(cp->flags), + jiffies_delta_to_msecs(cp->timer.expires - + jiffies) / 1000); + } + return 0; +} + +static const struct seq_operations ip_vs_conn_sync_seq_ops = { + .start = ip_vs_conn_seq_start, + .next = ip_vs_conn_seq_next, + .stop = ip_vs_conn_seq_stop, + .show = ip_vs_conn_sync_seq_show, +}; +#endif + + +/* Randomly drop connection entries before running out of memory + * Can be used for DATA and CTL conns. For TPL conns there are exceptions: + * - traffic for services in OPS mode increases ct->in_pkts, so it is supported + * - traffic for services not in OPS mode does not increase ct->in_pkts in + * all cases, so it is not supported + */ +static inline int todrop_entry(struct ip_vs_conn *cp) +{ + /* + * The drop rate array needs tuning for real environments. + * Called from timer bh only => no locking + */ + static const signed char todrop_rate[9] = {0, 1, 2, 3, 4, 5, 6, 7, 8}; + static signed char todrop_counter[9] = {0}; + int i; + + /* if the conn entry hasn't lasted for 60 seconds, don't drop it. + This will leave enough time for normal connection to get + through. */ + if (time_before(cp->timeout + jiffies, cp->timer.expires + 60*HZ)) + return 0; + + /* Don't drop the entry if its number of incoming packets is not + located in [0, 8] */ + i = atomic_read(&cp->in_pkts); + if (i > 8 || i < 0) return 0; + + if (!todrop_rate[i]) return 0; + if (--todrop_counter[i] > 0) return 0; + + todrop_counter[i] = todrop_rate[i]; + return 1; +} + +static inline bool ip_vs_conn_ops_mode(struct ip_vs_conn *cp) +{ + struct ip_vs_service *svc; + + if (!cp->dest) + return false; + svc = rcu_dereference(cp->dest->svc); + return svc && (svc->flags & IP_VS_SVC_F_ONEPACKET); +} + +/* Called from keventd and must protect itself from softirqs */ +void ip_vs_random_dropentry(struct netns_ipvs *ipvs) +{ + int idx; + struct ip_vs_conn *cp; + + rcu_read_lock(); + /* + * Randomly scan 1/32 of the whole table every second + */ + for (idx = 0; idx < (ip_vs_conn_tab_size>>5); idx++) { + unsigned int hash = get_random_u32() & ip_vs_conn_tab_mask; + + hlist_for_each_entry_rcu(cp, &ip_vs_conn_tab[hash], c_list) { + if (cp->ipvs != ipvs) + continue; + if (atomic_read(&cp->n_control)) + continue; + if (cp->flags & IP_VS_CONN_F_TEMPLATE) { + /* connection template of OPS */ + if (ip_vs_conn_ops_mode(cp)) + goto try_drop; + if (!(cp->state & IP_VS_CTPL_S_ASSURED)) + goto drop; + continue; + } + if (cp->protocol == IPPROTO_TCP) { + switch(cp->state) { + case IP_VS_TCP_S_SYN_RECV: + case IP_VS_TCP_S_SYNACK: + break; + + case IP_VS_TCP_S_ESTABLISHED: + if (todrop_entry(cp)) + break; + continue; + + default: + continue; + } + } else if (cp->protocol == IPPROTO_SCTP) { + switch (cp->state) { + case IP_VS_SCTP_S_INIT1: + case IP_VS_SCTP_S_INIT: + break; + case IP_VS_SCTP_S_ESTABLISHED: + if (todrop_entry(cp)) + break; + continue; + default: + continue; + } + } else { +try_drop: + if (!todrop_entry(cp)) + continue; + } + +drop: + IP_VS_DBG(4, "drop connection\n"); + ip_vs_conn_del(cp); + } + cond_resched_rcu(); + } + rcu_read_unlock(); +} + + +/* + * Flush all the connection entries in the ip_vs_conn_tab + */ +static void ip_vs_conn_flush(struct netns_ipvs *ipvs) +{ + int idx; + struct ip_vs_conn *cp, *cp_c; + +flush_again: + rcu_read_lock(); + for (idx = 0; idx < ip_vs_conn_tab_size; idx++) { + + hlist_for_each_entry_rcu(cp, &ip_vs_conn_tab[idx], c_list) { + if (cp->ipvs != ipvs) + continue; + if (atomic_read(&cp->n_control)) + continue; + cp_c = cp->control; + IP_VS_DBG(4, "del connection\n"); + ip_vs_conn_del(cp); + if (cp_c && !atomic_read(&cp_c->n_control)) { + IP_VS_DBG(4, "del controlling connection\n"); + ip_vs_conn_del(cp_c); + } + } + cond_resched_rcu(); + } + rcu_read_unlock(); + + /* the counter may be not NULL, because maybe some conn entries + are run by slow timer handler or unhashed but still referred */ + if (atomic_read(&ipvs->conn_count) != 0) { + schedule(); + goto flush_again; + } +} + +#ifdef CONFIG_SYSCTL +void ip_vs_expire_nodest_conn_flush(struct netns_ipvs *ipvs) +{ + int idx; + struct ip_vs_conn *cp, *cp_c; + struct ip_vs_dest *dest; + + rcu_read_lock(); + for (idx = 0; idx < ip_vs_conn_tab_size; idx++) { + hlist_for_each_entry_rcu(cp, &ip_vs_conn_tab[idx], c_list) { + if (cp->ipvs != ipvs) + continue; + + dest = cp->dest; + if (!dest || (dest->flags & IP_VS_DEST_F_AVAILABLE)) + continue; + + if (atomic_read(&cp->n_control)) + continue; + + cp_c = cp->control; + IP_VS_DBG(4, "del connection\n"); + ip_vs_conn_del(cp); + if (cp_c && !atomic_read(&cp_c->n_control)) { + IP_VS_DBG(4, "del controlling connection\n"); + ip_vs_conn_del(cp_c); + } + } + cond_resched_rcu(); + + /* netns clean up started, abort delayed work */ + if (!ipvs->enable) + break; + } + rcu_read_unlock(); +} +#endif + +/* + * per netns init and exit + */ +int __net_init ip_vs_conn_net_init(struct netns_ipvs *ipvs) +{ + atomic_set(&ipvs->conn_count, 0); + +#ifdef CONFIG_PROC_FS + if (!proc_create_net("ip_vs_conn", 0, ipvs->net->proc_net, + &ip_vs_conn_seq_ops, + sizeof(struct ip_vs_iter_state))) + goto err_conn; + + if (!proc_create_net("ip_vs_conn_sync", 0, ipvs->net->proc_net, + &ip_vs_conn_sync_seq_ops, + sizeof(struct ip_vs_iter_state))) + goto err_conn_sync; +#endif + + return 0; + +#ifdef CONFIG_PROC_FS +err_conn_sync: + remove_proc_entry("ip_vs_conn", ipvs->net->proc_net); +err_conn: + return -ENOMEM; +#endif +} + +void __net_exit ip_vs_conn_net_cleanup(struct netns_ipvs *ipvs) +{ + /* flush all the connection entries first */ + ip_vs_conn_flush(ipvs); +#ifdef CONFIG_PROC_FS + remove_proc_entry("ip_vs_conn", ipvs->net->proc_net); + remove_proc_entry("ip_vs_conn_sync", ipvs->net->proc_net); +#endif +} + +int __init ip_vs_conn_init(void) +{ + int idx; + + /* Compute size and mask */ + if (ip_vs_conn_tab_bits < 8 || ip_vs_conn_tab_bits > 27) { + pr_info("conn_tab_bits not in [8, 27]. Using default value\n"); + ip_vs_conn_tab_bits = CONFIG_IP_VS_TAB_BITS; + } + ip_vs_conn_tab_size = 1 << ip_vs_conn_tab_bits; + ip_vs_conn_tab_mask = ip_vs_conn_tab_size - 1; + + /* + * Allocate the connection hash table and initialize its list heads + */ + ip_vs_conn_tab = vmalloc(array_size(ip_vs_conn_tab_size, + sizeof(*ip_vs_conn_tab))); + if (!ip_vs_conn_tab) + return -ENOMEM; + + /* Allocate ip_vs_conn slab cache */ + ip_vs_conn_cachep = kmem_cache_create("ip_vs_conn", + sizeof(struct ip_vs_conn), 0, + SLAB_HWCACHE_ALIGN, NULL); + if (!ip_vs_conn_cachep) { + vfree(ip_vs_conn_tab); + return -ENOMEM; + } + + pr_info("Connection hash table configured " + "(size=%d, memory=%ldKbytes)\n", + ip_vs_conn_tab_size, + (long)(ip_vs_conn_tab_size*sizeof(*ip_vs_conn_tab))/1024); + IP_VS_DBG(0, "Each connection entry needs %zd bytes at least\n", + sizeof(struct ip_vs_conn)); + + for (idx = 0; idx < ip_vs_conn_tab_size; idx++) + INIT_HLIST_HEAD(&ip_vs_conn_tab[idx]); + + for (idx = 0; idx < CT_LOCKARRAY_SIZE; idx++) { + spin_lock_init(&__ip_vs_conntbl_lock_array[idx].l); + } + + /* calculate the random value for connection hash */ + get_random_bytes(&ip_vs_conn_rnd, sizeof(ip_vs_conn_rnd)); + + return 0; +} + +void ip_vs_conn_cleanup(void) +{ + /* Wait all ip_vs_conn_rcu_free() callbacks to complete */ + rcu_barrier(); + /* Release the empty cache */ + kmem_cache_destroy(ip_vs_conn_cachep); + vfree(ip_vs_conn_tab); +} diff --git a/net/netfilter/ipvs/ip_vs_core.c b/net/netfilter/ipvs/ip_vs_core.c new file mode 100644 index 000000000..b5ae41966 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_core.c @@ -0,0 +1,2456 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPVS An implementation of the IP virtual server support for the + * LINUX operating system. IPVS is now implemented as a module + * over the Netfilter framework. IPVS can be used to build a + * high-performance and highly available server based on a + * cluster of servers. + * + * Authors: Wensong Zhang + * Peter Kese + * Julian Anastasov + * + * The IPVS code for kernel 2.2 was done by Wensong Zhang and Peter Kese, + * with changes/fixes from Julian Anastasov, Lars Marowsky-Bree, Horms + * and others. + * + * Changes: + * Paul `Rusty' Russell properly handle non-linear skbs + * Harald Welte don't use nfcache + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include /* for icmp_send */ +#include +#include +#include +#include +#include /* net_generic() */ + +#include +#include + +#ifdef CONFIG_IP_VS_IPV6 +#include +#include +#include +#endif + +#include +#include + + +EXPORT_SYMBOL(register_ip_vs_scheduler); +EXPORT_SYMBOL(unregister_ip_vs_scheduler); +EXPORT_SYMBOL(ip_vs_proto_name); +EXPORT_SYMBOL(ip_vs_conn_new); +EXPORT_SYMBOL(ip_vs_conn_in_get); +EXPORT_SYMBOL(ip_vs_conn_out_get); +#ifdef CONFIG_IP_VS_PROTO_TCP +EXPORT_SYMBOL(ip_vs_tcp_conn_listen); +#endif +EXPORT_SYMBOL(ip_vs_conn_put); +#ifdef CONFIG_IP_VS_DEBUG +EXPORT_SYMBOL(ip_vs_get_debug_level); +#endif +EXPORT_SYMBOL(ip_vs_new_conn_out); + +#if defined(CONFIG_IP_VS_PROTO_TCP) && defined(CONFIG_IP_VS_PROTO_UDP) +#define SNAT_CALL(f, ...) \ + INDIRECT_CALL_2(f, tcp_snat_handler, udp_snat_handler, __VA_ARGS__) +#elif defined(CONFIG_IP_VS_PROTO_TCP) +#define SNAT_CALL(f, ...) INDIRECT_CALL_1(f, tcp_snat_handler, __VA_ARGS__) +#elif defined(CONFIG_IP_VS_PROTO_UDP) +#define SNAT_CALL(f, ...) INDIRECT_CALL_1(f, udp_snat_handler, __VA_ARGS__) +#else +#define SNAT_CALL(f, ...) f(__VA_ARGS__) +#endif + +static unsigned int ip_vs_net_id __read_mostly; +/* netns cnt used for uniqueness */ +static atomic_t ipvs_netns_cnt = ATOMIC_INIT(0); + +/* ID used in ICMP lookups */ +#define icmp_id(icmph) (((icmph)->un).echo.id) +#define icmpv6_id(icmph) (icmph->icmp6_dataun.u_echo.identifier) + +const char *ip_vs_proto_name(unsigned int proto) +{ + static char buf[20]; + + switch (proto) { + case IPPROTO_IP: + return "IP"; + case IPPROTO_UDP: + return "UDP"; + case IPPROTO_TCP: + return "TCP"; + case IPPROTO_SCTP: + return "SCTP"; + case IPPROTO_ICMP: + return "ICMP"; +#ifdef CONFIG_IP_VS_IPV6 + case IPPROTO_ICMPV6: + return "ICMPv6"; +#endif + default: + sprintf(buf, "IP_%u", proto); + return buf; + } +} + +void ip_vs_init_hash_table(struct list_head *table, int rows) +{ + while (--rows >= 0) + INIT_LIST_HEAD(&table[rows]); +} + +static inline void +ip_vs_in_stats(struct ip_vs_conn *cp, struct sk_buff *skb) +{ + struct ip_vs_dest *dest = cp->dest; + struct netns_ipvs *ipvs = cp->ipvs; + + if (dest && (dest->flags & IP_VS_DEST_F_AVAILABLE)) { + struct ip_vs_cpu_stats *s; + struct ip_vs_service *svc; + + local_bh_disable(); + + s = this_cpu_ptr(dest->stats.cpustats); + u64_stats_update_begin(&s->syncp); + u64_stats_inc(&s->cnt.inpkts); + u64_stats_add(&s->cnt.inbytes, skb->len); + u64_stats_update_end(&s->syncp); + + svc = rcu_dereference(dest->svc); + s = this_cpu_ptr(svc->stats.cpustats); + u64_stats_update_begin(&s->syncp); + u64_stats_inc(&s->cnt.inpkts); + u64_stats_add(&s->cnt.inbytes, skb->len); + u64_stats_update_end(&s->syncp); + + s = this_cpu_ptr(ipvs->tot_stats.cpustats); + u64_stats_update_begin(&s->syncp); + u64_stats_inc(&s->cnt.inpkts); + u64_stats_add(&s->cnt.inbytes, skb->len); + u64_stats_update_end(&s->syncp); + + local_bh_enable(); + } +} + + +static inline void +ip_vs_out_stats(struct ip_vs_conn *cp, struct sk_buff *skb) +{ + struct ip_vs_dest *dest = cp->dest; + struct netns_ipvs *ipvs = cp->ipvs; + + if (dest && (dest->flags & IP_VS_DEST_F_AVAILABLE)) { + struct ip_vs_cpu_stats *s; + struct ip_vs_service *svc; + + local_bh_disable(); + + s = this_cpu_ptr(dest->stats.cpustats); + u64_stats_update_begin(&s->syncp); + u64_stats_inc(&s->cnt.outpkts); + u64_stats_add(&s->cnt.outbytes, skb->len); + u64_stats_update_end(&s->syncp); + + svc = rcu_dereference(dest->svc); + s = this_cpu_ptr(svc->stats.cpustats); + u64_stats_update_begin(&s->syncp); + u64_stats_inc(&s->cnt.outpkts); + u64_stats_add(&s->cnt.outbytes, skb->len); + u64_stats_update_end(&s->syncp); + + s = this_cpu_ptr(ipvs->tot_stats.cpustats); + u64_stats_update_begin(&s->syncp); + u64_stats_inc(&s->cnt.outpkts); + u64_stats_add(&s->cnt.outbytes, skb->len); + u64_stats_update_end(&s->syncp); + + local_bh_enable(); + } +} + + +static inline void +ip_vs_conn_stats(struct ip_vs_conn *cp, struct ip_vs_service *svc) +{ + struct netns_ipvs *ipvs = svc->ipvs; + struct ip_vs_cpu_stats *s; + + local_bh_disable(); + + s = this_cpu_ptr(cp->dest->stats.cpustats); + u64_stats_update_begin(&s->syncp); + u64_stats_inc(&s->cnt.conns); + u64_stats_update_end(&s->syncp); + + s = this_cpu_ptr(svc->stats.cpustats); + u64_stats_update_begin(&s->syncp); + u64_stats_inc(&s->cnt.conns); + u64_stats_update_end(&s->syncp); + + s = this_cpu_ptr(ipvs->tot_stats.cpustats); + u64_stats_update_begin(&s->syncp); + u64_stats_inc(&s->cnt.conns); + u64_stats_update_end(&s->syncp); + + local_bh_enable(); +} + + +static inline void +ip_vs_set_state(struct ip_vs_conn *cp, int direction, + const struct sk_buff *skb, + struct ip_vs_proto_data *pd) +{ + if (likely(pd->pp->state_transition)) + pd->pp->state_transition(cp, direction, skb, pd); +} + +static inline int +ip_vs_conn_fill_param_persist(const struct ip_vs_service *svc, + struct sk_buff *skb, int protocol, + const union nf_inet_addr *caddr, __be16 cport, + const union nf_inet_addr *vaddr, __be16 vport, + struct ip_vs_conn_param *p) +{ + ip_vs_conn_fill_param(svc->ipvs, svc->af, protocol, caddr, cport, vaddr, + vport, p); + p->pe = rcu_dereference(svc->pe); + if (p->pe && p->pe->fill_param) + return p->pe->fill_param(p, skb); + + return 0; +} + +/* + * IPVS persistent scheduling function + * It creates a connection entry according to its template if exists, + * or selects a server and creates a connection entry plus a template. + * Locking: we are svc user (svc->refcnt), so we hold all dests too + * Protocols supported: TCP, UDP + */ +static struct ip_vs_conn * +ip_vs_sched_persist(struct ip_vs_service *svc, + struct sk_buff *skb, __be16 src_port, __be16 dst_port, + int *ignored, struct ip_vs_iphdr *iph) +{ + struct ip_vs_conn *cp = NULL; + struct ip_vs_dest *dest; + struct ip_vs_conn *ct; + __be16 dport = 0; /* destination port to forward */ + unsigned int flags; + struct ip_vs_conn_param param; + const union nf_inet_addr fwmark = { .ip = htonl(svc->fwmark) }; + union nf_inet_addr snet; /* source network of the client, + after masking */ + const union nf_inet_addr *src_addr, *dst_addr; + + if (likely(!ip_vs_iph_inverse(iph))) { + src_addr = &iph->saddr; + dst_addr = &iph->daddr; + } else { + src_addr = &iph->daddr; + dst_addr = &iph->saddr; + } + + + /* Mask saddr with the netmask to adjust template granularity */ +#ifdef CONFIG_IP_VS_IPV6 + if (svc->af == AF_INET6) + ipv6_addr_prefix(&snet.in6, &src_addr->in6, + (__force __u32) svc->netmask); + else +#endif + snet.ip = src_addr->ip & svc->netmask; + + IP_VS_DBG_BUF(6, "p-schedule: src %s:%u dest %s:%u " + "mnet %s\n", + IP_VS_DBG_ADDR(svc->af, src_addr), ntohs(src_port), + IP_VS_DBG_ADDR(svc->af, dst_addr), ntohs(dst_port), + IP_VS_DBG_ADDR(svc->af, &snet)); + + /* + * As far as we know, FTP is a very complicated network protocol, and + * it uses control connection and data connections. For active FTP, + * FTP server initialize data connection to the client, its source port + * is often 20. For passive FTP, FTP server tells the clients the port + * that it passively listens to, and the client issues the data + * connection. In the tunneling or direct routing mode, the load + * balancer is on the client-to-server half of connection, the port + * number is unknown to the load balancer. So, a conn template like + * is created for persistent FTP + * service, and a template like + * is created for other persistent services. + */ + { + int protocol = iph->protocol; + const union nf_inet_addr *vaddr = dst_addr; + __be16 vport = 0; + + if (dst_port == svc->port) { + /* non-FTP template: + * + * FTP template: + * + */ + if (svc->port != FTPPORT) + vport = dst_port; + } else { + /* Note: persistent fwmark-based services and + * persistent port zero service are handled here. + * fwmark template: + * + * port zero template: + * + */ + if (svc->fwmark) { + protocol = IPPROTO_IP; + vaddr = &fwmark; + } + } + /* return *ignored = -1 so NF_DROP can be used */ + if (ip_vs_conn_fill_param_persist(svc, skb, protocol, &snet, 0, + vaddr, vport, ¶m) < 0) { + *ignored = -1; + return NULL; + } + } + + /* Check if a template already exists */ + ct = ip_vs_ct_in_get(¶m); + if (!ct || !ip_vs_check_template(ct, NULL)) { + struct ip_vs_scheduler *sched; + + /* + * No template found or the dest of the connection + * template is not available. + * return *ignored=0 i.e. ICMP and NF_DROP + */ + sched = rcu_dereference(svc->scheduler); + if (sched) { + /* read svc->sched_data after svc->scheduler */ + smp_rmb(); + dest = sched->schedule(svc, skb, iph); + } else { + dest = NULL; + } + if (!dest) { + IP_VS_DBG(1, "p-schedule: no dest found.\n"); + kfree(param.pe_data); + *ignored = 0; + return NULL; + } + + if (dst_port == svc->port && svc->port != FTPPORT) + dport = dest->port; + + /* Create a template + * This adds param.pe_data to the template, + * and thus param.pe_data will be destroyed + * when the template expires */ + ct = ip_vs_conn_new(¶m, dest->af, &dest->addr, dport, + IP_VS_CONN_F_TEMPLATE, dest, skb->mark); + if (ct == NULL) { + kfree(param.pe_data); + *ignored = -1; + return NULL; + } + + ct->timeout = svc->timeout; + } else { + /* set destination with the found template */ + dest = ct->dest; + kfree(param.pe_data); + } + + dport = dst_port; + if (dport == svc->port && dest->port) + dport = dest->port; + + flags = (svc->flags & IP_VS_SVC_F_ONEPACKET + && iph->protocol == IPPROTO_UDP) ? + IP_VS_CONN_F_ONE_PACKET : 0; + + /* + * Create a new connection according to the template + */ + ip_vs_conn_fill_param(svc->ipvs, svc->af, iph->protocol, src_addr, + src_port, dst_addr, dst_port, ¶m); + + cp = ip_vs_conn_new(¶m, dest->af, &dest->addr, dport, flags, dest, + skb->mark); + if (cp == NULL) { + ip_vs_conn_put(ct); + *ignored = -1; + return NULL; + } + + /* + * Add its control + */ + ip_vs_control_add(cp, ct); + ip_vs_conn_put(ct); + + ip_vs_conn_stats(cp, svc); + return cp; +} + + +/* + * IPVS main scheduling function + * It selects a server according to the virtual service, and + * creates a connection entry. + * Protocols supported: TCP, UDP + * + * Usage of *ignored + * + * 1 : protocol tried to schedule (eg. on SYN), found svc but the + * svc/scheduler decides that this packet should be accepted with + * NF_ACCEPT because it must not be scheduled. + * + * 0 : scheduler can not find destination, so try bypass or + * return ICMP and then NF_DROP (ip_vs_leave). + * + * -1 : scheduler tried to schedule but fatal error occurred, eg. + * ip_vs_conn_new failure (ENOMEM) or ip_vs_sip_fill_param + * failure such as missing Call-ID, ENOMEM on skb_linearize + * or pe_data. In this case we should return NF_DROP without + * any attempts to send ICMP with ip_vs_leave. + */ +struct ip_vs_conn * +ip_vs_schedule(struct ip_vs_service *svc, struct sk_buff *skb, + struct ip_vs_proto_data *pd, int *ignored, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_protocol *pp = pd->pp; + struct ip_vs_conn *cp = NULL; + struct ip_vs_scheduler *sched; + struct ip_vs_dest *dest; + __be16 _ports[2], *pptr, cport, vport; + const void *caddr, *vaddr; + unsigned int flags; + + *ignored = 1; + /* + * IPv6 frags, only the first hit here. + */ + pptr = frag_safe_skb_hp(skb, iph->len, sizeof(_ports), _ports); + if (pptr == NULL) + return NULL; + + if (likely(!ip_vs_iph_inverse(iph))) { + cport = pptr[0]; + caddr = &iph->saddr; + vport = pptr[1]; + vaddr = &iph->daddr; + } else { + cport = pptr[1]; + caddr = &iph->daddr; + vport = pptr[0]; + vaddr = &iph->saddr; + } + + /* + * FTPDATA needs this check when using local real server. + * Never schedule Active FTPDATA connections from real server. + * For LVS-NAT they must be already created. For other methods + * with persistence the connection is created on SYN+ACK. + */ + if (cport == FTPDATA) { + IP_VS_DBG_PKT(12, svc->af, pp, skb, iph->off, + "Not scheduling FTPDATA"); + return NULL; + } + + /* + * Do not schedule replies from local real server. + */ + if ((!skb->dev || skb->dev->flags & IFF_LOOPBACK)) { + iph->hdr_flags ^= IP_VS_HDR_INVERSE; + cp = INDIRECT_CALL_1(pp->conn_in_get, + ip_vs_conn_in_get_proto, svc->ipvs, + svc->af, skb, iph); + iph->hdr_flags ^= IP_VS_HDR_INVERSE; + + if (cp) { + IP_VS_DBG_PKT(12, svc->af, pp, skb, iph->off, + "Not scheduling reply for existing" + " connection"); + __ip_vs_conn_put(cp); + return NULL; + } + } + + /* + * Persistent service + */ + if (svc->flags & IP_VS_SVC_F_PERSISTENT) + return ip_vs_sched_persist(svc, skb, cport, vport, ignored, + iph); + + *ignored = 0; + + /* + * Non-persistent service + */ + if (!svc->fwmark && vport != svc->port) { + if (!svc->port) + pr_err("Schedule: port zero only supported " + "in persistent services, " + "check your ipvs configuration\n"); + return NULL; + } + + sched = rcu_dereference(svc->scheduler); + if (sched) { + /* read svc->sched_data after svc->scheduler */ + smp_rmb(); + dest = sched->schedule(svc, skb, iph); + } else { + dest = NULL; + } + if (dest == NULL) { + IP_VS_DBG(1, "Schedule: no dest found.\n"); + return NULL; + } + + flags = (svc->flags & IP_VS_SVC_F_ONEPACKET + && iph->protocol == IPPROTO_UDP) ? + IP_VS_CONN_F_ONE_PACKET : 0; + + /* + * Create a connection entry. + */ + { + struct ip_vs_conn_param p; + + ip_vs_conn_fill_param(svc->ipvs, svc->af, iph->protocol, + caddr, cport, vaddr, vport, &p); + cp = ip_vs_conn_new(&p, dest->af, &dest->addr, + dest->port ? dest->port : vport, + flags, dest, skb->mark); + if (!cp) { + *ignored = -1; + return NULL; + } + } + + IP_VS_DBG_BUF(6, "Schedule fwd:%c c:%s:%u v:%s:%u " + "d:%s:%u conn->flags:%X conn->refcnt:%d\n", + ip_vs_fwd_tag(cp), + IP_VS_DBG_ADDR(cp->af, &cp->caddr), ntohs(cp->cport), + IP_VS_DBG_ADDR(cp->af, &cp->vaddr), ntohs(cp->vport), + IP_VS_DBG_ADDR(cp->daf, &cp->daddr), ntohs(cp->dport), + cp->flags, refcount_read(&cp->refcnt)); + + ip_vs_conn_stats(cp, svc); + return cp; +} + +static inline int ip_vs_addr_is_unicast(struct net *net, int af, + union nf_inet_addr *addr) +{ +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + return ipv6_addr_type(&addr->in6) & IPV6_ADDR_UNICAST; +#endif + return (inet_addr_type(net, addr->ip) == RTN_UNICAST); +} + +/* + * Pass or drop the packet. + * Called by ip_vs_in, when the virtual service is available but + * no destination is available for a new connection. + */ +int ip_vs_leave(struct ip_vs_service *svc, struct sk_buff *skb, + struct ip_vs_proto_data *pd, struct ip_vs_iphdr *iph) +{ + __be16 _ports[2], *pptr, dport; + struct netns_ipvs *ipvs = svc->ipvs; + struct net *net = ipvs->net; + + pptr = frag_safe_skb_hp(skb, iph->len, sizeof(_ports), _ports); + if (!pptr) + return NF_DROP; + dport = likely(!ip_vs_iph_inverse(iph)) ? pptr[1] : pptr[0]; + + /* if it is fwmark-based service, the cache_bypass sysctl is up + and the destination is a non-local unicast, then create + a cache_bypass connection entry */ + if (sysctl_cache_bypass(ipvs) && svc->fwmark && + !(iph->hdr_flags & (IP_VS_HDR_INVERSE | IP_VS_HDR_ICMP)) && + ip_vs_addr_is_unicast(net, svc->af, &iph->daddr)) { + int ret; + struct ip_vs_conn *cp; + unsigned int flags = (svc->flags & IP_VS_SVC_F_ONEPACKET && + iph->protocol == IPPROTO_UDP) ? + IP_VS_CONN_F_ONE_PACKET : 0; + union nf_inet_addr daddr = { .all = { 0, 0, 0, 0 } }; + + /* create a new connection entry */ + IP_VS_DBG(6, "%s(): create a cache_bypass entry\n", __func__); + { + struct ip_vs_conn_param p; + ip_vs_conn_fill_param(svc->ipvs, svc->af, iph->protocol, + &iph->saddr, pptr[0], + &iph->daddr, pptr[1], &p); + cp = ip_vs_conn_new(&p, svc->af, &daddr, 0, + IP_VS_CONN_F_BYPASS | flags, + NULL, skb->mark); + if (!cp) + return NF_DROP; + } + + /* statistics */ + ip_vs_in_stats(cp, skb); + + /* set state */ + ip_vs_set_state(cp, IP_VS_DIR_INPUT, skb, pd); + + /* transmit the first SYN packet */ + ret = cp->packet_xmit(skb, cp, pd->pp, iph); + /* do not touch skb anymore */ + + if ((cp->flags & IP_VS_CONN_F_ONE_PACKET) && cp->control) + atomic_inc(&cp->control->in_pkts); + else + atomic_inc(&cp->in_pkts); + ip_vs_conn_put(cp); + return ret; + } + + /* + * When the virtual ftp service is presented, packets destined + * for other services on the VIP may get here (except services + * listed in the ipvs table), pass the packets, because it is + * not ipvs job to decide to drop the packets. + */ + if (svc->port == FTPPORT && dport != FTPPORT) + return NF_ACCEPT; + + if (unlikely(ip_vs_iph_icmp(iph))) + return NF_DROP; + + /* + * Notify the client that the destination is unreachable, and + * release the socket buffer. + * Since it is in IP layer, the TCP socket is not actually + * created, the TCP RST packet cannot be sent, instead that + * ICMP_PORT_UNREACH is sent here no matter it is TCP/UDP. --WZ + */ +#ifdef CONFIG_IP_VS_IPV6 + if (svc->af == AF_INET6) { + if (!skb->dev) + skb->dev = net->loopback_dev; + icmpv6_send(skb, ICMPV6_DEST_UNREACH, ICMPV6_PORT_UNREACH, 0); + } else +#endif + icmp_send(skb, ICMP_DEST_UNREACH, ICMP_PORT_UNREACH, 0); + + return NF_DROP; +} + +#ifdef CONFIG_SYSCTL + +static int sysctl_snat_reroute(struct netns_ipvs *ipvs) +{ + return ipvs->sysctl_snat_reroute; +} + +static int sysctl_nat_icmp_send(struct netns_ipvs *ipvs) +{ + return ipvs->sysctl_nat_icmp_send; +} + +#else + +static int sysctl_snat_reroute(struct netns_ipvs *ipvs) { return 0; } +static int sysctl_nat_icmp_send(struct netns_ipvs *ipvs) { return 0; } + +#endif + +__sum16 ip_vs_checksum_complete(struct sk_buff *skb, int offset) +{ + return csum_fold(skb_checksum(skb, offset, skb->len - offset, 0)); +} + +static inline enum ip_defrag_users ip_vs_defrag_user(unsigned int hooknum) +{ + if (NF_INET_LOCAL_IN == hooknum) + return IP_DEFRAG_VS_IN; + if (NF_INET_FORWARD == hooknum) + return IP_DEFRAG_VS_FWD; + return IP_DEFRAG_VS_OUT; +} + +static inline int ip_vs_gather_frags(struct netns_ipvs *ipvs, + struct sk_buff *skb, u_int32_t user) +{ + int err; + + local_bh_disable(); + err = ip_defrag(ipvs->net, skb, user); + local_bh_enable(); + if (!err) + ip_send_check(ip_hdr(skb)); + + return err; +} + +static int ip_vs_route_me_harder(struct netns_ipvs *ipvs, int af, + struct sk_buff *skb, unsigned int hooknum) +{ + if (!sysctl_snat_reroute(ipvs)) + return 0; + /* Reroute replies only to remote clients (FORWARD and LOCAL_OUT) */ + if (NF_INET_LOCAL_IN == hooknum) + return 0; +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) { + struct dst_entry *dst = skb_dst(skb); + + if (dst->dev && !(dst->dev->flags & IFF_LOOPBACK) && + ip6_route_me_harder(ipvs->net, skb->sk, skb) != 0) + return 1; + } else +#endif + if (!(skb_rtable(skb)->rt_flags & RTCF_LOCAL) && + ip_route_me_harder(ipvs->net, skb->sk, skb, RTN_LOCAL) != 0) + return 1; + + return 0; +} + +/* + * Packet has been made sufficiently writable in caller + * - inout: 1=in->out, 0=out->in + */ +void ip_vs_nat_icmp(struct sk_buff *skb, struct ip_vs_protocol *pp, + struct ip_vs_conn *cp, int inout) +{ + struct iphdr *iph = ip_hdr(skb); + unsigned int icmp_offset = iph->ihl*4; + struct icmphdr *icmph = (struct icmphdr *)(skb_network_header(skb) + + icmp_offset); + struct iphdr *ciph = (struct iphdr *)(icmph + 1); + + if (inout) { + iph->saddr = cp->vaddr.ip; + ip_send_check(iph); + ciph->daddr = cp->vaddr.ip; + ip_send_check(ciph); + } else { + iph->daddr = cp->daddr.ip; + ip_send_check(iph); + ciph->saddr = cp->daddr.ip; + ip_send_check(ciph); + } + + /* the TCP/UDP/SCTP port */ + if (IPPROTO_TCP == ciph->protocol || IPPROTO_UDP == ciph->protocol || + IPPROTO_SCTP == ciph->protocol) { + __be16 *ports = (void *)ciph + ciph->ihl*4; + + if (inout) + ports[1] = cp->vport; + else + ports[0] = cp->dport; + } + + /* And finally the ICMP checksum */ + icmph->checksum = 0; + icmph->checksum = ip_vs_checksum_complete(skb, icmp_offset); + skb->ip_summed = CHECKSUM_UNNECESSARY; + + if (inout) + IP_VS_DBG_PKT(11, AF_INET, pp, skb, (void *)ciph - (void *)iph, + "Forwarding altered outgoing ICMP"); + else + IP_VS_DBG_PKT(11, AF_INET, pp, skb, (void *)ciph - (void *)iph, + "Forwarding altered incoming ICMP"); +} + +#ifdef CONFIG_IP_VS_IPV6 +void ip_vs_nat_icmp_v6(struct sk_buff *skb, struct ip_vs_protocol *pp, + struct ip_vs_conn *cp, int inout) +{ + struct ipv6hdr *iph = ipv6_hdr(skb); + unsigned int icmp_offset = 0; + unsigned int offs = 0; /* header offset*/ + int protocol; + struct icmp6hdr *icmph; + struct ipv6hdr *ciph; + unsigned short fragoffs; + + ipv6_find_hdr(skb, &icmp_offset, IPPROTO_ICMPV6, &fragoffs, NULL); + icmph = (struct icmp6hdr *)(skb_network_header(skb) + icmp_offset); + offs = icmp_offset + sizeof(struct icmp6hdr); + ciph = (struct ipv6hdr *)(skb_network_header(skb) + offs); + + protocol = ipv6_find_hdr(skb, &offs, -1, &fragoffs, NULL); + + if (inout) { + iph->saddr = cp->vaddr.in6; + ciph->daddr = cp->vaddr.in6; + } else { + iph->daddr = cp->daddr.in6; + ciph->saddr = cp->daddr.in6; + } + + /* the TCP/UDP/SCTP port */ + if (!fragoffs && (IPPROTO_TCP == protocol || IPPROTO_UDP == protocol || + IPPROTO_SCTP == protocol)) { + __be16 *ports = (void *)(skb_network_header(skb) + offs); + + IP_VS_DBG(11, "%s() changed port %d to %d\n", __func__, + ntohs(inout ? ports[1] : ports[0]), + ntohs(inout ? cp->vport : cp->dport)); + if (inout) + ports[1] = cp->vport; + else + ports[0] = cp->dport; + } + + /* And finally the ICMP checksum */ + icmph->icmp6_cksum = ~csum_ipv6_magic(&iph->saddr, &iph->daddr, + skb->len - icmp_offset, + IPPROTO_ICMPV6, 0); + skb->csum_start = skb_network_header(skb) - skb->head + icmp_offset; + skb->csum_offset = offsetof(struct icmp6hdr, icmp6_cksum); + skb->ip_summed = CHECKSUM_PARTIAL; + + if (inout) + IP_VS_DBG_PKT(11, AF_INET6, pp, skb, + (void *)ciph - (void *)iph, + "Forwarding altered outgoing ICMPv6"); + else + IP_VS_DBG_PKT(11, AF_INET6, pp, skb, + (void *)ciph - (void *)iph, + "Forwarding altered incoming ICMPv6"); +} +#endif + +/* Handle relevant response ICMP messages - forward to the right + * destination host. + */ +static int handle_response_icmp(int af, struct sk_buff *skb, + union nf_inet_addr *snet, + __u8 protocol, struct ip_vs_conn *cp, + struct ip_vs_protocol *pp, + unsigned int offset, unsigned int ihl, + unsigned int hooknum) +{ + unsigned int verdict = NF_DROP; + + if (IP_VS_FWD_METHOD(cp) != IP_VS_CONN_F_MASQ) + goto after_nat; + + /* Ensure the checksum is correct */ + if (!skb_csum_unnecessary(skb) && ip_vs_checksum_complete(skb, ihl)) { + /* Failed checksum! */ + IP_VS_DBG_BUF(1, "Forward ICMP: failed checksum from %s!\n", + IP_VS_DBG_ADDR(af, snet)); + goto out; + } + + if (IPPROTO_TCP == protocol || IPPROTO_UDP == protocol || + IPPROTO_SCTP == protocol) + offset += 2 * sizeof(__u16); + if (skb_ensure_writable(skb, offset)) + goto out; + +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + ip_vs_nat_icmp_v6(skb, pp, cp, 1); + else +#endif + ip_vs_nat_icmp(skb, pp, cp, 1); + + if (ip_vs_route_me_harder(cp->ipvs, af, skb, hooknum)) + goto out; + +after_nat: + /* do the statistics and put it back */ + ip_vs_out_stats(cp, skb); + + skb->ipvs_property = 1; + if (!(cp->flags & IP_VS_CONN_F_NFCT)) + ip_vs_notrack(skb); + else + ip_vs_update_conntrack(skb, cp, 0); + verdict = NF_ACCEPT; + +out: + __ip_vs_conn_put(cp); + + return verdict; +} + +/* + * Handle ICMP messages in the inside-to-outside direction (outgoing). + * Find any that might be relevant, check against existing connections. + * Currently handles error types - unreachable, quench, ttl exceeded. + */ +static int ip_vs_out_icmp(struct netns_ipvs *ipvs, struct sk_buff *skb, + int *related, unsigned int hooknum) +{ + struct iphdr *iph; + struct icmphdr _icmph, *ic; + struct iphdr _ciph, *cih; /* The ip header contained within the ICMP */ + struct ip_vs_iphdr ciph; + struct ip_vs_conn *cp; + struct ip_vs_protocol *pp; + unsigned int offset, ihl; + union nf_inet_addr snet; + + *related = 1; + + /* reassemble IP fragments */ + if (ip_is_fragment(ip_hdr(skb))) { + if (ip_vs_gather_frags(ipvs, skb, ip_vs_defrag_user(hooknum))) + return NF_STOLEN; + } + + iph = ip_hdr(skb); + offset = ihl = iph->ihl * 4; + ic = skb_header_pointer(skb, offset, sizeof(_icmph), &_icmph); + if (ic == NULL) + return NF_DROP; + + IP_VS_DBG(12, "Outgoing ICMP (%d,%d) %pI4->%pI4\n", + ic->type, ntohs(icmp_id(ic)), + &iph->saddr, &iph->daddr); + + /* + * Work through seeing if this is for us. + * These checks are supposed to be in an order that means easy + * things are checked first to speed up processing.... however + * this means that some packets will manage to get a long way + * down this stack and then be rejected, but that's life. + */ + if ((ic->type != ICMP_DEST_UNREACH) && + (ic->type != ICMP_SOURCE_QUENCH) && + (ic->type != ICMP_TIME_EXCEEDED)) { + *related = 0; + return NF_ACCEPT; + } + + /* Now find the contained IP header */ + offset += sizeof(_icmph); + cih = skb_header_pointer(skb, offset, sizeof(_ciph), &_ciph); + if (cih == NULL) + return NF_ACCEPT; /* The packet looks wrong, ignore */ + + pp = ip_vs_proto_get(cih->protocol); + if (!pp) + return NF_ACCEPT; + + /* Is the embedded protocol header present? */ + if (unlikely(cih->frag_off & htons(IP_OFFSET) && + pp->dont_defrag)) + return NF_ACCEPT; + + IP_VS_DBG_PKT(11, AF_INET, pp, skb, offset, + "Checking outgoing ICMP for"); + + ip_vs_fill_iph_skb_icmp(AF_INET, skb, offset, true, &ciph); + + /* The embedded headers contain source and dest in reverse order */ + cp = INDIRECT_CALL_1(pp->conn_out_get, ip_vs_conn_out_get_proto, + ipvs, AF_INET, skb, &ciph); + if (!cp) + return NF_ACCEPT; + + snet.ip = iph->saddr; + return handle_response_icmp(AF_INET, skb, &snet, cih->protocol, cp, + pp, ciph.len, ihl, hooknum); +} + +#ifdef CONFIG_IP_VS_IPV6 +static int ip_vs_out_icmp_v6(struct netns_ipvs *ipvs, struct sk_buff *skb, + int *related, unsigned int hooknum, + struct ip_vs_iphdr *ipvsh) +{ + struct icmp6hdr _icmph, *ic; + struct ip_vs_iphdr ciph = {.flags = 0, .fragoffs = 0};/*Contained IP */ + struct ip_vs_conn *cp; + struct ip_vs_protocol *pp; + union nf_inet_addr snet; + unsigned int offset; + + *related = 1; + ic = frag_safe_skb_hp(skb, ipvsh->len, sizeof(_icmph), &_icmph); + if (ic == NULL) + return NF_DROP; + + /* + * Work through seeing if this is for us. + * These checks are supposed to be in an order that means easy + * things are checked first to speed up processing.... however + * this means that some packets will manage to get a long way + * down this stack and then be rejected, but that's life. + */ + if (ic->icmp6_type & ICMPV6_INFOMSG_MASK) { + *related = 0; + return NF_ACCEPT; + } + /* Fragment header that is before ICMP header tells us that: + * it's not an error message since they can't be fragmented. + */ + if (ipvsh->flags & IP6_FH_F_FRAG) + return NF_DROP; + + IP_VS_DBG(8, "Outgoing ICMPv6 (%d,%d) %pI6c->%pI6c\n", + ic->icmp6_type, ntohs(icmpv6_id(ic)), + &ipvsh->saddr, &ipvsh->daddr); + + if (!ip_vs_fill_iph_skb_icmp(AF_INET6, skb, ipvsh->len + sizeof(_icmph), + true, &ciph)) + return NF_ACCEPT; /* The packet looks wrong, ignore */ + + pp = ip_vs_proto_get(ciph.protocol); + if (!pp) + return NF_ACCEPT; + + /* The embedded headers contain source and dest in reverse order */ + cp = INDIRECT_CALL_1(pp->conn_out_get, ip_vs_conn_out_get_proto, + ipvs, AF_INET6, skb, &ciph); + if (!cp) + return NF_ACCEPT; + + snet.in6 = ciph.saddr.in6; + offset = ciph.len; + return handle_response_icmp(AF_INET6, skb, &snet, ciph.protocol, cp, + pp, offset, sizeof(struct ipv6hdr), + hooknum); +} +#endif + +/* + * Check if sctp chunc is ABORT chunk + */ +static inline int is_sctp_abort(const struct sk_buff *skb, int nh_len) +{ + struct sctp_chunkhdr *sch, schunk; + sch = skb_header_pointer(skb, nh_len + sizeof(struct sctphdr), + sizeof(schunk), &schunk); + if (sch == NULL) + return 0; + if (sch->type == SCTP_CID_ABORT) + return 1; + return 0; +} + +static inline int is_tcp_reset(const struct sk_buff *skb, int nh_len) +{ + struct tcphdr _tcph, *th; + + th = skb_header_pointer(skb, nh_len, sizeof(_tcph), &_tcph); + if (th == NULL) + return 0; + return th->rst; +} + +static inline bool is_new_conn(const struct sk_buff *skb, + struct ip_vs_iphdr *iph) +{ + switch (iph->protocol) { + case IPPROTO_TCP: { + struct tcphdr _tcph, *th; + + th = skb_header_pointer(skb, iph->len, sizeof(_tcph), &_tcph); + if (th == NULL) + return false; + return th->syn; + } + case IPPROTO_SCTP: { + struct sctp_chunkhdr *sch, schunk; + + sch = skb_header_pointer(skb, iph->len + sizeof(struct sctphdr), + sizeof(schunk), &schunk); + if (sch == NULL) + return false; + return sch->type == SCTP_CID_INIT; + } + default: + return false; + } +} + +static inline bool is_new_conn_expected(const struct ip_vs_conn *cp, + int conn_reuse_mode) +{ + /* Controlled (FTP DATA or persistence)? */ + if (cp->control) + return false; + + switch (cp->protocol) { + case IPPROTO_TCP: + return (cp->state == IP_VS_TCP_S_TIME_WAIT) || + (cp->state == IP_VS_TCP_S_CLOSE) || + ((conn_reuse_mode & 2) && + (cp->state == IP_VS_TCP_S_FIN_WAIT) && + (cp->flags & IP_VS_CONN_F_NOOUTPUT)); + case IPPROTO_SCTP: + return cp->state == IP_VS_SCTP_S_CLOSED; + default: + return false; + } +} + +/* Generic function to create new connections for outgoing RS packets + * + * Pre-requisites for successful connection creation: + * 1) Virtual Service is NOT fwmark based: + * In fwmark-VS actual vaddr and vport are unknown to IPVS + * 2) Real Server and Virtual Service were NOT configured without port: + * This is to allow match of different VS to the same RS ip-addr + */ +struct ip_vs_conn *ip_vs_new_conn_out(struct ip_vs_service *svc, + struct ip_vs_dest *dest, + struct sk_buff *skb, + const struct ip_vs_iphdr *iph, + __be16 dport, + __be16 cport) +{ + struct ip_vs_conn_param param; + struct ip_vs_conn *ct = NULL, *cp = NULL; + const union nf_inet_addr *vaddr, *daddr, *caddr; + union nf_inet_addr snet; + __be16 vport; + unsigned int flags; + + EnterFunction(12); + vaddr = &svc->addr; + vport = svc->port; + daddr = &iph->saddr; + caddr = &iph->daddr; + + /* check pre-requisites are satisfied */ + if (svc->fwmark) + return NULL; + if (!vport || !dport) + return NULL; + + /* for persistent service first create connection template */ + if (svc->flags & IP_VS_SVC_F_PERSISTENT) { + /* apply netmask the same way ingress-side does */ +#ifdef CONFIG_IP_VS_IPV6 + if (svc->af == AF_INET6) + ipv6_addr_prefix(&snet.in6, &caddr->in6, + (__force __u32)svc->netmask); + else +#endif + snet.ip = caddr->ip & svc->netmask; + /* fill params and create template if not existent */ + if (ip_vs_conn_fill_param_persist(svc, skb, iph->protocol, + &snet, 0, vaddr, + vport, ¶m) < 0) + return NULL; + ct = ip_vs_ct_in_get(¶m); + /* check if template exists and points to the same dest */ + if (!ct || !ip_vs_check_template(ct, dest)) { + ct = ip_vs_conn_new(¶m, dest->af, daddr, dport, + IP_VS_CONN_F_TEMPLATE, dest, 0); + if (!ct) { + kfree(param.pe_data); + return NULL; + } + ct->timeout = svc->timeout; + } else { + kfree(param.pe_data); + } + } + + /* connection flags */ + flags = ((svc->flags & IP_VS_SVC_F_ONEPACKET) && + iph->protocol == IPPROTO_UDP) ? IP_VS_CONN_F_ONE_PACKET : 0; + /* create connection */ + ip_vs_conn_fill_param(svc->ipvs, svc->af, iph->protocol, + caddr, cport, vaddr, vport, ¶m); + cp = ip_vs_conn_new(¶m, dest->af, daddr, dport, flags, dest, 0); + if (!cp) { + if (ct) + ip_vs_conn_put(ct); + return NULL; + } + if (ct) { + ip_vs_control_add(cp, ct); + ip_vs_conn_put(ct); + } + ip_vs_conn_stats(cp, svc); + + /* return connection (will be used to handle outgoing packet) */ + IP_VS_DBG_BUF(6, "New connection RS-initiated:%c c:%s:%u v:%s:%u " + "d:%s:%u conn->flags:%X conn->refcnt:%d\n", + ip_vs_fwd_tag(cp), + IP_VS_DBG_ADDR(cp->af, &cp->caddr), ntohs(cp->cport), + IP_VS_DBG_ADDR(cp->af, &cp->vaddr), ntohs(cp->vport), + IP_VS_DBG_ADDR(cp->af, &cp->daddr), ntohs(cp->dport), + cp->flags, refcount_read(&cp->refcnt)); + LeaveFunction(12); + return cp; +} + +/* Handle outgoing packets which are considered requests initiated by + * real servers, so that subsequent responses from external client can be + * routed to the right real server. + * Used also for outgoing responses in OPS mode. + * + * Connection management is handled by persistent-engine specific callback. + */ +static struct ip_vs_conn *__ip_vs_rs_conn_out(unsigned int hooknum, + struct netns_ipvs *ipvs, + int af, struct sk_buff *skb, + const struct ip_vs_iphdr *iph) +{ + struct ip_vs_dest *dest; + struct ip_vs_conn *cp = NULL; + __be16 _ports[2], *pptr; + + if (hooknum == NF_INET_LOCAL_IN) + return NULL; + + pptr = frag_safe_skb_hp(skb, iph->len, + sizeof(_ports), _ports); + if (!pptr) + return NULL; + + dest = ip_vs_find_real_service(ipvs, af, iph->protocol, + &iph->saddr, pptr[0]); + if (dest) { + struct ip_vs_service *svc; + struct ip_vs_pe *pe; + + svc = rcu_dereference(dest->svc); + if (svc) { + pe = rcu_dereference(svc->pe); + if (pe && pe->conn_out) + cp = pe->conn_out(svc, dest, skb, iph, + pptr[0], pptr[1]); + } + } + + return cp; +} + +/* Handle response packets: rewrite addresses and send away... + */ +static unsigned int +handle_response(int af, struct sk_buff *skb, struct ip_vs_proto_data *pd, + struct ip_vs_conn *cp, struct ip_vs_iphdr *iph, + unsigned int hooknum) +{ + struct ip_vs_protocol *pp = pd->pp; + + if (IP_VS_FWD_METHOD(cp) != IP_VS_CONN_F_MASQ) + goto after_nat; + + IP_VS_DBG_PKT(11, af, pp, skb, iph->off, "Outgoing packet"); + + if (skb_ensure_writable(skb, iph->len)) + goto drop; + + /* mangle the packet */ + if (pp->snat_handler && + !SNAT_CALL(pp->snat_handler, skb, pp, cp, iph)) + goto drop; + +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + ipv6_hdr(skb)->saddr = cp->vaddr.in6; + else +#endif + { + ip_hdr(skb)->saddr = cp->vaddr.ip; + ip_send_check(ip_hdr(skb)); + } + + /* + * nf_iterate does not expect change in the skb->dst->dev. + * It looks like it is not fatal to enable this code for hooks + * where our handlers are at the end of the chain list and + * when all next handlers use skb->dst->dev and not outdev. + * It will definitely route properly the inout NAT traffic + * when multiple paths are used. + */ + + /* For policy routing, packets originating from this + * machine itself may be routed differently to packets + * passing through. We want this packet to be routed as + * if it came from this machine itself. So re-compute + * the routing information. + */ + if (ip_vs_route_me_harder(cp->ipvs, af, skb, hooknum)) + goto drop; + + IP_VS_DBG_PKT(10, af, pp, skb, iph->off, "After SNAT"); + +after_nat: + ip_vs_out_stats(cp, skb); + ip_vs_set_state(cp, IP_VS_DIR_OUTPUT, skb, pd); + skb->ipvs_property = 1; + if (!(cp->flags & IP_VS_CONN_F_NFCT)) + ip_vs_notrack(skb); + else + ip_vs_update_conntrack(skb, cp, 0); + ip_vs_conn_put(cp); + + LeaveFunction(11); + return NF_ACCEPT; + +drop: + ip_vs_conn_put(cp); + kfree_skb(skb); + LeaveFunction(11); + return NF_STOLEN; +} + +/* + * Check if outgoing packet belongs to the established ip_vs_conn. + */ +static unsigned int +ip_vs_out_hook(void *priv, struct sk_buff *skb, const struct nf_hook_state *state) +{ + struct netns_ipvs *ipvs = net_ipvs(state->net); + unsigned int hooknum = state->hook; + struct ip_vs_iphdr iph; + struct ip_vs_protocol *pp; + struct ip_vs_proto_data *pd; + struct ip_vs_conn *cp; + int af = state->pf; + struct sock *sk; + + EnterFunction(11); + + /* Already marked as IPVS request or reply? */ + if (skb->ipvs_property) + return NF_ACCEPT; + + sk = skb_to_full_sk(skb); + /* Bad... Do not break raw sockets */ + if (unlikely(sk && hooknum == NF_INET_LOCAL_OUT && + af == AF_INET)) { + + if (sk->sk_family == PF_INET && inet_sk(sk)->nodefrag) + return NF_ACCEPT; + } + + if (unlikely(!skb_dst(skb))) + return NF_ACCEPT; + + if (!ipvs->enable) + return NF_ACCEPT; + + ip_vs_fill_iph_skb(af, skb, false, &iph); +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) { + if (unlikely(iph.protocol == IPPROTO_ICMPV6)) { + int related; + int verdict = ip_vs_out_icmp_v6(ipvs, skb, &related, + hooknum, &iph); + + if (related) + return verdict; + } + } else +#endif + if (unlikely(iph.protocol == IPPROTO_ICMP)) { + int related; + int verdict = ip_vs_out_icmp(ipvs, skb, &related, hooknum); + + if (related) + return verdict; + } + + pd = ip_vs_proto_data_get(ipvs, iph.protocol); + if (unlikely(!pd)) + return NF_ACCEPT; + pp = pd->pp; + + /* reassemble IP fragments */ +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET) +#endif + if (unlikely(ip_is_fragment(ip_hdr(skb)) && !pp->dont_defrag)) { + if (ip_vs_gather_frags(ipvs, skb, + ip_vs_defrag_user(hooknum))) + return NF_STOLEN; + + ip_vs_fill_iph_skb(AF_INET, skb, false, &iph); + } + + /* + * Check if the packet belongs to an existing entry + */ + cp = INDIRECT_CALL_1(pp->conn_out_get, ip_vs_conn_out_get_proto, + ipvs, af, skb, &iph); + + if (likely(cp)) + return handle_response(af, skb, pd, cp, &iph, hooknum); + + /* Check for real-server-started requests */ + if (atomic_read(&ipvs->conn_out_counter)) { + /* Currently only for UDP: + * connection oriented protocols typically use + * ephemeral ports for outgoing connections, so + * related incoming responses would not match any VS + */ + if (pp->protocol == IPPROTO_UDP) { + cp = __ip_vs_rs_conn_out(hooknum, ipvs, af, skb, &iph); + if (likely(cp)) + return handle_response(af, skb, pd, cp, &iph, + hooknum); + } + } + + if (sysctl_nat_icmp_send(ipvs) && + (pp->protocol == IPPROTO_TCP || + pp->protocol == IPPROTO_UDP || + pp->protocol == IPPROTO_SCTP)) { + __be16 _ports[2], *pptr; + + pptr = frag_safe_skb_hp(skb, iph.len, + sizeof(_ports), _ports); + if (pptr == NULL) + return NF_ACCEPT; /* Not for me */ + if (ip_vs_has_real_service(ipvs, af, iph.protocol, &iph.saddr, + pptr[0])) { + /* + * Notify the real server: there is no + * existing entry if it is not RST + * packet or not TCP packet. + */ + if ((iph.protocol != IPPROTO_TCP && + iph.protocol != IPPROTO_SCTP) + || ((iph.protocol == IPPROTO_TCP + && !is_tcp_reset(skb, iph.len)) + || (iph.protocol == IPPROTO_SCTP + && !is_sctp_abort(skb, + iph.len)))) { +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) { + if (!skb->dev) + skb->dev = ipvs->net->loopback_dev; + icmpv6_send(skb, + ICMPV6_DEST_UNREACH, + ICMPV6_PORT_UNREACH, + 0); + } else +#endif + icmp_send(skb, + ICMP_DEST_UNREACH, + ICMP_PORT_UNREACH, 0); + return NF_DROP; + } + } + } + + IP_VS_DBG_PKT(12, af, pp, skb, iph.off, + "ip_vs_out: packet continues traversal as normal"); + return NF_ACCEPT; +} + +static unsigned int +ip_vs_try_to_schedule(struct netns_ipvs *ipvs, int af, struct sk_buff *skb, + struct ip_vs_proto_data *pd, + int *verdict, struct ip_vs_conn **cpp, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_protocol *pp = pd->pp; + + if (!iph->fragoffs) { + /* No (second) fragments need to enter here, as nf_defrag_ipv6 + * replayed fragment zero will already have created the cp + */ + + /* Schedule and create new connection entry into cpp */ + if (!pp->conn_schedule(ipvs, af, skb, pd, verdict, cpp, iph)) + return 0; + } + + if (unlikely(!*cpp)) { + /* sorry, all this trouble for a no-hit :) */ + IP_VS_DBG_PKT(12, af, pp, skb, iph->off, + "ip_vs_in: packet continues traversal as normal"); + + /* Fragment couldn't be mapped to a conn entry */ + if (iph->fragoffs) + IP_VS_DBG_PKT(7, af, pp, skb, iph->off, + "unhandled fragment"); + + *verdict = NF_ACCEPT; + return 0; + } + + return 1; +} + +/* Check the UDP tunnel and return its header length */ +static int ipvs_udp_decap(struct netns_ipvs *ipvs, struct sk_buff *skb, + unsigned int offset, __u16 af, + const union nf_inet_addr *daddr, __u8 *proto) +{ + struct udphdr _udph, *udph; + struct ip_vs_dest *dest; + + udph = skb_header_pointer(skb, offset, sizeof(_udph), &_udph); + if (!udph) + goto unk; + offset += sizeof(struct udphdr); + dest = ip_vs_find_tunnel(ipvs, af, daddr, udph->dest); + if (!dest) + goto unk; + if (dest->tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GUE) { + struct guehdr _gueh, *gueh; + + gueh = skb_header_pointer(skb, offset, sizeof(_gueh), &_gueh); + if (!gueh) + goto unk; + if (gueh->control != 0 || gueh->version != 0) + goto unk; + /* Later we can support also IPPROTO_IPV6 */ + if (gueh->proto_ctype != IPPROTO_IPIP) + goto unk; + *proto = gueh->proto_ctype; + return sizeof(struct udphdr) + sizeof(struct guehdr) + + (gueh->hlen << 2); + } + +unk: + return 0; +} + +/* Check the GRE tunnel and return its header length */ +static int ipvs_gre_decap(struct netns_ipvs *ipvs, struct sk_buff *skb, + unsigned int offset, __u16 af, + const union nf_inet_addr *daddr, __u8 *proto) +{ + struct gre_base_hdr _greh, *greh; + struct ip_vs_dest *dest; + + greh = skb_header_pointer(skb, offset, sizeof(_greh), &_greh); + if (!greh) + goto unk; + dest = ip_vs_find_tunnel(ipvs, af, daddr, 0); + if (!dest) + goto unk; + if (dest->tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GRE) { + __be16 type; + + /* Only support version 0 and C (csum) */ + if ((greh->flags & ~GRE_CSUM) != 0) + goto unk; + type = greh->protocol; + /* Later we can support also IPPROTO_IPV6 */ + if (type != htons(ETH_P_IP)) + goto unk; + *proto = IPPROTO_IPIP; + return gre_calc_hlen(gre_flags_to_tnl_flags(greh->flags)); + } + +unk: + return 0; +} + +/* + * Handle ICMP messages in the outside-to-inside direction (incoming). + * Find any that might be relevant, check against existing connections, + * forward to the right destination host if relevant. + * Currently handles error types - unreachable, quench, ttl exceeded. + */ +static int +ip_vs_in_icmp(struct netns_ipvs *ipvs, struct sk_buff *skb, int *related, + unsigned int hooknum) +{ + struct iphdr *iph; + struct icmphdr _icmph, *ic; + struct iphdr _ciph, *cih; /* The ip header contained within the ICMP */ + struct ip_vs_iphdr ciph; + struct ip_vs_conn *cp; + struct ip_vs_protocol *pp; + struct ip_vs_proto_data *pd; + unsigned int offset, offset2, ihl, verdict; + bool tunnel, new_cp = false; + union nf_inet_addr *raddr; + char *outer_proto = "IPIP"; + + *related = 1; + + /* reassemble IP fragments */ + if (ip_is_fragment(ip_hdr(skb))) { + if (ip_vs_gather_frags(ipvs, skb, ip_vs_defrag_user(hooknum))) + return NF_STOLEN; + } + + iph = ip_hdr(skb); + offset = ihl = iph->ihl * 4; + ic = skb_header_pointer(skb, offset, sizeof(_icmph), &_icmph); + if (ic == NULL) + return NF_DROP; + + IP_VS_DBG(12, "Incoming ICMP (%d,%d) %pI4->%pI4\n", + ic->type, ntohs(icmp_id(ic)), + &iph->saddr, &iph->daddr); + + /* + * Work through seeing if this is for us. + * These checks are supposed to be in an order that means easy + * things are checked first to speed up processing.... however + * this means that some packets will manage to get a long way + * down this stack and then be rejected, but that's life. + */ + if ((ic->type != ICMP_DEST_UNREACH) && + (ic->type != ICMP_SOURCE_QUENCH) && + (ic->type != ICMP_TIME_EXCEEDED)) { + *related = 0; + return NF_ACCEPT; + } + + /* Now find the contained IP header */ + offset += sizeof(_icmph); + cih = skb_header_pointer(skb, offset, sizeof(_ciph), &_ciph); + if (cih == NULL) + return NF_ACCEPT; /* The packet looks wrong, ignore */ + raddr = (union nf_inet_addr *)&cih->daddr; + + /* Special case for errors for IPIP/UDP/GRE tunnel packets */ + tunnel = false; + if (cih->protocol == IPPROTO_IPIP) { + struct ip_vs_dest *dest; + + if (unlikely(cih->frag_off & htons(IP_OFFSET))) + return NF_ACCEPT; + /* Error for our IPIP must arrive at LOCAL_IN */ + if (!(skb_rtable(skb)->rt_flags & RTCF_LOCAL)) + return NF_ACCEPT; + dest = ip_vs_find_tunnel(ipvs, AF_INET, raddr, 0); + /* Only for known tunnel */ + if (!dest || dest->tun_type != IP_VS_CONN_F_TUNNEL_TYPE_IPIP) + return NF_ACCEPT; + offset += cih->ihl * 4; + cih = skb_header_pointer(skb, offset, sizeof(_ciph), &_ciph); + if (cih == NULL) + return NF_ACCEPT; /* The packet looks wrong, ignore */ + tunnel = true; + } else if ((cih->protocol == IPPROTO_UDP || /* Can be UDP encap */ + cih->protocol == IPPROTO_GRE) && /* Can be GRE encap */ + /* Error for our tunnel must arrive at LOCAL_IN */ + (skb_rtable(skb)->rt_flags & RTCF_LOCAL)) { + __u8 iproto; + int ulen; + + /* Non-first fragment has no UDP/GRE header */ + if (unlikely(cih->frag_off & htons(IP_OFFSET))) + return NF_ACCEPT; + offset2 = offset + cih->ihl * 4; + if (cih->protocol == IPPROTO_UDP) { + ulen = ipvs_udp_decap(ipvs, skb, offset2, AF_INET, + raddr, &iproto); + outer_proto = "UDP"; + } else { + ulen = ipvs_gre_decap(ipvs, skb, offset2, AF_INET, + raddr, &iproto); + outer_proto = "GRE"; + } + if (ulen > 0) { + /* Skip IP and UDP/GRE tunnel headers */ + offset = offset2 + ulen; + /* Now we should be at the original IP header */ + cih = skb_header_pointer(skb, offset, sizeof(_ciph), + &_ciph); + if (cih && cih->version == 4 && cih->ihl >= 5 && + iproto == IPPROTO_IPIP) + tunnel = true; + else + return NF_ACCEPT; + } + } + + pd = ip_vs_proto_data_get(ipvs, cih->protocol); + if (!pd) + return NF_ACCEPT; + pp = pd->pp; + + /* Is the embedded protocol header present? */ + if (unlikely(cih->frag_off & htons(IP_OFFSET) && + pp->dont_defrag)) + return NF_ACCEPT; + + IP_VS_DBG_PKT(11, AF_INET, pp, skb, offset, + "Checking incoming ICMP for"); + + offset2 = offset; + ip_vs_fill_iph_skb_icmp(AF_INET, skb, offset, !tunnel, &ciph); + offset = ciph.len; + + /* The embedded headers contain source and dest in reverse order. + * For IPIP/UDP/GRE tunnel this is error for request, not for reply. + */ + cp = INDIRECT_CALL_1(pp->conn_in_get, ip_vs_conn_in_get_proto, + ipvs, AF_INET, skb, &ciph); + + if (!cp) { + int v; + + if (tunnel || !sysctl_schedule_icmp(ipvs)) + return NF_ACCEPT; + + if (!ip_vs_try_to_schedule(ipvs, AF_INET, skb, pd, &v, &cp, &ciph)) + return v; + new_cp = true; + } + + verdict = NF_DROP; + + /* Ensure the checksum is correct */ + if (!skb_csum_unnecessary(skb) && ip_vs_checksum_complete(skb, ihl)) { + /* Failed checksum! */ + IP_VS_DBG(1, "Incoming ICMP: failed checksum from %pI4!\n", + &iph->saddr); + goto out; + } + + if (tunnel) { + __be32 info = ic->un.gateway; + __u8 type = ic->type; + __u8 code = ic->code; + + /* Update the MTU */ + if (ic->type == ICMP_DEST_UNREACH && + ic->code == ICMP_FRAG_NEEDED) { + struct ip_vs_dest *dest = cp->dest; + u32 mtu = ntohs(ic->un.frag.mtu); + __be16 frag_off = cih->frag_off; + + /* Strip outer IP and ICMP, go to IPIP/UDP/GRE header */ + if (pskb_pull(skb, ihl + sizeof(_icmph)) == NULL) + goto ignore_tunnel; + offset2 -= ihl + sizeof(_icmph); + skb_reset_network_header(skb); + IP_VS_DBG(12, "ICMP for %s %pI4->%pI4: mtu=%u\n", + outer_proto, &ip_hdr(skb)->saddr, + &ip_hdr(skb)->daddr, mtu); + ipv4_update_pmtu(skb, ipvs->net, mtu, 0, 0); + /* Client uses PMTUD? */ + if (!(frag_off & htons(IP_DF))) + goto ignore_tunnel; + /* Prefer the resulting PMTU */ + if (dest) { + struct ip_vs_dest_dst *dest_dst; + + dest_dst = rcu_dereference(dest->dest_dst); + if (dest_dst) + mtu = dst_mtu(dest_dst->dst_cache); + } + if (mtu > 68 + sizeof(struct iphdr)) + mtu -= sizeof(struct iphdr); + info = htonl(mtu); + } + /* Strip outer IP, ICMP and IPIP/UDP/GRE, go to IP header of + * original request. + */ + if (pskb_pull(skb, offset2) == NULL) + goto ignore_tunnel; + skb_reset_network_header(skb); + IP_VS_DBG(12, "Sending ICMP for %pI4->%pI4: t=%u, c=%u, i=%u\n", + &ip_hdr(skb)->saddr, &ip_hdr(skb)->daddr, + type, code, ntohl(info)); + icmp_send(skb, type, code, info); + /* ICMP can be shorter but anyways, account it */ + ip_vs_out_stats(cp, skb); + +ignore_tunnel: + consume_skb(skb); + verdict = NF_STOLEN; + goto out; + } + + /* do the statistics and put it back */ + ip_vs_in_stats(cp, skb); + if (IPPROTO_TCP == cih->protocol || IPPROTO_UDP == cih->protocol || + IPPROTO_SCTP == cih->protocol) + offset += 2 * sizeof(__u16); + verdict = ip_vs_icmp_xmit(skb, cp, pp, offset, hooknum, &ciph); + +out: + if (likely(!new_cp)) + __ip_vs_conn_put(cp); + else + ip_vs_conn_put(cp); + + return verdict; +} + +#ifdef CONFIG_IP_VS_IPV6 +static int ip_vs_in_icmp_v6(struct netns_ipvs *ipvs, struct sk_buff *skb, + int *related, unsigned int hooknum, + struct ip_vs_iphdr *iph) +{ + struct icmp6hdr _icmph, *ic; + struct ip_vs_iphdr ciph = {.flags = 0, .fragoffs = 0};/*Contained IP */ + struct ip_vs_conn *cp; + struct ip_vs_protocol *pp; + struct ip_vs_proto_data *pd; + unsigned int offset, verdict; + bool new_cp = false; + + *related = 1; + + ic = frag_safe_skb_hp(skb, iph->len, sizeof(_icmph), &_icmph); + if (ic == NULL) + return NF_DROP; + + /* + * Work through seeing if this is for us. + * These checks are supposed to be in an order that means easy + * things are checked first to speed up processing.... however + * this means that some packets will manage to get a long way + * down this stack and then be rejected, but that's life. + */ + if (ic->icmp6_type & ICMPV6_INFOMSG_MASK) { + *related = 0; + return NF_ACCEPT; + } + /* Fragment header that is before ICMP header tells us that: + * it's not an error message since they can't be fragmented. + */ + if (iph->flags & IP6_FH_F_FRAG) + return NF_DROP; + + IP_VS_DBG(8, "Incoming ICMPv6 (%d,%d) %pI6c->%pI6c\n", + ic->icmp6_type, ntohs(icmpv6_id(ic)), + &iph->saddr, &iph->daddr); + + offset = iph->len + sizeof(_icmph); + if (!ip_vs_fill_iph_skb_icmp(AF_INET6, skb, offset, true, &ciph)) + return NF_ACCEPT; + + pd = ip_vs_proto_data_get(ipvs, ciph.protocol); + if (!pd) + return NF_ACCEPT; + pp = pd->pp; + + /* Cannot handle fragmented embedded protocol */ + if (ciph.fragoffs) + return NF_ACCEPT; + + IP_VS_DBG_PKT(11, AF_INET6, pp, skb, offset, + "Checking incoming ICMPv6 for"); + + /* The embedded headers contain source and dest in reverse order + * if not from localhost + */ + cp = INDIRECT_CALL_1(pp->conn_in_get, ip_vs_conn_in_get_proto, + ipvs, AF_INET6, skb, &ciph); + + if (!cp) { + int v; + + if (!sysctl_schedule_icmp(ipvs)) + return NF_ACCEPT; + + if (!ip_vs_try_to_schedule(ipvs, AF_INET6, skb, pd, &v, &cp, &ciph)) + return v; + + new_cp = true; + } + + /* VS/TUN, VS/DR and LOCALNODE just let it go */ + if ((hooknum == NF_INET_LOCAL_OUT) && + (IP_VS_FWD_METHOD(cp) != IP_VS_CONN_F_MASQ)) { + verdict = NF_ACCEPT; + goto out; + } + + /* do the statistics and put it back */ + ip_vs_in_stats(cp, skb); + + /* Need to mangle contained IPv6 header in ICMPv6 packet */ + offset = ciph.len; + if (IPPROTO_TCP == ciph.protocol || IPPROTO_UDP == ciph.protocol || + IPPROTO_SCTP == ciph.protocol) + offset += 2 * sizeof(__u16); /* Also mangle ports */ + + verdict = ip_vs_icmp_xmit_v6(skb, cp, pp, offset, hooknum, &ciph); + +out: + if (likely(!new_cp)) + __ip_vs_conn_put(cp); + else + ip_vs_conn_put(cp); + + return verdict; +} +#endif + + +/* + * Check if it's for virtual services, look it up, + * and send it on its way... + */ +static unsigned int +ip_vs_in_hook(void *priv, struct sk_buff *skb, const struct nf_hook_state *state) +{ + struct netns_ipvs *ipvs = net_ipvs(state->net); + unsigned int hooknum = state->hook; + struct ip_vs_iphdr iph; + struct ip_vs_protocol *pp; + struct ip_vs_proto_data *pd; + struct ip_vs_conn *cp; + int ret, pkts; + struct sock *sk; + int af = state->pf; + + /* Already marked as IPVS request or reply? */ + if (skb->ipvs_property) + return NF_ACCEPT; + + /* + * Big tappo: + * - remote client: only PACKET_HOST + * - route: used for struct net when skb->dev is unset + */ + if (unlikely((skb->pkt_type != PACKET_HOST && + hooknum != NF_INET_LOCAL_OUT) || + !skb_dst(skb))) { + ip_vs_fill_iph_skb(af, skb, false, &iph); + IP_VS_DBG_BUF(12, "packet type=%d proto=%d daddr=%s" + " ignored in hook %u\n", + skb->pkt_type, iph.protocol, + IP_VS_DBG_ADDR(af, &iph.daddr), hooknum); + return NF_ACCEPT; + } + /* ipvs enabled in this netns ? */ + if (unlikely(sysctl_backup_only(ipvs) || !ipvs->enable)) + return NF_ACCEPT; + + ip_vs_fill_iph_skb(af, skb, false, &iph); + + /* Bad... Do not break raw sockets */ + sk = skb_to_full_sk(skb); + if (unlikely(sk && hooknum == NF_INET_LOCAL_OUT && + af == AF_INET)) { + + if (sk->sk_family == PF_INET && inet_sk(sk)->nodefrag) + return NF_ACCEPT; + } + +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) { + if (unlikely(iph.protocol == IPPROTO_ICMPV6)) { + int related; + int verdict = ip_vs_in_icmp_v6(ipvs, skb, &related, + hooknum, &iph); + + if (related) + return verdict; + } + } else +#endif + if (unlikely(iph.protocol == IPPROTO_ICMP)) { + int related; + int verdict = ip_vs_in_icmp(ipvs, skb, &related, + hooknum); + + if (related) + return verdict; + } + + /* Protocol supported? */ + pd = ip_vs_proto_data_get(ipvs, iph.protocol); + if (unlikely(!pd)) { + /* The only way we'll see this packet again is if it's + * encapsulated, so mark it with ipvs_property=1 so we + * skip it if we're ignoring tunneled packets + */ + if (sysctl_ignore_tunneled(ipvs)) + skb->ipvs_property = 1; + + return NF_ACCEPT; + } + pp = pd->pp; + /* + * Check if the packet belongs to an existing connection entry + */ + cp = INDIRECT_CALL_1(pp->conn_in_get, ip_vs_conn_in_get_proto, + ipvs, af, skb, &iph); + + if (!iph.fragoffs && is_new_conn(skb, &iph) && cp) { + int conn_reuse_mode = sysctl_conn_reuse_mode(ipvs); + bool old_ct = false, resched = false; + + if (unlikely(sysctl_expire_nodest_conn(ipvs)) && cp->dest && + unlikely(!atomic_read(&cp->dest->weight))) { + resched = true; + old_ct = ip_vs_conn_uses_old_conntrack(cp, skb); + } else if (conn_reuse_mode && + is_new_conn_expected(cp, conn_reuse_mode)) { + old_ct = ip_vs_conn_uses_old_conntrack(cp, skb); + if (!atomic_read(&cp->n_control)) { + resched = true; + } else { + /* Do not reschedule controlling connection + * that uses conntrack while it is still + * referenced by controlled connection(s). + */ + resched = !old_ct; + } + } + + if (resched) { + if (!old_ct) + cp->flags &= ~IP_VS_CONN_F_NFCT; + if (!atomic_read(&cp->n_control)) + ip_vs_conn_expire_now(cp); + __ip_vs_conn_put(cp); + if (old_ct) + return NF_DROP; + cp = NULL; + } + } + + /* Check the server status */ + if (cp && cp->dest && !(cp->dest->flags & IP_VS_DEST_F_AVAILABLE)) { + /* the destination server is not available */ + if (sysctl_expire_nodest_conn(ipvs)) { + bool old_ct = ip_vs_conn_uses_old_conntrack(cp, skb); + + if (!old_ct) + cp->flags &= ~IP_VS_CONN_F_NFCT; + + ip_vs_conn_expire_now(cp); + __ip_vs_conn_put(cp); + if (old_ct) + return NF_DROP; + cp = NULL; + } else { + __ip_vs_conn_put(cp); + return NF_DROP; + } + } + + if (unlikely(!cp)) { + int v; + + if (!ip_vs_try_to_schedule(ipvs, af, skb, pd, &v, &cp, &iph)) + return v; + } + + IP_VS_DBG_PKT(11, af, pp, skb, iph.off, "Incoming packet"); + + ip_vs_in_stats(cp, skb); + ip_vs_set_state(cp, IP_VS_DIR_INPUT, skb, pd); + if (cp->packet_xmit) + ret = cp->packet_xmit(skb, cp, pp, &iph); + /* do not touch skb anymore */ + else { + IP_VS_DBG_RL("warning: packet_xmit is null"); + ret = NF_ACCEPT; + } + + /* Increase its packet counter and check if it is needed + * to be synchronized + * + * Sync connection if it is about to close to + * encorage the standby servers to update the connections timeout + * + * For ONE_PKT let ip_vs_sync_conn() do the filter work. + */ + + if (cp->flags & IP_VS_CONN_F_ONE_PACKET) + pkts = sysctl_sync_threshold(ipvs); + else + pkts = atomic_inc_return(&cp->in_pkts); + + if (ipvs->sync_state & IP_VS_STATE_MASTER) + ip_vs_sync_conn(ipvs, cp, pkts); + else if ((cp->flags & IP_VS_CONN_F_ONE_PACKET) && cp->control) + /* increment is done inside ip_vs_sync_conn too */ + atomic_inc(&cp->control->in_pkts); + + ip_vs_conn_put(cp); + return ret; +} + +/* + * It is hooked at the NF_INET_FORWARD chain, in order to catch ICMP + * related packets destined for 0.0.0.0/0. + * When fwmark-based virtual service is used, such as transparent + * cache cluster, TCP packets can be marked and routed to ip_vs_in, + * but ICMP destined for 0.0.0.0/0 cannot not be easily marked and + * sent to ip_vs_in_icmp. So, catch them at the NF_INET_FORWARD chain + * and send them to ip_vs_in_icmp. + */ +static unsigned int +ip_vs_forward_icmp(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct netns_ipvs *ipvs = net_ipvs(state->net); + int r; + + /* ipvs enabled in this netns ? */ + if (unlikely(sysctl_backup_only(ipvs) || !ipvs->enable)) + return NF_ACCEPT; + + if (state->pf == NFPROTO_IPV4) { + if (ip_hdr(skb)->protocol != IPPROTO_ICMP) + return NF_ACCEPT; +#ifdef CONFIG_IP_VS_IPV6 + } else { + struct ip_vs_iphdr iphdr; + + ip_vs_fill_iph_skb(AF_INET6, skb, false, &iphdr); + + if (iphdr.protocol != IPPROTO_ICMPV6) + return NF_ACCEPT; + + return ip_vs_in_icmp_v6(ipvs, skb, &r, state->hook, &iphdr); +#endif + } + + return ip_vs_in_icmp(ipvs, skb, &r, state->hook); +} + +static const struct nf_hook_ops ip_vs_ops4[] = { + /* After packet filtering, change source only for VS/NAT */ + { + .hook = ip_vs_out_hook, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_LOCAL_IN, + .priority = NF_IP_PRI_NAT_SRC - 2, + }, + /* After packet filtering, forward packet through VS/DR, VS/TUN, + * or VS/NAT(change destination), so that filtering rules can be + * applied to IPVS. */ + { + .hook = ip_vs_in_hook, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_LOCAL_IN, + .priority = NF_IP_PRI_NAT_SRC - 1, + }, + /* Before ip_vs_in, change source only for VS/NAT */ + { + .hook = ip_vs_out_hook, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_LOCAL_OUT, + .priority = NF_IP_PRI_NAT_DST + 1, + }, + /* After mangle, schedule and forward local requests */ + { + .hook = ip_vs_in_hook, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_LOCAL_OUT, + .priority = NF_IP_PRI_NAT_DST + 2, + }, + /* After packet filtering (but before ip_vs_out_icmp), catch icmp + * destined for 0.0.0.0/0, which is for incoming IPVS connections */ + { + .hook = ip_vs_forward_icmp, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_FORWARD, + .priority = 99, + }, + /* After packet filtering, change source only for VS/NAT */ + { + .hook = ip_vs_out_hook, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_FORWARD, + .priority = 100, + }, +}; + +#ifdef CONFIG_IP_VS_IPV6 +static const struct nf_hook_ops ip_vs_ops6[] = { + /* After packet filtering, change source only for VS/NAT */ + { + .hook = ip_vs_out_hook, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_LOCAL_IN, + .priority = NF_IP6_PRI_NAT_SRC - 2, + }, + /* After packet filtering, forward packet through VS/DR, VS/TUN, + * or VS/NAT(change destination), so that filtering rules can be + * applied to IPVS. */ + { + .hook = ip_vs_in_hook, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_LOCAL_IN, + .priority = NF_IP6_PRI_NAT_SRC - 1, + }, + /* Before ip_vs_in, change source only for VS/NAT */ + { + .hook = ip_vs_out_hook, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_LOCAL_OUT, + .priority = NF_IP6_PRI_NAT_DST + 1, + }, + /* After mangle, schedule and forward local requests */ + { + .hook = ip_vs_in_hook, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_LOCAL_OUT, + .priority = NF_IP6_PRI_NAT_DST + 2, + }, + /* After packet filtering (but before ip_vs_out_icmp), catch icmp + * destined for 0.0.0.0/0, which is for incoming IPVS connections */ + { + .hook = ip_vs_forward_icmp, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_FORWARD, + .priority = 99, + }, + /* After packet filtering, change source only for VS/NAT */ + { + .hook = ip_vs_out_hook, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_FORWARD, + .priority = 100, + }, +}; +#endif + +int ip_vs_register_hooks(struct netns_ipvs *ipvs, unsigned int af) +{ + const struct nf_hook_ops *ops; + unsigned int count; + unsigned int afmask; + int ret = 0; + + if (af == AF_INET6) { +#ifdef CONFIG_IP_VS_IPV6 + ops = ip_vs_ops6; + count = ARRAY_SIZE(ip_vs_ops6); + afmask = 2; +#else + return -EINVAL; +#endif + } else { + ops = ip_vs_ops4; + count = ARRAY_SIZE(ip_vs_ops4); + afmask = 1; + } + + if (!(ipvs->hooks_afmask & afmask)) { + ret = nf_register_net_hooks(ipvs->net, ops, count); + if (ret >= 0) + ipvs->hooks_afmask |= afmask; + } + return ret; +} + +void ip_vs_unregister_hooks(struct netns_ipvs *ipvs, unsigned int af) +{ + const struct nf_hook_ops *ops; + unsigned int count; + unsigned int afmask; + + if (af == AF_INET6) { +#ifdef CONFIG_IP_VS_IPV6 + ops = ip_vs_ops6; + count = ARRAY_SIZE(ip_vs_ops6); + afmask = 2; +#else + return; +#endif + } else { + ops = ip_vs_ops4; + count = ARRAY_SIZE(ip_vs_ops4); + afmask = 1; + } + + if (ipvs->hooks_afmask & afmask) { + nf_unregister_net_hooks(ipvs->net, ops, count); + ipvs->hooks_afmask &= ~afmask; + } +} + +/* + * Initialize IP Virtual Server netns mem. + */ +static int __net_init __ip_vs_init(struct net *net) +{ + struct netns_ipvs *ipvs; + + ipvs = net_generic(net, ip_vs_net_id); + if (ipvs == NULL) + return -ENOMEM; + + /* Hold the beast until a service is registered */ + ipvs->enable = 0; + ipvs->net = net; + /* Counters used for creating unique names */ + ipvs->gen = atomic_read(&ipvs_netns_cnt); + atomic_inc(&ipvs_netns_cnt); + net->ipvs = ipvs; + + if (ip_vs_estimator_net_init(ipvs) < 0) + goto estimator_fail; + + if (ip_vs_control_net_init(ipvs) < 0) + goto control_fail; + + if (ip_vs_protocol_net_init(ipvs) < 0) + goto protocol_fail; + + if (ip_vs_app_net_init(ipvs) < 0) + goto app_fail; + + if (ip_vs_conn_net_init(ipvs) < 0) + goto conn_fail; + + if (ip_vs_sync_net_init(ipvs) < 0) + goto sync_fail; + + return 0; +/* + * Error handling + */ + +sync_fail: + ip_vs_conn_net_cleanup(ipvs); +conn_fail: + ip_vs_app_net_cleanup(ipvs); +app_fail: + ip_vs_protocol_net_cleanup(ipvs); +protocol_fail: + ip_vs_control_net_cleanup(ipvs); +control_fail: + ip_vs_estimator_net_cleanup(ipvs); +estimator_fail: + net->ipvs = NULL; + return -ENOMEM; +} + +static void __net_exit __ip_vs_cleanup_batch(struct list_head *net_list) +{ + struct netns_ipvs *ipvs; + struct net *net; + + ip_vs_service_nets_cleanup(net_list); /* ip_vs_flush() with locks */ + list_for_each_entry(net, net_list, exit_list) { + ipvs = net_ipvs(net); + ip_vs_conn_net_cleanup(ipvs); + ip_vs_app_net_cleanup(ipvs); + ip_vs_protocol_net_cleanup(ipvs); + ip_vs_control_net_cleanup(ipvs); + ip_vs_estimator_net_cleanup(ipvs); + IP_VS_DBG(2, "ipvs netns %d released\n", ipvs->gen); + net->ipvs = NULL; + } +} + +static void __net_exit __ip_vs_dev_cleanup_batch(struct list_head *net_list) +{ + struct netns_ipvs *ipvs; + struct net *net; + + EnterFunction(2); + list_for_each_entry(net, net_list, exit_list) { + ipvs = net_ipvs(net); + ip_vs_unregister_hooks(ipvs, AF_INET); + ip_vs_unregister_hooks(ipvs, AF_INET6); + ipvs->enable = 0; /* Disable packet reception */ + smp_wmb(); + ip_vs_sync_net_cleanup(ipvs); + } + LeaveFunction(2); +} + +static struct pernet_operations ipvs_core_ops = { + .init = __ip_vs_init, + .exit_batch = __ip_vs_cleanup_batch, + .id = &ip_vs_net_id, + .size = sizeof(struct netns_ipvs), +}; + +static struct pernet_operations ipvs_core_dev_ops = { + .exit_batch = __ip_vs_dev_cleanup_batch, +}; + +/* + * Initialize IP Virtual Server + */ +static int __init ip_vs_init(void) +{ + int ret; + + ret = ip_vs_control_init(); + if (ret < 0) { + pr_err("can't setup control.\n"); + goto exit; + } + + ip_vs_protocol_init(); + + ret = ip_vs_conn_init(); + if (ret < 0) { + pr_err("can't setup connection table.\n"); + goto cleanup_protocol; + } + + ret = register_pernet_subsys(&ipvs_core_ops); /* Alloc ip_vs struct */ + if (ret < 0) + goto cleanup_conn; + + ret = register_pernet_device(&ipvs_core_dev_ops); + if (ret < 0) + goto cleanup_sub; + + ret = ip_vs_register_nl_ioctl(); + if (ret < 0) { + pr_err("can't register netlink/ioctl.\n"); + goto cleanup_dev; + } + + pr_info("ipvs loaded.\n"); + + return ret; + +cleanup_dev: + unregister_pernet_device(&ipvs_core_dev_ops); +cleanup_sub: + unregister_pernet_subsys(&ipvs_core_ops); +cleanup_conn: + ip_vs_conn_cleanup(); +cleanup_protocol: + ip_vs_protocol_cleanup(); + ip_vs_control_cleanup(); +exit: + return ret; +} + +static void __exit ip_vs_cleanup(void) +{ + ip_vs_unregister_nl_ioctl(); + unregister_pernet_device(&ipvs_core_dev_ops); + unregister_pernet_subsys(&ipvs_core_ops); /* free ip_vs struct */ + ip_vs_conn_cleanup(); + ip_vs_protocol_cleanup(); + ip_vs_control_cleanup(); + pr_info("ipvs unloaded.\n"); +} + +module_init(ip_vs_init); +module_exit(ip_vs_cleanup); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/ipvs/ip_vs_ctl.c b/net/netfilter/ipvs/ip_vs_ctl.c new file mode 100644 index 000000000..17a1b731a --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_ctl.c @@ -0,0 +1,4288 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPVS An implementation of the IP virtual server support for the + * LINUX operating system. IPVS is now implemented as a module + * over the NetFilter framework. IPVS can be used to build a + * high-performance and highly available server based on a + * cluster of servers. + * + * Authors: Wensong Zhang + * Peter Kese + * Julian Anastasov + * + * Changes: + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#ifdef CONFIG_IP_VS_IPV6 +#include +#include +#include +#endif +#include +#include +#include + +#include + +#include + +MODULE_ALIAS_GENL_FAMILY(IPVS_GENL_NAME); + +/* semaphore for IPVS sockopts. And, [gs]etsockopt may sleep. */ +static DEFINE_MUTEX(__ip_vs_mutex); + +/* sysctl variables */ + +#ifdef CONFIG_IP_VS_DEBUG +static int sysctl_ip_vs_debug_level = 0; + +int ip_vs_get_debug_level(void) +{ + return sysctl_ip_vs_debug_level; +} +#endif + + +/* Protos */ +static void __ip_vs_del_service(struct ip_vs_service *svc, bool cleanup); + + +#ifdef CONFIG_IP_VS_IPV6 +/* Taken from rt6_fill_node() in net/ipv6/route.c, is there a better way? */ +static bool __ip_vs_addr_is_local_v6(struct net *net, + const struct in6_addr *addr) +{ + struct flowi6 fl6 = { + .daddr = *addr, + }; + struct dst_entry *dst = ip6_route_output(net, NULL, &fl6); + bool is_local; + + is_local = !dst->error && dst->dev && (dst->dev->flags & IFF_LOOPBACK); + + dst_release(dst); + return is_local; +} +#endif + +#ifdef CONFIG_SYSCTL +/* + * update_defense_level is called from keventd and from sysctl, + * so it needs to protect itself from softirqs + */ +static void update_defense_level(struct netns_ipvs *ipvs) +{ + struct sysinfo i; + int availmem; + int nomem; + int to_change = -1; + + /* we only count free and buffered memory (in pages) */ + si_meminfo(&i); + availmem = i.freeram + i.bufferram; + /* however in linux 2.5 the i.bufferram is total page cache size, + we need adjust it */ + /* si_swapinfo(&i); */ + /* availmem = availmem - (i.totalswap - i.freeswap); */ + + nomem = (availmem < ipvs->sysctl_amemthresh); + + local_bh_disable(); + + /* drop_entry */ + spin_lock(&ipvs->dropentry_lock); + switch (ipvs->sysctl_drop_entry) { + case 0: + atomic_set(&ipvs->dropentry, 0); + break; + case 1: + if (nomem) { + atomic_set(&ipvs->dropentry, 1); + ipvs->sysctl_drop_entry = 2; + } else { + atomic_set(&ipvs->dropentry, 0); + } + break; + case 2: + if (nomem) { + atomic_set(&ipvs->dropentry, 1); + } else { + atomic_set(&ipvs->dropentry, 0); + ipvs->sysctl_drop_entry = 1; + } + break; + case 3: + atomic_set(&ipvs->dropentry, 1); + break; + } + spin_unlock(&ipvs->dropentry_lock); + + /* drop_packet */ + spin_lock(&ipvs->droppacket_lock); + switch (ipvs->sysctl_drop_packet) { + case 0: + ipvs->drop_rate = 0; + break; + case 1: + if (nomem) { + ipvs->drop_rate = ipvs->drop_counter + = ipvs->sysctl_amemthresh / + (ipvs->sysctl_amemthresh-availmem); + ipvs->sysctl_drop_packet = 2; + } else { + ipvs->drop_rate = 0; + } + break; + case 2: + if (nomem) { + ipvs->drop_rate = ipvs->drop_counter + = ipvs->sysctl_amemthresh / + (ipvs->sysctl_amemthresh-availmem); + } else { + ipvs->drop_rate = 0; + ipvs->sysctl_drop_packet = 1; + } + break; + case 3: + ipvs->drop_rate = ipvs->sysctl_am_droprate; + break; + } + spin_unlock(&ipvs->droppacket_lock); + + /* secure_tcp */ + spin_lock(&ipvs->securetcp_lock); + switch (ipvs->sysctl_secure_tcp) { + case 0: + if (ipvs->old_secure_tcp >= 2) + to_change = 0; + break; + case 1: + if (nomem) { + if (ipvs->old_secure_tcp < 2) + to_change = 1; + ipvs->sysctl_secure_tcp = 2; + } else { + if (ipvs->old_secure_tcp >= 2) + to_change = 0; + } + break; + case 2: + if (nomem) { + if (ipvs->old_secure_tcp < 2) + to_change = 1; + } else { + if (ipvs->old_secure_tcp >= 2) + to_change = 0; + ipvs->sysctl_secure_tcp = 1; + } + break; + case 3: + if (ipvs->old_secure_tcp < 2) + to_change = 1; + break; + } + ipvs->old_secure_tcp = ipvs->sysctl_secure_tcp; + if (to_change >= 0) + ip_vs_protocol_timeout_change(ipvs, + ipvs->sysctl_secure_tcp > 1); + spin_unlock(&ipvs->securetcp_lock); + + local_bh_enable(); +} + +/* Handler for delayed work for expiring no + * destination connections + */ +static void expire_nodest_conn_handler(struct work_struct *work) +{ + struct netns_ipvs *ipvs; + + ipvs = container_of(work, struct netns_ipvs, + expire_nodest_conn_work.work); + ip_vs_expire_nodest_conn_flush(ipvs); +} + +/* + * Timer for checking the defense + */ +#define DEFENSE_TIMER_PERIOD 1*HZ + +static void defense_work_handler(struct work_struct *work) +{ + struct netns_ipvs *ipvs = + container_of(work, struct netns_ipvs, defense_work.work); + + update_defense_level(ipvs); + if (atomic_read(&ipvs->dropentry)) + ip_vs_random_dropentry(ipvs); + queue_delayed_work(system_long_wq, &ipvs->defense_work, + DEFENSE_TIMER_PERIOD); +} +#endif + +int +ip_vs_use_count_inc(void) +{ + return try_module_get(THIS_MODULE); +} + +void +ip_vs_use_count_dec(void) +{ + module_put(THIS_MODULE); +} + + +/* + * Hash table: for virtual service lookups + */ +#define IP_VS_SVC_TAB_BITS 8 +#define IP_VS_SVC_TAB_SIZE (1 << IP_VS_SVC_TAB_BITS) +#define IP_VS_SVC_TAB_MASK (IP_VS_SVC_TAB_SIZE - 1) + +/* the service table hashed by */ +static struct hlist_head ip_vs_svc_table[IP_VS_SVC_TAB_SIZE]; +/* the service table hashed by fwmark */ +static struct hlist_head ip_vs_svc_fwm_table[IP_VS_SVC_TAB_SIZE]; + + +/* + * Returns hash value for virtual service + */ +static inline unsigned int +ip_vs_svc_hashkey(struct netns_ipvs *ipvs, int af, unsigned int proto, + const union nf_inet_addr *addr, __be16 port) +{ + unsigned int porth = ntohs(port); + __be32 addr_fold = addr->ip; + __u32 ahash; + +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + addr_fold = addr->ip6[0]^addr->ip6[1]^ + addr->ip6[2]^addr->ip6[3]; +#endif + ahash = ntohl(addr_fold); + ahash ^= ((size_t) ipvs >> 8); + + return (proto ^ ahash ^ (porth >> IP_VS_SVC_TAB_BITS) ^ porth) & + IP_VS_SVC_TAB_MASK; +} + +/* + * Returns hash value of fwmark for virtual service lookup + */ +static inline unsigned int ip_vs_svc_fwm_hashkey(struct netns_ipvs *ipvs, __u32 fwmark) +{ + return (((size_t)ipvs>>8) ^ fwmark) & IP_VS_SVC_TAB_MASK; +} + +/* + * Hashes a service in the ip_vs_svc_table by + * or in the ip_vs_svc_fwm_table by fwmark. + * Should be called with locked tables. + */ +static int ip_vs_svc_hash(struct ip_vs_service *svc) +{ + unsigned int hash; + + if (svc->flags & IP_VS_SVC_F_HASHED) { + pr_err("%s(): request for already hashed, called from %pS\n", + __func__, __builtin_return_address(0)); + return 0; + } + + if (svc->fwmark == 0) { + /* + * Hash it by in ip_vs_svc_table + */ + hash = ip_vs_svc_hashkey(svc->ipvs, svc->af, svc->protocol, + &svc->addr, svc->port); + hlist_add_head_rcu(&svc->s_list, &ip_vs_svc_table[hash]); + } else { + /* + * Hash it by fwmark in svc_fwm_table + */ + hash = ip_vs_svc_fwm_hashkey(svc->ipvs, svc->fwmark); + hlist_add_head_rcu(&svc->f_list, &ip_vs_svc_fwm_table[hash]); + } + + svc->flags |= IP_VS_SVC_F_HASHED; + /* increase its refcnt because it is referenced by the svc table */ + atomic_inc(&svc->refcnt); + return 1; +} + + +/* + * Unhashes a service from svc_table / svc_fwm_table. + * Should be called with locked tables. + */ +static int ip_vs_svc_unhash(struct ip_vs_service *svc) +{ + if (!(svc->flags & IP_VS_SVC_F_HASHED)) { + pr_err("%s(): request for unhash flagged, called from %pS\n", + __func__, __builtin_return_address(0)); + return 0; + } + + if (svc->fwmark == 0) { + /* Remove it from the svc_table table */ + hlist_del_rcu(&svc->s_list); + } else { + /* Remove it from the svc_fwm_table table */ + hlist_del_rcu(&svc->f_list); + } + + svc->flags &= ~IP_VS_SVC_F_HASHED; + atomic_dec(&svc->refcnt); + return 1; +} + + +/* + * Get service by {netns, proto,addr,port} in the service table. + */ +static inline struct ip_vs_service * +__ip_vs_service_find(struct netns_ipvs *ipvs, int af, __u16 protocol, + const union nf_inet_addr *vaddr, __be16 vport) +{ + unsigned int hash; + struct ip_vs_service *svc; + + /* Check for "full" addressed entries */ + hash = ip_vs_svc_hashkey(ipvs, af, protocol, vaddr, vport); + + hlist_for_each_entry_rcu(svc, &ip_vs_svc_table[hash], s_list) { + if ((svc->af == af) + && ip_vs_addr_equal(af, &svc->addr, vaddr) + && (svc->port == vport) + && (svc->protocol == protocol) + && (svc->ipvs == ipvs)) { + /* HIT */ + return svc; + } + } + + return NULL; +} + + +/* + * Get service by {fwmark} in the service table. + */ +static inline struct ip_vs_service * +__ip_vs_svc_fwm_find(struct netns_ipvs *ipvs, int af, __u32 fwmark) +{ + unsigned int hash; + struct ip_vs_service *svc; + + /* Check for fwmark addressed entries */ + hash = ip_vs_svc_fwm_hashkey(ipvs, fwmark); + + hlist_for_each_entry_rcu(svc, &ip_vs_svc_fwm_table[hash], f_list) { + if (svc->fwmark == fwmark && svc->af == af + && (svc->ipvs == ipvs)) { + /* HIT */ + return svc; + } + } + + return NULL; +} + +/* Find service, called under RCU lock */ +struct ip_vs_service * +ip_vs_service_find(struct netns_ipvs *ipvs, int af, __u32 fwmark, __u16 protocol, + const union nf_inet_addr *vaddr, __be16 vport) +{ + struct ip_vs_service *svc; + + /* + * Check the table hashed by fwmark first + */ + if (fwmark) { + svc = __ip_vs_svc_fwm_find(ipvs, af, fwmark); + if (svc) + goto out; + } + + /* + * Check the table hashed by + * for "full" addressed entries + */ + svc = __ip_vs_service_find(ipvs, af, protocol, vaddr, vport); + + if (!svc && protocol == IPPROTO_TCP && + atomic_read(&ipvs->ftpsvc_counter) && + (vport == FTPDATA || !inet_port_requires_bind_service(ipvs->net, ntohs(vport)))) { + /* + * Check if ftp service entry exists, the packet + * might belong to FTP data connections. + */ + svc = __ip_vs_service_find(ipvs, af, protocol, vaddr, FTPPORT); + } + + if (svc == NULL + && atomic_read(&ipvs->nullsvc_counter)) { + /* + * Check if the catch-all port (port zero) exists + */ + svc = __ip_vs_service_find(ipvs, af, protocol, vaddr, 0); + } + + out: + IP_VS_DBG_BUF(9, "lookup service: fwm %u %s %s:%u %s\n", + fwmark, ip_vs_proto_name(protocol), + IP_VS_DBG_ADDR(af, vaddr), ntohs(vport), + svc ? "hit" : "not hit"); + + return svc; +} + + +static inline void +__ip_vs_bind_svc(struct ip_vs_dest *dest, struct ip_vs_service *svc) +{ + atomic_inc(&svc->refcnt); + rcu_assign_pointer(dest->svc, svc); +} + +static void ip_vs_service_free(struct ip_vs_service *svc) +{ + free_percpu(svc->stats.cpustats); + kfree(svc); +} + +static void ip_vs_service_rcu_free(struct rcu_head *head) +{ + struct ip_vs_service *svc; + + svc = container_of(head, struct ip_vs_service, rcu_head); + ip_vs_service_free(svc); +} + +static void __ip_vs_svc_put(struct ip_vs_service *svc, bool do_delay) +{ + if (atomic_dec_and_test(&svc->refcnt)) { + IP_VS_DBG_BUF(3, "Removing service %u/%s:%u\n", + svc->fwmark, + IP_VS_DBG_ADDR(svc->af, &svc->addr), + ntohs(svc->port)); + if (do_delay) + call_rcu(&svc->rcu_head, ip_vs_service_rcu_free); + else + ip_vs_service_free(svc); + } +} + + +/* + * Returns hash value for real service + */ +static inline unsigned int ip_vs_rs_hashkey(int af, + const union nf_inet_addr *addr, + __be16 port) +{ + unsigned int porth = ntohs(port); + __be32 addr_fold = addr->ip; + +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + addr_fold = addr->ip6[0]^addr->ip6[1]^ + addr->ip6[2]^addr->ip6[3]; +#endif + + return (ntohl(addr_fold)^(porth>>IP_VS_RTAB_BITS)^porth) + & IP_VS_RTAB_MASK; +} + +/* Hash ip_vs_dest in rs_table by . */ +static void ip_vs_rs_hash(struct netns_ipvs *ipvs, struct ip_vs_dest *dest) +{ + unsigned int hash; + __be16 port; + + if (dest->in_rs_table) + return; + + switch (IP_VS_DFWD_METHOD(dest)) { + case IP_VS_CONN_F_MASQ: + port = dest->port; + break; + case IP_VS_CONN_F_TUNNEL: + switch (dest->tun_type) { + case IP_VS_CONN_F_TUNNEL_TYPE_GUE: + port = dest->tun_port; + break; + case IP_VS_CONN_F_TUNNEL_TYPE_IPIP: + case IP_VS_CONN_F_TUNNEL_TYPE_GRE: + port = 0; + break; + default: + return; + } + break; + default: + return; + } + + /* + * Hash by proto,addr,port, + * which are the parameters of the real service. + */ + hash = ip_vs_rs_hashkey(dest->af, &dest->addr, port); + + hlist_add_head_rcu(&dest->d_list, &ipvs->rs_table[hash]); + dest->in_rs_table = 1; +} + +/* Unhash ip_vs_dest from rs_table. */ +static void ip_vs_rs_unhash(struct ip_vs_dest *dest) +{ + /* + * Remove it from the rs_table table. + */ + if (dest->in_rs_table) { + hlist_del_rcu(&dest->d_list); + dest->in_rs_table = 0; + } +} + +/* Check if real service by is present */ +bool ip_vs_has_real_service(struct netns_ipvs *ipvs, int af, __u16 protocol, + const union nf_inet_addr *daddr, __be16 dport) +{ + unsigned int hash; + struct ip_vs_dest *dest; + + /* Check for "full" addressed entries */ + hash = ip_vs_rs_hashkey(af, daddr, dport); + + hlist_for_each_entry_rcu(dest, &ipvs->rs_table[hash], d_list) { + if (dest->port == dport && + dest->af == af && + ip_vs_addr_equal(af, &dest->addr, daddr) && + (dest->protocol == protocol || dest->vfwmark) && + IP_VS_DFWD_METHOD(dest) == IP_VS_CONN_F_MASQ) { + /* HIT */ + return true; + } + } + + return false; +} + +/* Find real service record by . + * In case of multiple records with the same , only + * the first found record is returned. + * + * To be called under RCU lock. + */ +struct ip_vs_dest *ip_vs_find_real_service(struct netns_ipvs *ipvs, int af, + __u16 protocol, + const union nf_inet_addr *daddr, + __be16 dport) +{ + unsigned int hash; + struct ip_vs_dest *dest; + + /* Check for "full" addressed entries */ + hash = ip_vs_rs_hashkey(af, daddr, dport); + + hlist_for_each_entry_rcu(dest, &ipvs->rs_table[hash], d_list) { + if (dest->port == dport && + dest->af == af && + ip_vs_addr_equal(af, &dest->addr, daddr) && + (dest->protocol == protocol || dest->vfwmark) && + IP_VS_DFWD_METHOD(dest) == IP_VS_CONN_F_MASQ) { + /* HIT */ + return dest; + } + } + + return NULL; +} + +/* Find real service record by . + * In case of multiple records with the same , only + * the first found record is returned. + * + * To be called under RCU lock. + */ +struct ip_vs_dest *ip_vs_find_tunnel(struct netns_ipvs *ipvs, int af, + const union nf_inet_addr *daddr, + __be16 tun_port) +{ + struct ip_vs_dest *dest; + unsigned int hash; + + /* Check for "full" addressed entries */ + hash = ip_vs_rs_hashkey(af, daddr, tun_port); + + hlist_for_each_entry_rcu(dest, &ipvs->rs_table[hash], d_list) { + if (dest->tun_port == tun_port && + dest->af == af && + ip_vs_addr_equal(af, &dest->addr, daddr) && + IP_VS_DFWD_METHOD(dest) == IP_VS_CONN_F_TUNNEL) { + /* HIT */ + return dest; + } + } + + return NULL; +} + +/* Lookup destination by {addr,port} in the given service + * Called under RCU lock. + */ +static struct ip_vs_dest * +ip_vs_lookup_dest(struct ip_vs_service *svc, int dest_af, + const union nf_inet_addr *daddr, __be16 dport) +{ + struct ip_vs_dest *dest; + + /* + * Find the destination for the given service + */ + list_for_each_entry_rcu(dest, &svc->destinations, n_list) { + if ((dest->af == dest_af) && + ip_vs_addr_equal(dest_af, &dest->addr, daddr) && + (dest->port == dport)) { + /* HIT */ + return dest; + } + } + + return NULL; +} + +/* + * Find destination by {daddr,dport,vaddr,protocol} + * Created to be used in ip_vs_process_message() in + * the backup synchronization daemon. It finds the + * destination to be bound to the received connection + * on the backup. + * Called under RCU lock, no refcnt is returned. + */ +struct ip_vs_dest *ip_vs_find_dest(struct netns_ipvs *ipvs, int svc_af, int dest_af, + const union nf_inet_addr *daddr, + __be16 dport, + const union nf_inet_addr *vaddr, + __be16 vport, __u16 protocol, __u32 fwmark, + __u32 flags) +{ + struct ip_vs_dest *dest; + struct ip_vs_service *svc; + __be16 port = dport; + + svc = ip_vs_service_find(ipvs, svc_af, fwmark, protocol, vaddr, vport); + if (!svc) + return NULL; + if (fwmark && (flags & IP_VS_CONN_F_FWD_MASK) != IP_VS_CONN_F_MASQ) + port = 0; + dest = ip_vs_lookup_dest(svc, dest_af, daddr, port); + if (!dest) + dest = ip_vs_lookup_dest(svc, dest_af, daddr, port ^ dport); + return dest; +} + +void ip_vs_dest_dst_rcu_free(struct rcu_head *head) +{ + struct ip_vs_dest_dst *dest_dst = container_of(head, + struct ip_vs_dest_dst, + rcu_head); + + dst_release(dest_dst->dst_cache); + kfree(dest_dst); +} + +/* Release dest_dst and dst_cache for dest in user context */ +static void __ip_vs_dst_cache_reset(struct ip_vs_dest *dest) +{ + struct ip_vs_dest_dst *old; + + old = rcu_dereference_protected(dest->dest_dst, 1); + if (old) { + RCU_INIT_POINTER(dest->dest_dst, NULL); + call_rcu(&old->rcu_head, ip_vs_dest_dst_rcu_free); + } +} + +/* + * Lookup dest by {svc,addr,port} in the destination trash. + * The destination trash is used to hold the destinations that are removed + * from the service table but are still referenced by some conn entries. + * The reason to add the destination trash is when the dest is temporary + * down (either by administrator or by monitor program), the dest can be + * picked back from the trash, the remaining connections to the dest can + * continue, and the counting information of the dest is also useful for + * scheduling. + */ +static struct ip_vs_dest * +ip_vs_trash_get_dest(struct ip_vs_service *svc, int dest_af, + const union nf_inet_addr *daddr, __be16 dport) +{ + struct ip_vs_dest *dest; + struct netns_ipvs *ipvs = svc->ipvs; + + /* + * Find the destination in trash + */ + spin_lock_bh(&ipvs->dest_trash_lock); + list_for_each_entry(dest, &ipvs->dest_trash, t_list) { + IP_VS_DBG_BUF(3, "Destination %u/%s:%u still in trash, " + "dest->refcnt=%d\n", + dest->vfwmark, + IP_VS_DBG_ADDR(dest->af, &dest->addr), + ntohs(dest->port), + refcount_read(&dest->refcnt)); + if (dest->af == dest_af && + ip_vs_addr_equal(dest_af, &dest->addr, daddr) && + dest->port == dport && + dest->vfwmark == svc->fwmark && + dest->protocol == svc->protocol && + (svc->fwmark || + (ip_vs_addr_equal(svc->af, &dest->vaddr, &svc->addr) && + dest->vport == svc->port))) { + /* HIT */ + list_del(&dest->t_list); + goto out; + } + } + + dest = NULL; + +out: + spin_unlock_bh(&ipvs->dest_trash_lock); + + return dest; +} + +static void ip_vs_dest_free(struct ip_vs_dest *dest) +{ + struct ip_vs_service *svc = rcu_dereference_protected(dest->svc, 1); + + __ip_vs_dst_cache_reset(dest); + __ip_vs_svc_put(svc, false); + free_percpu(dest->stats.cpustats); + ip_vs_dest_put_and_free(dest); +} + +/* + * Clean up all the destinations in the trash + * Called by the ip_vs_control_cleanup() + * + * When the ip_vs_control_clearup is activated by ipvs module exit, + * the service tables must have been flushed and all the connections + * are expired, and the refcnt of each destination in the trash must + * be 1, so we simply release them here. + */ +static void ip_vs_trash_cleanup(struct netns_ipvs *ipvs) +{ + struct ip_vs_dest *dest, *nxt; + + del_timer_sync(&ipvs->dest_trash_timer); + /* No need to use dest_trash_lock */ + list_for_each_entry_safe(dest, nxt, &ipvs->dest_trash, t_list) { + list_del(&dest->t_list); + ip_vs_dest_free(dest); + } +} + +static void +ip_vs_copy_stats(struct ip_vs_kstats *dst, struct ip_vs_stats *src) +{ +#define IP_VS_SHOW_STATS_COUNTER(c) dst->c = src->kstats.c - src->kstats0.c + + spin_lock_bh(&src->lock); + + IP_VS_SHOW_STATS_COUNTER(conns); + IP_VS_SHOW_STATS_COUNTER(inpkts); + IP_VS_SHOW_STATS_COUNTER(outpkts); + IP_VS_SHOW_STATS_COUNTER(inbytes); + IP_VS_SHOW_STATS_COUNTER(outbytes); + + ip_vs_read_estimator(dst, src); + + spin_unlock_bh(&src->lock); +} + +static void +ip_vs_export_stats_user(struct ip_vs_stats_user *dst, struct ip_vs_kstats *src) +{ + dst->conns = (u32)src->conns; + dst->inpkts = (u32)src->inpkts; + dst->outpkts = (u32)src->outpkts; + dst->inbytes = src->inbytes; + dst->outbytes = src->outbytes; + dst->cps = (u32)src->cps; + dst->inpps = (u32)src->inpps; + dst->outpps = (u32)src->outpps; + dst->inbps = (u32)src->inbps; + dst->outbps = (u32)src->outbps; +} + +static void +ip_vs_zero_stats(struct ip_vs_stats *stats) +{ + spin_lock_bh(&stats->lock); + + /* get current counters as zero point, rates are zeroed */ + +#define IP_VS_ZERO_STATS_COUNTER(c) stats->kstats0.c = stats->kstats.c + + IP_VS_ZERO_STATS_COUNTER(conns); + IP_VS_ZERO_STATS_COUNTER(inpkts); + IP_VS_ZERO_STATS_COUNTER(outpkts); + IP_VS_ZERO_STATS_COUNTER(inbytes); + IP_VS_ZERO_STATS_COUNTER(outbytes); + + ip_vs_zero_estimator(stats); + + spin_unlock_bh(&stats->lock); +} + +/* + * Update a destination in the given service + */ +static void +__ip_vs_update_dest(struct ip_vs_service *svc, struct ip_vs_dest *dest, + struct ip_vs_dest_user_kern *udest, int add) +{ + struct netns_ipvs *ipvs = svc->ipvs; + struct ip_vs_service *old_svc; + struct ip_vs_scheduler *sched; + int conn_flags; + + /* We cannot modify an address and change the address family */ + BUG_ON(!add && udest->af != dest->af); + + if (add && udest->af != svc->af) + ipvs->mixed_address_family_dests++; + + /* keep the last_weight with latest non-0 weight */ + if (add || udest->weight != 0) + atomic_set(&dest->last_weight, udest->weight); + + /* set the weight and the flags */ + atomic_set(&dest->weight, udest->weight); + conn_flags = udest->conn_flags & IP_VS_CONN_F_DEST_MASK; + conn_flags |= IP_VS_CONN_F_INACTIVE; + + /* Need to rehash? */ + if ((udest->conn_flags & IP_VS_CONN_F_FWD_MASK) != + IP_VS_DFWD_METHOD(dest) || + udest->tun_type != dest->tun_type || + udest->tun_port != dest->tun_port) + ip_vs_rs_unhash(dest); + + /* set the tunnel info */ + dest->tun_type = udest->tun_type; + dest->tun_port = udest->tun_port; + dest->tun_flags = udest->tun_flags; + + /* set the IP_VS_CONN_F_NOOUTPUT flag if not masquerading/NAT */ + if ((conn_flags & IP_VS_CONN_F_FWD_MASK) != IP_VS_CONN_F_MASQ) { + conn_flags |= IP_VS_CONN_F_NOOUTPUT; + } else { + /* FTP-NAT requires conntrack for mangling */ + if (svc->port == FTPPORT) + ip_vs_register_conntrack(svc); + } + atomic_set(&dest->conn_flags, conn_flags); + /* Put the real service in rs_table if not present. */ + ip_vs_rs_hash(ipvs, dest); + + /* bind the service */ + old_svc = rcu_dereference_protected(dest->svc, 1); + if (!old_svc) { + __ip_vs_bind_svc(dest, svc); + } else { + if (old_svc != svc) { + ip_vs_zero_stats(&dest->stats); + __ip_vs_bind_svc(dest, svc); + __ip_vs_svc_put(old_svc, true); + } + } + + /* set the dest status flags */ + dest->flags |= IP_VS_DEST_F_AVAILABLE; + + if (udest->u_threshold == 0 || udest->u_threshold > dest->u_threshold) + dest->flags &= ~IP_VS_DEST_F_OVERLOAD; + dest->u_threshold = udest->u_threshold; + dest->l_threshold = udest->l_threshold; + + dest->af = udest->af; + + spin_lock_bh(&dest->dst_lock); + __ip_vs_dst_cache_reset(dest); + spin_unlock_bh(&dest->dst_lock); + + if (add) { + ip_vs_start_estimator(svc->ipvs, &dest->stats); + list_add_rcu(&dest->n_list, &svc->destinations); + svc->num_dests++; + sched = rcu_dereference_protected(svc->scheduler, 1); + if (sched && sched->add_dest) + sched->add_dest(svc, dest); + } else { + sched = rcu_dereference_protected(svc->scheduler, 1); + if (sched && sched->upd_dest) + sched->upd_dest(svc, dest); + } +} + + +/* + * Create a destination for the given service + */ +static int +ip_vs_new_dest(struct ip_vs_service *svc, struct ip_vs_dest_user_kern *udest) +{ + struct ip_vs_dest *dest; + unsigned int atype, i; + + EnterFunction(2); + +#ifdef CONFIG_IP_VS_IPV6 + if (udest->af == AF_INET6) { + int ret; + + atype = ipv6_addr_type(&udest->addr.in6); + if ((!(atype & IPV6_ADDR_UNICAST) || + atype & IPV6_ADDR_LINKLOCAL) && + !__ip_vs_addr_is_local_v6(svc->ipvs->net, &udest->addr.in6)) + return -EINVAL; + + ret = nf_defrag_ipv6_enable(svc->ipvs->net); + if (ret) + return ret; + } else +#endif + { + atype = inet_addr_type(svc->ipvs->net, udest->addr.ip); + if (atype != RTN_LOCAL && atype != RTN_UNICAST) + return -EINVAL; + } + + dest = kzalloc(sizeof(struct ip_vs_dest), GFP_KERNEL); + if (dest == NULL) + return -ENOMEM; + + dest->stats.cpustats = alloc_percpu(struct ip_vs_cpu_stats); + if (!dest->stats.cpustats) + goto err_alloc; + + for_each_possible_cpu(i) { + struct ip_vs_cpu_stats *ip_vs_dest_stats; + ip_vs_dest_stats = per_cpu_ptr(dest->stats.cpustats, i); + u64_stats_init(&ip_vs_dest_stats->syncp); + } + + dest->af = udest->af; + dest->protocol = svc->protocol; + dest->vaddr = svc->addr; + dest->vport = svc->port; + dest->vfwmark = svc->fwmark; + ip_vs_addr_copy(udest->af, &dest->addr, &udest->addr); + dest->port = udest->port; + + atomic_set(&dest->activeconns, 0); + atomic_set(&dest->inactconns, 0); + atomic_set(&dest->persistconns, 0); + refcount_set(&dest->refcnt, 1); + + INIT_HLIST_NODE(&dest->d_list); + spin_lock_init(&dest->dst_lock); + spin_lock_init(&dest->stats.lock); + __ip_vs_update_dest(svc, dest, udest, 1); + + LeaveFunction(2); + return 0; + +err_alloc: + kfree(dest); + return -ENOMEM; +} + + +/* + * Add a destination into an existing service + */ +static int +ip_vs_add_dest(struct ip_vs_service *svc, struct ip_vs_dest_user_kern *udest) +{ + struct ip_vs_dest *dest; + union nf_inet_addr daddr; + __be16 dport = udest->port; + int ret; + + EnterFunction(2); + + if (udest->weight < 0) { + pr_err("%s(): server weight less than zero\n", __func__); + return -ERANGE; + } + + if (udest->l_threshold > udest->u_threshold) { + pr_err("%s(): lower threshold is higher than upper threshold\n", + __func__); + return -ERANGE; + } + + if (udest->tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GUE) { + if (udest->tun_port == 0) { + pr_err("%s(): tunnel port is zero\n", __func__); + return -EINVAL; + } + } + + ip_vs_addr_copy(udest->af, &daddr, &udest->addr); + + /* We use function that requires RCU lock */ + rcu_read_lock(); + dest = ip_vs_lookup_dest(svc, udest->af, &daddr, dport); + rcu_read_unlock(); + + if (dest != NULL) { + IP_VS_DBG(1, "%s(): dest already exists\n", __func__); + return -EEXIST; + } + + /* + * Check if the dest already exists in the trash and + * is from the same service + */ + dest = ip_vs_trash_get_dest(svc, udest->af, &daddr, dport); + + if (dest != NULL) { + IP_VS_DBG_BUF(3, "Get destination %s:%u from trash, " + "dest->refcnt=%d, service %u/%s:%u\n", + IP_VS_DBG_ADDR(udest->af, &daddr), ntohs(dport), + refcount_read(&dest->refcnt), + dest->vfwmark, + IP_VS_DBG_ADDR(svc->af, &dest->vaddr), + ntohs(dest->vport)); + + __ip_vs_update_dest(svc, dest, udest, 1); + ret = 0; + } else { + /* + * Allocate and initialize the dest structure + */ + ret = ip_vs_new_dest(svc, udest); + } + LeaveFunction(2); + + return ret; +} + + +/* + * Edit a destination in the given service + */ +static int +ip_vs_edit_dest(struct ip_vs_service *svc, struct ip_vs_dest_user_kern *udest) +{ + struct ip_vs_dest *dest; + union nf_inet_addr daddr; + __be16 dport = udest->port; + + EnterFunction(2); + + if (udest->weight < 0) { + pr_err("%s(): server weight less than zero\n", __func__); + return -ERANGE; + } + + if (udest->l_threshold > udest->u_threshold) { + pr_err("%s(): lower threshold is higher than upper threshold\n", + __func__); + return -ERANGE; + } + + if (udest->tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GUE) { + if (udest->tun_port == 0) { + pr_err("%s(): tunnel port is zero\n", __func__); + return -EINVAL; + } + } + + ip_vs_addr_copy(udest->af, &daddr, &udest->addr); + + /* We use function that requires RCU lock */ + rcu_read_lock(); + dest = ip_vs_lookup_dest(svc, udest->af, &daddr, dport); + rcu_read_unlock(); + + if (dest == NULL) { + IP_VS_DBG(1, "%s(): dest doesn't exist\n", __func__); + return -ENOENT; + } + + __ip_vs_update_dest(svc, dest, udest, 0); + LeaveFunction(2); + + return 0; +} + +/* + * Delete a destination (must be already unlinked from the service) + */ +static void __ip_vs_del_dest(struct netns_ipvs *ipvs, struct ip_vs_dest *dest, + bool cleanup) +{ + ip_vs_stop_estimator(ipvs, &dest->stats); + + /* + * Remove it from the d-linked list with the real services. + */ + ip_vs_rs_unhash(dest); + + spin_lock_bh(&ipvs->dest_trash_lock); + IP_VS_DBG_BUF(3, "Moving dest %s:%u into trash, dest->refcnt=%d\n", + IP_VS_DBG_ADDR(dest->af, &dest->addr), ntohs(dest->port), + refcount_read(&dest->refcnt)); + if (list_empty(&ipvs->dest_trash) && !cleanup) + mod_timer(&ipvs->dest_trash_timer, + jiffies + (IP_VS_DEST_TRASH_PERIOD >> 1)); + /* dest lives in trash with reference */ + list_add(&dest->t_list, &ipvs->dest_trash); + dest->idle_start = 0; + spin_unlock_bh(&ipvs->dest_trash_lock); + + /* Queue up delayed work to expire all no destination connections. + * No-op when CONFIG_SYSCTL is disabled. + */ + if (!cleanup) + ip_vs_enqueue_expire_nodest_conns(ipvs); +} + + +/* + * Unlink a destination from the given service + */ +static void __ip_vs_unlink_dest(struct ip_vs_service *svc, + struct ip_vs_dest *dest, + int svcupd) +{ + dest->flags &= ~IP_VS_DEST_F_AVAILABLE; + + /* + * Remove it from the d-linked destination list. + */ + list_del_rcu(&dest->n_list); + svc->num_dests--; + + if (dest->af != svc->af) + svc->ipvs->mixed_address_family_dests--; + + if (svcupd) { + struct ip_vs_scheduler *sched; + + sched = rcu_dereference_protected(svc->scheduler, 1); + if (sched && sched->del_dest) + sched->del_dest(svc, dest); + } +} + + +/* + * Delete a destination server in the given service + */ +static int +ip_vs_del_dest(struct ip_vs_service *svc, struct ip_vs_dest_user_kern *udest) +{ + struct ip_vs_dest *dest; + __be16 dport = udest->port; + + EnterFunction(2); + + /* We use function that requires RCU lock */ + rcu_read_lock(); + dest = ip_vs_lookup_dest(svc, udest->af, &udest->addr, dport); + rcu_read_unlock(); + + if (dest == NULL) { + IP_VS_DBG(1, "%s(): destination not found!\n", __func__); + return -ENOENT; + } + + /* + * Unlink dest from the service + */ + __ip_vs_unlink_dest(svc, dest, 1); + + /* + * Delete the destination + */ + __ip_vs_del_dest(svc->ipvs, dest, false); + + LeaveFunction(2); + + return 0; +} + +static void ip_vs_dest_trash_expire(struct timer_list *t) +{ + struct netns_ipvs *ipvs = from_timer(ipvs, t, dest_trash_timer); + struct ip_vs_dest *dest, *next; + unsigned long now = jiffies; + + spin_lock(&ipvs->dest_trash_lock); + list_for_each_entry_safe(dest, next, &ipvs->dest_trash, t_list) { + if (refcount_read(&dest->refcnt) > 1) + continue; + if (dest->idle_start) { + if (time_before(now, dest->idle_start + + IP_VS_DEST_TRASH_PERIOD)) + continue; + } else { + dest->idle_start = max(1UL, now); + continue; + } + IP_VS_DBG_BUF(3, "Removing destination %u/%s:%u from trash\n", + dest->vfwmark, + IP_VS_DBG_ADDR(dest->af, &dest->addr), + ntohs(dest->port)); + list_del(&dest->t_list); + ip_vs_dest_free(dest); + } + if (!list_empty(&ipvs->dest_trash)) + mod_timer(&ipvs->dest_trash_timer, + jiffies + (IP_VS_DEST_TRASH_PERIOD >> 1)); + spin_unlock(&ipvs->dest_trash_lock); +} + +/* + * Add a service into the service hash table + */ +static int +ip_vs_add_service(struct netns_ipvs *ipvs, struct ip_vs_service_user_kern *u, + struct ip_vs_service **svc_p) +{ + int ret = 0, i; + struct ip_vs_scheduler *sched = NULL; + struct ip_vs_pe *pe = NULL; + struct ip_vs_service *svc = NULL; + int ret_hooks = -1; + + /* increase the module use count */ + if (!ip_vs_use_count_inc()) + return -ENOPROTOOPT; + + /* Lookup the scheduler by 'u->sched_name' */ + if (strcmp(u->sched_name, "none")) { + sched = ip_vs_scheduler_get(u->sched_name); + if (!sched) { + pr_info("Scheduler module ip_vs_%s not found\n", + u->sched_name); + ret = -ENOENT; + goto out_err; + } + } + + if (u->pe_name && *u->pe_name) { + pe = ip_vs_pe_getbyname(u->pe_name); + if (pe == NULL) { + pr_info("persistence engine module ip_vs_pe_%s " + "not found\n", u->pe_name); + ret = -ENOENT; + goto out_err; + } + } + +#ifdef CONFIG_IP_VS_IPV6 + if (u->af == AF_INET6) { + __u32 plen = (__force __u32) u->netmask; + + if (plen < 1 || plen > 128) { + ret = -EINVAL; + goto out_err; + } + + ret = nf_defrag_ipv6_enable(ipvs->net); + if (ret) + goto out_err; + } +#endif + + if ((u->af == AF_INET && !ipvs->num_services) || + (u->af == AF_INET6 && !ipvs->num_services6)) { + ret = ip_vs_register_hooks(ipvs, u->af); + if (ret < 0) + goto out_err; + ret_hooks = ret; + } + + svc = kzalloc(sizeof(struct ip_vs_service), GFP_KERNEL); + if (svc == NULL) { + IP_VS_DBG(1, "%s(): no memory\n", __func__); + ret = -ENOMEM; + goto out_err; + } + svc->stats.cpustats = alloc_percpu(struct ip_vs_cpu_stats); + if (!svc->stats.cpustats) { + ret = -ENOMEM; + goto out_err; + } + + for_each_possible_cpu(i) { + struct ip_vs_cpu_stats *ip_vs_stats; + ip_vs_stats = per_cpu_ptr(svc->stats.cpustats, i); + u64_stats_init(&ip_vs_stats->syncp); + } + + + /* I'm the first user of the service */ + atomic_set(&svc->refcnt, 0); + + svc->af = u->af; + svc->protocol = u->protocol; + ip_vs_addr_copy(svc->af, &svc->addr, &u->addr); + svc->port = u->port; + svc->fwmark = u->fwmark; + svc->flags = u->flags & ~IP_VS_SVC_F_HASHED; + svc->timeout = u->timeout * HZ; + svc->netmask = u->netmask; + svc->ipvs = ipvs; + + INIT_LIST_HEAD(&svc->destinations); + spin_lock_init(&svc->sched_lock); + spin_lock_init(&svc->stats.lock); + + /* Bind the scheduler */ + if (sched) { + ret = ip_vs_bind_scheduler(svc, sched); + if (ret) + goto out_err; + sched = NULL; + } + + /* Bind the ct retriever */ + RCU_INIT_POINTER(svc->pe, pe); + pe = NULL; + + /* Update the virtual service counters */ + if (svc->port == FTPPORT) + atomic_inc(&ipvs->ftpsvc_counter); + else if (svc->port == 0) + atomic_inc(&ipvs->nullsvc_counter); + if (svc->pe && svc->pe->conn_out) + atomic_inc(&ipvs->conn_out_counter); + + ip_vs_start_estimator(ipvs, &svc->stats); + + /* Count only IPv4 services for old get/setsockopt interface */ + if (svc->af == AF_INET) + ipvs->num_services++; + else if (svc->af == AF_INET6) + ipvs->num_services6++; + + /* Hash the service into the service table */ + ip_vs_svc_hash(svc); + + *svc_p = svc; + /* Now there is a service - full throttle */ + ipvs->enable = 1; + return 0; + + + out_err: + if (ret_hooks >= 0) + ip_vs_unregister_hooks(ipvs, u->af); + if (svc != NULL) { + ip_vs_unbind_scheduler(svc, sched); + ip_vs_service_free(svc); + } + ip_vs_scheduler_put(sched); + ip_vs_pe_put(pe); + + /* decrease the module use count */ + ip_vs_use_count_dec(); + + return ret; +} + + +/* + * Edit a service and bind it with a new scheduler + */ +static int +ip_vs_edit_service(struct ip_vs_service *svc, struct ip_vs_service_user_kern *u) +{ + struct ip_vs_scheduler *sched = NULL, *old_sched; + struct ip_vs_pe *pe = NULL, *old_pe = NULL; + int ret = 0; + bool new_pe_conn_out, old_pe_conn_out; + + /* + * Lookup the scheduler, by 'u->sched_name' + */ + if (strcmp(u->sched_name, "none")) { + sched = ip_vs_scheduler_get(u->sched_name); + if (!sched) { + pr_info("Scheduler module ip_vs_%s not found\n", + u->sched_name); + return -ENOENT; + } + } + old_sched = sched; + + if (u->pe_name && *u->pe_name) { + pe = ip_vs_pe_getbyname(u->pe_name); + if (pe == NULL) { + pr_info("persistence engine module ip_vs_pe_%s " + "not found\n", u->pe_name); + ret = -ENOENT; + goto out; + } + old_pe = pe; + } + +#ifdef CONFIG_IP_VS_IPV6 + if (u->af == AF_INET6) { + __u32 plen = (__force __u32) u->netmask; + + if (plen < 1 || plen > 128) { + ret = -EINVAL; + goto out; + } + } +#endif + + old_sched = rcu_dereference_protected(svc->scheduler, 1); + if (sched != old_sched) { + if (old_sched) { + ip_vs_unbind_scheduler(svc, old_sched); + RCU_INIT_POINTER(svc->scheduler, NULL); + /* Wait all svc->sched_data users */ + synchronize_rcu(); + } + /* Bind the new scheduler */ + if (sched) { + ret = ip_vs_bind_scheduler(svc, sched); + if (ret) { + ip_vs_scheduler_put(sched); + goto out; + } + } + } + + /* + * Set the flags and timeout value + */ + svc->flags = u->flags | IP_VS_SVC_F_HASHED; + svc->timeout = u->timeout * HZ; + svc->netmask = u->netmask; + + old_pe = rcu_dereference_protected(svc->pe, 1); + if (pe != old_pe) { + rcu_assign_pointer(svc->pe, pe); + /* check for optional methods in new pe */ + new_pe_conn_out = (pe && pe->conn_out) ? true : false; + old_pe_conn_out = (old_pe && old_pe->conn_out) ? true : false; + if (new_pe_conn_out && !old_pe_conn_out) + atomic_inc(&svc->ipvs->conn_out_counter); + if (old_pe_conn_out && !new_pe_conn_out) + atomic_dec(&svc->ipvs->conn_out_counter); + } + +out: + ip_vs_scheduler_put(old_sched); + ip_vs_pe_put(old_pe); + return ret; +} + +/* + * Delete a service from the service list + * - The service must be unlinked, unlocked and not referenced! + * - We are called under _bh lock + */ +static void __ip_vs_del_service(struct ip_vs_service *svc, bool cleanup) +{ + struct ip_vs_dest *dest, *nxt; + struct ip_vs_scheduler *old_sched; + struct ip_vs_pe *old_pe; + struct netns_ipvs *ipvs = svc->ipvs; + + if (svc->af == AF_INET) { + ipvs->num_services--; + if (!ipvs->num_services) + ip_vs_unregister_hooks(ipvs, svc->af); + } else if (svc->af == AF_INET6) { + ipvs->num_services6--; + if (!ipvs->num_services6) + ip_vs_unregister_hooks(ipvs, svc->af); + } + + ip_vs_stop_estimator(svc->ipvs, &svc->stats); + + /* Unbind scheduler */ + old_sched = rcu_dereference_protected(svc->scheduler, 1); + ip_vs_unbind_scheduler(svc, old_sched); + ip_vs_scheduler_put(old_sched); + + /* Unbind persistence engine, keep svc->pe */ + old_pe = rcu_dereference_protected(svc->pe, 1); + if (old_pe && old_pe->conn_out) + atomic_dec(&ipvs->conn_out_counter); + ip_vs_pe_put(old_pe); + + /* + * Unlink the whole destination list + */ + list_for_each_entry_safe(dest, nxt, &svc->destinations, n_list) { + __ip_vs_unlink_dest(svc, dest, 0); + __ip_vs_del_dest(svc->ipvs, dest, cleanup); + } + + /* + * Update the virtual service counters + */ + if (svc->port == FTPPORT) + atomic_dec(&ipvs->ftpsvc_counter); + else if (svc->port == 0) + atomic_dec(&ipvs->nullsvc_counter); + + /* + * Free the service if nobody refers to it + */ + __ip_vs_svc_put(svc, true); + + /* decrease the module use count */ + ip_vs_use_count_dec(); +} + +/* + * Unlink a service from list and try to delete it if its refcnt reached 0 + */ +static void ip_vs_unlink_service(struct ip_vs_service *svc, bool cleanup) +{ + ip_vs_unregister_conntrack(svc); + /* Hold svc to avoid double release from dest_trash */ + atomic_inc(&svc->refcnt); + /* + * Unhash it from the service table + */ + ip_vs_svc_unhash(svc); + + __ip_vs_del_service(svc, cleanup); +} + +/* + * Delete a service from the service list + */ +static int ip_vs_del_service(struct ip_vs_service *svc) +{ + if (svc == NULL) + return -EEXIST; + ip_vs_unlink_service(svc, false); + + return 0; +} + + +/* + * Flush all the virtual services + */ +static int ip_vs_flush(struct netns_ipvs *ipvs, bool cleanup) +{ + int idx; + struct ip_vs_service *svc; + struct hlist_node *n; + + /* + * Flush the service table hashed by + */ + for(idx = 0; idx < IP_VS_SVC_TAB_SIZE; idx++) { + hlist_for_each_entry_safe(svc, n, &ip_vs_svc_table[idx], + s_list) { + if (svc->ipvs == ipvs) + ip_vs_unlink_service(svc, cleanup); + } + } + + /* + * Flush the service table hashed by fwmark + */ + for(idx = 0; idx < IP_VS_SVC_TAB_SIZE; idx++) { + hlist_for_each_entry_safe(svc, n, &ip_vs_svc_fwm_table[idx], + f_list) { + if (svc->ipvs == ipvs) + ip_vs_unlink_service(svc, cleanup); + } + } + + return 0; +} + +/* + * Delete service by {netns} in the service table. + * Called by __ip_vs_batch_cleanup() + */ +void ip_vs_service_nets_cleanup(struct list_head *net_list) +{ + struct netns_ipvs *ipvs; + struct net *net; + + EnterFunction(2); + /* Check for "full" addressed entries */ + mutex_lock(&__ip_vs_mutex); + list_for_each_entry(net, net_list, exit_list) { + ipvs = net_ipvs(net); + ip_vs_flush(ipvs, true); + } + mutex_unlock(&__ip_vs_mutex); + LeaveFunction(2); +} + +/* Put all references for device (dst_cache) */ +static inline void +ip_vs_forget_dev(struct ip_vs_dest *dest, struct net_device *dev) +{ + struct ip_vs_dest_dst *dest_dst; + + spin_lock_bh(&dest->dst_lock); + dest_dst = rcu_dereference_protected(dest->dest_dst, 1); + if (dest_dst && dest_dst->dst_cache->dev == dev) { + IP_VS_DBG_BUF(3, "Reset dev:%s dest %s:%u ,dest->refcnt=%d\n", + dev->name, + IP_VS_DBG_ADDR(dest->af, &dest->addr), + ntohs(dest->port), + refcount_read(&dest->refcnt)); + __ip_vs_dst_cache_reset(dest); + } + spin_unlock_bh(&dest->dst_lock); + +} +/* Netdev event receiver + * Currently only NETDEV_DOWN is handled to release refs to cached dsts + */ +static int ip_vs_dst_event(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct net *net = dev_net(dev); + struct netns_ipvs *ipvs = net_ipvs(net); + struct ip_vs_service *svc; + struct ip_vs_dest *dest; + unsigned int idx; + + if (event != NETDEV_DOWN || !ipvs) + return NOTIFY_DONE; + IP_VS_DBG(3, "%s() dev=%s\n", __func__, dev->name); + EnterFunction(2); + mutex_lock(&__ip_vs_mutex); + for (idx = 0; idx < IP_VS_SVC_TAB_SIZE; idx++) { + hlist_for_each_entry(svc, &ip_vs_svc_table[idx], s_list) { + if (svc->ipvs == ipvs) { + list_for_each_entry(dest, &svc->destinations, + n_list) { + ip_vs_forget_dev(dest, dev); + } + } + } + + hlist_for_each_entry(svc, &ip_vs_svc_fwm_table[idx], f_list) { + if (svc->ipvs == ipvs) { + list_for_each_entry(dest, &svc->destinations, + n_list) { + ip_vs_forget_dev(dest, dev); + } + } + + } + } + + spin_lock_bh(&ipvs->dest_trash_lock); + list_for_each_entry(dest, &ipvs->dest_trash, t_list) { + ip_vs_forget_dev(dest, dev); + } + spin_unlock_bh(&ipvs->dest_trash_lock); + mutex_unlock(&__ip_vs_mutex); + LeaveFunction(2); + return NOTIFY_DONE; +} + +/* + * Zero counters in a service or all services + */ +static int ip_vs_zero_service(struct ip_vs_service *svc) +{ + struct ip_vs_dest *dest; + + list_for_each_entry(dest, &svc->destinations, n_list) { + ip_vs_zero_stats(&dest->stats); + } + ip_vs_zero_stats(&svc->stats); + return 0; +} + +static int ip_vs_zero_all(struct netns_ipvs *ipvs) +{ + int idx; + struct ip_vs_service *svc; + + for(idx = 0; idx < IP_VS_SVC_TAB_SIZE; idx++) { + hlist_for_each_entry(svc, &ip_vs_svc_table[idx], s_list) { + if (svc->ipvs == ipvs) + ip_vs_zero_service(svc); + } + } + + for(idx = 0; idx < IP_VS_SVC_TAB_SIZE; idx++) { + hlist_for_each_entry(svc, &ip_vs_svc_fwm_table[idx], f_list) { + if (svc->ipvs == ipvs) + ip_vs_zero_service(svc); + } + } + + ip_vs_zero_stats(&ipvs->tot_stats); + return 0; +} + +#ifdef CONFIG_SYSCTL + +static int +proc_do_defense_mode(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct netns_ipvs *ipvs = table->extra2; + int *valp = table->data; + int val = *valp; + int rc; + + struct ctl_table tmp = { + .data = &val, + .maxlen = sizeof(int), + .mode = table->mode, + }; + + rc = proc_dointvec(&tmp, write, buffer, lenp, ppos); + if (write && (*valp != val)) { + if (val < 0 || val > 3) { + rc = -EINVAL; + } else { + *valp = val; + update_defense_level(ipvs); + } + } + return rc; +} + +static int +proc_do_sync_threshold(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct netns_ipvs *ipvs = table->extra2; + int *valp = table->data; + int val[2]; + int rc; + struct ctl_table tmp = { + .data = &val, + .maxlen = table->maxlen, + .mode = table->mode, + }; + + mutex_lock(&ipvs->sync_mutex); + memcpy(val, valp, sizeof(val)); + rc = proc_dointvec(&tmp, write, buffer, lenp, ppos); + if (write) { + if (val[0] < 0 || val[1] < 0 || + (val[0] >= val[1] && val[1])) + rc = -EINVAL; + else + memcpy(valp, val, sizeof(val)); + } + mutex_unlock(&ipvs->sync_mutex); + return rc; +} + +static int +proc_do_sync_ports(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + int *valp = table->data; + int val = *valp; + int rc; + + struct ctl_table tmp = { + .data = &val, + .maxlen = sizeof(int), + .mode = table->mode, + }; + + rc = proc_dointvec(&tmp, write, buffer, lenp, ppos); + if (write && (*valp != val)) { + if (val < 1 || !is_power_of_2(val)) + rc = -EINVAL; + else + *valp = val; + } + return rc; +} + +/* + * IPVS sysctl table (under the /proc/sys/net/ipv4/vs/) + * Do not change order or insert new entries without + * align with netns init in ip_vs_control_net_init() + */ + +static struct ctl_table vs_vars[] = { + { + .procname = "amemthresh", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "am_droprate", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "drop_entry", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_do_defense_mode, + }, + { + .procname = "drop_packet", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_do_defense_mode, + }, +#ifdef CONFIG_IP_VS_NFCT + { + .procname = "conntrack", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = &proc_dointvec, + }, +#endif + { + .procname = "secure_tcp", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_do_defense_mode, + }, + { + .procname = "snat_reroute", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = &proc_dointvec, + }, + { + .procname = "sync_version", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + { + .procname = "sync_ports", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_do_sync_ports, + }, + { + .procname = "sync_persist_mode", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "sync_qlen_max", + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + }, + { + .procname = "sync_sock_size", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "cache_bypass", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "expire_nodest_conn", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "sloppy_tcp", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "sloppy_sctp", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "expire_quiescent_template", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "sync_threshold", + .maxlen = + sizeof(((struct netns_ipvs *)0)->sysctl_sync_threshold), + .mode = 0644, + .proc_handler = proc_do_sync_threshold, + }, + { + .procname = "sync_refresh_period", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "sync_retries", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_THREE, + }, + { + .procname = "nat_icmp_send", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "pmtu_disc", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "backup_only", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "conn_reuse_mode", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "schedule_icmp", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "ignore_tunneled", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "run_estimation", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, +#ifdef CONFIG_IP_VS_DEBUG + { + .procname = "debug_level", + .data = &sysctl_ip_vs_debug_level, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, +#endif + { } +}; + +#endif + +#ifdef CONFIG_PROC_FS + +struct ip_vs_iter { + struct seq_net_private p; /* Do not move this, netns depends upon it*/ + struct hlist_head *table; + int bucket; +}; + +/* + * Write the contents of the VS rule table to a PROCfs file. + * (It is kept just for backward compatibility) + */ +static inline const char *ip_vs_fwd_name(unsigned int flags) +{ + switch (flags & IP_VS_CONN_F_FWD_MASK) { + case IP_VS_CONN_F_LOCALNODE: + return "Local"; + case IP_VS_CONN_F_TUNNEL: + return "Tunnel"; + case IP_VS_CONN_F_DROUTE: + return "Route"; + default: + return "Masq"; + } +} + + +/* Get the Nth entry in the two lists */ +static struct ip_vs_service *ip_vs_info_array(struct seq_file *seq, loff_t pos) +{ + struct net *net = seq_file_net(seq); + struct netns_ipvs *ipvs = net_ipvs(net); + struct ip_vs_iter *iter = seq->private; + int idx; + struct ip_vs_service *svc; + + /* look in hash by protocol */ + for (idx = 0; idx < IP_VS_SVC_TAB_SIZE; idx++) { + hlist_for_each_entry_rcu(svc, &ip_vs_svc_table[idx], s_list) { + if ((svc->ipvs == ipvs) && pos-- == 0) { + iter->table = ip_vs_svc_table; + iter->bucket = idx; + return svc; + } + } + } + + /* keep looking in fwmark */ + for (idx = 0; idx < IP_VS_SVC_TAB_SIZE; idx++) { + hlist_for_each_entry_rcu(svc, &ip_vs_svc_fwm_table[idx], + f_list) { + if ((svc->ipvs == ipvs) && pos-- == 0) { + iter->table = ip_vs_svc_fwm_table; + iter->bucket = idx; + return svc; + } + } + } + + return NULL; +} + +static void *ip_vs_info_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(RCU) +{ + rcu_read_lock(); + return *pos ? ip_vs_info_array(seq, *pos - 1) : SEQ_START_TOKEN; +} + + +static void *ip_vs_info_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct hlist_node *e; + struct ip_vs_iter *iter; + struct ip_vs_service *svc; + + ++*pos; + if (v == SEQ_START_TOKEN) + return ip_vs_info_array(seq,0); + + svc = v; + iter = seq->private; + + if (iter->table == ip_vs_svc_table) { + /* next service in table hashed by protocol */ + e = rcu_dereference(hlist_next_rcu(&svc->s_list)); + if (e) + return hlist_entry(e, struct ip_vs_service, s_list); + + while (++iter->bucket < IP_VS_SVC_TAB_SIZE) { + hlist_for_each_entry_rcu(svc, + &ip_vs_svc_table[iter->bucket], + s_list) { + return svc; + } + } + + iter->table = ip_vs_svc_fwm_table; + iter->bucket = -1; + goto scan_fwmark; + } + + /* next service in hashed by fwmark */ + e = rcu_dereference(hlist_next_rcu(&svc->f_list)); + if (e) + return hlist_entry(e, struct ip_vs_service, f_list); + + scan_fwmark: + while (++iter->bucket < IP_VS_SVC_TAB_SIZE) { + hlist_for_each_entry_rcu(svc, + &ip_vs_svc_fwm_table[iter->bucket], + f_list) + return svc; + } + + return NULL; +} + +static void ip_vs_info_seq_stop(struct seq_file *seq, void *v) + __releases(RCU) +{ + rcu_read_unlock(); +} + + +static int ip_vs_info_seq_show(struct seq_file *seq, void *v) +{ + if (v == SEQ_START_TOKEN) { + seq_printf(seq, + "IP Virtual Server version %d.%d.%d (size=%d)\n", + NVERSION(IP_VS_VERSION_CODE), ip_vs_conn_tab_size); + seq_puts(seq, + "Prot LocalAddress:Port Scheduler Flags\n"); + seq_puts(seq, + " -> RemoteAddress:Port Forward Weight ActiveConn InActConn\n"); + } else { + struct net *net = seq_file_net(seq); + struct netns_ipvs *ipvs = net_ipvs(net); + const struct ip_vs_service *svc = v; + const struct ip_vs_iter *iter = seq->private; + const struct ip_vs_dest *dest; + struct ip_vs_scheduler *sched = rcu_dereference(svc->scheduler); + char *sched_name = sched ? sched->name : "none"; + + if (svc->ipvs != ipvs) + return 0; + if (iter->table == ip_vs_svc_table) { +#ifdef CONFIG_IP_VS_IPV6 + if (svc->af == AF_INET6) + seq_printf(seq, "%s [%pI6]:%04X %s ", + ip_vs_proto_name(svc->protocol), + &svc->addr.in6, + ntohs(svc->port), + sched_name); + else +#endif + seq_printf(seq, "%s %08X:%04X %s %s ", + ip_vs_proto_name(svc->protocol), + ntohl(svc->addr.ip), + ntohs(svc->port), + sched_name, + (svc->flags & IP_VS_SVC_F_ONEPACKET)?"ops ":""); + } else { + seq_printf(seq, "FWM %08X %s %s", + svc->fwmark, sched_name, + (svc->flags & IP_VS_SVC_F_ONEPACKET)?"ops ":""); + } + + if (svc->flags & IP_VS_SVC_F_PERSISTENT) + seq_printf(seq, "persistent %d %08X\n", + svc->timeout, + ntohl(svc->netmask)); + else + seq_putc(seq, '\n'); + + list_for_each_entry_rcu(dest, &svc->destinations, n_list) { +#ifdef CONFIG_IP_VS_IPV6 + if (dest->af == AF_INET6) + seq_printf(seq, + " -> [%pI6]:%04X" + " %-7s %-6d %-10d %-10d\n", + &dest->addr.in6, + ntohs(dest->port), + ip_vs_fwd_name(atomic_read(&dest->conn_flags)), + atomic_read(&dest->weight), + atomic_read(&dest->activeconns), + atomic_read(&dest->inactconns)); + else +#endif + seq_printf(seq, + " -> %08X:%04X " + "%-7s %-6d %-10d %-10d\n", + ntohl(dest->addr.ip), + ntohs(dest->port), + ip_vs_fwd_name(atomic_read(&dest->conn_flags)), + atomic_read(&dest->weight), + atomic_read(&dest->activeconns), + atomic_read(&dest->inactconns)); + + } + } + return 0; +} + +static const struct seq_operations ip_vs_info_seq_ops = { + .start = ip_vs_info_seq_start, + .next = ip_vs_info_seq_next, + .stop = ip_vs_info_seq_stop, + .show = ip_vs_info_seq_show, +}; + +static int ip_vs_stats_show(struct seq_file *seq, void *v) +{ + struct net *net = seq_file_single_net(seq); + struct ip_vs_kstats show; + +/* 01234567 01234567 01234567 0123456701234567 0123456701234567 */ + seq_puts(seq, + " Total Incoming Outgoing Incoming Outgoing\n"); + seq_puts(seq, + " Conns Packets Packets Bytes Bytes\n"); + + ip_vs_copy_stats(&show, &net_ipvs(net)->tot_stats); + seq_printf(seq, "%8LX %8LX %8LX %16LX %16LX\n\n", + (unsigned long long)show.conns, + (unsigned long long)show.inpkts, + (unsigned long long)show.outpkts, + (unsigned long long)show.inbytes, + (unsigned long long)show.outbytes); + +/* 01234567 01234567 01234567 0123456701234567 0123456701234567*/ + seq_puts(seq, + " Conns/s Pkts/s Pkts/s Bytes/s Bytes/s\n"); + seq_printf(seq, "%8LX %8LX %8LX %16LX %16LX\n", + (unsigned long long)show.cps, + (unsigned long long)show.inpps, + (unsigned long long)show.outpps, + (unsigned long long)show.inbps, + (unsigned long long)show.outbps); + + return 0; +} + +static int ip_vs_stats_percpu_show(struct seq_file *seq, void *v) +{ + struct net *net = seq_file_single_net(seq); + struct ip_vs_stats *tot_stats = &net_ipvs(net)->tot_stats; + struct ip_vs_cpu_stats __percpu *cpustats = tot_stats->cpustats; + struct ip_vs_kstats kstats; + int i; + +/* 01234567 01234567 01234567 0123456701234567 0123456701234567 */ + seq_puts(seq, + " Total Incoming Outgoing Incoming Outgoing\n"); + seq_puts(seq, + "CPU Conns Packets Packets Bytes Bytes\n"); + + for_each_possible_cpu(i) { + struct ip_vs_cpu_stats *u = per_cpu_ptr(cpustats, i); + unsigned int start; + u64 conns, inpkts, outpkts, inbytes, outbytes; + + do { + start = u64_stats_fetch_begin_irq(&u->syncp); + conns = u64_stats_read(&u->cnt.conns); + inpkts = u64_stats_read(&u->cnt.inpkts); + outpkts = u64_stats_read(&u->cnt.outpkts); + inbytes = u64_stats_read(&u->cnt.inbytes); + outbytes = u64_stats_read(&u->cnt.outbytes); + } while (u64_stats_fetch_retry_irq(&u->syncp, start)); + + seq_printf(seq, "%3X %8LX %8LX %8LX %16LX %16LX\n", + i, (u64)conns, (u64)inpkts, + (u64)outpkts, (u64)inbytes, + (u64)outbytes); + } + + ip_vs_copy_stats(&kstats, tot_stats); + + seq_printf(seq, " ~ %8LX %8LX %8LX %16LX %16LX\n\n", + (unsigned long long)kstats.conns, + (unsigned long long)kstats.inpkts, + (unsigned long long)kstats.outpkts, + (unsigned long long)kstats.inbytes, + (unsigned long long)kstats.outbytes); + +/* ... 01234567 01234567 01234567 0123456701234567 0123456701234567 */ + seq_puts(seq, + " Conns/s Pkts/s Pkts/s Bytes/s Bytes/s\n"); + seq_printf(seq, " %8LX %8LX %8LX %16LX %16LX\n", + kstats.cps, + kstats.inpps, + kstats.outpps, + kstats.inbps, + kstats.outbps); + + return 0; +} +#endif + +/* + * Set timeout values for tcp tcpfin udp in the timeout_table. + */ +static int ip_vs_set_timeout(struct netns_ipvs *ipvs, struct ip_vs_timeout_user *u) +{ +#if defined(CONFIG_IP_VS_PROTO_TCP) || defined(CONFIG_IP_VS_PROTO_UDP) + struct ip_vs_proto_data *pd; +#endif + + IP_VS_DBG(2, "Setting timeout tcp:%d tcpfin:%d udp:%d\n", + u->tcp_timeout, + u->tcp_fin_timeout, + u->udp_timeout); + +#ifdef CONFIG_IP_VS_PROTO_TCP + if (u->tcp_timeout < 0 || u->tcp_timeout > (INT_MAX / HZ) || + u->tcp_fin_timeout < 0 || u->tcp_fin_timeout > (INT_MAX / HZ)) { + return -EINVAL; + } +#endif + +#ifdef CONFIG_IP_VS_PROTO_UDP + if (u->udp_timeout < 0 || u->udp_timeout > (INT_MAX / HZ)) + return -EINVAL; +#endif + +#ifdef CONFIG_IP_VS_PROTO_TCP + if (u->tcp_timeout) { + pd = ip_vs_proto_data_get(ipvs, IPPROTO_TCP); + pd->timeout_table[IP_VS_TCP_S_ESTABLISHED] + = u->tcp_timeout * HZ; + } + + if (u->tcp_fin_timeout) { + pd = ip_vs_proto_data_get(ipvs, IPPROTO_TCP); + pd->timeout_table[IP_VS_TCP_S_FIN_WAIT] + = u->tcp_fin_timeout * HZ; + } +#endif + +#ifdef CONFIG_IP_VS_PROTO_UDP + if (u->udp_timeout) { + pd = ip_vs_proto_data_get(ipvs, IPPROTO_UDP); + pd->timeout_table[IP_VS_UDP_S_NORMAL] + = u->udp_timeout * HZ; + } +#endif + return 0; +} + +#define CMDID(cmd) (cmd - IP_VS_BASE_CTL) + +struct ip_vs_svcdest_user { + struct ip_vs_service_user s; + struct ip_vs_dest_user d; +}; + +static const unsigned char set_arglen[CMDID(IP_VS_SO_SET_MAX) + 1] = { + [CMDID(IP_VS_SO_SET_ADD)] = sizeof(struct ip_vs_service_user), + [CMDID(IP_VS_SO_SET_EDIT)] = sizeof(struct ip_vs_service_user), + [CMDID(IP_VS_SO_SET_DEL)] = sizeof(struct ip_vs_service_user), + [CMDID(IP_VS_SO_SET_ADDDEST)] = sizeof(struct ip_vs_svcdest_user), + [CMDID(IP_VS_SO_SET_DELDEST)] = sizeof(struct ip_vs_svcdest_user), + [CMDID(IP_VS_SO_SET_EDITDEST)] = sizeof(struct ip_vs_svcdest_user), + [CMDID(IP_VS_SO_SET_TIMEOUT)] = sizeof(struct ip_vs_timeout_user), + [CMDID(IP_VS_SO_SET_STARTDAEMON)] = sizeof(struct ip_vs_daemon_user), + [CMDID(IP_VS_SO_SET_STOPDAEMON)] = sizeof(struct ip_vs_daemon_user), + [CMDID(IP_VS_SO_SET_ZERO)] = sizeof(struct ip_vs_service_user), +}; + +union ip_vs_set_arglen { + struct ip_vs_service_user field_IP_VS_SO_SET_ADD; + struct ip_vs_service_user field_IP_VS_SO_SET_EDIT; + struct ip_vs_service_user field_IP_VS_SO_SET_DEL; + struct ip_vs_svcdest_user field_IP_VS_SO_SET_ADDDEST; + struct ip_vs_svcdest_user field_IP_VS_SO_SET_DELDEST; + struct ip_vs_svcdest_user field_IP_VS_SO_SET_EDITDEST; + struct ip_vs_timeout_user field_IP_VS_SO_SET_TIMEOUT; + struct ip_vs_daemon_user field_IP_VS_SO_SET_STARTDAEMON; + struct ip_vs_daemon_user field_IP_VS_SO_SET_STOPDAEMON; + struct ip_vs_service_user field_IP_VS_SO_SET_ZERO; +}; + +#define MAX_SET_ARGLEN sizeof(union ip_vs_set_arglen) + +static void ip_vs_copy_usvc_compat(struct ip_vs_service_user_kern *usvc, + struct ip_vs_service_user *usvc_compat) +{ + memset(usvc, 0, sizeof(*usvc)); + + usvc->af = AF_INET; + usvc->protocol = usvc_compat->protocol; + usvc->addr.ip = usvc_compat->addr; + usvc->port = usvc_compat->port; + usvc->fwmark = usvc_compat->fwmark; + + /* Deep copy of sched_name is not needed here */ + usvc->sched_name = usvc_compat->sched_name; + + usvc->flags = usvc_compat->flags; + usvc->timeout = usvc_compat->timeout; + usvc->netmask = usvc_compat->netmask; +} + +static void ip_vs_copy_udest_compat(struct ip_vs_dest_user_kern *udest, + struct ip_vs_dest_user *udest_compat) +{ + memset(udest, 0, sizeof(*udest)); + + udest->addr.ip = udest_compat->addr; + udest->port = udest_compat->port; + udest->conn_flags = udest_compat->conn_flags; + udest->weight = udest_compat->weight; + udest->u_threshold = udest_compat->u_threshold; + udest->l_threshold = udest_compat->l_threshold; + udest->af = AF_INET; + udest->tun_type = IP_VS_CONN_F_TUNNEL_TYPE_IPIP; +} + +static int +do_ip_vs_set_ctl(struct sock *sk, int cmd, sockptr_t ptr, unsigned int len) +{ + struct net *net = sock_net(sk); + int ret; + unsigned char arg[MAX_SET_ARGLEN]; + struct ip_vs_service_user *usvc_compat; + struct ip_vs_service_user_kern usvc; + struct ip_vs_service *svc; + struct ip_vs_dest_user *udest_compat; + struct ip_vs_dest_user_kern udest; + struct netns_ipvs *ipvs = net_ipvs(net); + + BUILD_BUG_ON(sizeof(arg) > 255); + if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + if (cmd < IP_VS_BASE_CTL || cmd > IP_VS_SO_SET_MAX) + return -EINVAL; + if (len != set_arglen[CMDID(cmd)]) { + IP_VS_DBG(1, "set_ctl: len %u != %u\n", + len, set_arglen[CMDID(cmd)]); + return -EINVAL; + } + + if (copy_from_sockptr(arg, ptr, len) != 0) + return -EFAULT; + + /* Handle daemons since they have another lock */ + if (cmd == IP_VS_SO_SET_STARTDAEMON || + cmd == IP_VS_SO_SET_STOPDAEMON) { + struct ip_vs_daemon_user *dm = (struct ip_vs_daemon_user *)arg; + + if (cmd == IP_VS_SO_SET_STARTDAEMON) { + struct ipvs_sync_daemon_cfg cfg; + + memset(&cfg, 0, sizeof(cfg)); + ret = -EINVAL; + if (strscpy(cfg.mcast_ifn, dm->mcast_ifn, + sizeof(cfg.mcast_ifn)) <= 0) + return ret; + cfg.syncid = dm->syncid; + ret = start_sync_thread(ipvs, &cfg, dm->state); + } else { + ret = stop_sync_thread(ipvs, dm->state); + } + return ret; + } + + mutex_lock(&__ip_vs_mutex); + if (cmd == IP_VS_SO_SET_FLUSH) { + /* Flush the virtual service */ + ret = ip_vs_flush(ipvs, false); + goto out_unlock; + } else if (cmd == IP_VS_SO_SET_TIMEOUT) { + /* Set timeout values for (tcp tcpfin udp) */ + ret = ip_vs_set_timeout(ipvs, (struct ip_vs_timeout_user *)arg); + goto out_unlock; + } else if (!len) { + /* No more commands with len == 0 below */ + ret = -EINVAL; + goto out_unlock; + } + + usvc_compat = (struct ip_vs_service_user *)arg; + udest_compat = (struct ip_vs_dest_user *)(usvc_compat + 1); + + /* We only use the new structs internally, so copy userspace compat + * structs to extended internal versions */ + ip_vs_copy_usvc_compat(&usvc, usvc_compat); + ip_vs_copy_udest_compat(&udest, udest_compat); + + if (cmd == IP_VS_SO_SET_ZERO) { + /* if no service address is set, zero counters in all */ + if (!usvc.fwmark && !usvc.addr.ip && !usvc.port) { + ret = ip_vs_zero_all(ipvs); + goto out_unlock; + } + } + + if ((cmd == IP_VS_SO_SET_ADD || cmd == IP_VS_SO_SET_EDIT) && + strnlen(usvc.sched_name, IP_VS_SCHEDNAME_MAXLEN) == + IP_VS_SCHEDNAME_MAXLEN) { + ret = -EINVAL; + goto out_unlock; + } + + /* Check for valid protocol: TCP or UDP or SCTP, even for fwmark!=0 */ + if (usvc.protocol != IPPROTO_TCP && usvc.protocol != IPPROTO_UDP && + usvc.protocol != IPPROTO_SCTP) { + pr_err("set_ctl: invalid protocol: %d %pI4:%d\n", + usvc.protocol, &usvc.addr.ip, + ntohs(usvc.port)); + ret = -EFAULT; + goto out_unlock; + } + + /* Lookup the exact service by or fwmark */ + rcu_read_lock(); + if (usvc.fwmark == 0) + svc = __ip_vs_service_find(ipvs, usvc.af, usvc.protocol, + &usvc.addr, usvc.port); + else + svc = __ip_vs_svc_fwm_find(ipvs, usvc.af, usvc.fwmark); + rcu_read_unlock(); + + if (cmd != IP_VS_SO_SET_ADD + && (svc == NULL || svc->protocol != usvc.protocol)) { + ret = -ESRCH; + goto out_unlock; + } + + switch (cmd) { + case IP_VS_SO_SET_ADD: + if (svc != NULL) + ret = -EEXIST; + else + ret = ip_vs_add_service(ipvs, &usvc, &svc); + break; + case IP_VS_SO_SET_EDIT: + ret = ip_vs_edit_service(svc, &usvc); + break; + case IP_VS_SO_SET_DEL: + ret = ip_vs_del_service(svc); + if (!ret) + goto out_unlock; + break; + case IP_VS_SO_SET_ZERO: + ret = ip_vs_zero_service(svc); + break; + case IP_VS_SO_SET_ADDDEST: + ret = ip_vs_add_dest(svc, &udest); + break; + case IP_VS_SO_SET_EDITDEST: + ret = ip_vs_edit_dest(svc, &udest); + break; + case IP_VS_SO_SET_DELDEST: + ret = ip_vs_del_dest(svc, &udest); + } + + out_unlock: + mutex_unlock(&__ip_vs_mutex); + return ret; +} + + +static void +ip_vs_copy_service(struct ip_vs_service_entry *dst, struct ip_vs_service *src) +{ + struct ip_vs_scheduler *sched; + struct ip_vs_kstats kstats; + char *sched_name; + + sched = rcu_dereference_protected(src->scheduler, 1); + sched_name = sched ? sched->name : "none"; + dst->protocol = src->protocol; + dst->addr = src->addr.ip; + dst->port = src->port; + dst->fwmark = src->fwmark; + strscpy(dst->sched_name, sched_name, sizeof(dst->sched_name)); + dst->flags = src->flags; + dst->timeout = src->timeout / HZ; + dst->netmask = src->netmask; + dst->num_dests = src->num_dests; + ip_vs_copy_stats(&kstats, &src->stats); + ip_vs_export_stats_user(&dst->stats, &kstats); +} + +static inline int +__ip_vs_get_service_entries(struct netns_ipvs *ipvs, + const struct ip_vs_get_services *get, + struct ip_vs_get_services __user *uptr) +{ + int idx, count=0; + struct ip_vs_service *svc; + struct ip_vs_service_entry entry; + int ret = 0; + + for (idx = 0; idx < IP_VS_SVC_TAB_SIZE; idx++) { + hlist_for_each_entry(svc, &ip_vs_svc_table[idx], s_list) { + /* Only expose IPv4 entries to old interface */ + if (svc->af != AF_INET || (svc->ipvs != ipvs)) + continue; + + if (count >= get->num_services) + goto out; + memset(&entry, 0, sizeof(entry)); + ip_vs_copy_service(&entry, svc); + if (copy_to_user(&uptr->entrytable[count], + &entry, sizeof(entry))) { + ret = -EFAULT; + goto out; + } + count++; + } + } + + for (idx = 0; idx < IP_VS_SVC_TAB_SIZE; idx++) { + hlist_for_each_entry(svc, &ip_vs_svc_fwm_table[idx], f_list) { + /* Only expose IPv4 entries to old interface */ + if (svc->af != AF_INET || (svc->ipvs != ipvs)) + continue; + + if (count >= get->num_services) + goto out; + memset(&entry, 0, sizeof(entry)); + ip_vs_copy_service(&entry, svc); + if (copy_to_user(&uptr->entrytable[count], + &entry, sizeof(entry))) { + ret = -EFAULT; + goto out; + } + count++; + } + } +out: + return ret; +} + +static inline int +__ip_vs_get_dest_entries(struct netns_ipvs *ipvs, const struct ip_vs_get_dests *get, + struct ip_vs_get_dests __user *uptr) +{ + struct ip_vs_service *svc; + union nf_inet_addr addr = { .ip = get->addr }; + int ret = 0; + + rcu_read_lock(); + if (get->fwmark) + svc = __ip_vs_svc_fwm_find(ipvs, AF_INET, get->fwmark); + else + svc = __ip_vs_service_find(ipvs, AF_INET, get->protocol, &addr, + get->port); + rcu_read_unlock(); + + if (svc) { + int count = 0; + struct ip_vs_dest *dest; + struct ip_vs_dest_entry entry; + struct ip_vs_kstats kstats; + + memset(&entry, 0, sizeof(entry)); + list_for_each_entry(dest, &svc->destinations, n_list) { + if (count >= get->num_dests) + break; + + /* Cannot expose heterogeneous members via sockopt + * interface + */ + if (dest->af != svc->af) + continue; + + entry.addr = dest->addr.ip; + entry.port = dest->port; + entry.conn_flags = atomic_read(&dest->conn_flags); + entry.weight = atomic_read(&dest->weight); + entry.u_threshold = dest->u_threshold; + entry.l_threshold = dest->l_threshold; + entry.activeconns = atomic_read(&dest->activeconns); + entry.inactconns = atomic_read(&dest->inactconns); + entry.persistconns = atomic_read(&dest->persistconns); + ip_vs_copy_stats(&kstats, &dest->stats); + ip_vs_export_stats_user(&entry.stats, &kstats); + if (copy_to_user(&uptr->entrytable[count], + &entry, sizeof(entry))) { + ret = -EFAULT; + break; + } + count++; + } + } else + ret = -ESRCH; + return ret; +} + +static inline void +__ip_vs_get_timeouts(struct netns_ipvs *ipvs, struct ip_vs_timeout_user *u) +{ +#if defined(CONFIG_IP_VS_PROTO_TCP) || defined(CONFIG_IP_VS_PROTO_UDP) + struct ip_vs_proto_data *pd; +#endif + + memset(u, 0, sizeof (*u)); + +#ifdef CONFIG_IP_VS_PROTO_TCP + pd = ip_vs_proto_data_get(ipvs, IPPROTO_TCP); + u->tcp_timeout = pd->timeout_table[IP_VS_TCP_S_ESTABLISHED] / HZ; + u->tcp_fin_timeout = pd->timeout_table[IP_VS_TCP_S_FIN_WAIT] / HZ; +#endif +#ifdef CONFIG_IP_VS_PROTO_UDP + pd = ip_vs_proto_data_get(ipvs, IPPROTO_UDP); + u->udp_timeout = + pd->timeout_table[IP_VS_UDP_S_NORMAL] / HZ; +#endif +} + +static const unsigned char get_arglen[CMDID(IP_VS_SO_GET_MAX) + 1] = { + [CMDID(IP_VS_SO_GET_VERSION)] = 64, + [CMDID(IP_VS_SO_GET_INFO)] = sizeof(struct ip_vs_getinfo), + [CMDID(IP_VS_SO_GET_SERVICES)] = sizeof(struct ip_vs_get_services), + [CMDID(IP_VS_SO_GET_SERVICE)] = sizeof(struct ip_vs_service_entry), + [CMDID(IP_VS_SO_GET_DESTS)] = sizeof(struct ip_vs_get_dests), + [CMDID(IP_VS_SO_GET_TIMEOUT)] = sizeof(struct ip_vs_timeout_user), + [CMDID(IP_VS_SO_GET_DAEMON)] = 2 * sizeof(struct ip_vs_daemon_user), +}; + +union ip_vs_get_arglen { + char field_IP_VS_SO_GET_VERSION[64]; + struct ip_vs_getinfo field_IP_VS_SO_GET_INFO; + struct ip_vs_get_services field_IP_VS_SO_GET_SERVICES; + struct ip_vs_service_entry field_IP_VS_SO_GET_SERVICE; + struct ip_vs_get_dests field_IP_VS_SO_GET_DESTS; + struct ip_vs_timeout_user field_IP_VS_SO_GET_TIMEOUT; + struct ip_vs_daemon_user field_IP_VS_SO_GET_DAEMON[2]; +}; + +#define MAX_GET_ARGLEN sizeof(union ip_vs_get_arglen) + +static int +do_ip_vs_get_ctl(struct sock *sk, int cmd, void __user *user, int *len) +{ + unsigned char arg[MAX_GET_ARGLEN]; + int ret = 0; + unsigned int copylen; + struct net *net = sock_net(sk); + struct netns_ipvs *ipvs = net_ipvs(net); + + BUG_ON(!net); + BUILD_BUG_ON(sizeof(arg) > 255); + if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + if (cmd < IP_VS_BASE_CTL || cmd > IP_VS_SO_GET_MAX) + return -EINVAL; + + copylen = get_arglen[CMDID(cmd)]; + if (*len < (int) copylen) { + IP_VS_DBG(1, "get_ctl: len %d < %u\n", *len, copylen); + return -EINVAL; + } + + if (copy_from_user(arg, user, copylen) != 0) + return -EFAULT; + /* + * Handle daemons first since it has its own locking + */ + if (cmd == IP_VS_SO_GET_DAEMON) { + struct ip_vs_daemon_user d[2]; + + memset(&d, 0, sizeof(d)); + mutex_lock(&ipvs->sync_mutex); + if (ipvs->sync_state & IP_VS_STATE_MASTER) { + d[0].state = IP_VS_STATE_MASTER; + strscpy(d[0].mcast_ifn, ipvs->mcfg.mcast_ifn, + sizeof(d[0].mcast_ifn)); + d[0].syncid = ipvs->mcfg.syncid; + } + if (ipvs->sync_state & IP_VS_STATE_BACKUP) { + d[1].state = IP_VS_STATE_BACKUP; + strscpy(d[1].mcast_ifn, ipvs->bcfg.mcast_ifn, + sizeof(d[1].mcast_ifn)); + d[1].syncid = ipvs->bcfg.syncid; + } + if (copy_to_user(user, &d, sizeof(d)) != 0) + ret = -EFAULT; + mutex_unlock(&ipvs->sync_mutex); + return ret; + } + + mutex_lock(&__ip_vs_mutex); + switch (cmd) { + case IP_VS_SO_GET_VERSION: + { + char buf[64]; + + sprintf(buf, "IP Virtual Server version %d.%d.%d (size=%d)", + NVERSION(IP_VS_VERSION_CODE), ip_vs_conn_tab_size); + if (copy_to_user(user, buf, strlen(buf)+1) != 0) { + ret = -EFAULT; + goto out; + } + *len = strlen(buf)+1; + } + break; + + case IP_VS_SO_GET_INFO: + { + struct ip_vs_getinfo info; + info.version = IP_VS_VERSION_CODE; + info.size = ip_vs_conn_tab_size; + info.num_services = ipvs->num_services; + if (copy_to_user(user, &info, sizeof(info)) != 0) + ret = -EFAULT; + } + break; + + case IP_VS_SO_GET_SERVICES: + { + struct ip_vs_get_services *get; + int size; + + get = (struct ip_vs_get_services *)arg; + size = struct_size(get, entrytable, get->num_services); + if (*len != size) { + pr_err("length: %u != %u\n", *len, size); + ret = -EINVAL; + goto out; + } + ret = __ip_vs_get_service_entries(ipvs, get, user); + } + break; + + case IP_VS_SO_GET_SERVICE: + { + struct ip_vs_service_entry *entry; + struct ip_vs_service *svc; + union nf_inet_addr addr; + + entry = (struct ip_vs_service_entry *)arg; + addr.ip = entry->addr; + rcu_read_lock(); + if (entry->fwmark) + svc = __ip_vs_svc_fwm_find(ipvs, AF_INET, entry->fwmark); + else + svc = __ip_vs_service_find(ipvs, AF_INET, + entry->protocol, &addr, + entry->port); + rcu_read_unlock(); + if (svc) { + ip_vs_copy_service(entry, svc); + if (copy_to_user(user, entry, sizeof(*entry)) != 0) + ret = -EFAULT; + } else + ret = -ESRCH; + } + break; + + case IP_VS_SO_GET_DESTS: + { + struct ip_vs_get_dests *get; + int size; + + get = (struct ip_vs_get_dests *)arg; + size = struct_size(get, entrytable, get->num_dests); + if (*len != size) { + pr_err("length: %u != %u\n", *len, size); + ret = -EINVAL; + goto out; + } + ret = __ip_vs_get_dest_entries(ipvs, get, user); + } + break; + + case IP_VS_SO_GET_TIMEOUT: + { + struct ip_vs_timeout_user t; + + __ip_vs_get_timeouts(ipvs, &t); + if (copy_to_user(user, &t, sizeof(t)) != 0) + ret = -EFAULT; + } + break; + + default: + ret = -EINVAL; + } + +out: + mutex_unlock(&__ip_vs_mutex); + return ret; +} + + +static struct nf_sockopt_ops ip_vs_sockopts = { + .pf = PF_INET, + .set_optmin = IP_VS_BASE_CTL, + .set_optmax = IP_VS_SO_SET_MAX+1, + .set = do_ip_vs_set_ctl, + .get_optmin = IP_VS_BASE_CTL, + .get_optmax = IP_VS_SO_GET_MAX+1, + .get = do_ip_vs_get_ctl, + .owner = THIS_MODULE, +}; + +/* + * Generic Netlink interface + */ + +/* IPVS genetlink family */ +static struct genl_family ip_vs_genl_family; + +/* Policy used for first-level command attributes */ +static const struct nla_policy ip_vs_cmd_policy[IPVS_CMD_ATTR_MAX + 1] = { + [IPVS_CMD_ATTR_SERVICE] = { .type = NLA_NESTED }, + [IPVS_CMD_ATTR_DEST] = { .type = NLA_NESTED }, + [IPVS_CMD_ATTR_DAEMON] = { .type = NLA_NESTED }, + [IPVS_CMD_ATTR_TIMEOUT_TCP] = { .type = NLA_U32 }, + [IPVS_CMD_ATTR_TIMEOUT_TCP_FIN] = { .type = NLA_U32 }, + [IPVS_CMD_ATTR_TIMEOUT_UDP] = { .type = NLA_U32 }, +}; + +/* Policy used for attributes in nested attribute IPVS_CMD_ATTR_DAEMON */ +static const struct nla_policy ip_vs_daemon_policy[IPVS_DAEMON_ATTR_MAX + 1] = { + [IPVS_DAEMON_ATTR_STATE] = { .type = NLA_U32 }, + [IPVS_DAEMON_ATTR_MCAST_IFN] = { .type = NLA_NUL_STRING, + .len = IP_VS_IFNAME_MAXLEN - 1 }, + [IPVS_DAEMON_ATTR_SYNC_ID] = { .type = NLA_U32 }, + [IPVS_DAEMON_ATTR_SYNC_MAXLEN] = { .type = NLA_U16 }, + [IPVS_DAEMON_ATTR_MCAST_GROUP] = { .type = NLA_U32 }, + [IPVS_DAEMON_ATTR_MCAST_GROUP6] = { .len = sizeof(struct in6_addr) }, + [IPVS_DAEMON_ATTR_MCAST_PORT] = { .type = NLA_U16 }, + [IPVS_DAEMON_ATTR_MCAST_TTL] = { .type = NLA_U8 }, +}; + +/* Policy used for attributes in nested attribute IPVS_CMD_ATTR_SERVICE */ +static const struct nla_policy ip_vs_svc_policy[IPVS_SVC_ATTR_MAX + 1] = { + [IPVS_SVC_ATTR_AF] = { .type = NLA_U16 }, + [IPVS_SVC_ATTR_PROTOCOL] = { .type = NLA_U16 }, + [IPVS_SVC_ATTR_ADDR] = { .type = NLA_BINARY, + .len = sizeof(union nf_inet_addr) }, + [IPVS_SVC_ATTR_PORT] = { .type = NLA_U16 }, + [IPVS_SVC_ATTR_FWMARK] = { .type = NLA_U32 }, + [IPVS_SVC_ATTR_SCHED_NAME] = { .type = NLA_NUL_STRING, + .len = IP_VS_SCHEDNAME_MAXLEN - 1 }, + [IPVS_SVC_ATTR_PE_NAME] = { .type = NLA_NUL_STRING, + .len = IP_VS_PENAME_MAXLEN }, + [IPVS_SVC_ATTR_FLAGS] = { .type = NLA_BINARY, + .len = sizeof(struct ip_vs_flags) }, + [IPVS_SVC_ATTR_TIMEOUT] = { .type = NLA_U32 }, + [IPVS_SVC_ATTR_NETMASK] = { .type = NLA_U32 }, + [IPVS_SVC_ATTR_STATS] = { .type = NLA_NESTED }, +}; + +/* Policy used for attributes in nested attribute IPVS_CMD_ATTR_DEST */ +static const struct nla_policy ip_vs_dest_policy[IPVS_DEST_ATTR_MAX + 1] = { + [IPVS_DEST_ATTR_ADDR] = { .type = NLA_BINARY, + .len = sizeof(union nf_inet_addr) }, + [IPVS_DEST_ATTR_PORT] = { .type = NLA_U16 }, + [IPVS_DEST_ATTR_FWD_METHOD] = { .type = NLA_U32 }, + [IPVS_DEST_ATTR_WEIGHT] = { .type = NLA_U32 }, + [IPVS_DEST_ATTR_U_THRESH] = { .type = NLA_U32 }, + [IPVS_DEST_ATTR_L_THRESH] = { .type = NLA_U32 }, + [IPVS_DEST_ATTR_ACTIVE_CONNS] = { .type = NLA_U32 }, + [IPVS_DEST_ATTR_INACT_CONNS] = { .type = NLA_U32 }, + [IPVS_DEST_ATTR_PERSIST_CONNS] = { .type = NLA_U32 }, + [IPVS_DEST_ATTR_STATS] = { .type = NLA_NESTED }, + [IPVS_DEST_ATTR_ADDR_FAMILY] = { .type = NLA_U16 }, + [IPVS_DEST_ATTR_TUN_TYPE] = { .type = NLA_U8 }, + [IPVS_DEST_ATTR_TUN_PORT] = { .type = NLA_U16 }, + [IPVS_DEST_ATTR_TUN_FLAGS] = { .type = NLA_U16 }, +}; + +static int ip_vs_genl_fill_stats(struct sk_buff *skb, int container_type, + struct ip_vs_kstats *kstats) +{ + struct nlattr *nl_stats = nla_nest_start_noflag(skb, container_type); + + if (!nl_stats) + return -EMSGSIZE; + + if (nla_put_u32(skb, IPVS_STATS_ATTR_CONNS, (u32)kstats->conns) || + nla_put_u32(skb, IPVS_STATS_ATTR_INPKTS, (u32)kstats->inpkts) || + nla_put_u32(skb, IPVS_STATS_ATTR_OUTPKTS, (u32)kstats->outpkts) || + nla_put_u64_64bit(skb, IPVS_STATS_ATTR_INBYTES, kstats->inbytes, + IPVS_STATS_ATTR_PAD) || + nla_put_u64_64bit(skb, IPVS_STATS_ATTR_OUTBYTES, kstats->outbytes, + IPVS_STATS_ATTR_PAD) || + nla_put_u32(skb, IPVS_STATS_ATTR_CPS, (u32)kstats->cps) || + nla_put_u32(skb, IPVS_STATS_ATTR_INPPS, (u32)kstats->inpps) || + nla_put_u32(skb, IPVS_STATS_ATTR_OUTPPS, (u32)kstats->outpps) || + nla_put_u32(skb, IPVS_STATS_ATTR_INBPS, (u32)kstats->inbps) || + nla_put_u32(skb, IPVS_STATS_ATTR_OUTBPS, (u32)kstats->outbps)) + goto nla_put_failure; + nla_nest_end(skb, nl_stats); + + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nl_stats); + return -EMSGSIZE; +} + +static int ip_vs_genl_fill_stats64(struct sk_buff *skb, int container_type, + struct ip_vs_kstats *kstats) +{ + struct nlattr *nl_stats = nla_nest_start_noflag(skb, container_type); + + if (!nl_stats) + return -EMSGSIZE; + + if (nla_put_u64_64bit(skb, IPVS_STATS_ATTR_CONNS, kstats->conns, + IPVS_STATS_ATTR_PAD) || + nla_put_u64_64bit(skb, IPVS_STATS_ATTR_INPKTS, kstats->inpkts, + IPVS_STATS_ATTR_PAD) || + nla_put_u64_64bit(skb, IPVS_STATS_ATTR_OUTPKTS, kstats->outpkts, + IPVS_STATS_ATTR_PAD) || + nla_put_u64_64bit(skb, IPVS_STATS_ATTR_INBYTES, kstats->inbytes, + IPVS_STATS_ATTR_PAD) || + nla_put_u64_64bit(skb, IPVS_STATS_ATTR_OUTBYTES, kstats->outbytes, + IPVS_STATS_ATTR_PAD) || + nla_put_u64_64bit(skb, IPVS_STATS_ATTR_CPS, kstats->cps, + IPVS_STATS_ATTR_PAD) || + nla_put_u64_64bit(skb, IPVS_STATS_ATTR_INPPS, kstats->inpps, + IPVS_STATS_ATTR_PAD) || + nla_put_u64_64bit(skb, IPVS_STATS_ATTR_OUTPPS, kstats->outpps, + IPVS_STATS_ATTR_PAD) || + nla_put_u64_64bit(skb, IPVS_STATS_ATTR_INBPS, kstats->inbps, + IPVS_STATS_ATTR_PAD) || + nla_put_u64_64bit(skb, IPVS_STATS_ATTR_OUTBPS, kstats->outbps, + IPVS_STATS_ATTR_PAD)) + goto nla_put_failure; + nla_nest_end(skb, nl_stats); + + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nl_stats); + return -EMSGSIZE; +} + +static int ip_vs_genl_fill_service(struct sk_buff *skb, + struct ip_vs_service *svc) +{ + struct ip_vs_scheduler *sched; + struct ip_vs_pe *pe; + struct nlattr *nl_service; + struct ip_vs_flags flags = { .flags = svc->flags, + .mask = ~0 }; + struct ip_vs_kstats kstats; + char *sched_name; + + nl_service = nla_nest_start_noflag(skb, IPVS_CMD_ATTR_SERVICE); + if (!nl_service) + return -EMSGSIZE; + + if (nla_put_u16(skb, IPVS_SVC_ATTR_AF, svc->af)) + goto nla_put_failure; + if (svc->fwmark) { + if (nla_put_u32(skb, IPVS_SVC_ATTR_FWMARK, svc->fwmark)) + goto nla_put_failure; + } else { + if (nla_put_u16(skb, IPVS_SVC_ATTR_PROTOCOL, svc->protocol) || + nla_put(skb, IPVS_SVC_ATTR_ADDR, sizeof(svc->addr), &svc->addr) || + nla_put_be16(skb, IPVS_SVC_ATTR_PORT, svc->port)) + goto nla_put_failure; + } + + sched = rcu_dereference_protected(svc->scheduler, 1); + sched_name = sched ? sched->name : "none"; + pe = rcu_dereference_protected(svc->pe, 1); + if (nla_put_string(skb, IPVS_SVC_ATTR_SCHED_NAME, sched_name) || + (pe && nla_put_string(skb, IPVS_SVC_ATTR_PE_NAME, pe->name)) || + nla_put(skb, IPVS_SVC_ATTR_FLAGS, sizeof(flags), &flags) || + nla_put_u32(skb, IPVS_SVC_ATTR_TIMEOUT, svc->timeout / HZ) || + nla_put_be32(skb, IPVS_SVC_ATTR_NETMASK, svc->netmask)) + goto nla_put_failure; + ip_vs_copy_stats(&kstats, &svc->stats); + if (ip_vs_genl_fill_stats(skb, IPVS_SVC_ATTR_STATS, &kstats)) + goto nla_put_failure; + if (ip_vs_genl_fill_stats64(skb, IPVS_SVC_ATTR_STATS64, &kstats)) + goto nla_put_failure; + + nla_nest_end(skb, nl_service); + + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nl_service); + return -EMSGSIZE; +} + +static int ip_vs_genl_dump_service(struct sk_buff *skb, + struct ip_vs_service *svc, + struct netlink_callback *cb) +{ + void *hdr; + + hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &ip_vs_genl_family, NLM_F_MULTI, + IPVS_CMD_NEW_SERVICE); + if (!hdr) + return -EMSGSIZE; + + if (ip_vs_genl_fill_service(skb, svc) < 0) + goto nla_put_failure; + + genlmsg_end(skb, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; +} + +static int ip_vs_genl_dump_services(struct sk_buff *skb, + struct netlink_callback *cb) +{ + int idx = 0, i; + int start = cb->args[0]; + struct ip_vs_service *svc; + struct net *net = sock_net(skb->sk); + struct netns_ipvs *ipvs = net_ipvs(net); + + mutex_lock(&__ip_vs_mutex); + for (i = 0; i < IP_VS_SVC_TAB_SIZE; i++) { + hlist_for_each_entry(svc, &ip_vs_svc_table[i], s_list) { + if (++idx <= start || (svc->ipvs != ipvs)) + continue; + if (ip_vs_genl_dump_service(skb, svc, cb) < 0) { + idx--; + goto nla_put_failure; + } + } + } + + for (i = 0; i < IP_VS_SVC_TAB_SIZE; i++) { + hlist_for_each_entry(svc, &ip_vs_svc_fwm_table[i], f_list) { + if (++idx <= start || (svc->ipvs != ipvs)) + continue; + if (ip_vs_genl_dump_service(skb, svc, cb) < 0) { + idx--; + goto nla_put_failure; + } + } + } + +nla_put_failure: + mutex_unlock(&__ip_vs_mutex); + cb->args[0] = idx; + + return skb->len; +} + +static bool ip_vs_is_af_valid(int af) +{ + if (af == AF_INET) + return true; +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6 && ipv6_mod_enabled()) + return true; +#endif + return false; +} + +static int ip_vs_genl_parse_service(struct netns_ipvs *ipvs, + struct ip_vs_service_user_kern *usvc, + struct nlattr *nla, bool full_entry, + struct ip_vs_service **ret_svc) +{ + struct nlattr *attrs[IPVS_SVC_ATTR_MAX + 1]; + struct nlattr *nla_af, *nla_port, *nla_fwmark, *nla_protocol, *nla_addr; + struct ip_vs_service *svc; + + /* Parse mandatory identifying service fields first */ + if (nla == NULL || + nla_parse_nested_deprecated(attrs, IPVS_SVC_ATTR_MAX, nla, ip_vs_svc_policy, NULL)) + return -EINVAL; + + nla_af = attrs[IPVS_SVC_ATTR_AF]; + nla_protocol = attrs[IPVS_SVC_ATTR_PROTOCOL]; + nla_addr = attrs[IPVS_SVC_ATTR_ADDR]; + nla_port = attrs[IPVS_SVC_ATTR_PORT]; + nla_fwmark = attrs[IPVS_SVC_ATTR_FWMARK]; + + if (!(nla_af && (nla_fwmark || (nla_port && nla_protocol && nla_addr)))) + return -EINVAL; + + memset(usvc, 0, sizeof(*usvc)); + + usvc->af = nla_get_u16(nla_af); + if (!ip_vs_is_af_valid(usvc->af)) + return -EAFNOSUPPORT; + + if (nla_fwmark) { + usvc->protocol = IPPROTO_TCP; + usvc->fwmark = nla_get_u32(nla_fwmark); + } else { + usvc->protocol = nla_get_u16(nla_protocol); + nla_memcpy(&usvc->addr, nla_addr, sizeof(usvc->addr)); + usvc->port = nla_get_be16(nla_port); + usvc->fwmark = 0; + } + + rcu_read_lock(); + if (usvc->fwmark) + svc = __ip_vs_svc_fwm_find(ipvs, usvc->af, usvc->fwmark); + else + svc = __ip_vs_service_find(ipvs, usvc->af, usvc->protocol, + &usvc->addr, usvc->port); + rcu_read_unlock(); + *ret_svc = svc; + + /* If a full entry was requested, check for the additional fields */ + if (full_entry) { + struct nlattr *nla_sched, *nla_flags, *nla_pe, *nla_timeout, + *nla_netmask; + struct ip_vs_flags flags; + + nla_sched = attrs[IPVS_SVC_ATTR_SCHED_NAME]; + nla_pe = attrs[IPVS_SVC_ATTR_PE_NAME]; + nla_flags = attrs[IPVS_SVC_ATTR_FLAGS]; + nla_timeout = attrs[IPVS_SVC_ATTR_TIMEOUT]; + nla_netmask = attrs[IPVS_SVC_ATTR_NETMASK]; + + if (!(nla_sched && nla_flags && nla_timeout && nla_netmask)) + return -EINVAL; + + nla_memcpy(&flags, nla_flags, sizeof(flags)); + + /* prefill flags from service if it already exists */ + if (svc) + usvc->flags = svc->flags; + + /* set new flags from userland */ + usvc->flags = (usvc->flags & ~flags.mask) | + (flags.flags & flags.mask); + usvc->sched_name = nla_data(nla_sched); + usvc->pe_name = nla_pe ? nla_data(nla_pe) : NULL; + usvc->timeout = nla_get_u32(nla_timeout); + usvc->netmask = nla_get_be32(nla_netmask); + } + + return 0; +} + +static struct ip_vs_service *ip_vs_genl_find_service(struct netns_ipvs *ipvs, + struct nlattr *nla) +{ + struct ip_vs_service_user_kern usvc; + struct ip_vs_service *svc; + int ret; + + ret = ip_vs_genl_parse_service(ipvs, &usvc, nla, false, &svc); + return ret ? ERR_PTR(ret) : svc; +} + +static int ip_vs_genl_fill_dest(struct sk_buff *skb, struct ip_vs_dest *dest) +{ + struct nlattr *nl_dest; + struct ip_vs_kstats kstats; + + nl_dest = nla_nest_start_noflag(skb, IPVS_CMD_ATTR_DEST); + if (!nl_dest) + return -EMSGSIZE; + + if (nla_put(skb, IPVS_DEST_ATTR_ADDR, sizeof(dest->addr), &dest->addr) || + nla_put_be16(skb, IPVS_DEST_ATTR_PORT, dest->port) || + nla_put_u32(skb, IPVS_DEST_ATTR_FWD_METHOD, + (atomic_read(&dest->conn_flags) & + IP_VS_CONN_F_FWD_MASK)) || + nla_put_u32(skb, IPVS_DEST_ATTR_WEIGHT, + atomic_read(&dest->weight)) || + nla_put_u8(skb, IPVS_DEST_ATTR_TUN_TYPE, + dest->tun_type) || + nla_put_be16(skb, IPVS_DEST_ATTR_TUN_PORT, + dest->tun_port) || + nla_put_u16(skb, IPVS_DEST_ATTR_TUN_FLAGS, + dest->tun_flags) || + nla_put_u32(skb, IPVS_DEST_ATTR_U_THRESH, dest->u_threshold) || + nla_put_u32(skb, IPVS_DEST_ATTR_L_THRESH, dest->l_threshold) || + nla_put_u32(skb, IPVS_DEST_ATTR_ACTIVE_CONNS, + atomic_read(&dest->activeconns)) || + nla_put_u32(skb, IPVS_DEST_ATTR_INACT_CONNS, + atomic_read(&dest->inactconns)) || + nla_put_u32(skb, IPVS_DEST_ATTR_PERSIST_CONNS, + atomic_read(&dest->persistconns)) || + nla_put_u16(skb, IPVS_DEST_ATTR_ADDR_FAMILY, dest->af)) + goto nla_put_failure; + ip_vs_copy_stats(&kstats, &dest->stats); + if (ip_vs_genl_fill_stats(skb, IPVS_DEST_ATTR_STATS, &kstats)) + goto nla_put_failure; + if (ip_vs_genl_fill_stats64(skb, IPVS_DEST_ATTR_STATS64, &kstats)) + goto nla_put_failure; + + nla_nest_end(skb, nl_dest); + + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nl_dest); + return -EMSGSIZE; +} + +static int ip_vs_genl_dump_dest(struct sk_buff *skb, struct ip_vs_dest *dest, + struct netlink_callback *cb) +{ + void *hdr; + + hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &ip_vs_genl_family, NLM_F_MULTI, + IPVS_CMD_NEW_DEST); + if (!hdr) + return -EMSGSIZE; + + if (ip_vs_genl_fill_dest(skb, dest) < 0) + goto nla_put_failure; + + genlmsg_end(skb, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; +} + +static int ip_vs_genl_dump_dests(struct sk_buff *skb, + struct netlink_callback *cb) +{ + int idx = 0; + int start = cb->args[0]; + struct ip_vs_service *svc; + struct ip_vs_dest *dest; + struct nlattr *attrs[IPVS_CMD_ATTR_MAX + 1]; + struct net *net = sock_net(skb->sk); + struct netns_ipvs *ipvs = net_ipvs(net); + + mutex_lock(&__ip_vs_mutex); + + /* Try to find the service for which to dump destinations */ + if (nlmsg_parse_deprecated(cb->nlh, GENL_HDRLEN, attrs, IPVS_CMD_ATTR_MAX, ip_vs_cmd_policy, cb->extack)) + goto out_err; + + + svc = ip_vs_genl_find_service(ipvs, attrs[IPVS_CMD_ATTR_SERVICE]); + if (IS_ERR_OR_NULL(svc)) + goto out_err; + + /* Dump the destinations */ + list_for_each_entry(dest, &svc->destinations, n_list) { + if (++idx <= start) + continue; + if (ip_vs_genl_dump_dest(skb, dest, cb) < 0) { + idx--; + goto nla_put_failure; + } + } + +nla_put_failure: + cb->args[0] = idx; + +out_err: + mutex_unlock(&__ip_vs_mutex); + + return skb->len; +} + +static int ip_vs_genl_parse_dest(struct ip_vs_dest_user_kern *udest, + struct nlattr *nla, bool full_entry) +{ + struct nlattr *attrs[IPVS_DEST_ATTR_MAX + 1]; + struct nlattr *nla_addr, *nla_port; + struct nlattr *nla_addr_family; + + /* Parse mandatory identifying destination fields first */ + if (nla == NULL || + nla_parse_nested_deprecated(attrs, IPVS_DEST_ATTR_MAX, nla, ip_vs_dest_policy, NULL)) + return -EINVAL; + + nla_addr = attrs[IPVS_DEST_ATTR_ADDR]; + nla_port = attrs[IPVS_DEST_ATTR_PORT]; + nla_addr_family = attrs[IPVS_DEST_ATTR_ADDR_FAMILY]; + + if (!(nla_addr && nla_port)) + return -EINVAL; + + memset(udest, 0, sizeof(*udest)); + + nla_memcpy(&udest->addr, nla_addr, sizeof(udest->addr)); + udest->port = nla_get_be16(nla_port); + + if (nla_addr_family) + udest->af = nla_get_u16(nla_addr_family); + else + udest->af = 0; + + /* If a full entry was requested, check for the additional fields */ + if (full_entry) { + struct nlattr *nla_fwd, *nla_weight, *nla_u_thresh, + *nla_l_thresh, *nla_tun_type, *nla_tun_port, + *nla_tun_flags; + + nla_fwd = attrs[IPVS_DEST_ATTR_FWD_METHOD]; + nla_weight = attrs[IPVS_DEST_ATTR_WEIGHT]; + nla_u_thresh = attrs[IPVS_DEST_ATTR_U_THRESH]; + nla_l_thresh = attrs[IPVS_DEST_ATTR_L_THRESH]; + nla_tun_type = attrs[IPVS_DEST_ATTR_TUN_TYPE]; + nla_tun_port = attrs[IPVS_DEST_ATTR_TUN_PORT]; + nla_tun_flags = attrs[IPVS_DEST_ATTR_TUN_FLAGS]; + + if (!(nla_fwd && nla_weight && nla_u_thresh && nla_l_thresh)) + return -EINVAL; + + udest->conn_flags = nla_get_u32(nla_fwd) + & IP_VS_CONN_F_FWD_MASK; + udest->weight = nla_get_u32(nla_weight); + udest->u_threshold = nla_get_u32(nla_u_thresh); + udest->l_threshold = nla_get_u32(nla_l_thresh); + + if (nla_tun_type) + udest->tun_type = nla_get_u8(nla_tun_type); + + if (nla_tun_port) + udest->tun_port = nla_get_be16(nla_tun_port); + + if (nla_tun_flags) + udest->tun_flags = nla_get_u16(nla_tun_flags); + } + + return 0; +} + +static int ip_vs_genl_fill_daemon(struct sk_buff *skb, __u32 state, + struct ipvs_sync_daemon_cfg *c) +{ + struct nlattr *nl_daemon; + + nl_daemon = nla_nest_start_noflag(skb, IPVS_CMD_ATTR_DAEMON); + if (!nl_daemon) + return -EMSGSIZE; + + if (nla_put_u32(skb, IPVS_DAEMON_ATTR_STATE, state) || + nla_put_string(skb, IPVS_DAEMON_ATTR_MCAST_IFN, c->mcast_ifn) || + nla_put_u32(skb, IPVS_DAEMON_ATTR_SYNC_ID, c->syncid) || + nla_put_u16(skb, IPVS_DAEMON_ATTR_SYNC_MAXLEN, c->sync_maxlen) || + nla_put_u16(skb, IPVS_DAEMON_ATTR_MCAST_PORT, c->mcast_port) || + nla_put_u8(skb, IPVS_DAEMON_ATTR_MCAST_TTL, c->mcast_ttl)) + goto nla_put_failure; +#ifdef CONFIG_IP_VS_IPV6 + if (c->mcast_af == AF_INET6) { + if (nla_put_in6_addr(skb, IPVS_DAEMON_ATTR_MCAST_GROUP6, + &c->mcast_group.in6)) + goto nla_put_failure; + } else +#endif + if (c->mcast_af == AF_INET && + nla_put_in_addr(skb, IPVS_DAEMON_ATTR_MCAST_GROUP, + c->mcast_group.ip)) + goto nla_put_failure; + nla_nest_end(skb, nl_daemon); + + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nl_daemon); + return -EMSGSIZE; +} + +static int ip_vs_genl_dump_daemon(struct sk_buff *skb, __u32 state, + struct ipvs_sync_daemon_cfg *c, + struct netlink_callback *cb) +{ + void *hdr; + hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &ip_vs_genl_family, NLM_F_MULTI, + IPVS_CMD_NEW_DAEMON); + if (!hdr) + return -EMSGSIZE; + + if (ip_vs_genl_fill_daemon(skb, state, c)) + goto nla_put_failure; + + genlmsg_end(skb, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; +} + +static int ip_vs_genl_dump_daemons(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + struct netns_ipvs *ipvs = net_ipvs(net); + + mutex_lock(&ipvs->sync_mutex); + if ((ipvs->sync_state & IP_VS_STATE_MASTER) && !cb->args[0]) { + if (ip_vs_genl_dump_daemon(skb, IP_VS_STATE_MASTER, + &ipvs->mcfg, cb) < 0) + goto nla_put_failure; + + cb->args[0] = 1; + } + + if ((ipvs->sync_state & IP_VS_STATE_BACKUP) && !cb->args[1]) { + if (ip_vs_genl_dump_daemon(skb, IP_VS_STATE_BACKUP, + &ipvs->bcfg, cb) < 0) + goto nla_put_failure; + + cb->args[1] = 1; + } + +nla_put_failure: + mutex_unlock(&ipvs->sync_mutex); + + return skb->len; +} + +static int ip_vs_genl_new_daemon(struct netns_ipvs *ipvs, struct nlattr **attrs) +{ + struct ipvs_sync_daemon_cfg c; + struct nlattr *a; + int ret; + + memset(&c, 0, sizeof(c)); + if (!(attrs[IPVS_DAEMON_ATTR_STATE] && + attrs[IPVS_DAEMON_ATTR_MCAST_IFN] && + attrs[IPVS_DAEMON_ATTR_SYNC_ID])) + return -EINVAL; + strscpy(c.mcast_ifn, nla_data(attrs[IPVS_DAEMON_ATTR_MCAST_IFN]), + sizeof(c.mcast_ifn)); + c.syncid = nla_get_u32(attrs[IPVS_DAEMON_ATTR_SYNC_ID]); + + a = attrs[IPVS_DAEMON_ATTR_SYNC_MAXLEN]; + if (a) + c.sync_maxlen = nla_get_u16(a); + + a = attrs[IPVS_DAEMON_ATTR_MCAST_GROUP]; + if (a) { + c.mcast_af = AF_INET; + c.mcast_group.ip = nla_get_in_addr(a); + if (!ipv4_is_multicast(c.mcast_group.ip)) + return -EINVAL; + } else { + a = attrs[IPVS_DAEMON_ATTR_MCAST_GROUP6]; + if (a) { +#ifdef CONFIG_IP_VS_IPV6 + int addr_type; + + c.mcast_af = AF_INET6; + c.mcast_group.in6 = nla_get_in6_addr(a); + addr_type = ipv6_addr_type(&c.mcast_group.in6); + if (!(addr_type & IPV6_ADDR_MULTICAST)) + return -EINVAL; +#else + return -EAFNOSUPPORT; +#endif + } + } + + a = attrs[IPVS_DAEMON_ATTR_MCAST_PORT]; + if (a) + c.mcast_port = nla_get_u16(a); + + a = attrs[IPVS_DAEMON_ATTR_MCAST_TTL]; + if (a) + c.mcast_ttl = nla_get_u8(a); + + /* The synchronization protocol is incompatible with mixed family + * services + */ + if (ipvs->mixed_address_family_dests > 0) + return -EINVAL; + + ret = start_sync_thread(ipvs, &c, + nla_get_u32(attrs[IPVS_DAEMON_ATTR_STATE])); + return ret; +} + +static int ip_vs_genl_del_daemon(struct netns_ipvs *ipvs, struct nlattr **attrs) +{ + int ret; + + if (!attrs[IPVS_DAEMON_ATTR_STATE]) + return -EINVAL; + + ret = stop_sync_thread(ipvs, + nla_get_u32(attrs[IPVS_DAEMON_ATTR_STATE])); + return ret; +} + +static int ip_vs_genl_set_config(struct netns_ipvs *ipvs, struct nlattr **attrs) +{ + struct ip_vs_timeout_user t; + + __ip_vs_get_timeouts(ipvs, &t); + + if (attrs[IPVS_CMD_ATTR_TIMEOUT_TCP]) + t.tcp_timeout = nla_get_u32(attrs[IPVS_CMD_ATTR_TIMEOUT_TCP]); + + if (attrs[IPVS_CMD_ATTR_TIMEOUT_TCP_FIN]) + t.tcp_fin_timeout = + nla_get_u32(attrs[IPVS_CMD_ATTR_TIMEOUT_TCP_FIN]); + + if (attrs[IPVS_CMD_ATTR_TIMEOUT_UDP]) + t.udp_timeout = nla_get_u32(attrs[IPVS_CMD_ATTR_TIMEOUT_UDP]); + + return ip_vs_set_timeout(ipvs, &t); +} + +static int ip_vs_genl_set_daemon(struct sk_buff *skb, struct genl_info *info) +{ + int ret = -EINVAL, cmd; + struct net *net = sock_net(skb->sk); + struct netns_ipvs *ipvs = net_ipvs(net); + + cmd = info->genlhdr->cmd; + + if (cmd == IPVS_CMD_NEW_DAEMON || cmd == IPVS_CMD_DEL_DAEMON) { + struct nlattr *daemon_attrs[IPVS_DAEMON_ATTR_MAX + 1]; + + if (!info->attrs[IPVS_CMD_ATTR_DAEMON] || + nla_parse_nested_deprecated(daemon_attrs, IPVS_DAEMON_ATTR_MAX, info->attrs[IPVS_CMD_ATTR_DAEMON], ip_vs_daemon_policy, info->extack)) + goto out; + + if (cmd == IPVS_CMD_NEW_DAEMON) + ret = ip_vs_genl_new_daemon(ipvs, daemon_attrs); + else + ret = ip_vs_genl_del_daemon(ipvs, daemon_attrs); + } + +out: + return ret; +} + +static int ip_vs_genl_set_cmd(struct sk_buff *skb, struct genl_info *info) +{ + bool need_full_svc = false, need_full_dest = false; + struct ip_vs_service *svc = NULL; + struct ip_vs_service_user_kern usvc; + struct ip_vs_dest_user_kern udest; + int ret = 0, cmd; + struct net *net = sock_net(skb->sk); + struct netns_ipvs *ipvs = net_ipvs(net); + + cmd = info->genlhdr->cmd; + + mutex_lock(&__ip_vs_mutex); + + if (cmd == IPVS_CMD_FLUSH) { + ret = ip_vs_flush(ipvs, false); + goto out; + } else if (cmd == IPVS_CMD_SET_CONFIG) { + ret = ip_vs_genl_set_config(ipvs, info->attrs); + goto out; + } else if (cmd == IPVS_CMD_ZERO && + !info->attrs[IPVS_CMD_ATTR_SERVICE]) { + ret = ip_vs_zero_all(ipvs); + goto out; + } + + /* All following commands require a service argument, so check if we + * received a valid one. We need a full service specification when + * adding / editing a service. Only identifying members otherwise. */ + if (cmd == IPVS_CMD_NEW_SERVICE || cmd == IPVS_CMD_SET_SERVICE) + need_full_svc = true; + + ret = ip_vs_genl_parse_service(ipvs, &usvc, + info->attrs[IPVS_CMD_ATTR_SERVICE], + need_full_svc, &svc); + if (ret) + goto out; + + /* Unless we're adding a new service, the service must already exist */ + if ((cmd != IPVS_CMD_NEW_SERVICE) && (svc == NULL)) { + ret = -ESRCH; + goto out; + } + + /* Destination commands require a valid destination argument. For + * adding / editing a destination, we need a full destination + * specification. */ + if (cmd == IPVS_CMD_NEW_DEST || cmd == IPVS_CMD_SET_DEST || + cmd == IPVS_CMD_DEL_DEST) { + if (cmd != IPVS_CMD_DEL_DEST) + need_full_dest = true; + + ret = ip_vs_genl_parse_dest(&udest, + info->attrs[IPVS_CMD_ATTR_DEST], + need_full_dest); + if (ret) + goto out; + + /* Old protocols did not allow the user to specify address + * family, so we set it to zero instead. We also didn't + * allow heterogeneous pools in the old code, so it's safe + * to assume that this will have the same address family as + * the service. + */ + if (udest.af == 0) + udest.af = svc->af; + + if (!ip_vs_is_af_valid(udest.af)) { + ret = -EAFNOSUPPORT; + goto out; + } + + if (udest.af != svc->af && cmd != IPVS_CMD_DEL_DEST) { + /* The synchronization protocol is incompatible + * with mixed family services + */ + if (ipvs->sync_state) { + ret = -EINVAL; + goto out; + } + + /* Which connection types do we support? */ + switch (udest.conn_flags) { + case IP_VS_CONN_F_TUNNEL: + /* We are able to forward this */ + break; + default: + ret = -EINVAL; + goto out; + } + } + } + + switch (cmd) { + case IPVS_CMD_NEW_SERVICE: + if (svc == NULL) + ret = ip_vs_add_service(ipvs, &usvc, &svc); + else + ret = -EEXIST; + break; + case IPVS_CMD_SET_SERVICE: + ret = ip_vs_edit_service(svc, &usvc); + break; + case IPVS_CMD_DEL_SERVICE: + ret = ip_vs_del_service(svc); + /* do not use svc, it can be freed */ + break; + case IPVS_CMD_NEW_DEST: + ret = ip_vs_add_dest(svc, &udest); + break; + case IPVS_CMD_SET_DEST: + ret = ip_vs_edit_dest(svc, &udest); + break; + case IPVS_CMD_DEL_DEST: + ret = ip_vs_del_dest(svc, &udest); + break; + case IPVS_CMD_ZERO: + ret = ip_vs_zero_service(svc); + break; + default: + ret = -EINVAL; + } + +out: + mutex_unlock(&__ip_vs_mutex); + + return ret; +} + +static int ip_vs_genl_get_cmd(struct sk_buff *skb, struct genl_info *info) +{ + struct sk_buff *msg; + void *reply; + int ret, cmd, reply_cmd; + struct net *net = sock_net(skb->sk); + struct netns_ipvs *ipvs = net_ipvs(net); + + cmd = info->genlhdr->cmd; + + if (cmd == IPVS_CMD_GET_SERVICE) + reply_cmd = IPVS_CMD_NEW_SERVICE; + else if (cmd == IPVS_CMD_GET_INFO) + reply_cmd = IPVS_CMD_SET_INFO; + else if (cmd == IPVS_CMD_GET_CONFIG) + reply_cmd = IPVS_CMD_SET_CONFIG; + else { + pr_err("unknown Generic Netlink command\n"); + return -EINVAL; + } + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + mutex_lock(&__ip_vs_mutex); + + reply = genlmsg_put_reply(msg, info, &ip_vs_genl_family, 0, reply_cmd); + if (reply == NULL) + goto nla_put_failure; + + switch (cmd) { + case IPVS_CMD_GET_SERVICE: + { + struct ip_vs_service *svc; + + svc = ip_vs_genl_find_service(ipvs, + info->attrs[IPVS_CMD_ATTR_SERVICE]); + if (IS_ERR(svc)) { + ret = PTR_ERR(svc); + goto out_err; + } else if (svc) { + ret = ip_vs_genl_fill_service(msg, svc); + if (ret) + goto nla_put_failure; + } else { + ret = -ESRCH; + goto out_err; + } + + break; + } + + case IPVS_CMD_GET_CONFIG: + { + struct ip_vs_timeout_user t; + + __ip_vs_get_timeouts(ipvs, &t); +#ifdef CONFIG_IP_VS_PROTO_TCP + if (nla_put_u32(msg, IPVS_CMD_ATTR_TIMEOUT_TCP, + t.tcp_timeout) || + nla_put_u32(msg, IPVS_CMD_ATTR_TIMEOUT_TCP_FIN, + t.tcp_fin_timeout)) + goto nla_put_failure; +#endif +#ifdef CONFIG_IP_VS_PROTO_UDP + if (nla_put_u32(msg, IPVS_CMD_ATTR_TIMEOUT_UDP, t.udp_timeout)) + goto nla_put_failure; +#endif + + break; + } + + case IPVS_CMD_GET_INFO: + if (nla_put_u32(msg, IPVS_INFO_ATTR_VERSION, + IP_VS_VERSION_CODE) || + nla_put_u32(msg, IPVS_INFO_ATTR_CONN_TAB_SIZE, + ip_vs_conn_tab_size)) + goto nla_put_failure; + break; + } + + genlmsg_end(msg, reply); + ret = genlmsg_reply(msg, info); + goto out; + +nla_put_failure: + pr_err("not enough space in Netlink message\n"); + ret = -EMSGSIZE; + +out_err: + nlmsg_free(msg); +out: + mutex_unlock(&__ip_vs_mutex); + + return ret; +} + + +static const struct genl_small_ops ip_vs_genl_ops[] = { + { + .cmd = IPVS_CMD_NEW_SERVICE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ip_vs_genl_set_cmd, + }, + { + .cmd = IPVS_CMD_SET_SERVICE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ip_vs_genl_set_cmd, + }, + { + .cmd = IPVS_CMD_DEL_SERVICE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ip_vs_genl_set_cmd, + }, + { + .cmd = IPVS_CMD_GET_SERVICE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ip_vs_genl_get_cmd, + .dumpit = ip_vs_genl_dump_services, + }, + { + .cmd = IPVS_CMD_NEW_DEST, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ip_vs_genl_set_cmd, + }, + { + .cmd = IPVS_CMD_SET_DEST, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ip_vs_genl_set_cmd, + }, + { + .cmd = IPVS_CMD_DEL_DEST, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ip_vs_genl_set_cmd, + }, + { + .cmd = IPVS_CMD_GET_DEST, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .dumpit = ip_vs_genl_dump_dests, + }, + { + .cmd = IPVS_CMD_NEW_DAEMON, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ip_vs_genl_set_daemon, + }, + { + .cmd = IPVS_CMD_DEL_DAEMON, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ip_vs_genl_set_daemon, + }, + { + .cmd = IPVS_CMD_GET_DAEMON, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .dumpit = ip_vs_genl_dump_daemons, + }, + { + .cmd = IPVS_CMD_SET_CONFIG, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ip_vs_genl_set_cmd, + }, + { + .cmd = IPVS_CMD_GET_CONFIG, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ip_vs_genl_get_cmd, + }, + { + .cmd = IPVS_CMD_GET_INFO, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ip_vs_genl_get_cmd, + }, + { + .cmd = IPVS_CMD_ZERO, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ip_vs_genl_set_cmd, + }, + { + .cmd = IPVS_CMD_FLUSH, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = ip_vs_genl_set_cmd, + }, +}; + +static struct genl_family ip_vs_genl_family __ro_after_init = { + .hdrsize = 0, + .name = IPVS_GENL_NAME, + .version = IPVS_GENL_VERSION, + .maxattr = IPVS_CMD_ATTR_MAX, + .policy = ip_vs_cmd_policy, + .netnsok = true, /* Make ipvsadm to work on netns */ + .module = THIS_MODULE, + .small_ops = ip_vs_genl_ops, + .n_small_ops = ARRAY_SIZE(ip_vs_genl_ops), + .resv_start_op = IPVS_CMD_FLUSH + 1, +}; + +static int __init ip_vs_genl_register(void) +{ + return genl_register_family(&ip_vs_genl_family); +} + +static void ip_vs_genl_unregister(void) +{ + genl_unregister_family(&ip_vs_genl_family); +} + +/* End of Generic Netlink interface definitions */ + +/* + * per netns intit/exit func. + */ +#ifdef CONFIG_SYSCTL +static int __net_init ip_vs_control_net_init_sysctl(struct netns_ipvs *ipvs) +{ + struct net *net = ipvs->net; + int idx; + struct ctl_table *tbl; + + atomic_set(&ipvs->dropentry, 0); + spin_lock_init(&ipvs->dropentry_lock); + spin_lock_init(&ipvs->droppacket_lock); + spin_lock_init(&ipvs->securetcp_lock); + + if (!net_eq(net, &init_net)) { + tbl = kmemdup(vs_vars, sizeof(vs_vars), GFP_KERNEL); + if (tbl == NULL) + return -ENOMEM; + + /* Don't export sysctls to unprivileged users */ + if (net->user_ns != &init_user_ns) + tbl[0].procname = NULL; + } else + tbl = vs_vars; + /* Initialize sysctl defaults */ + for (idx = 0; idx < ARRAY_SIZE(vs_vars); idx++) { + if (tbl[idx].proc_handler == proc_do_defense_mode) + tbl[idx].extra2 = ipvs; + } + idx = 0; + ipvs->sysctl_amemthresh = 1024; + tbl[idx++].data = &ipvs->sysctl_amemthresh; + ipvs->sysctl_am_droprate = 10; + tbl[idx++].data = &ipvs->sysctl_am_droprate; + tbl[idx++].data = &ipvs->sysctl_drop_entry; + tbl[idx++].data = &ipvs->sysctl_drop_packet; +#ifdef CONFIG_IP_VS_NFCT + tbl[idx++].data = &ipvs->sysctl_conntrack; +#endif + tbl[idx++].data = &ipvs->sysctl_secure_tcp; + ipvs->sysctl_snat_reroute = 1; + tbl[idx++].data = &ipvs->sysctl_snat_reroute; + ipvs->sysctl_sync_ver = 1; + tbl[idx++].data = &ipvs->sysctl_sync_ver; + ipvs->sysctl_sync_ports = 1; + tbl[idx++].data = &ipvs->sysctl_sync_ports; + tbl[idx++].data = &ipvs->sysctl_sync_persist_mode; + ipvs->sysctl_sync_qlen_max = nr_free_buffer_pages() / 32; + tbl[idx++].data = &ipvs->sysctl_sync_qlen_max; + ipvs->sysctl_sync_sock_size = 0; + tbl[idx++].data = &ipvs->sysctl_sync_sock_size; + tbl[idx++].data = &ipvs->sysctl_cache_bypass; + tbl[idx++].data = &ipvs->sysctl_expire_nodest_conn; + tbl[idx++].data = &ipvs->sysctl_sloppy_tcp; + tbl[idx++].data = &ipvs->sysctl_sloppy_sctp; + tbl[idx++].data = &ipvs->sysctl_expire_quiescent_template; + ipvs->sysctl_sync_threshold[0] = DEFAULT_SYNC_THRESHOLD; + ipvs->sysctl_sync_threshold[1] = DEFAULT_SYNC_PERIOD; + tbl[idx].data = &ipvs->sysctl_sync_threshold; + tbl[idx].extra2 = ipvs; + tbl[idx++].maxlen = sizeof(ipvs->sysctl_sync_threshold); + ipvs->sysctl_sync_refresh_period = DEFAULT_SYNC_REFRESH_PERIOD; + tbl[idx++].data = &ipvs->sysctl_sync_refresh_period; + ipvs->sysctl_sync_retries = clamp_t(int, DEFAULT_SYNC_RETRIES, 0, 3); + tbl[idx++].data = &ipvs->sysctl_sync_retries; + tbl[idx++].data = &ipvs->sysctl_nat_icmp_send; + ipvs->sysctl_pmtu_disc = 1; + tbl[idx++].data = &ipvs->sysctl_pmtu_disc; + tbl[idx++].data = &ipvs->sysctl_backup_only; + ipvs->sysctl_conn_reuse_mode = 1; + tbl[idx++].data = &ipvs->sysctl_conn_reuse_mode; + tbl[idx++].data = &ipvs->sysctl_schedule_icmp; + tbl[idx++].data = &ipvs->sysctl_ignore_tunneled; + ipvs->sysctl_run_estimation = 1; + tbl[idx++].data = &ipvs->sysctl_run_estimation; +#ifdef CONFIG_IP_VS_DEBUG + /* Global sysctls must be ro in non-init netns */ + if (!net_eq(net, &init_net)) + tbl[idx++].mode = 0444; +#endif + + ipvs->sysctl_hdr = register_net_sysctl(net, "net/ipv4/vs", tbl); + if (ipvs->sysctl_hdr == NULL) { + if (!net_eq(net, &init_net)) + kfree(tbl); + return -ENOMEM; + } + ip_vs_start_estimator(ipvs, &ipvs->tot_stats); + ipvs->sysctl_tbl = tbl; + /* Schedule defense work */ + INIT_DELAYED_WORK(&ipvs->defense_work, defense_work_handler); + queue_delayed_work(system_long_wq, &ipvs->defense_work, + DEFENSE_TIMER_PERIOD); + + /* Init delayed work for expiring no dest conn */ + INIT_DELAYED_WORK(&ipvs->expire_nodest_conn_work, + expire_nodest_conn_handler); + + return 0; +} + +static void __net_exit ip_vs_control_net_cleanup_sysctl(struct netns_ipvs *ipvs) +{ + struct net *net = ipvs->net; + + cancel_delayed_work_sync(&ipvs->expire_nodest_conn_work); + cancel_delayed_work_sync(&ipvs->defense_work); + cancel_work_sync(&ipvs->defense_work.work); + unregister_net_sysctl_table(ipvs->sysctl_hdr); + ip_vs_stop_estimator(ipvs, &ipvs->tot_stats); + + if (!net_eq(net, &init_net)) + kfree(ipvs->sysctl_tbl); +} + +#else + +static int __net_init ip_vs_control_net_init_sysctl(struct netns_ipvs *ipvs) { return 0; } +static void __net_exit ip_vs_control_net_cleanup_sysctl(struct netns_ipvs *ipvs) { } + +#endif + +static struct notifier_block ip_vs_dst_notifier = { + .notifier_call = ip_vs_dst_event, +#ifdef CONFIG_IP_VS_IPV6 + .priority = ADDRCONF_NOTIFY_PRIORITY + 5, +#endif +}; + +int __net_init ip_vs_control_net_init(struct netns_ipvs *ipvs) +{ + int i, idx; + + /* Initialize rs_table */ + for (idx = 0; idx < IP_VS_RTAB_SIZE; idx++) + INIT_HLIST_HEAD(&ipvs->rs_table[idx]); + + INIT_LIST_HEAD(&ipvs->dest_trash); + spin_lock_init(&ipvs->dest_trash_lock); + timer_setup(&ipvs->dest_trash_timer, ip_vs_dest_trash_expire, 0); + atomic_set(&ipvs->ftpsvc_counter, 0); + atomic_set(&ipvs->nullsvc_counter, 0); + atomic_set(&ipvs->conn_out_counter, 0); + + /* procfs stats */ + ipvs->tot_stats.cpustats = alloc_percpu(struct ip_vs_cpu_stats); + if (!ipvs->tot_stats.cpustats) + return -ENOMEM; + + for_each_possible_cpu(i) { + struct ip_vs_cpu_stats *ipvs_tot_stats; + ipvs_tot_stats = per_cpu_ptr(ipvs->tot_stats.cpustats, i); + u64_stats_init(&ipvs_tot_stats->syncp); + } + + spin_lock_init(&ipvs->tot_stats.lock); + +#ifdef CONFIG_PROC_FS + if (!proc_create_net("ip_vs", 0, ipvs->net->proc_net, + &ip_vs_info_seq_ops, sizeof(struct ip_vs_iter))) + goto err_vs; + if (!proc_create_net_single("ip_vs_stats", 0, ipvs->net->proc_net, + ip_vs_stats_show, NULL)) + goto err_stats; + if (!proc_create_net_single("ip_vs_stats_percpu", 0, + ipvs->net->proc_net, + ip_vs_stats_percpu_show, NULL)) + goto err_percpu; +#endif + + if (ip_vs_control_net_init_sysctl(ipvs)) + goto err; + + return 0; + +err: +#ifdef CONFIG_PROC_FS + remove_proc_entry("ip_vs_stats_percpu", ipvs->net->proc_net); + +err_percpu: + remove_proc_entry("ip_vs_stats", ipvs->net->proc_net); + +err_stats: + remove_proc_entry("ip_vs", ipvs->net->proc_net); + +err_vs: +#endif + free_percpu(ipvs->tot_stats.cpustats); + return -ENOMEM; +} + +void __net_exit ip_vs_control_net_cleanup(struct netns_ipvs *ipvs) +{ + ip_vs_trash_cleanup(ipvs); + ip_vs_control_net_cleanup_sysctl(ipvs); +#ifdef CONFIG_PROC_FS + remove_proc_entry("ip_vs_stats_percpu", ipvs->net->proc_net); + remove_proc_entry("ip_vs_stats", ipvs->net->proc_net); + remove_proc_entry("ip_vs", ipvs->net->proc_net); +#endif + free_percpu(ipvs->tot_stats.cpustats); +} + +int __init ip_vs_register_nl_ioctl(void) +{ + int ret; + + ret = nf_register_sockopt(&ip_vs_sockopts); + if (ret) { + pr_err("cannot register sockopt.\n"); + goto err_sock; + } + + ret = ip_vs_genl_register(); + if (ret) { + pr_err("cannot register Generic Netlink interface.\n"); + goto err_genl; + } + return 0; + +err_genl: + nf_unregister_sockopt(&ip_vs_sockopts); +err_sock: + return ret; +} + +void ip_vs_unregister_nl_ioctl(void) +{ + ip_vs_genl_unregister(); + nf_unregister_sockopt(&ip_vs_sockopts); +} + +int __init ip_vs_control_init(void) +{ + int idx; + int ret; + + EnterFunction(2); + + /* Initialize svc_table, ip_vs_svc_fwm_table */ + for (idx = 0; idx < IP_VS_SVC_TAB_SIZE; idx++) { + INIT_HLIST_HEAD(&ip_vs_svc_table[idx]); + INIT_HLIST_HEAD(&ip_vs_svc_fwm_table[idx]); + } + + smp_wmb(); /* Do we really need it now ? */ + + ret = register_netdevice_notifier(&ip_vs_dst_notifier); + if (ret < 0) + return ret; + + LeaveFunction(2); + return 0; +} + + +void ip_vs_control_cleanup(void) +{ + EnterFunction(2); + unregister_netdevice_notifier(&ip_vs_dst_notifier); + LeaveFunction(2); +} diff --git a/net/netfilter/ipvs/ip_vs_dh.c b/net/netfilter/ipvs/ip_vs_dh.c new file mode 100644 index 000000000..5e6ec32af --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_dh.c @@ -0,0 +1,272 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPVS: Destination Hashing scheduling module + * + * Authors: Wensong Zhang + * + * Inspired by the consistent hashing scheduler patch from + * Thomas Proell + * + * Changes: + */ + +/* + * The dh algorithm is to select server by the hash key of destination IP + * address. The pseudo code is as follows: + * + * n <- servernode[dest_ip]; + * if (n is dead) OR + * (n is overloaded) OR (n.weight <= 0) then + * return NULL; + * + * return n; + * + * Notes that servernode is a 256-bucket hash table that maps the hash + * index derived from packet destination IP address to the current server + * array. If the dh scheduler is used in cache cluster, it is good to + * combine it with cache_bypass feature. When the statically assigned + * server is dead or overloaded, the load balancer can bypass the cache + * server and send requests to the original server directly. + * + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include + +#include + + +/* + * IPVS DH bucket + */ +struct ip_vs_dh_bucket { + struct ip_vs_dest __rcu *dest; /* real server (cache) */ +}; + +/* + * for IPVS DH entry hash table + */ +#ifndef CONFIG_IP_VS_DH_TAB_BITS +#define CONFIG_IP_VS_DH_TAB_BITS 8 +#endif +#define IP_VS_DH_TAB_BITS CONFIG_IP_VS_DH_TAB_BITS +#define IP_VS_DH_TAB_SIZE (1 << IP_VS_DH_TAB_BITS) +#define IP_VS_DH_TAB_MASK (IP_VS_DH_TAB_SIZE - 1) + +struct ip_vs_dh_state { + struct ip_vs_dh_bucket buckets[IP_VS_DH_TAB_SIZE]; + struct rcu_head rcu_head; +}; + +/* + * Returns hash value for IPVS DH entry + */ +static inline unsigned int ip_vs_dh_hashkey(int af, const union nf_inet_addr *addr) +{ + __be32 addr_fold = addr->ip; + +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + addr_fold = addr->ip6[0]^addr->ip6[1]^ + addr->ip6[2]^addr->ip6[3]; +#endif + return hash_32(ntohl(addr_fold), IP_VS_DH_TAB_BITS); +} + + +/* + * Get ip_vs_dest associated with supplied parameters. + */ +static inline struct ip_vs_dest * +ip_vs_dh_get(int af, struct ip_vs_dh_state *s, const union nf_inet_addr *addr) +{ + return rcu_dereference(s->buckets[ip_vs_dh_hashkey(af, addr)].dest); +} + + +/* + * Assign all the hash buckets of the specified table with the service. + */ +static int +ip_vs_dh_reassign(struct ip_vs_dh_state *s, struct ip_vs_service *svc) +{ + int i; + struct ip_vs_dh_bucket *b; + struct list_head *p; + struct ip_vs_dest *dest; + bool empty; + + b = &s->buckets[0]; + p = &svc->destinations; + empty = list_empty(p); + for (i=0; idest, 1); + if (dest) + ip_vs_dest_put(dest); + if (empty) + RCU_INIT_POINTER(b->dest, NULL); + else { + if (p == &svc->destinations) + p = p->next; + + dest = list_entry(p, struct ip_vs_dest, n_list); + ip_vs_dest_hold(dest); + RCU_INIT_POINTER(b->dest, dest); + + p = p->next; + } + b++; + } + return 0; +} + + +/* + * Flush all the hash buckets of the specified table. + */ +static void ip_vs_dh_flush(struct ip_vs_dh_state *s) +{ + int i; + struct ip_vs_dh_bucket *b; + struct ip_vs_dest *dest; + + b = &s->buckets[0]; + for (i=0; idest, 1); + if (dest) { + ip_vs_dest_put(dest); + RCU_INIT_POINTER(b->dest, NULL); + } + b++; + } +} + + +static int ip_vs_dh_init_svc(struct ip_vs_service *svc) +{ + struct ip_vs_dh_state *s; + + /* allocate the DH table for this service */ + s = kzalloc(sizeof(struct ip_vs_dh_state), GFP_KERNEL); + if (s == NULL) + return -ENOMEM; + + svc->sched_data = s; + IP_VS_DBG(6, "DH hash table (memory=%zdbytes) allocated for " + "current service\n", + sizeof(struct ip_vs_dh_bucket)*IP_VS_DH_TAB_SIZE); + + /* assign the hash buckets with current dests */ + ip_vs_dh_reassign(s, svc); + + return 0; +} + + +static void ip_vs_dh_done_svc(struct ip_vs_service *svc) +{ + struct ip_vs_dh_state *s = svc->sched_data; + + /* got to clean up hash buckets here */ + ip_vs_dh_flush(s); + + /* release the table itself */ + kfree_rcu(s, rcu_head); + IP_VS_DBG(6, "DH hash table (memory=%zdbytes) released\n", + sizeof(struct ip_vs_dh_bucket)*IP_VS_DH_TAB_SIZE); +} + + +static int ip_vs_dh_dest_changed(struct ip_vs_service *svc, + struct ip_vs_dest *dest) +{ + struct ip_vs_dh_state *s = svc->sched_data; + + /* assign the hash buckets with the updated service */ + ip_vs_dh_reassign(s, svc); + + return 0; +} + + +/* + * If the dest flags is set with IP_VS_DEST_F_OVERLOAD, + * consider that the server is overloaded here. + */ +static inline int is_overloaded(struct ip_vs_dest *dest) +{ + return dest->flags & IP_VS_DEST_F_OVERLOAD; +} + + +/* + * Destination hashing scheduling + */ +static struct ip_vs_dest * +ip_vs_dh_schedule(struct ip_vs_service *svc, const struct sk_buff *skb, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_dest *dest; + struct ip_vs_dh_state *s; + + IP_VS_DBG(6, "%s(): Scheduling...\n", __func__); + + s = (struct ip_vs_dh_state *) svc->sched_data; + dest = ip_vs_dh_get(svc->af, s, &iph->daddr); + if (!dest + || !(dest->flags & IP_VS_DEST_F_AVAILABLE) + || atomic_read(&dest->weight) <= 0 + || is_overloaded(dest)) { + ip_vs_scheduler_err(svc, "no destination available"); + return NULL; + } + + IP_VS_DBG_BUF(6, "DH: destination IP address %s --> server %s:%d\n", + IP_VS_DBG_ADDR(svc->af, &iph->daddr), + IP_VS_DBG_ADDR(dest->af, &dest->addr), + ntohs(dest->port)); + + return dest; +} + + +/* + * IPVS DH Scheduler structure + */ +static struct ip_vs_scheduler ip_vs_dh_scheduler = +{ + .name = "dh", + .refcnt = ATOMIC_INIT(0), + .module = THIS_MODULE, + .n_list = LIST_HEAD_INIT(ip_vs_dh_scheduler.n_list), + .init_service = ip_vs_dh_init_svc, + .done_service = ip_vs_dh_done_svc, + .add_dest = ip_vs_dh_dest_changed, + .del_dest = ip_vs_dh_dest_changed, + .schedule = ip_vs_dh_schedule, +}; + + +static int __init ip_vs_dh_init(void) +{ + return register_ip_vs_scheduler(&ip_vs_dh_scheduler); +} + + +static void __exit ip_vs_dh_cleanup(void) +{ + unregister_ip_vs_scheduler(&ip_vs_dh_scheduler); + synchronize_rcu(); +} + + +module_init(ip_vs_dh_init); +module_exit(ip_vs_dh_cleanup); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/ipvs/ip_vs_est.c b/net/netfilter/ipvs/ip_vs_est.c new file mode 100644 index 000000000..f53150d82 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_est.c @@ -0,0 +1,204 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * ip_vs_est.c: simple rate estimator for IPVS + * + * Authors: Wensong Zhang + * + * Changes: Hans Schillstrom + * Network name space (netns) aware. + * Global data moved to netns i.e struct netns_ipvs + * Affected data: est_list and est_lock. + * estimation_timer() runs with timer per netns. + * get_stats()) do the per cpu summing. + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include + +#include + +/* + This code is to estimate rate in a shorter interval (such as 8 + seconds) for virtual services and real servers. For measure rate in a + long interval, it is easy to implement a user level daemon which + periodically reads those statistical counters and measure rate. + + Currently, the measurement is activated by slow timer handler. Hope + this measurement will not introduce too much load. + + We measure rate during the last 8 seconds every 2 seconds: + + avgrate = avgrate*(1-W) + rate*W + + where W = 2^(-2) + + NOTES. + + * Average bps is scaled by 2^5, while average pps and cps are scaled by 2^10. + + * Netlink users can see 64-bit values but sockopt users are restricted + to 32-bit values for conns, packets, bps, cps and pps. + + * A lot of code is taken from net/core/gen_estimator.c + */ + + +/* + * Make a summary from each cpu + */ +static void ip_vs_read_cpu_stats(struct ip_vs_kstats *sum, + struct ip_vs_cpu_stats __percpu *stats) +{ + int i; + bool add = false; + + for_each_possible_cpu(i) { + struct ip_vs_cpu_stats *s = per_cpu_ptr(stats, i); + unsigned int start; + u64 conns, inpkts, outpkts, inbytes, outbytes; + + if (add) { + do { + start = u64_stats_fetch_begin(&s->syncp); + conns = u64_stats_read(&s->cnt.conns); + inpkts = u64_stats_read(&s->cnt.inpkts); + outpkts = u64_stats_read(&s->cnt.outpkts); + inbytes = u64_stats_read(&s->cnt.inbytes); + outbytes = u64_stats_read(&s->cnt.outbytes); + } while (u64_stats_fetch_retry(&s->syncp, start)); + sum->conns += conns; + sum->inpkts += inpkts; + sum->outpkts += outpkts; + sum->inbytes += inbytes; + sum->outbytes += outbytes; + } else { + add = true; + do { + start = u64_stats_fetch_begin(&s->syncp); + sum->conns = u64_stats_read(&s->cnt.conns); + sum->inpkts = u64_stats_read(&s->cnt.inpkts); + sum->outpkts = u64_stats_read(&s->cnt.outpkts); + sum->inbytes = u64_stats_read(&s->cnt.inbytes); + sum->outbytes = u64_stats_read(&s->cnt.outbytes); + } while (u64_stats_fetch_retry(&s->syncp, start)); + } + } +} + + +static void estimation_timer(struct timer_list *t) +{ + struct ip_vs_estimator *e; + struct ip_vs_stats *s; + u64 rate; + struct netns_ipvs *ipvs = from_timer(ipvs, t, est_timer); + + if (!sysctl_run_estimation(ipvs)) + goto skip; + + spin_lock(&ipvs->est_lock); + list_for_each_entry(e, &ipvs->est_list, list) { + s = container_of(e, struct ip_vs_stats, est); + + spin_lock(&s->lock); + ip_vs_read_cpu_stats(&s->kstats, s->cpustats); + + /* scaled by 2^10, but divided 2 seconds */ + rate = (s->kstats.conns - e->last_conns) << 9; + e->last_conns = s->kstats.conns; + e->cps += ((s64)rate - (s64)e->cps) >> 2; + + rate = (s->kstats.inpkts - e->last_inpkts) << 9; + e->last_inpkts = s->kstats.inpkts; + e->inpps += ((s64)rate - (s64)e->inpps) >> 2; + + rate = (s->kstats.outpkts - e->last_outpkts) << 9; + e->last_outpkts = s->kstats.outpkts; + e->outpps += ((s64)rate - (s64)e->outpps) >> 2; + + /* scaled by 2^5, but divided 2 seconds */ + rate = (s->kstats.inbytes - e->last_inbytes) << 4; + e->last_inbytes = s->kstats.inbytes; + e->inbps += ((s64)rate - (s64)e->inbps) >> 2; + + rate = (s->kstats.outbytes - e->last_outbytes) << 4; + e->last_outbytes = s->kstats.outbytes; + e->outbps += ((s64)rate - (s64)e->outbps) >> 2; + spin_unlock(&s->lock); + } + spin_unlock(&ipvs->est_lock); + +skip: + mod_timer(&ipvs->est_timer, jiffies + 2*HZ); +} + +void ip_vs_start_estimator(struct netns_ipvs *ipvs, struct ip_vs_stats *stats) +{ + struct ip_vs_estimator *est = &stats->est; + + INIT_LIST_HEAD(&est->list); + + spin_lock_bh(&ipvs->est_lock); + list_add(&est->list, &ipvs->est_list); + spin_unlock_bh(&ipvs->est_lock); +} + +void ip_vs_stop_estimator(struct netns_ipvs *ipvs, struct ip_vs_stats *stats) +{ + struct ip_vs_estimator *est = &stats->est; + + spin_lock_bh(&ipvs->est_lock); + list_del(&est->list); + spin_unlock_bh(&ipvs->est_lock); +} + +void ip_vs_zero_estimator(struct ip_vs_stats *stats) +{ + struct ip_vs_estimator *est = &stats->est; + struct ip_vs_kstats *k = &stats->kstats; + + /* reset counters, caller must hold the stats->lock lock */ + est->last_inbytes = k->inbytes; + est->last_outbytes = k->outbytes; + est->last_conns = k->conns; + est->last_inpkts = k->inpkts; + est->last_outpkts = k->outpkts; + est->cps = 0; + est->inpps = 0; + est->outpps = 0; + est->inbps = 0; + est->outbps = 0; +} + +/* Get decoded rates */ +void ip_vs_read_estimator(struct ip_vs_kstats *dst, struct ip_vs_stats *stats) +{ + struct ip_vs_estimator *e = &stats->est; + + dst->cps = (e->cps + 0x1FF) >> 10; + dst->inpps = (e->inpps + 0x1FF) >> 10; + dst->outpps = (e->outpps + 0x1FF) >> 10; + dst->inbps = (e->inbps + 0xF) >> 5; + dst->outbps = (e->outbps + 0xF) >> 5; +} + +int __net_init ip_vs_estimator_net_init(struct netns_ipvs *ipvs) +{ + INIT_LIST_HEAD(&ipvs->est_list); + spin_lock_init(&ipvs->est_lock); + timer_setup(&ipvs->est_timer, estimation_timer, 0); + mod_timer(&ipvs->est_timer, jiffies + 2 * HZ); + return 0; +} + +void __net_exit ip_vs_estimator_net_cleanup(struct netns_ipvs *ipvs) +{ + del_timer_sync(&ipvs->est_timer); +} diff --git a/net/netfilter/ipvs/ip_vs_fo.c b/net/netfilter/ipvs/ip_vs_fo.c new file mode 100644 index 000000000..b846cc385 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_fo.c @@ -0,0 +1,74 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPVS: Weighted Fail Over module + * + * Authors: Kenny Mathis + * + * Changes: + * Kenny Mathis : added initial functionality based on weight + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include + +#include + +/* Weighted Fail Over Module */ +static struct ip_vs_dest * +ip_vs_fo_schedule(struct ip_vs_service *svc, const struct sk_buff *skb, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_dest *dest, *hweight = NULL; + int hw = 0; /* Track highest weight */ + + IP_VS_DBG(6, "ip_vs_fo_schedule(): Scheduling...\n"); + + /* Basic failover functionality + * Find virtual server with highest weight and send it traffic + */ + list_for_each_entry_rcu(dest, &svc->destinations, n_list) { + if (!(dest->flags & IP_VS_DEST_F_OVERLOAD) && + atomic_read(&dest->weight) > hw) { + hweight = dest; + hw = atomic_read(&dest->weight); + } + } + + if (hweight) { + IP_VS_DBG_BUF(6, "FO: server %s:%u activeconns %d weight %d\n", + IP_VS_DBG_ADDR(hweight->af, &hweight->addr), + ntohs(hweight->port), + atomic_read(&hweight->activeconns), + atomic_read(&hweight->weight)); + return hweight; + } + + ip_vs_scheduler_err(svc, "no destination available"); + return NULL; +} + +static struct ip_vs_scheduler ip_vs_fo_scheduler = { + .name = "fo", + .refcnt = ATOMIC_INIT(0), + .module = THIS_MODULE, + .n_list = LIST_HEAD_INIT(ip_vs_fo_scheduler.n_list), + .schedule = ip_vs_fo_schedule, +}; + +static int __init ip_vs_fo_init(void) +{ + return register_ip_vs_scheduler(&ip_vs_fo_scheduler); +} + +static void __exit ip_vs_fo_cleanup(void) +{ + unregister_ip_vs_scheduler(&ip_vs_fo_scheduler); + synchronize_rcu(); +} + +module_init(ip_vs_fo_init); +module_exit(ip_vs_fo_cleanup); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/ipvs/ip_vs_ftp.c b/net/netfilter/ipvs/ip_vs_ftp.c new file mode 100644 index 000000000..ef1f45e43 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_ftp.c @@ -0,0 +1,637 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * ip_vs_ftp.c: IPVS ftp application module + * + * Authors: Wensong Zhang + * + * Changes: + * + * Most code here is taken from ip_masq_ftp.c in kernel 2.2. The difference + * is that ip_vs_ftp module handles the reverse direction to ip_masq_ftp. + * + * IP_MASQ_FTP ftp masquerading module + * + * Version: @(#)ip_masq_ftp.c 0.04 02/05/96 + * + * Author: Wouter Gadeyne + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + + +#define SERVER_STRING_PASV "227 " +#define CLIENT_STRING_PORT "PORT" +#define SERVER_STRING_EPSV "229 " +#define CLIENT_STRING_EPRT "EPRT" + +enum { + IP_VS_FTP_ACTIVE = 0, + IP_VS_FTP_PORT = 0, + IP_VS_FTP_PASV, + IP_VS_FTP_EPRT, + IP_VS_FTP_EPSV, +}; + +/* + * List of ports (up to IP_VS_APP_MAX_PORTS) to be handled by helper + * First port is set to the default port. + */ +static unsigned int ports_count = 1; +static unsigned short ports[IP_VS_APP_MAX_PORTS] = {21, 0}; +module_param_array(ports, ushort, &ports_count, 0444); +MODULE_PARM_DESC(ports, "Ports to monitor for FTP control commands"); + + +static char *ip_vs_ftp_data_ptr(struct sk_buff *skb, struct ip_vs_iphdr *ipvsh) +{ + struct tcphdr *th = (struct tcphdr *)((char *)skb->data + ipvsh->len); + + if ((th->doff << 2) < sizeof(struct tcphdr)) + return NULL; + + return (char *)th + (th->doff << 2); +} + +static int +ip_vs_ftp_init_conn(struct ip_vs_app *app, struct ip_vs_conn *cp) +{ + /* We use connection tracking for the command connection */ + cp->flags |= IP_VS_CONN_F_NFCT; + return 0; +} + + +static int +ip_vs_ftp_done_conn(struct ip_vs_app *app, struct ip_vs_conn *cp) +{ + return 0; +} + + +/* Get from the string "xxx.xxx.xxx.xxx,ppp,ppp", started + * with the "pattern". is in network order. + * Parse extended format depending on ext. In this case addr can be pre-set. + */ +static int ip_vs_ftp_get_addrport(char *data, char *data_limit, + const char *pattern, size_t plen, + char skip, bool ext, int mode, + union nf_inet_addr *addr, __be16 *port, + __u16 af, char **start, char **end) +{ + char *s, c; + unsigned char p[6]; + char edelim; + __u16 hport; + int i = 0; + + if (data_limit - data < plen) { + /* check if there is partial match */ + if (strncasecmp(data, pattern, data_limit - data) == 0) + return -1; + else + return 0; + } + + if (strncasecmp(data, pattern, plen) != 0) { + return 0; + } + s = data + plen; + if (skip) { + bool found = false; + + for (;; s++) { + if (s == data_limit) + return -1; + if (!found) { + /* "(" is optional for non-extended format, + * so catch the start of IPv4 address + */ + if (!ext && isdigit(*s)) + break; + if (*s == skip) + found = true; + } else if (*s != skip) { + break; + } + } + } + /* Old IPv4-only format? */ + if (!ext) { + p[0] = 0; + for (data = s; ; data++) { + if (data == data_limit) + return -1; + c = *data; + if (isdigit(c)) { + p[i] = p[i]*10 + c - '0'; + } else if (c == ',' && i < 5) { + i++; + p[i] = 0; + } else { + /* unexpected character or terminator */ + break; + } + } + + if (i != 5) + return -1; + + *start = s; + *end = data; + addr->ip = get_unaligned((__be32 *) p); + *port = get_unaligned((__be16 *) (p + 4)); + return 1; + } + if (s == data_limit) + return -1; + *start = s; + edelim = *s++; + if (edelim < 33 || edelim > 126) + return -1; + if (s == data_limit) + return -1; + if (*s == edelim) { + /* Address family is usually missing for EPSV response */ + if (mode != IP_VS_FTP_EPSV) + return -1; + s++; + if (s == data_limit) + return -1; + /* Then address should be missing too */ + if (*s != edelim) + return -1; + /* Caller can pre-set addr, if needed */ + s++; + } else { + const char *ep; + + /* We allow address only from same family */ + if (af == AF_INET6 && *s != '2') + return -1; + if (af == AF_INET && *s != '1') + return -1; + s++; + if (s == data_limit) + return -1; + if (*s != edelim) + return -1; + s++; + if (s == data_limit) + return -1; + if (af == AF_INET6) { + if (in6_pton(s, data_limit - s, (u8 *)addr, edelim, + &ep) <= 0) + return -1; + } else { + if (in4_pton(s, data_limit - s, (u8 *)addr, edelim, + &ep) <= 0) + return -1; + } + s = (char *) ep; + if (s == data_limit) + return -1; + if (*s != edelim) + return -1; + s++; + } + for (hport = 0; ; s++) + { + if (s == data_limit) + return -1; + if (!isdigit(*s)) + break; + hport = hport * 10 + *s - '0'; + } + if (s == data_limit || !hport || *s != edelim) + return -1; + s++; + *end = s; + *port = htons(hport); + return 1; +} + +/* Look at outgoing ftp packets to catch the response to a PASV/EPSV command + * from the server (inside-to-outside). + * When we see one, we build a connection entry with the client address, + * client port 0 (unknown at the moment), the server address and the + * server port. Mark the current connection entry as a control channel + * of the new entry. All this work is just to make the data connection + * can be scheduled to the right server later. + * + * The outgoing packet should be something like + * "227 Entering Passive Mode (xxx,xxx,xxx,xxx,ppp,ppp)". + * xxx,xxx,xxx,xxx is the server address, ppp,ppp is the server port number. + * The extended format for EPSV response provides usually only port: + * "229 Entering Extended Passive Mode (|||ppp|)" + */ +static int ip_vs_ftp_out(struct ip_vs_app *app, struct ip_vs_conn *cp, + struct sk_buff *skb, int *diff, + struct ip_vs_iphdr *ipvsh) +{ + char *data, *data_limit; + char *start, *end; + union nf_inet_addr from; + __be16 port; + struct ip_vs_conn *n_cp; + char buf[24]; /* xxx.xxx.xxx.xxx,ppp,ppp\000 */ + unsigned int buf_len; + int ret = 0; + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + + *diff = 0; + + /* Only useful for established sessions */ + if (cp->state != IP_VS_TCP_S_ESTABLISHED) + return 1; + + /* Linear packets are much easier to deal with. */ + if (skb_ensure_writable(skb, skb->len)) + return 0; + + if (cp->app_data == (void *) IP_VS_FTP_PASV) { + data = ip_vs_ftp_data_ptr(skb, ipvsh); + data_limit = skb_tail_pointer(skb); + + if (!data || data >= data_limit) + return 1; + + if (ip_vs_ftp_get_addrport(data, data_limit, + SERVER_STRING_PASV, + sizeof(SERVER_STRING_PASV)-1, + '(', false, IP_VS_FTP_PASV, + &from, &port, cp->af, + &start, &end) != 1) + return 1; + + IP_VS_DBG(7, "PASV response (%pI4:%u) -> %pI4:%u detected\n", + &from.ip, ntohs(port), &cp->caddr.ip, 0); + } else if (cp->app_data == (void *) IP_VS_FTP_EPSV) { + data = ip_vs_ftp_data_ptr(skb, ipvsh); + data_limit = skb_tail_pointer(skb); + + if (!data || data >= data_limit) + return 1; + + /* Usually, data address is not specified but + * we support different address, so pre-set it. + */ + from = cp->daddr; + if (ip_vs_ftp_get_addrport(data, data_limit, + SERVER_STRING_EPSV, + sizeof(SERVER_STRING_EPSV)-1, + '(', true, IP_VS_FTP_EPSV, + &from, &port, cp->af, + &start, &end) != 1) + return 1; + + IP_VS_DBG_BUF(7, "EPSV response (%s:%u) -> %s:%u detected\n", + IP_VS_DBG_ADDR(cp->af, &from), ntohs(port), + IP_VS_DBG_ADDR(cp->af, &cp->caddr), 0); + } else { + return 1; + } + + /* Now update or create a connection entry for it */ + { + struct ip_vs_conn_param p; + + ip_vs_conn_fill_param(cp->ipvs, cp->af, + ipvsh->protocol, &from, port, + &cp->caddr, 0, &p); + n_cp = ip_vs_conn_out_get(&p); + } + if (!n_cp) { + struct ip_vs_conn_param p; + + ip_vs_conn_fill_param(cp->ipvs, + cp->af, ipvsh->protocol, &cp->caddr, + 0, &cp->vaddr, port, &p); + n_cp = ip_vs_conn_new(&p, cp->af, &from, port, + IP_VS_CONN_F_NO_CPORT | + IP_VS_CONN_F_NFCT, + cp->dest, skb->mark); + if (!n_cp) + return 0; + + /* add its controller */ + ip_vs_control_add(n_cp, cp); + } + + /* Replace the old passive address with the new one */ + if (cp->app_data == (void *) IP_VS_FTP_PASV) { + from.ip = n_cp->vaddr.ip; + port = n_cp->vport; + snprintf(buf, sizeof(buf), "%u,%u,%u,%u,%u,%u", + ((unsigned char *)&from.ip)[0], + ((unsigned char *)&from.ip)[1], + ((unsigned char *)&from.ip)[2], + ((unsigned char *)&from.ip)[3], + ntohs(port) >> 8, + ntohs(port) & 0xFF); + } else if (cp->app_data == (void *) IP_VS_FTP_EPSV) { + from = n_cp->vaddr; + port = n_cp->vport; + /* Only port, client will use VIP for the data connection */ + snprintf(buf, sizeof(buf), "|||%u|", + ntohs(port)); + } else { + *buf = 0; + } + buf_len = strlen(buf); + + ct = nf_ct_get(skb, &ctinfo); + if (ct) { + bool mangled; + + /* If mangling fails this function will return 0 + * which will cause the packet to be dropped. + * Mangling can only fail under memory pressure, + * hopefully it will succeed on the retransmitted + * packet. + */ + mangled = nf_nat_mangle_tcp_packet(skb, ct, ctinfo, + ipvsh->len, + start - data, + end - start, + buf, buf_len); + if (mangled) { + ip_vs_nfct_expect_related(skb, ct, n_cp, + ipvsh->protocol, 0, 0); + if (skb->ip_summed == CHECKSUM_COMPLETE) + skb->ip_summed = CHECKSUM_UNNECESSARY; + /* csum is updated */ + ret = 1; + } + } + + /* Not setting 'diff' is intentional, otherwise the sequence + * would be adjusted twice. + */ + + cp->app_data = (void *) IP_VS_FTP_ACTIVE; + ip_vs_tcp_conn_listen(n_cp); + ip_vs_conn_put(n_cp); + return ret; +} + + +/* Look at incoming ftp packets to catch the PASV/PORT/EPRT/EPSV command + * (outside-to-inside). + * + * The incoming packet having the PORT command should be something like + * "PORT xxx,xxx,xxx,xxx,ppp,ppp\n". + * xxx,xxx,xxx,xxx is the client address, ppp,ppp is the client port number. + * In this case, we create a connection entry using the client address and + * port, so that the active ftp data connection from the server can reach + * the client. + * Extended format: + * "EPSV\r\n" when client requests server address from same family + * "EPSV 1\r\n" when client requests IPv4 server address + * "EPSV 2\r\n" when client requests IPv6 server address + * "EPSV ALL\r\n" - not supported + * EPRT with specified delimiter (ASCII 33..126), "|" by default: + * "EPRT |1|IPv4ADDR|PORT|\r\n" when client provides IPv4 addrport + * "EPRT |2|IPv6ADDR|PORT|\r\n" when client provides IPv6 addrport + */ +static int ip_vs_ftp_in(struct ip_vs_app *app, struct ip_vs_conn *cp, + struct sk_buff *skb, int *diff, + struct ip_vs_iphdr *ipvsh) +{ + char *data, *data_start, *data_limit; + char *start, *end; + union nf_inet_addr to; + __be16 port; + struct ip_vs_conn *n_cp; + + /* no diff required for incoming packets */ + *diff = 0; + + /* Only useful for established sessions */ + if (cp->state != IP_VS_TCP_S_ESTABLISHED) + return 1; + + /* Linear packets are much easier to deal with. */ + if (skb_ensure_writable(skb, skb->len)) + return 0; + + data = data_start = ip_vs_ftp_data_ptr(skb, ipvsh); + data_limit = skb_tail_pointer(skb); + if (!data || data >= data_limit) + return 1; + + while (data <= data_limit - 6) { + if (cp->af == AF_INET && + strncasecmp(data, "PASV\r\n", 6) == 0) { + /* Passive mode on */ + IP_VS_DBG(7, "got PASV at %td of %td\n", + data - data_start, + data_limit - data_start); + cp->app_data = (void *) IP_VS_FTP_PASV; + return 1; + } + + /* EPSV or EPSV */ + if (strncasecmp(data, "EPSV", 4) == 0 && + (data[4] == ' ' || data[4] == '\r')) { + if (data[4] == ' ') { + char proto = data[5]; + + if (data > data_limit - 7 || data[6] != '\r') + return 1; + +#ifdef CONFIG_IP_VS_IPV6 + if (cp->af == AF_INET6 && proto == '2') { + } else +#endif + if (cp->af == AF_INET && proto == '1') { + } else { + return 1; + } + } + /* Extended Passive mode on */ + IP_VS_DBG(7, "got EPSV at %td of %td\n", + data - data_start, + data_limit - data_start); + cp->app_data = (void *) IP_VS_FTP_EPSV; + return 1; + } + + data++; + } + + /* + * To support virtual FTP server, the scenerio is as follows: + * FTP client ----> Load Balancer ----> FTP server + * First detect the port number in the application data, + * then create a new connection entry for the coming data + * connection. + */ + if (cp->af == AF_INET && + ip_vs_ftp_get_addrport(data_start, data_limit, + CLIENT_STRING_PORT, + sizeof(CLIENT_STRING_PORT)-1, + ' ', false, IP_VS_FTP_PORT, + &to, &port, cp->af, + &start, &end) == 1) { + + IP_VS_DBG(7, "PORT %pI4:%u detected\n", &to.ip, ntohs(port)); + + /* Now update or create a connection entry for it */ + IP_VS_DBG(7, "protocol %s %pI4:%u %pI4:%u\n", + ip_vs_proto_name(ipvsh->protocol), + &to.ip, ntohs(port), &cp->vaddr.ip, + ntohs(cp->vport)-1); + } else if (ip_vs_ftp_get_addrport(data_start, data_limit, + CLIENT_STRING_EPRT, + sizeof(CLIENT_STRING_EPRT)-1, + ' ', true, IP_VS_FTP_EPRT, + &to, &port, cp->af, + &start, &end) == 1) { + + IP_VS_DBG_BUF(7, "EPRT %s:%u detected\n", + IP_VS_DBG_ADDR(cp->af, &to), ntohs(port)); + + /* Now update or create a connection entry for it */ + IP_VS_DBG_BUF(7, "protocol %s %s:%u %s:%u\n", + ip_vs_proto_name(ipvsh->protocol), + IP_VS_DBG_ADDR(cp->af, &to), ntohs(port), + IP_VS_DBG_ADDR(cp->af, &cp->vaddr), + ntohs(cp->vport)-1); + } else { + return 1; + } + + /* Passive mode off */ + cp->app_data = (void *) IP_VS_FTP_ACTIVE; + + { + struct ip_vs_conn_param p; + ip_vs_conn_fill_param(cp->ipvs, cp->af, + ipvsh->protocol, &to, port, &cp->vaddr, + htons(ntohs(cp->vport)-1), &p); + n_cp = ip_vs_conn_in_get(&p); + if (!n_cp) { + n_cp = ip_vs_conn_new(&p, cp->af, &cp->daddr, + htons(ntohs(cp->dport)-1), + IP_VS_CONN_F_NFCT, cp->dest, + skb->mark); + if (!n_cp) + return 0; + + /* add its controller */ + ip_vs_control_add(n_cp, cp); + } + } + + /* + * Move tunnel to listen state + */ + ip_vs_tcp_conn_listen(n_cp); + ip_vs_conn_put(n_cp); + + return 1; +} + + +static struct ip_vs_app ip_vs_ftp = { + .name = "ftp", + .type = IP_VS_APP_TYPE_FTP, + .protocol = IPPROTO_TCP, + .module = THIS_MODULE, + .incs_list = LIST_HEAD_INIT(ip_vs_ftp.incs_list), + .init_conn = ip_vs_ftp_init_conn, + .done_conn = ip_vs_ftp_done_conn, + .bind_conn = NULL, + .unbind_conn = NULL, + .pkt_out = ip_vs_ftp_out, + .pkt_in = ip_vs_ftp_in, +}; + +/* + * per netns ip_vs_ftp initialization + */ +static int __net_init __ip_vs_ftp_init(struct net *net) +{ + int i, ret; + struct ip_vs_app *app; + struct netns_ipvs *ipvs = net_ipvs(net); + + if (!ipvs) + return -ENOENT; + + app = register_ip_vs_app(ipvs, &ip_vs_ftp); + if (IS_ERR(app)) + return PTR_ERR(app); + + for (i = 0; i < ports_count; i++) { + if (!ports[i]) + continue; + ret = register_ip_vs_app_inc(ipvs, app, app->protocol, ports[i]); + if (ret) + goto err_unreg; + } + return 0; + +err_unreg: + unregister_ip_vs_app(ipvs, &ip_vs_ftp); + return ret; +} +/* + * netns exit + */ +static void __ip_vs_ftp_exit(struct net *net) +{ + struct netns_ipvs *ipvs = net_ipvs(net); + + if (!ipvs) + return; + + unregister_ip_vs_app(ipvs, &ip_vs_ftp); +} + +static struct pernet_operations ip_vs_ftp_ops = { + .init = __ip_vs_ftp_init, + .exit = __ip_vs_ftp_exit, +}; + +static int __init ip_vs_ftp_init(void) +{ + /* rcu_barrier() is called by netns on error */ + return register_pernet_subsys(&ip_vs_ftp_ops); +} + +/* + * ip_vs_ftp finish. + */ +static void __exit ip_vs_ftp_exit(void) +{ + unregister_pernet_subsys(&ip_vs_ftp_ops); + /* rcu_barrier() is called by netns */ +} + + +module_init(ip_vs_ftp_init); +module_exit(ip_vs_ftp_exit); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/ipvs/ip_vs_lblc.c b/net/netfilter/ipvs/ip_vs_lblc.c new file mode 100644 index 000000000..7ac7473e3 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_lblc.c @@ -0,0 +1,630 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPVS: Locality-Based Least-Connection scheduling module + * + * Authors: Wensong Zhang + * + * Changes: + * Martin Hamilton : fixed the terrible locking bugs + * *lock(tbl->lock) ==> *lock(&tbl->lock) + * Wensong Zhang : fixed the uninitialized tbl->lock bug + * Wensong Zhang : added doing full expiration check to + * collect stale entries of 24+ hours when + * no partial expire check in a half hour + * Julian Anastasov : replaced del_timer call with del_timer_sync + * to avoid the possible race between timer + * handler and del_timer thread in SMP + */ + +/* + * The lblc algorithm is as follows (pseudo code): + * + * if cachenode[dest_ip] is null then + * n, cachenode[dest_ip] <- {weighted least-conn node}; + * else + * n <- cachenode[dest_ip]; + * if (n is dead) OR + * (n.conns>n.weight AND + * there is a node m with m.conns +#include +#include +#include +#include +#include +#include + +/* for sysctl */ +#include +#include + +#include + + +/* + * It is for garbage collection of stale IPVS lblc entries, + * when the table is full. + */ +#define CHECK_EXPIRE_INTERVAL (60*HZ) +#define ENTRY_TIMEOUT (6*60*HZ) + +#define DEFAULT_EXPIRATION (24*60*60*HZ) + +/* + * It is for full expiration check. + * When there is no partial expiration check (garbage collection) + * in a half hour, do a full expiration check to collect stale + * entries that haven't been touched for a day. + */ +#define COUNT_FOR_FULL_EXPIRATION 30 + + +/* + * for IPVS lblc entry hash table + */ +#ifndef CONFIG_IP_VS_LBLC_TAB_BITS +#define CONFIG_IP_VS_LBLC_TAB_BITS 10 +#endif +#define IP_VS_LBLC_TAB_BITS CONFIG_IP_VS_LBLC_TAB_BITS +#define IP_VS_LBLC_TAB_SIZE (1 << IP_VS_LBLC_TAB_BITS) +#define IP_VS_LBLC_TAB_MASK (IP_VS_LBLC_TAB_SIZE - 1) + + +/* + * IPVS lblc entry represents an association between destination + * IP address and its destination server + */ +struct ip_vs_lblc_entry { + struct hlist_node list; + int af; /* address family */ + union nf_inet_addr addr; /* destination IP address */ + struct ip_vs_dest *dest; /* real server (cache) */ + unsigned long lastuse; /* last used time */ + struct rcu_head rcu_head; +}; + + +/* + * IPVS lblc hash table + */ +struct ip_vs_lblc_table { + struct rcu_head rcu_head; + struct hlist_head bucket[IP_VS_LBLC_TAB_SIZE]; /* hash bucket */ + struct timer_list periodic_timer; /* collect stale entries */ + struct ip_vs_service *svc; /* pointer back to service */ + atomic_t entries; /* number of entries */ + int max_size; /* maximum size of entries */ + int rover; /* rover for expire check */ + int counter; /* counter for no expire */ + bool dead; +}; + + +/* + * IPVS LBLC sysctl table + */ +#ifdef CONFIG_SYSCTL +static struct ctl_table vs_vars_table[] = { + { + .procname = "lblc_expiration", + .data = NULL, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { } +}; +#endif + +static void ip_vs_lblc_rcu_free(struct rcu_head *head) +{ + struct ip_vs_lblc_entry *en = container_of(head, + struct ip_vs_lblc_entry, + rcu_head); + + ip_vs_dest_put_and_free(en->dest); + kfree(en); +} + +static inline void ip_vs_lblc_del(struct ip_vs_lblc_entry *en) +{ + hlist_del_rcu(&en->list); + call_rcu(&en->rcu_head, ip_vs_lblc_rcu_free); +} + +/* + * Returns hash value for IPVS LBLC entry + */ +static inline unsigned int +ip_vs_lblc_hashkey(int af, const union nf_inet_addr *addr) +{ + __be32 addr_fold = addr->ip; + +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + addr_fold = addr->ip6[0]^addr->ip6[1]^ + addr->ip6[2]^addr->ip6[3]; +#endif + return hash_32(ntohl(addr_fold), IP_VS_LBLC_TAB_BITS); +} + + +/* + * Hash an entry in the ip_vs_lblc_table. + * returns bool success. + */ +static void +ip_vs_lblc_hash(struct ip_vs_lblc_table *tbl, struct ip_vs_lblc_entry *en) +{ + unsigned int hash = ip_vs_lblc_hashkey(en->af, &en->addr); + + hlist_add_head_rcu(&en->list, &tbl->bucket[hash]); + atomic_inc(&tbl->entries); +} + + +/* Get ip_vs_lblc_entry associated with supplied parameters. */ +static inline struct ip_vs_lblc_entry * +ip_vs_lblc_get(int af, struct ip_vs_lblc_table *tbl, + const union nf_inet_addr *addr) +{ + unsigned int hash = ip_vs_lblc_hashkey(af, addr); + struct ip_vs_lblc_entry *en; + + hlist_for_each_entry_rcu(en, &tbl->bucket[hash], list) + if (ip_vs_addr_equal(af, &en->addr, addr)) + return en; + + return NULL; +} + + +/* + * Create or update an ip_vs_lblc_entry, which is a mapping of a destination IP + * address to a server. Called under spin lock. + */ +static inline struct ip_vs_lblc_entry * +ip_vs_lblc_new(struct ip_vs_lblc_table *tbl, const union nf_inet_addr *daddr, + u16 af, struct ip_vs_dest *dest) +{ + struct ip_vs_lblc_entry *en; + + en = ip_vs_lblc_get(af, tbl, daddr); + if (en) { + if (en->dest == dest) + return en; + ip_vs_lblc_del(en); + } + en = kmalloc(sizeof(*en), GFP_ATOMIC); + if (!en) + return NULL; + + en->af = af; + ip_vs_addr_copy(af, &en->addr, daddr); + en->lastuse = jiffies; + + ip_vs_dest_hold(dest); + en->dest = dest; + + ip_vs_lblc_hash(tbl, en); + + return en; +} + + +/* + * Flush all the entries of the specified table. + */ +static void ip_vs_lblc_flush(struct ip_vs_service *svc) +{ + struct ip_vs_lblc_table *tbl = svc->sched_data; + struct ip_vs_lblc_entry *en; + struct hlist_node *next; + int i; + + spin_lock_bh(&svc->sched_lock); + tbl->dead = true; + for (i = 0; i < IP_VS_LBLC_TAB_SIZE; i++) { + hlist_for_each_entry_safe(en, next, &tbl->bucket[i], list) { + ip_vs_lblc_del(en); + atomic_dec(&tbl->entries); + } + } + spin_unlock_bh(&svc->sched_lock); +} + +static int sysctl_lblc_expiration(struct ip_vs_service *svc) +{ +#ifdef CONFIG_SYSCTL + return svc->ipvs->sysctl_lblc_expiration; +#else + return DEFAULT_EXPIRATION; +#endif +} + +static inline void ip_vs_lblc_full_check(struct ip_vs_service *svc) +{ + struct ip_vs_lblc_table *tbl = svc->sched_data; + struct ip_vs_lblc_entry *en; + struct hlist_node *next; + unsigned long now = jiffies; + int i, j; + + for (i = 0, j = tbl->rover; i < IP_VS_LBLC_TAB_SIZE; i++) { + j = (j + 1) & IP_VS_LBLC_TAB_MASK; + + spin_lock(&svc->sched_lock); + hlist_for_each_entry_safe(en, next, &tbl->bucket[j], list) { + if (time_before(now, + en->lastuse + + sysctl_lblc_expiration(svc))) + continue; + + ip_vs_lblc_del(en); + atomic_dec(&tbl->entries); + } + spin_unlock(&svc->sched_lock); + } + tbl->rover = j; +} + + +/* + * Periodical timer handler for IPVS lblc table + * It is used to collect stale entries when the number of entries + * exceeds the maximum size of the table. + * + * Fixme: we probably need more complicated algorithm to collect + * entries that have not been used for a long time even + * if the number of entries doesn't exceed the maximum size + * of the table. + * The full expiration check is for this purpose now. + */ +static void ip_vs_lblc_check_expire(struct timer_list *t) +{ + struct ip_vs_lblc_table *tbl = from_timer(tbl, t, periodic_timer); + struct ip_vs_service *svc = tbl->svc; + unsigned long now = jiffies; + int goal; + int i, j; + struct ip_vs_lblc_entry *en; + struct hlist_node *next; + + if ((tbl->counter % COUNT_FOR_FULL_EXPIRATION) == 0) { + /* do full expiration check */ + ip_vs_lblc_full_check(svc); + tbl->counter = 1; + goto out; + } + + if (atomic_read(&tbl->entries) <= tbl->max_size) { + tbl->counter++; + goto out; + } + + goal = (atomic_read(&tbl->entries) - tbl->max_size)*4/3; + if (goal > tbl->max_size/2) + goal = tbl->max_size/2; + + for (i = 0, j = tbl->rover; i < IP_VS_LBLC_TAB_SIZE; i++) { + j = (j + 1) & IP_VS_LBLC_TAB_MASK; + + spin_lock(&svc->sched_lock); + hlist_for_each_entry_safe(en, next, &tbl->bucket[j], list) { + if (time_before(now, en->lastuse + ENTRY_TIMEOUT)) + continue; + + ip_vs_lblc_del(en); + atomic_dec(&tbl->entries); + goal--; + } + spin_unlock(&svc->sched_lock); + if (goal <= 0) + break; + } + tbl->rover = j; + + out: + mod_timer(&tbl->periodic_timer, jiffies + CHECK_EXPIRE_INTERVAL); +} + + +static int ip_vs_lblc_init_svc(struct ip_vs_service *svc) +{ + int i; + struct ip_vs_lblc_table *tbl; + + /* + * Allocate the ip_vs_lblc_table for this service + */ + tbl = kmalloc(sizeof(*tbl), GFP_KERNEL); + if (tbl == NULL) + return -ENOMEM; + + svc->sched_data = tbl; + IP_VS_DBG(6, "LBLC hash table (memory=%zdbytes) allocated for " + "current service\n", sizeof(*tbl)); + + /* + * Initialize the hash buckets + */ + for (i = 0; i < IP_VS_LBLC_TAB_SIZE; i++) { + INIT_HLIST_HEAD(&tbl->bucket[i]); + } + tbl->max_size = IP_VS_LBLC_TAB_SIZE*16; + tbl->rover = 0; + tbl->counter = 1; + tbl->dead = false; + tbl->svc = svc; + atomic_set(&tbl->entries, 0); + + /* + * Hook periodic timer for garbage collection + */ + timer_setup(&tbl->periodic_timer, ip_vs_lblc_check_expire, 0); + mod_timer(&tbl->periodic_timer, jiffies + CHECK_EXPIRE_INTERVAL); + + return 0; +} + + +static void ip_vs_lblc_done_svc(struct ip_vs_service *svc) +{ + struct ip_vs_lblc_table *tbl = svc->sched_data; + + /* remove periodic timer */ + del_timer_sync(&tbl->periodic_timer); + + /* got to clean up table entries here */ + ip_vs_lblc_flush(svc); + + /* release the table itself */ + kfree_rcu(tbl, rcu_head); + IP_VS_DBG(6, "LBLC hash table (memory=%zdbytes) released\n", + sizeof(*tbl)); +} + + +static inline struct ip_vs_dest * +__ip_vs_lblc_schedule(struct ip_vs_service *svc) +{ + struct ip_vs_dest *dest, *least; + int loh, doh; + + /* + * We use the following formula to estimate the load: + * (dest overhead) / dest->weight + * + * Remember -- no floats in kernel mode!!! + * The comparison of h1*w2 > h2*w1 is equivalent to that of + * h1/w1 > h2/w2 + * if every weight is larger than zero. + * + * The server with weight=0 is quiesced and will not receive any + * new connection. + */ + list_for_each_entry_rcu(dest, &svc->destinations, n_list) { + if (dest->flags & IP_VS_DEST_F_OVERLOAD) + continue; + if (atomic_read(&dest->weight) > 0) { + least = dest; + loh = ip_vs_dest_conn_overhead(least); + goto nextstage; + } + } + return NULL; + + /* + * Find the destination with the least load. + */ + nextstage: + list_for_each_entry_continue_rcu(dest, &svc->destinations, n_list) { + if (dest->flags & IP_VS_DEST_F_OVERLOAD) + continue; + + doh = ip_vs_dest_conn_overhead(dest); + if ((__s64)loh * atomic_read(&dest->weight) > + (__s64)doh * atomic_read(&least->weight)) { + least = dest; + loh = doh; + } + } + + IP_VS_DBG_BUF(6, "LBLC: server %s:%d " + "activeconns %d refcnt %d weight %d overhead %d\n", + IP_VS_DBG_ADDR(least->af, &least->addr), + ntohs(least->port), + atomic_read(&least->activeconns), + refcount_read(&least->refcnt), + atomic_read(&least->weight), loh); + + return least; +} + + +/* + * If this destination server is overloaded and there is a less loaded + * server, then return true. + */ +static inline int +is_overloaded(struct ip_vs_dest *dest, struct ip_vs_service *svc) +{ + if (atomic_read(&dest->activeconns) > atomic_read(&dest->weight)) { + struct ip_vs_dest *d; + + list_for_each_entry_rcu(d, &svc->destinations, n_list) { + if (atomic_read(&d->activeconns)*2 + < atomic_read(&d->weight)) { + return 1; + } + } + } + return 0; +} + + +/* + * Locality-Based (weighted) Least-Connection scheduling + */ +static struct ip_vs_dest * +ip_vs_lblc_schedule(struct ip_vs_service *svc, const struct sk_buff *skb, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_lblc_table *tbl = svc->sched_data; + struct ip_vs_dest *dest = NULL; + struct ip_vs_lblc_entry *en; + + IP_VS_DBG(6, "%s(): Scheduling...\n", __func__); + + /* First look in our cache */ + en = ip_vs_lblc_get(svc->af, tbl, &iph->daddr); + if (en) { + /* We only hold a read lock, but this is atomic */ + en->lastuse = jiffies; + + /* + * If the destination is not available, i.e. it's in the trash, + * we must ignore it, as it may be removed from under our feet, + * if someone drops our reference count. Our caller only makes + * sure that destinations, that are not in the trash, are not + * moved to the trash, while we are scheduling. But anyone can + * free up entries from the trash at any time. + */ + + dest = en->dest; + if ((dest->flags & IP_VS_DEST_F_AVAILABLE) && + atomic_read(&dest->weight) > 0 && !is_overloaded(dest, svc)) + goto out; + } + + /* No cache entry or it is invalid, time to schedule */ + dest = __ip_vs_lblc_schedule(svc); + if (!dest) { + ip_vs_scheduler_err(svc, "no destination available"); + return NULL; + } + + /* If we fail to create a cache entry, we'll just use the valid dest */ + spin_lock_bh(&svc->sched_lock); + if (!tbl->dead) + ip_vs_lblc_new(tbl, &iph->daddr, svc->af, dest); + spin_unlock_bh(&svc->sched_lock); + +out: + IP_VS_DBG_BUF(6, "LBLC: destination IP address %s --> server %s:%d\n", + IP_VS_DBG_ADDR(svc->af, &iph->daddr), + IP_VS_DBG_ADDR(dest->af, &dest->addr), ntohs(dest->port)); + + return dest; +} + + +/* + * IPVS LBLC Scheduler structure + */ +static struct ip_vs_scheduler ip_vs_lblc_scheduler = { + .name = "lblc", + .refcnt = ATOMIC_INIT(0), + .module = THIS_MODULE, + .n_list = LIST_HEAD_INIT(ip_vs_lblc_scheduler.n_list), + .init_service = ip_vs_lblc_init_svc, + .done_service = ip_vs_lblc_done_svc, + .schedule = ip_vs_lblc_schedule, +}; + +/* + * per netns init. + */ +#ifdef CONFIG_SYSCTL +static int __net_init __ip_vs_lblc_init(struct net *net) +{ + struct netns_ipvs *ipvs = net_ipvs(net); + + if (!ipvs) + return -ENOENT; + + if (!net_eq(net, &init_net)) { + ipvs->lblc_ctl_table = kmemdup(vs_vars_table, + sizeof(vs_vars_table), + GFP_KERNEL); + if (ipvs->lblc_ctl_table == NULL) + return -ENOMEM; + + /* Don't export sysctls to unprivileged users */ + if (net->user_ns != &init_user_ns) + ipvs->lblc_ctl_table[0].procname = NULL; + + } else + ipvs->lblc_ctl_table = vs_vars_table; + ipvs->sysctl_lblc_expiration = DEFAULT_EXPIRATION; + ipvs->lblc_ctl_table[0].data = &ipvs->sysctl_lblc_expiration; + + ipvs->lblc_ctl_header = + register_net_sysctl(net, "net/ipv4/vs", ipvs->lblc_ctl_table); + if (!ipvs->lblc_ctl_header) { + if (!net_eq(net, &init_net)) + kfree(ipvs->lblc_ctl_table); + return -ENOMEM; + } + + return 0; +} + +static void __net_exit __ip_vs_lblc_exit(struct net *net) +{ + struct netns_ipvs *ipvs = net_ipvs(net); + + unregister_net_sysctl_table(ipvs->lblc_ctl_header); + + if (!net_eq(net, &init_net)) + kfree(ipvs->lblc_ctl_table); +} + +#else + +static int __net_init __ip_vs_lblc_init(struct net *net) { return 0; } +static void __net_exit __ip_vs_lblc_exit(struct net *net) { } + +#endif + +static struct pernet_operations ip_vs_lblc_ops = { + .init = __ip_vs_lblc_init, + .exit = __ip_vs_lblc_exit, +}; + +static int __init ip_vs_lblc_init(void) +{ + int ret; + + ret = register_pernet_subsys(&ip_vs_lblc_ops); + if (ret) + return ret; + + ret = register_ip_vs_scheduler(&ip_vs_lblc_scheduler); + if (ret) + unregister_pernet_subsys(&ip_vs_lblc_ops); + return ret; +} + +static void __exit ip_vs_lblc_cleanup(void) +{ + unregister_ip_vs_scheduler(&ip_vs_lblc_scheduler); + unregister_pernet_subsys(&ip_vs_lblc_ops); + rcu_barrier(); +} + + +module_init(ip_vs_lblc_init); +module_exit(ip_vs_lblc_cleanup); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/ipvs/ip_vs_lblcr.c b/net/netfilter/ipvs/ip_vs_lblcr.c new file mode 100644 index 000000000..77c323c36 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_lblcr.c @@ -0,0 +1,815 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPVS: Locality-Based Least-Connection with Replication scheduler + * + * Authors: Wensong Zhang + * + * Changes: + * Julian Anastasov : Added the missing (dest->weight>0) + * condition in the ip_vs_dest_set_max. + */ + +/* + * The lblc/r algorithm is as follows (pseudo code): + * + * if serverSet[dest_ip] is null then + * n, serverSet[dest_ip] <- {weighted least-conn node}; + * else + * n <- {least-conn (alive) node in serverSet[dest_ip]}; + * if (n is null) OR + * (n.conns>n.weight AND + * there is a node m with m.conns 1 AND + * now - serverSet[dest_ip].lastMod > T then + * m <- {most conn node in serverSet[dest_ip]}; + * remove m from serverSet[dest_ip]; + * if serverSet[dest_ip] changed then + * serverSet[dest_ip].lastMod <- now; + * + * return n; + * + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include + +/* for sysctl */ +#include +#include +#include + +#include + + +/* + * It is for garbage collection of stale IPVS lblcr entries, + * when the table is full. + */ +#define CHECK_EXPIRE_INTERVAL (60*HZ) +#define ENTRY_TIMEOUT (6*60*HZ) + +#define DEFAULT_EXPIRATION (24*60*60*HZ) + +/* + * It is for full expiration check. + * When there is no partial expiration check (garbage collection) + * in a half hour, do a full expiration check to collect stale + * entries that haven't been touched for a day. + */ +#define COUNT_FOR_FULL_EXPIRATION 30 + +/* + * for IPVS lblcr entry hash table + */ +#ifndef CONFIG_IP_VS_LBLCR_TAB_BITS +#define CONFIG_IP_VS_LBLCR_TAB_BITS 10 +#endif +#define IP_VS_LBLCR_TAB_BITS CONFIG_IP_VS_LBLCR_TAB_BITS +#define IP_VS_LBLCR_TAB_SIZE (1 << IP_VS_LBLCR_TAB_BITS) +#define IP_VS_LBLCR_TAB_MASK (IP_VS_LBLCR_TAB_SIZE - 1) + + +/* + * IPVS destination set structure and operations + */ +struct ip_vs_dest_set_elem { + struct list_head list; /* list link */ + struct ip_vs_dest *dest; /* destination server */ + struct rcu_head rcu_head; +}; + +struct ip_vs_dest_set { + atomic_t size; /* set size */ + unsigned long lastmod; /* last modified time */ + struct list_head list; /* destination list */ +}; + + +static void ip_vs_dest_set_insert(struct ip_vs_dest_set *set, + struct ip_vs_dest *dest, bool check) +{ + struct ip_vs_dest_set_elem *e; + + if (check) { + list_for_each_entry(e, &set->list, list) { + if (e->dest == dest) + return; + } + } + + e = kmalloc(sizeof(*e), GFP_ATOMIC); + if (e == NULL) + return; + + ip_vs_dest_hold(dest); + e->dest = dest; + + list_add_rcu(&e->list, &set->list); + atomic_inc(&set->size); + + set->lastmod = jiffies; +} + +static void ip_vs_lblcr_elem_rcu_free(struct rcu_head *head) +{ + struct ip_vs_dest_set_elem *e; + + e = container_of(head, struct ip_vs_dest_set_elem, rcu_head); + ip_vs_dest_put_and_free(e->dest); + kfree(e); +} + +static void +ip_vs_dest_set_erase(struct ip_vs_dest_set *set, struct ip_vs_dest *dest) +{ + struct ip_vs_dest_set_elem *e; + + list_for_each_entry(e, &set->list, list) { + if (e->dest == dest) { + /* HIT */ + atomic_dec(&set->size); + set->lastmod = jiffies; + list_del_rcu(&e->list); + call_rcu(&e->rcu_head, ip_vs_lblcr_elem_rcu_free); + break; + } + } +} + +static void ip_vs_dest_set_eraseall(struct ip_vs_dest_set *set) +{ + struct ip_vs_dest_set_elem *e, *ep; + + list_for_each_entry_safe(e, ep, &set->list, list) { + list_del_rcu(&e->list); + call_rcu(&e->rcu_head, ip_vs_lblcr_elem_rcu_free); + } +} + +/* get weighted least-connection node in the destination set */ +static inline struct ip_vs_dest *ip_vs_dest_set_min(struct ip_vs_dest_set *set) +{ + struct ip_vs_dest_set_elem *e; + struct ip_vs_dest *dest, *least; + int loh, doh; + + /* select the first destination server, whose weight > 0 */ + list_for_each_entry_rcu(e, &set->list, list) { + least = e->dest; + if (least->flags & IP_VS_DEST_F_OVERLOAD) + continue; + + if ((atomic_read(&least->weight) > 0) + && (least->flags & IP_VS_DEST_F_AVAILABLE)) { + loh = ip_vs_dest_conn_overhead(least); + goto nextstage; + } + } + return NULL; + + /* find the destination with the weighted least load */ + nextstage: + list_for_each_entry_continue_rcu(e, &set->list, list) { + dest = e->dest; + if (dest->flags & IP_VS_DEST_F_OVERLOAD) + continue; + + doh = ip_vs_dest_conn_overhead(dest); + if (((__s64)loh * atomic_read(&dest->weight) > + (__s64)doh * atomic_read(&least->weight)) + && (dest->flags & IP_VS_DEST_F_AVAILABLE)) { + least = dest; + loh = doh; + } + } + + IP_VS_DBG_BUF(6, "%s(): server %s:%d " + "activeconns %d refcnt %d weight %d overhead %d\n", + __func__, + IP_VS_DBG_ADDR(least->af, &least->addr), + ntohs(least->port), + atomic_read(&least->activeconns), + refcount_read(&least->refcnt), + atomic_read(&least->weight), loh); + return least; +} + + +/* get weighted most-connection node in the destination set */ +static inline struct ip_vs_dest *ip_vs_dest_set_max(struct ip_vs_dest_set *set) +{ + struct ip_vs_dest_set_elem *e; + struct ip_vs_dest *dest, *most; + int moh, doh; + + if (set == NULL) + return NULL; + + /* select the first destination server, whose weight > 0 */ + list_for_each_entry(e, &set->list, list) { + most = e->dest; + if (atomic_read(&most->weight) > 0) { + moh = ip_vs_dest_conn_overhead(most); + goto nextstage; + } + } + return NULL; + + /* find the destination with the weighted most load */ + nextstage: + list_for_each_entry_continue(e, &set->list, list) { + dest = e->dest; + doh = ip_vs_dest_conn_overhead(dest); + /* moh/mw < doh/dw ==> moh*dw < doh*mw, where mw,dw>0 */ + if (((__s64)moh * atomic_read(&dest->weight) < + (__s64)doh * atomic_read(&most->weight)) + && (atomic_read(&dest->weight) > 0)) { + most = dest; + moh = doh; + } + } + + IP_VS_DBG_BUF(6, "%s(): server %s:%d " + "activeconns %d refcnt %d weight %d overhead %d\n", + __func__, + IP_VS_DBG_ADDR(most->af, &most->addr), ntohs(most->port), + atomic_read(&most->activeconns), + refcount_read(&most->refcnt), + atomic_read(&most->weight), moh); + return most; +} + + +/* + * IPVS lblcr entry represents an association between destination + * IP address and its destination server set + */ +struct ip_vs_lblcr_entry { + struct hlist_node list; + int af; /* address family */ + union nf_inet_addr addr; /* destination IP address */ + struct ip_vs_dest_set set; /* destination server set */ + unsigned long lastuse; /* last used time */ + struct rcu_head rcu_head; +}; + + +/* + * IPVS lblcr hash table + */ +struct ip_vs_lblcr_table { + struct rcu_head rcu_head; + struct hlist_head bucket[IP_VS_LBLCR_TAB_SIZE]; /* hash bucket */ + atomic_t entries; /* number of entries */ + int max_size; /* maximum size of entries */ + struct timer_list periodic_timer; /* collect stale entries */ + struct ip_vs_service *svc; /* pointer back to service */ + int rover; /* rover for expire check */ + int counter; /* counter for no expire */ + bool dead; +}; + + +#ifdef CONFIG_SYSCTL +/* + * IPVS LBLCR sysctl table + */ + +static struct ctl_table vs_vars_table[] = { + { + .procname = "lblcr_expiration", + .data = NULL, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { } +}; +#endif + +static inline void ip_vs_lblcr_free(struct ip_vs_lblcr_entry *en) +{ + hlist_del_rcu(&en->list); + ip_vs_dest_set_eraseall(&en->set); + kfree_rcu(en, rcu_head); +} + + +/* + * Returns hash value for IPVS LBLCR entry + */ +static inline unsigned int +ip_vs_lblcr_hashkey(int af, const union nf_inet_addr *addr) +{ + __be32 addr_fold = addr->ip; + +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + addr_fold = addr->ip6[0]^addr->ip6[1]^ + addr->ip6[2]^addr->ip6[3]; +#endif + return hash_32(ntohl(addr_fold), IP_VS_LBLCR_TAB_BITS); +} + + +/* + * Hash an entry in the ip_vs_lblcr_table. + * returns bool success. + */ +static void +ip_vs_lblcr_hash(struct ip_vs_lblcr_table *tbl, struct ip_vs_lblcr_entry *en) +{ + unsigned int hash = ip_vs_lblcr_hashkey(en->af, &en->addr); + + hlist_add_head_rcu(&en->list, &tbl->bucket[hash]); + atomic_inc(&tbl->entries); +} + + +/* Get ip_vs_lblcr_entry associated with supplied parameters. */ +static inline struct ip_vs_lblcr_entry * +ip_vs_lblcr_get(int af, struct ip_vs_lblcr_table *tbl, + const union nf_inet_addr *addr) +{ + unsigned int hash = ip_vs_lblcr_hashkey(af, addr); + struct ip_vs_lblcr_entry *en; + + hlist_for_each_entry_rcu(en, &tbl->bucket[hash], list) + if (ip_vs_addr_equal(af, &en->addr, addr)) + return en; + + return NULL; +} + + +/* + * Create or update an ip_vs_lblcr_entry, which is a mapping of a destination + * IP address to a server. Called under spin lock. + */ +static inline struct ip_vs_lblcr_entry * +ip_vs_lblcr_new(struct ip_vs_lblcr_table *tbl, const union nf_inet_addr *daddr, + u16 af, struct ip_vs_dest *dest) +{ + struct ip_vs_lblcr_entry *en; + + en = ip_vs_lblcr_get(af, tbl, daddr); + if (!en) { + en = kmalloc(sizeof(*en), GFP_ATOMIC); + if (!en) + return NULL; + + en->af = af; + ip_vs_addr_copy(af, &en->addr, daddr); + en->lastuse = jiffies; + + /* initialize its dest set */ + atomic_set(&(en->set.size), 0); + INIT_LIST_HEAD(&en->set.list); + + ip_vs_dest_set_insert(&en->set, dest, false); + + ip_vs_lblcr_hash(tbl, en); + return en; + } + + ip_vs_dest_set_insert(&en->set, dest, true); + + return en; +} + + +/* + * Flush all the entries of the specified table. + */ +static void ip_vs_lblcr_flush(struct ip_vs_service *svc) +{ + struct ip_vs_lblcr_table *tbl = svc->sched_data; + int i; + struct ip_vs_lblcr_entry *en; + struct hlist_node *next; + + spin_lock_bh(&svc->sched_lock); + tbl->dead = true; + for (i = 0; i < IP_VS_LBLCR_TAB_SIZE; i++) { + hlist_for_each_entry_safe(en, next, &tbl->bucket[i], list) { + ip_vs_lblcr_free(en); + } + } + spin_unlock_bh(&svc->sched_lock); +} + +static int sysctl_lblcr_expiration(struct ip_vs_service *svc) +{ +#ifdef CONFIG_SYSCTL + return svc->ipvs->sysctl_lblcr_expiration; +#else + return DEFAULT_EXPIRATION; +#endif +} + +static inline void ip_vs_lblcr_full_check(struct ip_vs_service *svc) +{ + struct ip_vs_lblcr_table *tbl = svc->sched_data; + unsigned long now = jiffies; + int i, j; + struct ip_vs_lblcr_entry *en; + struct hlist_node *next; + + for (i = 0, j = tbl->rover; i < IP_VS_LBLCR_TAB_SIZE; i++) { + j = (j + 1) & IP_VS_LBLCR_TAB_MASK; + + spin_lock(&svc->sched_lock); + hlist_for_each_entry_safe(en, next, &tbl->bucket[j], list) { + if (time_after(en->lastuse + + sysctl_lblcr_expiration(svc), now)) + continue; + + ip_vs_lblcr_free(en); + atomic_dec(&tbl->entries); + } + spin_unlock(&svc->sched_lock); + } + tbl->rover = j; +} + + +/* + * Periodical timer handler for IPVS lblcr table + * It is used to collect stale entries when the number of entries + * exceeds the maximum size of the table. + * + * Fixme: we probably need more complicated algorithm to collect + * entries that have not been used for a long time even + * if the number of entries doesn't exceed the maximum size + * of the table. + * The full expiration check is for this purpose now. + */ +static void ip_vs_lblcr_check_expire(struct timer_list *t) +{ + struct ip_vs_lblcr_table *tbl = from_timer(tbl, t, periodic_timer); + struct ip_vs_service *svc = tbl->svc; + unsigned long now = jiffies; + int goal; + int i, j; + struct ip_vs_lblcr_entry *en; + struct hlist_node *next; + + if ((tbl->counter % COUNT_FOR_FULL_EXPIRATION) == 0) { + /* do full expiration check */ + ip_vs_lblcr_full_check(svc); + tbl->counter = 1; + goto out; + } + + if (atomic_read(&tbl->entries) <= tbl->max_size) { + tbl->counter++; + goto out; + } + + goal = (atomic_read(&tbl->entries) - tbl->max_size)*4/3; + if (goal > tbl->max_size/2) + goal = tbl->max_size/2; + + for (i = 0, j = tbl->rover; i < IP_VS_LBLCR_TAB_SIZE; i++) { + j = (j + 1) & IP_VS_LBLCR_TAB_MASK; + + spin_lock(&svc->sched_lock); + hlist_for_each_entry_safe(en, next, &tbl->bucket[j], list) { + if (time_before(now, en->lastuse+ENTRY_TIMEOUT)) + continue; + + ip_vs_lblcr_free(en); + atomic_dec(&tbl->entries); + goal--; + } + spin_unlock(&svc->sched_lock); + if (goal <= 0) + break; + } + tbl->rover = j; + + out: + mod_timer(&tbl->periodic_timer, jiffies+CHECK_EXPIRE_INTERVAL); +} + +static int ip_vs_lblcr_init_svc(struct ip_vs_service *svc) +{ + int i; + struct ip_vs_lblcr_table *tbl; + + /* + * Allocate the ip_vs_lblcr_table for this service + */ + tbl = kmalloc(sizeof(*tbl), GFP_KERNEL); + if (tbl == NULL) + return -ENOMEM; + + svc->sched_data = tbl; + IP_VS_DBG(6, "LBLCR hash table (memory=%zdbytes) allocated for " + "current service\n", sizeof(*tbl)); + + /* + * Initialize the hash buckets + */ + for (i = 0; i < IP_VS_LBLCR_TAB_SIZE; i++) { + INIT_HLIST_HEAD(&tbl->bucket[i]); + } + tbl->max_size = IP_VS_LBLCR_TAB_SIZE*16; + tbl->rover = 0; + tbl->counter = 1; + tbl->dead = false; + tbl->svc = svc; + atomic_set(&tbl->entries, 0); + + /* + * Hook periodic timer for garbage collection + */ + timer_setup(&tbl->periodic_timer, ip_vs_lblcr_check_expire, 0); + mod_timer(&tbl->periodic_timer, jiffies + CHECK_EXPIRE_INTERVAL); + + return 0; +} + + +static void ip_vs_lblcr_done_svc(struct ip_vs_service *svc) +{ + struct ip_vs_lblcr_table *tbl = svc->sched_data; + + /* remove periodic timer */ + del_timer_sync(&tbl->periodic_timer); + + /* got to clean up table entries here */ + ip_vs_lblcr_flush(svc); + + /* release the table itself */ + kfree_rcu(tbl, rcu_head); + IP_VS_DBG(6, "LBLCR hash table (memory=%zdbytes) released\n", + sizeof(*tbl)); +} + + +static inline struct ip_vs_dest * +__ip_vs_lblcr_schedule(struct ip_vs_service *svc) +{ + struct ip_vs_dest *dest, *least; + int loh, doh; + + /* + * We use the following formula to estimate the load: + * (dest overhead) / dest->weight + * + * Remember -- no floats in kernel mode!!! + * The comparison of h1*w2 > h2*w1 is equivalent to that of + * h1/w1 > h2/w2 + * if every weight is larger than zero. + * + * The server with weight=0 is quiesced and will not receive any + * new connection. + */ + list_for_each_entry_rcu(dest, &svc->destinations, n_list) { + if (dest->flags & IP_VS_DEST_F_OVERLOAD) + continue; + + if (atomic_read(&dest->weight) > 0) { + least = dest; + loh = ip_vs_dest_conn_overhead(least); + goto nextstage; + } + } + return NULL; + + /* + * Find the destination with the least load. + */ + nextstage: + list_for_each_entry_continue_rcu(dest, &svc->destinations, n_list) { + if (dest->flags & IP_VS_DEST_F_OVERLOAD) + continue; + + doh = ip_vs_dest_conn_overhead(dest); + if ((__s64)loh * atomic_read(&dest->weight) > + (__s64)doh * atomic_read(&least->weight)) { + least = dest; + loh = doh; + } + } + + IP_VS_DBG_BUF(6, "LBLCR: server %s:%d " + "activeconns %d refcnt %d weight %d overhead %d\n", + IP_VS_DBG_ADDR(least->af, &least->addr), + ntohs(least->port), + atomic_read(&least->activeconns), + refcount_read(&least->refcnt), + atomic_read(&least->weight), loh); + + return least; +} + + +/* + * If this destination server is overloaded and there is a less loaded + * server, then return true. + */ +static inline int +is_overloaded(struct ip_vs_dest *dest, struct ip_vs_service *svc) +{ + if (atomic_read(&dest->activeconns) > atomic_read(&dest->weight)) { + struct ip_vs_dest *d; + + list_for_each_entry_rcu(d, &svc->destinations, n_list) { + if (atomic_read(&d->activeconns)*2 + < atomic_read(&d->weight)) { + return 1; + } + } + } + return 0; +} + + +/* + * Locality-Based (weighted) Least-Connection scheduling + */ +static struct ip_vs_dest * +ip_vs_lblcr_schedule(struct ip_vs_service *svc, const struct sk_buff *skb, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_lblcr_table *tbl = svc->sched_data; + struct ip_vs_dest *dest; + struct ip_vs_lblcr_entry *en; + + IP_VS_DBG(6, "%s(): Scheduling...\n", __func__); + + /* First look in our cache */ + en = ip_vs_lblcr_get(svc->af, tbl, &iph->daddr); + if (en) { + en->lastuse = jiffies; + + /* Get the least loaded destination */ + dest = ip_vs_dest_set_min(&en->set); + + /* More than one destination + enough time passed by, cleanup */ + if (atomic_read(&en->set.size) > 1 && + time_after(jiffies, en->set.lastmod + + sysctl_lblcr_expiration(svc))) { + spin_lock_bh(&svc->sched_lock); + if (atomic_read(&en->set.size) > 1) { + struct ip_vs_dest *m; + + m = ip_vs_dest_set_max(&en->set); + if (m) + ip_vs_dest_set_erase(&en->set, m); + } + spin_unlock_bh(&svc->sched_lock); + } + + /* If the destination is not overloaded, use it */ + if (dest && !is_overloaded(dest, svc)) + goto out; + + /* The cache entry is invalid, time to schedule */ + dest = __ip_vs_lblcr_schedule(svc); + if (!dest) { + ip_vs_scheduler_err(svc, "no destination available"); + return NULL; + } + + /* Update our cache entry */ + spin_lock_bh(&svc->sched_lock); + if (!tbl->dead) + ip_vs_dest_set_insert(&en->set, dest, true); + spin_unlock_bh(&svc->sched_lock); + goto out; + } + + /* No cache entry, time to schedule */ + dest = __ip_vs_lblcr_schedule(svc); + if (!dest) { + IP_VS_DBG(1, "no destination available\n"); + return NULL; + } + + /* If we fail to create a cache entry, we'll just use the valid dest */ + spin_lock_bh(&svc->sched_lock); + if (!tbl->dead) + ip_vs_lblcr_new(tbl, &iph->daddr, svc->af, dest); + spin_unlock_bh(&svc->sched_lock); + +out: + IP_VS_DBG_BUF(6, "LBLCR: destination IP address %s --> server %s:%d\n", + IP_VS_DBG_ADDR(svc->af, &iph->daddr), + IP_VS_DBG_ADDR(dest->af, &dest->addr), ntohs(dest->port)); + + return dest; +} + + +/* + * IPVS LBLCR Scheduler structure + */ +static struct ip_vs_scheduler ip_vs_lblcr_scheduler = +{ + .name = "lblcr", + .refcnt = ATOMIC_INIT(0), + .module = THIS_MODULE, + .n_list = LIST_HEAD_INIT(ip_vs_lblcr_scheduler.n_list), + .init_service = ip_vs_lblcr_init_svc, + .done_service = ip_vs_lblcr_done_svc, + .schedule = ip_vs_lblcr_schedule, +}; + +/* + * per netns init. + */ +#ifdef CONFIG_SYSCTL +static int __net_init __ip_vs_lblcr_init(struct net *net) +{ + struct netns_ipvs *ipvs = net_ipvs(net); + + if (!ipvs) + return -ENOENT; + + if (!net_eq(net, &init_net)) { + ipvs->lblcr_ctl_table = kmemdup(vs_vars_table, + sizeof(vs_vars_table), + GFP_KERNEL); + if (ipvs->lblcr_ctl_table == NULL) + return -ENOMEM; + + /* Don't export sysctls to unprivileged users */ + if (net->user_ns != &init_user_ns) + ipvs->lblcr_ctl_table[0].procname = NULL; + } else + ipvs->lblcr_ctl_table = vs_vars_table; + ipvs->sysctl_lblcr_expiration = DEFAULT_EXPIRATION; + ipvs->lblcr_ctl_table[0].data = &ipvs->sysctl_lblcr_expiration; + + ipvs->lblcr_ctl_header = + register_net_sysctl(net, "net/ipv4/vs", ipvs->lblcr_ctl_table); + if (!ipvs->lblcr_ctl_header) { + if (!net_eq(net, &init_net)) + kfree(ipvs->lblcr_ctl_table); + return -ENOMEM; + } + + return 0; +} + +static void __net_exit __ip_vs_lblcr_exit(struct net *net) +{ + struct netns_ipvs *ipvs = net_ipvs(net); + + unregister_net_sysctl_table(ipvs->lblcr_ctl_header); + + if (!net_eq(net, &init_net)) + kfree(ipvs->lblcr_ctl_table); +} + +#else + +static int __net_init __ip_vs_lblcr_init(struct net *net) { return 0; } +static void __net_exit __ip_vs_lblcr_exit(struct net *net) { } + +#endif + +static struct pernet_operations ip_vs_lblcr_ops = { + .init = __ip_vs_lblcr_init, + .exit = __ip_vs_lblcr_exit, +}; + +static int __init ip_vs_lblcr_init(void) +{ + int ret; + + ret = register_pernet_subsys(&ip_vs_lblcr_ops); + if (ret) + return ret; + + ret = register_ip_vs_scheduler(&ip_vs_lblcr_scheduler); + if (ret) + unregister_pernet_subsys(&ip_vs_lblcr_ops); + return ret; +} + +static void __exit ip_vs_lblcr_cleanup(void) +{ + unregister_ip_vs_scheduler(&ip_vs_lblcr_scheduler); + unregister_pernet_subsys(&ip_vs_lblcr_ops); + rcu_barrier(); +} + + +module_init(ip_vs_lblcr_init); +module_exit(ip_vs_lblcr_cleanup); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/ipvs/ip_vs_lc.c b/net/netfilter/ipvs/ip_vs_lc.c new file mode 100644 index 000000000..9d34d81fc --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_lc.c @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPVS: Least-Connection Scheduling module + * + * Authors: Wensong Zhang + * + * Changes: + * Wensong Zhang : added the ip_vs_lc_update_svc + * Wensong Zhang : added any dest with weight=0 is quiesced + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include + +#include + +/* + * Least Connection scheduling + */ +static struct ip_vs_dest * +ip_vs_lc_schedule(struct ip_vs_service *svc, const struct sk_buff *skb, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_dest *dest, *least = NULL; + unsigned int loh = 0, doh; + + IP_VS_DBG(6, "%s(): Scheduling...\n", __func__); + + /* + * Simply select the server with the least number of + * (activeconns<<5) + inactconns + * Except whose weight is equal to zero. + * If the weight is equal to zero, it means that the server is + * quiesced, the existing connections to the server still get + * served, but no new connection is assigned to the server. + */ + + list_for_each_entry_rcu(dest, &svc->destinations, n_list) { + if ((dest->flags & IP_VS_DEST_F_OVERLOAD) || + atomic_read(&dest->weight) == 0) + continue; + doh = ip_vs_dest_conn_overhead(dest); + if (!least || doh < loh) { + least = dest; + loh = doh; + } + } + + if (!least) + ip_vs_scheduler_err(svc, "no destination available"); + else + IP_VS_DBG_BUF(6, "LC: server %s:%u activeconns %d " + "inactconns %d\n", + IP_VS_DBG_ADDR(least->af, &least->addr), + ntohs(least->port), + atomic_read(&least->activeconns), + atomic_read(&least->inactconns)); + + return least; +} + + +static struct ip_vs_scheduler ip_vs_lc_scheduler = { + .name = "lc", + .refcnt = ATOMIC_INIT(0), + .module = THIS_MODULE, + .n_list = LIST_HEAD_INIT(ip_vs_lc_scheduler.n_list), + .schedule = ip_vs_lc_schedule, +}; + + +static int __init ip_vs_lc_init(void) +{ + return register_ip_vs_scheduler(&ip_vs_lc_scheduler) ; +} + +static void __exit ip_vs_lc_cleanup(void) +{ + unregister_ip_vs_scheduler(&ip_vs_lc_scheduler); + synchronize_rcu(); +} + +module_init(ip_vs_lc_init); +module_exit(ip_vs_lc_cleanup); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/ipvs/ip_vs_mh.c b/net/netfilter/ipvs/ip_vs_mh.c new file mode 100644 index 000000000..e3d7f5c87 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_mh.c @@ -0,0 +1,539 @@ +// SPDX-License-Identifier: GPL-2.0 +/* IPVS: Maglev Hashing scheduling module + * + * Authors: Inju Song + * + */ + +/* The mh algorithm is to assign a preference list of all the lookup + * table positions to each destination and populate the table with + * the most-preferred position of destinations. Then it is to select + * destination with the hash key of source IP address through looking + * up a the lookup table. + * + * The algorithm is detailed in: + * [3.4 Consistent Hasing] +https://www.usenix.org/system/files/conference/nsdi16/nsdi16-paper-eisenbud.pdf + * + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#define IP_VS_SVC_F_SCHED_MH_FALLBACK IP_VS_SVC_F_SCHED1 /* MH fallback */ +#define IP_VS_SVC_F_SCHED_MH_PORT IP_VS_SVC_F_SCHED2 /* MH use port */ + +struct ip_vs_mh_lookup { + struct ip_vs_dest __rcu *dest; /* real server (cache) */ +}; + +struct ip_vs_mh_dest_setup { + unsigned int offset; /* starting offset */ + unsigned int skip; /* skip */ + unsigned int perm; /* next_offset */ + int turns; /* weight / gcd() and rshift */ +}; + +/* Available prime numbers for MH table */ +static int primes[] = {251, 509, 1021, 2039, 4093, + 8191, 16381, 32749, 65521, 131071}; + +/* For IPVS MH entry hash table */ +#ifndef CONFIG_IP_VS_MH_TAB_INDEX +#define CONFIG_IP_VS_MH_TAB_INDEX 12 +#endif +#define IP_VS_MH_TAB_BITS (CONFIG_IP_VS_MH_TAB_INDEX / 2) +#define IP_VS_MH_TAB_INDEX (CONFIG_IP_VS_MH_TAB_INDEX - 8) +#define IP_VS_MH_TAB_SIZE primes[IP_VS_MH_TAB_INDEX] + +struct ip_vs_mh_state { + struct rcu_head rcu_head; + struct ip_vs_mh_lookup *lookup; + struct ip_vs_mh_dest_setup *dest_setup; + hsiphash_key_t hash1, hash2; + int gcd; + int rshift; +}; + +static inline void generate_hash_secret(hsiphash_key_t *hash1, + hsiphash_key_t *hash2) +{ + hash1->key[0] = 2654435761UL; + hash1->key[1] = 2654435761UL; + + hash2->key[0] = 2654446892UL; + hash2->key[1] = 2654446892UL; +} + +/* Helper function to determine if server is unavailable */ +static inline bool is_unavailable(struct ip_vs_dest *dest) +{ + return atomic_read(&dest->weight) <= 0 || + dest->flags & IP_VS_DEST_F_OVERLOAD; +} + +/* Returns hash value for IPVS MH entry */ +static inline unsigned int +ip_vs_mh_hashkey(int af, const union nf_inet_addr *addr, + __be16 port, hsiphash_key_t *key, unsigned int offset) +{ + unsigned int v; + __be32 addr_fold = addr->ip; + +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + addr_fold = addr->ip6[0] ^ addr->ip6[1] ^ + addr->ip6[2] ^ addr->ip6[3]; +#endif + v = (offset + ntohs(port) + ntohl(addr_fold)); + return hsiphash(&v, sizeof(v), key); +} + +/* Reset all the hash buckets of the specified table. */ +static void ip_vs_mh_reset(struct ip_vs_mh_state *s) +{ + int i; + struct ip_vs_mh_lookup *l; + struct ip_vs_dest *dest; + + l = &s->lookup[0]; + for (i = 0; i < IP_VS_MH_TAB_SIZE; i++) { + dest = rcu_dereference_protected(l->dest, 1); + if (dest) { + ip_vs_dest_put(dest); + RCU_INIT_POINTER(l->dest, NULL); + } + l++; + } +} + +static int ip_vs_mh_permutate(struct ip_vs_mh_state *s, + struct ip_vs_service *svc) +{ + struct list_head *p; + struct ip_vs_mh_dest_setup *ds; + struct ip_vs_dest *dest; + int lw; + + /* If gcd is smaller then 1, number of dests or + * all last_weight of dests are zero. So, skip + * permutation for the dests. + */ + if (s->gcd < 1) + return 0; + + /* Set dest_setup for the dests permutation */ + p = &svc->destinations; + ds = &s->dest_setup[0]; + while ((p = p->next) != &svc->destinations) { + dest = list_entry(p, struct ip_vs_dest, n_list); + + ds->offset = ip_vs_mh_hashkey(svc->af, &dest->addr, + dest->port, &s->hash1, 0) % + IP_VS_MH_TAB_SIZE; + ds->skip = ip_vs_mh_hashkey(svc->af, &dest->addr, + dest->port, &s->hash2, 0) % + (IP_VS_MH_TAB_SIZE - 1) + 1; + ds->perm = ds->offset; + + lw = atomic_read(&dest->last_weight); + ds->turns = ((lw / s->gcd) >> s->rshift) ? : (lw != 0); + ds++; + } + + return 0; +} + +static int ip_vs_mh_populate(struct ip_vs_mh_state *s, + struct ip_vs_service *svc) +{ + int n, c, dt_count; + unsigned long *table; + struct list_head *p; + struct ip_vs_mh_dest_setup *ds; + struct ip_vs_dest *dest, *new_dest; + + /* If gcd is smaller then 1, number of dests or + * all last_weight of dests are zero. So, skip + * the population for the dests and reset lookup table. + */ + if (s->gcd < 1) { + ip_vs_mh_reset(s); + return 0; + } + + table = bitmap_zalloc(IP_VS_MH_TAB_SIZE, GFP_KERNEL); + if (!table) + return -ENOMEM; + + p = &svc->destinations; + n = 0; + dt_count = 0; + while (n < IP_VS_MH_TAB_SIZE) { + if (p == &svc->destinations) + p = p->next; + + ds = &s->dest_setup[0]; + while (p != &svc->destinations) { + /* Ignore added server with zero weight */ + if (ds->turns < 1) { + p = p->next; + ds++; + continue; + } + + c = ds->perm; + while (test_bit(c, table)) { + /* Add skip, mod IP_VS_MH_TAB_SIZE */ + ds->perm += ds->skip; + if (ds->perm >= IP_VS_MH_TAB_SIZE) + ds->perm -= IP_VS_MH_TAB_SIZE; + c = ds->perm; + } + + __set_bit(c, table); + + dest = rcu_dereference_protected(s->lookup[c].dest, 1); + new_dest = list_entry(p, struct ip_vs_dest, n_list); + if (dest != new_dest) { + if (dest) + ip_vs_dest_put(dest); + ip_vs_dest_hold(new_dest); + RCU_INIT_POINTER(s->lookup[c].dest, new_dest); + } + + if (++n == IP_VS_MH_TAB_SIZE) + goto out; + + if (++dt_count >= ds->turns) { + dt_count = 0; + p = p->next; + ds++; + } + } + } + +out: + bitmap_free(table); + return 0; +} + +/* Get ip_vs_dest associated with supplied parameters. */ +static inline struct ip_vs_dest * +ip_vs_mh_get(struct ip_vs_service *svc, struct ip_vs_mh_state *s, + const union nf_inet_addr *addr, __be16 port) +{ + unsigned int hash = ip_vs_mh_hashkey(svc->af, addr, port, &s->hash1, 0) + % IP_VS_MH_TAB_SIZE; + struct ip_vs_dest *dest = rcu_dereference(s->lookup[hash].dest); + + return (!dest || is_unavailable(dest)) ? NULL : dest; +} + +/* As ip_vs_mh_get, but with fallback if selected server is unavailable */ +static inline struct ip_vs_dest * +ip_vs_mh_get_fallback(struct ip_vs_service *svc, struct ip_vs_mh_state *s, + const union nf_inet_addr *addr, __be16 port) +{ + unsigned int offset, roffset; + unsigned int hash, ihash; + struct ip_vs_dest *dest; + + /* First try the dest it's supposed to go to */ + ihash = ip_vs_mh_hashkey(svc->af, addr, port, + &s->hash1, 0) % IP_VS_MH_TAB_SIZE; + dest = rcu_dereference(s->lookup[ihash].dest); + if (!dest) + return NULL; + if (!is_unavailable(dest)) + return dest; + + IP_VS_DBG_BUF(6, "MH: selected unavailable server %s:%u, reselecting", + IP_VS_DBG_ADDR(dest->af, &dest->addr), ntohs(dest->port)); + + /* If the original dest is unavailable, loop around the table + * starting from ihash to find a new dest + */ + for (offset = 0; offset < IP_VS_MH_TAB_SIZE; offset++) { + roffset = (offset + ihash) % IP_VS_MH_TAB_SIZE; + hash = ip_vs_mh_hashkey(svc->af, addr, port, &s->hash1, + roffset) % IP_VS_MH_TAB_SIZE; + dest = rcu_dereference(s->lookup[hash].dest); + if (!dest) + break; + if (!is_unavailable(dest)) + return dest; + IP_VS_DBG_BUF(6, + "MH: selected unavailable server %s:%u (offset %u), reselecting", + IP_VS_DBG_ADDR(dest->af, &dest->addr), + ntohs(dest->port), roffset); + } + + return NULL; +} + +/* Assign all the hash buckets of the specified table with the service. */ +static int ip_vs_mh_reassign(struct ip_vs_mh_state *s, + struct ip_vs_service *svc) +{ + int ret; + + if (svc->num_dests > IP_VS_MH_TAB_SIZE) + return -EINVAL; + + if (svc->num_dests >= 1) { + s->dest_setup = kcalloc(svc->num_dests, + sizeof(struct ip_vs_mh_dest_setup), + GFP_KERNEL); + if (!s->dest_setup) + return -ENOMEM; + } + + ip_vs_mh_permutate(s, svc); + + ret = ip_vs_mh_populate(s, svc); + if (ret < 0) + goto out; + + IP_VS_DBG_BUF(6, "MH: reassign lookup table of %s:%u\n", + IP_VS_DBG_ADDR(svc->af, &svc->addr), + ntohs(svc->port)); + +out: + if (svc->num_dests >= 1) { + kfree(s->dest_setup); + s->dest_setup = NULL; + } + return ret; +} + +static int ip_vs_mh_gcd_weight(struct ip_vs_service *svc) +{ + struct ip_vs_dest *dest; + int weight; + int g = 0; + + list_for_each_entry(dest, &svc->destinations, n_list) { + weight = atomic_read(&dest->last_weight); + if (weight > 0) { + if (g > 0) + g = gcd(weight, g); + else + g = weight; + } + } + return g; +} + +/* To avoid assigning huge weight for the MH table, + * calculate shift value with gcd. + */ +static int ip_vs_mh_shift_weight(struct ip_vs_service *svc, int gcd) +{ + struct ip_vs_dest *dest; + int new_weight, weight = 0; + int mw, shift; + + /* If gcd is smaller then 1, number of dests or + * all last_weight of dests are zero. So, return + * shift value as zero. + */ + if (gcd < 1) + return 0; + + list_for_each_entry(dest, &svc->destinations, n_list) { + new_weight = atomic_read(&dest->last_weight); + if (new_weight > weight) + weight = new_weight; + } + + /* Because gcd is greater than zero, + * the maximum weight and gcd are always greater than zero + */ + mw = weight / gcd; + + /* shift = occupied bits of weight/gcd - MH highest bits */ + shift = fls(mw) - IP_VS_MH_TAB_BITS; + return (shift >= 0) ? shift : 0; +} + +static void ip_vs_mh_state_free(struct rcu_head *head) +{ + struct ip_vs_mh_state *s; + + s = container_of(head, struct ip_vs_mh_state, rcu_head); + kfree(s->lookup); + kfree(s); +} + +static int ip_vs_mh_init_svc(struct ip_vs_service *svc) +{ + int ret; + struct ip_vs_mh_state *s; + + /* Allocate the MH table for this service */ + s = kzalloc(sizeof(*s), GFP_KERNEL); + if (!s) + return -ENOMEM; + + s->lookup = kcalloc(IP_VS_MH_TAB_SIZE, sizeof(struct ip_vs_mh_lookup), + GFP_KERNEL); + if (!s->lookup) { + kfree(s); + return -ENOMEM; + } + + generate_hash_secret(&s->hash1, &s->hash2); + s->gcd = ip_vs_mh_gcd_weight(svc); + s->rshift = ip_vs_mh_shift_weight(svc, s->gcd); + + IP_VS_DBG(6, + "MH lookup table (memory=%zdbytes) allocated for current service\n", + sizeof(struct ip_vs_mh_lookup) * IP_VS_MH_TAB_SIZE); + + /* Assign the lookup table with current dests */ + ret = ip_vs_mh_reassign(s, svc); + if (ret < 0) { + ip_vs_mh_reset(s); + ip_vs_mh_state_free(&s->rcu_head); + return ret; + } + + /* No more failures, attach state */ + svc->sched_data = s; + return 0; +} + +static void ip_vs_mh_done_svc(struct ip_vs_service *svc) +{ + struct ip_vs_mh_state *s = svc->sched_data; + + /* Got to clean up lookup entry here */ + ip_vs_mh_reset(s); + + call_rcu(&s->rcu_head, ip_vs_mh_state_free); + IP_VS_DBG(6, "MH lookup table (memory=%zdbytes) released\n", + sizeof(struct ip_vs_mh_lookup) * IP_VS_MH_TAB_SIZE); +} + +static int ip_vs_mh_dest_changed(struct ip_vs_service *svc, + struct ip_vs_dest *dest) +{ + struct ip_vs_mh_state *s = svc->sched_data; + + s->gcd = ip_vs_mh_gcd_weight(svc); + s->rshift = ip_vs_mh_shift_weight(svc, s->gcd); + + /* Assign the lookup table with the updated service */ + return ip_vs_mh_reassign(s, svc); +} + +/* Helper function to get port number */ +static inline __be16 +ip_vs_mh_get_port(const struct sk_buff *skb, struct ip_vs_iphdr *iph) +{ + __be16 _ports[2], *ports; + + /* At this point we know that we have a valid packet of some kind. + * Because ICMP packets are only guaranteed to have the first 8 + * bytes, let's just grab the ports. Fortunately they're in the + * same position for all three of the protocols we care about. + */ + switch (iph->protocol) { + case IPPROTO_TCP: + case IPPROTO_UDP: + case IPPROTO_SCTP: + ports = skb_header_pointer(skb, iph->len, sizeof(_ports), + &_ports); + if (unlikely(!ports)) + return 0; + + if (likely(!ip_vs_iph_inverse(iph))) + return ports[0]; + else + return ports[1]; + default: + return 0; + } +} + +/* Maglev Hashing scheduling */ +static struct ip_vs_dest * +ip_vs_mh_schedule(struct ip_vs_service *svc, const struct sk_buff *skb, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_dest *dest; + struct ip_vs_mh_state *s; + __be16 port = 0; + const union nf_inet_addr *hash_addr; + + hash_addr = ip_vs_iph_inverse(iph) ? &iph->daddr : &iph->saddr; + + IP_VS_DBG(6, "%s : Scheduling...\n", __func__); + + if (svc->flags & IP_VS_SVC_F_SCHED_MH_PORT) + port = ip_vs_mh_get_port(skb, iph); + + s = (struct ip_vs_mh_state *)svc->sched_data; + + if (svc->flags & IP_VS_SVC_F_SCHED_MH_FALLBACK) + dest = ip_vs_mh_get_fallback(svc, s, hash_addr, port); + else + dest = ip_vs_mh_get(svc, s, hash_addr, port); + + if (!dest) { + ip_vs_scheduler_err(svc, "no destination available"); + return NULL; + } + + IP_VS_DBG_BUF(6, "MH: source IP address %s:%u --> server %s:%u\n", + IP_VS_DBG_ADDR(svc->af, hash_addr), + ntohs(port), + IP_VS_DBG_ADDR(dest->af, &dest->addr), + ntohs(dest->port)); + + return dest; +} + +/* IPVS MH Scheduler structure */ +static struct ip_vs_scheduler ip_vs_mh_scheduler = { + .name = "mh", + .refcnt = ATOMIC_INIT(0), + .module = THIS_MODULE, + .n_list = LIST_HEAD_INIT(ip_vs_mh_scheduler.n_list), + .init_service = ip_vs_mh_init_svc, + .done_service = ip_vs_mh_done_svc, + .add_dest = ip_vs_mh_dest_changed, + .del_dest = ip_vs_mh_dest_changed, + .upd_dest = ip_vs_mh_dest_changed, + .schedule = ip_vs_mh_schedule, +}; + +static int __init ip_vs_mh_init(void) +{ + return register_ip_vs_scheduler(&ip_vs_mh_scheduler); +} + +static void __exit ip_vs_mh_cleanup(void) +{ + unregister_ip_vs_scheduler(&ip_vs_mh_scheduler); + rcu_barrier(); +} + +module_init(ip_vs_mh_init); +module_exit(ip_vs_mh_cleanup); +MODULE_DESCRIPTION("Maglev hashing ipvs scheduler"); +MODULE_LICENSE("GPL v2"); +MODULE_AUTHOR("Inju Song "); diff --git a/net/netfilter/ipvs/ip_vs_nfct.c b/net/netfilter/ipvs/ip_vs_nfct.c new file mode 100644 index 000000000..08adcb222 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_nfct.c @@ -0,0 +1,280 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * ip_vs_nfct.c: Netfilter connection tracking support for IPVS + * + * Portions Copyright (C) 2001-2002 + * Antefacto Ltd, 181 Parnell St, Dublin 1, Ireland. + * + * Portions Copyright (C) 2003-2010 + * Julian Anastasov + * + * Authors: + * Ben North + * Julian Anastasov Reorganize and sync with latest kernels + * Hannes Eder Extend NFCT support for FTP, ipvs match + * + * Current status: + * + * - provide conntrack confirmation for new and related connections, by + * this way we can see their proper conntrack state in all hooks + * - support for all forwarding methods, not only NAT + * - FTP support (NAT), ability to support other NAT apps with expectations + * - to correctly create expectations for related NAT connections the proper + * NF conntrack support must be already installed, eg. ip_vs_ftp requires + * nf_conntrack_ftp ... iptables_nat for the same ports (but no iptables + * NAT rules are needed) + * - alter reply for NAT when forwarding packet in original direction: + * conntrack from client in NEW or RELATED (Passive FTP DATA) state or + * when RELATED conntrack is created from real server (Active FTP DATA) + * - if iptables_nat is not loaded the Passive FTP will not work (the + * PASV response can not be NAT-ed) but Active FTP should work + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#define FMT_TUPLE "%s:%u->%s:%u/%u" +#define ARG_TUPLE(T) IP_VS_DBG_ADDR((T)->src.l3num, &(T)->src.u3), \ + ntohs((T)->src.u.all), \ + IP_VS_DBG_ADDR((T)->src.l3num, &(T)->dst.u3), \ + ntohs((T)->dst.u.all), \ + (T)->dst.protonum + +#define FMT_CONN "%s:%u->%s:%u->%s:%u/%u:%u" +#define ARG_CONN(C) IP_VS_DBG_ADDR((C)->af, &((C)->caddr)), \ + ntohs((C)->cport), \ + IP_VS_DBG_ADDR((C)->af, &((C)->vaddr)), \ + ntohs((C)->vport), \ + IP_VS_DBG_ADDR((C)->daf, &((C)->daddr)), \ + ntohs((C)->dport), \ + (C)->protocol, (C)->state + +void +ip_vs_update_conntrack(struct sk_buff *skb, struct ip_vs_conn *cp, int outin) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + struct nf_conntrack_tuple new_tuple; + + if (ct == NULL || nf_ct_is_confirmed(ct) || + nf_ct_is_dying(ct)) + return; + + /* Never alter conntrack for non-NAT conns */ + if (IP_VS_FWD_METHOD(cp) != IP_VS_CONN_F_MASQ) + return; + + /* Never alter conntrack for OPS conns (no reply is expected) */ + if (cp->flags & IP_VS_CONN_F_ONE_PACKET) + return; + + /* Alter reply only in original direction */ + if (CTINFO2DIR(ctinfo) != IP_CT_DIR_ORIGINAL) + return; + + /* Applications may adjust TCP seqs */ + if (cp->app && nf_ct_protonum(ct) == IPPROTO_TCP && + !nfct_seqadj(ct) && !nfct_seqadj_ext_add(ct)) + return; + + /* + * The connection is not yet in the hashtable, so we update it. + * CIP->VIP will remain the same, so leave the tuple in + * IP_CT_DIR_ORIGINAL untouched. When the reply comes back from the + * real-server we will see RIP->DIP. + */ + new_tuple = ct->tuplehash[IP_CT_DIR_REPLY].tuple; + /* + * This will also take care of UDP and other protocols. + */ + if (outin) { + new_tuple.src.u3 = cp->daddr; + if (new_tuple.dst.protonum != IPPROTO_ICMP && + new_tuple.dst.protonum != IPPROTO_ICMPV6) + new_tuple.src.u.tcp.port = cp->dport; + } else { + new_tuple.dst.u3 = cp->vaddr; + if (new_tuple.dst.protonum != IPPROTO_ICMP && + new_tuple.dst.protonum != IPPROTO_ICMPV6) + new_tuple.dst.u.tcp.port = cp->vport; + } + IP_VS_DBG_BUF(7, "%s: Updating conntrack ct=%p, status=0x%lX, " + "ctinfo=%d, old reply=" FMT_TUPLE "\n", + __func__, ct, ct->status, ctinfo, + ARG_TUPLE(&ct->tuplehash[IP_CT_DIR_REPLY].tuple)); + IP_VS_DBG_BUF(7, "%s: Updating conntrack ct=%p, status=0x%lX, " + "ctinfo=%d, new reply=" FMT_TUPLE "\n", + __func__, ct, ct->status, ctinfo, + ARG_TUPLE(&new_tuple)); + nf_conntrack_alter_reply(ct, &new_tuple); + IP_VS_DBG_BUF(7, "%s: Updated conntrack ct=%p for cp=" FMT_CONN "\n", + __func__, ct, ARG_CONN(cp)); +} + +int ip_vs_confirm_conntrack(struct sk_buff *skb) +{ + return nf_conntrack_confirm(skb); +} + +/* + * Called from init_conntrack() as expectfn handler. + */ +static void ip_vs_nfct_expect_callback(struct nf_conn *ct, + struct nf_conntrack_expect *exp) +{ + struct nf_conntrack_tuple *orig, new_reply; + struct ip_vs_conn *cp; + struct ip_vs_conn_param p; + struct net *net = nf_ct_net(ct); + + /* + * We assume that no NF locks are held before this callback. + * ip_vs_conn_out_get and ip_vs_conn_in_get should match their + * expectations even if they use wildcard values, now we provide the + * actual values from the newly created original conntrack direction. + * The conntrack is confirmed when packet reaches IPVS hooks. + */ + + /* RS->CLIENT */ + orig = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple; + ip_vs_conn_fill_param(net_ipvs(net), exp->tuple.src.l3num, orig->dst.protonum, + &orig->src.u3, orig->src.u.tcp.port, + &orig->dst.u3, orig->dst.u.tcp.port, &p); + cp = ip_vs_conn_out_get(&p); + if (cp) { + /* Change reply CLIENT->RS to CLIENT->VS */ + IP_VS_DBG_BUF(7, "%s: for ct=%p, status=0x%lX found inout cp=" + FMT_CONN "\n", + __func__, ct, ct->status, ARG_CONN(cp)); + new_reply = ct->tuplehash[IP_CT_DIR_REPLY].tuple; + IP_VS_DBG_BUF(7, "%s: ct=%p before alter: reply tuple=" + FMT_TUPLE "\n", + __func__, ct, ARG_TUPLE(&new_reply)); + new_reply.dst.u3 = cp->vaddr; + new_reply.dst.u.tcp.port = cp->vport; + goto alter; + } + + /* CLIENT->VS */ + cp = ip_vs_conn_in_get(&p); + if (cp) { + /* Change reply VS->CLIENT to RS->CLIENT */ + IP_VS_DBG_BUF(7, "%s: for ct=%p, status=0x%lX found outin cp=" + FMT_CONN "\n", + __func__, ct, ct->status, ARG_CONN(cp)); + new_reply = ct->tuplehash[IP_CT_DIR_REPLY].tuple; + IP_VS_DBG_BUF(7, "%s: ct=%p before alter: reply tuple=" + FMT_TUPLE "\n", + __func__, ct, ARG_TUPLE(&new_reply)); + new_reply.src.u3 = cp->daddr; + new_reply.src.u.tcp.port = cp->dport; + goto alter; + } + + IP_VS_DBG_BUF(7, "%s: ct=%p, status=0x%lX, tuple=" FMT_TUPLE + " - unknown expect\n", + __func__, ct, ct->status, ARG_TUPLE(orig)); + return; + +alter: + /* Never alter conntrack for non-NAT conns */ + if (IP_VS_FWD_METHOD(cp) == IP_VS_CONN_F_MASQ) + nf_conntrack_alter_reply(ct, &new_reply); + ip_vs_conn_put(cp); + return; +} + +/* + * Create NF conntrack expectation with wildcard (optional) source port. + * Then the default callback function will alter the reply and will confirm + * the conntrack entry when the first packet comes. + * Use port 0 to expect connection from any port. + */ +void ip_vs_nfct_expect_related(struct sk_buff *skb, struct nf_conn *ct, + struct ip_vs_conn *cp, u_int8_t proto, + const __be16 port, int from_rs) +{ + struct nf_conntrack_expect *exp; + + if (ct == NULL) + return; + + exp = nf_ct_expect_alloc(ct); + if (!exp) + return; + + nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT, nf_ct_l3num(ct), + from_rs ? &cp->daddr : &cp->caddr, + from_rs ? &cp->caddr : &cp->vaddr, + proto, port ? &port : NULL, + from_rs ? &cp->cport : &cp->vport); + + exp->expectfn = ip_vs_nfct_expect_callback; + + IP_VS_DBG_BUF(7, "%s: ct=%p, expect tuple=" FMT_TUPLE "\n", + __func__, ct, ARG_TUPLE(&exp->tuple)); + nf_ct_expect_related(exp, 0); + nf_ct_expect_put(exp); +} +EXPORT_SYMBOL(ip_vs_nfct_expect_related); + +/* + * Our connection was terminated, try to drop the conntrack immediately + */ +void ip_vs_conn_drop_conntrack(struct ip_vs_conn *cp) +{ + struct nf_conntrack_tuple_hash *h; + struct nf_conn *ct; + struct nf_conntrack_tuple tuple; + + if (!cp->cport) + return; + + tuple = (struct nf_conntrack_tuple) { + .dst = { .protonum = cp->protocol, .dir = IP_CT_DIR_ORIGINAL } }; + tuple.src.u3 = cp->caddr; + tuple.src.u.all = cp->cport; + tuple.src.l3num = cp->af; + tuple.dst.u3 = cp->vaddr; + tuple.dst.u.all = cp->vport; + + IP_VS_DBG_BUF(7, "%s: dropping conntrack for conn " FMT_CONN "\n", + __func__, ARG_CONN(cp)); + + h = nf_conntrack_find_get(cp->ipvs->net, &nf_ct_zone_dflt, &tuple); + if (h) { + ct = nf_ct_tuplehash_to_ctrack(h); + if (nf_ct_kill(ct)) { + IP_VS_DBG_BUF(7, "%s: ct=%p deleted for tuple=" + FMT_TUPLE "\n", + __func__, ct, ARG_TUPLE(&tuple)); + } else { + IP_VS_DBG_BUF(7, "%s: ct=%p, no conntrack for tuple=" + FMT_TUPLE "\n", + __func__, ct, ARG_TUPLE(&tuple)); + } + nf_ct_put(ct); + } else { + IP_VS_DBG_BUF(7, "%s: no conntrack for tuple=" FMT_TUPLE "\n", + __func__, ARG_TUPLE(&tuple)); + } +} + diff --git a/net/netfilter/ipvs/ip_vs_nq.c b/net/netfilter/ipvs/ip_vs_nq.c new file mode 100644 index 000000000..f56862a87 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_nq.c @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPVS: Never Queue scheduling module + * + * Authors: Wensong Zhang + * + * Changes: + */ + +/* + * The NQ algorithm adopts a two-speed model. When there is an idle server + * available, the job will be sent to the idle server, instead of waiting + * for a fast one. When there is no idle server available, the job will be + * sent to the server that minimize its expected delay (The Shortest + * Expected Delay scheduling algorithm). + * + * See the following paper for more information: + * A. Weinrib and S. Shenker, Greed is not enough: Adaptive load sharing + * in large heterogeneous systems. In Proceedings IEEE INFOCOM'88, + * pages 986-994, 1988. + * + * Thanks must go to Marko Buuri for talking NQ to me. + * + * The difference between NQ and SED is that NQ can improve overall + * system utilization. + * + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include + +#include + + +static inline int +ip_vs_nq_dest_overhead(struct ip_vs_dest *dest) +{ + /* + * We only use the active connection number in the cost + * calculation here. + */ + return atomic_read(&dest->activeconns) + 1; +} + + +/* + * Weighted Least Connection scheduling + */ +static struct ip_vs_dest * +ip_vs_nq_schedule(struct ip_vs_service *svc, const struct sk_buff *skb, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_dest *dest, *least = NULL; + int loh = 0, doh; + + IP_VS_DBG(6, "%s(): Scheduling...\n", __func__); + + /* + * We calculate the load of each dest server as follows: + * (server expected overhead) / dest->weight + * + * Remember -- no floats in kernel mode!!! + * The comparison of h1*w2 > h2*w1 is equivalent to that of + * h1/w1 > h2/w2 + * if every weight is larger than zero. + * + * The server with weight=0 is quiesced and will not receive any + * new connections. + */ + + list_for_each_entry_rcu(dest, &svc->destinations, n_list) { + + if (dest->flags & IP_VS_DEST_F_OVERLOAD || + !atomic_read(&dest->weight)) + continue; + + doh = ip_vs_nq_dest_overhead(dest); + + /* return the server directly if it is idle */ + if (atomic_read(&dest->activeconns) == 0) { + least = dest; + loh = doh; + goto out; + } + + if (!least || + ((__s64)loh * atomic_read(&dest->weight) > + (__s64)doh * atomic_read(&least->weight))) { + least = dest; + loh = doh; + } + } + + if (!least) { + ip_vs_scheduler_err(svc, "no destination available"); + return NULL; + } + + out: + IP_VS_DBG_BUF(6, "NQ: server %s:%u " + "activeconns %d refcnt %d weight %d overhead %d\n", + IP_VS_DBG_ADDR(least->af, &least->addr), + ntohs(least->port), + atomic_read(&least->activeconns), + refcount_read(&least->refcnt), + atomic_read(&least->weight), loh); + + return least; +} + + +static struct ip_vs_scheduler ip_vs_nq_scheduler = +{ + .name = "nq", + .refcnt = ATOMIC_INIT(0), + .module = THIS_MODULE, + .n_list = LIST_HEAD_INIT(ip_vs_nq_scheduler.n_list), + .schedule = ip_vs_nq_schedule, +}; + + +static int __init ip_vs_nq_init(void) +{ + return register_ip_vs_scheduler(&ip_vs_nq_scheduler); +} + +static void __exit ip_vs_nq_cleanup(void) +{ + unregister_ip_vs_scheduler(&ip_vs_nq_scheduler); + synchronize_rcu(); +} + +module_init(ip_vs_nq_init); +module_exit(ip_vs_nq_cleanup); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/ipvs/ip_vs_ovf.c b/net/netfilter/ipvs/ip_vs_ovf.c new file mode 100644 index 000000000..c03066fdd --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_ovf.c @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPVS: Overflow-Connection Scheduling module + * + * Authors: Raducu Deaconu + * + * Scheduler implements "overflow" loadbalancing according to number of active + * connections , will keep all connections to the node with the highest weight + * and overflow to the next node if the number of connections exceeds the node's + * weight. + * Note that this scheduler might not be suitable for UDP because it only uses + * active connections + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include + +#include + +/* OVF Connection scheduling */ +static struct ip_vs_dest * +ip_vs_ovf_schedule(struct ip_vs_service *svc, const struct sk_buff *skb, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_dest *dest, *h = NULL; + int hw = 0, w; + + IP_VS_DBG(6, "ip_vs_ovf_schedule(): Scheduling...\n"); + /* select the node with highest weight, go to next in line if active + * connections exceed weight + */ + list_for_each_entry_rcu(dest, &svc->destinations, n_list) { + w = atomic_read(&dest->weight); + if ((dest->flags & IP_VS_DEST_F_OVERLOAD) || + atomic_read(&dest->activeconns) > w || + w == 0) + continue; + if (!h || w > hw) { + h = dest; + hw = w; + } + } + + if (h) { + IP_VS_DBG_BUF(6, "OVF: server %s:%u active %d w %d\n", + IP_VS_DBG_ADDR(h->af, &h->addr), + ntohs(h->port), + atomic_read(&h->activeconns), + atomic_read(&h->weight)); + return h; + } + + ip_vs_scheduler_err(svc, "no destination available"); + return NULL; +} + +static struct ip_vs_scheduler ip_vs_ovf_scheduler = { + .name = "ovf", + .refcnt = ATOMIC_INIT(0), + .module = THIS_MODULE, + .n_list = LIST_HEAD_INIT(ip_vs_ovf_scheduler.n_list), + .schedule = ip_vs_ovf_schedule, +}; + +static int __init ip_vs_ovf_init(void) +{ + return register_ip_vs_scheduler(&ip_vs_ovf_scheduler); +} + +static void __exit ip_vs_ovf_cleanup(void) +{ + unregister_ip_vs_scheduler(&ip_vs_ovf_scheduler); + synchronize_rcu(); +} + +module_init(ip_vs_ovf_init); +module_exit(ip_vs_ovf_cleanup); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/ipvs/ip_vs_pe.c b/net/netfilter/ipvs/ip_vs_pe.c new file mode 100644 index 000000000..166c669f0 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_pe.c @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: GPL-2.0-only +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include + +#include + +/* IPVS pe list */ +static LIST_HEAD(ip_vs_pe); + +/* semaphore for IPVS PEs. */ +static DEFINE_MUTEX(ip_vs_pe_mutex); + +/* Get pe in the pe list by name */ +struct ip_vs_pe *__ip_vs_pe_getbyname(const char *pe_name) +{ + struct ip_vs_pe *pe; + + IP_VS_DBG(10, "%s(): pe_name \"%s\"\n", __func__, + pe_name); + + rcu_read_lock(); + list_for_each_entry_rcu(pe, &ip_vs_pe, n_list) { + /* Test and get the modules atomically */ + if (pe->module && + !try_module_get(pe->module)) { + /* This pe is just deleted */ + continue; + } + if (strcmp(pe_name, pe->name)==0) { + /* HIT */ + rcu_read_unlock(); + return pe; + } + module_put(pe->module); + } + rcu_read_unlock(); + + return NULL; +} + +/* Lookup pe and try to load it if it doesn't exist */ +struct ip_vs_pe *ip_vs_pe_getbyname(const char *name) +{ + struct ip_vs_pe *pe; + + /* Search for the pe by name */ + pe = __ip_vs_pe_getbyname(name); + + /* If pe not found, load the module and search again */ + if (!pe) { + request_module("ip_vs_pe_%s", name); + pe = __ip_vs_pe_getbyname(name); + } + + return pe; +} + +/* Register a pe in the pe list */ +int register_ip_vs_pe(struct ip_vs_pe *pe) +{ + struct ip_vs_pe *tmp; + + /* increase the module use count */ + if (!ip_vs_use_count_inc()) + return -ENOENT; + + mutex_lock(&ip_vs_pe_mutex); + /* Make sure that the pe with this name doesn't exist + * in the pe list. + */ + list_for_each_entry(tmp, &ip_vs_pe, n_list) { + if (strcmp(tmp->name, pe->name) == 0) { + mutex_unlock(&ip_vs_pe_mutex); + ip_vs_use_count_dec(); + pr_err("%s(): [%s] pe already existed " + "in the system\n", __func__, pe->name); + return -EINVAL; + } + } + /* Add it into the d-linked pe list */ + list_add_rcu(&pe->n_list, &ip_vs_pe); + mutex_unlock(&ip_vs_pe_mutex); + + pr_info("[%s] pe registered.\n", pe->name); + + return 0; +} +EXPORT_SYMBOL_GPL(register_ip_vs_pe); + +/* Unregister a pe from the pe list */ +int unregister_ip_vs_pe(struct ip_vs_pe *pe) +{ + mutex_lock(&ip_vs_pe_mutex); + /* Remove it from the d-linked pe list */ + list_del_rcu(&pe->n_list); + mutex_unlock(&ip_vs_pe_mutex); + + /* decrease the module use count */ + ip_vs_use_count_dec(); + + pr_info("[%s] pe unregistered.\n", pe->name); + + return 0; +} +EXPORT_SYMBOL_GPL(unregister_ip_vs_pe); diff --git a/net/netfilter/ipvs/ip_vs_pe_sip.c b/net/netfilter/ipvs/ip_vs_pe_sip.c new file mode 100644 index 000000000..0ac6705a6 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_pe_sip.c @@ -0,0 +1,187 @@ +// SPDX-License-Identifier: GPL-2.0-only +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include + +#include +#include +#include + +#ifdef CONFIG_IP_VS_DEBUG +static const char *ip_vs_dbg_callid(char *buf, size_t buf_len, + const char *callid, size_t callid_len, + int *idx) +{ + size_t max_len = 64; + size_t len = min3(max_len, callid_len, buf_len - *idx - 1); + memcpy(buf + *idx, callid, len); + buf[*idx+len] = '\0'; + *idx += len + 1; + return buf + *idx - len; +} + +#define IP_VS_DEBUG_CALLID(callid, len) \ + ip_vs_dbg_callid(ip_vs_dbg_buf, sizeof(ip_vs_dbg_buf), \ + callid, len, &ip_vs_dbg_idx) +#endif + +static int get_callid(const char *dptr, unsigned int dataoff, + unsigned int datalen, + unsigned int *matchoff, unsigned int *matchlen) +{ + /* Find callid */ + while (1) { + int ret = ct_sip_get_header(NULL, dptr, dataoff, datalen, + SIP_HDR_CALL_ID, matchoff, + matchlen); + if (ret > 0) + break; + if (!ret) + return -EINVAL; + dataoff += *matchoff; + } + + /* Too large is useless */ + if (*matchlen > IP_VS_PEDATA_MAXLEN) + return -EINVAL; + + /* SIP headers are always followed by a line terminator */ + if (*matchoff + *matchlen == datalen) + return -EINVAL; + + /* RFC 2543 allows lines to be terminated with CR, LF or CRLF, + * RFC 3261 allows only CRLF, we support both. */ + if (*(dptr + *matchoff + *matchlen) != '\r' && + *(dptr + *matchoff + *matchlen) != '\n') + return -EINVAL; + + IP_VS_DBG_BUF(9, "SIP callid %s (%d bytes)\n", + IP_VS_DEBUG_CALLID(dptr + *matchoff, *matchlen), + *matchlen); + return 0; +} + +static int +ip_vs_sip_fill_param(struct ip_vs_conn_param *p, struct sk_buff *skb) +{ + struct ip_vs_iphdr iph; + unsigned int dataoff, datalen, matchoff, matchlen; + const char *dptr; + int retc; + + retc = ip_vs_fill_iph_skb(p->af, skb, false, &iph); + + /* Only useful with UDP */ + if (!retc || iph.protocol != IPPROTO_UDP) + return -EINVAL; + /* todo: IPv6 fragments: + * I think this only should be done for the first fragment. /HS + */ + dataoff = iph.len + sizeof(struct udphdr); + + if (dataoff >= skb->len) + return -EINVAL; + retc = skb_linearize(skb); + if (retc < 0) + return retc; + dptr = skb->data + dataoff; + datalen = skb->len - dataoff; + + if (get_callid(dptr, 0, datalen, &matchoff, &matchlen)) + return -EINVAL; + + /* N.B: pe_data is only set on success, + * this allows fallback to the default persistence logic on failure + */ + p->pe_data = kmemdup(dptr + matchoff, matchlen, GFP_ATOMIC); + if (!p->pe_data) + return -ENOMEM; + + p->pe_data_len = matchlen; + + return 0; +} + +static bool ip_vs_sip_ct_match(const struct ip_vs_conn_param *p, + struct ip_vs_conn *ct) + +{ + bool ret = false; + + if (ct->af == p->af && + ip_vs_addr_equal(p->af, p->caddr, &ct->caddr) && + /* protocol should only be IPPROTO_IP if + * d_addr is a fwmark */ + ip_vs_addr_equal(p->protocol == IPPROTO_IP ? AF_UNSPEC : p->af, + p->vaddr, &ct->vaddr) && + ct->vport == p->vport && + ct->flags & IP_VS_CONN_F_TEMPLATE && + ct->protocol == p->protocol && + ct->pe_data && ct->pe_data_len == p->pe_data_len && + !memcmp(ct->pe_data, p->pe_data, p->pe_data_len)) + ret = true; + + IP_VS_DBG_BUF(9, "SIP template match %s %s->%s:%d %s\n", + ip_vs_proto_name(p->protocol), + IP_VS_DEBUG_CALLID(p->pe_data, p->pe_data_len), + IP_VS_DBG_ADDR(p->af, p->vaddr), ntohs(p->vport), + ret ? "hit" : "not hit"); + + return ret; +} + +static u32 ip_vs_sip_hashkey_raw(const struct ip_vs_conn_param *p, + u32 initval, bool inverse) +{ + return jhash(p->pe_data, p->pe_data_len, initval); +} + +static int ip_vs_sip_show_pe_data(const struct ip_vs_conn *cp, char *buf) +{ + memcpy(buf, cp->pe_data, cp->pe_data_len); + return cp->pe_data_len; +} + +static struct ip_vs_conn * +ip_vs_sip_conn_out(struct ip_vs_service *svc, + struct ip_vs_dest *dest, + struct sk_buff *skb, + const struct ip_vs_iphdr *iph, + __be16 dport, + __be16 cport) +{ + if (likely(iph->protocol == IPPROTO_UDP)) + return ip_vs_new_conn_out(svc, dest, skb, iph, dport, cport); + /* currently no need to handle other than UDP */ + return NULL; +} + +static struct ip_vs_pe ip_vs_sip_pe = +{ + .name = "sip", + .refcnt = ATOMIC_INIT(0), + .module = THIS_MODULE, + .n_list = LIST_HEAD_INIT(ip_vs_sip_pe.n_list), + .fill_param = ip_vs_sip_fill_param, + .ct_match = ip_vs_sip_ct_match, + .hashkey_raw = ip_vs_sip_hashkey_raw, + .show_pe_data = ip_vs_sip_show_pe_data, + .conn_out = ip_vs_sip_conn_out, +}; + +static int __init ip_vs_sip_init(void) +{ + return register_ip_vs_pe(&ip_vs_sip_pe); +} + +static void __exit ip_vs_sip_cleanup(void) +{ + unregister_ip_vs_pe(&ip_vs_sip_pe); + synchronize_rcu(); +} + +module_init(ip_vs_sip_init); +module_exit(ip_vs_sip_cleanup); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/ipvs/ip_vs_proto.c b/net/netfilter/ipvs/ip_vs_proto.c new file mode 100644 index 000000000..f100da4ba --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_proto.c @@ -0,0 +1,384 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * ip_vs_proto.c: transport protocol load balancing support for IPVS + * + * Authors: Wensong Zhang + * Julian Anastasov + * + * Changes: + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + + +/* + * IPVS protocols can only be registered/unregistered when the ipvs + * module is loaded/unloaded, so no lock is needed in accessing the + * ipvs protocol table. + */ + +#define IP_VS_PROTO_TAB_SIZE 32 /* must be power of 2 */ +#define IP_VS_PROTO_HASH(proto) ((proto) & (IP_VS_PROTO_TAB_SIZE-1)) + +static struct ip_vs_protocol *ip_vs_proto_table[IP_VS_PROTO_TAB_SIZE]; + +/* States for conn templates: NONE or words separated with ",", max 15 chars */ +static const char *ip_vs_ctpl_state_name_table[IP_VS_CTPL_S_LAST] = { + [IP_VS_CTPL_S_NONE] = "NONE", + [IP_VS_CTPL_S_ASSURED] = "ASSURED", +}; + +/* + * register an ipvs protocol + */ +static int __used __init register_ip_vs_protocol(struct ip_vs_protocol *pp) +{ + unsigned int hash = IP_VS_PROTO_HASH(pp->protocol); + + pp->next = ip_vs_proto_table[hash]; + ip_vs_proto_table[hash] = pp; + + if (pp->init != NULL) + pp->init(pp); + + return 0; +} + +/* + * register an ipvs protocols netns related data + */ +static int +register_ip_vs_proto_netns(struct netns_ipvs *ipvs, struct ip_vs_protocol *pp) +{ + unsigned int hash = IP_VS_PROTO_HASH(pp->protocol); + struct ip_vs_proto_data *pd = + kzalloc(sizeof(struct ip_vs_proto_data), GFP_KERNEL); + + if (!pd) + return -ENOMEM; + + pd->pp = pp; /* For speed issues */ + pd->next = ipvs->proto_data_table[hash]; + ipvs->proto_data_table[hash] = pd; + atomic_set(&pd->appcnt, 0); /* Init app counter */ + + if (pp->init_netns != NULL) { + int ret = pp->init_netns(ipvs, pd); + if (ret) { + /* unlink an free proto data */ + ipvs->proto_data_table[hash] = pd->next; + kfree(pd); + return ret; + } + } + + return 0; +} + +/* + * unregister an ipvs protocol + */ +static int unregister_ip_vs_protocol(struct ip_vs_protocol *pp) +{ + struct ip_vs_protocol **pp_p; + unsigned int hash = IP_VS_PROTO_HASH(pp->protocol); + + pp_p = &ip_vs_proto_table[hash]; + for (; *pp_p; pp_p = &(*pp_p)->next) { + if (*pp_p == pp) { + *pp_p = pp->next; + if (pp->exit != NULL) + pp->exit(pp); + return 0; + } + } + + return -ESRCH; +} + +/* + * unregister an ipvs protocols netns data + */ +static int +unregister_ip_vs_proto_netns(struct netns_ipvs *ipvs, struct ip_vs_proto_data *pd) +{ + struct ip_vs_proto_data **pd_p; + unsigned int hash = IP_VS_PROTO_HASH(pd->pp->protocol); + + pd_p = &ipvs->proto_data_table[hash]; + for (; *pd_p; pd_p = &(*pd_p)->next) { + if (*pd_p == pd) { + *pd_p = pd->next; + if (pd->pp->exit_netns != NULL) + pd->pp->exit_netns(ipvs, pd); + kfree(pd); + return 0; + } + } + + return -ESRCH; +} + +/* + * get ip_vs_protocol object by its proto. + */ +struct ip_vs_protocol * ip_vs_proto_get(unsigned short proto) +{ + struct ip_vs_protocol *pp; + unsigned int hash = IP_VS_PROTO_HASH(proto); + + for (pp = ip_vs_proto_table[hash]; pp; pp = pp->next) { + if (pp->protocol == proto) + return pp; + } + + return NULL; +} +EXPORT_SYMBOL(ip_vs_proto_get); + +/* + * get ip_vs_protocol object data by netns and proto + */ +struct ip_vs_proto_data * +ip_vs_proto_data_get(struct netns_ipvs *ipvs, unsigned short proto) +{ + struct ip_vs_proto_data *pd; + unsigned int hash = IP_VS_PROTO_HASH(proto); + + for (pd = ipvs->proto_data_table[hash]; pd; pd = pd->next) { + if (pd->pp->protocol == proto) + return pd; + } + + return NULL; +} +EXPORT_SYMBOL(ip_vs_proto_data_get); + +/* + * Propagate event for state change to all protocols + */ +void ip_vs_protocol_timeout_change(struct netns_ipvs *ipvs, int flags) +{ + struct ip_vs_proto_data *pd; + int i; + + for (i = 0; i < IP_VS_PROTO_TAB_SIZE; i++) { + for (pd = ipvs->proto_data_table[i]; pd; pd = pd->next) { + if (pd->pp->timeout_change) + pd->pp->timeout_change(pd, flags); + } + } +} + + +int * +ip_vs_create_timeout_table(int *table, int size) +{ + return kmemdup(table, size, GFP_KERNEL); +} + + +const char *ip_vs_state_name(const struct ip_vs_conn *cp) +{ + unsigned int state = cp->state; + struct ip_vs_protocol *pp; + + if (cp->flags & IP_VS_CONN_F_TEMPLATE) { + + if (state >= IP_VS_CTPL_S_LAST) + return "ERR!"; + return ip_vs_ctpl_state_name_table[state] ? : "?"; + } + pp = ip_vs_proto_get(cp->protocol); + if (pp == NULL || pp->state_name == NULL) + return (cp->protocol == IPPROTO_IP) ? "NONE" : "ERR!"; + return pp->state_name(state); +} + + +static void +ip_vs_tcpudp_debug_packet_v4(struct ip_vs_protocol *pp, + const struct sk_buff *skb, + int offset, + const char *msg) +{ + char buf[128]; + struct iphdr _iph, *ih; + + ih = skb_header_pointer(skb, offset, sizeof(_iph), &_iph); + if (ih == NULL) + sprintf(buf, "TRUNCATED"); + else if (ih->frag_off & htons(IP_OFFSET)) + sprintf(buf, "%pI4->%pI4 frag", &ih->saddr, &ih->daddr); + else { + __be16 _ports[2], *pptr; + + pptr = skb_header_pointer(skb, offset + ih->ihl*4, + sizeof(_ports), _ports); + if (pptr == NULL) + sprintf(buf, "TRUNCATED %pI4->%pI4", + &ih->saddr, &ih->daddr); + else + sprintf(buf, "%pI4:%u->%pI4:%u", + &ih->saddr, ntohs(pptr[0]), + &ih->daddr, ntohs(pptr[1])); + } + + pr_debug("%s: %s %s\n", msg, pp->name, buf); +} + +#ifdef CONFIG_IP_VS_IPV6 +static void +ip_vs_tcpudp_debug_packet_v6(struct ip_vs_protocol *pp, + const struct sk_buff *skb, + int offset, + const char *msg) +{ + char buf[192]; + struct ipv6hdr _iph, *ih; + + ih = skb_header_pointer(skb, offset, sizeof(_iph), &_iph); + if (ih == NULL) + sprintf(buf, "TRUNCATED"); + else if (ih->nexthdr == IPPROTO_FRAGMENT) + sprintf(buf, "%pI6c->%pI6c frag", &ih->saddr, &ih->daddr); + else { + __be16 _ports[2], *pptr; + + pptr = skb_header_pointer(skb, offset + sizeof(struct ipv6hdr), + sizeof(_ports), _ports); + if (pptr == NULL) + sprintf(buf, "TRUNCATED %pI6c->%pI6c", + &ih->saddr, &ih->daddr); + else + sprintf(buf, "%pI6c:%u->%pI6c:%u", + &ih->saddr, ntohs(pptr[0]), + &ih->daddr, ntohs(pptr[1])); + } + + pr_debug("%s: %s %s\n", msg, pp->name, buf); +} +#endif + + +void +ip_vs_tcpudp_debug_packet(int af, struct ip_vs_protocol *pp, + const struct sk_buff *skb, + int offset, + const char *msg) +{ +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + ip_vs_tcpudp_debug_packet_v6(pp, skb, offset, msg); + else +#endif + ip_vs_tcpudp_debug_packet_v4(pp, skb, offset, msg); +} + +/* + * per network name-space init + */ +int __net_init ip_vs_protocol_net_init(struct netns_ipvs *ipvs) +{ + int i, ret; + static struct ip_vs_protocol *protos[] = { +#ifdef CONFIG_IP_VS_PROTO_TCP + &ip_vs_protocol_tcp, +#endif +#ifdef CONFIG_IP_VS_PROTO_UDP + &ip_vs_protocol_udp, +#endif +#ifdef CONFIG_IP_VS_PROTO_SCTP + &ip_vs_protocol_sctp, +#endif +#ifdef CONFIG_IP_VS_PROTO_AH + &ip_vs_protocol_ah, +#endif +#ifdef CONFIG_IP_VS_PROTO_ESP + &ip_vs_protocol_esp, +#endif + }; + + for (i = 0; i < ARRAY_SIZE(protos); i++) { + ret = register_ip_vs_proto_netns(ipvs, protos[i]); + if (ret < 0) + goto cleanup; + } + return 0; + +cleanup: + ip_vs_protocol_net_cleanup(ipvs); + return ret; +} + +void __net_exit ip_vs_protocol_net_cleanup(struct netns_ipvs *ipvs) +{ + struct ip_vs_proto_data *pd; + int i; + + /* unregister all the ipvs proto data for this netns */ + for (i = 0; i < IP_VS_PROTO_TAB_SIZE; i++) { + while ((pd = ipvs->proto_data_table[i]) != NULL) + unregister_ip_vs_proto_netns(ipvs, pd); + } +} + +int __init ip_vs_protocol_init(void) +{ + char protocols[64]; +#define REGISTER_PROTOCOL(p) \ + do { \ + register_ip_vs_protocol(p); \ + strcat(protocols, ", "); \ + strcat(protocols, (p)->name); \ + } while (0) + + protocols[0] = '\0'; + protocols[2] = '\0'; +#ifdef CONFIG_IP_VS_PROTO_TCP + REGISTER_PROTOCOL(&ip_vs_protocol_tcp); +#endif +#ifdef CONFIG_IP_VS_PROTO_UDP + REGISTER_PROTOCOL(&ip_vs_protocol_udp); +#endif +#ifdef CONFIG_IP_VS_PROTO_SCTP + REGISTER_PROTOCOL(&ip_vs_protocol_sctp); +#endif +#ifdef CONFIG_IP_VS_PROTO_AH + REGISTER_PROTOCOL(&ip_vs_protocol_ah); +#endif +#ifdef CONFIG_IP_VS_PROTO_ESP + REGISTER_PROTOCOL(&ip_vs_protocol_esp); +#endif + pr_info("Registered protocols (%s)\n", &protocols[2]); + + return 0; +} + + +void ip_vs_protocol_cleanup(void) +{ + struct ip_vs_protocol *pp; + int i; + + /* unregister all the ipvs protocols */ + for (i = 0; i < IP_VS_PROTO_TAB_SIZE; i++) { + while ((pp = ip_vs_proto_table[i]) != NULL) + unregister_ip_vs_protocol(pp); + } +} diff --git a/net/netfilter/ipvs/ip_vs_proto_ah_esp.c b/net/netfilter/ipvs/ip_vs_proto_ah_esp.c new file mode 100644 index 000000000..89602c16f --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_proto_ah_esp.c @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * ip_vs_proto_ah_esp.c: AH/ESP IPSec load balancing support for IPVS + * + * Authors: Julian Anastasov , February 2002 + * Wensong Zhang + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include + +#include + + +/* TODO: + +struct isakmp_hdr { + __u8 icookie[8]; + __u8 rcookie[8]; + __u8 np; + __u8 version; + __u8 xchgtype; + __u8 flags; + __u32 msgid; + __u32 length; +}; + +*/ + +#define PORT_ISAKMP 500 + +static void +ah_esp_conn_fill_param_proto(struct netns_ipvs *ipvs, int af, + const struct ip_vs_iphdr *iph, + struct ip_vs_conn_param *p) +{ + if (likely(!ip_vs_iph_inverse(iph))) + ip_vs_conn_fill_param(ipvs, af, IPPROTO_UDP, + &iph->saddr, htons(PORT_ISAKMP), + &iph->daddr, htons(PORT_ISAKMP), p); + else + ip_vs_conn_fill_param(ipvs, af, IPPROTO_UDP, + &iph->daddr, htons(PORT_ISAKMP), + &iph->saddr, htons(PORT_ISAKMP), p); +} + +static struct ip_vs_conn * +ah_esp_conn_in_get(struct netns_ipvs *ipvs, int af, const struct sk_buff *skb, + const struct ip_vs_iphdr *iph) +{ + struct ip_vs_conn *cp; + struct ip_vs_conn_param p; + + ah_esp_conn_fill_param_proto(ipvs, af, iph, &p); + cp = ip_vs_conn_in_get(&p); + if (!cp) { + /* + * We are not sure if the packet is from our + * service, so our conn_schedule hook should return NF_ACCEPT + */ + IP_VS_DBG_BUF(12, "Unknown ISAKMP entry for outin packet " + "%s%s %s->%s\n", + ip_vs_iph_icmp(iph) ? "ICMP+" : "", + ip_vs_proto_get(iph->protocol)->name, + IP_VS_DBG_ADDR(af, &iph->saddr), + IP_VS_DBG_ADDR(af, &iph->daddr)); + } + + return cp; +} + + +static struct ip_vs_conn * +ah_esp_conn_out_get(struct netns_ipvs *ipvs, int af, const struct sk_buff *skb, + const struct ip_vs_iphdr *iph) +{ + struct ip_vs_conn *cp; + struct ip_vs_conn_param p; + + ah_esp_conn_fill_param_proto(ipvs, af, iph, &p); + cp = ip_vs_conn_out_get(&p); + if (!cp) { + IP_VS_DBG_BUF(12, "Unknown ISAKMP entry for inout packet " + "%s%s %s->%s\n", + ip_vs_iph_icmp(iph) ? "ICMP+" : "", + ip_vs_proto_get(iph->protocol)->name, + IP_VS_DBG_ADDR(af, &iph->saddr), + IP_VS_DBG_ADDR(af, &iph->daddr)); + } + + return cp; +} + + +static int +ah_esp_conn_schedule(struct netns_ipvs *ipvs, int af, struct sk_buff *skb, + struct ip_vs_proto_data *pd, + int *verdict, struct ip_vs_conn **cpp, + struct ip_vs_iphdr *iph) +{ + /* + * AH/ESP is only related traffic. Pass the packet to IP stack. + */ + *verdict = NF_ACCEPT; + return 0; +} + +#ifdef CONFIG_IP_VS_PROTO_AH +struct ip_vs_protocol ip_vs_protocol_ah = { + .name = "AH", + .protocol = IPPROTO_AH, + .num_states = 1, + .dont_defrag = 1, + .init = NULL, + .exit = NULL, + .conn_schedule = ah_esp_conn_schedule, + .conn_in_get = ah_esp_conn_in_get, + .conn_out_get = ah_esp_conn_out_get, + .snat_handler = NULL, + .dnat_handler = NULL, + .state_transition = NULL, + .register_app = NULL, + .unregister_app = NULL, + .app_conn_bind = NULL, + .debug_packet = ip_vs_tcpudp_debug_packet, + .timeout_change = NULL, /* ISAKMP */ +}; +#endif + +#ifdef CONFIG_IP_VS_PROTO_ESP +struct ip_vs_protocol ip_vs_protocol_esp = { + .name = "ESP", + .protocol = IPPROTO_ESP, + .num_states = 1, + .dont_defrag = 1, + .init = NULL, + .exit = NULL, + .conn_schedule = ah_esp_conn_schedule, + .conn_in_get = ah_esp_conn_in_get, + .conn_out_get = ah_esp_conn_out_get, + .snat_handler = NULL, + .dnat_handler = NULL, + .state_transition = NULL, + .register_app = NULL, + .unregister_app = NULL, + .app_conn_bind = NULL, + .debug_packet = ip_vs_tcpudp_debug_packet, + .timeout_change = NULL, /* ISAKMP */ +}; +#endif diff --git a/net/netfilter/ipvs/ip_vs_proto_sctp.c b/net/netfilter/ipvs/ip_vs_proto_sctp.c new file mode 100644 index 000000000..a0921adc3 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_proto_sctp.c @@ -0,0 +1,595 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int +sctp_csum_check(int af, struct sk_buff *skb, struct ip_vs_protocol *pp); + +static int +sctp_conn_schedule(struct netns_ipvs *ipvs, int af, struct sk_buff *skb, + struct ip_vs_proto_data *pd, + int *verdict, struct ip_vs_conn **cpp, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_service *svc; + struct sctp_chunkhdr _schunkh, *sch; + struct sctphdr *sh, _sctph; + __be16 _ports[2], *ports = NULL; + + if (likely(!ip_vs_iph_icmp(iph))) { + sh = skb_header_pointer(skb, iph->len, sizeof(_sctph), &_sctph); + if (sh) { + sch = skb_header_pointer(skb, iph->len + sizeof(_sctph), + sizeof(_schunkh), &_schunkh); + if (sch) { + if (sch->type == SCTP_CID_ABORT || + !(sysctl_sloppy_sctp(ipvs) || + sch->type == SCTP_CID_INIT)) + return 1; + ports = &sh->source; + } + } + } else { + ports = skb_header_pointer( + skb, iph->len, sizeof(_ports), &_ports); + } + + if (!ports) { + *verdict = NF_DROP; + return 0; + } + + if (likely(!ip_vs_iph_inverse(iph))) + svc = ip_vs_service_find(ipvs, af, skb->mark, iph->protocol, + &iph->daddr, ports[1]); + else + svc = ip_vs_service_find(ipvs, af, skb->mark, iph->protocol, + &iph->saddr, ports[0]); + if (svc) { + int ignored; + + if (ip_vs_todrop(ipvs)) { + /* + * It seems that we are very loaded. + * We have to drop this packet :( + */ + *verdict = NF_DROP; + return 0; + } + /* + * Let the virtual server select a real server for the + * incoming connection, and create a connection entry. + */ + *cpp = ip_vs_schedule(svc, skb, pd, &ignored, iph); + if (!*cpp && ignored <= 0) { + if (!ignored) + *verdict = ip_vs_leave(svc, skb, pd, iph); + else + *verdict = NF_DROP; + return 0; + } + } + /* NF_ACCEPT */ + return 1; +} + +static void sctp_nat_csum(struct sk_buff *skb, struct sctphdr *sctph, + unsigned int sctphoff) +{ + sctph->checksum = sctp_compute_cksum(skb, sctphoff); + skb->ip_summed = CHECKSUM_UNNECESSARY; +} + +static int +sctp_snat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, + struct ip_vs_conn *cp, struct ip_vs_iphdr *iph) +{ + struct sctphdr *sctph; + unsigned int sctphoff = iph->len; + bool payload_csum = false; + +#ifdef CONFIG_IP_VS_IPV6 + if (cp->af == AF_INET6 && iph->fragoffs) + return 1; +#endif + + /* csum_check requires unshared skb */ + if (skb_ensure_writable(skb, sctphoff + sizeof(*sctph))) + return 0; + + if (unlikely(cp->app != NULL)) { + int ret; + + /* Some checks before mangling */ + if (!sctp_csum_check(cp->af, skb, pp)) + return 0; + + /* Call application helper if needed */ + ret = ip_vs_app_pkt_out(cp, skb, iph); + if (ret == 0) + return 0; + /* ret=2: csum update is needed after payload mangling */ + if (ret == 2) + payload_csum = true; + } + + sctph = (void *) skb_network_header(skb) + sctphoff; + + /* Only update csum if we really have to */ + if (sctph->source != cp->vport || payload_csum || + skb->ip_summed == CHECKSUM_PARTIAL) { + sctph->source = cp->vport; + sctp_nat_csum(skb, sctph, sctphoff); + } else { + skb->ip_summed = CHECKSUM_UNNECESSARY; + } + + return 1; +} + +static int +sctp_dnat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, + struct ip_vs_conn *cp, struct ip_vs_iphdr *iph) +{ + struct sctphdr *sctph; + unsigned int sctphoff = iph->len; + bool payload_csum = false; + +#ifdef CONFIG_IP_VS_IPV6 + if (cp->af == AF_INET6 && iph->fragoffs) + return 1; +#endif + + /* csum_check requires unshared skb */ + if (skb_ensure_writable(skb, sctphoff + sizeof(*sctph))) + return 0; + + if (unlikely(cp->app != NULL)) { + int ret; + + /* Some checks before mangling */ + if (!sctp_csum_check(cp->af, skb, pp)) + return 0; + + /* Call application helper if needed */ + ret = ip_vs_app_pkt_in(cp, skb, iph); + if (ret == 0) + return 0; + /* ret=2: csum update is needed after payload mangling */ + if (ret == 2) + payload_csum = true; + } + + sctph = (void *) skb_network_header(skb) + sctphoff; + + /* Only update csum if we really have to */ + if (sctph->dest != cp->dport || payload_csum || + (skb->ip_summed == CHECKSUM_PARTIAL && + !(skb_dst(skb)->dev->features & NETIF_F_SCTP_CRC))) { + sctph->dest = cp->dport; + sctp_nat_csum(skb, sctph, sctphoff); + } else if (skb->ip_summed != CHECKSUM_PARTIAL) { + skb->ip_summed = CHECKSUM_UNNECESSARY; + } + + return 1; +} + +static int +sctp_csum_check(int af, struct sk_buff *skb, struct ip_vs_protocol *pp) +{ + unsigned int sctphoff; + struct sctphdr *sh; + __le32 cmp, val; + +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + sctphoff = sizeof(struct ipv6hdr); + else +#endif + sctphoff = ip_hdrlen(skb); + + sh = (struct sctphdr *)(skb->data + sctphoff); + cmp = sh->checksum; + val = sctp_compute_cksum(skb, sctphoff); + + if (val != cmp) { + /* CRC failure, dump it. */ + IP_VS_DBG_RL_PKT(0, af, pp, skb, 0, + "Failed checksum for"); + return 0; + } + return 1; +} + +enum ipvs_sctp_event_t { + IP_VS_SCTP_DATA = 0, /* DATA, SACK, HEARTBEATs */ + IP_VS_SCTP_INIT, + IP_VS_SCTP_INIT_ACK, + IP_VS_SCTP_COOKIE_ECHO, + IP_VS_SCTP_COOKIE_ACK, + IP_VS_SCTP_SHUTDOWN, + IP_VS_SCTP_SHUTDOWN_ACK, + IP_VS_SCTP_SHUTDOWN_COMPLETE, + IP_VS_SCTP_ERROR, + IP_VS_SCTP_ABORT, + IP_VS_SCTP_EVENT_LAST +}; + +/* RFC 2960, 3.2 Chunk Field Descriptions */ +static __u8 sctp_events[] = { + [SCTP_CID_DATA] = IP_VS_SCTP_DATA, + [SCTP_CID_INIT] = IP_VS_SCTP_INIT, + [SCTP_CID_INIT_ACK] = IP_VS_SCTP_INIT_ACK, + [SCTP_CID_SACK] = IP_VS_SCTP_DATA, + [SCTP_CID_HEARTBEAT] = IP_VS_SCTP_DATA, + [SCTP_CID_HEARTBEAT_ACK] = IP_VS_SCTP_DATA, + [SCTP_CID_ABORT] = IP_VS_SCTP_ABORT, + [SCTP_CID_SHUTDOWN] = IP_VS_SCTP_SHUTDOWN, + [SCTP_CID_SHUTDOWN_ACK] = IP_VS_SCTP_SHUTDOWN_ACK, + [SCTP_CID_ERROR] = IP_VS_SCTP_ERROR, + [SCTP_CID_COOKIE_ECHO] = IP_VS_SCTP_COOKIE_ECHO, + [SCTP_CID_COOKIE_ACK] = IP_VS_SCTP_COOKIE_ACK, + [SCTP_CID_ECN_ECNE] = IP_VS_SCTP_DATA, + [SCTP_CID_ECN_CWR] = IP_VS_SCTP_DATA, + [SCTP_CID_SHUTDOWN_COMPLETE] = IP_VS_SCTP_SHUTDOWN_COMPLETE, +}; + +/* SCTP States: + * See RFC 2960, 4. SCTP Association State Diagram + * + * New states (not in diagram): + * - INIT1 state: use shorter timeout for dropped INIT packets + * - REJECTED state: use shorter timeout if INIT is rejected with ABORT + * - INIT, COOKIE_SENT, COOKIE_REPLIED, COOKIE states: for better debugging + * + * The states are as seen in real server. In the diagram, INIT1, INIT, + * COOKIE_SENT and COOKIE_REPLIED processing happens in CLOSED state. + * + * States as per packets from client (C) and server (S): + * + * Setup of client connection: + * IP_VS_SCTP_S_INIT1: First C:INIT sent, wait for S:INIT-ACK + * IP_VS_SCTP_S_INIT: Next C:INIT sent, wait for S:INIT-ACK + * IP_VS_SCTP_S_COOKIE_SENT: S:INIT-ACK sent, wait for C:COOKIE-ECHO + * IP_VS_SCTP_S_COOKIE_REPLIED: C:COOKIE-ECHO sent, wait for S:COOKIE-ACK + * + * Setup of server connection: + * IP_VS_SCTP_S_COOKIE_WAIT: S:INIT sent, wait for C:INIT-ACK + * IP_VS_SCTP_S_COOKIE: C:INIT-ACK sent, wait for S:COOKIE-ECHO + * IP_VS_SCTP_S_COOKIE_ECHOED: S:COOKIE-ECHO sent, wait for C:COOKIE-ACK + */ + +#define sNO IP_VS_SCTP_S_NONE +#define sI1 IP_VS_SCTP_S_INIT1 +#define sIN IP_VS_SCTP_S_INIT +#define sCS IP_VS_SCTP_S_COOKIE_SENT +#define sCR IP_VS_SCTP_S_COOKIE_REPLIED +#define sCW IP_VS_SCTP_S_COOKIE_WAIT +#define sCO IP_VS_SCTP_S_COOKIE +#define sCE IP_VS_SCTP_S_COOKIE_ECHOED +#define sES IP_VS_SCTP_S_ESTABLISHED +#define sSS IP_VS_SCTP_S_SHUTDOWN_SENT +#define sSR IP_VS_SCTP_S_SHUTDOWN_RECEIVED +#define sSA IP_VS_SCTP_S_SHUTDOWN_ACK_SENT +#define sRJ IP_VS_SCTP_S_REJECTED +#define sCL IP_VS_SCTP_S_CLOSED + +static const __u8 sctp_states + [IP_VS_DIR_LAST][IP_VS_SCTP_EVENT_LAST][IP_VS_SCTP_S_LAST] = { + { /* INPUT */ +/* sNO, sI1, sIN, sCS, sCR, sCW, sCO, sCE, sES, sSS, sSR, sSA, sRJ, sCL*/ +/* d */{sES, sI1, sIN, sCS, sCR, sCW, sCO, sCE, sES, sSS, sSR, sSA, sRJ, sCL}, +/* i */{sI1, sIN, sIN, sCS, sCR, sCW, sCO, sCE, sES, sSS, sSR, sSA, sIN, sIN}, +/* i_a */{sCW, sCW, sCW, sCS, sCR, sCO, sCO, sCE, sES, sSS, sSR, sSA, sRJ, sCL}, +/* c_e */{sCR, sIN, sIN, sCR, sCR, sCW, sCO, sCE, sES, sSS, sSR, sSA, sRJ, sCL}, +/* c_a */{sES, sI1, sIN, sCS, sCR, sCW, sCO, sES, sES, sSS, sSR, sSA, sRJ, sCL}, +/* s */{sSR, sI1, sIN, sCS, sCR, sCW, sCO, sCE, sSR, sSS, sSR, sSA, sRJ, sCL}, +/* s_a */{sCL, sIN, sIN, sCS, sCR, sCW, sCO, sCE, sES, sCL, sSR, sCL, sRJ, sCL}, +/* s_c */{sCL, sCL, sCL, sCS, sCR, sCW, sCO, sCE, sES, sSS, sSR, sCL, sRJ, sCL}, +/* err */{sCL, sI1, sIN, sCS, sCR, sCW, sCO, sCL, sES, sSS, sSR, sSA, sRJ, sCL}, +/* ab */{sCL, sCL, sCL, sCL, sCL, sRJ, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL}, + }, + { /* OUTPUT */ +/* sNO, sI1, sIN, sCS, sCR, sCW, sCO, sCE, sES, sSS, sSR, sSA, sRJ, sCL*/ +/* d */{sES, sI1, sIN, sCS, sCR, sCW, sCO, sCE, sES, sSS, sSR, sSA, sRJ, sCL}, +/* i */{sCW, sCW, sCW, sCW, sCW, sCW, sCW, sCW, sES, sCW, sCW, sCW, sCW, sCW}, +/* i_a */{sCS, sCS, sCS, sCS, sCR, sCW, sCO, sCE, sES, sSS, sSR, sSA, sRJ, sCL}, +/* c_e */{sCE, sCE, sCE, sCE, sCE, sCE, sCE, sCE, sES, sSS, sSR, sSA, sRJ, sCL}, +/* c_a */{sES, sES, sES, sES, sES, sES, sES, sES, sES, sSS, sSR, sSA, sRJ, sCL}, +/* s */{sSS, sSS, sSS, sSS, sSS, sSS, sSS, sSS, sSS, sSS, sSR, sSA, sRJ, sCL}, +/* s_a */{sSA, sSA, sSA, sSA, sSA, sCW, sCO, sCE, sES, sSA, sSA, sSA, sRJ, sCL}, +/* s_c */{sCL, sI1, sIN, sCS, sCR, sCW, sCO, sCE, sES, sSS, sSR, sSA, sRJ, sCL}, +/* err */{sCL, sCL, sCL, sCL, sCL, sCW, sCO, sCE, sES, sSS, sSR, sSA, sRJ, sCL}, +/* ab */{sCL, sRJ, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL}, + }, + { /* INPUT-ONLY */ +/* sNO, sI1, sIN, sCS, sCR, sCW, sCO, sCE, sES, sSS, sSR, sSA, sRJ, sCL*/ +/* d */{sES, sI1, sIN, sCS, sCR, sES, sCO, sCE, sES, sSS, sSR, sSA, sRJ, sCL}, +/* i */{sI1, sIN, sIN, sIN, sIN, sIN, sCO, sCE, sES, sSS, sSR, sSA, sIN, sIN}, +/* i_a */{sCE, sCE, sCE, sCE, sCE, sCE, sCO, sCE, sES, sSS, sSR, sSA, sRJ, sCL}, +/* c_e */{sES, sES, sES, sES, sES, sES, sCO, sCE, sES, sSS, sSR, sSA, sRJ, sCL}, +/* c_a */{sES, sI1, sIN, sES, sES, sCW, sES, sES, sES, sSS, sSR, sSA, sRJ, sCL}, +/* s */{sSR, sI1, sIN, sCS, sCR, sCW, sCO, sCE, sSR, sSS, sSR, sSA, sRJ, sCL}, +/* s_a */{sCL, sIN, sIN, sCS, sCR, sCW, sCO, sCE, sCL, sCL, sSR, sCL, sRJ, sCL}, +/* s_c */{sCL, sCL, sCL, sCL, sCL, sCW, sCO, sCE, sES, sSS, sCL, sCL, sRJ, sCL}, +/* err */{sCL, sI1, sIN, sCS, sCR, sCW, sCO, sCE, sES, sSS, sSR, sSA, sRJ, sCL}, +/* ab */{sCL, sCL, sCL, sCL, sCL, sRJ, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL}, + }, +}; + +#define IP_VS_SCTP_MAX_RTO ((60 + 1) * HZ) + +/* Timeout table[state] */ +static const int sctp_timeouts[IP_VS_SCTP_S_LAST + 1] = { + [IP_VS_SCTP_S_NONE] = 2 * HZ, + [IP_VS_SCTP_S_INIT1] = (0 + 3 + 1) * HZ, + [IP_VS_SCTP_S_INIT] = IP_VS_SCTP_MAX_RTO, + [IP_VS_SCTP_S_COOKIE_SENT] = IP_VS_SCTP_MAX_RTO, + [IP_VS_SCTP_S_COOKIE_REPLIED] = IP_VS_SCTP_MAX_RTO, + [IP_VS_SCTP_S_COOKIE_WAIT] = IP_VS_SCTP_MAX_RTO, + [IP_VS_SCTP_S_COOKIE] = IP_VS_SCTP_MAX_RTO, + [IP_VS_SCTP_S_COOKIE_ECHOED] = IP_VS_SCTP_MAX_RTO, + [IP_VS_SCTP_S_ESTABLISHED] = 15 * 60 * HZ, + [IP_VS_SCTP_S_SHUTDOWN_SENT] = IP_VS_SCTP_MAX_RTO, + [IP_VS_SCTP_S_SHUTDOWN_RECEIVED] = IP_VS_SCTP_MAX_RTO, + [IP_VS_SCTP_S_SHUTDOWN_ACK_SENT] = IP_VS_SCTP_MAX_RTO, + [IP_VS_SCTP_S_REJECTED] = (0 + 3 + 1) * HZ, + [IP_VS_SCTP_S_CLOSED] = IP_VS_SCTP_MAX_RTO, + [IP_VS_SCTP_S_LAST] = 2 * HZ, +}; + +static const char *sctp_state_name_table[IP_VS_SCTP_S_LAST + 1] = { + [IP_VS_SCTP_S_NONE] = "NONE", + [IP_VS_SCTP_S_INIT1] = "INIT1", + [IP_VS_SCTP_S_INIT] = "INIT", + [IP_VS_SCTP_S_COOKIE_SENT] = "C-SENT", + [IP_VS_SCTP_S_COOKIE_REPLIED] = "C-REPLIED", + [IP_VS_SCTP_S_COOKIE_WAIT] = "C-WAIT", + [IP_VS_SCTP_S_COOKIE] = "COOKIE", + [IP_VS_SCTP_S_COOKIE_ECHOED] = "C-ECHOED", + [IP_VS_SCTP_S_ESTABLISHED] = "ESTABLISHED", + [IP_VS_SCTP_S_SHUTDOWN_SENT] = "S-SENT", + [IP_VS_SCTP_S_SHUTDOWN_RECEIVED] = "S-RECEIVED", + [IP_VS_SCTP_S_SHUTDOWN_ACK_SENT] = "S-ACK-SENT", + [IP_VS_SCTP_S_REJECTED] = "REJECTED", + [IP_VS_SCTP_S_CLOSED] = "CLOSED", + [IP_VS_SCTP_S_LAST] = "BUG!", +}; + + +static const char *sctp_state_name(int state) +{ + if (state >= IP_VS_SCTP_S_LAST) + return "ERR!"; + if (sctp_state_name_table[state]) + return sctp_state_name_table[state]; + return "?"; +} + +static inline void +set_sctp_state(struct ip_vs_proto_data *pd, struct ip_vs_conn *cp, + int direction, const struct sk_buff *skb) +{ + struct sctp_chunkhdr _sctpch, *sch; + unsigned char chunk_type; + int event, next_state; + int ihl, cofs; + +#ifdef CONFIG_IP_VS_IPV6 + ihl = cp->af == AF_INET ? ip_hdrlen(skb) : sizeof(struct ipv6hdr); +#else + ihl = ip_hdrlen(skb); +#endif + + cofs = ihl + sizeof(struct sctphdr); + sch = skb_header_pointer(skb, cofs, sizeof(_sctpch), &_sctpch); + if (sch == NULL) + return; + + chunk_type = sch->type; + /* + * Section 3: Multiple chunks can be bundled into one SCTP packet + * up to the MTU size, except for the INIT, INIT ACK, and + * SHUTDOWN COMPLETE chunks. These chunks MUST NOT be bundled with + * any other chunk in a packet. + * + * Section 3.3.7: DATA chunks MUST NOT be bundled with ABORT. Control + * chunks (except for INIT, INIT ACK, and SHUTDOWN COMPLETE) MAY be + * bundled with an ABORT, but they MUST be placed before the ABORT + * in the SCTP packet or they will be ignored by the receiver. + */ + if ((sch->type == SCTP_CID_COOKIE_ECHO) || + (sch->type == SCTP_CID_COOKIE_ACK)) { + int clen = ntohs(sch->length); + + if (clen >= sizeof(_sctpch)) { + sch = skb_header_pointer(skb, cofs + ALIGN(clen, 4), + sizeof(_sctpch), &_sctpch); + if (sch && sch->type == SCTP_CID_ABORT) + chunk_type = sch->type; + } + } + + event = (chunk_type < sizeof(sctp_events)) ? + sctp_events[chunk_type] : IP_VS_SCTP_DATA; + + /* Update direction to INPUT_ONLY if necessary + * or delete NO_OUTPUT flag if output packet detected + */ + if (cp->flags & IP_VS_CONN_F_NOOUTPUT) { + if (direction == IP_VS_DIR_OUTPUT) + cp->flags &= ~IP_VS_CONN_F_NOOUTPUT; + else + direction = IP_VS_DIR_INPUT_ONLY; + } + + next_state = sctp_states[direction][event][cp->state]; + + if (next_state != cp->state) { + struct ip_vs_dest *dest = cp->dest; + + IP_VS_DBG_BUF(8, "%s %s %s:%d->" + "%s:%d state: %s->%s conn->refcnt:%d\n", + pd->pp->name, + ((direction == IP_VS_DIR_OUTPUT) ? + "output " : "input "), + IP_VS_DBG_ADDR(cp->daf, &cp->daddr), + ntohs(cp->dport), + IP_VS_DBG_ADDR(cp->af, &cp->caddr), + ntohs(cp->cport), + sctp_state_name(cp->state), + sctp_state_name(next_state), + refcount_read(&cp->refcnt)); + if (dest) { + if (!(cp->flags & IP_VS_CONN_F_INACTIVE) && + (next_state != IP_VS_SCTP_S_ESTABLISHED)) { + atomic_dec(&dest->activeconns); + atomic_inc(&dest->inactconns); + cp->flags |= IP_VS_CONN_F_INACTIVE; + } else if ((cp->flags & IP_VS_CONN_F_INACTIVE) && + (next_state == IP_VS_SCTP_S_ESTABLISHED)) { + atomic_inc(&dest->activeconns); + atomic_dec(&dest->inactconns); + cp->flags &= ~IP_VS_CONN_F_INACTIVE; + } + } + if (next_state == IP_VS_SCTP_S_ESTABLISHED) + ip_vs_control_assure_ct(cp); + } + if (likely(pd)) + cp->timeout = pd->timeout_table[cp->state = next_state]; + else /* What to do ? */ + cp->timeout = sctp_timeouts[cp->state = next_state]; +} + +static void +sctp_state_transition(struct ip_vs_conn *cp, int direction, + const struct sk_buff *skb, struct ip_vs_proto_data *pd) +{ + spin_lock_bh(&cp->lock); + set_sctp_state(pd, cp, direction, skb); + spin_unlock_bh(&cp->lock); +} + +static inline __u16 sctp_app_hashkey(__be16 port) +{ + return (((__force u16)port >> SCTP_APP_TAB_BITS) ^ (__force u16)port) + & SCTP_APP_TAB_MASK; +} + +static int sctp_register_app(struct netns_ipvs *ipvs, struct ip_vs_app *inc) +{ + struct ip_vs_app *i; + __u16 hash; + __be16 port = inc->port; + int ret = 0; + struct ip_vs_proto_data *pd = ip_vs_proto_data_get(ipvs, IPPROTO_SCTP); + + hash = sctp_app_hashkey(port); + + list_for_each_entry(i, &ipvs->sctp_apps[hash], p_list) { + if (i->port == port) { + ret = -EEXIST; + goto out; + } + } + list_add_rcu(&inc->p_list, &ipvs->sctp_apps[hash]); + atomic_inc(&pd->appcnt); +out: + + return ret; +} + +static void sctp_unregister_app(struct netns_ipvs *ipvs, struct ip_vs_app *inc) +{ + struct ip_vs_proto_data *pd = ip_vs_proto_data_get(ipvs, IPPROTO_SCTP); + + atomic_dec(&pd->appcnt); + list_del_rcu(&inc->p_list); +} + +static int sctp_app_conn_bind(struct ip_vs_conn *cp) +{ + struct netns_ipvs *ipvs = cp->ipvs; + int hash; + struct ip_vs_app *inc; + int result = 0; + + /* Default binding: bind app only for NAT */ + if (IP_VS_FWD_METHOD(cp) != IP_VS_CONN_F_MASQ) + return 0; + /* Lookup application incarnations and bind the right one */ + hash = sctp_app_hashkey(cp->vport); + + list_for_each_entry_rcu(inc, &ipvs->sctp_apps[hash], p_list) { + if (inc->port == cp->vport) { + if (unlikely(!ip_vs_app_inc_get(inc))) + break; + + IP_VS_DBG_BUF(9, "%s: Binding conn %s:%u->" + "%s:%u to app %s on port %u\n", + __func__, + IP_VS_DBG_ADDR(cp->af, &cp->caddr), + ntohs(cp->cport), + IP_VS_DBG_ADDR(cp->af, &cp->vaddr), + ntohs(cp->vport), + inc->name, ntohs(inc->port)); + cp->app = inc; + if (inc->init_conn) + result = inc->init_conn(inc, cp); + break; + } + } + + return result; +} + +/* --------------------------------------------- + * timeouts is netns related now. + * --------------------------------------------- + */ +static int __ip_vs_sctp_init(struct netns_ipvs *ipvs, struct ip_vs_proto_data *pd) +{ + ip_vs_init_hash_table(ipvs->sctp_apps, SCTP_APP_TAB_SIZE); + pd->timeout_table = ip_vs_create_timeout_table((int *)sctp_timeouts, + sizeof(sctp_timeouts)); + if (!pd->timeout_table) + return -ENOMEM; + return 0; +} + +static void __ip_vs_sctp_exit(struct netns_ipvs *ipvs, struct ip_vs_proto_data *pd) +{ + kfree(pd->timeout_table); +} + +struct ip_vs_protocol ip_vs_protocol_sctp = { + .name = "SCTP", + .protocol = IPPROTO_SCTP, + .num_states = IP_VS_SCTP_S_LAST, + .dont_defrag = 0, + .init = NULL, + .exit = NULL, + .init_netns = __ip_vs_sctp_init, + .exit_netns = __ip_vs_sctp_exit, + .register_app = sctp_register_app, + .unregister_app = sctp_unregister_app, + .conn_schedule = sctp_conn_schedule, + .conn_in_get = ip_vs_conn_in_get_proto, + .conn_out_get = ip_vs_conn_out_get_proto, + .snat_handler = sctp_snat_handler, + .dnat_handler = sctp_dnat_handler, + .state_name = sctp_state_name, + .state_transition = sctp_state_transition, + .app_conn_bind = sctp_app_conn_bind, + .debug_packet = ip_vs_tcpudp_debug_packet, + .timeout_change = NULL, +}; diff --git a/net/netfilter/ipvs/ip_vs_proto_tcp.c b/net/netfilter/ipvs/ip_vs_proto_tcp.c new file mode 100644 index 000000000..7da51390c --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_proto_tcp.c @@ -0,0 +1,746 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * ip_vs_proto_tcp.c: TCP load balancing support for IPVS + * + * Authors: Wensong Zhang + * Julian Anastasov + * + * Changes: Hans Schillstrom + * + * Network name space (netns) aware. + * Global data moved to netns i.e struct netns_ipvs + * tcp_timeouts table has copy per netns in a hash table per + * protocol ip_vs_proto_data and is handled by netns + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include /* for tcphdr */ +#include +#include /* for csum_tcpudp_magic */ +#include +#include +#include +#include + +#include + +static int +tcp_csum_check(int af, struct sk_buff *skb, struct ip_vs_protocol *pp); + +static int +tcp_conn_schedule(struct netns_ipvs *ipvs, int af, struct sk_buff *skb, + struct ip_vs_proto_data *pd, + int *verdict, struct ip_vs_conn **cpp, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_service *svc; + struct tcphdr _tcph, *th; + __be16 _ports[2], *ports = NULL; + + /* In the event of icmp, we're only guaranteed to have the first 8 + * bytes of the transport header, so we only check the rest of the + * TCP packet for non-ICMP packets + */ + if (likely(!ip_vs_iph_icmp(iph))) { + th = skb_header_pointer(skb, iph->len, sizeof(_tcph), &_tcph); + if (th) { + if (th->rst || !(sysctl_sloppy_tcp(ipvs) || th->syn)) + return 1; + ports = &th->source; + } + } else { + ports = skb_header_pointer( + skb, iph->len, sizeof(_ports), &_ports); + } + + if (!ports) { + *verdict = NF_DROP; + return 0; + } + + /* No !th->ack check to allow scheduling on SYN+ACK for Active FTP */ + + if (likely(!ip_vs_iph_inverse(iph))) + svc = ip_vs_service_find(ipvs, af, skb->mark, iph->protocol, + &iph->daddr, ports[1]); + else + svc = ip_vs_service_find(ipvs, af, skb->mark, iph->protocol, + &iph->saddr, ports[0]); + + if (svc) { + int ignored; + + if (ip_vs_todrop(ipvs)) { + /* + * It seems that we are very loaded. + * We have to drop this packet :( + */ + *verdict = NF_DROP; + return 0; + } + + /* + * Let the virtual server select a real server for the + * incoming connection, and create a connection entry. + */ + *cpp = ip_vs_schedule(svc, skb, pd, &ignored, iph); + if (!*cpp && ignored <= 0) { + if (!ignored) + *verdict = ip_vs_leave(svc, skb, pd, iph); + else + *verdict = NF_DROP; + return 0; + } + } + /* NF_ACCEPT */ + return 1; +} + + +static inline void +tcp_fast_csum_update(int af, struct tcphdr *tcph, + const union nf_inet_addr *oldip, + const union nf_inet_addr *newip, + __be16 oldport, __be16 newport) +{ +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + tcph->check = + csum_fold(ip_vs_check_diff16(oldip->ip6, newip->ip6, + ip_vs_check_diff2(oldport, newport, + ~csum_unfold(tcph->check)))); + else +#endif + tcph->check = + csum_fold(ip_vs_check_diff4(oldip->ip, newip->ip, + ip_vs_check_diff2(oldport, newport, + ~csum_unfold(tcph->check)))); +} + + +static inline void +tcp_partial_csum_update(int af, struct tcphdr *tcph, + const union nf_inet_addr *oldip, + const union nf_inet_addr *newip, + __be16 oldlen, __be16 newlen) +{ +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + tcph->check = + ~csum_fold(ip_vs_check_diff16(oldip->ip6, newip->ip6, + ip_vs_check_diff2(oldlen, newlen, + csum_unfold(tcph->check)))); + else +#endif + tcph->check = + ~csum_fold(ip_vs_check_diff4(oldip->ip, newip->ip, + ip_vs_check_diff2(oldlen, newlen, + csum_unfold(tcph->check)))); +} + + +INDIRECT_CALLABLE_SCOPE int +tcp_snat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, + struct ip_vs_conn *cp, struct ip_vs_iphdr *iph) +{ + struct tcphdr *tcph; + unsigned int tcphoff = iph->len; + bool payload_csum = false; + int oldlen; + +#ifdef CONFIG_IP_VS_IPV6 + if (cp->af == AF_INET6 && iph->fragoffs) + return 1; +#endif + oldlen = skb->len - tcphoff; + + /* csum_check requires unshared skb */ + if (skb_ensure_writable(skb, tcphoff + sizeof(*tcph))) + return 0; + + if (unlikely(cp->app != NULL)) { + int ret; + + /* Some checks before mangling */ + if (!tcp_csum_check(cp->af, skb, pp)) + return 0; + + /* Call application helper if needed */ + if (!(ret = ip_vs_app_pkt_out(cp, skb, iph))) + return 0; + /* ret=2: csum update is needed after payload mangling */ + if (ret == 1) + oldlen = skb->len - tcphoff; + else + payload_csum = true; + } + + tcph = (void *)skb_network_header(skb) + tcphoff; + tcph->source = cp->vport; + + /* Adjust TCP checksums */ + if (skb->ip_summed == CHECKSUM_PARTIAL) { + tcp_partial_csum_update(cp->af, tcph, &cp->daddr, &cp->vaddr, + htons(oldlen), + htons(skb->len - tcphoff)); + } else if (!payload_csum) { + /* Only port and addr are changed, do fast csum update */ + tcp_fast_csum_update(cp->af, tcph, &cp->daddr, &cp->vaddr, + cp->dport, cp->vport); + if (skb->ip_summed == CHECKSUM_COMPLETE) + skb->ip_summed = cp->app ? + CHECKSUM_UNNECESSARY : CHECKSUM_NONE; + } else { + /* full checksum calculation */ + tcph->check = 0; + skb->csum = skb_checksum(skb, tcphoff, skb->len - tcphoff, 0); +#ifdef CONFIG_IP_VS_IPV6 + if (cp->af == AF_INET6) + tcph->check = csum_ipv6_magic(&cp->vaddr.in6, + &cp->caddr.in6, + skb->len - tcphoff, + cp->protocol, skb->csum); + else +#endif + tcph->check = csum_tcpudp_magic(cp->vaddr.ip, + cp->caddr.ip, + skb->len - tcphoff, + cp->protocol, + skb->csum); + skb->ip_summed = CHECKSUM_UNNECESSARY; + + IP_VS_DBG(11, "O-pkt: %s O-csum=%d (+%zd)\n", + pp->name, tcph->check, + (char*)&(tcph->check) - (char*)tcph); + } + return 1; +} + + +static int +tcp_dnat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, + struct ip_vs_conn *cp, struct ip_vs_iphdr *iph) +{ + struct tcphdr *tcph; + unsigned int tcphoff = iph->len; + bool payload_csum = false; + int oldlen; + +#ifdef CONFIG_IP_VS_IPV6 + if (cp->af == AF_INET6 && iph->fragoffs) + return 1; +#endif + oldlen = skb->len - tcphoff; + + /* csum_check requires unshared skb */ + if (skb_ensure_writable(skb, tcphoff + sizeof(*tcph))) + return 0; + + if (unlikely(cp->app != NULL)) { + int ret; + + /* Some checks before mangling */ + if (!tcp_csum_check(cp->af, skb, pp)) + return 0; + + /* + * Attempt ip_vs_app call. + * It will fix ip_vs_conn and iph ack_seq stuff + */ + if (!(ret = ip_vs_app_pkt_in(cp, skb, iph))) + return 0; + /* ret=2: csum update is needed after payload mangling */ + if (ret == 1) + oldlen = skb->len - tcphoff; + else + payload_csum = true; + } + + tcph = (void *)skb_network_header(skb) + tcphoff; + tcph->dest = cp->dport; + + /* + * Adjust TCP checksums + */ + if (skb->ip_summed == CHECKSUM_PARTIAL) { + tcp_partial_csum_update(cp->af, tcph, &cp->vaddr, &cp->daddr, + htons(oldlen), + htons(skb->len - tcphoff)); + } else if (!payload_csum) { + /* Only port and addr are changed, do fast csum update */ + tcp_fast_csum_update(cp->af, tcph, &cp->vaddr, &cp->daddr, + cp->vport, cp->dport); + if (skb->ip_summed == CHECKSUM_COMPLETE) + skb->ip_summed = cp->app ? + CHECKSUM_UNNECESSARY : CHECKSUM_NONE; + } else { + /* full checksum calculation */ + tcph->check = 0; + skb->csum = skb_checksum(skb, tcphoff, skb->len - tcphoff, 0); +#ifdef CONFIG_IP_VS_IPV6 + if (cp->af == AF_INET6) + tcph->check = csum_ipv6_magic(&cp->caddr.in6, + &cp->daddr.in6, + skb->len - tcphoff, + cp->protocol, skb->csum); + else +#endif + tcph->check = csum_tcpudp_magic(cp->caddr.ip, + cp->daddr.ip, + skb->len - tcphoff, + cp->protocol, + skb->csum); + skb->ip_summed = CHECKSUM_UNNECESSARY; + } + return 1; +} + + +static int +tcp_csum_check(int af, struct sk_buff *skb, struct ip_vs_protocol *pp) +{ + unsigned int tcphoff; + +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + tcphoff = sizeof(struct ipv6hdr); + else +#endif + tcphoff = ip_hdrlen(skb); + + switch (skb->ip_summed) { + case CHECKSUM_NONE: + skb->csum = skb_checksum(skb, tcphoff, skb->len - tcphoff, 0); + fallthrough; + case CHECKSUM_COMPLETE: +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) { + if (csum_ipv6_magic(&ipv6_hdr(skb)->saddr, + &ipv6_hdr(skb)->daddr, + skb->len - tcphoff, + ipv6_hdr(skb)->nexthdr, + skb->csum)) { + IP_VS_DBG_RL_PKT(0, af, pp, skb, 0, + "Failed checksum for"); + return 0; + } + } else +#endif + if (csum_tcpudp_magic(ip_hdr(skb)->saddr, + ip_hdr(skb)->daddr, + skb->len - tcphoff, + ip_hdr(skb)->protocol, + skb->csum)) { + IP_VS_DBG_RL_PKT(0, af, pp, skb, 0, + "Failed checksum for"); + return 0; + } + break; + default: + /* No need to checksum. */ + break; + } + + return 1; +} + + +#define TCP_DIR_INPUT 0 +#define TCP_DIR_OUTPUT 4 +#define TCP_DIR_INPUT_ONLY 8 + +static const int tcp_state_off[IP_VS_DIR_LAST] = { + [IP_VS_DIR_INPUT] = TCP_DIR_INPUT, + [IP_VS_DIR_OUTPUT] = TCP_DIR_OUTPUT, + [IP_VS_DIR_INPUT_ONLY] = TCP_DIR_INPUT_ONLY, +}; + +/* + * Timeout table[state] + */ +static const int tcp_timeouts[IP_VS_TCP_S_LAST+1] = { + [IP_VS_TCP_S_NONE] = 2*HZ, + [IP_VS_TCP_S_ESTABLISHED] = 15*60*HZ, + [IP_VS_TCP_S_SYN_SENT] = 2*60*HZ, + [IP_VS_TCP_S_SYN_RECV] = 1*60*HZ, + [IP_VS_TCP_S_FIN_WAIT] = 2*60*HZ, + [IP_VS_TCP_S_TIME_WAIT] = 2*60*HZ, + [IP_VS_TCP_S_CLOSE] = 10*HZ, + [IP_VS_TCP_S_CLOSE_WAIT] = 60*HZ, + [IP_VS_TCP_S_LAST_ACK] = 30*HZ, + [IP_VS_TCP_S_LISTEN] = 2*60*HZ, + [IP_VS_TCP_S_SYNACK] = 120*HZ, + [IP_VS_TCP_S_LAST] = 2*HZ, +}; + +static const char *const tcp_state_name_table[IP_VS_TCP_S_LAST+1] = { + [IP_VS_TCP_S_NONE] = "NONE", + [IP_VS_TCP_S_ESTABLISHED] = "ESTABLISHED", + [IP_VS_TCP_S_SYN_SENT] = "SYN_SENT", + [IP_VS_TCP_S_SYN_RECV] = "SYN_RECV", + [IP_VS_TCP_S_FIN_WAIT] = "FIN_WAIT", + [IP_VS_TCP_S_TIME_WAIT] = "TIME_WAIT", + [IP_VS_TCP_S_CLOSE] = "CLOSE", + [IP_VS_TCP_S_CLOSE_WAIT] = "CLOSE_WAIT", + [IP_VS_TCP_S_LAST_ACK] = "LAST_ACK", + [IP_VS_TCP_S_LISTEN] = "LISTEN", + [IP_VS_TCP_S_SYNACK] = "SYNACK", + [IP_VS_TCP_S_LAST] = "BUG!", +}; + +static const bool tcp_state_active_table[IP_VS_TCP_S_LAST] = { + [IP_VS_TCP_S_NONE] = false, + [IP_VS_TCP_S_ESTABLISHED] = true, + [IP_VS_TCP_S_SYN_SENT] = true, + [IP_VS_TCP_S_SYN_RECV] = true, + [IP_VS_TCP_S_FIN_WAIT] = false, + [IP_VS_TCP_S_TIME_WAIT] = false, + [IP_VS_TCP_S_CLOSE] = false, + [IP_VS_TCP_S_CLOSE_WAIT] = false, + [IP_VS_TCP_S_LAST_ACK] = false, + [IP_VS_TCP_S_LISTEN] = false, + [IP_VS_TCP_S_SYNACK] = true, +}; + +#define sNO IP_VS_TCP_S_NONE +#define sES IP_VS_TCP_S_ESTABLISHED +#define sSS IP_VS_TCP_S_SYN_SENT +#define sSR IP_VS_TCP_S_SYN_RECV +#define sFW IP_VS_TCP_S_FIN_WAIT +#define sTW IP_VS_TCP_S_TIME_WAIT +#define sCL IP_VS_TCP_S_CLOSE +#define sCW IP_VS_TCP_S_CLOSE_WAIT +#define sLA IP_VS_TCP_S_LAST_ACK +#define sLI IP_VS_TCP_S_LISTEN +#define sSA IP_VS_TCP_S_SYNACK + +struct tcp_states_t { + int next_state[IP_VS_TCP_S_LAST]; +}; + +static const char * tcp_state_name(int state) +{ + if (state >= IP_VS_TCP_S_LAST) + return "ERR!"; + return tcp_state_name_table[state] ? tcp_state_name_table[state] : "?"; +} + +static bool tcp_state_active(int state) +{ + if (state >= IP_VS_TCP_S_LAST) + return false; + return tcp_state_active_table[state]; +} + +static struct tcp_states_t tcp_states[] = { +/* INPUT */ +/* sNO, sES, sSS, sSR, sFW, sTW, sCL, sCW, sLA, sLI, sSA */ +/*syn*/ {{sSR, sES, sES, sSR, sSR, sSR, sSR, sSR, sSR, sSR, sSR }}, +/*fin*/ {{sCL, sCW, sSS, sTW, sTW, sTW, sCL, sCW, sLA, sLI, sTW }}, +/*ack*/ {{sES, sES, sSS, sES, sFW, sTW, sCL, sCW, sCL, sLI, sES }}, +/*rst*/ {{sCL, sCL, sCL, sSR, sCL, sCL, sCL, sCL, sLA, sLI, sSR }}, + +/* OUTPUT */ +/* sNO, sES, sSS, sSR, sFW, sTW, sCL, sCW, sLA, sLI, sSA */ +/*syn*/ {{sSS, sES, sSS, sSR, sSS, sSS, sSS, sSS, sSS, sLI, sSR }}, +/*fin*/ {{sTW, sFW, sSS, sTW, sFW, sTW, sCL, sTW, sLA, sLI, sTW }}, +/*ack*/ {{sES, sES, sSS, sES, sFW, sTW, sCL, sCW, sLA, sES, sES }}, +/*rst*/ {{sCL, sCL, sSS, sCL, sCL, sTW, sCL, sCL, sCL, sCL, sCL }}, + +/* INPUT-ONLY */ +/* sNO, sES, sSS, sSR, sFW, sTW, sCL, sCW, sLA, sLI, sSA */ +/*syn*/ {{sSR, sES, sES, sSR, sSR, sSR, sSR, sSR, sSR, sSR, sSR }}, +/*fin*/ {{sCL, sFW, sSS, sTW, sFW, sTW, sCL, sCW, sLA, sLI, sTW }}, +/*ack*/ {{sES, sES, sSS, sES, sFW, sTW, sCL, sCW, sCL, sLI, sES }}, +/*rst*/ {{sCL, sCL, sCL, sSR, sCL, sCL, sCL, sCL, sLA, sLI, sCL }}, +}; + +static struct tcp_states_t tcp_states_dos[] = { +/* INPUT */ +/* sNO, sES, sSS, sSR, sFW, sTW, sCL, sCW, sLA, sLI, sSA */ +/*syn*/ {{sSR, sES, sES, sSR, sSR, sSR, sSR, sSR, sSR, sSR, sSA }}, +/*fin*/ {{sCL, sCW, sSS, sTW, sTW, sTW, sCL, sCW, sLA, sLI, sSA }}, +/*ack*/ {{sES, sES, sSS, sSR, sFW, sTW, sCL, sCW, sCL, sLI, sSA }}, +/*rst*/ {{sCL, sCL, sCL, sSR, sCL, sCL, sCL, sCL, sLA, sLI, sCL }}, + +/* OUTPUT */ +/* sNO, sES, sSS, sSR, sFW, sTW, sCL, sCW, sLA, sLI, sSA */ +/*syn*/ {{sSS, sES, sSS, sSA, sSS, sSS, sSS, sSS, sSS, sLI, sSA }}, +/*fin*/ {{sTW, sFW, sSS, sTW, sFW, sTW, sCL, sTW, sLA, sLI, sTW }}, +/*ack*/ {{sES, sES, sSS, sES, sFW, sTW, sCL, sCW, sLA, sES, sES }}, +/*rst*/ {{sCL, sCL, sSS, sCL, sCL, sTW, sCL, sCL, sCL, sCL, sCL }}, + +/* INPUT-ONLY */ +/* sNO, sES, sSS, sSR, sFW, sTW, sCL, sCW, sLA, sLI, sSA */ +/*syn*/ {{sSA, sES, sES, sSR, sSA, sSA, sSA, sSA, sSA, sSA, sSA }}, +/*fin*/ {{sCL, sFW, sSS, sTW, sFW, sTW, sCL, sCW, sLA, sLI, sTW }}, +/*ack*/ {{sES, sES, sSS, sES, sFW, sTW, sCL, sCW, sCL, sLI, sES }}, +/*rst*/ {{sCL, sCL, sCL, sSR, sCL, sCL, sCL, sCL, sLA, sLI, sCL }}, +}; + +static void tcp_timeout_change(struct ip_vs_proto_data *pd, int flags) +{ + int on = (flags & 1); /* secure_tcp */ + + /* + ** FIXME: change secure_tcp to independent sysctl var + ** or make it per-service or per-app because it is valid + ** for most if not for all of the applications. Something + ** like "capabilities" (flags) for each object. + */ + pd->tcp_state_table = (on ? tcp_states_dos : tcp_states); +} + +static inline int tcp_state_idx(struct tcphdr *th) +{ + if (th->rst) + return 3; + if (th->syn) + return 0; + if (th->fin) + return 1; + if (th->ack) + return 2; + return -1; +} + +static inline void +set_tcp_state(struct ip_vs_proto_data *pd, struct ip_vs_conn *cp, + int direction, struct tcphdr *th) +{ + int state_idx; + int new_state = IP_VS_TCP_S_CLOSE; + int state_off = tcp_state_off[direction]; + + /* + * Update state offset to INPUT_ONLY if necessary + * or delete NO_OUTPUT flag if output packet detected + */ + if (cp->flags & IP_VS_CONN_F_NOOUTPUT) { + if (state_off == TCP_DIR_OUTPUT) + cp->flags &= ~IP_VS_CONN_F_NOOUTPUT; + else + state_off = TCP_DIR_INPUT_ONLY; + } + + if ((state_idx = tcp_state_idx(th)) < 0) { + IP_VS_DBG(8, "tcp_state_idx=%d!!!\n", state_idx); + goto tcp_state_out; + } + + new_state = + pd->tcp_state_table[state_off+state_idx].next_state[cp->state]; + + tcp_state_out: + if (new_state != cp->state) { + struct ip_vs_dest *dest = cp->dest; + + IP_VS_DBG_BUF(8, "%s %s [%c%c%c%c] c:%s:%d v:%s:%d " + "d:%s:%d state: %s->%s conn->refcnt:%d\n", + pd->pp->name, + ((state_off == TCP_DIR_OUTPUT) ? + "output " : "input "), + th->syn ? 'S' : '.', + th->fin ? 'F' : '.', + th->ack ? 'A' : '.', + th->rst ? 'R' : '.', + IP_VS_DBG_ADDR(cp->af, &cp->caddr), + ntohs(cp->cport), + IP_VS_DBG_ADDR(cp->af, &cp->vaddr), + ntohs(cp->vport), + IP_VS_DBG_ADDR(cp->daf, &cp->daddr), + ntohs(cp->dport), + tcp_state_name(cp->state), + tcp_state_name(new_state), + refcount_read(&cp->refcnt)); + + if (dest) { + if (!(cp->flags & IP_VS_CONN_F_INACTIVE) && + !tcp_state_active(new_state)) { + atomic_dec(&dest->activeconns); + atomic_inc(&dest->inactconns); + cp->flags |= IP_VS_CONN_F_INACTIVE; + } else if ((cp->flags & IP_VS_CONN_F_INACTIVE) && + tcp_state_active(new_state)) { + atomic_inc(&dest->activeconns); + atomic_dec(&dest->inactconns); + cp->flags &= ~IP_VS_CONN_F_INACTIVE; + } + } + if (new_state == IP_VS_TCP_S_ESTABLISHED) + ip_vs_control_assure_ct(cp); + } + + if (likely(pd)) + cp->timeout = pd->timeout_table[cp->state = new_state]; + else /* What to do ? */ + cp->timeout = tcp_timeouts[cp->state = new_state]; +} + +/* + * Handle state transitions + */ +static void +tcp_state_transition(struct ip_vs_conn *cp, int direction, + const struct sk_buff *skb, + struct ip_vs_proto_data *pd) +{ + struct tcphdr _tcph, *th; + +#ifdef CONFIG_IP_VS_IPV6 + int ihl = cp->af == AF_INET ? ip_hdrlen(skb) : sizeof(struct ipv6hdr); +#else + int ihl = ip_hdrlen(skb); +#endif + + th = skb_header_pointer(skb, ihl, sizeof(_tcph), &_tcph); + if (th == NULL) + return; + + spin_lock_bh(&cp->lock); + set_tcp_state(pd, cp, direction, th); + spin_unlock_bh(&cp->lock); +} + +static inline __u16 tcp_app_hashkey(__be16 port) +{ + return (((__force u16)port >> TCP_APP_TAB_BITS) ^ (__force u16)port) + & TCP_APP_TAB_MASK; +} + + +static int tcp_register_app(struct netns_ipvs *ipvs, struct ip_vs_app *inc) +{ + struct ip_vs_app *i; + __u16 hash; + __be16 port = inc->port; + int ret = 0; + struct ip_vs_proto_data *pd = ip_vs_proto_data_get(ipvs, IPPROTO_TCP); + + hash = tcp_app_hashkey(port); + + list_for_each_entry(i, &ipvs->tcp_apps[hash], p_list) { + if (i->port == port) { + ret = -EEXIST; + goto out; + } + } + list_add_rcu(&inc->p_list, &ipvs->tcp_apps[hash]); + atomic_inc(&pd->appcnt); + + out: + return ret; +} + + +static void +tcp_unregister_app(struct netns_ipvs *ipvs, struct ip_vs_app *inc) +{ + struct ip_vs_proto_data *pd = ip_vs_proto_data_get(ipvs, IPPROTO_TCP); + + atomic_dec(&pd->appcnt); + list_del_rcu(&inc->p_list); +} + + +static int +tcp_app_conn_bind(struct ip_vs_conn *cp) +{ + struct netns_ipvs *ipvs = cp->ipvs; + int hash; + struct ip_vs_app *inc; + int result = 0; + + /* Default binding: bind app only for NAT */ + if (IP_VS_FWD_METHOD(cp) != IP_VS_CONN_F_MASQ) + return 0; + + /* Lookup application incarnations and bind the right one */ + hash = tcp_app_hashkey(cp->vport); + + list_for_each_entry_rcu(inc, &ipvs->tcp_apps[hash], p_list) { + if (inc->port == cp->vport) { + if (unlikely(!ip_vs_app_inc_get(inc))) + break; + + IP_VS_DBG_BUF(9, "%s(): Binding conn %s:%u->" + "%s:%u to app %s on port %u\n", + __func__, + IP_VS_DBG_ADDR(cp->af, &cp->caddr), + ntohs(cp->cport), + IP_VS_DBG_ADDR(cp->af, &cp->vaddr), + ntohs(cp->vport), + inc->name, ntohs(inc->port)); + + cp->app = inc; + if (inc->init_conn) + result = inc->init_conn(inc, cp); + break; + } + } + + return result; +} + + +/* + * Set LISTEN timeout. (ip_vs_conn_put will setup timer) + */ +void ip_vs_tcp_conn_listen(struct ip_vs_conn *cp) +{ + struct ip_vs_proto_data *pd = ip_vs_proto_data_get(cp->ipvs, IPPROTO_TCP); + + spin_lock_bh(&cp->lock); + cp->state = IP_VS_TCP_S_LISTEN; + cp->timeout = (pd ? pd->timeout_table[IP_VS_TCP_S_LISTEN] + : tcp_timeouts[IP_VS_TCP_S_LISTEN]); + spin_unlock_bh(&cp->lock); +} + +/* --------------------------------------------- + * timeouts is netns related now. + * --------------------------------------------- + */ +static int __ip_vs_tcp_init(struct netns_ipvs *ipvs, struct ip_vs_proto_data *pd) +{ + ip_vs_init_hash_table(ipvs->tcp_apps, TCP_APP_TAB_SIZE); + pd->timeout_table = ip_vs_create_timeout_table((int *)tcp_timeouts, + sizeof(tcp_timeouts)); + if (!pd->timeout_table) + return -ENOMEM; + pd->tcp_state_table = tcp_states; + return 0; +} + +static void __ip_vs_tcp_exit(struct netns_ipvs *ipvs, struct ip_vs_proto_data *pd) +{ + kfree(pd->timeout_table); +} + + +struct ip_vs_protocol ip_vs_protocol_tcp = { + .name = "TCP", + .protocol = IPPROTO_TCP, + .num_states = IP_VS_TCP_S_LAST, + .dont_defrag = 0, + .init = NULL, + .exit = NULL, + .init_netns = __ip_vs_tcp_init, + .exit_netns = __ip_vs_tcp_exit, + .register_app = tcp_register_app, + .unregister_app = tcp_unregister_app, + .conn_schedule = tcp_conn_schedule, + .conn_in_get = ip_vs_conn_in_get_proto, + .conn_out_get = ip_vs_conn_out_get_proto, + .snat_handler = tcp_snat_handler, + .dnat_handler = tcp_dnat_handler, + .state_name = tcp_state_name, + .state_transition = tcp_state_transition, + .app_conn_bind = tcp_app_conn_bind, + .debug_packet = ip_vs_tcpudp_debug_packet, + .timeout_change = tcp_timeout_change, +}; diff --git a/net/netfilter/ipvs/ip_vs_proto_udp.c b/net/netfilter/ipvs/ip_vs_proto_udp.c new file mode 100644 index 000000000..68260d91c --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_proto_udp.c @@ -0,0 +1,503 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * ip_vs_proto_udp.c: UDP load balancing support for IPVS + * + * Authors: Wensong Zhang + * Julian Anastasov + * + * Changes: Hans Schillstrom + * Network name space (netns) aware. + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +static int +udp_csum_check(int af, struct sk_buff *skb, struct ip_vs_protocol *pp); + +static int +udp_conn_schedule(struct netns_ipvs *ipvs, int af, struct sk_buff *skb, + struct ip_vs_proto_data *pd, + int *verdict, struct ip_vs_conn **cpp, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_service *svc; + struct udphdr _udph, *uh; + __be16 _ports[2], *ports = NULL; + + if (likely(!ip_vs_iph_icmp(iph))) { + /* IPv6 fragments, only first fragment will hit this */ + uh = skb_header_pointer(skb, iph->len, sizeof(_udph), &_udph); + if (uh) + ports = &uh->source; + } else { + ports = skb_header_pointer( + skb, iph->len, sizeof(_ports), &_ports); + } + + if (!ports) { + *verdict = NF_DROP; + return 0; + } + + if (likely(!ip_vs_iph_inverse(iph))) + svc = ip_vs_service_find(ipvs, af, skb->mark, iph->protocol, + &iph->daddr, ports[1]); + else + svc = ip_vs_service_find(ipvs, af, skb->mark, iph->protocol, + &iph->saddr, ports[0]); + + if (svc) { + int ignored; + + if (ip_vs_todrop(ipvs)) { + /* + * It seems that we are very loaded. + * We have to drop this packet :( + */ + *verdict = NF_DROP; + return 0; + } + + /* + * Let the virtual server select a real server for the + * incoming connection, and create a connection entry. + */ + *cpp = ip_vs_schedule(svc, skb, pd, &ignored, iph); + if (!*cpp && ignored <= 0) { + if (!ignored) + *verdict = ip_vs_leave(svc, skb, pd, iph); + else + *verdict = NF_DROP; + return 0; + } + } + /* NF_ACCEPT */ + return 1; +} + + +static inline void +udp_fast_csum_update(int af, struct udphdr *uhdr, + const union nf_inet_addr *oldip, + const union nf_inet_addr *newip, + __be16 oldport, __be16 newport) +{ +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + uhdr->check = + csum_fold(ip_vs_check_diff16(oldip->ip6, newip->ip6, + ip_vs_check_diff2(oldport, newport, + ~csum_unfold(uhdr->check)))); + else +#endif + uhdr->check = + csum_fold(ip_vs_check_diff4(oldip->ip, newip->ip, + ip_vs_check_diff2(oldport, newport, + ~csum_unfold(uhdr->check)))); + if (!uhdr->check) + uhdr->check = CSUM_MANGLED_0; +} + +static inline void +udp_partial_csum_update(int af, struct udphdr *uhdr, + const union nf_inet_addr *oldip, + const union nf_inet_addr *newip, + __be16 oldlen, __be16 newlen) +{ +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + uhdr->check = + ~csum_fold(ip_vs_check_diff16(oldip->ip6, newip->ip6, + ip_vs_check_diff2(oldlen, newlen, + csum_unfold(uhdr->check)))); + else +#endif + uhdr->check = + ~csum_fold(ip_vs_check_diff4(oldip->ip, newip->ip, + ip_vs_check_diff2(oldlen, newlen, + csum_unfold(uhdr->check)))); +} + + +INDIRECT_CALLABLE_SCOPE int +udp_snat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, + struct ip_vs_conn *cp, struct ip_vs_iphdr *iph) +{ + struct udphdr *udph; + unsigned int udphoff = iph->len; + bool payload_csum = false; + int oldlen; + +#ifdef CONFIG_IP_VS_IPV6 + if (cp->af == AF_INET6 && iph->fragoffs) + return 1; +#endif + oldlen = skb->len - udphoff; + + /* csum_check requires unshared skb */ + if (skb_ensure_writable(skb, udphoff + sizeof(*udph))) + return 0; + + if (unlikely(cp->app != NULL)) { + int ret; + + /* Some checks before mangling */ + if (!udp_csum_check(cp->af, skb, pp)) + return 0; + + /* + * Call application helper if needed + */ + if (!(ret = ip_vs_app_pkt_out(cp, skb, iph))) + return 0; + /* ret=2: csum update is needed after payload mangling */ + if (ret == 1) + oldlen = skb->len - udphoff; + else + payload_csum = true; + } + + udph = (void *)skb_network_header(skb) + udphoff; + udph->source = cp->vport; + + /* + * Adjust UDP checksums + */ + if (skb->ip_summed == CHECKSUM_PARTIAL) { + udp_partial_csum_update(cp->af, udph, &cp->daddr, &cp->vaddr, + htons(oldlen), + htons(skb->len - udphoff)); + } else if (!payload_csum && (udph->check != 0)) { + /* Only port and addr are changed, do fast csum update */ + udp_fast_csum_update(cp->af, udph, &cp->daddr, &cp->vaddr, + cp->dport, cp->vport); + if (skb->ip_summed == CHECKSUM_COMPLETE) + skb->ip_summed = cp->app ? + CHECKSUM_UNNECESSARY : CHECKSUM_NONE; + } else { + /* full checksum calculation */ + udph->check = 0; + skb->csum = skb_checksum(skb, udphoff, skb->len - udphoff, 0); +#ifdef CONFIG_IP_VS_IPV6 + if (cp->af == AF_INET6) + udph->check = csum_ipv6_magic(&cp->vaddr.in6, + &cp->caddr.in6, + skb->len - udphoff, + cp->protocol, skb->csum); + else +#endif + udph->check = csum_tcpudp_magic(cp->vaddr.ip, + cp->caddr.ip, + skb->len - udphoff, + cp->protocol, + skb->csum); + if (udph->check == 0) + udph->check = CSUM_MANGLED_0; + skb->ip_summed = CHECKSUM_UNNECESSARY; + IP_VS_DBG(11, "O-pkt: %s O-csum=%d (+%zd)\n", + pp->name, udph->check, + (char*)&(udph->check) - (char*)udph); + } + return 1; +} + + +static int +udp_dnat_handler(struct sk_buff *skb, struct ip_vs_protocol *pp, + struct ip_vs_conn *cp, struct ip_vs_iphdr *iph) +{ + struct udphdr *udph; + unsigned int udphoff = iph->len; + bool payload_csum = false; + int oldlen; + +#ifdef CONFIG_IP_VS_IPV6 + if (cp->af == AF_INET6 && iph->fragoffs) + return 1; +#endif + oldlen = skb->len - udphoff; + + /* csum_check requires unshared skb */ + if (skb_ensure_writable(skb, udphoff + sizeof(*udph))) + return 0; + + if (unlikely(cp->app != NULL)) { + int ret; + + /* Some checks before mangling */ + if (!udp_csum_check(cp->af, skb, pp)) + return 0; + + /* + * Attempt ip_vs_app call. + * It will fix ip_vs_conn + */ + if (!(ret = ip_vs_app_pkt_in(cp, skb, iph))) + return 0; + /* ret=2: csum update is needed after payload mangling */ + if (ret == 1) + oldlen = skb->len - udphoff; + else + payload_csum = true; + } + + udph = (void *)skb_network_header(skb) + udphoff; + udph->dest = cp->dport; + + /* + * Adjust UDP checksums + */ + if (skb->ip_summed == CHECKSUM_PARTIAL) { + udp_partial_csum_update(cp->af, udph, &cp->vaddr, &cp->daddr, + htons(oldlen), + htons(skb->len - udphoff)); + } else if (!payload_csum && (udph->check != 0)) { + /* Only port and addr are changed, do fast csum update */ + udp_fast_csum_update(cp->af, udph, &cp->vaddr, &cp->daddr, + cp->vport, cp->dport); + if (skb->ip_summed == CHECKSUM_COMPLETE) + skb->ip_summed = cp->app ? + CHECKSUM_UNNECESSARY : CHECKSUM_NONE; + } else { + /* full checksum calculation */ + udph->check = 0; + skb->csum = skb_checksum(skb, udphoff, skb->len - udphoff, 0); +#ifdef CONFIG_IP_VS_IPV6 + if (cp->af == AF_INET6) + udph->check = csum_ipv6_magic(&cp->caddr.in6, + &cp->daddr.in6, + skb->len - udphoff, + cp->protocol, skb->csum); + else +#endif + udph->check = csum_tcpudp_magic(cp->caddr.ip, + cp->daddr.ip, + skb->len - udphoff, + cp->protocol, + skb->csum); + if (udph->check == 0) + udph->check = CSUM_MANGLED_0; + skb->ip_summed = CHECKSUM_UNNECESSARY; + } + return 1; +} + + +static int +udp_csum_check(int af, struct sk_buff *skb, struct ip_vs_protocol *pp) +{ + struct udphdr _udph, *uh; + unsigned int udphoff; + +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + udphoff = sizeof(struct ipv6hdr); + else +#endif + udphoff = ip_hdrlen(skb); + + uh = skb_header_pointer(skb, udphoff, sizeof(_udph), &_udph); + if (uh == NULL) + return 0; + + if (uh->check != 0) { + switch (skb->ip_summed) { + case CHECKSUM_NONE: + skb->csum = skb_checksum(skb, udphoff, + skb->len - udphoff, 0); + fallthrough; + case CHECKSUM_COMPLETE: +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) { + if (csum_ipv6_magic(&ipv6_hdr(skb)->saddr, + &ipv6_hdr(skb)->daddr, + skb->len - udphoff, + ipv6_hdr(skb)->nexthdr, + skb->csum)) { + IP_VS_DBG_RL_PKT(0, af, pp, skb, 0, + "Failed checksum for"); + return 0; + } + } else +#endif + if (csum_tcpudp_magic(ip_hdr(skb)->saddr, + ip_hdr(skb)->daddr, + skb->len - udphoff, + ip_hdr(skb)->protocol, + skb->csum)) { + IP_VS_DBG_RL_PKT(0, af, pp, skb, 0, + "Failed checksum for"); + return 0; + } + break; + default: + /* No need to checksum. */ + break; + } + } + return 1; +} + +static inline __u16 udp_app_hashkey(__be16 port) +{ + return (((__force u16)port >> UDP_APP_TAB_BITS) ^ (__force u16)port) + & UDP_APP_TAB_MASK; +} + + +static int udp_register_app(struct netns_ipvs *ipvs, struct ip_vs_app *inc) +{ + struct ip_vs_app *i; + __u16 hash; + __be16 port = inc->port; + int ret = 0; + struct ip_vs_proto_data *pd = ip_vs_proto_data_get(ipvs, IPPROTO_UDP); + + hash = udp_app_hashkey(port); + + list_for_each_entry(i, &ipvs->udp_apps[hash], p_list) { + if (i->port == port) { + ret = -EEXIST; + goto out; + } + } + list_add_rcu(&inc->p_list, &ipvs->udp_apps[hash]); + atomic_inc(&pd->appcnt); + + out: + return ret; +} + + +static void +udp_unregister_app(struct netns_ipvs *ipvs, struct ip_vs_app *inc) +{ + struct ip_vs_proto_data *pd = ip_vs_proto_data_get(ipvs, IPPROTO_UDP); + + atomic_dec(&pd->appcnt); + list_del_rcu(&inc->p_list); +} + + +static int udp_app_conn_bind(struct ip_vs_conn *cp) +{ + struct netns_ipvs *ipvs = cp->ipvs; + int hash; + struct ip_vs_app *inc; + int result = 0; + + /* Default binding: bind app only for NAT */ + if (IP_VS_FWD_METHOD(cp) != IP_VS_CONN_F_MASQ) + return 0; + + /* Lookup application incarnations and bind the right one */ + hash = udp_app_hashkey(cp->vport); + + list_for_each_entry_rcu(inc, &ipvs->udp_apps[hash], p_list) { + if (inc->port == cp->vport) { + if (unlikely(!ip_vs_app_inc_get(inc))) + break; + + IP_VS_DBG_BUF(9, "%s(): Binding conn %s:%u->" + "%s:%u to app %s on port %u\n", + __func__, + IP_VS_DBG_ADDR(cp->af, &cp->caddr), + ntohs(cp->cport), + IP_VS_DBG_ADDR(cp->af, &cp->vaddr), + ntohs(cp->vport), + inc->name, ntohs(inc->port)); + + cp->app = inc; + if (inc->init_conn) + result = inc->init_conn(inc, cp); + break; + } + } + + return result; +} + + +static const int udp_timeouts[IP_VS_UDP_S_LAST+1] = { + [IP_VS_UDP_S_NORMAL] = 5*60*HZ, + [IP_VS_UDP_S_LAST] = 2*HZ, +}; + +static const char *const udp_state_name_table[IP_VS_UDP_S_LAST+1] = { + [IP_VS_UDP_S_NORMAL] = "UDP", + [IP_VS_UDP_S_LAST] = "BUG!", +}; + +static const char * udp_state_name(int state) +{ + if (state >= IP_VS_UDP_S_LAST) + return "ERR!"; + return udp_state_name_table[state] ? udp_state_name_table[state] : "?"; +} + +static void +udp_state_transition(struct ip_vs_conn *cp, int direction, + const struct sk_buff *skb, + struct ip_vs_proto_data *pd) +{ + if (unlikely(!pd)) { + pr_err("UDP no ns data\n"); + return; + } + + cp->timeout = pd->timeout_table[IP_VS_UDP_S_NORMAL]; + if (direction == IP_VS_DIR_OUTPUT) + ip_vs_control_assure_ct(cp); +} + +static int __udp_init(struct netns_ipvs *ipvs, struct ip_vs_proto_data *pd) +{ + ip_vs_init_hash_table(ipvs->udp_apps, UDP_APP_TAB_SIZE); + pd->timeout_table = ip_vs_create_timeout_table((int *)udp_timeouts, + sizeof(udp_timeouts)); + if (!pd->timeout_table) + return -ENOMEM; + return 0; +} + +static void __udp_exit(struct netns_ipvs *ipvs, struct ip_vs_proto_data *pd) +{ + kfree(pd->timeout_table); +} + + +struct ip_vs_protocol ip_vs_protocol_udp = { + .name = "UDP", + .protocol = IPPROTO_UDP, + .num_states = IP_VS_UDP_S_LAST, + .dont_defrag = 0, + .init = NULL, + .exit = NULL, + .init_netns = __udp_init, + .exit_netns = __udp_exit, + .conn_schedule = udp_conn_schedule, + .conn_in_get = ip_vs_conn_in_get_proto, + .conn_out_get = ip_vs_conn_out_get_proto, + .snat_handler = udp_snat_handler, + .dnat_handler = udp_dnat_handler, + .state_transition = udp_state_transition, + .state_name = udp_state_name, + .register_app = udp_register_app, + .unregister_app = udp_unregister_app, + .app_conn_bind = udp_app_conn_bind, + .debug_packet = ip_vs_tcpudp_debug_packet, + .timeout_change = NULL, +}; diff --git a/net/netfilter/ipvs/ip_vs_rr.c b/net/netfilter/ipvs/ip_vs_rr.c new file mode 100644 index 000000000..38495c6f6 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_rr.c @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPVS: Round-Robin Scheduling module + * + * Authors: Wensong Zhang + * Peter Kese + * + * Fixes/Changes: + * Wensong Zhang : changed the ip_vs_rr_schedule to return dest + * Julian Anastasov : fixed the NULL pointer access bug in debugging + * Wensong Zhang : changed some comestics things for debugging + * Wensong Zhang : changed for the d-linked destination list + * Wensong Zhang : added the ip_vs_rr_update_svc + * Wensong Zhang : added any dest with weight=0 is quiesced + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include + +#include + + +static int ip_vs_rr_init_svc(struct ip_vs_service *svc) +{ + svc->sched_data = &svc->destinations; + return 0; +} + + +static int ip_vs_rr_del_dest(struct ip_vs_service *svc, struct ip_vs_dest *dest) +{ + struct list_head *p; + + spin_lock_bh(&svc->sched_lock); + p = (struct list_head *) svc->sched_data; + /* dest is already unlinked, so p->prev is not valid but + * p->next is valid, use it to reach previous entry. + */ + if (p == &dest->n_list) + svc->sched_data = p->next->prev; + spin_unlock_bh(&svc->sched_lock); + return 0; +} + + +/* + * Round-Robin Scheduling + */ +static struct ip_vs_dest * +ip_vs_rr_schedule(struct ip_vs_service *svc, const struct sk_buff *skb, + struct ip_vs_iphdr *iph) +{ + struct list_head *p; + struct ip_vs_dest *dest, *last; + int pass = 0; + + IP_VS_DBG(6, "%s(): Scheduling...\n", __func__); + + spin_lock_bh(&svc->sched_lock); + p = (struct list_head *) svc->sched_data; + last = dest = list_entry(p, struct ip_vs_dest, n_list); + + do { + list_for_each_entry_continue_rcu(dest, + &svc->destinations, + n_list) { + if (!(dest->flags & IP_VS_DEST_F_OVERLOAD) && + atomic_read(&dest->weight) > 0) + /* HIT */ + goto out; + if (dest == last) + goto stop; + } + pass++; + /* Previous dest could be unlinked, do not loop forever. + * If we stay at head there is no need for 2nd pass. + */ + } while (pass < 2 && p != &svc->destinations); + +stop: + spin_unlock_bh(&svc->sched_lock); + ip_vs_scheduler_err(svc, "no destination available"); + return NULL; + + out: + svc->sched_data = &dest->n_list; + spin_unlock_bh(&svc->sched_lock); + IP_VS_DBG_BUF(6, "RR: server %s:%u " + "activeconns %d refcnt %d weight %d\n", + IP_VS_DBG_ADDR(dest->af, &dest->addr), ntohs(dest->port), + atomic_read(&dest->activeconns), + refcount_read(&dest->refcnt), atomic_read(&dest->weight)); + + return dest; +} + + +static struct ip_vs_scheduler ip_vs_rr_scheduler = { + .name = "rr", /* name */ + .refcnt = ATOMIC_INIT(0), + .module = THIS_MODULE, + .n_list = LIST_HEAD_INIT(ip_vs_rr_scheduler.n_list), + .init_service = ip_vs_rr_init_svc, + .add_dest = NULL, + .del_dest = ip_vs_rr_del_dest, + .schedule = ip_vs_rr_schedule, +}; + +static int __init ip_vs_rr_init(void) +{ + return register_ip_vs_scheduler(&ip_vs_rr_scheduler); +} + +static void __exit ip_vs_rr_cleanup(void) +{ + unregister_ip_vs_scheduler(&ip_vs_rr_scheduler); + synchronize_rcu(); +} + +module_init(ip_vs_rr_init); +module_exit(ip_vs_rr_cleanup); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/ipvs/ip_vs_sched.c b/net/netfilter/ipvs/ip_vs_sched.c new file mode 100644 index 000000000..d4903723b --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_sched.c @@ -0,0 +1,250 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPVS An implementation of the IP virtual server support for the + * LINUX operating system. IPVS is now implemented as a module + * over the Netfilter framework. IPVS can be used to build a + * high-performance and highly available server based on a + * cluster of servers. + * + * Authors: Wensong Zhang + * Peter Kese + * + * Changes: + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include + +#include + +EXPORT_SYMBOL(ip_vs_scheduler_err); +/* + * IPVS scheduler list + */ +static LIST_HEAD(ip_vs_schedulers); + +/* semaphore for schedulers */ +static DEFINE_MUTEX(ip_vs_sched_mutex); + + +/* + * Bind a service with a scheduler + */ +int ip_vs_bind_scheduler(struct ip_vs_service *svc, + struct ip_vs_scheduler *scheduler) +{ + int ret; + + if (scheduler->init_service) { + ret = scheduler->init_service(svc); + if (ret) { + pr_err("%s(): init error\n", __func__); + return ret; + } + } + rcu_assign_pointer(svc->scheduler, scheduler); + return 0; +} + + +/* + * Unbind a service with its scheduler + */ +void ip_vs_unbind_scheduler(struct ip_vs_service *svc, + struct ip_vs_scheduler *sched) +{ + struct ip_vs_scheduler *cur_sched; + + cur_sched = rcu_dereference_protected(svc->scheduler, 1); + /* This check proves that old 'sched' was installed */ + if (!cur_sched) + return; + + if (sched->done_service) + sched->done_service(svc); + /* svc->scheduler can be set to NULL only by caller */ +} + + +/* + * Get scheduler in the scheduler list by name + */ +static struct ip_vs_scheduler *ip_vs_sched_getbyname(const char *sched_name) +{ + struct ip_vs_scheduler *sched; + + IP_VS_DBG(2, "%s(): sched_name \"%s\"\n", __func__, sched_name); + + mutex_lock(&ip_vs_sched_mutex); + + list_for_each_entry(sched, &ip_vs_schedulers, n_list) { + /* + * Test and get the modules atomically + */ + if (sched->module && !try_module_get(sched->module)) { + /* + * This scheduler is just deleted + */ + continue; + } + if (strcmp(sched_name, sched->name)==0) { + /* HIT */ + mutex_unlock(&ip_vs_sched_mutex); + return sched; + } + module_put(sched->module); + } + + mutex_unlock(&ip_vs_sched_mutex); + return NULL; +} + + +/* + * Lookup scheduler and try to load it if it doesn't exist + */ +struct ip_vs_scheduler *ip_vs_scheduler_get(const char *sched_name) +{ + struct ip_vs_scheduler *sched; + + /* + * Search for the scheduler by sched_name + */ + sched = ip_vs_sched_getbyname(sched_name); + + /* + * If scheduler not found, load the module and search again + */ + if (sched == NULL) { + request_module("ip_vs_%s", sched_name); + sched = ip_vs_sched_getbyname(sched_name); + } + + return sched; +} + +void ip_vs_scheduler_put(struct ip_vs_scheduler *scheduler) +{ + if (scheduler) + module_put(scheduler->module); +} + +/* + * Common error output helper for schedulers + */ + +void ip_vs_scheduler_err(struct ip_vs_service *svc, const char *msg) +{ + struct ip_vs_scheduler *sched = rcu_dereference(svc->scheduler); + char *sched_name = sched ? sched->name : "none"; + + if (svc->fwmark) { + IP_VS_ERR_RL("%s: FWM %u 0x%08X - %s\n", + sched_name, svc->fwmark, svc->fwmark, msg); +#ifdef CONFIG_IP_VS_IPV6 + } else if (svc->af == AF_INET6) { + IP_VS_ERR_RL("%s: %s [%pI6c]:%d - %s\n", + sched_name, ip_vs_proto_name(svc->protocol), + &svc->addr.in6, ntohs(svc->port), msg); +#endif + } else { + IP_VS_ERR_RL("%s: %s %pI4:%d - %s\n", + sched_name, ip_vs_proto_name(svc->protocol), + &svc->addr.ip, ntohs(svc->port), msg); + } +} + +/* + * Register a scheduler in the scheduler list + */ +int register_ip_vs_scheduler(struct ip_vs_scheduler *scheduler) +{ + struct ip_vs_scheduler *sched; + + if (!scheduler) { + pr_err("%s(): NULL arg\n", __func__); + return -EINVAL; + } + + if (!scheduler->name) { + pr_err("%s(): NULL scheduler_name\n", __func__); + return -EINVAL; + } + + /* increase the module use count */ + if (!ip_vs_use_count_inc()) + return -ENOENT; + + mutex_lock(&ip_vs_sched_mutex); + + if (!list_empty(&scheduler->n_list)) { + mutex_unlock(&ip_vs_sched_mutex); + ip_vs_use_count_dec(); + pr_err("%s(): [%s] scheduler already linked\n", + __func__, scheduler->name); + return -EINVAL; + } + + /* + * Make sure that the scheduler with this name doesn't exist + * in the scheduler list. + */ + list_for_each_entry(sched, &ip_vs_schedulers, n_list) { + if (strcmp(scheduler->name, sched->name) == 0) { + mutex_unlock(&ip_vs_sched_mutex); + ip_vs_use_count_dec(); + pr_err("%s(): [%s] scheduler already existed " + "in the system\n", __func__, scheduler->name); + return -EINVAL; + } + } + /* + * Add it into the d-linked scheduler list + */ + list_add(&scheduler->n_list, &ip_vs_schedulers); + mutex_unlock(&ip_vs_sched_mutex); + + pr_info("[%s] scheduler registered.\n", scheduler->name); + + return 0; +} + + +/* + * Unregister a scheduler from the scheduler list + */ +int unregister_ip_vs_scheduler(struct ip_vs_scheduler *scheduler) +{ + if (!scheduler) { + pr_err("%s(): NULL arg\n", __func__); + return -EINVAL; + } + + mutex_lock(&ip_vs_sched_mutex); + if (list_empty(&scheduler->n_list)) { + mutex_unlock(&ip_vs_sched_mutex); + pr_err("%s(): [%s] scheduler is not in the list. failed\n", + __func__, scheduler->name); + return -EINVAL; + } + + /* + * Remove it from the d-linked scheduler list + */ + list_del(&scheduler->n_list); + mutex_unlock(&ip_vs_sched_mutex); + + /* decrease the module use count */ + ip_vs_use_count_dec(); + + pr_info("[%s] scheduler unregistered.\n", scheduler->name); + + return 0; +} diff --git a/net/netfilter/ipvs/ip_vs_sed.c b/net/netfilter/ipvs/ip_vs_sed.c new file mode 100644 index 000000000..7663288e5 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_sed.c @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPVS: Shortest Expected Delay scheduling module + * + * Authors: Wensong Zhang + * + * Changes: + */ + +/* + * The SED algorithm attempts to minimize each job's expected delay until + * completion. The expected delay that the job will experience is + * (Ci + 1) / Ui if sent to the ith server, in which Ci is the number of + * jobs on the ith server and Ui is the fixed service rate (weight) of + * the ith server. The SED algorithm adopts a greedy policy that each does + * what is in its own best interest, i.e. to join the queue which would + * minimize its expected delay of completion. + * + * See the following paper for more information: + * A. Weinrib and S. Shenker, Greed is not enough: Adaptive load sharing + * in large heterogeneous systems. In Proceedings IEEE INFOCOM'88, + * pages 986-994, 1988. + * + * Thanks must go to Marko Buuri for talking SED to me. + * + * The difference between SED and WLC is that SED includes the incoming + * job in the cost function (the increment of 1). SED may outperform + * WLC, while scheduling big jobs under larger heterogeneous systems + * (the server weight varies a lot). + * + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include + +#include + + +static inline int +ip_vs_sed_dest_overhead(struct ip_vs_dest *dest) +{ + /* + * We only use the active connection number in the cost + * calculation here. + */ + return atomic_read(&dest->activeconns) + 1; +} + + +/* + * Weighted Least Connection scheduling + */ +static struct ip_vs_dest * +ip_vs_sed_schedule(struct ip_vs_service *svc, const struct sk_buff *skb, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_dest *dest, *least; + int loh, doh; + + IP_VS_DBG(6, "%s(): Scheduling...\n", __func__); + + /* + * We calculate the load of each dest server as follows: + * (server expected overhead) / dest->weight + * + * Remember -- no floats in kernel mode!!! + * The comparison of h1*w2 > h2*w1 is equivalent to that of + * h1/w1 > h2/w2 + * if every weight is larger than zero. + * + * The server with weight=0 is quiesced and will not receive any + * new connections. + */ + + list_for_each_entry_rcu(dest, &svc->destinations, n_list) { + if (!(dest->flags & IP_VS_DEST_F_OVERLOAD) && + atomic_read(&dest->weight) > 0) { + least = dest; + loh = ip_vs_sed_dest_overhead(least); + goto nextstage; + } + } + ip_vs_scheduler_err(svc, "no destination available"); + return NULL; + + /* + * Find the destination with the least load. + */ + nextstage: + list_for_each_entry_continue_rcu(dest, &svc->destinations, n_list) { + if (dest->flags & IP_VS_DEST_F_OVERLOAD) + continue; + doh = ip_vs_sed_dest_overhead(dest); + if ((__s64)loh * atomic_read(&dest->weight) > + (__s64)doh * atomic_read(&least->weight)) { + least = dest; + loh = doh; + } + } + + IP_VS_DBG_BUF(6, "SED: server %s:%u " + "activeconns %d refcnt %d weight %d overhead %d\n", + IP_VS_DBG_ADDR(least->af, &least->addr), + ntohs(least->port), + atomic_read(&least->activeconns), + refcount_read(&least->refcnt), + atomic_read(&least->weight), loh); + + return least; +} + + +static struct ip_vs_scheduler ip_vs_sed_scheduler = +{ + .name = "sed", + .refcnt = ATOMIC_INIT(0), + .module = THIS_MODULE, + .n_list = LIST_HEAD_INIT(ip_vs_sed_scheduler.n_list), + .schedule = ip_vs_sed_schedule, +}; + + +static int __init ip_vs_sed_init(void) +{ + return register_ip_vs_scheduler(&ip_vs_sed_scheduler); +} + +static void __exit ip_vs_sed_cleanup(void) +{ + unregister_ip_vs_scheduler(&ip_vs_sed_scheduler); + synchronize_rcu(); +} + +module_init(ip_vs_sed_init); +module_exit(ip_vs_sed_cleanup); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/ipvs/ip_vs_sh.c b/net/netfilter/ipvs/ip_vs_sh.c new file mode 100644 index 000000000..c2028e412 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_sh.c @@ -0,0 +1,378 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPVS: Source Hashing scheduling module + * + * Authors: Wensong Zhang + * + * Changes: + */ + +/* + * The sh algorithm is to select server by the hash key of source IP + * address. The pseudo code is as follows: + * + * n <- servernode[src_ip]; + * if (n is dead) OR + * (n is overloaded) or (n.weight <= 0) then + * return NULL; + * + * return n; + * + * Notes that servernode is a 256-bucket hash table that maps the hash + * index derived from packet source IP address to the current server + * array. If the sh scheduler is used in cache cluster, it is good to + * combine it with cache_bypass feature. When the statically assigned + * server is dead or overloaded, the load balancer can bypass the cache + * server and send requests to the original server directly. + * + * The weight destination attribute can be used to control the + * distribution of connections to the destinations in servernode. The + * greater the weight, the more connections the destination + * will receive. + * + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + + +/* + * IPVS SH bucket + */ +struct ip_vs_sh_bucket { + struct ip_vs_dest __rcu *dest; /* real server (cache) */ +}; + +/* + * for IPVS SH entry hash table + */ +#ifndef CONFIG_IP_VS_SH_TAB_BITS +#define CONFIG_IP_VS_SH_TAB_BITS 8 +#endif +#define IP_VS_SH_TAB_BITS CONFIG_IP_VS_SH_TAB_BITS +#define IP_VS_SH_TAB_SIZE (1 << IP_VS_SH_TAB_BITS) +#define IP_VS_SH_TAB_MASK (IP_VS_SH_TAB_SIZE - 1) + +struct ip_vs_sh_state { + struct rcu_head rcu_head; + struct ip_vs_sh_bucket buckets[IP_VS_SH_TAB_SIZE]; +}; + +/* Helper function to determine if server is unavailable */ +static inline bool is_unavailable(struct ip_vs_dest *dest) +{ + return atomic_read(&dest->weight) <= 0 || + dest->flags & IP_VS_DEST_F_OVERLOAD; +} + +/* + * Returns hash value for IPVS SH entry + */ +static inline unsigned int +ip_vs_sh_hashkey(int af, const union nf_inet_addr *addr, + __be16 port, unsigned int offset) +{ + __be32 addr_fold = addr->ip; + +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + addr_fold = addr->ip6[0]^addr->ip6[1]^ + addr->ip6[2]^addr->ip6[3]; +#endif + return (offset + hash_32(ntohs(port) + ntohl(addr_fold), + IP_VS_SH_TAB_BITS)) & + IP_VS_SH_TAB_MASK; +} + + +/* + * Get ip_vs_dest associated with supplied parameters. + */ +static inline struct ip_vs_dest * +ip_vs_sh_get(struct ip_vs_service *svc, struct ip_vs_sh_state *s, + const union nf_inet_addr *addr, __be16 port) +{ + unsigned int hash = ip_vs_sh_hashkey(svc->af, addr, port, 0); + struct ip_vs_dest *dest = rcu_dereference(s->buckets[hash].dest); + + return (!dest || is_unavailable(dest)) ? NULL : dest; +} + + +/* As ip_vs_sh_get, but with fallback if selected server is unavailable + * + * The fallback strategy loops around the table starting from a "random" + * point (in fact, it is chosen to be the original hash value to make the + * algorithm deterministic) to find a new server. + */ +static inline struct ip_vs_dest * +ip_vs_sh_get_fallback(struct ip_vs_service *svc, struct ip_vs_sh_state *s, + const union nf_inet_addr *addr, __be16 port) +{ + unsigned int offset, roffset; + unsigned int hash, ihash; + struct ip_vs_dest *dest; + + /* first try the dest it's supposed to go to */ + ihash = ip_vs_sh_hashkey(svc->af, addr, port, 0); + dest = rcu_dereference(s->buckets[ihash].dest); + if (!dest) + return NULL; + if (!is_unavailable(dest)) + return dest; + + IP_VS_DBG_BUF(6, "SH: selected unavailable server %s:%d, reselecting", + IP_VS_DBG_ADDR(dest->af, &dest->addr), ntohs(dest->port)); + + /* if the original dest is unavailable, loop around the table + * starting from ihash to find a new dest + */ + for (offset = 0; offset < IP_VS_SH_TAB_SIZE; offset++) { + roffset = (offset + ihash) % IP_VS_SH_TAB_SIZE; + hash = ip_vs_sh_hashkey(svc->af, addr, port, roffset); + dest = rcu_dereference(s->buckets[hash].dest); + if (!dest) + break; + if (!is_unavailable(dest)) + return dest; + IP_VS_DBG_BUF(6, "SH: selected unavailable " + "server %s:%d (offset %d), reselecting", + IP_VS_DBG_ADDR(dest->af, &dest->addr), + ntohs(dest->port), roffset); + } + + return NULL; +} + +/* + * Assign all the hash buckets of the specified table with the service. + */ +static int +ip_vs_sh_reassign(struct ip_vs_sh_state *s, struct ip_vs_service *svc) +{ + int i; + struct ip_vs_sh_bucket *b; + struct list_head *p; + struct ip_vs_dest *dest; + int d_count; + bool empty; + + b = &s->buckets[0]; + p = &svc->destinations; + empty = list_empty(p); + d_count = 0; + for (i=0; idest, 1); + if (dest) + ip_vs_dest_put(dest); + if (empty) + RCU_INIT_POINTER(b->dest, NULL); + else { + if (p == &svc->destinations) + p = p->next; + + dest = list_entry(p, struct ip_vs_dest, n_list); + ip_vs_dest_hold(dest); + RCU_INIT_POINTER(b->dest, dest); + + IP_VS_DBG_BUF(6, "assigned i: %d dest: %s weight: %d\n", + i, IP_VS_DBG_ADDR(dest->af, &dest->addr), + atomic_read(&dest->weight)); + + /* Don't move to next dest until filling weight */ + if (++d_count >= atomic_read(&dest->weight)) { + p = p->next; + d_count = 0; + } + + } + b++; + } + return 0; +} + + +/* + * Flush all the hash buckets of the specified table. + */ +static void ip_vs_sh_flush(struct ip_vs_sh_state *s) +{ + int i; + struct ip_vs_sh_bucket *b; + struct ip_vs_dest *dest; + + b = &s->buckets[0]; + for (i=0; idest, 1); + if (dest) { + ip_vs_dest_put(dest); + RCU_INIT_POINTER(b->dest, NULL); + } + b++; + } +} + + +static int ip_vs_sh_init_svc(struct ip_vs_service *svc) +{ + struct ip_vs_sh_state *s; + + /* allocate the SH table for this service */ + s = kzalloc(sizeof(struct ip_vs_sh_state), GFP_KERNEL); + if (s == NULL) + return -ENOMEM; + + svc->sched_data = s; + IP_VS_DBG(6, "SH hash table (memory=%zdbytes) allocated for " + "current service\n", + sizeof(struct ip_vs_sh_bucket)*IP_VS_SH_TAB_SIZE); + + /* assign the hash buckets with current dests */ + ip_vs_sh_reassign(s, svc); + + return 0; +} + + +static void ip_vs_sh_done_svc(struct ip_vs_service *svc) +{ + struct ip_vs_sh_state *s = svc->sched_data; + + /* got to clean up hash buckets here */ + ip_vs_sh_flush(s); + + /* release the table itself */ + kfree_rcu(s, rcu_head); + IP_VS_DBG(6, "SH hash table (memory=%zdbytes) released\n", + sizeof(struct ip_vs_sh_bucket)*IP_VS_SH_TAB_SIZE); +} + + +static int ip_vs_sh_dest_changed(struct ip_vs_service *svc, + struct ip_vs_dest *dest) +{ + struct ip_vs_sh_state *s = svc->sched_data; + + /* assign the hash buckets with the updated service */ + ip_vs_sh_reassign(s, svc); + + return 0; +} + + +/* Helper function to get port number */ +static inline __be16 +ip_vs_sh_get_port(const struct sk_buff *skb, struct ip_vs_iphdr *iph) +{ + __be16 _ports[2], *ports; + + /* At this point we know that we have a valid packet of some kind. + * Because ICMP packets are only guaranteed to have the first 8 + * bytes, let's just grab the ports. Fortunately they're in the + * same position for all three of the protocols we care about. + */ + switch (iph->protocol) { + case IPPROTO_TCP: + case IPPROTO_UDP: + case IPPROTO_SCTP: + ports = skb_header_pointer(skb, iph->len, sizeof(_ports), + &_ports); + if (unlikely(!ports)) + return 0; + + if (likely(!ip_vs_iph_inverse(iph))) + return ports[0]; + else + return ports[1]; + default: + return 0; + } +} + + +/* + * Source Hashing scheduling + */ +static struct ip_vs_dest * +ip_vs_sh_schedule(struct ip_vs_service *svc, const struct sk_buff *skb, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_dest *dest; + struct ip_vs_sh_state *s; + __be16 port = 0; + const union nf_inet_addr *hash_addr; + + hash_addr = ip_vs_iph_inverse(iph) ? &iph->daddr : &iph->saddr; + + IP_VS_DBG(6, "ip_vs_sh_schedule(): Scheduling...\n"); + + if (svc->flags & IP_VS_SVC_F_SCHED_SH_PORT) + port = ip_vs_sh_get_port(skb, iph); + + s = (struct ip_vs_sh_state *) svc->sched_data; + + if (svc->flags & IP_VS_SVC_F_SCHED_SH_FALLBACK) + dest = ip_vs_sh_get_fallback(svc, s, hash_addr, port); + else + dest = ip_vs_sh_get(svc, s, hash_addr, port); + + if (!dest) { + ip_vs_scheduler_err(svc, "no destination available"); + return NULL; + } + + IP_VS_DBG_BUF(6, "SH: source IP address %s --> server %s:%d\n", + IP_VS_DBG_ADDR(svc->af, hash_addr), + IP_VS_DBG_ADDR(dest->af, &dest->addr), + ntohs(dest->port)); + + return dest; +} + + +/* + * IPVS SH Scheduler structure + */ +static struct ip_vs_scheduler ip_vs_sh_scheduler = +{ + .name = "sh", + .refcnt = ATOMIC_INIT(0), + .module = THIS_MODULE, + .n_list = LIST_HEAD_INIT(ip_vs_sh_scheduler.n_list), + .init_service = ip_vs_sh_init_svc, + .done_service = ip_vs_sh_done_svc, + .add_dest = ip_vs_sh_dest_changed, + .del_dest = ip_vs_sh_dest_changed, + .upd_dest = ip_vs_sh_dest_changed, + .schedule = ip_vs_sh_schedule, +}; + + +static int __init ip_vs_sh_init(void) +{ + return register_ip_vs_scheduler(&ip_vs_sh_scheduler); +} + + +static void __exit ip_vs_sh_cleanup(void) +{ + unregister_ip_vs_scheduler(&ip_vs_sh_scheduler); + synchronize_rcu(); +} + + +module_init(ip_vs_sh_init); +module_exit(ip_vs_sh_cleanup); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/ipvs/ip_vs_sync.c b/net/netfilter/ipvs/ip_vs_sync.c new file mode 100644 index 000000000..e1dea9a82 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_sync.c @@ -0,0 +1,2051 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * IPVS An implementation of the IP virtual server support for the + * LINUX operating system. IPVS is now implemented as a module + * over the NetFilter framework. IPVS can be used to build a + * high-performance and highly available server based on a + * cluster of servers. + * + * Version 1, is capable of handling both version 0 and 1 messages. + * Version 0 is the plain old format. + * Note Version 0 receivers will just drop Ver 1 messages. + * Version 1 is capable of handle IPv6, Persistence data, + * time-outs, and firewall marks. + * In ver.1 "ip_vs_sync_conn_options" will be sent in netw. order. + * Ver. 0 can be turned on by sysctl -w net.ipv4.vs.sync_version=0 + * + * Definitions Message: is a complete datagram + * Sync_conn: is a part of a Message + * Param Data is an option to a Sync_conn. + * + * Authors: Wensong Zhang + * + * ip_vs_sync: sync connection info from master load balancer to backups + * through multicast + * + * Changes: + * Alexandre Cassen : Added master & backup support at a time. + * Alexandre Cassen : Added SyncID support for incoming sync + * messages filtering. + * Justin Ossevoort : Fix endian problem on sync message size. + * Hans Schillstrom : Added Version 1: i.e. IPv6, + * Persistence support, fwmark and time-out. + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include /* for ip_mc_join_group */ +#include +#include +#include +#include +#include +#include + +#include /* Used for ntoh_seq and hton_seq */ + +#include +#include + +#include + +#define IP_VS_SYNC_GROUP 0xe0000051 /* multicast addr - 224.0.0.81 */ +#define IP_VS_SYNC_PORT 8848 /* multicast port */ + +#define SYNC_PROTO_VER 1 /* Protocol version in header */ + +static struct lock_class_key __ipvs_sync_key; +/* + * IPVS sync connection entry + * Version 0, i.e. original version. + */ +struct ip_vs_sync_conn_v0 { + __u8 reserved; + + /* Protocol, addresses and port numbers */ + __u8 protocol; /* Which protocol (TCP/UDP) */ + __be16 cport; + __be16 vport; + __be16 dport; + __be32 caddr; /* client address */ + __be32 vaddr; /* virtual address */ + __be32 daddr; /* destination address */ + + /* Flags and state transition */ + __be16 flags; /* status flags */ + __be16 state; /* state info */ + + /* The sequence options start here */ +}; + +struct ip_vs_sync_conn_options { + struct ip_vs_seq in_seq; /* incoming seq. struct */ + struct ip_vs_seq out_seq; /* outgoing seq. struct */ +}; + +/* + Sync Connection format (sync_conn) + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Type | Protocol | Ver. | Size | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Flags | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | State | cport | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | vport | dport | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | fwmark | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | timeout (in sec.) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | ... | + | IP-Addresses (v4 or v6) | + | ... | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + Optional Parameters. + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Param. Type | Param. Length | Param. data | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | + | ... | + | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | Param Type | Param. Length | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Param data | + | Last Param data should be padded for 32 bit alignment | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ + +/* + * Type 0, IPv4 sync connection format + */ +struct ip_vs_sync_v4 { + __u8 type; + __u8 protocol; /* Which protocol (TCP/UDP) */ + __be16 ver_size; /* Version msb 4 bits */ + /* Flags and state transition */ + __be32 flags; /* status flags */ + __be16 state; /* state info */ + /* Protocol, addresses and port numbers */ + __be16 cport; + __be16 vport; + __be16 dport; + __be32 fwmark; /* Firewall mark from skb */ + __be32 timeout; /* cp timeout */ + __be32 caddr; /* client address */ + __be32 vaddr; /* virtual address */ + __be32 daddr; /* destination address */ + /* The sequence options start here */ + /* PE data padded to 32bit alignment after seq. options */ +}; +/* + * Type 2 messages IPv6 + */ +struct ip_vs_sync_v6 { + __u8 type; + __u8 protocol; /* Which protocol (TCP/UDP) */ + __be16 ver_size; /* Version msb 4 bits */ + /* Flags and state transition */ + __be32 flags; /* status flags */ + __be16 state; /* state info */ + /* Protocol, addresses and port numbers */ + __be16 cport; + __be16 vport; + __be16 dport; + __be32 fwmark; /* Firewall mark from skb */ + __be32 timeout; /* cp timeout */ + struct in6_addr caddr; /* client address */ + struct in6_addr vaddr; /* virtual address */ + struct in6_addr daddr; /* destination address */ + /* The sequence options start here */ + /* PE data padded to 32bit alignment after seq. options */ +}; + +union ip_vs_sync_conn { + struct ip_vs_sync_v4 v4; + struct ip_vs_sync_v6 v6; +}; + +/* Bits in Type field in above */ +#define STYPE_INET6 0 +#define STYPE_F_INET6 (1 << STYPE_INET6) + +#define SVER_SHIFT 12 /* Shift to get version */ +#define SVER_MASK 0x0fff /* Mask to strip version */ + +#define IPVS_OPT_SEQ_DATA 1 +#define IPVS_OPT_PE_DATA 2 +#define IPVS_OPT_PE_NAME 3 +#define IPVS_OPT_PARAM 7 + +#define IPVS_OPT_F_SEQ_DATA (1 << (IPVS_OPT_SEQ_DATA-1)) +#define IPVS_OPT_F_PE_DATA (1 << (IPVS_OPT_PE_DATA-1)) +#define IPVS_OPT_F_PE_NAME (1 << (IPVS_OPT_PE_NAME-1)) +#define IPVS_OPT_F_PARAM (1 << (IPVS_OPT_PARAM-1)) + +struct ip_vs_sync_thread_data { + struct task_struct *task; + struct netns_ipvs *ipvs; + struct socket *sock; + char *buf; + int id; +}; + +/* Version 0 definition of packet sizes */ +#define SIMPLE_CONN_SIZE (sizeof(struct ip_vs_sync_conn_v0)) +#define FULL_CONN_SIZE \ +(sizeof(struct ip_vs_sync_conn_v0) + sizeof(struct ip_vs_sync_conn_options)) + + +/* + The master mulitcasts messages (Datagrams) to the backup load balancers + in the following format. + + Version 1: + Note, first byte should be Zero, so ver 0 receivers will drop the packet. + + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | 0 | SyncID | Size | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Count Conns | Version | Reserved, set to Zero | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + | IPVS Sync Connection (1) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | . | + ~ . ~ + | . | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | | + | IPVS Sync Connection (n) | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + Version 0 Header + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | Count Conns | SyncID | Size | + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + | IPVS Sync Connection (1) | +*/ + +/* Version 0 header */ +struct ip_vs_sync_mesg_v0 { + __u8 nr_conns; + __u8 syncid; + __be16 size; + + /* ip_vs_sync_conn entries start here */ +}; + +/* Version 1 header */ +struct ip_vs_sync_mesg { + __u8 reserved; /* must be zero */ + __u8 syncid; + __be16 size; + __u8 nr_conns; + __s8 version; /* SYNC_PROTO_VER */ + __u16 spare; + /* ip_vs_sync_conn entries start here */ +}; + +union ipvs_sockaddr { + struct sockaddr_in in; + struct sockaddr_in6 in6; +}; + +struct ip_vs_sync_buff { + struct list_head list; + unsigned long firstuse; + + /* pointers for the message data */ + struct ip_vs_sync_mesg *mesg; + unsigned char *head; + unsigned char *end; +}; + +/* + * Copy of struct ip_vs_seq + * From unaligned network order to aligned host order + */ +static void ntoh_seq(struct ip_vs_seq *no, struct ip_vs_seq *ho) +{ + memset(ho, 0, sizeof(*ho)); + ho->init_seq = get_unaligned_be32(&no->init_seq); + ho->delta = get_unaligned_be32(&no->delta); + ho->previous_delta = get_unaligned_be32(&no->previous_delta); +} + +/* + * Copy of struct ip_vs_seq + * From Aligned host order to unaligned network order + */ +static void hton_seq(struct ip_vs_seq *ho, struct ip_vs_seq *no) +{ + put_unaligned_be32(ho->init_seq, &no->init_seq); + put_unaligned_be32(ho->delta, &no->delta); + put_unaligned_be32(ho->previous_delta, &no->previous_delta); +} + +static inline struct ip_vs_sync_buff * +sb_dequeue(struct netns_ipvs *ipvs, struct ipvs_master_sync_state *ms) +{ + struct ip_vs_sync_buff *sb; + + spin_lock_bh(&ipvs->sync_lock); + if (list_empty(&ms->sync_queue)) { + sb = NULL; + __set_current_state(TASK_INTERRUPTIBLE); + } else { + sb = list_entry(ms->sync_queue.next, struct ip_vs_sync_buff, + list); + list_del(&sb->list); + ms->sync_queue_len--; + if (!ms->sync_queue_len) + ms->sync_queue_delay = 0; + } + spin_unlock_bh(&ipvs->sync_lock); + + return sb; +} + +/* + * Create a new sync buffer for Version 1 proto. + */ +static inline struct ip_vs_sync_buff * +ip_vs_sync_buff_create(struct netns_ipvs *ipvs, unsigned int len) +{ + struct ip_vs_sync_buff *sb; + + if (!(sb=kmalloc(sizeof(struct ip_vs_sync_buff), GFP_ATOMIC))) + return NULL; + + len = max_t(unsigned int, len + sizeof(struct ip_vs_sync_mesg), + ipvs->mcfg.sync_maxlen); + sb->mesg = kmalloc(len, GFP_ATOMIC); + if (!sb->mesg) { + kfree(sb); + return NULL; + } + sb->mesg->reserved = 0; /* old nr_conns i.e. must be zero now */ + sb->mesg->version = SYNC_PROTO_VER; + sb->mesg->syncid = ipvs->mcfg.syncid; + sb->mesg->size = htons(sizeof(struct ip_vs_sync_mesg)); + sb->mesg->nr_conns = 0; + sb->mesg->spare = 0; + sb->head = (unsigned char *)sb->mesg + sizeof(struct ip_vs_sync_mesg); + sb->end = (unsigned char *)sb->mesg + len; + + sb->firstuse = jiffies; + return sb; +} + +static inline void ip_vs_sync_buff_release(struct ip_vs_sync_buff *sb) +{ + kfree(sb->mesg); + kfree(sb); +} + +static inline void sb_queue_tail(struct netns_ipvs *ipvs, + struct ipvs_master_sync_state *ms) +{ + struct ip_vs_sync_buff *sb = ms->sync_buff; + + spin_lock(&ipvs->sync_lock); + if (ipvs->sync_state & IP_VS_STATE_MASTER && + ms->sync_queue_len < sysctl_sync_qlen_max(ipvs)) { + if (!ms->sync_queue_len) + schedule_delayed_work(&ms->master_wakeup_work, + max(IPVS_SYNC_SEND_DELAY, 1)); + ms->sync_queue_len++; + list_add_tail(&sb->list, &ms->sync_queue); + if ((++ms->sync_queue_delay) == IPVS_SYNC_WAKEUP_RATE) { + int id = (int)(ms - ipvs->ms); + + wake_up_process(ipvs->master_tinfo[id].task); + } + } else + ip_vs_sync_buff_release(sb); + spin_unlock(&ipvs->sync_lock); +} + +/* + * Get the current sync buffer if it has been created for more + * than the specified time or the specified time is zero. + */ +static inline struct ip_vs_sync_buff * +get_curr_sync_buff(struct netns_ipvs *ipvs, struct ipvs_master_sync_state *ms, + unsigned long time) +{ + struct ip_vs_sync_buff *sb; + + spin_lock_bh(&ipvs->sync_buff_lock); + sb = ms->sync_buff; + if (sb && time_after_eq(jiffies - sb->firstuse, time)) { + ms->sync_buff = NULL; + __set_current_state(TASK_RUNNING); + } else + sb = NULL; + spin_unlock_bh(&ipvs->sync_buff_lock); + return sb; +} + +static inline int +select_master_thread_id(struct netns_ipvs *ipvs, struct ip_vs_conn *cp) +{ + return ((long) cp >> (1 + ilog2(sizeof(*cp)))) & ipvs->threads_mask; +} + +/* + * Create a new sync buffer for Version 0 proto. + */ +static inline struct ip_vs_sync_buff * +ip_vs_sync_buff_create_v0(struct netns_ipvs *ipvs, unsigned int len) +{ + struct ip_vs_sync_buff *sb; + struct ip_vs_sync_mesg_v0 *mesg; + + if (!(sb=kmalloc(sizeof(struct ip_vs_sync_buff), GFP_ATOMIC))) + return NULL; + + len = max_t(unsigned int, len + sizeof(struct ip_vs_sync_mesg_v0), + ipvs->mcfg.sync_maxlen); + sb->mesg = kmalloc(len, GFP_ATOMIC); + if (!sb->mesg) { + kfree(sb); + return NULL; + } + mesg = (struct ip_vs_sync_mesg_v0 *)sb->mesg; + mesg->nr_conns = 0; + mesg->syncid = ipvs->mcfg.syncid; + mesg->size = htons(sizeof(struct ip_vs_sync_mesg_v0)); + sb->head = (unsigned char *)mesg + sizeof(struct ip_vs_sync_mesg_v0); + sb->end = (unsigned char *)mesg + len; + sb->firstuse = jiffies; + return sb; +} + +/* Check if connection is controlled by persistence */ +static inline bool in_persistence(struct ip_vs_conn *cp) +{ + for (cp = cp->control; cp; cp = cp->control) { + if (cp->flags & IP_VS_CONN_F_TEMPLATE) + return true; + } + return false; +} + +/* Check if conn should be synced. + * pkts: conn packets, use sysctl_sync_threshold to avoid packet check + * - (1) sync_refresh_period: reduce sync rate. Additionally, retry + * sync_retries times with period of sync_refresh_period/8 + * - (2) if both sync_refresh_period and sync_period are 0 send sync only + * for state changes or only once when pkts matches sync_threshold + * - (3) templates: rate can be reduced only with sync_refresh_period or + * with (2) + */ +static int ip_vs_sync_conn_needed(struct netns_ipvs *ipvs, + struct ip_vs_conn *cp, int pkts) +{ + unsigned long orig = READ_ONCE(cp->sync_endtime); + unsigned long now = jiffies; + unsigned long n = (now + cp->timeout) & ~3UL; + unsigned int sync_refresh_period; + int sync_period; + int force; + + /* Check if we sync in current state */ + if (unlikely(cp->flags & IP_VS_CONN_F_TEMPLATE)) + force = 0; + else if (unlikely(sysctl_sync_persist_mode(ipvs) && in_persistence(cp))) + return 0; + else if (likely(cp->protocol == IPPROTO_TCP)) { + if (!((1 << cp->state) & + ((1 << IP_VS_TCP_S_ESTABLISHED) | + (1 << IP_VS_TCP_S_FIN_WAIT) | + (1 << IP_VS_TCP_S_CLOSE) | + (1 << IP_VS_TCP_S_CLOSE_WAIT) | + (1 << IP_VS_TCP_S_TIME_WAIT)))) + return 0; + force = cp->state != cp->old_state; + if (force && cp->state != IP_VS_TCP_S_ESTABLISHED) + goto set; + } else if (unlikely(cp->protocol == IPPROTO_SCTP)) { + if (!((1 << cp->state) & + ((1 << IP_VS_SCTP_S_ESTABLISHED) | + (1 << IP_VS_SCTP_S_SHUTDOWN_SENT) | + (1 << IP_VS_SCTP_S_SHUTDOWN_RECEIVED) | + (1 << IP_VS_SCTP_S_SHUTDOWN_ACK_SENT) | + (1 << IP_VS_SCTP_S_CLOSED)))) + return 0; + force = cp->state != cp->old_state; + if (force && cp->state != IP_VS_SCTP_S_ESTABLISHED) + goto set; + } else { + /* UDP or another protocol with single state */ + force = 0; + } + + sync_refresh_period = sysctl_sync_refresh_period(ipvs); + if (sync_refresh_period > 0) { + long diff = n - orig; + long min_diff = max(cp->timeout >> 1, 10UL * HZ); + + /* Avoid sync if difference is below sync_refresh_period + * and below the half timeout. + */ + if (abs(diff) < min_t(long, sync_refresh_period, min_diff)) { + int retries = orig & 3; + + if (retries >= sysctl_sync_retries(ipvs)) + return 0; + if (time_before(now, orig - cp->timeout + + (sync_refresh_period >> 3))) + return 0; + n |= retries + 1; + } + } + sync_period = sysctl_sync_period(ipvs); + if (sync_period > 0) { + if (!(cp->flags & IP_VS_CONN_F_TEMPLATE) && + pkts % sync_period != sysctl_sync_threshold(ipvs)) + return 0; + } else if (!sync_refresh_period && + pkts != sysctl_sync_threshold(ipvs)) + return 0; + +set: + cp->old_state = cp->state; + n = cmpxchg(&cp->sync_endtime, orig, n); + return n == orig || force; +} + +/* + * Version 0 , could be switched in by sys_ctl. + * Add an ip_vs_conn information into the current sync_buff. + */ +static void ip_vs_sync_conn_v0(struct netns_ipvs *ipvs, struct ip_vs_conn *cp, + int pkts) +{ + struct ip_vs_sync_mesg_v0 *m; + struct ip_vs_sync_conn_v0 *s; + struct ip_vs_sync_buff *buff; + struct ipvs_master_sync_state *ms; + int id; + unsigned int len; + + if (unlikely(cp->af != AF_INET)) + return; + /* Do not sync ONE PACKET */ + if (cp->flags & IP_VS_CONN_F_ONE_PACKET) + return; + + if (!ip_vs_sync_conn_needed(ipvs, cp, pkts)) + return; + + spin_lock_bh(&ipvs->sync_buff_lock); + if (!(ipvs->sync_state & IP_VS_STATE_MASTER)) { + spin_unlock_bh(&ipvs->sync_buff_lock); + return; + } + + id = select_master_thread_id(ipvs, cp); + ms = &ipvs->ms[id]; + buff = ms->sync_buff; + len = (cp->flags & IP_VS_CONN_F_SEQ_MASK) ? FULL_CONN_SIZE : + SIMPLE_CONN_SIZE; + if (buff) { + m = (struct ip_vs_sync_mesg_v0 *) buff->mesg; + /* Send buffer if it is for v1 */ + if (buff->head + len > buff->end || !m->nr_conns) { + sb_queue_tail(ipvs, ms); + ms->sync_buff = NULL; + buff = NULL; + } + } + if (!buff) { + buff = ip_vs_sync_buff_create_v0(ipvs, len); + if (!buff) { + spin_unlock_bh(&ipvs->sync_buff_lock); + pr_err("ip_vs_sync_buff_create failed.\n"); + return; + } + ms->sync_buff = buff; + } + + m = (struct ip_vs_sync_mesg_v0 *) buff->mesg; + s = (struct ip_vs_sync_conn_v0 *) buff->head; + + /* copy members */ + s->reserved = 0; + s->protocol = cp->protocol; + s->cport = cp->cport; + s->vport = cp->vport; + s->dport = cp->dport; + s->caddr = cp->caddr.ip; + s->vaddr = cp->vaddr.ip; + s->daddr = cp->daddr.ip; + s->flags = htons(cp->flags & ~IP_VS_CONN_F_HASHED); + s->state = htons(cp->state); + if (cp->flags & IP_VS_CONN_F_SEQ_MASK) { + struct ip_vs_sync_conn_options *opt = + (struct ip_vs_sync_conn_options *)&s[1]; + memcpy(opt, &cp->sync_conn_opt, sizeof(*opt)); + } + + m->nr_conns++; + m->size = htons(ntohs(m->size) + len); + buff->head += len; + spin_unlock_bh(&ipvs->sync_buff_lock); + + /* synchronize its controller if it has */ + cp = cp->control; + if (cp) { + if (cp->flags & IP_VS_CONN_F_TEMPLATE) + pkts = atomic_inc_return(&cp->in_pkts); + else + pkts = sysctl_sync_threshold(ipvs); + ip_vs_sync_conn(ipvs, cp, pkts); + } +} + +/* + * Add an ip_vs_conn information into the current sync_buff. + * Called by ip_vs_in. + * Sending Version 1 messages + */ +void ip_vs_sync_conn(struct netns_ipvs *ipvs, struct ip_vs_conn *cp, int pkts) +{ + struct ip_vs_sync_mesg *m; + union ip_vs_sync_conn *s; + struct ip_vs_sync_buff *buff; + struct ipvs_master_sync_state *ms; + int id; + __u8 *p; + unsigned int len, pe_name_len, pad; + + /* Handle old version of the protocol */ + if (sysctl_sync_ver(ipvs) == 0) { + ip_vs_sync_conn_v0(ipvs, cp, pkts); + return; + } + /* Do not sync ONE PACKET */ + if (cp->flags & IP_VS_CONN_F_ONE_PACKET) + goto control; +sloop: + if (!ip_vs_sync_conn_needed(ipvs, cp, pkts)) + goto control; + + /* Sanity checks */ + pe_name_len = 0; + if (cp->pe_data_len) { + if (!cp->pe_data || !cp->dest) { + IP_VS_ERR_RL("SYNC, connection pe_data invalid\n"); + return; + } + pe_name_len = strnlen(cp->pe->name, IP_VS_PENAME_MAXLEN); + } + + spin_lock_bh(&ipvs->sync_buff_lock); + if (!(ipvs->sync_state & IP_VS_STATE_MASTER)) { + spin_unlock_bh(&ipvs->sync_buff_lock); + return; + } + + id = select_master_thread_id(ipvs, cp); + ms = &ipvs->ms[id]; + +#ifdef CONFIG_IP_VS_IPV6 + if (cp->af == AF_INET6) + len = sizeof(struct ip_vs_sync_v6); + else +#endif + len = sizeof(struct ip_vs_sync_v4); + + if (cp->flags & IP_VS_CONN_F_SEQ_MASK) + len += sizeof(struct ip_vs_sync_conn_options) + 2; + + if (cp->pe_data_len) + len += cp->pe_data_len + 2; /* + Param hdr field */ + if (pe_name_len) + len += pe_name_len + 2; + + /* check if there is a space for this one */ + pad = 0; + buff = ms->sync_buff; + if (buff) { + m = buff->mesg; + pad = (4 - (size_t) buff->head) & 3; + /* Send buffer if it is for v0 */ + if (buff->head + len + pad > buff->end || m->reserved) { + sb_queue_tail(ipvs, ms); + ms->sync_buff = NULL; + buff = NULL; + pad = 0; + } + } + + if (!buff) { + buff = ip_vs_sync_buff_create(ipvs, len); + if (!buff) { + spin_unlock_bh(&ipvs->sync_buff_lock); + pr_err("ip_vs_sync_buff_create failed.\n"); + return; + } + ms->sync_buff = buff; + m = buff->mesg; + } + + p = buff->head; + buff->head += pad + len; + m->size = htons(ntohs(m->size) + pad + len); + /* Add ev. padding from prev. sync_conn */ + while (pad--) + *(p++) = 0; + + s = (union ip_vs_sync_conn *)p; + + /* Set message type & copy members */ + s->v4.type = (cp->af == AF_INET6 ? STYPE_F_INET6 : 0); + s->v4.ver_size = htons(len & SVER_MASK); /* Version 0 */ + s->v4.flags = htonl(cp->flags & ~IP_VS_CONN_F_HASHED); + s->v4.state = htons(cp->state); + s->v4.protocol = cp->protocol; + s->v4.cport = cp->cport; + s->v4.vport = cp->vport; + s->v4.dport = cp->dport; + s->v4.fwmark = htonl(cp->fwmark); + s->v4.timeout = htonl(cp->timeout / HZ); + m->nr_conns++; + +#ifdef CONFIG_IP_VS_IPV6 + if (cp->af == AF_INET6) { + p += sizeof(struct ip_vs_sync_v6); + s->v6.caddr = cp->caddr.in6; + s->v6.vaddr = cp->vaddr.in6; + s->v6.daddr = cp->daddr.in6; + } else +#endif + { + p += sizeof(struct ip_vs_sync_v4); /* options ptr */ + s->v4.caddr = cp->caddr.ip; + s->v4.vaddr = cp->vaddr.ip; + s->v4.daddr = cp->daddr.ip; + } + if (cp->flags & IP_VS_CONN_F_SEQ_MASK) { + *(p++) = IPVS_OPT_SEQ_DATA; + *(p++) = sizeof(struct ip_vs_sync_conn_options); + hton_seq((struct ip_vs_seq *)p, &cp->in_seq); + p += sizeof(struct ip_vs_seq); + hton_seq((struct ip_vs_seq *)p, &cp->out_seq); + p += sizeof(struct ip_vs_seq); + } + /* Handle pe data */ + if (cp->pe_data_len && cp->pe_data) { + *(p++) = IPVS_OPT_PE_DATA; + *(p++) = cp->pe_data_len; + memcpy(p, cp->pe_data, cp->pe_data_len); + p += cp->pe_data_len; + if (pe_name_len) { + /* Add PE_NAME */ + *(p++) = IPVS_OPT_PE_NAME; + *(p++) = pe_name_len; + memcpy(p, cp->pe->name, pe_name_len); + p += pe_name_len; + } + } + + spin_unlock_bh(&ipvs->sync_buff_lock); + +control: + /* synchronize its controller if it has */ + cp = cp->control; + if (!cp) + return; + if (cp->flags & IP_VS_CONN_F_TEMPLATE) + pkts = atomic_inc_return(&cp->in_pkts); + else + pkts = sysctl_sync_threshold(ipvs); + goto sloop; +} + +/* + * fill_param used by version 1 + */ +static inline int +ip_vs_conn_fill_param_sync(struct netns_ipvs *ipvs, int af, union ip_vs_sync_conn *sc, + struct ip_vs_conn_param *p, + __u8 *pe_data, unsigned int pe_data_len, + __u8 *pe_name, unsigned int pe_name_len) +{ +#ifdef CONFIG_IP_VS_IPV6 + if (af == AF_INET6) + ip_vs_conn_fill_param(ipvs, af, sc->v6.protocol, + (const union nf_inet_addr *)&sc->v6.caddr, + sc->v6.cport, + (const union nf_inet_addr *)&sc->v6.vaddr, + sc->v6.vport, p); + else +#endif + ip_vs_conn_fill_param(ipvs, af, sc->v4.protocol, + (const union nf_inet_addr *)&sc->v4.caddr, + sc->v4.cport, + (const union nf_inet_addr *)&sc->v4.vaddr, + sc->v4.vport, p); + /* Handle pe data */ + if (pe_data_len) { + if (pe_name_len) { + char buff[IP_VS_PENAME_MAXLEN+1]; + + memcpy(buff, pe_name, pe_name_len); + buff[pe_name_len]=0; + p->pe = __ip_vs_pe_getbyname(buff); + if (!p->pe) { + IP_VS_DBG(3, "BACKUP, no %s engine found/loaded\n", + buff); + return 1; + } + } else { + IP_VS_ERR_RL("BACKUP, Invalid PE parameters\n"); + return 1; + } + + p->pe_data = kmemdup(pe_data, pe_data_len, GFP_ATOMIC); + if (!p->pe_data) { + module_put(p->pe->module); + return -ENOMEM; + } + p->pe_data_len = pe_data_len; + } + return 0; +} + +/* + * Connection Add / Update. + * Common for version 0 and 1 reception of backup sync_conns. + * Param: ... + * timeout is in sec. + */ +static void ip_vs_proc_conn(struct netns_ipvs *ipvs, struct ip_vs_conn_param *param, + unsigned int flags, unsigned int state, + unsigned int protocol, unsigned int type, + const union nf_inet_addr *daddr, __be16 dport, + unsigned long timeout, __u32 fwmark, + struct ip_vs_sync_conn_options *opt) +{ + struct ip_vs_dest *dest; + struct ip_vs_conn *cp; + + if (!(flags & IP_VS_CONN_F_TEMPLATE)) { + cp = ip_vs_conn_in_get(param); + if (cp && ((cp->dport != dport) || + !ip_vs_addr_equal(cp->daf, &cp->daddr, daddr))) { + if (!(flags & IP_VS_CONN_F_INACTIVE)) { + ip_vs_conn_expire_now(cp); + __ip_vs_conn_put(cp); + cp = NULL; + } else { + /* This is the expiration message for the + * connection that was already replaced, so we + * just ignore it. + */ + __ip_vs_conn_put(cp); + kfree(param->pe_data); + return; + } + } + } else { + cp = ip_vs_ct_in_get(param); + } + + if (cp) { + /* Free pe_data */ + kfree(param->pe_data); + + dest = cp->dest; + spin_lock_bh(&cp->lock); + if ((cp->flags ^ flags) & IP_VS_CONN_F_INACTIVE && + !(flags & IP_VS_CONN_F_TEMPLATE) && dest) { + if (flags & IP_VS_CONN_F_INACTIVE) { + atomic_dec(&dest->activeconns); + atomic_inc(&dest->inactconns); + } else { + atomic_inc(&dest->activeconns); + atomic_dec(&dest->inactconns); + } + } + flags &= IP_VS_CONN_F_BACKUP_UPD_MASK; + flags |= cp->flags & ~IP_VS_CONN_F_BACKUP_UPD_MASK; + cp->flags = flags; + spin_unlock_bh(&cp->lock); + if (!dest) + ip_vs_try_bind_dest(cp); + } else { + /* + * Find the appropriate destination for the connection. + * If it is not found the connection will remain unbound + * but still handled. + */ + rcu_read_lock(); + /* This function is only invoked by the synchronization + * code. We do not currently support heterogeneous pools + * with synchronization, so we can make the assumption that + * the svc_af is the same as the dest_af + */ + dest = ip_vs_find_dest(ipvs, type, type, daddr, dport, + param->vaddr, param->vport, protocol, + fwmark, flags); + + cp = ip_vs_conn_new(param, type, daddr, dport, flags, dest, + fwmark); + rcu_read_unlock(); + if (!cp) { + kfree(param->pe_data); + IP_VS_DBG(2, "BACKUP, add new conn. failed\n"); + return; + } + if (!(flags & IP_VS_CONN_F_TEMPLATE)) + kfree(param->pe_data); + } + + if (opt) { + cp->in_seq = opt->in_seq; + cp->out_seq = opt->out_seq; + } + atomic_set(&cp->in_pkts, sysctl_sync_threshold(ipvs)); + cp->state = state; + cp->old_state = cp->state; + /* + * For Ver 0 messages style + * - Not possible to recover the right timeout for templates + * - can not find the right fwmark + * virtual service. If needed, we can do it for + * non-fwmark persistent services. + * Ver 1 messages style. + * - No problem. + */ + if (timeout) { + if (timeout > MAX_SCHEDULE_TIMEOUT / HZ) + timeout = MAX_SCHEDULE_TIMEOUT / HZ; + cp->timeout = timeout*HZ; + } else { + struct ip_vs_proto_data *pd; + + pd = ip_vs_proto_data_get(ipvs, protocol); + if (!(flags & IP_VS_CONN_F_TEMPLATE) && pd && pd->timeout_table) + cp->timeout = pd->timeout_table[state]; + else + cp->timeout = (3*60*HZ); + } + ip_vs_conn_put(cp); +} + +/* + * Process received multicast message for Version 0 + */ +static void ip_vs_process_message_v0(struct netns_ipvs *ipvs, const char *buffer, + const size_t buflen) +{ + struct ip_vs_sync_mesg_v0 *m = (struct ip_vs_sync_mesg_v0 *)buffer; + struct ip_vs_sync_conn_v0 *s; + struct ip_vs_sync_conn_options *opt; + struct ip_vs_protocol *pp; + struct ip_vs_conn_param param; + char *p; + int i; + + p = (char *)buffer + sizeof(struct ip_vs_sync_mesg_v0); + for (i=0; inr_conns; i++) { + unsigned int flags, state; + + if (p + SIMPLE_CONN_SIZE > buffer+buflen) { + IP_VS_ERR_RL("BACKUP v0, bogus conn\n"); + return; + } + s = (struct ip_vs_sync_conn_v0 *) p; + flags = ntohs(s->flags) | IP_VS_CONN_F_SYNC; + flags &= ~IP_VS_CONN_F_HASHED; + if (flags & IP_VS_CONN_F_SEQ_MASK) { + opt = (struct ip_vs_sync_conn_options *)&s[1]; + p += FULL_CONN_SIZE; + if (p > buffer+buflen) { + IP_VS_ERR_RL("BACKUP v0, Dropping buffer bogus conn options\n"); + return; + } + } else { + opt = NULL; + p += SIMPLE_CONN_SIZE; + } + + state = ntohs(s->state); + if (!(flags & IP_VS_CONN_F_TEMPLATE)) { + pp = ip_vs_proto_get(s->protocol); + if (!pp) { + IP_VS_DBG(2, "BACKUP v0, Unsupported protocol %u\n", + s->protocol); + continue; + } + if (state >= pp->num_states) { + IP_VS_DBG(2, "BACKUP v0, Invalid %s state %u\n", + pp->name, state); + continue; + } + } else { + if (state >= IP_VS_CTPL_S_LAST) + IP_VS_DBG(7, "BACKUP v0, Invalid tpl state %u\n", + state); + } + + ip_vs_conn_fill_param(ipvs, AF_INET, s->protocol, + (const union nf_inet_addr *)&s->caddr, + s->cport, + (const union nf_inet_addr *)&s->vaddr, + s->vport, ¶m); + + /* Send timeout as Zero */ + ip_vs_proc_conn(ipvs, ¶m, flags, state, s->protocol, AF_INET, + (union nf_inet_addr *)&s->daddr, s->dport, + 0, 0, opt); + } +} + +/* + * Handle options + */ +static inline int ip_vs_proc_seqopt(__u8 *p, unsigned int plen, + __u32 *opt_flags, + struct ip_vs_sync_conn_options *opt) +{ + struct ip_vs_sync_conn_options *topt; + + topt = (struct ip_vs_sync_conn_options *)p; + + if (plen != sizeof(struct ip_vs_sync_conn_options)) { + IP_VS_DBG(2, "BACKUP, bogus conn options length\n"); + return -EINVAL; + } + if (*opt_flags & IPVS_OPT_F_SEQ_DATA) { + IP_VS_DBG(2, "BACKUP, conn options found twice\n"); + return -EINVAL; + } + ntoh_seq(&topt->in_seq, &opt->in_seq); + ntoh_seq(&topt->out_seq, &opt->out_seq); + *opt_flags |= IPVS_OPT_F_SEQ_DATA; + return 0; +} + +static int ip_vs_proc_str(__u8 *p, unsigned int plen, unsigned int *data_len, + __u8 **data, unsigned int maxlen, + __u32 *opt_flags, __u32 flag) +{ + if (plen > maxlen) { + IP_VS_DBG(2, "BACKUP, bogus par.data len > %d\n", maxlen); + return -EINVAL; + } + if (*opt_flags & flag) { + IP_VS_DBG(2, "BACKUP, Par.data found twice 0x%x\n", flag); + return -EINVAL; + } + *data_len = plen; + *data = p; + *opt_flags |= flag; + return 0; +} +/* + * Process a Version 1 sync. connection + */ +static inline int ip_vs_proc_sync_conn(struct netns_ipvs *ipvs, __u8 *p, __u8 *msg_end) +{ + struct ip_vs_sync_conn_options opt; + union ip_vs_sync_conn *s; + struct ip_vs_protocol *pp; + struct ip_vs_conn_param param; + __u32 flags; + unsigned int af, state, pe_data_len=0, pe_name_len=0; + __u8 *pe_data=NULL, *pe_name=NULL; + __u32 opt_flags=0; + int retc=0; + + s = (union ip_vs_sync_conn *) p; + + if (s->v6.type & STYPE_F_INET6) { +#ifdef CONFIG_IP_VS_IPV6 + af = AF_INET6; + p += sizeof(struct ip_vs_sync_v6); +#else + IP_VS_DBG(3,"BACKUP, IPv6 msg received, and IPVS is not compiled for IPv6\n"); + retc = 10; + goto out; +#endif + } else if (!s->v4.type) { + af = AF_INET; + p += sizeof(struct ip_vs_sync_v4); + } else { + return -10; + } + if (p > msg_end) + return -20; + + /* Process optional params check Type & Len. */ + while (p < msg_end) { + int ptype; + int plen; + + if (p+2 > msg_end) + return -30; + ptype = *(p++); + plen = *(p++); + + if (!plen || ((p + plen) > msg_end)) + return -40; + /* Handle seq option p = param data */ + switch (ptype & ~IPVS_OPT_F_PARAM) { + case IPVS_OPT_SEQ_DATA: + if (ip_vs_proc_seqopt(p, plen, &opt_flags, &opt)) + return -50; + break; + + case IPVS_OPT_PE_DATA: + if (ip_vs_proc_str(p, plen, &pe_data_len, &pe_data, + IP_VS_PEDATA_MAXLEN, &opt_flags, + IPVS_OPT_F_PE_DATA)) + return -60; + break; + + case IPVS_OPT_PE_NAME: + if (ip_vs_proc_str(p, plen,&pe_name_len, &pe_name, + IP_VS_PENAME_MAXLEN, &opt_flags, + IPVS_OPT_F_PE_NAME)) + return -70; + break; + + default: + /* Param data mandatory ? */ + if (!(ptype & IPVS_OPT_F_PARAM)) { + IP_VS_DBG(3, "BACKUP, Unknown mandatory param %d found\n", + ptype & ~IPVS_OPT_F_PARAM); + retc = 20; + goto out; + } + } + p += plen; /* Next option */ + } + + /* Get flags and Mask off unsupported */ + flags = ntohl(s->v4.flags) & IP_VS_CONN_F_BACKUP_MASK; + flags |= IP_VS_CONN_F_SYNC; + state = ntohs(s->v4.state); + + if (!(flags & IP_VS_CONN_F_TEMPLATE)) { + pp = ip_vs_proto_get(s->v4.protocol); + if (!pp) { + IP_VS_DBG(3,"BACKUP, Unsupported protocol %u\n", + s->v4.protocol); + retc = 30; + goto out; + } + if (state >= pp->num_states) { + IP_VS_DBG(3, "BACKUP, Invalid %s state %u\n", + pp->name, state); + retc = 40; + goto out; + } + } else { + if (state >= IP_VS_CTPL_S_LAST) + IP_VS_DBG(7, "BACKUP, Invalid tpl state %u\n", + state); + } + if (ip_vs_conn_fill_param_sync(ipvs, af, s, ¶m, pe_data, + pe_data_len, pe_name, pe_name_len)) { + retc = 50; + goto out; + } + /* If only IPv4, just silent skip IPv6 */ + if (af == AF_INET) + ip_vs_proc_conn(ipvs, ¶m, flags, state, s->v4.protocol, af, + (union nf_inet_addr *)&s->v4.daddr, s->v4.dport, + ntohl(s->v4.timeout), ntohl(s->v4.fwmark), + (opt_flags & IPVS_OPT_F_SEQ_DATA ? &opt : NULL) + ); +#ifdef CONFIG_IP_VS_IPV6 + else + ip_vs_proc_conn(ipvs, ¶m, flags, state, s->v6.protocol, af, + (union nf_inet_addr *)&s->v6.daddr, s->v6.dport, + ntohl(s->v6.timeout), ntohl(s->v6.fwmark), + (opt_flags & IPVS_OPT_F_SEQ_DATA ? &opt : NULL) + ); +#endif + ip_vs_pe_put(param.pe); + return 0; + /* Error exit */ +out: + IP_VS_DBG(2, "BACKUP, Single msg dropped err:%d\n", retc); + return retc; + +} +/* + * Process received multicast message and create the corresponding + * ip_vs_conn entries. + * Handles Version 0 & 1 + */ +static void ip_vs_process_message(struct netns_ipvs *ipvs, __u8 *buffer, + const size_t buflen) +{ + struct ip_vs_sync_mesg *m2 = (struct ip_vs_sync_mesg *)buffer; + __u8 *p, *msg_end; + int i, nr_conns; + + if (buflen < sizeof(struct ip_vs_sync_mesg_v0)) { + IP_VS_DBG(2, "BACKUP, message header too short\n"); + return; + } + + if (buflen != ntohs(m2->size)) { + IP_VS_DBG(2, "BACKUP, bogus message size\n"); + return; + } + /* SyncID sanity check */ + if (ipvs->bcfg.syncid != 0 && m2->syncid != ipvs->bcfg.syncid) { + IP_VS_DBG(7, "BACKUP, Ignoring syncid = %d\n", m2->syncid); + return; + } + /* Handle version 1 message */ + if ((m2->version == SYNC_PROTO_VER) && (m2->reserved == 0) + && (m2->spare == 0)) { + + msg_end = buffer + sizeof(struct ip_vs_sync_mesg); + nr_conns = m2->nr_conns; + + for (i=0; iv4) > buffer+buflen) { + IP_VS_ERR_RL("BACKUP, Dropping buffer, too small\n"); + return; + } + s = (union ip_vs_sync_conn *)p; + size = ntohs(s->v4.ver_size) & SVER_MASK; + msg_end = p + size; + /* Basic sanity checks */ + if (msg_end > buffer+buflen) { + IP_VS_ERR_RL("BACKUP, Dropping buffer, msg > buffer\n"); + return; + } + if (ntohs(s->v4.ver_size) >> SVER_SHIFT) { + IP_VS_ERR_RL("BACKUP, Dropping buffer, Unknown version %d\n", + ntohs(s->v4.ver_size) >> SVER_SHIFT); + return; + } + /* Process a single sync_conn */ + retc = ip_vs_proc_sync_conn(ipvs, p, msg_end); + if (retc < 0) { + IP_VS_ERR_RL("BACKUP, Dropping buffer, Err: %d in decoding\n", + retc); + return; + } + /* Make sure we have 32 bit alignment */ + msg_end = p + ((size + 3) & ~3); + } + } else { + /* Old type of message */ + ip_vs_process_message_v0(ipvs, buffer, buflen); + return; + } +} + + +/* + * Setup sndbuf (mode=1) or rcvbuf (mode=0) + */ +static void set_sock_size(struct sock *sk, int mode, int val) +{ + /* setsockopt(sock, SOL_SOCKET, SO_SNDBUF, &val, sizeof(val)); */ + /* setsockopt(sock, SOL_SOCKET, SO_RCVBUF, &val, sizeof(val)); */ + lock_sock(sk); + if (mode) { + val = clamp_t(int, val, (SOCK_MIN_SNDBUF + 1) / 2, + READ_ONCE(sysctl_wmem_max)); + sk->sk_sndbuf = val * 2; + sk->sk_userlocks |= SOCK_SNDBUF_LOCK; + } else { + val = clamp_t(int, val, (SOCK_MIN_RCVBUF + 1) / 2, + READ_ONCE(sysctl_rmem_max)); + sk->sk_rcvbuf = val * 2; + sk->sk_userlocks |= SOCK_RCVBUF_LOCK; + } + release_sock(sk); +} + +/* + * Setup loopback of outgoing multicasts on a sending socket + */ +static void set_mcast_loop(struct sock *sk, u_char loop) +{ + struct inet_sock *inet = inet_sk(sk); + + /* setsockopt(sock, SOL_IP, IP_MULTICAST_LOOP, &loop, sizeof(loop)); */ + lock_sock(sk); + inet->mc_loop = loop ? 1 : 0; +#ifdef CONFIG_IP_VS_IPV6 + if (sk->sk_family == AF_INET6) { + struct ipv6_pinfo *np = inet6_sk(sk); + + /* IPV6_MULTICAST_LOOP */ + np->mc_loop = loop ? 1 : 0; + } +#endif + release_sock(sk); +} + +/* + * Specify TTL for outgoing multicasts on a sending socket + */ +static void set_mcast_ttl(struct sock *sk, u_char ttl) +{ + struct inet_sock *inet = inet_sk(sk); + + /* setsockopt(sock, SOL_IP, IP_MULTICAST_TTL, &ttl, sizeof(ttl)); */ + lock_sock(sk); + inet->mc_ttl = ttl; +#ifdef CONFIG_IP_VS_IPV6 + if (sk->sk_family == AF_INET6) { + struct ipv6_pinfo *np = inet6_sk(sk); + + /* IPV6_MULTICAST_HOPS */ + np->mcast_hops = ttl; + } +#endif + release_sock(sk); +} + +/* Control fragmentation of messages */ +static void set_mcast_pmtudisc(struct sock *sk, int val) +{ + struct inet_sock *inet = inet_sk(sk); + + /* setsockopt(sock, SOL_IP, IP_MTU_DISCOVER, &val, sizeof(val)); */ + lock_sock(sk); + inet->pmtudisc = val; +#ifdef CONFIG_IP_VS_IPV6 + if (sk->sk_family == AF_INET6) { + struct ipv6_pinfo *np = inet6_sk(sk); + + /* IPV6_MTU_DISCOVER */ + np->pmtudisc = val; + } +#endif + release_sock(sk); +} + +/* + * Specifiy default interface for outgoing multicasts + */ +static int set_mcast_if(struct sock *sk, struct net_device *dev) +{ + struct inet_sock *inet = inet_sk(sk); + + if (sk->sk_bound_dev_if && dev->ifindex != sk->sk_bound_dev_if) + return -EINVAL; + + lock_sock(sk); + inet->mc_index = dev->ifindex; + /* inet->mc_addr = 0; */ +#ifdef CONFIG_IP_VS_IPV6 + if (sk->sk_family == AF_INET6) { + struct ipv6_pinfo *np = inet6_sk(sk); + + /* IPV6_MULTICAST_IF */ + np->mcast_oif = dev->ifindex; + } +#endif + release_sock(sk); + + return 0; +} + + +/* + * Join a multicast group. + * the group is specified by a class D multicast address 224.0.0.0/8 + * in the in_addr structure passed in as a parameter. + */ +static int +join_mcast_group(struct sock *sk, struct in_addr *addr, struct net_device *dev) +{ + struct ip_mreqn mreq; + int ret; + + memset(&mreq, 0, sizeof(mreq)); + memcpy(&mreq.imr_multiaddr, addr, sizeof(struct in_addr)); + + if (sk->sk_bound_dev_if && dev->ifindex != sk->sk_bound_dev_if) + return -EINVAL; + + mreq.imr_ifindex = dev->ifindex; + + lock_sock(sk); + ret = ip_mc_join_group(sk, &mreq); + release_sock(sk); + + return ret; +} + +#ifdef CONFIG_IP_VS_IPV6 +static int join_mcast_group6(struct sock *sk, struct in6_addr *addr, + struct net_device *dev) +{ + int ret; + + if (sk->sk_bound_dev_if && dev->ifindex != sk->sk_bound_dev_if) + return -EINVAL; + + lock_sock(sk); + ret = ipv6_sock_mc_join(sk, dev->ifindex, addr); + release_sock(sk); + + return ret; +} +#endif + +static int bind_mcastif_addr(struct socket *sock, struct net_device *dev) +{ + __be32 addr; + struct sockaddr_in sin; + + addr = inet_select_addr(dev, 0, RT_SCOPE_UNIVERSE); + if (!addr) + pr_err("You probably need to specify IP address on " + "multicast interface.\n"); + + IP_VS_DBG(7, "binding socket with (%s) %pI4\n", + dev->name, &addr); + + /* Now bind the socket with the address of multicast interface */ + sin.sin_family = AF_INET; + sin.sin_addr.s_addr = addr; + sin.sin_port = 0; + + return kernel_bind(sock, (struct sockaddr *)&sin, sizeof(sin)); +} + +static void get_mcast_sockaddr(union ipvs_sockaddr *sa, int *salen, + struct ipvs_sync_daemon_cfg *c, int id) +{ + if (AF_INET6 == c->mcast_af) { + sa->in6 = (struct sockaddr_in6) { + .sin6_family = AF_INET6, + .sin6_port = htons(c->mcast_port + id), + }; + sa->in6.sin6_addr = c->mcast_group.in6; + *salen = sizeof(sa->in6); + } else { + sa->in = (struct sockaddr_in) { + .sin_family = AF_INET, + .sin_port = htons(c->mcast_port + id), + }; + sa->in.sin_addr = c->mcast_group.in; + *salen = sizeof(sa->in); + } +} + +/* + * Set up sending multicast socket over UDP + */ +static int make_send_sock(struct netns_ipvs *ipvs, int id, + struct net_device *dev, struct socket **sock_ret) +{ + /* multicast addr */ + union ipvs_sockaddr mcast_addr; + struct socket *sock; + int result, salen; + + /* First create a socket */ + result = sock_create_kern(ipvs->net, ipvs->mcfg.mcast_af, SOCK_DGRAM, + IPPROTO_UDP, &sock); + if (result < 0) { + pr_err("Error during creation of socket; terminating\n"); + goto error; + } + *sock_ret = sock; + result = set_mcast_if(sock->sk, dev); + if (result < 0) { + pr_err("Error setting outbound mcast interface\n"); + goto error; + } + + set_mcast_loop(sock->sk, 0); + set_mcast_ttl(sock->sk, ipvs->mcfg.mcast_ttl); + /* Allow fragmentation if MTU changes */ + set_mcast_pmtudisc(sock->sk, IP_PMTUDISC_DONT); + result = sysctl_sync_sock_size(ipvs); + if (result > 0) + set_sock_size(sock->sk, 1, result); + + if (AF_INET == ipvs->mcfg.mcast_af) + result = bind_mcastif_addr(sock, dev); + else + result = 0; + if (result < 0) { + pr_err("Error binding address of the mcast interface\n"); + goto error; + } + + get_mcast_sockaddr(&mcast_addr, &salen, &ipvs->mcfg, id); + result = kernel_connect(sock, (struct sockaddr *)&mcast_addr, + salen, 0); + if (result < 0) { + pr_err("Error connecting to the multicast addr\n"); + goto error; + } + + return 0; + +error: + return result; +} + + +/* + * Set up receiving multicast socket over UDP + */ +static int make_receive_sock(struct netns_ipvs *ipvs, int id, + struct net_device *dev, struct socket **sock_ret) +{ + /* multicast addr */ + union ipvs_sockaddr mcast_addr; + struct socket *sock; + int result, salen; + + /* First create a socket */ + result = sock_create_kern(ipvs->net, ipvs->bcfg.mcast_af, SOCK_DGRAM, + IPPROTO_UDP, &sock); + if (result < 0) { + pr_err("Error during creation of socket; terminating\n"); + goto error; + } + *sock_ret = sock; + /* it is equivalent to the REUSEADDR option in user-space */ + sock->sk->sk_reuse = SK_CAN_REUSE; + result = sysctl_sync_sock_size(ipvs); + if (result > 0) + set_sock_size(sock->sk, 0, result); + + get_mcast_sockaddr(&mcast_addr, &salen, &ipvs->bcfg, id); + sock->sk->sk_bound_dev_if = dev->ifindex; + result = kernel_bind(sock, (struct sockaddr *)&mcast_addr, salen); + if (result < 0) { + pr_err("Error binding to the multicast addr\n"); + goto error; + } + + /* join the multicast group */ +#ifdef CONFIG_IP_VS_IPV6 + if (ipvs->bcfg.mcast_af == AF_INET6) + result = join_mcast_group6(sock->sk, &mcast_addr.in6.sin6_addr, + dev); + else +#endif + result = join_mcast_group(sock->sk, &mcast_addr.in.sin_addr, + dev); + if (result < 0) { + pr_err("Error joining to the multicast group\n"); + goto error; + } + + return 0; + +error: + return result; +} + + +static int +ip_vs_send_async(struct socket *sock, const char *buffer, const size_t length) +{ + struct msghdr msg = {.msg_flags = MSG_DONTWAIT|MSG_NOSIGNAL}; + struct kvec iov; + int len; + + EnterFunction(7); + iov.iov_base = (void *)buffer; + iov.iov_len = length; + + len = kernel_sendmsg(sock, &msg, &iov, 1, (size_t)(length)); + + LeaveFunction(7); + return len; +} + +static int +ip_vs_send_sync_msg(struct socket *sock, struct ip_vs_sync_mesg *msg) +{ + int msize; + int ret; + + msize = ntohs(msg->size); + + ret = ip_vs_send_async(sock, (char *)msg, msize); + if (ret >= 0 || ret == -EAGAIN) + return ret; + pr_err("ip_vs_send_async error %d\n", ret); + return 0; +} + +static int +ip_vs_receive(struct socket *sock, char *buffer, const size_t buflen) +{ + struct msghdr msg = {NULL,}; + struct kvec iov = {buffer, buflen}; + int len; + + EnterFunction(7); + + /* Receive a packet */ + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &iov, 1, buflen); + len = sock_recvmsg(sock, &msg, MSG_DONTWAIT); + if (len < 0) + return len; + + LeaveFunction(7); + return len; +} + +/* Wakeup the master thread for sending */ +static void master_wakeup_work_handler(struct work_struct *work) +{ + struct ipvs_master_sync_state *ms = + container_of(work, struct ipvs_master_sync_state, + master_wakeup_work.work); + struct netns_ipvs *ipvs = ms->ipvs; + + spin_lock_bh(&ipvs->sync_lock); + if (ms->sync_queue_len && + ms->sync_queue_delay < IPVS_SYNC_WAKEUP_RATE) { + int id = (int)(ms - ipvs->ms); + + ms->sync_queue_delay = IPVS_SYNC_WAKEUP_RATE; + wake_up_process(ipvs->master_tinfo[id].task); + } + spin_unlock_bh(&ipvs->sync_lock); +} + +/* Get next buffer to send */ +static inline struct ip_vs_sync_buff * +next_sync_buff(struct netns_ipvs *ipvs, struct ipvs_master_sync_state *ms) +{ + struct ip_vs_sync_buff *sb; + + sb = sb_dequeue(ipvs, ms); + if (sb) + return sb; + /* Do not delay entries in buffer for more than 2 seconds */ + return get_curr_sync_buff(ipvs, ms, IPVS_SYNC_FLUSH_TIME); +} + +static int sync_thread_master(void *data) +{ + struct ip_vs_sync_thread_data *tinfo = data; + struct netns_ipvs *ipvs = tinfo->ipvs; + struct ipvs_master_sync_state *ms = &ipvs->ms[tinfo->id]; + struct sock *sk = tinfo->sock->sk; + struct ip_vs_sync_buff *sb; + + pr_info("sync thread started: state = MASTER, mcast_ifn = %s, " + "syncid = %d, id = %d\n", + ipvs->mcfg.mcast_ifn, ipvs->mcfg.syncid, tinfo->id); + + for (;;) { + sb = next_sync_buff(ipvs, ms); + if (unlikely(kthread_should_stop())) + break; + if (!sb) { + schedule_timeout(IPVS_SYNC_CHECK_PERIOD); + continue; + } + while (ip_vs_send_sync_msg(tinfo->sock, sb->mesg) < 0) { + /* (Ab)use interruptible sleep to avoid increasing + * the load avg. + */ + __wait_event_interruptible(*sk_sleep(sk), + sock_writeable(sk) || + kthread_should_stop()); + if (unlikely(kthread_should_stop())) + goto done; + } + ip_vs_sync_buff_release(sb); + } + +done: + __set_current_state(TASK_RUNNING); + if (sb) + ip_vs_sync_buff_release(sb); + + /* clean up the sync_buff queue */ + while ((sb = sb_dequeue(ipvs, ms))) + ip_vs_sync_buff_release(sb); + __set_current_state(TASK_RUNNING); + + /* clean up the current sync_buff */ + sb = get_curr_sync_buff(ipvs, ms, 0); + if (sb) + ip_vs_sync_buff_release(sb); + + return 0; +} + + +static int sync_thread_backup(void *data) +{ + struct ip_vs_sync_thread_data *tinfo = data; + struct netns_ipvs *ipvs = tinfo->ipvs; + struct sock *sk = tinfo->sock->sk; + struct udp_sock *up = udp_sk(sk); + int len; + + pr_info("sync thread started: state = BACKUP, mcast_ifn = %s, " + "syncid = %d, id = %d\n", + ipvs->bcfg.mcast_ifn, ipvs->bcfg.syncid, tinfo->id); + + while (!kthread_should_stop()) { + wait_event_interruptible(*sk_sleep(sk), + !skb_queue_empty_lockless(&sk->sk_receive_queue) || + !skb_queue_empty_lockless(&up->reader_queue) || + kthread_should_stop()); + + /* do we have data now? */ + while (!skb_queue_empty_lockless(&sk->sk_receive_queue) || + !skb_queue_empty_lockless(&up->reader_queue)) { + len = ip_vs_receive(tinfo->sock, tinfo->buf, + ipvs->bcfg.sync_maxlen); + if (len <= 0) { + if (len != -EAGAIN) + pr_err("receiving message error\n"); + break; + } + + ip_vs_process_message(ipvs, tinfo->buf, len); + } + } + + return 0; +} + + +int start_sync_thread(struct netns_ipvs *ipvs, struct ipvs_sync_daemon_cfg *c, + int state) +{ + struct ip_vs_sync_thread_data *ti = NULL, *tinfo; + struct task_struct *task; + struct net_device *dev; + char *name; + int (*threadfn)(void *data); + int id = 0, count, hlen; + int result = -ENOMEM; + u16 mtu, min_mtu; + + IP_VS_DBG(7, "%s(): pid %d\n", __func__, task_pid_nr(current)); + IP_VS_DBG(7, "Each ip_vs_sync_conn entry needs %zd bytes\n", + sizeof(struct ip_vs_sync_conn_v0)); + + /* increase the module use count */ + if (!ip_vs_use_count_inc()) + return -ENOPROTOOPT; + + /* Do not hold one mutex and then to block on another */ + for (;;) { + rtnl_lock(); + if (mutex_trylock(&ipvs->sync_mutex)) + break; + rtnl_unlock(); + mutex_lock(&ipvs->sync_mutex); + if (rtnl_trylock()) + break; + mutex_unlock(&ipvs->sync_mutex); + } + + if (!ipvs->sync_state) { + count = clamp(sysctl_sync_ports(ipvs), 1, IPVS_SYNC_PORTS_MAX); + ipvs->threads_mask = count - 1; + } else + count = ipvs->threads_mask + 1; + + if (c->mcast_af == AF_UNSPEC) { + c->mcast_af = AF_INET; + c->mcast_group.ip = cpu_to_be32(IP_VS_SYNC_GROUP); + } + if (!c->mcast_port) + c->mcast_port = IP_VS_SYNC_PORT; + if (!c->mcast_ttl) + c->mcast_ttl = 1; + + dev = __dev_get_by_name(ipvs->net, c->mcast_ifn); + if (!dev) { + pr_err("Unknown mcast interface: %s\n", c->mcast_ifn); + result = -ENODEV; + goto out_early; + } + hlen = (AF_INET6 == c->mcast_af) ? + sizeof(struct ipv6hdr) + sizeof(struct udphdr) : + sizeof(struct iphdr) + sizeof(struct udphdr); + mtu = (state == IP_VS_STATE_BACKUP) ? + clamp(dev->mtu, 1500U, 65535U) : 1500U; + min_mtu = (state == IP_VS_STATE_BACKUP) ? 1024 : 1; + + if (c->sync_maxlen) + c->sync_maxlen = clamp_t(unsigned int, + c->sync_maxlen, min_mtu, + 65535 - hlen); + else + c->sync_maxlen = mtu - hlen; + + if (state == IP_VS_STATE_MASTER) { + result = -EEXIST; + if (ipvs->ms) + goto out_early; + + ipvs->mcfg = *c; + name = "ipvs-m:%d:%d"; + threadfn = sync_thread_master; + } else if (state == IP_VS_STATE_BACKUP) { + result = -EEXIST; + if (ipvs->backup_tinfo) + goto out_early; + + ipvs->bcfg = *c; + name = "ipvs-b:%d:%d"; + threadfn = sync_thread_backup; + } else { + result = -EINVAL; + goto out_early; + } + + if (state == IP_VS_STATE_MASTER) { + struct ipvs_master_sync_state *ms; + + result = -ENOMEM; + ipvs->ms = kcalloc(count, sizeof(ipvs->ms[0]), GFP_KERNEL); + if (!ipvs->ms) + goto out; + ms = ipvs->ms; + for (id = 0; id < count; id++, ms++) { + INIT_LIST_HEAD(&ms->sync_queue); + ms->sync_queue_len = 0; + ms->sync_queue_delay = 0; + INIT_DELAYED_WORK(&ms->master_wakeup_work, + master_wakeup_work_handler); + ms->ipvs = ipvs; + } + } + result = -ENOMEM; + ti = kcalloc(count, sizeof(struct ip_vs_sync_thread_data), + GFP_KERNEL); + if (!ti) + goto out; + + for (id = 0; id < count; id++) { + tinfo = &ti[id]; + tinfo->ipvs = ipvs; + if (state == IP_VS_STATE_BACKUP) { + result = -ENOMEM; + tinfo->buf = kmalloc(ipvs->bcfg.sync_maxlen, + GFP_KERNEL); + if (!tinfo->buf) + goto out; + } + tinfo->id = id; + if (state == IP_VS_STATE_MASTER) + result = make_send_sock(ipvs, id, dev, &tinfo->sock); + else + result = make_receive_sock(ipvs, id, dev, &tinfo->sock); + if (result < 0) + goto out; + + task = kthread_run(threadfn, tinfo, name, ipvs->gen, id); + if (IS_ERR(task)) { + result = PTR_ERR(task); + goto out; + } + tinfo->task = task; + } + + /* mark as active */ + + if (state == IP_VS_STATE_MASTER) + ipvs->master_tinfo = ti; + else + ipvs->backup_tinfo = ti; + spin_lock_bh(&ipvs->sync_buff_lock); + ipvs->sync_state |= state; + spin_unlock_bh(&ipvs->sync_buff_lock); + + mutex_unlock(&ipvs->sync_mutex); + rtnl_unlock(); + + return 0; + +out: + /* We do not need RTNL lock anymore, release it here so that + * sock_release below can use rtnl_lock to leave the mcast group. + */ + rtnl_unlock(); + id = min(id, count - 1); + if (ti) { + for (tinfo = ti + id; tinfo >= ti; tinfo--) { + if (tinfo->task) + kthread_stop(tinfo->task); + } + } + if (!(ipvs->sync_state & IP_VS_STATE_MASTER)) { + kfree(ipvs->ms); + ipvs->ms = NULL; + } + mutex_unlock(&ipvs->sync_mutex); + + /* No more mutexes, release socks */ + if (ti) { + for (tinfo = ti + id; tinfo >= ti; tinfo--) { + if (tinfo->sock) + sock_release(tinfo->sock); + kfree(tinfo->buf); + } + kfree(ti); + } + + /* decrease the module use count */ + ip_vs_use_count_dec(); + return result; + +out_early: + mutex_unlock(&ipvs->sync_mutex); + rtnl_unlock(); + + /* decrease the module use count */ + ip_vs_use_count_dec(); + return result; +} + + +int stop_sync_thread(struct netns_ipvs *ipvs, int state) +{ + struct ip_vs_sync_thread_data *ti, *tinfo; + int id; + int retc = -EINVAL; + + IP_VS_DBG(7, "%s(): pid %d\n", __func__, task_pid_nr(current)); + + mutex_lock(&ipvs->sync_mutex); + if (state == IP_VS_STATE_MASTER) { + retc = -ESRCH; + if (!ipvs->ms) + goto err; + ti = ipvs->master_tinfo; + + /* + * The lock synchronizes with sb_queue_tail(), so that we don't + * add sync buffers to the queue, when we are already in + * progress of stopping the master sync daemon. + */ + + spin_lock_bh(&ipvs->sync_buff_lock); + spin_lock(&ipvs->sync_lock); + ipvs->sync_state &= ~IP_VS_STATE_MASTER; + spin_unlock(&ipvs->sync_lock); + spin_unlock_bh(&ipvs->sync_buff_lock); + + retc = 0; + for (id = ipvs->threads_mask; id >= 0; id--) { + struct ipvs_master_sync_state *ms = &ipvs->ms[id]; + int ret; + + tinfo = &ti[id]; + pr_info("stopping master sync thread %d ...\n", + task_pid_nr(tinfo->task)); + cancel_delayed_work_sync(&ms->master_wakeup_work); + ret = kthread_stop(tinfo->task); + if (retc >= 0) + retc = ret; + } + kfree(ipvs->ms); + ipvs->ms = NULL; + ipvs->master_tinfo = NULL; + } else if (state == IP_VS_STATE_BACKUP) { + retc = -ESRCH; + if (!ipvs->backup_tinfo) + goto err; + ti = ipvs->backup_tinfo; + + ipvs->sync_state &= ~IP_VS_STATE_BACKUP; + retc = 0; + for (id = ipvs->threads_mask; id >= 0; id--) { + int ret; + + tinfo = &ti[id]; + pr_info("stopping backup sync thread %d ...\n", + task_pid_nr(tinfo->task)); + ret = kthread_stop(tinfo->task); + if (retc >= 0) + retc = ret; + } + ipvs->backup_tinfo = NULL; + } else { + goto err; + } + id = ipvs->threads_mask; + mutex_unlock(&ipvs->sync_mutex); + + /* No more mutexes, release socks */ + for (tinfo = ti + id; tinfo >= ti; tinfo--) { + if (tinfo->sock) + sock_release(tinfo->sock); + kfree(tinfo->buf); + } + kfree(ti); + + /* decrease the module use count */ + ip_vs_use_count_dec(); + return retc; + +err: + mutex_unlock(&ipvs->sync_mutex); + return retc; +} + +/* + * Initialize data struct for each netns + */ +int __net_init ip_vs_sync_net_init(struct netns_ipvs *ipvs) +{ + __mutex_init(&ipvs->sync_mutex, "ipvs->sync_mutex", &__ipvs_sync_key); + spin_lock_init(&ipvs->sync_lock); + spin_lock_init(&ipvs->sync_buff_lock); + return 0; +} + +void ip_vs_sync_net_cleanup(struct netns_ipvs *ipvs) +{ + int retc; + + retc = stop_sync_thread(ipvs, IP_VS_STATE_MASTER); + if (retc && retc != -ESRCH) + pr_err("Failed to stop Master Daemon\n"); + + retc = stop_sync_thread(ipvs, IP_VS_STATE_BACKUP); + if (retc && retc != -ESRCH) + pr_err("Failed to stop Backup Daemon\n"); +} diff --git a/net/netfilter/ipvs/ip_vs_twos.c b/net/netfilter/ipvs/ip_vs_twos.c new file mode 100644 index 000000000..f2579fc9c --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_twos.c @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* IPVS: Power of Twos Choice Scheduling module + * + * Authors: Darby Payne + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include + +#include + +/* Power of Twos Choice scheduling, algorithm originally described by + * Michael Mitzenmacher. + * + * Randomly picks two destinations and picks the one with the least + * amount of connections + * + * The algorithm calculates a few variables + * - total_weight = sum of all weights + * - rweight1 = random number between [0,total_weight] + * - rweight2 = random number between [0,total_weight] + * + * For each destination + * decrement rweight1 and rweight2 by the destination weight + * pick choice1 when rweight1 is <= 0 + * pick choice2 when rweight2 is <= 0 + * + * Return choice2 if choice2 has less connections than choice 1 normalized + * by weight + * + * References + * ---------- + * + * [Mitzenmacher 2016] + * The Power of Two Random Choices: A Survey of Techniques and Results + * Michael Mitzenmacher, Andrea W. Richa y, Ramesh Sitaraman + * http://www.eecs.harvard.edu/~michaelm/NEWWORK/postscripts/twosurvey.pdf + * + */ +static struct ip_vs_dest *ip_vs_twos_schedule(struct ip_vs_service *svc, + const struct sk_buff *skb, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_dest *dest, *choice1 = NULL, *choice2 = NULL; + int rweight1, rweight2, weight1 = -1, weight2 = -1, overhead1 = 0; + int overhead2, total_weight = 0, weight; + + IP_VS_DBG(6, "%s(): Scheduling...\n", __func__); + + /* Generate a random weight between [0,sum of all weights) */ + list_for_each_entry_rcu(dest, &svc->destinations, n_list) { + if (!(dest->flags & IP_VS_DEST_F_OVERLOAD)) { + weight = atomic_read(&dest->weight); + if (weight > 0) { + total_weight += weight; + choice1 = dest; + } + } + } + + if (!choice1) { + ip_vs_scheduler_err(svc, "no destination available"); + return NULL; + } + + /* Add 1 to total_weight so that the random weights are inclusive + * from 0 to total_weight + */ + total_weight += 1; + rweight1 = prandom_u32_max(total_weight); + rweight2 = prandom_u32_max(total_weight); + + /* Pick two weighted servers */ + list_for_each_entry_rcu(dest, &svc->destinations, n_list) { + if (dest->flags & IP_VS_DEST_F_OVERLOAD) + continue; + + weight = atomic_read(&dest->weight); + if (weight <= 0) + continue; + + rweight1 -= weight; + rweight2 -= weight; + + if (rweight1 <= 0 && weight1 == -1) { + choice1 = dest; + weight1 = weight; + overhead1 = ip_vs_dest_conn_overhead(dest); + } + + if (rweight2 <= 0 && weight2 == -1) { + choice2 = dest; + weight2 = weight; + overhead2 = ip_vs_dest_conn_overhead(dest); + } + + if (weight1 != -1 && weight2 != -1) + goto nextstage; + } + +nextstage: + if (choice2 && (weight2 * overhead1) > (weight1 * overhead2)) + choice1 = choice2; + + IP_VS_DBG_BUF(6, "twos: server %s:%u conns %d refcnt %d weight %d\n", + IP_VS_DBG_ADDR(choice1->af, &choice1->addr), + ntohs(choice1->port), atomic_read(&choice1->activeconns), + refcount_read(&choice1->refcnt), + atomic_read(&choice1->weight)); + + return choice1; +} + +static struct ip_vs_scheduler ip_vs_twos_scheduler = { + .name = "twos", + .refcnt = ATOMIC_INIT(0), + .module = THIS_MODULE, + .n_list = LIST_HEAD_INIT(ip_vs_twos_scheduler.n_list), + .schedule = ip_vs_twos_schedule, +}; + +static int __init ip_vs_twos_init(void) +{ + return register_ip_vs_scheduler(&ip_vs_twos_scheduler); +} + +static void __exit ip_vs_twos_cleanup(void) +{ + unregister_ip_vs_scheduler(&ip_vs_twos_scheduler); + synchronize_rcu(); +} + +module_init(ip_vs_twos_init); +module_exit(ip_vs_twos_cleanup); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/ipvs/ip_vs_wlc.c b/net/netfilter/ipvs/ip_vs_wlc.c new file mode 100644 index 000000000..09f584b56 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_wlc.c @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPVS: Weighted Least-Connection Scheduling module + * + * Authors: Wensong Zhang + * Peter Kese + * + * Changes: + * Wensong Zhang : changed the ip_vs_wlc_schedule to return dest + * Wensong Zhang : changed to use the inactconns in scheduling + * Wensong Zhang : changed some comestics things for debugging + * Wensong Zhang : changed for the d-linked destination list + * Wensong Zhang : added the ip_vs_wlc_update_svc + * Wensong Zhang : added any dest with weight=0 is quiesced + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include + +#include + +/* + * Weighted Least Connection scheduling + */ +static struct ip_vs_dest * +ip_vs_wlc_schedule(struct ip_vs_service *svc, const struct sk_buff *skb, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_dest *dest, *least; + int loh, doh; + + IP_VS_DBG(6, "ip_vs_wlc_schedule(): Scheduling...\n"); + + /* + * We calculate the load of each dest server as follows: + * (dest overhead) / dest->weight + * + * Remember -- no floats in kernel mode!!! + * The comparison of h1*w2 > h2*w1 is equivalent to that of + * h1/w1 > h2/w2 + * if every weight is larger than zero. + * + * The server with weight=0 is quiesced and will not receive any + * new connections. + */ + + list_for_each_entry_rcu(dest, &svc->destinations, n_list) { + if (!(dest->flags & IP_VS_DEST_F_OVERLOAD) && + atomic_read(&dest->weight) > 0) { + least = dest; + loh = ip_vs_dest_conn_overhead(least); + goto nextstage; + } + } + ip_vs_scheduler_err(svc, "no destination available"); + return NULL; + + /* + * Find the destination with the least load. + */ + nextstage: + list_for_each_entry_continue_rcu(dest, &svc->destinations, n_list) { + if (dest->flags & IP_VS_DEST_F_OVERLOAD) + continue; + doh = ip_vs_dest_conn_overhead(dest); + if ((__s64)loh * atomic_read(&dest->weight) > + (__s64)doh * atomic_read(&least->weight)) { + least = dest; + loh = doh; + } + } + + IP_VS_DBG_BUF(6, "WLC: server %s:%u " + "activeconns %d refcnt %d weight %d overhead %d\n", + IP_VS_DBG_ADDR(least->af, &least->addr), + ntohs(least->port), + atomic_read(&least->activeconns), + refcount_read(&least->refcnt), + atomic_read(&least->weight), loh); + + return least; +} + + +static struct ip_vs_scheduler ip_vs_wlc_scheduler = +{ + .name = "wlc", + .refcnt = ATOMIC_INIT(0), + .module = THIS_MODULE, + .n_list = LIST_HEAD_INIT(ip_vs_wlc_scheduler.n_list), + .schedule = ip_vs_wlc_schedule, +}; + + +static int __init ip_vs_wlc_init(void) +{ + return register_ip_vs_scheduler(&ip_vs_wlc_scheduler); +} + +static void __exit ip_vs_wlc_cleanup(void) +{ + unregister_ip_vs_scheduler(&ip_vs_wlc_scheduler); + synchronize_rcu(); +} + +module_init(ip_vs_wlc_init); +module_exit(ip_vs_wlc_cleanup); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/ipvs/ip_vs_wrr.c b/net/netfilter/ipvs/ip_vs_wrr.c new file mode 100644 index 000000000..1bc7a0789 --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_wrr.c @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IPVS: Weighted Round-Robin Scheduling module + * + * Authors: Wensong Zhang + * + * Changes: + * Wensong Zhang : changed the ip_vs_wrr_schedule to return dest + * Wensong Zhang : changed some comestics things for debugging + * Wensong Zhang : changed for the d-linked destination list + * Wensong Zhang : added the ip_vs_wrr_update_svc + * Julian Anastasov : fixed the bug of returning destination + * with weight 0 when all weights are zero + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include + +#include + +/* The WRR algorithm depends on some caclulations: + * - mw: maximum weight + * - di: weight step, greatest common divisor from all weights + * - cw: current required weight + * As result, all weights are in the [di..mw] range with a step=di. + * + * First, we start with cw = mw and select dests with weight >= cw. + * Then cw is reduced with di and all dests are checked again. + * Last pass should be with cw = di. We have mw/di passes in total: + * + * pass 1: cw = max weight + * pass 2: cw = max weight - di + * pass 3: cw = max weight - 2 * di + * ... + * last pass: cw = di + * + * Weights are supposed to be >= di but we run in parallel with + * weight changes, it is possible some dest weight to be reduced + * below di, bad if it is the only available dest. + * + * So, we modify how mw is calculated, now it is reduced with (di - 1), + * so that last cw is 1 to catch such dests with weight below di: + * pass 1: cw = max weight - (di - 1) + * pass 2: cw = max weight - di - (di - 1) + * pass 3: cw = max weight - 2 * di - (di - 1) + * ... + * last pass: cw = 1 + * + */ + +/* + * current destination pointer for weighted round-robin scheduling + */ +struct ip_vs_wrr_mark { + struct ip_vs_dest *cl; /* current dest or head */ + int cw; /* current weight */ + int mw; /* maximum weight */ + int di; /* decreasing interval */ + struct rcu_head rcu_head; +}; + + +static int ip_vs_wrr_gcd_weight(struct ip_vs_service *svc) +{ + struct ip_vs_dest *dest; + int weight; + int g = 0; + + list_for_each_entry(dest, &svc->destinations, n_list) { + weight = atomic_read(&dest->weight); + if (weight > 0) { + if (g > 0) + g = gcd(weight, g); + else + g = weight; + } + } + return g ? g : 1; +} + + +/* + * Get the maximum weight of the service destinations. + */ +static int ip_vs_wrr_max_weight(struct ip_vs_service *svc) +{ + struct ip_vs_dest *dest; + int new_weight, weight = 0; + + list_for_each_entry(dest, &svc->destinations, n_list) { + new_weight = atomic_read(&dest->weight); + if (new_weight > weight) + weight = new_weight; + } + + return weight; +} + + +static int ip_vs_wrr_init_svc(struct ip_vs_service *svc) +{ + struct ip_vs_wrr_mark *mark; + + /* + * Allocate the mark variable for WRR scheduling + */ + mark = kmalloc(sizeof(struct ip_vs_wrr_mark), GFP_KERNEL); + if (mark == NULL) + return -ENOMEM; + + mark->cl = list_entry(&svc->destinations, struct ip_vs_dest, n_list); + mark->di = ip_vs_wrr_gcd_weight(svc); + mark->mw = ip_vs_wrr_max_weight(svc) - (mark->di - 1); + mark->cw = mark->mw; + svc->sched_data = mark; + + return 0; +} + + +static void ip_vs_wrr_done_svc(struct ip_vs_service *svc) +{ + struct ip_vs_wrr_mark *mark = svc->sched_data; + + /* + * Release the mark variable + */ + kfree_rcu(mark, rcu_head); +} + + +static int ip_vs_wrr_dest_changed(struct ip_vs_service *svc, + struct ip_vs_dest *dest) +{ + struct ip_vs_wrr_mark *mark = svc->sched_data; + + spin_lock_bh(&svc->sched_lock); + mark->cl = list_entry(&svc->destinations, struct ip_vs_dest, n_list); + mark->di = ip_vs_wrr_gcd_weight(svc); + mark->mw = ip_vs_wrr_max_weight(svc) - (mark->di - 1); + if (mark->cw > mark->mw || !mark->cw) + mark->cw = mark->mw; + else if (mark->di > 1) + mark->cw = (mark->cw / mark->di) * mark->di + 1; + spin_unlock_bh(&svc->sched_lock); + return 0; +} + + +/* + * Weighted Round-Robin Scheduling + */ +static struct ip_vs_dest * +ip_vs_wrr_schedule(struct ip_vs_service *svc, const struct sk_buff *skb, + struct ip_vs_iphdr *iph) +{ + struct ip_vs_dest *dest, *last, *stop = NULL; + struct ip_vs_wrr_mark *mark = svc->sched_data; + bool last_pass = false, restarted = false; + + IP_VS_DBG(6, "%s(): Scheduling...\n", __func__); + + spin_lock_bh(&svc->sched_lock); + dest = mark->cl; + /* No available dests? */ + if (mark->mw == 0) + goto err_noavail; + last = dest; + /* Stop only after all dests were checked for weight >= 1 (last pass) */ + while (1) { + list_for_each_entry_continue_rcu(dest, + &svc->destinations, + n_list) { + if (!(dest->flags & IP_VS_DEST_F_OVERLOAD) && + atomic_read(&dest->weight) >= mark->cw) + goto found; + if (dest == stop) + goto err_over; + } + mark->cw -= mark->di; + if (mark->cw <= 0) { + mark->cw = mark->mw; + /* Stop if we tried last pass from first dest: + * 1. last_pass: we started checks when cw > di but + * then all dests were checked for w >= 1 + * 2. last was head: the first and only traversal + * was for weight >= 1, for all dests. + */ + if (last_pass || + &last->n_list == &svc->destinations) + goto err_over; + restarted = true; + } + last_pass = mark->cw <= mark->di; + if (last_pass && restarted && + &last->n_list != &svc->destinations) { + /* First traversal was for w >= 1 but only + * for dests after 'last', now do the same + * for all dests up to 'last'. + */ + stop = last; + } + } + +found: + IP_VS_DBG_BUF(6, "WRR: server %s:%u " + "activeconns %d refcnt %d weight %d\n", + IP_VS_DBG_ADDR(dest->af, &dest->addr), ntohs(dest->port), + atomic_read(&dest->activeconns), + refcount_read(&dest->refcnt), + atomic_read(&dest->weight)); + mark->cl = dest; + + out: + spin_unlock_bh(&svc->sched_lock); + return dest; + +err_noavail: + mark->cl = dest; + dest = NULL; + ip_vs_scheduler_err(svc, "no destination available"); + goto out; + +err_over: + mark->cl = dest; + dest = NULL; + ip_vs_scheduler_err(svc, "no destination available: " + "all destinations are overloaded"); + goto out; +} + + +static struct ip_vs_scheduler ip_vs_wrr_scheduler = { + .name = "wrr", + .refcnt = ATOMIC_INIT(0), + .module = THIS_MODULE, + .n_list = LIST_HEAD_INIT(ip_vs_wrr_scheduler.n_list), + .init_service = ip_vs_wrr_init_svc, + .done_service = ip_vs_wrr_done_svc, + .add_dest = ip_vs_wrr_dest_changed, + .del_dest = ip_vs_wrr_dest_changed, + .upd_dest = ip_vs_wrr_dest_changed, + .schedule = ip_vs_wrr_schedule, +}; + +static int __init ip_vs_wrr_init(void) +{ + return register_ip_vs_scheduler(&ip_vs_wrr_scheduler) ; +} + +static void __exit ip_vs_wrr_cleanup(void) +{ + unregister_ip_vs_scheduler(&ip_vs_wrr_scheduler); + synchronize_rcu(); +} + +module_init(ip_vs_wrr_init); +module_exit(ip_vs_wrr_cleanup); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/ipvs/ip_vs_xmit.c b/net/netfilter/ipvs/ip_vs_xmit.c new file mode 100644 index 000000000..d40a4ca2b --- /dev/null +++ b/net/netfilter/ipvs/ip_vs_xmit.c @@ -0,0 +1,1686 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * ip_vs_xmit.c: various packet transmitters for IPVS + * + * Authors: Wensong Zhang + * Julian Anastasov + * + * Changes: + * + * Description of forwarding methods: + * - all transmitters are called from LOCAL_IN (remote clients) and + * LOCAL_OUT (local clients) but for ICMP can be called from FORWARD + * - not all connections have destination server, for example, + * connections in backup server when fwmark is used + * - bypass connections use daddr from packet + * - we can use dst without ref while sending in RCU section, we use + * ref when returning NF_ACCEPT for NAT-ed packet via loopback + * LOCAL_OUT rules: + * - skb->dev is NULL, skb->protocol is not set (both are set in POST_ROUTING) + * - skb->pkt_type is not set yet + * - the only place where we can see skb->sk != NULL + */ + +#define KMSG_COMPONENT "IPVS" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include /* for tcphdr */ +#include +#include +#include +#include /* for csum_tcpudp_magic */ +#include +#include /* for icmp_send */ +#include /* for ip_route_output */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +enum { + IP_VS_RT_MODE_LOCAL = 1, /* Allow local dest */ + IP_VS_RT_MODE_NON_LOCAL = 2, /* Allow non-local dest */ + IP_VS_RT_MODE_RDR = 4, /* Allow redirect from remote daddr to + * local + */ + IP_VS_RT_MODE_CONNECT = 8, /* Always bind route to saddr */ + IP_VS_RT_MODE_KNOWN_NH = 16,/* Route via remote addr */ + IP_VS_RT_MODE_TUNNEL = 32,/* Tunnel mode */ +}; + +static inline struct ip_vs_dest_dst *ip_vs_dest_dst_alloc(void) +{ + return kmalloc(sizeof(struct ip_vs_dest_dst), GFP_ATOMIC); +} + +static inline void ip_vs_dest_dst_free(struct ip_vs_dest_dst *dest_dst) +{ + kfree(dest_dst); +} + +/* + * Destination cache to speed up outgoing route lookup + */ +static inline void +__ip_vs_dst_set(struct ip_vs_dest *dest, struct ip_vs_dest_dst *dest_dst, + struct dst_entry *dst, u32 dst_cookie) +{ + struct ip_vs_dest_dst *old; + + old = rcu_dereference_protected(dest->dest_dst, + lockdep_is_held(&dest->dst_lock)); + + if (dest_dst) { + dest_dst->dst_cache = dst; + dest_dst->dst_cookie = dst_cookie; + } + rcu_assign_pointer(dest->dest_dst, dest_dst); + + if (old) + call_rcu(&old->rcu_head, ip_vs_dest_dst_rcu_free); +} + +static inline struct ip_vs_dest_dst * +__ip_vs_dst_check(struct ip_vs_dest *dest) +{ + struct ip_vs_dest_dst *dest_dst = rcu_dereference(dest->dest_dst); + struct dst_entry *dst; + + if (!dest_dst) + return NULL; + dst = dest_dst->dst_cache; + if (dst->obsolete && + dst->ops->check(dst, dest_dst->dst_cookie) == NULL) + return NULL; + return dest_dst; +} + +static inline bool +__mtu_check_toobig_v6(const struct sk_buff *skb, u32 mtu) +{ + if (IP6CB(skb)->frag_max_size) { + /* frag_max_size tell us that, this packet have been + * defragmented by netfilter IPv6 conntrack module. + */ + if (IP6CB(skb)->frag_max_size > mtu) + return true; /* largest fragment violate MTU */ + } + else if (skb->len > mtu && !skb_is_gso(skb)) { + return true; /* Packet size violate MTU size */ + } + return false; +} + +/* Get route to daddr, update *saddr, optionally bind route to saddr */ +static struct rtable *do_output_route4(struct net *net, __be32 daddr, + int rt_mode, __be32 *saddr) +{ + struct flowi4 fl4; + struct rtable *rt; + bool loop = false; + + memset(&fl4, 0, sizeof(fl4)); + fl4.daddr = daddr; + fl4.flowi4_flags = (rt_mode & IP_VS_RT_MODE_KNOWN_NH) ? + FLOWI_FLAG_KNOWN_NH : 0; + +retry: + rt = ip_route_output_key(net, &fl4); + if (IS_ERR(rt)) { + /* Invalid saddr ? */ + if (PTR_ERR(rt) == -EINVAL && *saddr && + rt_mode & IP_VS_RT_MODE_CONNECT && !loop) { + *saddr = 0; + flowi4_update_output(&fl4, 0, 0, daddr, 0); + goto retry; + } + IP_VS_DBG_RL("ip_route_output error, dest: %pI4\n", &daddr); + return NULL; + } else if (!*saddr && rt_mode & IP_VS_RT_MODE_CONNECT && fl4.saddr) { + ip_rt_put(rt); + *saddr = fl4.saddr; + flowi4_update_output(&fl4, 0, 0, daddr, fl4.saddr); + loop = true; + goto retry; + } + *saddr = fl4.saddr; + return rt; +} + +#ifdef CONFIG_IP_VS_IPV6 +static inline int __ip_vs_is_local_route6(struct rt6_info *rt) +{ + return rt->dst.dev && rt->dst.dev->flags & IFF_LOOPBACK; +} +#endif + +static inline bool crosses_local_route_boundary(int skb_af, struct sk_buff *skb, + int rt_mode, + bool new_rt_is_local) +{ + bool rt_mode_allow_local = !!(rt_mode & IP_VS_RT_MODE_LOCAL); + bool rt_mode_allow_non_local = !!(rt_mode & IP_VS_RT_MODE_NON_LOCAL); + bool rt_mode_allow_redirect = !!(rt_mode & IP_VS_RT_MODE_RDR); + bool source_is_loopback; + bool old_rt_is_local; + +#ifdef CONFIG_IP_VS_IPV6 + if (skb_af == AF_INET6) { + int addr_type = ipv6_addr_type(&ipv6_hdr(skb)->saddr); + + source_is_loopback = + (!skb->dev || skb->dev->flags & IFF_LOOPBACK) && + (addr_type & IPV6_ADDR_LOOPBACK); + old_rt_is_local = __ip_vs_is_local_route6( + (struct rt6_info *)skb_dst(skb)); + } else +#endif + { + source_is_loopback = ipv4_is_loopback(ip_hdr(skb)->saddr); + old_rt_is_local = skb_rtable(skb)->rt_flags & RTCF_LOCAL; + } + + if (unlikely(new_rt_is_local)) { + if (!rt_mode_allow_local) + return true; + if (!rt_mode_allow_redirect && !old_rt_is_local) + return true; + } else { + if (!rt_mode_allow_non_local) + return true; + if (source_is_loopback) + return true; + } + return false; +} + +static inline void maybe_update_pmtu(int skb_af, struct sk_buff *skb, int mtu) +{ + struct sock *sk = skb->sk; + struct rtable *ort = skb_rtable(skb); + + if (!skb->dev && sk && sk_fullsock(sk)) + ort->dst.ops->update_pmtu(&ort->dst, sk, NULL, mtu, true); +} + +static inline bool ensure_mtu_is_adequate(struct netns_ipvs *ipvs, int skb_af, + int rt_mode, + struct ip_vs_iphdr *ipvsh, + struct sk_buff *skb, int mtu) +{ +#ifdef CONFIG_IP_VS_IPV6 + if (skb_af == AF_INET6) { + struct net *net = ipvs->net; + + if (unlikely(__mtu_check_toobig_v6(skb, mtu))) { + if (!skb->dev) + skb->dev = net->loopback_dev; + /* only send ICMP too big on first fragment */ + if (!ipvsh->fragoffs && !ip_vs_iph_icmp(ipvsh)) + icmpv6_send(skb, ICMPV6_PKT_TOOBIG, 0, mtu); + IP_VS_DBG(1, "frag needed for %pI6c\n", + &ipv6_hdr(skb)->saddr); + return false; + } + } else +#endif + { + /* If we're going to tunnel the packet and pmtu discovery + * is disabled, we'll just fragment it anyway + */ + if ((rt_mode & IP_VS_RT_MODE_TUNNEL) && !sysctl_pmtu_disc(ipvs)) + return true; + + if (unlikely(ip_hdr(skb)->frag_off & htons(IP_DF) && + skb->len > mtu && !skb_is_gso(skb) && + !ip_vs_iph_icmp(ipvsh))) { + icmp_send(skb, ICMP_DEST_UNREACH, ICMP_FRAG_NEEDED, + htonl(mtu)); + IP_VS_DBG(1, "frag needed for %pI4\n", + &ip_hdr(skb)->saddr); + return false; + } + } + + return true; +} + +static inline bool decrement_ttl(struct netns_ipvs *ipvs, + int skb_af, + struct sk_buff *skb) +{ + struct net *net = ipvs->net; + +#ifdef CONFIG_IP_VS_IPV6 + if (skb_af == AF_INET6) { + struct dst_entry *dst = skb_dst(skb); + + /* check and decrement ttl */ + if (ipv6_hdr(skb)->hop_limit <= 1) { + struct inet6_dev *idev = __in6_dev_get_safely(skb->dev); + + /* Force OUTPUT device used as source address */ + skb->dev = dst->dev; + icmpv6_send(skb, ICMPV6_TIME_EXCEED, + ICMPV6_EXC_HOPLIMIT, 0); + IP6_INC_STATS(net, idev, IPSTATS_MIB_INHDRERRORS); + + return false; + } + + /* don't propagate ttl change to cloned packets */ + if (skb_ensure_writable(skb, sizeof(struct ipv6hdr))) + return false; + + ipv6_hdr(skb)->hop_limit--; + } else +#endif + { + if (ip_hdr(skb)->ttl <= 1) { + /* Tell the sender its packet died... */ + IP_INC_STATS(net, IPSTATS_MIB_INHDRERRORS); + icmp_send(skb, ICMP_TIME_EXCEEDED, ICMP_EXC_TTL, 0); + return false; + } + + /* don't propagate ttl change to cloned packets */ + if (skb_ensure_writable(skb, sizeof(struct iphdr))) + return false; + + /* Decrease ttl */ + ip_decrease_ttl(ip_hdr(skb)); + } + + return true; +} + +/* Get route to destination or remote server */ +static int +__ip_vs_get_out_rt(struct netns_ipvs *ipvs, int skb_af, struct sk_buff *skb, + struct ip_vs_dest *dest, + __be32 daddr, int rt_mode, __be32 *ret_saddr, + struct ip_vs_iphdr *ipvsh) +{ + struct net *net = ipvs->net; + struct ip_vs_dest_dst *dest_dst; + struct rtable *rt; /* Route to the other host */ + int mtu; + int local, noref = 1; + + if (dest) { + dest_dst = __ip_vs_dst_check(dest); + if (likely(dest_dst)) + rt = (struct rtable *) dest_dst->dst_cache; + else { + dest_dst = ip_vs_dest_dst_alloc(); + spin_lock_bh(&dest->dst_lock); + if (!dest_dst) { + __ip_vs_dst_set(dest, NULL, NULL, 0); + spin_unlock_bh(&dest->dst_lock); + goto err_unreach; + } + rt = do_output_route4(net, dest->addr.ip, rt_mode, + &dest_dst->dst_saddr.ip); + if (!rt) { + __ip_vs_dst_set(dest, NULL, NULL, 0); + spin_unlock_bh(&dest->dst_lock); + ip_vs_dest_dst_free(dest_dst); + goto err_unreach; + } + __ip_vs_dst_set(dest, dest_dst, &rt->dst, 0); + spin_unlock_bh(&dest->dst_lock); + IP_VS_DBG(10, "new dst %pI4, src %pI4, refcnt=%d\n", + &dest->addr.ip, &dest_dst->dst_saddr.ip, + atomic_read(&rt->dst.__refcnt)); + } + if (ret_saddr) + *ret_saddr = dest_dst->dst_saddr.ip; + } else { + __be32 saddr = htonl(INADDR_ANY); + + noref = 0; + + /* For such unconfigured boxes avoid many route lookups + * for performance reasons because we do not remember saddr + */ + rt_mode &= ~IP_VS_RT_MODE_CONNECT; + rt = do_output_route4(net, daddr, rt_mode, &saddr); + if (!rt) + goto err_unreach; + if (ret_saddr) + *ret_saddr = saddr; + } + + local = (rt->rt_flags & RTCF_LOCAL) ? 1 : 0; + if (unlikely(crosses_local_route_boundary(skb_af, skb, rt_mode, + local))) { + IP_VS_DBG_RL("We are crossing local and non-local addresses" + " daddr=%pI4\n", &daddr); + goto err_put; + } + + if (unlikely(local)) { + /* skb to local stack, preserve old route */ + if (!noref) + ip_rt_put(rt); + return local; + } + + if (!decrement_ttl(ipvs, skb_af, skb)) + goto err_put; + + if (likely(!(rt_mode & IP_VS_RT_MODE_TUNNEL))) { + mtu = dst_mtu(&rt->dst); + } else { + mtu = dst_mtu(&rt->dst) - sizeof(struct iphdr); + if (!dest) + goto err_put; + if (dest->tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GUE) { + mtu -= sizeof(struct udphdr) + sizeof(struct guehdr); + if ((dest->tun_flags & + IP_VS_TUNNEL_ENCAP_FLAG_REMCSUM) && + skb->ip_summed == CHECKSUM_PARTIAL) + mtu -= GUE_PLEN_REMCSUM + GUE_LEN_PRIV; + } else if (dest->tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GRE) { + __be16 tflags = 0; + + if (dest->tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_CSUM) + tflags |= TUNNEL_CSUM; + mtu -= gre_calc_hlen(tflags); + } + if (mtu < 68) { + IP_VS_DBG_RL("%s(): mtu less than 68\n", __func__); + goto err_put; + } + maybe_update_pmtu(skb_af, skb, mtu); + } + + if (!ensure_mtu_is_adequate(ipvs, skb_af, rt_mode, ipvsh, skb, mtu)) + goto err_put; + + skb_dst_drop(skb); + if (noref) + skb_dst_set_noref(skb, &rt->dst); + else + skb_dst_set(skb, &rt->dst); + + return local; + +err_put: + if (!noref) + ip_rt_put(rt); + return -1; + +err_unreach: + dst_link_failure(skb); + return -1; +} + +#ifdef CONFIG_IP_VS_IPV6 +static struct dst_entry * +__ip_vs_route_output_v6(struct net *net, struct in6_addr *daddr, + struct in6_addr *ret_saddr, int do_xfrm, int rt_mode) +{ + struct dst_entry *dst; + struct flowi6 fl6 = { + .daddr = *daddr, + }; + + if (rt_mode & IP_VS_RT_MODE_KNOWN_NH) + fl6.flowi6_flags = FLOWI_FLAG_KNOWN_NH; + + dst = ip6_route_output(net, NULL, &fl6); + if (dst->error) + goto out_err; + if (!ret_saddr) + return dst; + if (ipv6_addr_any(&fl6.saddr) && + ipv6_dev_get_saddr(net, ip6_dst_idev(dst)->dev, + &fl6.daddr, 0, &fl6.saddr) < 0) + goto out_err; + if (do_xfrm) { + dst = xfrm_lookup(net, dst, flowi6_to_flowi(&fl6), NULL, 0); + if (IS_ERR(dst)) { + dst = NULL; + goto out_err; + } + } + *ret_saddr = fl6.saddr; + return dst; + +out_err: + dst_release(dst); + IP_VS_DBG_RL("ip6_route_output error, dest: %pI6\n", daddr); + return NULL; +} + +/* + * Get route to destination or remote server + */ +static int +__ip_vs_get_out_rt_v6(struct netns_ipvs *ipvs, int skb_af, struct sk_buff *skb, + struct ip_vs_dest *dest, + struct in6_addr *daddr, struct in6_addr *ret_saddr, + struct ip_vs_iphdr *ipvsh, int do_xfrm, int rt_mode) +{ + struct net *net = ipvs->net; + struct ip_vs_dest_dst *dest_dst; + struct rt6_info *rt; /* Route to the other host */ + struct dst_entry *dst; + int mtu; + int local, noref = 1; + + if (dest) { + dest_dst = __ip_vs_dst_check(dest); + if (likely(dest_dst)) + rt = (struct rt6_info *) dest_dst->dst_cache; + else { + u32 cookie; + + dest_dst = ip_vs_dest_dst_alloc(); + spin_lock_bh(&dest->dst_lock); + if (!dest_dst) { + __ip_vs_dst_set(dest, NULL, NULL, 0); + spin_unlock_bh(&dest->dst_lock); + goto err_unreach; + } + dst = __ip_vs_route_output_v6(net, &dest->addr.in6, + &dest_dst->dst_saddr.in6, + do_xfrm, rt_mode); + if (!dst) { + __ip_vs_dst_set(dest, NULL, NULL, 0); + spin_unlock_bh(&dest->dst_lock); + ip_vs_dest_dst_free(dest_dst); + goto err_unreach; + } + rt = (struct rt6_info *) dst; + cookie = rt6_get_cookie(rt); + __ip_vs_dst_set(dest, dest_dst, &rt->dst, cookie); + spin_unlock_bh(&dest->dst_lock); + IP_VS_DBG(10, "new dst %pI6, src %pI6, refcnt=%d\n", + &dest->addr.in6, &dest_dst->dst_saddr.in6, + atomic_read(&rt->dst.__refcnt)); + } + if (ret_saddr) + *ret_saddr = dest_dst->dst_saddr.in6; + } else { + noref = 0; + dst = __ip_vs_route_output_v6(net, daddr, ret_saddr, do_xfrm, + rt_mode); + if (!dst) + goto err_unreach; + rt = (struct rt6_info *) dst; + } + + local = __ip_vs_is_local_route6(rt); + + if (unlikely(crosses_local_route_boundary(skb_af, skb, rt_mode, + local))) { + IP_VS_DBG_RL("We are crossing local and non-local addresses" + " daddr=%pI6\n", daddr); + goto err_put; + } + + if (unlikely(local)) { + /* skb to local stack, preserve old route */ + if (!noref) + dst_release(&rt->dst); + return local; + } + + if (!decrement_ttl(ipvs, skb_af, skb)) + goto err_put; + + /* MTU checking */ + if (likely(!(rt_mode & IP_VS_RT_MODE_TUNNEL))) + mtu = dst_mtu(&rt->dst); + else { + mtu = dst_mtu(&rt->dst) - sizeof(struct ipv6hdr); + if (!dest) + goto err_put; + if (dest->tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GUE) { + mtu -= sizeof(struct udphdr) + sizeof(struct guehdr); + if ((dest->tun_flags & + IP_VS_TUNNEL_ENCAP_FLAG_REMCSUM) && + skb->ip_summed == CHECKSUM_PARTIAL) + mtu -= GUE_PLEN_REMCSUM + GUE_LEN_PRIV; + } else if (dest->tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GRE) { + __be16 tflags = 0; + + if (dest->tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_CSUM) + tflags |= TUNNEL_CSUM; + mtu -= gre_calc_hlen(tflags); + } + if (mtu < IPV6_MIN_MTU) { + IP_VS_DBG_RL("%s(): mtu less than %d\n", __func__, + IPV6_MIN_MTU); + goto err_put; + } + maybe_update_pmtu(skb_af, skb, mtu); + } + + if (!ensure_mtu_is_adequate(ipvs, skb_af, rt_mode, ipvsh, skb, mtu)) + goto err_put; + + skb_dst_drop(skb); + if (noref) + skb_dst_set_noref(skb, &rt->dst); + else + skb_dst_set(skb, &rt->dst); + + return local; + +err_put: + if (!noref) + dst_release(&rt->dst); + return -1; + +err_unreach: + /* The ip6_link_failure function requires the dev field to be set + * in order to get the net (further for the sake of fwmark + * reflection). + */ + if (!skb->dev) + skb->dev = skb_dst(skb)->dev; + + dst_link_failure(skb); + return -1; +} +#endif + + +/* return NF_ACCEPT to allow forwarding or other NF_xxx on error */ +static inline int ip_vs_tunnel_xmit_prepare(struct sk_buff *skb, + struct ip_vs_conn *cp) +{ + int ret = NF_ACCEPT; + + skb->ipvs_property = 1; + if (unlikely(cp->flags & IP_VS_CONN_F_NFCT)) + ret = ip_vs_confirm_conntrack(skb); + if (ret == NF_ACCEPT) { + nf_reset_ct(skb); + skb_forward_csum(skb); + if (skb->dev) + skb_clear_tstamp(skb); + } + return ret; +} + +/* In the event of a remote destination, it's possible that we would have + * matches against an old socket (particularly a TIME-WAIT socket). This + * causes havoc down the line (ip_local_out et. al. expect regular sockets + * and invalid memory accesses will happen) so simply drop the association + * in this case. +*/ +static inline void ip_vs_drop_early_demux_sk(struct sk_buff *skb) +{ + /* If dev is set, the packet came from the LOCAL_IN callback and + * not from a local TCP socket. + */ + if (skb->dev) + skb_orphan(skb); +} + +/* return NF_STOLEN (sent) or NF_ACCEPT if local=1 (not sent) */ +static inline int ip_vs_nat_send_or_cont(int pf, struct sk_buff *skb, + struct ip_vs_conn *cp, int local) +{ + int ret = NF_STOLEN; + + skb->ipvs_property = 1; + if (likely(!(cp->flags & IP_VS_CONN_F_NFCT))) + ip_vs_notrack(skb); + else + ip_vs_update_conntrack(skb, cp, 1); + + /* Remove the early_demux association unless it's bound for the + * exact same port and address on this host after translation. + */ + if (!local || cp->vport != cp->dport || + !ip_vs_addr_equal(cp->af, &cp->vaddr, &cp->daddr)) + ip_vs_drop_early_demux_sk(skb); + + if (!local) { + skb_forward_csum(skb); + if (skb->dev) + skb_clear_tstamp(skb); + NF_HOOK(pf, NF_INET_LOCAL_OUT, cp->ipvs->net, NULL, skb, + NULL, skb_dst(skb)->dev, dst_output); + } else + ret = NF_ACCEPT; + + return ret; +} + +/* return NF_STOLEN (sent) or NF_ACCEPT if local=1 (not sent) */ +static inline int ip_vs_send_or_cont(int pf, struct sk_buff *skb, + struct ip_vs_conn *cp, int local) +{ + int ret = NF_STOLEN; + + skb->ipvs_property = 1; + if (likely(!(cp->flags & IP_VS_CONN_F_NFCT))) + ip_vs_notrack(skb); + if (!local) { + ip_vs_drop_early_demux_sk(skb); + skb_forward_csum(skb); + if (skb->dev) + skb_clear_tstamp(skb); + NF_HOOK(pf, NF_INET_LOCAL_OUT, cp->ipvs->net, NULL, skb, + NULL, skb_dst(skb)->dev, dst_output); + } else + ret = NF_ACCEPT; + return ret; +} + + +/* + * NULL transmitter (do nothing except return NF_ACCEPT) + */ +int +ip_vs_null_xmit(struct sk_buff *skb, struct ip_vs_conn *cp, + struct ip_vs_protocol *pp, struct ip_vs_iphdr *ipvsh) +{ + /* we do not touch skb and do not need pskb ptr */ + return ip_vs_send_or_cont(NFPROTO_IPV4, skb, cp, 1); +} + + +/* + * Bypass transmitter + * Let packets bypass the destination when the destination is not + * available, it may be only used in transparent cache cluster. + */ +int +ip_vs_bypass_xmit(struct sk_buff *skb, struct ip_vs_conn *cp, + struct ip_vs_protocol *pp, struct ip_vs_iphdr *ipvsh) +{ + struct iphdr *iph = ip_hdr(skb); + + EnterFunction(10); + + if (__ip_vs_get_out_rt(cp->ipvs, cp->af, skb, NULL, iph->daddr, + IP_VS_RT_MODE_NON_LOCAL, NULL, ipvsh) < 0) + goto tx_error; + + ip_send_check(iph); + + /* Another hack: avoid icmp_send in ip_fragment */ + skb->ignore_df = 1; + + ip_vs_send_or_cont(NFPROTO_IPV4, skb, cp, 0); + + LeaveFunction(10); + return NF_STOLEN; + + tx_error: + kfree_skb(skb); + LeaveFunction(10); + return NF_STOLEN; +} + +#ifdef CONFIG_IP_VS_IPV6 +int +ip_vs_bypass_xmit_v6(struct sk_buff *skb, struct ip_vs_conn *cp, + struct ip_vs_protocol *pp, struct ip_vs_iphdr *ipvsh) +{ + struct ipv6hdr *iph = ipv6_hdr(skb); + + EnterFunction(10); + + if (__ip_vs_get_out_rt_v6(cp->ipvs, cp->af, skb, NULL, + &iph->daddr, NULL, + ipvsh, 0, IP_VS_RT_MODE_NON_LOCAL) < 0) + goto tx_error; + + /* Another hack: avoid icmp_send in ip_fragment */ + skb->ignore_df = 1; + + ip_vs_send_or_cont(NFPROTO_IPV6, skb, cp, 0); + + LeaveFunction(10); + return NF_STOLEN; + + tx_error: + kfree_skb(skb); + LeaveFunction(10); + return NF_STOLEN; +} +#endif + +/* + * NAT transmitter (only for outside-to-inside nat forwarding) + * Not used for related ICMP + */ +int +ip_vs_nat_xmit(struct sk_buff *skb, struct ip_vs_conn *cp, + struct ip_vs_protocol *pp, struct ip_vs_iphdr *ipvsh) +{ + struct rtable *rt; /* Route to the other host */ + int local, rc, was_input; + + EnterFunction(10); + + /* check if it is a connection of no-client-port */ + if (unlikely(cp->flags & IP_VS_CONN_F_NO_CPORT)) { + __be16 _pt, *p; + + p = skb_header_pointer(skb, ipvsh->len, sizeof(_pt), &_pt); + if (p == NULL) + goto tx_error; + ip_vs_conn_fill_cport(cp, *p); + IP_VS_DBG(10, "filled cport=%d\n", ntohs(*p)); + } + + was_input = rt_is_input_route(skb_rtable(skb)); + local = __ip_vs_get_out_rt(cp->ipvs, cp->af, skb, cp->dest, cp->daddr.ip, + IP_VS_RT_MODE_LOCAL | + IP_VS_RT_MODE_NON_LOCAL | + IP_VS_RT_MODE_RDR, NULL, ipvsh); + if (local < 0) + goto tx_error; + rt = skb_rtable(skb); + /* + * Avoid duplicate tuple in reply direction for NAT traffic + * to local address when connection is sync-ed + */ +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + if (cp->flags & IP_VS_CONN_F_SYNC && local) { + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + + if (ct) { + IP_VS_DBG_RL_PKT(10, AF_INET, pp, skb, ipvsh->off, + "ip_vs_nat_xmit(): " + "stopping DNAT to local address"); + goto tx_error; + } + } +#endif + + /* From world but DNAT to loopback address? */ + if (local && ipv4_is_loopback(cp->daddr.ip) && was_input) { + IP_VS_DBG_RL_PKT(1, AF_INET, pp, skb, ipvsh->off, + "ip_vs_nat_xmit(): stopping DNAT to loopback " + "address"); + goto tx_error; + } + + /* copy-on-write the packet before mangling it */ + if (skb_ensure_writable(skb, sizeof(struct iphdr))) + goto tx_error; + + if (skb_cow(skb, rt->dst.dev->hard_header_len)) + goto tx_error; + + /* mangle the packet */ + if (pp->dnat_handler && !pp->dnat_handler(skb, pp, cp, ipvsh)) + goto tx_error; + ip_hdr(skb)->daddr = cp->daddr.ip; + ip_send_check(ip_hdr(skb)); + + IP_VS_DBG_PKT(10, AF_INET, pp, skb, ipvsh->off, "After DNAT"); + + /* FIXME: when application helper enlarges the packet and the length + is larger than the MTU of outgoing device, there will be still + MTU problem. */ + + /* Another hack: avoid icmp_send in ip_fragment */ + skb->ignore_df = 1; + + rc = ip_vs_nat_send_or_cont(NFPROTO_IPV4, skb, cp, local); + + LeaveFunction(10); + return rc; + + tx_error: + kfree_skb(skb); + LeaveFunction(10); + return NF_STOLEN; +} + +#ifdef CONFIG_IP_VS_IPV6 +int +ip_vs_nat_xmit_v6(struct sk_buff *skb, struct ip_vs_conn *cp, + struct ip_vs_protocol *pp, struct ip_vs_iphdr *ipvsh) +{ + struct rt6_info *rt; /* Route to the other host */ + int local, rc; + + EnterFunction(10); + + /* check if it is a connection of no-client-port */ + if (unlikely(cp->flags & IP_VS_CONN_F_NO_CPORT && !ipvsh->fragoffs)) { + __be16 _pt, *p; + p = skb_header_pointer(skb, ipvsh->len, sizeof(_pt), &_pt); + if (p == NULL) + goto tx_error; + ip_vs_conn_fill_cport(cp, *p); + IP_VS_DBG(10, "filled cport=%d\n", ntohs(*p)); + } + + local = __ip_vs_get_out_rt_v6(cp->ipvs, cp->af, skb, cp->dest, + &cp->daddr.in6, + NULL, ipvsh, 0, + IP_VS_RT_MODE_LOCAL | + IP_VS_RT_MODE_NON_LOCAL | + IP_VS_RT_MODE_RDR); + if (local < 0) + goto tx_error; + rt = (struct rt6_info *) skb_dst(skb); + /* + * Avoid duplicate tuple in reply direction for NAT traffic + * to local address when connection is sync-ed + */ +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + if (cp->flags & IP_VS_CONN_F_SYNC && local) { + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + + if (ct) { + IP_VS_DBG_RL_PKT(10, AF_INET6, pp, skb, ipvsh->off, + "ip_vs_nat_xmit_v6(): " + "stopping DNAT to local address"); + goto tx_error; + } + } +#endif + + /* From world but DNAT to loopback address? */ + if (local && skb->dev && !(skb->dev->flags & IFF_LOOPBACK) && + ipv6_addr_type(&cp->daddr.in6) & IPV6_ADDR_LOOPBACK) { + IP_VS_DBG_RL_PKT(1, AF_INET6, pp, skb, ipvsh->off, + "ip_vs_nat_xmit_v6(): " + "stopping DNAT to loopback address"); + goto tx_error; + } + + /* copy-on-write the packet before mangling it */ + if (skb_ensure_writable(skb, sizeof(struct ipv6hdr))) + goto tx_error; + + if (skb_cow(skb, rt->dst.dev->hard_header_len)) + goto tx_error; + + /* mangle the packet */ + if (pp->dnat_handler && !pp->dnat_handler(skb, pp, cp, ipvsh)) + goto tx_error; + ipv6_hdr(skb)->daddr = cp->daddr.in6; + + IP_VS_DBG_PKT(10, AF_INET6, pp, skb, ipvsh->off, "After DNAT"); + + /* FIXME: when application helper enlarges the packet and the length + is larger than the MTU of outgoing device, there will be still + MTU problem. */ + + /* Another hack: avoid icmp_send in ip_fragment */ + skb->ignore_df = 1; + + rc = ip_vs_nat_send_or_cont(NFPROTO_IPV6, skb, cp, local); + + LeaveFunction(10); + return rc; + +tx_error: + LeaveFunction(10); + kfree_skb(skb); + return NF_STOLEN; +} +#endif + +/* When forwarding a packet, we must ensure that we've got enough headroom + * for the encapsulation packet in the skb. This also gives us an + * opportunity to figure out what the payload_len, dsfield, ttl, and df + * values should be, so that we won't need to look at the old ip header + * again + */ +static struct sk_buff * +ip_vs_prepare_tunneled_skb(struct sk_buff *skb, int skb_af, + unsigned int max_headroom, __u8 *next_protocol, + __u32 *payload_len, __u8 *dsfield, __u8 *ttl, + __be16 *df) +{ + struct sk_buff *new_skb = NULL; + struct iphdr *old_iph = NULL; + __u8 old_dsfield; +#ifdef CONFIG_IP_VS_IPV6 + struct ipv6hdr *old_ipv6h = NULL; +#endif + + ip_vs_drop_early_demux_sk(skb); + + if (skb_headroom(skb) < max_headroom || skb_cloned(skb)) { + new_skb = skb_realloc_headroom(skb, max_headroom); + if (!new_skb) + goto error; + if (skb->sk) + skb_set_owner_w(new_skb, skb->sk); + consume_skb(skb); + skb = new_skb; + } + +#ifdef CONFIG_IP_VS_IPV6 + if (skb_af == AF_INET6) { + old_ipv6h = ipv6_hdr(skb); + *next_protocol = IPPROTO_IPV6; + if (payload_len) + *payload_len = + ntohs(old_ipv6h->payload_len) + + sizeof(*old_ipv6h); + old_dsfield = ipv6_get_dsfield(old_ipv6h); + *ttl = old_ipv6h->hop_limit; + if (df) + *df = 0; + } else +#endif + { + old_iph = ip_hdr(skb); + /* Copy DF, reset fragment offset and MF */ + if (df) + *df = (old_iph->frag_off & htons(IP_DF)); + *next_protocol = IPPROTO_IPIP; + + /* fix old IP header checksum */ + ip_send_check(old_iph); + old_dsfield = ipv4_get_dsfield(old_iph); + *ttl = old_iph->ttl; + if (payload_len) + *payload_len = skb_ip_totlen(skb); + } + + /* Implement full-functionality option for ECN encapsulation */ + *dsfield = INET_ECN_encapsulate(old_dsfield, old_dsfield); + + return skb; +error: + kfree_skb(skb); + return ERR_PTR(-ENOMEM); +} + +static inline int __tun_gso_type_mask(int encaps_af, int orig_af) +{ + switch (encaps_af) { + case AF_INET: + return SKB_GSO_IPXIP4; + case AF_INET6: + return SKB_GSO_IPXIP6; + default: + return 0; + } +} + +static int +ipvs_gue_encap(struct net *net, struct sk_buff *skb, + struct ip_vs_conn *cp, __u8 *next_protocol) +{ + __be16 dport; + __be16 sport = udp_flow_src_port(net, skb, 0, 0, false); + struct udphdr *udph; /* Our new UDP header */ + struct guehdr *gueh; /* Our new GUE header */ + size_t hdrlen, optlen = 0; + void *data; + bool need_priv = false; + + if ((cp->dest->tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_REMCSUM) && + skb->ip_summed == CHECKSUM_PARTIAL) { + optlen += GUE_PLEN_REMCSUM + GUE_LEN_PRIV; + need_priv = true; + } + + hdrlen = sizeof(struct guehdr) + optlen; + + skb_push(skb, hdrlen); + + gueh = (struct guehdr *)skb->data; + + gueh->control = 0; + gueh->version = 0; + gueh->hlen = optlen >> 2; + gueh->flags = 0; + gueh->proto_ctype = *next_protocol; + + data = &gueh[1]; + + if (need_priv) { + __be32 *flags = data; + u16 csum_start = skb_checksum_start_offset(skb); + __be16 *pd; + + gueh->flags |= GUE_FLAG_PRIV; + *flags = 0; + data += GUE_LEN_PRIV; + + if (csum_start < hdrlen) + return -EINVAL; + + csum_start -= hdrlen; + pd = data; + pd[0] = htons(csum_start); + pd[1] = htons(csum_start + skb->csum_offset); + + if (!skb_is_gso(skb)) { + skb->ip_summed = CHECKSUM_NONE; + skb->encapsulation = 0; + } + + *flags |= GUE_PFLAG_REMCSUM; + data += GUE_PLEN_REMCSUM; + } + + skb_push(skb, sizeof(struct udphdr)); + skb_reset_transport_header(skb); + + udph = udp_hdr(skb); + + dport = cp->dest->tun_port; + udph->dest = dport; + udph->source = sport; + udph->len = htons(skb->len); + udph->check = 0; + + *next_protocol = IPPROTO_UDP; + + return 0; +} + +static void +ipvs_gre_encap(struct net *net, struct sk_buff *skb, + struct ip_vs_conn *cp, __u8 *next_protocol) +{ + __be16 proto = *next_protocol == IPPROTO_IPIP ? + htons(ETH_P_IP) : htons(ETH_P_IPV6); + __be16 tflags = 0; + size_t hdrlen; + + if (cp->dest->tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_CSUM) + tflags |= TUNNEL_CSUM; + + hdrlen = gre_calc_hlen(tflags); + gre_build_header(skb, hdrlen, tflags, proto, 0, 0); + + *next_protocol = IPPROTO_GRE; +} + +/* + * IP Tunneling transmitter + * + * This function encapsulates the packet in a new IP packet, its + * destination will be set to cp->daddr. Most code of this function + * is taken from ipip.c. + * + * It is used in VS/TUN cluster. The load balancer selects a real + * server from a cluster based on a scheduling algorithm, + * encapsulates the request packet and forwards it to the selected + * server. For example, all real servers are configured with + * "ifconfig tunl0 up". When the server receives + * the encapsulated packet, it will decapsulate the packet, processe + * the request and return the response packets directly to the client + * without passing the load balancer. This can greatly increase the + * scalability of virtual server. + * + * Used for ANY protocol + */ +int +ip_vs_tunnel_xmit(struct sk_buff *skb, struct ip_vs_conn *cp, + struct ip_vs_protocol *pp, struct ip_vs_iphdr *ipvsh) +{ + struct netns_ipvs *ipvs = cp->ipvs; + struct net *net = ipvs->net; + struct rtable *rt; /* Route to the other host */ + __be32 saddr; /* Source for tunnel */ + struct net_device *tdev; /* Device to other host */ + __u8 next_protocol = 0; + __u8 dsfield = 0; + __u8 ttl = 0; + __be16 df = 0; + __be16 *dfp = NULL; + struct iphdr *iph; /* Our new IP header */ + unsigned int max_headroom; /* The extra header space needed */ + int ret, local; + int tun_type, gso_type; + int tun_flags; + + EnterFunction(10); + + local = __ip_vs_get_out_rt(ipvs, cp->af, skb, cp->dest, cp->daddr.ip, + IP_VS_RT_MODE_LOCAL | + IP_VS_RT_MODE_NON_LOCAL | + IP_VS_RT_MODE_CONNECT | + IP_VS_RT_MODE_TUNNEL, &saddr, ipvsh); + if (local < 0) + goto tx_error; + if (local) + return ip_vs_send_or_cont(NFPROTO_IPV4, skb, cp, 1); + + rt = skb_rtable(skb); + tdev = rt->dst.dev; + + /* + * Okay, now see if we can stuff it in the buffer as-is. + */ + max_headroom = LL_RESERVED_SPACE(tdev) + sizeof(struct iphdr); + + tun_type = cp->dest->tun_type; + tun_flags = cp->dest->tun_flags; + + if (tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GUE) { + size_t gue_hdrlen, gue_optlen = 0; + + if ((tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_REMCSUM) && + skb->ip_summed == CHECKSUM_PARTIAL) { + gue_optlen += GUE_PLEN_REMCSUM + GUE_LEN_PRIV; + } + gue_hdrlen = sizeof(struct guehdr) + gue_optlen; + + max_headroom += sizeof(struct udphdr) + gue_hdrlen; + } else if (tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GRE) { + size_t gre_hdrlen; + __be16 tflags = 0; + + if (tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_CSUM) + tflags |= TUNNEL_CSUM; + gre_hdrlen = gre_calc_hlen(tflags); + + max_headroom += gre_hdrlen; + } + + /* We only care about the df field if sysctl_pmtu_disc(ipvs) is set */ + dfp = sysctl_pmtu_disc(ipvs) ? &df : NULL; + skb = ip_vs_prepare_tunneled_skb(skb, cp->af, max_headroom, + &next_protocol, NULL, &dsfield, + &ttl, dfp); + if (IS_ERR(skb)) + goto tx_error; + + gso_type = __tun_gso_type_mask(AF_INET, cp->af); + if (tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GUE) { + if ((tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_CSUM) || + (tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_REMCSUM)) + gso_type |= SKB_GSO_UDP_TUNNEL_CSUM; + else + gso_type |= SKB_GSO_UDP_TUNNEL; + if ((tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_REMCSUM) && + skb->ip_summed == CHECKSUM_PARTIAL) { + gso_type |= SKB_GSO_TUNNEL_REMCSUM; + } + } else if (tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GRE) { + if (tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_CSUM) + gso_type |= SKB_GSO_GRE_CSUM; + else + gso_type |= SKB_GSO_GRE; + } + + if (iptunnel_handle_offloads(skb, gso_type)) + goto tx_error; + + skb->transport_header = skb->network_header; + + skb_set_inner_ipproto(skb, next_protocol); + skb_set_inner_mac_header(skb, skb_inner_network_offset(skb)); + + if (tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GUE) { + bool check = false; + + if (ipvs_gue_encap(net, skb, cp, &next_protocol)) + goto tx_error; + + if ((tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_CSUM) || + (tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_REMCSUM)) + check = true; + + udp_set_csum(!check, skb, saddr, cp->daddr.ip, skb->len); + } else if (tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GRE) + ipvs_gre_encap(net, skb, cp, &next_protocol); + + skb_push(skb, sizeof(struct iphdr)); + skb_reset_network_header(skb); + memset(&(IPCB(skb)->opt), 0, sizeof(IPCB(skb)->opt)); + + /* + * Push down and install the IPIP header. + */ + iph = ip_hdr(skb); + iph->version = 4; + iph->ihl = sizeof(struct iphdr)>>2; + iph->frag_off = df; + iph->protocol = next_protocol; + iph->tos = dsfield; + iph->daddr = cp->daddr.ip; + iph->saddr = saddr; + iph->ttl = ttl; + ip_select_ident(net, skb, NULL); + + /* Another hack: avoid icmp_send in ip_fragment */ + skb->ignore_df = 1; + + ret = ip_vs_tunnel_xmit_prepare(skb, cp); + if (ret == NF_ACCEPT) + ip_local_out(net, skb->sk, skb); + else if (ret == NF_DROP) + kfree_skb(skb); + + LeaveFunction(10); + + return NF_STOLEN; + + tx_error: + if (!IS_ERR(skb)) + kfree_skb(skb); + LeaveFunction(10); + return NF_STOLEN; +} + +#ifdef CONFIG_IP_VS_IPV6 +int +ip_vs_tunnel_xmit_v6(struct sk_buff *skb, struct ip_vs_conn *cp, + struct ip_vs_protocol *pp, struct ip_vs_iphdr *ipvsh) +{ + struct netns_ipvs *ipvs = cp->ipvs; + struct net *net = ipvs->net; + struct rt6_info *rt; /* Route to the other host */ + struct in6_addr saddr; /* Source for tunnel */ + struct net_device *tdev; /* Device to other host */ + __u8 next_protocol = 0; + __u32 payload_len = 0; + __u8 dsfield = 0; + __u8 ttl = 0; + struct ipv6hdr *iph; /* Our new IP header */ + unsigned int max_headroom; /* The extra header space needed */ + int ret, local; + int tun_type, gso_type; + int tun_flags; + + EnterFunction(10); + + local = __ip_vs_get_out_rt_v6(ipvs, cp->af, skb, cp->dest, + &cp->daddr.in6, + &saddr, ipvsh, 1, + IP_VS_RT_MODE_LOCAL | + IP_VS_RT_MODE_NON_LOCAL | + IP_VS_RT_MODE_TUNNEL); + if (local < 0) + goto tx_error; + if (local) + return ip_vs_send_or_cont(NFPROTO_IPV6, skb, cp, 1); + + rt = (struct rt6_info *) skb_dst(skb); + tdev = rt->dst.dev; + + /* + * Okay, now see if we can stuff it in the buffer as-is. + */ + max_headroom = LL_RESERVED_SPACE(tdev) + sizeof(struct ipv6hdr); + + tun_type = cp->dest->tun_type; + tun_flags = cp->dest->tun_flags; + + if (tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GUE) { + size_t gue_hdrlen, gue_optlen = 0; + + if ((tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_REMCSUM) && + skb->ip_summed == CHECKSUM_PARTIAL) { + gue_optlen += GUE_PLEN_REMCSUM + GUE_LEN_PRIV; + } + gue_hdrlen = sizeof(struct guehdr) + gue_optlen; + + max_headroom += sizeof(struct udphdr) + gue_hdrlen; + } else if (tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GRE) { + size_t gre_hdrlen; + __be16 tflags = 0; + + if (tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_CSUM) + tflags |= TUNNEL_CSUM; + gre_hdrlen = gre_calc_hlen(tflags); + + max_headroom += gre_hdrlen; + } + + skb = ip_vs_prepare_tunneled_skb(skb, cp->af, max_headroom, + &next_protocol, &payload_len, + &dsfield, &ttl, NULL); + if (IS_ERR(skb)) + goto tx_error; + + gso_type = __tun_gso_type_mask(AF_INET6, cp->af); + if (tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GUE) { + if ((tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_CSUM) || + (tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_REMCSUM)) + gso_type |= SKB_GSO_UDP_TUNNEL_CSUM; + else + gso_type |= SKB_GSO_UDP_TUNNEL; + if ((tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_REMCSUM) && + skb->ip_summed == CHECKSUM_PARTIAL) { + gso_type |= SKB_GSO_TUNNEL_REMCSUM; + } + } else if (tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GRE) { + if (tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_CSUM) + gso_type |= SKB_GSO_GRE_CSUM; + else + gso_type |= SKB_GSO_GRE; + } + + if (iptunnel_handle_offloads(skb, gso_type)) + goto tx_error; + + skb->transport_header = skb->network_header; + + skb_set_inner_ipproto(skb, next_protocol); + skb_set_inner_mac_header(skb, skb_inner_network_offset(skb)); + + if (tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GUE) { + bool check = false; + + if (ipvs_gue_encap(net, skb, cp, &next_protocol)) + goto tx_error; + + if ((tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_CSUM) || + (tun_flags & IP_VS_TUNNEL_ENCAP_FLAG_REMCSUM)) + check = true; + + udp6_set_csum(!check, skb, &saddr, &cp->daddr.in6, skb->len); + } else if (tun_type == IP_VS_CONN_F_TUNNEL_TYPE_GRE) + ipvs_gre_encap(net, skb, cp, &next_protocol); + + skb_push(skb, sizeof(struct ipv6hdr)); + skb_reset_network_header(skb); + memset(&(IPCB(skb)->opt), 0, sizeof(IPCB(skb)->opt)); + + /* + * Push down and install the IPIP header. + */ + iph = ipv6_hdr(skb); + iph->version = 6; + iph->nexthdr = next_protocol; + iph->payload_len = htons(payload_len); + memset(&iph->flow_lbl, 0, sizeof(iph->flow_lbl)); + ipv6_change_dsfield(iph, 0, dsfield); + iph->daddr = cp->daddr.in6; + iph->saddr = saddr; + iph->hop_limit = ttl; + + /* Another hack: avoid icmp_send in ip_fragment */ + skb->ignore_df = 1; + + ret = ip_vs_tunnel_xmit_prepare(skb, cp); + if (ret == NF_ACCEPT) + ip6_local_out(net, skb->sk, skb); + else if (ret == NF_DROP) + kfree_skb(skb); + + LeaveFunction(10); + + return NF_STOLEN; + +tx_error: + if (!IS_ERR(skb)) + kfree_skb(skb); + LeaveFunction(10); + return NF_STOLEN; +} +#endif + + +/* + * Direct Routing transmitter + * Used for ANY protocol + */ +int +ip_vs_dr_xmit(struct sk_buff *skb, struct ip_vs_conn *cp, + struct ip_vs_protocol *pp, struct ip_vs_iphdr *ipvsh) +{ + int local; + + EnterFunction(10); + + local = __ip_vs_get_out_rt(cp->ipvs, cp->af, skb, cp->dest, cp->daddr.ip, + IP_VS_RT_MODE_LOCAL | + IP_VS_RT_MODE_NON_LOCAL | + IP_VS_RT_MODE_KNOWN_NH, NULL, ipvsh); + if (local < 0) + goto tx_error; + if (local) + return ip_vs_send_or_cont(NFPROTO_IPV4, skb, cp, 1); + + ip_send_check(ip_hdr(skb)); + + /* Another hack: avoid icmp_send in ip_fragment */ + skb->ignore_df = 1; + + ip_vs_send_or_cont(NFPROTO_IPV4, skb, cp, 0); + + LeaveFunction(10); + return NF_STOLEN; + + tx_error: + kfree_skb(skb); + LeaveFunction(10); + return NF_STOLEN; +} + +#ifdef CONFIG_IP_VS_IPV6 +int +ip_vs_dr_xmit_v6(struct sk_buff *skb, struct ip_vs_conn *cp, + struct ip_vs_protocol *pp, struct ip_vs_iphdr *ipvsh) +{ + int local; + + EnterFunction(10); + + local = __ip_vs_get_out_rt_v6(cp->ipvs, cp->af, skb, cp->dest, + &cp->daddr.in6, + NULL, ipvsh, 0, + IP_VS_RT_MODE_LOCAL | + IP_VS_RT_MODE_NON_LOCAL | + IP_VS_RT_MODE_KNOWN_NH); + if (local < 0) + goto tx_error; + if (local) + return ip_vs_send_or_cont(NFPROTO_IPV6, skb, cp, 1); + + /* Another hack: avoid icmp_send in ip_fragment */ + skb->ignore_df = 1; + + ip_vs_send_or_cont(NFPROTO_IPV6, skb, cp, 0); + + LeaveFunction(10); + return NF_STOLEN; + +tx_error: + kfree_skb(skb); + LeaveFunction(10); + return NF_STOLEN; +} +#endif + + +/* + * ICMP packet transmitter + * called by the ip_vs_in_icmp + */ +int +ip_vs_icmp_xmit(struct sk_buff *skb, struct ip_vs_conn *cp, + struct ip_vs_protocol *pp, int offset, unsigned int hooknum, + struct ip_vs_iphdr *iph) +{ + struct rtable *rt; /* Route to the other host */ + int rc; + int local; + int rt_mode, was_input; + + EnterFunction(10); + + /* The ICMP packet for VS/TUN, VS/DR and LOCALNODE will be + forwarded directly here, because there is no need to + translate address/port back */ + if (IP_VS_FWD_METHOD(cp) != IP_VS_CONN_F_MASQ) { + if (cp->packet_xmit) + rc = cp->packet_xmit(skb, cp, pp, iph); + else + rc = NF_ACCEPT; + /* do not touch skb anymore */ + atomic_inc(&cp->in_pkts); + goto out; + } + + /* + * mangle and send the packet here (only for VS/NAT) + */ + was_input = rt_is_input_route(skb_rtable(skb)); + + /* LOCALNODE from FORWARD hook is not supported */ + rt_mode = (hooknum != NF_INET_FORWARD) ? + IP_VS_RT_MODE_LOCAL | IP_VS_RT_MODE_NON_LOCAL | + IP_VS_RT_MODE_RDR : IP_VS_RT_MODE_NON_LOCAL; + local = __ip_vs_get_out_rt(cp->ipvs, cp->af, skb, cp->dest, cp->daddr.ip, rt_mode, + NULL, iph); + if (local < 0) + goto tx_error; + rt = skb_rtable(skb); + + /* + * Avoid duplicate tuple in reply direction for NAT traffic + * to local address when connection is sync-ed + */ +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + if (cp->flags & IP_VS_CONN_F_SYNC && local) { + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + + if (ct) { + IP_VS_DBG(10, "%s(): " + "stopping DNAT to local address %pI4\n", + __func__, &cp->daddr.ip); + goto tx_error; + } + } +#endif + + /* From world but DNAT to loopback address? */ + if (local && ipv4_is_loopback(cp->daddr.ip) && was_input) { + IP_VS_DBG(1, "%s(): " + "stopping DNAT to loopback %pI4\n", + __func__, &cp->daddr.ip); + goto tx_error; + } + + /* copy-on-write the packet before mangling it */ + if (skb_ensure_writable(skb, offset)) + goto tx_error; + + if (skb_cow(skb, rt->dst.dev->hard_header_len)) + goto tx_error; + + ip_vs_nat_icmp(skb, pp, cp, 0); + + /* Another hack: avoid icmp_send in ip_fragment */ + skb->ignore_df = 1; + + rc = ip_vs_nat_send_or_cont(NFPROTO_IPV4, skb, cp, local); + goto out; + + tx_error: + kfree_skb(skb); + rc = NF_STOLEN; + out: + LeaveFunction(10); + return rc; +} + +#ifdef CONFIG_IP_VS_IPV6 +int +ip_vs_icmp_xmit_v6(struct sk_buff *skb, struct ip_vs_conn *cp, + struct ip_vs_protocol *pp, int offset, unsigned int hooknum, + struct ip_vs_iphdr *ipvsh) +{ + struct rt6_info *rt; /* Route to the other host */ + int rc; + int local; + int rt_mode; + + EnterFunction(10); + + /* The ICMP packet for VS/TUN, VS/DR and LOCALNODE will be + forwarded directly here, because there is no need to + translate address/port back */ + if (IP_VS_FWD_METHOD(cp) != IP_VS_CONN_F_MASQ) { + if (cp->packet_xmit) + rc = cp->packet_xmit(skb, cp, pp, ipvsh); + else + rc = NF_ACCEPT; + /* do not touch skb anymore */ + atomic_inc(&cp->in_pkts); + goto out; + } + + /* + * mangle and send the packet here (only for VS/NAT) + */ + + /* LOCALNODE from FORWARD hook is not supported */ + rt_mode = (hooknum != NF_INET_FORWARD) ? + IP_VS_RT_MODE_LOCAL | IP_VS_RT_MODE_NON_LOCAL | + IP_VS_RT_MODE_RDR : IP_VS_RT_MODE_NON_LOCAL; + local = __ip_vs_get_out_rt_v6(cp->ipvs, cp->af, skb, cp->dest, + &cp->daddr.in6, NULL, ipvsh, 0, rt_mode); + if (local < 0) + goto tx_error; + rt = (struct rt6_info *) skb_dst(skb); + /* + * Avoid duplicate tuple in reply direction for NAT traffic + * to local address when connection is sync-ed + */ +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + if (cp->flags & IP_VS_CONN_F_SYNC && local) { + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + + if (ct) { + IP_VS_DBG(10, "%s(): " + "stopping DNAT to local address %pI6\n", + __func__, &cp->daddr.in6); + goto tx_error; + } + } +#endif + + /* From world but DNAT to loopback address? */ + if (local && skb->dev && !(skb->dev->flags & IFF_LOOPBACK) && + ipv6_addr_type(&cp->daddr.in6) & IPV6_ADDR_LOOPBACK) { + IP_VS_DBG(1, "%s(): " + "stopping DNAT to loopback %pI6\n", + __func__, &cp->daddr.in6); + goto tx_error; + } + + /* copy-on-write the packet before mangling it */ + if (skb_ensure_writable(skb, offset)) + goto tx_error; + + if (skb_cow(skb, rt->dst.dev->hard_header_len)) + goto tx_error; + + ip_vs_nat_icmp_v6(skb, pp, cp, 0); + + /* Another hack: avoid icmp_send in ip_fragment */ + skb->ignore_df = 1; + + rc = ip_vs_nat_send_or_cont(NFPROTO_IPV6, skb, cp, local); + goto out; + +tx_error: + kfree_skb(skb); + rc = NF_STOLEN; +out: + LeaveFunction(10); + return rc; +} +#endif diff --git a/net/netfilter/nf_conncount.c b/net/netfilter/nf_conncount.c new file mode 100644 index 000000000..5d8ed6c90 --- /dev/null +++ b/net/netfilter/nf_conncount.c @@ -0,0 +1,636 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * count the number of connections matching an arbitrary key. + * + * (C) 2017 Red Hat GmbH + * Author: Florian Westphal + * + * split from xt_connlimit.c: + * (c) 2000 Gerd Knorr + * Nov 2002: Martin Bene : + * only ignore TIME_WAIT or gone connections + * (C) CC Computer Consultants GmbH, 2007 + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define CONNCOUNT_SLOTS 256U + +#define CONNCOUNT_GC_MAX_NODES 8 +#define MAX_KEYLEN 5 + +/* we will save the tuples of all connections we care about */ +struct nf_conncount_tuple { + struct list_head node; + struct nf_conntrack_tuple tuple; + struct nf_conntrack_zone zone; + int cpu; + u32 jiffies32; +}; + +struct nf_conncount_rb { + struct rb_node node; + struct nf_conncount_list list; + u32 key[MAX_KEYLEN]; + struct rcu_head rcu_head; +}; + +static spinlock_t nf_conncount_locks[CONNCOUNT_SLOTS] __cacheline_aligned_in_smp; + +struct nf_conncount_data { + unsigned int keylen; + struct rb_root root[CONNCOUNT_SLOTS]; + struct net *net; + struct work_struct gc_work; + unsigned long pending_trees[BITS_TO_LONGS(CONNCOUNT_SLOTS)]; + unsigned int gc_tree; +}; + +static u_int32_t conncount_rnd __read_mostly; +static struct kmem_cache *conncount_rb_cachep __read_mostly; +static struct kmem_cache *conncount_conn_cachep __read_mostly; + +static inline bool already_closed(const struct nf_conn *conn) +{ + if (nf_ct_protonum(conn) == IPPROTO_TCP) + return conn->proto.tcp.state == TCP_CONNTRACK_TIME_WAIT || + conn->proto.tcp.state == TCP_CONNTRACK_CLOSE; + else + return false; +} + +static int key_diff(const u32 *a, const u32 *b, unsigned int klen) +{ + return memcmp(a, b, klen * sizeof(u32)); +} + +static void conn_free(struct nf_conncount_list *list, + struct nf_conncount_tuple *conn) +{ + lockdep_assert_held(&list->list_lock); + + list->count--; + list_del(&conn->node); + + kmem_cache_free(conncount_conn_cachep, conn); +} + +static const struct nf_conntrack_tuple_hash * +find_or_evict(struct net *net, struct nf_conncount_list *list, + struct nf_conncount_tuple *conn) +{ + const struct nf_conntrack_tuple_hash *found; + unsigned long a, b; + int cpu = raw_smp_processor_id(); + u32 age; + + found = nf_conntrack_find_get(net, &conn->zone, &conn->tuple); + if (found) + return found; + b = conn->jiffies32; + a = (u32)jiffies; + + /* conn might have been added just before by another cpu and + * might still be unconfirmed. In this case, nf_conntrack_find() + * returns no result. Thus only evict if this cpu added the + * stale entry or if the entry is older than two jiffies. + */ + age = a - b; + if (conn->cpu == cpu || age >= 2) { + conn_free(list, conn); + return ERR_PTR(-ENOENT); + } + + return ERR_PTR(-EAGAIN); +} + +static int __nf_conncount_add(struct net *net, + struct nf_conncount_list *list, + const struct nf_conntrack_tuple *tuple, + const struct nf_conntrack_zone *zone) +{ + const struct nf_conntrack_tuple_hash *found; + struct nf_conncount_tuple *conn, *conn_n; + struct nf_conn *found_ct; + unsigned int collect = 0; + + if (time_is_after_eq_jiffies((unsigned long)list->last_gc)) + goto add_new_node; + + /* check the saved connections */ + list_for_each_entry_safe(conn, conn_n, &list->head, node) { + if (collect > CONNCOUNT_GC_MAX_NODES) + break; + + found = find_or_evict(net, list, conn); + if (IS_ERR(found)) { + /* Not found, but might be about to be confirmed */ + if (PTR_ERR(found) == -EAGAIN) { + if (nf_ct_tuple_equal(&conn->tuple, tuple) && + nf_ct_zone_id(&conn->zone, conn->zone.dir) == + nf_ct_zone_id(zone, zone->dir)) + return 0; /* already exists */ + } else { + collect++; + } + continue; + } + + found_ct = nf_ct_tuplehash_to_ctrack(found); + + if (nf_ct_tuple_equal(&conn->tuple, tuple) && + nf_ct_zone_equal(found_ct, zone, zone->dir)) { + /* + * We should not see tuples twice unless someone hooks + * this into a table without "-p tcp --syn". + * + * Attempt to avoid a re-add in this case. + */ + nf_ct_put(found_ct); + return 0; + } else if (already_closed(found_ct)) { + /* + * we do not care about connections which are + * closed already -> ditch it + */ + nf_ct_put(found_ct); + conn_free(list, conn); + collect++; + continue; + } + + nf_ct_put(found_ct); + } + +add_new_node: + if (WARN_ON_ONCE(list->count > INT_MAX)) + return -EOVERFLOW; + + conn = kmem_cache_alloc(conncount_conn_cachep, GFP_ATOMIC); + if (conn == NULL) + return -ENOMEM; + + conn->tuple = *tuple; + conn->zone = *zone; + conn->cpu = raw_smp_processor_id(); + conn->jiffies32 = (u32)jiffies; + list_add_tail(&conn->node, &list->head); + list->count++; + list->last_gc = (u32)jiffies; + return 0; +} + +int nf_conncount_add(struct net *net, + struct nf_conncount_list *list, + const struct nf_conntrack_tuple *tuple, + const struct nf_conntrack_zone *zone) +{ + int ret; + + /* check the saved connections */ + spin_lock_bh(&list->list_lock); + ret = __nf_conncount_add(net, list, tuple, zone); + spin_unlock_bh(&list->list_lock); + + return ret; +} +EXPORT_SYMBOL_GPL(nf_conncount_add); + +void nf_conncount_list_init(struct nf_conncount_list *list) +{ + spin_lock_init(&list->list_lock); + INIT_LIST_HEAD(&list->head); + list->count = 0; + list->last_gc = (u32)jiffies; +} +EXPORT_SYMBOL_GPL(nf_conncount_list_init); + +/* Return true if the list is empty. Must be called with BH disabled. */ +bool nf_conncount_gc_list(struct net *net, + struct nf_conncount_list *list) +{ + const struct nf_conntrack_tuple_hash *found; + struct nf_conncount_tuple *conn, *conn_n; + struct nf_conn *found_ct; + unsigned int collected = 0; + bool ret = false; + + /* don't bother if we just did GC */ + if (time_is_after_eq_jiffies((unsigned long)READ_ONCE(list->last_gc))) + return false; + + /* don't bother if other cpu is already doing GC */ + if (!spin_trylock(&list->list_lock)) + return false; + + list_for_each_entry_safe(conn, conn_n, &list->head, node) { + found = find_or_evict(net, list, conn); + if (IS_ERR(found)) { + if (PTR_ERR(found) == -ENOENT) + collected++; + continue; + } + + found_ct = nf_ct_tuplehash_to_ctrack(found); + if (already_closed(found_ct)) { + /* + * we do not care about connections which are + * closed already -> ditch it + */ + nf_ct_put(found_ct); + conn_free(list, conn); + collected++; + continue; + } + + nf_ct_put(found_ct); + if (collected > CONNCOUNT_GC_MAX_NODES) + break; + } + + if (!list->count) + ret = true; + list->last_gc = (u32)jiffies; + spin_unlock(&list->list_lock); + + return ret; +} +EXPORT_SYMBOL_GPL(nf_conncount_gc_list); + +static void __tree_nodes_free(struct rcu_head *h) +{ + struct nf_conncount_rb *rbconn; + + rbconn = container_of(h, struct nf_conncount_rb, rcu_head); + kmem_cache_free(conncount_rb_cachep, rbconn); +} + +/* caller must hold tree nf_conncount_locks[] lock */ +static void tree_nodes_free(struct rb_root *root, + struct nf_conncount_rb *gc_nodes[], + unsigned int gc_count) +{ + struct nf_conncount_rb *rbconn; + + while (gc_count) { + rbconn = gc_nodes[--gc_count]; + spin_lock(&rbconn->list.list_lock); + if (!rbconn->list.count) { + rb_erase(&rbconn->node, root); + call_rcu(&rbconn->rcu_head, __tree_nodes_free); + } + spin_unlock(&rbconn->list.list_lock); + } +} + +static void schedule_gc_worker(struct nf_conncount_data *data, int tree) +{ + set_bit(tree, data->pending_trees); + schedule_work(&data->gc_work); +} + +static unsigned int +insert_tree(struct net *net, + struct nf_conncount_data *data, + struct rb_root *root, + unsigned int hash, + const u32 *key, + const struct nf_conntrack_tuple *tuple, + const struct nf_conntrack_zone *zone) +{ + struct nf_conncount_rb *gc_nodes[CONNCOUNT_GC_MAX_NODES]; + struct rb_node **rbnode, *parent; + struct nf_conncount_rb *rbconn; + struct nf_conncount_tuple *conn; + unsigned int count = 0, gc_count = 0; + u8 keylen = data->keylen; + bool do_gc = true; + + spin_lock_bh(&nf_conncount_locks[hash]); +restart: + parent = NULL; + rbnode = &(root->rb_node); + while (*rbnode) { + int diff; + rbconn = rb_entry(*rbnode, struct nf_conncount_rb, node); + + parent = *rbnode; + diff = key_diff(key, rbconn->key, keylen); + if (diff < 0) { + rbnode = &((*rbnode)->rb_left); + } else if (diff > 0) { + rbnode = &((*rbnode)->rb_right); + } else { + int ret; + + ret = nf_conncount_add(net, &rbconn->list, tuple, zone); + if (ret) + count = 0; /* hotdrop */ + else + count = rbconn->list.count; + tree_nodes_free(root, gc_nodes, gc_count); + goto out_unlock; + } + + if (gc_count >= ARRAY_SIZE(gc_nodes)) + continue; + + if (do_gc && nf_conncount_gc_list(net, &rbconn->list)) + gc_nodes[gc_count++] = rbconn; + } + + if (gc_count) { + tree_nodes_free(root, gc_nodes, gc_count); + schedule_gc_worker(data, hash); + gc_count = 0; + do_gc = false; + goto restart; + } + + /* expected case: match, insert new node */ + rbconn = kmem_cache_alloc(conncount_rb_cachep, GFP_ATOMIC); + if (rbconn == NULL) + goto out_unlock; + + conn = kmem_cache_alloc(conncount_conn_cachep, GFP_ATOMIC); + if (conn == NULL) { + kmem_cache_free(conncount_rb_cachep, rbconn); + goto out_unlock; + } + + conn->tuple = *tuple; + conn->zone = *zone; + memcpy(rbconn->key, key, sizeof(u32) * keylen); + + nf_conncount_list_init(&rbconn->list); + list_add(&conn->node, &rbconn->list.head); + count = 1; + rbconn->list.count = count; + + rb_link_node_rcu(&rbconn->node, parent, rbnode); + rb_insert_color(&rbconn->node, root); +out_unlock: + spin_unlock_bh(&nf_conncount_locks[hash]); + return count; +} + +static unsigned int +count_tree(struct net *net, + struct nf_conncount_data *data, + const u32 *key, + const struct nf_conntrack_tuple *tuple, + const struct nf_conntrack_zone *zone) +{ + struct rb_root *root; + struct rb_node *parent; + struct nf_conncount_rb *rbconn; + unsigned int hash; + u8 keylen = data->keylen; + + hash = jhash2(key, data->keylen, conncount_rnd) % CONNCOUNT_SLOTS; + root = &data->root[hash]; + + parent = rcu_dereference_raw(root->rb_node); + while (parent) { + int diff; + + rbconn = rb_entry(parent, struct nf_conncount_rb, node); + + diff = key_diff(key, rbconn->key, keylen); + if (diff < 0) { + parent = rcu_dereference_raw(parent->rb_left); + } else if (diff > 0) { + parent = rcu_dereference_raw(parent->rb_right); + } else { + int ret; + + if (!tuple) { + nf_conncount_gc_list(net, &rbconn->list); + return rbconn->list.count; + } + + spin_lock_bh(&rbconn->list.list_lock); + /* Node might be about to be free'd. + * We need to defer to insert_tree() in this case. + */ + if (rbconn->list.count == 0) { + spin_unlock_bh(&rbconn->list.list_lock); + break; + } + + /* same source network -> be counted! */ + ret = __nf_conncount_add(net, &rbconn->list, tuple, zone); + spin_unlock_bh(&rbconn->list.list_lock); + if (ret) + return 0; /* hotdrop */ + else + return rbconn->list.count; + } + } + + if (!tuple) + return 0; + + return insert_tree(net, data, root, hash, key, tuple, zone); +} + +static void tree_gc_worker(struct work_struct *work) +{ + struct nf_conncount_data *data = container_of(work, struct nf_conncount_data, gc_work); + struct nf_conncount_rb *gc_nodes[CONNCOUNT_GC_MAX_NODES], *rbconn; + struct rb_root *root; + struct rb_node *node; + unsigned int tree, next_tree, gc_count = 0; + + tree = data->gc_tree % CONNCOUNT_SLOTS; + root = &data->root[tree]; + + local_bh_disable(); + rcu_read_lock(); + for (node = rb_first(root); node != NULL; node = rb_next(node)) { + rbconn = rb_entry(node, struct nf_conncount_rb, node); + if (nf_conncount_gc_list(data->net, &rbconn->list)) + gc_count++; + } + rcu_read_unlock(); + local_bh_enable(); + + cond_resched(); + + spin_lock_bh(&nf_conncount_locks[tree]); + if (gc_count < ARRAY_SIZE(gc_nodes)) + goto next; /* do not bother */ + + gc_count = 0; + node = rb_first(root); + while (node != NULL) { + rbconn = rb_entry(node, struct nf_conncount_rb, node); + node = rb_next(node); + + if (rbconn->list.count > 0) + continue; + + gc_nodes[gc_count++] = rbconn; + if (gc_count >= ARRAY_SIZE(gc_nodes)) { + tree_nodes_free(root, gc_nodes, gc_count); + gc_count = 0; + } + } + + tree_nodes_free(root, gc_nodes, gc_count); +next: + clear_bit(tree, data->pending_trees); + + next_tree = (tree + 1) % CONNCOUNT_SLOTS; + next_tree = find_next_bit(data->pending_trees, CONNCOUNT_SLOTS, next_tree); + + if (next_tree < CONNCOUNT_SLOTS) { + data->gc_tree = next_tree; + schedule_work(work); + } + + spin_unlock_bh(&nf_conncount_locks[tree]); +} + +/* Count and return number of conntrack entries in 'net' with particular 'key'. + * If 'tuple' is not null, insert it into the accounting data structure. + * Call with RCU read lock. + */ +unsigned int nf_conncount_count(struct net *net, + struct nf_conncount_data *data, + const u32 *key, + const struct nf_conntrack_tuple *tuple, + const struct nf_conntrack_zone *zone) +{ + return count_tree(net, data, key, tuple, zone); +} +EXPORT_SYMBOL_GPL(nf_conncount_count); + +struct nf_conncount_data *nf_conncount_init(struct net *net, unsigned int family, + unsigned int keylen) +{ + struct nf_conncount_data *data; + int ret, i; + + if (keylen % sizeof(u32) || + keylen / sizeof(u32) > MAX_KEYLEN || + keylen == 0) + return ERR_PTR(-EINVAL); + + net_get_random_once(&conncount_rnd, sizeof(conncount_rnd)); + + data = kmalloc(sizeof(*data), GFP_KERNEL); + if (!data) + return ERR_PTR(-ENOMEM); + + ret = nf_ct_netns_get(net, family); + if (ret < 0) { + kfree(data); + return ERR_PTR(ret); + } + + for (i = 0; i < ARRAY_SIZE(data->root); ++i) + data->root[i] = RB_ROOT; + + data->keylen = keylen / sizeof(u32); + data->net = net; + INIT_WORK(&data->gc_work, tree_gc_worker); + + return data; +} +EXPORT_SYMBOL_GPL(nf_conncount_init); + +void nf_conncount_cache_free(struct nf_conncount_list *list) +{ + struct nf_conncount_tuple *conn, *conn_n; + + list_for_each_entry_safe(conn, conn_n, &list->head, node) + kmem_cache_free(conncount_conn_cachep, conn); +} +EXPORT_SYMBOL_GPL(nf_conncount_cache_free); + +static void destroy_tree(struct rb_root *r) +{ + struct nf_conncount_rb *rbconn; + struct rb_node *node; + + while ((node = rb_first(r)) != NULL) { + rbconn = rb_entry(node, struct nf_conncount_rb, node); + + rb_erase(node, r); + + nf_conncount_cache_free(&rbconn->list); + + kmem_cache_free(conncount_rb_cachep, rbconn); + } +} + +void nf_conncount_destroy(struct net *net, unsigned int family, + struct nf_conncount_data *data) +{ + unsigned int i; + + cancel_work_sync(&data->gc_work); + nf_ct_netns_put(net, family); + + for (i = 0; i < ARRAY_SIZE(data->root); ++i) + destroy_tree(&data->root[i]); + + kfree(data); +} +EXPORT_SYMBOL_GPL(nf_conncount_destroy); + +static int __init nf_conncount_modinit(void) +{ + int i; + + for (i = 0; i < CONNCOUNT_SLOTS; ++i) + spin_lock_init(&nf_conncount_locks[i]); + + conncount_conn_cachep = kmem_cache_create("nf_conncount_tuple", + sizeof(struct nf_conncount_tuple), + 0, 0, NULL); + if (!conncount_conn_cachep) + return -ENOMEM; + + conncount_rb_cachep = kmem_cache_create("nf_conncount_rb", + sizeof(struct nf_conncount_rb), + 0, 0, NULL); + if (!conncount_rb_cachep) { + kmem_cache_destroy(conncount_conn_cachep); + return -ENOMEM; + } + + return 0; +} + +static void __exit nf_conncount_modexit(void) +{ + kmem_cache_destroy(conncount_conn_cachep); + kmem_cache_destroy(conncount_rb_cachep); +} + +module_init(nf_conncount_modinit); +module_exit(nf_conncount_modexit); +MODULE_AUTHOR("Jan Engelhardt "); +MODULE_AUTHOR("Florian Westphal "); +MODULE_DESCRIPTION("netfilter: count number of connections matching a key"); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/nf_conntrack_acct.c b/net/netfilter/nf_conntrack_acct.c new file mode 100644 index 000000000..385a5f458 --- /dev/null +++ b/net/netfilter/nf_conntrack_acct.c @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Accounting handling for netfilter. */ + +/* + * (C) 2008 Krzysztof Piotr Oledzki + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include + +#include +#include +#include + +static bool nf_ct_acct __read_mostly; + +module_param_named(acct, nf_ct_acct, bool, 0644); +MODULE_PARM_DESC(acct, "Enable connection tracking flow accounting."); + +void nf_conntrack_acct_pernet_init(struct net *net) +{ + net->ct.sysctl_acct = nf_ct_acct; +} diff --git a/net/netfilter/nf_conntrack_amanda.c b/net/netfilter/nf_conntrack_amanda.c new file mode 100644 index 000000000..d011d2eb0 --- /dev/null +++ b/net/netfilter/nf_conntrack_amanda.c @@ -0,0 +1,240 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* Amanda extension for IP connection tracking + * + * (C) 2002 by Brian J. Murrell + * based on HW's ip_conntrack_irc.c as well as other modules + * (C) 2006 Patrick McHardy + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +static unsigned int master_timeout __read_mostly = 300; +static char *ts_algo = "kmp"; + +#define HELPER_NAME "amanda" + +MODULE_AUTHOR("Brian J. Murrell "); +MODULE_DESCRIPTION("Amanda connection tracking module"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ip_conntrack_amanda"); +MODULE_ALIAS_NFCT_HELPER(HELPER_NAME); + +module_param(master_timeout, uint, 0600); +MODULE_PARM_DESC(master_timeout, "timeout for the master connection"); +module_param(ts_algo, charp, 0400); +MODULE_PARM_DESC(ts_algo, "textsearch algorithm to use (default kmp)"); + +unsigned int (*nf_nat_amanda_hook)(struct sk_buff *skb, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned int matchoff, + unsigned int matchlen, + struct nf_conntrack_expect *exp) + __read_mostly; +EXPORT_SYMBOL_GPL(nf_nat_amanda_hook); + +enum amanda_strings { + SEARCH_CONNECT, + SEARCH_NEWLINE, + SEARCH_DATA, + SEARCH_MESG, + SEARCH_INDEX, + SEARCH_STATE, +}; + +static struct { + const char *string; + size_t len; + struct ts_config *ts; +} search[] __read_mostly = { + [SEARCH_CONNECT] = { + .string = "CONNECT ", + .len = 8, + }, + [SEARCH_NEWLINE] = { + .string = "\n", + .len = 1, + }, + [SEARCH_DATA] = { + .string = "DATA ", + .len = 5, + }, + [SEARCH_MESG] = { + .string = "MESG ", + .len = 5, + }, + [SEARCH_INDEX] = { + .string = "INDEX ", + .len = 6, + }, + [SEARCH_STATE] = { + .string = "STATE ", + .len = 6, + }, +}; + +static int amanda_help(struct sk_buff *skb, + unsigned int protoff, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo) +{ + struct nf_conntrack_expect *exp; + struct nf_conntrack_tuple *tuple; + unsigned int dataoff, start, stop, off, i; + char pbuf[sizeof("65535")], *tmp; + u_int16_t len; + __be16 port; + int ret = NF_ACCEPT; + typeof(nf_nat_amanda_hook) nf_nat_amanda; + + /* Only look at packets from the Amanda server */ + if (CTINFO2DIR(ctinfo) == IP_CT_DIR_ORIGINAL) + return NF_ACCEPT; + + /* increase the UDP timeout of the master connection as replies from + * Amanda clients to the server can be quite delayed */ + nf_ct_refresh(ct, skb, master_timeout * HZ); + + /* No data? */ + dataoff = protoff + sizeof(struct udphdr); + if (dataoff >= skb->len) { + net_err_ratelimited("amanda_help: skblen = %u\n", skb->len); + return NF_ACCEPT; + } + + start = skb_find_text(skb, dataoff, skb->len, + search[SEARCH_CONNECT].ts); + if (start == UINT_MAX) + goto out; + start += dataoff + search[SEARCH_CONNECT].len; + + stop = skb_find_text(skb, start, skb->len, + search[SEARCH_NEWLINE].ts); + if (stop == UINT_MAX) + goto out; + stop += start; + + for (i = SEARCH_DATA; i <= SEARCH_STATE; i++) { + off = skb_find_text(skb, start, stop, search[i].ts); + if (off == UINT_MAX) + continue; + off += start + search[i].len; + + len = min_t(unsigned int, sizeof(pbuf) - 1, stop - off); + if (skb_copy_bits(skb, off, pbuf, len)) + break; + pbuf[len] = '\0'; + + port = htons(simple_strtoul(pbuf, &tmp, 10)); + len = tmp - pbuf; + if (port == 0 || len > 5) + break; + + exp = nf_ct_expect_alloc(ct); + if (exp == NULL) { + nf_ct_helper_log(skb, ct, "cannot alloc expectation"); + ret = NF_DROP; + goto out; + } + tuple = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple; + nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT, + nf_ct_l3num(ct), + &tuple->src.u3, &tuple->dst.u3, + IPPROTO_TCP, NULL, &port); + + nf_nat_amanda = rcu_dereference(nf_nat_amanda_hook); + if (nf_nat_amanda && ct->status & IPS_NAT_MASK) + ret = nf_nat_amanda(skb, ctinfo, protoff, + off - dataoff, len, exp); + else if (nf_ct_expect_related(exp, 0) != 0) { + nf_ct_helper_log(skb, ct, "cannot add expectation"); + ret = NF_DROP; + } + nf_ct_expect_put(exp); + } + +out: + return ret; +} + +static const struct nf_conntrack_expect_policy amanda_exp_policy = { + .max_expected = 4, + .timeout = 180, +}; + +static struct nf_conntrack_helper amanda_helper[2] __read_mostly = { + { + .name = HELPER_NAME, + .me = THIS_MODULE, + .help = amanda_help, + .tuple.src.l3num = AF_INET, + .tuple.src.u.udp.port = cpu_to_be16(10080), + .tuple.dst.protonum = IPPROTO_UDP, + .expect_policy = &amanda_exp_policy, + .nat_mod_name = NF_NAT_HELPER_NAME(HELPER_NAME), + }, + { + .name = "amanda", + .me = THIS_MODULE, + .help = amanda_help, + .tuple.src.l3num = AF_INET6, + .tuple.src.u.udp.port = cpu_to_be16(10080), + .tuple.dst.protonum = IPPROTO_UDP, + .expect_policy = &amanda_exp_policy, + .nat_mod_name = NF_NAT_HELPER_NAME(HELPER_NAME), + }, +}; + +static void __exit nf_conntrack_amanda_fini(void) +{ + int i; + + nf_conntrack_helpers_unregister(amanda_helper, + ARRAY_SIZE(amanda_helper)); + for (i = 0; i < ARRAY_SIZE(search); i++) + textsearch_destroy(search[i].ts); +} + +static int __init nf_conntrack_amanda_init(void) +{ + int ret, i; + + NF_CT_HELPER_BUILD_BUG_ON(0); + + for (i = 0; i < ARRAY_SIZE(search); i++) { + search[i].ts = textsearch_prepare(ts_algo, search[i].string, + search[i].len, + GFP_KERNEL, TS_AUTOLOAD); + if (IS_ERR(search[i].ts)) { + ret = PTR_ERR(search[i].ts); + goto err1; + } + } + ret = nf_conntrack_helpers_register(amanda_helper, + ARRAY_SIZE(amanda_helper)); + if (ret < 0) + goto err1; + return 0; + +err1: + while (--i >= 0) + textsearch_destroy(search[i].ts); + + return ret; +} + +module_init(nf_conntrack_amanda_init); +module_exit(nf_conntrack_amanda_fini); diff --git a/net/netfilter/nf_conntrack_bpf.c b/net/netfilter/nf_conntrack_bpf.c new file mode 100644 index 000000000..816283f0a --- /dev/null +++ b/net/netfilter/nf_conntrack_bpf.c @@ -0,0 +1,515 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Unstable Conntrack Helpers for XDP and TC-BPF hook + * + * These are called from the XDP and SCHED_CLS BPF programs. Note that it is + * allowed to break compatibility for these functions since the interface they + * are exposed through to BPF programs is explicitly unstable. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* bpf_ct_opts - Options for CT lookup helpers + * + * Members: + * @netns_id - Specify the network namespace for lookup + * Values: + * BPF_F_CURRENT_NETNS (-1) + * Use namespace associated with ctx (xdp_md, __sk_buff) + * [0, S32_MAX] + * Network Namespace ID + * @error - Out parameter, set for any errors encountered + * Values: + * -EINVAL - Passed NULL for bpf_tuple pointer + * -EINVAL - opts->reserved is not 0 + * -EINVAL - netns_id is less than -1 + * -EINVAL - opts__sz isn't NF_BPF_CT_OPTS_SZ (12) + * -EPROTO - l4proto isn't one of IPPROTO_TCP or IPPROTO_UDP + * -ENONET - No network namespace found for netns_id + * -ENOENT - Conntrack lookup could not find entry for tuple + * -EAFNOSUPPORT - tuple__sz isn't one of sizeof(tuple->ipv4) + * or sizeof(tuple->ipv6) + * @l4proto - Layer 4 protocol + * Values: + * IPPROTO_TCP, IPPROTO_UDP + * @dir: - connection tracking tuple direction. + * @reserved - Reserved member, will be reused for more options in future + * Values: + * 0 + */ +struct bpf_ct_opts { + s32 netns_id; + s32 error; + u8 l4proto; + u8 dir; + u8 reserved[2]; +}; + +enum { + NF_BPF_CT_OPTS_SZ = 12, +}; + +static int bpf_nf_ct_tuple_parse(struct bpf_sock_tuple *bpf_tuple, + u32 tuple_len, u8 protonum, u8 dir, + struct nf_conntrack_tuple *tuple) +{ + union nf_inet_addr *src = dir ? &tuple->dst.u3 : &tuple->src.u3; + union nf_inet_addr *dst = dir ? &tuple->src.u3 : &tuple->dst.u3; + union nf_conntrack_man_proto *sport = dir ? (void *)&tuple->dst.u + : &tuple->src.u; + union nf_conntrack_man_proto *dport = dir ? &tuple->src.u + : (void *)&tuple->dst.u; + + if (unlikely(protonum != IPPROTO_TCP && protonum != IPPROTO_UDP)) + return -EPROTO; + + memset(tuple, 0, sizeof(*tuple)); + + switch (tuple_len) { + case sizeof(bpf_tuple->ipv4): + tuple->src.l3num = AF_INET; + src->ip = bpf_tuple->ipv4.saddr; + sport->tcp.port = bpf_tuple->ipv4.sport; + dst->ip = bpf_tuple->ipv4.daddr; + dport->tcp.port = bpf_tuple->ipv4.dport; + break; + case sizeof(bpf_tuple->ipv6): + tuple->src.l3num = AF_INET6; + memcpy(src->ip6, bpf_tuple->ipv6.saddr, sizeof(bpf_tuple->ipv6.saddr)); + sport->tcp.port = bpf_tuple->ipv6.sport; + memcpy(dst->ip6, bpf_tuple->ipv6.daddr, sizeof(bpf_tuple->ipv6.daddr)); + dport->tcp.port = bpf_tuple->ipv6.dport; + break; + default: + return -EAFNOSUPPORT; + } + tuple->dst.protonum = protonum; + tuple->dst.dir = dir; + + return 0; +} + +static struct nf_conn * +__bpf_nf_ct_alloc_entry(struct net *net, struct bpf_sock_tuple *bpf_tuple, + u32 tuple_len, struct bpf_ct_opts *opts, u32 opts_len, + u32 timeout) +{ + struct nf_conntrack_tuple otuple, rtuple; + struct nf_conn *ct; + int err; + + if (!opts || !bpf_tuple || opts->reserved[0] || opts->reserved[1] || + opts_len != NF_BPF_CT_OPTS_SZ) + return ERR_PTR(-EINVAL); + + if (unlikely(opts->netns_id < BPF_F_CURRENT_NETNS)) + return ERR_PTR(-EINVAL); + + err = bpf_nf_ct_tuple_parse(bpf_tuple, tuple_len, opts->l4proto, + IP_CT_DIR_ORIGINAL, &otuple); + if (err < 0) + return ERR_PTR(err); + + err = bpf_nf_ct_tuple_parse(bpf_tuple, tuple_len, opts->l4proto, + IP_CT_DIR_REPLY, &rtuple); + if (err < 0) + return ERR_PTR(err); + + if (opts->netns_id >= 0) { + net = get_net_ns_by_id(net, opts->netns_id); + if (unlikely(!net)) + return ERR_PTR(-ENONET); + } + + ct = nf_conntrack_alloc(net, &nf_ct_zone_dflt, &otuple, &rtuple, + GFP_ATOMIC); + if (IS_ERR(ct)) + goto out; + + memset(&ct->proto, 0, sizeof(ct->proto)); + __nf_ct_set_timeout(ct, timeout * HZ); + +out: + if (opts->netns_id >= 0) + put_net(net); + + return ct; +} + +static struct nf_conn *__bpf_nf_ct_lookup(struct net *net, + struct bpf_sock_tuple *bpf_tuple, + u32 tuple_len, struct bpf_ct_opts *opts, + u32 opts_len) +{ + struct nf_conntrack_tuple_hash *hash; + struct nf_conntrack_tuple tuple; + struct nf_conn *ct; + int err; + + if (!opts || !bpf_tuple || opts->reserved[0] || opts->reserved[1] || + opts_len != NF_BPF_CT_OPTS_SZ) + return ERR_PTR(-EINVAL); + if (unlikely(opts->l4proto != IPPROTO_TCP && opts->l4proto != IPPROTO_UDP)) + return ERR_PTR(-EPROTO); + if (unlikely(opts->netns_id < BPF_F_CURRENT_NETNS)) + return ERR_PTR(-EINVAL); + + err = bpf_nf_ct_tuple_parse(bpf_tuple, tuple_len, opts->l4proto, + IP_CT_DIR_ORIGINAL, &tuple); + if (err < 0) + return ERR_PTR(err); + + if (opts->netns_id >= 0) { + net = get_net_ns_by_id(net, opts->netns_id); + if (unlikely(!net)) + return ERR_PTR(-ENONET); + } + + hash = nf_conntrack_find_get(net, &nf_ct_zone_dflt, &tuple); + if (opts->netns_id >= 0) + put_net(net); + if (!hash) + return ERR_PTR(-ENOENT); + + ct = nf_ct_tuplehash_to_ctrack(hash); + opts->dir = NF_CT_DIRECTION(hash); + + return ct; +} + +BTF_ID_LIST(btf_nf_conn_ids) +BTF_ID(struct, nf_conn) +BTF_ID(struct, nf_conn___init) + +/* Check writes into `struct nf_conn` */ +static int _nf_conntrack_btf_struct_access(struct bpf_verifier_log *log, + const struct btf *btf, + const struct btf_type *t, int off, + int size, enum bpf_access_type atype, + u32 *next_btf_id, + enum bpf_type_flag *flag) +{ + const struct btf_type *ncit; + const struct btf_type *nct; + size_t end; + + ncit = btf_type_by_id(btf, btf_nf_conn_ids[1]); + nct = btf_type_by_id(btf, btf_nf_conn_ids[0]); + + if (t != nct && t != ncit) { + bpf_log(log, "only read is supported\n"); + return -EACCES; + } + + /* `struct nf_conn` and `struct nf_conn___init` have the same layout + * so we are safe to simply merge offset checks here + */ + switch (off) { +#if defined(CONFIG_NF_CONNTRACK_MARK) + case offsetof(struct nf_conn, mark): + end = offsetofend(struct nf_conn, mark); + break; +#endif + default: + bpf_log(log, "no write support to nf_conn at off %d\n", off); + return -EACCES; + } + + if (off + size > end) { + bpf_log(log, + "write access at off %d with size %d beyond the member of nf_conn ended at %zu\n", + off, size, end); + return -EACCES; + } + + return 0; +} + +__diag_push(); +__diag_ignore_all("-Wmissing-prototypes", + "Global functions as their definitions will be in nf_conntrack BTF"); + +/* bpf_xdp_ct_alloc - Allocate a new CT entry + * + * Parameters: + * @xdp_ctx - Pointer to ctx (xdp_md) in XDP program + * Cannot be NULL + * @bpf_tuple - Pointer to memory representing the tuple to look up + * Cannot be NULL + * @tuple__sz - Length of the tuple structure + * Must be one of sizeof(bpf_tuple->ipv4) or + * sizeof(bpf_tuple->ipv6) + * @opts - Additional options for allocation (documented above) + * Cannot be NULL + * @opts__sz - Length of the bpf_ct_opts structure + * Must be NF_BPF_CT_OPTS_SZ (12) + */ +struct nf_conn___init * +bpf_xdp_ct_alloc(struct xdp_md *xdp_ctx, struct bpf_sock_tuple *bpf_tuple, + u32 tuple__sz, struct bpf_ct_opts *opts, u32 opts__sz) +{ + struct xdp_buff *ctx = (struct xdp_buff *)xdp_ctx; + struct nf_conn *nfct; + + nfct = __bpf_nf_ct_alloc_entry(dev_net(ctx->rxq->dev), bpf_tuple, tuple__sz, + opts, opts__sz, 10); + if (IS_ERR(nfct)) { + if (opts) + opts->error = PTR_ERR(nfct); + return NULL; + } + + return (struct nf_conn___init *)nfct; +} + +/* bpf_xdp_ct_lookup - Lookup CT entry for the given tuple, and acquire a + * reference to it + * + * Parameters: + * @xdp_ctx - Pointer to ctx (xdp_md) in XDP program + * Cannot be NULL + * @bpf_tuple - Pointer to memory representing the tuple to look up + * Cannot be NULL + * @tuple__sz - Length of the tuple structure + * Must be one of sizeof(bpf_tuple->ipv4) or + * sizeof(bpf_tuple->ipv6) + * @opts - Additional options for lookup (documented above) + * Cannot be NULL + * @opts__sz - Length of the bpf_ct_opts structure + * Must be NF_BPF_CT_OPTS_SZ (12) + */ +struct nf_conn * +bpf_xdp_ct_lookup(struct xdp_md *xdp_ctx, struct bpf_sock_tuple *bpf_tuple, + u32 tuple__sz, struct bpf_ct_opts *opts, u32 opts__sz) +{ + struct xdp_buff *ctx = (struct xdp_buff *)xdp_ctx; + struct net *caller_net; + struct nf_conn *nfct; + + caller_net = dev_net(ctx->rxq->dev); + nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple__sz, opts, opts__sz); + if (IS_ERR(nfct)) { + if (opts) + opts->error = PTR_ERR(nfct); + return NULL; + } + return nfct; +} + +/* bpf_skb_ct_alloc - Allocate a new CT entry + * + * Parameters: + * @skb_ctx - Pointer to ctx (__sk_buff) in TC program + * Cannot be NULL + * @bpf_tuple - Pointer to memory representing the tuple to look up + * Cannot be NULL + * @tuple__sz - Length of the tuple structure + * Must be one of sizeof(bpf_tuple->ipv4) or + * sizeof(bpf_tuple->ipv6) + * @opts - Additional options for allocation (documented above) + * Cannot be NULL + * @opts__sz - Length of the bpf_ct_opts structure + * Must be NF_BPF_CT_OPTS_SZ (12) + */ +struct nf_conn___init * +bpf_skb_ct_alloc(struct __sk_buff *skb_ctx, struct bpf_sock_tuple *bpf_tuple, + u32 tuple__sz, struct bpf_ct_opts *opts, u32 opts__sz) +{ + struct sk_buff *skb = (struct sk_buff *)skb_ctx; + struct nf_conn *nfct; + struct net *net; + + net = skb->dev ? dev_net(skb->dev) : sock_net(skb->sk); + nfct = __bpf_nf_ct_alloc_entry(net, bpf_tuple, tuple__sz, opts, opts__sz, 10); + if (IS_ERR(nfct)) { + if (opts) + opts->error = PTR_ERR(nfct); + return NULL; + } + + return (struct nf_conn___init *)nfct; +} + +/* bpf_skb_ct_lookup - Lookup CT entry for the given tuple, and acquire a + * reference to it + * + * Parameters: + * @skb_ctx - Pointer to ctx (__sk_buff) in TC program + * Cannot be NULL + * @bpf_tuple - Pointer to memory representing the tuple to look up + * Cannot be NULL + * @tuple__sz - Length of the tuple structure + * Must be one of sizeof(bpf_tuple->ipv4) or + * sizeof(bpf_tuple->ipv6) + * @opts - Additional options for lookup (documented above) + * Cannot be NULL + * @opts__sz - Length of the bpf_ct_opts structure + * Must be NF_BPF_CT_OPTS_SZ (12) + */ +struct nf_conn * +bpf_skb_ct_lookup(struct __sk_buff *skb_ctx, struct bpf_sock_tuple *bpf_tuple, + u32 tuple__sz, struct bpf_ct_opts *opts, u32 opts__sz) +{ + struct sk_buff *skb = (struct sk_buff *)skb_ctx; + struct net *caller_net; + struct nf_conn *nfct; + + caller_net = skb->dev ? dev_net(skb->dev) : sock_net(skb->sk); + nfct = __bpf_nf_ct_lookup(caller_net, bpf_tuple, tuple__sz, opts, opts__sz); + if (IS_ERR(nfct)) { + if (opts) + opts->error = PTR_ERR(nfct); + return NULL; + } + return nfct; +} + +/* bpf_ct_insert_entry - Add the provided entry into a CT map + * + * This must be invoked for referenced PTR_TO_BTF_ID. + * + * @nfct - Pointer to referenced nf_conn___init object, obtained + * using bpf_xdp_ct_alloc or bpf_skb_ct_alloc. + */ +struct nf_conn *bpf_ct_insert_entry(struct nf_conn___init *nfct_i) +{ + struct nf_conn *nfct = (struct nf_conn *)nfct_i; + int err; + + if (!nf_ct_is_confirmed(nfct)) + nfct->timeout += nfct_time_stamp; + nfct->status |= IPS_CONFIRMED; + err = nf_conntrack_hash_check_insert(nfct); + if (err < 0) { + nf_conntrack_free(nfct); + return NULL; + } + return nfct; +} + +/* bpf_ct_release - Release acquired nf_conn object + * + * This must be invoked for referenced PTR_TO_BTF_ID, and the verifier rejects + * the program if any references remain in the program in all of the explored + * states. + * + * Parameters: + * @nf_conn - Pointer to referenced nf_conn object, obtained using + * bpf_xdp_ct_lookup or bpf_skb_ct_lookup. + */ +void bpf_ct_release(struct nf_conn *nfct) +{ + if (!nfct) + return; + nf_ct_put(nfct); +} + +/* bpf_ct_set_timeout - Set timeout of allocated nf_conn + * + * Sets the default timeout of newly allocated nf_conn before insertion. + * This helper must be invoked for refcounted pointer to nf_conn___init. + * + * Parameters: + * @nfct - Pointer to referenced nf_conn object, obtained using + * bpf_xdp_ct_alloc or bpf_skb_ct_alloc. + * @timeout - Timeout in msecs. + */ +void bpf_ct_set_timeout(struct nf_conn___init *nfct, u32 timeout) +{ + __nf_ct_set_timeout((struct nf_conn *)nfct, msecs_to_jiffies(timeout)); +} + +/* bpf_ct_change_timeout - Change timeout of inserted nf_conn + * + * Change timeout associated of the inserted or looked up nf_conn. + * This helper must be invoked for refcounted pointer to nf_conn. + * + * Parameters: + * @nfct - Pointer to referenced nf_conn object, obtained using + * bpf_ct_insert_entry, bpf_xdp_ct_lookup, or bpf_skb_ct_lookup. + * @timeout - New timeout in msecs. + */ +int bpf_ct_change_timeout(struct nf_conn *nfct, u32 timeout) +{ + return __nf_ct_change_timeout(nfct, msecs_to_jiffies(timeout)); +} + +/* bpf_ct_set_status - Set status field of allocated nf_conn + * + * Set the status field of the newly allocated nf_conn before insertion. + * This must be invoked for referenced PTR_TO_BTF_ID to nf_conn___init. + * + * Parameters: + * @nfct - Pointer to referenced nf_conn object, obtained using + * bpf_xdp_ct_alloc or bpf_skb_ct_alloc. + * @status - New status value. + */ +int bpf_ct_set_status(const struct nf_conn___init *nfct, u32 status) +{ + return nf_ct_change_status_common((struct nf_conn *)nfct, status); +} + +/* bpf_ct_change_status - Change status of inserted nf_conn + * + * Change the status field of the provided connection tracking entry. + * This must be invoked for referenced PTR_TO_BTF_ID to nf_conn. + * + * Parameters: + * @nfct - Pointer to referenced nf_conn object, obtained using + * bpf_ct_insert_entry, bpf_xdp_ct_lookup or bpf_skb_ct_lookup. + * @status - New status value. + */ +int bpf_ct_change_status(struct nf_conn *nfct, u32 status) +{ + return nf_ct_change_status_common(nfct, status); +} + +__diag_pop() + +BTF_SET8_START(nf_ct_kfunc_set) +BTF_ID_FLAGS(func, bpf_xdp_ct_alloc, KF_ACQUIRE | KF_RET_NULL) +BTF_ID_FLAGS(func, bpf_xdp_ct_lookup, KF_ACQUIRE | KF_RET_NULL) +BTF_ID_FLAGS(func, bpf_skb_ct_alloc, KF_ACQUIRE | KF_RET_NULL) +BTF_ID_FLAGS(func, bpf_skb_ct_lookup, KF_ACQUIRE | KF_RET_NULL) +BTF_ID_FLAGS(func, bpf_ct_insert_entry, KF_ACQUIRE | KF_RET_NULL | KF_RELEASE) +BTF_ID_FLAGS(func, bpf_ct_release, KF_RELEASE) +BTF_ID_FLAGS(func, bpf_ct_set_timeout, KF_TRUSTED_ARGS) +BTF_ID_FLAGS(func, bpf_ct_change_timeout, KF_TRUSTED_ARGS) +BTF_ID_FLAGS(func, bpf_ct_set_status, KF_TRUSTED_ARGS) +BTF_ID_FLAGS(func, bpf_ct_change_status, KF_TRUSTED_ARGS) +BTF_SET8_END(nf_ct_kfunc_set) + +static const struct btf_kfunc_id_set nf_conntrack_kfunc_set = { + .owner = THIS_MODULE, + .set = &nf_ct_kfunc_set, +}; + +int register_nf_conntrack_bpf(void) +{ + int ret; + + ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_XDP, &nf_conntrack_kfunc_set); + ret = ret ?: register_btf_kfunc_id_set(BPF_PROG_TYPE_SCHED_CLS, &nf_conntrack_kfunc_set); + if (!ret) { + mutex_lock(&nf_conn_btf_access_lock); + nfct_btf_struct_access = _nf_conntrack_btf_struct_access; + mutex_unlock(&nf_conn_btf_access_lock); + } + + return ret; +} + +void cleanup_nf_conntrack_bpf(void) +{ + mutex_lock(&nf_conn_btf_access_lock); + nfct_btf_struct_access = NULL; + mutex_unlock(&nf_conn_btf_access_lock); +} diff --git a/net/netfilter/nf_conntrack_broadcast.c b/net/netfilter/nf_conntrack_broadcast.c new file mode 100644 index 000000000..9fb9b8031 --- /dev/null +++ b/net/netfilter/nf_conntrack_broadcast.c @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * broadcast connection tracking helper + * + * (c) 2005 Patrick McHardy + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +int nf_conntrack_broadcast_help(struct sk_buff *skb, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int timeout) +{ + const struct nf_conntrack_helper *helper; + struct nf_conntrack_expect *exp; + struct iphdr *iph = ip_hdr(skb); + struct rtable *rt = skb_rtable(skb); + struct in_device *in_dev; + struct nf_conn_help *help = nfct_help(ct); + __be32 mask = 0; + + /* we're only interested in locally generated packets */ + if (skb->sk == NULL || !net_eq(nf_ct_net(ct), sock_net(skb->sk))) + goto out; + if (rt == NULL || !(rt->rt_flags & RTCF_BROADCAST)) + goto out; + if (CTINFO2DIR(ctinfo) != IP_CT_DIR_ORIGINAL) + goto out; + + in_dev = __in_dev_get_rcu(rt->dst.dev); + if (in_dev != NULL) { + const struct in_ifaddr *ifa; + + in_dev_for_each_ifa_rcu(ifa, in_dev) { + if (ifa->ifa_flags & IFA_F_SECONDARY) + continue; + + if (ifa->ifa_broadcast == iph->daddr) { + mask = ifa->ifa_mask; + break; + } + } + } + + if (mask == 0) + goto out; + + exp = nf_ct_expect_alloc(ct); + if (exp == NULL) + goto out; + + exp->tuple = ct->tuplehash[IP_CT_DIR_REPLY].tuple; + + helper = rcu_dereference(help->helper); + if (helper) + exp->tuple.src.u.udp.port = helper->tuple.src.u.udp.port; + + exp->mask.src.u3.ip = mask; + exp->mask.src.u.udp.port = htons(0xFFFF); + + exp->expectfn = NULL; + exp->flags = NF_CT_EXPECT_PERMANENT; + exp->class = NF_CT_EXPECT_CLASS_DEFAULT; + exp->helper = NULL; + + nf_ct_expect_related(exp, 0); + nf_ct_expect_put(exp); + + nf_ct_refresh(ct, skb, timeout * HZ); +out: + return NF_ACCEPT; +} +EXPORT_SYMBOL_GPL(nf_conntrack_broadcast_help); + +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/nf_conntrack_core.c b/net/netfilter/nf_conntrack_core.c new file mode 100644 index 000000000..796026296 --- /dev/null +++ b/net/netfilter/nf_conntrack_core.c @@ -0,0 +1,2875 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Connection state tracking for netfilter. This is separated from, + but required by, the NAT layer; it can also be used by an iptables + extension. */ + +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2006 Netfilter Core Team + * (C) 2003,2004 USAGI/WIDE Project + * (C) 2005-2012 Patrick McHardy + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "nf_internals.h" + +__cacheline_aligned_in_smp spinlock_t nf_conntrack_locks[CONNTRACK_LOCKS]; +EXPORT_SYMBOL_GPL(nf_conntrack_locks); + +__cacheline_aligned_in_smp DEFINE_SPINLOCK(nf_conntrack_expect_lock); +EXPORT_SYMBOL_GPL(nf_conntrack_expect_lock); + +struct hlist_nulls_head *nf_conntrack_hash __read_mostly; +EXPORT_SYMBOL_GPL(nf_conntrack_hash); + +struct conntrack_gc_work { + struct delayed_work dwork; + u32 next_bucket; + u32 avg_timeout; + u32 count; + u32 start_time; + bool exiting; + bool early_drop; +}; + +static __read_mostly struct kmem_cache *nf_conntrack_cachep; +static DEFINE_SPINLOCK(nf_conntrack_locks_all_lock); +static __read_mostly bool nf_conntrack_locks_all; + +/* serialize hash resizes and nf_ct_iterate_cleanup */ +static DEFINE_MUTEX(nf_conntrack_mutex); + +#define GC_SCAN_INTERVAL_MAX (60ul * HZ) +#define GC_SCAN_INTERVAL_MIN (1ul * HZ) + +/* clamp timeouts to this value (TCP unacked) */ +#define GC_SCAN_INTERVAL_CLAMP (300ul * HZ) + +/* Initial bias pretending we have 100 entries at the upper bound so we don't + * wakeup often just because we have three entries with a 1s timeout while still + * allowing non-idle machines to wakeup more often when needed. + */ +#define GC_SCAN_INITIAL_COUNT 100 +#define GC_SCAN_INTERVAL_INIT GC_SCAN_INTERVAL_MAX + +#define GC_SCAN_MAX_DURATION msecs_to_jiffies(10) +#define GC_SCAN_EXPIRED_MAX (64000u / HZ) + +#define MIN_CHAINLEN 50u +#define MAX_CHAINLEN (80u - MIN_CHAINLEN) + +static struct conntrack_gc_work conntrack_gc_work; + +void nf_conntrack_lock(spinlock_t *lock) __acquires(lock) +{ + /* 1) Acquire the lock */ + spin_lock(lock); + + /* 2) read nf_conntrack_locks_all, with ACQUIRE semantics + * It pairs with the smp_store_release() in nf_conntrack_all_unlock() + */ + if (likely(smp_load_acquire(&nf_conntrack_locks_all) == false)) + return; + + /* fast path failed, unlock */ + spin_unlock(lock); + + /* Slow path 1) get global lock */ + spin_lock(&nf_conntrack_locks_all_lock); + + /* Slow path 2) get the lock we want */ + spin_lock(lock); + + /* Slow path 3) release the global lock */ + spin_unlock(&nf_conntrack_locks_all_lock); +} +EXPORT_SYMBOL_GPL(nf_conntrack_lock); + +static void nf_conntrack_double_unlock(unsigned int h1, unsigned int h2) +{ + h1 %= CONNTRACK_LOCKS; + h2 %= CONNTRACK_LOCKS; + spin_unlock(&nf_conntrack_locks[h1]); + if (h1 != h2) + spin_unlock(&nf_conntrack_locks[h2]); +} + +/* return true if we need to recompute hashes (in case hash table was resized) */ +static bool nf_conntrack_double_lock(struct net *net, unsigned int h1, + unsigned int h2, unsigned int sequence) +{ + h1 %= CONNTRACK_LOCKS; + h2 %= CONNTRACK_LOCKS; + if (h1 <= h2) { + nf_conntrack_lock(&nf_conntrack_locks[h1]); + if (h1 != h2) + spin_lock_nested(&nf_conntrack_locks[h2], + SINGLE_DEPTH_NESTING); + } else { + nf_conntrack_lock(&nf_conntrack_locks[h2]); + spin_lock_nested(&nf_conntrack_locks[h1], + SINGLE_DEPTH_NESTING); + } + if (read_seqcount_retry(&nf_conntrack_generation, sequence)) { + nf_conntrack_double_unlock(h1, h2); + return true; + } + return false; +} + +static void nf_conntrack_all_lock(void) + __acquires(&nf_conntrack_locks_all_lock) +{ + int i; + + spin_lock(&nf_conntrack_locks_all_lock); + + /* For nf_contrack_locks_all, only the latest time when another + * CPU will see an update is controlled, by the "release" of the + * spin_lock below. + * The earliest time is not controlled, an thus KCSAN could detect + * a race when nf_conntract_lock() reads the variable. + * WRITE_ONCE() is used to ensure the compiler will not + * optimize the write. + */ + WRITE_ONCE(nf_conntrack_locks_all, true); + + for (i = 0; i < CONNTRACK_LOCKS; i++) { + spin_lock(&nf_conntrack_locks[i]); + + /* This spin_unlock provides the "release" to ensure that + * nf_conntrack_locks_all==true is visible to everyone that + * acquired spin_lock(&nf_conntrack_locks[]). + */ + spin_unlock(&nf_conntrack_locks[i]); + } +} + +static void nf_conntrack_all_unlock(void) + __releases(&nf_conntrack_locks_all_lock) +{ + /* All prior stores must be complete before we clear + * 'nf_conntrack_locks_all'. Otherwise nf_conntrack_lock() + * might observe the false value but not the entire + * critical section. + * It pairs with the smp_load_acquire() in nf_conntrack_lock() + */ + smp_store_release(&nf_conntrack_locks_all, false); + spin_unlock(&nf_conntrack_locks_all_lock); +} + +unsigned int nf_conntrack_htable_size __read_mostly; +EXPORT_SYMBOL_GPL(nf_conntrack_htable_size); + +unsigned int nf_conntrack_max __read_mostly; +EXPORT_SYMBOL_GPL(nf_conntrack_max); +seqcount_spinlock_t nf_conntrack_generation __read_mostly; +static siphash_aligned_key_t nf_conntrack_hash_rnd; + +static u32 hash_conntrack_raw(const struct nf_conntrack_tuple *tuple, + unsigned int zoneid, + const struct net *net) +{ + struct { + struct nf_conntrack_man src; + union nf_inet_addr dst_addr; + unsigned int zone; + u32 net_mix; + u16 dport; + u16 proto; + } __aligned(SIPHASH_ALIGNMENT) combined; + + get_random_once(&nf_conntrack_hash_rnd, sizeof(nf_conntrack_hash_rnd)); + + memset(&combined, 0, sizeof(combined)); + + /* The direction must be ignored, so handle usable members manually. */ + combined.src = tuple->src; + combined.dst_addr = tuple->dst.u3; + combined.zone = zoneid; + combined.net_mix = net_hash_mix(net); + combined.dport = (__force __u16)tuple->dst.u.all; + combined.proto = tuple->dst.protonum; + + return (u32)siphash(&combined, sizeof(combined), &nf_conntrack_hash_rnd); +} + +static u32 scale_hash(u32 hash) +{ + return reciprocal_scale(hash, nf_conntrack_htable_size); +} + +static u32 __hash_conntrack(const struct net *net, + const struct nf_conntrack_tuple *tuple, + unsigned int zoneid, + unsigned int size) +{ + return reciprocal_scale(hash_conntrack_raw(tuple, zoneid, net), size); +} + +static u32 hash_conntrack(const struct net *net, + const struct nf_conntrack_tuple *tuple, + unsigned int zoneid) +{ + return scale_hash(hash_conntrack_raw(tuple, zoneid, net)); +} + +static bool nf_ct_get_tuple_ports(const struct sk_buff *skb, + unsigned int dataoff, + struct nf_conntrack_tuple *tuple) +{ struct { + __be16 sport; + __be16 dport; + } _inet_hdr, *inet_hdr; + + /* Actually only need first 4 bytes to get ports. */ + inet_hdr = skb_header_pointer(skb, dataoff, sizeof(_inet_hdr), &_inet_hdr); + if (!inet_hdr) + return false; + + tuple->src.u.udp.port = inet_hdr->sport; + tuple->dst.u.udp.port = inet_hdr->dport; + return true; +} + +static bool +nf_ct_get_tuple(const struct sk_buff *skb, + unsigned int nhoff, + unsigned int dataoff, + u_int16_t l3num, + u_int8_t protonum, + struct net *net, + struct nf_conntrack_tuple *tuple) +{ + unsigned int size; + const __be32 *ap; + __be32 _addrs[8]; + + memset(tuple, 0, sizeof(*tuple)); + + tuple->src.l3num = l3num; + switch (l3num) { + case NFPROTO_IPV4: + nhoff += offsetof(struct iphdr, saddr); + size = 2 * sizeof(__be32); + break; + case NFPROTO_IPV6: + nhoff += offsetof(struct ipv6hdr, saddr); + size = sizeof(_addrs); + break; + default: + return true; + } + + ap = skb_header_pointer(skb, nhoff, size, _addrs); + if (!ap) + return false; + + switch (l3num) { + case NFPROTO_IPV4: + tuple->src.u3.ip = ap[0]; + tuple->dst.u3.ip = ap[1]; + break; + case NFPROTO_IPV6: + memcpy(tuple->src.u3.ip6, ap, sizeof(tuple->src.u3.ip6)); + memcpy(tuple->dst.u3.ip6, ap + 4, sizeof(tuple->dst.u3.ip6)); + break; + } + + tuple->dst.protonum = protonum; + tuple->dst.dir = IP_CT_DIR_ORIGINAL; + + switch (protonum) { +#if IS_ENABLED(CONFIG_IPV6) + case IPPROTO_ICMPV6: + return icmpv6_pkt_to_tuple(skb, dataoff, net, tuple); +#endif + case IPPROTO_ICMP: + return icmp_pkt_to_tuple(skb, dataoff, net, tuple); +#ifdef CONFIG_NF_CT_PROTO_GRE + case IPPROTO_GRE: + return gre_pkt_to_tuple(skb, dataoff, net, tuple); +#endif + case IPPROTO_TCP: + case IPPROTO_UDP: +#ifdef CONFIG_NF_CT_PROTO_UDPLITE + case IPPROTO_UDPLITE: +#endif +#ifdef CONFIG_NF_CT_PROTO_SCTP + case IPPROTO_SCTP: +#endif +#ifdef CONFIG_NF_CT_PROTO_DCCP + case IPPROTO_DCCP: +#endif + /* fallthrough */ + return nf_ct_get_tuple_ports(skb, dataoff, tuple); + default: + break; + } + + return true; +} + +static int ipv4_get_l4proto(const struct sk_buff *skb, unsigned int nhoff, + u_int8_t *protonum) +{ + int dataoff = -1; + const struct iphdr *iph; + struct iphdr _iph; + + iph = skb_header_pointer(skb, nhoff, sizeof(_iph), &_iph); + if (!iph) + return -1; + + /* Conntrack defragments packets, we might still see fragments + * inside ICMP packets though. + */ + if (iph->frag_off & htons(IP_OFFSET)) + return -1; + + dataoff = nhoff + (iph->ihl << 2); + *protonum = iph->protocol; + + /* Check bogus IP headers */ + if (dataoff > skb->len) { + pr_debug("bogus IPv4 packet: nhoff %u, ihl %u, skblen %u\n", + nhoff, iph->ihl << 2, skb->len); + return -1; + } + return dataoff; +} + +#if IS_ENABLED(CONFIG_IPV6) +static int ipv6_get_l4proto(const struct sk_buff *skb, unsigned int nhoff, + u8 *protonum) +{ + int protoff = -1; + unsigned int extoff = nhoff + sizeof(struct ipv6hdr); + __be16 frag_off; + u8 nexthdr; + + if (skb_copy_bits(skb, nhoff + offsetof(struct ipv6hdr, nexthdr), + &nexthdr, sizeof(nexthdr)) != 0) { + pr_debug("can't get nexthdr\n"); + return -1; + } + protoff = ipv6_skip_exthdr(skb, extoff, &nexthdr, &frag_off); + /* + * (protoff == skb->len) means the packet has not data, just + * IPv6 and possibly extensions headers, but it is tracked anyway + */ + if (protoff < 0 || (frag_off & htons(~0x7)) != 0) { + pr_debug("can't find proto in pkt\n"); + return -1; + } + + *protonum = nexthdr; + return protoff; +} +#endif + +static int get_l4proto(const struct sk_buff *skb, + unsigned int nhoff, u8 pf, u8 *l4num) +{ + switch (pf) { + case NFPROTO_IPV4: + return ipv4_get_l4proto(skb, nhoff, l4num); +#if IS_ENABLED(CONFIG_IPV6) + case NFPROTO_IPV6: + return ipv6_get_l4proto(skb, nhoff, l4num); +#endif + default: + *l4num = 0; + break; + } + return -1; +} + +bool nf_ct_get_tuplepr(const struct sk_buff *skb, unsigned int nhoff, + u_int16_t l3num, + struct net *net, struct nf_conntrack_tuple *tuple) +{ + u8 protonum; + int protoff; + + protoff = get_l4proto(skb, nhoff, l3num, &protonum); + if (protoff <= 0) + return false; + + return nf_ct_get_tuple(skb, nhoff, protoff, l3num, protonum, net, tuple); +} +EXPORT_SYMBOL_GPL(nf_ct_get_tuplepr); + +bool +nf_ct_invert_tuple(struct nf_conntrack_tuple *inverse, + const struct nf_conntrack_tuple *orig) +{ + memset(inverse, 0, sizeof(*inverse)); + + inverse->src.l3num = orig->src.l3num; + + switch (orig->src.l3num) { + case NFPROTO_IPV4: + inverse->src.u3.ip = orig->dst.u3.ip; + inverse->dst.u3.ip = orig->src.u3.ip; + break; + case NFPROTO_IPV6: + inverse->src.u3.in6 = orig->dst.u3.in6; + inverse->dst.u3.in6 = orig->src.u3.in6; + break; + default: + break; + } + + inverse->dst.dir = !orig->dst.dir; + + inverse->dst.protonum = orig->dst.protonum; + + switch (orig->dst.protonum) { + case IPPROTO_ICMP: + return nf_conntrack_invert_icmp_tuple(inverse, orig); +#if IS_ENABLED(CONFIG_IPV6) + case IPPROTO_ICMPV6: + return nf_conntrack_invert_icmpv6_tuple(inverse, orig); +#endif + } + + inverse->src.u.all = orig->dst.u.all; + inverse->dst.u.all = orig->src.u.all; + return true; +} +EXPORT_SYMBOL_GPL(nf_ct_invert_tuple); + +/* Generate a almost-unique pseudo-id for a given conntrack. + * + * intentionally doesn't re-use any of the seeds used for hash + * table location, we assume id gets exposed to userspace. + * + * Following nf_conn items do not change throughout lifetime + * of the nf_conn: + * + * 1. nf_conn address + * 2. nf_conn->master address (normally NULL) + * 3. the associated net namespace + * 4. the original direction tuple + */ +u32 nf_ct_get_id(const struct nf_conn *ct) +{ + static siphash_aligned_key_t ct_id_seed; + unsigned long a, b, c, d; + + net_get_random_once(&ct_id_seed, sizeof(ct_id_seed)); + + a = (unsigned long)ct; + b = (unsigned long)ct->master; + c = (unsigned long)nf_ct_net(ct); + d = (unsigned long)siphash(&ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple, + sizeof(ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple), + &ct_id_seed); +#ifdef CONFIG_64BIT + return siphash_4u64((u64)a, (u64)b, (u64)c, (u64)d, &ct_id_seed); +#else + return siphash_4u32((u32)a, (u32)b, (u32)c, (u32)d, &ct_id_seed); +#endif +} +EXPORT_SYMBOL_GPL(nf_ct_get_id); + +static void +clean_from_lists(struct nf_conn *ct) +{ + pr_debug("clean_from_lists(%p)\n", ct); + hlist_nulls_del_rcu(&ct->tuplehash[IP_CT_DIR_ORIGINAL].hnnode); + hlist_nulls_del_rcu(&ct->tuplehash[IP_CT_DIR_REPLY].hnnode); + + /* Destroy all pending expectations */ + nf_ct_remove_expectations(ct); +} + +#define NFCT_ALIGN(len) (((len) + NFCT_INFOMASK) & ~NFCT_INFOMASK) + +/* Released via nf_ct_destroy() */ +struct nf_conn *nf_ct_tmpl_alloc(struct net *net, + const struct nf_conntrack_zone *zone, + gfp_t flags) +{ + struct nf_conn *tmpl, *p; + + if (ARCH_KMALLOC_MINALIGN <= NFCT_INFOMASK) { + tmpl = kzalloc(sizeof(*tmpl) + NFCT_INFOMASK, flags); + if (!tmpl) + return NULL; + + p = tmpl; + tmpl = (struct nf_conn *)NFCT_ALIGN((unsigned long)p); + if (tmpl != p) { + tmpl = (struct nf_conn *)NFCT_ALIGN((unsigned long)p); + tmpl->proto.tmpl_padto = (char *)tmpl - (char *)p; + } + } else { + tmpl = kzalloc(sizeof(*tmpl), flags); + if (!tmpl) + return NULL; + } + + tmpl->status = IPS_TEMPLATE; + write_pnet(&tmpl->ct_net, net); + nf_ct_zone_add(tmpl, zone); + refcount_set(&tmpl->ct_general.use, 1); + + return tmpl; +} +EXPORT_SYMBOL_GPL(nf_ct_tmpl_alloc); + +void nf_ct_tmpl_free(struct nf_conn *tmpl) +{ + kfree(tmpl->ext); + + if (ARCH_KMALLOC_MINALIGN <= NFCT_INFOMASK) + kfree((char *)tmpl - tmpl->proto.tmpl_padto); + else + kfree(tmpl); +} +EXPORT_SYMBOL_GPL(nf_ct_tmpl_free); + +static void destroy_gre_conntrack(struct nf_conn *ct) +{ +#ifdef CONFIG_NF_CT_PROTO_GRE + struct nf_conn *master = ct->master; + + if (master) + nf_ct_gre_keymap_destroy(master); +#endif +} + +void nf_ct_destroy(struct nf_conntrack *nfct) +{ + struct nf_conn *ct = (struct nf_conn *)nfct; + + pr_debug("%s(%p)\n", __func__, ct); + WARN_ON(refcount_read(&nfct->use) != 0); + + if (unlikely(nf_ct_is_template(ct))) { + nf_ct_tmpl_free(ct); + return; + } + + if (unlikely(nf_ct_protonum(ct) == IPPROTO_GRE)) + destroy_gre_conntrack(ct); + + /* Expectations will have been removed in clean_from_lists, + * except TFTP can create an expectation on the first packet, + * before connection is in the list, so we need to clean here, + * too. + */ + nf_ct_remove_expectations(ct); + + if (ct->master) + nf_ct_put(ct->master); + + pr_debug("%s: returning ct=%p to slab\n", __func__, ct); + nf_conntrack_free(ct); +} +EXPORT_SYMBOL(nf_ct_destroy); + +static void __nf_ct_delete_from_lists(struct nf_conn *ct) +{ + struct net *net = nf_ct_net(ct); + unsigned int hash, reply_hash; + unsigned int sequence; + + do { + sequence = read_seqcount_begin(&nf_conntrack_generation); + hash = hash_conntrack(net, + &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple, + nf_ct_zone_id(nf_ct_zone(ct), IP_CT_DIR_ORIGINAL)); + reply_hash = hash_conntrack(net, + &ct->tuplehash[IP_CT_DIR_REPLY].tuple, + nf_ct_zone_id(nf_ct_zone(ct), IP_CT_DIR_REPLY)); + } while (nf_conntrack_double_lock(net, hash, reply_hash, sequence)); + + clean_from_lists(ct); + nf_conntrack_double_unlock(hash, reply_hash); +} + +static void nf_ct_delete_from_lists(struct nf_conn *ct) +{ + nf_ct_helper_destroy(ct); + local_bh_disable(); + + __nf_ct_delete_from_lists(ct); + + local_bh_enable(); +} + +static void nf_ct_add_to_ecache_list(struct nf_conn *ct) +{ +#ifdef CONFIG_NF_CONNTRACK_EVENTS + struct nf_conntrack_net *cnet = nf_ct_pernet(nf_ct_net(ct)); + + spin_lock(&cnet->ecache.dying_lock); + hlist_nulls_add_head_rcu(&ct->tuplehash[IP_CT_DIR_ORIGINAL].hnnode, + &cnet->ecache.dying_list); + spin_unlock(&cnet->ecache.dying_lock); +#endif +} + +bool nf_ct_delete(struct nf_conn *ct, u32 portid, int report) +{ + struct nf_conn_tstamp *tstamp; + struct net *net; + + if (test_and_set_bit(IPS_DYING_BIT, &ct->status)) + return false; + + tstamp = nf_conn_tstamp_find(ct); + if (tstamp) { + s32 timeout = READ_ONCE(ct->timeout) - nfct_time_stamp; + + tstamp->stop = ktime_get_real_ns(); + if (timeout < 0) + tstamp->stop -= jiffies_to_nsecs(-timeout); + } + + if (nf_conntrack_event_report(IPCT_DESTROY, ct, + portid, report) < 0) { + /* destroy event was not delivered. nf_ct_put will + * be done by event cache worker on redelivery. + */ + nf_ct_helper_destroy(ct); + local_bh_disable(); + __nf_ct_delete_from_lists(ct); + nf_ct_add_to_ecache_list(ct); + local_bh_enable(); + + nf_conntrack_ecache_work(nf_ct_net(ct), NFCT_ECACHE_DESTROY_FAIL); + return false; + } + + net = nf_ct_net(ct); + if (nf_conntrack_ecache_dwork_pending(net)) + nf_conntrack_ecache_work(net, NFCT_ECACHE_DESTROY_SENT); + nf_ct_delete_from_lists(ct); + nf_ct_put(ct); + return true; +} +EXPORT_SYMBOL_GPL(nf_ct_delete); + +static inline bool +nf_ct_key_equal(struct nf_conntrack_tuple_hash *h, + const struct nf_conntrack_tuple *tuple, + const struct nf_conntrack_zone *zone, + const struct net *net) +{ + struct nf_conn *ct = nf_ct_tuplehash_to_ctrack(h); + + /* A conntrack can be recreated with the equal tuple, + * so we need to check that the conntrack is confirmed + */ + return nf_ct_tuple_equal(tuple, &h->tuple) && + nf_ct_zone_equal(ct, zone, NF_CT_DIRECTION(h)) && + nf_ct_is_confirmed(ct) && + net_eq(net, nf_ct_net(ct)); +} + +static inline bool +nf_ct_match(const struct nf_conn *ct1, const struct nf_conn *ct2) +{ + return nf_ct_tuple_equal(&ct1->tuplehash[IP_CT_DIR_ORIGINAL].tuple, + &ct2->tuplehash[IP_CT_DIR_ORIGINAL].tuple) && + nf_ct_tuple_equal(&ct1->tuplehash[IP_CT_DIR_REPLY].tuple, + &ct2->tuplehash[IP_CT_DIR_REPLY].tuple) && + nf_ct_zone_equal(ct1, nf_ct_zone(ct2), IP_CT_DIR_ORIGINAL) && + nf_ct_zone_equal(ct1, nf_ct_zone(ct2), IP_CT_DIR_REPLY) && + net_eq(nf_ct_net(ct1), nf_ct_net(ct2)); +} + +/* caller must hold rcu readlock and none of the nf_conntrack_locks */ +static void nf_ct_gc_expired(struct nf_conn *ct) +{ + if (!refcount_inc_not_zero(&ct->ct_general.use)) + return; + + /* load ->status after refcount increase */ + smp_acquire__after_ctrl_dep(); + + if (nf_ct_should_gc(ct)) + nf_ct_kill(ct); + + nf_ct_put(ct); +} + +/* + * Warning : + * - Caller must take a reference on returned object + * and recheck nf_ct_tuple_equal(tuple, &h->tuple) + */ +static struct nf_conntrack_tuple_hash * +____nf_conntrack_find(struct net *net, const struct nf_conntrack_zone *zone, + const struct nf_conntrack_tuple *tuple, u32 hash) +{ + struct nf_conntrack_tuple_hash *h; + struct hlist_nulls_head *ct_hash; + struct hlist_nulls_node *n; + unsigned int bucket, hsize; + +begin: + nf_conntrack_get_ht(&ct_hash, &hsize); + bucket = reciprocal_scale(hash, hsize); + + hlist_nulls_for_each_entry_rcu(h, n, &ct_hash[bucket], hnnode) { + struct nf_conn *ct; + + ct = nf_ct_tuplehash_to_ctrack(h); + if (nf_ct_is_expired(ct)) { + nf_ct_gc_expired(ct); + continue; + } + + if (nf_ct_key_equal(h, tuple, zone, net)) + return h; + } + /* + * if the nulls value we got at the end of this lookup is + * not the expected one, we must restart lookup. + * We probably met an item that was moved to another chain. + */ + if (get_nulls_value(n) != bucket) { + NF_CT_STAT_INC_ATOMIC(net, search_restart); + goto begin; + } + + return NULL; +} + +/* Find a connection corresponding to a tuple. */ +static struct nf_conntrack_tuple_hash * +__nf_conntrack_find_get(struct net *net, const struct nf_conntrack_zone *zone, + const struct nf_conntrack_tuple *tuple, u32 hash) +{ + struct nf_conntrack_tuple_hash *h; + struct nf_conn *ct; + + rcu_read_lock(); + + h = ____nf_conntrack_find(net, zone, tuple, hash); + if (h) { + /* We have a candidate that matches the tuple we're interested + * in, try to obtain a reference and re-check tuple + */ + ct = nf_ct_tuplehash_to_ctrack(h); + if (likely(refcount_inc_not_zero(&ct->ct_general.use))) { + /* re-check key after refcount */ + smp_acquire__after_ctrl_dep(); + + if (likely(nf_ct_key_equal(h, tuple, zone, net))) + goto found; + + /* TYPESAFE_BY_RCU recycled the candidate */ + nf_ct_put(ct); + } + + h = NULL; + } +found: + rcu_read_unlock(); + + return h; +} + +struct nf_conntrack_tuple_hash * +nf_conntrack_find_get(struct net *net, const struct nf_conntrack_zone *zone, + const struct nf_conntrack_tuple *tuple) +{ + unsigned int rid, zone_id = nf_ct_zone_id(zone, IP_CT_DIR_ORIGINAL); + struct nf_conntrack_tuple_hash *thash; + + thash = __nf_conntrack_find_get(net, zone, tuple, + hash_conntrack_raw(tuple, zone_id, net)); + + if (thash) + return thash; + + rid = nf_ct_zone_id(zone, IP_CT_DIR_REPLY); + if (rid != zone_id) + return __nf_conntrack_find_get(net, zone, tuple, + hash_conntrack_raw(tuple, rid, net)); + return thash; +} +EXPORT_SYMBOL_GPL(nf_conntrack_find_get); + +static void __nf_conntrack_hash_insert(struct nf_conn *ct, + unsigned int hash, + unsigned int reply_hash) +{ + hlist_nulls_add_head_rcu(&ct->tuplehash[IP_CT_DIR_ORIGINAL].hnnode, + &nf_conntrack_hash[hash]); + hlist_nulls_add_head_rcu(&ct->tuplehash[IP_CT_DIR_REPLY].hnnode, + &nf_conntrack_hash[reply_hash]); +} + +static bool nf_ct_ext_valid_pre(const struct nf_ct_ext *ext) +{ + /* if ext->gen_id is not equal to nf_conntrack_ext_genid, some extensions + * may contain stale pointers to e.g. helper that has been removed. + * + * The helper can't clear this because the nf_conn object isn't in + * any hash and synchronize_rcu() isn't enough because associated skb + * might sit in a queue. + */ + return !ext || ext->gen_id == atomic_read(&nf_conntrack_ext_genid); +} + +static bool nf_ct_ext_valid_post(struct nf_ct_ext *ext) +{ + if (!ext) + return true; + + if (ext->gen_id != atomic_read(&nf_conntrack_ext_genid)) + return false; + + /* inserted into conntrack table, nf_ct_iterate_cleanup() + * will find it. Disable nf_ct_ext_find() id check. + */ + WRITE_ONCE(ext->gen_id, 0); + return true; +} + +int +nf_conntrack_hash_check_insert(struct nf_conn *ct) +{ + const struct nf_conntrack_zone *zone; + struct net *net = nf_ct_net(ct); + unsigned int hash, reply_hash; + struct nf_conntrack_tuple_hash *h; + struct hlist_nulls_node *n; + unsigned int max_chainlen; + unsigned int chainlen = 0; + unsigned int sequence; + int err = -EEXIST; + + zone = nf_ct_zone(ct); + + if (!nf_ct_ext_valid_pre(ct->ext)) + return -EAGAIN; + + local_bh_disable(); + do { + sequence = read_seqcount_begin(&nf_conntrack_generation); + hash = hash_conntrack(net, + &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple, + nf_ct_zone_id(nf_ct_zone(ct), IP_CT_DIR_ORIGINAL)); + reply_hash = hash_conntrack(net, + &ct->tuplehash[IP_CT_DIR_REPLY].tuple, + nf_ct_zone_id(nf_ct_zone(ct), IP_CT_DIR_REPLY)); + } while (nf_conntrack_double_lock(net, hash, reply_hash, sequence)); + + max_chainlen = MIN_CHAINLEN + prandom_u32_max(MAX_CHAINLEN); + + /* See if there's one in the list already, including reverse */ + hlist_nulls_for_each_entry(h, n, &nf_conntrack_hash[hash], hnnode) { + if (nf_ct_key_equal(h, &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple, + zone, net)) + goto out; + + if (chainlen++ > max_chainlen) + goto chaintoolong; + } + + chainlen = 0; + + hlist_nulls_for_each_entry(h, n, &nf_conntrack_hash[reply_hash], hnnode) { + if (nf_ct_key_equal(h, &ct->tuplehash[IP_CT_DIR_REPLY].tuple, + zone, net)) + goto out; + if (chainlen++ > max_chainlen) + goto chaintoolong; + } + + /* If genid has changed, we can't insert anymore because ct + * extensions could have stale pointers and nf_ct_iterate_destroy + * might have completed its table scan already. + * + * Increment of the ext genid right after this check is fine: + * nf_ct_iterate_destroy blocks until locks are released. + */ + if (!nf_ct_ext_valid_post(ct->ext)) { + err = -EAGAIN; + goto out; + } + + smp_wmb(); + /* The caller holds a reference to this object */ + refcount_set(&ct->ct_general.use, 2); + __nf_conntrack_hash_insert(ct, hash, reply_hash); + nf_conntrack_double_unlock(hash, reply_hash); + NF_CT_STAT_INC(net, insert); + local_bh_enable(); + + return 0; +chaintoolong: + NF_CT_STAT_INC(net, chaintoolong); + err = -ENOSPC; +out: + nf_conntrack_double_unlock(hash, reply_hash); + local_bh_enable(); + return err; +} +EXPORT_SYMBOL_GPL(nf_conntrack_hash_check_insert); + +void nf_ct_acct_add(struct nf_conn *ct, u32 dir, unsigned int packets, + unsigned int bytes) +{ + struct nf_conn_acct *acct; + + acct = nf_conn_acct_find(ct); + if (acct) { + struct nf_conn_counter *counter = acct->counter; + + atomic64_add(packets, &counter[dir].packets); + atomic64_add(bytes, &counter[dir].bytes); + } +} +EXPORT_SYMBOL_GPL(nf_ct_acct_add); + +static void nf_ct_acct_merge(struct nf_conn *ct, enum ip_conntrack_info ctinfo, + const struct nf_conn *loser_ct) +{ + struct nf_conn_acct *acct; + + acct = nf_conn_acct_find(loser_ct); + if (acct) { + struct nf_conn_counter *counter = acct->counter; + unsigned int bytes; + + /* u32 should be fine since we must have seen one packet. */ + bytes = atomic64_read(&counter[CTINFO2DIR(ctinfo)].bytes); + nf_ct_acct_update(ct, CTINFO2DIR(ctinfo), bytes); + } +} + +static void __nf_conntrack_insert_prepare(struct nf_conn *ct) +{ + struct nf_conn_tstamp *tstamp; + + refcount_inc(&ct->ct_general.use); + + /* set conntrack timestamp, if enabled. */ + tstamp = nf_conn_tstamp_find(ct); + if (tstamp) + tstamp->start = ktime_get_real_ns(); +} + +/* caller must hold locks to prevent concurrent changes */ +static int __nf_ct_resolve_clash(struct sk_buff *skb, + struct nf_conntrack_tuple_hash *h) +{ + /* This is the conntrack entry already in hashes that won race. */ + struct nf_conn *ct = nf_ct_tuplehash_to_ctrack(h); + enum ip_conntrack_info ctinfo; + struct nf_conn *loser_ct; + + loser_ct = nf_ct_get(skb, &ctinfo); + + if (nf_ct_is_dying(ct)) + return NF_DROP; + + if (((ct->status & IPS_NAT_DONE_MASK) == 0) || + nf_ct_match(ct, loser_ct)) { + struct net *net = nf_ct_net(ct); + + nf_conntrack_get(&ct->ct_general); + + nf_ct_acct_merge(ct, ctinfo, loser_ct); + nf_ct_put(loser_ct); + nf_ct_set(skb, ct, ctinfo); + + NF_CT_STAT_INC(net, clash_resolve); + return NF_ACCEPT; + } + + return NF_DROP; +} + +/** + * nf_ct_resolve_clash_harder - attempt to insert clashing conntrack entry + * + * @skb: skb that causes the collision + * @repl_idx: hash slot for reply direction + * + * Called when origin or reply direction had a clash. + * The skb can be handled without packet drop provided the reply direction + * is unique or there the existing entry has the identical tuple in both + * directions. + * + * Caller must hold conntrack table locks to prevent concurrent updates. + * + * Returns NF_DROP if the clash could not be handled. + */ +static int nf_ct_resolve_clash_harder(struct sk_buff *skb, u32 repl_idx) +{ + struct nf_conn *loser_ct = (struct nf_conn *)skb_nfct(skb); + const struct nf_conntrack_zone *zone; + struct nf_conntrack_tuple_hash *h; + struct hlist_nulls_node *n; + struct net *net; + + zone = nf_ct_zone(loser_ct); + net = nf_ct_net(loser_ct); + + /* Reply direction must never result in a clash, unless both origin + * and reply tuples are identical. + */ + hlist_nulls_for_each_entry(h, n, &nf_conntrack_hash[repl_idx], hnnode) { + if (nf_ct_key_equal(h, + &loser_ct->tuplehash[IP_CT_DIR_REPLY].tuple, + zone, net)) + return __nf_ct_resolve_clash(skb, h); + } + + /* We want the clashing entry to go away real soon: 1 second timeout. */ + WRITE_ONCE(loser_ct->timeout, nfct_time_stamp + HZ); + + /* IPS_NAT_CLASH removes the entry automatically on the first + * reply. Also prevents UDP tracker from moving the entry to + * ASSURED state, i.e. the entry can always be evicted under + * pressure. + */ + loser_ct->status |= IPS_FIXED_TIMEOUT | IPS_NAT_CLASH; + + __nf_conntrack_insert_prepare(loser_ct); + + /* fake add for ORIGINAL dir: we want lookups to only find the entry + * already in the table. This also hides the clashing entry from + * ctnetlink iteration, i.e. conntrack -L won't show them. + */ + hlist_nulls_add_fake(&loser_ct->tuplehash[IP_CT_DIR_ORIGINAL].hnnode); + + hlist_nulls_add_head_rcu(&loser_ct->tuplehash[IP_CT_DIR_REPLY].hnnode, + &nf_conntrack_hash[repl_idx]); + + NF_CT_STAT_INC(net, clash_resolve); + return NF_ACCEPT; +} + +/** + * nf_ct_resolve_clash - attempt to handle clash without packet drop + * + * @skb: skb that causes the clash + * @h: tuplehash of the clashing entry already in table + * @reply_hash: hash slot for reply direction + * + * A conntrack entry can be inserted to the connection tracking table + * if there is no existing entry with an identical tuple. + * + * If there is one, @skb (and the assocated, unconfirmed conntrack) has + * to be dropped. In case @skb is retransmitted, next conntrack lookup + * will find the already-existing entry. + * + * The major problem with such packet drop is the extra delay added by + * the packet loss -- it will take some time for a retransmit to occur + * (or the sender to time out when waiting for a reply). + * + * This function attempts to handle the situation without packet drop. + * + * If @skb has no NAT transformation or if the colliding entries are + * exactly the same, only the to-be-confirmed conntrack entry is discarded + * and @skb is associated with the conntrack entry already in the table. + * + * Failing that, the new, unconfirmed conntrack is still added to the table + * provided that the collision only occurs in the ORIGINAL direction. + * The new entry will be added only in the non-clashing REPLY direction, + * so packets in the ORIGINAL direction will continue to match the existing + * entry. The new entry will also have a fixed timeout so it expires -- + * due to the collision, it will only see reply traffic. + * + * Returns NF_DROP if the clash could not be resolved. + */ +static __cold noinline int +nf_ct_resolve_clash(struct sk_buff *skb, struct nf_conntrack_tuple_hash *h, + u32 reply_hash) +{ + /* This is the conntrack entry already in hashes that won race. */ + struct nf_conn *ct = nf_ct_tuplehash_to_ctrack(h); + const struct nf_conntrack_l4proto *l4proto; + enum ip_conntrack_info ctinfo; + struct nf_conn *loser_ct; + struct net *net; + int ret; + + loser_ct = nf_ct_get(skb, &ctinfo); + net = nf_ct_net(loser_ct); + + l4proto = nf_ct_l4proto_find(nf_ct_protonum(ct)); + if (!l4proto->allow_clash) + goto drop; + + ret = __nf_ct_resolve_clash(skb, h); + if (ret == NF_ACCEPT) + return ret; + + ret = nf_ct_resolve_clash_harder(skb, reply_hash); + if (ret == NF_ACCEPT) + return ret; + +drop: + NF_CT_STAT_INC(net, drop); + NF_CT_STAT_INC(net, insert_failed); + return NF_DROP; +} + +/* Confirm a connection given skb; places it in hash table */ +int +__nf_conntrack_confirm(struct sk_buff *skb) +{ + unsigned int chainlen = 0, sequence, max_chainlen; + const struct nf_conntrack_zone *zone; + unsigned int hash, reply_hash; + struct nf_conntrack_tuple_hash *h; + struct nf_conn *ct; + struct nf_conn_help *help; + struct hlist_nulls_node *n; + enum ip_conntrack_info ctinfo; + struct net *net; + int ret = NF_DROP; + + ct = nf_ct_get(skb, &ctinfo); + net = nf_ct_net(ct); + + /* ipt_REJECT uses nf_conntrack_attach to attach related + ICMP/TCP RST packets in other direction. Actual packet + which created connection will be IP_CT_NEW or for an + expected connection, IP_CT_RELATED. */ + if (CTINFO2DIR(ctinfo) != IP_CT_DIR_ORIGINAL) + return NF_ACCEPT; + + zone = nf_ct_zone(ct); + local_bh_disable(); + + do { + sequence = read_seqcount_begin(&nf_conntrack_generation); + /* reuse the hash saved before */ + hash = *(unsigned long *)&ct->tuplehash[IP_CT_DIR_REPLY].hnnode.pprev; + hash = scale_hash(hash); + reply_hash = hash_conntrack(net, + &ct->tuplehash[IP_CT_DIR_REPLY].tuple, + nf_ct_zone_id(nf_ct_zone(ct), IP_CT_DIR_REPLY)); + } while (nf_conntrack_double_lock(net, hash, reply_hash, sequence)); + + /* We're not in hash table, and we refuse to set up related + * connections for unconfirmed conns. But packet copies and + * REJECT will give spurious warnings here. + */ + + /* Another skb with the same unconfirmed conntrack may + * win the race. This may happen for bridge(br_flood) + * or broadcast/multicast packets do skb_clone with + * unconfirmed conntrack. + */ + if (unlikely(nf_ct_is_confirmed(ct))) { + WARN_ON_ONCE(1); + nf_conntrack_double_unlock(hash, reply_hash); + local_bh_enable(); + return NF_DROP; + } + + if (!nf_ct_ext_valid_pre(ct->ext)) { + NF_CT_STAT_INC(net, insert_failed); + goto dying; + } + + pr_debug("Confirming conntrack %p\n", ct); + /* We have to check the DYING flag after unlink to prevent + * a race against nf_ct_get_next_corpse() possibly called from + * user context, else we insert an already 'dead' hash, blocking + * further use of that particular connection -JM. + */ + ct->status |= IPS_CONFIRMED; + + if (unlikely(nf_ct_is_dying(ct))) { + NF_CT_STAT_INC(net, insert_failed); + goto dying; + } + + max_chainlen = MIN_CHAINLEN + prandom_u32_max(MAX_CHAINLEN); + /* See if there's one in the list already, including reverse: + NAT could have grabbed it without realizing, since we're + not in the hash. If there is, we lost race. */ + hlist_nulls_for_each_entry(h, n, &nf_conntrack_hash[hash], hnnode) { + if (nf_ct_key_equal(h, &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple, + zone, net)) + goto out; + if (chainlen++ > max_chainlen) + goto chaintoolong; + } + + chainlen = 0; + hlist_nulls_for_each_entry(h, n, &nf_conntrack_hash[reply_hash], hnnode) { + if (nf_ct_key_equal(h, &ct->tuplehash[IP_CT_DIR_REPLY].tuple, + zone, net)) + goto out; + if (chainlen++ > max_chainlen) { +chaintoolong: + NF_CT_STAT_INC(net, chaintoolong); + NF_CT_STAT_INC(net, insert_failed); + ret = NF_DROP; + goto dying; + } + } + + /* Timer relative to confirmation time, not original + setting time, otherwise we'd get timer wrap in + weird delay cases. */ + ct->timeout += nfct_time_stamp; + + __nf_conntrack_insert_prepare(ct); + + /* Since the lookup is lockless, hash insertion must be done after + * starting the timer and setting the CONFIRMED bit. The RCU barriers + * guarantee that no other CPU can find the conntrack before the above + * stores are visible. + */ + __nf_conntrack_hash_insert(ct, hash, reply_hash); + nf_conntrack_double_unlock(hash, reply_hash); + local_bh_enable(); + + /* ext area is still valid (rcu read lock is held, + * but will go out of scope soon, we need to remove + * this conntrack again. + */ + if (!nf_ct_ext_valid_post(ct->ext)) { + nf_ct_kill(ct); + NF_CT_STAT_INC_ATOMIC(net, drop); + return NF_DROP; + } + + help = nfct_help(ct); + if (help && help->helper) + nf_conntrack_event_cache(IPCT_HELPER, ct); + + nf_conntrack_event_cache(master_ct(ct) ? + IPCT_RELATED : IPCT_NEW, ct); + return NF_ACCEPT; + +out: + ret = nf_ct_resolve_clash(skb, h, reply_hash); +dying: + nf_conntrack_double_unlock(hash, reply_hash); + local_bh_enable(); + return ret; +} +EXPORT_SYMBOL_GPL(__nf_conntrack_confirm); + +/* Returns true if a connection correspondings to the tuple (required + for NAT). */ +int +nf_conntrack_tuple_taken(const struct nf_conntrack_tuple *tuple, + const struct nf_conn *ignored_conntrack) +{ + struct net *net = nf_ct_net(ignored_conntrack); + const struct nf_conntrack_zone *zone; + struct nf_conntrack_tuple_hash *h; + struct hlist_nulls_head *ct_hash; + unsigned int hash, hsize; + struct hlist_nulls_node *n; + struct nf_conn *ct; + + zone = nf_ct_zone(ignored_conntrack); + + rcu_read_lock(); + begin: + nf_conntrack_get_ht(&ct_hash, &hsize); + hash = __hash_conntrack(net, tuple, nf_ct_zone_id(zone, IP_CT_DIR_REPLY), hsize); + + hlist_nulls_for_each_entry_rcu(h, n, &ct_hash[hash], hnnode) { + ct = nf_ct_tuplehash_to_ctrack(h); + + if (ct == ignored_conntrack) + continue; + + if (nf_ct_is_expired(ct)) { + nf_ct_gc_expired(ct); + continue; + } + + if (nf_ct_key_equal(h, tuple, zone, net)) { + /* Tuple is taken already, so caller will need to find + * a new source port to use. + * + * Only exception: + * If the *original tuples* are identical, then both + * conntracks refer to the same flow. + * This is a rare situation, it can occur e.g. when + * more than one UDP packet is sent from same socket + * in different threads. + * + * Let nf_ct_resolve_clash() deal with this later. + */ + if (nf_ct_tuple_equal(&ignored_conntrack->tuplehash[IP_CT_DIR_ORIGINAL].tuple, + &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple) && + nf_ct_zone_equal(ct, zone, IP_CT_DIR_ORIGINAL)) + continue; + + NF_CT_STAT_INC_ATOMIC(net, found); + rcu_read_unlock(); + return 1; + } + } + + if (get_nulls_value(n) != hash) { + NF_CT_STAT_INC_ATOMIC(net, search_restart); + goto begin; + } + + rcu_read_unlock(); + + return 0; +} +EXPORT_SYMBOL_GPL(nf_conntrack_tuple_taken); + +#define NF_CT_EVICTION_RANGE 8 + +/* There's a small race here where we may free a just-assured + connection. Too bad: we're in trouble anyway. */ +static unsigned int early_drop_list(struct net *net, + struct hlist_nulls_head *head) +{ + struct nf_conntrack_tuple_hash *h; + struct hlist_nulls_node *n; + unsigned int drops = 0; + struct nf_conn *tmp; + + hlist_nulls_for_each_entry_rcu(h, n, head, hnnode) { + tmp = nf_ct_tuplehash_to_ctrack(h); + + if (test_bit(IPS_OFFLOAD_BIT, &tmp->status)) + continue; + + if (nf_ct_is_expired(tmp)) { + nf_ct_gc_expired(tmp); + continue; + } + + if (test_bit(IPS_ASSURED_BIT, &tmp->status) || + !net_eq(nf_ct_net(tmp), net) || + nf_ct_is_dying(tmp)) + continue; + + if (!refcount_inc_not_zero(&tmp->ct_general.use)) + continue; + + /* load ->ct_net and ->status after refcount increase */ + smp_acquire__after_ctrl_dep(); + + /* kill only if still in same netns -- might have moved due to + * SLAB_TYPESAFE_BY_RCU rules. + * + * We steal the timer reference. If that fails timer has + * already fired or someone else deleted it. Just drop ref + * and move to next entry. + */ + if (net_eq(nf_ct_net(tmp), net) && + nf_ct_is_confirmed(tmp) && + nf_ct_delete(tmp, 0, 0)) + drops++; + + nf_ct_put(tmp); + } + + return drops; +} + +static noinline int early_drop(struct net *net, unsigned int hash) +{ + unsigned int i, bucket; + + for (i = 0; i < NF_CT_EVICTION_RANGE; i++) { + struct hlist_nulls_head *ct_hash; + unsigned int hsize, drops; + + rcu_read_lock(); + nf_conntrack_get_ht(&ct_hash, &hsize); + if (!i) + bucket = reciprocal_scale(hash, hsize); + else + bucket = (bucket + 1) % hsize; + + drops = early_drop_list(net, &ct_hash[bucket]); + rcu_read_unlock(); + + if (drops) { + NF_CT_STAT_ADD_ATOMIC(net, early_drop, drops); + return true; + } + } + + return false; +} + +static bool gc_worker_skip_ct(const struct nf_conn *ct) +{ + return !nf_ct_is_confirmed(ct) || nf_ct_is_dying(ct); +} + +static bool gc_worker_can_early_drop(const struct nf_conn *ct) +{ + const struct nf_conntrack_l4proto *l4proto; + + if (!test_bit(IPS_ASSURED_BIT, &ct->status)) + return true; + + l4proto = nf_ct_l4proto_find(nf_ct_protonum(ct)); + if (l4proto->can_early_drop && l4proto->can_early_drop(ct)) + return true; + + return false; +} + +static void gc_worker(struct work_struct *work) +{ + unsigned int i, hashsz, nf_conntrack_max95 = 0; + u32 end_time, start_time = nfct_time_stamp; + struct conntrack_gc_work *gc_work; + unsigned int expired_count = 0; + unsigned long next_run; + s32 delta_time; + long count; + + gc_work = container_of(work, struct conntrack_gc_work, dwork.work); + + i = gc_work->next_bucket; + if (gc_work->early_drop) + nf_conntrack_max95 = nf_conntrack_max / 100u * 95u; + + if (i == 0) { + gc_work->avg_timeout = GC_SCAN_INTERVAL_INIT; + gc_work->count = GC_SCAN_INITIAL_COUNT; + gc_work->start_time = start_time; + } + + next_run = gc_work->avg_timeout; + count = gc_work->count; + + end_time = start_time + GC_SCAN_MAX_DURATION; + + do { + struct nf_conntrack_tuple_hash *h; + struct hlist_nulls_head *ct_hash; + struct hlist_nulls_node *n; + struct nf_conn *tmp; + + rcu_read_lock(); + + nf_conntrack_get_ht(&ct_hash, &hashsz); + if (i >= hashsz) { + rcu_read_unlock(); + break; + } + + hlist_nulls_for_each_entry_rcu(h, n, &ct_hash[i], hnnode) { + struct nf_conntrack_net *cnet; + struct net *net; + long expires; + + tmp = nf_ct_tuplehash_to_ctrack(h); + + if (test_bit(IPS_OFFLOAD_BIT, &tmp->status)) { + nf_ct_offload_timeout(tmp); + continue; + } + + if (expired_count > GC_SCAN_EXPIRED_MAX) { + rcu_read_unlock(); + + gc_work->next_bucket = i; + gc_work->avg_timeout = next_run; + gc_work->count = count; + + delta_time = nfct_time_stamp - gc_work->start_time; + + /* re-sched immediately if total cycle time is exceeded */ + next_run = delta_time < (s32)GC_SCAN_INTERVAL_MAX; + goto early_exit; + } + + if (nf_ct_is_expired(tmp)) { + nf_ct_gc_expired(tmp); + expired_count++; + continue; + } + + expires = clamp(nf_ct_expires(tmp), GC_SCAN_INTERVAL_MIN, GC_SCAN_INTERVAL_CLAMP); + expires = (expires - (long)next_run) / ++count; + next_run += expires; + + if (nf_conntrack_max95 == 0 || gc_worker_skip_ct(tmp)) + continue; + + net = nf_ct_net(tmp); + cnet = nf_ct_pernet(net); + if (atomic_read(&cnet->count) < nf_conntrack_max95) + continue; + + /* need to take reference to avoid possible races */ + if (!refcount_inc_not_zero(&tmp->ct_general.use)) + continue; + + /* load ->status after refcount increase */ + smp_acquire__after_ctrl_dep(); + + if (gc_worker_skip_ct(tmp)) { + nf_ct_put(tmp); + continue; + } + + if (gc_worker_can_early_drop(tmp)) { + nf_ct_kill(tmp); + expired_count++; + } + + nf_ct_put(tmp); + } + + /* could check get_nulls_value() here and restart if ct + * was moved to another chain. But given gc is best-effort + * we will just continue with next hash slot. + */ + rcu_read_unlock(); + cond_resched(); + i++; + + delta_time = nfct_time_stamp - end_time; + if (delta_time > 0 && i < hashsz) { + gc_work->avg_timeout = next_run; + gc_work->count = count; + gc_work->next_bucket = i; + next_run = 0; + goto early_exit; + } + } while (i < hashsz); + + gc_work->next_bucket = 0; + + next_run = clamp(next_run, GC_SCAN_INTERVAL_MIN, GC_SCAN_INTERVAL_MAX); + + delta_time = max_t(s32, nfct_time_stamp - gc_work->start_time, 1); + if (next_run > (unsigned long)delta_time) + next_run -= delta_time; + else + next_run = 1; + +early_exit: + if (gc_work->exiting) + return; + + if (next_run) + gc_work->early_drop = false; + + queue_delayed_work(system_power_efficient_wq, &gc_work->dwork, next_run); +} + +static void conntrack_gc_work_init(struct conntrack_gc_work *gc_work) +{ + INIT_DELAYED_WORK(&gc_work->dwork, gc_worker); + gc_work->exiting = false; +} + +static struct nf_conn * +__nf_conntrack_alloc(struct net *net, + const struct nf_conntrack_zone *zone, + const struct nf_conntrack_tuple *orig, + const struct nf_conntrack_tuple *repl, + gfp_t gfp, u32 hash) +{ + struct nf_conntrack_net *cnet = nf_ct_pernet(net); + unsigned int ct_count; + struct nf_conn *ct; + + /* We don't want any race condition at early drop stage */ + ct_count = atomic_inc_return(&cnet->count); + + if (nf_conntrack_max && unlikely(ct_count > nf_conntrack_max)) { + if (!early_drop(net, hash)) { + if (!conntrack_gc_work.early_drop) + conntrack_gc_work.early_drop = true; + atomic_dec(&cnet->count); + net_warn_ratelimited("nf_conntrack: table full, dropping packet\n"); + return ERR_PTR(-ENOMEM); + } + } + + /* + * Do not use kmem_cache_zalloc(), as this cache uses + * SLAB_TYPESAFE_BY_RCU. + */ + ct = kmem_cache_alloc(nf_conntrack_cachep, gfp); + if (ct == NULL) + goto out; + + spin_lock_init(&ct->lock); + ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple = *orig; + ct->tuplehash[IP_CT_DIR_ORIGINAL].hnnode.pprev = NULL; + ct->tuplehash[IP_CT_DIR_REPLY].tuple = *repl; + /* save hash for reusing when confirming */ + *(unsigned long *)(&ct->tuplehash[IP_CT_DIR_REPLY].hnnode.pprev) = hash; + ct->status = 0; + WRITE_ONCE(ct->timeout, 0); + write_pnet(&ct->ct_net, net); + memset_after(ct, 0, __nfct_init_offset); + + nf_ct_zone_add(ct, zone); + + /* Because we use RCU lookups, we set ct_general.use to zero before + * this is inserted in any list. + */ + refcount_set(&ct->ct_general.use, 0); + return ct; +out: + atomic_dec(&cnet->count); + return ERR_PTR(-ENOMEM); +} + +struct nf_conn *nf_conntrack_alloc(struct net *net, + const struct nf_conntrack_zone *zone, + const struct nf_conntrack_tuple *orig, + const struct nf_conntrack_tuple *repl, + gfp_t gfp) +{ + return __nf_conntrack_alloc(net, zone, orig, repl, gfp, 0); +} +EXPORT_SYMBOL_GPL(nf_conntrack_alloc); + +void nf_conntrack_free(struct nf_conn *ct) +{ + struct net *net = nf_ct_net(ct); + struct nf_conntrack_net *cnet; + + /* A freed object has refcnt == 0, that's + * the golden rule for SLAB_TYPESAFE_BY_RCU + */ + WARN_ON(refcount_read(&ct->ct_general.use) != 0); + + if (ct->status & IPS_SRC_NAT_DONE) { + const struct nf_nat_hook *nat_hook; + + rcu_read_lock(); + nat_hook = rcu_dereference(nf_nat_hook); + if (nat_hook) + nat_hook->remove_nat_bysrc(ct); + rcu_read_unlock(); + } + + kfree(ct->ext); + kmem_cache_free(nf_conntrack_cachep, ct); + cnet = nf_ct_pernet(net); + + smp_mb__before_atomic(); + atomic_dec(&cnet->count); +} +EXPORT_SYMBOL_GPL(nf_conntrack_free); + + +/* Allocate a new conntrack: we return -ENOMEM if classification + failed due to stress. Otherwise it really is unclassifiable. */ +static noinline struct nf_conntrack_tuple_hash * +init_conntrack(struct net *net, struct nf_conn *tmpl, + const struct nf_conntrack_tuple *tuple, + struct sk_buff *skb, + unsigned int dataoff, u32 hash) +{ + struct nf_conn *ct; + struct nf_conn_help *help; + struct nf_conntrack_tuple repl_tuple; +#ifdef CONFIG_NF_CONNTRACK_EVENTS + struct nf_conntrack_ecache *ecache; +#endif + struct nf_conntrack_expect *exp = NULL; + const struct nf_conntrack_zone *zone; + struct nf_conn_timeout *timeout_ext; + struct nf_conntrack_zone tmp; + struct nf_conntrack_net *cnet; + + if (!nf_ct_invert_tuple(&repl_tuple, tuple)) { + pr_debug("Can't invert tuple.\n"); + return NULL; + } + + zone = nf_ct_zone_tmpl(tmpl, skb, &tmp); + ct = __nf_conntrack_alloc(net, zone, tuple, &repl_tuple, GFP_ATOMIC, + hash); + if (IS_ERR(ct)) + return (struct nf_conntrack_tuple_hash *)ct; + + if (!nf_ct_add_synproxy(ct, tmpl)) { + nf_conntrack_free(ct); + return ERR_PTR(-ENOMEM); + } + + timeout_ext = tmpl ? nf_ct_timeout_find(tmpl) : NULL; + + if (timeout_ext) + nf_ct_timeout_ext_add(ct, rcu_dereference(timeout_ext->timeout), + GFP_ATOMIC); + + nf_ct_acct_ext_add(ct, GFP_ATOMIC); + nf_ct_tstamp_ext_add(ct, GFP_ATOMIC); + nf_ct_labels_ext_add(ct); + +#ifdef CONFIG_NF_CONNTRACK_EVENTS + ecache = tmpl ? nf_ct_ecache_find(tmpl) : NULL; + + if ((ecache || net->ct.sysctl_events) && + !nf_ct_ecache_ext_add(ct, ecache ? ecache->ctmask : 0, + ecache ? ecache->expmask : 0, + GFP_ATOMIC)) { + nf_conntrack_free(ct); + return ERR_PTR(-ENOMEM); + } +#endif + + cnet = nf_ct_pernet(net); + if (cnet->expect_count) { + spin_lock_bh(&nf_conntrack_expect_lock); + exp = nf_ct_find_expectation(net, zone, tuple); + if (exp) { + pr_debug("expectation arrives ct=%p exp=%p\n", + ct, exp); + /* Welcome, Mr. Bond. We've been expecting you... */ + __set_bit(IPS_EXPECTED_BIT, &ct->status); + /* exp->master safe, refcnt bumped in nf_ct_find_expectation */ + ct->master = exp->master; + if (exp->helper) { + help = nf_ct_helper_ext_add(ct, GFP_ATOMIC); + if (help) + rcu_assign_pointer(help->helper, exp->helper); + } + +#ifdef CONFIG_NF_CONNTRACK_MARK + ct->mark = READ_ONCE(exp->master->mark); +#endif +#ifdef CONFIG_NF_CONNTRACK_SECMARK + ct->secmark = exp->master->secmark; +#endif + NF_CT_STAT_INC(net, expect_new); + } + spin_unlock_bh(&nf_conntrack_expect_lock); + } + if (!exp && tmpl) + __nf_ct_try_assign_helper(ct, tmpl, GFP_ATOMIC); + + /* Other CPU might have obtained a pointer to this object before it was + * released. Because refcount is 0, refcount_inc_not_zero() will fail. + * + * After refcount_set(1) it will succeed; ensure that zeroing of + * ct->status and the correct ct->net pointer are visible; else other + * core might observe CONFIRMED bit which means the entry is valid and + * in the hash table, but its not (anymore). + */ + smp_wmb(); + + /* Now it is going to be associated with an sk_buff, set refcount to 1. */ + refcount_set(&ct->ct_general.use, 1); + + if (exp) { + if (exp->expectfn) + exp->expectfn(ct, exp); + nf_ct_expect_put(exp); + } + + return &ct->tuplehash[IP_CT_DIR_ORIGINAL]; +} + +/* On success, returns 0, sets skb->_nfct | ctinfo */ +static int +resolve_normal_ct(struct nf_conn *tmpl, + struct sk_buff *skb, + unsigned int dataoff, + u_int8_t protonum, + const struct nf_hook_state *state) +{ + const struct nf_conntrack_zone *zone; + struct nf_conntrack_tuple tuple; + struct nf_conntrack_tuple_hash *h; + enum ip_conntrack_info ctinfo; + struct nf_conntrack_zone tmp; + u32 hash, zone_id, rid; + struct nf_conn *ct; + + if (!nf_ct_get_tuple(skb, skb_network_offset(skb), + dataoff, state->pf, protonum, state->net, + &tuple)) { + pr_debug("Can't get tuple\n"); + return 0; + } + + /* look for tuple match */ + zone = nf_ct_zone_tmpl(tmpl, skb, &tmp); + + zone_id = nf_ct_zone_id(zone, IP_CT_DIR_ORIGINAL); + hash = hash_conntrack_raw(&tuple, zone_id, state->net); + h = __nf_conntrack_find_get(state->net, zone, &tuple, hash); + + if (!h) { + rid = nf_ct_zone_id(zone, IP_CT_DIR_REPLY); + if (zone_id != rid) { + u32 tmp = hash_conntrack_raw(&tuple, rid, state->net); + + h = __nf_conntrack_find_get(state->net, zone, &tuple, tmp); + } + } + + if (!h) { + h = init_conntrack(state->net, tmpl, &tuple, + skb, dataoff, hash); + if (!h) + return 0; + if (IS_ERR(h)) + return PTR_ERR(h); + } + ct = nf_ct_tuplehash_to_ctrack(h); + + /* It exists; we have (non-exclusive) reference. */ + if (NF_CT_DIRECTION(h) == IP_CT_DIR_REPLY) { + ctinfo = IP_CT_ESTABLISHED_REPLY; + } else { + /* Once we've had two way comms, always ESTABLISHED. */ + if (test_bit(IPS_SEEN_REPLY_BIT, &ct->status)) { + pr_debug("normal packet for %p\n", ct); + ctinfo = IP_CT_ESTABLISHED; + } else if (test_bit(IPS_EXPECTED_BIT, &ct->status)) { + pr_debug("related packet for %p\n", ct); + ctinfo = IP_CT_RELATED; + } else { + pr_debug("new packet for %p\n", ct); + ctinfo = IP_CT_NEW; + } + } + nf_ct_set(skb, ct, ctinfo); + return 0; +} + +/* + * icmp packets need special treatment to handle error messages that are + * related to a connection. + * + * Callers need to check if skb has a conntrack assigned when this + * helper returns; in such case skb belongs to an already known connection. + */ +static unsigned int __cold +nf_conntrack_handle_icmp(struct nf_conn *tmpl, + struct sk_buff *skb, + unsigned int dataoff, + u8 protonum, + const struct nf_hook_state *state) +{ + int ret; + + if (state->pf == NFPROTO_IPV4 && protonum == IPPROTO_ICMP) + ret = nf_conntrack_icmpv4_error(tmpl, skb, dataoff, state); +#if IS_ENABLED(CONFIG_IPV6) + else if (state->pf == NFPROTO_IPV6 && protonum == IPPROTO_ICMPV6) + ret = nf_conntrack_icmpv6_error(tmpl, skb, dataoff, state); +#endif + else + return NF_ACCEPT; + + if (ret <= 0) + NF_CT_STAT_INC_ATOMIC(state->net, error); + + return ret; +} + +static int generic_packet(struct nf_conn *ct, struct sk_buff *skb, + enum ip_conntrack_info ctinfo) +{ + const unsigned int *timeout = nf_ct_timeout_lookup(ct); + + if (!timeout) + timeout = &nf_generic_pernet(nf_ct_net(ct))->timeout; + + nf_ct_refresh_acct(ct, ctinfo, skb, *timeout); + return NF_ACCEPT; +} + +/* Returns verdict for packet, or -1 for invalid. */ +static int nf_conntrack_handle_packet(struct nf_conn *ct, + struct sk_buff *skb, + unsigned int dataoff, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) +{ + switch (nf_ct_protonum(ct)) { + case IPPROTO_TCP: + return nf_conntrack_tcp_packet(ct, skb, dataoff, + ctinfo, state); + case IPPROTO_UDP: + return nf_conntrack_udp_packet(ct, skb, dataoff, + ctinfo, state); + case IPPROTO_ICMP: + return nf_conntrack_icmp_packet(ct, skb, ctinfo, state); +#if IS_ENABLED(CONFIG_IPV6) + case IPPROTO_ICMPV6: + return nf_conntrack_icmpv6_packet(ct, skb, ctinfo, state); +#endif +#ifdef CONFIG_NF_CT_PROTO_UDPLITE + case IPPROTO_UDPLITE: + return nf_conntrack_udplite_packet(ct, skb, dataoff, + ctinfo, state); +#endif +#ifdef CONFIG_NF_CT_PROTO_SCTP + case IPPROTO_SCTP: + return nf_conntrack_sctp_packet(ct, skb, dataoff, + ctinfo, state); +#endif +#ifdef CONFIG_NF_CT_PROTO_DCCP + case IPPROTO_DCCP: + return nf_conntrack_dccp_packet(ct, skb, dataoff, + ctinfo, state); +#endif +#ifdef CONFIG_NF_CT_PROTO_GRE + case IPPROTO_GRE: + return nf_conntrack_gre_packet(ct, skb, dataoff, + ctinfo, state); +#endif + } + + return generic_packet(ct, skb, ctinfo); +} + +unsigned int +nf_conntrack_in(struct sk_buff *skb, const struct nf_hook_state *state) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct, *tmpl; + u_int8_t protonum; + int dataoff, ret; + + tmpl = nf_ct_get(skb, &ctinfo); + if (tmpl || ctinfo == IP_CT_UNTRACKED) { + /* Previously seen (loopback or untracked)? Ignore. */ + if ((tmpl && !nf_ct_is_template(tmpl)) || + ctinfo == IP_CT_UNTRACKED) + return NF_ACCEPT; + skb->_nfct = 0; + } + + /* rcu_read_lock()ed by nf_hook_thresh */ + dataoff = get_l4proto(skb, skb_network_offset(skb), state->pf, &protonum); + if (dataoff <= 0) { + pr_debug("not prepared to track yet or error occurred\n"); + NF_CT_STAT_INC_ATOMIC(state->net, invalid); + ret = NF_ACCEPT; + goto out; + } + + if (protonum == IPPROTO_ICMP || protonum == IPPROTO_ICMPV6) { + ret = nf_conntrack_handle_icmp(tmpl, skb, dataoff, + protonum, state); + if (ret <= 0) { + ret = -ret; + goto out; + } + /* ICMP[v6] protocol trackers may assign one conntrack. */ + if (skb->_nfct) + goto out; + } +repeat: + ret = resolve_normal_ct(tmpl, skb, dataoff, + protonum, state); + if (ret < 0) { + /* Too stressed to deal. */ + NF_CT_STAT_INC_ATOMIC(state->net, drop); + ret = NF_DROP; + goto out; + } + + ct = nf_ct_get(skb, &ctinfo); + if (!ct) { + /* Not valid part of a connection */ + NF_CT_STAT_INC_ATOMIC(state->net, invalid); + ret = NF_ACCEPT; + goto out; + } + + ret = nf_conntrack_handle_packet(ct, skb, dataoff, ctinfo, state); + if (ret <= 0) { + /* Invalid: inverse of the return code tells + * the netfilter core what to do */ + pr_debug("nf_conntrack_in: Can't track with proto module\n"); + nf_ct_put(ct); + skb->_nfct = 0; + /* Special case: TCP tracker reports an attempt to reopen a + * closed/aborted connection. We have to go back and create a + * fresh conntrack. + */ + if (ret == -NF_REPEAT) + goto repeat; + + NF_CT_STAT_INC_ATOMIC(state->net, invalid); + if (ret == -NF_DROP) + NF_CT_STAT_INC_ATOMIC(state->net, drop); + + ret = -ret; + goto out; + } + + if (ctinfo == IP_CT_ESTABLISHED_REPLY && + !test_and_set_bit(IPS_SEEN_REPLY_BIT, &ct->status)) + nf_conntrack_event_cache(IPCT_REPLY, ct); +out: + if (tmpl) + nf_ct_put(tmpl); + + return ret; +} +EXPORT_SYMBOL_GPL(nf_conntrack_in); + +/* Alter reply tuple (maybe alter helper). This is for NAT, and is + implicitly racy: see __nf_conntrack_confirm */ +void nf_conntrack_alter_reply(struct nf_conn *ct, + const struct nf_conntrack_tuple *newreply) +{ + struct nf_conn_help *help = nfct_help(ct); + + /* Should be unconfirmed, so not in hash table yet */ + WARN_ON(nf_ct_is_confirmed(ct)); + + pr_debug("Altering reply tuple of %p to ", ct); + nf_ct_dump_tuple(newreply); + + ct->tuplehash[IP_CT_DIR_REPLY].tuple = *newreply; + if (ct->master || (help && !hlist_empty(&help->expectations))) + return; +} +EXPORT_SYMBOL_GPL(nf_conntrack_alter_reply); + +/* Refresh conntrack for this many jiffies and do accounting if do_acct is 1 */ +void __nf_ct_refresh_acct(struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + const struct sk_buff *skb, + u32 extra_jiffies, + bool do_acct) +{ + /* Only update if this is not a fixed timeout */ + if (test_bit(IPS_FIXED_TIMEOUT_BIT, &ct->status)) + goto acct; + + /* If not in hash table, timer will not be active yet */ + if (nf_ct_is_confirmed(ct)) + extra_jiffies += nfct_time_stamp; + + if (READ_ONCE(ct->timeout) != extra_jiffies) + WRITE_ONCE(ct->timeout, extra_jiffies); +acct: + if (do_acct) + nf_ct_acct_update(ct, CTINFO2DIR(ctinfo), skb->len); +} +EXPORT_SYMBOL_GPL(__nf_ct_refresh_acct); + +bool nf_ct_kill_acct(struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + const struct sk_buff *skb) +{ + nf_ct_acct_update(ct, CTINFO2DIR(ctinfo), skb->len); + + return nf_ct_delete(ct, 0, 0); +} +EXPORT_SYMBOL_GPL(nf_ct_kill_acct); + +#if IS_ENABLED(CONFIG_NF_CT_NETLINK) + +#include +#include +#include + +/* Generic function for tcp/udp/sctp/dccp and alike. */ +int nf_ct_port_tuple_to_nlattr(struct sk_buff *skb, + const struct nf_conntrack_tuple *tuple) +{ + if (nla_put_be16(skb, CTA_PROTO_SRC_PORT, tuple->src.u.tcp.port) || + nla_put_be16(skb, CTA_PROTO_DST_PORT, tuple->dst.u.tcp.port)) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} +EXPORT_SYMBOL_GPL(nf_ct_port_tuple_to_nlattr); + +const struct nla_policy nf_ct_port_nla_policy[CTA_PROTO_MAX+1] = { + [CTA_PROTO_SRC_PORT] = { .type = NLA_U16 }, + [CTA_PROTO_DST_PORT] = { .type = NLA_U16 }, +}; +EXPORT_SYMBOL_GPL(nf_ct_port_nla_policy); + +int nf_ct_port_nlattr_to_tuple(struct nlattr *tb[], + struct nf_conntrack_tuple *t, + u_int32_t flags) +{ + if (flags & CTA_FILTER_FLAG(CTA_PROTO_SRC_PORT)) { + if (!tb[CTA_PROTO_SRC_PORT]) + return -EINVAL; + + t->src.u.tcp.port = nla_get_be16(tb[CTA_PROTO_SRC_PORT]); + } + + if (flags & CTA_FILTER_FLAG(CTA_PROTO_DST_PORT)) { + if (!tb[CTA_PROTO_DST_PORT]) + return -EINVAL; + + t->dst.u.tcp.port = nla_get_be16(tb[CTA_PROTO_DST_PORT]); + } + + return 0; +} +EXPORT_SYMBOL_GPL(nf_ct_port_nlattr_to_tuple); + +unsigned int nf_ct_port_nlattr_tuple_size(void) +{ + static unsigned int size __read_mostly; + + if (!size) + size = nla_policy_len(nf_ct_port_nla_policy, CTA_PROTO_MAX + 1); + + return size; +} +EXPORT_SYMBOL_GPL(nf_ct_port_nlattr_tuple_size); +#endif + +/* Used by ipt_REJECT and ip6t_REJECT. */ +static void nf_conntrack_attach(struct sk_buff *nskb, const struct sk_buff *skb) +{ + struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + + /* This ICMP is in reverse direction to the packet which caused it */ + ct = nf_ct_get(skb, &ctinfo); + if (CTINFO2DIR(ctinfo) == IP_CT_DIR_ORIGINAL) + ctinfo = IP_CT_RELATED_REPLY; + else + ctinfo = IP_CT_RELATED; + + /* Attach to new skbuff, and increment count */ + nf_ct_set(nskb, ct, ctinfo); + nf_conntrack_get(skb_nfct(nskb)); +} + +static int __nf_conntrack_update(struct net *net, struct sk_buff *skb, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo) +{ + const struct nf_nat_hook *nat_hook; + struct nf_conntrack_tuple_hash *h; + struct nf_conntrack_tuple tuple; + unsigned int status; + int dataoff; + u16 l3num; + u8 l4num; + + l3num = nf_ct_l3num(ct); + + dataoff = get_l4proto(skb, skb_network_offset(skb), l3num, &l4num); + if (dataoff <= 0) + return -1; + + if (!nf_ct_get_tuple(skb, skb_network_offset(skb), dataoff, l3num, + l4num, net, &tuple)) + return -1; + + if (ct->status & IPS_SRC_NAT) { + memcpy(tuple.src.u3.all, + ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.u3.all, + sizeof(tuple.src.u3.all)); + tuple.src.u.all = + ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.u.all; + } + + if (ct->status & IPS_DST_NAT) { + memcpy(tuple.dst.u3.all, + ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.u3.all, + sizeof(tuple.dst.u3.all)); + tuple.dst.u.all = + ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.u.all; + } + + h = nf_conntrack_find_get(net, nf_ct_zone(ct), &tuple); + if (!h) + return 0; + + /* Store status bits of the conntrack that is clashing to re-do NAT + * mangling according to what it has been done already to this packet. + */ + status = ct->status; + + nf_ct_put(ct); + ct = nf_ct_tuplehash_to_ctrack(h); + nf_ct_set(skb, ct, ctinfo); + + nat_hook = rcu_dereference(nf_nat_hook); + if (!nat_hook) + return 0; + + if (status & IPS_SRC_NAT && + nat_hook->manip_pkt(skb, ct, NF_NAT_MANIP_SRC, + IP_CT_DIR_ORIGINAL) == NF_DROP) + return -1; + + if (status & IPS_DST_NAT && + nat_hook->manip_pkt(skb, ct, NF_NAT_MANIP_DST, + IP_CT_DIR_ORIGINAL) == NF_DROP) + return -1; + + return 0; +} + +/* This packet is coming from userspace via nf_queue, complete the packet + * processing after the helper invocation in nf_confirm(). + */ +static int nf_confirm_cthelper(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo) +{ + const struct nf_conntrack_helper *helper; + const struct nf_conn_help *help; + int protoff; + + help = nfct_help(ct); + if (!help) + return 0; + + helper = rcu_dereference(help->helper); + if (!helper) + return 0; + + if (!(helper->flags & NF_CT_HELPER_F_USERSPACE)) + return 0; + + switch (nf_ct_l3num(ct)) { + case NFPROTO_IPV4: + protoff = skb_network_offset(skb) + ip_hdrlen(skb); + break; +#if IS_ENABLED(CONFIG_IPV6) + case NFPROTO_IPV6: { + __be16 frag_off; + u8 pnum; + + pnum = ipv6_hdr(skb)->nexthdr; + protoff = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &pnum, + &frag_off); + if (protoff < 0 || (frag_off & htons(~0x7)) != 0) + return 0; + break; + } +#endif + default: + return 0; + } + + if (test_bit(IPS_SEQ_ADJUST_BIT, &ct->status) && + !nf_is_loopback_packet(skb)) { + if (!nf_ct_seq_adjust(skb, ct, ctinfo, protoff)) { + NF_CT_STAT_INC_ATOMIC(nf_ct_net(ct), drop); + return -1; + } + } + + /* We've seen it coming out the other side: confirm it */ + return nf_conntrack_confirm(skb) == NF_DROP ? - 1 : 0; +} + +static int nf_conntrack_update(struct net *net, struct sk_buff *skb) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + int err; + + ct = nf_ct_get(skb, &ctinfo); + if (!ct) + return 0; + + if (!nf_ct_is_confirmed(ct)) { + err = __nf_conntrack_update(net, skb, ct, ctinfo); + if (err < 0) + return err; + + ct = nf_ct_get(skb, &ctinfo); + } + + return nf_confirm_cthelper(skb, ct, ctinfo); +} + +static bool nf_conntrack_get_tuple_skb(struct nf_conntrack_tuple *dst_tuple, + const struct sk_buff *skb) +{ + const struct nf_conntrack_tuple *src_tuple; + const struct nf_conntrack_tuple_hash *hash; + struct nf_conntrack_tuple srctuple; + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + + ct = nf_ct_get(skb, &ctinfo); + if (ct) { + src_tuple = nf_ct_tuple(ct, CTINFO2DIR(ctinfo)); + memcpy(dst_tuple, src_tuple, sizeof(*dst_tuple)); + return true; + } + + if (!nf_ct_get_tuplepr(skb, skb_network_offset(skb), + NFPROTO_IPV4, dev_net(skb->dev), + &srctuple)) + return false; + + hash = nf_conntrack_find_get(dev_net(skb->dev), + &nf_ct_zone_dflt, + &srctuple); + if (!hash) + return false; + + ct = nf_ct_tuplehash_to_ctrack(hash); + src_tuple = nf_ct_tuple(ct, !hash->tuple.dst.dir); + memcpy(dst_tuple, src_tuple, sizeof(*dst_tuple)); + nf_ct_put(ct); + + return true; +} + +/* Bring out ya dead! */ +static struct nf_conn * +get_next_corpse(int (*iter)(struct nf_conn *i, void *data), + const struct nf_ct_iter_data *iter_data, unsigned int *bucket) +{ + struct nf_conntrack_tuple_hash *h; + struct nf_conn *ct; + struct hlist_nulls_node *n; + spinlock_t *lockp; + + for (; *bucket < nf_conntrack_htable_size; (*bucket)++) { + struct hlist_nulls_head *hslot = &nf_conntrack_hash[*bucket]; + + if (hlist_nulls_empty(hslot)) + continue; + + lockp = &nf_conntrack_locks[*bucket % CONNTRACK_LOCKS]; + local_bh_disable(); + nf_conntrack_lock(lockp); + hlist_nulls_for_each_entry(h, n, hslot, hnnode) { + if (NF_CT_DIRECTION(h) != IP_CT_DIR_REPLY) + continue; + /* All nf_conn objects are added to hash table twice, one + * for original direction tuple, once for the reply tuple. + * + * Exception: In the IPS_NAT_CLASH case, only the reply + * tuple is added (the original tuple already existed for + * a different object). + * + * We only need to call the iterator once for each + * conntrack, so we just use the 'reply' direction + * tuple while iterating. + */ + ct = nf_ct_tuplehash_to_ctrack(h); + + if (iter_data->net && + !net_eq(iter_data->net, nf_ct_net(ct))) + continue; + + if (iter(ct, iter_data->data)) + goto found; + } + spin_unlock(lockp); + local_bh_enable(); + cond_resched(); + } + + return NULL; +found: + refcount_inc(&ct->ct_general.use); + spin_unlock(lockp); + local_bh_enable(); + return ct; +} + +static void nf_ct_iterate_cleanup(int (*iter)(struct nf_conn *i, void *data), + const struct nf_ct_iter_data *iter_data) +{ + unsigned int bucket = 0; + struct nf_conn *ct; + + might_sleep(); + + mutex_lock(&nf_conntrack_mutex); + while ((ct = get_next_corpse(iter, iter_data, &bucket)) != NULL) { + /* Time to push up daises... */ + + nf_ct_delete(ct, iter_data->portid, iter_data->report); + nf_ct_put(ct); + cond_resched(); + } + mutex_unlock(&nf_conntrack_mutex); +} + +void nf_ct_iterate_cleanup_net(int (*iter)(struct nf_conn *i, void *data), + const struct nf_ct_iter_data *iter_data) +{ + struct net *net = iter_data->net; + struct nf_conntrack_net *cnet = nf_ct_pernet(net); + + might_sleep(); + + if (atomic_read(&cnet->count) == 0) + return; + + nf_ct_iterate_cleanup(iter, iter_data); +} +EXPORT_SYMBOL_GPL(nf_ct_iterate_cleanup_net); + +/** + * nf_ct_iterate_destroy - destroy unconfirmed conntracks and iterate table + * @iter: callback to invoke for each conntrack + * @data: data to pass to @iter + * + * Like nf_ct_iterate_cleanup, but first marks conntracks on the + * unconfirmed list as dying (so they will not be inserted into + * main table). + * + * Can only be called in module exit path. + */ +void +nf_ct_iterate_destroy(int (*iter)(struct nf_conn *i, void *data), void *data) +{ + struct nf_ct_iter_data iter_data = {}; + struct net *net; + + down_read(&net_rwsem); + for_each_net(net) { + struct nf_conntrack_net *cnet = nf_ct_pernet(net); + + if (atomic_read(&cnet->count) == 0) + continue; + nf_queue_nf_hook_drop(net); + } + up_read(&net_rwsem); + + /* Need to wait for netns cleanup worker to finish, if its + * running -- it might have deleted a net namespace from + * the global list, so hook drop above might not have + * affected all namespaces. + */ + net_ns_barrier(); + + /* a skb w. unconfirmed conntrack could have been reinjected just + * before we called nf_queue_nf_hook_drop(). + * + * This makes sure its inserted into conntrack table. + */ + synchronize_net(); + + nf_ct_ext_bump_genid(); + iter_data.data = data; + nf_ct_iterate_cleanup(iter, &iter_data); + + /* Another cpu might be in a rcu read section with + * rcu protected pointer cleared in iter callback + * or hidden via nf_ct_ext_bump_genid() above. + * + * Wait until those are done. + */ + synchronize_rcu(); +} +EXPORT_SYMBOL_GPL(nf_ct_iterate_destroy); + +static int kill_all(struct nf_conn *i, void *data) +{ + return 1; +} + +void nf_conntrack_cleanup_start(void) +{ + cleanup_nf_conntrack_bpf(); + conntrack_gc_work.exiting = true; +} + +void nf_conntrack_cleanup_end(void) +{ + RCU_INIT_POINTER(nf_ct_hook, NULL); + cancel_delayed_work_sync(&conntrack_gc_work.dwork); + kvfree(nf_conntrack_hash); + + nf_conntrack_proto_fini(); + nf_conntrack_helper_fini(); + nf_conntrack_expect_fini(); + + kmem_cache_destroy(nf_conntrack_cachep); +} + +/* + * Mishearing the voices in his head, our hero wonders how he's + * supposed to kill the mall. + */ +void nf_conntrack_cleanup_net(struct net *net) +{ + LIST_HEAD(single); + + list_add(&net->exit_list, &single); + nf_conntrack_cleanup_net_list(&single); +} + +void nf_conntrack_cleanup_net_list(struct list_head *net_exit_list) +{ + struct nf_ct_iter_data iter_data = {}; + struct net *net; + int busy; + + /* + * This makes sure all current packets have passed through + * netfilter framework. Roll on, two-stage module + * delete... + */ + synchronize_net(); +i_see_dead_people: + busy = 0; + list_for_each_entry(net, net_exit_list, exit_list) { + struct nf_conntrack_net *cnet = nf_ct_pernet(net); + + iter_data.net = net; + nf_ct_iterate_cleanup_net(kill_all, &iter_data); + if (atomic_read(&cnet->count) != 0) + busy = 1; + } + if (busy) { + schedule(); + goto i_see_dead_people; + } + + list_for_each_entry(net, net_exit_list, exit_list) { + nf_conntrack_ecache_pernet_fini(net); + nf_conntrack_expect_pernet_fini(net); + free_percpu(net->ct.stat); + } +} + +void *nf_ct_alloc_hashtable(unsigned int *sizep, int nulls) +{ + struct hlist_nulls_head *hash; + unsigned int nr_slots, i; + + if (*sizep > (UINT_MAX / sizeof(struct hlist_nulls_head))) + return NULL; + + BUILD_BUG_ON(sizeof(struct hlist_nulls_head) != sizeof(struct hlist_head)); + nr_slots = *sizep = roundup(*sizep, PAGE_SIZE / sizeof(struct hlist_nulls_head)); + + hash = kvcalloc(nr_slots, sizeof(struct hlist_nulls_head), GFP_KERNEL); + + if (hash && nulls) + for (i = 0; i < nr_slots; i++) + INIT_HLIST_NULLS_HEAD(&hash[i], i); + + return hash; +} +EXPORT_SYMBOL_GPL(nf_ct_alloc_hashtable); + +int nf_conntrack_hash_resize(unsigned int hashsize) +{ + int i, bucket; + unsigned int old_size; + struct hlist_nulls_head *hash, *old_hash; + struct nf_conntrack_tuple_hash *h; + struct nf_conn *ct; + + if (!hashsize) + return -EINVAL; + + hash = nf_ct_alloc_hashtable(&hashsize, 1); + if (!hash) + return -ENOMEM; + + mutex_lock(&nf_conntrack_mutex); + old_size = nf_conntrack_htable_size; + if (old_size == hashsize) { + mutex_unlock(&nf_conntrack_mutex); + kvfree(hash); + return 0; + } + + local_bh_disable(); + nf_conntrack_all_lock(); + write_seqcount_begin(&nf_conntrack_generation); + + /* Lookups in the old hash might happen in parallel, which means we + * might get false negatives during connection lookup. New connections + * created because of a false negative won't make it into the hash + * though since that required taking the locks. + */ + + for (i = 0; i < nf_conntrack_htable_size; i++) { + while (!hlist_nulls_empty(&nf_conntrack_hash[i])) { + unsigned int zone_id; + + h = hlist_nulls_entry(nf_conntrack_hash[i].first, + struct nf_conntrack_tuple_hash, hnnode); + ct = nf_ct_tuplehash_to_ctrack(h); + hlist_nulls_del_rcu(&h->hnnode); + + zone_id = nf_ct_zone_id(nf_ct_zone(ct), NF_CT_DIRECTION(h)); + bucket = __hash_conntrack(nf_ct_net(ct), + &h->tuple, zone_id, hashsize); + hlist_nulls_add_head_rcu(&h->hnnode, &hash[bucket]); + } + } + old_hash = nf_conntrack_hash; + + nf_conntrack_hash = hash; + nf_conntrack_htable_size = hashsize; + + write_seqcount_end(&nf_conntrack_generation); + nf_conntrack_all_unlock(); + local_bh_enable(); + + mutex_unlock(&nf_conntrack_mutex); + + synchronize_net(); + kvfree(old_hash); + return 0; +} + +int nf_conntrack_set_hashsize(const char *val, const struct kernel_param *kp) +{ + unsigned int hashsize; + int rc; + + if (current->nsproxy->net_ns != &init_net) + return -EOPNOTSUPP; + + /* On boot, we can set this without any fancy locking. */ + if (!nf_conntrack_hash) + return param_set_uint(val, kp); + + rc = kstrtouint(val, 0, &hashsize); + if (rc) + return rc; + + return nf_conntrack_hash_resize(hashsize); +} + +int nf_conntrack_init_start(void) +{ + unsigned long nr_pages = totalram_pages(); + int max_factor = 8; + int ret = -ENOMEM; + int i; + + seqcount_spinlock_init(&nf_conntrack_generation, + &nf_conntrack_locks_all_lock); + + for (i = 0; i < CONNTRACK_LOCKS; i++) + spin_lock_init(&nf_conntrack_locks[i]); + + if (!nf_conntrack_htable_size) { + nf_conntrack_htable_size + = (((nr_pages << PAGE_SHIFT) / 16384) + / sizeof(struct hlist_head)); + if (BITS_PER_LONG >= 64 && + nr_pages > (4 * (1024 * 1024 * 1024 / PAGE_SIZE))) + nf_conntrack_htable_size = 262144; + else if (nr_pages > (1024 * 1024 * 1024 / PAGE_SIZE)) + nf_conntrack_htable_size = 65536; + + if (nf_conntrack_htable_size < 1024) + nf_conntrack_htable_size = 1024; + /* Use a max. factor of one by default to keep the average + * hash chain length at 2 entries. Each entry has to be added + * twice (once for original direction, once for reply). + * When a table size is given we use the old value of 8 to + * avoid implicit reduction of the max entries setting. + */ + max_factor = 1; + } + + nf_conntrack_hash = nf_ct_alloc_hashtable(&nf_conntrack_htable_size, 1); + if (!nf_conntrack_hash) + return -ENOMEM; + + nf_conntrack_max = max_factor * nf_conntrack_htable_size; + + nf_conntrack_cachep = kmem_cache_create("nf_conntrack", + sizeof(struct nf_conn), + NFCT_INFOMASK + 1, + SLAB_TYPESAFE_BY_RCU | SLAB_HWCACHE_ALIGN, NULL); + if (!nf_conntrack_cachep) + goto err_cachep; + + ret = nf_conntrack_expect_init(); + if (ret < 0) + goto err_expect; + + ret = nf_conntrack_helper_init(); + if (ret < 0) + goto err_helper; + + ret = nf_conntrack_proto_init(); + if (ret < 0) + goto err_proto; + + conntrack_gc_work_init(&conntrack_gc_work); + queue_delayed_work(system_power_efficient_wq, &conntrack_gc_work.dwork, HZ); + + ret = register_nf_conntrack_bpf(); + if (ret < 0) + goto err_kfunc; + + return 0; + +err_kfunc: + cancel_delayed_work_sync(&conntrack_gc_work.dwork); + nf_conntrack_proto_fini(); +err_proto: + nf_conntrack_helper_fini(); +err_helper: + nf_conntrack_expect_fini(); +err_expect: + kmem_cache_destroy(nf_conntrack_cachep); +err_cachep: + kvfree(nf_conntrack_hash); + return ret; +} + +static const struct nf_ct_hook nf_conntrack_hook = { + .update = nf_conntrack_update, + .destroy = nf_ct_destroy, + .get_tuple_skb = nf_conntrack_get_tuple_skb, + .attach = nf_conntrack_attach, +}; + +void nf_conntrack_init_end(void) +{ + RCU_INIT_POINTER(nf_ct_hook, &nf_conntrack_hook); +} + +/* + * We need to use special "null" values, not used in hash table + */ +#define UNCONFIRMED_NULLS_VAL ((1<<30)+0) + +int nf_conntrack_init_net(struct net *net) +{ + struct nf_conntrack_net *cnet = nf_ct_pernet(net); + int ret = -ENOMEM; + + BUILD_BUG_ON(IP_CT_UNTRACKED == IP_CT_NUMBER); + BUILD_BUG_ON_NOT_POWER_OF_2(CONNTRACK_LOCKS); + atomic_set(&cnet->count, 0); + + net->ct.stat = alloc_percpu(struct ip_conntrack_stat); + if (!net->ct.stat) + return ret; + + ret = nf_conntrack_expect_pernet_init(net); + if (ret < 0) + goto err_expect; + + nf_conntrack_acct_pernet_init(net); + nf_conntrack_tstamp_pernet_init(net); + nf_conntrack_ecache_pernet_init(net); + nf_conntrack_proto_pernet_init(net); + + return 0; + +err_expect: + free_percpu(net->ct.stat); + return ret; +} + +/* ctnetlink code shared by both ctnetlink and nf_conntrack_bpf */ + +int __nf_ct_change_timeout(struct nf_conn *ct, u64 timeout) +{ + if (test_bit(IPS_FIXED_TIMEOUT_BIT, &ct->status)) + return -EPERM; + + __nf_ct_set_timeout(ct, timeout); + + if (test_bit(IPS_DYING_BIT, &ct->status)) + return -ETIME; + + return 0; +} +EXPORT_SYMBOL_GPL(__nf_ct_change_timeout); + +void __nf_ct_change_status(struct nf_conn *ct, unsigned long on, unsigned long off) +{ + unsigned int bit; + + /* Ignore these unchangable bits */ + on &= ~IPS_UNCHANGEABLE_MASK; + off &= ~IPS_UNCHANGEABLE_MASK; + + for (bit = 0; bit < __IPS_MAX_BIT; bit++) { + if (on & (1 << bit)) + set_bit(bit, &ct->status); + else if (off & (1 << bit)) + clear_bit(bit, &ct->status); + } +} +EXPORT_SYMBOL_GPL(__nf_ct_change_status); + +int nf_ct_change_status_common(struct nf_conn *ct, unsigned int status) +{ + unsigned long d; + + d = ct->status ^ status; + + if (d & (IPS_EXPECTED|IPS_CONFIRMED|IPS_DYING)) + /* unchangeable */ + return -EBUSY; + + if (d & IPS_SEEN_REPLY && !(status & IPS_SEEN_REPLY)) + /* SEEN_REPLY bit can only be set */ + return -EBUSY; + + if (d & IPS_ASSURED && !(status & IPS_ASSURED)) + /* ASSURED bit can only be set */ + return -EBUSY; + + __nf_ct_change_status(ct, status, 0); + return 0; +} +EXPORT_SYMBOL_GPL(nf_ct_change_status_common); diff --git a/net/netfilter/nf_conntrack_ecache.c b/net/netfilter/nf_conntrack_ecache.c new file mode 100644 index 000000000..69948e1d6 --- /dev/null +++ b/net/netfilter/nf_conntrack_ecache.c @@ -0,0 +1,358 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Event cache for netfilter. */ + +/* + * (C) 2005 Harald Welte + * (C) 2005 Patrick McHardy + * (C) 2005-2006 Netfilter Core Team + * (C) 2005 USAGI/WIDE Project + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +static DEFINE_MUTEX(nf_ct_ecache_mutex); + +#define DYING_NULLS_VAL ((1 << 30) + 1) +#define ECACHE_MAX_JIFFIES msecs_to_jiffies(10) +#define ECACHE_RETRY_JIFFIES msecs_to_jiffies(10) + +enum retry_state { + STATE_CONGESTED, + STATE_RESTART, + STATE_DONE, +}; + +struct nf_conntrack_net_ecache *nf_conn_pernet_ecache(const struct net *net) +{ + struct nf_conntrack_net *cnet = nf_ct_pernet(net); + + return &cnet->ecache; +} +#if IS_MODULE(CONFIG_NF_CT_NETLINK) +EXPORT_SYMBOL_GPL(nf_conn_pernet_ecache); +#endif + +static enum retry_state ecache_work_evict_list(struct nf_conntrack_net *cnet) +{ + unsigned long stop = jiffies + ECACHE_MAX_JIFFIES; + struct hlist_nulls_head evicted_list; + enum retry_state ret = STATE_DONE; + struct nf_conntrack_tuple_hash *h; + struct hlist_nulls_node *n; + unsigned int sent; + + INIT_HLIST_NULLS_HEAD(&evicted_list, DYING_NULLS_VAL); + +next: + sent = 0; + spin_lock_bh(&cnet->ecache.dying_lock); + + hlist_nulls_for_each_entry_safe(h, n, &cnet->ecache.dying_list, hnnode) { + struct nf_conn *ct = nf_ct_tuplehash_to_ctrack(h); + + /* The worker owns all entries, ct remains valid until nf_ct_put + * in the loop below. + */ + if (nf_conntrack_event(IPCT_DESTROY, ct)) { + ret = STATE_CONGESTED; + break; + } + + hlist_nulls_del_rcu(&ct->tuplehash[IP_CT_DIR_ORIGINAL].hnnode); + hlist_nulls_add_head(&ct->tuplehash[IP_CT_DIR_REPLY].hnnode, &evicted_list); + + if (time_after(stop, jiffies)) { + ret = STATE_RESTART; + break; + } + + if (sent++ > 16) { + spin_unlock_bh(&cnet->ecache.dying_lock); + cond_resched(); + goto next; + } + } + + spin_unlock_bh(&cnet->ecache.dying_lock); + + hlist_nulls_for_each_entry_safe(h, n, &evicted_list, hnnode) { + struct nf_conn *ct = nf_ct_tuplehash_to_ctrack(h); + + hlist_nulls_del_rcu(&ct->tuplehash[IP_CT_DIR_REPLY].hnnode); + nf_ct_put(ct); + + cond_resched(); + } + + return ret; +} + +static void ecache_work(struct work_struct *work) +{ + struct nf_conntrack_net *cnet = container_of(work, struct nf_conntrack_net, ecache.dwork.work); + int ret, delay = -1; + + ret = ecache_work_evict_list(cnet); + switch (ret) { + case STATE_CONGESTED: + delay = ECACHE_RETRY_JIFFIES; + break; + case STATE_RESTART: + delay = 0; + break; + case STATE_DONE: + break; + } + + if (delay >= 0) + schedule_delayed_work(&cnet->ecache.dwork, delay); +} + +static int __nf_conntrack_eventmask_report(struct nf_conntrack_ecache *e, + const u32 events, + const u32 missed, + const struct nf_ct_event *item) +{ + struct net *net = nf_ct_net(item->ct); + struct nf_ct_event_notifier *notify; + u32 old, want; + int ret; + + if (!((events | missed) & e->ctmask)) + return 0; + + rcu_read_lock(); + + notify = rcu_dereference(net->ct.nf_conntrack_event_cb); + if (!notify) { + rcu_read_unlock(); + return 0; + } + + ret = notify->ct_event(events | missed, item); + rcu_read_unlock(); + + if (likely(ret >= 0 && missed == 0)) + return 0; + + do { + old = READ_ONCE(e->missed); + if (ret < 0) + want = old | events; + else + want = old & ~missed; + } while (cmpxchg(&e->missed, old, want) != old); + + return ret; +} + +int nf_conntrack_eventmask_report(unsigned int events, struct nf_conn *ct, + u32 portid, int report) +{ + struct nf_conntrack_ecache *e; + struct nf_ct_event item; + unsigned int missed; + int ret; + + if (!nf_ct_is_confirmed(ct)) + return 0; + + e = nf_ct_ecache_find(ct); + if (!e) + return 0; + + memset(&item, 0, sizeof(item)); + + item.ct = ct; + item.portid = e->portid ? e->portid : portid; + item.report = report; + + /* This is a resent of a destroy event? If so, skip missed */ + missed = e->portid ? 0 : e->missed; + + ret = __nf_conntrack_eventmask_report(e, events, missed, &item); + if (unlikely(ret < 0 && (events & (1 << IPCT_DESTROY)))) { + /* This is a destroy event that has been triggered by a process, + * we store the PORTID to include it in the retransmission. + */ + if (e->portid == 0 && portid != 0) + e->portid = portid; + } + + return ret; +} +EXPORT_SYMBOL_GPL(nf_conntrack_eventmask_report); + +/* deliver cached events and clear cache entry - must be called with locally + * disabled softirqs */ +void nf_ct_deliver_cached_events(struct nf_conn *ct) +{ + struct nf_conntrack_ecache *e; + struct nf_ct_event item; + unsigned int events; + + if (!nf_ct_is_confirmed(ct) || nf_ct_is_dying(ct)) + return; + + e = nf_ct_ecache_find(ct); + if (e == NULL) + return; + + events = xchg(&e->cache, 0); + + item.ct = ct; + item.portid = 0; + item.report = 0; + + /* We make a copy of the missed event cache without taking + * the lock, thus we may send missed events twice. However, + * this does not harm and it happens very rarely. + */ + __nf_conntrack_eventmask_report(e, events, e->missed, &item); +} +EXPORT_SYMBOL_GPL(nf_ct_deliver_cached_events); + +void nf_ct_expect_event_report(enum ip_conntrack_expect_events event, + struct nf_conntrack_expect *exp, + u32 portid, int report) + +{ + struct net *net = nf_ct_exp_net(exp); + struct nf_ct_event_notifier *notify; + struct nf_conntrack_ecache *e; + + rcu_read_lock(); + notify = rcu_dereference(net->ct.nf_conntrack_event_cb); + if (!notify) + goto out_unlock; + + e = nf_ct_ecache_find(exp->master); + if (!e) + goto out_unlock; + + if (e->expmask & (1 << event)) { + struct nf_exp_event item = { + .exp = exp, + .portid = portid, + .report = report + }; + notify->exp_event(1 << event, &item); + } +out_unlock: + rcu_read_unlock(); +} + +void nf_conntrack_register_notifier(struct net *net, + const struct nf_ct_event_notifier *new) +{ + struct nf_ct_event_notifier *notify; + + mutex_lock(&nf_ct_ecache_mutex); + notify = rcu_dereference_protected(net->ct.nf_conntrack_event_cb, + lockdep_is_held(&nf_ct_ecache_mutex)); + WARN_ON_ONCE(notify); + rcu_assign_pointer(net->ct.nf_conntrack_event_cb, new); + mutex_unlock(&nf_ct_ecache_mutex); +} +EXPORT_SYMBOL_GPL(nf_conntrack_register_notifier); + +void nf_conntrack_unregister_notifier(struct net *net) +{ + mutex_lock(&nf_ct_ecache_mutex); + RCU_INIT_POINTER(net->ct.nf_conntrack_event_cb, NULL); + mutex_unlock(&nf_ct_ecache_mutex); + /* synchronize_rcu() is called after netns pre_exit */ +} +EXPORT_SYMBOL_GPL(nf_conntrack_unregister_notifier); + +void nf_conntrack_ecache_work(struct net *net, enum nf_ct_ecache_state state) +{ + struct nf_conntrack_net *cnet = nf_ct_pernet(net); + + if (state == NFCT_ECACHE_DESTROY_FAIL && + !delayed_work_pending(&cnet->ecache.dwork)) { + schedule_delayed_work(&cnet->ecache.dwork, HZ); + net->ct.ecache_dwork_pending = true; + } else if (state == NFCT_ECACHE_DESTROY_SENT) { + if (!hlist_nulls_empty(&cnet->ecache.dying_list)) + mod_delayed_work(system_wq, &cnet->ecache.dwork, 0); + else + net->ct.ecache_dwork_pending = false; + } +} + +bool nf_ct_ecache_ext_add(struct nf_conn *ct, u16 ctmask, u16 expmask, gfp_t gfp) +{ + struct net *net = nf_ct_net(ct); + struct nf_conntrack_ecache *e; + + switch (net->ct.sysctl_events) { + case 0: + /* assignment via template / ruleset? ignore sysctl. */ + if (ctmask || expmask) + break; + return true; + case 2: /* autodetect: no event listener, don't allocate extension. */ + if (!READ_ONCE(nf_ctnetlink_has_listener)) + return true; + fallthrough; + case 1: + /* always allocate an extension. */ + if (!ctmask && !expmask) { + ctmask = ~0; + expmask = ~0; + } + break; + default: + WARN_ON_ONCE(1); + return true; + } + + e = nf_ct_ext_add(ct, NF_CT_EXT_ECACHE, gfp); + if (e) { + e->ctmask = ctmask; + e->expmask = expmask; + } + + return e != NULL; +} +EXPORT_SYMBOL_GPL(nf_ct_ecache_ext_add); + +#define NF_CT_EVENTS_DEFAULT 2 +static int nf_ct_events __read_mostly = NF_CT_EVENTS_DEFAULT; + +void nf_conntrack_ecache_pernet_init(struct net *net) +{ + struct nf_conntrack_net *cnet = nf_ct_pernet(net); + + net->ct.sysctl_events = nf_ct_events; + + INIT_DELAYED_WORK(&cnet->ecache.dwork, ecache_work); + INIT_HLIST_NULLS_HEAD(&cnet->ecache.dying_list, DYING_NULLS_VAL); + spin_lock_init(&cnet->ecache.dying_lock); + + BUILD_BUG_ON(__IPCT_MAX >= 16); /* e->ctmask is u16 */ +} + +void nf_conntrack_ecache_pernet_fini(struct net *net) +{ + struct nf_conntrack_net *cnet = nf_ct_pernet(net); + + cancel_delayed_work_sync(&cnet->ecache.dwork); +} diff --git a/net/netfilter/nf_conntrack_expect.c b/net/netfilter/nf_conntrack_expect.c new file mode 100644 index 000000000..96948e98e --- /dev/null +++ b/net/netfilter/nf_conntrack_expect.c @@ -0,0 +1,745 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Expectation handling for nf_conntrack. */ + +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2006 Netfilter Core Team + * (C) 2003,2004 USAGI/WIDE Project + * (c) 2005-2012 Patrick McHardy + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +unsigned int nf_ct_expect_hsize __read_mostly; +EXPORT_SYMBOL_GPL(nf_ct_expect_hsize); + +struct hlist_head *nf_ct_expect_hash __read_mostly; +EXPORT_SYMBOL_GPL(nf_ct_expect_hash); + +unsigned int nf_ct_expect_max __read_mostly; + +static struct kmem_cache *nf_ct_expect_cachep __read_mostly; +static siphash_aligned_key_t nf_ct_expect_hashrnd; + +/* nf_conntrack_expect helper functions */ +void nf_ct_unlink_expect_report(struct nf_conntrack_expect *exp, + u32 portid, int report) +{ + struct nf_conn_help *master_help = nfct_help(exp->master); + struct net *net = nf_ct_exp_net(exp); + struct nf_conntrack_net *cnet; + + WARN_ON(!master_help); + WARN_ON(timer_pending(&exp->timeout)); + + hlist_del_rcu(&exp->hnode); + + cnet = nf_ct_pernet(net); + cnet->expect_count--; + + hlist_del_rcu(&exp->lnode); + master_help->expecting[exp->class]--; + + nf_ct_expect_event_report(IPEXP_DESTROY, exp, portid, report); + nf_ct_expect_put(exp); + + NF_CT_STAT_INC(net, expect_delete); +} +EXPORT_SYMBOL_GPL(nf_ct_unlink_expect_report); + +static void nf_ct_expectation_timed_out(struct timer_list *t) +{ + struct nf_conntrack_expect *exp = from_timer(exp, t, timeout); + + spin_lock_bh(&nf_conntrack_expect_lock); + nf_ct_unlink_expect(exp); + spin_unlock_bh(&nf_conntrack_expect_lock); + nf_ct_expect_put(exp); +} + +static unsigned int nf_ct_expect_dst_hash(const struct net *n, const struct nf_conntrack_tuple *tuple) +{ + struct { + union nf_inet_addr dst_addr; + u32 net_mix; + u16 dport; + u8 l3num; + u8 protonum; + } __aligned(SIPHASH_ALIGNMENT) combined; + u32 hash; + + get_random_once(&nf_ct_expect_hashrnd, sizeof(nf_ct_expect_hashrnd)); + + memset(&combined, 0, sizeof(combined)); + + combined.dst_addr = tuple->dst.u3; + combined.net_mix = net_hash_mix(n); + combined.dport = (__force __u16)tuple->dst.u.all; + combined.l3num = tuple->src.l3num; + combined.protonum = tuple->dst.protonum; + + hash = siphash(&combined, sizeof(combined), &nf_ct_expect_hashrnd); + + return reciprocal_scale(hash, nf_ct_expect_hsize); +} + +static bool +nf_ct_exp_equal(const struct nf_conntrack_tuple *tuple, + const struct nf_conntrack_expect *i, + const struct nf_conntrack_zone *zone, + const struct net *net) +{ + return nf_ct_tuple_mask_cmp(tuple, &i->tuple, &i->mask) && + net_eq(net, nf_ct_net(i->master)) && + nf_ct_zone_equal_any(i->master, zone); +} + +bool nf_ct_remove_expect(struct nf_conntrack_expect *exp) +{ + if (del_timer(&exp->timeout)) { + nf_ct_unlink_expect(exp); + nf_ct_expect_put(exp); + return true; + } + return false; +} +EXPORT_SYMBOL_GPL(nf_ct_remove_expect); + +struct nf_conntrack_expect * +__nf_ct_expect_find(struct net *net, + const struct nf_conntrack_zone *zone, + const struct nf_conntrack_tuple *tuple) +{ + struct nf_conntrack_net *cnet = nf_ct_pernet(net); + struct nf_conntrack_expect *i; + unsigned int h; + + if (!cnet->expect_count) + return NULL; + + h = nf_ct_expect_dst_hash(net, tuple); + hlist_for_each_entry_rcu(i, &nf_ct_expect_hash[h], hnode) { + if (nf_ct_exp_equal(tuple, i, zone, net)) + return i; + } + return NULL; +} +EXPORT_SYMBOL_GPL(__nf_ct_expect_find); + +/* Just find a expectation corresponding to a tuple. */ +struct nf_conntrack_expect * +nf_ct_expect_find_get(struct net *net, + const struct nf_conntrack_zone *zone, + const struct nf_conntrack_tuple *tuple) +{ + struct nf_conntrack_expect *i; + + rcu_read_lock(); + i = __nf_ct_expect_find(net, zone, tuple); + if (i && !refcount_inc_not_zero(&i->use)) + i = NULL; + rcu_read_unlock(); + + return i; +} +EXPORT_SYMBOL_GPL(nf_ct_expect_find_get); + +/* If an expectation for this connection is found, it gets delete from + * global list then returned. */ +struct nf_conntrack_expect * +nf_ct_find_expectation(struct net *net, + const struct nf_conntrack_zone *zone, + const struct nf_conntrack_tuple *tuple) +{ + struct nf_conntrack_net *cnet = nf_ct_pernet(net); + struct nf_conntrack_expect *i, *exp = NULL; + unsigned int h; + + if (!cnet->expect_count) + return NULL; + + h = nf_ct_expect_dst_hash(net, tuple); + hlist_for_each_entry(i, &nf_ct_expect_hash[h], hnode) { + if (!(i->flags & NF_CT_EXPECT_INACTIVE) && + nf_ct_exp_equal(tuple, i, zone, net)) { + exp = i; + break; + } + } + if (!exp) + return NULL; + + /* If master is not in hash table yet (ie. packet hasn't left + this machine yet), how can other end know about expected? + Hence these are not the droids you are looking for (if + master ct never got confirmed, we'd hold a reference to it + and weird things would happen to future packets). */ + if (!nf_ct_is_confirmed(exp->master)) + return NULL; + + /* Avoid race with other CPUs, that for exp->master ct, is + * about to invoke ->destroy(), or nf_ct_delete() via timeout + * or early_drop(). + * + * The refcount_inc_not_zero() check tells: If that fails, we + * know that the ct is being destroyed. If it succeeds, we + * can be sure the ct cannot disappear underneath. + */ + if (unlikely(nf_ct_is_dying(exp->master) || + !refcount_inc_not_zero(&exp->master->ct_general.use))) + return NULL; + + if (exp->flags & NF_CT_EXPECT_PERMANENT) { + refcount_inc(&exp->use); + return exp; + } else if (del_timer(&exp->timeout)) { + nf_ct_unlink_expect(exp); + return exp; + } + /* Undo exp->master refcnt increase, if del_timer() failed */ + nf_ct_put(exp->master); + + return NULL; +} + +/* delete all expectations for this conntrack */ +void nf_ct_remove_expectations(struct nf_conn *ct) +{ + struct nf_conn_help *help = nfct_help(ct); + struct nf_conntrack_expect *exp; + struct hlist_node *next; + + /* Optimization: most connection never expect any others. */ + if (!help) + return; + + spin_lock_bh(&nf_conntrack_expect_lock); + hlist_for_each_entry_safe(exp, next, &help->expectations, lnode) { + nf_ct_remove_expect(exp); + } + spin_unlock_bh(&nf_conntrack_expect_lock); +} +EXPORT_SYMBOL_GPL(nf_ct_remove_expectations); + +/* Would two expected things clash? */ +static inline int expect_clash(const struct nf_conntrack_expect *a, + const struct nf_conntrack_expect *b) +{ + /* Part covered by intersection of masks must be unequal, + otherwise they clash */ + struct nf_conntrack_tuple_mask intersect_mask; + int count; + + intersect_mask.src.u.all = a->mask.src.u.all & b->mask.src.u.all; + + for (count = 0; count < NF_CT_TUPLE_L3SIZE; count++){ + intersect_mask.src.u3.all[count] = + a->mask.src.u3.all[count] & b->mask.src.u3.all[count]; + } + + return nf_ct_tuple_mask_cmp(&a->tuple, &b->tuple, &intersect_mask) && + net_eq(nf_ct_net(a->master), nf_ct_net(b->master)) && + nf_ct_zone_equal_any(a->master, nf_ct_zone(b->master)); +} + +static inline int expect_matches(const struct nf_conntrack_expect *a, + const struct nf_conntrack_expect *b) +{ + return nf_ct_tuple_equal(&a->tuple, &b->tuple) && + nf_ct_tuple_mask_equal(&a->mask, &b->mask) && + net_eq(nf_ct_net(a->master), nf_ct_net(b->master)) && + nf_ct_zone_equal_any(a->master, nf_ct_zone(b->master)); +} + +static bool master_matches(const struct nf_conntrack_expect *a, + const struct nf_conntrack_expect *b, + unsigned int flags) +{ + if (flags & NF_CT_EXP_F_SKIP_MASTER) + return true; + + return a->master == b->master; +} + +/* Generally a bad idea to call this: could have matched already. */ +void nf_ct_unexpect_related(struct nf_conntrack_expect *exp) +{ + spin_lock_bh(&nf_conntrack_expect_lock); + nf_ct_remove_expect(exp); + spin_unlock_bh(&nf_conntrack_expect_lock); +} +EXPORT_SYMBOL_GPL(nf_ct_unexpect_related); + +/* We don't increase the master conntrack refcount for non-fulfilled + * conntracks. During the conntrack destruction, the expectations are + * always killed before the conntrack itself */ +struct nf_conntrack_expect *nf_ct_expect_alloc(struct nf_conn *me) +{ + struct nf_conntrack_expect *new; + + new = kmem_cache_alloc(nf_ct_expect_cachep, GFP_ATOMIC); + if (!new) + return NULL; + + new->master = me; + refcount_set(&new->use, 1); + return new; +} +EXPORT_SYMBOL_GPL(nf_ct_expect_alloc); + +void nf_ct_expect_init(struct nf_conntrack_expect *exp, unsigned int class, + u_int8_t family, + const union nf_inet_addr *saddr, + const union nf_inet_addr *daddr, + u_int8_t proto, const __be16 *src, const __be16 *dst) +{ + int len; + + if (family == AF_INET) + len = 4; + else + len = 16; + + exp->flags = 0; + exp->class = class; + exp->expectfn = NULL; + exp->helper = NULL; + exp->tuple.src.l3num = family; + exp->tuple.dst.protonum = proto; + + if (saddr) { + memcpy(&exp->tuple.src.u3, saddr, len); + if (sizeof(exp->tuple.src.u3) > len) + /* address needs to be cleared for nf_ct_tuple_equal */ + memset((void *)&exp->tuple.src.u3 + len, 0x00, + sizeof(exp->tuple.src.u3) - len); + memset(&exp->mask.src.u3, 0xFF, len); + if (sizeof(exp->mask.src.u3) > len) + memset((void *)&exp->mask.src.u3 + len, 0x00, + sizeof(exp->mask.src.u3) - len); + } else { + memset(&exp->tuple.src.u3, 0x00, sizeof(exp->tuple.src.u3)); + memset(&exp->mask.src.u3, 0x00, sizeof(exp->mask.src.u3)); + } + + if (src) { + exp->tuple.src.u.all = *src; + exp->mask.src.u.all = htons(0xFFFF); + } else { + exp->tuple.src.u.all = 0; + exp->mask.src.u.all = 0; + } + + memcpy(&exp->tuple.dst.u3, daddr, len); + if (sizeof(exp->tuple.dst.u3) > len) + /* address needs to be cleared for nf_ct_tuple_equal */ + memset((void *)&exp->tuple.dst.u3 + len, 0x00, + sizeof(exp->tuple.dst.u3) - len); + + exp->tuple.dst.u.all = *dst; + +#if IS_ENABLED(CONFIG_NF_NAT) + memset(&exp->saved_addr, 0, sizeof(exp->saved_addr)); + memset(&exp->saved_proto, 0, sizeof(exp->saved_proto)); +#endif +} +EXPORT_SYMBOL_GPL(nf_ct_expect_init); + +static void nf_ct_expect_free_rcu(struct rcu_head *head) +{ + struct nf_conntrack_expect *exp; + + exp = container_of(head, struct nf_conntrack_expect, rcu); + kmem_cache_free(nf_ct_expect_cachep, exp); +} + +void nf_ct_expect_put(struct nf_conntrack_expect *exp) +{ + if (refcount_dec_and_test(&exp->use)) + call_rcu(&exp->rcu, nf_ct_expect_free_rcu); +} +EXPORT_SYMBOL_GPL(nf_ct_expect_put); + +static void nf_ct_expect_insert(struct nf_conntrack_expect *exp) +{ + struct nf_conntrack_net *cnet; + struct nf_conn_help *master_help = nfct_help(exp->master); + struct nf_conntrack_helper *helper; + struct net *net = nf_ct_exp_net(exp); + unsigned int h = nf_ct_expect_dst_hash(net, &exp->tuple); + + /* two references : one for hash insert, one for the timer */ + refcount_add(2, &exp->use); + + timer_setup(&exp->timeout, nf_ct_expectation_timed_out, 0); + helper = rcu_dereference_protected(master_help->helper, + lockdep_is_held(&nf_conntrack_expect_lock)); + if (helper) { + exp->timeout.expires = jiffies + + helper->expect_policy[exp->class].timeout * HZ; + } + add_timer(&exp->timeout); + + hlist_add_head_rcu(&exp->lnode, &master_help->expectations); + master_help->expecting[exp->class]++; + + hlist_add_head_rcu(&exp->hnode, &nf_ct_expect_hash[h]); + cnet = nf_ct_pernet(net); + cnet->expect_count++; + + NF_CT_STAT_INC(net, expect_create); +} + +/* Race with expectations being used means we could have none to find; OK. */ +static void evict_oldest_expect(struct nf_conn *master, + struct nf_conntrack_expect *new) +{ + struct nf_conn_help *master_help = nfct_help(master); + struct nf_conntrack_expect *exp, *last = NULL; + + hlist_for_each_entry(exp, &master_help->expectations, lnode) { + if (exp->class == new->class) + last = exp; + } + + if (last) + nf_ct_remove_expect(last); +} + +static inline int __nf_ct_expect_check(struct nf_conntrack_expect *expect, + unsigned int flags) +{ + const struct nf_conntrack_expect_policy *p; + struct nf_conntrack_expect *i; + struct nf_conntrack_net *cnet; + struct nf_conn *master = expect->master; + struct nf_conn_help *master_help = nfct_help(master); + struct nf_conntrack_helper *helper; + struct net *net = nf_ct_exp_net(expect); + struct hlist_node *next; + unsigned int h; + int ret = 0; + + if (!master_help) { + ret = -ESHUTDOWN; + goto out; + } + h = nf_ct_expect_dst_hash(net, &expect->tuple); + hlist_for_each_entry_safe(i, next, &nf_ct_expect_hash[h], hnode) { + if (master_matches(i, expect, flags) && + expect_matches(i, expect)) { + if (i->class != expect->class || + i->master != expect->master) + return -EALREADY; + + if (nf_ct_remove_expect(i)) + break; + } else if (expect_clash(i, expect)) { + ret = -EBUSY; + goto out; + } + } + /* Will be over limit? */ + helper = rcu_dereference_protected(master_help->helper, + lockdep_is_held(&nf_conntrack_expect_lock)); + if (helper) { + p = &helper->expect_policy[expect->class]; + if (p->max_expected && + master_help->expecting[expect->class] >= p->max_expected) { + evict_oldest_expect(master, expect); + if (master_help->expecting[expect->class] + >= p->max_expected) { + ret = -EMFILE; + goto out; + } + } + } + + cnet = nf_ct_pernet(net); + if (cnet->expect_count >= nf_ct_expect_max) { + net_warn_ratelimited("nf_conntrack: expectation table full\n"); + ret = -EMFILE; + } +out: + return ret; +} + +int nf_ct_expect_related_report(struct nf_conntrack_expect *expect, + u32 portid, int report, unsigned int flags) +{ + int ret; + + spin_lock_bh(&nf_conntrack_expect_lock); + ret = __nf_ct_expect_check(expect, flags); + if (ret < 0) + goto out; + + nf_ct_expect_insert(expect); + + spin_unlock_bh(&nf_conntrack_expect_lock); + nf_ct_expect_event_report(IPEXP_NEW, expect, portid, report); + return 0; +out: + spin_unlock_bh(&nf_conntrack_expect_lock); + return ret; +} +EXPORT_SYMBOL_GPL(nf_ct_expect_related_report); + +void nf_ct_expect_iterate_destroy(bool (*iter)(struct nf_conntrack_expect *e, void *data), + void *data) +{ + struct nf_conntrack_expect *exp; + const struct hlist_node *next; + unsigned int i; + + spin_lock_bh(&nf_conntrack_expect_lock); + + for (i = 0; i < nf_ct_expect_hsize; i++) { + hlist_for_each_entry_safe(exp, next, + &nf_ct_expect_hash[i], + hnode) { + if (iter(exp, data) && del_timer(&exp->timeout)) { + nf_ct_unlink_expect(exp); + nf_ct_expect_put(exp); + } + } + } + + spin_unlock_bh(&nf_conntrack_expect_lock); +} +EXPORT_SYMBOL_GPL(nf_ct_expect_iterate_destroy); + +void nf_ct_expect_iterate_net(struct net *net, + bool (*iter)(struct nf_conntrack_expect *e, void *data), + void *data, + u32 portid, int report) +{ + struct nf_conntrack_expect *exp; + const struct hlist_node *next; + unsigned int i; + + spin_lock_bh(&nf_conntrack_expect_lock); + + for (i = 0; i < nf_ct_expect_hsize; i++) { + hlist_for_each_entry_safe(exp, next, + &nf_ct_expect_hash[i], + hnode) { + + if (!net_eq(nf_ct_exp_net(exp), net)) + continue; + + if (iter(exp, data) && del_timer(&exp->timeout)) { + nf_ct_unlink_expect_report(exp, portid, report); + nf_ct_expect_put(exp); + } + } + } + + spin_unlock_bh(&nf_conntrack_expect_lock); +} +EXPORT_SYMBOL_GPL(nf_ct_expect_iterate_net); + +#ifdef CONFIG_NF_CONNTRACK_PROCFS +struct ct_expect_iter_state { + struct seq_net_private p; + unsigned int bucket; +}; + +static struct hlist_node *ct_expect_get_first(struct seq_file *seq) +{ + struct ct_expect_iter_state *st = seq->private; + struct hlist_node *n; + + for (st->bucket = 0; st->bucket < nf_ct_expect_hsize; st->bucket++) { + n = rcu_dereference(hlist_first_rcu(&nf_ct_expect_hash[st->bucket])); + if (n) + return n; + } + return NULL; +} + +static struct hlist_node *ct_expect_get_next(struct seq_file *seq, + struct hlist_node *head) +{ + struct ct_expect_iter_state *st = seq->private; + + head = rcu_dereference(hlist_next_rcu(head)); + while (head == NULL) { + if (++st->bucket >= nf_ct_expect_hsize) + return NULL; + head = rcu_dereference(hlist_first_rcu(&nf_ct_expect_hash[st->bucket])); + } + return head; +} + +static struct hlist_node *ct_expect_get_idx(struct seq_file *seq, loff_t pos) +{ + struct hlist_node *head = ct_expect_get_first(seq); + + if (head) + while (pos && (head = ct_expect_get_next(seq, head))) + pos--; + return pos ? NULL : head; +} + +static void *exp_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(RCU) +{ + rcu_read_lock(); + return ct_expect_get_idx(seq, *pos); +} + +static void *exp_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + (*pos)++; + return ct_expect_get_next(seq, v); +} + +static void exp_seq_stop(struct seq_file *seq, void *v) + __releases(RCU) +{ + rcu_read_unlock(); +} + +static int exp_seq_show(struct seq_file *s, void *v) +{ + struct nf_conntrack_expect *expect; + struct nf_conntrack_helper *helper; + struct hlist_node *n = v; + char *delim = ""; + + expect = hlist_entry(n, struct nf_conntrack_expect, hnode); + + if (expect->timeout.function) + seq_printf(s, "%ld ", timer_pending(&expect->timeout) + ? (long)(expect->timeout.expires - jiffies)/HZ : 0); + else + seq_puts(s, "- "); + seq_printf(s, "l3proto = %u proto=%u ", + expect->tuple.src.l3num, + expect->tuple.dst.protonum); + print_tuple(s, &expect->tuple, + nf_ct_l4proto_find(expect->tuple.dst.protonum)); + + if (expect->flags & NF_CT_EXPECT_PERMANENT) { + seq_puts(s, "PERMANENT"); + delim = ","; + } + if (expect->flags & NF_CT_EXPECT_INACTIVE) { + seq_printf(s, "%sINACTIVE", delim); + delim = ","; + } + if (expect->flags & NF_CT_EXPECT_USERSPACE) + seq_printf(s, "%sUSERSPACE", delim); + + helper = rcu_dereference(nfct_help(expect->master)->helper); + if (helper) { + seq_printf(s, "%s%s", expect->flags ? " " : "", helper->name); + if (helper->expect_policy[expect->class].name[0]) + seq_printf(s, "/%s", + helper->expect_policy[expect->class].name); + } + + seq_putc(s, '\n'); + + return 0; +} + +static const struct seq_operations exp_seq_ops = { + .start = exp_seq_start, + .next = exp_seq_next, + .stop = exp_seq_stop, + .show = exp_seq_show +}; +#endif /* CONFIG_NF_CONNTRACK_PROCFS */ + +static int exp_proc_init(struct net *net) +{ +#ifdef CONFIG_NF_CONNTRACK_PROCFS + struct proc_dir_entry *proc; + kuid_t root_uid; + kgid_t root_gid; + + proc = proc_create_net("nf_conntrack_expect", 0440, net->proc_net, + &exp_seq_ops, sizeof(struct ct_expect_iter_state)); + if (!proc) + return -ENOMEM; + + root_uid = make_kuid(net->user_ns, 0); + root_gid = make_kgid(net->user_ns, 0); + if (uid_valid(root_uid) && gid_valid(root_gid)) + proc_set_user(proc, root_uid, root_gid); +#endif /* CONFIG_NF_CONNTRACK_PROCFS */ + return 0; +} + +static void exp_proc_remove(struct net *net) +{ +#ifdef CONFIG_NF_CONNTRACK_PROCFS + remove_proc_entry("nf_conntrack_expect", net->proc_net); +#endif /* CONFIG_NF_CONNTRACK_PROCFS */ +} + +module_param_named(expect_hashsize, nf_ct_expect_hsize, uint, 0400); + +int nf_conntrack_expect_pernet_init(struct net *net) +{ + return exp_proc_init(net); +} + +void nf_conntrack_expect_pernet_fini(struct net *net) +{ + exp_proc_remove(net); +} + +int nf_conntrack_expect_init(void) +{ + if (!nf_ct_expect_hsize) { + nf_ct_expect_hsize = nf_conntrack_htable_size / 256; + if (!nf_ct_expect_hsize) + nf_ct_expect_hsize = 1; + } + nf_ct_expect_max = nf_ct_expect_hsize * 4; + nf_ct_expect_cachep = kmem_cache_create("nf_conntrack_expect", + sizeof(struct nf_conntrack_expect), + 0, 0, NULL); + if (!nf_ct_expect_cachep) + return -ENOMEM; + + nf_ct_expect_hash = nf_ct_alloc_hashtable(&nf_ct_expect_hsize, 0); + if (!nf_ct_expect_hash) { + kmem_cache_destroy(nf_ct_expect_cachep); + return -ENOMEM; + } + + return 0; +} + +void nf_conntrack_expect_fini(void) +{ + rcu_barrier(); /* Wait for call_rcu() before destroy */ + kmem_cache_destroy(nf_ct_expect_cachep); + kvfree(nf_ct_expect_hash); +} diff --git a/net/netfilter/nf_conntrack_extend.c b/net/netfilter/nf_conntrack_extend.c new file mode 100644 index 000000000..dd62cc12e --- /dev/null +++ b/net/netfilter/nf_conntrack_extend.c @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* Structure dynamic extension infrastructure + * Copyright (C) 2004 Rusty Russell IBM Corporation + * Copyright (C) 2007 Netfilter Core Team + * Copyright (C) 2007 USAGI/WIDE Project + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define NF_CT_EXT_PREALLOC 128u /* conntrack events are on by default */ + +atomic_t nf_conntrack_ext_genid __read_mostly = ATOMIC_INIT(1); + +static const u8 nf_ct_ext_type_len[NF_CT_EXT_NUM] = { + [NF_CT_EXT_HELPER] = sizeof(struct nf_conn_help), +#if IS_ENABLED(CONFIG_NF_NAT) + [NF_CT_EXT_NAT] = sizeof(struct nf_conn_nat), +#endif + [NF_CT_EXT_SEQADJ] = sizeof(struct nf_conn_seqadj), + [NF_CT_EXT_ACCT] = sizeof(struct nf_conn_acct), +#ifdef CONFIG_NF_CONNTRACK_EVENTS + [NF_CT_EXT_ECACHE] = sizeof(struct nf_conntrack_ecache), +#endif +#ifdef CONFIG_NF_CONNTRACK_TIMESTAMP + [NF_CT_EXT_TSTAMP] = sizeof(struct nf_conn_tstamp), +#endif +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + [NF_CT_EXT_TIMEOUT] = sizeof(struct nf_conn_timeout), +#endif +#ifdef CONFIG_NF_CONNTRACK_LABELS + [NF_CT_EXT_LABELS] = sizeof(struct nf_conn_labels), +#endif +#if IS_ENABLED(CONFIG_NETFILTER_SYNPROXY) + [NF_CT_EXT_SYNPROXY] = sizeof(struct nf_conn_synproxy), +#endif +#if IS_ENABLED(CONFIG_NET_ACT_CT) + [NF_CT_EXT_ACT_CT] = sizeof(struct nf_conn_act_ct_ext), +#endif +}; + +static __always_inline unsigned int total_extension_size(void) +{ + /* remember to add new extensions below */ + BUILD_BUG_ON(NF_CT_EXT_NUM > 10); + + return sizeof(struct nf_ct_ext) + + sizeof(struct nf_conn_help) +#if IS_ENABLED(CONFIG_NF_NAT) + + sizeof(struct nf_conn_nat) +#endif + + sizeof(struct nf_conn_seqadj) + + sizeof(struct nf_conn_acct) +#ifdef CONFIG_NF_CONNTRACK_EVENTS + + sizeof(struct nf_conntrack_ecache) +#endif +#ifdef CONFIG_NF_CONNTRACK_TIMESTAMP + + sizeof(struct nf_conn_tstamp) +#endif +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + + sizeof(struct nf_conn_timeout) +#endif +#ifdef CONFIG_NF_CONNTRACK_LABELS + + sizeof(struct nf_conn_labels) +#endif +#if IS_ENABLED(CONFIG_NETFILTER_SYNPROXY) + + sizeof(struct nf_conn_synproxy) +#endif +#if IS_ENABLED(CONFIG_NET_ACT_CT) + + sizeof(struct nf_conn_act_ct_ext) +#endif + ; +} + +void *nf_ct_ext_add(struct nf_conn *ct, enum nf_ct_ext_id id, gfp_t gfp) +{ + unsigned int newlen, newoff, oldlen, alloc; + struct nf_ct_ext *new; + + /* Conntrack must not be confirmed to avoid races on reallocation. */ + WARN_ON(nf_ct_is_confirmed(ct)); + + /* struct nf_ct_ext uses u8 to store offsets/size */ + BUILD_BUG_ON(total_extension_size() > 255u); + + if (ct->ext) { + const struct nf_ct_ext *old = ct->ext; + + if (__nf_ct_ext_exist(old, id)) + return NULL; + oldlen = old->len; + } else { + oldlen = sizeof(*new); + } + + newoff = ALIGN(oldlen, __alignof__(struct nf_ct_ext)); + newlen = newoff + nf_ct_ext_type_len[id]; + + alloc = max(newlen, NF_CT_EXT_PREALLOC); + new = krealloc(ct->ext, alloc, gfp); + if (!new) + return NULL; + + if (!ct->ext) { + memset(new->offset, 0, sizeof(new->offset)); + new->gen_id = atomic_read(&nf_conntrack_ext_genid); + } + + new->offset[id] = newoff; + new->len = newlen; + memset((void *)new + newoff, 0, newlen - newoff); + + ct->ext = new; + return (void *)new + newoff; +} +EXPORT_SYMBOL(nf_ct_ext_add); + +/* Use nf_ct_ext_find wrapper. This is only useful for unconfirmed entries. */ +void *__nf_ct_ext_find(const struct nf_ct_ext *ext, u8 id) +{ + unsigned int gen_id = atomic_read(&nf_conntrack_ext_genid); + unsigned int this_id = READ_ONCE(ext->gen_id); + + if (!__nf_ct_ext_exist(ext, id)) + return NULL; + + if (this_id == 0 || ext->gen_id == gen_id) + return (void *)ext + ext->offset[id]; + + return NULL; +} +EXPORT_SYMBOL(__nf_ct_ext_find); + +void nf_ct_ext_bump_genid(void) +{ + unsigned int value = atomic_inc_return(&nf_conntrack_ext_genid); + + if (value == UINT_MAX) + atomic_set(&nf_conntrack_ext_genid, 1); + + msleep(HZ); +} diff --git a/net/netfilter/nf_conntrack_ftp.c b/net/netfilter/nf_conntrack_ftp.c new file mode 100644 index 000000000..617f744a2 --- /dev/null +++ b/net/netfilter/nf_conntrack_ftp.c @@ -0,0 +1,604 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* FTP extension for connection tracking. */ + +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2004 Netfilter Core Team + * (C) 2003,2004 USAGI/WIDE Project + * (C) 2006-2012 Patrick McHardy + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#define HELPER_NAME "ftp" + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Rusty Russell "); +MODULE_DESCRIPTION("ftp connection tracking helper"); +MODULE_ALIAS("ip_conntrack_ftp"); +MODULE_ALIAS_NFCT_HELPER(HELPER_NAME); +static DEFINE_SPINLOCK(nf_ftp_lock); + +#define MAX_PORTS 8 +static u_int16_t ports[MAX_PORTS]; +static unsigned int ports_c; +module_param_array(ports, ushort, &ports_c, 0400); + +static bool loose; +module_param(loose, bool, 0600); + +unsigned int (*nf_nat_ftp_hook)(struct sk_buff *skb, + enum ip_conntrack_info ctinfo, + enum nf_ct_ftp_type type, + unsigned int protoff, + unsigned int matchoff, + unsigned int matchlen, + struct nf_conntrack_expect *exp); +EXPORT_SYMBOL_GPL(nf_nat_ftp_hook); + +static int try_rfc959(const char *, size_t, struct nf_conntrack_man *, + char, unsigned int *); +static int try_rfc1123(const char *, size_t, struct nf_conntrack_man *, + char, unsigned int *); +static int try_eprt(const char *, size_t, struct nf_conntrack_man *, + char, unsigned int *); +static int try_epsv_response(const char *, size_t, struct nf_conntrack_man *, + char, unsigned int *); + +static struct ftp_search { + const char *pattern; + size_t plen; + char skip; + char term; + enum nf_ct_ftp_type ftptype; + int (*getnum)(const char *, size_t, struct nf_conntrack_man *, char, unsigned int *); +} search[IP_CT_DIR_MAX][2] = { + [IP_CT_DIR_ORIGINAL] = { + { + .pattern = "PORT", + .plen = sizeof("PORT") - 1, + .skip = ' ', + .term = '\r', + .ftptype = NF_CT_FTP_PORT, + .getnum = try_rfc959, + }, + { + .pattern = "EPRT", + .plen = sizeof("EPRT") - 1, + .skip = ' ', + .term = '\r', + .ftptype = NF_CT_FTP_EPRT, + .getnum = try_eprt, + }, + }, + [IP_CT_DIR_REPLY] = { + { + .pattern = "227 ", + .plen = sizeof("227 ") - 1, + .ftptype = NF_CT_FTP_PASV, + .getnum = try_rfc1123, + }, + { + .pattern = "229 ", + .plen = sizeof("229 ") - 1, + .skip = '(', + .term = ')', + .ftptype = NF_CT_FTP_EPSV, + .getnum = try_epsv_response, + }, + }, +}; + +static int +get_ipv6_addr(const char *src, size_t dlen, struct in6_addr *dst, u_int8_t term) +{ + const char *end; + int ret = in6_pton(src, min_t(size_t, dlen, 0xffff), (u8 *)dst, term, &end); + if (ret > 0) + return (int)(end - src); + return 0; +} + +static int try_number(const char *data, size_t dlen, u_int32_t array[], + int array_size, char sep, char term) +{ + u_int32_t i, len; + + memset(array, 0, sizeof(array[0])*array_size); + + /* Keep data pointing at next char. */ + for (i = 0, len = 0; len < dlen && i < array_size; len++, data++) { + if (*data >= '0' && *data <= '9') { + array[i] = array[i]*10 + *data - '0'; + } + else if (*data == sep) + i++; + else { + /* Unexpected character; true if it's the + terminator (or we don't care about one) + and we're finished. */ + if ((*data == term || !term) && i == array_size - 1) + return len; + + pr_debug("Char %u (got %u nums) `%u' unexpected\n", + len, i, *data); + return 0; + } + } + pr_debug("Failed to fill %u numbers separated by %c\n", + array_size, sep); + return 0; +} + +/* Returns 0, or length of numbers: 192,168,1,1,5,6 */ +static int try_rfc959(const char *data, size_t dlen, + struct nf_conntrack_man *cmd, char term, + unsigned int *offset) +{ + int length; + u_int32_t array[6]; + + length = try_number(data, dlen, array, 6, ',', term); + if (length == 0) + return 0; + + cmd->u3.ip = htonl((array[0] << 24) | (array[1] << 16) | + (array[2] << 8) | array[3]); + cmd->u.tcp.port = htons((array[4] << 8) | array[5]); + return length; +} + +/* + * From RFC 1123: + * The format of the 227 reply to a PASV command is not + * well standardized. In particular, an FTP client cannot + * assume that the parentheses shown on page 40 of RFC-959 + * will be present (and in fact, Figure 3 on page 43 omits + * them). Therefore, a User-FTP program that interprets + * the PASV reply must scan the reply for the first digit + * of the host and port numbers. + */ +static int try_rfc1123(const char *data, size_t dlen, + struct nf_conntrack_man *cmd, char term, + unsigned int *offset) +{ + int i; + for (i = 0; i < dlen; i++) + if (isdigit(data[i])) + break; + + if (i == dlen) + return 0; + + *offset += i; + + return try_rfc959(data + i, dlen - i, cmd, 0, offset); +} + +/* Grab port: number up to delimiter */ +static int get_port(const char *data, int start, size_t dlen, char delim, + __be16 *port) +{ + u_int16_t tmp_port = 0; + int i; + + for (i = start; i < dlen; i++) { + /* Finished? */ + if (data[i] == delim) { + if (tmp_port == 0) + break; + *port = htons(tmp_port); + pr_debug("get_port: return %d\n", tmp_port); + return i + 1; + } + else if (data[i] >= '0' && data[i] <= '9') + tmp_port = tmp_port*10 + data[i] - '0'; + else { /* Some other crap */ + pr_debug("get_port: invalid char.\n"); + break; + } + } + return 0; +} + +/* Returns 0, or length of numbers: |1|132.235.1.2|6275| or |2|3ffe::1|6275| */ +static int try_eprt(const char *data, size_t dlen, struct nf_conntrack_man *cmd, + char term, unsigned int *offset) +{ + char delim; + int length; + + /* First character is delimiter, then "1" for IPv4 or "2" for IPv6, + then delimiter again. */ + if (dlen <= 3) { + pr_debug("EPRT: too short\n"); + return 0; + } + delim = data[0]; + if (isdigit(delim) || delim < 33 || delim > 126 || data[2] != delim) { + pr_debug("try_eprt: invalid delimiter.\n"); + return 0; + } + + if ((cmd->l3num == PF_INET && data[1] != '1') || + (cmd->l3num == PF_INET6 && data[1] != '2')) { + pr_debug("EPRT: invalid protocol number.\n"); + return 0; + } + + pr_debug("EPRT: Got %c%c%c\n", delim, data[1], delim); + + if (data[1] == '1') { + u_int32_t array[4]; + + /* Now we have IP address. */ + length = try_number(data + 3, dlen - 3, array, 4, '.', delim); + if (length != 0) + cmd->u3.ip = htonl((array[0] << 24) | (array[1] << 16) + | (array[2] << 8) | array[3]); + } else { + /* Now we have IPv6 address. */ + length = get_ipv6_addr(data + 3, dlen - 3, + (struct in6_addr *)cmd->u3.ip6, delim); + } + + if (length == 0) + return 0; + pr_debug("EPRT: Got IP address!\n"); + /* Start offset includes initial "|1|", and trailing delimiter */ + return get_port(data, 3 + length + 1, dlen, delim, &cmd->u.tcp.port); +} + +/* Returns 0, or length of numbers: |||6446| */ +static int try_epsv_response(const char *data, size_t dlen, + struct nf_conntrack_man *cmd, char term, + unsigned int *offset) +{ + char delim; + + /* Three delimiters. */ + if (dlen <= 3) return 0; + delim = data[0]; + if (isdigit(delim) || delim < 33 || delim > 126 || + data[1] != delim || data[2] != delim) + return 0; + + return get_port(data, 3, dlen, delim, &cmd->u.tcp.port); +} + +/* Return 1 for match, 0 for accept, -1 for partial. */ +static int find_pattern(const char *data, size_t dlen, + const char *pattern, size_t plen, + char skip, char term, + unsigned int *numoff, + unsigned int *numlen, + struct nf_conntrack_man *cmd, + int (*getnum)(const char *, size_t, + struct nf_conntrack_man *, char, + unsigned int *)) +{ + size_t i = plen; + + pr_debug("find_pattern `%s': dlen = %zu\n", pattern, dlen); + + if (dlen <= plen) { + /* Short packet: try for partial? */ + if (strncasecmp(data, pattern, dlen) == 0) + return -1; + else return 0; + } + + if (strncasecmp(data, pattern, plen) != 0) + return 0; + + pr_debug("Pattern matches!\n"); + /* Now we've found the constant string, try to skip + to the 'skip' character */ + if (skip) { + for (i = plen; data[i] != skip; i++) + if (i == dlen - 1) return -1; + + /* Skip over the last character */ + i++; + } + + pr_debug("Skipped up to 0x%hhx delimiter!\n", skip); + + *numoff = i; + *numlen = getnum(data + i, dlen - i, cmd, term, numoff); + if (!*numlen) + return -1; + + pr_debug("Match succeeded!\n"); + return 1; +} + +/* Look up to see if we're just after a \n. */ +static int find_nl_seq(u32 seq, const struct nf_ct_ftp_master *info, int dir) +{ + unsigned int i; + + for (i = 0; i < info->seq_aft_nl_num[dir]; i++) + if (info->seq_aft_nl[dir][i] == seq) + return 1; + return 0; +} + +/* We don't update if it's older than what we have. */ +static void update_nl_seq(struct nf_conn *ct, u32 nl_seq, + struct nf_ct_ftp_master *info, int dir, + struct sk_buff *skb) +{ + unsigned int i, oldest; + + /* Look for oldest: if we find exact match, we're done. */ + for (i = 0; i < info->seq_aft_nl_num[dir]; i++) { + if (info->seq_aft_nl[dir][i] == nl_seq) + return; + } + + if (info->seq_aft_nl_num[dir] < NUM_SEQ_TO_REMEMBER) { + info->seq_aft_nl[dir][info->seq_aft_nl_num[dir]++] = nl_seq; + } else { + if (before(info->seq_aft_nl[dir][0], info->seq_aft_nl[dir][1])) + oldest = 0; + else + oldest = 1; + + if (after(nl_seq, info->seq_aft_nl[dir][oldest])) + info->seq_aft_nl[dir][oldest] = nl_seq; + } +} + +static int help(struct sk_buff *skb, + unsigned int protoff, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo) +{ + unsigned int dataoff, datalen; + const struct tcphdr *th; + struct tcphdr _tcph; + const char *fb_ptr; + int ret; + u32 seq; + int dir = CTINFO2DIR(ctinfo); + unsigned int matchlen, matchoff; + struct nf_ct_ftp_master *ct_ftp_info = nfct_help_data(ct); + struct nf_conntrack_expect *exp; + union nf_inet_addr *daddr; + struct nf_conntrack_man cmd = {}; + unsigned int i; + int found = 0, ends_in_nl; + typeof(nf_nat_ftp_hook) nf_nat_ftp; + + /* Until there's been traffic both ways, don't look in packets. */ + if (ctinfo != IP_CT_ESTABLISHED && + ctinfo != IP_CT_ESTABLISHED_REPLY) { + pr_debug("ftp: Conntrackinfo = %u\n", ctinfo); + return NF_ACCEPT; + } + + if (unlikely(skb_linearize(skb))) + return NF_DROP; + + th = skb_header_pointer(skb, protoff, sizeof(_tcph), &_tcph); + if (th == NULL) + return NF_ACCEPT; + + dataoff = protoff + th->doff * 4; + /* No data? */ + if (dataoff >= skb->len) { + pr_debug("ftp: dataoff(%u) >= skblen(%u)\n", dataoff, + skb->len); + return NF_ACCEPT; + } + datalen = skb->len - dataoff; + + /* seqadj (nat) uses ct->lock internally, nf_nat_ftp would cause deadlock */ + spin_lock_bh(&nf_ftp_lock); + fb_ptr = skb->data + dataoff; + + ends_in_nl = (fb_ptr[datalen - 1] == '\n'); + seq = ntohl(th->seq) + datalen; + + /* Look up to see if we're just after a \n. */ + if (!find_nl_seq(ntohl(th->seq), ct_ftp_info, dir)) { + /* We're picking up this, clear flags and let it continue */ + if (unlikely(ct_ftp_info->flags[dir] & NF_CT_FTP_SEQ_PICKUP)) { + ct_ftp_info->flags[dir] ^= NF_CT_FTP_SEQ_PICKUP; + goto skip_nl_seq; + } + + /* Now if this ends in \n, update ftp info. */ + pr_debug("nf_conntrack_ftp: wrong seq pos %s(%u) or %s(%u)\n", + ct_ftp_info->seq_aft_nl_num[dir] > 0 ? "" : "(UNSET)", + ct_ftp_info->seq_aft_nl[dir][0], + ct_ftp_info->seq_aft_nl_num[dir] > 1 ? "" : "(UNSET)", + ct_ftp_info->seq_aft_nl[dir][1]); + ret = NF_ACCEPT; + goto out_update_nl; + } + +skip_nl_seq: + /* Initialize IP/IPv6 addr to expected address (it's not mentioned + in EPSV responses) */ + cmd.l3num = nf_ct_l3num(ct); + memcpy(cmd.u3.all, &ct->tuplehash[dir].tuple.src.u3.all, + sizeof(cmd.u3.all)); + + for (i = 0; i < ARRAY_SIZE(search[dir]); i++) { + found = find_pattern(fb_ptr, datalen, + search[dir][i].pattern, + search[dir][i].plen, + search[dir][i].skip, + search[dir][i].term, + &matchoff, &matchlen, + &cmd, + search[dir][i].getnum); + if (found) break; + } + if (found == -1) { + /* We don't usually drop packets. After all, this is + connection tracking, not packet filtering. + However, it is necessary for accurate tracking in + this case. */ + nf_ct_helper_log(skb, ct, "partial matching of `%s'", + search[dir][i].pattern); + ret = NF_DROP; + goto out; + } else if (found == 0) { /* No match */ + ret = NF_ACCEPT; + goto out_update_nl; + } + + pr_debug("conntrack_ftp: match `%.*s' (%u bytes at %u)\n", + matchlen, fb_ptr + matchoff, + matchlen, ntohl(th->seq) + matchoff); + + exp = nf_ct_expect_alloc(ct); + if (exp == NULL) { + nf_ct_helper_log(skb, ct, "cannot alloc expectation"); + ret = NF_DROP; + goto out; + } + + /* We refer to the reverse direction ("!dir") tuples here, + * because we're expecting something in the other direction. + * Doesn't matter unless NAT is happening. */ + daddr = &ct->tuplehash[!dir].tuple.dst.u3; + + /* Update the ftp info */ + if ((cmd.l3num == nf_ct_l3num(ct)) && + memcmp(&cmd.u3.all, &ct->tuplehash[dir].tuple.src.u3.all, + sizeof(cmd.u3.all))) { + /* Enrico Scholz's passive FTP to partially RNAT'd ftp + server: it really wants us to connect to a + different IP address. Simply don't record it for + NAT. */ + if (cmd.l3num == PF_INET) { + pr_debug("NOT RECORDING: %pI4 != %pI4\n", + &cmd.u3.ip, + &ct->tuplehash[dir].tuple.src.u3.ip); + } else { + pr_debug("NOT RECORDING: %pI6 != %pI6\n", + cmd.u3.ip6, + ct->tuplehash[dir].tuple.src.u3.ip6); + } + + /* Thanks to Cristiano Lincoln Mattos + for reporting this potential + problem (DMZ machines opening holes to internal + networks, or the packet filter itself). */ + if (!loose) { + ret = NF_ACCEPT; + goto out_put_expect; + } + daddr = &cmd.u3; + } + + nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT, cmd.l3num, + &ct->tuplehash[!dir].tuple.src.u3, daddr, + IPPROTO_TCP, NULL, &cmd.u.tcp.port); + + /* Now, NAT might want to mangle the packet, and register the + * (possibly changed) expectation itself. */ + nf_nat_ftp = rcu_dereference(nf_nat_ftp_hook); + if (nf_nat_ftp && ct->status & IPS_NAT_MASK) + ret = nf_nat_ftp(skb, ctinfo, search[dir][i].ftptype, + protoff, matchoff, matchlen, exp); + else { + /* Can't expect this? Best to drop packet now. */ + if (nf_ct_expect_related(exp, 0) != 0) { + nf_ct_helper_log(skb, ct, "cannot add expectation"); + ret = NF_DROP; + } else + ret = NF_ACCEPT; + } + +out_put_expect: + nf_ct_expect_put(exp); + +out_update_nl: + /* Now if this ends in \n, update ftp info. Seq may have been + * adjusted by NAT code. */ + if (ends_in_nl) + update_nl_seq(ct, seq, ct_ftp_info, dir, skb); + out: + spin_unlock_bh(&nf_ftp_lock); + return ret; +} + +static int nf_ct_ftp_from_nlattr(struct nlattr *attr, struct nf_conn *ct) +{ + struct nf_ct_ftp_master *ftp = nfct_help_data(ct); + + /* This conntrack has been injected from user-space, always pick up + * sequence tracking. Otherwise, the first FTP command after the + * failover breaks. + */ + ftp->flags[IP_CT_DIR_ORIGINAL] |= NF_CT_FTP_SEQ_PICKUP; + ftp->flags[IP_CT_DIR_REPLY] |= NF_CT_FTP_SEQ_PICKUP; + return 0; +} + +static struct nf_conntrack_helper ftp[MAX_PORTS * 2] __read_mostly; + +static const struct nf_conntrack_expect_policy ftp_exp_policy = { + .max_expected = 1, + .timeout = 5 * 60, +}; + +static void __exit nf_conntrack_ftp_fini(void) +{ + nf_conntrack_helpers_unregister(ftp, ports_c * 2); +} + +static int __init nf_conntrack_ftp_init(void) +{ + int i, ret = 0; + + NF_CT_HELPER_BUILD_BUG_ON(sizeof(struct nf_ct_ftp_master)); + + if (ports_c == 0) + ports[ports_c++] = FTP_PORT; + + /* FIXME should be configurable whether IPv4 and IPv6 FTP connections + are tracked or not - YK */ + for (i = 0; i < ports_c; i++) { + nf_ct_helper_init(&ftp[2 * i], AF_INET, IPPROTO_TCP, + HELPER_NAME, FTP_PORT, ports[i], ports[i], + &ftp_exp_policy, 0, help, + nf_ct_ftp_from_nlattr, THIS_MODULE); + nf_ct_helper_init(&ftp[2 * i + 1], AF_INET6, IPPROTO_TCP, + HELPER_NAME, FTP_PORT, ports[i], ports[i], + &ftp_exp_policy, 0, help, + nf_ct_ftp_from_nlattr, THIS_MODULE); + } + + ret = nf_conntrack_helpers_register(ftp, ports_c * 2); + if (ret < 0) { + pr_err("failed to register helpers\n"); + return ret; + } + + return 0; +} + +module_init(nf_conntrack_ftp_init); +module_exit(nf_conntrack_ftp_fini); diff --git a/net/netfilter/nf_conntrack_h323_asn1.c b/net/netfilter/nf_conntrack_h323_asn1.c new file mode 100644 index 000000000..e697a824b --- /dev/null +++ b/net/netfilter/nf_conntrack_h323_asn1.c @@ -0,0 +1,939 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * BER and PER decoding library for H.323 conntrack/NAT module. + * + * Copyright (c) 2006 by Jing Min Zhao + * + * See nf_conntrack_helper_h323_asn1.h for details. + */ + +#ifdef __KERNEL__ +#include +#else +#include +#endif +#include + +/* Trace Flag */ +#ifndef H323_TRACE +#define H323_TRACE 0 +#endif + +#if H323_TRACE +#define TAB_SIZE 4 +#define IFTHEN(cond, act) if(cond){act;} +#ifdef __KERNEL__ +#define PRINT printk +#else +#define PRINT printf +#endif +#define FNAME(name) name, +#else +#define IFTHEN(cond, act) +#define PRINT(fmt, args...) +#define FNAME(name) +#endif + +/* ASN.1 Types */ +#define NUL 0 +#define BOOL 1 +#define OID 2 +#define INT 3 +#define ENUM 4 +#define BITSTR 5 +#define NUMSTR 6 +#define NUMDGT 6 +#define TBCDSTR 6 +#define OCTSTR 7 +#define PRTSTR 7 +#define IA5STR 7 +#define GENSTR 7 +#define BMPSTR 8 +#define SEQ 9 +#define SET 9 +#define SEQOF 10 +#define SETOF 10 +#define CHOICE 11 + +/* Constraint Types */ +#define FIXD 0 +/* #define BITS 1-8 */ +#define BYTE 9 +#define WORD 10 +#define CONS 11 +#define SEMI 12 +#define UNCO 13 + +/* ASN.1 Type Attributes */ +#define SKIP 0 +#define STOP 1 +#define DECODE 2 +#define EXT 4 +#define OPEN 8 +#define OPT 16 + + +/* ASN.1 Field Structure */ +typedef struct field_t { +#if H323_TRACE + char *name; +#endif + unsigned char type; + unsigned char sz; + unsigned char lb; + unsigned char ub; + unsigned short attr; + unsigned short offset; + const struct field_t *fields; +} field_t; + +/* Bit Stream */ +struct bitstr { + unsigned char *buf; + unsigned char *beg; + unsigned char *end; + unsigned char *cur; + unsigned int bit; +}; + +/* Tool Functions */ +#define INC_BIT(bs) if((++(bs)->bit)>7){(bs)->cur++;(bs)->bit=0;} +#define INC_BITS(bs,b) if(((bs)->bit+=(b))>7){(bs)->cur+=(bs)->bit>>3;(bs)->bit&=7;} +#define BYTE_ALIGN(bs) if((bs)->bit){(bs)->cur++;(bs)->bit=0;} +static unsigned int get_len(struct bitstr *bs); +static unsigned int get_bit(struct bitstr *bs); +static unsigned int get_bits(struct bitstr *bs, unsigned int b); +static unsigned int get_bitmap(struct bitstr *bs, unsigned int b); +static unsigned int get_uint(struct bitstr *bs, int b); + +/* Decoder Functions */ +static int decode_nul(struct bitstr *bs, const struct field_t *f, char *base, int level); +static int decode_bool(struct bitstr *bs, const struct field_t *f, char *base, int level); +static int decode_oid(struct bitstr *bs, const struct field_t *f, char *base, int level); +static int decode_int(struct bitstr *bs, const struct field_t *f, char *base, int level); +static int decode_enum(struct bitstr *bs, const struct field_t *f, char *base, int level); +static int decode_bitstr(struct bitstr *bs, const struct field_t *f, char *base, int level); +static int decode_numstr(struct bitstr *bs, const struct field_t *f, char *base, int level); +static int decode_octstr(struct bitstr *bs, const struct field_t *f, char *base, int level); +static int decode_bmpstr(struct bitstr *bs, const struct field_t *f, char *base, int level); +static int decode_seq(struct bitstr *bs, const struct field_t *f, char *base, int level); +static int decode_seqof(struct bitstr *bs, const struct field_t *f, char *base, int level); +static int decode_choice(struct bitstr *bs, const struct field_t *f, char *base, int level); + +/* Decoder Functions Vector */ +typedef int (*decoder_t)(struct bitstr *, const struct field_t *, char *, int); +static const decoder_t Decoders[] = { + decode_nul, + decode_bool, + decode_oid, + decode_int, + decode_enum, + decode_bitstr, + decode_numstr, + decode_octstr, + decode_bmpstr, + decode_seq, + decode_seqof, + decode_choice, +}; + +/* + * H.323 Types + */ +#include "nf_conntrack_h323_types.c" + +/* + * Functions + */ + +/* Assume bs is aligned && v < 16384 */ +static unsigned int get_len(struct bitstr *bs) +{ + unsigned int v; + + v = *bs->cur++; + + if (v & 0x80) { + v &= 0x3f; + v <<= 8; + v += *bs->cur++; + } + + return v; +} + +static int nf_h323_error_boundary(struct bitstr *bs, size_t bytes, size_t bits) +{ + bits += bs->bit; + bytes += bits / BITS_PER_BYTE; + if (bits % BITS_PER_BYTE > 0) + bytes++; + + if (bs->cur + bytes > bs->end) + return 1; + + return 0; +} + +static unsigned int get_bit(struct bitstr *bs) +{ + unsigned int b = (*bs->cur) & (0x80 >> bs->bit); + + INC_BIT(bs); + + return b; +} + +/* Assume b <= 8 */ +static unsigned int get_bits(struct bitstr *bs, unsigned int b) +{ + unsigned int v, l; + + v = (*bs->cur) & (0xffU >> bs->bit); + l = b + bs->bit; + + if (l < 8) { + v >>= 8 - l; + bs->bit = l; + } else if (l == 8) { + bs->cur++; + bs->bit = 0; + } else { /* l > 8 */ + + v <<= 8; + v += *(++bs->cur); + v >>= 16 - l; + bs->bit = l - 8; + } + + return v; +} + +/* Assume b <= 32 */ +static unsigned int get_bitmap(struct bitstr *bs, unsigned int b) +{ + unsigned int v, l, shift, bytes; + + if (!b) + return 0; + + l = bs->bit + b; + + if (l < 8) { + v = (unsigned int)(*bs->cur) << (bs->bit + 24); + bs->bit = l; + } else if (l == 8) { + v = (unsigned int)(*bs->cur++) << (bs->bit + 24); + bs->bit = 0; + } else { + for (bytes = l >> 3, shift = 24, v = 0; bytes; + bytes--, shift -= 8) + v |= (unsigned int)(*bs->cur++) << shift; + + if (l < 32) { + v |= (unsigned int)(*bs->cur) << shift; + v <<= bs->bit; + } else if (l > 32) { + v <<= bs->bit; + v |= (*bs->cur) >> (8 - bs->bit); + } + + bs->bit = l & 0x7; + } + + v &= 0xffffffff << (32 - b); + + return v; +} + +/* + * Assume bs is aligned and sizeof(unsigned int) == 4 + */ +static unsigned int get_uint(struct bitstr *bs, int b) +{ + unsigned int v = 0; + + switch (b) { + case 4: + v |= *bs->cur++; + v <<= 8; + fallthrough; + case 3: + v |= *bs->cur++; + v <<= 8; + fallthrough; + case 2: + v |= *bs->cur++; + v <<= 8; + fallthrough; + case 1: + v |= *bs->cur++; + break; + } + return v; +} + +static int decode_nul(struct bitstr *bs, const struct field_t *f, + char *base, int level) +{ + PRINT("%*.s%s\n", level * TAB_SIZE, " ", f->name); + + return H323_ERROR_NONE; +} + +static int decode_bool(struct bitstr *bs, const struct field_t *f, + char *base, int level) +{ + PRINT("%*.s%s\n", level * TAB_SIZE, " ", f->name); + + INC_BIT(bs); + if (nf_h323_error_boundary(bs, 0, 0)) + return H323_ERROR_BOUND; + return H323_ERROR_NONE; +} + +static int decode_oid(struct bitstr *bs, const struct field_t *f, + char *base, int level) +{ + int len; + + PRINT("%*.s%s\n", level * TAB_SIZE, " ", f->name); + + BYTE_ALIGN(bs); + if (nf_h323_error_boundary(bs, 1, 0)) + return H323_ERROR_BOUND; + + len = *bs->cur++; + bs->cur += len; + if (nf_h323_error_boundary(bs, 0, 0)) + return H323_ERROR_BOUND; + + return H323_ERROR_NONE; +} + +static int decode_int(struct bitstr *bs, const struct field_t *f, + char *base, int level) +{ + unsigned int len; + + PRINT("%*.s%s", level * TAB_SIZE, " ", f->name); + + switch (f->sz) { + case BYTE: /* Range == 256 */ + BYTE_ALIGN(bs); + bs->cur++; + break; + case WORD: /* 257 <= Range <= 64K */ + BYTE_ALIGN(bs); + bs->cur += 2; + break; + case CONS: /* 64K < Range < 4G */ + if (nf_h323_error_boundary(bs, 0, 2)) + return H323_ERROR_BOUND; + len = get_bits(bs, 2) + 1; + BYTE_ALIGN(bs); + if (base && (f->attr & DECODE)) { /* timeToLive */ + unsigned int v = get_uint(bs, len) + f->lb; + PRINT(" = %u", v); + *((unsigned int *)(base + f->offset)) = v; + } + bs->cur += len; + break; + case UNCO: + BYTE_ALIGN(bs); + if (nf_h323_error_boundary(bs, 2, 0)) + return H323_ERROR_BOUND; + len = get_len(bs); + bs->cur += len; + break; + default: /* 2 <= Range <= 255 */ + INC_BITS(bs, f->sz); + break; + } + + PRINT("\n"); + + if (nf_h323_error_boundary(bs, 0, 0)) + return H323_ERROR_BOUND; + return H323_ERROR_NONE; +} + +static int decode_enum(struct bitstr *bs, const struct field_t *f, + char *base, int level) +{ + PRINT("%*.s%s\n", level * TAB_SIZE, " ", f->name); + + if ((f->attr & EXT) && get_bit(bs)) { + INC_BITS(bs, 7); + } else { + INC_BITS(bs, f->sz); + } + + if (nf_h323_error_boundary(bs, 0, 0)) + return H323_ERROR_BOUND; + return H323_ERROR_NONE; +} + +static int decode_bitstr(struct bitstr *bs, const struct field_t *f, + char *base, int level) +{ + unsigned int len; + + PRINT("%*.s%s\n", level * TAB_SIZE, " ", f->name); + + BYTE_ALIGN(bs); + switch (f->sz) { + case FIXD: /* fixed length > 16 */ + len = f->lb; + break; + case WORD: /* 2-byte length */ + if (nf_h323_error_boundary(bs, 2, 0)) + return H323_ERROR_BOUND; + len = (*bs->cur++) << 8; + len += (*bs->cur++) + f->lb; + break; + case SEMI: + if (nf_h323_error_boundary(bs, 2, 0)) + return H323_ERROR_BOUND; + len = get_len(bs); + break; + default: + len = 0; + break; + } + + bs->cur += len >> 3; + bs->bit = len & 7; + + if (nf_h323_error_boundary(bs, 0, 0)) + return H323_ERROR_BOUND; + return H323_ERROR_NONE; +} + +static int decode_numstr(struct bitstr *bs, const struct field_t *f, + char *base, int level) +{ + unsigned int len; + + PRINT("%*.s%s\n", level * TAB_SIZE, " ", f->name); + + /* 2 <= Range <= 255 */ + if (nf_h323_error_boundary(bs, 0, f->sz)) + return H323_ERROR_BOUND; + len = get_bits(bs, f->sz) + f->lb; + + BYTE_ALIGN(bs); + INC_BITS(bs, (len << 2)); + + if (nf_h323_error_boundary(bs, 0, 0)) + return H323_ERROR_BOUND; + return H323_ERROR_NONE; +} + +static int decode_octstr(struct bitstr *bs, const struct field_t *f, + char *base, int level) +{ + unsigned int len; + + PRINT("%*.s%s", level * TAB_SIZE, " ", f->name); + + switch (f->sz) { + case FIXD: /* Range == 1 */ + if (f->lb > 2) { + BYTE_ALIGN(bs); + if (base && (f->attr & DECODE)) { + /* The IP Address */ + IFTHEN(f->lb == 4, + PRINT(" = %d.%d.%d.%d:%d", + bs->cur[0], bs->cur[1], + bs->cur[2], bs->cur[3], + bs->cur[4] * 256 + bs->cur[5])); + *((unsigned int *)(base + f->offset)) = + bs->cur - bs->buf; + } + } + len = f->lb; + break; + case BYTE: /* Range == 256 */ + BYTE_ALIGN(bs); + if (nf_h323_error_boundary(bs, 1, 0)) + return H323_ERROR_BOUND; + len = (*bs->cur++) + f->lb; + break; + case SEMI: + BYTE_ALIGN(bs); + if (nf_h323_error_boundary(bs, 2, 0)) + return H323_ERROR_BOUND; + len = get_len(bs) + f->lb; + break; + default: /* 2 <= Range <= 255 */ + if (nf_h323_error_boundary(bs, 0, f->sz)) + return H323_ERROR_BOUND; + len = get_bits(bs, f->sz) + f->lb; + BYTE_ALIGN(bs); + break; + } + + bs->cur += len; + + PRINT("\n"); + + if (nf_h323_error_boundary(bs, 0, 0)) + return H323_ERROR_BOUND; + return H323_ERROR_NONE; +} + +static int decode_bmpstr(struct bitstr *bs, const struct field_t *f, + char *base, int level) +{ + unsigned int len; + + PRINT("%*.s%s\n", level * TAB_SIZE, " ", f->name); + + switch (f->sz) { + case BYTE: /* Range == 256 */ + BYTE_ALIGN(bs); + if (nf_h323_error_boundary(bs, 1, 0)) + return H323_ERROR_BOUND; + len = (*bs->cur++) + f->lb; + break; + default: /* 2 <= Range <= 255 */ + if (nf_h323_error_boundary(bs, 0, f->sz)) + return H323_ERROR_BOUND; + len = get_bits(bs, f->sz) + f->lb; + BYTE_ALIGN(bs); + break; + } + + bs->cur += len << 1; + + if (nf_h323_error_boundary(bs, 0, 0)) + return H323_ERROR_BOUND; + return H323_ERROR_NONE; +} + +static int decode_seq(struct bitstr *bs, const struct field_t *f, + char *base, int level) +{ + unsigned int ext, bmp, i, opt, len = 0, bmp2, bmp2_len; + int err; + const struct field_t *son; + unsigned char *beg = NULL; + + PRINT("%*.s%s\n", level * TAB_SIZE, " ", f->name); + + /* Decode? */ + base = (base && (f->attr & DECODE)) ? base + f->offset : NULL; + + /* Extensible? */ + if (nf_h323_error_boundary(bs, 0, 1)) + return H323_ERROR_BOUND; + ext = (f->attr & EXT) ? get_bit(bs) : 0; + + /* Get fields bitmap */ + if (nf_h323_error_boundary(bs, 0, f->sz)) + return H323_ERROR_BOUND; + bmp = get_bitmap(bs, f->sz); + if (base) + *(unsigned int *)base = bmp; + + /* Decode the root components */ + for (i = opt = 0, son = f->fields; i < f->lb; i++, son++) { + if (son->attr & STOP) { + PRINT("%*.s%s\n", (level + 1) * TAB_SIZE, " ", + son->name); + return H323_ERROR_STOP; + } + + if (son->attr & OPT) { /* Optional component */ + if (!((0x80000000U >> (opt++)) & bmp)) /* Not exist */ + continue; + } + + /* Decode */ + if (son->attr & OPEN) { /* Open field */ + if (nf_h323_error_boundary(bs, 2, 0)) + return H323_ERROR_BOUND; + len = get_len(bs); + if (nf_h323_error_boundary(bs, len, 0)) + return H323_ERROR_BOUND; + if (!base || !(son->attr & DECODE)) { + PRINT("%*.s%s\n", (level + 1) * TAB_SIZE, + " ", son->name); + bs->cur += len; + continue; + } + beg = bs->cur; + + /* Decode */ + if ((err = (Decoders[son->type]) (bs, son, base, + level + 1)) < + H323_ERROR_NONE) + return err; + + bs->cur = beg + len; + bs->bit = 0; + } else if ((err = (Decoders[son->type]) (bs, son, base, + level + 1)) < + H323_ERROR_NONE) + return err; + } + + /* No extension? */ + if (!ext) + return H323_ERROR_NONE; + + /* Get the extension bitmap */ + if (nf_h323_error_boundary(bs, 0, 7)) + return H323_ERROR_BOUND; + bmp2_len = get_bits(bs, 7) + 1; + if (nf_h323_error_boundary(bs, 0, bmp2_len)) + return H323_ERROR_BOUND; + bmp2 = get_bitmap(bs, bmp2_len); + bmp |= bmp2 >> f->sz; + if (base) + *(unsigned int *)base = bmp; + BYTE_ALIGN(bs); + + /* Decode the extension components */ + for (opt = 0; opt < bmp2_len; opt++, i++, son++) { + /* Check Range */ + if (i >= f->ub) { /* Newer Version? */ + if (nf_h323_error_boundary(bs, 2, 0)) + return H323_ERROR_BOUND; + len = get_len(bs); + if (nf_h323_error_boundary(bs, len, 0)) + return H323_ERROR_BOUND; + bs->cur += len; + continue; + } + + if (son->attr & STOP) { + PRINT("%*.s%s\n", (level + 1) * TAB_SIZE, " ", + son->name); + return H323_ERROR_STOP; + } + + if (!((0x80000000 >> opt) & bmp2)) /* Not present */ + continue; + + if (nf_h323_error_boundary(bs, 2, 0)) + return H323_ERROR_BOUND; + len = get_len(bs); + if (nf_h323_error_boundary(bs, len, 0)) + return H323_ERROR_BOUND; + if (!base || !(son->attr & DECODE)) { + PRINT("%*.s%s\n", (level + 1) * TAB_SIZE, " ", + son->name); + bs->cur += len; + continue; + } + beg = bs->cur; + + if ((err = (Decoders[son->type]) (bs, son, base, + level + 1)) < + H323_ERROR_NONE) + return err; + + bs->cur = beg + len; + bs->bit = 0; + } + return H323_ERROR_NONE; +} + +static int decode_seqof(struct bitstr *bs, const struct field_t *f, + char *base, int level) +{ + unsigned int count, effective_count = 0, i, len = 0; + int err; + const struct field_t *son; + unsigned char *beg = NULL; + + PRINT("%*.s%s\n", level * TAB_SIZE, " ", f->name); + + /* Decode? */ + base = (base && (f->attr & DECODE)) ? base + f->offset : NULL; + + /* Decode item count */ + switch (f->sz) { + case BYTE: + BYTE_ALIGN(bs); + if (nf_h323_error_boundary(bs, 1, 0)) + return H323_ERROR_BOUND; + count = *bs->cur++; + break; + case WORD: + BYTE_ALIGN(bs); + if (nf_h323_error_boundary(bs, 2, 0)) + return H323_ERROR_BOUND; + count = *bs->cur++; + count <<= 8; + count += *bs->cur++; + break; + case SEMI: + BYTE_ALIGN(bs); + if (nf_h323_error_boundary(bs, 2, 0)) + return H323_ERROR_BOUND; + count = get_len(bs); + break; + default: + if (nf_h323_error_boundary(bs, 0, f->sz)) + return H323_ERROR_BOUND; + count = get_bits(bs, f->sz); + break; + } + count += f->lb; + + /* Write Count */ + if (base) { + effective_count = count > f->ub ? f->ub : count; + *(unsigned int *)base = effective_count; + base += sizeof(unsigned int); + } + + /* Decode nested field */ + son = f->fields; + if (base) + base -= son->offset; + for (i = 0; i < count; i++) { + if (son->attr & OPEN) { + BYTE_ALIGN(bs); + if (nf_h323_error_boundary(bs, 2, 0)) + return H323_ERROR_BOUND; + len = get_len(bs); + if (nf_h323_error_boundary(bs, len, 0)) + return H323_ERROR_BOUND; + if (!base || !(son->attr & DECODE)) { + PRINT("%*.s%s\n", (level + 1) * TAB_SIZE, + " ", son->name); + bs->cur += len; + continue; + } + beg = bs->cur; + + if ((err = (Decoders[son->type]) (bs, son, + i < + effective_count ? + base : NULL, + level + 1)) < + H323_ERROR_NONE) + return err; + + bs->cur = beg + len; + bs->bit = 0; + } else + if ((err = (Decoders[son->type]) (bs, son, + i < + effective_count ? + base : NULL, + level + 1)) < + H323_ERROR_NONE) + return err; + + if (base) + base += son->offset; + } + + return H323_ERROR_NONE; +} + +static int decode_choice(struct bitstr *bs, const struct field_t *f, + char *base, int level) +{ + unsigned int type, ext, len = 0; + int err; + const struct field_t *son; + unsigned char *beg = NULL; + + PRINT("%*.s%s\n", level * TAB_SIZE, " ", f->name); + + /* Decode? */ + base = (base && (f->attr & DECODE)) ? base + f->offset : NULL; + + /* Decode the choice index number */ + if (nf_h323_error_boundary(bs, 0, 1)) + return H323_ERROR_BOUND; + if ((f->attr & EXT) && get_bit(bs)) { + ext = 1; + if (nf_h323_error_boundary(bs, 0, 7)) + return H323_ERROR_BOUND; + type = get_bits(bs, 7) + f->lb; + } else { + ext = 0; + if (nf_h323_error_boundary(bs, 0, f->sz)) + return H323_ERROR_BOUND; + type = get_bits(bs, f->sz); + if (type >= f->lb) + return H323_ERROR_RANGE; + } + + /* Write Type */ + if (base) + *(unsigned int *)base = type; + + /* Check Range */ + if (type >= f->ub) { /* Newer version? */ + BYTE_ALIGN(bs); + if (nf_h323_error_boundary(bs, 2, 0)) + return H323_ERROR_BOUND; + len = get_len(bs); + if (nf_h323_error_boundary(bs, len, 0)) + return H323_ERROR_BOUND; + bs->cur += len; + return H323_ERROR_NONE; + } + + /* Transfer to son level */ + son = &f->fields[type]; + if (son->attr & STOP) { + PRINT("%*.s%s\n", (level + 1) * TAB_SIZE, " ", son->name); + return H323_ERROR_STOP; + } + + if (ext || (son->attr & OPEN)) { + BYTE_ALIGN(bs); + if (nf_h323_error_boundary(bs, len, 0)) + return H323_ERROR_BOUND; + len = get_len(bs); + if (nf_h323_error_boundary(bs, len, 0)) + return H323_ERROR_BOUND; + if (!base || !(son->attr & DECODE)) { + PRINT("%*.s%s\n", (level + 1) * TAB_SIZE, " ", + son->name); + bs->cur += len; + return H323_ERROR_NONE; + } + beg = bs->cur; + + if ((err = (Decoders[son->type]) (bs, son, base, level + 1)) < + H323_ERROR_NONE) + return err; + + bs->cur = beg + len; + bs->bit = 0; + } else if ((err = (Decoders[son->type]) (bs, son, base, level + 1)) < + H323_ERROR_NONE) + return err; + + return H323_ERROR_NONE; +} + +int DecodeRasMessage(unsigned char *buf, size_t sz, RasMessage *ras) +{ + static const struct field_t ras_message = { + FNAME("RasMessage") CHOICE, 5, 24, 32, DECODE | EXT, + 0, _RasMessage + }; + struct bitstr bs; + + bs.buf = bs.beg = bs.cur = buf; + bs.end = buf + sz; + bs.bit = 0; + + return decode_choice(&bs, &ras_message, (char *) ras, 0); +} + +static int DecodeH323_UserInformation(unsigned char *buf, unsigned char *beg, + size_t sz, H323_UserInformation *uuie) +{ + static const struct field_t h323_userinformation = { + FNAME("H323-UserInformation") SEQ, 1, 2, 2, DECODE | EXT, + 0, _H323_UserInformation + }; + struct bitstr bs; + + bs.buf = buf; + bs.beg = bs.cur = beg; + bs.end = beg + sz; + bs.bit = 0; + + return decode_seq(&bs, &h323_userinformation, (char *) uuie, 0); +} + +int DecodeMultimediaSystemControlMessage(unsigned char *buf, size_t sz, + MultimediaSystemControlMessage * + mscm) +{ + static const struct field_t multimediasystemcontrolmessage = { + FNAME("MultimediaSystemControlMessage") CHOICE, 2, 4, 4, + DECODE | EXT, 0, _MultimediaSystemControlMessage + }; + struct bitstr bs; + + bs.buf = bs.beg = bs.cur = buf; + bs.end = buf + sz; + bs.bit = 0; + + return decode_choice(&bs, &multimediasystemcontrolmessage, + (char *) mscm, 0); +} + +int DecodeQ931(unsigned char *buf, size_t sz, Q931 *q931) +{ + unsigned char *p = buf; + int len; + + if (!p || sz < 1) + return H323_ERROR_BOUND; + + /* Protocol Discriminator */ + if (*p != 0x08) { + PRINT("Unknown Protocol Discriminator\n"); + return H323_ERROR_RANGE; + } + p++; + sz--; + + /* CallReferenceValue */ + if (sz < 1) + return H323_ERROR_BOUND; + len = *p++; + sz--; + if (sz < len) + return H323_ERROR_BOUND; + p += len; + sz -= len; + + /* Message Type */ + if (sz < 2) + return H323_ERROR_BOUND; + q931->MessageType = *p++; + sz--; + PRINT("MessageType = %02X\n", q931->MessageType); + if (*p & 0x80) { + p++; + sz--; + } + + /* Decode Information Elements */ + while (sz > 0) { + if (*p == 0x7e) { /* UserUserIE */ + if (sz < 3) + break; + p++; + len = *p++ << 8; + len |= *p++; + sz -= 3; + if (sz < len) + break; + p++; + len--; + return DecodeH323_UserInformation(buf, p, len, + &q931->UUIE); + } + p++; + sz--; + if (sz < 1) + break; + len = *p++; + sz--; + if (sz < len) + break; + p += len; + sz -= len; + } + + PRINT("Q.931 UUIE not found\n"); + + return H323_ERROR_BOUND; +} diff --git a/net/netfilter/nf_conntrack_h323_main.c b/net/netfilter/nf_conntrack_h323_main.c new file mode 100644 index 000000000..5a9bce24f --- /dev/null +++ b/net/netfilter/nf_conntrack_h323_main.c @@ -0,0 +1,1803 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * H.323 connection tracking helper + * + * Copyright (c) 2006 Jing Min Zhao + * Copyright (c) 2006-2012 Patrick McHardy + * + * Based on the 'brute force' H.323 connection tracking module by + * Jozsef Kadlecsik + * + * For more information, please see http://nath323.sourceforge.net/ + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#define H323_MAX_SIZE 65535 + +/* Parameters */ +static unsigned int default_rrq_ttl __read_mostly = 300; +module_param(default_rrq_ttl, uint, 0600); +MODULE_PARM_DESC(default_rrq_ttl, "use this TTL if it's missing in RRQ"); + +static int gkrouted_only __read_mostly = 1; +module_param(gkrouted_only, int, 0600); +MODULE_PARM_DESC(gkrouted_only, "only accept calls from gatekeeper"); + +static bool callforward_filter __read_mostly = true; +module_param(callforward_filter, bool, 0600); +MODULE_PARM_DESC(callforward_filter, "only create call forwarding expectations " + "if both endpoints are on different sides " + "(determined by routing information)"); + +const struct nfct_h323_nat_hooks __rcu *nfct_h323_nat_hook __read_mostly; +EXPORT_SYMBOL_GPL(nfct_h323_nat_hook); + +static DEFINE_SPINLOCK(nf_h323_lock); +static char *h323_buffer; + +static struct nf_conntrack_helper nf_conntrack_helper_h245; +static struct nf_conntrack_helper nf_conntrack_helper_q931[]; +static struct nf_conntrack_helper nf_conntrack_helper_ras[]; + +static int get_tpkt_data(struct sk_buff *skb, unsigned int protoff, + struct nf_conn *ct, enum ip_conntrack_info ctinfo, + unsigned char **data, int *datalen, int *dataoff) +{ + struct nf_ct_h323_master *info = nfct_help_data(ct); + int dir = CTINFO2DIR(ctinfo); + const struct tcphdr *th; + struct tcphdr _tcph; + int tcpdatalen; + int tcpdataoff; + unsigned char *tpkt; + int tpktlen; + int tpktoff; + + /* Get TCP header */ + th = skb_header_pointer(skb, protoff, sizeof(_tcph), &_tcph); + if (th == NULL) + return 0; + + /* Get TCP data offset */ + tcpdataoff = protoff + th->doff * 4; + + /* Get TCP data length */ + tcpdatalen = skb->len - tcpdataoff; + if (tcpdatalen <= 0) /* No TCP data */ + goto clear_out; + + if (tcpdatalen > H323_MAX_SIZE) + tcpdatalen = H323_MAX_SIZE; + + if (*data == NULL) { /* first TPKT */ + /* Get first TPKT pointer */ + tpkt = skb_header_pointer(skb, tcpdataoff, tcpdatalen, + h323_buffer); + if (!tpkt) + goto clear_out; + + /* Validate TPKT identifier */ + if (tcpdatalen < 4 || tpkt[0] != 0x03 || tpkt[1] != 0) { + /* Netmeeting sends TPKT header and data separately */ + if (info->tpkt_len[dir] > 0) { + pr_debug("nf_ct_h323: previous packet " + "indicated separate TPKT data of %hu " + "bytes\n", info->tpkt_len[dir]); + if (info->tpkt_len[dir] <= tcpdatalen) { + /* Yes, there was a TPKT header + * received */ + *data = tpkt; + *datalen = info->tpkt_len[dir]; + *dataoff = 0; + goto out; + } + + /* Fragmented TPKT */ + pr_debug("nf_ct_h323: fragmented TPKT\n"); + goto clear_out; + } + + /* It is not even a TPKT */ + return 0; + } + tpktoff = 0; + } else { /* Next TPKT */ + tpktoff = *dataoff + *datalen; + tcpdatalen -= tpktoff; + if (tcpdatalen <= 4) /* No more TPKT */ + goto clear_out; + tpkt = *data + *datalen; + + /* Validate TPKT identifier */ + if (tpkt[0] != 0x03 || tpkt[1] != 0) + goto clear_out; + } + + /* Validate TPKT length */ + tpktlen = tpkt[2] * 256 + tpkt[3]; + if (tpktlen < 4) + goto clear_out; + if (tpktlen > tcpdatalen) { + if (tcpdatalen == 4) { /* Separate TPKT header */ + /* Netmeeting sends TPKT header and data separately */ + pr_debug("nf_ct_h323: separate TPKT header indicates " + "there will be TPKT data of %d bytes\n", + tpktlen - 4); + info->tpkt_len[dir] = tpktlen - 4; + return 0; + } + + pr_debug("nf_ct_h323: incomplete TPKT (fragmented?)\n"); + goto clear_out; + } + + /* This is the encapsulated data */ + *data = tpkt + 4; + *datalen = tpktlen - 4; + *dataoff = tpktoff + 4; + + out: + /* Clear TPKT length */ + info->tpkt_len[dir] = 0; + return 1; + + clear_out: + info->tpkt_len[dir] = 0; + return 0; +} + +static int get_h245_addr(struct nf_conn *ct, const unsigned char *data, + H245_TransportAddress *taddr, + union nf_inet_addr *addr, __be16 *port) +{ + const unsigned char *p; + int len; + + if (taddr->choice != eH245_TransportAddress_unicastAddress) + return 0; + + switch (taddr->unicastAddress.choice) { + case eUnicastAddress_iPAddress: + if (nf_ct_l3num(ct) != AF_INET) + return 0; + p = data + taddr->unicastAddress.iPAddress.network; + len = 4; + break; + case eUnicastAddress_iP6Address: + if (nf_ct_l3num(ct) != AF_INET6) + return 0; + p = data + taddr->unicastAddress.iP6Address.network; + len = 16; + break; + default: + return 0; + } + + memcpy(addr, p, len); + memset((void *)addr + len, 0, sizeof(*addr) - len); + memcpy(port, p + len, sizeof(__be16)); + + return 1; +} + +static int expect_rtp_rtcp(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, int dataoff, + H245_TransportAddress *taddr) +{ + const struct nfct_h323_nat_hooks *nathook; + int dir = CTINFO2DIR(ctinfo); + int ret = 0; + __be16 port; + __be16 rtp_port, rtcp_port; + union nf_inet_addr addr; + struct nf_conntrack_expect *rtp_exp; + struct nf_conntrack_expect *rtcp_exp; + + /* Read RTP or RTCP address */ + if (!get_h245_addr(ct, *data, taddr, &addr, &port) || + memcmp(&addr, &ct->tuplehash[dir].tuple.src.u3, sizeof(addr)) || + port == 0) + return 0; + + /* RTP port is even */ + rtp_port = port & ~htons(1); + rtcp_port = port | htons(1); + + /* Create expect for RTP */ + if ((rtp_exp = nf_ct_expect_alloc(ct)) == NULL) + return -1; + nf_ct_expect_init(rtp_exp, NF_CT_EXPECT_CLASS_DEFAULT, nf_ct_l3num(ct), + &ct->tuplehash[!dir].tuple.src.u3, + &ct->tuplehash[!dir].tuple.dst.u3, + IPPROTO_UDP, NULL, &rtp_port); + + /* Create expect for RTCP */ + if ((rtcp_exp = nf_ct_expect_alloc(ct)) == NULL) { + nf_ct_expect_put(rtp_exp); + return -1; + } + nf_ct_expect_init(rtcp_exp, NF_CT_EXPECT_CLASS_DEFAULT, nf_ct_l3num(ct), + &ct->tuplehash[!dir].tuple.src.u3, + &ct->tuplehash[!dir].tuple.dst.u3, + IPPROTO_UDP, NULL, &rtcp_port); + + nathook = rcu_dereference(nfct_h323_nat_hook); + if (memcmp(&ct->tuplehash[dir].tuple.src.u3, + &ct->tuplehash[!dir].tuple.dst.u3, + sizeof(ct->tuplehash[dir].tuple.src.u3)) && + nathook && + nf_ct_l3num(ct) == NFPROTO_IPV4 && + ct->status & IPS_NAT_MASK) { + /* NAT needed */ + ret = nathook->nat_rtp_rtcp(skb, ct, ctinfo, protoff, data, dataoff, + taddr, port, rtp_port, rtp_exp, rtcp_exp); + } else { /* Conntrack only */ + if (nf_ct_expect_related(rtp_exp, 0) == 0) { + if (nf_ct_expect_related(rtcp_exp, 0) == 0) { + pr_debug("nf_ct_h323: expect RTP "); + nf_ct_dump_tuple(&rtp_exp->tuple); + pr_debug("nf_ct_h323: expect RTCP "); + nf_ct_dump_tuple(&rtcp_exp->tuple); + } else { + nf_ct_unexpect_related(rtp_exp); + ret = -1; + } + } else + ret = -1; + } + + nf_ct_expect_put(rtp_exp); + nf_ct_expect_put(rtcp_exp); + + return ret; +} + +static int expect_t120(struct sk_buff *skb, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, int dataoff, + H245_TransportAddress *taddr) +{ + const struct nfct_h323_nat_hooks *nathook; + int dir = CTINFO2DIR(ctinfo); + int ret = 0; + __be16 port; + union nf_inet_addr addr; + struct nf_conntrack_expect *exp; + + /* Read T.120 address */ + if (!get_h245_addr(ct, *data, taddr, &addr, &port) || + memcmp(&addr, &ct->tuplehash[dir].tuple.src.u3, sizeof(addr)) || + port == 0) + return 0; + + /* Create expect for T.120 connections */ + if ((exp = nf_ct_expect_alloc(ct)) == NULL) + return -1; + nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT, nf_ct_l3num(ct), + &ct->tuplehash[!dir].tuple.src.u3, + &ct->tuplehash[!dir].tuple.dst.u3, + IPPROTO_TCP, NULL, &port); + exp->flags = NF_CT_EXPECT_PERMANENT; /* Accept multiple channels */ + + nathook = rcu_dereference(nfct_h323_nat_hook); + if (memcmp(&ct->tuplehash[dir].tuple.src.u3, + &ct->tuplehash[!dir].tuple.dst.u3, + sizeof(ct->tuplehash[dir].tuple.src.u3)) && + nathook && + nf_ct_l3num(ct) == NFPROTO_IPV4 && + ct->status & IPS_NAT_MASK) { + /* NAT needed */ + ret = nathook->nat_t120(skb, ct, ctinfo, protoff, data, + dataoff, taddr, port, exp); + } else { /* Conntrack only */ + if (nf_ct_expect_related(exp, 0) == 0) { + pr_debug("nf_ct_h323: expect T.120 "); + nf_ct_dump_tuple(&exp->tuple); + } else + ret = -1; + } + + nf_ct_expect_put(exp); + + return ret; +} + +static int process_h245_channel(struct sk_buff *skb, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, int dataoff, + H2250LogicalChannelParameters *channel) +{ + int ret; + + if (channel->options & eH2250LogicalChannelParameters_mediaChannel) { + /* RTP */ + ret = expect_rtp_rtcp(skb, ct, ctinfo, protoff, data, dataoff, + &channel->mediaChannel); + if (ret < 0) + return -1; + } + + if (channel-> + options & eH2250LogicalChannelParameters_mediaControlChannel) { + /* RTCP */ + ret = expect_rtp_rtcp(skb, ct, ctinfo, protoff, data, dataoff, + &channel->mediaControlChannel); + if (ret < 0) + return -1; + } + + return 0; +} + +static int process_olc(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, int dataoff, + OpenLogicalChannel *olc) +{ + int ret; + + pr_debug("nf_ct_h323: OpenLogicalChannel\n"); + + if (olc->forwardLogicalChannelParameters.multiplexParameters.choice == + eOpenLogicalChannel_forwardLogicalChannelParameters_multiplexParameters_h2250LogicalChannelParameters) + { + ret = process_h245_channel(skb, ct, ctinfo, + protoff, data, dataoff, + &olc-> + forwardLogicalChannelParameters. + multiplexParameters. + h2250LogicalChannelParameters); + if (ret < 0) + return -1; + } + + if ((olc->options & + eOpenLogicalChannel_reverseLogicalChannelParameters) && + (olc->reverseLogicalChannelParameters.options & + eOpenLogicalChannel_reverseLogicalChannelParameters_multiplexParameters) + && (olc->reverseLogicalChannelParameters.multiplexParameters. + choice == + eOpenLogicalChannel_reverseLogicalChannelParameters_multiplexParameters_h2250LogicalChannelParameters)) + { + ret = + process_h245_channel(skb, ct, ctinfo, + protoff, data, dataoff, + &olc-> + reverseLogicalChannelParameters. + multiplexParameters. + h2250LogicalChannelParameters); + if (ret < 0) + return -1; + } + + if ((olc->options & eOpenLogicalChannel_separateStack) && + olc->forwardLogicalChannelParameters.dataType.choice == + eDataType_data && + olc->forwardLogicalChannelParameters.dataType.data.application. + choice == eDataApplicationCapability_application_t120 && + olc->forwardLogicalChannelParameters.dataType.data.application. + t120.choice == eDataProtocolCapability_separateLANStack && + olc->separateStack.networkAddress.choice == + eNetworkAccessParameters_networkAddress_localAreaAddress) { + ret = expect_t120(skb, ct, ctinfo, protoff, data, dataoff, + &olc->separateStack.networkAddress. + localAreaAddress); + if (ret < 0) + return -1; + } + + return 0; +} + +static int process_olca(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, unsigned char **data, int dataoff, + OpenLogicalChannelAck *olca) +{ + H2250LogicalChannelAckParameters *ack; + int ret; + + pr_debug("nf_ct_h323: OpenLogicalChannelAck\n"); + + if ((olca->options & + eOpenLogicalChannelAck_reverseLogicalChannelParameters) && + (olca->reverseLogicalChannelParameters.options & + eOpenLogicalChannelAck_reverseLogicalChannelParameters_multiplexParameters) + && (olca->reverseLogicalChannelParameters.multiplexParameters. + choice == + eOpenLogicalChannelAck_reverseLogicalChannelParameters_multiplexParameters_h2250LogicalChannelParameters)) + { + ret = process_h245_channel(skb, ct, ctinfo, + protoff, data, dataoff, + &olca-> + reverseLogicalChannelParameters. + multiplexParameters. + h2250LogicalChannelParameters); + if (ret < 0) + return -1; + } + + if ((olca->options & + eOpenLogicalChannelAck_forwardMultiplexAckParameters) && + (olca->forwardMultiplexAckParameters.choice == + eOpenLogicalChannelAck_forwardMultiplexAckParameters_h2250LogicalChannelAckParameters)) + { + ack = &olca->forwardMultiplexAckParameters. + h2250LogicalChannelAckParameters; + if (ack->options & + eH2250LogicalChannelAckParameters_mediaChannel) { + /* RTP */ + ret = expect_rtp_rtcp(skb, ct, ctinfo, + protoff, data, dataoff, + &ack->mediaChannel); + if (ret < 0) + return -1; + } + + if (ack->options & + eH2250LogicalChannelAckParameters_mediaControlChannel) { + /* RTCP */ + ret = expect_rtp_rtcp(skb, ct, ctinfo, + protoff, data, dataoff, + &ack->mediaControlChannel); + if (ret < 0) + return -1; + } + } + + if ((olca->options & eOpenLogicalChannelAck_separateStack) && + olca->separateStack.networkAddress.choice == + eNetworkAccessParameters_networkAddress_localAreaAddress) { + ret = expect_t120(skb, ct, ctinfo, protoff, data, dataoff, + &olca->separateStack.networkAddress. + localAreaAddress); + if (ret < 0) + return -1; + } + + return 0; +} + +static int process_h245(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, unsigned char **data, int dataoff, + MultimediaSystemControlMessage *mscm) +{ + switch (mscm->choice) { + case eMultimediaSystemControlMessage_request: + if (mscm->request.choice == + eRequestMessage_openLogicalChannel) { + return process_olc(skb, ct, ctinfo, + protoff, data, dataoff, + &mscm->request.openLogicalChannel); + } + pr_debug("nf_ct_h323: H.245 Request %d\n", + mscm->request.choice); + break; + case eMultimediaSystemControlMessage_response: + if (mscm->response.choice == + eResponseMessage_openLogicalChannelAck) { + return process_olca(skb, ct, ctinfo, + protoff, data, dataoff, + &mscm->response. + openLogicalChannelAck); + } + pr_debug("nf_ct_h323: H.245 Response %d\n", + mscm->response.choice); + break; + default: + pr_debug("nf_ct_h323: H.245 signal %d\n", mscm->choice); + break; + } + + return 0; +} + +static int h245_help(struct sk_buff *skb, unsigned int protoff, + struct nf_conn *ct, enum ip_conntrack_info ctinfo) +{ + static MultimediaSystemControlMessage mscm; + unsigned char *data = NULL; + int datalen; + int dataoff; + int ret; + + /* Until there's been traffic both ways, don't look in packets. */ + if (ctinfo != IP_CT_ESTABLISHED && ctinfo != IP_CT_ESTABLISHED_REPLY) + return NF_ACCEPT; + + pr_debug("nf_ct_h245: skblen = %u\n", skb->len); + + spin_lock_bh(&nf_h323_lock); + + /* Process each TPKT */ + while (get_tpkt_data(skb, protoff, ct, ctinfo, + &data, &datalen, &dataoff)) { + pr_debug("nf_ct_h245: TPKT len=%d ", datalen); + nf_ct_dump_tuple(&ct->tuplehash[CTINFO2DIR(ctinfo)].tuple); + + /* Decode H.245 signal */ + ret = DecodeMultimediaSystemControlMessage(data, datalen, + &mscm); + if (ret < 0) { + pr_debug("nf_ct_h245: decoding error: %s\n", + ret == H323_ERROR_BOUND ? + "out of bound" : "out of range"); + /* We don't drop when decoding error */ + break; + } + + /* Process H.245 signal */ + if (process_h245(skb, ct, ctinfo, protoff, + &data, dataoff, &mscm) < 0) + goto drop; + } + + spin_unlock_bh(&nf_h323_lock); + return NF_ACCEPT; + + drop: + spin_unlock_bh(&nf_h323_lock); + nf_ct_helper_log(skb, ct, "cannot process H.245 message"); + return NF_DROP; +} + +static const struct nf_conntrack_expect_policy h245_exp_policy = { + .max_expected = H323_RTP_CHANNEL_MAX * 4 + 2 /* T.120 */, + .timeout = 240, +}; + +static struct nf_conntrack_helper nf_conntrack_helper_h245 __read_mostly = { + .name = "H.245", + .me = THIS_MODULE, + .tuple.src.l3num = AF_UNSPEC, + .tuple.dst.protonum = IPPROTO_UDP, + .help = h245_help, + .expect_policy = &h245_exp_policy, +}; + +int get_h225_addr(struct nf_conn *ct, unsigned char *data, + TransportAddress *taddr, + union nf_inet_addr *addr, __be16 *port) +{ + const unsigned char *p; + int len; + + switch (taddr->choice) { + case eTransportAddress_ipAddress: + if (nf_ct_l3num(ct) != AF_INET) + return 0; + p = data + taddr->ipAddress.ip; + len = 4; + break; + case eTransportAddress_ip6Address: + if (nf_ct_l3num(ct) != AF_INET6) + return 0; + p = data + taddr->ip6Address.ip; + len = 16; + break; + default: + return 0; + } + + memcpy(addr, p, len); + memset((void *)addr + len, 0, sizeof(*addr) - len); + memcpy(port, p + len, sizeof(__be16)); + + return 1; +} +EXPORT_SYMBOL_GPL(get_h225_addr); + +static int expect_h245(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, unsigned char **data, int dataoff, + TransportAddress *taddr) +{ + const struct nfct_h323_nat_hooks *nathook; + int dir = CTINFO2DIR(ctinfo); + int ret = 0; + __be16 port; + union nf_inet_addr addr; + struct nf_conntrack_expect *exp; + + /* Read h245Address */ + if (!get_h225_addr(ct, *data, taddr, &addr, &port) || + memcmp(&addr, &ct->tuplehash[dir].tuple.src.u3, sizeof(addr)) || + port == 0) + return 0; + + /* Create expect for h245 connection */ + if ((exp = nf_ct_expect_alloc(ct)) == NULL) + return -1; + nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT, nf_ct_l3num(ct), + &ct->tuplehash[!dir].tuple.src.u3, + &ct->tuplehash[!dir].tuple.dst.u3, + IPPROTO_TCP, NULL, &port); + exp->helper = &nf_conntrack_helper_h245; + + nathook = rcu_dereference(nfct_h323_nat_hook); + if (memcmp(&ct->tuplehash[dir].tuple.src.u3, + &ct->tuplehash[!dir].tuple.dst.u3, + sizeof(ct->tuplehash[dir].tuple.src.u3)) && + nathook && + nf_ct_l3num(ct) == NFPROTO_IPV4 && + ct->status & IPS_NAT_MASK) { + /* NAT needed */ + ret = nathook->nat_h245(skb, ct, ctinfo, protoff, data, + dataoff, taddr, port, exp); + } else { /* Conntrack only */ + if (nf_ct_expect_related(exp, 0) == 0) { + pr_debug("nf_ct_q931: expect H.245 "); + nf_ct_dump_tuple(&exp->tuple); + } else + ret = -1; + } + + nf_ct_expect_put(exp); + + return ret; +} + +/* If the calling party is on the same side of the forward-to party, + * we don't need to track the second call + */ +static int callforward_do_filter(struct net *net, + const union nf_inet_addr *src, + const union nf_inet_addr *dst, + u_int8_t family) +{ + int ret = 0; + + switch (family) { + case AF_INET: { + struct flowi4 fl1, fl2; + struct rtable *rt1, *rt2; + + memset(&fl1, 0, sizeof(fl1)); + fl1.daddr = src->ip; + + memset(&fl2, 0, sizeof(fl2)); + fl2.daddr = dst->ip; + if (!nf_ip_route(net, (struct dst_entry **)&rt1, + flowi4_to_flowi(&fl1), false)) { + if (!nf_ip_route(net, (struct dst_entry **)&rt2, + flowi4_to_flowi(&fl2), false)) { + if (rt_nexthop(rt1, fl1.daddr) == + rt_nexthop(rt2, fl2.daddr) && + rt1->dst.dev == rt2->dst.dev) + ret = 1; + dst_release(&rt2->dst); + } + dst_release(&rt1->dst); + } + break; + } +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: { + struct rt6_info *rt1, *rt2; + struct flowi6 fl1, fl2; + + memset(&fl1, 0, sizeof(fl1)); + fl1.daddr = src->in6; + + memset(&fl2, 0, sizeof(fl2)); + fl2.daddr = dst->in6; + if (!nf_ip6_route(net, (struct dst_entry **)&rt1, + flowi6_to_flowi(&fl1), false)) { + if (!nf_ip6_route(net, (struct dst_entry **)&rt2, + flowi6_to_flowi(&fl2), false)) { + if (ipv6_addr_equal(rt6_nexthop(rt1, &fl1.daddr), + rt6_nexthop(rt2, &fl2.daddr)) && + rt1->dst.dev == rt2->dst.dev) + ret = 1; + dst_release(&rt2->dst); + } + dst_release(&rt1->dst); + } + break; + } +#endif + } + return ret; + +} + +static int expect_callforwarding(struct sk_buff *skb, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, int dataoff, + TransportAddress *taddr) +{ + const struct nfct_h323_nat_hooks *nathook; + int dir = CTINFO2DIR(ctinfo); + int ret = 0; + __be16 port; + union nf_inet_addr addr; + struct nf_conntrack_expect *exp; + struct net *net = nf_ct_net(ct); + + /* Read alternativeAddress */ + if (!get_h225_addr(ct, *data, taddr, &addr, &port) || port == 0) + return 0; + + /* If the calling party is on the same side of the forward-to party, + * we don't need to track the second call + */ + if (callforward_filter && + callforward_do_filter(net, &addr, &ct->tuplehash[!dir].tuple.src.u3, + nf_ct_l3num(ct))) { + pr_debug("nf_ct_q931: Call Forwarding not tracked\n"); + return 0; + } + + /* Create expect for the second call leg */ + if ((exp = nf_ct_expect_alloc(ct)) == NULL) + return -1; + nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT, nf_ct_l3num(ct), + &ct->tuplehash[!dir].tuple.src.u3, &addr, + IPPROTO_TCP, NULL, &port); + exp->helper = nf_conntrack_helper_q931; + + nathook = rcu_dereference(nfct_h323_nat_hook); + if (memcmp(&ct->tuplehash[dir].tuple.src.u3, + &ct->tuplehash[!dir].tuple.dst.u3, + sizeof(ct->tuplehash[dir].tuple.src.u3)) && + nathook && + nf_ct_l3num(ct) == NFPROTO_IPV4 && + ct->status & IPS_NAT_MASK) { + /* Need NAT */ + ret = nathook->nat_callforwarding(skb, ct, ctinfo, + protoff, data, dataoff, + taddr, port, exp); + } else { /* Conntrack only */ + if (nf_ct_expect_related(exp, 0) == 0) { + pr_debug("nf_ct_q931: expect Call Forwarding "); + nf_ct_dump_tuple(&exp->tuple); + } else + ret = -1; + } + + nf_ct_expect_put(exp); + + return ret; +} + +static int process_setup(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, int dataoff, + Setup_UUIE *setup) +{ + const struct nfct_h323_nat_hooks *nathook; + int dir = CTINFO2DIR(ctinfo); + int ret; + int i; + __be16 port; + union nf_inet_addr addr; + + pr_debug("nf_ct_q931: Setup\n"); + + if (setup->options & eSetup_UUIE_h245Address) { + ret = expect_h245(skb, ct, ctinfo, protoff, data, dataoff, + &setup->h245Address); + if (ret < 0) + return -1; + } + + nathook = rcu_dereference(nfct_h323_nat_hook); + if ((setup->options & eSetup_UUIE_destCallSignalAddress) && + nathook && nf_ct_l3num(ct) == NFPROTO_IPV4 && + ct->status & IPS_NAT_MASK && + get_h225_addr(ct, *data, &setup->destCallSignalAddress, + &addr, &port) && + memcmp(&addr, &ct->tuplehash[!dir].tuple.src.u3, sizeof(addr))) { + pr_debug("nf_ct_q931: set destCallSignalAddress %pI6:%hu->%pI6:%hu\n", + &addr, ntohs(port), &ct->tuplehash[!dir].tuple.src.u3, + ntohs(ct->tuplehash[!dir].tuple.src.u.tcp.port)); + ret = nathook->set_h225_addr(skb, protoff, data, dataoff, + &setup->destCallSignalAddress, + &ct->tuplehash[!dir].tuple.src.u3, + ct->tuplehash[!dir].tuple.src.u.tcp.port); + if (ret < 0) + return -1; + } + + if ((setup->options & eSetup_UUIE_sourceCallSignalAddress) && + nathook && nf_ct_l3num(ct) == NFPROTO_IPV4 && + ct->status & IPS_NAT_MASK && + get_h225_addr(ct, *data, &setup->sourceCallSignalAddress, + &addr, &port) && + memcmp(&addr, &ct->tuplehash[!dir].tuple.dst.u3, sizeof(addr))) { + pr_debug("nf_ct_q931: set sourceCallSignalAddress %pI6:%hu->%pI6:%hu\n", + &addr, ntohs(port), &ct->tuplehash[!dir].tuple.dst.u3, + ntohs(ct->tuplehash[!dir].tuple.dst.u.tcp.port)); + ret = nathook->set_h225_addr(skb, protoff, data, dataoff, + &setup->sourceCallSignalAddress, + &ct->tuplehash[!dir].tuple.dst.u3, + ct->tuplehash[!dir].tuple.dst.u.tcp.port); + if (ret < 0) + return -1; + } + + if (setup->options & eSetup_UUIE_fastStart) { + for (i = 0; i < setup->fastStart.count; i++) { + ret = process_olc(skb, ct, ctinfo, + protoff, data, dataoff, + &setup->fastStart.item[i]); + if (ret < 0) + return -1; + } + } + + return 0; +} + +static int process_callproceeding(struct sk_buff *skb, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, int dataoff, + CallProceeding_UUIE *callproc) +{ + int ret; + int i; + + pr_debug("nf_ct_q931: CallProceeding\n"); + + if (callproc->options & eCallProceeding_UUIE_h245Address) { + ret = expect_h245(skb, ct, ctinfo, protoff, data, dataoff, + &callproc->h245Address); + if (ret < 0) + return -1; + } + + if (callproc->options & eCallProceeding_UUIE_fastStart) { + for (i = 0; i < callproc->fastStart.count; i++) { + ret = process_olc(skb, ct, ctinfo, + protoff, data, dataoff, + &callproc->fastStart.item[i]); + if (ret < 0) + return -1; + } + } + + return 0; +} + +static int process_connect(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, int dataoff, + Connect_UUIE *connect) +{ + int ret; + int i; + + pr_debug("nf_ct_q931: Connect\n"); + + if (connect->options & eConnect_UUIE_h245Address) { + ret = expect_h245(skb, ct, ctinfo, protoff, data, dataoff, + &connect->h245Address); + if (ret < 0) + return -1; + } + + if (connect->options & eConnect_UUIE_fastStart) { + for (i = 0; i < connect->fastStart.count; i++) { + ret = process_olc(skb, ct, ctinfo, + protoff, data, dataoff, + &connect->fastStart.item[i]); + if (ret < 0) + return -1; + } + } + + return 0; +} + +static int process_alerting(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, int dataoff, + Alerting_UUIE *alert) +{ + int ret; + int i; + + pr_debug("nf_ct_q931: Alerting\n"); + + if (alert->options & eAlerting_UUIE_h245Address) { + ret = expect_h245(skb, ct, ctinfo, protoff, data, dataoff, + &alert->h245Address); + if (ret < 0) + return -1; + } + + if (alert->options & eAlerting_UUIE_fastStart) { + for (i = 0; i < alert->fastStart.count; i++) { + ret = process_olc(skb, ct, ctinfo, + protoff, data, dataoff, + &alert->fastStart.item[i]); + if (ret < 0) + return -1; + } + } + + return 0; +} + +static int process_facility(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, int dataoff, + Facility_UUIE *facility) +{ + int ret; + int i; + + pr_debug("nf_ct_q931: Facility\n"); + + if (facility->reason.choice == eFacilityReason_callForwarded) { + if (facility->options & eFacility_UUIE_alternativeAddress) + return expect_callforwarding(skb, ct, ctinfo, + protoff, data, dataoff, + &facility-> + alternativeAddress); + return 0; + } + + if (facility->options & eFacility_UUIE_h245Address) { + ret = expect_h245(skb, ct, ctinfo, protoff, data, dataoff, + &facility->h245Address); + if (ret < 0) + return -1; + } + + if (facility->options & eFacility_UUIE_fastStart) { + for (i = 0; i < facility->fastStart.count; i++) { + ret = process_olc(skb, ct, ctinfo, + protoff, data, dataoff, + &facility->fastStart.item[i]); + if (ret < 0) + return -1; + } + } + + return 0; +} + +static int process_progress(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, int dataoff, + Progress_UUIE *progress) +{ + int ret; + int i; + + pr_debug("nf_ct_q931: Progress\n"); + + if (progress->options & eProgress_UUIE_h245Address) { + ret = expect_h245(skb, ct, ctinfo, protoff, data, dataoff, + &progress->h245Address); + if (ret < 0) + return -1; + } + + if (progress->options & eProgress_UUIE_fastStart) { + for (i = 0; i < progress->fastStart.count; i++) { + ret = process_olc(skb, ct, ctinfo, + protoff, data, dataoff, + &progress->fastStart.item[i]); + if (ret < 0) + return -1; + } + } + + return 0; +} + +static int process_q931(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, unsigned char **data, int dataoff, + Q931 *q931) +{ + H323_UU_PDU *pdu = &q931->UUIE.h323_uu_pdu; + int i; + int ret = 0; + + switch (pdu->h323_message_body.choice) { + case eH323_UU_PDU_h323_message_body_setup: + ret = process_setup(skb, ct, ctinfo, protoff, data, dataoff, + &pdu->h323_message_body.setup); + break; + case eH323_UU_PDU_h323_message_body_callProceeding: + ret = process_callproceeding(skb, ct, ctinfo, + protoff, data, dataoff, + &pdu->h323_message_body. + callProceeding); + break; + case eH323_UU_PDU_h323_message_body_connect: + ret = process_connect(skb, ct, ctinfo, protoff, data, dataoff, + &pdu->h323_message_body.connect); + break; + case eH323_UU_PDU_h323_message_body_alerting: + ret = process_alerting(skb, ct, ctinfo, protoff, data, dataoff, + &pdu->h323_message_body.alerting); + break; + case eH323_UU_PDU_h323_message_body_facility: + ret = process_facility(skb, ct, ctinfo, protoff, data, dataoff, + &pdu->h323_message_body.facility); + break; + case eH323_UU_PDU_h323_message_body_progress: + ret = process_progress(skb, ct, ctinfo, protoff, data, dataoff, + &pdu->h323_message_body.progress); + break; + default: + pr_debug("nf_ct_q931: Q.931 signal %d\n", + pdu->h323_message_body.choice); + break; + } + + if (ret < 0) + return -1; + + if (pdu->options & eH323_UU_PDU_h245Control) { + for (i = 0; i < pdu->h245Control.count; i++) { + ret = process_h245(skb, ct, ctinfo, + protoff, data, dataoff, + &pdu->h245Control.item[i]); + if (ret < 0) + return -1; + } + } + + return 0; +} + +static int q931_help(struct sk_buff *skb, unsigned int protoff, + struct nf_conn *ct, enum ip_conntrack_info ctinfo) +{ + static Q931 q931; + unsigned char *data = NULL; + int datalen; + int dataoff; + int ret; + + /* Until there's been traffic both ways, don't look in packets. */ + if (ctinfo != IP_CT_ESTABLISHED && ctinfo != IP_CT_ESTABLISHED_REPLY) + return NF_ACCEPT; + + pr_debug("nf_ct_q931: skblen = %u\n", skb->len); + + spin_lock_bh(&nf_h323_lock); + + /* Process each TPKT */ + while (get_tpkt_data(skb, protoff, ct, ctinfo, + &data, &datalen, &dataoff)) { + pr_debug("nf_ct_q931: TPKT len=%d ", datalen); + nf_ct_dump_tuple(&ct->tuplehash[CTINFO2DIR(ctinfo)].tuple); + + /* Decode Q.931 signal */ + ret = DecodeQ931(data, datalen, &q931); + if (ret < 0) { + pr_debug("nf_ct_q931: decoding error: %s\n", + ret == H323_ERROR_BOUND ? + "out of bound" : "out of range"); + /* We don't drop when decoding error */ + break; + } + + /* Process Q.931 signal */ + if (process_q931(skb, ct, ctinfo, protoff, + &data, dataoff, &q931) < 0) + goto drop; + } + + spin_unlock_bh(&nf_h323_lock); + return NF_ACCEPT; + + drop: + spin_unlock_bh(&nf_h323_lock); + nf_ct_helper_log(skb, ct, "cannot process Q.931 message"); + return NF_DROP; +} + +static const struct nf_conntrack_expect_policy q931_exp_policy = { + /* T.120 and H.245 */ + .max_expected = H323_RTP_CHANNEL_MAX * 4 + 4, + .timeout = 240, +}; + +static struct nf_conntrack_helper nf_conntrack_helper_q931[] __read_mostly = { + { + .name = "Q.931", + .me = THIS_MODULE, + .tuple.src.l3num = AF_INET, + .tuple.src.u.tcp.port = cpu_to_be16(Q931_PORT), + .tuple.dst.protonum = IPPROTO_TCP, + .help = q931_help, + .expect_policy = &q931_exp_policy, + }, + { + .name = "Q.931", + .me = THIS_MODULE, + .tuple.src.l3num = AF_INET6, + .tuple.src.u.tcp.port = cpu_to_be16(Q931_PORT), + .tuple.dst.protonum = IPPROTO_TCP, + .help = q931_help, + .expect_policy = &q931_exp_policy, + }, +}; + +static unsigned char *get_udp_data(struct sk_buff *skb, unsigned int protoff, + int *datalen) +{ + const struct udphdr *uh; + struct udphdr _uh; + int dataoff; + + uh = skb_header_pointer(skb, protoff, sizeof(_uh), &_uh); + if (uh == NULL) + return NULL; + dataoff = protoff + sizeof(_uh); + if (dataoff >= skb->len) + return NULL; + *datalen = skb->len - dataoff; + if (*datalen > H323_MAX_SIZE) + *datalen = H323_MAX_SIZE; + + return skb_header_pointer(skb, dataoff, *datalen, h323_buffer); +} + +static struct nf_conntrack_expect *find_expect(struct nf_conn *ct, + union nf_inet_addr *addr, + __be16 port) +{ + struct net *net = nf_ct_net(ct); + struct nf_conntrack_expect *exp; + struct nf_conntrack_tuple tuple; + + memset(&tuple.src.u3, 0, sizeof(tuple.src.u3)); + tuple.src.u.tcp.port = 0; + memcpy(&tuple.dst.u3, addr, sizeof(tuple.dst.u3)); + tuple.dst.u.tcp.port = port; + tuple.dst.protonum = IPPROTO_TCP; + + exp = __nf_ct_expect_find(net, nf_ct_zone(ct), &tuple); + if (exp && exp->master == ct) + return exp; + return NULL; +} + +static int expect_q931(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, unsigned char **data, + TransportAddress *taddr, int count) +{ + struct nf_ct_h323_master *info = nfct_help_data(ct); + const struct nfct_h323_nat_hooks *nathook; + int dir = CTINFO2DIR(ctinfo); + int ret = 0; + int i; + __be16 port; + union nf_inet_addr addr; + struct nf_conntrack_expect *exp; + + /* Look for the first related address */ + for (i = 0; i < count; i++) { + if (get_h225_addr(ct, *data, &taddr[i], &addr, &port) && + memcmp(&addr, &ct->tuplehash[dir].tuple.src.u3, + sizeof(addr)) == 0 && port != 0) + break; + } + + if (i >= count) /* Not found */ + return 0; + + /* Create expect for Q.931 */ + if ((exp = nf_ct_expect_alloc(ct)) == NULL) + return -1; + nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT, nf_ct_l3num(ct), + gkrouted_only ? /* only accept calls from GK? */ + &ct->tuplehash[!dir].tuple.src.u3 : NULL, + &ct->tuplehash[!dir].tuple.dst.u3, + IPPROTO_TCP, NULL, &port); + exp->helper = nf_conntrack_helper_q931; + exp->flags = NF_CT_EXPECT_PERMANENT; /* Accept multiple calls */ + + nathook = rcu_dereference(nfct_h323_nat_hook); + if (nathook && nf_ct_l3num(ct) == NFPROTO_IPV4 && + ct->status & IPS_NAT_MASK) { /* Need NAT */ + ret = nathook->nat_q931(skb, ct, ctinfo, protoff, data, + taddr, i, port, exp); + } else { /* Conntrack only */ + if (nf_ct_expect_related(exp, 0) == 0) { + pr_debug("nf_ct_ras: expect Q.931 "); + nf_ct_dump_tuple(&exp->tuple); + + /* Save port for looking up expect in processing RCF */ + info->sig_port[dir] = port; + } else + ret = -1; + } + + nf_ct_expect_put(exp); + + return ret; +} + +static int process_grq(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, GatekeeperRequest *grq) +{ + const struct nfct_h323_nat_hooks *nathook; + + pr_debug("nf_ct_ras: GRQ\n"); + + nathook = rcu_dereference(nfct_h323_nat_hook); + if (nathook && nf_ct_l3num(ct) == NFPROTO_IPV4 && + ct->status & IPS_NAT_MASK) /* NATed */ + return nathook->set_ras_addr(skb, ct, ctinfo, protoff, data, + &grq->rasAddress, 1); + return 0; +} + +static int process_gcf(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, GatekeeperConfirm *gcf) +{ + int dir = CTINFO2DIR(ctinfo); + int ret = 0; + __be16 port; + union nf_inet_addr addr; + struct nf_conntrack_expect *exp; + + pr_debug("nf_ct_ras: GCF\n"); + + if (!get_h225_addr(ct, *data, &gcf->rasAddress, &addr, &port)) + return 0; + + /* Registration port is the same as discovery port */ + if (!memcmp(&addr, &ct->tuplehash[dir].tuple.src.u3, sizeof(addr)) && + port == ct->tuplehash[dir].tuple.src.u.udp.port) + return 0; + + /* Avoid RAS expectation loops. A GCF is never expected. */ + if (test_bit(IPS_EXPECTED_BIT, &ct->status)) + return 0; + + /* Need new expect */ + if ((exp = nf_ct_expect_alloc(ct)) == NULL) + return -1; + nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT, nf_ct_l3num(ct), + &ct->tuplehash[!dir].tuple.src.u3, &addr, + IPPROTO_UDP, NULL, &port); + exp->helper = nf_conntrack_helper_ras; + + if (nf_ct_expect_related(exp, 0) == 0) { + pr_debug("nf_ct_ras: expect RAS "); + nf_ct_dump_tuple(&exp->tuple); + } else + ret = -1; + + nf_ct_expect_put(exp); + + return ret; +} + +static int process_rrq(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, RegistrationRequest *rrq) +{ + struct nf_ct_h323_master *info = nfct_help_data(ct); + const struct nfct_h323_nat_hooks *nathook; + int ret; + + pr_debug("nf_ct_ras: RRQ\n"); + + ret = expect_q931(skb, ct, ctinfo, protoff, data, + rrq->callSignalAddress.item, + rrq->callSignalAddress.count); + if (ret < 0) + return -1; + + nathook = rcu_dereference(nfct_h323_nat_hook); + if (nathook && nf_ct_l3num(ct) == NFPROTO_IPV4 && + ct->status & IPS_NAT_MASK) { + ret = nathook->set_ras_addr(skb, ct, ctinfo, protoff, data, + rrq->rasAddress.item, + rrq->rasAddress.count); + if (ret < 0) + return -1; + } + + if (rrq->options & eRegistrationRequest_timeToLive) { + pr_debug("nf_ct_ras: RRQ TTL = %u seconds\n", rrq->timeToLive); + info->timeout = rrq->timeToLive; + } else + info->timeout = default_rrq_ttl; + + return 0; +} + +static int process_rcf(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, RegistrationConfirm *rcf) +{ + struct nf_ct_h323_master *info = nfct_help_data(ct); + const struct nfct_h323_nat_hooks *nathook; + int dir = CTINFO2DIR(ctinfo); + int ret; + struct nf_conntrack_expect *exp; + + pr_debug("nf_ct_ras: RCF\n"); + + nathook = rcu_dereference(nfct_h323_nat_hook); + if (nathook && nf_ct_l3num(ct) == NFPROTO_IPV4 && + ct->status & IPS_NAT_MASK) { + ret = nathook->set_sig_addr(skb, ct, ctinfo, protoff, data, + rcf->callSignalAddress.item, + rcf->callSignalAddress.count); + if (ret < 0) + return -1; + } + + if (rcf->options & eRegistrationConfirm_timeToLive) { + pr_debug("nf_ct_ras: RCF TTL = %u seconds\n", rcf->timeToLive); + info->timeout = rcf->timeToLive; + } + + if (info->timeout > 0) { + pr_debug("nf_ct_ras: set RAS connection timeout to " + "%u seconds\n", info->timeout); + nf_ct_refresh(ct, skb, info->timeout * HZ); + + /* Set expect timeout */ + spin_lock_bh(&nf_conntrack_expect_lock); + exp = find_expect(ct, &ct->tuplehash[dir].tuple.dst.u3, + info->sig_port[!dir]); + if (exp) { + pr_debug("nf_ct_ras: set Q.931 expect " + "timeout to %u seconds for", + info->timeout); + nf_ct_dump_tuple(&exp->tuple); + mod_timer_pending(&exp->timeout, + jiffies + info->timeout * HZ); + } + spin_unlock_bh(&nf_conntrack_expect_lock); + } + + return 0; +} + +static int process_urq(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, UnregistrationRequest *urq) +{ + struct nf_ct_h323_master *info = nfct_help_data(ct); + const struct nfct_h323_nat_hooks *nathook; + int dir = CTINFO2DIR(ctinfo); + int ret; + + pr_debug("nf_ct_ras: URQ\n"); + + nathook = rcu_dereference(nfct_h323_nat_hook); + if (nathook && nf_ct_l3num(ct) == NFPROTO_IPV4 && + ct->status & IPS_NAT_MASK) { + ret = nathook->set_sig_addr(skb, ct, ctinfo, protoff, data, + urq->callSignalAddress.item, + urq->callSignalAddress.count); + if (ret < 0) + return -1; + } + + /* Clear old expect */ + nf_ct_remove_expectations(ct); + info->sig_port[dir] = 0; + info->sig_port[!dir] = 0; + + /* Give it 30 seconds for UCF or URJ */ + nf_ct_refresh(ct, skb, 30 * HZ); + + return 0; +} + +static int process_arq(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, AdmissionRequest *arq) +{ + const struct nf_ct_h323_master *info = nfct_help_data(ct); + const struct nfct_h323_nat_hooks *nathook; + int dir = CTINFO2DIR(ctinfo); + __be16 port; + union nf_inet_addr addr; + + pr_debug("nf_ct_ras: ARQ\n"); + + nathook = rcu_dereference(nfct_h323_nat_hook); + if (!nathook) + return 0; + + if ((arq->options & eAdmissionRequest_destCallSignalAddress) && + get_h225_addr(ct, *data, &arq->destCallSignalAddress, + &addr, &port) && + !memcmp(&addr, &ct->tuplehash[dir].tuple.src.u3, sizeof(addr)) && + port == info->sig_port[dir] && + nf_ct_l3num(ct) == NFPROTO_IPV4 && + ct->status & IPS_NAT_MASK) { + /* Answering ARQ */ + return nathook->set_h225_addr(skb, protoff, data, 0, + &arq->destCallSignalAddress, + &ct->tuplehash[!dir].tuple.dst.u3, + info->sig_port[!dir]); + } + + if ((arq->options & eAdmissionRequest_srcCallSignalAddress) && + get_h225_addr(ct, *data, &arq->srcCallSignalAddress, + &addr, &port) && + !memcmp(&addr, &ct->tuplehash[dir].tuple.src.u3, sizeof(addr)) && + nf_ct_l3num(ct) == NFPROTO_IPV4 && + ct->status & IPS_NAT_MASK) { + /* Calling ARQ */ + return nathook->set_h225_addr(skb, protoff, data, 0, + &arq->srcCallSignalAddress, + &ct->tuplehash[!dir].tuple.dst.u3, + port); + } + + return 0; +} + +static int process_acf(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, AdmissionConfirm *acf) +{ + int dir = CTINFO2DIR(ctinfo); + int ret = 0; + __be16 port; + union nf_inet_addr addr; + struct nf_conntrack_expect *exp; + + pr_debug("nf_ct_ras: ACF\n"); + + if (!get_h225_addr(ct, *data, &acf->destCallSignalAddress, + &addr, &port)) + return 0; + + if (!memcmp(&addr, &ct->tuplehash[dir].tuple.dst.u3, sizeof(addr))) { + const struct nfct_h323_nat_hooks *nathook; + + /* Answering ACF */ + nathook = rcu_dereference(nfct_h323_nat_hook); + if (nathook && nf_ct_l3num(ct) == NFPROTO_IPV4 && + ct->status & IPS_NAT_MASK) + return nathook->set_sig_addr(skb, ct, ctinfo, protoff, + data, + &acf->destCallSignalAddress, 1); + return 0; + } + + /* Need new expect */ + if ((exp = nf_ct_expect_alloc(ct)) == NULL) + return -1; + nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT, nf_ct_l3num(ct), + &ct->tuplehash[!dir].tuple.src.u3, &addr, + IPPROTO_TCP, NULL, &port); + exp->flags = NF_CT_EXPECT_PERMANENT; + exp->helper = nf_conntrack_helper_q931; + + if (nf_ct_expect_related(exp, 0) == 0) { + pr_debug("nf_ct_ras: expect Q.931 "); + nf_ct_dump_tuple(&exp->tuple); + } else + ret = -1; + + nf_ct_expect_put(exp); + + return ret; +} + +static int process_lrq(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, LocationRequest *lrq) +{ + const struct nfct_h323_nat_hooks *nathook; + + pr_debug("nf_ct_ras: LRQ\n"); + + nathook = rcu_dereference(nfct_h323_nat_hook); + if (nathook && nf_ct_l3num(ct) == NFPROTO_IPV4 && + ct->status & IPS_NAT_MASK) + return nathook->set_ras_addr(skb, ct, ctinfo, protoff, data, + &lrq->replyAddress, 1); + return 0; +} + +static int process_lcf(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, LocationConfirm *lcf) +{ + int dir = CTINFO2DIR(ctinfo); + int ret = 0; + __be16 port; + union nf_inet_addr addr; + struct nf_conntrack_expect *exp; + + pr_debug("nf_ct_ras: LCF\n"); + + if (!get_h225_addr(ct, *data, &lcf->callSignalAddress, + &addr, &port)) + return 0; + + /* Need new expect for call signal */ + if ((exp = nf_ct_expect_alloc(ct)) == NULL) + return -1; + nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT, nf_ct_l3num(ct), + &ct->tuplehash[!dir].tuple.src.u3, &addr, + IPPROTO_TCP, NULL, &port); + exp->flags = NF_CT_EXPECT_PERMANENT; + exp->helper = nf_conntrack_helper_q931; + + if (nf_ct_expect_related(exp, 0) == 0) { + pr_debug("nf_ct_ras: expect Q.931 "); + nf_ct_dump_tuple(&exp->tuple); + } else + ret = -1; + + nf_ct_expect_put(exp); + + /* Ignore rasAddress */ + + return ret; +} + +static int process_irr(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, InfoRequestResponse *irr) +{ + const struct nfct_h323_nat_hooks *nathook; + int ret; + + pr_debug("nf_ct_ras: IRR\n"); + + nathook = rcu_dereference(nfct_h323_nat_hook); + if (nathook && nf_ct_l3num(ct) == NFPROTO_IPV4 && + ct->status & IPS_NAT_MASK) { + ret = nathook->set_ras_addr(skb, ct, ctinfo, protoff, data, + &irr->rasAddress, 1); + if (ret < 0) + return -1; + + ret = nathook->set_sig_addr(skb, ct, ctinfo, protoff, data, + irr->callSignalAddress.item, + irr->callSignalAddress.count); + if (ret < 0) + return -1; + } + + return 0; +} + +static int process_ras(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned char **data, RasMessage *ras) +{ + switch (ras->choice) { + case eRasMessage_gatekeeperRequest: + return process_grq(skb, ct, ctinfo, protoff, data, + &ras->gatekeeperRequest); + case eRasMessage_gatekeeperConfirm: + return process_gcf(skb, ct, ctinfo, protoff, data, + &ras->gatekeeperConfirm); + case eRasMessage_registrationRequest: + return process_rrq(skb, ct, ctinfo, protoff, data, + &ras->registrationRequest); + case eRasMessage_registrationConfirm: + return process_rcf(skb, ct, ctinfo, protoff, data, + &ras->registrationConfirm); + case eRasMessage_unregistrationRequest: + return process_urq(skb, ct, ctinfo, protoff, data, + &ras->unregistrationRequest); + case eRasMessage_admissionRequest: + return process_arq(skb, ct, ctinfo, protoff, data, + &ras->admissionRequest); + case eRasMessage_admissionConfirm: + return process_acf(skb, ct, ctinfo, protoff, data, + &ras->admissionConfirm); + case eRasMessage_locationRequest: + return process_lrq(skb, ct, ctinfo, protoff, data, + &ras->locationRequest); + case eRasMessage_locationConfirm: + return process_lcf(skb, ct, ctinfo, protoff, data, + &ras->locationConfirm); + case eRasMessage_infoRequestResponse: + return process_irr(skb, ct, ctinfo, protoff, data, + &ras->infoRequestResponse); + default: + pr_debug("nf_ct_ras: RAS message %d\n", ras->choice); + break; + } + + return 0; +} + +static int ras_help(struct sk_buff *skb, unsigned int protoff, + struct nf_conn *ct, enum ip_conntrack_info ctinfo) +{ + static RasMessage ras; + unsigned char *data; + int datalen = 0; + int ret; + + pr_debug("nf_ct_ras: skblen = %u\n", skb->len); + + spin_lock_bh(&nf_h323_lock); + + /* Get UDP data */ + data = get_udp_data(skb, protoff, &datalen); + if (data == NULL) + goto accept; + pr_debug("nf_ct_ras: RAS message len=%d ", datalen); + nf_ct_dump_tuple(&ct->tuplehash[CTINFO2DIR(ctinfo)].tuple); + + /* Decode RAS message */ + ret = DecodeRasMessage(data, datalen, &ras); + if (ret < 0) { + pr_debug("nf_ct_ras: decoding error: %s\n", + ret == H323_ERROR_BOUND ? + "out of bound" : "out of range"); + goto accept; + } + + /* Process RAS message */ + if (process_ras(skb, ct, ctinfo, protoff, &data, &ras) < 0) + goto drop; + + accept: + spin_unlock_bh(&nf_h323_lock); + return NF_ACCEPT; + + drop: + spin_unlock_bh(&nf_h323_lock); + nf_ct_helper_log(skb, ct, "cannot process RAS message"); + return NF_DROP; +} + +static const struct nf_conntrack_expect_policy ras_exp_policy = { + .max_expected = 32, + .timeout = 240, +}; + +static struct nf_conntrack_helper nf_conntrack_helper_ras[] __read_mostly = { + { + .name = "RAS", + .me = THIS_MODULE, + .tuple.src.l3num = AF_INET, + .tuple.src.u.udp.port = cpu_to_be16(RAS_PORT), + .tuple.dst.protonum = IPPROTO_UDP, + .help = ras_help, + .expect_policy = &ras_exp_policy, + }, + { + .name = "RAS", + .me = THIS_MODULE, + .tuple.src.l3num = AF_INET6, + .tuple.src.u.udp.port = cpu_to_be16(RAS_PORT), + .tuple.dst.protonum = IPPROTO_UDP, + .help = ras_help, + .expect_policy = &ras_exp_policy, + }, +}; + +static int __init h323_helper_init(void) +{ + int ret; + + ret = nf_conntrack_helper_register(&nf_conntrack_helper_h245); + if (ret < 0) + return ret; + ret = nf_conntrack_helpers_register(nf_conntrack_helper_q931, + ARRAY_SIZE(nf_conntrack_helper_q931)); + if (ret < 0) + goto err1; + ret = nf_conntrack_helpers_register(nf_conntrack_helper_ras, + ARRAY_SIZE(nf_conntrack_helper_ras)); + if (ret < 0) + goto err2; + + return 0; +err2: + nf_conntrack_helpers_unregister(nf_conntrack_helper_q931, + ARRAY_SIZE(nf_conntrack_helper_q931)); +err1: + nf_conntrack_helper_unregister(&nf_conntrack_helper_h245); + return ret; +} + +static void __exit h323_helper_exit(void) +{ + nf_conntrack_helpers_unregister(nf_conntrack_helper_ras, + ARRAY_SIZE(nf_conntrack_helper_ras)); + nf_conntrack_helpers_unregister(nf_conntrack_helper_q931, + ARRAY_SIZE(nf_conntrack_helper_q931)); + nf_conntrack_helper_unregister(&nf_conntrack_helper_h245); +} + +static void __exit nf_conntrack_h323_fini(void) +{ + h323_helper_exit(); + kfree(h323_buffer); + pr_debug("nf_ct_h323: fini\n"); +} + +static int __init nf_conntrack_h323_init(void) +{ + int ret; + + NF_CT_HELPER_BUILD_BUG_ON(sizeof(struct nf_ct_h323_master)); + + h323_buffer = kmalloc(H323_MAX_SIZE + 1, GFP_KERNEL); + if (!h323_buffer) + return -ENOMEM; + ret = h323_helper_init(); + if (ret < 0) + goto err1; + pr_debug("nf_ct_h323: init success\n"); + return 0; +err1: + kfree(h323_buffer); + return ret; +} + +module_init(nf_conntrack_h323_init); +module_exit(nf_conntrack_h323_fini); + +MODULE_AUTHOR("Jing Min Zhao "); +MODULE_DESCRIPTION("H.323 connection tracking helper"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ip_conntrack_h323"); +MODULE_ALIAS_NFCT_HELPER("RAS"); +MODULE_ALIAS_NFCT_HELPER("Q.931"); +MODULE_ALIAS_NFCT_HELPER("H.245"); diff --git a/net/netfilter/nf_conntrack_h323_types.c b/net/netfilter/nf_conntrack_h323_types.c new file mode 100644 index 000000000..fb1cb67a5 --- /dev/null +++ b/net/netfilter/nf_conntrack_h323_types.c @@ -0,0 +1,1921 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Generated by Jing Min Zhao's ASN.1 parser, May 16 2007 + * + * Copyright (c) 2006 Jing Min Zhao + */ + +static const struct field_t _TransportAddress_ipAddress[] = { /* SEQUENCE */ + {FNAME("ip") OCTSTR, FIXD, 4, 0, DECODE, + offsetof(TransportAddress_ipAddress, ip), NULL}, + {FNAME("port") INT, WORD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _TransportAddress_ipSourceRoute_route[] = { /* SEQUENCE OF */ + {FNAME("item") OCTSTR, FIXD, 4, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _TransportAddress_ipSourceRoute_routing[] = { /* CHOICE */ + {FNAME("strict") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("loose") NUL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _TransportAddress_ipSourceRoute[] = { /* SEQUENCE */ + {FNAME("ip") OCTSTR, FIXD, 4, 0, SKIP, 0, NULL}, + {FNAME("port") INT, WORD, 0, 0, SKIP, 0, NULL}, + {FNAME("route") SEQOF, SEMI, 0, 0, SKIP, 0, + _TransportAddress_ipSourceRoute_route}, + {FNAME("routing") CHOICE, 1, 2, 2, SKIP | EXT, 0, + _TransportAddress_ipSourceRoute_routing}, +}; + +static const struct field_t _TransportAddress_ipxAddress[] = { /* SEQUENCE */ + {FNAME("node") OCTSTR, FIXD, 6, 0, SKIP, 0, NULL}, + {FNAME("netnum") OCTSTR, FIXD, 4, 0, SKIP, 0, NULL}, + {FNAME("port") OCTSTR, FIXD, 2, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _TransportAddress_ip6Address[] = { /* SEQUENCE */ + {FNAME("ip") OCTSTR, FIXD, 16, 0, DECODE, + offsetof(TransportAddress_ip6Address, ip), NULL}, + {FNAME("port") INT, WORD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _H221NonStandard[] = { /* SEQUENCE */ + {FNAME("t35CountryCode") INT, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("t35Extension") INT, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("manufacturerCode") INT, WORD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _NonStandardIdentifier[] = { /* CHOICE */ + {FNAME("object") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("h221NonStandard") SEQ, 0, 3, 3, SKIP | EXT, 0, + _H221NonStandard}, +}; + +static const struct field_t _NonStandardParameter[] = { /* SEQUENCE */ + {FNAME("nonStandardIdentifier") CHOICE, 1, 2, 2, SKIP | EXT, 0, + _NonStandardIdentifier}, + {FNAME("data") OCTSTR, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _TransportAddress[] = { /* CHOICE */ + {FNAME("ipAddress") SEQ, 0, 2, 2, DECODE, + offsetof(TransportAddress, ipAddress), _TransportAddress_ipAddress}, + {FNAME("ipSourceRoute") SEQ, 0, 4, 4, SKIP | EXT, 0, + _TransportAddress_ipSourceRoute}, + {FNAME("ipxAddress") SEQ, 0, 3, 3, SKIP, 0, + _TransportAddress_ipxAddress}, + {FNAME("ip6Address") SEQ, 0, 2, 2, DECODE | EXT, + offsetof(TransportAddress, ip6Address), + _TransportAddress_ip6Address}, + {FNAME("netBios") OCTSTR, FIXD, 16, 0, SKIP, 0, NULL}, + {FNAME("nsap") OCTSTR, 5, 1, 0, SKIP, 0, NULL}, + {FNAME("nonStandardAddress") SEQ, 0, 2, 2, SKIP, 0, + _NonStandardParameter}, +}; + +static const struct field_t _AliasAddress[] = { /* CHOICE */ + {FNAME("dialedDigits") NUMDGT, 7, 1, 0, SKIP, 0, NULL}, + {FNAME("h323-ID") BMPSTR, BYTE, 1, 0, SKIP, 0, NULL}, + {FNAME("url-ID") IA5STR, WORD, 1, 0, SKIP, 0, NULL}, + {FNAME("transportID") CHOICE, 3, 7, 7, SKIP | EXT, 0, NULL}, + {FNAME("email-ID") IA5STR, WORD, 1, 0, SKIP, 0, NULL}, + {FNAME("partyNumber") CHOICE, 3, 5, 5, SKIP | EXT, 0, NULL}, + {FNAME("mobileUIM") CHOICE, 1, 2, 2, SKIP | EXT, 0, NULL}, +}; + +static const struct field_t _Setup_UUIE_sourceAddress[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 1, 2, 7, SKIP | EXT, 0, _AliasAddress}, +}; + +static const struct field_t _VendorIdentifier[] = { /* SEQUENCE */ + {FNAME("vendor") SEQ, 0, 3, 3, SKIP | EXT, 0, _H221NonStandard}, + {FNAME("productId") OCTSTR, BYTE, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("versionId") OCTSTR, BYTE, 1, 0, SKIP | OPT, 0, NULL}, +}; + +static const struct field_t _GatekeeperInfo[] = { /* SEQUENCE */ + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, +}; + +static const struct field_t _H310Caps[] = { /* SEQUENCE */ + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("dataRatesSupported") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("supportedPrefixes") SEQOF, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _H320Caps[] = { /* SEQUENCE */ + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("dataRatesSupported") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("supportedPrefixes") SEQOF, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _H321Caps[] = { /* SEQUENCE */ + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("dataRatesSupported") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("supportedPrefixes") SEQOF, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _H322Caps[] = { /* SEQUENCE */ + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("dataRatesSupported") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("supportedPrefixes") SEQOF, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _H323Caps[] = { /* SEQUENCE */ + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("dataRatesSupported") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("supportedPrefixes") SEQOF, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _H324Caps[] = { /* SEQUENCE */ + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("dataRatesSupported") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("supportedPrefixes") SEQOF, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _VoiceCaps[] = { /* SEQUENCE */ + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("dataRatesSupported") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("supportedPrefixes") SEQOF, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _T120OnlyCaps[] = { /* SEQUENCE */ + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("dataRatesSupported") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("supportedPrefixes") SEQOF, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _SupportedProtocols[] = { /* CHOICE */ + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP, 0, + _NonStandardParameter}, + {FNAME("h310") SEQ, 1, 1, 3, SKIP | EXT, 0, _H310Caps}, + {FNAME("h320") SEQ, 1, 1, 3, SKIP | EXT, 0, _H320Caps}, + {FNAME("h321") SEQ, 1, 1, 3, SKIP | EXT, 0, _H321Caps}, + {FNAME("h322") SEQ, 1, 1, 3, SKIP | EXT, 0, _H322Caps}, + {FNAME("h323") SEQ, 1, 1, 3, SKIP | EXT, 0, _H323Caps}, + {FNAME("h324") SEQ, 1, 1, 3, SKIP | EXT, 0, _H324Caps}, + {FNAME("voice") SEQ, 1, 1, 3, SKIP | EXT, 0, _VoiceCaps}, + {FNAME("t120-only") SEQ, 1, 1, 3, SKIP | EXT, 0, _T120OnlyCaps}, + {FNAME("nonStandardProtocol") SEQ, 2, 3, 3, SKIP | EXT, 0, NULL}, + {FNAME("t38FaxAnnexbOnly") SEQ, 2, 5, 5, SKIP | EXT, 0, NULL}, +}; + +static const struct field_t _GatewayInfo_protocol[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 4, 9, 11, SKIP | EXT, 0, _SupportedProtocols}, +}; + +static const struct field_t _GatewayInfo[] = { /* SEQUENCE */ + {FNAME("protocol") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, + _GatewayInfo_protocol}, + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, +}; + +static const struct field_t _McuInfo[] = { /* SEQUENCE */ + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("protocol") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, +}; + +static const struct field_t _TerminalInfo[] = { /* SEQUENCE */ + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, +}; + +static const struct field_t _EndpointType[] = { /* SEQUENCE */ + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("vendor") SEQ, 2, 3, 3, SKIP | EXT | OPT, 0, + _VendorIdentifier}, + {FNAME("gatekeeper") SEQ, 1, 1, 1, SKIP | EXT | OPT, 0, + _GatekeeperInfo}, + {FNAME("gateway") SEQ, 2, 2, 2, SKIP | EXT | OPT, 0, _GatewayInfo}, + {FNAME("mcu") SEQ, 1, 1, 2, SKIP | EXT | OPT, 0, _McuInfo}, + {FNAME("terminal") SEQ, 1, 1, 1, SKIP | EXT | OPT, 0, _TerminalInfo}, + {FNAME("mc") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("undefinedNode") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("set") BITSTR, FIXD, 32, 0, SKIP | OPT, 0, NULL}, + {FNAME("supportedTunnelledProtocols") SEQOF, SEMI, 0, 0, SKIP | OPT, + 0, NULL}, +}; + +static const struct field_t _Setup_UUIE_destinationAddress[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 1, 2, 7, SKIP | EXT, 0, _AliasAddress}, +}; + +static const struct field_t _Setup_UUIE_destExtraCallInfo[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 1, 2, 7, SKIP | EXT, 0, _AliasAddress}, +}; + +static const struct field_t _Setup_UUIE_destExtraCRV[] = { /* SEQUENCE OF */ + {FNAME("item") INT, WORD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _Setup_UUIE_conferenceGoal[] = { /* CHOICE */ + {FNAME("create") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("join") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("invite") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("capability-negotiation") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("callIndependentSupplementaryService") NUL, FIXD, 0, 0, SKIP, + 0, NULL}, +}; + +static const struct field_t _Q954Details[] = { /* SEQUENCE */ + {FNAME("conferenceCalling") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("threePartyService") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _QseriesOptions[] = { /* SEQUENCE */ + {FNAME("q932Full") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("q951Full") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("q952Full") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("q953Full") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("q955Full") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("q956Full") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("q957Full") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("q954Info") SEQ, 0, 2, 2, SKIP | EXT, 0, _Q954Details}, +}; + +static const struct field_t _CallType[] = { /* CHOICE */ + {FNAME("pointToPoint") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("oneToN") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("nToOne") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("nToN") NUL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _H245_NonStandardIdentifier_h221NonStandard[] = { /* SEQUENCE */ + {FNAME("t35CountryCode") INT, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("t35Extension") INT, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("manufacturerCode") INT, WORD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _H245_NonStandardIdentifier[] = { /* CHOICE */ + {FNAME("object") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("h221NonStandard") SEQ, 0, 3, 3, SKIP, 0, + _H245_NonStandardIdentifier_h221NonStandard}, +}; + +static const struct field_t _H245_NonStandardParameter[] = { /* SEQUENCE */ + {FNAME("nonStandardIdentifier") CHOICE, 1, 2, 2, SKIP, 0, + _H245_NonStandardIdentifier}, + {FNAME("data") OCTSTR, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _H261VideoCapability[] = { /* SEQUENCE */ + {FNAME("qcifMPI") INT, 2, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("cifMPI") INT, 2, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("temporalSpatialTradeOffCapability") BOOL, FIXD, 0, 0, SKIP, 0, + NULL}, + {FNAME("maxBitRate") INT, WORD, 1, 0, SKIP, 0, NULL}, + {FNAME("stillImageTransmission") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("videoBadMBsCap") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _H262VideoCapability[] = { /* SEQUENCE */ + {FNAME("profileAndLevel-SPatML") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("profileAndLevel-MPatLL") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("profileAndLevel-MPatML") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("profileAndLevel-MPatH-14") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("profileAndLevel-MPatHL") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("profileAndLevel-SNRatLL") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("profileAndLevel-SNRatML") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("profileAndLevel-SpatialatH-14") BOOL, FIXD, 0, 0, SKIP, 0, + NULL}, + {FNAME("profileAndLevel-HPatML") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("profileAndLevel-HPatH-14") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("profileAndLevel-HPatHL") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("videoBitRate") INT, CONS, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("vbvBufferSize") INT, CONS, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("samplesPerLine") INT, WORD, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("linesPerFrame") INT, WORD, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("framesPerSecond") INT, 4, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("luminanceSampleRate") INT, CONS, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("videoBadMBsCap") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _H263VideoCapability[] = { /* SEQUENCE */ + {FNAME("sqcifMPI") INT, 5, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("qcifMPI") INT, 5, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("cifMPI") INT, 5, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("cif4MPI") INT, 5, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("cif16MPI") INT, 5, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("maxBitRate") INT, CONS, 1, 0, SKIP, 0, NULL}, + {FNAME("unrestrictedVector") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("arithmeticCoding") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("advancedPrediction") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("pbFrames") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("temporalSpatialTradeOffCapability") BOOL, FIXD, 0, 0, SKIP, 0, + NULL}, + {FNAME("hrd-B") INT, CONS, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("bppMaxKb") INT, WORD, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("slowSqcifMPI") INT, WORD, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("slowQcifMPI") INT, WORD, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("slowCifMPI") INT, WORD, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("slowCif4MPI") INT, WORD, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("slowCif16MPI") INT, WORD, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("errorCompensation") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("enhancementLayerInfo") SEQ, 3, 4, 4, SKIP | EXT | OPT, 0, + NULL}, + {FNAME("h263Options") SEQ, 5, 29, 31, SKIP | EXT | OPT, 0, NULL}, +}; + +static const struct field_t _IS11172VideoCapability[] = { /* SEQUENCE */ + {FNAME("constrainedBitstream") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("videoBitRate") INT, CONS, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("vbvBufferSize") INT, CONS, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("samplesPerLine") INT, WORD, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("linesPerFrame") INT, WORD, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("pictureRate") INT, 4, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("luminanceSampleRate") INT, CONS, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("videoBadMBsCap") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _VideoCapability[] = { /* CHOICE */ + {FNAME("nonStandard") SEQ, 0, 2, 2, SKIP, 0, + _H245_NonStandardParameter}, + {FNAME("h261VideoCapability") SEQ, 2, 5, 6, SKIP | EXT, 0, + _H261VideoCapability}, + {FNAME("h262VideoCapability") SEQ, 6, 17, 18, SKIP | EXT, 0, + _H262VideoCapability}, + {FNAME("h263VideoCapability") SEQ, 7, 13, 21, SKIP | EXT, 0, + _H263VideoCapability}, + {FNAME("is11172VideoCapability") SEQ, 6, 7, 8, SKIP | EXT, 0, + _IS11172VideoCapability}, + {FNAME("genericVideoCapability") SEQ, 5, 6, 6, SKIP | EXT, 0, NULL}, +}; + +static const struct field_t _AudioCapability_g7231[] = { /* SEQUENCE */ + {FNAME("maxAl-sduAudioFrames") INT, BYTE, 1, 0, SKIP, 0, NULL}, + {FNAME("silenceSuppression") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _IS11172AudioCapability[] = { /* SEQUENCE */ + {FNAME("audioLayer1") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("audioLayer2") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("audioLayer3") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("audioSampling32k") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("audioSampling44k1") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("audioSampling48k") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("singleChannel") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("twoChannels") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("bitRate") INT, WORD, 1, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _IS13818AudioCapability[] = { /* SEQUENCE */ + {FNAME("audioLayer1") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("audioLayer2") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("audioLayer3") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("audioSampling16k") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("audioSampling22k05") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("audioSampling24k") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("audioSampling32k") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("audioSampling44k1") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("audioSampling48k") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("singleChannel") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("twoChannels") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("threeChannels2-1") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("threeChannels3-0") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("fourChannels2-0-2-0") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("fourChannels2-2") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("fourChannels3-1") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("fiveChannels3-0-2-0") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("fiveChannels3-2") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("lowFrequencyEnhancement") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("multilingual") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("bitRate") INT, WORD, 1, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _AudioCapability[] = { /* CHOICE */ + {FNAME("nonStandard") SEQ, 0, 2, 2, SKIP, 0, + _H245_NonStandardParameter}, + {FNAME("g711Alaw64k") INT, BYTE, 1, 0, SKIP, 0, NULL}, + {FNAME("g711Alaw56k") INT, BYTE, 1, 0, SKIP, 0, NULL}, + {FNAME("g711Ulaw64k") INT, BYTE, 1, 0, SKIP, 0, NULL}, + {FNAME("g711Ulaw56k") INT, BYTE, 1, 0, SKIP, 0, NULL}, + {FNAME("g722-64k") INT, BYTE, 1, 0, SKIP, 0, NULL}, + {FNAME("g722-56k") INT, BYTE, 1, 0, SKIP, 0, NULL}, + {FNAME("g722-48k") INT, BYTE, 1, 0, SKIP, 0, NULL}, + {FNAME("g7231") SEQ, 0, 2, 2, SKIP, 0, _AudioCapability_g7231}, + {FNAME("g728") INT, BYTE, 1, 0, SKIP, 0, NULL}, + {FNAME("g729") INT, BYTE, 1, 0, SKIP, 0, NULL}, + {FNAME("g729AnnexA") INT, BYTE, 1, 0, SKIP, 0, NULL}, + {FNAME("is11172AudioCapability") SEQ, 0, 9, 9, SKIP | EXT, 0, + _IS11172AudioCapability}, + {FNAME("is13818AudioCapability") SEQ, 0, 21, 21, SKIP | EXT, 0, + _IS13818AudioCapability}, + {FNAME("g729wAnnexB") INT, BYTE, 1, 0, SKIP, 0, NULL}, + {FNAME("g729AnnexAwAnnexB") INT, BYTE, 1, 0, SKIP, 0, NULL}, + {FNAME("g7231AnnexCCapability") SEQ, 1, 3, 3, SKIP | EXT, 0, NULL}, + {FNAME("gsmFullRate") SEQ, 0, 3, 3, SKIP | EXT, 0, NULL}, + {FNAME("gsmHalfRate") SEQ, 0, 3, 3, SKIP | EXT, 0, NULL}, + {FNAME("gsmEnhancedFullRate") SEQ, 0, 3, 3, SKIP | EXT, 0, NULL}, + {FNAME("genericAudioCapability") SEQ, 5, 6, 6, SKIP | EXT, 0, NULL}, + {FNAME("g729Extensions") SEQ, 1, 8, 8, SKIP | EXT, 0, NULL}, +}; + +static const struct field_t _DataProtocolCapability[] = { /* CHOICE */ + {FNAME("nonStandard") SEQ, 0, 2, 2, SKIP, 0, + _H245_NonStandardParameter}, + {FNAME("v14buffered") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("v42lapm") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("hdlcFrameTunnelling") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("h310SeparateVCStack") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("h310SingleVCStack") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("transparent") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("segmentationAndReassembly") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("hdlcFrameTunnelingwSAR") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("v120") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("separateLANStack") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("v76wCompression") CHOICE, 2, 3, 3, SKIP | EXT, 0, NULL}, + {FNAME("tcp") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("udp") NUL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _T84Profile_t84Restricted[] = { /* SEQUENCE */ + {FNAME("qcif") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("cif") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("ccir601Seq") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("ccir601Prog") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("hdtvSeq") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("hdtvProg") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("g3FacsMH200x100") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("g3FacsMH200x200") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("g4FacsMMR200x100") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("g4FacsMMR200x200") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("jbig200x200Seq") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("jbig200x200Prog") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("jbig300x300Seq") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("jbig300x300Prog") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("digPhotoLow") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("digPhotoMedSeq") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("digPhotoMedProg") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("digPhotoHighSeq") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("digPhotoHighProg") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _T84Profile[] = { /* CHOICE */ + {FNAME("t84Unrestricted") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("t84Restricted") SEQ, 0, 19, 19, SKIP | EXT, 0, + _T84Profile_t84Restricted}, +}; + +static const struct field_t _DataApplicationCapability_application_t84[] = { /* SEQUENCE */ + {FNAME("t84Protocol") CHOICE, 3, 7, 14, SKIP | EXT, 0, + _DataProtocolCapability}, + {FNAME("t84Profile") CHOICE, 1, 2, 2, SKIP, 0, _T84Profile}, +}; + +static const struct field_t _DataApplicationCapability_application_nlpid[] = { /* SEQUENCE */ + {FNAME("nlpidProtocol") CHOICE, 3, 7, 14, SKIP | EXT, 0, + _DataProtocolCapability}, + {FNAME("nlpidData") OCTSTR, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _DataApplicationCapability_application[] = { /* CHOICE */ + {FNAME("nonStandard") SEQ, 0, 2, 2, SKIP, 0, + _H245_NonStandardParameter}, + {FNAME("t120") CHOICE, 3, 7, 14, DECODE | EXT, + offsetof(DataApplicationCapability_application, t120), + _DataProtocolCapability}, + {FNAME("dsm-cc") CHOICE, 3, 7, 14, SKIP | EXT, 0, + _DataProtocolCapability}, + {FNAME("userData") CHOICE, 3, 7, 14, SKIP | EXT, 0, + _DataProtocolCapability}, + {FNAME("t84") SEQ, 0, 2, 2, SKIP, 0, + _DataApplicationCapability_application_t84}, + {FNAME("t434") CHOICE, 3, 7, 14, SKIP | EXT, 0, + _DataProtocolCapability}, + {FNAME("h224") CHOICE, 3, 7, 14, SKIP | EXT, 0, + _DataProtocolCapability}, + {FNAME("nlpid") SEQ, 0, 2, 2, SKIP, 0, + _DataApplicationCapability_application_nlpid}, + {FNAME("dsvdControl") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("h222DataPartitioning") CHOICE, 3, 7, 14, SKIP | EXT, 0, + _DataProtocolCapability}, + {FNAME("t30fax") CHOICE, 3, 7, 14, SKIP | EXT, 0, NULL}, + {FNAME("t140") CHOICE, 3, 7, 14, SKIP | EXT, 0, NULL}, + {FNAME("t38fax") SEQ, 0, 2, 2, SKIP, 0, NULL}, + {FNAME("genericDataCapability") SEQ, 5, 6, 6, SKIP | EXT, 0, NULL}, +}; + +static const struct field_t _DataApplicationCapability[] = { /* SEQUENCE */ + {FNAME("application") CHOICE, 4, 10, 14, DECODE | EXT, + offsetof(DataApplicationCapability, application), + _DataApplicationCapability_application}, + {FNAME("maxBitRate") INT, CONS, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _EncryptionMode[] = { /* CHOICE */ + {FNAME("nonStandard") SEQ, 0, 2, 2, SKIP, 0, + _H245_NonStandardParameter}, + {FNAME("h233Encryption") NUL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _DataType[] = { /* CHOICE */ + {FNAME("nonStandard") SEQ, 0, 2, 2, SKIP, 0, + _H245_NonStandardParameter}, + {FNAME("nullData") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("videoData") CHOICE, 3, 5, 6, SKIP | EXT, 0, _VideoCapability}, + {FNAME("audioData") CHOICE, 4, 14, 22, SKIP | EXT, 0, + _AudioCapability}, + {FNAME("data") SEQ, 0, 2, 2, DECODE | EXT, offsetof(DataType, data), + _DataApplicationCapability}, + {FNAME("encryptionData") CHOICE, 1, 2, 2, SKIP | EXT, 0, + _EncryptionMode}, + {FNAME("h235Control") SEQ, 0, 2, 2, SKIP, 0, NULL}, + {FNAME("h235Media") SEQ, 0, 2, 2, SKIP | EXT, 0, NULL}, + {FNAME("multiplexedStream") SEQ, 0, 2, 2, SKIP | EXT, 0, NULL}, +}; + +static const struct field_t _H222LogicalChannelParameters[] = { /* SEQUENCE */ + {FNAME("resourceID") INT, WORD, 0, 0, SKIP, 0, NULL}, + {FNAME("subChannelID") INT, WORD, 0, 0, SKIP, 0, NULL}, + {FNAME("pcr-pid") INT, WORD, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("programDescriptors") OCTSTR, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("streamDescriptors") OCTSTR, SEMI, 0, 0, SKIP | OPT, 0, NULL}, +}; + +static const struct field_t _H223LogicalChannelParameters_adaptationLayerType_al3[] = { /* SEQUENCE */ + {FNAME("controlFieldOctets") INT, 2, 0, 0, SKIP, 0, NULL}, + {FNAME("sendBufferSize") INT, CONS, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _H223LogicalChannelParameters_adaptationLayerType[] = { /* CHOICE */ + {FNAME("nonStandard") SEQ, 0, 2, 2, SKIP, 0, + _H245_NonStandardParameter}, + {FNAME("al1Framed") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("al1NotFramed") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("al2WithoutSequenceNumbers") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("al2WithSequenceNumbers") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("al3") SEQ, 0, 2, 2, SKIP, 0, + _H223LogicalChannelParameters_adaptationLayerType_al3}, + {FNAME("al1M") SEQ, 0, 7, 8, SKIP | EXT, 0, NULL}, + {FNAME("al2M") SEQ, 0, 2, 2, SKIP | EXT, 0, NULL}, + {FNAME("al3M") SEQ, 0, 5, 6, SKIP | EXT, 0, NULL}, +}; + +static const struct field_t _H223LogicalChannelParameters[] = { /* SEQUENCE */ + {FNAME("adaptationLayerType") CHOICE, 3, 6, 9, SKIP | EXT, 0, + _H223LogicalChannelParameters_adaptationLayerType}, + {FNAME("segmentableFlag") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _CRCLength[] = { /* CHOICE */ + {FNAME("crc8bit") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("crc16bit") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("crc32bit") NUL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _V76HDLCParameters[] = { /* SEQUENCE */ + {FNAME("crcLength") CHOICE, 2, 3, 3, SKIP | EXT, 0, _CRCLength}, + {FNAME("n401") INT, WORD, 1, 0, SKIP, 0, NULL}, + {FNAME("loopbackTestProcedure") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _V76LogicalChannelParameters_suspendResume[] = { /* CHOICE */ + {FNAME("noSuspendResume") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("suspendResumewAddress") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("suspendResumewoAddress") NUL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _V76LogicalChannelParameters_mode_eRM_recovery[] = { /* CHOICE */ + {FNAME("rej") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("sREJ") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("mSREJ") NUL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _V76LogicalChannelParameters_mode_eRM[] = { /* SEQUENCE */ + {FNAME("windowSize") INT, 7, 1, 0, SKIP, 0, NULL}, + {FNAME("recovery") CHOICE, 2, 3, 3, SKIP | EXT, 0, + _V76LogicalChannelParameters_mode_eRM_recovery}, +}; + +static const struct field_t _V76LogicalChannelParameters_mode[] = { /* CHOICE */ + {FNAME("eRM") SEQ, 0, 2, 2, SKIP | EXT, 0, + _V76LogicalChannelParameters_mode_eRM}, + {FNAME("uNERM") NUL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _V75Parameters[] = { /* SEQUENCE */ + {FNAME("audioHeaderPresent") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _V76LogicalChannelParameters[] = { /* SEQUENCE */ + {FNAME("hdlcParameters") SEQ, 0, 3, 3, SKIP | EXT, 0, + _V76HDLCParameters}, + {FNAME("suspendResume") CHOICE, 2, 3, 3, SKIP | EXT, 0, + _V76LogicalChannelParameters_suspendResume}, + {FNAME("uIH") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("mode") CHOICE, 1, 2, 2, SKIP | EXT, 0, + _V76LogicalChannelParameters_mode}, + {FNAME("v75Parameters") SEQ, 0, 1, 1, SKIP | EXT, 0, _V75Parameters}, +}; + +static const struct field_t _H2250LogicalChannelParameters_nonStandard[] = { /* SEQUENCE OF */ + {FNAME("item") SEQ, 0, 2, 2, SKIP, 0, _H245_NonStandardParameter}, +}; + +static const struct field_t _UnicastAddress_iPAddress[] = { /* SEQUENCE */ + {FNAME("network") OCTSTR, FIXD, 4, 0, DECODE, + offsetof(UnicastAddress_iPAddress, network), NULL}, + {FNAME("tsapIdentifier") INT, WORD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _UnicastAddress_iPXAddress[] = { /* SEQUENCE */ + {FNAME("node") OCTSTR, FIXD, 6, 0, SKIP, 0, NULL}, + {FNAME("netnum") OCTSTR, FIXD, 4, 0, SKIP, 0, NULL}, + {FNAME("tsapIdentifier") OCTSTR, FIXD, 2, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _UnicastAddress_iP6Address[] = { /* SEQUENCE */ + {FNAME("network") OCTSTR, FIXD, 16, 0, DECODE, + offsetof(UnicastAddress_iP6Address, network), NULL}, + {FNAME("tsapIdentifier") INT, WORD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _UnicastAddress_iPSourceRouteAddress_routing[] = { /* CHOICE */ + {FNAME("strict") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("loose") NUL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _UnicastAddress_iPSourceRouteAddress_route[] = { /* SEQUENCE OF */ + {FNAME("item") OCTSTR, FIXD, 4, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _UnicastAddress_iPSourceRouteAddress[] = { /* SEQUENCE */ + {FNAME("routing") CHOICE, 1, 2, 2, SKIP, 0, + _UnicastAddress_iPSourceRouteAddress_routing}, + {FNAME("network") OCTSTR, FIXD, 4, 0, SKIP, 0, NULL}, + {FNAME("tsapIdentifier") INT, WORD, 0, 0, SKIP, 0, NULL}, + {FNAME("route") SEQOF, SEMI, 0, 0, SKIP, 0, + _UnicastAddress_iPSourceRouteAddress_route}, +}; + +static const struct field_t _UnicastAddress[] = { /* CHOICE */ + {FNAME("iPAddress") SEQ, 0, 2, 2, DECODE | EXT, + offsetof(UnicastAddress, iPAddress), _UnicastAddress_iPAddress}, + {FNAME("iPXAddress") SEQ, 0, 3, 3, SKIP | EXT, 0, + _UnicastAddress_iPXAddress}, + {FNAME("iP6Address") SEQ, 0, 2, 2, DECODE | EXT, + offsetof(UnicastAddress, iP6Address), _UnicastAddress_iP6Address}, + {FNAME("netBios") OCTSTR, FIXD, 16, 0, SKIP, 0, NULL}, + {FNAME("iPSourceRouteAddress") SEQ, 0, 4, 4, SKIP | EXT, 0, + _UnicastAddress_iPSourceRouteAddress}, + {FNAME("nsap") OCTSTR, 5, 1, 0, SKIP, 0, NULL}, + {FNAME("nonStandardAddress") SEQ, 0, 2, 2, SKIP, 0, NULL}, +}; + +static const struct field_t _MulticastAddress_iPAddress[] = { /* SEQUENCE */ + {FNAME("network") OCTSTR, FIXD, 4, 0, SKIP, 0, NULL}, + {FNAME("tsapIdentifier") INT, WORD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _MulticastAddress_iP6Address[] = { /* SEQUENCE */ + {FNAME("network") OCTSTR, FIXD, 16, 0, SKIP, 0, NULL}, + {FNAME("tsapIdentifier") INT, WORD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _MulticastAddress[] = { /* CHOICE */ + {FNAME("iPAddress") SEQ, 0, 2, 2, SKIP | EXT, 0, + _MulticastAddress_iPAddress}, + {FNAME("iP6Address") SEQ, 0, 2, 2, SKIP | EXT, 0, + _MulticastAddress_iP6Address}, + {FNAME("nsap") OCTSTR, 5, 1, 0, SKIP, 0, NULL}, + {FNAME("nonStandardAddress") SEQ, 0, 2, 2, SKIP, 0, NULL}, +}; + +static const struct field_t _H245_TransportAddress[] = { /* CHOICE */ + {FNAME("unicastAddress") CHOICE, 3, 5, 7, DECODE | EXT, + offsetof(H245_TransportAddress, unicastAddress), _UnicastAddress}, + {FNAME("multicastAddress") CHOICE, 1, 2, 4, SKIP | EXT, 0, + _MulticastAddress}, +}; + +static const struct field_t _H2250LogicalChannelParameters[] = { /* SEQUENCE */ + {FNAME("nonStandard") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, + _H2250LogicalChannelParameters_nonStandard}, + {FNAME("sessionID") INT, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("associatedSessionID") INT, 8, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("mediaChannel") CHOICE, 1, 2, 2, DECODE | EXT | OPT, + offsetof(H2250LogicalChannelParameters, mediaChannel), + _H245_TransportAddress}, + {FNAME("mediaGuaranteedDelivery") BOOL, FIXD, 0, 0, SKIP | OPT, 0, + NULL}, + {FNAME("mediaControlChannel") CHOICE, 1, 2, 2, DECODE | EXT | OPT, + offsetof(H2250LogicalChannelParameters, mediaControlChannel), + _H245_TransportAddress}, + {FNAME("mediaControlGuaranteedDelivery") BOOL, FIXD, 0, 0, STOP | OPT, + 0, NULL}, + {FNAME("silenceSuppression") BOOL, FIXD, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("destination") SEQ, 0, 2, 2, STOP | EXT | OPT, 0, NULL}, + {FNAME("dynamicRTPPayloadType") INT, 5, 96, 0, STOP | OPT, 0, NULL}, + {FNAME("mediaPacketization") CHOICE, 0, 1, 2, STOP | EXT | OPT, 0, + NULL}, + {FNAME("transportCapability") SEQ, 3, 3, 3, STOP | EXT | OPT, 0, + NULL}, + {FNAME("redundancyEncoding") SEQ, 1, 2, 2, STOP | EXT | OPT, 0, NULL}, + {FNAME("source") SEQ, 0, 2, 2, SKIP | EXT | OPT, 0, NULL}, +}; + +static const struct field_t _OpenLogicalChannel_forwardLogicalChannelParameters_multiplexParameters[] = { /* CHOICE */ + {FNAME("h222LogicalChannelParameters") SEQ, 3, 5, 5, SKIP | EXT, 0, + _H222LogicalChannelParameters}, + {FNAME("h223LogicalChannelParameters") SEQ, 0, 2, 2, SKIP | EXT, 0, + _H223LogicalChannelParameters}, + {FNAME("v76LogicalChannelParameters") SEQ, 0, 5, 5, SKIP | EXT, 0, + _V76LogicalChannelParameters}, + {FNAME("h2250LogicalChannelParameters") SEQ, 10, 11, 14, DECODE | EXT, + offsetof + (OpenLogicalChannel_forwardLogicalChannelParameters_multiplexParameters, + h2250LogicalChannelParameters), _H2250LogicalChannelParameters}, + {FNAME("none") NUL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _OpenLogicalChannel_forwardLogicalChannelParameters[] = { /* SEQUENCE */ + {FNAME("portNumber") INT, WORD, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("dataType") CHOICE, 3, 6, 9, DECODE | EXT, + offsetof(OpenLogicalChannel_forwardLogicalChannelParameters, + dataType), _DataType}, + {FNAME("multiplexParameters") CHOICE, 2, 3, 5, DECODE | EXT, + offsetof(OpenLogicalChannel_forwardLogicalChannelParameters, + multiplexParameters), + _OpenLogicalChannel_forwardLogicalChannelParameters_multiplexParameters}, + {FNAME("forwardLogicalChannelDependency") INT, WORD, 1, 0, SKIP | OPT, + 0, NULL}, + {FNAME("replacementFor") INT, WORD, 1, 0, SKIP | OPT, 0, NULL}, +}; + +static const struct field_t _OpenLogicalChannel_reverseLogicalChannelParameters_multiplexParameters[] = { /* CHOICE */ + {FNAME("h223LogicalChannelParameters") SEQ, 0, 2, 2, SKIP | EXT, 0, + _H223LogicalChannelParameters}, + {FNAME("v76LogicalChannelParameters") SEQ, 0, 5, 5, SKIP | EXT, 0, + _V76LogicalChannelParameters}, + {FNAME("h2250LogicalChannelParameters") SEQ, 10, 11, 14, DECODE | EXT, + offsetof + (OpenLogicalChannel_reverseLogicalChannelParameters_multiplexParameters, + h2250LogicalChannelParameters), _H2250LogicalChannelParameters}, +}; + +static const struct field_t _OpenLogicalChannel_reverseLogicalChannelParameters[] = { /* SEQUENCE */ + {FNAME("dataType") CHOICE, 3, 6, 9, SKIP | EXT, 0, _DataType}, + {FNAME("multiplexParameters") CHOICE, 1, 2, 3, DECODE | EXT | OPT, + offsetof(OpenLogicalChannel_reverseLogicalChannelParameters, + multiplexParameters), + _OpenLogicalChannel_reverseLogicalChannelParameters_multiplexParameters}, + {FNAME("reverseLogicalChannelDependency") INT, WORD, 1, 0, SKIP | OPT, + 0, NULL}, + {FNAME("replacementFor") INT, WORD, 1, 0, SKIP | OPT, 0, NULL}, +}; + +static const struct field_t _NetworkAccessParameters_distribution[] = { /* CHOICE */ + {FNAME("unicast") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("multicast") NUL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _Q2931Address_address[] = { /* CHOICE */ + {FNAME("internationalNumber") NUMSTR, 4, 1, 0, SKIP, 0, NULL}, + {FNAME("nsapAddress") OCTSTR, 5, 1, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _Q2931Address[] = { /* SEQUENCE */ + {FNAME("address") CHOICE, 1, 2, 2, SKIP | EXT, 0, + _Q2931Address_address}, + {FNAME("subaddress") OCTSTR, 5, 1, 0, SKIP | OPT, 0, NULL}, +}; + +static const struct field_t _NetworkAccessParameters_networkAddress[] = { /* CHOICE */ + {FNAME("q2931Address") SEQ, 1, 2, 2, SKIP | EXT, 0, _Q2931Address}, + {FNAME("e164Address") NUMDGT, 7, 1, 0, SKIP, 0, NULL}, + {FNAME("localAreaAddress") CHOICE, 1, 2, 2, DECODE | EXT, + offsetof(NetworkAccessParameters_networkAddress, localAreaAddress), + _H245_TransportAddress}, +}; + +static const struct field_t _NetworkAccessParameters[] = { /* SEQUENCE */ + {FNAME("distribution") CHOICE, 1, 2, 2, SKIP | EXT | OPT, 0, + _NetworkAccessParameters_distribution}, + {FNAME("networkAddress") CHOICE, 2, 3, 3, DECODE | EXT, + offsetof(NetworkAccessParameters, networkAddress), + _NetworkAccessParameters_networkAddress}, + {FNAME("associateConference") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("externalReference") OCTSTR, 8, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("t120SetupProcedure") CHOICE, 2, 3, 3, SKIP | EXT | OPT, 0, + NULL}, +}; + +static const struct field_t _OpenLogicalChannel[] = { /* SEQUENCE */ + {FNAME("forwardLogicalChannelNumber") INT, WORD, 1, 0, SKIP, 0, NULL}, + {FNAME("forwardLogicalChannelParameters") SEQ, 1, 3, 5, DECODE | EXT, + offsetof(OpenLogicalChannel, forwardLogicalChannelParameters), + _OpenLogicalChannel_forwardLogicalChannelParameters}, + {FNAME("reverseLogicalChannelParameters") SEQ, 1, 2, 4, + DECODE | EXT | OPT, offsetof(OpenLogicalChannel, + reverseLogicalChannelParameters), + _OpenLogicalChannel_reverseLogicalChannelParameters}, + {FNAME("separateStack") SEQ, 2, 4, 5, DECODE | EXT | OPT, + offsetof(OpenLogicalChannel, separateStack), + _NetworkAccessParameters}, + {FNAME("encryptionSync") SEQ, 2, 4, 4, STOP | EXT | OPT, 0, NULL}, +}; + +static const struct field_t _Setup_UUIE_fastStart[] = { /* SEQUENCE OF */ + {FNAME("item") SEQ, 1, 3, 5, DECODE | OPEN | EXT, + sizeof(OpenLogicalChannel), _OpenLogicalChannel} + , +}; + +static const struct field_t _Setup_UUIE[] = { /* SEQUENCE */ + {FNAME("protocolIdentifier") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("h245Address") CHOICE, 3, 7, 7, DECODE | EXT | OPT, + offsetof(Setup_UUIE, h245Address), _TransportAddress}, + {FNAME("sourceAddress") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, + _Setup_UUIE_sourceAddress}, + {FNAME("sourceInfo") SEQ, 6, 8, 10, SKIP | EXT, 0, _EndpointType}, + {FNAME("destinationAddress") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, + _Setup_UUIE_destinationAddress}, + {FNAME("destCallSignalAddress") CHOICE, 3, 7, 7, DECODE | EXT | OPT, + offsetof(Setup_UUIE, destCallSignalAddress), _TransportAddress}, + {FNAME("destExtraCallInfo") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, + _Setup_UUIE_destExtraCallInfo}, + {FNAME("destExtraCRV") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, + _Setup_UUIE_destExtraCRV}, + {FNAME("activeMC") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("conferenceID") OCTSTR, FIXD, 16, 0, SKIP, 0, NULL}, + {FNAME("conferenceGoal") CHOICE, 2, 3, 5, SKIP | EXT, 0, + _Setup_UUIE_conferenceGoal}, + {FNAME("callServices") SEQ, 0, 8, 8, SKIP | EXT | OPT, 0, + _QseriesOptions}, + {FNAME("callType") CHOICE, 2, 4, 4, SKIP | EXT, 0, _CallType}, + {FNAME("sourceCallSignalAddress") CHOICE, 3, 7, 7, DECODE | EXT | OPT, + offsetof(Setup_UUIE, sourceCallSignalAddress), _TransportAddress}, + {FNAME("remoteExtensionAddress") CHOICE, 1, 2, 7, SKIP | EXT | OPT, 0, + NULL}, + {FNAME("callIdentifier") SEQ, 0, 1, 1, SKIP | EXT, 0, NULL}, + {FNAME("h245SecurityCapability") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, + NULL}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("fastStart") SEQOF, SEMI, 0, 30, DECODE | OPT, + offsetof(Setup_UUIE, fastStart), _Setup_UUIE_fastStart}, + {FNAME("mediaWaitForConnect") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("canOverlapSend") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("endpointIdentifier") BMPSTR, 7, 1, 0, STOP | OPT, 0, NULL}, + {FNAME("multipleCalls") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("maintainConnection") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("connectionParameters") SEQ, 0, 3, 3, SKIP | EXT | OPT, 0, + NULL}, + {FNAME("language") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("presentationIndicator") CHOICE, 2, 3, 3, SKIP | EXT | OPT, 0, + NULL}, + {FNAME("screeningIndicator") ENUM, 2, 0, 0, SKIP | EXT | OPT, 0, + NULL}, + {FNAME("serviceControl") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("symmetricOperationRequired") NUL, FIXD, 0, 0, SKIP | OPT, 0, + NULL}, + {FNAME("capacity") SEQ, 2, 2, 2, SKIP | EXT | OPT, 0, NULL}, + {FNAME("circuitInfo") SEQ, 3, 3, 3, SKIP | EXT | OPT, 0, NULL}, + {FNAME("desiredProtocols") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("neededFeatures") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("desiredFeatures") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("supportedFeatures") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("parallelH245Control") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("additionalSourceAddresses") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, + NULL}, +}; + +static const struct field_t _CallProceeding_UUIE_fastStart[] = { /* SEQUENCE OF */ + {FNAME("item") SEQ, 1, 3, 5, DECODE | OPEN | EXT, + sizeof(OpenLogicalChannel), _OpenLogicalChannel} + , +}; + +static const struct field_t _CallProceeding_UUIE[] = { /* SEQUENCE */ + {FNAME("protocolIdentifier") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("destinationInfo") SEQ, 6, 8, 10, SKIP | EXT, 0, + _EndpointType}, + {FNAME("h245Address") CHOICE, 3, 7, 7, DECODE | EXT | OPT, + offsetof(CallProceeding_UUIE, h245Address), _TransportAddress}, + {FNAME("callIdentifier") SEQ, 0, 1, 1, SKIP | EXT, 0, NULL}, + {FNAME("h245SecurityMode") CHOICE, 2, 4, 4, SKIP | EXT | OPT, 0, + NULL}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("fastStart") SEQOF, SEMI, 0, 30, DECODE | OPT, + offsetof(CallProceeding_UUIE, fastStart), + _CallProceeding_UUIE_fastStart}, + {FNAME("multipleCalls") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("maintainConnection") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("fastConnectRefused") NUL, FIXD, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("featureSet") SEQ, 3, 4, 4, SKIP | EXT | OPT, 0, NULL}, +}; + +static const struct field_t _Connect_UUIE_fastStart[] = { /* SEQUENCE OF */ + {FNAME("item") SEQ, 1, 3, 5, DECODE | OPEN | EXT, + sizeof(OpenLogicalChannel), _OpenLogicalChannel} + , +}; + +static const struct field_t _Connect_UUIE[] = { /* SEQUENCE */ + {FNAME("protocolIdentifier") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("h245Address") CHOICE, 3, 7, 7, DECODE | EXT | OPT, + offsetof(Connect_UUIE, h245Address), _TransportAddress}, + {FNAME("destinationInfo") SEQ, 6, 8, 10, SKIP | EXT, 0, + _EndpointType}, + {FNAME("conferenceID") OCTSTR, FIXD, 16, 0, SKIP, 0, NULL}, + {FNAME("callIdentifier") SEQ, 0, 1, 1, SKIP | EXT, 0, NULL}, + {FNAME("h245SecurityMode") CHOICE, 2, 4, 4, SKIP | EXT | OPT, 0, + NULL}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("fastStart") SEQOF, SEMI, 0, 30, DECODE | OPT, + offsetof(Connect_UUIE, fastStart), _Connect_UUIE_fastStart}, + {FNAME("multipleCalls") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("maintainConnection") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("language") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("connectedAddress") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("presentationIndicator") CHOICE, 2, 3, 3, SKIP | EXT | OPT, 0, + NULL}, + {FNAME("screeningIndicator") ENUM, 2, 0, 0, SKIP | EXT | OPT, 0, + NULL}, + {FNAME("fastConnectRefused") NUL, FIXD, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("serviceControl") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("capacity") SEQ, 2, 2, 2, SKIP | EXT | OPT, 0, NULL}, + {FNAME("featureSet") SEQ, 3, 4, 4, SKIP | EXT | OPT, 0, NULL}, +}; + +static const struct field_t _Alerting_UUIE_fastStart[] = { /* SEQUENCE OF */ + {FNAME("item") SEQ, 1, 3, 5, DECODE | OPEN | EXT, + sizeof(OpenLogicalChannel), _OpenLogicalChannel} + , +}; + +static const struct field_t _Alerting_UUIE[] = { /* SEQUENCE */ + {FNAME("protocolIdentifier") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("destinationInfo") SEQ, 6, 8, 10, SKIP | EXT, 0, + _EndpointType}, + {FNAME("h245Address") CHOICE, 3, 7, 7, DECODE | EXT | OPT, + offsetof(Alerting_UUIE, h245Address), _TransportAddress}, + {FNAME("callIdentifier") SEQ, 0, 1, 1, SKIP | EXT, 0, NULL}, + {FNAME("h245SecurityMode") CHOICE, 2, 4, 4, SKIP | EXT | OPT, 0, + NULL}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("fastStart") SEQOF, SEMI, 0, 30, DECODE | OPT, + offsetof(Alerting_UUIE, fastStart), _Alerting_UUIE_fastStart}, + {FNAME("multipleCalls") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("maintainConnection") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("alertingAddress") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("presentationIndicator") CHOICE, 2, 3, 3, SKIP | EXT | OPT, 0, + NULL}, + {FNAME("screeningIndicator") ENUM, 2, 0, 0, SKIP | EXT | OPT, 0, + NULL}, + {FNAME("fastConnectRefused") NUL, FIXD, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("serviceControl") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("capacity") SEQ, 2, 2, 2, SKIP | EXT | OPT, 0, NULL}, + {FNAME("featureSet") SEQ, 3, 4, 4, SKIP | EXT | OPT, 0, NULL}, +}; + +static const struct field_t _Information_UUIE[] = { /* SEQUENCE */ + {FNAME("protocolIdentifier") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("callIdentifier") SEQ, 0, 1, 1, SKIP | EXT, 0, NULL}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("fastStart") SEQOF, SEMI, 0, 30, SKIP | OPT, 0, NULL}, + {FNAME("fastConnectRefused") NUL, FIXD, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("circuitInfo") SEQ, 3, 3, 3, SKIP | EXT | OPT, 0, NULL}, +}; + +static const struct field_t _ReleaseCompleteReason[] = { /* CHOICE */ + {FNAME("noBandwidth") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("gatekeeperResources") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("unreachableDestination") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("destinationRejection") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("invalidRevision") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("noPermission") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("unreachableGatekeeper") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("gatewayResources") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("badFormatAddress") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("adaptiveBusy") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("inConf") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("undefinedReason") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("facilityCallDeflection") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("securityDenied") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("calledPartyNotRegistered") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("callerNotRegistered") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("newConnectionNeeded") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("nonStandardReason") SEQ, 0, 2, 2, SKIP, 0, NULL}, + {FNAME("replaceWithConferenceInvite") OCTSTR, FIXD, 16, 0, SKIP, 0, + NULL}, + {FNAME("genericDataReason") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("neededFeatureNotSupported") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("tunnelledSignallingRejected") NUL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _ReleaseComplete_UUIE[] = { /* SEQUENCE */ + {FNAME("protocolIdentifier") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("reason") CHOICE, 4, 12, 22, SKIP | EXT | OPT, 0, + _ReleaseCompleteReason}, + {FNAME("callIdentifier") SEQ, 0, 1, 1, SKIP | EXT, 0, NULL}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("busyAddress") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("presentationIndicator") CHOICE, 2, 3, 3, SKIP | EXT | OPT, 0, + NULL}, + {FNAME("screeningIndicator") ENUM, 2, 0, 0, SKIP | EXT | OPT, 0, + NULL}, + {FNAME("capacity") SEQ, 2, 2, 2, SKIP | EXT | OPT, 0, NULL}, + {FNAME("serviceControl") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("featureSet") SEQ, 3, 4, 4, SKIP | EXT | OPT, 0, NULL}, +}; + +static const struct field_t _Facility_UUIE_alternativeAliasAddress[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 1, 2, 7, SKIP | EXT, 0, _AliasAddress}, +}; + +static const struct field_t _FacilityReason[] = { /* CHOICE */ + {FNAME("routeCallToGatekeeper") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("callForwarded") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("routeCallToMC") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("undefinedReason") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("conferenceListChoice") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("startH245") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("noH245") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("newTokens") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("featureSetUpdate") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("forwardedElements") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("transportedInformation") NUL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _Facility_UUIE_fastStart[] = { /* SEQUENCE OF */ + {FNAME("item") SEQ, 1, 3, 5, DECODE | OPEN | EXT, + sizeof(OpenLogicalChannel), _OpenLogicalChannel} + , +}; + +static const struct field_t _Facility_UUIE[] = { /* SEQUENCE */ + {FNAME("protocolIdentifier") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("alternativeAddress") CHOICE, 3, 7, 7, DECODE | EXT | OPT, + offsetof(Facility_UUIE, alternativeAddress), _TransportAddress}, + {FNAME("alternativeAliasAddress") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, + _Facility_UUIE_alternativeAliasAddress}, + {FNAME("conferenceID") OCTSTR, FIXD, 16, 0, SKIP | OPT, 0, NULL}, + {FNAME("reason") CHOICE, 2, 4, 11, DECODE | EXT, + offsetof(Facility_UUIE, reason), _FacilityReason}, + {FNAME("callIdentifier") SEQ, 0, 1, 1, SKIP | EXT, 0, NULL}, + {FNAME("destExtraCallInfo") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("remoteExtensionAddress") CHOICE, 1, 2, 7, SKIP | EXT | OPT, 0, + NULL}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("conferences") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("h245Address") CHOICE, 3, 7, 7, DECODE | EXT | OPT, + offsetof(Facility_UUIE, h245Address), _TransportAddress}, + {FNAME("fastStart") SEQOF, SEMI, 0, 30, DECODE | OPT, + offsetof(Facility_UUIE, fastStart), _Facility_UUIE_fastStart}, + {FNAME("multipleCalls") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("maintainConnection") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("fastConnectRefused") NUL, FIXD, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("serviceControl") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("circuitInfo") SEQ, 3, 3, 3, SKIP | EXT | OPT, 0, NULL}, + {FNAME("featureSet") SEQ, 3, 4, 4, SKIP | EXT | OPT, 0, NULL}, + {FNAME("destinationInfo") SEQ, 6, 8, 10, SKIP | EXT | OPT, 0, NULL}, + {FNAME("h245SecurityMode") CHOICE, 2, 4, 4, SKIP | EXT | OPT, 0, + NULL}, +}; + +static const struct field_t _CallIdentifier[] = { /* SEQUENCE */ + {FNAME("guid") OCTSTR, FIXD, 16, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _SecurityServiceMode[] = { /* CHOICE */ + {FNAME("nonStandard") SEQ, 0, 2, 2, SKIP, 0, _NonStandardParameter}, + {FNAME("none") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("default") NUL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _SecurityCapabilities[] = { /* SEQUENCE */ + {FNAME("nonStandard") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("encryption") CHOICE, 2, 3, 3, SKIP | EXT, 0, + _SecurityServiceMode}, + {FNAME("authenticaton") CHOICE, 2, 3, 3, SKIP | EXT, 0, + _SecurityServiceMode}, + {FNAME("integrity") CHOICE, 2, 3, 3, SKIP | EXT, 0, + _SecurityServiceMode}, +}; + +static const struct field_t _H245Security[] = { /* CHOICE */ + {FNAME("nonStandard") SEQ, 0, 2, 2, SKIP, 0, _NonStandardParameter}, + {FNAME("noSecurity") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("tls") SEQ, 1, 4, 4, SKIP | EXT, 0, _SecurityCapabilities}, + {FNAME("ipsec") SEQ, 1, 4, 4, SKIP | EXT, 0, _SecurityCapabilities}, +}; + +static const struct field_t _DHset[] = { /* SEQUENCE */ + {FNAME("halfkey") BITSTR, WORD, 0, 0, SKIP, 0, NULL}, + {FNAME("modSize") BITSTR, WORD, 0, 0, SKIP, 0, NULL}, + {FNAME("generator") BITSTR, WORD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _TypedCertificate[] = { /* SEQUENCE */ + {FNAME("type") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("certificate") OCTSTR, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _H235_NonStandardParameter[] = { /* SEQUENCE */ + {FNAME("nonStandardIdentifier") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("data") OCTSTR, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _ClearToken[] = { /* SEQUENCE */ + {FNAME("tokenOID") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("timeStamp") INT, CONS, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("password") BMPSTR, 7, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("dhkey") SEQ, 0, 3, 3, SKIP | EXT | OPT, 0, _DHset}, + {FNAME("challenge") OCTSTR, 7, 8, 0, SKIP | OPT, 0, NULL}, + {FNAME("random") INT, UNCO, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("certificate") SEQ, 0, 2, 2, SKIP | EXT | OPT, 0, + _TypedCertificate}, + {FNAME("generalID") BMPSTR, 7, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("nonStandard") SEQ, 0, 2, 2, SKIP | OPT, 0, + _H235_NonStandardParameter}, + {FNAME("eckasdhkey") CHOICE, 1, 2, 2, SKIP | EXT | OPT, 0, NULL}, + {FNAME("sendersID") BMPSTR, 7, 1, 0, SKIP | OPT, 0, NULL}, +}; + +static const struct field_t _Progress_UUIE_tokens[] = { /* SEQUENCE OF */ + {FNAME("item") SEQ, 8, 9, 11, SKIP | EXT, 0, _ClearToken}, +}; + +static const struct field_t _Params[] = { /* SEQUENCE */ + {FNAME("ranInt") INT, UNCO, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("iv8") OCTSTR, FIXD, 8, 0, SKIP | OPT, 0, NULL}, + {FNAME("iv16") OCTSTR, FIXD, 16, 0, SKIP | OPT, 0, NULL}, +}; + +static const struct field_t _CryptoH323Token_cryptoEPPwdHash_token[] = { /* SEQUENCE */ + {FNAME("algorithmOID") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("paramS") SEQ, 2, 2, 3, SKIP | EXT, 0, _Params}, + {FNAME("hash") BITSTR, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _CryptoH323Token_cryptoEPPwdHash[] = { /* SEQUENCE */ + {FNAME("alias") CHOICE, 1, 2, 7, SKIP | EXT, 0, _AliasAddress}, + {FNAME("timeStamp") INT, CONS, 1, 0, SKIP, 0, NULL}, + {FNAME("token") SEQ, 0, 3, 3, SKIP, 0, + _CryptoH323Token_cryptoEPPwdHash_token}, +}; + +static const struct field_t _CryptoH323Token_cryptoGKPwdHash_token[] = { /* SEQUENCE */ + {FNAME("algorithmOID") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("paramS") SEQ, 2, 2, 3, SKIP | EXT, 0, _Params}, + {FNAME("hash") BITSTR, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _CryptoH323Token_cryptoGKPwdHash[] = { /* SEQUENCE */ + {FNAME("gatekeeperId") BMPSTR, 7, 1, 0, SKIP, 0, NULL}, + {FNAME("timeStamp") INT, CONS, 1, 0, SKIP, 0, NULL}, + {FNAME("token") SEQ, 0, 3, 3, SKIP, 0, + _CryptoH323Token_cryptoGKPwdHash_token}, +}; + +static const struct field_t _CryptoH323Token_cryptoEPPwdEncr[] = { /* SEQUENCE */ + {FNAME("algorithmOID") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("paramS") SEQ, 2, 2, 3, SKIP | EXT, 0, _Params}, + {FNAME("encryptedData") OCTSTR, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _CryptoH323Token_cryptoGKPwdEncr[] = { /* SEQUENCE */ + {FNAME("algorithmOID") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("paramS") SEQ, 2, 2, 3, SKIP | EXT, 0, _Params}, + {FNAME("encryptedData") OCTSTR, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _CryptoH323Token_cryptoEPCert[] = { /* SEQUENCE */ + {FNAME("toBeSigned") SEQ, 8, 9, 11, SKIP | OPEN | EXT, 0, NULL}, + {FNAME("algorithmOID") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("paramS") SEQ, 2, 2, 3, SKIP | EXT, 0, _Params}, + {FNAME("signature") BITSTR, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _CryptoH323Token_cryptoGKCert[] = { /* SEQUENCE */ + {FNAME("toBeSigned") SEQ, 8, 9, 11, SKIP | OPEN | EXT, 0, NULL}, + {FNAME("algorithmOID") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("paramS") SEQ, 2, 2, 3, SKIP | EXT, 0, _Params}, + {FNAME("signature") BITSTR, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _CryptoH323Token_cryptoFastStart[] = { /* SEQUENCE */ + {FNAME("toBeSigned") SEQ, 8, 9, 11, SKIP | OPEN | EXT, 0, NULL}, + {FNAME("algorithmOID") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("paramS") SEQ, 2, 2, 3, SKIP | EXT, 0, _Params}, + {FNAME("signature") BITSTR, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _CryptoToken_cryptoEncryptedToken_token[] = { /* SEQUENCE */ + {FNAME("algorithmOID") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("paramS") SEQ, 2, 2, 3, SKIP | EXT, 0, _Params}, + {FNAME("encryptedData") OCTSTR, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _CryptoToken_cryptoEncryptedToken[] = { /* SEQUENCE */ + {FNAME("tokenOID") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("token") SEQ, 0, 3, 3, SKIP, 0, + _CryptoToken_cryptoEncryptedToken_token}, +}; + +static const struct field_t _CryptoToken_cryptoSignedToken_token[] = { /* SEQUENCE */ + {FNAME("toBeSigned") SEQ, 8, 9, 11, SKIP | OPEN | EXT, 0, NULL}, + {FNAME("algorithmOID") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("paramS") SEQ, 2, 2, 3, SKIP | EXT, 0, _Params}, + {FNAME("signature") BITSTR, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _CryptoToken_cryptoSignedToken[] = { /* SEQUENCE */ + {FNAME("tokenOID") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("token") SEQ, 0, 4, 4, SKIP, 0, + _CryptoToken_cryptoSignedToken_token}, +}; + +static const struct field_t _CryptoToken_cryptoHashedToken_token[] = { /* SEQUENCE */ + {FNAME("algorithmOID") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("paramS") SEQ, 2, 2, 3, SKIP | EXT, 0, _Params}, + {FNAME("hash") BITSTR, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _CryptoToken_cryptoHashedToken[] = { /* SEQUENCE */ + {FNAME("tokenOID") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("hashedVals") SEQ, 8, 9, 11, SKIP | EXT, 0, _ClearToken}, + {FNAME("token") SEQ, 0, 3, 3, SKIP, 0, + _CryptoToken_cryptoHashedToken_token}, +}; + +static const struct field_t _CryptoToken_cryptoPwdEncr[] = { /* SEQUENCE */ + {FNAME("algorithmOID") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("paramS") SEQ, 2, 2, 3, SKIP | EXT, 0, _Params}, + {FNAME("encryptedData") OCTSTR, SEMI, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _CryptoToken[] = { /* CHOICE */ + {FNAME("cryptoEncryptedToken") SEQ, 0, 2, 2, SKIP, 0, + _CryptoToken_cryptoEncryptedToken}, + {FNAME("cryptoSignedToken") SEQ, 0, 2, 2, SKIP, 0, + _CryptoToken_cryptoSignedToken}, + {FNAME("cryptoHashedToken") SEQ, 0, 3, 3, SKIP, 0, + _CryptoToken_cryptoHashedToken}, + {FNAME("cryptoPwdEncr") SEQ, 0, 3, 3, SKIP, 0, + _CryptoToken_cryptoPwdEncr}, +}; + +static const struct field_t _CryptoH323Token[] = { /* CHOICE */ + {FNAME("cryptoEPPwdHash") SEQ, 0, 3, 3, SKIP, 0, + _CryptoH323Token_cryptoEPPwdHash}, + {FNAME("cryptoGKPwdHash") SEQ, 0, 3, 3, SKIP, 0, + _CryptoH323Token_cryptoGKPwdHash}, + {FNAME("cryptoEPPwdEncr") SEQ, 0, 3, 3, SKIP, 0, + _CryptoH323Token_cryptoEPPwdEncr}, + {FNAME("cryptoGKPwdEncr") SEQ, 0, 3, 3, SKIP, 0, + _CryptoH323Token_cryptoGKPwdEncr}, + {FNAME("cryptoEPCert") SEQ, 0, 4, 4, SKIP, 0, + _CryptoH323Token_cryptoEPCert}, + {FNAME("cryptoGKCert") SEQ, 0, 4, 4, SKIP, 0, + _CryptoH323Token_cryptoGKCert}, + {FNAME("cryptoFastStart") SEQ, 0, 4, 4, SKIP, 0, + _CryptoH323Token_cryptoFastStart}, + {FNAME("nestedcryptoToken") CHOICE, 2, 4, 4, SKIP | EXT, 0, + _CryptoToken}, +}; + +static const struct field_t _Progress_UUIE_cryptoTokens[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 3, 8, 8, SKIP | EXT, 0, _CryptoH323Token}, +}; + +static const struct field_t _Progress_UUIE_fastStart[] = { /* SEQUENCE OF */ + {FNAME("item") SEQ, 1, 3, 5, DECODE | OPEN | EXT, + sizeof(OpenLogicalChannel), _OpenLogicalChannel} + , +}; + +static const struct field_t _Progress_UUIE[] = { /* SEQUENCE */ + {FNAME("protocolIdentifier") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("destinationInfo") SEQ, 6, 8, 10, SKIP | EXT, 0, + _EndpointType}, + {FNAME("h245Address") CHOICE, 3, 7, 7, DECODE | EXT | OPT, + offsetof(Progress_UUIE, h245Address), _TransportAddress}, + {FNAME("callIdentifier") SEQ, 0, 1, 1, SKIP | EXT, 0, + _CallIdentifier}, + {FNAME("h245SecurityMode") CHOICE, 2, 4, 4, SKIP | EXT | OPT, 0, + _H245Security}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, + _Progress_UUIE_tokens}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, + _Progress_UUIE_cryptoTokens}, + {FNAME("fastStart") SEQOF, SEMI, 0, 30, DECODE | OPT, + offsetof(Progress_UUIE, fastStart), _Progress_UUIE_fastStart}, + {FNAME("multipleCalls") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("maintainConnection") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("fastConnectRefused") NUL, FIXD, 0, 0, SKIP | OPT, 0, NULL}, +}; + +static const struct field_t _H323_UU_PDU_h323_message_body[] = { /* CHOICE */ + {FNAME("setup") SEQ, 7, 13, 39, DECODE | EXT, + offsetof(H323_UU_PDU_h323_message_body, setup), _Setup_UUIE}, + {FNAME("callProceeding") SEQ, 1, 3, 12, DECODE | EXT, + offsetof(H323_UU_PDU_h323_message_body, callProceeding), + _CallProceeding_UUIE}, + {FNAME("connect") SEQ, 1, 4, 19, DECODE | EXT, + offsetof(H323_UU_PDU_h323_message_body, connect), _Connect_UUIE}, + {FNAME("alerting") SEQ, 1, 3, 17, DECODE | EXT, + offsetof(H323_UU_PDU_h323_message_body, alerting), _Alerting_UUIE}, + {FNAME("information") SEQ, 0, 1, 7, SKIP | EXT, 0, _Information_UUIE}, + {FNAME("releaseComplete") SEQ, 1, 2, 11, SKIP | EXT, 0, + _ReleaseComplete_UUIE}, + {FNAME("facility") SEQ, 3, 5, 21, DECODE | EXT, + offsetof(H323_UU_PDU_h323_message_body, facility), _Facility_UUIE}, + {FNAME("progress") SEQ, 5, 8, 11, DECODE | EXT, + offsetof(H323_UU_PDU_h323_message_body, progress), _Progress_UUIE}, + {FNAME("empty") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("status") SEQ, 2, 4, 4, SKIP | EXT, 0, NULL}, + {FNAME("statusInquiry") SEQ, 2, 4, 4, SKIP | EXT, 0, NULL}, + {FNAME("setupAcknowledge") SEQ, 2, 4, 4, SKIP | EXT, 0, NULL}, + {FNAME("notify") SEQ, 2, 4, 4, SKIP | EXT, 0, NULL}, +}; + +static const struct field_t _RequestMessage[] = { /* CHOICE */ + {FNAME("nonStandard") SEQ, 0, 1, 1, STOP | EXT, 0, NULL}, + {FNAME("masterSlaveDetermination") SEQ, 0, 2, 2, STOP | EXT, 0, NULL}, + {FNAME("terminalCapabilitySet") SEQ, 3, 5, 5, STOP | EXT, 0, NULL}, + {FNAME("openLogicalChannel") SEQ, 1, 3, 5, DECODE | EXT, + offsetof(RequestMessage, openLogicalChannel), _OpenLogicalChannel}, + {FNAME("closeLogicalChannel") SEQ, 0, 2, 3, STOP | EXT, 0, NULL}, + {FNAME("requestChannelClose") SEQ, 0, 1, 3, STOP | EXT, 0, NULL}, + {FNAME("multiplexEntrySend") SEQ, 0, 2, 2, STOP | EXT, 0, NULL}, + {FNAME("requestMultiplexEntry") SEQ, 0, 1, 1, STOP | EXT, 0, NULL}, + {FNAME("requestMode") SEQ, 0, 2, 2, STOP | EXT, 0, NULL}, + {FNAME("roundTripDelayRequest") SEQ, 0, 1, 1, STOP | EXT, 0, NULL}, + {FNAME("maintenanceLoopRequest") SEQ, 0, 1, 1, STOP | EXT, 0, NULL}, + {FNAME("communicationModeRequest") SEQ, 0, 0, 0, STOP | EXT, 0, NULL}, + {FNAME("conferenceRequest") CHOICE, 3, 8, 16, STOP | EXT, 0, NULL}, + {FNAME("multilinkRequest") CHOICE, 3, 5, 5, STOP | EXT, 0, NULL}, + {FNAME("logicalChannelRateRequest") SEQ, 0, 3, 3, STOP | EXT, 0, + NULL}, +}; + +static const struct field_t _OpenLogicalChannelAck_reverseLogicalChannelParameters_multiplexParameters[] = { /* CHOICE */ + {FNAME("h222LogicalChannelParameters") SEQ, 3, 5, 5, SKIP | EXT, 0, + _H222LogicalChannelParameters}, + {FNAME("h2250LogicalChannelParameters") SEQ, 10, 11, 14, DECODE | EXT, + offsetof + (OpenLogicalChannelAck_reverseLogicalChannelParameters_multiplexParameters, + h2250LogicalChannelParameters), _H2250LogicalChannelParameters}, +}; + +static const struct field_t _OpenLogicalChannelAck_reverseLogicalChannelParameters[] = { /* SEQUENCE */ + {FNAME("reverseLogicalChannelNumber") INT, WORD, 1, 0, SKIP, 0, NULL}, + {FNAME("portNumber") INT, WORD, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("multiplexParameters") CHOICE, 0, 1, 2, DECODE | EXT | OPT, + offsetof(OpenLogicalChannelAck_reverseLogicalChannelParameters, + multiplexParameters), + _OpenLogicalChannelAck_reverseLogicalChannelParameters_multiplexParameters}, + {FNAME("replacementFor") INT, WORD, 1, 0, SKIP | OPT, 0, NULL}, +}; + +static const struct field_t _H2250LogicalChannelAckParameters_nonStandard[] = { /* SEQUENCE OF */ + {FNAME("item") SEQ, 0, 2, 2, SKIP, 0, _H245_NonStandardParameter}, +}; + +static const struct field_t _H2250LogicalChannelAckParameters[] = { /* SEQUENCE */ + {FNAME("nonStandard") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, + _H2250LogicalChannelAckParameters_nonStandard}, + {FNAME("sessionID") INT, 8, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("mediaChannel") CHOICE, 1, 2, 2, DECODE | EXT | OPT, + offsetof(H2250LogicalChannelAckParameters, mediaChannel), + _H245_TransportAddress}, + {FNAME("mediaControlChannel") CHOICE, 1, 2, 2, DECODE | EXT | OPT, + offsetof(H2250LogicalChannelAckParameters, mediaControlChannel), + _H245_TransportAddress}, + {FNAME("dynamicRTPPayloadType") INT, 5, 96, 0, SKIP | OPT, 0, NULL}, + {FNAME("flowControlToZero") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("portNumber") INT, WORD, 0, 0, SKIP | OPT, 0, NULL}, +}; + +static const struct field_t _OpenLogicalChannelAck_forwardMultiplexAckParameters[] = { /* CHOICE */ + {FNAME("h2250LogicalChannelAckParameters") SEQ, 5, 5, 7, DECODE | EXT, + offsetof(OpenLogicalChannelAck_forwardMultiplexAckParameters, + h2250LogicalChannelAckParameters), + _H2250LogicalChannelAckParameters}, +}; + +static const struct field_t _OpenLogicalChannelAck[] = { /* SEQUENCE */ + {FNAME("forwardLogicalChannelNumber") INT, WORD, 1, 0, SKIP, 0, NULL}, + {FNAME("reverseLogicalChannelParameters") SEQ, 2, 3, 4, + DECODE | EXT | OPT, offsetof(OpenLogicalChannelAck, + reverseLogicalChannelParameters), + _OpenLogicalChannelAck_reverseLogicalChannelParameters}, + {FNAME("separateStack") SEQ, 2, 4, 5, DECODE | EXT | OPT, + offsetof(OpenLogicalChannelAck, separateStack), + _NetworkAccessParameters}, + {FNAME("forwardMultiplexAckParameters") CHOICE, 0, 1, 1, + DECODE | EXT | OPT, offsetof(OpenLogicalChannelAck, + forwardMultiplexAckParameters), + _OpenLogicalChannelAck_forwardMultiplexAckParameters}, + {FNAME("encryptionSync") SEQ, 2, 4, 4, STOP | EXT | OPT, 0, NULL}, +}; + +static const struct field_t _ResponseMessage[] = { /* CHOICE */ + {FNAME("nonStandard") SEQ, 0, 1, 1, STOP | EXT, 0, NULL}, + {FNAME("masterSlaveDeterminationAck") SEQ, 0, 1, 1, STOP | EXT, 0, + NULL}, + {FNAME("masterSlaveDeterminationReject") SEQ, 0, 1, 1, STOP | EXT, 0, + NULL}, + {FNAME("terminalCapabilitySetAck") SEQ, 0, 1, 1, STOP | EXT, 0, NULL}, + {FNAME("terminalCapabilitySetReject") SEQ, 0, 2, 2, STOP | EXT, 0, + NULL}, + {FNAME("openLogicalChannelAck") SEQ, 1, 2, 5, DECODE | EXT, + offsetof(ResponseMessage, openLogicalChannelAck), + _OpenLogicalChannelAck}, + {FNAME("openLogicalChannelReject") SEQ, 0, 2, 2, STOP | EXT, 0, NULL}, + {FNAME("closeLogicalChannelAck") SEQ, 0, 1, 1, STOP | EXT, 0, NULL}, + {FNAME("requestChannelCloseAck") SEQ, 0, 1, 1, STOP | EXT, 0, NULL}, + {FNAME("requestChannelCloseReject") SEQ, 0, 2, 2, STOP | EXT, 0, + NULL}, + {FNAME("multiplexEntrySendAck") SEQ, 0, 2, 2, STOP | EXT, 0, NULL}, + {FNAME("multiplexEntrySendReject") SEQ, 0, 2, 2, STOP | EXT, 0, NULL}, + {FNAME("requestMultiplexEntryAck") SEQ, 0, 1, 1, STOP | EXT, 0, NULL}, + {FNAME("requestMultiplexEntryReject") SEQ, 0, 2, 2, STOP | EXT, 0, + NULL}, + {FNAME("requestModeAck") SEQ, 0, 2, 2, STOP | EXT, 0, NULL}, + {FNAME("requestModeReject") SEQ, 0, 2, 2, STOP | EXT, 0, NULL}, + {FNAME("roundTripDelayResponse") SEQ, 0, 1, 1, STOP | EXT, 0, NULL}, + {FNAME("maintenanceLoopAck") SEQ, 0, 1, 1, STOP | EXT, 0, NULL}, + {FNAME("maintenanceLoopReject") SEQ, 0, 2, 2, STOP | EXT, 0, NULL}, + {FNAME("communicationModeResponse") CHOICE, 0, 1, 1, STOP | EXT, 0, + NULL}, + {FNAME("conferenceResponse") CHOICE, 3, 8, 16, STOP | EXT, 0, NULL}, + {FNAME("multilinkResponse") CHOICE, 3, 5, 5, STOP | EXT, 0, NULL}, + {FNAME("logicalChannelRateAcknowledge") SEQ, 0, 3, 3, STOP | EXT, 0, + NULL}, + {FNAME("logicalChannelRateReject") SEQ, 1, 4, 4, STOP | EXT, 0, NULL}, +}; + +static const struct field_t _MultimediaSystemControlMessage[] = { /* CHOICE */ + {FNAME("request") CHOICE, 4, 11, 15, DECODE | EXT, + offsetof(MultimediaSystemControlMessage, request), _RequestMessage}, + {FNAME("response") CHOICE, 5, 19, 24, DECODE | EXT, + offsetof(MultimediaSystemControlMessage, response), + _ResponseMessage}, + {FNAME("command") CHOICE, 3, 7, 12, STOP | EXT, 0, NULL}, + {FNAME("indication") CHOICE, 4, 14, 23, STOP | EXT, 0, NULL}, +}; + +static const struct field_t _H323_UU_PDU_h245Control[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 2, 4, 4, DECODE | OPEN | EXT, + sizeof(MultimediaSystemControlMessage), + _MultimediaSystemControlMessage} + , +}; + +static const struct field_t _H323_UU_PDU[] = { /* SEQUENCE */ + {FNAME("h323-message-body") CHOICE, 3, 7, 13, DECODE | EXT, + offsetof(H323_UU_PDU, h323_message_body), + _H323_UU_PDU_h323_message_body}, + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("h4501SupplementaryService") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, + NULL}, + {FNAME("h245Tunneling") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("h245Control") SEQOF, SEMI, 0, 4, DECODE | OPT, + offsetof(H323_UU_PDU, h245Control), _H323_UU_PDU_h245Control}, + {FNAME("nonStandardControl") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("callLinkage") SEQ, 2, 2, 2, STOP | EXT | OPT, 0, NULL}, + {FNAME("tunnelledSignallingMessage") SEQ, 2, 4, 4, STOP | EXT | OPT, + 0, NULL}, + {FNAME("provisionalRespToH245Tunneling") NUL, FIXD, 0, 0, STOP | OPT, + 0, NULL}, + {FNAME("stimulusControl") SEQ, 3, 3, 3, STOP | EXT | OPT, 0, NULL}, + {FNAME("genericData") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, +}; + +static const struct field_t _H323_UserInformation[] = { /* SEQUENCE */ + {FNAME("h323-uu-pdu") SEQ, 1, 2, 11, DECODE | EXT, + offsetof(H323_UserInformation, h323_uu_pdu), _H323_UU_PDU}, + {FNAME("user-data") SEQ, 0, 2, 2, STOP | EXT | OPT, 0, NULL}, +}; + +static const struct field_t _GatekeeperRequest[] = { /* SEQUENCE */ + {FNAME("requestSeqNum") INT, WORD, 1, 0, SKIP, 0, NULL}, + {FNAME("protocolIdentifier") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("rasAddress") CHOICE, 3, 7, 7, DECODE | EXT, + offsetof(GatekeeperRequest, rasAddress), _TransportAddress}, + {FNAME("endpointType") SEQ, 6, 8, 10, STOP | EXT, 0, NULL}, + {FNAME("gatekeeperIdentifier") BMPSTR, 7, 1, 0, STOP | OPT, 0, NULL}, + {FNAME("callServices") SEQ, 0, 8, 8, STOP | EXT | OPT, 0, NULL}, + {FNAME("endpointAlias") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("alternateEndpoints") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("authenticationCapability") SEQOF, SEMI, 0, 0, STOP | OPT, 0, + NULL}, + {FNAME("algorithmOIDs") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("integrity") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("integrityCheckValue") SEQ, 0, 2, 2, STOP | OPT, 0, NULL}, + {FNAME("supportsAltGK") NUL, FIXD, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("featureSet") SEQ, 3, 4, 4, STOP | EXT | OPT, 0, NULL}, + {FNAME("genericData") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, +}; + +static const struct field_t _GatekeeperConfirm[] = { /* SEQUENCE */ + {FNAME("requestSeqNum") INT, WORD, 1, 0, SKIP, 0, NULL}, + {FNAME("protocolIdentifier") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("gatekeeperIdentifier") BMPSTR, 7, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("rasAddress") CHOICE, 3, 7, 7, DECODE | EXT, + offsetof(GatekeeperConfirm, rasAddress), _TransportAddress}, + {FNAME("alternateGatekeeper") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("authenticationMode") CHOICE, 3, 7, 8, STOP | EXT | OPT, 0, + NULL}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("algorithmOID") OID, BYTE, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("integrity") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("integrityCheckValue") SEQ, 0, 2, 2, STOP | OPT, 0, NULL}, + {FNAME("featureSet") SEQ, 3, 4, 4, STOP | EXT | OPT, 0, NULL}, + {FNAME("genericData") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, +}; + +static const struct field_t _RegistrationRequest_callSignalAddress[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 3, 7, 7, DECODE | EXT, + sizeof(TransportAddress), _TransportAddress} + , +}; + +static const struct field_t _RegistrationRequest_rasAddress[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 3, 7, 7, DECODE | EXT, + sizeof(TransportAddress), _TransportAddress} + , +}; + +static const struct field_t _RegistrationRequest_terminalAlias[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 1, 2, 7, SKIP | EXT, 0, _AliasAddress}, +}; + +static const struct field_t _RegistrationRequest[] = { /* SEQUENCE */ + {FNAME("requestSeqNum") INT, WORD, 1, 0, SKIP, 0, NULL}, + {FNAME("protocolIdentifier") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("discoveryComplete") BOOL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("callSignalAddress") SEQOF, SEMI, 0, 10, DECODE, + offsetof(RegistrationRequest, callSignalAddress), + _RegistrationRequest_callSignalAddress}, + {FNAME("rasAddress") SEQOF, SEMI, 0, 10, DECODE, + offsetof(RegistrationRequest, rasAddress), + _RegistrationRequest_rasAddress}, + {FNAME("terminalType") SEQ, 6, 8, 10, SKIP | EXT, 0, _EndpointType}, + {FNAME("terminalAlias") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, + _RegistrationRequest_terminalAlias}, + {FNAME("gatekeeperIdentifier") BMPSTR, 7, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("endpointVendor") SEQ, 2, 3, 3, SKIP | EXT, 0, + _VendorIdentifier}, + {FNAME("alternateEndpoints") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("timeToLive") INT, CONS, 1, 0, DECODE | OPT, + offsetof(RegistrationRequest, timeToLive), NULL}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("integrityCheckValue") SEQ, 0, 2, 2, STOP | OPT, 0, NULL}, + {FNAME("keepAlive") BOOL, FIXD, 0, 0, STOP, 0, NULL}, + {FNAME("endpointIdentifier") BMPSTR, 7, 1, 0, STOP | OPT, 0, NULL}, + {FNAME("willSupplyUUIEs") BOOL, FIXD, 0, 0, STOP, 0, NULL}, + {FNAME("maintainConnection") BOOL, FIXD, 0, 0, STOP, 0, NULL}, + {FNAME("alternateTransportAddresses") SEQ, 1, 1, 1, STOP | EXT | OPT, + 0, NULL}, + {FNAME("additiveRegistration") NUL, FIXD, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("terminalAliasPattern") SEQOF, SEMI, 0, 0, STOP | OPT, 0, + NULL}, + {FNAME("supportsAltGK") NUL, FIXD, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("usageReportingCapability") SEQ, 3, 4, 4, STOP | EXT | OPT, 0, + NULL}, + {FNAME("multipleCalls") BOOL, FIXD, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("supportedH248Packages") SEQOF, SEMI, 0, 0, STOP | OPT, 0, + NULL}, + {FNAME("callCreditCapability") SEQ, 2, 2, 2, STOP | EXT | OPT, 0, + NULL}, + {FNAME("capacityReportingCapability") SEQ, 0, 1, 1, STOP | EXT | OPT, + 0, NULL}, + {FNAME("capacity") SEQ, 2, 2, 2, STOP | EXT | OPT, 0, NULL}, + {FNAME("featureSet") SEQ, 3, 4, 4, STOP | EXT | OPT, 0, NULL}, + {FNAME("genericData") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, +}; + +static const struct field_t _RegistrationConfirm_callSignalAddress[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 3, 7, 7, DECODE | EXT, + sizeof(TransportAddress), _TransportAddress} + , +}; + +static const struct field_t _RegistrationConfirm_terminalAlias[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 1, 2, 7, SKIP | EXT, 0, _AliasAddress}, +}; + +static const struct field_t _RegistrationConfirm[] = { /* SEQUENCE */ + {FNAME("requestSeqNum") INT, WORD, 1, 0, SKIP, 0, NULL}, + {FNAME("protocolIdentifier") OID, BYTE, 0, 0, SKIP, 0, NULL}, + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("callSignalAddress") SEQOF, SEMI, 0, 10, DECODE, + offsetof(RegistrationConfirm, callSignalAddress), + _RegistrationConfirm_callSignalAddress}, + {FNAME("terminalAlias") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, + _RegistrationConfirm_terminalAlias}, + {FNAME("gatekeeperIdentifier") BMPSTR, 7, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("endpointIdentifier") BMPSTR, 7, 1, 0, SKIP, 0, NULL}, + {FNAME("alternateGatekeeper") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, NULL}, + {FNAME("timeToLive") INT, CONS, 1, 0, DECODE | OPT, + offsetof(RegistrationConfirm, timeToLive), NULL}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("integrityCheckValue") SEQ, 0, 2, 2, STOP | OPT, 0, NULL}, + {FNAME("willRespondToIRR") BOOL, FIXD, 0, 0, STOP, 0, NULL}, + {FNAME("preGrantedARQ") SEQ, 0, 4, 8, STOP | EXT | OPT, 0, NULL}, + {FNAME("maintainConnection") BOOL, FIXD, 0, 0, STOP, 0, NULL}, + {FNAME("serviceControl") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("supportsAdditiveRegistration") NUL, FIXD, 0, 0, STOP | OPT, 0, + NULL}, + {FNAME("terminalAliasPattern") SEQOF, SEMI, 0, 0, STOP | OPT, 0, + NULL}, + {FNAME("supportedPrefixes") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("usageSpec") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("featureServerAlias") CHOICE, 1, 2, 7, STOP | EXT | OPT, 0, + NULL}, + {FNAME("capacityReportingSpec") SEQ, 0, 1, 1, STOP | EXT | OPT, 0, + NULL}, + {FNAME("featureSet") SEQ, 3, 4, 4, STOP | EXT | OPT, 0, NULL}, + {FNAME("genericData") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, +}; + +static const struct field_t _UnregistrationRequest_callSignalAddress[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 3, 7, 7, DECODE | EXT, + sizeof(TransportAddress), _TransportAddress} + , +}; + +static const struct field_t _UnregistrationRequest[] = { /* SEQUENCE */ + {FNAME("requestSeqNum") INT, WORD, 1, 0, SKIP, 0, NULL}, + {FNAME("callSignalAddress") SEQOF, SEMI, 0, 10, DECODE, + offsetof(UnregistrationRequest, callSignalAddress), + _UnregistrationRequest_callSignalAddress}, + {FNAME("endpointAlias") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("nonStandardData") SEQ, 0, 2, 2, STOP | OPT, 0, NULL}, + {FNAME("endpointIdentifier") BMPSTR, 7, 1, 0, STOP | OPT, 0, NULL}, + {FNAME("alternateEndpoints") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("gatekeeperIdentifier") BMPSTR, 7, 1, 0, STOP | OPT, 0, NULL}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("integrityCheckValue") SEQ, 0, 2, 2, STOP | OPT, 0, NULL}, + {FNAME("reason") CHOICE, 2, 4, 5, STOP | EXT | OPT, 0, NULL}, + {FNAME("endpointAliasPattern") SEQOF, SEMI, 0, 0, STOP | OPT, 0, + NULL}, + {FNAME("supportedPrefixes") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("alternateGatekeeper") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("genericData") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, +}; + +static const struct field_t _CallModel[] = { /* CHOICE */ + {FNAME("direct") NUL, FIXD, 0, 0, SKIP, 0, NULL}, + {FNAME("gatekeeperRouted") NUL, FIXD, 0, 0, SKIP, 0, NULL}, +}; + +static const struct field_t _AdmissionRequest_destinationInfo[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 1, 2, 7, SKIP | EXT, 0, _AliasAddress}, +}; + +static const struct field_t _AdmissionRequest_destExtraCallInfo[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 1, 2, 7, SKIP | EXT, 0, _AliasAddress}, +}; + +static const struct field_t _AdmissionRequest_srcInfo[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 1, 2, 7, SKIP | EXT, 0, _AliasAddress}, +}; + +static const struct field_t _AdmissionRequest[] = { /* SEQUENCE */ + {FNAME("requestSeqNum") INT, WORD, 1, 0, SKIP, 0, NULL}, + {FNAME("callType") CHOICE, 2, 4, 4, SKIP | EXT, 0, _CallType}, + {FNAME("callModel") CHOICE, 1, 2, 2, SKIP | EXT | OPT, 0, _CallModel}, + {FNAME("endpointIdentifier") BMPSTR, 7, 1, 0, SKIP, 0, NULL}, + {FNAME("destinationInfo") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, + _AdmissionRequest_destinationInfo}, + {FNAME("destCallSignalAddress") CHOICE, 3, 7, 7, DECODE | EXT | OPT, + offsetof(AdmissionRequest, destCallSignalAddress), + _TransportAddress}, + {FNAME("destExtraCallInfo") SEQOF, SEMI, 0, 0, SKIP | OPT, 0, + _AdmissionRequest_destExtraCallInfo}, + {FNAME("srcInfo") SEQOF, SEMI, 0, 0, SKIP, 0, + _AdmissionRequest_srcInfo}, + {FNAME("srcCallSignalAddress") CHOICE, 3, 7, 7, DECODE | EXT | OPT, + offsetof(AdmissionRequest, srcCallSignalAddress), _TransportAddress}, + {FNAME("bandWidth") INT, CONS, 0, 0, STOP, 0, NULL}, + {FNAME("callReferenceValue") INT, WORD, 0, 0, STOP, 0, NULL}, + {FNAME("nonStandardData") SEQ, 0, 2, 2, STOP | OPT, 0, NULL}, + {FNAME("callServices") SEQ, 0, 8, 8, STOP | EXT | OPT, 0, NULL}, + {FNAME("conferenceID") OCTSTR, FIXD, 16, 0, STOP, 0, NULL}, + {FNAME("activeMC") BOOL, FIXD, 0, 0, STOP, 0, NULL}, + {FNAME("answerCall") BOOL, FIXD, 0, 0, STOP, 0, NULL}, + {FNAME("canMapAlias") BOOL, FIXD, 0, 0, STOP, 0, NULL}, + {FNAME("callIdentifier") SEQ, 0, 1, 1, STOP | EXT, 0, NULL}, + {FNAME("srcAlternatives") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("destAlternatives") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("gatekeeperIdentifier") BMPSTR, 7, 1, 0, STOP | OPT, 0, NULL}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("integrityCheckValue") SEQ, 0, 2, 2, STOP | OPT, 0, NULL}, + {FNAME("transportQOS") CHOICE, 2, 3, 3, STOP | EXT | OPT, 0, NULL}, + {FNAME("willSupplyUUIEs") BOOL, FIXD, 0, 0, STOP, 0, NULL}, + {FNAME("callLinkage") SEQ, 2, 2, 2, STOP | EXT | OPT, 0, NULL}, + {FNAME("gatewayDataRate") SEQ, 2, 3, 3, STOP | EXT | OPT, 0, NULL}, + {FNAME("capacity") SEQ, 2, 2, 2, STOP | EXT | OPT, 0, NULL}, + {FNAME("circuitInfo") SEQ, 3, 3, 3, STOP | EXT | OPT, 0, NULL}, + {FNAME("desiredProtocols") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("desiredTunnelledProtocol") SEQ, 1, 2, 2, STOP | EXT | OPT, 0, + NULL}, + {FNAME("featureSet") SEQ, 3, 4, 4, STOP | EXT | OPT, 0, NULL}, + {FNAME("genericData") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, +}; + +static const struct field_t _AdmissionConfirm[] = { /* SEQUENCE */ + {FNAME("requestSeqNum") INT, WORD, 1, 0, SKIP, 0, NULL}, + {FNAME("bandWidth") INT, CONS, 0, 0, SKIP, 0, NULL}, + {FNAME("callModel") CHOICE, 1, 2, 2, SKIP | EXT, 0, _CallModel}, + {FNAME("destCallSignalAddress") CHOICE, 3, 7, 7, DECODE | EXT, + offsetof(AdmissionConfirm, destCallSignalAddress), + _TransportAddress}, + {FNAME("irrFrequency") INT, WORD, 1, 0, STOP | OPT, 0, NULL}, + {FNAME("nonStandardData") SEQ, 0, 2, 2, STOP | OPT, 0, NULL}, + {FNAME("destinationInfo") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("destExtraCallInfo") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("destinationType") SEQ, 6, 8, 10, STOP | EXT | OPT, 0, NULL}, + {FNAME("remoteExtensionAddress") SEQOF, SEMI, 0, 0, STOP | OPT, 0, + NULL}, + {FNAME("alternateEndpoints") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("integrityCheckValue") SEQ, 0, 2, 2, STOP | OPT, 0, NULL}, + {FNAME("transportQOS") CHOICE, 2, 3, 3, STOP | EXT | OPT, 0, NULL}, + {FNAME("willRespondToIRR") BOOL, FIXD, 0, 0, STOP, 0, NULL}, + {FNAME("uuiesRequested") SEQ, 0, 9, 13, STOP | EXT, 0, NULL}, + {FNAME("language") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("alternateTransportAddresses") SEQ, 1, 1, 1, STOP | EXT | OPT, + 0, NULL}, + {FNAME("useSpecifiedTransport") CHOICE, 1, 2, 2, STOP | EXT | OPT, 0, + NULL}, + {FNAME("circuitInfo") SEQ, 3, 3, 3, STOP | EXT | OPT, 0, NULL}, + {FNAME("usageSpec") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("supportedProtocols") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("serviceControl") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("multipleCalls") BOOL, FIXD, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("featureSet") SEQ, 3, 4, 4, STOP | EXT | OPT, 0, NULL}, + {FNAME("genericData") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, +}; + +static const struct field_t _LocationRequest_destinationInfo[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 1, 2, 7, SKIP | EXT, 0, _AliasAddress}, +}; + +static const struct field_t _LocationRequest[] = { /* SEQUENCE */ + {FNAME("requestSeqNum") INT, WORD, 1, 0, SKIP, 0, NULL}, + {FNAME("endpointIdentifier") BMPSTR, 7, 1, 0, SKIP | OPT, 0, NULL}, + {FNAME("destinationInfo") SEQOF, SEMI, 0, 0, SKIP, 0, + _LocationRequest_destinationInfo}, + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("replyAddress") CHOICE, 3, 7, 7, DECODE | EXT, + offsetof(LocationRequest, replyAddress), _TransportAddress}, + {FNAME("sourceInfo") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("canMapAlias") BOOL, FIXD, 0, 0, STOP, 0, NULL}, + {FNAME("gatekeeperIdentifier") BMPSTR, 7, 1, 0, STOP | OPT, 0, NULL}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("integrityCheckValue") SEQ, 0, 2, 2, STOP | OPT, 0, NULL}, + {FNAME("desiredProtocols") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("desiredTunnelledProtocol") SEQ, 1, 2, 2, STOP | EXT | OPT, 0, + NULL}, + {FNAME("featureSet") SEQ, 3, 4, 4, STOP | EXT | OPT, 0, NULL}, + {FNAME("genericData") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("hopCount") INT, 8, 1, 0, STOP | OPT, 0, NULL}, + {FNAME("circuitInfo") SEQ, 3, 3, 3, STOP | EXT | OPT, 0, NULL}, +}; + +static const struct field_t _LocationConfirm[] = { /* SEQUENCE */ + {FNAME("requestSeqNum") INT, WORD, 1, 0, SKIP, 0, NULL}, + {FNAME("callSignalAddress") CHOICE, 3, 7, 7, DECODE | EXT, + offsetof(LocationConfirm, callSignalAddress), _TransportAddress}, + {FNAME("rasAddress") CHOICE, 3, 7, 7, DECODE | EXT, + offsetof(LocationConfirm, rasAddress), _TransportAddress}, + {FNAME("nonStandardData") SEQ, 0, 2, 2, STOP | OPT, 0, NULL}, + {FNAME("destinationInfo") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("destExtraCallInfo") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("destinationType") SEQ, 6, 8, 10, STOP | EXT | OPT, 0, NULL}, + {FNAME("remoteExtensionAddress") SEQOF, SEMI, 0, 0, STOP | OPT, 0, + NULL}, + {FNAME("alternateEndpoints") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("integrityCheckValue") SEQ, 0, 2, 2, STOP | OPT, 0, NULL}, + {FNAME("alternateTransportAddresses") SEQ, 1, 1, 1, STOP | EXT | OPT, + 0, NULL}, + {FNAME("supportedProtocols") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("multipleCalls") BOOL, FIXD, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("featureSet") SEQ, 3, 4, 4, STOP | EXT | OPT, 0, NULL}, + {FNAME("genericData") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("circuitInfo") SEQ, 3, 3, 3, STOP | EXT | OPT, 0, NULL}, + {FNAME("serviceControl") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, +}; + +static const struct field_t _InfoRequestResponse_callSignalAddress[] = { /* SEQUENCE OF */ + {FNAME("item") CHOICE, 3, 7, 7, DECODE | EXT, + sizeof(TransportAddress), _TransportAddress} + , +}; + +static const struct field_t _InfoRequestResponse[] = { /* SEQUENCE */ + {FNAME("nonStandardData") SEQ, 0, 2, 2, SKIP | OPT, 0, + _NonStandardParameter}, + {FNAME("requestSeqNum") INT, WORD, 1, 0, SKIP, 0, NULL}, + {FNAME("endpointType") SEQ, 6, 8, 10, SKIP | EXT, 0, _EndpointType}, + {FNAME("endpointIdentifier") BMPSTR, 7, 1, 0, SKIP, 0, NULL}, + {FNAME("rasAddress") CHOICE, 3, 7, 7, DECODE | EXT, + offsetof(InfoRequestResponse, rasAddress), _TransportAddress}, + {FNAME("callSignalAddress") SEQOF, SEMI, 0, 10, DECODE, + offsetof(InfoRequestResponse, callSignalAddress), + _InfoRequestResponse_callSignalAddress}, + {FNAME("endpointAlias") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("perCallInfo") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("tokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("cryptoTokens") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, + {FNAME("integrityCheckValue") SEQ, 0, 2, 2, STOP | OPT, 0, NULL}, + {FNAME("needResponse") BOOL, FIXD, 0, 0, STOP, 0, NULL}, + {FNAME("capacity") SEQ, 2, 2, 2, STOP | EXT | OPT, 0, NULL}, + {FNAME("irrStatus") CHOICE, 2, 4, 4, STOP | EXT | OPT, 0, NULL}, + {FNAME("unsolicited") BOOL, FIXD, 0, 0, STOP, 0, NULL}, + {FNAME("genericData") SEQOF, SEMI, 0, 0, STOP | OPT, 0, NULL}, +}; + +static const struct field_t _RasMessage[] = { /* CHOICE */ + {FNAME("gatekeeperRequest") SEQ, 4, 8, 18, DECODE | EXT, + offsetof(RasMessage, gatekeeperRequest), _GatekeeperRequest}, + {FNAME("gatekeeperConfirm") SEQ, 2, 5, 14, DECODE | EXT, + offsetof(RasMessage, gatekeeperConfirm), _GatekeeperConfirm}, + {FNAME("gatekeeperReject") SEQ, 2, 5, 11, STOP | EXT, 0, NULL}, + {FNAME("registrationRequest") SEQ, 3, 10, 31, DECODE | EXT, + offsetof(RasMessage, registrationRequest), _RegistrationRequest}, + {FNAME("registrationConfirm") SEQ, 3, 7, 24, DECODE | EXT, + offsetof(RasMessage, registrationConfirm), _RegistrationConfirm}, + {FNAME("registrationReject") SEQ, 2, 5, 11, STOP | EXT, 0, NULL}, + {FNAME("unregistrationRequest") SEQ, 3, 5, 15, DECODE | EXT, + offsetof(RasMessage, unregistrationRequest), _UnregistrationRequest}, + {FNAME("unregistrationConfirm") SEQ, 1, 2, 6, STOP | EXT, 0, NULL}, + {FNAME("unregistrationReject") SEQ, 1, 3, 8, STOP | EXT, 0, NULL}, + {FNAME("admissionRequest") SEQ, 7, 16, 34, DECODE | EXT, + offsetof(RasMessage, admissionRequest), _AdmissionRequest}, + {FNAME("admissionConfirm") SEQ, 2, 6, 27, DECODE | EXT, + offsetof(RasMessage, admissionConfirm), _AdmissionConfirm}, + {FNAME("admissionReject") SEQ, 1, 3, 11, STOP | EXT, 0, NULL}, + {FNAME("bandwidthRequest") SEQ, 2, 7, 18, STOP | EXT, 0, NULL}, + {FNAME("bandwidthConfirm") SEQ, 1, 3, 8, STOP | EXT, 0, NULL}, + {FNAME("bandwidthReject") SEQ, 1, 4, 9, STOP | EXT, 0, NULL}, + {FNAME("disengageRequest") SEQ, 1, 6, 19, STOP | EXT, 0, NULL}, + {FNAME("disengageConfirm") SEQ, 1, 2, 9, STOP | EXT, 0, NULL}, + {FNAME("disengageReject") SEQ, 1, 3, 8, STOP | EXT, 0, NULL}, + {FNAME("locationRequest") SEQ, 2, 5, 17, DECODE | EXT, + offsetof(RasMessage, locationRequest), _LocationRequest}, + {FNAME("locationConfirm") SEQ, 1, 4, 19, DECODE | EXT, + offsetof(RasMessage, locationConfirm), _LocationConfirm}, + {FNAME("locationReject") SEQ, 1, 3, 10, STOP | EXT, 0, NULL}, + {FNAME("infoRequest") SEQ, 2, 4, 15, STOP | EXT, 0, NULL}, + {FNAME("infoRequestResponse") SEQ, 3, 8, 16, DECODE | EXT, + offsetof(RasMessage, infoRequestResponse), _InfoRequestResponse}, + {FNAME("nonStandardMessage") SEQ, 0, 2, 7, STOP | EXT, 0, NULL}, + {FNAME("unknownMessageResponse") SEQ, 0, 1, 5, STOP | EXT, 0, NULL}, + {FNAME("requestInProgress") SEQ, 4, 6, 6, STOP | EXT, 0, NULL}, + {FNAME("resourcesAvailableIndicate") SEQ, 4, 9, 11, STOP | EXT, 0, + NULL}, + {FNAME("resourcesAvailableConfirm") SEQ, 4, 6, 7, STOP | EXT, 0, + NULL}, + {FNAME("infoRequestAck") SEQ, 4, 5, 5, STOP | EXT, 0, NULL}, + {FNAME("infoRequestNak") SEQ, 5, 7, 7, STOP | EXT, 0, NULL}, + {FNAME("serviceControlIndication") SEQ, 8, 10, 10, STOP | EXT, 0, + NULL}, + {FNAME("serviceControlResponse") SEQ, 7, 8, 8, STOP | EXT, 0, NULL}, +}; diff --git a/net/netfilter/nf_conntrack_helper.c b/net/netfilter/nf_conntrack_helper.c new file mode 100644 index 000000000..bf09a1e06 --- /dev/null +++ b/net/netfilter/nf_conntrack_helper.c @@ -0,0 +1,520 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Helper handling for netfilter. */ + +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2006 Netfilter Core Team + * (C) 2003,2004 USAGI/WIDE Project + * (C) 2006-2012 Patrick McHardy + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +static DEFINE_MUTEX(nf_ct_helper_mutex); +struct hlist_head *nf_ct_helper_hash __read_mostly; +EXPORT_SYMBOL_GPL(nf_ct_helper_hash); +unsigned int nf_ct_helper_hsize __read_mostly; +EXPORT_SYMBOL_GPL(nf_ct_helper_hsize); +static unsigned int nf_ct_helper_count __read_mostly; + +static DEFINE_MUTEX(nf_ct_nat_helpers_mutex); +static struct list_head nf_ct_nat_helpers __read_mostly; + +/* Stupid hash, but collision free for the default registrations of the + * helpers currently in the kernel. */ +static unsigned int helper_hash(const struct nf_conntrack_tuple *tuple) +{ + return (((tuple->src.l3num << 8) | tuple->dst.protonum) ^ + (__force __u16)tuple->src.u.all) % nf_ct_helper_hsize; +} + +struct nf_conntrack_helper * +__nf_conntrack_helper_find(const char *name, u16 l3num, u8 protonum) +{ + struct nf_conntrack_helper *h; + unsigned int i; + + for (i = 0; i < nf_ct_helper_hsize; i++) { + hlist_for_each_entry_rcu(h, &nf_ct_helper_hash[i], hnode) { + if (strcmp(h->name, name)) + continue; + + if (h->tuple.src.l3num != NFPROTO_UNSPEC && + h->tuple.src.l3num != l3num) + continue; + + if (h->tuple.dst.protonum == protonum) + return h; + } + } + return NULL; +} +EXPORT_SYMBOL_GPL(__nf_conntrack_helper_find); + +struct nf_conntrack_helper * +nf_conntrack_helper_try_module_get(const char *name, u16 l3num, u8 protonum) +{ + struct nf_conntrack_helper *h; + + rcu_read_lock(); + + h = __nf_conntrack_helper_find(name, l3num, protonum); +#ifdef CONFIG_MODULES + if (h == NULL) { + rcu_read_unlock(); + if (request_module("nfct-helper-%s", name) == 0) { + rcu_read_lock(); + h = __nf_conntrack_helper_find(name, l3num, protonum); + } else { + return h; + } + } +#endif + if (h != NULL && !try_module_get(h->me)) + h = NULL; + if (h != NULL && !refcount_inc_not_zero(&h->refcnt)) { + module_put(h->me); + h = NULL; + } + + rcu_read_unlock(); + + return h; +} +EXPORT_SYMBOL_GPL(nf_conntrack_helper_try_module_get); + +void nf_conntrack_helper_put(struct nf_conntrack_helper *helper) +{ + refcount_dec(&helper->refcnt); + module_put(helper->me); +} +EXPORT_SYMBOL_GPL(nf_conntrack_helper_put); + +static struct nf_conntrack_nat_helper * +nf_conntrack_nat_helper_find(const char *mod_name) +{ + struct nf_conntrack_nat_helper *cur; + bool found = false; + + list_for_each_entry_rcu(cur, &nf_ct_nat_helpers, list) { + if (!strcmp(cur->mod_name, mod_name)) { + found = true; + break; + } + } + return found ? cur : NULL; +} + +int +nf_nat_helper_try_module_get(const char *name, u16 l3num, u8 protonum) +{ + struct nf_conntrack_helper *h; + struct nf_conntrack_nat_helper *nat; + char mod_name[NF_CT_HELPER_NAME_LEN]; + int ret = 0; + + rcu_read_lock(); + h = __nf_conntrack_helper_find(name, l3num, protonum); + if (!h) { + rcu_read_unlock(); + return -ENOENT; + } + + nat = nf_conntrack_nat_helper_find(h->nat_mod_name); + if (!nat) { + snprintf(mod_name, sizeof(mod_name), "%s", h->nat_mod_name); + rcu_read_unlock(); + request_module("%s", mod_name); + + rcu_read_lock(); + nat = nf_conntrack_nat_helper_find(mod_name); + if (!nat) { + rcu_read_unlock(); + return -ENOENT; + } + } + + if (!try_module_get(nat->module)) + ret = -ENOENT; + + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(nf_nat_helper_try_module_get); + +void nf_nat_helper_put(struct nf_conntrack_helper *helper) +{ + struct nf_conntrack_nat_helper *nat; + + nat = nf_conntrack_nat_helper_find(helper->nat_mod_name); + if (WARN_ON_ONCE(!nat)) + return; + + module_put(nat->module); +} +EXPORT_SYMBOL_GPL(nf_nat_helper_put); + +struct nf_conn_help * +nf_ct_helper_ext_add(struct nf_conn *ct, gfp_t gfp) +{ + struct nf_conn_help *help; + + help = nf_ct_ext_add(ct, NF_CT_EXT_HELPER, gfp); + if (help) + INIT_HLIST_HEAD(&help->expectations); + else + pr_debug("failed to add helper extension area"); + return help; +} +EXPORT_SYMBOL_GPL(nf_ct_helper_ext_add); + +int __nf_ct_try_assign_helper(struct nf_conn *ct, struct nf_conn *tmpl, + gfp_t flags) +{ + struct nf_conntrack_helper *helper = NULL; + struct nf_conn_help *help; + + /* We already got a helper explicitly attached. The function + * nf_conntrack_alter_reply - in case NAT is in use - asks for looking + * the helper up again. Since now the user is in full control of + * making consistent helper configurations, skip this automatic + * re-lookup, otherwise we'll lose the helper. + */ + if (test_bit(IPS_HELPER_BIT, &ct->status)) + return 0; + + if (WARN_ON_ONCE(!tmpl)) + return 0; + + help = nfct_help(tmpl); + if (help != NULL) { + helper = rcu_dereference(help->helper); + set_bit(IPS_HELPER_BIT, &ct->status); + } + + help = nfct_help(ct); + + if (helper == NULL) { + if (help) + RCU_INIT_POINTER(help->helper, NULL); + return 0; + } + + if (help == NULL) { + help = nf_ct_helper_ext_add(ct, flags); + if (help == NULL) + return -ENOMEM; + } else { + /* We only allow helper re-assignment of the same sort since + * we cannot reallocate the helper extension area. + */ + struct nf_conntrack_helper *tmp = rcu_dereference(help->helper); + + if (tmp && tmp->help != helper->help) { + RCU_INIT_POINTER(help->helper, NULL); + return 0; + } + } + + rcu_assign_pointer(help->helper, helper); + + return 0; +} +EXPORT_SYMBOL_GPL(__nf_ct_try_assign_helper); + +/* appropriate ct lock protecting must be taken by caller */ +static int unhelp(struct nf_conn *ct, void *me) +{ + struct nf_conn_help *help = nfct_help(ct); + + if (help && rcu_dereference_raw(help->helper) == me) { + nf_conntrack_event(IPCT_HELPER, ct); + RCU_INIT_POINTER(help->helper, NULL); + } + + /* We are not intended to delete this conntrack. */ + return 0; +} + +void nf_ct_helper_destroy(struct nf_conn *ct) +{ + struct nf_conn_help *help = nfct_help(ct); + struct nf_conntrack_helper *helper; + + if (help) { + rcu_read_lock(); + helper = rcu_dereference(help->helper); + if (helper && helper->destroy) + helper->destroy(ct); + rcu_read_unlock(); + } +} + +static LIST_HEAD(nf_ct_helper_expectfn_list); + +void nf_ct_helper_expectfn_register(struct nf_ct_helper_expectfn *n) +{ + spin_lock_bh(&nf_conntrack_expect_lock); + list_add_rcu(&n->head, &nf_ct_helper_expectfn_list); + spin_unlock_bh(&nf_conntrack_expect_lock); +} +EXPORT_SYMBOL_GPL(nf_ct_helper_expectfn_register); + +void nf_ct_helper_expectfn_unregister(struct nf_ct_helper_expectfn *n) +{ + spin_lock_bh(&nf_conntrack_expect_lock); + list_del_rcu(&n->head); + spin_unlock_bh(&nf_conntrack_expect_lock); +} +EXPORT_SYMBOL_GPL(nf_ct_helper_expectfn_unregister); + +/* Caller should hold the rcu lock */ +struct nf_ct_helper_expectfn * +nf_ct_helper_expectfn_find_by_name(const char *name) +{ + struct nf_ct_helper_expectfn *cur; + bool found = false; + + list_for_each_entry_rcu(cur, &nf_ct_helper_expectfn_list, head) { + if (!strcmp(cur->name, name)) { + found = true; + break; + } + } + return found ? cur : NULL; +} +EXPORT_SYMBOL_GPL(nf_ct_helper_expectfn_find_by_name); + +/* Caller should hold the rcu lock */ +struct nf_ct_helper_expectfn * +nf_ct_helper_expectfn_find_by_symbol(const void *symbol) +{ + struct nf_ct_helper_expectfn *cur; + bool found = false; + + list_for_each_entry_rcu(cur, &nf_ct_helper_expectfn_list, head) { + if (cur->expectfn == symbol) { + found = true; + break; + } + } + return found ? cur : NULL; +} +EXPORT_SYMBOL_GPL(nf_ct_helper_expectfn_find_by_symbol); + +__printf(3, 4) +void nf_ct_helper_log(struct sk_buff *skb, const struct nf_conn *ct, + const char *fmt, ...) +{ + const struct nf_conn_help *help; + const struct nf_conntrack_helper *helper; + struct va_format vaf; + va_list args; + + va_start(args, fmt); + + vaf.fmt = fmt; + vaf.va = &args; + + /* Called from the helper function, this call never fails */ + help = nfct_help(ct); + + /* rcu_read_lock()ed by nf_hook_thresh */ + helper = rcu_dereference(help->helper); + + nf_log_packet(nf_ct_net(ct), nf_ct_l3num(ct), 0, skb, NULL, NULL, NULL, + "nf_ct_%s: dropping packet: %pV ", helper->name, &vaf); + + va_end(args); +} +EXPORT_SYMBOL_GPL(nf_ct_helper_log); + +int nf_conntrack_helper_register(struct nf_conntrack_helper *me) +{ + struct nf_conntrack_tuple_mask mask = { .src.u.all = htons(0xFFFF) }; + unsigned int h = helper_hash(&me->tuple); + struct nf_conntrack_helper *cur; + int ret = 0, i; + + BUG_ON(me->expect_policy == NULL); + BUG_ON(me->expect_class_max >= NF_CT_MAX_EXPECT_CLASSES); + BUG_ON(strlen(me->name) > NF_CT_HELPER_NAME_LEN - 1); + + if (!nf_ct_helper_hash) + return -ENOENT; + + if (me->expect_policy->max_expected > NF_CT_EXPECT_MAX_CNT) + return -EINVAL; + + mutex_lock(&nf_ct_helper_mutex); + for (i = 0; i < nf_ct_helper_hsize; i++) { + hlist_for_each_entry(cur, &nf_ct_helper_hash[i], hnode) { + if (!strcmp(cur->name, me->name) && + (cur->tuple.src.l3num == NFPROTO_UNSPEC || + cur->tuple.src.l3num == me->tuple.src.l3num) && + cur->tuple.dst.protonum == me->tuple.dst.protonum) { + ret = -EEXIST; + goto out; + } + } + } + + /* avoid unpredictable behaviour for auto_assign_helper */ + if (!(me->flags & NF_CT_HELPER_F_USERSPACE)) { + hlist_for_each_entry(cur, &nf_ct_helper_hash[h], hnode) { + if (nf_ct_tuple_src_mask_cmp(&cur->tuple, &me->tuple, + &mask)) { + ret = -EEXIST; + goto out; + } + } + } + refcount_set(&me->refcnt, 1); + hlist_add_head_rcu(&me->hnode, &nf_ct_helper_hash[h]); + nf_ct_helper_count++; +out: + mutex_unlock(&nf_ct_helper_mutex); + return ret; +} +EXPORT_SYMBOL_GPL(nf_conntrack_helper_register); + +static bool expect_iter_me(struct nf_conntrack_expect *exp, void *data) +{ + struct nf_conn_help *help = nfct_help(exp->master); + const struct nf_conntrack_helper *me = data; + const struct nf_conntrack_helper *this; + + if (exp->helper == me) + return true; + + this = rcu_dereference_protected(help->helper, + lockdep_is_held(&nf_conntrack_expect_lock)); + return this == me; +} + +void nf_conntrack_helper_unregister(struct nf_conntrack_helper *me) +{ + mutex_lock(&nf_ct_helper_mutex); + hlist_del_rcu(&me->hnode); + nf_ct_helper_count--; + mutex_unlock(&nf_ct_helper_mutex); + + /* Make sure every nothing is still using the helper unless its a + * connection in the hash. + */ + synchronize_rcu(); + + nf_ct_expect_iterate_destroy(expect_iter_me, NULL); + nf_ct_iterate_destroy(unhelp, me); +} +EXPORT_SYMBOL_GPL(nf_conntrack_helper_unregister); + +void nf_ct_helper_init(struct nf_conntrack_helper *helper, + u16 l3num, u16 protonum, const char *name, + u16 default_port, u16 spec_port, u32 id, + const struct nf_conntrack_expect_policy *exp_pol, + u32 expect_class_max, + int (*help)(struct sk_buff *skb, unsigned int protoff, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo), + int (*from_nlattr)(struct nlattr *attr, + struct nf_conn *ct), + struct module *module) +{ + helper->tuple.src.l3num = l3num; + helper->tuple.dst.protonum = protonum; + helper->tuple.src.u.all = htons(spec_port); + helper->expect_policy = exp_pol; + helper->expect_class_max = expect_class_max; + helper->help = help; + helper->from_nlattr = from_nlattr; + helper->me = module; + snprintf(helper->nat_mod_name, sizeof(helper->nat_mod_name), + NF_NAT_HELPER_PREFIX "%s", name); + + if (spec_port == default_port) + snprintf(helper->name, sizeof(helper->name), "%s", name); + else + snprintf(helper->name, sizeof(helper->name), "%s-%u", name, id); +} +EXPORT_SYMBOL_GPL(nf_ct_helper_init); + +int nf_conntrack_helpers_register(struct nf_conntrack_helper *helper, + unsigned int n) +{ + unsigned int i; + int err = 0; + + for (i = 0; i < n; i++) { + err = nf_conntrack_helper_register(&helper[i]); + if (err < 0) + goto err; + } + + return err; +err: + if (i > 0) + nf_conntrack_helpers_unregister(helper, i); + return err; +} +EXPORT_SYMBOL_GPL(nf_conntrack_helpers_register); + +void nf_conntrack_helpers_unregister(struct nf_conntrack_helper *helper, + unsigned int n) +{ + while (n-- > 0) + nf_conntrack_helper_unregister(&helper[n]); +} +EXPORT_SYMBOL_GPL(nf_conntrack_helpers_unregister); + +void nf_nat_helper_register(struct nf_conntrack_nat_helper *nat) +{ + mutex_lock(&nf_ct_nat_helpers_mutex); + list_add_rcu(&nat->list, &nf_ct_nat_helpers); + mutex_unlock(&nf_ct_nat_helpers_mutex); +} +EXPORT_SYMBOL_GPL(nf_nat_helper_register); + +void nf_nat_helper_unregister(struct nf_conntrack_nat_helper *nat) +{ + mutex_lock(&nf_ct_nat_helpers_mutex); + list_del_rcu(&nat->list); + mutex_unlock(&nf_ct_nat_helpers_mutex); +} +EXPORT_SYMBOL_GPL(nf_nat_helper_unregister); + +int nf_conntrack_helper_init(void) +{ + nf_ct_helper_hsize = 1; /* gets rounded up to use one page */ + nf_ct_helper_hash = + nf_ct_alloc_hashtable(&nf_ct_helper_hsize, 0); + if (!nf_ct_helper_hash) + return -ENOMEM; + + INIT_LIST_HEAD(&nf_ct_nat_helpers); + return 0; +} + +void nf_conntrack_helper_fini(void) +{ + kvfree(nf_ct_helper_hash); + nf_ct_helper_hash = NULL; +} diff --git a/net/netfilter/nf_conntrack_irc.c b/net/netfilter/nf_conntrack_irc.c new file mode 100644 index 000000000..5703846be --- /dev/null +++ b/net/netfilter/nf_conntrack_irc.c @@ -0,0 +1,314 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* IRC extension for IP connection tracking, Version 1.21 + * (C) 2000-2002 by Harald Welte + * based on RR's ip_conntrack_ftp.c + * (C) 2006-2012 Patrick McHardy + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#define MAX_PORTS 8 +static unsigned short ports[MAX_PORTS]; +static unsigned int ports_c; +static unsigned int max_dcc_channels = 8; +static unsigned int dcc_timeout __read_mostly = 300; +/* This is slow, but it's simple. --RR */ +static char *irc_buffer; +static DEFINE_SPINLOCK(irc_buffer_lock); + +unsigned int (*nf_nat_irc_hook)(struct sk_buff *skb, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned int matchoff, + unsigned int matchlen, + struct nf_conntrack_expect *exp) __read_mostly; +EXPORT_SYMBOL_GPL(nf_nat_irc_hook); + +#define HELPER_NAME "irc" +#define MAX_SEARCH_SIZE 4095 + +MODULE_AUTHOR("Harald Welte "); +MODULE_DESCRIPTION("IRC (DCC) connection tracking helper"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ip_conntrack_irc"); +MODULE_ALIAS_NFCT_HELPER(HELPER_NAME); + +module_param_array(ports, ushort, &ports_c, 0400); +MODULE_PARM_DESC(ports, "port numbers of IRC servers"); +module_param(max_dcc_channels, uint, 0400); +MODULE_PARM_DESC(max_dcc_channels, "max number of expected DCC channels per " + "IRC session"); +module_param(dcc_timeout, uint, 0400); +MODULE_PARM_DESC(dcc_timeout, "timeout on for unestablished DCC channels"); + +static const char *const dccprotos[] = { + "SEND ", "CHAT ", "MOVE ", "TSEND ", "SCHAT " +}; + +#define MINMATCHLEN 5 + +/* tries to get the ip_addr and port out of a dcc command + * return value: -1 on failure, 0 on success + * data pointer to first byte of DCC command data + * data_end pointer to last byte of dcc command data + * ip returns parsed ip of dcc command + * port returns parsed port of dcc command + * ad_beg_p returns pointer to first byte of addr data + * ad_end_p returns pointer to last byte of addr data + */ +static int parse_dcc(char *data, const char *data_end, __be32 *ip, + u_int16_t *port, char **ad_beg_p, char **ad_end_p) +{ + char *tmp; + + /* at least 12: "AAAAAAAA P\1\n" */ + while (*data++ != ' ') + if (data > data_end - 12) + return -1; + + /* Make sure we have a newline character within the packet boundaries + * because simple_strtoul parses until the first invalid character. */ + for (tmp = data; tmp <= data_end; tmp++) + if (*tmp == '\n') + break; + if (tmp > data_end || *tmp != '\n') + return -1; + + *ad_beg_p = data; + *ip = cpu_to_be32(simple_strtoul(data, &data, 10)); + + /* skip blanks between ip and port */ + while (*data == ' ') { + if (data >= data_end) + return -1; + data++; + } + + *port = simple_strtoul(data, &data, 10); + *ad_end_p = data; + + return 0; +} + +static int help(struct sk_buff *skb, unsigned int protoff, + struct nf_conn *ct, enum ip_conntrack_info ctinfo) +{ + unsigned int dataoff; + const struct iphdr *iph; + const struct tcphdr *th; + struct tcphdr _tcph; + const char *data_limit; + char *data, *ib_ptr; + int dir = CTINFO2DIR(ctinfo); + struct nf_conntrack_expect *exp; + struct nf_conntrack_tuple *tuple; + __be32 dcc_ip; + u_int16_t dcc_port; + __be16 port; + int i, ret = NF_ACCEPT; + char *addr_beg_p, *addr_end_p; + typeof(nf_nat_irc_hook) nf_nat_irc; + unsigned int datalen; + + /* If packet is coming from IRC server */ + if (dir == IP_CT_DIR_REPLY) + return NF_ACCEPT; + + /* Until there's been traffic both ways, don't look in packets. */ + if (ctinfo != IP_CT_ESTABLISHED && ctinfo != IP_CT_ESTABLISHED_REPLY) + return NF_ACCEPT; + + /* Not a full tcp header? */ + th = skb_header_pointer(skb, protoff, sizeof(_tcph), &_tcph); + if (th == NULL) + return NF_ACCEPT; + + /* No data? */ + dataoff = protoff + th->doff*4; + if (dataoff >= skb->len) + return NF_ACCEPT; + + datalen = skb->len - dataoff; + if (datalen > MAX_SEARCH_SIZE) + datalen = MAX_SEARCH_SIZE; + + spin_lock_bh(&irc_buffer_lock); + ib_ptr = skb_header_pointer(skb, dataoff, datalen, + irc_buffer); + if (!ib_ptr) { + spin_unlock_bh(&irc_buffer_lock); + return NF_ACCEPT; + } + + data = ib_ptr; + data_limit = ib_ptr + datalen; + + /* Skip any whitespace */ + while (data < data_limit - 10) { + if (*data == ' ' || *data == '\r' || *data == '\n') + data++; + else + break; + } + + /* strlen("PRIVMSG x ")=10 */ + if (data < data_limit - 10) { + if (strncasecmp("PRIVMSG ", data, 8)) + goto out; + data += 8; + } + + /* strlen(" :\1DCC SENT t AAAAAAAA P\1\n")=26 + * 7+MINMATCHLEN+strlen("t AAAAAAAA P\1\n")=26 + */ + while (data < data_limit - (21 + MINMATCHLEN)) { + /* Find first " :", the start of message */ + if (memcmp(data, " :", 2)) { + data++; + continue; + } + data += 2; + + /* then check that place only for the DCC command */ + if (memcmp(data, "\1DCC ", 5)) + goto out; + data += 5; + /* we have at least (21+MINMATCHLEN)-(2+5) bytes valid data left */ + + iph = ip_hdr(skb); + pr_debug("DCC found in master %pI4:%u %pI4:%u\n", + &iph->saddr, ntohs(th->source), + &iph->daddr, ntohs(th->dest)); + + for (i = 0; i < ARRAY_SIZE(dccprotos); i++) { + if (memcmp(data, dccprotos[i], strlen(dccprotos[i]))) { + /* no match */ + continue; + } + data += strlen(dccprotos[i]); + pr_debug("DCC %s detected\n", dccprotos[i]); + + /* we have at least + * (21+MINMATCHLEN)-7-dccprotos[i].matchlen bytes valid + * data left (== 14/13 bytes) */ + if (parse_dcc(data, data_limit, &dcc_ip, + &dcc_port, &addr_beg_p, &addr_end_p)) { + pr_debug("unable to parse dcc command\n"); + continue; + } + + pr_debug("DCC bound ip/port: %pI4:%u\n", + &dcc_ip, dcc_port); + + /* dcc_ip can be the internal OR external (NAT'ed) IP */ + tuple = &ct->tuplehash[dir].tuple; + if ((tuple->src.u3.ip != dcc_ip && + ct->tuplehash[!dir].tuple.dst.u3.ip != dcc_ip) || + dcc_port == 0) { + net_warn_ratelimited("Forged DCC command from %pI4: %pI4:%u\n", + &tuple->src.u3.ip, + &dcc_ip, dcc_port); + continue; + } + + exp = nf_ct_expect_alloc(ct); + if (exp == NULL) { + nf_ct_helper_log(skb, ct, + "cannot alloc expectation"); + ret = NF_DROP; + goto out; + } + tuple = &ct->tuplehash[!dir].tuple; + port = htons(dcc_port); + nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT, + tuple->src.l3num, + NULL, &tuple->dst.u3, + IPPROTO_TCP, NULL, &port); + + nf_nat_irc = rcu_dereference(nf_nat_irc_hook); + if (nf_nat_irc && ct->status & IPS_NAT_MASK) + ret = nf_nat_irc(skb, ctinfo, protoff, + addr_beg_p - ib_ptr, + addr_end_p - addr_beg_p, + exp); + else if (nf_ct_expect_related(exp, 0) != 0) { + nf_ct_helper_log(skb, ct, + "cannot add expectation"); + ret = NF_DROP; + } + nf_ct_expect_put(exp); + goto out; + } + } + out: + spin_unlock_bh(&irc_buffer_lock); + return ret; +} + +static struct nf_conntrack_helper irc[MAX_PORTS] __read_mostly; +static struct nf_conntrack_expect_policy irc_exp_policy; + +static int __init nf_conntrack_irc_init(void) +{ + int i, ret; + + if (max_dcc_channels < 1) { + pr_err("max_dcc_channels must not be zero\n"); + return -EINVAL; + } + + if (max_dcc_channels > NF_CT_EXPECT_MAX_CNT) { + pr_err("max_dcc_channels must not be more than %u\n", + NF_CT_EXPECT_MAX_CNT); + return -EINVAL; + } + + irc_exp_policy.max_expected = max_dcc_channels; + irc_exp_policy.timeout = dcc_timeout; + + irc_buffer = kmalloc(MAX_SEARCH_SIZE + 1, GFP_KERNEL); + if (!irc_buffer) + return -ENOMEM; + + /* If no port given, default to standard irc port */ + if (ports_c == 0) + ports[ports_c++] = IRC_PORT; + + for (i = 0; i < ports_c; i++) { + nf_ct_helper_init(&irc[i], AF_INET, IPPROTO_TCP, HELPER_NAME, + IRC_PORT, ports[i], i, &irc_exp_policy, + 0, help, NULL, THIS_MODULE); + } + + ret = nf_conntrack_helpers_register(&irc[0], ports_c); + if (ret) { + pr_err("failed to register helpers\n"); + kfree(irc_buffer); + return ret; + } + + return 0; +} + +static void __exit nf_conntrack_irc_fini(void) +{ + nf_conntrack_helpers_unregister(irc, ports_c); + kfree(irc_buffer); +} + +module_init(nf_conntrack_irc_init); +module_exit(nf_conntrack_irc_fini); diff --git a/net/netfilter/nf_conntrack_labels.c b/net/netfilter/nf_conntrack_labels.c new file mode 100644 index 000000000..6e70e137a --- /dev/null +++ b/net/netfilter/nf_conntrack_labels.c @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * test/set flag bits stored in conntrack extension area. + * + * (C) 2013 Astaro GmbH & Co KG + */ + +#include +#include + +#include +#include + +static DEFINE_SPINLOCK(nf_connlabels_lock); + +static int replace_u32(u32 *address, u32 mask, u32 new) +{ + u32 old, tmp; + + do { + old = *address; + tmp = (old & mask) ^ new; + if (old == tmp) + return 0; + } while (cmpxchg(address, old, tmp) != old); + + return 1; +} + +int nf_connlabels_replace(struct nf_conn *ct, + const u32 *data, + const u32 *mask, unsigned int words32) +{ + struct nf_conn_labels *labels; + unsigned int size, i; + int changed = 0; + u32 *dst; + + labels = nf_ct_labels_find(ct); + if (!labels) + return -ENOSPC; + + size = sizeof(labels->bits); + if (size < (words32 * sizeof(u32))) + words32 = size / sizeof(u32); + + dst = (u32 *) labels->bits; + for (i = 0; i < words32; i++) + changed |= replace_u32(&dst[i], mask ? ~mask[i] : 0, data[i]); + + size /= sizeof(u32); + for (i = words32; i < size; i++) /* pad */ + replace_u32(&dst[i], 0, 0); + + if (changed) + nf_conntrack_event_cache(IPCT_LABEL, ct); + return 0; +} +EXPORT_SYMBOL_GPL(nf_connlabels_replace); + +int nf_connlabels_get(struct net *net, unsigned int bits) +{ + if (BIT_WORD(bits) >= NF_CT_LABELS_MAX_SIZE / sizeof(long)) + return -ERANGE; + + spin_lock(&nf_connlabels_lock); + net->ct.labels_used++; + spin_unlock(&nf_connlabels_lock); + + BUILD_BUG_ON(NF_CT_LABELS_MAX_SIZE / sizeof(long) >= U8_MAX); + + return 0; +} +EXPORT_SYMBOL_GPL(nf_connlabels_get); + +void nf_connlabels_put(struct net *net) +{ + spin_lock(&nf_connlabels_lock); + net->ct.labels_used--; + spin_unlock(&nf_connlabels_lock); +} +EXPORT_SYMBOL_GPL(nf_connlabels_put); diff --git a/net/netfilter/nf_conntrack_netbios_ns.c b/net/netfilter/nf_conntrack_netbios_ns.c new file mode 100644 index 000000000..55415f011 --- /dev/null +++ b/net/netfilter/nf_conntrack_netbios_ns.c @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NetBIOS name service broadcast connection tracking helper + * + * (c) 2005 Patrick McHardy + */ +/* + * This helper tracks locally originating NetBIOS name service + * requests by issuing permanent expectations (valid until + * timing out) matching all reply connections from the + * destination network. The only NetBIOS specific thing is + * actually the port number. + */ +#include +#include +#include +#include + +#include +#include +#include + +#define HELPER_NAME "netbios-ns" +#define NMBD_PORT 137 + +MODULE_AUTHOR("Patrick McHardy "); +MODULE_DESCRIPTION("NetBIOS name service broadcast connection tracking helper"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ip_conntrack_netbios_ns"); +MODULE_ALIAS_NFCT_HELPER(HELPER_NAME); + +static unsigned int timeout __read_mostly = 3; +module_param(timeout, uint, 0400); +MODULE_PARM_DESC(timeout, "timeout for master connection/replies in seconds"); + +static struct nf_conntrack_expect_policy exp_policy = { + .max_expected = 1, +}; + +static int netbios_ns_help(struct sk_buff *skb, unsigned int protoff, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo) +{ + return nf_conntrack_broadcast_help(skb, ct, ctinfo, timeout); +} + +static struct nf_conntrack_helper helper __read_mostly = { + .name = HELPER_NAME, + .tuple.src.l3num = NFPROTO_IPV4, + .tuple.src.u.udp.port = cpu_to_be16(NMBD_PORT), + .tuple.dst.protonum = IPPROTO_UDP, + .me = THIS_MODULE, + .help = netbios_ns_help, + .expect_policy = &exp_policy, +}; + +static int __init nf_conntrack_netbios_ns_init(void) +{ + NF_CT_HELPER_BUILD_BUG_ON(0); + + exp_policy.timeout = timeout; + return nf_conntrack_helper_register(&helper); +} + +static void __exit nf_conntrack_netbios_ns_fini(void) +{ + nf_conntrack_helper_unregister(&helper); +} + +module_init(nf_conntrack_netbios_ns_init); +module_exit(nf_conntrack_netbios_ns_fini); diff --git a/net/netfilter/nf_conntrack_netlink.c b/net/netfilter/nf_conntrack_netlink.c new file mode 100644 index 000000000..9ee8abd3e --- /dev/null +++ b/net/netfilter/nf_conntrack_netlink.c @@ -0,0 +1,3918 @@ +/* Connection tracking via netlink socket. Allows for user space + * protocol helpers and general trouble making from userspace. + * + * (C) 2001 by Jay Schulist + * (C) 2002-2006 by Harald Welte + * (C) 2003 by Patrick Mchardy + * (C) 2005-2012 by Pablo Neira Ayuso + * + * Initial connection tracking via netlink development funded and + * generally made possible by Network Robots, Inc. (www.networkrobots.com) + * + * Further development of this code funded by Astaro AG (http://www.astaro.com) + * + * This software may be used and distributed according to the terms + * of the GNU General Public License, incorporated herein by reference. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_NF_NAT) +#include +#include +#endif + +#include +#include + +#include "nf_internals.h" + +MODULE_LICENSE("GPL"); + +struct ctnetlink_list_dump_ctx { + struct nf_conn *last; + unsigned int cpu; + bool done; +}; + +static int ctnetlink_dump_tuples_proto(struct sk_buff *skb, + const struct nf_conntrack_tuple *tuple, + const struct nf_conntrack_l4proto *l4proto) +{ + int ret = 0; + struct nlattr *nest_parms; + + nest_parms = nla_nest_start(skb, CTA_TUPLE_PROTO); + if (!nest_parms) + goto nla_put_failure; + if (nla_put_u8(skb, CTA_PROTO_NUM, tuple->dst.protonum)) + goto nla_put_failure; + + if (likely(l4proto->tuple_to_nlattr)) + ret = l4proto->tuple_to_nlattr(skb, tuple); + + nla_nest_end(skb, nest_parms); + + return ret; + +nla_put_failure: + return -1; +} + +static int ipv4_tuple_to_nlattr(struct sk_buff *skb, + const struct nf_conntrack_tuple *tuple) +{ + if (nla_put_in_addr(skb, CTA_IP_V4_SRC, tuple->src.u3.ip) || + nla_put_in_addr(skb, CTA_IP_V4_DST, tuple->dst.u3.ip)) + return -EMSGSIZE; + return 0; +} + +static int ipv6_tuple_to_nlattr(struct sk_buff *skb, + const struct nf_conntrack_tuple *tuple) +{ + if (nla_put_in6_addr(skb, CTA_IP_V6_SRC, &tuple->src.u3.in6) || + nla_put_in6_addr(skb, CTA_IP_V6_DST, &tuple->dst.u3.in6)) + return -EMSGSIZE; + return 0; +} + +static int ctnetlink_dump_tuples_ip(struct sk_buff *skb, + const struct nf_conntrack_tuple *tuple) +{ + int ret = 0; + struct nlattr *nest_parms; + + nest_parms = nla_nest_start(skb, CTA_TUPLE_IP); + if (!nest_parms) + goto nla_put_failure; + + switch (tuple->src.l3num) { + case NFPROTO_IPV4: + ret = ipv4_tuple_to_nlattr(skb, tuple); + break; + case NFPROTO_IPV6: + ret = ipv6_tuple_to_nlattr(skb, tuple); + break; + } + + nla_nest_end(skb, nest_parms); + + return ret; + +nla_put_failure: + return -1; +} + +static int ctnetlink_dump_tuples(struct sk_buff *skb, + const struct nf_conntrack_tuple *tuple) +{ + const struct nf_conntrack_l4proto *l4proto; + int ret; + + rcu_read_lock(); + ret = ctnetlink_dump_tuples_ip(skb, tuple); + + if (ret >= 0) { + l4proto = nf_ct_l4proto_find(tuple->dst.protonum); + ret = ctnetlink_dump_tuples_proto(skb, tuple, l4proto); + } + rcu_read_unlock(); + return ret; +} + +static int ctnetlink_dump_zone_id(struct sk_buff *skb, int attrtype, + const struct nf_conntrack_zone *zone, int dir) +{ + if (zone->id == NF_CT_DEFAULT_ZONE_ID || zone->dir != dir) + return 0; + if (nla_put_be16(skb, attrtype, htons(zone->id))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static int ctnetlink_dump_status(struct sk_buff *skb, const struct nf_conn *ct) +{ + if (nla_put_be32(skb, CTA_STATUS, htonl(ct->status))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static int ctnetlink_dump_timeout(struct sk_buff *skb, const struct nf_conn *ct, + bool skip_zero) +{ + long timeout; + + if (nf_ct_is_confirmed(ct)) + timeout = nf_ct_expires(ct) / HZ; + else + timeout = ct->timeout / HZ; + + if (skip_zero && timeout == 0) + return 0; + + if (nla_put_be32(skb, CTA_TIMEOUT, htonl(timeout))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static int ctnetlink_dump_protoinfo(struct sk_buff *skb, struct nf_conn *ct, + bool destroy) +{ + const struct nf_conntrack_l4proto *l4proto; + struct nlattr *nest_proto; + int ret; + + l4proto = nf_ct_l4proto_find(nf_ct_protonum(ct)); + if (!l4proto->to_nlattr) + return 0; + + nest_proto = nla_nest_start(skb, CTA_PROTOINFO); + if (!nest_proto) + goto nla_put_failure; + + ret = l4proto->to_nlattr(skb, nest_proto, ct, destroy); + + nla_nest_end(skb, nest_proto); + + return ret; + +nla_put_failure: + return -1; +} + +static int ctnetlink_dump_helpinfo(struct sk_buff *skb, + const struct nf_conn *ct) +{ + struct nlattr *nest_helper; + const struct nf_conn_help *help = nfct_help(ct); + struct nf_conntrack_helper *helper; + + if (!help) + return 0; + + rcu_read_lock(); + helper = rcu_dereference(help->helper); + if (!helper) + goto out; + + nest_helper = nla_nest_start(skb, CTA_HELP); + if (!nest_helper) + goto nla_put_failure; + if (nla_put_string(skb, CTA_HELP_NAME, helper->name)) + goto nla_put_failure; + + if (helper->to_nlattr) + helper->to_nlattr(skb, ct); + + nla_nest_end(skb, nest_helper); +out: + rcu_read_unlock(); + return 0; + +nla_put_failure: + rcu_read_unlock(); + return -1; +} + +static int +dump_counters(struct sk_buff *skb, struct nf_conn_acct *acct, + enum ip_conntrack_dir dir, int type) +{ + enum ctattr_type attr = dir ? CTA_COUNTERS_REPLY: CTA_COUNTERS_ORIG; + struct nf_conn_counter *counter = acct->counter; + struct nlattr *nest_count; + u64 pkts, bytes; + + if (type == IPCTNL_MSG_CT_GET_CTRZERO) { + pkts = atomic64_xchg(&counter[dir].packets, 0); + bytes = atomic64_xchg(&counter[dir].bytes, 0); + } else { + pkts = atomic64_read(&counter[dir].packets); + bytes = atomic64_read(&counter[dir].bytes); + } + + nest_count = nla_nest_start(skb, attr); + if (!nest_count) + goto nla_put_failure; + + if (nla_put_be64(skb, CTA_COUNTERS_PACKETS, cpu_to_be64(pkts), + CTA_COUNTERS_PAD) || + nla_put_be64(skb, CTA_COUNTERS_BYTES, cpu_to_be64(bytes), + CTA_COUNTERS_PAD)) + goto nla_put_failure; + + nla_nest_end(skb, nest_count); + + return 0; + +nla_put_failure: + return -1; +} + +static int +ctnetlink_dump_acct(struct sk_buff *skb, const struct nf_conn *ct, int type) +{ + struct nf_conn_acct *acct = nf_conn_acct_find(ct); + + if (!acct) + return 0; + + if (dump_counters(skb, acct, IP_CT_DIR_ORIGINAL, type) < 0) + return -1; + if (dump_counters(skb, acct, IP_CT_DIR_REPLY, type) < 0) + return -1; + + return 0; +} + +static int +ctnetlink_dump_timestamp(struct sk_buff *skb, const struct nf_conn *ct) +{ + struct nlattr *nest_count; + const struct nf_conn_tstamp *tstamp; + + tstamp = nf_conn_tstamp_find(ct); + if (!tstamp) + return 0; + + nest_count = nla_nest_start(skb, CTA_TIMESTAMP); + if (!nest_count) + goto nla_put_failure; + + if (nla_put_be64(skb, CTA_TIMESTAMP_START, cpu_to_be64(tstamp->start), + CTA_TIMESTAMP_PAD) || + (tstamp->stop != 0 && nla_put_be64(skb, CTA_TIMESTAMP_STOP, + cpu_to_be64(tstamp->stop), + CTA_TIMESTAMP_PAD))) + goto nla_put_failure; + nla_nest_end(skb, nest_count); + + return 0; + +nla_put_failure: + return -1; +} + +#ifdef CONFIG_NF_CONNTRACK_MARK +static int ctnetlink_dump_mark(struct sk_buff *skb, const struct nf_conn *ct, + bool dump) +{ + u32 mark = READ_ONCE(ct->mark); + + if (!mark && !dump) + return 0; + + if (nla_put_be32(skb, CTA_MARK, htonl(mark))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} +#else +#define ctnetlink_dump_mark(a, b, c) (0) +#endif + +#ifdef CONFIG_NF_CONNTRACK_SECMARK +static int ctnetlink_dump_secctx(struct sk_buff *skb, const struct nf_conn *ct) +{ + struct nlattr *nest_secctx; + int len, ret; + char *secctx; + + ret = security_secid_to_secctx(ct->secmark, &secctx, &len); + if (ret) + return 0; + + ret = -1; + nest_secctx = nla_nest_start(skb, CTA_SECCTX); + if (!nest_secctx) + goto nla_put_failure; + + if (nla_put_string(skb, CTA_SECCTX_NAME, secctx)) + goto nla_put_failure; + nla_nest_end(skb, nest_secctx); + + ret = 0; +nla_put_failure: + security_release_secctx(secctx, len); + return ret; +} +#else +#define ctnetlink_dump_secctx(a, b) (0) +#endif + +#ifdef CONFIG_NF_CONNTRACK_LABELS +static inline int ctnetlink_label_size(const struct nf_conn *ct) +{ + struct nf_conn_labels *labels = nf_ct_labels_find(ct); + + if (!labels) + return 0; + return nla_total_size(sizeof(labels->bits)); +} + +static int +ctnetlink_dump_labels(struct sk_buff *skb, const struct nf_conn *ct) +{ + struct nf_conn_labels *labels = nf_ct_labels_find(ct); + unsigned int i; + + if (!labels) + return 0; + + i = 0; + do { + if (labels->bits[i] != 0) + return nla_put(skb, CTA_LABELS, sizeof(labels->bits), + labels->bits); + i++; + } while (i < ARRAY_SIZE(labels->bits)); + + return 0; +} +#else +#define ctnetlink_dump_labels(a, b) (0) +#define ctnetlink_label_size(a) (0) +#endif + +#define master_tuple(ct) &(ct->master->tuplehash[IP_CT_DIR_ORIGINAL].tuple) + +static int ctnetlink_dump_master(struct sk_buff *skb, const struct nf_conn *ct) +{ + struct nlattr *nest_parms; + + if (!(ct->status & IPS_EXPECTED)) + return 0; + + nest_parms = nla_nest_start(skb, CTA_TUPLE_MASTER); + if (!nest_parms) + goto nla_put_failure; + if (ctnetlink_dump_tuples(skb, master_tuple(ct)) < 0) + goto nla_put_failure; + nla_nest_end(skb, nest_parms); + + return 0; + +nla_put_failure: + return -1; +} + +static int +dump_ct_seq_adj(struct sk_buff *skb, const struct nf_ct_seqadj *seq, int type) +{ + struct nlattr *nest_parms; + + nest_parms = nla_nest_start(skb, type); + if (!nest_parms) + goto nla_put_failure; + + if (nla_put_be32(skb, CTA_SEQADJ_CORRECTION_POS, + htonl(seq->correction_pos)) || + nla_put_be32(skb, CTA_SEQADJ_OFFSET_BEFORE, + htonl(seq->offset_before)) || + nla_put_be32(skb, CTA_SEQADJ_OFFSET_AFTER, + htonl(seq->offset_after))) + goto nla_put_failure; + + nla_nest_end(skb, nest_parms); + + return 0; + +nla_put_failure: + return -1; +} + +static int ctnetlink_dump_ct_seq_adj(struct sk_buff *skb, struct nf_conn *ct) +{ + struct nf_conn_seqadj *seqadj = nfct_seqadj(ct); + struct nf_ct_seqadj *seq; + + if (!(ct->status & IPS_SEQ_ADJUST) || !seqadj) + return 0; + + spin_lock_bh(&ct->lock); + seq = &seqadj->seq[IP_CT_DIR_ORIGINAL]; + if (dump_ct_seq_adj(skb, seq, CTA_SEQ_ADJ_ORIG) == -1) + goto err; + + seq = &seqadj->seq[IP_CT_DIR_REPLY]; + if (dump_ct_seq_adj(skb, seq, CTA_SEQ_ADJ_REPLY) == -1) + goto err; + + spin_unlock_bh(&ct->lock); + return 0; +err: + spin_unlock_bh(&ct->lock); + return -1; +} + +static int ctnetlink_dump_ct_synproxy(struct sk_buff *skb, struct nf_conn *ct) +{ + struct nf_conn_synproxy *synproxy = nfct_synproxy(ct); + struct nlattr *nest_parms; + + if (!synproxy) + return 0; + + nest_parms = nla_nest_start(skb, CTA_SYNPROXY); + if (!nest_parms) + goto nla_put_failure; + + if (nla_put_be32(skb, CTA_SYNPROXY_ISN, htonl(synproxy->isn)) || + nla_put_be32(skb, CTA_SYNPROXY_ITS, htonl(synproxy->its)) || + nla_put_be32(skb, CTA_SYNPROXY_TSOFF, htonl(synproxy->tsoff))) + goto nla_put_failure; + + nla_nest_end(skb, nest_parms); + + return 0; + +nla_put_failure: + return -1; +} + +static int ctnetlink_dump_id(struct sk_buff *skb, const struct nf_conn *ct) +{ + __be32 id = (__force __be32)nf_ct_get_id(ct); + + if (nla_put_be32(skb, CTA_ID, id)) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static int ctnetlink_dump_use(struct sk_buff *skb, const struct nf_conn *ct) +{ + if (nla_put_be32(skb, CTA_USE, htonl(refcount_read(&ct->ct_general.use)))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +/* all these functions access ct->ext. Caller must either hold a reference + * on ct or prevent its deletion by holding either the bucket spinlock or + * pcpu dying list lock. + */ +static int ctnetlink_dump_extinfo(struct sk_buff *skb, + struct nf_conn *ct, u32 type) +{ + if (ctnetlink_dump_acct(skb, ct, type) < 0 || + ctnetlink_dump_timestamp(skb, ct) < 0 || + ctnetlink_dump_helpinfo(skb, ct) < 0 || + ctnetlink_dump_labels(skb, ct) < 0 || + ctnetlink_dump_ct_seq_adj(skb, ct) < 0 || + ctnetlink_dump_ct_synproxy(skb, ct) < 0) + return -1; + + return 0; +} + +static int ctnetlink_dump_info(struct sk_buff *skb, struct nf_conn *ct) +{ + if (ctnetlink_dump_status(skb, ct) < 0 || + ctnetlink_dump_mark(skb, ct, true) < 0 || + ctnetlink_dump_secctx(skb, ct) < 0 || + ctnetlink_dump_id(skb, ct) < 0 || + ctnetlink_dump_use(skb, ct) < 0 || + ctnetlink_dump_master(skb, ct) < 0) + return -1; + + if (!test_bit(IPS_OFFLOAD_BIT, &ct->status) && + (ctnetlink_dump_timeout(skb, ct, false) < 0 || + ctnetlink_dump_protoinfo(skb, ct, false) < 0)) + return -1; + + return 0; +} + +static int +ctnetlink_fill_info(struct sk_buff *skb, u32 portid, u32 seq, u32 type, + struct nf_conn *ct, bool extinfo, unsigned int flags) +{ + const struct nf_conntrack_zone *zone; + struct nlmsghdr *nlh; + struct nlattr *nest_parms; + unsigned int event; + + if (portid) + flags |= NLM_F_MULTI; + event = nfnl_msg_type(NFNL_SUBSYS_CTNETLINK, IPCTNL_MSG_CT_NEW); + nlh = nfnl_msg_put(skb, portid, seq, event, flags, nf_ct_l3num(ct), + NFNETLINK_V0, 0); + if (!nlh) + goto nlmsg_failure; + + zone = nf_ct_zone(ct); + + nest_parms = nla_nest_start(skb, CTA_TUPLE_ORIG); + if (!nest_parms) + goto nla_put_failure; + if (ctnetlink_dump_tuples(skb, nf_ct_tuple(ct, IP_CT_DIR_ORIGINAL)) < 0) + goto nla_put_failure; + if (ctnetlink_dump_zone_id(skb, CTA_TUPLE_ZONE, zone, + NF_CT_ZONE_DIR_ORIG) < 0) + goto nla_put_failure; + nla_nest_end(skb, nest_parms); + + nest_parms = nla_nest_start(skb, CTA_TUPLE_REPLY); + if (!nest_parms) + goto nla_put_failure; + if (ctnetlink_dump_tuples(skb, nf_ct_tuple(ct, IP_CT_DIR_REPLY)) < 0) + goto nla_put_failure; + if (ctnetlink_dump_zone_id(skb, CTA_TUPLE_ZONE, zone, + NF_CT_ZONE_DIR_REPL) < 0) + goto nla_put_failure; + nla_nest_end(skb, nest_parms); + + if (ctnetlink_dump_zone_id(skb, CTA_ZONE, zone, + NF_CT_DEFAULT_ZONE_DIR) < 0) + goto nla_put_failure; + + if (ctnetlink_dump_info(skb, ct) < 0) + goto nla_put_failure; + if (extinfo && ctnetlink_dump_extinfo(skb, ct, type) < 0) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return skb->len; + +nlmsg_failure: +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -1; +} + +static const struct nla_policy cta_ip_nla_policy[CTA_IP_MAX + 1] = { + [CTA_IP_V4_SRC] = { .type = NLA_U32 }, + [CTA_IP_V4_DST] = { .type = NLA_U32 }, + [CTA_IP_V6_SRC] = { .len = sizeof(__be32) * 4 }, + [CTA_IP_V6_DST] = { .len = sizeof(__be32) * 4 }, +}; + +#if defined(CONFIG_NETFILTER_NETLINK_GLUE_CT) || defined(CONFIG_NF_CONNTRACK_EVENTS) +static size_t ctnetlink_proto_size(const struct nf_conn *ct) +{ + const struct nf_conntrack_l4proto *l4proto; + size_t len, len4 = 0; + + len = nla_policy_len(cta_ip_nla_policy, CTA_IP_MAX + 1); + len *= 3u; /* ORIG, REPLY, MASTER */ + + l4proto = nf_ct_l4proto_find(nf_ct_protonum(ct)); + len += l4proto->nlattr_size; + if (l4proto->nlattr_tuple_size) { + len4 = l4proto->nlattr_tuple_size(); + len4 *= 3u; /* ORIG, REPLY, MASTER */ + } + + return len + len4; +} +#endif + +static inline size_t ctnetlink_acct_size(const struct nf_conn *ct) +{ + if (!nf_ct_ext_exist(ct, NF_CT_EXT_ACCT)) + return 0; + return 2 * nla_total_size(0) /* CTA_COUNTERS_ORIG|REPL */ + + 2 * nla_total_size_64bit(sizeof(uint64_t)) /* CTA_COUNTERS_PACKETS */ + + 2 * nla_total_size_64bit(sizeof(uint64_t)) /* CTA_COUNTERS_BYTES */ + ; +} + +static inline int ctnetlink_secctx_size(const struct nf_conn *ct) +{ +#ifdef CONFIG_NF_CONNTRACK_SECMARK + int len, ret; + + ret = security_secid_to_secctx(ct->secmark, NULL, &len); + if (ret) + return 0; + + return nla_total_size(0) /* CTA_SECCTX */ + + nla_total_size(sizeof(char) * len); /* CTA_SECCTX_NAME */ +#else + return 0; +#endif +} + +static inline size_t ctnetlink_timestamp_size(const struct nf_conn *ct) +{ +#ifdef CONFIG_NF_CONNTRACK_TIMESTAMP + if (!nf_ct_ext_exist(ct, NF_CT_EXT_TSTAMP)) + return 0; + return nla_total_size(0) + 2 * nla_total_size_64bit(sizeof(uint64_t)); +#else + return 0; +#endif +} + +#ifdef CONFIG_NF_CONNTRACK_EVENTS +static size_t ctnetlink_nlmsg_size(const struct nf_conn *ct) +{ + return NLMSG_ALIGN(sizeof(struct nfgenmsg)) + + 3 * nla_total_size(0) /* CTA_TUPLE_ORIG|REPL|MASTER */ + + 3 * nla_total_size(0) /* CTA_TUPLE_IP */ + + 3 * nla_total_size(0) /* CTA_TUPLE_PROTO */ + + 3 * nla_total_size(sizeof(u_int8_t)) /* CTA_PROTO_NUM */ + + nla_total_size(sizeof(u_int32_t)) /* CTA_ID */ + + nla_total_size(sizeof(u_int32_t)) /* CTA_STATUS */ + + ctnetlink_acct_size(ct) + + ctnetlink_timestamp_size(ct) + + nla_total_size(sizeof(u_int32_t)) /* CTA_TIMEOUT */ + + nla_total_size(0) /* CTA_PROTOINFO */ + + nla_total_size(0) /* CTA_HELP */ + + nla_total_size(NF_CT_HELPER_NAME_LEN) /* CTA_HELP_NAME */ + + ctnetlink_secctx_size(ct) +#if IS_ENABLED(CONFIG_NF_NAT) + + 2 * nla_total_size(0) /* CTA_NAT_SEQ_ADJ_ORIG|REPL */ + + 6 * nla_total_size(sizeof(u_int32_t)) /* CTA_NAT_SEQ_OFFSET */ +#endif +#ifdef CONFIG_NF_CONNTRACK_MARK + + nla_total_size(sizeof(u_int32_t)) /* CTA_MARK */ +#endif +#ifdef CONFIG_NF_CONNTRACK_ZONES + + nla_total_size(sizeof(u_int16_t)) /* CTA_ZONE|CTA_TUPLE_ZONE */ +#endif + + ctnetlink_proto_size(ct) + + ctnetlink_label_size(ct) + ; +} + +static int +ctnetlink_conntrack_event(unsigned int events, const struct nf_ct_event *item) +{ + const struct nf_conntrack_zone *zone; + struct net *net; + struct nlmsghdr *nlh; + struct nlattr *nest_parms; + struct nf_conn *ct = item->ct; + struct sk_buff *skb; + unsigned int type; + unsigned int flags = 0, group; + int err; + + if (events & (1 << IPCT_DESTROY)) { + type = IPCTNL_MSG_CT_DELETE; + group = NFNLGRP_CONNTRACK_DESTROY; + } else if (events & ((1 << IPCT_NEW) | (1 << IPCT_RELATED))) { + type = IPCTNL_MSG_CT_NEW; + flags = NLM_F_CREATE|NLM_F_EXCL; + group = NFNLGRP_CONNTRACK_NEW; + } else if (events) { + type = IPCTNL_MSG_CT_NEW; + group = NFNLGRP_CONNTRACK_UPDATE; + } else + return 0; + + net = nf_ct_net(ct); + if (!item->report && !nfnetlink_has_listeners(net, group)) + return 0; + + skb = nlmsg_new(ctnetlink_nlmsg_size(ct), GFP_ATOMIC); + if (skb == NULL) + goto errout; + + type = nfnl_msg_type(NFNL_SUBSYS_CTNETLINK, type); + nlh = nfnl_msg_put(skb, item->portid, 0, type, flags, nf_ct_l3num(ct), + NFNETLINK_V0, 0); + if (!nlh) + goto nlmsg_failure; + + zone = nf_ct_zone(ct); + + nest_parms = nla_nest_start(skb, CTA_TUPLE_ORIG); + if (!nest_parms) + goto nla_put_failure; + if (ctnetlink_dump_tuples(skb, nf_ct_tuple(ct, IP_CT_DIR_ORIGINAL)) < 0) + goto nla_put_failure; + if (ctnetlink_dump_zone_id(skb, CTA_TUPLE_ZONE, zone, + NF_CT_ZONE_DIR_ORIG) < 0) + goto nla_put_failure; + nla_nest_end(skb, nest_parms); + + nest_parms = nla_nest_start(skb, CTA_TUPLE_REPLY); + if (!nest_parms) + goto nla_put_failure; + if (ctnetlink_dump_tuples(skb, nf_ct_tuple(ct, IP_CT_DIR_REPLY)) < 0) + goto nla_put_failure; + if (ctnetlink_dump_zone_id(skb, CTA_TUPLE_ZONE, zone, + NF_CT_ZONE_DIR_REPL) < 0) + goto nla_put_failure; + nla_nest_end(skb, nest_parms); + + if (ctnetlink_dump_zone_id(skb, CTA_ZONE, zone, + NF_CT_DEFAULT_ZONE_DIR) < 0) + goto nla_put_failure; + + if (ctnetlink_dump_id(skb, ct) < 0) + goto nla_put_failure; + + if (ctnetlink_dump_status(skb, ct) < 0) + goto nla_put_failure; + + if (events & (1 << IPCT_DESTROY)) { + if (ctnetlink_dump_timeout(skb, ct, true) < 0) + goto nla_put_failure; + + if (ctnetlink_dump_acct(skb, ct, type) < 0 || + ctnetlink_dump_timestamp(skb, ct) < 0 || + ctnetlink_dump_protoinfo(skb, ct, true) < 0) + goto nla_put_failure; + } else { + if (ctnetlink_dump_timeout(skb, ct, false) < 0) + goto nla_put_failure; + + if (events & (1 << IPCT_PROTOINFO) && + ctnetlink_dump_protoinfo(skb, ct, false) < 0) + goto nla_put_failure; + + if ((events & (1 << IPCT_HELPER) || nfct_help(ct)) + && ctnetlink_dump_helpinfo(skb, ct) < 0) + goto nla_put_failure; + +#ifdef CONFIG_NF_CONNTRACK_SECMARK + if ((events & (1 << IPCT_SECMARK) || ct->secmark) + && ctnetlink_dump_secctx(skb, ct) < 0) + goto nla_put_failure; +#endif + if (events & (1 << IPCT_LABEL) && + ctnetlink_dump_labels(skb, ct) < 0) + goto nla_put_failure; + + if (events & (1 << IPCT_RELATED) && + ctnetlink_dump_master(skb, ct) < 0) + goto nla_put_failure; + + if (events & (1 << IPCT_SEQADJ) && + ctnetlink_dump_ct_seq_adj(skb, ct) < 0) + goto nla_put_failure; + + if (events & (1 << IPCT_SYNPROXY) && + ctnetlink_dump_ct_synproxy(skb, ct) < 0) + goto nla_put_failure; + } + +#ifdef CONFIG_NF_CONNTRACK_MARK + if (ctnetlink_dump_mark(skb, ct, events & (1 << IPCT_MARK))) + goto nla_put_failure; +#endif + nlmsg_end(skb, nlh); + err = nfnetlink_send(skb, net, item->portid, group, item->report, + GFP_ATOMIC); + if (err == -ENOBUFS || err == -EAGAIN) + return -ENOBUFS; + + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); +nlmsg_failure: + kfree_skb(skb); +errout: + if (nfnetlink_set_err(net, 0, group, -ENOBUFS) > 0) + return -ENOBUFS; + + return 0; +} +#endif /* CONFIG_NF_CONNTRACK_EVENTS */ + +static int ctnetlink_done(struct netlink_callback *cb) +{ + if (cb->args[1]) + nf_ct_put((struct nf_conn *)cb->args[1]); + kfree(cb->data); + return 0; +} + +struct ctnetlink_filter_u32 { + u32 val; + u32 mask; +}; + +struct ctnetlink_filter { + u8 family; + + u_int32_t orig_flags; + u_int32_t reply_flags; + + struct nf_conntrack_tuple orig; + struct nf_conntrack_tuple reply; + struct nf_conntrack_zone zone; + + struct ctnetlink_filter_u32 mark; + struct ctnetlink_filter_u32 status; +}; + +static const struct nla_policy cta_filter_nla_policy[CTA_FILTER_MAX + 1] = { + [CTA_FILTER_ORIG_FLAGS] = { .type = NLA_U32 }, + [CTA_FILTER_REPLY_FLAGS] = { .type = NLA_U32 }, +}; + +static int ctnetlink_parse_filter(const struct nlattr *attr, + struct ctnetlink_filter *filter) +{ + struct nlattr *tb[CTA_FILTER_MAX + 1]; + int ret = 0; + + ret = nla_parse_nested(tb, CTA_FILTER_MAX, attr, cta_filter_nla_policy, + NULL); + if (ret) + return ret; + + if (tb[CTA_FILTER_ORIG_FLAGS]) { + filter->orig_flags = nla_get_u32(tb[CTA_FILTER_ORIG_FLAGS]); + if (filter->orig_flags & ~CTA_FILTER_F_ALL) + return -EOPNOTSUPP; + } + + if (tb[CTA_FILTER_REPLY_FLAGS]) { + filter->reply_flags = nla_get_u32(tb[CTA_FILTER_REPLY_FLAGS]); + if (filter->reply_flags & ~CTA_FILTER_F_ALL) + return -EOPNOTSUPP; + } + + return 0; +} + +static int ctnetlink_parse_zone(const struct nlattr *attr, + struct nf_conntrack_zone *zone); +static int ctnetlink_parse_tuple_filter(const struct nlattr * const cda[], + struct nf_conntrack_tuple *tuple, + u32 type, u_int8_t l3num, + struct nf_conntrack_zone *zone, + u_int32_t flags); + +static int ctnetlink_filter_parse_mark(struct ctnetlink_filter_u32 *mark, + const struct nlattr * const cda[]) +{ +#ifdef CONFIG_NF_CONNTRACK_MARK + if (cda[CTA_MARK]) { + mark->val = ntohl(nla_get_be32(cda[CTA_MARK])); + + if (cda[CTA_MARK_MASK]) + mark->mask = ntohl(nla_get_be32(cda[CTA_MARK_MASK])); + else + mark->mask = 0xffffffff; + } else if (cda[CTA_MARK_MASK]) { + return -EINVAL; + } +#endif + return 0; +} + +static int ctnetlink_filter_parse_status(struct ctnetlink_filter_u32 *status, + const struct nlattr * const cda[]) +{ + if (cda[CTA_STATUS]) { + status->val = ntohl(nla_get_be32(cda[CTA_STATUS])); + if (cda[CTA_STATUS_MASK]) + status->mask = ntohl(nla_get_be32(cda[CTA_STATUS_MASK])); + else + status->mask = status->val; + + /* status->val == 0? always true, else always false. */ + if (status->mask == 0) + return -EINVAL; + } else if (cda[CTA_STATUS_MASK]) { + return -EINVAL; + } + + /* CTA_STATUS is NLA_U32, if this fires UAPI needs to be extended */ + BUILD_BUG_ON(__IPS_MAX_BIT >= 32); + return 0; +} + +static struct ctnetlink_filter * +ctnetlink_alloc_filter(const struct nlattr * const cda[], u8 family) +{ + struct ctnetlink_filter *filter; + int err; + +#ifndef CONFIG_NF_CONNTRACK_MARK + if (cda[CTA_MARK] || cda[CTA_MARK_MASK]) + return ERR_PTR(-EOPNOTSUPP); +#endif + + filter = kzalloc(sizeof(*filter), GFP_KERNEL); + if (filter == NULL) + return ERR_PTR(-ENOMEM); + + filter->family = family; + + err = ctnetlink_filter_parse_mark(&filter->mark, cda); + if (err) + goto err_filter; + + err = ctnetlink_filter_parse_status(&filter->status, cda); + if (err) + goto err_filter; + + if (!cda[CTA_FILTER]) + return filter; + + err = ctnetlink_parse_zone(cda[CTA_ZONE], &filter->zone); + if (err < 0) + goto err_filter; + + err = ctnetlink_parse_filter(cda[CTA_FILTER], filter); + if (err < 0) + goto err_filter; + + if (filter->orig_flags) { + if (!cda[CTA_TUPLE_ORIG]) { + err = -EINVAL; + goto err_filter; + } + + err = ctnetlink_parse_tuple_filter(cda, &filter->orig, + CTA_TUPLE_ORIG, + filter->family, + &filter->zone, + filter->orig_flags); + if (err < 0) + goto err_filter; + } + + if (filter->reply_flags) { + if (!cda[CTA_TUPLE_REPLY]) { + err = -EINVAL; + goto err_filter; + } + + err = ctnetlink_parse_tuple_filter(cda, &filter->reply, + CTA_TUPLE_REPLY, + filter->family, + &filter->zone, + filter->reply_flags); + if (err < 0) + goto err_filter; + } + + return filter; + +err_filter: + kfree(filter); + + return ERR_PTR(err); +} + +static bool ctnetlink_needs_filter(u8 family, const struct nlattr * const *cda) +{ + return family || cda[CTA_MARK] || cda[CTA_FILTER] || cda[CTA_STATUS]; +} + +static int ctnetlink_start(struct netlink_callback *cb) +{ + const struct nlattr * const *cda = cb->data; + struct ctnetlink_filter *filter = NULL; + struct nfgenmsg *nfmsg = nlmsg_data(cb->nlh); + u8 family = nfmsg->nfgen_family; + + if (ctnetlink_needs_filter(family, cda)) { + filter = ctnetlink_alloc_filter(cda, family); + if (IS_ERR(filter)) + return PTR_ERR(filter); + } + + cb->data = filter; + return 0; +} + +static int ctnetlink_filter_match_tuple(struct nf_conntrack_tuple *filter_tuple, + struct nf_conntrack_tuple *ct_tuple, + u_int32_t flags, int family) +{ + switch (family) { + case NFPROTO_IPV4: + if ((flags & CTA_FILTER_FLAG(CTA_IP_SRC)) && + filter_tuple->src.u3.ip != ct_tuple->src.u3.ip) + return 0; + + if ((flags & CTA_FILTER_FLAG(CTA_IP_DST)) && + filter_tuple->dst.u3.ip != ct_tuple->dst.u3.ip) + return 0; + break; + case NFPROTO_IPV6: + if ((flags & CTA_FILTER_FLAG(CTA_IP_SRC)) && + !ipv6_addr_cmp(&filter_tuple->src.u3.in6, + &ct_tuple->src.u3.in6)) + return 0; + + if ((flags & CTA_FILTER_FLAG(CTA_IP_DST)) && + !ipv6_addr_cmp(&filter_tuple->dst.u3.in6, + &ct_tuple->dst.u3.in6)) + return 0; + break; + } + + if ((flags & CTA_FILTER_FLAG(CTA_PROTO_NUM)) && + filter_tuple->dst.protonum != ct_tuple->dst.protonum) + return 0; + + switch (ct_tuple->dst.protonum) { + case IPPROTO_TCP: + case IPPROTO_UDP: + if ((flags & CTA_FILTER_FLAG(CTA_PROTO_SRC_PORT)) && + filter_tuple->src.u.tcp.port != ct_tuple->src.u.tcp.port) + return 0; + + if ((flags & CTA_FILTER_FLAG(CTA_PROTO_DST_PORT)) && + filter_tuple->dst.u.tcp.port != ct_tuple->dst.u.tcp.port) + return 0; + break; + case IPPROTO_ICMP: + if ((flags & CTA_FILTER_FLAG(CTA_PROTO_ICMP_TYPE)) && + filter_tuple->dst.u.icmp.type != ct_tuple->dst.u.icmp.type) + return 0; + if ((flags & CTA_FILTER_FLAG(CTA_PROTO_ICMP_CODE)) && + filter_tuple->dst.u.icmp.code != ct_tuple->dst.u.icmp.code) + return 0; + if ((flags & CTA_FILTER_FLAG(CTA_PROTO_ICMP_ID)) && + filter_tuple->src.u.icmp.id != ct_tuple->src.u.icmp.id) + return 0; + break; + case IPPROTO_ICMPV6: + if ((flags & CTA_FILTER_FLAG(CTA_PROTO_ICMPV6_TYPE)) && + filter_tuple->dst.u.icmp.type != ct_tuple->dst.u.icmp.type) + return 0; + if ((flags & CTA_FILTER_FLAG(CTA_PROTO_ICMPV6_CODE)) && + filter_tuple->dst.u.icmp.code != ct_tuple->dst.u.icmp.code) + return 0; + if ((flags & CTA_FILTER_FLAG(CTA_PROTO_ICMPV6_ID)) && + filter_tuple->src.u.icmp.id != ct_tuple->src.u.icmp.id) + return 0; + break; + } + + return 1; +} + +static int ctnetlink_filter_match(struct nf_conn *ct, void *data) +{ + struct ctnetlink_filter *filter = data; + struct nf_conntrack_tuple *tuple; + u32 status; + + if (filter == NULL) + goto out; + + /* Match entries of a given L3 protocol number. + * If it is not specified, ie. l3proto == 0, + * then match everything. + */ + if (filter->family && nf_ct_l3num(ct) != filter->family) + goto ignore_entry; + + if (filter->orig_flags) { + tuple = nf_ct_tuple(ct, IP_CT_DIR_ORIGINAL); + if (!ctnetlink_filter_match_tuple(&filter->orig, tuple, + filter->orig_flags, + filter->family)) + goto ignore_entry; + } + + if (filter->reply_flags) { + tuple = nf_ct_tuple(ct, IP_CT_DIR_REPLY); + if (!ctnetlink_filter_match_tuple(&filter->reply, tuple, + filter->reply_flags, + filter->family)) + goto ignore_entry; + } + +#ifdef CONFIG_NF_CONNTRACK_MARK + if ((READ_ONCE(ct->mark) & filter->mark.mask) != filter->mark.val) + goto ignore_entry; +#endif + status = (u32)READ_ONCE(ct->status); + if ((status & filter->status.mask) != filter->status.val) + goto ignore_entry; + +out: + return 1; + +ignore_entry: + return 0; +} + +static int +ctnetlink_dump_table(struct sk_buff *skb, struct netlink_callback *cb) +{ + unsigned int flags = cb->data ? NLM_F_DUMP_FILTERED : 0; + struct net *net = sock_net(skb->sk); + struct nf_conn *ct, *last; + struct nf_conntrack_tuple_hash *h; + struct hlist_nulls_node *n; + struct nf_conn *nf_ct_evict[8]; + int res, i; + spinlock_t *lockp; + + last = (struct nf_conn *)cb->args[1]; + i = 0; + + local_bh_disable(); + for (; cb->args[0] < nf_conntrack_htable_size; cb->args[0]++) { +restart: + while (i) { + i--; + if (nf_ct_should_gc(nf_ct_evict[i])) + nf_ct_kill(nf_ct_evict[i]); + nf_ct_put(nf_ct_evict[i]); + } + + lockp = &nf_conntrack_locks[cb->args[0] % CONNTRACK_LOCKS]; + nf_conntrack_lock(lockp); + if (cb->args[0] >= nf_conntrack_htable_size) { + spin_unlock(lockp); + goto out; + } + hlist_nulls_for_each_entry(h, n, &nf_conntrack_hash[cb->args[0]], + hnnode) { + ct = nf_ct_tuplehash_to_ctrack(h); + if (nf_ct_is_expired(ct)) { + /* need to defer nf_ct_kill() until lock is released */ + if (i < ARRAY_SIZE(nf_ct_evict) && + refcount_inc_not_zero(&ct->ct_general.use)) + nf_ct_evict[i++] = ct; + continue; + } + + if (!net_eq(net, nf_ct_net(ct))) + continue; + + if (NF_CT_DIRECTION(h) != IP_CT_DIR_ORIGINAL) + continue; + + if (cb->args[1]) { + if (ct != last) + continue; + cb->args[1] = 0; + } + if (!ctnetlink_filter_match(ct, cb->data)) + continue; + + res = + ctnetlink_fill_info(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NFNL_MSG_TYPE(cb->nlh->nlmsg_type), + ct, true, flags); + if (res < 0) { + nf_conntrack_get(&ct->ct_general); + cb->args[1] = (unsigned long)ct; + spin_unlock(lockp); + goto out; + } + } + spin_unlock(lockp); + if (cb->args[1]) { + cb->args[1] = 0; + goto restart; + } + } +out: + local_bh_enable(); + if (last) { + /* nf ct hash resize happened, now clear the leftover. */ + if ((struct nf_conn *)cb->args[1] == last) + cb->args[1] = 0; + + nf_ct_put(last); + } + + while (i) { + i--; + if (nf_ct_should_gc(nf_ct_evict[i])) + nf_ct_kill(nf_ct_evict[i]); + nf_ct_put(nf_ct_evict[i]); + } + + return skb->len; +} + +static int ipv4_nlattr_to_tuple(struct nlattr *tb[], + struct nf_conntrack_tuple *t, + u_int32_t flags) +{ + if (flags & CTA_FILTER_FLAG(CTA_IP_SRC)) { + if (!tb[CTA_IP_V4_SRC]) + return -EINVAL; + + t->src.u3.ip = nla_get_in_addr(tb[CTA_IP_V4_SRC]); + } + + if (flags & CTA_FILTER_FLAG(CTA_IP_DST)) { + if (!tb[CTA_IP_V4_DST]) + return -EINVAL; + + t->dst.u3.ip = nla_get_in_addr(tb[CTA_IP_V4_DST]); + } + + return 0; +} + +static int ipv6_nlattr_to_tuple(struct nlattr *tb[], + struct nf_conntrack_tuple *t, + u_int32_t flags) +{ + if (flags & CTA_FILTER_FLAG(CTA_IP_SRC)) { + if (!tb[CTA_IP_V6_SRC]) + return -EINVAL; + + t->src.u3.in6 = nla_get_in6_addr(tb[CTA_IP_V6_SRC]); + } + + if (flags & CTA_FILTER_FLAG(CTA_IP_DST)) { + if (!tb[CTA_IP_V6_DST]) + return -EINVAL; + + t->dst.u3.in6 = nla_get_in6_addr(tb[CTA_IP_V6_DST]); + } + + return 0; +} + +static int ctnetlink_parse_tuple_ip(struct nlattr *attr, + struct nf_conntrack_tuple *tuple, + u_int32_t flags) +{ + struct nlattr *tb[CTA_IP_MAX+1]; + int ret = 0; + + ret = nla_parse_nested_deprecated(tb, CTA_IP_MAX, attr, NULL, NULL); + if (ret < 0) + return ret; + + ret = nla_validate_nested_deprecated(attr, CTA_IP_MAX, + cta_ip_nla_policy, NULL); + if (ret) + return ret; + + switch (tuple->src.l3num) { + case NFPROTO_IPV4: + ret = ipv4_nlattr_to_tuple(tb, tuple, flags); + break; + case NFPROTO_IPV6: + ret = ipv6_nlattr_to_tuple(tb, tuple, flags); + break; + } + + return ret; +} + +static const struct nla_policy proto_nla_policy[CTA_PROTO_MAX+1] = { + [CTA_PROTO_NUM] = { .type = NLA_U8 }, +}; + +static int ctnetlink_parse_tuple_proto(struct nlattr *attr, + struct nf_conntrack_tuple *tuple, + u_int32_t flags) +{ + const struct nf_conntrack_l4proto *l4proto; + struct nlattr *tb[CTA_PROTO_MAX+1]; + int ret = 0; + + ret = nla_parse_nested_deprecated(tb, CTA_PROTO_MAX, attr, + proto_nla_policy, NULL); + if (ret < 0) + return ret; + + if (!(flags & CTA_FILTER_FLAG(CTA_PROTO_NUM))) + return 0; + + if (!tb[CTA_PROTO_NUM]) + return -EINVAL; + + tuple->dst.protonum = nla_get_u8(tb[CTA_PROTO_NUM]); + + rcu_read_lock(); + l4proto = nf_ct_l4proto_find(tuple->dst.protonum); + + if (likely(l4proto->nlattr_to_tuple)) { + ret = nla_validate_nested_deprecated(attr, CTA_PROTO_MAX, + l4proto->nla_policy, + NULL); + if (ret == 0) + ret = l4proto->nlattr_to_tuple(tb, tuple, flags); + } + + rcu_read_unlock(); + + return ret; +} + +static int +ctnetlink_parse_zone(const struct nlattr *attr, + struct nf_conntrack_zone *zone) +{ + nf_ct_zone_init(zone, NF_CT_DEFAULT_ZONE_ID, + NF_CT_DEFAULT_ZONE_DIR, 0); +#ifdef CONFIG_NF_CONNTRACK_ZONES + if (attr) + zone->id = ntohs(nla_get_be16(attr)); +#else + if (attr) + return -EOPNOTSUPP; +#endif + return 0; +} + +static int +ctnetlink_parse_tuple_zone(struct nlattr *attr, enum ctattr_type type, + struct nf_conntrack_zone *zone) +{ + int ret; + + if (zone->id != NF_CT_DEFAULT_ZONE_ID) + return -EINVAL; + + ret = ctnetlink_parse_zone(attr, zone); + if (ret < 0) + return ret; + + if (type == CTA_TUPLE_REPLY) + zone->dir = NF_CT_ZONE_DIR_REPL; + else + zone->dir = NF_CT_ZONE_DIR_ORIG; + + return 0; +} + +static const struct nla_policy tuple_nla_policy[CTA_TUPLE_MAX+1] = { + [CTA_TUPLE_IP] = { .type = NLA_NESTED }, + [CTA_TUPLE_PROTO] = { .type = NLA_NESTED }, + [CTA_TUPLE_ZONE] = { .type = NLA_U16 }, +}; + +#define CTA_FILTER_F_ALL_CTA_PROTO \ + (CTA_FILTER_F_CTA_PROTO_SRC_PORT | \ + CTA_FILTER_F_CTA_PROTO_DST_PORT | \ + CTA_FILTER_F_CTA_PROTO_ICMP_TYPE | \ + CTA_FILTER_F_CTA_PROTO_ICMP_CODE | \ + CTA_FILTER_F_CTA_PROTO_ICMP_ID | \ + CTA_FILTER_F_CTA_PROTO_ICMPV6_TYPE | \ + CTA_FILTER_F_CTA_PROTO_ICMPV6_CODE | \ + CTA_FILTER_F_CTA_PROTO_ICMPV6_ID) + +static int +ctnetlink_parse_tuple_filter(const struct nlattr * const cda[], + struct nf_conntrack_tuple *tuple, u32 type, + u_int8_t l3num, struct nf_conntrack_zone *zone, + u_int32_t flags) +{ + struct nlattr *tb[CTA_TUPLE_MAX+1]; + int err; + + memset(tuple, 0, sizeof(*tuple)); + + err = nla_parse_nested_deprecated(tb, CTA_TUPLE_MAX, cda[type], + tuple_nla_policy, NULL); + if (err < 0) + return err; + + if (l3num != NFPROTO_IPV4 && l3num != NFPROTO_IPV6) + return -EOPNOTSUPP; + tuple->src.l3num = l3num; + + if (flags & CTA_FILTER_FLAG(CTA_IP_DST) || + flags & CTA_FILTER_FLAG(CTA_IP_SRC)) { + if (!tb[CTA_TUPLE_IP]) + return -EINVAL; + + err = ctnetlink_parse_tuple_ip(tb[CTA_TUPLE_IP], tuple, flags); + if (err < 0) + return err; + } + + if (flags & CTA_FILTER_FLAG(CTA_PROTO_NUM)) { + if (!tb[CTA_TUPLE_PROTO]) + return -EINVAL; + + err = ctnetlink_parse_tuple_proto(tb[CTA_TUPLE_PROTO], tuple, flags); + if (err < 0) + return err; + } else if (flags & CTA_FILTER_FLAG(ALL_CTA_PROTO)) { + /* Can't manage proto flags without a protonum */ + return -EINVAL; + } + + if ((flags & CTA_FILTER_FLAG(CTA_TUPLE_ZONE)) && tb[CTA_TUPLE_ZONE]) { + if (!zone) + return -EINVAL; + + err = ctnetlink_parse_tuple_zone(tb[CTA_TUPLE_ZONE], + type, zone); + if (err < 0) + return err; + } + + /* orig and expect tuples get DIR_ORIGINAL */ + if (type == CTA_TUPLE_REPLY) + tuple->dst.dir = IP_CT_DIR_REPLY; + else + tuple->dst.dir = IP_CT_DIR_ORIGINAL; + + return 0; +} + +static int +ctnetlink_parse_tuple(const struct nlattr * const cda[], + struct nf_conntrack_tuple *tuple, u32 type, + u_int8_t l3num, struct nf_conntrack_zone *zone) +{ + return ctnetlink_parse_tuple_filter(cda, tuple, type, l3num, zone, + CTA_FILTER_FLAG(ALL)); +} + +static const struct nla_policy help_nla_policy[CTA_HELP_MAX+1] = { + [CTA_HELP_NAME] = { .type = NLA_NUL_STRING, + .len = NF_CT_HELPER_NAME_LEN - 1 }, +}; + +static int ctnetlink_parse_help(const struct nlattr *attr, char **helper_name, + struct nlattr **helpinfo) +{ + int err; + struct nlattr *tb[CTA_HELP_MAX+1]; + + err = nla_parse_nested_deprecated(tb, CTA_HELP_MAX, attr, + help_nla_policy, NULL); + if (err < 0) + return err; + + if (!tb[CTA_HELP_NAME]) + return -EINVAL; + + *helper_name = nla_data(tb[CTA_HELP_NAME]); + + if (tb[CTA_HELP_INFO]) + *helpinfo = tb[CTA_HELP_INFO]; + + return 0; +} + +static const struct nla_policy ct_nla_policy[CTA_MAX+1] = { + [CTA_TUPLE_ORIG] = { .type = NLA_NESTED }, + [CTA_TUPLE_REPLY] = { .type = NLA_NESTED }, + [CTA_STATUS] = { .type = NLA_U32 }, + [CTA_PROTOINFO] = { .type = NLA_NESTED }, + [CTA_HELP] = { .type = NLA_NESTED }, + [CTA_NAT_SRC] = { .type = NLA_NESTED }, + [CTA_TIMEOUT] = { .type = NLA_U32 }, + [CTA_MARK] = { .type = NLA_U32 }, + [CTA_ID] = { .type = NLA_U32 }, + [CTA_NAT_DST] = { .type = NLA_NESTED }, + [CTA_TUPLE_MASTER] = { .type = NLA_NESTED }, + [CTA_NAT_SEQ_ADJ_ORIG] = { .type = NLA_NESTED }, + [CTA_NAT_SEQ_ADJ_REPLY] = { .type = NLA_NESTED }, + [CTA_ZONE] = { .type = NLA_U16 }, + [CTA_MARK_MASK] = { .type = NLA_U32 }, + [CTA_LABELS] = { .type = NLA_BINARY, + .len = NF_CT_LABELS_MAX_SIZE }, + [CTA_LABELS_MASK] = { .type = NLA_BINARY, + .len = NF_CT_LABELS_MAX_SIZE }, + [CTA_FILTER] = { .type = NLA_NESTED }, + [CTA_STATUS_MASK] = { .type = NLA_U32 }, +}; + +static int ctnetlink_flush_iterate(struct nf_conn *ct, void *data) +{ + return ctnetlink_filter_match(ct, data); +} + +static int ctnetlink_flush_conntrack(struct net *net, + const struct nlattr * const cda[], + u32 portid, int report, u8 family) +{ + struct ctnetlink_filter *filter = NULL; + struct nf_ct_iter_data iter = { + .net = net, + .portid = portid, + .report = report, + }; + + if (ctnetlink_needs_filter(family, cda)) { + if (cda[CTA_FILTER]) + return -EOPNOTSUPP; + + filter = ctnetlink_alloc_filter(cda, family); + if (IS_ERR(filter)) + return PTR_ERR(filter); + + iter.data = filter; + } + + nf_ct_iterate_cleanup_net(ctnetlink_flush_iterate, &iter); + kfree(filter); + + return 0; +} + +static int ctnetlink_del_conntrack(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const cda[]) +{ + u8 family = info->nfmsg->nfgen_family; + struct nf_conntrack_tuple_hash *h; + struct nf_conntrack_tuple tuple; + struct nf_conntrack_zone zone; + struct nf_conn *ct; + int err; + + err = ctnetlink_parse_zone(cda[CTA_ZONE], &zone); + if (err < 0) + return err; + + if (cda[CTA_TUPLE_ORIG]) + err = ctnetlink_parse_tuple(cda, &tuple, CTA_TUPLE_ORIG, + family, &zone); + else if (cda[CTA_TUPLE_REPLY]) + err = ctnetlink_parse_tuple(cda, &tuple, CTA_TUPLE_REPLY, + family, &zone); + else { + u_int8_t u3 = info->nfmsg->version ? family : AF_UNSPEC; + + return ctnetlink_flush_conntrack(info->net, cda, + NETLINK_CB(skb).portid, + nlmsg_report(info->nlh), u3); + } + + if (err < 0) + return err; + + h = nf_conntrack_find_get(info->net, &zone, &tuple); + if (!h) + return -ENOENT; + + ct = nf_ct_tuplehash_to_ctrack(h); + + if (cda[CTA_ID]) { + __be32 id = nla_get_be32(cda[CTA_ID]); + + if (id != (__force __be32)nf_ct_get_id(ct)) { + nf_ct_put(ct); + return -ENOENT; + } + } + + nf_ct_delete(ct, NETLINK_CB(skb).portid, nlmsg_report(info->nlh)); + nf_ct_put(ct); + + return 0; +} + +static int ctnetlink_get_conntrack(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const cda[]) +{ + u_int8_t u3 = info->nfmsg->nfgen_family; + struct nf_conntrack_tuple_hash *h; + struct nf_conntrack_tuple tuple; + struct nf_conntrack_zone zone; + struct sk_buff *skb2; + struct nf_conn *ct; + int err; + + if (info->nlh->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .start = ctnetlink_start, + .dump = ctnetlink_dump_table, + .done = ctnetlink_done, + .data = (void *)cda, + }; + + return netlink_dump_start(info->sk, skb, info->nlh, &c); + } + + err = ctnetlink_parse_zone(cda[CTA_ZONE], &zone); + if (err < 0) + return err; + + if (cda[CTA_TUPLE_ORIG]) + err = ctnetlink_parse_tuple(cda, &tuple, CTA_TUPLE_ORIG, + u3, &zone); + else if (cda[CTA_TUPLE_REPLY]) + err = ctnetlink_parse_tuple(cda, &tuple, CTA_TUPLE_REPLY, + u3, &zone); + else + return -EINVAL; + + if (err < 0) + return err; + + h = nf_conntrack_find_get(info->net, &zone, &tuple); + if (!h) + return -ENOENT; + + ct = nf_ct_tuplehash_to_ctrack(h); + + skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb2) { + nf_ct_put(ct); + return -ENOMEM; + } + + err = ctnetlink_fill_info(skb2, NETLINK_CB(skb).portid, + info->nlh->nlmsg_seq, + NFNL_MSG_TYPE(info->nlh->nlmsg_type), ct, + true, 0); + nf_ct_put(ct); + if (err <= 0) { + kfree_skb(skb2); + return -ENOMEM; + } + + return nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid); +} + +static int ctnetlink_done_list(struct netlink_callback *cb) +{ + struct ctnetlink_list_dump_ctx *ctx = (void *)cb->ctx; + + if (ctx->last) + nf_ct_put(ctx->last); + + return 0; +} + +#ifdef CONFIG_NF_CONNTRACK_EVENTS +static int ctnetlink_dump_one_entry(struct sk_buff *skb, + struct netlink_callback *cb, + struct nf_conn *ct, + bool dying) +{ + struct ctnetlink_list_dump_ctx *ctx = (void *)cb->ctx; + struct nfgenmsg *nfmsg = nlmsg_data(cb->nlh); + u8 l3proto = nfmsg->nfgen_family; + int res; + + if (l3proto && nf_ct_l3num(ct) != l3proto) + return 0; + + if (ctx->last) { + if (ct != ctx->last) + return 0; + + ctx->last = NULL; + } + + /* We can't dump extension info for the unconfirmed + * list because unconfirmed conntracks can have + * ct->ext reallocated (and thus freed). + * + * In the dying list case ct->ext can't be free'd + * until after we drop pcpu->lock. + */ + res = ctnetlink_fill_info(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NFNL_MSG_TYPE(cb->nlh->nlmsg_type), + ct, dying, 0); + if (res < 0) { + if (!refcount_inc_not_zero(&ct->ct_general.use)) + return 0; + + ctx->last = ct; + } + + return res; +} +#endif + +static int +ctnetlink_dump_unconfirmed(struct sk_buff *skb, struct netlink_callback *cb) +{ + return 0; +} + +static int +ctnetlink_dump_dying(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct ctnetlink_list_dump_ctx *ctx = (void *)cb->ctx; + struct nf_conn *last = ctx->last; +#ifdef CONFIG_NF_CONNTRACK_EVENTS + const struct net *net = sock_net(skb->sk); + struct nf_conntrack_net_ecache *ecache_net; + struct nf_conntrack_tuple_hash *h; + struct hlist_nulls_node *n; +#endif + + if (ctx->done) + return 0; + + ctx->last = NULL; + +#ifdef CONFIG_NF_CONNTRACK_EVENTS + ecache_net = nf_conn_pernet_ecache(net); + spin_lock_bh(&ecache_net->dying_lock); + + hlist_nulls_for_each_entry(h, n, &ecache_net->dying_list, hnnode) { + struct nf_conn *ct; + int res; + + ct = nf_ct_tuplehash_to_ctrack(h); + if (last && last != ct) + continue; + + res = ctnetlink_dump_one_entry(skb, cb, ct, true); + if (res < 0) { + spin_unlock_bh(&ecache_net->dying_lock); + nf_ct_put(last); + return skb->len; + } + + nf_ct_put(last); + last = NULL; + } + + spin_unlock_bh(&ecache_net->dying_lock); +#endif + ctx->done = true; + nf_ct_put(last); + + return skb->len; +} + +static int ctnetlink_get_ct_dying(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const cda[]) +{ + if (info->nlh->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .dump = ctnetlink_dump_dying, + .done = ctnetlink_done_list, + }; + return netlink_dump_start(info->sk, skb, info->nlh, &c); + } + + return -EOPNOTSUPP; +} + +static int ctnetlink_get_ct_unconfirmed(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const cda[]) +{ + if (info->nlh->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .dump = ctnetlink_dump_unconfirmed, + .done = ctnetlink_done_list, + }; + return netlink_dump_start(info->sk, skb, info->nlh, &c); + } + + return -EOPNOTSUPP; +} + +#if IS_ENABLED(CONFIG_NF_NAT) +static int +ctnetlink_parse_nat_setup(struct nf_conn *ct, + enum nf_nat_manip_type manip, + const struct nlattr *attr) + __must_hold(RCU) +{ + const struct nf_nat_hook *nat_hook; + int err; + + nat_hook = rcu_dereference(nf_nat_hook); + if (!nat_hook) { +#ifdef CONFIG_MODULES + rcu_read_unlock(); + nfnl_unlock(NFNL_SUBSYS_CTNETLINK); + if (request_module("nf-nat") < 0) { + nfnl_lock(NFNL_SUBSYS_CTNETLINK); + rcu_read_lock(); + return -EOPNOTSUPP; + } + nfnl_lock(NFNL_SUBSYS_CTNETLINK); + rcu_read_lock(); + nat_hook = rcu_dereference(nf_nat_hook); + if (nat_hook) + return -EAGAIN; +#endif + return -EOPNOTSUPP; + } + + err = nat_hook->parse_nat_setup(ct, manip, attr); + if (err == -EAGAIN) { +#ifdef CONFIG_MODULES + rcu_read_unlock(); + nfnl_unlock(NFNL_SUBSYS_CTNETLINK); + if (request_module("nf-nat-%u", nf_ct_l3num(ct)) < 0) { + nfnl_lock(NFNL_SUBSYS_CTNETLINK); + rcu_read_lock(); + return -EOPNOTSUPP; + } + nfnl_lock(NFNL_SUBSYS_CTNETLINK); + rcu_read_lock(); +#else + err = -EOPNOTSUPP; +#endif + } + return err; +} +#endif + +static int +ctnetlink_change_status(struct nf_conn *ct, const struct nlattr * const cda[]) +{ + return nf_ct_change_status_common(ct, ntohl(nla_get_be32(cda[CTA_STATUS]))); +} + +static int +ctnetlink_setup_nat(struct nf_conn *ct, const struct nlattr * const cda[]) +{ +#if IS_ENABLED(CONFIG_NF_NAT) + int ret; + + if (!cda[CTA_NAT_DST] && !cda[CTA_NAT_SRC]) + return 0; + + ret = ctnetlink_parse_nat_setup(ct, NF_NAT_MANIP_DST, + cda[CTA_NAT_DST]); + if (ret < 0) + return ret; + + return ctnetlink_parse_nat_setup(ct, NF_NAT_MANIP_SRC, + cda[CTA_NAT_SRC]); +#else + if (!cda[CTA_NAT_DST] && !cda[CTA_NAT_SRC]) + return 0; + return -EOPNOTSUPP; +#endif +} + +static int ctnetlink_change_helper(struct nf_conn *ct, + const struct nlattr * const cda[]) +{ + struct nf_conntrack_helper *helper; + struct nf_conn_help *help = nfct_help(ct); + char *helpname = NULL; + struct nlattr *helpinfo = NULL; + int err; + + err = ctnetlink_parse_help(cda[CTA_HELP], &helpname, &helpinfo); + if (err < 0) + return err; + + /* don't change helper of sibling connections */ + if (ct->master) { + /* If we try to change the helper to the same thing twice, + * treat the second attempt as a no-op instead of returning + * an error. + */ + err = -EBUSY; + if (help) { + rcu_read_lock(); + helper = rcu_dereference(help->helper); + if (helper && !strcmp(helper->name, helpname)) + err = 0; + rcu_read_unlock(); + } + + return err; + } + + if (!strcmp(helpname, "")) { + if (help && help->helper) { + /* we had a helper before ... */ + nf_ct_remove_expectations(ct); + RCU_INIT_POINTER(help->helper, NULL); + } + + return 0; + } + + rcu_read_lock(); + helper = __nf_conntrack_helper_find(helpname, nf_ct_l3num(ct), + nf_ct_protonum(ct)); + if (helper == NULL) { + rcu_read_unlock(); + return -EOPNOTSUPP; + } + + if (help) { + if (rcu_access_pointer(help->helper) == helper) { + /* update private helper data if allowed. */ + if (helper->from_nlattr) + helper->from_nlattr(helpinfo, ct); + err = 0; + } else + err = -EBUSY; + } else { + /* we cannot set a helper for an existing conntrack */ + err = -EOPNOTSUPP; + } + + rcu_read_unlock(); + return err; +} + +static int ctnetlink_change_timeout(struct nf_conn *ct, + const struct nlattr * const cda[]) +{ + return __nf_ct_change_timeout(ct, (u64)ntohl(nla_get_be32(cda[CTA_TIMEOUT])) * HZ); +} + +#if defined(CONFIG_NF_CONNTRACK_MARK) +static void ctnetlink_change_mark(struct nf_conn *ct, + const struct nlattr * const cda[]) +{ + u32 mark, newmark, mask = 0; + + if (cda[CTA_MARK_MASK]) + mask = ~ntohl(nla_get_be32(cda[CTA_MARK_MASK])); + + mark = ntohl(nla_get_be32(cda[CTA_MARK])); + newmark = (READ_ONCE(ct->mark) & mask) ^ mark; + if (newmark != READ_ONCE(ct->mark)) + WRITE_ONCE(ct->mark, newmark); +} +#endif + +static const struct nla_policy protoinfo_policy[CTA_PROTOINFO_MAX+1] = { + [CTA_PROTOINFO_TCP] = { .type = NLA_NESTED }, + [CTA_PROTOINFO_DCCP] = { .type = NLA_NESTED }, + [CTA_PROTOINFO_SCTP] = { .type = NLA_NESTED }, +}; + +static int ctnetlink_change_protoinfo(struct nf_conn *ct, + const struct nlattr * const cda[]) +{ + const struct nlattr *attr = cda[CTA_PROTOINFO]; + const struct nf_conntrack_l4proto *l4proto; + struct nlattr *tb[CTA_PROTOINFO_MAX+1]; + int err = 0; + + err = nla_parse_nested_deprecated(tb, CTA_PROTOINFO_MAX, attr, + protoinfo_policy, NULL); + if (err < 0) + return err; + + l4proto = nf_ct_l4proto_find(nf_ct_protonum(ct)); + if (l4proto->from_nlattr) + err = l4proto->from_nlattr(tb, ct); + + return err; +} + +static const struct nla_policy seqadj_policy[CTA_SEQADJ_MAX+1] = { + [CTA_SEQADJ_CORRECTION_POS] = { .type = NLA_U32 }, + [CTA_SEQADJ_OFFSET_BEFORE] = { .type = NLA_U32 }, + [CTA_SEQADJ_OFFSET_AFTER] = { .type = NLA_U32 }, +}; + +static int change_seq_adj(struct nf_ct_seqadj *seq, + const struct nlattr * const attr) +{ + int err; + struct nlattr *cda[CTA_SEQADJ_MAX+1]; + + err = nla_parse_nested_deprecated(cda, CTA_SEQADJ_MAX, attr, + seqadj_policy, NULL); + if (err < 0) + return err; + + if (!cda[CTA_SEQADJ_CORRECTION_POS]) + return -EINVAL; + + seq->correction_pos = + ntohl(nla_get_be32(cda[CTA_SEQADJ_CORRECTION_POS])); + + if (!cda[CTA_SEQADJ_OFFSET_BEFORE]) + return -EINVAL; + + seq->offset_before = + ntohl(nla_get_be32(cda[CTA_SEQADJ_OFFSET_BEFORE])); + + if (!cda[CTA_SEQADJ_OFFSET_AFTER]) + return -EINVAL; + + seq->offset_after = + ntohl(nla_get_be32(cda[CTA_SEQADJ_OFFSET_AFTER])); + + return 0; +} + +static int +ctnetlink_change_seq_adj(struct nf_conn *ct, + const struct nlattr * const cda[]) +{ + struct nf_conn_seqadj *seqadj = nfct_seqadj(ct); + int ret = 0; + + if (!seqadj) + return 0; + + spin_lock_bh(&ct->lock); + if (cda[CTA_SEQ_ADJ_ORIG]) { + ret = change_seq_adj(&seqadj->seq[IP_CT_DIR_ORIGINAL], + cda[CTA_SEQ_ADJ_ORIG]); + if (ret < 0) + goto err; + + set_bit(IPS_SEQ_ADJUST_BIT, &ct->status); + } + + if (cda[CTA_SEQ_ADJ_REPLY]) { + ret = change_seq_adj(&seqadj->seq[IP_CT_DIR_REPLY], + cda[CTA_SEQ_ADJ_REPLY]); + if (ret < 0) + goto err; + + set_bit(IPS_SEQ_ADJUST_BIT, &ct->status); + } + + spin_unlock_bh(&ct->lock); + return 0; +err: + spin_unlock_bh(&ct->lock); + return ret; +} + +static const struct nla_policy synproxy_policy[CTA_SYNPROXY_MAX + 1] = { + [CTA_SYNPROXY_ISN] = { .type = NLA_U32 }, + [CTA_SYNPROXY_ITS] = { .type = NLA_U32 }, + [CTA_SYNPROXY_TSOFF] = { .type = NLA_U32 }, +}; + +static int ctnetlink_change_synproxy(struct nf_conn *ct, + const struct nlattr * const cda[]) +{ + struct nf_conn_synproxy *synproxy = nfct_synproxy(ct); + struct nlattr *tb[CTA_SYNPROXY_MAX + 1]; + int err; + + if (!synproxy) + return 0; + + err = nla_parse_nested_deprecated(tb, CTA_SYNPROXY_MAX, + cda[CTA_SYNPROXY], synproxy_policy, + NULL); + if (err < 0) + return err; + + if (!tb[CTA_SYNPROXY_ISN] || + !tb[CTA_SYNPROXY_ITS] || + !tb[CTA_SYNPROXY_TSOFF]) + return -EINVAL; + + synproxy->isn = ntohl(nla_get_be32(tb[CTA_SYNPROXY_ISN])); + synproxy->its = ntohl(nla_get_be32(tb[CTA_SYNPROXY_ITS])); + synproxy->tsoff = ntohl(nla_get_be32(tb[CTA_SYNPROXY_TSOFF])); + + return 0; +} + +static int +ctnetlink_attach_labels(struct nf_conn *ct, const struct nlattr * const cda[]) +{ +#ifdef CONFIG_NF_CONNTRACK_LABELS + size_t len = nla_len(cda[CTA_LABELS]); + const void *mask = cda[CTA_LABELS_MASK]; + + if (len & (sizeof(u32)-1)) /* must be multiple of u32 */ + return -EINVAL; + + if (mask) { + if (nla_len(cda[CTA_LABELS_MASK]) == 0 || + nla_len(cda[CTA_LABELS_MASK]) != len) + return -EINVAL; + mask = nla_data(cda[CTA_LABELS_MASK]); + } + + len /= sizeof(u32); + + return nf_connlabels_replace(ct, nla_data(cda[CTA_LABELS]), mask, len); +#else + return -EOPNOTSUPP; +#endif +} + +static int +ctnetlink_change_conntrack(struct nf_conn *ct, + const struct nlattr * const cda[]) +{ + int err; + + /* only allow NAT changes and master assignation for new conntracks */ + if (cda[CTA_NAT_SRC] || cda[CTA_NAT_DST] || cda[CTA_TUPLE_MASTER]) + return -EOPNOTSUPP; + + if (cda[CTA_HELP]) { + err = ctnetlink_change_helper(ct, cda); + if (err < 0) + return err; + } + + if (cda[CTA_TIMEOUT]) { + err = ctnetlink_change_timeout(ct, cda); + if (err < 0) + return err; + } + + if (cda[CTA_STATUS]) { + err = ctnetlink_change_status(ct, cda); + if (err < 0) + return err; + } + + if (cda[CTA_PROTOINFO]) { + err = ctnetlink_change_protoinfo(ct, cda); + if (err < 0) + return err; + } + +#if defined(CONFIG_NF_CONNTRACK_MARK) + if (cda[CTA_MARK]) + ctnetlink_change_mark(ct, cda); +#endif + + if (cda[CTA_SEQ_ADJ_ORIG] || cda[CTA_SEQ_ADJ_REPLY]) { + err = ctnetlink_change_seq_adj(ct, cda); + if (err < 0) + return err; + } + + if (cda[CTA_SYNPROXY]) { + err = ctnetlink_change_synproxy(ct, cda); + if (err < 0) + return err; + } + + if (cda[CTA_LABELS]) { + err = ctnetlink_attach_labels(ct, cda); + if (err < 0) + return err; + } + + return 0; +} + +static struct nf_conn * +ctnetlink_create_conntrack(struct net *net, + const struct nf_conntrack_zone *zone, + const struct nlattr * const cda[], + struct nf_conntrack_tuple *otuple, + struct nf_conntrack_tuple *rtuple, + u8 u3) +{ + struct nf_conn *ct; + int err = -EINVAL; + struct nf_conntrack_helper *helper; + struct nf_conn_tstamp *tstamp; + u64 timeout; + + ct = nf_conntrack_alloc(net, zone, otuple, rtuple, GFP_ATOMIC); + if (IS_ERR(ct)) + return ERR_PTR(-ENOMEM); + + if (!cda[CTA_TIMEOUT]) + goto err1; + + rcu_read_lock(); + if (cda[CTA_HELP]) { + char *helpname = NULL; + struct nlattr *helpinfo = NULL; + + err = ctnetlink_parse_help(cda[CTA_HELP], &helpname, &helpinfo); + if (err < 0) + goto err2; + + helper = __nf_conntrack_helper_find(helpname, nf_ct_l3num(ct), + nf_ct_protonum(ct)); + if (helper == NULL) { + rcu_read_unlock(); +#ifdef CONFIG_MODULES + if (request_module("nfct-helper-%s", helpname) < 0) { + err = -EOPNOTSUPP; + goto err1; + } + + rcu_read_lock(); + helper = __nf_conntrack_helper_find(helpname, + nf_ct_l3num(ct), + nf_ct_protonum(ct)); + if (helper) { + err = -EAGAIN; + goto err2; + } + rcu_read_unlock(); +#endif + err = -EOPNOTSUPP; + goto err1; + } else { + struct nf_conn_help *help; + + help = nf_ct_helper_ext_add(ct, GFP_ATOMIC); + if (help == NULL) { + err = -ENOMEM; + goto err2; + } + /* set private helper data if allowed. */ + if (helper->from_nlattr) + helper->from_nlattr(helpinfo, ct); + + /* disable helper auto-assignment for this entry */ + ct->status |= IPS_HELPER; + RCU_INIT_POINTER(help->helper, helper); + } + } + + err = ctnetlink_setup_nat(ct, cda); + if (err < 0) + goto err2; + + nf_ct_acct_ext_add(ct, GFP_ATOMIC); + nf_ct_tstamp_ext_add(ct, GFP_ATOMIC); + nf_ct_ecache_ext_add(ct, 0, 0, GFP_ATOMIC); + nf_ct_labels_ext_add(ct); + nfct_seqadj_ext_add(ct); + nfct_synproxy_ext_add(ct); + + /* we must add conntrack extensions before confirmation. */ + ct->status |= IPS_CONFIRMED; + + timeout = (u64)ntohl(nla_get_be32(cda[CTA_TIMEOUT])) * HZ; + __nf_ct_set_timeout(ct, timeout); + + if (cda[CTA_STATUS]) { + err = ctnetlink_change_status(ct, cda); + if (err < 0) + goto err2; + } + + if (cda[CTA_SEQ_ADJ_ORIG] || cda[CTA_SEQ_ADJ_REPLY]) { + err = ctnetlink_change_seq_adj(ct, cda); + if (err < 0) + goto err2; + } + + memset(&ct->proto, 0, sizeof(ct->proto)); + if (cda[CTA_PROTOINFO]) { + err = ctnetlink_change_protoinfo(ct, cda); + if (err < 0) + goto err2; + } + + if (cda[CTA_SYNPROXY]) { + err = ctnetlink_change_synproxy(ct, cda); + if (err < 0) + goto err2; + } + +#if defined(CONFIG_NF_CONNTRACK_MARK) + if (cda[CTA_MARK]) + ctnetlink_change_mark(ct, cda); +#endif + + /* setup master conntrack: this is a confirmed expectation */ + if (cda[CTA_TUPLE_MASTER]) { + struct nf_conntrack_tuple master; + struct nf_conntrack_tuple_hash *master_h; + struct nf_conn *master_ct; + + err = ctnetlink_parse_tuple(cda, &master, CTA_TUPLE_MASTER, + u3, NULL); + if (err < 0) + goto err2; + + master_h = nf_conntrack_find_get(net, zone, &master); + if (master_h == NULL) { + err = -ENOENT; + goto err2; + } + master_ct = nf_ct_tuplehash_to_ctrack(master_h); + __set_bit(IPS_EXPECTED_BIT, &ct->status); + ct->master = master_ct; + } + tstamp = nf_conn_tstamp_find(ct); + if (tstamp) + tstamp->start = ktime_get_real_ns(); + + err = nf_conntrack_hash_check_insert(ct); + if (err < 0) + goto err3; + + rcu_read_unlock(); + + return ct; + +err3: + if (ct->master) + nf_ct_put(ct->master); +err2: + rcu_read_unlock(); +err1: + nf_conntrack_free(ct); + return ERR_PTR(err); +} + +static int ctnetlink_new_conntrack(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const cda[]) +{ + struct nf_conntrack_tuple otuple, rtuple; + struct nf_conntrack_tuple_hash *h = NULL; + u_int8_t u3 = info->nfmsg->nfgen_family; + struct nf_conntrack_zone zone; + struct nf_conn *ct; + int err; + + err = ctnetlink_parse_zone(cda[CTA_ZONE], &zone); + if (err < 0) + return err; + + if (cda[CTA_TUPLE_ORIG]) { + err = ctnetlink_parse_tuple(cda, &otuple, CTA_TUPLE_ORIG, + u3, &zone); + if (err < 0) + return err; + } + + if (cda[CTA_TUPLE_REPLY]) { + err = ctnetlink_parse_tuple(cda, &rtuple, CTA_TUPLE_REPLY, + u3, &zone); + if (err < 0) + return err; + } + + if (cda[CTA_TUPLE_ORIG]) + h = nf_conntrack_find_get(info->net, &zone, &otuple); + else if (cda[CTA_TUPLE_REPLY]) + h = nf_conntrack_find_get(info->net, &zone, &rtuple); + + if (h == NULL) { + err = -ENOENT; + if (info->nlh->nlmsg_flags & NLM_F_CREATE) { + enum ip_conntrack_events events; + + if (!cda[CTA_TUPLE_ORIG] || !cda[CTA_TUPLE_REPLY]) + return -EINVAL; + if (otuple.dst.protonum != rtuple.dst.protonum) + return -EINVAL; + + ct = ctnetlink_create_conntrack(info->net, &zone, cda, + &otuple, &rtuple, u3); + if (IS_ERR(ct)) + return PTR_ERR(ct); + + err = 0; + if (test_bit(IPS_EXPECTED_BIT, &ct->status)) + events = 1 << IPCT_RELATED; + else + events = 1 << IPCT_NEW; + + if (cda[CTA_LABELS] && + ctnetlink_attach_labels(ct, cda) == 0) + events |= (1 << IPCT_LABEL); + + nf_conntrack_eventmask_report((1 << IPCT_REPLY) | + (1 << IPCT_ASSURED) | + (1 << IPCT_HELPER) | + (1 << IPCT_PROTOINFO) | + (1 << IPCT_SEQADJ) | + (1 << IPCT_MARK) | + (1 << IPCT_SYNPROXY) | + events, + ct, NETLINK_CB(skb).portid, + nlmsg_report(info->nlh)); + nf_ct_put(ct); + } + + return err; + } + /* implicit 'else' */ + + err = -EEXIST; + ct = nf_ct_tuplehash_to_ctrack(h); + if (!(info->nlh->nlmsg_flags & NLM_F_EXCL)) { + err = ctnetlink_change_conntrack(ct, cda); + if (err == 0) { + nf_conntrack_eventmask_report((1 << IPCT_REPLY) | + (1 << IPCT_ASSURED) | + (1 << IPCT_HELPER) | + (1 << IPCT_LABEL) | + (1 << IPCT_PROTOINFO) | + (1 << IPCT_SEQADJ) | + (1 << IPCT_MARK) | + (1 << IPCT_SYNPROXY), + ct, NETLINK_CB(skb).portid, + nlmsg_report(info->nlh)); + } + } + + nf_ct_put(ct); + return err; +} + +static int +ctnetlink_ct_stat_cpu_fill_info(struct sk_buff *skb, u32 portid, u32 seq, + __u16 cpu, const struct ip_conntrack_stat *st) +{ + struct nlmsghdr *nlh; + unsigned int flags = portid ? NLM_F_MULTI : 0, event; + + event = nfnl_msg_type(NFNL_SUBSYS_CTNETLINK, + IPCTNL_MSG_CT_GET_STATS_CPU); + nlh = nfnl_msg_put(skb, portid, seq, event, flags, AF_UNSPEC, + NFNETLINK_V0, htons(cpu)); + if (!nlh) + goto nlmsg_failure; + + if (nla_put_be32(skb, CTA_STATS_FOUND, htonl(st->found)) || + nla_put_be32(skb, CTA_STATS_INVALID, htonl(st->invalid)) || + nla_put_be32(skb, CTA_STATS_INSERT, htonl(st->insert)) || + nla_put_be32(skb, CTA_STATS_INSERT_FAILED, + htonl(st->insert_failed)) || + nla_put_be32(skb, CTA_STATS_DROP, htonl(st->drop)) || + nla_put_be32(skb, CTA_STATS_EARLY_DROP, htonl(st->early_drop)) || + nla_put_be32(skb, CTA_STATS_ERROR, htonl(st->error)) || + nla_put_be32(skb, CTA_STATS_SEARCH_RESTART, + htonl(st->search_restart)) || + nla_put_be32(skb, CTA_STATS_CLASH_RESOLVE, + htonl(st->clash_resolve)) || + nla_put_be32(skb, CTA_STATS_CHAIN_TOOLONG, + htonl(st->chaintoolong))) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return skb->len; + +nla_put_failure: +nlmsg_failure: + nlmsg_cancel(skb, nlh); + return -1; +} + +static int +ctnetlink_ct_stat_cpu_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + int cpu; + struct net *net = sock_net(skb->sk); + + if (cb->args[0] == nr_cpu_ids) + return 0; + + for (cpu = cb->args[0]; cpu < nr_cpu_ids; cpu++) { + const struct ip_conntrack_stat *st; + + if (!cpu_possible(cpu)) + continue; + + st = per_cpu_ptr(net->ct.stat, cpu); + if (ctnetlink_ct_stat_cpu_fill_info(skb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + cpu, st) < 0) + break; + } + cb->args[0] = cpu; + + return skb->len; +} + +static int ctnetlink_stat_ct_cpu(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const cda[]) +{ + if (info->nlh->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .dump = ctnetlink_ct_stat_cpu_dump, + }; + return netlink_dump_start(info->sk, skb, info->nlh, &c); + } + + return 0; +} + +static int +ctnetlink_stat_ct_fill_info(struct sk_buff *skb, u32 portid, u32 seq, u32 type, + struct net *net) +{ + unsigned int flags = portid ? NLM_F_MULTI : 0, event; + unsigned int nr_conntracks; + struct nlmsghdr *nlh; + + event = nfnl_msg_type(NFNL_SUBSYS_CTNETLINK, IPCTNL_MSG_CT_GET_STATS); + nlh = nfnl_msg_put(skb, portid, seq, event, flags, AF_UNSPEC, + NFNETLINK_V0, 0); + if (!nlh) + goto nlmsg_failure; + + nr_conntracks = nf_conntrack_count(net); + if (nla_put_be32(skb, CTA_STATS_GLOBAL_ENTRIES, htonl(nr_conntracks))) + goto nla_put_failure; + + if (nla_put_be32(skb, CTA_STATS_GLOBAL_MAX_ENTRIES, htonl(nf_conntrack_max))) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return skb->len; + +nla_put_failure: +nlmsg_failure: + nlmsg_cancel(skb, nlh); + return -1; +} + +static int ctnetlink_stat_ct(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const cda[]) +{ + struct sk_buff *skb2; + int err; + + skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (skb2 == NULL) + return -ENOMEM; + + err = ctnetlink_stat_ct_fill_info(skb2, NETLINK_CB(skb).portid, + info->nlh->nlmsg_seq, + NFNL_MSG_TYPE(info->nlh->nlmsg_type), + sock_net(skb->sk)); + if (err <= 0) { + kfree_skb(skb2); + return -ENOMEM; + } + + return nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid); +} + +static const struct nla_policy exp_nla_policy[CTA_EXPECT_MAX+1] = { + [CTA_EXPECT_MASTER] = { .type = NLA_NESTED }, + [CTA_EXPECT_TUPLE] = { .type = NLA_NESTED }, + [CTA_EXPECT_MASK] = { .type = NLA_NESTED }, + [CTA_EXPECT_TIMEOUT] = { .type = NLA_U32 }, + [CTA_EXPECT_ID] = { .type = NLA_U32 }, + [CTA_EXPECT_HELP_NAME] = { .type = NLA_NUL_STRING, + .len = NF_CT_HELPER_NAME_LEN - 1 }, + [CTA_EXPECT_ZONE] = { .type = NLA_U16 }, + [CTA_EXPECT_FLAGS] = { .type = NLA_U32 }, + [CTA_EXPECT_CLASS] = { .type = NLA_U32 }, + [CTA_EXPECT_NAT] = { .type = NLA_NESTED }, + [CTA_EXPECT_FN] = { .type = NLA_NUL_STRING }, +}; + +static struct nf_conntrack_expect * +ctnetlink_alloc_expect(const struct nlattr *const cda[], struct nf_conn *ct, + struct nf_conntrack_helper *helper, + struct nf_conntrack_tuple *tuple, + struct nf_conntrack_tuple *mask); + +#ifdef CONFIG_NETFILTER_NETLINK_GLUE_CT +static size_t +ctnetlink_glue_build_size(const struct nf_conn *ct) +{ + return 3 * nla_total_size(0) /* CTA_TUPLE_ORIG|REPL|MASTER */ + + 3 * nla_total_size(0) /* CTA_TUPLE_IP */ + + 3 * nla_total_size(0) /* CTA_TUPLE_PROTO */ + + 3 * nla_total_size(sizeof(u_int8_t)) /* CTA_PROTO_NUM */ + + nla_total_size(sizeof(u_int32_t)) /* CTA_ID */ + + nla_total_size(sizeof(u_int32_t)) /* CTA_STATUS */ + + nla_total_size(sizeof(u_int32_t)) /* CTA_TIMEOUT */ + + nla_total_size(0) /* CTA_PROTOINFO */ + + nla_total_size(0) /* CTA_HELP */ + + nla_total_size(NF_CT_HELPER_NAME_LEN) /* CTA_HELP_NAME */ + + ctnetlink_secctx_size(ct) + + ctnetlink_acct_size(ct) + + ctnetlink_timestamp_size(ct) +#if IS_ENABLED(CONFIG_NF_NAT) + + 2 * nla_total_size(0) /* CTA_NAT_SEQ_ADJ_ORIG|REPL */ + + 6 * nla_total_size(sizeof(u_int32_t)) /* CTA_NAT_SEQ_OFFSET */ +#endif +#ifdef CONFIG_NF_CONNTRACK_MARK + + nla_total_size(sizeof(u_int32_t)) /* CTA_MARK */ +#endif +#ifdef CONFIG_NF_CONNTRACK_ZONES + + nla_total_size(sizeof(u_int16_t)) /* CTA_ZONE|CTA_TUPLE_ZONE */ +#endif + + ctnetlink_proto_size(ct) + ; +} + +static int __ctnetlink_glue_build(struct sk_buff *skb, struct nf_conn *ct) +{ + const struct nf_conntrack_zone *zone; + struct nlattr *nest_parms; + + zone = nf_ct_zone(ct); + + nest_parms = nla_nest_start(skb, CTA_TUPLE_ORIG); + if (!nest_parms) + goto nla_put_failure; + if (ctnetlink_dump_tuples(skb, nf_ct_tuple(ct, IP_CT_DIR_ORIGINAL)) < 0) + goto nla_put_failure; + if (ctnetlink_dump_zone_id(skb, CTA_TUPLE_ZONE, zone, + NF_CT_ZONE_DIR_ORIG) < 0) + goto nla_put_failure; + nla_nest_end(skb, nest_parms); + + nest_parms = nla_nest_start(skb, CTA_TUPLE_REPLY); + if (!nest_parms) + goto nla_put_failure; + if (ctnetlink_dump_tuples(skb, nf_ct_tuple(ct, IP_CT_DIR_REPLY)) < 0) + goto nla_put_failure; + if (ctnetlink_dump_zone_id(skb, CTA_TUPLE_ZONE, zone, + NF_CT_ZONE_DIR_REPL) < 0) + goto nla_put_failure; + nla_nest_end(skb, nest_parms); + + if (ctnetlink_dump_zone_id(skb, CTA_ZONE, zone, + NF_CT_DEFAULT_ZONE_DIR) < 0) + goto nla_put_failure; + + if (ctnetlink_dump_id(skb, ct) < 0) + goto nla_put_failure; + + if (ctnetlink_dump_status(skb, ct) < 0) + goto nla_put_failure; + + if (ctnetlink_dump_timeout(skb, ct, false) < 0) + goto nla_put_failure; + + if (ctnetlink_dump_protoinfo(skb, ct, false) < 0) + goto nla_put_failure; + + if (ctnetlink_dump_acct(skb, ct, IPCTNL_MSG_CT_GET) < 0 || + ctnetlink_dump_timestamp(skb, ct) < 0) + goto nla_put_failure; + + if (ctnetlink_dump_helpinfo(skb, ct) < 0) + goto nla_put_failure; + +#ifdef CONFIG_NF_CONNTRACK_SECMARK + if (ct->secmark && ctnetlink_dump_secctx(skb, ct) < 0) + goto nla_put_failure; +#endif + if (ct->master && ctnetlink_dump_master(skb, ct) < 0) + goto nla_put_failure; + + if ((ct->status & IPS_SEQ_ADJUST) && + ctnetlink_dump_ct_seq_adj(skb, ct) < 0) + goto nla_put_failure; + + if (ctnetlink_dump_ct_synproxy(skb, ct) < 0) + goto nla_put_failure; + +#ifdef CONFIG_NF_CONNTRACK_MARK + if (ctnetlink_dump_mark(skb, ct, true) < 0) + goto nla_put_failure; +#endif + if (ctnetlink_dump_labels(skb, ct) < 0) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -ENOSPC; +} + +static int +ctnetlink_glue_build(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + u_int16_t ct_attr, u_int16_t ct_info_attr) +{ + struct nlattr *nest_parms; + + nest_parms = nla_nest_start(skb, ct_attr); + if (!nest_parms) + goto nla_put_failure; + + if (__ctnetlink_glue_build(skb, ct) < 0) + goto nla_put_failure; + + nla_nest_end(skb, nest_parms); + + if (nla_put_be32(skb, ct_info_attr, htonl(ctinfo))) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -ENOSPC; +} + +static int +ctnetlink_update_status(struct nf_conn *ct, const struct nlattr * const cda[]) +{ + unsigned int status = ntohl(nla_get_be32(cda[CTA_STATUS])); + unsigned long d = ct->status ^ status; + + if (d & IPS_SEEN_REPLY && !(status & IPS_SEEN_REPLY)) + /* SEEN_REPLY bit can only be set */ + return -EBUSY; + + if (d & IPS_ASSURED && !(status & IPS_ASSURED)) + /* ASSURED bit can only be set */ + return -EBUSY; + + /* This check is less strict than ctnetlink_change_status() + * because callers often flip IPS_EXPECTED bits when sending + * an NFQA_CT attribute to the kernel. So ignore the + * unchangeable bits but do not error out. Also user programs + * are allowed to clear the bits that they are allowed to change. + */ + __nf_ct_change_status(ct, status, ~status); + return 0; +} + +static int +ctnetlink_glue_parse_ct(const struct nlattr *cda[], struct nf_conn *ct) +{ + int err; + + if (cda[CTA_TIMEOUT]) { + err = ctnetlink_change_timeout(ct, cda); + if (err < 0) + return err; + } + if (cda[CTA_STATUS]) { + err = ctnetlink_update_status(ct, cda); + if (err < 0) + return err; + } + if (cda[CTA_HELP]) { + err = ctnetlink_change_helper(ct, cda); + if (err < 0) + return err; + } + if (cda[CTA_LABELS]) { + err = ctnetlink_attach_labels(ct, cda); + if (err < 0) + return err; + } +#if defined(CONFIG_NF_CONNTRACK_MARK) + if (cda[CTA_MARK]) { + ctnetlink_change_mark(ct, cda); + } +#endif + return 0; +} + +static int +ctnetlink_glue_parse(const struct nlattr *attr, struct nf_conn *ct) +{ + struct nlattr *cda[CTA_MAX+1]; + int ret; + + ret = nla_parse_nested_deprecated(cda, CTA_MAX, attr, ct_nla_policy, + NULL); + if (ret < 0) + return ret; + + return ctnetlink_glue_parse_ct((const struct nlattr **)cda, ct); +} + +static int ctnetlink_glue_exp_parse(const struct nlattr * const *cda, + const struct nf_conn *ct, + struct nf_conntrack_tuple *tuple, + struct nf_conntrack_tuple *mask) +{ + int err; + + err = ctnetlink_parse_tuple(cda, tuple, CTA_EXPECT_TUPLE, + nf_ct_l3num(ct), NULL); + if (err < 0) + return err; + + return ctnetlink_parse_tuple(cda, mask, CTA_EXPECT_MASK, + nf_ct_l3num(ct), NULL); +} + +static int +ctnetlink_glue_attach_expect(const struct nlattr *attr, struct nf_conn *ct, + u32 portid, u32 report) +{ + struct nlattr *cda[CTA_EXPECT_MAX+1]; + struct nf_conntrack_tuple tuple, mask; + struct nf_conntrack_helper *helper = NULL; + struct nf_conntrack_expect *exp; + int err; + + err = nla_parse_nested_deprecated(cda, CTA_EXPECT_MAX, attr, + exp_nla_policy, NULL); + if (err < 0) + return err; + + err = ctnetlink_glue_exp_parse((const struct nlattr * const *)cda, + ct, &tuple, &mask); + if (err < 0) + return err; + + if (cda[CTA_EXPECT_HELP_NAME]) { + const char *helpname = nla_data(cda[CTA_EXPECT_HELP_NAME]); + + helper = __nf_conntrack_helper_find(helpname, nf_ct_l3num(ct), + nf_ct_protonum(ct)); + if (helper == NULL) + return -EOPNOTSUPP; + } + + exp = ctnetlink_alloc_expect((const struct nlattr * const *)cda, ct, + helper, &tuple, &mask); + if (IS_ERR(exp)) + return PTR_ERR(exp); + + err = nf_ct_expect_related_report(exp, portid, report, 0); + nf_ct_expect_put(exp); + return err; +} + +static void ctnetlink_glue_seqadj(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, int diff) +{ + if (!(ct->status & IPS_NAT_MASK)) + return; + + nf_ct_tcp_seqadj_set(skb, ct, ctinfo, diff); +} + +static const struct nfnl_ct_hook ctnetlink_glue_hook = { + .build_size = ctnetlink_glue_build_size, + .build = ctnetlink_glue_build, + .parse = ctnetlink_glue_parse, + .attach_expect = ctnetlink_glue_attach_expect, + .seq_adjust = ctnetlink_glue_seqadj, +}; +#endif /* CONFIG_NETFILTER_NETLINK_GLUE_CT */ + +/*********************************************************************** + * EXPECT + ***********************************************************************/ + +static int ctnetlink_exp_dump_tuple(struct sk_buff *skb, + const struct nf_conntrack_tuple *tuple, + u32 type) +{ + struct nlattr *nest_parms; + + nest_parms = nla_nest_start(skb, type); + if (!nest_parms) + goto nla_put_failure; + if (ctnetlink_dump_tuples(skb, tuple) < 0) + goto nla_put_failure; + nla_nest_end(skb, nest_parms); + + return 0; + +nla_put_failure: + return -1; +} + +static int ctnetlink_exp_dump_mask(struct sk_buff *skb, + const struct nf_conntrack_tuple *tuple, + const struct nf_conntrack_tuple_mask *mask) +{ + const struct nf_conntrack_l4proto *l4proto; + struct nf_conntrack_tuple m; + struct nlattr *nest_parms; + int ret; + + memset(&m, 0xFF, sizeof(m)); + memcpy(&m.src.u3, &mask->src.u3, sizeof(m.src.u3)); + m.src.u.all = mask->src.u.all; + m.src.l3num = tuple->src.l3num; + m.dst.protonum = tuple->dst.protonum; + + nest_parms = nla_nest_start(skb, CTA_EXPECT_MASK); + if (!nest_parms) + goto nla_put_failure; + + rcu_read_lock(); + ret = ctnetlink_dump_tuples_ip(skb, &m); + if (ret >= 0) { + l4proto = nf_ct_l4proto_find(tuple->dst.protonum); + ret = ctnetlink_dump_tuples_proto(skb, &m, l4proto); + } + rcu_read_unlock(); + + if (unlikely(ret < 0)) + goto nla_put_failure; + + nla_nest_end(skb, nest_parms); + + return 0; + +nla_put_failure: + return -1; +} + +#if IS_ENABLED(CONFIG_NF_NAT) +static const union nf_inet_addr any_addr; +#endif + +static __be32 nf_expect_get_id(const struct nf_conntrack_expect *exp) +{ + static siphash_aligned_key_t exp_id_seed; + unsigned long a, b, c, d; + + net_get_random_once(&exp_id_seed, sizeof(exp_id_seed)); + + a = (unsigned long)exp; + b = (unsigned long)exp->helper; + c = (unsigned long)exp->master; + d = (unsigned long)siphash(&exp->tuple, sizeof(exp->tuple), &exp_id_seed); + +#ifdef CONFIG_64BIT + return (__force __be32)siphash_4u64((u64)a, (u64)b, (u64)c, (u64)d, &exp_id_seed); +#else + return (__force __be32)siphash_4u32((u32)a, (u32)b, (u32)c, (u32)d, &exp_id_seed); +#endif +} + +static int +ctnetlink_exp_dump_expect(struct sk_buff *skb, + const struct nf_conntrack_expect *exp) +{ + struct nf_conn *master = exp->master; + long timeout = ((long)exp->timeout.expires - (long)jiffies) / HZ; + struct nf_conn_help *help; +#if IS_ENABLED(CONFIG_NF_NAT) + struct nlattr *nest_parms; + struct nf_conntrack_tuple nat_tuple = {}; +#endif + struct nf_ct_helper_expectfn *expfn; + + if (timeout < 0) + timeout = 0; + + if (ctnetlink_exp_dump_tuple(skb, &exp->tuple, CTA_EXPECT_TUPLE) < 0) + goto nla_put_failure; + if (ctnetlink_exp_dump_mask(skb, &exp->tuple, &exp->mask) < 0) + goto nla_put_failure; + if (ctnetlink_exp_dump_tuple(skb, + &master->tuplehash[IP_CT_DIR_ORIGINAL].tuple, + CTA_EXPECT_MASTER) < 0) + goto nla_put_failure; + +#if IS_ENABLED(CONFIG_NF_NAT) + if (!nf_inet_addr_cmp(&exp->saved_addr, &any_addr) || + exp->saved_proto.all) { + nest_parms = nla_nest_start(skb, CTA_EXPECT_NAT); + if (!nest_parms) + goto nla_put_failure; + + if (nla_put_be32(skb, CTA_EXPECT_NAT_DIR, htonl(exp->dir))) + goto nla_put_failure; + + nat_tuple.src.l3num = nf_ct_l3num(master); + nat_tuple.src.u3 = exp->saved_addr; + nat_tuple.dst.protonum = nf_ct_protonum(master); + nat_tuple.src.u = exp->saved_proto; + + if (ctnetlink_exp_dump_tuple(skb, &nat_tuple, + CTA_EXPECT_NAT_TUPLE) < 0) + goto nla_put_failure; + nla_nest_end(skb, nest_parms); + } +#endif + if (nla_put_be32(skb, CTA_EXPECT_TIMEOUT, htonl(timeout)) || + nla_put_be32(skb, CTA_EXPECT_ID, nf_expect_get_id(exp)) || + nla_put_be32(skb, CTA_EXPECT_FLAGS, htonl(exp->flags)) || + nla_put_be32(skb, CTA_EXPECT_CLASS, htonl(exp->class))) + goto nla_put_failure; + help = nfct_help(master); + if (help) { + struct nf_conntrack_helper *helper; + + helper = rcu_dereference(help->helper); + if (helper && + nla_put_string(skb, CTA_EXPECT_HELP_NAME, helper->name)) + goto nla_put_failure; + } + expfn = nf_ct_helper_expectfn_find_by_symbol(exp->expectfn); + if (expfn != NULL && + nla_put_string(skb, CTA_EXPECT_FN, expfn->name)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static int +ctnetlink_exp_fill_info(struct sk_buff *skb, u32 portid, u32 seq, + int event, const struct nf_conntrack_expect *exp) +{ + struct nlmsghdr *nlh; + unsigned int flags = portid ? NLM_F_MULTI : 0; + + event = nfnl_msg_type(NFNL_SUBSYS_CTNETLINK_EXP, event); + nlh = nfnl_msg_put(skb, portid, seq, event, flags, + exp->tuple.src.l3num, NFNETLINK_V0, 0); + if (!nlh) + goto nlmsg_failure; + + if (ctnetlink_exp_dump_expect(skb, exp) < 0) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return skb->len; + +nlmsg_failure: +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -1; +} + +#ifdef CONFIG_NF_CONNTRACK_EVENTS +static int +ctnetlink_expect_event(unsigned int events, const struct nf_exp_event *item) +{ + struct nf_conntrack_expect *exp = item->exp; + struct net *net = nf_ct_exp_net(exp); + struct nlmsghdr *nlh; + struct sk_buff *skb; + unsigned int type, group; + int flags = 0; + + if (events & (1 << IPEXP_DESTROY)) { + type = IPCTNL_MSG_EXP_DELETE; + group = NFNLGRP_CONNTRACK_EXP_DESTROY; + } else if (events & (1 << IPEXP_NEW)) { + type = IPCTNL_MSG_EXP_NEW; + flags = NLM_F_CREATE|NLM_F_EXCL; + group = NFNLGRP_CONNTRACK_EXP_NEW; + } else + return 0; + + if (!item->report && !nfnetlink_has_listeners(net, group)) + return 0; + + skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (skb == NULL) + goto errout; + + type = nfnl_msg_type(NFNL_SUBSYS_CTNETLINK_EXP, type); + nlh = nfnl_msg_put(skb, item->portid, 0, type, flags, + exp->tuple.src.l3num, NFNETLINK_V0, 0); + if (!nlh) + goto nlmsg_failure; + + if (ctnetlink_exp_dump_expect(skb, exp) < 0) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + nfnetlink_send(skb, net, item->portid, group, item->report, GFP_ATOMIC); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); +nlmsg_failure: + kfree_skb(skb); +errout: + nfnetlink_set_err(net, 0, 0, -ENOBUFS); + return 0; +} +#endif +static int ctnetlink_exp_done(struct netlink_callback *cb) +{ + if (cb->args[1]) + nf_ct_expect_put((struct nf_conntrack_expect *)cb->args[1]); + return 0; +} + +static int +ctnetlink_exp_dump_table(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + struct nf_conntrack_expect *exp, *last; + struct nfgenmsg *nfmsg = nlmsg_data(cb->nlh); + u_int8_t l3proto = nfmsg->nfgen_family; + + rcu_read_lock(); + last = (struct nf_conntrack_expect *)cb->args[1]; + for (; cb->args[0] < nf_ct_expect_hsize; cb->args[0]++) { +restart: + hlist_for_each_entry_rcu(exp, &nf_ct_expect_hash[cb->args[0]], + hnode) { + if (l3proto && exp->tuple.src.l3num != l3proto) + continue; + + if (!net_eq(nf_ct_net(exp->master), net)) + continue; + + if (cb->args[1]) { + if (exp != last) + continue; + cb->args[1] = 0; + } + if (ctnetlink_exp_fill_info(skb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + IPCTNL_MSG_EXP_NEW, + exp) < 0) { + if (!refcount_inc_not_zero(&exp->use)) + continue; + cb->args[1] = (unsigned long)exp; + goto out; + } + } + if (cb->args[1]) { + cb->args[1] = 0; + goto restart; + } + } +out: + rcu_read_unlock(); + if (last) + nf_ct_expect_put(last); + + return skb->len; +} + +static int +ctnetlink_exp_ct_dump_table(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct nf_conntrack_expect *exp, *last; + struct nfgenmsg *nfmsg = nlmsg_data(cb->nlh); + struct nf_conn *ct = cb->data; + struct nf_conn_help *help = nfct_help(ct); + u_int8_t l3proto = nfmsg->nfgen_family; + + if (cb->args[0]) + return 0; + + rcu_read_lock(); + last = (struct nf_conntrack_expect *)cb->args[1]; +restart: + hlist_for_each_entry_rcu(exp, &help->expectations, lnode) { + if (l3proto && exp->tuple.src.l3num != l3proto) + continue; + if (cb->args[1]) { + if (exp != last) + continue; + cb->args[1] = 0; + } + if (ctnetlink_exp_fill_info(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + IPCTNL_MSG_EXP_NEW, + exp) < 0) { + if (!refcount_inc_not_zero(&exp->use)) + continue; + cb->args[1] = (unsigned long)exp; + goto out; + } + } + if (cb->args[1]) { + cb->args[1] = 0; + goto restart; + } + cb->args[0] = 1; +out: + rcu_read_unlock(); + if (last) + nf_ct_expect_put(last); + + return skb->len; +} + +static int ctnetlink_dump_exp_ct(struct net *net, struct sock *ctnl, + struct sk_buff *skb, + const struct nlmsghdr *nlh, + const struct nlattr * const cda[], + struct netlink_ext_ack *extack) +{ + int err; + struct nfgenmsg *nfmsg = nlmsg_data(nlh); + u_int8_t u3 = nfmsg->nfgen_family; + struct nf_conntrack_tuple tuple; + struct nf_conntrack_tuple_hash *h; + struct nf_conn *ct; + struct nf_conntrack_zone zone; + struct netlink_dump_control c = { + .dump = ctnetlink_exp_ct_dump_table, + .done = ctnetlink_exp_done, + }; + + err = ctnetlink_parse_tuple(cda, &tuple, CTA_EXPECT_MASTER, + u3, NULL); + if (err < 0) + return err; + + err = ctnetlink_parse_zone(cda[CTA_EXPECT_ZONE], &zone); + if (err < 0) + return err; + + h = nf_conntrack_find_get(net, &zone, &tuple); + if (!h) + return -ENOENT; + + ct = nf_ct_tuplehash_to_ctrack(h); + /* No expectation linked to this connection tracking. */ + if (!nfct_help(ct)) { + nf_ct_put(ct); + return 0; + } + + c.data = ct; + + err = netlink_dump_start(ctnl, skb, nlh, &c); + nf_ct_put(ct); + + return err; +} + +static int ctnetlink_get_expect(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const cda[]) +{ + u_int8_t u3 = info->nfmsg->nfgen_family; + struct nf_conntrack_tuple tuple; + struct nf_conntrack_expect *exp; + struct nf_conntrack_zone zone; + struct sk_buff *skb2; + int err; + + if (info->nlh->nlmsg_flags & NLM_F_DUMP) { + if (cda[CTA_EXPECT_MASTER]) + return ctnetlink_dump_exp_ct(info->net, info->sk, skb, + info->nlh, cda, + info->extack); + else { + struct netlink_dump_control c = { + .dump = ctnetlink_exp_dump_table, + .done = ctnetlink_exp_done, + }; + return netlink_dump_start(info->sk, skb, info->nlh, &c); + } + } + + err = ctnetlink_parse_zone(cda[CTA_EXPECT_ZONE], &zone); + if (err < 0) + return err; + + if (cda[CTA_EXPECT_TUPLE]) + err = ctnetlink_parse_tuple(cda, &tuple, CTA_EXPECT_TUPLE, + u3, NULL); + else if (cda[CTA_EXPECT_MASTER]) + err = ctnetlink_parse_tuple(cda, &tuple, CTA_EXPECT_MASTER, + u3, NULL); + else + return -EINVAL; + + if (err < 0) + return err; + + exp = nf_ct_expect_find_get(info->net, &zone, &tuple); + if (!exp) + return -ENOENT; + + if (cda[CTA_EXPECT_ID]) { + __be32 id = nla_get_be32(cda[CTA_EXPECT_ID]); + + if (id != nf_expect_get_id(exp)) { + nf_ct_expect_put(exp); + return -ENOENT; + } + } + + skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb2) { + nf_ct_expect_put(exp); + return -ENOMEM; + } + + rcu_read_lock(); + err = ctnetlink_exp_fill_info(skb2, NETLINK_CB(skb).portid, + info->nlh->nlmsg_seq, IPCTNL_MSG_EXP_NEW, + exp); + rcu_read_unlock(); + nf_ct_expect_put(exp); + if (err <= 0) { + kfree_skb(skb2); + return -ENOMEM; + } + + return nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid); +} + +static bool expect_iter_name(struct nf_conntrack_expect *exp, void *data) +{ + struct nf_conntrack_helper *helper; + const struct nf_conn_help *m_help; + const char *name = data; + + m_help = nfct_help(exp->master); + + helper = rcu_dereference(m_help->helper); + if (!helper) + return false; + + return strcmp(helper->name, name) == 0; +} + +static bool expect_iter_all(struct nf_conntrack_expect *exp, void *data) +{ + return true; +} + +static int ctnetlink_del_expect(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const cda[]) +{ + u_int8_t u3 = info->nfmsg->nfgen_family; + struct nf_conntrack_expect *exp; + struct nf_conntrack_tuple tuple; + struct nf_conntrack_zone zone; + int err; + + if (cda[CTA_EXPECT_TUPLE]) { + /* delete a single expect by tuple */ + err = ctnetlink_parse_zone(cda[CTA_EXPECT_ZONE], &zone); + if (err < 0) + return err; + + err = ctnetlink_parse_tuple(cda, &tuple, CTA_EXPECT_TUPLE, + u3, NULL); + if (err < 0) + return err; + + /* bump usage count to 2 */ + exp = nf_ct_expect_find_get(info->net, &zone, &tuple); + if (!exp) + return -ENOENT; + + if (cda[CTA_EXPECT_ID]) { + __be32 id = nla_get_be32(cda[CTA_EXPECT_ID]); + if (ntohl(id) != (u32)(unsigned long)exp) { + nf_ct_expect_put(exp); + return -ENOENT; + } + } + + /* after list removal, usage count == 1 */ + spin_lock_bh(&nf_conntrack_expect_lock); + if (del_timer(&exp->timeout)) { + nf_ct_unlink_expect_report(exp, NETLINK_CB(skb).portid, + nlmsg_report(info->nlh)); + nf_ct_expect_put(exp); + } + spin_unlock_bh(&nf_conntrack_expect_lock); + /* have to put what we 'get' above. + * after this line usage count == 0 */ + nf_ct_expect_put(exp); + } else if (cda[CTA_EXPECT_HELP_NAME]) { + char *name = nla_data(cda[CTA_EXPECT_HELP_NAME]); + + nf_ct_expect_iterate_net(info->net, expect_iter_name, name, + NETLINK_CB(skb).portid, + nlmsg_report(info->nlh)); + } else { + /* This basically means we have to flush everything*/ + nf_ct_expect_iterate_net(info->net, expect_iter_all, NULL, + NETLINK_CB(skb).portid, + nlmsg_report(info->nlh)); + } + + return 0; +} +static int +ctnetlink_change_expect(struct nf_conntrack_expect *x, + const struct nlattr * const cda[]) +{ + if (cda[CTA_EXPECT_TIMEOUT]) { + if (!del_timer(&x->timeout)) + return -ETIME; + + x->timeout.expires = jiffies + + ntohl(nla_get_be32(cda[CTA_EXPECT_TIMEOUT])) * HZ; + add_timer(&x->timeout); + } + return 0; +} + +#if IS_ENABLED(CONFIG_NF_NAT) +static const struct nla_policy exp_nat_nla_policy[CTA_EXPECT_NAT_MAX+1] = { + [CTA_EXPECT_NAT_DIR] = { .type = NLA_U32 }, + [CTA_EXPECT_NAT_TUPLE] = { .type = NLA_NESTED }, +}; +#endif + +static int +ctnetlink_parse_expect_nat(const struct nlattr *attr, + struct nf_conntrack_expect *exp, + u_int8_t u3) +{ +#if IS_ENABLED(CONFIG_NF_NAT) + struct nlattr *tb[CTA_EXPECT_NAT_MAX+1]; + struct nf_conntrack_tuple nat_tuple = {}; + int err; + + err = nla_parse_nested_deprecated(tb, CTA_EXPECT_NAT_MAX, attr, + exp_nat_nla_policy, NULL); + if (err < 0) + return err; + + if (!tb[CTA_EXPECT_NAT_DIR] || !tb[CTA_EXPECT_NAT_TUPLE]) + return -EINVAL; + + err = ctnetlink_parse_tuple((const struct nlattr * const *)tb, + &nat_tuple, CTA_EXPECT_NAT_TUPLE, + u3, NULL); + if (err < 0) + return err; + + exp->saved_addr = nat_tuple.src.u3; + exp->saved_proto = nat_tuple.src.u; + exp->dir = ntohl(nla_get_be32(tb[CTA_EXPECT_NAT_DIR])); + + return 0; +#else + return -EOPNOTSUPP; +#endif +} + +static struct nf_conntrack_expect * +ctnetlink_alloc_expect(const struct nlattr * const cda[], struct nf_conn *ct, + struct nf_conntrack_helper *helper, + struct nf_conntrack_tuple *tuple, + struct nf_conntrack_tuple *mask) +{ + u_int32_t class = 0; + struct nf_conntrack_expect *exp; + struct nf_conn_help *help; + int err; + + help = nfct_help(ct); + if (!help) + return ERR_PTR(-EOPNOTSUPP); + + if (cda[CTA_EXPECT_CLASS] && helper) { + class = ntohl(nla_get_be32(cda[CTA_EXPECT_CLASS])); + if (class > helper->expect_class_max) + return ERR_PTR(-EINVAL); + } + exp = nf_ct_expect_alloc(ct); + if (!exp) + return ERR_PTR(-ENOMEM); + + if (cda[CTA_EXPECT_FLAGS]) { + exp->flags = ntohl(nla_get_be32(cda[CTA_EXPECT_FLAGS])); + exp->flags &= ~NF_CT_EXPECT_USERSPACE; + } else { + exp->flags = 0; + } + if (cda[CTA_EXPECT_FN]) { + const char *name = nla_data(cda[CTA_EXPECT_FN]); + struct nf_ct_helper_expectfn *expfn; + + expfn = nf_ct_helper_expectfn_find_by_name(name); + if (expfn == NULL) { + err = -EINVAL; + goto err_out; + } + exp->expectfn = expfn->expectfn; + } else + exp->expectfn = NULL; + + exp->class = class; + exp->master = ct; + exp->helper = helper; + exp->tuple = *tuple; + exp->mask.src.u3 = mask->src.u3; + exp->mask.src.u.all = mask->src.u.all; + + if (cda[CTA_EXPECT_NAT]) { + err = ctnetlink_parse_expect_nat(cda[CTA_EXPECT_NAT], + exp, nf_ct_l3num(ct)); + if (err < 0) + goto err_out; + } + return exp; +err_out: + nf_ct_expect_put(exp); + return ERR_PTR(err); +} + +static int +ctnetlink_create_expect(struct net *net, + const struct nf_conntrack_zone *zone, + const struct nlattr * const cda[], + u_int8_t u3, u32 portid, int report) +{ + struct nf_conntrack_tuple tuple, mask, master_tuple; + struct nf_conntrack_tuple_hash *h = NULL; + struct nf_conntrack_helper *helper = NULL; + struct nf_conntrack_expect *exp; + struct nf_conn *ct; + int err; + + /* caller guarantees that those three CTA_EXPECT_* exist */ + err = ctnetlink_parse_tuple(cda, &tuple, CTA_EXPECT_TUPLE, + u3, NULL); + if (err < 0) + return err; + err = ctnetlink_parse_tuple(cda, &mask, CTA_EXPECT_MASK, + u3, NULL); + if (err < 0) + return err; + err = ctnetlink_parse_tuple(cda, &master_tuple, CTA_EXPECT_MASTER, + u3, NULL); + if (err < 0) + return err; + + /* Look for master conntrack of this expectation */ + h = nf_conntrack_find_get(net, zone, &master_tuple); + if (!h) + return -ENOENT; + ct = nf_ct_tuplehash_to_ctrack(h); + + rcu_read_lock(); + if (cda[CTA_EXPECT_HELP_NAME]) { + const char *helpname = nla_data(cda[CTA_EXPECT_HELP_NAME]); + + helper = __nf_conntrack_helper_find(helpname, u3, + nf_ct_protonum(ct)); + if (helper == NULL) { + rcu_read_unlock(); +#ifdef CONFIG_MODULES + if (request_module("nfct-helper-%s", helpname) < 0) { + err = -EOPNOTSUPP; + goto err_ct; + } + rcu_read_lock(); + helper = __nf_conntrack_helper_find(helpname, u3, + nf_ct_protonum(ct)); + if (helper) { + err = -EAGAIN; + goto err_rcu; + } + rcu_read_unlock(); +#endif + err = -EOPNOTSUPP; + goto err_ct; + } + } + + exp = ctnetlink_alloc_expect(cda, ct, helper, &tuple, &mask); + if (IS_ERR(exp)) { + err = PTR_ERR(exp); + goto err_rcu; + } + + err = nf_ct_expect_related_report(exp, portid, report, 0); + nf_ct_expect_put(exp); +err_rcu: + rcu_read_unlock(); +err_ct: + nf_ct_put(ct); + return err; +} + +static int ctnetlink_new_expect(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const cda[]) +{ + u_int8_t u3 = info->nfmsg->nfgen_family; + struct nf_conntrack_tuple tuple; + struct nf_conntrack_expect *exp; + struct nf_conntrack_zone zone; + int err; + + if (!cda[CTA_EXPECT_TUPLE] + || !cda[CTA_EXPECT_MASK] + || !cda[CTA_EXPECT_MASTER]) + return -EINVAL; + + err = ctnetlink_parse_zone(cda[CTA_EXPECT_ZONE], &zone); + if (err < 0) + return err; + + err = ctnetlink_parse_tuple(cda, &tuple, CTA_EXPECT_TUPLE, + u3, NULL); + if (err < 0) + return err; + + spin_lock_bh(&nf_conntrack_expect_lock); + exp = __nf_ct_expect_find(info->net, &zone, &tuple); + if (!exp) { + spin_unlock_bh(&nf_conntrack_expect_lock); + err = -ENOENT; + if (info->nlh->nlmsg_flags & NLM_F_CREATE) { + err = ctnetlink_create_expect(info->net, &zone, cda, u3, + NETLINK_CB(skb).portid, + nlmsg_report(info->nlh)); + } + return err; + } + + err = -EEXIST; + if (!(info->nlh->nlmsg_flags & NLM_F_EXCL)) + err = ctnetlink_change_expect(exp, cda); + spin_unlock_bh(&nf_conntrack_expect_lock); + + return err; +} + +static int +ctnetlink_exp_stat_fill_info(struct sk_buff *skb, u32 portid, u32 seq, int cpu, + const struct ip_conntrack_stat *st) +{ + struct nlmsghdr *nlh; + unsigned int flags = portid ? NLM_F_MULTI : 0, event; + + event = nfnl_msg_type(NFNL_SUBSYS_CTNETLINK, + IPCTNL_MSG_EXP_GET_STATS_CPU); + nlh = nfnl_msg_put(skb, portid, seq, event, flags, AF_UNSPEC, + NFNETLINK_V0, htons(cpu)); + if (!nlh) + goto nlmsg_failure; + + if (nla_put_be32(skb, CTA_STATS_EXP_NEW, htonl(st->expect_new)) || + nla_put_be32(skb, CTA_STATS_EXP_CREATE, htonl(st->expect_create)) || + nla_put_be32(skb, CTA_STATS_EXP_DELETE, htonl(st->expect_delete))) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return skb->len; + +nla_put_failure: +nlmsg_failure: + nlmsg_cancel(skb, nlh); + return -1; +} + +static int +ctnetlink_exp_stat_cpu_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + int cpu; + struct net *net = sock_net(skb->sk); + + if (cb->args[0] == nr_cpu_ids) + return 0; + + for (cpu = cb->args[0]; cpu < nr_cpu_ids; cpu++) { + const struct ip_conntrack_stat *st; + + if (!cpu_possible(cpu)) + continue; + + st = per_cpu_ptr(net->ct.stat, cpu); + if (ctnetlink_exp_stat_fill_info(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + cpu, st) < 0) + break; + } + cb->args[0] = cpu; + + return skb->len; +} + +static int ctnetlink_stat_exp_cpu(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const cda[]) +{ + if (info->nlh->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .dump = ctnetlink_exp_stat_cpu_dump, + }; + return netlink_dump_start(info->sk, skb, info->nlh, &c); + } + + return 0; +} + +#ifdef CONFIG_NF_CONNTRACK_EVENTS +static struct nf_ct_event_notifier ctnl_notifier = { + .ct_event = ctnetlink_conntrack_event, + .exp_event = ctnetlink_expect_event, +}; +#endif + +static const struct nfnl_callback ctnl_cb[IPCTNL_MSG_MAX] = { + [IPCTNL_MSG_CT_NEW] = { + .call = ctnetlink_new_conntrack, + .type = NFNL_CB_MUTEX, + .attr_count = CTA_MAX, + .policy = ct_nla_policy + }, + [IPCTNL_MSG_CT_GET] = { + .call = ctnetlink_get_conntrack, + .type = NFNL_CB_MUTEX, + .attr_count = CTA_MAX, + .policy = ct_nla_policy + }, + [IPCTNL_MSG_CT_DELETE] = { + .call = ctnetlink_del_conntrack, + .type = NFNL_CB_MUTEX, + .attr_count = CTA_MAX, + .policy = ct_nla_policy + }, + [IPCTNL_MSG_CT_GET_CTRZERO] = { + .call = ctnetlink_get_conntrack, + .type = NFNL_CB_MUTEX, + .attr_count = CTA_MAX, + .policy = ct_nla_policy + }, + [IPCTNL_MSG_CT_GET_STATS_CPU] = { + .call = ctnetlink_stat_ct_cpu, + .type = NFNL_CB_MUTEX, + }, + [IPCTNL_MSG_CT_GET_STATS] = { + .call = ctnetlink_stat_ct, + .type = NFNL_CB_MUTEX, + }, + [IPCTNL_MSG_CT_GET_DYING] = { + .call = ctnetlink_get_ct_dying, + .type = NFNL_CB_MUTEX, + }, + [IPCTNL_MSG_CT_GET_UNCONFIRMED] = { + .call = ctnetlink_get_ct_unconfirmed, + .type = NFNL_CB_MUTEX, + }, +}; + +static const struct nfnl_callback ctnl_exp_cb[IPCTNL_MSG_EXP_MAX] = { + [IPCTNL_MSG_EXP_GET] = { + .call = ctnetlink_get_expect, + .type = NFNL_CB_MUTEX, + .attr_count = CTA_EXPECT_MAX, + .policy = exp_nla_policy + }, + [IPCTNL_MSG_EXP_NEW] = { + .call = ctnetlink_new_expect, + .type = NFNL_CB_MUTEX, + .attr_count = CTA_EXPECT_MAX, + .policy = exp_nla_policy + }, + [IPCTNL_MSG_EXP_DELETE] = { + .call = ctnetlink_del_expect, + .type = NFNL_CB_MUTEX, + .attr_count = CTA_EXPECT_MAX, + .policy = exp_nla_policy + }, + [IPCTNL_MSG_EXP_GET_STATS_CPU] = { + .call = ctnetlink_stat_exp_cpu, + .type = NFNL_CB_MUTEX, + }, +}; + +static const struct nfnetlink_subsystem ctnl_subsys = { + .name = "conntrack", + .subsys_id = NFNL_SUBSYS_CTNETLINK, + .cb_count = IPCTNL_MSG_MAX, + .cb = ctnl_cb, +}; + +static const struct nfnetlink_subsystem ctnl_exp_subsys = { + .name = "conntrack_expect", + .subsys_id = NFNL_SUBSYS_CTNETLINK_EXP, + .cb_count = IPCTNL_MSG_EXP_MAX, + .cb = ctnl_exp_cb, +}; + +MODULE_ALIAS("ip_conntrack_netlink"); +MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_CTNETLINK); +MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_CTNETLINK_EXP); + +static int __net_init ctnetlink_net_init(struct net *net) +{ +#ifdef CONFIG_NF_CONNTRACK_EVENTS + nf_conntrack_register_notifier(net, &ctnl_notifier); +#endif + return 0; +} + +static void ctnetlink_net_pre_exit(struct net *net) +{ +#ifdef CONFIG_NF_CONNTRACK_EVENTS + nf_conntrack_unregister_notifier(net); +#endif +} + +static struct pernet_operations ctnetlink_net_ops = { + .init = ctnetlink_net_init, + .pre_exit = ctnetlink_net_pre_exit, +}; + +static int __init ctnetlink_init(void) +{ + int ret; + + BUILD_BUG_ON(sizeof(struct ctnetlink_list_dump_ctx) > sizeof_field(struct netlink_callback, ctx)); + + ret = nfnetlink_subsys_register(&ctnl_subsys); + if (ret < 0) { + pr_err("ctnetlink_init: cannot register with nfnetlink.\n"); + goto err_out; + } + + ret = nfnetlink_subsys_register(&ctnl_exp_subsys); + if (ret < 0) { + pr_err("ctnetlink_init: cannot register exp with nfnetlink.\n"); + goto err_unreg_subsys; + } + + ret = register_pernet_subsys(&ctnetlink_net_ops); + if (ret < 0) { + pr_err("ctnetlink_init: cannot register pernet operations\n"); + goto err_unreg_exp_subsys; + } +#ifdef CONFIG_NETFILTER_NETLINK_GLUE_CT + /* setup interaction between nf_queue and nf_conntrack_netlink. */ + RCU_INIT_POINTER(nfnl_ct_hook, &ctnetlink_glue_hook); +#endif + return 0; + +err_unreg_exp_subsys: + nfnetlink_subsys_unregister(&ctnl_exp_subsys); +err_unreg_subsys: + nfnetlink_subsys_unregister(&ctnl_subsys); +err_out: + return ret; +} + +static void __exit ctnetlink_exit(void) +{ + unregister_pernet_subsys(&ctnetlink_net_ops); + nfnetlink_subsys_unregister(&ctnl_exp_subsys); + nfnetlink_subsys_unregister(&ctnl_subsys); +#ifdef CONFIG_NETFILTER_NETLINK_GLUE_CT + RCU_INIT_POINTER(nfnl_ct_hook, NULL); +#endif + synchronize_rcu(); +} + +module_init(ctnetlink_init); +module_exit(ctnetlink_exit); diff --git a/net/netfilter/nf_conntrack_pptp.c b/net/netfilter/nf_conntrack_pptp.c new file mode 100644 index 000000000..4c679638d --- /dev/null +++ b/net/netfilter/nf_conntrack_pptp.c @@ -0,0 +1,613 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Connection tracking support for PPTP (Point to Point Tunneling Protocol). + * PPTP is a protocol for creating virtual private networks. + * It is a specification defined by Microsoft and some vendors + * working with Microsoft. PPTP is built on top of a modified + * version of the Internet Generic Routing Encapsulation Protocol. + * GRE is defined in RFC 1701 and RFC 1702. Documentation of + * PPTP can be found in RFC 2637 + * + * (C) 2000-2005 by Harald Welte + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + * + * (C) 2006-2012 Patrick McHardy + * + * Limitations: + * - We blindly assume that control connections are always + * established in PNS->PAC direction. This is a violation + * of RFC 2637 + * - We can only support one single call within each session + * TODO: + * - testing of incoming PPTP calls + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#define NF_CT_PPTP_VERSION "3.1" + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Harald Welte "); +MODULE_DESCRIPTION("Netfilter connection tracking helper module for PPTP"); +MODULE_ALIAS("ip_conntrack_pptp"); +MODULE_ALIAS_NFCT_HELPER("pptp"); + +static DEFINE_SPINLOCK(nf_pptp_lock); + +const struct nf_nat_pptp_hook __rcu *nf_nat_pptp_hook; +EXPORT_SYMBOL_GPL(nf_nat_pptp_hook); + +#if defined(DEBUG) || defined(CONFIG_DYNAMIC_DEBUG) +/* PptpControlMessageType names */ +static const char *const pptp_msg_name_array[PPTP_MSG_MAX + 1] = { + [0] = "UNKNOWN_MESSAGE", + [PPTP_START_SESSION_REQUEST] = "START_SESSION_REQUEST", + [PPTP_START_SESSION_REPLY] = "START_SESSION_REPLY", + [PPTP_STOP_SESSION_REQUEST] = "STOP_SESSION_REQUEST", + [PPTP_STOP_SESSION_REPLY] = "STOP_SESSION_REPLY", + [PPTP_ECHO_REQUEST] = "ECHO_REQUEST", + [PPTP_ECHO_REPLY] = "ECHO_REPLY", + [PPTP_OUT_CALL_REQUEST] = "OUT_CALL_REQUEST", + [PPTP_OUT_CALL_REPLY] = "OUT_CALL_REPLY", + [PPTP_IN_CALL_REQUEST] = "IN_CALL_REQUEST", + [PPTP_IN_CALL_REPLY] = "IN_CALL_REPLY", + [PPTP_IN_CALL_CONNECT] = "IN_CALL_CONNECT", + [PPTP_CALL_CLEAR_REQUEST] = "CALL_CLEAR_REQUEST", + [PPTP_CALL_DISCONNECT_NOTIFY] = "CALL_DISCONNECT_NOTIFY", + [PPTP_WAN_ERROR_NOTIFY] = "WAN_ERROR_NOTIFY", + [PPTP_SET_LINK_INFO] = "SET_LINK_INFO" +}; + +const char *pptp_msg_name(u_int16_t msg) +{ + if (msg > PPTP_MSG_MAX) + return pptp_msg_name_array[0]; + + return pptp_msg_name_array[msg]; +} +EXPORT_SYMBOL(pptp_msg_name); +#endif + +#define SECS *HZ +#define MINS * 60 SECS +#define HOURS * 60 MINS + +#define PPTP_GRE_TIMEOUT (10 MINS) +#define PPTP_GRE_STREAM_TIMEOUT (5 HOURS) + +static void pptp_expectfn(struct nf_conn *ct, + struct nf_conntrack_expect *exp) +{ + const struct nf_nat_pptp_hook *hook; + struct net *net = nf_ct_net(ct); + pr_debug("increasing timeouts\n"); + + /* increase timeout of GRE data channel conntrack entry */ + ct->proto.gre.timeout = PPTP_GRE_TIMEOUT; + ct->proto.gre.stream_timeout = PPTP_GRE_STREAM_TIMEOUT; + + /* Can you see how rusty this code is, compared with the pre-2.6.11 + * one? That's what happened to my shiny newnat of 2002 ;( -HW */ + + hook = rcu_dereference(nf_nat_pptp_hook); + if (hook && ct->master->status & IPS_NAT_MASK) + hook->expectfn(ct, exp); + else { + struct nf_conntrack_tuple inv_t; + struct nf_conntrack_expect *exp_other; + + /* obviously this tuple inversion only works until you do NAT */ + nf_ct_invert_tuple(&inv_t, &exp->tuple); + pr_debug("trying to unexpect other dir: "); + nf_ct_dump_tuple(&inv_t); + + exp_other = nf_ct_expect_find_get(net, nf_ct_zone(ct), &inv_t); + if (exp_other) { + /* delete other expectation. */ + pr_debug("found\n"); + nf_ct_unexpect_related(exp_other); + nf_ct_expect_put(exp_other); + } else { + pr_debug("not found\n"); + } + } +} + +static int destroy_sibling_or_exp(struct net *net, struct nf_conn *ct, + const struct nf_conntrack_tuple *t) +{ + const struct nf_conntrack_tuple_hash *h; + const struct nf_conntrack_zone *zone; + struct nf_conntrack_expect *exp; + struct nf_conn *sibling; + + pr_debug("trying to timeout ct or exp for tuple "); + nf_ct_dump_tuple(t); + + zone = nf_ct_zone(ct); + h = nf_conntrack_find_get(net, zone, t); + if (h) { + sibling = nf_ct_tuplehash_to_ctrack(h); + pr_debug("setting timeout of conntrack %p to 0\n", sibling); + sibling->proto.gre.timeout = 0; + sibling->proto.gre.stream_timeout = 0; + nf_ct_kill(sibling); + nf_ct_put(sibling); + return 1; + } else { + exp = nf_ct_expect_find_get(net, zone, t); + if (exp) { + pr_debug("unexpect_related of expect %p\n", exp); + nf_ct_unexpect_related(exp); + nf_ct_expect_put(exp); + return 1; + } + } + return 0; +} + +/* timeout GRE data connections */ +static void pptp_destroy_siblings(struct nf_conn *ct) +{ + struct net *net = nf_ct_net(ct); + const struct nf_ct_pptp_master *ct_pptp_info = nfct_help_data(ct); + struct nf_conntrack_tuple t; + + nf_ct_gre_keymap_destroy(ct); + + /* try original (pns->pac) tuple */ + memcpy(&t, &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple, sizeof(t)); + t.dst.protonum = IPPROTO_GRE; + t.src.u.gre.key = ct_pptp_info->pns_call_id; + t.dst.u.gre.key = ct_pptp_info->pac_call_id; + if (!destroy_sibling_or_exp(net, ct, &t)) + pr_debug("failed to timeout original pns->pac ct/exp\n"); + + /* try reply (pac->pns) tuple */ + memcpy(&t, &ct->tuplehash[IP_CT_DIR_REPLY].tuple, sizeof(t)); + t.dst.protonum = IPPROTO_GRE; + t.src.u.gre.key = ct_pptp_info->pac_call_id; + t.dst.u.gre.key = ct_pptp_info->pns_call_id; + if (!destroy_sibling_or_exp(net, ct, &t)) + pr_debug("failed to timeout reply pac->pns ct/exp\n"); +} + +/* expect GRE connections (PNS->PAC and PAC->PNS direction) */ +static int exp_gre(struct nf_conn *ct, __be16 callid, __be16 peer_callid) +{ + struct nf_conntrack_expect *exp_orig, *exp_reply; + const struct nf_nat_pptp_hook *hook; + enum ip_conntrack_dir dir; + int ret = 1; + + exp_orig = nf_ct_expect_alloc(ct); + if (exp_orig == NULL) + goto out; + + exp_reply = nf_ct_expect_alloc(ct); + if (exp_reply == NULL) + goto out_put_orig; + + /* original direction, PNS->PAC */ + dir = IP_CT_DIR_ORIGINAL; + nf_ct_expect_init(exp_orig, NF_CT_EXPECT_CLASS_DEFAULT, + nf_ct_l3num(ct), + &ct->tuplehash[dir].tuple.src.u3, + &ct->tuplehash[dir].tuple.dst.u3, + IPPROTO_GRE, &peer_callid, &callid); + exp_orig->expectfn = pptp_expectfn; + + /* reply direction, PAC->PNS */ + dir = IP_CT_DIR_REPLY; + nf_ct_expect_init(exp_reply, NF_CT_EXPECT_CLASS_DEFAULT, + nf_ct_l3num(ct), + &ct->tuplehash[dir].tuple.src.u3, + &ct->tuplehash[dir].tuple.dst.u3, + IPPROTO_GRE, &callid, &peer_callid); + exp_reply->expectfn = pptp_expectfn; + + hook = rcu_dereference(nf_nat_pptp_hook); + if (hook && ct->status & IPS_NAT_MASK) + hook->exp_gre(exp_orig, exp_reply); + if (nf_ct_expect_related(exp_orig, 0) != 0) + goto out_put_both; + if (nf_ct_expect_related(exp_reply, 0) != 0) + goto out_unexpect_orig; + + /* Add GRE keymap entries */ + if (nf_ct_gre_keymap_add(ct, IP_CT_DIR_ORIGINAL, &exp_orig->tuple) != 0) + goto out_unexpect_both; + if (nf_ct_gre_keymap_add(ct, IP_CT_DIR_REPLY, &exp_reply->tuple) != 0) { + nf_ct_gre_keymap_destroy(ct); + goto out_unexpect_both; + } + ret = 0; + +out_put_both: + nf_ct_expect_put(exp_reply); +out_put_orig: + nf_ct_expect_put(exp_orig); +out: + return ret; + +out_unexpect_both: + nf_ct_unexpect_related(exp_reply); +out_unexpect_orig: + nf_ct_unexpect_related(exp_orig); + goto out_put_both; +} + +static int +pptp_inbound_pkt(struct sk_buff *skb, unsigned int protoff, + struct PptpControlHeader *ctlh, + union pptp_ctrl_union *pptpReq, + unsigned int reqlen, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo) +{ + struct nf_ct_pptp_master *info = nfct_help_data(ct); + const struct nf_nat_pptp_hook *hook; + u_int16_t msg; + __be16 cid = 0, pcid = 0; + + msg = ntohs(ctlh->messageType); + pr_debug("inbound control message %s\n", pptp_msg_name(msg)); + + switch (msg) { + case PPTP_START_SESSION_REPLY: + /* server confirms new control session */ + if (info->sstate < PPTP_SESSION_REQUESTED) + goto invalid; + if (pptpReq->srep.resultCode == PPTP_START_OK) + info->sstate = PPTP_SESSION_CONFIRMED; + else + info->sstate = PPTP_SESSION_ERROR; + break; + + case PPTP_STOP_SESSION_REPLY: + /* server confirms end of control session */ + if (info->sstate > PPTP_SESSION_STOPREQ) + goto invalid; + if (pptpReq->strep.resultCode == PPTP_STOP_OK) + info->sstate = PPTP_SESSION_NONE; + else + info->sstate = PPTP_SESSION_ERROR; + break; + + case PPTP_OUT_CALL_REPLY: + /* server accepted call, we now expect GRE frames */ + if (info->sstate != PPTP_SESSION_CONFIRMED) + goto invalid; + if (info->cstate != PPTP_CALL_OUT_REQ && + info->cstate != PPTP_CALL_OUT_CONF) + goto invalid; + + cid = pptpReq->ocack.callID; + pcid = pptpReq->ocack.peersCallID; + if (info->pns_call_id != pcid) + goto invalid; + pr_debug("%s, CID=%X, PCID=%X\n", pptp_msg_name(msg), + ntohs(cid), ntohs(pcid)); + + if (pptpReq->ocack.resultCode == PPTP_OUTCALL_CONNECT) { + info->cstate = PPTP_CALL_OUT_CONF; + info->pac_call_id = cid; + exp_gre(ct, cid, pcid); + } else + info->cstate = PPTP_CALL_NONE; + break; + + case PPTP_IN_CALL_REQUEST: + /* server tells us about incoming call request */ + if (info->sstate != PPTP_SESSION_CONFIRMED) + goto invalid; + + cid = pptpReq->icreq.callID; + pr_debug("%s, CID=%X\n", pptp_msg_name(msg), ntohs(cid)); + info->cstate = PPTP_CALL_IN_REQ; + info->pac_call_id = cid; + break; + + case PPTP_IN_CALL_CONNECT: + /* server tells us about incoming call established */ + if (info->sstate != PPTP_SESSION_CONFIRMED) + goto invalid; + if (info->cstate != PPTP_CALL_IN_REP && + info->cstate != PPTP_CALL_IN_CONF) + goto invalid; + + pcid = pptpReq->iccon.peersCallID; + cid = info->pac_call_id; + + if (info->pns_call_id != pcid) + goto invalid; + + pr_debug("%s, PCID=%X\n", pptp_msg_name(msg), ntohs(pcid)); + info->cstate = PPTP_CALL_IN_CONF; + + /* we expect a GRE connection from PAC to PNS */ + exp_gre(ct, cid, pcid); + break; + + case PPTP_CALL_DISCONNECT_NOTIFY: + /* server confirms disconnect */ + cid = pptpReq->disc.callID; + pr_debug("%s, CID=%X\n", pptp_msg_name(msg), ntohs(cid)); + info->cstate = PPTP_CALL_NONE; + + /* untrack this call id, unexpect GRE packets */ + pptp_destroy_siblings(ct); + break; + + case PPTP_WAN_ERROR_NOTIFY: + case PPTP_SET_LINK_INFO: + case PPTP_ECHO_REQUEST: + case PPTP_ECHO_REPLY: + /* I don't have to explain these ;) */ + break; + + default: + goto invalid; + } + + hook = rcu_dereference(nf_nat_pptp_hook); + if (hook && ct->status & IPS_NAT_MASK) + return hook->inbound(skb, ct, ctinfo, protoff, ctlh, pptpReq); + return NF_ACCEPT; + +invalid: + pr_debug("invalid %s: type=%d cid=%u pcid=%u " + "cstate=%d sstate=%d pns_cid=%u pac_cid=%u\n", + pptp_msg_name(msg), + msg, ntohs(cid), ntohs(pcid), info->cstate, info->sstate, + ntohs(info->pns_call_id), ntohs(info->pac_call_id)); + return NF_ACCEPT; +} + +static int +pptp_outbound_pkt(struct sk_buff *skb, unsigned int protoff, + struct PptpControlHeader *ctlh, + union pptp_ctrl_union *pptpReq, + unsigned int reqlen, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo) +{ + struct nf_ct_pptp_master *info = nfct_help_data(ct); + const struct nf_nat_pptp_hook *hook; + u_int16_t msg; + __be16 cid = 0, pcid = 0; + + msg = ntohs(ctlh->messageType); + pr_debug("outbound control message %s\n", pptp_msg_name(msg)); + + switch (msg) { + case PPTP_START_SESSION_REQUEST: + /* client requests for new control session */ + if (info->sstate != PPTP_SESSION_NONE) + goto invalid; + info->sstate = PPTP_SESSION_REQUESTED; + break; + + case PPTP_STOP_SESSION_REQUEST: + /* client requests end of control session */ + info->sstate = PPTP_SESSION_STOPREQ; + break; + + case PPTP_OUT_CALL_REQUEST: + /* client initiating connection to server */ + if (info->sstate != PPTP_SESSION_CONFIRMED) + goto invalid; + info->cstate = PPTP_CALL_OUT_REQ; + /* track PNS call id */ + cid = pptpReq->ocreq.callID; + pr_debug("%s, CID=%X\n", pptp_msg_name(msg), ntohs(cid)); + info->pns_call_id = cid; + break; + + case PPTP_IN_CALL_REPLY: + /* client answers incoming call */ + if (info->cstate != PPTP_CALL_IN_REQ && + info->cstate != PPTP_CALL_IN_REP) + goto invalid; + + cid = pptpReq->icack.callID; + pcid = pptpReq->icack.peersCallID; + if (info->pac_call_id != pcid) + goto invalid; + pr_debug("%s, CID=%X PCID=%X\n", pptp_msg_name(msg), + ntohs(cid), ntohs(pcid)); + + if (pptpReq->icack.resultCode == PPTP_INCALL_ACCEPT) { + /* part two of the three-way handshake */ + info->cstate = PPTP_CALL_IN_REP; + info->pns_call_id = cid; + } else + info->cstate = PPTP_CALL_NONE; + break; + + case PPTP_CALL_CLEAR_REQUEST: + /* client requests hangup of call */ + if (info->sstate != PPTP_SESSION_CONFIRMED) + goto invalid; + /* FUTURE: iterate over all calls and check if + * call ID is valid. We don't do this without newnat, + * because we only know about last call */ + info->cstate = PPTP_CALL_CLEAR_REQ; + break; + + case PPTP_SET_LINK_INFO: + case PPTP_ECHO_REQUEST: + case PPTP_ECHO_REPLY: + /* I don't have to explain these ;) */ + break; + + default: + goto invalid; + } + + hook = rcu_dereference(nf_nat_pptp_hook); + if (hook && ct->status & IPS_NAT_MASK) + return hook->outbound(skb, ct, ctinfo, protoff, ctlh, pptpReq); + return NF_ACCEPT; + +invalid: + pr_debug("invalid %s: type=%d cid=%u pcid=%u " + "cstate=%d sstate=%d pns_cid=%u pac_cid=%u\n", + pptp_msg_name(msg), + msg, ntohs(cid), ntohs(pcid), info->cstate, info->sstate, + ntohs(info->pns_call_id), ntohs(info->pac_call_id)); + return NF_ACCEPT; +} + +static const unsigned int pptp_msg_size[] = { + [PPTP_START_SESSION_REQUEST] = sizeof(struct PptpStartSessionRequest), + [PPTP_START_SESSION_REPLY] = sizeof(struct PptpStartSessionReply), + [PPTP_STOP_SESSION_REQUEST] = sizeof(struct PptpStopSessionRequest), + [PPTP_STOP_SESSION_REPLY] = sizeof(struct PptpStopSessionReply), + [PPTP_OUT_CALL_REQUEST] = sizeof(struct PptpOutCallRequest), + [PPTP_OUT_CALL_REPLY] = sizeof(struct PptpOutCallReply), + [PPTP_IN_CALL_REQUEST] = sizeof(struct PptpInCallRequest), + [PPTP_IN_CALL_REPLY] = sizeof(struct PptpInCallReply), + [PPTP_IN_CALL_CONNECT] = sizeof(struct PptpInCallConnected), + [PPTP_CALL_CLEAR_REQUEST] = sizeof(struct PptpClearCallRequest), + [PPTP_CALL_DISCONNECT_NOTIFY] = sizeof(struct PptpCallDisconnectNotify), + [PPTP_WAN_ERROR_NOTIFY] = sizeof(struct PptpWanErrorNotify), + [PPTP_SET_LINK_INFO] = sizeof(struct PptpSetLinkInfo), +}; + +/* track caller id inside control connection, call expect_related */ +static int +conntrack_pptp_help(struct sk_buff *skb, unsigned int protoff, + struct nf_conn *ct, enum ip_conntrack_info ctinfo) + +{ + int dir = CTINFO2DIR(ctinfo); + const struct nf_ct_pptp_master *info = nfct_help_data(ct); + const struct tcphdr *tcph; + struct tcphdr _tcph; + const struct pptp_pkt_hdr *pptph; + struct pptp_pkt_hdr _pptph; + struct PptpControlHeader _ctlh, *ctlh; + union pptp_ctrl_union _pptpReq, *pptpReq; + unsigned int tcplen = skb->len - protoff; + unsigned int datalen, reqlen, nexthdr_off; + int oldsstate, oldcstate; + int ret; + u_int16_t msg; + +#if IS_ENABLED(CONFIG_NF_NAT) + if (!nf_ct_is_confirmed(ct) && (ct->status & IPS_NAT_MASK)) { + struct nf_conn_nat *nat = nf_ct_ext_find(ct, NF_CT_EXT_NAT); + + if (!nat && !nf_ct_ext_add(ct, NF_CT_EXT_NAT, GFP_ATOMIC)) + return NF_DROP; + } +#endif + /* don't do any tracking before tcp handshake complete */ + if (ctinfo != IP_CT_ESTABLISHED && ctinfo != IP_CT_ESTABLISHED_REPLY) + return NF_ACCEPT; + + nexthdr_off = protoff; + tcph = skb_header_pointer(skb, nexthdr_off, sizeof(_tcph), &_tcph); + if (!tcph) + return NF_ACCEPT; + + nexthdr_off += tcph->doff * 4; + datalen = tcplen - tcph->doff * 4; + + pptph = skb_header_pointer(skb, nexthdr_off, sizeof(_pptph), &_pptph); + if (!pptph) { + pr_debug("no full PPTP header, can't track\n"); + return NF_ACCEPT; + } + nexthdr_off += sizeof(_pptph); + datalen -= sizeof(_pptph); + + /* if it's not a control message we can't do anything with it */ + if (ntohs(pptph->packetType) != PPTP_PACKET_CONTROL || + ntohl(pptph->magicCookie) != PPTP_MAGIC_COOKIE) { + pr_debug("not a control packet\n"); + return NF_ACCEPT; + } + + ctlh = skb_header_pointer(skb, nexthdr_off, sizeof(_ctlh), &_ctlh); + if (!ctlh) + return NF_ACCEPT; + nexthdr_off += sizeof(_ctlh); + datalen -= sizeof(_ctlh); + + reqlen = datalen; + msg = ntohs(ctlh->messageType); + if (msg > 0 && msg <= PPTP_MSG_MAX && reqlen < pptp_msg_size[msg]) + return NF_ACCEPT; + if (reqlen > sizeof(*pptpReq)) + reqlen = sizeof(*pptpReq); + + pptpReq = skb_header_pointer(skb, nexthdr_off, reqlen, &_pptpReq); + if (!pptpReq) + return NF_ACCEPT; + + oldsstate = info->sstate; + oldcstate = info->cstate; + + spin_lock_bh(&nf_pptp_lock); + + /* FIXME: We just blindly assume that the control connection is always + * established from PNS->PAC. However, RFC makes no guarantee */ + if (dir == IP_CT_DIR_ORIGINAL) + /* client -> server (PNS -> PAC) */ + ret = pptp_outbound_pkt(skb, protoff, ctlh, pptpReq, reqlen, ct, + ctinfo); + else + /* server -> client (PAC -> PNS) */ + ret = pptp_inbound_pkt(skb, protoff, ctlh, pptpReq, reqlen, ct, + ctinfo); + pr_debug("sstate: %d->%d, cstate: %d->%d\n", + oldsstate, info->sstate, oldcstate, info->cstate); + spin_unlock_bh(&nf_pptp_lock); + + return ret; +} + +static const struct nf_conntrack_expect_policy pptp_exp_policy = { + .max_expected = 2, + .timeout = 5 * 60, +}; + +/* control protocol helper */ +static struct nf_conntrack_helper pptp __read_mostly = { + .name = "pptp", + .me = THIS_MODULE, + .tuple.src.l3num = AF_INET, + .tuple.src.u.tcp.port = cpu_to_be16(PPTP_CONTROL_PORT), + .tuple.dst.protonum = IPPROTO_TCP, + .help = conntrack_pptp_help, + .destroy = pptp_destroy_siblings, + .expect_policy = &pptp_exp_policy, +}; + +static int __init nf_conntrack_pptp_init(void) +{ + NF_CT_HELPER_BUILD_BUG_ON(sizeof(struct nf_ct_pptp_master)); + + return nf_conntrack_helper_register(&pptp); +} + +static void __exit nf_conntrack_pptp_fini(void) +{ + nf_conntrack_helper_unregister(&pptp); +} + +module_init(nf_conntrack_pptp_init); +module_exit(nf_conntrack_pptp_fini); diff --git a/net/netfilter/nf_conntrack_proto.c b/net/netfilter/nf_conntrack_proto.c new file mode 100644 index 000000000..895b09cbd --- /dev/null +++ b/net/netfilter/nf_conntrack_proto.c @@ -0,0 +1,726 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +static DEFINE_MUTEX(nf_ct_proto_mutex); + +#ifdef CONFIG_SYSCTL +__printf(4, 5) +void nf_l4proto_log_invalid(const struct sk_buff *skb, + const struct nf_hook_state *state, + u8 protonum, + const char *fmt, ...) +{ + struct net *net = state->net; + struct va_format vaf; + va_list args; + + if (net->ct.sysctl_log_invalid != protonum && + net->ct.sysctl_log_invalid != IPPROTO_RAW) + return; + + va_start(args, fmt); + vaf.fmt = fmt; + vaf.va = &args; + + nf_log_packet(net, state->pf, 0, skb, state->in, state->out, + NULL, "nf_ct_proto_%d: %pV ", protonum, &vaf); + va_end(args); +} +EXPORT_SYMBOL_GPL(nf_l4proto_log_invalid); + +__printf(4, 5) +void nf_ct_l4proto_log_invalid(const struct sk_buff *skb, + const struct nf_conn *ct, + const struct nf_hook_state *state, + const char *fmt, ...) +{ + struct va_format vaf; + struct net *net; + va_list args; + + net = nf_ct_net(ct); + if (likely(net->ct.sysctl_log_invalid == 0)) + return; + + va_start(args, fmt); + vaf.fmt = fmt; + vaf.va = &args; + + nf_l4proto_log_invalid(skb, state, + nf_ct_protonum(ct), "%pV", &vaf); + va_end(args); +} +EXPORT_SYMBOL_GPL(nf_ct_l4proto_log_invalid); +#endif + +const struct nf_conntrack_l4proto *nf_ct_l4proto_find(u8 l4proto) +{ + switch (l4proto) { + case IPPROTO_UDP: return &nf_conntrack_l4proto_udp; + case IPPROTO_TCP: return &nf_conntrack_l4proto_tcp; + case IPPROTO_ICMP: return &nf_conntrack_l4proto_icmp; +#ifdef CONFIG_NF_CT_PROTO_DCCP + case IPPROTO_DCCP: return &nf_conntrack_l4proto_dccp; +#endif +#ifdef CONFIG_NF_CT_PROTO_SCTP + case IPPROTO_SCTP: return &nf_conntrack_l4proto_sctp; +#endif +#ifdef CONFIG_NF_CT_PROTO_UDPLITE + case IPPROTO_UDPLITE: return &nf_conntrack_l4proto_udplite; +#endif +#ifdef CONFIG_NF_CT_PROTO_GRE + case IPPROTO_GRE: return &nf_conntrack_l4proto_gre; +#endif +#if IS_ENABLED(CONFIG_IPV6) + case IPPROTO_ICMPV6: return &nf_conntrack_l4proto_icmpv6; +#endif /* CONFIG_IPV6 */ + } + + return &nf_conntrack_l4proto_generic; +}; +EXPORT_SYMBOL_GPL(nf_ct_l4proto_find); + +unsigned int nf_confirm(struct sk_buff *skb, unsigned int protoff, + struct nf_conn *ct, enum ip_conntrack_info ctinfo) +{ + const struct nf_conn_help *help; + + help = nfct_help(ct); + if (help) { + const struct nf_conntrack_helper *helper; + int ret; + + /* rcu_read_lock()ed by nf_hook_thresh */ + helper = rcu_dereference(help->helper); + if (helper) { + ret = helper->help(skb, + protoff, + ct, ctinfo); + if (ret != NF_ACCEPT) + return ret; + } + } + + if (test_bit(IPS_SEQ_ADJUST_BIT, &ct->status) && + !nf_is_loopback_packet(skb)) { + if (!nf_ct_seq_adjust(skb, ct, ctinfo, protoff)) { + NF_CT_STAT_INC_ATOMIC(nf_ct_net(ct), drop); + return NF_DROP; + } + } + + /* We've seen it coming out the other side: confirm it */ + return nf_conntrack_confirm(skb); +} +EXPORT_SYMBOL_GPL(nf_confirm); + +static bool in_vrf_postrouting(const struct nf_hook_state *state) +{ +#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV) + if (state->hook == NF_INET_POST_ROUTING && + netif_is_l3_master(state->out)) + return true; +#endif + return false; +} + +static unsigned int ipv4_confirm(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + + ct = nf_ct_get(skb, &ctinfo); + if (!ct || ctinfo == IP_CT_RELATED_REPLY) + return nf_conntrack_confirm(skb); + + if (in_vrf_postrouting(state)) + return NF_ACCEPT; + + return nf_confirm(skb, + skb_network_offset(skb) + ip_hdrlen(skb), + ct, ctinfo); +} + +static unsigned int ipv4_conntrack_in(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + return nf_conntrack_in(skb, state); +} + +static unsigned int ipv4_conntrack_local(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + if (ip_is_fragment(ip_hdr(skb))) { /* IP_NODEFRAG setsockopt set */ + enum ip_conntrack_info ctinfo; + struct nf_conn *tmpl; + + tmpl = nf_ct_get(skb, &ctinfo); + if (tmpl && nf_ct_is_template(tmpl)) { + /* when skipping ct, clear templates to avoid fooling + * later targets/matches + */ + skb->_nfct = 0; + nf_ct_put(tmpl); + } + return NF_ACCEPT; + } + + return nf_conntrack_in(skb, state); +} + +/* Connection tracking may drop packets, but never alters them, so + * make it the first hook. + */ +static const struct nf_hook_ops ipv4_conntrack_ops[] = { + { + .hook = ipv4_conntrack_in, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_PRE_ROUTING, + .priority = NF_IP_PRI_CONNTRACK, + }, + { + .hook = ipv4_conntrack_local, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_LOCAL_OUT, + .priority = NF_IP_PRI_CONNTRACK, + }, + { + .hook = ipv4_confirm, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_POST_ROUTING, + .priority = NF_IP_PRI_CONNTRACK_CONFIRM, + }, + { + .hook = ipv4_confirm, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_LOCAL_IN, + .priority = NF_IP_PRI_CONNTRACK_CONFIRM, + }, +}; + +/* Fast function for those who don't want to parse /proc (and I don't + * blame them). + * Reversing the socket's dst/src point of view gives us the reply + * mapping. + */ +static int +getorigdst(struct sock *sk, int optval, void __user *user, int *len) +{ + const struct inet_sock *inet = inet_sk(sk); + const struct nf_conntrack_tuple_hash *h; + struct nf_conntrack_tuple tuple; + + memset(&tuple, 0, sizeof(tuple)); + + lock_sock(sk); + tuple.src.u3.ip = inet->inet_rcv_saddr; + tuple.src.u.tcp.port = inet->inet_sport; + tuple.dst.u3.ip = inet->inet_daddr; + tuple.dst.u.tcp.port = inet->inet_dport; + tuple.src.l3num = PF_INET; + tuple.dst.protonum = sk->sk_protocol; + release_sock(sk); + + /* We only do TCP and SCTP at the moment: is there a better way? */ + if (tuple.dst.protonum != IPPROTO_TCP && + tuple.dst.protonum != IPPROTO_SCTP) { + pr_debug("SO_ORIGINAL_DST: Not a TCP/SCTP socket\n"); + return -ENOPROTOOPT; + } + + if ((unsigned int)*len < sizeof(struct sockaddr_in)) { + pr_debug("SO_ORIGINAL_DST: len %d not %zu\n", + *len, sizeof(struct sockaddr_in)); + return -EINVAL; + } + + h = nf_conntrack_find_get(sock_net(sk), &nf_ct_zone_dflt, &tuple); + if (h) { + struct sockaddr_in sin; + struct nf_conn *ct = nf_ct_tuplehash_to_ctrack(h); + + sin.sin_family = AF_INET; + sin.sin_port = ct->tuplehash[IP_CT_DIR_ORIGINAL] + .tuple.dst.u.tcp.port; + sin.sin_addr.s_addr = ct->tuplehash[IP_CT_DIR_ORIGINAL] + .tuple.dst.u3.ip; + memset(sin.sin_zero, 0, sizeof(sin.sin_zero)); + + pr_debug("SO_ORIGINAL_DST: %pI4 %u\n", + &sin.sin_addr.s_addr, ntohs(sin.sin_port)); + nf_ct_put(ct); + if (copy_to_user(user, &sin, sizeof(sin)) != 0) + return -EFAULT; + else + return 0; + } + pr_debug("SO_ORIGINAL_DST: Can't find %pI4/%u-%pI4/%u.\n", + &tuple.src.u3.ip, ntohs(tuple.src.u.tcp.port), + &tuple.dst.u3.ip, ntohs(tuple.dst.u.tcp.port)); + return -ENOENT; +} + +static struct nf_sockopt_ops so_getorigdst = { + .pf = PF_INET, + .get_optmin = SO_ORIGINAL_DST, + .get_optmax = SO_ORIGINAL_DST + 1, + .get = getorigdst, + .owner = THIS_MODULE, +}; + +#if IS_ENABLED(CONFIG_IPV6) +static int +ipv6_getorigdst(struct sock *sk, int optval, void __user *user, int *len) +{ + struct nf_conntrack_tuple tuple = { .src.l3num = NFPROTO_IPV6 }; + const struct ipv6_pinfo *inet6 = inet6_sk(sk); + const struct inet_sock *inet = inet_sk(sk); + const struct nf_conntrack_tuple_hash *h; + struct sockaddr_in6 sin6; + struct nf_conn *ct; + __be32 flow_label; + int bound_dev_if; + + lock_sock(sk); + tuple.src.u3.in6 = sk->sk_v6_rcv_saddr; + tuple.src.u.tcp.port = inet->inet_sport; + tuple.dst.u3.in6 = sk->sk_v6_daddr; + tuple.dst.u.tcp.port = inet->inet_dport; + tuple.dst.protonum = sk->sk_protocol; + bound_dev_if = sk->sk_bound_dev_if; + flow_label = inet6->flow_label; + release_sock(sk); + + if (tuple.dst.protonum != IPPROTO_TCP && + tuple.dst.protonum != IPPROTO_SCTP) + return -ENOPROTOOPT; + + if (*len < 0 || (unsigned int)*len < sizeof(sin6)) + return -EINVAL; + + h = nf_conntrack_find_get(sock_net(sk), &nf_ct_zone_dflt, &tuple); + if (!h) { + pr_debug("IP6T_SO_ORIGINAL_DST: Can't find %pI6c/%u-%pI6c/%u.\n", + &tuple.src.u3.ip6, ntohs(tuple.src.u.tcp.port), + &tuple.dst.u3.ip6, ntohs(tuple.dst.u.tcp.port)); + return -ENOENT; + } + + ct = nf_ct_tuplehash_to_ctrack(h); + + sin6.sin6_family = AF_INET6; + sin6.sin6_port = ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.u.tcp.port; + sin6.sin6_flowinfo = flow_label & IPV6_FLOWINFO_MASK; + memcpy(&sin6.sin6_addr, + &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.u3.in6, + sizeof(sin6.sin6_addr)); + + nf_ct_put(ct); + sin6.sin6_scope_id = ipv6_iface_scope_id(&sin6.sin6_addr, bound_dev_if); + return copy_to_user(user, &sin6, sizeof(sin6)) ? -EFAULT : 0; +} + +static struct nf_sockopt_ops so_getorigdst6 = { + .pf = NFPROTO_IPV6, + .get_optmin = IP6T_SO_ORIGINAL_DST, + .get_optmax = IP6T_SO_ORIGINAL_DST + 1, + .get = ipv6_getorigdst, + .owner = THIS_MODULE, +}; + +static unsigned int ipv6_confirm(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + unsigned char pnum = ipv6_hdr(skb)->nexthdr; + __be16 frag_off; + int protoff; + + ct = nf_ct_get(skb, &ctinfo); + if (!ct || ctinfo == IP_CT_RELATED_REPLY) + return nf_conntrack_confirm(skb); + + if (in_vrf_postrouting(state)) + return NF_ACCEPT; + + protoff = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &pnum, + &frag_off); + if (protoff < 0 || (frag_off & htons(~0x7)) != 0) { + pr_debug("proto header not found\n"); + return nf_conntrack_confirm(skb); + } + + return nf_confirm(skb, protoff, ct, ctinfo); +} + +static unsigned int ipv6_conntrack_in(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + return nf_conntrack_in(skb, state); +} + +static unsigned int ipv6_conntrack_local(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + return nf_conntrack_in(skb, state); +} + +static const struct nf_hook_ops ipv6_conntrack_ops[] = { + { + .hook = ipv6_conntrack_in, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_PRE_ROUTING, + .priority = NF_IP6_PRI_CONNTRACK, + }, + { + .hook = ipv6_conntrack_local, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_LOCAL_OUT, + .priority = NF_IP6_PRI_CONNTRACK, + }, + { + .hook = ipv6_confirm, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_POST_ROUTING, + .priority = NF_IP6_PRI_LAST, + }, + { + .hook = ipv6_confirm, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_LOCAL_IN, + .priority = NF_IP6_PRI_LAST - 1, + }, +}; +#endif + +static int nf_ct_tcp_fixup(struct nf_conn *ct, void *_nfproto) +{ + u8 nfproto = (unsigned long)_nfproto; + + if (nf_ct_l3num(ct) != nfproto) + return 0; + + if (nf_ct_protonum(ct) == IPPROTO_TCP && + ct->proto.tcp.state == TCP_CONNTRACK_ESTABLISHED) { + ct->proto.tcp.seen[0].td_maxwin = 0; + ct->proto.tcp.seen[1].td_maxwin = 0; + } + + return 0; +} + +static struct nf_ct_bridge_info *nf_ct_bridge_info; + +static int nf_ct_netns_do_get(struct net *net, u8 nfproto) +{ + struct nf_conntrack_net *cnet = nf_ct_pernet(net); + bool fixup_needed = false, retry = true; + int err = 0; +retry: + mutex_lock(&nf_ct_proto_mutex); + + switch (nfproto) { + case NFPROTO_IPV4: + cnet->users4++; + if (cnet->users4 > 1) + goto out_unlock; + err = nf_defrag_ipv4_enable(net); + if (err) { + cnet->users4 = 0; + goto out_unlock; + } + + err = nf_register_net_hooks(net, ipv4_conntrack_ops, + ARRAY_SIZE(ipv4_conntrack_ops)); + if (err) + cnet->users4 = 0; + else + fixup_needed = true; + break; +#if IS_ENABLED(CONFIG_IPV6) + case NFPROTO_IPV6: + cnet->users6++; + if (cnet->users6 > 1) + goto out_unlock; + err = nf_defrag_ipv6_enable(net); + if (err < 0) { + cnet->users6 = 0; + goto out_unlock; + } + + err = nf_register_net_hooks(net, ipv6_conntrack_ops, + ARRAY_SIZE(ipv6_conntrack_ops)); + if (err) + cnet->users6 = 0; + else + fixup_needed = true; + break; +#endif + case NFPROTO_BRIDGE: + if (!nf_ct_bridge_info) { + if (!retry) { + err = -EPROTO; + goto out_unlock; + } + mutex_unlock(&nf_ct_proto_mutex); + request_module("nf_conntrack_bridge"); + retry = false; + goto retry; + } + if (!try_module_get(nf_ct_bridge_info->me)) { + err = -EPROTO; + goto out_unlock; + } + cnet->users_bridge++; + if (cnet->users_bridge > 1) + goto out_unlock; + + err = nf_register_net_hooks(net, nf_ct_bridge_info->ops, + nf_ct_bridge_info->ops_size); + if (err) + cnet->users_bridge = 0; + else + fixup_needed = true; + break; + default: + err = -EPROTO; + break; + } + out_unlock: + mutex_unlock(&nf_ct_proto_mutex); + + if (fixup_needed) { + struct nf_ct_iter_data iter_data = { + .net = net, + .data = (void *)(unsigned long)nfproto, + }; + nf_ct_iterate_cleanup_net(nf_ct_tcp_fixup, &iter_data); + } + + return err; +} + +static void nf_ct_netns_do_put(struct net *net, u8 nfproto) +{ + struct nf_conntrack_net *cnet = nf_ct_pernet(net); + + mutex_lock(&nf_ct_proto_mutex); + switch (nfproto) { + case NFPROTO_IPV4: + if (cnet->users4 && (--cnet->users4 == 0)) { + nf_unregister_net_hooks(net, ipv4_conntrack_ops, + ARRAY_SIZE(ipv4_conntrack_ops)); + nf_defrag_ipv4_disable(net); + } + break; +#if IS_ENABLED(CONFIG_IPV6) + case NFPROTO_IPV6: + if (cnet->users6 && (--cnet->users6 == 0)) { + nf_unregister_net_hooks(net, ipv6_conntrack_ops, + ARRAY_SIZE(ipv6_conntrack_ops)); + nf_defrag_ipv6_disable(net); + } + break; +#endif + case NFPROTO_BRIDGE: + if (!nf_ct_bridge_info) + break; + if (cnet->users_bridge && (--cnet->users_bridge == 0)) + nf_unregister_net_hooks(net, nf_ct_bridge_info->ops, + nf_ct_bridge_info->ops_size); + + module_put(nf_ct_bridge_info->me); + break; + } + mutex_unlock(&nf_ct_proto_mutex); +} + +static int nf_ct_netns_inet_get(struct net *net) +{ + int err; + + err = nf_ct_netns_do_get(net, NFPROTO_IPV4); +#if IS_ENABLED(CONFIG_IPV6) + if (err < 0) + goto err1; + err = nf_ct_netns_do_get(net, NFPROTO_IPV6); + if (err < 0) + goto err2; + + return err; +err2: + nf_ct_netns_put(net, NFPROTO_IPV4); +err1: +#endif + return err; +} + +int nf_ct_netns_get(struct net *net, u8 nfproto) +{ + int err; + + switch (nfproto) { + case NFPROTO_INET: + err = nf_ct_netns_inet_get(net); + break; + case NFPROTO_BRIDGE: + err = nf_ct_netns_do_get(net, NFPROTO_BRIDGE); + if (err < 0) + return err; + + err = nf_ct_netns_inet_get(net); + if (err < 0) { + nf_ct_netns_put(net, NFPROTO_BRIDGE); + return err; + } + break; + default: + err = nf_ct_netns_do_get(net, nfproto); + break; + } + return err; +} +EXPORT_SYMBOL_GPL(nf_ct_netns_get); + +void nf_ct_netns_put(struct net *net, uint8_t nfproto) +{ + switch (nfproto) { + case NFPROTO_BRIDGE: + nf_ct_netns_do_put(net, NFPROTO_BRIDGE); + fallthrough; + case NFPROTO_INET: + nf_ct_netns_do_put(net, NFPROTO_IPV4); + nf_ct_netns_do_put(net, NFPROTO_IPV6); + break; + default: + nf_ct_netns_do_put(net, nfproto); + break; + } +} +EXPORT_SYMBOL_GPL(nf_ct_netns_put); + +void nf_ct_bridge_register(struct nf_ct_bridge_info *info) +{ + WARN_ON(nf_ct_bridge_info); + mutex_lock(&nf_ct_proto_mutex); + nf_ct_bridge_info = info; + mutex_unlock(&nf_ct_proto_mutex); +} +EXPORT_SYMBOL_GPL(nf_ct_bridge_register); + +void nf_ct_bridge_unregister(struct nf_ct_bridge_info *info) +{ + WARN_ON(!nf_ct_bridge_info); + mutex_lock(&nf_ct_proto_mutex); + nf_ct_bridge_info = NULL; + mutex_unlock(&nf_ct_proto_mutex); +} +EXPORT_SYMBOL_GPL(nf_ct_bridge_unregister); + +int nf_conntrack_proto_init(void) +{ + int ret; + + ret = nf_register_sockopt(&so_getorigdst); + if (ret < 0) + return ret; + +#if IS_ENABLED(CONFIG_IPV6) + ret = nf_register_sockopt(&so_getorigdst6); + if (ret < 0) + goto cleanup_sockopt; +#endif + + return ret; + +#if IS_ENABLED(CONFIG_IPV6) +cleanup_sockopt: + nf_unregister_sockopt(&so_getorigdst); +#endif + return ret; +} + +void nf_conntrack_proto_fini(void) +{ + nf_unregister_sockopt(&so_getorigdst); +#if IS_ENABLED(CONFIG_IPV6) + nf_unregister_sockopt(&so_getorigdst6); +#endif +} + +void nf_conntrack_proto_pernet_init(struct net *net) +{ + nf_conntrack_generic_init_net(net); + nf_conntrack_udp_init_net(net); + nf_conntrack_tcp_init_net(net); + nf_conntrack_icmp_init_net(net); +#if IS_ENABLED(CONFIG_IPV6) + nf_conntrack_icmpv6_init_net(net); +#endif +#ifdef CONFIG_NF_CT_PROTO_DCCP + nf_conntrack_dccp_init_net(net); +#endif +#ifdef CONFIG_NF_CT_PROTO_SCTP + nf_conntrack_sctp_init_net(net); +#endif +#ifdef CONFIG_NF_CT_PROTO_GRE + nf_conntrack_gre_init_net(net); +#endif +} + +module_param_call(hashsize, nf_conntrack_set_hashsize, param_get_uint, + &nf_conntrack_htable_size, 0600); + +MODULE_ALIAS("ip_conntrack"); +MODULE_ALIAS("nf_conntrack-" __stringify(AF_INET)); +MODULE_ALIAS("nf_conntrack-" __stringify(AF_INET6)); +MODULE_LICENSE("GPL"); diff --git a/net/netfilter/nf_conntrack_proto_dccp.c b/net/netfilter/nf_conntrack_proto_dccp.c new file mode 100644 index 000000000..d4fd626d2 --- /dev/null +++ b/net/netfilter/nf_conntrack_proto_dccp.c @@ -0,0 +1,824 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * DCCP connection tracking protocol helper + * + * Copyright (c) 2005, 2006, 2008 Patrick McHardy + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +/* Timeouts are based on values from RFC4340: + * + * - REQUEST: + * + * 8.1.2. Client Request + * + * A client MAY give up on its DCCP-Requests after some time + * (3 minutes, for example). + * + * - RESPOND: + * + * 8.1.3. Server Response + * + * It MAY also leave the RESPOND state for CLOSED after a timeout of + * not less than 4MSL (8 minutes); + * + * - PARTOPEN: + * + * 8.1.5. Handshake Completion + * + * If the client remains in PARTOPEN for more than 4MSL (8 minutes), + * it SHOULD reset the connection with Reset Code 2, "Aborted". + * + * - OPEN: + * + * The DCCP timestamp overflows after 11.9 hours. If the connection + * stays idle this long the sequence number won't be recognized + * as valid anymore. + * + * - CLOSEREQ/CLOSING: + * + * 8.3. Termination + * + * The retransmission timer should initially be set to go off in two + * round-trip times and should back off to not less than once every + * 64 seconds ... + * + * - TIMEWAIT: + * + * 4.3. States + * + * A server or client socket remains in this state for 2MSL (4 minutes) + * after the connection has been town down, ... + */ + +#define DCCP_MSL (2 * 60 * HZ) + +static const char * const dccp_state_names[] = { + [CT_DCCP_NONE] = "NONE", + [CT_DCCP_REQUEST] = "REQUEST", + [CT_DCCP_RESPOND] = "RESPOND", + [CT_DCCP_PARTOPEN] = "PARTOPEN", + [CT_DCCP_OPEN] = "OPEN", + [CT_DCCP_CLOSEREQ] = "CLOSEREQ", + [CT_DCCP_CLOSING] = "CLOSING", + [CT_DCCP_TIMEWAIT] = "TIMEWAIT", + [CT_DCCP_IGNORE] = "IGNORE", + [CT_DCCP_INVALID] = "INVALID", +}; + +#define sNO CT_DCCP_NONE +#define sRQ CT_DCCP_REQUEST +#define sRS CT_DCCP_RESPOND +#define sPO CT_DCCP_PARTOPEN +#define sOP CT_DCCP_OPEN +#define sCR CT_DCCP_CLOSEREQ +#define sCG CT_DCCP_CLOSING +#define sTW CT_DCCP_TIMEWAIT +#define sIG CT_DCCP_IGNORE +#define sIV CT_DCCP_INVALID + +/* + * DCCP state transition table + * + * The assumption is the same as for TCP tracking: + * + * We are the man in the middle. All the packets go through us but might + * get lost in transit to the destination. It is assumed that the destination + * can't receive segments we haven't seen. + * + * The following states exist: + * + * NONE: Initial state, expecting Request + * REQUEST: Request seen, waiting for Response from server + * RESPOND: Response from server seen, waiting for Ack from client + * PARTOPEN: Ack after Response seen, waiting for packet other than Response, + * Reset or Sync from server + * OPEN: Packet other than Response, Reset or Sync seen + * CLOSEREQ: CloseReq from server seen, expecting Close from client + * CLOSING: Close seen, expecting Reset + * TIMEWAIT: Reset seen + * IGNORE: Not determinable whether packet is valid + * + * Some states exist only on one side of the connection: REQUEST, RESPOND, + * PARTOPEN, CLOSEREQ. For the other side these states are equivalent to + * the one it was in before. + * + * Packets are marked as ignored (sIG) if we don't know if they're valid + * (for example a reincarnation of a connection we didn't notice is dead + * already) and the server may send back a connection closing Reset or a + * Response. They're also used for Sync/SyncAck packets, which we don't + * care about. + */ +static const u_int8_t +dccp_state_table[CT_DCCP_ROLE_MAX + 1][DCCP_PKT_SYNCACK + 1][CT_DCCP_MAX + 1] = { + [CT_DCCP_ROLE_CLIENT] = { + [DCCP_PKT_REQUEST] = { + /* + * sNO -> sRQ Regular Request + * sRQ -> sRQ Retransmitted Request or reincarnation + * sRS -> sRS Retransmitted Request (apparently Response + * got lost after we saw it) or reincarnation + * sPO -> sIG Ignore, conntrack might be out of sync + * sOP -> sIG Ignore, conntrack might be out of sync + * sCR -> sIG Ignore, conntrack might be out of sync + * sCG -> sIG Ignore, conntrack might be out of sync + * sTW -> sRQ Reincarnation + * + * sNO, sRQ, sRS, sPO. sOP, sCR, sCG, sTW, */ + sRQ, sRQ, sRS, sIG, sIG, sIG, sIG, sRQ, + }, + [DCCP_PKT_RESPONSE] = { + /* + * sNO -> sIV Invalid + * sRQ -> sIG Ignore, might be response to ignored Request + * sRS -> sIG Ignore, might be response to ignored Request + * sPO -> sIG Ignore, might be response to ignored Request + * sOP -> sIG Ignore, might be response to ignored Request + * sCR -> sIG Ignore, might be response to ignored Request + * sCG -> sIG Ignore, might be response to ignored Request + * sTW -> sIV Invalid, reincarnation in reverse direction + * goes through sRQ + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sIG, sIG, sIG, sIG, sIG, sIG, sIV, + }, + [DCCP_PKT_ACK] = { + /* + * sNO -> sIV No connection + * sRQ -> sIV No connection + * sRS -> sPO Ack for Response, move to PARTOPEN (8.1.5.) + * sPO -> sPO Retransmitted Ack for Response, remain in PARTOPEN + * sOP -> sOP Regular ACK, remain in OPEN + * sCR -> sCR Ack in CLOSEREQ MAY be processed (8.3.) + * sCG -> sCG Ack in CLOSING MAY be processed (8.3.) + * sTW -> sIV + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sIV, sPO, sPO, sOP, sCR, sCG, sIV + }, + [DCCP_PKT_DATA] = { + /* + * sNO -> sIV No connection + * sRQ -> sIV No connection + * sRS -> sIV No connection + * sPO -> sIV MUST use DataAck in PARTOPEN state (8.1.5.) + * sOP -> sOP Regular Data packet + * sCR -> sCR Data in CLOSEREQ MAY be processed (8.3.) + * sCG -> sCG Data in CLOSING MAY be processed (8.3.) + * sTW -> sIV + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sIV, sIV, sIV, sOP, sCR, sCG, sIV, + }, + [DCCP_PKT_DATAACK] = { + /* + * sNO -> sIV No connection + * sRQ -> sIV No connection + * sRS -> sPO Ack for Response, move to PARTOPEN (8.1.5.) + * sPO -> sPO Remain in PARTOPEN state + * sOP -> sOP Regular DataAck packet in OPEN state + * sCR -> sCR DataAck in CLOSEREQ MAY be processed (8.3.) + * sCG -> sCG DataAck in CLOSING MAY be processed (8.3.) + * sTW -> sIV + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sIV, sPO, sPO, sOP, sCR, sCG, sIV + }, + [DCCP_PKT_CLOSEREQ] = { + /* + * CLOSEREQ may only be sent by the server. + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sIV, sIV, sIV, sIV, sIV, sIV, sIV + }, + [DCCP_PKT_CLOSE] = { + /* + * sNO -> sIV No connection + * sRQ -> sIV No connection + * sRS -> sIV No connection + * sPO -> sCG Client-initiated close + * sOP -> sCG Client-initiated close + * sCR -> sCG Close in response to CloseReq (8.3.) + * sCG -> sCG Retransmit + * sTW -> sIV Late retransmit, already in TIME_WAIT + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sIV, sIV, sCG, sCG, sCG, sIV, sIV + }, + [DCCP_PKT_RESET] = { + /* + * sNO -> sIV No connection + * sRQ -> sTW Sync received or timeout, SHOULD send Reset (8.1.1.) + * sRS -> sTW Response received without Request + * sPO -> sTW Timeout, SHOULD send Reset (8.1.5.) + * sOP -> sTW Connection reset + * sCR -> sTW Connection reset + * sCG -> sTW Connection reset + * sTW -> sIG Ignore (don't refresh timer) + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sTW, sTW, sTW, sTW, sTW, sTW, sIG + }, + [DCCP_PKT_SYNC] = { + /* + * We currently ignore Sync packets + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sIG, sIG, sIG, sIG, sIG, sIG, sIG, + }, + [DCCP_PKT_SYNCACK] = { + /* + * We currently ignore SyncAck packets + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sIG, sIG, sIG, sIG, sIG, sIG, sIG, + }, + }, + [CT_DCCP_ROLE_SERVER] = { + [DCCP_PKT_REQUEST] = { + /* + * sNO -> sIV Invalid + * sRQ -> sIG Ignore, conntrack might be out of sync + * sRS -> sIG Ignore, conntrack might be out of sync + * sPO -> sIG Ignore, conntrack might be out of sync + * sOP -> sIG Ignore, conntrack might be out of sync + * sCR -> sIG Ignore, conntrack might be out of sync + * sCG -> sIG Ignore, conntrack might be out of sync + * sTW -> sRQ Reincarnation, must reverse roles + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sIG, sIG, sIG, sIG, sIG, sIG, sRQ + }, + [DCCP_PKT_RESPONSE] = { + /* + * sNO -> sIV Response without Request + * sRQ -> sRS Response to clients Request + * sRS -> sRS Retransmitted Response (8.1.3. SHOULD NOT) + * sPO -> sIG Response to an ignored Request or late retransmit + * sOP -> sIG Ignore, might be response to ignored Request + * sCR -> sIG Ignore, might be response to ignored Request + * sCG -> sIG Ignore, might be response to ignored Request + * sTW -> sIV Invalid, Request from client in sTW moves to sRQ + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sRS, sRS, sIG, sIG, sIG, sIG, sIV + }, + [DCCP_PKT_ACK] = { + /* + * sNO -> sIV No connection + * sRQ -> sIV No connection + * sRS -> sIV No connection + * sPO -> sOP Enter OPEN state (8.1.5.) + * sOP -> sOP Regular Ack in OPEN state + * sCR -> sIV Waiting for Close from client + * sCG -> sCG Ack in CLOSING MAY be processed (8.3.) + * sTW -> sIV + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sIV, sIV, sOP, sOP, sIV, sCG, sIV + }, + [DCCP_PKT_DATA] = { + /* + * sNO -> sIV No connection + * sRQ -> sIV No connection + * sRS -> sIV No connection + * sPO -> sOP Enter OPEN state (8.1.5.) + * sOP -> sOP Regular Data packet in OPEN state + * sCR -> sIV Waiting for Close from client + * sCG -> sCG Data in CLOSING MAY be processed (8.3.) + * sTW -> sIV + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sIV, sIV, sOP, sOP, sIV, sCG, sIV + }, + [DCCP_PKT_DATAACK] = { + /* + * sNO -> sIV No connection + * sRQ -> sIV No connection + * sRS -> sIV No connection + * sPO -> sOP Enter OPEN state (8.1.5.) + * sOP -> sOP Regular DataAck in OPEN state + * sCR -> sIV Waiting for Close from client + * sCG -> sCG Data in CLOSING MAY be processed (8.3.) + * sTW -> sIV + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sIV, sIV, sOP, sOP, sIV, sCG, sIV + }, + [DCCP_PKT_CLOSEREQ] = { + /* + * sNO -> sIV No connection + * sRQ -> sIV No connection + * sRS -> sIV No connection + * sPO -> sOP -> sCR Move directly to CLOSEREQ (8.1.5.) + * sOP -> sCR CloseReq in OPEN state + * sCR -> sCR Retransmit + * sCG -> sCR Simultaneous close, client sends another Close + * sTW -> sIV Already closed + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sIV, sIV, sCR, sCR, sCR, sCR, sIV + }, + [DCCP_PKT_CLOSE] = { + /* + * sNO -> sIV No connection + * sRQ -> sIV No connection + * sRS -> sIV No connection + * sPO -> sOP -> sCG Move direcly to CLOSING + * sOP -> sCG Move to CLOSING + * sCR -> sIV Close after CloseReq is invalid + * sCG -> sCG Retransmit + * sTW -> sIV Already closed + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sIV, sIV, sCG, sCG, sIV, sCG, sIV + }, + [DCCP_PKT_RESET] = { + /* + * sNO -> sIV No connection + * sRQ -> sTW Reset in response to Request + * sRS -> sTW Timeout, SHOULD send Reset (8.1.3.) + * sPO -> sTW Timeout, SHOULD send Reset (8.1.3.) + * sOP -> sTW + * sCR -> sTW + * sCG -> sTW + * sTW -> sIG Ignore (don't refresh timer) + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW, sTW */ + sIV, sTW, sTW, sTW, sTW, sTW, sTW, sTW, sIG + }, + [DCCP_PKT_SYNC] = { + /* + * We currently ignore Sync packets + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sIG, sIG, sIG, sIG, sIG, sIG, sIG, + }, + [DCCP_PKT_SYNCACK] = { + /* + * We currently ignore SyncAck packets + * + * sNO, sRQ, sRS, sPO, sOP, sCR, sCG, sTW */ + sIV, sIG, sIG, sIG, sIG, sIG, sIG, sIG, + }, + }, +}; + +static noinline bool +dccp_new(struct nf_conn *ct, const struct sk_buff *skb, + const struct dccp_hdr *dh, + const struct nf_hook_state *hook_state) +{ + struct net *net = nf_ct_net(ct); + struct nf_dccp_net *dn; + const char *msg; + u_int8_t state; + + state = dccp_state_table[CT_DCCP_ROLE_CLIENT][dh->dccph_type][CT_DCCP_NONE]; + switch (state) { + default: + dn = nf_dccp_pernet(net); + if (dn->dccp_loose == 0) { + msg = "not picking up existing connection "; + goto out_invalid; + } + break; + case CT_DCCP_REQUEST: + break; + case CT_DCCP_INVALID: + msg = "invalid state transition "; + goto out_invalid; + } + + ct->proto.dccp.role[IP_CT_DIR_ORIGINAL] = CT_DCCP_ROLE_CLIENT; + ct->proto.dccp.role[IP_CT_DIR_REPLY] = CT_DCCP_ROLE_SERVER; + ct->proto.dccp.state = CT_DCCP_NONE; + ct->proto.dccp.last_pkt = DCCP_PKT_REQUEST; + ct->proto.dccp.last_dir = IP_CT_DIR_ORIGINAL; + ct->proto.dccp.handshake_seq = 0; + return true; + +out_invalid: + nf_ct_l4proto_log_invalid(skb, ct, hook_state, "%s", msg); + return false; +} + +static u64 dccp_ack_seq(const struct dccp_hdr *dh) +{ + const struct dccp_hdr_ack_bits *dhack; + + dhack = (void *)dh + __dccp_basic_hdr_len(dh); + return ((u64)ntohs(dhack->dccph_ack_nr_high) << 32) + + ntohl(dhack->dccph_ack_nr_low); +} + +static bool dccp_error(const struct dccp_hdr *dh, + struct sk_buff *skb, unsigned int dataoff, + const struct nf_hook_state *state) +{ + static const unsigned long require_seq48 = 1 << DCCP_PKT_REQUEST | + 1 << DCCP_PKT_RESPONSE | + 1 << DCCP_PKT_CLOSEREQ | + 1 << DCCP_PKT_CLOSE | + 1 << DCCP_PKT_RESET | + 1 << DCCP_PKT_SYNC | + 1 << DCCP_PKT_SYNCACK; + unsigned int dccp_len = skb->len - dataoff; + unsigned int cscov; + const char *msg; + u8 type; + + BUILD_BUG_ON(DCCP_PKT_INVALID >= BITS_PER_LONG); + + if (dh->dccph_doff * 4 < sizeof(struct dccp_hdr) || + dh->dccph_doff * 4 > dccp_len) { + msg = "nf_ct_dccp: truncated/malformed packet "; + goto out_invalid; + } + + cscov = dccp_len; + if (dh->dccph_cscov) { + cscov = (dh->dccph_cscov - 1) * 4; + if (cscov > dccp_len) { + msg = "nf_ct_dccp: bad checksum coverage "; + goto out_invalid; + } + } + + if (state->hook == NF_INET_PRE_ROUTING && + state->net->ct.sysctl_checksum && + nf_checksum_partial(skb, state->hook, dataoff, cscov, + IPPROTO_DCCP, state->pf)) { + msg = "nf_ct_dccp: bad checksum "; + goto out_invalid; + } + + type = dh->dccph_type; + if (type >= DCCP_PKT_INVALID) { + msg = "nf_ct_dccp: reserved packet type "; + goto out_invalid; + } + + if (test_bit(type, &require_seq48) && !dh->dccph_x) { + msg = "nf_ct_dccp: type lacks 48bit sequence numbers"; + goto out_invalid; + } + + return false; +out_invalid: + nf_l4proto_log_invalid(skb, state, IPPROTO_DCCP, "%s", msg); + return true; +} + +struct nf_conntrack_dccp_buf { + struct dccp_hdr dh; /* generic header part */ + struct dccp_hdr_ext ext; /* optional depending dh->dccph_x */ + union { /* depends on header type */ + struct dccp_hdr_ack_bits ack; + struct dccp_hdr_request req; + struct dccp_hdr_response response; + struct dccp_hdr_reset rst; + } u; +}; + +static struct dccp_hdr * +dccp_header_pointer(const struct sk_buff *skb, int offset, const struct dccp_hdr *dh, + struct nf_conntrack_dccp_buf *buf) +{ + unsigned int hdrlen = __dccp_hdr_len(dh); + + if (hdrlen > sizeof(*buf)) + return NULL; + + return skb_header_pointer(skb, offset, hdrlen, buf); +} + +int nf_conntrack_dccp_packet(struct nf_conn *ct, struct sk_buff *skb, + unsigned int dataoff, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) +{ + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + struct nf_conntrack_dccp_buf _dh; + u_int8_t type, old_state, new_state; + enum ct_dccp_roles role; + unsigned int *timeouts; + struct dccp_hdr *dh; + + dh = skb_header_pointer(skb, dataoff, sizeof(*dh), &_dh.dh); + if (!dh) + return NF_DROP; + + if (dccp_error(dh, skb, dataoff, state)) + return -NF_ACCEPT; + + /* pull again, including possible 48 bit sequences and subtype header */ + dh = dccp_header_pointer(skb, dataoff, dh, &_dh); + if (!dh) + return NF_DROP; + + type = dh->dccph_type; + if (!nf_ct_is_confirmed(ct) && !dccp_new(ct, skb, dh, state)) + return -NF_ACCEPT; + + if (type == DCCP_PKT_RESET && + !test_bit(IPS_SEEN_REPLY_BIT, &ct->status)) { + /* Tear down connection immediately if only reply is a RESET */ + nf_ct_kill_acct(ct, ctinfo, skb); + return NF_ACCEPT; + } + + spin_lock_bh(&ct->lock); + + role = ct->proto.dccp.role[dir]; + old_state = ct->proto.dccp.state; + new_state = dccp_state_table[role][type][old_state]; + + switch (new_state) { + case CT_DCCP_REQUEST: + if (old_state == CT_DCCP_TIMEWAIT && + role == CT_DCCP_ROLE_SERVER) { + /* Reincarnation in the reverse direction: reopen and + * reverse client/server roles. */ + ct->proto.dccp.role[dir] = CT_DCCP_ROLE_CLIENT; + ct->proto.dccp.role[!dir] = CT_DCCP_ROLE_SERVER; + } + break; + case CT_DCCP_RESPOND: + if (old_state == CT_DCCP_REQUEST) + ct->proto.dccp.handshake_seq = dccp_hdr_seq(dh); + break; + case CT_DCCP_PARTOPEN: + if (old_state == CT_DCCP_RESPOND && + type == DCCP_PKT_ACK && + dccp_ack_seq(dh) == ct->proto.dccp.handshake_seq) + set_bit(IPS_ASSURED_BIT, &ct->status); + break; + case CT_DCCP_IGNORE: + /* + * Connection tracking might be out of sync, so we ignore + * packets that might establish a new connection and resync + * if the server responds with a valid Response. + */ + if (ct->proto.dccp.last_dir == !dir && + ct->proto.dccp.last_pkt == DCCP_PKT_REQUEST && + type == DCCP_PKT_RESPONSE) { + ct->proto.dccp.role[!dir] = CT_DCCP_ROLE_CLIENT; + ct->proto.dccp.role[dir] = CT_DCCP_ROLE_SERVER; + ct->proto.dccp.handshake_seq = dccp_hdr_seq(dh); + new_state = CT_DCCP_RESPOND; + break; + } + ct->proto.dccp.last_dir = dir; + ct->proto.dccp.last_pkt = type; + + spin_unlock_bh(&ct->lock); + nf_ct_l4proto_log_invalid(skb, ct, state, "%s", "invalid packet"); + return NF_ACCEPT; + case CT_DCCP_INVALID: + spin_unlock_bh(&ct->lock); + nf_ct_l4proto_log_invalid(skb, ct, state, "%s", "invalid state transition"); + return -NF_ACCEPT; + } + + ct->proto.dccp.last_dir = dir; + ct->proto.dccp.last_pkt = type; + ct->proto.dccp.state = new_state; + spin_unlock_bh(&ct->lock); + + if (new_state != old_state) + nf_conntrack_event_cache(IPCT_PROTOINFO, ct); + + timeouts = nf_ct_timeout_lookup(ct); + if (!timeouts) + timeouts = nf_dccp_pernet(nf_ct_net(ct))->dccp_timeout; + nf_ct_refresh_acct(ct, ctinfo, skb, timeouts[new_state]); + + return NF_ACCEPT; +} + +static bool dccp_can_early_drop(const struct nf_conn *ct) +{ + switch (ct->proto.dccp.state) { + case CT_DCCP_CLOSEREQ: + case CT_DCCP_CLOSING: + case CT_DCCP_TIMEWAIT: + return true; + default: + break; + } + + return false; +} + +#ifdef CONFIG_NF_CONNTRACK_PROCFS +static void dccp_print_conntrack(struct seq_file *s, struct nf_conn *ct) +{ + seq_printf(s, "%s ", dccp_state_names[ct->proto.dccp.state]); +} +#endif + +#if IS_ENABLED(CONFIG_NF_CT_NETLINK) +static int dccp_to_nlattr(struct sk_buff *skb, struct nlattr *nla, + struct nf_conn *ct, bool destroy) +{ + struct nlattr *nest_parms; + + spin_lock_bh(&ct->lock); + nest_parms = nla_nest_start(skb, CTA_PROTOINFO_DCCP); + if (!nest_parms) + goto nla_put_failure; + if (nla_put_u8(skb, CTA_PROTOINFO_DCCP_STATE, ct->proto.dccp.state)) + goto nla_put_failure; + + if (destroy) + goto skip_state; + + if (nla_put_u8(skb, CTA_PROTOINFO_DCCP_ROLE, + ct->proto.dccp.role[IP_CT_DIR_ORIGINAL]) || + nla_put_be64(skb, CTA_PROTOINFO_DCCP_HANDSHAKE_SEQ, + cpu_to_be64(ct->proto.dccp.handshake_seq), + CTA_PROTOINFO_DCCP_PAD)) + goto nla_put_failure; +skip_state: + nla_nest_end(skb, nest_parms); + spin_unlock_bh(&ct->lock); + + return 0; + +nla_put_failure: + spin_unlock_bh(&ct->lock); + return -1; +} + +static const struct nla_policy dccp_nla_policy[CTA_PROTOINFO_DCCP_MAX + 1] = { + [CTA_PROTOINFO_DCCP_STATE] = { .type = NLA_U8 }, + [CTA_PROTOINFO_DCCP_ROLE] = { .type = NLA_U8 }, + [CTA_PROTOINFO_DCCP_HANDSHAKE_SEQ] = { .type = NLA_U64 }, + [CTA_PROTOINFO_DCCP_PAD] = { .type = NLA_UNSPEC }, +}; + +#define DCCP_NLATTR_SIZE ( \ + NLA_ALIGN(NLA_HDRLEN + 1) + \ + NLA_ALIGN(NLA_HDRLEN + 1) + \ + NLA_ALIGN(NLA_HDRLEN + sizeof(u64)) + \ + NLA_ALIGN(NLA_HDRLEN + 0)) + +static int nlattr_to_dccp(struct nlattr *cda[], struct nf_conn *ct) +{ + struct nlattr *attr = cda[CTA_PROTOINFO_DCCP]; + struct nlattr *tb[CTA_PROTOINFO_DCCP_MAX + 1]; + int err; + + if (!attr) + return 0; + + err = nla_parse_nested_deprecated(tb, CTA_PROTOINFO_DCCP_MAX, attr, + dccp_nla_policy, NULL); + if (err < 0) + return err; + + if (!tb[CTA_PROTOINFO_DCCP_STATE] || + !tb[CTA_PROTOINFO_DCCP_ROLE] || + nla_get_u8(tb[CTA_PROTOINFO_DCCP_ROLE]) > CT_DCCP_ROLE_MAX || + nla_get_u8(tb[CTA_PROTOINFO_DCCP_STATE]) >= CT_DCCP_IGNORE) { + return -EINVAL; + } + + spin_lock_bh(&ct->lock); + ct->proto.dccp.state = nla_get_u8(tb[CTA_PROTOINFO_DCCP_STATE]); + if (nla_get_u8(tb[CTA_PROTOINFO_DCCP_ROLE]) == CT_DCCP_ROLE_CLIENT) { + ct->proto.dccp.role[IP_CT_DIR_ORIGINAL] = CT_DCCP_ROLE_CLIENT; + ct->proto.dccp.role[IP_CT_DIR_REPLY] = CT_DCCP_ROLE_SERVER; + } else { + ct->proto.dccp.role[IP_CT_DIR_ORIGINAL] = CT_DCCP_ROLE_SERVER; + ct->proto.dccp.role[IP_CT_DIR_REPLY] = CT_DCCP_ROLE_CLIENT; + } + if (tb[CTA_PROTOINFO_DCCP_HANDSHAKE_SEQ]) { + ct->proto.dccp.handshake_seq = + be64_to_cpu(nla_get_be64(tb[CTA_PROTOINFO_DCCP_HANDSHAKE_SEQ])); + } + spin_unlock_bh(&ct->lock); + return 0; +} +#endif + +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + +#include +#include + +static int dccp_timeout_nlattr_to_obj(struct nlattr *tb[], + struct net *net, void *data) +{ + struct nf_dccp_net *dn = nf_dccp_pernet(net); + unsigned int *timeouts = data; + int i; + + if (!timeouts) + timeouts = dn->dccp_timeout; + + /* set default DCCP timeouts. */ + for (i=0; idccp_timeout[i]; + + /* there's a 1:1 mapping between attributes and protocol states. */ + for (i=CTA_TIMEOUT_DCCP_UNSPEC+1; idccp_loose = 1; + dn->dccp_timeout[CT_DCCP_REQUEST] = 2 * DCCP_MSL; + dn->dccp_timeout[CT_DCCP_RESPOND] = 4 * DCCP_MSL; + dn->dccp_timeout[CT_DCCP_PARTOPEN] = 4 * DCCP_MSL; + dn->dccp_timeout[CT_DCCP_OPEN] = 12 * 3600 * HZ; + dn->dccp_timeout[CT_DCCP_CLOSEREQ] = 64 * HZ; + dn->dccp_timeout[CT_DCCP_CLOSING] = 64 * HZ; + dn->dccp_timeout[CT_DCCP_TIMEWAIT] = 2 * DCCP_MSL; + + /* timeouts[0] is unused, make it same as SYN_SENT so + * ->timeouts[0] contains 'new' timeout, like udp or icmp. + */ + dn->dccp_timeout[CT_DCCP_NONE] = dn->dccp_timeout[CT_DCCP_REQUEST]; +} + +const struct nf_conntrack_l4proto nf_conntrack_l4proto_dccp = { + .l4proto = IPPROTO_DCCP, + .can_early_drop = dccp_can_early_drop, +#ifdef CONFIG_NF_CONNTRACK_PROCFS + .print_conntrack = dccp_print_conntrack, +#endif +#if IS_ENABLED(CONFIG_NF_CT_NETLINK) + .nlattr_size = DCCP_NLATTR_SIZE, + .to_nlattr = dccp_to_nlattr, + .from_nlattr = nlattr_to_dccp, + .tuple_to_nlattr = nf_ct_port_tuple_to_nlattr, + .nlattr_tuple_size = nf_ct_port_nlattr_tuple_size, + .nlattr_to_tuple = nf_ct_port_nlattr_to_tuple, + .nla_policy = nf_ct_port_nla_policy, +#endif +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + .ctnl_timeout = { + .nlattr_to_obj = dccp_timeout_nlattr_to_obj, + .obj_to_nlattr = dccp_timeout_obj_to_nlattr, + .nlattr_max = CTA_TIMEOUT_DCCP_MAX, + .obj_size = sizeof(unsigned int) * CT_DCCP_MAX, + .nla_policy = dccp_timeout_nla_policy, + }, +#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ +}; diff --git a/net/netfilter/nf_conntrack_proto_generic.c b/net/netfilter/nf_conntrack_proto_generic.c new file mode 100644 index 000000000..e831637bc --- /dev/null +++ b/net/netfilter/nf_conntrack_proto_generic.c @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2004 Netfilter Core Team + */ + +#include +#include +#include +#include +#include +#include + +static const unsigned int nf_ct_generic_timeout = 600*HZ; + +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + +#include +#include + +static int generic_timeout_nlattr_to_obj(struct nlattr *tb[], + struct net *net, void *data) +{ + struct nf_generic_net *gn = nf_generic_pernet(net); + unsigned int *timeout = data; + + if (!timeout) + timeout = &gn->timeout; + + if (tb[CTA_TIMEOUT_GENERIC_TIMEOUT]) + *timeout = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_GENERIC_TIMEOUT])) * HZ; + else { + /* Set default generic timeout. */ + *timeout = gn->timeout; + } + + return 0; +} + +static int +generic_timeout_obj_to_nlattr(struct sk_buff *skb, const void *data) +{ + const unsigned int *timeout = data; + + if (nla_put_be32(skb, CTA_TIMEOUT_GENERIC_TIMEOUT, htonl(*timeout / HZ))) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -ENOSPC; +} + +static const struct nla_policy +generic_timeout_nla_policy[CTA_TIMEOUT_GENERIC_MAX+1] = { + [CTA_TIMEOUT_GENERIC_TIMEOUT] = { .type = NLA_U32 }, +}; +#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ + +void nf_conntrack_generic_init_net(struct net *net) +{ + struct nf_generic_net *gn = nf_generic_pernet(net); + + gn->timeout = nf_ct_generic_timeout; +} + +const struct nf_conntrack_l4proto nf_conntrack_l4proto_generic = +{ + .l4proto = 255, +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + .ctnl_timeout = { + .nlattr_to_obj = generic_timeout_nlattr_to_obj, + .obj_to_nlattr = generic_timeout_obj_to_nlattr, + .nlattr_max = CTA_TIMEOUT_GENERIC_MAX, + .obj_size = sizeof(unsigned int), + .nla_policy = generic_timeout_nla_policy, + }, +#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ +}; diff --git a/net/netfilter/nf_conntrack_proto_gre.c b/net/netfilter/nf_conntrack_proto_gre.c new file mode 100644 index 000000000..728eeb0ae --- /dev/null +++ b/net/netfilter/nf_conntrack_proto_gre.c @@ -0,0 +1,317 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Connection tracking protocol helper module for GRE. + * + * GRE is a generic encapsulation protocol, which is generally not very + * suited for NAT, as it has no protocol-specific part as port numbers. + * + * It has an optional key field, which may help us distinguishing two + * connections between the same two hosts. + * + * GRE is defined in RFC 1701 and RFC 1702, as well as RFC 2784 + * + * PPTP is built on top of a modified version of GRE, and has a mandatory + * field called "CallID", which serves us for the same purpose as the key + * field in plain GRE. + * + * Documentation about PPTP can be found in RFC 2637 + * + * (C) 2000-2005 by Harald Welte + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + * + * (C) 2006-2012 Patrick McHardy + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static const unsigned int gre_timeouts[GRE_CT_MAX] = { + [GRE_CT_UNREPLIED] = 30*HZ, + [GRE_CT_REPLIED] = 180*HZ, +}; + +/* used when expectation is added */ +static DEFINE_SPINLOCK(keymap_lock); + +static inline struct nf_gre_net *gre_pernet(struct net *net) +{ + return &net->ct.nf_ct_proto.gre; +} + +static inline int gre_key_cmpfn(const struct nf_ct_gre_keymap *km, + const struct nf_conntrack_tuple *t) +{ + return km->tuple.src.l3num == t->src.l3num && + !memcmp(&km->tuple.src.u3, &t->src.u3, sizeof(t->src.u3)) && + !memcmp(&km->tuple.dst.u3, &t->dst.u3, sizeof(t->dst.u3)) && + km->tuple.dst.protonum == t->dst.protonum && + km->tuple.dst.u.all == t->dst.u.all; +} + +/* look up the source key for a given tuple */ +static __be16 gre_keymap_lookup(struct net *net, struct nf_conntrack_tuple *t) +{ + struct nf_gre_net *net_gre = gre_pernet(net); + struct nf_ct_gre_keymap *km; + __be16 key = 0; + + list_for_each_entry_rcu(km, &net_gre->keymap_list, list) { + if (gre_key_cmpfn(km, t)) { + key = km->tuple.src.u.gre.key; + break; + } + } + + pr_debug("lookup src key 0x%x for ", key); + nf_ct_dump_tuple(t); + + return key; +} + +/* add a single keymap entry, associate with specified master ct */ +int nf_ct_gre_keymap_add(struct nf_conn *ct, enum ip_conntrack_dir dir, + struct nf_conntrack_tuple *t) +{ + struct net *net = nf_ct_net(ct); + struct nf_gre_net *net_gre = gre_pernet(net); + struct nf_ct_pptp_master *ct_pptp_info = nfct_help_data(ct); + struct nf_ct_gre_keymap **kmp, *km; + + kmp = &ct_pptp_info->keymap[dir]; + if (*kmp) { + /* check whether it's a retransmission */ + list_for_each_entry_rcu(km, &net_gre->keymap_list, list) { + if (gre_key_cmpfn(km, t) && km == *kmp) + return 0; + } + pr_debug("trying to override keymap_%s for ct %p\n", + dir == IP_CT_DIR_REPLY ? "reply" : "orig", ct); + return -EEXIST; + } + + km = kmalloc(sizeof(*km), GFP_ATOMIC); + if (!km) + return -ENOMEM; + memcpy(&km->tuple, t, sizeof(*t)); + *kmp = km; + + pr_debug("adding new entry %p: ", km); + nf_ct_dump_tuple(&km->tuple); + + spin_lock_bh(&keymap_lock); + list_add_tail(&km->list, &net_gre->keymap_list); + spin_unlock_bh(&keymap_lock); + + return 0; +} +EXPORT_SYMBOL_GPL(nf_ct_gre_keymap_add); + +/* destroy the keymap entries associated with specified master ct */ +void nf_ct_gre_keymap_destroy(struct nf_conn *ct) +{ + struct nf_ct_pptp_master *ct_pptp_info = nfct_help_data(ct); + enum ip_conntrack_dir dir; + + pr_debug("entering for ct %p\n", ct); + + spin_lock_bh(&keymap_lock); + for (dir = IP_CT_DIR_ORIGINAL; dir < IP_CT_DIR_MAX; dir++) { + if (ct_pptp_info->keymap[dir]) { + pr_debug("removing %p from list\n", + ct_pptp_info->keymap[dir]); + list_del_rcu(&ct_pptp_info->keymap[dir]->list); + kfree_rcu(ct_pptp_info->keymap[dir], rcu); + ct_pptp_info->keymap[dir] = NULL; + } + } + spin_unlock_bh(&keymap_lock); +} +EXPORT_SYMBOL_GPL(nf_ct_gre_keymap_destroy); + +/* PUBLIC CONNTRACK PROTO HELPER FUNCTIONS */ + +/* gre hdr info to tuple */ +bool gre_pkt_to_tuple(const struct sk_buff *skb, unsigned int dataoff, + struct net *net, struct nf_conntrack_tuple *tuple) +{ + const struct pptp_gre_header *pgrehdr; + struct pptp_gre_header _pgrehdr; + __be16 srckey; + const struct gre_base_hdr *grehdr; + struct gre_base_hdr _grehdr; + + /* first only delinearize old RFC1701 GRE header */ + grehdr = skb_header_pointer(skb, dataoff, sizeof(_grehdr), &_grehdr); + if (!grehdr || (grehdr->flags & GRE_VERSION) != GRE_VERSION_1) { + /* try to behave like "nf_conntrack_proto_generic" */ + tuple->src.u.all = 0; + tuple->dst.u.all = 0; + return true; + } + + /* PPTP header is variable length, only need up to the call_id field */ + pgrehdr = skb_header_pointer(skb, dataoff, 8, &_pgrehdr); + if (!pgrehdr) + return true; + + if (grehdr->protocol != GRE_PROTO_PPP) { + pr_debug("Unsupported GRE proto(0x%x)\n", ntohs(grehdr->protocol)); + return false; + } + + tuple->dst.u.gre.key = pgrehdr->call_id; + srckey = gre_keymap_lookup(net, tuple); + tuple->src.u.gre.key = srckey; + + return true; +} + +#ifdef CONFIG_NF_CONNTRACK_PROCFS +/* print private data for conntrack */ +static void gre_print_conntrack(struct seq_file *s, struct nf_conn *ct) +{ + seq_printf(s, "timeout=%u, stream_timeout=%u ", + (ct->proto.gre.timeout / HZ), + (ct->proto.gre.stream_timeout / HZ)); +} +#endif + +static unsigned int *gre_get_timeouts(struct net *net) +{ + return gre_pernet(net)->timeouts; +} + +/* Returns verdict for packet, and may modify conntrack */ +int nf_conntrack_gre_packet(struct nf_conn *ct, + struct sk_buff *skb, + unsigned int dataoff, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) +{ + if (!nf_ct_is_confirmed(ct)) { + unsigned int *timeouts = nf_ct_timeout_lookup(ct); + + if (!timeouts) + timeouts = gre_get_timeouts(nf_ct_net(ct)); + + /* initialize to sane value. Ideally a conntrack helper + * (e.g. in case of pptp) is increasing them */ + ct->proto.gre.stream_timeout = timeouts[GRE_CT_REPLIED]; + ct->proto.gre.timeout = timeouts[GRE_CT_UNREPLIED]; + } + + /* If we've seen traffic both ways, this is a GRE connection. + * Extend timeout. */ + if (ct->status & IPS_SEEN_REPLY) { + nf_ct_refresh_acct(ct, ctinfo, skb, + ct->proto.gre.stream_timeout); + /* Also, more likely to be important, and not a probe. */ + if (!test_and_set_bit(IPS_ASSURED_BIT, &ct->status)) + nf_conntrack_event_cache(IPCT_ASSURED, ct); + } else + nf_ct_refresh_acct(ct, ctinfo, skb, + ct->proto.gre.timeout); + + return NF_ACCEPT; +} + +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + +#include +#include + +static int gre_timeout_nlattr_to_obj(struct nlattr *tb[], + struct net *net, void *data) +{ + unsigned int *timeouts = data; + struct nf_gre_net *net_gre = gre_pernet(net); + + if (!timeouts) + timeouts = gre_get_timeouts(net); + /* set default timeouts for GRE. */ + timeouts[GRE_CT_UNREPLIED] = net_gre->timeouts[GRE_CT_UNREPLIED]; + timeouts[GRE_CT_REPLIED] = net_gre->timeouts[GRE_CT_REPLIED]; + + if (tb[CTA_TIMEOUT_GRE_UNREPLIED]) { + timeouts[GRE_CT_UNREPLIED] = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_GRE_UNREPLIED])) * HZ; + } + if (tb[CTA_TIMEOUT_GRE_REPLIED]) { + timeouts[GRE_CT_REPLIED] = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_GRE_REPLIED])) * HZ; + } + return 0; +} + +static int +gre_timeout_obj_to_nlattr(struct sk_buff *skb, const void *data) +{ + const unsigned int *timeouts = data; + + if (nla_put_be32(skb, CTA_TIMEOUT_GRE_UNREPLIED, + htonl(timeouts[GRE_CT_UNREPLIED] / HZ)) || + nla_put_be32(skb, CTA_TIMEOUT_GRE_REPLIED, + htonl(timeouts[GRE_CT_REPLIED] / HZ))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -ENOSPC; +} + +static const struct nla_policy +gre_timeout_nla_policy[CTA_TIMEOUT_GRE_MAX+1] = { + [CTA_TIMEOUT_GRE_UNREPLIED] = { .type = NLA_U32 }, + [CTA_TIMEOUT_GRE_REPLIED] = { .type = NLA_U32 }, +}; +#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ + +void nf_conntrack_gre_init_net(struct net *net) +{ + struct nf_gre_net *net_gre = gre_pernet(net); + int i; + + INIT_LIST_HEAD(&net_gre->keymap_list); + for (i = 0; i < GRE_CT_MAX; i++) + net_gre->timeouts[i] = gre_timeouts[i]; +} + +/* protocol helper struct */ +const struct nf_conntrack_l4proto nf_conntrack_l4proto_gre = { + .l4proto = IPPROTO_GRE, +#ifdef CONFIG_NF_CONNTRACK_PROCFS + .print_conntrack = gre_print_conntrack, +#endif +#if IS_ENABLED(CONFIG_NF_CT_NETLINK) + .tuple_to_nlattr = nf_ct_port_tuple_to_nlattr, + .nlattr_tuple_size = nf_ct_port_nlattr_tuple_size, + .nlattr_to_tuple = nf_ct_port_nlattr_to_tuple, + .nla_policy = nf_ct_port_nla_policy, +#endif +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + .ctnl_timeout = { + .nlattr_to_obj = gre_timeout_nlattr_to_obj, + .obj_to_nlattr = gre_timeout_obj_to_nlattr, + .nlattr_max = CTA_TIMEOUT_GRE_MAX, + .obj_size = sizeof(unsigned int) * GRE_CT_MAX, + .nla_policy = gre_timeout_nla_policy, + }, +#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ +}; diff --git a/net/netfilter/nf_conntrack_proto_icmp.c b/net/netfilter/nf_conntrack_proto_icmp.c new file mode 100644 index 000000000..b38b7164a --- /dev/null +++ b/net/netfilter/nf_conntrack_proto_icmp.c @@ -0,0 +1,383 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2004 Netfilter Core Team + * (C) 2006-2010 Patrick McHardy + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "nf_internals.h" + +static const unsigned int nf_ct_icmp_timeout = 30*HZ; + +bool icmp_pkt_to_tuple(const struct sk_buff *skb, unsigned int dataoff, + struct net *net, struct nf_conntrack_tuple *tuple) +{ + const struct icmphdr *hp; + struct icmphdr _hdr; + + hp = skb_header_pointer(skb, dataoff, sizeof(_hdr), &_hdr); + if (hp == NULL) + return false; + + tuple->dst.u.icmp.type = hp->type; + tuple->src.u.icmp.id = hp->un.echo.id; + tuple->dst.u.icmp.code = hp->code; + + return true; +} + +/* Add 1; spaces filled with 0. */ +static const u_int8_t invmap[] = { + [ICMP_ECHO] = ICMP_ECHOREPLY + 1, + [ICMP_ECHOREPLY] = ICMP_ECHO + 1, + [ICMP_TIMESTAMP] = ICMP_TIMESTAMPREPLY + 1, + [ICMP_TIMESTAMPREPLY] = ICMP_TIMESTAMP + 1, + [ICMP_INFO_REQUEST] = ICMP_INFO_REPLY + 1, + [ICMP_INFO_REPLY] = ICMP_INFO_REQUEST + 1, + [ICMP_ADDRESS] = ICMP_ADDRESSREPLY + 1, + [ICMP_ADDRESSREPLY] = ICMP_ADDRESS + 1 +}; + +bool nf_conntrack_invert_icmp_tuple(struct nf_conntrack_tuple *tuple, + const struct nf_conntrack_tuple *orig) +{ + if (orig->dst.u.icmp.type >= sizeof(invmap) || + !invmap[orig->dst.u.icmp.type]) + return false; + + tuple->src.u.icmp.id = orig->src.u.icmp.id; + tuple->dst.u.icmp.type = invmap[orig->dst.u.icmp.type] - 1; + tuple->dst.u.icmp.code = orig->dst.u.icmp.code; + return true; +} + +/* Returns verdict for packet, or -1 for invalid. */ +int nf_conntrack_icmp_packet(struct nf_conn *ct, + struct sk_buff *skb, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) +{ + /* Do not immediately delete the connection after the first + successful reply to avoid excessive conntrackd traffic + and also to handle correctly ICMP echo reply duplicates. */ + unsigned int *timeout = nf_ct_timeout_lookup(ct); + static const u_int8_t valid_new[] = { + [ICMP_ECHO] = 1, + [ICMP_TIMESTAMP] = 1, + [ICMP_INFO_REQUEST] = 1, + [ICMP_ADDRESS] = 1 + }; + + if (state->pf != NFPROTO_IPV4) + return -NF_ACCEPT; + + if (ct->tuplehash[0].tuple.dst.u.icmp.type >= sizeof(valid_new) || + !valid_new[ct->tuplehash[0].tuple.dst.u.icmp.type]) { + /* Can't create a new ICMP `conn' with this. */ + pr_debug("icmp: can't create new conn with type %u\n", + ct->tuplehash[0].tuple.dst.u.icmp.type); + nf_ct_dump_tuple_ip(&ct->tuplehash[0].tuple); + return -NF_ACCEPT; + } + + if (!timeout) + timeout = &nf_icmp_pernet(nf_ct_net(ct))->timeout; + + nf_ct_refresh_acct(ct, ctinfo, skb, *timeout); + return NF_ACCEPT; +} + +/* Check inner header is related to any of the existing connections */ +int nf_conntrack_inet_error(struct nf_conn *tmpl, struct sk_buff *skb, + unsigned int dataoff, + const struct nf_hook_state *state, + u8 l4proto, union nf_inet_addr *outer_daddr) +{ + struct nf_conntrack_tuple innertuple, origtuple; + const struct nf_conntrack_tuple_hash *h; + const struct nf_conntrack_zone *zone; + enum ip_conntrack_info ctinfo; + struct nf_conntrack_zone tmp; + union nf_inet_addr *ct_daddr; + enum ip_conntrack_dir dir; + struct nf_conn *ct; + + WARN_ON(skb_nfct(skb)); + zone = nf_ct_zone_tmpl(tmpl, skb, &tmp); + + /* Are they talking about one of our connections? */ + if (!nf_ct_get_tuplepr(skb, dataoff, + state->pf, state->net, &origtuple)) + return -NF_ACCEPT; + + /* Ordinarily, we'd expect the inverted tupleproto, but it's + been preserved inside the ICMP. */ + if (!nf_ct_invert_tuple(&innertuple, &origtuple)) + return -NF_ACCEPT; + + h = nf_conntrack_find_get(state->net, zone, &innertuple); + if (!h) + return -NF_ACCEPT; + + /* Consider: A -> T (=This machine) -> B + * Conntrack entry will look like this: + * Original: A->B + * Reply: B->T (SNAT case) OR A + * + * When this function runs, we got packet that looks like this: + * iphdr|icmphdr|inner_iphdr|l4header (tcp, udp, ..). + * + * Above nf_conntrack_find_get() makes lookup based on inner_hdr, + * so we should expect that destination of the found connection + * matches outer header destination address. + * + * In above example, we can consider these two cases: + * 1. Error coming in reply direction from B or M (middle box) to + * T (SNAT case) or A. + * Inner saddr will be B, dst will be T or A. + * The found conntrack will be reply tuple (B->T/A). + * 2. Error coming in original direction from A or M to B. + * Inner saddr will be A, inner daddr will be B. + * The found conntrack will be original tuple (A->B). + * + * In both cases, conntrack[dir].dst == inner.dst. + * + * A bogus packet could look like this: + * Inner: B->T + * Outer: B->X (other machine reachable by T). + * + * In this case, lookup yields connection A->B and will + * set packet from B->X as *RELATED*, even though no connection + * from X was ever seen. + */ + ct = nf_ct_tuplehash_to_ctrack(h); + dir = NF_CT_DIRECTION(h); + ct_daddr = &ct->tuplehash[dir].tuple.dst.u3; + if (!nf_inet_addr_cmp(outer_daddr, ct_daddr)) { + if (state->pf == AF_INET) { + nf_l4proto_log_invalid(skb, state, + l4proto, + "outer daddr %pI4 != inner %pI4", + &outer_daddr->ip, &ct_daddr->ip); + } else if (state->pf == AF_INET6) { + nf_l4proto_log_invalid(skb, state, + l4proto, + "outer daddr %pI6 != inner %pI6", + &outer_daddr->ip6, &ct_daddr->ip6); + } + nf_ct_put(ct); + return -NF_ACCEPT; + } + + ctinfo = IP_CT_RELATED; + if (dir == IP_CT_DIR_REPLY) + ctinfo += IP_CT_IS_REPLY; + + /* Update skb to refer to this connection */ + nf_ct_set(skb, ct, ctinfo); + return NF_ACCEPT; +} + +static void icmp_error_log(const struct sk_buff *skb, + const struct nf_hook_state *state, + const char *msg) +{ + nf_l4proto_log_invalid(skb, state, IPPROTO_ICMP, "%s", msg); +} + +/* Small and modified version of icmp_rcv */ +int nf_conntrack_icmpv4_error(struct nf_conn *tmpl, + struct sk_buff *skb, unsigned int dataoff, + const struct nf_hook_state *state) +{ + union nf_inet_addr outer_daddr; + const struct icmphdr *icmph; + struct icmphdr _ih; + + /* Not enough header? */ + icmph = skb_header_pointer(skb, dataoff, sizeof(_ih), &_ih); + if (icmph == NULL) { + icmp_error_log(skb, state, "short packet"); + return -NF_ACCEPT; + } + + /* See nf_conntrack_proto_tcp.c */ + if (state->net->ct.sysctl_checksum && + state->hook == NF_INET_PRE_ROUTING && + nf_ip_checksum(skb, state->hook, dataoff, IPPROTO_ICMP)) { + icmp_error_log(skb, state, "bad hw icmp checksum"); + return -NF_ACCEPT; + } + + /* + * 18 is the highest 'known' ICMP type. Anything else is a mystery + * + * RFC 1122: 3.2.2 Unknown ICMP messages types MUST be silently + * discarded. + */ + if (icmph->type > NR_ICMP_TYPES) { + icmp_error_log(skb, state, "invalid icmp type"); + return -NF_ACCEPT; + } + + /* Need to track icmp error message? */ + if (!icmp_is_err(icmph->type)) + return NF_ACCEPT; + + memset(&outer_daddr, 0, sizeof(outer_daddr)); + outer_daddr.ip = ip_hdr(skb)->daddr; + + dataoff += sizeof(*icmph); + return nf_conntrack_inet_error(tmpl, skb, dataoff, state, + IPPROTO_ICMP, &outer_daddr); +} + +#if IS_ENABLED(CONFIG_NF_CT_NETLINK) + +#include +#include + +static int icmp_tuple_to_nlattr(struct sk_buff *skb, + const struct nf_conntrack_tuple *t) +{ + if (nla_put_be16(skb, CTA_PROTO_ICMP_ID, t->src.u.icmp.id) || + nla_put_u8(skb, CTA_PROTO_ICMP_TYPE, t->dst.u.icmp.type) || + nla_put_u8(skb, CTA_PROTO_ICMP_CODE, t->dst.u.icmp.code)) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static const struct nla_policy icmp_nla_policy[CTA_PROTO_MAX+1] = { + [CTA_PROTO_ICMP_TYPE] = { .type = NLA_U8 }, + [CTA_PROTO_ICMP_CODE] = { .type = NLA_U8 }, + [CTA_PROTO_ICMP_ID] = { .type = NLA_U16 }, +}; + +static int icmp_nlattr_to_tuple(struct nlattr *tb[], + struct nf_conntrack_tuple *tuple, + u_int32_t flags) +{ + if (flags & CTA_FILTER_FLAG(CTA_PROTO_ICMP_TYPE)) { + if (!tb[CTA_PROTO_ICMP_TYPE]) + return -EINVAL; + + tuple->dst.u.icmp.type = nla_get_u8(tb[CTA_PROTO_ICMP_TYPE]); + if (tuple->dst.u.icmp.type >= sizeof(invmap) || + !invmap[tuple->dst.u.icmp.type]) + return -EINVAL; + } + + if (flags & CTA_FILTER_FLAG(CTA_PROTO_ICMP_CODE)) { + if (!tb[CTA_PROTO_ICMP_CODE]) + return -EINVAL; + + tuple->dst.u.icmp.code = nla_get_u8(tb[CTA_PROTO_ICMP_CODE]); + } + + if (flags & CTA_FILTER_FLAG(CTA_PROTO_ICMP_ID)) { + if (!tb[CTA_PROTO_ICMP_ID]) + return -EINVAL; + + tuple->src.u.icmp.id = nla_get_be16(tb[CTA_PROTO_ICMP_ID]); + } + + return 0; +} + +static unsigned int icmp_nlattr_tuple_size(void) +{ + static unsigned int size __read_mostly; + + if (!size) + size = nla_policy_len(icmp_nla_policy, CTA_PROTO_MAX + 1); + + return size; +} +#endif + +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + +#include +#include + +static int icmp_timeout_nlattr_to_obj(struct nlattr *tb[], + struct net *net, void *data) +{ + unsigned int *timeout = data; + struct nf_icmp_net *in = nf_icmp_pernet(net); + + if (tb[CTA_TIMEOUT_ICMP_TIMEOUT]) { + if (!timeout) + timeout = &in->timeout; + *timeout = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_ICMP_TIMEOUT])) * HZ; + } else if (timeout) { + /* Set default ICMP timeout. */ + *timeout = in->timeout; + } + return 0; +} + +static int +icmp_timeout_obj_to_nlattr(struct sk_buff *skb, const void *data) +{ + const unsigned int *timeout = data; + + if (nla_put_be32(skb, CTA_TIMEOUT_ICMP_TIMEOUT, htonl(*timeout / HZ))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -ENOSPC; +} + +static const struct nla_policy +icmp_timeout_nla_policy[CTA_TIMEOUT_ICMP_MAX+1] = { + [CTA_TIMEOUT_ICMP_TIMEOUT] = { .type = NLA_U32 }, +}; +#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ + +void nf_conntrack_icmp_init_net(struct net *net) +{ + struct nf_icmp_net *in = nf_icmp_pernet(net); + + in->timeout = nf_ct_icmp_timeout; +} + +const struct nf_conntrack_l4proto nf_conntrack_l4proto_icmp = +{ + .l4proto = IPPROTO_ICMP, +#if IS_ENABLED(CONFIG_NF_CT_NETLINK) + .tuple_to_nlattr = icmp_tuple_to_nlattr, + .nlattr_tuple_size = icmp_nlattr_tuple_size, + .nlattr_to_tuple = icmp_nlattr_to_tuple, + .nla_policy = icmp_nla_policy, +#endif +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + .ctnl_timeout = { + .nlattr_to_obj = icmp_timeout_nlattr_to_obj, + .obj_to_nlattr = icmp_timeout_obj_to_nlattr, + .nlattr_max = CTA_TIMEOUT_ICMP_MAX, + .obj_size = sizeof(unsigned int), + .nla_policy = icmp_timeout_nla_policy, + }, +#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ +}; diff --git a/net/netfilter/nf_conntrack_proto_icmpv6.c b/net/netfilter/nf_conntrack_proto_icmpv6.c new file mode 100644 index 000000000..1020d6760 --- /dev/null +++ b/net/netfilter/nf_conntrack_proto_icmpv6.c @@ -0,0 +1,359 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C)2003,2004 USAGI/WIDE Project + * + * Author: + * Yasuyuki Kozakai @USAGI + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "nf_internals.h" + +static const unsigned int nf_ct_icmpv6_timeout = 30*HZ; + +bool icmpv6_pkt_to_tuple(const struct sk_buff *skb, + unsigned int dataoff, + struct net *net, + struct nf_conntrack_tuple *tuple) +{ + const struct icmp6hdr *hp; + struct icmp6hdr _hdr; + + hp = skb_header_pointer(skb, dataoff, sizeof(_hdr), &_hdr); + if (hp == NULL) + return false; + tuple->dst.u.icmp.type = hp->icmp6_type; + tuple->src.u.icmp.id = hp->icmp6_identifier; + tuple->dst.u.icmp.code = hp->icmp6_code; + + return true; +} + +/* Add 1; spaces filled with 0. */ +static const u_int8_t invmap[] = { + [ICMPV6_ECHO_REQUEST - 128] = ICMPV6_ECHO_REPLY + 1, + [ICMPV6_ECHO_REPLY - 128] = ICMPV6_ECHO_REQUEST + 1, + [ICMPV6_NI_QUERY - 128] = ICMPV6_NI_REPLY + 1, + [ICMPV6_NI_REPLY - 128] = ICMPV6_NI_QUERY + 1 +}; + +static const u_int8_t noct_valid_new[] = { + [ICMPV6_MGM_QUERY - 130] = 1, + [ICMPV6_MGM_REPORT - 130] = 1, + [ICMPV6_MGM_REDUCTION - 130] = 1, + [NDISC_ROUTER_SOLICITATION - 130] = 1, + [NDISC_ROUTER_ADVERTISEMENT - 130] = 1, + [NDISC_NEIGHBOUR_SOLICITATION - 130] = 1, + [NDISC_NEIGHBOUR_ADVERTISEMENT - 130] = 1, + [ICMPV6_MLD2_REPORT - 130] = 1 +}; + +bool nf_conntrack_invert_icmpv6_tuple(struct nf_conntrack_tuple *tuple, + const struct nf_conntrack_tuple *orig) +{ + int type = orig->dst.u.icmp.type - 128; + if (type < 0 || type >= sizeof(invmap) || !invmap[type]) + return false; + + tuple->src.u.icmp.id = orig->src.u.icmp.id; + tuple->dst.u.icmp.type = invmap[type] - 1; + tuple->dst.u.icmp.code = orig->dst.u.icmp.code; + return true; +} + +static unsigned int *icmpv6_get_timeouts(struct net *net) +{ + return &nf_icmpv6_pernet(net)->timeout; +} + +/* Returns verdict for packet, or -1 for invalid. */ +int nf_conntrack_icmpv6_packet(struct nf_conn *ct, + struct sk_buff *skb, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) +{ + unsigned int *timeout = nf_ct_timeout_lookup(ct); + static const u8 valid_new[] = { + [ICMPV6_ECHO_REQUEST - 128] = 1, + [ICMPV6_NI_QUERY - 128] = 1 + }; + + if (state->pf != NFPROTO_IPV6) + return -NF_ACCEPT; + + if (!nf_ct_is_confirmed(ct)) { + int type = ct->tuplehash[0].tuple.dst.u.icmp.type - 128; + + if (type < 0 || type >= sizeof(valid_new) || !valid_new[type]) { + /* Can't create a new ICMPv6 `conn' with this. */ + pr_debug("icmpv6: can't create new conn with type %u\n", + type + 128); + nf_ct_dump_tuple_ipv6(&ct->tuplehash[0].tuple); + return -NF_ACCEPT; + } + } + + if (!timeout) + timeout = icmpv6_get_timeouts(nf_ct_net(ct)); + + /* Do not immediately delete the connection after the first + successful reply to avoid excessive conntrackd traffic + and also to handle correctly ICMP echo reply duplicates. */ + nf_ct_refresh_acct(ct, ctinfo, skb, *timeout); + + return NF_ACCEPT; +} + + +static void icmpv6_error_log(const struct sk_buff *skb, + const struct nf_hook_state *state, + const char *msg) +{ + nf_l4proto_log_invalid(skb, state, IPPROTO_ICMPV6, "%s", msg); +} + +static noinline_for_stack int +nf_conntrack_icmpv6_redirect(struct nf_conn *tmpl, struct sk_buff *skb, + unsigned int dataoff, + const struct nf_hook_state *state) +{ + u8 hl = ipv6_hdr(skb)->hop_limit; + union nf_inet_addr outer_daddr; + union { + struct nd_opt_hdr nd_opt; + struct rd_msg rd_msg; + } tmp; + const struct nd_opt_hdr *nd_opt; + const struct rd_msg *rd_msg; + + rd_msg = skb_header_pointer(skb, dataoff, sizeof(*rd_msg), &tmp.rd_msg); + if (!rd_msg) { + icmpv6_error_log(skb, state, "short redirect"); + return -NF_ACCEPT; + } + + if (rd_msg->icmph.icmp6_code != 0) + return NF_ACCEPT; + + if (hl != 255 || !(ipv6_addr_type(&ipv6_hdr(skb)->saddr) & IPV6_ADDR_LINKLOCAL)) { + icmpv6_error_log(skb, state, "invalid saddr or hoplimit for redirect"); + return -NF_ACCEPT; + } + + dataoff += sizeof(*rd_msg); + + /* warning: rd_msg no longer usable after this call */ + nd_opt = skb_header_pointer(skb, dataoff, sizeof(*nd_opt), &tmp.nd_opt); + if (!nd_opt || nd_opt->nd_opt_len == 0) { + icmpv6_error_log(skb, state, "redirect without options"); + return -NF_ACCEPT; + } + + /* We could call ndisc_parse_options(), but it would need + * skb_linearize() and a bit more work. + */ + if (nd_opt->nd_opt_type != ND_OPT_REDIRECT_HDR) + return NF_ACCEPT; + + memcpy(&outer_daddr.ip6, &ipv6_hdr(skb)->daddr, + sizeof(outer_daddr.ip6)); + dataoff += 8; + return nf_conntrack_inet_error(tmpl, skb, dataoff, state, + IPPROTO_ICMPV6, &outer_daddr); +} + +int nf_conntrack_icmpv6_error(struct nf_conn *tmpl, + struct sk_buff *skb, + unsigned int dataoff, + const struct nf_hook_state *state) +{ + union nf_inet_addr outer_daddr; + const struct icmp6hdr *icmp6h; + struct icmp6hdr _ih; + int type; + + icmp6h = skb_header_pointer(skb, dataoff, sizeof(_ih), &_ih); + if (icmp6h == NULL) { + icmpv6_error_log(skb, state, "short packet"); + return -NF_ACCEPT; + } + + if (state->hook == NF_INET_PRE_ROUTING && + state->net->ct.sysctl_checksum && + nf_ip6_checksum(skb, state->hook, dataoff, IPPROTO_ICMPV6)) { + icmpv6_error_log(skb, state, "ICMPv6 checksum failed"); + return -NF_ACCEPT; + } + + type = icmp6h->icmp6_type - 130; + if (type >= 0 && type < sizeof(noct_valid_new) && + noct_valid_new[type]) { + nf_ct_set(skb, NULL, IP_CT_UNTRACKED); + return NF_ACCEPT; + } + + if (icmp6h->icmp6_type == NDISC_REDIRECT) + return nf_conntrack_icmpv6_redirect(tmpl, skb, dataoff, state); + + /* is not error message ? */ + if (icmp6h->icmp6_type >= 128) + return NF_ACCEPT; + + memcpy(&outer_daddr.ip6, &ipv6_hdr(skb)->daddr, + sizeof(outer_daddr.ip6)); + dataoff += sizeof(*icmp6h); + return nf_conntrack_inet_error(tmpl, skb, dataoff, state, + IPPROTO_ICMPV6, &outer_daddr); +} + +#if IS_ENABLED(CONFIG_NF_CT_NETLINK) + +#include +#include +static int icmpv6_tuple_to_nlattr(struct sk_buff *skb, + const struct nf_conntrack_tuple *t) +{ + if (nla_put_be16(skb, CTA_PROTO_ICMPV6_ID, t->src.u.icmp.id) || + nla_put_u8(skb, CTA_PROTO_ICMPV6_TYPE, t->dst.u.icmp.type) || + nla_put_u8(skb, CTA_PROTO_ICMPV6_CODE, t->dst.u.icmp.code)) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static const struct nla_policy icmpv6_nla_policy[CTA_PROTO_MAX+1] = { + [CTA_PROTO_ICMPV6_TYPE] = { .type = NLA_U8 }, + [CTA_PROTO_ICMPV6_CODE] = { .type = NLA_U8 }, + [CTA_PROTO_ICMPV6_ID] = { .type = NLA_U16 }, +}; + +static int icmpv6_nlattr_to_tuple(struct nlattr *tb[], + struct nf_conntrack_tuple *tuple, + u_int32_t flags) +{ + if (flags & CTA_FILTER_FLAG(CTA_PROTO_ICMPV6_TYPE)) { + if (!tb[CTA_PROTO_ICMPV6_TYPE]) + return -EINVAL; + + tuple->dst.u.icmp.type = nla_get_u8(tb[CTA_PROTO_ICMPV6_TYPE]); + if (tuple->dst.u.icmp.type < 128 || + tuple->dst.u.icmp.type - 128 >= sizeof(invmap) || + !invmap[tuple->dst.u.icmp.type - 128]) + return -EINVAL; + } + + if (flags & CTA_FILTER_FLAG(CTA_PROTO_ICMPV6_CODE)) { + if (!tb[CTA_PROTO_ICMPV6_CODE]) + return -EINVAL; + + tuple->dst.u.icmp.code = nla_get_u8(tb[CTA_PROTO_ICMPV6_CODE]); + } + + if (flags & CTA_FILTER_FLAG(CTA_PROTO_ICMPV6_ID)) { + if (!tb[CTA_PROTO_ICMPV6_ID]) + return -EINVAL; + + tuple->src.u.icmp.id = nla_get_be16(tb[CTA_PROTO_ICMPV6_ID]); + } + + return 0; +} + +static unsigned int icmpv6_nlattr_tuple_size(void) +{ + static unsigned int size __read_mostly; + + if (!size) + size = nla_policy_len(icmpv6_nla_policy, CTA_PROTO_MAX + 1); + + return size; +} +#endif + +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + +#include +#include + +static int icmpv6_timeout_nlattr_to_obj(struct nlattr *tb[], + struct net *net, void *data) +{ + unsigned int *timeout = data; + struct nf_icmp_net *in = nf_icmpv6_pernet(net); + + if (!timeout) + timeout = icmpv6_get_timeouts(net); + if (tb[CTA_TIMEOUT_ICMPV6_TIMEOUT]) { + *timeout = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_ICMPV6_TIMEOUT])) * HZ; + } else { + /* Set default ICMPv6 timeout. */ + *timeout = in->timeout; + } + return 0; +} + +static int +icmpv6_timeout_obj_to_nlattr(struct sk_buff *skb, const void *data) +{ + const unsigned int *timeout = data; + + if (nla_put_be32(skb, CTA_TIMEOUT_ICMPV6_TIMEOUT, htonl(*timeout / HZ))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -ENOSPC; +} + +static const struct nla_policy +icmpv6_timeout_nla_policy[CTA_TIMEOUT_ICMPV6_MAX+1] = { + [CTA_TIMEOUT_ICMPV6_TIMEOUT] = { .type = NLA_U32 }, +}; +#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ + +void nf_conntrack_icmpv6_init_net(struct net *net) +{ + struct nf_icmp_net *in = nf_icmpv6_pernet(net); + + in->timeout = nf_ct_icmpv6_timeout; +} + +const struct nf_conntrack_l4proto nf_conntrack_l4proto_icmpv6 = +{ + .l4proto = IPPROTO_ICMPV6, +#if IS_ENABLED(CONFIG_NF_CT_NETLINK) + .tuple_to_nlattr = icmpv6_tuple_to_nlattr, + .nlattr_tuple_size = icmpv6_nlattr_tuple_size, + .nlattr_to_tuple = icmpv6_nlattr_to_tuple, + .nla_policy = icmpv6_nla_policy, +#endif +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + .ctnl_timeout = { + .nlattr_to_obj = icmpv6_timeout_nlattr_to_obj, + .obj_to_nlattr = icmpv6_timeout_obj_to_nlattr, + .nlattr_max = CTA_TIMEOUT_ICMP_MAX, + .obj_size = sizeof(unsigned int), + .nla_policy = icmpv6_timeout_nla_policy, + }, +#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ +}; diff --git a/net/netfilter/nf_conntrack_proto_sctp.c b/net/netfilter/nf_conntrack_proto_sctp.c new file mode 100644 index 000000000..c94a9971d --- /dev/null +++ b/net/netfilter/nf_conntrack_proto_sctp.c @@ -0,0 +1,743 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Connection tracking protocol helper module for SCTP. + * + * Copyright (c) 2004 Kiran Kumar Immidi + * Copyright (c) 2004-2012 Patrick McHardy + * + * SCTP is defined in RFC 2960. References to various sections in this code + * are to this RFC. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +static const char *const sctp_conntrack_names[] = { + [SCTP_CONNTRACK_NONE] = "NONE", + [SCTP_CONNTRACK_CLOSED] = "CLOSED", + [SCTP_CONNTRACK_COOKIE_WAIT] = "COOKIE_WAIT", + [SCTP_CONNTRACK_COOKIE_ECHOED] = "COOKIE_ECHOED", + [SCTP_CONNTRACK_ESTABLISHED] = "ESTABLISHED", + [SCTP_CONNTRACK_SHUTDOWN_SENT] = "SHUTDOWN_SENT", + [SCTP_CONNTRACK_SHUTDOWN_RECD] = "SHUTDOWN_RECD", + [SCTP_CONNTRACK_SHUTDOWN_ACK_SENT] = "SHUTDOWN_ACK_SENT", + [SCTP_CONNTRACK_HEARTBEAT_SENT] = "HEARTBEAT_SENT", +}; + +#define SECS * HZ +#define MINS * 60 SECS +#define HOURS * 60 MINS +#define DAYS * 24 HOURS + +static const unsigned int sctp_timeouts[SCTP_CONNTRACK_MAX] = { + [SCTP_CONNTRACK_CLOSED] = 10 SECS, + [SCTP_CONNTRACK_COOKIE_WAIT] = 3 SECS, + [SCTP_CONNTRACK_COOKIE_ECHOED] = 3 SECS, + [SCTP_CONNTRACK_ESTABLISHED] = 210 SECS, + [SCTP_CONNTRACK_SHUTDOWN_SENT] = 3 SECS, + [SCTP_CONNTRACK_SHUTDOWN_RECD] = 3 SECS, + [SCTP_CONNTRACK_SHUTDOWN_ACK_SENT] = 3 SECS, + [SCTP_CONNTRACK_HEARTBEAT_SENT] = 30 SECS, +}; + +#define SCTP_FLAG_HEARTBEAT_VTAG_FAILED 1 + +#define sNO SCTP_CONNTRACK_NONE +#define sCL SCTP_CONNTRACK_CLOSED +#define sCW SCTP_CONNTRACK_COOKIE_WAIT +#define sCE SCTP_CONNTRACK_COOKIE_ECHOED +#define sES SCTP_CONNTRACK_ESTABLISHED +#define sSS SCTP_CONNTRACK_SHUTDOWN_SENT +#define sSR SCTP_CONNTRACK_SHUTDOWN_RECD +#define sSA SCTP_CONNTRACK_SHUTDOWN_ACK_SENT +#define sHS SCTP_CONNTRACK_HEARTBEAT_SENT +#define sIV SCTP_CONNTRACK_MAX + +/* + These are the descriptions of the states: + +NOTE: These state names are tantalizingly similar to the states of an +SCTP endpoint. But the interpretation of the states is a little different, +considering that these are the states of the connection and not of an end +point. Please note the subtleties. -Kiran + +NONE - Nothing so far. +COOKIE WAIT - We have seen an INIT chunk in the original direction, or also + an INIT_ACK chunk in the reply direction. +COOKIE ECHOED - We have seen a COOKIE_ECHO chunk in the original direction. +ESTABLISHED - We have seen a COOKIE_ACK in the reply direction. +SHUTDOWN_SENT - We have seen a SHUTDOWN chunk in the original direction. +SHUTDOWN_RECD - We have seen a SHUTDOWN chunk in the reply directoin. +SHUTDOWN_ACK_SENT - We have seen a SHUTDOWN_ACK chunk in the direction opposite + to that of the SHUTDOWN chunk. +CLOSED - We have seen a SHUTDOWN_COMPLETE chunk in the direction of + the SHUTDOWN chunk. Connection is closed. +HEARTBEAT_SENT - We have seen a HEARTBEAT in a new flow. +*/ + +/* TODO + - I have assumed that the first INIT is in the original direction. + This messes things when an INIT comes in the reply direction in CLOSED + state. + - Check the error type in the reply dir before transitioning from +cookie echoed to closed. + - Sec 5.2.4 of RFC 2960 + - Full Multi Homing support. +*/ + +/* SCTP conntrack state transitions */ +static const u8 sctp_conntracks[2][11][SCTP_CONNTRACK_MAX] = { + { +/* ORIGINAL */ +/* sNO, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS */ +/* init */ {sCL, sCL, sCW, sCE, sES, sCL, sCL, sSA, sCW}, +/* init_ack */ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sSA, sCL}, +/* abort */ {sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL}, +/* shutdown */ {sCL, sCL, sCW, sCE, sSS, sSS, sSR, sSA, sCL}, +/* shutdown_ack */ {sSA, sCL, sCW, sCE, sES, sSA, sSA, sSA, sSA}, +/* error */ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sSA, sCL},/* Can't have Stale cookie*/ +/* cookie_echo */ {sCL, sCL, sCE, sCE, sES, sSS, sSR, sSA, sCL},/* 5.2.4 - Big TODO */ +/* cookie_ack */ {sCL, sCL, sCW, sES, sES, sSS, sSR, sSA, sCL},/* Can't come in orig dir */ +/* shutdown_comp*/ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sCL, sCL}, +/* heartbeat */ {sHS, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS}, +/* heartbeat_ack*/ {sCL, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS}, + }, + { +/* REPLY */ +/* sNO, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS */ +/* init */ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sSA, sIV},/* INIT in sCL Big TODO */ +/* init_ack */ {sIV, sCW, sCW, sCE, sES, sSS, sSR, sSA, sIV}, +/* abort */ {sIV, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sIV}, +/* shutdown */ {sIV, sCL, sCW, sCE, sSR, sSS, sSR, sSA, sIV}, +/* shutdown_ack */ {sIV, sCL, sCW, sCE, sES, sSA, sSA, sSA, sIV}, +/* error */ {sIV, sCL, sCW, sCL, sES, sSS, sSR, sSA, sIV}, +/* cookie_echo */ {sIV, sCL, sCE, sCE, sES, sSS, sSR, sSA, sIV},/* Can't come in reply dir */ +/* cookie_ack */ {sIV, sCL, sCW, sES, sES, sSS, sSR, sSA, sIV}, +/* shutdown_comp*/ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sCL, sIV}, +/* heartbeat */ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sSA, sHS}, +/* heartbeat_ack*/ {sIV, sCL, sCW, sCE, sES, sSS, sSR, sSA, sES}, + } +}; + +#ifdef CONFIG_NF_CONNTRACK_PROCFS +/* Print out the private part of the conntrack. */ +static void sctp_print_conntrack(struct seq_file *s, struct nf_conn *ct) +{ + seq_printf(s, "%s ", sctp_conntrack_names[ct->proto.sctp.state]); +} +#endif + +#define for_each_sctp_chunk(skb, sch, _sch, offset, dataoff, count) \ +for ((offset) = (dataoff) + sizeof(struct sctphdr), (count) = 0; \ + (offset) < (skb)->len && \ + ((sch) = skb_header_pointer((skb), (offset), sizeof(_sch), &(_sch))); \ + (offset) += (ntohs((sch)->length) + 3) & ~3, (count)++) + +/* Some validity checks to make sure the chunks are fine */ +static int do_basic_checks(struct nf_conn *ct, + const struct sk_buff *skb, + unsigned int dataoff, + unsigned long *map) +{ + u_int32_t offset, count; + struct sctp_chunkhdr _sch, *sch; + int flag; + + flag = 0; + + for_each_sctp_chunk (skb, sch, _sch, offset, dataoff, count) { + pr_debug("Chunk Num: %d Type: %d\n", count, sch->type); + + if (sch->type == SCTP_CID_INIT || + sch->type == SCTP_CID_INIT_ACK || + sch->type == SCTP_CID_SHUTDOWN_COMPLETE) + flag = 1; + + /* + * Cookie Ack/Echo chunks not the first OR + * Init / Init Ack / Shutdown compl chunks not the only chunks + * OR zero-length. + */ + if (((sch->type == SCTP_CID_COOKIE_ACK || + sch->type == SCTP_CID_COOKIE_ECHO || + flag) && + count != 0) || !sch->length) { + pr_debug("Basic checks failed\n"); + return 1; + } + + if (map) + set_bit(sch->type, map); + } + + pr_debug("Basic checks passed\n"); + return count == 0; +} + +static int sctp_new_state(enum ip_conntrack_dir dir, + enum sctp_conntrack cur_state, + int chunk_type) +{ + int i; + + pr_debug("Chunk type: %d\n", chunk_type); + + switch (chunk_type) { + case SCTP_CID_INIT: + pr_debug("SCTP_CID_INIT\n"); + i = 0; + break; + case SCTP_CID_INIT_ACK: + pr_debug("SCTP_CID_INIT_ACK\n"); + i = 1; + break; + case SCTP_CID_ABORT: + pr_debug("SCTP_CID_ABORT\n"); + i = 2; + break; + case SCTP_CID_SHUTDOWN: + pr_debug("SCTP_CID_SHUTDOWN\n"); + i = 3; + break; + case SCTP_CID_SHUTDOWN_ACK: + pr_debug("SCTP_CID_SHUTDOWN_ACK\n"); + i = 4; + break; + case SCTP_CID_ERROR: + pr_debug("SCTP_CID_ERROR\n"); + i = 5; + break; + case SCTP_CID_COOKIE_ECHO: + pr_debug("SCTP_CID_COOKIE_ECHO\n"); + i = 6; + break; + case SCTP_CID_COOKIE_ACK: + pr_debug("SCTP_CID_COOKIE_ACK\n"); + i = 7; + break; + case SCTP_CID_SHUTDOWN_COMPLETE: + pr_debug("SCTP_CID_SHUTDOWN_COMPLETE\n"); + i = 8; + break; + case SCTP_CID_HEARTBEAT: + pr_debug("SCTP_CID_HEARTBEAT"); + i = 9; + break; + case SCTP_CID_HEARTBEAT_ACK: + pr_debug("SCTP_CID_HEARTBEAT_ACK"); + i = 10; + break; + default: + /* Other chunks like DATA or SACK do not change the state */ + pr_debug("Unknown chunk type, Will stay in %s\n", + sctp_conntrack_names[cur_state]); + return cur_state; + } + + pr_debug("dir: %d cur_state: %s chunk_type: %d new_state: %s\n", + dir, sctp_conntrack_names[cur_state], chunk_type, + sctp_conntrack_names[sctp_conntracks[dir][i][cur_state]]); + + return sctp_conntracks[dir][i][cur_state]; +} + +/* Don't need lock here: this conntrack not in circulation yet */ +static noinline bool +sctp_new(struct nf_conn *ct, const struct sk_buff *skb, + const struct sctphdr *sh, unsigned int dataoff) +{ + enum sctp_conntrack new_state; + const struct sctp_chunkhdr *sch; + struct sctp_chunkhdr _sch; + u32 offset, count; + + memset(&ct->proto.sctp, 0, sizeof(ct->proto.sctp)); + new_state = SCTP_CONNTRACK_MAX; + for_each_sctp_chunk(skb, sch, _sch, offset, dataoff, count) { + new_state = sctp_new_state(IP_CT_DIR_ORIGINAL, + SCTP_CONNTRACK_NONE, sch->type); + + /* Invalid: delete conntrack */ + if (new_state == SCTP_CONNTRACK_NONE || + new_state == SCTP_CONNTRACK_MAX) { + pr_debug("nf_conntrack_sctp: invalid new deleting.\n"); + return false; + } + + /* Copy the vtag into the state info */ + if (sch->type == SCTP_CID_INIT) { + struct sctp_inithdr _inithdr, *ih; + /* Sec 8.5.1 (A) */ + if (sh->vtag) + return false; + + ih = skb_header_pointer(skb, offset + sizeof(_sch), + sizeof(_inithdr), &_inithdr); + if (!ih) + return false; + + pr_debug("Setting vtag %x for new conn\n", + ih->init_tag); + + ct->proto.sctp.vtag[IP_CT_DIR_REPLY] = ih->init_tag; + } else if (sch->type == SCTP_CID_HEARTBEAT) { + pr_debug("Setting vtag %x for secondary conntrack\n", + sh->vtag); + ct->proto.sctp.vtag[IP_CT_DIR_ORIGINAL] = sh->vtag; + } else { + /* If it is a shutdown ack OOTB packet, we expect a return + shutdown complete, otherwise an ABORT Sec 8.4 (5) and (8) */ + pr_debug("Setting vtag %x for new conn OOTB\n", + sh->vtag); + ct->proto.sctp.vtag[IP_CT_DIR_REPLY] = sh->vtag; + } + + ct->proto.sctp.state = SCTP_CONNTRACK_NONE; + } + + return true; +} + +static bool sctp_error(struct sk_buff *skb, + unsigned int dataoff, + const struct nf_hook_state *state) +{ + const struct sctphdr *sh; + const char *logmsg; + + if (skb->len < dataoff + sizeof(struct sctphdr)) { + logmsg = "nf_ct_sctp: short packet "; + goto out_invalid; + } + if (state->hook == NF_INET_PRE_ROUTING && + state->net->ct.sysctl_checksum && + skb->ip_summed == CHECKSUM_NONE) { + if (skb_ensure_writable(skb, dataoff + sizeof(*sh))) { + logmsg = "nf_ct_sctp: failed to read header "; + goto out_invalid; + } + sh = (const struct sctphdr *)(skb->data + dataoff); + if (sh->checksum != sctp_compute_cksum(skb, dataoff)) { + logmsg = "nf_ct_sctp: bad CRC "; + goto out_invalid; + } + skb->ip_summed = CHECKSUM_UNNECESSARY; + } + return false; +out_invalid: + nf_l4proto_log_invalid(skb, state, IPPROTO_SCTP, "%s", logmsg); + return true; +} + +/* Returns verdict for packet, or -NF_ACCEPT for invalid. */ +int nf_conntrack_sctp_packet(struct nf_conn *ct, + struct sk_buff *skb, + unsigned int dataoff, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) +{ + enum sctp_conntrack new_state, old_state; + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + const struct sctphdr *sh; + struct sctphdr _sctph; + const struct sctp_chunkhdr *sch; + struct sctp_chunkhdr _sch; + u_int32_t offset, count; + unsigned int *timeouts; + unsigned long map[256 / sizeof(unsigned long)] = { 0 }; + bool ignore = false; + + if (sctp_error(skb, dataoff, state)) + return -NF_ACCEPT; + + sh = skb_header_pointer(skb, dataoff, sizeof(_sctph), &_sctph); + if (sh == NULL) + goto out; + + if (do_basic_checks(ct, skb, dataoff, map) != 0) + goto out; + + if (!nf_ct_is_confirmed(ct)) { + /* If an OOTB packet has any of these chunks discard (Sec 8.4) */ + if (test_bit(SCTP_CID_ABORT, map) || + test_bit(SCTP_CID_SHUTDOWN_COMPLETE, map) || + test_bit(SCTP_CID_COOKIE_ACK, map)) + return -NF_ACCEPT; + + if (!sctp_new(ct, skb, sh, dataoff)) + return -NF_ACCEPT; + } + + /* Check the verification tag (Sec 8.5) */ + if (!test_bit(SCTP_CID_INIT, map) && + !test_bit(SCTP_CID_SHUTDOWN_COMPLETE, map) && + !test_bit(SCTP_CID_COOKIE_ECHO, map) && + !test_bit(SCTP_CID_ABORT, map) && + !test_bit(SCTP_CID_SHUTDOWN_ACK, map) && + !test_bit(SCTP_CID_HEARTBEAT, map) && + !test_bit(SCTP_CID_HEARTBEAT_ACK, map) && + sh->vtag != ct->proto.sctp.vtag[dir]) { + pr_debug("Verification tag check failed\n"); + goto out; + } + + old_state = new_state = SCTP_CONNTRACK_NONE; + spin_lock_bh(&ct->lock); + for_each_sctp_chunk (skb, sch, _sch, offset, dataoff, count) { + /* Special cases of Verification tag check (Sec 8.5.1) */ + if (sch->type == SCTP_CID_INIT) { + /* (A) vtag MUST be zero */ + if (sh->vtag != 0) + goto out_unlock; + } else if (sch->type == SCTP_CID_ABORT) { + /* (B) vtag MUST match own vtag if T flag is unset OR + * MUST match peer's vtag if T flag is set + */ + if ((!(sch->flags & SCTP_CHUNK_FLAG_T) && + sh->vtag != ct->proto.sctp.vtag[dir]) || + ((sch->flags & SCTP_CHUNK_FLAG_T) && + sh->vtag != ct->proto.sctp.vtag[!dir])) + goto out_unlock; + } else if (sch->type == SCTP_CID_SHUTDOWN_COMPLETE) { + /* (C) vtag MUST match own vtag if T flag is unset OR + * MUST match peer's vtag if T flag is set + */ + if ((!(sch->flags & SCTP_CHUNK_FLAG_T) && + sh->vtag != ct->proto.sctp.vtag[dir]) || + ((sch->flags & SCTP_CHUNK_FLAG_T) && + sh->vtag != ct->proto.sctp.vtag[!dir])) + goto out_unlock; + } else if (sch->type == SCTP_CID_COOKIE_ECHO) { + /* (D) vtag must be same as init_vtag as found in INIT_ACK */ + if (sh->vtag != ct->proto.sctp.vtag[dir]) + goto out_unlock; + } else if (sch->type == SCTP_CID_COOKIE_ACK) { + ct->proto.sctp.init[dir] = 0; + ct->proto.sctp.init[!dir] = 0; + } else if (sch->type == SCTP_CID_HEARTBEAT) { + if (ct->proto.sctp.vtag[dir] == 0) { + pr_debug("Setting %d vtag %x for dir %d\n", sch->type, sh->vtag, dir); + ct->proto.sctp.vtag[dir] = sh->vtag; + } else if (sh->vtag != ct->proto.sctp.vtag[dir]) { + if (test_bit(SCTP_CID_DATA, map) || ignore) + goto out_unlock; + + ct->proto.sctp.flags |= SCTP_FLAG_HEARTBEAT_VTAG_FAILED; + ct->proto.sctp.last_dir = dir; + ignore = true; + continue; + } else if (ct->proto.sctp.flags & SCTP_FLAG_HEARTBEAT_VTAG_FAILED) { + ct->proto.sctp.flags &= ~SCTP_FLAG_HEARTBEAT_VTAG_FAILED; + } + } else if (sch->type == SCTP_CID_HEARTBEAT_ACK) { + if (ct->proto.sctp.vtag[dir] == 0) { + pr_debug("Setting vtag %x for dir %d\n", + sh->vtag, dir); + ct->proto.sctp.vtag[dir] = sh->vtag; + } else if (sh->vtag != ct->proto.sctp.vtag[dir]) { + if (test_bit(SCTP_CID_DATA, map) || ignore) + goto out_unlock; + + if ((ct->proto.sctp.flags & SCTP_FLAG_HEARTBEAT_VTAG_FAILED) == 0 || + ct->proto.sctp.last_dir == dir) + goto out_unlock; + + ct->proto.sctp.flags &= ~SCTP_FLAG_HEARTBEAT_VTAG_FAILED; + ct->proto.sctp.vtag[dir] = sh->vtag; + ct->proto.sctp.vtag[!dir] = 0; + } else if (ct->proto.sctp.flags & SCTP_FLAG_HEARTBEAT_VTAG_FAILED) { + ct->proto.sctp.flags &= ~SCTP_FLAG_HEARTBEAT_VTAG_FAILED; + } + } + + old_state = ct->proto.sctp.state; + new_state = sctp_new_state(dir, old_state, sch->type); + + /* Invalid */ + if (new_state == SCTP_CONNTRACK_MAX) { + pr_debug("nf_conntrack_sctp: Invalid dir=%i ctype=%u " + "conntrack=%u\n", + dir, sch->type, old_state); + goto out_unlock; + } + + /* If it is an INIT or an INIT ACK note down the vtag */ + if (sch->type == SCTP_CID_INIT) { + struct sctp_inithdr _ih, *ih; + + ih = skb_header_pointer(skb, offset + sizeof(_sch), sizeof(*ih), &_ih); + if (!ih) + goto out_unlock; + + if (ct->proto.sctp.init[dir] && ct->proto.sctp.init[!dir]) + ct->proto.sctp.init[!dir] = 0; + ct->proto.sctp.init[dir] = 1; + + pr_debug("Setting vtag %x for dir %d\n", ih->init_tag, !dir); + ct->proto.sctp.vtag[!dir] = ih->init_tag; + + /* don't renew timeout on init retransmit so + * port reuse by client or NAT middlebox cannot + * keep entry alive indefinitely (incl. nat info). + */ + if (new_state == SCTP_CONNTRACK_CLOSED && + old_state == SCTP_CONNTRACK_CLOSED && + nf_ct_is_confirmed(ct)) + ignore = true; + } else if (sch->type == SCTP_CID_INIT_ACK) { + struct sctp_inithdr _ih, *ih; + __be32 vtag; + + ih = skb_header_pointer(skb, offset + sizeof(_sch), sizeof(*ih), &_ih); + if (!ih) + goto out_unlock; + + vtag = ct->proto.sctp.vtag[!dir]; + if (!ct->proto.sctp.init[!dir] && vtag && vtag != ih->init_tag) + goto out_unlock; + /* collision */ + if (ct->proto.sctp.init[dir] && ct->proto.sctp.init[!dir] && + vtag != ih->init_tag) + goto out_unlock; + + pr_debug("Setting vtag %x for dir %d\n", ih->init_tag, !dir); + ct->proto.sctp.vtag[!dir] = ih->init_tag; + } + + ct->proto.sctp.state = new_state; + if (old_state != new_state) { + nf_conntrack_event_cache(IPCT_PROTOINFO, ct); + if (new_state == SCTP_CONNTRACK_ESTABLISHED && + !test_and_set_bit(IPS_ASSURED_BIT, &ct->status)) + nf_conntrack_event_cache(IPCT_ASSURED, ct); + } + } + spin_unlock_bh(&ct->lock); + + /* allow but do not refresh timeout */ + if (ignore) + return NF_ACCEPT; + + timeouts = nf_ct_timeout_lookup(ct); + if (!timeouts) + timeouts = nf_sctp_pernet(nf_ct_net(ct))->timeouts; + + nf_ct_refresh_acct(ct, ctinfo, skb, timeouts[new_state]); + + return NF_ACCEPT; + +out_unlock: + spin_unlock_bh(&ct->lock); +out: + return -NF_ACCEPT; +} + +static bool sctp_can_early_drop(const struct nf_conn *ct) +{ + switch (ct->proto.sctp.state) { + case SCTP_CONNTRACK_SHUTDOWN_SENT: + case SCTP_CONNTRACK_SHUTDOWN_RECD: + case SCTP_CONNTRACK_SHUTDOWN_ACK_SENT: + return true; + default: + break; + } + + return false; +} + +#if IS_ENABLED(CONFIG_NF_CT_NETLINK) + +#include +#include + +static int sctp_to_nlattr(struct sk_buff *skb, struct nlattr *nla, + struct nf_conn *ct, bool destroy) +{ + struct nlattr *nest_parms; + + spin_lock_bh(&ct->lock); + nest_parms = nla_nest_start(skb, CTA_PROTOINFO_SCTP); + if (!nest_parms) + goto nla_put_failure; + + if (nla_put_u8(skb, CTA_PROTOINFO_SCTP_STATE, ct->proto.sctp.state)) + goto nla_put_failure; + + if (destroy) + goto skip_state; + + if (nla_put_be32(skb, CTA_PROTOINFO_SCTP_VTAG_ORIGINAL, + ct->proto.sctp.vtag[IP_CT_DIR_ORIGINAL]) || + nla_put_be32(skb, CTA_PROTOINFO_SCTP_VTAG_REPLY, + ct->proto.sctp.vtag[IP_CT_DIR_REPLY])) + goto nla_put_failure; + +skip_state: + spin_unlock_bh(&ct->lock); + nla_nest_end(skb, nest_parms); + + return 0; + +nla_put_failure: + spin_unlock_bh(&ct->lock); + return -1; +} + +static const struct nla_policy sctp_nla_policy[CTA_PROTOINFO_SCTP_MAX+1] = { + [CTA_PROTOINFO_SCTP_STATE] = { .type = NLA_U8 }, + [CTA_PROTOINFO_SCTP_VTAG_ORIGINAL] = { .type = NLA_U32 }, + [CTA_PROTOINFO_SCTP_VTAG_REPLY] = { .type = NLA_U32 }, +}; + +#define SCTP_NLATTR_SIZE ( \ + NLA_ALIGN(NLA_HDRLEN + 1) + \ + NLA_ALIGN(NLA_HDRLEN + 4) + \ + NLA_ALIGN(NLA_HDRLEN + 4)) + +static int nlattr_to_sctp(struct nlattr *cda[], struct nf_conn *ct) +{ + struct nlattr *attr = cda[CTA_PROTOINFO_SCTP]; + struct nlattr *tb[CTA_PROTOINFO_SCTP_MAX+1]; + int err; + + /* updates may not contain the internal protocol info, skip parsing */ + if (!attr) + return 0; + + err = nla_parse_nested_deprecated(tb, CTA_PROTOINFO_SCTP_MAX, attr, + sctp_nla_policy, NULL); + if (err < 0) + return err; + + if (!tb[CTA_PROTOINFO_SCTP_STATE] || + !tb[CTA_PROTOINFO_SCTP_VTAG_ORIGINAL] || + !tb[CTA_PROTOINFO_SCTP_VTAG_REPLY]) + return -EINVAL; + + spin_lock_bh(&ct->lock); + ct->proto.sctp.state = nla_get_u8(tb[CTA_PROTOINFO_SCTP_STATE]); + ct->proto.sctp.vtag[IP_CT_DIR_ORIGINAL] = + nla_get_be32(tb[CTA_PROTOINFO_SCTP_VTAG_ORIGINAL]); + ct->proto.sctp.vtag[IP_CT_DIR_REPLY] = + nla_get_be32(tb[CTA_PROTOINFO_SCTP_VTAG_REPLY]); + spin_unlock_bh(&ct->lock); + + return 0; +} +#endif + +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + +#include +#include + +static int sctp_timeout_nlattr_to_obj(struct nlattr *tb[], + struct net *net, void *data) +{ + unsigned int *timeouts = data; + struct nf_sctp_net *sn = nf_sctp_pernet(net); + int i; + + if (!timeouts) + timeouts = sn->timeouts; + + /* set default SCTP timeouts. */ + for (i=0; itimeouts[i]; + + /* there's a 1:1 mapping between attributes and protocol states. */ + for (i=CTA_TIMEOUT_SCTP_UNSPEC+1; itimeouts[i] = sctp_timeouts[i]; + + /* timeouts[0] is unused, init it so ->timeouts[0] contains + * 'new' timeout, like udp or icmp. + */ + sn->timeouts[0] = sctp_timeouts[SCTP_CONNTRACK_CLOSED]; +} + +const struct nf_conntrack_l4proto nf_conntrack_l4proto_sctp = { + .l4proto = IPPROTO_SCTP, +#ifdef CONFIG_NF_CONNTRACK_PROCFS + .print_conntrack = sctp_print_conntrack, +#endif + .can_early_drop = sctp_can_early_drop, +#if IS_ENABLED(CONFIG_NF_CT_NETLINK) + .nlattr_size = SCTP_NLATTR_SIZE, + .to_nlattr = sctp_to_nlattr, + .from_nlattr = nlattr_to_sctp, + .tuple_to_nlattr = nf_ct_port_tuple_to_nlattr, + .nlattr_tuple_size = nf_ct_port_nlattr_tuple_size, + .nlattr_to_tuple = nf_ct_port_nlattr_to_tuple, + .nla_policy = nf_ct_port_nla_policy, +#endif +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + .ctnl_timeout = { + .nlattr_to_obj = sctp_timeout_nlattr_to_obj, + .obj_to_nlattr = sctp_timeout_obj_to_nlattr, + .nlattr_max = CTA_TIMEOUT_SCTP_MAX, + .obj_size = sizeof(unsigned int) * SCTP_CONNTRACK_MAX, + .nla_policy = sctp_timeout_nla_policy, + }, +#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ +}; diff --git a/net/netfilter/nf_conntrack_proto_tcp.c b/net/netfilter/nf_conntrack_proto_tcp.c new file mode 100644 index 000000000..3ac1af6f5 --- /dev/null +++ b/net/netfilter/nf_conntrack_proto_tcp.c @@ -0,0 +1,1616 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2004 Netfilter Core Team + * (C) 2002-2013 Jozsef Kadlecsik + * (C) 2006-2012 Patrick McHardy + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + /* FIXME: Examine ipfilter's timeouts and conntrack transitions more + closely. They're more complex. --RR */ + +static const char *const tcp_conntrack_names[] = { + "NONE", + "SYN_SENT", + "SYN_RECV", + "ESTABLISHED", + "FIN_WAIT", + "CLOSE_WAIT", + "LAST_ACK", + "TIME_WAIT", + "CLOSE", + "SYN_SENT2", +}; + +enum nf_ct_tcp_action { + NFCT_TCP_IGNORE, + NFCT_TCP_INVALID, + NFCT_TCP_ACCEPT, +}; + +#define SECS * HZ +#define MINS * 60 SECS +#define HOURS * 60 MINS +#define DAYS * 24 HOURS + +static const unsigned int tcp_timeouts[TCP_CONNTRACK_TIMEOUT_MAX] = { + [TCP_CONNTRACK_SYN_SENT] = 2 MINS, + [TCP_CONNTRACK_SYN_RECV] = 60 SECS, + [TCP_CONNTRACK_ESTABLISHED] = 5 DAYS, + [TCP_CONNTRACK_FIN_WAIT] = 2 MINS, + [TCP_CONNTRACK_CLOSE_WAIT] = 60 SECS, + [TCP_CONNTRACK_LAST_ACK] = 30 SECS, + [TCP_CONNTRACK_TIME_WAIT] = 2 MINS, + [TCP_CONNTRACK_CLOSE] = 10 SECS, + [TCP_CONNTRACK_SYN_SENT2] = 2 MINS, +/* RFC1122 says the R2 limit should be at least 100 seconds. + Linux uses 15 packets as limit, which corresponds + to ~13-30min depending on RTO. */ + [TCP_CONNTRACK_RETRANS] = 5 MINS, + [TCP_CONNTRACK_UNACK] = 5 MINS, +}; + +#define sNO TCP_CONNTRACK_NONE +#define sSS TCP_CONNTRACK_SYN_SENT +#define sSR TCP_CONNTRACK_SYN_RECV +#define sES TCP_CONNTRACK_ESTABLISHED +#define sFW TCP_CONNTRACK_FIN_WAIT +#define sCW TCP_CONNTRACK_CLOSE_WAIT +#define sLA TCP_CONNTRACK_LAST_ACK +#define sTW TCP_CONNTRACK_TIME_WAIT +#define sCL TCP_CONNTRACK_CLOSE +#define sS2 TCP_CONNTRACK_SYN_SENT2 +#define sIV TCP_CONNTRACK_MAX +#define sIG TCP_CONNTRACK_IGNORE + +/* What TCP flags are set from RST/SYN/FIN/ACK. */ +enum tcp_bit_set { + TCP_SYN_SET, + TCP_SYNACK_SET, + TCP_FIN_SET, + TCP_ACK_SET, + TCP_RST_SET, + TCP_NONE_SET, +}; + +/* + * The TCP state transition table needs a few words... + * + * We are the man in the middle. All the packets go through us + * but might get lost in transit to the destination. + * It is assumed that the destinations can't receive segments + * we haven't seen. + * + * The checked segment is in window, but our windows are *not* + * equivalent with the ones of the sender/receiver. We always + * try to guess the state of the current sender. + * + * The meaning of the states are: + * + * NONE: initial state + * SYN_SENT: SYN-only packet seen + * SYN_SENT2: SYN-only packet seen from reply dir, simultaneous open + * SYN_RECV: SYN-ACK packet seen + * ESTABLISHED: ACK packet seen + * FIN_WAIT: FIN packet seen + * CLOSE_WAIT: ACK seen (after FIN) + * LAST_ACK: FIN seen (after FIN) + * TIME_WAIT: last ACK seen + * CLOSE: closed connection (RST) + * + * Packets marked as IGNORED (sIG): + * if they may be either invalid or valid + * and the receiver may send back a connection + * closing RST or a SYN/ACK. + * + * Packets marked as INVALID (sIV): + * if we regard them as truly invalid packets + */ +static const u8 tcp_conntracks[2][6][TCP_CONNTRACK_MAX] = { + { +/* ORIGINAL */ +/* sNO, sSS, sSR, sES, sFW, sCW, sLA, sTW, sCL, sS2 */ +/*syn*/ { sSS, sSS, sIG, sIG, sIG, sIG, sIG, sSS, sSS, sS2 }, +/* + * sNO -> sSS Initialize a new connection + * sSS -> sSS Retransmitted SYN + * sS2 -> sS2 Late retransmitted SYN + * sSR -> sIG + * sES -> sIG Error: SYNs in window outside the SYN_SENT state + * are errors. Receiver will reply with RST + * and close the connection. + * Or we are not in sync and hold a dead connection. + * sFW -> sIG + * sCW -> sIG + * sLA -> sIG + * sTW -> sSS Reopened connection (RFC 1122). + * sCL -> sSS + */ +/* sNO, sSS, sSR, sES, sFW, sCW, sLA, sTW, sCL, sS2 */ +/*synack*/ { sIV, sIV, sSR, sIV, sIV, sIV, sIV, sIV, sIV, sSR }, +/* + * sNO -> sIV Too late and no reason to do anything + * sSS -> sIV Client can't send SYN and then SYN/ACK + * sS2 -> sSR SYN/ACK sent to SYN2 in simultaneous open + * sSR -> sSR Late retransmitted SYN/ACK in simultaneous open + * sES -> sIV Invalid SYN/ACK packets sent by the client + * sFW -> sIV + * sCW -> sIV + * sLA -> sIV + * sTW -> sIV + * sCL -> sIV + */ +/* sNO, sSS, sSR, sES, sFW, sCW, sLA, sTW, sCL, sS2 */ +/*fin*/ { sIV, sIV, sFW, sFW, sLA, sLA, sLA, sTW, sCL, sIV }, +/* + * sNO -> sIV Too late and no reason to do anything... + * sSS -> sIV Client migth not send FIN in this state: + * we enforce waiting for a SYN/ACK reply first. + * sS2 -> sIV + * sSR -> sFW Close started. + * sES -> sFW + * sFW -> sLA FIN seen in both directions, waiting for + * the last ACK. + * Migth be a retransmitted FIN as well... + * sCW -> sLA + * sLA -> sLA Retransmitted FIN. Remain in the same state. + * sTW -> sTW + * sCL -> sCL + */ +/* sNO, sSS, sSR, sES, sFW, sCW, sLA, sTW, sCL, sS2 */ +/*ack*/ { sES, sIV, sES, sES, sCW, sCW, sTW, sTW, sCL, sIV }, +/* + * sNO -> sES Assumed. + * sSS -> sIV ACK is invalid: we haven't seen a SYN/ACK yet. + * sS2 -> sIV + * sSR -> sES Established state is reached. + * sES -> sES :-) + * sFW -> sCW Normal close request answered by ACK. + * sCW -> sCW + * sLA -> sTW Last ACK detected (RFC5961 challenged) + * sTW -> sTW Retransmitted last ACK. Remain in the same state. + * sCL -> sCL + */ +/* sNO, sSS, sSR, sES, sFW, sCW, sLA, sTW, sCL, sS2 */ +/*rst*/ { sIV, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL }, +/*none*/ { sIV, sIV, sIV, sIV, sIV, sIV, sIV, sIV, sIV, sIV } + }, + { +/* REPLY */ +/* sNO, sSS, sSR, sES, sFW, sCW, sLA, sTW, sCL, sS2 */ +/*syn*/ { sIV, sS2, sIV, sIV, sIV, sIV, sIV, sSS, sIV, sS2 }, +/* + * sNO -> sIV Never reached. + * sSS -> sS2 Simultaneous open + * sS2 -> sS2 Retransmitted simultaneous SYN + * sSR -> sIV Invalid SYN packets sent by the server + * sES -> sIV + * sFW -> sIV + * sCW -> sIV + * sLA -> sIV + * sTW -> sSS Reopened connection, but server may have switched role + * sCL -> sIV + */ +/* sNO, sSS, sSR, sES, sFW, sCW, sLA, sTW, sCL, sS2 */ +/*synack*/ { sIV, sSR, sIG, sIG, sIG, sIG, sIG, sIG, sIG, sSR }, +/* + * sSS -> sSR Standard open. + * sS2 -> sSR Simultaneous open + * sSR -> sIG Retransmitted SYN/ACK, ignore it. + * sES -> sIG Late retransmitted SYN/ACK? + * sFW -> sIG Might be SYN/ACK answering ignored SYN + * sCW -> sIG + * sLA -> sIG + * sTW -> sIG + * sCL -> sIG + */ +/* sNO, sSS, sSR, sES, sFW, sCW, sLA, sTW, sCL, sS2 */ +/*fin*/ { sIV, sIV, sFW, sFW, sLA, sLA, sLA, sTW, sCL, sIV }, +/* + * sSS -> sIV Server might not send FIN in this state. + * sS2 -> sIV + * sSR -> sFW Close started. + * sES -> sFW + * sFW -> sLA FIN seen in both directions. + * sCW -> sLA + * sLA -> sLA Retransmitted FIN. + * sTW -> sTW + * sCL -> sCL + */ +/* sNO, sSS, sSR, sES, sFW, sCW, sLA, sTW, sCL, sS2 */ +/*ack*/ { sIV, sIG, sSR, sES, sCW, sCW, sTW, sTW, sCL, sIG }, +/* + * sSS -> sIG Might be a half-open connection. + * sS2 -> sIG + * sSR -> sSR Might answer late resent SYN. + * sES -> sES :-) + * sFW -> sCW Normal close request answered by ACK. + * sCW -> sCW + * sLA -> sTW Last ACK detected (RFC5961 challenged) + * sTW -> sTW Retransmitted last ACK. + * sCL -> sCL + */ +/* sNO, sSS, sSR, sES, sFW, sCW, sLA, sTW, sCL, sS2 */ +/*rst*/ { sIV, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL, sCL }, +/*none*/ { sIV, sIV, sIV, sIV, sIV, sIV, sIV, sIV, sIV, sIV } + } +}; + +#ifdef CONFIG_NF_CONNTRACK_PROCFS +/* Print out the private part of the conntrack. */ +static void tcp_print_conntrack(struct seq_file *s, struct nf_conn *ct) +{ + if (test_bit(IPS_OFFLOAD_BIT, &ct->status)) + return; + + seq_printf(s, "%s ", tcp_conntrack_names[ct->proto.tcp.state]); +} +#endif + +static unsigned int get_conntrack_index(const struct tcphdr *tcph) +{ + if (tcph->rst) return TCP_RST_SET; + else if (tcph->syn) return (tcph->ack ? TCP_SYNACK_SET : TCP_SYN_SET); + else if (tcph->fin) return TCP_FIN_SET; + else if (tcph->ack) return TCP_ACK_SET; + else return TCP_NONE_SET; +} + +/* TCP connection tracking based on 'Real Stateful TCP Packet Filtering + in IP Filter' by Guido van Rooij. + + http://www.sane.nl/events/sane2000/papers.html + http://www.darkart.com/mirrors/www.obfuscation.org/ipf/ + + The boundaries and the conditions are changed according to RFC793: + the packet must intersect the window (i.e. segments may be + after the right or before the left edge) and thus receivers may ACK + segments after the right edge of the window. + + td_maxend = max(sack + max(win,1)) seen in reply packets + td_maxwin = max(max(win, 1)) + (sack - ack) seen in sent packets + td_maxwin += seq + len - sender.td_maxend + if seq + len > sender.td_maxend + td_end = max(seq + len) seen in sent packets + + I. Upper bound for valid data: seq <= sender.td_maxend + II. Lower bound for valid data: seq + len >= sender.td_end - receiver.td_maxwin + III. Upper bound for valid (s)ack: sack <= receiver.td_end + IV. Lower bound for valid (s)ack: sack >= receiver.td_end - MAXACKWINDOW + + where sack is the highest right edge of sack block found in the packet + or ack in the case of packet without SACK option. + + The upper bound limit for a valid (s)ack is not ignored - + we doesn't have to deal with fragments. +*/ + +static inline __u32 segment_seq_plus_len(__u32 seq, + size_t len, + unsigned int dataoff, + const struct tcphdr *tcph) +{ + /* XXX Should I use payload length field in IP/IPv6 header ? + * - YK */ + return (seq + len - dataoff - tcph->doff*4 + + (tcph->syn ? 1 : 0) + (tcph->fin ? 1 : 0)); +} + +/* Fixme: what about big packets? */ +#define MAXACKWINCONST 66000 +#define MAXACKWINDOW(sender) \ + ((sender)->td_maxwin > MAXACKWINCONST ? (sender)->td_maxwin \ + : MAXACKWINCONST) + +/* + * Simplified tcp_parse_options routine from tcp_input.c + */ +static void tcp_options(const struct sk_buff *skb, + unsigned int dataoff, + const struct tcphdr *tcph, + struct ip_ct_tcp_state *state) +{ + unsigned char buff[(15 * 4) - sizeof(struct tcphdr)]; + const unsigned char *ptr; + int length = (tcph->doff*4) - sizeof(struct tcphdr); + + if (!length) + return; + + ptr = skb_header_pointer(skb, dataoff + sizeof(struct tcphdr), + length, buff); + if (!ptr) + return; + + state->td_scale = 0; + state->flags &= IP_CT_TCP_FLAG_BE_LIBERAL; + + while (length > 0) { + int opcode=*ptr++; + int opsize; + + switch (opcode) { + case TCPOPT_EOL: + return; + case TCPOPT_NOP: /* Ref: RFC 793 section 3.1 */ + length--; + continue; + default: + if (length < 2) + return; + opsize=*ptr++; + if (opsize < 2) /* "silly options" */ + return; + if (opsize > length) + return; /* don't parse partial options */ + + if (opcode == TCPOPT_SACK_PERM + && opsize == TCPOLEN_SACK_PERM) + state->flags |= IP_CT_TCP_FLAG_SACK_PERM; + else if (opcode == TCPOPT_WINDOW + && opsize == TCPOLEN_WINDOW) { + state->td_scale = *(u_int8_t *)ptr; + + if (state->td_scale > TCP_MAX_WSCALE) + state->td_scale = TCP_MAX_WSCALE; + + state->flags |= + IP_CT_TCP_FLAG_WINDOW_SCALE; + } + ptr += opsize - 2; + length -= opsize; + } + } +} + +static void tcp_sack(const struct sk_buff *skb, unsigned int dataoff, + const struct tcphdr *tcph, __u32 *sack) +{ + unsigned char buff[(15 * 4) - sizeof(struct tcphdr)]; + const unsigned char *ptr; + int length = (tcph->doff*4) - sizeof(struct tcphdr); + __u32 tmp; + + if (!length) + return; + + ptr = skb_header_pointer(skb, dataoff + sizeof(struct tcphdr), + length, buff); + if (!ptr) + return; + + /* Fast path for timestamp-only option */ + if (length == TCPOLEN_TSTAMP_ALIGNED + && *(__be32 *)ptr == htonl((TCPOPT_NOP << 24) + | (TCPOPT_NOP << 16) + | (TCPOPT_TIMESTAMP << 8) + | TCPOLEN_TIMESTAMP)) + return; + + while (length > 0) { + int opcode = *ptr++; + int opsize, i; + + switch (opcode) { + case TCPOPT_EOL: + return; + case TCPOPT_NOP: /* Ref: RFC 793 section 3.1 */ + length--; + continue; + default: + if (length < 2) + return; + opsize = *ptr++; + if (opsize < 2) /* "silly options" */ + return; + if (opsize > length) + return; /* don't parse partial options */ + + if (opcode == TCPOPT_SACK + && opsize >= (TCPOLEN_SACK_BASE + + TCPOLEN_SACK_PERBLOCK) + && !((opsize - TCPOLEN_SACK_BASE) + % TCPOLEN_SACK_PERBLOCK)) { + for (i = 0; + i < (opsize - TCPOLEN_SACK_BASE); + i += TCPOLEN_SACK_PERBLOCK) { + tmp = get_unaligned_be32((__be32 *)(ptr+i)+1); + + if (after(tmp, *sack)) + *sack = tmp; + } + return; + } + ptr += opsize - 2; + length -= opsize; + } + } +} + +static void tcp_init_sender(struct ip_ct_tcp_state *sender, + struct ip_ct_tcp_state *receiver, + const struct sk_buff *skb, + unsigned int dataoff, + const struct tcphdr *tcph, + u32 end, u32 win) +{ + /* SYN-ACK in reply to a SYN + * or SYN from reply direction in simultaneous open. + */ + sender->td_end = + sender->td_maxend = end; + sender->td_maxwin = (win == 0 ? 1 : win); + + tcp_options(skb, dataoff, tcph, sender); + /* RFC 1323: + * Both sides must send the Window Scale option + * to enable window scaling in either direction. + */ + if (!(sender->flags & IP_CT_TCP_FLAG_WINDOW_SCALE && + receiver->flags & IP_CT_TCP_FLAG_WINDOW_SCALE)) { + sender->td_scale = 0; + receiver->td_scale = 0; + } +} + +__printf(6, 7) +static enum nf_ct_tcp_action nf_tcp_log_invalid(const struct sk_buff *skb, + const struct nf_conn *ct, + const struct nf_hook_state *state, + const struct ip_ct_tcp_state *sender, + enum nf_ct_tcp_action ret, + const char *fmt, ...) +{ + const struct nf_tcp_net *tn = nf_tcp_pernet(nf_ct_net(ct)); + struct va_format vaf; + va_list args; + bool be_liberal; + + be_liberal = sender->flags & IP_CT_TCP_FLAG_BE_LIBERAL || tn->tcp_be_liberal; + if (be_liberal) + return NFCT_TCP_ACCEPT; + + va_start(args, fmt); + vaf.fmt = fmt; + vaf.va = &args; + nf_ct_l4proto_log_invalid(skb, ct, state, "%pV", &vaf); + va_end(args); + + return ret; +} + +static enum nf_ct_tcp_action +tcp_in_window(struct nf_conn *ct, enum ip_conntrack_dir dir, + unsigned int index, const struct sk_buff *skb, + unsigned int dataoff, const struct tcphdr *tcph, + const struct nf_hook_state *hook_state) +{ + struct ip_ct_tcp *state = &ct->proto.tcp; + struct ip_ct_tcp_state *sender = &state->seen[dir]; + struct ip_ct_tcp_state *receiver = &state->seen[!dir]; + __u32 seq, ack, sack, end, win, swin; + bool in_recv_win, seq_ok; + s32 receiver_offset; + u16 win_raw; + + /* + * Get the required data from the packet. + */ + seq = ntohl(tcph->seq); + ack = sack = ntohl(tcph->ack_seq); + win_raw = ntohs(tcph->window); + win = win_raw; + end = segment_seq_plus_len(seq, skb->len, dataoff, tcph); + + if (receiver->flags & IP_CT_TCP_FLAG_SACK_PERM) + tcp_sack(skb, dataoff, tcph, &sack); + + /* Take into account NAT sequence number mangling */ + receiver_offset = nf_ct_seq_offset(ct, !dir, ack - 1); + ack -= receiver_offset; + sack -= receiver_offset; + + if (sender->td_maxwin == 0) { + /* + * Initialize sender data. + */ + if (tcph->syn) { + tcp_init_sender(sender, receiver, + skb, dataoff, tcph, + end, win); + if (!tcph->ack) + /* Simultaneous open */ + return NFCT_TCP_ACCEPT; + } else { + /* + * We are in the middle of a connection, + * its history is lost for us. + * Let's try to use the data from the packet. + */ + sender->td_end = end; + swin = win << sender->td_scale; + sender->td_maxwin = (swin == 0 ? 1 : swin); + sender->td_maxend = end + sender->td_maxwin; + if (receiver->td_maxwin == 0) { + /* We haven't seen traffic in the other + * direction yet but we have to tweak window + * tracking to pass III and IV until that + * happens. + */ + receiver->td_end = receiver->td_maxend = sack; + } else if (sack == receiver->td_end + 1) { + /* Likely a reply to a keepalive. + * Needed for III. + */ + receiver->td_end++; + } + + } + } else if (tcph->syn && + after(end, sender->td_end) && + (state->state == TCP_CONNTRACK_SYN_SENT || + state->state == TCP_CONNTRACK_SYN_RECV)) { + /* + * RFC 793: "if a TCP is reinitialized ... then it need + * not wait at all; it must only be sure to use sequence + * numbers larger than those recently used." + * + * Re-init state for this direction, just like for the first + * syn(-ack) reply, it might differ in seq, ack or tcp options. + */ + tcp_init_sender(sender, receiver, + skb, dataoff, tcph, + end, win); + + if (dir == IP_CT_DIR_REPLY && !tcph->ack) + return NFCT_TCP_ACCEPT; + } + + if (!(tcph->ack)) { + /* + * If there is no ACK, just pretend it was set and OK. + */ + ack = sack = receiver->td_end; + } else if (((tcp_flag_word(tcph) & (TCP_FLAG_ACK|TCP_FLAG_RST)) == + (TCP_FLAG_ACK|TCP_FLAG_RST)) + && (ack == 0)) { + /* + * Broken TCP stacks, that set ACK in RST packets as well + * with zero ack value. + */ + ack = sack = receiver->td_end; + } + + if (tcph->rst && seq == 0 && state->state == TCP_CONNTRACK_SYN_SENT) + /* + * RST sent answering SYN. + */ + seq = end = sender->td_end; + + seq_ok = before(seq, sender->td_maxend + 1); + if (!seq_ok) { + u32 overshot = end - sender->td_maxend + 1; + bool ack_ok; + + ack_ok = after(sack, receiver->td_end - MAXACKWINDOW(sender) - 1); + in_recv_win = receiver->td_maxwin && + after(end, sender->td_end - receiver->td_maxwin - 1); + + if (in_recv_win && + ack_ok && + overshot <= receiver->td_maxwin && + before(sack, receiver->td_end + 1)) { + /* Work around TCPs that send more bytes than allowed by + * the receive window. + * + * If the (marked as invalid) packet is allowed to pass by + * the ruleset and the peer acks this data, then its possible + * all future packets will trigger 'ACK is over upper bound' check. + * + * Thus if only the sequence check fails then do update td_end so + * possible ACK for this data can update internal state. + */ + sender->td_end = end; + sender->flags |= IP_CT_TCP_FLAG_DATA_UNACKNOWLEDGED; + + return nf_tcp_log_invalid(skb, ct, hook_state, sender, NFCT_TCP_IGNORE, + "%u bytes more than expected", overshot); + } + + return nf_tcp_log_invalid(skb, ct, hook_state, sender, NFCT_TCP_INVALID, + "SEQ is over upper bound %u (over the window of the receiver)", + sender->td_maxend + 1); + } + + if (!before(sack, receiver->td_end + 1)) + return nf_tcp_log_invalid(skb, ct, hook_state, sender, NFCT_TCP_INVALID, + "ACK is over upper bound %u (ACKed data not seen yet)", + receiver->td_end + 1); + + /* Is the ending sequence in the receive window (if available)? */ + in_recv_win = !receiver->td_maxwin || + after(end, sender->td_end - receiver->td_maxwin - 1); + if (!in_recv_win) + return nf_tcp_log_invalid(skb, ct, hook_state, sender, NFCT_TCP_IGNORE, + "SEQ is under lower bound %u (already ACKed data retransmitted)", + sender->td_end - receiver->td_maxwin - 1); + if (!after(sack, receiver->td_end - MAXACKWINDOW(sender) - 1)) + return nf_tcp_log_invalid(skb, ct, hook_state, sender, NFCT_TCP_IGNORE, + "ignored ACK under lower bound %u (possible overly delayed)", + receiver->td_end - MAXACKWINDOW(sender) - 1); + + /* Take into account window scaling (RFC 1323). */ + if (!tcph->syn) + win <<= sender->td_scale; + + /* Update sender data. */ + swin = win + (sack - ack); + if (sender->td_maxwin < swin) + sender->td_maxwin = swin; + if (after(end, sender->td_end)) { + sender->td_end = end; + sender->flags |= IP_CT_TCP_FLAG_DATA_UNACKNOWLEDGED; + } + if (tcph->ack) { + if (!(sender->flags & IP_CT_TCP_FLAG_MAXACK_SET)) { + sender->td_maxack = ack; + sender->flags |= IP_CT_TCP_FLAG_MAXACK_SET; + } else if (after(ack, sender->td_maxack)) { + sender->td_maxack = ack; + } + } + + /* Update receiver data. */ + if (receiver->td_maxwin != 0 && after(end, sender->td_maxend)) + receiver->td_maxwin += end - sender->td_maxend; + if (after(sack + win, receiver->td_maxend - 1)) { + receiver->td_maxend = sack + win; + if (win == 0) + receiver->td_maxend++; + } + if (ack == receiver->td_end) + receiver->flags &= ~IP_CT_TCP_FLAG_DATA_UNACKNOWLEDGED; + + /* Check retransmissions. */ + if (index == TCP_ACK_SET) { + if (state->last_dir == dir && + state->last_seq == seq && + state->last_ack == ack && + state->last_end == end && + state->last_win == win_raw) { + state->retrans++; + } else { + state->last_dir = dir; + state->last_seq = seq; + state->last_ack = ack; + state->last_end = end; + state->last_win = win_raw; + state->retrans = 0; + } + } + + return NFCT_TCP_ACCEPT; +} + +static void __cold nf_tcp_handle_invalid(struct nf_conn *ct, + enum ip_conntrack_dir dir, + int index, + const struct sk_buff *skb, + const struct nf_hook_state *hook_state) +{ + const unsigned int *timeouts; + const struct nf_tcp_net *tn; + unsigned int timeout; + u32 expires; + + if (!test_bit(IPS_ASSURED_BIT, &ct->status) || + test_bit(IPS_FIXED_TIMEOUT_BIT, &ct->status)) + return; + + /* We don't want to have connections hanging around in ESTABLISHED + * state for long time 'just because' conntrack deemed a FIN/RST + * out-of-window. + * + * Shrink the timeout just like when there is unacked data. + * This speeds up eviction of 'dead' connections where the + * connection and conntracks internal state are out of sync. + */ + switch (index) { + case TCP_RST_SET: + case TCP_FIN_SET: + break; + default: + return; + } + + if (ct->proto.tcp.last_dir != dir && + (ct->proto.tcp.last_index == TCP_FIN_SET || + ct->proto.tcp.last_index == TCP_RST_SET)) { + expires = nf_ct_expires(ct); + if (expires < 120 * HZ) + return; + + tn = nf_tcp_pernet(nf_ct_net(ct)); + timeouts = nf_ct_timeout_lookup(ct); + if (!timeouts) + timeouts = tn->timeouts; + + timeout = READ_ONCE(timeouts[TCP_CONNTRACK_UNACK]); + if (expires > timeout) { + nf_ct_l4proto_log_invalid(skb, ct, hook_state, + "packet (index %d, dir %d) response for index %d lower timeout to %u", + index, dir, ct->proto.tcp.last_index, timeout); + + WRITE_ONCE(ct->timeout, timeout + nfct_time_stamp); + } + } else { + ct->proto.tcp.last_index = index; + ct->proto.tcp.last_dir = dir; + } +} + +/* table of valid flag combinations - PUSH, ECE and CWR are always valid */ +static const u8 tcp_valid_flags[(TCPHDR_FIN|TCPHDR_SYN|TCPHDR_RST|TCPHDR_ACK| + TCPHDR_URG) + 1] = +{ + [TCPHDR_SYN] = 1, + [TCPHDR_SYN|TCPHDR_URG] = 1, + [TCPHDR_SYN|TCPHDR_ACK] = 1, + [TCPHDR_RST] = 1, + [TCPHDR_RST|TCPHDR_ACK] = 1, + [TCPHDR_FIN|TCPHDR_ACK] = 1, + [TCPHDR_FIN|TCPHDR_ACK|TCPHDR_URG] = 1, + [TCPHDR_ACK] = 1, + [TCPHDR_ACK|TCPHDR_URG] = 1, +}; + +static void tcp_error_log(const struct sk_buff *skb, + const struct nf_hook_state *state, + const char *msg) +{ + nf_l4proto_log_invalid(skb, state, IPPROTO_TCP, "%s", msg); +} + +/* Protect conntrack agaist broken packets. Code taken from ipt_unclean.c. */ +static bool tcp_error(const struct tcphdr *th, + struct sk_buff *skb, + unsigned int dataoff, + const struct nf_hook_state *state) +{ + unsigned int tcplen = skb->len - dataoff; + u8 tcpflags; + + /* Not whole TCP header or malformed packet */ + if (th->doff*4 < sizeof(struct tcphdr) || tcplen < th->doff*4) { + tcp_error_log(skb, state, "truncated packet"); + return true; + } + + /* Checksum invalid? Ignore. + * We skip checking packets on the outgoing path + * because the checksum is assumed to be correct. + */ + /* FIXME: Source route IP option packets --RR */ + if (state->net->ct.sysctl_checksum && + state->hook == NF_INET_PRE_ROUTING && + nf_checksum(skb, state->hook, dataoff, IPPROTO_TCP, state->pf)) { + tcp_error_log(skb, state, "bad checksum"); + return true; + } + + /* Check TCP flags. */ + tcpflags = (tcp_flag_byte(th) & ~(TCPHDR_ECE|TCPHDR_CWR|TCPHDR_PSH)); + if (!tcp_valid_flags[tcpflags]) { + tcp_error_log(skb, state, "invalid tcp flag combination"); + return true; + } + + return false; +} + +static noinline bool tcp_new(struct nf_conn *ct, const struct sk_buff *skb, + unsigned int dataoff, + const struct tcphdr *th) +{ + enum tcp_conntrack new_state; + struct net *net = nf_ct_net(ct); + const struct nf_tcp_net *tn = nf_tcp_pernet(net); + + /* Don't need lock here: this conntrack not in circulation yet */ + new_state = tcp_conntracks[0][get_conntrack_index(th)][TCP_CONNTRACK_NONE]; + + /* Invalid: delete conntrack */ + if (new_state >= TCP_CONNTRACK_MAX) { + pr_debug("nf_ct_tcp: invalid new deleting.\n"); + return false; + } + + if (new_state == TCP_CONNTRACK_SYN_SENT) { + memset(&ct->proto.tcp, 0, sizeof(ct->proto.tcp)); + /* SYN packet */ + ct->proto.tcp.seen[0].td_end = + segment_seq_plus_len(ntohl(th->seq), skb->len, + dataoff, th); + ct->proto.tcp.seen[0].td_maxwin = ntohs(th->window); + if (ct->proto.tcp.seen[0].td_maxwin == 0) + ct->proto.tcp.seen[0].td_maxwin = 1; + ct->proto.tcp.seen[0].td_maxend = + ct->proto.tcp.seen[0].td_end; + + tcp_options(skb, dataoff, th, &ct->proto.tcp.seen[0]); + } else if (tn->tcp_loose == 0) { + /* Don't try to pick up connections. */ + return false; + } else { + memset(&ct->proto.tcp, 0, sizeof(ct->proto.tcp)); + /* + * We are in the middle of a connection, + * its history is lost for us. + * Let's try to use the data from the packet. + */ + ct->proto.tcp.seen[0].td_end = + segment_seq_plus_len(ntohl(th->seq), skb->len, + dataoff, th); + ct->proto.tcp.seen[0].td_maxwin = ntohs(th->window); + if (ct->proto.tcp.seen[0].td_maxwin == 0) + ct->proto.tcp.seen[0].td_maxwin = 1; + ct->proto.tcp.seen[0].td_maxend = + ct->proto.tcp.seen[0].td_end + + ct->proto.tcp.seen[0].td_maxwin; + + /* We assume SACK and liberal window checking to handle + * window scaling */ + ct->proto.tcp.seen[0].flags = + ct->proto.tcp.seen[1].flags = IP_CT_TCP_FLAG_SACK_PERM | + IP_CT_TCP_FLAG_BE_LIBERAL; + } + + /* tcp_packet will set them */ + ct->proto.tcp.last_index = TCP_NONE_SET; + return true; +} + +static bool tcp_can_early_drop(const struct nf_conn *ct) +{ + switch (ct->proto.tcp.state) { + case TCP_CONNTRACK_FIN_WAIT: + case TCP_CONNTRACK_LAST_ACK: + case TCP_CONNTRACK_TIME_WAIT: + case TCP_CONNTRACK_CLOSE: + case TCP_CONNTRACK_CLOSE_WAIT: + return true; + default: + break; + } + + return false; +} + +static void nf_ct_tcp_state_reset(struct ip_ct_tcp_state *state) +{ + state->td_end = 0; + state->td_maxend = 0; + state->td_maxwin = 0; + state->td_maxack = 0; + state->td_scale = 0; + state->flags &= IP_CT_TCP_FLAG_BE_LIBERAL; +} + +/* Returns verdict for packet, or -1 for invalid. */ +int nf_conntrack_tcp_packet(struct nf_conn *ct, + struct sk_buff *skb, + unsigned int dataoff, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) +{ + struct net *net = nf_ct_net(ct); + struct nf_tcp_net *tn = nf_tcp_pernet(net); + struct nf_conntrack_tuple *tuple; + enum tcp_conntrack new_state, old_state; + unsigned int index, *timeouts; + enum nf_ct_tcp_action res; + enum ip_conntrack_dir dir; + const struct tcphdr *th; + struct tcphdr _tcph; + unsigned long timeout; + + th = skb_header_pointer(skb, dataoff, sizeof(_tcph), &_tcph); + if (th == NULL) + return -NF_ACCEPT; + + if (tcp_error(th, skb, dataoff, state)) + return -NF_ACCEPT; + + if (!nf_ct_is_confirmed(ct) && !tcp_new(ct, skb, dataoff, th)) + return -NF_ACCEPT; + + spin_lock_bh(&ct->lock); + old_state = ct->proto.tcp.state; + dir = CTINFO2DIR(ctinfo); + index = get_conntrack_index(th); + new_state = tcp_conntracks[dir][index][old_state]; + tuple = &ct->tuplehash[dir].tuple; + + switch (new_state) { + case TCP_CONNTRACK_SYN_SENT: + if (old_state < TCP_CONNTRACK_TIME_WAIT) + break; + /* RFC 1122: "When a connection is closed actively, + * it MUST linger in TIME-WAIT state for a time 2xMSL + * (Maximum Segment Lifetime). However, it MAY accept + * a new SYN from the remote TCP to reopen the connection + * directly from TIME-WAIT state, if..." + * We ignore the conditions because we are in the + * TIME-WAIT state anyway. + * + * Handle aborted connections: we and the server + * think there is an existing connection but the client + * aborts it and starts a new one. + */ + if (((ct->proto.tcp.seen[dir].flags + | ct->proto.tcp.seen[!dir].flags) + & IP_CT_TCP_FLAG_CLOSE_INIT) + || (ct->proto.tcp.last_dir == dir + && ct->proto.tcp.last_index == TCP_RST_SET)) { + /* Attempt to reopen a closed/aborted connection. + * Delete this connection and look up again. */ + spin_unlock_bh(&ct->lock); + + /* Only repeat if we can actually remove the timer. + * Destruction may already be in progress in process + * context and we must give it a chance to terminate. + */ + if (nf_ct_kill(ct)) + return -NF_REPEAT; + return NF_DROP; + } + fallthrough; + case TCP_CONNTRACK_IGNORE: + /* Ignored packets: + * + * Our connection entry may be out of sync, so ignore + * packets which may signal the real connection between + * the client and the server. + * + * a) SYN in ORIGINAL + * b) SYN/ACK in REPLY + * c) ACK in reply direction after initial SYN in original. + * + * If the ignored packet is invalid, the receiver will send + * a RST we'll catch below. + */ + if (index == TCP_SYNACK_SET + && ct->proto.tcp.last_index == TCP_SYN_SET + && ct->proto.tcp.last_dir != dir + && ntohl(th->ack_seq) == ct->proto.tcp.last_end) { + /* b) This SYN/ACK acknowledges a SYN that we earlier + * ignored as invalid. This means that the client and + * the server are both in sync, while the firewall is + * not. We get in sync from the previously annotated + * values. + */ + old_state = TCP_CONNTRACK_SYN_SENT; + new_state = TCP_CONNTRACK_SYN_RECV; + ct->proto.tcp.seen[ct->proto.tcp.last_dir].td_end = + ct->proto.tcp.last_end; + ct->proto.tcp.seen[ct->proto.tcp.last_dir].td_maxend = + ct->proto.tcp.last_end; + ct->proto.tcp.seen[ct->proto.tcp.last_dir].td_maxwin = + ct->proto.tcp.last_win == 0 ? + 1 : ct->proto.tcp.last_win; + ct->proto.tcp.seen[ct->proto.tcp.last_dir].td_scale = + ct->proto.tcp.last_wscale; + ct->proto.tcp.last_flags &= ~IP_CT_EXP_CHALLENGE_ACK; + ct->proto.tcp.seen[ct->proto.tcp.last_dir].flags = + ct->proto.tcp.last_flags; + nf_ct_tcp_state_reset(&ct->proto.tcp.seen[dir]); + break; + } + ct->proto.tcp.last_index = index; + ct->proto.tcp.last_dir = dir; + ct->proto.tcp.last_seq = ntohl(th->seq); + ct->proto.tcp.last_end = + segment_seq_plus_len(ntohl(th->seq), skb->len, dataoff, th); + ct->proto.tcp.last_win = ntohs(th->window); + + /* a) This is a SYN in ORIGINAL. The client and the server + * may be in sync but we are not. In that case, we annotate + * the TCP options and let the packet go through. If it is a + * valid SYN packet, the server will reply with a SYN/ACK, and + * then we'll get in sync. Otherwise, the server potentially + * responds with a challenge ACK if implementing RFC5961. + */ + if (index == TCP_SYN_SET && dir == IP_CT_DIR_ORIGINAL) { + struct ip_ct_tcp_state seen = {}; + + ct->proto.tcp.last_flags = + ct->proto.tcp.last_wscale = 0; + tcp_options(skb, dataoff, th, &seen); + if (seen.flags & IP_CT_TCP_FLAG_WINDOW_SCALE) { + ct->proto.tcp.last_flags |= + IP_CT_TCP_FLAG_WINDOW_SCALE; + ct->proto.tcp.last_wscale = seen.td_scale; + } + if (seen.flags & IP_CT_TCP_FLAG_SACK_PERM) { + ct->proto.tcp.last_flags |= + IP_CT_TCP_FLAG_SACK_PERM; + } + /* Mark the potential for RFC5961 challenge ACK, + * this pose a special problem for LAST_ACK state + * as ACK is intrepretated as ACKing last FIN. + */ + if (old_state == TCP_CONNTRACK_LAST_ACK) + ct->proto.tcp.last_flags |= + IP_CT_EXP_CHALLENGE_ACK; + } + + /* possible challenge ack reply to syn */ + if (old_state == TCP_CONNTRACK_SYN_SENT && + index == TCP_ACK_SET && + dir == IP_CT_DIR_REPLY) + ct->proto.tcp.last_ack = ntohl(th->ack_seq); + + spin_unlock_bh(&ct->lock); + nf_ct_l4proto_log_invalid(skb, ct, state, + "packet (index %d) in dir %d ignored, state %s", + index, dir, + tcp_conntrack_names[old_state]); + return NF_ACCEPT; + case TCP_CONNTRACK_MAX: + /* Special case for SYN proxy: when the SYN to the server or + * the SYN/ACK from the server is lost, the client may transmit + * a keep-alive packet while in SYN_SENT state. This needs to + * be associated with the original conntrack entry in order to + * generate a new SYN with the correct sequence number. + */ + if (nfct_synproxy(ct) && old_state == TCP_CONNTRACK_SYN_SENT && + index == TCP_ACK_SET && dir == IP_CT_DIR_ORIGINAL && + ct->proto.tcp.last_dir == IP_CT_DIR_ORIGINAL && + ct->proto.tcp.seen[dir].td_end - 1 == ntohl(th->seq)) { + pr_debug("nf_ct_tcp: SYN proxy client keep alive\n"); + spin_unlock_bh(&ct->lock); + return NF_ACCEPT; + } + + /* Invalid packet */ + spin_unlock_bh(&ct->lock); + nf_ct_l4proto_log_invalid(skb, ct, state, + "packet (index %d) in dir %d invalid, state %s", + index, dir, + tcp_conntrack_names[old_state]); + return -NF_ACCEPT; + case TCP_CONNTRACK_TIME_WAIT: + /* RFC5961 compliance cause stack to send "challenge-ACK" + * e.g. in response to spurious SYNs. Conntrack MUST + * not believe this ACK is acking last FIN. + */ + if (old_state == TCP_CONNTRACK_LAST_ACK && + index == TCP_ACK_SET && + ct->proto.tcp.last_dir != dir && + ct->proto.tcp.last_index == TCP_SYN_SET && + (ct->proto.tcp.last_flags & IP_CT_EXP_CHALLENGE_ACK)) { + /* Detected RFC5961 challenge ACK */ + ct->proto.tcp.last_flags &= ~IP_CT_EXP_CHALLENGE_ACK; + spin_unlock_bh(&ct->lock); + nf_ct_l4proto_log_invalid(skb, ct, state, "challenge-ack ignored"); + return NF_ACCEPT; /* Don't change state */ + } + break; + case TCP_CONNTRACK_SYN_SENT2: + /* tcp_conntracks table is not smart enough to handle + * simultaneous open. + */ + ct->proto.tcp.last_flags |= IP_CT_TCP_SIMULTANEOUS_OPEN; + break; + case TCP_CONNTRACK_SYN_RECV: + if (dir == IP_CT_DIR_REPLY && index == TCP_ACK_SET && + ct->proto.tcp.last_flags & IP_CT_TCP_SIMULTANEOUS_OPEN) + new_state = TCP_CONNTRACK_ESTABLISHED; + break; + case TCP_CONNTRACK_CLOSE: + if (index != TCP_RST_SET) + break; + + /* If we are closing, tuple might have been re-used already. + * last_index, last_ack, and all other ct fields used for + * sequence/window validation are outdated in that case. + * + * As the conntrack can already be expired by GC under pressure, + * just skip validation checks. + */ + if (tcp_can_early_drop(ct)) + goto in_window; + + /* td_maxack might be outdated if we let a SYN through earlier */ + if ((ct->proto.tcp.seen[!dir].flags & IP_CT_TCP_FLAG_MAXACK_SET) && + ct->proto.tcp.last_index != TCP_SYN_SET) { + u32 seq = ntohl(th->seq); + + /* If we are not in established state and SEQ=0 this is most + * likely an answer to a SYN we let go through above (last_index + * can be updated due to out-of-order ACKs). + */ + if (seq == 0 && !nf_conntrack_tcp_established(ct)) + break; + + if (before(seq, ct->proto.tcp.seen[!dir].td_maxack) && + !tn->tcp_ignore_invalid_rst) { + /* Invalid RST */ + spin_unlock_bh(&ct->lock); + nf_ct_l4proto_log_invalid(skb, ct, state, "invalid rst"); + return -NF_ACCEPT; + } + + if (!nf_conntrack_tcp_established(ct) || + seq == ct->proto.tcp.seen[!dir].td_maxack) + break; + + /* Check if rst is part of train, such as + * foo:80 > bar:4379: P, 235946583:235946602(19) ack 42 + * foo:80 > bar:4379: R, 235946602:235946602(0) ack 42 + */ + if (ct->proto.tcp.last_index == TCP_ACK_SET && + ct->proto.tcp.last_dir == dir && + seq == ct->proto.tcp.last_end) + break; + + /* ... RST sequence number doesn't match exactly, keep + * established state to allow a possible challenge ACK. + */ + new_state = old_state; + } + if (((test_bit(IPS_SEEN_REPLY_BIT, &ct->status) + && ct->proto.tcp.last_index == TCP_SYN_SET) + || (!test_bit(IPS_ASSURED_BIT, &ct->status) + && ct->proto.tcp.last_index == TCP_ACK_SET)) + && ntohl(th->ack_seq) == ct->proto.tcp.last_end) { + /* RST sent to invalid SYN or ACK we had let through + * at a) and c) above: + * + * a) SYN was in window then + * c) we hold a half-open connection. + * + * Delete our connection entry. + * We skip window checking, because packet might ACK + * segments we ignored. */ + goto in_window; + } + + /* Reset in response to a challenge-ack we let through earlier */ + if (old_state == TCP_CONNTRACK_SYN_SENT && + ct->proto.tcp.last_index == TCP_ACK_SET && + ct->proto.tcp.last_dir == IP_CT_DIR_REPLY && + ntohl(th->seq) == ct->proto.tcp.last_ack) + goto in_window; + + break; + default: + /* Keep compilers happy. */ + break; + } + + res = tcp_in_window(ct, dir, index, + skb, dataoff, th, state); + switch (res) { + case NFCT_TCP_IGNORE: + spin_unlock_bh(&ct->lock); + return NF_ACCEPT; + case NFCT_TCP_INVALID: + nf_tcp_handle_invalid(ct, dir, index, skb, state); + spin_unlock_bh(&ct->lock); + return -NF_ACCEPT; + case NFCT_TCP_ACCEPT: + break; + } + in_window: + /* From now on we have got in-window packets */ + ct->proto.tcp.last_index = index; + ct->proto.tcp.last_dir = dir; + + pr_debug("tcp_conntracks: "); + nf_ct_dump_tuple(tuple); + pr_debug("syn=%i ack=%i fin=%i rst=%i old=%i new=%i\n", + (th->syn ? 1 : 0), (th->ack ? 1 : 0), + (th->fin ? 1 : 0), (th->rst ? 1 : 0), + old_state, new_state); + + ct->proto.tcp.state = new_state; + if (old_state != new_state + && new_state == TCP_CONNTRACK_FIN_WAIT) + ct->proto.tcp.seen[dir].flags |= IP_CT_TCP_FLAG_CLOSE_INIT; + + timeouts = nf_ct_timeout_lookup(ct); + if (!timeouts) + timeouts = tn->timeouts; + + if (ct->proto.tcp.retrans >= tn->tcp_max_retrans && + timeouts[new_state] > timeouts[TCP_CONNTRACK_RETRANS]) + timeout = timeouts[TCP_CONNTRACK_RETRANS]; + else if (unlikely(index == TCP_RST_SET)) + timeout = timeouts[TCP_CONNTRACK_CLOSE]; + else if ((ct->proto.tcp.seen[0].flags | ct->proto.tcp.seen[1].flags) & + IP_CT_TCP_FLAG_DATA_UNACKNOWLEDGED && + timeouts[new_state] > timeouts[TCP_CONNTRACK_UNACK]) + timeout = timeouts[TCP_CONNTRACK_UNACK]; + else if (ct->proto.tcp.last_win == 0 && + timeouts[new_state] > timeouts[TCP_CONNTRACK_RETRANS]) + timeout = timeouts[TCP_CONNTRACK_RETRANS]; + else + timeout = timeouts[new_state]; + spin_unlock_bh(&ct->lock); + + if (new_state != old_state) + nf_conntrack_event_cache(IPCT_PROTOINFO, ct); + + if (!test_bit(IPS_SEEN_REPLY_BIT, &ct->status)) { + /* If only reply is a RST, we can consider ourselves not to + have an established connection: this is a fairly common + problem case, so we can delete the conntrack + immediately. --RR */ + if (th->rst) { + nf_ct_kill_acct(ct, ctinfo, skb); + return NF_ACCEPT; + } + + if (index == TCP_SYN_SET && old_state == TCP_CONNTRACK_SYN_SENT) { + /* do not renew timeout on SYN retransmit. + * + * Else port reuse by client or NAT middlebox can keep + * entry alive indefinitely (including nat info). + */ + return NF_ACCEPT; + } + + /* ESTABLISHED without SEEN_REPLY, i.e. mid-connection + * pickup with loose=1. Avoid large ESTABLISHED timeout. + */ + if (new_state == TCP_CONNTRACK_ESTABLISHED && + timeout > timeouts[TCP_CONNTRACK_UNACK]) + timeout = timeouts[TCP_CONNTRACK_UNACK]; + } else if (!test_bit(IPS_ASSURED_BIT, &ct->status) + && (old_state == TCP_CONNTRACK_SYN_RECV + || old_state == TCP_CONNTRACK_ESTABLISHED) + && new_state == TCP_CONNTRACK_ESTABLISHED) { + /* Set ASSURED if we see valid ack in ESTABLISHED + after SYN_RECV or a valid answer for a picked up + connection. */ + set_bit(IPS_ASSURED_BIT, &ct->status); + nf_conntrack_event_cache(IPCT_ASSURED, ct); + } + nf_ct_refresh_acct(ct, ctinfo, skb, timeout); + + return NF_ACCEPT; +} + +#if IS_ENABLED(CONFIG_NF_CT_NETLINK) + +#include +#include + +static int tcp_to_nlattr(struct sk_buff *skb, struct nlattr *nla, + struct nf_conn *ct, bool destroy) +{ + struct nlattr *nest_parms; + struct nf_ct_tcp_flags tmp = {}; + + spin_lock_bh(&ct->lock); + nest_parms = nla_nest_start(skb, CTA_PROTOINFO_TCP); + if (!nest_parms) + goto nla_put_failure; + + if (nla_put_u8(skb, CTA_PROTOINFO_TCP_STATE, ct->proto.tcp.state)) + goto nla_put_failure; + + if (destroy) + goto skip_state; + + if (nla_put_u8(skb, CTA_PROTOINFO_TCP_WSCALE_ORIGINAL, + ct->proto.tcp.seen[0].td_scale) || + nla_put_u8(skb, CTA_PROTOINFO_TCP_WSCALE_REPLY, + ct->proto.tcp.seen[1].td_scale)) + goto nla_put_failure; + + tmp.flags = ct->proto.tcp.seen[0].flags; + if (nla_put(skb, CTA_PROTOINFO_TCP_FLAGS_ORIGINAL, + sizeof(struct nf_ct_tcp_flags), &tmp)) + goto nla_put_failure; + + tmp.flags = ct->proto.tcp.seen[1].flags; + if (nla_put(skb, CTA_PROTOINFO_TCP_FLAGS_REPLY, + sizeof(struct nf_ct_tcp_flags), &tmp)) + goto nla_put_failure; +skip_state: + spin_unlock_bh(&ct->lock); + nla_nest_end(skb, nest_parms); + + return 0; + +nla_put_failure: + spin_unlock_bh(&ct->lock); + return -1; +} + +static const struct nla_policy tcp_nla_policy[CTA_PROTOINFO_TCP_MAX+1] = { + [CTA_PROTOINFO_TCP_STATE] = { .type = NLA_U8 }, + [CTA_PROTOINFO_TCP_WSCALE_ORIGINAL] = { .type = NLA_U8 }, + [CTA_PROTOINFO_TCP_WSCALE_REPLY] = { .type = NLA_U8 }, + [CTA_PROTOINFO_TCP_FLAGS_ORIGINAL] = { .len = sizeof(struct nf_ct_tcp_flags) }, + [CTA_PROTOINFO_TCP_FLAGS_REPLY] = { .len = sizeof(struct nf_ct_tcp_flags) }, +}; + +#define TCP_NLATTR_SIZE ( \ + NLA_ALIGN(NLA_HDRLEN + 1) + \ + NLA_ALIGN(NLA_HDRLEN + 1) + \ + NLA_ALIGN(NLA_HDRLEN + sizeof(struct nf_ct_tcp_flags)) + \ + NLA_ALIGN(NLA_HDRLEN + sizeof(struct nf_ct_tcp_flags))) + +static int nlattr_to_tcp(struct nlattr *cda[], struct nf_conn *ct) +{ + struct nlattr *pattr = cda[CTA_PROTOINFO_TCP]; + struct nlattr *tb[CTA_PROTOINFO_TCP_MAX+1]; + int err; + + /* updates could not contain anything about the private + * protocol info, in that case skip the parsing */ + if (!pattr) + return 0; + + err = nla_parse_nested_deprecated(tb, CTA_PROTOINFO_TCP_MAX, pattr, + tcp_nla_policy, NULL); + if (err < 0) + return err; + + if (tb[CTA_PROTOINFO_TCP_STATE] && + nla_get_u8(tb[CTA_PROTOINFO_TCP_STATE]) >= TCP_CONNTRACK_MAX) + return -EINVAL; + + spin_lock_bh(&ct->lock); + if (tb[CTA_PROTOINFO_TCP_STATE]) + ct->proto.tcp.state = nla_get_u8(tb[CTA_PROTOINFO_TCP_STATE]); + + if (tb[CTA_PROTOINFO_TCP_FLAGS_ORIGINAL]) { + struct nf_ct_tcp_flags *attr = + nla_data(tb[CTA_PROTOINFO_TCP_FLAGS_ORIGINAL]); + ct->proto.tcp.seen[0].flags &= ~attr->mask; + ct->proto.tcp.seen[0].flags |= attr->flags & attr->mask; + } + + if (tb[CTA_PROTOINFO_TCP_FLAGS_REPLY]) { + struct nf_ct_tcp_flags *attr = + nla_data(tb[CTA_PROTOINFO_TCP_FLAGS_REPLY]); + ct->proto.tcp.seen[1].flags &= ~attr->mask; + ct->proto.tcp.seen[1].flags |= attr->flags & attr->mask; + } + + if (tb[CTA_PROTOINFO_TCP_WSCALE_ORIGINAL] && + tb[CTA_PROTOINFO_TCP_WSCALE_REPLY] && + ct->proto.tcp.seen[0].flags & IP_CT_TCP_FLAG_WINDOW_SCALE && + ct->proto.tcp.seen[1].flags & IP_CT_TCP_FLAG_WINDOW_SCALE) { + ct->proto.tcp.seen[0].td_scale = + nla_get_u8(tb[CTA_PROTOINFO_TCP_WSCALE_ORIGINAL]); + ct->proto.tcp.seen[1].td_scale = + nla_get_u8(tb[CTA_PROTOINFO_TCP_WSCALE_REPLY]); + } + spin_unlock_bh(&ct->lock); + + return 0; +} + +static unsigned int tcp_nlattr_tuple_size(void) +{ + static unsigned int size __read_mostly; + + if (!size) + size = nla_policy_len(nf_ct_port_nla_policy, CTA_PROTO_MAX + 1); + + return size; +} +#endif + +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + +#include +#include + +static int tcp_timeout_nlattr_to_obj(struct nlattr *tb[], + struct net *net, void *data) +{ + struct nf_tcp_net *tn = nf_tcp_pernet(net); + unsigned int *timeouts = data; + int i; + + if (!timeouts) + timeouts = tn->timeouts; + /* set default TCP timeouts. */ + for (i=0; itimeouts[i]; + + if (tb[CTA_TIMEOUT_TCP_SYN_SENT]) { + timeouts[TCP_CONNTRACK_SYN_SENT] = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_TCP_SYN_SENT]))*HZ; + } + + if (tb[CTA_TIMEOUT_TCP_SYN_RECV]) { + timeouts[TCP_CONNTRACK_SYN_RECV] = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_TCP_SYN_RECV]))*HZ; + } + if (tb[CTA_TIMEOUT_TCP_ESTABLISHED]) { + timeouts[TCP_CONNTRACK_ESTABLISHED] = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_TCP_ESTABLISHED]))*HZ; + } + if (tb[CTA_TIMEOUT_TCP_FIN_WAIT]) { + timeouts[TCP_CONNTRACK_FIN_WAIT] = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_TCP_FIN_WAIT]))*HZ; + } + if (tb[CTA_TIMEOUT_TCP_CLOSE_WAIT]) { + timeouts[TCP_CONNTRACK_CLOSE_WAIT] = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_TCP_CLOSE_WAIT]))*HZ; + } + if (tb[CTA_TIMEOUT_TCP_LAST_ACK]) { + timeouts[TCP_CONNTRACK_LAST_ACK] = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_TCP_LAST_ACK]))*HZ; + } + if (tb[CTA_TIMEOUT_TCP_TIME_WAIT]) { + timeouts[TCP_CONNTRACK_TIME_WAIT] = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_TCP_TIME_WAIT]))*HZ; + } + if (tb[CTA_TIMEOUT_TCP_CLOSE]) { + timeouts[TCP_CONNTRACK_CLOSE] = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_TCP_CLOSE]))*HZ; + } + if (tb[CTA_TIMEOUT_TCP_SYN_SENT2]) { + timeouts[TCP_CONNTRACK_SYN_SENT2] = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_TCP_SYN_SENT2]))*HZ; + } + if (tb[CTA_TIMEOUT_TCP_RETRANS]) { + timeouts[TCP_CONNTRACK_RETRANS] = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_TCP_RETRANS]))*HZ; + } + if (tb[CTA_TIMEOUT_TCP_UNACK]) { + timeouts[TCP_CONNTRACK_UNACK] = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_TCP_UNACK]))*HZ; + } + + timeouts[CTA_TIMEOUT_TCP_UNSPEC] = timeouts[CTA_TIMEOUT_TCP_SYN_SENT]; + return 0; +} + +static int +tcp_timeout_obj_to_nlattr(struct sk_buff *skb, const void *data) +{ + const unsigned int *timeouts = data; + + if (nla_put_be32(skb, CTA_TIMEOUT_TCP_SYN_SENT, + htonl(timeouts[TCP_CONNTRACK_SYN_SENT] / HZ)) || + nla_put_be32(skb, CTA_TIMEOUT_TCP_SYN_RECV, + htonl(timeouts[TCP_CONNTRACK_SYN_RECV] / HZ)) || + nla_put_be32(skb, CTA_TIMEOUT_TCP_ESTABLISHED, + htonl(timeouts[TCP_CONNTRACK_ESTABLISHED] / HZ)) || + nla_put_be32(skb, CTA_TIMEOUT_TCP_FIN_WAIT, + htonl(timeouts[TCP_CONNTRACK_FIN_WAIT] / HZ)) || + nla_put_be32(skb, CTA_TIMEOUT_TCP_CLOSE_WAIT, + htonl(timeouts[TCP_CONNTRACK_CLOSE_WAIT] / HZ)) || + nla_put_be32(skb, CTA_TIMEOUT_TCP_LAST_ACK, + htonl(timeouts[TCP_CONNTRACK_LAST_ACK] / HZ)) || + nla_put_be32(skb, CTA_TIMEOUT_TCP_TIME_WAIT, + htonl(timeouts[TCP_CONNTRACK_TIME_WAIT] / HZ)) || + nla_put_be32(skb, CTA_TIMEOUT_TCP_CLOSE, + htonl(timeouts[TCP_CONNTRACK_CLOSE] / HZ)) || + nla_put_be32(skb, CTA_TIMEOUT_TCP_SYN_SENT2, + htonl(timeouts[TCP_CONNTRACK_SYN_SENT2] / HZ)) || + nla_put_be32(skb, CTA_TIMEOUT_TCP_RETRANS, + htonl(timeouts[TCP_CONNTRACK_RETRANS] / HZ)) || + nla_put_be32(skb, CTA_TIMEOUT_TCP_UNACK, + htonl(timeouts[TCP_CONNTRACK_UNACK] / HZ))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -ENOSPC; +} + +static const struct nla_policy tcp_timeout_nla_policy[CTA_TIMEOUT_TCP_MAX+1] = { + [CTA_TIMEOUT_TCP_SYN_SENT] = { .type = NLA_U32 }, + [CTA_TIMEOUT_TCP_SYN_RECV] = { .type = NLA_U32 }, + [CTA_TIMEOUT_TCP_ESTABLISHED] = { .type = NLA_U32 }, + [CTA_TIMEOUT_TCP_FIN_WAIT] = { .type = NLA_U32 }, + [CTA_TIMEOUT_TCP_CLOSE_WAIT] = { .type = NLA_U32 }, + [CTA_TIMEOUT_TCP_LAST_ACK] = { .type = NLA_U32 }, + [CTA_TIMEOUT_TCP_TIME_WAIT] = { .type = NLA_U32 }, + [CTA_TIMEOUT_TCP_CLOSE] = { .type = NLA_U32 }, + [CTA_TIMEOUT_TCP_SYN_SENT2] = { .type = NLA_U32 }, + [CTA_TIMEOUT_TCP_RETRANS] = { .type = NLA_U32 }, + [CTA_TIMEOUT_TCP_UNACK] = { .type = NLA_U32 }, +}; +#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ + +void nf_conntrack_tcp_init_net(struct net *net) +{ + struct nf_tcp_net *tn = nf_tcp_pernet(net); + int i; + + for (i = 0; i < TCP_CONNTRACK_TIMEOUT_MAX; i++) + tn->timeouts[i] = tcp_timeouts[i]; + + /* timeouts[0] is unused, make it same as SYN_SENT so + * ->timeouts[0] contains 'new' timeout, like udp or icmp. + */ + tn->timeouts[0] = tcp_timeouts[TCP_CONNTRACK_SYN_SENT]; + + /* If it is set to zero, we disable picking up already established + * connections. + */ + tn->tcp_loose = 1; + + /* "Be conservative in what you do, + * be liberal in what you accept from others." + * If it's non-zero, we mark only out of window RST segments as INVALID. + */ + tn->tcp_be_liberal = 0; + + /* If it's non-zero, we turn off RST sequence number check */ + tn->tcp_ignore_invalid_rst = 0; + + /* Max number of the retransmitted packets without receiving an (acceptable) + * ACK from the destination. If this number is reached, a shorter timer + * will be started. + */ + tn->tcp_max_retrans = 3; + +#if IS_ENABLED(CONFIG_NF_FLOW_TABLE) + tn->offload_timeout = 30 * HZ; +#endif +} + +const struct nf_conntrack_l4proto nf_conntrack_l4proto_tcp = +{ + .l4proto = IPPROTO_TCP, +#ifdef CONFIG_NF_CONNTRACK_PROCFS + .print_conntrack = tcp_print_conntrack, +#endif + .can_early_drop = tcp_can_early_drop, +#if IS_ENABLED(CONFIG_NF_CT_NETLINK) + .to_nlattr = tcp_to_nlattr, + .from_nlattr = nlattr_to_tcp, + .tuple_to_nlattr = nf_ct_port_tuple_to_nlattr, + .nlattr_to_tuple = nf_ct_port_nlattr_to_tuple, + .nlattr_tuple_size = tcp_nlattr_tuple_size, + .nlattr_size = TCP_NLATTR_SIZE, + .nla_policy = nf_ct_port_nla_policy, +#endif +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + .ctnl_timeout = { + .nlattr_to_obj = tcp_timeout_nlattr_to_obj, + .obj_to_nlattr = tcp_timeout_obj_to_nlattr, + .nlattr_max = CTA_TIMEOUT_TCP_MAX, + .obj_size = sizeof(unsigned int) * + TCP_CONNTRACK_TIMEOUT_MAX, + .nla_policy = tcp_timeout_nla_policy, + }, +#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ +}; diff --git a/net/netfilter/nf_conntrack_proto_udp.c b/net/netfilter/nf_conntrack_proto_udp.c new file mode 100644 index 000000000..3b516cffc --- /dev/null +++ b/net/netfilter/nf_conntrack_proto_udp.c @@ -0,0 +1,322 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2004 Netfilter Core Team + * (C) 2006-2012 Patrick McHardy + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static const unsigned int udp_timeouts[UDP_CT_MAX] = { + [UDP_CT_UNREPLIED] = 30*HZ, + [UDP_CT_REPLIED] = 120*HZ, +}; + +static unsigned int *udp_get_timeouts(struct net *net) +{ + return nf_udp_pernet(net)->timeouts; +} + +static void udp_error_log(const struct sk_buff *skb, + const struct nf_hook_state *state, + const char *msg) +{ + nf_l4proto_log_invalid(skb, state, IPPROTO_UDP, "%s", msg); +} + +static bool udp_error(struct sk_buff *skb, + unsigned int dataoff, + const struct nf_hook_state *state) +{ + unsigned int udplen = skb->len - dataoff; + const struct udphdr *hdr; + struct udphdr _hdr; + + /* Header is too small? */ + hdr = skb_header_pointer(skb, dataoff, sizeof(_hdr), &_hdr); + if (!hdr) { + udp_error_log(skb, state, "short packet"); + return true; + } + + /* Truncated/malformed packets */ + if (ntohs(hdr->len) > udplen || ntohs(hdr->len) < sizeof(*hdr)) { + udp_error_log(skb, state, "truncated/malformed packet"); + return true; + } + + /* Packet with no checksum */ + if (!hdr->check) + return false; + + /* Checksum invalid? Ignore. + * We skip checking packets on the outgoing path + * because the checksum is assumed to be correct. + * FIXME: Source route IP option packets --RR */ + if (state->hook == NF_INET_PRE_ROUTING && + state->net->ct.sysctl_checksum && + nf_checksum(skb, state->hook, dataoff, IPPROTO_UDP, state->pf)) { + udp_error_log(skb, state, "bad checksum"); + return true; + } + + return false; +} + +/* Returns verdict for packet, and may modify conntracktype */ +int nf_conntrack_udp_packet(struct nf_conn *ct, + struct sk_buff *skb, + unsigned int dataoff, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) +{ + unsigned int *timeouts; + + if (udp_error(skb, dataoff, state)) + return -NF_ACCEPT; + + timeouts = nf_ct_timeout_lookup(ct); + if (!timeouts) + timeouts = udp_get_timeouts(nf_ct_net(ct)); + + if (!nf_ct_is_confirmed(ct)) + ct->proto.udp.stream_ts = 2 * HZ + jiffies; + + /* If we've seen traffic both ways, this is some kind of UDP + * stream. Set Assured. + */ + if (test_bit(IPS_SEEN_REPLY_BIT, &ct->status)) { + unsigned long extra = timeouts[UDP_CT_UNREPLIED]; + bool stream = false; + + /* Still active after two seconds? Extend timeout. */ + if (time_after(jiffies, ct->proto.udp.stream_ts)) { + extra = timeouts[UDP_CT_REPLIED]; + stream = true; + } + + nf_ct_refresh_acct(ct, ctinfo, skb, extra); + + /* never set ASSURED for IPS_NAT_CLASH, they time out soon */ + if (unlikely((ct->status & IPS_NAT_CLASH))) + return NF_ACCEPT; + + /* Also, more likely to be important, and not a probe */ + if (stream && !test_and_set_bit(IPS_ASSURED_BIT, &ct->status)) + nf_conntrack_event_cache(IPCT_ASSURED, ct); + } else { + nf_ct_refresh_acct(ct, ctinfo, skb, timeouts[UDP_CT_UNREPLIED]); + } + return NF_ACCEPT; +} + +#ifdef CONFIG_NF_CT_PROTO_UDPLITE +static void udplite_error_log(const struct sk_buff *skb, + const struct nf_hook_state *state, + const char *msg) +{ + nf_l4proto_log_invalid(skb, state, IPPROTO_UDPLITE, "%s", msg); +} + +static bool udplite_error(struct sk_buff *skb, + unsigned int dataoff, + const struct nf_hook_state *state) +{ + unsigned int udplen = skb->len - dataoff; + const struct udphdr *hdr; + struct udphdr _hdr; + unsigned int cscov; + + /* Header is too small? */ + hdr = skb_header_pointer(skb, dataoff, sizeof(_hdr), &_hdr); + if (!hdr) { + udplite_error_log(skb, state, "short packet"); + return true; + } + + cscov = ntohs(hdr->len); + if (cscov == 0) { + cscov = udplen; + } else if (cscov < sizeof(*hdr) || cscov > udplen) { + udplite_error_log(skb, state, "invalid checksum coverage"); + return true; + } + + /* UDPLITE mandates checksums */ + if (!hdr->check) { + udplite_error_log(skb, state, "checksum missing"); + return true; + } + + /* Checksum invalid? Ignore. */ + if (state->hook == NF_INET_PRE_ROUTING && + state->net->ct.sysctl_checksum && + nf_checksum_partial(skb, state->hook, dataoff, cscov, IPPROTO_UDP, + state->pf)) { + udplite_error_log(skb, state, "bad checksum"); + return true; + } + + return false; +} + +/* Returns verdict for packet, and may modify conntracktype */ +int nf_conntrack_udplite_packet(struct nf_conn *ct, + struct sk_buff *skb, + unsigned int dataoff, + enum ip_conntrack_info ctinfo, + const struct nf_hook_state *state) +{ + unsigned int *timeouts; + + if (udplite_error(skb, dataoff, state)) + return -NF_ACCEPT; + + timeouts = nf_ct_timeout_lookup(ct); + if (!timeouts) + timeouts = udp_get_timeouts(nf_ct_net(ct)); + + /* If we've seen traffic both ways, this is some kind of UDP + stream. Extend timeout. */ + if (test_bit(IPS_SEEN_REPLY_BIT, &ct->status)) { + nf_ct_refresh_acct(ct, ctinfo, skb, + timeouts[UDP_CT_REPLIED]); + + if (unlikely((ct->status & IPS_NAT_CLASH))) + return NF_ACCEPT; + + /* Also, more likely to be important, and not a probe */ + if (!test_and_set_bit(IPS_ASSURED_BIT, &ct->status)) + nf_conntrack_event_cache(IPCT_ASSURED, ct); + } else { + nf_ct_refresh_acct(ct, ctinfo, skb, timeouts[UDP_CT_UNREPLIED]); + } + return NF_ACCEPT; +} +#endif + +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + +#include +#include + +static int udp_timeout_nlattr_to_obj(struct nlattr *tb[], + struct net *net, void *data) +{ + unsigned int *timeouts = data; + struct nf_udp_net *un = nf_udp_pernet(net); + + if (!timeouts) + timeouts = un->timeouts; + + /* set default timeouts for UDP. */ + timeouts[UDP_CT_UNREPLIED] = un->timeouts[UDP_CT_UNREPLIED]; + timeouts[UDP_CT_REPLIED] = un->timeouts[UDP_CT_REPLIED]; + + if (tb[CTA_TIMEOUT_UDP_UNREPLIED]) { + timeouts[UDP_CT_UNREPLIED] = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_UDP_UNREPLIED])) * HZ; + } + if (tb[CTA_TIMEOUT_UDP_REPLIED]) { + timeouts[UDP_CT_REPLIED] = + ntohl(nla_get_be32(tb[CTA_TIMEOUT_UDP_REPLIED])) * HZ; + } + return 0; +} + +static int +udp_timeout_obj_to_nlattr(struct sk_buff *skb, const void *data) +{ + const unsigned int *timeouts = data; + + if (nla_put_be32(skb, CTA_TIMEOUT_UDP_UNREPLIED, + htonl(timeouts[UDP_CT_UNREPLIED] / HZ)) || + nla_put_be32(skb, CTA_TIMEOUT_UDP_REPLIED, + htonl(timeouts[UDP_CT_REPLIED] / HZ))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -ENOSPC; +} + +static const struct nla_policy +udp_timeout_nla_policy[CTA_TIMEOUT_UDP_MAX+1] = { + [CTA_TIMEOUT_UDP_UNREPLIED] = { .type = NLA_U32 }, + [CTA_TIMEOUT_UDP_REPLIED] = { .type = NLA_U32 }, +}; +#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ + +void nf_conntrack_udp_init_net(struct net *net) +{ + struct nf_udp_net *un = nf_udp_pernet(net); + int i; + + for (i = 0; i < UDP_CT_MAX; i++) + un->timeouts[i] = udp_timeouts[i]; + +#if IS_ENABLED(CONFIG_NF_FLOW_TABLE) + un->offload_timeout = 30 * HZ; +#endif +} + +const struct nf_conntrack_l4proto nf_conntrack_l4proto_udp = +{ + .l4proto = IPPROTO_UDP, + .allow_clash = true, +#if IS_ENABLED(CONFIG_NF_CT_NETLINK) + .tuple_to_nlattr = nf_ct_port_tuple_to_nlattr, + .nlattr_to_tuple = nf_ct_port_nlattr_to_tuple, + .nlattr_tuple_size = nf_ct_port_nlattr_tuple_size, + .nla_policy = nf_ct_port_nla_policy, +#endif +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + .ctnl_timeout = { + .nlattr_to_obj = udp_timeout_nlattr_to_obj, + .obj_to_nlattr = udp_timeout_obj_to_nlattr, + .nlattr_max = CTA_TIMEOUT_UDP_MAX, + .obj_size = sizeof(unsigned int) * CTA_TIMEOUT_UDP_MAX, + .nla_policy = udp_timeout_nla_policy, + }, +#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ +}; + +#ifdef CONFIG_NF_CT_PROTO_UDPLITE +const struct nf_conntrack_l4proto nf_conntrack_l4proto_udplite = +{ + .l4proto = IPPROTO_UDPLITE, + .allow_clash = true, +#if IS_ENABLED(CONFIG_NF_CT_NETLINK) + .tuple_to_nlattr = nf_ct_port_tuple_to_nlattr, + .nlattr_to_tuple = nf_ct_port_nlattr_to_tuple, + .nlattr_tuple_size = nf_ct_port_nlattr_tuple_size, + .nla_policy = nf_ct_port_nla_policy, +#endif +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + .ctnl_timeout = { + .nlattr_to_obj = udp_timeout_nlattr_to_obj, + .obj_to_nlattr = udp_timeout_obj_to_nlattr, + .nlattr_max = CTA_TIMEOUT_UDP_MAX, + .obj_size = sizeof(unsigned int) * CTA_TIMEOUT_UDP_MAX, + .nla_policy = udp_timeout_nla_policy, + }, +#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ +}; +#endif diff --git a/net/netfilter/nf_conntrack_sane.c b/net/netfilter/nf_conntrack_sane.c new file mode 100644 index 000000000..13dc421fc --- /dev/null +++ b/net/netfilter/nf_conntrack_sane.c @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* SANE connection tracking helper + * (SANE = Scanner Access Now Easy) + * For documentation about the SANE network protocol see + * http://www.sane-project.org/html/doc015.html + */ + +/* Copyright (C) 2007 Red Hat, Inc. + * Author: Michal Schmidt + * Based on the FTP conntrack helper (net/netfilter/nf_conntrack_ftp.c): + * (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2004 Netfilter Core Team + * (C) 2003,2004 USAGI/WIDE Project + * (C) 2003 Yasuyuki Kozakai @USAGI + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define HELPER_NAME "sane" + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Michal Schmidt "); +MODULE_DESCRIPTION("SANE connection tracking helper"); +MODULE_ALIAS_NFCT_HELPER(HELPER_NAME); + +#define MAX_PORTS 8 +static u_int16_t ports[MAX_PORTS]; +static unsigned int ports_c; +module_param_array(ports, ushort, &ports_c, 0400); + +struct sane_request { + __be32 RPC_code; +#define SANE_NET_START 7 /* RPC code */ + + __be32 handle; +}; + +struct sane_reply_net_start { + __be32 status; +#define SANE_STATUS_SUCCESS 0 + + __be16 zero; + __be16 port; + /* other fields aren't interesting for conntrack */ +}; + +static int help(struct sk_buff *skb, + unsigned int protoff, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo) +{ + unsigned int dataoff, datalen; + const struct tcphdr *th; + struct tcphdr _tcph; + int ret = NF_ACCEPT; + int dir = CTINFO2DIR(ctinfo); + struct nf_ct_sane_master *ct_sane_info = nfct_help_data(ct); + struct nf_conntrack_expect *exp; + struct nf_conntrack_tuple *tuple; + struct sane_reply_net_start *reply; + union { + struct sane_request req; + struct sane_reply_net_start repl; + } buf; + + /* Until there's been traffic both ways, don't look in packets. */ + if (ctinfo != IP_CT_ESTABLISHED && + ctinfo != IP_CT_ESTABLISHED_REPLY) + return NF_ACCEPT; + + /* Not a full tcp header? */ + th = skb_header_pointer(skb, protoff, sizeof(_tcph), &_tcph); + if (th == NULL) + return NF_ACCEPT; + + /* No data? */ + dataoff = protoff + th->doff * 4; + if (dataoff >= skb->len) + return NF_ACCEPT; + + datalen = skb->len - dataoff; + if (dir == IP_CT_DIR_ORIGINAL) { + const struct sane_request *req; + + if (datalen != sizeof(struct sane_request)) + return NF_ACCEPT; + + req = skb_header_pointer(skb, dataoff, datalen, &buf.req); + if (!req) + return NF_ACCEPT; + + if (req->RPC_code != htonl(SANE_NET_START)) { + /* Not an interesting command */ + WRITE_ONCE(ct_sane_info->state, SANE_STATE_NORMAL); + return NF_ACCEPT; + } + + /* We're interested in the next reply */ + WRITE_ONCE(ct_sane_info->state, SANE_STATE_START_REQUESTED); + return NF_ACCEPT; + } + + /* IP_CT_DIR_REPLY */ + + /* Is it a reply to an uninteresting command? */ + if (READ_ONCE(ct_sane_info->state) != SANE_STATE_START_REQUESTED) + return NF_ACCEPT; + + /* It's a reply to SANE_NET_START. */ + WRITE_ONCE(ct_sane_info->state, SANE_STATE_NORMAL); + + if (datalen < sizeof(struct sane_reply_net_start)) { + pr_debug("NET_START reply too short\n"); + return NF_ACCEPT; + } + + datalen = sizeof(struct sane_reply_net_start); + + reply = skb_header_pointer(skb, dataoff, datalen, &buf.repl); + if (!reply) + return NF_ACCEPT; + + if (reply->status != htonl(SANE_STATUS_SUCCESS)) { + /* saned refused the command */ + pr_debug("unsuccessful SANE_STATUS = %u\n", + ntohl(reply->status)); + return NF_ACCEPT; + } + + /* Invalid saned reply? Ignore it. */ + if (reply->zero != 0) + return NF_ACCEPT; + + exp = nf_ct_expect_alloc(ct); + if (exp == NULL) { + nf_ct_helper_log(skb, ct, "cannot alloc expectation"); + return NF_DROP; + } + + tuple = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple; + nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT, nf_ct_l3num(ct), + &tuple->src.u3, &tuple->dst.u3, + IPPROTO_TCP, NULL, &reply->port); + + pr_debug("expect: "); + nf_ct_dump_tuple(&exp->tuple); + + /* Can't expect this? Best to drop packet now. */ + if (nf_ct_expect_related(exp, 0) != 0) { + nf_ct_helper_log(skb, ct, "cannot add expectation"); + ret = NF_DROP; + } + + nf_ct_expect_put(exp); + return ret; +} + +static struct nf_conntrack_helper sane[MAX_PORTS * 2] __read_mostly; + +static const struct nf_conntrack_expect_policy sane_exp_policy = { + .max_expected = 1, + .timeout = 5 * 60, +}; + +static void __exit nf_conntrack_sane_fini(void) +{ + nf_conntrack_helpers_unregister(sane, ports_c * 2); +} + +static int __init nf_conntrack_sane_init(void) +{ + int i, ret = 0; + + NF_CT_HELPER_BUILD_BUG_ON(sizeof(struct nf_ct_sane_master)); + + if (ports_c == 0) + ports[ports_c++] = SANE_PORT; + + /* FIXME should be configurable whether IPv4 and IPv6 connections + are tracked or not - YK */ + for (i = 0; i < ports_c; i++) { + nf_ct_helper_init(&sane[2 * i], AF_INET, IPPROTO_TCP, + HELPER_NAME, SANE_PORT, ports[i], ports[i], + &sane_exp_policy, 0, help, NULL, + THIS_MODULE); + nf_ct_helper_init(&sane[2 * i + 1], AF_INET6, IPPROTO_TCP, + HELPER_NAME, SANE_PORT, ports[i], ports[i], + &sane_exp_policy, 0, help, NULL, + THIS_MODULE); + } + + ret = nf_conntrack_helpers_register(sane, ports_c * 2); + if (ret < 0) { + pr_err("failed to register helpers\n"); + return ret; + } + + return 0; +} + +module_init(nf_conntrack_sane_init); +module_exit(nf_conntrack_sane_fini); diff --git a/net/netfilter/nf_conntrack_seqadj.c b/net/netfilter/nf_conntrack_seqadj.c new file mode 100644 index 000000000..7ab2b25b5 --- /dev/null +++ b/net/netfilter/nf_conntrack_seqadj.c @@ -0,0 +1,234 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include + +#include +#include +#include + +int nf_ct_seqadj_init(struct nf_conn *ct, enum ip_conntrack_info ctinfo, + s32 off) +{ + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + struct nf_conn_seqadj *seqadj; + struct nf_ct_seqadj *this_way; + + if (off == 0) + return 0; + + set_bit(IPS_SEQ_ADJUST_BIT, &ct->status); + + seqadj = nfct_seqadj(ct); + this_way = &seqadj->seq[dir]; + this_way->offset_before = off; + this_way->offset_after = off; + return 0; +} +EXPORT_SYMBOL_GPL(nf_ct_seqadj_init); + +int nf_ct_seqadj_set(struct nf_conn *ct, enum ip_conntrack_info ctinfo, + __be32 seq, s32 off) +{ + struct nf_conn_seqadj *seqadj = nfct_seqadj(ct); + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + struct nf_ct_seqadj *this_way; + + if (off == 0) + return 0; + + if (unlikely(!seqadj)) { + WARN_ONCE(1, "Missing nfct_seqadj_ext_add() setup call\n"); + return 0; + } + + set_bit(IPS_SEQ_ADJUST_BIT, &ct->status); + + spin_lock_bh(&ct->lock); + this_way = &seqadj->seq[dir]; + if (this_way->offset_before == this_way->offset_after || + before(this_way->correction_pos, ntohl(seq))) { + this_way->correction_pos = ntohl(seq); + this_way->offset_before = this_way->offset_after; + this_way->offset_after += off; + } + spin_unlock_bh(&ct->lock); + return 0; +} +EXPORT_SYMBOL_GPL(nf_ct_seqadj_set); + +void nf_ct_tcp_seqadj_set(struct sk_buff *skb, + struct nf_conn *ct, enum ip_conntrack_info ctinfo, + s32 off) +{ + const struct tcphdr *th; + + if (nf_ct_protonum(ct) != IPPROTO_TCP) + return; + + th = (struct tcphdr *)(skb_network_header(skb) + ip_hdrlen(skb)); + nf_ct_seqadj_set(ct, ctinfo, th->seq, off); +} +EXPORT_SYMBOL_GPL(nf_ct_tcp_seqadj_set); + +/* Adjust one found SACK option including checksum correction */ +static void nf_ct_sack_block_adjust(struct sk_buff *skb, + struct tcphdr *tcph, + unsigned int sackoff, + unsigned int sackend, + struct nf_ct_seqadj *seq) +{ + while (sackoff < sackend) { + struct tcp_sack_block_wire *sack; + __be32 new_start_seq, new_end_seq; + + sack = (void *)skb->data + sackoff; + if (after(ntohl(sack->start_seq) - seq->offset_before, + seq->correction_pos)) + new_start_seq = htonl(ntohl(sack->start_seq) - + seq->offset_after); + else + new_start_seq = htonl(ntohl(sack->start_seq) - + seq->offset_before); + + if (after(ntohl(sack->end_seq) - seq->offset_before, + seq->correction_pos)) + new_end_seq = htonl(ntohl(sack->end_seq) - + seq->offset_after); + else + new_end_seq = htonl(ntohl(sack->end_seq) - + seq->offset_before); + + pr_debug("sack_adjust: start_seq: %u->%u, end_seq: %u->%u\n", + ntohl(sack->start_seq), ntohl(new_start_seq), + ntohl(sack->end_seq), ntohl(new_end_seq)); + + inet_proto_csum_replace4(&tcph->check, skb, + sack->start_seq, new_start_seq, false); + inet_proto_csum_replace4(&tcph->check, skb, + sack->end_seq, new_end_seq, false); + sack->start_seq = new_start_seq; + sack->end_seq = new_end_seq; + sackoff += sizeof(*sack); + } +} + +/* TCP SACK sequence number adjustment */ +static unsigned int nf_ct_sack_adjust(struct sk_buff *skb, + unsigned int protoff, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo) +{ + struct tcphdr *tcph = (void *)skb->data + protoff; + struct nf_conn_seqadj *seqadj = nfct_seqadj(ct); + unsigned int dir, optoff, optend; + + optoff = protoff + sizeof(struct tcphdr); + optend = protoff + tcph->doff * 4; + + if (skb_ensure_writable(skb, optend)) + return 0; + + tcph = (void *)skb->data + protoff; + dir = CTINFO2DIR(ctinfo); + + while (optoff < optend) { + /* Usually: option, length. */ + unsigned char *op = skb->data + optoff; + + switch (op[0]) { + case TCPOPT_EOL: + return 1; + case TCPOPT_NOP: + optoff++; + continue; + default: + /* no partial options */ + if (optoff + 1 == optend || + optoff + op[1] > optend || + op[1] < 2) + return 0; + if (op[0] == TCPOPT_SACK && + op[1] >= 2+TCPOLEN_SACK_PERBLOCK && + ((op[1] - 2) % TCPOLEN_SACK_PERBLOCK) == 0) + nf_ct_sack_block_adjust(skb, tcph, optoff + 2, + optoff+op[1], + &seqadj->seq[!dir]); + optoff += op[1]; + } + } + return 1; +} + +/* TCP sequence number adjustment. Returns 1 on success, 0 on failure */ +int nf_ct_seq_adjust(struct sk_buff *skb, + struct nf_conn *ct, enum ip_conntrack_info ctinfo, + unsigned int protoff) +{ + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + struct tcphdr *tcph; + __be32 newseq, newack; + s32 seqoff, ackoff; + struct nf_conn_seqadj *seqadj = nfct_seqadj(ct); + struct nf_ct_seqadj *this_way, *other_way; + int res = 1; + + this_way = &seqadj->seq[dir]; + other_way = &seqadj->seq[!dir]; + + if (skb_ensure_writable(skb, protoff + sizeof(*tcph))) + return 0; + + tcph = (void *)skb->data + protoff; + spin_lock_bh(&ct->lock); + if (after(ntohl(tcph->seq), this_way->correction_pos)) + seqoff = this_way->offset_after; + else + seqoff = this_way->offset_before; + + newseq = htonl(ntohl(tcph->seq) + seqoff); + inet_proto_csum_replace4(&tcph->check, skb, tcph->seq, newseq, false); + pr_debug("Adjusting sequence number from %u->%u\n", + ntohl(tcph->seq), ntohl(newseq)); + tcph->seq = newseq; + + if (!tcph->ack) + goto out; + + if (after(ntohl(tcph->ack_seq) - other_way->offset_before, + other_way->correction_pos)) + ackoff = other_way->offset_after; + else + ackoff = other_way->offset_before; + + newack = htonl(ntohl(tcph->ack_seq) - ackoff); + inet_proto_csum_replace4(&tcph->check, skb, tcph->ack_seq, newack, + false); + pr_debug("Adjusting ack number from %u->%u, ack from %u->%u\n", + ntohl(tcph->seq), ntohl(newseq), ntohl(tcph->ack_seq), + ntohl(newack)); + tcph->ack_seq = newack; + + res = nf_ct_sack_adjust(skb, protoff, ct, ctinfo); +out: + spin_unlock_bh(&ct->lock); + + return res; +} +EXPORT_SYMBOL_GPL(nf_ct_seq_adjust); + +s32 nf_ct_seq_offset(const struct nf_conn *ct, + enum ip_conntrack_dir dir, + u32 seq) +{ + struct nf_conn_seqadj *seqadj = nfct_seqadj(ct); + struct nf_ct_seqadj *this_way; + + if (!seqadj) + return 0; + + this_way = &seqadj->seq[dir]; + return after(seq, this_way->correction_pos) ? + this_way->offset_after : this_way->offset_before; +} +EXPORT_SYMBOL_GPL(nf_ct_seq_offset); diff --git a/net/netfilter/nf_conntrack_sip.c b/net/netfilter/nf_conntrack_sip.c new file mode 100644 index 000000000..d0eac27f6 --- /dev/null +++ b/net/netfilter/nf_conntrack_sip.c @@ -0,0 +1,1707 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* SIP extension for IP connection tracking. + * + * (C) 2005 by Christian Hentschel + * based on RR's ip_conntrack_ftp.c and other modules. + * (C) 2007 United Security Providers + * (C) 2007, 2008 Patrick McHardy + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#define HELPER_NAME "sip" + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Christian Hentschel "); +MODULE_DESCRIPTION("SIP connection tracking helper"); +MODULE_ALIAS("ip_conntrack_sip"); +MODULE_ALIAS_NFCT_HELPER(HELPER_NAME); + +#define MAX_PORTS 8 +static unsigned short ports[MAX_PORTS]; +static unsigned int ports_c; +module_param_array(ports, ushort, &ports_c, 0400); +MODULE_PARM_DESC(ports, "port numbers of SIP servers"); + +static unsigned int sip_timeout __read_mostly = SIP_TIMEOUT; +module_param(sip_timeout, uint, 0600); +MODULE_PARM_DESC(sip_timeout, "timeout for the master SIP session"); + +static int sip_direct_signalling __read_mostly = 1; +module_param(sip_direct_signalling, int, 0600); +MODULE_PARM_DESC(sip_direct_signalling, "expect incoming calls from registrar " + "only (default 1)"); + +static int sip_direct_media __read_mostly = 1; +module_param(sip_direct_media, int, 0600); +MODULE_PARM_DESC(sip_direct_media, "Expect Media streams between signalling " + "endpoints only (default 1)"); + +static int sip_external_media __read_mostly = 0; +module_param(sip_external_media, int, 0600); +MODULE_PARM_DESC(sip_external_media, "Expect Media streams between external " + "endpoints (default 0)"); + +const struct nf_nat_sip_hooks __rcu *nf_nat_sip_hooks; +EXPORT_SYMBOL_GPL(nf_nat_sip_hooks); + +static int string_len(const struct nf_conn *ct, const char *dptr, + const char *limit, int *shift) +{ + int len = 0; + + while (dptr < limit && isalpha(*dptr)) { + dptr++; + len++; + } + return len; +} + +static int digits_len(const struct nf_conn *ct, const char *dptr, + const char *limit, int *shift) +{ + int len = 0; + while (dptr < limit && isdigit(*dptr)) { + dptr++; + len++; + } + return len; +} + +static int iswordc(const char c) +{ + if (isalnum(c) || c == '!' || c == '"' || c == '%' || + (c >= '(' && c <= '+') || c == ':' || c == '<' || c == '>' || + c == '?' || (c >= '[' && c <= ']') || c == '_' || c == '`' || + c == '{' || c == '}' || c == '~' || (c >= '-' && c <= '/') || + c == '\'') + return 1; + return 0; +} + +static int word_len(const char *dptr, const char *limit) +{ + int len = 0; + while (dptr < limit && iswordc(*dptr)) { + dptr++; + len++; + } + return len; +} + +static int callid_len(const struct nf_conn *ct, const char *dptr, + const char *limit, int *shift) +{ + int len, domain_len; + + len = word_len(dptr, limit); + dptr += len; + if (!len || dptr == limit || *dptr != '@') + return len; + dptr++; + len++; + + domain_len = word_len(dptr, limit); + if (!domain_len) + return 0; + return len + domain_len; +} + +/* get media type + port length */ +static int media_len(const struct nf_conn *ct, const char *dptr, + const char *limit, int *shift) +{ + int len = string_len(ct, dptr, limit, shift); + + dptr += len; + if (dptr >= limit || *dptr != ' ') + return 0; + len++; + dptr++; + + return len + digits_len(ct, dptr, limit, shift); +} + +static int sip_parse_addr(const struct nf_conn *ct, const char *cp, + const char **endp, union nf_inet_addr *addr, + const char *limit, bool delim) +{ + const char *end; + int ret; + + if (!ct) + return 0; + + memset(addr, 0, sizeof(*addr)); + switch (nf_ct_l3num(ct)) { + case AF_INET: + ret = in4_pton(cp, limit - cp, (u8 *)&addr->ip, -1, &end); + if (ret == 0) + return 0; + break; + case AF_INET6: + if (cp < limit && *cp == '[') + cp++; + else if (delim) + return 0; + + ret = in6_pton(cp, limit - cp, (u8 *)&addr->ip6, -1, &end); + if (ret == 0) + return 0; + + if (end < limit && *end == ']') + end++; + else if (delim) + return 0; + break; + default: + BUG(); + } + + if (endp) + *endp = end; + return 1; +} + +/* skip ip address. returns its length. */ +static int epaddr_len(const struct nf_conn *ct, const char *dptr, + const char *limit, int *shift) +{ + union nf_inet_addr addr; + const char *aux = dptr; + + if (!sip_parse_addr(ct, dptr, &dptr, &addr, limit, true)) { + pr_debug("ip: %s parse failed.!\n", dptr); + return 0; + } + + /* Port number */ + if (*dptr == ':') { + dptr++; + dptr += digits_len(ct, dptr, limit, shift); + } + return dptr - aux; +} + +/* get address length, skiping user info. */ +static int skp_epaddr_len(const struct nf_conn *ct, const char *dptr, + const char *limit, int *shift) +{ + const char *start = dptr; + int s = *shift; + + /* Search for @, but stop at the end of the line. + * We are inside a sip: URI, so we don't need to worry about + * continuation lines. */ + while (dptr < limit && + *dptr != '@' && *dptr != '\r' && *dptr != '\n') { + (*shift)++; + dptr++; + } + + if (dptr < limit && *dptr == '@') { + dptr++; + (*shift)++; + } else { + dptr = start; + *shift = s; + } + + return epaddr_len(ct, dptr, limit, shift); +} + +/* Parse a SIP request line of the form: + * + * Request-Line = Method SP Request-URI SP SIP-Version CRLF + * + * and return the offset and length of the address contained in the Request-URI. + */ +int ct_sip_parse_request(const struct nf_conn *ct, + const char *dptr, unsigned int datalen, + unsigned int *matchoff, unsigned int *matchlen, + union nf_inet_addr *addr, __be16 *port) +{ + const char *start = dptr, *limit = dptr + datalen, *end; + unsigned int mlen; + unsigned int p; + int shift = 0; + + /* Skip method and following whitespace */ + mlen = string_len(ct, dptr, limit, NULL); + if (!mlen) + return 0; + dptr += mlen; + if (++dptr >= limit) + return 0; + + /* Find SIP URI */ + for (; dptr < limit - strlen("sip:"); dptr++) { + if (*dptr == '\r' || *dptr == '\n') + return -1; + if (strncasecmp(dptr, "sip:", strlen("sip:")) == 0) { + dptr += strlen("sip:"); + break; + } + } + if (!skp_epaddr_len(ct, dptr, limit, &shift)) + return 0; + dptr += shift; + + if (!sip_parse_addr(ct, dptr, &end, addr, limit, true)) + return -1; + if (end < limit && *end == ':') { + end++; + p = simple_strtoul(end, (char **)&end, 10); + if (p < 1024 || p > 65535) + return -1; + *port = htons(p); + } else + *port = htons(SIP_PORT); + + if (end == dptr) + return 0; + *matchoff = dptr - start; + *matchlen = end - dptr; + return 1; +} +EXPORT_SYMBOL_GPL(ct_sip_parse_request); + +/* SIP header parsing: SIP headers are located at the beginning of a line, but + * may span several lines, in which case the continuation lines begin with a + * whitespace character. RFC 2543 allows lines to be terminated with CR, LF or + * CRLF, RFC 3261 allows only CRLF, we support both. + * + * Headers are followed by (optionally) whitespace, a colon, again (optionally) + * whitespace and the values. Whitespace in this context means any amount of + * tabs, spaces and continuation lines, which are treated as a single whitespace + * character. + * + * Some headers may appear multiple times. A comma separated list of values is + * equivalent to multiple headers. + */ +static const struct sip_header ct_sip_hdrs[] = { + [SIP_HDR_CSEQ] = SIP_HDR("CSeq", NULL, NULL, digits_len), + [SIP_HDR_FROM] = SIP_HDR("From", "f", "sip:", skp_epaddr_len), + [SIP_HDR_TO] = SIP_HDR("To", "t", "sip:", skp_epaddr_len), + [SIP_HDR_CONTACT] = SIP_HDR("Contact", "m", "sip:", skp_epaddr_len), + [SIP_HDR_VIA_UDP] = SIP_HDR("Via", "v", "UDP ", epaddr_len), + [SIP_HDR_VIA_TCP] = SIP_HDR("Via", "v", "TCP ", epaddr_len), + [SIP_HDR_EXPIRES] = SIP_HDR("Expires", NULL, NULL, digits_len), + [SIP_HDR_CONTENT_LENGTH] = SIP_HDR("Content-Length", "l", NULL, digits_len), + [SIP_HDR_CALL_ID] = SIP_HDR("Call-Id", "i", NULL, callid_len), +}; + +static const char *sip_follow_continuation(const char *dptr, const char *limit) +{ + /* Walk past newline */ + if (++dptr >= limit) + return NULL; + + /* Skip '\n' in CR LF */ + if (*(dptr - 1) == '\r' && *dptr == '\n') { + if (++dptr >= limit) + return NULL; + } + + /* Continuation line? */ + if (*dptr != ' ' && *dptr != '\t') + return NULL; + + /* skip leading whitespace */ + for (; dptr < limit; dptr++) { + if (*dptr != ' ' && *dptr != '\t') + break; + } + return dptr; +} + +static const char *sip_skip_whitespace(const char *dptr, const char *limit) +{ + for (; dptr < limit; dptr++) { + if (*dptr == ' ' || *dptr == '\t') + continue; + if (*dptr != '\r' && *dptr != '\n') + break; + dptr = sip_follow_continuation(dptr, limit); + break; + } + return dptr; +} + +/* Search within a SIP header value, dealing with continuation lines */ +static const char *ct_sip_header_search(const char *dptr, const char *limit, + const char *needle, unsigned int len) +{ + for (limit -= len; dptr < limit; dptr++) { + if (*dptr == '\r' || *dptr == '\n') { + dptr = sip_follow_continuation(dptr, limit); + if (dptr == NULL) + break; + continue; + } + + if (strncasecmp(dptr, needle, len) == 0) + return dptr; + } + return NULL; +} + +int ct_sip_get_header(const struct nf_conn *ct, const char *dptr, + unsigned int dataoff, unsigned int datalen, + enum sip_header_types type, + unsigned int *matchoff, unsigned int *matchlen) +{ + const struct sip_header *hdr = &ct_sip_hdrs[type]; + const char *start = dptr, *limit = dptr + datalen; + int shift = 0; + + for (dptr += dataoff; dptr < limit; dptr++) { + /* Find beginning of line */ + if (*dptr != '\r' && *dptr != '\n') + continue; + if (++dptr >= limit) + break; + if (*(dptr - 1) == '\r' && *dptr == '\n') { + if (++dptr >= limit) + break; + } + + /* Skip continuation lines */ + if (*dptr == ' ' || *dptr == '\t') + continue; + + /* Find header. Compact headers must be followed by a + * non-alphabetic character to avoid mismatches. */ + if (limit - dptr >= hdr->len && + strncasecmp(dptr, hdr->name, hdr->len) == 0) + dptr += hdr->len; + else if (hdr->cname && limit - dptr >= hdr->clen + 1 && + strncasecmp(dptr, hdr->cname, hdr->clen) == 0 && + !isalpha(*(dptr + hdr->clen))) + dptr += hdr->clen; + else + continue; + + /* Find and skip colon */ + dptr = sip_skip_whitespace(dptr, limit); + if (dptr == NULL) + break; + if (*dptr != ':' || ++dptr >= limit) + break; + + /* Skip whitespace after colon */ + dptr = sip_skip_whitespace(dptr, limit); + if (dptr == NULL) + break; + + *matchoff = dptr - start; + if (hdr->search) { + dptr = ct_sip_header_search(dptr, limit, hdr->search, + hdr->slen); + if (!dptr) + return -1; + dptr += hdr->slen; + } + + *matchlen = hdr->match_len(ct, dptr, limit, &shift); + if (!*matchlen) + return -1; + *matchoff = dptr - start + shift; + return 1; + } + return 0; +} +EXPORT_SYMBOL_GPL(ct_sip_get_header); + +/* Get next header field in a list of comma separated values */ +static int ct_sip_next_header(const struct nf_conn *ct, const char *dptr, + unsigned int dataoff, unsigned int datalen, + enum sip_header_types type, + unsigned int *matchoff, unsigned int *matchlen) +{ + const struct sip_header *hdr = &ct_sip_hdrs[type]; + const char *start = dptr, *limit = dptr + datalen; + int shift = 0; + + dptr += dataoff; + + dptr = ct_sip_header_search(dptr, limit, ",", strlen(",")); + if (!dptr) + return 0; + + dptr = ct_sip_header_search(dptr, limit, hdr->search, hdr->slen); + if (!dptr) + return 0; + dptr += hdr->slen; + + *matchoff = dptr - start; + *matchlen = hdr->match_len(ct, dptr, limit, &shift); + if (!*matchlen) + return -1; + *matchoff += shift; + return 1; +} + +/* Walk through headers until a parsable one is found or no header of the + * given type is left. */ +static int ct_sip_walk_headers(const struct nf_conn *ct, const char *dptr, + unsigned int dataoff, unsigned int datalen, + enum sip_header_types type, int *in_header, + unsigned int *matchoff, unsigned int *matchlen) +{ + int ret; + + if (in_header && *in_header) { + while (1) { + ret = ct_sip_next_header(ct, dptr, dataoff, datalen, + type, matchoff, matchlen); + if (ret > 0) + return ret; + if (ret == 0) + break; + dataoff = *matchoff; + } + *in_header = 0; + } + + while (1) { + ret = ct_sip_get_header(ct, dptr, dataoff, datalen, + type, matchoff, matchlen); + if (ret > 0) + break; + if (ret == 0) + return ret; + dataoff = *matchoff; + } + + if (in_header) + *in_header = 1; + return 1; +} + +/* Locate a SIP header, parse the URI and return the offset and length of + * the address as well as the address and port themselves. A stream of + * headers can be parsed by handing in a non-NULL datalen and in_header + * pointer. + */ +int ct_sip_parse_header_uri(const struct nf_conn *ct, const char *dptr, + unsigned int *dataoff, unsigned int datalen, + enum sip_header_types type, int *in_header, + unsigned int *matchoff, unsigned int *matchlen, + union nf_inet_addr *addr, __be16 *port) +{ + const char *c, *limit = dptr + datalen; + unsigned int p; + int ret; + + ret = ct_sip_walk_headers(ct, dptr, dataoff ? *dataoff : 0, datalen, + type, in_header, matchoff, matchlen); + WARN_ON(ret < 0); + if (ret == 0) + return ret; + + if (!sip_parse_addr(ct, dptr + *matchoff, &c, addr, limit, true)) + return -1; + if (*c == ':') { + c++; + p = simple_strtoul(c, (char **)&c, 10); + if (p < 1024 || p > 65535) + return -1; + *port = htons(p); + } else + *port = htons(SIP_PORT); + + if (dataoff) + *dataoff = c - dptr; + return 1; +} +EXPORT_SYMBOL_GPL(ct_sip_parse_header_uri); + +static int ct_sip_parse_param(const struct nf_conn *ct, const char *dptr, + unsigned int dataoff, unsigned int datalen, + const char *name, + unsigned int *matchoff, unsigned int *matchlen) +{ + const char *limit = dptr + datalen; + const char *start; + const char *end; + + limit = ct_sip_header_search(dptr + dataoff, limit, ",", strlen(",")); + if (!limit) + limit = dptr + datalen; + + start = ct_sip_header_search(dptr + dataoff, limit, name, strlen(name)); + if (!start) + return 0; + start += strlen(name); + + end = ct_sip_header_search(start, limit, ";", strlen(";")); + if (!end) + end = limit; + + *matchoff = start - dptr; + *matchlen = end - start; + return 1; +} + +/* Parse address from header parameter and return address, offset and length */ +int ct_sip_parse_address_param(const struct nf_conn *ct, const char *dptr, + unsigned int dataoff, unsigned int datalen, + const char *name, + unsigned int *matchoff, unsigned int *matchlen, + union nf_inet_addr *addr, bool delim) +{ + const char *limit = dptr + datalen; + const char *start, *end; + + limit = ct_sip_header_search(dptr + dataoff, limit, ",", strlen(",")); + if (!limit) + limit = dptr + datalen; + + start = ct_sip_header_search(dptr + dataoff, limit, name, strlen(name)); + if (!start) + return 0; + + start += strlen(name); + if (!sip_parse_addr(ct, start, &end, addr, limit, delim)) + return 0; + *matchoff = start - dptr; + *matchlen = end - start; + return 1; +} +EXPORT_SYMBOL_GPL(ct_sip_parse_address_param); + +/* Parse numerical header parameter and return value, offset and length */ +int ct_sip_parse_numerical_param(const struct nf_conn *ct, const char *dptr, + unsigned int dataoff, unsigned int datalen, + const char *name, + unsigned int *matchoff, unsigned int *matchlen, + unsigned int *val) +{ + const char *limit = dptr + datalen; + const char *start; + char *end; + + limit = ct_sip_header_search(dptr + dataoff, limit, ",", strlen(",")); + if (!limit) + limit = dptr + datalen; + + start = ct_sip_header_search(dptr + dataoff, limit, name, strlen(name)); + if (!start) + return 0; + + start += strlen(name); + *val = simple_strtoul(start, &end, 0); + if (start == end) + return -1; + if (matchoff && matchlen) { + *matchoff = start - dptr; + *matchlen = end - start; + } + return 1; +} +EXPORT_SYMBOL_GPL(ct_sip_parse_numerical_param); + +static int ct_sip_parse_transport(struct nf_conn *ct, const char *dptr, + unsigned int dataoff, unsigned int datalen, + u8 *proto) +{ + unsigned int matchoff, matchlen; + + if (ct_sip_parse_param(ct, dptr, dataoff, datalen, "transport=", + &matchoff, &matchlen)) { + if (!strncasecmp(dptr + matchoff, "TCP", strlen("TCP"))) + *proto = IPPROTO_TCP; + else if (!strncasecmp(dptr + matchoff, "UDP", strlen("UDP"))) + *proto = IPPROTO_UDP; + else + return 0; + + if (*proto != nf_ct_protonum(ct)) + return 0; + } else + *proto = nf_ct_protonum(ct); + + return 1; +} + +static int sdp_parse_addr(const struct nf_conn *ct, const char *cp, + const char **endp, union nf_inet_addr *addr, + const char *limit) +{ + const char *end; + int ret; + + memset(addr, 0, sizeof(*addr)); + switch (nf_ct_l3num(ct)) { + case AF_INET: + ret = in4_pton(cp, limit - cp, (u8 *)&addr->ip, -1, &end); + break; + case AF_INET6: + ret = in6_pton(cp, limit - cp, (u8 *)&addr->ip6, -1, &end); + break; + default: + BUG(); + } + + if (ret == 0) + return 0; + if (endp) + *endp = end; + return 1; +} + +/* skip ip address. returns its length. */ +static int sdp_addr_len(const struct nf_conn *ct, const char *dptr, + const char *limit, int *shift) +{ + union nf_inet_addr addr; + const char *aux = dptr; + + if (!sdp_parse_addr(ct, dptr, &dptr, &addr, limit)) { + pr_debug("ip: %s parse failed.!\n", dptr); + return 0; + } + + return dptr - aux; +} + +/* SDP header parsing: a SDP session description contains an ordered set of + * headers, starting with a section containing general session parameters, + * optionally followed by multiple media descriptions. + * + * SDP headers always start at the beginning of a line. According to RFC 2327: + * "The sequence CRLF (0x0d0a) is used to end a record, although parsers should + * be tolerant and also accept records terminated with a single newline + * character". We handle both cases. + */ +static const struct sip_header ct_sdp_hdrs_v4[] = { + [SDP_HDR_VERSION] = SDP_HDR("v=", NULL, digits_len), + [SDP_HDR_OWNER] = SDP_HDR("o=", "IN IP4 ", sdp_addr_len), + [SDP_HDR_CONNECTION] = SDP_HDR("c=", "IN IP4 ", sdp_addr_len), + [SDP_HDR_MEDIA] = SDP_HDR("m=", NULL, media_len), +}; + +static const struct sip_header ct_sdp_hdrs_v6[] = { + [SDP_HDR_VERSION] = SDP_HDR("v=", NULL, digits_len), + [SDP_HDR_OWNER] = SDP_HDR("o=", "IN IP6 ", sdp_addr_len), + [SDP_HDR_CONNECTION] = SDP_HDR("c=", "IN IP6 ", sdp_addr_len), + [SDP_HDR_MEDIA] = SDP_HDR("m=", NULL, media_len), +}; + +/* Linear string search within SDP header values */ +static const char *ct_sdp_header_search(const char *dptr, const char *limit, + const char *needle, unsigned int len) +{ + for (limit -= len; dptr < limit; dptr++) { + if (*dptr == '\r' || *dptr == '\n') + break; + if (strncmp(dptr, needle, len) == 0) + return dptr; + } + return NULL; +} + +/* Locate a SDP header (optionally a substring within the header value), + * optionally stopping at the first occurrence of the term header, parse + * it and return the offset and length of the data we're interested in. + */ +int ct_sip_get_sdp_header(const struct nf_conn *ct, const char *dptr, + unsigned int dataoff, unsigned int datalen, + enum sdp_header_types type, + enum sdp_header_types term, + unsigned int *matchoff, unsigned int *matchlen) +{ + const struct sip_header *hdrs, *hdr, *thdr; + const char *start = dptr, *limit = dptr + datalen; + int shift = 0; + + hdrs = nf_ct_l3num(ct) == NFPROTO_IPV4 ? ct_sdp_hdrs_v4 : ct_sdp_hdrs_v6; + hdr = &hdrs[type]; + thdr = &hdrs[term]; + + for (dptr += dataoff; dptr < limit; dptr++) { + /* Find beginning of line */ + if (*dptr != '\r' && *dptr != '\n') + continue; + if (++dptr >= limit) + break; + if (*(dptr - 1) == '\r' && *dptr == '\n') { + if (++dptr >= limit) + break; + } + + if (term != SDP_HDR_UNSPEC && + limit - dptr >= thdr->len && + strncasecmp(dptr, thdr->name, thdr->len) == 0) + break; + else if (limit - dptr >= hdr->len && + strncasecmp(dptr, hdr->name, hdr->len) == 0) + dptr += hdr->len; + else + continue; + + *matchoff = dptr - start; + if (hdr->search) { + dptr = ct_sdp_header_search(dptr, limit, hdr->search, + hdr->slen); + if (!dptr) + return -1; + dptr += hdr->slen; + } + + *matchlen = hdr->match_len(ct, dptr, limit, &shift); + if (!*matchlen) + return -1; + *matchoff = dptr - start + shift; + return 1; + } + return 0; +} +EXPORT_SYMBOL_GPL(ct_sip_get_sdp_header); + +static int ct_sip_parse_sdp_addr(const struct nf_conn *ct, const char *dptr, + unsigned int dataoff, unsigned int datalen, + enum sdp_header_types type, + enum sdp_header_types term, + unsigned int *matchoff, unsigned int *matchlen, + union nf_inet_addr *addr) +{ + int ret; + + ret = ct_sip_get_sdp_header(ct, dptr, dataoff, datalen, type, term, + matchoff, matchlen); + if (ret <= 0) + return ret; + + if (!sdp_parse_addr(ct, dptr + *matchoff, NULL, addr, + dptr + *matchoff + *matchlen)) + return -1; + return 1; +} + +static int refresh_signalling_expectation(struct nf_conn *ct, + union nf_inet_addr *addr, + u8 proto, __be16 port, + unsigned int expires) +{ + struct nf_conn_help *help = nfct_help(ct); + struct nf_conntrack_expect *exp; + struct hlist_node *next; + int found = 0; + + spin_lock_bh(&nf_conntrack_expect_lock); + hlist_for_each_entry_safe(exp, next, &help->expectations, lnode) { + if (exp->class != SIP_EXPECT_SIGNALLING || + !nf_inet_addr_cmp(&exp->tuple.dst.u3, addr) || + exp->tuple.dst.protonum != proto || + exp->tuple.dst.u.udp.port != port) + continue; + if (mod_timer_pending(&exp->timeout, jiffies + expires * HZ)) { + exp->flags &= ~NF_CT_EXPECT_INACTIVE; + found = 1; + break; + } + } + spin_unlock_bh(&nf_conntrack_expect_lock); + return found; +} + +static void flush_expectations(struct nf_conn *ct, bool media) +{ + struct nf_conn_help *help = nfct_help(ct); + struct nf_conntrack_expect *exp; + struct hlist_node *next; + + spin_lock_bh(&nf_conntrack_expect_lock); + hlist_for_each_entry_safe(exp, next, &help->expectations, lnode) { + if ((exp->class != SIP_EXPECT_SIGNALLING) ^ media) + continue; + if (!nf_ct_remove_expect(exp)) + continue; + if (!media) + break; + } + spin_unlock_bh(&nf_conntrack_expect_lock); +} + +static int set_expected_rtp_rtcp(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + union nf_inet_addr *daddr, __be16 port, + enum sip_expectation_classes class, + unsigned int mediaoff, unsigned int medialen) +{ + struct nf_conntrack_expect *exp, *rtp_exp, *rtcp_exp; + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + struct net *net = nf_ct_net(ct); + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + union nf_inet_addr *saddr; + struct nf_conntrack_tuple tuple; + int direct_rtp = 0, skip_expect = 0, ret = NF_DROP; + u_int16_t base_port; + __be16 rtp_port, rtcp_port; + const struct nf_nat_sip_hooks *hooks; + + saddr = NULL; + if (sip_direct_media) { + if (!nf_inet_addr_cmp(daddr, &ct->tuplehash[dir].tuple.src.u3)) + return NF_ACCEPT; + saddr = &ct->tuplehash[!dir].tuple.src.u3; + } else if (sip_external_media) { + struct net_device *dev = skb_dst(skb)->dev; + struct net *net = dev_net(dev); + struct flowi fl; + struct dst_entry *dst = NULL; + + memset(&fl, 0, sizeof(fl)); + + switch (nf_ct_l3num(ct)) { + case NFPROTO_IPV4: + fl.u.ip4.daddr = daddr->ip; + nf_ip_route(net, &dst, &fl, false); + break; + + case NFPROTO_IPV6: + fl.u.ip6.daddr = daddr->in6; + nf_ip6_route(net, &dst, &fl, false); + break; + } + + /* Don't predict any conntracks when media endpoint is reachable + * through the same interface as the signalling peer. + */ + if (dst) { + bool external_media = (dst->dev == dev); + + dst_release(dst); + if (external_media) + return NF_ACCEPT; + } + } + + /* We need to check whether the registration exists before attempting + * to register it since we can see the same media description multiple + * times on different connections in case multiple endpoints receive + * the same call. + * + * RTP optimization: if we find a matching media channel expectation + * and both the expectation and this connection are SNATed, we assume + * both sides can reach each other directly and use the final + * destination address from the expectation. We still need to keep + * the NATed expectations for media that might arrive from the + * outside, and additionally need to expect the direct RTP stream + * in case it passes through us even without NAT. + */ + memset(&tuple, 0, sizeof(tuple)); + if (saddr) + tuple.src.u3 = *saddr; + tuple.src.l3num = nf_ct_l3num(ct); + tuple.dst.protonum = IPPROTO_UDP; + tuple.dst.u3 = *daddr; + tuple.dst.u.udp.port = port; + + do { + exp = __nf_ct_expect_find(net, nf_ct_zone(ct), &tuple); + + if (!exp || exp->master == ct || + nfct_help(exp->master)->helper != nfct_help(ct)->helper || + exp->class != class) + break; +#if IS_ENABLED(CONFIG_NF_NAT) + if (!direct_rtp && + (!nf_inet_addr_cmp(&exp->saved_addr, &exp->tuple.dst.u3) || + exp->saved_proto.udp.port != exp->tuple.dst.u.udp.port) && + ct->status & IPS_NAT_MASK) { + *daddr = exp->saved_addr; + tuple.dst.u3 = exp->saved_addr; + tuple.dst.u.udp.port = exp->saved_proto.udp.port; + direct_rtp = 1; + } else +#endif + skip_expect = 1; + } while (!skip_expect); + + base_port = ntohs(tuple.dst.u.udp.port) & ~1; + rtp_port = htons(base_port); + rtcp_port = htons(base_port + 1); + + if (direct_rtp) { + hooks = rcu_dereference(nf_nat_sip_hooks); + if (hooks && + !hooks->sdp_port(skb, protoff, dataoff, dptr, datalen, + mediaoff, medialen, ntohs(rtp_port))) + goto err1; + } + + if (skip_expect) + return NF_ACCEPT; + + rtp_exp = nf_ct_expect_alloc(ct); + if (rtp_exp == NULL) + goto err1; + nf_ct_expect_init(rtp_exp, class, nf_ct_l3num(ct), saddr, daddr, + IPPROTO_UDP, NULL, &rtp_port); + + rtcp_exp = nf_ct_expect_alloc(ct); + if (rtcp_exp == NULL) + goto err2; + nf_ct_expect_init(rtcp_exp, class, nf_ct_l3num(ct), saddr, daddr, + IPPROTO_UDP, NULL, &rtcp_port); + + hooks = rcu_dereference(nf_nat_sip_hooks); + if (hooks && ct->status & IPS_NAT_MASK && !direct_rtp) + ret = hooks->sdp_media(skb, protoff, dataoff, dptr, + datalen, rtp_exp, rtcp_exp, + mediaoff, medialen, daddr); + else { + /* -EALREADY handling works around end-points that send + * SDP messages with identical port but different media type, + * we pretend expectation was set up. + * It also works in the case that SDP messages are sent with + * identical expect tuples but for different master conntracks. + */ + int errp = nf_ct_expect_related(rtp_exp, + NF_CT_EXP_F_SKIP_MASTER); + + if (errp == 0 || errp == -EALREADY) { + int errcp = nf_ct_expect_related(rtcp_exp, + NF_CT_EXP_F_SKIP_MASTER); + + if (errcp == 0 || errcp == -EALREADY) + ret = NF_ACCEPT; + else if (errp == 0) + nf_ct_unexpect_related(rtp_exp); + } + } + nf_ct_expect_put(rtcp_exp); +err2: + nf_ct_expect_put(rtp_exp); +err1: + return ret; +} + +static const struct sdp_media_type sdp_media_types[] = { + SDP_MEDIA_TYPE("audio ", SIP_EXPECT_AUDIO), + SDP_MEDIA_TYPE("video ", SIP_EXPECT_VIDEO), + SDP_MEDIA_TYPE("image ", SIP_EXPECT_IMAGE), +}; + +static const struct sdp_media_type *sdp_media_type(const char *dptr, + unsigned int matchoff, + unsigned int matchlen) +{ + const struct sdp_media_type *t; + unsigned int i; + + for (i = 0; i < ARRAY_SIZE(sdp_media_types); i++) { + t = &sdp_media_types[i]; + if (matchlen < t->len || + strncmp(dptr + matchoff, t->name, t->len)) + continue; + return t; + } + return NULL; +} + +static int process_sdp(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + unsigned int cseq) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + unsigned int matchoff, matchlen; + unsigned int mediaoff, medialen; + unsigned int sdpoff; + unsigned int caddr_len, maddr_len; + unsigned int i; + union nf_inet_addr caddr, maddr, rtp_addr; + const struct nf_nat_sip_hooks *hooks; + unsigned int port; + const struct sdp_media_type *t; + int ret = NF_ACCEPT; + + hooks = rcu_dereference(nf_nat_sip_hooks); + + /* Find beginning of session description */ + if (ct_sip_get_sdp_header(ct, *dptr, 0, *datalen, + SDP_HDR_VERSION, SDP_HDR_UNSPEC, + &matchoff, &matchlen) <= 0) + return NF_ACCEPT; + sdpoff = matchoff; + + /* The connection information is contained in the session description + * and/or once per media description. The first media description marks + * the end of the session description. */ + caddr_len = 0; + if (ct_sip_parse_sdp_addr(ct, *dptr, sdpoff, *datalen, + SDP_HDR_CONNECTION, SDP_HDR_MEDIA, + &matchoff, &matchlen, &caddr) > 0) + caddr_len = matchlen; + + mediaoff = sdpoff; + for (i = 0; i < ARRAY_SIZE(sdp_media_types); ) { + if (ct_sip_get_sdp_header(ct, *dptr, mediaoff, *datalen, + SDP_HDR_MEDIA, SDP_HDR_UNSPEC, + &mediaoff, &medialen) <= 0) + break; + + /* Get media type and port number. A media port value of zero + * indicates an inactive stream. */ + t = sdp_media_type(*dptr, mediaoff, medialen); + if (!t) { + mediaoff += medialen; + continue; + } + mediaoff += t->len; + medialen -= t->len; + + port = simple_strtoul(*dptr + mediaoff, NULL, 10); + if (port == 0) + continue; + if (port < 1024 || port > 65535) { + nf_ct_helper_log(skb, ct, "wrong port %u", port); + return NF_DROP; + } + + /* The media description overrides the session description. */ + maddr_len = 0; + if (ct_sip_parse_sdp_addr(ct, *dptr, mediaoff, *datalen, + SDP_HDR_CONNECTION, SDP_HDR_MEDIA, + &matchoff, &matchlen, &maddr) > 0) { + maddr_len = matchlen; + memcpy(&rtp_addr, &maddr, sizeof(rtp_addr)); + } else if (caddr_len) + memcpy(&rtp_addr, &caddr, sizeof(rtp_addr)); + else { + nf_ct_helper_log(skb, ct, "cannot parse SDP message"); + return NF_DROP; + } + + ret = set_expected_rtp_rtcp(skb, protoff, dataoff, + dptr, datalen, + &rtp_addr, htons(port), t->class, + mediaoff, medialen); + if (ret != NF_ACCEPT) { + nf_ct_helper_log(skb, ct, + "cannot add expectation for voice"); + return ret; + } + + /* Update media connection address if present */ + if (maddr_len && hooks && ct->status & IPS_NAT_MASK) { + ret = hooks->sdp_addr(skb, protoff, dataoff, + dptr, datalen, mediaoff, + SDP_HDR_CONNECTION, + SDP_HDR_MEDIA, + &rtp_addr); + if (ret != NF_ACCEPT) { + nf_ct_helper_log(skb, ct, "cannot mangle SDP"); + return ret; + } + } + i++; + } + + /* Update session connection and owner addresses */ + hooks = rcu_dereference(nf_nat_sip_hooks); + if (hooks && ct->status & IPS_NAT_MASK) + ret = hooks->sdp_session(skb, protoff, dataoff, + dptr, datalen, sdpoff, + &rtp_addr); + + return ret; +} +static int process_invite_response(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + unsigned int cseq, unsigned int code) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + struct nf_ct_sip_master *ct_sip_info = nfct_help_data(ct); + + if ((code >= 100 && code <= 199) || + (code >= 200 && code <= 299)) + return process_sdp(skb, protoff, dataoff, dptr, datalen, cseq); + else if (ct_sip_info->invite_cseq == cseq) + flush_expectations(ct, true); + return NF_ACCEPT; +} + +static int process_update_response(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + unsigned int cseq, unsigned int code) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + struct nf_ct_sip_master *ct_sip_info = nfct_help_data(ct); + + if ((code >= 100 && code <= 199) || + (code >= 200 && code <= 299)) + return process_sdp(skb, protoff, dataoff, dptr, datalen, cseq); + else if (ct_sip_info->invite_cseq == cseq) + flush_expectations(ct, true); + return NF_ACCEPT; +} + +static int process_prack_response(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + unsigned int cseq, unsigned int code) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + struct nf_ct_sip_master *ct_sip_info = nfct_help_data(ct); + + if ((code >= 100 && code <= 199) || + (code >= 200 && code <= 299)) + return process_sdp(skb, protoff, dataoff, dptr, datalen, cseq); + else if (ct_sip_info->invite_cseq == cseq) + flush_expectations(ct, true); + return NF_ACCEPT; +} + +static int process_invite_request(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + unsigned int cseq) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + struct nf_ct_sip_master *ct_sip_info = nfct_help_data(ct); + unsigned int ret; + + flush_expectations(ct, true); + ret = process_sdp(skb, protoff, dataoff, dptr, datalen, cseq); + if (ret == NF_ACCEPT) + ct_sip_info->invite_cseq = cseq; + return ret; +} + +static int process_bye_request(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + unsigned int cseq) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + + flush_expectations(ct, true); + return NF_ACCEPT; +} + +/* Parse a REGISTER request and create a permanent expectation for incoming + * signalling connections. The expectation is marked inactive and is activated + * when receiving a response indicating success from the registrar. + */ +static int process_register_request(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + unsigned int cseq) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + struct nf_ct_sip_master *ct_sip_info = nfct_help_data(ct); + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + unsigned int matchoff, matchlen; + struct nf_conntrack_expect *exp; + union nf_inet_addr *saddr, daddr; + const struct nf_nat_sip_hooks *hooks; + struct nf_conntrack_helper *helper; + __be16 port; + u8 proto; + unsigned int expires = 0; + int ret; + + /* Expected connections can not register again. */ + if (ct->status & IPS_EXPECTED) + return NF_ACCEPT; + + /* We must check the expiration time: a value of zero signals the + * registrar to release the binding. We'll remove our expectation + * when receiving the new bindings in the response, but we don't + * want to create new ones. + * + * The expiration time may be contained in Expires: header, the + * Contact: header parameters or the URI parameters. + */ + if (ct_sip_get_header(ct, *dptr, 0, *datalen, SIP_HDR_EXPIRES, + &matchoff, &matchlen) > 0) + expires = simple_strtoul(*dptr + matchoff, NULL, 10); + + ret = ct_sip_parse_header_uri(ct, *dptr, NULL, *datalen, + SIP_HDR_CONTACT, NULL, + &matchoff, &matchlen, &daddr, &port); + if (ret < 0) { + nf_ct_helper_log(skb, ct, "cannot parse contact"); + return NF_DROP; + } else if (ret == 0) + return NF_ACCEPT; + + /* We don't support third-party registrations */ + if (!nf_inet_addr_cmp(&ct->tuplehash[dir].tuple.src.u3, &daddr)) + return NF_ACCEPT; + + if (ct_sip_parse_transport(ct, *dptr, matchoff + matchlen, *datalen, + &proto) == 0) + return NF_ACCEPT; + + if (ct_sip_parse_numerical_param(ct, *dptr, + matchoff + matchlen, *datalen, + "expires=", NULL, NULL, &expires) < 0) { + nf_ct_helper_log(skb, ct, "cannot parse expires"); + return NF_DROP; + } + + if (expires == 0) { + ret = NF_ACCEPT; + goto store_cseq; + } + + exp = nf_ct_expect_alloc(ct); + if (!exp) { + nf_ct_helper_log(skb, ct, "cannot alloc expectation"); + return NF_DROP; + } + + saddr = NULL; + if (sip_direct_signalling) + saddr = &ct->tuplehash[!dir].tuple.src.u3; + + helper = rcu_dereference(nfct_help(ct)->helper); + if (!helper) + return NF_DROP; + + nf_ct_expect_init(exp, SIP_EXPECT_SIGNALLING, nf_ct_l3num(ct), + saddr, &daddr, proto, NULL, &port); + exp->timeout.expires = sip_timeout * HZ; + exp->helper = helper; + exp->flags = NF_CT_EXPECT_PERMANENT | NF_CT_EXPECT_INACTIVE; + + hooks = rcu_dereference(nf_nat_sip_hooks); + if (hooks && ct->status & IPS_NAT_MASK) + ret = hooks->expect(skb, protoff, dataoff, dptr, datalen, + exp, matchoff, matchlen); + else { + if (nf_ct_expect_related(exp, 0) != 0) { + nf_ct_helper_log(skb, ct, "cannot add expectation"); + ret = NF_DROP; + } else + ret = NF_ACCEPT; + } + nf_ct_expect_put(exp); + +store_cseq: + if (ret == NF_ACCEPT) + ct_sip_info->register_cseq = cseq; + return ret; +} + +static int process_register_response(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + unsigned int cseq, unsigned int code) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + struct nf_ct_sip_master *ct_sip_info = nfct_help_data(ct); + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + union nf_inet_addr addr; + __be16 port; + u8 proto; + unsigned int matchoff, matchlen, coff = 0; + unsigned int expires = 0; + int in_contact = 0, ret; + + /* According to RFC 3261, "UAs MUST NOT send a new registration until + * they have received a final response from the registrar for the + * previous one or the previous REGISTER request has timed out". + * + * However, some servers fail to detect retransmissions and send late + * responses, so we store the sequence number of the last valid + * request and compare it here. + */ + if (ct_sip_info->register_cseq != cseq) + return NF_ACCEPT; + + if (code >= 100 && code <= 199) + return NF_ACCEPT; + if (code < 200 || code > 299) + goto flush; + + if (ct_sip_get_header(ct, *dptr, 0, *datalen, SIP_HDR_EXPIRES, + &matchoff, &matchlen) > 0) + expires = simple_strtoul(*dptr + matchoff, NULL, 10); + + while (1) { + unsigned int c_expires = expires; + + ret = ct_sip_parse_header_uri(ct, *dptr, &coff, *datalen, + SIP_HDR_CONTACT, &in_contact, + &matchoff, &matchlen, + &addr, &port); + if (ret < 0) { + nf_ct_helper_log(skb, ct, "cannot parse contact"); + return NF_DROP; + } else if (ret == 0) + break; + + /* We don't support third-party registrations */ + if (!nf_inet_addr_cmp(&ct->tuplehash[dir].tuple.dst.u3, &addr)) + continue; + + if (ct_sip_parse_transport(ct, *dptr, matchoff + matchlen, + *datalen, &proto) == 0) + continue; + + ret = ct_sip_parse_numerical_param(ct, *dptr, + matchoff + matchlen, + *datalen, "expires=", + NULL, NULL, &c_expires); + if (ret < 0) { + nf_ct_helper_log(skb, ct, "cannot parse expires"); + return NF_DROP; + } + if (c_expires == 0) + break; + if (refresh_signalling_expectation(ct, &addr, proto, port, + c_expires)) + return NF_ACCEPT; + } + +flush: + flush_expectations(ct, false); + return NF_ACCEPT; +} + +static const struct sip_handler sip_handlers[] = { + SIP_HANDLER("INVITE", process_invite_request, process_invite_response), + SIP_HANDLER("UPDATE", process_sdp, process_update_response), + SIP_HANDLER("ACK", process_sdp, NULL), + SIP_HANDLER("PRACK", process_sdp, process_prack_response), + SIP_HANDLER("BYE", process_bye_request, NULL), + SIP_HANDLER("REGISTER", process_register_request, process_register_response), +}; + +static int process_sip_response(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + unsigned int matchoff, matchlen, matchend; + unsigned int code, cseq, i; + + if (*datalen < strlen("SIP/2.0 200")) + return NF_ACCEPT; + code = simple_strtoul(*dptr + strlen("SIP/2.0 "), NULL, 10); + if (!code) { + nf_ct_helper_log(skb, ct, "cannot get code"); + return NF_DROP; + } + + if (ct_sip_get_header(ct, *dptr, 0, *datalen, SIP_HDR_CSEQ, + &matchoff, &matchlen) <= 0) { + nf_ct_helper_log(skb, ct, "cannot parse cseq"); + return NF_DROP; + } + cseq = simple_strtoul(*dptr + matchoff, NULL, 10); + if (!cseq && *(*dptr + matchoff) != '0') { + nf_ct_helper_log(skb, ct, "cannot get cseq"); + return NF_DROP; + } + matchend = matchoff + matchlen + 1; + + for (i = 0; i < ARRAY_SIZE(sip_handlers); i++) { + const struct sip_handler *handler; + + handler = &sip_handlers[i]; + if (handler->response == NULL) + continue; + if (*datalen < matchend + handler->len || + strncasecmp(*dptr + matchend, handler->method, handler->len)) + continue; + return handler->response(skb, protoff, dataoff, dptr, datalen, + cseq, code); + } + return NF_ACCEPT; +} + +static int process_sip_request(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + struct nf_ct_sip_master *ct_sip_info = nfct_help_data(ct); + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + unsigned int matchoff, matchlen; + unsigned int cseq, i; + union nf_inet_addr addr; + __be16 port; + + /* Many Cisco IP phones use a high source port for SIP requests, but + * listen for the response on port 5060. If we are the local + * router for one of these phones, save the port number from the + * Via: header so that nf_nat_sip can redirect the responses to + * the correct port. + */ + if (ct_sip_parse_header_uri(ct, *dptr, NULL, *datalen, + SIP_HDR_VIA_UDP, NULL, &matchoff, + &matchlen, &addr, &port) > 0 && + port != ct->tuplehash[dir].tuple.src.u.udp.port && + nf_inet_addr_cmp(&addr, &ct->tuplehash[dir].tuple.src.u3)) + ct_sip_info->forced_dport = port; + + for (i = 0; i < ARRAY_SIZE(sip_handlers); i++) { + const struct sip_handler *handler; + + handler = &sip_handlers[i]; + if (handler->request == NULL) + continue; + if (*datalen < handler->len + 2 || + strncasecmp(*dptr, handler->method, handler->len)) + continue; + if ((*dptr)[handler->len] != ' ' || + !isalpha((*dptr)[handler->len+1])) + continue; + + if (ct_sip_get_header(ct, *dptr, 0, *datalen, SIP_HDR_CSEQ, + &matchoff, &matchlen) <= 0) { + nf_ct_helper_log(skb, ct, "cannot parse cseq"); + return NF_DROP; + } + cseq = simple_strtoul(*dptr + matchoff, NULL, 10); + if (!cseq && *(*dptr + matchoff) != '0') { + nf_ct_helper_log(skb, ct, "cannot get cseq"); + return NF_DROP; + } + + return handler->request(skb, protoff, dataoff, dptr, datalen, + cseq); + } + return NF_ACCEPT; +} + +static int process_sip_msg(struct sk_buff *skb, struct nf_conn *ct, + unsigned int protoff, unsigned int dataoff, + const char **dptr, unsigned int *datalen) +{ + const struct nf_nat_sip_hooks *hooks; + int ret; + + if (strncasecmp(*dptr, "SIP/2.0 ", strlen("SIP/2.0 ")) != 0) + ret = process_sip_request(skb, protoff, dataoff, dptr, datalen); + else + ret = process_sip_response(skb, protoff, dataoff, dptr, datalen); + + if (ret == NF_ACCEPT && ct->status & IPS_NAT_MASK) { + hooks = rcu_dereference(nf_nat_sip_hooks); + if (hooks && !hooks->msg(skb, protoff, dataoff, + dptr, datalen)) { + nf_ct_helper_log(skb, ct, "cannot NAT SIP message"); + ret = NF_DROP; + } + } + + return ret; +} + +static int sip_help_tcp(struct sk_buff *skb, unsigned int protoff, + struct nf_conn *ct, enum ip_conntrack_info ctinfo) +{ + struct tcphdr *th, _tcph; + unsigned int dataoff, datalen; + unsigned int matchoff, matchlen, clen; + unsigned int msglen, origlen; + const char *dptr, *end; + s16 diff, tdiff = 0; + int ret = NF_ACCEPT; + bool term; + + if (ctinfo != IP_CT_ESTABLISHED && + ctinfo != IP_CT_ESTABLISHED_REPLY) + return NF_ACCEPT; + + /* No Data ? */ + th = skb_header_pointer(skb, protoff, sizeof(_tcph), &_tcph); + if (th == NULL) + return NF_ACCEPT; + dataoff = protoff + th->doff * 4; + if (dataoff >= skb->len) + return NF_ACCEPT; + + nf_ct_refresh(ct, skb, sip_timeout * HZ); + + if (unlikely(skb_linearize(skb))) + return NF_DROP; + + dptr = skb->data + dataoff; + datalen = skb->len - dataoff; + if (datalen < strlen("SIP/2.0 200")) + return NF_ACCEPT; + + while (1) { + if (ct_sip_get_header(ct, dptr, 0, datalen, + SIP_HDR_CONTENT_LENGTH, + &matchoff, &matchlen) <= 0) + break; + + clen = simple_strtoul(dptr + matchoff, (char **)&end, 10); + if (dptr + matchoff == end) + break; + + term = false; + for (; end + strlen("\r\n\r\n") <= dptr + datalen; end++) { + if (end[0] == '\r' && end[1] == '\n' && + end[2] == '\r' && end[3] == '\n') { + term = true; + break; + } + } + if (!term) + break; + end += strlen("\r\n\r\n") + clen; + + msglen = origlen = end - dptr; + if (msglen > datalen) + return NF_ACCEPT; + + ret = process_sip_msg(skb, ct, protoff, dataoff, + &dptr, &msglen); + /* process_sip_* functions report why this packet is dropped */ + if (ret != NF_ACCEPT) + break; + diff = msglen - origlen; + tdiff += diff; + + dataoff += msglen; + dptr += msglen; + datalen = datalen + diff - msglen; + } + + if (ret == NF_ACCEPT && ct->status & IPS_NAT_MASK) { + const struct nf_nat_sip_hooks *hooks; + + hooks = rcu_dereference(nf_nat_sip_hooks); + if (hooks) + hooks->seq_adjust(skb, protoff, tdiff); + } + + return ret; +} + +static int sip_help_udp(struct sk_buff *skb, unsigned int protoff, + struct nf_conn *ct, enum ip_conntrack_info ctinfo) +{ + unsigned int dataoff, datalen; + const char *dptr; + + /* No Data ? */ + dataoff = protoff + sizeof(struct udphdr); + if (dataoff >= skb->len) + return NF_ACCEPT; + + nf_ct_refresh(ct, skb, sip_timeout * HZ); + + if (unlikely(skb_linearize(skb))) + return NF_DROP; + + dptr = skb->data + dataoff; + datalen = skb->len - dataoff; + if (datalen < strlen("SIP/2.0 200")) + return NF_ACCEPT; + + return process_sip_msg(skb, ct, protoff, dataoff, &dptr, &datalen); +} + +static struct nf_conntrack_helper sip[MAX_PORTS * 4] __read_mostly; + +static const struct nf_conntrack_expect_policy sip_exp_policy[SIP_EXPECT_MAX + 1] = { + [SIP_EXPECT_SIGNALLING] = { + .name = "signalling", + .max_expected = 1, + .timeout = 3 * 60, + }, + [SIP_EXPECT_AUDIO] = { + .name = "audio", + .max_expected = 2 * IP_CT_DIR_MAX, + .timeout = 3 * 60, + }, + [SIP_EXPECT_VIDEO] = { + .name = "video", + .max_expected = 2 * IP_CT_DIR_MAX, + .timeout = 3 * 60, + }, + [SIP_EXPECT_IMAGE] = { + .name = "image", + .max_expected = IP_CT_DIR_MAX, + .timeout = 3 * 60, + }, +}; + +static void __exit nf_conntrack_sip_fini(void) +{ + nf_conntrack_helpers_unregister(sip, ports_c * 4); +} + +static int __init nf_conntrack_sip_init(void) +{ + int i, ret; + + NF_CT_HELPER_BUILD_BUG_ON(sizeof(struct nf_ct_sip_master)); + + if (ports_c == 0) + ports[ports_c++] = SIP_PORT; + + for (i = 0; i < ports_c; i++) { + nf_ct_helper_init(&sip[4 * i], AF_INET, IPPROTO_UDP, + HELPER_NAME, SIP_PORT, ports[i], i, + sip_exp_policy, SIP_EXPECT_MAX, sip_help_udp, + NULL, THIS_MODULE); + nf_ct_helper_init(&sip[4 * i + 1], AF_INET, IPPROTO_TCP, + HELPER_NAME, SIP_PORT, ports[i], i, + sip_exp_policy, SIP_EXPECT_MAX, sip_help_tcp, + NULL, THIS_MODULE); + nf_ct_helper_init(&sip[4 * i + 2], AF_INET6, IPPROTO_UDP, + HELPER_NAME, SIP_PORT, ports[i], i, + sip_exp_policy, SIP_EXPECT_MAX, sip_help_udp, + NULL, THIS_MODULE); + nf_ct_helper_init(&sip[4 * i + 3], AF_INET6, IPPROTO_TCP, + HELPER_NAME, SIP_PORT, ports[i], i, + sip_exp_policy, SIP_EXPECT_MAX, sip_help_tcp, + NULL, THIS_MODULE); + } + + ret = nf_conntrack_helpers_register(sip, ports_c * 4); + if (ret < 0) { + pr_err("failed to register helpers\n"); + return ret; + } + return 0; +} + +module_init(nf_conntrack_sip_init); +module_exit(nf_conntrack_sip_fini); diff --git a/net/netfilter/nf_conntrack_snmp.c b/net/netfilter/nf_conntrack_snmp.c new file mode 100644 index 000000000..daacf2023 --- /dev/null +++ b/net/netfilter/nf_conntrack_snmp.c @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * SNMP service broadcast connection tracking helper + * + * (c) 2011 Jiri Olsa + */ +#include +#include +#include +#include + +#include +#include +#include +#include + +#define SNMP_PORT 161 + +MODULE_AUTHOR("Jiri Olsa "); +MODULE_DESCRIPTION("SNMP service broadcast connection tracking helper"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NFCT_HELPER("snmp"); + +static unsigned int timeout __read_mostly = 30; +module_param(timeout, uint, 0400); +MODULE_PARM_DESC(timeout, "timeout for master connection/replies in seconds"); + +int (*nf_nat_snmp_hook)(struct sk_buff *skb, + unsigned int protoff, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo); +EXPORT_SYMBOL_GPL(nf_nat_snmp_hook); + +static int snmp_conntrack_help(struct sk_buff *skb, unsigned int protoff, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo) +{ + typeof(nf_nat_snmp_hook) nf_nat_snmp; + + nf_conntrack_broadcast_help(skb, ct, ctinfo, timeout); + + nf_nat_snmp = rcu_dereference(nf_nat_snmp_hook); + if (nf_nat_snmp && ct->status & IPS_NAT_MASK) + return nf_nat_snmp(skb, protoff, ct, ctinfo); + + return NF_ACCEPT; +} + +static struct nf_conntrack_expect_policy exp_policy = { + .max_expected = 1, +}; + +static struct nf_conntrack_helper helper __read_mostly = { + .name = "snmp", + .tuple.src.l3num = NFPROTO_IPV4, + .tuple.src.u.udp.port = cpu_to_be16(SNMP_PORT), + .tuple.dst.protonum = IPPROTO_UDP, + .me = THIS_MODULE, + .help = snmp_conntrack_help, + .expect_policy = &exp_policy, +}; + +static int __init nf_conntrack_snmp_init(void) +{ + exp_policy.timeout = timeout; + return nf_conntrack_helper_register(&helper); +} + +static void __exit nf_conntrack_snmp_fini(void) +{ + nf_conntrack_helper_unregister(&helper); +} + +module_init(nf_conntrack_snmp_init); +module_exit(nf_conntrack_snmp_fini); diff --git a/net/netfilter/nf_conntrack_standalone.c b/net/netfilter/nf_conntrack_standalone.c new file mode 100644 index 000000000..52245dbfa --- /dev/null +++ b/net/netfilter/nf_conntrack_standalone.c @@ -0,0 +1,1254 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_SYSCTL +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_LWTUNNEL +#include +#endif +#include + +static bool enable_hooks __read_mostly; +MODULE_PARM_DESC(enable_hooks, "Always enable conntrack hooks"); +module_param(enable_hooks, bool, 0000); + +unsigned int nf_conntrack_net_id __read_mostly; + +#ifdef CONFIG_NF_CONNTRACK_PROCFS +void +print_tuple(struct seq_file *s, const struct nf_conntrack_tuple *tuple, + const struct nf_conntrack_l4proto *l4proto) +{ + switch (tuple->src.l3num) { + case NFPROTO_IPV4: + seq_printf(s, "src=%pI4 dst=%pI4 ", + &tuple->src.u3.ip, &tuple->dst.u3.ip); + break; + case NFPROTO_IPV6: + seq_printf(s, "src=%pI6 dst=%pI6 ", + tuple->src.u3.ip6, tuple->dst.u3.ip6); + break; + default: + break; + } + + switch (l4proto->l4proto) { + case IPPROTO_ICMP: + seq_printf(s, "type=%u code=%u id=%u ", + tuple->dst.u.icmp.type, + tuple->dst.u.icmp.code, + ntohs(tuple->src.u.icmp.id)); + break; + case IPPROTO_TCP: + seq_printf(s, "sport=%hu dport=%hu ", + ntohs(tuple->src.u.tcp.port), + ntohs(tuple->dst.u.tcp.port)); + break; + case IPPROTO_UDPLITE: + case IPPROTO_UDP: + seq_printf(s, "sport=%hu dport=%hu ", + ntohs(tuple->src.u.udp.port), + ntohs(tuple->dst.u.udp.port)); + + break; + case IPPROTO_DCCP: + seq_printf(s, "sport=%hu dport=%hu ", + ntohs(tuple->src.u.dccp.port), + ntohs(tuple->dst.u.dccp.port)); + break; + case IPPROTO_SCTP: + seq_printf(s, "sport=%hu dport=%hu ", + ntohs(tuple->src.u.sctp.port), + ntohs(tuple->dst.u.sctp.port)); + break; + case IPPROTO_ICMPV6: + seq_printf(s, "type=%u code=%u id=%u ", + tuple->dst.u.icmp.type, + tuple->dst.u.icmp.code, + ntohs(tuple->src.u.icmp.id)); + break; + case IPPROTO_GRE: + seq_printf(s, "srckey=0x%x dstkey=0x%x ", + ntohs(tuple->src.u.gre.key), + ntohs(tuple->dst.u.gre.key)); + break; + default: + break; + } +} +EXPORT_SYMBOL_GPL(print_tuple); + +struct ct_iter_state { + struct seq_net_private p; + struct hlist_nulls_head *hash; + unsigned int htable_size; + unsigned int bucket; + u_int64_t time_now; +}; + +static struct hlist_nulls_node *ct_get_first(struct seq_file *seq) +{ + struct ct_iter_state *st = seq->private; + struct hlist_nulls_node *n; + + for (st->bucket = 0; + st->bucket < st->htable_size; + st->bucket++) { + n = rcu_dereference( + hlist_nulls_first_rcu(&st->hash[st->bucket])); + if (!is_a_nulls(n)) + return n; + } + return NULL; +} + +static struct hlist_nulls_node *ct_get_next(struct seq_file *seq, + struct hlist_nulls_node *head) +{ + struct ct_iter_state *st = seq->private; + + head = rcu_dereference(hlist_nulls_next_rcu(head)); + while (is_a_nulls(head)) { + if (likely(get_nulls_value(head) == st->bucket)) { + if (++st->bucket >= st->htable_size) + return NULL; + } + head = rcu_dereference( + hlist_nulls_first_rcu(&st->hash[st->bucket])); + } + return head; +} + +static struct hlist_nulls_node *ct_get_idx(struct seq_file *seq, loff_t pos) +{ + struct hlist_nulls_node *head = ct_get_first(seq); + + if (head) + while (pos && (head = ct_get_next(seq, head))) + pos--; + return pos ? NULL : head; +} + +static void *ct_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(RCU) +{ + struct ct_iter_state *st = seq->private; + + st->time_now = ktime_get_real_ns(); + rcu_read_lock(); + + nf_conntrack_get_ht(&st->hash, &st->htable_size); + return ct_get_idx(seq, *pos); +} + +static void *ct_seq_next(struct seq_file *s, void *v, loff_t *pos) +{ + (*pos)++; + return ct_get_next(s, v); +} + +static void ct_seq_stop(struct seq_file *s, void *v) + __releases(RCU) +{ + rcu_read_unlock(); +} + +#ifdef CONFIG_NF_CONNTRACK_SECMARK +static void ct_show_secctx(struct seq_file *s, const struct nf_conn *ct) +{ + int ret; + u32 len; + char *secctx; + + ret = security_secid_to_secctx(ct->secmark, &secctx, &len); + if (ret) + return; + + seq_printf(s, "secctx=%s ", secctx); + + security_release_secctx(secctx, len); +} +#else +static inline void ct_show_secctx(struct seq_file *s, const struct nf_conn *ct) +{ +} +#endif + +#ifdef CONFIG_NF_CONNTRACK_ZONES +static void ct_show_zone(struct seq_file *s, const struct nf_conn *ct, + int dir) +{ + const struct nf_conntrack_zone *zone = nf_ct_zone(ct); + + if (zone->dir != dir) + return; + switch (zone->dir) { + case NF_CT_DEFAULT_ZONE_DIR: + seq_printf(s, "zone=%u ", zone->id); + break; + case NF_CT_ZONE_DIR_ORIG: + seq_printf(s, "zone-orig=%u ", zone->id); + break; + case NF_CT_ZONE_DIR_REPL: + seq_printf(s, "zone-reply=%u ", zone->id); + break; + default: + break; + } +} +#else +static inline void ct_show_zone(struct seq_file *s, const struct nf_conn *ct, + int dir) +{ +} +#endif + +#ifdef CONFIG_NF_CONNTRACK_TIMESTAMP +static void ct_show_delta_time(struct seq_file *s, const struct nf_conn *ct) +{ + struct ct_iter_state *st = s->private; + struct nf_conn_tstamp *tstamp; + s64 delta_time; + + tstamp = nf_conn_tstamp_find(ct); + if (tstamp) { + delta_time = st->time_now - tstamp->start; + if (delta_time > 0) + delta_time = div_s64(delta_time, NSEC_PER_SEC); + else + delta_time = 0; + + seq_printf(s, "delta-time=%llu ", + (unsigned long long)delta_time); + } + return; +} +#else +static inline void +ct_show_delta_time(struct seq_file *s, const struct nf_conn *ct) +{ +} +#endif + +static const char* l3proto_name(u16 proto) +{ + switch (proto) { + case AF_INET: return "ipv4"; + case AF_INET6: return "ipv6"; + } + + return "unknown"; +} + +static const char* l4proto_name(u16 proto) +{ + switch (proto) { + case IPPROTO_ICMP: return "icmp"; + case IPPROTO_TCP: return "tcp"; + case IPPROTO_UDP: return "udp"; + case IPPROTO_DCCP: return "dccp"; + case IPPROTO_GRE: return "gre"; + case IPPROTO_SCTP: return "sctp"; + case IPPROTO_UDPLITE: return "udplite"; + case IPPROTO_ICMPV6: return "icmpv6"; + } + + return "unknown"; +} + +static unsigned int +seq_print_acct(struct seq_file *s, const struct nf_conn *ct, int dir) +{ + struct nf_conn_acct *acct; + struct nf_conn_counter *counter; + + acct = nf_conn_acct_find(ct); + if (!acct) + return 0; + + counter = acct->counter; + seq_printf(s, "packets=%llu bytes=%llu ", + (unsigned long long)atomic64_read(&counter[dir].packets), + (unsigned long long)atomic64_read(&counter[dir].bytes)); + + return 0; +} + +/* return 0 on success, 1 in case of error */ +static int ct_seq_show(struct seq_file *s, void *v) +{ + struct nf_conntrack_tuple_hash *hash = v; + struct nf_conn *ct = nf_ct_tuplehash_to_ctrack(hash); + const struct nf_conntrack_l4proto *l4proto; + struct net *net = seq_file_net(s); + int ret = 0; + + WARN_ON(!ct); + if (unlikely(!refcount_inc_not_zero(&ct->ct_general.use))) + return 0; + + /* load ->status after refcount increase */ + smp_acquire__after_ctrl_dep(); + + if (nf_ct_should_gc(ct)) { + nf_ct_kill(ct); + goto release; + } + + /* we only want to print DIR_ORIGINAL */ + if (NF_CT_DIRECTION(hash)) + goto release; + + if (!net_eq(nf_ct_net(ct), net)) + goto release; + + l4proto = nf_ct_l4proto_find(nf_ct_protonum(ct)); + + ret = -ENOSPC; + seq_printf(s, "%-8s %u %-8s %u ", + l3proto_name(nf_ct_l3num(ct)), nf_ct_l3num(ct), + l4proto_name(l4proto->l4proto), nf_ct_protonum(ct)); + + if (!test_bit(IPS_OFFLOAD_BIT, &ct->status)) + seq_printf(s, "%ld ", nf_ct_expires(ct) / HZ); + + if (l4proto->print_conntrack) + l4proto->print_conntrack(s, ct); + + print_tuple(s, &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple, + l4proto); + + ct_show_zone(s, ct, NF_CT_ZONE_DIR_ORIG); + + if (seq_has_overflowed(s)) + goto release; + + if (seq_print_acct(s, ct, IP_CT_DIR_ORIGINAL)) + goto release; + + if (!(test_bit(IPS_SEEN_REPLY_BIT, &ct->status))) + seq_puts(s, "[UNREPLIED] "); + + print_tuple(s, &ct->tuplehash[IP_CT_DIR_REPLY].tuple, l4proto); + + ct_show_zone(s, ct, NF_CT_ZONE_DIR_REPL); + + if (seq_print_acct(s, ct, IP_CT_DIR_REPLY)) + goto release; + + if (test_bit(IPS_HW_OFFLOAD_BIT, &ct->status)) + seq_puts(s, "[HW_OFFLOAD] "); + else if (test_bit(IPS_OFFLOAD_BIT, &ct->status)) + seq_puts(s, "[OFFLOAD] "); + else if (test_bit(IPS_ASSURED_BIT, &ct->status)) + seq_puts(s, "[ASSURED] "); + + if (seq_has_overflowed(s)) + goto release; + +#if defined(CONFIG_NF_CONNTRACK_MARK) + seq_printf(s, "mark=%u ", READ_ONCE(ct->mark)); +#endif + + ct_show_secctx(s, ct); + ct_show_zone(s, ct, NF_CT_DEFAULT_ZONE_DIR); + ct_show_delta_time(s, ct); + + seq_printf(s, "use=%u\n", refcount_read(&ct->ct_general.use)); + + if (seq_has_overflowed(s)) + goto release; + + ret = 0; +release: + nf_ct_put(ct); + return ret; +} + +static const struct seq_operations ct_seq_ops = { + .start = ct_seq_start, + .next = ct_seq_next, + .stop = ct_seq_stop, + .show = ct_seq_show +}; + +static void *ct_cpu_seq_start(struct seq_file *seq, loff_t *pos) +{ + struct net *net = seq_file_net(seq); + int cpu; + + if (*pos == 0) + return SEQ_START_TOKEN; + + for (cpu = *pos-1; cpu < nr_cpu_ids; ++cpu) { + if (!cpu_possible(cpu)) + continue; + *pos = cpu + 1; + return per_cpu_ptr(net->ct.stat, cpu); + } + + return NULL; +} + +static void *ct_cpu_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct net *net = seq_file_net(seq); + int cpu; + + for (cpu = *pos; cpu < nr_cpu_ids; ++cpu) { + if (!cpu_possible(cpu)) + continue; + *pos = cpu + 1; + return per_cpu_ptr(net->ct.stat, cpu); + } + (*pos)++; + return NULL; +} + +static void ct_cpu_seq_stop(struct seq_file *seq, void *v) +{ +} + +static int ct_cpu_seq_show(struct seq_file *seq, void *v) +{ + struct net *net = seq_file_net(seq); + const struct ip_conntrack_stat *st = v; + unsigned int nr_conntracks; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, "entries clashres found new invalid ignore delete chainlength insert insert_failed drop early_drop icmp_error expect_new expect_create expect_delete search_restart\n"); + return 0; + } + + nr_conntracks = nf_conntrack_count(net); + + seq_printf(seq, "%08x %08x %08x %08x %08x %08x %08x %08x " + "%08x %08x %08x %08x %08x %08x %08x %08x %08x\n", + nr_conntracks, + st->clash_resolve, + st->found, + 0, + st->invalid, + 0, + 0, + st->chaintoolong, + st->insert, + st->insert_failed, + st->drop, + st->early_drop, + st->error, + + st->expect_new, + st->expect_create, + st->expect_delete, + st->search_restart + ); + return 0; +} + +static const struct seq_operations ct_cpu_seq_ops = { + .start = ct_cpu_seq_start, + .next = ct_cpu_seq_next, + .stop = ct_cpu_seq_stop, + .show = ct_cpu_seq_show, +}; + +static int nf_conntrack_standalone_init_proc(struct net *net) +{ + struct proc_dir_entry *pde; + kuid_t root_uid; + kgid_t root_gid; + + pde = proc_create_net("nf_conntrack", 0440, net->proc_net, &ct_seq_ops, + sizeof(struct ct_iter_state)); + if (!pde) + goto out_nf_conntrack; + + root_uid = make_kuid(net->user_ns, 0); + root_gid = make_kgid(net->user_ns, 0); + if (uid_valid(root_uid) && gid_valid(root_gid)) + proc_set_user(pde, root_uid, root_gid); + + pde = proc_create_net("nf_conntrack", 0444, net->proc_net_stat, + &ct_cpu_seq_ops, sizeof(struct seq_net_private)); + if (!pde) + goto out_stat_nf_conntrack; + return 0; + +out_stat_nf_conntrack: + remove_proc_entry("nf_conntrack", net->proc_net); +out_nf_conntrack: + return -ENOMEM; +} + +static void nf_conntrack_standalone_fini_proc(struct net *net) +{ + remove_proc_entry("nf_conntrack", net->proc_net_stat); + remove_proc_entry("nf_conntrack", net->proc_net); +} +#else +static int nf_conntrack_standalone_init_proc(struct net *net) +{ + return 0; +} + +static void nf_conntrack_standalone_fini_proc(struct net *net) +{ +} +#endif /* CONFIG_NF_CONNTRACK_PROCFS */ + +u32 nf_conntrack_count(const struct net *net) +{ + const struct nf_conntrack_net *cnet = nf_ct_pernet(net); + + return atomic_read(&cnet->count); +} +EXPORT_SYMBOL_GPL(nf_conntrack_count); + +/* Sysctl support */ + +#ifdef CONFIG_SYSCTL +/* size the user *wants to set */ +static unsigned int nf_conntrack_htable_size_user __read_mostly; + +static int +nf_conntrack_hash_sysctl(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + int ret; + + /* module_param hashsize could have changed value */ + nf_conntrack_htable_size_user = nf_conntrack_htable_size; + + ret = proc_dointvec(table, write, buffer, lenp, ppos); + if (ret < 0 || !write) + return ret; + + /* update ret, we might not be able to satisfy request */ + ret = nf_conntrack_hash_resize(nf_conntrack_htable_size_user); + + /* update it to the actual value used by conntrack */ + nf_conntrack_htable_size_user = nf_conntrack_htable_size; + return ret; +} + +static struct ctl_table_header *nf_ct_netfilter_header; + +enum nf_ct_sysctl_index { + NF_SYSCTL_CT_MAX, + NF_SYSCTL_CT_COUNT, + NF_SYSCTL_CT_BUCKETS, + NF_SYSCTL_CT_CHECKSUM, + NF_SYSCTL_CT_LOG_INVALID, + NF_SYSCTL_CT_EXPECT_MAX, + NF_SYSCTL_CT_ACCT, +#ifdef CONFIG_NF_CONNTRACK_EVENTS + NF_SYSCTL_CT_EVENTS, +#endif +#ifdef CONFIG_NF_CONNTRACK_TIMESTAMP + NF_SYSCTL_CT_TIMESTAMP, +#endif + NF_SYSCTL_CT_PROTO_TIMEOUT_GENERIC, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_SYN_SENT, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_SYN_RECV, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_ESTABLISHED, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_FIN_WAIT, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_CLOSE_WAIT, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_LAST_ACK, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_TIME_WAIT, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_CLOSE, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_RETRANS, + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_UNACK, +#if IS_ENABLED(CONFIG_NF_FLOW_TABLE) + NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_OFFLOAD, +#endif + NF_SYSCTL_CT_PROTO_TCP_LOOSE, + NF_SYSCTL_CT_PROTO_TCP_LIBERAL, + NF_SYSCTL_CT_PROTO_TCP_IGNORE_INVALID_RST, + NF_SYSCTL_CT_PROTO_TCP_MAX_RETRANS, + NF_SYSCTL_CT_PROTO_TIMEOUT_UDP, + NF_SYSCTL_CT_PROTO_TIMEOUT_UDP_STREAM, +#if IS_ENABLED(CONFIG_NF_FLOW_TABLE) + NF_SYSCTL_CT_PROTO_TIMEOUT_UDP_OFFLOAD, +#endif + NF_SYSCTL_CT_PROTO_TIMEOUT_ICMP, + NF_SYSCTL_CT_PROTO_TIMEOUT_ICMPV6, +#ifdef CONFIG_NF_CT_PROTO_SCTP + NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_CLOSED, + NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_COOKIE_WAIT, + NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_COOKIE_ECHOED, + NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_ESTABLISHED, + NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_SHUTDOWN_SENT, + NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_SHUTDOWN_RECD, + NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_SHUTDOWN_ACK_SENT, + NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_HEARTBEAT_SENT, +#endif +#ifdef CONFIG_NF_CT_PROTO_DCCP + NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_REQUEST, + NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_RESPOND, + NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_PARTOPEN, + NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_OPEN, + NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_CLOSEREQ, + NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_CLOSING, + NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_TIMEWAIT, + NF_SYSCTL_CT_PROTO_DCCP_LOOSE, +#endif +#ifdef CONFIG_NF_CT_PROTO_GRE + NF_SYSCTL_CT_PROTO_TIMEOUT_GRE, + NF_SYSCTL_CT_PROTO_TIMEOUT_GRE_STREAM, +#endif +#ifdef CONFIG_LWTUNNEL + NF_SYSCTL_CT_LWTUNNEL, +#endif + + __NF_SYSCTL_CT_LAST_SYSCTL, +}; + +#define NF_SYSCTL_CT_LAST_SYSCTL (__NF_SYSCTL_CT_LAST_SYSCTL + 1) + +static struct ctl_table nf_ct_sysctl_table[] = { + [NF_SYSCTL_CT_MAX] = { + .procname = "nf_conntrack_max", + .data = &nf_conntrack_max, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + [NF_SYSCTL_CT_COUNT] = { + .procname = "nf_conntrack_count", + .maxlen = sizeof(int), + .mode = 0444, + .proc_handler = proc_dointvec, + }, + [NF_SYSCTL_CT_BUCKETS] = { + .procname = "nf_conntrack_buckets", + .data = &nf_conntrack_htable_size_user, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = nf_conntrack_hash_sysctl, + }, + [NF_SYSCTL_CT_CHECKSUM] = { + .procname = "nf_conntrack_checksum", + .data = &init_net.ct.sysctl_checksum, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + [NF_SYSCTL_CT_LOG_INVALID] = { + .procname = "nf_conntrack_log_invalid", + .data = &init_net.ct.sysctl_log_invalid, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + [NF_SYSCTL_CT_EXPECT_MAX] = { + .procname = "nf_conntrack_expect_max", + .data = &nf_ct_expect_max, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + [NF_SYSCTL_CT_ACCT] = { + .procname = "nf_conntrack_acct", + .data = &init_net.ct.sysctl_acct, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, +#ifdef CONFIG_NF_CONNTRACK_EVENTS + [NF_SYSCTL_CT_EVENTS] = { + .procname = "nf_conntrack_events", + .data = &init_net.ct.sysctl_events, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_TWO, + }, +#endif +#ifdef CONFIG_NF_CONNTRACK_TIMESTAMP + [NF_SYSCTL_CT_TIMESTAMP] = { + .procname = "nf_conntrack_timestamp", + .data = &init_net.ct.sysctl_tstamp, + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, +#endif + [NF_SYSCTL_CT_PROTO_TIMEOUT_GENERIC] = { + .procname = "nf_conntrack_generic_timeout", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_SYN_SENT] = { + .procname = "nf_conntrack_tcp_timeout_syn_sent", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_SYN_RECV] = { + .procname = "nf_conntrack_tcp_timeout_syn_recv", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_ESTABLISHED] = { + .procname = "nf_conntrack_tcp_timeout_established", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_FIN_WAIT] = { + .procname = "nf_conntrack_tcp_timeout_fin_wait", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_CLOSE_WAIT] = { + .procname = "nf_conntrack_tcp_timeout_close_wait", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_LAST_ACK] = { + .procname = "nf_conntrack_tcp_timeout_last_ack", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_TIME_WAIT] = { + .procname = "nf_conntrack_tcp_timeout_time_wait", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_CLOSE] = { + .procname = "nf_conntrack_tcp_timeout_close", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_RETRANS] = { + .procname = "nf_conntrack_tcp_timeout_max_retrans", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_UNACK] = { + .procname = "nf_conntrack_tcp_timeout_unacknowledged", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, +#if IS_ENABLED(CONFIG_NF_FLOW_TABLE) + [NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_OFFLOAD] = { + .procname = "nf_flowtable_tcp_timeout", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, +#endif + [NF_SYSCTL_CT_PROTO_TCP_LOOSE] = { + .procname = "nf_conntrack_tcp_loose", + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + [NF_SYSCTL_CT_PROTO_TCP_LIBERAL] = { + .procname = "nf_conntrack_tcp_be_liberal", + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + [NF_SYSCTL_CT_PROTO_TCP_IGNORE_INVALID_RST] = { + .procname = "nf_conntrack_tcp_ignore_invalid_rst", + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, + [NF_SYSCTL_CT_PROTO_TCP_MAX_RETRANS] = { + .procname = "nf_conntrack_tcp_max_retrans", + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_UDP] = { + .procname = "nf_conntrack_udp_timeout", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_UDP_STREAM] = { + .procname = "nf_conntrack_udp_timeout_stream", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, +#if IS_ENABLED(CONFIG_NF_FLOW_TABLE) + [NF_SYSCTL_CT_PROTO_TIMEOUT_UDP_OFFLOAD] = { + .procname = "nf_flowtable_udp_timeout", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, +#endif + [NF_SYSCTL_CT_PROTO_TIMEOUT_ICMP] = { + .procname = "nf_conntrack_icmp_timeout", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_ICMPV6] = { + .procname = "nf_conntrack_icmpv6_timeout", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, +#ifdef CONFIG_NF_CT_PROTO_SCTP + [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_CLOSED] = { + .procname = "nf_conntrack_sctp_timeout_closed", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_COOKIE_WAIT] = { + .procname = "nf_conntrack_sctp_timeout_cookie_wait", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_COOKIE_ECHOED] = { + .procname = "nf_conntrack_sctp_timeout_cookie_echoed", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_ESTABLISHED] = { + .procname = "nf_conntrack_sctp_timeout_established", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_SHUTDOWN_SENT] = { + .procname = "nf_conntrack_sctp_timeout_shutdown_sent", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_SHUTDOWN_RECD] = { + .procname = "nf_conntrack_sctp_timeout_shutdown_recd", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_SHUTDOWN_ACK_SENT] = { + .procname = "nf_conntrack_sctp_timeout_shutdown_ack_sent", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_HEARTBEAT_SENT] = { + .procname = "nf_conntrack_sctp_timeout_heartbeat_sent", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, +#endif +#ifdef CONFIG_NF_CT_PROTO_DCCP + [NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_REQUEST] = { + .procname = "nf_conntrack_dccp_timeout_request", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_RESPOND] = { + .procname = "nf_conntrack_dccp_timeout_respond", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_PARTOPEN] = { + .procname = "nf_conntrack_dccp_timeout_partopen", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_OPEN] = { + .procname = "nf_conntrack_dccp_timeout_open", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_CLOSEREQ] = { + .procname = "nf_conntrack_dccp_timeout_closereq", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_CLOSING] = { + .procname = "nf_conntrack_dccp_timeout_closing", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_TIMEWAIT] = { + .procname = "nf_conntrack_dccp_timeout_timewait", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_DCCP_LOOSE] = { + .procname = "nf_conntrack_dccp_loose", + .maxlen = sizeof(u8), + .mode = 0644, + .proc_handler = proc_dou8vec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, +#endif +#ifdef CONFIG_NF_CT_PROTO_GRE + [NF_SYSCTL_CT_PROTO_TIMEOUT_GRE] = { + .procname = "nf_conntrack_gre_timeout", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + [NF_SYSCTL_CT_PROTO_TIMEOUT_GRE_STREAM] = { + .procname = "nf_conntrack_gre_timeout_stream", + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, +#endif +#ifdef CONFIG_LWTUNNEL + [NF_SYSCTL_CT_LWTUNNEL] = { + .procname = "nf_hooks_lwtunnel", + .data = NULL, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = nf_hooks_lwtunnel_sysctl_handler, + }, +#endif + {} +}; + +static struct ctl_table nf_ct_netfilter_table[] = { + { + .procname = "nf_conntrack_max", + .data = &nf_conntrack_max, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { } +}; + +static void nf_conntrack_standalone_init_tcp_sysctl(struct net *net, + struct ctl_table *table) +{ + struct nf_tcp_net *tn = nf_tcp_pernet(net); + +#define XASSIGN(XNAME, tn) \ + table[NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_ ## XNAME].data = \ + &(tn)->timeouts[TCP_CONNTRACK_ ## XNAME] + + XASSIGN(SYN_SENT, tn); + XASSIGN(SYN_RECV, tn); + XASSIGN(ESTABLISHED, tn); + XASSIGN(FIN_WAIT, tn); + XASSIGN(CLOSE_WAIT, tn); + XASSIGN(LAST_ACK, tn); + XASSIGN(TIME_WAIT, tn); + XASSIGN(CLOSE, tn); + XASSIGN(RETRANS, tn); + XASSIGN(UNACK, tn); +#undef XASSIGN +#define XASSIGN(XNAME, rval) \ + table[NF_SYSCTL_CT_PROTO_TCP_ ## XNAME].data = (rval) + + XASSIGN(LOOSE, &tn->tcp_loose); + XASSIGN(LIBERAL, &tn->tcp_be_liberal); + XASSIGN(MAX_RETRANS, &tn->tcp_max_retrans); + XASSIGN(IGNORE_INVALID_RST, &tn->tcp_ignore_invalid_rst); +#undef XASSIGN + +#if IS_ENABLED(CONFIG_NF_FLOW_TABLE) + table[NF_SYSCTL_CT_PROTO_TIMEOUT_TCP_OFFLOAD].data = &tn->offload_timeout; +#endif + +} + +static void nf_conntrack_standalone_init_sctp_sysctl(struct net *net, + struct ctl_table *table) +{ +#ifdef CONFIG_NF_CT_PROTO_SCTP + struct nf_sctp_net *sn = nf_sctp_pernet(net); + +#define XASSIGN(XNAME, sn) \ + table[NF_SYSCTL_CT_PROTO_TIMEOUT_SCTP_ ## XNAME].data = \ + &(sn)->timeouts[SCTP_CONNTRACK_ ## XNAME] + + XASSIGN(CLOSED, sn); + XASSIGN(COOKIE_WAIT, sn); + XASSIGN(COOKIE_ECHOED, sn); + XASSIGN(ESTABLISHED, sn); + XASSIGN(SHUTDOWN_SENT, sn); + XASSIGN(SHUTDOWN_RECD, sn); + XASSIGN(SHUTDOWN_ACK_SENT, sn); + XASSIGN(HEARTBEAT_SENT, sn); +#undef XASSIGN +#endif +} + +static void nf_conntrack_standalone_init_dccp_sysctl(struct net *net, + struct ctl_table *table) +{ +#ifdef CONFIG_NF_CT_PROTO_DCCP + struct nf_dccp_net *dn = nf_dccp_pernet(net); + +#define XASSIGN(XNAME, dn) \ + table[NF_SYSCTL_CT_PROTO_TIMEOUT_DCCP_ ## XNAME].data = \ + &(dn)->dccp_timeout[CT_DCCP_ ## XNAME] + + XASSIGN(REQUEST, dn); + XASSIGN(RESPOND, dn); + XASSIGN(PARTOPEN, dn); + XASSIGN(OPEN, dn); + XASSIGN(CLOSEREQ, dn); + XASSIGN(CLOSING, dn); + XASSIGN(TIMEWAIT, dn); +#undef XASSIGN + + table[NF_SYSCTL_CT_PROTO_DCCP_LOOSE].data = &dn->dccp_loose; +#endif +} + +static void nf_conntrack_standalone_init_gre_sysctl(struct net *net, + struct ctl_table *table) +{ +#ifdef CONFIG_NF_CT_PROTO_GRE + struct nf_gre_net *gn = nf_gre_pernet(net); + + table[NF_SYSCTL_CT_PROTO_TIMEOUT_GRE].data = &gn->timeouts[GRE_CT_UNREPLIED]; + table[NF_SYSCTL_CT_PROTO_TIMEOUT_GRE_STREAM].data = &gn->timeouts[GRE_CT_REPLIED]; +#endif +} + +static int nf_conntrack_standalone_init_sysctl(struct net *net) +{ + struct nf_conntrack_net *cnet = nf_ct_pernet(net); + struct nf_udp_net *un = nf_udp_pernet(net); + struct ctl_table *table; + + BUILD_BUG_ON(ARRAY_SIZE(nf_ct_sysctl_table) != NF_SYSCTL_CT_LAST_SYSCTL); + + table = kmemdup(nf_ct_sysctl_table, sizeof(nf_ct_sysctl_table), + GFP_KERNEL); + if (!table) + return -ENOMEM; + + table[NF_SYSCTL_CT_COUNT].data = &cnet->count; + table[NF_SYSCTL_CT_CHECKSUM].data = &net->ct.sysctl_checksum; + table[NF_SYSCTL_CT_LOG_INVALID].data = &net->ct.sysctl_log_invalid; + table[NF_SYSCTL_CT_ACCT].data = &net->ct.sysctl_acct; +#ifdef CONFIG_NF_CONNTRACK_EVENTS + table[NF_SYSCTL_CT_EVENTS].data = &net->ct.sysctl_events; +#endif +#ifdef CONFIG_NF_CONNTRACK_TIMESTAMP + table[NF_SYSCTL_CT_TIMESTAMP].data = &net->ct.sysctl_tstamp; +#endif + table[NF_SYSCTL_CT_PROTO_TIMEOUT_GENERIC].data = &nf_generic_pernet(net)->timeout; + table[NF_SYSCTL_CT_PROTO_TIMEOUT_ICMP].data = &nf_icmp_pernet(net)->timeout; + table[NF_SYSCTL_CT_PROTO_TIMEOUT_ICMPV6].data = &nf_icmpv6_pernet(net)->timeout; + table[NF_SYSCTL_CT_PROTO_TIMEOUT_UDP].data = &un->timeouts[UDP_CT_UNREPLIED]; + table[NF_SYSCTL_CT_PROTO_TIMEOUT_UDP_STREAM].data = &un->timeouts[UDP_CT_REPLIED]; +#if IS_ENABLED(CONFIG_NF_FLOW_TABLE) + table[NF_SYSCTL_CT_PROTO_TIMEOUT_UDP_OFFLOAD].data = &un->offload_timeout; +#endif + + nf_conntrack_standalone_init_tcp_sysctl(net, table); + nf_conntrack_standalone_init_sctp_sysctl(net, table); + nf_conntrack_standalone_init_dccp_sysctl(net, table); + nf_conntrack_standalone_init_gre_sysctl(net, table); + + /* Don't allow non-init_net ns to alter global sysctls */ + if (!net_eq(&init_net, net)) { + table[NF_SYSCTL_CT_MAX].mode = 0444; + table[NF_SYSCTL_CT_EXPECT_MAX].mode = 0444; + table[NF_SYSCTL_CT_BUCKETS].mode = 0444; + } + + cnet->sysctl_header = register_net_sysctl(net, "net/netfilter", table); + if (!cnet->sysctl_header) + goto out_unregister_netfilter; + + return 0; + +out_unregister_netfilter: + kfree(table); + return -ENOMEM; +} + +static void nf_conntrack_standalone_fini_sysctl(struct net *net) +{ + struct nf_conntrack_net *cnet = nf_ct_pernet(net); + struct ctl_table *table; + + table = cnet->sysctl_header->ctl_table_arg; + unregister_net_sysctl_table(cnet->sysctl_header); + kfree(table); +} +#else +static int nf_conntrack_standalone_init_sysctl(struct net *net) +{ + return 0; +} + +static void nf_conntrack_standalone_fini_sysctl(struct net *net) +{ +} +#endif /* CONFIG_SYSCTL */ + +static void nf_conntrack_fini_net(struct net *net) +{ + if (enable_hooks) + nf_ct_netns_put(net, NFPROTO_INET); + + nf_conntrack_standalone_fini_proc(net); + nf_conntrack_standalone_fini_sysctl(net); +} + +static int nf_conntrack_pernet_init(struct net *net) +{ + int ret; + + net->ct.sysctl_checksum = 1; + + ret = nf_conntrack_standalone_init_sysctl(net); + if (ret < 0) + return ret; + + ret = nf_conntrack_standalone_init_proc(net); + if (ret < 0) + goto out_proc; + + ret = nf_conntrack_init_net(net); + if (ret < 0) + goto out_init_net; + + if (enable_hooks) { + ret = nf_ct_netns_get(net, NFPROTO_INET); + if (ret < 0) + goto out_hooks; + } + + return 0; + +out_hooks: + nf_conntrack_cleanup_net(net); +out_init_net: + nf_conntrack_standalone_fini_proc(net); +out_proc: + nf_conntrack_standalone_fini_sysctl(net); + return ret; +} + +static void nf_conntrack_pernet_exit(struct list_head *net_exit_list) +{ + struct net *net; + + list_for_each_entry(net, net_exit_list, exit_list) + nf_conntrack_fini_net(net); + + nf_conntrack_cleanup_net_list(net_exit_list); +} + +static struct pernet_operations nf_conntrack_net_ops = { + .init = nf_conntrack_pernet_init, + .exit_batch = nf_conntrack_pernet_exit, + .id = &nf_conntrack_net_id, + .size = sizeof(struct nf_conntrack_net), +}; + +static int __init nf_conntrack_standalone_init(void) +{ + int ret = nf_conntrack_init_start(); + if (ret < 0) + goto out_start; + + BUILD_BUG_ON(NFCT_INFOMASK <= IP_CT_NUMBER); + +#ifdef CONFIG_SYSCTL + nf_ct_netfilter_header = + register_net_sysctl(&init_net, "net", nf_ct_netfilter_table); + if (!nf_ct_netfilter_header) { + pr_err("nf_conntrack: can't register to sysctl.\n"); + ret = -ENOMEM; + goto out_sysctl; + } + + nf_conntrack_htable_size_user = nf_conntrack_htable_size; +#endif + + nf_conntrack_init_end(); + + ret = register_pernet_subsys(&nf_conntrack_net_ops); + if (ret < 0) + goto out_pernet; + + return 0; + +out_pernet: +#ifdef CONFIG_SYSCTL + unregister_net_sysctl_table(nf_ct_netfilter_header); +out_sysctl: +#endif + nf_conntrack_cleanup_end(); +out_start: + return ret; +} + +static void __exit nf_conntrack_standalone_fini(void) +{ + nf_conntrack_cleanup_start(); + unregister_pernet_subsys(&nf_conntrack_net_ops); +#ifdef CONFIG_SYSCTL + unregister_net_sysctl_table(nf_ct_netfilter_header); +#endif + nf_conntrack_cleanup_end(); +} + +module_init(nf_conntrack_standalone_init); +module_exit(nf_conntrack_standalone_fini); diff --git a/net/netfilter/nf_conntrack_tftp.c b/net/netfilter/nf_conntrack_tftp.c new file mode 100644 index 000000000..80ee53f29 --- /dev/null +++ b/net/netfilter/nf_conntrack_tftp.c @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* (C) 2001-2002 Magnus Boden + * (C) 2006-2012 Patrick McHardy + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#define HELPER_NAME "tftp" + +MODULE_AUTHOR("Magnus Boden "); +MODULE_DESCRIPTION("TFTP connection tracking helper"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ip_conntrack_tftp"); +MODULE_ALIAS_NFCT_HELPER(HELPER_NAME); + +#define MAX_PORTS 8 +static unsigned short ports[MAX_PORTS]; +static unsigned int ports_c; +module_param_array(ports, ushort, &ports_c, 0400); +MODULE_PARM_DESC(ports, "Port numbers of TFTP servers"); + +unsigned int (*nf_nat_tftp_hook)(struct sk_buff *skb, + enum ip_conntrack_info ctinfo, + struct nf_conntrack_expect *exp) __read_mostly; +EXPORT_SYMBOL_GPL(nf_nat_tftp_hook); + +static int tftp_help(struct sk_buff *skb, + unsigned int protoff, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo) +{ + const struct tftphdr *tfh; + struct tftphdr _tftph; + struct nf_conntrack_expect *exp; + struct nf_conntrack_tuple *tuple; + unsigned int ret = NF_ACCEPT; + typeof(nf_nat_tftp_hook) nf_nat_tftp; + + tfh = skb_header_pointer(skb, protoff + sizeof(struct udphdr), + sizeof(_tftph), &_tftph); + if (tfh == NULL) + return NF_ACCEPT; + + switch (ntohs(tfh->opcode)) { + case TFTP_OPCODE_READ: + case TFTP_OPCODE_WRITE: + /* RRQ and WRQ works the same way */ + nf_ct_dump_tuple(&ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple); + nf_ct_dump_tuple(&ct->tuplehash[IP_CT_DIR_REPLY].tuple); + + exp = nf_ct_expect_alloc(ct); + if (exp == NULL) { + nf_ct_helper_log(skb, ct, "cannot alloc expectation"); + return NF_DROP; + } + tuple = &ct->tuplehash[IP_CT_DIR_REPLY].tuple; + nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT, + nf_ct_l3num(ct), + &tuple->src.u3, &tuple->dst.u3, + IPPROTO_UDP, NULL, &tuple->dst.u.udp.port); + + pr_debug("expect: "); + nf_ct_dump_tuple(&exp->tuple); + + nf_nat_tftp = rcu_dereference(nf_nat_tftp_hook); + if (nf_nat_tftp && ct->status & IPS_NAT_MASK) + ret = nf_nat_tftp(skb, ctinfo, exp); + else if (nf_ct_expect_related(exp, 0) != 0) { + nf_ct_helper_log(skb, ct, "cannot add expectation"); + ret = NF_DROP; + } + nf_ct_expect_put(exp); + break; + case TFTP_OPCODE_DATA: + case TFTP_OPCODE_ACK: + pr_debug("Data/ACK opcode\n"); + break; + case TFTP_OPCODE_ERROR: + pr_debug("Error opcode\n"); + break; + default: + pr_debug("Unknown opcode\n"); + } + return ret; +} + +static struct nf_conntrack_helper tftp[MAX_PORTS * 2] __read_mostly; + +static const struct nf_conntrack_expect_policy tftp_exp_policy = { + .max_expected = 1, + .timeout = 5 * 60, +}; + +static void __exit nf_conntrack_tftp_fini(void) +{ + nf_conntrack_helpers_unregister(tftp, ports_c * 2); +} + +static int __init nf_conntrack_tftp_init(void) +{ + int i, ret; + + NF_CT_HELPER_BUILD_BUG_ON(0); + + if (ports_c == 0) + ports[ports_c++] = TFTP_PORT; + + for (i = 0; i < ports_c; i++) { + nf_ct_helper_init(&tftp[2 * i], AF_INET, IPPROTO_UDP, + HELPER_NAME, TFTP_PORT, ports[i], i, + &tftp_exp_policy, 0, tftp_help, NULL, + THIS_MODULE); + nf_ct_helper_init(&tftp[2 * i + 1], AF_INET6, IPPROTO_UDP, + HELPER_NAME, TFTP_PORT, ports[i], i, + &tftp_exp_policy, 0, tftp_help, NULL, + THIS_MODULE); + } + + ret = nf_conntrack_helpers_register(tftp, ports_c * 2); + if (ret < 0) { + pr_err("failed to register helpers\n"); + return ret; + } + return 0; +} + +module_init(nf_conntrack_tftp_init); +module_exit(nf_conntrack_tftp_fini); diff --git a/net/netfilter/nf_conntrack_timeout.c b/net/netfilter/nf_conntrack_timeout.c new file mode 100644 index 000000000..0cc584d3d --- /dev/null +++ b/net/netfilter/nf_conntrack_timeout.c @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * (C) 2012 by Pablo Neira Ayuso + * (C) 2012 by Vyatta Inc. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +const struct nf_ct_timeout_hooks __rcu *nf_ct_timeout_hook __read_mostly; +EXPORT_SYMBOL_GPL(nf_ct_timeout_hook); + +static int untimeout(struct nf_conn *ct, void *timeout) +{ + struct nf_conn_timeout *timeout_ext = nf_ct_timeout_find(ct); + + if (timeout_ext) { + const struct nf_ct_timeout *t; + + t = rcu_access_pointer(timeout_ext->timeout); + + if (!timeout || t == timeout) + RCU_INIT_POINTER(timeout_ext->timeout, NULL); + } + + /* We are not intended to delete this conntrack. */ + return 0; +} + +void nf_ct_untimeout(struct net *net, struct nf_ct_timeout *timeout) +{ + struct nf_ct_iter_data iter_data = { + .net = net, + .data = timeout, + }; + + nf_ct_iterate_cleanup_net(untimeout, &iter_data); +} +EXPORT_SYMBOL_GPL(nf_ct_untimeout); + +static void __nf_ct_timeout_put(struct nf_ct_timeout *timeout) +{ + const struct nf_ct_timeout_hooks *h = rcu_dereference(nf_ct_timeout_hook); + + if (h) + h->timeout_put(timeout); +} + +int nf_ct_set_timeout(struct net *net, struct nf_conn *ct, + u8 l3num, u8 l4num, const char *timeout_name) +{ + const struct nf_ct_timeout_hooks *h; + struct nf_ct_timeout *timeout; + struct nf_conn_timeout *timeout_ext; + const char *errmsg = NULL; + int ret = 0; + + rcu_read_lock(); + h = rcu_dereference(nf_ct_timeout_hook); + if (!h) { + ret = -ENOENT; + errmsg = "Timeout policy base is empty"; + goto out; + } + + timeout = h->timeout_find_get(net, timeout_name); + if (!timeout) { + ret = -ENOENT; + pr_info_ratelimited("No such timeout policy \"%s\"\n", + timeout_name); + goto out; + } + + if (timeout->l3num != l3num) { + ret = -EINVAL; + pr_info_ratelimited("Timeout policy `%s' can only be used by " + "L%d protocol number %d\n", + timeout_name, 3, timeout->l3num); + goto err_put_timeout; + } + /* Make sure the timeout policy matches any existing protocol tracker, + * otherwise default to generic. + */ + if (timeout->l4proto->l4proto != l4num) { + ret = -EINVAL; + pr_info_ratelimited("Timeout policy `%s' can only be used by " + "L%d protocol number %d\n", + timeout_name, 4, timeout->l4proto->l4proto); + goto err_put_timeout; + } + timeout_ext = nf_ct_timeout_ext_add(ct, timeout, GFP_ATOMIC); + if (!timeout_ext) { + ret = -ENOMEM; + goto err_put_timeout; + } + + rcu_read_unlock(); + return ret; + +err_put_timeout: + __nf_ct_timeout_put(timeout); +out: + rcu_read_unlock(); + if (errmsg) + pr_info_ratelimited("%s\n", errmsg); + return ret; +} +EXPORT_SYMBOL_GPL(nf_ct_set_timeout); + +void nf_ct_destroy_timeout(struct nf_conn *ct) +{ + struct nf_conn_timeout *timeout_ext; + const struct nf_ct_timeout_hooks *h; + + rcu_read_lock(); + h = rcu_dereference(nf_ct_timeout_hook); + + if (h) { + timeout_ext = nf_ct_timeout_find(ct); + if (timeout_ext) { + struct nf_ct_timeout *t; + + t = rcu_dereference(timeout_ext->timeout); + if (t) + h->timeout_put(t); + RCU_INIT_POINTER(timeout_ext->timeout, NULL); + } + } + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(nf_ct_destroy_timeout); diff --git a/net/netfilter/nf_conntrack_timestamp.c b/net/netfilter/nf_conntrack_timestamp.c new file mode 100644 index 000000000..9e43a0a59 --- /dev/null +++ b/net/netfilter/nf_conntrack_timestamp.c @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * (C) 2010 Pablo Neira Ayuso + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include + +#include +#include +#include + +static bool nf_ct_tstamp __read_mostly; + +module_param_named(tstamp, nf_ct_tstamp, bool, 0644); +MODULE_PARM_DESC(tstamp, "Enable connection tracking flow timestamping."); + +void nf_conntrack_tstamp_pernet_init(struct net *net) +{ + net->ct.sysctl_tstamp = nf_ct_tstamp; +} diff --git a/net/netfilter/nf_dup_netdev.c b/net/netfilter/nf_dup_netdev.c new file mode 100644 index 000000000..a8e2425e4 --- /dev/null +++ b/net/netfilter/nf_dup_netdev.c @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2015 Pablo Neira Ayuso + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define NF_RECURSION_LIMIT 2 + +static DEFINE_PER_CPU(u8, nf_dup_skb_recursion); + +static void nf_do_netdev_egress(struct sk_buff *skb, struct net_device *dev, + enum nf_dev_hooks hook) +{ + if (__this_cpu_read(nf_dup_skb_recursion) > NF_RECURSION_LIMIT) + goto err; + + if (hook == NF_NETDEV_INGRESS && skb_mac_header_was_set(skb)) { + if (skb_cow_head(skb, skb->mac_len)) + goto err; + + skb_push(skb, skb->mac_len); + } + + skb->dev = dev; + skb_clear_tstamp(skb); + __this_cpu_inc(nf_dup_skb_recursion); + dev_queue_xmit(skb); + __this_cpu_dec(nf_dup_skb_recursion); + return; +err: + kfree_skb(skb); +} + +void nf_fwd_netdev_egress(const struct nft_pktinfo *pkt, int oif) +{ + struct net_device *dev; + + dev = dev_get_by_index_rcu(nft_net(pkt), oif); + if (!dev) { + kfree_skb(pkt->skb); + return; + } + + nf_do_netdev_egress(pkt->skb, dev, nft_hook(pkt)); +} +EXPORT_SYMBOL_GPL(nf_fwd_netdev_egress); + +void nf_dup_netdev_egress(const struct nft_pktinfo *pkt, int oif) +{ + struct net_device *dev; + struct sk_buff *skb; + + dev = dev_get_by_index_rcu(nft_net(pkt), oif); + if (dev == NULL) + return; + + skb = skb_clone(pkt->skb, GFP_ATOMIC); + if (skb) + nf_do_netdev_egress(skb, dev, nft_hook(pkt)); +} +EXPORT_SYMBOL_GPL(nf_dup_netdev_egress); + +int nft_fwd_dup_netdev_offload(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + enum flow_action_id id, int oif) +{ + struct flow_action_entry *entry; + struct net_device *dev; + + /* nft_flow_rule_destroy() releases the reference on this device. */ + dev = dev_get_by_index(ctx->net, oif); + if (!dev) + return -EOPNOTSUPP; + + entry = &flow->rule->action.entries[ctx->num_actions++]; + entry->id = id; + entry->dev = dev; + + return 0; +} +EXPORT_SYMBOL_GPL(nft_fwd_dup_netdev_offload); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_DESCRIPTION("Netfilter packet duplication support"); diff --git a/net/netfilter/nf_flow_table_core.c b/net/netfilter/nf_flow_table_core.c new file mode 100644 index 000000000..c1d99cb37 --- /dev/null +++ b/net/netfilter/nf_flow_table_core.c @@ -0,0 +1,699 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static DEFINE_MUTEX(flowtable_lock); +static LIST_HEAD(flowtables); + +static void +flow_offload_fill_dir(struct flow_offload *flow, + enum flow_offload_tuple_dir dir) +{ + struct flow_offload_tuple *ft = &flow->tuplehash[dir].tuple; + struct nf_conntrack_tuple *ctt = &flow->ct->tuplehash[dir].tuple; + + ft->dir = dir; + + switch (ctt->src.l3num) { + case NFPROTO_IPV4: + ft->src_v4 = ctt->src.u3.in; + ft->dst_v4 = ctt->dst.u3.in; + break; + case NFPROTO_IPV6: + ft->src_v6 = ctt->src.u3.in6; + ft->dst_v6 = ctt->dst.u3.in6; + break; + } + + ft->l3proto = ctt->src.l3num; + ft->l4proto = ctt->dst.protonum; + + switch (ctt->dst.protonum) { + case IPPROTO_TCP: + case IPPROTO_UDP: + ft->src_port = ctt->src.u.tcp.port; + ft->dst_port = ctt->dst.u.tcp.port; + break; + } +} + +struct flow_offload *flow_offload_alloc(struct nf_conn *ct) +{ + struct flow_offload *flow; + + if (unlikely(nf_ct_is_dying(ct))) + return NULL; + + flow = kzalloc(sizeof(*flow), GFP_ATOMIC); + if (!flow) + return NULL; + + refcount_inc(&ct->ct_general.use); + flow->ct = ct; + + flow_offload_fill_dir(flow, FLOW_OFFLOAD_DIR_ORIGINAL); + flow_offload_fill_dir(flow, FLOW_OFFLOAD_DIR_REPLY); + + if (ct->status & IPS_SRC_NAT) + __set_bit(NF_FLOW_SNAT, &flow->flags); + if (ct->status & IPS_DST_NAT) + __set_bit(NF_FLOW_DNAT, &flow->flags); + + return flow; +} +EXPORT_SYMBOL_GPL(flow_offload_alloc); + +static u32 flow_offload_dst_cookie(struct flow_offload_tuple *flow_tuple) +{ + const struct rt6_info *rt; + + if (flow_tuple->l3proto == NFPROTO_IPV6) { + rt = (const struct rt6_info *)flow_tuple->dst_cache; + return rt6_get_cookie(rt); + } + + return 0; +} + +static int flow_offload_fill_route(struct flow_offload *flow, + const struct nf_flow_route *route, + enum flow_offload_tuple_dir dir) +{ + struct flow_offload_tuple *flow_tuple = &flow->tuplehash[dir].tuple; + struct dst_entry *dst = route->tuple[dir].dst; + int i, j = 0; + + switch (flow_tuple->l3proto) { + case NFPROTO_IPV4: + flow_tuple->mtu = ip_dst_mtu_maybe_forward(dst, true); + break; + case NFPROTO_IPV6: + flow_tuple->mtu = ip6_dst_mtu_maybe_forward(dst, true); + break; + } + + flow_tuple->iifidx = route->tuple[dir].in.ifindex; + for (i = route->tuple[dir].in.num_encaps - 1; i >= 0; i--) { + flow_tuple->encap[j].id = route->tuple[dir].in.encap[i].id; + flow_tuple->encap[j].proto = route->tuple[dir].in.encap[i].proto; + if (route->tuple[dir].in.ingress_vlans & BIT(i)) + flow_tuple->in_vlan_ingress |= BIT(j); + j++; + } + flow_tuple->encap_num = route->tuple[dir].in.num_encaps; + + switch (route->tuple[dir].xmit_type) { + case FLOW_OFFLOAD_XMIT_DIRECT: + memcpy(flow_tuple->out.h_dest, route->tuple[dir].out.h_dest, + ETH_ALEN); + memcpy(flow_tuple->out.h_source, route->tuple[dir].out.h_source, + ETH_ALEN); + flow_tuple->out.ifidx = route->tuple[dir].out.ifindex; + flow_tuple->out.hw_ifidx = route->tuple[dir].out.hw_ifindex; + break; + case FLOW_OFFLOAD_XMIT_XFRM: + case FLOW_OFFLOAD_XMIT_NEIGH: + if (!dst_hold_safe(route->tuple[dir].dst)) + return -1; + + flow_tuple->dst_cache = dst; + flow_tuple->dst_cookie = flow_offload_dst_cookie(flow_tuple); + break; + default: + WARN_ON_ONCE(1); + break; + } + flow_tuple->xmit_type = route->tuple[dir].xmit_type; + + return 0; +} + +static void nft_flow_dst_release(struct flow_offload *flow, + enum flow_offload_tuple_dir dir) +{ + if (flow->tuplehash[dir].tuple.xmit_type == FLOW_OFFLOAD_XMIT_NEIGH || + flow->tuplehash[dir].tuple.xmit_type == FLOW_OFFLOAD_XMIT_XFRM) + dst_release(flow->tuplehash[dir].tuple.dst_cache); +} + +int flow_offload_route_init(struct flow_offload *flow, + const struct nf_flow_route *route) +{ + int err; + + err = flow_offload_fill_route(flow, route, FLOW_OFFLOAD_DIR_ORIGINAL); + if (err < 0) + return err; + + err = flow_offload_fill_route(flow, route, FLOW_OFFLOAD_DIR_REPLY); + if (err < 0) + goto err_route_reply; + + flow->type = NF_FLOW_OFFLOAD_ROUTE; + + return 0; + +err_route_reply: + nft_flow_dst_release(flow, FLOW_OFFLOAD_DIR_ORIGINAL); + + return err; +} +EXPORT_SYMBOL_GPL(flow_offload_route_init); + +static void flow_offload_fixup_tcp(struct ip_ct_tcp *tcp) +{ + tcp->seen[0].td_maxwin = 0; + tcp->seen[1].td_maxwin = 0; +} + +static void flow_offload_fixup_ct(struct nf_conn *ct) +{ + struct net *net = nf_ct_net(ct); + int l4num = nf_ct_protonum(ct); + s32 timeout; + + if (l4num == IPPROTO_TCP) { + struct nf_tcp_net *tn = nf_tcp_pernet(net); + + flow_offload_fixup_tcp(&ct->proto.tcp); + + timeout = tn->timeouts[ct->proto.tcp.state]; + timeout -= tn->offload_timeout; + } else if (l4num == IPPROTO_UDP) { + struct nf_udp_net *tn = nf_udp_pernet(net); + + timeout = tn->timeouts[UDP_CT_REPLIED]; + timeout -= tn->offload_timeout; + } else { + return; + } + + if (timeout < 0) + timeout = 0; + + if (nf_flow_timeout_delta(READ_ONCE(ct->timeout)) > (__s32)timeout) + WRITE_ONCE(ct->timeout, nfct_time_stamp + timeout); +} + +static void flow_offload_route_release(struct flow_offload *flow) +{ + nft_flow_dst_release(flow, FLOW_OFFLOAD_DIR_ORIGINAL); + nft_flow_dst_release(flow, FLOW_OFFLOAD_DIR_REPLY); +} + +void flow_offload_free(struct flow_offload *flow) +{ + switch (flow->type) { + case NF_FLOW_OFFLOAD_ROUTE: + flow_offload_route_release(flow); + break; + default: + break; + } + nf_ct_put(flow->ct); + kfree_rcu(flow, rcu_head); +} +EXPORT_SYMBOL_GPL(flow_offload_free); + +static u32 flow_offload_hash(const void *data, u32 len, u32 seed) +{ + const struct flow_offload_tuple *tuple = data; + + return jhash(tuple, offsetof(struct flow_offload_tuple, __hash), seed); +} + +static u32 flow_offload_hash_obj(const void *data, u32 len, u32 seed) +{ + const struct flow_offload_tuple_rhash *tuplehash = data; + + return jhash(&tuplehash->tuple, offsetof(struct flow_offload_tuple, __hash), seed); +} + +static int flow_offload_hash_cmp(struct rhashtable_compare_arg *arg, + const void *ptr) +{ + const struct flow_offload_tuple *tuple = arg->key; + const struct flow_offload_tuple_rhash *x = ptr; + + if (memcmp(&x->tuple, tuple, offsetof(struct flow_offload_tuple, __hash))) + return 1; + + return 0; +} + +static const struct rhashtable_params nf_flow_offload_rhash_params = { + .head_offset = offsetof(struct flow_offload_tuple_rhash, node), + .hashfn = flow_offload_hash, + .obj_hashfn = flow_offload_hash_obj, + .obj_cmpfn = flow_offload_hash_cmp, + .automatic_shrinking = true, +}; + +unsigned long flow_offload_get_timeout(struct flow_offload *flow) +{ + unsigned long timeout = NF_FLOW_TIMEOUT; + struct net *net = nf_ct_net(flow->ct); + int l4num = nf_ct_protonum(flow->ct); + + if (l4num == IPPROTO_TCP) { + struct nf_tcp_net *tn = nf_tcp_pernet(net); + + timeout = tn->offload_timeout; + } else if (l4num == IPPROTO_UDP) { + struct nf_udp_net *tn = nf_udp_pernet(net); + + timeout = tn->offload_timeout; + } + + return timeout; +} + +int flow_offload_add(struct nf_flowtable *flow_table, struct flow_offload *flow) +{ + int err; + + flow->timeout = nf_flowtable_time_stamp + flow_offload_get_timeout(flow); + + err = rhashtable_insert_fast(&flow_table->rhashtable, + &flow->tuplehash[0].node, + nf_flow_offload_rhash_params); + if (err < 0) + return err; + + err = rhashtable_insert_fast(&flow_table->rhashtable, + &flow->tuplehash[1].node, + nf_flow_offload_rhash_params); + if (err < 0) { + rhashtable_remove_fast(&flow_table->rhashtable, + &flow->tuplehash[0].node, + nf_flow_offload_rhash_params); + return err; + } + + nf_ct_offload_timeout(flow->ct); + + if (nf_flowtable_hw_offload(flow_table)) { + __set_bit(NF_FLOW_HW, &flow->flags); + nf_flow_offload_add(flow_table, flow); + } + + return 0; +} +EXPORT_SYMBOL_GPL(flow_offload_add); + +void flow_offload_refresh(struct nf_flowtable *flow_table, + struct flow_offload *flow, bool force) +{ + u32 timeout; + + timeout = nf_flowtable_time_stamp + flow_offload_get_timeout(flow); + if (force || timeout - READ_ONCE(flow->timeout) > HZ) + WRITE_ONCE(flow->timeout, timeout); + else + return; + + if (likely(!nf_flowtable_hw_offload(flow_table))) + return; + + nf_flow_offload_add(flow_table, flow); +} +EXPORT_SYMBOL_GPL(flow_offload_refresh); + +static inline bool nf_flow_has_expired(const struct flow_offload *flow) +{ + return nf_flow_timeout_delta(flow->timeout) <= 0; +} + +static void flow_offload_del(struct nf_flowtable *flow_table, + struct flow_offload *flow) +{ + rhashtable_remove_fast(&flow_table->rhashtable, + &flow->tuplehash[FLOW_OFFLOAD_DIR_ORIGINAL].node, + nf_flow_offload_rhash_params); + rhashtable_remove_fast(&flow_table->rhashtable, + &flow->tuplehash[FLOW_OFFLOAD_DIR_REPLY].node, + nf_flow_offload_rhash_params); + flow_offload_free(flow); +} + +void flow_offload_teardown(struct flow_offload *flow) +{ + clear_bit(IPS_OFFLOAD_BIT, &flow->ct->status); + set_bit(NF_FLOW_TEARDOWN, &flow->flags); + flow_offload_fixup_ct(flow->ct); +} +EXPORT_SYMBOL_GPL(flow_offload_teardown); + +struct flow_offload_tuple_rhash * +flow_offload_lookup(struct nf_flowtable *flow_table, + struct flow_offload_tuple *tuple) +{ + struct flow_offload_tuple_rhash *tuplehash; + struct flow_offload *flow; + int dir; + + tuplehash = rhashtable_lookup(&flow_table->rhashtable, tuple, + nf_flow_offload_rhash_params); + if (!tuplehash) + return NULL; + + dir = tuplehash->tuple.dir; + flow = container_of(tuplehash, struct flow_offload, tuplehash[dir]); + if (test_bit(NF_FLOW_TEARDOWN, &flow->flags)) + return NULL; + + if (unlikely(nf_ct_is_dying(flow->ct))) + return NULL; + + return tuplehash; +} +EXPORT_SYMBOL_GPL(flow_offload_lookup); + +static int +nf_flow_table_iterate(struct nf_flowtable *flow_table, + void (*iter)(struct nf_flowtable *flowtable, + struct flow_offload *flow, void *data), + void *data) +{ + struct flow_offload_tuple_rhash *tuplehash; + struct rhashtable_iter hti; + struct flow_offload *flow; + int err = 0; + + rhashtable_walk_enter(&flow_table->rhashtable, &hti); + rhashtable_walk_start(&hti); + + while ((tuplehash = rhashtable_walk_next(&hti))) { + if (IS_ERR(tuplehash)) { + if (PTR_ERR(tuplehash) != -EAGAIN) { + err = PTR_ERR(tuplehash); + break; + } + continue; + } + if (tuplehash->tuple.dir) + continue; + + flow = container_of(tuplehash, struct flow_offload, tuplehash[0]); + + iter(flow_table, flow, data); + } + rhashtable_walk_stop(&hti); + rhashtable_walk_exit(&hti); + + return err; +} + +static bool nf_flow_custom_gc(struct nf_flowtable *flow_table, + const struct flow_offload *flow) +{ + return flow_table->type->gc && flow_table->type->gc(flow); +} + +static void nf_flow_offload_gc_step(struct nf_flowtable *flow_table, + struct flow_offload *flow, void *data) +{ + if (nf_flow_has_expired(flow) || + nf_ct_is_dying(flow->ct) || + nf_flow_custom_gc(flow_table, flow)) + flow_offload_teardown(flow); + + if (test_bit(NF_FLOW_TEARDOWN, &flow->flags)) { + if (test_bit(NF_FLOW_HW, &flow->flags)) { + if (!test_bit(NF_FLOW_HW_DYING, &flow->flags)) + nf_flow_offload_del(flow_table, flow); + else if (test_bit(NF_FLOW_HW_DEAD, &flow->flags)) + flow_offload_del(flow_table, flow); + } else { + flow_offload_del(flow_table, flow); + } + } else if (test_bit(NF_FLOW_HW, &flow->flags)) { + nf_flow_offload_stats(flow_table, flow); + } +} + +void nf_flow_table_gc_run(struct nf_flowtable *flow_table) +{ + nf_flow_table_iterate(flow_table, nf_flow_offload_gc_step, NULL); +} + +static void nf_flow_offload_work_gc(struct work_struct *work) +{ + struct nf_flowtable *flow_table; + + flow_table = container_of(work, struct nf_flowtable, gc_work.work); + nf_flow_table_gc_run(flow_table); + queue_delayed_work(system_power_efficient_wq, &flow_table->gc_work, HZ); +} + +static void nf_flow_nat_port_tcp(struct sk_buff *skb, unsigned int thoff, + __be16 port, __be16 new_port) +{ + struct tcphdr *tcph; + + tcph = (void *)(skb_network_header(skb) + thoff); + inet_proto_csum_replace2(&tcph->check, skb, port, new_port, false); +} + +static void nf_flow_nat_port_udp(struct sk_buff *skb, unsigned int thoff, + __be16 port, __be16 new_port) +{ + struct udphdr *udph; + + udph = (void *)(skb_network_header(skb) + thoff); + if (udph->check || skb->ip_summed == CHECKSUM_PARTIAL) { + inet_proto_csum_replace2(&udph->check, skb, port, + new_port, false); + if (!udph->check) + udph->check = CSUM_MANGLED_0; + } +} + +static void nf_flow_nat_port(struct sk_buff *skb, unsigned int thoff, + u8 protocol, __be16 port, __be16 new_port) +{ + switch (protocol) { + case IPPROTO_TCP: + nf_flow_nat_port_tcp(skb, thoff, port, new_port); + break; + case IPPROTO_UDP: + nf_flow_nat_port_udp(skb, thoff, port, new_port); + break; + } +} + +void nf_flow_snat_port(const struct flow_offload *flow, + struct sk_buff *skb, unsigned int thoff, + u8 protocol, enum flow_offload_tuple_dir dir) +{ + struct flow_ports *hdr; + __be16 port, new_port; + + hdr = (void *)(skb_network_header(skb) + thoff); + + switch (dir) { + case FLOW_OFFLOAD_DIR_ORIGINAL: + port = hdr->source; + new_port = flow->tuplehash[FLOW_OFFLOAD_DIR_REPLY].tuple.dst_port; + hdr->source = new_port; + break; + case FLOW_OFFLOAD_DIR_REPLY: + port = hdr->dest; + new_port = flow->tuplehash[FLOW_OFFLOAD_DIR_ORIGINAL].tuple.src_port; + hdr->dest = new_port; + break; + } + + nf_flow_nat_port(skb, thoff, protocol, port, new_port); +} +EXPORT_SYMBOL_GPL(nf_flow_snat_port); + +void nf_flow_dnat_port(const struct flow_offload *flow, struct sk_buff *skb, + unsigned int thoff, u8 protocol, + enum flow_offload_tuple_dir dir) +{ + struct flow_ports *hdr; + __be16 port, new_port; + + hdr = (void *)(skb_network_header(skb) + thoff); + + switch (dir) { + case FLOW_OFFLOAD_DIR_ORIGINAL: + port = hdr->dest; + new_port = flow->tuplehash[FLOW_OFFLOAD_DIR_REPLY].tuple.src_port; + hdr->dest = new_port; + break; + case FLOW_OFFLOAD_DIR_REPLY: + port = hdr->source; + new_port = flow->tuplehash[FLOW_OFFLOAD_DIR_ORIGINAL].tuple.dst_port; + hdr->source = new_port; + break; + } + + nf_flow_nat_port(skb, thoff, protocol, port, new_port); +} +EXPORT_SYMBOL_GPL(nf_flow_dnat_port); + +int nf_flow_table_init(struct nf_flowtable *flowtable) +{ + int err; + + INIT_DELAYED_WORK(&flowtable->gc_work, nf_flow_offload_work_gc); + flow_block_init(&flowtable->flow_block); + init_rwsem(&flowtable->flow_block_lock); + + err = rhashtable_init(&flowtable->rhashtable, + &nf_flow_offload_rhash_params); + if (err < 0) + return err; + + queue_delayed_work(system_power_efficient_wq, + &flowtable->gc_work, HZ); + + mutex_lock(&flowtable_lock); + list_add(&flowtable->list, &flowtables); + mutex_unlock(&flowtable_lock); + + return 0; +} +EXPORT_SYMBOL_GPL(nf_flow_table_init); + +static void nf_flow_table_do_cleanup(struct nf_flowtable *flow_table, + struct flow_offload *flow, void *data) +{ + struct net_device *dev = data; + + if (!dev) { + flow_offload_teardown(flow); + return; + } + + if (net_eq(nf_ct_net(flow->ct), dev_net(dev)) && + (flow->tuplehash[0].tuple.iifidx == dev->ifindex || + flow->tuplehash[1].tuple.iifidx == dev->ifindex)) + flow_offload_teardown(flow); +} + +void nf_flow_table_gc_cleanup(struct nf_flowtable *flowtable, + struct net_device *dev) +{ + nf_flow_table_iterate(flowtable, nf_flow_table_do_cleanup, dev); + flush_delayed_work(&flowtable->gc_work); + nf_flow_table_offload_flush(flowtable); +} + +void nf_flow_table_cleanup(struct net_device *dev) +{ + struct nf_flowtable *flowtable; + + mutex_lock(&flowtable_lock); + list_for_each_entry(flowtable, &flowtables, list) + nf_flow_table_gc_cleanup(flowtable, dev); + mutex_unlock(&flowtable_lock); +} +EXPORT_SYMBOL_GPL(nf_flow_table_cleanup); + +void nf_flow_table_free(struct nf_flowtable *flow_table) +{ + mutex_lock(&flowtable_lock); + list_del(&flow_table->list); + mutex_unlock(&flowtable_lock); + + cancel_delayed_work_sync(&flow_table->gc_work); + nf_flow_table_offload_flush(flow_table); + /* ... no more pending work after this stage ... */ + nf_flow_table_iterate(flow_table, nf_flow_table_do_cleanup, NULL); + nf_flow_table_gc_run(flow_table); + nf_flow_table_offload_flush_cleanup(flow_table); + rhashtable_destroy(&flow_table->rhashtable); +} +EXPORT_SYMBOL_GPL(nf_flow_table_free); + +static int nf_flow_table_init_net(struct net *net) +{ + net->ft.stat = alloc_percpu(struct nf_flow_table_stat); + return net->ft.stat ? 0 : -ENOMEM; +} + +static void nf_flow_table_fini_net(struct net *net) +{ + free_percpu(net->ft.stat); +} + +static int nf_flow_table_pernet_init(struct net *net) +{ + int ret; + + ret = nf_flow_table_init_net(net); + if (ret < 0) + return ret; + + ret = nf_flow_table_init_proc(net); + if (ret < 0) + goto out_proc; + + return 0; + +out_proc: + nf_flow_table_fini_net(net); + return ret; +} + +static void nf_flow_table_pernet_exit(struct list_head *net_exit_list) +{ + struct net *net; + + list_for_each_entry(net, net_exit_list, exit_list) { + nf_flow_table_fini_proc(net); + nf_flow_table_fini_net(net); + } +} + +static struct pernet_operations nf_flow_table_net_ops = { + .init = nf_flow_table_pernet_init, + .exit_batch = nf_flow_table_pernet_exit, +}; + +static int __init nf_flow_table_module_init(void) +{ + int ret; + + ret = register_pernet_subsys(&nf_flow_table_net_ops); + if (ret < 0) + return ret; + + ret = nf_flow_table_offload_init(); + if (ret) + goto out_offload; + + return 0; + +out_offload: + unregister_pernet_subsys(&nf_flow_table_net_ops); + return ret; +} + +static void __exit nf_flow_table_module_exit(void) +{ + nf_flow_table_offload_exit(); + unregister_pernet_subsys(&nf_flow_table_net_ops); +} + +module_init(nf_flow_table_module_init); +module_exit(nf_flow_table_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_DESCRIPTION("Netfilter flow table module"); diff --git a/net/netfilter/nf_flow_table_inet.c b/net/netfilter/nf_flow_table_inet.c new file mode 100644 index 000000000..9505f9d18 --- /dev/null +++ b/net/netfilter/nf_flow_table_inet.c @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include +#include +#include + +static unsigned int +nf_flow_offload_inet_hook(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct vlan_ethhdr *veth; + __be16 proto; + + switch (skb->protocol) { + case htons(ETH_P_8021Q): + veth = (struct vlan_ethhdr *)skb_mac_header(skb); + proto = veth->h_vlan_encapsulated_proto; + break; + case htons(ETH_P_PPP_SES): + proto = nf_flow_pppoe_proto(skb); + break; + default: + proto = skb->protocol; + break; + } + + switch (proto) { + case htons(ETH_P_IP): + return nf_flow_offload_ip_hook(priv, skb, state); + case htons(ETH_P_IPV6): + return nf_flow_offload_ipv6_hook(priv, skb, state); + } + + return NF_ACCEPT; +} + +static int nf_flow_rule_route_inet(struct net *net, + struct flow_offload *flow, + enum flow_offload_tuple_dir dir, + struct nf_flow_rule *flow_rule) +{ + const struct flow_offload_tuple *flow_tuple = &flow->tuplehash[dir].tuple; + int err; + + switch (flow_tuple->l3proto) { + case NFPROTO_IPV4: + err = nf_flow_rule_route_ipv4(net, flow, dir, flow_rule); + break; + case NFPROTO_IPV6: + err = nf_flow_rule_route_ipv6(net, flow, dir, flow_rule); + break; + default: + err = -1; + break; + } + + return err; +} + +static struct nf_flowtable_type flowtable_inet = { + .family = NFPROTO_INET, + .init = nf_flow_table_init, + .setup = nf_flow_table_offload_setup, + .action = nf_flow_rule_route_inet, + .free = nf_flow_table_free, + .hook = nf_flow_offload_inet_hook, + .owner = THIS_MODULE, +}; + +static struct nf_flowtable_type flowtable_ipv4 = { + .family = NFPROTO_IPV4, + .init = nf_flow_table_init, + .setup = nf_flow_table_offload_setup, + .action = nf_flow_rule_route_ipv4, + .free = nf_flow_table_free, + .hook = nf_flow_offload_ip_hook, + .owner = THIS_MODULE, +}; + +static struct nf_flowtable_type flowtable_ipv6 = { + .family = NFPROTO_IPV6, + .init = nf_flow_table_init, + .setup = nf_flow_table_offload_setup, + .action = nf_flow_rule_route_ipv6, + .free = nf_flow_table_free, + .hook = nf_flow_offload_ipv6_hook, + .owner = THIS_MODULE, +}; + +static int __init nf_flow_inet_module_init(void) +{ + nft_register_flowtable_type(&flowtable_ipv4); + nft_register_flowtable_type(&flowtable_ipv6); + nft_register_flowtable_type(&flowtable_inet); + + return 0; +} + +static void __exit nf_flow_inet_module_exit(void) +{ + nft_unregister_flowtable_type(&flowtable_inet); + nft_unregister_flowtable_type(&flowtable_ipv6); + nft_unregister_flowtable_type(&flowtable_ipv4); +} + +module_init(nf_flow_inet_module_init); +module_exit(nf_flow_inet_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_ALIAS_NF_FLOWTABLE(AF_INET); +MODULE_ALIAS_NF_FLOWTABLE(AF_INET6); +MODULE_ALIAS_NF_FLOWTABLE(1); /* NFPROTO_INET */ +MODULE_DESCRIPTION("Netfilter flow table mixed IPv4/IPv6 module"); diff --git a/net/netfilter/nf_flow_table_ip.c b/net/netfilter/nf_flow_table_ip.c new file mode 100644 index 000000000..6feaac9ab --- /dev/null +++ b/net/netfilter/nf_flow_table_ip.c @@ -0,0 +1,689 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +/* For layer 4 checksum field offset. */ +#include +#include + +static int nf_flow_state_check(struct flow_offload *flow, int proto, + struct sk_buff *skb, unsigned int thoff) +{ + struct tcphdr *tcph; + + if (proto != IPPROTO_TCP) + return 0; + + tcph = (void *)(skb_network_header(skb) + thoff); + if (unlikely(tcph->fin || tcph->rst)) { + flow_offload_teardown(flow); + return -1; + } + + return 0; +} + +static void nf_flow_nat_ip_tcp(struct sk_buff *skb, unsigned int thoff, + __be32 addr, __be32 new_addr) +{ + struct tcphdr *tcph; + + tcph = (void *)(skb_network_header(skb) + thoff); + inet_proto_csum_replace4(&tcph->check, skb, addr, new_addr, true); +} + +static void nf_flow_nat_ip_udp(struct sk_buff *skb, unsigned int thoff, + __be32 addr, __be32 new_addr) +{ + struct udphdr *udph; + + udph = (void *)(skb_network_header(skb) + thoff); + if (udph->check || skb->ip_summed == CHECKSUM_PARTIAL) { + inet_proto_csum_replace4(&udph->check, skb, addr, + new_addr, true); + if (!udph->check) + udph->check = CSUM_MANGLED_0; + } +} + +static void nf_flow_nat_ip_l4proto(struct sk_buff *skb, struct iphdr *iph, + unsigned int thoff, __be32 addr, + __be32 new_addr) +{ + switch (iph->protocol) { + case IPPROTO_TCP: + nf_flow_nat_ip_tcp(skb, thoff, addr, new_addr); + break; + case IPPROTO_UDP: + nf_flow_nat_ip_udp(skb, thoff, addr, new_addr); + break; + } +} + +static void nf_flow_snat_ip(const struct flow_offload *flow, + struct sk_buff *skb, struct iphdr *iph, + unsigned int thoff, enum flow_offload_tuple_dir dir) +{ + __be32 addr, new_addr; + + switch (dir) { + case FLOW_OFFLOAD_DIR_ORIGINAL: + addr = iph->saddr; + new_addr = flow->tuplehash[FLOW_OFFLOAD_DIR_REPLY].tuple.dst_v4.s_addr; + iph->saddr = new_addr; + break; + case FLOW_OFFLOAD_DIR_REPLY: + addr = iph->daddr; + new_addr = flow->tuplehash[FLOW_OFFLOAD_DIR_ORIGINAL].tuple.src_v4.s_addr; + iph->daddr = new_addr; + break; + } + csum_replace4(&iph->check, addr, new_addr); + + nf_flow_nat_ip_l4proto(skb, iph, thoff, addr, new_addr); +} + +static void nf_flow_dnat_ip(const struct flow_offload *flow, + struct sk_buff *skb, struct iphdr *iph, + unsigned int thoff, enum flow_offload_tuple_dir dir) +{ + __be32 addr, new_addr; + + switch (dir) { + case FLOW_OFFLOAD_DIR_ORIGINAL: + addr = iph->daddr; + new_addr = flow->tuplehash[FLOW_OFFLOAD_DIR_REPLY].tuple.src_v4.s_addr; + iph->daddr = new_addr; + break; + case FLOW_OFFLOAD_DIR_REPLY: + addr = iph->saddr; + new_addr = flow->tuplehash[FLOW_OFFLOAD_DIR_ORIGINAL].tuple.dst_v4.s_addr; + iph->saddr = new_addr; + break; + } + csum_replace4(&iph->check, addr, new_addr); + + nf_flow_nat_ip_l4proto(skb, iph, thoff, addr, new_addr); +} + +static void nf_flow_nat_ip(const struct flow_offload *flow, struct sk_buff *skb, + unsigned int thoff, enum flow_offload_tuple_dir dir, + struct iphdr *iph) +{ + if (test_bit(NF_FLOW_SNAT, &flow->flags)) { + nf_flow_snat_port(flow, skb, thoff, iph->protocol, dir); + nf_flow_snat_ip(flow, skb, iph, thoff, dir); + } + if (test_bit(NF_FLOW_DNAT, &flow->flags)) { + nf_flow_dnat_port(flow, skb, thoff, iph->protocol, dir); + nf_flow_dnat_ip(flow, skb, iph, thoff, dir); + } +} + +static bool ip_has_options(unsigned int thoff) +{ + return thoff != sizeof(struct iphdr); +} + +static void nf_flow_tuple_encap(struct sk_buff *skb, + struct flow_offload_tuple *tuple) +{ + struct vlan_ethhdr *veth; + struct pppoe_hdr *phdr; + int i = 0; + + if (skb_vlan_tag_present(skb)) { + tuple->encap[i].id = skb_vlan_tag_get(skb); + tuple->encap[i].proto = skb->vlan_proto; + i++; + } + switch (skb->protocol) { + case htons(ETH_P_8021Q): + veth = (struct vlan_ethhdr *)skb_mac_header(skb); + tuple->encap[i].id = ntohs(veth->h_vlan_TCI); + tuple->encap[i].proto = skb->protocol; + break; + case htons(ETH_P_PPP_SES): + phdr = (struct pppoe_hdr *)skb_mac_header(skb); + tuple->encap[i].id = ntohs(phdr->sid); + tuple->encap[i].proto = skb->protocol; + break; + } +} + +static int nf_flow_tuple_ip(struct sk_buff *skb, const struct net_device *dev, + struct flow_offload_tuple *tuple, u32 *hdrsize, + u32 offset) +{ + struct flow_ports *ports; + unsigned int thoff; + struct iphdr *iph; + u8 ipproto; + + if (!pskb_may_pull(skb, sizeof(*iph) + offset)) + return -1; + + iph = (struct iphdr *)(skb_network_header(skb) + offset); + thoff = (iph->ihl * 4); + + if (ip_is_fragment(iph) || + unlikely(ip_has_options(thoff))) + return -1; + + thoff += offset; + + ipproto = iph->protocol; + switch (ipproto) { + case IPPROTO_TCP: + *hdrsize = sizeof(struct tcphdr); + break; + case IPPROTO_UDP: + *hdrsize = sizeof(struct udphdr); + break; +#ifdef CONFIG_NF_CT_PROTO_GRE + case IPPROTO_GRE: + *hdrsize = sizeof(struct gre_base_hdr); + break; +#endif + default: + return -1; + } + + if (iph->ttl <= 1) + return -1; + + if (!pskb_may_pull(skb, thoff + *hdrsize)) + return -1; + + switch (ipproto) { + case IPPROTO_TCP: + case IPPROTO_UDP: + ports = (struct flow_ports *)(skb_network_header(skb) + thoff); + tuple->src_port = ports->source; + tuple->dst_port = ports->dest; + break; + case IPPROTO_GRE: { + struct gre_base_hdr *greh; + + greh = (struct gre_base_hdr *)(skb_network_header(skb) + thoff); + if ((greh->flags & GRE_VERSION) != GRE_VERSION_0) + return -1; + break; + } + } + + iph = (struct iphdr *)(skb_network_header(skb) + offset); + + tuple->src_v4.s_addr = iph->saddr; + tuple->dst_v4.s_addr = iph->daddr; + tuple->l3proto = AF_INET; + tuple->l4proto = ipproto; + tuple->iifidx = dev->ifindex; + nf_flow_tuple_encap(skb, tuple); + + return 0; +} + +/* Based on ip_exceeds_mtu(). */ +static bool nf_flow_exceeds_mtu(const struct sk_buff *skb, unsigned int mtu) +{ + if (skb->len <= mtu) + return false; + + if (skb_is_gso(skb) && skb_gso_validate_network_len(skb, mtu)) + return false; + + return true; +} + +static inline bool nf_flow_dst_check(struct flow_offload_tuple *tuple) +{ + if (tuple->xmit_type != FLOW_OFFLOAD_XMIT_NEIGH && + tuple->xmit_type != FLOW_OFFLOAD_XMIT_XFRM) + return true; + + return dst_check(tuple->dst_cache, tuple->dst_cookie); +} + +static unsigned int nf_flow_xmit_xfrm(struct sk_buff *skb, + const struct nf_hook_state *state, + struct dst_entry *dst) +{ + skb_orphan(skb); + skb_dst_set_noref(skb, dst); + dst_output(state->net, state->sk, skb); + return NF_STOLEN; +} + +static bool nf_flow_skb_encap_protocol(const struct sk_buff *skb, __be16 proto, + u32 *offset) +{ + struct vlan_ethhdr *veth; + + switch (skb->protocol) { + case htons(ETH_P_8021Q): + veth = (struct vlan_ethhdr *)skb_mac_header(skb); + if (veth->h_vlan_encapsulated_proto == proto) { + *offset += VLAN_HLEN; + return true; + } + break; + case htons(ETH_P_PPP_SES): + if (nf_flow_pppoe_proto(skb) == proto) { + *offset += PPPOE_SES_HLEN; + return true; + } + break; + } + + return false; +} + +static void nf_flow_encap_pop(struct sk_buff *skb, + struct flow_offload_tuple_rhash *tuplehash) +{ + struct vlan_hdr *vlan_hdr; + int i; + + for (i = 0; i < tuplehash->tuple.encap_num; i++) { + if (skb_vlan_tag_present(skb)) { + __vlan_hwaccel_clear_tag(skb); + continue; + } + switch (skb->protocol) { + case htons(ETH_P_8021Q): + vlan_hdr = (struct vlan_hdr *)skb->data; + __skb_pull(skb, VLAN_HLEN); + vlan_set_encap_proto(skb, vlan_hdr); + skb_reset_network_header(skb); + break; + case htons(ETH_P_PPP_SES): + skb->protocol = nf_flow_pppoe_proto(skb); + skb_pull(skb, PPPOE_SES_HLEN); + skb_reset_network_header(skb); + break; + } + } +} + +static unsigned int nf_flow_queue_xmit(struct net *net, struct sk_buff *skb, + const struct flow_offload_tuple_rhash *tuplehash, + unsigned short type) +{ + struct net_device *outdev; + + outdev = dev_get_by_index_rcu(net, tuplehash->tuple.out.ifidx); + if (!outdev) + return NF_DROP; + + skb->dev = outdev; + dev_hard_header(skb, skb->dev, type, tuplehash->tuple.out.h_dest, + tuplehash->tuple.out.h_source, skb->len); + dev_queue_xmit(skb); + + return NF_STOLEN; +} + +unsigned int +nf_flow_offload_ip_hook(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct flow_offload_tuple_rhash *tuplehash; + struct nf_flowtable *flow_table = priv; + struct flow_offload_tuple tuple = {}; + enum flow_offload_tuple_dir dir; + struct flow_offload *flow; + struct net_device *outdev; + u32 hdrsize, offset = 0; + unsigned int thoff, mtu; + struct rtable *rt; + struct iphdr *iph; + __be32 nexthop; + int ret; + + if (skb->protocol != htons(ETH_P_IP) && + !nf_flow_skb_encap_protocol(skb, htons(ETH_P_IP), &offset)) + return NF_ACCEPT; + + if (nf_flow_tuple_ip(skb, state->in, &tuple, &hdrsize, offset) < 0) + return NF_ACCEPT; + + tuplehash = flow_offload_lookup(flow_table, &tuple); + if (tuplehash == NULL) + return NF_ACCEPT; + + dir = tuplehash->tuple.dir; + flow = container_of(tuplehash, struct flow_offload, tuplehash[dir]); + + mtu = flow->tuplehash[dir].tuple.mtu + offset; + if (unlikely(nf_flow_exceeds_mtu(skb, mtu))) + return NF_ACCEPT; + + iph = (struct iphdr *)(skb_network_header(skb) + offset); + thoff = (iph->ihl * 4) + offset; + if (nf_flow_state_check(flow, iph->protocol, skb, thoff)) + return NF_ACCEPT; + + if (!nf_flow_dst_check(&tuplehash->tuple)) { + flow_offload_teardown(flow); + return NF_ACCEPT; + } + + if (skb_try_make_writable(skb, thoff + hdrsize)) + return NF_DROP; + + flow_offload_refresh(flow_table, flow, false); + + nf_flow_encap_pop(skb, tuplehash); + thoff -= offset; + + iph = ip_hdr(skb); + nf_flow_nat_ip(flow, skb, thoff, dir, iph); + + ip_decrease_ttl(iph); + skb_clear_tstamp(skb); + + if (flow_table->flags & NF_FLOWTABLE_COUNTER) + nf_ct_acct_update(flow->ct, tuplehash->tuple.dir, skb->len); + + if (unlikely(tuplehash->tuple.xmit_type == FLOW_OFFLOAD_XMIT_XFRM)) { + rt = (struct rtable *)tuplehash->tuple.dst_cache; + memset(skb->cb, 0, sizeof(struct inet_skb_parm)); + IPCB(skb)->iif = skb->dev->ifindex; + IPCB(skb)->flags = IPSKB_FORWARDED; + return nf_flow_xmit_xfrm(skb, state, &rt->dst); + } + + switch (tuplehash->tuple.xmit_type) { + case FLOW_OFFLOAD_XMIT_NEIGH: + rt = (struct rtable *)tuplehash->tuple.dst_cache; + outdev = rt->dst.dev; + skb->dev = outdev; + nexthop = rt_nexthop(rt, flow->tuplehash[!dir].tuple.src_v4.s_addr); + skb_dst_set_noref(skb, &rt->dst); + neigh_xmit(NEIGH_ARP_TABLE, outdev, &nexthop, skb); + ret = NF_STOLEN; + break; + case FLOW_OFFLOAD_XMIT_DIRECT: + ret = nf_flow_queue_xmit(state->net, skb, tuplehash, ETH_P_IP); + if (ret == NF_DROP) + flow_offload_teardown(flow); + break; + } + + return ret; +} +EXPORT_SYMBOL_GPL(nf_flow_offload_ip_hook); + +static void nf_flow_nat_ipv6_tcp(struct sk_buff *skb, unsigned int thoff, + struct in6_addr *addr, + struct in6_addr *new_addr, + struct ipv6hdr *ip6h) +{ + struct tcphdr *tcph; + + tcph = (void *)(skb_network_header(skb) + thoff); + inet_proto_csum_replace16(&tcph->check, skb, addr->s6_addr32, + new_addr->s6_addr32, true); +} + +static void nf_flow_nat_ipv6_udp(struct sk_buff *skb, unsigned int thoff, + struct in6_addr *addr, + struct in6_addr *new_addr) +{ + struct udphdr *udph; + + udph = (void *)(skb_network_header(skb) + thoff); + if (udph->check || skb->ip_summed == CHECKSUM_PARTIAL) { + inet_proto_csum_replace16(&udph->check, skb, addr->s6_addr32, + new_addr->s6_addr32, true); + if (!udph->check) + udph->check = CSUM_MANGLED_0; + } +} + +static void nf_flow_nat_ipv6_l4proto(struct sk_buff *skb, struct ipv6hdr *ip6h, + unsigned int thoff, struct in6_addr *addr, + struct in6_addr *new_addr) +{ + switch (ip6h->nexthdr) { + case IPPROTO_TCP: + nf_flow_nat_ipv6_tcp(skb, thoff, addr, new_addr, ip6h); + break; + case IPPROTO_UDP: + nf_flow_nat_ipv6_udp(skb, thoff, addr, new_addr); + break; + } +} + +static void nf_flow_snat_ipv6(const struct flow_offload *flow, + struct sk_buff *skb, struct ipv6hdr *ip6h, + unsigned int thoff, + enum flow_offload_tuple_dir dir) +{ + struct in6_addr addr, new_addr; + + switch (dir) { + case FLOW_OFFLOAD_DIR_ORIGINAL: + addr = ip6h->saddr; + new_addr = flow->tuplehash[FLOW_OFFLOAD_DIR_REPLY].tuple.dst_v6; + ip6h->saddr = new_addr; + break; + case FLOW_OFFLOAD_DIR_REPLY: + addr = ip6h->daddr; + new_addr = flow->tuplehash[FLOW_OFFLOAD_DIR_ORIGINAL].tuple.src_v6; + ip6h->daddr = new_addr; + break; + } + + nf_flow_nat_ipv6_l4proto(skb, ip6h, thoff, &addr, &new_addr); +} + +static void nf_flow_dnat_ipv6(const struct flow_offload *flow, + struct sk_buff *skb, struct ipv6hdr *ip6h, + unsigned int thoff, + enum flow_offload_tuple_dir dir) +{ + struct in6_addr addr, new_addr; + + switch (dir) { + case FLOW_OFFLOAD_DIR_ORIGINAL: + addr = ip6h->daddr; + new_addr = flow->tuplehash[FLOW_OFFLOAD_DIR_REPLY].tuple.src_v6; + ip6h->daddr = new_addr; + break; + case FLOW_OFFLOAD_DIR_REPLY: + addr = ip6h->saddr; + new_addr = flow->tuplehash[FLOW_OFFLOAD_DIR_ORIGINAL].tuple.dst_v6; + ip6h->saddr = new_addr; + break; + } + + nf_flow_nat_ipv6_l4proto(skb, ip6h, thoff, &addr, &new_addr); +} + +static void nf_flow_nat_ipv6(const struct flow_offload *flow, + struct sk_buff *skb, + enum flow_offload_tuple_dir dir, + struct ipv6hdr *ip6h) +{ + unsigned int thoff = sizeof(*ip6h); + + if (test_bit(NF_FLOW_SNAT, &flow->flags)) { + nf_flow_snat_port(flow, skb, thoff, ip6h->nexthdr, dir); + nf_flow_snat_ipv6(flow, skb, ip6h, thoff, dir); + } + if (test_bit(NF_FLOW_DNAT, &flow->flags)) { + nf_flow_dnat_port(flow, skb, thoff, ip6h->nexthdr, dir); + nf_flow_dnat_ipv6(flow, skb, ip6h, thoff, dir); + } +} + +static int nf_flow_tuple_ipv6(struct sk_buff *skb, const struct net_device *dev, + struct flow_offload_tuple *tuple, u32 *hdrsize, + u32 offset) +{ + struct flow_ports *ports; + struct ipv6hdr *ip6h; + unsigned int thoff; + u8 nexthdr; + + thoff = sizeof(*ip6h) + offset; + if (!pskb_may_pull(skb, thoff)) + return -1; + + ip6h = (struct ipv6hdr *)(skb_network_header(skb) + offset); + + nexthdr = ip6h->nexthdr; + switch (nexthdr) { + case IPPROTO_TCP: + *hdrsize = sizeof(struct tcphdr); + break; + case IPPROTO_UDP: + *hdrsize = sizeof(struct udphdr); + break; +#ifdef CONFIG_NF_CT_PROTO_GRE + case IPPROTO_GRE: + *hdrsize = sizeof(struct gre_base_hdr); + break; +#endif + default: + return -1; + } + + if (ip6h->hop_limit <= 1) + return -1; + + if (!pskb_may_pull(skb, thoff + *hdrsize)) + return -1; + + switch (nexthdr) { + case IPPROTO_TCP: + case IPPROTO_UDP: + ports = (struct flow_ports *)(skb_network_header(skb) + thoff); + tuple->src_port = ports->source; + tuple->dst_port = ports->dest; + break; + case IPPROTO_GRE: { + struct gre_base_hdr *greh; + + greh = (struct gre_base_hdr *)(skb_network_header(skb) + thoff); + if ((greh->flags & GRE_VERSION) != GRE_VERSION_0) + return -1; + break; + } + } + + ip6h = (struct ipv6hdr *)(skb_network_header(skb) + offset); + + tuple->src_v6 = ip6h->saddr; + tuple->dst_v6 = ip6h->daddr; + tuple->l3proto = AF_INET6; + tuple->l4proto = nexthdr; + tuple->iifidx = dev->ifindex; + nf_flow_tuple_encap(skb, tuple); + + return 0; +} + +unsigned int +nf_flow_offload_ipv6_hook(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct flow_offload_tuple_rhash *tuplehash; + struct nf_flowtable *flow_table = priv; + struct flow_offload_tuple tuple = {}; + enum flow_offload_tuple_dir dir; + const struct in6_addr *nexthop; + struct flow_offload *flow; + struct net_device *outdev; + unsigned int thoff, mtu; + u32 hdrsize, offset = 0; + struct ipv6hdr *ip6h; + struct rt6_info *rt; + int ret; + + if (skb->protocol != htons(ETH_P_IPV6) && + !nf_flow_skb_encap_protocol(skb, htons(ETH_P_IPV6), &offset)) + return NF_ACCEPT; + + if (nf_flow_tuple_ipv6(skb, state->in, &tuple, &hdrsize, offset) < 0) + return NF_ACCEPT; + + tuplehash = flow_offload_lookup(flow_table, &tuple); + if (tuplehash == NULL) + return NF_ACCEPT; + + dir = tuplehash->tuple.dir; + flow = container_of(tuplehash, struct flow_offload, tuplehash[dir]); + + mtu = flow->tuplehash[dir].tuple.mtu + offset; + if (unlikely(nf_flow_exceeds_mtu(skb, mtu))) + return NF_ACCEPT; + + ip6h = (struct ipv6hdr *)(skb_network_header(skb) + offset); + thoff = sizeof(*ip6h) + offset; + if (nf_flow_state_check(flow, ip6h->nexthdr, skb, thoff)) + return NF_ACCEPT; + + if (!nf_flow_dst_check(&tuplehash->tuple)) { + flow_offload_teardown(flow); + return NF_ACCEPT; + } + + if (skb_try_make_writable(skb, thoff + hdrsize)) + return NF_DROP; + + flow_offload_refresh(flow_table, flow, false); + + nf_flow_encap_pop(skb, tuplehash); + + ip6h = ipv6_hdr(skb); + nf_flow_nat_ipv6(flow, skb, dir, ip6h); + + ip6h->hop_limit--; + skb_clear_tstamp(skb); + + if (flow_table->flags & NF_FLOWTABLE_COUNTER) + nf_ct_acct_update(flow->ct, tuplehash->tuple.dir, skb->len); + + if (unlikely(tuplehash->tuple.xmit_type == FLOW_OFFLOAD_XMIT_XFRM)) { + rt = (struct rt6_info *)tuplehash->tuple.dst_cache; + memset(skb->cb, 0, sizeof(struct inet6_skb_parm)); + IP6CB(skb)->iif = skb->dev->ifindex; + IP6CB(skb)->flags = IP6SKB_FORWARDED; + return nf_flow_xmit_xfrm(skb, state, &rt->dst); + } + + switch (tuplehash->tuple.xmit_type) { + case FLOW_OFFLOAD_XMIT_NEIGH: + rt = (struct rt6_info *)tuplehash->tuple.dst_cache; + outdev = rt->dst.dev; + skb->dev = outdev; + nexthop = rt6_nexthop(rt, &flow->tuplehash[!dir].tuple.src_v6); + skb_dst_set_noref(skb, &rt->dst); + neigh_xmit(NEIGH_ND_TABLE, outdev, nexthop, skb); + ret = NF_STOLEN; + break; + case FLOW_OFFLOAD_XMIT_DIRECT: + ret = nf_flow_queue_xmit(state->net, skb, tuplehash, ETH_P_IPV6); + if (ret == NF_DROP) + flow_offload_teardown(flow); + break; + } + + return ret; +} +EXPORT_SYMBOL_GPL(nf_flow_offload_ipv6_hook); diff --git a/net/netfilter/nf_flow_table_offload.c b/net/netfilter/nf_flow_table_offload.c new file mode 100644 index 000000000..1c26f03fc --- /dev/null +++ b/net/netfilter/nf_flow_table_offload.c @@ -0,0 +1,1241 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static struct workqueue_struct *nf_flow_offload_add_wq; +static struct workqueue_struct *nf_flow_offload_del_wq; +static struct workqueue_struct *nf_flow_offload_stats_wq; + +struct flow_offload_work { + struct list_head list; + enum flow_cls_command cmd; + struct nf_flowtable *flowtable; + struct flow_offload *flow; + struct work_struct work; +}; + +#define NF_FLOW_DISSECTOR(__match, __type, __field) \ + (__match)->dissector.offset[__type] = \ + offsetof(struct nf_flow_key, __field) + +static void nf_flow_rule_lwt_match(struct nf_flow_match *match, + struct ip_tunnel_info *tun_info) +{ + struct nf_flow_key *mask = &match->mask; + struct nf_flow_key *key = &match->key; + unsigned int enc_keys; + + if (!tun_info || !(tun_info->mode & IP_TUNNEL_INFO_TX)) + return; + + NF_FLOW_DISSECTOR(match, FLOW_DISSECTOR_KEY_ENC_CONTROL, enc_control); + NF_FLOW_DISSECTOR(match, FLOW_DISSECTOR_KEY_ENC_KEYID, enc_key_id); + key->enc_key_id.keyid = tunnel_id_to_key32(tun_info->key.tun_id); + mask->enc_key_id.keyid = 0xffffffff; + enc_keys = BIT(FLOW_DISSECTOR_KEY_ENC_KEYID) | + BIT(FLOW_DISSECTOR_KEY_ENC_CONTROL); + + if (ip_tunnel_info_af(tun_info) == AF_INET) { + NF_FLOW_DISSECTOR(match, FLOW_DISSECTOR_KEY_ENC_IPV4_ADDRS, + enc_ipv4); + key->enc_ipv4.src = tun_info->key.u.ipv4.dst; + key->enc_ipv4.dst = tun_info->key.u.ipv4.src; + if (key->enc_ipv4.src) + mask->enc_ipv4.src = 0xffffffff; + if (key->enc_ipv4.dst) + mask->enc_ipv4.dst = 0xffffffff; + enc_keys |= BIT(FLOW_DISSECTOR_KEY_ENC_IPV4_ADDRS); + key->enc_control.addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; + } else { + memcpy(&key->enc_ipv6.src, &tun_info->key.u.ipv6.dst, + sizeof(struct in6_addr)); + memcpy(&key->enc_ipv6.dst, &tun_info->key.u.ipv6.src, + sizeof(struct in6_addr)); + if (memcmp(&key->enc_ipv6.src, &in6addr_any, + sizeof(struct in6_addr))) + memset(&mask->enc_ipv6.src, 0xff, + sizeof(struct in6_addr)); + if (memcmp(&key->enc_ipv6.dst, &in6addr_any, + sizeof(struct in6_addr))) + memset(&mask->enc_ipv6.dst, 0xff, + sizeof(struct in6_addr)); + enc_keys |= BIT(FLOW_DISSECTOR_KEY_ENC_IPV6_ADDRS); + key->enc_control.addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + } + + match->dissector.used_keys |= enc_keys; +} + +static void nf_flow_rule_vlan_match(struct flow_dissector_key_vlan *key, + struct flow_dissector_key_vlan *mask, + u16 vlan_id, __be16 proto) +{ + key->vlan_id = vlan_id; + mask->vlan_id = VLAN_VID_MASK; + key->vlan_tpid = proto; + mask->vlan_tpid = 0xffff; +} + +static int nf_flow_rule_match(struct nf_flow_match *match, + const struct flow_offload_tuple *tuple, + struct dst_entry *other_dst) +{ + struct nf_flow_key *mask = &match->mask; + struct nf_flow_key *key = &match->key; + struct ip_tunnel_info *tun_info; + bool vlan_encap = false; + + NF_FLOW_DISSECTOR(match, FLOW_DISSECTOR_KEY_META, meta); + NF_FLOW_DISSECTOR(match, FLOW_DISSECTOR_KEY_CONTROL, control); + NF_FLOW_DISSECTOR(match, FLOW_DISSECTOR_KEY_BASIC, basic); + NF_FLOW_DISSECTOR(match, FLOW_DISSECTOR_KEY_IPV4_ADDRS, ipv4); + NF_FLOW_DISSECTOR(match, FLOW_DISSECTOR_KEY_IPV6_ADDRS, ipv6); + NF_FLOW_DISSECTOR(match, FLOW_DISSECTOR_KEY_TCP, tcp); + NF_FLOW_DISSECTOR(match, FLOW_DISSECTOR_KEY_PORTS, tp); + + if (other_dst && other_dst->lwtstate) { + tun_info = lwt_tun_info(other_dst->lwtstate); + nf_flow_rule_lwt_match(match, tun_info); + } + + if (tuple->xmit_type == FLOW_OFFLOAD_XMIT_TC) + key->meta.ingress_ifindex = tuple->tc.iifidx; + else + key->meta.ingress_ifindex = tuple->iifidx; + + mask->meta.ingress_ifindex = 0xffffffff; + + if (tuple->encap_num > 0 && !(tuple->in_vlan_ingress & BIT(0)) && + tuple->encap[0].proto == htons(ETH_P_8021Q)) { + NF_FLOW_DISSECTOR(match, FLOW_DISSECTOR_KEY_VLAN, vlan); + nf_flow_rule_vlan_match(&key->vlan, &mask->vlan, + tuple->encap[0].id, + tuple->encap[0].proto); + vlan_encap = true; + } + + if (tuple->encap_num > 1 && !(tuple->in_vlan_ingress & BIT(1)) && + tuple->encap[1].proto == htons(ETH_P_8021Q)) { + if (vlan_encap) { + NF_FLOW_DISSECTOR(match, FLOW_DISSECTOR_KEY_CVLAN, + cvlan); + nf_flow_rule_vlan_match(&key->cvlan, &mask->cvlan, + tuple->encap[1].id, + tuple->encap[1].proto); + } else { + NF_FLOW_DISSECTOR(match, FLOW_DISSECTOR_KEY_VLAN, + vlan); + nf_flow_rule_vlan_match(&key->vlan, &mask->vlan, + tuple->encap[1].id, + tuple->encap[1].proto); + } + } + + switch (tuple->l3proto) { + case AF_INET: + key->control.addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; + key->basic.n_proto = htons(ETH_P_IP); + key->ipv4.src = tuple->src_v4.s_addr; + mask->ipv4.src = 0xffffffff; + key->ipv4.dst = tuple->dst_v4.s_addr; + mask->ipv4.dst = 0xffffffff; + break; + case AF_INET6: + key->control.addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + key->basic.n_proto = htons(ETH_P_IPV6); + key->ipv6.src = tuple->src_v6; + memset(&mask->ipv6.src, 0xff, sizeof(mask->ipv6.src)); + key->ipv6.dst = tuple->dst_v6; + memset(&mask->ipv6.dst, 0xff, sizeof(mask->ipv6.dst)); + break; + default: + return -EOPNOTSUPP; + } + mask->control.addr_type = 0xffff; + match->dissector.used_keys |= BIT(key->control.addr_type); + mask->basic.n_proto = 0xffff; + + switch (tuple->l4proto) { + case IPPROTO_TCP: + key->tcp.flags = 0; + mask->tcp.flags = cpu_to_be16(be32_to_cpu(TCP_FLAG_RST | TCP_FLAG_FIN) >> 16); + match->dissector.used_keys |= BIT(FLOW_DISSECTOR_KEY_TCP); + break; + case IPPROTO_UDP: + case IPPROTO_GRE: + break; + default: + return -EOPNOTSUPP; + } + + key->basic.ip_proto = tuple->l4proto; + mask->basic.ip_proto = 0xff; + + match->dissector.used_keys |= BIT(FLOW_DISSECTOR_KEY_META) | + BIT(FLOW_DISSECTOR_KEY_CONTROL) | + BIT(FLOW_DISSECTOR_KEY_BASIC); + + switch (tuple->l4proto) { + case IPPROTO_TCP: + case IPPROTO_UDP: + key->tp.src = tuple->src_port; + mask->tp.src = 0xffff; + key->tp.dst = tuple->dst_port; + mask->tp.dst = 0xffff; + + match->dissector.used_keys |= BIT(FLOW_DISSECTOR_KEY_PORTS); + break; + } + + return 0; +} + +static void flow_offload_mangle(struct flow_action_entry *entry, + enum flow_action_mangle_base htype, u32 offset, + const __be32 *value, const __be32 *mask) +{ + entry->id = FLOW_ACTION_MANGLE; + entry->mangle.htype = htype; + entry->mangle.offset = offset; + memcpy(&entry->mangle.mask, mask, sizeof(u32)); + memcpy(&entry->mangle.val, value, sizeof(u32)); +} + +static inline struct flow_action_entry * +flow_action_entry_next(struct nf_flow_rule *flow_rule) +{ + int i = flow_rule->rule->action.num_entries++; + + return &flow_rule->rule->action.entries[i]; +} + +static int flow_offload_eth_src(struct net *net, + const struct flow_offload *flow, + enum flow_offload_tuple_dir dir, + struct nf_flow_rule *flow_rule) +{ + struct flow_action_entry *entry0 = flow_action_entry_next(flow_rule); + struct flow_action_entry *entry1 = flow_action_entry_next(flow_rule); + const struct flow_offload_tuple *other_tuple, *this_tuple; + struct net_device *dev = NULL; + const unsigned char *addr; + u32 mask, val; + u16 val16; + + this_tuple = &flow->tuplehash[dir].tuple; + + switch (this_tuple->xmit_type) { + case FLOW_OFFLOAD_XMIT_DIRECT: + addr = this_tuple->out.h_source; + break; + case FLOW_OFFLOAD_XMIT_NEIGH: + other_tuple = &flow->tuplehash[!dir].tuple; + dev = dev_get_by_index(net, other_tuple->iifidx); + if (!dev) + return -ENOENT; + + addr = dev->dev_addr; + break; + default: + return -EOPNOTSUPP; + } + + mask = ~0xffff0000; + memcpy(&val16, addr, 2); + val = val16 << 16; + flow_offload_mangle(entry0, FLOW_ACT_MANGLE_HDR_TYPE_ETH, 4, + &val, &mask); + + mask = ~0xffffffff; + memcpy(&val, addr + 2, 4); + flow_offload_mangle(entry1, FLOW_ACT_MANGLE_HDR_TYPE_ETH, 8, + &val, &mask); + + dev_put(dev); + + return 0; +} + +static int flow_offload_eth_dst(struct net *net, + const struct flow_offload *flow, + enum flow_offload_tuple_dir dir, + struct nf_flow_rule *flow_rule) +{ + struct flow_action_entry *entry0 = flow_action_entry_next(flow_rule); + struct flow_action_entry *entry1 = flow_action_entry_next(flow_rule); + const struct flow_offload_tuple *other_tuple, *this_tuple; + const struct dst_entry *dst_cache; + unsigned char ha[ETH_ALEN]; + struct neighbour *n; + const void *daddr; + u32 mask, val; + u8 nud_state; + u16 val16; + + this_tuple = &flow->tuplehash[dir].tuple; + + switch (this_tuple->xmit_type) { + case FLOW_OFFLOAD_XMIT_DIRECT: + ether_addr_copy(ha, this_tuple->out.h_dest); + break; + case FLOW_OFFLOAD_XMIT_NEIGH: + other_tuple = &flow->tuplehash[!dir].tuple; + daddr = &other_tuple->src_v4; + dst_cache = this_tuple->dst_cache; + n = dst_neigh_lookup(dst_cache, daddr); + if (!n) + return -ENOENT; + + read_lock_bh(&n->lock); + nud_state = n->nud_state; + ether_addr_copy(ha, n->ha); + read_unlock_bh(&n->lock); + neigh_release(n); + + if (!(nud_state & NUD_VALID)) + return -ENOENT; + break; + default: + return -EOPNOTSUPP; + } + + mask = ~0xffffffff; + memcpy(&val, ha, 4); + flow_offload_mangle(entry0, FLOW_ACT_MANGLE_HDR_TYPE_ETH, 0, + &val, &mask); + + mask = ~0x0000ffff; + memcpy(&val16, ha + 4, 2); + val = val16; + flow_offload_mangle(entry1, FLOW_ACT_MANGLE_HDR_TYPE_ETH, 4, + &val, &mask); + + return 0; +} + +static void flow_offload_ipv4_snat(struct net *net, + const struct flow_offload *flow, + enum flow_offload_tuple_dir dir, + struct nf_flow_rule *flow_rule) +{ + struct flow_action_entry *entry = flow_action_entry_next(flow_rule); + u32 mask = ~htonl(0xffffffff); + __be32 addr; + u32 offset; + + switch (dir) { + case FLOW_OFFLOAD_DIR_ORIGINAL: + addr = flow->tuplehash[FLOW_OFFLOAD_DIR_REPLY].tuple.dst_v4.s_addr; + offset = offsetof(struct iphdr, saddr); + break; + case FLOW_OFFLOAD_DIR_REPLY: + addr = flow->tuplehash[FLOW_OFFLOAD_DIR_ORIGINAL].tuple.src_v4.s_addr; + offset = offsetof(struct iphdr, daddr); + break; + default: + return; + } + + flow_offload_mangle(entry, FLOW_ACT_MANGLE_HDR_TYPE_IP4, offset, + &addr, &mask); +} + +static void flow_offload_ipv4_dnat(struct net *net, + const struct flow_offload *flow, + enum flow_offload_tuple_dir dir, + struct nf_flow_rule *flow_rule) +{ + struct flow_action_entry *entry = flow_action_entry_next(flow_rule); + u32 mask = ~htonl(0xffffffff); + __be32 addr; + u32 offset; + + switch (dir) { + case FLOW_OFFLOAD_DIR_ORIGINAL: + addr = flow->tuplehash[FLOW_OFFLOAD_DIR_REPLY].tuple.src_v4.s_addr; + offset = offsetof(struct iphdr, daddr); + break; + case FLOW_OFFLOAD_DIR_REPLY: + addr = flow->tuplehash[FLOW_OFFLOAD_DIR_ORIGINAL].tuple.dst_v4.s_addr; + offset = offsetof(struct iphdr, saddr); + break; + default: + return; + } + + flow_offload_mangle(entry, FLOW_ACT_MANGLE_HDR_TYPE_IP4, offset, + &addr, &mask); +} + +static void flow_offload_ipv6_mangle(struct nf_flow_rule *flow_rule, + unsigned int offset, + const __be32 *addr, const __be32 *mask) +{ + struct flow_action_entry *entry; + int i; + + for (i = 0; i < sizeof(struct in6_addr) / sizeof(u32); i++) { + entry = flow_action_entry_next(flow_rule); + flow_offload_mangle(entry, FLOW_ACT_MANGLE_HDR_TYPE_IP6, + offset + i * sizeof(u32), &addr[i], mask); + } +} + +static void flow_offload_ipv6_snat(struct net *net, + const struct flow_offload *flow, + enum flow_offload_tuple_dir dir, + struct nf_flow_rule *flow_rule) +{ + u32 mask = ~htonl(0xffffffff); + const __be32 *addr; + u32 offset; + + switch (dir) { + case FLOW_OFFLOAD_DIR_ORIGINAL: + addr = flow->tuplehash[FLOW_OFFLOAD_DIR_REPLY].tuple.dst_v6.s6_addr32; + offset = offsetof(struct ipv6hdr, saddr); + break; + case FLOW_OFFLOAD_DIR_REPLY: + addr = flow->tuplehash[FLOW_OFFLOAD_DIR_ORIGINAL].tuple.src_v6.s6_addr32; + offset = offsetof(struct ipv6hdr, daddr); + break; + default: + return; + } + + flow_offload_ipv6_mangle(flow_rule, offset, addr, &mask); +} + +static void flow_offload_ipv6_dnat(struct net *net, + const struct flow_offload *flow, + enum flow_offload_tuple_dir dir, + struct nf_flow_rule *flow_rule) +{ + u32 mask = ~htonl(0xffffffff); + const __be32 *addr; + u32 offset; + + switch (dir) { + case FLOW_OFFLOAD_DIR_ORIGINAL: + addr = flow->tuplehash[FLOW_OFFLOAD_DIR_REPLY].tuple.src_v6.s6_addr32; + offset = offsetof(struct ipv6hdr, daddr); + break; + case FLOW_OFFLOAD_DIR_REPLY: + addr = flow->tuplehash[FLOW_OFFLOAD_DIR_ORIGINAL].tuple.dst_v6.s6_addr32; + offset = offsetof(struct ipv6hdr, saddr); + break; + default: + return; + } + + flow_offload_ipv6_mangle(flow_rule, offset, addr, &mask); +} + +static int flow_offload_l4proto(const struct flow_offload *flow) +{ + u8 protonum = flow->tuplehash[FLOW_OFFLOAD_DIR_ORIGINAL].tuple.l4proto; + u8 type = 0; + + switch (protonum) { + case IPPROTO_TCP: + type = FLOW_ACT_MANGLE_HDR_TYPE_TCP; + break; + case IPPROTO_UDP: + type = FLOW_ACT_MANGLE_HDR_TYPE_UDP; + break; + default: + break; + } + + return type; +} + +static void flow_offload_port_snat(struct net *net, + const struct flow_offload *flow, + enum flow_offload_tuple_dir dir, + struct nf_flow_rule *flow_rule) +{ + struct flow_action_entry *entry = flow_action_entry_next(flow_rule); + u32 mask, port; + u32 offset; + + switch (dir) { + case FLOW_OFFLOAD_DIR_ORIGINAL: + port = ntohs(flow->tuplehash[FLOW_OFFLOAD_DIR_REPLY].tuple.dst_port); + offset = 0; /* offsetof(struct tcphdr, source); */ + port = htonl(port << 16); + mask = ~htonl(0xffff0000); + break; + case FLOW_OFFLOAD_DIR_REPLY: + port = ntohs(flow->tuplehash[FLOW_OFFLOAD_DIR_ORIGINAL].tuple.src_port); + offset = 0; /* offsetof(struct tcphdr, dest); */ + port = htonl(port); + mask = ~htonl(0xffff); + break; + default: + return; + } + + flow_offload_mangle(entry, flow_offload_l4proto(flow), offset, + &port, &mask); +} + +static void flow_offload_port_dnat(struct net *net, + const struct flow_offload *flow, + enum flow_offload_tuple_dir dir, + struct nf_flow_rule *flow_rule) +{ + struct flow_action_entry *entry = flow_action_entry_next(flow_rule); + u32 mask, port; + u32 offset; + + switch (dir) { + case FLOW_OFFLOAD_DIR_ORIGINAL: + port = ntohs(flow->tuplehash[FLOW_OFFLOAD_DIR_REPLY].tuple.src_port); + offset = 0; /* offsetof(struct tcphdr, dest); */ + port = htonl(port); + mask = ~htonl(0xffff); + break; + case FLOW_OFFLOAD_DIR_REPLY: + port = ntohs(flow->tuplehash[FLOW_OFFLOAD_DIR_ORIGINAL].tuple.dst_port); + offset = 0; /* offsetof(struct tcphdr, source); */ + port = htonl(port << 16); + mask = ~htonl(0xffff0000); + break; + default: + return; + } + + flow_offload_mangle(entry, flow_offload_l4proto(flow), offset, + &port, &mask); +} + +static void flow_offload_ipv4_checksum(struct net *net, + const struct flow_offload *flow, + struct nf_flow_rule *flow_rule) +{ + u8 protonum = flow->tuplehash[FLOW_OFFLOAD_DIR_ORIGINAL].tuple.l4proto; + struct flow_action_entry *entry = flow_action_entry_next(flow_rule); + + entry->id = FLOW_ACTION_CSUM; + entry->csum_flags = TCA_CSUM_UPDATE_FLAG_IPV4HDR; + + switch (protonum) { + case IPPROTO_TCP: + entry->csum_flags |= TCA_CSUM_UPDATE_FLAG_TCP; + break; + case IPPROTO_UDP: + entry->csum_flags |= TCA_CSUM_UPDATE_FLAG_UDP; + break; + } +} + +static void flow_offload_redirect(struct net *net, + const struct flow_offload *flow, + enum flow_offload_tuple_dir dir, + struct nf_flow_rule *flow_rule) +{ + const struct flow_offload_tuple *this_tuple, *other_tuple; + struct flow_action_entry *entry; + struct net_device *dev; + int ifindex; + + this_tuple = &flow->tuplehash[dir].tuple; + switch (this_tuple->xmit_type) { + case FLOW_OFFLOAD_XMIT_DIRECT: + this_tuple = &flow->tuplehash[dir].tuple; + ifindex = this_tuple->out.hw_ifidx; + break; + case FLOW_OFFLOAD_XMIT_NEIGH: + other_tuple = &flow->tuplehash[!dir].tuple; + ifindex = other_tuple->iifidx; + break; + default: + return; + } + + dev = dev_get_by_index(net, ifindex); + if (!dev) + return; + + entry = flow_action_entry_next(flow_rule); + entry->id = FLOW_ACTION_REDIRECT; + entry->dev = dev; +} + +static void flow_offload_encap_tunnel(const struct flow_offload *flow, + enum flow_offload_tuple_dir dir, + struct nf_flow_rule *flow_rule) +{ + const struct flow_offload_tuple *this_tuple; + struct flow_action_entry *entry; + struct dst_entry *dst; + + this_tuple = &flow->tuplehash[dir].tuple; + if (this_tuple->xmit_type == FLOW_OFFLOAD_XMIT_DIRECT) + return; + + dst = this_tuple->dst_cache; + if (dst && dst->lwtstate) { + struct ip_tunnel_info *tun_info; + + tun_info = lwt_tun_info(dst->lwtstate); + if (tun_info && (tun_info->mode & IP_TUNNEL_INFO_TX)) { + entry = flow_action_entry_next(flow_rule); + entry->id = FLOW_ACTION_TUNNEL_ENCAP; + entry->tunnel = tun_info; + } + } +} + +static void flow_offload_decap_tunnel(const struct flow_offload *flow, + enum flow_offload_tuple_dir dir, + struct nf_flow_rule *flow_rule) +{ + const struct flow_offload_tuple *other_tuple; + struct flow_action_entry *entry; + struct dst_entry *dst; + + other_tuple = &flow->tuplehash[!dir].tuple; + if (other_tuple->xmit_type == FLOW_OFFLOAD_XMIT_DIRECT) + return; + + dst = other_tuple->dst_cache; + if (dst && dst->lwtstate) { + struct ip_tunnel_info *tun_info; + + tun_info = lwt_tun_info(dst->lwtstate); + if (tun_info && (tun_info->mode & IP_TUNNEL_INFO_TX)) { + entry = flow_action_entry_next(flow_rule); + entry->id = FLOW_ACTION_TUNNEL_DECAP; + } + } +} + +static int +nf_flow_rule_route_common(struct net *net, const struct flow_offload *flow, + enum flow_offload_tuple_dir dir, + struct nf_flow_rule *flow_rule) +{ + const struct flow_offload_tuple *other_tuple; + const struct flow_offload_tuple *tuple; + int i; + + flow_offload_decap_tunnel(flow, dir, flow_rule); + flow_offload_encap_tunnel(flow, dir, flow_rule); + + if (flow_offload_eth_src(net, flow, dir, flow_rule) < 0 || + flow_offload_eth_dst(net, flow, dir, flow_rule) < 0) + return -1; + + tuple = &flow->tuplehash[dir].tuple; + + for (i = 0; i < tuple->encap_num; i++) { + struct flow_action_entry *entry; + + if (tuple->in_vlan_ingress & BIT(i)) + continue; + + if (tuple->encap[i].proto == htons(ETH_P_8021Q)) { + entry = flow_action_entry_next(flow_rule); + entry->id = FLOW_ACTION_VLAN_POP; + } + } + + other_tuple = &flow->tuplehash[!dir].tuple; + + for (i = 0; i < other_tuple->encap_num; i++) { + struct flow_action_entry *entry; + + if (other_tuple->in_vlan_ingress & BIT(i)) + continue; + + entry = flow_action_entry_next(flow_rule); + + switch (other_tuple->encap[i].proto) { + case htons(ETH_P_PPP_SES): + entry->id = FLOW_ACTION_PPPOE_PUSH; + entry->pppoe.sid = other_tuple->encap[i].id; + break; + case htons(ETH_P_8021Q): + entry->id = FLOW_ACTION_VLAN_PUSH; + entry->vlan.vid = other_tuple->encap[i].id; + entry->vlan.proto = other_tuple->encap[i].proto; + break; + } + } + + return 0; +} + +int nf_flow_rule_route_ipv4(struct net *net, struct flow_offload *flow, + enum flow_offload_tuple_dir dir, + struct nf_flow_rule *flow_rule) +{ + if (nf_flow_rule_route_common(net, flow, dir, flow_rule) < 0) + return -1; + + if (test_bit(NF_FLOW_SNAT, &flow->flags)) { + flow_offload_ipv4_snat(net, flow, dir, flow_rule); + flow_offload_port_snat(net, flow, dir, flow_rule); + } + if (test_bit(NF_FLOW_DNAT, &flow->flags)) { + flow_offload_ipv4_dnat(net, flow, dir, flow_rule); + flow_offload_port_dnat(net, flow, dir, flow_rule); + } + if (test_bit(NF_FLOW_SNAT, &flow->flags) || + test_bit(NF_FLOW_DNAT, &flow->flags)) + flow_offload_ipv4_checksum(net, flow, flow_rule); + + flow_offload_redirect(net, flow, dir, flow_rule); + + return 0; +} +EXPORT_SYMBOL_GPL(nf_flow_rule_route_ipv4); + +int nf_flow_rule_route_ipv6(struct net *net, struct flow_offload *flow, + enum flow_offload_tuple_dir dir, + struct nf_flow_rule *flow_rule) +{ + if (nf_flow_rule_route_common(net, flow, dir, flow_rule) < 0) + return -1; + + if (test_bit(NF_FLOW_SNAT, &flow->flags)) { + flow_offload_ipv6_snat(net, flow, dir, flow_rule); + flow_offload_port_snat(net, flow, dir, flow_rule); + } + if (test_bit(NF_FLOW_DNAT, &flow->flags)) { + flow_offload_ipv6_dnat(net, flow, dir, flow_rule); + flow_offload_port_dnat(net, flow, dir, flow_rule); + } + + flow_offload_redirect(net, flow, dir, flow_rule); + + return 0; +} +EXPORT_SYMBOL_GPL(nf_flow_rule_route_ipv6); + +#define NF_FLOW_RULE_ACTION_MAX 16 + +static struct nf_flow_rule * +nf_flow_offload_rule_alloc(struct net *net, + const struct flow_offload_work *offload, + enum flow_offload_tuple_dir dir) +{ + const struct nf_flowtable *flowtable = offload->flowtable; + const struct flow_offload_tuple *tuple, *other_tuple; + struct flow_offload *flow = offload->flow; + struct dst_entry *other_dst = NULL; + struct nf_flow_rule *flow_rule; + int err = -ENOMEM; + + flow_rule = kzalloc(sizeof(*flow_rule), GFP_KERNEL); + if (!flow_rule) + goto err_flow; + + flow_rule->rule = flow_rule_alloc(NF_FLOW_RULE_ACTION_MAX); + if (!flow_rule->rule) + goto err_flow_rule; + + flow_rule->rule->match.dissector = &flow_rule->match.dissector; + flow_rule->rule->match.mask = &flow_rule->match.mask; + flow_rule->rule->match.key = &flow_rule->match.key; + + tuple = &flow->tuplehash[dir].tuple; + other_tuple = &flow->tuplehash[!dir].tuple; + if (other_tuple->xmit_type == FLOW_OFFLOAD_XMIT_NEIGH) + other_dst = other_tuple->dst_cache; + + err = nf_flow_rule_match(&flow_rule->match, tuple, other_dst); + if (err < 0) + goto err_flow_match; + + flow_rule->rule->action.num_entries = 0; + if (flowtable->type->action(net, flow, dir, flow_rule) < 0) + goto err_flow_match; + + return flow_rule; + +err_flow_match: + kfree(flow_rule->rule); +err_flow_rule: + kfree(flow_rule); +err_flow: + return NULL; +} + +static void __nf_flow_offload_destroy(struct nf_flow_rule *flow_rule) +{ + struct flow_action_entry *entry; + int i; + + for (i = 0; i < flow_rule->rule->action.num_entries; i++) { + entry = &flow_rule->rule->action.entries[i]; + if (entry->id != FLOW_ACTION_REDIRECT) + continue; + + dev_put(entry->dev); + } + kfree(flow_rule->rule); + kfree(flow_rule); +} + +static void nf_flow_offload_destroy(struct nf_flow_rule *flow_rule[]) +{ + int i; + + for (i = 0; i < FLOW_OFFLOAD_DIR_MAX; i++) + __nf_flow_offload_destroy(flow_rule[i]); +} + +static int nf_flow_offload_alloc(const struct flow_offload_work *offload, + struct nf_flow_rule *flow_rule[]) +{ + struct net *net = read_pnet(&offload->flowtable->net); + + flow_rule[0] = nf_flow_offload_rule_alloc(net, offload, + FLOW_OFFLOAD_DIR_ORIGINAL); + if (!flow_rule[0]) + return -ENOMEM; + + flow_rule[1] = nf_flow_offload_rule_alloc(net, offload, + FLOW_OFFLOAD_DIR_REPLY); + if (!flow_rule[1]) { + __nf_flow_offload_destroy(flow_rule[0]); + return -ENOMEM; + } + + return 0; +} + +static void nf_flow_offload_init(struct flow_cls_offload *cls_flow, + __be16 proto, int priority, + enum flow_cls_command cmd, + const struct flow_offload_tuple *tuple, + struct netlink_ext_ack *extack) +{ + cls_flow->common.protocol = proto; + cls_flow->common.prio = priority; + cls_flow->common.extack = extack; + cls_flow->command = cmd; + cls_flow->cookie = (unsigned long)tuple; +} + +static int nf_flow_offload_tuple(struct nf_flowtable *flowtable, + struct flow_offload *flow, + struct nf_flow_rule *flow_rule, + enum flow_offload_tuple_dir dir, + int priority, int cmd, + struct flow_stats *stats, + struct list_head *block_cb_list) +{ + struct flow_cls_offload cls_flow = {}; + struct flow_block_cb *block_cb; + struct netlink_ext_ack extack; + __be16 proto = ETH_P_ALL; + int err, i = 0; + + nf_flow_offload_init(&cls_flow, proto, priority, cmd, + &flow->tuplehash[dir].tuple, &extack); + if (cmd == FLOW_CLS_REPLACE) + cls_flow.rule = flow_rule->rule; + + down_read(&flowtable->flow_block_lock); + list_for_each_entry(block_cb, block_cb_list, list) { + err = block_cb->cb(TC_SETUP_CLSFLOWER, &cls_flow, + block_cb->cb_priv); + if (err < 0) + continue; + + i++; + } + up_read(&flowtable->flow_block_lock); + + if (cmd == FLOW_CLS_STATS) + memcpy(stats, &cls_flow.stats, sizeof(*stats)); + + return i; +} + +static int flow_offload_tuple_add(struct flow_offload_work *offload, + struct nf_flow_rule *flow_rule, + enum flow_offload_tuple_dir dir) +{ + return nf_flow_offload_tuple(offload->flowtable, offload->flow, + flow_rule, dir, + offload->flowtable->priority, + FLOW_CLS_REPLACE, NULL, + &offload->flowtable->flow_block.cb_list); +} + +static void flow_offload_tuple_del(struct flow_offload_work *offload, + enum flow_offload_tuple_dir dir) +{ + nf_flow_offload_tuple(offload->flowtable, offload->flow, NULL, dir, + offload->flowtable->priority, + FLOW_CLS_DESTROY, NULL, + &offload->flowtable->flow_block.cb_list); +} + +static int flow_offload_rule_add(struct flow_offload_work *offload, + struct nf_flow_rule *flow_rule[]) +{ + int ok_count = 0; + + ok_count += flow_offload_tuple_add(offload, flow_rule[0], + FLOW_OFFLOAD_DIR_ORIGINAL); + if (test_bit(NF_FLOW_HW_BIDIRECTIONAL, &offload->flow->flags)) + ok_count += flow_offload_tuple_add(offload, flow_rule[1], + FLOW_OFFLOAD_DIR_REPLY); + if (ok_count == 0) + return -ENOENT; + + return 0; +} + +static void flow_offload_work_add(struct flow_offload_work *offload) +{ + struct nf_flow_rule *flow_rule[FLOW_OFFLOAD_DIR_MAX]; + int err; + + err = nf_flow_offload_alloc(offload, flow_rule); + if (err < 0) + return; + + err = flow_offload_rule_add(offload, flow_rule); + if (err < 0) + goto out; + + set_bit(IPS_HW_OFFLOAD_BIT, &offload->flow->ct->status); + +out: + nf_flow_offload_destroy(flow_rule); +} + +static void flow_offload_work_del(struct flow_offload_work *offload) +{ + clear_bit(IPS_HW_OFFLOAD_BIT, &offload->flow->ct->status); + flow_offload_tuple_del(offload, FLOW_OFFLOAD_DIR_ORIGINAL); + if (test_bit(NF_FLOW_HW_BIDIRECTIONAL, &offload->flow->flags)) + flow_offload_tuple_del(offload, FLOW_OFFLOAD_DIR_REPLY); + set_bit(NF_FLOW_HW_DEAD, &offload->flow->flags); +} + +static void flow_offload_tuple_stats(struct flow_offload_work *offload, + enum flow_offload_tuple_dir dir, + struct flow_stats *stats) +{ + nf_flow_offload_tuple(offload->flowtable, offload->flow, NULL, dir, + offload->flowtable->priority, + FLOW_CLS_STATS, stats, + &offload->flowtable->flow_block.cb_list); +} + +static void flow_offload_work_stats(struct flow_offload_work *offload) +{ + struct flow_stats stats[FLOW_OFFLOAD_DIR_MAX] = {}; + u64 lastused; + + flow_offload_tuple_stats(offload, FLOW_OFFLOAD_DIR_ORIGINAL, &stats[0]); + if (test_bit(NF_FLOW_HW_BIDIRECTIONAL, &offload->flow->flags)) + flow_offload_tuple_stats(offload, FLOW_OFFLOAD_DIR_REPLY, + &stats[1]); + + lastused = max_t(u64, stats[0].lastused, stats[1].lastused); + offload->flow->timeout = max_t(u64, offload->flow->timeout, + lastused + flow_offload_get_timeout(offload->flow)); + + if (offload->flowtable->flags & NF_FLOWTABLE_COUNTER) { + if (stats[0].pkts) + nf_ct_acct_add(offload->flow->ct, + FLOW_OFFLOAD_DIR_ORIGINAL, + stats[0].pkts, stats[0].bytes); + if (stats[1].pkts) + nf_ct_acct_add(offload->flow->ct, + FLOW_OFFLOAD_DIR_REPLY, + stats[1].pkts, stats[1].bytes); + } +} + +static void flow_offload_work_handler(struct work_struct *work) +{ + struct flow_offload_work *offload; + struct net *net; + + offload = container_of(work, struct flow_offload_work, work); + net = read_pnet(&offload->flowtable->net); + switch (offload->cmd) { + case FLOW_CLS_REPLACE: + flow_offload_work_add(offload); + NF_FLOW_TABLE_STAT_DEC_ATOMIC(net, count_wq_add); + break; + case FLOW_CLS_DESTROY: + flow_offload_work_del(offload); + NF_FLOW_TABLE_STAT_DEC_ATOMIC(net, count_wq_del); + break; + case FLOW_CLS_STATS: + flow_offload_work_stats(offload); + NF_FLOW_TABLE_STAT_DEC_ATOMIC(net, count_wq_stats); + break; + default: + WARN_ON_ONCE(1); + } + + clear_bit(NF_FLOW_HW_PENDING, &offload->flow->flags); + kfree(offload); +} + +static void flow_offload_queue_work(struct flow_offload_work *offload) +{ + struct net *net = read_pnet(&offload->flowtable->net); + + if (offload->cmd == FLOW_CLS_REPLACE) { + NF_FLOW_TABLE_STAT_INC_ATOMIC(net, count_wq_add); + queue_work(nf_flow_offload_add_wq, &offload->work); + } else if (offload->cmd == FLOW_CLS_DESTROY) { + NF_FLOW_TABLE_STAT_INC_ATOMIC(net, count_wq_del); + queue_work(nf_flow_offload_del_wq, &offload->work); + } else { + NF_FLOW_TABLE_STAT_INC_ATOMIC(net, count_wq_stats); + queue_work(nf_flow_offload_stats_wq, &offload->work); + } +} + +static struct flow_offload_work * +nf_flow_offload_work_alloc(struct nf_flowtable *flowtable, + struct flow_offload *flow, unsigned int cmd) +{ + struct flow_offload_work *offload; + + if (test_and_set_bit(NF_FLOW_HW_PENDING, &flow->flags)) + return NULL; + + offload = kmalloc(sizeof(struct flow_offload_work), GFP_ATOMIC); + if (!offload) { + clear_bit(NF_FLOW_HW_PENDING, &flow->flags); + return NULL; + } + + offload->cmd = cmd; + offload->flow = flow; + offload->flowtable = flowtable; + INIT_WORK(&offload->work, flow_offload_work_handler); + + return offload; +} + + +void nf_flow_offload_add(struct nf_flowtable *flowtable, + struct flow_offload *flow) +{ + struct flow_offload_work *offload; + + offload = nf_flow_offload_work_alloc(flowtable, flow, FLOW_CLS_REPLACE); + if (!offload) + return; + + flow_offload_queue_work(offload); +} + +void nf_flow_offload_del(struct nf_flowtable *flowtable, + struct flow_offload *flow) +{ + struct flow_offload_work *offload; + + offload = nf_flow_offload_work_alloc(flowtable, flow, FLOW_CLS_DESTROY); + if (!offload) + return; + + set_bit(NF_FLOW_HW_DYING, &flow->flags); + flow_offload_queue_work(offload); +} + +void nf_flow_offload_stats(struct nf_flowtable *flowtable, + struct flow_offload *flow) +{ + struct flow_offload_work *offload; + __s32 delta; + + delta = nf_flow_timeout_delta(flow->timeout); + if ((delta >= (9 * flow_offload_get_timeout(flow)) / 10)) + return; + + offload = nf_flow_offload_work_alloc(flowtable, flow, FLOW_CLS_STATS); + if (!offload) + return; + + flow_offload_queue_work(offload); +} + +void nf_flow_table_offload_flush_cleanup(struct nf_flowtable *flowtable) +{ + if (nf_flowtable_hw_offload(flowtable)) { + flush_workqueue(nf_flow_offload_del_wq); + nf_flow_table_gc_run(flowtable); + } +} + +void nf_flow_table_offload_flush(struct nf_flowtable *flowtable) +{ + if (nf_flowtable_hw_offload(flowtable)) { + flush_workqueue(nf_flow_offload_add_wq); + flush_workqueue(nf_flow_offload_del_wq); + flush_workqueue(nf_flow_offload_stats_wq); + } +} + +static int nf_flow_table_block_setup(struct nf_flowtable *flowtable, + struct flow_block_offload *bo, + enum flow_block_command cmd) +{ + struct flow_block_cb *block_cb, *next; + int err = 0; + + down_write(&flowtable->flow_block_lock); + switch (cmd) { + case FLOW_BLOCK_BIND: + list_splice(&bo->cb_list, &flowtable->flow_block.cb_list); + break; + case FLOW_BLOCK_UNBIND: + list_for_each_entry_safe(block_cb, next, &bo->cb_list, list) { + list_del(&block_cb->list); + flow_block_cb_free(block_cb); + } + break; + default: + WARN_ON_ONCE(1); + err = -EOPNOTSUPP; + } + up_write(&flowtable->flow_block_lock); + + return err; +} + +static void nf_flow_table_block_offload_init(struct flow_block_offload *bo, + struct net *net, + enum flow_block_command cmd, + struct nf_flowtable *flowtable, + struct netlink_ext_ack *extack) +{ + memset(bo, 0, sizeof(*bo)); + bo->net = net; + bo->block = &flowtable->flow_block; + bo->command = cmd; + bo->binder_type = FLOW_BLOCK_BINDER_TYPE_CLSACT_INGRESS; + bo->extack = extack; + bo->cb_list_head = &flowtable->flow_block.cb_list; + INIT_LIST_HEAD(&bo->cb_list); +} + +static void nf_flow_table_indr_cleanup(struct flow_block_cb *block_cb) +{ + struct nf_flowtable *flowtable = block_cb->indr.data; + struct net_device *dev = block_cb->indr.dev; + + nf_flow_table_gc_cleanup(flowtable, dev); + down_write(&flowtable->flow_block_lock); + list_del(&block_cb->list); + list_del(&block_cb->driver_list); + flow_block_cb_free(block_cb); + up_write(&flowtable->flow_block_lock); +} + +static int nf_flow_table_indr_offload_cmd(struct flow_block_offload *bo, + struct nf_flowtable *flowtable, + struct net_device *dev, + enum flow_block_command cmd, + struct netlink_ext_ack *extack) +{ + nf_flow_table_block_offload_init(bo, dev_net(dev), cmd, flowtable, + extack); + + return flow_indr_dev_setup_offload(dev, NULL, TC_SETUP_FT, flowtable, bo, + nf_flow_table_indr_cleanup); +} + +static int nf_flow_table_offload_cmd(struct flow_block_offload *bo, + struct nf_flowtable *flowtable, + struct net_device *dev, + enum flow_block_command cmd, + struct netlink_ext_ack *extack) +{ + int err; + + nf_flow_table_block_offload_init(bo, dev_net(dev), cmd, flowtable, + extack); + down_write(&flowtable->flow_block_lock); + err = dev->netdev_ops->ndo_setup_tc(dev, TC_SETUP_FT, bo); + up_write(&flowtable->flow_block_lock); + if (err < 0) + return err; + + return 0; +} + +int nf_flow_table_offload_setup(struct nf_flowtable *flowtable, + struct net_device *dev, + enum flow_block_command cmd) +{ + struct netlink_ext_ack extack = {}; + struct flow_block_offload bo; + int err; + + if (!nf_flowtable_hw_offload(flowtable)) + return 0; + + if (dev->netdev_ops->ndo_setup_tc) + err = nf_flow_table_offload_cmd(&bo, flowtable, dev, cmd, + &extack); + else + err = nf_flow_table_indr_offload_cmd(&bo, flowtable, dev, cmd, + &extack); + if (err < 0) + return err; + + return nf_flow_table_block_setup(flowtable, &bo, cmd); +} +EXPORT_SYMBOL_GPL(nf_flow_table_offload_setup); + +int nf_flow_table_offload_init(void) +{ + nf_flow_offload_add_wq = alloc_workqueue("nf_ft_offload_add", + WQ_UNBOUND | WQ_SYSFS, 0); + if (!nf_flow_offload_add_wq) + return -ENOMEM; + + nf_flow_offload_del_wq = alloc_workqueue("nf_ft_offload_del", + WQ_UNBOUND | WQ_SYSFS, 0); + if (!nf_flow_offload_del_wq) + goto err_del_wq; + + nf_flow_offload_stats_wq = alloc_workqueue("nf_ft_offload_stats", + WQ_UNBOUND | WQ_SYSFS, 0); + if (!nf_flow_offload_stats_wq) + goto err_stats_wq; + + return 0; + +err_stats_wq: + destroy_workqueue(nf_flow_offload_del_wq); +err_del_wq: + destroy_workqueue(nf_flow_offload_add_wq); + return -ENOMEM; +} + +void nf_flow_table_offload_exit(void) +{ + destroy_workqueue(nf_flow_offload_add_wq); + destroy_workqueue(nf_flow_offload_del_wq); + destroy_workqueue(nf_flow_offload_stats_wq); +} diff --git a/net/netfilter/nf_flow_table_procfs.c b/net/netfilter/nf_flow_table_procfs.c new file mode 100644 index 000000000..159b033a4 --- /dev/null +++ b/net/netfilter/nf_flow_table_procfs.c @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include + +static void *nf_flow_table_cpu_seq_start(struct seq_file *seq, loff_t *pos) +{ + struct net *net = seq_file_net(seq); + int cpu; + + if (*pos == 0) + return SEQ_START_TOKEN; + + for (cpu = *pos - 1; cpu < nr_cpu_ids; ++cpu) { + if (!cpu_possible(cpu)) + continue; + *pos = cpu + 1; + return per_cpu_ptr(net->ft.stat, cpu); + } + + return NULL; +} + +static void *nf_flow_table_cpu_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct net *net = seq_file_net(seq); + int cpu; + + for (cpu = *pos; cpu < nr_cpu_ids; ++cpu) { + if (!cpu_possible(cpu)) + continue; + *pos = cpu + 1; + return per_cpu_ptr(net->ft.stat, cpu); + } + (*pos)++; + return NULL; +} + +static void nf_flow_table_cpu_seq_stop(struct seq_file *seq, void *v) +{ +} + +static int nf_flow_table_cpu_seq_show(struct seq_file *seq, void *v) +{ + const struct nf_flow_table_stat *st = v; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, "wq_add wq_del wq_stats\n"); + return 0; + } + + seq_printf(seq, "%8d %8d %8d\n", + st->count_wq_add, + st->count_wq_del, + st->count_wq_stats + ); + return 0; +} + +static const struct seq_operations nf_flow_table_cpu_seq_ops = { + .start = nf_flow_table_cpu_seq_start, + .next = nf_flow_table_cpu_seq_next, + .stop = nf_flow_table_cpu_seq_stop, + .show = nf_flow_table_cpu_seq_show, +}; + +int nf_flow_table_init_proc(struct net *net) +{ + struct proc_dir_entry *pde; + + pde = proc_create_net("nf_flowtable", 0444, net->proc_net_stat, + &nf_flow_table_cpu_seq_ops, + sizeof(struct seq_net_private)); + return pde ? 0 : -ENOMEM; +} + +void nf_flow_table_fini_proc(struct net *net) +{ + remove_proc_entry("nf_flowtable", net->proc_net_stat); +} diff --git a/net/netfilter/nf_hooks_lwtunnel.c b/net/netfilter/nf_hooks_lwtunnel.c new file mode 100644 index 000000000..00e89ffd7 --- /dev/null +++ b/net/netfilter/nf_hooks_lwtunnel.c @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include +#include +#include + +static inline int nf_hooks_lwtunnel_get(void) +{ + if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled)) + return 1; + else + return 0; +} + +static inline int nf_hooks_lwtunnel_set(int enable) +{ + if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled)) { + if (!enable) + return -EBUSY; + } else if (enable) { + static_branch_enable(&nf_hooks_lwtunnel_enabled); + } + + return 0; +} + +#ifdef CONFIG_SYSCTL +int nf_hooks_lwtunnel_sysctl_handler(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + int proc_nf_hooks_lwtunnel_enabled = 0; + struct ctl_table tmp = { + .procname = table->procname, + .data = &proc_nf_hooks_lwtunnel_enabled, + .maxlen = sizeof(int), + .mode = table->mode, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }; + int ret; + + if (!write) + proc_nf_hooks_lwtunnel_enabled = nf_hooks_lwtunnel_get(); + + ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos); + + if (write && ret == 0) + ret = nf_hooks_lwtunnel_set(proc_nf_hooks_lwtunnel_enabled); + + return ret; +} +EXPORT_SYMBOL_GPL(nf_hooks_lwtunnel_sysctl_handler); +#endif /* CONFIG_SYSCTL */ diff --git a/net/netfilter/nf_internals.h b/net/netfilter/nf_internals.h new file mode 100644 index 000000000..832ae6417 --- /dev/null +++ b/net/netfilter/nf_internals.h @@ -0,0 +1,37 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _NF_INTERNALS_H +#define _NF_INTERNALS_H + +#include +#include +#include + +/* nf_conntrack_netlink.c: applied on tuple filters */ +#define CTA_FILTER_F_CTA_IP_SRC (1 << 0) +#define CTA_FILTER_F_CTA_IP_DST (1 << 1) +#define CTA_FILTER_F_CTA_TUPLE_ZONE (1 << 2) +#define CTA_FILTER_F_CTA_PROTO_NUM (1 << 3) +#define CTA_FILTER_F_CTA_PROTO_SRC_PORT (1 << 4) +#define CTA_FILTER_F_CTA_PROTO_DST_PORT (1 << 5) +#define CTA_FILTER_F_CTA_PROTO_ICMP_TYPE (1 << 6) +#define CTA_FILTER_F_CTA_PROTO_ICMP_CODE (1 << 7) +#define CTA_FILTER_F_CTA_PROTO_ICMP_ID (1 << 8) +#define CTA_FILTER_F_CTA_PROTO_ICMPV6_TYPE (1 << 9) +#define CTA_FILTER_F_CTA_PROTO_ICMPV6_CODE (1 << 10) +#define CTA_FILTER_F_CTA_PROTO_ICMPV6_ID (1 << 11) +#define CTA_FILTER_F_MAX (1 << 12) +#define CTA_FILTER_F_ALL (CTA_FILTER_F_MAX-1) +#define CTA_FILTER_FLAG(ctattr) CTA_FILTER_F_ ## ctattr + +/* nf_queue.c */ +void nf_queue_nf_hook_drop(struct net *net); + +/* nf_log.c */ +int __init netfilter_log_init(void); + +/* core.c */ +void nf_hook_entries_delete_raw(struct nf_hook_entries __rcu **pp, + const struct nf_hook_ops *reg); +int nf_hook_entries_insert_raw(struct nf_hook_entries __rcu **pp, + const struct nf_hook_ops *reg); +#endif diff --git a/net/netfilter/nf_log.c b/net/netfilter/nf_log.c new file mode 100644 index 000000000..8a2929014 --- /dev/null +++ b/net/netfilter/nf_log.c @@ -0,0 +1,568 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "nf_internals.h" + +/* Internal logging interface, which relies on the real + LOG target modules */ + +#define NFLOGGER_NAME_LEN 64 + +int sysctl_nf_log_all_netns __read_mostly; +EXPORT_SYMBOL(sysctl_nf_log_all_netns); + +static struct nf_logger __rcu *loggers[NFPROTO_NUMPROTO][NF_LOG_TYPE_MAX] __read_mostly; +static DEFINE_MUTEX(nf_log_mutex); + +#define nft_log_dereference(logger) \ + rcu_dereference_protected(logger, lockdep_is_held(&nf_log_mutex)) + +static struct nf_logger *__find_logger(int pf, const char *str_logger) +{ + struct nf_logger *log; + int i; + + for (i = 0; i < NF_LOG_TYPE_MAX; i++) { + if (loggers[pf][i] == NULL) + continue; + + log = nft_log_dereference(loggers[pf][i]); + if (!strncasecmp(str_logger, log->name, strlen(log->name))) + return log; + } + + return NULL; +} + +int nf_log_set(struct net *net, u_int8_t pf, const struct nf_logger *logger) +{ + const struct nf_logger *log; + + if (pf == NFPROTO_UNSPEC || pf >= ARRAY_SIZE(net->nf.nf_loggers)) + return -EOPNOTSUPP; + + mutex_lock(&nf_log_mutex); + log = nft_log_dereference(net->nf.nf_loggers[pf]); + if (log == NULL) + rcu_assign_pointer(net->nf.nf_loggers[pf], logger); + + mutex_unlock(&nf_log_mutex); + + return 0; +} +EXPORT_SYMBOL(nf_log_set); + +void nf_log_unset(struct net *net, const struct nf_logger *logger) +{ + int i; + const struct nf_logger *log; + + mutex_lock(&nf_log_mutex); + for (i = 0; i < NFPROTO_NUMPROTO; i++) { + log = nft_log_dereference(net->nf.nf_loggers[i]); + if (log == logger) + RCU_INIT_POINTER(net->nf.nf_loggers[i], NULL); + } + mutex_unlock(&nf_log_mutex); +} +EXPORT_SYMBOL(nf_log_unset); + +/* return EEXIST if the same logger is registered, 0 on success. */ +int nf_log_register(u_int8_t pf, struct nf_logger *logger) +{ + int i; + int ret = 0; + + if (pf >= ARRAY_SIZE(init_net.nf.nf_loggers)) + return -EINVAL; + + mutex_lock(&nf_log_mutex); + + if (pf == NFPROTO_UNSPEC) { + for (i = NFPROTO_UNSPEC; i < NFPROTO_NUMPROTO; i++) { + if (rcu_access_pointer(loggers[i][logger->type])) { + ret = -EEXIST; + goto unlock; + } + } + for (i = NFPROTO_UNSPEC; i < NFPROTO_NUMPROTO; i++) + rcu_assign_pointer(loggers[i][logger->type], logger); + } else { + if (rcu_access_pointer(loggers[pf][logger->type])) { + ret = -EEXIST; + goto unlock; + } + rcu_assign_pointer(loggers[pf][logger->type], logger); + } + +unlock: + mutex_unlock(&nf_log_mutex); + return ret; +} +EXPORT_SYMBOL(nf_log_register); + +void nf_log_unregister(struct nf_logger *logger) +{ + const struct nf_logger *log; + int i; + + mutex_lock(&nf_log_mutex); + for (i = 0; i < NFPROTO_NUMPROTO; i++) { + log = nft_log_dereference(loggers[i][logger->type]); + if (log == logger) + RCU_INIT_POINTER(loggers[i][logger->type], NULL); + } + mutex_unlock(&nf_log_mutex); + synchronize_rcu(); +} +EXPORT_SYMBOL(nf_log_unregister); + +int nf_log_bind_pf(struct net *net, u_int8_t pf, + const struct nf_logger *logger) +{ + if (pf >= ARRAY_SIZE(net->nf.nf_loggers)) + return -EINVAL; + mutex_lock(&nf_log_mutex); + if (__find_logger(pf, logger->name) == NULL) { + mutex_unlock(&nf_log_mutex); + return -ENOENT; + } + rcu_assign_pointer(net->nf.nf_loggers[pf], logger); + mutex_unlock(&nf_log_mutex); + return 0; +} +EXPORT_SYMBOL(nf_log_bind_pf); + +void nf_log_unbind_pf(struct net *net, u_int8_t pf) +{ + if (pf >= ARRAY_SIZE(net->nf.nf_loggers)) + return; + mutex_lock(&nf_log_mutex); + RCU_INIT_POINTER(net->nf.nf_loggers[pf], NULL); + mutex_unlock(&nf_log_mutex); +} +EXPORT_SYMBOL(nf_log_unbind_pf); + +int nf_logger_find_get(int pf, enum nf_log_type type) +{ + struct nf_logger *logger; + int ret = -ENOENT; + + if (pf == NFPROTO_INET) { + ret = nf_logger_find_get(NFPROTO_IPV4, type); + if (ret < 0) + return ret; + + ret = nf_logger_find_get(NFPROTO_IPV6, type); + if (ret < 0) { + nf_logger_put(NFPROTO_IPV4, type); + return ret; + } + + return 0; + } + + rcu_read_lock(); + logger = rcu_dereference(loggers[pf][type]); + if (logger == NULL) + goto out; + + if (try_module_get(logger->me)) + ret = 0; +out: + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(nf_logger_find_get); + +void nf_logger_put(int pf, enum nf_log_type type) +{ + struct nf_logger *logger; + + if (pf == NFPROTO_INET) { + nf_logger_put(NFPROTO_IPV4, type); + nf_logger_put(NFPROTO_IPV6, type); + return; + } + + BUG_ON(loggers[pf][type] == NULL); + + rcu_read_lock(); + logger = rcu_dereference(loggers[pf][type]); + module_put(logger->me); + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(nf_logger_put); + +void nf_log_packet(struct net *net, + u_int8_t pf, + unsigned int hooknum, + const struct sk_buff *skb, + const struct net_device *in, + const struct net_device *out, + const struct nf_loginfo *loginfo, + const char *fmt, ...) +{ + va_list args; + char prefix[NF_LOG_PREFIXLEN]; + const struct nf_logger *logger; + + rcu_read_lock(); + if (loginfo != NULL) + logger = rcu_dereference(loggers[pf][loginfo->type]); + else + logger = rcu_dereference(net->nf.nf_loggers[pf]); + + if (logger) { + va_start(args, fmt); + vsnprintf(prefix, sizeof(prefix), fmt, args); + va_end(args); + logger->logfn(net, pf, hooknum, skb, in, out, loginfo, prefix); + } + rcu_read_unlock(); +} +EXPORT_SYMBOL(nf_log_packet); + +void nf_log_trace(struct net *net, + u_int8_t pf, + unsigned int hooknum, + const struct sk_buff *skb, + const struct net_device *in, + const struct net_device *out, + const struct nf_loginfo *loginfo, const char *fmt, ...) +{ + va_list args; + char prefix[NF_LOG_PREFIXLEN]; + const struct nf_logger *logger; + + rcu_read_lock(); + logger = rcu_dereference(net->nf.nf_loggers[pf]); + if (logger) { + va_start(args, fmt); + vsnprintf(prefix, sizeof(prefix), fmt, args); + va_end(args); + logger->logfn(net, pf, hooknum, skb, in, out, loginfo, prefix); + } + rcu_read_unlock(); +} +EXPORT_SYMBOL(nf_log_trace); + +#define S_SIZE (1024 - (sizeof(unsigned int) + 1)) + +struct nf_log_buf { + unsigned int count; + char buf[S_SIZE + 1]; +}; +static struct nf_log_buf emergency, *emergency_ptr = &emergency; + +__printf(2, 3) int nf_log_buf_add(struct nf_log_buf *m, const char *f, ...) +{ + va_list args; + int len; + + if (likely(m->count < S_SIZE)) { + va_start(args, f); + len = vsnprintf(m->buf + m->count, S_SIZE - m->count, f, args); + va_end(args); + if (likely(m->count + len < S_SIZE)) { + m->count += len; + return 0; + } + } + m->count = S_SIZE; + printk_once(KERN_ERR KBUILD_MODNAME " please increase S_SIZE\n"); + return -1; +} +EXPORT_SYMBOL_GPL(nf_log_buf_add); + +struct nf_log_buf *nf_log_buf_open(void) +{ + struct nf_log_buf *m = kmalloc(sizeof(*m), GFP_ATOMIC); + + if (unlikely(!m)) { + local_bh_disable(); + do { + m = xchg(&emergency_ptr, NULL); + } while (!m); + } + m->count = 0; + return m; +} +EXPORT_SYMBOL_GPL(nf_log_buf_open); + +void nf_log_buf_close(struct nf_log_buf *m) +{ + m->buf[m->count] = 0; + printk("%s\n", m->buf); + + if (likely(m != &emergency)) + kfree(m); + else { + emergency_ptr = m; + local_bh_enable(); + } +} +EXPORT_SYMBOL_GPL(nf_log_buf_close); + +#ifdef CONFIG_PROC_FS +static void *seq_start(struct seq_file *seq, loff_t *pos) +{ + struct net *net = seq_file_net(seq); + + mutex_lock(&nf_log_mutex); + + if (*pos >= ARRAY_SIZE(net->nf.nf_loggers)) + return NULL; + + return pos; +} + +static void *seq_next(struct seq_file *s, void *v, loff_t *pos) +{ + struct net *net = seq_file_net(s); + + (*pos)++; + + if (*pos >= ARRAY_SIZE(net->nf.nf_loggers)) + return NULL; + + return pos; +} + +static void seq_stop(struct seq_file *s, void *v) +{ + mutex_unlock(&nf_log_mutex); +} + +static int seq_show(struct seq_file *s, void *v) +{ + loff_t *pos = v; + const struct nf_logger *logger; + int i; + struct net *net = seq_file_net(s); + + logger = nft_log_dereference(net->nf.nf_loggers[*pos]); + + if (!logger) + seq_printf(s, "%2lld NONE (", *pos); + else + seq_printf(s, "%2lld %s (", *pos, logger->name); + + if (seq_has_overflowed(s)) + return -ENOSPC; + + for (i = 0; i < NF_LOG_TYPE_MAX; i++) { + if (loggers[*pos][i] == NULL) + continue; + + logger = nft_log_dereference(loggers[*pos][i]); + seq_puts(s, logger->name); + if (i == 0 && loggers[*pos][i + 1] != NULL) + seq_puts(s, ","); + + if (seq_has_overflowed(s)) + return -ENOSPC; + } + + seq_puts(s, ")\n"); + + if (seq_has_overflowed(s)) + return -ENOSPC; + return 0; +} + +static const struct seq_operations nflog_seq_ops = { + .start = seq_start, + .next = seq_next, + .stop = seq_stop, + .show = seq_show, +}; +#endif /* PROC_FS */ + +#ifdef CONFIG_SYSCTL +static char nf_log_sysctl_fnames[NFPROTO_NUMPROTO-NFPROTO_UNSPEC][3]; +static struct ctl_table nf_log_sysctl_table[NFPROTO_NUMPROTO+1]; +static struct ctl_table_header *nf_log_sysctl_fhdr; + +static struct ctl_table nf_log_sysctl_ftable[] = { + { + .procname = "nf_log_all_netns", + .data = &sysctl_nf_log_all_netns, + .maxlen = sizeof(sysctl_nf_log_all_netns), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { } +}; + +static int nf_log_proc_dostring(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + const struct nf_logger *logger; + char buf[NFLOGGER_NAME_LEN]; + int r = 0; + int tindex = (unsigned long)table->extra1; + struct net *net = table->extra2; + + if (write) { + struct ctl_table tmp = *table; + + /* proc_dostring() can append to existing strings, so we need to + * initialize it as an empty string. + */ + buf[0] = '\0'; + tmp.data = buf; + r = proc_dostring(&tmp, write, buffer, lenp, ppos); + if (r) + return r; + + if (!strcmp(buf, "NONE")) { + nf_log_unbind_pf(net, tindex); + return 0; + } + mutex_lock(&nf_log_mutex); + logger = __find_logger(tindex, buf); + if (logger == NULL) { + mutex_unlock(&nf_log_mutex); + return -ENOENT; + } + rcu_assign_pointer(net->nf.nf_loggers[tindex], logger); + mutex_unlock(&nf_log_mutex); + } else { + struct ctl_table tmp = *table; + + tmp.data = buf; + mutex_lock(&nf_log_mutex); + logger = nft_log_dereference(net->nf.nf_loggers[tindex]); + if (!logger) + strscpy(buf, "NONE", sizeof(buf)); + else + strscpy(buf, logger->name, sizeof(buf)); + mutex_unlock(&nf_log_mutex); + r = proc_dostring(&tmp, write, buffer, lenp, ppos); + } + + return r; +} + +static int netfilter_log_sysctl_init(struct net *net) +{ + int i; + struct ctl_table *table; + + table = nf_log_sysctl_table; + if (!net_eq(net, &init_net)) { + table = kmemdup(nf_log_sysctl_table, + sizeof(nf_log_sysctl_table), + GFP_KERNEL); + if (!table) + goto err_alloc; + } else { + for (i = NFPROTO_UNSPEC; i < NFPROTO_NUMPROTO; i++) { + snprintf(nf_log_sysctl_fnames[i], + 3, "%d", i); + nf_log_sysctl_table[i].procname = + nf_log_sysctl_fnames[i]; + nf_log_sysctl_table[i].maxlen = NFLOGGER_NAME_LEN; + nf_log_sysctl_table[i].mode = 0644; + nf_log_sysctl_table[i].proc_handler = + nf_log_proc_dostring; + nf_log_sysctl_table[i].extra1 = + (void *)(unsigned long) i; + } + nf_log_sysctl_fhdr = register_net_sysctl(net, "net/netfilter", + nf_log_sysctl_ftable); + if (!nf_log_sysctl_fhdr) + goto err_freg; + } + + for (i = NFPROTO_UNSPEC; i < NFPROTO_NUMPROTO; i++) + table[i].extra2 = net; + + net->nf.nf_log_dir_header = register_net_sysctl(net, + "net/netfilter/nf_log", + table); + if (!net->nf.nf_log_dir_header) + goto err_reg; + + return 0; + +err_reg: + if (!net_eq(net, &init_net)) + kfree(table); + else + unregister_net_sysctl_table(nf_log_sysctl_fhdr); +err_freg: +err_alloc: + return -ENOMEM; +} + +static void netfilter_log_sysctl_exit(struct net *net) +{ + struct ctl_table *table; + + table = net->nf.nf_log_dir_header->ctl_table_arg; + unregister_net_sysctl_table(net->nf.nf_log_dir_header); + if (!net_eq(net, &init_net)) + kfree(table); + else + unregister_net_sysctl_table(nf_log_sysctl_fhdr); +} +#else +static int netfilter_log_sysctl_init(struct net *net) +{ + return 0; +} + +static void netfilter_log_sysctl_exit(struct net *net) +{ +} +#endif /* CONFIG_SYSCTL */ + +static int __net_init nf_log_net_init(struct net *net) +{ + int ret = -ENOMEM; + +#ifdef CONFIG_PROC_FS + if (!proc_create_net("nf_log", 0444, net->nf.proc_netfilter, + &nflog_seq_ops, sizeof(struct seq_net_private))) + return ret; +#endif + ret = netfilter_log_sysctl_init(net); + if (ret < 0) + goto out_sysctl; + + return 0; + +out_sysctl: +#ifdef CONFIG_PROC_FS + remove_proc_entry("nf_log", net->nf.proc_netfilter); +#endif + return ret; +} + +static void __net_exit nf_log_net_exit(struct net *net) +{ + netfilter_log_sysctl_exit(net); +#ifdef CONFIG_PROC_FS + remove_proc_entry("nf_log", net->nf.proc_netfilter); +#endif +} + +static struct pernet_operations nf_log_net_ops = { + .init = nf_log_net_init, + .exit = nf_log_net_exit, +}; + +int __init netfilter_log_init(void) +{ + return register_pernet_subsys(&nf_log_net_ops); +} diff --git a/net/netfilter/nf_log_syslog.c b/net/netfilter/nf_log_syslog.c new file mode 100644 index 000000000..584022260 --- /dev/null +++ b/net/netfilter/nf_log_syslog.c @@ -0,0 +1,1083 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2004 Netfilter Core Team + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +static const struct nf_loginfo default_loginfo = { + .type = NF_LOG_TYPE_LOG, + .u = { + .log = { + .level = LOGLEVEL_NOTICE, + .logflags = NF_LOG_DEFAULT_MASK, + }, + }, +}; + +struct arppayload { + unsigned char mac_src[ETH_ALEN]; + unsigned char ip_src[4]; + unsigned char mac_dst[ETH_ALEN]; + unsigned char ip_dst[4]; +}; + +/* Guard against containers flooding syslog. */ +static bool nf_log_allowed(const struct net *net) +{ + return net_eq(net, &init_net) || sysctl_nf_log_all_netns; +} + +static void nf_log_dump_vlan(struct nf_log_buf *m, const struct sk_buff *skb) +{ + u16 vid; + + if (!skb_vlan_tag_present(skb)) + return; + + vid = skb_vlan_tag_get(skb); + nf_log_buf_add(m, "VPROTO=%04x VID=%u ", ntohs(skb->vlan_proto), vid); +} +static void noinline_for_stack +dump_arp_packet(struct nf_log_buf *m, + const struct nf_loginfo *info, + const struct sk_buff *skb, unsigned int nhoff) +{ + const struct arppayload *ap; + struct arppayload _arpp; + const struct arphdr *ah; + unsigned int logflags; + struct arphdr _arph; + + ah = skb_header_pointer(skb, nhoff, sizeof(_arph), &_arph); + if (!ah) { + nf_log_buf_add(m, "TRUNCATED"); + return; + } + + if (info->type == NF_LOG_TYPE_LOG) + logflags = info->u.log.logflags; + else + logflags = NF_LOG_DEFAULT_MASK; + + if (logflags & NF_LOG_MACDECODE) { + nf_log_buf_add(m, "MACSRC=%pM MACDST=%pM ", + eth_hdr(skb)->h_source, eth_hdr(skb)->h_dest); + nf_log_dump_vlan(m, skb); + nf_log_buf_add(m, "MACPROTO=%04x ", + ntohs(eth_hdr(skb)->h_proto)); + } + + nf_log_buf_add(m, "ARP HTYPE=%d PTYPE=0x%04x OPCODE=%d", + ntohs(ah->ar_hrd), ntohs(ah->ar_pro), ntohs(ah->ar_op)); + /* If it's for Ethernet and the lengths are OK, then log the ARP + * payload. + */ + if (ah->ar_hrd != htons(ARPHRD_ETHER) || + ah->ar_hln != ETH_ALEN || + ah->ar_pln != sizeof(__be32)) + return; + + ap = skb_header_pointer(skb, nhoff + sizeof(_arph), sizeof(_arpp), &_arpp); + if (!ap) { + nf_log_buf_add(m, " INCOMPLETE [%zu bytes]", + skb->len - sizeof(_arph)); + return; + } + nf_log_buf_add(m, " MACSRC=%pM IPSRC=%pI4 MACDST=%pM IPDST=%pI4", + ap->mac_src, ap->ip_src, ap->mac_dst, ap->ip_dst); +} + +static void +nf_log_dump_packet_common(struct nf_log_buf *m, u8 pf, + unsigned int hooknum, const struct sk_buff *skb, + const struct net_device *in, + const struct net_device *out, + const struct nf_loginfo *loginfo, const char *prefix, + struct net *net) +{ + const struct net_device *physoutdev __maybe_unused; + const struct net_device *physindev __maybe_unused; + + nf_log_buf_add(m, KERN_SOH "%c%sIN=%s OUT=%s ", + '0' + loginfo->u.log.level, prefix, + in ? in->name : "", + out ? out->name : ""); +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + physindev = nf_bridge_get_physindev(skb, net); + if (physindev && in != physindev) + nf_log_buf_add(m, "PHYSIN=%s ", physindev->name); + physoutdev = nf_bridge_get_physoutdev(skb); + if (physoutdev && out != physoutdev) + nf_log_buf_add(m, "PHYSOUT=%s ", physoutdev->name); +#endif +} + +static void nf_log_arp_packet(struct net *net, u_int8_t pf, + unsigned int hooknum, const struct sk_buff *skb, + const struct net_device *in, + const struct net_device *out, + const struct nf_loginfo *loginfo, + const char *prefix) +{ + struct nf_log_buf *m; + + if (!nf_log_allowed(net)) + return; + + m = nf_log_buf_open(); + + if (!loginfo) + loginfo = &default_loginfo; + + nf_log_dump_packet_common(m, pf, hooknum, skb, in, out, loginfo, + prefix, net); + dump_arp_packet(m, loginfo, skb, skb_network_offset(skb)); + + nf_log_buf_close(m); +} + +static struct nf_logger nf_arp_logger __read_mostly = { + .name = "nf_log_arp", + .type = NF_LOG_TYPE_LOG, + .logfn = nf_log_arp_packet, + .me = THIS_MODULE, +}; + +static void nf_log_dump_sk_uid_gid(struct net *net, struct nf_log_buf *m, + struct sock *sk) +{ + if (!sk || !sk_fullsock(sk) || !net_eq(net, sock_net(sk))) + return; + + read_lock_bh(&sk->sk_callback_lock); + if (sk->sk_socket && sk->sk_socket->file) { + const struct cred *cred = sk->sk_socket->file->f_cred; + + nf_log_buf_add(m, "UID=%u GID=%u ", + from_kuid_munged(&init_user_ns, cred->fsuid), + from_kgid_munged(&init_user_ns, cred->fsgid)); + } + read_unlock_bh(&sk->sk_callback_lock); +} + +static noinline_for_stack int +nf_log_dump_tcp_header(struct nf_log_buf *m, + const struct sk_buff *skb, + u8 proto, int fragment, + unsigned int offset, + unsigned int logflags) +{ + struct tcphdr _tcph; + const struct tcphdr *th; + + /* Max length: 10 "PROTO=TCP " */ + nf_log_buf_add(m, "PROTO=TCP "); + + if (fragment) + return 0; + + /* Max length: 25 "INCOMPLETE [65535 bytes] " */ + th = skb_header_pointer(skb, offset, sizeof(_tcph), &_tcph); + if (!th) { + nf_log_buf_add(m, "INCOMPLETE [%u bytes] ", skb->len - offset); + return 1; + } + + /* Max length: 20 "SPT=65535 DPT=65535 " */ + nf_log_buf_add(m, "SPT=%u DPT=%u ", + ntohs(th->source), ntohs(th->dest)); + /* Max length: 30 "SEQ=4294967295 ACK=4294967295 " */ + if (logflags & NF_LOG_TCPSEQ) { + nf_log_buf_add(m, "SEQ=%u ACK=%u ", + ntohl(th->seq), ntohl(th->ack_seq)); + } + + /* Max length: 13 "WINDOW=65535 " */ + nf_log_buf_add(m, "WINDOW=%u ", ntohs(th->window)); + /* Max length: 9 "RES=0x3C " */ + nf_log_buf_add(m, "RES=0x%02x ", (u_int8_t)(ntohl(tcp_flag_word(th) & + TCP_RESERVED_BITS) >> 22)); + /* Max length: 32 "CWR ECE URG ACK PSH RST SYN FIN " */ + if (th->cwr) + nf_log_buf_add(m, "CWR "); + if (th->ece) + nf_log_buf_add(m, "ECE "); + if (th->urg) + nf_log_buf_add(m, "URG "); + if (th->ack) + nf_log_buf_add(m, "ACK "); + if (th->psh) + nf_log_buf_add(m, "PSH "); + if (th->rst) + nf_log_buf_add(m, "RST "); + if (th->syn) + nf_log_buf_add(m, "SYN "); + if (th->fin) + nf_log_buf_add(m, "FIN "); + /* Max length: 11 "URGP=65535 " */ + nf_log_buf_add(m, "URGP=%u ", ntohs(th->urg_ptr)); + + if ((logflags & NF_LOG_TCPOPT) && th->doff * 4 > sizeof(struct tcphdr)) { + unsigned int optsize = th->doff * 4 - sizeof(struct tcphdr); + u8 _opt[60 - sizeof(struct tcphdr)]; + unsigned int i; + const u8 *op; + + op = skb_header_pointer(skb, offset + sizeof(struct tcphdr), + optsize, _opt); + if (!op) { + nf_log_buf_add(m, "OPT (TRUNCATED)"); + return 1; + } + + /* Max length: 127 "OPT (" 15*4*2chars ") " */ + nf_log_buf_add(m, "OPT ("); + for (i = 0; i < optsize; i++) + nf_log_buf_add(m, "%02X", op[i]); + + nf_log_buf_add(m, ") "); + } + + return 0; +} + +static noinline_for_stack int +nf_log_dump_udp_header(struct nf_log_buf *m, + const struct sk_buff *skb, + u8 proto, int fragment, + unsigned int offset) +{ + struct udphdr _udph; + const struct udphdr *uh; + + if (proto == IPPROTO_UDP) + /* Max length: 10 "PROTO=UDP " */ + nf_log_buf_add(m, "PROTO=UDP "); + else /* Max length: 14 "PROTO=UDPLITE " */ + nf_log_buf_add(m, "PROTO=UDPLITE "); + + if (fragment) + goto out; + + /* Max length: 25 "INCOMPLETE [65535 bytes] " */ + uh = skb_header_pointer(skb, offset, sizeof(_udph), &_udph); + if (!uh) { + nf_log_buf_add(m, "INCOMPLETE [%u bytes] ", skb->len - offset); + + return 1; + } + + /* Max length: 20 "SPT=65535 DPT=65535 " */ + nf_log_buf_add(m, "SPT=%u DPT=%u LEN=%u ", + ntohs(uh->source), ntohs(uh->dest), ntohs(uh->len)); + +out: + return 0; +} + +/* One level of recursion won't kill us */ +static noinline_for_stack void +dump_ipv4_packet(struct net *net, struct nf_log_buf *m, + const struct nf_loginfo *info, + const struct sk_buff *skb, unsigned int iphoff) +{ + const struct iphdr *ih; + unsigned int logflags; + struct iphdr _iph; + + if (info->type == NF_LOG_TYPE_LOG) + logflags = info->u.log.logflags; + else + logflags = NF_LOG_DEFAULT_MASK; + + ih = skb_header_pointer(skb, iphoff, sizeof(_iph), &_iph); + if (!ih) { + nf_log_buf_add(m, "TRUNCATED"); + return; + } + + /* Important fields: + * TOS, len, DF/MF, fragment offset, TTL, src, dst, options. + * Max length: 40 "SRC=255.255.255.255 DST=255.255.255.255 " + */ + nf_log_buf_add(m, "SRC=%pI4 DST=%pI4 ", &ih->saddr, &ih->daddr); + + /* Max length: 46 "LEN=65535 TOS=0xFF PREC=0xFF TTL=255 ID=65535 " */ + nf_log_buf_add(m, "LEN=%u TOS=0x%02X PREC=0x%02X TTL=%u ID=%u ", + iph_totlen(skb, ih), ih->tos & IPTOS_TOS_MASK, + ih->tos & IPTOS_PREC_MASK, ih->ttl, ntohs(ih->id)); + + /* Max length: 6 "CE DF MF " */ + if (ntohs(ih->frag_off) & IP_CE) + nf_log_buf_add(m, "CE "); + if (ntohs(ih->frag_off) & IP_DF) + nf_log_buf_add(m, "DF "); + if (ntohs(ih->frag_off) & IP_MF) + nf_log_buf_add(m, "MF "); + + /* Max length: 11 "FRAG:65535 " */ + if (ntohs(ih->frag_off) & IP_OFFSET) + nf_log_buf_add(m, "FRAG:%u ", ntohs(ih->frag_off) & IP_OFFSET); + + if ((logflags & NF_LOG_IPOPT) && + ih->ihl * 4 > sizeof(struct iphdr)) { + unsigned char _opt[4 * 15 - sizeof(struct iphdr)]; + const unsigned char *op; + unsigned int i, optsize; + + optsize = ih->ihl * 4 - sizeof(struct iphdr); + op = skb_header_pointer(skb, iphoff + sizeof(_iph), + optsize, _opt); + if (!op) { + nf_log_buf_add(m, "TRUNCATED"); + return; + } + + /* Max length: 127 "OPT (" 15*4*2chars ") " */ + nf_log_buf_add(m, "OPT ("); + for (i = 0; i < optsize; i++) + nf_log_buf_add(m, "%02X", op[i]); + nf_log_buf_add(m, ") "); + } + + switch (ih->protocol) { + case IPPROTO_TCP: + if (nf_log_dump_tcp_header(m, skb, ih->protocol, + ntohs(ih->frag_off) & IP_OFFSET, + iphoff + ih->ihl * 4, logflags)) + return; + break; + case IPPROTO_UDP: + case IPPROTO_UDPLITE: + if (nf_log_dump_udp_header(m, skb, ih->protocol, + ntohs(ih->frag_off) & IP_OFFSET, + iphoff + ih->ihl * 4)) + return; + break; + case IPPROTO_ICMP: { + static const size_t required_len[NR_ICMP_TYPES + 1] = { + [ICMP_ECHOREPLY] = 4, + [ICMP_DEST_UNREACH] = 8 + sizeof(struct iphdr), + [ICMP_SOURCE_QUENCH] = 8 + sizeof(struct iphdr), + [ICMP_REDIRECT] = 8 + sizeof(struct iphdr), + [ICMP_ECHO] = 4, + [ICMP_TIME_EXCEEDED] = 8 + sizeof(struct iphdr), + [ICMP_PARAMETERPROB] = 8 + sizeof(struct iphdr), + [ICMP_TIMESTAMP] = 20, + [ICMP_TIMESTAMPREPLY] = 20, + [ICMP_ADDRESS] = 12, + [ICMP_ADDRESSREPLY] = 12 }; + const struct icmphdr *ich; + struct icmphdr _icmph; + + /* Max length: 11 "PROTO=ICMP " */ + nf_log_buf_add(m, "PROTO=ICMP "); + + if (ntohs(ih->frag_off) & IP_OFFSET) + break; + + /* Max length: 25 "INCOMPLETE [65535 bytes] " */ + ich = skb_header_pointer(skb, iphoff + ih->ihl * 4, + sizeof(_icmph), &_icmph); + if (!ich) { + nf_log_buf_add(m, "INCOMPLETE [%u bytes] ", + skb->len - iphoff - ih->ihl * 4); + break; + } + + /* Max length: 18 "TYPE=255 CODE=255 " */ + nf_log_buf_add(m, "TYPE=%u CODE=%u ", ich->type, ich->code); + + /* Max length: 25 "INCOMPLETE [65535 bytes] " */ + if (ich->type <= NR_ICMP_TYPES && + required_len[ich->type] && + skb->len - iphoff - ih->ihl * 4 < required_len[ich->type]) { + nf_log_buf_add(m, "INCOMPLETE [%u bytes] ", + skb->len - iphoff - ih->ihl * 4); + break; + } + + switch (ich->type) { + case ICMP_ECHOREPLY: + case ICMP_ECHO: + /* Max length: 19 "ID=65535 SEQ=65535 " */ + nf_log_buf_add(m, "ID=%u SEQ=%u ", + ntohs(ich->un.echo.id), + ntohs(ich->un.echo.sequence)); + break; + + case ICMP_PARAMETERPROB: + /* Max length: 14 "PARAMETER=255 " */ + nf_log_buf_add(m, "PARAMETER=%u ", + ntohl(ich->un.gateway) >> 24); + break; + case ICMP_REDIRECT: + /* Max length: 24 "GATEWAY=255.255.255.255 " */ + nf_log_buf_add(m, "GATEWAY=%pI4 ", &ich->un.gateway); + fallthrough; + case ICMP_DEST_UNREACH: + case ICMP_SOURCE_QUENCH: + case ICMP_TIME_EXCEEDED: + /* Max length: 3+maxlen */ + if (!iphoff) { /* Only recurse once. */ + nf_log_buf_add(m, "["); + dump_ipv4_packet(net, m, info, skb, + iphoff + ih->ihl * 4 + sizeof(_icmph)); + nf_log_buf_add(m, "] "); + } + + /* Max length: 10 "MTU=65535 " */ + if (ich->type == ICMP_DEST_UNREACH && + ich->code == ICMP_FRAG_NEEDED) { + nf_log_buf_add(m, "MTU=%u ", + ntohs(ich->un.frag.mtu)); + } + } + break; + } + /* Max Length */ + case IPPROTO_AH: { + const struct ip_auth_hdr *ah; + struct ip_auth_hdr _ahdr; + + if (ntohs(ih->frag_off) & IP_OFFSET) + break; + + /* Max length: 9 "PROTO=AH " */ + nf_log_buf_add(m, "PROTO=AH "); + + /* Max length: 25 "INCOMPLETE [65535 bytes] " */ + ah = skb_header_pointer(skb, iphoff + ih->ihl * 4, + sizeof(_ahdr), &_ahdr); + if (!ah) { + nf_log_buf_add(m, "INCOMPLETE [%u bytes] ", + skb->len - iphoff - ih->ihl * 4); + break; + } + + /* Length: 15 "SPI=0xF1234567 " */ + nf_log_buf_add(m, "SPI=0x%x ", ntohl(ah->spi)); + break; + } + case IPPROTO_ESP: { + const struct ip_esp_hdr *eh; + struct ip_esp_hdr _esph; + + /* Max length: 10 "PROTO=ESP " */ + nf_log_buf_add(m, "PROTO=ESP "); + + if (ntohs(ih->frag_off) & IP_OFFSET) + break; + + /* Max length: 25 "INCOMPLETE [65535 bytes] " */ + eh = skb_header_pointer(skb, iphoff + ih->ihl * 4, + sizeof(_esph), &_esph); + if (!eh) { + nf_log_buf_add(m, "INCOMPLETE [%u bytes] ", + skb->len - iphoff - ih->ihl * 4); + break; + } + + /* Length: 15 "SPI=0xF1234567 " */ + nf_log_buf_add(m, "SPI=0x%x ", ntohl(eh->spi)); + break; + } + /* Max length: 10 "PROTO 255 " */ + default: + nf_log_buf_add(m, "PROTO=%u ", ih->protocol); + } + + /* Max length: 15 "UID=4294967295 " */ + if ((logflags & NF_LOG_UID) && !iphoff) + nf_log_dump_sk_uid_gid(net, m, skb->sk); + + /* Max length: 16 "MARK=0xFFFFFFFF " */ + if (!iphoff && skb->mark) + nf_log_buf_add(m, "MARK=0x%x ", skb->mark); + + /* Proto Max log string length */ + /* IP: 40+46+6+11+127 = 230 */ + /* TCP: 10+max(25,20+30+13+9+32+11+127) = 252 */ + /* UDP: 10+max(25,20) = 35 */ + /* UDPLITE: 14+max(25,20) = 39 */ + /* ICMP: 11+max(25, 18+25+max(19,14,24+3+n+10,3+n+10)) = 91+n */ + /* ESP: 10+max(25)+15 = 50 */ + /* AH: 9+max(25)+15 = 49 */ + /* unknown: 10 */ + + /* (ICMP allows recursion one level deep) */ + /* maxlen = IP + ICMP + IP + max(TCP,UDP,ICMP,unknown) */ + /* maxlen = 230+ 91 + 230 + 252 = 803 */ +} + +static noinline_for_stack void +dump_ipv6_packet(struct net *net, struct nf_log_buf *m, + const struct nf_loginfo *info, + const struct sk_buff *skb, unsigned int ip6hoff, + int recurse) +{ + const struct ipv6hdr *ih; + unsigned int hdrlen = 0; + unsigned int logflags; + struct ipv6hdr _ip6h; + unsigned int ptr; + u8 currenthdr; + int fragment; + + if (info->type == NF_LOG_TYPE_LOG) + logflags = info->u.log.logflags; + else + logflags = NF_LOG_DEFAULT_MASK; + + ih = skb_header_pointer(skb, ip6hoff, sizeof(_ip6h), &_ip6h); + if (!ih) { + nf_log_buf_add(m, "TRUNCATED"); + return; + } + + /* Max length: 88 "SRC=0000.0000.0000.0000.0000.0000.0000.0000 DST=0000.0000.0000.0000.0000.0000.0000.0000 " */ + nf_log_buf_add(m, "SRC=%pI6 DST=%pI6 ", &ih->saddr, &ih->daddr); + + /* Max length: 44 "LEN=65535 TC=255 HOPLIMIT=255 FLOWLBL=FFFFF " */ + nf_log_buf_add(m, "LEN=%zu TC=%u HOPLIMIT=%u FLOWLBL=%u ", + ntohs(ih->payload_len) + sizeof(struct ipv6hdr), + (ntohl(*(__be32 *)ih) & 0x0ff00000) >> 20, + ih->hop_limit, + (ntohl(*(__be32 *)ih) & 0x000fffff)); + + fragment = 0; + ptr = ip6hoff + sizeof(struct ipv6hdr); + currenthdr = ih->nexthdr; + while (currenthdr != NEXTHDR_NONE && nf_ip6_ext_hdr(currenthdr)) { + struct ipv6_opt_hdr _hdr; + const struct ipv6_opt_hdr *hp; + + hp = skb_header_pointer(skb, ptr, sizeof(_hdr), &_hdr); + if (!hp) { + nf_log_buf_add(m, "TRUNCATED"); + return; + } + + /* Max length: 48 "OPT (...) " */ + if (logflags & NF_LOG_IPOPT) + nf_log_buf_add(m, "OPT ( "); + + switch (currenthdr) { + case IPPROTO_FRAGMENT: { + struct frag_hdr _fhdr; + const struct frag_hdr *fh; + + nf_log_buf_add(m, "FRAG:"); + fh = skb_header_pointer(skb, ptr, sizeof(_fhdr), + &_fhdr); + if (!fh) { + nf_log_buf_add(m, "TRUNCATED "); + return; + } + + /* Max length: 6 "65535 " */ + nf_log_buf_add(m, "%u ", ntohs(fh->frag_off) & 0xFFF8); + + /* Max length: 11 "INCOMPLETE " */ + if (fh->frag_off & htons(0x0001)) + nf_log_buf_add(m, "INCOMPLETE "); + + nf_log_buf_add(m, "ID:%08x ", + ntohl(fh->identification)); + + if (ntohs(fh->frag_off) & 0xFFF8) + fragment = 1; + + hdrlen = 8; + break; + } + case IPPROTO_DSTOPTS: + case IPPROTO_ROUTING: + case IPPROTO_HOPOPTS: + if (fragment) { + if (logflags & NF_LOG_IPOPT) + nf_log_buf_add(m, ")"); + return; + } + hdrlen = ipv6_optlen(hp); + break; + /* Max Length */ + case IPPROTO_AH: + if (logflags & NF_LOG_IPOPT) { + struct ip_auth_hdr _ahdr; + const struct ip_auth_hdr *ah; + + /* Max length: 3 "AH " */ + nf_log_buf_add(m, "AH "); + + if (fragment) { + nf_log_buf_add(m, ")"); + return; + } + + ah = skb_header_pointer(skb, ptr, sizeof(_ahdr), + &_ahdr); + if (!ah) { + /* Max length: 26 "INCOMPLETE [65535 bytes] )" */ + nf_log_buf_add(m, "INCOMPLETE [%u bytes] )", + skb->len - ptr); + return; + } + + /* Length: 15 "SPI=0xF1234567 */ + nf_log_buf_add(m, "SPI=0x%x ", ntohl(ah->spi)); + } + + hdrlen = ipv6_authlen(hp); + break; + case IPPROTO_ESP: + if (logflags & NF_LOG_IPOPT) { + struct ip_esp_hdr _esph; + const struct ip_esp_hdr *eh; + + /* Max length: 4 "ESP " */ + nf_log_buf_add(m, "ESP "); + + if (fragment) { + nf_log_buf_add(m, ")"); + return; + } + + /* Max length: 26 "INCOMPLETE [65535 bytes] )" */ + eh = skb_header_pointer(skb, ptr, sizeof(_esph), + &_esph); + if (!eh) { + nf_log_buf_add(m, "INCOMPLETE [%u bytes] )", + skb->len - ptr); + return; + } + + /* Length: 16 "SPI=0xF1234567 )" */ + nf_log_buf_add(m, "SPI=0x%x )", + ntohl(eh->spi)); + } + return; + default: + /* Max length: 20 "Unknown Ext Hdr 255" */ + nf_log_buf_add(m, "Unknown Ext Hdr %u", currenthdr); + return; + } + if (logflags & NF_LOG_IPOPT) + nf_log_buf_add(m, ") "); + + currenthdr = hp->nexthdr; + ptr += hdrlen; + } + + switch (currenthdr) { + case IPPROTO_TCP: + if (nf_log_dump_tcp_header(m, skb, currenthdr, fragment, + ptr, logflags)) + return; + break; + case IPPROTO_UDP: + case IPPROTO_UDPLITE: + if (nf_log_dump_udp_header(m, skb, currenthdr, fragment, ptr)) + return; + break; + case IPPROTO_ICMPV6: { + struct icmp6hdr _icmp6h; + const struct icmp6hdr *ic; + + /* Max length: 13 "PROTO=ICMPv6 " */ + nf_log_buf_add(m, "PROTO=ICMPv6 "); + + if (fragment) + break; + + /* Max length: 25 "INCOMPLETE [65535 bytes] " */ + ic = skb_header_pointer(skb, ptr, sizeof(_icmp6h), &_icmp6h); + if (!ic) { + nf_log_buf_add(m, "INCOMPLETE [%u bytes] ", + skb->len - ptr); + return; + } + + /* Max length: 18 "TYPE=255 CODE=255 " */ + nf_log_buf_add(m, "TYPE=%u CODE=%u ", + ic->icmp6_type, ic->icmp6_code); + + switch (ic->icmp6_type) { + case ICMPV6_ECHO_REQUEST: + case ICMPV6_ECHO_REPLY: + /* Max length: 19 "ID=65535 SEQ=65535 " */ + nf_log_buf_add(m, "ID=%u SEQ=%u ", + ntohs(ic->icmp6_identifier), + ntohs(ic->icmp6_sequence)); + break; + case ICMPV6_MGM_QUERY: + case ICMPV6_MGM_REPORT: + case ICMPV6_MGM_REDUCTION: + break; + + case ICMPV6_PARAMPROB: + /* Max length: 17 "POINTER=ffffffff " */ + nf_log_buf_add(m, "POINTER=%08x ", + ntohl(ic->icmp6_pointer)); + fallthrough; + case ICMPV6_DEST_UNREACH: + case ICMPV6_PKT_TOOBIG: + case ICMPV6_TIME_EXCEED: + /* Max length: 3+maxlen */ + if (recurse) { + nf_log_buf_add(m, "["); + dump_ipv6_packet(net, m, info, skb, + ptr + sizeof(_icmp6h), 0); + nf_log_buf_add(m, "] "); + } + + /* Max length: 10 "MTU=65535 " */ + if (ic->icmp6_type == ICMPV6_PKT_TOOBIG) { + nf_log_buf_add(m, "MTU=%u ", + ntohl(ic->icmp6_mtu)); + } + } + break; + } + /* Max length: 10 "PROTO=255 " */ + default: + nf_log_buf_add(m, "PROTO=%u ", currenthdr); + } + + /* Max length: 15 "UID=4294967295 " */ + if ((logflags & NF_LOG_UID) && recurse) + nf_log_dump_sk_uid_gid(net, m, skb->sk); + + /* Max length: 16 "MARK=0xFFFFFFFF " */ + if (recurse && skb->mark) + nf_log_buf_add(m, "MARK=0x%x ", skb->mark); +} + +static void dump_mac_header(struct nf_log_buf *m, + const struct nf_loginfo *info, + const struct sk_buff *skb) +{ + struct net_device *dev = skb->dev; + unsigned int logflags = 0; + + if (info->type == NF_LOG_TYPE_LOG) + logflags = info->u.log.logflags; + + if (!(logflags & NF_LOG_MACDECODE)) + goto fallback; + + switch (dev->type) { + case ARPHRD_ETHER: + nf_log_buf_add(m, "MACSRC=%pM MACDST=%pM ", + eth_hdr(skb)->h_source, eth_hdr(skb)->h_dest); + nf_log_dump_vlan(m, skb); + nf_log_buf_add(m, "MACPROTO=%04x ", + ntohs(eth_hdr(skb)->h_proto)); + return; + default: + break; + } + +fallback: + nf_log_buf_add(m, "MAC="); + if (dev->hard_header_len && + skb->mac_header != skb->network_header) { + const unsigned char *p = skb_mac_header(skb); + unsigned int i; + + if (dev->type == ARPHRD_SIT) { + p -= ETH_HLEN; + + if (p < skb->head) + p = NULL; + } + + if (p) { + nf_log_buf_add(m, "%02x", *p++); + for (i = 1; i < dev->hard_header_len; i++) + nf_log_buf_add(m, ":%02x", *p++); + } + + if (dev->type == ARPHRD_SIT) { + const struct iphdr *iph = + (struct iphdr *)skb_mac_header(skb); + + nf_log_buf_add(m, " TUNNEL=%pI4->%pI4", &iph->saddr, + &iph->daddr); + } + } + nf_log_buf_add(m, " "); +} + +static void nf_log_ip_packet(struct net *net, u_int8_t pf, + unsigned int hooknum, const struct sk_buff *skb, + const struct net_device *in, + const struct net_device *out, + const struct nf_loginfo *loginfo, + const char *prefix) +{ + struct nf_log_buf *m; + + if (!nf_log_allowed(net)) + return; + + m = nf_log_buf_open(); + + if (!loginfo) + loginfo = &default_loginfo; + + nf_log_dump_packet_common(m, pf, hooknum, skb, in, + out, loginfo, prefix, net); + + if (in) + dump_mac_header(m, loginfo, skb); + + dump_ipv4_packet(net, m, loginfo, skb, skb_network_offset(skb)); + + nf_log_buf_close(m); +} + +static struct nf_logger nf_ip_logger __read_mostly = { + .name = "nf_log_ipv4", + .type = NF_LOG_TYPE_LOG, + .logfn = nf_log_ip_packet, + .me = THIS_MODULE, +}; + +static void nf_log_ip6_packet(struct net *net, u_int8_t pf, + unsigned int hooknum, const struct sk_buff *skb, + const struct net_device *in, + const struct net_device *out, + const struct nf_loginfo *loginfo, + const char *prefix) +{ + struct nf_log_buf *m; + + if (!nf_log_allowed(net)) + return; + + m = nf_log_buf_open(); + + if (!loginfo) + loginfo = &default_loginfo; + + nf_log_dump_packet_common(m, pf, hooknum, skb, in, out, + loginfo, prefix, net); + + if (in) + dump_mac_header(m, loginfo, skb); + + dump_ipv6_packet(net, m, loginfo, skb, skb_network_offset(skb), 1); + + nf_log_buf_close(m); +} + +static struct nf_logger nf_ip6_logger __read_mostly = { + .name = "nf_log_ipv6", + .type = NF_LOG_TYPE_LOG, + .logfn = nf_log_ip6_packet, + .me = THIS_MODULE, +}; + +static void nf_log_unknown_packet(struct net *net, u_int8_t pf, + unsigned int hooknum, + const struct sk_buff *skb, + const struct net_device *in, + const struct net_device *out, + const struct nf_loginfo *loginfo, + const char *prefix) +{ + struct nf_log_buf *m; + + if (!nf_log_allowed(net)) + return; + + m = nf_log_buf_open(); + + if (!loginfo) + loginfo = &default_loginfo; + + nf_log_dump_packet_common(m, pf, hooknum, skb, in, out, loginfo, + prefix, net); + + dump_mac_header(m, loginfo, skb); + + nf_log_buf_close(m); +} + +static void nf_log_netdev_packet(struct net *net, u_int8_t pf, + unsigned int hooknum, + const struct sk_buff *skb, + const struct net_device *in, + const struct net_device *out, + const struct nf_loginfo *loginfo, + const char *prefix) +{ + switch (skb->protocol) { + case htons(ETH_P_IP): + nf_log_ip_packet(net, pf, hooknum, skb, in, out, loginfo, prefix); + break; + case htons(ETH_P_IPV6): + nf_log_ip6_packet(net, pf, hooknum, skb, in, out, loginfo, prefix); + break; + case htons(ETH_P_ARP): + case htons(ETH_P_RARP): + nf_log_arp_packet(net, pf, hooknum, skb, in, out, loginfo, prefix); + break; + default: + nf_log_unknown_packet(net, pf, hooknum, skb, + in, out, loginfo, prefix); + break; + } +} + +static struct nf_logger nf_netdev_logger __read_mostly = { + .name = "nf_log_netdev", + .type = NF_LOG_TYPE_LOG, + .logfn = nf_log_netdev_packet, + .me = THIS_MODULE, +}; + +static struct nf_logger nf_bridge_logger __read_mostly = { + .name = "nf_log_bridge", + .type = NF_LOG_TYPE_LOG, + .logfn = nf_log_netdev_packet, + .me = THIS_MODULE, +}; + +static int __net_init nf_log_syslog_net_init(struct net *net) +{ + int ret = nf_log_set(net, NFPROTO_IPV4, &nf_ip_logger); + + if (ret) + return ret; + + ret = nf_log_set(net, NFPROTO_ARP, &nf_arp_logger); + if (ret) + goto err1; + + ret = nf_log_set(net, NFPROTO_IPV6, &nf_ip6_logger); + if (ret) + goto err2; + + ret = nf_log_set(net, NFPROTO_NETDEV, &nf_netdev_logger); + if (ret) + goto err3; + + ret = nf_log_set(net, NFPROTO_BRIDGE, &nf_bridge_logger); + if (ret) + goto err4; + return 0; +err4: + nf_log_unset(net, &nf_netdev_logger); +err3: + nf_log_unset(net, &nf_ip6_logger); +err2: + nf_log_unset(net, &nf_arp_logger); +err1: + nf_log_unset(net, &nf_ip_logger); + return ret; +} + +static void __net_exit nf_log_syslog_net_exit(struct net *net) +{ + nf_log_unset(net, &nf_ip_logger); + nf_log_unset(net, &nf_arp_logger); + nf_log_unset(net, &nf_ip6_logger); + nf_log_unset(net, &nf_netdev_logger); + nf_log_unset(net, &nf_bridge_logger); +} + +static struct pernet_operations nf_log_syslog_net_ops = { + .init = nf_log_syslog_net_init, + .exit = nf_log_syslog_net_exit, +}; + +static int __init nf_log_syslog_init(void) +{ + int ret; + + ret = register_pernet_subsys(&nf_log_syslog_net_ops); + if (ret < 0) + return ret; + + ret = nf_log_register(NFPROTO_IPV4, &nf_ip_logger); + if (ret < 0) + goto err1; + + ret = nf_log_register(NFPROTO_ARP, &nf_arp_logger); + if (ret < 0) + goto err2; + + ret = nf_log_register(NFPROTO_IPV6, &nf_ip6_logger); + if (ret < 0) + goto err3; + + ret = nf_log_register(NFPROTO_NETDEV, &nf_netdev_logger); + if (ret < 0) + goto err4; + + ret = nf_log_register(NFPROTO_BRIDGE, &nf_bridge_logger); + if (ret < 0) + goto err5; + + return 0; +err5: + nf_log_unregister(&nf_netdev_logger); +err4: + nf_log_unregister(&nf_ip6_logger); +err3: + nf_log_unregister(&nf_arp_logger); +err2: + nf_log_unregister(&nf_ip_logger); +err1: + pr_err("failed to register logger\n"); + unregister_pernet_subsys(&nf_log_syslog_net_ops); + return ret; +} + +static void __exit nf_log_syslog_exit(void) +{ + unregister_pernet_subsys(&nf_log_syslog_net_ops); + nf_log_unregister(&nf_ip_logger); + nf_log_unregister(&nf_arp_logger); + nf_log_unregister(&nf_ip6_logger); + nf_log_unregister(&nf_netdev_logger); + nf_log_unregister(&nf_bridge_logger); +} + +module_init(nf_log_syslog_init); +module_exit(nf_log_syslog_exit); + +MODULE_AUTHOR("Netfilter Core Team "); +MODULE_DESCRIPTION("Netfilter syslog packet logging"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("nf_log_arp"); +MODULE_ALIAS("nf_log_bridge"); +MODULE_ALIAS("nf_log_ipv4"); +MODULE_ALIAS("nf_log_ipv6"); +MODULE_ALIAS("nf_log_netdev"); +MODULE_ALIAS_NF_LOGGER(AF_BRIDGE, 0); +MODULE_ALIAS_NF_LOGGER(AF_INET, 0); +MODULE_ALIAS_NF_LOGGER(3, 0); +MODULE_ALIAS_NF_LOGGER(5, 0); /* NFPROTO_NETDEV */ +MODULE_ALIAS_NF_LOGGER(AF_INET6, 0); diff --git a/net/netfilter/nf_nat_amanda.c b/net/netfilter/nf_nat_amanda.c new file mode 100644 index 000000000..98deef6cd --- /dev/null +++ b/net/netfilter/nf_nat_amanda.c @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* Amanda extension for TCP NAT alteration. + * (C) 2002 by Brian J. Murrell + * based on a copy of HW's ip_nat_irc.c as well as other modules + * (C) 2006-2012 Patrick McHardy + */ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#define NAT_HELPER_NAME "amanda" + +MODULE_AUTHOR("Brian J. Murrell "); +MODULE_DESCRIPTION("Amanda NAT helper"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NF_NAT_HELPER(NAT_HELPER_NAME); + +static struct nf_conntrack_nat_helper nat_helper_amanda = + NF_CT_NAT_HELPER_INIT(NAT_HELPER_NAME); + +static unsigned int help(struct sk_buff *skb, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned int matchoff, + unsigned int matchlen, + struct nf_conntrack_expect *exp) +{ + char buffer[sizeof("65535")]; + u_int16_t port; + + /* Connection comes from client. */ + exp->saved_proto.tcp.port = exp->tuple.dst.u.tcp.port; + exp->dir = IP_CT_DIR_ORIGINAL; + + /* When you see the packet, we need to NAT it the same as the + * this one (ie. same IP: it will be TCP and master is UDP). */ + exp->expectfn = nf_nat_follow_master; + + /* Try to get same port: if not, try to change it. */ + port = nf_nat_exp_find_port(exp, ntohs(exp->saved_proto.tcp.port)); + if (port == 0) { + nf_ct_helper_log(skb, exp->master, "all ports in use"); + return NF_DROP; + } + + sprintf(buffer, "%u", port); + if (!nf_nat_mangle_udp_packet(skb, exp->master, ctinfo, + protoff, matchoff, matchlen, + buffer, strlen(buffer))) { + nf_ct_helper_log(skb, exp->master, "cannot mangle packet"); + nf_ct_unexpect_related(exp); + return NF_DROP; + } + return NF_ACCEPT; +} + +static void __exit nf_nat_amanda_fini(void) +{ + nf_nat_helper_unregister(&nat_helper_amanda); + RCU_INIT_POINTER(nf_nat_amanda_hook, NULL); + synchronize_rcu(); +} + +static int __init nf_nat_amanda_init(void) +{ + BUG_ON(nf_nat_amanda_hook != NULL); + nf_nat_helper_register(&nat_helper_amanda); + RCU_INIT_POINTER(nf_nat_amanda_hook, help); + return 0; +} + +module_init(nf_nat_amanda_init); +module_exit(nf_nat_amanda_fini); diff --git a/net/netfilter/nf_nat_bpf.c b/net/netfilter/nf_nat_bpf.c new file mode 100644 index 000000000..0fa5a0bbb --- /dev/null +++ b/net/netfilter/nf_nat_bpf.c @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Unstable NAT Helpers for XDP and TC-BPF hook + * + * These are called from the XDP and SCHED_CLS BPF programs. Note that it is + * allowed to break compatibility for these functions since the interface they + * are exposed through to BPF programs is explicitly unstable. + */ + +#include +#include +#include +#include +#include + +__diag_push(); +__diag_ignore_all("-Wmissing-prototypes", + "Global functions as their definitions will be in nf_nat BTF"); + +/* bpf_ct_set_nat_info - Set source or destination nat address + * + * Set source or destination nat address of the newly allocated + * nf_conn before insertion. This must be invoked for referenced + * PTR_TO_BTF_ID to nf_conn___init. + * + * Parameters: + * @nfct - Pointer to referenced nf_conn object, obtained using + * bpf_xdp_ct_alloc or bpf_skb_ct_alloc. + * @addr - Nat source/destination address + * @port - Nat source/destination port. Non-positive values are + * interpreted as select a random port. + * @manip - NF_NAT_MANIP_SRC or NF_NAT_MANIP_DST + */ +int bpf_ct_set_nat_info(struct nf_conn___init *nfct, + union nf_inet_addr *addr, int port, + enum nf_nat_manip_type manip) +{ + struct nf_conn *ct = (struct nf_conn *)nfct; + u16 proto = nf_ct_l3num(ct); + struct nf_nat_range2 range; + + if (proto != NFPROTO_IPV4 && proto != NFPROTO_IPV6) + return -EINVAL; + + memset(&range, 0, sizeof(struct nf_nat_range2)); + range.flags = NF_NAT_RANGE_MAP_IPS; + range.min_addr = *addr; + range.max_addr = range.min_addr; + if (port > 0) { + range.flags |= NF_NAT_RANGE_PROTO_SPECIFIED; + range.min_proto.all = cpu_to_be16(port); + range.max_proto.all = range.min_proto.all; + } + + return nf_nat_setup_info(ct, &range, manip) == NF_DROP ? -ENOMEM : 0; +} + +__diag_pop() + +BTF_SET8_START(nf_nat_kfunc_set) +BTF_ID_FLAGS(func, bpf_ct_set_nat_info, KF_TRUSTED_ARGS) +BTF_SET8_END(nf_nat_kfunc_set) + +static const struct btf_kfunc_id_set nf_bpf_nat_kfunc_set = { + .owner = THIS_MODULE, + .set = &nf_nat_kfunc_set, +}; + +int register_nf_nat_bpf(void) +{ + int ret; + + ret = register_btf_kfunc_id_set(BPF_PROG_TYPE_XDP, + &nf_bpf_nat_kfunc_set); + if (ret) + return ret; + + return register_btf_kfunc_id_set(BPF_PROG_TYPE_SCHED_CLS, + &nf_bpf_nat_kfunc_set); +} diff --git a/net/netfilter/nf_nat_core.c b/net/netfilter/nf_nat_core.c new file mode 100644 index 000000000..e29e4ccb5 --- /dev/null +++ b/net/netfilter/nf_nat_core.c @@ -0,0 +1,1184 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2006 Netfilter Core Team + * (C) 2011 Patrick McHardy + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "nf_internals.h" + +static spinlock_t nf_nat_locks[CONNTRACK_LOCKS]; + +static DEFINE_MUTEX(nf_nat_proto_mutex); +static unsigned int nat_net_id __read_mostly; + +static struct hlist_head *nf_nat_bysource __read_mostly; +static unsigned int nf_nat_htable_size __read_mostly; +static siphash_aligned_key_t nf_nat_hash_rnd; + +struct nf_nat_lookup_hook_priv { + struct nf_hook_entries __rcu *entries; + + struct rcu_head rcu_head; +}; + +struct nf_nat_hooks_net { + struct nf_hook_ops *nat_hook_ops; + unsigned int users; +}; + +struct nat_net { + struct nf_nat_hooks_net nat_proto_net[NFPROTO_NUMPROTO]; +}; + +#ifdef CONFIG_XFRM +static void nf_nat_ipv4_decode_session(struct sk_buff *skb, + const struct nf_conn *ct, + enum ip_conntrack_dir dir, + unsigned long statusbit, + struct flowi *fl) +{ + const struct nf_conntrack_tuple *t = &ct->tuplehash[dir].tuple; + struct flowi4 *fl4 = &fl->u.ip4; + + if (ct->status & statusbit) { + fl4->daddr = t->dst.u3.ip; + if (t->dst.protonum == IPPROTO_TCP || + t->dst.protonum == IPPROTO_UDP || + t->dst.protonum == IPPROTO_UDPLITE || + t->dst.protonum == IPPROTO_DCCP || + t->dst.protonum == IPPROTO_SCTP) + fl4->fl4_dport = t->dst.u.all; + } + + statusbit ^= IPS_NAT_MASK; + + if (ct->status & statusbit) { + fl4->saddr = t->src.u3.ip; + if (t->dst.protonum == IPPROTO_TCP || + t->dst.protonum == IPPROTO_UDP || + t->dst.protonum == IPPROTO_UDPLITE || + t->dst.protonum == IPPROTO_DCCP || + t->dst.protonum == IPPROTO_SCTP) + fl4->fl4_sport = t->src.u.all; + } +} + +static void nf_nat_ipv6_decode_session(struct sk_buff *skb, + const struct nf_conn *ct, + enum ip_conntrack_dir dir, + unsigned long statusbit, + struct flowi *fl) +{ +#if IS_ENABLED(CONFIG_IPV6) + const struct nf_conntrack_tuple *t = &ct->tuplehash[dir].tuple; + struct flowi6 *fl6 = &fl->u.ip6; + + if (ct->status & statusbit) { + fl6->daddr = t->dst.u3.in6; + if (t->dst.protonum == IPPROTO_TCP || + t->dst.protonum == IPPROTO_UDP || + t->dst.protonum == IPPROTO_UDPLITE || + t->dst.protonum == IPPROTO_DCCP || + t->dst.protonum == IPPROTO_SCTP) + fl6->fl6_dport = t->dst.u.all; + } + + statusbit ^= IPS_NAT_MASK; + + if (ct->status & statusbit) { + fl6->saddr = t->src.u3.in6; + if (t->dst.protonum == IPPROTO_TCP || + t->dst.protonum == IPPROTO_UDP || + t->dst.protonum == IPPROTO_UDPLITE || + t->dst.protonum == IPPROTO_DCCP || + t->dst.protonum == IPPROTO_SCTP) + fl6->fl6_sport = t->src.u.all; + } +#endif +} + +static void __nf_nat_decode_session(struct sk_buff *skb, struct flowi *fl) +{ + const struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + enum ip_conntrack_dir dir; + unsigned long statusbit; + u8 family; + + ct = nf_ct_get(skb, &ctinfo); + if (ct == NULL) + return; + + family = nf_ct_l3num(ct); + dir = CTINFO2DIR(ctinfo); + if (dir == IP_CT_DIR_ORIGINAL) + statusbit = IPS_DST_NAT; + else + statusbit = IPS_SRC_NAT; + + switch (family) { + case NFPROTO_IPV4: + nf_nat_ipv4_decode_session(skb, ct, dir, statusbit, fl); + return; + case NFPROTO_IPV6: + nf_nat_ipv6_decode_session(skb, ct, dir, statusbit, fl); + return; + } +} +#endif /* CONFIG_XFRM */ + +/* We keep an extra hash for each conntrack, for fast searching. */ +static unsigned int +hash_by_src(const struct net *net, + const struct nf_conntrack_zone *zone, + const struct nf_conntrack_tuple *tuple) +{ + unsigned int hash; + struct { + struct nf_conntrack_man src; + u32 net_mix; + u32 protonum; + u32 zone; + } __aligned(SIPHASH_ALIGNMENT) combined; + + get_random_once(&nf_nat_hash_rnd, sizeof(nf_nat_hash_rnd)); + + memset(&combined, 0, sizeof(combined)); + + /* Original src, to ensure we map it consistently if poss. */ + combined.src = tuple->src; + combined.net_mix = net_hash_mix(net); + combined.protonum = tuple->dst.protonum; + + /* Zone ID can be used provided its valid for both directions */ + if (zone->dir == NF_CT_DEFAULT_ZONE_DIR) + combined.zone = zone->id; + + hash = siphash(&combined, sizeof(combined), &nf_nat_hash_rnd); + + return reciprocal_scale(hash, nf_nat_htable_size); +} + +/* Is this tuple already taken? (not by us) */ +static int +nf_nat_used_tuple(const struct nf_conntrack_tuple *tuple, + const struct nf_conn *ignored_conntrack) +{ + /* Conntrack tracking doesn't keep track of outgoing tuples; only + * incoming ones. NAT means they don't have a fixed mapping, + * so we invert the tuple and look for the incoming reply. + * + * We could keep a separate hash if this proves too slow. + */ + struct nf_conntrack_tuple reply; + + nf_ct_invert_tuple(&reply, tuple); + return nf_conntrack_tuple_taken(&reply, ignored_conntrack); +} + +static bool nf_nat_inet_in_range(const struct nf_conntrack_tuple *t, + const struct nf_nat_range2 *range) +{ + if (t->src.l3num == NFPROTO_IPV4) + return ntohl(t->src.u3.ip) >= ntohl(range->min_addr.ip) && + ntohl(t->src.u3.ip) <= ntohl(range->max_addr.ip); + + return ipv6_addr_cmp(&t->src.u3.in6, &range->min_addr.in6) >= 0 && + ipv6_addr_cmp(&t->src.u3.in6, &range->max_addr.in6) <= 0; +} + +/* Is the manipable part of the tuple between min and max incl? */ +static bool l4proto_in_range(const struct nf_conntrack_tuple *tuple, + enum nf_nat_manip_type maniptype, + const union nf_conntrack_man_proto *min, + const union nf_conntrack_man_proto *max) +{ + __be16 port; + + switch (tuple->dst.protonum) { + case IPPROTO_ICMP: + case IPPROTO_ICMPV6: + return ntohs(tuple->src.u.icmp.id) >= ntohs(min->icmp.id) && + ntohs(tuple->src.u.icmp.id) <= ntohs(max->icmp.id); + case IPPROTO_GRE: /* all fall though */ + case IPPROTO_TCP: + case IPPROTO_UDP: + case IPPROTO_UDPLITE: + case IPPROTO_DCCP: + case IPPROTO_SCTP: + if (maniptype == NF_NAT_MANIP_SRC) + port = tuple->src.u.all; + else + port = tuple->dst.u.all; + + return ntohs(port) >= ntohs(min->all) && + ntohs(port) <= ntohs(max->all); + default: + return true; + } +} + +/* If we source map this tuple so reply looks like reply_tuple, will + * that meet the constraints of range. + */ +static int in_range(const struct nf_conntrack_tuple *tuple, + const struct nf_nat_range2 *range) +{ + /* If we are supposed to map IPs, then we must be in the + * range specified, otherwise let this drag us onto a new src IP. + */ + if (range->flags & NF_NAT_RANGE_MAP_IPS && + !nf_nat_inet_in_range(tuple, range)) + return 0; + + if (!(range->flags & NF_NAT_RANGE_PROTO_SPECIFIED)) + return 1; + + return l4proto_in_range(tuple, NF_NAT_MANIP_SRC, + &range->min_proto, &range->max_proto); +} + +static inline int +same_src(const struct nf_conn *ct, + const struct nf_conntrack_tuple *tuple) +{ + const struct nf_conntrack_tuple *t; + + t = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple; + return (t->dst.protonum == tuple->dst.protonum && + nf_inet_addr_cmp(&t->src.u3, &tuple->src.u3) && + t->src.u.all == tuple->src.u.all); +} + +/* Only called for SRC manip */ +static int +find_appropriate_src(struct net *net, + const struct nf_conntrack_zone *zone, + const struct nf_conntrack_tuple *tuple, + struct nf_conntrack_tuple *result, + const struct nf_nat_range2 *range) +{ + unsigned int h = hash_by_src(net, zone, tuple); + const struct nf_conn *ct; + + hlist_for_each_entry_rcu(ct, &nf_nat_bysource[h], nat_bysource) { + if (same_src(ct, tuple) && + net_eq(net, nf_ct_net(ct)) && + nf_ct_zone_equal(ct, zone, IP_CT_DIR_ORIGINAL)) { + /* Copy source part from reply tuple. */ + nf_ct_invert_tuple(result, + &ct->tuplehash[IP_CT_DIR_REPLY].tuple); + result->dst = tuple->dst; + + if (in_range(result, range)) + return 1; + } + } + return 0; +} + +/* For [FUTURE] fragmentation handling, we want the least-used + * src-ip/dst-ip/proto triple. Fairness doesn't come into it. Thus + * if the range specifies 1.2.3.4 ports 10000-10005 and 1.2.3.5 ports + * 1-65535, we don't do pro-rata allocation based on ports; we choose + * the ip with the lowest src-ip/dst-ip/proto usage. + */ +static void +find_best_ips_proto(const struct nf_conntrack_zone *zone, + struct nf_conntrack_tuple *tuple, + const struct nf_nat_range2 *range, + const struct nf_conn *ct, + enum nf_nat_manip_type maniptype) +{ + union nf_inet_addr *var_ipp; + unsigned int i, max; + /* Host order */ + u32 minip, maxip, j, dist; + bool full_range; + + /* No IP mapping? Do nothing. */ + if (!(range->flags & NF_NAT_RANGE_MAP_IPS)) + return; + + if (maniptype == NF_NAT_MANIP_SRC) + var_ipp = &tuple->src.u3; + else + var_ipp = &tuple->dst.u3; + + /* Fast path: only one choice. */ + if (nf_inet_addr_cmp(&range->min_addr, &range->max_addr)) { + *var_ipp = range->min_addr; + return; + } + + if (nf_ct_l3num(ct) == NFPROTO_IPV4) + max = sizeof(var_ipp->ip) / sizeof(u32) - 1; + else + max = sizeof(var_ipp->ip6) / sizeof(u32) - 1; + + /* Hashing source and destination IPs gives a fairly even + * spread in practice (if there are a small number of IPs + * involved, there usually aren't that many connections + * anyway). The consistency means that servers see the same + * client coming from the same IP (some Internet Banking sites + * like this), even across reboots. + */ + j = jhash2((u32 *)&tuple->src.u3, sizeof(tuple->src.u3) / sizeof(u32), + range->flags & NF_NAT_RANGE_PERSISTENT ? + 0 : (__force u32)tuple->dst.u3.all[max] ^ zone->id); + + full_range = false; + for (i = 0; i <= max; i++) { + /* If first bytes of the address are at the maximum, use the + * distance. Otherwise use the full range. + */ + if (!full_range) { + minip = ntohl((__force __be32)range->min_addr.all[i]); + maxip = ntohl((__force __be32)range->max_addr.all[i]); + dist = maxip - minip + 1; + } else { + minip = 0; + dist = ~0; + } + + var_ipp->all[i] = (__force __u32) + htonl(minip + reciprocal_scale(j, dist)); + if (var_ipp->all[i] != range->max_addr.all[i]) + full_range = true; + + if (!(range->flags & NF_NAT_RANGE_PERSISTENT)) + j ^= (__force u32)tuple->dst.u3.all[i]; + } +} + +/* Alter the per-proto part of the tuple (depending on maniptype), to + * give a unique tuple in the given range if possible. + * + * Per-protocol part of tuple is initialized to the incoming packet. + */ +static void nf_nat_l4proto_unique_tuple(struct nf_conntrack_tuple *tuple, + const struct nf_nat_range2 *range, + enum nf_nat_manip_type maniptype, + const struct nf_conn *ct) +{ + unsigned int range_size, min, max, i, attempts; + __be16 *keyptr; + u16 off; + static const unsigned int max_attempts = 128; + + switch (tuple->dst.protonum) { + case IPPROTO_ICMP: + case IPPROTO_ICMPV6: + /* id is same for either direction... */ + keyptr = &tuple->src.u.icmp.id; + if (!(range->flags & NF_NAT_RANGE_PROTO_SPECIFIED)) { + min = 0; + range_size = 65536; + } else { + min = ntohs(range->min_proto.icmp.id); + range_size = ntohs(range->max_proto.icmp.id) - + ntohs(range->min_proto.icmp.id) + 1; + } + goto find_free_id; +#if IS_ENABLED(CONFIG_NF_CT_PROTO_GRE) + case IPPROTO_GRE: + /* If there is no master conntrack we are not PPTP, + do not change tuples */ + if (!ct->master) + return; + + if (maniptype == NF_NAT_MANIP_SRC) + keyptr = &tuple->src.u.gre.key; + else + keyptr = &tuple->dst.u.gre.key; + + if (!(range->flags & NF_NAT_RANGE_PROTO_SPECIFIED)) { + min = 1; + range_size = 65535; + } else { + min = ntohs(range->min_proto.gre.key); + range_size = ntohs(range->max_proto.gre.key) - min + 1; + } + goto find_free_id; +#endif + case IPPROTO_UDP: + case IPPROTO_UDPLITE: + case IPPROTO_TCP: + case IPPROTO_SCTP: + case IPPROTO_DCCP: + if (maniptype == NF_NAT_MANIP_SRC) + keyptr = &tuple->src.u.all; + else + keyptr = &tuple->dst.u.all; + + break; + default: + return; + } + + /* If no range specified... */ + if (!(range->flags & NF_NAT_RANGE_PROTO_SPECIFIED)) { + /* If it's dst rewrite, can't change port */ + if (maniptype == NF_NAT_MANIP_DST) + return; + + if (ntohs(*keyptr) < 1024) { + /* Loose convention: >> 512 is credential passing */ + if (ntohs(*keyptr) < 512) { + min = 1; + range_size = 511 - min + 1; + } else { + min = 600; + range_size = 1023 - min + 1; + } + } else { + min = 1024; + range_size = 65535 - 1024 + 1; + } + } else { + min = ntohs(range->min_proto.all); + max = ntohs(range->max_proto.all); + if (unlikely(max < min)) + swap(max, min); + range_size = max - min + 1; + } + +find_free_id: + if (range->flags & NF_NAT_RANGE_PROTO_OFFSET) + off = (ntohs(*keyptr) - ntohs(range->base_proto.all)); + else + off = get_random_u16(); + + attempts = range_size; + if (attempts > max_attempts) + attempts = max_attempts; + + /* We are in softirq; doing a search of the entire range risks + * soft lockup when all tuples are already used. + * + * If we can't find any free port from first offset, pick a new + * one and try again, with ever smaller search window. + */ +another_round: + for (i = 0; i < attempts; i++, off++) { + *keyptr = htons(min + off % range_size); + if (!nf_nat_used_tuple(tuple, ct)) + return; + } + + if (attempts >= range_size || attempts < 16) + return; + attempts /= 2; + off = get_random_u16(); + goto another_round; +} + +/* Manipulate the tuple into the range given. For NF_INET_POST_ROUTING, + * we change the source to map into the range. For NF_INET_PRE_ROUTING + * and NF_INET_LOCAL_OUT, we change the destination to map into the + * range. It might not be possible to get a unique tuple, but we try. + * At worst (or if we race), we will end up with a final duplicate in + * __nf_conntrack_confirm and drop the packet. */ +static void +get_unique_tuple(struct nf_conntrack_tuple *tuple, + const struct nf_conntrack_tuple *orig_tuple, + const struct nf_nat_range2 *range, + struct nf_conn *ct, + enum nf_nat_manip_type maniptype) +{ + const struct nf_conntrack_zone *zone; + struct net *net = nf_ct_net(ct); + + zone = nf_ct_zone(ct); + + /* 1) If this srcip/proto/src-proto-part is currently mapped, + * and that same mapping gives a unique tuple within the given + * range, use that. + * + * This is only required for source (ie. NAT/masq) mappings. + * So far, we don't do local source mappings, so multiple + * manips not an issue. + */ + if (maniptype == NF_NAT_MANIP_SRC && + !(range->flags & NF_NAT_RANGE_PROTO_RANDOM_ALL)) { + /* try the original tuple first */ + if (in_range(orig_tuple, range)) { + if (!nf_nat_used_tuple(orig_tuple, ct)) { + *tuple = *orig_tuple; + return; + } + } else if (find_appropriate_src(net, zone, + orig_tuple, tuple, range)) { + pr_debug("get_unique_tuple: Found current src map\n"); + if (!nf_nat_used_tuple(tuple, ct)) + return; + } + } + + /* 2) Select the least-used IP/proto combination in the given range */ + *tuple = *orig_tuple; + find_best_ips_proto(zone, tuple, range, ct, maniptype); + + /* 3) The per-protocol part of the manip is made to map into + * the range to make a unique tuple. + */ + + /* Only bother mapping if it's not already in range and unique */ + if (!(range->flags & NF_NAT_RANGE_PROTO_RANDOM_ALL)) { + if (range->flags & NF_NAT_RANGE_PROTO_SPECIFIED) { + if (!(range->flags & NF_NAT_RANGE_PROTO_OFFSET) && + l4proto_in_range(tuple, maniptype, + &range->min_proto, + &range->max_proto) && + (range->min_proto.all == range->max_proto.all || + !nf_nat_used_tuple(tuple, ct))) + return; + } else if (!nf_nat_used_tuple(tuple, ct)) { + return; + } + } + + /* Last chance: get protocol to try to obtain unique tuple. */ + nf_nat_l4proto_unique_tuple(tuple, range, maniptype, ct); +} + +struct nf_conn_nat *nf_ct_nat_ext_add(struct nf_conn *ct) +{ + struct nf_conn_nat *nat = nfct_nat(ct); + if (nat) + return nat; + + if (!nf_ct_is_confirmed(ct)) + nat = nf_ct_ext_add(ct, NF_CT_EXT_NAT, GFP_ATOMIC); + + return nat; +} +EXPORT_SYMBOL_GPL(nf_ct_nat_ext_add); + +unsigned int +nf_nat_setup_info(struct nf_conn *ct, + const struct nf_nat_range2 *range, + enum nf_nat_manip_type maniptype) +{ + struct net *net = nf_ct_net(ct); + struct nf_conntrack_tuple curr_tuple, new_tuple; + + /* Can't setup nat info for confirmed ct. */ + if (nf_ct_is_confirmed(ct)) + return NF_ACCEPT; + + WARN_ON(maniptype != NF_NAT_MANIP_SRC && + maniptype != NF_NAT_MANIP_DST); + + if (WARN_ON(nf_nat_initialized(ct, maniptype))) + return NF_DROP; + + /* What we've got will look like inverse of reply. Normally + * this is what is in the conntrack, except for prior + * manipulations (future optimization: if num_manips == 0, + * orig_tp = ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple) + */ + nf_ct_invert_tuple(&curr_tuple, + &ct->tuplehash[IP_CT_DIR_REPLY].tuple); + + get_unique_tuple(&new_tuple, &curr_tuple, range, ct, maniptype); + + if (!nf_ct_tuple_equal(&new_tuple, &curr_tuple)) { + struct nf_conntrack_tuple reply; + + /* Alter conntrack table so will recognize replies. */ + nf_ct_invert_tuple(&reply, &new_tuple); + nf_conntrack_alter_reply(ct, &reply); + + /* Non-atomic: we own this at the moment. */ + if (maniptype == NF_NAT_MANIP_SRC) + ct->status |= IPS_SRC_NAT; + else + ct->status |= IPS_DST_NAT; + + if (nfct_help(ct) && !nfct_seqadj(ct)) + if (!nfct_seqadj_ext_add(ct)) + return NF_DROP; + } + + if (maniptype == NF_NAT_MANIP_SRC) { + unsigned int srchash; + spinlock_t *lock; + + srchash = hash_by_src(net, nf_ct_zone(ct), + &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple); + lock = &nf_nat_locks[srchash % CONNTRACK_LOCKS]; + spin_lock_bh(lock); + hlist_add_head_rcu(&ct->nat_bysource, + &nf_nat_bysource[srchash]); + spin_unlock_bh(lock); + } + + /* It's done. */ + if (maniptype == NF_NAT_MANIP_DST) + ct->status |= IPS_DST_NAT_DONE; + else + ct->status |= IPS_SRC_NAT_DONE; + + return NF_ACCEPT; +} +EXPORT_SYMBOL(nf_nat_setup_info); + +static unsigned int +__nf_nat_alloc_null_binding(struct nf_conn *ct, enum nf_nat_manip_type manip) +{ + /* Force range to this IP; let proto decide mapping for + * per-proto parts (hence not IP_NAT_RANGE_PROTO_SPECIFIED). + * Use reply in case it's already been mangled (eg local packet). + */ + union nf_inet_addr ip = + (manip == NF_NAT_MANIP_SRC ? + ct->tuplehash[IP_CT_DIR_REPLY].tuple.dst.u3 : + ct->tuplehash[IP_CT_DIR_REPLY].tuple.src.u3); + struct nf_nat_range2 range = { + .flags = NF_NAT_RANGE_MAP_IPS, + .min_addr = ip, + .max_addr = ip, + }; + return nf_nat_setup_info(ct, &range, manip); +} + +unsigned int +nf_nat_alloc_null_binding(struct nf_conn *ct, unsigned int hooknum) +{ + return __nf_nat_alloc_null_binding(ct, HOOK2MANIP(hooknum)); +} +EXPORT_SYMBOL_GPL(nf_nat_alloc_null_binding); + +/* Do packet manipulations according to nf_nat_setup_info. */ +unsigned int nf_nat_packet(struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int hooknum, + struct sk_buff *skb) +{ + enum nf_nat_manip_type mtype = HOOK2MANIP(hooknum); + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + unsigned int verdict = NF_ACCEPT; + unsigned long statusbit; + + if (mtype == NF_NAT_MANIP_SRC) + statusbit = IPS_SRC_NAT; + else + statusbit = IPS_DST_NAT; + + /* Invert if this is reply dir. */ + if (dir == IP_CT_DIR_REPLY) + statusbit ^= IPS_NAT_MASK; + + /* Non-atomic: these bits don't change. */ + if (ct->status & statusbit) + verdict = nf_nat_manip_pkt(skb, ct, mtype, dir); + + return verdict; +} +EXPORT_SYMBOL_GPL(nf_nat_packet); + +static bool in_vrf_postrouting(const struct nf_hook_state *state) +{ +#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV) + if (state->hook == NF_INET_POST_ROUTING && + netif_is_l3_master(state->out)) + return true; +#endif + return false; +} + +unsigned int +nf_nat_inet_fn(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + struct nf_conn_nat *nat; + /* maniptype == SRC for postrouting. */ + enum nf_nat_manip_type maniptype = HOOK2MANIP(state->hook); + + ct = nf_ct_get(skb, &ctinfo); + /* Can't track? It's not due to stress, or conntrack would + * have dropped it. Hence it's the user's responsibilty to + * packet filter it out, or implement conntrack/NAT for that + * protocol. 8) --RR + */ + if (!ct || in_vrf_postrouting(state)) + return NF_ACCEPT; + + nat = nfct_nat(ct); + + switch (ctinfo) { + case IP_CT_RELATED: + case IP_CT_RELATED_REPLY: + /* Only ICMPs can be IP_CT_IS_REPLY. Fallthrough */ + case IP_CT_NEW: + /* Seen it before? This can happen for loopback, retrans, + * or local packets. + */ + if (!nf_nat_initialized(ct, maniptype)) { + struct nf_nat_lookup_hook_priv *lpriv = priv; + struct nf_hook_entries *e = rcu_dereference(lpriv->entries); + unsigned int ret; + int i; + + if (!e) + goto null_bind; + + for (i = 0; i < e->num_hook_entries; i++) { + ret = e->hooks[i].hook(e->hooks[i].priv, skb, + state); + if (ret != NF_ACCEPT) + return ret; + if (nf_nat_initialized(ct, maniptype)) + goto do_nat; + } +null_bind: + ret = nf_nat_alloc_null_binding(ct, state->hook); + if (ret != NF_ACCEPT) + return ret; + } else { + pr_debug("Already setup manip %s for ct %p (status bits 0x%lx)\n", + maniptype == NF_NAT_MANIP_SRC ? "SRC" : "DST", + ct, ct->status); + if (nf_nat_oif_changed(state->hook, ctinfo, nat, + state->out)) + goto oif_changed; + } + break; + default: + /* ESTABLISHED */ + WARN_ON(ctinfo != IP_CT_ESTABLISHED && + ctinfo != IP_CT_ESTABLISHED_REPLY); + if (nf_nat_oif_changed(state->hook, ctinfo, nat, state->out)) + goto oif_changed; + } +do_nat: + return nf_nat_packet(ct, ctinfo, state->hook, skb); + +oif_changed: + nf_ct_kill_acct(ct, ctinfo, skb); + return NF_DROP; +} +EXPORT_SYMBOL_GPL(nf_nat_inet_fn); + +struct nf_nat_proto_clean { + u8 l3proto; + u8 l4proto; +}; + +/* kill conntracks with affected NAT section */ +static int nf_nat_proto_remove(struct nf_conn *i, void *data) +{ + const struct nf_nat_proto_clean *clean = data; + + if ((clean->l3proto && nf_ct_l3num(i) != clean->l3proto) || + (clean->l4proto && nf_ct_protonum(i) != clean->l4proto)) + return 0; + + return i->status & IPS_NAT_MASK ? 1 : 0; +} + +static void nf_nat_cleanup_conntrack(struct nf_conn *ct) +{ + unsigned int h; + + h = hash_by_src(nf_ct_net(ct), nf_ct_zone(ct), &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple); + spin_lock_bh(&nf_nat_locks[h % CONNTRACK_LOCKS]); + hlist_del_rcu(&ct->nat_bysource); + spin_unlock_bh(&nf_nat_locks[h % CONNTRACK_LOCKS]); +} + +static int nf_nat_proto_clean(struct nf_conn *ct, void *data) +{ + if (nf_nat_proto_remove(ct, data)) + return 1; + + /* This module is being removed and conntrack has nat null binding. + * Remove it from bysource hash, as the table will be freed soon. + * + * Else, when the conntrack is destoyed, nf_nat_cleanup_conntrack() + * will delete entry from already-freed table. + */ + if (test_and_clear_bit(IPS_SRC_NAT_DONE_BIT, &ct->status)) + nf_nat_cleanup_conntrack(ct); + + /* don't delete conntrack. Although that would make things a lot + * simpler, we'd end up flushing all conntracks on nat rmmod. + */ + return 0; +} + +#if IS_ENABLED(CONFIG_NF_CT_NETLINK) + +#include +#include + +static const struct nla_policy protonat_nla_policy[CTA_PROTONAT_MAX+1] = { + [CTA_PROTONAT_PORT_MIN] = { .type = NLA_U16 }, + [CTA_PROTONAT_PORT_MAX] = { .type = NLA_U16 }, +}; + +static int nf_nat_l4proto_nlattr_to_range(struct nlattr *tb[], + struct nf_nat_range2 *range) +{ + if (tb[CTA_PROTONAT_PORT_MIN]) { + range->min_proto.all = nla_get_be16(tb[CTA_PROTONAT_PORT_MIN]); + range->max_proto.all = range->min_proto.all; + range->flags |= NF_NAT_RANGE_PROTO_SPECIFIED; + } + if (tb[CTA_PROTONAT_PORT_MAX]) { + range->max_proto.all = nla_get_be16(tb[CTA_PROTONAT_PORT_MAX]); + range->flags |= NF_NAT_RANGE_PROTO_SPECIFIED; + } + return 0; +} + +static int nfnetlink_parse_nat_proto(struct nlattr *attr, + const struct nf_conn *ct, + struct nf_nat_range2 *range) +{ + struct nlattr *tb[CTA_PROTONAT_MAX+1]; + int err; + + err = nla_parse_nested_deprecated(tb, CTA_PROTONAT_MAX, attr, + protonat_nla_policy, NULL); + if (err < 0) + return err; + + return nf_nat_l4proto_nlattr_to_range(tb, range); +} + +static const struct nla_policy nat_nla_policy[CTA_NAT_MAX+1] = { + [CTA_NAT_V4_MINIP] = { .type = NLA_U32 }, + [CTA_NAT_V4_MAXIP] = { .type = NLA_U32 }, + [CTA_NAT_V6_MINIP] = { .len = sizeof(struct in6_addr) }, + [CTA_NAT_V6_MAXIP] = { .len = sizeof(struct in6_addr) }, + [CTA_NAT_PROTO] = { .type = NLA_NESTED }, +}; + +static int nf_nat_ipv4_nlattr_to_range(struct nlattr *tb[], + struct nf_nat_range2 *range) +{ + if (tb[CTA_NAT_V4_MINIP]) { + range->min_addr.ip = nla_get_be32(tb[CTA_NAT_V4_MINIP]); + range->flags |= NF_NAT_RANGE_MAP_IPS; + } + + if (tb[CTA_NAT_V4_MAXIP]) + range->max_addr.ip = nla_get_be32(tb[CTA_NAT_V4_MAXIP]); + else + range->max_addr.ip = range->min_addr.ip; + + return 0; +} + +static int nf_nat_ipv6_nlattr_to_range(struct nlattr *tb[], + struct nf_nat_range2 *range) +{ + if (tb[CTA_NAT_V6_MINIP]) { + nla_memcpy(&range->min_addr.ip6, tb[CTA_NAT_V6_MINIP], + sizeof(struct in6_addr)); + range->flags |= NF_NAT_RANGE_MAP_IPS; + } + + if (tb[CTA_NAT_V6_MAXIP]) + nla_memcpy(&range->max_addr.ip6, tb[CTA_NAT_V6_MAXIP], + sizeof(struct in6_addr)); + else + range->max_addr = range->min_addr; + + return 0; +} + +static int +nfnetlink_parse_nat(const struct nlattr *nat, + const struct nf_conn *ct, struct nf_nat_range2 *range) +{ + struct nlattr *tb[CTA_NAT_MAX+1]; + int err; + + memset(range, 0, sizeof(*range)); + + err = nla_parse_nested_deprecated(tb, CTA_NAT_MAX, nat, + nat_nla_policy, NULL); + if (err < 0) + return err; + + switch (nf_ct_l3num(ct)) { + case NFPROTO_IPV4: + err = nf_nat_ipv4_nlattr_to_range(tb, range); + break; + case NFPROTO_IPV6: + err = nf_nat_ipv6_nlattr_to_range(tb, range); + break; + default: + err = -EPROTONOSUPPORT; + break; + } + + if (err) + return err; + + if (!tb[CTA_NAT_PROTO]) + return 0; + + return nfnetlink_parse_nat_proto(tb[CTA_NAT_PROTO], ct, range); +} + +/* This function is called under rcu_read_lock() */ +static int +nfnetlink_parse_nat_setup(struct nf_conn *ct, + enum nf_nat_manip_type manip, + const struct nlattr *attr) +{ + struct nf_nat_range2 range; + int err; + + /* Should not happen, restricted to creating new conntracks + * via ctnetlink. + */ + if (WARN_ON_ONCE(nf_nat_initialized(ct, manip))) + return -EEXIST; + + /* No NAT information has been passed, allocate the null-binding */ + if (attr == NULL) + return __nf_nat_alloc_null_binding(ct, manip) == NF_DROP ? -ENOMEM : 0; + + err = nfnetlink_parse_nat(attr, ct, &range); + if (err < 0) + return err; + + return nf_nat_setup_info(ct, &range, manip) == NF_DROP ? -ENOMEM : 0; +} +#else +static int +nfnetlink_parse_nat_setup(struct nf_conn *ct, + enum nf_nat_manip_type manip, + const struct nlattr *attr) +{ + return -EOPNOTSUPP; +} +#endif + +static struct nf_ct_helper_expectfn follow_master_nat = { + .name = "nat-follow-master", + .expectfn = nf_nat_follow_master, +}; + +int nf_nat_register_fn(struct net *net, u8 pf, const struct nf_hook_ops *ops, + const struct nf_hook_ops *orig_nat_ops, unsigned int ops_count) +{ + struct nat_net *nat_net = net_generic(net, nat_net_id); + struct nf_nat_hooks_net *nat_proto_net; + struct nf_nat_lookup_hook_priv *priv; + unsigned int hooknum = ops->hooknum; + struct nf_hook_ops *nat_ops; + int i, ret; + + if (WARN_ON_ONCE(pf >= ARRAY_SIZE(nat_net->nat_proto_net))) + return -EINVAL; + + nat_proto_net = &nat_net->nat_proto_net[pf]; + + for (i = 0; i < ops_count; i++) { + if (orig_nat_ops[i].hooknum == hooknum) { + hooknum = i; + break; + } + } + + if (WARN_ON_ONCE(i == ops_count)) + return -EINVAL; + + mutex_lock(&nf_nat_proto_mutex); + if (!nat_proto_net->nat_hook_ops) { + WARN_ON(nat_proto_net->users != 0); + + nat_ops = kmemdup(orig_nat_ops, sizeof(*orig_nat_ops) * ops_count, GFP_KERNEL); + if (!nat_ops) { + mutex_unlock(&nf_nat_proto_mutex); + return -ENOMEM; + } + + for (i = 0; i < ops_count; i++) { + priv = kzalloc(sizeof(*priv), GFP_KERNEL); + if (priv) { + nat_ops[i].priv = priv; + continue; + } + mutex_unlock(&nf_nat_proto_mutex); + while (i) + kfree(nat_ops[--i].priv); + kfree(nat_ops); + return -ENOMEM; + } + + ret = nf_register_net_hooks(net, nat_ops, ops_count); + if (ret < 0) { + mutex_unlock(&nf_nat_proto_mutex); + for (i = 0; i < ops_count; i++) + kfree(nat_ops[i].priv); + kfree(nat_ops); + return ret; + } + + nat_proto_net->nat_hook_ops = nat_ops; + } + + nat_ops = nat_proto_net->nat_hook_ops; + priv = nat_ops[hooknum].priv; + if (WARN_ON_ONCE(!priv)) { + mutex_unlock(&nf_nat_proto_mutex); + return -EOPNOTSUPP; + } + + ret = nf_hook_entries_insert_raw(&priv->entries, ops); + if (ret == 0) + nat_proto_net->users++; + + mutex_unlock(&nf_nat_proto_mutex); + return ret; +} + +void nf_nat_unregister_fn(struct net *net, u8 pf, const struct nf_hook_ops *ops, + unsigned int ops_count) +{ + struct nat_net *nat_net = net_generic(net, nat_net_id); + struct nf_nat_hooks_net *nat_proto_net; + struct nf_nat_lookup_hook_priv *priv; + struct nf_hook_ops *nat_ops; + int hooknum = ops->hooknum; + int i; + + if (pf >= ARRAY_SIZE(nat_net->nat_proto_net)) + return; + + nat_proto_net = &nat_net->nat_proto_net[pf]; + + mutex_lock(&nf_nat_proto_mutex); + if (WARN_ON(nat_proto_net->users == 0)) + goto unlock; + + nat_proto_net->users--; + + nat_ops = nat_proto_net->nat_hook_ops; + for (i = 0; i < ops_count; i++) { + if (nat_ops[i].hooknum == hooknum) { + hooknum = i; + break; + } + } + if (WARN_ON_ONCE(i == ops_count)) + goto unlock; + priv = nat_ops[hooknum].priv; + nf_hook_entries_delete_raw(&priv->entries, ops); + + if (nat_proto_net->users == 0) { + nf_unregister_net_hooks(net, nat_ops, ops_count); + + for (i = 0; i < ops_count; i++) { + priv = nat_ops[i].priv; + kfree_rcu(priv, rcu_head); + } + + nat_proto_net->nat_hook_ops = NULL; + kfree(nat_ops); + } +unlock: + mutex_unlock(&nf_nat_proto_mutex); +} + +static struct pernet_operations nat_net_ops = { + .id = &nat_net_id, + .size = sizeof(struct nat_net), +}; + +static const struct nf_nat_hook nat_hook = { + .parse_nat_setup = nfnetlink_parse_nat_setup, +#ifdef CONFIG_XFRM + .decode_session = __nf_nat_decode_session, +#endif + .manip_pkt = nf_nat_manip_pkt, + .remove_nat_bysrc = nf_nat_cleanup_conntrack, +}; + +static int __init nf_nat_init(void) +{ + int ret, i; + + /* Leave them the same for the moment. */ + nf_nat_htable_size = nf_conntrack_htable_size; + if (nf_nat_htable_size < CONNTRACK_LOCKS) + nf_nat_htable_size = CONNTRACK_LOCKS; + + nf_nat_bysource = nf_ct_alloc_hashtable(&nf_nat_htable_size, 0); + if (!nf_nat_bysource) + return -ENOMEM; + + for (i = 0; i < CONNTRACK_LOCKS; i++) + spin_lock_init(&nf_nat_locks[i]); + + ret = register_pernet_subsys(&nat_net_ops); + if (ret < 0) { + kvfree(nf_nat_bysource); + return ret; + } + + nf_ct_helper_expectfn_register(&follow_master_nat); + + WARN_ON(nf_nat_hook != NULL); + RCU_INIT_POINTER(nf_nat_hook, &nat_hook); + + ret = register_nf_nat_bpf(); + if (ret < 0) { + RCU_INIT_POINTER(nf_nat_hook, NULL); + nf_ct_helper_expectfn_unregister(&follow_master_nat); + synchronize_net(); + unregister_pernet_subsys(&nat_net_ops); + kvfree(nf_nat_bysource); + } + + return ret; +} + +static void __exit nf_nat_cleanup(void) +{ + struct nf_nat_proto_clean clean = {}; + + nf_ct_iterate_destroy(nf_nat_proto_clean, &clean); + + nf_ct_helper_expectfn_unregister(&follow_master_nat); + RCU_INIT_POINTER(nf_nat_hook, NULL); + + synchronize_net(); + kvfree(nf_nat_bysource); + unregister_pernet_subsys(&nat_net_ops); +} + +MODULE_LICENSE("GPL"); + +module_init(nf_nat_init); +module_exit(nf_nat_cleanup); diff --git a/net/netfilter/nf_nat_ftp.c b/net/netfilter/nf_nat_ftp.c new file mode 100644 index 000000000..c92a436d9 --- /dev/null +++ b/net/netfilter/nf_nat_ftp.c @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* FTP extension for TCP NAT alteration. */ + +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2006 Netfilter Core Team + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define NAT_HELPER_NAME "ftp" + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Rusty Russell "); +MODULE_DESCRIPTION("ftp NAT helper"); +MODULE_ALIAS_NF_NAT_HELPER(NAT_HELPER_NAME); + +/* FIXME: Time out? --RR */ + +static struct nf_conntrack_nat_helper nat_helper_ftp = + NF_CT_NAT_HELPER_INIT(NAT_HELPER_NAME); + +static int nf_nat_ftp_fmt_cmd(struct nf_conn *ct, enum nf_ct_ftp_type type, + char *buffer, size_t buflen, + union nf_inet_addr *addr, u16 port) +{ + switch (type) { + case NF_CT_FTP_PORT: + case NF_CT_FTP_PASV: + return snprintf(buffer, buflen, "%u,%u,%u,%u,%u,%u", + ((unsigned char *)&addr->ip)[0], + ((unsigned char *)&addr->ip)[1], + ((unsigned char *)&addr->ip)[2], + ((unsigned char *)&addr->ip)[3], + port >> 8, + port & 0xFF); + case NF_CT_FTP_EPRT: + if (nf_ct_l3num(ct) == NFPROTO_IPV4) + return snprintf(buffer, buflen, "|1|%pI4|%u|", + &addr->ip, port); + else + return snprintf(buffer, buflen, "|2|%pI6|%u|", + &addr->ip6, port); + case NF_CT_FTP_EPSV: + return snprintf(buffer, buflen, "|||%u|", port); + } + + return 0; +} + +/* So, this packet has hit the connection tracking matching code. + Mangle it, and change the expectation to match the new version. */ +static unsigned int nf_nat_ftp(struct sk_buff *skb, + enum ip_conntrack_info ctinfo, + enum nf_ct_ftp_type type, + unsigned int protoff, + unsigned int matchoff, + unsigned int matchlen, + struct nf_conntrack_expect *exp) +{ + union nf_inet_addr newaddr; + u_int16_t port; + int dir = CTINFO2DIR(ctinfo); + struct nf_conn *ct = exp->master; + char buffer[sizeof("|1||65535|") + INET6_ADDRSTRLEN]; + unsigned int buflen; + + pr_debug("type %i, off %u len %u\n", type, matchoff, matchlen); + + /* Connection will come from wherever this packet goes, hence !dir */ + newaddr = ct->tuplehash[!dir].tuple.dst.u3; + exp->saved_proto.tcp.port = exp->tuple.dst.u.tcp.port; + exp->dir = !dir; + + /* When you see the packet, we need to NAT it the same as the + * this one. */ + exp->expectfn = nf_nat_follow_master; + + port = nf_nat_exp_find_port(exp, ntohs(exp->saved_proto.tcp.port)); + if (port == 0) { + nf_ct_helper_log(skb, exp->master, "all ports in use"); + return NF_DROP; + } + + buflen = nf_nat_ftp_fmt_cmd(ct, type, buffer, sizeof(buffer), + &newaddr, port); + if (!buflen) + goto out; + + pr_debug("calling nf_nat_mangle_tcp_packet\n"); + + if (!nf_nat_mangle_tcp_packet(skb, ct, ctinfo, protoff, matchoff, + matchlen, buffer, buflen)) + goto out; + + return NF_ACCEPT; + +out: + nf_ct_helper_log(skb, ct, "cannot mangle packet"); + nf_ct_unexpect_related(exp); + return NF_DROP; +} + +static void __exit nf_nat_ftp_fini(void) +{ + nf_nat_helper_unregister(&nat_helper_ftp); + RCU_INIT_POINTER(nf_nat_ftp_hook, NULL); + synchronize_rcu(); +} + +static int __init nf_nat_ftp_init(void) +{ + BUG_ON(nf_nat_ftp_hook != NULL); + nf_nat_helper_register(&nat_helper_ftp); + RCU_INIT_POINTER(nf_nat_ftp_hook, nf_nat_ftp); + return 0; +} + +/* Prior to 2.6.11, we had a ports param. No longer, but don't break users. */ +static int warn_set(const char *val, const struct kernel_param *kp) +{ + pr_info("kernel >= 2.6.10 only uses 'ports' for conntrack modules\n"); + return 0; +} +module_param_call(ports, warn_set, NULL, NULL, 0); + +module_init(nf_nat_ftp_init); +module_exit(nf_nat_ftp_fini); diff --git a/net/netfilter/nf_nat_helper.c b/net/netfilter/nf_nat_helper.c new file mode 100644 index 000000000..a95a25196 --- /dev/null +++ b/net/netfilter/nf_nat_helper.c @@ -0,0 +1,231 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* nf_nat_helper.c - generic support functions for NAT helpers + * + * (C) 2000-2002 Harald Welte + * (C) 2003-2006 Netfilter Core Team + * (C) 2007-2012 Patrick McHardy + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +/* Frobs data inside this packet, which is linear. */ +static void mangle_contents(struct sk_buff *skb, + unsigned int dataoff, + unsigned int match_offset, + unsigned int match_len, + const char *rep_buffer, + unsigned int rep_len) +{ + unsigned char *data; + + SKB_LINEAR_ASSERT(skb); + data = skb_network_header(skb) + dataoff; + + /* move post-replacement */ + memmove(data + match_offset + rep_len, + data + match_offset + match_len, + skb_tail_pointer(skb) - (skb_network_header(skb) + dataoff + + match_offset + match_len)); + + /* insert data from buffer */ + memcpy(data + match_offset, rep_buffer, rep_len); + + /* update skb info */ + if (rep_len > match_len) { + pr_debug("nf_nat_mangle_packet: Extending packet by " + "%u from %u bytes\n", rep_len - match_len, skb->len); + skb_put(skb, rep_len - match_len); + } else { + pr_debug("nf_nat_mangle_packet: Shrinking packet from " + "%u from %u bytes\n", match_len - rep_len, skb->len); + __skb_trim(skb, skb->len + rep_len - match_len); + } + + if (nf_ct_l3num((struct nf_conn *)skb_nfct(skb)) == NFPROTO_IPV4) { + /* fix IP hdr checksum information */ + ip_hdr(skb)->tot_len = htons(skb->len); + ip_send_check(ip_hdr(skb)); + } else + ipv6_hdr(skb)->payload_len = + htons(skb->len - sizeof(struct ipv6hdr)); +} + +/* Unusual, but possible case. */ +static bool enlarge_skb(struct sk_buff *skb, unsigned int extra) +{ + if (skb->len + extra > 65535) + return false; + + if (pskb_expand_head(skb, 0, extra - skb_tailroom(skb), GFP_ATOMIC)) + return false; + + return true; +} + +/* Generic function for mangling variable-length address changes inside + * NATed TCP connections (like the PORT XXX,XXX,XXX,XXX,XXX,XXX + * command in FTP). + * + * Takes care about all the nasty sequence number changes, checksumming, + * skb enlargement, ... + * + * */ +bool __nf_nat_mangle_tcp_packet(struct sk_buff *skb, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned int match_offset, + unsigned int match_len, + const char *rep_buffer, + unsigned int rep_len, bool adjust) +{ + struct tcphdr *tcph; + int oldlen, datalen; + + if (skb_ensure_writable(skb, skb->len)) + return false; + + if (rep_len > match_len && + rep_len - match_len > skb_tailroom(skb) && + !enlarge_skb(skb, rep_len - match_len)) + return false; + + tcph = (void *)skb->data + protoff; + + oldlen = skb->len - protoff; + mangle_contents(skb, protoff + tcph->doff*4, + match_offset, match_len, rep_buffer, rep_len); + + datalen = skb->len - protoff; + + nf_nat_csum_recalc(skb, nf_ct_l3num(ct), IPPROTO_TCP, + tcph, &tcph->check, datalen, oldlen); + + if (adjust && rep_len != match_len) + nf_ct_seqadj_set(ct, ctinfo, tcph->seq, + (int)rep_len - (int)match_len); + + return true; +} +EXPORT_SYMBOL(__nf_nat_mangle_tcp_packet); + +/* Generic function for mangling variable-length address changes inside + * NATed UDP connections (like the CONNECT DATA XXXXX MESG XXXXX INDEX XXXXX + * command in the Amanda protocol) + * + * Takes care about all the nasty sequence number changes, checksumming, + * skb enlargement, ... + * + * XXX - This function could be merged with nf_nat_mangle_tcp_packet which + * should be fairly easy to do. + */ +bool +nf_nat_mangle_udp_packet(struct sk_buff *skb, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned int match_offset, + unsigned int match_len, + const char *rep_buffer, + unsigned int rep_len) +{ + struct udphdr *udph; + int datalen, oldlen; + + if (skb_ensure_writable(skb, skb->len)) + return false; + + if (rep_len > match_len && + rep_len - match_len > skb_tailroom(skb) && + !enlarge_skb(skb, rep_len - match_len)) + return false; + + udph = (void *)skb->data + protoff; + + oldlen = skb->len - protoff; + mangle_contents(skb, protoff + sizeof(*udph), + match_offset, match_len, rep_buffer, rep_len); + + /* update the length of the UDP packet */ + datalen = skb->len - protoff; + udph->len = htons(datalen); + + /* fix udp checksum if udp checksum was previously calculated */ + if (!udph->check && skb->ip_summed != CHECKSUM_PARTIAL) + return true; + + nf_nat_csum_recalc(skb, nf_ct_l3num(ct), IPPROTO_UDP, + udph, &udph->check, datalen, oldlen); + + return true; +} +EXPORT_SYMBOL(nf_nat_mangle_udp_packet); + +/* Setup NAT on this expected conntrack so it follows master. */ +/* If we fail to get a free NAT slot, we'll get dropped on confirm */ +void nf_nat_follow_master(struct nf_conn *ct, + struct nf_conntrack_expect *exp) +{ + struct nf_nat_range2 range; + + /* This must be a fresh one. */ + BUG_ON(ct->status & IPS_NAT_DONE_MASK); + + /* Change src to where master sends to */ + range.flags = NF_NAT_RANGE_MAP_IPS; + range.min_addr = range.max_addr + = ct->master->tuplehash[!exp->dir].tuple.dst.u3; + nf_nat_setup_info(ct, &range, NF_NAT_MANIP_SRC); + + /* For DST manip, map port here to where it's expected. */ + range.flags = (NF_NAT_RANGE_MAP_IPS | NF_NAT_RANGE_PROTO_SPECIFIED); + range.min_proto = range.max_proto = exp->saved_proto; + range.min_addr = range.max_addr + = ct->master->tuplehash[!exp->dir].tuple.src.u3; + nf_nat_setup_info(ct, &range, NF_NAT_MANIP_DST); +} +EXPORT_SYMBOL(nf_nat_follow_master); + +u16 nf_nat_exp_find_port(struct nf_conntrack_expect *exp, u16 port) +{ + static const unsigned int max_attempts = 128; + int range, attempts_left; + u16 min = port; + + range = USHRT_MAX - port; + attempts_left = range; + + if (attempts_left > max_attempts) + attempts_left = max_attempts; + + /* Try to get same port: if not, try to change it. */ + for (;;) { + int res; + + exp->tuple.dst.u.tcp.port = htons(port); + res = nf_ct_expect_related(exp, 0); + if (res == 0) + return port; + + if (res != -EBUSY || (--attempts_left < 0)) + break; + + port = min + prandom_u32_max(range); + } + + return 0; +} +EXPORT_SYMBOL_GPL(nf_nat_exp_find_port); diff --git a/net/netfilter/nf_nat_irc.c b/net/netfilter/nf_nat_irc.c new file mode 100644 index 000000000..19c4fcc60 --- /dev/null +++ b/net/netfilter/nf_nat_irc.c @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* IRC extension for TCP NAT alteration. + * + * (C) 2000-2001 by Harald Welte + * (C) 2004 Rusty Russell IBM Corporation + * based on a copy of RR's ip_nat_ftp.c + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#define NAT_HELPER_NAME "irc" + +MODULE_AUTHOR("Harald Welte "); +MODULE_DESCRIPTION("IRC (DCC) NAT helper"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NF_NAT_HELPER(NAT_HELPER_NAME); + +static struct nf_conntrack_nat_helper nat_helper_irc = + NF_CT_NAT_HELPER_INIT(NAT_HELPER_NAME); + +static unsigned int help(struct sk_buff *skb, + enum ip_conntrack_info ctinfo, + unsigned int protoff, + unsigned int matchoff, + unsigned int matchlen, + struct nf_conntrack_expect *exp) +{ + char buffer[sizeof("4294967296 65635")]; + struct nf_conn *ct = exp->master; + union nf_inet_addr newaddr; + u_int16_t port; + + /* Reply comes from server. */ + newaddr = ct->tuplehash[IP_CT_DIR_REPLY].tuple.dst.u3; + + exp->saved_proto.tcp.port = exp->tuple.dst.u.tcp.port; + exp->dir = IP_CT_DIR_REPLY; + exp->expectfn = nf_nat_follow_master; + + port = nf_nat_exp_find_port(exp, + ntohs(exp->saved_proto.tcp.port)); + if (port == 0) { + nf_ct_helper_log(skb, ct, "all ports in use"); + return NF_DROP; + } + + /* strlen("\1DCC CHAT chat AAAAAAAA P\1\n")=27 + * strlen("\1DCC SCHAT chat AAAAAAAA P\1\n")=28 + * strlen("\1DCC SEND F AAAAAAAA P S\1\n")=26 + * strlen("\1DCC MOVE F AAAAAAAA P S\1\n")=26 + * strlen("\1DCC TSEND F AAAAAAAA P S\1\n")=27 + * + * AAAAAAAAA: bound addr (1.0.0.0==16777216, min 8 digits, + * 255.255.255.255==4294967296, 10 digits) + * P: bound port (min 1 d, max 5d (65635)) + * F: filename (min 1 d ) + * S: size (min 1 d ) + * 0x01, \n: terminators + */ + /* AAA = "us", ie. where server normally talks to. */ + snprintf(buffer, sizeof(buffer), "%u %u", ntohl(newaddr.ip), port); + pr_debug("inserting '%s' == %pI4, port %u\n", + buffer, &newaddr.ip, port); + + if (!nf_nat_mangle_tcp_packet(skb, ct, ctinfo, protoff, matchoff, + matchlen, buffer, strlen(buffer))) { + nf_ct_helper_log(skb, ct, "cannot mangle packet"); + nf_ct_unexpect_related(exp); + return NF_DROP; + } + + return NF_ACCEPT; +} + +static void __exit nf_nat_irc_fini(void) +{ + nf_nat_helper_unregister(&nat_helper_irc); + RCU_INIT_POINTER(nf_nat_irc_hook, NULL); + synchronize_rcu(); +} + +static int __init nf_nat_irc_init(void) +{ + BUG_ON(nf_nat_irc_hook != NULL); + nf_nat_helper_register(&nat_helper_irc); + RCU_INIT_POINTER(nf_nat_irc_hook, help); + return 0; +} + +/* Prior to 2.6.11, we had a ports param. No longer, but don't break users. */ +static int warn_set(const char *val, const struct kernel_param *kp) +{ + pr_info("kernel >= 2.6.10 only uses 'ports' for conntrack modules\n"); + return 0; +} +module_param_call(ports, warn_set, NULL, NULL, 0); + +module_init(nf_nat_irc_init); +module_exit(nf_nat_irc_fini); diff --git a/net/netfilter/nf_nat_masquerade.c b/net/netfilter/nf_nat_masquerade.c new file mode 100644 index 000000000..1a506b0c6 --- /dev/null +++ b/net/netfilter/nf_nat_masquerade.c @@ -0,0 +1,368 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include +#include +#include +#include +#include +#include + +#include + +struct masq_dev_work { + struct work_struct work; + struct net *net; + netns_tracker ns_tracker; + union nf_inet_addr addr; + int ifindex; + int (*iter)(struct nf_conn *i, void *data); +}; + +#define MAX_MASQ_WORKER_COUNT 16 + +static DEFINE_MUTEX(masq_mutex); +static unsigned int masq_refcnt __read_mostly; +static atomic_t masq_worker_count __read_mostly; + +unsigned int +nf_nat_masquerade_ipv4(struct sk_buff *skb, unsigned int hooknum, + const struct nf_nat_range2 *range, + const struct net_device *out) +{ + struct nf_conn *ct; + struct nf_conn_nat *nat; + enum ip_conntrack_info ctinfo; + struct nf_nat_range2 newrange; + const struct rtable *rt; + __be32 newsrc, nh; + + WARN_ON(hooknum != NF_INET_POST_ROUTING); + + ct = nf_ct_get(skb, &ctinfo); + + WARN_ON(!(ct && (ctinfo == IP_CT_NEW || ctinfo == IP_CT_RELATED || + ctinfo == IP_CT_RELATED_REPLY))); + + /* Source address is 0.0.0.0 - locally generated packet that is + * probably not supposed to be masqueraded. + */ + if (ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.u3.ip == 0) + return NF_ACCEPT; + + rt = skb_rtable(skb); + nh = rt_nexthop(rt, ip_hdr(skb)->daddr); + newsrc = inet_select_addr(out, nh, RT_SCOPE_UNIVERSE); + if (!newsrc) { + pr_info("%s ate my IP address\n", out->name); + return NF_DROP; + } + + nat = nf_ct_nat_ext_add(ct); + if (nat) + nat->masq_index = out->ifindex; + + /* Transfer from original range. */ + memset(&newrange.min_addr, 0, sizeof(newrange.min_addr)); + memset(&newrange.max_addr, 0, sizeof(newrange.max_addr)); + newrange.flags = range->flags | NF_NAT_RANGE_MAP_IPS; + newrange.min_addr.ip = newsrc; + newrange.max_addr.ip = newsrc; + newrange.min_proto = range->min_proto; + newrange.max_proto = range->max_proto; + + /* Hand modified range to generic setup. */ + return nf_nat_setup_info(ct, &newrange, NF_NAT_MANIP_SRC); +} +EXPORT_SYMBOL_GPL(nf_nat_masquerade_ipv4); + +static void iterate_cleanup_work(struct work_struct *work) +{ + struct nf_ct_iter_data iter_data = {}; + struct masq_dev_work *w; + + w = container_of(work, struct masq_dev_work, work); + + iter_data.net = w->net; + iter_data.data = (void *)w; + nf_ct_iterate_cleanup_net(w->iter, &iter_data); + + put_net_track(w->net, &w->ns_tracker); + kfree(w); + atomic_dec(&masq_worker_count); + module_put(THIS_MODULE); +} + +/* Iterate conntrack table in the background and remove conntrack entries + * that use the device/address being removed. + * + * In case too many work items have been queued already or memory allocation + * fails iteration is skipped, conntrack entries will time out eventually. + */ +static void nf_nat_masq_schedule(struct net *net, union nf_inet_addr *addr, + int ifindex, + int (*iter)(struct nf_conn *i, void *data), + gfp_t gfp_flags) +{ + struct masq_dev_work *w; + + if (atomic_read(&masq_worker_count) > MAX_MASQ_WORKER_COUNT) + return; + + net = maybe_get_net(net); + if (!net) + return; + + if (!try_module_get(THIS_MODULE)) + goto err_module; + + w = kzalloc(sizeof(*w), gfp_flags); + if (w) { + /* We can overshoot MAX_MASQ_WORKER_COUNT, no big deal */ + atomic_inc(&masq_worker_count); + + INIT_WORK(&w->work, iterate_cleanup_work); + w->ifindex = ifindex; + w->net = net; + netns_tracker_alloc(net, &w->ns_tracker, gfp_flags); + w->iter = iter; + if (addr) + w->addr = *addr; + schedule_work(&w->work); + return; + } + + module_put(THIS_MODULE); + err_module: + put_net(net); +} + +static int device_cmp(struct nf_conn *i, void *arg) +{ + const struct nf_conn_nat *nat = nfct_nat(i); + const struct masq_dev_work *w = arg; + + if (!nat) + return 0; + return nat->masq_index == w->ifindex; +} + +static int masq_device_event(struct notifier_block *this, + unsigned long event, + void *ptr) +{ + const struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct net *net = dev_net(dev); + + if (event == NETDEV_DOWN) { + /* Device was downed. Search entire table for + * conntracks which were associated with that device, + * and forget them. + */ + + nf_nat_masq_schedule(net, NULL, dev->ifindex, + device_cmp, GFP_KERNEL); + } + + return NOTIFY_DONE; +} + +static int inet_cmp(struct nf_conn *ct, void *ptr) +{ + struct nf_conntrack_tuple *tuple; + struct masq_dev_work *w = ptr; + + if (!device_cmp(ct, ptr)) + return 0; + + tuple = &ct->tuplehash[IP_CT_DIR_REPLY].tuple; + + return nf_inet_addr_cmp(&w->addr, &tuple->dst.u3); +} + +static int masq_inet_event(struct notifier_block *this, + unsigned long event, + void *ptr) +{ + const struct in_ifaddr *ifa = ptr; + const struct in_device *idev; + const struct net_device *dev; + union nf_inet_addr addr; + + if (event != NETDEV_DOWN) + return NOTIFY_DONE; + + /* The masq_dev_notifier will catch the case of the device going + * down. So if the inetdev is dead and being destroyed we have + * no work to do. Otherwise this is an individual address removal + * and we have to perform the flush. + */ + idev = ifa->ifa_dev; + if (idev->dead) + return NOTIFY_DONE; + + memset(&addr, 0, sizeof(addr)); + + addr.ip = ifa->ifa_address; + + dev = idev->dev; + nf_nat_masq_schedule(dev_net(idev->dev), &addr, dev->ifindex, + inet_cmp, GFP_KERNEL); + + return NOTIFY_DONE; +} + +static struct notifier_block masq_dev_notifier = { + .notifier_call = masq_device_event, +}; + +static struct notifier_block masq_inet_notifier = { + .notifier_call = masq_inet_event, +}; + +#if IS_ENABLED(CONFIG_IPV6) +static int +nat_ipv6_dev_get_saddr(struct net *net, const struct net_device *dev, + const struct in6_addr *daddr, unsigned int srcprefs, + struct in6_addr *saddr) +{ +#ifdef CONFIG_IPV6_MODULE + const struct nf_ipv6_ops *v6_ops = nf_get_ipv6_ops(); + + if (!v6_ops) + return -EHOSTUNREACH; + + return v6_ops->dev_get_saddr(net, dev, daddr, srcprefs, saddr); +#else + return ipv6_dev_get_saddr(net, dev, daddr, srcprefs, saddr); +#endif +} + +unsigned int +nf_nat_masquerade_ipv6(struct sk_buff *skb, const struct nf_nat_range2 *range, + const struct net_device *out) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn_nat *nat; + struct in6_addr src; + struct nf_conn *ct; + struct nf_nat_range2 newrange; + + ct = nf_ct_get(skb, &ctinfo); + WARN_ON(!(ct && (ctinfo == IP_CT_NEW || ctinfo == IP_CT_RELATED || + ctinfo == IP_CT_RELATED_REPLY))); + + if (nat_ipv6_dev_get_saddr(nf_ct_net(ct), out, + &ipv6_hdr(skb)->daddr, 0, &src) < 0) + return NF_DROP; + + nat = nf_ct_nat_ext_add(ct); + if (nat) + nat->masq_index = out->ifindex; + + newrange.flags = range->flags | NF_NAT_RANGE_MAP_IPS; + newrange.min_addr.in6 = src; + newrange.max_addr.in6 = src; + newrange.min_proto = range->min_proto; + newrange.max_proto = range->max_proto; + + return nf_nat_setup_info(ct, &newrange, NF_NAT_MANIP_SRC); +} +EXPORT_SYMBOL_GPL(nf_nat_masquerade_ipv6); + +/* atomic notifier; can't call nf_ct_iterate_cleanup_net (it can sleep). + * + * Defer it to the system workqueue. + * + * As we can have 'a lot' of inet_events (depending on amount of ipv6 + * addresses being deleted), we also need to limit work item queue. + */ +static int masq_inet6_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct inet6_ifaddr *ifa = ptr; + const struct net_device *dev; + union nf_inet_addr addr; + + if (event != NETDEV_DOWN) + return NOTIFY_DONE; + + dev = ifa->idev->dev; + + memset(&addr, 0, sizeof(addr)); + + addr.in6 = ifa->addr; + + nf_nat_masq_schedule(dev_net(dev), &addr, dev->ifindex, inet_cmp, + GFP_ATOMIC); + return NOTIFY_DONE; +} + +static struct notifier_block masq_inet6_notifier = { + .notifier_call = masq_inet6_event, +}; + +static int nf_nat_masquerade_ipv6_register_notifier(void) +{ + return register_inet6addr_notifier(&masq_inet6_notifier); +} +#else +static inline int nf_nat_masquerade_ipv6_register_notifier(void) { return 0; } +#endif + +int nf_nat_masquerade_inet_register_notifiers(void) +{ + int ret = 0; + + mutex_lock(&masq_mutex); + if (WARN_ON_ONCE(masq_refcnt == UINT_MAX)) { + ret = -EOVERFLOW; + goto out_unlock; + } + + /* check if the notifier was already set */ + if (++masq_refcnt > 1) + goto out_unlock; + + /* Register for device down reports */ + ret = register_netdevice_notifier(&masq_dev_notifier); + if (ret) + goto err_dec; + /* Register IP address change reports */ + ret = register_inetaddr_notifier(&masq_inet_notifier); + if (ret) + goto err_unregister; + + ret = nf_nat_masquerade_ipv6_register_notifier(); + if (ret) + goto err_unreg_inet; + + mutex_unlock(&masq_mutex); + return ret; +err_unreg_inet: + unregister_inetaddr_notifier(&masq_inet_notifier); +err_unregister: + unregister_netdevice_notifier(&masq_dev_notifier); +err_dec: + masq_refcnt--; +out_unlock: + mutex_unlock(&masq_mutex); + return ret; +} +EXPORT_SYMBOL_GPL(nf_nat_masquerade_inet_register_notifiers); + +void nf_nat_masquerade_inet_unregister_notifiers(void) +{ + mutex_lock(&masq_mutex); + /* check if the notifiers still have clients */ + if (--masq_refcnt > 0) + goto out_unlock; + + unregister_netdevice_notifier(&masq_dev_notifier); + unregister_inetaddr_notifier(&masq_inet_notifier); +#if IS_ENABLED(CONFIG_IPV6) + unregister_inet6addr_notifier(&masq_inet6_notifier); +#endif +out_unlock: + mutex_unlock(&masq_mutex); +} +EXPORT_SYMBOL_GPL(nf_nat_masquerade_inet_unregister_notifiers); diff --git a/net/netfilter/nf_nat_proto.c b/net/netfilter/nf_nat_proto.c new file mode 100644 index 000000000..48cc60084 --- /dev/null +++ b/net/netfilter/nf_nat_proto.c @@ -0,0 +1,1103 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2006 Netfilter Core Team + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +static void nf_csum_update(struct sk_buff *skb, + unsigned int iphdroff, __sum16 *check, + const struct nf_conntrack_tuple *t, + enum nf_nat_manip_type maniptype); + +static void +__udp_manip_pkt(struct sk_buff *skb, + unsigned int iphdroff, struct udphdr *hdr, + const struct nf_conntrack_tuple *tuple, + enum nf_nat_manip_type maniptype, bool do_csum) +{ + __be16 *portptr, newport; + + if (maniptype == NF_NAT_MANIP_SRC) { + /* Get rid of src port */ + newport = tuple->src.u.udp.port; + portptr = &hdr->source; + } else { + /* Get rid of dst port */ + newport = tuple->dst.u.udp.port; + portptr = &hdr->dest; + } + if (do_csum) { + nf_csum_update(skb, iphdroff, &hdr->check, tuple, maniptype); + inet_proto_csum_replace2(&hdr->check, skb, *portptr, newport, + false); + if (!hdr->check) + hdr->check = CSUM_MANGLED_0; + } + *portptr = newport; +} + +static bool udp_manip_pkt(struct sk_buff *skb, + unsigned int iphdroff, unsigned int hdroff, + const struct nf_conntrack_tuple *tuple, + enum nf_nat_manip_type maniptype) +{ + struct udphdr *hdr; + + if (skb_ensure_writable(skb, hdroff + sizeof(*hdr))) + return false; + + hdr = (struct udphdr *)(skb->data + hdroff); + __udp_manip_pkt(skb, iphdroff, hdr, tuple, maniptype, !!hdr->check); + + return true; +} + +static bool udplite_manip_pkt(struct sk_buff *skb, + unsigned int iphdroff, unsigned int hdroff, + const struct nf_conntrack_tuple *tuple, + enum nf_nat_manip_type maniptype) +{ +#ifdef CONFIG_NF_CT_PROTO_UDPLITE + struct udphdr *hdr; + + if (skb_ensure_writable(skb, hdroff + sizeof(*hdr))) + return false; + + hdr = (struct udphdr *)(skb->data + hdroff); + __udp_manip_pkt(skb, iphdroff, hdr, tuple, maniptype, true); +#endif + return true; +} + +static bool +sctp_manip_pkt(struct sk_buff *skb, + unsigned int iphdroff, unsigned int hdroff, + const struct nf_conntrack_tuple *tuple, + enum nf_nat_manip_type maniptype) +{ +#ifdef CONFIG_NF_CT_PROTO_SCTP + struct sctphdr *hdr; + int hdrsize = 8; + + /* This could be an inner header returned in imcp packet; in such + * cases we cannot update the checksum field since it is outside + * of the 8 bytes of transport layer headers we are guaranteed. + */ + if (skb->len >= hdroff + sizeof(*hdr)) + hdrsize = sizeof(*hdr); + + if (skb_ensure_writable(skb, hdroff + hdrsize)) + return false; + + hdr = (struct sctphdr *)(skb->data + hdroff); + + if (maniptype == NF_NAT_MANIP_SRC) { + /* Get rid of src port */ + hdr->source = tuple->src.u.sctp.port; + } else { + /* Get rid of dst port */ + hdr->dest = tuple->dst.u.sctp.port; + } + + if (hdrsize < sizeof(*hdr)) + return true; + + if (skb->ip_summed != CHECKSUM_PARTIAL) { + hdr->checksum = sctp_compute_cksum(skb, hdroff); + skb->ip_summed = CHECKSUM_NONE; + } + +#endif + return true; +} + +static bool +tcp_manip_pkt(struct sk_buff *skb, + unsigned int iphdroff, unsigned int hdroff, + const struct nf_conntrack_tuple *tuple, + enum nf_nat_manip_type maniptype) +{ + struct tcphdr *hdr; + __be16 *portptr, newport, oldport; + int hdrsize = 8; /* TCP connection tracking guarantees this much */ + + /* this could be a inner header returned in icmp packet; in such + cases we cannot update the checksum field since it is outside of + the 8 bytes of transport layer headers we are guaranteed */ + if (skb->len >= hdroff + sizeof(struct tcphdr)) + hdrsize = sizeof(struct tcphdr); + + if (skb_ensure_writable(skb, hdroff + hdrsize)) + return false; + + hdr = (struct tcphdr *)(skb->data + hdroff); + + if (maniptype == NF_NAT_MANIP_SRC) { + /* Get rid of src port */ + newport = tuple->src.u.tcp.port; + portptr = &hdr->source; + } else { + /* Get rid of dst port */ + newport = tuple->dst.u.tcp.port; + portptr = &hdr->dest; + } + + oldport = *portptr; + *portptr = newport; + + if (hdrsize < sizeof(*hdr)) + return true; + + nf_csum_update(skb, iphdroff, &hdr->check, tuple, maniptype); + inet_proto_csum_replace2(&hdr->check, skb, oldport, newport, false); + return true; +} + +static bool +dccp_manip_pkt(struct sk_buff *skb, + unsigned int iphdroff, unsigned int hdroff, + const struct nf_conntrack_tuple *tuple, + enum nf_nat_manip_type maniptype) +{ +#ifdef CONFIG_NF_CT_PROTO_DCCP + struct dccp_hdr *hdr; + __be16 *portptr, oldport, newport; + int hdrsize = 8; /* DCCP connection tracking guarantees this much */ + + if (skb->len >= hdroff + sizeof(struct dccp_hdr)) + hdrsize = sizeof(struct dccp_hdr); + + if (skb_ensure_writable(skb, hdroff + hdrsize)) + return false; + + hdr = (struct dccp_hdr *)(skb->data + hdroff); + + if (maniptype == NF_NAT_MANIP_SRC) { + newport = tuple->src.u.dccp.port; + portptr = &hdr->dccph_sport; + } else { + newport = tuple->dst.u.dccp.port; + portptr = &hdr->dccph_dport; + } + + oldport = *portptr; + *portptr = newport; + + if (hdrsize < sizeof(*hdr)) + return true; + + nf_csum_update(skb, iphdroff, &hdr->dccph_checksum, tuple, maniptype); + inet_proto_csum_replace2(&hdr->dccph_checksum, skb, oldport, newport, + false); +#endif + return true; +} + +static bool +icmp_manip_pkt(struct sk_buff *skb, + unsigned int iphdroff, unsigned int hdroff, + const struct nf_conntrack_tuple *tuple, + enum nf_nat_manip_type maniptype) +{ + struct icmphdr *hdr; + + if (skb_ensure_writable(skb, hdroff + sizeof(*hdr))) + return false; + + hdr = (struct icmphdr *)(skb->data + hdroff); + switch (hdr->type) { + case ICMP_ECHO: + case ICMP_ECHOREPLY: + case ICMP_TIMESTAMP: + case ICMP_TIMESTAMPREPLY: + case ICMP_INFO_REQUEST: + case ICMP_INFO_REPLY: + case ICMP_ADDRESS: + case ICMP_ADDRESSREPLY: + break; + default: + return true; + } + inet_proto_csum_replace2(&hdr->checksum, skb, + hdr->un.echo.id, tuple->src.u.icmp.id, false); + hdr->un.echo.id = tuple->src.u.icmp.id; + return true; +} + +static bool +icmpv6_manip_pkt(struct sk_buff *skb, + unsigned int iphdroff, unsigned int hdroff, + const struct nf_conntrack_tuple *tuple, + enum nf_nat_manip_type maniptype) +{ + struct icmp6hdr *hdr; + + if (skb_ensure_writable(skb, hdroff + sizeof(*hdr))) + return false; + + hdr = (struct icmp6hdr *)(skb->data + hdroff); + nf_csum_update(skb, iphdroff, &hdr->icmp6_cksum, tuple, maniptype); + if (hdr->icmp6_type == ICMPV6_ECHO_REQUEST || + hdr->icmp6_type == ICMPV6_ECHO_REPLY) { + inet_proto_csum_replace2(&hdr->icmp6_cksum, skb, + hdr->icmp6_identifier, + tuple->src.u.icmp.id, false); + hdr->icmp6_identifier = tuple->src.u.icmp.id; + } + return true; +} + +/* manipulate a GRE packet according to maniptype */ +static bool +gre_manip_pkt(struct sk_buff *skb, + unsigned int iphdroff, unsigned int hdroff, + const struct nf_conntrack_tuple *tuple, + enum nf_nat_manip_type maniptype) +{ +#if IS_ENABLED(CONFIG_NF_CT_PROTO_GRE) + const struct gre_base_hdr *greh; + struct pptp_gre_header *pgreh; + + /* pgreh includes two optional 32bit fields which are not required + * to be there. That's where the magic '8' comes from */ + if (skb_ensure_writable(skb, hdroff + sizeof(*pgreh) - 8)) + return false; + + greh = (void *)skb->data + hdroff; + pgreh = (struct pptp_gre_header *)greh; + + /* we only have destination manip of a packet, since 'source key' + * is not present in the packet itself */ + if (maniptype != NF_NAT_MANIP_DST) + return true; + + switch (greh->flags & GRE_VERSION) { + case GRE_VERSION_0: + /* We do not currently NAT any GREv0 packets. + * Try to behave like "nf_nat_proto_unknown" */ + break; + case GRE_VERSION_1: + pr_debug("call_id -> 0x%04x\n", ntohs(tuple->dst.u.gre.key)); + pgreh->call_id = tuple->dst.u.gre.key; + break; + default: + pr_debug("can't nat unknown GRE version\n"); + return false; + } +#endif + return true; +} + +static bool l4proto_manip_pkt(struct sk_buff *skb, + unsigned int iphdroff, unsigned int hdroff, + const struct nf_conntrack_tuple *tuple, + enum nf_nat_manip_type maniptype) +{ + switch (tuple->dst.protonum) { + case IPPROTO_TCP: + return tcp_manip_pkt(skb, iphdroff, hdroff, + tuple, maniptype); + case IPPROTO_UDP: + return udp_manip_pkt(skb, iphdroff, hdroff, + tuple, maniptype); + case IPPROTO_UDPLITE: + return udplite_manip_pkt(skb, iphdroff, hdroff, + tuple, maniptype); + case IPPROTO_SCTP: + return sctp_manip_pkt(skb, iphdroff, hdroff, + tuple, maniptype); + case IPPROTO_ICMP: + return icmp_manip_pkt(skb, iphdroff, hdroff, + tuple, maniptype); + case IPPROTO_ICMPV6: + return icmpv6_manip_pkt(skb, iphdroff, hdroff, + tuple, maniptype); + case IPPROTO_DCCP: + return dccp_manip_pkt(skb, iphdroff, hdroff, + tuple, maniptype); + case IPPROTO_GRE: + return gre_manip_pkt(skb, iphdroff, hdroff, + tuple, maniptype); + } + + /* If we don't know protocol -- no error, pass it unmodified. */ + return true; +} + +static bool nf_nat_ipv4_manip_pkt(struct sk_buff *skb, + unsigned int iphdroff, + const struct nf_conntrack_tuple *target, + enum nf_nat_manip_type maniptype) +{ + struct iphdr *iph; + unsigned int hdroff; + + if (skb_ensure_writable(skb, iphdroff + sizeof(*iph))) + return false; + + iph = (void *)skb->data + iphdroff; + hdroff = iphdroff + iph->ihl * 4; + + if (!l4proto_manip_pkt(skb, iphdroff, hdroff, target, maniptype)) + return false; + iph = (void *)skb->data + iphdroff; + + if (maniptype == NF_NAT_MANIP_SRC) { + csum_replace4(&iph->check, iph->saddr, target->src.u3.ip); + iph->saddr = target->src.u3.ip; + } else { + csum_replace4(&iph->check, iph->daddr, target->dst.u3.ip); + iph->daddr = target->dst.u3.ip; + } + return true; +} + +static bool nf_nat_ipv6_manip_pkt(struct sk_buff *skb, + unsigned int iphdroff, + const struct nf_conntrack_tuple *target, + enum nf_nat_manip_type maniptype) +{ +#if IS_ENABLED(CONFIG_IPV6) + struct ipv6hdr *ipv6h; + __be16 frag_off; + int hdroff; + u8 nexthdr; + + if (skb_ensure_writable(skb, iphdroff + sizeof(*ipv6h))) + return false; + + ipv6h = (void *)skb->data + iphdroff; + nexthdr = ipv6h->nexthdr; + hdroff = ipv6_skip_exthdr(skb, iphdroff + sizeof(*ipv6h), + &nexthdr, &frag_off); + if (hdroff < 0) + goto manip_addr; + + if ((frag_off & htons(~0x7)) == 0 && + !l4proto_manip_pkt(skb, iphdroff, hdroff, target, maniptype)) + return false; + + /* must reload, offset might have changed */ + ipv6h = (void *)skb->data + iphdroff; + +manip_addr: + if (maniptype == NF_NAT_MANIP_SRC) + ipv6h->saddr = target->src.u3.in6; + else + ipv6h->daddr = target->dst.u3.in6; + +#endif + return true; +} + +unsigned int nf_nat_manip_pkt(struct sk_buff *skb, struct nf_conn *ct, + enum nf_nat_manip_type mtype, + enum ip_conntrack_dir dir) +{ + struct nf_conntrack_tuple target; + + /* We are aiming to look like inverse of other direction. */ + nf_ct_invert_tuple(&target, &ct->tuplehash[!dir].tuple); + + switch (target.src.l3num) { + case NFPROTO_IPV6: + if (nf_nat_ipv6_manip_pkt(skb, 0, &target, mtype)) + return NF_ACCEPT; + break; + case NFPROTO_IPV4: + if (nf_nat_ipv4_manip_pkt(skb, 0, &target, mtype)) + return NF_ACCEPT; + break; + default: + WARN_ON_ONCE(1); + break; + } + + return NF_DROP; +} + +static void nf_nat_ipv4_csum_update(struct sk_buff *skb, + unsigned int iphdroff, __sum16 *check, + const struct nf_conntrack_tuple *t, + enum nf_nat_manip_type maniptype) +{ + struct iphdr *iph = (struct iphdr *)(skb->data + iphdroff); + __be32 oldip, newip; + + if (maniptype == NF_NAT_MANIP_SRC) { + oldip = iph->saddr; + newip = t->src.u3.ip; + } else { + oldip = iph->daddr; + newip = t->dst.u3.ip; + } + inet_proto_csum_replace4(check, skb, oldip, newip, true); +} + +static void nf_nat_ipv6_csum_update(struct sk_buff *skb, + unsigned int iphdroff, __sum16 *check, + const struct nf_conntrack_tuple *t, + enum nf_nat_manip_type maniptype) +{ +#if IS_ENABLED(CONFIG_IPV6) + const struct ipv6hdr *ipv6h = (struct ipv6hdr *)(skb->data + iphdroff); + const struct in6_addr *oldip, *newip; + + if (maniptype == NF_NAT_MANIP_SRC) { + oldip = &ipv6h->saddr; + newip = &t->src.u3.in6; + } else { + oldip = &ipv6h->daddr; + newip = &t->dst.u3.in6; + } + inet_proto_csum_replace16(check, skb, oldip->s6_addr32, + newip->s6_addr32, true); +#endif +} + +static void nf_csum_update(struct sk_buff *skb, + unsigned int iphdroff, __sum16 *check, + const struct nf_conntrack_tuple *t, + enum nf_nat_manip_type maniptype) +{ + switch (t->src.l3num) { + case NFPROTO_IPV4: + nf_nat_ipv4_csum_update(skb, iphdroff, check, t, maniptype); + return; + case NFPROTO_IPV6: + nf_nat_ipv6_csum_update(skb, iphdroff, check, t, maniptype); + return; + } +} + +static void nf_nat_ipv4_csum_recalc(struct sk_buff *skb, + u8 proto, void *data, __sum16 *check, + int datalen, int oldlen) +{ + if (skb->ip_summed != CHECKSUM_PARTIAL) { + const struct iphdr *iph = ip_hdr(skb); + + skb->ip_summed = CHECKSUM_PARTIAL; + skb->csum_start = skb_headroom(skb) + skb_network_offset(skb) + + ip_hdrlen(skb); + skb->csum_offset = (void *)check - data; + *check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, datalen, + proto, 0); + } else { + inet_proto_csum_replace2(check, skb, + htons(oldlen), htons(datalen), true); + } +} + +#if IS_ENABLED(CONFIG_IPV6) +static void nf_nat_ipv6_csum_recalc(struct sk_buff *skb, + u8 proto, void *data, __sum16 *check, + int datalen, int oldlen) +{ + if (skb->ip_summed != CHECKSUM_PARTIAL) { + const struct ipv6hdr *ipv6h = ipv6_hdr(skb); + + skb->ip_summed = CHECKSUM_PARTIAL; + skb->csum_start = skb_headroom(skb) + skb_network_offset(skb) + + (data - (void *)skb->data); + skb->csum_offset = (void *)check - data; + *check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr, + datalen, proto, 0); + } else { + inet_proto_csum_replace2(check, skb, + htons(oldlen), htons(datalen), true); + } +} +#endif + +void nf_nat_csum_recalc(struct sk_buff *skb, + u8 nfproto, u8 proto, void *data, __sum16 *check, + int datalen, int oldlen) +{ + switch (nfproto) { + case NFPROTO_IPV4: + nf_nat_ipv4_csum_recalc(skb, proto, data, check, + datalen, oldlen); + return; +#if IS_ENABLED(CONFIG_IPV6) + case NFPROTO_IPV6: + nf_nat_ipv6_csum_recalc(skb, proto, data, check, + datalen, oldlen); + return; +#endif + } + + WARN_ON_ONCE(1); +} + +int nf_nat_icmp_reply_translation(struct sk_buff *skb, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int hooknum) +{ + struct { + struct icmphdr icmp; + struct iphdr ip; + } *inside; + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + enum nf_nat_manip_type manip = HOOK2MANIP(hooknum); + unsigned int hdrlen = ip_hdrlen(skb); + struct nf_conntrack_tuple target; + unsigned long statusbit; + + WARN_ON(ctinfo != IP_CT_RELATED && ctinfo != IP_CT_RELATED_REPLY); + + if (skb_ensure_writable(skb, hdrlen + sizeof(*inside))) + return 0; + if (nf_ip_checksum(skb, hooknum, hdrlen, IPPROTO_ICMP)) + return 0; + + inside = (void *)skb->data + hdrlen; + if (inside->icmp.type == ICMP_REDIRECT) { + if ((ct->status & IPS_NAT_DONE_MASK) != IPS_NAT_DONE_MASK) + return 0; + if (ct->status & IPS_NAT_MASK) + return 0; + } + + if (manip == NF_NAT_MANIP_SRC) + statusbit = IPS_SRC_NAT; + else + statusbit = IPS_DST_NAT; + + /* Invert if this is reply direction */ + if (dir == IP_CT_DIR_REPLY) + statusbit ^= IPS_NAT_MASK; + + if (!(ct->status & statusbit)) + return 1; + + if (!nf_nat_ipv4_manip_pkt(skb, hdrlen + sizeof(inside->icmp), + &ct->tuplehash[!dir].tuple, !manip)) + return 0; + + if (skb->ip_summed != CHECKSUM_PARTIAL) { + /* Reloading "inside" here since manip_pkt may reallocate */ + inside = (void *)skb->data + hdrlen; + inside->icmp.checksum = 0; + inside->icmp.checksum = + csum_fold(skb_checksum(skb, hdrlen, + skb->len - hdrlen, 0)); + } + + /* Change outer to look like the reply to an incoming packet */ + nf_ct_invert_tuple(&target, &ct->tuplehash[!dir].tuple); + target.dst.protonum = IPPROTO_ICMP; + if (!nf_nat_ipv4_manip_pkt(skb, 0, &target, manip)) + return 0; + + return 1; +} +EXPORT_SYMBOL_GPL(nf_nat_icmp_reply_translation); + +static unsigned int +nf_nat_ipv4_fn(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + + ct = nf_ct_get(skb, &ctinfo); + if (!ct) + return NF_ACCEPT; + + if (ctinfo == IP_CT_RELATED || ctinfo == IP_CT_RELATED_REPLY) { + if (ip_hdr(skb)->protocol == IPPROTO_ICMP) { + if (!nf_nat_icmp_reply_translation(skb, ct, ctinfo, + state->hook)) + return NF_DROP; + else + return NF_ACCEPT; + } + } + + return nf_nat_inet_fn(priv, skb, state); +} + +static unsigned int +nf_nat_ipv4_pre_routing(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + unsigned int ret; + __be32 daddr = ip_hdr(skb)->daddr; + + ret = nf_nat_ipv4_fn(priv, skb, state); + if (ret == NF_ACCEPT && daddr != ip_hdr(skb)->daddr) + skb_dst_drop(skb); + + return ret; +} + +#ifdef CONFIG_XFRM +static int nf_xfrm_me_harder(struct net *net, struct sk_buff *skb, unsigned int family) +{ + struct sock *sk = skb->sk; + struct dst_entry *dst; + unsigned int hh_len; + struct flowi fl; + int err; + + err = xfrm_decode_session(skb, &fl, family); + if (err < 0) + return err; + + dst = skb_dst(skb); + if (dst->xfrm) + dst = ((struct xfrm_dst *)dst)->route; + if (!dst_hold_safe(dst)) + return -EHOSTUNREACH; + + if (sk && !net_eq(net, sock_net(sk))) + sk = NULL; + + dst = xfrm_lookup(net, dst, &fl, sk, 0); + if (IS_ERR(dst)) + return PTR_ERR(dst); + + skb_dst_drop(skb); + skb_dst_set(skb, dst); + + /* Change in oif may mean change in hh_len. */ + hh_len = skb_dst(skb)->dev->hard_header_len; + if (skb_headroom(skb) < hh_len && + pskb_expand_head(skb, hh_len - skb_headroom(skb), 0, GFP_ATOMIC)) + return -ENOMEM; + return 0; +} +#endif + +static unsigned int +nf_nat_ipv4_local_in(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + __be32 saddr = ip_hdr(skb)->saddr; + struct sock *sk = skb->sk; + unsigned int ret; + + ret = nf_nat_ipv4_fn(priv, skb, state); + + if (ret == NF_ACCEPT && sk && saddr != ip_hdr(skb)->saddr && + !inet_sk_transparent(sk)) + skb_orphan(skb); /* TCP edemux obtained wrong socket */ + + return ret; +} + +static unsigned int +nf_nat_ipv4_out(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ +#ifdef CONFIG_XFRM + const struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + int err; +#endif + unsigned int ret; + + ret = nf_nat_ipv4_fn(priv, skb, state); +#ifdef CONFIG_XFRM + if (ret != NF_ACCEPT) + return ret; + + if (IPCB(skb)->flags & IPSKB_XFRM_TRANSFORMED) + return ret; + + ct = nf_ct_get(skb, &ctinfo); + if (ct) { + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + + if (ct->tuplehash[dir].tuple.src.u3.ip != + ct->tuplehash[!dir].tuple.dst.u3.ip || + (ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMP && + ct->tuplehash[dir].tuple.src.u.all != + ct->tuplehash[!dir].tuple.dst.u.all)) { + err = nf_xfrm_me_harder(state->net, skb, AF_INET); + if (err < 0) + ret = NF_DROP_ERR(err); + } + } +#endif + return ret; +} + +static unsigned int +nf_nat_ipv4_local_fn(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + const struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + unsigned int ret; + int err; + + ret = nf_nat_ipv4_fn(priv, skb, state); + if (ret != NF_ACCEPT) + return ret; + + ct = nf_ct_get(skb, &ctinfo); + if (ct) { + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + + if (ct->tuplehash[dir].tuple.dst.u3.ip != + ct->tuplehash[!dir].tuple.src.u3.ip) { + err = ip_route_me_harder(state->net, state->sk, skb, RTN_UNSPEC); + if (err < 0) + ret = NF_DROP_ERR(err); + } +#ifdef CONFIG_XFRM + else if (!(IPCB(skb)->flags & IPSKB_XFRM_TRANSFORMED) && + ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMP && + ct->tuplehash[dir].tuple.dst.u.all != + ct->tuplehash[!dir].tuple.src.u.all) { + err = nf_xfrm_me_harder(state->net, skb, AF_INET); + if (err < 0) + ret = NF_DROP_ERR(err); + } +#endif + } + return ret; +} + +static const struct nf_hook_ops nf_nat_ipv4_ops[] = { + /* Before packet filtering, change destination */ + { + .hook = nf_nat_ipv4_pre_routing, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_PRE_ROUTING, + .priority = NF_IP_PRI_NAT_DST, + }, + /* After packet filtering, change source */ + { + .hook = nf_nat_ipv4_out, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_POST_ROUTING, + .priority = NF_IP_PRI_NAT_SRC, + }, + /* Before packet filtering, change destination */ + { + .hook = nf_nat_ipv4_local_fn, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_LOCAL_OUT, + .priority = NF_IP_PRI_NAT_DST, + }, + /* After packet filtering, change source */ + { + .hook = nf_nat_ipv4_local_in, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_LOCAL_IN, + .priority = NF_IP_PRI_NAT_SRC, + }, +}; + +int nf_nat_ipv4_register_fn(struct net *net, const struct nf_hook_ops *ops) +{ + return nf_nat_register_fn(net, ops->pf, ops, nf_nat_ipv4_ops, + ARRAY_SIZE(nf_nat_ipv4_ops)); +} +EXPORT_SYMBOL_GPL(nf_nat_ipv4_register_fn); + +void nf_nat_ipv4_unregister_fn(struct net *net, const struct nf_hook_ops *ops) +{ + nf_nat_unregister_fn(net, ops->pf, ops, ARRAY_SIZE(nf_nat_ipv4_ops)); +} +EXPORT_SYMBOL_GPL(nf_nat_ipv4_unregister_fn); + +#if IS_ENABLED(CONFIG_IPV6) +int nf_nat_icmpv6_reply_translation(struct sk_buff *skb, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + unsigned int hooknum, + unsigned int hdrlen) +{ + struct { + struct icmp6hdr icmp6; + struct ipv6hdr ip6; + } *inside; + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + enum nf_nat_manip_type manip = HOOK2MANIP(hooknum); + struct nf_conntrack_tuple target; + unsigned long statusbit; + + WARN_ON(ctinfo != IP_CT_RELATED && ctinfo != IP_CT_RELATED_REPLY); + + if (skb_ensure_writable(skb, hdrlen + sizeof(*inside))) + return 0; + if (nf_ip6_checksum(skb, hooknum, hdrlen, IPPROTO_ICMPV6)) + return 0; + + inside = (void *)skb->data + hdrlen; + if (inside->icmp6.icmp6_type == NDISC_REDIRECT) { + if ((ct->status & IPS_NAT_DONE_MASK) != IPS_NAT_DONE_MASK) + return 0; + if (ct->status & IPS_NAT_MASK) + return 0; + } + + if (manip == NF_NAT_MANIP_SRC) + statusbit = IPS_SRC_NAT; + else + statusbit = IPS_DST_NAT; + + /* Invert if this is reply direction */ + if (dir == IP_CT_DIR_REPLY) + statusbit ^= IPS_NAT_MASK; + + if (!(ct->status & statusbit)) + return 1; + + if (!nf_nat_ipv6_manip_pkt(skb, hdrlen + sizeof(inside->icmp6), + &ct->tuplehash[!dir].tuple, !manip)) + return 0; + + if (skb->ip_summed != CHECKSUM_PARTIAL) { + struct ipv6hdr *ipv6h = ipv6_hdr(skb); + + inside = (void *)skb->data + hdrlen; + inside->icmp6.icmp6_cksum = 0; + inside->icmp6.icmp6_cksum = + csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr, + skb->len - hdrlen, IPPROTO_ICMPV6, + skb_checksum(skb, hdrlen, + skb->len - hdrlen, 0)); + } + + nf_ct_invert_tuple(&target, &ct->tuplehash[!dir].tuple); + target.dst.protonum = IPPROTO_ICMPV6; + if (!nf_nat_ipv6_manip_pkt(skb, 0, &target, manip)) + return 0; + + return 1; +} +EXPORT_SYMBOL_GPL(nf_nat_icmpv6_reply_translation); + +static unsigned int +nf_nat_ipv6_fn(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + __be16 frag_off; + int hdrlen; + u8 nexthdr; + + ct = nf_ct_get(skb, &ctinfo); + /* Can't track? It's not due to stress, or conntrack would + * have dropped it. Hence it's the user's responsibilty to + * packet filter it out, or implement conntrack/NAT for that + * protocol. 8) --RR + */ + if (!ct) + return NF_ACCEPT; + + if (ctinfo == IP_CT_RELATED || ctinfo == IP_CT_RELATED_REPLY) { + nexthdr = ipv6_hdr(skb)->nexthdr; + hdrlen = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), + &nexthdr, &frag_off); + + if (hdrlen >= 0 && nexthdr == IPPROTO_ICMPV6) { + if (!nf_nat_icmpv6_reply_translation(skb, ct, ctinfo, + state->hook, + hdrlen)) + return NF_DROP; + else + return NF_ACCEPT; + } + } + + return nf_nat_inet_fn(priv, skb, state); +} + +static unsigned int +nf_nat_ipv6_in(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + unsigned int ret; + struct in6_addr daddr = ipv6_hdr(skb)->daddr; + + ret = nf_nat_ipv6_fn(priv, skb, state); + if (ret != NF_DROP && ret != NF_STOLEN && + ipv6_addr_cmp(&daddr, &ipv6_hdr(skb)->daddr)) + skb_dst_drop(skb); + + return ret; +} + +static unsigned int +nf_nat_ipv6_out(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ +#ifdef CONFIG_XFRM + const struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + int err; +#endif + unsigned int ret; + + ret = nf_nat_ipv6_fn(priv, skb, state); +#ifdef CONFIG_XFRM + if (ret != NF_ACCEPT) + return ret; + + if (IP6CB(skb)->flags & IP6SKB_XFRM_TRANSFORMED) + return ret; + ct = nf_ct_get(skb, &ctinfo); + if (ct) { + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + + if (!nf_inet_addr_cmp(&ct->tuplehash[dir].tuple.src.u3, + &ct->tuplehash[!dir].tuple.dst.u3) || + (ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMPV6 && + ct->tuplehash[dir].tuple.src.u.all != + ct->tuplehash[!dir].tuple.dst.u.all)) { + err = nf_xfrm_me_harder(state->net, skb, AF_INET6); + if (err < 0) + ret = NF_DROP_ERR(err); + } + } +#endif + + return ret; +} + +static unsigned int +nf_nat_ipv6_local_fn(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + const struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + unsigned int ret; + int err; + + ret = nf_nat_ipv6_fn(priv, skb, state); + if (ret != NF_ACCEPT) + return ret; + + ct = nf_ct_get(skb, &ctinfo); + if (ct) { + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + + if (!nf_inet_addr_cmp(&ct->tuplehash[dir].tuple.dst.u3, + &ct->tuplehash[!dir].tuple.src.u3)) { + err = nf_ip6_route_me_harder(state->net, state->sk, skb); + if (err < 0) + ret = NF_DROP_ERR(err); + } +#ifdef CONFIG_XFRM + else if (!(IP6CB(skb)->flags & IP6SKB_XFRM_TRANSFORMED) && + ct->tuplehash[dir].tuple.dst.protonum != IPPROTO_ICMPV6 && + ct->tuplehash[dir].tuple.dst.u.all != + ct->tuplehash[!dir].tuple.src.u.all) { + err = nf_xfrm_me_harder(state->net, skb, AF_INET6); + if (err < 0) + ret = NF_DROP_ERR(err); + } +#endif + } + + return ret; +} + +static const struct nf_hook_ops nf_nat_ipv6_ops[] = { + /* Before packet filtering, change destination */ + { + .hook = nf_nat_ipv6_in, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_PRE_ROUTING, + .priority = NF_IP6_PRI_NAT_DST, + }, + /* After packet filtering, change source */ + { + .hook = nf_nat_ipv6_out, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_POST_ROUTING, + .priority = NF_IP6_PRI_NAT_SRC, + }, + /* Before packet filtering, change destination */ + { + .hook = nf_nat_ipv6_local_fn, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_LOCAL_OUT, + .priority = NF_IP6_PRI_NAT_DST, + }, + /* After packet filtering, change source */ + { + .hook = nf_nat_ipv6_fn, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_LOCAL_IN, + .priority = NF_IP6_PRI_NAT_SRC, + }, +}; + +int nf_nat_ipv6_register_fn(struct net *net, const struct nf_hook_ops *ops) +{ + return nf_nat_register_fn(net, ops->pf, ops, nf_nat_ipv6_ops, + ARRAY_SIZE(nf_nat_ipv6_ops)); +} +EXPORT_SYMBOL_GPL(nf_nat_ipv6_register_fn); + +void nf_nat_ipv6_unregister_fn(struct net *net, const struct nf_hook_ops *ops) +{ + nf_nat_unregister_fn(net, ops->pf, ops, ARRAY_SIZE(nf_nat_ipv6_ops)); +} +EXPORT_SYMBOL_GPL(nf_nat_ipv6_unregister_fn); +#endif /* CONFIG_IPV6 */ + +#if defined(CONFIG_NF_TABLES_INET) && IS_ENABLED(CONFIG_NFT_NAT) +int nf_nat_inet_register_fn(struct net *net, const struct nf_hook_ops *ops) +{ + int ret; + + if (WARN_ON_ONCE(ops->pf != NFPROTO_INET)) + return -EINVAL; + + ret = nf_nat_register_fn(net, NFPROTO_IPV6, ops, nf_nat_ipv6_ops, + ARRAY_SIZE(nf_nat_ipv6_ops)); + if (ret) + return ret; + + ret = nf_nat_register_fn(net, NFPROTO_IPV4, ops, nf_nat_ipv4_ops, + ARRAY_SIZE(nf_nat_ipv4_ops)); + if (ret) + nf_nat_unregister_fn(net, NFPROTO_IPV6, ops, + ARRAY_SIZE(nf_nat_ipv6_ops)); + return ret; +} +EXPORT_SYMBOL_GPL(nf_nat_inet_register_fn); + +void nf_nat_inet_unregister_fn(struct net *net, const struct nf_hook_ops *ops) +{ + nf_nat_unregister_fn(net, NFPROTO_IPV4, ops, ARRAY_SIZE(nf_nat_ipv4_ops)); + nf_nat_unregister_fn(net, NFPROTO_IPV6, ops, ARRAY_SIZE(nf_nat_ipv6_ops)); +} +EXPORT_SYMBOL_GPL(nf_nat_inet_unregister_fn); +#endif /* NFT INET NAT */ diff --git a/net/netfilter/nf_nat_redirect.c b/net/netfilter/nf_nat_redirect.c new file mode 100644 index 000000000..5b37487d9 --- /dev/null +++ b/net/netfilter/nf_nat_redirect.c @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2006 Netfilter Core Team + * Copyright (c) 2011 Patrick McHardy + * + * Based on Rusty Russell's IPv4 REDIRECT target. Development of IPv6 + * NAT funded by Astaro. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static unsigned int +nf_nat_redirect(struct sk_buff *skb, const struct nf_nat_range2 *range, + const union nf_inet_addr *newdst) +{ + struct nf_nat_range2 newrange; + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + + ct = nf_ct_get(skb, &ctinfo); + + memset(&newrange, 0, sizeof(newrange)); + + newrange.flags = range->flags | NF_NAT_RANGE_MAP_IPS; + newrange.min_addr = *newdst; + newrange.max_addr = *newdst; + newrange.min_proto = range->min_proto; + newrange.max_proto = range->max_proto; + + return nf_nat_setup_info(ct, &newrange, NF_NAT_MANIP_DST); +} + +unsigned int +nf_nat_redirect_ipv4(struct sk_buff *skb, const struct nf_nat_range2 *range, + unsigned int hooknum) +{ + union nf_inet_addr newdst = {}; + + WARN_ON(hooknum != NF_INET_PRE_ROUTING && + hooknum != NF_INET_LOCAL_OUT); + + /* Local packets: make them go to loopback */ + if (hooknum == NF_INET_LOCAL_OUT) { + newdst.ip = htonl(INADDR_LOOPBACK); + } else { + const struct in_device *indev; + + indev = __in_dev_get_rcu(skb->dev); + if (indev) { + const struct in_ifaddr *ifa; + + ifa = rcu_dereference(indev->ifa_list); + if (ifa) + newdst.ip = ifa->ifa_local; + } + + if (!newdst.ip) + return NF_DROP; + } + + return nf_nat_redirect(skb, range, &newdst); +} +EXPORT_SYMBOL_GPL(nf_nat_redirect_ipv4); + +static const struct in6_addr loopback_addr = IN6ADDR_LOOPBACK_INIT; + +static bool nf_nat_redirect_ipv6_usable(const struct inet6_ifaddr *ifa, unsigned int scope) +{ + unsigned int ifa_addr_type = ipv6_addr_type(&ifa->addr); + + if (ifa_addr_type & IPV6_ADDR_MAPPED) + return false; + + if ((ifa->flags & IFA_F_TENTATIVE) && (!(ifa->flags & IFA_F_OPTIMISTIC))) + return false; + + if (scope) { + unsigned int ifa_scope = ifa_addr_type & IPV6_ADDR_SCOPE_MASK; + + if (!(scope & ifa_scope)) + return false; + } + + return true; +} + +unsigned int +nf_nat_redirect_ipv6(struct sk_buff *skb, const struct nf_nat_range2 *range, + unsigned int hooknum) +{ + union nf_inet_addr newdst = {}; + + if (hooknum == NF_INET_LOCAL_OUT) { + newdst.in6 = loopback_addr; + } else { + unsigned int scope = ipv6_addr_scope(&ipv6_hdr(skb)->daddr); + struct inet6_dev *idev; + bool addr = false; + + idev = __in6_dev_get(skb->dev); + if (idev != NULL) { + const struct inet6_ifaddr *ifa; + + read_lock_bh(&idev->lock); + list_for_each_entry(ifa, &idev->addr_list, if_list) { + if (!nf_nat_redirect_ipv6_usable(ifa, scope)) + continue; + + newdst.in6 = ifa->addr; + addr = true; + break; + } + read_unlock_bh(&idev->lock); + } + + if (!addr) + return NF_DROP; + } + + return nf_nat_redirect(skb, range, &newdst); +} +EXPORT_SYMBOL_GPL(nf_nat_redirect_ipv6); diff --git a/net/netfilter/nf_nat_sip.c b/net/netfilter/nf_nat_sip.c new file mode 100644 index 000000000..cf4aeb299 --- /dev/null +++ b/net/netfilter/nf_nat_sip.c @@ -0,0 +1,676 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* SIP extension for NAT alteration. + * + * (C) 2005 by Christian Hentschel + * based on RR's ip_nat_ftp.c and other modules. + * (C) 2007 United Security Providers + * (C) 2007, 2008, 2011, 2012 Patrick McHardy + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define NAT_HELPER_NAME "sip" + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Christian Hentschel "); +MODULE_DESCRIPTION("SIP NAT helper"); +MODULE_ALIAS_NF_NAT_HELPER(NAT_HELPER_NAME); + +static struct nf_conntrack_nat_helper nat_helper_sip = + NF_CT_NAT_HELPER_INIT(NAT_HELPER_NAME); + +static unsigned int mangle_packet(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + unsigned int matchoff, unsigned int matchlen, + const char *buffer, unsigned int buflen) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + struct tcphdr *th; + unsigned int baseoff; + + if (nf_ct_protonum(ct) == IPPROTO_TCP) { + th = (struct tcphdr *)(skb->data + protoff); + baseoff = protoff + th->doff * 4; + matchoff += dataoff - baseoff; + + if (!__nf_nat_mangle_tcp_packet(skb, ct, ctinfo, + protoff, matchoff, matchlen, + buffer, buflen, false)) + return 0; + } else { + baseoff = protoff + sizeof(struct udphdr); + matchoff += dataoff - baseoff; + + if (!nf_nat_mangle_udp_packet(skb, ct, ctinfo, + protoff, matchoff, matchlen, + buffer, buflen)) + return 0; + } + + /* Reload data pointer and adjust datalen value */ + *dptr = skb->data + dataoff; + *datalen += buflen - matchlen; + return 1; +} + +static int sip_sprintf_addr(const struct nf_conn *ct, char *buffer, + const union nf_inet_addr *addr, bool delim) +{ + if (nf_ct_l3num(ct) == NFPROTO_IPV4) + return sprintf(buffer, "%pI4", &addr->ip); + else { + if (delim) + return sprintf(buffer, "[%pI6c]", &addr->ip6); + else + return sprintf(buffer, "%pI6c", &addr->ip6); + } +} + +static int sip_sprintf_addr_port(const struct nf_conn *ct, char *buffer, + const union nf_inet_addr *addr, u16 port) +{ + if (nf_ct_l3num(ct) == NFPROTO_IPV4) + return sprintf(buffer, "%pI4:%u", &addr->ip, port); + else + return sprintf(buffer, "[%pI6c]:%u", &addr->ip6, port); +} + +static int map_addr(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + unsigned int matchoff, unsigned int matchlen, + union nf_inet_addr *addr, __be16 port) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + struct nf_ct_sip_master *ct_sip_info = nfct_help_data(ct); + char buffer[INET6_ADDRSTRLEN + sizeof("[]:nnnnn")]; + unsigned int buflen; + union nf_inet_addr newaddr; + __be16 newport; + + if (nf_inet_addr_cmp(&ct->tuplehash[dir].tuple.src.u3, addr) && + ct->tuplehash[dir].tuple.src.u.udp.port == port) { + newaddr = ct->tuplehash[!dir].tuple.dst.u3; + newport = ct->tuplehash[!dir].tuple.dst.u.udp.port; + } else if (nf_inet_addr_cmp(&ct->tuplehash[dir].tuple.dst.u3, addr) && + ct->tuplehash[dir].tuple.dst.u.udp.port == port) { + newaddr = ct->tuplehash[!dir].tuple.src.u3; + newport = ct_sip_info->forced_dport ? : + ct->tuplehash[!dir].tuple.src.u.udp.port; + } else + return 1; + + if (nf_inet_addr_cmp(&newaddr, addr) && newport == port) + return 1; + + buflen = sip_sprintf_addr_port(ct, buffer, &newaddr, ntohs(newport)); + return mangle_packet(skb, protoff, dataoff, dptr, datalen, + matchoff, matchlen, buffer, buflen); +} + +static int map_sip_addr(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + enum sip_header_types type) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + unsigned int matchlen, matchoff; + union nf_inet_addr addr; + __be16 port; + + if (ct_sip_parse_header_uri(ct, *dptr, NULL, *datalen, type, NULL, + &matchoff, &matchlen, &addr, &port) <= 0) + return 1; + return map_addr(skb, protoff, dataoff, dptr, datalen, + matchoff, matchlen, &addr, port); +} + +static unsigned int nf_nat_sip(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + struct nf_ct_sip_master *ct_sip_info = nfct_help_data(ct); + unsigned int coff, matchoff, matchlen; + enum sip_header_types hdr; + union nf_inet_addr addr; + __be16 port; + int request, in_header; + + /* Basic rules: requests and responses. */ + if (strncasecmp(*dptr, "SIP/2.0", strlen("SIP/2.0")) != 0) { + if (ct_sip_parse_request(ct, *dptr, *datalen, + &matchoff, &matchlen, + &addr, &port) > 0 && + !map_addr(skb, protoff, dataoff, dptr, datalen, + matchoff, matchlen, &addr, port)) { + nf_ct_helper_log(skb, ct, "cannot mangle SIP message"); + return NF_DROP; + } + request = 1; + } else + request = 0; + + if (nf_ct_protonum(ct) == IPPROTO_TCP) + hdr = SIP_HDR_VIA_TCP; + else + hdr = SIP_HDR_VIA_UDP; + + /* Translate topmost Via header and parameters */ + if (ct_sip_parse_header_uri(ct, *dptr, NULL, *datalen, + hdr, NULL, &matchoff, &matchlen, + &addr, &port) > 0) { + unsigned int olen, matchend, poff, plen, buflen, n; + char buffer[INET6_ADDRSTRLEN + sizeof("[]:nnnnn")]; + + /* We're only interested in headers related to this + * connection */ + if (request) { + if (!nf_inet_addr_cmp(&addr, + &ct->tuplehash[dir].tuple.src.u3) || + port != ct->tuplehash[dir].tuple.src.u.udp.port) + goto next; + } else { + if (!nf_inet_addr_cmp(&addr, + &ct->tuplehash[dir].tuple.dst.u3) || + port != ct->tuplehash[dir].tuple.dst.u.udp.port) + goto next; + } + + olen = *datalen; + if (!map_addr(skb, protoff, dataoff, dptr, datalen, + matchoff, matchlen, &addr, port)) { + nf_ct_helper_log(skb, ct, "cannot mangle Via header"); + return NF_DROP; + } + + matchend = matchoff + matchlen + *datalen - olen; + + /* The maddr= parameter (RFC 2361) specifies where to send + * the reply. */ + if (ct_sip_parse_address_param(ct, *dptr, matchend, *datalen, + "maddr=", &poff, &plen, + &addr, true) > 0 && + nf_inet_addr_cmp(&addr, &ct->tuplehash[dir].tuple.src.u3) && + !nf_inet_addr_cmp(&addr, &ct->tuplehash[!dir].tuple.dst.u3)) { + buflen = sip_sprintf_addr(ct, buffer, + &ct->tuplehash[!dir].tuple.dst.u3, + true); + if (!mangle_packet(skb, protoff, dataoff, dptr, datalen, + poff, plen, buffer, buflen)) { + nf_ct_helper_log(skb, ct, "cannot mangle maddr"); + return NF_DROP; + } + } + + /* The received= parameter (RFC 2361) contains the address + * from which the server received the request. */ + if (ct_sip_parse_address_param(ct, *dptr, matchend, *datalen, + "received=", &poff, &plen, + &addr, false) > 0 && + nf_inet_addr_cmp(&addr, &ct->tuplehash[dir].tuple.dst.u3) && + !nf_inet_addr_cmp(&addr, &ct->tuplehash[!dir].tuple.src.u3)) { + buflen = sip_sprintf_addr(ct, buffer, + &ct->tuplehash[!dir].tuple.src.u3, + false); + if (!mangle_packet(skb, protoff, dataoff, dptr, datalen, + poff, plen, buffer, buflen)) { + nf_ct_helper_log(skb, ct, "cannot mangle received"); + return NF_DROP; + } + } + + /* The rport= parameter (RFC 3581) contains the port number + * from which the server received the request. */ + if (ct_sip_parse_numerical_param(ct, *dptr, matchend, *datalen, + "rport=", &poff, &plen, + &n) > 0 && + htons(n) == ct->tuplehash[dir].tuple.dst.u.udp.port && + htons(n) != ct->tuplehash[!dir].tuple.src.u.udp.port) { + __be16 p = ct->tuplehash[!dir].tuple.src.u.udp.port; + buflen = sprintf(buffer, "%u", ntohs(p)); + if (!mangle_packet(skb, protoff, dataoff, dptr, datalen, + poff, plen, buffer, buflen)) { + nf_ct_helper_log(skb, ct, "cannot mangle rport"); + return NF_DROP; + } + } + } + +next: + /* Translate Contact headers */ + coff = 0; + in_header = 0; + while (ct_sip_parse_header_uri(ct, *dptr, &coff, *datalen, + SIP_HDR_CONTACT, &in_header, + &matchoff, &matchlen, + &addr, &port) > 0) { + if (!map_addr(skb, protoff, dataoff, dptr, datalen, + matchoff, matchlen, + &addr, port)) { + nf_ct_helper_log(skb, ct, "cannot mangle contact"); + return NF_DROP; + } + } + + if (!map_sip_addr(skb, protoff, dataoff, dptr, datalen, SIP_HDR_FROM) || + !map_sip_addr(skb, protoff, dataoff, dptr, datalen, SIP_HDR_TO)) { + nf_ct_helper_log(skb, ct, "cannot mangle SIP from/to"); + return NF_DROP; + } + + /* Mangle destination port for Cisco phones, then fix up checksums */ + if (dir == IP_CT_DIR_REPLY && ct_sip_info->forced_dport) { + struct udphdr *uh; + + if (skb_ensure_writable(skb, skb->len)) { + nf_ct_helper_log(skb, ct, "cannot mangle packet"); + return NF_DROP; + } + + uh = (void *)skb->data + protoff; + uh->dest = ct_sip_info->forced_dport; + + if (!nf_nat_mangle_udp_packet(skb, ct, ctinfo, protoff, + 0, 0, NULL, 0)) { + nf_ct_helper_log(skb, ct, "cannot mangle packet"); + return NF_DROP; + } + } + + return NF_ACCEPT; +} + +static void nf_nat_sip_seq_adjust(struct sk_buff *skb, unsigned int protoff, + s16 off) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + const struct tcphdr *th; + + if (nf_ct_protonum(ct) != IPPROTO_TCP || off == 0) + return; + + th = (struct tcphdr *)(skb->data + protoff); + nf_ct_seqadj_set(ct, ctinfo, th->seq, off); +} + +/* Handles expected signalling connections and media streams */ +static void nf_nat_sip_expected(struct nf_conn *ct, + struct nf_conntrack_expect *exp) +{ + struct nf_conn_help *help = nfct_help(ct->master); + struct nf_conntrack_expect *pair_exp; + int range_set_for_snat = 0; + struct nf_nat_range2 range; + + /* This must be a fresh one. */ + BUG_ON(ct->status & IPS_NAT_DONE_MASK); + + /* For DST manip, map port here to where it's expected. */ + range.flags = (NF_NAT_RANGE_MAP_IPS | NF_NAT_RANGE_PROTO_SPECIFIED); + range.min_proto = range.max_proto = exp->saved_proto; + range.min_addr = range.max_addr = exp->saved_addr; + nf_nat_setup_info(ct, &range, NF_NAT_MANIP_DST); + + /* Do media streams SRC manip according with the parameters + * found in the paired expectation. + */ + if (exp->class != SIP_EXPECT_SIGNALLING) { + spin_lock_bh(&nf_conntrack_expect_lock); + hlist_for_each_entry(pair_exp, &help->expectations, lnode) { + if (pair_exp->tuple.src.l3num == nf_ct_l3num(ct) && + pair_exp->tuple.dst.protonum == ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.protonum && + nf_inet_addr_cmp(&ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.u3, &pair_exp->saved_addr) && + ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.u.all == pair_exp->saved_proto.all) { + range.flags = (NF_NAT_RANGE_MAP_IPS | NF_NAT_RANGE_PROTO_SPECIFIED); + range.min_proto.all = range.max_proto.all = pair_exp->tuple.dst.u.all; + range.min_addr = range.max_addr = pair_exp->tuple.dst.u3; + range_set_for_snat = 1; + break; + } + } + spin_unlock_bh(&nf_conntrack_expect_lock); + } + + /* When no paired expectation has been found, change src to + * where master sends to, but only if the connection actually came + * from the same source. + */ + if (!range_set_for_snat && + nf_inet_addr_cmp(&ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.u3, + &ct->master->tuplehash[exp->dir].tuple.src.u3)) { + range.flags = NF_NAT_RANGE_MAP_IPS; + range.min_addr = range.max_addr + = ct->master->tuplehash[!exp->dir].tuple.dst.u3; + range_set_for_snat = 1; + } + + /* Perform SRC manip. */ + if (range_set_for_snat) + nf_nat_setup_info(ct, &range, NF_NAT_MANIP_SRC); +} + +static unsigned int nf_nat_sip_expect(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + struct nf_conntrack_expect *exp, + unsigned int matchoff, + unsigned int matchlen) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + struct nf_ct_sip_master *ct_sip_info = nfct_help_data(ct); + union nf_inet_addr newaddr; + u_int16_t port; + __be16 srcport; + char buffer[INET6_ADDRSTRLEN + sizeof("[]:nnnnn")]; + unsigned int buflen; + + /* Connection will come from reply */ + if (nf_inet_addr_cmp(&ct->tuplehash[dir].tuple.src.u3, + &ct->tuplehash[!dir].tuple.dst.u3)) + newaddr = exp->tuple.dst.u3; + else + newaddr = ct->tuplehash[!dir].tuple.dst.u3; + + /* If the signalling port matches the connection's source port in the + * original direction, try to use the destination port in the opposite + * direction. */ + srcport = ct_sip_info->forced_dport ? : + ct->tuplehash[dir].tuple.src.u.udp.port; + if (exp->tuple.dst.u.udp.port == srcport) + port = ntohs(ct->tuplehash[!dir].tuple.dst.u.udp.port); + else + port = ntohs(exp->tuple.dst.u.udp.port); + + exp->saved_addr = exp->tuple.dst.u3; + exp->tuple.dst.u3 = newaddr; + exp->saved_proto.udp.port = exp->tuple.dst.u.udp.port; + exp->dir = !dir; + exp->expectfn = nf_nat_sip_expected; + + port = nf_nat_exp_find_port(exp, port); + if (port == 0) { + nf_ct_helper_log(skb, ct, "all ports in use for SIP"); + return NF_DROP; + } + + if (!nf_inet_addr_cmp(&exp->tuple.dst.u3, &exp->saved_addr) || + exp->tuple.dst.u.udp.port != exp->saved_proto.udp.port) { + buflen = sip_sprintf_addr_port(ct, buffer, &newaddr, port); + if (!mangle_packet(skb, protoff, dataoff, dptr, datalen, + matchoff, matchlen, buffer, buflen)) { + nf_ct_helper_log(skb, ct, "cannot mangle packet"); + goto err; + } + } + return NF_ACCEPT; + +err: + nf_ct_unexpect_related(exp); + return NF_DROP; +} + +static int mangle_content_len(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + unsigned int matchoff, matchlen; + char buffer[sizeof("65536")]; + int buflen, c_len; + + /* Get actual SDP length */ + if (ct_sip_get_sdp_header(ct, *dptr, 0, *datalen, + SDP_HDR_VERSION, SDP_HDR_UNSPEC, + &matchoff, &matchlen) <= 0) + return 0; + c_len = *datalen - matchoff + strlen("v="); + + /* Now, update SDP length */ + if (ct_sip_get_header(ct, *dptr, 0, *datalen, SIP_HDR_CONTENT_LENGTH, + &matchoff, &matchlen) <= 0) + return 0; + + buflen = sprintf(buffer, "%u", c_len); + return mangle_packet(skb, protoff, dataoff, dptr, datalen, + matchoff, matchlen, buffer, buflen); +} + +static int mangle_sdp_packet(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + unsigned int sdpoff, + enum sdp_header_types type, + enum sdp_header_types term, + char *buffer, int buflen) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + unsigned int matchlen, matchoff; + + if (ct_sip_get_sdp_header(ct, *dptr, sdpoff, *datalen, type, term, + &matchoff, &matchlen) <= 0) + return -ENOENT; + return mangle_packet(skb, protoff, dataoff, dptr, datalen, + matchoff, matchlen, buffer, buflen) ? 0 : -EINVAL; +} + +static unsigned int nf_nat_sdp_addr(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + unsigned int sdpoff, + enum sdp_header_types type, + enum sdp_header_types term, + const union nf_inet_addr *addr) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + char buffer[INET6_ADDRSTRLEN]; + unsigned int buflen; + + buflen = sip_sprintf_addr(ct, buffer, addr, false); + if (mangle_sdp_packet(skb, protoff, dataoff, dptr, datalen, + sdpoff, type, term, buffer, buflen)) + return 0; + + return mangle_content_len(skb, protoff, dataoff, dptr, datalen); +} + +static unsigned int nf_nat_sdp_port(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + unsigned int matchoff, + unsigned int matchlen, + u_int16_t port) +{ + char buffer[sizeof("nnnnn")]; + unsigned int buflen; + + buflen = sprintf(buffer, "%u", port); + if (!mangle_packet(skb, protoff, dataoff, dptr, datalen, + matchoff, matchlen, buffer, buflen)) + return 0; + + return mangle_content_len(skb, protoff, dataoff, dptr, datalen); +} + +static unsigned int nf_nat_sdp_session(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + unsigned int sdpoff, + const union nf_inet_addr *addr) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + char buffer[INET6_ADDRSTRLEN]; + unsigned int buflen; + + /* Mangle session description owner and contact addresses */ + buflen = sip_sprintf_addr(ct, buffer, addr, false); + if (mangle_sdp_packet(skb, protoff, dataoff, dptr, datalen, sdpoff, + SDP_HDR_OWNER, SDP_HDR_MEDIA, buffer, buflen)) + return 0; + + switch (mangle_sdp_packet(skb, protoff, dataoff, dptr, datalen, sdpoff, + SDP_HDR_CONNECTION, SDP_HDR_MEDIA, + buffer, buflen)) { + case 0: + /* + * RFC 2327: + * + * Session description + * + * c=* (connection information - not required if included in all media) + */ + case -ENOENT: + break; + default: + return 0; + } + + return mangle_content_len(skb, protoff, dataoff, dptr, datalen); +} + +/* So, this packet has hit the connection tracking matching code. + Mangle it, and change the expectation to match the new version. */ +static unsigned int nf_nat_sdp_media(struct sk_buff *skb, unsigned int protoff, + unsigned int dataoff, + const char **dptr, unsigned int *datalen, + struct nf_conntrack_expect *rtp_exp, + struct nf_conntrack_expect *rtcp_exp, + unsigned int mediaoff, + unsigned int medialen, + union nf_inet_addr *rtp_addr) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo); + u_int16_t port; + + /* Connection will come from reply */ + if (nf_inet_addr_cmp(&ct->tuplehash[dir].tuple.src.u3, + &ct->tuplehash[!dir].tuple.dst.u3)) + *rtp_addr = rtp_exp->tuple.dst.u3; + else + *rtp_addr = ct->tuplehash[!dir].tuple.dst.u3; + + rtp_exp->saved_addr = rtp_exp->tuple.dst.u3; + rtp_exp->tuple.dst.u3 = *rtp_addr; + rtp_exp->saved_proto.udp.port = rtp_exp->tuple.dst.u.udp.port; + rtp_exp->dir = !dir; + rtp_exp->expectfn = nf_nat_sip_expected; + + rtcp_exp->saved_addr = rtcp_exp->tuple.dst.u3; + rtcp_exp->tuple.dst.u3 = *rtp_addr; + rtcp_exp->saved_proto.udp.port = rtcp_exp->tuple.dst.u.udp.port; + rtcp_exp->dir = !dir; + rtcp_exp->expectfn = nf_nat_sip_expected; + + /* Try to get same pair of ports: if not, try to change them. */ + for (port = ntohs(rtp_exp->tuple.dst.u.udp.port); + port != 0; port += 2) { + int ret; + + rtp_exp->tuple.dst.u.udp.port = htons(port); + ret = nf_ct_expect_related(rtp_exp, + NF_CT_EXP_F_SKIP_MASTER); + if (ret == -EBUSY) + continue; + else if (ret < 0) { + port = 0; + break; + } + rtcp_exp->tuple.dst.u.udp.port = htons(port + 1); + ret = nf_ct_expect_related(rtcp_exp, + NF_CT_EXP_F_SKIP_MASTER); + if (ret == 0) + break; + else if (ret == -EBUSY) { + nf_ct_unexpect_related(rtp_exp); + continue; + } else if (ret < 0) { + nf_ct_unexpect_related(rtp_exp); + port = 0; + break; + } + } + + if (port == 0) { + nf_ct_helper_log(skb, ct, "all ports in use for SDP media"); + goto err1; + } + + /* Update media port. */ + if (rtp_exp->tuple.dst.u.udp.port != rtp_exp->saved_proto.udp.port && + !nf_nat_sdp_port(skb, protoff, dataoff, dptr, datalen, + mediaoff, medialen, port)) { + nf_ct_helper_log(skb, ct, "cannot mangle SDP message"); + goto err2; + } + + return NF_ACCEPT; + +err2: + nf_ct_unexpect_related(rtp_exp); + nf_ct_unexpect_related(rtcp_exp); +err1: + return NF_DROP; +} + +static struct nf_ct_helper_expectfn sip_nat = { + .name = "sip", + .expectfn = nf_nat_sip_expected, +}; + +static void __exit nf_nat_sip_fini(void) +{ + nf_nat_helper_unregister(&nat_helper_sip); + RCU_INIT_POINTER(nf_nat_sip_hooks, NULL); + nf_ct_helper_expectfn_unregister(&sip_nat); + synchronize_rcu(); +} + +static const struct nf_nat_sip_hooks sip_hooks = { + .msg = nf_nat_sip, + .seq_adjust = nf_nat_sip_seq_adjust, + .expect = nf_nat_sip_expect, + .sdp_addr = nf_nat_sdp_addr, + .sdp_port = nf_nat_sdp_port, + .sdp_session = nf_nat_sdp_session, + .sdp_media = nf_nat_sdp_media, +}; + +static int __init nf_nat_sip_init(void) +{ + BUG_ON(nf_nat_sip_hooks != NULL); + nf_nat_helper_register(&nat_helper_sip); + RCU_INIT_POINTER(nf_nat_sip_hooks, &sip_hooks); + nf_ct_helper_expectfn_register(&sip_nat); + return 0; +} + +module_init(nf_nat_sip_init); +module_exit(nf_nat_sip_fini); diff --git a/net/netfilter/nf_nat_tftp.c b/net/netfilter/nf_nat_tftp.c new file mode 100644 index 000000000..1a591132d --- /dev/null +++ b/net/netfilter/nf_nat_tftp.c @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* (C) 2001-2002 Magnus Boden + */ + +#include +#include + +#include +#include +#include +#include + +#define NAT_HELPER_NAME "tftp" + +MODULE_AUTHOR("Magnus Boden "); +MODULE_DESCRIPTION("TFTP NAT helper"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NF_NAT_HELPER(NAT_HELPER_NAME); + +static struct nf_conntrack_nat_helper nat_helper_tftp = + NF_CT_NAT_HELPER_INIT(NAT_HELPER_NAME); + +static unsigned int help(struct sk_buff *skb, + enum ip_conntrack_info ctinfo, + struct nf_conntrack_expect *exp) +{ + const struct nf_conn *ct = exp->master; + + exp->saved_proto.udp.port + = ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.u.udp.port; + exp->dir = IP_CT_DIR_REPLY; + exp->expectfn = nf_nat_follow_master; + if (nf_ct_expect_related(exp, 0) != 0) { + nf_ct_helper_log(skb, exp->master, "cannot add expectation"); + return NF_DROP; + } + return NF_ACCEPT; +} + +static void __exit nf_nat_tftp_fini(void) +{ + nf_nat_helper_unregister(&nat_helper_tftp); + RCU_INIT_POINTER(nf_nat_tftp_hook, NULL); + synchronize_rcu(); +} + +static int __init nf_nat_tftp_init(void) +{ + BUG_ON(nf_nat_tftp_hook != NULL); + nf_nat_helper_register(&nat_helper_tftp); + RCU_INIT_POINTER(nf_nat_tftp_hook, help); + return 0; +} + +module_init(nf_nat_tftp_init); +module_exit(nf_nat_tftp_fini); diff --git a/net/netfilter/nf_queue.c b/net/netfilter/nf_queue.c new file mode 100644 index 000000000..e2f334f70 --- /dev/null +++ b/net/netfilter/nf_queue.c @@ -0,0 +1,356 @@ +/* + * Rusty Russell (C)2000 -- This code is GPL. + * Patrick McHardy (c) 2006-2012 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "nf_internals.h" + +static const struct nf_queue_handler __rcu *nf_queue_handler; + +/* + * Hook for nfnetlink_queue to register its queue handler. + * We do this so that most of the NFQUEUE code can be modular. + * + * Once the queue is registered it must reinject all packets it + * receives, no matter what. + */ + +void nf_register_queue_handler(const struct nf_queue_handler *qh) +{ + /* should never happen, we only have one queueing backend in kernel */ + WARN_ON(rcu_access_pointer(nf_queue_handler)); + rcu_assign_pointer(nf_queue_handler, qh); +} +EXPORT_SYMBOL(nf_register_queue_handler); + +/* The caller must flush their queue before this */ +void nf_unregister_queue_handler(void) +{ + RCU_INIT_POINTER(nf_queue_handler, NULL); +} +EXPORT_SYMBOL(nf_unregister_queue_handler); + +static void nf_queue_sock_put(struct sock *sk) +{ +#ifdef CONFIG_INET + sock_gen_put(sk); +#else + sock_put(sk); +#endif +} + +static void nf_queue_entry_release_refs(struct nf_queue_entry *entry) +{ + struct nf_hook_state *state = &entry->state; + + /* Release those devices we held, or Alexey will kill me. */ + dev_put(state->in); + dev_put(state->out); + if (state->sk) + nf_queue_sock_put(state->sk); + +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + dev_put(entry->physin); + dev_put(entry->physout); +#endif +} + +void nf_queue_entry_free(struct nf_queue_entry *entry) +{ + nf_queue_entry_release_refs(entry); + kfree(entry); +} +EXPORT_SYMBOL_GPL(nf_queue_entry_free); + +static void __nf_queue_entry_init_physdevs(struct nf_queue_entry *entry) +{ +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + const struct sk_buff *skb = entry->skb; + + if (nf_bridge_info_exists(skb)) { + entry->physin = nf_bridge_get_physindev(skb, entry->state.net); + entry->physout = nf_bridge_get_physoutdev(skb); + } else { + entry->physin = NULL; + entry->physout = NULL; + } +#endif +} + +/* Bump dev refs so they don't vanish while packet is out */ +bool nf_queue_entry_get_refs(struct nf_queue_entry *entry) +{ + struct nf_hook_state *state = &entry->state; + + if (state->sk && !refcount_inc_not_zero(&state->sk->sk_refcnt)) + return false; + + dev_hold(state->in); + dev_hold(state->out); + +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + dev_hold(entry->physin); + dev_hold(entry->physout); +#endif + return true; +} +EXPORT_SYMBOL_GPL(nf_queue_entry_get_refs); + +void nf_queue_nf_hook_drop(struct net *net) +{ + const struct nf_queue_handler *qh; + + rcu_read_lock(); + qh = rcu_dereference(nf_queue_handler); + if (qh) + qh->nf_hook_drop(net); + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(nf_queue_nf_hook_drop); + +static void nf_ip_saveroute(const struct sk_buff *skb, + struct nf_queue_entry *entry) +{ + struct ip_rt_info *rt_info = nf_queue_entry_reroute(entry); + + if (entry->state.hook == NF_INET_LOCAL_OUT) { + const struct iphdr *iph = ip_hdr(skb); + + rt_info->tos = iph->tos; + rt_info->daddr = iph->daddr; + rt_info->saddr = iph->saddr; + rt_info->mark = skb->mark; + } +} + +static void nf_ip6_saveroute(const struct sk_buff *skb, + struct nf_queue_entry *entry) +{ + struct ip6_rt_info *rt_info = nf_queue_entry_reroute(entry); + + if (entry->state.hook == NF_INET_LOCAL_OUT) { + const struct ipv6hdr *iph = ipv6_hdr(skb); + + rt_info->daddr = iph->daddr; + rt_info->saddr = iph->saddr; + rt_info->mark = skb->mark; + } +} + +static int __nf_queue(struct sk_buff *skb, const struct nf_hook_state *state, + unsigned int index, unsigned int queuenum) +{ + struct nf_queue_entry *entry = NULL; + const struct nf_queue_handler *qh; + unsigned int route_key_size; + int status; + + /* QUEUE == DROP if no one is waiting, to be safe. */ + qh = rcu_dereference(nf_queue_handler); + if (!qh) + return -ESRCH; + + switch (state->pf) { + case AF_INET: + route_key_size = sizeof(struct ip_rt_info); + break; + case AF_INET6: + route_key_size = sizeof(struct ip6_rt_info); + break; + default: + route_key_size = 0; + break; + } + + if (skb_sk_is_prefetched(skb)) { + struct sock *sk = skb->sk; + + if (!sk_is_refcounted(sk)) { + if (!refcount_inc_not_zero(&sk->sk_refcnt)) + return -ENOTCONN; + + /* drop refcount on skb_orphan */ + skb->destructor = sock_edemux; + } + } + + entry = kmalloc(sizeof(*entry) + route_key_size, GFP_ATOMIC); + if (!entry) + return -ENOMEM; + + if (skb_dst(skb) && !skb_dst_force(skb)) { + kfree(entry); + return -ENETDOWN; + } + + *entry = (struct nf_queue_entry) { + .skb = skb, + .state = *state, + .hook_index = index, + .size = sizeof(*entry) + route_key_size, + }; + + __nf_queue_entry_init_physdevs(entry); + + if (!nf_queue_entry_get_refs(entry)) { + kfree(entry); + return -ENOTCONN; + } + + switch (entry->state.pf) { + case AF_INET: + nf_ip_saveroute(skb, entry); + break; + case AF_INET6: + nf_ip6_saveroute(skb, entry); + break; + } + + status = qh->outfn(entry, queuenum); + if (status < 0) { + nf_queue_entry_free(entry); + return status; + } + + return 0; +} + +/* Packets leaving via this function must come back through nf_reinject(). */ +int nf_queue(struct sk_buff *skb, struct nf_hook_state *state, + unsigned int index, unsigned int verdict) +{ + int ret; + + ret = __nf_queue(skb, state, index, verdict >> NF_VERDICT_QBITS); + if (ret < 0) { + if (ret == -ESRCH && + (verdict & NF_VERDICT_FLAG_QUEUE_BYPASS)) + return 1; + kfree_skb(skb); + } + + return 0; +} +EXPORT_SYMBOL_GPL(nf_queue); + +static unsigned int nf_iterate(struct sk_buff *skb, + struct nf_hook_state *state, + const struct nf_hook_entries *hooks, + unsigned int *index) +{ + const struct nf_hook_entry *hook; + unsigned int verdict, i = *index; + + while (i < hooks->num_hook_entries) { + hook = &hooks->hooks[i]; +repeat: + verdict = nf_hook_entry_hookfn(hook, skb, state); + if (verdict != NF_ACCEPT) { + *index = i; + if (verdict != NF_REPEAT) + return verdict; + goto repeat; + } + i++; + } + + *index = i; + return NF_ACCEPT; +} + +static struct nf_hook_entries *nf_hook_entries_head(const struct net *net, u8 pf, u8 hooknum) +{ + switch (pf) { +#ifdef CONFIG_NETFILTER_FAMILY_BRIDGE + case NFPROTO_BRIDGE: + return rcu_dereference(net->nf.hooks_bridge[hooknum]); +#endif + case NFPROTO_IPV4: + return rcu_dereference(net->nf.hooks_ipv4[hooknum]); + case NFPROTO_IPV6: + return rcu_dereference(net->nf.hooks_ipv6[hooknum]); + default: + WARN_ON_ONCE(1); + return NULL; + } + + return NULL; +} + +/* Caller must hold rcu read-side lock */ +void nf_reinject(struct nf_queue_entry *entry, unsigned int verdict) +{ + const struct nf_hook_entry *hook_entry; + const struct nf_hook_entries *hooks; + struct sk_buff *skb = entry->skb; + const struct net *net; + unsigned int i; + int err; + u8 pf; + + net = entry->state.net; + pf = entry->state.pf; + + hooks = nf_hook_entries_head(net, pf, entry->state.hook); + + i = entry->hook_index; + if (WARN_ON_ONCE(!hooks || i >= hooks->num_hook_entries)) { + kfree_skb(skb); + nf_queue_entry_free(entry); + return; + } + + hook_entry = &hooks->hooks[i]; + + /* Continue traversal iff userspace said ok... */ + if (verdict == NF_REPEAT) + verdict = nf_hook_entry_hookfn(hook_entry, skb, &entry->state); + + if (verdict == NF_ACCEPT) { + if (nf_reroute(skb, entry) < 0) + verdict = NF_DROP; + } + + if (verdict == NF_ACCEPT) { +next_hook: + ++i; + verdict = nf_iterate(skb, &entry->state, hooks, &i); + } + + switch (verdict & NF_VERDICT_MASK) { + case NF_ACCEPT: + case NF_STOP: + local_bh_disable(); + entry->state.okfn(entry->state.net, entry->state.sk, skb); + local_bh_enable(); + break; + case NF_QUEUE: + err = nf_queue(skb, &entry->state, i, verdict); + if (err == 1) + goto next_hook; + break; + case NF_STOLEN: + break; + default: + kfree_skb(skb); + } + + nf_queue_entry_free(entry); +} +EXPORT_SYMBOL(nf_reinject); diff --git a/net/netfilter/nf_sockopt.c b/net/netfilter/nf_sockopt.c new file mode 100644 index 000000000..34afcd03b --- /dev/null +++ b/net/netfilter/nf_sockopt.c @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include +#include + +#include "nf_internals.h" + +/* Sockopts only registered and called from user context, so + net locking would be overkill. Also, [gs]etsockopt calls may + sleep. */ +static DEFINE_MUTEX(nf_sockopt_mutex); +static LIST_HEAD(nf_sockopts); + +/* Do exclusive ranges overlap? */ +static inline int overlap(int min1, int max1, int min2, int max2) +{ + return max1 > min2 && min1 < max2; +} + +/* Functions to register sockopt ranges (exclusive). */ +int nf_register_sockopt(struct nf_sockopt_ops *reg) +{ + struct nf_sockopt_ops *ops; + int ret = 0; + + mutex_lock(&nf_sockopt_mutex); + list_for_each_entry(ops, &nf_sockopts, list) { + if (ops->pf == reg->pf + && (overlap(ops->set_optmin, ops->set_optmax, + reg->set_optmin, reg->set_optmax) + || overlap(ops->get_optmin, ops->get_optmax, + reg->get_optmin, reg->get_optmax))) { + pr_debug("nf_sock overlap: %u-%u/%u-%u v %u-%u/%u-%u\n", + ops->set_optmin, ops->set_optmax, + ops->get_optmin, ops->get_optmax, + reg->set_optmin, reg->set_optmax, + reg->get_optmin, reg->get_optmax); + ret = -EBUSY; + goto out; + } + } + + list_add(®->list, &nf_sockopts); +out: + mutex_unlock(&nf_sockopt_mutex); + return ret; +} +EXPORT_SYMBOL(nf_register_sockopt); + +void nf_unregister_sockopt(struct nf_sockopt_ops *reg) +{ + mutex_lock(&nf_sockopt_mutex); + list_del(®->list); + mutex_unlock(&nf_sockopt_mutex); +} +EXPORT_SYMBOL(nf_unregister_sockopt); + +static struct nf_sockopt_ops *nf_sockopt_find(struct sock *sk, u_int8_t pf, + int val, int get) +{ + struct nf_sockopt_ops *ops; + + mutex_lock(&nf_sockopt_mutex); + list_for_each_entry(ops, &nf_sockopts, list) { + if (ops->pf == pf) { + if (!try_module_get(ops->owner)) + goto out_nosup; + + if (get) { + if (val >= ops->get_optmin && + val < ops->get_optmax) + goto out; + } else { + if (val >= ops->set_optmin && + val < ops->set_optmax) + goto out; + } + module_put(ops->owner); + } + } +out_nosup: + ops = ERR_PTR(-ENOPROTOOPT); +out: + mutex_unlock(&nf_sockopt_mutex); + return ops; +} + +int nf_setsockopt(struct sock *sk, u_int8_t pf, int val, sockptr_t opt, + unsigned int len) +{ + struct nf_sockopt_ops *ops; + int ret; + + ops = nf_sockopt_find(sk, pf, val, 0); + if (IS_ERR(ops)) + return PTR_ERR(ops); + ret = ops->set(sk, val, opt, len); + module_put(ops->owner); + return ret; +} +EXPORT_SYMBOL(nf_setsockopt); + +int nf_getsockopt(struct sock *sk, u_int8_t pf, int val, char __user *opt, + int *len) +{ + struct nf_sockopt_ops *ops; + int ret; + + ops = nf_sockopt_find(sk, pf, val, 1); + if (IS_ERR(ops)) + return PTR_ERR(ops); + ret = ops->get(sk, val, opt, len); + module_put(ops->owner); + return ret; +} +EXPORT_SYMBOL(nf_getsockopt); diff --git a/net/netfilter/nf_synproxy_core.c b/net/netfilter/nf_synproxy_core.c new file mode 100644 index 000000000..16915f8ee --- /dev/null +++ b/net/netfilter/nf_synproxy_core.c @@ -0,0 +1,1220 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2013 Patrick McHardy + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +unsigned int synproxy_net_id; +EXPORT_SYMBOL_GPL(synproxy_net_id); + +bool +synproxy_parse_options(const struct sk_buff *skb, unsigned int doff, + const struct tcphdr *th, struct synproxy_options *opts) +{ + int length = (th->doff * 4) - sizeof(*th); + u8 buf[40], *ptr; + + if (unlikely(length < 0)) + return false; + + ptr = skb_header_pointer(skb, doff + sizeof(*th), length, buf); + if (ptr == NULL) + return false; + + opts->options = 0; + while (length > 0) { + int opcode = *ptr++; + int opsize; + + switch (opcode) { + case TCPOPT_EOL: + return true; + case TCPOPT_NOP: + length--; + continue; + default: + if (length < 2) + return true; + opsize = *ptr++; + if (opsize < 2) + return true; + if (opsize > length) + return true; + + switch (opcode) { + case TCPOPT_MSS: + if (opsize == TCPOLEN_MSS) { + opts->mss_option = get_unaligned_be16(ptr); + opts->options |= NF_SYNPROXY_OPT_MSS; + } + break; + case TCPOPT_WINDOW: + if (opsize == TCPOLEN_WINDOW) { + opts->wscale = *ptr; + if (opts->wscale > TCP_MAX_WSCALE) + opts->wscale = TCP_MAX_WSCALE; + opts->options |= NF_SYNPROXY_OPT_WSCALE; + } + break; + case TCPOPT_TIMESTAMP: + if (opsize == TCPOLEN_TIMESTAMP) { + opts->tsval = get_unaligned_be32(ptr); + opts->tsecr = get_unaligned_be32(ptr + 4); + opts->options |= NF_SYNPROXY_OPT_TIMESTAMP; + } + break; + case TCPOPT_SACK_PERM: + if (opsize == TCPOLEN_SACK_PERM) + opts->options |= NF_SYNPROXY_OPT_SACK_PERM; + break; + } + + ptr += opsize - 2; + length -= opsize; + } + } + return true; +} +EXPORT_SYMBOL_GPL(synproxy_parse_options); + +static unsigned int +synproxy_options_size(const struct synproxy_options *opts) +{ + unsigned int size = 0; + + if (opts->options & NF_SYNPROXY_OPT_MSS) + size += TCPOLEN_MSS_ALIGNED; + if (opts->options & NF_SYNPROXY_OPT_TIMESTAMP) + size += TCPOLEN_TSTAMP_ALIGNED; + else if (opts->options & NF_SYNPROXY_OPT_SACK_PERM) + size += TCPOLEN_SACKPERM_ALIGNED; + if (opts->options & NF_SYNPROXY_OPT_WSCALE) + size += TCPOLEN_WSCALE_ALIGNED; + + return size; +} + +static void +synproxy_build_options(struct tcphdr *th, const struct synproxy_options *opts) +{ + __be32 *ptr = (__be32 *)(th + 1); + u8 options = opts->options; + + if (options & NF_SYNPROXY_OPT_MSS) + *ptr++ = htonl((TCPOPT_MSS << 24) | + (TCPOLEN_MSS << 16) | + opts->mss_option); + + if (options & NF_SYNPROXY_OPT_TIMESTAMP) { + if (options & NF_SYNPROXY_OPT_SACK_PERM) + *ptr++ = htonl((TCPOPT_SACK_PERM << 24) | + (TCPOLEN_SACK_PERM << 16) | + (TCPOPT_TIMESTAMP << 8) | + TCPOLEN_TIMESTAMP); + else + *ptr++ = htonl((TCPOPT_NOP << 24) | + (TCPOPT_NOP << 16) | + (TCPOPT_TIMESTAMP << 8) | + TCPOLEN_TIMESTAMP); + + *ptr++ = htonl(opts->tsval); + *ptr++ = htonl(opts->tsecr); + } else if (options & NF_SYNPROXY_OPT_SACK_PERM) + *ptr++ = htonl((TCPOPT_NOP << 24) | + (TCPOPT_NOP << 16) | + (TCPOPT_SACK_PERM << 8) | + TCPOLEN_SACK_PERM); + + if (options & NF_SYNPROXY_OPT_WSCALE) + *ptr++ = htonl((TCPOPT_NOP << 24) | + (TCPOPT_WINDOW << 16) | + (TCPOLEN_WINDOW << 8) | + opts->wscale); +} + +void synproxy_init_timestamp_cookie(const struct nf_synproxy_info *info, + struct synproxy_options *opts) +{ + opts->tsecr = opts->tsval; + opts->tsval = tcp_time_stamp_raw() & ~0x3f; + + if (opts->options & NF_SYNPROXY_OPT_WSCALE) { + opts->tsval |= opts->wscale; + opts->wscale = info->wscale; + } else + opts->tsval |= 0xf; + + if (opts->options & NF_SYNPROXY_OPT_SACK_PERM) + opts->tsval |= 1 << 4; + + if (opts->options & NF_SYNPROXY_OPT_ECN) + opts->tsval |= 1 << 5; +} +EXPORT_SYMBOL_GPL(synproxy_init_timestamp_cookie); + +static void +synproxy_check_timestamp_cookie(struct synproxy_options *opts) +{ + opts->wscale = opts->tsecr & 0xf; + if (opts->wscale != 0xf) + opts->options |= NF_SYNPROXY_OPT_WSCALE; + + opts->options |= opts->tsecr & (1 << 4) ? NF_SYNPROXY_OPT_SACK_PERM : 0; + + opts->options |= opts->tsecr & (1 << 5) ? NF_SYNPROXY_OPT_ECN : 0; +} + +static unsigned int +synproxy_tstamp_adjust(struct sk_buff *skb, unsigned int protoff, + struct tcphdr *th, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + const struct nf_conn_synproxy *synproxy) +{ + unsigned int optoff, optend; + __be32 *ptr, old; + + if (synproxy->tsoff == 0) + return 1; + + optoff = protoff + sizeof(struct tcphdr); + optend = protoff + th->doff * 4; + + if (skb_ensure_writable(skb, optend)) + return 0; + + while (optoff < optend) { + unsigned char *op = skb->data + optoff; + + switch (op[0]) { + case TCPOPT_EOL: + return 1; + case TCPOPT_NOP: + optoff++; + continue; + default: + if (optoff + 1 == optend || + optoff + op[1] > optend || + op[1] < 2) + return 0; + if (op[0] == TCPOPT_TIMESTAMP && + op[1] == TCPOLEN_TIMESTAMP) { + if (CTINFO2DIR(ctinfo) == IP_CT_DIR_REPLY) { + ptr = (__be32 *)&op[2]; + old = *ptr; + *ptr = htonl(ntohl(*ptr) - + synproxy->tsoff); + } else { + ptr = (__be32 *)&op[6]; + old = *ptr; + *ptr = htonl(ntohl(*ptr) + + synproxy->tsoff); + } + inet_proto_csum_replace4(&th->check, skb, + old, *ptr, false); + return 1; + } + optoff += op[1]; + } + } + return 1; +} + +#ifdef CONFIG_PROC_FS +static void *synproxy_cpu_seq_start(struct seq_file *seq, loff_t *pos) +{ + struct synproxy_net *snet = synproxy_pernet(seq_file_net(seq)); + int cpu; + + if (*pos == 0) + return SEQ_START_TOKEN; + + for (cpu = *pos - 1; cpu < nr_cpu_ids; cpu++) { + if (!cpu_possible(cpu)) + continue; + *pos = cpu + 1; + return per_cpu_ptr(snet->stats, cpu); + } + + return NULL; +} + +static void *synproxy_cpu_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct synproxy_net *snet = synproxy_pernet(seq_file_net(seq)); + int cpu; + + for (cpu = *pos; cpu < nr_cpu_ids; cpu++) { + if (!cpu_possible(cpu)) + continue; + *pos = cpu + 1; + return per_cpu_ptr(snet->stats, cpu); + } + (*pos)++; + return NULL; +} + +static void synproxy_cpu_seq_stop(struct seq_file *seq, void *v) +{ + return; +} + +static int synproxy_cpu_seq_show(struct seq_file *seq, void *v) +{ + struct synproxy_stats *stats = v; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, "entries\t\tsyn_received\t" + "cookie_invalid\tcookie_valid\t" + "cookie_retrans\tconn_reopened\n"); + return 0; + } + + seq_printf(seq, "%08x\t%08x\t%08x\t%08x\t%08x\t%08x\n", 0, + stats->syn_received, + stats->cookie_invalid, + stats->cookie_valid, + stats->cookie_retrans, + stats->conn_reopened); + + return 0; +} + +static const struct seq_operations synproxy_cpu_seq_ops = { + .start = synproxy_cpu_seq_start, + .next = synproxy_cpu_seq_next, + .stop = synproxy_cpu_seq_stop, + .show = synproxy_cpu_seq_show, +}; + +static int __net_init synproxy_proc_init(struct net *net) +{ + if (!proc_create_net("synproxy", 0444, net->proc_net_stat, + &synproxy_cpu_seq_ops, sizeof(struct seq_net_private))) + return -ENOMEM; + return 0; +} + +static void __net_exit synproxy_proc_exit(struct net *net) +{ + remove_proc_entry("synproxy", net->proc_net_stat); +} +#else +static int __net_init synproxy_proc_init(struct net *net) +{ + return 0; +} + +static void __net_exit synproxy_proc_exit(struct net *net) +{ + return; +} +#endif /* CONFIG_PROC_FS */ + +static int __net_init synproxy_net_init(struct net *net) +{ + struct synproxy_net *snet = synproxy_pernet(net); + struct nf_conn *ct; + int err = -ENOMEM; + + ct = nf_ct_tmpl_alloc(net, &nf_ct_zone_dflt, GFP_KERNEL); + if (!ct) + goto err1; + + if (!nfct_seqadj_ext_add(ct)) + goto err2; + if (!nfct_synproxy_ext_add(ct)) + goto err2; + + __set_bit(IPS_CONFIRMED_BIT, &ct->status); + snet->tmpl = ct; + + snet->stats = alloc_percpu(struct synproxy_stats); + if (snet->stats == NULL) + goto err2; + + err = synproxy_proc_init(net); + if (err < 0) + goto err3; + + return 0; + +err3: + free_percpu(snet->stats); +err2: + nf_ct_tmpl_free(ct); +err1: + return err; +} + +static void __net_exit synproxy_net_exit(struct net *net) +{ + struct synproxy_net *snet = synproxy_pernet(net); + + nf_ct_put(snet->tmpl); + synproxy_proc_exit(net); + free_percpu(snet->stats); +} + +static struct pernet_operations synproxy_net_ops = { + .init = synproxy_net_init, + .exit = synproxy_net_exit, + .id = &synproxy_net_id, + .size = sizeof(struct synproxy_net), +}; + +static int __init synproxy_core_init(void) +{ + return register_pernet_subsys(&synproxy_net_ops); +} + +static void __exit synproxy_core_exit(void) +{ + unregister_pernet_subsys(&synproxy_net_ops); +} + +module_init(synproxy_core_init); +module_exit(synproxy_core_exit); + +static struct iphdr * +synproxy_build_ip(struct net *net, struct sk_buff *skb, __be32 saddr, + __be32 daddr) +{ + struct iphdr *iph; + + skb_reset_network_header(skb); + iph = skb_put(skb, sizeof(*iph)); + iph->version = 4; + iph->ihl = sizeof(*iph) / 4; + iph->tos = 0; + iph->id = 0; + iph->frag_off = htons(IP_DF); + iph->ttl = READ_ONCE(net->ipv4.sysctl_ip_default_ttl); + iph->protocol = IPPROTO_TCP; + iph->check = 0; + iph->saddr = saddr; + iph->daddr = daddr; + + return iph; +} + +static void +synproxy_send_tcp(struct net *net, + const struct sk_buff *skb, struct sk_buff *nskb, + struct nf_conntrack *nfct, enum ip_conntrack_info ctinfo, + struct iphdr *niph, struct tcphdr *nth, + unsigned int tcp_hdr_size) +{ + nth->check = ~tcp_v4_check(tcp_hdr_size, niph->saddr, niph->daddr, 0); + nskb->ip_summed = CHECKSUM_PARTIAL; + nskb->csum_start = (unsigned char *)nth - nskb->head; + nskb->csum_offset = offsetof(struct tcphdr, check); + + skb_dst_set_noref(nskb, skb_dst(skb)); + nskb->protocol = htons(ETH_P_IP); + if (ip_route_me_harder(net, nskb->sk, nskb, RTN_UNSPEC)) + goto free_nskb; + + if (nfct) { + nf_ct_set(nskb, (struct nf_conn *)nfct, ctinfo); + nf_conntrack_get(nfct); + } + + ip_local_out(net, nskb->sk, nskb); + return; + +free_nskb: + kfree_skb(nskb); +} + +void +synproxy_send_client_synack(struct net *net, + const struct sk_buff *skb, const struct tcphdr *th, + const struct synproxy_options *opts) +{ + struct sk_buff *nskb; + struct iphdr *iph, *niph; + struct tcphdr *nth; + unsigned int tcp_hdr_size; + u16 mss = opts->mss_encode; + + iph = ip_hdr(skb); + + tcp_hdr_size = sizeof(*nth) + synproxy_options_size(opts); + nskb = alloc_skb(sizeof(*niph) + tcp_hdr_size + MAX_TCP_HEADER, + GFP_ATOMIC); + if (!nskb) + return; + skb_reserve(nskb, MAX_TCP_HEADER); + + niph = synproxy_build_ip(net, nskb, iph->daddr, iph->saddr); + + skb_reset_transport_header(nskb); + nth = skb_put(nskb, tcp_hdr_size); + nth->source = th->dest; + nth->dest = th->source; + nth->seq = htonl(__cookie_v4_init_sequence(iph, th, &mss)); + nth->ack_seq = htonl(ntohl(th->seq) + 1); + tcp_flag_word(nth) = TCP_FLAG_SYN | TCP_FLAG_ACK; + if (opts->options & NF_SYNPROXY_OPT_ECN) + tcp_flag_word(nth) |= TCP_FLAG_ECE; + nth->doff = tcp_hdr_size / 4; + nth->window = 0; + nth->check = 0; + nth->urg_ptr = 0; + + synproxy_build_options(nth, opts); + + synproxy_send_tcp(net, skb, nskb, skb_nfct(skb), + IP_CT_ESTABLISHED_REPLY, niph, nth, tcp_hdr_size); +} +EXPORT_SYMBOL_GPL(synproxy_send_client_synack); + +static void +synproxy_send_server_syn(struct net *net, + const struct sk_buff *skb, const struct tcphdr *th, + const struct synproxy_options *opts, u32 recv_seq) +{ + struct synproxy_net *snet = synproxy_pernet(net); + struct sk_buff *nskb; + struct iphdr *iph, *niph; + struct tcphdr *nth; + unsigned int tcp_hdr_size; + + iph = ip_hdr(skb); + + tcp_hdr_size = sizeof(*nth) + synproxy_options_size(opts); + nskb = alloc_skb(sizeof(*niph) + tcp_hdr_size + MAX_TCP_HEADER, + GFP_ATOMIC); + if (!nskb) + return; + skb_reserve(nskb, MAX_TCP_HEADER); + + niph = synproxy_build_ip(net, nskb, iph->saddr, iph->daddr); + + skb_reset_transport_header(nskb); + nth = skb_put(nskb, tcp_hdr_size); + nth->source = th->source; + nth->dest = th->dest; + nth->seq = htonl(recv_seq - 1); + /* ack_seq is used to relay our ISN to the synproxy hook to initialize + * sequence number translation once a connection tracking entry exists. + */ + nth->ack_seq = htonl(ntohl(th->ack_seq) - 1); + tcp_flag_word(nth) = TCP_FLAG_SYN; + if (opts->options & NF_SYNPROXY_OPT_ECN) + tcp_flag_word(nth) |= TCP_FLAG_ECE | TCP_FLAG_CWR; + nth->doff = tcp_hdr_size / 4; + nth->window = th->window; + nth->check = 0; + nth->urg_ptr = 0; + + synproxy_build_options(nth, opts); + + synproxy_send_tcp(net, skb, nskb, &snet->tmpl->ct_general, IP_CT_NEW, + niph, nth, tcp_hdr_size); +} + +static void +synproxy_send_server_ack(struct net *net, + const struct ip_ct_tcp *state, + const struct sk_buff *skb, const struct tcphdr *th, + const struct synproxy_options *opts) +{ + struct sk_buff *nskb; + struct iphdr *iph, *niph; + struct tcphdr *nth; + unsigned int tcp_hdr_size; + + iph = ip_hdr(skb); + + tcp_hdr_size = sizeof(*nth) + synproxy_options_size(opts); + nskb = alloc_skb(sizeof(*niph) + tcp_hdr_size + MAX_TCP_HEADER, + GFP_ATOMIC); + if (!nskb) + return; + skb_reserve(nskb, MAX_TCP_HEADER); + + niph = synproxy_build_ip(net, nskb, iph->daddr, iph->saddr); + + skb_reset_transport_header(nskb); + nth = skb_put(nskb, tcp_hdr_size); + nth->source = th->dest; + nth->dest = th->source; + nth->seq = htonl(ntohl(th->ack_seq)); + nth->ack_seq = htonl(ntohl(th->seq) + 1); + tcp_flag_word(nth) = TCP_FLAG_ACK; + nth->doff = tcp_hdr_size / 4; + nth->window = htons(state->seen[IP_CT_DIR_ORIGINAL].td_maxwin); + nth->check = 0; + nth->urg_ptr = 0; + + synproxy_build_options(nth, opts); + + synproxy_send_tcp(net, skb, nskb, NULL, 0, niph, nth, tcp_hdr_size); +} + +static void +synproxy_send_client_ack(struct net *net, + const struct sk_buff *skb, const struct tcphdr *th, + const struct synproxy_options *opts) +{ + struct sk_buff *nskb; + struct iphdr *iph, *niph; + struct tcphdr *nth; + unsigned int tcp_hdr_size; + + iph = ip_hdr(skb); + + tcp_hdr_size = sizeof(*nth) + synproxy_options_size(opts); + nskb = alloc_skb(sizeof(*niph) + tcp_hdr_size + MAX_TCP_HEADER, + GFP_ATOMIC); + if (!nskb) + return; + skb_reserve(nskb, MAX_TCP_HEADER); + + niph = synproxy_build_ip(net, nskb, iph->saddr, iph->daddr); + + skb_reset_transport_header(nskb); + nth = skb_put(nskb, tcp_hdr_size); + nth->source = th->source; + nth->dest = th->dest; + nth->seq = htonl(ntohl(th->seq) + 1); + nth->ack_seq = th->ack_seq; + tcp_flag_word(nth) = TCP_FLAG_ACK; + nth->doff = tcp_hdr_size / 4; + nth->window = htons(ntohs(th->window) >> opts->wscale); + nth->check = 0; + nth->urg_ptr = 0; + + synproxy_build_options(nth, opts); + + synproxy_send_tcp(net, skb, nskb, skb_nfct(skb), + IP_CT_ESTABLISHED_REPLY, niph, nth, tcp_hdr_size); +} + +bool +synproxy_recv_client_ack(struct net *net, + const struct sk_buff *skb, const struct tcphdr *th, + struct synproxy_options *opts, u32 recv_seq) +{ + struct synproxy_net *snet = synproxy_pernet(net); + int mss; + + mss = __cookie_v4_check(ip_hdr(skb), th, ntohl(th->ack_seq) - 1); + if (mss == 0) { + this_cpu_inc(snet->stats->cookie_invalid); + return false; + } + + this_cpu_inc(snet->stats->cookie_valid); + opts->mss_option = mss; + opts->options |= NF_SYNPROXY_OPT_MSS; + + if (opts->options & NF_SYNPROXY_OPT_TIMESTAMP) + synproxy_check_timestamp_cookie(opts); + + synproxy_send_server_syn(net, skb, th, opts, recv_seq); + return true; +} +EXPORT_SYMBOL_GPL(synproxy_recv_client_ack); + +unsigned int +ipv4_synproxy_hook(void *priv, struct sk_buff *skb, + const struct nf_hook_state *nhs) +{ + struct net *net = nhs->net; + struct synproxy_net *snet = synproxy_pernet(net); + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + struct nf_conn_synproxy *synproxy; + struct synproxy_options opts = {}; + const struct ip_ct_tcp *state; + struct tcphdr *th, _th; + unsigned int thoff; + + ct = nf_ct_get(skb, &ctinfo); + if (!ct) + return NF_ACCEPT; + + synproxy = nfct_synproxy(ct); + if (!synproxy) + return NF_ACCEPT; + + if (nf_is_loopback_packet(skb) || + ip_hdr(skb)->protocol != IPPROTO_TCP) + return NF_ACCEPT; + + thoff = ip_hdrlen(skb); + th = skb_header_pointer(skb, thoff, sizeof(_th), &_th); + if (!th) + return NF_DROP; + + state = &ct->proto.tcp; + switch (state->state) { + case TCP_CONNTRACK_CLOSE: + if (th->rst && CTINFO2DIR(ctinfo) != IP_CT_DIR_ORIGINAL) { + nf_ct_seqadj_init(ct, ctinfo, synproxy->isn - + ntohl(th->seq) + 1); + break; + } + + if (!th->syn || th->ack || + CTINFO2DIR(ctinfo) != IP_CT_DIR_ORIGINAL) + break; + + /* Reopened connection - reset the sequence number and timestamp + * adjustments, they will get initialized once the connection is + * reestablished. + */ + nf_ct_seqadj_init(ct, ctinfo, 0); + synproxy->tsoff = 0; + this_cpu_inc(snet->stats->conn_reopened); + fallthrough; + case TCP_CONNTRACK_SYN_SENT: + if (!synproxy_parse_options(skb, thoff, th, &opts)) + return NF_DROP; + + if (!th->syn && th->ack && + CTINFO2DIR(ctinfo) == IP_CT_DIR_ORIGINAL) { + /* Keep-Alives are sent with SEG.SEQ = SND.NXT-1, + * therefore we need to add 1 to make the SYN sequence + * number match the one of first SYN. + */ + if (synproxy_recv_client_ack(net, skb, th, &opts, + ntohl(th->seq) + 1)) { + this_cpu_inc(snet->stats->cookie_retrans); + consume_skb(skb); + return NF_STOLEN; + } else { + return NF_DROP; + } + } + + synproxy->isn = ntohl(th->ack_seq); + if (opts.options & NF_SYNPROXY_OPT_TIMESTAMP) + synproxy->its = opts.tsecr; + + nf_conntrack_event_cache(IPCT_SYNPROXY, ct); + break; + case TCP_CONNTRACK_SYN_RECV: + if (!th->syn || !th->ack) + break; + + if (!synproxy_parse_options(skb, thoff, th, &opts)) + return NF_DROP; + + if (opts.options & NF_SYNPROXY_OPT_TIMESTAMP) { + synproxy->tsoff = opts.tsval - synproxy->its; + nf_conntrack_event_cache(IPCT_SYNPROXY, ct); + } + + opts.options &= ~(NF_SYNPROXY_OPT_MSS | + NF_SYNPROXY_OPT_WSCALE | + NF_SYNPROXY_OPT_SACK_PERM); + + swap(opts.tsval, opts.tsecr); + synproxy_send_server_ack(net, state, skb, th, &opts); + + nf_ct_seqadj_init(ct, ctinfo, synproxy->isn - ntohl(th->seq)); + nf_conntrack_event_cache(IPCT_SEQADJ, ct); + + swap(opts.tsval, opts.tsecr); + synproxy_send_client_ack(net, skb, th, &opts); + + consume_skb(skb); + return NF_STOLEN; + default: + break; + } + + synproxy_tstamp_adjust(skb, thoff, th, ct, ctinfo, synproxy); + return NF_ACCEPT; +} +EXPORT_SYMBOL_GPL(ipv4_synproxy_hook); + +static const struct nf_hook_ops ipv4_synproxy_ops[] = { + { + .hook = ipv4_synproxy_hook, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_LOCAL_IN, + .priority = NF_IP_PRI_CONNTRACK_CONFIRM - 1, + }, + { + .hook = ipv4_synproxy_hook, + .pf = NFPROTO_IPV4, + .hooknum = NF_INET_POST_ROUTING, + .priority = NF_IP_PRI_CONNTRACK_CONFIRM - 1, + }, +}; + +int nf_synproxy_ipv4_init(struct synproxy_net *snet, struct net *net) +{ + int err; + + if (snet->hook_ref4 == 0) { + err = nf_register_net_hooks(net, ipv4_synproxy_ops, + ARRAY_SIZE(ipv4_synproxy_ops)); + if (err) + return err; + } + + snet->hook_ref4++; + return 0; +} +EXPORT_SYMBOL_GPL(nf_synproxy_ipv4_init); + +void nf_synproxy_ipv4_fini(struct synproxy_net *snet, struct net *net) +{ + snet->hook_ref4--; + if (snet->hook_ref4 == 0) + nf_unregister_net_hooks(net, ipv4_synproxy_ops, + ARRAY_SIZE(ipv4_synproxy_ops)); +} +EXPORT_SYMBOL_GPL(nf_synproxy_ipv4_fini); + +#if IS_ENABLED(CONFIG_IPV6) +static struct ipv6hdr * +synproxy_build_ip_ipv6(struct net *net, struct sk_buff *skb, + const struct in6_addr *saddr, + const struct in6_addr *daddr) +{ + struct ipv6hdr *iph; + + skb_reset_network_header(skb); + iph = skb_put(skb, sizeof(*iph)); + ip6_flow_hdr(iph, 0, 0); + iph->hop_limit = net->ipv6.devconf_all->hop_limit; + iph->nexthdr = IPPROTO_TCP; + iph->saddr = *saddr; + iph->daddr = *daddr; + + return iph; +} + +static void +synproxy_send_tcp_ipv6(struct net *net, + const struct sk_buff *skb, struct sk_buff *nskb, + struct nf_conntrack *nfct, enum ip_conntrack_info ctinfo, + struct ipv6hdr *niph, struct tcphdr *nth, + unsigned int tcp_hdr_size) +{ + struct dst_entry *dst; + struct flowi6 fl6; + int err; + + nth->check = ~tcp_v6_check(tcp_hdr_size, &niph->saddr, &niph->daddr, 0); + nskb->ip_summed = CHECKSUM_PARTIAL; + nskb->csum_start = (unsigned char *)nth - nskb->head; + nskb->csum_offset = offsetof(struct tcphdr, check); + + memset(&fl6, 0, sizeof(fl6)); + fl6.flowi6_proto = IPPROTO_TCP; + fl6.saddr = niph->saddr; + fl6.daddr = niph->daddr; + fl6.fl6_sport = nth->source; + fl6.fl6_dport = nth->dest; + security_skb_classify_flow((struct sk_buff *)skb, + flowi6_to_flowi_common(&fl6)); + err = nf_ip6_route(net, &dst, flowi6_to_flowi(&fl6), false); + if (err) { + goto free_nskb; + } + + dst = xfrm_lookup(net, dst, flowi6_to_flowi(&fl6), NULL, 0); + if (IS_ERR(dst)) + goto free_nskb; + + skb_dst_set(nskb, dst); + + if (nfct) { + nf_ct_set(nskb, (struct nf_conn *)nfct, ctinfo); + nf_conntrack_get(nfct); + } + + ip6_local_out(net, nskb->sk, nskb); + return; + +free_nskb: + kfree_skb(nskb); +} + +void +synproxy_send_client_synack_ipv6(struct net *net, + const struct sk_buff *skb, + const struct tcphdr *th, + const struct synproxy_options *opts) +{ + struct sk_buff *nskb; + struct ipv6hdr *iph, *niph; + struct tcphdr *nth; + unsigned int tcp_hdr_size; + u16 mss = opts->mss_encode; + + iph = ipv6_hdr(skb); + + tcp_hdr_size = sizeof(*nth) + synproxy_options_size(opts); + nskb = alloc_skb(sizeof(*niph) + tcp_hdr_size + MAX_TCP_HEADER, + GFP_ATOMIC); + if (!nskb) + return; + skb_reserve(nskb, MAX_TCP_HEADER); + + niph = synproxy_build_ip_ipv6(net, nskb, &iph->daddr, &iph->saddr); + + skb_reset_transport_header(nskb); + nth = skb_put(nskb, tcp_hdr_size); + nth->source = th->dest; + nth->dest = th->source; + nth->seq = htonl(nf_ipv6_cookie_init_sequence(iph, th, &mss)); + nth->ack_seq = htonl(ntohl(th->seq) + 1); + tcp_flag_word(nth) = TCP_FLAG_SYN | TCP_FLAG_ACK; + if (opts->options & NF_SYNPROXY_OPT_ECN) + tcp_flag_word(nth) |= TCP_FLAG_ECE; + nth->doff = tcp_hdr_size / 4; + nth->window = 0; + nth->check = 0; + nth->urg_ptr = 0; + + synproxy_build_options(nth, opts); + + synproxy_send_tcp_ipv6(net, skb, nskb, skb_nfct(skb), + IP_CT_ESTABLISHED_REPLY, niph, nth, + tcp_hdr_size); +} +EXPORT_SYMBOL_GPL(synproxy_send_client_synack_ipv6); + +static void +synproxy_send_server_syn_ipv6(struct net *net, const struct sk_buff *skb, + const struct tcphdr *th, + const struct synproxy_options *opts, u32 recv_seq) +{ + struct synproxy_net *snet = synproxy_pernet(net); + struct sk_buff *nskb; + struct ipv6hdr *iph, *niph; + struct tcphdr *nth; + unsigned int tcp_hdr_size; + + iph = ipv6_hdr(skb); + + tcp_hdr_size = sizeof(*nth) + synproxy_options_size(opts); + nskb = alloc_skb(sizeof(*niph) + tcp_hdr_size + MAX_TCP_HEADER, + GFP_ATOMIC); + if (!nskb) + return; + skb_reserve(nskb, MAX_TCP_HEADER); + + niph = synproxy_build_ip_ipv6(net, nskb, &iph->saddr, &iph->daddr); + + skb_reset_transport_header(nskb); + nth = skb_put(nskb, tcp_hdr_size); + nth->source = th->source; + nth->dest = th->dest; + nth->seq = htonl(recv_seq - 1); + /* ack_seq is used to relay our ISN to the synproxy hook to initialize + * sequence number translation once a connection tracking entry exists. + */ + nth->ack_seq = htonl(ntohl(th->ack_seq) - 1); + tcp_flag_word(nth) = TCP_FLAG_SYN; + if (opts->options & NF_SYNPROXY_OPT_ECN) + tcp_flag_word(nth) |= TCP_FLAG_ECE | TCP_FLAG_CWR; + nth->doff = tcp_hdr_size / 4; + nth->window = th->window; + nth->check = 0; + nth->urg_ptr = 0; + + synproxy_build_options(nth, opts); + + synproxy_send_tcp_ipv6(net, skb, nskb, &snet->tmpl->ct_general, + IP_CT_NEW, niph, nth, tcp_hdr_size); +} + +static void +synproxy_send_server_ack_ipv6(struct net *net, const struct ip_ct_tcp *state, + const struct sk_buff *skb, + const struct tcphdr *th, + const struct synproxy_options *opts) +{ + struct sk_buff *nskb; + struct ipv6hdr *iph, *niph; + struct tcphdr *nth; + unsigned int tcp_hdr_size; + + iph = ipv6_hdr(skb); + + tcp_hdr_size = sizeof(*nth) + synproxy_options_size(opts); + nskb = alloc_skb(sizeof(*niph) + tcp_hdr_size + MAX_TCP_HEADER, + GFP_ATOMIC); + if (!nskb) + return; + skb_reserve(nskb, MAX_TCP_HEADER); + + niph = synproxy_build_ip_ipv6(net, nskb, &iph->daddr, &iph->saddr); + + skb_reset_transport_header(nskb); + nth = skb_put(nskb, tcp_hdr_size); + nth->source = th->dest; + nth->dest = th->source; + nth->seq = htonl(ntohl(th->ack_seq)); + nth->ack_seq = htonl(ntohl(th->seq) + 1); + tcp_flag_word(nth) = TCP_FLAG_ACK; + nth->doff = tcp_hdr_size / 4; + nth->window = htons(state->seen[IP_CT_DIR_ORIGINAL].td_maxwin); + nth->check = 0; + nth->urg_ptr = 0; + + synproxy_build_options(nth, opts); + + synproxy_send_tcp_ipv6(net, skb, nskb, NULL, 0, niph, nth, + tcp_hdr_size); +} + +static void +synproxy_send_client_ack_ipv6(struct net *net, const struct sk_buff *skb, + const struct tcphdr *th, + const struct synproxy_options *opts) +{ + struct sk_buff *nskb; + struct ipv6hdr *iph, *niph; + struct tcphdr *nth; + unsigned int tcp_hdr_size; + + iph = ipv6_hdr(skb); + + tcp_hdr_size = sizeof(*nth) + synproxy_options_size(opts); + nskb = alloc_skb(sizeof(*niph) + tcp_hdr_size + MAX_TCP_HEADER, + GFP_ATOMIC); + if (!nskb) + return; + skb_reserve(nskb, MAX_TCP_HEADER); + + niph = synproxy_build_ip_ipv6(net, nskb, &iph->saddr, &iph->daddr); + + skb_reset_transport_header(nskb); + nth = skb_put(nskb, tcp_hdr_size); + nth->source = th->source; + nth->dest = th->dest; + nth->seq = htonl(ntohl(th->seq) + 1); + nth->ack_seq = th->ack_seq; + tcp_flag_word(nth) = TCP_FLAG_ACK; + nth->doff = tcp_hdr_size / 4; + nth->window = htons(ntohs(th->window) >> opts->wscale); + nth->check = 0; + nth->urg_ptr = 0; + + synproxy_build_options(nth, opts); + + synproxy_send_tcp_ipv6(net, skb, nskb, skb_nfct(skb), + IP_CT_ESTABLISHED_REPLY, niph, nth, + tcp_hdr_size); +} + +bool +synproxy_recv_client_ack_ipv6(struct net *net, + const struct sk_buff *skb, + const struct tcphdr *th, + struct synproxy_options *opts, u32 recv_seq) +{ + struct synproxy_net *snet = synproxy_pernet(net); + int mss; + + mss = nf_cookie_v6_check(ipv6_hdr(skb), th, ntohl(th->ack_seq) - 1); + if (mss == 0) { + this_cpu_inc(snet->stats->cookie_invalid); + return false; + } + + this_cpu_inc(snet->stats->cookie_valid); + opts->mss_option = mss; + opts->options |= NF_SYNPROXY_OPT_MSS; + + if (opts->options & NF_SYNPROXY_OPT_TIMESTAMP) + synproxy_check_timestamp_cookie(opts); + + synproxy_send_server_syn_ipv6(net, skb, th, opts, recv_seq); + return true; +} +EXPORT_SYMBOL_GPL(synproxy_recv_client_ack_ipv6); + +unsigned int +ipv6_synproxy_hook(void *priv, struct sk_buff *skb, + const struct nf_hook_state *nhs) +{ + struct net *net = nhs->net; + struct synproxy_net *snet = synproxy_pernet(net); + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + struct nf_conn_synproxy *synproxy; + struct synproxy_options opts = {}; + const struct ip_ct_tcp *state; + struct tcphdr *th, _th; + __be16 frag_off; + u8 nexthdr; + int thoff; + + ct = nf_ct_get(skb, &ctinfo); + if (!ct) + return NF_ACCEPT; + + synproxy = nfct_synproxy(ct); + if (!synproxy) + return NF_ACCEPT; + + if (nf_is_loopback_packet(skb)) + return NF_ACCEPT; + + nexthdr = ipv6_hdr(skb)->nexthdr; + thoff = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &nexthdr, + &frag_off); + if (thoff < 0 || nexthdr != IPPROTO_TCP) + return NF_ACCEPT; + + th = skb_header_pointer(skb, thoff, sizeof(_th), &_th); + if (!th) + return NF_DROP; + + state = &ct->proto.tcp; + switch (state->state) { + case TCP_CONNTRACK_CLOSE: + if (th->rst && CTINFO2DIR(ctinfo) != IP_CT_DIR_ORIGINAL) { + nf_ct_seqadj_init(ct, ctinfo, synproxy->isn - + ntohl(th->seq) + 1); + break; + } + + if (!th->syn || th->ack || + CTINFO2DIR(ctinfo) != IP_CT_DIR_ORIGINAL) + break; + + /* Reopened connection - reset the sequence number and timestamp + * adjustments, they will get initialized once the connection is + * reestablished. + */ + nf_ct_seqadj_init(ct, ctinfo, 0); + synproxy->tsoff = 0; + this_cpu_inc(snet->stats->conn_reopened); + fallthrough; + case TCP_CONNTRACK_SYN_SENT: + if (!synproxy_parse_options(skb, thoff, th, &opts)) + return NF_DROP; + + if (!th->syn && th->ack && + CTINFO2DIR(ctinfo) == IP_CT_DIR_ORIGINAL) { + /* Keep-Alives are sent with SEG.SEQ = SND.NXT-1, + * therefore we need to add 1 to make the SYN sequence + * number match the one of first SYN. + */ + if (synproxy_recv_client_ack_ipv6(net, skb, th, &opts, + ntohl(th->seq) + 1)) { + this_cpu_inc(snet->stats->cookie_retrans); + consume_skb(skb); + return NF_STOLEN; + } else { + return NF_DROP; + } + } + + synproxy->isn = ntohl(th->ack_seq); + if (opts.options & NF_SYNPROXY_OPT_TIMESTAMP) + synproxy->its = opts.tsecr; + + nf_conntrack_event_cache(IPCT_SYNPROXY, ct); + break; + case TCP_CONNTRACK_SYN_RECV: + if (!th->syn || !th->ack) + break; + + if (!synproxy_parse_options(skb, thoff, th, &opts)) + return NF_DROP; + + if (opts.options & NF_SYNPROXY_OPT_TIMESTAMP) { + synproxy->tsoff = opts.tsval - synproxy->its; + nf_conntrack_event_cache(IPCT_SYNPROXY, ct); + } + + opts.options &= ~(NF_SYNPROXY_OPT_MSS | + NF_SYNPROXY_OPT_WSCALE | + NF_SYNPROXY_OPT_SACK_PERM); + + swap(opts.tsval, opts.tsecr); + synproxy_send_server_ack_ipv6(net, state, skb, th, &opts); + + nf_ct_seqadj_init(ct, ctinfo, synproxy->isn - ntohl(th->seq)); + nf_conntrack_event_cache(IPCT_SEQADJ, ct); + + swap(opts.tsval, opts.tsecr); + synproxy_send_client_ack_ipv6(net, skb, th, &opts); + + consume_skb(skb); + return NF_STOLEN; + default: + break; + } + + synproxy_tstamp_adjust(skb, thoff, th, ct, ctinfo, synproxy); + return NF_ACCEPT; +} +EXPORT_SYMBOL_GPL(ipv6_synproxy_hook); + +static const struct nf_hook_ops ipv6_synproxy_ops[] = { + { + .hook = ipv6_synproxy_hook, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_LOCAL_IN, + .priority = NF_IP_PRI_CONNTRACK_CONFIRM - 1, + }, + { + .hook = ipv6_synproxy_hook, + .pf = NFPROTO_IPV6, + .hooknum = NF_INET_POST_ROUTING, + .priority = NF_IP_PRI_CONNTRACK_CONFIRM - 1, + }, +}; + +int +nf_synproxy_ipv6_init(struct synproxy_net *snet, struct net *net) +{ + int err; + + if (snet->hook_ref6 == 0) { + err = nf_register_net_hooks(net, ipv6_synproxy_ops, + ARRAY_SIZE(ipv6_synproxy_ops)); + if (err) + return err; + } + + snet->hook_ref6++; + return 0; +} +EXPORT_SYMBOL_GPL(nf_synproxy_ipv6_init); + +void +nf_synproxy_ipv6_fini(struct synproxy_net *snet, struct net *net) +{ + snet->hook_ref6--; + if (snet->hook_ref6 == 0) + nf_unregister_net_hooks(net, ipv6_synproxy_ops, + ARRAY_SIZE(ipv6_synproxy_ops)); +} +EXPORT_SYMBOL_GPL(nf_synproxy_ipv6_fini); +#endif /* CONFIG_IPV6 */ + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_DESCRIPTION("nftables SYNPROXY expression support"); diff --git a/net/netfilter/nf_tables_api.c b/net/netfilter/nf_tables_api.c new file mode 100644 index 000000000..1edb21382 --- /dev/null +++ b/net/netfilter/nf_tables_api.c @@ -0,0 +1,10973 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2007-2009 Patrick McHardy + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define NFT_MODULE_AUTOLOAD_LIMIT (MODULE_NAME_LEN - sizeof("nft-expr-255-")) +#define NFT_SET_MAX_ANONLEN 16 + +unsigned int nf_tables_net_id __read_mostly; + +static LIST_HEAD(nf_tables_expressions); +static LIST_HEAD(nf_tables_objects); +static LIST_HEAD(nf_tables_flowtables); +static LIST_HEAD(nf_tables_destroy_list); +static LIST_HEAD(nf_tables_gc_list); +static DEFINE_SPINLOCK(nf_tables_destroy_list_lock); +static DEFINE_SPINLOCK(nf_tables_gc_list_lock); + +enum { + NFT_VALIDATE_SKIP = 0, + NFT_VALIDATE_NEED, + NFT_VALIDATE_DO, +}; + +static struct rhltable nft_objname_ht; + +static u32 nft_chain_hash(const void *data, u32 len, u32 seed); +static u32 nft_chain_hash_obj(const void *data, u32 len, u32 seed); +static int nft_chain_hash_cmp(struct rhashtable_compare_arg *, const void *); + +static u32 nft_objname_hash(const void *data, u32 len, u32 seed); +static u32 nft_objname_hash_obj(const void *data, u32 len, u32 seed); +static int nft_objname_hash_cmp(struct rhashtable_compare_arg *, const void *); + +static const struct rhashtable_params nft_chain_ht_params = { + .head_offset = offsetof(struct nft_chain, rhlhead), + .key_offset = offsetof(struct nft_chain, name), + .hashfn = nft_chain_hash, + .obj_hashfn = nft_chain_hash_obj, + .obj_cmpfn = nft_chain_hash_cmp, + .automatic_shrinking = true, +}; + +static const struct rhashtable_params nft_objname_ht_params = { + .head_offset = offsetof(struct nft_object, rhlhead), + .key_offset = offsetof(struct nft_object, key), + .hashfn = nft_objname_hash, + .obj_hashfn = nft_objname_hash_obj, + .obj_cmpfn = nft_objname_hash_cmp, + .automatic_shrinking = true, +}; + +struct nft_audit_data { + struct nft_table *table; + int entries; + int op; + struct list_head list; +}; + +static const u8 nft2audit_op[NFT_MSG_MAX] = { // enum nf_tables_msg_types + [NFT_MSG_NEWTABLE] = AUDIT_NFT_OP_TABLE_REGISTER, + [NFT_MSG_GETTABLE] = AUDIT_NFT_OP_INVALID, + [NFT_MSG_DELTABLE] = AUDIT_NFT_OP_TABLE_UNREGISTER, + [NFT_MSG_NEWCHAIN] = AUDIT_NFT_OP_CHAIN_REGISTER, + [NFT_MSG_GETCHAIN] = AUDIT_NFT_OP_INVALID, + [NFT_MSG_DELCHAIN] = AUDIT_NFT_OP_CHAIN_UNREGISTER, + [NFT_MSG_NEWRULE] = AUDIT_NFT_OP_RULE_REGISTER, + [NFT_MSG_GETRULE] = AUDIT_NFT_OP_INVALID, + [NFT_MSG_DELRULE] = AUDIT_NFT_OP_RULE_UNREGISTER, + [NFT_MSG_NEWSET] = AUDIT_NFT_OP_SET_REGISTER, + [NFT_MSG_GETSET] = AUDIT_NFT_OP_INVALID, + [NFT_MSG_DELSET] = AUDIT_NFT_OP_SET_UNREGISTER, + [NFT_MSG_NEWSETELEM] = AUDIT_NFT_OP_SETELEM_REGISTER, + [NFT_MSG_GETSETELEM] = AUDIT_NFT_OP_INVALID, + [NFT_MSG_DELSETELEM] = AUDIT_NFT_OP_SETELEM_UNREGISTER, + [NFT_MSG_NEWGEN] = AUDIT_NFT_OP_GEN_REGISTER, + [NFT_MSG_GETGEN] = AUDIT_NFT_OP_INVALID, + [NFT_MSG_TRACE] = AUDIT_NFT_OP_INVALID, + [NFT_MSG_NEWOBJ] = AUDIT_NFT_OP_OBJ_REGISTER, + [NFT_MSG_GETOBJ] = AUDIT_NFT_OP_INVALID, + [NFT_MSG_DELOBJ] = AUDIT_NFT_OP_OBJ_UNREGISTER, + [NFT_MSG_GETOBJ_RESET] = AUDIT_NFT_OP_OBJ_RESET, + [NFT_MSG_NEWFLOWTABLE] = AUDIT_NFT_OP_FLOWTABLE_REGISTER, + [NFT_MSG_GETFLOWTABLE] = AUDIT_NFT_OP_INVALID, + [NFT_MSG_DELFLOWTABLE] = AUDIT_NFT_OP_FLOWTABLE_UNREGISTER, +}; + +static void nft_validate_state_update(struct net *net, u8 new_validate_state) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + + switch (nft_net->validate_state) { + case NFT_VALIDATE_SKIP: + WARN_ON_ONCE(new_validate_state == NFT_VALIDATE_DO); + break; + case NFT_VALIDATE_NEED: + break; + case NFT_VALIDATE_DO: + if (new_validate_state == NFT_VALIDATE_NEED) + return; + } + + nft_net->validate_state = new_validate_state; +} +static void nf_tables_trans_destroy_work(struct work_struct *w); +static DECLARE_WORK(trans_destroy_work, nf_tables_trans_destroy_work); + +static void nft_trans_gc_work(struct work_struct *work); +static DECLARE_WORK(trans_gc_work, nft_trans_gc_work); + +static void nft_ctx_init(struct nft_ctx *ctx, + struct net *net, + const struct sk_buff *skb, + const struct nlmsghdr *nlh, + u8 family, + struct nft_table *table, + struct nft_chain *chain, + const struct nlattr * const *nla) +{ + ctx->net = net; + ctx->family = family; + ctx->level = 0; + ctx->table = table; + ctx->chain = chain; + ctx->nla = nla; + ctx->portid = NETLINK_CB(skb).portid; + ctx->report = nlmsg_report(nlh); + ctx->flags = nlh->nlmsg_flags; + ctx->seq = nlh->nlmsg_seq; +} + +static struct nft_trans *nft_trans_alloc_gfp(const struct nft_ctx *ctx, + int msg_type, u32 size, gfp_t gfp) +{ + struct nft_trans *trans; + + trans = kzalloc(sizeof(struct nft_trans) + size, gfp); + if (trans == NULL) + return NULL; + + INIT_LIST_HEAD(&trans->list); + INIT_LIST_HEAD(&trans->binding_list); + trans->msg_type = msg_type; + trans->ctx = *ctx; + + return trans; +} + +static struct nft_trans *nft_trans_alloc(const struct nft_ctx *ctx, + int msg_type, u32 size) +{ + return nft_trans_alloc_gfp(ctx, msg_type, size, GFP_KERNEL); +} + +static void nft_trans_list_del(struct nft_trans *trans) +{ + list_del(&trans->list); + list_del(&trans->binding_list); +} + +static void nft_trans_destroy(struct nft_trans *trans) +{ + nft_trans_list_del(trans); + kfree(trans); +} + +static void __nft_set_trans_bind(const struct nft_ctx *ctx, struct nft_set *set, + bool bind) +{ + struct nftables_pernet *nft_net; + struct net *net = ctx->net; + struct nft_trans *trans; + + if (!nft_set_is_anonymous(set)) + return; + + nft_net = nft_pernet(net); + list_for_each_entry_reverse(trans, &nft_net->commit_list, list) { + switch (trans->msg_type) { + case NFT_MSG_NEWSET: + if (nft_trans_set(trans) == set) + nft_trans_set_bound(trans) = bind; + break; + case NFT_MSG_NEWSETELEM: + if (nft_trans_elem_set(trans) == set) + nft_trans_elem_set_bound(trans) = bind; + break; + } + } +} + +static void nft_set_trans_bind(const struct nft_ctx *ctx, struct nft_set *set) +{ + return __nft_set_trans_bind(ctx, set, true); +} + +static void nft_set_trans_unbind(const struct nft_ctx *ctx, struct nft_set *set) +{ + return __nft_set_trans_bind(ctx, set, false); +} + +static void __nft_chain_trans_bind(const struct nft_ctx *ctx, + struct nft_chain *chain, bool bind) +{ + struct nftables_pernet *nft_net; + struct net *net = ctx->net; + struct nft_trans *trans; + + if (!nft_chain_binding(chain)) + return; + + nft_net = nft_pernet(net); + list_for_each_entry_reverse(trans, &nft_net->commit_list, list) { + switch (trans->msg_type) { + case NFT_MSG_NEWCHAIN: + if (nft_trans_chain(trans) == chain) + nft_trans_chain_bound(trans) = bind; + break; + case NFT_MSG_NEWRULE: + if (trans->ctx.chain == chain) + nft_trans_rule_bound(trans) = bind; + break; + } + } +} + +static void nft_chain_trans_bind(const struct nft_ctx *ctx, + struct nft_chain *chain) +{ + __nft_chain_trans_bind(ctx, chain, true); +} + +int nf_tables_bind_chain(const struct nft_ctx *ctx, struct nft_chain *chain) +{ + if (!nft_chain_binding(chain)) + return 0; + + if (nft_chain_binding(ctx->chain)) + return -EOPNOTSUPP; + + if (chain->bound) + return -EBUSY; + + if (!nft_use_inc(&chain->use)) + return -EMFILE; + + chain->bound = true; + nft_chain_trans_bind(ctx, chain); + + return 0; +} + +void nf_tables_unbind_chain(const struct nft_ctx *ctx, struct nft_chain *chain) +{ + __nft_chain_trans_bind(ctx, chain, false); +} + +static int nft_netdev_register_hooks(struct net *net, + struct list_head *hook_list) +{ + struct nft_hook *hook; + int err, j; + + j = 0; + list_for_each_entry(hook, hook_list, list) { + err = nf_register_net_hook(net, &hook->ops); + if (err < 0) + goto err_register; + + j++; + } + return 0; + +err_register: + list_for_each_entry(hook, hook_list, list) { + if (j-- <= 0) + break; + + nf_unregister_net_hook(net, &hook->ops); + } + return err; +} + +static void nft_netdev_unregister_hooks(struct net *net, + struct list_head *hook_list, + bool release_netdev) +{ + struct nft_hook *hook, *next; + + list_for_each_entry_safe(hook, next, hook_list, list) { + nf_unregister_net_hook(net, &hook->ops); + if (release_netdev) { + list_del(&hook->list); + kfree_rcu(hook, rcu); + } + } +} + +static int nf_tables_register_hook(struct net *net, + const struct nft_table *table, + struct nft_chain *chain) +{ + struct nft_base_chain *basechain; + const struct nf_hook_ops *ops; + + if (table->flags & NFT_TABLE_F_DORMANT || + !nft_is_base_chain(chain)) + return 0; + + basechain = nft_base_chain(chain); + ops = &basechain->ops; + + if (basechain->type->ops_register) + return basechain->type->ops_register(net, ops); + + if (nft_base_chain_netdev(table->family, basechain->ops.hooknum)) + return nft_netdev_register_hooks(net, &basechain->hook_list); + + return nf_register_net_hook(net, &basechain->ops); +} + +static void __nf_tables_unregister_hook(struct net *net, + const struct nft_table *table, + struct nft_chain *chain, + bool release_netdev) +{ + struct nft_base_chain *basechain; + const struct nf_hook_ops *ops; + + if (table->flags & NFT_TABLE_F_DORMANT || + !nft_is_base_chain(chain)) + return; + basechain = nft_base_chain(chain); + ops = &basechain->ops; + + if (basechain->type->ops_unregister) + return basechain->type->ops_unregister(net, ops); + + if (nft_base_chain_netdev(table->family, basechain->ops.hooknum)) + nft_netdev_unregister_hooks(net, &basechain->hook_list, + release_netdev); + else + nf_unregister_net_hook(net, &basechain->ops); +} + +static void nf_tables_unregister_hook(struct net *net, + const struct nft_table *table, + struct nft_chain *chain) +{ + return __nf_tables_unregister_hook(net, table, chain, false); +} + +static void nft_trans_commit_list_add_tail(struct net *net, struct nft_trans *trans) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + + switch (trans->msg_type) { + case NFT_MSG_NEWSET: + if (!nft_trans_set_update(trans) && + nft_set_is_anonymous(nft_trans_set(trans))) + list_add_tail(&trans->binding_list, &nft_net->binding_list); + break; + case NFT_MSG_NEWCHAIN: + if (!nft_trans_chain_update(trans) && + nft_chain_binding(nft_trans_chain(trans))) + list_add_tail(&trans->binding_list, &nft_net->binding_list); + break; + } + + list_add_tail(&trans->list, &nft_net->commit_list); +} + +static int nft_trans_table_add(struct nft_ctx *ctx, int msg_type) +{ + struct nft_trans *trans; + + trans = nft_trans_alloc(ctx, msg_type, sizeof(struct nft_trans_table)); + if (trans == NULL) + return -ENOMEM; + + if (msg_type == NFT_MSG_NEWTABLE) + nft_activate_next(ctx->net, ctx->table); + + nft_trans_commit_list_add_tail(ctx->net, trans); + return 0; +} + +static int nft_deltable(struct nft_ctx *ctx) +{ + int err; + + err = nft_trans_table_add(ctx, NFT_MSG_DELTABLE); + if (err < 0) + return err; + + nft_deactivate_next(ctx->net, ctx->table); + return err; +} + +static struct nft_trans *nft_trans_chain_add(struct nft_ctx *ctx, int msg_type) +{ + struct nft_trans *trans; + + trans = nft_trans_alloc(ctx, msg_type, sizeof(struct nft_trans_chain)); + if (trans == NULL) + return ERR_PTR(-ENOMEM); + + if (msg_type == NFT_MSG_NEWCHAIN) { + nft_activate_next(ctx->net, ctx->chain); + + if (ctx->nla[NFTA_CHAIN_ID]) { + nft_trans_chain_id(trans) = + ntohl(nla_get_be32(ctx->nla[NFTA_CHAIN_ID])); + } + } + nft_trans_chain(trans) = ctx->chain; + nft_trans_commit_list_add_tail(ctx->net, trans); + + return trans; +} + +static int nft_delchain(struct nft_ctx *ctx) +{ + struct nft_trans *trans; + + trans = nft_trans_chain_add(ctx, NFT_MSG_DELCHAIN); + if (IS_ERR(trans)) + return PTR_ERR(trans); + + nft_use_dec(&ctx->table->use); + nft_deactivate_next(ctx->net, ctx->chain); + + return 0; +} + +void nft_rule_expr_activate(const struct nft_ctx *ctx, struct nft_rule *rule) +{ + struct nft_expr *expr; + + expr = nft_expr_first(rule); + while (nft_expr_more(rule, expr)) { + if (expr->ops->activate) + expr->ops->activate(ctx, expr); + + expr = nft_expr_next(expr); + } +} + +void nft_rule_expr_deactivate(const struct nft_ctx *ctx, struct nft_rule *rule, + enum nft_trans_phase phase) +{ + struct nft_expr *expr; + + expr = nft_expr_first(rule); + while (nft_expr_more(rule, expr)) { + if (expr->ops->deactivate) + expr->ops->deactivate(ctx, expr, phase); + + expr = nft_expr_next(expr); + } +} + +static int +nf_tables_delrule_deactivate(struct nft_ctx *ctx, struct nft_rule *rule) +{ + /* You cannot delete the same rule twice */ + if (nft_is_active_next(ctx->net, rule)) { + nft_deactivate_next(ctx->net, rule); + nft_use_dec(&ctx->chain->use); + return 0; + } + return -ENOENT; +} + +static struct nft_trans *nft_trans_rule_add(struct nft_ctx *ctx, int msg_type, + struct nft_rule *rule) +{ + struct nft_trans *trans; + + trans = nft_trans_alloc(ctx, msg_type, sizeof(struct nft_trans_rule)); + if (trans == NULL) + return NULL; + + if (msg_type == NFT_MSG_NEWRULE && ctx->nla[NFTA_RULE_ID] != NULL) { + nft_trans_rule_id(trans) = + ntohl(nla_get_be32(ctx->nla[NFTA_RULE_ID])); + } + nft_trans_rule(trans) = rule; + nft_trans_commit_list_add_tail(ctx->net, trans); + + return trans; +} + +static int nft_delrule(struct nft_ctx *ctx, struct nft_rule *rule) +{ + struct nft_flow_rule *flow; + struct nft_trans *trans; + int err; + + trans = nft_trans_rule_add(ctx, NFT_MSG_DELRULE, rule); + if (trans == NULL) + return -ENOMEM; + + if (ctx->chain->flags & NFT_CHAIN_HW_OFFLOAD) { + flow = nft_flow_rule_create(ctx->net, rule); + if (IS_ERR(flow)) { + nft_trans_destroy(trans); + return PTR_ERR(flow); + } + + nft_trans_flow_rule(trans) = flow; + } + + err = nf_tables_delrule_deactivate(ctx, rule); + if (err < 0) { + nft_trans_destroy(trans); + return err; + } + nft_rule_expr_deactivate(ctx, rule, NFT_TRANS_PREPARE); + + return 0; +} + +static int nft_delrule_by_chain(struct nft_ctx *ctx) +{ + struct nft_rule *rule; + int err; + + list_for_each_entry(rule, &ctx->chain->rules, list) { + if (!nft_is_active_next(ctx->net, rule)) + continue; + + err = nft_delrule(ctx, rule); + if (err < 0) + return err; + } + return 0; +} + +static int __nft_trans_set_add(const struct nft_ctx *ctx, int msg_type, + struct nft_set *set, + const struct nft_set_desc *desc) +{ + struct nft_trans *trans; + + trans = nft_trans_alloc(ctx, msg_type, sizeof(struct nft_trans_set)); + if (trans == NULL) + return -ENOMEM; + + if (msg_type == NFT_MSG_NEWSET && ctx->nla[NFTA_SET_ID] && !desc) { + nft_trans_set_id(trans) = + ntohl(nla_get_be32(ctx->nla[NFTA_SET_ID])); + nft_activate_next(ctx->net, set); + } + nft_trans_set(trans) = set; + if (desc) { + nft_trans_set_update(trans) = true; + nft_trans_set_gc_int(trans) = desc->gc_int; + nft_trans_set_timeout(trans) = desc->timeout; + } + nft_trans_commit_list_add_tail(ctx->net, trans); + + return 0; +} + +static int nft_trans_set_add(const struct nft_ctx *ctx, int msg_type, + struct nft_set *set) +{ + return __nft_trans_set_add(ctx, msg_type, set, NULL); +} + +static int nft_mapelem_deactivate(const struct nft_ctx *ctx, + struct nft_set *set, + const struct nft_set_iter *iter, + struct nft_set_elem *elem) +{ + nft_setelem_data_deactivate(ctx->net, set, elem); + + return 0; +} + +struct nft_set_elem_catchall { + struct list_head list; + struct rcu_head rcu; + void *elem; +}; + +static void nft_map_catchall_deactivate(const struct nft_ctx *ctx, + struct nft_set *set) +{ + u8 genmask = nft_genmask_next(ctx->net); + struct nft_set_elem_catchall *catchall; + struct nft_set_elem elem; + struct nft_set_ext *ext; + + list_for_each_entry(catchall, &set->catchall_list, list) { + ext = nft_set_elem_ext(set, catchall->elem); + if (!nft_set_elem_active(ext, genmask)) + continue; + + elem.priv = catchall->elem; + nft_setelem_data_deactivate(ctx->net, set, &elem); + break; + } +} + +static void nft_map_deactivate(const struct nft_ctx *ctx, struct nft_set *set) +{ + struct nft_set_iter iter = { + .genmask = nft_genmask_next(ctx->net), + .fn = nft_mapelem_deactivate, + }; + + set->ops->walk(ctx, set, &iter); + WARN_ON_ONCE(iter.err); + + nft_map_catchall_deactivate(ctx, set); +} + +static int nft_delset(const struct nft_ctx *ctx, struct nft_set *set) +{ + int err; + + err = nft_trans_set_add(ctx, NFT_MSG_DELSET, set); + if (err < 0) + return err; + + if (set->flags & (NFT_SET_MAP | NFT_SET_OBJECT)) + nft_map_deactivate(ctx, set); + + nft_deactivate_next(ctx->net, set); + nft_use_dec(&ctx->table->use); + + return err; +} + +static int nft_trans_obj_add(struct nft_ctx *ctx, int msg_type, + struct nft_object *obj) +{ + struct nft_trans *trans; + + trans = nft_trans_alloc(ctx, msg_type, sizeof(struct nft_trans_obj)); + if (trans == NULL) + return -ENOMEM; + + if (msg_type == NFT_MSG_NEWOBJ) + nft_activate_next(ctx->net, obj); + + nft_trans_obj(trans) = obj; + nft_trans_commit_list_add_tail(ctx->net, trans); + + return 0; +} + +static int nft_delobj(struct nft_ctx *ctx, struct nft_object *obj) +{ + int err; + + err = nft_trans_obj_add(ctx, NFT_MSG_DELOBJ, obj); + if (err < 0) + return err; + + nft_deactivate_next(ctx->net, obj); + nft_use_dec(&ctx->table->use); + + return err; +} + +static int nft_trans_flowtable_add(struct nft_ctx *ctx, int msg_type, + struct nft_flowtable *flowtable) +{ + struct nft_trans *trans; + + trans = nft_trans_alloc(ctx, msg_type, + sizeof(struct nft_trans_flowtable)); + if (trans == NULL) + return -ENOMEM; + + if (msg_type == NFT_MSG_NEWFLOWTABLE) + nft_activate_next(ctx->net, flowtable); + + INIT_LIST_HEAD(&nft_trans_flowtable_hooks(trans)); + nft_trans_flowtable(trans) = flowtable; + nft_trans_commit_list_add_tail(ctx->net, trans); + + return 0; +} + +static int nft_delflowtable(struct nft_ctx *ctx, + struct nft_flowtable *flowtable) +{ + int err; + + err = nft_trans_flowtable_add(ctx, NFT_MSG_DELFLOWTABLE, flowtable); + if (err < 0) + return err; + + nft_deactivate_next(ctx->net, flowtable); + nft_use_dec(&ctx->table->use); + + return err; +} + +static void __nft_reg_track_clobber(struct nft_regs_track *track, u8 dreg) +{ + int i; + + for (i = track->regs[dreg].num_reg; i > 0; i--) + __nft_reg_track_cancel(track, dreg - i); +} + +static void __nft_reg_track_update(struct nft_regs_track *track, + const struct nft_expr *expr, + u8 dreg, u8 num_reg) +{ + track->regs[dreg].selector = expr; + track->regs[dreg].bitwise = NULL; + track->regs[dreg].num_reg = num_reg; +} + +void nft_reg_track_update(struct nft_regs_track *track, + const struct nft_expr *expr, u8 dreg, u8 len) +{ + unsigned int regcount; + int i; + + __nft_reg_track_clobber(track, dreg); + + regcount = DIV_ROUND_UP(len, NFT_REG32_SIZE); + for (i = 0; i < regcount; i++, dreg++) + __nft_reg_track_update(track, expr, dreg, i); +} +EXPORT_SYMBOL_GPL(nft_reg_track_update); + +void nft_reg_track_cancel(struct nft_regs_track *track, u8 dreg, u8 len) +{ + unsigned int regcount; + int i; + + __nft_reg_track_clobber(track, dreg); + + regcount = DIV_ROUND_UP(len, NFT_REG32_SIZE); + for (i = 0; i < regcount; i++, dreg++) + __nft_reg_track_cancel(track, dreg); +} +EXPORT_SYMBOL_GPL(nft_reg_track_cancel); + +void __nft_reg_track_cancel(struct nft_regs_track *track, u8 dreg) +{ + track->regs[dreg].selector = NULL; + track->regs[dreg].bitwise = NULL; + track->regs[dreg].num_reg = 0; +} +EXPORT_SYMBOL_GPL(__nft_reg_track_cancel); + +/* + * Tables + */ + +static struct nft_table *nft_table_lookup(const struct net *net, + const struct nlattr *nla, + u8 family, u8 genmask, u32 nlpid) +{ + struct nftables_pernet *nft_net; + struct nft_table *table; + + if (nla == NULL) + return ERR_PTR(-EINVAL); + + nft_net = nft_pernet(net); + list_for_each_entry_rcu(table, &nft_net->tables, list, + lockdep_is_held(&nft_net->commit_mutex)) { + if (!nla_strcmp(nla, table->name) && + table->family == family && + nft_active_genmask(table, genmask)) { + if (nft_table_has_owner(table) && + nlpid && table->nlpid != nlpid) + return ERR_PTR(-EPERM); + + return table; + } + } + + return ERR_PTR(-ENOENT); +} + +static struct nft_table *nft_table_lookup_byhandle(const struct net *net, + const struct nlattr *nla, + int family, u8 genmask, u32 nlpid) +{ + struct nftables_pernet *nft_net; + struct nft_table *table; + + nft_net = nft_pernet(net); + list_for_each_entry(table, &nft_net->tables, list) { + if (be64_to_cpu(nla_get_be64(nla)) == table->handle && + table->family == family && + nft_active_genmask(table, genmask)) { + if (nft_table_has_owner(table) && + nlpid && table->nlpid != nlpid) + return ERR_PTR(-EPERM); + + return table; + } + } + + return ERR_PTR(-ENOENT); +} + +static inline u64 nf_tables_alloc_handle(struct nft_table *table) +{ + return ++table->hgenerator; +} + +static const struct nft_chain_type *chain_type[NFPROTO_NUMPROTO][NFT_CHAIN_T_MAX]; + +static const struct nft_chain_type * +__nft_chain_type_get(u8 family, enum nft_chain_types type) +{ + if (family >= NFPROTO_NUMPROTO || + type >= NFT_CHAIN_T_MAX) + return NULL; + + return chain_type[family][type]; +} + +static const struct nft_chain_type * +__nf_tables_chain_type_lookup(const struct nlattr *nla, u8 family) +{ + const struct nft_chain_type *type; + int i; + + for (i = 0; i < NFT_CHAIN_T_MAX; i++) { + type = __nft_chain_type_get(family, i); + if (!type) + continue; + if (!nla_strcmp(nla, type->name)) + return type; + } + return NULL; +} + +struct nft_module_request { + struct list_head list; + char module[MODULE_NAME_LEN]; + bool done; +}; + +#ifdef CONFIG_MODULES +__printf(2, 3) int nft_request_module(struct net *net, const char *fmt, + ...) +{ + char module_name[MODULE_NAME_LEN]; + struct nftables_pernet *nft_net; + struct nft_module_request *req; + va_list args; + int ret; + + va_start(args, fmt); + ret = vsnprintf(module_name, MODULE_NAME_LEN, fmt, args); + va_end(args); + if (ret >= MODULE_NAME_LEN) + return 0; + + nft_net = nft_pernet(net); + list_for_each_entry(req, &nft_net->module_list, list) { + if (!strcmp(req->module, module_name)) { + if (req->done) + return 0; + + /* A request to load this module already exists. */ + return -EAGAIN; + } + } + + req = kmalloc(sizeof(*req), GFP_KERNEL); + if (!req) + return -ENOMEM; + + req->done = false; + strscpy(req->module, module_name, MODULE_NAME_LEN); + list_add_tail(&req->list, &nft_net->module_list); + + return -EAGAIN; +} +EXPORT_SYMBOL_GPL(nft_request_module); +#endif + +static void lockdep_nfnl_nft_mutex_not_held(void) +{ +#ifdef CONFIG_PROVE_LOCKING + if (debug_locks) + WARN_ON_ONCE(lockdep_nfnl_is_held(NFNL_SUBSYS_NFTABLES)); +#endif +} + +static const struct nft_chain_type * +nf_tables_chain_type_lookup(struct net *net, const struct nlattr *nla, + u8 family, bool autoload) +{ + const struct nft_chain_type *type; + + type = __nf_tables_chain_type_lookup(nla, family); + if (type != NULL) + return type; + + lockdep_nfnl_nft_mutex_not_held(); +#ifdef CONFIG_MODULES + if (autoload) { + if (nft_request_module(net, "nft-chain-%u-%.*s", family, + nla_len(nla), + (const char *)nla_data(nla)) == -EAGAIN) + return ERR_PTR(-EAGAIN); + } +#endif + return ERR_PTR(-ENOENT); +} + +static __be16 nft_base_seq(const struct net *net) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + + return htons(nft_net->base_seq & 0xffff); +} + +static const struct nla_policy nft_table_policy[NFTA_TABLE_MAX + 1] = { + [NFTA_TABLE_NAME] = { .type = NLA_STRING, + .len = NFT_TABLE_MAXNAMELEN - 1 }, + [NFTA_TABLE_FLAGS] = { .type = NLA_U32 }, + [NFTA_TABLE_HANDLE] = { .type = NLA_U64 }, + [NFTA_TABLE_USERDATA] = { .type = NLA_BINARY, + .len = NFT_USERDATA_MAXLEN } +}; + +static int nf_tables_fill_table_info(struct sk_buff *skb, struct net *net, + u32 portid, u32 seq, int event, u32 flags, + int family, const struct nft_table *table) +{ + struct nlmsghdr *nlh; + + event = nfnl_msg_type(NFNL_SUBSYS_NFTABLES, event); + nlh = nfnl_msg_put(skb, portid, seq, event, flags, family, + NFNETLINK_V0, nft_base_seq(net)); + if (!nlh) + goto nla_put_failure; + + if (nla_put_string(skb, NFTA_TABLE_NAME, table->name) || + nla_put_be32(skb, NFTA_TABLE_FLAGS, + htonl(table->flags & NFT_TABLE_F_MASK)) || + nla_put_be32(skb, NFTA_TABLE_USE, htonl(table->use)) || + nla_put_be64(skb, NFTA_TABLE_HANDLE, cpu_to_be64(table->handle), + NFTA_TABLE_PAD)) + goto nla_put_failure; + if (nft_table_has_owner(table) && + nla_put_be32(skb, NFTA_TABLE_OWNER, htonl(table->nlpid))) + goto nla_put_failure; + + if (table->udata) { + if (nla_put(skb, NFTA_TABLE_USERDATA, table->udlen, table->udata)) + goto nla_put_failure; + } + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_trim(skb, nlh); + return -1; +} + +struct nftnl_skb_parms { + bool report; +}; +#define NFT_CB(skb) (*(struct nftnl_skb_parms*)&((skb)->cb)) + +static void nft_notify_enqueue(struct sk_buff *skb, bool report, + struct list_head *notify_list) +{ + NFT_CB(skb).report = report; + list_add_tail(&skb->list, notify_list); +} + +static void nf_tables_table_notify(const struct nft_ctx *ctx, int event) +{ + struct nftables_pernet *nft_net; + struct sk_buff *skb; + u16 flags = 0; + int err; + + if (!ctx->report && + !nfnetlink_has_listeners(ctx->net, NFNLGRP_NFTABLES)) + return; + + skb = nlmsg_new(NLMSG_GOODSIZE, GFP_KERNEL); + if (skb == NULL) + goto err; + + if (ctx->flags & (NLM_F_CREATE | NLM_F_EXCL)) + flags |= ctx->flags & (NLM_F_CREATE | NLM_F_EXCL); + + err = nf_tables_fill_table_info(skb, ctx->net, ctx->portid, ctx->seq, + event, flags, ctx->family, ctx->table); + if (err < 0) { + kfree_skb(skb); + goto err; + } + + nft_net = nft_pernet(ctx->net); + nft_notify_enqueue(skb, ctx->report, &nft_net->notify_list); + return; +err: + nfnetlink_set_err(ctx->net, ctx->portid, NFNLGRP_NFTABLES, -ENOBUFS); +} + +static int nf_tables_dump_tables(struct sk_buff *skb, + struct netlink_callback *cb) +{ + const struct nfgenmsg *nfmsg = nlmsg_data(cb->nlh); + struct nftables_pernet *nft_net; + const struct nft_table *table; + unsigned int idx = 0, s_idx = cb->args[0]; + struct net *net = sock_net(skb->sk); + int family = nfmsg->nfgen_family; + + rcu_read_lock(); + nft_net = nft_pernet(net); + cb->seq = READ_ONCE(nft_net->base_seq); + + list_for_each_entry_rcu(table, &nft_net->tables, list) { + if (family != NFPROTO_UNSPEC && family != table->family) + continue; + + if (idx < s_idx) + goto cont; + if (idx > s_idx) + memset(&cb->args[1], 0, + sizeof(cb->args) - sizeof(cb->args[0])); + if (!nft_is_active(net, table)) + continue; + if (nf_tables_fill_table_info(skb, net, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NFT_MSG_NEWTABLE, NLM_F_MULTI, + table->family, table) < 0) + goto done; + + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); +cont: + idx++; + } +done: + rcu_read_unlock(); + cb->args[0] = idx; + return skb->len; +} + +static int nft_netlink_dump_start_rcu(struct sock *nlsk, struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct netlink_dump_control *c) +{ + int err; + + if (!try_module_get(THIS_MODULE)) + return -EINVAL; + + rcu_read_unlock(); + err = netlink_dump_start(nlsk, skb, nlh, c); + rcu_read_lock(); + module_put(THIS_MODULE); + + return err; +} + +/* called with rcu_read_lock held */ +static int nf_tables_gettable(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_cur(info->net); + u8 family = info->nfmsg->nfgen_family; + const struct nft_table *table; + struct net *net = info->net; + struct sk_buff *skb2; + int err; + + if (info->nlh->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .dump = nf_tables_dump_tables, + .module = THIS_MODULE, + }; + + return nft_netlink_dump_start_rcu(info->sk, skb, info->nlh, &c); + } + + table = nft_table_lookup(net, nla[NFTA_TABLE_NAME], family, genmask, 0); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_TABLE_NAME]); + return PTR_ERR(table); + } + + skb2 = alloc_skb(NLMSG_GOODSIZE, GFP_ATOMIC); + if (!skb2) + return -ENOMEM; + + err = nf_tables_fill_table_info(skb2, net, NETLINK_CB(skb).portid, + info->nlh->nlmsg_seq, NFT_MSG_NEWTABLE, + 0, family, table); + if (err < 0) + goto err_fill_table_info; + + return nfnetlink_unicast(skb2, net, NETLINK_CB(skb).portid); + +err_fill_table_info: + kfree_skb(skb2); + return err; +} + +static void nft_table_disable(struct net *net, struct nft_table *table, u32 cnt) +{ + struct nft_chain *chain; + u32 i = 0; + + list_for_each_entry(chain, &table->chains, list) { + if (!nft_is_active_next(net, chain)) + continue; + if (!nft_is_base_chain(chain)) + continue; + + if (cnt && i++ == cnt) + break; + + nf_tables_unregister_hook(net, table, chain); + } +} + +static int nf_tables_table_enable(struct net *net, struct nft_table *table) +{ + struct nft_chain *chain; + int err, i = 0; + + list_for_each_entry(chain, &table->chains, list) { + if (!nft_is_active_next(net, chain)) + continue; + if (!nft_is_base_chain(chain)) + continue; + + err = nf_tables_register_hook(net, table, chain); + if (err < 0) + goto err_register_hooks; + + i++; + } + return 0; + +err_register_hooks: + if (i) + nft_table_disable(net, table, i); + return err; +} + +static void nf_tables_table_disable(struct net *net, struct nft_table *table) +{ + table->flags &= ~NFT_TABLE_F_DORMANT; + nft_table_disable(net, table, 0); + table->flags |= NFT_TABLE_F_DORMANT; +} + +#define __NFT_TABLE_F_INTERNAL (NFT_TABLE_F_MASK + 1) +#define __NFT_TABLE_F_WAS_DORMANT (__NFT_TABLE_F_INTERNAL << 0) +#define __NFT_TABLE_F_WAS_AWAKEN (__NFT_TABLE_F_INTERNAL << 1) +#define __NFT_TABLE_F_UPDATE (__NFT_TABLE_F_WAS_DORMANT | \ + __NFT_TABLE_F_WAS_AWAKEN) + +static int nf_tables_updtable(struct nft_ctx *ctx) +{ + struct nft_trans *trans; + u32 flags; + int ret; + + if (!ctx->nla[NFTA_TABLE_FLAGS]) + return 0; + + flags = ntohl(nla_get_be32(ctx->nla[NFTA_TABLE_FLAGS])); + if (flags & ~NFT_TABLE_F_MASK) + return -EOPNOTSUPP; + + if (flags == ctx->table->flags) + return 0; + + if ((nft_table_has_owner(ctx->table) && + !(flags & NFT_TABLE_F_OWNER)) || + (!nft_table_has_owner(ctx->table) && + flags & NFT_TABLE_F_OWNER)) + return -EOPNOTSUPP; + + /* No dormant off/on/off/on games in single transaction */ + if (ctx->table->flags & __NFT_TABLE_F_UPDATE) + return -EINVAL; + + trans = nft_trans_alloc(ctx, NFT_MSG_NEWTABLE, + sizeof(struct nft_trans_table)); + if (trans == NULL) + return -ENOMEM; + + if ((flags & NFT_TABLE_F_DORMANT) && + !(ctx->table->flags & NFT_TABLE_F_DORMANT)) { + ctx->table->flags |= NFT_TABLE_F_DORMANT; + if (!(ctx->table->flags & __NFT_TABLE_F_UPDATE)) + ctx->table->flags |= __NFT_TABLE_F_WAS_AWAKEN; + } else if (!(flags & NFT_TABLE_F_DORMANT) && + ctx->table->flags & NFT_TABLE_F_DORMANT) { + ctx->table->flags &= ~NFT_TABLE_F_DORMANT; + if (!(ctx->table->flags & __NFT_TABLE_F_UPDATE)) { + ret = nf_tables_table_enable(ctx->net, ctx->table); + if (ret < 0) + goto err_register_hooks; + + ctx->table->flags |= __NFT_TABLE_F_WAS_DORMANT; + } + } + + nft_trans_table_update(trans) = true; + nft_trans_commit_list_add_tail(ctx->net, trans); + + return 0; + +err_register_hooks: + nft_trans_destroy(trans); + return ret; +} + +static u32 nft_chain_hash(const void *data, u32 len, u32 seed) +{ + const char *name = data; + + return jhash(name, strlen(name), seed); +} + +static u32 nft_chain_hash_obj(const void *data, u32 len, u32 seed) +{ + const struct nft_chain *chain = data; + + return nft_chain_hash(chain->name, 0, seed); +} + +static int nft_chain_hash_cmp(struct rhashtable_compare_arg *arg, + const void *ptr) +{ + const struct nft_chain *chain = ptr; + const char *name = arg->key; + + return strcmp(chain->name, name); +} + +static u32 nft_objname_hash(const void *data, u32 len, u32 seed) +{ + const struct nft_object_hash_key *k = data; + + seed ^= hash_ptr(k->table, 32); + + return jhash(k->name, strlen(k->name), seed); +} + +static u32 nft_objname_hash_obj(const void *data, u32 len, u32 seed) +{ + const struct nft_object *obj = data; + + return nft_objname_hash(&obj->key, 0, seed); +} + +static int nft_objname_hash_cmp(struct rhashtable_compare_arg *arg, + const void *ptr) +{ + const struct nft_object_hash_key *k = arg->key; + const struct nft_object *obj = ptr; + + if (obj->key.table != k->table) + return -1; + + return strcmp(obj->key.name, k->name); +} + +static bool nft_supported_family(u8 family) +{ + return false +#ifdef CONFIG_NF_TABLES_INET + || family == NFPROTO_INET +#endif +#ifdef CONFIG_NF_TABLES_IPV4 + || family == NFPROTO_IPV4 +#endif +#ifdef CONFIG_NF_TABLES_ARP + || family == NFPROTO_ARP +#endif +#ifdef CONFIG_NF_TABLES_NETDEV + || family == NFPROTO_NETDEV +#endif +#if IS_ENABLED(CONFIG_NF_TABLES_BRIDGE) + || family == NFPROTO_BRIDGE +#endif +#ifdef CONFIG_NF_TABLES_IPV6 + || family == NFPROTO_IPV6 +#endif + ; +} + +static int nf_tables_newtable(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct nftables_pernet *nft_net = nft_pernet(info->net); + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_next(info->net); + u8 family = info->nfmsg->nfgen_family; + struct net *net = info->net; + const struct nlattr *attr; + struct nft_table *table; + struct nft_ctx ctx; + u32 flags = 0; + int err; + + if (!nft_supported_family(family)) + return -EOPNOTSUPP; + + lockdep_assert_held(&nft_net->commit_mutex); + attr = nla[NFTA_TABLE_NAME]; + table = nft_table_lookup(net, attr, family, genmask, + NETLINK_CB(skb).portid); + if (IS_ERR(table)) { + if (PTR_ERR(table) != -ENOENT) + return PTR_ERR(table); + } else { + if (info->nlh->nlmsg_flags & NLM_F_EXCL) { + NL_SET_BAD_ATTR(extack, attr); + return -EEXIST; + } + if (info->nlh->nlmsg_flags & NLM_F_REPLACE) + return -EOPNOTSUPP; + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, NULL, nla); + + return nf_tables_updtable(&ctx); + } + + if (nla[NFTA_TABLE_FLAGS]) { + flags = ntohl(nla_get_be32(nla[NFTA_TABLE_FLAGS])); + if (flags & ~NFT_TABLE_F_MASK) + return -EOPNOTSUPP; + } + + err = -ENOMEM; + table = kzalloc(sizeof(*table), GFP_KERNEL_ACCOUNT); + if (table == NULL) + goto err_kzalloc; + + table->name = nla_strdup(attr, GFP_KERNEL_ACCOUNT); + if (table->name == NULL) + goto err_strdup; + + if (nla[NFTA_TABLE_USERDATA]) { + table->udata = nla_memdup(nla[NFTA_TABLE_USERDATA], GFP_KERNEL_ACCOUNT); + if (table->udata == NULL) + goto err_table_udata; + + table->udlen = nla_len(nla[NFTA_TABLE_USERDATA]); + } + + err = rhltable_init(&table->chains_ht, &nft_chain_ht_params); + if (err) + goto err_chain_ht; + + INIT_LIST_HEAD(&table->chains); + INIT_LIST_HEAD(&table->sets); + INIT_LIST_HEAD(&table->objects); + INIT_LIST_HEAD(&table->flowtables); + table->family = family; + table->flags = flags; + table->handle = ++nft_net->table_handle; + if (table->flags & NFT_TABLE_F_OWNER) + table->nlpid = NETLINK_CB(skb).portid; + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, NULL, nla); + err = nft_trans_table_add(&ctx, NFT_MSG_NEWTABLE); + if (err < 0) + goto err_trans; + + list_add_tail_rcu(&table->list, &nft_net->tables); + return 0; +err_trans: + rhltable_destroy(&table->chains_ht); +err_chain_ht: + kfree(table->udata); +err_table_udata: + kfree(table->name); +err_strdup: + kfree(table); +err_kzalloc: + return err; +} + +static int nft_flush_table(struct nft_ctx *ctx) +{ + struct nft_flowtable *flowtable, *nft; + struct nft_chain *chain, *nc; + struct nft_object *obj, *ne; + struct nft_set *set, *ns; + int err; + + list_for_each_entry(chain, &ctx->table->chains, list) { + if (!nft_is_active_next(ctx->net, chain)) + continue; + + if (nft_chain_binding(chain)) + continue; + + ctx->chain = chain; + + err = nft_delrule_by_chain(ctx); + if (err < 0) + goto out; + } + + list_for_each_entry_safe(set, ns, &ctx->table->sets, list) { + if (!nft_is_active_next(ctx->net, set)) + continue; + + if (nft_set_is_anonymous(set)) + continue; + + err = nft_delset(ctx, set); + if (err < 0) + goto out; + } + + list_for_each_entry_safe(flowtable, nft, &ctx->table->flowtables, list) { + if (!nft_is_active_next(ctx->net, flowtable)) + continue; + + err = nft_delflowtable(ctx, flowtable); + if (err < 0) + goto out; + } + + list_for_each_entry_safe(obj, ne, &ctx->table->objects, list) { + if (!nft_is_active_next(ctx->net, obj)) + continue; + + err = nft_delobj(ctx, obj); + if (err < 0) + goto out; + } + + list_for_each_entry_safe(chain, nc, &ctx->table->chains, list) { + if (!nft_is_active_next(ctx->net, chain)) + continue; + + if (nft_chain_binding(chain)) + continue; + + ctx->chain = chain; + + err = nft_delchain(ctx); + if (err < 0) + goto out; + } + + err = nft_deltable(ctx); +out: + return err; +} + +static int nft_flush(struct nft_ctx *ctx, int family) +{ + struct nftables_pernet *nft_net = nft_pernet(ctx->net); + const struct nlattr * const *nla = ctx->nla; + struct nft_table *table, *nt; + int err = 0; + + list_for_each_entry_safe(table, nt, &nft_net->tables, list) { + if (family != AF_UNSPEC && table->family != family) + continue; + + ctx->family = table->family; + + if (!nft_is_active_next(ctx->net, table)) + continue; + + if (nft_table_has_owner(table) && table->nlpid != ctx->portid) + continue; + + if (nla[NFTA_TABLE_NAME] && + nla_strcmp(nla[NFTA_TABLE_NAME], table->name) != 0) + continue; + + ctx->table = table; + + err = nft_flush_table(ctx); + if (err < 0) + goto out; + } +out: + return err; +} + +static int nf_tables_deltable(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_next(info->net); + u8 family = info->nfmsg->nfgen_family; + struct net *net = info->net; + const struct nlattr *attr; + struct nft_table *table; + struct nft_ctx ctx; + + nft_ctx_init(&ctx, net, skb, info->nlh, 0, NULL, NULL, nla); + if (family == AF_UNSPEC || + (!nla[NFTA_TABLE_NAME] && !nla[NFTA_TABLE_HANDLE])) + return nft_flush(&ctx, family); + + if (nla[NFTA_TABLE_HANDLE]) { + attr = nla[NFTA_TABLE_HANDLE]; + table = nft_table_lookup_byhandle(net, attr, family, genmask, + NETLINK_CB(skb).portid); + } else { + attr = nla[NFTA_TABLE_NAME]; + table = nft_table_lookup(net, attr, family, genmask, + NETLINK_CB(skb).portid); + } + + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, attr); + return PTR_ERR(table); + } + + if (info->nlh->nlmsg_flags & NLM_F_NONREC && + table->use > 0) + return -EBUSY; + + ctx.family = family; + ctx.table = table; + + return nft_flush_table(&ctx); +} + +static void nf_tables_table_destroy(struct nft_ctx *ctx) +{ + if (WARN_ON(ctx->table->use > 0)) + return; + + rhltable_destroy(&ctx->table->chains_ht); + kfree(ctx->table->name); + kfree(ctx->table->udata); + kfree(ctx->table); +} + +void nft_register_chain_type(const struct nft_chain_type *ctype) +{ + nfnl_lock(NFNL_SUBSYS_NFTABLES); + if (WARN_ON(__nft_chain_type_get(ctype->family, ctype->type))) { + nfnl_unlock(NFNL_SUBSYS_NFTABLES); + return; + } + chain_type[ctype->family][ctype->type] = ctype; + nfnl_unlock(NFNL_SUBSYS_NFTABLES); +} +EXPORT_SYMBOL_GPL(nft_register_chain_type); + +void nft_unregister_chain_type(const struct nft_chain_type *ctype) +{ + nfnl_lock(NFNL_SUBSYS_NFTABLES); + chain_type[ctype->family][ctype->type] = NULL; + nfnl_unlock(NFNL_SUBSYS_NFTABLES); +} +EXPORT_SYMBOL_GPL(nft_unregister_chain_type); + +/* + * Chains + */ + +static struct nft_chain * +nft_chain_lookup_byhandle(const struct nft_table *table, u64 handle, u8 genmask) +{ + struct nft_chain *chain; + + list_for_each_entry(chain, &table->chains, list) { + if (chain->handle == handle && + nft_active_genmask(chain, genmask)) + return chain; + } + + return ERR_PTR(-ENOENT); +} + +static bool lockdep_commit_lock_is_held(const struct net *net) +{ +#ifdef CONFIG_PROVE_LOCKING + struct nftables_pernet *nft_net = nft_pernet(net); + + return lockdep_is_held(&nft_net->commit_mutex); +#else + return true; +#endif +} + +static struct nft_chain *nft_chain_lookup(struct net *net, + struct nft_table *table, + const struct nlattr *nla, u8 genmask) +{ + char search[NFT_CHAIN_MAXNAMELEN + 1]; + struct rhlist_head *tmp, *list; + struct nft_chain *chain; + + if (nla == NULL) + return ERR_PTR(-EINVAL); + + nla_strscpy(search, nla, sizeof(search)); + + WARN_ON(!rcu_read_lock_held() && + !lockdep_commit_lock_is_held(net)); + + chain = ERR_PTR(-ENOENT); + rcu_read_lock(); + list = rhltable_lookup(&table->chains_ht, search, nft_chain_ht_params); + if (!list) + goto out_unlock; + + rhl_for_each_entry_rcu(chain, tmp, list, rhlhead) { + if (nft_active_genmask(chain, genmask)) + goto out_unlock; + } + chain = ERR_PTR(-ENOENT); +out_unlock: + rcu_read_unlock(); + return chain; +} + +static const struct nla_policy nft_chain_policy[NFTA_CHAIN_MAX + 1] = { + [NFTA_CHAIN_TABLE] = { .type = NLA_STRING, + .len = NFT_TABLE_MAXNAMELEN - 1 }, + [NFTA_CHAIN_HANDLE] = { .type = NLA_U64 }, + [NFTA_CHAIN_NAME] = { .type = NLA_STRING, + .len = NFT_CHAIN_MAXNAMELEN - 1 }, + [NFTA_CHAIN_HOOK] = { .type = NLA_NESTED }, + [NFTA_CHAIN_POLICY] = { .type = NLA_U32 }, + [NFTA_CHAIN_TYPE] = { .type = NLA_STRING, + .len = NFT_MODULE_AUTOLOAD_LIMIT }, + [NFTA_CHAIN_COUNTERS] = { .type = NLA_NESTED }, + [NFTA_CHAIN_FLAGS] = { .type = NLA_U32 }, + [NFTA_CHAIN_ID] = { .type = NLA_U32 }, + [NFTA_CHAIN_USERDATA] = { .type = NLA_BINARY, + .len = NFT_USERDATA_MAXLEN }, +}; + +static const struct nla_policy nft_hook_policy[NFTA_HOOK_MAX + 1] = { + [NFTA_HOOK_HOOKNUM] = { .type = NLA_U32 }, + [NFTA_HOOK_PRIORITY] = { .type = NLA_U32 }, + [NFTA_HOOK_DEV] = { .type = NLA_STRING, + .len = IFNAMSIZ - 1 }, +}; + +static int nft_dump_stats(struct sk_buff *skb, struct nft_stats __percpu *stats) +{ + struct nft_stats *cpu_stats, total; + struct nlattr *nest; + unsigned int seq; + u64 pkts, bytes; + int cpu; + + if (!stats) + return 0; + + memset(&total, 0, sizeof(total)); + for_each_possible_cpu(cpu) { + cpu_stats = per_cpu_ptr(stats, cpu); + do { + seq = u64_stats_fetch_begin_irq(&cpu_stats->syncp); + pkts = cpu_stats->pkts; + bytes = cpu_stats->bytes; + } while (u64_stats_fetch_retry_irq(&cpu_stats->syncp, seq)); + total.pkts += pkts; + total.bytes += bytes; + } + nest = nla_nest_start_noflag(skb, NFTA_CHAIN_COUNTERS); + if (nest == NULL) + goto nla_put_failure; + + if (nla_put_be64(skb, NFTA_COUNTER_PACKETS, cpu_to_be64(total.pkts), + NFTA_COUNTER_PAD) || + nla_put_be64(skb, NFTA_COUNTER_BYTES, cpu_to_be64(total.bytes), + NFTA_COUNTER_PAD)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + return -ENOSPC; +} + +static int nft_dump_basechain_hook(struct sk_buff *skb, int family, + const struct nft_base_chain *basechain) +{ + const struct nf_hook_ops *ops = &basechain->ops; + struct nft_hook *hook, *first = NULL; + struct nlattr *nest, *nest_devs; + int n = 0; + + nest = nla_nest_start_noflag(skb, NFTA_CHAIN_HOOK); + if (nest == NULL) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_HOOK_HOOKNUM, htonl(ops->hooknum))) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_HOOK_PRIORITY, htonl(ops->priority))) + goto nla_put_failure; + + if (nft_base_chain_netdev(family, ops->hooknum)) { + nest_devs = nla_nest_start_noflag(skb, NFTA_HOOK_DEVS); + list_for_each_entry(hook, &basechain->hook_list, list) { + if (!first) + first = hook; + + if (nla_put_string(skb, NFTA_DEVICE_NAME, + hook->ops.dev->name)) + goto nla_put_failure; + n++; + } + nla_nest_end(skb, nest_devs); + + if (n == 1 && + nla_put_string(skb, NFTA_HOOK_DEV, first->ops.dev->name)) + goto nla_put_failure; + } + nla_nest_end(skb, nest); + + return 0; +nla_put_failure: + return -1; +} + +static int nf_tables_fill_chain_info(struct sk_buff *skb, struct net *net, + u32 portid, u32 seq, int event, u32 flags, + int family, const struct nft_table *table, + const struct nft_chain *chain) +{ + struct nlmsghdr *nlh; + + event = nfnl_msg_type(NFNL_SUBSYS_NFTABLES, event); + nlh = nfnl_msg_put(skb, portid, seq, event, flags, family, + NFNETLINK_V0, nft_base_seq(net)); + if (!nlh) + goto nla_put_failure; + + if (nla_put_string(skb, NFTA_CHAIN_TABLE, table->name)) + goto nla_put_failure; + if (nla_put_be64(skb, NFTA_CHAIN_HANDLE, cpu_to_be64(chain->handle), + NFTA_CHAIN_PAD)) + goto nla_put_failure; + if (nla_put_string(skb, NFTA_CHAIN_NAME, chain->name)) + goto nla_put_failure; + + if (nft_is_base_chain(chain)) { + const struct nft_base_chain *basechain = nft_base_chain(chain); + struct nft_stats __percpu *stats; + + if (nft_dump_basechain_hook(skb, family, basechain)) + goto nla_put_failure; + + if (nla_put_be32(skb, NFTA_CHAIN_POLICY, + htonl(basechain->policy))) + goto nla_put_failure; + + if (nla_put_string(skb, NFTA_CHAIN_TYPE, basechain->type->name)) + goto nla_put_failure; + + stats = rcu_dereference_check(basechain->stats, + lockdep_commit_lock_is_held(net)); + if (nft_dump_stats(skb, stats)) + goto nla_put_failure; + } + + if (chain->flags && + nla_put_be32(skb, NFTA_CHAIN_FLAGS, htonl(chain->flags))) + goto nla_put_failure; + + if (nla_put_be32(skb, NFTA_CHAIN_USE, htonl(chain->use))) + goto nla_put_failure; + + if (chain->udata && + nla_put(skb, NFTA_CHAIN_USERDATA, chain->udlen, chain->udata)) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_trim(skb, nlh); + return -1; +} + +static void nf_tables_chain_notify(const struct nft_ctx *ctx, int event) +{ + struct nftables_pernet *nft_net; + struct sk_buff *skb; + u16 flags = 0; + int err; + + if (!ctx->report && + !nfnetlink_has_listeners(ctx->net, NFNLGRP_NFTABLES)) + return; + + skb = nlmsg_new(NLMSG_GOODSIZE, GFP_KERNEL); + if (skb == NULL) + goto err; + + if (ctx->flags & (NLM_F_CREATE | NLM_F_EXCL)) + flags |= ctx->flags & (NLM_F_CREATE | NLM_F_EXCL); + + err = nf_tables_fill_chain_info(skb, ctx->net, ctx->portid, ctx->seq, + event, flags, ctx->family, ctx->table, + ctx->chain); + if (err < 0) { + kfree_skb(skb); + goto err; + } + + nft_net = nft_pernet(ctx->net); + nft_notify_enqueue(skb, ctx->report, &nft_net->notify_list); + return; +err: + nfnetlink_set_err(ctx->net, ctx->portid, NFNLGRP_NFTABLES, -ENOBUFS); +} + +static int nf_tables_dump_chains(struct sk_buff *skb, + struct netlink_callback *cb) +{ + const struct nfgenmsg *nfmsg = nlmsg_data(cb->nlh); + unsigned int idx = 0, s_idx = cb->args[0]; + struct net *net = sock_net(skb->sk); + int family = nfmsg->nfgen_family; + struct nftables_pernet *nft_net; + const struct nft_table *table; + const struct nft_chain *chain; + + rcu_read_lock(); + nft_net = nft_pernet(net); + cb->seq = READ_ONCE(nft_net->base_seq); + + list_for_each_entry_rcu(table, &nft_net->tables, list) { + if (family != NFPROTO_UNSPEC && family != table->family) + continue; + + list_for_each_entry_rcu(chain, &table->chains, list) { + if (idx < s_idx) + goto cont; + if (idx > s_idx) + memset(&cb->args[1], 0, + sizeof(cb->args) - sizeof(cb->args[0])); + if (!nft_is_active(net, chain)) + continue; + if (nf_tables_fill_chain_info(skb, net, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NFT_MSG_NEWCHAIN, + NLM_F_MULTI, + table->family, table, + chain) < 0) + goto done; + + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); +cont: + idx++; + } + } +done: + rcu_read_unlock(); + cb->args[0] = idx; + return skb->len; +} + +/* called with rcu_read_lock held */ +static int nf_tables_getchain(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_cur(info->net); + u8 family = info->nfmsg->nfgen_family; + const struct nft_chain *chain; + struct net *net = info->net; + struct nft_table *table; + struct sk_buff *skb2; + int err; + + if (info->nlh->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .dump = nf_tables_dump_chains, + .module = THIS_MODULE, + }; + + return nft_netlink_dump_start_rcu(info->sk, skb, info->nlh, &c); + } + + table = nft_table_lookup(net, nla[NFTA_CHAIN_TABLE], family, genmask, 0); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_CHAIN_TABLE]); + return PTR_ERR(table); + } + + chain = nft_chain_lookup(net, table, nla[NFTA_CHAIN_NAME], genmask); + if (IS_ERR(chain)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_CHAIN_NAME]); + return PTR_ERR(chain); + } + + skb2 = alloc_skb(NLMSG_GOODSIZE, GFP_ATOMIC); + if (!skb2) + return -ENOMEM; + + err = nf_tables_fill_chain_info(skb2, net, NETLINK_CB(skb).portid, + info->nlh->nlmsg_seq, NFT_MSG_NEWCHAIN, + 0, family, table, chain); + if (err < 0) + goto err_fill_chain_info; + + return nfnetlink_unicast(skb2, net, NETLINK_CB(skb).portid); + +err_fill_chain_info: + kfree_skb(skb2); + return err; +} + +static const struct nla_policy nft_counter_policy[NFTA_COUNTER_MAX + 1] = { + [NFTA_COUNTER_PACKETS] = { .type = NLA_U64 }, + [NFTA_COUNTER_BYTES] = { .type = NLA_U64 }, +}; + +static struct nft_stats __percpu *nft_stats_alloc(const struct nlattr *attr) +{ + struct nlattr *tb[NFTA_COUNTER_MAX+1]; + struct nft_stats __percpu *newstats; + struct nft_stats *stats; + int err; + + err = nla_parse_nested_deprecated(tb, NFTA_COUNTER_MAX, attr, + nft_counter_policy, NULL); + if (err < 0) + return ERR_PTR(err); + + if (!tb[NFTA_COUNTER_BYTES] || !tb[NFTA_COUNTER_PACKETS]) + return ERR_PTR(-EINVAL); + + newstats = netdev_alloc_pcpu_stats(struct nft_stats); + if (newstats == NULL) + return ERR_PTR(-ENOMEM); + + /* Restore old counters on this cpu, no problem. Per-cpu statistics + * are not exposed to userspace. + */ + preempt_disable(); + stats = this_cpu_ptr(newstats); + stats->bytes = be64_to_cpu(nla_get_be64(tb[NFTA_COUNTER_BYTES])); + stats->pkts = be64_to_cpu(nla_get_be64(tb[NFTA_COUNTER_PACKETS])); + preempt_enable(); + + return newstats; +} + +static void nft_chain_stats_replace(struct nft_trans *trans) +{ + struct nft_base_chain *chain = nft_base_chain(trans->ctx.chain); + + if (!nft_trans_chain_stats(trans)) + return; + + nft_trans_chain_stats(trans) = + rcu_replace_pointer(chain->stats, nft_trans_chain_stats(trans), + lockdep_commit_lock_is_held(trans->ctx.net)); + + if (!nft_trans_chain_stats(trans)) + static_branch_inc(&nft_counters_enabled); +} + +static void nf_tables_chain_free_chain_rules(struct nft_chain *chain) +{ + struct nft_rule_blob *g0 = rcu_dereference_raw(chain->blob_gen_0); + struct nft_rule_blob *g1 = rcu_dereference_raw(chain->blob_gen_1); + + if (g0 != g1) + kvfree(g1); + kvfree(g0); + + /* should be NULL either via abort or via successful commit */ + WARN_ON_ONCE(chain->blob_next); + kvfree(chain->blob_next); +} + +void nf_tables_chain_destroy(struct nft_ctx *ctx) +{ + struct nft_chain *chain = ctx->chain; + struct nft_hook *hook, *next; + + if (WARN_ON(chain->use > 0)) + return; + + /* no concurrent access possible anymore */ + nf_tables_chain_free_chain_rules(chain); + + if (nft_is_base_chain(chain)) { + struct nft_base_chain *basechain = nft_base_chain(chain); + + if (nft_base_chain_netdev(ctx->family, basechain->ops.hooknum)) { + list_for_each_entry_safe(hook, next, + &basechain->hook_list, list) { + list_del_rcu(&hook->list); + kfree_rcu(hook, rcu); + } + } + module_put(basechain->type->owner); + if (rcu_access_pointer(basechain->stats)) { + static_branch_dec(&nft_counters_enabled); + free_percpu(rcu_dereference_raw(basechain->stats)); + } + kfree(chain->name); + kfree(chain->udata); + kfree(basechain); + } else { + kfree(chain->name); + kfree(chain->udata); + kfree(chain); + } +} + +static struct nft_hook *nft_netdev_hook_alloc(struct net *net, + const struct nlattr *attr) +{ + struct net_device *dev; + char ifname[IFNAMSIZ]; + struct nft_hook *hook; + int err; + + hook = kmalloc(sizeof(struct nft_hook), GFP_KERNEL_ACCOUNT); + if (!hook) { + err = -ENOMEM; + goto err_hook_alloc; + } + + nla_strscpy(ifname, attr, IFNAMSIZ); + /* nf_tables_netdev_event() is called under rtnl_mutex, this is + * indirectly serializing all the other holders of the commit_mutex with + * the rtnl_mutex. + */ + dev = __dev_get_by_name(net, ifname); + if (!dev) { + err = -ENOENT; + goto err_hook_dev; + } + hook->ops.dev = dev; + + return hook; + +err_hook_dev: + kfree(hook); +err_hook_alloc: + return ERR_PTR(err); +} + +static struct nft_hook *nft_hook_list_find(struct list_head *hook_list, + const struct nft_hook *this) +{ + struct nft_hook *hook; + + list_for_each_entry(hook, hook_list, list) { + if (this->ops.dev == hook->ops.dev) + return hook; + } + + return NULL; +} + +static int nf_tables_parse_netdev_hooks(struct net *net, + const struct nlattr *attr, + struct list_head *hook_list) +{ + struct nft_hook *hook, *next; + const struct nlattr *tmp; + int rem, n = 0, err; + + nla_for_each_nested(tmp, attr, rem) { + if (nla_type(tmp) != NFTA_DEVICE_NAME) { + err = -EINVAL; + goto err_hook; + } + + hook = nft_netdev_hook_alloc(net, tmp); + if (IS_ERR(hook)) { + err = PTR_ERR(hook); + goto err_hook; + } + if (nft_hook_list_find(hook_list, hook)) { + kfree(hook); + err = -EEXIST; + goto err_hook; + } + list_add_tail(&hook->list, hook_list); + n++; + + if (n == NFT_NETDEVICE_MAX) { + err = -EFBIG; + goto err_hook; + } + } + + return 0; + +err_hook: + list_for_each_entry_safe(hook, next, hook_list, list) { + list_del(&hook->list); + kfree(hook); + } + return err; +} + +struct nft_chain_hook { + u32 num; + s32 priority; + const struct nft_chain_type *type; + struct list_head list; +}; + +static int nft_chain_parse_netdev(struct net *net, + struct nlattr *tb[], + struct list_head *hook_list) +{ + struct nft_hook *hook; + int err; + + if (tb[NFTA_HOOK_DEV]) { + hook = nft_netdev_hook_alloc(net, tb[NFTA_HOOK_DEV]); + if (IS_ERR(hook)) + return PTR_ERR(hook); + + list_add_tail(&hook->list, hook_list); + } else if (tb[NFTA_HOOK_DEVS]) { + err = nf_tables_parse_netdev_hooks(net, tb[NFTA_HOOK_DEVS], + hook_list); + if (err < 0) + return err; + + if (list_empty(hook_list)) + return -EINVAL; + } else { + return -EINVAL; + } + + return 0; +} + +static int nft_chain_parse_hook(struct net *net, + const struct nlattr * const nla[], + struct nft_chain_hook *hook, u8 family, + struct netlink_ext_ack *extack, bool autoload) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + struct nlattr *ha[NFTA_HOOK_MAX + 1]; + const struct nft_chain_type *type; + int err; + + lockdep_assert_held(&nft_net->commit_mutex); + lockdep_nfnl_nft_mutex_not_held(); + + err = nla_parse_nested_deprecated(ha, NFTA_HOOK_MAX, + nla[NFTA_CHAIN_HOOK], + nft_hook_policy, NULL); + if (err < 0) + return err; + + if (ha[NFTA_HOOK_HOOKNUM] == NULL || + ha[NFTA_HOOK_PRIORITY] == NULL) + return -EINVAL; + + hook->num = ntohl(nla_get_be32(ha[NFTA_HOOK_HOOKNUM])); + hook->priority = ntohl(nla_get_be32(ha[NFTA_HOOK_PRIORITY])); + + type = __nft_chain_type_get(family, NFT_CHAIN_T_DEFAULT); + if (!type) + return -EOPNOTSUPP; + + if (nla[NFTA_CHAIN_TYPE]) { + type = nf_tables_chain_type_lookup(net, nla[NFTA_CHAIN_TYPE], + family, autoload); + if (IS_ERR(type)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_CHAIN_TYPE]); + return PTR_ERR(type); + } + } + if (hook->num >= NFT_MAX_HOOKS || !(type->hook_mask & (1 << hook->num))) + return -EOPNOTSUPP; + + if (type->type == NFT_CHAIN_T_NAT && + hook->priority <= NF_IP_PRI_CONNTRACK) + return -EOPNOTSUPP; + + if (!try_module_get(type->owner)) { + if (nla[NFTA_CHAIN_TYPE]) + NL_SET_BAD_ATTR(extack, nla[NFTA_CHAIN_TYPE]); + return -ENOENT; + } + + hook->type = type; + + INIT_LIST_HEAD(&hook->list); + if (nft_base_chain_netdev(family, hook->num)) { + err = nft_chain_parse_netdev(net, ha, &hook->list); + if (err < 0) { + module_put(type->owner); + return err; + } + } else if (ha[NFTA_HOOK_DEV] || ha[NFTA_HOOK_DEVS]) { + module_put(type->owner); + return -EOPNOTSUPP; + } + + return 0; +} + +static void nft_chain_release_hook(struct nft_chain_hook *hook) +{ + struct nft_hook *h, *next; + + list_for_each_entry_safe(h, next, &hook->list, list) { + list_del(&h->list); + kfree(h); + } + module_put(hook->type->owner); +} + +struct nft_rules_old { + struct rcu_head h; + struct nft_rule_blob *blob; +}; + +static void nft_last_rule(struct nft_rule_blob *blob, const void *ptr) +{ + struct nft_rule_dp *prule; + + prule = (struct nft_rule_dp *)ptr; + prule->is_last = 1; + /* blob size does not include the trailer rule */ +} + +static struct nft_rule_blob *nf_tables_chain_alloc_rules(unsigned int size) +{ + struct nft_rule_blob *blob; + + /* size must include room for the last rule */ + if (size < offsetof(struct nft_rule_dp, data)) + return NULL; + + size += sizeof(struct nft_rule_blob) + sizeof(struct nft_rules_old); + if (size > INT_MAX) + return NULL; + + blob = kvmalloc(size, GFP_KERNEL_ACCOUNT); + if (!blob) + return NULL; + + blob->size = 0; + nft_last_rule(blob, blob->data); + + return blob; +} + +static void nft_basechain_hook_init(struct nf_hook_ops *ops, u8 family, + const struct nft_chain_hook *hook, + struct nft_chain *chain) +{ + ops->pf = family; + ops->hooknum = hook->num; + ops->priority = hook->priority; + ops->priv = chain; + ops->hook = hook->type->hooks[ops->hooknum]; + ops->hook_ops_type = NF_HOOK_OP_NF_TABLES; +} + +static int nft_basechain_init(struct nft_base_chain *basechain, u8 family, + struct nft_chain_hook *hook, u32 flags) +{ + struct nft_chain *chain; + struct nft_hook *h; + + basechain->type = hook->type; + INIT_LIST_HEAD(&basechain->hook_list); + chain = &basechain->chain; + + if (nft_base_chain_netdev(family, hook->num)) { + list_splice_init(&hook->list, &basechain->hook_list); + list_for_each_entry(h, &basechain->hook_list, list) + nft_basechain_hook_init(&h->ops, family, hook, chain); + + basechain->ops.hooknum = hook->num; + basechain->ops.priority = hook->priority; + } else { + nft_basechain_hook_init(&basechain->ops, family, hook, chain); + } + + chain->flags |= NFT_CHAIN_BASE | flags; + basechain->policy = NF_ACCEPT; + if (chain->flags & NFT_CHAIN_HW_OFFLOAD && + !nft_chain_offload_support(basechain)) { + list_splice_init(&basechain->hook_list, &hook->list); + return -EOPNOTSUPP; + } + + flow_block_init(&basechain->flow_block); + + return 0; +} + +int nft_chain_add(struct nft_table *table, struct nft_chain *chain) +{ + int err; + + err = rhltable_insert_key(&table->chains_ht, chain->name, + &chain->rhlhead, nft_chain_ht_params); + if (err) + return err; + + list_add_tail_rcu(&chain->list, &table->chains); + + return 0; +} + +static u64 chain_id; + +static int nf_tables_addchain(struct nft_ctx *ctx, u8 family, u8 genmask, + u8 policy, u32 flags, + struct netlink_ext_ack *extack) +{ + const struct nlattr * const *nla = ctx->nla; + struct nft_table *table = ctx->table; + struct nft_base_chain *basechain; + struct net *net = ctx->net; + char name[NFT_NAME_MAXLEN]; + struct nft_rule_blob *blob; + struct nft_trans *trans; + struct nft_chain *chain; + unsigned int data_size; + int err; + + if (nla[NFTA_CHAIN_HOOK]) { + struct nft_stats __percpu *stats = NULL; + struct nft_chain_hook hook; + + if (flags & NFT_CHAIN_BINDING) + return -EOPNOTSUPP; + + err = nft_chain_parse_hook(net, nla, &hook, family, extack, + true); + if (err < 0) + return err; + + basechain = kzalloc(sizeof(*basechain), GFP_KERNEL_ACCOUNT); + if (basechain == NULL) { + nft_chain_release_hook(&hook); + return -ENOMEM; + } + chain = &basechain->chain; + + if (nla[NFTA_CHAIN_COUNTERS]) { + stats = nft_stats_alloc(nla[NFTA_CHAIN_COUNTERS]); + if (IS_ERR(stats)) { + nft_chain_release_hook(&hook); + kfree(basechain); + return PTR_ERR(stats); + } + rcu_assign_pointer(basechain->stats, stats); + } + + err = nft_basechain_init(basechain, family, &hook, flags); + if (err < 0) { + nft_chain_release_hook(&hook); + kfree(basechain); + free_percpu(stats); + return err; + } + if (stats) + static_branch_inc(&nft_counters_enabled); + } else { + if (flags & NFT_CHAIN_BASE) + return -EINVAL; + if (flags & NFT_CHAIN_HW_OFFLOAD) + return -EOPNOTSUPP; + + chain = kzalloc(sizeof(*chain), GFP_KERNEL_ACCOUNT); + if (chain == NULL) + return -ENOMEM; + + chain->flags = flags; + } + ctx->chain = chain; + + INIT_LIST_HEAD(&chain->rules); + chain->handle = nf_tables_alloc_handle(table); + chain->table = table; + + if (nla[NFTA_CHAIN_NAME]) { + chain->name = nla_strdup(nla[NFTA_CHAIN_NAME], GFP_KERNEL_ACCOUNT); + } else { + if (!(flags & NFT_CHAIN_BINDING)) { + err = -EINVAL; + goto err_destroy_chain; + } + + snprintf(name, sizeof(name), "__chain%llu", ++chain_id); + chain->name = kstrdup(name, GFP_KERNEL_ACCOUNT); + } + + if (!chain->name) { + err = -ENOMEM; + goto err_destroy_chain; + } + + if (nla[NFTA_CHAIN_USERDATA]) { + chain->udata = nla_memdup(nla[NFTA_CHAIN_USERDATA], GFP_KERNEL_ACCOUNT); + if (chain->udata == NULL) { + err = -ENOMEM; + goto err_destroy_chain; + } + chain->udlen = nla_len(nla[NFTA_CHAIN_USERDATA]); + } + + data_size = offsetof(struct nft_rule_dp, data); /* last rule */ + blob = nf_tables_chain_alloc_rules(data_size); + if (!blob) { + err = -ENOMEM; + goto err_destroy_chain; + } + + RCU_INIT_POINTER(chain->blob_gen_0, blob); + RCU_INIT_POINTER(chain->blob_gen_1, blob); + + err = nf_tables_register_hook(net, table, chain); + if (err < 0) + goto err_destroy_chain; + + if (!nft_use_inc(&table->use)) { + err = -EMFILE; + goto err_use; + } + + trans = nft_trans_chain_add(ctx, NFT_MSG_NEWCHAIN); + if (IS_ERR(trans)) { + err = PTR_ERR(trans); + goto err_unregister_hook; + } + + nft_trans_chain_policy(trans) = NFT_CHAIN_POLICY_UNSET; + if (nft_is_base_chain(chain)) + nft_trans_chain_policy(trans) = policy; + + err = nft_chain_add(table, chain); + if (err < 0) { + nft_trans_destroy(trans); + goto err_unregister_hook; + } + + return 0; + +err_unregister_hook: + nft_use_dec_restore(&table->use); +err_use: + nf_tables_unregister_hook(net, table, chain); +err_destroy_chain: + nf_tables_chain_destroy(ctx); + + return err; +} + +static bool nft_hook_list_equal(struct list_head *hook_list1, + struct list_head *hook_list2) +{ + struct nft_hook *hook; + int n = 0, m = 0; + + n = 0; + list_for_each_entry(hook, hook_list2, list) { + if (!nft_hook_list_find(hook_list1, hook)) + return false; + + n++; + } + list_for_each_entry(hook, hook_list1, list) + m++; + + return n == m; +} + +static int nf_tables_updchain(struct nft_ctx *ctx, u8 genmask, u8 policy, + u32 flags, const struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + const struct nlattr * const *nla = ctx->nla; + struct nft_table *table = ctx->table; + struct nft_chain *chain = ctx->chain; + struct nft_base_chain *basechain; + struct nft_stats *stats = NULL; + struct nft_chain_hook hook; + struct nf_hook_ops *ops; + struct nft_trans *trans; + int err; + + if (chain->flags ^ flags) + return -EOPNOTSUPP; + + if (nla[NFTA_CHAIN_HOOK]) { + if (!nft_is_base_chain(chain)) { + NL_SET_BAD_ATTR(extack, attr); + return -EEXIST; + } + err = nft_chain_parse_hook(ctx->net, nla, &hook, ctx->family, + extack, false); + if (err < 0) + return err; + + basechain = nft_base_chain(chain); + if (basechain->type != hook.type) { + nft_chain_release_hook(&hook); + NL_SET_BAD_ATTR(extack, attr); + return -EEXIST; + } + + if (nft_base_chain_netdev(ctx->family, hook.num)) { + if (!nft_hook_list_equal(&basechain->hook_list, + &hook.list)) { + nft_chain_release_hook(&hook); + NL_SET_BAD_ATTR(extack, attr); + return -EEXIST; + } + } else { + ops = &basechain->ops; + if (ops->hooknum != hook.num || + ops->priority != hook.priority) { + nft_chain_release_hook(&hook); + NL_SET_BAD_ATTR(extack, attr); + return -EEXIST; + } + } + nft_chain_release_hook(&hook); + } + + if (nla[NFTA_CHAIN_HANDLE] && + nla[NFTA_CHAIN_NAME]) { + struct nft_chain *chain2; + + chain2 = nft_chain_lookup(ctx->net, table, + nla[NFTA_CHAIN_NAME], genmask); + if (!IS_ERR(chain2)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_CHAIN_NAME]); + return -EEXIST; + } + } + + if (nla[NFTA_CHAIN_COUNTERS]) { + if (!nft_is_base_chain(chain)) + return -EOPNOTSUPP; + + stats = nft_stats_alloc(nla[NFTA_CHAIN_COUNTERS]); + if (IS_ERR(stats)) + return PTR_ERR(stats); + } + + err = -ENOMEM; + trans = nft_trans_alloc(ctx, NFT_MSG_NEWCHAIN, + sizeof(struct nft_trans_chain)); + if (trans == NULL) + goto err; + + nft_trans_chain_stats(trans) = stats; + nft_trans_chain_update(trans) = true; + + if (nla[NFTA_CHAIN_POLICY]) + nft_trans_chain_policy(trans) = policy; + else + nft_trans_chain_policy(trans) = -1; + + if (nla[NFTA_CHAIN_HANDLE] && + nla[NFTA_CHAIN_NAME]) { + struct nftables_pernet *nft_net = nft_pernet(ctx->net); + struct nft_trans *tmp; + char *name; + + err = -ENOMEM; + name = nla_strdup(nla[NFTA_CHAIN_NAME], GFP_KERNEL_ACCOUNT); + if (!name) + goto err; + + err = -EEXIST; + list_for_each_entry(tmp, &nft_net->commit_list, list) { + if (tmp->msg_type == NFT_MSG_NEWCHAIN && + tmp->ctx.table == table && + nft_trans_chain_update(tmp) && + nft_trans_chain_name(tmp) && + strcmp(name, nft_trans_chain_name(tmp)) == 0) { + NL_SET_BAD_ATTR(extack, nla[NFTA_CHAIN_NAME]); + kfree(name); + goto err; + } + } + + nft_trans_chain_name(trans) = name; + } + nft_trans_commit_list_add_tail(ctx->net, trans); + + return 0; +err: + free_percpu(stats); + kfree(trans); + return err; +} + +static struct nft_chain *nft_chain_lookup_byid(const struct net *net, + const struct nft_table *table, + const struct nlattr *nla, u8 genmask) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + u32 id = ntohl(nla_get_be32(nla)); + struct nft_trans *trans; + + list_for_each_entry(trans, &nft_net->commit_list, list) { + struct nft_chain *chain = trans->ctx.chain; + + if (trans->msg_type == NFT_MSG_NEWCHAIN && + chain->table == table && + id == nft_trans_chain_id(trans) && + nft_active_genmask(chain, genmask)) + return chain; + } + return ERR_PTR(-ENOENT); +} + +static int nf_tables_newchain(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct nftables_pernet *nft_net = nft_pernet(info->net); + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_next(info->net); + u8 family = info->nfmsg->nfgen_family; + struct nft_chain *chain = NULL; + struct net *net = info->net; + const struct nlattr *attr; + struct nft_table *table; + u8 policy = NF_ACCEPT; + struct nft_ctx ctx; + u64 handle = 0; + u32 flags = 0; + + lockdep_assert_held(&nft_net->commit_mutex); + + table = nft_table_lookup(net, nla[NFTA_CHAIN_TABLE], family, genmask, + NETLINK_CB(skb).portid); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_CHAIN_TABLE]); + return PTR_ERR(table); + } + + chain = NULL; + attr = nla[NFTA_CHAIN_NAME]; + + if (nla[NFTA_CHAIN_HANDLE]) { + handle = be64_to_cpu(nla_get_be64(nla[NFTA_CHAIN_HANDLE])); + chain = nft_chain_lookup_byhandle(table, handle, genmask); + if (IS_ERR(chain)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_CHAIN_HANDLE]); + return PTR_ERR(chain); + } + attr = nla[NFTA_CHAIN_HANDLE]; + } else if (nla[NFTA_CHAIN_NAME]) { + chain = nft_chain_lookup(net, table, attr, genmask); + if (IS_ERR(chain)) { + if (PTR_ERR(chain) != -ENOENT) { + NL_SET_BAD_ATTR(extack, attr); + return PTR_ERR(chain); + } + chain = NULL; + } + } else if (!nla[NFTA_CHAIN_ID]) { + return -EINVAL; + } + + if (nla[NFTA_CHAIN_POLICY]) { + if (chain != NULL && + !nft_is_base_chain(chain)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_CHAIN_POLICY]); + return -EOPNOTSUPP; + } + + if (chain == NULL && + nla[NFTA_CHAIN_HOOK] == NULL) { + NL_SET_BAD_ATTR(extack, nla[NFTA_CHAIN_POLICY]); + return -EOPNOTSUPP; + } + + policy = ntohl(nla_get_be32(nla[NFTA_CHAIN_POLICY])); + switch (policy) { + case NF_DROP: + case NF_ACCEPT: + break; + default: + return -EINVAL; + } + } + + if (nla[NFTA_CHAIN_FLAGS]) + flags = ntohl(nla_get_be32(nla[NFTA_CHAIN_FLAGS])); + else if (chain) + flags = chain->flags; + + if (flags & ~NFT_CHAIN_FLAGS) + return -EOPNOTSUPP; + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, chain, nla); + + if (chain != NULL) { + if (chain->flags & NFT_CHAIN_BINDING) + return -EINVAL; + + if (info->nlh->nlmsg_flags & NLM_F_EXCL) { + NL_SET_BAD_ATTR(extack, attr); + return -EEXIST; + } + if (info->nlh->nlmsg_flags & NLM_F_REPLACE) + return -EOPNOTSUPP; + + flags |= chain->flags & NFT_CHAIN_BASE; + return nf_tables_updchain(&ctx, genmask, policy, flags, attr, + extack); + } + + return nf_tables_addchain(&ctx, family, genmask, policy, flags, extack); +} + +static int nf_tables_delchain(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_next(info->net); + u8 family = info->nfmsg->nfgen_family; + struct net *net = info->net; + const struct nlattr *attr; + struct nft_table *table; + struct nft_chain *chain; + struct nft_rule *rule; + struct nft_ctx ctx; + u64 handle; + u32 use; + int err; + + table = nft_table_lookup(net, nla[NFTA_CHAIN_TABLE], family, genmask, + NETLINK_CB(skb).portid); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_CHAIN_TABLE]); + return PTR_ERR(table); + } + + if (nla[NFTA_CHAIN_HANDLE]) { + attr = nla[NFTA_CHAIN_HANDLE]; + handle = be64_to_cpu(nla_get_be64(attr)); + chain = nft_chain_lookup_byhandle(table, handle, genmask); + } else { + attr = nla[NFTA_CHAIN_NAME]; + chain = nft_chain_lookup(net, table, attr, genmask); + } + if (IS_ERR(chain)) { + NL_SET_BAD_ATTR(extack, attr); + return PTR_ERR(chain); + } + + if (nft_chain_binding(chain)) + return -EOPNOTSUPP; + + if (info->nlh->nlmsg_flags & NLM_F_NONREC && + chain->use > 0) + return -EBUSY; + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, chain, nla); + + use = chain->use; + list_for_each_entry(rule, &chain->rules, list) { + if (!nft_is_active_next(net, rule)) + continue; + use--; + + err = nft_delrule(&ctx, rule); + if (err < 0) + return err; + } + + /* There are rules and elements that are still holding references to us, + * we cannot do a recursive removal in this case. + */ + if (use > 0) { + NL_SET_BAD_ATTR(extack, attr); + return -EBUSY; + } + + return nft_delchain(&ctx); +} + +/* + * Expressions + */ + +/** + * nft_register_expr - register nf_tables expr type + * @type: expr type + * + * Registers the expr type for use with nf_tables. Returns zero on + * success or a negative errno code otherwise. + */ +int nft_register_expr(struct nft_expr_type *type) +{ + nfnl_lock(NFNL_SUBSYS_NFTABLES); + if (type->family == NFPROTO_UNSPEC) + list_add_tail_rcu(&type->list, &nf_tables_expressions); + else + list_add_rcu(&type->list, &nf_tables_expressions); + nfnl_unlock(NFNL_SUBSYS_NFTABLES); + return 0; +} +EXPORT_SYMBOL_GPL(nft_register_expr); + +/** + * nft_unregister_expr - unregister nf_tables expr type + * @type: expr type + * + * Unregisters the expr typefor use with nf_tables. + */ +void nft_unregister_expr(struct nft_expr_type *type) +{ + nfnl_lock(NFNL_SUBSYS_NFTABLES); + list_del_rcu(&type->list); + nfnl_unlock(NFNL_SUBSYS_NFTABLES); +} +EXPORT_SYMBOL_GPL(nft_unregister_expr); + +static const struct nft_expr_type *__nft_expr_type_get(u8 family, + struct nlattr *nla) +{ + const struct nft_expr_type *type, *candidate = NULL; + + list_for_each_entry(type, &nf_tables_expressions, list) { + if (!nla_strcmp(nla, type->name)) { + if (!type->family && !candidate) + candidate = type; + else if (type->family == family) + candidate = type; + } + } + return candidate; +} + +#ifdef CONFIG_MODULES +static int nft_expr_type_request_module(struct net *net, u8 family, + struct nlattr *nla) +{ + if (nft_request_module(net, "nft-expr-%u-%.*s", family, + nla_len(nla), (char *)nla_data(nla)) == -EAGAIN) + return -EAGAIN; + + return 0; +} +#endif + +static const struct nft_expr_type *nft_expr_type_get(struct net *net, + u8 family, + struct nlattr *nla) +{ + const struct nft_expr_type *type; + + if (nla == NULL) + return ERR_PTR(-EINVAL); + + type = __nft_expr_type_get(family, nla); + if (type != NULL && try_module_get(type->owner)) + return type; + + lockdep_nfnl_nft_mutex_not_held(); +#ifdef CONFIG_MODULES + if (type == NULL) { + if (nft_expr_type_request_module(net, family, nla) == -EAGAIN) + return ERR_PTR(-EAGAIN); + + if (nft_request_module(net, "nft-expr-%.*s", + nla_len(nla), + (char *)nla_data(nla)) == -EAGAIN) + return ERR_PTR(-EAGAIN); + } +#endif + return ERR_PTR(-ENOENT); +} + +static const struct nla_policy nft_expr_policy[NFTA_EXPR_MAX + 1] = { + [NFTA_EXPR_NAME] = { .type = NLA_STRING, + .len = NFT_MODULE_AUTOLOAD_LIMIT }, + [NFTA_EXPR_DATA] = { .type = NLA_NESTED }, +}; + +static int nf_tables_fill_expr_info(struct sk_buff *skb, + const struct nft_expr *expr) +{ + if (nla_put_string(skb, NFTA_EXPR_NAME, expr->ops->type->name)) + goto nla_put_failure; + + if (expr->ops->dump) { + struct nlattr *data = nla_nest_start_noflag(skb, + NFTA_EXPR_DATA); + if (data == NULL) + goto nla_put_failure; + if (expr->ops->dump(skb, expr) < 0) + goto nla_put_failure; + nla_nest_end(skb, data); + } + + return skb->len; + +nla_put_failure: + return -1; +}; + +int nft_expr_dump(struct sk_buff *skb, unsigned int attr, + const struct nft_expr *expr) +{ + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, attr); + if (!nest) + goto nla_put_failure; + if (nf_tables_fill_expr_info(skb, expr) < 0) + goto nla_put_failure; + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + return -1; +} + +struct nft_expr_info { + const struct nft_expr_ops *ops; + const struct nlattr *attr; + struct nlattr *tb[NFT_EXPR_MAXATTR + 1]; +}; + +static int nf_tables_expr_parse(const struct nft_ctx *ctx, + const struct nlattr *nla, + struct nft_expr_info *info) +{ + const struct nft_expr_type *type; + const struct nft_expr_ops *ops; + struct nlattr *tb[NFTA_EXPR_MAX + 1]; + int err; + + err = nla_parse_nested_deprecated(tb, NFTA_EXPR_MAX, nla, + nft_expr_policy, NULL); + if (err < 0) + return err; + + type = nft_expr_type_get(ctx->net, ctx->family, tb[NFTA_EXPR_NAME]); + if (IS_ERR(type)) + return PTR_ERR(type); + + if (tb[NFTA_EXPR_DATA]) { + err = nla_parse_nested_deprecated(info->tb, type->maxattr, + tb[NFTA_EXPR_DATA], + type->policy, NULL); + if (err < 0) + goto err1; + } else + memset(info->tb, 0, sizeof(info->tb[0]) * (type->maxattr + 1)); + + if (type->select_ops != NULL) { + ops = type->select_ops(ctx, + (const struct nlattr * const *)info->tb); + if (IS_ERR(ops)) { + err = PTR_ERR(ops); +#ifdef CONFIG_MODULES + if (err == -EAGAIN) + if (nft_expr_type_request_module(ctx->net, + ctx->family, + tb[NFTA_EXPR_NAME]) != -EAGAIN) + err = -ENOENT; +#endif + goto err1; + } + } else + ops = type->ops; + + info->attr = nla; + info->ops = ops; + + return 0; + +err1: + module_put(type->owner); + return err; +} + +static int nf_tables_newexpr(const struct nft_ctx *ctx, + const struct nft_expr_info *expr_info, + struct nft_expr *expr) +{ + const struct nft_expr_ops *ops = expr_info->ops; + int err; + + expr->ops = ops; + if (ops->init) { + err = ops->init(ctx, expr, (const struct nlattr **)expr_info->tb); + if (err < 0) + goto err1; + } + + return 0; +err1: + expr->ops = NULL; + return err; +} + +static void nf_tables_expr_destroy(const struct nft_ctx *ctx, + struct nft_expr *expr) +{ + const struct nft_expr_type *type = expr->ops->type; + + if (expr->ops->destroy) + expr->ops->destroy(ctx, expr); + module_put(type->owner); +} + +static struct nft_expr *nft_expr_init(const struct nft_ctx *ctx, + const struct nlattr *nla) +{ + struct nft_expr_info expr_info; + struct nft_expr *expr; + struct module *owner; + int err; + + err = nf_tables_expr_parse(ctx, nla, &expr_info); + if (err < 0) + goto err_expr_parse; + + err = -EOPNOTSUPP; + if (!(expr_info.ops->type->flags & NFT_EXPR_STATEFUL)) + goto err_expr_stateful; + + err = -ENOMEM; + expr = kzalloc(expr_info.ops->size, GFP_KERNEL_ACCOUNT); + if (expr == NULL) + goto err_expr_stateful; + + err = nf_tables_newexpr(ctx, &expr_info, expr); + if (err < 0) + goto err_expr_new; + + return expr; +err_expr_new: + kfree(expr); +err_expr_stateful: + owner = expr_info.ops->type->owner; + if (expr_info.ops->type->release_ops) + expr_info.ops->type->release_ops(expr_info.ops); + + module_put(owner); +err_expr_parse: + return ERR_PTR(err); +} + +int nft_expr_clone(struct nft_expr *dst, struct nft_expr *src) +{ + int err; + + if (src->ops->clone) { + dst->ops = src->ops; + err = src->ops->clone(dst, src); + if (err < 0) + return err; + } else { + memcpy(dst, src, src->ops->size); + } + + __module_get(src->ops->type->owner); + + return 0; +} + +void nft_expr_destroy(const struct nft_ctx *ctx, struct nft_expr *expr) +{ + nf_tables_expr_destroy(ctx, expr); + kfree(expr); +} + +/* + * Rules + */ + +static struct nft_rule *__nft_rule_lookup(const struct nft_chain *chain, + u64 handle) +{ + struct nft_rule *rule; + + // FIXME: this sucks + list_for_each_entry_rcu(rule, &chain->rules, list) { + if (handle == rule->handle) + return rule; + } + + return ERR_PTR(-ENOENT); +} + +static struct nft_rule *nft_rule_lookup(const struct nft_chain *chain, + const struct nlattr *nla) +{ + if (nla == NULL) + return ERR_PTR(-EINVAL); + + return __nft_rule_lookup(chain, be64_to_cpu(nla_get_be64(nla))); +} + +static const struct nla_policy nft_rule_policy[NFTA_RULE_MAX + 1] = { + [NFTA_RULE_TABLE] = { .type = NLA_STRING, + .len = NFT_TABLE_MAXNAMELEN - 1 }, + [NFTA_RULE_CHAIN] = { .type = NLA_STRING, + .len = NFT_CHAIN_MAXNAMELEN - 1 }, + [NFTA_RULE_HANDLE] = { .type = NLA_U64 }, + [NFTA_RULE_EXPRESSIONS] = { .type = NLA_NESTED }, + [NFTA_RULE_COMPAT] = { .type = NLA_NESTED }, + [NFTA_RULE_POSITION] = { .type = NLA_U64 }, + [NFTA_RULE_USERDATA] = { .type = NLA_BINARY, + .len = NFT_USERDATA_MAXLEN }, + [NFTA_RULE_ID] = { .type = NLA_U32 }, + [NFTA_RULE_POSITION_ID] = { .type = NLA_U32 }, + [NFTA_RULE_CHAIN_ID] = { .type = NLA_U32 }, +}; + +static int nf_tables_fill_rule_info(struct sk_buff *skb, struct net *net, + u32 portid, u32 seq, int event, + u32 flags, int family, + const struct nft_table *table, + const struct nft_chain *chain, + const struct nft_rule *rule, u64 handle) +{ + struct nlmsghdr *nlh; + const struct nft_expr *expr, *next; + struct nlattr *list; + u16 type = nfnl_msg_type(NFNL_SUBSYS_NFTABLES, event); + + nlh = nfnl_msg_put(skb, portid, seq, type, flags, family, NFNETLINK_V0, + nft_base_seq(net)); + if (!nlh) + goto nla_put_failure; + + if (nla_put_string(skb, NFTA_RULE_TABLE, table->name)) + goto nla_put_failure; + if (nla_put_string(skb, NFTA_RULE_CHAIN, chain->name)) + goto nla_put_failure; + if (nla_put_be64(skb, NFTA_RULE_HANDLE, cpu_to_be64(rule->handle), + NFTA_RULE_PAD)) + goto nla_put_failure; + + if (event != NFT_MSG_DELRULE && handle) { + if (nla_put_be64(skb, NFTA_RULE_POSITION, cpu_to_be64(handle), + NFTA_RULE_PAD)) + goto nla_put_failure; + } + + if (chain->flags & NFT_CHAIN_HW_OFFLOAD) + nft_flow_rule_stats(chain, rule); + + list = nla_nest_start_noflag(skb, NFTA_RULE_EXPRESSIONS); + if (list == NULL) + goto nla_put_failure; + nft_rule_for_each_expr(expr, next, rule) { + if (nft_expr_dump(skb, NFTA_LIST_ELEM, expr) < 0) + goto nla_put_failure; + } + nla_nest_end(skb, list); + + if (rule->udata) { + struct nft_userdata *udata = nft_userdata(rule); + if (nla_put(skb, NFTA_RULE_USERDATA, udata->len + 1, + udata->data) < 0) + goto nla_put_failure; + } + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_trim(skb, nlh); + return -1; +} + +static void nf_tables_rule_notify(const struct nft_ctx *ctx, + const struct nft_rule *rule, int event) +{ + struct nftables_pernet *nft_net = nft_pernet(ctx->net); + const struct nft_rule *prule; + struct sk_buff *skb; + u64 handle = 0; + u16 flags = 0; + int err; + + if (!ctx->report && + !nfnetlink_has_listeners(ctx->net, NFNLGRP_NFTABLES)) + return; + + skb = nlmsg_new(NLMSG_GOODSIZE, GFP_KERNEL); + if (skb == NULL) + goto err; + + if (event == NFT_MSG_NEWRULE && + !list_is_first(&rule->list, &ctx->chain->rules) && + !list_is_last(&rule->list, &ctx->chain->rules)) { + prule = list_prev_entry(rule, list); + handle = prule->handle; + } + if (ctx->flags & (NLM_F_APPEND | NLM_F_REPLACE)) + flags |= NLM_F_APPEND; + if (ctx->flags & (NLM_F_CREATE | NLM_F_EXCL)) + flags |= ctx->flags & (NLM_F_CREATE | NLM_F_EXCL); + + err = nf_tables_fill_rule_info(skb, ctx->net, ctx->portid, ctx->seq, + event, flags, ctx->family, ctx->table, + ctx->chain, rule, handle); + if (err < 0) { + kfree_skb(skb); + goto err; + } + + nft_notify_enqueue(skb, ctx->report, &nft_net->notify_list); + return; +err: + nfnetlink_set_err(ctx->net, ctx->portid, NFNLGRP_NFTABLES, -ENOBUFS); +} + +struct nft_rule_dump_ctx { + char *table; + char *chain; +}; + +static int __nf_tables_dump_rules(struct sk_buff *skb, + unsigned int *idx, + struct netlink_callback *cb, + const struct nft_table *table, + const struct nft_chain *chain) +{ + struct net *net = sock_net(skb->sk); + const struct nft_rule *rule, *prule; + unsigned int s_idx = cb->args[0]; + u64 handle; + + prule = NULL; + list_for_each_entry_rcu(rule, &chain->rules, list) { + if (!nft_is_active(net, rule)) + goto cont_skip; + if (*idx < s_idx) + goto cont; + if (prule) + handle = prule->handle; + else + handle = 0; + + if (nf_tables_fill_rule_info(skb, net, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NFT_MSG_NEWRULE, + NLM_F_MULTI | NLM_F_APPEND, + table->family, + table, chain, rule, handle) < 0) + return 1; + + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); +cont: + prule = rule; +cont_skip: + (*idx)++; + } + return 0; +} + +static int nf_tables_dump_rules(struct sk_buff *skb, + struct netlink_callback *cb) +{ + const struct nfgenmsg *nfmsg = nlmsg_data(cb->nlh); + const struct nft_rule_dump_ctx *ctx = cb->data; + struct nft_table *table; + const struct nft_chain *chain; + unsigned int idx = 0; + struct net *net = sock_net(skb->sk); + int family = nfmsg->nfgen_family; + struct nftables_pernet *nft_net; + + rcu_read_lock(); + nft_net = nft_pernet(net); + cb->seq = READ_ONCE(nft_net->base_seq); + + list_for_each_entry_rcu(table, &nft_net->tables, list) { + if (family != NFPROTO_UNSPEC && family != table->family) + continue; + + if (ctx && ctx->table && strcmp(ctx->table, table->name) != 0) + continue; + + if (ctx && ctx->table && ctx->chain) { + struct rhlist_head *list, *tmp; + + list = rhltable_lookup(&table->chains_ht, ctx->chain, + nft_chain_ht_params); + if (!list) + goto done; + + rhl_for_each_entry_rcu(chain, tmp, list, rhlhead) { + if (!nft_is_active(net, chain)) + continue; + __nf_tables_dump_rules(skb, &idx, + cb, table, chain); + break; + } + goto done; + } + + list_for_each_entry_rcu(chain, &table->chains, list) { + if (__nf_tables_dump_rules(skb, &idx, cb, table, chain)) + goto done; + } + + if (ctx && ctx->table) + break; + } +done: + rcu_read_unlock(); + + cb->args[0] = idx; + return skb->len; +} + +static int nf_tables_dump_rules_start(struct netlink_callback *cb) +{ + const struct nlattr * const *nla = cb->data; + struct nft_rule_dump_ctx *ctx = NULL; + + if (nla[NFTA_RULE_TABLE] || nla[NFTA_RULE_CHAIN]) { + ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC); + if (!ctx) + return -ENOMEM; + + if (nla[NFTA_RULE_TABLE]) { + ctx->table = nla_strdup(nla[NFTA_RULE_TABLE], + GFP_ATOMIC); + if (!ctx->table) { + kfree(ctx); + return -ENOMEM; + } + } + if (nla[NFTA_RULE_CHAIN]) { + ctx->chain = nla_strdup(nla[NFTA_RULE_CHAIN], + GFP_ATOMIC); + if (!ctx->chain) { + kfree(ctx->table); + kfree(ctx); + return -ENOMEM; + } + } + } + + cb->data = ctx; + return 0; +} + +static int nf_tables_dump_rules_done(struct netlink_callback *cb) +{ + struct nft_rule_dump_ctx *ctx = cb->data; + + if (ctx) { + kfree(ctx->table); + kfree(ctx->chain); + kfree(ctx); + } + return 0; +} + +/* called with rcu_read_lock held */ +static int nf_tables_getrule(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_cur(info->net); + u8 family = info->nfmsg->nfgen_family; + const struct nft_chain *chain; + const struct nft_rule *rule; + struct net *net = info->net; + struct nft_table *table; + struct sk_buff *skb2; + int err; + + if (info->nlh->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .start= nf_tables_dump_rules_start, + .dump = nf_tables_dump_rules, + .done = nf_tables_dump_rules_done, + .module = THIS_MODULE, + .data = (void *)nla, + }; + + return nft_netlink_dump_start_rcu(info->sk, skb, info->nlh, &c); + } + + table = nft_table_lookup(net, nla[NFTA_RULE_TABLE], family, genmask, 0); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_RULE_TABLE]); + return PTR_ERR(table); + } + + chain = nft_chain_lookup(net, table, nla[NFTA_RULE_CHAIN], genmask); + if (IS_ERR(chain)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_RULE_CHAIN]); + return PTR_ERR(chain); + } + + rule = nft_rule_lookup(chain, nla[NFTA_RULE_HANDLE]); + if (IS_ERR(rule)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_RULE_HANDLE]); + return PTR_ERR(rule); + } + + skb2 = alloc_skb(NLMSG_GOODSIZE, GFP_ATOMIC); + if (!skb2) + return -ENOMEM; + + err = nf_tables_fill_rule_info(skb2, net, NETLINK_CB(skb).portid, + info->nlh->nlmsg_seq, NFT_MSG_NEWRULE, 0, + family, table, chain, rule, 0); + if (err < 0) + goto err_fill_rule_info; + + return nfnetlink_unicast(skb2, net, NETLINK_CB(skb).portid); + +err_fill_rule_info: + kfree_skb(skb2); + return err; +} + +void nf_tables_rule_destroy(const struct nft_ctx *ctx, struct nft_rule *rule) +{ + struct nft_expr *expr, *next; + + /* + * Careful: some expressions might not be initialized in case this + * is called on error from nf_tables_newrule(). + */ + expr = nft_expr_first(rule); + while (nft_expr_more(rule, expr)) { + next = nft_expr_next(expr); + nf_tables_expr_destroy(ctx, expr); + expr = next; + } + kfree(rule); +} + +static void nf_tables_rule_release(const struct nft_ctx *ctx, struct nft_rule *rule) +{ + nft_rule_expr_deactivate(ctx, rule, NFT_TRANS_RELEASE); + nf_tables_rule_destroy(ctx, rule); +} + +int nft_chain_validate(const struct nft_ctx *ctx, const struct nft_chain *chain) +{ + struct nft_expr *expr, *last; + const struct nft_data *data; + struct nft_rule *rule; + int err; + + if (ctx->level == NFT_JUMP_STACK_SIZE) + return -EMLINK; + + list_for_each_entry(rule, &chain->rules, list) { + if (!nft_is_active_next(ctx->net, rule)) + continue; + + nft_rule_for_each_expr(expr, last, rule) { + if (!expr->ops->validate) + continue; + + err = expr->ops->validate(ctx, expr, &data); + if (err < 0) + return err; + } + } + + return 0; +} +EXPORT_SYMBOL_GPL(nft_chain_validate); + +static int nft_table_validate(struct net *net, const struct nft_table *table) +{ + struct nft_chain *chain; + struct nft_ctx ctx = { + .net = net, + .family = table->family, + }; + int err; + + list_for_each_entry(chain, &table->chains, list) { + if (!nft_is_base_chain(chain)) + continue; + + ctx.chain = chain; + err = nft_chain_validate(&ctx, chain); + if (err < 0) + return err; + + cond_resched(); + } + + return 0; +} + +int nft_setelem_validate(const struct nft_ctx *ctx, struct nft_set *set, + const struct nft_set_iter *iter, + struct nft_set_elem *elem) +{ + const struct nft_set_ext *ext = nft_set_elem_ext(set, elem->priv); + struct nft_ctx *pctx = (struct nft_ctx *)ctx; + const struct nft_data *data; + int err; + + if (nft_set_ext_exists(ext, NFT_SET_EXT_FLAGS) && + *nft_set_ext_flags(ext) & NFT_SET_ELEM_INTERVAL_END) + return 0; + + data = nft_set_ext_data(ext); + switch (data->verdict.code) { + case NFT_JUMP: + case NFT_GOTO: + pctx->level++; + err = nft_chain_validate(ctx, data->verdict.chain); + if (err < 0) + return err; + pctx->level--; + break; + default: + break; + } + + return 0; +} + +int nft_set_catchall_validate(const struct nft_ctx *ctx, struct nft_set *set) +{ + u8 genmask = nft_genmask_next(ctx->net); + struct nft_set_elem_catchall *catchall; + struct nft_set_elem elem; + struct nft_set_ext *ext; + int ret = 0; + + list_for_each_entry_rcu(catchall, &set->catchall_list, list) { + ext = nft_set_elem_ext(set, catchall->elem); + if (!nft_set_elem_active(ext, genmask)) + continue; + + elem.priv = catchall->elem; + ret = nft_setelem_validate(ctx, set, NULL, &elem); + if (ret < 0) + return ret; + } + + return ret; +} + +static struct nft_rule *nft_rule_lookup_byid(const struct net *net, + const struct nft_chain *chain, + const struct nlattr *nla); + +#define NFT_RULE_MAXEXPRS 128 + +static int nf_tables_newrule(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct nftables_pernet *nft_net = nft_pernet(info->net); + struct netlink_ext_ack *extack = info->extack; + unsigned int size, i, n, ulen = 0, usize = 0; + u8 genmask = nft_genmask_next(info->net); + struct nft_rule *rule, *old_rule = NULL; + struct nft_expr_info *expr_info = NULL; + u8 family = info->nfmsg->nfgen_family; + struct nft_flow_rule *flow = NULL; + struct net *net = info->net; + struct nft_userdata *udata; + struct nft_table *table; + struct nft_chain *chain; + struct nft_trans *trans; + u64 handle, pos_handle; + struct nft_expr *expr; + struct nft_ctx ctx; + struct nlattr *tmp; + int err, rem; + + lockdep_assert_held(&nft_net->commit_mutex); + + table = nft_table_lookup(net, nla[NFTA_RULE_TABLE], family, genmask, + NETLINK_CB(skb).portid); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_RULE_TABLE]); + return PTR_ERR(table); + } + + if (nla[NFTA_RULE_CHAIN]) { + chain = nft_chain_lookup(net, table, nla[NFTA_RULE_CHAIN], + genmask); + if (IS_ERR(chain)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_RULE_CHAIN]); + return PTR_ERR(chain); + } + + } else if (nla[NFTA_RULE_CHAIN_ID]) { + chain = nft_chain_lookup_byid(net, table, nla[NFTA_RULE_CHAIN_ID], + genmask); + if (IS_ERR(chain)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_RULE_CHAIN_ID]); + return PTR_ERR(chain); + } + } else { + return -EINVAL; + } + + if (nft_chain_is_bound(chain)) + return -EOPNOTSUPP; + + if (nla[NFTA_RULE_HANDLE]) { + handle = be64_to_cpu(nla_get_be64(nla[NFTA_RULE_HANDLE])); + rule = __nft_rule_lookup(chain, handle); + if (IS_ERR(rule)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_RULE_HANDLE]); + return PTR_ERR(rule); + } + + if (info->nlh->nlmsg_flags & NLM_F_EXCL) { + NL_SET_BAD_ATTR(extack, nla[NFTA_RULE_HANDLE]); + return -EEXIST; + } + if (info->nlh->nlmsg_flags & NLM_F_REPLACE) + old_rule = rule; + else + return -EOPNOTSUPP; + } else { + if (!(info->nlh->nlmsg_flags & NLM_F_CREATE) || + info->nlh->nlmsg_flags & NLM_F_REPLACE) + return -EINVAL; + handle = nf_tables_alloc_handle(table); + + if (nla[NFTA_RULE_POSITION]) { + pos_handle = be64_to_cpu(nla_get_be64(nla[NFTA_RULE_POSITION])); + old_rule = __nft_rule_lookup(chain, pos_handle); + if (IS_ERR(old_rule)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_RULE_POSITION]); + return PTR_ERR(old_rule); + } + } else if (nla[NFTA_RULE_POSITION_ID]) { + old_rule = nft_rule_lookup_byid(net, chain, nla[NFTA_RULE_POSITION_ID]); + if (IS_ERR(old_rule)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_RULE_POSITION_ID]); + return PTR_ERR(old_rule); + } + } + } + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, chain, nla); + + n = 0; + size = 0; + if (nla[NFTA_RULE_EXPRESSIONS]) { + expr_info = kvmalloc_array(NFT_RULE_MAXEXPRS, + sizeof(struct nft_expr_info), + GFP_KERNEL); + if (!expr_info) + return -ENOMEM; + + nla_for_each_nested(tmp, nla[NFTA_RULE_EXPRESSIONS], rem) { + err = -EINVAL; + if (nla_type(tmp) != NFTA_LIST_ELEM) + goto err_release_expr; + if (n == NFT_RULE_MAXEXPRS) + goto err_release_expr; + err = nf_tables_expr_parse(&ctx, tmp, &expr_info[n]); + if (err < 0) { + NL_SET_BAD_ATTR(extack, tmp); + goto err_release_expr; + } + size += expr_info[n].ops->size; + n++; + } + } + /* Check for overflow of dlen field */ + err = -EFBIG; + if (size >= 1 << 12) + goto err_release_expr; + + if (nla[NFTA_RULE_USERDATA]) { + ulen = nla_len(nla[NFTA_RULE_USERDATA]); + if (ulen > 0) + usize = sizeof(struct nft_userdata) + ulen; + } + + err = -ENOMEM; + rule = kzalloc(sizeof(*rule) + size + usize, GFP_KERNEL_ACCOUNT); + if (rule == NULL) + goto err_release_expr; + + nft_activate_next(net, rule); + + rule->handle = handle; + rule->dlen = size; + rule->udata = ulen ? 1 : 0; + + if (ulen) { + udata = nft_userdata(rule); + udata->len = ulen - 1; + nla_memcpy(udata->data, nla[NFTA_RULE_USERDATA], ulen); + } + + expr = nft_expr_first(rule); + for (i = 0; i < n; i++) { + err = nf_tables_newexpr(&ctx, &expr_info[i], expr); + if (err < 0) { + NL_SET_BAD_ATTR(extack, expr_info[i].attr); + goto err_release_rule; + } + + if (expr_info[i].ops->validate) + nft_validate_state_update(net, NFT_VALIDATE_NEED); + + expr_info[i].ops = NULL; + expr = nft_expr_next(expr); + } + + if (chain->flags & NFT_CHAIN_HW_OFFLOAD) { + flow = nft_flow_rule_create(net, rule); + if (IS_ERR(flow)) { + err = PTR_ERR(flow); + goto err_release_rule; + } + } + + if (!nft_use_inc(&chain->use)) { + err = -EMFILE; + goto err_release_rule; + } + + if (info->nlh->nlmsg_flags & NLM_F_REPLACE) { + if (nft_chain_binding(chain)) { + err = -EOPNOTSUPP; + goto err_destroy_flow_rule; + } + + err = nft_delrule(&ctx, old_rule); + if (err < 0) + goto err_destroy_flow_rule; + + trans = nft_trans_rule_add(&ctx, NFT_MSG_NEWRULE, rule); + if (trans == NULL) { + err = -ENOMEM; + goto err_destroy_flow_rule; + } + list_add_tail_rcu(&rule->list, &old_rule->list); + } else { + trans = nft_trans_rule_add(&ctx, NFT_MSG_NEWRULE, rule); + if (!trans) { + err = -ENOMEM; + goto err_destroy_flow_rule; + } + + if (info->nlh->nlmsg_flags & NLM_F_APPEND) { + if (old_rule) + list_add_rcu(&rule->list, &old_rule->list); + else + list_add_tail_rcu(&rule->list, &chain->rules); + } else { + if (old_rule) + list_add_tail_rcu(&rule->list, &old_rule->list); + else + list_add_rcu(&rule->list, &chain->rules); + } + } + kvfree(expr_info); + + if (flow) + nft_trans_flow_rule(trans) = flow; + + if (nft_net->validate_state == NFT_VALIDATE_DO) + return nft_table_validate(net, table); + + return 0; + +err_destroy_flow_rule: + nft_use_dec_restore(&chain->use); + if (flow) + nft_flow_rule_destroy(flow); +err_release_rule: + nft_rule_expr_deactivate(&ctx, rule, NFT_TRANS_PREPARE_ERROR); + nf_tables_rule_destroy(&ctx, rule); +err_release_expr: + for (i = 0; i < n; i++) { + if (expr_info[i].ops) { + module_put(expr_info[i].ops->type->owner); + if (expr_info[i].ops->type->release_ops) + expr_info[i].ops->type->release_ops(expr_info[i].ops); + } + } + kvfree(expr_info); + + return err; +} + +static struct nft_rule *nft_rule_lookup_byid(const struct net *net, + const struct nft_chain *chain, + const struct nlattr *nla) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + u32 id = ntohl(nla_get_be32(nla)); + struct nft_trans *trans; + + list_for_each_entry(trans, &nft_net->commit_list, list) { + if (trans->msg_type == NFT_MSG_NEWRULE && + trans->ctx.chain == chain && + id == nft_trans_rule_id(trans)) + return nft_trans_rule(trans); + } + return ERR_PTR(-ENOENT); +} + +static int nf_tables_delrule(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_next(info->net); + u8 family = info->nfmsg->nfgen_family; + struct nft_chain *chain = NULL; + struct net *net = info->net; + struct nft_table *table; + struct nft_rule *rule; + struct nft_ctx ctx; + int err = 0; + + table = nft_table_lookup(net, nla[NFTA_RULE_TABLE], family, genmask, + NETLINK_CB(skb).portid); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_RULE_TABLE]); + return PTR_ERR(table); + } + + if (nla[NFTA_RULE_CHAIN]) { + chain = nft_chain_lookup(net, table, nla[NFTA_RULE_CHAIN], + genmask); + if (IS_ERR(chain)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_RULE_CHAIN]); + return PTR_ERR(chain); + } + if (nft_chain_binding(chain)) + return -EOPNOTSUPP; + } + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, chain, nla); + + if (chain) { + if (nla[NFTA_RULE_HANDLE]) { + rule = nft_rule_lookup(chain, nla[NFTA_RULE_HANDLE]); + if (IS_ERR(rule)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_RULE_HANDLE]); + return PTR_ERR(rule); + } + + err = nft_delrule(&ctx, rule); + } else if (nla[NFTA_RULE_ID]) { + rule = nft_rule_lookup_byid(net, chain, nla[NFTA_RULE_ID]); + if (IS_ERR(rule)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_RULE_ID]); + return PTR_ERR(rule); + } + + err = nft_delrule(&ctx, rule); + } else { + err = nft_delrule_by_chain(&ctx); + } + } else { + list_for_each_entry(chain, &table->chains, list) { + if (!nft_is_active_next(net, chain)) + continue; + if (nft_chain_binding(chain)) + continue; + + ctx.chain = chain; + err = nft_delrule_by_chain(&ctx); + if (err < 0) + break; + } + } + + return err; +} + +/* + * Sets + */ +static const struct nft_set_type *nft_set_types[] = { + &nft_set_hash_fast_type, + &nft_set_hash_type, + &nft_set_rhash_type, + &nft_set_bitmap_type, + &nft_set_rbtree_type, +#if defined(CONFIG_X86_64) && !defined(CONFIG_UML) + &nft_set_pipapo_avx2_type, +#endif + &nft_set_pipapo_type, +}; + +#define NFT_SET_FEATURES (NFT_SET_INTERVAL | NFT_SET_MAP | \ + NFT_SET_TIMEOUT | NFT_SET_OBJECT | \ + NFT_SET_EVAL) + +static bool nft_set_ops_candidate(const struct nft_set_type *type, u32 flags) +{ + return (flags & type->features) == (flags & NFT_SET_FEATURES); +} + +/* + * Select a set implementation based on the data characteristics and the + * given policy. The total memory use might not be known if no size is + * given, in that case the amount of memory per element is used. + */ +static const struct nft_set_ops * +nft_select_set_ops(const struct nft_ctx *ctx, + const struct nlattr * const nla[], + const struct nft_set_desc *desc) +{ + struct nftables_pernet *nft_net = nft_pernet(ctx->net); + const struct nft_set_ops *ops, *bops; + struct nft_set_estimate est, best; + const struct nft_set_type *type; + u32 flags = 0; + int i; + + lockdep_assert_held(&nft_net->commit_mutex); + lockdep_nfnl_nft_mutex_not_held(); + + if (nla[NFTA_SET_FLAGS] != NULL) + flags = ntohl(nla_get_be32(nla[NFTA_SET_FLAGS])); + + bops = NULL; + best.size = ~0; + best.lookup = ~0; + best.space = ~0; + + for (i = 0; i < ARRAY_SIZE(nft_set_types); i++) { + type = nft_set_types[i]; + ops = &type->ops; + + if (!nft_set_ops_candidate(type, flags)) + continue; + if (!ops->estimate(desc, flags, &est)) + continue; + + switch (desc->policy) { + case NFT_SET_POL_PERFORMANCE: + if (est.lookup < best.lookup) + break; + if (est.lookup == best.lookup && + est.space < best.space) + break; + continue; + case NFT_SET_POL_MEMORY: + if (!desc->size) { + if (est.space < best.space) + break; + if (est.space == best.space && + est.lookup < best.lookup) + break; + } else if (est.size < best.size || !bops) { + break; + } + continue; + default: + break; + } + + bops = ops; + best = est; + } + + if (bops != NULL) + return bops; + + return ERR_PTR(-EOPNOTSUPP); +} + +static const struct nla_policy nft_set_policy[NFTA_SET_MAX + 1] = { + [NFTA_SET_TABLE] = { .type = NLA_STRING, + .len = NFT_TABLE_MAXNAMELEN - 1 }, + [NFTA_SET_NAME] = { .type = NLA_STRING, + .len = NFT_SET_MAXNAMELEN - 1 }, + [NFTA_SET_FLAGS] = { .type = NLA_U32 }, + [NFTA_SET_KEY_TYPE] = { .type = NLA_U32 }, + [NFTA_SET_KEY_LEN] = { .type = NLA_U32 }, + [NFTA_SET_DATA_TYPE] = { .type = NLA_U32 }, + [NFTA_SET_DATA_LEN] = { .type = NLA_U32 }, + [NFTA_SET_POLICY] = { .type = NLA_U32 }, + [NFTA_SET_DESC] = { .type = NLA_NESTED }, + [NFTA_SET_ID] = { .type = NLA_U32 }, + [NFTA_SET_TIMEOUT] = { .type = NLA_U64 }, + [NFTA_SET_GC_INTERVAL] = { .type = NLA_U32 }, + [NFTA_SET_USERDATA] = { .type = NLA_BINARY, + .len = NFT_USERDATA_MAXLEN }, + [NFTA_SET_OBJ_TYPE] = { .type = NLA_U32 }, + [NFTA_SET_HANDLE] = { .type = NLA_U64 }, + [NFTA_SET_EXPR] = { .type = NLA_NESTED }, + [NFTA_SET_EXPRESSIONS] = { .type = NLA_NESTED }, +}; + +static const struct nla_policy nft_set_desc_policy[NFTA_SET_DESC_MAX + 1] = { + [NFTA_SET_DESC_SIZE] = { .type = NLA_U32 }, + [NFTA_SET_DESC_CONCAT] = { .type = NLA_NESTED }, +}; + +static struct nft_set *nft_set_lookup(const struct nft_table *table, + const struct nlattr *nla, u8 genmask) +{ + struct nft_set *set; + + if (nla == NULL) + return ERR_PTR(-EINVAL); + + list_for_each_entry_rcu(set, &table->sets, list) { + if (!nla_strcmp(nla, set->name) && + nft_active_genmask(set, genmask)) + return set; + } + return ERR_PTR(-ENOENT); +} + +static struct nft_set *nft_set_lookup_byhandle(const struct nft_table *table, + const struct nlattr *nla, + u8 genmask) +{ + struct nft_set *set; + + list_for_each_entry(set, &table->sets, list) { + if (be64_to_cpu(nla_get_be64(nla)) == set->handle && + nft_active_genmask(set, genmask)) + return set; + } + return ERR_PTR(-ENOENT); +} + +static struct nft_set *nft_set_lookup_byid(const struct net *net, + const struct nft_table *table, + const struct nlattr *nla, u8 genmask) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + u32 id = ntohl(nla_get_be32(nla)); + struct nft_trans *trans; + + list_for_each_entry(trans, &nft_net->commit_list, list) { + if (trans->msg_type == NFT_MSG_NEWSET) { + struct nft_set *set = nft_trans_set(trans); + + if (id == nft_trans_set_id(trans) && + set->table == table && + nft_active_genmask(set, genmask)) + return set; + } + } + return ERR_PTR(-ENOENT); +} + +struct nft_set *nft_set_lookup_global(const struct net *net, + const struct nft_table *table, + const struct nlattr *nla_set_name, + const struct nlattr *nla_set_id, + u8 genmask) +{ + struct nft_set *set; + + set = nft_set_lookup(table, nla_set_name, genmask); + if (IS_ERR(set)) { + if (!nla_set_id) + return set; + + set = nft_set_lookup_byid(net, table, nla_set_id, genmask); + } + return set; +} +EXPORT_SYMBOL_GPL(nft_set_lookup_global); + +static int nf_tables_set_alloc_name(struct nft_ctx *ctx, struct nft_set *set, + const char *name) +{ + const struct nft_set *i; + const char *p; + unsigned long *inuse; + unsigned int n = 0, min = 0; + + p = strchr(name, '%'); + if (p != NULL) { + if (p[1] != 'd' || strchr(p + 2, '%')) + return -EINVAL; + + if (strnlen(name, NFT_SET_MAX_ANONLEN) >= NFT_SET_MAX_ANONLEN) + return -EINVAL; + + inuse = (unsigned long *)get_zeroed_page(GFP_KERNEL); + if (inuse == NULL) + return -ENOMEM; +cont: + list_for_each_entry(i, &ctx->table->sets, list) { + int tmp; + + if (!nft_is_active_next(ctx->net, i)) + continue; + if (!sscanf(i->name, name, &tmp)) + continue; + if (tmp < min || tmp >= min + BITS_PER_BYTE * PAGE_SIZE) + continue; + + set_bit(tmp - min, inuse); + } + + n = find_first_zero_bit(inuse, BITS_PER_BYTE * PAGE_SIZE); + if (n >= BITS_PER_BYTE * PAGE_SIZE) { + min += BITS_PER_BYTE * PAGE_SIZE; + memset(inuse, 0, PAGE_SIZE); + goto cont; + } + free_page((unsigned long)inuse); + } + + set->name = kasprintf(GFP_KERNEL_ACCOUNT, name, min + n); + if (!set->name) + return -ENOMEM; + + list_for_each_entry(i, &ctx->table->sets, list) { + if (!nft_is_active_next(ctx->net, i)) + continue; + if (!strcmp(set->name, i->name)) { + kfree(set->name); + set->name = NULL; + return -ENFILE; + } + } + return 0; +} + +int nf_msecs_to_jiffies64(const struct nlattr *nla, u64 *result) +{ + u64 ms = be64_to_cpu(nla_get_be64(nla)); + u64 max = (u64)(~((u64)0)); + + max = div_u64(max, NSEC_PER_MSEC); + if (ms >= max) + return -ERANGE; + + ms *= NSEC_PER_MSEC; + *result = nsecs_to_jiffies64(ms); + return 0; +} + +__be64 nf_jiffies64_to_msecs(u64 input) +{ + return cpu_to_be64(jiffies64_to_msecs(input)); +} + +static int nf_tables_fill_set_concat(struct sk_buff *skb, + const struct nft_set *set) +{ + struct nlattr *concat, *field; + int i; + + concat = nla_nest_start_noflag(skb, NFTA_SET_DESC_CONCAT); + if (!concat) + return -ENOMEM; + + for (i = 0; i < set->field_count; i++) { + field = nla_nest_start_noflag(skb, NFTA_LIST_ELEM); + if (!field) + return -ENOMEM; + + if (nla_put_be32(skb, NFTA_SET_FIELD_LEN, + htonl(set->field_len[i]))) + return -ENOMEM; + + nla_nest_end(skb, field); + } + + nla_nest_end(skb, concat); + + return 0; +} + +static int nf_tables_fill_set(struct sk_buff *skb, const struct nft_ctx *ctx, + const struct nft_set *set, u16 event, u16 flags) +{ + u64 timeout = READ_ONCE(set->timeout); + u32 gc_int = READ_ONCE(set->gc_int); + u32 portid = ctx->portid; + struct nlmsghdr *nlh; + struct nlattr *nest; + u32 seq = ctx->seq; + int i; + + event = nfnl_msg_type(NFNL_SUBSYS_NFTABLES, event); + nlh = nfnl_msg_put(skb, portid, seq, event, flags, ctx->family, + NFNETLINK_V0, nft_base_seq(ctx->net)); + if (!nlh) + goto nla_put_failure; + + if (nla_put_string(skb, NFTA_SET_TABLE, ctx->table->name)) + goto nla_put_failure; + if (nla_put_string(skb, NFTA_SET_NAME, set->name)) + goto nla_put_failure; + if (nla_put_be64(skb, NFTA_SET_HANDLE, cpu_to_be64(set->handle), + NFTA_SET_PAD)) + goto nla_put_failure; + if (set->flags != 0) + if (nla_put_be32(skb, NFTA_SET_FLAGS, htonl(set->flags))) + goto nla_put_failure; + + if (nla_put_be32(skb, NFTA_SET_KEY_TYPE, htonl(set->ktype))) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_SET_KEY_LEN, htonl(set->klen))) + goto nla_put_failure; + if (set->flags & NFT_SET_MAP) { + if (nla_put_be32(skb, NFTA_SET_DATA_TYPE, htonl(set->dtype))) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_SET_DATA_LEN, htonl(set->dlen))) + goto nla_put_failure; + } + if (set->flags & NFT_SET_OBJECT && + nla_put_be32(skb, NFTA_SET_OBJ_TYPE, htonl(set->objtype))) + goto nla_put_failure; + + if (timeout && + nla_put_be64(skb, NFTA_SET_TIMEOUT, + nf_jiffies64_to_msecs(timeout), + NFTA_SET_PAD)) + goto nla_put_failure; + if (gc_int && + nla_put_be32(skb, NFTA_SET_GC_INTERVAL, htonl(gc_int))) + goto nla_put_failure; + + if (set->policy != NFT_SET_POL_PERFORMANCE) { + if (nla_put_be32(skb, NFTA_SET_POLICY, htonl(set->policy))) + goto nla_put_failure; + } + + if (set->udata && + nla_put(skb, NFTA_SET_USERDATA, set->udlen, set->udata)) + goto nla_put_failure; + + nest = nla_nest_start_noflag(skb, NFTA_SET_DESC); + if (!nest) + goto nla_put_failure; + if (set->size && + nla_put_be32(skb, NFTA_SET_DESC_SIZE, htonl(set->size))) + goto nla_put_failure; + + if (set->field_count > 1 && + nf_tables_fill_set_concat(skb, set)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + + if (set->num_exprs == 1) { + nest = nla_nest_start_noflag(skb, NFTA_SET_EXPR); + if (nf_tables_fill_expr_info(skb, set->exprs[0]) < 0) + goto nla_put_failure; + + nla_nest_end(skb, nest); + } else if (set->num_exprs > 1) { + nest = nla_nest_start_noflag(skb, NFTA_SET_EXPRESSIONS); + if (nest == NULL) + goto nla_put_failure; + + for (i = 0; i < set->num_exprs; i++) { + if (nft_expr_dump(skb, NFTA_LIST_ELEM, + set->exprs[i]) < 0) + goto nla_put_failure; + } + nla_nest_end(skb, nest); + } + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_trim(skb, nlh); + return -1; +} + +static void nf_tables_set_notify(const struct nft_ctx *ctx, + const struct nft_set *set, int event, + gfp_t gfp_flags) +{ + struct nftables_pernet *nft_net = nft_pernet(ctx->net); + u32 portid = ctx->portid; + struct sk_buff *skb; + u16 flags = 0; + int err; + + if (!ctx->report && + !nfnetlink_has_listeners(ctx->net, NFNLGRP_NFTABLES)) + return; + + skb = nlmsg_new(NLMSG_GOODSIZE, gfp_flags); + if (skb == NULL) + goto err; + + if (ctx->flags & (NLM_F_CREATE | NLM_F_EXCL)) + flags |= ctx->flags & (NLM_F_CREATE | NLM_F_EXCL); + + err = nf_tables_fill_set(skb, ctx, set, event, flags); + if (err < 0) { + kfree_skb(skb); + goto err; + } + + nft_notify_enqueue(skb, ctx->report, &nft_net->notify_list); + return; +err: + nfnetlink_set_err(ctx->net, portid, NFNLGRP_NFTABLES, -ENOBUFS); +} + +static int nf_tables_dump_sets(struct sk_buff *skb, struct netlink_callback *cb) +{ + const struct nft_set *set; + unsigned int idx, s_idx = cb->args[0]; + struct nft_table *table, *cur_table = (struct nft_table *)cb->args[2]; + struct net *net = sock_net(skb->sk); + struct nft_ctx *ctx = cb->data, ctx_set; + struct nftables_pernet *nft_net; + + if (cb->args[1]) + return skb->len; + + rcu_read_lock(); + nft_net = nft_pernet(net); + cb->seq = READ_ONCE(nft_net->base_seq); + + list_for_each_entry_rcu(table, &nft_net->tables, list) { + if (ctx->family != NFPROTO_UNSPEC && + ctx->family != table->family) + continue; + + if (ctx->table && ctx->table != table) + continue; + + if (cur_table) { + if (cur_table != table) + continue; + + cur_table = NULL; + } + idx = 0; + list_for_each_entry_rcu(set, &table->sets, list) { + if (idx < s_idx) + goto cont; + if (!nft_is_active(net, set)) + goto cont; + + ctx_set = *ctx; + ctx_set.table = table; + ctx_set.family = table->family; + + if (nf_tables_fill_set(skb, &ctx_set, set, + NFT_MSG_NEWSET, + NLM_F_MULTI) < 0) { + cb->args[0] = idx; + cb->args[2] = (unsigned long) table; + goto done; + } + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); +cont: + idx++; + } + if (s_idx) + s_idx = 0; + } + cb->args[1] = 1; +done: + rcu_read_unlock(); + return skb->len; +} + +static int nf_tables_dump_sets_start(struct netlink_callback *cb) +{ + struct nft_ctx *ctx_dump = NULL; + + ctx_dump = kmemdup(cb->data, sizeof(*ctx_dump), GFP_ATOMIC); + if (ctx_dump == NULL) + return -ENOMEM; + + cb->data = ctx_dump; + return 0; +} + +static int nf_tables_dump_sets_done(struct netlink_callback *cb) +{ + kfree(cb->data); + return 0; +} + +/* called with rcu_read_lock held */ +static int nf_tables_getset(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_cur(info->net); + u8 family = info->nfmsg->nfgen_family; + struct nft_table *table = NULL; + struct net *net = info->net; + const struct nft_set *set; + struct sk_buff *skb2; + struct nft_ctx ctx; + int err; + + if (nla[NFTA_SET_TABLE]) { + table = nft_table_lookup(net, nla[NFTA_SET_TABLE], family, + genmask, 0); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_SET_TABLE]); + return PTR_ERR(table); + } + } + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, NULL, nla); + + if (info->nlh->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .start = nf_tables_dump_sets_start, + .dump = nf_tables_dump_sets, + .done = nf_tables_dump_sets_done, + .data = &ctx, + .module = THIS_MODULE, + }; + + return nft_netlink_dump_start_rcu(info->sk, skb, info->nlh, &c); + } + + /* Only accept unspec with dump */ + if (info->nfmsg->nfgen_family == NFPROTO_UNSPEC) + return -EAFNOSUPPORT; + if (!nla[NFTA_SET_TABLE]) + return -EINVAL; + + set = nft_set_lookup(table, nla[NFTA_SET_NAME], genmask); + if (IS_ERR(set)) + return PTR_ERR(set); + + skb2 = alloc_skb(NLMSG_GOODSIZE, GFP_ATOMIC); + if (skb2 == NULL) + return -ENOMEM; + + err = nf_tables_fill_set(skb2, &ctx, set, NFT_MSG_NEWSET, 0); + if (err < 0) + goto err_fill_set_info; + + return nfnetlink_unicast(skb2, net, NETLINK_CB(skb).portid); + +err_fill_set_info: + kfree_skb(skb2); + return err; +} + +static const struct nla_policy nft_concat_policy[NFTA_SET_FIELD_MAX + 1] = { + [NFTA_SET_FIELD_LEN] = { .type = NLA_U32 }, +}; + +static int nft_set_desc_concat_parse(const struct nlattr *attr, + struct nft_set_desc *desc) +{ + struct nlattr *tb[NFTA_SET_FIELD_MAX + 1]; + u32 len; + int err; + + if (desc->field_count >= ARRAY_SIZE(desc->field_len)) + return -E2BIG; + + err = nla_parse_nested_deprecated(tb, NFTA_SET_FIELD_MAX, attr, + nft_concat_policy, NULL); + if (err < 0) + return err; + + if (!tb[NFTA_SET_FIELD_LEN]) + return -EINVAL; + + len = ntohl(nla_get_be32(tb[NFTA_SET_FIELD_LEN])); + if (!len || len > U8_MAX) + return -EINVAL; + + desc->field_len[desc->field_count++] = len; + + return 0; +} + +static int nft_set_desc_concat(struct nft_set_desc *desc, + const struct nlattr *nla) +{ + u32 num_regs = 0, key_num_regs = 0; + struct nlattr *attr; + int rem, err, i; + + nla_for_each_nested(attr, nla, rem) { + if (nla_type(attr) != NFTA_LIST_ELEM) + return -EINVAL; + + err = nft_set_desc_concat_parse(attr, desc); + if (err < 0) + return err; + } + + for (i = 0; i < desc->field_count; i++) + num_regs += DIV_ROUND_UP(desc->field_len[i], sizeof(u32)); + + key_num_regs = DIV_ROUND_UP(desc->klen, sizeof(u32)); + if (key_num_regs != num_regs) + return -EINVAL; + + if (num_regs > NFT_REG32_COUNT) + return -E2BIG; + + return 0; +} + +static int nf_tables_set_desc_parse(struct nft_set_desc *desc, + const struct nlattr *nla) +{ + struct nlattr *da[NFTA_SET_DESC_MAX + 1]; + int err; + + err = nla_parse_nested_deprecated(da, NFTA_SET_DESC_MAX, nla, + nft_set_desc_policy, NULL); + if (err < 0) + return err; + + if (da[NFTA_SET_DESC_SIZE] != NULL) + desc->size = ntohl(nla_get_be32(da[NFTA_SET_DESC_SIZE])); + if (da[NFTA_SET_DESC_CONCAT]) + err = nft_set_desc_concat(desc, da[NFTA_SET_DESC_CONCAT]); + + return err; +} + +static int nft_set_expr_alloc(struct nft_ctx *ctx, struct nft_set *set, + const struct nlattr * const *nla, + struct nft_expr **exprs, int *num_exprs, + u32 flags) +{ + struct nft_expr *expr; + int err, i; + + if (nla[NFTA_SET_EXPR]) { + expr = nft_set_elem_expr_alloc(ctx, set, nla[NFTA_SET_EXPR]); + if (IS_ERR(expr)) { + err = PTR_ERR(expr); + goto err_set_expr_alloc; + } + exprs[0] = expr; + (*num_exprs)++; + } else if (nla[NFTA_SET_EXPRESSIONS]) { + struct nlattr *tmp; + int left; + + if (!(flags & NFT_SET_EXPR)) { + err = -EINVAL; + goto err_set_expr_alloc; + } + i = 0; + nla_for_each_nested(tmp, nla[NFTA_SET_EXPRESSIONS], left) { + if (i == NFT_SET_EXPR_MAX) { + err = -E2BIG; + goto err_set_expr_alloc; + } + if (nla_type(tmp) != NFTA_LIST_ELEM) { + err = -EINVAL; + goto err_set_expr_alloc; + } + expr = nft_set_elem_expr_alloc(ctx, set, tmp); + if (IS_ERR(expr)) { + err = PTR_ERR(expr); + goto err_set_expr_alloc; + } + exprs[i++] = expr; + (*num_exprs)++; + } + } + + return 0; + +err_set_expr_alloc: + for (i = 0; i < *num_exprs; i++) + nft_expr_destroy(ctx, exprs[i]); + + return err; +} + +static bool nft_set_is_same(const struct nft_set *set, + const struct nft_set_desc *desc, + struct nft_expr *exprs[], u32 num_exprs, u32 flags) +{ + int i; + + if (set->ktype != desc->ktype || + set->dtype != desc->dtype || + set->flags != flags || + set->klen != desc->klen || + set->dlen != desc->dlen || + set->field_count != desc->field_count || + set->num_exprs != num_exprs) + return false; + + for (i = 0; i < desc->field_count; i++) { + if (set->field_len[i] != desc->field_len[i]) + return false; + } + + for (i = 0; i < num_exprs; i++) { + if (set->exprs[i]->ops != exprs[i]->ops) + return false; + } + + return true; +} + +static int nf_tables_newset(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_next(info->net); + u8 family = info->nfmsg->nfgen_family; + const struct nft_set_ops *ops; + struct net *net = info->net; + struct nft_set_desc desc; + struct nft_table *table; + unsigned char *udata; + struct nft_set *set; + struct nft_ctx ctx; + size_t alloc_size; + int num_exprs = 0; + char *name; + int err, i; + u16 udlen; + u32 flags; + u64 size; + + if (nla[NFTA_SET_TABLE] == NULL || + nla[NFTA_SET_NAME] == NULL || + nla[NFTA_SET_KEY_LEN] == NULL || + nla[NFTA_SET_ID] == NULL) + return -EINVAL; + + memset(&desc, 0, sizeof(desc)); + + desc.ktype = NFT_DATA_VALUE; + if (nla[NFTA_SET_KEY_TYPE] != NULL) { + desc.ktype = ntohl(nla_get_be32(nla[NFTA_SET_KEY_TYPE])); + if ((desc.ktype & NFT_DATA_RESERVED_MASK) == NFT_DATA_RESERVED_MASK) + return -EINVAL; + } + + desc.klen = ntohl(nla_get_be32(nla[NFTA_SET_KEY_LEN])); + if (desc.klen == 0 || desc.klen > NFT_DATA_VALUE_MAXLEN) + return -EINVAL; + + flags = 0; + if (nla[NFTA_SET_FLAGS] != NULL) { + flags = ntohl(nla_get_be32(nla[NFTA_SET_FLAGS])); + if (flags & ~(NFT_SET_ANONYMOUS | NFT_SET_CONSTANT | + NFT_SET_INTERVAL | NFT_SET_TIMEOUT | + NFT_SET_MAP | NFT_SET_EVAL | + NFT_SET_OBJECT | NFT_SET_CONCAT | NFT_SET_EXPR)) + return -EOPNOTSUPP; + /* Only one of these operations is supported */ + if ((flags & (NFT_SET_MAP | NFT_SET_OBJECT)) == + (NFT_SET_MAP | NFT_SET_OBJECT)) + return -EOPNOTSUPP; + if ((flags & (NFT_SET_EVAL | NFT_SET_OBJECT)) == + (NFT_SET_EVAL | NFT_SET_OBJECT)) + return -EOPNOTSUPP; + } + + desc.dtype = 0; + if (nla[NFTA_SET_DATA_TYPE] != NULL) { + if (!(flags & NFT_SET_MAP)) + return -EINVAL; + + desc.dtype = ntohl(nla_get_be32(nla[NFTA_SET_DATA_TYPE])); + if ((desc.dtype & NFT_DATA_RESERVED_MASK) == NFT_DATA_RESERVED_MASK && + desc.dtype != NFT_DATA_VERDICT) + return -EINVAL; + + if (desc.dtype != NFT_DATA_VERDICT) { + if (nla[NFTA_SET_DATA_LEN] == NULL) + return -EINVAL; + desc.dlen = ntohl(nla_get_be32(nla[NFTA_SET_DATA_LEN])); + if (desc.dlen == 0 || desc.dlen > NFT_DATA_VALUE_MAXLEN) + return -EINVAL; + } else + desc.dlen = sizeof(struct nft_verdict); + } else if (flags & NFT_SET_MAP) + return -EINVAL; + + if (nla[NFTA_SET_OBJ_TYPE] != NULL) { + if (!(flags & NFT_SET_OBJECT)) + return -EINVAL; + + desc.objtype = ntohl(nla_get_be32(nla[NFTA_SET_OBJ_TYPE])); + if (desc.objtype == NFT_OBJECT_UNSPEC || + desc.objtype > NFT_OBJECT_MAX) + return -EOPNOTSUPP; + } else if (flags & NFT_SET_OBJECT) + return -EINVAL; + else + desc.objtype = NFT_OBJECT_UNSPEC; + + desc.timeout = 0; + if (nla[NFTA_SET_TIMEOUT] != NULL) { + if (!(flags & NFT_SET_TIMEOUT)) + return -EINVAL; + + err = nf_msecs_to_jiffies64(nla[NFTA_SET_TIMEOUT], &desc.timeout); + if (err) + return err; + } + desc.gc_int = 0; + if (nla[NFTA_SET_GC_INTERVAL] != NULL) { + if (!(flags & NFT_SET_TIMEOUT)) + return -EINVAL; + desc.gc_int = ntohl(nla_get_be32(nla[NFTA_SET_GC_INTERVAL])); + } + + desc.policy = NFT_SET_POL_PERFORMANCE; + if (nla[NFTA_SET_POLICY] != NULL) { + desc.policy = ntohl(nla_get_be32(nla[NFTA_SET_POLICY])); + switch (desc.policy) { + case NFT_SET_POL_PERFORMANCE: + case NFT_SET_POL_MEMORY: + break; + default: + return -EOPNOTSUPP; + } + } + + if (nla[NFTA_SET_DESC] != NULL) { + err = nf_tables_set_desc_parse(&desc, nla[NFTA_SET_DESC]); + if (err < 0) + return err; + + if (desc.field_count > 1) { + if (!(flags & NFT_SET_CONCAT)) + return -EINVAL; + } else if (flags & NFT_SET_CONCAT) { + return -EINVAL; + } + } else if (flags & NFT_SET_CONCAT) { + return -EINVAL; + } + + if (nla[NFTA_SET_EXPR] || nla[NFTA_SET_EXPRESSIONS]) + desc.expr = true; + + table = nft_table_lookup(net, nla[NFTA_SET_TABLE], family, genmask, + NETLINK_CB(skb).portid); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_SET_TABLE]); + return PTR_ERR(table); + } + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, NULL, nla); + + set = nft_set_lookup(table, nla[NFTA_SET_NAME], genmask); + if (IS_ERR(set)) { + if (PTR_ERR(set) != -ENOENT) { + NL_SET_BAD_ATTR(extack, nla[NFTA_SET_NAME]); + return PTR_ERR(set); + } + } else { + struct nft_expr *exprs[NFT_SET_EXPR_MAX] = {}; + + if (info->nlh->nlmsg_flags & NLM_F_EXCL) { + NL_SET_BAD_ATTR(extack, nla[NFTA_SET_NAME]); + return -EEXIST; + } + if (info->nlh->nlmsg_flags & NLM_F_REPLACE) + return -EOPNOTSUPP; + + if (nft_set_is_anonymous(set)) + return -EOPNOTSUPP; + + err = nft_set_expr_alloc(&ctx, set, nla, exprs, &num_exprs, flags); + if (err < 0) + return err; + + err = 0; + if (!nft_set_is_same(set, &desc, exprs, num_exprs, flags)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_SET_NAME]); + err = -EEXIST; + } + + for (i = 0; i < num_exprs; i++) + nft_expr_destroy(&ctx, exprs[i]); + + if (err < 0) + return err; + + return __nft_trans_set_add(&ctx, NFT_MSG_NEWSET, set, &desc); + } + + if (!(info->nlh->nlmsg_flags & NLM_F_CREATE)) + return -ENOENT; + + ops = nft_select_set_ops(&ctx, nla, &desc); + if (IS_ERR(ops)) + return PTR_ERR(ops); + + udlen = 0; + if (nla[NFTA_SET_USERDATA]) + udlen = nla_len(nla[NFTA_SET_USERDATA]); + + size = 0; + if (ops->privsize != NULL) + size = ops->privsize(nla, &desc); + alloc_size = sizeof(*set) + size + udlen; + if (alloc_size < size || alloc_size > INT_MAX) + return -ENOMEM; + + if (!nft_use_inc(&table->use)) + return -EMFILE; + + set = kvzalloc(alloc_size, GFP_KERNEL_ACCOUNT); + if (!set) { + err = -ENOMEM; + goto err_alloc; + } + + name = nla_strdup(nla[NFTA_SET_NAME], GFP_KERNEL_ACCOUNT); + if (!name) { + err = -ENOMEM; + goto err_set_name; + } + + err = nf_tables_set_alloc_name(&ctx, set, name); + kfree(name); + if (err < 0) + goto err_set_name; + + udata = NULL; + if (udlen) { + udata = set->data + size; + nla_memcpy(udata, nla[NFTA_SET_USERDATA], udlen); + } + + INIT_LIST_HEAD(&set->bindings); + INIT_LIST_HEAD(&set->catchall_list); + refcount_set(&set->refs, 1); + set->table = table; + write_pnet(&set->net, net); + set->ops = ops; + set->ktype = desc.ktype; + set->klen = desc.klen; + set->dtype = desc.dtype; + set->objtype = desc.objtype; + set->dlen = desc.dlen; + set->flags = flags; + set->size = desc.size; + set->policy = desc.policy; + set->udlen = udlen; + set->udata = udata; + set->timeout = desc.timeout; + set->gc_int = desc.gc_int; + + set->field_count = desc.field_count; + for (i = 0; i < desc.field_count; i++) + set->field_len[i] = desc.field_len[i]; + + err = ops->init(set, &desc, nla); + if (err < 0) + goto err_set_init; + + err = nft_set_expr_alloc(&ctx, set, nla, set->exprs, &num_exprs, flags); + if (err < 0) + goto err_set_destroy; + + set->num_exprs = num_exprs; + set->handle = nf_tables_alloc_handle(table); + INIT_LIST_HEAD(&set->pending_update); + + err = nft_trans_set_add(&ctx, NFT_MSG_NEWSET, set); + if (err < 0) + goto err_set_expr_alloc; + + list_add_tail_rcu(&set->list, &table->sets); + + return 0; + +err_set_expr_alloc: + for (i = 0; i < set->num_exprs; i++) + nft_expr_destroy(&ctx, set->exprs[i]); +err_set_destroy: + ops->destroy(&ctx, set); +err_set_init: + kfree(set->name); +err_set_name: + kvfree(set); +err_alloc: + nft_use_dec_restore(&table->use); + + return err; +} + +static void nft_set_catchall_destroy(const struct nft_ctx *ctx, + struct nft_set *set) +{ + struct nft_set_elem_catchall *next, *catchall; + + list_for_each_entry_safe(catchall, next, &set->catchall_list, list) { + list_del_rcu(&catchall->list); + nf_tables_set_elem_destroy(ctx, set, catchall->elem); + kfree_rcu(catchall, rcu); + } +} + +static void nft_set_put(struct nft_set *set) +{ + if (refcount_dec_and_test(&set->refs)) { + kfree(set->name); + kvfree(set); + } +} + +static void nft_set_destroy(const struct nft_ctx *ctx, struct nft_set *set) +{ + int i; + + if (WARN_ON(set->use > 0)) + return; + + for (i = 0; i < set->num_exprs; i++) + nft_expr_destroy(ctx, set->exprs[i]); + + set->ops->destroy(ctx, set); + nft_set_catchall_destroy(ctx, set); + nft_set_put(set); +} + +static int nf_tables_delset(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_next(info->net); + u8 family = info->nfmsg->nfgen_family; + struct net *net = info->net; + const struct nlattr *attr; + struct nft_table *table; + struct nft_set *set; + struct nft_ctx ctx; + + if (info->nfmsg->nfgen_family == NFPROTO_UNSPEC) + return -EAFNOSUPPORT; + + table = nft_table_lookup(net, nla[NFTA_SET_TABLE], family, + genmask, NETLINK_CB(skb).portid); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_SET_TABLE]); + return PTR_ERR(table); + } + + if (nla[NFTA_SET_HANDLE]) { + attr = nla[NFTA_SET_HANDLE]; + set = nft_set_lookup_byhandle(table, attr, genmask); + } else { + attr = nla[NFTA_SET_NAME]; + set = nft_set_lookup(table, attr, genmask); + } + + if (IS_ERR(set)) { + NL_SET_BAD_ATTR(extack, attr); + return PTR_ERR(set); + } + if (set->use || + (info->nlh->nlmsg_flags & NLM_F_NONREC && + atomic_read(&set->nelems) > 0)) { + NL_SET_BAD_ATTR(extack, attr); + return -EBUSY; + } + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, NULL, nla); + + return nft_delset(&ctx, set); +} + +static int nft_validate_register_store(const struct nft_ctx *ctx, + enum nft_registers reg, + const struct nft_data *data, + enum nft_data_types type, + unsigned int len); + +static int nft_setelem_data_validate(const struct nft_ctx *ctx, + struct nft_set *set, + struct nft_set_elem *elem) +{ + const struct nft_set_ext *ext = nft_set_elem_ext(set, elem->priv); + enum nft_registers dreg; + + dreg = nft_type_to_reg(set->dtype); + return nft_validate_register_store(ctx, dreg, nft_set_ext_data(ext), + set->dtype == NFT_DATA_VERDICT ? + NFT_DATA_VERDICT : NFT_DATA_VALUE, + set->dlen); +} + +static int nf_tables_bind_check_setelem(const struct nft_ctx *ctx, + struct nft_set *set, + const struct nft_set_iter *iter, + struct nft_set_elem *elem) +{ + return nft_setelem_data_validate(ctx, set, elem); +} + +static int nft_set_catchall_bind_check(const struct nft_ctx *ctx, + struct nft_set *set) +{ + u8 genmask = nft_genmask_next(ctx->net); + struct nft_set_elem_catchall *catchall; + struct nft_set_elem elem; + struct nft_set_ext *ext; + int ret = 0; + + list_for_each_entry_rcu(catchall, &set->catchall_list, list) { + ext = nft_set_elem_ext(set, catchall->elem); + if (!nft_set_elem_active(ext, genmask)) + continue; + + elem.priv = catchall->elem; + ret = nft_setelem_data_validate(ctx, set, &elem); + if (ret < 0) + break; + } + + return ret; +} + +int nf_tables_bind_set(const struct nft_ctx *ctx, struct nft_set *set, + struct nft_set_binding *binding) +{ + struct nft_set_binding *i; + struct nft_set_iter iter; + + if (!list_empty(&set->bindings) && nft_set_is_anonymous(set)) + return -EBUSY; + + if (binding->flags & NFT_SET_MAP) { + /* If the set is already bound to the same chain all + * jumps are already validated for that chain. + */ + list_for_each_entry(i, &set->bindings, list) { + if (i->flags & NFT_SET_MAP && + i->chain == binding->chain) + goto bind; + } + + iter.genmask = nft_genmask_next(ctx->net); + iter.skip = 0; + iter.count = 0; + iter.err = 0; + iter.fn = nf_tables_bind_check_setelem; + + set->ops->walk(ctx, set, &iter); + if (!iter.err) + iter.err = nft_set_catchall_bind_check(ctx, set); + + if (iter.err < 0) + return iter.err; + } +bind: + if (!nft_use_inc(&set->use)) + return -EMFILE; + + binding->chain = ctx->chain; + list_add_tail_rcu(&binding->list, &set->bindings); + nft_set_trans_bind(ctx, set); + + return 0; +} +EXPORT_SYMBOL_GPL(nf_tables_bind_set); + +static void nf_tables_unbind_set(const struct nft_ctx *ctx, struct nft_set *set, + struct nft_set_binding *binding, bool event) +{ + list_del_rcu(&binding->list); + + if (list_empty(&set->bindings) && nft_set_is_anonymous(set)) { + list_del_rcu(&set->list); + if (event) + nf_tables_set_notify(ctx, set, NFT_MSG_DELSET, + GFP_KERNEL); + } +} + +static void nft_setelem_data_activate(const struct net *net, + const struct nft_set *set, + struct nft_set_elem *elem); + +static int nft_mapelem_activate(const struct nft_ctx *ctx, + struct nft_set *set, + const struct nft_set_iter *iter, + struct nft_set_elem *elem) +{ + nft_setelem_data_activate(ctx->net, set, elem); + + return 0; +} + +static void nft_map_catchall_activate(const struct nft_ctx *ctx, + struct nft_set *set) +{ + u8 genmask = nft_genmask_next(ctx->net); + struct nft_set_elem_catchall *catchall; + struct nft_set_elem elem; + struct nft_set_ext *ext; + + list_for_each_entry(catchall, &set->catchall_list, list) { + ext = nft_set_elem_ext(set, catchall->elem); + if (!nft_set_elem_active(ext, genmask)) + continue; + + elem.priv = catchall->elem; + nft_setelem_data_activate(ctx->net, set, &elem); + break; + } +} + +static void nft_map_activate(const struct nft_ctx *ctx, struct nft_set *set) +{ + struct nft_set_iter iter = { + .genmask = nft_genmask_next(ctx->net), + .fn = nft_mapelem_activate, + }; + + set->ops->walk(ctx, set, &iter); + WARN_ON_ONCE(iter.err); + + nft_map_catchall_activate(ctx, set); +} + +void nf_tables_activate_set(const struct nft_ctx *ctx, struct nft_set *set) +{ + if (nft_set_is_anonymous(set)) { + if (set->flags & (NFT_SET_MAP | NFT_SET_OBJECT)) + nft_map_activate(ctx, set); + + nft_clear(ctx->net, set); + } + + nft_use_inc_restore(&set->use); +} +EXPORT_SYMBOL_GPL(nf_tables_activate_set); + +void nf_tables_deactivate_set(const struct nft_ctx *ctx, struct nft_set *set, + struct nft_set_binding *binding, + enum nft_trans_phase phase) +{ + switch (phase) { + case NFT_TRANS_PREPARE_ERROR: + nft_set_trans_unbind(ctx, set); + if (nft_set_is_anonymous(set)) + nft_deactivate_next(ctx->net, set); + else + list_del_rcu(&binding->list); + + nft_use_dec(&set->use); + break; + case NFT_TRANS_PREPARE: + if (nft_set_is_anonymous(set)) { + if (set->flags & (NFT_SET_MAP | NFT_SET_OBJECT)) + nft_map_deactivate(ctx, set); + + nft_deactivate_next(ctx->net, set); + } + nft_use_dec(&set->use); + return; + case NFT_TRANS_ABORT: + case NFT_TRANS_RELEASE: + if (nft_set_is_anonymous(set) && + set->flags & (NFT_SET_MAP | NFT_SET_OBJECT)) + nft_map_deactivate(ctx, set); + + nft_use_dec(&set->use); + fallthrough; + default: + nf_tables_unbind_set(ctx, set, binding, + phase == NFT_TRANS_COMMIT); + } +} +EXPORT_SYMBOL_GPL(nf_tables_deactivate_set); + +void nf_tables_destroy_set(const struct nft_ctx *ctx, struct nft_set *set) +{ + if (list_empty(&set->bindings) && nft_set_is_anonymous(set)) + nft_set_destroy(ctx, set); +} +EXPORT_SYMBOL_GPL(nf_tables_destroy_set); + +const struct nft_set_ext_type nft_set_ext_types[] = { + [NFT_SET_EXT_KEY] = { + .align = __alignof__(u32), + }, + [NFT_SET_EXT_DATA] = { + .align = __alignof__(u32), + }, + [NFT_SET_EXT_EXPRESSIONS] = { + .align = __alignof__(struct nft_set_elem_expr), + }, + [NFT_SET_EXT_OBJREF] = { + .len = sizeof(struct nft_object *), + .align = __alignof__(struct nft_object *), + }, + [NFT_SET_EXT_FLAGS] = { + .len = sizeof(u8), + .align = __alignof__(u8), + }, + [NFT_SET_EXT_TIMEOUT] = { + .len = sizeof(u64), + .align = __alignof__(u64), + }, + [NFT_SET_EXT_EXPIRATION] = { + .len = sizeof(u64), + .align = __alignof__(u64), + }, + [NFT_SET_EXT_USERDATA] = { + .len = sizeof(struct nft_userdata), + .align = __alignof__(struct nft_userdata), + }, + [NFT_SET_EXT_KEY_END] = { + .align = __alignof__(u32), + }, +}; + +/* + * Set elements + */ + +static const struct nla_policy nft_set_elem_policy[NFTA_SET_ELEM_MAX + 1] = { + [NFTA_SET_ELEM_KEY] = { .type = NLA_NESTED }, + [NFTA_SET_ELEM_DATA] = { .type = NLA_NESTED }, + [NFTA_SET_ELEM_FLAGS] = { .type = NLA_U32 }, + [NFTA_SET_ELEM_TIMEOUT] = { .type = NLA_U64 }, + [NFTA_SET_ELEM_EXPIRATION] = { .type = NLA_U64 }, + [NFTA_SET_ELEM_USERDATA] = { .type = NLA_BINARY, + .len = NFT_USERDATA_MAXLEN }, + [NFTA_SET_ELEM_EXPR] = { .type = NLA_NESTED }, + [NFTA_SET_ELEM_OBJREF] = { .type = NLA_STRING, + .len = NFT_OBJ_MAXNAMELEN - 1 }, + [NFTA_SET_ELEM_KEY_END] = { .type = NLA_NESTED }, + [NFTA_SET_ELEM_EXPRESSIONS] = { .type = NLA_NESTED }, +}; + +static const struct nla_policy nft_set_elem_list_policy[NFTA_SET_ELEM_LIST_MAX + 1] = { + [NFTA_SET_ELEM_LIST_TABLE] = { .type = NLA_STRING, + .len = NFT_TABLE_MAXNAMELEN - 1 }, + [NFTA_SET_ELEM_LIST_SET] = { .type = NLA_STRING, + .len = NFT_SET_MAXNAMELEN - 1 }, + [NFTA_SET_ELEM_LIST_ELEMENTS] = { .type = NLA_NESTED }, + [NFTA_SET_ELEM_LIST_SET_ID] = { .type = NLA_U32 }, +}; + +static int nft_set_elem_expr_dump(struct sk_buff *skb, + const struct nft_set *set, + const struct nft_set_ext *ext) +{ + struct nft_set_elem_expr *elem_expr; + u32 size, num_exprs = 0; + struct nft_expr *expr; + struct nlattr *nest; + + elem_expr = nft_set_ext_expr(ext); + nft_setelem_expr_foreach(expr, elem_expr, size) + num_exprs++; + + if (num_exprs == 1) { + expr = nft_setelem_expr_at(elem_expr, 0); + if (nft_expr_dump(skb, NFTA_SET_ELEM_EXPR, expr) < 0) + return -1; + + return 0; + } else if (num_exprs > 1) { + nest = nla_nest_start_noflag(skb, NFTA_SET_ELEM_EXPRESSIONS); + if (nest == NULL) + goto nla_put_failure; + + nft_setelem_expr_foreach(expr, elem_expr, size) { + expr = nft_setelem_expr_at(elem_expr, size); + if (nft_expr_dump(skb, NFTA_LIST_ELEM, expr) < 0) + goto nla_put_failure; + } + nla_nest_end(skb, nest); + } + return 0; + +nla_put_failure: + return -1; +} + +static int nf_tables_fill_setelem(struct sk_buff *skb, + const struct nft_set *set, + const struct nft_set_elem *elem) +{ + const struct nft_set_ext *ext = nft_set_elem_ext(set, elem->priv); + unsigned char *b = skb_tail_pointer(skb); + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, NFTA_LIST_ELEM); + if (nest == NULL) + goto nla_put_failure; + + if (nft_set_ext_exists(ext, NFT_SET_EXT_KEY) && + nft_data_dump(skb, NFTA_SET_ELEM_KEY, nft_set_ext_key(ext), + NFT_DATA_VALUE, set->klen) < 0) + goto nla_put_failure; + + if (nft_set_ext_exists(ext, NFT_SET_EXT_KEY_END) && + nft_data_dump(skb, NFTA_SET_ELEM_KEY_END, nft_set_ext_key_end(ext), + NFT_DATA_VALUE, set->klen) < 0) + goto nla_put_failure; + + if (nft_set_ext_exists(ext, NFT_SET_EXT_DATA) && + nft_data_dump(skb, NFTA_SET_ELEM_DATA, nft_set_ext_data(ext), + set->dtype == NFT_DATA_VERDICT ? NFT_DATA_VERDICT : NFT_DATA_VALUE, + set->dlen) < 0) + goto nla_put_failure; + + if (nft_set_ext_exists(ext, NFT_SET_EXT_EXPRESSIONS) && + nft_set_elem_expr_dump(skb, set, ext)) + goto nla_put_failure; + + if (nft_set_ext_exists(ext, NFT_SET_EXT_OBJREF) && + nla_put_string(skb, NFTA_SET_ELEM_OBJREF, + (*nft_set_ext_obj(ext))->key.name) < 0) + goto nla_put_failure; + + if (nft_set_ext_exists(ext, NFT_SET_EXT_FLAGS) && + nla_put_be32(skb, NFTA_SET_ELEM_FLAGS, + htonl(*nft_set_ext_flags(ext)))) + goto nla_put_failure; + + if (nft_set_ext_exists(ext, NFT_SET_EXT_TIMEOUT) && + nla_put_be64(skb, NFTA_SET_ELEM_TIMEOUT, + nf_jiffies64_to_msecs(*nft_set_ext_timeout(ext)), + NFTA_SET_ELEM_PAD)) + goto nla_put_failure; + + if (nft_set_ext_exists(ext, NFT_SET_EXT_EXPIRATION)) { + u64 expires, now = get_jiffies_64(); + + expires = *nft_set_ext_expiration(ext); + if (time_before64(now, expires)) + expires -= now; + else + expires = 0; + + if (nla_put_be64(skb, NFTA_SET_ELEM_EXPIRATION, + nf_jiffies64_to_msecs(expires), + NFTA_SET_ELEM_PAD)) + goto nla_put_failure; + } + + if (nft_set_ext_exists(ext, NFT_SET_EXT_USERDATA)) { + struct nft_userdata *udata; + + udata = nft_set_ext_userdata(ext); + if (nla_put(skb, NFTA_SET_ELEM_USERDATA, + udata->len + 1, udata->data)) + goto nla_put_failure; + } + + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + nlmsg_trim(skb, b); + return -EMSGSIZE; +} + +struct nft_set_dump_args { + const struct netlink_callback *cb; + struct nft_set_iter iter; + struct sk_buff *skb; +}; + +static int nf_tables_dump_setelem(const struct nft_ctx *ctx, + struct nft_set *set, + const struct nft_set_iter *iter, + struct nft_set_elem *elem) +{ + const struct nft_set_ext *ext = nft_set_elem_ext(set, elem->priv); + struct nft_set_dump_args *args; + + if (nft_set_elem_expired(ext) || nft_set_elem_is_dead(ext)) + return 0; + + args = container_of(iter, struct nft_set_dump_args, iter); + return nf_tables_fill_setelem(args->skb, set, elem); +} + +struct nft_set_dump_ctx { + const struct nft_set *set; + struct nft_ctx ctx; +}; + +static int nft_set_catchall_dump(struct net *net, struct sk_buff *skb, + const struct nft_set *set) +{ + struct nft_set_elem_catchall *catchall; + u8 genmask = nft_genmask_cur(net); + struct nft_set_elem elem; + struct nft_set_ext *ext; + int ret = 0; + + list_for_each_entry_rcu(catchall, &set->catchall_list, list) { + ext = nft_set_elem_ext(set, catchall->elem); + if (!nft_set_elem_active(ext, genmask) || + nft_set_elem_expired(ext)) + continue; + + elem.priv = catchall->elem; + ret = nf_tables_fill_setelem(skb, set, &elem); + break; + } + + return ret; +} + +static int nf_tables_dump_set(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct nft_set_dump_ctx *dump_ctx = cb->data; + struct net *net = sock_net(skb->sk); + struct nftables_pernet *nft_net; + struct nft_table *table; + struct nft_set *set; + struct nft_set_dump_args args; + bool set_found = false; + struct nlmsghdr *nlh; + struct nlattr *nest; + u32 portid, seq; + int event; + + rcu_read_lock(); + nft_net = nft_pernet(net); + cb->seq = READ_ONCE(nft_net->base_seq); + + list_for_each_entry_rcu(table, &nft_net->tables, list) { + if (dump_ctx->ctx.family != NFPROTO_UNSPEC && + dump_ctx->ctx.family != table->family) + continue; + + if (table != dump_ctx->ctx.table) + continue; + + list_for_each_entry_rcu(set, &table->sets, list) { + if (set == dump_ctx->set) { + set_found = true; + break; + } + } + break; + } + + if (!set_found) { + rcu_read_unlock(); + return -ENOENT; + } + + event = nfnl_msg_type(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWSETELEM); + portid = NETLINK_CB(cb->skb).portid; + seq = cb->nlh->nlmsg_seq; + + nlh = nfnl_msg_put(skb, portid, seq, event, NLM_F_MULTI, + table->family, NFNETLINK_V0, nft_base_seq(net)); + if (!nlh) + goto nla_put_failure; + + if (nla_put_string(skb, NFTA_SET_ELEM_LIST_TABLE, table->name)) + goto nla_put_failure; + if (nla_put_string(skb, NFTA_SET_ELEM_LIST_SET, set->name)) + goto nla_put_failure; + + nest = nla_nest_start_noflag(skb, NFTA_SET_ELEM_LIST_ELEMENTS); + if (nest == NULL) + goto nla_put_failure; + + args.cb = cb; + args.skb = skb; + args.iter.genmask = nft_genmask_cur(net); + args.iter.skip = cb->args[0]; + args.iter.count = 0; + args.iter.err = 0; + args.iter.fn = nf_tables_dump_setelem; + set->ops->walk(&dump_ctx->ctx, set, &args.iter); + + if (!args.iter.err && args.iter.count == cb->args[0]) + args.iter.err = nft_set_catchall_dump(net, skb, set); + rcu_read_unlock(); + + nla_nest_end(skb, nest); + nlmsg_end(skb, nlh); + + if (args.iter.err && args.iter.err != -EMSGSIZE) + return args.iter.err; + if (args.iter.count == cb->args[0]) + return 0; + + cb->args[0] = args.iter.count; + return skb->len; + +nla_put_failure: + rcu_read_unlock(); + return -ENOSPC; +} + +static int nf_tables_dump_set_start(struct netlink_callback *cb) +{ + struct nft_set_dump_ctx *dump_ctx = cb->data; + + cb->data = kmemdup(dump_ctx, sizeof(*dump_ctx), GFP_ATOMIC); + + return cb->data ? 0 : -ENOMEM; +} + +static int nf_tables_dump_set_done(struct netlink_callback *cb) +{ + kfree(cb->data); + return 0; +} + +static int nf_tables_fill_setelem_info(struct sk_buff *skb, + const struct nft_ctx *ctx, u32 seq, + u32 portid, int event, u16 flags, + const struct nft_set *set, + const struct nft_set_elem *elem) +{ + struct nlmsghdr *nlh; + struct nlattr *nest; + int err; + + event = nfnl_msg_type(NFNL_SUBSYS_NFTABLES, event); + nlh = nfnl_msg_put(skb, portid, seq, event, flags, ctx->family, + NFNETLINK_V0, nft_base_seq(ctx->net)); + if (!nlh) + goto nla_put_failure; + + if (nla_put_string(skb, NFTA_SET_TABLE, ctx->table->name)) + goto nla_put_failure; + if (nla_put_string(skb, NFTA_SET_NAME, set->name)) + goto nla_put_failure; + + nest = nla_nest_start_noflag(skb, NFTA_SET_ELEM_LIST_ELEMENTS); + if (nest == NULL) + goto nla_put_failure; + + err = nf_tables_fill_setelem(skb, set, elem); + if (err < 0) + goto nla_put_failure; + + nla_nest_end(skb, nest); + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_trim(skb, nlh); + return -1; +} + +static int nft_setelem_parse_flags(const struct nft_set *set, + const struct nlattr *attr, u32 *flags) +{ + if (attr == NULL) + return 0; + + *flags = ntohl(nla_get_be32(attr)); + if (*flags & ~(NFT_SET_ELEM_INTERVAL_END | NFT_SET_ELEM_CATCHALL)) + return -EOPNOTSUPP; + if (!(set->flags & NFT_SET_INTERVAL) && + *flags & NFT_SET_ELEM_INTERVAL_END) + return -EINVAL; + if ((*flags & (NFT_SET_ELEM_INTERVAL_END | NFT_SET_ELEM_CATCHALL)) == + (NFT_SET_ELEM_INTERVAL_END | NFT_SET_ELEM_CATCHALL)) + return -EINVAL; + + return 0; +} + +static int nft_setelem_parse_key(struct nft_ctx *ctx, struct nft_set *set, + struct nft_data *key, struct nlattr *attr) +{ + struct nft_data_desc desc = { + .type = NFT_DATA_VALUE, + .size = NFT_DATA_VALUE_MAXLEN, + .len = set->klen, + }; + + return nft_data_init(ctx, key, &desc, attr); +} + +static int nft_setelem_parse_data(struct nft_ctx *ctx, struct nft_set *set, + struct nft_data_desc *desc, + struct nft_data *data, + struct nlattr *attr) +{ + u32 dtype; + + if (set->dtype == NFT_DATA_VERDICT) + dtype = NFT_DATA_VERDICT; + else + dtype = NFT_DATA_VALUE; + + desc->type = dtype; + desc->size = NFT_DATA_VALUE_MAXLEN; + desc->len = set->dlen; + desc->flags = NFT_DATA_DESC_SETELEM; + + return nft_data_init(ctx, data, desc, attr); +} + +static void *nft_setelem_catchall_get(const struct net *net, + const struct nft_set *set) +{ + struct nft_set_elem_catchall *catchall; + u8 genmask = nft_genmask_cur(net); + struct nft_set_ext *ext; + void *priv = NULL; + + list_for_each_entry_rcu(catchall, &set->catchall_list, list) { + ext = nft_set_elem_ext(set, catchall->elem); + if (!nft_set_elem_active(ext, genmask) || + nft_set_elem_expired(ext)) + continue; + + priv = catchall->elem; + break; + } + + return priv; +} + +static int nft_setelem_get(struct nft_ctx *ctx, struct nft_set *set, + struct nft_set_elem *elem, u32 flags) +{ + void *priv; + + if (!(flags & NFT_SET_ELEM_CATCHALL)) { + priv = set->ops->get(ctx->net, set, elem, flags); + if (IS_ERR(priv)) + return PTR_ERR(priv); + } else { + priv = nft_setelem_catchall_get(ctx->net, set); + if (!priv) + return -ENOENT; + } + elem->priv = priv; + + return 0; +} + +static int nft_get_set_elem(struct nft_ctx *ctx, struct nft_set *set, + const struct nlattr *attr) +{ + struct nlattr *nla[NFTA_SET_ELEM_MAX + 1]; + struct nft_set_elem elem; + struct sk_buff *skb; + uint32_t flags = 0; + int err; + + err = nla_parse_nested_deprecated(nla, NFTA_SET_ELEM_MAX, attr, + nft_set_elem_policy, NULL); + if (err < 0) + return err; + + err = nft_setelem_parse_flags(set, nla[NFTA_SET_ELEM_FLAGS], &flags); + if (err < 0) + return err; + + if (!nla[NFTA_SET_ELEM_KEY] && !(flags & NFT_SET_ELEM_CATCHALL)) + return -EINVAL; + + if (nla[NFTA_SET_ELEM_KEY]) { + err = nft_setelem_parse_key(ctx, set, &elem.key.val, + nla[NFTA_SET_ELEM_KEY]); + if (err < 0) + return err; + } + + if (nla[NFTA_SET_ELEM_KEY_END]) { + err = nft_setelem_parse_key(ctx, set, &elem.key_end.val, + nla[NFTA_SET_ELEM_KEY_END]); + if (err < 0) + return err; + } + + err = nft_setelem_get(ctx, set, &elem, flags); + if (err < 0) + return err; + + err = -ENOMEM; + skb = nlmsg_new(NLMSG_GOODSIZE, GFP_ATOMIC); + if (skb == NULL) + return err; + + err = nf_tables_fill_setelem_info(skb, ctx, ctx->seq, ctx->portid, + NFT_MSG_NEWSETELEM, 0, set, &elem); + if (err < 0) + goto err_fill_setelem; + + return nfnetlink_unicast(skb, ctx->net, ctx->portid); + +err_fill_setelem: + kfree_skb(skb); + return err; +} + +/* called with rcu_read_lock held */ +static int nf_tables_getsetelem(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_cur(info->net); + u8 family = info->nfmsg->nfgen_family; + struct net *net = info->net; + struct nft_table *table; + struct nft_set *set; + struct nlattr *attr; + struct nft_ctx ctx; + int rem, err = 0; + + table = nft_table_lookup(net, nla[NFTA_SET_ELEM_LIST_TABLE], family, + genmask, 0); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_SET_ELEM_LIST_TABLE]); + return PTR_ERR(table); + } + + set = nft_set_lookup(table, nla[NFTA_SET_ELEM_LIST_SET], genmask); + if (IS_ERR(set)) + return PTR_ERR(set); + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, NULL, nla); + + if (info->nlh->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .start = nf_tables_dump_set_start, + .dump = nf_tables_dump_set, + .done = nf_tables_dump_set_done, + .module = THIS_MODULE, + }; + struct nft_set_dump_ctx dump_ctx = { + .set = set, + .ctx = ctx, + }; + + c.data = &dump_ctx; + return nft_netlink_dump_start_rcu(info->sk, skb, info->nlh, &c); + } + + if (!nla[NFTA_SET_ELEM_LIST_ELEMENTS]) + return -EINVAL; + + nla_for_each_nested(attr, nla[NFTA_SET_ELEM_LIST_ELEMENTS], rem) { + err = nft_get_set_elem(&ctx, set, attr); + if (err < 0) { + NL_SET_BAD_ATTR(extack, attr); + break; + } + } + + return err; +} + +static void nf_tables_setelem_notify(const struct nft_ctx *ctx, + const struct nft_set *set, + const struct nft_set_elem *elem, + int event) +{ + struct nftables_pernet *nft_net; + struct net *net = ctx->net; + u32 portid = ctx->portid; + struct sk_buff *skb; + u16 flags = 0; + int err; + + if (!ctx->report && !nfnetlink_has_listeners(net, NFNLGRP_NFTABLES)) + return; + + skb = nlmsg_new(NLMSG_GOODSIZE, GFP_KERNEL); + if (skb == NULL) + goto err; + + if (ctx->flags & (NLM_F_CREATE | NLM_F_EXCL)) + flags |= ctx->flags & (NLM_F_CREATE | NLM_F_EXCL); + + err = nf_tables_fill_setelem_info(skb, ctx, 0, portid, event, flags, + set, elem); + if (err < 0) { + kfree_skb(skb); + goto err; + } + + nft_net = nft_pernet(net); + nft_notify_enqueue(skb, ctx->report, &nft_net->notify_list); + return; +err: + nfnetlink_set_err(net, portid, NFNLGRP_NFTABLES, -ENOBUFS); +} + +static struct nft_trans *nft_trans_elem_alloc(struct nft_ctx *ctx, + int msg_type, + struct nft_set *set) +{ + struct nft_trans *trans; + + trans = nft_trans_alloc(ctx, msg_type, sizeof(struct nft_trans_elem)); + if (trans == NULL) + return NULL; + + nft_trans_elem_set(trans) = set; + return trans; +} + +struct nft_expr *nft_set_elem_expr_alloc(const struct nft_ctx *ctx, + const struct nft_set *set, + const struct nlattr *attr) +{ + struct nft_expr *expr; + int err; + + expr = nft_expr_init(ctx, attr); + if (IS_ERR(expr)) + return expr; + + err = -EOPNOTSUPP; + if (expr->ops->type->flags & NFT_EXPR_GC) { + if (set->flags & NFT_SET_TIMEOUT) + goto err_set_elem_expr; + if (!set->ops->gc_init) + goto err_set_elem_expr; + set->ops->gc_init(set); + } + + return expr; + +err_set_elem_expr: + nft_expr_destroy(ctx, expr); + return ERR_PTR(err); +} + +static int nft_set_ext_check(const struct nft_set_ext_tmpl *tmpl, u8 id, u32 len) +{ + len += nft_set_ext_types[id].len; + if (len > tmpl->ext_len[id] || + len > U8_MAX) + return -1; + + return 0; +} + +static int nft_set_ext_memcpy(const struct nft_set_ext_tmpl *tmpl, u8 id, + void *to, const void *from, u32 len) +{ + if (nft_set_ext_check(tmpl, id, len) < 0) + return -1; + + memcpy(to, from, len); + + return 0; +} + +void *nft_set_elem_init(const struct nft_set *set, + const struct nft_set_ext_tmpl *tmpl, + const u32 *key, const u32 *key_end, + const u32 *data, u64 timeout, u64 expiration, gfp_t gfp) +{ + struct nft_set_ext *ext; + void *elem; + + elem = kzalloc(set->ops->elemsize + tmpl->len, gfp); + if (elem == NULL) + return ERR_PTR(-ENOMEM); + + ext = nft_set_elem_ext(set, elem); + nft_set_ext_init(ext, tmpl); + + if (nft_set_ext_exists(ext, NFT_SET_EXT_KEY) && + nft_set_ext_memcpy(tmpl, NFT_SET_EXT_KEY, + nft_set_ext_key(ext), key, set->klen) < 0) + goto err_ext_check; + + if (nft_set_ext_exists(ext, NFT_SET_EXT_KEY_END) && + nft_set_ext_memcpy(tmpl, NFT_SET_EXT_KEY_END, + nft_set_ext_key_end(ext), key_end, set->klen) < 0) + goto err_ext_check; + + if (nft_set_ext_exists(ext, NFT_SET_EXT_DATA) && + nft_set_ext_memcpy(tmpl, NFT_SET_EXT_DATA, + nft_set_ext_data(ext), data, set->dlen) < 0) + goto err_ext_check; + + if (nft_set_ext_exists(ext, NFT_SET_EXT_EXPIRATION)) { + *nft_set_ext_expiration(ext) = get_jiffies_64() + expiration; + if (expiration == 0) + *nft_set_ext_expiration(ext) += timeout; + } + if (nft_set_ext_exists(ext, NFT_SET_EXT_TIMEOUT)) + *nft_set_ext_timeout(ext) = timeout; + + return elem; + +err_ext_check: + kfree(elem); + + return ERR_PTR(-EINVAL); +} + +static void __nft_set_elem_expr_destroy(const struct nft_ctx *ctx, + struct nft_expr *expr) +{ + if (expr->ops->destroy_clone) { + expr->ops->destroy_clone(ctx, expr); + module_put(expr->ops->type->owner); + } else { + nf_tables_expr_destroy(ctx, expr); + } +} + +static void nft_set_elem_expr_destroy(const struct nft_ctx *ctx, + struct nft_set_elem_expr *elem_expr) +{ + struct nft_expr *expr; + u32 size; + + nft_setelem_expr_foreach(expr, elem_expr, size) + __nft_set_elem_expr_destroy(ctx, expr); +} + +/* Drop references and destroy. Called from gc, dynset and abort path. */ +void nft_set_elem_destroy(const struct nft_set *set, void *elem, + bool destroy_expr) +{ + struct nft_set_ext *ext = nft_set_elem_ext(set, elem); + struct nft_ctx ctx = { + .net = read_pnet(&set->net), + .family = set->table->family, + }; + + nft_data_release(nft_set_ext_key(ext), NFT_DATA_VALUE); + if (nft_set_ext_exists(ext, NFT_SET_EXT_DATA)) + nft_data_release(nft_set_ext_data(ext), set->dtype); + if (destroy_expr && nft_set_ext_exists(ext, NFT_SET_EXT_EXPRESSIONS)) + nft_set_elem_expr_destroy(&ctx, nft_set_ext_expr(ext)); + + if (nft_set_ext_exists(ext, NFT_SET_EXT_OBJREF)) + nft_use_dec(&(*nft_set_ext_obj(ext))->use); + kfree(elem); +} +EXPORT_SYMBOL_GPL(nft_set_elem_destroy); + +/* Destroy element. References have been already dropped in the preparation + * path via nft_setelem_data_deactivate(). + */ +void nf_tables_set_elem_destroy(const struct nft_ctx *ctx, + const struct nft_set *set, void *elem) +{ + struct nft_set_ext *ext = nft_set_elem_ext(set, elem); + + if (nft_set_ext_exists(ext, NFT_SET_EXT_EXPRESSIONS)) + nft_set_elem_expr_destroy(ctx, nft_set_ext_expr(ext)); + + kfree(elem); +} + +int nft_set_elem_expr_clone(const struct nft_ctx *ctx, struct nft_set *set, + struct nft_expr *expr_array[]) +{ + struct nft_expr *expr; + int err, i, k; + + for (i = 0; i < set->num_exprs; i++) { + expr = kzalloc(set->exprs[i]->ops->size, GFP_KERNEL_ACCOUNT); + if (!expr) + goto err_expr; + + err = nft_expr_clone(expr, set->exprs[i]); + if (err < 0) { + kfree(expr); + goto err_expr; + } + expr_array[i] = expr; + } + + return 0; + +err_expr: + for (k = i - 1; k >= 0; k--) + nft_expr_destroy(ctx, expr_array[k]); + + return -ENOMEM; +} + +static int nft_set_elem_expr_setup(struct nft_ctx *ctx, + const struct nft_set_ext_tmpl *tmpl, + const struct nft_set_ext *ext, + struct nft_expr *expr_array[], + u32 num_exprs) +{ + struct nft_set_elem_expr *elem_expr = nft_set_ext_expr(ext); + u32 len = sizeof(struct nft_set_elem_expr); + struct nft_expr *expr; + int i, err; + + if (num_exprs == 0) + return 0; + + for (i = 0; i < num_exprs; i++) + len += expr_array[i]->ops->size; + + if (nft_set_ext_check(tmpl, NFT_SET_EXT_EXPRESSIONS, len) < 0) + return -EINVAL; + + for (i = 0; i < num_exprs; i++) { + expr = nft_setelem_expr_at(elem_expr, elem_expr->size); + err = nft_expr_clone(expr, expr_array[i]); + if (err < 0) + goto err_elem_expr_setup; + + elem_expr->size += expr_array[i]->ops->size; + nft_expr_destroy(ctx, expr_array[i]); + expr_array[i] = NULL; + } + + return 0; + +err_elem_expr_setup: + for (; i < num_exprs; i++) { + nft_expr_destroy(ctx, expr_array[i]); + expr_array[i] = NULL; + } + + return -ENOMEM; +} + +struct nft_set_ext *nft_set_catchall_lookup(const struct net *net, + const struct nft_set *set) +{ + struct nft_set_elem_catchall *catchall; + u8 genmask = nft_genmask_cur(net); + struct nft_set_ext *ext; + + list_for_each_entry_rcu(catchall, &set->catchall_list, list) { + ext = nft_set_elem_ext(set, catchall->elem); + if (nft_set_elem_active(ext, genmask) && + !nft_set_elem_expired(ext) && + !nft_set_elem_is_dead(ext)) + return ext; + } + + return NULL; +} +EXPORT_SYMBOL_GPL(nft_set_catchall_lookup); + +static int nft_setelem_catchall_insert(const struct net *net, + struct nft_set *set, + const struct nft_set_elem *elem, + struct nft_set_ext **pext) +{ + struct nft_set_elem_catchall *catchall; + u8 genmask = nft_genmask_next(net); + struct nft_set_ext *ext; + + list_for_each_entry(catchall, &set->catchall_list, list) { + ext = nft_set_elem_ext(set, catchall->elem); + if (nft_set_elem_active(ext, genmask)) { + *pext = ext; + return -EEXIST; + } + } + + catchall = kmalloc(sizeof(*catchall), GFP_KERNEL); + if (!catchall) + return -ENOMEM; + + catchall->elem = elem->priv; + list_add_tail_rcu(&catchall->list, &set->catchall_list); + + return 0; +} + +static int nft_setelem_insert(const struct net *net, + struct nft_set *set, + const struct nft_set_elem *elem, + struct nft_set_ext **ext, unsigned int flags) +{ + int ret; + + if (flags & NFT_SET_ELEM_CATCHALL) + ret = nft_setelem_catchall_insert(net, set, elem, ext); + else + ret = set->ops->insert(net, set, elem, ext); + + return ret; +} + +static bool nft_setelem_is_catchall(const struct nft_set *set, + const struct nft_set_elem *elem) +{ + struct nft_set_ext *ext = nft_set_elem_ext(set, elem->priv); + + if (nft_set_ext_exists(ext, NFT_SET_EXT_FLAGS) && + *nft_set_ext_flags(ext) & NFT_SET_ELEM_CATCHALL) + return true; + + return false; +} + +static void nft_setelem_activate(struct net *net, struct nft_set *set, + struct nft_set_elem *elem) +{ + struct nft_set_ext *ext = nft_set_elem_ext(set, elem->priv); + + if (nft_setelem_is_catchall(set, elem)) { + nft_set_elem_change_active(net, set, ext); + } else { + set->ops->activate(net, set, elem); + } +} + +static int nft_setelem_catchall_deactivate(const struct net *net, + struct nft_set *set, + struct nft_set_elem *elem) +{ + struct nft_set_elem_catchall *catchall; + struct nft_set_ext *ext; + + list_for_each_entry(catchall, &set->catchall_list, list) { + ext = nft_set_elem_ext(set, catchall->elem); + if (!nft_is_active_next(net, ext)) + continue; + + kfree(elem->priv); + elem->priv = catchall->elem; + nft_set_elem_change_active(net, set, ext); + return 0; + } + + return -ENOENT; +} + +static int __nft_setelem_deactivate(const struct net *net, + struct nft_set *set, + struct nft_set_elem *elem) +{ + void *priv; + + priv = set->ops->deactivate(net, set, elem); + if (!priv) + return -ENOENT; + + kfree(elem->priv); + elem->priv = priv; + set->ndeact++; + + return 0; +} + +static int nft_setelem_deactivate(const struct net *net, + struct nft_set *set, + struct nft_set_elem *elem, u32 flags) +{ + int ret; + + if (flags & NFT_SET_ELEM_CATCHALL) + ret = nft_setelem_catchall_deactivate(net, set, elem); + else + ret = __nft_setelem_deactivate(net, set, elem); + + return ret; +} + +static void nft_setelem_catchall_destroy(struct nft_set_elem_catchall *catchall) +{ + list_del_rcu(&catchall->list); + kfree_rcu(catchall, rcu); +} + +static void nft_setelem_catchall_remove(const struct net *net, + const struct nft_set *set, + const struct nft_set_elem *elem) +{ + struct nft_set_elem_catchall *catchall, *next; + + list_for_each_entry_safe(catchall, next, &set->catchall_list, list) { + if (catchall->elem == elem->priv) { + nft_setelem_catchall_destroy(catchall); + break; + } + } +} + +static void nft_setelem_remove(const struct net *net, + const struct nft_set *set, + const struct nft_set_elem *elem) +{ + if (nft_setelem_is_catchall(set, elem)) + nft_setelem_catchall_remove(net, set, elem); + else + set->ops->remove(net, set, elem); +} + +static bool nft_setelem_valid_key_end(const struct nft_set *set, + struct nlattr **nla, u32 flags) +{ + if ((set->flags & (NFT_SET_CONCAT | NFT_SET_INTERVAL)) == + (NFT_SET_CONCAT | NFT_SET_INTERVAL)) { + if (flags & NFT_SET_ELEM_INTERVAL_END) + return false; + + if (nla[NFTA_SET_ELEM_KEY_END] && + flags & NFT_SET_ELEM_CATCHALL) + return false; + } else { + if (nla[NFTA_SET_ELEM_KEY_END]) + return false; + } + + return true; +} + +static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set, + const struct nlattr *attr, u32 nlmsg_flags) +{ + struct nft_expr *expr_array[NFT_SET_EXPR_MAX] = {}; + struct nlattr *nla[NFTA_SET_ELEM_MAX + 1]; + u8 genmask = nft_genmask_next(ctx->net); + u32 flags = 0, size = 0, num_exprs = 0; + struct nft_set_ext_tmpl tmpl; + struct nft_set_ext *ext, *ext2; + struct nft_set_elem elem; + struct nft_set_binding *binding; + struct nft_object *obj = NULL; + struct nft_userdata *udata; + struct nft_data_desc desc; + enum nft_registers dreg; + struct nft_trans *trans; + u64 timeout; + u64 expiration; + int err, i; + u8 ulen; + + err = nla_parse_nested_deprecated(nla, NFTA_SET_ELEM_MAX, attr, + nft_set_elem_policy, NULL); + if (err < 0) + return err; + + nft_set_ext_prepare(&tmpl); + + err = nft_setelem_parse_flags(set, nla[NFTA_SET_ELEM_FLAGS], &flags); + if (err < 0) + return err; + + if (((flags & NFT_SET_ELEM_CATCHALL) && nla[NFTA_SET_ELEM_KEY]) || + (!(flags & NFT_SET_ELEM_CATCHALL) && !nla[NFTA_SET_ELEM_KEY])) + return -EINVAL; + + if (flags != 0) { + err = nft_set_ext_add(&tmpl, NFT_SET_EXT_FLAGS); + if (err < 0) + return err; + } + + if (set->flags & NFT_SET_MAP) { + if (nla[NFTA_SET_ELEM_DATA] == NULL && + !(flags & NFT_SET_ELEM_INTERVAL_END)) + return -EINVAL; + } else { + if (nla[NFTA_SET_ELEM_DATA] != NULL) + return -EINVAL; + } + + if (set->flags & NFT_SET_OBJECT) { + if (!nla[NFTA_SET_ELEM_OBJREF] && + !(flags & NFT_SET_ELEM_INTERVAL_END)) + return -EINVAL; + } else { + if (nla[NFTA_SET_ELEM_OBJREF]) + return -EINVAL; + } + + if (!nft_setelem_valid_key_end(set, nla, flags)) + return -EINVAL; + + if ((flags & NFT_SET_ELEM_INTERVAL_END) && + (nla[NFTA_SET_ELEM_DATA] || + nla[NFTA_SET_ELEM_OBJREF] || + nla[NFTA_SET_ELEM_TIMEOUT] || + nla[NFTA_SET_ELEM_EXPIRATION] || + nla[NFTA_SET_ELEM_USERDATA] || + nla[NFTA_SET_ELEM_EXPR] || + nla[NFTA_SET_ELEM_KEY_END] || + nla[NFTA_SET_ELEM_EXPRESSIONS])) + return -EINVAL; + + timeout = 0; + if (nla[NFTA_SET_ELEM_TIMEOUT] != NULL) { + if (!(set->flags & NFT_SET_TIMEOUT)) + return -EINVAL; + err = nf_msecs_to_jiffies64(nla[NFTA_SET_ELEM_TIMEOUT], + &timeout); + if (err) + return err; + } else if (set->flags & NFT_SET_TIMEOUT && + !(flags & NFT_SET_ELEM_INTERVAL_END)) { + timeout = READ_ONCE(set->timeout); + } + + expiration = 0; + if (nla[NFTA_SET_ELEM_EXPIRATION] != NULL) { + if (!(set->flags & NFT_SET_TIMEOUT)) + return -EINVAL; + err = nf_msecs_to_jiffies64(nla[NFTA_SET_ELEM_EXPIRATION], + &expiration); + if (err) + return err; + } + + if (nla[NFTA_SET_ELEM_EXPR]) { + struct nft_expr *expr; + + if (set->num_exprs && set->num_exprs != 1) + return -EOPNOTSUPP; + + expr = nft_set_elem_expr_alloc(ctx, set, + nla[NFTA_SET_ELEM_EXPR]); + if (IS_ERR(expr)) + return PTR_ERR(expr); + + expr_array[0] = expr; + num_exprs = 1; + + if (set->num_exprs && set->exprs[0]->ops != expr->ops) { + err = -EOPNOTSUPP; + goto err_set_elem_expr; + } + } else if (nla[NFTA_SET_ELEM_EXPRESSIONS]) { + struct nft_expr *expr; + struct nlattr *tmp; + int left; + + i = 0; + nla_for_each_nested(tmp, nla[NFTA_SET_ELEM_EXPRESSIONS], left) { + if (i == NFT_SET_EXPR_MAX || + (set->num_exprs && set->num_exprs == i)) { + err = -E2BIG; + goto err_set_elem_expr; + } + if (nla_type(tmp) != NFTA_LIST_ELEM) { + err = -EINVAL; + goto err_set_elem_expr; + } + expr = nft_set_elem_expr_alloc(ctx, set, tmp); + if (IS_ERR(expr)) { + err = PTR_ERR(expr); + goto err_set_elem_expr; + } + expr_array[i] = expr; + num_exprs++; + + if (set->num_exprs && expr->ops != set->exprs[i]->ops) { + err = -EOPNOTSUPP; + goto err_set_elem_expr; + } + i++; + } + if (set->num_exprs && set->num_exprs != i) { + err = -EOPNOTSUPP; + goto err_set_elem_expr; + } + } else if (set->num_exprs > 0 && + !(flags & NFT_SET_ELEM_INTERVAL_END)) { + err = nft_set_elem_expr_clone(ctx, set, expr_array); + if (err < 0) + goto err_set_elem_expr_clone; + + num_exprs = set->num_exprs; + } + + if (nla[NFTA_SET_ELEM_KEY]) { + err = nft_setelem_parse_key(ctx, set, &elem.key.val, + nla[NFTA_SET_ELEM_KEY]); + if (err < 0) + goto err_set_elem_expr; + + err = nft_set_ext_add_length(&tmpl, NFT_SET_EXT_KEY, set->klen); + if (err < 0) + goto err_parse_key; + } + + if (nla[NFTA_SET_ELEM_KEY_END]) { + err = nft_setelem_parse_key(ctx, set, &elem.key_end.val, + nla[NFTA_SET_ELEM_KEY_END]); + if (err < 0) + goto err_parse_key; + + err = nft_set_ext_add_length(&tmpl, NFT_SET_EXT_KEY_END, set->klen); + if (err < 0) + goto err_parse_key_end; + } + + if (timeout > 0) { + err = nft_set_ext_add(&tmpl, NFT_SET_EXT_EXPIRATION); + if (err < 0) + goto err_parse_key_end; + + if (timeout != READ_ONCE(set->timeout)) { + err = nft_set_ext_add(&tmpl, NFT_SET_EXT_TIMEOUT); + if (err < 0) + goto err_parse_key_end; + } + } + + if (num_exprs) { + for (i = 0; i < num_exprs; i++) + size += expr_array[i]->ops->size; + + err = nft_set_ext_add_length(&tmpl, NFT_SET_EXT_EXPRESSIONS, + sizeof(struct nft_set_elem_expr) + size); + if (err < 0) + goto err_parse_key_end; + } + + if (nla[NFTA_SET_ELEM_OBJREF] != NULL) { + obj = nft_obj_lookup(ctx->net, ctx->table, + nla[NFTA_SET_ELEM_OBJREF], + set->objtype, genmask); + if (IS_ERR(obj)) { + err = PTR_ERR(obj); + obj = NULL; + goto err_parse_key_end; + } + + if (!nft_use_inc(&obj->use)) { + err = -EMFILE; + obj = NULL; + goto err_parse_key_end; + } + + err = nft_set_ext_add(&tmpl, NFT_SET_EXT_OBJREF); + if (err < 0) + goto err_parse_key_end; + } + + if (nla[NFTA_SET_ELEM_DATA] != NULL) { + err = nft_setelem_parse_data(ctx, set, &desc, &elem.data.val, + nla[NFTA_SET_ELEM_DATA]); + if (err < 0) + goto err_parse_key_end; + + dreg = nft_type_to_reg(set->dtype); + list_for_each_entry(binding, &set->bindings, list) { + struct nft_ctx bind_ctx = { + .net = ctx->net, + .family = ctx->family, + .table = ctx->table, + .chain = (struct nft_chain *)binding->chain, + }; + + if (!(binding->flags & NFT_SET_MAP)) + continue; + + err = nft_validate_register_store(&bind_ctx, dreg, + &elem.data.val, + desc.type, desc.len); + if (err < 0) + goto err_parse_data; + + if (desc.type == NFT_DATA_VERDICT && + (elem.data.val.verdict.code == NFT_GOTO || + elem.data.val.verdict.code == NFT_JUMP)) + nft_validate_state_update(ctx->net, + NFT_VALIDATE_NEED); + } + + err = nft_set_ext_add_length(&tmpl, NFT_SET_EXT_DATA, desc.len); + if (err < 0) + goto err_parse_data; + } + + /* The full maximum length of userdata can exceed the maximum + * offset value (U8_MAX) for following extensions, therefor it + * must be the last extension added. + */ + ulen = 0; + if (nla[NFTA_SET_ELEM_USERDATA] != NULL) { + ulen = nla_len(nla[NFTA_SET_ELEM_USERDATA]); + if (ulen > 0) { + err = nft_set_ext_add_length(&tmpl, NFT_SET_EXT_USERDATA, + ulen); + if (err < 0) + goto err_parse_data; + } + } + + elem.priv = nft_set_elem_init(set, &tmpl, elem.key.val.data, + elem.key_end.val.data, elem.data.val.data, + timeout, expiration, GFP_KERNEL_ACCOUNT); + if (IS_ERR(elem.priv)) { + err = PTR_ERR(elem.priv); + goto err_parse_data; + } + + ext = nft_set_elem_ext(set, elem.priv); + if (flags) + *nft_set_ext_flags(ext) = flags; + + if (obj) + *nft_set_ext_obj(ext) = obj; + + if (ulen > 0) { + if (nft_set_ext_check(&tmpl, NFT_SET_EXT_USERDATA, ulen) < 0) { + err = -EINVAL; + goto err_elem_free; + } + udata = nft_set_ext_userdata(ext); + udata->len = ulen - 1; + nla_memcpy(&udata->data, nla[NFTA_SET_ELEM_USERDATA], ulen); + } + err = nft_set_elem_expr_setup(ctx, &tmpl, ext, expr_array, num_exprs); + if (err < 0) + goto err_elem_free; + + trans = nft_trans_elem_alloc(ctx, NFT_MSG_NEWSETELEM, set); + if (trans == NULL) { + err = -ENOMEM; + goto err_elem_free; + } + + ext->genmask = nft_genmask_cur(ctx->net); + + err = nft_setelem_insert(ctx->net, set, &elem, &ext2, flags); + if (err) { + if (err == -EEXIST) { + if (nft_set_ext_exists(ext, NFT_SET_EXT_DATA) ^ + nft_set_ext_exists(ext2, NFT_SET_EXT_DATA) || + nft_set_ext_exists(ext, NFT_SET_EXT_OBJREF) ^ + nft_set_ext_exists(ext2, NFT_SET_EXT_OBJREF)) + goto err_element_clash; + if ((nft_set_ext_exists(ext, NFT_SET_EXT_DATA) && + nft_set_ext_exists(ext2, NFT_SET_EXT_DATA) && + memcmp(nft_set_ext_data(ext), + nft_set_ext_data(ext2), set->dlen) != 0) || + (nft_set_ext_exists(ext, NFT_SET_EXT_OBJREF) && + nft_set_ext_exists(ext2, NFT_SET_EXT_OBJREF) && + *nft_set_ext_obj(ext) != *nft_set_ext_obj(ext2))) + goto err_element_clash; + else if (!(nlmsg_flags & NLM_F_EXCL)) + err = 0; + } else if (err == -ENOTEMPTY) { + /* ENOTEMPTY reports overlapping between this element + * and an existing one. + */ + err = -EEXIST; + } + goto err_element_clash; + } + + if (!(flags & NFT_SET_ELEM_CATCHALL) && set->size && + !atomic_add_unless(&set->nelems, 1, set->size + set->ndeact)) { + err = -ENFILE; + goto err_set_full; + } + + nft_trans_elem(trans) = elem; + nft_trans_commit_list_add_tail(ctx->net, trans); + return 0; + +err_set_full: + nft_setelem_remove(ctx->net, set, &elem); +err_element_clash: + kfree(trans); +err_elem_free: + nf_tables_set_elem_destroy(ctx, set, elem.priv); +err_parse_data: + if (nla[NFTA_SET_ELEM_DATA] != NULL) + nft_data_release(&elem.data.val, desc.type); +err_parse_key_end: + if (obj) + nft_use_dec_restore(&obj->use); + + nft_data_release(&elem.key_end.val, NFT_DATA_VALUE); +err_parse_key: + nft_data_release(&elem.key.val, NFT_DATA_VALUE); +err_set_elem_expr: + for (i = 0; i < num_exprs && expr_array[i]; i++) + nft_expr_destroy(ctx, expr_array[i]); +err_set_elem_expr_clone: + return err; +} + +static int nf_tables_newsetelem(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct nftables_pernet *nft_net = nft_pernet(info->net); + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_next(info->net); + u8 family = info->nfmsg->nfgen_family; + struct net *net = info->net; + const struct nlattr *attr; + struct nft_table *table; + struct nft_set *set; + struct nft_ctx ctx; + int rem, err; + + if (nla[NFTA_SET_ELEM_LIST_ELEMENTS] == NULL) + return -EINVAL; + + table = nft_table_lookup(net, nla[NFTA_SET_ELEM_LIST_TABLE], family, + genmask, NETLINK_CB(skb).portid); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_SET_ELEM_LIST_TABLE]); + return PTR_ERR(table); + } + + set = nft_set_lookup_global(net, table, nla[NFTA_SET_ELEM_LIST_SET], + nla[NFTA_SET_ELEM_LIST_SET_ID], genmask); + if (IS_ERR(set)) + return PTR_ERR(set); + + if (!list_empty(&set->bindings) && + (set->flags & (NFT_SET_CONSTANT | NFT_SET_ANONYMOUS))) + return -EBUSY; + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, NULL, nla); + + nla_for_each_nested(attr, nla[NFTA_SET_ELEM_LIST_ELEMENTS], rem) { + err = nft_add_set_elem(&ctx, set, attr, info->nlh->nlmsg_flags); + if (err < 0) { + NL_SET_BAD_ATTR(extack, attr); + return err; + } + } + + if (nft_net->validate_state == NFT_VALIDATE_DO) + return nft_table_validate(net, table); + + return 0; +} + +/** + * nft_data_hold - hold a nft_data item + * + * @data: struct nft_data to release + * @type: type of data + * + * Hold a nft_data item. NFT_DATA_VALUE types can be silently discarded, + * NFT_DATA_VERDICT bumps the reference to chains in case of NFT_JUMP and + * NFT_GOTO verdicts. This function must be called on active data objects + * from the second phase of the commit protocol. + */ +void nft_data_hold(const struct nft_data *data, enum nft_data_types type) +{ + struct nft_chain *chain; + + if (type == NFT_DATA_VERDICT) { + switch (data->verdict.code) { + case NFT_JUMP: + case NFT_GOTO: + chain = data->verdict.chain; + nft_use_inc_restore(&chain->use); + break; + } + } +} + +static void nft_setelem_data_activate(const struct net *net, + const struct nft_set *set, + struct nft_set_elem *elem) +{ + const struct nft_set_ext *ext = nft_set_elem_ext(set, elem->priv); + + if (nft_set_ext_exists(ext, NFT_SET_EXT_DATA)) + nft_data_hold(nft_set_ext_data(ext), set->dtype); + if (nft_set_ext_exists(ext, NFT_SET_EXT_OBJREF)) + nft_use_inc_restore(&(*nft_set_ext_obj(ext))->use); +} + +void nft_setelem_data_deactivate(const struct net *net, + const struct nft_set *set, + struct nft_set_elem *elem) +{ + const struct nft_set_ext *ext = nft_set_elem_ext(set, elem->priv); + + if (nft_set_ext_exists(ext, NFT_SET_EXT_DATA)) + nft_data_release(nft_set_ext_data(ext), set->dtype); + if (nft_set_ext_exists(ext, NFT_SET_EXT_OBJREF)) + nft_use_dec(&(*nft_set_ext_obj(ext))->use); +} + +static int nft_del_setelem(struct nft_ctx *ctx, struct nft_set *set, + const struct nlattr *attr) +{ + struct nlattr *nla[NFTA_SET_ELEM_MAX + 1]; + struct nft_set_ext_tmpl tmpl; + struct nft_set_elem elem; + struct nft_set_ext *ext; + struct nft_trans *trans; + u32 flags = 0; + int err; + + err = nla_parse_nested_deprecated(nla, NFTA_SET_ELEM_MAX, attr, + nft_set_elem_policy, NULL); + if (err < 0) + return err; + + err = nft_setelem_parse_flags(set, nla[NFTA_SET_ELEM_FLAGS], &flags); + if (err < 0) + return err; + + if (!nla[NFTA_SET_ELEM_KEY] && !(flags & NFT_SET_ELEM_CATCHALL)) + return -EINVAL; + + if (!nft_setelem_valid_key_end(set, nla, flags)) + return -EINVAL; + + nft_set_ext_prepare(&tmpl); + + if (flags != 0) { + err = nft_set_ext_add(&tmpl, NFT_SET_EXT_FLAGS); + if (err < 0) + return err; + } + + if (nla[NFTA_SET_ELEM_KEY]) { + err = nft_setelem_parse_key(ctx, set, &elem.key.val, + nla[NFTA_SET_ELEM_KEY]); + if (err < 0) + return err; + + err = nft_set_ext_add_length(&tmpl, NFT_SET_EXT_KEY, set->klen); + if (err < 0) + goto fail_elem; + } + + if (nla[NFTA_SET_ELEM_KEY_END]) { + err = nft_setelem_parse_key(ctx, set, &elem.key_end.val, + nla[NFTA_SET_ELEM_KEY_END]); + if (err < 0) + goto fail_elem; + + err = nft_set_ext_add_length(&tmpl, NFT_SET_EXT_KEY_END, set->klen); + if (err < 0) + goto fail_elem_key_end; + } + + err = -ENOMEM; + elem.priv = nft_set_elem_init(set, &tmpl, elem.key.val.data, + elem.key_end.val.data, NULL, 0, 0, + GFP_KERNEL_ACCOUNT); + if (IS_ERR(elem.priv)) { + err = PTR_ERR(elem.priv); + goto fail_elem_key_end; + } + + ext = nft_set_elem_ext(set, elem.priv); + if (flags) + *nft_set_ext_flags(ext) = flags; + + trans = nft_trans_elem_alloc(ctx, NFT_MSG_DELSETELEM, set); + if (trans == NULL) + goto fail_trans; + + err = nft_setelem_deactivate(ctx->net, set, &elem, flags); + if (err < 0) + goto fail_ops; + + nft_setelem_data_deactivate(ctx->net, set, &elem); + + nft_trans_elem(trans) = elem; + nft_trans_commit_list_add_tail(ctx->net, trans); + return 0; + +fail_ops: + kfree(trans); +fail_trans: + kfree(elem.priv); +fail_elem_key_end: + nft_data_release(&elem.key_end.val, NFT_DATA_VALUE); +fail_elem: + nft_data_release(&elem.key.val, NFT_DATA_VALUE); + return err; +} + +static int nft_setelem_flush(const struct nft_ctx *ctx, + struct nft_set *set, + const struct nft_set_iter *iter, + struct nft_set_elem *elem) +{ + struct nft_trans *trans; + int err; + + trans = nft_trans_alloc_gfp(ctx, NFT_MSG_DELSETELEM, + sizeof(struct nft_trans_elem), GFP_ATOMIC); + if (!trans) + return -ENOMEM; + + if (!set->ops->flush(ctx->net, set, elem->priv)) { + err = -ENOENT; + goto err1; + } + set->ndeact++; + + nft_setelem_data_deactivate(ctx->net, set, elem); + nft_trans_elem_set(trans) = set; + nft_trans_elem(trans) = *elem; + nft_trans_commit_list_add_tail(ctx->net, trans); + + return 0; +err1: + kfree(trans); + return err; +} + +static int __nft_set_catchall_flush(const struct nft_ctx *ctx, + struct nft_set *set, + struct nft_set_elem *elem) +{ + struct nft_trans *trans; + + trans = nft_trans_alloc_gfp(ctx, NFT_MSG_DELSETELEM, + sizeof(struct nft_trans_elem), GFP_KERNEL); + if (!trans) + return -ENOMEM; + + nft_setelem_data_deactivate(ctx->net, set, elem); + nft_trans_elem_set(trans) = set; + nft_trans_elem(trans) = *elem; + nft_trans_commit_list_add_tail(ctx->net, trans); + + return 0; +} + +static int nft_set_catchall_flush(const struct nft_ctx *ctx, + struct nft_set *set) +{ + u8 genmask = nft_genmask_next(ctx->net); + struct nft_set_elem_catchall *catchall; + struct nft_set_elem elem; + struct nft_set_ext *ext; + int ret = 0; + + list_for_each_entry_rcu(catchall, &set->catchall_list, list) { + ext = nft_set_elem_ext(set, catchall->elem); + if (!nft_set_elem_active(ext, genmask)) + continue; + + elem.priv = catchall->elem; + ret = __nft_set_catchall_flush(ctx, set, &elem); + if (ret < 0) + break; + nft_set_elem_change_active(ctx->net, set, ext); + } + + return ret; +} + +static int nft_set_flush(struct nft_ctx *ctx, struct nft_set *set, u8 genmask) +{ + struct nft_set_iter iter = { + .genmask = genmask, + .fn = nft_setelem_flush, + }; + + set->ops->walk(ctx, set, &iter); + if (!iter.err) + iter.err = nft_set_catchall_flush(ctx, set); + + return iter.err; +} + +static int nf_tables_delsetelem(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_next(info->net); + u8 family = info->nfmsg->nfgen_family; + struct net *net = info->net; + const struct nlattr *attr; + struct nft_table *table; + struct nft_set *set; + struct nft_ctx ctx; + int rem, err = 0; + + table = nft_table_lookup(net, nla[NFTA_SET_ELEM_LIST_TABLE], family, + genmask, NETLINK_CB(skb).portid); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_SET_ELEM_LIST_TABLE]); + return PTR_ERR(table); + } + + set = nft_set_lookup(table, nla[NFTA_SET_ELEM_LIST_SET], genmask); + if (IS_ERR(set)) + return PTR_ERR(set); + + if (nft_set_is_anonymous(set)) + return -EOPNOTSUPP; + + if (!list_empty(&set->bindings) && (set->flags & NFT_SET_CONSTANT)) + return -EBUSY; + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, NULL, nla); + + if (!nla[NFTA_SET_ELEM_LIST_ELEMENTS]) + return nft_set_flush(&ctx, set, genmask); + + nla_for_each_nested(attr, nla[NFTA_SET_ELEM_LIST_ELEMENTS], rem) { + err = nft_del_setelem(&ctx, set, attr); + if (err < 0) { + NL_SET_BAD_ATTR(extack, attr); + break; + } + } + return err; +} + +/* + * Stateful objects + */ + +/** + * nft_register_obj- register nf_tables stateful object type + * @obj_type: object type + * + * Registers the object type for use with nf_tables. Returns zero on + * success or a negative errno code otherwise. + */ +int nft_register_obj(struct nft_object_type *obj_type) +{ + if (obj_type->type == NFT_OBJECT_UNSPEC) + return -EINVAL; + + nfnl_lock(NFNL_SUBSYS_NFTABLES); + list_add_rcu(&obj_type->list, &nf_tables_objects); + nfnl_unlock(NFNL_SUBSYS_NFTABLES); + return 0; +} +EXPORT_SYMBOL_GPL(nft_register_obj); + +/** + * nft_unregister_obj - unregister nf_tables object type + * @obj_type: object type + * + * Unregisters the object type for use with nf_tables. + */ +void nft_unregister_obj(struct nft_object_type *obj_type) +{ + nfnl_lock(NFNL_SUBSYS_NFTABLES); + list_del_rcu(&obj_type->list); + nfnl_unlock(NFNL_SUBSYS_NFTABLES); +} +EXPORT_SYMBOL_GPL(nft_unregister_obj); + +struct nft_object *nft_obj_lookup(const struct net *net, + const struct nft_table *table, + const struct nlattr *nla, u32 objtype, + u8 genmask) +{ + struct nft_object_hash_key k = { .table = table }; + char search[NFT_OBJ_MAXNAMELEN]; + struct rhlist_head *tmp, *list; + struct nft_object *obj; + + nla_strscpy(search, nla, sizeof(search)); + k.name = search; + + WARN_ON_ONCE(!rcu_read_lock_held() && + !lockdep_commit_lock_is_held(net)); + + rcu_read_lock(); + list = rhltable_lookup(&nft_objname_ht, &k, nft_objname_ht_params); + if (!list) + goto out; + + rhl_for_each_entry_rcu(obj, tmp, list, rhlhead) { + if (objtype == obj->ops->type->type && + nft_active_genmask(obj, genmask)) { + rcu_read_unlock(); + return obj; + } + } +out: + rcu_read_unlock(); + return ERR_PTR(-ENOENT); +} +EXPORT_SYMBOL_GPL(nft_obj_lookup); + +static struct nft_object *nft_obj_lookup_byhandle(const struct nft_table *table, + const struct nlattr *nla, + u32 objtype, u8 genmask) +{ + struct nft_object *obj; + + list_for_each_entry(obj, &table->objects, list) { + if (be64_to_cpu(nla_get_be64(nla)) == obj->handle && + objtype == obj->ops->type->type && + nft_active_genmask(obj, genmask)) + return obj; + } + return ERR_PTR(-ENOENT); +} + +static const struct nla_policy nft_obj_policy[NFTA_OBJ_MAX + 1] = { + [NFTA_OBJ_TABLE] = { .type = NLA_STRING, + .len = NFT_TABLE_MAXNAMELEN - 1 }, + [NFTA_OBJ_NAME] = { .type = NLA_STRING, + .len = NFT_OBJ_MAXNAMELEN - 1 }, + [NFTA_OBJ_TYPE] = { .type = NLA_U32 }, + [NFTA_OBJ_DATA] = { .type = NLA_NESTED }, + [NFTA_OBJ_HANDLE] = { .type = NLA_U64}, + [NFTA_OBJ_USERDATA] = { .type = NLA_BINARY, + .len = NFT_USERDATA_MAXLEN }, +}; + +static struct nft_object *nft_obj_init(const struct nft_ctx *ctx, + const struct nft_object_type *type, + const struct nlattr *attr) +{ + struct nlattr **tb; + const struct nft_object_ops *ops; + struct nft_object *obj; + int err = -ENOMEM; + + tb = kmalloc_array(type->maxattr + 1, sizeof(*tb), GFP_KERNEL); + if (!tb) + goto err1; + + if (attr) { + err = nla_parse_nested_deprecated(tb, type->maxattr, attr, + type->policy, NULL); + if (err < 0) + goto err2; + } else { + memset(tb, 0, sizeof(tb[0]) * (type->maxattr + 1)); + } + + if (type->select_ops) { + ops = type->select_ops(ctx, (const struct nlattr * const *)tb); + if (IS_ERR(ops)) { + err = PTR_ERR(ops); + goto err2; + } + } else { + ops = type->ops; + } + + err = -ENOMEM; + obj = kzalloc(sizeof(*obj) + ops->size, GFP_KERNEL_ACCOUNT); + if (!obj) + goto err2; + + err = ops->init(ctx, (const struct nlattr * const *)tb, obj); + if (err < 0) + goto err3; + + obj->ops = ops; + + kfree(tb); + return obj; +err3: + kfree(obj); +err2: + kfree(tb); +err1: + return ERR_PTR(err); +} + +static int nft_object_dump(struct sk_buff *skb, unsigned int attr, + struct nft_object *obj, bool reset) +{ + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, attr); + if (!nest) + goto nla_put_failure; + if (obj->ops->dump(skb, obj, reset) < 0) + goto nla_put_failure; + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + return -1; +} + +static const struct nft_object_type *__nft_obj_type_get(u32 objtype) +{ + const struct nft_object_type *type; + + list_for_each_entry(type, &nf_tables_objects, list) { + if (objtype == type->type) + return type; + } + return NULL; +} + +static const struct nft_object_type * +nft_obj_type_get(struct net *net, u32 objtype) +{ + const struct nft_object_type *type; + + type = __nft_obj_type_get(objtype); + if (type != NULL && try_module_get(type->owner)) + return type; + + lockdep_nfnl_nft_mutex_not_held(); +#ifdef CONFIG_MODULES + if (type == NULL) { + if (nft_request_module(net, "nft-obj-%u", objtype) == -EAGAIN) + return ERR_PTR(-EAGAIN); + } +#endif + return ERR_PTR(-ENOENT); +} + +static int nf_tables_updobj(const struct nft_ctx *ctx, + const struct nft_object_type *type, + const struct nlattr *attr, + struct nft_object *obj) +{ + struct nft_object *newobj; + struct nft_trans *trans; + int err = -ENOMEM; + + if (!try_module_get(type->owner)) + return -ENOENT; + + trans = nft_trans_alloc(ctx, NFT_MSG_NEWOBJ, + sizeof(struct nft_trans_obj)); + if (!trans) + goto err_trans; + + newobj = nft_obj_init(ctx, type, attr); + if (IS_ERR(newobj)) { + err = PTR_ERR(newobj); + goto err_free_trans; + } + + nft_trans_obj(trans) = obj; + nft_trans_obj_update(trans) = true; + nft_trans_obj_newobj(trans) = newobj; + nft_trans_commit_list_add_tail(ctx->net, trans); + + return 0; + +err_free_trans: + kfree(trans); +err_trans: + module_put(type->owner); + return err; +} + +static int nf_tables_newobj(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_next(info->net); + u8 family = info->nfmsg->nfgen_family; + const struct nft_object_type *type; + struct net *net = info->net; + struct nft_table *table; + struct nft_object *obj; + struct nft_ctx ctx; + u32 objtype; + int err; + + if (!nla[NFTA_OBJ_TYPE] || + !nla[NFTA_OBJ_NAME] || + !nla[NFTA_OBJ_DATA]) + return -EINVAL; + + table = nft_table_lookup(net, nla[NFTA_OBJ_TABLE], family, genmask, + NETLINK_CB(skb).portid); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_OBJ_TABLE]); + return PTR_ERR(table); + } + + objtype = ntohl(nla_get_be32(nla[NFTA_OBJ_TYPE])); + obj = nft_obj_lookup(net, table, nla[NFTA_OBJ_NAME], objtype, genmask); + if (IS_ERR(obj)) { + err = PTR_ERR(obj); + if (err != -ENOENT) { + NL_SET_BAD_ATTR(extack, nla[NFTA_OBJ_NAME]); + return err; + } + } else { + if (info->nlh->nlmsg_flags & NLM_F_EXCL) { + NL_SET_BAD_ATTR(extack, nla[NFTA_OBJ_NAME]); + return -EEXIST; + } + if (info->nlh->nlmsg_flags & NLM_F_REPLACE) + return -EOPNOTSUPP; + + type = __nft_obj_type_get(objtype); + if (WARN_ON_ONCE(!type)) + return -ENOENT; + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, NULL, nla); + + return nf_tables_updobj(&ctx, type, nla[NFTA_OBJ_DATA], obj); + } + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, NULL, nla); + + if (!nft_use_inc(&table->use)) + return -EMFILE; + + type = nft_obj_type_get(net, objtype); + if (IS_ERR(type)) { + err = PTR_ERR(type); + goto err_type; + } + + obj = nft_obj_init(&ctx, type, nla[NFTA_OBJ_DATA]); + if (IS_ERR(obj)) { + err = PTR_ERR(obj); + goto err_init; + } + obj->key.table = table; + obj->handle = nf_tables_alloc_handle(table); + + obj->key.name = nla_strdup(nla[NFTA_OBJ_NAME], GFP_KERNEL_ACCOUNT); + if (!obj->key.name) { + err = -ENOMEM; + goto err_strdup; + } + + if (nla[NFTA_OBJ_USERDATA]) { + obj->udata = nla_memdup(nla[NFTA_OBJ_USERDATA], GFP_KERNEL_ACCOUNT); + if (obj->udata == NULL) + goto err_userdata; + + obj->udlen = nla_len(nla[NFTA_OBJ_USERDATA]); + } + + err = nft_trans_obj_add(&ctx, NFT_MSG_NEWOBJ, obj); + if (err < 0) + goto err_trans; + + err = rhltable_insert(&nft_objname_ht, &obj->rhlhead, + nft_objname_ht_params); + if (err < 0) + goto err_obj_ht; + + list_add_tail_rcu(&obj->list, &table->objects); + + return 0; +err_obj_ht: + /* queued in transaction log */ + INIT_LIST_HEAD(&obj->list); + return err; +err_trans: + kfree(obj->udata); +err_userdata: + kfree(obj->key.name); +err_strdup: + if (obj->ops->destroy) + obj->ops->destroy(&ctx, obj); + kfree(obj); +err_init: + module_put(type->owner); +err_type: + nft_use_dec_restore(&table->use); + + return err; +} + +static int nf_tables_fill_obj_info(struct sk_buff *skb, struct net *net, + u32 portid, u32 seq, int event, u32 flags, + int family, const struct nft_table *table, + struct nft_object *obj, bool reset) +{ + struct nlmsghdr *nlh; + + event = nfnl_msg_type(NFNL_SUBSYS_NFTABLES, event); + nlh = nfnl_msg_put(skb, portid, seq, event, flags, family, + NFNETLINK_V0, nft_base_seq(net)); + if (!nlh) + goto nla_put_failure; + + if (nla_put_string(skb, NFTA_OBJ_TABLE, table->name) || + nla_put_string(skb, NFTA_OBJ_NAME, obj->key.name) || + nla_put_be32(skb, NFTA_OBJ_TYPE, htonl(obj->ops->type->type)) || + nla_put_be32(skb, NFTA_OBJ_USE, htonl(obj->use)) || + nft_object_dump(skb, NFTA_OBJ_DATA, obj, reset) || + nla_put_be64(skb, NFTA_OBJ_HANDLE, cpu_to_be64(obj->handle), + NFTA_OBJ_PAD)) + goto nla_put_failure; + + if (obj->udata && + nla_put(skb, NFTA_OBJ_USERDATA, obj->udlen, obj->udata)) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_trim(skb, nlh); + return -1; +} + +static void audit_log_obj_reset(const struct nft_table *table, + unsigned int base_seq, unsigned int nentries) +{ + char *buf = kasprintf(GFP_ATOMIC, "%s:%u", table->name, base_seq); + + audit_log_nfcfg(buf, table->family, nentries, + AUDIT_NFT_OP_OBJ_RESET, GFP_ATOMIC); + kfree(buf); +} + +struct nft_obj_filter { + char *table; + u32 type; +}; + +static int nf_tables_dump_obj(struct sk_buff *skb, struct netlink_callback *cb) +{ + const struct nfgenmsg *nfmsg = nlmsg_data(cb->nlh); + const struct nft_table *table; + unsigned int idx = 0, s_idx = cb->args[0]; + struct nft_obj_filter *filter = cb->data; + struct net *net = sock_net(skb->sk); + int family = nfmsg->nfgen_family; + struct nftables_pernet *nft_net; + unsigned int entries = 0; + struct nft_object *obj; + bool reset = false; + int rc = 0; + + if (NFNL_MSG_TYPE(cb->nlh->nlmsg_type) == NFT_MSG_GETOBJ_RESET) + reset = true; + + rcu_read_lock(); + nft_net = nft_pernet(net); + cb->seq = READ_ONCE(nft_net->base_seq); + + list_for_each_entry_rcu(table, &nft_net->tables, list) { + if (family != NFPROTO_UNSPEC && family != table->family) + continue; + + entries = 0; + list_for_each_entry_rcu(obj, &table->objects, list) { + if (!nft_is_active(net, obj)) + goto cont; + if (idx < s_idx) + goto cont; + if (idx > s_idx) + memset(&cb->args[1], 0, + sizeof(cb->args) - sizeof(cb->args[0])); + if (filter && filter->table && + strcmp(filter->table, table->name)) + goto cont; + if (filter && + filter->type != NFT_OBJECT_UNSPEC && + obj->ops->type->type != filter->type) + goto cont; + + rc = nf_tables_fill_obj_info(skb, net, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NFT_MSG_NEWOBJ, + NLM_F_MULTI | NLM_F_APPEND, + table->family, table, + obj, reset); + if (rc < 0) + break; + + entries++; + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); +cont: + idx++; + } + if (reset && entries) + audit_log_obj_reset(table, nft_net->base_seq, entries); + if (rc < 0) + break; + } + rcu_read_unlock(); + + cb->args[0] = idx; + return skb->len; +} + +static int nf_tables_dump_obj_start(struct netlink_callback *cb) +{ + const struct nlattr * const *nla = cb->data; + struct nft_obj_filter *filter = NULL; + + if (nla[NFTA_OBJ_TABLE] || nla[NFTA_OBJ_TYPE]) { + filter = kzalloc(sizeof(*filter), GFP_ATOMIC); + if (!filter) + return -ENOMEM; + + if (nla[NFTA_OBJ_TABLE]) { + filter->table = nla_strdup(nla[NFTA_OBJ_TABLE], GFP_ATOMIC); + if (!filter->table) { + kfree(filter); + return -ENOMEM; + } + } + + if (nla[NFTA_OBJ_TYPE]) + filter->type = ntohl(nla_get_be32(nla[NFTA_OBJ_TYPE])); + } + + cb->data = filter; + return 0; +} + +static int nf_tables_dump_obj_done(struct netlink_callback *cb) +{ + struct nft_obj_filter *filter = cb->data; + + if (filter) { + kfree(filter->table); + kfree(filter); + } + + return 0; +} + +/* called with rcu_read_lock held */ +static int nf_tables_getobj(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_cur(info->net); + u8 family = info->nfmsg->nfgen_family; + const struct nft_table *table; + struct net *net = info->net; + struct nft_object *obj; + struct sk_buff *skb2; + bool reset = false; + u32 objtype; + int err; + + if (info->nlh->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .start = nf_tables_dump_obj_start, + .dump = nf_tables_dump_obj, + .done = nf_tables_dump_obj_done, + .module = THIS_MODULE, + .data = (void *)nla, + }; + + return nft_netlink_dump_start_rcu(info->sk, skb, info->nlh, &c); + } + + if (!nla[NFTA_OBJ_NAME] || + !nla[NFTA_OBJ_TYPE]) + return -EINVAL; + + table = nft_table_lookup(net, nla[NFTA_OBJ_TABLE], family, genmask, 0); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_OBJ_TABLE]); + return PTR_ERR(table); + } + + objtype = ntohl(nla_get_be32(nla[NFTA_OBJ_TYPE])); + obj = nft_obj_lookup(net, table, nla[NFTA_OBJ_NAME], objtype, genmask); + if (IS_ERR(obj)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_OBJ_NAME]); + return PTR_ERR(obj); + } + + skb2 = alloc_skb(NLMSG_GOODSIZE, GFP_ATOMIC); + if (!skb2) + return -ENOMEM; + + if (NFNL_MSG_TYPE(info->nlh->nlmsg_type) == NFT_MSG_GETOBJ_RESET) + reset = true; + + if (reset) { + const struct nftables_pernet *nft_net; + char *buf; + + nft_net = nft_pernet(net); + buf = kasprintf(GFP_ATOMIC, "%s:%u", table->name, nft_net->base_seq); + + audit_log_nfcfg(buf, + family, + 1, + AUDIT_NFT_OP_OBJ_RESET, + GFP_ATOMIC); + kfree(buf); + } + + err = nf_tables_fill_obj_info(skb2, net, NETLINK_CB(skb).portid, + info->nlh->nlmsg_seq, NFT_MSG_NEWOBJ, 0, + family, table, obj, reset); + if (err < 0) + goto err_fill_obj_info; + + return nfnetlink_unicast(skb2, net, NETLINK_CB(skb).portid); + +err_fill_obj_info: + kfree_skb(skb2); + return err; +} + +static void nft_obj_destroy(const struct nft_ctx *ctx, struct nft_object *obj) +{ + if (obj->ops->destroy) + obj->ops->destroy(ctx, obj); + + module_put(obj->ops->type->owner); + kfree(obj->key.name); + kfree(obj->udata); + kfree(obj); +} + +static int nf_tables_delobj(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_next(info->net); + u8 family = info->nfmsg->nfgen_family; + struct net *net = info->net; + const struct nlattr *attr; + struct nft_table *table; + struct nft_object *obj; + struct nft_ctx ctx; + u32 objtype; + + if (!nla[NFTA_OBJ_TYPE] || + (!nla[NFTA_OBJ_NAME] && !nla[NFTA_OBJ_HANDLE])) + return -EINVAL; + + table = nft_table_lookup(net, nla[NFTA_OBJ_TABLE], family, genmask, + NETLINK_CB(skb).portid); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_OBJ_TABLE]); + return PTR_ERR(table); + } + + objtype = ntohl(nla_get_be32(nla[NFTA_OBJ_TYPE])); + if (nla[NFTA_OBJ_HANDLE]) { + attr = nla[NFTA_OBJ_HANDLE]; + obj = nft_obj_lookup_byhandle(table, attr, objtype, genmask); + } else { + attr = nla[NFTA_OBJ_NAME]; + obj = nft_obj_lookup(net, table, attr, objtype, genmask); + } + + if (IS_ERR(obj)) { + NL_SET_BAD_ATTR(extack, attr); + return PTR_ERR(obj); + } + if (obj->use > 0) { + NL_SET_BAD_ATTR(extack, attr); + return -EBUSY; + } + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, NULL, nla); + + return nft_delobj(&ctx, obj); +} + +static void +__nft_obj_notify(struct net *net, const struct nft_table *table, + struct nft_object *obj, u32 portid, u32 seq, int event, + u16 flags, int family, int report, gfp_t gfp) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + struct sk_buff *skb; + int err; + + if (!report && + !nfnetlink_has_listeners(net, NFNLGRP_NFTABLES)) + return; + + skb = nlmsg_new(NLMSG_GOODSIZE, gfp); + if (skb == NULL) + goto err; + + err = nf_tables_fill_obj_info(skb, net, portid, seq, event, + flags & (NLM_F_CREATE | NLM_F_EXCL), + family, table, obj, false); + if (err < 0) { + kfree_skb(skb); + goto err; + } + + nft_notify_enqueue(skb, report, &nft_net->notify_list); + return; +err: + nfnetlink_set_err(net, portid, NFNLGRP_NFTABLES, -ENOBUFS); +} + +void nft_obj_notify(struct net *net, const struct nft_table *table, + struct nft_object *obj, u32 portid, u32 seq, int event, + u16 flags, int family, int report, gfp_t gfp) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + char *buf = kasprintf(gfp, "%s:%u", + table->name, nft_net->base_seq); + + audit_log_nfcfg(buf, + family, + obj->handle, + event == NFT_MSG_NEWOBJ ? + AUDIT_NFT_OP_OBJ_REGISTER : + AUDIT_NFT_OP_OBJ_UNREGISTER, + gfp); + kfree(buf); + + __nft_obj_notify(net, table, obj, portid, seq, event, + flags, family, report, gfp); +} +EXPORT_SYMBOL_GPL(nft_obj_notify); + +static void nf_tables_obj_notify(const struct nft_ctx *ctx, + struct nft_object *obj, int event) +{ + __nft_obj_notify(ctx->net, ctx->table, obj, ctx->portid, + ctx->seq, event, ctx->flags, ctx->family, + ctx->report, GFP_KERNEL); +} + +/* + * Flow tables + */ +void nft_register_flowtable_type(struct nf_flowtable_type *type) +{ + nfnl_lock(NFNL_SUBSYS_NFTABLES); + list_add_tail_rcu(&type->list, &nf_tables_flowtables); + nfnl_unlock(NFNL_SUBSYS_NFTABLES); +} +EXPORT_SYMBOL_GPL(nft_register_flowtable_type); + +void nft_unregister_flowtable_type(struct nf_flowtable_type *type) +{ + nfnl_lock(NFNL_SUBSYS_NFTABLES); + list_del_rcu(&type->list); + nfnl_unlock(NFNL_SUBSYS_NFTABLES); +} +EXPORT_SYMBOL_GPL(nft_unregister_flowtable_type); + +static const struct nla_policy nft_flowtable_policy[NFTA_FLOWTABLE_MAX + 1] = { + [NFTA_FLOWTABLE_TABLE] = { .type = NLA_STRING, + .len = NFT_NAME_MAXLEN - 1 }, + [NFTA_FLOWTABLE_NAME] = { .type = NLA_STRING, + .len = NFT_NAME_MAXLEN - 1 }, + [NFTA_FLOWTABLE_HOOK] = { .type = NLA_NESTED }, + [NFTA_FLOWTABLE_HANDLE] = { .type = NLA_U64 }, + [NFTA_FLOWTABLE_FLAGS] = { .type = NLA_U32 }, +}; + +struct nft_flowtable *nft_flowtable_lookup(const struct nft_table *table, + const struct nlattr *nla, u8 genmask) +{ + struct nft_flowtable *flowtable; + + list_for_each_entry_rcu(flowtable, &table->flowtables, list) { + if (!nla_strcmp(nla, flowtable->name) && + nft_active_genmask(flowtable, genmask)) + return flowtable; + } + return ERR_PTR(-ENOENT); +} +EXPORT_SYMBOL_GPL(nft_flowtable_lookup); + +void nf_tables_deactivate_flowtable(const struct nft_ctx *ctx, + struct nft_flowtable *flowtable, + enum nft_trans_phase phase) +{ + switch (phase) { + case NFT_TRANS_PREPARE_ERROR: + case NFT_TRANS_PREPARE: + case NFT_TRANS_ABORT: + case NFT_TRANS_RELEASE: + nft_use_dec(&flowtable->use); + fallthrough; + default: + return; + } +} +EXPORT_SYMBOL_GPL(nf_tables_deactivate_flowtable); + +static struct nft_flowtable * +nft_flowtable_lookup_byhandle(const struct nft_table *table, + const struct nlattr *nla, u8 genmask) +{ + struct nft_flowtable *flowtable; + + list_for_each_entry(flowtable, &table->flowtables, list) { + if (be64_to_cpu(nla_get_be64(nla)) == flowtable->handle && + nft_active_genmask(flowtable, genmask)) + return flowtable; + } + return ERR_PTR(-ENOENT); +} + +struct nft_flowtable_hook { + u32 num; + int priority; + struct list_head list; +}; + +static const struct nla_policy nft_flowtable_hook_policy[NFTA_FLOWTABLE_HOOK_MAX + 1] = { + [NFTA_FLOWTABLE_HOOK_NUM] = { .type = NLA_U32 }, + [NFTA_FLOWTABLE_HOOK_PRIORITY] = { .type = NLA_U32 }, + [NFTA_FLOWTABLE_HOOK_DEVS] = { .type = NLA_NESTED }, +}; + +static int nft_flowtable_parse_hook(const struct nft_ctx *ctx, + const struct nlattr *attr, + struct nft_flowtable_hook *flowtable_hook, + struct nft_flowtable *flowtable, bool add) +{ + struct nlattr *tb[NFTA_FLOWTABLE_HOOK_MAX + 1]; + struct nft_hook *hook; + int hooknum, priority; + int err; + + INIT_LIST_HEAD(&flowtable_hook->list); + + err = nla_parse_nested_deprecated(tb, NFTA_FLOWTABLE_HOOK_MAX, attr, + nft_flowtable_hook_policy, NULL); + if (err < 0) + return err; + + if (add) { + if (!tb[NFTA_FLOWTABLE_HOOK_NUM] || + !tb[NFTA_FLOWTABLE_HOOK_PRIORITY]) + return -EINVAL; + + hooknum = ntohl(nla_get_be32(tb[NFTA_FLOWTABLE_HOOK_NUM])); + if (hooknum != NF_NETDEV_INGRESS) + return -EOPNOTSUPP; + + priority = ntohl(nla_get_be32(tb[NFTA_FLOWTABLE_HOOK_PRIORITY])); + + flowtable_hook->priority = priority; + flowtable_hook->num = hooknum; + } else { + if (tb[NFTA_FLOWTABLE_HOOK_NUM]) { + hooknum = ntohl(nla_get_be32(tb[NFTA_FLOWTABLE_HOOK_NUM])); + if (hooknum != flowtable->hooknum) + return -EOPNOTSUPP; + } + + if (tb[NFTA_FLOWTABLE_HOOK_PRIORITY]) { + priority = ntohl(nla_get_be32(tb[NFTA_FLOWTABLE_HOOK_PRIORITY])); + if (priority != flowtable->data.priority) + return -EOPNOTSUPP; + } + + flowtable_hook->priority = flowtable->data.priority; + flowtable_hook->num = flowtable->hooknum; + } + + if (tb[NFTA_FLOWTABLE_HOOK_DEVS]) { + err = nf_tables_parse_netdev_hooks(ctx->net, + tb[NFTA_FLOWTABLE_HOOK_DEVS], + &flowtable_hook->list); + if (err < 0) + return err; + } + + list_for_each_entry(hook, &flowtable_hook->list, list) { + hook->ops.pf = NFPROTO_NETDEV; + hook->ops.hooknum = flowtable_hook->num; + hook->ops.priority = flowtable_hook->priority; + hook->ops.priv = &flowtable->data; + hook->ops.hook = flowtable->data.type->hook; + } + + return err; +} + +static const struct nf_flowtable_type *__nft_flowtable_type_get(u8 family) +{ + const struct nf_flowtable_type *type; + + list_for_each_entry(type, &nf_tables_flowtables, list) { + if (family == type->family) + return type; + } + return NULL; +} + +static const struct nf_flowtable_type * +nft_flowtable_type_get(struct net *net, u8 family) +{ + const struct nf_flowtable_type *type; + + type = __nft_flowtable_type_get(family); + if (type != NULL && try_module_get(type->owner)) + return type; + + lockdep_nfnl_nft_mutex_not_held(); +#ifdef CONFIG_MODULES + if (type == NULL) { + if (nft_request_module(net, "nf-flowtable-%u", family) == -EAGAIN) + return ERR_PTR(-EAGAIN); + } +#endif + return ERR_PTR(-ENOENT); +} + +/* Only called from error and netdev event paths. */ +static void nft_unregister_flowtable_hook(struct net *net, + struct nft_flowtable *flowtable, + struct nft_hook *hook) +{ + nf_unregister_net_hook(net, &hook->ops); + flowtable->data.type->setup(&flowtable->data, hook->ops.dev, + FLOW_BLOCK_UNBIND); +} + +static void __nft_unregister_flowtable_net_hooks(struct net *net, + struct list_head *hook_list, + bool release_netdev) +{ + struct nft_hook *hook, *next; + + list_for_each_entry_safe(hook, next, hook_list, list) { + nf_unregister_net_hook(net, &hook->ops); + if (release_netdev) { + list_del(&hook->list); + kfree_rcu(hook, rcu); + } + } +} + +static void nft_unregister_flowtable_net_hooks(struct net *net, + struct list_head *hook_list) +{ + __nft_unregister_flowtable_net_hooks(net, hook_list, false); +} + +static int nft_register_flowtable_net_hooks(struct net *net, + struct nft_table *table, + struct list_head *hook_list, + struct nft_flowtable *flowtable) +{ + struct nft_hook *hook, *hook2, *next; + struct nft_flowtable *ft; + int err, i = 0; + + list_for_each_entry(hook, hook_list, list) { + list_for_each_entry(ft, &table->flowtables, list) { + if (!nft_is_active_next(net, ft)) + continue; + + list_for_each_entry(hook2, &ft->hook_list, list) { + if (hook->ops.dev == hook2->ops.dev && + hook->ops.pf == hook2->ops.pf) { + err = -EEXIST; + goto err_unregister_net_hooks; + } + } + } + + err = flowtable->data.type->setup(&flowtable->data, + hook->ops.dev, + FLOW_BLOCK_BIND); + if (err < 0) + goto err_unregister_net_hooks; + + err = nf_register_net_hook(net, &hook->ops); + if (err < 0) { + flowtable->data.type->setup(&flowtable->data, + hook->ops.dev, + FLOW_BLOCK_UNBIND); + goto err_unregister_net_hooks; + } + + i++; + } + + return 0; + +err_unregister_net_hooks: + list_for_each_entry_safe(hook, next, hook_list, list) { + if (i-- <= 0) + break; + + nft_unregister_flowtable_hook(net, flowtable, hook); + list_del_rcu(&hook->list); + kfree_rcu(hook, rcu); + } + + return err; +} + +static void nft_flowtable_hooks_destroy(struct list_head *hook_list) +{ + struct nft_hook *hook, *next; + + list_for_each_entry_safe(hook, next, hook_list, list) { + list_del_rcu(&hook->list); + kfree_rcu(hook, rcu); + } +} + +static int nft_flowtable_update(struct nft_ctx *ctx, const struct nlmsghdr *nlh, + struct nft_flowtable *flowtable) +{ + const struct nlattr * const *nla = ctx->nla; + struct nft_flowtable_hook flowtable_hook; + struct nft_hook *hook, *next; + struct nft_trans *trans; + bool unregister = false; + u32 flags; + int err; + + err = nft_flowtable_parse_hook(ctx, nla[NFTA_FLOWTABLE_HOOK], + &flowtable_hook, flowtable, false); + if (err < 0) + return err; + + list_for_each_entry_safe(hook, next, &flowtable_hook.list, list) { + if (nft_hook_list_find(&flowtable->hook_list, hook)) { + list_del(&hook->list); + kfree(hook); + } + } + + if (nla[NFTA_FLOWTABLE_FLAGS]) { + flags = ntohl(nla_get_be32(nla[NFTA_FLOWTABLE_FLAGS])); + if (flags & ~NFT_FLOWTABLE_MASK) { + err = -EOPNOTSUPP; + goto err_flowtable_update_hook; + } + if ((flowtable->data.flags & NFT_FLOWTABLE_HW_OFFLOAD) ^ + (flags & NFT_FLOWTABLE_HW_OFFLOAD)) { + err = -EOPNOTSUPP; + goto err_flowtable_update_hook; + } + } else { + flags = flowtable->data.flags; + } + + err = nft_register_flowtable_net_hooks(ctx->net, ctx->table, + &flowtable_hook.list, flowtable); + if (err < 0) + goto err_flowtable_update_hook; + + trans = nft_trans_alloc(ctx, NFT_MSG_NEWFLOWTABLE, + sizeof(struct nft_trans_flowtable)); + if (!trans) { + unregister = true; + err = -ENOMEM; + goto err_flowtable_update_hook; + } + + nft_trans_flowtable_flags(trans) = flags; + nft_trans_flowtable(trans) = flowtable; + nft_trans_flowtable_update(trans) = true; + INIT_LIST_HEAD(&nft_trans_flowtable_hooks(trans)); + list_splice(&flowtable_hook.list, &nft_trans_flowtable_hooks(trans)); + + nft_trans_commit_list_add_tail(ctx->net, trans); + + return 0; + +err_flowtable_update_hook: + list_for_each_entry_safe(hook, next, &flowtable_hook.list, list) { + if (unregister) + nft_unregister_flowtable_hook(ctx->net, flowtable, hook); + list_del_rcu(&hook->list); + kfree_rcu(hook, rcu); + } + + return err; + +} + +static int nf_tables_newflowtable(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct netlink_ext_ack *extack = info->extack; + struct nft_flowtable_hook flowtable_hook; + u8 genmask = nft_genmask_next(info->net); + u8 family = info->nfmsg->nfgen_family; + const struct nf_flowtable_type *type; + struct nft_flowtable *flowtable; + struct nft_hook *hook, *next; + struct net *net = info->net; + struct nft_table *table; + struct nft_ctx ctx; + int err; + + if (!nla[NFTA_FLOWTABLE_TABLE] || + !nla[NFTA_FLOWTABLE_NAME] || + !nla[NFTA_FLOWTABLE_HOOK]) + return -EINVAL; + + table = nft_table_lookup(net, nla[NFTA_FLOWTABLE_TABLE], family, + genmask, NETLINK_CB(skb).portid); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_FLOWTABLE_TABLE]); + return PTR_ERR(table); + } + + flowtable = nft_flowtable_lookup(table, nla[NFTA_FLOWTABLE_NAME], + genmask); + if (IS_ERR(flowtable)) { + err = PTR_ERR(flowtable); + if (err != -ENOENT) { + NL_SET_BAD_ATTR(extack, nla[NFTA_FLOWTABLE_NAME]); + return err; + } + } else { + if (info->nlh->nlmsg_flags & NLM_F_EXCL) { + NL_SET_BAD_ATTR(extack, nla[NFTA_FLOWTABLE_NAME]); + return -EEXIST; + } + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, NULL, nla); + + return nft_flowtable_update(&ctx, info->nlh, flowtable); + } + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, NULL, nla); + + if (!nft_use_inc(&table->use)) + return -EMFILE; + + flowtable = kzalloc(sizeof(*flowtable), GFP_KERNEL_ACCOUNT); + if (!flowtable) { + err = -ENOMEM; + goto flowtable_alloc; + } + + flowtable->table = table; + flowtable->handle = nf_tables_alloc_handle(table); + INIT_LIST_HEAD(&flowtable->hook_list); + + flowtable->name = nla_strdup(nla[NFTA_FLOWTABLE_NAME], GFP_KERNEL_ACCOUNT); + if (!flowtable->name) { + err = -ENOMEM; + goto err1; + } + + type = nft_flowtable_type_get(net, family); + if (IS_ERR(type)) { + err = PTR_ERR(type); + goto err2; + } + + if (nla[NFTA_FLOWTABLE_FLAGS]) { + flowtable->data.flags = + ntohl(nla_get_be32(nla[NFTA_FLOWTABLE_FLAGS])); + if (flowtable->data.flags & ~NFT_FLOWTABLE_MASK) { + err = -EOPNOTSUPP; + goto err3; + } + } + + write_pnet(&flowtable->data.net, net); + flowtable->data.type = type; + err = type->init(&flowtable->data); + if (err < 0) + goto err3; + + err = nft_flowtable_parse_hook(&ctx, nla[NFTA_FLOWTABLE_HOOK], + &flowtable_hook, flowtable, true); + if (err < 0) + goto err4; + + list_splice(&flowtable_hook.list, &flowtable->hook_list); + flowtable->data.priority = flowtable_hook.priority; + flowtable->hooknum = flowtable_hook.num; + + err = nft_register_flowtable_net_hooks(ctx.net, table, + &flowtable->hook_list, + flowtable); + if (err < 0) { + nft_flowtable_hooks_destroy(&flowtable->hook_list); + goto err4; + } + + err = nft_trans_flowtable_add(&ctx, NFT_MSG_NEWFLOWTABLE, flowtable); + if (err < 0) + goto err5; + + list_add_tail_rcu(&flowtable->list, &table->flowtables); + + return 0; +err5: + list_for_each_entry_safe(hook, next, &flowtable->hook_list, list) { + nft_unregister_flowtable_hook(net, flowtable, hook); + list_del_rcu(&hook->list); + kfree_rcu(hook, rcu); + } +err4: + flowtable->data.type->free(&flowtable->data); +err3: + module_put(type->owner); +err2: + kfree(flowtable->name); +err1: + kfree(flowtable); +flowtable_alloc: + nft_use_dec_restore(&table->use); + + return err; +} + +static void nft_flowtable_hook_release(struct nft_flowtable_hook *flowtable_hook) +{ + struct nft_hook *this, *next; + + list_for_each_entry_safe(this, next, &flowtable_hook->list, list) { + list_del(&this->list); + kfree(this); + } +} + +static int nft_delflowtable_hook(struct nft_ctx *ctx, + struct nft_flowtable *flowtable) +{ + const struct nlattr * const *nla = ctx->nla; + struct nft_flowtable_hook flowtable_hook; + LIST_HEAD(flowtable_del_list); + struct nft_hook *this, *hook; + struct nft_trans *trans; + int err; + + err = nft_flowtable_parse_hook(ctx, nla[NFTA_FLOWTABLE_HOOK], + &flowtable_hook, flowtable, false); + if (err < 0) + return err; + + list_for_each_entry(this, &flowtable_hook.list, list) { + hook = nft_hook_list_find(&flowtable->hook_list, this); + if (!hook) { + err = -ENOENT; + goto err_flowtable_del_hook; + } + list_move(&hook->list, &flowtable_del_list); + } + + trans = nft_trans_alloc(ctx, NFT_MSG_DELFLOWTABLE, + sizeof(struct nft_trans_flowtable)); + if (!trans) { + err = -ENOMEM; + goto err_flowtable_del_hook; + } + + nft_trans_flowtable(trans) = flowtable; + nft_trans_flowtable_update(trans) = true; + INIT_LIST_HEAD(&nft_trans_flowtable_hooks(trans)); + list_splice(&flowtable_del_list, &nft_trans_flowtable_hooks(trans)); + nft_flowtable_hook_release(&flowtable_hook); + + nft_trans_commit_list_add_tail(ctx->net, trans); + + return 0; + +err_flowtable_del_hook: + list_splice(&flowtable_del_list, &flowtable->hook_list); + nft_flowtable_hook_release(&flowtable_hook); + + return err; +} + +static int nf_tables_delflowtable(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct netlink_ext_ack *extack = info->extack; + u8 genmask = nft_genmask_next(info->net); + u8 family = info->nfmsg->nfgen_family; + struct nft_flowtable *flowtable; + struct net *net = info->net; + const struct nlattr *attr; + struct nft_table *table; + struct nft_ctx ctx; + + if (!nla[NFTA_FLOWTABLE_TABLE] || + (!nla[NFTA_FLOWTABLE_NAME] && + !nla[NFTA_FLOWTABLE_HANDLE])) + return -EINVAL; + + table = nft_table_lookup(net, nla[NFTA_FLOWTABLE_TABLE], family, + genmask, NETLINK_CB(skb).portid); + if (IS_ERR(table)) { + NL_SET_BAD_ATTR(extack, nla[NFTA_FLOWTABLE_TABLE]); + return PTR_ERR(table); + } + + if (nla[NFTA_FLOWTABLE_HANDLE]) { + attr = nla[NFTA_FLOWTABLE_HANDLE]; + flowtable = nft_flowtable_lookup_byhandle(table, attr, genmask); + } else { + attr = nla[NFTA_FLOWTABLE_NAME]; + flowtable = nft_flowtable_lookup(table, attr, genmask); + } + + if (IS_ERR(flowtable)) { + NL_SET_BAD_ATTR(extack, attr); + return PTR_ERR(flowtable); + } + + nft_ctx_init(&ctx, net, skb, info->nlh, family, table, NULL, nla); + + if (nla[NFTA_FLOWTABLE_HOOK]) + return nft_delflowtable_hook(&ctx, flowtable); + + if (flowtable->use > 0) { + NL_SET_BAD_ATTR(extack, attr); + return -EBUSY; + } + + return nft_delflowtable(&ctx, flowtable); +} + +static int nf_tables_fill_flowtable_info(struct sk_buff *skb, struct net *net, + u32 portid, u32 seq, int event, + u32 flags, int family, + struct nft_flowtable *flowtable, + struct list_head *hook_list) +{ + struct nlattr *nest, *nest_devs; + struct nft_hook *hook; + struct nlmsghdr *nlh; + + event = nfnl_msg_type(NFNL_SUBSYS_NFTABLES, event); + nlh = nfnl_msg_put(skb, portid, seq, event, flags, family, + NFNETLINK_V0, nft_base_seq(net)); + if (!nlh) + goto nla_put_failure; + + if (nla_put_string(skb, NFTA_FLOWTABLE_TABLE, flowtable->table->name) || + nla_put_string(skb, NFTA_FLOWTABLE_NAME, flowtable->name) || + nla_put_be32(skb, NFTA_FLOWTABLE_USE, htonl(flowtable->use)) || + nla_put_be64(skb, NFTA_FLOWTABLE_HANDLE, cpu_to_be64(flowtable->handle), + NFTA_FLOWTABLE_PAD) || + nla_put_be32(skb, NFTA_FLOWTABLE_FLAGS, htonl(flowtable->data.flags))) + goto nla_put_failure; + + nest = nla_nest_start_noflag(skb, NFTA_FLOWTABLE_HOOK); + if (!nest) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_FLOWTABLE_HOOK_NUM, htonl(flowtable->hooknum)) || + nla_put_be32(skb, NFTA_FLOWTABLE_HOOK_PRIORITY, htonl(flowtable->data.priority))) + goto nla_put_failure; + + nest_devs = nla_nest_start_noflag(skb, NFTA_FLOWTABLE_HOOK_DEVS); + if (!nest_devs) + goto nla_put_failure; + + list_for_each_entry_rcu(hook, hook_list, list) { + if (nla_put_string(skb, NFTA_DEVICE_NAME, hook->ops.dev->name)) + goto nla_put_failure; + } + nla_nest_end(skb, nest_devs); + nla_nest_end(skb, nest); + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_trim(skb, nlh); + return -1; +} + +struct nft_flowtable_filter { + char *table; +}; + +static int nf_tables_dump_flowtable(struct sk_buff *skb, + struct netlink_callback *cb) +{ + const struct nfgenmsg *nfmsg = nlmsg_data(cb->nlh); + struct nft_flowtable_filter *filter = cb->data; + unsigned int idx = 0, s_idx = cb->args[0]; + struct net *net = sock_net(skb->sk); + int family = nfmsg->nfgen_family; + struct nft_flowtable *flowtable; + struct nftables_pernet *nft_net; + const struct nft_table *table; + + rcu_read_lock(); + nft_net = nft_pernet(net); + cb->seq = READ_ONCE(nft_net->base_seq); + + list_for_each_entry_rcu(table, &nft_net->tables, list) { + if (family != NFPROTO_UNSPEC && family != table->family) + continue; + + list_for_each_entry_rcu(flowtable, &table->flowtables, list) { + if (!nft_is_active(net, flowtable)) + goto cont; + if (idx < s_idx) + goto cont; + if (idx > s_idx) + memset(&cb->args[1], 0, + sizeof(cb->args) - sizeof(cb->args[0])); + if (filter && filter->table && + strcmp(filter->table, table->name)) + goto cont; + + if (nf_tables_fill_flowtable_info(skb, net, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NFT_MSG_NEWFLOWTABLE, + NLM_F_MULTI | NLM_F_APPEND, + table->family, + flowtable, + &flowtable->hook_list) < 0) + goto done; + + nl_dump_check_consistent(cb, nlmsg_hdr(skb)); +cont: + idx++; + } + } +done: + rcu_read_unlock(); + + cb->args[0] = idx; + return skb->len; +} + +static int nf_tables_dump_flowtable_start(struct netlink_callback *cb) +{ + const struct nlattr * const *nla = cb->data; + struct nft_flowtable_filter *filter = NULL; + + if (nla[NFTA_FLOWTABLE_TABLE]) { + filter = kzalloc(sizeof(*filter), GFP_ATOMIC); + if (!filter) + return -ENOMEM; + + filter->table = nla_strdup(nla[NFTA_FLOWTABLE_TABLE], + GFP_ATOMIC); + if (!filter->table) { + kfree(filter); + return -ENOMEM; + } + } + + cb->data = filter; + return 0; +} + +static int nf_tables_dump_flowtable_done(struct netlink_callback *cb) +{ + struct nft_flowtable_filter *filter = cb->data; + + if (!filter) + return 0; + + kfree(filter->table); + kfree(filter); + + return 0; +} + +/* called with rcu_read_lock held */ +static int nf_tables_getflowtable(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + u8 genmask = nft_genmask_cur(info->net); + u8 family = info->nfmsg->nfgen_family; + struct nft_flowtable *flowtable; + const struct nft_table *table; + struct net *net = info->net; + struct sk_buff *skb2; + int err; + + if (info->nlh->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .start = nf_tables_dump_flowtable_start, + .dump = nf_tables_dump_flowtable, + .done = nf_tables_dump_flowtable_done, + .module = THIS_MODULE, + .data = (void *)nla, + }; + + return nft_netlink_dump_start_rcu(info->sk, skb, info->nlh, &c); + } + + if (!nla[NFTA_FLOWTABLE_NAME]) + return -EINVAL; + + table = nft_table_lookup(net, nla[NFTA_FLOWTABLE_TABLE], family, + genmask, 0); + if (IS_ERR(table)) + return PTR_ERR(table); + + flowtable = nft_flowtable_lookup(table, nla[NFTA_FLOWTABLE_NAME], + genmask); + if (IS_ERR(flowtable)) + return PTR_ERR(flowtable); + + skb2 = alloc_skb(NLMSG_GOODSIZE, GFP_ATOMIC); + if (!skb2) + return -ENOMEM; + + err = nf_tables_fill_flowtable_info(skb2, net, NETLINK_CB(skb).portid, + info->nlh->nlmsg_seq, + NFT_MSG_NEWFLOWTABLE, 0, family, + flowtable, &flowtable->hook_list); + if (err < 0) + goto err_fill_flowtable_info; + + return nfnetlink_unicast(skb2, net, NETLINK_CB(skb).portid); + +err_fill_flowtable_info: + kfree_skb(skb2); + return err; +} + +static void nf_tables_flowtable_notify(struct nft_ctx *ctx, + struct nft_flowtable *flowtable, + struct list_head *hook_list, + int event) +{ + struct nftables_pernet *nft_net = nft_pernet(ctx->net); + struct sk_buff *skb; + u16 flags = 0; + int err; + + if (!ctx->report && + !nfnetlink_has_listeners(ctx->net, NFNLGRP_NFTABLES)) + return; + + skb = nlmsg_new(NLMSG_GOODSIZE, GFP_KERNEL); + if (skb == NULL) + goto err; + + if (ctx->flags & (NLM_F_CREATE | NLM_F_EXCL)) + flags |= ctx->flags & (NLM_F_CREATE | NLM_F_EXCL); + + err = nf_tables_fill_flowtable_info(skb, ctx->net, ctx->portid, + ctx->seq, event, flags, + ctx->family, flowtable, hook_list); + if (err < 0) { + kfree_skb(skb); + goto err; + } + + nft_notify_enqueue(skb, ctx->report, &nft_net->notify_list); + return; +err: + nfnetlink_set_err(ctx->net, ctx->portid, NFNLGRP_NFTABLES, -ENOBUFS); +} + +static void nf_tables_flowtable_destroy(struct nft_flowtable *flowtable) +{ + struct nft_hook *hook, *next; + + flowtable->data.type->free(&flowtable->data); + list_for_each_entry_safe(hook, next, &flowtable->hook_list, list) { + flowtable->data.type->setup(&flowtable->data, hook->ops.dev, + FLOW_BLOCK_UNBIND); + list_del_rcu(&hook->list); + kfree(hook); + } + kfree(flowtable->name); + module_put(flowtable->data.type->owner); + kfree(flowtable); +} + +static int nf_tables_fill_gen_info(struct sk_buff *skb, struct net *net, + u32 portid, u32 seq) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + struct nlmsghdr *nlh; + char buf[TASK_COMM_LEN]; + int event = nfnl_msg_type(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWGEN); + + nlh = nfnl_msg_put(skb, portid, seq, event, 0, AF_UNSPEC, + NFNETLINK_V0, nft_base_seq(net)); + if (!nlh) + goto nla_put_failure; + + if (nla_put_be32(skb, NFTA_GEN_ID, htonl(nft_net->base_seq)) || + nla_put_be32(skb, NFTA_GEN_PROC_PID, htonl(task_pid_nr(current))) || + nla_put_string(skb, NFTA_GEN_PROC_NAME, get_task_comm(buf, current))) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_trim(skb, nlh); + return -EMSGSIZE; +} + +static void nft_flowtable_event(unsigned long event, struct net_device *dev, + struct nft_flowtable *flowtable) +{ + struct nft_hook *hook; + + list_for_each_entry(hook, &flowtable->hook_list, list) { + if (hook->ops.dev != dev) + continue; + + /* flow_offload_netdev_event() cleans up entries for us. */ + nft_unregister_flowtable_hook(dev_net(dev), flowtable, hook); + list_del_rcu(&hook->list); + kfree_rcu(hook, rcu); + break; + } +} + +static int nf_tables_flowtable_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct nft_flowtable *flowtable; + struct nftables_pernet *nft_net; + struct nft_table *table; + struct net *net; + + if (event != NETDEV_UNREGISTER) + return 0; + + net = dev_net(dev); + nft_net = nft_pernet(net); + mutex_lock(&nft_net->commit_mutex); + list_for_each_entry(table, &nft_net->tables, list) { + list_for_each_entry(flowtable, &table->flowtables, list) { + nft_flowtable_event(event, dev, flowtable); + } + } + mutex_unlock(&nft_net->commit_mutex); + + return NOTIFY_DONE; +} + +static struct notifier_block nf_tables_flowtable_notifier = { + .notifier_call = nf_tables_flowtable_event, +}; + +static void nf_tables_gen_notify(struct net *net, struct sk_buff *skb, + int event) +{ + struct nlmsghdr *nlh = nlmsg_hdr(skb); + struct sk_buff *skb2; + int err; + + if (!nlmsg_report(nlh) && + !nfnetlink_has_listeners(net, NFNLGRP_NFTABLES)) + return; + + skb2 = nlmsg_new(NLMSG_GOODSIZE, GFP_KERNEL); + if (skb2 == NULL) + goto err; + + err = nf_tables_fill_gen_info(skb2, net, NETLINK_CB(skb).portid, + nlh->nlmsg_seq); + if (err < 0) { + kfree_skb(skb2); + goto err; + } + + nfnetlink_send(skb2, net, NETLINK_CB(skb).portid, NFNLGRP_NFTABLES, + nlmsg_report(nlh), GFP_KERNEL); + return; +err: + nfnetlink_set_err(net, NETLINK_CB(skb).portid, NFNLGRP_NFTABLES, + -ENOBUFS); +} + +static int nf_tables_getgen(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + struct sk_buff *skb2; + int err; + + skb2 = alloc_skb(NLMSG_GOODSIZE, GFP_ATOMIC); + if (skb2 == NULL) + return -ENOMEM; + + err = nf_tables_fill_gen_info(skb2, info->net, NETLINK_CB(skb).portid, + info->nlh->nlmsg_seq); + if (err < 0) + goto err_fill_gen_info; + + return nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid); + +err_fill_gen_info: + kfree_skb(skb2); + return err; +} + +static const struct nfnl_callback nf_tables_cb[NFT_MSG_MAX] = { + [NFT_MSG_NEWTABLE] = { + .call = nf_tables_newtable, + .type = NFNL_CB_BATCH, + .attr_count = NFTA_TABLE_MAX, + .policy = nft_table_policy, + }, + [NFT_MSG_GETTABLE] = { + .call = nf_tables_gettable, + .type = NFNL_CB_RCU, + .attr_count = NFTA_TABLE_MAX, + .policy = nft_table_policy, + }, + [NFT_MSG_DELTABLE] = { + .call = nf_tables_deltable, + .type = NFNL_CB_BATCH, + .attr_count = NFTA_TABLE_MAX, + .policy = nft_table_policy, + }, + [NFT_MSG_NEWCHAIN] = { + .call = nf_tables_newchain, + .type = NFNL_CB_BATCH, + .attr_count = NFTA_CHAIN_MAX, + .policy = nft_chain_policy, + }, + [NFT_MSG_GETCHAIN] = { + .call = nf_tables_getchain, + .type = NFNL_CB_RCU, + .attr_count = NFTA_CHAIN_MAX, + .policy = nft_chain_policy, + }, + [NFT_MSG_DELCHAIN] = { + .call = nf_tables_delchain, + .type = NFNL_CB_BATCH, + .attr_count = NFTA_CHAIN_MAX, + .policy = nft_chain_policy, + }, + [NFT_MSG_NEWRULE] = { + .call = nf_tables_newrule, + .type = NFNL_CB_BATCH, + .attr_count = NFTA_RULE_MAX, + .policy = nft_rule_policy, + }, + [NFT_MSG_GETRULE] = { + .call = nf_tables_getrule, + .type = NFNL_CB_RCU, + .attr_count = NFTA_RULE_MAX, + .policy = nft_rule_policy, + }, + [NFT_MSG_DELRULE] = { + .call = nf_tables_delrule, + .type = NFNL_CB_BATCH, + .attr_count = NFTA_RULE_MAX, + .policy = nft_rule_policy, + }, + [NFT_MSG_NEWSET] = { + .call = nf_tables_newset, + .type = NFNL_CB_BATCH, + .attr_count = NFTA_SET_MAX, + .policy = nft_set_policy, + }, + [NFT_MSG_GETSET] = { + .call = nf_tables_getset, + .type = NFNL_CB_RCU, + .attr_count = NFTA_SET_MAX, + .policy = nft_set_policy, + }, + [NFT_MSG_DELSET] = { + .call = nf_tables_delset, + .type = NFNL_CB_BATCH, + .attr_count = NFTA_SET_MAX, + .policy = nft_set_policy, + }, + [NFT_MSG_NEWSETELEM] = { + .call = nf_tables_newsetelem, + .type = NFNL_CB_BATCH, + .attr_count = NFTA_SET_ELEM_LIST_MAX, + .policy = nft_set_elem_list_policy, + }, + [NFT_MSG_GETSETELEM] = { + .call = nf_tables_getsetelem, + .type = NFNL_CB_RCU, + .attr_count = NFTA_SET_ELEM_LIST_MAX, + .policy = nft_set_elem_list_policy, + }, + [NFT_MSG_DELSETELEM] = { + .call = nf_tables_delsetelem, + .type = NFNL_CB_BATCH, + .attr_count = NFTA_SET_ELEM_LIST_MAX, + .policy = nft_set_elem_list_policy, + }, + [NFT_MSG_GETGEN] = { + .call = nf_tables_getgen, + .type = NFNL_CB_RCU, + }, + [NFT_MSG_NEWOBJ] = { + .call = nf_tables_newobj, + .type = NFNL_CB_BATCH, + .attr_count = NFTA_OBJ_MAX, + .policy = nft_obj_policy, + }, + [NFT_MSG_GETOBJ] = { + .call = nf_tables_getobj, + .type = NFNL_CB_RCU, + .attr_count = NFTA_OBJ_MAX, + .policy = nft_obj_policy, + }, + [NFT_MSG_DELOBJ] = { + .call = nf_tables_delobj, + .type = NFNL_CB_BATCH, + .attr_count = NFTA_OBJ_MAX, + .policy = nft_obj_policy, + }, + [NFT_MSG_GETOBJ_RESET] = { + .call = nf_tables_getobj, + .type = NFNL_CB_RCU, + .attr_count = NFTA_OBJ_MAX, + .policy = nft_obj_policy, + }, + [NFT_MSG_NEWFLOWTABLE] = { + .call = nf_tables_newflowtable, + .type = NFNL_CB_BATCH, + .attr_count = NFTA_FLOWTABLE_MAX, + .policy = nft_flowtable_policy, + }, + [NFT_MSG_GETFLOWTABLE] = { + .call = nf_tables_getflowtable, + .type = NFNL_CB_RCU, + .attr_count = NFTA_FLOWTABLE_MAX, + .policy = nft_flowtable_policy, + }, + [NFT_MSG_DELFLOWTABLE] = { + .call = nf_tables_delflowtable, + .type = NFNL_CB_BATCH, + .attr_count = NFTA_FLOWTABLE_MAX, + .policy = nft_flowtable_policy, + }, +}; + +static int nf_tables_validate(struct net *net) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + struct nft_table *table; + + switch (nft_net->validate_state) { + case NFT_VALIDATE_SKIP: + break; + case NFT_VALIDATE_NEED: + nft_validate_state_update(net, NFT_VALIDATE_DO); + fallthrough; + case NFT_VALIDATE_DO: + list_for_each_entry(table, &nft_net->tables, list) { + if (nft_table_validate(net, table) < 0) + return -EAGAIN; + } + + nft_validate_state_update(net, NFT_VALIDATE_SKIP); + break; + } + + return 0; +} + +/* a drop policy has to be deferred until all rules have been activated, + * otherwise a large ruleset that contains a drop-policy base chain will + * cause all packets to get dropped until the full transaction has been + * processed. + * + * We defer the drop policy until the transaction has been finalized. + */ +static void nft_chain_commit_drop_policy(struct nft_trans *trans) +{ + struct nft_base_chain *basechain; + + if (nft_trans_chain_policy(trans) != NF_DROP) + return; + + if (!nft_is_base_chain(trans->ctx.chain)) + return; + + basechain = nft_base_chain(trans->ctx.chain); + basechain->policy = NF_DROP; +} + +static void nft_chain_commit_update(struct nft_trans *trans) +{ + struct nft_base_chain *basechain; + + if (nft_trans_chain_name(trans)) { + rhltable_remove(&trans->ctx.table->chains_ht, + &trans->ctx.chain->rhlhead, + nft_chain_ht_params); + swap(trans->ctx.chain->name, nft_trans_chain_name(trans)); + rhltable_insert_key(&trans->ctx.table->chains_ht, + trans->ctx.chain->name, + &trans->ctx.chain->rhlhead, + nft_chain_ht_params); + } + + if (!nft_is_base_chain(trans->ctx.chain)) + return; + + nft_chain_stats_replace(trans); + + basechain = nft_base_chain(trans->ctx.chain); + + switch (nft_trans_chain_policy(trans)) { + case NF_DROP: + case NF_ACCEPT: + basechain->policy = nft_trans_chain_policy(trans); + break; + } +} + +static void nft_obj_commit_update(struct nft_trans *trans) +{ + struct nft_object *newobj; + struct nft_object *obj; + + obj = nft_trans_obj(trans); + newobj = nft_trans_obj_newobj(trans); + + if (obj->ops->update) + obj->ops->update(obj, newobj); + + nft_obj_destroy(&trans->ctx, newobj); +} + +static void nft_commit_release(struct nft_trans *trans) +{ + switch (trans->msg_type) { + case NFT_MSG_DELTABLE: + nf_tables_table_destroy(&trans->ctx); + break; + case NFT_MSG_NEWCHAIN: + free_percpu(nft_trans_chain_stats(trans)); + kfree(nft_trans_chain_name(trans)); + break; + case NFT_MSG_DELCHAIN: + nf_tables_chain_destroy(&trans->ctx); + break; + case NFT_MSG_DELRULE: + nf_tables_rule_destroy(&trans->ctx, nft_trans_rule(trans)); + break; + case NFT_MSG_DELSET: + nft_set_destroy(&trans->ctx, nft_trans_set(trans)); + break; + case NFT_MSG_DELSETELEM: + nf_tables_set_elem_destroy(&trans->ctx, + nft_trans_elem_set(trans), + nft_trans_elem(trans).priv); + break; + case NFT_MSG_DELOBJ: + nft_obj_destroy(&trans->ctx, nft_trans_obj(trans)); + break; + case NFT_MSG_DELFLOWTABLE: + if (nft_trans_flowtable_update(trans)) + nft_flowtable_hooks_destroy(&nft_trans_flowtable_hooks(trans)); + else + nf_tables_flowtable_destroy(nft_trans_flowtable(trans)); + break; + } + + if (trans->put_net) + put_net(trans->ctx.net); + + kfree(trans); +} + +static void nf_tables_trans_destroy_work(struct work_struct *w) +{ + struct nft_trans *trans, *next; + LIST_HEAD(head); + + spin_lock(&nf_tables_destroy_list_lock); + list_splice_init(&nf_tables_destroy_list, &head); + spin_unlock(&nf_tables_destroy_list_lock); + + if (list_empty(&head)) + return; + + synchronize_rcu(); + + list_for_each_entry_safe(trans, next, &head, list) { + nft_trans_list_del(trans); + nft_commit_release(trans); + } +} + +void nf_tables_trans_destroy_flush_work(void) +{ + flush_work(&trans_destroy_work); +} +EXPORT_SYMBOL_GPL(nf_tables_trans_destroy_flush_work); + +static bool nft_expr_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + return false; +} + +static int nf_tables_commit_chain_prepare(struct net *net, struct nft_chain *chain) +{ + const struct nft_expr *expr, *last; + struct nft_regs_track track = {}; + unsigned int size, data_size; + void *data, *data_boundary; + struct nft_rule_dp *prule; + struct nft_rule *rule; + + /* already handled or inactive chain? */ + if (chain->blob_next || !nft_is_active_next(net, chain)) + return 0; + + data_size = 0; + list_for_each_entry(rule, &chain->rules, list) { + if (nft_is_active_next(net, rule)) { + data_size += sizeof(*prule) + rule->dlen; + if (data_size > INT_MAX) + return -ENOMEM; + } + } + data_size += offsetof(struct nft_rule_dp, data); /* last rule */ + + chain->blob_next = nf_tables_chain_alloc_rules(data_size); + if (!chain->blob_next) + return -ENOMEM; + + data = (void *)chain->blob_next->data; + data_boundary = data + data_size; + size = 0; + + list_for_each_entry(rule, &chain->rules, list) { + if (!nft_is_active_next(net, rule)) + continue; + + prule = (struct nft_rule_dp *)data; + data += offsetof(struct nft_rule_dp, data); + if (WARN_ON_ONCE(data > data_boundary)) + return -ENOMEM; + + size = 0; + track.last = nft_expr_last(rule); + nft_rule_for_each_expr(expr, last, rule) { + track.cur = expr; + + if (nft_expr_reduce(&track, expr)) { + expr = track.cur; + continue; + } + + if (WARN_ON_ONCE(data + size + expr->ops->size > data_boundary)) + return -ENOMEM; + + memcpy(data + size, expr, expr->ops->size); + size += expr->ops->size; + } + if (WARN_ON_ONCE(size >= 1 << 12)) + return -ENOMEM; + + prule->handle = rule->handle; + prule->dlen = size; + prule->is_last = 0; + + data += size; + size = 0; + chain->blob_next->size += (unsigned long)(data - (void *)prule); + } + + prule = (struct nft_rule_dp *)data; + data += offsetof(struct nft_rule_dp, data); + if (WARN_ON_ONCE(data > data_boundary)) + return -ENOMEM; + + nft_last_rule(chain->blob_next, prule); + + return 0; +} + +static void nf_tables_commit_chain_prepare_cancel(struct net *net) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + struct nft_trans *trans, *next; + + list_for_each_entry_safe(trans, next, &nft_net->commit_list, list) { + struct nft_chain *chain = trans->ctx.chain; + + if (trans->msg_type == NFT_MSG_NEWRULE || + trans->msg_type == NFT_MSG_DELRULE) { + kvfree(chain->blob_next); + chain->blob_next = NULL; + } + } +} + +static void __nf_tables_commit_chain_free_rules_old(struct rcu_head *h) +{ + struct nft_rules_old *o = container_of(h, struct nft_rules_old, h); + + kvfree(o->blob); +} + +static void nf_tables_commit_chain_free_rules_old(struct nft_rule_blob *blob) +{ + struct nft_rules_old *old; + + /* rcu_head is after end marker */ + old = (void *)blob + sizeof(*blob) + blob->size; + old->blob = blob; + + call_rcu(&old->h, __nf_tables_commit_chain_free_rules_old); +} + +static void nf_tables_commit_chain(struct net *net, struct nft_chain *chain) +{ + struct nft_rule_blob *g0, *g1; + bool next_genbit; + + next_genbit = nft_gencursor_next(net); + + g0 = rcu_dereference_protected(chain->blob_gen_0, + lockdep_commit_lock_is_held(net)); + g1 = rcu_dereference_protected(chain->blob_gen_1, + lockdep_commit_lock_is_held(net)); + + /* No changes to this chain? */ + if (chain->blob_next == NULL) { + /* chain had no change in last or next generation */ + if (g0 == g1) + return; + /* + * chain had no change in this generation; make sure next + * one uses same rules as current generation. + */ + if (next_genbit) { + rcu_assign_pointer(chain->blob_gen_1, g0); + nf_tables_commit_chain_free_rules_old(g1); + } else { + rcu_assign_pointer(chain->blob_gen_0, g1); + nf_tables_commit_chain_free_rules_old(g0); + } + + return; + } + + if (next_genbit) + rcu_assign_pointer(chain->blob_gen_1, chain->blob_next); + else + rcu_assign_pointer(chain->blob_gen_0, chain->blob_next); + + chain->blob_next = NULL; + + if (g0 == g1) + return; + + if (next_genbit) + nf_tables_commit_chain_free_rules_old(g1); + else + nf_tables_commit_chain_free_rules_old(g0); +} + +static void nft_obj_del(struct nft_object *obj) +{ + rhltable_remove(&nft_objname_ht, &obj->rhlhead, nft_objname_ht_params); + list_del_rcu(&obj->list); +} + +void nft_chain_del(struct nft_chain *chain) +{ + struct nft_table *table = chain->table; + + WARN_ON_ONCE(rhltable_remove(&table->chains_ht, &chain->rhlhead, + nft_chain_ht_params)); + list_del_rcu(&chain->list); +} + +static void nft_trans_gc_setelem_remove(struct nft_ctx *ctx, + struct nft_trans_gc *trans) +{ + void **priv = trans->priv; + unsigned int i; + + for (i = 0; i < trans->count; i++) { + struct nft_set_elem elem = { + .priv = priv[i], + }; + + nft_setelem_data_deactivate(ctx->net, trans->set, &elem); + nft_setelem_remove(ctx->net, trans->set, &elem); + } +} + +void nft_trans_gc_destroy(struct nft_trans_gc *trans) +{ + nft_set_put(trans->set); + put_net(trans->net); + kfree(trans); +} + +static void nft_trans_gc_trans_free(struct rcu_head *rcu) +{ + struct nft_set_elem elem = {}; + struct nft_trans_gc *trans; + struct nft_ctx ctx = {}; + unsigned int i; + + trans = container_of(rcu, struct nft_trans_gc, rcu); + ctx.net = read_pnet(&trans->set->net); + + for (i = 0; i < trans->count; i++) { + elem.priv = trans->priv[i]; + if (!nft_setelem_is_catchall(trans->set, &elem)) + atomic_dec(&trans->set->nelems); + + nf_tables_set_elem_destroy(&ctx, trans->set, elem.priv); + } + + nft_trans_gc_destroy(trans); +} + +static bool nft_trans_gc_work_done(struct nft_trans_gc *trans) +{ + struct nftables_pernet *nft_net; + struct nft_ctx ctx = {}; + + nft_net = nft_pernet(trans->net); + + mutex_lock(&nft_net->commit_mutex); + + /* Check for race with transaction, otherwise this batch refers to + * stale objects that might not be there anymore. Skip transaction if + * set has been destroyed from control plane transaction in case gc + * worker loses race. + */ + if (READ_ONCE(nft_net->gc_seq) != trans->seq || trans->set->dead) { + mutex_unlock(&nft_net->commit_mutex); + return false; + } + + ctx.net = trans->net; + ctx.table = trans->set->table; + + nft_trans_gc_setelem_remove(&ctx, trans); + mutex_unlock(&nft_net->commit_mutex); + + return true; +} + +static void nft_trans_gc_work(struct work_struct *work) +{ + struct nft_trans_gc *trans, *next; + LIST_HEAD(trans_gc_list); + + spin_lock(&nf_tables_gc_list_lock); + list_splice_init(&nf_tables_gc_list, &trans_gc_list); + spin_unlock(&nf_tables_gc_list_lock); + + list_for_each_entry_safe(trans, next, &trans_gc_list, list) { + list_del(&trans->list); + if (!nft_trans_gc_work_done(trans)) { + nft_trans_gc_destroy(trans); + continue; + } + call_rcu(&trans->rcu, nft_trans_gc_trans_free); + } +} + +struct nft_trans_gc *nft_trans_gc_alloc(struct nft_set *set, + unsigned int gc_seq, gfp_t gfp) +{ + struct net *net = read_pnet(&set->net); + struct nft_trans_gc *trans; + + trans = kzalloc(sizeof(*trans), gfp); + if (!trans) + return NULL; + + trans->net = maybe_get_net(net); + if (!trans->net) { + kfree(trans); + return NULL; + } + + refcount_inc(&set->refs); + trans->set = set; + trans->seq = gc_seq; + + return trans; +} + +void nft_trans_gc_elem_add(struct nft_trans_gc *trans, void *priv) +{ + trans->priv[trans->count++] = priv; +} + +static void nft_trans_gc_queue_work(struct nft_trans_gc *trans) +{ + spin_lock(&nf_tables_gc_list_lock); + list_add_tail(&trans->list, &nf_tables_gc_list); + spin_unlock(&nf_tables_gc_list_lock); + + schedule_work(&trans_gc_work); +} + +static int nft_trans_gc_space(struct nft_trans_gc *trans) +{ + return NFT_TRANS_GC_BATCHCOUNT - trans->count; +} + +struct nft_trans_gc *nft_trans_gc_queue_async(struct nft_trans_gc *gc, + unsigned int gc_seq, gfp_t gfp) +{ + struct nft_set *set; + + if (nft_trans_gc_space(gc)) + return gc; + + set = gc->set; + nft_trans_gc_queue_work(gc); + + return nft_trans_gc_alloc(set, gc_seq, gfp); +} + +void nft_trans_gc_queue_async_done(struct nft_trans_gc *trans) +{ + if (trans->count == 0) { + nft_trans_gc_destroy(trans); + return; + } + + nft_trans_gc_queue_work(trans); +} + +struct nft_trans_gc *nft_trans_gc_queue_sync(struct nft_trans_gc *gc, gfp_t gfp) +{ + struct nft_set *set; + + if (WARN_ON_ONCE(!lockdep_commit_lock_is_held(gc->net))) + return NULL; + + if (nft_trans_gc_space(gc)) + return gc; + + set = gc->set; + call_rcu(&gc->rcu, nft_trans_gc_trans_free); + + return nft_trans_gc_alloc(set, 0, gfp); +} + +void nft_trans_gc_queue_sync_done(struct nft_trans_gc *trans) +{ + WARN_ON_ONCE(!lockdep_commit_lock_is_held(trans->net)); + + if (trans->count == 0) { + nft_trans_gc_destroy(trans); + return; + } + + call_rcu(&trans->rcu, nft_trans_gc_trans_free); +} + +struct nft_trans_gc *nft_trans_gc_catchall_async(struct nft_trans_gc *gc, + unsigned int gc_seq) +{ + struct nft_set_elem_catchall *catchall; + const struct nft_set *set = gc->set; + struct nft_set_ext *ext; + + list_for_each_entry_rcu(catchall, &set->catchall_list, list) { + ext = nft_set_elem_ext(set, catchall->elem); + + if (!nft_set_elem_expired(ext)) + continue; + if (nft_set_elem_is_dead(ext)) + goto dead_elem; + + nft_set_elem_dead(ext); +dead_elem: + gc = nft_trans_gc_queue_async(gc, gc_seq, GFP_ATOMIC); + if (!gc) + return NULL; + + nft_trans_gc_elem_add(gc, catchall->elem); + } + + return gc; +} + +struct nft_trans_gc *nft_trans_gc_catchall_sync(struct nft_trans_gc *gc) +{ + struct nft_set_elem_catchall *catchall, *next; + const struct nft_set *set = gc->set; + struct nft_set_elem elem; + struct nft_set_ext *ext; + + WARN_ON_ONCE(!lockdep_commit_lock_is_held(gc->net)); + + list_for_each_entry_safe(catchall, next, &set->catchall_list, list) { + ext = nft_set_elem_ext(set, catchall->elem); + + if (!nft_set_elem_expired(ext)) + continue; + + gc = nft_trans_gc_queue_sync(gc, GFP_KERNEL); + if (!gc) + return NULL; + + memset(&elem, 0, sizeof(elem)); + elem.priv = catchall->elem; + + nft_setelem_data_deactivate(gc->net, gc->set, &elem); + nft_setelem_catchall_destroy(catchall); + nft_trans_gc_elem_add(gc, elem.priv); + } + + return gc; +} + +static void nf_tables_module_autoload_cleanup(struct net *net) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + struct nft_module_request *req, *next; + + WARN_ON_ONCE(!list_empty(&nft_net->commit_list)); + list_for_each_entry_safe(req, next, &nft_net->module_list, list) { + WARN_ON_ONCE(!req->done); + list_del(&req->list); + kfree(req); + } +} + +static void nf_tables_commit_release(struct net *net) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + struct nft_trans *trans; + + /* all side effects have to be made visible. + * For example, if a chain named 'foo' has been deleted, a + * new transaction must not find it anymore. + * + * Memory reclaim happens asynchronously from work queue + * to prevent expensive synchronize_rcu() in commit phase. + */ + if (list_empty(&nft_net->commit_list)) { + nf_tables_module_autoload_cleanup(net); + mutex_unlock(&nft_net->commit_mutex); + return; + } + + trans = list_last_entry(&nft_net->commit_list, + struct nft_trans, list); + get_net(trans->ctx.net); + WARN_ON_ONCE(trans->put_net); + + trans->put_net = true; + spin_lock(&nf_tables_destroy_list_lock); + list_splice_tail_init(&nft_net->commit_list, &nf_tables_destroy_list); + spin_unlock(&nf_tables_destroy_list_lock); + + nf_tables_module_autoload_cleanup(net); + schedule_work(&trans_destroy_work); + + mutex_unlock(&nft_net->commit_mutex); +} + +static void nft_commit_notify(struct net *net, u32 portid) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + struct sk_buff *batch_skb = NULL, *nskb, *skb; + unsigned char *data; + int len; + + list_for_each_entry_safe(skb, nskb, &nft_net->notify_list, list) { + if (!batch_skb) { +new_batch: + batch_skb = skb; + len = NLMSG_GOODSIZE - skb->len; + list_del(&skb->list); + continue; + } + len -= skb->len; + if (len > 0 && NFT_CB(skb).report == NFT_CB(batch_skb).report) { + data = skb_put(batch_skb, skb->len); + memcpy(data, skb->data, skb->len); + list_del(&skb->list); + kfree_skb(skb); + continue; + } + nfnetlink_send(batch_skb, net, portid, NFNLGRP_NFTABLES, + NFT_CB(batch_skb).report, GFP_KERNEL); + goto new_batch; + } + + if (batch_skb) { + nfnetlink_send(batch_skb, net, portid, NFNLGRP_NFTABLES, + NFT_CB(batch_skb).report, GFP_KERNEL); + } + + WARN_ON_ONCE(!list_empty(&nft_net->notify_list)); +} + +static int nf_tables_commit_audit_alloc(struct list_head *adl, + struct nft_table *table) +{ + struct nft_audit_data *adp; + + list_for_each_entry(adp, adl, list) { + if (adp->table == table) + return 0; + } + adp = kzalloc(sizeof(*adp), GFP_KERNEL); + if (!adp) + return -ENOMEM; + adp->table = table; + list_add(&adp->list, adl); + return 0; +} + +static void nf_tables_commit_audit_free(struct list_head *adl) +{ + struct nft_audit_data *adp, *adn; + + list_for_each_entry_safe(adp, adn, adl, list) { + list_del(&adp->list); + kfree(adp); + } +} + +static void nf_tables_commit_audit_collect(struct list_head *adl, + struct nft_table *table, u32 op) +{ + struct nft_audit_data *adp; + + list_for_each_entry(adp, adl, list) { + if (adp->table == table) + goto found; + } + WARN_ONCE(1, "table=%s not expected in commit list", table->name); + return; +found: + adp->entries++; + if (!adp->op || adp->op > op) + adp->op = op; +} + +#define AUNFTABLENAMELEN (NFT_TABLE_MAXNAMELEN + 22) + +static void nf_tables_commit_audit_log(struct list_head *adl, u32 generation) +{ + struct nft_audit_data *adp, *adn; + char aubuf[AUNFTABLENAMELEN]; + + list_for_each_entry_safe(adp, adn, adl, list) { + snprintf(aubuf, AUNFTABLENAMELEN, "%s:%u", adp->table->name, + generation); + audit_log_nfcfg(aubuf, adp->table->family, adp->entries, + nft2audit_op[adp->op], GFP_KERNEL); + list_del(&adp->list); + kfree(adp); + } +} + +static void nft_set_commit_update(struct list_head *set_update_list) +{ + struct nft_set *set, *next; + + list_for_each_entry_safe(set, next, set_update_list, pending_update) { + list_del_init(&set->pending_update); + + if (!set->ops->commit || set->dead) + continue; + + set->ops->commit(set); + } +} + +static unsigned int nft_gc_seq_begin(struct nftables_pernet *nft_net) +{ + unsigned int gc_seq; + + /* Bump gc counter, it becomes odd, this is the busy mark. */ + gc_seq = READ_ONCE(nft_net->gc_seq); + WRITE_ONCE(nft_net->gc_seq, ++gc_seq); + + return gc_seq; +} + +static void nft_gc_seq_end(struct nftables_pernet *nft_net, unsigned int gc_seq) +{ + WRITE_ONCE(nft_net->gc_seq, ++gc_seq); +} + +static int nf_tables_commit(struct net *net, struct sk_buff *skb) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + struct nft_trans *trans, *next; + unsigned int base_seq, gc_seq; + LIST_HEAD(set_update_list); + struct nft_trans_elem *te; + struct nft_chain *chain; + struct nft_table *table; + LIST_HEAD(adl); + int err; + + if (list_empty(&nft_net->commit_list)) { + mutex_unlock(&nft_net->commit_mutex); + return 0; + } + + list_for_each_entry(trans, &nft_net->binding_list, binding_list) { + switch (trans->msg_type) { + case NFT_MSG_NEWSET: + if (!nft_trans_set_update(trans) && + nft_set_is_anonymous(nft_trans_set(trans)) && + !nft_trans_set_bound(trans)) { + pr_warn_once("nftables ruleset with unbound set\n"); + return -EINVAL; + } + break; + case NFT_MSG_NEWCHAIN: + if (!nft_trans_chain_update(trans) && + nft_chain_binding(nft_trans_chain(trans)) && + !nft_trans_chain_bound(trans)) { + pr_warn_once("nftables ruleset with unbound chain\n"); + return -EINVAL; + } + break; + } + } + + /* 0. Validate ruleset, otherwise roll back for error reporting. */ + if (nf_tables_validate(net) < 0) + return -EAGAIN; + + err = nft_flow_rule_offload_commit(net); + if (err < 0) + return err; + + /* 1. Allocate space for next generation rules_gen_X[] */ + list_for_each_entry_safe(trans, next, &nft_net->commit_list, list) { + int ret; + + ret = nf_tables_commit_audit_alloc(&adl, trans->ctx.table); + if (ret) { + nf_tables_commit_chain_prepare_cancel(net); + nf_tables_commit_audit_free(&adl); + return ret; + } + if (trans->msg_type == NFT_MSG_NEWRULE || + trans->msg_type == NFT_MSG_DELRULE) { + chain = trans->ctx.chain; + + ret = nf_tables_commit_chain_prepare(net, chain); + if (ret < 0) { + nf_tables_commit_chain_prepare_cancel(net); + nf_tables_commit_audit_free(&adl); + return ret; + } + } + } + + /* step 2. Make rules_gen_X visible to packet path */ + list_for_each_entry(table, &nft_net->tables, list) { + list_for_each_entry(chain, &table->chains, list) + nf_tables_commit_chain(net, chain); + } + + /* + * Bump generation counter, invalidate any dump in progress. + * Cannot fail after this point. + */ + base_seq = READ_ONCE(nft_net->base_seq); + while (++base_seq == 0) + ; + + WRITE_ONCE(nft_net->base_seq, base_seq); + + gc_seq = nft_gc_seq_begin(nft_net); + + /* step 3. Start new generation, rules_gen_X now in use. */ + net->nft.gencursor = nft_gencursor_next(net); + + list_for_each_entry_safe(trans, next, &nft_net->commit_list, list) { + nf_tables_commit_audit_collect(&adl, trans->ctx.table, + trans->msg_type); + switch (trans->msg_type) { + case NFT_MSG_NEWTABLE: + if (nft_trans_table_update(trans)) { + if (!(trans->ctx.table->flags & __NFT_TABLE_F_UPDATE)) { + nft_trans_destroy(trans); + break; + } + if (trans->ctx.table->flags & NFT_TABLE_F_DORMANT) + nf_tables_table_disable(net, trans->ctx.table); + + trans->ctx.table->flags &= ~__NFT_TABLE_F_UPDATE; + } else { + nft_clear(net, trans->ctx.table); + } + nf_tables_table_notify(&trans->ctx, NFT_MSG_NEWTABLE); + nft_trans_destroy(trans); + break; + case NFT_MSG_DELTABLE: + list_del_rcu(&trans->ctx.table->list); + nf_tables_table_notify(&trans->ctx, NFT_MSG_DELTABLE); + break; + case NFT_MSG_NEWCHAIN: + if (nft_trans_chain_update(trans)) { + nft_chain_commit_update(trans); + nf_tables_chain_notify(&trans->ctx, NFT_MSG_NEWCHAIN); + /* trans destroyed after rcu grace period */ + } else { + nft_chain_commit_drop_policy(trans); + nft_clear(net, trans->ctx.chain); + nf_tables_chain_notify(&trans->ctx, NFT_MSG_NEWCHAIN); + nft_trans_destroy(trans); + } + break; + case NFT_MSG_DELCHAIN: + nft_chain_del(trans->ctx.chain); + nf_tables_chain_notify(&trans->ctx, NFT_MSG_DELCHAIN); + nf_tables_unregister_hook(trans->ctx.net, + trans->ctx.table, + trans->ctx.chain); + break; + case NFT_MSG_NEWRULE: + nft_clear(trans->ctx.net, nft_trans_rule(trans)); + nf_tables_rule_notify(&trans->ctx, + nft_trans_rule(trans), + NFT_MSG_NEWRULE); + if (trans->ctx.chain->flags & NFT_CHAIN_HW_OFFLOAD) + nft_flow_rule_destroy(nft_trans_flow_rule(trans)); + + nft_trans_destroy(trans); + break; + case NFT_MSG_DELRULE: + list_del_rcu(&nft_trans_rule(trans)->list); + nf_tables_rule_notify(&trans->ctx, + nft_trans_rule(trans), + NFT_MSG_DELRULE); + nft_rule_expr_deactivate(&trans->ctx, + nft_trans_rule(trans), + NFT_TRANS_COMMIT); + + if (trans->ctx.chain->flags & NFT_CHAIN_HW_OFFLOAD) + nft_flow_rule_destroy(nft_trans_flow_rule(trans)); + break; + case NFT_MSG_NEWSET: + if (nft_trans_set_update(trans)) { + struct nft_set *set = nft_trans_set(trans); + + WRITE_ONCE(set->timeout, nft_trans_set_timeout(trans)); + WRITE_ONCE(set->gc_int, nft_trans_set_gc_int(trans)); + } else { + nft_clear(net, nft_trans_set(trans)); + /* This avoids hitting -EBUSY when deleting the table + * from the transaction. + */ + if (nft_set_is_anonymous(nft_trans_set(trans)) && + !list_empty(&nft_trans_set(trans)->bindings)) + nft_use_dec(&trans->ctx.table->use); + } + nf_tables_set_notify(&trans->ctx, nft_trans_set(trans), + NFT_MSG_NEWSET, GFP_KERNEL); + nft_trans_destroy(trans); + break; + case NFT_MSG_DELSET: + nft_trans_set(trans)->dead = 1; + list_del_rcu(&nft_trans_set(trans)->list); + nf_tables_set_notify(&trans->ctx, nft_trans_set(trans), + NFT_MSG_DELSET, GFP_KERNEL); + break; + case NFT_MSG_NEWSETELEM: + te = (struct nft_trans_elem *)trans->data; + + nft_setelem_activate(net, te->set, &te->elem); + nf_tables_setelem_notify(&trans->ctx, te->set, + &te->elem, + NFT_MSG_NEWSETELEM); + if (te->set->ops->commit && + list_empty(&te->set->pending_update)) { + list_add_tail(&te->set->pending_update, + &set_update_list); + } + nft_trans_destroy(trans); + break; + case NFT_MSG_DELSETELEM: + te = (struct nft_trans_elem *)trans->data; + + nf_tables_setelem_notify(&trans->ctx, te->set, + &te->elem, + NFT_MSG_DELSETELEM); + nft_setelem_remove(net, te->set, &te->elem); + if (!nft_setelem_is_catchall(te->set, &te->elem)) { + atomic_dec(&te->set->nelems); + te->set->ndeact--; + } + if (te->set->ops->commit && + list_empty(&te->set->pending_update)) { + list_add_tail(&te->set->pending_update, + &set_update_list); + } + break; + case NFT_MSG_NEWOBJ: + if (nft_trans_obj_update(trans)) { + nft_obj_commit_update(trans); + nf_tables_obj_notify(&trans->ctx, + nft_trans_obj(trans), + NFT_MSG_NEWOBJ); + } else { + nft_clear(net, nft_trans_obj(trans)); + nf_tables_obj_notify(&trans->ctx, + nft_trans_obj(trans), + NFT_MSG_NEWOBJ); + nft_trans_destroy(trans); + } + break; + case NFT_MSG_DELOBJ: + nft_obj_del(nft_trans_obj(trans)); + nf_tables_obj_notify(&trans->ctx, nft_trans_obj(trans), + NFT_MSG_DELOBJ); + break; + case NFT_MSG_NEWFLOWTABLE: + if (nft_trans_flowtable_update(trans)) { + nft_trans_flowtable(trans)->data.flags = + nft_trans_flowtable_flags(trans); + nf_tables_flowtable_notify(&trans->ctx, + nft_trans_flowtable(trans), + &nft_trans_flowtable_hooks(trans), + NFT_MSG_NEWFLOWTABLE); + list_splice(&nft_trans_flowtable_hooks(trans), + &nft_trans_flowtable(trans)->hook_list); + } else { + nft_clear(net, nft_trans_flowtable(trans)); + nf_tables_flowtable_notify(&trans->ctx, + nft_trans_flowtable(trans), + &nft_trans_flowtable(trans)->hook_list, + NFT_MSG_NEWFLOWTABLE); + } + nft_trans_destroy(trans); + break; + case NFT_MSG_DELFLOWTABLE: + if (nft_trans_flowtable_update(trans)) { + nf_tables_flowtable_notify(&trans->ctx, + nft_trans_flowtable(trans), + &nft_trans_flowtable_hooks(trans), + NFT_MSG_DELFLOWTABLE); + nft_unregister_flowtable_net_hooks(net, + &nft_trans_flowtable_hooks(trans)); + } else { + list_del_rcu(&nft_trans_flowtable(trans)->list); + nf_tables_flowtable_notify(&trans->ctx, + nft_trans_flowtable(trans), + &nft_trans_flowtable(trans)->hook_list, + NFT_MSG_DELFLOWTABLE); + nft_unregister_flowtable_net_hooks(net, + &nft_trans_flowtable(trans)->hook_list); + } + break; + } + } + + nft_set_commit_update(&set_update_list); + + nft_commit_notify(net, NETLINK_CB(skb).portid); + nf_tables_gen_notify(net, skb, NFT_MSG_NEWGEN); + nf_tables_commit_audit_log(&adl, nft_net->base_seq); + + nft_gc_seq_end(nft_net, gc_seq); + nf_tables_commit_release(net); + + return 0; +} + +static void nf_tables_module_autoload(struct net *net) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + struct nft_module_request *req, *next; + LIST_HEAD(module_list); + + list_splice_init(&nft_net->module_list, &module_list); + mutex_unlock(&nft_net->commit_mutex); + list_for_each_entry_safe(req, next, &module_list, list) { + request_module("%s", req->module); + req->done = true; + } + mutex_lock(&nft_net->commit_mutex); + list_splice(&module_list, &nft_net->module_list); +} + +static void nf_tables_abort_release(struct nft_trans *trans) +{ + switch (trans->msg_type) { + case NFT_MSG_NEWTABLE: + nf_tables_table_destroy(&trans->ctx); + break; + case NFT_MSG_NEWCHAIN: + nf_tables_chain_destroy(&trans->ctx); + break; + case NFT_MSG_NEWRULE: + nf_tables_rule_destroy(&trans->ctx, nft_trans_rule(trans)); + break; + case NFT_MSG_NEWSET: + nft_set_destroy(&trans->ctx, nft_trans_set(trans)); + break; + case NFT_MSG_NEWSETELEM: + nft_set_elem_destroy(nft_trans_elem_set(trans), + nft_trans_elem(trans).priv, true); + break; + case NFT_MSG_NEWOBJ: + nft_obj_destroy(&trans->ctx, nft_trans_obj(trans)); + break; + case NFT_MSG_NEWFLOWTABLE: + if (nft_trans_flowtable_update(trans)) + nft_flowtable_hooks_destroy(&nft_trans_flowtable_hooks(trans)); + else + nf_tables_flowtable_destroy(nft_trans_flowtable(trans)); + break; + } + kfree(trans); +} + +static void nft_set_abort_update(struct list_head *set_update_list) +{ + struct nft_set *set, *next; + + list_for_each_entry_safe(set, next, set_update_list, pending_update) { + list_del_init(&set->pending_update); + + if (!set->ops->abort) + continue; + + set->ops->abort(set); + } +} + +static int __nf_tables_abort(struct net *net, enum nfnl_abort_action action) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + struct nft_trans *trans, *next; + LIST_HEAD(set_update_list); + struct nft_trans_elem *te; + + if (action == NFNL_ABORT_VALIDATE && + nf_tables_validate(net) < 0) + return -EAGAIN; + + list_for_each_entry_safe_reverse(trans, next, &nft_net->commit_list, + list) { + switch (trans->msg_type) { + case NFT_MSG_NEWTABLE: + if (nft_trans_table_update(trans)) { + if (!(trans->ctx.table->flags & __NFT_TABLE_F_UPDATE)) { + nft_trans_destroy(trans); + break; + } + if (trans->ctx.table->flags & __NFT_TABLE_F_WAS_DORMANT) { + nf_tables_table_disable(net, trans->ctx.table); + trans->ctx.table->flags |= NFT_TABLE_F_DORMANT; + } else if (trans->ctx.table->flags & __NFT_TABLE_F_WAS_AWAKEN) { + trans->ctx.table->flags &= ~NFT_TABLE_F_DORMANT; + } + trans->ctx.table->flags &= ~__NFT_TABLE_F_UPDATE; + nft_trans_destroy(trans); + } else { + list_del_rcu(&trans->ctx.table->list); + } + break; + case NFT_MSG_DELTABLE: + nft_clear(trans->ctx.net, trans->ctx.table); + nft_trans_destroy(trans); + break; + case NFT_MSG_NEWCHAIN: + if (nft_trans_chain_update(trans)) { + free_percpu(nft_trans_chain_stats(trans)); + kfree(nft_trans_chain_name(trans)); + nft_trans_destroy(trans); + } else { + if (nft_trans_chain_bound(trans)) { + nft_trans_destroy(trans); + break; + } + nft_use_dec_restore(&trans->ctx.table->use); + nft_chain_del(trans->ctx.chain); + nf_tables_unregister_hook(trans->ctx.net, + trans->ctx.table, + trans->ctx.chain); + } + break; + case NFT_MSG_DELCHAIN: + nft_use_inc_restore(&trans->ctx.table->use); + nft_clear(trans->ctx.net, trans->ctx.chain); + nft_trans_destroy(trans); + break; + case NFT_MSG_NEWRULE: + if (nft_trans_rule_bound(trans)) { + nft_trans_destroy(trans); + break; + } + nft_use_dec_restore(&trans->ctx.chain->use); + list_del_rcu(&nft_trans_rule(trans)->list); + nft_rule_expr_deactivate(&trans->ctx, + nft_trans_rule(trans), + NFT_TRANS_ABORT); + if (trans->ctx.chain->flags & NFT_CHAIN_HW_OFFLOAD) + nft_flow_rule_destroy(nft_trans_flow_rule(trans)); + break; + case NFT_MSG_DELRULE: + nft_use_inc_restore(&trans->ctx.chain->use); + nft_clear(trans->ctx.net, nft_trans_rule(trans)); + nft_rule_expr_activate(&trans->ctx, nft_trans_rule(trans)); + if (trans->ctx.chain->flags & NFT_CHAIN_HW_OFFLOAD) + nft_flow_rule_destroy(nft_trans_flow_rule(trans)); + + nft_trans_destroy(trans); + break; + case NFT_MSG_NEWSET: + if (nft_trans_set_update(trans)) { + nft_trans_destroy(trans); + break; + } + nft_use_dec_restore(&trans->ctx.table->use); + if (nft_trans_set_bound(trans)) { + nft_trans_destroy(trans); + break; + } + nft_trans_set(trans)->dead = 1; + list_del_rcu(&nft_trans_set(trans)->list); + break; + case NFT_MSG_DELSET: + nft_use_inc_restore(&trans->ctx.table->use); + nft_clear(trans->ctx.net, nft_trans_set(trans)); + if (nft_trans_set(trans)->flags & (NFT_SET_MAP | NFT_SET_OBJECT)) + nft_map_activate(&trans->ctx, nft_trans_set(trans)); + + nft_trans_destroy(trans); + break; + case NFT_MSG_NEWSETELEM: + if (nft_trans_elem_set_bound(trans)) { + nft_trans_destroy(trans); + break; + } + te = (struct nft_trans_elem *)trans->data; + nft_setelem_remove(net, te->set, &te->elem); + if (!nft_setelem_is_catchall(te->set, &te->elem)) + atomic_dec(&te->set->nelems); + + if (te->set->ops->abort && + list_empty(&te->set->pending_update)) { + list_add_tail(&te->set->pending_update, + &set_update_list); + } + break; + case NFT_MSG_DELSETELEM: + te = (struct nft_trans_elem *)trans->data; + + nft_setelem_data_activate(net, te->set, &te->elem); + nft_setelem_activate(net, te->set, &te->elem); + if (!nft_setelem_is_catchall(te->set, &te->elem)) + te->set->ndeact--; + + if (te->set->ops->abort && + list_empty(&te->set->pending_update)) { + list_add_tail(&te->set->pending_update, + &set_update_list); + } + nft_trans_destroy(trans); + break; + case NFT_MSG_NEWOBJ: + if (nft_trans_obj_update(trans)) { + nft_obj_destroy(&trans->ctx, nft_trans_obj_newobj(trans)); + nft_trans_destroy(trans); + } else { + nft_use_dec_restore(&trans->ctx.table->use); + nft_obj_del(nft_trans_obj(trans)); + } + break; + case NFT_MSG_DELOBJ: + nft_use_inc_restore(&trans->ctx.table->use); + nft_clear(trans->ctx.net, nft_trans_obj(trans)); + nft_trans_destroy(trans); + break; + case NFT_MSG_NEWFLOWTABLE: + if (nft_trans_flowtable_update(trans)) { + nft_unregister_flowtable_net_hooks(net, + &nft_trans_flowtable_hooks(trans)); + } else { + nft_use_dec_restore(&trans->ctx.table->use); + list_del_rcu(&nft_trans_flowtable(trans)->list); + nft_unregister_flowtable_net_hooks(net, + &nft_trans_flowtable(trans)->hook_list); + } + break; + case NFT_MSG_DELFLOWTABLE: + if (nft_trans_flowtable_update(trans)) { + list_splice(&nft_trans_flowtable_hooks(trans), + &nft_trans_flowtable(trans)->hook_list); + } else { + nft_use_inc_restore(&trans->ctx.table->use); + nft_clear(trans->ctx.net, nft_trans_flowtable(trans)); + } + nft_trans_destroy(trans); + break; + } + } + + nft_set_abort_update(&set_update_list); + + synchronize_rcu(); + + list_for_each_entry_safe_reverse(trans, next, + &nft_net->commit_list, list) { + nft_trans_list_del(trans); + nf_tables_abort_release(trans); + } + + if (action == NFNL_ABORT_AUTOLOAD) + nf_tables_module_autoload(net); + else + nf_tables_module_autoload_cleanup(net); + + return 0; +} + +static int nf_tables_abort(struct net *net, struct sk_buff *skb, + enum nfnl_abort_action action) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + unsigned int gc_seq; + int ret; + + gc_seq = nft_gc_seq_begin(nft_net); + ret = __nf_tables_abort(net, action); + nft_gc_seq_end(nft_net, gc_seq); + + mutex_unlock(&nft_net->commit_mutex); + + return ret; +} + +static bool nf_tables_valid_genid(struct net *net, u32 genid) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + bool genid_ok; + + mutex_lock(&nft_net->commit_mutex); + + genid_ok = genid == 0 || nft_net->base_seq == genid; + if (!genid_ok) + mutex_unlock(&nft_net->commit_mutex); + + /* else, commit mutex has to be released by commit or abort function */ + return genid_ok; +} + +static const struct nfnetlink_subsystem nf_tables_subsys = { + .name = "nf_tables", + .subsys_id = NFNL_SUBSYS_NFTABLES, + .cb_count = NFT_MSG_MAX, + .cb = nf_tables_cb, + .commit = nf_tables_commit, + .abort = nf_tables_abort, + .valid_genid = nf_tables_valid_genid, + .owner = THIS_MODULE, +}; + +int nft_chain_validate_dependency(const struct nft_chain *chain, + enum nft_chain_types type) +{ + const struct nft_base_chain *basechain; + + if (nft_is_base_chain(chain)) { + basechain = nft_base_chain(chain); + if (basechain->type->type != type) + return -EOPNOTSUPP; + } + return 0; +} +EXPORT_SYMBOL_GPL(nft_chain_validate_dependency); + +int nft_chain_validate_hooks(const struct nft_chain *chain, + unsigned int hook_flags) +{ + struct nft_base_chain *basechain; + + if (nft_is_base_chain(chain)) { + basechain = nft_base_chain(chain); + + if ((1 << basechain->ops.hooknum) & hook_flags) + return 0; + + return -EOPNOTSUPP; + } + + return 0; +} +EXPORT_SYMBOL_GPL(nft_chain_validate_hooks); + +/* + * Loop detection - walk through the ruleset beginning at the destination chain + * of a new jump until either the source chain is reached (loop) or all + * reachable chains have been traversed. + * + * The loop check is performed whenever a new jump verdict is added to an + * expression or verdict map or a verdict map is bound to a new chain. + */ + +static int nf_tables_check_loops(const struct nft_ctx *ctx, + const struct nft_chain *chain); + +static int nft_check_loops(const struct nft_ctx *ctx, + const struct nft_set_ext *ext) +{ + const struct nft_data *data; + int ret; + + data = nft_set_ext_data(ext); + switch (data->verdict.code) { + case NFT_JUMP: + case NFT_GOTO: + ret = nf_tables_check_loops(ctx, data->verdict.chain); + break; + default: + ret = 0; + break; + } + + return ret; +} + +static int nf_tables_loop_check_setelem(const struct nft_ctx *ctx, + struct nft_set *set, + const struct nft_set_iter *iter, + struct nft_set_elem *elem) +{ + const struct nft_set_ext *ext = nft_set_elem_ext(set, elem->priv); + + if (nft_set_ext_exists(ext, NFT_SET_EXT_FLAGS) && + *nft_set_ext_flags(ext) & NFT_SET_ELEM_INTERVAL_END) + return 0; + + return nft_check_loops(ctx, ext); +} + +static int nft_set_catchall_loops(const struct nft_ctx *ctx, + struct nft_set *set) +{ + u8 genmask = nft_genmask_next(ctx->net); + struct nft_set_elem_catchall *catchall; + struct nft_set_ext *ext; + int ret = 0; + + list_for_each_entry_rcu(catchall, &set->catchall_list, list) { + ext = nft_set_elem_ext(set, catchall->elem); + if (!nft_set_elem_active(ext, genmask)) + continue; + + ret = nft_check_loops(ctx, ext); + if (ret < 0) + return ret; + } + + return ret; +} + +static int nf_tables_check_loops(const struct nft_ctx *ctx, + const struct nft_chain *chain) +{ + const struct nft_rule *rule; + const struct nft_expr *expr, *last; + struct nft_set *set; + struct nft_set_binding *binding; + struct nft_set_iter iter; + + if (ctx->chain == chain) + return -ELOOP; + + list_for_each_entry(rule, &chain->rules, list) { + nft_rule_for_each_expr(expr, last, rule) { + struct nft_immediate_expr *priv; + const struct nft_data *data; + int err; + + if (strcmp(expr->ops->type->name, "immediate")) + continue; + + priv = nft_expr_priv(expr); + if (priv->dreg != NFT_REG_VERDICT) + continue; + + data = &priv->data; + switch (data->verdict.code) { + case NFT_JUMP: + case NFT_GOTO: + err = nf_tables_check_loops(ctx, + data->verdict.chain); + if (err < 0) + return err; + break; + default: + break; + } + } + } + + list_for_each_entry(set, &ctx->table->sets, list) { + if (!nft_is_active_next(ctx->net, set)) + continue; + if (!(set->flags & NFT_SET_MAP) || + set->dtype != NFT_DATA_VERDICT) + continue; + + list_for_each_entry(binding, &set->bindings, list) { + if (!(binding->flags & NFT_SET_MAP) || + binding->chain != chain) + continue; + + iter.genmask = nft_genmask_next(ctx->net); + iter.skip = 0; + iter.count = 0; + iter.err = 0; + iter.fn = nf_tables_loop_check_setelem; + + set->ops->walk(ctx, set, &iter); + if (!iter.err) + iter.err = nft_set_catchall_loops(ctx, set); + + if (iter.err < 0) + return iter.err; + } + } + + return 0; +} + +/** + * nft_parse_u32_check - fetch u32 attribute and check for maximum value + * + * @attr: netlink attribute to fetch value from + * @max: maximum value to be stored in dest + * @dest: pointer to the variable + * + * Parse, check and store a given u32 netlink attribute into variable. + * This function returns -ERANGE if the value goes over maximum value. + * Otherwise a 0 is returned and the attribute value is stored in the + * destination variable. + */ +int nft_parse_u32_check(const struct nlattr *attr, int max, u32 *dest) +{ + u32 val; + + val = ntohl(nla_get_be32(attr)); + if (val > max) + return -ERANGE; + + *dest = val; + return 0; +} +EXPORT_SYMBOL_GPL(nft_parse_u32_check); + +static int nft_parse_register(const struct nlattr *attr, u32 *preg) +{ + unsigned int reg; + + reg = ntohl(nla_get_be32(attr)); + switch (reg) { + case NFT_REG_VERDICT...NFT_REG_4: + *preg = reg * NFT_REG_SIZE / NFT_REG32_SIZE; + break; + case NFT_REG32_00...NFT_REG32_15: + *preg = reg + NFT_REG_SIZE / NFT_REG32_SIZE - NFT_REG32_00; + break; + default: + return -ERANGE; + } + + return 0; +} + +/** + * nft_dump_register - dump a register value to a netlink attribute + * + * @skb: socket buffer + * @attr: attribute number + * @reg: register number + * + * Construct a netlink attribute containing the register number. For + * compatibility reasons, register numbers being a multiple of 4 are + * translated to the corresponding 128 bit register numbers. + */ +int nft_dump_register(struct sk_buff *skb, unsigned int attr, unsigned int reg) +{ + if (reg % (NFT_REG_SIZE / NFT_REG32_SIZE) == 0) + reg = reg / (NFT_REG_SIZE / NFT_REG32_SIZE); + else + reg = reg - NFT_REG_SIZE / NFT_REG32_SIZE + NFT_REG32_00; + + return nla_put_be32(skb, attr, htonl(reg)); +} +EXPORT_SYMBOL_GPL(nft_dump_register); + +static int nft_validate_register_load(enum nft_registers reg, unsigned int len) +{ + if (reg < NFT_REG_1 * NFT_REG_SIZE / NFT_REG32_SIZE) + return -EINVAL; + if (len == 0) + return -EINVAL; + if (reg * NFT_REG32_SIZE + len > sizeof_field(struct nft_regs, data)) + return -ERANGE; + + return 0; +} + +int nft_parse_register_load(const struct nlattr *attr, u8 *sreg, u32 len) +{ + u32 reg; + int err; + + err = nft_parse_register(attr, ®); + if (err < 0) + return err; + + err = nft_validate_register_load(reg, len); + if (err < 0) + return err; + + *sreg = reg; + return 0; +} +EXPORT_SYMBOL_GPL(nft_parse_register_load); + +static int nft_validate_register_store(const struct nft_ctx *ctx, + enum nft_registers reg, + const struct nft_data *data, + enum nft_data_types type, + unsigned int len) +{ + int err; + + switch (reg) { + case NFT_REG_VERDICT: + if (type != NFT_DATA_VERDICT) + return -EINVAL; + + if (data != NULL && + (data->verdict.code == NFT_GOTO || + data->verdict.code == NFT_JUMP)) { + err = nf_tables_check_loops(ctx, data->verdict.chain); + if (err < 0) + return err; + } + + return 0; + default: + if (reg < NFT_REG_1 * NFT_REG_SIZE / NFT_REG32_SIZE) + return -EINVAL; + if (len == 0) + return -EINVAL; + if (reg * NFT_REG32_SIZE + len > + sizeof_field(struct nft_regs, data)) + return -ERANGE; + + if (data != NULL && type != NFT_DATA_VALUE) + return -EINVAL; + return 0; + } +} + +int nft_parse_register_store(const struct nft_ctx *ctx, + const struct nlattr *attr, u8 *dreg, + const struct nft_data *data, + enum nft_data_types type, unsigned int len) +{ + int err; + u32 reg; + + err = nft_parse_register(attr, ®); + if (err < 0) + return err; + + err = nft_validate_register_store(ctx, reg, data, type, len); + if (err < 0) + return err; + + *dreg = reg; + return 0; +} +EXPORT_SYMBOL_GPL(nft_parse_register_store); + +static const struct nla_policy nft_verdict_policy[NFTA_VERDICT_MAX + 1] = { + [NFTA_VERDICT_CODE] = { .type = NLA_U32 }, + [NFTA_VERDICT_CHAIN] = { .type = NLA_STRING, + .len = NFT_CHAIN_MAXNAMELEN - 1 }, + [NFTA_VERDICT_CHAIN_ID] = { .type = NLA_U32 }, +}; + +static int nft_verdict_init(const struct nft_ctx *ctx, struct nft_data *data, + struct nft_data_desc *desc, const struct nlattr *nla) +{ + u8 genmask = nft_genmask_next(ctx->net); + struct nlattr *tb[NFTA_VERDICT_MAX + 1]; + struct nft_chain *chain; + int err; + + err = nla_parse_nested_deprecated(tb, NFTA_VERDICT_MAX, nla, + nft_verdict_policy, NULL); + if (err < 0) + return err; + + if (!tb[NFTA_VERDICT_CODE]) + return -EINVAL; + + /* zero padding hole for memcmp */ + memset(data, 0, sizeof(*data)); + data->verdict.code = ntohl(nla_get_be32(tb[NFTA_VERDICT_CODE])); + + switch (data->verdict.code) { + case NF_ACCEPT: + case NF_DROP: + case NF_QUEUE: + break; + case NFT_CONTINUE: + case NFT_BREAK: + case NFT_RETURN: + break; + case NFT_JUMP: + case NFT_GOTO: + if (tb[NFTA_VERDICT_CHAIN]) { + chain = nft_chain_lookup(ctx->net, ctx->table, + tb[NFTA_VERDICT_CHAIN], + genmask); + } else if (tb[NFTA_VERDICT_CHAIN_ID]) { + chain = nft_chain_lookup_byid(ctx->net, ctx->table, + tb[NFTA_VERDICT_CHAIN_ID], + genmask); + if (IS_ERR(chain)) + return PTR_ERR(chain); + } else { + return -EINVAL; + } + + if (IS_ERR(chain)) + return PTR_ERR(chain); + if (nft_is_base_chain(chain)) + return -EOPNOTSUPP; + if (nft_chain_is_bound(chain)) + return -EINVAL; + if (desc->flags & NFT_DATA_DESC_SETELEM && + chain->flags & NFT_CHAIN_BINDING) + return -EINVAL; + if (!nft_use_inc(&chain->use)) + return -EMFILE; + + data->verdict.chain = chain; + break; + default: + return -EINVAL; + } + + desc->len = sizeof(data->verdict); + + return 0; +} + +static void nft_verdict_uninit(const struct nft_data *data) +{ + struct nft_chain *chain; + + switch (data->verdict.code) { + case NFT_JUMP: + case NFT_GOTO: + chain = data->verdict.chain; + nft_use_dec(&chain->use); + break; + } +} + +int nft_verdict_dump(struct sk_buff *skb, int type, const struct nft_verdict *v) +{ + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, type); + if (!nest) + goto nla_put_failure; + + if (nla_put_be32(skb, NFTA_VERDICT_CODE, htonl(v->code))) + goto nla_put_failure; + + switch (v->code) { + case NFT_JUMP: + case NFT_GOTO: + if (nla_put_string(skb, NFTA_VERDICT_CHAIN, + v->chain->name)) + goto nla_put_failure; + } + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + return -1; +} + +static int nft_value_init(const struct nft_ctx *ctx, + struct nft_data *data, struct nft_data_desc *desc, + const struct nlattr *nla) +{ + unsigned int len; + + len = nla_len(nla); + if (len == 0) + return -EINVAL; + if (len > desc->size) + return -EOVERFLOW; + if (desc->len) { + if (len != desc->len) + return -EINVAL; + } else { + desc->len = len; + } + + nla_memcpy(data->data, nla, len); + + return 0; +} + +static int nft_value_dump(struct sk_buff *skb, const struct nft_data *data, + unsigned int len) +{ + return nla_put(skb, NFTA_DATA_VALUE, len, data->data); +} + +static const struct nla_policy nft_data_policy[NFTA_DATA_MAX + 1] = { + [NFTA_DATA_VALUE] = { .type = NLA_BINARY }, + [NFTA_DATA_VERDICT] = { .type = NLA_NESTED }, +}; + +/** + * nft_data_init - parse nf_tables data netlink attributes + * + * @ctx: context of the expression using the data + * @data: destination struct nft_data + * @desc: data description + * @nla: netlink attribute containing data + * + * Parse the netlink data attributes and initialize a struct nft_data. + * The type and length of data are returned in the data description. + * + * The caller can indicate that it only wants to accept data of type + * NFT_DATA_VALUE by passing NULL for the ctx argument. + */ +int nft_data_init(const struct nft_ctx *ctx, struct nft_data *data, + struct nft_data_desc *desc, const struct nlattr *nla) +{ + struct nlattr *tb[NFTA_DATA_MAX + 1]; + int err; + + if (WARN_ON_ONCE(!desc->size)) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, NFTA_DATA_MAX, nla, + nft_data_policy, NULL); + if (err < 0) + return err; + + if (tb[NFTA_DATA_VALUE]) { + if (desc->type != NFT_DATA_VALUE) + return -EINVAL; + + err = nft_value_init(ctx, data, desc, tb[NFTA_DATA_VALUE]); + } else if (tb[NFTA_DATA_VERDICT] && ctx != NULL) { + if (desc->type != NFT_DATA_VERDICT) + return -EINVAL; + + err = nft_verdict_init(ctx, data, desc, tb[NFTA_DATA_VERDICT]); + } else { + err = -EINVAL; + } + + return err; +} +EXPORT_SYMBOL_GPL(nft_data_init); + +/** + * nft_data_release - release a nft_data item + * + * @data: struct nft_data to release + * @type: type of data + * + * Release a nft_data item. NFT_DATA_VALUE types can be silently discarded, + * all others need to be released by calling this function. + */ +void nft_data_release(const struct nft_data *data, enum nft_data_types type) +{ + if (type < NFT_DATA_VERDICT) + return; + switch (type) { + case NFT_DATA_VERDICT: + return nft_verdict_uninit(data); + default: + WARN_ON(1); + } +} +EXPORT_SYMBOL_GPL(nft_data_release); + +int nft_data_dump(struct sk_buff *skb, int attr, const struct nft_data *data, + enum nft_data_types type, unsigned int len) +{ + struct nlattr *nest; + int err; + + nest = nla_nest_start_noflag(skb, attr); + if (nest == NULL) + return -1; + + switch (type) { + case NFT_DATA_VALUE: + err = nft_value_dump(skb, data, len); + break; + case NFT_DATA_VERDICT: + err = nft_verdict_dump(skb, NFTA_DATA_VERDICT, &data->verdict); + break; + default: + err = -EINVAL; + WARN_ON(1); + } + + nla_nest_end(skb, nest); + return err; +} +EXPORT_SYMBOL_GPL(nft_data_dump); + +int __nft_release_basechain(struct nft_ctx *ctx) +{ + struct nft_rule *rule, *nr; + + if (WARN_ON(!nft_is_base_chain(ctx->chain))) + return 0; + + nf_tables_unregister_hook(ctx->net, ctx->chain->table, ctx->chain); + list_for_each_entry_safe(rule, nr, &ctx->chain->rules, list) { + list_del(&rule->list); + nft_use_dec(&ctx->chain->use); + nf_tables_rule_release(ctx, rule); + } + nft_chain_del(ctx->chain); + nft_use_dec(&ctx->table->use); + nf_tables_chain_destroy(ctx); + + return 0; +} +EXPORT_SYMBOL_GPL(__nft_release_basechain); + +static void __nft_release_hook(struct net *net, struct nft_table *table) +{ + struct nft_flowtable *flowtable; + struct nft_chain *chain; + + list_for_each_entry(chain, &table->chains, list) + __nf_tables_unregister_hook(net, table, chain, true); + list_for_each_entry(flowtable, &table->flowtables, list) + __nft_unregister_flowtable_net_hooks(net, &flowtable->hook_list, + true); +} + +static void __nft_release_hooks(struct net *net) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + struct nft_table *table; + + list_for_each_entry(table, &nft_net->tables, list) { + if (nft_table_has_owner(table)) + continue; + + __nft_release_hook(net, table); + } +} + +static void __nft_release_table(struct net *net, struct nft_table *table) +{ + struct nft_flowtable *flowtable, *nf; + struct nft_chain *chain, *nc; + struct nft_object *obj, *ne; + struct nft_rule *rule, *nr; + struct nft_set *set, *ns; + struct nft_ctx ctx = { + .net = net, + .family = NFPROTO_NETDEV, + }; + + ctx.family = table->family; + ctx.table = table; + list_for_each_entry(chain, &table->chains, list) { + if (nft_chain_binding(chain)) + continue; + + ctx.chain = chain; + list_for_each_entry_safe(rule, nr, &chain->rules, list) { + list_del(&rule->list); + nft_use_dec(&chain->use); + nf_tables_rule_release(&ctx, rule); + } + } + list_for_each_entry_safe(flowtable, nf, &table->flowtables, list) { + list_del(&flowtable->list); + nft_use_dec(&table->use); + nf_tables_flowtable_destroy(flowtable); + } + list_for_each_entry_safe(set, ns, &table->sets, list) { + list_del(&set->list); + nft_use_dec(&table->use); + if (set->flags & (NFT_SET_MAP | NFT_SET_OBJECT)) + nft_map_deactivate(&ctx, set); + + nft_set_destroy(&ctx, set); + } + list_for_each_entry_safe(obj, ne, &table->objects, list) { + nft_obj_del(obj); + nft_use_dec(&table->use); + nft_obj_destroy(&ctx, obj); + } + list_for_each_entry_safe(chain, nc, &table->chains, list) { + ctx.chain = chain; + nft_chain_del(chain); + nft_use_dec(&table->use); + nf_tables_chain_destroy(&ctx); + } + nf_tables_table_destroy(&ctx); +} + +static void __nft_release_tables(struct net *net) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + struct nft_table *table, *nt; + + list_for_each_entry_safe(table, nt, &nft_net->tables, list) { + if (nft_table_has_owner(table)) + continue; + + list_del(&table->list); + + __nft_release_table(net, table); + } +} + +static int nft_rcv_nl_event(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct nft_table *table, *to_delete[8]; + struct nftables_pernet *nft_net; + struct netlink_notify *n = ptr; + struct net *net = n->net; + unsigned int deleted; + bool restart = false; + unsigned int gc_seq; + + if (event != NETLINK_URELEASE || n->protocol != NETLINK_NETFILTER) + return NOTIFY_DONE; + + nft_net = nft_pernet(net); + deleted = 0; + mutex_lock(&nft_net->commit_mutex); + + gc_seq = nft_gc_seq_begin(nft_net); + + if (!list_empty(&nf_tables_destroy_list)) + nf_tables_trans_destroy_flush_work(); +again: + list_for_each_entry(table, &nft_net->tables, list) { + if (nft_table_has_owner(table) && + n->portid == table->nlpid) { + __nft_release_hook(net, table); + list_del_rcu(&table->list); + to_delete[deleted++] = table; + if (deleted >= ARRAY_SIZE(to_delete)) + break; + } + } + if (deleted) { + restart = deleted >= ARRAY_SIZE(to_delete); + synchronize_rcu(); + while (deleted) + __nft_release_table(net, to_delete[--deleted]); + + if (restart) + goto again; + } + nft_gc_seq_end(nft_net, gc_seq); + + mutex_unlock(&nft_net->commit_mutex); + + return NOTIFY_DONE; +} + +static struct notifier_block nft_nl_notifier = { + .notifier_call = nft_rcv_nl_event, +}; + +static int __net_init nf_tables_init_net(struct net *net) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + + INIT_LIST_HEAD(&nft_net->tables); + INIT_LIST_HEAD(&nft_net->commit_list); + INIT_LIST_HEAD(&nft_net->binding_list); + INIT_LIST_HEAD(&nft_net->module_list); + INIT_LIST_HEAD(&nft_net->notify_list); + mutex_init(&nft_net->commit_mutex); + nft_net->base_seq = 1; + nft_net->validate_state = NFT_VALIDATE_SKIP; + nft_net->gc_seq = 0; + + return 0; +} + +static void __net_exit nf_tables_pre_exit_net(struct net *net) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + + mutex_lock(&nft_net->commit_mutex); + __nft_release_hooks(net); + mutex_unlock(&nft_net->commit_mutex); +} + +static void __net_exit nf_tables_exit_net(struct net *net) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + unsigned int gc_seq; + + mutex_lock(&nft_net->commit_mutex); + + gc_seq = nft_gc_seq_begin(nft_net); + + if (!list_empty(&nft_net->commit_list) || + !list_empty(&nft_net->module_list)) + __nf_tables_abort(net, NFNL_ABORT_NONE); + + __nft_release_tables(net); + + nft_gc_seq_end(nft_net, gc_seq); + + mutex_unlock(&nft_net->commit_mutex); + WARN_ON_ONCE(!list_empty(&nft_net->tables)); + WARN_ON_ONCE(!list_empty(&nft_net->module_list)); + WARN_ON_ONCE(!list_empty(&nft_net->notify_list)); +} + +static void nf_tables_exit_batch(struct list_head *net_exit_list) +{ + flush_work(&trans_gc_work); +} + +static struct pernet_operations nf_tables_net_ops = { + .init = nf_tables_init_net, + .pre_exit = nf_tables_pre_exit_net, + .exit = nf_tables_exit_net, + .exit_batch = nf_tables_exit_batch, + .id = &nf_tables_net_id, + .size = sizeof(struct nftables_pernet), +}; + +static int __init nf_tables_module_init(void) +{ + int err; + + err = register_pernet_subsys(&nf_tables_net_ops); + if (err < 0) + return err; + + err = nft_chain_filter_init(); + if (err < 0) + goto err_chain_filter; + + err = nf_tables_core_module_init(); + if (err < 0) + goto err_core_module; + + err = register_netdevice_notifier(&nf_tables_flowtable_notifier); + if (err < 0) + goto err_netdev_notifier; + + err = rhltable_init(&nft_objname_ht, &nft_objname_ht_params); + if (err < 0) + goto err_rht_objname; + + err = nft_offload_init(); + if (err < 0) + goto err_offload; + + err = netlink_register_notifier(&nft_nl_notifier); + if (err < 0) + goto err_netlink_notifier; + + /* must be last */ + err = nfnetlink_subsys_register(&nf_tables_subsys); + if (err < 0) + goto err_nfnl_subsys; + + nft_chain_route_init(); + + return err; + +err_nfnl_subsys: + netlink_unregister_notifier(&nft_nl_notifier); +err_netlink_notifier: + nft_offload_exit(); +err_offload: + rhltable_destroy(&nft_objname_ht); +err_rht_objname: + unregister_netdevice_notifier(&nf_tables_flowtable_notifier); +err_netdev_notifier: + nf_tables_core_module_exit(); +err_core_module: + nft_chain_filter_fini(); +err_chain_filter: + unregister_pernet_subsys(&nf_tables_net_ops); + return err; +} + +static void __exit nf_tables_module_exit(void) +{ + nfnetlink_subsys_unregister(&nf_tables_subsys); + netlink_unregister_notifier(&nft_nl_notifier); + nft_offload_exit(); + unregister_netdevice_notifier(&nf_tables_flowtable_notifier); + nft_chain_filter_fini(); + nft_chain_route_fini(); + unregister_pernet_subsys(&nf_tables_net_ops); + cancel_work_sync(&trans_gc_work); + cancel_work_sync(&trans_destroy_work); + rcu_barrier(); + rhltable_destroy(&nft_objname_ht); + nf_tables_core_module_exit(); +} + +module_init(nf_tables_module_init); +module_exit(nf_tables_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_NFTABLES); diff --git a/net/netfilter/nf_tables_core.c b/net/netfilter/nf_tables_core.c new file mode 100644 index 000000000..e0c117229 --- /dev/null +++ b/net/netfilter/nf_tables_core.c @@ -0,0 +1,393 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008 Patrick McHardy + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static noinline void __nft_trace_packet(struct nft_traceinfo *info, + const struct nft_chain *chain, + enum nft_trace_types type) +{ + if (!info->trace || !info->nf_trace) + return; + + info->chain = chain; + info->type = type; + + nft_trace_notify(info); +} + +static inline void nft_trace_packet(const struct nft_pktinfo *pkt, + struct nft_traceinfo *info, + const struct nft_chain *chain, + const struct nft_rule_dp *rule, + enum nft_trace_types type) +{ + if (static_branch_unlikely(&nft_trace_enabled)) { + info->nf_trace = pkt->skb->nf_trace; + info->rule = rule; + __nft_trace_packet(info, chain, type); + } +} + +static inline void nft_trace_copy_nftrace(const struct nft_pktinfo *pkt, + struct nft_traceinfo *info) +{ + if (static_branch_unlikely(&nft_trace_enabled)) { + if (info->trace) + info->nf_trace = pkt->skb->nf_trace; + } +} + +static void nft_bitwise_fast_eval(const struct nft_expr *expr, + struct nft_regs *regs) +{ + const struct nft_bitwise_fast_expr *priv = nft_expr_priv(expr); + u32 *src = ®s->data[priv->sreg]; + u32 *dst = ®s->data[priv->dreg]; + + *dst = (*src & priv->mask) ^ priv->xor; +} + +static void nft_cmp_fast_eval(const struct nft_expr *expr, + struct nft_regs *regs) +{ + const struct nft_cmp_fast_expr *priv = nft_expr_priv(expr); + + if (((regs->data[priv->sreg] & priv->mask) == priv->data) ^ priv->inv) + return; + regs->verdict.code = NFT_BREAK; +} + +static void nft_cmp16_fast_eval(const struct nft_expr *expr, + struct nft_regs *regs) +{ + const struct nft_cmp16_fast_expr *priv = nft_expr_priv(expr); + const u64 *reg_data = (const u64 *)®s->data[priv->sreg]; + const u64 *mask = (const u64 *)&priv->mask; + const u64 *data = (const u64 *)&priv->data; + + if (((reg_data[0] & mask[0]) == data[0] && + ((reg_data[1] & mask[1]) == data[1])) ^ priv->inv) + return; + regs->verdict.code = NFT_BREAK; +} + +static noinline void __nft_trace_verdict(struct nft_traceinfo *info, + const struct nft_chain *chain, + const struct nft_regs *regs) +{ + enum nft_trace_types type; + + switch (regs->verdict.code) { + case NFT_CONTINUE: + case NFT_RETURN: + type = NFT_TRACETYPE_RETURN; + break; + case NF_STOLEN: + type = NFT_TRACETYPE_RULE; + /* can't access skb->nf_trace; use copy */ + break; + default: + type = NFT_TRACETYPE_RULE; + + if (info->trace) + info->nf_trace = info->pkt->skb->nf_trace; + break; + } + + __nft_trace_packet(info, chain, type); +} + +static inline void nft_trace_verdict(struct nft_traceinfo *info, + const struct nft_chain *chain, + const struct nft_rule_dp *rule, + const struct nft_regs *regs) +{ + if (static_branch_unlikely(&nft_trace_enabled)) { + info->rule = rule; + __nft_trace_verdict(info, chain, regs); + } +} + +static bool nft_payload_fast_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_payload *priv = nft_expr_priv(expr); + const struct sk_buff *skb = pkt->skb; + u32 *dest = ®s->data[priv->dreg]; + unsigned char *ptr; + + if (priv->base == NFT_PAYLOAD_NETWORK_HEADER) + ptr = skb_network_header(skb); + else { + if (!(pkt->flags & NFT_PKTINFO_L4PROTO)) + return false; + ptr = skb->data + nft_thoff(pkt); + } + + ptr += priv->offset; + + if (unlikely(ptr + priv->len > skb_tail_pointer(skb))) + return false; + + *dest = 0; + if (priv->len == 2) + *(u16 *)dest = *(u16 *)ptr; + else if (priv->len == 4) + *(u32 *)dest = *(u32 *)ptr; + else + *(u8 *)dest = *(u8 *)ptr; + return true; +} + +DEFINE_STATIC_KEY_FALSE(nft_counters_enabled); + +static noinline void nft_update_chain_stats(const struct nft_chain *chain, + const struct nft_pktinfo *pkt) +{ + struct nft_base_chain *base_chain; + struct nft_stats __percpu *pstats; + struct nft_stats *stats; + + base_chain = nft_base_chain(chain); + + pstats = READ_ONCE(base_chain->stats); + if (pstats) { + local_bh_disable(); + stats = this_cpu_ptr(pstats); + u64_stats_update_begin(&stats->syncp); + stats->pkts++; + stats->bytes += pkt->skb->len; + u64_stats_update_end(&stats->syncp); + local_bh_enable(); + } +} + +struct nft_jumpstack { + const struct nft_chain *chain; + const struct nft_rule_dp *rule; + const struct nft_rule_dp *last_rule; +}; + +static void expr_call_ops_eval(const struct nft_expr *expr, + struct nft_regs *regs, + struct nft_pktinfo *pkt) +{ +#ifdef CONFIG_RETPOLINE + unsigned long e = (unsigned long)expr->ops->eval; +#define X(e, fun) \ + do { if ((e) == (unsigned long)(fun)) \ + return fun(expr, regs, pkt); } while (0) + + X(e, nft_payload_eval); + X(e, nft_cmp_eval); + X(e, nft_counter_eval); + X(e, nft_meta_get_eval); + X(e, nft_lookup_eval); + X(e, nft_range_eval); + X(e, nft_immediate_eval); + X(e, nft_byteorder_eval); + X(e, nft_dynset_eval); + X(e, nft_rt_get_eval); + X(e, nft_bitwise_eval); +#undef X +#endif /* CONFIG_RETPOLINE */ + expr->ops->eval(expr, regs, pkt); +} + +#define nft_rule_expr_first(rule) (struct nft_expr *)&rule->data[0] +#define nft_rule_expr_next(expr) ((void *)expr) + expr->ops->size +#define nft_rule_expr_last(rule) (struct nft_expr *)&rule->data[rule->dlen] +#define nft_rule_next(rule) (void *)rule + sizeof(*rule) + rule->dlen + +#define nft_rule_dp_for_each_expr(expr, last, rule) \ + for ((expr) = nft_rule_expr_first(rule), (last) = nft_rule_expr_last(rule); \ + (expr) != (last); \ + (expr) = nft_rule_expr_next(expr)) + +unsigned int +nft_do_chain(struct nft_pktinfo *pkt, void *priv) +{ + const struct nft_chain *chain = priv, *basechain = chain; + const struct nft_rule_dp *rule, *last_rule; + const struct net *net = nft_net(pkt); + const struct nft_expr *expr, *last; + struct nft_regs regs = {}; + unsigned int stackptr = 0; + struct nft_jumpstack jumpstack[NFT_JUMP_STACK_SIZE]; + bool genbit = READ_ONCE(net->nft.gencursor); + struct nft_rule_blob *blob; + struct nft_traceinfo info; + + info.trace = false; + if (static_branch_unlikely(&nft_trace_enabled)) + nft_trace_init(&info, pkt, ®s.verdict, basechain); +do_chain: + if (genbit) + blob = rcu_dereference(chain->blob_gen_1); + else + blob = rcu_dereference(chain->blob_gen_0); + + rule = (struct nft_rule_dp *)blob->data; + last_rule = (void *)blob->data + blob->size; +next_rule: + regs.verdict.code = NFT_CONTINUE; + for (; rule < last_rule; rule = nft_rule_next(rule)) { + nft_rule_dp_for_each_expr(expr, last, rule) { + if (expr->ops == &nft_cmp_fast_ops) + nft_cmp_fast_eval(expr, ®s); + else if (expr->ops == &nft_cmp16_fast_ops) + nft_cmp16_fast_eval(expr, ®s); + else if (expr->ops == &nft_bitwise_fast_ops) + nft_bitwise_fast_eval(expr, ®s); + else if (expr->ops != &nft_payload_fast_ops || + !nft_payload_fast_eval(expr, ®s, pkt)) + expr_call_ops_eval(expr, ®s, pkt); + + if (regs.verdict.code != NFT_CONTINUE) + break; + } + + switch (regs.verdict.code) { + case NFT_BREAK: + regs.verdict.code = NFT_CONTINUE; + nft_trace_copy_nftrace(pkt, &info); + continue; + case NFT_CONTINUE: + nft_trace_packet(pkt, &info, chain, rule, + NFT_TRACETYPE_RULE); + continue; + } + break; + } + + nft_trace_verdict(&info, chain, rule, ®s); + + switch (regs.verdict.code & NF_VERDICT_MASK) { + case NF_ACCEPT: + case NF_DROP: + case NF_QUEUE: + case NF_STOLEN: + return regs.verdict.code; + } + + switch (regs.verdict.code) { + case NFT_JUMP: + if (WARN_ON_ONCE(stackptr >= NFT_JUMP_STACK_SIZE)) + return NF_DROP; + jumpstack[stackptr].chain = chain; + jumpstack[stackptr].rule = nft_rule_next(rule); + jumpstack[stackptr].last_rule = last_rule; + stackptr++; + fallthrough; + case NFT_GOTO: + chain = regs.verdict.chain; + goto do_chain; + case NFT_CONTINUE: + case NFT_RETURN: + break; + default: + WARN_ON_ONCE(1); + } + + if (stackptr > 0) { + stackptr--; + chain = jumpstack[stackptr].chain; + rule = jumpstack[stackptr].rule; + last_rule = jumpstack[stackptr].last_rule; + goto next_rule; + } + + nft_trace_packet(pkt, &info, basechain, NULL, NFT_TRACETYPE_POLICY); + + if (static_branch_unlikely(&nft_counters_enabled)) + nft_update_chain_stats(basechain, pkt); + + return nft_base_chain(basechain)->policy; +} +EXPORT_SYMBOL_GPL(nft_do_chain); + +static struct nft_expr_type *nft_basic_types[] = { + &nft_imm_type, + &nft_cmp_type, + &nft_lookup_type, + &nft_bitwise_type, + &nft_byteorder_type, + &nft_payload_type, + &nft_dynset_type, + &nft_range_type, + &nft_meta_type, + &nft_rt_type, + &nft_exthdr_type, + &nft_last_type, + &nft_counter_type, +}; + +static struct nft_object_type *nft_basic_objects[] = { +#ifdef CONFIG_NETWORK_SECMARK + &nft_secmark_obj_type, +#endif + &nft_counter_obj_type, +}; + +int __init nf_tables_core_module_init(void) +{ + int err, i, j = 0; + + nft_counter_init_seqcount(); + + for (i = 0; i < ARRAY_SIZE(nft_basic_objects); i++) { + err = nft_register_obj(nft_basic_objects[i]); + if (err) + goto err; + } + + for (j = 0; j < ARRAY_SIZE(nft_basic_types); j++) { + err = nft_register_expr(nft_basic_types[j]); + if (err) + goto err; + } + + return 0; + +err: + while (j-- > 0) + nft_unregister_expr(nft_basic_types[j]); + + while (i-- > 0) + nft_unregister_obj(nft_basic_objects[i]); + + return err; +} + +void nf_tables_core_module_exit(void) +{ + int i; + + i = ARRAY_SIZE(nft_basic_types); + while (i-- > 0) + nft_unregister_expr(nft_basic_types[i]); + + i = ARRAY_SIZE(nft_basic_objects); + while (i-- > 0) + nft_unregister_obj(nft_basic_objects[i]); +} diff --git a/net/netfilter/nf_tables_offload.c b/net/netfilter/nf_tables_offload.c new file mode 100644 index 000000000..910ef881c --- /dev/null +++ b/net/netfilter/nf_tables_offload.c @@ -0,0 +1,691 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#include +#include +#include +#include +#include +#include +#include + +static struct nft_flow_rule *nft_flow_rule_alloc(int num_actions) +{ + struct nft_flow_rule *flow; + + flow = kzalloc(sizeof(struct nft_flow_rule), GFP_KERNEL); + if (!flow) + return NULL; + + flow->rule = flow_rule_alloc(num_actions); + if (!flow->rule) { + kfree(flow); + return NULL; + } + + flow->rule->match.dissector = &flow->match.dissector; + flow->rule->match.mask = &flow->match.mask; + flow->rule->match.key = &flow->match.key; + + return flow; +} + +void nft_flow_rule_set_addr_type(struct nft_flow_rule *flow, + enum flow_dissector_key_id addr_type) +{ + struct nft_flow_match *match = &flow->match; + struct nft_flow_key *mask = &match->mask; + struct nft_flow_key *key = &match->key; + + if (match->dissector.used_keys & BIT(FLOW_DISSECTOR_KEY_CONTROL)) + return; + + key->control.addr_type = addr_type; + mask->control.addr_type = 0xffff; + match->dissector.used_keys |= BIT(FLOW_DISSECTOR_KEY_CONTROL); + match->dissector.offset[FLOW_DISSECTOR_KEY_CONTROL] = + offsetof(struct nft_flow_key, control); +} + +struct nft_offload_ethertype { + __be16 value; + __be16 mask; +}; + +static void nft_flow_rule_transfer_vlan(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow) +{ + struct nft_flow_match *match = &flow->match; + struct nft_offload_ethertype ethertype = { + .value = match->key.basic.n_proto, + .mask = match->mask.basic.n_proto, + }; + + if (match->dissector.used_keys & BIT(FLOW_DISSECTOR_KEY_VLAN) && + (match->key.vlan.vlan_tpid == htons(ETH_P_8021Q) || + match->key.vlan.vlan_tpid == htons(ETH_P_8021AD))) { + match->key.basic.n_proto = match->key.cvlan.vlan_tpid; + match->mask.basic.n_proto = match->mask.cvlan.vlan_tpid; + match->key.cvlan.vlan_tpid = match->key.vlan.vlan_tpid; + match->mask.cvlan.vlan_tpid = match->mask.vlan.vlan_tpid; + match->key.vlan.vlan_tpid = ethertype.value; + match->mask.vlan.vlan_tpid = ethertype.mask; + match->dissector.offset[FLOW_DISSECTOR_KEY_CVLAN] = + offsetof(struct nft_flow_key, cvlan); + match->dissector.used_keys |= BIT(FLOW_DISSECTOR_KEY_CVLAN); + } else if (match->dissector.used_keys & BIT(FLOW_DISSECTOR_KEY_BASIC) && + (match->key.basic.n_proto == htons(ETH_P_8021Q) || + match->key.basic.n_proto == htons(ETH_P_8021AD))) { + match->key.basic.n_proto = match->key.vlan.vlan_tpid; + match->mask.basic.n_proto = match->mask.vlan.vlan_tpid; + match->key.vlan.vlan_tpid = ethertype.value; + match->mask.vlan.vlan_tpid = ethertype.mask; + match->dissector.offset[FLOW_DISSECTOR_KEY_VLAN] = + offsetof(struct nft_flow_key, vlan); + match->dissector.used_keys |= BIT(FLOW_DISSECTOR_KEY_VLAN); + } +} + +struct nft_flow_rule *nft_flow_rule_create(struct net *net, + const struct nft_rule *rule) +{ + struct nft_offload_ctx *ctx; + struct nft_flow_rule *flow; + int num_actions = 0, err; + struct nft_expr *expr; + + expr = nft_expr_first(rule); + while (nft_expr_more(rule, expr)) { + if (expr->ops->offload_action && + expr->ops->offload_action(expr)) + num_actions++; + + expr = nft_expr_next(expr); + } + + if (num_actions == 0) + return ERR_PTR(-EOPNOTSUPP); + + flow = nft_flow_rule_alloc(num_actions); + if (!flow) + return ERR_PTR(-ENOMEM); + + expr = nft_expr_first(rule); + + ctx = kzalloc(sizeof(struct nft_offload_ctx), GFP_KERNEL); + if (!ctx) { + err = -ENOMEM; + goto err_out; + } + ctx->net = net; + ctx->dep.type = NFT_OFFLOAD_DEP_UNSPEC; + + while (nft_expr_more(rule, expr)) { + if (!expr->ops->offload) { + err = -EOPNOTSUPP; + goto err_out; + } + err = expr->ops->offload(ctx, flow, expr); + if (err < 0) + goto err_out; + + expr = nft_expr_next(expr); + } + nft_flow_rule_transfer_vlan(ctx, flow); + + flow->proto = ctx->dep.l3num; + kfree(ctx); + + return flow; +err_out: + kfree(ctx); + nft_flow_rule_destroy(flow); + + return ERR_PTR(err); +} + +void nft_flow_rule_destroy(struct nft_flow_rule *flow) +{ + struct flow_action_entry *entry; + int i; + + flow_action_for_each(i, entry, &flow->rule->action) { + switch (entry->id) { + case FLOW_ACTION_REDIRECT: + case FLOW_ACTION_MIRRED: + dev_put(entry->dev); + break; + default: + break; + } + } + kfree(flow->rule); + kfree(flow); +} + +void nft_offload_set_dependency(struct nft_offload_ctx *ctx, + enum nft_offload_dep_type type) +{ + ctx->dep.type = type; +} + +void nft_offload_update_dependency(struct nft_offload_ctx *ctx, + const void *data, u32 len) +{ + switch (ctx->dep.type) { + case NFT_OFFLOAD_DEP_NETWORK: + WARN_ON(len != sizeof(__u16)); + memcpy(&ctx->dep.l3num, data, sizeof(__u16)); + break; + case NFT_OFFLOAD_DEP_TRANSPORT: + WARN_ON(len != sizeof(__u8)); + memcpy(&ctx->dep.protonum, data, sizeof(__u8)); + break; + default: + break; + } + ctx->dep.type = NFT_OFFLOAD_DEP_UNSPEC; +} + +static void nft_flow_offload_common_init(struct flow_cls_common_offload *common, + __be16 proto, int priority, + struct netlink_ext_ack *extack) +{ + common->protocol = proto; + common->prio = priority; + common->extack = extack; +} + +static int nft_setup_cb_call(enum tc_setup_type type, void *type_data, + struct list_head *cb_list) +{ + struct flow_block_cb *block_cb; + int err; + + list_for_each_entry(block_cb, cb_list, list) { + err = block_cb->cb(type, type_data, block_cb->cb_priv); + if (err < 0) + return err; + } + return 0; +} + +static int nft_chain_offload_priority(const struct nft_base_chain *basechain) +{ + if (basechain->ops.priority <= 0 || + basechain->ops.priority > USHRT_MAX) + return -1; + + return 0; +} + +bool nft_chain_offload_support(const struct nft_base_chain *basechain) +{ + struct net_device *dev; + struct nft_hook *hook; + + if (nft_chain_offload_priority(basechain) < 0) + return false; + + list_for_each_entry(hook, &basechain->hook_list, list) { + if (hook->ops.pf != NFPROTO_NETDEV || + hook->ops.hooknum != NF_NETDEV_INGRESS) + return false; + + dev = hook->ops.dev; + if (!dev->netdev_ops->ndo_setup_tc && !flow_indr_dev_exists()) + return false; + } + + return true; +} + +static void nft_flow_cls_offload_setup(struct flow_cls_offload *cls_flow, + const struct nft_base_chain *basechain, + const struct nft_rule *rule, + const struct nft_flow_rule *flow, + struct netlink_ext_ack *extack, + enum flow_cls_command command) +{ + __be16 proto = ETH_P_ALL; + + memset(cls_flow, 0, sizeof(*cls_flow)); + + if (flow) + proto = flow->proto; + + nft_flow_offload_common_init(&cls_flow->common, proto, + basechain->ops.priority, extack); + cls_flow->command = command; + cls_flow->cookie = (unsigned long) rule; + if (flow) + cls_flow->rule = flow->rule; +} + +static int nft_flow_offload_cmd(const struct nft_chain *chain, + const struct nft_rule *rule, + struct nft_flow_rule *flow, + enum flow_cls_command command, + struct flow_cls_offload *cls_flow) +{ + struct netlink_ext_ack extack = {}; + struct nft_base_chain *basechain; + + if (!nft_is_base_chain(chain)) + return -EOPNOTSUPP; + + basechain = nft_base_chain(chain); + nft_flow_cls_offload_setup(cls_flow, basechain, rule, flow, &extack, + command); + + return nft_setup_cb_call(TC_SETUP_CLSFLOWER, cls_flow, + &basechain->flow_block.cb_list); +} + +static int nft_flow_offload_rule(const struct nft_chain *chain, + struct nft_rule *rule, + struct nft_flow_rule *flow, + enum flow_cls_command command) +{ + struct flow_cls_offload cls_flow; + + return nft_flow_offload_cmd(chain, rule, flow, command, &cls_flow); +} + +int nft_flow_rule_stats(const struct nft_chain *chain, + const struct nft_rule *rule) +{ + struct flow_cls_offload cls_flow = {}; + struct nft_expr *expr, *next; + int err; + + err = nft_flow_offload_cmd(chain, rule, NULL, FLOW_CLS_STATS, + &cls_flow); + if (err < 0) + return err; + + nft_rule_for_each_expr(expr, next, rule) { + if (expr->ops->offload_stats) + expr->ops->offload_stats(expr, &cls_flow.stats); + } + + return 0; +} + +static int nft_flow_offload_bind(struct flow_block_offload *bo, + struct nft_base_chain *basechain) +{ + list_splice(&bo->cb_list, &basechain->flow_block.cb_list); + return 0; +} + +static int nft_flow_offload_unbind(struct flow_block_offload *bo, + struct nft_base_chain *basechain) +{ + struct flow_block_cb *block_cb, *next; + struct flow_cls_offload cls_flow; + struct netlink_ext_ack extack; + struct nft_chain *chain; + struct nft_rule *rule; + + chain = &basechain->chain; + list_for_each_entry(rule, &chain->rules, list) { + memset(&extack, 0, sizeof(extack)); + nft_flow_cls_offload_setup(&cls_flow, basechain, rule, NULL, + &extack, FLOW_CLS_DESTROY); + nft_setup_cb_call(TC_SETUP_CLSFLOWER, &cls_flow, &bo->cb_list); + } + + list_for_each_entry_safe(block_cb, next, &bo->cb_list, list) { + list_del(&block_cb->list); + flow_block_cb_free(block_cb); + } + + return 0; +} + +static int nft_block_setup(struct nft_base_chain *basechain, + struct flow_block_offload *bo, + enum flow_block_command cmd) +{ + int err; + + switch (cmd) { + case FLOW_BLOCK_BIND: + err = nft_flow_offload_bind(bo, basechain); + break; + case FLOW_BLOCK_UNBIND: + err = nft_flow_offload_unbind(bo, basechain); + break; + default: + WARN_ON_ONCE(1); + err = -EOPNOTSUPP; + } + + return err; +} + +static void nft_flow_block_offload_init(struct flow_block_offload *bo, + struct net *net, + enum flow_block_command cmd, + struct nft_base_chain *basechain, + struct netlink_ext_ack *extack) +{ + memset(bo, 0, sizeof(*bo)); + bo->net = net; + bo->block = &basechain->flow_block; + bo->command = cmd; + bo->binder_type = FLOW_BLOCK_BINDER_TYPE_CLSACT_INGRESS; + bo->extack = extack; + bo->cb_list_head = &basechain->flow_block.cb_list; + INIT_LIST_HEAD(&bo->cb_list); +} + +static int nft_block_offload_cmd(struct nft_base_chain *chain, + struct net_device *dev, + enum flow_block_command cmd) +{ + struct netlink_ext_ack extack = {}; + struct flow_block_offload bo; + int err; + + nft_flow_block_offload_init(&bo, dev_net(dev), cmd, chain, &extack); + + err = dev->netdev_ops->ndo_setup_tc(dev, TC_SETUP_BLOCK, &bo); + if (err < 0) + return err; + + return nft_block_setup(chain, &bo, cmd); +} + +static void nft_indr_block_cleanup(struct flow_block_cb *block_cb) +{ + struct nft_base_chain *basechain = block_cb->indr.data; + struct net_device *dev = block_cb->indr.dev; + struct netlink_ext_ack extack = {}; + struct nftables_pernet *nft_net; + struct net *net = dev_net(dev); + struct flow_block_offload bo; + + nft_flow_block_offload_init(&bo, dev_net(dev), FLOW_BLOCK_UNBIND, + basechain, &extack); + nft_net = nft_pernet(net); + mutex_lock(&nft_net->commit_mutex); + list_del(&block_cb->driver_list); + list_move(&block_cb->list, &bo.cb_list); + nft_flow_offload_unbind(&bo, basechain); + mutex_unlock(&nft_net->commit_mutex); +} + +static int nft_indr_block_offload_cmd(struct nft_base_chain *basechain, + struct net_device *dev, + enum flow_block_command cmd) +{ + struct netlink_ext_ack extack = {}; + struct flow_block_offload bo; + int err; + + nft_flow_block_offload_init(&bo, dev_net(dev), cmd, basechain, &extack); + + err = flow_indr_dev_setup_offload(dev, NULL, TC_SETUP_BLOCK, basechain, &bo, + nft_indr_block_cleanup); + if (err < 0) + return err; + + if (list_empty(&bo.cb_list)) + return -EOPNOTSUPP; + + return nft_block_setup(basechain, &bo, cmd); +} + +static int nft_chain_offload_cmd(struct nft_base_chain *basechain, + struct net_device *dev, + enum flow_block_command cmd) +{ + int err; + + if (dev->netdev_ops->ndo_setup_tc) + err = nft_block_offload_cmd(basechain, dev, cmd); + else + err = nft_indr_block_offload_cmd(basechain, dev, cmd); + + return err; +} + +static int nft_flow_block_chain(struct nft_base_chain *basechain, + const struct net_device *this_dev, + enum flow_block_command cmd) +{ + struct net_device *dev; + struct nft_hook *hook; + int err, i = 0; + + list_for_each_entry(hook, &basechain->hook_list, list) { + dev = hook->ops.dev; + if (this_dev && this_dev != dev) + continue; + + err = nft_chain_offload_cmd(basechain, dev, cmd); + if (err < 0 && cmd == FLOW_BLOCK_BIND) { + if (!this_dev) + goto err_flow_block; + + return err; + } + i++; + } + + return 0; + +err_flow_block: + list_for_each_entry(hook, &basechain->hook_list, list) { + if (i-- <= 0) + break; + + dev = hook->ops.dev; + nft_chain_offload_cmd(basechain, dev, FLOW_BLOCK_UNBIND); + } + return err; +} + +static int nft_flow_offload_chain(struct nft_chain *chain, u8 *ppolicy, + enum flow_block_command cmd) +{ + struct nft_base_chain *basechain; + u8 policy; + + if (!nft_is_base_chain(chain)) + return -EOPNOTSUPP; + + basechain = nft_base_chain(chain); + policy = ppolicy ? *ppolicy : basechain->policy; + + /* Only default policy to accept is supported for now. */ + if (cmd == FLOW_BLOCK_BIND && policy == NF_DROP) + return -EOPNOTSUPP; + + return nft_flow_block_chain(basechain, NULL, cmd); +} + +static void nft_flow_rule_offload_abort(struct net *net, + struct nft_trans *trans) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + int err = 0; + + list_for_each_entry_continue_reverse(trans, &nft_net->commit_list, list) { + if (trans->ctx.family != NFPROTO_NETDEV) + continue; + + switch (trans->msg_type) { + case NFT_MSG_NEWCHAIN: + if (!(trans->ctx.chain->flags & NFT_CHAIN_HW_OFFLOAD) || + nft_trans_chain_update(trans)) + continue; + + err = nft_flow_offload_chain(trans->ctx.chain, NULL, + FLOW_BLOCK_UNBIND); + break; + case NFT_MSG_DELCHAIN: + if (!(trans->ctx.chain->flags & NFT_CHAIN_HW_OFFLOAD)) + continue; + + err = nft_flow_offload_chain(trans->ctx.chain, NULL, + FLOW_BLOCK_BIND); + break; + case NFT_MSG_NEWRULE: + if (!(trans->ctx.chain->flags & NFT_CHAIN_HW_OFFLOAD)) + continue; + + err = nft_flow_offload_rule(trans->ctx.chain, + nft_trans_rule(trans), + NULL, FLOW_CLS_DESTROY); + break; + case NFT_MSG_DELRULE: + if (!(trans->ctx.chain->flags & NFT_CHAIN_HW_OFFLOAD)) + continue; + + err = nft_flow_offload_rule(trans->ctx.chain, + nft_trans_rule(trans), + nft_trans_flow_rule(trans), + FLOW_CLS_REPLACE); + break; + } + + if (WARN_ON_ONCE(err)) + break; + } +} + +int nft_flow_rule_offload_commit(struct net *net) +{ + struct nftables_pernet *nft_net = nft_pernet(net); + struct nft_trans *trans; + int err = 0; + u8 policy; + + list_for_each_entry(trans, &nft_net->commit_list, list) { + if (trans->ctx.family != NFPROTO_NETDEV) + continue; + + switch (trans->msg_type) { + case NFT_MSG_NEWCHAIN: + if (!(trans->ctx.chain->flags & NFT_CHAIN_HW_OFFLOAD) || + nft_trans_chain_update(trans)) + continue; + + policy = nft_trans_chain_policy(trans); + err = nft_flow_offload_chain(trans->ctx.chain, &policy, + FLOW_BLOCK_BIND); + break; + case NFT_MSG_DELCHAIN: + if (!(trans->ctx.chain->flags & NFT_CHAIN_HW_OFFLOAD)) + continue; + + policy = nft_trans_chain_policy(trans); + err = nft_flow_offload_chain(trans->ctx.chain, &policy, + FLOW_BLOCK_UNBIND); + break; + case NFT_MSG_NEWRULE: + if (!(trans->ctx.chain->flags & NFT_CHAIN_HW_OFFLOAD)) + continue; + + if (trans->ctx.flags & NLM_F_REPLACE || + !(trans->ctx.flags & NLM_F_APPEND)) { + err = -EOPNOTSUPP; + break; + } + err = nft_flow_offload_rule(trans->ctx.chain, + nft_trans_rule(trans), + nft_trans_flow_rule(trans), + FLOW_CLS_REPLACE); + break; + case NFT_MSG_DELRULE: + if (!(trans->ctx.chain->flags & NFT_CHAIN_HW_OFFLOAD)) + continue; + + err = nft_flow_offload_rule(trans->ctx.chain, + nft_trans_rule(trans), + NULL, FLOW_CLS_DESTROY); + break; + } + + if (err) { + nft_flow_rule_offload_abort(net, trans); + break; + } + } + + return err; +} + +static struct nft_chain *__nft_offload_get_chain(const struct nftables_pernet *nft_net, + struct net_device *dev) +{ + struct nft_base_chain *basechain; + struct nft_hook *hook, *found; + const struct nft_table *table; + struct nft_chain *chain; + + list_for_each_entry(table, &nft_net->tables, list) { + if (table->family != NFPROTO_NETDEV) + continue; + + list_for_each_entry(chain, &table->chains, list) { + if (!nft_is_base_chain(chain) || + !(chain->flags & NFT_CHAIN_HW_OFFLOAD)) + continue; + + found = NULL; + basechain = nft_base_chain(chain); + list_for_each_entry(hook, &basechain->hook_list, list) { + if (hook->ops.dev != dev) + continue; + + found = hook; + break; + } + if (!found) + continue; + + return chain; + } + } + + return NULL; +} + +static int nft_offload_netdev_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct nftables_pernet *nft_net; + struct net *net = dev_net(dev); + struct nft_chain *chain; + + if (event != NETDEV_UNREGISTER) + return NOTIFY_DONE; + + nft_net = nft_pernet(net); + mutex_lock(&nft_net->commit_mutex); + chain = __nft_offload_get_chain(nft_net, dev); + if (chain) + nft_flow_block_chain(nft_base_chain(chain), dev, + FLOW_BLOCK_UNBIND); + + mutex_unlock(&nft_net->commit_mutex); + + return NOTIFY_DONE; +} + +static struct notifier_block nft_offload_netdev_notifier = { + .notifier_call = nft_offload_netdev_event, +}; + +int nft_offload_init(void) +{ + return register_netdevice_notifier(&nft_offload_netdev_notifier); +} + +void nft_offload_exit(void) +{ + unregister_netdevice_notifier(&nft_offload_netdev_notifier); +} diff --git a/net/netfilter/nf_tables_trace.c b/net/netfilter/nf_tables_trace.c new file mode 100644 index 000000000..1163ba9c1 --- /dev/null +++ b/net/netfilter/nf_tables_trace.c @@ -0,0 +1,295 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * (C) 2015 Red Hat GmbH + * Author: Florian Westphal + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define NFT_TRACETYPE_LL_HSIZE 20 +#define NFT_TRACETYPE_NETWORK_HSIZE 40 +#define NFT_TRACETYPE_TRANSPORT_HSIZE 20 + +DEFINE_STATIC_KEY_FALSE(nft_trace_enabled); +EXPORT_SYMBOL_GPL(nft_trace_enabled); + +static int trace_fill_header(struct sk_buff *nlskb, u16 type, + const struct sk_buff *skb, + int off, unsigned int len) +{ + struct nlattr *nla; + + if (len == 0) + return 0; + + nla = nla_reserve(nlskb, type, len); + if (!nla || skb_copy_bits(skb, off, nla_data(nla), len)) + return -1; + + return 0; +} + +static int nf_trace_fill_ll_header(struct sk_buff *nlskb, + const struct sk_buff *skb) +{ + struct vlan_ethhdr veth; + int off; + + BUILD_BUG_ON(sizeof(veth) > NFT_TRACETYPE_LL_HSIZE); + + off = skb_mac_header(skb) - skb->data; + if (off != -ETH_HLEN) + return -1; + + if (skb_copy_bits(skb, off, &veth, ETH_HLEN)) + return -1; + + veth.h_vlan_proto = skb->vlan_proto; + veth.h_vlan_TCI = htons(skb_vlan_tag_get(skb)); + veth.h_vlan_encapsulated_proto = skb->protocol; + + return nla_put(nlskb, NFTA_TRACE_LL_HEADER, sizeof(veth), &veth); +} + +static int nf_trace_fill_dev_info(struct sk_buff *nlskb, + const struct net_device *indev, + const struct net_device *outdev) +{ + if (indev) { + if (nla_put_be32(nlskb, NFTA_TRACE_IIF, + htonl(indev->ifindex))) + return -1; + + if (nla_put_be16(nlskb, NFTA_TRACE_IIFTYPE, + htons(indev->type))) + return -1; + } + + if (outdev) { + if (nla_put_be32(nlskb, NFTA_TRACE_OIF, + htonl(outdev->ifindex))) + return -1; + + if (nla_put_be16(nlskb, NFTA_TRACE_OIFTYPE, + htons(outdev->type))) + return -1; + } + + return 0; +} + +static int nf_trace_fill_pkt_info(struct sk_buff *nlskb, + const struct nft_pktinfo *pkt) +{ + const struct sk_buff *skb = pkt->skb; + int off = skb_network_offset(skb); + unsigned int len, nh_end; + + nh_end = pkt->flags & NFT_PKTINFO_L4PROTO ? nft_thoff(pkt) : skb->len; + len = min_t(unsigned int, nh_end - skb_network_offset(skb), + NFT_TRACETYPE_NETWORK_HSIZE); + if (trace_fill_header(nlskb, NFTA_TRACE_NETWORK_HEADER, skb, off, len)) + return -1; + + if (pkt->flags & NFT_PKTINFO_L4PROTO) { + len = min_t(unsigned int, skb->len - nft_thoff(pkt), + NFT_TRACETYPE_TRANSPORT_HSIZE); + if (trace_fill_header(nlskb, NFTA_TRACE_TRANSPORT_HEADER, skb, + nft_thoff(pkt), len)) + return -1; + } + + if (!skb_mac_header_was_set(skb)) + return 0; + + if (skb_vlan_tag_get(skb)) + return nf_trace_fill_ll_header(nlskb, skb); + + off = skb_mac_header(skb) - skb->data; + len = min_t(unsigned int, -off, NFT_TRACETYPE_LL_HSIZE); + return trace_fill_header(nlskb, NFTA_TRACE_LL_HEADER, + skb, off, len); +} + +static int nf_trace_fill_rule_info(struct sk_buff *nlskb, + const struct nft_traceinfo *info) +{ + if (!info->rule || info->rule->is_last) + return 0; + + /* a continue verdict with ->type == RETURN means that this is + * an implicit return (end of chain reached). + * + * Since no rule matched, the ->rule pointer is invalid. + */ + if (info->type == NFT_TRACETYPE_RETURN && + info->verdict->code == NFT_CONTINUE) + return 0; + + return nla_put_be64(nlskb, NFTA_TRACE_RULE_HANDLE, + cpu_to_be64(info->rule->handle), + NFTA_TRACE_PAD); +} + +static bool nft_trace_have_verdict_chain(struct nft_traceinfo *info) +{ + switch (info->type) { + case NFT_TRACETYPE_RETURN: + case NFT_TRACETYPE_RULE: + break; + default: + return false; + } + + switch (info->verdict->code) { + case NFT_JUMP: + case NFT_GOTO: + break; + default: + return false; + } + + return true; +} + +void nft_trace_notify(struct nft_traceinfo *info) +{ + const struct nft_pktinfo *pkt = info->pkt; + struct nlmsghdr *nlh; + struct sk_buff *skb; + unsigned int size; + u32 mark = 0; + u16 event; + + if (!nfnetlink_has_listeners(nft_net(pkt), NFNLGRP_NFTRACE)) + return; + + size = nlmsg_total_size(sizeof(struct nfgenmsg)) + + nla_total_size(strlen(info->chain->table->name)) + + nla_total_size(strlen(info->chain->name)) + + nla_total_size_64bit(sizeof(__be64)) + /* rule handle */ + nla_total_size(sizeof(__be32)) + /* trace type */ + nla_total_size(0) + /* VERDICT, nested */ + nla_total_size(sizeof(u32)) + /* verdict code */ + nla_total_size(sizeof(u32)) + /* id */ + nla_total_size(NFT_TRACETYPE_LL_HSIZE) + + nla_total_size(NFT_TRACETYPE_NETWORK_HSIZE) + + nla_total_size(NFT_TRACETYPE_TRANSPORT_HSIZE) + + nla_total_size(sizeof(u32)) + /* iif */ + nla_total_size(sizeof(__be16)) + /* iiftype */ + nla_total_size(sizeof(u32)) + /* oif */ + nla_total_size(sizeof(__be16)) + /* oiftype */ + nla_total_size(sizeof(u32)) + /* mark */ + nla_total_size(sizeof(u32)) + /* nfproto */ + nla_total_size(sizeof(u32)); /* policy */ + + if (nft_trace_have_verdict_chain(info)) + size += nla_total_size(strlen(info->verdict->chain->name)); /* jump target */ + + skb = nlmsg_new(size, GFP_ATOMIC); + if (!skb) + return; + + event = nfnl_msg_type(NFNL_SUBSYS_NFTABLES, NFT_MSG_TRACE); + nlh = nfnl_msg_put(skb, 0, 0, event, 0, info->basechain->type->family, + NFNETLINK_V0, 0); + if (!nlh) + goto nla_put_failure; + + if (nla_put_be32(skb, NFTA_TRACE_NFPROTO, htonl(nft_pf(pkt)))) + goto nla_put_failure; + + if (nla_put_be32(skb, NFTA_TRACE_TYPE, htonl(info->type))) + goto nla_put_failure; + + if (nla_put_u32(skb, NFTA_TRACE_ID, info->skbid)) + goto nla_put_failure; + + if (nla_put_string(skb, NFTA_TRACE_CHAIN, info->chain->name)) + goto nla_put_failure; + + if (nla_put_string(skb, NFTA_TRACE_TABLE, info->chain->table->name)) + goto nla_put_failure; + + if (nf_trace_fill_rule_info(skb, info)) + goto nla_put_failure; + + switch (info->type) { + case NFT_TRACETYPE_UNSPEC: + case __NFT_TRACETYPE_MAX: + break; + case NFT_TRACETYPE_RETURN: + case NFT_TRACETYPE_RULE: + if (nft_verdict_dump(skb, NFTA_TRACE_VERDICT, info->verdict)) + goto nla_put_failure; + + /* pkt->skb undefined iff NF_STOLEN, disable dump */ + if (info->verdict->code == NF_STOLEN) + info->packet_dumped = true; + else + mark = pkt->skb->mark; + + break; + case NFT_TRACETYPE_POLICY: + mark = pkt->skb->mark; + + if (nla_put_be32(skb, NFTA_TRACE_POLICY, + htonl(info->basechain->policy))) + goto nla_put_failure; + break; + } + + if (mark && nla_put_be32(skb, NFTA_TRACE_MARK, htonl(mark))) + goto nla_put_failure; + + if (!info->packet_dumped) { + if (nf_trace_fill_dev_info(skb, nft_in(pkt), nft_out(pkt))) + goto nla_put_failure; + + if (nf_trace_fill_pkt_info(skb, pkt)) + goto nla_put_failure; + info->packet_dumped = true; + } + + nlmsg_end(skb, nlh); + nfnetlink_send(skb, nft_net(pkt), 0, NFNLGRP_NFTRACE, 0, GFP_ATOMIC); + return; + + nla_put_failure: + WARN_ON_ONCE(1); + kfree_skb(skb); +} + +void nft_trace_init(struct nft_traceinfo *info, const struct nft_pktinfo *pkt, + const struct nft_verdict *verdict, + const struct nft_chain *chain) +{ + static siphash_key_t trace_key __read_mostly; + struct sk_buff *skb = pkt->skb; + + info->basechain = nft_base_chain(chain); + info->trace = true; + info->nf_trace = pkt->skb->nf_trace; + info->packet_dumped = false; + info->pkt = pkt; + info->verdict = verdict; + + net_get_random_once(&trace_key, sizeof(trace_key)); + + info->skbid = (u32)siphash_3u32(hash32_ptr(skb), + skb_get_hash(skb), + skb->skb_iif, + &trace_key); +} diff --git a/net/netfilter/nfnetlink.c b/net/netfilter/nfnetlink.c new file mode 100644 index 000000000..c9fbe0f70 --- /dev/null +++ b/net/netfilter/nfnetlink.c @@ -0,0 +1,811 @@ +/* Netfilter messages via netlink socket. Allows for user space + * protocol helpers and general trouble making from userspace. + * + * (C) 2001 by Jay Schulist , + * (C) 2002-2005 by Harald Welte + * (C) 2005-2017 by Pablo Neira Ayuso + * + * Initial netfilter messages via netlink development funded and + * generally made possible by Network Robots, Inc. (www.networkrobots.com) + * + * Further development of this code funded by Astaro AG (http://www.astaro.com) + * + * This software may be used and distributed according to the terms + * of the GNU General Public License, incorporated herein by reference. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Harald Welte "); +MODULE_ALIAS_NET_PF_PROTO(PF_NETLINK, NETLINK_NETFILTER); +MODULE_DESCRIPTION("Netfilter messages via netlink socket"); + +#define nfnl_dereference_protected(id) \ + rcu_dereference_protected(table[(id)].subsys, \ + lockdep_nfnl_is_held((id))) + +#define NFNL_MAX_ATTR_COUNT 32 + +static unsigned int nfnetlink_pernet_id __read_mostly; + +#ifdef CONFIG_NF_CONNTRACK_EVENTS +static DEFINE_SPINLOCK(nfnl_grp_active_lock); +#endif + +struct nfnl_net { + struct sock *nfnl; +}; + +static struct { + struct mutex mutex; + const struct nfnetlink_subsystem __rcu *subsys; +} table[NFNL_SUBSYS_COUNT]; + +static struct lock_class_key nfnl_lockdep_keys[NFNL_SUBSYS_COUNT]; + +static const char *const nfnl_lockdep_names[NFNL_SUBSYS_COUNT] = { + [NFNL_SUBSYS_NONE] = "nfnl_subsys_none", + [NFNL_SUBSYS_CTNETLINK] = "nfnl_subsys_ctnetlink", + [NFNL_SUBSYS_CTNETLINK_EXP] = "nfnl_subsys_ctnetlink_exp", + [NFNL_SUBSYS_QUEUE] = "nfnl_subsys_queue", + [NFNL_SUBSYS_ULOG] = "nfnl_subsys_ulog", + [NFNL_SUBSYS_OSF] = "nfnl_subsys_osf", + [NFNL_SUBSYS_IPSET] = "nfnl_subsys_ipset", + [NFNL_SUBSYS_ACCT] = "nfnl_subsys_acct", + [NFNL_SUBSYS_CTNETLINK_TIMEOUT] = "nfnl_subsys_cttimeout", + [NFNL_SUBSYS_CTHELPER] = "nfnl_subsys_cthelper", + [NFNL_SUBSYS_NFTABLES] = "nfnl_subsys_nftables", + [NFNL_SUBSYS_NFT_COMPAT] = "nfnl_subsys_nftcompat", + [NFNL_SUBSYS_HOOK] = "nfnl_subsys_hook", +}; + +static const int nfnl_group2type[NFNLGRP_MAX+1] = { + [NFNLGRP_CONNTRACK_NEW] = NFNL_SUBSYS_CTNETLINK, + [NFNLGRP_CONNTRACK_UPDATE] = NFNL_SUBSYS_CTNETLINK, + [NFNLGRP_CONNTRACK_DESTROY] = NFNL_SUBSYS_CTNETLINK, + [NFNLGRP_CONNTRACK_EXP_NEW] = NFNL_SUBSYS_CTNETLINK_EXP, + [NFNLGRP_CONNTRACK_EXP_UPDATE] = NFNL_SUBSYS_CTNETLINK_EXP, + [NFNLGRP_CONNTRACK_EXP_DESTROY] = NFNL_SUBSYS_CTNETLINK_EXP, + [NFNLGRP_NFTABLES] = NFNL_SUBSYS_NFTABLES, + [NFNLGRP_ACCT_QUOTA] = NFNL_SUBSYS_ACCT, + [NFNLGRP_NFTRACE] = NFNL_SUBSYS_NFTABLES, +}; + +static struct nfnl_net *nfnl_pernet(struct net *net) +{ + return net_generic(net, nfnetlink_pernet_id); +} + +void nfnl_lock(__u8 subsys_id) +{ + mutex_lock(&table[subsys_id].mutex); +} +EXPORT_SYMBOL_GPL(nfnl_lock); + +void nfnl_unlock(__u8 subsys_id) +{ + mutex_unlock(&table[subsys_id].mutex); +} +EXPORT_SYMBOL_GPL(nfnl_unlock); + +#ifdef CONFIG_PROVE_LOCKING +bool lockdep_nfnl_is_held(u8 subsys_id) +{ + return lockdep_is_held(&table[subsys_id].mutex); +} +EXPORT_SYMBOL_GPL(lockdep_nfnl_is_held); +#endif + +int nfnetlink_subsys_register(const struct nfnetlink_subsystem *n) +{ + u8 cb_id; + + /* Sanity-check attr_count size to avoid stack buffer overflow. */ + for (cb_id = 0; cb_id < n->cb_count; cb_id++) + if (WARN_ON(n->cb[cb_id].attr_count > NFNL_MAX_ATTR_COUNT)) + return -EINVAL; + + nfnl_lock(n->subsys_id); + if (table[n->subsys_id].subsys) { + nfnl_unlock(n->subsys_id); + return -EBUSY; + } + rcu_assign_pointer(table[n->subsys_id].subsys, n); + nfnl_unlock(n->subsys_id); + + return 0; +} +EXPORT_SYMBOL_GPL(nfnetlink_subsys_register); + +int nfnetlink_subsys_unregister(const struct nfnetlink_subsystem *n) +{ + nfnl_lock(n->subsys_id); + table[n->subsys_id].subsys = NULL; + nfnl_unlock(n->subsys_id); + synchronize_rcu(); + return 0; +} +EXPORT_SYMBOL_GPL(nfnetlink_subsys_unregister); + +static inline const struct nfnetlink_subsystem *nfnetlink_get_subsys(u16 type) +{ + u8 subsys_id = NFNL_SUBSYS_ID(type); + + if (subsys_id >= NFNL_SUBSYS_COUNT) + return NULL; + + return rcu_dereference(table[subsys_id].subsys); +} + +static inline const struct nfnl_callback * +nfnetlink_find_client(u16 type, const struct nfnetlink_subsystem *ss) +{ + u8 cb_id = NFNL_MSG_TYPE(type); + + if (cb_id >= ss->cb_count) + return NULL; + + return &ss->cb[cb_id]; +} + +int nfnetlink_has_listeners(struct net *net, unsigned int group) +{ + struct nfnl_net *nfnlnet = nfnl_pernet(net); + + return netlink_has_listeners(nfnlnet->nfnl, group); +} +EXPORT_SYMBOL_GPL(nfnetlink_has_listeners); + +int nfnetlink_send(struct sk_buff *skb, struct net *net, u32 portid, + unsigned int group, int echo, gfp_t flags) +{ + struct nfnl_net *nfnlnet = nfnl_pernet(net); + + return nlmsg_notify(nfnlnet->nfnl, skb, portid, group, echo, flags); +} +EXPORT_SYMBOL_GPL(nfnetlink_send); + +int nfnetlink_set_err(struct net *net, u32 portid, u32 group, int error) +{ + struct nfnl_net *nfnlnet = nfnl_pernet(net); + + return netlink_set_err(nfnlnet->nfnl, portid, group, error); +} +EXPORT_SYMBOL_GPL(nfnetlink_set_err); + +int nfnetlink_unicast(struct sk_buff *skb, struct net *net, u32 portid) +{ + struct nfnl_net *nfnlnet = nfnl_pernet(net); + int err; + + err = nlmsg_unicast(nfnlnet->nfnl, skb, portid); + if (err == -EAGAIN) + err = -ENOBUFS; + + return err; +} +EXPORT_SYMBOL_GPL(nfnetlink_unicast); + +void nfnetlink_broadcast(struct net *net, struct sk_buff *skb, __u32 portid, + __u32 group, gfp_t allocation) +{ + struct nfnl_net *nfnlnet = nfnl_pernet(net); + + netlink_broadcast(nfnlnet->nfnl, skb, portid, group, allocation); +} +EXPORT_SYMBOL_GPL(nfnetlink_broadcast); + +/* Process one complete nfnetlink message. */ +static int nfnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + const struct nfnl_callback *nc; + const struct nfnetlink_subsystem *ss; + int type, err; + + /* All the messages must at least contain nfgenmsg */ + if (nlmsg_len(nlh) < sizeof(struct nfgenmsg)) + return 0; + + type = nlh->nlmsg_type; +replay: + rcu_read_lock(); + + ss = nfnetlink_get_subsys(type); + if (!ss) { +#ifdef CONFIG_MODULES + rcu_read_unlock(); + request_module("nfnetlink-subsys-%d", NFNL_SUBSYS_ID(type)); + rcu_read_lock(); + ss = nfnetlink_get_subsys(type); + if (!ss) +#endif + { + rcu_read_unlock(); + return -EINVAL; + } + } + + nc = nfnetlink_find_client(type, ss); + if (!nc) { + rcu_read_unlock(); + return -EINVAL; + } + + { + int min_len = nlmsg_total_size(sizeof(struct nfgenmsg)); + struct nfnl_net *nfnlnet = nfnl_pernet(net); + u8 cb_id = NFNL_MSG_TYPE(nlh->nlmsg_type); + struct nlattr *cda[NFNL_MAX_ATTR_COUNT + 1]; + struct nlattr *attr = (void *)nlh + min_len; + int attrlen = nlh->nlmsg_len - min_len; + __u8 subsys_id = NFNL_SUBSYS_ID(type); + struct nfnl_info info = { + .net = net, + .sk = nfnlnet->nfnl, + .nlh = nlh, + .nfmsg = nlmsg_data(nlh), + .extack = extack, + }; + + /* Sanity-check NFNL_MAX_ATTR_COUNT */ + if (ss->cb[cb_id].attr_count > NFNL_MAX_ATTR_COUNT) { + rcu_read_unlock(); + return -ENOMEM; + } + + err = nla_parse_deprecated(cda, ss->cb[cb_id].attr_count, + attr, attrlen, + ss->cb[cb_id].policy, extack); + if (err < 0) { + rcu_read_unlock(); + return err; + } + + if (!nc->call) { + rcu_read_unlock(); + return -EINVAL; + } + + switch (nc->type) { + case NFNL_CB_RCU: + err = nc->call(skb, &info, (const struct nlattr **)cda); + rcu_read_unlock(); + break; + case NFNL_CB_MUTEX: + rcu_read_unlock(); + nfnl_lock(subsys_id); + if (nfnl_dereference_protected(subsys_id) != ss || + nfnetlink_find_client(type, ss) != nc) { + nfnl_unlock(subsys_id); + err = -EAGAIN; + break; + } + err = nc->call(skb, &info, (const struct nlattr **)cda); + nfnl_unlock(subsys_id); + break; + default: + rcu_read_unlock(); + err = -EINVAL; + break; + } + if (err == -EAGAIN) + goto replay; + return err; + } +} + +struct nfnl_err { + struct list_head head; + struct nlmsghdr *nlh; + int err; + struct netlink_ext_ack extack; +}; + +static int nfnl_err_add(struct list_head *list, struct nlmsghdr *nlh, int err, + const struct netlink_ext_ack *extack) +{ + struct nfnl_err *nfnl_err; + + nfnl_err = kmalloc(sizeof(struct nfnl_err), GFP_KERNEL); + if (nfnl_err == NULL) + return -ENOMEM; + + nfnl_err->nlh = nlh; + nfnl_err->err = err; + nfnl_err->extack = *extack; + list_add_tail(&nfnl_err->head, list); + + return 0; +} + +static void nfnl_err_del(struct nfnl_err *nfnl_err) +{ + list_del(&nfnl_err->head); + kfree(nfnl_err); +} + +static void nfnl_err_reset(struct list_head *err_list) +{ + struct nfnl_err *nfnl_err, *next; + + list_for_each_entry_safe(nfnl_err, next, err_list, head) + nfnl_err_del(nfnl_err); +} + +static void nfnl_err_deliver(struct list_head *err_list, struct sk_buff *skb) +{ + struct nfnl_err *nfnl_err, *next; + + list_for_each_entry_safe(nfnl_err, next, err_list, head) { + netlink_ack(skb, nfnl_err->nlh, nfnl_err->err, + &nfnl_err->extack); + nfnl_err_del(nfnl_err); + } +} + +enum { + NFNL_BATCH_FAILURE = (1 << 0), + NFNL_BATCH_DONE = (1 << 1), + NFNL_BATCH_REPLAY = (1 << 2), +}; + +static void nfnetlink_rcv_batch(struct sk_buff *skb, struct nlmsghdr *nlh, + u16 subsys_id, u32 genid) +{ + struct sk_buff *oskb = skb; + struct net *net = sock_net(skb->sk); + const struct nfnetlink_subsystem *ss; + const struct nfnl_callback *nc; + struct netlink_ext_ack extack; + LIST_HEAD(err_list); + u32 status; + int err; + + if (subsys_id >= NFNL_SUBSYS_COUNT) + return netlink_ack(skb, nlh, -EINVAL, NULL); +replay: + status = 0; +replay_abort: + skb = netlink_skb_clone(oskb, GFP_KERNEL); + if (!skb) + return netlink_ack(oskb, nlh, -ENOMEM, NULL); + + nfnl_lock(subsys_id); + ss = nfnl_dereference_protected(subsys_id); + if (!ss) { +#ifdef CONFIG_MODULES + nfnl_unlock(subsys_id); + request_module("nfnetlink-subsys-%d", subsys_id); + nfnl_lock(subsys_id); + ss = nfnl_dereference_protected(subsys_id); + if (!ss) +#endif + { + nfnl_unlock(subsys_id); + netlink_ack(oskb, nlh, -EOPNOTSUPP, NULL); + return kfree_skb(skb); + } + } + + if (!ss->valid_genid || !ss->commit || !ss->abort) { + nfnl_unlock(subsys_id); + netlink_ack(oskb, nlh, -EOPNOTSUPP, NULL); + return kfree_skb(skb); + } + + if (!try_module_get(ss->owner)) { + nfnl_unlock(subsys_id); + netlink_ack(oskb, nlh, -EOPNOTSUPP, NULL); + return kfree_skb(skb); + } + + if (!ss->valid_genid(net, genid)) { + module_put(ss->owner); + nfnl_unlock(subsys_id); + netlink_ack(oskb, nlh, -ERESTART, NULL); + return kfree_skb(skb); + } + + nfnl_unlock(subsys_id); + + while (skb->len >= nlmsg_total_size(0)) { + int msglen, type; + + if (fatal_signal_pending(current)) { + nfnl_err_reset(&err_list); + err = -EINTR; + status = NFNL_BATCH_FAILURE; + goto done; + } + + memset(&extack, 0, sizeof(extack)); + nlh = nlmsg_hdr(skb); + err = 0; + + if (nlh->nlmsg_len < NLMSG_HDRLEN || + skb->len < nlh->nlmsg_len || + nlmsg_len(nlh) < sizeof(struct nfgenmsg)) { + nfnl_err_reset(&err_list); + status |= NFNL_BATCH_FAILURE; + goto done; + } + + /* Only requests are handled by the kernel */ + if (!(nlh->nlmsg_flags & NLM_F_REQUEST)) { + err = -EINVAL; + goto ack; + } + + type = nlh->nlmsg_type; + if (type == NFNL_MSG_BATCH_BEGIN) { + /* Malformed: Batch begin twice */ + nfnl_err_reset(&err_list); + status |= NFNL_BATCH_FAILURE; + goto done; + } else if (type == NFNL_MSG_BATCH_END) { + status |= NFNL_BATCH_DONE; + goto done; + } else if (type < NLMSG_MIN_TYPE) { + err = -EINVAL; + goto ack; + } + + /* We only accept a batch with messages for the same + * subsystem. + */ + if (NFNL_SUBSYS_ID(type) != subsys_id) { + err = -EINVAL; + goto ack; + } + + nc = nfnetlink_find_client(type, ss); + if (!nc) { + err = -EINVAL; + goto ack; + } + + if (nc->type != NFNL_CB_BATCH) { + err = -EINVAL; + goto ack; + } + + { + int min_len = nlmsg_total_size(sizeof(struct nfgenmsg)); + struct nfnl_net *nfnlnet = nfnl_pernet(net); + struct nlattr *cda[NFNL_MAX_ATTR_COUNT + 1]; + struct nlattr *attr = (void *)nlh + min_len; + u8 cb_id = NFNL_MSG_TYPE(nlh->nlmsg_type); + int attrlen = nlh->nlmsg_len - min_len; + struct nfnl_info info = { + .net = net, + .sk = nfnlnet->nfnl, + .nlh = nlh, + .nfmsg = nlmsg_data(nlh), + .extack = &extack, + }; + + /* Sanity-check NFTA_MAX_ATTR */ + if (ss->cb[cb_id].attr_count > NFNL_MAX_ATTR_COUNT) { + err = -ENOMEM; + goto ack; + } + + err = nla_parse_deprecated(cda, + ss->cb[cb_id].attr_count, + attr, attrlen, + ss->cb[cb_id].policy, NULL); + if (err < 0) + goto ack; + + err = nc->call(skb, &info, (const struct nlattr **)cda); + + /* The lock was released to autoload some module, we + * have to abort and start from scratch using the + * original skb. + */ + if (err == -EAGAIN) { + status |= NFNL_BATCH_REPLAY; + goto done; + } + } +ack: + if (nlh->nlmsg_flags & NLM_F_ACK || err) { + /* Errors are delivered once the full batch has been + * processed, this avoids that the same error is + * reported several times when replaying the batch. + */ + if (err == -ENOMEM || + nfnl_err_add(&err_list, nlh, err, &extack) < 0) { + /* We failed to enqueue an error, reset the + * list of errors and send OOM to userspace + * pointing to the batch header. + */ + nfnl_err_reset(&err_list); + netlink_ack(oskb, nlmsg_hdr(oskb), -ENOMEM, + NULL); + status |= NFNL_BATCH_FAILURE; + goto done; + } + /* We don't stop processing the batch on errors, thus, + * userspace gets all the errors that the batch + * triggers. + */ + if (err) + status |= NFNL_BATCH_FAILURE; + } + + msglen = NLMSG_ALIGN(nlh->nlmsg_len); + if (msglen > skb->len) + msglen = skb->len; + skb_pull(skb, msglen); + } +done: + if (status & NFNL_BATCH_REPLAY) { + ss->abort(net, oskb, NFNL_ABORT_AUTOLOAD); + nfnl_err_reset(&err_list); + kfree_skb(skb); + module_put(ss->owner); + goto replay; + } else if (status == NFNL_BATCH_DONE) { + err = ss->commit(net, oskb); + if (err == -EAGAIN) { + status |= NFNL_BATCH_REPLAY; + goto done; + } else if (err) { + ss->abort(net, oskb, NFNL_ABORT_NONE); + netlink_ack(oskb, nlmsg_hdr(oskb), err, NULL); + } + } else { + enum nfnl_abort_action abort_action; + + if (status & NFNL_BATCH_FAILURE) + abort_action = NFNL_ABORT_NONE; + else + abort_action = NFNL_ABORT_VALIDATE; + + err = ss->abort(net, oskb, abort_action); + if (err == -EAGAIN) { + nfnl_err_reset(&err_list); + kfree_skb(skb); + module_put(ss->owner); + status |= NFNL_BATCH_FAILURE; + goto replay_abort; + } + } + + nfnl_err_deliver(&err_list, oskb); + kfree_skb(skb); + module_put(ss->owner); +} + +static const struct nla_policy nfnl_batch_policy[NFNL_BATCH_MAX + 1] = { + [NFNL_BATCH_GENID] = { .type = NLA_U32 }, +}; + +static void nfnetlink_rcv_skb_batch(struct sk_buff *skb, struct nlmsghdr *nlh) +{ + int min_len = nlmsg_total_size(sizeof(struct nfgenmsg)); + struct nlattr *attr = (void *)nlh + min_len; + struct nlattr *cda[NFNL_BATCH_MAX + 1]; + int attrlen = nlh->nlmsg_len - min_len; + struct nfgenmsg *nfgenmsg; + int msglen, err; + u32 gen_id = 0; + u16 res_id; + + msglen = NLMSG_ALIGN(nlh->nlmsg_len); + if (msglen > skb->len) + msglen = skb->len; + + if (skb->len < NLMSG_HDRLEN + sizeof(struct nfgenmsg)) + return; + + err = nla_parse_deprecated(cda, NFNL_BATCH_MAX, attr, attrlen, + nfnl_batch_policy, NULL); + if (err < 0) { + netlink_ack(skb, nlh, err, NULL); + return; + } + if (cda[NFNL_BATCH_GENID]) + gen_id = ntohl(nla_get_be32(cda[NFNL_BATCH_GENID])); + + nfgenmsg = nlmsg_data(nlh); + skb_pull(skb, msglen); + /* Work around old nft using host byte order */ + if (nfgenmsg->res_id == (__force __be16)NFNL_SUBSYS_NFTABLES) + res_id = NFNL_SUBSYS_NFTABLES; + else + res_id = ntohs(nfgenmsg->res_id); + + nfnetlink_rcv_batch(skb, nlh, res_id, gen_id); +} + +static void nfnetlink_rcv(struct sk_buff *skb) +{ + struct nlmsghdr *nlh = nlmsg_hdr(skb); + + if (skb->len < NLMSG_HDRLEN || + nlh->nlmsg_len < NLMSG_HDRLEN || + skb->len < nlh->nlmsg_len) + return; + + if (!netlink_net_capable(skb, CAP_NET_ADMIN)) { + netlink_ack(skb, nlh, -EPERM, NULL); + return; + } + + if (nlh->nlmsg_type == NFNL_MSG_BATCH_BEGIN) + nfnetlink_rcv_skb_batch(skb, nlh); + else + netlink_rcv_skb(skb, nfnetlink_rcv_msg); +} + +static void nfnetlink_bind_event(struct net *net, unsigned int group) +{ +#ifdef CONFIG_NF_CONNTRACK_EVENTS + int type, group_bit; + u8 v; + + /* All NFNLGRP_CONNTRACK_* group bits fit into u8. + * The other groups are not relevant and can be ignored. + */ + if (group >= 8) + return; + + type = nfnl_group2type[group]; + + switch (type) { + case NFNL_SUBSYS_CTNETLINK: + break; + case NFNL_SUBSYS_CTNETLINK_EXP: + break; + default: + return; + } + + group_bit = (1 << group); + + spin_lock(&nfnl_grp_active_lock); + v = READ_ONCE(nf_ctnetlink_has_listener); + if ((v & group_bit) == 0) { + v |= group_bit; + + /* read concurrently without nfnl_grp_active_lock held. */ + WRITE_ONCE(nf_ctnetlink_has_listener, v); + } + + spin_unlock(&nfnl_grp_active_lock); +#endif +} + +static int nfnetlink_bind(struct net *net, int group) +{ + const struct nfnetlink_subsystem *ss; + int type; + + if (group <= NFNLGRP_NONE || group > NFNLGRP_MAX) + return 0; + + type = nfnl_group2type[group]; + + rcu_read_lock(); + ss = nfnetlink_get_subsys(type << 8); + rcu_read_unlock(); + if (!ss) + request_module_nowait("nfnetlink-subsys-%d", type); + + nfnetlink_bind_event(net, group); + return 0; +} + +static void nfnetlink_unbind(struct net *net, int group) +{ +#ifdef CONFIG_NF_CONNTRACK_EVENTS + int type, group_bit; + + if (group <= NFNLGRP_NONE || group > NFNLGRP_MAX) + return; + + type = nfnl_group2type[group]; + + switch (type) { + case NFNL_SUBSYS_CTNETLINK: + break; + case NFNL_SUBSYS_CTNETLINK_EXP: + break; + default: + return; + } + + /* ctnetlink_has_listener is u8 */ + if (group >= 8) + return; + + group_bit = (1 << group); + + spin_lock(&nfnl_grp_active_lock); + if (!nfnetlink_has_listeners(net, group)) { + u8 v = READ_ONCE(nf_ctnetlink_has_listener); + + v &= ~group_bit; + + /* read concurrently without nfnl_grp_active_lock held. */ + WRITE_ONCE(nf_ctnetlink_has_listener, v); + } + spin_unlock(&nfnl_grp_active_lock); +#endif +} + +static int __net_init nfnetlink_net_init(struct net *net) +{ + struct nfnl_net *nfnlnet = nfnl_pernet(net); + struct netlink_kernel_cfg cfg = { + .groups = NFNLGRP_MAX, + .input = nfnetlink_rcv, + .bind = nfnetlink_bind, + .unbind = nfnetlink_unbind, + }; + + nfnlnet->nfnl = netlink_kernel_create(net, NETLINK_NETFILTER, &cfg); + if (!nfnlnet->nfnl) + return -ENOMEM; + return 0; +} + +static void __net_exit nfnetlink_net_exit_batch(struct list_head *net_exit_list) +{ + struct nfnl_net *nfnlnet; + struct net *net; + + list_for_each_entry(net, net_exit_list, exit_list) { + nfnlnet = nfnl_pernet(net); + + netlink_kernel_release(nfnlnet->nfnl); + } +} + +static struct pernet_operations nfnetlink_net_ops = { + .init = nfnetlink_net_init, + .exit_batch = nfnetlink_net_exit_batch, + .id = &nfnetlink_pernet_id, + .size = sizeof(struct nfnl_net), +}; + +static int __init nfnetlink_init(void) +{ + int i; + + for (i = NFNLGRP_NONE + 1; i <= NFNLGRP_MAX; i++) + BUG_ON(nfnl_group2type[i] == NFNL_SUBSYS_NONE); + + for (i=0; i + * (C) 2011 Intra2net AG + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_DESCRIPTION("nfacct: Extended Netfilter accounting infrastructure"); + +struct nf_acct { + atomic64_t pkts; + atomic64_t bytes; + unsigned long flags; + struct list_head head; + refcount_t refcnt; + char name[NFACCT_NAME_MAX]; + struct rcu_head rcu_head; + char data[]; +}; + +struct nfacct_filter { + u32 value; + u32 mask; +}; + +struct nfnl_acct_net { + struct list_head nfnl_acct_list; +}; + +static unsigned int nfnl_acct_net_id __read_mostly; + +static inline struct nfnl_acct_net *nfnl_acct_pernet(struct net *net) +{ + return net_generic(net, nfnl_acct_net_id); +} + +#define NFACCT_F_QUOTA (NFACCT_F_QUOTA_PKTS | NFACCT_F_QUOTA_BYTES) +#define NFACCT_OVERQUOTA_BIT 2 /* NFACCT_F_OVERQUOTA */ + +static int nfnl_acct_new(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const tb[]) +{ + struct nfnl_acct_net *nfnl_acct_net = nfnl_acct_pernet(info->net); + struct nf_acct *nfacct, *matching = NULL; + unsigned int size = 0; + char *acct_name; + u32 flags = 0; + + if (!tb[NFACCT_NAME]) + return -EINVAL; + + acct_name = nla_data(tb[NFACCT_NAME]); + if (strlen(acct_name) == 0) + return -EINVAL; + + list_for_each_entry(nfacct, &nfnl_acct_net->nfnl_acct_list, head) { + if (strncmp(nfacct->name, acct_name, NFACCT_NAME_MAX) != 0) + continue; + + if (info->nlh->nlmsg_flags & NLM_F_EXCL) + return -EEXIST; + + matching = nfacct; + break; + } + + if (matching) { + if (info->nlh->nlmsg_flags & NLM_F_REPLACE) { + /* reset counters if you request a replacement. */ + atomic64_set(&matching->pkts, 0); + atomic64_set(&matching->bytes, 0); + smp_mb__before_atomic(); + /* reset overquota flag if quota is enabled. */ + if ((matching->flags & NFACCT_F_QUOTA)) + clear_bit(NFACCT_OVERQUOTA_BIT, + &matching->flags); + return 0; + } + return -EBUSY; + } + + if (tb[NFACCT_FLAGS]) { + flags = ntohl(nla_get_be32(tb[NFACCT_FLAGS])); + if (flags & ~NFACCT_F_QUOTA) + return -EOPNOTSUPP; + if ((flags & NFACCT_F_QUOTA) == NFACCT_F_QUOTA) + return -EINVAL; + if (flags & NFACCT_F_OVERQUOTA) + return -EINVAL; + if ((flags & NFACCT_F_QUOTA) && !tb[NFACCT_QUOTA]) + return -EINVAL; + + size += sizeof(u64); + } + + nfacct = kzalloc(sizeof(struct nf_acct) + size, GFP_KERNEL); + if (nfacct == NULL) + return -ENOMEM; + + if (flags & NFACCT_F_QUOTA) { + u64 *quota = (u64 *)nfacct->data; + + *quota = be64_to_cpu(nla_get_be64(tb[NFACCT_QUOTA])); + nfacct->flags = flags; + } + + nla_strscpy(nfacct->name, tb[NFACCT_NAME], NFACCT_NAME_MAX); + + if (tb[NFACCT_BYTES]) { + atomic64_set(&nfacct->bytes, + be64_to_cpu(nla_get_be64(tb[NFACCT_BYTES]))); + } + if (tb[NFACCT_PKTS]) { + atomic64_set(&nfacct->pkts, + be64_to_cpu(nla_get_be64(tb[NFACCT_PKTS]))); + } + refcount_set(&nfacct->refcnt, 1); + list_add_tail_rcu(&nfacct->head, &nfnl_acct_net->nfnl_acct_list); + return 0; +} + +static int +nfnl_acct_fill_info(struct sk_buff *skb, u32 portid, u32 seq, u32 type, + int event, struct nf_acct *acct) +{ + struct nlmsghdr *nlh; + unsigned int flags = portid ? NLM_F_MULTI : 0; + u64 pkts, bytes; + u32 old_flags; + + event = nfnl_msg_type(NFNL_SUBSYS_ACCT, event); + nlh = nfnl_msg_put(skb, portid, seq, event, flags, AF_UNSPEC, + NFNETLINK_V0, 0); + if (!nlh) + goto nlmsg_failure; + + if (nla_put_string(skb, NFACCT_NAME, acct->name)) + goto nla_put_failure; + + old_flags = acct->flags; + if (type == NFNL_MSG_ACCT_GET_CTRZERO) { + pkts = atomic64_xchg(&acct->pkts, 0); + bytes = atomic64_xchg(&acct->bytes, 0); + smp_mb__before_atomic(); + if (acct->flags & NFACCT_F_QUOTA) + clear_bit(NFACCT_OVERQUOTA_BIT, &acct->flags); + } else { + pkts = atomic64_read(&acct->pkts); + bytes = atomic64_read(&acct->bytes); + } + if (nla_put_be64(skb, NFACCT_PKTS, cpu_to_be64(pkts), + NFACCT_PAD) || + nla_put_be64(skb, NFACCT_BYTES, cpu_to_be64(bytes), + NFACCT_PAD) || + nla_put_be32(skb, NFACCT_USE, htonl(refcount_read(&acct->refcnt)))) + goto nla_put_failure; + if (acct->flags & NFACCT_F_QUOTA) { + u64 *quota = (u64 *)acct->data; + + if (nla_put_be32(skb, NFACCT_FLAGS, htonl(old_flags)) || + nla_put_be64(skb, NFACCT_QUOTA, cpu_to_be64(*quota), + NFACCT_PAD)) + goto nla_put_failure; + } + nlmsg_end(skb, nlh); + return skb->len; + +nlmsg_failure: +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -1; +} + +static int +nfnl_acct_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + struct nfnl_acct_net *nfnl_acct_net = nfnl_acct_pernet(net); + struct nf_acct *cur, *last; + const struct nfacct_filter *filter = cb->data; + + if (cb->args[2]) + return 0; + + last = (struct nf_acct *)cb->args[1]; + if (cb->args[1]) + cb->args[1] = 0; + + rcu_read_lock(); + list_for_each_entry_rcu(cur, &nfnl_acct_net->nfnl_acct_list, head) { + if (last) { + if (cur != last) + continue; + + last = NULL; + } + + if (filter && (cur->flags & filter->mask) != filter->value) + continue; + + if (nfnl_acct_fill_info(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NFNL_MSG_TYPE(cb->nlh->nlmsg_type), + NFNL_MSG_ACCT_NEW, cur) < 0) { + cb->args[1] = (unsigned long)cur; + break; + } + } + if (!cb->args[1]) + cb->args[2] = 1; + rcu_read_unlock(); + return skb->len; +} + +static int nfnl_acct_done(struct netlink_callback *cb) +{ + kfree(cb->data); + return 0; +} + +static const struct nla_policy filter_policy[NFACCT_FILTER_MAX + 1] = { + [NFACCT_FILTER_MASK] = { .type = NLA_U32 }, + [NFACCT_FILTER_VALUE] = { .type = NLA_U32 }, +}; + +static int nfnl_acct_start(struct netlink_callback *cb) +{ + const struct nlattr *const attr = cb->data; + struct nlattr *tb[NFACCT_FILTER_MAX + 1]; + struct nfacct_filter *filter; + int err; + + if (!attr) + return 0; + + err = nla_parse_nested_deprecated(tb, NFACCT_FILTER_MAX, attr, + filter_policy, NULL); + if (err < 0) + return err; + + if (!tb[NFACCT_FILTER_MASK] || !tb[NFACCT_FILTER_VALUE]) + return -EINVAL; + + filter = kzalloc(sizeof(struct nfacct_filter), GFP_KERNEL); + if (!filter) + return -ENOMEM; + + filter->mask = ntohl(nla_get_be32(tb[NFACCT_FILTER_MASK])); + filter->value = ntohl(nla_get_be32(tb[NFACCT_FILTER_VALUE])); + cb->data = filter; + + return 0; +} + +static int nfnl_acct_get(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const tb[]) +{ + struct nfnl_acct_net *nfnl_acct_net = nfnl_acct_pernet(info->net); + int ret = -ENOENT; + struct nf_acct *cur; + char *acct_name; + + if (info->nlh->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .dump = nfnl_acct_dump, + .start = nfnl_acct_start, + .done = nfnl_acct_done, + .data = (void *)tb[NFACCT_FILTER], + }; + + return netlink_dump_start(info->sk, skb, info->nlh, &c); + } + + if (!tb[NFACCT_NAME]) + return -EINVAL; + acct_name = nla_data(tb[NFACCT_NAME]); + + list_for_each_entry(cur, &nfnl_acct_net->nfnl_acct_list, head) { + struct sk_buff *skb2; + + if (strncmp(cur->name, acct_name, NFACCT_NAME_MAX)!= 0) + continue; + + skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (skb2 == NULL) { + ret = -ENOMEM; + break; + } + + ret = nfnl_acct_fill_info(skb2, NETLINK_CB(skb).portid, + info->nlh->nlmsg_seq, + NFNL_MSG_TYPE(info->nlh->nlmsg_type), + NFNL_MSG_ACCT_NEW, cur); + if (ret <= 0) { + kfree_skb(skb2); + break; + } + + ret = nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid); + break; + } + + return ret; +} + +/* try to delete object, fail if it is still in use. */ +static int nfnl_acct_try_del(struct nf_acct *cur) +{ + int ret = 0; + + /* We want to avoid races with nfnl_acct_put. So only when the current + * refcnt is 1, we decrease it to 0. + */ + if (refcount_dec_if_one(&cur->refcnt)) { + /* We are protected by nfnl mutex. */ + list_del_rcu(&cur->head); + kfree_rcu(cur, rcu_head); + } else { + ret = -EBUSY; + } + return ret; +} + +static int nfnl_acct_del(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const tb[]) +{ + struct nfnl_acct_net *nfnl_acct_net = nfnl_acct_pernet(info->net); + struct nf_acct *cur, *tmp; + int ret = -ENOENT; + char *acct_name; + + if (!tb[NFACCT_NAME]) { + list_for_each_entry_safe(cur, tmp, &nfnl_acct_net->nfnl_acct_list, head) + nfnl_acct_try_del(cur); + + return 0; + } + acct_name = nla_data(tb[NFACCT_NAME]); + + list_for_each_entry(cur, &nfnl_acct_net->nfnl_acct_list, head) { + if (strncmp(cur->name, acct_name, NFACCT_NAME_MAX) != 0) + continue; + + ret = nfnl_acct_try_del(cur); + if (ret < 0) + return ret; + + break; + } + return ret; +} + +static const struct nla_policy nfnl_acct_policy[NFACCT_MAX+1] = { + [NFACCT_NAME] = { .type = NLA_NUL_STRING, .len = NFACCT_NAME_MAX-1 }, + [NFACCT_BYTES] = { .type = NLA_U64 }, + [NFACCT_PKTS] = { .type = NLA_U64 }, + [NFACCT_FLAGS] = { .type = NLA_U32 }, + [NFACCT_QUOTA] = { .type = NLA_U64 }, + [NFACCT_FILTER] = {.type = NLA_NESTED }, +}; + +static const struct nfnl_callback nfnl_acct_cb[NFNL_MSG_ACCT_MAX] = { + [NFNL_MSG_ACCT_NEW] = { + .call = nfnl_acct_new, + .type = NFNL_CB_MUTEX, + .attr_count = NFACCT_MAX, + .policy = nfnl_acct_policy + }, + [NFNL_MSG_ACCT_GET] = { + .call = nfnl_acct_get, + .type = NFNL_CB_MUTEX, + .attr_count = NFACCT_MAX, + .policy = nfnl_acct_policy + }, + [NFNL_MSG_ACCT_GET_CTRZERO] = { + .call = nfnl_acct_get, + .type = NFNL_CB_MUTEX, + .attr_count = NFACCT_MAX, + .policy = nfnl_acct_policy + }, + [NFNL_MSG_ACCT_DEL] = { + .call = nfnl_acct_del, + .type = NFNL_CB_MUTEX, + .attr_count = NFACCT_MAX, + .policy = nfnl_acct_policy + }, +}; + +static const struct nfnetlink_subsystem nfnl_acct_subsys = { + .name = "acct", + .subsys_id = NFNL_SUBSYS_ACCT, + .cb_count = NFNL_MSG_ACCT_MAX, + .cb = nfnl_acct_cb, +}; + +MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_ACCT); + +struct nf_acct *nfnl_acct_find_get(struct net *net, const char *acct_name) +{ + struct nfnl_acct_net *nfnl_acct_net = nfnl_acct_pernet(net); + struct nf_acct *cur, *acct = NULL; + + rcu_read_lock(); + list_for_each_entry_rcu(cur, &nfnl_acct_net->nfnl_acct_list, head) { + if (strncmp(cur->name, acct_name, NFACCT_NAME_MAX)!= 0) + continue; + + if (!try_module_get(THIS_MODULE)) + goto err; + + if (!refcount_inc_not_zero(&cur->refcnt)) { + module_put(THIS_MODULE); + goto err; + } + + acct = cur; + break; + } +err: + rcu_read_unlock(); + return acct; +} +EXPORT_SYMBOL_GPL(nfnl_acct_find_get); + +void nfnl_acct_put(struct nf_acct *acct) +{ + if (refcount_dec_and_test(&acct->refcnt)) + kfree_rcu(acct, rcu_head); + + module_put(THIS_MODULE); +} +EXPORT_SYMBOL_GPL(nfnl_acct_put); + +void nfnl_acct_update(const struct sk_buff *skb, struct nf_acct *nfacct) +{ + atomic64_inc(&nfacct->pkts); + atomic64_add(skb->len, &nfacct->bytes); +} +EXPORT_SYMBOL_GPL(nfnl_acct_update); + +static void nfnl_overquota_report(struct net *net, struct nf_acct *nfacct) +{ + int ret; + struct sk_buff *skb; + + skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (skb == NULL) + return; + + ret = nfnl_acct_fill_info(skb, 0, 0, NFNL_MSG_ACCT_OVERQUOTA, 0, + nfacct); + if (ret <= 0) { + kfree_skb(skb); + return; + } + nfnetlink_broadcast(net, skb, 0, NFNLGRP_ACCT_QUOTA, GFP_ATOMIC); +} + +int nfnl_acct_overquota(struct net *net, struct nf_acct *nfacct) +{ + u64 now; + u64 *quota; + int ret = NFACCT_UNDERQUOTA; + + /* no place here if we don't have a quota */ + if (!(nfacct->flags & NFACCT_F_QUOTA)) + return NFACCT_NO_QUOTA; + + quota = (u64 *)nfacct->data; + now = (nfacct->flags & NFACCT_F_QUOTA_PKTS) ? + atomic64_read(&nfacct->pkts) : atomic64_read(&nfacct->bytes); + + ret = now > *quota; + + if (now >= *quota && + !test_and_set_bit(NFACCT_OVERQUOTA_BIT, &nfacct->flags)) { + nfnl_overquota_report(net, nfacct); + } + + return ret; +} +EXPORT_SYMBOL_GPL(nfnl_acct_overquota); + +static int __net_init nfnl_acct_net_init(struct net *net) +{ + INIT_LIST_HEAD(&nfnl_acct_pernet(net)->nfnl_acct_list); + + return 0; +} + +static void __net_exit nfnl_acct_net_exit(struct net *net) +{ + struct nfnl_acct_net *nfnl_acct_net = nfnl_acct_pernet(net); + struct nf_acct *cur, *tmp; + + list_for_each_entry_safe(cur, tmp, &nfnl_acct_net->nfnl_acct_list, head) { + list_del_rcu(&cur->head); + + if (refcount_dec_and_test(&cur->refcnt)) + kfree_rcu(cur, rcu_head); + } +} + +static struct pernet_operations nfnl_acct_ops = { + .init = nfnl_acct_net_init, + .exit = nfnl_acct_net_exit, + .id = &nfnl_acct_net_id, + .size = sizeof(struct nfnl_acct_net), +}; + +static int __init nfnl_acct_init(void) +{ + int ret; + + ret = register_pernet_subsys(&nfnl_acct_ops); + if (ret < 0) { + pr_err("nfnl_acct_init: failed to register pernet ops\n"); + goto err_out; + } + + ret = nfnetlink_subsys_register(&nfnl_acct_subsys); + if (ret < 0) { + pr_err("nfnl_acct_init: cannot register with nfnetlink.\n"); + goto cleanup_pernet; + } + return 0; + +cleanup_pernet: + unregister_pernet_subsys(&nfnl_acct_ops); +err_out: + return ret; +} + +static void __exit nfnl_acct_exit(void) +{ + nfnetlink_subsys_unregister(&nfnl_acct_subsys); + unregister_pernet_subsys(&nfnl_acct_ops); +} + +module_init(nfnl_acct_init); +module_exit(nfnl_acct_exit); diff --git a/net/netfilter/nfnetlink_cthelper.c b/net/netfilter/nfnetlink_cthelper.c new file mode 100644 index 000000000..97248963a --- /dev/null +++ b/net/netfilter/nfnetlink_cthelper.c @@ -0,0 +1,804 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * (C) 2012 Pablo Neira Ayuso + * + * This software has been sponsored by Vyatta Inc. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_DESCRIPTION("nfnl_cthelper: User-space connection tracking helpers"); + +struct nfnl_cthelper { + struct list_head list; + struct nf_conntrack_helper helper; +}; + +static LIST_HEAD(nfnl_cthelper_list); + +static int +nfnl_userspace_cthelper(struct sk_buff *skb, unsigned int protoff, + struct nf_conn *ct, enum ip_conntrack_info ctinfo) +{ + const struct nf_conn_help *help; + struct nf_conntrack_helper *helper; + + help = nfct_help(ct); + if (help == NULL) + return NF_DROP; + + /* rcu_read_lock()ed by nf_hook_thresh */ + helper = rcu_dereference(help->helper); + if (helper == NULL) + return NF_DROP; + + /* This is a user-space helper not yet configured, skip. */ + if ((helper->flags & + (NF_CT_HELPER_F_USERSPACE | NF_CT_HELPER_F_CONFIGURED)) == + NF_CT_HELPER_F_USERSPACE) + return NF_ACCEPT; + + /* If the user-space helper is not available, don't block traffic. */ + return NF_QUEUE_NR(helper->queue_num) | NF_VERDICT_FLAG_QUEUE_BYPASS; +} + +static const struct nla_policy nfnl_cthelper_tuple_pol[NFCTH_TUPLE_MAX+1] = { + [NFCTH_TUPLE_L3PROTONUM] = { .type = NLA_U16, }, + [NFCTH_TUPLE_L4PROTONUM] = { .type = NLA_U8, }, +}; + +static int +nfnl_cthelper_parse_tuple(struct nf_conntrack_tuple *tuple, + const struct nlattr *attr) +{ + int err; + struct nlattr *tb[NFCTH_TUPLE_MAX+1]; + + err = nla_parse_nested_deprecated(tb, NFCTH_TUPLE_MAX, attr, + nfnl_cthelper_tuple_pol, NULL); + if (err < 0) + return err; + + if (!tb[NFCTH_TUPLE_L3PROTONUM] || !tb[NFCTH_TUPLE_L4PROTONUM]) + return -EINVAL; + + /* Not all fields are initialized so first zero the tuple */ + memset(tuple, 0, sizeof(struct nf_conntrack_tuple)); + + tuple->src.l3num = ntohs(nla_get_be16(tb[NFCTH_TUPLE_L3PROTONUM])); + tuple->dst.protonum = nla_get_u8(tb[NFCTH_TUPLE_L4PROTONUM]); + + return 0; +} + +static int +nfnl_cthelper_from_nlattr(struct nlattr *attr, struct nf_conn *ct) +{ + struct nf_conn_help *help = nfct_help(ct); + const struct nf_conntrack_helper *helper; + + if (attr == NULL) + return -EINVAL; + + helper = rcu_dereference(help->helper); + if (!helper || helper->data_len == 0) + return -EINVAL; + + nla_memcpy(help->data, attr, sizeof(help->data)); + return 0; +} + +static int +nfnl_cthelper_to_nlattr(struct sk_buff *skb, const struct nf_conn *ct) +{ + const struct nf_conn_help *help = nfct_help(ct); + const struct nf_conntrack_helper *helper; + + helper = rcu_dereference(help->helper); + if (helper && helper->data_len && + nla_put(skb, CTA_HELP_INFO, helper->data_len, &help->data)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -ENOSPC; +} + +static const struct nla_policy nfnl_cthelper_expect_pol[NFCTH_POLICY_MAX+1] = { + [NFCTH_POLICY_NAME] = { .type = NLA_NUL_STRING, + .len = NF_CT_HELPER_NAME_LEN-1 }, + [NFCTH_POLICY_EXPECT_MAX] = { .type = NLA_U32, }, + [NFCTH_POLICY_EXPECT_TIMEOUT] = { .type = NLA_U32, }, +}; + +static int +nfnl_cthelper_expect_policy(struct nf_conntrack_expect_policy *expect_policy, + const struct nlattr *attr) +{ + int err; + struct nlattr *tb[NFCTH_POLICY_MAX+1]; + + err = nla_parse_nested_deprecated(tb, NFCTH_POLICY_MAX, attr, + nfnl_cthelper_expect_pol, NULL); + if (err < 0) + return err; + + if (!tb[NFCTH_POLICY_NAME] || + !tb[NFCTH_POLICY_EXPECT_MAX] || + !tb[NFCTH_POLICY_EXPECT_TIMEOUT]) + return -EINVAL; + + nla_strscpy(expect_policy->name, + tb[NFCTH_POLICY_NAME], NF_CT_HELPER_NAME_LEN); + expect_policy->max_expected = + ntohl(nla_get_be32(tb[NFCTH_POLICY_EXPECT_MAX])); + if (expect_policy->max_expected > NF_CT_EXPECT_MAX_CNT) + return -EINVAL; + + expect_policy->timeout = + ntohl(nla_get_be32(tb[NFCTH_POLICY_EXPECT_TIMEOUT])); + + return 0; +} + +static const struct nla_policy +nfnl_cthelper_expect_policy_set[NFCTH_POLICY_SET_MAX+1] = { + [NFCTH_POLICY_SET_NUM] = { .type = NLA_U32, }, +}; + +static int +nfnl_cthelper_parse_expect_policy(struct nf_conntrack_helper *helper, + const struct nlattr *attr) +{ + int i, ret; + struct nf_conntrack_expect_policy *expect_policy; + struct nlattr *tb[NFCTH_POLICY_SET_MAX+1]; + unsigned int class_max; + + ret = nla_parse_nested_deprecated(tb, NFCTH_POLICY_SET_MAX, attr, + nfnl_cthelper_expect_policy_set, + NULL); + if (ret < 0) + return ret; + + if (!tb[NFCTH_POLICY_SET_NUM]) + return -EINVAL; + + class_max = ntohl(nla_get_be32(tb[NFCTH_POLICY_SET_NUM])); + if (class_max == 0) + return -EINVAL; + if (class_max > NF_CT_MAX_EXPECT_CLASSES) + return -EOVERFLOW; + + expect_policy = kcalloc(class_max, + sizeof(struct nf_conntrack_expect_policy), + GFP_KERNEL); + if (expect_policy == NULL) + return -ENOMEM; + + for (i = 0; i < class_max; i++) { + if (!tb[NFCTH_POLICY_SET+i]) + goto err; + + ret = nfnl_cthelper_expect_policy(&expect_policy[i], + tb[NFCTH_POLICY_SET+i]); + if (ret < 0) + goto err; + } + + helper->expect_class_max = class_max - 1; + helper->expect_policy = expect_policy; + return 0; +err: + kfree(expect_policy); + return -EINVAL; +} + +static int +nfnl_cthelper_create(const struct nlattr * const tb[], + struct nf_conntrack_tuple *tuple) +{ + struct nf_conntrack_helper *helper; + struct nfnl_cthelper *nfcth; + unsigned int size; + int ret; + + if (!tb[NFCTH_TUPLE] || !tb[NFCTH_POLICY] || !tb[NFCTH_PRIV_DATA_LEN]) + return -EINVAL; + + nfcth = kzalloc(sizeof(*nfcth), GFP_KERNEL); + if (nfcth == NULL) + return -ENOMEM; + helper = &nfcth->helper; + + ret = nfnl_cthelper_parse_expect_policy(helper, tb[NFCTH_POLICY]); + if (ret < 0) + goto err1; + + nla_strscpy(helper->name, + tb[NFCTH_NAME], NF_CT_HELPER_NAME_LEN); + size = ntohl(nla_get_be32(tb[NFCTH_PRIV_DATA_LEN])); + if (size > sizeof_field(struct nf_conn_help, data)) { + ret = -ENOMEM; + goto err2; + } + helper->data_len = size; + + helper->flags |= NF_CT_HELPER_F_USERSPACE; + memcpy(&helper->tuple, tuple, sizeof(struct nf_conntrack_tuple)); + + helper->me = THIS_MODULE; + helper->help = nfnl_userspace_cthelper; + helper->from_nlattr = nfnl_cthelper_from_nlattr; + helper->to_nlattr = nfnl_cthelper_to_nlattr; + + /* Default to queue number zero, this can be updated at any time. */ + if (tb[NFCTH_QUEUE_NUM]) + helper->queue_num = ntohl(nla_get_be32(tb[NFCTH_QUEUE_NUM])); + + if (tb[NFCTH_STATUS]) { + int status = ntohl(nla_get_be32(tb[NFCTH_STATUS])); + + switch(status) { + case NFCT_HELPER_STATUS_ENABLED: + helper->flags |= NF_CT_HELPER_F_CONFIGURED; + break; + case NFCT_HELPER_STATUS_DISABLED: + helper->flags &= ~NF_CT_HELPER_F_CONFIGURED; + break; + } + } + + ret = nf_conntrack_helper_register(helper); + if (ret < 0) + goto err2; + + list_add_tail(&nfcth->list, &nfnl_cthelper_list); + return 0; +err2: + kfree(helper->expect_policy); +err1: + kfree(nfcth); + return ret; +} + +static int +nfnl_cthelper_update_policy_one(const struct nf_conntrack_expect_policy *policy, + struct nf_conntrack_expect_policy *new_policy, + const struct nlattr *attr) +{ + struct nlattr *tb[NFCTH_POLICY_MAX + 1]; + int err; + + err = nla_parse_nested_deprecated(tb, NFCTH_POLICY_MAX, attr, + nfnl_cthelper_expect_pol, NULL); + if (err < 0) + return err; + + if (!tb[NFCTH_POLICY_NAME] || + !tb[NFCTH_POLICY_EXPECT_MAX] || + !tb[NFCTH_POLICY_EXPECT_TIMEOUT]) + return -EINVAL; + + if (nla_strcmp(tb[NFCTH_POLICY_NAME], policy->name)) + return -EBUSY; + + new_policy->max_expected = + ntohl(nla_get_be32(tb[NFCTH_POLICY_EXPECT_MAX])); + if (new_policy->max_expected > NF_CT_EXPECT_MAX_CNT) + return -EINVAL; + + new_policy->timeout = + ntohl(nla_get_be32(tb[NFCTH_POLICY_EXPECT_TIMEOUT])); + + return 0; +} + +static int nfnl_cthelper_update_policy_all(struct nlattr *tb[], + struct nf_conntrack_helper *helper) +{ + struct nf_conntrack_expect_policy *new_policy; + struct nf_conntrack_expect_policy *policy; + int i, ret = 0; + + new_policy = kmalloc_array(helper->expect_class_max + 1, + sizeof(*new_policy), GFP_KERNEL); + if (!new_policy) + return -ENOMEM; + + /* Check first that all policy attributes are well-formed, so we don't + * leave things in inconsistent state on errors. + */ + for (i = 0; i < helper->expect_class_max + 1; i++) { + + if (!tb[NFCTH_POLICY_SET + i]) { + ret = -EINVAL; + goto err; + } + + ret = nfnl_cthelper_update_policy_one(&helper->expect_policy[i], + &new_policy[i], + tb[NFCTH_POLICY_SET + i]); + if (ret < 0) + goto err; + } + /* Now we can safely update them. */ + for (i = 0; i < helper->expect_class_max + 1; i++) { + policy = (struct nf_conntrack_expect_policy *) + &helper->expect_policy[i]; + policy->max_expected = new_policy->max_expected; + policy->timeout = new_policy->timeout; + } + +err: + kfree(new_policy); + return ret; +} + +static int nfnl_cthelper_update_policy(struct nf_conntrack_helper *helper, + const struct nlattr *attr) +{ + struct nlattr *tb[NFCTH_POLICY_SET_MAX + 1]; + unsigned int class_max; + int err; + + err = nla_parse_nested_deprecated(tb, NFCTH_POLICY_SET_MAX, attr, + nfnl_cthelper_expect_policy_set, + NULL); + if (err < 0) + return err; + + if (!tb[NFCTH_POLICY_SET_NUM]) + return -EINVAL; + + class_max = ntohl(nla_get_be32(tb[NFCTH_POLICY_SET_NUM])); + if (helper->expect_class_max + 1 != class_max) + return -EBUSY; + + return nfnl_cthelper_update_policy_all(tb, helper); +} + +static int +nfnl_cthelper_update(const struct nlattr * const tb[], + struct nf_conntrack_helper *helper) +{ + u32 size; + int ret; + + if (tb[NFCTH_PRIV_DATA_LEN]) { + size = ntohl(nla_get_be32(tb[NFCTH_PRIV_DATA_LEN])); + if (size != helper->data_len) + return -EBUSY; + } + + if (tb[NFCTH_POLICY]) { + ret = nfnl_cthelper_update_policy(helper, tb[NFCTH_POLICY]); + if (ret < 0) + return ret; + } + if (tb[NFCTH_QUEUE_NUM]) + helper->queue_num = ntohl(nla_get_be32(tb[NFCTH_QUEUE_NUM])); + + if (tb[NFCTH_STATUS]) { + int status = ntohl(nla_get_be32(tb[NFCTH_STATUS])); + + switch(status) { + case NFCT_HELPER_STATUS_ENABLED: + helper->flags |= NF_CT_HELPER_F_CONFIGURED; + break; + case NFCT_HELPER_STATUS_DISABLED: + helper->flags &= ~NF_CT_HELPER_F_CONFIGURED; + break; + } + } + return 0; +} + +static int nfnl_cthelper_new(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const tb[]) +{ + const char *helper_name; + struct nf_conntrack_helper *cur, *helper = NULL; + struct nf_conntrack_tuple tuple; + struct nfnl_cthelper *nlcth; + int ret = 0; + + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + if (!tb[NFCTH_NAME] || !tb[NFCTH_TUPLE]) + return -EINVAL; + + helper_name = nla_data(tb[NFCTH_NAME]); + + ret = nfnl_cthelper_parse_tuple(&tuple, tb[NFCTH_TUPLE]); + if (ret < 0) + return ret; + + list_for_each_entry(nlcth, &nfnl_cthelper_list, list) { + cur = &nlcth->helper; + + if (strncmp(cur->name, helper_name, NF_CT_HELPER_NAME_LEN)) + continue; + + if ((tuple.src.l3num != cur->tuple.src.l3num || + tuple.dst.protonum != cur->tuple.dst.protonum)) + continue; + + if (info->nlh->nlmsg_flags & NLM_F_EXCL) + return -EEXIST; + + helper = cur; + break; + } + + if (helper == NULL) + ret = nfnl_cthelper_create(tb, &tuple); + else + ret = nfnl_cthelper_update(tb, helper); + + return ret; +} + +static int +nfnl_cthelper_dump_tuple(struct sk_buff *skb, + struct nf_conntrack_helper *helper) +{ + struct nlattr *nest_parms; + + nest_parms = nla_nest_start(skb, NFCTH_TUPLE); + if (nest_parms == NULL) + goto nla_put_failure; + + if (nla_put_be16(skb, NFCTH_TUPLE_L3PROTONUM, + htons(helper->tuple.src.l3num))) + goto nla_put_failure; + + if (nla_put_u8(skb, NFCTH_TUPLE_L4PROTONUM, helper->tuple.dst.protonum)) + goto nla_put_failure; + + nla_nest_end(skb, nest_parms); + return 0; + +nla_put_failure: + return -1; +} + +static int +nfnl_cthelper_dump_policy(struct sk_buff *skb, + struct nf_conntrack_helper *helper) +{ + int i; + struct nlattr *nest_parms1, *nest_parms2; + + nest_parms1 = nla_nest_start(skb, NFCTH_POLICY); + if (nest_parms1 == NULL) + goto nla_put_failure; + + if (nla_put_be32(skb, NFCTH_POLICY_SET_NUM, + htonl(helper->expect_class_max + 1))) + goto nla_put_failure; + + for (i = 0; i < helper->expect_class_max + 1; i++) { + nest_parms2 = nla_nest_start(skb, (NFCTH_POLICY_SET + i)); + if (nest_parms2 == NULL) + goto nla_put_failure; + + if (nla_put_string(skb, NFCTH_POLICY_NAME, + helper->expect_policy[i].name)) + goto nla_put_failure; + + if (nla_put_be32(skb, NFCTH_POLICY_EXPECT_MAX, + htonl(helper->expect_policy[i].max_expected))) + goto nla_put_failure; + + if (nla_put_be32(skb, NFCTH_POLICY_EXPECT_TIMEOUT, + htonl(helper->expect_policy[i].timeout))) + goto nla_put_failure; + + nla_nest_end(skb, nest_parms2); + } + nla_nest_end(skb, nest_parms1); + return 0; + +nla_put_failure: + return -1; +} + +static int +nfnl_cthelper_fill_info(struct sk_buff *skb, u32 portid, u32 seq, u32 type, + int event, struct nf_conntrack_helper *helper) +{ + struct nlmsghdr *nlh; + unsigned int flags = portid ? NLM_F_MULTI : 0; + int status; + + event = nfnl_msg_type(NFNL_SUBSYS_CTHELPER, event); + nlh = nfnl_msg_put(skb, portid, seq, event, flags, AF_UNSPEC, + NFNETLINK_V0, 0); + if (!nlh) + goto nlmsg_failure; + + if (nla_put_string(skb, NFCTH_NAME, helper->name)) + goto nla_put_failure; + + if (nla_put_be32(skb, NFCTH_QUEUE_NUM, htonl(helper->queue_num))) + goto nla_put_failure; + + if (nfnl_cthelper_dump_tuple(skb, helper) < 0) + goto nla_put_failure; + + if (nfnl_cthelper_dump_policy(skb, helper) < 0) + goto nla_put_failure; + + if (nla_put_be32(skb, NFCTH_PRIV_DATA_LEN, htonl(helper->data_len))) + goto nla_put_failure; + + if (helper->flags & NF_CT_HELPER_F_CONFIGURED) + status = NFCT_HELPER_STATUS_ENABLED; + else + status = NFCT_HELPER_STATUS_DISABLED; + + if (nla_put_be32(skb, NFCTH_STATUS, htonl(status))) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return skb->len; + +nlmsg_failure: +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -1; +} + +static int +nfnl_cthelper_dump_table(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct nf_conntrack_helper *cur, *last; + + rcu_read_lock(); + last = (struct nf_conntrack_helper *)cb->args[1]; + for (; cb->args[0] < nf_ct_helper_hsize; cb->args[0]++) { +restart: + hlist_for_each_entry_rcu(cur, + &nf_ct_helper_hash[cb->args[0]], hnode) { + + /* skip non-userspace conntrack helpers. */ + if (!(cur->flags & NF_CT_HELPER_F_USERSPACE)) + continue; + + if (cb->args[1]) { + if (cur != last) + continue; + cb->args[1] = 0; + } + if (nfnl_cthelper_fill_info(skb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NFNL_MSG_TYPE(cb->nlh->nlmsg_type), + NFNL_MSG_CTHELPER_NEW, cur) < 0) { + cb->args[1] = (unsigned long)cur; + goto out; + } + } + } + if (cb->args[1]) { + cb->args[1] = 0; + goto restart; + } +out: + rcu_read_unlock(); + return skb->len; +} + +static int nfnl_cthelper_get(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const tb[]) +{ + int ret = -ENOENT; + struct nf_conntrack_helper *cur; + struct sk_buff *skb2; + char *helper_name = NULL; + struct nf_conntrack_tuple tuple; + struct nfnl_cthelper *nlcth; + bool tuple_set = false; + + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + if (info->nlh->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .dump = nfnl_cthelper_dump_table, + }; + return netlink_dump_start(info->sk, skb, info->nlh, &c); + } + + if (tb[NFCTH_NAME]) + helper_name = nla_data(tb[NFCTH_NAME]); + + if (tb[NFCTH_TUPLE]) { + ret = nfnl_cthelper_parse_tuple(&tuple, tb[NFCTH_TUPLE]); + if (ret < 0) + return ret; + + tuple_set = true; + } + + list_for_each_entry(nlcth, &nfnl_cthelper_list, list) { + cur = &nlcth->helper; + if (helper_name && + strncmp(cur->name, helper_name, NF_CT_HELPER_NAME_LEN)) + continue; + + if (tuple_set && + (tuple.src.l3num != cur->tuple.src.l3num || + tuple.dst.protonum != cur->tuple.dst.protonum)) + continue; + + skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (skb2 == NULL) { + ret = -ENOMEM; + break; + } + + ret = nfnl_cthelper_fill_info(skb2, NETLINK_CB(skb).portid, + info->nlh->nlmsg_seq, + NFNL_MSG_TYPE(info->nlh->nlmsg_type), + NFNL_MSG_CTHELPER_NEW, cur); + if (ret <= 0) { + kfree_skb(skb2); + break; + } + + ret = nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid); + break; + } + + return ret; +} + +static int nfnl_cthelper_del(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const tb[]) +{ + char *helper_name = NULL; + struct nf_conntrack_helper *cur; + struct nf_conntrack_tuple tuple; + bool tuple_set = false, found = false; + struct nfnl_cthelper *nlcth, *n; + int j = 0, ret; + + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + if (tb[NFCTH_NAME]) + helper_name = nla_data(tb[NFCTH_NAME]); + + if (tb[NFCTH_TUPLE]) { + ret = nfnl_cthelper_parse_tuple(&tuple, tb[NFCTH_TUPLE]); + if (ret < 0) + return ret; + + tuple_set = true; + } + + ret = -ENOENT; + list_for_each_entry_safe(nlcth, n, &nfnl_cthelper_list, list) { + cur = &nlcth->helper; + j++; + + if (helper_name && + strncmp(cur->name, helper_name, NF_CT_HELPER_NAME_LEN)) + continue; + + if (tuple_set && + (tuple.src.l3num != cur->tuple.src.l3num || + tuple.dst.protonum != cur->tuple.dst.protonum)) + continue; + + if (refcount_dec_if_one(&cur->refcnt)) { + found = true; + nf_conntrack_helper_unregister(cur); + kfree(cur->expect_policy); + + list_del(&nlcth->list); + kfree(nlcth); + } else { + ret = -EBUSY; + } + } + + /* Make sure we return success if we flush and there is no helpers */ + return (found || j == 0) ? 0 : ret; +} + +static const struct nla_policy nfnl_cthelper_policy[NFCTH_MAX+1] = { + [NFCTH_NAME] = { .type = NLA_NUL_STRING, + .len = NF_CT_HELPER_NAME_LEN-1 }, + [NFCTH_QUEUE_NUM] = { .type = NLA_U32, }, + [NFCTH_PRIV_DATA_LEN] = { .type = NLA_U32, }, + [NFCTH_STATUS] = { .type = NLA_U32, }, +}; + +static const struct nfnl_callback nfnl_cthelper_cb[NFNL_MSG_CTHELPER_MAX] = { + [NFNL_MSG_CTHELPER_NEW] = { + .call = nfnl_cthelper_new, + .type = NFNL_CB_MUTEX, + .attr_count = NFCTH_MAX, + .policy = nfnl_cthelper_policy + }, + [NFNL_MSG_CTHELPER_GET] = { + .call = nfnl_cthelper_get, + .type = NFNL_CB_MUTEX, + .attr_count = NFCTH_MAX, + .policy = nfnl_cthelper_policy + }, + [NFNL_MSG_CTHELPER_DEL] = { + .call = nfnl_cthelper_del, + .type = NFNL_CB_MUTEX, + .attr_count = NFCTH_MAX, + .policy = nfnl_cthelper_policy + }, +}; + +static const struct nfnetlink_subsystem nfnl_cthelper_subsys = { + .name = "cthelper", + .subsys_id = NFNL_SUBSYS_CTHELPER, + .cb_count = NFNL_MSG_CTHELPER_MAX, + .cb = nfnl_cthelper_cb, +}; + +MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_CTHELPER); + +static int __init nfnl_cthelper_init(void) +{ + int ret; + + ret = nfnetlink_subsys_register(&nfnl_cthelper_subsys); + if (ret < 0) { + pr_err("nfnl_cthelper: cannot register with nfnetlink.\n"); + goto err_out; + } + return 0; +err_out: + return ret; +} + +static void __exit nfnl_cthelper_exit(void) +{ + struct nf_conntrack_helper *cur; + struct nfnl_cthelper *nlcth, *n; + + nfnetlink_subsys_unregister(&nfnl_cthelper_subsys); + + list_for_each_entry_safe(nlcth, n, &nfnl_cthelper_list, list) { + cur = &nlcth->helper; + + nf_conntrack_helper_unregister(cur); + kfree(cur->expect_policy); + kfree(nlcth); + } +} + +module_init(nfnl_cthelper_init); +module_exit(nfnl_cthelper_exit); diff --git a/net/netfilter/nfnetlink_cttimeout.c b/net/netfilter/nfnetlink_cttimeout.c new file mode 100644 index 000000000..f466af4f8 --- /dev/null +++ b/net/netfilter/nfnetlink_cttimeout.c @@ -0,0 +1,681 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * (C) 2012 by Pablo Neira Ayuso + * (C) 2012 by Vyatta Inc. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +static unsigned int nfct_timeout_id __read_mostly; + +struct ctnl_timeout { + struct list_head head; + struct list_head free_head; + struct rcu_head rcu_head; + refcount_t refcnt; + char name[CTNL_TIMEOUT_NAME_MAX]; + + /* must be at the end */ + struct nf_ct_timeout timeout; +}; + +struct nfct_timeout_pernet { + struct list_head nfct_timeout_list; + struct list_head nfct_timeout_freelist; +}; + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_DESCRIPTION("cttimeout: Extended Netfilter Connection Tracking timeout tuning"); + +static const struct nla_policy cttimeout_nla_policy[CTA_TIMEOUT_MAX+1] = { + [CTA_TIMEOUT_NAME] = { .type = NLA_NUL_STRING, + .len = CTNL_TIMEOUT_NAME_MAX - 1}, + [CTA_TIMEOUT_L3PROTO] = { .type = NLA_U16 }, + [CTA_TIMEOUT_L4PROTO] = { .type = NLA_U8 }, + [CTA_TIMEOUT_DATA] = { .type = NLA_NESTED }, +}; + +static struct nfct_timeout_pernet *nfct_timeout_pernet(struct net *net) +{ + return net_generic(net, nfct_timeout_id); +} + +static int +ctnl_timeout_parse_policy(void *timeout, + const struct nf_conntrack_l4proto *l4proto, + struct net *net, const struct nlattr *attr) +{ + struct nlattr **tb; + int ret = 0; + + tb = kcalloc(l4proto->ctnl_timeout.nlattr_max + 1, sizeof(*tb), + GFP_KERNEL); + + if (!tb) + return -ENOMEM; + + ret = nla_parse_nested_deprecated(tb, + l4proto->ctnl_timeout.nlattr_max, + attr, + l4proto->ctnl_timeout.nla_policy, + NULL); + if (ret < 0) + goto err; + + ret = l4proto->ctnl_timeout.nlattr_to_obj(tb, net, timeout); + +err: + kfree(tb); + return ret; +} + +static int cttimeout_new_timeout(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const cda[]) +{ + struct nfct_timeout_pernet *pernet = nfct_timeout_pernet(info->net); + __u16 l3num; + __u8 l4num; + const struct nf_conntrack_l4proto *l4proto; + struct ctnl_timeout *timeout, *matching = NULL; + char *name; + int ret; + + if (!cda[CTA_TIMEOUT_NAME] || + !cda[CTA_TIMEOUT_L3PROTO] || + !cda[CTA_TIMEOUT_L4PROTO] || + !cda[CTA_TIMEOUT_DATA]) + return -EINVAL; + + name = nla_data(cda[CTA_TIMEOUT_NAME]); + l3num = ntohs(nla_get_be16(cda[CTA_TIMEOUT_L3PROTO])); + l4num = nla_get_u8(cda[CTA_TIMEOUT_L4PROTO]); + + list_for_each_entry(timeout, &pernet->nfct_timeout_list, head) { + if (strncmp(timeout->name, name, CTNL_TIMEOUT_NAME_MAX) != 0) + continue; + + if (info->nlh->nlmsg_flags & NLM_F_EXCL) + return -EEXIST; + + matching = timeout; + break; + } + + if (matching) { + if (info->nlh->nlmsg_flags & NLM_F_REPLACE) { + /* You cannot replace one timeout policy by another of + * different kind, sorry. + */ + if (matching->timeout.l3num != l3num || + matching->timeout.l4proto->l4proto != l4num) + return -EINVAL; + + return ctnl_timeout_parse_policy(&matching->timeout.data, + matching->timeout.l4proto, + info->net, + cda[CTA_TIMEOUT_DATA]); + } + + return -EBUSY; + } + + l4proto = nf_ct_l4proto_find(l4num); + + /* This protocol is not supportted, skip. */ + if (l4proto->l4proto != l4num) { + ret = -EOPNOTSUPP; + goto err_proto_put; + } + + timeout = kzalloc(sizeof(struct ctnl_timeout) + + l4proto->ctnl_timeout.obj_size, GFP_KERNEL); + if (timeout == NULL) { + ret = -ENOMEM; + goto err_proto_put; + } + + ret = ctnl_timeout_parse_policy(&timeout->timeout.data, l4proto, + info->net, cda[CTA_TIMEOUT_DATA]); + if (ret < 0) + goto err; + + strcpy(timeout->name, nla_data(cda[CTA_TIMEOUT_NAME])); + timeout->timeout.l3num = l3num; + timeout->timeout.l4proto = l4proto; + refcount_set(&timeout->refcnt, 1); + __module_get(THIS_MODULE); + list_add_tail_rcu(&timeout->head, &pernet->nfct_timeout_list); + + return 0; +err: + kfree(timeout); +err_proto_put: + return ret; +} + +static int +ctnl_timeout_fill_info(struct sk_buff *skb, u32 portid, u32 seq, u32 type, + int event, struct ctnl_timeout *timeout) +{ + struct nlmsghdr *nlh; + unsigned int flags = portid ? NLM_F_MULTI : 0; + const struct nf_conntrack_l4proto *l4proto = timeout->timeout.l4proto; + struct nlattr *nest_parms; + int ret; + + event = nfnl_msg_type(NFNL_SUBSYS_CTNETLINK_TIMEOUT, event); + nlh = nfnl_msg_put(skb, portid, seq, event, flags, AF_UNSPEC, + NFNETLINK_V0, 0); + if (!nlh) + goto nlmsg_failure; + + if (nla_put_string(skb, CTA_TIMEOUT_NAME, timeout->name) || + nla_put_be16(skb, CTA_TIMEOUT_L3PROTO, + htons(timeout->timeout.l3num)) || + nla_put_u8(skb, CTA_TIMEOUT_L4PROTO, l4proto->l4proto) || + nla_put_be32(skb, CTA_TIMEOUT_USE, + htonl(refcount_read(&timeout->refcnt)))) + goto nla_put_failure; + + nest_parms = nla_nest_start(skb, CTA_TIMEOUT_DATA); + if (!nest_parms) + goto nla_put_failure; + + ret = l4proto->ctnl_timeout.obj_to_nlattr(skb, &timeout->timeout.data); + if (ret < 0) + goto nla_put_failure; + + nla_nest_end(skb, nest_parms); + + nlmsg_end(skb, nlh); + return skb->len; + +nlmsg_failure: +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -1; +} + +static int +ctnl_timeout_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct nfct_timeout_pernet *pernet; + struct net *net = sock_net(skb->sk); + struct ctnl_timeout *cur, *last; + + if (cb->args[2]) + return 0; + + last = (struct ctnl_timeout *)cb->args[1]; + if (cb->args[1]) + cb->args[1] = 0; + + rcu_read_lock(); + pernet = nfct_timeout_pernet(net); + list_for_each_entry_rcu(cur, &pernet->nfct_timeout_list, head) { + if (last) { + if (cur != last) + continue; + + last = NULL; + } + if (ctnl_timeout_fill_info(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NFNL_MSG_TYPE(cb->nlh->nlmsg_type), + IPCTNL_MSG_TIMEOUT_NEW, cur) < 0) { + cb->args[1] = (unsigned long)cur; + break; + } + } + if (!cb->args[1]) + cb->args[2] = 1; + rcu_read_unlock(); + return skb->len; +} + +static int cttimeout_get_timeout(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const cda[]) +{ + struct nfct_timeout_pernet *pernet = nfct_timeout_pernet(info->net); + int ret = -ENOENT; + char *name; + struct ctnl_timeout *cur; + + if (info->nlh->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .dump = ctnl_timeout_dump, + }; + return netlink_dump_start(info->sk, skb, info->nlh, &c); + } + + if (!cda[CTA_TIMEOUT_NAME]) + return -EINVAL; + name = nla_data(cda[CTA_TIMEOUT_NAME]); + + list_for_each_entry(cur, &pernet->nfct_timeout_list, head) { + struct sk_buff *skb2; + + if (strncmp(cur->name, name, CTNL_TIMEOUT_NAME_MAX) != 0) + continue; + + skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (skb2 == NULL) { + ret = -ENOMEM; + break; + } + + ret = ctnl_timeout_fill_info(skb2, NETLINK_CB(skb).portid, + info->nlh->nlmsg_seq, + NFNL_MSG_TYPE(info->nlh->nlmsg_type), + IPCTNL_MSG_TIMEOUT_NEW, cur); + if (ret <= 0) { + kfree_skb(skb2); + break; + } + + ret = nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid); + break; + } + + return ret; +} + +/* try to delete object, fail if it is still in use. */ +static int ctnl_timeout_try_del(struct net *net, struct ctnl_timeout *timeout) +{ + int ret = 0; + + /* We want to avoid races with ctnl_timeout_put. So only when the + * current refcnt is 1, we decrease it to 0. + */ + if (refcount_dec_if_one(&timeout->refcnt)) { + /* We are protected by nfnl mutex. */ + list_del_rcu(&timeout->head); + nf_ct_untimeout(net, &timeout->timeout); + kfree_rcu(timeout, rcu_head); + } else { + ret = -EBUSY; + } + return ret; +} + +static int cttimeout_del_timeout(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const cda[]) +{ + struct nfct_timeout_pernet *pernet = nfct_timeout_pernet(info->net); + struct ctnl_timeout *cur, *tmp; + int ret = -ENOENT; + char *name; + + if (!cda[CTA_TIMEOUT_NAME]) { + list_for_each_entry_safe(cur, tmp, &pernet->nfct_timeout_list, + head) + ctnl_timeout_try_del(info->net, cur); + + return 0; + } + name = nla_data(cda[CTA_TIMEOUT_NAME]); + + list_for_each_entry(cur, &pernet->nfct_timeout_list, head) { + if (strncmp(cur->name, name, CTNL_TIMEOUT_NAME_MAX) != 0) + continue; + + ret = ctnl_timeout_try_del(info->net, cur); + if (ret < 0) + return ret; + + break; + } + return ret; +} + +static int cttimeout_default_set(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const cda[]) +{ + const struct nf_conntrack_l4proto *l4proto; + __u8 l4num; + int ret; + + if (!cda[CTA_TIMEOUT_L3PROTO] || + !cda[CTA_TIMEOUT_L4PROTO] || + !cda[CTA_TIMEOUT_DATA]) + return -EINVAL; + + l4num = nla_get_u8(cda[CTA_TIMEOUT_L4PROTO]); + l4proto = nf_ct_l4proto_find(l4num); + + /* This protocol is not supported, skip. */ + if (l4proto->l4proto != l4num) { + ret = -EOPNOTSUPP; + goto err; + } + + ret = ctnl_timeout_parse_policy(NULL, l4proto, info->net, + cda[CTA_TIMEOUT_DATA]); + if (ret < 0) + goto err; + + return 0; +err: + return ret; +} + +static int +cttimeout_default_fill_info(struct net *net, struct sk_buff *skb, u32 portid, + u32 seq, u32 type, int event, u16 l3num, + const struct nf_conntrack_l4proto *l4proto, + const unsigned int *timeouts) +{ + struct nlmsghdr *nlh; + unsigned int flags = portid ? NLM_F_MULTI : 0; + struct nlattr *nest_parms; + int ret; + + event = nfnl_msg_type(NFNL_SUBSYS_CTNETLINK_TIMEOUT, event); + nlh = nfnl_msg_put(skb, portid, seq, event, flags, AF_UNSPEC, + NFNETLINK_V0, 0); + if (!nlh) + goto nlmsg_failure; + + if (nla_put_be16(skb, CTA_TIMEOUT_L3PROTO, htons(l3num)) || + nla_put_u8(skb, CTA_TIMEOUT_L4PROTO, l4proto->l4proto)) + goto nla_put_failure; + + nest_parms = nla_nest_start(skb, CTA_TIMEOUT_DATA); + if (!nest_parms) + goto nla_put_failure; + + ret = l4proto->ctnl_timeout.obj_to_nlattr(skb, timeouts); + if (ret < 0) + goto nla_put_failure; + + nla_nest_end(skb, nest_parms); + + nlmsg_end(skb, nlh); + return skb->len; + +nlmsg_failure: +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -1; +} + +static int cttimeout_default_get(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const cda[]) +{ + const struct nf_conntrack_l4proto *l4proto; + unsigned int *timeouts = NULL; + struct sk_buff *skb2; + __u16 l3num; + __u8 l4num; + int ret; + + if (!cda[CTA_TIMEOUT_L3PROTO] || !cda[CTA_TIMEOUT_L4PROTO]) + return -EINVAL; + + l3num = ntohs(nla_get_be16(cda[CTA_TIMEOUT_L3PROTO])); + l4num = nla_get_u8(cda[CTA_TIMEOUT_L4PROTO]); + l4proto = nf_ct_l4proto_find(l4num); + + if (l4proto->l4proto != l4num) + return -EOPNOTSUPP; + + switch (l4proto->l4proto) { + case IPPROTO_ICMP: + timeouts = &nf_icmp_pernet(info->net)->timeout; + break; + case IPPROTO_TCP: + timeouts = nf_tcp_pernet(info->net)->timeouts; + break; + case IPPROTO_UDP: + case IPPROTO_UDPLITE: + timeouts = nf_udp_pernet(info->net)->timeouts; + break; + case IPPROTO_DCCP: +#ifdef CONFIG_NF_CT_PROTO_DCCP + timeouts = nf_dccp_pernet(info->net)->dccp_timeout; +#endif + break; + case IPPROTO_ICMPV6: + timeouts = &nf_icmpv6_pernet(info->net)->timeout; + break; + case IPPROTO_SCTP: +#ifdef CONFIG_NF_CT_PROTO_SCTP + timeouts = nf_sctp_pernet(info->net)->timeouts; +#endif + break; + case IPPROTO_GRE: +#ifdef CONFIG_NF_CT_PROTO_GRE + timeouts = nf_gre_pernet(info->net)->timeouts; +#endif + break; + case 255: + timeouts = &nf_generic_pernet(info->net)->timeout; + break; + default: + WARN_ONCE(1, "Missing timeouts for proto %d", l4proto->l4proto); + break; + } + + if (!timeouts) + return -EOPNOTSUPP; + + skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb2) + return -ENOMEM; + + ret = cttimeout_default_fill_info(info->net, skb2, + NETLINK_CB(skb).portid, + info->nlh->nlmsg_seq, + NFNL_MSG_TYPE(info->nlh->nlmsg_type), + IPCTNL_MSG_TIMEOUT_DEFAULT_SET, + l3num, l4proto, timeouts); + if (ret <= 0) { + kfree_skb(skb2); + return -ENOMEM; + } + + return nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid); +} + +static struct nf_ct_timeout *ctnl_timeout_find_get(struct net *net, + const char *name) +{ + struct nfct_timeout_pernet *pernet = nfct_timeout_pernet(net); + struct ctnl_timeout *timeout, *matching = NULL; + + list_for_each_entry_rcu(timeout, &pernet->nfct_timeout_list, head) { + if (strncmp(timeout->name, name, CTNL_TIMEOUT_NAME_MAX) != 0) + continue; + + if (!refcount_inc_not_zero(&timeout->refcnt)) + goto err; + matching = timeout; + break; + } +err: + return matching ? &matching->timeout : NULL; +} + +static void ctnl_timeout_put(struct nf_ct_timeout *t) +{ + struct ctnl_timeout *timeout = + container_of(t, struct ctnl_timeout, timeout); + + if (refcount_dec_and_test(&timeout->refcnt)) { + kfree_rcu(timeout, rcu_head); + module_put(THIS_MODULE); + } +} + +static const struct nfnl_callback cttimeout_cb[IPCTNL_MSG_TIMEOUT_MAX] = { + [IPCTNL_MSG_TIMEOUT_NEW] = { + .call = cttimeout_new_timeout, + .type = NFNL_CB_MUTEX, + .attr_count = CTA_TIMEOUT_MAX, + .policy = cttimeout_nla_policy + }, + [IPCTNL_MSG_TIMEOUT_GET] = { + .call = cttimeout_get_timeout, + .type = NFNL_CB_MUTEX, + .attr_count = CTA_TIMEOUT_MAX, + .policy = cttimeout_nla_policy + }, + [IPCTNL_MSG_TIMEOUT_DELETE] = { + .call = cttimeout_del_timeout, + .type = NFNL_CB_MUTEX, + .attr_count = CTA_TIMEOUT_MAX, + .policy = cttimeout_nla_policy + }, + [IPCTNL_MSG_TIMEOUT_DEFAULT_SET] = { + .call = cttimeout_default_set, + .type = NFNL_CB_MUTEX, + .attr_count = CTA_TIMEOUT_MAX, + .policy = cttimeout_nla_policy + }, + [IPCTNL_MSG_TIMEOUT_DEFAULT_GET] = { + .call = cttimeout_default_get, + .type = NFNL_CB_MUTEX, + .attr_count = CTA_TIMEOUT_MAX, + .policy = cttimeout_nla_policy + }, +}; + +static const struct nfnetlink_subsystem cttimeout_subsys = { + .name = "conntrack_timeout", + .subsys_id = NFNL_SUBSYS_CTNETLINK_TIMEOUT, + .cb_count = IPCTNL_MSG_TIMEOUT_MAX, + .cb = cttimeout_cb, +}; + +MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_CTNETLINK_TIMEOUT); + +static int __net_init cttimeout_net_init(struct net *net) +{ + struct nfct_timeout_pernet *pernet = nfct_timeout_pernet(net); + + INIT_LIST_HEAD(&pernet->nfct_timeout_list); + INIT_LIST_HEAD(&pernet->nfct_timeout_freelist); + + return 0; +} + +static void __net_exit cttimeout_net_pre_exit(struct net *net) +{ + struct nfct_timeout_pernet *pernet = nfct_timeout_pernet(net); + struct ctnl_timeout *cur, *tmp; + + list_for_each_entry_safe(cur, tmp, &pernet->nfct_timeout_list, head) { + list_del_rcu(&cur->head); + list_add(&cur->free_head, &pernet->nfct_timeout_freelist); + } + + /* core calls synchronize_rcu() after this */ +} + +static void __net_exit cttimeout_net_exit(struct net *net) +{ + struct nfct_timeout_pernet *pernet = nfct_timeout_pernet(net); + struct ctnl_timeout *cur, *tmp; + + if (list_empty(&pernet->nfct_timeout_freelist)) + return; + + nf_ct_untimeout(net, NULL); + + list_for_each_entry_safe(cur, tmp, &pernet->nfct_timeout_freelist, free_head) { + list_del(&cur->free_head); + + if (refcount_dec_and_test(&cur->refcnt)) + kfree_rcu(cur, rcu_head); + } +} + +static struct pernet_operations cttimeout_ops = { + .init = cttimeout_net_init, + .pre_exit = cttimeout_net_pre_exit, + .exit = cttimeout_net_exit, + .id = &nfct_timeout_id, + .size = sizeof(struct nfct_timeout_pernet), +}; + +static const struct nf_ct_timeout_hooks hooks = { + .timeout_find_get = ctnl_timeout_find_get, + .timeout_put = ctnl_timeout_put, +}; + +static int __init cttimeout_init(void) +{ + int ret; + + ret = register_pernet_subsys(&cttimeout_ops); + if (ret < 0) + return ret; + + ret = nfnetlink_subsys_register(&cttimeout_subsys); + if (ret < 0) { + pr_err("cttimeout_init: cannot register cttimeout with " + "nfnetlink.\n"); + goto err_out; + } + RCU_INIT_POINTER(nf_ct_timeout_hook, &hooks); + return 0; + +err_out: + unregister_pernet_subsys(&cttimeout_ops); + return ret; +} + +static int untimeout(struct nf_conn *ct, void *timeout) +{ + struct nf_conn_timeout *timeout_ext = nf_ct_timeout_find(ct); + + if (timeout_ext) + RCU_INIT_POINTER(timeout_ext->timeout, NULL); + + return 0; +} + +static void __exit cttimeout_exit(void) +{ + nfnetlink_subsys_unregister(&cttimeout_subsys); + + unregister_pernet_subsys(&cttimeout_ops); + RCU_INIT_POINTER(nf_ct_timeout_hook, NULL); + + nf_ct_iterate_destroy(untimeout, NULL); +} + +module_init(cttimeout_init); +module_exit(cttimeout_exit); diff --git a/net/netfilter/nfnetlink_hook.c b/net/netfilter/nfnetlink_hook.c new file mode 100644 index 000000000..8120aadf6 --- /dev/null +++ b/net/netfilter/nfnetlink_hook.c @@ -0,0 +1,393 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (c) 2021 Red Hat GmbH + * + * Author: Florian Westphal + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include + +static const struct nla_policy nfnl_hook_nla_policy[NFNLA_HOOK_MAX + 1] = { + [NFNLA_HOOK_HOOKNUM] = { .type = NLA_U32 }, + [NFNLA_HOOK_PRIORITY] = { .type = NLA_U32 }, + [NFNLA_HOOK_DEV] = { .type = NLA_STRING, + .len = IFNAMSIZ - 1 }, + [NFNLA_HOOK_FUNCTION_NAME] = { .type = NLA_NUL_STRING, + .len = KSYM_NAME_LEN, }, + [NFNLA_HOOK_MODULE_NAME] = { .type = NLA_NUL_STRING, + .len = MODULE_NAME_LEN, }, + [NFNLA_HOOK_CHAIN_INFO] = { .type = NLA_NESTED, }, +}; + +static int nf_netlink_dump_start_rcu(struct sock *nlsk, struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct netlink_dump_control *c) +{ + int err; + + if (!try_module_get(THIS_MODULE)) + return -EINVAL; + + rcu_read_unlock(); + err = netlink_dump_start(nlsk, skb, nlh, c); + rcu_read_lock(); + module_put(THIS_MODULE); + + return err; +} + +struct nfnl_dump_hook_data { + char devname[IFNAMSIZ]; + unsigned long headv; + u8 hook; +}; + +static int nfnl_hook_put_nft_chain_info(struct sk_buff *nlskb, + const struct nfnl_dump_hook_data *ctx, + unsigned int seq, + const struct nf_hook_ops *ops) +{ + struct net *net = sock_net(nlskb->sk); + struct nlattr *nest, *nest2; + struct nft_chain *chain; + int ret = 0; + + if (ops->hook_ops_type != NF_HOOK_OP_NF_TABLES) + return 0; + + chain = ops->priv; + if (WARN_ON_ONCE(!chain)) + return 0; + + if (!nft_is_active(net, chain)) + return 0; + + nest = nla_nest_start(nlskb, NFNLA_HOOK_CHAIN_INFO); + if (!nest) + return -EMSGSIZE; + + ret = nla_put_be32(nlskb, NFNLA_HOOK_INFO_TYPE, + htonl(NFNL_HOOK_TYPE_NFTABLES)); + if (ret) + goto cancel_nest; + + nest2 = nla_nest_start(nlskb, NFNLA_HOOK_INFO_DESC); + if (!nest2) + goto cancel_nest; + + ret = nla_put_string(nlskb, NFNLA_CHAIN_TABLE, chain->table->name); + if (ret) + goto cancel_nest; + + ret = nla_put_string(nlskb, NFNLA_CHAIN_NAME, chain->name); + if (ret) + goto cancel_nest; + + ret = nla_put_u8(nlskb, NFNLA_CHAIN_FAMILY, chain->table->family); + if (ret) + goto cancel_nest; + + nla_nest_end(nlskb, nest2); + nla_nest_end(nlskb, nest); + return ret; + +cancel_nest: + nla_nest_cancel(nlskb, nest); + return -EMSGSIZE; +} + +static int nfnl_hook_dump_one(struct sk_buff *nlskb, + const struct nfnl_dump_hook_data *ctx, + const struct nf_hook_ops *ops, + int family, unsigned int seq) +{ + u16 event = nfnl_msg_type(NFNL_SUBSYS_HOOK, NFNL_MSG_HOOK_GET); + unsigned int portid = NETLINK_CB(nlskb).portid; + struct nlmsghdr *nlh; + int ret = -EMSGSIZE; + u32 hooknum; +#ifdef CONFIG_KALLSYMS + char sym[KSYM_SYMBOL_LEN]; + char *module_name; +#endif + nlh = nfnl_msg_put(nlskb, portid, seq, event, + NLM_F_MULTI, family, NFNETLINK_V0, 0); + if (!nlh) + goto nla_put_failure; + +#ifdef CONFIG_KALLSYMS + ret = snprintf(sym, sizeof(sym), "%ps", ops->hook); + if (ret >= sizeof(sym)) { + ret = -EINVAL; + goto nla_put_failure; + } + + module_name = strstr(sym, " ["); + if (module_name) { + char *end; + + *module_name = '\0'; + module_name += 2; + end = strchr(module_name, ']'); + if (end) { + *end = 0; + + ret = nla_put_string(nlskb, NFNLA_HOOK_MODULE_NAME, module_name); + if (ret) + goto nla_put_failure; + } + } + + ret = nla_put_string(nlskb, NFNLA_HOOK_FUNCTION_NAME, sym); + if (ret) + goto nla_put_failure; +#endif + + if (ops->pf == NFPROTO_INET && ops->hooknum == NF_INET_INGRESS) + hooknum = NF_NETDEV_INGRESS; + else + hooknum = ops->hooknum; + + ret = nla_put_be32(nlskb, NFNLA_HOOK_HOOKNUM, htonl(hooknum)); + if (ret) + goto nla_put_failure; + + ret = nla_put_be32(nlskb, NFNLA_HOOK_PRIORITY, htonl(ops->priority)); + if (ret) + goto nla_put_failure; + + ret = nfnl_hook_put_nft_chain_info(nlskb, ctx, seq, ops); + if (ret) + goto nla_put_failure; + + nlmsg_end(nlskb, nlh); + return 0; +nla_put_failure: + nlmsg_trim(nlskb, nlh); + return ret; +} + +static const struct nf_hook_entries * +nfnl_hook_entries_head(u8 pf, unsigned int hook, struct net *net, const char *dev) +{ + const struct nf_hook_entries *hook_head = NULL; +#if defined(CONFIG_NETFILTER_INGRESS) || defined(CONFIG_NETFILTER_EGRESS) + struct net_device *netdev; +#endif + + switch (pf) { + case NFPROTO_IPV4: + if (hook >= ARRAY_SIZE(net->nf.hooks_ipv4)) + return ERR_PTR(-EINVAL); + hook_head = rcu_dereference(net->nf.hooks_ipv4[hook]); + break; + case NFPROTO_IPV6: + if (hook >= ARRAY_SIZE(net->nf.hooks_ipv6)) + return ERR_PTR(-EINVAL); + hook_head = rcu_dereference(net->nf.hooks_ipv6[hook]); + break; + case NFPROTO_ARP: +#ifdef CONFIG_NETFILTER_FAMILY_ARP + if (hook >= ARRAY_SIZE(net->nf.hooks_arp)) + return ERR_PTR(-EINVAL); + hook_head = rcu_dereference(net->nf.hooks_arp[hook]); +#endif + break; + case NFPROTO_BRIDGE: +#ifdef CONFIG_NETFILTER_FAMILY_BRIDGE + if (hook >= ARRAY_SIZE(net->nf.hooks_bridge)) + return ERR_PTR(-EINVAL); + hook_head = rcu_dereference(net->nf.hooks_bridge[hook]); +#endif + break; +#if defined(CONFIG_NETFILTER_INGRESS) || defined(CONFIG_NETFILTER_EGRESS) + case NFPROTO_NETDEV: + if (hook >= NF_NETDEV_NUMHOOKS) + return ERR_PTR(-EOPNOTSUPP); + + if (!dev) + return ERR_PTR(-ENODEV); + + netdev = dev_get_by_name_rcu(net, dev); + if (!netdev) + return ERR_PTR(-ENODEV); + +#ifdef CONFIG_NETFILTER_INGRESS + if (hook == NF_NETDEV_INGRESS) + return rcu_dereference(netdev->nf_hooks_ingress); +#endif +#ifdef CONFIG_NETFILTER_EGRESS + if (hook == NF_NETDEV_EGRESS) + return rcu_dereference(netdev->nf_hooks_egress); +#endif + fallthrough; +#endif + default: + return ERR_PTR(-EPROTONOSUPPORT); + } + + return hook_head; +} + +static int nfnl_hook_dump(struct sk_buff *nlskb, + struct netlink_callback *cb) +{ + struct nfgenmsg *nfmsg = nlmsg_data(cb->nlh); + struct nfnl_dump_hook_data *ctx = cb->data; + int err, family = nfmsg->nfgen_family; + struct net *net = sock_net(nlskb->sk); + struct nf_hook_ops * const *ops; + const struct nf_hook_entries *e; + unsigned int i = cb->args[0]; + + rcu_read_lock(); + + e = nfnl_hook_entries_head(family, ctx->hook, net, ctx->devname); + if (!e) + goto done; + + if (IS_ERR(e)) { + cb->seq++; + goto done; + } + + if ((unsigned long)e != ctx->headv || i >= e->num_hook_entries) + cb->seq++; + + ops = nf_hook_entries_get_hook_ops(e); + + for (; i < e->num_hook_entries; i++) { + err = nfnl_hook_dump_one(nlskb, ctx, ops[i], family, + cb->nlh->nlmsg_seq); + if (err) + break; + } + +done: + nl_dump_check_consistent(cb, nlmsg_hdr(nlskb)); + rcu_read_unlock(); + cb->args[0] = i; + return nlskb->len; +} + +static int nfnl_hook_dump_start(struct netlink_callback *cb) +{ + const struct nfgenmsg *nfmsg = nlmsg_data(cb->nlh); + const struct nlattr * const *nla = cb->data; + struct nfnl_dump_hook_data *ctx = NULL; + struct net *net = sock_net(cb->skb->sk); + u8 family = nfmsg->nfgen_family; + char name[IFNAMSIZ] = ""; + const void *head; + u32 hooknum; + + hooknum = ntohl(nla_get_be32(nla[NFNLA_HOOK_HOOKNUM])); + if (hooknum > 255) + return -EINVAL; + + if (family == NFPROTO_NETDEV) { + if (!nla[NFNLA_HOOK_DEV]) + return -EINVAL; + + nla_strscpy(name, nla[NFNLA_HOOK_DEV], sizeof(name)); + } + + rcu_read_lock(); + /* Not dereferenced; for consistency check only */ + head = nfnl_hook_entries_head(family, hooknum, net, name); + rcu_read_unlock(); + + if (head && IS_ERR(head)) + return PTR_ERR(head); + + ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); + if (!ctx) + return -ENOMEM; + + strscpy(ctx->devname, name, sizeof(ctx->devname)); + ctx->headv = (unsigned long)head; + ctx->hook = hooknum; + + cb->seq = 1; + cb->data = ctx; + + return 0; +} + +static int nfnl_hook_dump_stop(struct netlink_callback *cb) +{ + kfree(cb->data); + return 0; +} + +static int nfnl_hook_get(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const nla[]) +{ + if (!nla[NFNLA_HOOK_HOOKNUM]) + return -EINVAL; + + if (info->nlh->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .start = nfnl_hook_dump_start, + .done = nfnl_hook_dump_stop, + .dump = nfnl_hook_dump, + .module = THIS_MODULE, + .data = (void *)nla, + }; + + return nf_netlink_dump_start_rcu(info->sk, skb, info->nlh, &c); + } + + return -EOPNOTSUPP; +} + +static const struct nfnl_callback nfnl_hook_cb[NFNL_MSG_HOOK_MAX] = { + [NFNL_MSG_HOOK_GET] = { + .call = nfnl_hook_get, + .type = NFNL_CB_RCU, + .attr_count = NFNLA_HOOK_MAX, + .policy = nfnl_hook_nla_policy + }, +}; + +static const struct nfnetlink_subsystem nfhook_subsys = { + .name = "nfhook", + .subsys_id = NFNL_SUBSYS_HOOK, + .cb_count = NFNL_MSG_HOOK_MAX, + .cb = nfnl_hook_cb, +}; + +MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_HOOK); + +static int __init nfnetlink_hook_init(void) +{ + return nfnetlink_subsys_register(&nfhook_subsys); +} + +static void __exit nfnetlink_hook_exit(void) +{ + nfnetlink_subsys_unregister(&nfhook_subsys); +} + +module_init(nfnetlink_hook_init); +module_exit(nfnetlink_hook_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Florian Westphal "); +MODULE_DESCRIPTION("nfnetlink_hook: list registered netfilter hooks"); diff --git a/net/netfilter/nfnetlink_log.c b/net/netfilter/nfnetlink_log.c new file mode 100644 index 000000000..200a82a8f --- /dev/null +++ b/net/netfilter/nfnetlink_log.c @@ -0,0 +1,1209 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * This is a module which is used for logging packets to userspace via + * nfetlink. + * + * (C) 2005 by Harald Welte + * (C) 2006-2012 Patrick McHardy + * + * Based on the old ipv4-only ipt_ULOG.c: + * (C) 2000-2004 by Harald Welte + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + + +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) +#include "../bridge/br_private.h" +#endif + +#if IS_ENABLED(CONFIG_NF_CONNTRACK) +#include +#endif + +#define NFULNL_COPY_DISABLED 0xff +#define NFULNL_NLBUFSIZ_DEFAULT NLMSG_GOODSIZE +#define NFULNL_TIMEOUT_DEFAULT 100 /* every second */ +#define NFULNL_QTHRESH_DEFAULT 100 /* 100 packets */ +/* max packet size is limited by 16-bit struct nfattr nfa_len field */ +#define NFULNL_COPY_RANGE_MAX (0xFFFF - NLA_HDRLEN) + +#define PRINTR(x, args...) do { if (net_ratelimit()) \ + printk(x, ## args); } while (0); + +struct nfulnl_instance { + struct hlist_node hlist; /* global list of instances */ + spinlock_t lock; + refcount_t use; /* use count */ + + unsigned int qlen; /* number of nlmsgs in skb */ + struct sk_buff *skb; /* pre-allocatd skb */ + struct timer_list timer; + struct net *net; + netns_tracker ns_tracker; + struct user_namespace *peer_user_ns; /* User namespace of the peer process */ + u32 peer_portid; /* PORTID of the peer process */ + + /* configurable parameters */ + unsigned int flushtimeout; /* timeout until queue flush */ + unsigned int nlbufsiz; /* netlink buffer allocation size */ + unsigned int qthreshold; /* threshold of the queue */ + u_int32_t copy_range; + u_int32_t seq; /* instance-local sequential counter */ + u_int16_t group_num; /* number of this queue */ + u_int16_t flags; + u_int8_t copy_mode; + struct rcu_head rcu; +}; + +#define INSTANCE_BUCKETS 16 + +static unsigned int nfnl_log_net_id __read_mostly; + +struct nfnl_log_net { + spinlock_t instances_lock; + struct hlist_head instance_table[INSTANCE_BUCKETS]; + atomic_t global_seq; +}; + +static struct nfnl_log_net *nfnl_log_pernet(struct net *net) +{ + return net_generic(net, nfnl_log_net_id); +} + +static inline u_int8_t instance_hashfn(u_int16_t group_num) +{ + return ((group_num & 0xff) % INSTANCE_BUCKETS); +} + +static struct nfulnl_instance * +__instance_lookup(struct nfnl_log_net *log, u_int16_t group_num) +{ + struct hlist_head *head; + struct nfulnl_instance *inst; + + head = &log->instance_table[instance_hashfn(group_num)]; + hlist_for_each_entry_rcu(inst, head, hlist) { + if (inst->group_num == group_num) + return inst; + } + return NULL; +} + +static inline void +instance_get(struct nfulnl_instance *inst) +{ + refcount_inc(&inst->use); +} + +static struct nfulnl_instance * +instance_lookup_get(struct nfnl_log_net *log, u_int16_t group_num) +{ + struct nfulnl_instance *inst; + + rcu_read_lock_bh(); + inst = __instance_lookup(log, group_num); + if (inst && !refcount_inc_not_zero(&inst->use)) + inst = NULL; + rcu_read_unlock_bh(); + + return inst; +} + +static void nfulnl_instance_free_rcu(struct rcu_head *head) +{ + struct nfulnl_instance *inst = + container_of(head, struct nfulnl_instance, rcu); + + put_net_track(inst->net, &inst->ns_tracker); + kfree(inst); + module_put(THIS_MODULE); +} + +static void +instance_put(struct nfulnl_instance *inst) +{ + if (inst && refcount_dec_and_test(&inst->use)) + call_rcu(&inst->rcu, nfulnl_instance_free_rcu); +} + +static void nfulnl_timer(struct timer_list *t); + +static struct nfulnl_instance * +instance_create(struct net *net, u_int16_t group_num, + u32 portid, struct user_namespace *user_ns) +{ + struct nfulnl_instance *inst; + struct nfnl_log_net *log = nfnl_log_pernet(net); + int err; + + spin_lock_bh(&log->instances_lock); + if (__instance_lookup(log, group_num)) { + err = -EEXIST; + goto out_unlock; + } + + inst = kzalloc(sizeof(*inst), GFP_ATOMIC); + if (!inst) { + err = -ENOMEM; + goto out_unlock; + } + + if (!try_module_get(THIS_MODULE)) { + kfree(inst); + err = -EAGAIN; + goto out_unlock; + } + + INIT_HLIST_NODE(&inst->hlist); + spin_lock_init(&inst->lock); + /* needs to be two, since we _put() after creation */ + refcount_set(&inst->use, 2); + + timer_setup(&inst->timer, nfulnl_timer, 0); + + inst->net = get_net_track(net, &inst->ns_tracker, GFP_ATOMIC); + inst->peer_user_ns = user_ns; + inst->peer_portid = portid; + inst->group_num = group_num; + + inst->qthreshold = NFULNL_QTHRESH_DEFAULT; + inst->flushtimeout = NFULNL_TIMEOUT_DEFAULT; + inst->nlbufsiz = NFULNL_NLBUFSIZ_DEFAULT; + inst->copy_mode = NFULNL_COPY_PACKET; + inst->copy_range = NFULNL_COPY_RANGE_MAX; + + hlist_add_head_rcu(&inst->hlist, + &log->instance_table[instance_hashfn(group_num)]); + + + spin_unlock_bh(&log->instances_lock); + + return inst; + +out_unlock: + spin_unlock_bh(&log->instances_lock); + return ERR_PTR(err); +} + +static void __nfulnl_flush(struct nfulnl_instance *inst); + +/* called with BH disabled */ +static void +__instance_destroy(struct nfulnl_instance *inst) +{ + /* first pull it out of the global list */ + hlist_del_rcu(&inst->hlist); + + /* then flush all pending packets from skb */ + + spin_lock(&inst->lock); + + /* lockless readers wont be able to use us */ + inst->copy_mode = NFULNL_COPY_DISABLED; + + if (inst->skb) + __nfulnl_flush(inst); + spin_unlock(&inst->lock); + + /* and finally put the refcount */ + instance_put(inst); +} + +static inline void +instance_destroy(struct nfnl_log_net *log, + struct nfulnl_instance *inst) +{ + spin_lock_bh(&log->instances_lock); + __instance_destroy(inst); + spin_unlock_bh(&log->instances_lock); +} + +static int +nfulnl_set_mode(struct nfulnl_instance *inst, u_int8_t mode, + unsigned int range) +{ + int status = 0; + + spin_lock_bh(&inst->lock); + + switch (mode) { + case NFULNL_COPY_NONE: + case NFULNL_COPY_META: + inst->copy_mode = mode; + inst->copy_range = 0; + break; + + case NFULNL_COPY_PACKET: + inst->copy_mode = mode; + if (range == 0) + range = NFULNL_COPY_RANGE_MAX; + inst->copy_range = min_t(unsigned int, + range, NFULNL_COPY_RANGE_MAX); + break; + + default: + status = -EINVAL; + break; + } + + spin_unlock_bh(&inst->lock); + + return status; +} + +static int +nfulnl_set_nlbufsiz(struct nfulnl_instance *inst, u_int32_t nlbufsiz) +{ + int status; + + spin_lock_bh(&inst->lock); + if (nlbufsiz < NFULNL_NLBUFSIZ_DEFAULT) + status = -ERANGE; + else if (nlbufsiz > 131072) + status = -ERANGE; + else { + inst->nlbufsiz = nlbufsiz; + status = 0; + } + spin_unlock_bh(&inst->lock); + + return status; +} + +static void +nfulnl_set_timeout(struct nfulnl_instance *inst, u_int32_t timeout) +{ + spin_lock_bh(&inst->lock); + inst->flushtimeout = timeout; + spin_unlock_bh(&inst->lock); +} + +static void +nfulnl_set_qthresh(struct nfulnl_instance *inst, u_int32_t qthresh) +{ + spin_lock_bh(&inst->lock); + inst->qthreshold = qthresh; + spin_unlock_bh(&inst->lock); +} + +static int +nfulnl_set_flags(struct nfulnl_instance *inst, u_int16_t flags) +{ + spin_lock_bh(&inst->lock); + inst->flags = flags; + spin_unlock_bh(&inst->lock); + + return 0; +} + +static struct sk_buff * +nfulnl_alloc_skb(struct net *net, u32 peer_portid, unsigned int inst_size, + unsigned int pkt_size) +{ + struct sk_buff *skb; + unsigned int n; + + /* alloc skb which should be big enough for a whole multipart + * message. WARNING: has to be <= 128k due to slab restrictions */ + + n = max(inst_size, pkt_size); + skb = alloc_skb(n, GFP_ATOMIC | __GFP_NOWARN); + if (!skb) { + if (n > pkt_size) { + /* try to allocate only as much as we need for current + * packet */ + + skb = alloc_skb(pkt_size, GFP_ATOMIC); + } + } + + return skb; +} + +static void +__nfulnl_send(struct nfulnl_instance *inst) +{ + if (inst->qlen > 1) { + struct nlmsghdr *nlh = nlmsg_put(inst->skb, 0, 0, + NLMSG_DONE, + sizeof(struct nfgenmsg), + 0); + if (WARN_ONCE(!nlh, "bad nlskb size: %u, tailroom %d\n", + inst->skb->len, skb_tailroom(inst->skb))) { + kfree_skb(inst->skb); + goto out; + } + } + nfnetlink_unicast(inst->skb, inst->net, inst->peer_portid); +out: + inst->qlen = 0; + inst->skb = NULL; +} + +static void +__nfulnl_flush(struct nfulnl_instance *inst) +{ + /* timer holds a reference */ + if (del_timer(&inst->timer)) + instance_put(inst); + if (inst->skb) + __nfulnl_send(inst); +} + +static void +nfulnl_timer(struct timer_list *t) +{ + struct nfulnl_instance *inst = from_timer(inst, t, timer); + + spin_lock_bh(&inst->lock); + if (inst->skb) + __nfulnl_send(inst); + spin_unlock_bh(&inst->lock); + instance_put(inst); +} + +static u32 nfulnl_get_bridge_size(const struct sk_buff *skb) +{ + u32 size = 0; + + if (!skb_mac_header_was_set(skb)) + return 0; + + if (skb_vlan_tag_present(skb)) { + size += nla_total_size(0); /* nested */ + size += nla_total_size(sizeof(u16)); /* id */ + size += nla_total_size(sizeof(u16)); /* tag */ + } + + if (skb->network_header > skb->mac_header) + size += nla_total_size(skb->network_header - skb->mac_header); + + return size; +} + +static int nfulnl_put_bridge(struct nfulnl_instance *inst, const struct sk_buff *skb) +{ + if (!skb_mac_header_was_set(skb)) + return 0; + + if (skb_vlan_tag_present(skb)) { + struct nlattr *nest; + + nest = nla_nest_start(inst->skb, NFULA_VLAN); + if (!nest) + goto nla_put_failure; + + if (nla_put_be16(inst->skb, NFULA_VLAN_TCI, htons(skb->vlan_tci)) || + nla_put_be16(inst->skb, NFULA_VLAN_PROTO, skb->vlan_proto)) + goto nla_put_failure; + + nla_nest_end(inst->skb, nest); + } + + if (skb->mac_header < skb->network_header) { + int len = (int)(skb->network_header - skb->mac_header); + + if (nla_put(inst->skb, NFULA_L2HDR, len, skb_mac_header(skb))) + goto nla_put_failure; + } + + return 0; + +nla_put_failure: + return -1; +} + +/* This is an inline function, we don't really care about a long + * list of arguments */ +static inline int +__build_packet_message(struct nfnl_log_net *log, + struct nfulnl_instance *inst, + const struct sk_buff *skb, + unsigned int data_len, + u_int8_t pf, + unsigned int hooknum, + const struct net_device *indev, + const struct net_device *outdev, + const char *prefix, unsigned int plen, + const struct nfnl_ct_hook *nfnl_ct, + struct nf_conn *ct, enum ip_conntrack_info ctinfo) +{ + struct nfulnl_msg_packet_hdr pmsg; + struct nlmsghdr *nlh; + sk_buff_data_t old_tail = inst->skb->tail; + struct sock *sk; + const unsigned char *hwhdrp; + ktime_t tstamp; + + nlh = nfnl_msg_put(inst->skb, 0, 0, + nfnl_msg_type(NFNL_SUBSYS_ULOG, NFULNL_MSG_PACKET), + 0, pf, NFNETLINK_V0, htons(inst->group_num)); + if (!nlh) + return -1; + + memset(&pmsg, 0, sizeof(pmsg)); + pmsg.hw_protocol = skb->protocol; + pmsg.hook = hooknum; + + if (nla_put(inst->skb, NFULA_PACKET_HDR, sizeof(pmsg), &pmsg)) + goto nla_put_failure; + + if (prefix && + nla_put(inst->skb, NFULA_PREFIX, plen, prefix)) + goto nla_put_failure; + + if (indev) { +#if !IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + if (nla_put_be32(inst->skb, NFULA_IFINDEX_INDEV, + htonl(indev->ifindex))) + goto nla_put_failure; +#else + if (pf == PF_BRIDGE) { + /* Case 1: outdev is physical input device, we need to + * look for bridge group (when called from + * netfilter_bridge) */ + if (nla_put_be32(inst->skb, NFULA_IFINDEX_PHYSINDEV, + htonl(indev->ifindex)) || + /* this is the bridge group "brX" */ + /* rcu_read_lock()ed by nf_hook_thresh or + * nf_log_packet. + */ + nla_put_be32(inst->skb, NFULA_IFINDEX_INDEV, + htonl(br_port_get_rcu(indev)->br->dev->ifindex))) + goto nla_put_failure; + } else { + int physinif; + + /* Case 2: indev is bridge group, we need to look for + * physical device (when called from ipv4) */ + if (nla_put_be32(inst->skb, NFULA_IFINDEX_INDEV, + htonl(indev->ifindex))) + goto nla_put_failure; + + physinif = nf_bridge_get_physinif(skb); + if (physinif && + nla_put_be32(inst->skb, NFULA_IFINDEX_PHYSINDEV, + htonl(physinif))) + goto nla_put_failure; + } +#endif + } + + if (outdev) { +#if !IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + if (nla_put_be32(inst->skb, NFULA_IFINDEX_OUTDEV, + htonl(outdev->ifindex))) + goto nla_put_failure; +#else + if (pf == PF_BRIDGE) { + /* Case 1: outdev is physical output device, we need to + * look for bridge group (when called from + * netfilter_bridge) */ + if (nla_put_be32(inst->skb, NFULA_IFINDEX_PHYSOUTDEV, + htonl(outdev->ifindex)) || + /* this is the bridge group "brX" */ + /* rcu_read_lock()ed by nf_hook_thresh or + * nf_log_packet. + */ + nla_put_be32(inst->skb, NFULA_IFINDEX_OUTDEV, + htonl(br_port_get_rcu(outdev)->br->dev->ifindex))) + goto nla_put_failure; + } else { + struct net_device *physoutdev; + + /* Case 2: indev is a bridge group, we need to look + * for physical device (when called from ipv4) */ + if (nla_put_be32(inst->skb, NFULA_IFINDEX_OUTDEV, + htonl(outdev->ifindex))) + goto nla_put_failure; + + physoutdev = nf_bridge_get_physoutdev(skb); + if (physoutdev && + nla_put_be32(inst->skb, NFULA_IFINDEX_PHYSOUTDEV, + htonl(physoutdev->ifindex))) + goto nla_put_failure; + } +#endif + } + + if (skb->mark && + nla_put_be32(inst->skb, NFULA_MARK, htonl(skb->mark))) + goto nla_put_failure; + + if (indev && skb->dev && + skb_mac_header_was_set(skb) && + skb_mac_header_len(skb) != 0) { + struct nfulnl_msg_packet_hw phw; + int len; + + memset(&phw, 0, sizeof(phw)); + len = dev_parse_header(skb, phw.hw_addr); + if (len > 0) { + phw.hw_addrlen = htons(len); + if (nla_put(inst->skb, NFULA_HWADDR, sizeof(phw), &phw)) + goto nla_put_failure; + } + } + + if (indev && skb_mac_header_was_set(skb)) { + if (nla_put_be16(inst->skb, NFULA_HWTYPE, htons(skb->dev->type)) || + nla_put_be16(inst->skb, NFULA_HWLEN, + htons(skb->dev->hard_header_len))) + goto nla_put_failure; + + hwhdrp = skb_mac_header(skb); + + if (skb->dev->type == ARPHRD_SIT) + hwhdrp -= ETH_HLEN; + + if (hwhdrp >= skb->head && + nla_put(inst->skb, NFULA_HWHEADER, + skb->dev->hard_header_len, hwhdrp)) + goto nla_put_failure; + } + + tstamp = skb_tstamp_cond(skb, false); + if (hooknum <= NF_INET_FORWARD && tstamp) { + struct nfulnl_msg_packet_timestamp ts; + struct timespec64 kts = ktime_to_timespec64(tstamp); + ts.sec = cpu_to_be64(kts.tv_sec); + ts.usec = cpu_to_be64(kts.tv_nsec / NSEC_PER_USEC); + + if (nla_put(inst->skb, NFULA_TIMESTAMP, sizeof(ts), &ts)) + goto nla_put_failure; + } + + /* UID */ + sk = skb->sk; + if (sk && sk_fullsock(sk)) { + read_lock_bh(&sk->sk_callback_lock); + if (sk->sk_socket && sk->sk_socket->file) { + struct file *file = sk->sk_socket->file; + const struct cred *cred = file->f_cred; + struct user_namespace *user_ns = inst->peer_user_ns; + __be32 uid = htonl(from_kuid_munged(user_ns, cred->fsuid)); + __be32 gid = htonl(from_kgid_munged(user_ns, cred->fsgid)); + read_unlock_bh(&sk->sk_callback_lock); + if (nla_put_be32(inst->skb, NFULA_UID, uid) || + nla_put_be32(inst->skb, NFULA_GID, gid)) + goto nla_put_failure; + } else + read_unlock_bh(&sk->sk_callback_lock); + } + + /* local sequence number */ + if ((inst->flags & NFULNL_CFG_F_SEQ) && + nla_put_be32(inst->skb, NFULA_SEQ, htonl(inst->seq++))) + goto nla_put_failure; + + /* global sequence number */ + if ((inst->flags & NFULNL_CFG_F_SEQ_GLOBAL) && + nla_put_be32(inst->skb, NFULA_SEQ_GLOBAL, + htonl(atomic_inc_return(&log->global_seq)))) + goto nla_put_failure; + + if (ct && nfnl_ct->build(inst->skb, ct, ctinfo, + NFULA_CT, NFULA_CT_INFO) < 0) + goto nla_put_failure; + + if ((pf == NFPROTO_NETDEV || pf == NFPROTO_BRIDGE) && + nfulnl_put_bridge(inst, skb) < 0) + goto nla_put_failure; + + if (data_len) { + struct nlattr *nla; + int size = nla_attr_size(data_len); + + if (skb_tailroom(inst->skb) < nla_total_size(data_len)) + goto nla_put_failure; + + nla = skb_put(inst->skb, nla_total_size(data_len)); + nla->nla_type = NFULA_PAYLOAD; + nla->nla_len = size; + + if (skb_copy_bits(skb, 0, nla_data(nla), data_len)) + BUG(); + } + + nlh->nlmsg_len = inst->skb->tail - old_tail; + return 0; + +nla_put_failure: + PRINTR(KERN_ERR "nfnetlink_log: error creating log nlmsg\n"); + return -1; +} + +static const struct nf_loginfo default_loginfo = { + .type = NF_LOG_TYPE_ULOG, + .u = { + .ulog = { + .copy_len = 0xffff, + .group = 0, + .qthreshold = 1, + }, + }, +}; + +/* log handler for internal netfilter logging api */ +static void +nfulnl_log_packet(struct net *net, + u_int8_t pf, + unsigned int hooknum, + const struct sk_buff *skb, + const struct net_device *in, + const struct net_device *out, + const struct nf_loginfo *li_user, + const char *prefix) +{ + size_t size; + unsigned int data_len; + struct nfulnl_instance *inst; + const struct nf_loginfo *li; + unsigned int qthreshold; + unsigned int plen = 0; + struct nfnl_log_net *log = nfnl_log_pernet(net); + const struct nfnl_ct_hook *nfnl_ct = NULL; + enum ip_conntrack_info ctinfo = 0; + struct nf_conn *ct = NULL; + + if (li_user && li_user->type == NF_LOG_TYPE_ULOG) + li = li_user; + else + li = &default_loginfo; + + inst = instance_lookup_get(log, li->u.ulog.group); + if (!inst) + return; + + if (prefix) + plen = strlen(prefix) + 1; + + /* FIXME: do we want to make the size calculation conditional based on + * what is actually present? way more branches and checks, but more + * memory efficient... */ + size = nlmsg_total_size(sizeof(struct nfgenmsg)) + + nla_total_size(sizeof(struct nfulnl_msg_packet_hdr)) + + nla_total_size(sizeof(u_int32_t)) /* ifindex */ + + nla_total_size(sizeof(u_int32_t)) /* ifindex */ +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + + nla_total_size(sizeof(u_int32_t)) /* ifindex */ + + nla_total_size(sizeof(u_int32_t)) /* ifindex */ +#endif + + nla_total_size(sizeof(u_int32_t)) /* mark */ + + nla_total_size(sizeof(u_int32_t)) /* uid */ + + nla_total_size(sizeof(u_int32_t)) /* gid */ + + nla_total_size(plen) /* prefix */ + + nla_total_size(sizeof(struct nfulnl_msg_packet_hw)) + + nla_total_size(sizeof(struct nfulnl_msg_packet_timestamp)) + + nla_total_size(sizeof(struct nfgenmsg)); /* NLMSG_DONE */ + + if (in && skb_mac_header_was_set(skb)) { + size += nla_total_size(skb->dev->hard_header_len) + + nla_total_size(sizeof(u_int16_t)) /* hwtype */ + + nla_total_size(sizeof(u_int16_t)); /* hwlen */ + } + + spin_lock_bh(&inst->lock); + + if (inst->flags & NFULNL_CFG_F_SEQ) + size += nla_total_size(sizeof(u_int32_t)); + if (inst->flags & NFULNL_CFG_F_SEQ_GLOBAL) + size += nla_total_size(sizeof(u_int32_t)); +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + if (inst->flags & NFULNL_CFG_F_CONNTRACK) { + nfnl_ct = rcu_dereference(nfnl_ct_hook); + if (nfnl_ct != NULL) { + ct = nf_ct_get(skb, &ctinfo); + if (ct != NULL) + size += nfnl_ct->build_size(ct); + } + } +#endif + if (pf == NFPROTO_NETDEV || pf == NFPROTO_BRIDGE) + size += nfulnl_get_bridge_size(skb); + + qthreshold = inst->qthreshold; + /* per-rule qthreshold overrides per-instance */ + if (li->u.ulog.qthreshold) + if (qthreshold > li->u.ulog.qthreshold) + qthreshold = li->u.ulog.qthreshold; + + + switch (inst->copy_mode) { + case NFULNL_COPY_META: + case NFULNL_COPY_NONE: + data_len = 0; + break; + + case NFULNL_COPY_PACKET: + data_len = inst->copy_range; + if ((li->u.ulog.flags & NF_LOG_F_COPY_LEN) && + (li->u.ulog.copy_len < data_len)) + data_len = li->u.ulog.copy_len; + + if (data_len > skb->len) + data_len = skb->len; + + size += nla_total_size(data_len); + break; + + case NFULNL_COPY_DISABLED: + default: + goto unlock_and_release; + } + + if (inst->skb && size > skb_tailroom(inst->skb)) { + /* either the queue len is too high or we don't have + * enough room in the skb left. flush to userspace. */ + __nfulnl_flush(inst); + } + + if (!inst->skb) { + inst->skb = nfulnl_alloc_skb(net, inst->peer_portid, + inst->nlbufsiz, size); + if (!inst->skb) + goto alloc_failure; + } + + inst->qlen++; + + __build_packet_message(log, inst, skb, data_len, pf, + hooknum, in, out, prefix, plen, + nfnl_ct, ct, ctinfo); + + if (inst->qlen >= qthreshold) + __nfulnl_flush(inst); + /* timer_pending always called within inst->lock, so there + * is no chance of a race here */ + else if (!timer_pending(&inst->timer)) { + instance_get(inst); + inst->timer.expires = jiffies + (inst->flushtimeout*HZ/100); + add_timer(&inst->timer); + } + +unlock_and_release: + spin_unlock_bh(&inst->lock); + instance_put(inst); + return; + +alloc_failure: + /* FIXME: statistics */ + goto unlock_and_release; +} + +static int +nfulnl_rcv_nl_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct netlink_notify *n = ptr; + struct nfnl_log_net *log = nfnl_log_pernet(n->net); + + if (event == NETLINK_URELEASE && n->protocol == NETLINK_NETFILTER) { + int i; + + /* destroy all instances for this portid */ + spin_lock_bh(&log->instances_lock); + for (i = 0; i < INSTANCE_BUCKETS; i++) { + struct hlist_node *t2; + struct nfulnl_instance *inst; + struct hlist_head *head = &log->instance_table[i]; + + hlist_for_each_entry_safe(inst, t2, head, hlist) { + if (n->portid == inst->peer_portid) + __instance_destroy(inst); + } + } + spin_unlock_bh(&log->instances_lock); + } + return NOTIFY_DONE; +} + +static struct notifier_block nfulnl_rtnl_notifier = { + .notifier_call = nfulnl_rcv_nl_event, +}; + +static int nfulnl_recv_unsupp(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nfula[]) +{ + return -ENOTSUPP; +} + +static struct nf_logger nfulnl_logger __read_mostly = { + .name = "nfnetlink_log", + .type = NF_LOG_TYPE_ULOG, + .logfn = nfulnl_log_packet, + .me = THIS_MODULE, +}; + +static const struct nla_policy nfula_cfg_policy[NFULA_CFG_MAX+1] = { + [NFULA_CFG_CMD] = { .len = sizeof(struct nfulnl_msg_config_cmd) }, + [NFULA_CFG_MODE] = { .len = sizeof(struct nfulnl_msg_config_mode) }, + [NFULA_CFG_TIMEOUT] = { .type = NLA_U32 }, + [NFULA_CFG_QTHRESH] = { .type = NLA_U32 }, + [NFULA_CFG_NLBUFSIZ] = { .type = NLA_U32 }, + [NFULA_CFG_FLAGS] = { .type = NLA_U16 }, +}; + +static int nfulnl_recv_config(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nfula[]) +{ + struct nfnl_log_net *log = nfnl_log_pernet(info->net); + u_int16_t group_num = ntohs(info->nfmsg->res_id); + struct nfulnl_msg_config_cmd *cmd = NULL; + struct nfulnl_instance *inst; + u16 flags = 0; + int ret = 0; + + if (nfula[NFULA_CFG_CMD]) { + u_int8_t pf = info->nfmsg->nfgen_family; + cmd = nla_data(nfula[NFULA_CFG_CMD]); + + /* Commands without queue context */ + switch (cmd->command) { + case NFULNL_CFG_CMD_PF_BIND: + return nf_log_bind_pf(info->net, pf, &nfulnl_logger); + case NFULNL_CFG_CMD_PF_UNBIND: + nf_log_unbind_pf(info->net, pf); + return 0; + } + } + + inst = instance_lookup_get(log, group_num); + if (inst && inst->peer_portid != NETLINK_CB(skb).portid) { + ret = -EPERM; + goto out_put; + } + + /* Check if we support these flags in first place, dependencies should + * be there too not to break atomicity. + */ + if (nfula[NFULA_CFG_FLAGS]) { + flags = ntohs(nla_get_be16(nfula[NFULA_CFG_FLAGS])); + + if ((flags & NFULNL_CFG_F_CONNTRACK) && + !rcu_access_pointer(nfnl_ct_hook)) { +#ifdef CONFIG_MODULES + nfnl_unlock(NFNL_SUBSYS_ULOG); + request_module("ip_conntrack_netlink"); + nfnl_lock(NFNL_SUBSYS_ULOG); + if (rcu_access_pointer(nfnl_ct_hook)) { + ret = -EAGAIN; + goto out_put; + } +#endif + ret = -EOPNOTSUPP; + goto out_put; + } + } + + if (cmd != NULL) { + switch (cmd->command) { + case NFULNL_CFG_CMD_BIND: + if (inst) { + ret = -EBUSY; + goto out_put; + } + + inst = instance_create(info->net, group_num, + NETLINK_CB(skb).portid, + sk_user_ns(NETLINK_CB(skb).sk)); + if (IS_ERR(inst)) { + ret = PTR_ERR(inst); + goto out; + } + break; + case NFULNL_CFG_CMD_UNBIND: + if (!inst) { + ret = -ENODEV; + goto out; + } + + instance_destroy(log, inst); + goto out_put; + default: + ret = -ENOTSUPP; + goto out_put; + } + } else if (!inst) { + ret = -ENODEV; + goto out; + } + + if (nfula[NFULA_CFG_MODE]) { + struct nfulnl_msg_config_mode *params = + nla_data(nfula[NFULA_CFG_MODE]); + + nfulnl_set_mode(inst, params->copy_mode, + ntohl(params->copy_range)); + } + + if (nfula[NFULA_CFG_TIMEOUT]) { + __be32 timeout = nla_get_be32(nfula[NFULA_CFG_TIMEOUT]); + + nfulnl_set_timeout(inst, ntohl(timeout)); + } + + if (nfula[NFULA_CFG_NLBUFSIZ]) { + __be32 nlbufsiz = nla_get_be32(nfula[NFULA_CFG_NLBUFSIZ]); + + nfulnl_set_nlbufsiz(inst, ntohl(nlbufsiz)); + } + + if (nfula[NFULA_CFG_QTHRESH]) { + __be32 qthresh = nla_get_be32(nfula[NFULA_CFG_QTHRESH]); + + nfulnl_set_qthresh(inst, ntohl(qthresh)); + } + + if (nfula[NFULA_CFG_FLAGS]) + nfulnl_set_flags(inst, flags); + +out_put: + instance_put(inst); +out: + return ret; +} + +static const struct nfnl_callback nfulnl_cb[NFULNL_MSG_MAX] = { + [NFULNL_MSG_PACKET] = { + .call = nfulnl_recv_unsupp, + .type = NFNL_CB_MUTEX, + .attr_count = NFULA_MAX, + }, + [NFULNL_MSG_CONFIG] = { + .call = nfulnl_recv_config, + .type = NFNL_CB_MUTEX, + .attr_count = NFULA_CFG_MAX, + .policy = nfula_cfg_policy + }, +}; + +static const struct nfnetlink_subsystem nfulnl_subsys = { + .name = "log", + .subsys_id = NFNL_SUBSYS_ULOG, + .cb_count = NFULNL_MSG_MAX, + .cb = nfulnl_cb, +}; + +#ifdef CONFIG_PROC_FS +struct iter_state { + struct seq_net_private p; + unsigned int bucket; +}; + +static struct hlist_node *get_first(struct net *net, struct iter_state *st) +{ + struct nfnl_log_net *log; + if (!st) + return NULL; + + log = nfnl_log_pernet(net); + + for (st->bucket = 0; st->bucket < INSTANCE_BUCKETS; st->bucket++) { + struct hlist_head *head = &log->instance_table[st->bucket]; + + if (!hlist_empty(head)) + return rcu_dereference_bh(hlist_first_rcu(head)); + } + return NULL; +} + +static struct hlist_node *get_next(struct net *net, struct iter_state *st, + struct hlist_node *h) +{ + h = rcu_dereference_bh(hlist_next_rcu(h)); + while (!h) { + struct nfnl_log_net *log; + struct hlist_head *head; + + if (++st->bucket >= INSTANCE_BUCKETS) + return NULL; + + log = nfnl_log_pernet(net); + head = &log->instance_table[st->bucket]; + h = rcu_dereference_bh(hlist_first_rcu(head)); + } + return h; +} + +static struct hlist_node *get_idx(struct net *net, struct iter_state *st, + loff_t pos) +{ + struct hlist_node *head; + head = get_first(net, st); + + if (head) + while (pos && (head = get_next(net, st, head))) + pos--; + return pos ? NULL : head; +} + +static void *seq_start(struct seq_file *s, loff_t *pos) + __acquires(rcu_bh) +{ + rcu_read_lock_bh(); + return get_idx(seq_file_net(s), s->private, *pos); +} + +static void *seq_next(struct seq_file *s, void *v, loff_t *pos) +{ + (*pos)++; + return get_next(seq_file_net(s), s->private, v); +} + +static void seq_stop(struct seq_file *s, void *v) + __releases(rcu_bh) +{ + rcu_read_unlock_bh(); +} + +static int seq_show(struct seq_file *s, void *v) +{ + const struct nfulnl_instance *inst = v; + + seq_printf(s, "%5u %6u %5u %1u %5u %6u %2u\n", + inst->group_num, + inst->peer_portid, inst->qlen, + inst->copy_mode, inst->copy_range, + inst->flushtimeout, refcount_read(&inst->use)); + + return 0; +} + +static const struct seq_operations nful_seq_ops = { + .start = seq_start, + .next = seq_next, + .stop = seq_stop, + .show = seq_show, +}; +#endif /* PROC_FS */ + +static int __net_init nfnl_log_net_init(struct net *net) +{ + unsigned int i; + struct nfnl_log_net *log = nfnl_log_pernet(net); +#ifdef CONFIG_PROC_FS + struct proc_dir_entry *proc; + kuid_t root_uid; + kgid_t root_gid; +#endif + + for (i = 0; i < INSTANCE_BUCKETS; i++) + INIT_HLIST_HEAD(&log->instance_table[i]); + spin_lock_init(&log->instances_lock); + +#ifdef CONFIG_PROC_FS + proc = proc_create_net("nfnetlink_log", 0440, net->nf.proc_netfilter, + &nful_seq_ops, sizeof(struct iter_state)); + if (!proc) + return -ENOMEM; + + root_uid = make_kuid(net->user_ns, 0); + root_gid = make_kgid(net->user_ns, 0); + if (uid_valid(root_uid) && gid_valid(root_gid)) + proc_set_user(proc, root_uid, root_gid); +#endif + return 0; +} + +static void __net_exit nfnl_log_net_exit(struct net *net) +{ + struct nfnl_log_net *log = nfnl_log_pernet(net); + unsigned int i; + +#ifdef CONFIG_PROC_FS + remove_proc_entry("nfnetlink_log", net->nf.proc_netfilter); +#endif + nf_log_unset(net, &nfulnl_logger); + for (i = 0; i < INSTANCE_BUCKETS; i++) + WARN_ON_ONCE(!hlist_empty(&log->instance_table[i])); +} + +static struct pernet_operations nfnl_log_net_ops = { + .init = nfnl_log_net_init, + .exit = nfnl_log_net_exit, + .id = &nfnl_log_net_id, + .size = sizeof(struct nfnl_log_net), +}; + +static int __init nfnetlink_log_init(void) +{ + int status; + + status = register_pernet_subsys(&nfnl_log_net_ops); + if (status < 0) { + pr_err("failed to register pernet ops\n"); + goto out; + } + + netlink_register_notifier(&nfulnl_rtnl_notifier); + status = nfnetlink_subsys_register(&nfulnl_subsys); + if (status < 0) { + pr_err("failed to create netlink socket\n"); + goto cleanup_netlink_notifier; + } + + status = nf_log_register(NFPROTO_UNSPEC, &nfulnl_logger); + if (status < 0) { + pr_err("failed to register logger\n"); + goto cleanup_subsys; + } + + return status; + +cleanup_subsys: + nfnetlink_subsys_unregister(&nfulnl_subsys); +cleanup_netlink_notifier: + netlink_unregister_notifier(&nfulnl_rtnl_notifier); + unregister_pernet_subsys(&nfnl_log_net_ops); +out: + return status; +} + +static void __exit nfnetlink_log_fini(void) +{ + nfnetlink_subsys_unregister(&nfulnl_subsys); + netlink_unregister_notifier(&nfulnl_rtnl_notifier); + unregister_pernet_subsys(&nfnl_log_net_ops); + nf_log_unregister(&nfulnl_logger); +} + +MODULE_DESCRIPTION("netfilter userspace logging"); +MODULE_AUTHOR("Harald Welte "); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_ULOG); +MODULE_ALIAS_NF_LOGGER(AF_INET, 1); +MODULE_ALIAS_NF_LOGGER(AF_INET6, 1); +MODULE_ALIAS_NF_LOGGER(AF_BRIDGE, 1); +MODULE_ALIAS_NF_LOGGER(3, 1); /* NFPROTO_ARP */ +MODULE_ALIAS_NF_LOGGER(5, 1); /* NFPROTO_NETDEV */ + +module_init(nfnetlink_log_init); +module_exit(nfnetlink_log_fini); diff --git a/net/netfilter/nfnetlink_osf.c b/net/netfilter/nfnetlink_osf.c new file mode 100644 index 000000000..50723ba08 --- /dev/null +++ b/net/netfilter/nfnetlink_osf.c @@ -0,0 +1,450 @@ +// SPDX-License-Identifier: GPL-2.0-only +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +/* + * Indexed by dont-fragment bit. + * It is the only constant value in the fingerprint. + */ +struct list_head nf_osf_fingers[2]; +EXPORT_SYMBOL_GPL(nf_osf_fingers); + +static inline int nf_osf_ttl(const struct sk_buff *skb, + int ttl_check, unsigned char f_ttl) +{ + struct in_device *in_dev = __in_dev_get_rcu(skb->dev); + const struct iphdr *ip = ip_hdr(skb); + const struct in_ifaddr *ifa; + int ret = 0; + + if (ttl_check == NF_OSF_TTL_TRUE) + return ip->ttl == f_ttl; + if (ttl_check == NF_OSF_TTL_NOCHECK) + return 1; + else if (ip->ttl <= f_ttl) + return 1; + + in_dev_for_each_ifa_rcu(ifa, in_dev) { + if (inet_ifa_match(ip->saddr, ifa)) { + ret = (ip->ttl == f_ttl); + break; + } + } + + return ret; +} + +struct nf_osf_hdr_ctx { + bool df; + u16 window; + u16 totlen; + const unsigned char *optp; + unsigned int optsize; +}; + +static bool nf_osf_match_one(const struct sk_buff *skb, + const struct nf_osf_user_finger *f, + int ttl_check, + struct nf_osf_hdr_ctx *ctx) +{ + const __u8 *optpinit = ctx->optp; + unsigned int check_WSS = 0; + int fmatch = FMATCH_WRONG; + int foptsize, optnum; + u16 mss = 0; + + if (ctx->totlen != f->ss || !nf_osf_ttl(skb, ttl_check, f->ttl)) + return false; + + /* + * Should not happen if userspace parser was written correctly. + */ + if (f->wss.wc >= OSF_WSS_MAX) + return false; + + /* Check options */ + + foptsize = 0; + for (optnum = 0; optnum < f->opt_num; ++optnum) + foptsize += f->opt[optnum].length; + + if (foptsize > MAX_IPOPTLEN || + ctx->optsize > MAX_IPOPTLEN || + ctx->optsize != foptsize) + return false; + + check_WSS = f->wss.wc; + + for (optnum = 0; optnum < f->opt_num; ++optnum) { + if (f->opt[optnum].kind == *ctx->optp) { + __u32 len = f->opt[optnum].length; + const __u8 *optend = ctx->optp + len; + + fmatch = FMATCH_OK; + + switch (*ctx->optp) { + case OSFOPT_MSS: + mss = ctx->optp[3]; + mss <<= 8; + mss |= ctx->optp[2]; + + mss = ntohs((__force __be16)mss); + break; + case OSFOPT_TS: + break; + } + + ctx->optp = optend; + } else + fmatch = FMATCH_OPT_WRONG; + + if (fmatch != FMATCH_OK) + break; + } + + if (fmatch != FMATCH_OPT_WRONG) { + fmatch = FMATCH_WRONG; + + switch (check_WSS) { + case OSF_WSS_PLAIN: + if (f->wss.val == 0 || ctx->window == f->wss.val) + fmatch = FMATCH_OK; + break; + case OSF_WSS_MSS: + /* + * Some smart modems decrease mangle MSS to + * SMART_MSS_2, so we check standard, decreased + * and the one provided in the fingerprint MSS + * values. + */ +#define SMART_MSS_1 1460 +#define SMART_MSS_2 1448 + if (ctx->window == f->wss.val * mss || + ctx->window == f->wss.val * SMART_MSS_1 || + ctx->window == f->wss.val * SMART_MSS_2) + fmatch = FMATCH_OK; + break; + case OSF_WSS_MTU: + if (ctx->window == f->wss.val * (mss + 40) || + ctx->window == f->wss.val * (SMART_MSS_1 + 40) || + ctx->window == f->wss.val * (SMART_MSS_2 + 40)) + fmatch = FMATCH_OK; + break; + case OSF_WSS_MODULO: + if ((ctx->window % f->wss.val) == 0) + fmatch = FMATCH_OK; + break; + } + } + + if (fmatch != FMATCH_OK) + ctx->optp = optpinit; + + return fmatch == FMATCH_OK; +} + +static const struct tcphdr *nf_osf_hdr_ctx_init(struct nf_osf_hdr_ctx *ctx, + const struct sk_buff *skb, + const struct iphdr *ip, + unsigned char *opts, + struct tcphdr *_tcph) +{ + const struct tcphdr *tcp; + + tcp = skb_header_pointer(skb, ip_hdrlen(skb), sizeof(struct tcphdr), _tcph); + if (!tcp) + return NULL; + + if (!tcp->syn) + return NULL; + + ctx->totlen = ntohs(ip->tot_len); + ctx->df = ntohs(ip->frag_off) & IP_DF; + ctx->window = ntohs(tcp->window); + + if (tcp->doff * 4 > sizeof(struct tcphdr)) { + ctx->optsize = tcp->doff * 4 - sizeof(struct tcphdr); + + ctx->optp = skb_header_pointer(skb, ip_hdrlen(skb) + + sizeof(struct tcphdr), ctx->optsize, opts); + if (!ctx->optp) + return NULL; + } + + return tcp; +} + +bool +nf_osf_match(const struct sk_buff *skb, u_int8_t family, + int hooknum, struct net_device *in, struct net_device *out, + const struct nf_osf_info *info, struct net *net, + const struct list_head *nf_osf_fingers) +{ + const struct iphdr *ip = ip_hdr(skb); + const struct nf_osf_user_finger *f; + unsigned char opts[MAX_IPOPTLEN]; + const struct nf_osf_finger *kf; + int fcount = 0, ttl_check; + int fmatch = FMATCH_WRONG; + struct nf_osf_hdr_ctx ctx; + const struct tcphdr *tcp; + struct tcphdr _tcph; + + memset(&ctx, 0, sizeof(ctx)); + + tcp = nf_osf_hdr_ctx_init(&ctx, skb, ip, opts, &_tcph); + if (!tcp) + return false; + + ttl_check = (info->flags & NF_OSF_TTL) ? info->ttl : 0; + + list_for_each_entry_rcu(kf, &nf_osf_fingers[ctx.df], finger_entry) { + + f = &kf->finger; + + if (!(info->flags & NF_OSF_LOG) && strcmp(info->genre, f->genre)) + continue; + + if (!nf_osf_match_one(skb, f, ttl_check, &ctx)) + continue; + + fmatch = FMATCH_OK; + + fcount++; + + if (info->flags & NF_OSF_LOG) + nf_log_packet(net, family, hooknum, skb, + in, out, NULL, + "%s [%s:%s] : %pI4:%d -> %pI4:%d hops=%d\n", + f->genre, f->version, f->subtype, + &ip->saddr, ntohs(tcp->source), + &ip->daddr, ntohs(tcp->dest), + f->ttl - ip->ttl); + + if ((info->flags & NF_OSF_LOG) && + info->loglevel == NF_OSF_LOGLEVEL_FIRST) + break; + } + + if (!fcount && (info->flags & NF_OSF_LOG)) + nf_log_packet(net, family, hooknum, skb, in, out, NULL, + "Remote OS is not known: %pI4:%u -> %pI4:%u\n", + &ip->saddr, ntohs(tcp->source), + &ip->daddr, ntohs(tcp->dest)); + + if (fcount) + fmatch = FMATCH_OK; + + return fmatch == FMATCH_OK; +} +EXPORT_SYMBOL_GPL(nf_osf_match); + +bool nf_osf_find(const struct sk_buff *skb, + const struct list_head *nf_osf_fingers, + const int ttl_check, struct nf_osf_data *data) +{ + const struct iphdr *ip = ip_hdr(skb); + const struct nf_osf_user_finger *f; + unsigned char opts[MAX_IPOPTLEN]; + const struct nf_osf_finger *kf; + struct nf_osf_hdr_ctx ctx; + const struct tcphdr *tcp; + struct tcphdr _tcph; + bool found = false; + + memset(&ctx, 0, sizeof(ctx)); + + tcp = nf_osf_hdr_ctx_init(&ctx, skb, ip, opts, &_tcph); + if (!tcp) + return false; + + list_for_each_entry_rcu(kf, &nf_osf_fingers[ctx.df], finger_entry) { + f = &kf->finger; + if (!nf_osf_match_one(skb, f, ttl_check, &ctx)) + continue; + + data->genre = f->genre; + data->version = f->version; + found = true; + break; + } + + return found; +} +EXPORT_SYMBOL_GPL(nf_osf_find); + +static const struct nla_policy nfnl_osf_policy[OSF_ATTR_MAX + 1] = { + [OSF_ATTR_FINGER] = { .len = sizeof(struct nf_osf_user_finger) }, +}; + +static int nfnl_osf_add_callback(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const osf_attrs[]) +{ + struct nf_osf_user_finger *f; + struct nf_osf_finger *kf = NULL, *sf; + int err = 0; + + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + if (!osf_attrs[OSF_ATTR_FINGER]) + return -EINVAL; + + if (!(info->nlh->nlmsg_flags & NLM_F_CREATE)) + return -EINVAL; + + f = nla_data(osf_attrs[OSF_ATTR_FINGER]); + + if (f->opt_num > ARRAY_SIZE(f->opt)) + return -EINVAL; + + if (!memchr(f->genre, 0, MAXGENRELEN) || + !memchr(f->subtype, 0, MAXGENRELEN) || + !memchr(f->version, 0, MAXGENRELEN)) + return -EINVAL; + + kf = kmalloc(sizeof(struct nf_osf_finger), GFP_KERNEL); + if (!kf) + return -ENOMEM; + + memcpy(&kf->finger, f, sizeof(struct nf_osf_user_finger)); + + list_for_each_entry(sf, &nf_osf_fingers[!!f->df], finger_entry) { + if (memcmp(&sf->finger, f, sizeof(struct nf_osf_user_finger))) + continue; + + kfree(kf); + kf = NULL; + + if (info->nlh->nlmsg_flags & NLM_F_EXCL) + err = -EEXIST; + break; + } + + /* + * We are protected by nfnl mutex. + */ + if (kf) + list_add_tail_rcu(&kf->finger_entry, &nf_osf_fingers[!!f->df]); + + return err; +} + +static int nfnl_osf_remove_callback(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const osf_attrs[]) +{ + struct nf_osf_user_finger *f; + struct nf_osf_finger *sf; + int err = -ENOENT; + + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + if (!osf_attrs[OSF_ATTR_FINGER]) + return -EINVAL; + + f = nla_data(osf_attrs[OSF_ATTR_FINGER]); + + list_for_each_entry(sf, &nf_osf_fingers[!!f->df], finger_entry) { + if (memcmp(&sf->finger, f, sizeof(struct nf_osf_user_finger))) + continue; + + /* + * We are protected by nfnl mutex. + */ + list_del_rcu(&sf->finger_entry); + kfree_rcu(sf, rcu_head); + + err = 0; + break; + } + + return err; +} + +static const struct nfnl_callback nfnl_osf_callbacks[OSF_MSG_MAX] = { + [OSF_MSG_ADD] = { + .call = nfnl_osf_add_callback, + .type = NFNL_CB_MUTEX, + .attr_count = OSF_ATTR_MAX, + .policy = nfnl_osf_policy, + }, + [OSF_MSG_REMOVE] = { + .call = nfnl_osf_remove_callback, + .type = NFNL_CB_MUTEX, + .attr_count = OSF_ATTR_MAX, + .policy = nfnl_osf_policy, + }, +}; + +static const struct nfnetlink_subsystem nfnl_osf_subsys = { + .name = "osf", + .subsys_id = NFNL_SUBSYS_OSF, + .cb_count = OSF_MSG_MAX, + .cb = nfnl_osf_callbacks, +}; + +static int __init nfnl_osf_init(void) +{ + int err = -EINVAL; + int i; + + for (i = 0; i < ARRAY_SIZE(nf_osf_fingers); ++i) + INIT_LIST_HEAD(&nf_osf_fingers[i]); + + err = nfnetlink_subsys_register(&nfnl_osf_subsys); + if (err < 0) { + pr_err("Failed to register OSF nsfnetlink helper (%d)\n", err); + goto err_out_exit; + } + return 0; + +err_out_exit: + return err; +} + +static void __exit nfnl_osf_fini(void) +{ + struct nf_osf_finger *f; + int i; + + nfnetlink_subsys_unregister(&nfnl_osf_subsys); + + rcu_read_lock(); + for (i = 0; i < ARRAY_SIZE(nf_osf_fingers); ++i) { + list_for_each_entry_rcu(f, &nf_osf_fingers[i], finger_entry) { + list_del_rcu(&f->finger_entry); + kfree_rcu(f, rcu_head); + } + } + rcu_read_unlock(); + + rcu_barrier(); +} + +module_init(nfnl_osf_init); +module_exit(nfnl_osf_fini); + +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_OSF); diff --git a/net/netfilter/nfnetlink_queue.c b/net/netfilter/nfnetlink_queue.c new file mode 100644 index 000000000..87a9009d5 --- /dev/null +++ b/net/netfilter/nfnetlink_queue.c @@ -0,0 +1,1617 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * This is a module which is used for queueing packets and communicating with + * userspace via nfnetlink. + * + * (C) 2005 by Harald Welte + * (C) 2007 by Patrick McHardy + * + * Based on the old ipv4-only ip_queue.c: + * (C) 2000-2002 James Morris + * (C) 2003-2005 Netfilter Core Team + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) +#include "../bridge/br_private.h" +#endif + +#if IS_ENABLED(CONFIG_NF_CONNTRACK) +#include +#endif + +#define NFQNL_QMAX_DEFAULT 1024 + +/* We're using struct nlattr which has 16bit nla_len. Note that nla_len + * includes the header length. Thus, the maximum packet length that we + * support is 65531 bytes. We send truncated packets if the specified length + * is larger than that. Userspace can check for presence of NFQA_CAP_LEN + * attribute to detect truncation. + */ +#define NFQNL_MAX_COPY_RANGE (0xffff - NLA_HDRLEN) + +struct nfqnl_instance { + struct hlist_node hlist; /* global list of queues */ + struct rcu_head rcu; + + u32 peer_portid; + unsigned int queue_maxlen; + unsigned int copy_range; + unsigned int queue_dropped; + unsigned int queue_user_dropped; + + + u_int16_t queue_num; /* number of this queue */ + u_int8_t copy_mode; + u_int32_t flags; /* Set using NFQA_CFG_FLAGS */ +/* + * Following fields are dirtied for each queued packet, + * keep them in same cache line if possible. + */ + spinlock_t lock ____cacheline_aligned_in_smp; + unsigned int queue_total; + unsigned int id_sequence; /* 'sequence' of pkt ids */ + struct list_head queue_list; /* packets in queue */ +}; + +typedef int (*nfqnl_cmpfn)(struct nf_queue_entry *, unsigned long); + +static unsigned int nfnl_queue_net_id __read_mostly; + +#define INSTANCE_BUCKETS 16 +struct nfnl_queue_net { + spinlock_t instances_lock; + struct hlist_head instance_table[INSTANCE_BUCKETS]; +}; + +static struct nfnl_queue_net *nfnl_queue_pernet(struct net *net) +{ + return net_generic(net, nfnl_queue_net_id); +} + +static inline u_int8_t instance_hashfn(u_int16_t queue_num) +{ + return ((queue_num >> 8) ^ queue_num) % INSTANCE_BUCKETS; +} + +static struct nfqnl_instance * +instance_lookup(struct nfnl_queue_net *q, u_int16_t queue_num) +{ + struct hlist_head *head; + struct nfqnl_instance *inst; + + head = &q->instance_table[instance_hashfn(queue_num)]; + hlist_for_each_entry_rcu(inst, head, hlist) { + if (inst->queue_num == queue_num) + return inst; + } + return NULL; +} + +static struct nfqnl_instance * +instance_create(struct nfnl_queue_net *q, u_int16_t queue_num, u32 portid) +{ + struct nfqnl_instance *inst; + unsigned int h; + int err; + + spin_lock(&q->instances_lock); + if (instance_lookup(q, queue_num)) { + err = -EEXIST; + goto out_unlock; + } + + inst = kzalloc(sizeof(*inst), GFP_ATOMIC); + if (!inst) { + err = -ENOMEM; + goto out_unlock; + } + + inst->queue_num = queue_num; + inst->peer_portid = portid; + inst->queue_maxlen = NFQNL_QMAX_DEFAULT; + inst->copy_range = NFQNL_MAX_COPY_RANGE; + inst->copy_mode = NFQNL_COPY_NONE; + spin_lock_init(&inst->lock); + INIT_LIST_HEAD(&inst->queue_list); + + if (!try_module_get(THIS_MODULE)) { + err = -EAGAIN; + goto out_free; + } + + h = instance_hashfn(queue_num); + hlist_add_head_rcu(&inst->hlist, &q->instance_table[h]); + + spin_unlock(&q->instances_lock); + + return inst; + +out_free: + kfree(inst); +out_unlock: + spin_unlock(&q->instances_lock); + return ERR_PTR(err); +} + +static void nfqnl_flush(struct nfqnl_instance *queue, nfqnl_cmpfn cmpfn, + unsigned long data); + +static void +instance_destroy_rcu(struct rcu_head *head) +{ + struct nfqnl_instance *inst = container_of(head, struct nfqnl_instance, + rcu); + + nfqnl_flush(inst, NULL, 0); + kfree(inst); + module_put(THIS_MODULE); +} + +static void +__instance_destroy(struct nfqnl_instance *inst) +{ + hlist_del_rcu(&inst->hlist); + call_rcu(&inst->rcu, instance_destroy_rcu); +} + +static void +instance_destroy(struct nfnl_queue_net *q, struct nfqnl_instance *inst) +{ + spin_lock(&q->instances_lock); + __instance_destroy(inst); + spin_unlock(&q->instances_lock); +} + +static inline void +__enqueue_entry(struct nfqnl_instance *queue, struct nf_queue_entry *entry) +{ + list_add_tail(&entry->list, &queue->queue_list); + queue->queue_total++; +} + +static void +__dequeue_entry(struct nfqnl_instance *queue, struct nf_queue_entry *entry) +{ + list_del(&entry->list); + queue->queue_total--; +} + +static struct nf_queue_entry * +find_dequeue_entry(struct nfqnl_instance *queue, unsigned int id) +{ + struct nf_queue_entry *entry = NULL, *i; + + spin_lock_bh(&queue->lock); + + list_for_each_entry(i, &queue->queue_list, list) { + if (i->id == id) { + entry = i; + break; + } + } + + if (entry) + __dequeue_entry(queue, entry); + + spin_unlock_bh(&queue->lock); + + return entry; +} + +static void nfqnl_reinject(struct nf_queue_entry *entry, unsigned int verdict) +{ + const struct nf_ct_hook *ct_hook; + int err; + + if (verdict == NF_ACCEPT || + verdict == NF_REPEAT || + verdict == NF_STOP) { + rcu_read_lock(); + ct_hook = rcu_dereference(nf_ct_hook); + if (ct_hook) { + err = ct_hook->update(entry->state.net, entry->skb); + if (err < 0) + verdict = NF_DROP; + } + rcu_read_unlock(); + } + nf_reinject(entry, verdict); +} + +static void +nfqnl_flush(struct nfqnl_instance *queue, nfqnl_cmpfn cmpfn, unsigned long data) +{ + struct nf_queue_entry *entry, *next; + + spin_lock_bh(&queue->lock); + list_for_each_entry_safe(entry, next, &queue->queue_list, list) { + if (!cmpfn || cmpfn(entry, data)) { + list_del(&entry->list); + queue->queue_total--; + nfqnl_reinject(entry, NF_DROP); + } + } + spin_unlock_bh(&queue->lock); +} + +static int +nfqnl_put_packet_info(struct sk_buff *nlskb, struct sk_buff *packet, + bool csum_verify) +{ + __u32 flags = 0; + + if (packet->ip_summed == CHECKSUM_PARTIAL) + flags = NFQA_SKB_CSUMNOTREADY; + else if (csum_verify) + flags = NFQA_SKB_CSUM_NOTVERIFIED; + + if (skb_is_gso(packet)) + flags |= NFQA_SKB_GSO; + + return flags ? nla_put_be32(nlskb, NFQA_SKB_INFO, htonl(flags)) : 0; +} + +static int nfqnl_put_sk_uidgid(struct sk_buff *skb, struct sock *sk) +{ + const struct cred *cred; + + if (!sk_fullsock(sk)) + return 0; + + read_lock_bh(&sk->sk_callback_lock); + if (sk->sk_socket && sk->sk_socket->file) { + cred = sk->sk_socket->file->f_cred; + if (nla_put_be32(skb, NFQA_UID, + htonl(from_kuid_munged(&init_user_ns, cred->fsuid)))) + goto nla_put_failure; + if (nla_put_be32(skb, NFQA_GID, + htonl(from_kgid_munged(&init_user_ns, cred->fsgid)))) + goto nla_put_failure; + } + read_unlock_bh(&sk->sk_callback_lock); + return 0; + +nla_put_failure: + read_unlock_bh(&sk->sk_callback_lock); + return -1; +} + +static u32 nfqnl_get_sk_secctx(struct sk_buff *skb, char **secdata) +{ + u32 seclen = 0; +#if IS_ENABLED(CONFIG_NETWORK_SECMARK) + if (!skb || !sk_fullsock(skb->sk)) + return 0; + + read_lock_bh(&skb->sk->sk_callback_lock); + + if (skb->secmark) + security_secid_to_secctx(skb->secmark, secdata, &seclen); + + read_unlock_bh(&skb->sk->sk_callback_lock); +#endif + return seclen; +} + +static u32 nfqnl_get_bridge_size(struct nf_queue_entry *entry) +{ + struct sk_buff *entskb = entry->skb; + u32 nlalen = 0; + + if (entry->state.pf != PF_BRIDGE || !skb_mac_header_was_set(entskb)) + return 0; + + if (skb_vlan_tag_present(entskb)) + nlalen += nla_total_size(nla_total_size(sizeof(__be16)) + + nla_total_size(sizeof(__be16))); + + if (entskb->network_header > entskb->mac_header) + nlalen += nla_total_size((entskb->network_header - + entskb->mac_header)); + + return nlalen; +} + +static int nfqnl_put_bridge(struct nf_queue_entry *entry, struct sk_buff *skb) +{ + struct sk_buff *entskb = entry->skb; + + if (entry->state.pf != PF_BRIDGE || !skb_mac_header_was_set(entskb)) + return 0; + + if (skb_vlan_tag_present(entskb)) { + struct nlattr *nest; + + nest = nla_nest_start(skb, NFQA_VLAN); + if (!nest) + goto nla_put_failure; + + if (nla_put_be16(skb, NFQA_VLAN_TCI, htons(entskb->vlan_tci)) || + nla_put_be16(skb, NFQA_VLAN_PROTO, entskb->vlan_proto)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + } + + if (entskb->mac_header < entskb->network_header) { + int len = (int)(entskb->network_header - entskb->mac_header); + + if (nla_put(skb, NFQA_L2HDR, len, skb_mac_header(entskb))) + goto nla_put_failure; + } + + return 0; + +nla_put_failure: + return -1; +} + +static struct sk_buff * +nfqnl_build_packet_message(struct net *net, struct nfqnl_instance *queue, + struct nf_queue_entry *entry, + __be32 **packet_id_ptr) +{ + size_t size; + size_t data_len = 0, cap_len = 0; + unsigned int hlen = 0; + struct sk_buff *skb; + struct nlattr *nla; + struct nfqnl_msg_packet_hdr *pmsg; + struct nlmsghdr *nlh; + struct sk_buff *entskb = entry->skb; + struct net_device *indev; + struct net_device *outdev; + struct nf_conn *ct = NULL; + enum ip_conntrack_info ctinfo = 0; + const struct nfnl_ct_hook *nfnl_ct; + bool csum_verify; + char *secdata = NULL; + u32 seclen = 0; + ktime_t tstamp; + + size = nlmsg_total_size(sizeof(struct nfgenmsg)) + + nla_total_size(sizeof(struct nfqnl_msg_packet_hdr)) + + nla_total_size(sizeof(u_int32_t)) /* ifindex */ + + nla_total_size(sizeof(u_int32_t)) /* ifindex */ +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + + nla_total_size(sizeof(u_int32_t)) /* ifindex */ + + nla_total_size(sizeof(u_int32_t)) /* ifindex */ +#endif + + nla_total_size(sizeof(u_int32_t)) /* mark */ + + nla_total_size(sizeof(u_int32_t)) /* priority */ + + nla_total_size(sizeof(struct nfqnl_msg_packet_hw)) + + nla_total_size(sizeof(u_int32_t)) /* skbinfo */ + + nla_total_size(sizeof(u_int32_t)); /* cap_len */ + + tstamp = skb_tstamp_cond(entskb, false); + if (tstamp) + size += nla_total_size(sizeof(struct nfqnl_msg_packet_timestamp)); + + size += nfqnl_get_bridge_size(entry); + + if (entry->state.hook <= NF_INET_FORWARD || + (entry->state.hook == NF_INET_POST_ROUTING && entskb->sk == NULL)) + csum_verify = !skb_csum_unnecessary(entskb); + else + csum_verify = false; + + outdev = entry->state.out; + + switch ((enum nfqnl_config_mode)READ_ONCE(queue->copy_mode)) { + case NFQNL_COPY_META: + case NFQNL_COPY_NONE: + break; + + case NFQNL_COPY_PACKET: + if (!(queue->flags & NFQA_CFG_F_GSO) && + entskb->ip_summed == CHECKSUM_PARTIAL && + skb_checksum_help(entskb)) + return NULL; + + data_len = READ_ONCE(queue->copy_range); + if (data_len > entskb->len) + data_len = entskb->len; + + hlen = skb_zerocopy_headlen(entskb); + hlen = min_t(unsigned int, hlen, data_len); + size += sizeof(struct nlattr) + hlen; + cap_len = entskb->len; + break; + } + + nfnl_ct = rcu_dereference(nfnl_ct_hook); + +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + if (queue->flags & NFQA_CFG_F_CONNTRACK) { + if (nfnl_ct != NULL) { + ct = nf_ct_get(entskb, &ctinfo); + if (ct != NULL) + size += nfnl_ct->build_size(ct); + } + } +#endif + + if (queue->flags & NFQA_CFG_F_UID_GID) { + size += (nla_total_size(sizeof(u_int32_t)) /* uid */ + + nla_total_size(sizeof(u_int32_t))); /* gid */ + } + + if ((queue->flags & NFQA_CFG_F_SECCTX) && entskb->sk) { + seclen = nfqnl_get_sk_secctx(entskb, &secdata); + if (seclen) + size += nla_total_size(seclen); + } + + skb = alloc_skb(size, GFP_ATOMIC); + if (!skb) { + skb_tx_error(entskb); + goto nlmsg_failure; + } + + nlh = nfnl_msg_put(skb, 0, 0, + nfnl_msg_type(NFNL_SUBSYS_QUEUE, NFQNL_MSG_PACKET), + 0, entry->state.pf, NFNETLINK_V0, + htons(queue->queue_num)); + if (!nlh) { + skb_tx_error(entskb); + kfree_skb(skb); + goto nlmsg_failure; + } + + nla = __nla_reserve(skb, NFQA_PACKET_HDR, sizeof(*pmsg)); + pmsg = nla_data(nla); + pmsg->hw_protocol = entskb->protocol; + pmsg->hook = entry->state.hook; + *packet_id_ptr = &pmsg->packet_id; + + indev = entry->state.in; + if (indev) { +#if !IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + if (nla_put_be32(skb, NFQA_IFINDEX_INDEV, htonl(indev->ifindex))) + goto nla_put_failure; +#else + if (entry->state.pf == PF_BRIDGE) { + /* Case 1: indev is physical input device, we need to + * look for bridge group (when called from + * netfilter_bridge) */ + if (nla_put_be32(skb, NFQA_IFINDEX_PHYSINDEV, + htonl(indev->ifindex)) || + /* this is the bridge group "brX" */ + /* rcu_read_lock()ed by __nf_queue */ + nla_put_be32(skb, NFQA_IFINDEX_INDEV, + htonl(br_port_get_rcu(indev)->br->dev->ifindex))) + goto nla_put_failure; + } else { + int physinif; + + /* Case 2: indev is bridge group, we need to look for + * physical device (when called from ipv4) */ + if (nla_put_be32(skb, NFQA_IFINDEX_INDEV, + htonl(indev->ifindex))) + goto nla_put_failure; + + physinif = nf_bridge_get_physinif(entskb); + if (physinif && + nla_put_be32(skb, NFQA_IFINDEX_PHYSINDEV, + htonl(physinif))) + goto nla_put_failure; + } +#endif + } + + if (outdev) { +#if !IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + if (nla_put_be32(skb, NFQA_IFINDEX_OUTDEV, htonl(outdev->ifindex))) + goto nla_put_failure; +#else + if (entry->state.pf == PF_BRIDGE) { + /* Case 1: outdev is physical output device, we need to + * look for bridge group (when called from + * netfilter_bridge) */ + if (nla_put_be32(skb, NFQA_IFINDEX_PHYSOUTDEV, + htonl(outdev->ifindex)) || + /* this is the bridge group "brX" */ + /* rcu_read_lock()ed by __nf_queue */ + nla_put_be32(skb, NFQA_IFINDEX_OUTDEV, + htonl(br_port_get_rcu(outdev)->br->dev->ifindex))) + goto nla_put_failure; + } else { + int physoutif; + + /* Case 2: outdev is bridge group, we need to look for + * physical output device (when called from ipv4) */ + if (nla_put_be32(skb, NFQA_IFINDEX_OUTDEV, + htonl(outdev->ifindex))) + goto nla_put_failure; + + physoutif = nf_bridge_get_physoutif(entskb); + if (physoutif && + nla_put_be32(skb, NFQA_IFINDEX_PHYSOUTDEV, + htonl(physoutif))) + goto nla_put_failure; + } +#endif + } + + if (entskb->mark && + nla_put_be32(skb, NFQA_MARK, htonl(entskb->mark))) + goto nla_put_failure; + + if (entskb->priority && + nla_put_be32(skb, NFQA_PRIORITY, htonl(entskb->priority))) + goto nla_put_failure; + + if (indev && entskb->dev && + skb_mac_header_was_set(entskb) && + skb_mac_header_len(entskb) != 0) { + struct nfqnl_msg_packet_hw phw; + int len; + + memset(&phw, 0, sizeof(phw)); + len = dev_parse_header(entskb, phw.hw_addr); + if (len) { + phw.hw_addrlen = htons(len); + if (nla_put(skb, NFQA_HWADDR, sizeof(phw), &phw)) + goto nla_put_failure; + } + } + + if (nfqnl_put_bridge(entry, skb) < 0) + goto nla_put_failure; + + if (entry->state.hook <= NF_INET_FORWARD && tstamp) { + struct nfqnl_msg_packet_timestamp ts; + struct timespec64 kts = ktime_to_timespec64(tstamp); + + ts.sec = cpu_to_be64(kts.tv_sec); + ts.usec = cpu_to_be64(kts.tv_nsec / NSEC_PER_USEC); + + if (nla_put(skb, NFQA_TIMESTAMP, sizeof(ts), &ts)) + goto nla_put_failure; + } + + if ((queue->flags & NFQA_CFG_F_UID_GID) && entskb->sk && + nfqnl_put_sk_uidgid(skb, entskb->sk) < 0) + goto nla_put_failure; + + if (seclen && nla_put(skb, NFQA_SECCTX, seclen, secdata)) + goto nla_put_failure; + + if (ct && nfnl_ct->build(skb, ct, ctinfo, NFQA_CT, NFQA_CT_INFO) < 0) + goto nla_put_failure; + + if (cap_len > data_len && + nla_put_be32(skb, NFQA_CAP_LEN, htonl(cap_len))) + goto nla_put_failure; + + if (nfqnl_put_packet_info(skb, entskb, csum_verify)) + goto nla_put_failure; + + if (data_len) { + struct nlattr *nla; + + if (skb_tailroom(skb) < sizeof(*nla) + hlen) + goto nla_put_failure; + + nla = skb_put(skb, sizeof(*nla)); + nla->nla_type = NFQA_PAYLOAD; + nla->nla_len = nla_attr_size(data_len); + + if (skb_zerocopy(skb, entskb, data_len, hlen)) + goto nla_put_failure; + } + + nlh->nlmsg_len = skb->len; + if (seclen) + security_release_secctx(secdata, seclen); + return skb; + +nla_put_failure: + skb_tx_error(entskb); + kfree_skb(skb); + net_err_ratelimited("nf_queue: error creating packet message\n"); +nlmsg_failure: + if (seclen) + security_release_secctx(secdata, seclen); + return NULL; +} + +static bool nf_ct_drop_unconfirmed(const struct nf_queue_entry *entry) +{ +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + static const unsigned long flags = IPS_CONFIRMED | IPS_DYING; + const struct nf_conn *ct = (void *)skb_nfct(entry->skb); + + if (ct && ((ct->status & flags) == IPS_DYING)) + return true; +#endif + return false; +} + +static int +__nfqnl_enqueue_packet(struct net *net, struct nfqnl_instance *queue, + struct nf_queue_entry *entry) +{ + struct sk_buff *nskb; + int err = -ENOBUFS; + __be32 *packet_id_ptr; + int failopen = 0; + + nskb = nfqnl_build_packet_message(net, queue, entry, &packet_id_ptr); + if (nskb == NULL) { + err = -ENOMEM; + goto err_out; + } + spin_lock_bh(&queue->lock); + + if (nf_ct_drop_unconfirmed(entry)) + goto err_out_free_nskb; + + if (queue->queue_total >= queue->queue_maxlen) { + if (queue->flags & NFQA_CFG_F_FAIL_OPEN) { + failopen = 1; + err = 0; + } else { + queue->queue_dropped++; + net_warn_ratelimited("nf_queue: full at %d entries, dropping packets(s)\n", + queue->queue_total); + } + goto err_out_free_nskb; + } + entry->id = ++queue->id_sequence; + *packet_id_ptr = htonl(entry->id); + + /* nfnetlink_unicast will either free the nskb or add it to a socket */ + err = nfnetlink_unicast(nskb, net, queue->peer_portid); + if (err < 0) { + if (queue->flags & NFQA_CFG_F_FAIL_OPEN) { + failopen = 1; + err = 0; + } else { + queue->queue_user_dropped++; + } + goto err_out_unlock; + } + + __enqueue_entry(queue, entry); + + spin_unlock_bh(&queue->lock); + return 0; + +err_out_free_nskb: + kfree_skb(nskb); +err_out_unlock: + spin_unlock_bh(&queue->lock); + if (failopen) + nfqnl_reinject(entry, NF_ACCEPT); +err_out: + return err; +} + +static struct nf_queue_entry * +nf_queue_entry_dup(struct nf_queue_entry *e) +{ + struct nf_queue_entry *entry = kmemdup(e, e->size, GFP_ATOMIC); + + if (!entry) + return NULL; + + if (nf_queue_entry_get_refs(entry)) + return entry; + + kfree(entry); + return NULL; +} + +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) +/* When called from bridge netfilter, skb->data must point to MAC header + * before calling skb_gso_segment(). Else, original MAC header is lost + * and segmented skbs will be sent to wrong destination. + */ +static void nf_bridge_adjust_skb_data(struct sk_buff *skb) +{ + if (nf_bridge_info_get(skb)) + __skb_push(skb, skb->network_header - skb->mac_header); +} + +static void nf_bridge_adjust_segmented_data(struct sk_buff *skb) +{ + if (nf_bridge_info_get(skb)) + __skb_pull(skb, skb->network_header - skb->mac_header); +} +#else +#define nf_bridge_adjust_skb_data(s) do {} while (0) +#define nf_bridge_adjust_segmented_data(s) do {} while (0) +#endif + +static int +__nfqnl_enqueue_packet_gso(struct net *net, struct nfqnl_instance *queue, + struct sk_buff *skb, struct nf_queue_entry *entry) +{ + int ret = -ENOMEM; + struct nf_queue_entry *entry_seg; + + nf_bridge_adjust_segmented_data(skb); + + if (skb->next == NULL) { /* last packet, no need to copy entry */ + struct sk_buff *gso_skb = entry->skb; + entry->skb = skb; + ret = __nfqnl_enqueue_packet(net, queue, entry); + if (ret) + entry->skb = gso_skb; + return ret; + } + + skb_mark_not_on_list(skb); + + entry_seg = nf_queue_entry_dup(entry); + if (entry_seg) { + entry_seg->skb = skb; + ret = __nfqnl_enqueue_packet(net, queue, entry_seg); + if (ret) + nf_queue_entry_free(entry_seg); + } + return ret; +} + +static int +nfqnl_enqueue_packet(struct nf_queue_entry *entry, unsigned int queuenum) +{ + unsigned int queued; + struct nfqnl_instance *queue; + struct sk_buff *skb, *segs, *nskb; + int err = -ENOBUFS; + struct net *net = entry->state.net; + struct nfnl_queue_net *q = nfnl_queue_pernet(net); + + /* rcu_read_lock()ed by nf_hook_thresh */ + queue = instance_lookup(q, queuenum); + if (!queue) + return -ESRCH; + + if (queue->copy_mode == NFQNL_COPY_NONE) + return -EINVAL; + + skb = entry->skb; + + switch (entry->state.pf) { + case NFPROTO_IPV4: + skb->protocol = htons(ETH_P_IP); + break; + case NFPROTO_IPV6: + skb->protocol = htons(ETH_P_IPV6); + break; + } + + if ((queue->flags & NFQA_CFG_F_GSO) || !skb_is_gso(skb)) + return __nfqnl_enqueue_packet(net, queue, entry); + + nf_bridge_adjust_skb_data(skb); + segs = skb_gso_segment(skb, 0); + /* Does not use PTR_ERR to limit the number of error codes that can be + * returned by nf_queue. For instance, callers rely on -ESRCH to + * mean 'ignore this hook'. + */ + if (IS_ERR_OR_NULL(segs)) + goto out_err; + queued = 0; + err = 0; + skb_list_walk_safe(segs, segs, nskb) { + if (err == 0) + err = __nfqnl_enqueue_packet_gso(net, queue, + segs, entry); + if (err == 0) + queued++; + else + kfree_skb(segs); + } + + if (queued) { + if (err) /* some segments are already queued */ + nf_queue_entry_free(entry); + kfree_skb(skb); + return 0; + } + out_err: + nf_bridge_adjust_segmented_data(skb); + return err; +} + +static int +nfqnl_mangle(void *data, unsigned int data_len, struct nf_queue_entry *e, int diff) +{ + struct sk_buff *nskb; + + if (diff < 0) { + unsigned int min_len = skb_transport_offset(e->skb); + + if (data_len < min_len) + return -EINVAL; + + if (pskb_trim(e->skb, data_len)) + return -ENOMEM; + } else if (diff > 0) { + if (data_len > 0xFFFF) + return -EINVAL; + if (diff > skb_tailroom(e->skb)) { + nskb = skb_copy_expand(e->skb, skb_headroom(e->skb), + diff, GFP_ATOMIC); + if (!nskb) + return -ENOMEM; + kfree_skb(e->skb); + e->skb = nskb; + } + skb_put(e->skb, diff); + } + if (skb_ensure_writable(e->skb, data_len)) + return -ENOMEM; + skb_copy_to_linear_data(e->skb, data, data_len); + e->skb->ip_summed = CHECKSUM_NONE; + return 0; +} + +static int +nfqnl_set_mode(struct nfqnl_instance *queue, + unsigned char mode, unsigned int range) +{ + int status = 0; + + spin_lock_bh(&queue->lock); + switch (mode) { + case NFQNL_COPY_NONE: + case NFQNL_COPY_META: + queue->copy_mode = mode; + queue->copy_range = 0; + break; + + case NFQNL_COPY_PACKET: + queue->copy_mode = mode; + if (range == 0 || range > NFQNL_MAX_COPY_RANGE) + queue->copy_range = NFQNL_MAX_COPY_RANGE; + else + queue->copy_range = range; + break; + + default: + status = -EINVAL; + + } + spin_unlock_bh(&queue->lock); + + return status; +} + +static int +dev_cmp(struct nf_queue_entry *entry, unsigned long ifindex) +{ +#if IS_ENABLED(CONFIG_BRIDGE_NETFILTER) + int physinif, physoutif; + + physinif = nf_bridge_get_physinif(entry->skb); + physoutif = nf_bridge_get_physoutif(entry->skb); + + if (physinif == ifindex || physoutif == ifindex) + return 1; +#endif + if (entry->state.in) + if (entry->state.in->ifindex == ifindex) + return 1; + if (entry->state.out) + if (entry->state.out->ifindex == ifindex) + return 1; + + return 0; +} + +/* drop all packets with either indev or outdev == ifindex from all queue + * instances */ +static void +nfqnl_dev_drop(struct net *net, int ifindex) +{ + int i; + struct nfnl_queue_net *q = nfnl_queue_pernet(net); + + rcu_read_lock(); + + for (i = 0; i < INSTANCE_BUCKETS; i++) { + struct nfqnl_instance *inst; + struct hlist_head *head = &q->instance_table[i]; + + hlist_for_each_entry_rcu(inst, head, hlist) + nfqnl_flush(inst, dev_cmp, ifindex); + } + + rcu_read_unlock(); +} + +static int +nfqnl_rcv_dev_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + /* Drop any packets associated with the downed device */ + if (event == NETDEV_DOWN) + nfqnl_dev_drop(dev_net(dev), dev->ifindex); + return NOTIFY_DONE; +} + +static struct notifier_block nfqnl_dev_notifier = { + .notifier_call = nfqnl_rcv_dev_event, +}; + +static void nfqnl_nf_hook_drop(struct net *net) +{ + struct nfnl_queue_net *q = nfnl_queue_pernet(net); + int i; + + /* This function is also called on net namespace error unwind, + * when pernet_ops->init() failed and ->exit() functions of the + * previous pernet_ops gets called. + * + * This may result in a call to nfqnl_nf_hook_drop() before + * struct nfnl_queue_net was allocated. + */ + if (!q) + return; + + for (i = 0; i < INSTANCE_BUCKETS; i++) { + struct nfqnl_instance *inst; + struct hlist_head *head = &q->instance_table[i]; + + hlist_for_each_entry_rcu(inst, head, hlist) + nfqnl_flush(inst, NULL, 0); + } +} + +static int +nfqnl_rcv_nl_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct netlink_notify *n = ptr; + struct nfnl_queue_net *q = nfnl_queue_pernet(n->net); + + if (event == NETLINK_URELEASE && n->protocol == NETLINK_NETFILTER) { + int i; + + /* destroy all instances for this portid */ + spin_lock(&q->instances_lock); + for (i = 0; i < INSTANCE_BUCKETS; i++) { + struct hlist_node *t2; + struct nfqnl_instance *inst; + struct hlist_head *head = &q->instance_table[i]; + + hlist_for_each_entry_safe(inst, t2, head, hlist) { + if (n->portid == inst->peer_portid) + __instance_destroy(inst); + } + } + spin_unlock(&q->instances_lock); + } + return NOTIFY_DONE; +} + +static struct notifier_block nfqnl_rtnl_notifier = { + .notifier_call = nfqnl_rcv_nl_event, +}; + +static const struct nla_policy nfqa_vlan_policy[NFQA_VLAN_MAX + 1] = { + [NFQA_VLAN_TCI] = { .type = NLA_U16}, + [NFQA_VLAN_PROTO] = { .type = NLA_U16}, +}; + +static const struct nla_policy nfqa_verdict_policy[NFQA_MAX+1] = { + [NFQA_VERDICT_HDR] = { .len = sizeof(struct nfqnl_msg_verdict_hdr) }, + [NFQA_MARK] = { .type = NLA_U32 }, + [NFQA_PAYLOAD] = { .type = NLA_UNSPEC }, + [NFQA_CT] = { .type = NLA_UNSPEC }, + [NFQA_EXP] = { .type = NLA_UNSPEC }, + [NFQA_VLAN] = { .type = NLA_NESTED }, + [NFQA_PRIORITY] = { .type = NLA_U32 }, +}; + +static const struct nla_policy nfqa_verdict_batch_policy[NFQA_MAX+1] = { + [NFQA_VERDICT_HDR] = { .len = sizeof(struct nfqnl_msg_verdict_hdr) }, + [NFQA_MARK] = { .type = NLA_U32 }, + [NFQA_PRIORITY] = { .type = NLA_U32 }, +}; + +static struct nfqnl_instance * +verdict_instance_lookup(struct nfnl_queue_net *q, u16 queue_num, u32 nlportid) +{ + struct nfqnl_instance *queue; + + queue = instance_lookup(q, queue_num); + if (!queue) + return ERR_PTR(-ENODEV); + + if (queue->peer_portid != nlportid) + return ERR_PTR(-EPERM); + + return queue; +} + +static struct nfqnl_msg_verdict_hdr* +verdicthdr_get(const struct nlattr * const nfqa[]) +{ + struct nfqnl_msg_verdict_hdr *vhdr; + unsigned int verdict; + + if (!nfqa[NFQA_VERDICT_HDR]) + return NULL; + + vhdr = nla_data(nfqa[NFQA_VERDICT_HDR]); + verdict = ntohl(vhdr->verdict) & NF_VERDICT_MASK; + if (verdict > NF_MAX_VERDICT || verdict == NF_STOLEN) + return NULL; + return vhdr; +} + +static int nfq_id_after(unsigned int id, unsigned int max) +{ + return (int)(id - max) > 0; +} + +static int nfqnl_recv_verdict_batch(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const nfqa[]) +{ + struct nfnl_queue_net *q = nfnl_queue_pernet(info->net); + u16 queue_num = ntohs(info->nfmsg->res_id); + struct nf_queue_entry *entry, *tmp; + struct nfqnl_msg_verdict_hdr *vhdr; + struct nfqnl_instance *queue; + unsigned int verdict, maxid; + LIST_HEAD(batch_list); + + queue = verdict_instance_lookup(q, queue_num, + NETLINK_CB(skb).portid); + if (IS_ERR(queue)) + return PTR_ERR(queue); + + vhdr = verdicthdr_get(nfqa); + if (!vhdr) + return -EINVAL; + + verdict = ntohl(vhdr->verdict); + maxid = ntohl(vhdr->id); + + spin_lock_bh(&queue->lock); + + list_for_each_entry_safe(entry, tmp, &queue->queue_list, list) { + if (nfq_id_after(entry->id, maxid)) + break; + __dequeue_entry(queue, entry); + list_add_tail(&entry->list, &batch_list); + } + + spin_unlock_bh(&queue->lock); + + if (list_empty(&batch_list)) + return -ENOENT; + + list_for_each_entry_safe(entry, tmp, &batch_list, list) { + if (nfqa[NFQA_MARK]) + entry->skb->mark = ntohl(nla_get_be32(nfqa[NFQA_MARK])); + + if (nfqa[NFQA_PRIORITY]) + entry->skb->priority = ntohl(nla_get_be32(nfqa[NFQA_PRIORITY])); + + nfqnl_reinject(entry, verdict); + } + return 0; +} + +static struct nf_conn *nfqnl_ct_parse(const struct nfnl_ct_hook *nfnl_ct, + const struct nlmsghdr *nlh, + const struct nlattr * const nfqa[], + struct nf_queue_entry *entry, + enum ip_conntrack_info *ctinfo) +{ +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + struct nf_conn *ct; + + ct = nf_ct_get(entry->skb, ctinfo); + if (ct == NULL) + return NULL; + + if (nfnl_ct->parse(nfqa[NFQA_CT], ct) < 0) + return NULL; + + if (nfqa[NFQA_EXP]) + nfnl_ct->attach_expect(nfqa[NFQA_EXP], ct, + NETLINK_CB(entry->skb).portid, + nlmsg_report(nlh)); + return ct; +#else + return NULL; +#endif +} + +static int nfqa_parse_bridge(struct nf_queue_entry *entry, + const struct nlattr * const nfqa[]) +{ + if (nfqa[NFQA_VLAN]) { + struct nlattr *tb[NFQA_VLAN_MAX + 1]; + int err; + + err = nla_parse_nested_deprecated(tb, NFQA_VLAN_MAX, + nfqa[NFQA_VLAN], + nfqa_vlan_policy, NULL); + if (err < 0) + return err; + + if (!tb[NFQA_VLAN_TCI] || !tb[NFQA_VLAN_PROTO]) + return -EINVAL; + + __vlan_hwaccel_put_tag(entry->skb, + nla_get_be16(tb[NFQA_VLAN_PROTO]), + ntohs(nla_get_be16(tb[NFQA_VLAN_TCI]))); + } + + if (nfqa[NFQA_L2HDR]) { + int mac_header_len = entry->skb->network_header - + entry->skb->mac_header; + + if (mac_header_len != nla_len(nfqa[NFQA_L2HDR])) + return -EINVAL; + else if (mac_header_len > 0) + memcpy(skb_mac_header(entry->skb), + nla_data(nfqa[NFQA_L2HDR]), + mac_header_len); + } + + return 0; +} + +static int nfqnl_recv_verdict(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nfqa[]) +{ + struct nfnl_queue_net *q = nfnl_queue_pernet(info->net); + u_int16_t queue_num = ntohs(info->nfmsg->res_id); + const struct nfnl_ct_hook *nfnl_ct; + struct nfqnl_msg_verdict_hdr *vhdr; + enum ip_conntrack_info ctinfo; + struct nfqnl_instance *queue; + struct nf_queue_entry *entry; + struct nf_conn *ct = NULL; + unsigned int verdict; + int err; + + queue = verdict_instance_lookup(q, queue_num, + NETLINK_CB(skb).portid); + if (IS_ERR(queue)) + return PTR_ERR(queue); + + vhdr = verdicthdr_get(nfqa); + if (!vhdr) + return -EINVAL; + + verdict = ntohl(vhdr->verdict); + + entry = find_dequeue_entry(queue, ntohl(vhdr->id)); + if (entry == NULL) + return -ENOENT; + + /* rcu lock already held from nfnl->call_rcu. */ + nfnl_ct = rcu_dereference(nfnl_ct_hook); + + if (nfqa[NFQA_CT]) { + if (nfnl_ct != NULL) + ct = nfqnl_ct_parse(nfnl_ct, info->nlh, nfqa, entry, + &ctinfo); + } + + if (entry->state.pf == PF_BRIDGE) { + err = nfqa_parse_bridge(entry, nfqa); + if (err < 0) + return err; + } + + if (nfqa[NFQA_PAYLOAD]) { + u16 payload_len = nla_len(nfqa[NFQA_PAYLOAD]); + int diff = payload_len - entry->skb->len; + + if (nfqnl_mangle(nla_data(nfqa[NFQA_PAYLOAD]), + payload_len, entry, diff) < 0) + verdict = NF_DROP; + + if (ct && diff) + nfnl_ct->seq_adjust(entry->skb, ct, ctinfo, diff); + } + + if (nfqa[NFQA_MARK]) + entry->skb->mark = ntohl(nla_get_be32(nfqa[NFQA_MARK])); + + if (nfqa[NFQA_PRIORITY]) + entry->skb->priority = ntohl(nla_get_be32(nfqa[NFQA_PRIORITY])); + + nfqnl_reinject(entry, verdict); + return 0; +} + +static int nfqnl_recv_unsupp(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const cda[]) +{ + return -ENOTSUPP; +} + +static const struct nla_policy nfqa_cfg_policy[NFQA_CFG_MAX+1] = { + [NFQA_CFG_CMD] = { .len = sizeof(struct nfqnl_msg_config_cmd) }, + [NFQA_CFG_PARAMS] = { .len = sizeof(struct nfqnl_msg_config_params) }, + [NFQA_CFG_QUEUE_MAXLEN] = { .type = NLA_U32 }, + [NFQA_CFG_MASK] = { .type = NLA_U32 }, + [NFQA_CFG_FLAGS] = { .type = NLA_U32 }, +}; + +static const struct nf_queue_handler nfqh = { + .outfn = nfqnl_enqueue_packet, + .nf_hook_drop = nfqnl_nf_hook_drop, +}; + +static int nfqnl_recv_config(struct sk_buff *skb, const struct nfnl_info *info, + const struct nlattr * const nfqa[]) +{ + struct nfnl_queue_net *q = nfnl_queue_pernet(info->net); + u_int16_t queue_num = ntohs(info->nfmsg->res_id); + struct nfqnl_msg_config_cmd *cmd = NULL; + struct nfqnl_instance *queue; + __u32 flags = 0, mask = 0; + int ret = 0; + + if (nfqa[NFQA_CFG_CMD]) { + cmd = nla_data(nfqa[NFQA_CFG_CMD]); + + /* Obsolete commands without queue context */ + switch (cmd->command) { + case NFQNL_CFG_CMD_PF_BIND: return 0; + case NFQNL_CFG_CMD_PF_UNBIND: return 0; + } + } + + /* Check if we support these flags in first place, dependencies should + * be there too not to break atomicity. + */ + if (nfqa[NFQA_CFG_FLAGS]) { + if (!nfqa[NFQA_CFG_MASK]) { + /* A mask is needed to specify which flags are being + * changed. + */ + return -EINVAL; + } + + flags = ntohl(nla_get_be32(nfqa[NFQA_CFG_FLAGS])); + mask = ntohl(nla_get_be32(nfqa[NFQA_CFG_MASK])); + + if (flags >= NFQA_CFG_F_MAX) + return -EOPNOTSUPP; + +#if !IS_ENABLED(CONFIG_NETWORK_SECMARK) + if (flags & mask & NFQA_CFG_F_SECCTX) + return -EOPNOTSUPP; +#endif + if ((flags & mask & NFQA_CFG_F_CONNTRACK) && + !rcu_access_pointer(nfnl_ct_hook)) { +#ifdef CONFIG_MODULES + nfnl_unlock(NFNL_SUBSYS_QUEUE); + request_module("ip_conntrack_netlink"); + nfnl_lock(NFNL_SUBSYS_QUEUE); + if (rcu_access_pointer(nfnl_ct_hook)) + return -EAGAIN; +#endif + return -EOPNOTSUPP; + } + } + + rcu_read_lock(); + queue = instance_lookup(q, queue_num); + if (queue && queue->peer_portid != NETLINK_CB(skb).portid) { + ret = -EPERM; + goto err_out_unlock; + } + + if (cmd != NULL) { + switch (cmd->command) { + case NFQNL_CFG_CMD_BIND: + if (queue) { + ret = -EBUSY; + goto err_out_unlock; + } + queue = instance_create(q, queue_num, + NETLINK_CB(skb).portid); + if (IS_ERR(queue)) { + ret = PTR_ERR(queue); + goto err_out_unlock; + } + break; + case NFQNL_CFG_CMD_UNBIND: + if (!queue) { + ret = -ENODEV; + goto err_out_unlock; + } + instance_destroy(q, queue); + goto err_out_unlock; + case NFQNL_CFG_CMD_PF_BIND: + case NFQNL_CFG_CMD_PF_UNBIND: + break; + default: + ret = -ENOTSUPP; + goto err_out_unlock; + } + } + + if (!queue) { + ret = -ENODEV; + goto err_out_unlock; + } + + if (nfqa[NFQA_CFG_PARAMS]) { + struct nfqnl_msg_config_params *params = + nla_data(nfqa[NFQA_CFG_PARAMS]); + + nfqnl_set_mode(queue, params->copy_mode, + ntohl(params->copy_range)); + } + + if (nfqa[NFQA_CFG_QUEUE_MAXLEN]) { + __be32 *queue_maxlen = nla_data(nfqa[NFQA_CFG_QUEUE_MAXLEN]); + + spin_lock_bh(&queue->lock); + queue->queue_maxlen = ntohl(*queue_maxlen); + spin_unlock_bh(&queue->lock); + } + + if (nfqa[NFQA_CFG_FLAGS]) { + spin_lock_bh(&queue->lock); + queue->flags &= ~mask; + queue->flags |= flags & mask; + spin_unlock_bh(&queue->lock); + } + +err_out_unlock: + rcu_read_unlock(); + return ret; +} + +static const struct nfnl_callback nfqnl_cb[NFQNL_MSG_MAX] = { + [NFQNL_MSG_PACKET] = { + .call = nfqnl_recv_unsupp, + .type = NFNL_CB_RCU, + .attr_count = NFQA_MAX, + }, + [NFQNL_MSG_VERDICT] = { + .call = nfqnl_recv_verdict, + .type = NFNL_CB_RCU, + .attr_count = NFQA_MAX, + .policy = nfqa_verdict_policy + }, + [NFQNL_MSG_CONFIG] = { + .call = nfqnl_recv_config, + .type = NFNL_CB_MUTEX, + .attr_count = NFQA_CFG_MAX, + .policy = nfqa_cfg_policy + }, + [NFQNL_MSG_VERDICT_BATCH] = { + .call = nfqnl_recv_verdict_batch, + .type = NFNL_CB_RCU, + .attr_count = NFQA_MAX, + .policy = nfqa_verdict_batch_policy + }, +}; + +static const struct nfnetlink_subsystem nfqnl_subsys = { + .name = "nf_queue", + .subsys_id = NFNL_SUBSYS_QUEUE, + .cb_count = NFQNL_MSG_MAX, + .cb = nfqnl_cb, +}; + +#ifdef CONFIG_PROC_FS +struct iter_state { + struct seq_net_private p; + unsigned int bucket; +}; + +static struct hlist_node *get_first(struct seq_file *seq) +{ + struct iter_state *st = seq->private; + struct net *net; + struct nfnl_queue_net *q; + + if (!st) + return NULL; + + net = seq_file_net(seq); + q = nfnl_queue_pernet(net); + for (st->bucket = 0; st->bucket < INSTANCE_BUCKETS; st->bucket++) { + if (!hlist_empty(&q->instance_table[st->bucket])) + return q->instance_table[st->bucket].first; + } + return NULL; +} + +static struct hlist_node *get_next(struct seq_file *seq, struct hlist_node *h) +{ + struct iter_state *st = seq->private; + struct net *net = seq_file_net(seq); + + h = h->next; + while (!h) { + struct nfnl_queue_net *q; + + if (++st->bucket >= INSTANCE_BUCKETS) + return NULL; + + q = nfnl_queue_pernet(net); + h = q->instance_table[st->bucket].first; + } + return h; +} + +static struct hlist_node *get_idx(struct seq_file *seq, loff_t pos) +{ + struct hlist_node *head; + head = get_first(seq); + + if (head) + while (pos && (head = get_next(seq, head))) + pos--; + return pos ? NULL : head; +} + +static void *seq_start(struct seq_file *s, loff_t *pos) + __acquires(nfnl_queue_pernet(seq_file_net(s))->instances_lock) +{ + spin_lock(&nfnl_queue_pernet(seq_file_net(s))->instances_lock); + return get_idx(s, *pos); +} + +static void *seq_next(struct seq_file *s, void *v, loff_t *pos) +{ + (*pos)++; + return get_next(s, v); +} + +static void seq_stop(struct seq_file *s, void *v) + __releases(nfnl_queue_pernet(seq_file_net(s))->instances_lock) +{ + spin_unlock(&nfnl_queue_pernet(seq_file_net(s))->instances_lock); +} + +static int seq_show(struct seq_file *s, void *v) +{ + const struct nfqnl_instance *inst = v; + + seq_printf(s, "%5u %6u %5u %1u %5u %5u %5u %8u %2d\n", + inst->queue_num, + inst->peer_portid, inst->queue_total, + inst->copy_mode, inst->copy_range, + inst->queue_dropped, inst->queue_user_dropped, + inst->id_sequence, 1); + return 0; +} + +static const struct seq_operations nfqnl_seq_ops = { + .start = seq_start, + .next = seq_next, + .stop = seq_stop, + .show = seq_show, +}; +#endif /* PROC_FS */ + +static int __net_init nfnl_queue_net_init(struct net *net) +{ + unsigned int i; + struct nfnl_queue_net *q = nfnl_queue_pernet(net); + + for (i = 0; i < INSTANCE_BUCKETS; i++) + INIT_HLIST_HEAD(&q->instance_table[i]); + + spin_lock_init(&q->instances_lock); + +#ifdef CONFIG_PROC_FS + if (!proc_create_net("nfnetlink_queue", 0440, net->nf.proc_netfilter, + &nfqnl_seq_ops, sizeof(struct iter_state))) + return -ENOMEM; +#endif + return 0; +} + +static void __net_exit nfnl_queue_net_exit(struct net *net) +{ + struct nfnl_queue_net *q = nfnl_queue_pernet(net); + unsigned int i; + +#ifdef CONFIG_PROC_FS + remove_proc_entry("nfnetlink_queue", net->nf.proc_netfilter); +#endif + for (i = 0; i < INSTANCE_BUCKETS; i++) + WARN_ON_ONCE(!hlist_empty(&q->instance_table[i])); +} + +static struct pernet_operations nfnl_queue_net_ops = { + .init = nfnl_queue_net_init, + .exit = nfnl_queue_net_exit, + .id = &nfnl_queue_net_id, + .size = sizeof(struct nfnl_queue_net), +}; + +static int __init nfnetlink_queue_init(void) +{ + int status; + + status = register_pernet_subsys(&nfnl_queue_net_ops); + if (status < 0) { + pr_err("failed to register pernet ops\n"); + goto out; + } + + netlink_register_notifier(&nfqnl_rtnl_notifier); + status = nfnetlink_subsys_register(&nfqnl_subsys); + if (status < 0) { + pr_err("failed to create netlink socket\n"); + goto cleanup_netlink_notifier; + } + + status = register_netdevice_notifier(&nfqnl_dev_notifier); + if (status < 0) { + pr_err("failed to register netdevice notifier\n"); + goto cleanup_netlink_subsys; + } + + nf_register_queue_handler(&nfqh); + + return status; + +cleanup_netlink_subsys: + nfnetlink_subsys_unregister(&nfqnl_subsys); +cleanup_netlink_notifier: + netlink_unregister_notifier(&nfqnl_rtnl_notifier); + unregister_pernet_subsys(&nfnl_queue_net_ops); +out: + return status; +} + +static void __exit nfnetlink_queue_fini(void) +{ + nf_unregister_queue_handler(); + unregister_netdevice_notifier(&nfqnl_dev_notifier); + nfnetlink_subsys_unregister(&nfqnl_subsys); + netlink_unregister_notifier(&nfqnl_rtnl_notifier); + unregister_pernet_subsys(&nfnl_queue_net_ops); + + rcu_barrier(); /* Wait for completion of call_rcu()'s */ +} + +MODULE_DESCRIPTION("netfilter packet queue handler"); +MODULE_AUTHOR("Harald Welte "); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_QUEUE); + +module_init(nfnetlink_queue_init); +module_exit(nfnetlink_queue_fini); diff --git a/net/netfilter/nft_bitwise.c b/net/netfilter/nft_bitwise.c new file mode 100644 index 000000000..b84312df9 --- /dev/null +++ b/net/netfilter/nft_bitwise.c @@ -0,0 +1,533 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008-2009 Patrick McHardy + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_bitwise { + u8 sreg; + u8 dreg; + enum nft_bitwise_ops op:8; + u8 len; + struct nft_data mask; + struct nft_data xor; + struct nft_data data; +}; + +static void nft_bitwise_eval_bool(u32 *dst, const u32 *src, + const struct nft_bitwise *priv) +{ + unsigned int i; + + for (i = 0; i < DIV_ROUND_UP(priv->len, sizeof(u32)); i++) + dst[i] = (src[i] & priv->mask.data[i]) ^ priv->xor.data[i]; +} + +static void nft_bitwise_eval_lshift(u32 *dst, const u32 *src, + const struct nft_bitwise *priv) +{ + u32 shift = priv->data.data[0]; + unsigned int i; + u32 carry = 0; + + for (i = DIV_ROUND_UP(priv->len, sizeof(u32)); i > 0; i--) { + dst[i - 1] = (src[i - 1] << shift) | carry; + carry = src[i - 1] >> (BITS_PER_TYPE(u32) - shift); + } +} + +static void nft_bitwise_eval_rshift(u32 *dst, const u32 *src, + const struct nft_bitwise *priv) +{ + u32 shift = priv->data.data[0]; + unsigned int i; + u32 carry = 0; + + for (i = 0; i < DIV_ROUND_UP(priv->len, sizeof(u32)); i++) { + dst[i] = carry | (src[i] >> shift); + carry = src[i] << (BITS_PER_TYPE(u32) - shift); + } +} + +void nft_bitwise_eval(const struct nft_expr *expr, + struct nft_regs *regs, const struct nft_pktinfo *pkt) +{ + const struct nft_bitwise *priv = nft_expr_priv(expr); + const u32 *src = ®s->data[priv->sreg]; + u32 *dst = ®s->data[priv->dreg]; + + switch (priv->op) { + case NFT_BITWISE_BOOL: + nft_bitwise_eval_bool(dst, src, priv); + break; + case NFT_BITWISE_LSHIFT: + nft_bitwise_eval_lshift(dst, src, priv); + break; + case NFT_BITWISE_RSHIFT: + nft_bitwise_eval_rshift(dst, src, priv); + break; + } +} + +static const struct nla_policy nft_bitwise_policy[NFTA_BITWISE_MAX + 1] = { + [NFTA_BITWISE_SREG] = { .type = NLA_U32 }, + [NFTA_BITWISE_DREG] = { .type = NLA_U32 }, + [NFTA_BITWISE_LEN] = { .type = NLA_U32 }, + [NFTA_BITWISE_MASK] = { .type = NLA_NESTED }, + [NFTA_BITWISE_XOR] = { .type = NLA_NESTED }, + [NFTA_BITWISE_OP] = { .type = NLA_U32 }, + [NFTA_BITWISE_DATA] = { .type = NLA_NESTED }, +}; + +static int nft_bitwise_init_bool(struct nft_bitwise *priv, + const struct nlattr *const tb[]) +{ + struct nft_data_desc mask = { + .type = NFT_DATA_VALUE, + .size = sizeof(priv->mask), + .len = priv->len, + }; + struct nft_data_desc xor = { + .type = NFT_DATA_VALUE, + .size = sizeof(priv->xor), + .len = priv->len, + }; + int err; + + if (tb[NFTA_BITWISE_DATA]) + return -EINVAL; + + if (!tb[NFTA_BITWISE_MASK] || + !tb[NFTA_BITWISE_XOR]) + return -EINVAL; + + err = nft_data_init(NULL, &priv->mask, &mask, tb[NFTA_BITWISE_MASK]); + if (err < 0) + return err; + + err = nft_data_init(NULL, &priv->xor, &xor, tb[NFTA_BITWISE_XOR]); + if (err < 0) + goto err_xor_err; + + return 0; + +err_xor_err: + nft_data_release(&priv->mask, mask.type); + + return err; +} + +static int nft_bitwise_init_shift(struct nft_bitwise *priv, + const struct nlattr *const tb[]) +{ + struct nft_data_desc desc = { + .type = NFT_DATA_VALUE, + .size = sizeof(priv->data), + .len = sizeof(u32), + }; + int err; + + if (tb[NFTA_BITWISE_MASK] || + tb[NFTA_BITWISE_XOR]) + return -EINVAL; + + if (!tb[NFTA_BITWISE_DATA]) + return -EINVAL; + + err = nft_data_init(NULL, &priv->data, &desc, tb[NFTA_BITWISE_DATA]); + if (err < 0) + return err; + + if (priv->data.data[0] >= BITS_PER_TYPE(u32)) { + nft_data_release(&priv->data, desc.type); + return -EINVAL; + } + + return 0; +} + +static int nft_bitwise_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_bitwise *priv = nft_expr_priv(expr); + u32 len; + int err; + + err = nft_parse_u32_check(tb[NFTA_BITWISE_LEN], U8_MAX, &len); + if (err < 0) + return err; + + priv->len = len; + + err = nft_parse_register_load(tb[NFTA_BITWISE_SREG], &priv->sreg, + priv->len); + if (err < 0) + return err; + + err = nft_parse_register_store(ctx, tb[NFTA_BITWISE_DREG], + &priv->dreg, NULL, NFT_DATA_VALUE, + priv->len); + if (err < 0) + return err; + + if (tb[NFTA_BITWISE_OP]) { + priv->op = ntohl(nla_get_be32(tb[NFTA_BITWISE_OP])); + switch (priv->op) { + case NFT_BITWISE_BOOL: + case NFT_BITWISE_LSHIFT: + case NFT_BITWISE_RSHIFT: + break; + default: + return -EOPNOTSUPP; + } + } else { + priv->op = NFT_BITWISE_BOOL; + } + + switch(priv->op) { + case NFT_BITWISE_BOOL: + err = nft_bitwise_init_bool(priv, tb); + break; + case NFT_BITWISE_LSHIFT: + case NFT_BITWISE_RSHIFT: + err = nft_bitwise_init_shift(priv, tb); + break; + } + + return err; +} + +static int nft_bitwise_dump_bool(struct sk_buff *skb, + const struct nft_bitwise *priv) +{ + if (nft_data_dump(skb, NFTA_BITWISE_MASK, &priv->mask, + NFT_DATA_VALUE, priv->len) < 0) + return -1; + + if (nft_data_dump(skb, NFTA_BITWISE_XOR, &priv->xor, + NFT_DATA_VALUE, priv->len) < 0) + return -1; + + return 0; +} + +static int nft_bitwise_dump_shift(struct sk_buff *skb, + const struct nft_bitwise *priv) +{ + if (nft_data_dump(skb, NFTA_BITWISE_DATA, &priv->data, + NFT_DATA_VALUE, sizeof(u32)) < 0) + return -1; + return 0; +} + +static int nft_bitwise_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_bitwise *priv = nft_expr_priv(expr); + int err = 0; + + if (nft_dump_register(skb, NFTA_BITWISE_SREG, priv->sreg)) + return -1; + if (nft_dump_register(skb, NFTA_BITWISE_DREG, priv->dreg)) + return -1; + if (nla_put_be32(skb, NFTA_BITWISE_LEN, htonl(priv->len))) + return -1; + if (nla_put_be32(skb, NFTA_BITWISE_OP, htonl(priv->op))) + return -1; + + switch (priv->op) { + case NFT_BITWISE_BOOL: + err = nft_bitwise_dump_bool(skb, priv); + break; + case NFT_BITWISE_LSHIFT: + case NFT_BITWISE_RSHIFT: + err = nft_bitwise_dump_shift(skb, priv); + break; + } + + return err; +} + +static struct nft_data zero; + +static int nft_bitwise_offload(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_expr *expr) +{ + const struct nft_bitwise *priv = nft_expr_priv(expr); + struct nft_offload_reg *reg = &ctx->regs[priv->dreg]; + + if (priv->op != NFT_BITWISE_BOOL) + return -EOPNOTSUPP; + + if (memcmp(&priv->xor, &zero, sizeof(priv->xor)) || + priv->sreg != priv->dreg || priv->len != reg->len) + return -EOPNOTSUPP; + + memcpy(®->mask, &priv->mask, sizeof(priv->mask)); + + return 0; +} + +static bool nft_bitwise_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + const struct nft_bitwise *priv = nft_expr_priv(expr); + const struct nft_bitwise *bitwise; + unsigned int regcount; + u8 dreg; + int i; + + if (!track->regs[priv->sreg].selector) + return false; + + bitwise = nft_expr_priv(track->regs[priv->dreg].selector); + if (track->regs[priv->sreg].selector == track->regs[priv->dreg].selector && + track->regs[priv->sreg].num_reg == 0 && + track->regs[priv->dreg].bitwise && + track->regs[priv->dreg].bitwise->ops == expr->ops && + priv->sreg == bitwise->sreg && + priv->dreg == bitwise->dreg && + priv->op == bitwise->op && + priv->len == bitwise->len && + !memcmp(&priv->mask, &bitwise->mask, sizeof(priv->mask)) && + !memcmp(&priv->xor, &bitwise->xor, sizeof(priv->xor)) && + !memcmp(&priv->data, &bitwise->data, sizeof(priv->data))) { + track->cur = expr; + return true; + } + + if (track->regs[priv->sreg].bitwise || + track->regs[priv->sreg].num_reg != 0) { + nft_reg_track_cancel(track, priv->dreg, priv->len); + return false; + } + + if (priv->sreg != priv->dreg) { + nft_reg_track_update(track, track->regs[priv->sreg].selector, + priv->dreg, priv->len); + } + + dreg = priv->dreg; + regcount = DIV_ROUND_UP(priv->len, NFT_REG32_SIZE); + for (i = 0; i < regcount; i++, dreg++) + track->regs[dreg].bitwise = expr; + + return false; +} + +static const struct nft_expr_ops nft_bitwise_ops = { + .type = &nft_bitwise_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_bitwise)), + .eval = nft_bitwise_eval, + .init = nft_bitwise_init, + .dump = nft_bitwise_dump, + .reduce = nft_bitwise_reduce, + .offload = nft_bitwise_offload, +}; + +static int +nft_bitwise_extract_u32_data(const struct nlattr * const tb, u32 *out) +{ + struct nft_data data; + struct nft_data_desc desc = { + .type = NFT_DATA_VALUE, + .size = sizeof(data), + .len = sizeof(u32), + }; + int err; + + err = nft_data_init(NULL, &data, &desc, tb); + if (err < 0) + return err; + + *out = data.data[0]; + + return 0; +} + +static int nft_bitwise_fast_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_bitwise_fast_expr *priv = nft_expr_priv(expr); + int err; + + err = nft_parse_register_load(tb[NFTA_BITWISE_SREG], &priv->sreg, + sizeof(u32)); + if (err < 0) + return err; + + err = nft_parse_register_store(ctx, tb[NFTA_BITWISE_DREG], &priv->dreg, + NULL, NFT_DATA_VALUE, sizeof(u32)); + if (err < 0) + return err; + + if (tb[NFTA_BITWISE_DATA]) + return -EINVAL; + + if (!tb[NFTA_BITWISE_MASK] || + !tb[NFTA_BITWISE_XOR]) + return -EINVAL; + + err = nft_bitwise_extract_u32_data(tb[NFTA_BITWISE_MASK], &priv->mask); + if (err < 0) + return err; + + err = nft_bitwise_extract_u32_data(tb[NFTA_BITWISE_XOR], &priv->xor); + if (err < 0) + return err; + + return 0; +} + +static int +nft_bitwise_fast_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_bitwise_fast_expr *priv = nft_expr_priv(expr); + struct nft_data data; + + if (nft_dump_register(skb, NFTA_BITWISE_SREG, priv->sreg)) + return -1; + if (nft_dump_register(skb, NFTA_BITWISE_DREG, priv->dreg)) + return -1; + if (nla_put_be32(skb, NFTA_BITWISE_LEN, htonl(sizeof(u32)))) + return -1; + if (nla_put_be32(skb, NFTA_BITWISE_OP, htonl(NFT_BITWISE_BOOL))) + return -1; + + data.data[0] = priv->mask; + if (nft_data_dump(skb, NFTA_BITWISE_MASK, &data, + NFT_DATA_VALUE, sizeof(u32)) < 0) + return -1; + + data.data[0] = priv->xor; + if (nft_data_dump(skb, NFTA_BITWISE_XOR, &data, + NFT_DATA_VALUE, sizeof(u32)) < 0) + return -1; + + return 0; +} + +static int nft_bitwise_fast_offload(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_expr *expr) +{ + const struct nft_bitwise_fast_expr *priv = nft_expr_priv(expr); + struct nft_offload_reg *reg = &ctx->regs[priv->dreg]; + + if (priv->xor || priv->sreg != priv->dreg || reg->len != sizeof(u32)) + return -EOPNOTSUPP; + + reg->mask.data[0] = priv->mask; + return 0; +} + +static bool nft_bitwise_fast_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + const struct nft_bitwise_fast_expr *priv = nft_expr_priv(expr); + const struct nft_bitwise_fast_expr *bitwise; + + if (!track->regs[priv->sreg].selector) + return false; + + bitwise = nft_expr_priv(track->regs[priv->dreg].selector); + if (track->regs[priv->sreg].selector == track->regs[priv->dreg].selector && + track->regs[priv->dreg].bitwise && + track->regs[priv->dreg].bitwise->ops == expr->ops && + priv->sreg == bitwise->sreg && + priv->dreg == bitwise->dreg && + priv->mask == bitwise->mask && + priv->xor == bitwise->xor) { + track->cur = expr; + return true; + } + + if (track->regs[priv->sreg].bitwise) { + nft_reg_track_cancel(track, priv->dreg, NFT_REG32_SIZE); + return false; + } + + if (priv->sreg != priv->dreg) { + track->regs[priv->dreg].selector = + track->regs[priv->sreg].selector; + } + track->regs[priv->dreg].bitwise = expr; + + return false; +} + +const struct nft_expr_ops nft_bitwise_fast_ops = { + .type = &nft_bitwise_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_bitwise_fast_expr)), + .eval = NULL, /* inlined */ + .init = nft_bitwise_fast_init, + .dump = nft_bitwise_fast_dump, + .reduce = nft_bitwise_fast_reduce, + .offload = nft_bitwise_fast_offload, +}; + +static const struct nft_expr_ops * +nft_bitwise_select_ops(const struct nft_ctx *ctx, + const struct nlattr * const tb[]) +{ + int err; + u32 len; + + if (!tb[NFTA_BITWISE_LEN] || + !tb[NFTA_BITWISE_SREG] || + !tb[NFTA_BITWISE_DREG]) + return ERR_PTR(-EINVAL); + + err = nft_parse_u32_check(tb[NFTA_BITWISE_LEN], U8_MAX, &len); + if (err < 0) + return ERR_PTR(err); + + if (len != sizeof(u32)) + return &nft_bitwise_ops; + + if (tb[NFTA_BITWISE_OP] && + ntohl(nla_get_be32(tb[NFTA_BITWISE_OP])) != NFT_BITWISE_BOOL) + return &nft_bitwise_ops; + + return &nft_bitwise_fast_ops; +} + +struct nft_expr_type nft_bitwise_type __read_mostly = { + .name = "bitwise", + .select_ops = nft_bitwise_select_ops, + .policy = nft_bitwise_policy, + .maxattr = NFTA_BITWISE_MAX, + .owner = THIS_MODULE, +}; + +bool nft_expr_reduce_bitwise(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + const struct nft_expr *last = track->last; + const struct nft_expr *next; + + if (expr == last) + return false; + + next = nft_expr_next(expr); + if (next->ops == &nft_bitwise_ops) + return nft_bitwise_reduce(track, next); + else if (next->ops == &nft_bitwise_fast_ops) + return nft_bitwise_fast_reduce(track, next); + + return false; +} +EXPORT_SYMBOL_GPL(nft_expr_reduce_bitwise); diff --git a/net/netfilter/nft_byteorder.c b/net/netfilter/nft_byteorder.c new file mode 100644 index 000000000..605178133 --- /dev/null +++ b/net/netfilter/nft_byteorder.c @@ -0,0 +1,197 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008-2009 Patrick McHardy + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_byteorder { + u8 sreg; + u8 dreg; + enum nft_byteorder_ops op:8; + u8 len; + u8 size; +}; + +void nft_byteorder_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_byteorder *priv = nft_expr_priv(expr); + u32 *src = ®s->data[priv->sreg]; + u32 *dst = ®s->data[priv->dreg]; + u16 *s16, *d16; + unsigned int i; + + s16 = (void *)src; + d16 = (void *)dst; + + switch (priv->size) { + case 8: { + u64 *dst64 = (void *)dst; + u64 src64; + + switch (priv->op) { + case NFT_BYTEORDER_NTOH: + for (i = 0; i < priv->len / 8; i++) { + src64 = nft_reg_load64(&src[i]); + nft_reg_store64(&dst64[i], + be64_to_cpu((__force __be64)src64)); + } + break; + case NFT_BYTEORDER_HTON: + for (i = 0; i < priv->len / 8; i++) { + src64 = (__force __u64) + cpu_to_be64(nft_reg_load64(&src[i])); + nft_reg_store64(&dst64[i], src64); + } + break; + } + break; + } + case 4: + switch (priv->op) { + case NFT_BYTEORDER_NTOH: + for (i = 0; i < priv->len / 4; i++) + dst[i] = ntohl((__force __be32)src[i]); + break; + case NFT_BYTEORDER_HTON: + for (i = 0; i < priv->len / 4; i++) + dst[i] = (__force __u32)htonl(src[i]); + break; + } + break; + case 2: + switch (priv->op) { + case NFT_BYTEORDER_NTOH: + for (i = 0; i < priv->len / 2; i++) + d16[i] = ntohs((__force __be16)s16[i]); + break; + case NFT_BYTEORDER_HTON: + for (i = 0; i < priv->len / 2; i++) + d16[i] = (__force __u16)htons(s16[i]); + break; + } + break; + } +} + +static const struct nla_policy nft_byteorder_policy[NFTA_BYTEORDER_MAX + 1] = { + [NFTA_BYTEORDER_SREG] = { .type = NLA_U32 }, + [NFTA_BYTEORDER_DREG] = { .type = NLA_U32 }, + [NFTA_BYTEORDER_OP] = { .type = NLA_U32 }, + [NFTA_BYTEORDER_LEN] = { .type = NLA_U32 }, + [NFTA_BYTEORDER_SIZE] = { .type = NLA_U32 }, +}; + +static int nft_byteorder_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_byteorder *priv = nft_expr_priv(expr); + u32 size, len; + int err; + + if (tb[NFTA_BYTEORDER_SREG] == NULL || + tb[NFTA_BYTEORDER_DREG] == NULL || + tb[NFTA_BYTEORDER_LEN] == NULL || + tb[NFTA_BYTEORDER_SIZE] == NULL || + tb[NFTA_BYTEORDER_OP] == NULL) + return -EINVAL; + + priv->op = ntohl(nla_get_be32(tb[NFTA_BYTEORDER_OP])); + switch (priv->op) { + case NFT_BYTEORDER_NTOH: + case NFT_BYTEORDER_HTON: + break; + default: + return -EINVAL; + } + + err = nft_parse_u32_check(tb[NFTA_BYTEORDER_SIZE], U8_MAX, &size); + if (err < 0) + return err; + + priv->size = size; + + switch (priv->size) { + case 2: + case 4: + case 8: + break; + default: + return -EINVAL; + } + + err = nft_parse_u32_check(tb[NFTA_BYTEORDER_LEN], U8_MAX, &len); + if (err < 0) + return err; + + priv->len = len; + + err = nft_parse_register_load(tb[NFTA_BYTEORDER_SREG], &priv->sreg, + priv->len); + if (err < 0) + return err; + + return nft_parse_register_store(ctx, tb[NFTA_BYTEORDER_DREG], + &priv->dreg, NULL, NFT_DATA_VALUE, + priv->len); +} + +static int nft_byteorder_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_byteorder *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_BYTEORDER_SREG, priv->sreg)) + goto nla_put_failure; + if (nft_dump_register(skb, NFTA_BYTEORDER_DREG, priv->dreg)) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_BYTEORDER_OP, htonl(priv->op))) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_BYTEORDER_LEN, htonl(priv->len))) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_BYTEORDER_SIZE, htonl(priv->size))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static bool nft_byteorder_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + struct nft_byteorder *priv = nft_expr_priv(expr); + + nft_reg_track_cancel(track, priv->dreg, priv->len); + + return false; +} + +static const struct nft_expr_ops nft_byteorder_ops = { + .type = &nft_byteorder_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_byteorder)), + .eval = nft_byteorder_eval, + .init = nft_byteorder_init, + .dump = nft_byteorder_dump, + .reduce = nft_byteorder_reduce, +}; + +struct nft_expr_type nft_byteorder_type __read_mostly = { + .name = "byteorder", + .ops = &nft_byteorder_ops, + .policy = nft_byteorder_policy, + .maxattr = NFTA_BYTEORDER_MAX, + .owner = THIS_MODULE, +}; diff --git a/net/netfilter/nft_chain_filter.c b/net/netfilter/nft_chain_filter.c new file mode 100644 index 000000000..274b6f7e6 --- /dev/null +++ b/net/netfilter/nft_chain_filter.c @@ -0,0 +1,456 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CONFIG_NF_TABLES_IPV4 +static unsigned int nft_do_chain_ipv4(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nft_pktinfo pkt; + + nft_set_pktinfo(&pkt, skb, state); + nft_set_pktinfo_ipv4(&pkt); + + return nft_do_chain(&pkt, priv); +} + +static const struct nft_chain_type nft_chain_filter_ipv4 = { + .name = "filter", + .type = NFT_CHAIN_T_DEFAULT, + .family = NFPROTO_IPV4, + .hook_mask = (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_FORWARD) | + (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_POST_ROUTING), + .hooks = { + [NF_INET_LOCAL_IN] = nft_do_chain_ipv4, + [NF_INET_LOCAL_OUT] = nft_do_chain_ipv4, + [NF_INET_FORWARD] = nft_do_chain_ipv4, + [NF_INET_PRE_ROUTING] = nft_do_chain_ipv4, + [NF_INET_POST_ROUTING] = nft_do_chain_ipv4, + }, +}; + +static void nft_chain_filter_ipv4_init(void) +{ + nft_register_chain_type(&nft_chain_filter_ipv4); +} +static void nft_chain_filter_ipv4_fini(void) +{ + nft_unregister_chain_type(&nft_chain_filter_ipv4); +} + +#else +static inline void nft_chain_filter_ipv4_init(void) {} +static inline void nft_chain_filter_ipv4_fini(void) {} +#endif /* CONFIG_NF_TABLES_IPV4 */ + +#ifdef CONFIG_NF_TABLES_ARP +static unsigned int nft_do_chain_arp(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nft_pktinfo pkt; + + nft_set_pktinfo(&pkt, skb, state); + nft_set_pktinfo_unspec(&pkt); + + return nft_do_chain(&pkt, priv); +} + +static const struct nft_chain_type nft_chain_filter_arp = { + .name = "filter", + .type = NFT_CHAIN_T_DEFAULT, + .family = NFPROTO_ARP, + .owner = THIS_MODULE, + .hook_mask = (1 << NF_ARP_IN) | + (1 << NF_ARP_OUT), + .hooks = { + [NF_ARP_IN] = nft_do_chain_arp, + [NF_ARP_OUT] = nft_do_chain_arp, + }, +}; + +static void nft_chain_filter_arp_init(void) +{ + nft_register_chain_type(&nft_chain_filter_arp); +} + +static void nft_chain_filter_arp_fini(void) +{ + nft_unregister_chain_type(&nft_chain_filter_arp); +} +#else +static inline void nft_chain_filter_arp_init(void) {} +static inline void nft_chain_filter_arp_fini(void) {} +#endif /* CONFIG_NF_TABLES_ARP */ + +#ifdef CONFIG_NF_TABLES_IPV6 +static unsigned int nft_do_chain_ipv6(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nft_pktinfo pkt; + + nft_set_pktinfo(&pkt, skb, state); + nft_set_pktinfo_ipv6(&pkt); + + return nft_do_chain(&pkt, priv); +} + +static const struct nft_chain_type nft_chain_filter_ipv6 = { + .name = "filter", + .type = NFT_CHAIN_T_DEFAULT, + .family = NFPROTO_IPV6, + .hook_mask = (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_FORWARD) | + (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_POST_ROUTING), + .hooks = { + [NF_INET_LOCAL_IN] = nft_do_chain_ipv6, + [NF_INET_LOCAL_OUT] = nft_do_chain_ipv6, + [NF_INET_FORWARD] = nft_do_chain_ipv6, + [NF_INET_PRE_ROUTING] = nft_do_chain_ipv6, + [NF_INET_POST_ROUTING] = nft_do_chain_ipv6, + }, +}; + +static void nft_chain_filter_ipv6_init(void) +{ + nft_register_chain_type(&nft_chain_filter_ipv6); +} + +static void nft_chain_filter_ipv6_fini(void) +{ + nft_unregister_chain_type(&nft_chain_filter_ipv6); +} +#else +static inline void nft_chain_filter_ipv6_init(void) {} +static inline void nft_chain_filter_ipv6_fini(void) {} +#endif /* CONFIG_NF_TABLES_IPV6 */ + +#ifdef CONFIG_NF_TABLES_INET +static unsigned int nft_do_chain_inet(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nft_pktinfo pkt; + + nft_set_pktinfo(&pkt, skb, state); + + switch (state->pf) { + case NFPROTO_IPV4: + nft_set_pktinfo_ipv4(&pkt); + break; + case NFPROTO_IPV6: + nft_set_pktinfo_ipv6(&pkt); + break; + default: + break; + } + + return nft_do_chain(&pkt, priv); +} + +static unsigned int nft_do_chain_inet_ingress(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nf_hook_state ingress_state = *state; + struct nft_pktinfo pkt; + + switch (skb->protocol) { + case htons(ETH_P_IP): + /* Original hook is NFPROTO_NETDEV and NF_NETDEV_INGRESS. */ + ingress_state.pf = NFPROTO_IPV4; + ingress_state.hook = NF_INET_INGRESS; + nft_set_pktinfo(&pkt, skb, &ingress_state); + + if (nft_set_pktinfo_ipv4_ingress(&pkt) < 0) + return NF_DROP; + break; + case htons(ETH_P_IPV6): + ingress_state.pf = NFPROTO_IPV6; + ingress_state.hook = NF_INET_INGRESS; + nft_set_pktinfo(&pkt, skb, &ingress_state); + + if (nft_set_pktinfo_ipv6_ingress(&pkt) < 0) + return NF_DROP; + break; + default: + return NF_ACCEPT; + } + + return nft_do_chain(&pkt, priv); +} + +static const struct nft_chain_type nft_chain_filter_inet = { + .name = "filter", + .type = NFT_CHAIN_T_DEFAULT, + .family = NFPROTO_INET, + .hook_mask = (1 << NF_INET_INGRESS) | + (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_FORWARD) | + (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_POST_ROUTING), + .hooks = { + [NF_INET_INGRESS] = nft_do_chain_inet_ingress, + [NF_INET_LOCAL_IN] = nft_do_chain_inet, + [NF_INET_LOCAL_OUT] = nft_do_chain_inet, + [NF_INET_FORWARD] = nft_do_chain_inet, + [NF_INET_PRE_ROUTING] = nft_do_chain_inet, + [NF_INET_POST_ROUTING] = nft_do_chain_inet, + }, +}; + +static void nft_chain_filter_inet_init(void) +{ + nft_register_chain_type(&nft_chain_filter_inet); +} + +static void nft_chain_filter_inet_fini(void) +{ + nft_unregister_chain_type(&nft_chain_filter_inet); +} +#else +static inline void nft_chain_filter_inet_init(void) {} +static inline void nft_chain_filter_inet_fini(void) {} +#endif /* CONFIG_NF_TABLES_IPV6 */ + +#if IS_ENABLED(CONFIG_NF_TABLES_BRIDGE) +static unsigned int +nft_do_chain_bridge(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nft_pktinfo pkt; + + nft_set_pktinfo(&pkt, skb, state); + + switch (eth_hdr(skb)->h_proto) { + case htons(ETH_P_IP): + nft_set_pktinfo_ipv4_validate(&pkt); + break; + case htons(ETH_P_IPV6): + nft_set_pktinfo_ipv6_validate(&pkt); + break; + default: + nft_set_pktinfo_unspec(&pkt); + break; + } + + return nft_do_chain(&pkt, priv); +} + +static const struct nft_chain_type nft_chain_filter_bridge = { + .name = "filter", + .type = NFT_CHAIN_T_DEFAULT, + .family = NFPROTO_BRIDGE, + .hook_mask = (1 << NF_BR_PRE_ROUTING) | + (1 << NF_BR_LOCAL_IN) | + (1 << NF_BR_FORWARD) | + (1 << NF_BR_LOCAL_OUT) | + (1 << NF_BR_POST_ROUTING), + .hooks = { + [NF_BR_PRE_ROUTING] = nft_do_chain_bridge, + [NF_BR_LOCAL_IN] = nft_do_chain_bridge, + [NF_BR_FORWARD] = nft_do_chain_bridge, + [NF_BR_LOCAL_OUT] = nft_do_chain_bridge, + [NF_BR_POST_ROUTING] = nft_do_chain_bridge, + }, +}; + +static void nft_chain_filter_bridge_init(void) +{ + nft_register_chain_type(&nft_chain_filter_bridge); +} + +static void nft_chain_filter_bridge_fini(void) +{ + nft_unregister_chain_type(&nft_chain_filter_bridge); +} +#else +static inline void nft_chain_filter_bridge_init(void) {} +static inline void nft_chain_filter_bridge_fini(void) {} +#endif /* CONFIG_NF_TABLES_BRIDGE */ + +#ifdef CONFIG_NF_TABLES_NETDEV +static unsigned int nft_do_chain_netdev(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nft_pktinfo pkt; + + nft_set_pktinfo(&pkt, skb, state); + + switch (skb->protocol) { + case htons(ETH_P_IP): + nft_set_pktinfo_ipv4_validate(&pkt); + break; + case htons(ETH_P_IPV6): + nft_set_pktinfo_ipv6_validate(&pkt); + break; + default: + nft_set_pktinfo_unspec(&pkt); + break; + } + + return nft_do_chain(&pkt, priv); +} + +static const struct nft_chain_type nft_chain_filter_netdev = { + .name = "filter", + .type = NFT_CHAIN_T_DEFAULT, + .family = NFPROTO_NETDEV, + .hook_mask = (1 << NF_NETDEV_INGRESS) | + (1 << NF_NETDEV_EGRESS), + .hooks = { + [NF_NETDEV_INGRESS] = nft_do_chain_netdev, + [NF_NETDEV_EGRESS] = nft_do_chain_netdev, + }, +}; + +static void nft_netdev_event(unsigned long event, struct net_device *dev, + struct nft_ctx *ctx) +{ + struct nft_base_chain *basechain = nft_base_chain(ctx->chain); + struct nft_hook *hook, *found = NULL; + int n = 0; + + if (event != NETDEV_UNREGISTER) + return; + + list_for_each_entry(hook, &basechain->hook_list, list) { + if (hook->ops.dev == dev) + found = hook; + + n++; + } + if (!found) + return; + + if (n > 1) { + nf_unregister_net_hook(ctx->net, &found->ops); + list_del_rcu(&found->list); + kfree_rcu(found, rcu); + return; + } + + /* UNREGISTER events are also happening on netns exit. + * + * Although nf_tables core releases all tables/chains, only this event + * handler provides guarantee that hook->ops.dev is still accessible, + * so we cannot skip exiting net namespaces. + */ + __nft_release_basechain(ctx); +} + +static int nf_tables_netdev_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct nft_base_chain *basechain; + struct nftables_pernet *nft_net; + struct nft_chain *chain, *nr; + struct nft_table *table; + struct nft_ctx ctx = { + .net = dev_net(dev), + }; + + if (event != NETDEV_UNREGISTER && + event != NETDEV_CHANGENAME) + return NOTIFY_DONE; + + nft_net = nft_pernet(ctx.net); + mutex_lock(&nft_net->commit_mutex); + list_for_each_entry(table, &nft_net->tables, list) { + if (table->family != NFPROTO_NETDEV && + table->family != NFPROTO_INET) + continue; + + ctx.family = table->family; + ctx.table = table; + list_for_each_entry_safe(chain, nr, &table->chains, list) { + if (!nft_is_base_chain(chain)) + continue; + + basechain = nft_base_chain(chain); + if (table->family == NFPROTO_INET && + basechain->ops.hooknum != NF_INET_INGRESS) + continue; + + ctx.chain = chain; + nft_netdev_event(event, dev, &ctx); + } + } + mutex_unlock(&nft_net->commit_mutex); + + return NOTIFY_DONE; +} + +static struct notifier_block nf_tables_netdev_notifier = { + .notifier_call = nf_tables_netdev_event, +}; + +static int nft_chain_filter_netdev_init(void) +{ + int err; + + nft_register_chain_type(&nft_chain_filter_netdev); + + err = register_netdevice_notifier(&nf_tables_netdev_notifier); + if (err) + goto err_register_netdevice_notifier; + + return 0; + +err_register_netdevice_notifier: + nft_unregister_chain_type(&nft_chain_filter_netdev); + + return err; +} + +static void nft_chain_filter_netdev_fini(void) +{ + nft_unregister_chain_type(&nft_chain_filter_netdev); + unregister_netdevice_notifier(&nf_tables_netdev_notifier); +} +#else +static inline int nft_chain_filter_netdev_init(void) { return 0; } +static inline void nft_chain_filter_netdev_fini(void) {} +#endif /* CONFIG_NF_TABLES_NETDEV */ + +int __init nft_chain_filter_init(void) +{ + int err; + + err = nft_chain_filter_netdev_init(); + if (err < 0) + return err; + + nft_chain_filter_ipv4_init(); + nft_chain_filter_ipv6_init(); + nft_chain_filter_arp_init(); + nft_chain_filter_inet_init(); + nft_chain_filter_bridge_init(); + + return 0; +} + +void nft_chain_filter_fini(void) +{ + nft_chain_filter_bridge_fini(); + nft_chain_filter_inet_fini(); + nft_chain_filter_arp_fini(); + nft_chain_filter_ipv6_fini(); + nft_chain_filter_ipv4_fini(); + nft_chain_filter_netdev_fini(); +} diff --git a/net/netfilter/nft_chain_nat.c b/net/netfilter/nft_chain_nat.c new file mode 100644 index 000000000..98e494610 --- /dev/null +++ b/net/netfilter/nft_chain_nat.c @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include +#include +#include +#include +#include +#include + +static unsigned int nft_nat_do_chain(void *priv, struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nft_pktinfo pkt; + + nft_set_pktinfo(&pkt, skb, state); + + switch (state->pf) { +#ifdef CONFIG_NF_TABLES_IPV4 + case NFPROTO_IPV4: + nft_set_pktinfo_ipv4(&pkt); + break; +#endif +#ifdef CONFIG_NF_TABLES_IPV6 + case NFPROTO_IPV6: + nft_set_pktinfo_ipv6(&pkt); + break; +#endif + default: + break; + } + + return nft_do_chain(&pkt, priv); +} + +#ifdef CONFIG_NF_TABLES_IPV4 +static const struct nft_chain_type nft_chain_nat_ipv4 = { + .name = "nat", + .type = NFT_CHAIN_T_NAT, + .family = NFPROTO_IPV4, + .owner = THIS_MODULE, + .hook_mask = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_LOCAL_IN), + .hooks = { + [NF_INET_PRE_ROUTING] = nft_nat_do_chain, + [NF_INET_POST_ROUTING] = nft_nat_do_chain, + [NF_INET_LOCAL_OUT] = nft_nat_do_chain, + [NF_INET_LOCAL_IN] = nft_nat_do_chain, + }, + .ops_register = nf_nat_ipv4_register_fn, + .ops_unregister = nf_nat_ipv4_unregister_fn, +}; +#endif + +#ifdef CONFIG_NF_TABLES_IPV6 +static const struct nft_chain_type nft_chain_nat_ipv6 = { + .name = "nat", + .type = NFT_CHAIN_T_NAT, + .family = NFPROTO_IPV6, + .owner = THIS_MODULE, + .hook_mask = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_LOCAL_IN), + .hooks = { + [NF_INET_PRE_ROUTING] = nft_nat_do_chain, + [NF_INET_POST_ROUTING] = nft_nat_do_chain, + [NF_INET_LOCAL_OUT] = nft_nat_do_chain, + [NF_INET_LOCAL_IN] = nft_nat_do_chain, + }, + .ops_register = nf_nat_ipv6_register_fn, + .ops_unregister = nf_nat_ipv6_unregister_fn, +}; +#endif + +#ifdef CONFIG_NF_TABLES_INET +static int nft_nat_inet_reg(struct net *net, const struct nf_hook_ops *ops) +{ + return nf_nat_inet_register_fn(net, ops); +} + +static void nft_nat_inet_unreg(struct net *net, const struct nf_hook_ops *ops) +{ + nf_nat_inet_unregister_fn(net, ops); +} + +static const struct nft_chain_type nft_chain_nat_inet = { + .name = "nat", + .type = NFT_CHAIN_T_NAT, + .family = NFPROTO_INET, + .owner = THIS_MODULE, + .hook_mask = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_POST_ROUTING), + .hooks = { + [NF_INET_PRE_ROUTING] = nft_nat_do_chain, + [NF_INET_LOCAL_IN] = nft_nat_do_chain, + [NF_INET_LOCAL_OUT] = nft_nat_do_chain, + [NF_INET_POST_ROUTING] = nft_nat_do_chain, + }, + .ops_register = nft_nat_inet_reg, + .ops_unregister = nft_nat_inet_unreg, +}; +#endif + +static int __init nft_chain_nat_init(void) +{ +#ifdef CONFIG_NF_TABLES_IPV6 + nft_register_chain_type(&nft_chain_nat_ipv6); +#endif +#ifdef CONFIG_NF_TABLES_IPV4 + nft_register_chain_type(&nft_chain_nat_ipv4); +#endif +#ifdef CONFIG_NF_TABLES_INET + nft_register_chain_type(&nft_chain_nat_inet); +#endif + + return 0; +} + +static void __exit nft_chain_nat_exit(void) +{ +#ifdef CONFIG_NF_TABLES_IPV4 + nft_unregister_chain_type(&nft_chain_nat_ipv4); +#endif +#ifdef CONFIG_NF_TABLES_IPV6 + nft_unregister_chain_type(&nft_chain_nat_ipv6); +#endif +#ifdef CONFIG_NF_TABLES_INET + nft_unregister_chain_type(&nft_chain_nat_inet); +#endif +} + +module_init(nft_chain_nat_init); +module_exit(nft_chain_nat_exit); + +MODULE_LICENSE("GPL"); +#ifdef CONFIG_NF_TABLES_IPV4 +MODULE_ALIAS_NFT_CHAIN(AF_INET, "nat"); +#endif +#ifdef CONFIG_NF_TABLES_IPV6 +MODULE_ALIAS_NFT_CHAIN(AF_INET6, "nat"); +#endif +#ifdef CONFIG_NF_TABLES_INET +MODULE_ALIAS_NFT_CHAIN(1, "nat"); /* NFPROTO_INET */ +#endif diff --git a/net/netfilter/nft_chain_route.c b/net/netfilter/nft_chain_route.c new file mode 100644 index 000000000..925db0dce --- /dev/null +++ b/net/netfilter/nft_chain_route.c @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CONFIG_NF_TABLES_IPV4 +static unsigned int nf_route_table_hook4(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + const struct iphdr *iph; + struct nft_pktinfo pkt; + __be32 saddr, daddr; + unsigned int ret; + u32 mark; + int err; + u8 tos; + + nft_set_pktinfo(&pkt, skb, state); + nft_set_pktinfo_ipv4(&pkt); + + mark = skb->mark; + iph = ip_hdr(skb); + saddr = iph->saddr; + daddr = iph->daddr; + tos = iph->tos; + + ret = nft_do_chain(&pkt, priv); + if (ret == NF_ACCEPT) { + iph = ip_hdr(skb); + + if (iph->saddr != saddr || + iph->daddr != daddr || + skb->mark != mark || + iph->tos != tos) { + err = ip_route_me_harder(state->net, state->sk, skb, RTN_UNSPEC); + if (err < 0) + ret = NF_DROP_ERR(err); + } + } + return ret; +} + +static const struct nft_chain_type nft_chain_route_ipv4 = { + .name = "route", + .type = NFT_CHAIN_T_ROUTE, + .family = NFPROTO_IPV4, + .hook_mask = (1 << NF_INET_LOCAL_OUT), + .hooks = { + [NF_INET_LOCAL_OUT] = nf_route_table_hook4, + }, +}; +#endif + +#ifdef CONFIG_NF_TABLES_IPV6 +static unsigned int nf_route_table_hook6(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct in6_addr saddr, daddr; + struct nft_pktinfo pkt; + u32 mark, flowlabel; + unsigned int ret; + u8 hop_limit; + int err; + + nft_set_pktinfo(&pkt, skb, state); + nft_set_pktinfo_ipv6(&pkt); + + /* save source/dest address, mark, hoplimit, flowlabel, priority */ + memcpy(&saddr, &ipv6_hdr(skb)->saddr, sizeof(saddr)); + memcpy(&daddr, &ipv6_hdr(skb)->daddr, sizeof(daddr)); + mark = skb->mark; + hop_limit = ipv6_hdr(skb)->hop_limit; + + /* flowlabel and prio (includes version, which shouldn't change either)*/ + flowlabel = *((u32 *)ipv6_hdr(skb)); + + ret = nft_do_chain(&pkt, priv); + if (ret == NF_ACCEPT && + (memcmp(&ipv6_hdr(skb)->saddr, &saddr, sizeof(saddr)) || + memcmp(&ipv6_hdr(skb)->daddr, &daddr, sizeof(daddr)) || + skb->mark != mark || + ipv6_hdr(skb)->hop_limit != hop_limit || + flowlabel != *((u32 *)ipv6_hdr(skb)))) { + err = nf_ip6_route_me_harder(state->net, state->sk, skb); + if (err < 0) + ret = NF_DROP_ERR(err); + } + + return ret; +} + +static const struct nft_chain_type nft_chain_route_ipv6 = { + .name = "route", + .type = NFT_CHAIN_T_ROUTE, + .family = NFPROTO_IPV6, + .hook_mask = (1 << NF_INET_LOCAL_OUT), + .hooks = { + [NF_INET_LOCAL_OUT] = nf_route_table_hook6, + }, +}; +#endif + +#ifdef CONFIG_NF_TABLES_INET +static unsigned int nf_route_table_inet(void *priv, + struct sk_buff *skb, + const struct nf_hook_state *state) +{ + struct nft_pktinfo pkt; + + switch (state->pf) { + case NFPROTO_IPV4: + return nf_route_table_hook4(priv, skb, state); + case NFPROTO_IPV6: + return nf_route_table_hook6(priv, skb, state); + default: + nft_set_pktinfo(&pkt, skb, state); + break; + } + + return nft_do_chain(&pkt, priv); +} + +static const struct nft_chain_type nft_chain_route_inet = { + .name = "route", + .type = NFT_CHAIN_T_ROUTE, + .family = NFPROTO_INET, + .hook_mask = (1 << NF_INET_LOCAL_OUT), + .hooks = { + [NF_INET_LOCAL_OUT] = nf_route_table_inet, + }, +}; +#endif + +void __init nft_chain_route_init(void) +{ +#ifdef CONFIG_NF_TABLES_IPV6 + nft_register_chain_type(&nft_chain_route_ipv6); +#endif +#ifdef CONFIG_NF_TABLES_IPV4 + nft_register_chain_type(&nft_chain_route_ipv4); +#endif +#ifdef CONFIG_NF_TABLES_INET + nft_register_chain_type(&nft_chain_route_inet); +#endif +} + +void __exit nft_chain_route_fini(void) +{ +#ifdef CONFIG_NF_TABLES_IPV6 + nft_unregister_chain_type(&nft_chain_route_ipv6); +#endif +#ifdef CONFIG_NF_TABLES_IPV4 + nft_unregister_chain_type(&nft_chain_route_ipv4); +#endif +#ifdef CONFIG_NF_TABLES_INET + nft_unregister_chain_type(&nft_chain_route_inet); +#endif +} diff --git a/net/netfilter/nft_cmp.c b/net/netfilter/nft_cmp.c new file mode 100644 index 000000000..963cf8317 --- /dev/null +++ b/net/netfilter/nft_cmp.c @@ -0,0 +1,433 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008-2009 Patrick McHardy + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_cmp_expr { + struct nft_data data; + u8 sreg; + u8 len; + enum nft_cmp_ops op:8; +}; + +void nft_cmp_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_cmp_expr *priv = nft_expr_priv(expr); + int d; + + d = memcmp(®s->data[priv->sreg], &priv->data, priv->len); + switch (priv->op) { + case NFT_CMP_EQ: + if (d != 0) + goto mismatch; + break; + case NFT_CMP_NEQ: + if (d == 0) + goto mismatch; + break; + case NFT_CMP_LT: + if (d == 0) + goto mismatch; + fallthrough; + case NFT_CMP_LTE: + if (d > 0) + goto mismatch; + break; + case NFT_CMP_GT: + if (d == 0) + goto mismatch; + fallthrough; + case NFT_CMP_GTE: + if (d < 0) + goto mismatch; + break; + } + return; + +mismatch: + regs->verdict.code = NFT_BREAK; +} + +static const struct nla_policy nft_cmp_policy[NFTA_CMP_MAX + 1] = { + [NFTA_CMP_SREG] = { .type = NLA_U32 }, + [NFTA_CMP_OP] = { .type = NLA_U32 }, + [NFTA_CMP_DATA] = { .type = NLA_NESTED }, +}; + +static int nft_cmp_init(const struct nft_ctx *ctx, const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_cmp_expr *priv = nft_expr_priv(expr); + struct nft_data_desc desc = { + .type = NFT_DATA_VALUE, + .size = sizeof(priv->data), + }; + int err; + + err = nft_data_init(NULL, &priv->data, &desc, tb[NFTA_CMP_DATA]); + if (err < 0) + return err; + + err = nft_parse_register_load(tb[NFTA_CMP_SREG], &priv->sreg, desc.len); + if (err < 0) + return err; + + priv->op = ntohl(nla_get_be32(tb[NFTA_CMP_OP])); + priv->len = desc.len; + return 0; +} + +static int nft_cmp_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_cmp_expr *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_CMP_SREG, priv->sreg)) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_CMP_OP, htonl(priv->op))) + goto nla_put_failure; + + if (nft_data_dump(skb, NFTA_CMP_DATA, &priv->data, + NFT_DATA_VALUE, priv->len) < 0) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +union nft_cmp_offload_data { + u16 val16; + u32 val32; + u64 val64; +}; + +static void nft_payload_n2h(union nft_cmp_offload_data *data, + const u8 *val, u32 len) +{ + switch (len) { + case 2: + data->val16 = ntohs(*((__be16 *)val)); + break; + case 4: + data->val32 = ntohl(*((__be32 *)val)); + break; + case 8: + data->val64 = be64_to_cpu(*((__be64 *)val)); + break; + default: + WARN_ON_ONCE(1); + break; + } +} + +static int __nft_cmp_offload(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_cmp_expr *priv) +{ + struct nft_offload_reg *reg = &ctx->regs[priv->sreg]; + union nft_cmp_offload_data _data, _datamask; + u8 *mask = (u8 *)&flow->match.mask; + u8 *key = (u8 *)&flow->match.key; + u8 *data, *datamask; + + if (priv->op != NFT_CMP_EQ || priv->len > reg->len) + return -EOPNOTSUPP; + + if (reg->flags & NFT_OFFLOAD_F_NETWORK2HOST) { + nft_payload_n2h(&_data, (u8 *)&priv->data, reg->len); + nft_payload_n2h(&_datamask, (u8 *)®->mask, reg->len); + data = (u8 *)&_data; + datamask = (u8 *)&_datamask; + } else { + data = (u8 *)&priv->data; + datamask = (u8 *)®->mask; + } + + memcpy(key + reg->offset, data, reg->len); + memcpy(mask + reg->offset, datamask, reg->len); + + flow->match.dissector.used_keys |= BIT(reg->key); + flow->match.dissector.offset[reg->key] = reg->base_offset; + + if (reg->key == FLOW_DISSECTOR_KEY_META && + reg->offset == offsetof(struct nft_flow_key, meta.ingress_iftype) && + nft_reg_load16(priv->data.data) != ARPHRD_ETHER) + return -EOPNOTSUPP; + + nft_offload_update_dependency(ctx, &priv->data, reg->len); + + return 0; +} + +static int nft_cmp_offload(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_expr *expr) +{ + const struct nft_cmp_expr *priv = nft_expr_priv(expr); + + return __nft_cmp_offload(ctx, flow, priv); +} + +static const struct nft_expr_ops nft_cmp_ops = { + .type = &nft_cmp_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_cmp_expr)), + .eval = nft_cmp_eval, + .init = nft_cmp_init, + .dump = nft_cmp_dump, + .reduce = NFT_REDUCE_READONLY, + .offload = nft_cmp_offload, +}; + +/* Calculate the mask for the nft_cmp_fast expression. On big endian the + * mask needs to include the *upper* bytes when interpreting that data as + * something smaller than the full u32, therefore a cpu_to_le32 is done. + */ +static u32 nft_cmp_fast_mask(unsigned int len) +{ + __le32 mask = cpu_to_le32(~0U >> (sizeof_field(struct nft_cmp_fast_expr, + data) * BITS_PER_BYTE - len)); + + return (__force u32)mask; +} + +static int nft_cmp_fast_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_cmp_fast_expr *priv = nft_expr_priv(expr); + struct nft_data data; + struct nft_data_desc desc = { + .type = NFT_DATA_VALUE, + .size = sizeof(data), + }; + int err; + + err = nft_data_init(NULL, &data, &desc, tb[NFTA_CMP_DATA]); + if (err < 0) + return err; + + err = nft_parse_register_load(tb[NFTA_CMP_SREG], &priv->sreg, desc.len); + if (err < 0) + return err; + + desc.len *= BITS_PER_BYTE; + + priv->mask = nft_cmp_fast_mask(desc.len); + priv->data = data.data[0] & priv->mask; + priv->len = desc.len; + priv->inv = ntohl(nla_get_be32(tb[NFTA_CMP_OP])) != NFT_CMP_EQ; + return 0; +} + +static int nft_cmp_fast_offload(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_expr *expr) +{ + const struct nft_cmp_fast_expr *priv = nft_expr_priv(expr); + struct nft_cmp_expr cmp = { + .data = { + .data = { + [0] = priv->data, + }, + }, + .sreg = priv->sreg, + .len = priv->len / BITS_PER_BYTE, + .op = priv->inv ? NFT_CMP_NEQ : NFT_CMP_EQ, + }; + + return __nft_cmp_offload(ctx, flow, &cmp); +} + +static int nft_cmp_fast_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_cmp_fast_expr *priv = nft_expr_priv(expr); + enum nft_cmp_ops op = priv->inv ? NFT_CMP_NEQ : NFT_CMP_EQ; + struct nft_data data; + + if (nft_dump_register(skb, NFTA_CMP_SREG, priv->sreg)) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_CMP_OP, htonl(op))) + goto nla_put_failure; + + data.data[0] = priv->data; + if (nft_data_dump(skb, NFTA_CMP_DATA, &data, + NFT_DATA_VALUE, priv->len / BITS_PER_BYTE) < 0) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +const struct nft_expr_ops nft_cmp_fast_ops = { + .type = &nft_cmp_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_cmp_fast_expr)), + .eval = NULL, /* inlined */ + .init = nft_cmp_fast_init, + .dump = nft_cmp_fast_dump, + .reduce = NFT_REDUCE_READONLY, + .offload = nft_cmp_fast_offload, +}; + +static u32 nft_cmp_mask(u32 bitlen) +{ + return (__force u32)cpu_to_le32(~0U >> (sizeof(u32) * BITS_PER_BYTE - bitlen)); +} + +static void nft_cmp16_fast_mask(struct nft_data *data, unsigned int bitlen) +{ + int len = bitlen / BITS_PER_BYTE; + int i, words = len / sizeof(u32); + + for (i = 0; i < words; i++) { + data->data[i] = 0xffffffff; + bitlen -= sizeof(u32) * BITS_PER_BYTE; + } + + if (len % sizeof(u32)) + data->data[i++] = nft_cmp_mask(bitlen); + + for (; i < 4; i++) + data->data[i] = 0; +} + +static int nft_cmp16_fast_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_cmp16_fast_expr *priv = nft_expr_priv(expr); + struct nft_data_desc desc = { + .type = NFT_DATA_VALUE, + .size = sizeof(priv->data), + }; + int err; + + err = nft_data_init(NULL, &priv->data, &desc, tb[NFTA_CMP_DATA]); + if (err < 0) + return err; + + err = nft_parse_register_load(tb[NFTA_CMP_SREG], &priv->sreg, desc.len); + if (err < 0) + return err; + + nft_cmp16_fast_mask(&priv->mask, desc.len * BITS_PER_BYTE); + priv->inv = ntohl(nla_get_be32(tb[NFTA_CMP_OP])) != NFT_CMP_EQ; + priv->len = desc.len; + + return 0; +} + +static int nft_cmp16_fast_offload(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_expr *expr) +{ + const struct nft_cmp16_fast_expr *priv = nft_expr_priv(expr); + struct nft_cmp_expr cmp = { + .data = priv->data, + .sreg = priv->sreg, + .len = priv->len, + .op = priv->inv ? NFT_CMP_NEQ : NFT_CMP_EQ, + }; + + return __nft_cmp_offload(ctx, flow, &cmp); +} + +static int nft_cmp16_fast_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_cmp16_fast_expr *priv = nft_expr_priv(expr); + enum nft_cmp_ops op = priv->inv ? NFT_CMP_NEQ : NFT_CMP_EQ; + + if (nft_dump_register(skb, NFTA_CMP_SREG, priv->sreg)) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_CMP_OP, htonl(op))) + goto nla_put_failure; + + if (nft_data_dump(skb, NFTA_CMP_DATA, &priv->data, + NFT_DATA_VALUE, priv->len) < 0) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + + +const struct nft_expr_ops nft_cmp16_fast_ops = { + .type = &nft_cmp_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_cmp16_fast_expr)), + .eval = NULL, /* inlined */ + .init = nft_cmp16_fast_init, + .dump = nft_cmp16_fast_dump, + .reduce = NFT_REDUCE_READONLY, + .offload = nft_cmp16_fast_offload, +}; + +static const struct nft_expr_ops * +nft_cmp_select_ops(const struct nft_ctx *ctx, const struct nlattr * const tb[]) +{ + struct nft_data data; + struct nft_data_desc desc = { + .type = NFT_DATA_VALUE, + .size = sizeof(data), + }; + enum nft_cmp_ops op; + u8 sreg; + int err; + + if (tb[NFTA_CMP_SREG] == NULL || + tb[NFTA_CMP_OP] == NULL || + tb[NFTA_CMP_DATA] == NULL) + return ERR_PTR(-EINVAL); + + op = ntohl(nla_get_be32(tb[NFTA_CMP_OP])); + switch (op) { + case NFT_CMP_EQ: + case NFT_CMP_NEQ: + case NFT_CMP_LT: + case NFT_CMP_LTE: + case NFT_CMP_GT: + case NFT_CMP_GTE: + break; + default: + return ERR_PTR(-EINVAL); + } + + err = nft_data_init(NULL, &data, &desc, tb[NFTA_CMP_DATA]); + if (err < 0) + return ERR_PTR(err); + + sreg = ntohl(nla_get_be32(tb[NFTA_CMP_SREG])); + + if (op == NFT_CMP_EQ || op == NFT_CMP_NEQ) { + if (desc.len <= sizeof(u32)) + return &nft_cmp_fast_ops; + else if (desc.len <= sizeof(data) && + ((sreg >= NFT_REG_1 && sreg <= NFT_REG_4) || + (sreg >= NFT_REG32_00 && sreg <= NFT_REG32_12 && sreg % 2 == 0))) + return &nft_cmp16_fast_ops; + } + return &nft_cmp_ops; +} + +struct nft_expr_type nft_cmp_type __read_mostly = { + .name = "cmp", + .select_ops = nft_cmp_select_ops, + .policy = nft_cmp_policy, + .maxattr = NFTA_CMP_MAX, + .owner = THIS_MODULE, +}; diff --git a/net/netfilter/nft_compat.c b/net/netfilter/nft_compat.c new file mode 100644 index 000000000..6952da7df --- /dev/null +++ b/net/netfilter/nft_compat.c @@ -0,0 +1,958 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * (C) 2012-2013 by Pablo Neira Ayuso + * + * This software has been sponsored by Sophos Astaro + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* Used for matches where *info is larger than X byte */ +#define NFT_MATCH_LARGE_THRESH 192 + +struct nft_xt_match_priv { + void *info; +}; + +static int nft_compat_chain_validate_dependency(const struct nft_ctx *ctx, + const char *tablename) +{ + enum nft_chain_types type = NFT_CHAIN_T_DEFAULT; + const struct nft_chain *chain = ctx->chain; + const struct nft_base_chain *basechain; + + if (!tablename || + !nft_is_base_chain(chain)) + return 0; + + basechain = nft_base_chain(chain); + if (strcmp(tablename, "nat") == 0) { + if (ctx->family != NFPROTO_BRIDGE) + type = NFT_CHAIN_T_NAT; + if (basechain->type->type != type) + return -EINVAL; + } + + return 0; +} + +union nft_entry { + struct ipt_entry e4; + struct ip6t_entry e6; + struct ebt_entry ebt; + struct arpt_entry arp; +}; + +static inline void +nft_compat_set_par(struct xt_action_param *par, + const struct nft_pktinfo *pkt, + const void *xt, const void *xt_info) +{ + par->state = pkt->state; + par->thoff = nft_thoff(pkt); + par->fragoff = pkt->fragoff; + par->target = xt; + par->targinfo = xt_info; + par->hotdrop = false; +} + +static void nft_target_eval_xt(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + void *info = nft_expr_priv(expr); + struct xt_target *target = expr->ops->data; + struct sk_buff *skb = pkt->skb; + struct xt_action_param xt; + int ret; + + nft_compat_set_par(&xt, pkt, target, info); + + ret = target->target(skb, &xt); + + if (xt.hotdrop) + ret = NF_DROP; + + switch (ret) { + case XT_CONTINUE: + regs->verdict.code = NFT_CONTINUE; + break; + default: + regs->verdict.code = ret; + break; + } +} + +static void nft_target_eval_bridge(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + void *info = nft_expr_priv(expr); + struct xt_target *target = expr->ops->data; + struct sk_buff *skb = pkt->skb; + struct xt_action_param xt; + int ret; + + nft_compat_set_par(&xt, pkt, target, info); + + ret = target->target(skb, &xt); + + if (xt.hotdrop) + ret = NF_DROP; + + switch (ret) { + case EBT_ACCEPT: + regs->verdict.code = NF_ACCEPT; + break; + case EBT_DROP: + regs->verdict.code = NF_DROP; + break; + case EBT_CONTINUE: + regs->verdict.code = NFT_CONTINUE; + break; + case EBT_RETURN: + regs->verdict.code = NFT_RETURN; + break; + default: + regs->verdict.code = ret; + break; + } +} + +static const struct nla_policy nft_target_policy[NFTA_TARGET_MAX + 1] = { + [NFTA_TARGET_NAME] = { .type = NLA_NUL_STRING }, + [NFTA_TARGET_REV] = { .type = NLA_U32 }, + [NFTA_TARGET_INFO] = { .type = NLA_BINARY }, +}; + +static void +nft_target_set_tgchk_param(struct xt_tgchk_param *par, + const struct nft_ctx *ctx, + struct xt_target *target, void *info, + union nft_entry *entry, u16 proto, bool inv) +{ + par->net = ctx->net; + par->table = ctx->table->name; + switch (ctx->family) { + case AF_INET: + entry->e4.ip.proto = proto; + entry->e4.ip.invflags = inv ? IPT_INV_PROTO : 0; + break; + case AF_INET6: + if (proto) + entry->e6.ipv6.flags |= IP6T_F_PROTO; + + entry->e6.ipv6.proto = proto; + entry->e6.ipv6.invflags = inv ? IP6T_INV_PROTO : 0; + break; + case NFPROTO_BRIDGE: + entry->ebt.ethproto = (__force __be16)proto; + entry->ebt.invflags = inv ? EBT_IPROTO : 0; + break; + case NFPROTO_ARP: + break; + } + par->entryinfo = entry; + par->target = target; + par->targinfo = info; + if (nft_is_base_chain(ctx->chain)) { + const struct nft_base_chain *basechain = + nft_base_chain(ctx->chain); + const struct nf_hook_ops *ops = &basechain->ops; + + par->hook_mask = 1 << ops->hooknum; + } else { + par->hook_mask = 0; + } + par->family = ctx->family; + par->nft_compat = true; +} + +static void target_compat_from_user(struct xt_target *t, void *in, void *out) +{ + int pad; + + memcpy(out, in, t->targetsize); + pad = XT_ALIGN(t->targetsize) - t->targetsize; + if (pad > 0) + memset(out + t->targetsize, 0, pad); +} + +static const struct nla_policy nft_rule_compat_policy[NFTA_RULE_COMPAT_MAX + 1] = { + [NFTA_RULE_COMPAT_PROTO] = { .type = NLA_U32 }, + [NFTA_RULE_COMPAT_FLAGS] = { .type = NLA_U32 }, +}; + +static int nft_parse_compat(const struct nlattr *attr, u16 *proto, bool *inv) +{ + struct nlattr *tb[NFTA_RULE_COMPAT_MAX+1]; + u32 flags; + int err; + + err = nla_parse_nested_deprecated(tb, NFTA_RULE_COMPAT_MAX, attr, + nft_rule_compat_policy, NULL); + if (err < 0) + return err; + + if (!tb[NFTA_RULE_COMPAT_PROTO] || !tb[NFTA_RULE_COMPAT_FLAGS]) + return -EINVAL; + + flags = ntohl(nla_get_be32(tb[NFTA_RULE_COMPAT_FLAGS])); + if (flags & ~NFT_RULE_COMPAT_F_MASK) + return -EINVAL; + if (flags & NFT_RULE_COMPAT_F_INV) + *inv = true; + + *proto = ntohl(nla_get_be32(tb[NFTA_RULE_COMPAT_PROTO])); + return 0; +} + +static void nft_compat_wait_for_destructors(void) +{ + /* xtables matches or targets can have side effects, e.g. + * creation/destruction of /proc files. + * The xt ->destroy functions are run asynchronously from + * work queue. If we have pending invocations we thus + * need to wait for those to finish. + */ + nf_tables_trans_destroy_flush_work(); +} + +static int +nft_target_init(const struct nft_ctx *ctx, const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + void *info = nft_expr_priv(expr); + struct xt_target *target = expr->ops->data; + struct xt_tgchk_param par; + size_t size = XT_ALIGN(nla_len(tb[NFTA_TARGET_INFO])); + u16 proto = 0; + bool inv = false; + union nft_entry e = {}; + int ret; + + target_compat_from_user(target, nla_data(tb[NFTA_TARGET_INFO]), info); + + if (ctx->nla[NFTA_RULE_COMPAT]) { + ret = nft_parse_compat(ctx->nla[NFTA_RULE_COMPAT], &proto, &inv); + if (ret < 0) + return ret; + } + + nft_target_set_tgchk_param(&par, ctx, target, info, &e, proto, inv); + + nft_compat_wait_for_destructors(); + + ret = xt_check_target(&par, size, proto, inv); + if (ret < 0) { + if (ret == -ENOENT) { + const char *modname = NULL; + + if (strcmp(target->name, "LOG") == 0) + modname = "nf_log_syslog"; + else if (strcmp(target->name, "NFLOG") == 0) + modname = "nfnetlink_log"; + + if (modname && + nft_request_module(ctx->net, "%s", modname) == -EAGAIN) + return -EAGAIN; + } + + return ret; + } + + /* The standard target cannot be used */ + if (!target->target) + return -EINVAL; + + return 0; +} + +static void __nft_mt_tg_destroy(struct module *me, const struct nft_expr *expr) +{ + module_put(me); + kfree(expr->ops); +} + +static void +nft_target_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) +{ + struct xt_target *target = expr->ops->data; + void *info = nft_expr_priv(expr); + struct module *me = target->me; + struct xt_tgdtor_param par; + + par.net = ctx->net; + par.target = target; + par.targinfo = info; + par.family = ctx->family; + if (par.target->destroy != NULL) + par.target->destroy(&par); + + __nft_mt_tg_destroy(me, expr); +} + +static int nft_extension_dump_info(struct sk_buff *skb, int attr, + const void *info, + unsigned int size, unsigned int user_size) +{ + unsigned int info_size, aligned_size = XT_ALIGN(size); + struct nlattr *nla; + + nla = nla_reserve(skb, attr, aligned_size); + if (!nla) + return -1; + + info_size = user_size ? : size; + memcpy(nla_data(nla), info, info_size); + memset(nla_data(nla) + info_size, 0, aligned_size - info_size); + + return 0; +} + +static int nft_target_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct xt_target *target = expr->ops->data; + void *info = nft_expr_priv(expr); + + if (nla_put_string(skb, NFTA_TARGET_NAME, target->name) || + nla_put_be32(skb, NFTA_TARGET_REV, htonl(target->revision)) || + nft_extension_dump_info(skb, NFTA_TARGET_INFO, info, + target->targetsize, target->usersize)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static int nft_target_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + struct xt_target *target = expr->ops->data; + unsigned int hook_mask = 0; + int ret; + + if (ctx->family != NFPROTO_IPV4 && + ctx->family != NFPROTO_IPV6 && + ctx->family != NFPROTO_BRIDGE && + ctx->family != NFPROTO_ARP) + return -EOPNOTSUPP; + + if (nft_is_base_chain(ctx->chain)) { + const struct nft_base_chain *basechain = + nft_base_chain(ctx->chain); + const struct nf_hook_ops *ops = &basechain->ops; + + hook_mask = 1 << ops->hooknum; + if (target->hooks && !(hook_mask & target->hooks)) + return -EINVAL; + + ret = nft_compat_chain_validate_dependency(ctx, target->table); + if (ret < 0) + return ret; + } + return 0; +} + +static void __nft_match_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt, + void *info) +{ + struct xt_match *match = expr->ops->data; + struct sk_buff *skb = pkt->skb; + struct xt_action_param xt; + bool ret; + + nft_compat_set_par(&xt, pkt, match, info); + + ret = match->match(skb, &xt); + + if (xt.hotdrop) { + regs->verdict.code = NF_DROP; + return; + } + + switch (ret ? 1 : 0) { + case 1: + regs->verdict.code = NFT_CONTINUE; + break; + case 0: + regs->verdict.code = NFT_BREAK; + break; + } +} + +static void nft_match_large_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_xt_match_priv *priv = nft_expr_priv(expr); + + __nft_match_eval(expr, regs, pkt, priv->info); +} + +static void nft_match_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + __nft_match_eval(expr, regs, pkt, nft_expr_priv(expr)); +} + +static const struct nla_policy nft_match_policy[NFTA_MATCH_MAX + 1] = { + [NFTA_MATCH_NAME] = { .type = NLA_NUL_STRING }, + [NFTA_MATCH_REV] = { .type = NLA_U32 }, + [NFTA_MATCH_INFO] = { .type = NLA_BINARY }, +}; + +/* struct xt_mtchk_param and xt_tgchk_param look very similar */ +static void +nft_match_set_mtchk_param(struct xt_mtchk_param *par, const struct nft_ctx *ctx, + struct xt_match *match, void *info, + union nft_entry *entry, u16 proto, bool inv) +{ + par->net = ctx->net; + par->table = ctx->table->name; + switch (ctx->family) { + case AF_INET: + entry->e4.ip.proto = proto; + entry->e4.ip.invflags = inv ? IPT_INV_PROTO : 0; + break; + case AF_INET6: + if (proto) + entry->e6.ipv6.flags |= IP6T_F_PROTO; + + entry->e6.ipv6.proto = proto; + entry->e6.ipv6.invflags = inv ? IP6T_INV_PROTO : 0; + break; + case NFPROTO_BRIDGE: + entry->ebt.ethproto = (__force __be16)proto; + entry->ebt.invflags = inv ? EBT_IPROTO : 0; + break; + case NFPROTO_ARP: + break; + } + par->entryinfo = entry; + par->match = match; + par->matchinfo = info; + if (nft_is_base_chain(ctx->chain)) { + const struct nft_base_chain *basechain = + nft_base_chain(ctx->chain); + const struct nf_hook_ops *ops = &basechain->ops; + + par->hook_mask = 1 << ops->hooknum; + } else { + par->hook_mask = 0; + } + par->family = ctx->family; + par->nft_compat = true; +} + +static void match_compat_from_user(struct xt_match *m, void *in, void *out) +{ + int pad; + + memcpy(out, in, m->matchsize); + pad = XT_ALIGN(m->matchsize) - m->matchsize; + if (pad > 0) + memset(out + m->matchsize, 0, pad); +} + +static int +__nft_match_init(const struct nft_ctx *ctx, const struct nft_expr *expr, + const struct nlattr * const tb[], + void *info) +{ + struct xt_match *match = expr->ops->data; + struct xt_mtchk_param par; + size_t size = XT_ALIGN(nla_len(tb[NFTA_MATCH_INFO])); + u16 proto = 0; + bool inv = false; + union nft_entry e = {}; + int ret; + + match_compat_from_user(match, nla_data(tb[NFTA_MATCH_INFO]), info); + + if (ctx->nla[NFTA_RULE_COMPAT]) { + ret = nft_parse_compat(ctx->nla[NFTA_RULE_COMPAT], &proto, &inv); + if (ret < 0) + return ret; + } + + nft_match_set_mtchk_param(&par, ctx, match, info, &e, proto, inv); + + nft_compat_wait_for_destructors(); + + return xt_check_match(&par, size, proto, inv); +} + +static int +nft_match_init(const struct nft_ctx *ctx, const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + return __nft_match_init(ctx, expr, tb, nft_expr_priv(expr)); +} + +static int +nft_match_large_init(const struct nft_ctx *ctx, const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_xt_match_priv *priv = nft_expr_priv(expr); + struct xt_match *m = expr->ops->data; + int ret; + + priv->info = kmalloc(XT_ALIGN(m->matchsize), GFP_KERNEL); + if (!priv->info) + return -ENOMEM; + + ret = __nft_match_init(ctx, expr, tb, priv->info); + if (ret) + kfree(priv->info); + return ret; +} + +static void +__nft_match_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr, + void *info) +{ + struct xt_match *match = expr->ops->data; + struct module *me = match->me; + struct xt_mtdtor_param par; + + par.net = ctx->net; + par.match = match; + par.matchinfo = info; + par.family = ctx->family; + if (par.match->destroy != NULL) + par.match->destroy(&par); + + __nft_mt_tg_destroy(me, expr); +} + +static void +nft_match_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) +{ + __nft_match_destroy(ctx, expr, nft_expr_priv(expr)); +} + +static void +nft_match_large_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) +{ + struct nft_xt_match_priv *priv = nft_expr_priv(expr); + + __nft_match_destroy(ctx, expr, priv->info); + kfree(priv->info); +} + +static int __nft_match_dump(struct sk_buff *skb, const struct nft_expr *expr, + void *info) +{ + struct xt_match *match = expr->ops->data; + + if (nla_put_string(skb, NFTA_MATCH_NAME, match->name) || + nla_put_be32(skb, NFTA_MATCH_REV, htonl(match->revision)) || + nft_extension_dump_info(skb, NFTA_MATCH_INFO, info, + match->matchsize, match->usersize)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static int nft_match_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + return __nft_match_dump(skb, expr, nft_expr_priv(expr)); +} + +static int nft_match_large_dump(struct sk_buff *skb, const struct nft_expr *e) +{ + struct nft_xt_match_priv *priv = nft_expr_priv(e); + + return __nft_match_dump(skb, e, priv->info); +} + +static int nft_match_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + struct xt_match *match = expr->ops->data; + unsigned int hook_mask = 0; + int ret; + + if (ctx->family != NFPROTO_IPV4 && + ctx->family != NFPROTO_IPV6 && + ctx->family != NFPROTO_BRIDGE && + ctx->family != NFPROTO_ARP) + return -EOPNOTSUPP; + + if (nft_is_base_chain(ctx->chain)) { + const struct nft_base_chain *basechain = + nft_base_chain(ctx->chain); + const struct nf_hook_ops *ops = &basechain->ops; + + hook_mask = 1 << ops->hooknum; + if (match->hooks && !(hook_mask & match->hooks)) + return -EINVAL; + + ret = nft_compat_chain_validate_dependency(ctx, match->table); + if (ret < 0) + return ret; + } + return 0; +} + +static int +nfnl_compat_fill_info(struct sk_buff *skb, u32 portid, u32 seq, u32 type, + int event, u16 family, const char *name, + int rev, int target) +{ + struct nlmsghdr *nlh; + unsigned int flags = portid ? NLM_F_MULTI : 0; + + event = nfnl_msg_type(NFNL_SUBSYS_NFT_COMPAT, event); + nlh = nfnl_msg_put(skb, portid, seq, event, flags, family, + NFNETLINK_V0, 0); + if (!nlh) + goto nlmsg_failure; + + if (nla_put_string(skb, NFTA_COMPAT_NAME, name) || + nla_put_be32(skb, NFTA_COMPAT_REV, htonl(rev)) || + nla_put_be32(skb, NFTA_COMPAT_TYPE, htonl(target))) + goto nla_put_failure; + + nlmsg_end(skb, nlh); + return skb->len; + +nlmsg_failure: +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -1; +} + +static int nfnl_compat_get_rcu(struct sk_buff *skb, + const struct nfnl_info *info, + const struct nlattr * const tb[]) +{ + u8 family = info->nfmsg->nfgen_family; + const char *name, *fmt; + struct sk_buff *skb2; + int ret = 0, target; + u32 rev; + + if (tb[NFTA_COMPAT_NAME] == NULL || + tb[NFTA_COMPAT_REV] == NULL || + tb[NFTA_COMPAT_TYPE] == NULL) + return -EINVAL; + + name = nla_data(tb[NFTA_COMPAT_NAME]); + rev = ntohl(nla_get_be32(tb[NFTA_COMPAT_REV])); + target = ntohl(nla_get_be32(tb[NFTA_COMPAT_TYPE])); + + switch(family) { + case AF_INET: + fmt = "ipt_%s"; + break; + case AF_INET6: + fmt = "ip6t_%s"; + break; + case NFPROTO_BRIDGE: + fmt = "ebt_%s"; + break; + case NFPROTO_ARP: + fmt = "arpt_%s"; + break; + default: + pr_err("nft_compat: unsupported protocol %d\n", family); + return -EINVAL; + } + + if (!try_module_get(THIS_MODULE)) + return -EINVAL; + + rcu_read_unlock(); + try_then_request_module(xt_find_revision(family, name, rev, target, &ret), + fmt, name); + if (ret < 0) + goto out_put; + + skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (skb2 == NULL) { + ret = -ENOMEM; + goto out_put; + } + + /* include the best revision for this extension in the message */ + if (nfnl_compat_fill_info(skb2, NETLINK_CB(skb).portid, + info->nlh->nlmsg_seq, + NFNL_MSG_TYPE(info->nlh->nlmsg_type), + NFNL_MSG_COMPAT_GET, + family, name, ret, target) <= 0) { + kfree_skb(skb2); + goto out_put; + } + + ret = nfnetlink_unicast(skb2, info->net, NETLINK_CB(skb).portid); +out_put: + rcu_read_lock(); + module_put(THIS_MODULE); + + return ret; +} + +static const struct nla_policy nfnl_compat_policy_get[NFTA_COMPAT_MAX+1] = { + [NFTA_COMPAT_NAME] = { .type = NLA_NUL_STRING, + .len = NFT_COMPAT_NAME_MAX-1 }, + [NFTA_COMPAT_REV] = { .type = NLA_U32 }, + [NFTA_COMPAT_TYPE] = { .type = NLA_U32 }, +}; + +static const struct nfnl_callback nfnl_nft_compat_cb[NFNL_MSG_COMPAT_MAX] = { + [NFNL_MSG_COMPAT_GET] = { + .call = nfnl_compat_get_rcu, + .type = NFNL_CB_RCU, + .attr_count = NFTA_COMPAT_MAX, + .policy = nfnl_compat_policy_get + }, +}; + +static const struct nfnetlink_subsystem nfnl_compat_subsys = { + .name = "nft-compat", + .subsys_id = NFNL_SUBSYS_NFT_COMPAT, + .cb_count = NFNL_MSG_COMPAT_MAX, + .cb = nfnl_nft_compat_cb, +}; + +static struct nft_expr_type nft_match_type; + +static bool nft_match_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + const struct xt_match *match = expr->ops->data; + + return strcmp(match->name, "comment") == 0; +} + +static const struct nft_expr_ops * +nft_match_select_ops(const struct nft_ctx *ctx, + const struct nlattr * const tb[]) +{ + struct nft_expr_ops *ops; + struct xt_match *match; + unsigned int matchsize; + char *mt_name; + u32 rev, family; + int err; + + if (tb[NFTA_MATCH_NAME] == NULL || + tb[NFTA_MATCH_REV] == NULL || + tb[NFTA_MATCH_INFO] == NULL) + return ERR_PTR(-EINVAL); + + mt_name = nla_data(tb[NFTA_MATCH_NAME]); + rev = ntohl(nla_get_be32(tb[NFTA_MATCH_REV])); + family = ctx->family; + + match = xt_request_find_match(family, mt_name, rev); + if (IS_ERR(match)) + return ERR_PTR(-ENOENT); + + if (match->matchsize > nla_len(tb[NFTA_MATCH_INFO])) { + err = -EINVAL; + goto err; + } + + ops = kzalloc(sizeof(struct nft_expr_ops), GFP_KERNEL); + if (!ops) { + err = -ENOMEM; + goto err; + } + + ops->type = &nft_match_type; + ops->eval = nft_match_eval; + ops->init = nft_match_init; + ops->destroy = nft_match_destroy; + ops->dump = nft_match_dump; + ops->validate = nft_match_validate; + ops->data = match; + ops->reduce = nft_match_reduce; + + matchsize = NFT_EXPR_SIZE(XT_ALIGN(match->matchsize)); + if (matchsize > NFT_MATCH_LARGE_THRESH) { + matchsize = NFT_EXPR_SIZE(sizeof(struct nft_xt_match_priv)); + + ops->eval = nft_match_large_eval; + ops->init = nft_match_large_init; + ops->destroy = nft_match_large_destroy; + ops->dump = nft_match_large_dump; + } + + ops->size = matchsize; + + return ops; +err: + module_put(match->me); + return ERR_PTR(err); +} + +static void nft_match_release_ops(const struct nft_expr_ops *ops) +{ + struct xt_match *match = ops->data; + + module_put(match->me); + kfree(ops); +} + +static struct nft_expr_type nft_match_type __read_mostly = { + .name = "match", + .select_ops = nft_match_select_ops, + .release_ops = nft_match_release_ops, + .policy = nft_match_policy, + .maxattr = NFTA_MATCH_MAX, + .owner = THIS_MODULE, +}; + +static struct nft_expr_type nft_target_type; + +static const struct nft_expr_ops * +nft_target_select_ops(const struct nft_ctx *ctx, + const struct nlattr * const tb[]) +{ + struct nft_expr_ops *ops; + struct xt_target *target; + char *tg_name; + u32 rev, family; + int err; + + if (tb[NFTA_TARGET_NAME] == NULL || + tb[NFTA_TARGET_REV] == NULL || + tb[NFTA_TARGET_INFO] == NULL) + return ERR_PTR(-EINVAL); + + tg_name = nla_data(tb[NFTA_TARGET_NAME]); + rev = ntohl(nla_get_be32(tb[NFTA_TARGET_REV])); + family = ctx->family; + + if (strcmp(tg_name, XT_ERROR_TARGET) == 0 || + strcmp(tg_name, XT_STANDARD_TARGET) == 0 || + strcmp(tg_name, "standard") == 0) + return ERR_PTR(-EINVAL); + + target = xt_request_find_target(family, tg_name, rev); + if (IS_ERR(target)) + return ERR_PTR(-ENOENT); + + if (!target->target) { + err = -EINVAL; + goto err; + } + + if (target->targetsize > nla_len(tb[NFTA_TARGET_INFO])) { + err = -EINVAL; + goto err; + } + + ops = kzalloc(sizeof(struct nft_expr_ops), GFP_KERNEL); + if (!ops) { + err = -ENOMEM; + goto err; + } + + ops->type = &nft_target_type; + ops->size = NFT_EXPR_SIZE(XT_ALIGN(target->targetsize)); + ops->init = nft_target_init; + ops->destroy = nft_target_destroy; + ops->dump = nft_target_dump; + ops->validate = nft_target_validate; + ops->data = target; + ops->reduce = NFT_REDUCE_READONLY; + + if (family == NFPROTO_BRIDGE) + ops->eval = nft_target_eval_bridge; + else + ops->eval = nft_target_eval_xt; + + return ops; +err: + module_put(target->me); + return ERR_PTR(err); +} + +static void nft_target_release_ops(const struct nft_expr_ops *ops) +{ + struct xt_target *target = ops->data; + + module_put(target->me); + kfree(ops); +} + +static struct nft_expr_type nft_target_type __read_mostly = { + .name = "target", + .select_ops = nft_target_select_ops, + .release_ops = nft_target_release_ops, + .policy = nft_target_policy, + .maxattr = NFTA_TARGET_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_compat_module_init(void) +{ + int ret; + + ret = nft_register_expr(&nft_match_type); + if (ret < 0) + return ret; + + ret = nft_register_expr(&nft_target_type); + if (ret < 0) + goto err_match; + + ret = nfnetlink_subsys_register(&nfnl_compat_subsys); + if (ret < 0) { + pr_err("nft_compat: cannot register with nfnetlink.\n"); + goto err_target; + } + + return ret; +err_target: + nft_unregister_expr(&nft_target_type); +err_match: + nft_unregister_expr(&nft_match_type); + return ret; +} + +static void __exit nft_compat_module_exit(void) +{ + nfnetlink_subsys_unregister(&nfnl_compat_subsys); + nft_unregister_expr(&nft_target_type); + nft_unregister_expr(&nft_match_type); +} + +MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_NFT_COMPAT); + +module_init(nft_compat_module_init); +module_exit(nft_compat_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_ALIAS_NFT_EXPR("match"); +MODULE_ALIAS_NFT_EXPR("target"); +MODULE_DESCRIPTION("x_tables over nftables support"); diff --git a/net/netfilter/nft_connlimit.c b/net/netfilter/nft_connlimit.c new file mode 100644 index 000000000..d657f999a --- /dev/null +++ b/net/netfilter/nft_connlimit.c @@ -0,0 +1,303 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_connlimit { + struct nf_conncount_list *list; + u32 limit; + bool invert; +}; + +static inline void nft_connlimit_do_eval(struct nft_connlimit *priv, + struct nft_regs *regs, + const struct nft_pktinfo *pkt, + const struct nft_set_ext *ext) +{ + const struct nf_conntrack_zone *zone = &nf_ct_zone_dflt; + const struct nf_conntrack_tuple *tuple_ptr; + struct nf_conntrack_tuple tuple; + enum ip_conntrack_info ctinfo; + const struct nf_conn *ct; + unsigned int count; + + tuple_ptr = &tuple; + + ct = nf_ct_get(pkt->skb, &ctinfo); + if (ct != NULL) { + tuple_ptr = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple; + zone = nf_ct_zone(ct); + } else if (!nf_ct_get_tuplepr(pkt->skb, skb_network_offset(pkt->skb), + nft_pf(pkt), nft_net(pkt), &tuple)) { + regs->verdict.code = NF_DROP; + return; + } + + if (nf_conncount_add(nft_net(pkt), priv->list, tuple_ptr, zone)) { + regs->verdict.code = NF_DROP; + return; + } + + count = priv->list->count; + + if ((count > priv->limit) ^ priv->invert) { + regs->verdict.code = NFT_BREAK; + return; + } +} + +static int nft_connlimit_do_init(const struct nft_ctx *ctx, + const struct nlattr * const tb[], + struct nft_connlimit *priv) +{ + bool invert = false; + u32 flags, limit; + int err; + + if (!tb[NFTA_CONNLIMIT_COUNT]) + return -EINVAL; + + limit = ntohl(nla_get_be32(tb[NFTA_CONNLIMIT_COUNT])); + + if (tb[NFTA_CONNLIMIT_FLAGS]) { + flags = ntohl(nla_get_be32(tb[NFTA_CONNLIMIT_FLAGS])); + if (flags & ~NFT_CONNLIMIT_F_INV) + return -EOPNOTSUPP; + if (flags & NFT_CONNLIMIT_F_INV) + invert = true; + } + + priv->list = kmalloc(sizeof(*priv->list), GFP_KERNEL_ACCOUNT); + if (!priv->list) + return -ENOMEM; + + nf_conncount_list_init(priv->list); + priv->limit = limit; + priv->invert = invert; + + err = nf_ct_netns_get(ctx->net, ctx->family); + if (err < 0) + goto err_netns; + + return 0; +err_netns: + kfree(priv->list); + + return err; +} + +static void nft_connlimit_do_destroy(const struct nft_ctx *ctx, + struct nft_connlimit *priv) +{ + nf_ct_netns_put(ctx->net, ctx->family); + nf_conncount_cache_free(priv->list); + kfree(priv->list); +} + +static int nft_connlimit_do_dump(struct sk_buff *skb, + struct nft_connlimit *priv) +{ + if (nla_put_be32(skb, NFTA_CONNLIMIT_COUNT, htonl(priv->limit))) + goto nla_put_failure; + if (priv->invert && + nla_put_be32(skb, NFTA_CONNLIMIT_FLAGS, htonl(NFT_CONNLIMIT_F_INV))) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static inline void nft_connlimit_obj_eval(struct nft_object *obj, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_connlimit *priv = nft_obj_data(obj); + + nft_connlimit_do_eval(priv, regs, pkt, NULL); +} + +static int nft_connlimit_obj_init(const struct nft_ctx *ctx, + const struct nlattr * const tb[], + struct nft_object *obj) +{ + struct nft_connlimit *priv = nft_obj_data(obj); + + return nft_connlimit_do_init(ctx, tb, priv); +} + +static void nft_connlimit_obj_destroy(const struct nft_ctx *ctx, + struct nft_object *obj) +{ + struct nft_connlimit *priv = nft_obj_data(obj); + + nft_connlimit_do_destroy(ctx, priv); +} + +static int nft_connlimit_obj_dump(struct sk_buff *skb, + struct nft_object *obj, bool reset) +{ + struct nft_connlimit *priv = nft_obj_data(obj); + + return nft_connlimit_do_dump(skb, priv); +} + +static const struct nla_policy nft_connlimit_policy[NFTA_CONNLIMIT_MAX + 1] = { + [NFTA_CONNLIMIT_COUNT] = { .type = NLA_U32 }, + [NFTA_CONNLIMIT_FLAGS] = { .type = NLA_U32 }, +}; + +static struct nft_object_type nft_connlimit_obj_type; +static const struct nft_object_ops nft_connlimit_obj_ops = { + .type = &nft_connlimit_obj_type, + .size = sizeof(struct nft_connlimit), + .eval = nft_connlimit_obj_eval, + .init = nft_connlimit_obj_init, + .destroy = nft_connlimit_obj_destroy, + .dump = nft_connlimit_obj_dump, +}; + +static struct nft_object_type nft_connlimit_obj_type __read_mostly = { + .type = NFT_OBJECT_CONNLIMIT, + .ops = &nft_connlimit_obj_ops, + .maxattr = NFTA_CONNLIMIT_MAX, + .policy = nft_connlimit_policy, + .owner = THIS_MODULE, +}; + +static void nft_connlimit_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_connlimit *priv = nft_expr_priv(expr); + + nft_connlimit_do_eval(priv, regs, pkt, NULL); +} + +static int nft_connlimit_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + struct nft_connlimit *priv = nft_expr_priv(expr); + + return nft_connlimit_do_dump(skb, priv); +} + +static int nft_connlimit_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_connlimit *priv = nft_expr_priv(expr); + + return nft_connlimit_do_init(ctx, tb, priv); +} + +static void nft_connlimit_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_connlimit *priv = nft_expr_priv(expr); + + nft_connlimit_do_destroy(ctx, priv); +} + +static int nft_connlimit_clone(struct nft_expr *dst, const struct nft_expr *src) +{ + struct nft_connlimit *priv_dst = nft_expr_priv(dst); + struct nft_connlimit *priv_src = nft_expr_priv(src); + + priv_dst->list = kmalloc(sizeof(*priv_dst->list), GFP_ATOMIC); + if (!priv_dst->list) + return -ENOMEM; + + nf_conncount_list_init(priv_dst->list); + priv_dst->limit = priv_src->limit; + priv_dst->invert = priv_src->invert; + + return 0; +} + +static void nft_connlimit_destroy_clone(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_connlimit *priv = nft_expr_priv(expr); + + nf_conncount_cache_free(priv->list); + kfree(priv->list); +} + +static bool nft_connlimit_gc(struct net *net, const struct nft_expr *expr) +{ + struct nft_connlimit *priv = nft_expr_priv(expr); + bool ret; + + local_bh_disable(); + ret = nf_conncount_gc_list(net, priv->list); + local_bh_enable(); + + return ret; +} + +static struct nft_expr_type nft_connlimit_type; +static const struct nft_expr_ops nft_connlimit_ops = { + .type = &nft_connlimit_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_connlimit)), + .eval = nft_connlimit_eval, + .init = nft_connlimit_init, + .destroy = nft_connlimit_destroy, + .clone = nft_connlimit_clone, + .destroy_clone = nft_connlimit_destroy_clone, + .dump = nft_connlimit_dump, + .gc = nft_connlimit_gc, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_connlimit_type __read_mostly = { + .name = "connlimit", + .ops = &nft_connlimit_ops, + .policy = nft_connlimit_policy, + .maxattr = NFTA_CONNLIMIT_MAX, + .flags = NFT_EXPR_STATEFUL | NFT_EXPR_GC, + .owner = THIS_MODULE, +}; + +static int __init nft_connlimit_module_init(void) +{ + int err; + + err = nft_register_obj(&nft_connlimit_obj_type); + if (err < 0) + return err; + + err = nft_register_expr(&nft_connlimit_type); + if (err < 0) + goto err1; + + return 0; +err1: + nft_unregister_obj(&nft_connlimit_obj_type); + return err; +} + +static void __exit nft_connlimit_module_exit(void) +{ + nft_unregister_expr(&nft_connlimit_type); + nft_unregister_obj(&nft_connlimit_obj_type); +} + +module_init(nft_connlimit_module_init); +module_exit(nft_connlimit_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo Neira Ayuso"); +MODULE_ALIAS_NFT_EXPR("connlimit"); +MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_CONNLIMIT); +MODULE_DESCRIPTION("nftables connlimit rule support"); diff --git a/net/netfilter/nft_counter.c b/net/netfilter/nft_counter.c new file mode 100644 index 000000000..f4d3573e8 --- /dev/null +++ b/net/netfilter/nft_counter.c @@ -0,0 +1,308 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008-2009 Patrick McHardy + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_counter { + s64 bytes; + s64 packets; +}; + +struct nft_counter_percpu_priv { + struct nft_counter __percpu *counter; +}; + +static DEFINE_PER_CPU(seqcount_t, nft_counter_seq); + +static inline void nft_counter_do_eval(struct nft_counter_percpu_priv *priv, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_counter *this_cpu; + seqcount_t *myseq; + + local_bh_disable(); + this_cpu = this_cpu_ptr(priv->counter); + myseq = this_cpu_ptr(&nft_counter_seq); + + write_seqcount_begin(myseq); + + this_cpu->bytes += pkt->skb->len; + this_cpu->packets++; + + write_seqcount_end(myseq); + local_bh_enable(); +} + +static inline void nft_counter_obj_eval(struct nft_object *obj, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_counter_percpu_priv *priv = nft_obj_data(obj); + + nft_counter_do_eval(priv, regs, pkt); +} + +static int nft_counter_do_init(const struct nlattr * const tb[], + struct nft_counter_percpu_priv *priv) +{ + struct nft_counter __percpu *cpu_stats; + struct nft_counter *this_cpu; + + cpu_stats = alloc_percpu_gfp(struct nft_counter, GFP_KERNEL_ACCOUNT); + if (cpu_stats == NULL) + return -ENOMEM; + + preempt_disable(); + this_cpu = this_cpu_ptr(cpu_stats); + if (tb[NFTA_COUNTER_PACKETS]) { + this_cpu->packets = + be64_to_cpu(nla_get_be64(tb[NFTA_COUNTER_PACKETS])); + } + if (tb[NFTA_COUNTER_BYTES]) { + this_cpu->bytes = + be64_to_cpu(nla_get_be64(tb[NFTA_COUNTER_BYTES])); + } + preempt_enable(); + priv->counter = cpu_stats; + return 0; +} + +static int nft_counter_obj_init(const struct nft_ctx *ctx, + const struct nlattr * const tb[], + struct nft_object *obj) +{ + struct nft_counter_percpu_priv *priv = nft_obj_data(obj); + + return nft_counter_do_init(tb, priv); +} + +static void nft_counter_do_destroy(struct nft_counter_percpu_priv *priv) +{ + free_percpu(priv->counter); +} + +static void nft_counter_obj_destroy(const struct nft_ctx *ctx, + struct nft_object *obj) +{ + struct nft_counter_percpu_priv *priv = nft_obj_data(obj); + + nft_counter_do_destroy(priv); +} + +static void nft_counter_reset(struct nft_counter_percpu_priv *priv, + struct nft_counter *total) +{ + struct nft_counter *this_cpu; + + local_bh_disable(); + this_cpu = this_cpu_ptr(priv->counter); + this_cpu->packets -= total->packets; + this_cpu->bytes -= total->bytes; + local_bh_enable(); +} + +static void nft_counter_fetch(struct nft_counter_percpu_priv *priv, + struct nft_counter *total) +{ + struct nft_counter *this_cpu; + const seqcount_t *myseq; + u64 bytes, packets; + unsigned int seq; + int cpu; + + memset(total, 0, sizeof(*total)); + for_each_possible_cpu(cpu) { + myseq = per_cpu_ptr(&nft_counter_seq, cpu); + this_cpu = per_cpu_ptr(priv->counter, cpu); + do { + seq = read_seqcount_begin(myseq); + bytes = this_cpu->bytes; + packets = this_cpu->packets; + } while (read_seqcount_retry(myseq, seq)); + + total->bytes += bytes; + total->packets += packets; + } +} + +static int nft_counter_do_dump(struct sk_buff *skb, + struct nft_counter_percpu_priv *priv, + bool reset) +{ + struct nft_counter total; + + nft_counter_fetch(priv, &total); + + if (nla_put_be64(skb, NFTA_COUNTER_BYTES, cpu_to_be64(total.bytes), + NFTA_COUNTER_PAD) || + nla_put_be64(skb, NFTA_COUNTER_PACKETS, cpu_to_be64(total.packets), + NFTA_COUNTER_PAD)) + goto nla_put_failure; + + if (reset) + nft_counter_reset(priv, &total); + + return 0; + +nla_put_failure: + return -1; +} + +static int nft_counter_obj_dump(struct sk_buff *skb, + struct nft_object *obj, bool reset) +{ + struct nft_counter_percpu_priv *priv = nft_obj_data(obj); + + return nft_counter_do_dump(skb, priv, reset); +} + +static const struct nla_policy nft_counter_policy[NFTA_COUNTER_MAX + 1] = { + [NFTA_COUNTER_PACKETS] = { .type = NLA_U64 }, + [NFTA_COUNTER_BYTES] = { .type = NLA_U64 }, +}; + +struct nft_object_type nft_counter_obj_type; +static const struct nft_object_ops nft_counter_obj_ops = { + .type = &nft_counter_obj_type, + .size = sizeof(struct nft_counter_percpu_priv), + .eval = nft_counter_obj_eval, + .init = nft_counter_obj_init, + .destroy = nft_counter_obj_destroy, + .dump = nft_counter_obj_dump, +}; + +struct nft_object_type nft_counter_obj_type __read_mostly = { + .type = NFT_OBJECT_COUNTER, + .ops = &nft_counter_obj_ops, + .maxattr = NFTA_COUNTER_MAX, + .policy = nft_counter_policy, + .owner = THIS_MODULE, +}; + +void nft_counter_eval(const struct nft_expr *expr, struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_counter_percpu_priv *priv = nft_expr_priv(expr); + + nft_counter_do_eval(priv, regs, pkt); +} + +static int nft_counter_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + struct nft_counter_percpu_priv *priv = nft_expr_priv(expr); + + return nft_counter_do_dump(skb, priv, false); +} + +static int nft_counter_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_counter_percpu_priv *priv = nft_expr_priv(expr); + + return nft_counter_do_init(tb, priv); +} + +static void nft_counter_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_counter_percpu_priv *priv = nft_expr_priv(expr); + + nft_counter_do_destroy(priv); +} + +static int nft_counter_clone(struct nft_expr *dst, const struct nft_expr *src) +{ + struct nft_counter_percpu_priv *priv = nft_expr_priv(src); + struct nft_counter_percpu_priv *priv_clone = nft_expr_priv(dst); + struct nft_counter __percpu *cpu_stats; + struct nft_counter *this_cpu; + struct nft_counter total; + + nft_counter_fetch(priv, &total); + + cpu_stats = alloc_percpu_gfp(struct nft_counter, GFP_ATOMIC); + if (cpu_stats == NULL) + return -ENOMEM; + + preempt_disable(); + this_cpu = this_cpu_ptr(cpu_stats); + this_cpu->packets = total.packets; + this_cpu->bytes = total.bytes; + preempt_enable(); + + priv_clone->counter = cpu_stats; + return 0; +} + +static int nft_counter_offload(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_expr *expr) +{ + /* No specific offload action is needed, but report success. */ + return 0; +} + +static void nft_counter_offload_stats(struct nft_expr *expr, + const struct flow_stats *stats) +{ + struct nft_counter_percpu_priv *priv = nft_expr_priv(expr); + struct nft_counter *this_cpu; + seqcount_t *myseq; + + preempt_disable(); + this_cpu = this_cpu_ptr(priv->counter); + myseq = this_cpu_ptr(&nft_counter_seq); + + write_seqcount_begin(myseq); + this_cpu->packets += stats->pkts; + this_cpu->bytes += stats->bytes; + write_seqcount_end(myseq); + preempt_enable(); +} + +void nft_counter_init_seqcount(void) +{ + int cpu; + + for_each_possible_cpu(cpu) + seqcount_init(per_cpu_ptr(&nft_counter_seq, cpu)); +} + +struct nft_expr_type nft_counter_type; +static const struct nft_expr_ops nft_counter_ops = { + .type = &nft_counter_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_counter_percpu_priv)), + .eval = nft_counter_eval, + .init = nft_counter_init, + .destroy = nft_counter_destroy, + .destroy_clone = nft_counter_destroy, + .dump = nft_counter_dump, + .clone = nft_counter_clone, + .reduce = NFT_REDUCE_READONLY, + .offload = nft_counter_offload, + .offload_stats = nft_counter_offload_stats, +}; + +struct nft_expr_type nft_counter_type __read_mostly = { + .name = "counter", + .ops = &nft_counter_ops, + .policy = nft_counter_policy, + .maxattr = NFTA_COUNTER_MAX, + .flags = NFT_EXPR_STATEFUL, + .owner = THIS_MODULE, +}; diff --git a/net/netfilter/nft_ct.c b/net/netfilter/nft_ct.c new file mode 100644 index 000000000..641dc21f9 --- /dev/null +++ b/net/netfilter/nft_ct.c @@ -0,0 +1,1401 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008-2009 Patrick McHardy + * Copyright (c) 2016 Pablo Neira Ayuso + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_ct { + enum nft_ct_keys key:8; + enum ip_conntrack_dir dir:8; + u8 len; + union { + u8 dreg; + u8 sreg; + }; +}; + +struct nft_ct_helper_obj { + struct nf_conntrack_helper *helper4; + struct nf_conntrack_helper *helper6; + u8 l4proto; +}; + +#ifdef CONFIG_NF_CONNTRACK_ZONES +static DEFINE_PER_CPU(struct nf_conn *, nft_ct_pcpu_template); +static unsigned int nft_ct_pcpu_template_refcnt __read_mostly; +static DEFINE_MUTEX(nft_ct_pcpu_mutex); +#endif + +static u64 nft_ct_get_eval_counter(const struct nf_conn_counter *c, + enum nft_ct_keys k, + enum ip_conntrack_dir d) +{ + if (d < IP_CT_DIR_MAX) + return k == NFT_CT_BYTES ? atomic64_read(&c[d].bytes) : + atomic64_read(&c[d].packets); + + return nft_ct_get_eval_counter(c, k, IP_CT_DIR_ORIGINAL) + + nft_ct_get_eval_counter(c, k, IP_CT_DIR_REPLY); +} + +static void nft_ct_get_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_ct *priv = nft_expr_priv(expr); + u32 *dest = ®s->data[priv->dreg]; + enum ip_conntrack_info ctinfo; + const struct nf_conn *ct; + const struct nf_conn_help *help; + const struct nf_conntrack_tuple *tuple; + const struct nf_conntrack_helper *helper; + unsigned int state; + + ct = nf_ct_get(pkt->skb, &ctinfo); + + switch (priv->key) { + case NFT_CT_STATE: + if (ct) + state = NF_CT_STATE_BIT(ctinfo); + else if (ctinfo == IP_CT_UNTRACKED) + state = NF_CT_STATE_UNTRACKED_BIT; + else + state = NF_CT_STATE_INVALID_BIT; + *dest = state; + return; + default: + break; + } + + if (ct == NULL) + goto err; + + switch (priv->key) { + case NFT_CT_DIRECTION: + nft_reg_store8(dest, CTINFO2DIR(ctinfo)); + return; + case NFT_CT_STATUS: + *dest = ct->status; + return; +#ifdef CONFIG_NF_CONNTRACK_MARK + case NFT_CT_MARK: + *dest = READ_ONCE(ct->mark); + return; +#endif +#ifdef CONFIG_NF_CONNTRACK_SECMARK + case NFT_CT_SECMARK: + *dest = ct->secmark; + return; +#endif + case NFT_CT_EXPIRATION: + *dest = jiffies_to_msecs(nf_ct_expires(ct)); + return; + case NFT_CT_HELPER: + if (ct->master == NULL) + goto err; + help = nfct_help(ct->master); + if (help == NULL) + goto err; + helper = rcu_dereference(help->helper); + if (helper == NULL) + goto err; + strncpy((char *)dest, helper->name, NF_CT_HELPER_NAME_LEN); + return; +#ifdef CONFIG_NF_CONNTRACK_LABELS + case NFT_CT_LABELS: { + struct nf_conn_labels *labels = nf_ct_labels_find(ct); + + if (labels) + memcpy(dest, labels->bits, NF_CT_LABELS_MAX_SIZE); + else + memset(dest, 0, NF_CT_LABELS_MAX_SIZE); + return; + } +#endif + case NFT_CT_BYTES: + case NFT_CT_PKTS: { + const struct nf_conn_acct *acct = nf_conn_acct_find(ct); + u64 count = 0; + + if (acct) + count = nft_ct_get_eval_counter(acct->counter, + priv->key, priv->dir); + memcpy(dest, &count, sizeof(count)); + return; + } + case NFT_CT_AVGPKT: { + const struct nf_conn_acct *acct = nf_conn_acct_find(ct); + u64 avgcnt = 0, bcnt = 0, pcnt = 0; + + if (acct) { + pcnt = nft_ct_get_eval_counter(acct->counter, + NFT_CT_PKTS, priv->dir); + bcnt = nft_ct_get_eval_counter(acct->counter, + NFT_CT_BYTES, priv->dir); + if (pcnt != 0) + avgcnt = div64_u64(bcnt, pcnt); + } + + memcpy(dest, &avgcnt, sizeof(avgcnt)); + return; + } + case NFT_CT_L3PROTOCOL: + nft_reg_store8(dest, nf_ct_l3num(ct)); + return; + case NFT_CT_PROTOCOL: + nft_reg_store8(dest, nf_ct_protonum(ct)); + return; +#ifdef CONFIG_NF_CONNTRACK_ZONES + case NFT_CT_ZONE: { + const struct nf_conntrack_zone *zone = nf_ct_zone(ct); + u16 zoneid; + + if (priv->dir < IP_CT_DIR_MAX) + zoneid = nf_ct_zone_id(zone, priv->dir); + else + zoneid = zone->id; + + nft_reg_store16(dest, zoneid); + return; + } +#endif + case NFT_CT_ID: + *dest = nf_ct_get_id(ct); + return; + default: + break; + } + + tuple = &ct->tuplehash[priv->dir].tuple; + switch (priv->key) { + case NFT_CT_SRC: + memcpy(dest, tuple->src.u3.all, + nf_ct_l3num(ct) == NFPROTO_IPV4 ? 4 : 16); + return; + case NFT_CT_DST: + memcpy(dest, tuple->dst.u3.all, + nf_ct_l3num(ct) == NFPROTO_IPV4 ? 4 : 16); + return; + case NFT_CT_PROTO_SRC: + nft_reg_store16(dest, (__force u16)tuple->src.u.all); + return; + case NFT_CT_PROTO_DST: + nft_reg_store16(dest, (__force u16)tuple->dst.u.all); + return; + case NFT_CT_SRC_IP: + if (nf_ct_l3num(ct) != NFPROTO_IPV4) + goto err; + *dest = (__force __u32)tuple->src.u3.ip; + return; + case NFT_CT_DST_IP: + if (nf_ct_l3num(ct) != NFPROTO_IPV4) + goto err; + *dest = (__force __u32)tuple->dst.u3.ip; + return; + case NFT_CT_SRC_IP6: + if (nf_ct_l3num(ct) != NFPROTO_IPV6) + goto err; + memcpy(dest, tuple->src.u3.ip6, sizeof(struct in6_addr)); + return; + case NFT_CT_DST_IP6: + if (nf_ct_l3num(ct) != NFPROTO_IPV6) + goto err; + memcpy(dest, tuple->dst.u3.ip6, sizeof(struct in6_addr)); + return; + default: + break; + } + return; +err: + regs->verdict.code = NFT_BREAK; +} + +#ifdef CONFIG_NF_CONNTRACK_ZONES +static void nft_ct_set_zone_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nf_conntrack_zone zone = { .dir = NF_CT_DEFAULT_ZONE_DIR }; + const struct nft_ct *priv = nft_expr_priv(expr); + struct sk_buff *skb = pkt->skb; + enum ip_conntrack_info ctinfo; + u16 value = nft_reg_load16(®s->data[priv->sreg]); + struct nf_conn *ct; + + ct = nf_ct_get(skb, &ctinfo); + if (ct) /* already tracked */ + return; + + zone.id = value; + + switch (priv->dir) { + case IP_CT_DIR_ORIGINAL: + zone.dir = NF_CT_ZONE_DIR_ORIG; + break; + case IP_CT_DIR_REPLY: + zone.dir = NF_CT_ZONE_DIR_REPL; + break; + default: + break; + } + + ct = this_cpu_read(nft_ct_pcpu_template); + + if (likely(refcount_read(&ct->ct_general.use) == 1)) { + refcount_inc(&ct->ct_general.use); + nf_ct_zone_add(ct, &zone); + } else { + /* previous skb got queued to userspace, allocate temporary + * one until percpu template can be reused. + */ + ct = nf_ct_tmpl_alloc(nft_net(pkt), &zone, GFP_ATOMIC); + if (!ct) { + regs->verdict.code = NF_DROP; + return; + } + } + + nf_ct_set(skb, ct, IP_CT_NEW); +} +#endif + +static void nft_ct_set_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_ct *priv = nft_expr_priv(expr); + struct sk_buff *skb = pkt->skb; +#if defined(CONFIG_NF_CONNTRACK_MARK) || defined(CONFIG_NF_CONNTRACK_SECMARK) + u32 value = regs->data[priv->sreg]; +#endif + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + + ct = nf_ct_get(skb, &ctinfo); + if (ct == NULL || nf_ct_is_template(ct)) + return; + + switch (priv->key) { +#ifdef CONFIG_NF_CONNTRACK_MARK + case NFT_CT_MARK: + if (READ_ONCE(ct->mark) != value) { + WRITE_ONCE(ct->mark, value); + nf_conntrack_event_cache(IPCT_MARK, ct); + } + break; +#endif +#ifdef CONFIG_NF_CONNTRACK_SECMARK + case NFT_CT_SECMARK: + if (ct->secmark != value) { + ct->secmark = value; + nf_conntrack_event_cache(IPCT_SECMARK, ct); + } + break; +#endif +#ifdef CONFIG_NF_CONNTRACK_LABELS + case NFT_CT_LABELS: + nf_connlabels_replace(ct, + ®s->data[priv->sreg], + ®s->data[priv->sreg], + NF_CT_LABELS_MAX_SIZE / sizeof(u32)); + break; +#endif +#ifdef CONFIG_NF_CONNTRACK_EVENTS + case NFT_CT_EVENTMASK: { + struct nf_conntrack_ecache *e = nf_ct_ecache_find(ct); + u32 ctmask = regs->data[priv->sreg]; + + if (e) { + if (e->ctmask != ctmask) + e->ctmask = ctmask; + break; + } + + if (ctmask && !nf_ct_is_confirmed(ct)) + nf_ct_ecache_ext_add(ct, ctmask, 0, GFP_ATOMIC); + break; + } +#endif + default: + break; + } +} + +static const struct nla_policy nft_ct_policy[NFTA_CT_MAX + 1] = { + [NFTA_CT_DREG] = { .type = NLA_U32 }, + [NFTA_CT_KEY] = { .type = NLA_U32 }, + [NFTA_CT_DIRECTION] = { .type = NLA_U8 }, + [NFTA_CT_SREG] = { .type = NLA_U32 }, +}; + +#ifdef CONFIG_NF_CONNTRACK_ZONES +static void nft_ct_tmpl_put_pcpu(void) +{ + struct nf_conn *ct; + int cpu; + + for_each_possible_cpu(cpu) { + ct = per_cpu(nft_ct_pcpu_template, cpu); + if (!ct) + break; + nf_ct_put(ct); + per_cpu(nft_ct_pcpu_template, cpu) = NULL; + } +} + +static bool nft_ct_tmpl_alloc_pcpu(void) +{ + struct nf_conntrack_zone zone = { .id = 0 }; + struct nf_conn *tmp; + int cpu; + + if (nft_ct_pcpu_template_refcnt) + return true; + + for_each_possible_cpu(cpu) { + tmp = nf_ct_tmpl_alloc(&init_net, &zone, GFP_KERNEL); + if (!tmp) { + nft_ct_tmpl_put_pcpu(); + return false; + } + + per_cpu(nft_ct_pcpu_template, cpu) = tmp; + } + + return true; +} +#endif + +static int nft_ct_get_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_ct *priv = nft_expr_priv(expr); + unsigned int len; + int err; + + priv->key = ntohl(nla_get_be32(tb[NFTA_CT_KEY])); + priv->dir = IP_CT_DIR_MAX; + switch (priv->key) { + case NFT_CT_DIRECTION: + if (tb[NFTA_CT_DIRECTION] != NULL) + return -EINVAL; + len = sizeof(u8); + break; + case NFT_CT_STATE: + case NFT_CT_STATUS: +#ifdef CONFIG_NF_CONNTRACK_MARK + case NFT_CT_MARK: +#endif +#ifdef CONFIG_NF_CONNTRACK_SECMARK + case NFT_CT_SECMARK: +#endif + case NFT_CT_EXPIRATION: + if (tb[NFTA_CT_DIRECTION] != NULL) + return -EINVAL; + len = sizeof(u32); + break; +#ifdef CONFIG_NF_CONNTRACK_LABELS + case NFT_CT_LABELS: + if (tb[NFTA_CT_DIRECTION] != NULL) + return -EINVAL; + len = NF_CT_LABELS_MAX_SIZE; + break; +#endif + case NFT_CT_HELPER: + if (tb[NFTA_CT_DIRECTION] != NULL) + return -EINVAL; + len = NF_CT_HELPER_NAME_LEN; + break; + + case NFT_CT_L3PROTOCOL: + case NFT_CT_PROTOCOL: + /* For compatibility, do not report error if NFTA_CT_DIRECTION + * attribute is specified. + */ + len = sizeof(u8); + break; + case NFT_CT_SRC: + case NFT_CT_DST: + if (tb[NFTA_CT_DIRECTION] == NULL) + return -EINVAL; + + switch (ctx->family) { + case NFPROTO_IPV4: + len = sizeof_field(struct nf_conntrack_tuple, + src.u3.ip); + break; + case NFPROTO_IPV6: + case NFPROTO_INET: + len = sizeof_field(struct nf_conntrack_tuple, + src.u3.ip6); + break; + default: + return -EAFNOSUPPORT; + } + break; + case NFT_CT_SRC_IP: + case NFT_CT_DST_IP: + if (tb[NFTA_CT_DIRECTION] == NULL) + return -EINVAL; + + len = sizeof_field(struct nf_conntrack_tuple, src.u3.ip); + break; + case NFT_CT_SRC_IP6: + case NFT_CT_DST_IP6: + if (tb[NFTA_CT_DIRECTION] == NULL) + return -EINVAL; + + len = sizeof_field(struct nf_conntrack_tuple, src.u3.ip6); + break; + case NFT_CT_PROTO_SRC: + case NFT_CT_PROTO_DST: + if (tb[NFTA_CT_DIRECTION] == NULL) + return -EINVAL; + len = sizeof_field(struct nf_conntrack_tuple, src.u.all); + break; + case NFT_CT_BYTES: + case NFT_CT_PKTS: + case NFT_CT_AVGPKT: + len = sizeof(u64); + break; +#ifdef CONFIG_NF_CONNTRACK_ZONES + case NFT_CT_ZONE: + len = sizeof(u16); + break; +#endif + case NFT_CT_ID: + len = sizeof(u32); + break; + default: + return -EOPNOTSUPP; + } + + if (tb[NFTA_CT_DIRECTION] != NULL) { + priv->dir = nla_get_u8(tb[NFTA_CT_DIRECTION]); + switch (priv->dir) { + case IP_CT_DIR_ORIGINAL: + case IP_CT_DIR_REPLY: + break; + default: + return -EINVAL; + } + } + + priv->len = len; + err = nft_parse_register_store(ctx, tb[NFTA_CT_DREG], &priv->dreg, NULL, + NFT_DATA_VALUE, len); + if (err < 0) + return err; + + err = nf_ct_netns_get(ctx->net, ctx->family); + if (err < 0) + return err; + + if (priv->key == NFT_CT_BYTES || + priv->key == NFT_CT_PKTS || + priv->key == NFT_CT_AVGPKT) + nf_ct_set_acct(ctx->net, true); + + return 0; +} + +static void __nft_ct_set_destroy(const struct nft_ctx *ctx, struct nft_ct *priv) +{ + switch (priv->key) { +#ifdef CONFIG_NF_CONNTRACK_LABELS + case NFT_CT_LABELS: + nf_connlabels_put(ctx->net); + break; +#endif +#ifdef CONFIG_NF_CONNTRACK_ZONES + case NFT_CT_ZONE: + mutex_lock(&nft_ct_pcpu_mutex); + if (--nft_ct_pcpu_template_refcnt == 0) + nft_ct_tmpl_put_pcpu(); + mutex_unlock(&nft_ct_pcpu_mutex); + break; +#endif + default: + break; + } +} + +static int nft_ct_set_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_ct *priv = nft_expr_priv(expr); + unsigned int len; + int err; + + priv->dir = IP_CT_DIR_MAX; + priv->key = ntohl(nla_get_be32(tb[NFTA_CT_KEY])); + switch (priv->key) { +#ifdef CONFIG_NF_CONNTRACK_MARK + case NFT_CT_MARK: + if (tb[NFTA_CT_DIRECTION]) + return -EINVAL; + len = sizeof_field(struct nf_conn, mark); + break; +#endif +#ifdef CONFIG_NF_CONNTRACK_LABELS + case NFT_CT_LABELS: + if (tb[NFTA_CT_DIRECTION]) + return -EINVAL; + len = NF_CT_LABELS_MAX_SIZE; + err = nf_connlabels_get(ctx->net, (len * BITS_PER_BYTE) - 1); + if (err) + return err; + break; +#endif +#ifdef CONFIG_NF_CONNTRACK_ZONES + case NFT_CT_ZONE: + mutex_lock(&nft_ct_pcpu_mutex); + if (!nft_ct_tmpl_alloc_pcpu()) { + mutex_unlock(&nft_ct_pcpu_mutex); + return -ENOMEM; + } + nft_ct_pcpu_template_refcnt++; + mutex_unlock(&nft_ct_pcpu_mutex); + len = sizeof(u16); + break; +#endif +#ifdef CONFIG_NF_CONNTRACK_EVENTS + case NFT_CT_EVENTMASK: + if (tb[NFTA_CT_DIRECTION]) + return -EINVAL; + len = sizeof(u32); + break; +#endif +#ifdef CONFIG_NF_CONNTRACK_SECMARK + case NFT_CT_SECMARK: + if (tb[NFTA_CT_DIRECTION]) + return -EINVAL; + len = sizeof(u32); + break; +#endif + default: + return -EOPNOTSUPP; + } + + if (tb[NFTA_CT_DIRECTION]) { + priv->dir = nla_get_u8(tb[NFTA_CT_DIRECTION]); + switch (priv->dir) { + case IP_CT_DIR_ORIGINAL: + case IP_CT_DIR_REPLY: + break; + default: + err = -EINVAL; + goto err1; + } + } + + priv->len = len; + err = nft_parse_register_load(tb[NFTA_CT_SREG], &priv->sreg, len); + if (err < 0) + goto err1; + + err = nf_ct_netns_get(ctx->net, ctx->family); + if (err < 0) + goto err1; + + return 0; + +err1: + __nft_ct_set_destroy(ctx, priv); + return err; +} + +static void nft_ct_get_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + nf_ct_netns_put(ctx->net, ctx->family); +} + +static void nft_ct_set_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_ct *priv = nft_expr_priv(expr); + + __nft_ct_set_destroy(ctx, priv); + nf_ct_netns_put(ctx->net, ctx->family); +} + +static int nft_ct_get_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_ct *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_CT_DREG, priv->dreg)) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_CT_KEY, htonl(priv->key))) + goto nla_put_failure; + + switch (priv->key) { + case NFT_CT_SRC: + case NFT_CT_DST: + case NFT_CT_SRC_IP: + case NFT_CT_DST_IP: + case NFT_CT_SRC_IP6: + case NFT_CT_DST_IP6: + case NFT_CT_PROTO_SRC: + case NFT_CT_PROTO_DST: + if (nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir)) + goto nla_put_failure; + break; + case NFT_CT_BYTES: + case NFT_CT_PKTS: + case NFT_CT_AVGPKT: + case NFT_CT_ZONE: + if (priv->dir < IP_CT_DIR_MAX && + nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir)) + goto nla_put_failure; + break; + default: + break; + } + + return 0; + +nla_put_failure: + return -1; +} + +static bool nft_ct_get_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + const struct nft_ct *priv = nft_expr_priv(expr); + const struct nft_ct *ct; + + if (!nft_reg_track_cmp(track, expr, priv->dreg)) { + nft_reg_track_update(track, expr, priv->dreg, priv->len); + return false; + } + + ct = nft_expr_priv(track->regs[priv->dreg].selector); + if (priv->key != ct->key) { + nft_reg_track_update(track, expr, priv->dreg, priv->len); + return false; + } + + if (!track->regs[priv->dreg].bitwise) + return true; + + return nft_expr_reduce_bitwise(track, expr); +} + +static int nft_ct_set_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_ct *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_CT_SREG, priv->sreg)) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_CT_KEY, htonl(priv->key))) + goto nla_put_failure; + + switch (priv->key) { + case NFT_CT_ZONE: + if (priv->dir < IP_CT_DIR_MAX && + nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir)) + goto nla_put_failure; + break; + default: + break; + } + + return 0; + +nla_put_failure: + return -1; +} + +static struct nft_expr_type nft_ct_type; +static const struct nft_expr_ops nft_ct_get_ops = { + .type = &nft_ct_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_ct)), + .eval = nft_ct_get_eval, + .init = nft_ct_get_init, + .destroy = nft_ct_get_destroy, + .dump = nft_ct_get_dump, + .reduce = nft_ct_get_reduce, +}; + +static bool nft_ct_set_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + int i; + + for (i = 0; i < NFT_REG32_NUM; i++) { + if (!track->regs[i].selector) + continue; + + if (track->regs[i].selector->ops != &nft_ct_get_ops) + continue; + + __nft_reg_track_cancel(track, i); + } + + return false; +} + +static const struct nft_expr_ops nft_ct_set_ops = { + .type = &nft_ct_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_ct)), + .eval = nft_ct_set_eval, + .init = nft_ct_set_init, + .destroy = nft_ct_set_destroy, + .dump = nft_ct_set_dump, + .reduce = nft_ct_set_reduce, +}; + +#ifdef CONFIG_NF_CONNTRACK_ZONES +static const struct nft_expr_ops nft_ct_set_zone_ops = { + .type = &nft_ct_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_ct)), + .eval = nft_ct_set_zone_eval, + .init = nft_ct_set_init, + .destroy = nft_ct_set_destroy, + .dump = nft_ct_set_dump, + .reduce = nft_ct_set_reduce, +}; +#endif + +static const struct nft_expr_ops * +nft_ct_select_ops(const struct nft_ctx *ctx, + const struct nlattr * const tb[]) +{ + if (tb[NFTA_CT_KEY] == NULL) + return ERR_PTR(-EINVAL); + + if (tb[NFTA_CT_DREG] && tb[NFTA_CT_SREG]) + return ERR_PTR(-EINVAL); + + if (tb[NFTA_CT_DREG]) + return &nft_ct_get_ops; + + if (tb[NFTA_CT_SREG]) { +#ifdef CONFIG_NF_CONNTRACK_ZONES + if (nla_get_be32(tb[NFTA_CT_KEY]) == htonl(NFT_CT_ZONE)) + return &nft_ct_set_zone_ops; +#endif + return &nft_ct_set_ops; + } + + return ERR_PTR(-EINVAL); +} + +static struct nft_expr_type nft_ct_type __read_mostly = { + .name = "ct", + .select_ops = nft_ct_select_ops, + .policy = nft_ct_policy, + .maxattr = NFTA_CT_MAX, + .owner = THIS_MODULE, +}; + +static void nft_notrack_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct sk_buff *skb = pkt->skb; + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + + ct = nf_ct_get(pkt->skb, &ctinfo); + /* Previously seen (loopback or untracked)? Ignore. */ + if (ct || ctinfo == IP_CT_UNTRACKED) + return; + + nf_ct_set(skb, ct, IP_CT_UNTRACKED); +} + +static struct nft_expr_type nft_notrack_type; +static const struct nft_expr_ops nft_notrack_ops = { + .type = &nft_notrack_type, + .size = NFT_EXPR_SIZE(0), + .eval = nft_notrack_eval, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_notrack_type __read_mostly = { + .name = "notrack", + .ops = &nft_notrack_ops, + .owner = THIS_MODULE, +}; + +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT +static int +nft_ct_timeout_parse_policy(void *timeouts, + const struct nf_conntrack_l4proto *l4proto, + struct net *net, const struct nlattr *attr) +{ + struct nlattr **tb; + int ret = 0; + + tb = kcalloc(l4proto->ctnl_timeout.nlattr_max + 1, sizeof(*tb), + GFP_KERNEL); + + if (!tb) + return -ENOMEM; + + ret = nla_parse_nested_deprecated(tb, + l4proto->ctnl_timeout.nlattr_max, + attr, + l4proto->ctnl_timeout.nla_policy, + NULL); + if (ret < 0) + goto err; + + ret = l4proto->ctnl_timeout.nlattr_to_obj(tb, net, timeouts); + +err: + kfree(tb); + return ret; +} + +struct nft_ct_timeout_obj { + struct nf_ct_timeout *timeout; + u8 l4proto; +}; + +static void nft_ct_timeout_obj_eval(struct nft_object *obj, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_ct_timeout_obj *priv = nft_obj_data(obj); + struct nf_conn *ct = (struct nf_conn *)skb_nfct(pkt->skb); + struct nf_conn_timeout *timeout; + const unsigned int *values; + + if (priv->l4proto != pkt->tprot) + return; + + if (!ct || nf_ct_is_template(ct) || nf_ct_is_confirmed(ct)) + return; + + timeout = nf_ct_timeout_find(ct); + if (!timeout) { + timeout = nf_ct_timeout_ext_add(ct, priv->timeout, GFP_ATOMIC); + if (!timeout) { + regs->verdict.code = NF_DROP; + return; + } + } + + rcu_assign_pointer(timeout->timeout, priv->timeout); + + /* adjust the timeout as per 'new' state. ct is unconfirmed, + * so the current timestamp must not be added. + */ + values = nf_ct_timeout_data(timeout); + if (values) + nf_ct_refresh(ct, pkt->skb, values[0]); +} + +static int nft_ct_timeout_obj_init(const struct nft_ctx *ctx, + const struct nlattr * const tb[], + struct nft_object *obj) +{ + struct nft_ct_timeout_obj *priv = nft_obj_data(obj); + const struct nf_conntrack_l4proto *l4proto; + struct nf_ct_timeout *timeout; + int l3num = ctx->family; + __u8 l4num; + int ret; + + if (!tb[NFTA_CT_TIMEOUT_L4PROTO] || + !tb[NFTA_CT_TIMEOUT_DATA]) + return -EINVAL; + + if (tb[NFTA_CT_TIMEOUT_L3PROTO]) + l3num = ntohs(nla_get_be16(tb[NFTA_CT_TIMEOUT_L3PROTO])); + + l4num = nla_get_u8(tb[NFTA_CT_TIMEOUT_L4PROTO]); + priv->l4proto = l4num; + + l4proto = nf_ct_l4proto_find(l4num); + + if (l4proto->l4proto != l4num) { + ret = -EOPNOTSUPP; + goto err_proto_put; + } + + timeout = kzalloc(sizeof(struct nf_ct_timeout) + + l4proto->ctnl_timeout.obj_size, GFP_KERNEL); + if (timeout == NULL) { + ret = -ENOMEM; + goto err_proto_put; + } + + ret = nft_ct_timeout_parse_policy(&timeout->data, l4proto, ctx->net, + tb[NFTA_CT_TIMEOUT_DATA]); + if (ret < 0) + goto err_free_timeout; + + timeout->l3num = l3num; + timeout->l4proto = l4proto; + + ret = nf_ct_netns_get(ctx->net, ctx->family); + if (ret < 0) + goto err_free_timeout; + + priv->timeout = timeout; + return 0; + +err_free_timeout: + kfree(timeout); +err_proto_put: + return ret; +} + +static void nft_ct_timeout_obj_destroy(const struct nft_ctx *ctx, + struct nft_object *obj) +{ + struct nft_ct_timeout_obj *priv = nft_obj_data(obj); + struct nf_ct_timeout *timeout = priv->timeout; + + nf_ct_untimeout(ctx->net, timeout); + nf_ct_netns_put(ctx->net, ctx->family); + kfree(priv->timeout); +} + +static int nft_ct_timeout_obj_dump(struct sk_buff *skb, + struct nft_object *obj, bool reset) +{ + const struct nft_ct_timeout_obj *priv = nft_obj_data(obj); + const struct nf_ct_timeout *timeout = priv->timeout; + struct nlattr *nest_params; + int ret; + + if (nla_put_u8(skb, NFTA_CT_TIMEOUT_L4PROTO, timeout->l4proto->l4proto) || + nla_put_be16(skb, NFTA_CT_TIMEOUT_L3PROTO, htons(timeout->l3num))) + return -1; + + nest_params = nla_nest_start(skb, NFTA_CT_TIMEOUT_DATA); + if (!nest_params) + return -1; + + ret = timeout->l4proto->ctnl_timeout.obj_to_nlattr(skb, &timeout->data); + if (ret < 0) + return -1; + nla_nest_end(skb, nest_params); + return 0; +} + +static const struct nla_policy nft_ct_timeout_policy[NFTA_CT_TIMEOUT_MAX + 1] = { + [NFTA_CT_TIMEOUT_L3PROTO] = {.type = NLA_U16 }, + [NFTA_CT_TIMEOUT_L4PROTO] = {.type = NLA_U8 }, + [NFTA_CT_TIMEOUT_DATA] = {.type = NLA_NESTED }, +}; + +static struct nft_object_type nft_ct_timeout_obj_type; + +static const struct nft_object_ops nft_ct_timeout_obj_ops = { + .type = &nft_ct_timeout_obj_type, + .size = sizeof(struct nft_ct_timeout_obj), + .eval = nft_ct_timeout_obj_eval, + .init = nft_ct_timeout_obj_init, + .destroy = nft_ct_timeout_obj_destroy, + .dump = nft_ct_timeout_obj_dump, +}; + +static struct nft_object_type nft_ct_timeout_obj_type __read_mostly = { + .type = NFT_OBJECT_CT_TIMEOUT, + .ops = &nft_ct_timeout_obj_ops, + .maxattr = NFTA_CT_TIMEOUT_MAX, + .policy = nft_ct_timeout_policy, + .owner = THIS_MODULE, +}; +#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */ + +static int nft_ct_helper_obj_init(const struct nft_ctx *ctx, + const struct nlattr * const tb[], + struct nft_object *obj) +{ + struct nft_ct_helper_obj *priv = nft_obj_data(obj); + struct nf_conntrack_helper *help4, *help6; + char name[NF_CT_HELPER_NAME_LEN]; + int family = ctx->family; + int err; + + if (!tb[NFTA_CT_HELPER_NAME] || !tb[NFTA_CT_HELPER_L4PROTO]) + return -EINVAL; + + priv->l4proto = nla_get_u8(tb[NFTA_CT_HELPER_L4PROTO]); + if (!priv->l4proto) + return -ENOENT; + + nla_strscpy(name, tb[NFTA_CT_HELPER_NAME], sizeof(name)); + + if (tb[NFTA_CT_HELPER_L3PROTO]) + family = ntohs(nla_get_be16(tb[NFTA_CT_HELPER_L3PROTO])); + + help4 = NULL; + help6 = NULL; + + switch (family) { + case NFPROTO_IPV4: + if (ctx->family == NFPROTO_IPV6) + return -EINVAL; + + help4 = nf_conntrack_helper_try_module_get(name, family, + priv->l4proto); + break; + case NFPROTO_IPV6: + if (ctx->family == NFPROTO_IPV4) + return -EINVAL; + + help6 = nf_conntrack_helper_try_module_get(name, family, + priv->l4proto); + break; + case NFPROTO_NETDEV: + case NFPROTO_BRIDGE: + case NFPROTO_INET: + help4 = nf_conntrack_helper_try_module_get(name, NFPROTO_IPV4, + priv->l4proto); + help6 = nf_conntrack_helper_try_module_get(name, NFPROTO_IPV6, + priv->l4proto); + break; + default: + return -EAFNOSUPPORT; + } + + /* && is intentional; only error if INET found neither ipv4 or ipv6 */ + if (!help4 && !help6) + return -ENOENT; + + priv->helper4 = help4; + priv->helper6 = help6; + + err = nf_ct_netns_get(ctx->net, ctx->family); + if (err < 0) + goto err_put_helper; + + return 0; + +err_put_helper: + if (priv->helper4) + nf_conntrack_helper_put(priv->helper4); + if (priv->helper6) + nf_conntrack_helper_put(priv->helper6); + return err; +} + +static void nft_ct_helper_obj_destroy(const struct nft_ctx *ctx, + struct nft_object *obj) +{ + struct nft_ct_helper_obj *priv = nft_obj_data(obj); + + if (priv->helper4) + nf_conntrack_helper_put(priv->helper4); + if (priv->helper6) + nf_conntrack_helper_put(priv->helper6); + + nf_ct_netns_put(ctx->net, ctx->family); +} + +static void nft_ct_helper_obj_eval(struct nft_object *obj, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_ct_helper_obj *priv = nft_obj_data(obj); + struct nf_conn *ct = (struct nf_conn *)skb_nfct(pkt->skb); + struct nf_conntrack_helper *to_assign = NULL; + struct nf_conn_help *help; + + if (!ct || + nf_ct_is_confirmed(ct) || + nf_ct_is_template(ct) || + priv->l4proto != nf_ct_protonum(ct)) + return; + + switch (nf_ct_l3num(ct)) { + case NFPROTO_IPV4: + to_assign = priv->helper4; + break; + case NFPROTO_IPV6: + to_assign = priv->helper6; + break; + default: + WARN_ON_ONCE(1); + return; + } + + if (!to_assign) + return; + + if (test_bit(IPS_HELPER_BIT, &ct->status)) + return; + + help = nf_ct_helper_ext_add(ct, GFP_ATOMIC); + if (help) { + rcu_assign_pointer(help->helper, to_assign); + set_bit(IPS_HELPER_BIT, &ct->status); + } +} + +static int nft_ct_helper_obj_dump(struct sk_buff *skb, + struct nft_object *obj, bool reset) +{ + const struct nft_ct_helper_obj *priv = nft_obj_data(obj); + const struct nf_conntrack_helper *helper; + u16 family; + + if (priv->helper4 && priv->helper6) { + family = NFPROTO_INET; + helper = priv->helper4; + } else if (priv->helper6) { + family = NFPROTO_IPV6; + helper = priv->helper6; + } else { + family = NFPROTO_IPV4; + helper = priv->helper4; + } + + if (nla_put_string(skb, NFTA_CT_HELPER_NAME, helper->name)) + return -1; + + if (nla_put_u8(skb, NFTA_CT_HELPER_L4PROTO, priv->l4proto)) + return -1; + + if (nla_put_be16(skb, NFTA_CT_HELPER_L3PROTO, htons(family))) + return -1; + + return 0; +} + +static const struct nla_policy nft_ct_helper_policy[NFTA_CT_HELPER_MAX + 1] = { + [NFTA_CT_HELPER_NAME] = { .type = NLA_STRING, + .len = NF_CT_HELPER_NAME_LEN - 1 }, + [NFTA_CT_HELPER_L3PROTO] = { .type = NLA_U16 }, + [NFTA_CT_HELPER_L4PROTO] = { .type = NLA_U8 }, +}; + +static struct nft_object_type nft_ct_helper_obj_type; +static const struct nft_object_ops nft_ct_helper_obj_ops = { + .type = &nft_ct_helper_obj_type, + .size = sizeof(struct nft_ct_helper_obj), + .eval = nft_ct_helper_obj_eval, + .init = nft_ct_helper_obj_init, + .destroy = nft_ct_helper_obj_destroy, + .dump = nft_ct_helper_obj_dump, +}; + +static struct nft_object_type nft_ct_helper_obj_type __read_mostly = { + .type = NFT_OBJECT_CT_HELPER, + .ops = &nft_ct_helper_obj_ops, + .maxattr = NFTA_CT_HELPER_MAX, + .policy = nft_ct_helper_policy, + .owner = THIS_MODULE, +}; + +struct nft_ct_expect_obj { + u16 l3num; + __be16 dport; + u8 l4proto; + u8 size; + u32 timeout; +}; + +static int nft_ct_expect_obj_init(const struct nft_ctx *ctx, + const struct nlattr * const tb[], + struct nft_object *obj) +{ + struct nft_ct_expect_obj *priv = nft_obj_data(obj); + + if (!tb[NFTA_CT_EXPECT_L4PROTO] || + !tb[NFTA_CT_EXPECT_DPORT] || + !tb[NFTA_CT_EXPECT_TIMEOUT] || + !tb[NFTA_CT_EXPECT_SIZE]) + return -EINVAL; + + priv->l3num = ctx->family; + if (tb[NFTA_CT_EXPECT_L3PROTO]) + priv->l3num = ntohs(nla_get_be16(tb[NFTA_CT_EXPECT_L3PROTO])); + + priv->l4proto = nla_get_u8(tb[NFTA_CT_EXPECT_L4PROTO]); + priv->dport = nla_get_be16(tb[NFTA_CT_EXPECT_DPORT]); + priv->timeout = nla_get_u32(tb[NFTA_CT_EXPECT_TIMEOUT]); + priv->size = nla_get_u8(tb[NFTA_CT_EXPECT_SIZE]); + + return nf_ct_netns_get(ctx->net, ctx->family); +} + +static void nft_ct_expect_obj_destroy(const struct nft_ctx *ctx, + struct nft_object *obj) +{ + nf_ct_netns_put(ctx->net, ctx->family); +} + +static int nft_ct_expect_obj_dump(struct sk_buff *skb, + struct nft_object *obj, bool reset) +{ + const struct nft_ct_expect_obj *priv = nft_obj_data(obj); + + if (nla_put_be16(skb, NFTA_CT_EXPECT_L3PROTO, htons(priv->l3num)) || + nla_put_u8(skb, NFTA_CT_EXPECT_L4PROTO, priv->l4proto) || + nla_put_be16(skb, NFTA_CT_EXPECT_DPORT, priv->dport) || + nla_put_u32(skb, NFTA_CT_EXPECT_TIMEOUT, priv->timeout) || + nla_put_u8(skb, NFTA_CT_EXPECT_SIZE, priv->size)) + return -1; + + return 0; +} + +static void nft_ct_expect_obj_eval(struct nft_object *obj, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_ct_expect_obj *priv = nft_obj_data(obj); + struct nf_conntrack_expect *exp; + enum ip_conntrack_info ctinfo; + struct nf_conn_help *help; + enum ip_conntrack_dir dir; + u16 l3num = priv->l3num; + struct nf_conn *ct; + + ct = nf_ct_get(pkt->skb, &ctinfo); + if (!ct || nf_ct_is_confirmed(ct) || nf_ct_is_template(ct)) { + regs->verdict.code = NFT_BREAK; + return; + } + dir = CTINFO2DIR(ctinfo); + + help = nfct_help(ct); + if (!help) + help = nf_ct_helper_ext_add(ct, GFP_ATOMIC); + if (!help) { + regs->verdict.code = NF_DROP; + return; + } + + if (help->expecting[NF_CT_EXPECT_CLASS_DEFAULT] >= priv->size) { + regs->verdict.code = NFT_BREAK; + return; + } + if (l3num == NFPROTO_INET) + l3num = nf_ct_l3num(ct); + + exp = nf_ct_expect_alloc(ct); + if (exp == NULL) { + regs->verdict.code = NF_DROP; + return; + } + nf_ct_expect_init(exp, NF_CT_EXPECT_CLASS_DEFAULT, l3num, + &ct->tuplehash[!dir].tuple.src.u3, + &ct->tuplehash[!dir].tuple.dst.u3, + priv->l4proto, NULL, &priv->dport); + exp->timeout.expires = jiffies + priv->timeout * HZ; + + if (nf_ct_expect_related(exp, 0) != 0) + regs->verdict.code = NF_DROP; +} + +static const struct nla_policy nft_ct_expect_policy[NFTA_CT_EXPECT_MAX + 1] = { + [NFTA_CT_EXPECT_L3PROTO] = { .type = NLA_U16 }, + [NFTA_CT_EXPECT_L4PROTO] = { .type = NLA_U8 }, + [NFTA_CT_EXPECT_DPORT] = { .type = NLA_U16 }, + [NFTA_CT_EXPECT_TIMEOUT] = { .type = NLA_U32 }, + [NFTA_CT_EXPECT_SIZE] = { .type = NLA_U8 }, +}; + +static struct nft_object_type nft_ct_expect_obj_type; + +static const struct nft_object_ops nft_ct_expect_obj_ops = { + .type = &nft_ct_expect_obj_type, + .size = sizeof(struct nft_ct_expect_obj), + .eval = nft_ct_expect_obj_eval, + .init = nft_ct_expect_obj_init, + .destroy = nft_ct_expect_obj_destroy, + .dump = nft_ct_expect_obj_dump, +}; + +static struct nft_object_type nft_ct_expect_obj_type __read_mostly = { + .type = NFT_OBJECT_CT_EXPECT, + .ops = &nft_ct_expect_obj_ops, + .maxattr = NFTA_CT_EXPECT_MAX, + .policy = nft_ct_expect_policy, + .owner = THIS_MODULE, +}; + +static int __init nft_ct_module_init(void) +{ + int err; + + BUILD_BUG_ON(NF_CT_LABELS_MAX_SIZE > NFT_REG_SIZE); + + err = nft_register_expr(&nft_ct_type); + if (err < 0) + return err; + + err = nft_register_expr(&nft_notrack_type); + if (err < 0) + goto err1; + + err = nft_register_obj(&nft_ct_helper_obj_type); + if (err < 0) + goto err2; + + err = nft_register_obj(&nft_ct_expect_obj_type); + if (err < 0) + goto err3; +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + err = nft_register_obj(&nft_ct_timeout_obj_type); + if (err < 0) + goto err4; +#endif + return 0; + +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT +err4: + nft_unregister_obj(&nft_ct_expect_obj_type); +#endif +err3: + nft_unregister_obj(&nft_ct_helper_obj_type); +err2: + nft_unregister_expr(&nft_notrack_type); +err1: + nft_unregister_expr(&nft_ct_type); + return err; +} + +static void __exit nft_ct_module_exit(void) +{ +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + nft_unregister_obj(&nft_ct_timeout_obj_type); +#endif + nft_unregister_obj(&nft_ct_expect_obj_type); + nft_unregister_obj(&nft_ct_helper_obj_type); + nft_unregister_expr(&nft_notrack_type); + nft_unregister_expr(&nft_ct_type); +} + +module_init(nft_ct_module_init); +module_exit(nft_ct_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_ALIAS_NFT_EXPR("ct"); +MODULE_ALIAS_NFT_EXPR("notrack"); +MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_CT_HELPER); +MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_CT_TIMEOUT); +MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_CT_EXPECT); +MODULE_DESCRIPTION("Netfilter nf_tables conntrack module"); diff --git a/net/netfilter/nft_dup_netdev.c b/net/netfilter/nft_dup_netdev.c new file mode 100644 index 000000000..635074027 --- /dev/null +++ b/net/netfilter/nft_dup_netdev.c @@ -0,0 +1,112 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2015 Pablo Neira Ayuso + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_dup_netdev { + u8 sreg_dev; +}; + +static void nft_dup_netdev_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_dup_netdev *priv = nft_expr_priv(expr); + int oif = regs->data[priv->sreg_dev]; + + nf_dup_netdev_egress(pkt, oif); +} + +static const struct nla_policy nft_dup_netdev_policy[NFTA_DUP_MAX + 1] = { + [NFTA_DUP_SREG_DEV] = { .type = NLA_U32 }, +}; + +static int nft_dup_netdev_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_dup_netdev *priv = nft_expr_priv(expr); + + if (tb[NFTA_DUP_SREG_DEV] == NULL) + return -EINVAL; + + return nft_parse_register_load(tb[NFTA_DUP_SREG_DEV], &priv->sreg_dev, + sizeof(int)); +} + +static int nft_dup_netdev_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + struct nft_dup_netdev *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_DUP_SREG_DEV, priv->sreg_dev)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static int nft_dup_netdev_offload(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_expr *expr) +{ + const struct nft_dup_netdev *priv = nft_expr_priv(expr); + int oif = ctx->regs[priv->sreg_dev].data.data[0]; + + return nft_fwd_dup_netdev_offload(ctx, flow, FLOW_ACTION_MIRRED, oif); +} + +static bool nft_dup_netdev_offload_action(const struct nft_expr *expr) +{ + return true; +} + +static struct nft_expr_type nft_dup_netdev_type; +static const struct nft_expr_ops nft_dup_netdev_ops = { + .type = &nft_dup_netdev_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_dup_netdev)), + .eval = nft_dup_netdev_eval, + .init = nft_dup_netdev_init, + .dump = nft_dup_netdev_dump, + .reduce = NFT_REDUCE_READONLY, + .offload = nft_dup_netdev_offload, + .offload_action = nft_dup_netdev_offload_action, +}; + +static struct nft_expr_type nft_dup_netdev_type __read_mostly = { + .family = NFPROTO_NETDEV, + .name = "dup", + .ops = &nft_dup_netdev_ops, + .policy = nft_dup_netdev_policy, + .maxattr = NFTA_DUP_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_dup_netdev_module_init(void) +{ + return nft_register_expr(&nft_dup_netdev_type); +} + +static void __exit nft_dup_netdev_module_exit(void) +{ + nft_unregister_expr(&nft_dup_netdev_type); +} + +module_init(nft_dup_netdev_module_init); +module_exit(nft_dup_netdev_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_ALIAS_NFT_AF_EXPR(5, "dup"); +MODULE_DESCRIPTION("nftables netdev packet duplication support"); diff --git a/net/netfilter/nft_dynset.c b/net/netfilter/nft_dynset.c new file mode 100644 index 000000000..a470e5f61 --- /dev/null +++ b/net/netfilter/nft_dynset.c @@ -0,0 +1,433 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2015 Patrick McHardy + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_dynset { + struct nft_set *set; + struct nft_set_ext_tmpl tmpl; + enum nft_dynset_ops op:8; + u8 sreg_key; + u8 sreg_data; + bool invert; + bool expr; + u8 num_exprs; + u64 timeout; + struct nft_expr *expr_array[NFT_SET_EXPR_MAX]; + struct nft_set_binding binding; +}; + +static int nft_dynset_expr_setup(const struct nft_dynset *priv, + const struct nft_set_ext *ext) +{ + struct nft_set_elem_expr *elem_expr = nft_set_ext_expr(ext); + struct nft_expr *expr; + int i; + + for (i = 0; i < priv->num_exprs; i++) { + expr = nft_setelem_expr_at(elem_expr, elem_expr->size); + if (nft_expr_clone(expr, priv->expr_array[i]) < 0) + return -1; + + elem_expr->size += priv->expr_array[i]->ops->size; + } + + return 0; +} + +static void *nft_dynset_new(struct nft_set *set, const struct nft_expr *expr, + struct nft_regs *regs) +{ + const struct nft_dynset *priv = nft_expr_priv(expr); + struct nft_set_ext *ext; + u64 timeout; + void *elem; + + if (!atomic_add_unless(&set->nelems, 1, set->size)) + return NULL; + + timeout = priv->timeout ? : set->timeout; + elem = nft_set_elem_init(set, &priv->tmpl, + ®s->data[priv->sreg_key], NULL, + ®s->data[priv->sreg_data], + timeout, 0, GFP_ATOMIC); + if (IS_ERR(elem)) + goto err1; + + ext = nft_set_elem_ext(set, elem); + if (priv->num_exprs && nft_dynset_expr_setup(priv, ext) < 0) + goto err2; + + return elem; + +err2: + nft_set_elem_destroy(set, elem, false); +err1: + if (set->size) + atomic_dec(&set->nelems); + return NULL; +} + +void nft_dynset_eval(const struct nft_expr *expr, + struct nft_regs *regs, const struct nft_pktinfo *pkt) +{ + const struct nft_dynset *priv = nft_expr_priv(expr); + struct nft_set *set = priv->set; + const struct nft_set_ext *ext; + u64 timeout; + + if (priv->op == NFT_DYNSET_OP_DELETE) { + set->ops->delete(set, ®s->data[priv->sreg_key]); + return; + } + + if (set->ops->update(set, ®s->data[priv->sreg_key], nft_dynset_new, + expr, regs, &ext)) { + if (priv->op == NFT_DYNSET_OP_UPDATE && + nft_set_ext_exists(ext, NFT_SET_EXT_EXPIRATION)) { + timeout = priv->timeout ? : set->timeout; + *nft_set_ext_expiration(ext) = get_jiffies_64() + timeout; + } + + nft_set_elem_update_expr(ext, regs, pkt); + + if (priv->invert) + regs->verdict.code = NFT_BREAK; + return; + } + + if (!priv->invert) + regs->verdict.code = NFT_BREAK; +} + +static void nft_dynset_ext_add_expr(struct nft_dynset *priv) +{ + u8 size = 0; + int i; + + for (i = 0; i < priv->num_exprs; i++) + size += priv->expr_array[i]->ops->size; + + nft_set_ext_add_length(&priv->tmpl, NFT_SET_EXT_EXPRESSIONS, + sizeof(struct nft_set_elem_expr) + size); +} + +static struct nft_expr * +nft_dynset_expr_alloc(const struct nft_ctx *ctx, const struct nft_set *set, + const struct nlattr *attr, int pos) +{ + struct nft_expr *expr; + int err; + + expr = nft_set_elem_expr_alloc(ctx, set, attr); + if (IS_ERR(expr)) + return expr; + + if (set->exprs[pos] && set->exprs[pos]->ops != expr->ops) { + err = -EOPNOTSUPP; + goto err_dynset_expr; + } + + return expr; + +err_dynset_expr: + nft_expr_destroy(ctx, expr); + return ERR_PTR(err); +} + +static const struct nla_policy nft_dynset_policy[NFTA_DYNSET_MAX + 1] = { + [NFTA_DYNSET_SET_NAME] = { .type = NLA_STRING, + .len = NFT_SET_MAXNAMELEN - 1 }, + [NFTA_DYNSET_SET_ID] = { .type = NLA_U32 }, + [NFTA_DYNSET_OP] = { .type = NLA_U32 }, + [NFTA_DYNSET_SREG_KEY] = { .type = NLA_U32 }, + [NFTA_DYNSET_SREG_DATA] = { .type = NLA_U32 }, + [NFTA_DYNSET_TIMEOUT] = { .type = NLA_U64 }, + [NFTA_DYNSET_EXPR] = { .type = NLA_NESTED }, + [NFTA_DYNSET_FLAGS] = { .type = NLA_U32 }, + [NFTA_DYNSET_EXPRESSIONS] = { .type = NLA_NESTED }, +}; + +static int nft_dynset_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nftables_pernet *nft_net = nft_pernet(ctx->net); + struct nft_dynset *priv = nft_expr_priv(expr); + u8 genmask = nft_genmask_next(ctx->net); + struct nft_set *set; + u64 timeout; + int err, i; + + lockdep_assert_held(&nft_net->commit_mutex); + + if (tb[NFTA_DYNSET_SET_NAME] == NULL || + tb[NFTA_DYNSET_OP] == NULL || + tb[NFTA_DYNSET_SREG_KEY] == NULL) + return -EINVAL; + + if (tb[NFTA_DYNSET_FLAGS]) { + u32 flags = ntohl(nla_get_be32(tb[NFTA_DYNSET_FLAGS])); + if (flags & ~(NFT_DYNSET_F_INV | NFT_DYNSET_F_EXPR)) + return -EOPNOTSUPP; + if (flags & NFT_DYNSET_F_INV) + priv->invert = true; + if (flags & NFT_DYNSET_F_EXPR) + priv->expr = true; + } + + set = nft_set_lookup_global(ctx->net, ctx->table, + tb[NFTA_DYNSET_SET_NAME], + tb[NFTA_DYNSET_SET_ID], genmask); + if (IS_ERR(set)) + return PTR_ERR(set); + + if (set->flags & NFT_SET_OBJECT) + return -EOPNOTSUPP; + + if (set->ops->update == NULL) + return -EOPNOTSUPP; + + if (set->flags & NFT_SET_CONSTANT) + return -EBUSY; + + priv->op = ntohl(nla_get_be32(tb[NFTA_DYNSET_OP])); + if (priv->op > NFT_DYNSET_OP_DELETE) + return -EOPNOTSUPP; + + timeout = 0; + if (tb[NFTA_DYNSET_TIMEOUT] != NULL) { + if (!(set->flags & NFT_SET_TIMEOUT)) + return -EOPNOTSUPP; + + err = nf_msecs_to_jiffies64(tb[NFTA_DYNSET_TIMEOUT], &timeout); + if (err) + return err; + } + + err = nft_parse_register_load(tb[NFTA_DYNSET_SREG_KEY], &priv->sreg_key, + set->klen); + if (err < 0) + return err; + + if (tb[NFTA_DYNSET_SREG_DATA] != NULL) { + if (!(set->flags & NFT_SET_MAP)) + return -EOPNOTSUPP; + if (set->dtype == NFT_DATA_VERDICT) + return -EOPNOTSUPP; + + err = nft_parse_register_load(tb[NFTA_DYNSET_SREG_DATA], + &priv->sreg_data, set->dlen); + if (err < 0) + return err; + } else if (set->flags & NFT_SET_MAP) + return -EINVAL; + + if ((tb[NFTA_DYNSET_EXPR] || tb[NFTA_DYNSET_EXPRESSIONS]) && + !(set->flags & NFT_SET_EVAL)) + return -EINVAL; + + if (tb[NFTA_DYNSET_EXPR]) { + struct nft_expr *dynset_expr; + + dynset_expr = nft_dynset_expr_alloc(ctx, set, + tb[NFTA_DYNSET_EXPR], 0); + if (IS_ERR(dynset_expr)) + return PTR_ERR(dynset_expr); + + priv->num_exprs++; + priv->expr_array[0] = dynset_expr; + + if (set->num_exprs > 1 || + (set->num_exprs == 1 && + dynset_expr->ops != set->exprs[0]->ops)) { + err = -EOPNOTSUPP; + goto err_expr_free; + } + } else if (tb[NFTA_DYNSET_EXPRESSIONS]) { + struct nft_expr *dynset_expr; + struct nlattr *tmp; + int left; + + if (!priv->expr) + return -EINVAL; + + i = 0; + nla_for_each_nested(tmp, tb[NFTA_DYNSET_EXPRESSIONS], left) { + if (i == NFT_SET_EXPR_MAX) { + err = -E2BIG; + goto err_expr_free; + } + if (nla_type(tmp) != NFTA_LIST_ELEM) { + err = -EINVAL; + goto err_expr_free; + } + dynset_expr = nft_dynset_expr_alloc(ctx, set, tmp, i); + if (IS_ERR(dynset_expr)) { + err = PTR_ERR(dynset_expr); + goto err_expr_free; + } + priv->expr_array[i] = dynset_expr; + priv->num_exprs++; + + if (set->num_exprs) { + if (i >= set->num_exprs) { + err = -EINVAL; + goto err_expr_free; + } + if (dynset_expr->ops != set->exprs[i]->ops) { + err = -EOPNOTSUPP; + goto err_expr_free; + } + } + i++; + } + if (set->num_exprs && set->num_exprs != i) { + err = -EOPNOTSUPP; + goto err_expr_free; + } + } else if (set->num_exprs > 0) { + err = nft_set_elem_expr_clone(ctx, set, priv->expr_array); + if (err < 0) + return err; + + priv->num_exprs = set->num_exprs; + } + + nft_set_ext_prepare(&priv->tmpl); + nft_set_ext_add_length(&priv->tmpl, NFT_SET_EXT_KEY, set->klen); + if (set->flags & NFT_SET_MAP) + nft_set_ext_add_length(&priv->tmpl, NFT_SET_EXT_DATA, set->dlen); + + if (priv->num_exprs) + nft_dynset_ext_add_expr(priv); + + if (set->flags & NFT_SET_TIMEOUT) { + if (timeout || set->timeout) { + nft_set_ext_add(&priv->tmpl, NFT_SET_EXT_TIMEOUT); + nft_set_ext_add(&priv->tmpl, NFT_SET_EXT_EXPIRATION); + } + } + + priv->timeout = timeout; + + err = nf_tables_bind_set(ctx, set, &priv->binding); + if (err < 0) + goto err_expr_free; + + if (set->size == 0) + set->size = 0xffff; + + priv->set = set; + return 0; + +err_expr_free: + for (i = 0; i < priv->num_exprs; i++) + nft_expr_destroy(ctx, priv->expr_array[i]); + return err; +} + +static void nft_dynset_deactivate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + enum nft_trans_phase phase) +{ + struct nft_dynset *priv = nft_expr_priv(expr); + + nf_tables_deactivate_set(ctx, priv->set, &priv->binding, phase); +} + +static void nft_dynset_activate(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_dynset *priv = nft_expr_priv(expr); + + nf_tables_activate_set(ctx, priv->set); +} + +static void nft_dynset_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_dynset *priv = nft_expr_priv(expr); + int i; + + for (i = 0; i < priv->num_exprs; i++) + nft_expr_destroy(ctx, priv->expr_array[i]); + + nf_tables_destroy_set(ctx, priv->set); +} + +static int nft_dynset_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_dynset *priv = nft_expr_priv(expr); + u32 flags = priv->invert ? NFT_DYNSET_F_INV : 0; + int i; + + if (nft_dump_register(skb, NFTA_DYNSET_SREG_KEY, priv->sreg_key)) + goto nla_put_failure; + if (priv->set->flags & NFT_SET_MAP && + nft_dump_register(skb, NFTA_DYNSET_SREG_DATA, priv->sreg_data)) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_DYNSET_OP, htonl(priv->op))) + goto nla_put_failure; + if (nla_put_string(skb, NFTA_DYNSET_SET_NAME, priv->set->name)) + goto nla_put_failure; + if (nla_put_be64(skb, NFTA_DYNSET_TIMEOUT, + nf_jiffies64_to_msecs(priv->timeout), + NFTA_DYNSET_PAD)) + goto nla_put_failure; + if (priv->set->num_exprs == 0) { + if (priv->num_exprs == 1) { + if (nft_expr_dump(skb, NFTA_DYNSET_EXPR, + priv->expr_array[0])) + goto nla_put_failure; + } else if (priv->num_exprs > 1) { + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, NFTA_DYNSET_EXPRESSIONS); + if (!nest) + goto nla_put_failure; + + for (i = 0; i < priv->num_exprs; i++) { + if (nft_expr_dump(skb, NFTA_LIST_ELEM, + priv->expr_array[i])) + goto nla_put_failure; + } + nla_nest_end(skb, nest); + } + } + if (nla_put_be32(skb, NFTA_DYNSET_FLAGS, htonl(flags))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static const struct nft_expr_ops nft_dynset_ops = { + .type = &nft_dynset_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_dynset)), + .eval = nft_dynset_eval, + .init = nft_dynset_init, + .destroy = nft_dynset_destroy, + .activate = nft_dynset_activate, + .deactivate = nft_dynset_deactivate, + .dump = nft_dynset_dump, + .reduce = NFT_REDUCE_READONLY, +}; + +struct nft_expr_type nft_dynset_type __read_mostly = { + .name = "dynset", + .ops = &nft_dynset_ops, + .policy = nft_dynset_policy, + .maxattr = NFTA_DYNSET_MAX, + .owner = THIS_MODULE, +}; diff --git a/net/netfilter/nft_exthdr.c b/net/netfilter/nft_exthdr.c new file mode 100644 index 000000000..de588f7b6 --- /dev/null +++ b/net/netfilter/nft_exthdr.c @@ -0,0 +1,840 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008 Patrick McHardy + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_exthdr { + u8 type; + u8 offset; + u8 len; + u8 op; + u8 dreg; + u8 sreg; + u8 flags; +}; + +static unsigned int optlen(const u8 *opt, unsigned int offset) +{ + /* Beware zero-length options: make finite progress */ + if (opt[offset] <= TCPOPT_NOP || opt[offset + 1] == 0) + return 1; + else + return opt[offset + 1]; +} + +static int nft_skb_copy_to_reg(const struct sk_buff *skb, int offset, u32 *dest, unsigned int len) +{ + if (len % NFT_REG32_SIZE) + dest[len / NFT_REG32_SIZE] = 0; + + return skb_copy_bits(skb, offset, dest, len); +} + +static void nft_exthdr_ipv6_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_exthdr *priv = nft_expr_priv(expr); + u32 *dest = ®s->data[priv->dreg]; + unsigned int offset = 0; + int err; + + if (pkt->skb->protocol != htons(ETH_P_IPV6)) + goto err; + + err = ipv6_find_hdr(pkt->skb, &offset, priv->type, NULL, NULL); + if (priv->flags & NFT_EXTHDR_F_PRESENT) { + nft_reg_store8(dest, err >= 0); + return; + } else if (err < 0) { + goto err; + } + offset += priv->offset; + + if (nft_skb_copy_to_reg(pkt->skb, offset, dest, priv->len) < 0) + goto err; + return; +err: + regs->verdict.code = NFT_BREAK; +} + +/* find the offset to specified option. + * + * If target header is found, its offset is set in *offset and return option + * number. Otherwise, return negative error. + * + * If the first fragment doesn't contain the End of Options it is considered + * invalid. + */ +static int ipv4_find_option(struct net *net, struct sk_buff *skb, + unsigned int *offset, int target) +{ + unsigned char optbuf[sizeof(struct ip_options) + 40]; + struct ip_options *opt = (struct ip_options *)optbuf; + struct iphdr *iph, _iph; + unsigned int start; + bool found = false; + __be32 info; + int optlen; + + iph = skb_header_pointer(skb, 0, sizeof(_iph), &_iph); + if (!iph) + return -EBADMSG; + start = sizeof(struct iphdr); + + optlen = iph->ihl * 4 - (int)sizeof(struct iphdr); + if (optlen <= 0) + return -ENOENT; + + memset(opt, 0, sizeof(struct ip_options)); + /* Copy the options since __ip_options_compile() modifies + * the options. + */ + if (skb_copy_bits(skb, start, opt->__data, optlen)) + return -EBADMSG; + opt->optlen = optlen; + + if (__ip_options_compile(net, opt, NULL, &info)) + return -EBADMSG; + + switch (target) { + case IPOPT_SSRR: + case IPOPT_LSRR: + if (!opt->srr) + break; + found = target == IPOPT_SSRR ? opt->is_strictroute : + !opt->is_strictroute; + if (found) + *offset = opt->srr + start; + break; + case IPOPT_RR: + if (!opt->rr) + break; + *offset = opt->rr + start; + found = true; + break; + case IPOPT_RA: + if (!opt->router_alert) + break; + *offset = opt->router_alert + start; + found = true; + break; + default: + return -EOPNOTSUPP; + } + return found ? target : -ENOENT; +} + +static void nft_exthdr_ipv4_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_exthdr *priv = nft_expr_priv(expr); + u32 *dest = ®s->data[priv->dreg]; + struct sk_buff *skb = pkt->skb; + unsigned int offset; + int err; + + if (skb->protocol != htons(ETH_P_IP)) + goto err; + + err = ipv4_find_option(nft_net(pkt), skb, &offset, priv->type); + if (priv->flags & NFT_EXTHDR_F_PRESENT) { + nft_reg_store8(dest, err >= 0); + return; + } else if (err < 0) { + goto err; + } + offset += priv->offset; + + if (nft_skb_copy_to_reg(pkt->skb, offset, dest, priv->len) < 0) + goto err; + return; +err: + regs->verdict.code = NFT_BREAK; +} + +static void * +nft_tcp_header_pointer(const struct nft_pktinfo *pkt, + unsigned int len, void *buffer, unsigned int *tcphdr_len) +{ + struct tcphdr *tcph; + + if (pkt->tprot != IPPROTO_TCP || pkt->fragoff) + return NULL; + + tcph = skb_header_pointer(pkt->skb, nft_thoff(pkt), sizeof(*tcph), buffer); + if (!tcph) + return NULL; + + *tcphdr_len = __tcp_hdrlen(tcph); + if (*tcphdr_len < sizeof(*tcph) || *tcphdr_len > len) + return NULL; + + return skb_header_pointer(pkt->skb, nft_thoff(pkt), *tcphdr_len, buffer); +} + +static void nft_exthdr_tcp_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + u8 buff[sizeof(struct tcphdr) + MAX_TCP_OPTION_SPACE]; + struct nft_exthdr *priv = nft_expr_priv(expr); + unsigned int i, optl, tcphdr_len, offset; + u32 *dest = ®s->data[priv->dreg]; + struct tcphdr *tcph; + u8 *opt; + + tcph = nft_tcp_header_pointer(pkt, sizeof(buff), buff, &tcphdr_len); + if (!tcph) + goto err; + + opt = (u8 *)tcph; + for (i = sizeof(*tcph); i < tcphdr_len - 1; i += optl) { + optl = optlen(opt, i); + + if (priv->type != opt[i]) + continue; + + if (i + optl > tcphdr_len || priv->len + priv->offset > optl) + goto err; + + offset = i + priv->offset; + if (priv->flags & NFT_EXTHDR_F_PRESENT) { + nft_reg_store8(dest, 1); + } else { + if (priv->len % NFT_REG32_SIZE) + dest[priv->len / NFT_REG32_SIZE] = 0; + memcpy(dest, opt + offset, priv->len); + } + + return; + } + +err: + if (priv->flags & NFT_EXTHDR_F_PRESENT) + *dest = 0; + else + regs->verdict.code = NFT_BREAK; +} + +static void nft_exthdr_tcp_set_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + u8 buff[sizeof(struct tcphdr) + MAX_TCP_OPTION_SPACE]; + struct nft_exthdr *priv = nft_expr_priv(expr); + unsigned int i, optl, tcphdr_len, offset; + struct tcphdr *tcph; + u8 *opt; + + tcph = nft_tcp_header_pointer(pkt, sizeof(buff), buff, &tcphdr_len); + if (!tcph) + goto err; + + if (skb_ensure_writable(pkt->skb, nft_thoff(pkt) + tcphdr_len)) + goto err; + + tcph = (struct tcphdr *)(pkt->skb->data + nft_thoff(pkt)); + opt = (u8 *)tcph; + + for (i = sizeof(*tcph); i < tcphdr_len - 1; i += optl) { + union { + __be16 v16; + __be32 v32; + } old, new; + + optl = optlen(opt, i); + + if (priv->type != opt[i]) + continue; + + if (i + optl > tcphdr_len || priv->len + priv->offset > optl) + goto err; + + offset = i + priv->offset; + + switch (priv->len) { + case 2: + old.v16 = (__force __be16)get_unaligned((u16 *)(opt + offset)); + new.v16 = (__force __be16)nft_reg_load16( + ®s->data[priv->sreg]); + + switch (priv->type) { + case TCPOPT_MSS: + /* increase can cause connection to stall */ + if (ntohs(old.v16) <= ntohs(new.v16)) + return; + break; + } + + if (old.v16 == new.v16) + return; + + put_unaligned(new.v16, (__be16*)(opt + offset)); + inet_proto_csum_replace2(&tcph->check, pkt->skb, + old.v16, new.v16, false); + break; + case 4: + new.v32 = nft_reg_load_be32(®s->data[priv->sreg]); + old.v32 = (__force __be32)get_unaligned((u32 *)(opt + offset)); + + if (old.v32 == new.v32) + return; + + put_unaligned(new.v32, (__be32*)(opt + offset)); + inet_proto_csum_replace4(&tcph->check, pkt->skb, + old.v32, new.v32, false); + break; + default: + WARN_ON_ONCE(1); + break; + } + + return; + } + return; +err: + regs->verdict.code = NFT_BREAK; +} + +static void nft_exthdr_tcp_strip_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + u8 buff[sizeof(struct tcphdr) + MAX_TCP_OPTION_SPACE]; + struct nft_exthdr *priv = nft_expr_priv(expr); + unsigned int i, tcphdr_len, optl; + struct tcphdr *tcph; + u8 *opt; + + tcph = nft_tcp_header_pointer(pkt, sizeof(buff), buff, &tcphdr_len); + if (!tcph) + goto err; + + if (skb_ensure_writable(pkt->skb, nft_thoff(pkt) + tcphdr_len)) + goto drop; + + tcph = (struct tcphdr *)(pkt->skb->data + nft_thoff(pkt)); + opt = (u8 *)tcph; + + for (i = sizeof(*tcph); i < tcphdr_len - 1; i += optl) { + unsigned int j; + + optl = optlen(opt, i); + if (priv->type != opt[i]) + continue; + + if (i + optl > tcphdr_len) + goto drop; + + for (j = 0; j < optl; ++j) { + u16 n = TCPOPT_NOP; + u16 o = opt[i+j]; + + if ((i + j) % 2 == 0) { + o <<= 8; + n <<= 8; + } + inet_proto_csum_replace2(&tcph->check, pkt->skb, htons(o), + htons(n), false); + } + memset(opt + i, TCPOPT_NOP, optl); + return; + } + + /* option not found, continue. This allows to do multiple + * option removals per rule. + */ + return; +err: + regs->verdict.code = NFT_BREAK; + return; +drop: + /* can't remove, no choice but to drop */ + regs->verdict.code = NF_DROP; +} + +static void nft_exthdr_sctp_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + unsigned int offset = nft_thoff(pkt) + sizeof(struct sctphdr); + struct nft_exthdr *priv = nft_expr_priv(expr); + u32 *dest = ®s->data[priv->dreg]; + const struct sctp_chunkhdr *sch; + struct sctp_chunkhdr _sch; + + if (pkt->tprot != IPPROTO_SCTP) + goto err; + + do { + sch = skb_header_pointer(pkt->skb, offset, sizeof(_sch), &_sch); + if (!sch || !sch->length) + break; + + if (sch->type == priv->type) { + if (priv->flags & NFT_EXTHDR_F_PRESENT) { + nft_reg_store8(dest, true); + return; + } + if (priv->offset + priv->len > ntohs(sch->length) || + offset + ntohs(sch->length) > pkt->skb->len) + break; + + if (nft_skb_copy_to_reg(pkt->skb, offset + priv->offset, + dest, priv->len) < 0) + break; + return; + } + offset += SCTP_PAD4(ntohs(sch->length)); + } while (offset < pkt->skb->len); +err: + if (priv->flags & NFT_EXTHDR_F_PRESENT) + nft_reg_store8(dest, false); + else + regs->verdict.code = NFT_BREAK; +} + +static void nft_exthdr_dccp_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_exthdr *priv = nft_expr_priv(expr); + unsigned int thoff, dataoff, optoff, optlen, i; + u32 *dest = ®s->data[priv->dreg]; + const struct dccp_hdr *dh; + struct dccp_hdr _dh; + + if (pkt->tprot != IPPROTO_DCCP || pkt->fragoff) + goto err; + + thoff = nft_thoff(pkt); + + dh = skb_header_pointer(pkt->skb, thoff, sizeof(_dh), &_dh); + if (!dh) + goto err; + + dataoff = dh->dccph_doff * sizeof(u32); + optoff = __dccp_hdr_len(dh); + if (dataoff <= optoff) + goto err; + + optlen = dataoff - optoff; + + for (i = 0; i < optlen; ) { + /* Options 0 (DCCPO_PADDING) - 31 (DCCPO_MAX_RESERVED) are 1B in + * the length; the remaining options are at least 2B long. In + * all cases, the first byte contains the option type. In + * multi-byte options, the second byte contains the option + * length, which must be at least two: 1 for the type plus 1 for + * the length plus 0-253 for any following option data. We + * aren't interested in the option data, only the type and the + * length, so we don't need to read more than two bytes at a + * time. + */ + unsigned int buflen = optlen - i; + u8 buf[2], *bufp; + u8 type, len; + + if (buflen > sizeof(buf)) + buflen = sizeof(buf); + + bufp = skb_header_pointer(pkt->skb, thoff + optoff + i, buflen, + &buf); + if (!bufp) + goto err; + + type = bufp[0]; + + if (type == priv->type) { + nft_reg_store8(dest, 1); + return; + } + + if (type <= DCCPO_MAX_RESERVED) { + i++; + continue; + } + + if (buflen < 2) + goto err; + + len = bufp[1]; + + if (len < 2) + goto err; + + i += len; + } + +err: + *dest = 0; +} + +static const struct nla_policy nft_exthdr_policy[NFTA_EXTHDR_MAX + 1] = { + [NFTA_EXTHDR_DREG] = { .type = NLA_U32 }, + [NFTA_EXTHDR_TYPE] = { .type = NLA_U8 }, + [NFTA_EXTHDR_OFFSET] = { .type = NLA_U32 }, + [NFTA_EXTHDR_LEN] = { .type = NLA_U32 }, + [NFTA_EXTHDR_FLAGS] = { .type = NLA_U32 }, + [NFTA_EXTHDR_OP] = { .type = NLA_U32 }, + [NFTA_EXTHDR_SREG] = { .type = NLA_U32 }, +}; + +static int nft_exthdr_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_exthdr *priv = nft_expr_priv(expr); + u32 offset, len, flags = 0, op = NFT_EXTHDR_OP_IPV6; + int err; + + if (!tb[NFTA_EXTHDR_DREG] || + !tb[NFTA_EXTHDR_TYPE] || + !tb[NFTA_EXTHDR_OFFSET] || + !tb[NFTA_EXTHDR_LEN]) + return -EINVAL; + + err = nft_parse_u32_check(tb[NFTA_EXTHDR_OFFSET], U8_MAX, &offset); + if (err < 0) + return err; + + err = nft_parse_u32_check(tb[NFTA_EXTHDR_LEN], U8_MAX, &len); + if (err < 0) + return err; + + if (tb[NFTA_EXTHDR_FLAGS]) { + err = nft_parse_u32_check(tb[NFTA_EXTHDR_FLAGS], U8_MAX, &flags); + if (err < 0) + return err; + + if (flags & ~NFT_EXTHDR_F_PRESENT) + return -EINVAL; + } + + if (tb[NFTA_EXTHDR_OP]) { + err = nft_parse_u32_check(tb[NFTA_EXTHDR_OP], U8_MAX, &op); + if (err < 0) + return err; + } + + priv->type = nla_get_u8(tb[NFTA_EXTHDR_TYPE]); + priv->offset = offset; + priv->len = len; + priv->flags = flags; + priv->op = op; + + return nft_parse_register_store(ctx, tb[NFTA_EXTHDR_DREG], + &priv->dreg, NULL, NFT_DATA_VALUE, + priv->len); +} + +static int nft_exthdr_tcp_set_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_exthdr *priv = nft_expr_priv(expr); + u32 offset, len, flags = 0, op = NFT_EXTHDR_OP_IPV6; + int err; + + if (!tb[NFTA_EXTHDR_SREG] || + !tb[NFTA_EXTHDR_TYPE] || + !tb[NFTA_EXTHDR_OFFSET] || + !tb[NFTA_EXTHDR_LEN]) + return -EINVAL; + + if (tb[NFTA_EXTHDR_DREG] || tb[NFTA_EXTHDR_FLAGS]) + return -EINVAL; + + err = nft_parse_u32_check(tb[NFTA_EXTHDR_OFFSET], U8_MAX, &offset); + if (err < 0) + return err; + + err = nft_parse_u32_check(tb[NFTA_EXTHDR_LEN], U8_MAX, &len); + if (err < 0) + return err; + + if (offset < 2) + return -EOPNOTSUPP; + + switch (len) { + case 2: break; + case 4: break; + default: + return -EOPNOTSUPP; + } + + err = nft_parse_u32_check(tb[NFTA_EXTHDR_OP], U8_MAX, &op); + if (err < 0) + return err; + + priv->type = nla_get_u8(tb[NFTA_EXTHDR_TYPE]); + priv->offset = offset; + priv->len = len; + priv->flags = flags; + priv->op = op; + + return nft_parse_register_load(tb[NFTA_EXTHDR_SREG], &priv->sreg, + priv->len); +} + +static int nft_exthdr_tcp_strip_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_exthdr *priv = nft_expr_priv(expr); + + if (tb[NFTA_EXTHDR_SREG] || + tb[NFTA_EXTHDR_DREG] || + tb[NFTA_EXTHDR_FLAGS] || + tb[NFTA_EXTHDR_OFFSET] || + tb[NFTA_EXTHDR_LEN]) + return -EINVAL; + + if (!tb[NFTA_EXTHDR_TYPE]) + return -EINVAL; + + priv->type = nla_get_u8(tb[NFTA_EXTHDR_TYPE]); + priv->op = NFT_EXTHDR_OP_TCPOPT; + + return 0; +} + +static int nft_exthdr_ipv4_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_exthdr *priv = nft_expr_priv(expr); + int err = nft_exthdr_init(ctx, expr, tb); + + if (err < 0) + return err; + + switch (priv->type) { + case IPOPT_SSRR: + case IPOPT_LSRR: + case IPOPT_RR: + case IPOPT_RA: + break; + default: + return -EOPNOTSUPP; + } + return 0; +} + +static int nft_exthdr_dccp_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_exthdr *priv = nft_expr_priv(expr); + int err = nft_exthdr_init(ctx, expr, tb); + + if (err < 0) + return err; + + if (!(priv->flags & NFT_EXTHDR_F_PRESENT)) + return -EOPNOTSUPP; + + return 0; +} + +static int nft_exthdr_dump_common(struct sk_buff *skb, const struct nft_exthdr *priv) +{ + if (nla_put_u8(skb, NFTA_EXTHDR_TYPE, priv->type)) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_EXTHDR_OFFSET, htonl(priv->offset))) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_EXTHDR_LEN, htonl(priv->len))) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_EXTHDR_FLAGS, htonl(priv->flags))) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_EXTHDR_OP, htonl(priv->op))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static int nft_exthdr_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_exthdr *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_EXTHDR_DREG, priv->dreg)) + return -1; + + return nft_exthdr_dump_common(skb, priv); +} + +static int nft_exthdr_dump_set(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_exthdr *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_EXTHDR_SREG, priv->sreg)) + return -1; + + return nft_exthdr_dump_common(skb, priv); +} + +static int nft_exthdr_dump_strip(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_exthdr *priv = nft_expr_priv(expr); + + return nft_exthdr_dump_common(skb, priv); +} + +static bool nft_exthdr_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + const struct nft_exthdr *priv = nft_expr_priv(expr); + const struct nft_exthdr *exthdr; + + if (!nft_reg_track_cmp(track, expr, priv->dreg)) { + nft_reg_track_update(track, expr, priv->dreg, priv->len); + return false; + } + + exthdr = nft_expr_priv(track->regs[priv->dreg].selector); + if (priv->type != exthdr->type || + priv->op != exthdr->op || + priv->flags != exthdr->flags || + priv->offset != exthdr->offset || + priv->len != exthdr->len) { + nft_reg_track_update(track, expr, priv->dreg, priv->len); + return false; + } + + if (!track->regs[priv->dreg].bitwise) + return true; + + return nft_expr_reduce_bitwise(track, expr); +} + +static const struct nft_expr_ops nft_exthdr_ipv6_ops = { + .type = &nft_exthdr_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_exthdr)), + .eval = nft_exthdr_ipv6_eval, + .init = nft_exthdr_init, + .dump = nft_exthdr_dump, + .reduce = nft_exthdr_reduce, +}; + +static const struct nft_expr_ops nft_exthdr_ipv4_ops = { + .type = &nft_exthdr_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_exthdr)), + .eval = nft_exthdr_ipv4_eval, + .init = nft_exthdr_ipv4_init, + .dump = nft_exthdr_dump, + .reduce = nft_exthdr_reduce, +}; + +static const struct nft_expr_ops nft_exthdr_tcp_ops = { + .type = &nft_exthdr_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_exthdr)), + .eval = nft_exthdr_tcp_eval, + .init = nft_exthdr_init, + .dump = nft_exthdr_dump, + .reduce = nft_exthdr_reduce, +}; + +static const struct nft_expr_ops nft_exthdr_tcp_set_ops = { + .type = &nft_exthdr_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_exthdr)), + .eval = nft_exthdr_tcp_set_eval, + .init = nft_exthdr_tcp_set_init, + .dump = nft_exthdr_dump_set, + .reduce = NFT_REDUCE_READONLY, +}; + +static const struct nft_expr_ops nft_exthdr_tcp_strip_ops = { + .type = &nft_exthdr_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_exthdr)), + .eval = nft_exthdr_tcp_strip_eval, + .init = nft_exthdr_tcp_strip_init, + .dump = nft_exthdr_dump_strip, + .reduce = NFT_REDUCE_READONLY, +}; + +static const struct nft_expr_ops nft_exthdr_sctp_ops = { + .type = &nft_exthdr_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_exthdr)), + .eval = nft_exthdr_sctp_eval, + .init = nft_exthdr_init, + .dump = nft_exthdr_dump, + .reduce = nft_exthdr_reduce, +}; + +static const struct nft_expr_ops nft_exthdr_dccp_ops = { + .type = &nft_exthdr_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_exthdr)), + .eval = nft_exthdr_dccp_eval, + .init = nft_exthdr_dccp_init, + .dump = nft_exthdr_dump, + .reduce = nft_exthdr_reduce, +}; + +static const struct nft_expr_ops * +nft_exthdr_select_ops(const struct nft_ctx *ctx, + const struct nlattr * const tb[]) +{ + u32 op; + + if (!tb[NFTA_EXTHDR_OP]) + return &nft_exthdr_ipv6_ops; + + if (tb[NFTA_EXTHDR_SREG] && tb[NFTA_EXTHDR_DREG]) + return ERR_PTR(-EOPNOTSUPP); + + op = ntohl(nla_get_be32(tb[NFTA_EXTHDR_OP])); + switch (op) { + case NFT_EXTHDR_OP_TCPOPT: + if (tb[NFTA_EXTHDR_SREG]) + return &nft_exthdr_tcp_set_ops; + if (tb[NFTA_EXTHDR_DREG]) + return &nft_exthdr_tcp_ops; + return &nft_exthdr_tcp_strip_ops; + case NFT_EXTHDR_OP_IPV6: + if (tb[NFTA_EXTHDR_DREG]) + return &nft_exthdr_ipv6_ops; + break; + case NFT_EXTHDR_OP_IPV4: + if (ctx->family != NFPROTO_IPV6) { + if (tb[NFTA_EXTHDR_DREG]) + return &nft_exthdr_ipv4_ops; + } + break; + case NFT_EXTHDR_OP_SCTP: + if (tb[NFTA_EXTHDR_DREG]) + return &nft_exthdr_sctp_ops; + break; + case NFT_EXTHDR_OP_DCCP: + if (tb[NFTA_EXTHDR_DREG]) + return &nft_exthdr_dccp_ops; + break; + } + + return ERR_PTR(-EOPNOTSUPP); +} + +struct nft_expr_type nft_exthdr_type __read_mostly = { + .name = "exthdr", + .select_ops = nft_exthdr_select_ops, + .policy = nft_exthdr_policy, + .maxattr = NFTA_EXTHDR_MAX, + .owner = THIS_MODULE, +}; diff --git a/net/netfilter/nft_fib.c b/net/netfilter/nft_fib.c new file mode 100644 index 000000000..5748415f7 --- /dev/null +++ b/net/netfilter/nft_fib.c @@ -0,0 +1,210 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * + * Generic part shared by ipv4 and ipv6 backends. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +const struct nla_policy nft_fib_policy[NFTA_FIB_MAX + 1] = { + [NFTA_FIB_DREG] = { .type = NLA_U32 }, + [NFTA_FIB_RESULT] = { .type = NLA_U32 }, + [NFTA_FIB_FLAGS] = { .type = NLA_U32 }, +}; +EXPORT_SYMBOL(nft_fib_policy); + +#define NFTA_FIB_F_ALL (NFTA_FIB_F_SADDR | NFTA_FIB_F_DADDR | \ + NFTA_FIB_F_MARK | NFTA_FIB_F_IIF | NFTA_FIB_F_OIF | \ + NFTA_FIB_F_PRESENT) + +int nft_fib_validate(const struct nft_ctx *ctx, const struct nft_expr *expr, + const struct nft_data **data) +{ + const struct nft_fib *priv = nft_expr_priv(expr); + unsigned int hooks; + + switch (priv->result) { + case NFT_FIB_RESULT_OIF: + case NFT_FIB_RESULT_OIFNAME: + hooks = (1 << NF_INET_PRE_ROUTING); + if (priv->flags & NFTA_FIB_F_IIF) { + hooks |= (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_FORWARD); + } + break; + case NFT_FIB_RESULT_ADDRTYPE: + if (priv->flags & NFTA_FIB_F_IIF) + hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_FORWARD); + else if (priv->flags & NFTA_FIB_F_OIF) + hooks = (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_FORWARD); + else + hooks = (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_FORWARD) | + (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_POST_ROUTING); + + break; + default: + return -EINVAL; + } + + return nft_chain_validate_hooks(ctx->chain, hooks); +} +EXPORT_SYMBOL_GPL(nft_fib_validate); + +int nft_fib_init(const struct nft_ctx *ctx, const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_fib *priv = nft_expr_priv(expr); + unsigned int len; + int err; + + if (!tb[NFTA_FIB_DREG] || !tb[NFTA_FIB_RESULT] || !tb[NFTA_FIB_FLAGS]) + return -EINVAL; + + priv->flags = ntohl(nla_get_be32(tb[NFTA_FIB_FLAGS])); + + if (priv->flags == 0 || (priv->flags & ~NFTA_FIB_F_ALL)) + return -EINVAL; + + if ((priv->flags & (NFTA_FIB_F_SADDR | NFTA_FIB_F_DADDR)) == + (NFTA_FIB_F_SADDR | NFTA_FIB_F_DADDR)) + return -EINVAL; + if ((priv->flags & (NFTA_FIB_F_IIF | NFTA_FIB_F_OIF)) == + (NFTA_FIB_F_IIF | NFTA_FIB_F_OIF)) + return -EINVAL; + if ((priv->flags & (NFTA_FIB_F_SADDR | NFTA_FIB_F_DADDR)) == 0) + return -EINVAL; + + priv->result = ntohl(nla_get_be32(tb[NFTA_FIB_RESULT])); + + switch (priv->result) { + case NFT_FIB_RESULT_OIF: + if (priv->flags & NFTA_FIB_F_OIF) + return -EINVAL; + len = sizeof(int); + break; + case NFT_FIB_RESULT_OIFNAME: + if (priv->flags & NFTA_FIB_F_OIF) + return -EINVAL; + len = IFNAMSIZ; + break; + case NFT_FIB_RESULT_ADDRTYPE: + len = sizeof(u32); + break; + default: + return -EINVAL; + } + + err = nft_parse_register_store(ctx, tb[NFTA_FIB_DREG], &priv->dreg, + NULL, NFT_DATA_VALUE, len); + if (err < 0) + return err; + + return 0; +} +EXPORT_SYMBOL_GPL(nft_fib_init); + +int nft_fib_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_fib *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_FIB_DREG, priv->dreg)) + return -1; + + if (nla_put_be32(skb, NFTA_FIB_RESULT, htonl(priv->result))) + return -1; + + if (nla_put_be32(skb, NFTA_FIB_FLAGS, htonl(priv->flags))) + return -1; + + return 0; +} +EXPORT_SYMBOL_GPL(nft_fib_dump); + +void nft_fib_store_result(void *reg, const struct nft_fib *priv, + const struct net_device *dev) +{ + u32 *dreg = reg; + int index; + + switch (priv->result) { + case NFT_FIB_RESULT_OIF: + index = dev ? dev->ifindex : 0; + if (priv->flags & NFTA_FIB_F_PRESENT) + nft_reg_store8(dreg, !!index); + else + *dreg = index; + + break; + case NFT_FIB_RESULT_OIFNAME: + if (priv->flags & NFTA_FIB_F_PRESENT) + nft_reg_store8(dreg, !!dev); + else + strncpy(reg, dev ? dev->name : "", IFNAMSIZ); + break; + default: + WARN_ON_ONCE(1); + *dreg = 0; + break; + } +} +EXPORT_SYMBOL_GPL(nft_fib_store_result); + +bool nft_fib_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + const struct nft_fib *priv = nft_expr_priv(expr); + unsigned int len = NFT_REG32_SIZE; + const struct nft_fib *fib; + + switch (priv->result) { + case NFT_FIB_RESULT_OIF: + break; + case NFT_FIB_RESULT_OIFNAME: + if (priv->flags & NFTA_FIB_F_PRESENT) + len = NFT_REG32_SIZE; + else + len = IFNAMSIZ; + break; + case NFT_FIB_RESULT_ADDRTYPE: + break; + default: + WARN_ON_ONCE(1); + break; + } + + if (!nft_reg_track_cmp(track, expr, priv->dreg)) { + nft_reg_track_update(track, expr, priv->dreg, len); + return false; + } + + fib = nft_expr_priv(track->regs[priv->dreg].selector); + if (priv->result != fib->result || + priv->flags != fib->flags) { + nft_reg_track_update(track, expr, priv->dreg, len); + return false; + } + + if (!track->regs[priv->dreg].bitwise) + return true; + + return false; +} +EXPORT_SYMBOL_GPL(nft_fib_reduce); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Florian Westphal "); diff --git a/net/netfilter/nft_fib_inet.c b/net/netfilter/nft_fib_inet.c new file mode 100644 index 000000000..666a3741d --- /dev/null +++ b/net/netfilter/nft_fib_inet.c @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +static void nft_fib_inet_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_fib *priv = nft_expr_priv(expr); + + switch (nft_pf(pkt)) { + case NFPROTO_IPV4: + switch (priv->result) { + case NFT_FIB_RESULT_OIF: + case NFT_FIB_RESULT_OIFNAME: + return nft_fib4_eval(expr, regs, pkt); + case NFT_FIB_RESULT_ADDRTYPE: + return nft_fib4_eval_type(expr, regs, pkt); + } + break; + case NFPROTO_IPV6: + switch (priv->result) { + case NFT_FIB_RESULT_OIF: + case NFT_FIB_RESULT_OIFNAME: + return nft_fib6_eval(expr, regs, pkt); + case NFT_FIB_RESULT_ADDRTYPE: + return nft_fib6_eval_type(expr, regs, pkt); + } + break; + } + + regs->verdict.code = NF_DROP; +} + +static struct nft_expr_type nft_fib_inet_type; +static const struct nft_expr_ops nft_fib_inet_ops = { + .type = &nft_fib_inet_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_fib)), + .eval = nft_fib_inet_eval, + .init = nft_fib_init, + .dump = nft_fib_dump, + .validate = nft_fib_validate, + .reduce = nft_fib_reduce, +}; + +static struct nft_expr_type nft_fib_inet_type __read_mostly = { + .family = NFPROTO_INET, + .name = "fib", + .ops = &nft_fib_inet_ops, + .policy = nft_fib_policy, + .maxattr = NFTA_FIB_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_fib_inet_module_init(void) +{ + return nft_register_expr(&nft_fib_inet_type); +} + +static void __exit nft_fib_inet_module_exit(void) +{ + nft_unregister_expr(&nft_fib_inet_type); +} + +module_init(nft_fib_inet_module_init); +module_exit(nft_fib_inet_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Florian Westphal "); +MODULE_ALIAS_NFT_AF_EXPR(1, "fib"); +MODULE_DESCRIPTION("nftables fib inet support"); diff --git a/net/netfilter/nft_fib_netdev.c b/net/netfilter/nft_fib_netdev.c new file mode 100644 index 000000000..9121ec64e --- /dev/null +++ b/net/netfilter/nft_fib_netdev.c @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2017 Pablo M. Bermudo Garay + * + * This code is based on net/netfilter/nft_fib_inet.c, written by + * Florian Westphal . + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +static void nft_fib_netdev_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_fib *priv = nft_expr_priv(expr); + + switch (ntohs(pkt->skb->protocol)) { + case ETH_P_IP: + switch (priv->result) { + case NFT_FIB_RESULT_OIF: + case NFT_FIB_RESULT_OIFNAME: + return nft_fib4_eval(expr, regs, pkt); + case NFT_FIB_RESULT_ADDRTYPE: + return nft_fib4_eval_type(expr, regs, pkt); + } + break; + case ETH_P_IPV6: + if (!ipv6_mod_enabled()) + break; + switch (priv->result) { + case NFT_FIB_RESULT_OIF: + case NFT_FIB_RESULT_OIFNAME: + return nft_fib6_eval(expr, regs, pkt); + case NFT_FIB_RESULT_ADDRTYPE: + return nft_fib6_eval_type(expr, regs, pkt); + } + break; + } + + regs->verdict.code = NFT_BREAK; +} + +static struct nft_expr_type nft_fib_netdev_type; +static const struct nft_expr_ops nft_fib_netdev_ops = { + .type = &nft_fib_netdev_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_fib)), + .eval = nft_fib_netdev_eval, + .init = nft_fib_init, + .dump = nft_fib_dump, + .validate = nft_fib_validate, + .reduce = nft_fib_reduce, +}; + +static struct nft_expr_type nft_fib_netdev_type __read_mostly = { + .family = NFPROTO_NETDEV, + .name = "fib", + .ops = &nft_fib_netdev_ops, + .policy = nft_fib_policy, + .maxattr = NFTA_FIB_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_fib_netdev_module_init(void) +{ + return nft_register_expr(&nft_fib_netdev_type); +} + +static void __exit nft_fib_netdev_module_exit(void) +{ + nft_unregister_expr(&nft_fib_netdev_type); +} + +module_init(nft_fib_netdev_module_init); +module_exit(nft_fib_netdev_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo M. Bermudo Garay "); +MODULE_ALIAS_NFT_AF_EXPR(5, "fib"); +MODULE_DESCRIPTION("nftables netdev fib lookups support"); diff --git a/net/netfilter/nft_flow_offload.c b/net/netfilter/nft_flow_offload.c new file mode 100644 index 000000000..3d9f6dda5 --- /dev/null +++ b/net/netfilter/nft_flow_offload.c @@ -0,0 +1,527 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include /* for ipv4 options. */ +#include +#include +#include +#include +#include + +struct nft_flow_offload { + struct nft_flowtable *flowtable; +}; + +static enum flow_offload_xmit_type nft_xmit_type(struct dst_entry *dst) +{ + if (dst_xfrm(dst)) + return FLOW_OFFLOAD_XMIT_XFRM; + + return FLOW_OFFLOAD_XMIT_NEIGH; +} + +static void nft_default_forward_path(struct nf_flow_route *route, + struct dst_entry *dst_cache, + enum ip_conntrack_dir dir) +{ + route->tuple[!dir].in.ifindex = dst_cache->dev->ifindex; + route->tuple[dir].dst = dst_cache; + route->tuple[dir].xmit_type = nft_xmit_type(dst_cache); +} + +static bool nft_is_valid_ether_device(const struct net_device *dev) +{ + if (!dev || (dev->flags & IFF_LOOPBACK) || dev->type != ARPHRD_ETHER || + dev->addr_len != ETH_ALEN || !is_valid_ether_addr(dev->dev_addr)) + return false; + + return true; +} + +static int nft_dev_fill_forward_path(const struct nf_flow_route *route, + const struct dst_entry *dst_cache, + const struct nf_conn *ct, + enum ip_conntrack_dir dir, u8 *ha, + struct net_device_path_stack *stack) +{ + const void *daddr = &ct->tuplehash[!dir].tuple.src.u3; + struct net_device *dev = dst_cache->dev; + struct neighbour *n; + u8 nud_state; + + if (!nft_is_valid_ether_device(dev)) + goto out; + + n = dst_neigh_lookup(dst_cache, daddr); + if (!n) + return -1; + + read_lock_bh(&n->lock); + nud_state = n->nud_state; + ether_addr_copy(ha, n->ha); + read_unlock_bh(&n->lock); + neigh_release(n); + + if (!(nud_state & NUD_VALID)) + return -1; + +out: + return dev_fill_forward_path(dev, ha, stack); +} + +struct nft_forward_info { + const struct net_device *indev; + const struct net_device *outdev; + const struct net_device *hw_outdev; + struct id { + __u16 id; + __be16 proto; + } encap[NF_FLOW_TABLE_ENCAP_MAX]; + u8 num_encaps; + u8 ingress_vlans; + u8 h_source[ETH_ALEN]; + u8 h_dest[ETH_ALEN]; + enum flow_offload_xmit_type xmit_type; +}; + +static void nft_dev_path_info(const struct net_device_path_stack *stack, + struct nft_forward_info *info, + unsigned char *ha, struct nf_flowtable *flowtable) +{ + const struct net_device_path *path; + int i; + + memcpy(info->h_dest, ha, ETH_ALEN); + + for (i = 0; i < stack->num_paths; i++) { + path = &stack->path[i]; + switch (path->type) { + case DEV_PATH_ETHERNET: + case DEV_PATH_DSA: + case DEV_PATH_VLAN: + case DEV_PATH_PPPOE: + info->indev = path->dev; + if (is_zero_ether_addr(info->h_source)) + memcpy(info->h_source, path->dev->dev_addr, ETH_ALEN); + + if (path->type == DEV_PATH_ETHERNET) + break; + if (path->type == DEV_PATH_DSA) { + i = stack->num_paths; + break; + } + + /* DEV_PATH_VLAN and DEV_PATH_PPPOE */ + if (info->num_encaps >= NF_FLOW_TABLE_ENCAP_MAX) { + info->indev = NULL; + break; + } + if (!info->outdev) + info->outdev = path->dev; + info->encap[info->num_encaps].id = path->encap.id; + info->encap[info->num_encaps].proto = path->encap.proto; + info->num_encaps++; + if (path->type == DEV_PATH_PPPOE) + memcpy(info->h_dest, path->encap.h_dest, ETH_ALEN); + break; + case DEV_PATH_BRIDGE: + if (is_zero_ether_addr(info->h_source)) + memcpy(info->h_source, path->dev->dev_addr, ETH_ALEN); + + switch (path->bridge.vlan_mode) { + case DEV_PATH_BR_VLAN_UNTAG_HW: + info->ingress_vlans |= BIT(info->num_encaps - 1); + break; + case DEV_PATH_BR_VLAN_TAG: + info->encap[info->num_encaps].id = path->bridge.vlan_id; + info->encap[info->num_encaps].proto = path->bridge.vlan_proto; + info->num_encaps++; + break; + case DEV_PATH_BR_VLAN_UNTAG: + info->num_encaps--; + break; + case DEV_PATH_BR_VLAN_KEEP: + break; + } + info->xmit_type = FLOW_OFFLOAD_XMIT_DIRECT; + break; + default: + info->indev = NULL; + break; + } + } + if (!info->outdev) + info->outdev = info->indev; + + info->hw_outdev = info->indev; + + if (nf_flowtable_hw_offload(flowtable) && + nft_is_valid_ether_device(info->indev)) + info->xmit_type = FLOW_OFFLOAD_XMIT_DIRECT; +} + +static bool nft_flowtable_find_dev(const struct net_device *dev, + struct nft_flowtable *ft) +{ + struct nft_hook *hook; + bool found = false; + + list_for_each_entry_rcu(hook, &ft->hook_list, list) { + if (hook->ops.dev != dev) + continue; + + found = true; + break; + } + + return found; +} + +static void nft_dev_forward_path(struct nf_flow_route *route, + const struct nf_conn *ct, + enum ip_conntrack_dir dir, + struct nft_flowtable *ft) +{ + const struct dst_entry *dst = route->tuple[dir].dst; + struct net_device_path_stack stack; + struct nft_forward_info info = {}; + unsigned char ha[ETH_ALEN]; + int i; + + if (nft_dev_fill_forward_path(route, dst, ct, dir, ha, &stack) >= 0) + nft_dev_path_info(&stack, &info, ha, &ft->data); + + if (!info.indev || !nft_flowtable_find_dev(info.indev, ft)) + return; + + route->tuple[!dir].in.ifindex = info.indev->ifindex; + for (i = 0; i < info.num_encaps; i++) { + route->tuple[!dir].in.encap[i].id = info.encap[i].id; + route->tuple[!dir].in.encap[i].proto = info.encap[i].proto; + } + route->tuple[!dir].in.num_encaps = info.num_encaps; + route->tuple[!dir].in.ingress_vlans = info.ingress_vlans; + + if (info.xmit_type == FLOW_OFFLOAD_XMIT_DIRECT) { + memcpy(route->tuple[dir].out.h_source, info.h_source, ETH_ALEN); + memcpy(route->tuple[dir].out.h_dest, info.h_dest, ETH_ALEN); + route->tuple[dir].out.ifindex = info.outdev->ifindex; + route->tuple[dir].out.hw_ifindex = info.hw_outdev->ifindex; + route->tuple[dir].xmit_type = info.xmit_type; + } +} + +static int nft_flow_route(const struct nft_pktinfo *pkt, + const struct nf_conn *ct, + struct nf_flow_route *route, + enum ip_conntrack_dir dir, + struct nft_flowtable *ft) +{ + struct dst_entry *this_dst = skb_dst(pkt->skb); + struct dst_entry *other_dst = NULL; + struct flowi fl; + + memset(&fl, 0, sizeof(fl)); + switch (nft_pf(pkt)) { + case NFPROTO_IPV4: + fl.u.ip4.daddr = ct->tuplehash[dir].tuple.src.u3.ip; + fl.u.ip4.saddr = ct->tuplehash[!dir].tuple.src.u3.ip; + fl.u.ip4.flowi4_oif = nft_in(pkt)->ifindex; + fl.u.ip4.flowi4_iif = this_dst->dev->ifindex; + fl.u.ip4.flowi4_tos = RT_TOS(ip_hdr(pkt->skb)->tos); + fl.u.ip4.flowi4_mark = pkt->skb->mark; + fl.u.ip4.flowi4_flags = FLOWI_FLAG_ANYSRC; + break; + case NFPROTO_IPV6: + fl.u.ip6.daddr = ct->tuplehash[dir].tuple.src.u3.in6; + fl.u.ip6.saddr = ct->tuplehash[!dir].tuple.src.u3.in6; + fl.u.ip6.flowi6_oif = nft_in(pkt)->ifindex; + fl.u.ip6.flowi6_iif = this_dst->dev->ifindex; + fl.u.ip6.flowlabel = ip6_flowinfo(ipv6_hdr(pkt->skb)); + fl.u.ip6.flowi6_mark = pkt->skb->mark; + fl.u.ip6.flowi6_flags = FLOWI_FLAG_ANYSRC; + break; + } + + nf_route(nft_net(pkt), &other_dst, &fl, false, nft_pf(pkt)); + if (!other_dst) + return -ENOENT; + + nft_default_forward_path(route, this_dst, dir); + nft_default_forward_path(route, other_dst, !dir); + + if (route->tuple[dir].xmit_type == FLOW_OFFLOAD_XMIT_NEIGH && + route->tuple[!dir].xmit_type == FLOW_OFFLOAD_XMIT_NEIGH) { + nft_dev_forward_path(route, ct, dir, ft); + nft_dev_forward_path(route, ct, !dir, ft); + } + + return 0; +} + +static bool nft_flow_offload_skip(struct sk_buff *skb, int family) +{ + if (skb_sec_path(skb)) + return true; + + if (family == NFPROTO_IPV4) { + const struct ip_options *opt; + + opt = &(IPCB(skb)->opt); + + if (unlikely(opt->optlen)) + return true; + } + + return false; +} + +static void nft_flow_offload_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_flow_offload *priv = nft_expr_priv(expr); + struct nf_flowtable *flowtable = &priv->flowtable->data; + struct tcphdr _tcph, *tcph = NULL; + struct nf_flow_route route = {}; + enum ip_conntrack_info ctinfo; + struct flow_offload *flow; + enum ip_conntrack_dir dir; + struct nf_conn *ct; + int ret; + + if (nft_flow_offload_skip(pkt->skb, nft_pf(pkt))) + goto out; + + ct = nf_ct_get(pkt->skb, &ctinfo); + if (!ct) + goto out; + + switch (ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.protonum) { + case IPPROTO_TCP: + tcph = skb_header_pointer(pkt->skb, nft_thoff(pkt), + sizeof(_tcph), &_tcph); + if (unlikely(!tcph || tcph->fin || tcph->rst || + !nf_conntrack_tcp_established(ct))) + goto out; + break; + case IPPROTO_UDP: + break; +#ifdef CONFIG_NF_CT_PROTO_GRE + case IPPROTO_GRE: { + struct nf_conntrack_tuple *tuple; + + if (ct->status & IPS_NAT_MASK) + goto out; + tuple = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple; + /* No support for GRE v1 */ + if (tuple->src.u.gre.key || tuple->dst.u.gre.key) + goto out; + break; + } +#endif + default: + goto out; + } + + if (nf_ct_ext_exist(ct, NF_CT_EXT_HELPER) || + ct->status & (IPS_SEQ_ADJUST | IPS_NAT_CLASH)) + goto out; + + if (!nf_ct_is_confirmed(ct)) + goto out; + + if (test_and_set_bit(IPS_OFFLOAD_BIT, &ct->status)) + goto out; + + dir = CTINFO2DIR(ctinfo); + if (nft_flow_route(pkt, ct, &route, dir, priv->flowtable) < 0) + goto err_flow_route; + + flow = flow_offload_alloc(ct); + if (!flow) + goto err_flow_alloc; + + if (flow_offload_route_init(flow, &route) < 0) + goto err_flow_add; + + if (tcph) { + ct->proto.tcp.seen[0].flags |= IP_CT_TCP_FLAG_BE_LIBERAL; + ct->proto.tcp.seen[1].flags |= IP_CT_TCP_FLAG_BE_LIBERAL; + } + + ret = flow_offload_add(flowtable, flow); + if (ret < 0) + goto err_flow_add; + + dst_release(route.tuple[!dir].dst); + return; + +err_flow_add: + flow_offload_free(flow); +err_flow_alloc: + dst_release(route.tuple[!dir].dst); +err_flow_route: + clear_bit(IPS_OFFLOAD_BIT, &ct->status); +out: + regs->verdict.code = NFT_BREAK; +} + +static int nft_flow_offload_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + unsigned int hook_mask = (1 << NF_INET_FORWARD); + + if (ctx->family != NFPROTO_IPV4 && + ctx->family != NFPROTO_IPV6 && + ctx->family != NFPROTO_INET) + return -EOPNOTSUPP; + + return nft_chain_validate_hooks(ctx->chain, hook_mask); +} + +static const struct nla_policy nft_flow_offload_policy[NFTA_FLOW_MAX + 1] = { + [NFTA_FLOW_TABLE_NAME] = { .type = NLA_STRING, + .len = NFT_NAME_MAXLEN - 1 }, +}; + +static int nft_flow_offload_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_flow_offload *priv = nft_expr_priv(expr); + u8 genmask = nft_genmask_next(ctx->net); + struct nft_flowtable *flowtable; + + if (!tb[NFTA_FLOW_TABLE_NAME]) + return -EINVAL; + + flowtable = nft_flowtable_lookup(ctx->table, tb[NFTA_FLOW_TABLE_NAME], + genmask); + if (IS_ERR(flowtable)) + return PTR_ERR(flowtable); + + if (!nft_use_inc(&flowtable->use)) + return -EMFILE; + + priv->flowtable = flowtable; + + return nf_ct_netns_get(ctx->net, ctx->family); +} + +static void nft_flow_offload_deactivate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + enum nft_trans_phase phase) +{ + struct nft_flow_offload *priv = nft_expr_priv(expr); + + nf_tables_deactivate_flowtable(ctx, priv->flowtable, phase); +} + +static void nft_flow_offload_activate(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_flow_offload *priv = nft_expr_priv(expr); + + nft_use_inc_restore(&priv->flowtable->use); +} + +static void nft_flow_offload_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + nf_ct_netns_put(ctx->net, ctx->family); +} + +static int nft_flow_offload_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + struct nft_flow_offload *priv = nft_expr_priv(expr); + + if (nla_put_string(skb, NFTA_FLOW_TABLE_NAME, priv->flowtable->name)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static struct nft_expr_type nft_flow_offload_type; +static const struct nft_expr_ops nft_flow_offload_ops = { + .type = &nft_flow_offload_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_flow_offload)), + .eval = nft_flow_offload_eval, + .init = nft_flow_offload_init, + .activate = nft_flow_offload_activate, + .deactivate = nft_flow_offload_deactivate, + .destroy = nft_flow_offload_destroy, + .validate = nft_flow_offload_validate, + .dump = nft_flow_offload_dump, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_flow_offload_type __read_mostly = { + .name = "flow_offload", + .ops = &nft_flow_offload_ops, + .policy = nft_flow_offload_policy, + .maxattr = NFTA_FLOW_MAX, + .owner = THIS_MODULE, +}; + +static int flow_offload_netdev_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + if (event != NETDEV_DOWN) + return NOTIFY_DONE; + + nf_flow_table_cleanup(dev); + + return NOTIFY_DONE; +} + +static struct notifier_block flow_offload_netdev_notifier = { + .notifier_call = flow_offload_netdev_event, +}; + +static int __init nft_flow_offload_module_init(void) +{ + int err; + + err = register_netdevice_notifier(&flow_offload_netdev_notifier); + if (err) + goto err; + + err = nft_register_expr(&nft_flow_offload_type); + if (err < 0) + goto register_expr; + + return 0; + +register_expr: + unregister_netdevice_notifier(&flow_offload_netdev_notifier); +err: + return err; +} + +static void __exit nft_flow_offload_module_exit(void) +{ + nft_unregister_expr(&nft_flow_offload_type); + unregister_netdevice_notifier(&flow_offload_netdev_notifier); +} + +module_init(nft_flow_offload_module_init); +module_exit(nft_flow_offload_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_ALIAS_NFT_EXPR("flow_offload"); +MODULE_DESCRIPTION("nftables hardware flow offload module"); diff --git a/net/netfilter/nft_fwd_netdev.c b/net/netfilter/nft_fwd_netdev.c new file mode 100644 index 000000000..7c5876dc9 --- /dev/null +++ b/net/netfilter/nft_fwd_netdev.c @@ -0,0 +1,271 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2015 Pablo Neira Ayuso + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_fwd_netdev { + u8 sreg_dev; +}; + +static void nft_fwd_netdev_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_fwd_netdev *priv = nft_expr_priv(expr); + int oif = regs->data[priv->sreg_dev]; + struct sk_buff *skb = pkt->skb; + + /* This is used by ifb only. */ + skb->skb_iif = skb->dev->ifindex; + skb_set_redirected(skb, nft_hook(pkt) == NF_NETDEV_INGRESS); + + nf_fwd_netdev_egress(pkt, oif); + regs->verdict.code = NF_STOLEN; +} + +static const struct nla_policy nft_fwd_netdev_policy[NFTA_FWD_MAX + 1] = { + [NFTA_FWD_SREG_DEV] = { .type = NLA_U32 }, + [NFTA_FWD_SREG_ADDR] = { .type = NLA_U32 }, + [NFTA_FWD_NFPROTO] = { .type = NLA_U32 }, +}; + +static int nft_fwd_netdev_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_fwd_netdev *priv = nft_expr_priv(expr); + + if (tb[NFTA_FWD_SREG_DEV] == NULL) + return -EINVAL; + + return nft_parse_register_load(tb[NFTA_FWD_SREG_DEV], &priv->sreg_dev, + sizeof(int)); +} + +static int nft_fwd_netdev_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + struct nft_fwd_netdev *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_FWD_SREG_DEV, priv->sreg_dev)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static int nft_fwd_netdev_offload(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_expr *expr) +{ + const struct nft_fwd_netdev *priv = nft_expr_priv(expr); + int oif = ctx->regs[priv->sreg_dev].data.data[0]; + + return nft_fwd_dup_netdev_offload(ctx, flow, FLOW_ACTION_REDIRECT, oif); +} + +static bool nft_fwd_netdev_offload_action(const struct nft_expr *expr) +{ + return true; +} + +struct nft_fwd_neigh { + u8 sreg_dev; + u8 sreg_addr; + u8 nfproto; +}; + +static void nft_fwd_neigh_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_fwd_neigh *priv = nft_expr_priv(expr); + void *addr = ®s->data[priv->sreg_addr]; + int oif = regs->data[priv->sreg_dev]; + unsigned int verdict = NF_STOLEN; + struct sk_buff *skb = pkt->skb; + struct net_device *dev; + int neigh_table; + + switch (priv->nfproto) { + case NFPROTO_IPV4: { + struct iphdr *iph; + + if (skb->protocol != htons(ETH_P_IP)) { + verdict = NFT_BREAK; + goto out; + } + if (skb_try_make_writable(skb, sizeof(*iph))) { + verdict = NF_DROP; + goto out; + } + iph = ip_hdr(skb); + ip_decrease_ttl(iph); + neigh_table = NEIGH_ARP_TABLE; + break; + } + case NFPROTO_IPV6: { + struct ipv6hdr *ip6h; + + if (skb->protocol != htons(ETH_P_IPV6)) { + verdict = NFT_BREAK; + goto out; + } + if (skb_try_make_writable(skb, sizeof(*ip6h))) { + verdict = NF_DROP; + goto out; + } + ip6h = ipv6_hdr(skb); + ip6h->hop_limit--; + neigh_table = NEIGH_ND_TABLE; + break; + } + default: + verdict = NFT_BREAK; + goto out; + } + + dev = dev_get_by_index_rcu(nft_net(pkt), oif); + if (dev == NULL) + return; + + skb->dev = dev; + skb_clear_tstamp(skb); + neigh_xmit(neigh_table, dev, addr, skb); +out: + regs->verdict.code = verdict; +} + +static int nft_fwd_neigh_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_fwd_neigh *priv = nft_expr_priv(expr); + unsigned int addr_len; + int err; + + if (!tb[NFTA_FWD_SREG_DEV] || + !tb[NFTA_FWD_SREG_ADDR] || + !tb[NFTA_FWD_NFPROTO]) + return -EINVAL; + + priv->nfproto = ntohl(nla_get_be32(tb[NFTA_FWD_NFPROTO])); + + switch (priv->nfproto) { + case NFPROTO_IPV4: + addr_len = sizeof(struct in_addr); + break; + case NFPROTO_IPV6: + addr_len = sizeof(struct in6_addr); + break; + default: + return -EOPNOTSUPP; + } + + err = nft_parse_register_load(tb[NFTA_FWD_SREG_DEV], &priv->sreg_dev, + sizeof(int)); + if (err < 0) + return err; + + return nft_parse_register_load(tb[NFTA_FWD_SREG_ADDR], &priv->sreg_addr, + addr_len); +} + +static int nft_fwd_neigh_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + struct nft_fwd_neigh *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_FWD_SREG_DEV, priv->sreg_dev) || + nft_dump_register(skb, NFTA_FWD_SREG_ADDR, priv->sreg_addr) || + nla_put_be32(skb, NFTA_FWD_NFPROTO, htonl(priv->nfproto))) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static int nft_fwd_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + return nft_chain_validate_hooks(ctx->chain, (1 << NF_NETDEV_INGRESS) | + (1 << NF_NETDEV_EGRESS)); +} + +static struct nft_expr_type nft_fwd_netdev_type; +static const struct nft_expr_ops nft_fwd_neigh_netdev_ops = { + .type = &nft_fwd_netdev_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_fwd_neigh)), + .eval = nft_fwd_neigh_eval, + .init = nft_fwd_neigh_init, + .dump = nft_fwd_neigh_dump, + .validate = nft_fwd_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +static const struct nft_expr_ops nft_fwd_netdev_ops = { + .type = &nft_fwd_netdev_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_fwd_netdev)), + .eval = nft_fwd_netdev_eval, + .init = nft_fwd_netdev_init, + .dump = nft_fwd_netdev_dump, + .validate = nft_fwd_validate, + .reduce = NFT_REDUCE_READONLY, + .offload = nft_fwd_netdev_offload, + .offload_action = nft_fwd_netdev_offload_action, +}; + +static const struct nft_expr_ops * +nft_fwd_select_ops(const struct nft_ctx *ctx, + const struct nlattr * const tb[]) +{ + if (tb[NFTA_FWD_SREG_ADDR]) + return &nft_fwd_neigh_netdev_ops; + if (tb[NFTA_FWD_SREG_DEV]) + return &nft_fwd_netdev_ops; + + return ERR_PTR(-EOPNOTSUPP); +} + +static struct nft_expr_type nft_fwd_netdev_type __read_mostly = { + .family = NFPROTO_NETDEV, + .name = "fwd", + .select_ops = nft_fwd_select_ops, + .policy = nft_fwd_netdev_policy, + .maxattr = NFTA_FWD_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_fwd_netdev_module_init(void) +{ + return nft_register_expr(&nft_fwd_netdev_type); +} + +static void __exit nft_fwd_netdev_module_exit(void) +{ + nft_unregister_expr(&nft_fwd_netdev_type); +} + +module_init(nft_fwd_netdev_module_init); +module_exit(nft_fwd_netdev_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_ALIAS_NFT_AF_EXPR(5, "fwd"); diff --git a/net/netfilter/nft_hash.c b/net/netfilter/nft_hash.c new file mode 100644 index 000000000..e5631e88b --- /dev/null +++ b/net/netfilter/nft_hash.c @@ -0,0 +1,286 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2016 Laura Garcia + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_jhash { + u8 sreg; + u8 dreg; + u8 len; + bool autogen_seed:1; + u32 modulus; + u32 seed; + u32 offset; +}; + +static void nft_jhash_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_jhash *priv = nft_expr_priv(expr); + const void *data = ®s->data[priv->sreg]; + u32 h; + + h = reciprocal_scale(jhash(data, priv->len, priv->seed), + priv->modulus); + + regs->data[priv->dreg] = h + priv->offset; +} + +struct nft_symhash { + u8 dreg; + u32 modulus; + u32 offset; +}; + +static void nft_symhash_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_symhash *priv = nft_expr_priv(expr); + struct sk_buff *skb = pkt->skb; + u32 h; + + h = reciprocal_scale(__skb_get_hash_symmetric(skb), priv->modulus); + + regs->data[priv->dreg] = h + priv->offset; +} + +static const struct nla_policy nft_hash_policy[NFTA_HASH_MAX + 1] = { + [NFTA_HASH_SREG] = { .type = NLA_U32 }, + [NFTA_HASH_DREG] = { .type = NLA_U32 }, + [NFTA_HASH_LEN] = { .type = NLA_U32 }, + [NFTA_HASH_MODULUS] = { .type = NLA_U32 }, + [NFTA_HASH_SEED] = { .type = NLA_U32 }, + [NFTA_HASH_OFFSET] = { .type = NLA_U32 }, + [NFTA_HASH_TYPE] = { .type = NLA_U32 }, +}; + +static int nft_jhash_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_jhash *priv = nft_expr_priv(expr); + u32 len; + int err; + + if (!tb[NFTA_HASH_SREG] || + !tb[NFTA_HASH_DREG] || + !tb[NFTA_HASH_LEN] || + !tb[NFTA_HASH_MODULUS]) + return -EINVAL; + + if (tb[NFTA_HASH_OFFSET]) + priv->offset = ntohl(nla_get_be32(tb[NFTA_HASH_OFFSET])); + + err = nft_parse_u32_check(tb[NFTA_HASH_LEN], U8_MAX, &len); + if (err < 0) + return err; + if (len == 0) + return -ERANGE; + + priv->len = len; + + err = nft_parse_register_load(tb[NFTA_HASH_SREG], &priv->sreg, len); + if (err < 0) + return err; + + priv->modulus = ntohl(nla_get_be32(tb[NFTA_HASH_MODULUS])); + if (priv->modulus < 1) + return -ERANGE; + + if (priv->offset + priv->modulus - 1 < priv->offset) + return -EOVERFLOW; + + if (tb[NFTA_HASH_SEED]) { + priv->seed = ntohl(nla_get_be32(tb[NFTA_HASH_SEED])); + } else { + priv->autogen_seed = true; + get_random_bytes(&priv->seed, sizeof(priv->seed)); + } + + return nft_parse_register_store(ctx, tb[NFTA_HASH_DREG], &priv->dreg, + NULL, NFT_DATA_VALUE, sizeof(u32)); +} + +static int nft_symhash_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_symhash *priv = nft_expr_priv(expr); + + if (!tb[NFTA_HASH_DREG] || + !tb[NFTA_HASH_MODULUS]) + return -EINVAL; + + if (tb[NFTA_HASH_OFFSET]) + priv->offset = ntohl(nla_get_be32(tb[NFTA_HASH_OFFSET])); + + priv->modulus = ntohl(nla_get_be32(tb[NFTA_HASH_MODULUS])); + if (priv->modulus < 1) + return -ERANGE; + + if (priv->offset + priv->modulus - 1 < priv->offset) + return -EOVERFLOW; + + return nft_parse_register_store(ctx, tb[NFTA_HASH_DREG], + &priv->dreg, NULL, NFT_DATA_VALUE, + sizeof(u32)); +} + +static int nft_jhash_dump(struct sk_buff *skb, + const struct nft_expr *expr) +{ + const struct nft_jhash *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_HASH_SREG, priv->sreg)) + goto nla_put_failure; + if (nft_dump_register(skb, NFTA_HASH_DREG, priv->dreg)) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_HASH_LEN, htonl(priv->len))) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_HASH_MODULUS, htonl(priv->modulus))) + goto nla_put_failure; + if (!priv->autogen_seed && + nla_put_be32(skb, NFTA_HASH_SEED, htonl(priv->seed))) + goto nla_put_failure; + if (priv->offset != 0) + if (nla_put_be32(skb, NFTA_HASH_OFFSET, htonl(priv->offset))) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_HASH_TYPE, htonl(NFT_HASH_JENKINS))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static bool nft_jhash_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + const struct nft_jhash *priv = nft_expr_priv(expr); + + nft_reg_track_cancel(track, priv->dreg, sizeof(u32)); + + return false; +} + +static int nft_symhash_dump(struct sk_buff *skb, + const struct nft_expr *expr) +{ + const struct nft_symhash *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_HASH_DREG, priv->dreg)) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_HASH_MODULUS, htonl(priv->modulus))) + goto nla_put_failure; + if (priv->offset != 0) + if (nla_put_be32(skb, NFTA_HASH_OFFSET, htonl(priv->offset))) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_HASH_TYPE, htonl(NFT_HASH_SYM))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static bool nft_symhash_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + struct nft_symhash *priv = nft_expr_priv(expr); + struct nft_symhash *symhash; + + if (!nft_reg_track_cmp(track, expr, priv->dreg)) { + nft_reg_track_update(track, expr, priv->dreg, sizeof(u32)); + return false; + } + + symhash = nft_expr_priv(track->regs[priv->dreg].selector); + if (priv->offset != symhash->offset || + priv->modulus != symhash->modulus) { + nft_reg_track_update(track, expr, priv->dreg, sizeof(u32)); + return false; + } + + if (!track->regs[priv->dreg].bitwise) + return true; + + return false; +} + +static struct nft_expr_type nft_hash_type; +static const struct nft_expr_ops nft_jhash_ops = { + .type = &nft_hash_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_jhash)), + .eval = nft_jhash_eval, + .init = nft_jhash_init, + .dump = nft_jhash_dump, + .reduce = nft_jhash_reduce, +}; + +static const struct nft_expr_ops nft_symhash_ops = { + .type = &nft_hash_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_symhash)), + .eval = nft_symhash_eval, + .init = nft_symhash_init, + .dump = nft_symhash_dump, + .reduce = nft_symhash_reduce, +}; + +static const struct nft_expr_ops * +nft_hash_select_ops(const struct nft_ctx *ctx, + const struct nlattr * const tb[]) +{ + u32 type; + + if (!tb[NFTA_HASH_TYPE]) + return &nft_jhash_ops; + + type = ntohl(nla_get_be32(tb[NFTA_HASH_TYPE])); + switch (type) { + case NFT_HASH_SYM: + return &nft_symhash_ops; + case NFT_HASH_JENKINS: + return &nft_jhash_ops; + default: + break; + } + return ERR_PTR(-EOPNOTSUPP); +} + +static struct nft_expr_type nft_hash_type __read_mostly = { + .name = "hash", + .select_ops = nft_hash_select_ops, + .policy = nft_hash_policy, + .maxattr = NFTA_HASH_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_hash_module_init(void) +{ + return nft_register_expr(&nft_hash_type); +} + +static void __exit nft_hash_module_exit(void) +{ + nft_unregister_expr(&nft_hash_type); +} + +module_init(nft_hash_module_init); +module_exit(nft_hash_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Laura Garcia "); +MODULE_ALIAS_NFT_EXPR("hash"); +MODULE_DESCRIPTION("Netfilter nftables hash module"); diff --git a/net/netfilter/nft_immediate.c b/net/netfilter/nft_immediate.c new file mode 100644 index 000000000..55fcf0280 --- /dev/null +++ b/net/netfilter/nft_immediate.c @@ -0,0 +1,355 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008-2009 Patrick McHardy + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void nft_immediate_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_immediate_expr *priv = nft_expr_priv(expr); + + nft_data_copy(®s->data[priv->dreg], &priv->data, priv->dlen); +} + +static const struct nla_policy nft_immediate_policy[NFTA_IMMEDIATE_MAX + 1] = { + [NFTA_IMMEDIATE_DREG] = { .type = NLA_U32 }, + [NFTA_IMMEDIATE_DATA] = { .type = NLA_NESTED }, +}; + +static enum nft_data_types nft_reg_to_type(const struct nlattr *nla) +{ + enum nft_data_types type; + u8 reg; + + reg = ntohl(nla_get_be32(nla)); + if (reg == NFT_REG_VERDICT) + type = NFT_DATA_VERDICT; + else + type = NFT_DATA_VALUE; + + return type; +} + +static int nft_immediate_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_immediate_expr *priv = nft_expr_priv(expr); + struct nft_data_desc desc = { + .size = sizeof(priv->data), + }; + int err; + + if (tb[NFTA_IMMEDIATE_DREG] == NULL || + tb[NFTA_IMMEDIATE_DATA] == NULL) + return -EINVAL; + + desc.type = nft_reg_to_type(tb[NFTA_IMMEDIATE_DREG]); + err = nft_data_init(ctx, &priv->data, &desc, tb[NFTA_IMMEDIATE_DATA]); + if (err < 0) + return err; + + priv->dlen = desc.len; + + err = nft_parse_register_store(ctx, tb[NFTA_IMMEDIATE_DREG], + &priv->dreg, &priv->data, desc.type, + desc.len); + if (err < 0) + goto err1; + + if (priv->dreg == NFT_REG_VERDICT) { + struct nft_chain *chain = priv->data.verdict.chain; + + switch (priv->data.verdict.code) { + case NFT_JUMP: + case NFT_GOTO: + err = nf_tables_bind_chain(ctx, chain); + if (err < 0) + goto err1; + break; + default: + break; + } + } + + return 0; + +err1: + nft_data_release(&priv->data, desc.type); + return err; +} + +static void nft_immediate_activate(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + const struct nft_immediate_expr *priv = nft_expr_priv(expr); + const struct nft_data *data = &priv->data; + struct nft_ctx chain_ctx; + struct nft_chain *chain; + struct nft_rule *rule; + + if (priv->dreg == NFT_REG_VERDICT) { + switch (data->verdict.code) { + case NFT_JUMP: + case NFT_GOTO: + chain = data->verdict.chain; + if (!nft_chain_binding(chain)) + break; + + chain_ctx = *ctx; + chain_ctx.chain = chain; + + list_for_each_entry(rule, &chain->rules, list) + nft_rule_expr_activate(&chain_ctx, rule); + + nft_clear(ctx->net, chain); + break; + default: + break; + } + } + + return nft_data_hold(&priv->data, nft_dreg_to_type(priv->dreg)); +} + +static void nft_immediate_chain_deactivate(const struct nft_ctx *ctx, + struct nft_chain *chain, + enum nft_trans_phase phase) +{ + struct nft_ctx chain_ctx; + struct nft_rule *rule; + + chain_ctx = *ctx; + chain_ctx.chain = chain; + + list_for_each_entry(rule, &chain->rules, list) + nft_rule_expr_deactivate(&chain_ctx, rule, phase); +} + +static void nft_immediate_deactivate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + enum nft_trans_phase phase) +{ + const struct nft_immediate_expr *priv = nft_expr_priv(expr); + const struct nft_data *data = &priv->data; + struct nft_chain *chain; + + if (priv->dreg == NFT_REG_VERDICT) { + switch (data->verdict.code) { + case NFT_JUMP: + case NFT_GOTO: + chain = data->verdict.chain; + if (!nft_chain_binding(chain)) + break; + + switch (phase) { + case NFT_TRANS_PREPARE_ERROR: + nf_tables_unbind_chain(ctx, chain); + nft_deactivate_next(ctx->net, chain); + break; + case NFT_TRANS_PREPARE: + nft_immediate_chain_deactivate(ctx, chain, phase); + nft_deactivate_next(ctx->net, chain); + break; + default: + nft_immediate_chain_deactivate(ctx, chain, phase); + nft_chain_del(chain); + chain->bound = false; + nft_use_dec(&chain->table->use); + break; + } + break; + default: + break; + } + } + + if (phase == NFT_TRANS_COMMIT) + return; + + return nft_data_release(&priv->data, nft_dreg_to_type(priv->dreg)); +} + +static void nft_immediate_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + const struct nft_immediate_expr *priv = nft_expr_priv(expr); + const struct nft_data *data = &priv->data; + struct nft_rule *rule, *n; + struct nft_ctx chain_ctx; + struct nft_chain *chain; + + if (priv->dreg != NFT_REG_VERDICT) + return; + + switch (data->verdict.code) { + case NFT_JUMP: + case NFT_GOTO: + chain = data->verdict.chain; + + if (!nft_chain_binding(chain)) + break; + + /* Rule construction failed, but chain is already bound: + * let the transaction records release this chain and its rules. + */ + if (chain->bound) { + nft_use_dec(&chain->use); + break; + } + + /* Rule has been deleted, release chain and its rules. */ + chain_ctx = *ctx; + chain_ctx.chain = chain; + + nft_use_dec(&chain->use); + list_for_each_entry_safe(rule, n, &chain->rules, list) { + nft_use_dec(&chain->use); + list_del(&rule->list); + nf_tables_rule_destroy(&chain_ctx, rule); + } + nf_tables_chain_destroy(&chain_ctx); + break; + default: + break; + } +} + +static int nft_immediate_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_immediate_expr *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_IMMEDIATE_DREG, priv->dreg)) + goto nla_put_failure; + + return nft_data_dump(skb, NFTA_IMMEDIATE_DATA, &priv->data, + nft_dreg_to_type(priv->dreg), priv->dlen); + +nla_put_failure: + return -1; +} + +static int nft_immediate_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **d) +{ + const struct nft_immediate_expr *priv = nft_expr_priv(expr); + struct nft_ctx *pctx = (struct nft_ctx *)ctx; + const struct nft_data *data; + int err; + + if (priv->dreg != NFT_REG_VERDICT) + return 0; + + data = &priv->data; + + switch (data->verdict.code) { + case NFT_JUMP: + case NFT_GOTO: + pctx->level++; + err = nft_chain_validate(ctx, data->verdict.chain); + if (err < 0) + return err; + pctx->level--; + break; + default: + break; + } + + return 0; +} + +static int nft_immediate_offload_verdict(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_immediate_expr *priv) +{ + struct flow_action_entry *entry; + const struct nft_data *data; + + entry = &flow->rule->action.entries[ctx->num_actions++]; + + data = &priv->data; + switch (data->verdict.code) { + case NF_ACCEPT: + entry->id = FLOW_ACTION_ACCEPT; + break; + case NF_DROP: + entry->id = FLOW_ACTION_DROP; + break; + default: + return -EOPNOTSUPP; + } + + return 0; +} + +static int nft_immediate_offload(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_expr *expr) +{ + const struct nft_immediate_expr *priv = nft_expr_priv(expr); + + if (priv->dreg == NFT_REG_VERDICT) + return nft_immediate_offload_verdict(ctx, flow, priv); + + memcpy(&ctx->regs[priv->dreg].data, &priv->data, sizeof(priv->data)); + + return 0; +} + +static bool nft_immediate_offload_action(const struct nft_expr *expr) +{ + const struct nft_immediate_expr *priv = nft_expr_priv(expr); + + if (priv->dreg == NFT_REG_VERDICT) + return true; + + return false; +} + +static bool nft_immediate_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + const struct nft_immediate_expr *priv = nft_expr_priv(expr); + + if (priv->dreg != NFT_REG_VERDICT) + nft_reg_track_cancel(track, priv->dreg, priv->dlen); + + return false; +} + +static const struct nft_expr_ops nft_imm_ops = { + .type = &nft_imm_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_immediate_expr)), + .eval = nft_immediate_eval, + .init = nft_immediate_init, + .activate = nft_immediate_activate, + .deactivate = nft_immediate_deactivate, + .destroy = nft_immediate_destroy, + .dump = nft_immediate_dump, + .validate = nft_immediate_validate, + .reduce = nft_immediate_reduce, + .offload = nft_immediate_offload, + .offload_action = nft_immediate_offload_action, +}; + +struct nft_expr_type nft_imm_type __read_mostly = { + .name = "immediate", + .ops = &nft_imm_ops, + .policy = nft_immediate_policy, + .maxattr = NFTA_IMMEDIATE_MAX, + .owner = THIS_MODULE, +}; diff --git a/net/netfilter/nft_last.c b/net/netfilter/nft_last.c new file mode 100644 index 000000000..eaa54964c --- /dev/null +++ b/net/netfilter/nft_last.c @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_last { + unsigned long jiffies; + unsigned int set; +}; + +struct nft_last_priv { + struct nft_last *last; +}; + +static const struct nla_policy nft_last_policy[NFTA_LAST_MAX + 1] = { + [NFTA_LAST_SET] = { .type = NLA_U32 }, + [NFTA_LAST_MSECS] = { .type = NLA_U64 }, +}; + +static int nft_last_init(const struct nft_ctx *ctx, const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_last_priv *priv = nft_expr_priv(expr); + struct nft_last *last; + u64 last_jiffies; + int err; + + last = kzalloc(sizeof(*last), GFP_KERNEL_ACCOUNT); + if (!last) + return -ENOMEM; + + if (tb[NFTA_LAST_SET]) + last->set = ntohl(nla_get_be32(tb[NFTA_LAST_SET])); + + if (last->set && tb[NFTA_LAST_MSECS]) { + err = nf_msecs_to_jiffies64(tb[NFTA_LAST_MSECS], &last_jiffies); + if (err < 0) + goto err; + + last->jiffies = jiffies - (unsigned long)last_jiffies; + } + priv->last = last; + + return 0; +err: + kfree(last); + + return err; +} + +static void nft_last_eval(const struct nft_expr *expr, + struct nft_regs *regs, const struct nft_pktinfo *pkt) +{ + struct nft_last_priv *priv = nft_expr_priv(expr); + struct nft_last *last = priv->last; + + if (READ_ONCE(last->jiffies) != jiffies) + WRITE_ONCE(last->jiffies, jiffies); + if (READ_ONCE(last->set) == 0) + WRITE_ONCE(last->set, 1); +} + +static int nft_last_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + struct nft_last_priv *priv = nft_expr_priv(expr); + struct nft_last *last = priv->last; + unsigned long last_jiffies = READ_ONCE(last->jiffies); + u32 last_set = READ_ONCE(last->set); + __be64 msecs; + + if (time_before(jiffies, last_jiffies)) { + WRITE_ONCE(last->set, 0); + last_set = 0; + } + + if (last_set) + msecs = nf_jiffies64_to_msecs(jiffies - last_jiffies); + else + msecs = 0; + + if (nla_put_be32(skb, NFTA_LAST_SET, htonl(last_set)) || + nla_put_be64(skb, NFTA_LAST_MSECS, msecs, NFTA_LAST_PAD)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static void nft_last_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_last_priv *priv = nft_expr_priv(expr); + + kfree(priv->last); +} + +static int nft_last_clone(struct nft_expr *dst, const struct nft_expr *src) +{ + struct nft_last_priv *priv_dst = nft_expr_priv(dst); + struct nft_last_priv *priv_src = nft_expr_priv(src); + + priv_dst->last = kzalloc(sizeof(*priv_dst->last), GFP_ATOMIC); + if (!priv_dst->last) + return -ENOMEM; + + priv_dst->last->set = priv_src->last->set; + priv_dst->last->jiffies = priv_src->last->jiffies; + + return 0; +} + +static const struct nft_expr_ops nft_last_ops = { + .type = &nft_last_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_last_priv)), + .eval = nft_last_eval, + .init = nft_last_init, + .destroy = nft_last_destroy, + .clone = nft_last_clone, + .dump = nft_last_dump, + .reduce = NFT_REDUCE_READONLY, +}; + +struct nft_expr_type nft_last_type __read_mostly = { + .name = "last", + .ops = &nft_last_ops, + .policy = nft_last_policy, + .maxattr = NFTA_LAST_MAX, + .flags = NFT_EXPR_STATEFUL, + .owner = THIS_MODULE, +}; diff --git a/net/netfilter/nft_limit.c b/net/netfilter/nft_limit.c new file mode 100644 index 000000000..36ded7d43 --- /dev/null +++ b/net/netfilter/nft_limit.c @@ -0,0 +1,481 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008-2009 Patrick McHardy + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_limit { + spinlock_t lock; + u64 last; + u64 tokens; +}; + +struct nft_limit_priv { + struct nft_limit *limit; + u64 tokens_max; + u64 rate; + u64 nsecs; + u32 burst; + bool invert; +}; + +static inline bool nft_limit_eval(struct nft_limit_priv *priv, u64 cost) +{ + u64 now, tokens; + s64 delta; + + spin_lock_bh(&priv->limit->lock); + now = ktime_get_ns(); + tokens = priv->limit->tokens + now - priv->limit->last; + if (tokens > priv->tokens_max) + tokens = priv->tokens_max; + + priv->limit->last = now; + delta = tokens - cost; + if (delta >= 0) { + priv->limit->tokens = delta; + spin_unlock_bh(&priv->limit->lock); + return priv->invert; + } + priv->limit->tokens = tokens; + spin_unlock_bh(&priv->limit->lock); + return !priv->invert; +} + +/* Use same default as in iptables. */ +#define NFT_LIMIT_PKT_BURST_DEFAULT 5 + +static int nft_limit_init(struct nft_limit_priv *priv, + const struct nlattr * const tb[], bool pkts) +{ + u64 unit, tokens, rate_with_burst; + bool invert = false; + + if (tb[NFTA_LIMIT_RATE] == NULL || + tb[NFTA_LIMIT_UNIT] == NULL) + return -EINVAL; + + priv->rate = be64_to_cpu(nla_get_be64(tb[NFTA_LIMIT_RATE])); + if (priv->rate == 0) + return -EINVAL; + + unit = be64_to_cpu(nla_get_be64(tb[NFTA_LIMIT_UNIT])); + if (check_mul_overflow(unit, NSEC_PER_SEC, &priv->nsecs)) + return -EOVERFLOW; + + if (tb[NFTA_LIMIT_BURST]) + priv->burst = ntohl(nla_get_be32(tb[NFTA_LIMIT_BURST])); + + if (pkts && priv->burst == 0) + priv->burst = NFT_LIMIT_PKT_BURST_DEFAULT; + + if (check_add_overflow(priv->rate, priv->burst, &rate_with_burst)) + return -EOVERFLOW; + + if (pkts) { + u64 tmp = div64_u64(priv->nsecs, priv->rate); + + if (check_mul_overflow(tmp, priv->burst, &tokens)) + return -EOVERFLOW; + } else { + u64 tmp; + + /* The token bucket size limits the number of tokens can be + * accumulated. tokens_max specifies the bucket size. + * tokens_max = unit * (rate + burst) / rate. + */ + if (check_mul_overflow(priv->nsecs, rate_with_burst, &tmp)) + return -EOVERFLOW; + + tokens = div64_u64(tmp, priv->rate); + } + + if (tb[NFTA_LIMIT_FLAGS]) { + u32 flags = ntohl(nla_get_be32(tb[NFTA_LIMIT_FLAGS])); + + if (flags & ~NFT_LIMIT_F_INV) + return -EOPNOTSUPP; + + if (flags & NFT_LIMIT_F_INV) + invert = true; + } + + priv->limit = kmalloc(sizeof(*priv->limit), GFP_KERNEL_ACCOUNT); + if (!priv->limit) + return -ENOMEM; + + priv->limit->tokens = tokens; + priv->tokens_max = priv->limit->tokens; + priv->invert = invert; + priv->limit->last = ktime_get_ns(); + spin_lock_init(&priv->limit->lock); + + return 0; +} + +static int nft_limit_dump(struct sk_buff *skb, const struct nft_limit_priv *priv, + enum nft_limit_type type) +{ + u32 flags = priv->invert ? NFT_LIMIT_F_INV : 0; + u64 secs = div_u64(priv->nsecs, NSEC_PER_SEC); + + if (nla_put_be64(skb, NFTA_LIMIT_RATE, cpu_to_be64(priv->rate), + NFTA_LIMIT_PAD) || + nla_put_be64(skb, NFTA_LIMIT_UNIT, cpu_to_be64(secs), + NFTA_LIMIT_PAD) || + nla_put_be32(skb, NFTA_LIMIT_BURST, htonl(priv->burst)) || + nla_put_be32(skb, NFTA_LIMIT_TYPE, htonl(type)) || + nla_put_be32(skb, NFTA_LIMIT_FLAGS, htonl(flags))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static void nft_limit_destroy(const struct nft_ctx *ctx, + const struct nft_limit_priv *priv) +{ + kfree(priv->limit); +} + +static int nft_limit_clone(struct nft_limit_priv *priv_dst, + const struct nft_limit_priv *priv_src) +{ + priv_dst->tokens_max = priv_src->tokens_max; + priv_dst->rate = priv_src->rate; + priv_dst->nsecs = priv_src->nsecs; + priv_dst->burst = priv_src->burst; + priv_dst->invert = priv_src->invert; + + priv_dst->limit = kmalloc(sizeof(*priv_dst->limit), GFP_ATOMIC); + if (!priv_dst->limit) + return -ENOMEM; + + spin_lock_init(&priv_dst->limit->lock); + priv_dst->limit->tokens = priv_src->tokens_max; + priv_dst->limit->last = ktime_get_ns(); + + return 0; +} + +struct nft_limit_priv_pkts { + struct nft_limit_priv limit; + u64 cost; +}; + +static void nft_limit_pkts_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_limit_priv_pkts *priv = nft_expr_priv(expr); + + if (nft_limit_eval(&priv->limit, priv->cost)) + regs->verdict.code = NFT_BREAK; +} + +static const struct nla_policy nft_limit_policy[NFTA_LIMIT_MAX + 1] = { + [NFTA_LIMIT_RATE] = { .type = NLA_U64 }, + [NFTA_LIMIT_UNIT] = { .type = NLA_U64 }, + [NFTA_LIMIT_BURST] = { .type = NLA_U32 }, + [NFTA_LIMIT_TYPE] = { .type = NLA_U32 }, + [NFTA_LIMIT_FLAGS] = { .type = NLA_U32 }, +}; + +static int nft_limit_pkts_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_limit_priv_pkts *priv = nft_expr_priv(expr); + int err; + + err = nft_limit_init(&priv->limit, tb, true); + if (err < 0) + return err; + + priv->cost = div64_u64(priv->limit.nsecs, priv->limit.rate); + return 0; +} + +static int nft_limit_pkts_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_limit_priv_pkts *priv = nft_expr_priv(expr); + + return nft_limit_dump(skb, &priv->limit, NFT_LIMIT_PKTS); +} + +static void nft_limit_pkts_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + const struct nft_limit_priv_pkts *priv = nft_expr_priv(expr); + + nft_limit_destroy(ctx, &priv->limit); +} + +static int nft_limit_pkts_clone(struct nft_expr *dst, const struct nft_expr *src) +{ + struct nft_limit_priv_pkts *priv_dst = nft_expr_priv(dst); + struct nft_limit_priv_pkts *priv_src = nft_expr_priv(src); + + priv_dst->cost = priv_src->cost; + + return nft_limit_clone(&priv_dst->limit, &priv_src->limit); +} + +static struct nft_expr_type nft_limit_type; +static const struct nft_expr_ops nft_limit_pkts_ops = { + .type = &nft_limit_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_limit_priv_pkts)), + .eval = nft_limit_pkts_eval, + .init = nft_limit_pkts_init, + .destroy = nft_limit_pkts_destroy, + .clone = nft_limit_pkts_clone, + .dump = nft_limit_pkts_dump, + .reduce = NFT_REDUCE_READONLY, +}; + +static void nft_limit_bytes_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_limit_priv *priv = nft_expr_priv(expr); + u64 cost = div64_u64(priv->nsecs * pkt->skb->len, priv->rate); + + if (nft_limit_eval(priv, cost)) + regs->verdict.code = NFT_BREAK; +} + +static int nft_limit_bytes_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_limit_priv *priv = nft_expr_priv(expr); + + return nft_limit_init(priv, tb, false); +} + +static int nft_limit_bytes_dump(struct sk_buff *skb, + const struct nft_expr *expr) +{ + const struct nft_limit_priv *priv = nft_expr_priv(expr); + + return nft_limit_dump(skb, priv, NFT_LIMIT_PKT_BYTES); +} + +static void nft_limit_bytes_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + const struct nft_limit_priv *priv = nft_expr_priv(expr); + + nft_limit_destroy(ctx, priv); +} + +static int nft_limit_bytes_clone(struct nft_expr *dst, const struct nft_expr *src) +{ + struct nft_limit_priv *priv_dst = nft_expr_priv(dst); + struct nft_limit_priv *priv_src = nft_expr_priv(src); + + return nft_limit_clone(priv_dst, priv_src); +} + +static const struct nft_expr_ops nft_limit_bytes_ops = { + .type = &nft_limit_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_limit_priv)), + .eval = nft_limit_bytes_eval, + .init = nft_limit_bytes_init, + .dump = nft_limit_bytes_dump, + .clone = nft_limit_bytes_clone, + .destroy = nft_limit_bytes_destroy, + .reduce = NFT_REDUCE_READONLY, +}; + +static const struct nft_expr_ops * +nft_limit_select_ops(const struct nft_ctx *ctx, + const struct nlattr * const tb[]) +{ + if (tb[NFTA_LIMIT_TYPE] == NULL) + return &nft_limit_pkts_ops; + + switch (ntohl(nla_get_be32(tb[NFTA_LIMIT_TYPE]))) { + case NFT_LIMIT_PKTS: + return &nft_limit_pkts_ops; + case NFT_LIMIT_PKT_BYTES: + return &nft_limit_bytes_ops; + } + return ERR_PTR(-EOPNOTSUPP); +} + +static struct nft_expr_type nft_limit_type __read_mostly = { + .name = "limit", + .select_ops = nft_limit_select_ops, + .policy = nft_limit_policy, + .maxattr = NFTA_LIMIT_MAX, + .flags = NFT_EXPR_STATEFUL, + .owner = THIS_MODULE, +}; + +static void nft_limit_obj_pkts_eval(struct nft_object *obj, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_limit_priv_pkts *priv = nft_obj_data(obj); + + if (nft_limit_eval(&priv->limit, priv->cost)) + regs->verdict.code = NFT_BREAK; +} + +static int nft_limit_obj_pkts_init(const struct nft_ctx *ctx, + const struct nlattr * const tb[], + struct nft_object *obj) +{ + struct nft_limit_priv_pkts *priv = nft_obj_data(obj); + int err; + + err = nft_limit_init(&priv->limit, tb, true); + if (err < 0) + return err; + + priv->cost = div64_u64(priv->limit.nsecs, priv->limit.rate); + return 0; +} + +static int nft_limit_obj_pkts_dump(struct sk_buff *skb, + struct nft_object *obj, + bool reset) +{ + const struct nft_limit_priv_pkts *priv = nft_obj_data(obj); + + return nft_limit_dump(skb, &priv->limit, NFT_LIMIT_PKTS); +} + +static void nft_limit_obj_pkts_destroy(const struct nft_ctx *ctx, + struct nft_object *obj) +{ + struct nft_limit_priv_pkts *priv = nft_obj_data(obj); + + nft_limit_destroy(ctx, &priv->limit); +} + +static struct nft_object_type nft_limit_obj_type; +static const struct nft_object_ops nft_limit_obj_pkts_ops = { + .type = &nft_limit_obj_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_limit_priv_pkts)), + .init = nft_limit_obj_pkts_init, + .destroy = nft_limit_obj_pkts_destroy, + .eval = nft_limit_obj_pkts_eval, + .dump = nft_limit_obj_pkts_dump, +}; + +static void nft_limit_obj_bytes_eval(struct nft_object *obj, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_limit_priv *priv = nft_obj_data(obj); + u64 cost = div64_u64(priv->nsecs * pkt->skb->len, priv->rate); + + if (nft_limit_eval(priv, cost)) + regs->verdict.code = NFT_BREAK; +} + +static int nft_limit_obj_bytes_init(const struct nft_ctx *ctx, + const struct nlattr * const tb[], + struct nft_object *obj) +{ + struct nft_limit_priv *priv = nft_obj_data(obj); + + return nft_limit_init(priv, tb, false); +} + +static int nft_limit_obj_bytes_dump(struct sk_buff *skb, + struct nft_object *obj, + bool reset) +{ + const struct nft_limit_priv *priv = nft_obj_data(obj); + + return nft_limit_dump(skb, priv, NFT_LIMIT_PKT_BYTES); +} + +static void nft_limit_obj_bytes_destroy(const struct nft_ctx *ctx, + struct nft_object *obj) +{ + struct nft_limit_priv *priv = nft_obj_data(obj); + + nft_limit_destroy(ctx, priv); +} + +static struct nft_object_type nft_limit_obj_type; +static const struct nft_object_ops nft_limit_obj_bytes_ops = { + .type = &nft_limit_obj_type, + .size = sizeof(struct nft_limit_priv), + .init = nft_limit_obj_bytes_init, + .destroy = nft_limit_obj_bytes_destroy, + .eval = nft_limit_obj_bytes_eval, + .dump = nft_limit_obj_bytes_dump, +}; + +static const struct nft_object_ops * +nft_limit_obj_select_ops(const struct nft_ctx *ctx, + const struct nlattr * const tb[]) +{ + if (!tb[NFTA_LIMIT_TYPE]) + return &nft_limit_obj_pkts_ops; + + switch (ntohl(nla_get_be32(tb[NFTA_LIMIT_TYPE]))) { + case NFT_LIMIT_PKTS: + return &nft_limit_obj_pkts_ops; + case NFT_LIMIT_PKT_BYTES: + return &nft_limit_obj_bytes_ops; + } + return ERR_PTR(-EOPNOTSUPP); +} + +static struct nft_object_type nft_limit_obj_type __read_mostly = { + .select_ops = nft_limit_obj_select_ops, + .type = NFT_OBJECT_LIMIT, + .maxattr = NFTA_LIMIT_MAX, + .policy = nft_limit_policy, + .owner = THIS_MODULE, +}; + +static int __init nft_limit_module_init(void) +{ + int err; + + err = nft_register_obj(&nft_limit_obj_type); + if (err < 0) + return err; + + err = nft_register_expr(&nft_limit_type); + if (err < 0) + goto err1; + + return 0; +err1: + nft_unregister_obj(&nft_limit_obj_type); + return err; +} + +static void __exit nft_limit_module_exit(void) +{ + nft_unregister_expr(&nft_limit_type); + nft_unregister_obj(&nft_limit_obj_type); +} + +module_init(nft_limit_module_init); +module_exit(nft_limit_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_ALIAS_NFT_EXPR("limit"); +MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_LIMIT); +MODULE_DESCRIPTION("nftables limit expression support"); diff --git a/net/netfilter/nft_log.c b/net/netfilter/nft_log.c new file mode 100644 index 000000000..0e13c003f --- /dev/null +++ b/net/netfilter/nft_log.c @@ -0,0 +1,320 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008-2009 Patrick McHardy + * Copyright (c) 2012-2014 Pablo Neira Ayuso + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static const char *nft_log_null_prefix = ""; + +struct nft_log { + struct nf_loginfo loginfo; + char *prefix; +}; + +static bool audit_ip4(struct audit_buffer *ab, struct sk_buff *skb) +{ + struct iphdr _iph; + const struct iphdr *ih; + + ih = skb_header_pointer(skb, skb_network_offset(skb), sizeof(_iph), &_iph); + if (!ih) + return false; + + audit_log_format(ab, " saddr=%pI4 daddr=%pI4 proto=%hhu", + &ih->saddr, &ih->daddr, ih->protocol); + + return true; +} + +static bool audit_ip6(struct audit_buffer *ab, struct sk_buff *skb) +{ + struct ipv6hdr _ip6h; + const struct ipv6hdr *ih; + u8 nexthdr; + __be16 frag_off; + + ih = skb_header_pointer(skb, skb_network_offset(skb), sizeof(_ip6h), &_ip6h); + if (!ih) + return false; + + nexthdr = ih->nexthdr; + ipv6_skip_exthdr(skb, skb_network_offset(skb) + sizeof(_ip6h), &nexthdr, &frag_off); + + audit_log_format(ab, " saddr=%pI6c daddr=%pI6c proto=%hhu", + &ih->saddr, &ih->daddr, nexthdr); + + return true; +} + +static void nft_log_eval_audit(const struct nft_pktinfo *pkt) +{ + struct sk_buff *skb = pkt->skb; + struct audit_buffer *ab; + int fam = -1; + + if (!audit_enabled) + return; + + ab = audit_log_start(NULL, GFP_ATOMIC, AUDIT_NETFILTER_PKT); + if (!ab) + return; + + audit_log_format(ab, "mark=%#x", skb->mark); + + switch (nft_pf(pkt)) { + case NFPROTO_BRIDGE: + switch (eth_hdr(skb)->h_proto) { + case htons(ETH_P_IP): + fam = audit_ip4(ab, skb) ? NFPROTO_IPV4 : -1; + break; + case htons(ETH_P_IPV6): + fam = audit_ip6(ab, skb) ? NFPROTO_IPV6 : -1; + break; + } + break; + case NFPROTO_IPV4: + fam = audit_ip4(ab, skb) ? NFPROTO_IPV4 : -1; + break; + case NFPROTO_IPV6: + fam = audit_ip6(ab, skb) ? NFPROTO_IPV6 : -1; + break; + } + + if (fam == -1) + audit_log_format(ab, " saddr=? daddr=? proto=-1"); + + audit_log_end(ab); +} + +static void nft_log_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_log *priv = nft_expr_priv(expr); + + if (priv->loginfo.type == NF_LOG_TYPE_LOG && + priv->loginfo.u.log.level == NFT_LOGLEVEL_AUDIT) { + nft_log_eval_audit(pkt); + return; + } + + nf_log_packet(nft_net(pkt), nft_pf(pkt), nft_hook(pkt), pkt->skb, + nft_in(pkt), nft_out(pkt), &priv->loginfo, "%s", + priv->prefix); +} + +static const struct nla_policy nft_log_policy[NFTA_LOG_MAX + 1] = { + [NFTA_LOG_GROUP] = { .type = NLA_U16 }, + [NFTA_LOG_PREFIX] = { .type = NLA_STRING, + .len = NF_LOG_PREFIXLEN - 1 }, + [NFTA_LOG_SNAPLEN] = { .type = NLA_U32 }, + [NFTA_LOG_QTHRESHOLD] = { .type = NLA_U16 }, + [NFTA_LOG_LEVEL] = { .type = NLA_U32 }, + [NFTA_LOG_FLAGS] = { .type = NLA_U32 }, +}; + +static int nft_log_modprobe(struct net *net, enum nf_log_type t) +{ + switch (t) { + case NF_LOG_TYPE_LOG: + return nft_request_module(net, "%s", "nf_log_syslog"); + case NF_LOG_TYPE_ULOG: + return nft_request_module(net, "%s", "nfnetlink_log"); + case NF_LOG_TYPE_MAX: + break; + } + + return -ENOENT; +} + +static int nft_log_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_log *priv = nft_expr_priv(expr); + struct nf_loginfo *li = &priv->loginfo; + const struct nlattr *nla; + int err; + + li->type = NF_LOG_TYPE_LOG; + if (tb[NFTA_LOG_LEVEL] != NULL && + tb[NFTA_LOG_GROUP] != NULL) + return -EINVAL; + if (tb[NFTA_LOG_GROUP] != NULL) { + li->type = NF_LOG_TYPE_ULOG; + if (tb[NFTA_LOG_FLAGS] != NULL) + return -EINVAL; + } + + nla = tb[NFTA_LOG_PREFIX]; + if (nla != NULL) { + priv->prefix = kmalloc(nla_len(nla) + 1, GFP_KERNEL); + if (priv->prefix == NULL) + return -ENOMEM; + nla_strscpy(priv->prefix, nla, nla_len(nla) + 1); + } else { + priv->prefix = (char *)nft_log_null_prefix; + } + + switch (li->type) { + case NF_LOG_TYPE_LOG: + if (tb[NFTA_LOG_LEVEL] != NULL) { + li->u.log.level = + ntohl(nla_get_be32(tb[NFTA_LOG_LEVEL])); + } else { + li->u.log.level = NFT_LOGLEVEL_WARNING; + } + if (li->u.log.level > NFT_LOGLEVEL_AUDIT) { + err = -EINVAL; + goto err1; + } + + if (tb[NFTA_LOG_FLAGS] != NULL) { + li->u.log.logflags = + ntohl(nla_get_be32(tb[NFTA_LOG_FLAGS])); + if (li->u.log.logflags & ~NF_LOG_MASK) { + err = -EINVAL; + goto err1; + } + } + break; + case NF_LOG_TYPE_ULOG: + li->u.ulog.group = ntohs(nla_get_be16(tb[NFTA_LOG_GROUP])); + if (tb[NFTA_LOG_SNAPLEN] != NULL) { + li->u.ulog.flags |= NF_LOG_F_COPY_LEN; + li->u.ulog.copy_len = + ntohl(nla_get_be32(tb[NFTA_LOG_SNAPLEN])); + } + if (tb[NFTA_LOG_QTHRESHOLD] != NULL) { + li->u.ulog.qthreshold = + ntohs(nla_get_be16(tb[NFTA_LOG_QTHRESHOLD])); + } + break; + } + + if (li->u.log.level == NFT_LOGLEVEL_AUDIT) + return 0; + + err = nf_logger_find_get(ctx->family, li->type); + if (err < 0) { + if (nft_log_modprobe(ctx->net, li->type) == -EAGAIN) + err = -EAGAIN; + + goto err1; + } + + return 0; + +err1: + if (priv->prefix != nft_log_null_prefix) + kfree(priv->prefix); + return err; +} + +static void nft_log_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_log *priv = nft_expr_priv(expr); + struct nf_loginfo *li = &priv->loginfo; + + if (priv->prefix != nft_log_null_prefix) + kfree(priv->prefix); + + if (li->u.log.level == NFT_LOGLEVEL_AUDIT) + return; + + nf_logger_put(ctx->family, li->type); +} + +static int nft_log_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_log *priv = nft_expr_priv(expr); + const struct nf_loginfo *li = &priv->loginfo; + + if (priv->prefix != nft_log_null_prefix) + if (nla_put_string(skb, NFTA_LOG_PREFIX, priv->prefix)) + goto nla_put_failure; + switch (li->type) { + case NF_LOG_TYPE_LOG: + if (nla_put_be32(skb, NFTA_LOG_LEVEL, htonl(li->u.log.level))) + goto nla_put_failure; + + if (li->u.log.logflags) { + if (nla_put_be32(skb, NFTA_LOG_FLAGS, + htonl(li->u.log.logflags))) + goto nla_put_failure; + } + break; + case NF_LOG_TYPE_ULOG: + if (nla_put_be16(skb, NFTA_LOG_GROUP, htons(li->u.ulog.group))) + goto nla_put_failure; + + if (li->u.ulog.flags & NF_LOG_F_COPY_LEN) { + if (nla_put_be32(skb, NFTA_LOG_SNAPLEN, + htonl(li->u.ulog.copy_len))) + goto nla_put_failure; + } + if (li->u.ulog.qthreshold) { + if (nla_put_be16(skb, NFTA_LOG_QTHRESHOLD, + htons(li->u.ulog.qthreshold))) + goto nla_put_failure; + } + break; + } + return 0; + +nla_put_failure: + return -1; +} + +static struct nft_expr_type nft_log_type; +static const struct nft_expr_ops nft_log_ops = { + .type = &nft_log_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_log)), + .eval = nft_log_eval, + .init = nft_log_init, + .destroy = nft_log_destroy, + .dump = nft_log_dump, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_log_type __read_mostly = { + .name = "log", + .ops = &nft_log_ops, + .policy = nft_log_policy, + .maxattr = NFTA_LOG_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_log_module_init(void) +{ + return nft_register_expr(&nft_log_type); +} + +static void __exit nft_log_module_exit(void) +{ + nft_unregister_expr(&nft_log_type); +} + +module_init(nft_log_module_init); +module_exit(nft_log_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_ALIAS_NFT_EXPR("log"); +MODULE_DESCRIPTION("Netfilter nf_tables log module"); diff --git a/net/netfilter/nft_lookup.c b/net/netfilter/nft_lookup.c new file mode 100644 index 000000000..68a5dea80 --- /dev/null +++ b/net/netfilter/nft_lookup.c @@ -0,0 +1,258 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2009 Patrick McHardy + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_lookup { + struct nft_set *set; + u8 sreg; + u8 dreg; + bool invert; + struct nft_set_binding binding; +}; + +#ifdef CONFIG_RETPOLINE +bool nft_set_do_lookup(const struct net *net, const struct nft_set *set, + const u32 *key, const struct nft_set_ext **ext) +{ + if (set->ops == &nft_set_hash_fast_type.ops) + return nft_hash_lookup_fast(net, set, key, ext); + if (set->ops == &nft_set_hash_type.ops) + return nft_hash_lookup(net, set, key, ext); + + if (set->ops == &nft_set_rhash_type.ops) + return nft_rhash_lookup(net, set, key, ext); + + if (set->ops == &nft_set_bitmap_type.ops) + return nft_bitmap_lookup(net, set, key, ext); + + if (set->ops == &nft_set_pipapo_type.ops) + return nft_pipapo_lookup(net, set, key, ext); +#if defined(CONFIG_X86_64) && !defined(CONFIG_UML) + if (set->ops == &nft_set_pipapo_avx2_type.ops) + return nft_pipapo_avx2_lookup(net, set, key, ext); +#endif + + if (set->ops == &nft_set_rbtree_type.ops) + return nft_rbtree_lookup(net, set, key, ext); + + WARN_ON_ONCE(1); + return set->ops->lookup(net, set, key, ext); +} +EXPORT_SYMBOL_GPL(nft_set_do_lookup); +#endif + +void nft_lookup_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_lookup *priv = nft_expr_priv(expr); + const struct nft_set *set = priv->set; + const struct nft_set_ext *ext = NULL; + const struct net *net = nft_net(pkt); + bool found; + + found = nft_set_do_lookup(net, set, ®s->data[priv->sreg], &ext) ^ + priv->invert; + if (!found) { + ext = nft_set_catchall_lookup(net, set); + if (!ext) { + regs->verdict.code = NFT_BREAK; + return; + } + } + + if (ext) { + if (set->flags & NFT_SET_MAP) + nft_data_copy(®s->data[priv->dreg], + nft_set_ext_data(ext), set->dlen); + + nft_set_elem_update_expr(ext, regs, pkt); + } +} + +static const struct nla_policy nft_lookup_policy[NFTA_LOOKUP_MAX + 1] = { + [NFTA_LOOKUP_SET] = { .type = NLA_STRING, + .len = NFT_SET_MAXNAMELEN - 1 }, + [NFTA_LOOKUP_SET_ID] = { .type = NLA_U32 }, + [NFTA_LOOKUP_SREG] = { .type = NLA_U32 }, + [NFTA_LOOKUP_DREG] = { .type = NLA_U32 }, + [NFTA_LOOKUP_FLAGS] = { .type = NLA_U32 }, +}; + +static int nft_lookup_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_lookup *priv = nft_expr_priv(expr); + u8 genmask = nft_genmask_next(ctx->net); + struct nft_set *set; + u32 flags; + int err; + + if (tb[NFTA_LOOKUP_SET] == NULL || + tb[NFTA_LOOKUP_SREG] == NULL) + return -EINVAL; + + set = nft_set_lookup_global(ctx->net, ctx->table, tb[NFTA_LOOKUP_SET], + tb[NFTA_LOOKUP_SET_ID], genmask); + if (IS_ERR(set)) + return PTR_ERR(set); + + err = nft_parse_register_load(tb[NFTA_LOOKUP_SREG], &priv->sreg, + set->klen); + if (err < 0) + return err; + + if (tb[NFTA_LOOKUP_FLAGS]) { + flags = ntohl(nla_get_be32(tb[NFTA_LOOKUP_FLAGS])); + + if (flags & ~NFT_LOOKUP_F_INV) + return -EINVAL; + + if (flags & NFT_LOOKUP_F_INV) { + if (set->flags & NFT_SET_MAP) + return -EINVAL; + priv->invert = true; + } + } + + if (tb[NFTA_LOOKUP_DREG] != NULL) { + if (priv->invert) + return -EINVAL; + if (!(set->flags & NFT_SET_MAP)) + return -EINVAL; + + err = nft_parse_register_store(ctx, tb[NFTA_LOOKUP_DREG], + &priv->dreg, NULL, set->dtype, + set->dlen); + if (err < 0) + return err; + } else if (set->flags & NFT_SET_MAP) + return -EINVAL; + + priv->binding.flags = set->flags & NFT_SET_MAP; + + err = nf_tables_bind_set(ctx, set, &priv->binding); + if (err < 0) + return err; + + priv->set = set; + return 0; +} + +static void nft_lookup_deactivate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + enum nft_trans_phase phase) +{ + struct nft_lookup *priv = nft_expr_priv(expr); + + nf_tables_deactivate_set(ctx, priv->set, &priv->binding, phase); +} + +static void nft_lookup_activate(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_lookup *priv = nft_expr_priv(expr); + + nf_tables_activate_set(ctx, priv->set); +} + +static void nft_lookup_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_lookup *priv = nft_expr_priv(expr); + + nf_tables_destroy_set(ctx, priv->set); +} + +static int nft_lookup_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_lookup *priv = nft_expr_priv(expr); + u32 flags = priv->invert ? NFT_LOOKUP_F_INV : 0; + + if (nla_put_string(skb, NFTA_LOOKUP_SET, priv->set->name)) + goto nla_put_failure; + if (nft_dump_register(skb, NFTA_LOOKUP_SREG, priv->sreg)) + goto nla_put_failure; + if (priv->set->flags & NFT_SET_MAP) + if (nft_dump_register(skb, NFTA_LOOKUP_DREG, priv->dreg)) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_LOOKUP_FLAGS, htonl(flags))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static int nft_lookup_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **d) +{ + const struct nft_lookup *priv = nft_expr_priv(expr); + struct nft_set_iter iter; + + if (!(priv->set->flags & NFT_SET_MAP) || + priv->set->dtype != NFT_DATA_VERDICT) + return 0; + + iter.genmask = nft_genmask_next(ctx->net); + iter.skip = 0; + iter.count = 0; + iter.err = 0; + iter.fn = nft_setelem_validate; + + priv->set->ops->walk(ctx, priv->set, &iter); + if (!iter.err) + iter.err = nft_set_catchall_validate(ctx, priv->set); + + if (iter.err < 0) + return iter.err; + + return 0; +} + +static bool nft_lookup_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + const struct nft_lookup *priv = nft_expr_priv(expr); + + if (priv->set->flags & NFT_SET_MAP) + nft_reg_track_cancel(track, priv->dreg, priv->set->dlen); + + return false; +} + +static const struct nft_expr_ops nft_lookup_ops = { + .type = &nft_lookup_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_lookup)), + .eval = nft_lookup_eval, + .init = nft_lookup_init, + .activate = nft_lookup_activate, + .deactivate = nft_lookup_deactivate, + .destroy = nft_lookup_destroy, + .dump = nft_lookup_dump, + .validate = nft_lookup_validate, + .reduce = nft_lookup_reduce, +}; + +struct nft_expr_type nft_lookup_type __read_mostly = { + .name = "lookup", + .ops = &nft_lookup_ops, + .policy = nft_lookup_policy, + .maxattr = NFTA_LOOKUP_MAX, + .owner = THIS_MODULE, +}; diff --git a/net/netfilter/nft_masq.c b/net/netfilter/nft_masq.c new file mode 100644 index 000000000..026b4f87d --- /dev/null +++ b/net/netfilter/nft_masq.c @@ -0,0 +1,307 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2014 Arturo Borrero Gonzalez + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_masq { + u32 flags; + u8 sreg_proto_min; + u8 sreg_proto_max; +}; + +static const struct nla_policy nft_masq_policy[NFTA_MASQ_MAX + 1] = { + [NFTA_MASQ_FLAGS] = { .type = NLA_U32 }, + [NFTA_MASQ_REG_PROTO_MIN] = { .type = NLA_U32 }, + [NFTA_MASQ_REG_PROTO_MAX] = { .type = NLA_U32 }, +}; + +static int nft_masq_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + int err; + + err = nft_chain_validate_dependency(ctx->chain, NFT_CHAIN_T_NAT); + if (err < 0) + return err; + + return nft_chain_validate_hooks(ctx->chain, + (1 << NF_INET_POST_ROUTING)); +} + +static int nft_masq_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + u32 plen = sizeof_field(struct nf_nat_range, min_proto.all); + struct nft_masq *priv = nft_expr_priv(expr); + int err; + + if (tb[NFTA_MASQ_FLAGS]) { + priv->flags = ntohl(nla_get_be32(tb[NFTA_MASQ_FLAGS])); + if (priv->flags & ~NF_NAT_RANGE_MASK) + return -EINVAL; + } + + if (tb[NFTA_MASQ_REG_PROTO_MIN]) { + err = nft_parse_register_load(tb[NFTA_MASQ_REG_PROTO_MIN], + &priv->sreg_proto_min, plen); + if (err < 0) + return err; + + if (tb[NFTA_MASQ_REG_PROTO_MAX]) { + err = nft_parse_register_load(tb[NFTA_MASQ_REG_PROTO_MAX], + &priv->sreg_proto_max, + plen); + if (err < 0) + return err; + } else { + priv->sreg_proto_max = priv->sreg_proto_min; + } + } + + return nf_ct_netns_get(ctx->net, ctx->family); +} + +static int nft_masq_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_masq *priv = nft_expr_priv(expr); + + if (priv->flags != 0 && + nla_put_be32(skb, NFTA_MASQ_FLAGS, htonl(priv->flags))) + goto nla_put_failure; + + if (priv->sreg_proto_min) { + if (nft_dump_register(skb, NFTA_MASQ_REG_PROTO_MIN, + priv->sreg_proto_min) || + nft_dump_register(skb, NFTA_MASQ_REG_PROTO_MAX, + priv->sreg_proto_max)) + goto nla_put_failure; + } + + return 0; + +nla_put_failure: + return -1; +} + +static void nft_masq_ipv4_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_masq *priv = nft_expr_priv(expr); + struct nf_nat_range2 range; + + memset(&range, 0, sizeof(range)); + range.flags = priv->flags; + if (priv->sreg_proto_min) { + range.min_proto.all = (__force __be16)nft_reg_load16( + ®s->data[priv->sreg_proto_min]); + range.max_proto.all = (__force __be16)nft_reg_load16( + ®s->data[priv->sreg_proto_max]); + } + regs->verdict.code = nf_nat_masquerade_ipv4(pkt->skb, nft_hook(pkt), + &range, nft_out(pkt)); +} + +static void +nft_masq_ipv4_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) +{ + nf_ct_netns_put(ctx->net, NFPROTO_IPV4); +} + +static struct nft_expr_type nft_masq_ipv4_type; +static const struct nft_expr_ops nft_masq_ipv4_ops = { + .type = &nft_masq_ipv4_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_masq)), + .eval = nft_masq_ipv4_eval, + .init = nft_masq_init, + .destroy = nft_masq_ipv4_destroy, + .dump = nft_masq_dump, + .validate = nft_masq_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_masq_ipv4_type __read_mostly = { + .family = NFPROTO_IPV4, + .name = "masq", + .ops = &nft_masq_ipv4_ops, + .policy = nft_masq_policy, + .maxattr = NFTA_MASQ_MAX, + .owner = THIS_MODULE, +}; + +#ifdef CONFIG_NF_TABLES_IPV6 +static void nft_masq_ipv6_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_masq *priv = nft_expr_priv(expr); + struct nf_nat_range2 range; + + memset(&range, 0, sizeof(range)); + range.flags = priv->flags; + if (priv->sreg_proto_min) { + range.min_proto.all = (__force __be16)nft_reg_load16( + ®s->data[priv->sreg_proto_min]); + range.max_proto.all = (__force __be16)nft_reg_load16( + ®s->data[priv->sreg_proto_max]); + } + regs->verdict.code = nf_nat_masquerade_ipv6(pkt->skb, &range, + nft_out(pkt)); +} + +static void +nft_masq_ipv6_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) +{ + nf_ct_netns_put(ctx->net, NFPROTO_IPV6); +} + +static struct nft_expr_type nft_masq_ipv6_type; +static const struct nft_expr_ops nft_masq_ipv6_ops = { + .type = &nft_masq_ipv6_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_masq)), + .eval = nft_masq_ipv6_eval, + .init = nft_masq_init, + .destroy = nft_masq_ipv6_destroy, + .dump = nft_masq_dump, + .validate = nft_masq_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_masq_ipv6_type __read_mostly = { + .family = NFPROTO_IPV6, + .name = "masq", + .ops = &nft_masq_ipv6_ops, + .policy = nft_masq_policy, + .maxattr = NFTA_MASQ_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_masq_module_init_ipv6(void) +{ + return nft_register_expr(&nft_masq_ipv6_type); +} + +static void nft_masq_module_exit_ipv6(void) +{ + nft_unregister_expr(&nft_masq_ipv6_type); +} +#else +static inline int nft_masq_module_init_ipv6(void) { return 0; } +static inline void nft_masq_module_exit_ipv6(void) {} +#endif + +#ifdef CONFIG_NF_TABLES_INET +static void nft_masq_inet_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + switch (nft_pf(pkt)) { + case NFPROTO_IPV4: + return nft_masq_ipv4_eval(expr, regs, pkt); + case NFPROTO_IPV6: + return nft_masq_ipv6_eval(expr, regs, pkt); + } + + WARN_ON_ONCE(1); +} + +static void +nft_masq_inet_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) +{ + nf_ct_netns_put(ctx->net, NFPROTO_INET); +} + +static struct nft_expr_type nft_masq_inet_type; +static const struct nft_expr_ops nft_masq_inet_ops = { + .type = &nft_masq_inet_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_masq)), + .eval = nft_masq_inet_eval, + .init = nft_masq_init, + .destroy = nft_masq_inet_destroy, + .dump = nft_masq_dump, + .validate = nft_masq_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_masq_inet_type __read_mostly = { + .family = NFPROTO_INET, + .name = "masq", + .ops = &nft_masq_inet_ops, + .policy = nft_masq_policy, + .maxattr = NFTA_MASQ_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_masq_module_init_inet(void) +{ + return nft_register_expr(&nft_masq_inet_type); +} + +static void nft_masq_module_exit_inet(void) +{ + nft_unregister_expr(&nft_masq_inet_type); +} +#else +static inline int nft_masq_module_init_inet(void) { return 0; } +static inline void nft_masq_module_exit_inet(void) {} +#endif + +static int __init nft_masq_module_init(void) +{ + int ret; + + ret = nft_masq_module_init_ipv6(); + if (ret < 0) + return ret; + + ret = nft_masq_module_init_inet(); + if (ret < 0) { + nft_masq_module_exit_ipv6(); + return ret; + } + + ret = nft_register_expr(&nft_masq_ipv4_type); + if (ret < 0) { + nft_masq_module_exit_inet(); + nft_masq_module_exit_ipv6(); + return ret; + } + + ret = nf_nat_masquerade_inet_register_notifiers(); + if (ret < 0) { + nft_masq_module_exit_ipv6(); + nft_masq_module_exit_inet(); + nft_unregister_expr(&nft_masq_ipv4_type); + return ret; + } + + return ret; +} + +static void __exit nft_masq_module_exit(void) +{ + nft_masq_module_exit_ipv6(); + nft_masq_module_exit_inet(); + nft_unregister_expr(&nft_masq_ipv4_type); + nf_nat_masquerade_inet_unregister_notifiers(); +} + +module_init(nft_masq_module_init); +module_exit(nft_masq_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Arturo Borrero Gonzalez "); +MODULE_ALIAS_NFT_EXPR("masq"); +MODULE_DESCRIPTION("Netfilter nftables masquerade expression support"); diff --git a/net/netfilter/nft_meta.c b/net/netfilter/nft_meta.c new file mode 100644 index 000000000..6e8332192 --- /dev/null +++ b/net/netfilter/nft_meta.c @@ -0,0 +1,948 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008-2009 Patrick McHardy + * Copyright (c) 2014 Intel Corporation + * Author: Tomasz Bursztyka + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include /* for TCP_TIME_WAIT */ +#include +#include +#include +#include + +#include /* NF_BR_PRE_ROUTING */ + +#define NFT_META_SECS_PER_MINUTE 60 +#define NFT_META_SECS_PER_HOUR 3600 +#define NFT_META_SECS_PER_DAY 86400 +#define NFT_META_DAYS_PER_WEEK 7 + +static u8 nft_meta_weekday(void) +{ + time64_t secs = ktime_get_real_seconds(); + unsigned int dse; + u8 wday; + + secs -= NFT_META_SECS_PER_MINUTE * sys_tz.tz_minuteswest; + dse = div_u64(secs, NFT_META_SECS_PER_DAY); + wday = (4 + dse) % NFT_META_DAYS_PER_WEEK; + + return wday; +} + +static u32 nft_meta_hour(time64_t secs) +{ + struct tm tm; + + time64_to_tm(secs, 0, &tm); + + return tm.tm_hour * NFT_META_SECS_PER_HOUR + + tm.tm_min * NFT_META_SECS_PER_MINUTE + + tm.tm_sec; +} + +static noinline_for_stack void +nft_meta_get_eval_time(enum nft_meta_keys key, + u32 *dest) +{ + switch (key) { + case NFT_META_TIME_NS: + nft_reg_store64((u64 *)dest, ktime_get_real_ns()); + break; + case NFT_META_TIME_DAY: + nft_reg_store8(dest, nft_meta_weekday()); + break; + case NFT_META_TIME_HOUR: + *dest = nft_meta_hour(ktime_get_real_seconds()); + break; + default: + break; + } +} + +static noinline bool +nft_meta_get_eval_pkttype_lo(const struct nft_pktinfo *pkt, + u32 *dest) +{ + const struct sk_buff *skb = pkt->skb; + + switch (nft_pf(pkt)) { + case NFPROTO_IPV4: + if (ipv4_is_multicast(ip_hdr(skb)->daddr)) + nft_reg_store8(dest, PACKET_MULTICAST); + else + nft_reg_store8(dest, PACKET_BROADCAST); + break; + case NFPROTO_IPV6: + nft_reg_store8(dest, PACKET_MULTICAST); + break; + case NFPROTO_NETDEV: + switch (skb->protocol) { + case htons(ETH_P_IP): { + int noff = skb_network_offset(skb); + struct iphdr *iph, _iph; + + iph = skb_header_pointer(skb, noff, + sizeof(_iph), &_iph); + if (!iph) + return false; + + if (ipv4_is_multicast(iph->daddr)) + nft_reg_store8(dest, PACKET_MULTICAST); + else + nft_reg_store8(dest, PACKET_BROADCAST); + + break; + } + case htons(ETH_P_IPV6): + nft_reg_store8(dest, PACKET_MULTICAST); + break; + default: + WARN_ON_ONCE(1); + return false; + } + break; + default: + WARN_ON_ONCE(1); + return false; + } + + return true; +} + +static noinline bool +nft_meta_get_eval_skugid(enum nft_meta_keys key, + u32 *dest, + const struct nft_pktinfo *pkt) +{ + struct sock *sk = skb_to_full_sk(pkt->skb); + struct socket *sock; + + if (!sk || !sk_fullsock(sk) || !net_eq(nft_net(pkt), sock_net(sk))) + return false; + + read_lock_bh(&sk->sk_callback_lock); + sock = sk->sk_socket; + if (!sock || !sock->file) { + read_unlock_bh(&sk->sk_callback_lock); + return false; + } + + switch (key) { + case NFT_META_SKUID: + *dest = from_kuid_munged(sock_net(sk)->user_ns, + sock->file->f_cred->fsuid); + break; + case NFT_META_SKGID: + *dest = from_kgid_munged(sock_net(sk)->user_ns, + sock->file->f_cred->fsgid); + break; + default: + break; + } + + read_unlock_bh(&sk->sk_callback_lock); + return true; +} + +#ifdef CONFIG_CGROUP_NET_CLASSID +static noinline bool +nft_meta_get_eval_cgroup(u32 *dest, const struct nft_pktinfo *pkt) +{ + struct sock *sk = skb_to_full_sk(pkt->skb); + + if (!sk || !sk_fullsock(sk) || !net_eq(nft_net(pkt), sock_net(sk))) + return false; + + *dest = sock_cgroup_classid(&sk->sk_cgrp_data); + return true; +} +#endif + +static noinline bool nft_meta_get_eval_kind(enum nft_meta_keys key, + u32 *dest, + const struct nft_pktinfo *pkt) +{ + const struct net_device *in = nft_in(pkt), *out = nft_out(pkt); + + switch (key) { + case NFT_META_IIFKIND: + if (!in || !in->rtnl_link_ops) + return false; + strncpy((char *)dest, in->rtnl_link_ops->kind, IFNAMSIZ); + break; + case NFT_META_OIFKIND: + if (!out || !out->rtnl_link_ops) + return false; + strncpy((char *)dest, out->rtnl_link_ops->kind, IFNAMSIZ); + break; + default: + return false; + } + + return true; +} + +static void nft_meta_store_ifindex(u32 *dest, const struct net_device *dev) +{ + *dest = dev ? dev->ifindex : 0; +} + +static void nft_meta_store_ifname(u32 *dest, const struct net_device *dev) +{ + strncpy((char *)dest, dev ? dev->name : "", IFNAMSIZ); +} + +static bool nft_meta_store_iftype(u32 *dest, const struct net_device *dev) +{ + if (!dev) + return false; + + nft_reg_store16(dest, dev->type); + return true; +} + +static bool nft_meta_store_ifgroup(u32 *dest, const struct net_device *dev) +{ + if (!dev) + return false; + + *dest = dev->group; + return true; +} + +static bool nft_meta_get_eval_ifname(enum nft_meta_keys key, u32 *dest, + const struct nft_pktinfo *pkt) +{ + switch (key) { + case NFT_META_IIFNAME: + nft_meta_store_ifname(dest, nft_in(pkt)); + break; + case NFT_META_OIFNAME: + nft_meta_store_ifname(dest, nft_out(pkt)); + break; + case NFT_META_IIF: + nft_meta_store_ifindex(dest, nft_in(pkt)); + break; + case NFT_META_OIF: + nft_meta_store_ifindex(dest, nft_out(pkt)); + break; + case NFT_META_IFTYPE: + if (!nft_meta_store_iftype(dest, pkt->skb->dev)) + return false; + break; + case __NFT_META_IIFTYPE: + if (!nft_meta_store_iftype(dest, nft_in(pkt))) + return false; + break; + case NFT_META_OIFTYPE: + if (!nft_meta_store_iftype(dest, nft_out(pkt))) + return false; + break; + case NFT_META_IIFGROUP: + if (!nft_meta_store_ifgroup(dest, nft_in(pkt))) + return false; + break; + case NFT_META_OIFGROUP: + if (!nft_meta_store_ifgroup(dest, nft_out(pkt))) + return false; + break; + default: + return false; + } + + return true; +} + +#ifdef CONFIG_IP_ROUTE_CLASSID +static noinline bool +nft_meta_get_eval_rtclassid(const struct sk_buff *skb, u32 *dest) +{ + const struct dst_entry *dst = skb_dst(skb); + + if (!dst) + return false; + + *dest = dst->tclassid; + return true; +} +#endif + +static noinline u32 nft_meta_get_eval_sdif(const struct nft_pktinfo *pkt) +{ + switch (nft_pf(pkt)) { + case NFPROTO_IPV4: + return inet_sdif(pkt->skb); + case NFPROTO_IPV6: + return inet6_sdif(pkt->skb); + } + + return 0; +} + +static noinline void +nft_meta_get_eval_sdifname(u32 *dest, const struct nft_pktinfo *pkt) +{ + u32 sdif = nft_meta_get_eval_sdif(pkt); + const struct net_device *dev; + + dev = sdif ? dev_get_by_index_rcu(nft_net(pkt), sdif) : NULL; + nft_meta_store_ifname(dest, dev); +} + +void nft_meta_get_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_meta *priv = nft_expr_priv(expr); + const struct sk_buff *skb = pkt->skb; + u32 *dest = ®s->data[priv->dreg]; + + switch (priv->key) { + case NFT_META_LEN: + *dest = skb->len; + break; + case NFT_META_PROTOCOL: + nft_reg_store16(dest, (__force u16)skb->protocol); + break; + case NFT_META_NFPROTO: + nft_reg_store8(dest, nft_pf(pkt)); + break; + case NFT_META_L4PROTO: + if (!(pkt->flags & NFT_PKTINFO_L4PROTO)) + goto err; + nft_reg_store8(dest, pkt->tprot); + break; + case NFT_META_PRIORITY: + *dest = skb->priority; + break; + case NFT_META_MARK: + *dest = skb->mark; + break; + case NFT_META_IIF: + case NFT_META_OIF: + case NFT_META_IIFNAME: + case NFT_META_OIFNAME: + case NFT_META_IIFTYPE: + case NFT_META_OIFTYPE: + case NFT_META_IIFGROUP: + case NFT_META_OIFGROUP: + if (!nft_meta_get_eval_ifname(priv->key, dest, pkt)) + goto err; + break; + case NFT_META_SKUID: + case NFT_META_SKGID: + if (!nft_meta_get_eval_skugid(priv->key, dest, pkt)) + goto err; + break; +#ifdef CONFIG_IP_ROUTE_CLASSID + case NFT_META_RTCLASSID: + if (!nft_meta_get_eval_rtclassid(skb, dest)) + goto err; + break; +#endif +#ifdef CONFIG_NETWORK_SECMARK + case NFT_META_SECMARK: + *dest = skb->secmark; + break; +#endif + case NFT_META_PKTTYPE: + if (skb->pkt_type != PACKET_LOOPBACK) { + nft_reg_store8(dest, skb->pkt_type); + break; + } + + if (!nft_meta_get_eval_pkttype_lo(pkt, dest)) + goto err; + break; + case NFT_META_CPU: + *dest = raw_smp_processor_id(); + break; +#ifdef CONFIG_CGROUP_NET_CLASSID + case NFT_META_CGROUP: + if (!nft_meta_get_eval_cgroup(dest, pkt)) + goto err; + break; +#endif + case NFT_META_PRANDOM: + *dest = get_random_u32(); + break; +#ifdef CONFIG_XFRM + case NFT_META_SECPATH: + nft_reg_store8(dest, secpath_exists(skb)); + break; +#endif + case NFT_META_IIFKIND: + case NFT_META_OIFKIND: + if (!nft_meta_get_eval_kind(priv->key, dest, pkt)) + goto err; + break; + case NFT_META_TIME_NS: + case NFT_META_TIME_DAY: + case NFT_META_TIME_HOUR: + nft_meta_get_eval_time(priv->key, dest); + break; + case NFT_META_SDIF: + *dest = nft_meta_get_eval_sdif(pkt); + break; + case NFT_META_SDIFNAME: + nft_meta_get_eval_sdifname(dest, pkt); + break; + default: + WARN_ON(1); + goto err; + } + return; + +err: + regs->verdict.code = NFT_BREAK; +} +EXPORT_SYMBOL_GPL(nft_meta_get_eval); + +void nft_meta_set_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_meta *meta = nft_expr_priv(expr); + struct sk_buff *skb = pkt->skb; + u32 *sreg = ®s->data[meta->sreg]; + u32 value = *sreg; + u8 value8; + + switch (meta->key) { + case NFT_META_MARK: + skb->mark = value; + break; + case NFT_META_PRIORITY: + skb->priority = value; + break; + case NFT_META_PKTTYPE: + value8 = nft_reg_load8(sreg); + + if (skb->pkt_type != value8 && + skb_pkt_type_ok(value8) && + skb_pkt_type_ok(skb->pkt_type)) + skb->pkt_type = value8; + break; + case NFT_META_NFTRACE: + value8 = nft_reg_load8(sreg); + + skb->nf_trace = !!value8; + break; +#ifdef CONFIG_NETWORK_SECMARK + case NFT_META_SECMARK: + skb->secmark = value; + break; +#endif + default: + WARN_ON(1); + } +} +EXPORT_SYMBOL_GPL(nft_meta_set_eval); + +const struct nla_policy nft_meta_policy[NFTA_META_MAX + 1] = { + [NFTA_META_DREG] = { .type = NLA_U32 }, + [NFTA_META_KEY] = { .type = NLA_U32 }, + [NFTA_META_SREG] = { .type = NLA_U32 }, +}; +EXPORT_SYMBOL_GPL(nft_meta_policy); + +int nft_meta_get_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_meta *priv = nft_expr_priv(expr); + unsigned int len; + + priv->key = ntohl(nla_get_be32(tb[NFTA_META_KEY])); + switch (priv->key) { + case NFT_META_PROTOCOL: + case NFT_META_IIFTYPE: + case NFT_META_OIFTYPE: + len = sizeof(u16); + break; + case NFT_META_NFPROTO: + case NFT_META_L4PROTO: + case NFT_META_LEN: + case NFT_META_PRIORITY: + case NFT_META_MARK: + case NFT_META_IIF: + case NFT_META_OIF: + case NFT_META_SDIF: + case NFT_META_SKUID: + case NFT_META_SKGID: +#ifdef CONFIG_IP_ROUTE_CLASSID + case NFT_META_RTCLASSID: +#endif +#ifdef CONFIG_NETWORK_SECMARK + case NFT_META_SECMARK: +#endif + case NFT_META_PKTTYPE: + case NFT_META_CPU: + case NFT_META_IIFGROUP: + case NFT_META_OIFGROUP: +#ifdef CONFIG_CGROUP_NET_CLASSID + case NFT_META_CGROUP: +#endif + len = sizeof(u32); + break; + case NFT_META_IIFNAME: + case NFT_META_OIFNAME: + case NFT_META_IIFKIND: + case NFT_META_OIFKIND: + case NFT_META_SDIFNAME: + len = IFNAMSIZ; + break; + case NFT_META_PRANDOM: + len = sizeof(u32); + break; +#ifdef CONFIG_XFRM + case NFT_META_SECPATH: + len = sizeof(u8); + break; +#endif + case NFT_META_TIME_NS: + len = sizeof(u64); + break; + case NFT_META_TIME_DAY: + len = sizeof(u8); + break; + case NFT_META_TIME_HOUR: + len = sizeof(u32); + break; + default: + return -EOPNOTSUPP; + } + + priv->len = len; + return nft_parse_register_store(ctx, tb[NFTA_META_DREG], &priv->dreg, + NULL, NFT_DATA_VALUE, len); +} +EXPORT_SYMBOL_GPL(nft_meta_get_init); + +static int nft_meta_get_validate_sdif(const struct nft_ctx *ctx) +{ + unsigned int hooks; + + switch (ctx->family) { + case NFPROTO_IPV4: + case NFPROTO_IPV6: + case NFPROTO_INET: + hooks = (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_FORWARD); + break; + default: + return -EOPNOTSUPP; + } + + return nft_chain_validate_hooks(ctx->chain, hooks); +} + +static int nft_meta_get_validate_xfrm(const struct nft_ctx *ctx) +{ +#ifdef CONFIG_XFRM + unsigned int hooks; + + switch (ctx->family) { + case NFPROTO_NETDEV: + hooks = 1 << NF_NETDEV_INGRESS; + break; + case NFPROTO_IPV4: + case NFPROTO_IPV6: + case NFPROTO_INET: + hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_FORWARD); + break; + default: + return -EOPNOTSUPP; + } + + return nft_chain_validate_hooks(ctx->chain, hooks); +#else + return 0; +#endif +} + +static int nft_meta_get_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + const struct nft_meta *priv = nft_expr_priv(expr); + + switch (priv->key) { + case NFT_META_SECPATH: + return nft_meta_get_validate_xfrm(ctx); + case NFT_META_SDIF: + case NFT_META_SDIFNAME: + return nft_meta_get_validate_sdif(ctx); + default: + break; + } + + return 0; +} + +int nft_meta_set_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + struct nft_meta *priv = nft_expr_priv(expr); + unsigned int hooks; + + if (priv->key != NFT_META_PKTTYPE) + return 0; + + switch (ctx->family) { + case NFPROTO_BRIDGE: + hooks = 1 << NF_BR_PRE_ROUTING; + break; + case NFPROTO_NETDEV: + hooks = 1 << NF_NETDEV_INGRESS; + break; + case NFPROTO_IPV4: + case NFPROTO_IPV6: + case NFPROTO_INET: + hooks = 1 << NF_INET_PRE_ROUTING; + break; + default: + return -EOPNOTSUPP; + } + + return nft_chain_validate_hooks(ctx->chain, hooks); +} +EXPORT_SYMBOL_GPL(nft_meta_set_validate); + +int nft_meta_set_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_meta *priv = nft_expr_priv(expr); + unsigned int len; + int err; + + priv->key = ntohl(nla_get_be32(tb[NFTA_META_KEY])); + switch (priv->key) { + case NFT_META_MARK: + case NFT_META_PRIORITY: +#ifdef CONFIG_NETWORK_SECMARK + case NFT_META_SECMARK: +#endif + len = sizeof(u32); + break; + case NFT_META_NFTRACE: + len = sizeof(u8); + break; + case NFT_META_PKTTYPE: + len = sizeof(u8); + break; + default: + return -EOPNOTSUPP; + } + + priv->len = len; + err = nft_parse_register_load(tb[NFTA_META_SREG], &priv->sreg, len); + if (err < 0) + return err; + + if (priv->key == NFT_META_NFTRACE) + static_branch_inc(&nft_trace_enabled); + + return 0; +} +EXPORT_SYMBOL_GPL(nft_meta_set_init); + +int nft_meta_get_dump(struct sk_buff *skb, + const struct nft_expr *expr) +{ + const struct nft_meta *priv = nft_expr_priv(expr); + + if (nla_put_be32(skb, NFTA_META_KEY, htonl(priv->key))) + goto nla_put_failure; + if (nft_dump_register(skb, NFTA_META_DREG, priv->dreg)) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} +EXPORT_SYMBOL_GPL(nft_meta_get_dump); + +int nft_meta_set_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_meta *priv = nft_expr_priv(expr); + + if (nla_put_be32(skb, NFTA_META_KEY, htonl(priv->key))) + goto nla_put_failure; + if (nft_dump_register(skb, NFTA_META_SREG, priv->sreg)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} +EXPORT_SYMBOL_GPL(nft_meta_set_dump); + +void nft_meta_set_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + const struct nft_meta *priv = nft_expr_priv(expr); + + if (priv->key == NFT_META_NFTRACE) + static_branch_dec(&nft_trace_enabled); +} +EXPORT_SYMBOL_GPL(nft_meta_set_destroy); + +static int nft_meta_get_offload(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_expr *expr) +{ + const struct nft_meta *priv = nft_expr_priv(expr); + struct nft_offload_reg *reg = &ctx->regs[priv->dreg]; + + switch (priv->key) { + case NFT_META_PROTOCOL: + NFT_OFFLOAD_MATCH_EXACT(FLOW_DISSECTOR_KEY_BASIC, basic, n_proto, + sizeof(__u16), reg); + nft_offload_set_dependency(ctx, NFT_OFFLOAD_DEP_NETWORK); + break; + case NFT_META_L4PROTO: + NFT_OFFLOAD_MATCH_EXACT(FLOW_DISSECTOR_KEY_BASIC, basic, ip_proto, + sizeof(__u8), reg); + nft_offload_set_dependency(ctx, NFT_OFFLOAD_DEP_TRANSPORT); + break; + case NFT_META_IIF: + NFT_OFFLOAD_MATCH_EXACT(FLOW_DISSECTOR_KEY_META, meta, + ingress_ifindex, sizeof(__u32), reg); + break; + case NFT_META_IIFTYPE: + NFT_OFFLOAD_MATCH_EXACT(FLOW_DISSECTOR_KEY_META, meta, + ingress_iftype, sizeof(__u16), reg); + break; + default: + return -EOPNOTSUPP; + } + + return 0; +} + +bool nft_meta_get_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + const struct nft_meta *priv = nft_expr_priv(expr); + const struct nft_meta *meta; + + if (!nft_reg_track_cmp(track, expr, priv->dreg)) { + nft_reg_track_update(track, expr, priv->dreg, priv->len); + return false; + } + + meta = nft_expr_priv(track->regs[priv->dreg].selector); + if (priv->key != meta->key || + priv->dreg != meta->dreg) { + nft_reg_track_update(track, expr, priv->dreg, priv->len); + return false; + } + + if (!track->regs[priv->dreg].bitwise) + return true; + + return nft_expr_reduce_bitwise(track, expr); +} +EXPORT_SYMBOL_GPL(nft_meta_get_reduce); + +static const struct nft_expr_ops nft_meta_get_ops = { + .type = &nft_meta_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_meta)), + .eval = nft_meta_get_eval, + .init = nft_meta_get_init, + .dump = nft_meta_get_dump, + .reduce = nft_meta_get_reduce, + .validate = nft_meta_get_validate, + .offload = nft_meta_get_offload, +}; + +static bool nft_meta_set_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + int i; + + for (i = 0; i < NFT_REG32_NUM; i++) { + if (!track->regs[i].selector) + continue; + + if (track->regs[i].selector->ops != &nft_meta_get_ops) + continue; + + __nft_reg_track_cancel(track, i); + } + + return false; +} + +static const struct nft_expr_ops nft_meta_set_ops = { + .type = &nft_meta_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_meta)), + .eval = nft_meta_set_eval, + .init = nft_meta_set_init, + .destroy = nft_meta_set_destroy, + .dump = nft_meta_set_dump, + .reduce = nft_meta_set_reduce, + .validate = nft_meta_set_validate, +}; + +static const struct nft_expr_ops * +nft_meta_select_ops(const struct nft_ctx *ctx, + const struct nlattr * const tb[]) +{ + if (tb[NFTA_META_KEY] == NULL) + return ERR_PTR(-EINVAL); + + if (tb[NFTA_META_DREG] && tb[NFTA_META_SREG]) + return ERR_PTR(-EINVAL); + +#if IS_ENABLED(CONFIG_NF_TABLES_BRIDGE) && IS_MODULE(CONFIG_NFT_BRIDGE_META) + if (ctx->family == NFPROTO_BRIDGE) + return ERR_PTR(-EAGAIN); +#endif + if (tb[NFTA_META_DREG]) + return &nft_meta_get_ops; + + if (tb[NFTA_META_SREG]) + return &nft_meta_set_ops; + + return ERR_PTR(-EINVAL); +} + +struct nft_expr_type nft_meta_type __read_mostly = { + .name = "meta", + .select_ops = nft_meta_select_ops, + .policy = nft_meta_policy, + .maxattr = NFTA_META_MAX, + .owner = THIS_MODULE, +}; + +#ifdef CONFIG_NETWORK_SECMARK +struct nft_secmark { + u32 secid; + char *ctx; +}; + +static const struct nla_policy nft_secmark_policy[NFTA_SECMARK_MAX + 1] = { + [NFTA_SECMARK_CTX] = { .type = NLA_STRING, .len = NFT_SECMARK_CTX_MAXLEN }, +}; + +static int nft_secmark_compute_secid(struct nft_secmark *priv) +{ + u32 tmp_secid = 0; + int err; + + err = security_secctx_to_secid(priv->ctx, strlen(priv->ctx), &tmp_secid); + if (err) + return err; + + if (!tmp_secid) + return -ENOENT; + + err = security_secmark_relabel_packet(tmp_secid); + if (err) + return err; + + priv->secid = tmp_secid; + return 0; +} + +static void nft_secmark_obj_eval(struct nft_object *obj, struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_secmark *priv = nft_obj_data(obj); + struct sk_buff *skb = pkt->skb; + + skb->secmark = priv->secid; +} + +static int nft_secmark_obj_init(const struct nft_ctx *ctx, + const struct nlattr * const tb[], + struct nft_object *obj) +{ + struct nft_secmark *priv = nft_obj_data(obj); + int err; + + if (tb[NFTA_SECMARK_CTX] == NULL) + return -EINVAL; + + priv->ctx = nla_strdup(tb[NFTA_SECMARK_CTX], GFP_KERNEL); + if (!priv->ctx) + return -ENOMEM; + + err = nft_secmark_compute_secid(priv); + if (err) { + kfree(priv->ctx); + return err; + } + + security_secmark_refcount_inc(); + + return 0; +} + +static int nft_secmark_obj_dump(struct sk_buff *skb, struct nft_object *obj, + bool reset) +{ + struct nft_secmark *priv = nft_obj_data(obj); + int err; + + if (nla_put_string(skb, NFTA_SECMARK_CTX, priv->ctx)) + return -1; + + if (reset) { + err = nft_secmark_compute_secid(priv); + if (err) + return err; + } + + return 0; +} + +static void nft_secmark_obj_destroy(const struct nft_ctx *ctx, struct nft_object *obj) +{ + struct nft_secmark *priv = nft_obj_data(obj); + + security_secmark_refcount_dec(); + + kfree(priv->ctx); +} + +static const struct nft_object_ops nft_secmark_obj_ops = { + .type = &nft_secmark_obj_type, + .size = sizeof(struct nft_secmark), + .init = nft_secmark_obj_init, + .eval = nft_secmark_obj_eval, + .dump = nft_secmark_obj_dump, + .destroy = nft_secmark_obj_destroy, +}; +struct nft_object_type nft_secmark_obj_type __read_mostly = { + .type = NFT_OBJECT_SECMARK, + .ops = &nft_secmark_obj_ops, + .maxattr = NFTA_SECMARK_MAX, + .policy = nft_secmark_policy, + .owner = THIS_MODULE, +}; +#endif /* CONFIG_NETWORK_SECMARK */ diff --git a/net/netfilter/nft_nat.c b/net/netfilter/nft_nat.c new file mode 100644 index 000000000..ba7bcce72 --- /dev/null +++ b/net/netfilter/nft_nat.c @@ -0,0 +1,408 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008-2009 Patrick McHardy + * Copyright (c) 2012 Pablo Neira Ayuso + * Copyright (c) 2012 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_nat { + u8 sreg_addr_min; + u8 sreg_addr_max; + u8 sreg_proto_min; + u8 sreg_proto_max; + enum nf_nat_manip_type type:8; + u8 family; + u16 flags; +}; + +static void nft_nat_setup_addr(struct nf_nat_range2 *range, + const struct nft_regs *regs, + const struct nft_nat *priv) +{ + switch (priv->family) { + case AF_INET: + range->min_addr.ip = (__force __be32) + regs->data[priv->sreg_addr_min]; + range->max_addr.ip = (__force __be32) + regs->data[priv->sreg_addr_max]; + break; + case AF_INET6: + memcpy(range->min_addr.ip6, ®s->data[priv->sreg_addr_min], + sizeof(range->min_addr.ip6)); + memcpy(range->max_addr.ip6, ®s->data[priv->sreg_addr_max], + sizeof(range->max_addr.ip6)); + break; + } +} + +static void nft_nat_setup_proto(struct nf_nat_range2 *range, + const struct nft_regs *regs, + const struct nft_nat *priv) +{ + range->min_proto.all = (__force __be16) + nft_reg_load16(®s->data[priv->sreg_proto_min]); + range->max_proto.all = (__force __be16) + nft_reg_load16(®s->data[priv->sreg_proto_max]); +} + +static void nft_nat_setup_netmap(struct nf_nat_range2 *range, + const struct nft_pktinfo *pkt, + const struct nft_nat *priv) +{ + struct sk_buff *skb = pkt->skb; + union nf_inet_addr new_addr; + __be32 netmask; + int i, len = 0; + + switch (priv->type) { + case NFT_NAT_SNAT: + if (nft_pf(pkt) == NFPROTO_IPV4) { + new_addr.ip = ip_hdr(skb)->saddr; + len = sizeof(struct in_addr); + } else { + new_addr.in6 = ipv6_hdr(skb)->saddr; + len = sizeof(struct in6_addr); + } + break; + case NFT_NAT_DNAT: + if (nft_pf(pkt) == NFPROTO_IPV4) { + new_addr.ip = ip_hdr(skb)->daddr; + len = sizeof(struct in_addr); + } else { + new_addr.in6 = ipv6_hdr(skb)->daddr; + len = sizeof(struct in6_addr); + } + break; + } + + for (i = 0; i < len / sizeof(__be32); i++) { + netmask = ~(range->min_addr.ip6[i] ^ range->max_addr.ip6[i]); + new_addr.ip6[i] &= ~netmask; + new_addr.ip6[i] |= range->min_addr.ip6[i] & netmask; + } + + range->min_addr = new_addr; + range->max_addr = new_addr; +} + +static void nft_nat_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_nat *priv = nft_expr_priv(expr); + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(pkt->skb, &ctinfo); + struct nf_nat_range2 range; + + memset(&range, 0, sizeof(range)); + + if (priv->sreg_addr_min) { + nft_nat_setup_addr(&range, regs, priv); + if (priv->flags & NF_NAT_RANGE_NETMAP) + nft_nat_setup_netmap(&range, pkt, priv); + } + + if (priv->sreg_proto_min) + nft_nat_setup_proto(&range, regs, priv); + + range.flags = priv->flags; + + regs->verdict.code = nf_nat_setup_info(ct, &range, priv->type); +} + +static const struct nla_policy nft_nat_policy[NFTA_NAT_MAX + 1] = { + [NFTA_NAT_TYPE] = { .type = NLA_U32 }, + [NFTA_NAT_FAMILY] = { .type = NLA_U32 }, + [NFTA_NAT_REG_ADDR_MIN] = { .type = NLA_U32 }, + [NFTA_NAT_REG_ADDR_MAX] = { .type = NLA_U32 }, + [NFTA_NAT_REG_PROTO_MIN] = { .type = NLA_U32 }, + [NFTA_NAT_REG_PROTO_MAX] = { .type = NLA_U32 }, + [NFTA_NAT_FLAGS] = { .type = NLA_U32 }, +}; + +static int nft_nat_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + struct nft_nat *priv = nft_expr_priv(expr); + int err; + + if (ctx->family != NFPROTO_IPV4 && + ctx->family != NFPROTO_IPV6 && + ctx->family != NFPROTO_INET) + return -EOPNOTSUPP; + + err = nft_chain_validate_dependency(ctx->chain, NFT_CHAIN_T_NAT); + if (err < 0) + return err; + + switch (priv->type) { + case NFT_NAT_SNAT: + err = nft_chain_validate_hooks(ctx->chain, + (1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_IN)); + break; + case NFT_NAT_DNAT: + err = nft_chain_validate_hooks(ctx->chain, + (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_OUT)); + break; + } + + return err; +} + +static int nft_nat_init(const struct nft_ctx *ctx, const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_nat *priv = nft_expr_priv(expr); + unsigned int alen, plen; + u32 family; + int err; + + if (tb[NFTA_NAT_TYPE] == NULL || + (tb[NFTA_NAT_REG_ADDR_MIN] == NULL && + tb[NFTA_NAT_REG_PROTO_MIN] == NULL)) + return -EINVAL; + + switch (ntohl(nla_get_be32(tb[NFTA_NAT_TYPE]))) { + case NFT_NAT_SNAT: + priv->type = NF_NAT_MANIP_SRC; + break; + case NFT_NAT_DNAT: + priv->type = NF_NAT_MANIP_DST; + break; + default: + return -EOPNOTSUPP; + } + + if (tb[NFTA_NAT_FAMILY] == NULL) + return -EINVAL; + + family = ntohl(nla_get_be32(tb[NFTA_NAT_FAMILY])); + if (ctx->family != NFPROTO_INET && ctx->family != family) + return -EOPNOTSUPP; + + switch (family) { + case NFPROTO_IPV4: + alen = sizeof_field(struct nf_nat_range, min_addr.ip); + break; + case NFPROTO_IPV6: + alen = sizeof_field(struct nf_nat_range, min_addr.ip6); + break; + default: + if (tb[NFTA_NAT_REG_ADDR_MIN]) + return -EAFNOSUPPORT; + break; + } + priv->family = family; + + if (tb[NFTA_NAT_REG_ADDR_MIN]) { + err = nft_parse_register_load(tb[NFTA_NAT_REG_ADDR_MIN], + &priv->sreg_addr_min, alen); + if (err < 0) + return err; + + if (tb[NFTA_NAT_REG_ADDR_MAX]) { + err = nft_parse_register_load(tb[NFTA_NAT_REG_ADDR_MAX], + &priv->sreg_addr_max, + alen); + if (err < 0) + return err; + } else { + priv->sreg_addr_max = priv->sreg_addr_min; + } + + priv->flags |= NF_NAT_RANGE_MAP_IPS; + } + + plen = sizeof_field(struct nf_nat_range, min_proto.all); + if (tb[NFTA_NAT_REG_PROTO_MIN]) { + err = nft_parse_register_load(tb[NFTA_NAT_REG_PROTO_MIN], + &priv->sreg_proto_min, plen); + if (err < 0) + return err; + + if (tb[NFTA_NAT_REG_PROTO_MAX]) { + err = nft_parse_register_load(tb[NFTA_NAT_REG_PROTO_MAX], + &priv->sreg_proto_max, + plen); + if (err < 0) + return err; + } else { + priv->sreg_proto_max = priv->sreg_proto_min; + } + + priv->flags |= NF_NAT_RANGE_PROTO_SPECIFIED; + } + + if (tb[NFTA_NAT_FLAGS]) { + priv->flags |= ntohl(nla_get_be32(tb[NFTA_NAT_FLAGS])); + if (priv->flags & ~NF_NAT_RANGE_MASK) + return -EOPNOTSUPP; + } + + return nf_ct_netns_get(ctx->net, family); +} + +static int nft_nat_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_nat *priv = nft_expr_priv(expr); + + switch (priv->type) { + case NF_NAT_MANIP_SRC: + if (nla_put_be32(skb, NFTA_NAT_TYPE, htonl(NFT_NAT_SNAT))) + goto nla_put_failure; + break; + case NF_NAT_MANIP_DST: + if (nla_put_be32(skb, NFTA_NAT_TYPE, htonl(NFT_NAT_DNAT))) + goto nla_put_failure; + break; + } + + if (nla_put_be32(skb, NFTA_NAT_FAMILY, htonl(priv->family))) + goto nla_put_failure; + + if (priv->sreg_addr_min) { + if (nft_dump_register(skb, NFTA_NAT_REG_ADDR_MIN, + priv->sreg_addr_min) || + nft_dump_register(skb, NFTA_NAT_REG_ADDR_MAX, + priv->sreg_addr_max)) + goto nla_put_failure; + } + + if (priv->sreg_proto_min) { + if (nft_dump_register(skb, NFTA_NAT_REG_PROTO_MIN, + priv->sreg_proto_min) || + nft_dump_register(skb, NFTA_NAT_REG_PROTO_MAX, + priv->sreg_proto_max)) + goto nla_put_failure; + } + + if (priv->flags != 0) { + if (nla_put_be32(skb, NFTA_NAT_FLAGS, htonl(priv->flags))) + goto nla_put_failure; + } + + return 0; + +nla_put_failure: + return -1; +} + +static void +nft_nat_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) +{ + const struct nft_nat *priv = nft_expr_priv(expr); + + nf_ct_netns_put(ctx->net, priv->family); +} + +static struct nft_expr_type nft_nat_type; +static const struct nft_expr_ops nft_nat_ops = { + .type = &nft_nat_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_nat)), + .eval = nft_nat_eval, + .init = nft_nat_init, + .destroy = nft_nat_destroy, + .dump = nft_nat_dump, + .validate = nft_nat_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_nat_type __read_mostly = { + .name = "nat", + .ops = &nft_nat_ops, + .policy = nft_nat_policy, + .maxattr = NFTA_NAT_MAX, + .owner = THIS_MODULE, +}; + +#ifdef CONFIG_NF_TABLES_INET +static void nft_nat_inet_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_nat *priv = nft_expr_priv(expr); + + if (priv->family == nft_pf(pkt) || + priv->family == NFPROTO_INET) + nft_nat_eval(expr, regs, pkt); +} + +static const struct nft_expr_ops nft_nat_inet_ops = { + .type = &nft_nat_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_nat)), + .eval = nft_nat_inet_eval, + .init = nft_nat_init, + .destroy = nft_nat_destroy, + .dump = nft_nat_dump, + .validate = nft_nat_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_inet_nat_type __read_mostly = { + .name = "nat", + .family = NFPROTO_INET, + .ops = &nft_nat_inet_ops, + .policy = nft_nat_policy, + .maxattr = NFTA_NAT_MAX, + .owner = THIS_MODULE, +}; + +static int nft_nat_inet_module_init(void) +{ + return nft_register_expr(&nft_inet_nat_type); +} + +static void nft_nat_inet_module_exit(void) +{ + nft_unregister_expr(&nft_inet_nat_type); +} +#else +static int nft_nat_inet_module_init(void) { return 0; } +static void nft_nat_inet_module_exit(void) { } +#endif + +static int __init nft_nat_module_init(void) +{ + int ret = nft_nat_inet_module_init(); + + if (ret) + return ret; + + ret = nft_register_expr(&nft_nat_type); + if (ret) + nft_nat_inet_module_exit(); + + return ret; +} + +static void __exit nft_nat_module_exit(void) +{ + nft_nat_inet_module_exit(); + nft_unregister_expr(&nft_nat_type); +} + +module_init(nft_nat_module_init); +module_exit(nft_nat_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Tomasz Bursztyka "); +MODULE_ALIAS_NFT_EXPR("nat"); +MODULE_DESCRIPTION("Network Address Translation support"); diff --git a/net/netfilter/nft_numgen.c b/net/netfilter/nft_numgen.c new file mode 100644 index 000000000..45d3dc9e9 --- /dev/null +++ b/net/netfilter/nft_numgen.c @@ -0,0 +1,255 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2016 Laura Garcia + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_ng_inc { + u8 dreg; + u32 modulus; + atomic_t *counter; + u32 offset; +}; + +static u32 nft_ng_inc_gen(struct nft_ng_inc *priv) +{ + u32 nval, oval; + + do { + oval = atomic_read(priv->counter); + nval = (oval + 1 < priv->modulus) ? oval + 1 : 0; + } while (atomic_cmpxchg(priv->counter, oval, nval) != oval); + + return nval + priv->offset; +} + +static void nft_ng_inc_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_ng_inc *priv = nft_expr_priv(expr); + + regs->data[priv->dreg] = nft_ng_inc_gen(priv); +} + +static const struct nla_policy nft_ng_policy[NFTA_NG_MAX + 1] = { + [NFTA_NG_DREG] = { .type = NLA_U32 }, + [NFTA_NG_MODULUS] = { .type = NLA_U32 }, + [NFTA_NG_TYPE] = { .type = NLA_U32 }, + [NFTA_NG_OFFSET] = { .type = NLA_U32 }, +}; + +static int nft_ng_inc_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_ng_inc *priv = nft_expr_priv(expr); + int err; + + if (tb[NFTA_NG_OFFSET]) + priv->offset = ntohl(nla_get_be32(tb[NFTA_NG_OFFSET])); + + priv->modulus = ntohl(nla_get_be32(tb[NFTA_NG_MODULUS])); + if (priv->modulus == 0) + return -ERANGE; + + if (priv->offset + priv->modulus - 1 < priv->offset) + return -EOVERFLOW; + + priv->counter = kmalloc(sizeof(*priv->counter), GFP_KERNEL); + if (!priv->counter) + return -ENOMEM; + + atomic_set(priv->counter, priv->modulus - 1); + + err = nft_parse_register_store(ctx, tb[NFTA_NG_DREG], &priv->dreg, + NULL, NFT_DATA_VALUE, sizeof(u32)); + if (err < 0) + goto err; + + return 0; +err: + kfree(priv->counter); + + return err; +} + +static bool nft_ng_inc_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + const struct nft_ng_inc *priv = nft_expr_priv(expr); + + nft_reg_track_cancel(track, priv->dreg, NFT_REG32_SIZE); + + return false; +} + +static int nft_ng_dump(struct sk_buff *skb, enum nft_registers dreg, + u32 modulus, enum nft_ng_types type, u32 offset) +{ + if (nft_dump_register(skb, NFTA_NG_DREG, dreg)) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_NG_MODULUS, htonl(modulus))) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_NG_TYPE, htonl(type))) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_NG_OFFSET, htonl(offset))) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static int nft_ng_inc_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_ng_inc *priv = nft_expr_priv(expr); + + return nft_ng_dump(skb, priv->dreg, priv->modulus, NFT_NG_INCREMENTAL, + priv->offset); +} + +static void nft_ng_inc_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + const struct nft_ng_inc *priv = nft_expr_priv(expr); + + kfree(priv->counter); +} + +struct nft_ng_random { + u8 dreg; + u32 modulus; + u32 offset; +}; + +static u32 nft_ng_random_gen(const struct nft_ng_random *priv) +{ + return reciprocal_scale(get_random_u32(), priv->modulus) + priv->offset; +} + +static void nft_ng_random_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_ng_random *priv = nft_expr_priv(expr); + + regs->data[priv->dreg] = nft_ng_random_gen(priv); +} + +static int nft_ng_random_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_ng_random *priv = nft_expr_priv(expr); + + if (tb[NFTA_NG_OFFSET]) + priv->offset = ntohl(nla_get_be32(tb[NFTA_NG_OFFSET])); + + priv->modulus = ntohl(nla_get_be32(tb[NFTA_NG_MODULUS])); + if (priv->modulus == 0) + return -ERANGE; + + if (priv->offset + priv->modulus - 1 < priv->offset) + return -EOVERFLOW; + + return nft_parse_register_store(ctx, tb[NFTA_NG_DREG], &priv->dreg, + NULL, NFT_DATA_VALUE, sizeof(u32)); +} + +static int nft_ng_random_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_ng_random *priv = nft_expr_priv(expr); + + return nft_ng_dump(skb, priv->dreg, priv->modulus, NFT_NG_RANDOM, + priv->offset); +} + +static bool nft_ng_random_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + const struct nft_ng_random *priv = nft_expr_priv(expr); + + nft_reg_track_cancel(track, priv->dreg, NFT_REG32_SIZE); + + return false; +} + +static struct nft_expr_type nft_ng_type; +static const struct nft_expr_ops nft_ng_inc_ops = { + .type = &nft_ng_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_ng_inc)), + .eval = nft_ng_inc_eval, + .init = nft_ng_inc_init, + .destroy = nft_ng_inc_destroy, + .dump = nft_ng_inc_dump, + .reduce = nft_ng_inc_reduce, +}; + +static const struct nft_expr_ops nft_ng_random_ops = { + .type = &nft_ng_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_ng_random)), + .eval = nft_ng_random_eval, + .init = nft_ng_random_init, + .dump = nft_ng_random_dump, + .reduce = nft_ng_random_reduce, +}; + +static const struct nft_expr_ops * +nft_ng_select_ops(const struct nft_ctx *ctx, const struct nlattr * const tb[]) +{ + u32 type; + + if (!tb[NFTA_NG_DREG] || + !tb[NFTA_NG_MODULUS] || + !tb[NFTA_NG_TYPE]) + return ERR_PTR(-EINVAL); + + type = ntohl(nla_get_be32(tb[NFTA_NG_TYPE])); + + switch (type) { + case NFT_NG_INCREMENTAL: + return &nft_ng_inc_ops; + case NFT_NG_RANDOM: + return &nft_ng_random_ops; + } + + return ERR_PTR(-EINVAL); +} + +static struct nft_expr_type nft_ng_type __read_mostly = { + .name = "numgen", + .select_ops = nft_ng_select_ops, + .policy = nft_ng_policy, + .maxattr = NFTA_NG_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_ng_module_init(void) +{ + return nft_register_expr(&nft_ng_type); +} + +static void __exit nft_ng_module_exit(void) +{ + nft_unregister_expr(&nft_ng_type); +} + +module_init(nft_ng_module_init); +module_exit(nft_ng_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Laura Garcia "); +MODULE_ALIAS_NFT_EXPR("numgen"); +MODULE_DESCRIPTION("nftables number generator module"); diff --git a/net/netfilter/nft_objref.c b/net/netfilter/nft_objref.c new file mode 100644 index 000000000..0017bd341 --- /dev/null +++ b/net/netfilter/nft_objref.c @@ -0,0 +1,262 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2012-2016 Pablo Neira Ayuso + */ + +#include +#include +#include +#include +#include +#include +#include + +#define nft_objref_priv(expr) *((struct nft_object **)nft_expr_priv(expr)) + +static void nft_objref_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_object *obj = nft_objref_priv(expr); + + obj->ops->eval(obj, regs, pkt); +} + +static int nft_objref_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_object *obj = nft_objref_priv(expr); + u8 genmask = nft_genmask_next(ctx->net); + u32 objtype; + + if (!tb[NFTA_OBJREF_IMM_NAME] || + !tb[NFTA_OBJREF_IMM_TYPE]) + return -EINVAL; + + objtype = ntohl(nla_get_be32(tb[NFTA_OBJREF_IMM_TYPE])); + obj = nft_obj_lookup(ctx->net, ctx->table, + tb[NFTA_OBJREF_IMM_NAME], objtype, + genmask); + if (IS_ERR(obj)) + return -ENOENT; + + if (!nft_use_inc(&obj->use)) + return -EMFILE; + + nft_objref_priv(expr) = obj; + + return 0; +} + +static int nft_objref_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_object *obj = nft_objref_priv(expr); + + if (nla_put_string(skb, NFTA_OBJREF_IMM_NAME, obj->key.name) || + nla_put_be32(skb, NFTA_OBJREF_IMM_TYPE, + htonl(obj->ops->type->type))) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static void nft_objref_deactivate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + enum nft_trans_phase phase) +{ + struct nft_object *obj = nft_objref_priv(expr); + + if (phase == NFT_TRANS_COMMIT) + return; + + nft_use_dec(&obj->use); +} + +static void nft_objref_activate(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_object *obj = nft_objref_priv(expr); + + nft_use_inc_restore(&obj->use); +} + +static struct nft_expr_type nft_objref_type; +static const struct nft_expr_ops nft_objref_ops = { + .type = &nft_objref_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_object *)), + .eval = nft_objref_eval, + .init = nft_objref_init, + .activate = nft_objref_activate, + .deactivate = nft_objref_deactivate, + .dump = nft_objref_dump, + .reduce = NFT_REDUCE_READONLY, +}; + +struct nft_objref_map { + struct nft_set *set; + u8 sreg; + struct nft_set_binding binding; +}; + +static void nft_objref_map_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_objref_map *priv = nft_expr_priv(expr); + const struct nft_set *set = priv->set; + struct net *net = nft_net(pkt); + const struct nft_set_ext *ext; + struct nft_object *obj; + bool found; + + found = nft_set_do_lookup(net, set, ®s->data[priv->sreg], &ext); + if (!found) { + ext = nft_set_catchall_lookup(net, set); + if (!ext) { + regs->verdict.code = NFT_BREAK; + return; + } + } + obj = *nft_set_ext_obj(ext); + obj->ops->eval(obj, regs, pkt); +} + +static int nft_objref_map_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_objref_map *priv = nft_expr_priv(expr); + u8 genmask = nft_genmask_next(ctx->net); + struct nft_set *set; + int err; + + set = nft_set_lookup_global(ctx->net, ctx->table, + tb[NFTA_OBJREF_SET_NAME], + tb[NFTA_OBJREF_SET_ID], genmask); + if (IS_ERR(set)) + return PTR_ERR(set); + + if (!(set->flags & NFT_SET_OBJECT)) + return -EINVAL; + + err = nft_parse_register_load(tb[NFTA_OBJREF_SET_SREG], &priv->sreg, + set->klen); + if (err < 0) + return err; + + priv->binding.flags = set->flags & NFT_SET_OBJECT; + + err = nf_tables_bind_set(ctx, set, &priv->binding); + if (err < 0) + return err; + + priv->set = set; + return 0; +} + +static int nft_objref_map_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_objref_map *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_OBJREF_SET_SREG, priv->sreg) || + nla_put_string(skb, NFTA_OBJREF_SET_NAME, priv->set->name)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static void nft_objref_map_deactivate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + enum nft_trans_phase phase) +{ + struct nft_objref_map *priv = nft_expr_priv(expr); + + nf_tables_deactivate_set(ctx, priv->set, &priv->binding, phase); +} + +static void nft_objref_map_activate(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_objref_map *priv = nft_expr_priv(expr); + + nf_tables_activate_set(ctx, priv->set); +} + +static void nft_objref_map_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_objref_map *priv = nft_expr_priv(expr); + + nf_tables_destroy_set(ctx, priv->set); +} + +static struct nft_expr_type nft_objref_type; +static const struct nft_expr_ops nft_objref_map_ops = { + .type = &nft_objref_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_objref_map)), + .eval = nft_objref_map_eval, + .init = nft_objref_map_init, + .activate = nft_objref_map_activate, + .deactivate = nft_objref_map_deactivate, + .destroy = nft_objref_map_destroy, + .dump = nft_objref_map_dump, + .reduce = NFT_REDUCE_READONLY, +}; + +static const struct nft_expr_ops * +nft_objref_select_ops(const struct nft_ctx *ctx, + const struct nlattr * const tb[]) +{ + if (tb[NFTA_OBJREF_SET_SREG] && + (tb[NFTA_OBJREF_SET_NAME] || + tb[NFTA_OBJREF_SET_ID])) + return &nft_objref_map_ops; + else if (tb[NFTA_OBJREF_IMM_NAME] && + tb[NFTA_OBJREF_IMM_TYPE]) + return &nft_objref_ops; + + return ERR_PTR(-EOPNOTSUPP); +} + +static const struct nla_policy nft_objref_policy[NFTA_OBJREF_MAX + 1] = { + [NFTA_OBJREF_IMM_NAME] = { .type = NLA_STRING, + .len = NFT_OBJ_MAXNAMELEN - 1 }, + [NFTA_OBJREF_IMM_TYPE] = { .type = NLA_U32 }, + [NFTA_OBJREF_SET_SREG] = { .type = NLA_U32 }, + [NFTA_OBJREF_SET_NAME] = { .type = NLA_STRING, + .len = NFT_SET_MAXNAMELEN - 1 }, + [NFTA_OBJREF_SET_ID] = { .type = NLA_U32 }, +}; + +static struct nft_expr_type nft_objref_type __read_mostly = { + .name = "objref", + .select_ops = nft_objref_select_ops, + .policy = nft_objref_policy, + .maxattr = NFTA_OBJREF_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_objref_module_init(void) +{ + return nft_register_expr(&nft_objref_type); +} + +static void __exit nft_objref_module_exit(void) +{ + nft_unregister_expr(&nft_objref_type); +} + +module_init(nft_objref_module_init); +module_exit(nft_objref_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_ALIAS_NFT_EXPR("objref"); +MODULE_DESCRIPTION("nftables stateful object reference module"); diff --git a/net/netfilter/nft_osf.c b/net/netfilter/nft_osf.c new file mode 100644 index 000000000..adacf95b6 --- /dev/null +++ b/net/netfilter/nft_osf.c @@ -0,0 +1,194 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include + +#include +#include + +struct nft_osf { + u8 dreg; + u8 ttl; + u32 flags; +}; + +static const struct nla_policy nft_osf_policy[NFTA_OSF_MAX + 1] = { + [NFTA_OSF_DREG] = { .type = NLA_U32 }, + [NFTA_OSF_TTL] = { .type = NLA_U8 }, + [NFTA_OSF_FLAGS] = { .type = NLA_U32 }, +}; + +static void nft_osf_eval(const struct nft_expr *expr, struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_osf *priv = nft_expr_priv(expr); + u32 *dest = ®s->data[priv->dreg]; + struct sk_buff *skb = pkt->skb; + char os_match[NFT_OSF_MAXGENRELEN + 1]; + const struct tcphdr *tcp; + struct nf_osf_data data; + struct tcphdr _tcph; + + if (pkt->tprot != IPPROTO_TCP) { + regs->verdict.code = NFT_BREAK; + return; + } + + tcp = skb_header_pointer(skb, ip_hdrlen(skb), + sizeof(struct tcphdr), &_tcph); + if (!tcp) { + regs->verdict.code = NFT_BREAK; + return; + } + if (!tcp->syn) { + regs->verdict.code = NFT_BREAK; + return; + } + + if (!nf_osf_find(skb, nf_osf_fingers, priv->ttl, &data)) { + strncpy((char *)dest, "unknown", NFT_OSF_MAXGENRELEN); + } else { + if (priv->flags & NFT_OSF_F_VERSION) + snprintf(os_match, NFT_OSF_MAXGENRELEN, "%s:%s", + data.genre, data.version); + else + strscpy(os_match, data.genre, NFT_OSF_MAXGENRELEN); + + strncpy((char *)dest, os_match, NFT_OSF_MAXGENRELEN); + } +} + +static int nft_osf_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_osf *priv = nft_expr_priv(expr); + u32 flags; + int err; + u8 ttl; + + if (!tb[NFTA_OSF_DREG]) + return -EINVAL; + + if (tb[NFTA_OSF_TTL]) { + ttl = nla_get_u8(tb[NFTA_OSF_TTL]); + if (ttl > 2) + return -EINVAL; + priv->ttl = ttl; + } + + if (tb[NFTA_OSF_FLAGS]) { + flags = ntohl(nla_get_be32(tb[NFTA_OSF_FLAGS])); + if (flags != NFT_OSF_F_VERSION) + return -EINVAL; + priv->flags = flags; + } + + err = nft_parse_register_store(ctx, tb[NFTA_OSF_DREG], &priv->dreg, + NULL, NFT_DATA_VALUE, + NFT_OSF_MAXGENRELEN); + if (err < 0) + return err; + + return 0; +} + +static int nft_osf_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_osf *priv = nft_expr_priv(expr); + + if (nla_put_u8(skb, NFTA_OSF_TTL, priv->ttl)) + goto nla_put_failure; + + if (nla_put_u32(skb, NFTA_OSF_FLAGS, ntohl((__force __be32)priv->flags))) + goto nla_put_failure; + + if (nft_dump_register(skb, NFTA_OSF_DREG, priv->dreg)) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static int nft_osf_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + unsigned int hooks; + + switch (ctx->family) { + case NFPROTO_IPV4: + case NFPROTO_IPV6: + case NFPROTO_INET: + hooks = (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_FORWARD); + break; + default: + return -EOPNOTSUPP; + } + + return nft_chain_validate_hooks(ctx->chain, hooks); +} + +static bool nft_osf_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + struct nft_osf *priv = nft_expr_priv(expr); + struct nft_osf *osf; + + if (!nft_reg_track_cmp(track, expr, priv->dreg)) { + nft_reg_track_update(track, expr, priv->dreg, NFT_OSF_MAXGENRELEN); + return false; + } + + osf = nft_expr_priv(track->regs[priv->dreg].selector); + if (priv->flags != osf->flags || + priv->ttl != osf->ttl) { + nft_reg_track_update(track, expr, priv->dreg, NFT_OSF_MAXGENRELEN); + return false; + } + + if (!track->regs[priv->dreg].bitwise) + return true; + + return false; +} + +static struct nft_expr_type nft_osf_type; +static const struct nft_expr_ops nft_osf_op = { + .eval = nft_osf_eval, + .size = NFT_EXPR_SIZE(sizeof(struct nft_osf)), + .init = nft_osf_init, + .dump = nft_osf_dump, + .type = &nft_osf_type, + .validate = nft_osf_validate, + .reduce = nft_osf_reduce, +}; + +static struct nft_expr_type nft_osf_type __read_mostly = { + .ops = &nft_osf_op, + .name = "osf", + .owner = THIS_MODULE, + .policy = nft_osf_policy, + .maxattr = NFTA_OSF_MAX, +}; + +static int __init nft_osf_module_init(void) +{ + return nft_register_expr(&nft_osf_type); +} + +static void __exit nft_osf_module_exit(void) +{ + return nft_unregister_expr(&nft_osf_type); +} + +module_init(nft_osf_module_init); +module_exit(nft_osf_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Fernando Fernandez "); +MODULE_ALIAS_NFT_EXPR("osf"); +MODULE_DESCRIPTION("nftables passive OS fingerprint support"); diff --git a/net/netfilter/nft_payload.c b/net/netfilter/nft_payload.c new file mode 100644 index 000000000..f44f2eaf3 --- /dev/null +++ b/net/netfilter/nft_payload.c @@ -0,0 +1,891 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008-2009 Patrick McHardy + * Copyright (c) 2016 Pablo Neira Ayuso + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +/* For layer 4 checksum field offset. */ +#include +#include +#include +#include +#include +#include + +static bool nft_payload_rebuild_vlan_hdr(const struct sk_buff *skb, int mac_off, + struct vlan_ethhdr *veth) +{ + if (skb_copy_bits(skb, mac_off, veth, ETH_HLEN)) + return false; + + veth->h_vlan_proto = skb->vlan_proto; + veth->h_vlan_TCI = htons(skb_vlan_tag_get(skb)); + veth->h_vlan_encapsulated_proto = skb->protocol; + + return true; +} + +/* add vlan header into the user buffer for if tag was removed by offloads */ +static bool +nft_payload_copy_vlan(u32 *d, const struct sk_buff *skb, u8 offset, u8 len) +{ + int mac_off = skb_mac_header(skb) - skb->data; + u8 *vlanh, *dst_u8 = (u8 *) d; + struct vlan_ethhdr veth; + u8 vlan_hlen = 0; + + if ((skb->protocol == htons(ETH_P_8021AD) || + skb->protocol == htons(ETH_P_8021Q)) && + offset >= VLAN_ETH_HLEN && offset < VLAN_ETH_HLEN + VLAN_HLEN) + vlan_hlen += VLAN_HLEN; + + vlanh = (u8 *) &veth; + if (offset < VLAN_ETH_HLEN + vlan_hlen) { + u8 ethlen = len; + + if (vlan_hlen && + skb_copy_bits(skb, mac_off, &veth, VLAN_ETH_HLEN) < 0) + return false; + else if (!nft_payload_rebuild_vlan_hdr(skb, mac_off, &veth)) + return false; + + if (offset + len > VLAN_ETH_HLEN + vlan_hlen) + ethlen -= offset + len - VLAN_ETH_HLEN - vlan_hlen; + + memcpy(dst_u8, vlanh + offset - vlan_hlen, ethlen); + + len -= ethlen; + if (len == 0) + return true; + + dst_u8 += ethlen; + offset = ETH_HLEN + vlan_hlen; + } else { + offset -= VLAN_HLEN + vlan_hlen; + } + + return skb_copy_bits(skb, offset + mac_off, dst_u8, len) == 0; +} + +static int __nft_payload_inner_offset(struct nft_pktinfo *pkt) +{ + unsigned int thoff = nft_thoff(pkt); + + if (!(pkt->flags & NFT_PKTINFO_L4PROTO) || pkt->fragoff) + return -1; + + switch (pkt->tprot) { + case IPPROTO_UDP: + pkt->inneroff = thoff + sizeof(struct udphdr); + break; + case IPPROTO_TCP: { + struct tcphdr *th, _tcph; + + th = skb_header_pointer(pkt->skb, thoff, sizeof(_tcph), &_tcph); + if (!th) + return -1; + + pkt->inneroff = thoff + __tcp_hdrlen(th); + } + break; + default: + return -1; + } + + pkt->flags |= NFT_PKTINFO_INNER; + + return 0; +} + +static int nft_payload_inner_offset(const struct nft_pktinfo *pkt) +{ + if (!(pkt->flags & NFT_PKTINFO_INNER) && + __nft_payload_inner_offset((struct nft_pktinfo *)pkt) < 0) + return -1; + + return pkt->inneroff; +} + +void nft_payload_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_payload *priv = nft_expr_priv(expr); + const struct sk_buff *skb = pkt->skb; + u32 *dest = ®s->data[priv->dreg]; + int offset; + + if (priv->len % NFT_REG32_SIZE) + dest[priv->len / NFT_REG32_SIZE] = 0; + + switch (priv->base) { + case NFT_PAYLOAD_LL_HEADER: + if (!skb_mac_header_was_set(skb) || skb_mac_header_len(skb) == 0) + goto err; + + if (skb_vlan_tag_present(skb)) { + if (!nft_payload_copy_vlan(dest, skb, + priv->offset, priv->len)) + goto err; + return; + } + offset = skb_mac_header(skb) - skb->data; + break; + case NFT_PAYLOAD_NETWORK_HEADER: + offset = skb_network_offset(skb); + break; + case NFT_PAYLOAD_TRANSPORT_HEADER: + if (!(pkt->flags & NFT_PKTINFO_L4PROTO) || pkt->fragoff) + goto err; + offset = nft_thoff(pkt); + break; + case NFT_PAYLOAD_INNER_HEADER: + offset = nft_payload_inner_offset(pkt); + if (offset < 0) + goto err; + break; + default: + WARN_ON_ONCE(1); + goto err; + } + offset += priv->offset; + + if (skb_copy_bits(skb, offset, dest, priv->len) < 0) + goto err; + return; +err: + regs->verdict.code = NFT_BREAK; +} + +static const struct nla_policy nft_payload_policy[NFTA_PAYLOAD_MAX + 1] = { + [NFTA_PAYLOAD_SREG] = { .type = NLA_U32 }, + [NFTA_PAYLOAD_DREG] = { .type = NLA_U32 }, + [NFTA_PAYLOAD_BASE] = { .type = NLA_U32 }, + [NFTA_PAYLOAD_OFFSET] = NLA_POLICY_MAX(NLA_BE32, 255), + [NFTA_PAYLOAD_LEN] = NLA_POLICY_MAX(NLA_BE32, 255), + [NFTA_PAYLOAD_CSUM_TYPE] = { .type = NLA_U32 }, + [NFTA_PAYLOAD_CSUM_OFFSET] = NLA_POLICY_MAX(NLA_BE32, 255), + [NFTA_PAYLOAD_CSUM_FLAGS] = { .type = NLA_U32 }, +}; + +static int nft_payload_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_payload *priv = nft_expr_priv(expr); + + priv->base = ntohl(nla_get_be32(tb[NFTA_PAYLOAD_BASE])); + priv->offset = ntohl(nla_get_be32(tb[NFTA_PAYLOAD_OFFSET])); + priv->len = ntohl(nla_get_be32(tb[NFTA_PAYLOAD_LEN])); + + return nft_parse_register_store(ctx, tb[NFTA_PAYLOAD_DREG], + &priv->dreg, NULL, NFT_DATA_VALUE, + priv->len); +} + +static int nft_payload_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_payload *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_PAYLOAD_DREG, priv->dreg) || + nla_put_be32(skb, NFTA_PAYLOAD_BASE, htonl(priv->base)) || + nla_put_be32(skb, NFTA_PAYLOAD_OFFSET, htonl(priv->offset)) || + nla_put_be32(skb, NFTA_PAYLOAD_LEN, htonl(priv->len))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static bool nft_payload_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + const struct nft_payload *priv = nft_expr_priv(expr); + const struct nft_payload *payload; + + if (!nft_reg_track_cmp(track, expr, priv->dreg)) { + nft_reg_track_update(track, expr, priv->dreg, priv->len); + return false; + } + + payload = nft_expr_priv(track->regs[priv->dreg].selector); + if (priv->base != payload->base || + priv->offset != payload->offset || + priv->len != payload->len) { + nft_reg_track_update(track, expr, priv->dreg, priv->len); + return false; + } + + if (!track->regs[priv->dreg].bitwise) + return true; + + return nft_expr_reduce_bitwise(track, expr); +} + +static bool nft_payload_offload_mask(struct nft_offload_reg *reg, + u32 priv_len, u32 field_len) +{ + unsigned int remainder, delta, k; + struct nft_data mask = {}; + __be32 remainder_mask; + + if (priv_len == field_len) { + memset(®->mask, 0xff, priv_len); + return true; + } else if (priv_len > field_len) { + return false; + } + + memset(&mask, 0xff, field_len); + remainder = priv_len % sizeof(u32); + if (remainder) { + k = priv_len / sizeof(u32); + delta = field_len - priv_len; + remainder_mask = htonl(~((1 << (delta * BITS_PER_BYTE)) - 1)); + mask.data[k] = (__force u32)remainder_mask; + } + + memcpy(®->mask, &mask, field_len); + + return true; +} + +static int nft_payload_offload_ll(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_payload *priv) +{ + struct nft_offload_reg *reg = &ctx->regs[priv->dreg]; + + switch (priv->offset) { + case offsetof(struct ethhdr, h_source): + if (!nft_payload_offload_mask(reg, priv->len, ETH_ALEN)) + return -EOPNOTSUPP; + + NFT_OFFLOAD_MATCH(FLOW_DISSECTOR_KEY_ETH_ADDRS, eth_addrs, + src, ETH_ALEN, reg); + break; + case offsetof(struct ethhdr, h_dest): + if (!nft_payload_offload_mask(reg, priv->len, ETH_ALEN)) + return -EOPNOTSUPP; + + NFT_OFFLOAD_MATCH(FLOW_DISSECTOR_KEY_ETH_ADDRS, eth_addrs, + dst, ETH_ALEN, reg); + break; + case offsetof(struct ethhdr, h_proto): + if (!nft_payload_offload_mask(reg, priv->len, sizeof(__be16))) + return -EOPNOTSUPP; + + NFT_OFFLOAD_MATCH(FLOW_DISSECTOR_KEY_BASIC, basic, + n_proto, sizeof(__be16), reg); + nft_offload_set_dependency(ctx, NFT_OFFLOAD_DEP_NETWORK); + break; + case offsetof(struct vlan_ethhdr, h_vlan_TCI): + if (!nft_payload_offload_mask(reg, priv->len, sizeof(__be16))) + return -EOPNOTSUPP; + + NFT_OFFLOAD_MATCH_FLAGS(FLOW_DISSECTOR_KEY_VLAN, vlan, + vlan_tci, sizeof(__be16), reg, + NFT_OFFLOAD_F_NETWORK2HOST); + break; + case offsetof(struct vlan_ethhdr, h_vlan_encapsulated_proto): + if (!nft_payload_offload_mask(reg, priv->len, sizeof(__be16))) + return -EOPNOTSUPP; + + NFT_OFFLOAD_MATCH(FLOW_DISSECTOR_KEY_VLAN, vlan, + vlan_tpid, sizeof(__be16), reg); + nft_offload_set_dependency(ctx, NFT_OFFLOAD_DEP_NETWORK); + break; + case offsetof(struct vlan_ethhdr, h_vlan_TCI) + sizeof(struct vlan_hdr): + if (!nft_payload_offload_mask(reg, priv->len, sizeof(__be16))) + return -EOPNOTSUPP; + + NFT_OFFLOAD_MATCH_FLAGS(FLOW_DISSECTOR_KEY_CVLAN, cvlan, + vlan_tci, sizeof(__be16), reg, + NFT_OFFLOAD_F_NETWORK2HOST); + break; + case offsetof(struct vlan_ethhdr, h_vlan_encapsulated_proto) + + sizeof(struct vlan_hdr): + if (!nft_payload_offload_mask(reg, priv->len, sizeof(__be16))) + return -EOPNOTSUPP; + + NFT_OFFLOAD_MATCH(FLOW_DISSECTOR_KEY_CVLAN, cvlan, + vlan_tpid, sizeof(__be16), reg); + nft_offload_set_dependency(ctx, NFT_OFFLOAD_DEP_NETWORK); + break; + default: + return -EOPNOTSUPP; + } + + return 0; +} + +static int nft_payload_offload_ip(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_payload *priv) +{ + struct nft_offload_reg *reg = &ctx->regs[priv->dreg]; + + switch (priv->offset) { + case offsetof(struct iphdr, saddr): + if (!nft_payload_offload_mask(reg, priv->len, + sizeof(struct in_addr))) + return -EOPNOTSUPP; + + NFT_OFFLOAD_MATCH(FLOW_DISSECTOR_KEY_IPV4_ADDRS, ipv4, src, + sizeof(struct in_addr), reg); + nft_flow_rule_set_addr_type(flow, FLOW_DISSECTOR_KEY_IPV4_ADDRS); + break; + case offsetof(struct iphdr, daddr): + if (!nft_payload_offload_mask(reg, priv->len, + sizeof(struct in_addr))) + return -EOPNOTSUPP; + + NFT_OFFLOAD_MATCH(FLOW_DISSECTOR_KEY_IPV4_ADDRS, ipv4, dst, + sizeof(struct in_addr), reg); + nft_flow_rule_set_addr_type(flow, FLOW_DISSECTOR_KEY_IPV4_ADDRS); + break; + case offsetof(struct iphdr, protocol): + if (!nft_payload_offload_mask(reg, priv->len, sizeof(__u8))) + return -EOPNOTSUPP; + + NFT_OFFLOAD_MATCH(FLOW_DISSECTOR_KEY_BASIC, basic, ip_proto, + sizeof(__u8), reg); + nft_offload_set_dependency(ctx, NFT_OFFLOAD_DEP_TRANSPORT); + break; + default: + return -EOPNOTSUPP; + } + + return 0; +} + +static int nft_payload_offload_ip6(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_payload *priv) +{ + struct nft_offload_reg *reg = &ctx->regs[priv->dreg]; + + switch (priv->offset) { + case offsetof(struct ipv6hdr, saddr): + if (!nft_payload_offload_mask(reg, priv->len, + sizeof(struct in6_addr))) + return -EOPNOTSUPP; + + NFT_OFFLOAD_MATCH(FLOW_DISSECTOR_KEY_IPV6_ADDRS, ipv6, src, + sizeof(struct in6_addr), reg); + nft_flow_rule_set_addr_type(flow, FLOW_DISSECTOR_KEY_IPV6_ADDRS); + break; + case offsetof(struct ipv6hdr, daddr): + if (!nft_payload_offload_mask(reg, priv->len, + sizeof(struct in6_addr))) + return -EOPNOTSUPP; + + NFT_OFFLOAD_MATCH(FLOW_DISSECTOR_KEY_IPV6_ADDRS, ipv6, dst, + sizeof(struct in6_addr), reg); + nft_flow_rule_set_addr_type(flow, FLOW_DISSECTOR_KEY_IPV6_ADDRS); + break; + case offsetof(struct ipv6hdr, nexthdr): + if (!nft_payload_offload_mask(reg, priv->len, sizeof(__u8))) + return -EOPNOTSUPP; + + NFT_OFFLOAD_MATCH(FLOW_DISSECTOR_KEY_BASIC, basic, ip_proto, + sizeof(__u8), reg); + nft_offload_set_dependency(ctx, NFT_OFFLOAD_DEP_TRANSPORT); + break; + default: + return -EOPNOTSUPP; + } + + return 0; +} + +static int nft_payload_offload_nh(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_payload *priv) +{ + int err; + + switch (ctx->dep.l3num) { + case htons(ETH_P_IP): + err = nft_payload_offload_ip(ctx, flow, priv); + break; + case htons(ETH_P_IPV6): + err = nft_payload_offload_ip6(ctx, flow, priv); + break; + default: + return -EOPNOTSUPP; + } + + return err; +} + +static int nft_payload_offload_tcp(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_payload *priv) +{ + struct nft_offload_reg *reg = &ctx->regs[priv->dreg]; + + switch (priv->offset) { + case offsetof(struct tcphdr, source): + if (!nft_payload_offload_mask(reg, priv->len, sizeof(__be16))) + return -EOPNOTSUPP; + + NFT_OFFLOAD_MATCH(FLOW_DISSECTOR_KEY_PORTS, tp, src, + sizeof(__be16), reg); + break; + case offsetof(struct tcphdr, dest): + if (!nft_payload_offload_mask(reg, priv->len, sizeof(__be16))) + return -EOPNOTSUPP; + + NFT_OFFLOAD_MATCH(FLOW_DISSECTOR_KEY_PORTS, tp, dst, + sizeof(__be16), reg); + break; + default: + return -EOPNOTSUPP; + } + + return 0; +} + +static int nft_payload_offload_udp(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_payload *priv) +{ + struct nft_offload_reg *reg = &ctx->regs[priv->dreg]; + + switch (priv->offset) { + case offsetof(struct udphdr, source): + if (!nft_payload_offload_mask(reg, priv->len, sizeof(__be16))) + return -EOPNOTSUPP; + + NFT_OFFLOAD_MATCH(FLOW_DISSECTOR_KEY_PORTS, tp, src, + sizeof(__be16), reg); + break; + case offsetof(struct udphdr, dest): + if (!nft_payload_offload_mask(reg, priv->len, sizeof(__be16))) + return -EOPNOTSUPP; + + NFT_OFFLOAD_MATCH(FLOW_DISSECTOR_KEY_PORTS, tp, dst, + sizeof(__be16), reg); + break; + default: + return -EOPNOTSUPP; + } + + return 0; +} + +static int nft_payload_offload_th(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_payload *priv) +{ + int err; + + switch (ctx->dep.protonum) { + case IPPROTO_TCP: + err = nft_payload_offload_tcp(ctx, flow, priv); + break; + case IPPROTO_UDP: + err = nft_payload_offload_udp(ctx, flow, priv); + break; + default: + return -EOPNOTSUPP; + } + + return err; +} + +static int nft_payload_offload(struct nft_offload_ctx *ctx, + struct nft_flow_rule *flow, + const struct nft_expr *expr) +{ + const struct nft_payload *priv = nft_expr_priv(expr); + int err; + + switch (priv->base) { + case NFT_PAYLOAD_LL_HEADER: + err = nft_payload_offload_ll(ctx, flow, priv); + break; + case NFT_PAYLOAD_NETWORK_HEADER: + err = nft_payload_offload_nh(ctx, flow, priv); + break; + case NFT_PAYLOAD_TRANSPORT_HEADER: + err = nft_payload_offload_th(ctx, flow, priv); + break; + default: + err = -EOPNOTSUPP; + break; + } + return err; +} + +static const struct nft_expr_ops nft_payload_ops = { + .type = &nft_payload_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_payload)), + .eval = nft_payload_eval, + .init = nft_payload_init, + .dump = nft_payload_dump, + .reduce = nft_payload_reduce, + .offload = nft_payload_offload, +}; + +const struct nft_expr_ops nft_payload_fast_ops = { + .type = &nft_payload_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_payload)), + .eval = nft_payload_eval, + .init = nft_payload_init, + .dump = nft_payload_dump, + .reduce = nft_payload_reduce, + .offload = nft_payload_offload, +}; + +static inline void nft_csum_replace(__sum16 *sum, __wsum fsum, __wsum tsum) +{ + *sum = csum_fold(csum_add(csum_sub(~csum_unfold(*sum), fsum), tsum)); + if (*sum == 0) + *sum = CSUM_MANGLED_0; +} + +static bool nft_payload_udp_checksum(struct sk_buff *skb, unsigned int thoff) +{ + struct udphdr *uh, _uh; + + uh = skb_header_pointer(skb, thoff, sizeof(_uh), &_uh); + if (!uh) + return false; + + return (__force bool)uh->check; +} + +static int nft_payload_l4csum_offset(const struct nft_pktinfo *pkt, + struct sk_buff *skb, + unsigned int *l4csum_offset) +{ + if (pkt->fragoff) + return -1; + + switch (pkt->tprot) { + case IPPROTO_TCP: + *l4csum_offset = offsetof(struct tcphdr, check); + break; + case IPPROTO_UDP: + if (!nft_payload_udp_checksum(skb, nft_thoff(pkt))) + return -1; + fallthrough; + case IPPROTO_UDPLITE: + *l4csum_offset = offsetof(struct udphdr, check); + break; + case IPPROTO_ICMPV6: + *l4csum_offset = offsetof(struct icmp6hdr, icmp6_cksum); + break; + default: + return -1; + } + + *l4csum_offset += nft_thoff(pkt); + return 0; +} + +static int nft_payload_csum_sctp(struct sk_buff *skb, int offset) +{ + struct sctphdr *sh; + + if (skb_ensure_writable(skb, offset + sizeof(*sh))) + return -1; + + sh = (struct sctphdr *)(skb->data + offset); + sh->checksum = sctp_compute_cksum(skb, offset); + skb->ip_summed = CHECKSUM_UNNECESSARY; + return 0; +} + +static int nft_payload_l4csum_update(const struct nft_pktinfo *pkt, + struct sk_buff *skb, + __wsum fsum, __wsum tsum) +{ + int l4csum_offset; + __sum16 sum; + + /* If we cannot determine layer 4 checksum offset or this packet doesn't + * require layer 4 checksum recalculation, skip this packet. + */ + if (nft_payload_l4csum_offset(pkt, skb, &l4csum_offset) < 0) + return 0; + + if (skb_copy_bits(skb, l4csum_offset, &sum, sizeof(sum)) < 0) + return -1; + + /* Checksum mangling for an arbitrary amount of bytes, based on + * inet_proto_csum_replace*() functions. + */ + if (skb->ip_summed != CHECKSUM_PARTIAL) { + nft_csum_replace(&sum, fsum, tsum); + if (skb->ip_summed == CHECKSUM_COMPLETE) { + skb->csum = ~csum_add(csum_sub(~(skb->csum), fsum), + tsum); + } + } else { + sum = ~csum_fold(csum_add(csum_sub(csum_unfold(sum), fsum), + tsum)); + } + + if (skb_ensure_writable(skb, l4csum_offset + sizeof(sum)) || + skb_store_bits(skb, l4csum_offset, &sum, sizeof(sum)) < 0) + return -1; + + return 0; +} + +static int nft_payload_csum_inet(struct sk_buff *skb, const u32 *src, + __wsum fsum, __wsum tsum, int csum_offset) +{ + __sum16 sum; + + if (skb_copy_bits(skb, csum_offset, &sum, sizeof(sum)) < 0) + return -1; + + nft_csum_replace(&sum, fsum, tsum); + if (skb_ensure_writable(skb, csum_offset + sizeof(sum)) || + skb_store_bits(skb, csum_offset, &sum, sizeof(sum)) < 0) + return -1; + + return 0; +} + +static void nft_payload_set_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_payload_set *priv = nft_expr_priv(expr); + struct sk_buff *skb = pkt->skb; + const u32 *src = ®s->data[priv->sreg]; + int offset, csum_offset; + __wsum fsum, tsum; + + switch (priv->base) { + case NFT_PAYLOAD_LL_HEADER: + if (!skb_mac_header_was_set(skb)) + goto err; + offset = skb_mac_header(skb) - skb->data; + break; + case NFT_PAYLOAD_NETWORK_HEADER: + offset = skb_network_offset(skb); + break; + case NFT_PAYLOAD_TRANSPORT_HEADER: + if (!(pkt->flags & NFT_PKTINFO_L4PROTO) || pkt->fragoff) + goto err; + offset = nft_thoff(pkt); + break; + case NFT_PAYLOAD_INNER_HEADER: + offset = nft_payload_inner_offset(pkt); + if (offset < 0) + goto err; + break; + default: + WARN_ON_ONCE(1); + goto err; + } + + csum_offset = offset + priv->csum_offset; + offset += priv->offset; + + if ((priv->csum_type == NFT_PAYLOAD_CSUM_INET || priv->csum_flags) && + ((priv->base != NFT_PAYLOAD_TRANSPORT_HEADER && + priv->base != NFT_PAYLOAD_INNER_HEADER) || + skb->ip_summed != CHECKSUM_PARTIAL)) { + fsum = skb_checksum(skb, offset, priv->len, 0); + tsum = csum_partial(src, priv->len, 0); + + if (priv->csum_type == NFT_PAYLOAD_CSUM_INET && + nft_payload_csum_inet(skb, src, fsum, tsum, csum_offset)) + goto err; + + if (priv->csum_flags && + nft_payload_l4csum_update(pkt, skb, fsum, tsum) < 0) + goto err; + } + + if (skb_ensure_writable(skb, max(offset + priv->len, 0)) || + skb_store_bits(skb, offset, src, priv->len) < 0) + goto err; + + if (priv->csum_type == NFT_PAYLOAD_CSUM_SCTP && + pkt->tprot == IPPROTO_SCTP && + skb->ip_summed != CHECKSUM_PARTIAL) { + if (pkt->fragoff == 0 && + nft_payload_csum_sctp(skb, nft_thoff(pkt))) + goto err; + } + + return; +err: + regs->verdict.code = NFT_BREAK; +} + +static int nft_payload_set_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_payload_set *priv = nft_expr_priv(expr); + u32 csum_offset, csum_type = NFT_PAYLOAD_CSUM_NONE; + int err; + + priv->base = ntohl(nla_get_be32(tb[NFTA_PAYLOAD_BASE])); + priv->offset = ntohl(nla_get_be32(tb[NFTA_PAYLOAD_OFFSET])); + priv->len = ntohl(nla_get_be32(tb[NFTA_PAYLOAD_LEN])); + + if (tb[NFTA_PAYLOAD_CSUM_TYPE]) + csum_type = ntohl(nla_get_be32(tb[NFTA_PAYLOAD_CSUM_TYPE])); + if (tb[NFTA_PAYLOAD_CSUM_OFFSET]) { + err = nft_parse_u32_check(tb[NFTA_PAYLOAD_CSUM_OFFSET], U8_MAX, + &csum_offset); + if (err < 0) + return err; + + priv->csum_offset = csum_offset; + } + if (tb[NFTA_PAYLOAD_CSUM_FLAGS]) { + u32 flags; + + flags = ntohl(nla_get_be32(tb[NFTA_PAYLOAD_CSUM_FLAGS])); + if (flags & ~NFT_PAYLOAD_L4CSUM_PSEUDOHDR) + return -EINVAL; + + priv->csum_flags = flags; + } + + switch (csum_type) { + case NFT_PAYLOAD_CSUM_NONE: + case NFT_PAYLOAD_CSUM_INET: + break; + case NFT_PAYLOAD_CSUM_SCTP: + if (priv->base != NFT_PAYLOAD_TRANSPORT_HEADER) + return -EINVAL; + + if (priv->csum_offset != offsetof(struct sctphdr, checksum)) + return -EINVAL; + break; + default: + return -EOPNOTSUPP; + } + priv->csum_type = csum_type; + + return nft_parse_register_load(tb[NFTA_PAYLOAD_SREG], &priv->sreg, + priv->len); +} + +static int nft_payload_set_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_payload_set *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_PAYLOAD_SREG, priv->sreg) || + nla_put_be32(skb, NFTA_PAYLOAD_BASE, htonl(priv->base)) || + nla_put_be32(skb, NFTA_PAYLOAD_OFFSET, htonl(priv->offset)) || + nla_put_be32(skb, NFTA_PAYLOAD_LEN, htonl(priv->len)) || + nla_put_be32(skb, NFTA_PAYLOAD_CSUM_TYPE, htonl(priv->csum_type)) || + nla_put_be32(skb, NFTA_PAYLOAD_CSUM_OFFSET, + htonl(priv->csum_offset)) || + nla_put_be32(skb, NFTA_PAYLOAD_CSUM_FLAGS, htonl(priv->csum_flags))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static bool nft_payload_set_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + int i; + + for (i = 0; i < NFT_REG32_NUM; i++) { + if (!track->regs[i].selector) + continue; + + if (track->regs[i].selector->ops != &nft_payload_ops && + track->regs[i].selector->ops != &nft_payload_fast_ops) + continue; + + __nft_reg_track_cancel(track, i); + } + + return false; +} + +static const struct nft_expr_ops nft_payload_set_ops = { + .type = &nft_payload_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_payload_set)), + .eval = nft_payload_set_eval, + .init = nft_payload_set_init, + .dump = nft_payload_set_dump, + .reduce = nft_payload_set_reduce, +}; + +static const struct nft_expr_ops * +nft_payload_select_ops(const struct nft_ctx *ctx, + const struct nlattr * const tb[]) +{ + enum nft_payload_bases base; + unsigned int offset, len; + int err; + + if (tb[NFTA_PAYLOAD_BASE] == NULL || + tb[NFTA_PAYLOAD_OFFSET] == NULL || + tb[NFTA_PAYLOAD_LEN] == NULL) + return ERR_PTR(-EINVAL); + + base = ntohl(nla_get_be32(tb[NFTA_PAYLOAD_BASE])); + switch (base) { + case NFT_PAYLOAD_LL_HEADER: + case NFT_PAYLOAD_NETWORK_HEADER: + case NFT_PAYLOAD_TRANSPORT_HEADER: + case NFT_PAYLOAD_INNER_HEADER: + break; + default: + return ERR_PTR(-EOPNOTSUPP); + } + + if (tb[NFTA_PAYLOAD_SREG] != NULL) { + if (tb[NFTA_PAYLOAD_DREG] != NULL) + return ERR_PTR(-EINVAL); + return &nft_payload_set_ops; + } + + if (tb[NFTA_PAYLOAD_DREG] == NULL) + return ERR_PTR(-EINVAL); + + err = nft_parse_u32_check(tb[NFTA_PAYLOAD_OFFSET], U8_MAX, &offset); + if (err < 0) + return ERR_PTR(err); + + err = nft_parse_u32_check(tb[NFTA_PAYLOAD_LEN], U8_MAX, &len); + if (err < 0) + return ERR_PTR(err); + + if (len <= 4 && is_power_of_2(len) && IS_ALIGNED(offset, len) && + base != NFT_PAYLOAD_LL_HEADER && base != NFT_PAYLOAD_INNER_HEADER) + return &nft_payload_fast_ops; + else + return &nft_payload_ops; +} + +struct nft_expr_type nft_payload_type __read_mostly = { + .name = "payload", + .select_ops = nft_payload_select_ops, + .policy = nft_payload_policy, + .maxattr = NFTA_PAYLOAD_MAX, + .owner = THIS_MODULE, +}; diff --git a/net/netfilter/nft_queue.c b/net/netfilter/nft_queue.c new file mode 100644 index 000000000..da29e92c0 --- /dev/null +++ b/net/netfilter/nft_queue.c @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2013 Eric Leblond + * + * Development of this code partly funded by OISF + * (http://www.openinfosecfoundation.org/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static u32 jhash_initval __read_mostly; + +struct nft_queue { + u8 sreg_qnum; + u16 queuenum; + u16 queues_total; + u16 flags; +}; + +static void nft_queue_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_queue *priv = nft_expr_priv(expr); + u32 queue = priv->queuenum; + u32 ret; + + if (priv->queues_total > 1) { + if (priv->flags & NFT_QUEUE_FLAG_CPU_FANOUT) { + int cpu = raw_smp_processor_id(); + + queue = priv->queuenum + cpu % priv->queues_total; + } else { + queue = nfqueue_hash(pkt->skb, queue, + priv->queues_total, nft_pf(pkt), + jhash_initval); + } + } + + ret = NF_QUEUE_NR(queue); + if (priv->flags & NFT_QUEUE_FLAG_BYPASS) + ret |= NF_VERDICT_FLAG_QUEUE_BYPASS; + + regs->verdict.code = ret; +} + +static void nft_queue_sreg_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_queue *priv = nft_expr_priv(expr); + u32 queue, ret; + + queue = regs->data[priv->sreg_qnum]; + + ret = NF_QUEUE_NR(queue); + if (priv->flags & NFT_QUEUE_FLAG_BYPASS) + ret |= NF_VERDICT_FLAG_QUEUE_BYPASS; + + regs->verdict.code = ret; +} + +static int nft_queue_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + static const unsigned int supported_hooks = ((1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_FORWARD) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_POST_ROUTING)); + + switch (ctx->family) { + case NFPROTO_IPV4: + case NFPROTO_IPV6: + case NFPROTO_INET: + case NFPROTO_BRIDGE: + break; + case NFPROTO_NETDEV: /* lacks okfn */ + fallthrough; + default: + return -EOPNOTSUPP; + } + + return nft_chain_validate_hooks(ctx->chain, supported_hooks); +} + +static const struct nla_policy nft_queue_policy[NFTA_QUEUE_MAX + 1] = { + [NFTA_QUEUE_NUM] = { .type = NLA_U16 }, + [NFTA_QUEUE_TOTAL] = { .type = NLA_U16 }, + [NFTA_QUEUE_FLAGS] = { .type = NLA_U16 }, + [NFTA_QUEUE_SREG_QNUM] = { .type = NLA_U32 }, +}; + +static int nft_queue_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_queue *priv = nft_expr_priv(expr); + u32 maxid; + + priv->queuenum = ntohs(nla_get_be16(tb[NFTA_QUEUE_NUM])); + + if (tb[NFTA_QUEUE_TOTAL]) + priv->queues_total = ntohs(nla_get_be16(tb[NFTA_QUEUE_TOTAL])); + else + priv->queues_total = 1; + + if (priv->queues_total == 0) + return -EINVAL; + + maxid = priv->queues_total - 1 + priv->queuenum; + if (maxid > U16_MAX) + return -ERANGE; + + if (tb[NFTA_QUEUE_FLAGS]) { + priv->flags = ntohs(nla_get_be16(tb[NFTA_QUEUE_FLAGS])); + if (priv->flags & ~NFT_QUEUE_FLAG_MASK) + return -EINVAL; + } + return 0; +} + +static int nft_queue_sreg_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_queue *priv = nft_expr_priv(expr); + int err; + + err = nft_parse_register_load(tb[NFTA_QUEUE_SREG_QNUM], + &priv->sreg_qnum, sizeof(u32)); + if (err < 0) + return err; + + if (tb[NFTA_QUEUE_FLAGS]) { + priv->flags = ntohs(nla_get_be16(tb[NFTA_QUEUE_FLAGS])); + if (priv->flags & ~NFT_QUEUE_FLAG_MASK) + return -EINVAL; + if (priv->flags & NFT_QUEUE_FLAG_CPU_FANOUT) + return -EOPNOTSUPP; + } + + return 0; +} + +static int nft_queue_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_queue *priv = nft_expr_priv(expr); + + if (nla_put_be16(skb, NFTA_QUEUE_NUM, htons(priv->queuenum)) || + nla_put_be16(skb, NFTA_QUEUE_TOTAL, htons(priv->queues_total)) || + nla_put_be16(skb, NFTA_QUEUE_FLAGS, htons(priv->flags))) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static int +nft_queue_sreg_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_queue *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_QUEUE_SREG_QNUM, priv->sreg_qnum) || + nla_put_be16(skb, NFTA_QUEUE_FLAGS, htons(priv->flags))) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static struct nft_expr_type nft_queue_type; +static const struct nft_expr_ops nft_queue_ops = { + .type = &nft_queue_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_queue)), + .eval = nft_queue_eval, + .init = nft_queue_init, + .dump = nft_queue_dump, + .validate = nft_queue_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +static const struct nft_expr_ops nft_queue_sreg_ops = { + .type = &nft_queue_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_queue)), + .eval = nft_queue_sreg_eval, + .init = nft_queue_sreg_init, + .dump = nft_queue_sreg_dump, + .validate = nft_queue_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +static const struct nft_expr_ops * +nft_queue_select_ops(const struct nft_ctx *ctx, + const struct nlattr * const tb[]) +{ + if (tb[NFTA_QUEUE_NUM] && tb[NFTA_QUEUE_SREG_QNUM]) + return ERR_PTR(-EINVAL); + + init_hashrandom(&jhash_initval); + + if (tb[NFTA_QUEUE_NUM]) + return &nft_queue_ops; + + if (tb[NFTA_QUEUE_SREG_QNUM]) + return &nft_queue_sreg_ops; + + return ERR_PTR(-EINVAL); +} + +static struct nft_expr_type nft_queue_type __read_mostly = { + .name = "queue", + .select_ops = nft_queue_select_ops, + .policy = nft_queue_policy, + .maxattr = NFTA_QUEUE_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_queue_module_init(void) +{ + return nft_register_expr(&nft_queue_type); +} + +static void __exit nft_queue_module_exit(void) +{ + nft_unregister_expr(&nft_queue_type); +} + +module_init(nft_queue_module_init); +module_exit(nft_queue_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Eric Leblond "); +MODULE_ALIAS_NFT_EXPR("queue"); +MODULE_DESCRIPTION("Netfilter nftables queue module"); diff --git a/net/netfilter/nft_quota.c b/net/netfilter/nft_quota.c new file mode 100644 index 000000000..410a5fcf8 --- /dev/null +++ b/net/netfilter/nft_quota.c @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2016 Pablo Neira Ayuso + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_quota { + atomic64_t quota; + unsigned long flags; + atomic64_t *consumed; +}; + +static inline bool nft_overquota(struct nft_quota *priv, + const struct sk_buff *skb) +{ + return atomic64_add_return(skb->len, priv->consumed) >= + atomic64_read(&priv->quota); +} + +static inline bool nft_quota_invert(struct nft_quota *priv) +{ + return priv->flags & NFT_QUOTA_F_INV; +} + +static inline void nft_quota_do_eval(struct nft_quota *priv, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + if (nft_overquota(priv, pkt->skb) ^ nft_quota_invert(priv)) + regs->verdict.code = NFT_BREAK; +} + +static const struct nla_policy nft_quota_policy[NFTA_QUOTA_MAX + 1] = { + [NFTA_QUOTA_BYTES] = { .type = NLA_U64 }, + [NFTA_QUOTA_FLAGS] = { .type = NLA_U32 }, + [NFTA_QUOTA_CONSUMED] = { .type = NLA_U64 }, +}; + +#define NFT_QUOTA_DEPLETED_BIT 1 /* From NFT_QUOTA_F_DEPLETED. */ + +static void nft_quota_obj_eval(struct nft_object *obj, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_quota *priv = nft_obj_data(obj); + bool overquota; + + overquota = nft_overquota(priv, pkt->skb); + if (overquota ^ nft_quota_invert(priv)) + regs->verdict.code = NFT_BREAK; + + if (overquota && + !test_and_set_bit(NFT_QUOTA_DEPLETED_BIT, &priv->flags)) + nft_obj_notify(nft_net(pkt), obj->key.table, obj, 0, 0, + NFT_MSG_NEWOBJ, 0, nft_pf(pkt), 0, GFP_ATOMIC); +} + +static int nft_quota_do_init(const struct nlattr * const tb[], + struct nft_quota *priv) +{ + unsigned long flags = 0; + u64 quota, consumed = 0; + + if (!tb[NFTA_QUOTA_BYTES]) + return -EINVAL; + + quota = be64_to_cpu(nla_get_be64(tb[NFTA_QUOTA_BYTES])); + if (quota > S64_MAX) + return -EOVERFLOW; + + if (tb[NFTA_QUOTA_CONSUMED]) { + consumed = be64_to_cpu(nla_get_be64(tb[NFTA_QUOTA_CONSUMED])); + if (consumed > quota) + return -EINVAL; + } + + if (tb[NFTA_QUOTA_FLAGS]) { + flags = ntohl(nla_get_be32(tb[NFTA_QUOTA_FLAGS])); + if (flags & ~NFT_QUOTA_F_INV) + return -EINVAL; + if (flags & NFT_QUOTA_F_DEPLETED) + return -EOPNOTSUPP; + } + + priv->consumed = kmalloc(sizeof(*priv->consumed), GFP_KERNEL_ACCOUNT); + if (!priv->consumed) + return -ENOMEM; + + atomic64_set(&priv->quota, quota); + priv->flags = flags; + atomic64_set(priv->consumed, consumed); + + return 0; +} + +static void nft_quota_do_destroy(const struct nft_ctx *ctx, + struct nft_quota *priv) +{ + kfree(priv->consumed); +} + +static int nft_quota_obj_init(const struct nft_ctx *ctx, + const struct nlattr * const tb[], + struct nft_object *obj) +{ + struct nft_quota *priv = nft_obj_data(obj); + + return nft_quota_do_init(tb, priv); +} + +static void nft_quota_obj_update(struct nft_object *obj, + struct nft_object *newobj) +{ + struct nft_quota *newpriv = nft_obj_data(newobj); + struct nft_quota *priv = nft_obj_data(obj); + u64 newquota; + + newquota = atomic64_read(&newpriv->quota); + atomic64_set(&priv->quota, newquota); + priv->flags = newpriv->flags; +} + +static int nft_quota_do_dump(struct sk_buff *skb, struct nft_quota *priv, + bool reset) +{ + u64 consumed, consumed_cap, quota; + u32 flags = priv->flags; + + /* Since we inconditionally increment consumed quota for each packet + * that we see, don't go over the quota boundary in what we send to + * userspace. + */ + consumed = atomic64_read(priv->consumed); + quota = atomic64_read(&priv->quota); + if (consumed >= quota) { + consumed_cap = quota; + flags |= NFT_QUOTA_F_DEPLETED; + } else { + consumed_cap = consumed; + } + + if (nla_put_be64(skb, NFTA_QUOTA_BYTES, cpu_to_be64(quota), + NFTA_QUOTA_PAD) || + nla_put_be64(skb, NFTA_QUOTA_CONSUMED, cpu_to_be64(consumed_cap), + NFTA_QUOTA_PAD) || + nla_put_be32(skb, NFTA_QUOTA_FLAGS, htonl(flags))) + goto nla_put_failure; + + if (reset) { + atomic64_sub(consumed, priv->consumed); + clear_bit(NFT_QUOTA_DEPLETED_BIT, &priv->flags); + } + return 0; + +nla_put_failure: + return -1; +} + +static int nft_quota_obj_dump(struct sk_buff *skb, struct nft_object *obj, + bool reset) +{ + struct nft_quota *priv = nft_obj_data(obj); + + return nft_quota_do_dump(skb, priv, reset); +} + +static void nft_quota_obj_destroy(const struct nft_ctx *ctx, + struct nft_object *obj) +{ + struct nft_quota *priv = nft_obj_data(obj); + + return nft_quota_do_destroy(ctx, priv); +} + +static struct nft_object_type nft_quota_obj_type; +static const struct nft_object_ops nft_quota_obj_ops = { + .type = &nft_quota_obj_type, + .size = sizeof(struct nft_quota), + .init = nft_quota_obj_init, + .destroy = nft_quota_obj_destroy, + .eval = nft_quota_obj_eval, + .dump = nft_quota_obj_dump, + .update = nft_quota_obj_update, +}; + +static struct nft_object_type nft_quota_obj_type __read_mostly = { + .type = NFT_OBJECT_QUOTA, + .ops = &nft_quota_obj_ops, + .maxattr = NFTA_QUOTA_MAX, + .policy = nft_quota_policy, + .owner = THIS_MODULE, +}; + +static void nft_quota_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_quota *priv = nft_expr_priv(expr); + + nft_quota_do_eval(priv, regs, pkt); +} + +static int nft_quota_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_quota *priv = nft_expr_priv(expr); + + return nft_quota_do_init(tb, priv); +} + +static int nft_quota_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + struct nft_quota *priv = nft_expr_priv(expr); + + return nft_quota_do_dump(skb, priv, false); +} + +static void nft_quota_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + struct nft_quota *priv = nft_expr_priv(expr); + + return nft_quota_do_destroy(ctx, priv); +} + +static int nft_quota_clone(struct nft_expr *dst, const struct nft_expr *src) +{ + struct nft_quota *priv_dst = nft_expr_priv(dst); + struct nft_quota *priv_src = nft_expr_priv(src); + + priv_dst->quota = priv_src->quota; + priv_dst->flags = priv_src->flags; + + priv_dst->consumed = kmalloc(sizeof(*priv_dst->consumed), GFP_ATOMIC); + if (!priv_dst->consumed) + return -ENOMEM; + + *priv_dst->consumed = *priv_src->consumed; + + return 0; +} + +static struct nft_expr_type nft_quota_type; +static const struct nft_expr_ops nft_quota_ops = { + .type = &nft_quota_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_quota)), + .eval = nft_quota_eval, + .init = nft_quota_init, + .destroy = nft_quota_destroy, + .clone = nft_quota_clone, + .dump = nft_quota_dump, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_quota_type __read_mostly = { + .name = "quota", + .ops = &nft_quota_ops, + .policy = nft_quota_policy, + .maxattr = NFTA_QUOTA_MAX, + .flags = NFT_EXPR_STATEFUL, + .owner = THIS_MODULE, +}; + +static int __init nft_quota_module_init(void) +{ + int err; + + err = nft_register_obj(&nft_quota_obj_type); + if (err < 0) + return err; + + err = nft_register_expr(&nft_quota_type); + if (err < 0) + goto err1; + + return 0; +err1: + nft_unregister_obj(&nft_quota_obj_type); + return err; +} + +static void __exit nft_quota_module_exit(void) +{ + nft_unregister_expr(&nft_quota_type); + nft_unregister_obj(&nft_quota_obj_type); +} + +module_init(nft_quota_module_init); +module_exit(nft_quota_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_ALIAS_NFT_EXPR("quota"); +MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_QUOTA); +MODULE_DESCRIPTION("Netfilter nftables quota module"); diff --git a/net/netfilter/nft_range.c b/net/netfilter/nft_range.c new file mode 100644 index 000000000..832f0d725 --- /dev/null +++ b/net/netfilter/nft_range.c @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2016 Pablo Neira Ayuso + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_range_expr { + struct nft_data data_from; + struct nft_data data_to; + u8 sreg; + u8 len; + enum nft_range_ops op:8; +}; + +void nft_range_eval(const struct nft_expr *expr, + struct nft_regs *regs, const struct nft_pktinfo *pkt) +{ + const struct nft_range_expr *priv = nft_expr_priv(expr); + int d1, d2; + + d1 = memcmp(®s->data[priv->sreg], &priv->data_from, priv->len); + d2 = memcmp(®s->data[priv->sreg], &priv->data_to, priv->len); + switch (priv->op) { + case NFT_RANGE_EQ: + if (d1 < 0 || d2 > 0) + regs->verdict.code = NFT_BREAK; + break; + case NFT_RANGE_NEQ: + if (d1 >= 0 && d2 <= 0) + regs->verdict.code = NFT_BREAK; + break; + } +} + +static const struct nla_policy nft_range_policy[NFTA_RANGE_MAX + 1] = { + [NFTA_RANGE_SREG] = { .type = NLA_U32 }, + [NFTA_RANGE_OP] = { .type = NLA_U32 }, + [NFTA_RANGE_FROM_DATA] = { .type = NLA_NESTED }, + [NFTA_RANGE_TO_DATA] = { .type = NLA_NESTED }, +}; + +static int nft_range_init(const struct nft_ctx *ctx, const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_range_expr *priv = nft_expr_priv(expr); + struct nft_data_desc desc_from = { + .type = NFT_DATA_VALUE, + .size = sizeof(priv->data_from), + }; + struct nft_data_desc desc_to = { + .type = NFT_DATA_VALUE, + .size = sizeof(priv->data_to), + }; + int err; + u32 op; + + if (!tb[NFTA_RANGE_SREG] || + !tb[NFTA_RANGE_OP] || + !tb[NFTA_RANGE_FROM_DATA] || + !tb[NFTA_RANGE_TO_DATA]) + return -EINVAL; + + err = nft_data_init(NULL, &priv->data_from, &desc_from, + tb[NFTA_RANGE_FROM_DATA]); + if (err < 0) + return err; + + err = nft_data_init(NULL, &priv->data_to, &desc_to, + tb[NFTA_RANGE_TO_DATA]); + if (err < 0) + goto err1; + + if (desc_from.len != desc_to.len) { + err = -EINVAL; + goto err2; + } + + err = nft_parse_register_load(tb[NFTA_RANGE_SREG], &priv->sreg, + desc_from.len); + if (err < 0) + goto err2; + + err = nft_parse_u32_check(tb[NFTA_RANGE_OP], U8_MAX, &op); + if (err < 0) + goto err2; + + switch (op) { + case NFT_RANGE_EQ: + case NFT_RANGE_NEQ: + break; + default: + err = -EINVAL; + goto err2; + } + + priv->op = op; + priv->len = desc_from.len; + return 0; +err2: + nft_data_release(&priv->data_to, desc_to.type); +err1: + nft_data_release(&priv->data_from, desc_from.type); + return err; +} + +static int nft_range_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_range_expr *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_RANGE_SREG, priv->sreg)) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_RANGE_OP, htonl(priv->op))) + goto nla_put_failure; + + if (nft_data_dump(skb, NFTA_RANGE_FROM_DATA, &priv->data_from, + NFT_DATA_VALUE, priv->len) < 0 || + nft_data_dump(skb, NFTA_RANGE_TO_DATA, &priv->data_to, + NFT_DATA_VALUE, priv->len) < 0) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static const struct nft_expr_ops nft_range_ops = { + .type = &nft_range_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_range_expr)), + .eval = nft_range_eval, + .init = nft_range_init, + .dump = nft_range_dump, + .reduce = NFT_REDUCE_READONLY, +}; + +struct nft_expr_type nft_range_type __read_mostly = { + .name = "range", + .ops = &nft_range_ops, + .policy = nft_range_policy, + .maxattr = NFTA_RANGE_MAX, + .owner = THIS_MODULE, +}; diff --git a/net/netfilter/nft_redir.c b/net/netfilter/nft_redir.c new file mode 100644 index 000000000..08b408d3e --- /dev/null +++ b/net/netfilter/nft_redir.c @@ -0,0 +1,272 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2014 Arturo Borrero Gonzalez + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_redir { + u8 sreg_proto_min; + u8 sreg_proto_max; + u16 flags; +}; + +static const struct nla_policy nft_redir_policy[NFTA_REDIR_MAX + 1] = { + [NFTA_REDIR_REG_PROTO_MIN] = { .type = NLA_U32 }, + [NFTA_REDIR_REG_PROTO_MAX] = { .type = NLA_U32 }, + [NFTA_REDIR_FLAGS] = { .type = NLA_U32 }, +}; + +static int nft_redir_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + int err; + + err = nft_chain_validate_dependency(ctx->chain, NFT_CHAIN_T_NAT); + if (err < 0) + return err; + + return nft_chain_validate_hooks(ctx->chain, + (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_OUT)); +} + +static int nft_redir_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_redir *priv = nft_expr_priv(expr); + unsigned int plen; + int err; + + plen = sizeof_field(struct nf_nat_range, min_proto.all); + if (tb[NFTA_REDIR_REG_PROTO_MIN]) { + err = nft_parse_register_load(tb[NFTA_REDIR_REG_PROTO_MIN], + &priv->sreg_proto_min, plen); + if (err < 0) + return err; + + if (tb[NFTA_REDIR_REG_PROTO_MAX]) { + err = nft_parse_register_load(tb[NFTA_REDIR_REG_PROTO_MAX], + &priv->sreg_proto_max, + plen); + if (err < 0) + return err; + } else { + priv->sreg_proto_max = priv->sreg_proto_min; + } + + priv->flags |= NF_NAT_RANGE_PROTO_SPECIFIED; + } + + if (tb[NFTA_REDIR_FLAGS]) { + priv->flags = ntohl(nla_get_be32(tb[NFTA_REDIR_FLAGS])); + if (priv->flags & ~NF_NAT_RANGE_MASK) + return -EINVAL; + } + + return nf_ct_netns_get(ctx->net, ctx->family); +} + +static int nft_redir_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_redir *priv = nft_expr_priv(expr); + + if (priv->sreg_proto_min) { + if (nft_dump_register(skb, NFTA_REDIR_REG_PROTO_MIN, + priv->sreg_proto_min)) + goto nla_put_failure; + if (nft_dump_register(skb, NFTA_REDIR_REG_PROTO_MAX, + priv->sreg_proto_max)) + goto nla_put_failure; + } + + if (priv->flags != 0 && + nla_put_be32(skb, NFTA_REDIR_FLAGS, htonl(priv->flags))) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static void nft_redir_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_redir *priv = nft_expr_priv(expr); + struct nf_nat_range2 range; + + memset(&range, 0, sizeof(range)); + range.flags = priv->flags; + if (priv->sreg_proto_min) { + range.min_proto.all = (__force __be16) + nft_reg_load16(®s->data[priv->sreg_proto_min]); + range.max_proto.all = (__force __be16) + nft_reg_load16(®s->data[priv->sreg_proto_max]); + } + + switch (nft_pf(pkt)) { + case NFPROTO_IPV4: + regs->verdict.code = nf_nat_redirect_ipv4(pkt->skb, &range, + nft_hook(pkt)); + break; +#ifdef CONFIG_NF_TABLES_IPV6 + case NFPROTO_IPV6: + regs->verdict.code = nf_nat_redirect_ipv6(pkt->skb, &range, + nft_hook(pkt)); + break; +#endif + default: + WARN_ON_ONCE(1); + break; + } +} + +static void +nft_redir_ipv4_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) +{ + nf_ct_netns_put(ctx->net, NFPROTO_IPV4); +} + +static struct nft_expr_type nft_redir_ipv4_type; +static const struct nft_expr_ops nft_redir_ipv4_ops = { + .type = &nft_redir_ipv4_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_redir)), + .eval = nft_redir_eval, + .init = nft_redir_init, + .destroy = nft_redir_ipv4_destroy, + .dump = nft_redir_dump, + .validate = nft_redir_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_redir_ipv4_type __read_mostly = { + .family = NFPROTO_IPV4, + .name = "redir", + .ops = &nft_redir_ipv4_ops, + .policy = nft_redir_policy, + .maxattr = NFTA_REDIR_MAX, + .owner = THIS_MODULE, +}; + +#ifdef CONFIG_NF_TABLES_IPV6 +static void +nft_redir_ipv6_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) +{ + nf_ct_netns_put(ctx->net, NFPROTO_IPV6); +} + +static struct nft_expr_type nft_redir_ipv6_type; +static const struct nft_expr_ops nft_redir_ipv6_ops = { + .type = &nft_redir_ipv6_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_redir)), + .eval = nft_redir_eval, + .init = nft_redir_init, + .destroy = nft_redir_ipv6_destroy, + .dump = nft_redir_dump, + .validate = nft_redir_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_redir_ipv6_type __read_mostly = { + .family = NFPROTO_IPV6, + .name = "redir", + .ops = &nft_redir_ipv6_ops, + .policy = nft_redir_policy, + .maxattr = NFTA_REDIR_MAX, + .owner = THIS_MODULE, +}; +#endif + +#ifdef CONFIG_NF_TABLES_INET +static void +nft_redir_inet_destroy(const struct nft_ctx *ctx, const struct nft_expr *expr) +{ + nf_ct_netns_put(ctx->net, NFPROTO_INET); +} + +static struct nft_expr_type nft_redir_inet_type; +static const struct nft_expr_ops nft_redir_inet_ops = { + .type = &nft_redir_inet_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_redir)), + .eval = nft_redir_eval, + .init = nft_redir_init, + .destroy = nft_redir_inet_destroy, + .dump = nft_redir_dump, + .validate = nft_redir_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_redir_inet_type __read_mostly = { + .family = NFPROTO_INET, + .name = "redir", + .ops = &nft_redir_inet_ops, + .policy = nft_redir_policy, + .maxattr = NFTA_REDIR_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_redir_module_init_inet(void) +{ + return nft_register_expr(&nft_redir_inet_type); +} +#else +static inline int nft_redir_module_init_inet(void) { return 0; } +#endif + +static int __init nft_redir_module_init(void) +{ + int ret = nft_register_expr(&nft_redir_ipv4_type); + + if (ret) + return ret; + +#ifdef CONFIG_NF_TABLES_IPV6 + ret = nft_register_expr(&nft_redir_ipv6_type); + if (ret) { + nft_unregister_expr(&nft_redir_ipv4_type); + return ret; + } +#endif + + ret = nft_redir_module_init_inet(); + if (ret < 0) { + nft_unregister_expr(&nft_redir_ipv4_type); +#ifdef CONFIG_NF_TABLES_IPV6 + nft_unregister_expr(&nft_redir_ipv6_type); +#endif + return ret; + } + + return ret; +} + +static void __exit nft_redir_module_exit(void) +{ + nft_unregister_expr(&nft_redir_ipv4_type); +#ifdef CONFIG_NF_TABLES_IPV6 + nft_unregister_expr(&nft_redir_ipv6_type); +#endif +#ifdef CONFIG_NF_TABLES_INET + nft_unregister_expr(&nft_redir_inet_type); +#endif +} + +module_init(nft_redir_module_init); +module_exit(nft_redir_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Arturo Borrero Gonzalez "); +MODULE_ALIAS_NFT_EXPR("redir"); +MODULE_DESCRIPTION("Netfilter nftables redirect support"); diff --git a/net/netfilter/nft_reject.c b/net/netfilter/nft_reject.c new file mode 100644 index 000000000..927ff8459 --- /dev/null +++ b/net/netfilter/nft_reject.c @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008-2009 Patrick McHardy + * Copyright (c) 2013 Eric Leblond + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +const struct nla_policy nft_reject_policy[NFTA_REJECT_MAX + 1] = { + [NFTA_REJECT_TYPE] = { .type = NLA_U32 }, + [NFTA_REJECT_ICMP_CODE] = { .type = NLA_U8 }, +}; +EXPORT_SYMBOL_GPL(nft_reject_policy); + +int nft_reject_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + return nft_chain_validate_hooks(ctx->chain, + (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_FORWARD) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_PRE_ROUTING)); +} +EXPORT_SYMBOL_GPL(nft_reject_validate); + +int nft_reject_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_reject *priv = nft_expr_priv(expr); + int icmp_code; + + if (tb[NFTA_REJECT_TYPE] == NULL) + return -EINVAL; + + priv->type = ntohl(nla_get_be32(tb[NFTA_REJECT_TYPE])); + switch (priv->type) { + case NFT_REJECT_ICMP_UNREACH: + case NFT_REJECT_ICMPX_UNREACH: + if (tb[NFTA_REJECT_ICMP_CODE] == NULL) + return -EINVAL; + + icmp_code = nla_get_u8(tb[NFTA_REJECT_ICMP_CODE]); + if (priv->type == NFT_REJECT_ICMPX_UNREACH && + icmp_code > NFT_REJECT_ICMPX_MAX) + return -EINVAL; + + priv->icmp_code = icmp_code; + break; + case NFT_REJECT_TCP_RST: + break; + default: + return -EINVAL; + } + + return 0; +} +EXPORT_SYMBOL_GPL(nft_reject_init); + +int nft_reject_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + const struct nft_reject *priv = nft_expr_priv(expr); + + if (nla_put_be32(skb, NFTA_REJECT_TYPE, htonl(priv->type))) + goto nla_put_failure; + + switch (priv->type) { + case NFT_REJECT_ICMP_UNREACH: + case NFT_REJECT_ICMPX_UNREACH: + if (nla_put_u8(skb, NFTA_REJECT_ICMP_CODE, priv->icmp_code)) + goto nla_put_failure; + break; + default: + break; + } + + return 0; + +nla_put_failure: + return -1; +} +EXPORT_SYMBOL_GPL(nft_reject_dump); + +static u8 icmp_code_v4[NFT_REJECT_ICMPX_MAX + 1] = { + [NFT_REJECT_ICMPX_NO_ROUTE] = ICMP_NET_UNREACH, + [NFT_REJECT_ICMPX_PORT_UNREACH] = ICMP_PORT_UNREACH, + [NFT_REJECT_ICMPX_HOST_UNREACH] = ICMP_HOST_UNREACH, + [NFT_REJECT_ICMPX_ADMIN_PROHIBITED] = ICMP_PKT_FILTERED, +}; + +int nft_reject_icmp_code(u8 code) +{ + if (WARN_ON_ONCE(code > NFT_REJECT_ICMPX_MAX)) + return ICMP_NET_UNREACH; + + return icmp_code_v4[code]; +} + +EXPORT_SYMBOL_GPL(nft_reject_icmp_code); + + +static u8 icmp_code_v6[NFT_REJECT_ICMPX_MAX + 1] = { + [NFT_REJECT_ICMPX_NO_ROUTE] = ICMPV6_NOROUTE, + [NFT_REJECT_ICMPX_PORT_UNREACH] = ICMPV6_PORT_UNREACH, + [NFT_REJECT_ICMPX_HOST_UNREACH] = ICMPV6_ADDR_UNREACH, + [NFT_REJECT_ICMPX_ADMIN_PROHIBITED] = ICMPV6_ADM_PROHIBITED, +}; + +int nft_reject_icmpv6_code(u8 code) +{ + if (WARN_ON_ONCE(code > NFT_REJECT_ICMPX_MAX)) + return ICMPV6_NOROUTE; + + return icmp_code_v6[code]; +} + +EXPORT_SYMBOL_GPL(nft_reject_icmpv6_code); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_DESCRIPTION("Netfilter x_tables over nftables module"); diff --git a/net/netfilter/nft_reject_inet.c b/net/netfilter/nft_reject_inet.c new file mode 100644 index 000000000..973fa31a9 --- /dev/null +++ b/net/netfilter/nft_reject_inet.c @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2014 Patrick McHardy + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void nft_reject_inet_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_reject *priv = nft_expr_priv(expr); + + switch (nft_pf(pkt)) { + case NFPROTO_IPV4: + switch (priv->type) { + case NFT_REJECT_ICMP_UNREACH: + nf_send_unreach(pkt->skb, priv->icmp_code, + nft_hook(pkt)); + break; + case NFT_REJECT_TCP_RST: + nf_send_reset(nft_net(pkt), nft_sk(pkt), + pkt->skb, nft_hook(pkt)); + break; + case NFT_REJECT_ICMPX_UNREACH: + nf_send_unreach(pkt->skb, + nft_reject_icmp_code(priv->icmp_code), + nft_hook(pkt)); + break; + } + break; + case NFPROTO_IPV6: + switch (priv->type) { + case NFT_REJECT_ICMP_UNREACH: + nf_send_unreach6(nft_net(pkt), pkt->skb, + priv->icmp_code, nft_hook(pkt)); + break; + case NFT_REJECT_TCP_RST: + nf_send_reset6(nft_net(pkt), nft_sk(pkt), + pkt->skb, nft_hook(pkt)); + break; + case NFT_REJECT_ICMPX_UNREACH: + nf_send_unreach6(nft_net(pkt), pkt->skb, + nft_reject_icmpv6_code(priv->icmp_code), + nft_hook(pkt)); + break; + } + break; + } + + regs->verdict.code = NF_DROP; +} + +static int nft_reject_inet_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + return nft_chain_validate_hooks(ctx->chain, + (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_FORWARD) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_INGRESS)); +} + +static struct nft_expr_type nft_reject_inet_type; +static const struct nft_expr_ops nft_reject_inet_ops = { + .type = &nft_reject_inet_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_reject)), + .eval = nft_reject_inet_eval, + .init = nft_reject_init, + .dump = nft_reject_dump, + .validate = nft_reject_inet_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_reject_inet_type __read_mostly = { + .family = NFPROTO_INET, + .name = "reject", + .ops = &nft_reject_inet_ops, + .policy = nft_reject_policy, + .maxattr = NFTA_REJECT_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_reject_inet_module_init(void) +{ + return nft_register_expr(&nft_reject_inet_type); +} + +static void __exit nft_reject_inet_module_exit(void) +{ + nft_unregister_expr(&nft_reject_inet_type); +} + +module_init(nft_reject_inet_module_init); +module_exit(nft_reject_inet_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_ALIAS_NFT_AF_EXPR(1, "reject"); +MODULE_DESCRIPTION("Netfilter nftables reject inet support"); diff --git a/net/netfilter/nft_reject_netdev.c b/net/netfilter/nft_reject_netdev.c new file mode 100644 index 000000000..7865cd8b1 --- /dev/null +++ b/net/netfilter/nft_reject_netdev.c @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2020 Laura Garcia Liebana + * Copyright (c) 2020 Jose M. Guisado + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void nft_reject_queue_xmit(struct sk_buff *nskb, struct sk_buff *oldskb) +{ + dev_hard_header(nskb, nskb->dev, ntohs(oldskb->protocol), + eth_hdr(oldskb)->h_source, eth_hdr(oldskb)->h_dest, + nskb->len); + dev_queue_xmit(nskb); +} + +static void nft_reject_netdev_send_v4_tcp_reset(struct net *net, + struct sk_buff *oldskb, + const struct net_device *dev, + int hook) +{ + struct sk_buff *nskb; + + nskb = nf_reject_skb_v4_tcp_reset(net, oldskb, dev, hook); + if (!nskb) + return; + + nft_reject_queue_xmit(nskb, oldskb); +} + +static void nft_reject_netdev_send_v4_unreach(struct net *net, + struct sk_buff *oldskb, + const struct net_device *dev, + int hook, u8 code) +{ + struct sk_buff *nskb; + + nskb = nf_reject_skb_v4_unreach(net, oldskb, dev, hook, code); + if (!nskb) + return; + + nft_reject_queue_xmit(nskb, oldskb); +} + +static void nft_reject_netdev_send_v6_tcp_reset(struct net *net, + struct sk_buff *oldskb, + const struct net_device *dev, + int hook) +{ + struct sk_buff *nskb; + + nskb = nf_reject_skb_v6_tcp_reset(net, oldskb, dev, hook); + if (!nskb) + return; + + nft_reject_queue_xmit(nskb, oldskb); +} + + +static void nft_reject_netdev_send_v6_unreach(struct net *net, + struct sk_buff *oldskb, + const struct net_device *dev, + int hook, u8 code) +{ + struct sk_buff *nskb; + + nskb = nf_reject_skb_v6_unreach(net, oldskb, dev, hook, code); + if (!nskb) + return; + + nft_reject_queue_xmit(nskb, oldskb); +} + +static void nft_reject_netdev_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct ethhdr *eth = eth_hdr(pkt->skb); + struct nft_reject *priv = nft_expr_priv(expr); + const unsigned char *dest = eth->h_dest; + + if (is_broadcast_ether_addr(dest) || + is_multicast_ether_addr(dest)) + goto out; + + switch (eth->h_proto) { + case htons(ETH_P_IP): + switch (priv->type) { + case NFT_REJECT_ICMP_UNREACH: + nft_reject_netdev_send_v4_unreach(nft_net(pkt), pkt->skb, + nft_in(pkt), + nft_hook(pkt), + priv->icmp_code); + break; + case NFT_REJECT_TCP_RST: + nft_reject_netdev_send_v4_tcp_reset(nft_net(pkt), pkt->skb, + nft_in(pkt), + nft_hook(pkt)); + break; + case NFT_REJECT_ICMPX_UNREACH: + nft_reject_netdev_send_v4_unreach(nft_net(pkt), pkt->skb, + nft_in(pkt), + nft_hook(pkt), + nft_reject_icmp_code(priv->icmp_code)); + break; + } + break; + case htons(ETH_P_IPV6): + switch (priv->type) { + case NFT_REJECT_ICMP_UNREACH: + nft_reject_netdev_send_v6_unreach(nft_net(pkt), pkt->skb, + nft_in(pkt), + nft_hook(pkt), + priv->icmp_code); + break; + case NFT_REJECT_TCP_RST: + nft_reject_netdev_send_v6_tcp_reset(nft_net(pkt), pkt->skb, + nft_in(pkt), + nft_hook(pkt)); + break; + case NFT_REJECT_ICMPX_UNREACH: + nft_reject_netdev_send_v6_unreach(nft_net(pkt), pkt->skb, + nft_in(pkt), + nft_hook(pkt), + nft_reject_icmpv6_code(priv->icmp_code)); + break; + } + break; + default: + /* No explicit way to reject this protocol, drop it. */ + break; + } +out: + regs->verdict.code = NF_DROP; +} + +static int nft_reject_netdev_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + return nft_chain_validate_hooks(ctx->chain, (1 << NF_NETDEV_INGRESS)); +} + +static struct nft_expr_type nft_reject_netdev_type; +static const struct nft_expr_ops nft_reject_netdev_ops = { + .type = &nft_reject_netdev_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_reject)), + .eval = nft_reject_netdev_eval, + .init = nft_reject_init, + .dump = nft_reject_dump, + .validate = nft_reject_netdev_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_reject_netdev_type __read_mostly = { + .family = NFPROTO_NETDEV, + .name = "reject", + .ops = &nft_reject_netdev_ops, + .policy = nft_reject_policy, + .maxattr = NFTA_REJECT_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_reject_netdev_module_init(void) +{ + return nft_register_expr(&nft_reject_netdev_type); +} + +static void __exit nft_reject_netdev_module_exit(void) +{ + nft_unregister_expr(&nft_reject_netdev_type); +} + +module_init(nft_reject_netdev_module_init); +module_exit(nft_reject_netdev_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Laura Garcia Liebana "); +MODULE_AUTHOR("Jose M. Guisado "); +MODULE_DESCRIPTION("Reject packets from netdev via nftables"); +MODULE_ALIAS_NFT_AF_EXPR(5, "reject"); diff --git a/net/netfilter/nft_rt.c b/net/netfilter/nft_rt.c new file mode 100644 index 000000000..7d21e1649 --- /dev/null +++ b/net/netfilter/nft_rt.c @@ -0,0 +1,208 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2016 Anders K. Pedersen + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_rt { + enum nft_rt_keys key:8; + u8 dreg; +}; + +static u16 get_tcpmss(const struct nft_pktinfo *pkt, const struct dst_entry *skbdst) +{ + u32 minlen = sizeof(struct ipv6hdr), mtu = dst_mtu(skbdst); + const struct sk_buff *skb = pkt->skb; + struct dst_entry *dst = NULL; + struct flowi fl; + + memset(&fl, 0, sizeof(fl)); + + switch (nft_pf(pkt)) { + case NFPROTO_IPV4: + fl.u.ip4.daddr = ip_hdr(skb)->saddr; + minlen = sizeof(struct iphdr) + sizeof(struct tcphdr); + break; + case NFPROTO_IPV6: + fl.u.ip6.daddr = ipv6_hdr(skb)->saddr; + minlen = sizeof(struct ipv6hdr) + sizeof(struct tcphdr); + break; + } + + nf_route(nft_net(pkt), &dst, &fl, false, nft_pf(pkt)); + if (dst) { + mtu = min(mtu, dst_mtu(dst)); + dst_release(dst); + } + + if (mtu <= minlen || mtu > 0xffff) + return TCP_MSS_DEFAULT; + + return mtu - minlen; +} + +void nft_rt_get_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_rt *priv = nft_expr_priv(expr); + const struct sk_buff *skb = pkt->skb; + u32 *dest = ®s->data[priv->dreg]; + const struct dst_entry *dst; + + dst = skb_dst(skb); + if (!dst) + goto err; + + switch (priv->key) { +#ifdef CONFIG_IP_ROUTE_CLASSID + case NFT_RT_CLASSID: + *dest = dst->tclassid; + break; +#endif + case NFT_RT_NEXTHOP4: + if (nft_pf(pkt) != NFPROTO_IPV4) + goto err; + + *dest = (__force u32)rt_nexthop((const struct rtable *)dst, + ip_hdr(skb)->daddr); + break; + case NFT_RT_NEXTHOP6: + if (nft_pf(pkt) != NFPROTO_IPV6) + goto err; + + memcpy(dest, rt6_nexthop((struct rt6_info *)dst, + &ipv6_hdr(skb)->daddr), + sizeof(struct in6_addr)); + break; + case NFT_RT_TCPMSS: + nft_reg_store16(dest, get_tcpmss(pkt, dst)); + break; +#ifdef CONFIG_XFRM + case NFT_RT_XFRM: + nft_reg_store8(dest, !!dst->xfrm); + break; +#endif + default: + WARN_ON(1); + goto err; + } + return; + +err: + regs->verdict.code = NFT_BREAK; +} + +static const struct nla_policy nft_rt_policy[NFTA_RT_MAX + 1] = { + [NFTA_RT_DREG] = { .type = NLA_U32 }, + [NFTA_RT_KEY] = { .type = NLA_U32 }, +}; + +static int nft_rt_get_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_rt *priv = nft_expr_priv(expr); + unsigned int len; + + if (tb[NFTA_RT_KEY] == NULL || + tb[NFTA_RT_DREG] == NULL) + return -EINVAL; + + priv->key = ntohl(nla_get_be32(tb[NFTA_RT_KEY])); + switch (priv->key) { +#ifdef CONFIG_IP_ROUTE_CLASSID + case NFT_RT_CLASSID: +#endif + case NFT_RT_NEXTHOP4: + len = sizeof(u32); + break; + case NFT_RT_NEXTHOP6: + len = sizeof(struct in6_addr); + break; + case NFT_RT_TCPMSS: + len = sizeof(u16); + break; +#ifdef CONFIG_XFRM + case NFT_RT_XFRM: + len = sizeof(u8); + break; +#endif + default: + return -EOPNOTSUPP; + } + + return nft_parse_register_store(ctx, tb[NFTA_RT_DREG], &priv->dreg, + NULL, NFT_DATA_VALUE, len); +} + +static int nft_rt_get_dump(struct sk_buff *skb, + const struct nft_expr *expr) +{ + const struct nft_rt *priv = nft_expr_priv(expr); + + if (nla_put_be32(skb, NFTA_RT_KEY, htonl(priv->key))) + goto nla_put_failure; + if (nft_dump_register(skb, NFTA_RT_DREG, priv->dreg)) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static int nft_rt_validate(const struct nft_ctx *ctx, const struct nft_expr *expr, + const struct nft_data **data) +{ + const struct nft_rt *priv = nft_expr_priv(expr); + unsigned int hooks; + + if (ctx->family != NFPROTO_IPV4 && + ctx->family != NFPROTO_IPV6 && + ctx->family != NFPROTO_INET) + return -EOPNOTSUPP; + + switch (priv->key) { + case NFT_RT_NEXTHOP4: + case NFT_RT_NEXTHOP6: + case NFT_RT_CLASSID: + case NFT_RT_XFRM: + return 0; + case NFT_RT_TCPMSS: + hooks = (1 << NF_INET_FORWARD) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_POST_ROUTING); + break; + default: + return -EINVAL; + } + + return nft_chain_validate_hooks(ctx->chain, hooks); +} + +static const struct nft_expr_ops nft_rt_get_ops = { + .type = &nft_rt_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_rt)), + .eval = nft_rt_get_eval, + .init = nft_rt_get_init, + .dump = nft_rt_get_dump, + .validate = nft_rt_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +struct nft_expr_type nft_rt_type __read_mostly = { + .name = "rt", + .ops = &nft_rt_get_ops, + .policy = nft_rt_policy, + .maxattr = NFTA_RT_MAX, + .owner = THIS_MODULE, +}; diff --git a/net/netfilter/nft_set_bitmap.c b/net/netfilter/nft_set_bitmap.c new file mode 100644 index 000000000..1e5e7a181 --- /dev/null +++ b/net/netfilter/nft_set_bitmap.c @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2017 Pablo Neira Ayuso + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_bitmap_elem { + struct list_head head; + struct nft_set_ext ext; +}; + +/* This bitmap uses two bits to represent one element. These two bits determine + * the element state in the current and the future generation. + * + * An element can be in three states. The generation cursor is represented using + * the ^ character, note that this cursor shifts on every successful transaction. + * If no transaction is going on, we observe all elements are in the following + * state: + * + * 11 = this element is active in the current generation. In case of no updates, + * ^ it stays active in the next generation. + * 00 = this element is inactive in the current generation. In case of no + * ^ updates, it stays inactive in the next generation. + * + * On transaction handling, we observe these two temporary states: + * + * 01 = this element is inactive in the current generation and it becomes active + * ^ in the next one. This happens when the element is inserted but commit + * path has not yet been executed yet, so activation is still pending. On + * transaction abortion, the element is removed. + * 10 = this element is active in the current generation and it becomes inactive + * ^ in the next one. This happens when the element is deactivated but commit + * path has not yet been executed yet, so removal is still pending. On + * transaction abortion, the next generation bit is reset to go back to + * restore its previous state. + */ +struct nft_bitmap { + struct list_head list; + u16 bitmap_size; + u8 bitmap[]; +}; + +static inline void nft_bitmap_location(const struct nft_set *set, + const void *key, + u32 *idx, u32 *off) +{ + u32 k; + + if (set->klen == 2) + k = *(u16 *)key; + else + k = *(u8 *)key; + k <<= 1; + + *idx = k / BITS_PER_BYTE; + *off = k % BITS_PER_BYTE; +} + +/* Fetch the two bits that represent the element and check if it is active based + * on the generation mask. + */ +static inline bool +nft_bitmap_active(const u8 *bitmap, u32 idx, u32 off, u8 genmask) +{ + return (bitmap[idx] & (0x3 << off)) & (genmask << off); +} + +INDIRECT_CALLABLE_SCOPE +bool nft_bitmap_lookup(const struct net *net, const struct nft_set *set, + const u32 *key, const struct nft_set_ext **ext) +{ + const struct nft_bitmap *priv = nft_set_priv(set); + u8 genmask = nft_genmask_cur(net); + u32 idx, off; + + nft_bitmap_location(set, key, &idx, &off); + + return nft_bitmap_active(priv->bitmap, idx, off, genmask); +} + +static struct nft_bitmap_elem * +nft_bitmap_elem_find(const struct nft_set *set, struct nft_bitmap_elem *this, + u8 genmask) +{ + const struct nft_bitmap *priv = nft_set_priv(set); + struct nft_bitmap_elem *be; + + list_for_each_entry_rcu(be, &priv->list, head) { + if (memcmp(nft_set_ext_key(&be->ext), + nft_set_ext_key(&this->ext), set->klen) || + !nft_set_elem_active(&be->ext, genmask)) + continue; + + return be; + } + return NULL; +} + +static void *nft_bitmap_get(const struct net *net, const struct nft_set *set, + const struct nft_set_elem *elem, unsigned int flags) +{ + const struct nft_bitmap *priv = nft_set_priv(set); + u8 genmask = nft_genmask_cur(net); + struct nft_bitmap_elem *be; + + list_for_each_entry_rcu(be, &priv->list, head) { + if (memcmp(nft_set_ext_key(&be->ext), elem->key.val.data, set->klen) || + !nft_set_elem_active(&be->ext, genmask)) + continue; + + return be; + } + return ERR_PTR(-ENOENT); +} + +static int nft_bitmap_insert(const struct net *net, const struct nft_set *set, + const struct nft_set_elem *elem, + struct nft_set_ext **ext) +{ + struct nft_bitmap *priv = nft_set_priv(set); + struct nft_bitmap_elem *new = elem->priv, *be; + u8 genmask = nft_genmask_next(net); + u32 idx, off; + + be = nft_bitmap_elem_find(set, new, genmask); + if (be) { + *ext = &be->ext; + return -EEXIST; + } + + nft_bitmap_location(set, nft_set_ext_key(&new->ext), &idx, &off); + /* Enter 01 state. */ + priv->bitmap[idx] |= (genmask << off); + list_add_tail_rcu(&new->head, &priv->list); + + return 0; +} + +static void nft_bitmap_remove(const struct net *net, + const struct nft_set *set, + const struct nft_set_elem *elem) +{ + struct nft_bitmap *priv = nft_set_priv(set); + struct nft_bitmap_elem *be = elem->priv; + u8 genmask = nft_genmask_next(net); + u32 idx, off; + + nft_bitmap_location(set, nft_set_ext_key(&be->ext), &idx, &off); + /* Enter 00 state. */ + priv->bitmap[idx] &= ~(genmask << off); + list_del_rcu(&be->head); +} + +static void nft_bitmap_activate(const struct net *net, + const struct nft_set *set, + const struct nft_set_elem *elem) +{ + struct nft_bitmap *priv = nft_set_priv(set); + struct nft_bitmap_elem *be = elem->priv; + u8 genmask = nft_genmask_next(net); + u32 idx, off; + + nft_bitmap_location(set, nft_set_ext_key(&be->ext), &idx, &off); + /* Enter 11 state. */ + priv->bitmap[idx] |= (genmask << off); + nft_set_elem_change_active(net, set, &be->ext); +} + +static bool nft_bitmap_flush(const struct net *net, + const struct nft_set *set, void *_be) +{ + struct nft_bitmap *priv = nft_set_priv(set); + u8 genmask = nft_genmask_next(net); + struct nft_bitmap_elem *be = _be; + u32 idx, off; + + nft_bitmap_location(set, nft_set_ext_key(&be->ext), &idx, &off); + /* Enter 10 state, similar to deactivation. */ + priv->bitmap[idx] &= ~(genmask << off); + nft_set_elem_change_active(net, set, &be->ext); + + return true; +} + +static void *nft_bitmap_deactivate(const struct net *net, + const struct nft_set *set, + const struct nft_set_elem *elem) +{ + struct nft_bitmap *priv = nft_set_priv(set); + struct nft_bitmap_elem *this = elem->priv, *be; + u8 genmask = nft_genmask_next(net); + u32 idx, off; + + nft_bitmap_location(set, elem->key.val.data, &idx, &off); + + be = nft_bitmap_elem_find(set, this, genmask); + if (!be) + return NULL; + + /* Enter 10 state. */ + priv->bitmap[idx] &= ~(genmask << off); + nft_set_elem_change_active(net, set, &be->ext); + + return be; +} + +static void nft_bitmap_walk(const struct nft_ctx *ctx, + struct nft_set *set, + struct nft_set_iter *iter) +{ + const struct nft_bitmap *priv = nft_set_priv(set); + struct nft_bitmap_elem *be; + struct nft_set_elem elem; + + list_for_each_entry_rcu(be, &priv->list, head) { + if (iter->count < iter->skip) + goto cont; + if (!nft_set_elem_active(&be->ext, iter->genmask)) + goto cont; + + elem.priv = be; + + iter->err = iter->fn(ctx, set, iter, &elem); + + if (iter->err < 0) + return; +cont: + iter->count++; + } +} + +/* The bitmap size is pow(2, key length in bits) / bits per byte. This is + * multiplied by two since each element takes two bits. For 8 bit keys, the + * bitmap consumes 66 bytes. For 16 bit keys, 16388 bytes. + */ +static inline u32 nft_bitmap_size(u32 klen) +{ + return ((2 << ((klen * BITS_PER_BYTE) - 1)) / BITS_PER_BYTE) << 1; +} + +static inline u64 nft_bitmap_total_size(u32 klen) +{ + return sizeof(struct nft_bitmap) + nft_bitmap_size(klen); +} + +static u64 nft_bitmap_privsize(const struct nlattr * const nla[], + const struct nft_set_desc *desc) +{ + u32 klen = ntohl(nla_get_be32(nla[NFTA_SET_KEY_LEN])); + + return nft_bitmap_total_size(klen); +} + +static int nft_bitmap_init(const struct nft_set *set, + const struct nft_set_desc *desc, + const struct nlattr * const nla[]) +{ + struct nft_bitmap *priv = nft_set_priv(set); + + INIT_LIST_HEAD(&priv->list); + priv->bitmap_size = nft_bitmap_size(set->klen); + + return 0; +} + +static void nft_bitmap_destroy(const struct nft_ctx *ctx, + const struct nft_set *set) +{ + struct nft_bitmap *priv = nft_set_priv(set); + struct nft_bitmap_elem *be, *n; + + list_for_each_entry_safe(be, n, &priv->list, head) + nf_tables_set_elem_destroy(ctx, set, be); +} + +static bool nft_bitmap_estimate(const struct nft_set_desc *desc, u32 features, + struct nft_set_estimate *est) +{ + /* Make sure bitmaps we don't get bitmaps larger than 16 Kbytes. */ + if (desc->klen > 2) + return false; + else if (desc->expr) + return false; + + est->size = nft_bitmap_total_size(desc->klen); + est->lookup = NFT_SET_CLASS_O_1; + est->space = NFT_SET_CLASS_O_1; + + return true; +} + +const struct nft_set_type nft_set_bitmap_type = { + .ops = { + .privsize = nft_bitmap_privsize, + .elemsize = offsetof(struct nft_bitmap_elem, ext), + .estimate = nft_bitmap_estimate, + .init = nft_bitmap_init, + .destroy = nft_bitmap_destroy, + .insert = nft_bitmap_insert, + .remove = nft_bitmap_remove, + .deactivate = nft_bitmap_deactivate, + .flush = nft_bitmap_flush, + .activate = nft_bitmap_activate, + .lookup = nft_bitmap_lookup, + .walk = nft_bitmap_walk, + .get = nft_bitmap_get, + }, +}; diff --git a/net/netfilter/nft_set_hash.c b/net/netfilter/nft_set_hash.c new file mode 100644 index 000000000..2013de934 --- /dev/null +++ b/net/netfilter/nft_set_hash.c @@ -0,0 +1,788 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008-2014 Patrick McHardy + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* We target a hash table size of 4, element hint is 75% of final size */ +#define NFT_RHASH_ELEMENT_HINT 3 + +struct nft_rhash { + struct rhashtable ht; + struct delayed_work gc_work; +}; + +struct nft_rhash_elem { + struct rhash_head node; + struct nft_set_ext ext; +}; + +struct nft_rhash_cmp_arg { + const struct nft_set *set; + const u32 *key; + u8 genmask; +}; + +static inline u32 nft_rhash_key(const void *data, u32 len, u32 seed) +{ + const struct nft_rhash_cmp_arg *arg = data; + + return jhash(arg->key, len, seed); +} + +static inline u32 nft_rhash_obj(const void *data, u32 len, u32 seed) +{ + const struct nft_rhash_elem *he = data; + + return jhash(nft_set_ext_key(&he->ext), len, seed); +} + +static inline int nft_rhash_cmp(struct rhashtable_compare_arg *arg, + const void *ptr) +{ + const struct nft_rhash_cmp_arg *x = arg->key; + const struct nft_rhash_elem *he = ptr; + + if (memcmp(nft_set_ext_key(&he->ext), x->key, x->set->klen)) + return 1; + if (nft_set_elem_is_dead(&he->ext)) + return 1; + if (nft_set_elem_expired(&he->ext)) + return 1; + if (!nft_set_elem_active(&he->ext, x->genmask)) + return 1; + return 0; +} + +static const struct rhashtable_params nft_rhash_params = { + .head_offset = offsetof(struct nft_rhash_elem, node), + .hashfn = nft_rhash_key, + .obj_hashfn = nft_rhash_obj, + .obj_cmpfn = nft_rhash_cmp, + .automatic_shrinking = true, +}; + +INDIRECT_CALLABLE_SCOPE +bool nft_rhash_lookup(const struct net *net, const struct nft_set *set, + const u32 *key, const struct nft_set_ext **ext) +{ + struct nft_rhash *priv = nft_set_priv(set); + const struct nft_rhash_elem *he; + struct nft_rhash_cmp_arg arg = { + .genmask = nft_genmask_cur(net), + .set = set, + .key = key, + }; + + he = rhashtable_lookup(&priv->ht, &arg, nft_rhash_params); + if (he != NULL) + *ext = &he->ext; + + return !!he; +} + +static void *nft_rhash_get(const struct net *net, const struct nft_set *set, + const struct nft_set_elem *elem, unsigned int flags) +{ + struct nft_rhash *priv = nft_set_priv(set); + struct nft_rhash_elem *he; + struct nft_rhash_cmp_arg arg = { + .genmask = nft_genmask_cur(net), + .set = set, + .key = elem->key.val.data, + }; + + he = rhashtable_lookup(&priv->ht, &arg, nft_rhash_params); + if (he != NULL) + return he; + + return ERR_PTR(-ENOENT); +} + +static bool nft_rhash_update(struct nft_set *set, const u32 *key, + void *(*new)(struct nft_set *, + const struct nft_expr *, + struct nft_regs *regs), + const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_set_ext **ext) +{ + struct nft_rhash *priv = nft_set_priv(set); + struct nft_rhash_elem *he, *prev; + struct nft_rhash_cmp_arg arg = { + .genmask = NFT_GENMASK_ANY, + .set = set, + .key = key, + }; + + he = rhashtable_lookup(&priv->ht, &arg, nft_rhash_params); + if (he != NULL) + goto out; + + he = new(set, expr, regs); + if (he == NULL) + goto err1; + + prev = rhashtable_lookup_get_insert_key(&priv->ht, &arg, &he->node, + nft_rhash_params); + if (IS_ERR(prev)) + goto err2; + + /* Another cpu may race to insert the element with the same key */ + if (prev) { + nft_set_elem_destroy(set, he, true); + atomic_dec(&set->nelems); + he = prev; + } + +out: + *ext = &he->ext; + return true; + +err2: + nft_set_elem_destroy(set, he, true); + atomic_dec(&set->nelems); +err1: + return false; +} + +static int nft_rhash_insert(const struct net *net, const struct nft_set *set, + const struct nft_set_elem *elem, + struct nft_set_ext **ext) +{ + struct nft_rhash *priv = nft_set_priv(set); + struct nft_rhash_elem *he = elem->priv; + struct nft_rhash_cmp_arg arg = { + .genmask = nft_genmask_next(net), + .set = set, + .key = elem->key.val.data, + }; + struct nft_rhash_elem *prev; + + prev = rhashtable_lookup_get_insert_key(&priv->ht, &arg, &he->node, + nft_rhash_params); + if (IS_ERR(prev)) + return PTR_ERR(prev); + if (prev) { + *ext = &prev->ext; + return -EEXIST; + } + return 0; +} + +static void nft_rhash_activate(const struct net *net, const struct nft_set *set, + const struct nft_set_elem *elem) +{ + struct nft_rhash_elem *he = elem->priv; + + nft_set_elem_change_active(net, set, &he->ext); +} + +static bool nft_rhash_flush(const struct net *net, + const struct nft_set *set, void *priv) +{ + struct nft_rhash_elem *he = priv; + + nft_set_elem_change_active(net, set, &he->ext); + + return true; +} + +static void *nft_rhash_deactivate(const struct net *net, + const struct nft_set *set, + const struct nft_set_elem *elem) +{ + struct nft_rhash *priv = nft_set_priv(set); + struct nft_rhash_elem *he; + struct nft_rhash_cmp_arg arg = { + .genmask = nft_genmask_next(net), + .set = set, + .key = elem->key.val.data, + }; + + rcu_read_lock(); + he = rhashtable_lookup(&priv->ht, &arg, nft_rhash_params); + if (he) + nft_set_elem_change_active(net, set, &he->ext); + + rcu_read_unlock(); + + return he; +} + +static void nft_rhash_remove(const struct net *net, + const struct nft_set *set, + const struct nft_set_elem *elem) +{ + struct nft_rhash *priv = nft_set_priv(set); + struct nft_rhash_elem *he = elem->priv; + + rhashtable_remove_fast(&priv->ht, &he->node, nft_rhash_params); +} + +static bool nft_rhash_delete(const struct nft_set *set, + const u32 *key) +{ + struct nft_rhash *priv = nft_set_priv(set); + struct nft_rhash_cmp_arg arg = { + .genmask = NFT_GENMASK_ANY, + .set = set, + .key = key, + }; + struct nft_rhash_elem *he; + + he = rhashtable_lookup(&priv->ht, &arg, nft_rhash_params); + if (he == NULL) + return false; + + nft_set_elem_dead(&he->ext); + + return true; +} + +static void nft_rhash_walk(const struct nft_ctx *ctx, struct nft_set *set, + struct nft_set_iter *iter) +{ + struct nft_rhash *priv = nft_set_priv(set); + struct nft_rhash_elem *he; + struct rhashtable_iter hti; + struct nft_set_elem elem; + + rhashtable_walk_enter(&priv->ht, &hti); + rhashtable_walk_start(&hti); + + while ((he = rhashtable_walk_next(&hti))) { + if (IS_ERR(he)) { + if (PTR_ERR(he) != -EAGAIN) { + iter->err = PTR_ERR(he); + break; + } + + continue; + } + + if (iter->count < iter->skip) + goto cont; + if (!nft_set_elem_active(&he->ext, iter->genmask)) + goto cont; + + elem.priv = he; + + iter->err = iter->fn(ctx, set, iter, &elem); + if (iter->err < 0) + break; + +cont: + iter->count++; + } + rhashtable_walk_stop(&hti); + rhashtable_walk_exit(&hti); +} + +static bool nft_rhash_expr_needs_gc_run(const struct nft_set *set, + struct nft_set_ext *ext) +{ + struct nft_set_elem_expr *elem_expr = nft_set_ext_expr(ext); + struct nft_expr *expr; + u32 size; + + nft_setelem_expr_foreach(expr, elem_expr, size) { + if (expr->ops->gc && + expr->ops->gc(read_pnet(&set->net), expr)) + return true; + } + + return false; +} + +static void nft_rhash_gc(struct work_struct *work) +{ + struct nftables_pernet *nft_net; + struct nft_set *set; + struct nft_rhash_elem *he; + struct nft_rhash *priv; + struct rhashtable_iter hti; + struct nft_trans_gc *gc; + struct net *net; + u32 gc_seq; + + priv = container_of(work, struct nft_rhash, gc_work.work); + set = nft_set_container_of(priv); + net = read_pnet(&set->net); + nft_net = nft_pernet(net); + gc_seq = READ_ONCE(nft_net->gc_seq); + + if (nft_set_gc_is_pending(set)) + goto done; + + gc = nft_trans_gc_alloc(set, gc_seq, GFP_KERNEL); + if (!gc) + goto done; + + rhashtable_walk_enter(&priv->ht, &hti); + rhashtable_walk_start(&hti); + + while ((he = rhashtable_walk_next(&hti))) { + if (IS_ERR(he)) { + nft_trans_gc_destroy(gc); + gc = NULL; + goto try_later; + } + + /* Ruleset has been updated, try later. */ + if (READ_ONCE(nft_net->gc_seq) != gc_seq) { + nft_trans_gc_destroy(gc); + gc = NULL; + goto try_later; + } + + if (nft_set_elem_is_dead(&he->ext)) + goto dead_elem; + + if (nft_set_ext_exists(&he->ext, NFT_SET_EXT_EXPRESSIONS) && + nft_rhash_expr_needs_gc_run(set, &he->ext)) + goto needs_gc_run; + + if (!nft_set_elem_expired(&he->ext)) + continue; +needs_gc_run: + nft_set_elem_dead(&he->ext); +dead_elem: + gc = nft_trans_gc_queue_async(gc, gc_seq, GFP_ATOMIC); + if (!gc) + goto try_later; + + nft_trans_gc_elem_add(gc, he); + } + + gc = nft_trans_gc_catchall_async(gc, gc_seq); + +try_later: + /* catchall list iteration requires rcu read side lock. */ + rhashtable_walk_stop(&hti); + rhashtable_walk_exit(&hti); + + if (gc) + nft_trans_gc_queue_async_done(gc); + +done: + queue_delayed_work(system_power_efficient_wq, &priv->gc_work, + nft_set_gc_interval(set)); +} + +static u64 nft_rhash_privsize(const struct nlattr * const nla[], + const struct nft_set_desc *desc) +{ + return sizeof(struct nft_rhash); +} + +static void nft_rhash_gc_init(const struct nft_set *set) +{ + struct nft_rhash *priv = nft_set_priv(set); + + queue_delayed_work(system_power_efficient_wq, &priv->gc_work, + nft_set_gc_interval(set)); +} + +static int nft_rhash_init(const struct nft_set *set, + const struct nft_set_desc *desc, + const struct nlattr * const tb[]) +{ + struct nft_rhash *priv = nft_set_priv(set); + struct rhashtable_params params = nft_rhash_params; + int err; + + params.nelem_hint = desc->size ?: NFT_RHASH_ELEMENT_HINT; + params.key_len = set->klen; + + err = rhashtable_init(&priv->ht, ¶ms); + if (err < 0) + return err; + + INIT_DEFERRABLE_WORK(&priv->gc_work, nft_rhash_gc); + if (set->flags & (NFT_SET_TIMEOUT | NFT_SET_EVAL)) + nft_rhash_gc_init(set); + + return 0; +} + +struct nft_rhash_ctx { + const struct nft_ctx ctx; + const struct nft_set *set; +}; + +static void nft_rhash_elem_destroy(void *ptr, void *arg) +{ + struct nft_rhash_ctx *rhash_ctx = arg; + + nf_tables_set_elem_destroy(&rhash_ctx->ctx, rhash_ctx->set, ptr); +} + +static void nft_rhash_destroy(const struct nft_ctx *ctx, + const struct nft_set *set) +{ + struct nft_rhash *priv = nft_set_priv(set); + struct nft_rhash_ctx rhash_ctx = { + .ctx = *ctx, + .set = set, + }; + + cancel_delayed_work_sync(&priv->gc_work); + rhashtable_free_and_destroy(&priv->ht, nft_rhash_elem_destroy, + (void *)&rhash_ctx); +} + +/* Number of buckets is stored in u32, so cap our result to 1U<<31 */ +#define NFT_MAX_BUCKETS (1U << 31) + +static u32 nft_hash_buckets(u32 size) +{ + u64 val = div_u64((u64)size * 4, 3); + + if (val >= NFT_MAX_BUCKETS) + return NFT_MAX_BUCKETS; + + return roundup_pow_of_two(val); +} + +static bool nft_rhash_estimate(const struct nft_set_desc *desc, u32 features, + struct nft_set_estimate *est) +{ + est->size = ~0; + est->lookup = NFT_SET_CLASS_O_1; + est->space = NFT_SET_CLASS_O_N; + + return true; +} + +struct nft_hash { + u32 seed; + u32 buckets; + struct hlist_head table[]; +}; + +struct nft_hash_elem { + struct hlist_node node; + struct nft_set_ext ext; +}; + +INDIRECT_CALLABLE_SCOPE +bool nft_hash_lookup(const struct net *net, const struct nft_set *set, + const u32 *key, const struct nft_set_ext **ext) +{ + struct nft_hash *priv = nft_set_priv(set); + u8 genmask = nft_genmask_cur(net); + const struct nft_hash_elem *he; + u32 hash; + + hash = jhash(key, set->klen, priv->seed); + hash = reciprocal_scale(hash, priv->buckets); + hlist_for_each_entry_rcu(he, &priv->table[hash], node) { + if (!memcmp(nft_set_ext_key(&he->ext), key, set->klen) && + nft_set_elem_active(&he->ext, genmask)) { + *ext = &he->ext; + return true; + } + } + return false; +} + +static void *nft_hash_get(const struct net *net, const struct nft_set *set, + const struct nft_set_elem *elem, unsigned int flags) +{ + struct nft_hash *priv = nft_set_priv(set); + u8 genmask = nft_genmask_cur(net); + struct nft_hash_elem *he; + u32 hash; + + hash = jhash(elem->key.val.data, set->klen, priv->seed); + hash = reciprocal_scale(hash, priv->buckets); + hlist_for_each_entry_rcu(he, &priv->table[hash], node) { + if (!memcmp(nft_set_ext_key(&he->ext), elem->key.val.data, set->klen) && + nft_set_elem_active(&he->ext, genmask)) + return he; + } + return ERR_PTR(-ENOENT); +} + +INDIRECT_CALLABLE_SCOPE +bool nft_hash_lookup_fast(const struct net *net, + const struct nft_set *set, + const u32 *key, const struct nft_set_ext **ext) +{ + struct nft_hash *priv = nft_set_priv(set); + u8 genmask = nft_genmask_cur(net); + const struct nft_hash_elem *he; + u32 hash, k1, k2; + + k1 = *key; + hash = jhash_1word(k1, priv->seed); + hash = reciprocal_scale(hash, priv->buckets); + hlist_for_each_entry_rcu(he, &priv->table[hash], node) { + k2 = *(u32 *)nft_set_ext_key(&he->ext)->data; + if (k1 == k2 && + nft_set_elem_active(&he->ext, genmask)) { + *ext = &he->ext; + return true; + } + } + return false; +} + +static u32 nft_jhash(const struct nft_set *set, const struct nft_hash *priv, + const struct nft_set_ext *ext) +{ + const struct nft_data *key = nft_set_ext_key(ext); + u32 hash, k1; + + if (set->klen == 4) { + k1 = *(u32 *)key; + hash = jhash_1word(k1, priv->seed); + } else { + hash = jhash(key, set->klen, priv->seed); + } + hash = reciprocal_scale(hash, priv->buckets); + + return hash; +} + +static int nft_hash_insert(const struct net *net, const struct nft_set *set, + const struct nft_set_elem *elem, + struct nft_set_ext **ext) +{ + struct nft_hash_elem *this = elem->priv, *he; + struct nft_hash *priv = nft_set_priv(set); + u8 genmask = nft_genmask_next(net); + u32 hash; + + hash = nft_jhash(set, priv, &this->ext); + hlist_for_each_entry(he, &priv->table[hash], node) { + if (!memcmp(nft_set_ext_key(&this->ext), + nft_set_ext_key(&he->ext), set->klen) && + nft_set_elem_active(&he->ext, genmask)) { + *ext = &he->ext; + return -EEXIST; + } + } + hlist_add_head_rcu(&this->node, &priv->table[hash]); + return 0; +} + +static void nft_hash_activate(const struct net *net, const struct nft_set *set, + const struct nft_set_elem *elem) +{ + struct nft_hash_elem *he = elem->priv; + + nft_set_elem_change_active(net, set, &he->ext); +} + +static bool nft_hash_flush(const struct net *net, + const struct nft_set *set, void *priv) +{ + struct nft_hash_elem *he = priv; + + nft_set_elem_change_active(net, set, &he->ext); + return true; +} + +static void *nft_hash_deactivate(const struct net *net, + const struct nft_set *set, + const struct nft_set_elem *elem) +{ + struct nft_hash *priv = nft_set_priv(set); + struct nft_hash_elem *this = elem->priv, *he; + u8 genmask = nft_genmask_next(net); + u32 hash; + + hash = nft_jhash(set, priv, &this->ext); + hlist_for_each_entry(he, &priv->table[hash], node) { + if (!memcmp(nft_set_ext_key(&he->ext), &elem->key.val, + set->klen) && + nft_set_elem_active(&he->ext, genmask)) { + nft_set_elem_change_active(net, set, &he->ext); + return he; + } + } + return NULL; +} + +static void nft_hash_remove(const struct net *net, + const struct nft_set *set, + const struct nft_set_elem *elem) +{ + struct nft_hash_elem *he = elem->priv; + + hlist_del_rcu(&he->node); +} + +static void nft_hash_walk(const struct nft_ctx *ctx, struct nft_set *set, + struct nft_set_iter *iter) +{ + struct nft_hash *priv = nft_set_priv(set); + struct nft_hash_elem *he; + struct nft_set_elem elem; + int i; + + for (i = 0; i < priv->buckets; i++) { + hlist_for_each_entry_rcu(he, &priv->table[i], node) { + if (iter->count < iter->skip) + goto cont; + if (!nft_set_elem_active(&he->ext, iter->genmask)) + goto cont; + + elem.priv = he; + + iter->err = iter->fn(ctx, set, iter, &elem); + if (iter->err < 0) + return; +cont: + iter->count++; + } + } +} + +static u64 nft_hash_privsize(const struct nlattr * const nla[], + const struct nft_set_desc *desc) +{ + return sizeof(struct nft_hash) + + (u64)nft_hash_buckets(desc->size) * sizeof(struct hlist_head); +} + +static int nft_hash_init(const struct nft_set *set, + const struct nft_set_desc *desc, + const struct nlattr * const tb[]) +{ + struct nft_hash *priv = nft_set_priv(set); + + priv->buckets = nft_hash_buckets(desc->size); + get_random_bytes(&priv->seed, sizeof(priv->seed)); + + return 0; +} + +static void nft_hash_destroy(const struct nft_ctx *ctx, + const struct nft_set *set) +{ + struct nft_hash *priv = nft_set_priv(set); + struct nft_hash_elem *he; + struct hlist_node *next; + int i; + + for (i = 0; i < priv->buckets; i++) { + hlist_for_each_entry_safe(he, next, &priv->table[i], node) { + hlist_del_rcu(&he->node); + nf_tables_set_elem_destroy(ctx, set, he); + } + } +} + +static bool nft_hash_estimate(const struct nft_set_desc *desc, u32 features, + struct nft_set_estimate *est) +{ + if (!desc->size) + return false; + + if (desc->klen == 4) + return false; + + est->size = sizeof(struct nft_hash) + + (u64)nft_hash_buckets(desc->size) * sizeof(struct hlist_head) + + (u64)desc->size * sizeof(struct nft_hash_elem); + est->lookup = NFT_SET_CLASS_O_1; + est->space = NFT_SET_CLASS_O_N; + + return true; +} + +static bool nft_hash_fast_estimate(const struct nft_set_desc *desc, u32 features, + struct nft_set_estimate *est) +{ + if (!desc->size) + return false; + + if (desc->klen != 4) + return false; + + est->size = sizeof(struct nft_hash) + + (u64)nft_hash_buckets(desc->size) * sizeof(struct hlist_head) + + (u64)desc->size * sizeof(struct nft_hash_elem); + est->lookup = NFT_SET_CLASS_O_1; + est->space = NFT_SET_CLASS_O_N; + + return true; +} + +const struct nft_set_type nft_set_rhash_type = { + .features = NFT_SET_MAP | NFT_SET_OBJECT | + NFT_SET_TIMEOUT | NFT_SET_EVAL, + .ops = { + .privsize = nft_rhash_privsize, + .elemsize = offsetof(struct nft_rhash_elem, ext), + .estimate = nft_rhash_estimate, + .init = nft_rhash_init, + .gc_init = nft_rhash_gc_init, + .destroy = nft_rhash_destroy, + .insert = nft_rhash_insert, + .activate = nft_rhash_activate, + .deactivate = nft_rhash_deactivate, + .flush = nft_rhash_flush, + .remove = nft_rhash_remove, + .lookup = nft_rhash_lookup, + .update = nft_rhash_update, + .delete = nft_rhash_delete, + .walk = nft_rhash_walk, + .get = nft_rhash_get, + }, +}; + +const struct nft_set_type nft_set_hash_type = { + .features = NFT_SET_MAP | NFT_SET_OBJECT, + .ops = { + .privsize = nft_hash_privsize, + .elemsize = offsetof(struct nft_hash_elem, ext), + .estimate = nft_hash_estimate, + .init = nft_hash_init, + .destroy = nft_hash_destroy, + .insert = nft_hash_insert, + .activate = nft_hash_activate, + .deactivate = nft_hash_deactivate, + .flush = nft_hash_flush, + .remove = nft_hash_remove, + .lookup = nft_hash_lookup, + .walk = nft_hash_walk, + .get = nft_hash_get, + }, +}; + +const struct nft_set_type nft_set_hash_fast_type = { + .features = NFT_SET_MAP | NFT_SET_OBJECT, + .ops = { + .privsize = nft_hash_privsize, + .elemsize = offsetof(struct nft_hash_elem, ext), + .estimate = nft_hash_fast_estimate, + .init = nft_hash_init, + .destroy = nft_hash_destroy, + .insert = nft_hash_insert, + .activate = nft_hash_activate, + .deactivate = nft_hash_deactivate, + .flush = nft_hash_flush, + .remove = nft_hash_remove, + .lookup = nft_hash_lookup_fast, + .walk = nft_hash_walk, + .get = nft_hash_get, + }, +}; diff --git a/net/netfilter/nft_set_pipapo.c b/net/netfilter/nft_set_pipapo.c new file mode 100644 index 000000000..4e1cc3172 --- /dev/null +++ b/net/netfilter/nft_set_pipapo.c @@ -0,0 +1,2331 @@ +// SPDX-License-Identifier: GPL-2.0-only + +/* PIPAPO: PIle PAcket POlicies: set for arbitrary concatenations of ranges + * + * Copyright (c) 2019-2020 Red Hat GmbH + * + * Author: Stefano Brivio + */ + +/** + * DOC: Theory of Operation + * + * + * Problem + * ------- + * + * Match packet bytes against entries composed of ranged or non-ranged packet + * field specifiers, mapping them to arbitrary references. For example: + * + * :: + * + * --- fields ---> + * | [net],[port],[net]... => [reference] + * entries [net],[port],[net]... => [reference] + * | [net],[port],[net]... => [reference] + * V ... + * + * where [net] fields can be IP ranges or netmasks, and [port] fields are port + * ranges. Arbitrary packet fields can be matched. + * + * + * Algorithm Overview + * ------------------ + * + * This algorithm is loosely inspired by [Ligatti 2010], and fundamentally + * relies on the consideration that every contiguous range in a space of b bits + * can be converted into b * 2 netmasks, from Theorem 3 in [Rottenstreich 2010], + * as also illustrated in Section 9 of [Kogan 2014]. + * + * Classification against a number of entries, that require matching given bits + * of a packet field, is performed by grouping those bits in sets of arbitrary + * size, and classifying packet bits one group at a time. + * + * Example: + * to match the source port (16 bits) of a packet, we can divide those 16 bits + * in 4 groups of 4 bits each. Given the entry: + * 0000 0001 0101 1001 + * and a packet with source port: + * 0000 0001 1010 1001 + * first and second groups match, but the third doesn't. We conclude that the + * packet doesn't match the given entry. + * + * Translate the set to a sequence of lookup tables, one per field. Each table + * has two dimensions: bit groups to be matched for a single packet field, and + * all the possible values of said groups (buckets). Input entries are + * represented as one or more rules, depending on the number of composing + * netmasks for the given field specifier, and a group match is indicated as a + * set bit, with number corresponding to the rule index, in all the buckets + * whose value matches the entry for a given group. + * + * Rules are mapped between fields through an array of x, n pairs, with each + * item mapping a matched rule to one or more rules. The position of the pair in + * the array indicates the matched rule to be mapped to the next field, x + * indicates the first rule index in the next field, and n the amount of + * next-field rules the current rule maps to. + * + * The mapping array for the last field maps to the desired references. + * + * To match, we perform table lookups using the values of grouped packet bits, + * and use a sequence of bitwise operations to progressively evaluate rule + * matching. + * + * A stand-alone, reference implementation, also including notes about possible + * future optimisations, is available at: + * https://pipapo.lameexcu.se/ + * + * Insertion + * --------- + * + * - For each packet field: + * + * - divide the b packet bits we want to classify into groups of size t, + * obtaining ceil(b / t) groups + * + * Example: match on destination IP address, with t = 4: 32 bits, 8 groups + * of 4 bits each + * + * - allocate a lookup table with one column ("bucket") for each possible + * value of a group, and with one row for each group + * + * Example: 8 groups, 2^4 buckets: + * + * :: + * + * bucket + * group 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + * 0 + * 1 + * 2 + * 3 + * 4 + * 5 + * 6 + * 7 + * + * - map the bits we want to classify for the current field, for a given + * entry, to a single rule for non-ranged and netmask set items, and to one + * or multiple rules for ranges. Ranges are expanded to composing netmasks + * by pipapo_expand(). + * + * Example: 2 entries, 10.0.0.5:1024 and 192.168.1.0-192.168.2.1:2048 + * - rule #0: 10.0.0.5 + * - rule #1: 192.168.1.0/24 + * - rule #2: 192.168.2.0/31 + * + * - insert references to the rules in the lookup table, selecting buckets + * according to bit values of a rule in the given group. This is done by + * pipapo_insert(). + * + * Example: given: + * - rule #0: 10.0.0.5 mapping to buckets + * < 0 10 0 0 0 0 0 5 > + * - rule #1: 192.168.1.0/24 mapping to buckets + * < 12 0 10 8 0 1 < 0..15 > < 0..15 > > + * - rule #2: 192.168.2.0/31 mapping to buckets + * < 12 0 10 8 0 2 0 < 0..1 > > + * + * these bits are set in the lookup table: + * + * :: + * + * bucket + * group 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + * 0 0 1,2 + * 1 1,2 0 + * 2 0 1,2 + * 3 0 1,2 + * 4 0,1,2 + * 5 0 1 2 + * 6 0,1,2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + * 7 1,2 1,2 1 1 1 0,1 1 1 1 1 1 1 1 1 1 1 + * + * - if this is not the last field in the set, fill a mapping array that maps + * rules from the lookup table to rules belonging to the same entry in + * the next lookup table, done by pipapo_map(). + * + * Note that as rules map to contiguous ranges of rules, given how netmask + * expansion and insertion is performed, &union nft_pipapo_map_bucket stores + * this information as pairs of first rule index, rule count. + * + * Example: 2 entries, 10.0.0.5:1024 and 192.168.1.0-192.168.2.1:2048, + * given lookup table #0 for field 0 (see example above): + * + * :: + * + * bucket + * group 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + * 0 0 1,2 + * 1 1,2 0 + * 2 0 1,2 + * 3 0 1,2 + * 4 0,1,2 + * 5 0 1 2 + * 6 0,1,2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + * 7 1,2 1,2 1 1 1 0,1 1 1 1 1 1 1 1 1 1 1 + * + * and lookup table #1 for field 1 with: + * - rule #0: 1024 mapping to buckets + * < 0 0 4 0 > + * - rule #1: 2048 mapping to buckets + * < 0 0 5 0 > + * + * :: + * + * bucket + * group 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + * 0 0,1 + * 1 0,1 + * 2 0 1 + * 3 0,1 + * + * we need to map rules for 10.0.0.5 in lookup table #0 (rule #0) to 1024 + * in lookup table #1 (rule #0) and rules for 192.168.1.0-192.168.2.1 + * (rules #1, #2) to 2048 in lookup table #2 (rule #1): + * + * :: + * + * rule indices in current field: 0 1 2 + * map to rules in next field: 0 1 1 + * + * - if this is the last field in the set, fill a mapping array that maps + * rules from the last lookup table to element pointers, also done by + * pipapo_map(). + * + * Note that, in this implementation, we have two elements (start, end) for + * each entry. The pointer to the end element is stored in this array, and + * the pointer to the start element is linked from it. + * + * Example: entry 10.0.0.5:1024 has a corresponding &struct nft_pipapo_elem + * pointer, 0x66, and element for 192.168.1.0-192.168.2.1:2048 is at 0x42. + * From the rules of lookup table #1 as mapped above: + * + * :: + * + * rule indices in last field: 0 1 + * map to elements: 0x66 0x42 + * + * + * Matching + * -------- + * + * We use a result bitmap, with the size of a single lookup table bucket, to + * represent the matching state that applies at every algorithm step. This is + * done by pipapo_lookup(). + * + * - For each packet field: + * + * - start with an all-ones result bitmap (res_map in pipapo_lookup()) + * + * - perform a lookup into the table corresponding to the current field, + * for each group, and at every group, AND the current result bitmap with + * the value from the lookup table bucket + * + * :: + * + * Example: 192.168.1.5 < 12 0 10 8 0 1 0 5 >, with lookup table from + * insertion examples. + * Lookup table buckets are at least 3 bits wide, we'll assume 8 bits for + * convenience in this example. Initial result bitmap is 0xff, the steps + * below show the value of the result bitmap after each group is processed: + * + * bucket + * group 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + * 0 0 1,2 + * result bitmap is now: 0xff & 0x6 [bucket 12] = 0x6 + * + * 1 1,2 0 + * result bitmap is now: 0x6 & 0x6 [bucket 0] = 0x6 + * + * 2 0 1,2 + * result bitmap is now: 0x6 & 0x6 [bucket 10] = 0x6 + * + * 3 0 1,2 + * result bitmap is now: 0x6 & 0x6 [bucket 8] = 0x6 + * + * 4 0,1,2 + * result bitmap is now: 0x6 & 0x7 [bucket 0] = 0x6 + * + * 5 0 1 2 + * result bitmap is now: 0x6 & 0x2 [bucket 1] = 0x2 + * + * 6 0,1,2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + * result bitmap is now: 0x2 & 0x7 [bucket 0] = 0x2 + * + * 7 1,2 1,2 1 1 1 0,1 1 1 1 1 1 1 1 1 1 1 + * final result bitmap for this field is: 0x2 & 0x3 [bucket 5] = 0x2 + * + * - at the next field, start with a new, all-zeroes result bitmap. For each + * bit set in the previous result bitmap, fill the new result bitmap + * (fill_map in pipapo_lookup()) with the rule indices from the + * corresponding buckets of the mapping field for this field, done by + * pipapo_refill() + * + * Example: with mapping table from insertion examples, with the current + * result bitmap from the previous example, 0x02: + * + * :: + * + * rule indices in current field: 0 1 2 + * map to rules in next field: 0 1 1 + * + * the new result bitmap will be 0x02: rule 1 was set, and rule 1 will be + * set. + * + * We can now extend this example to cover the second iteration of the step + * above (lookup and AND bitmap): assuming the port field is + * 2048 < 0 0 5 0 >, with starting result bitmap 0x2, and lookup table + * for "port" field from pre-computation example: + * + * :: + * + * bucket + * group 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + * 0 0,1 + * 1 0,1 + * 2 0 1 + * 3 0,1 + * + * operations are: 0x2 & 0x3 [bucket 0] & 0x3 [bucket 0] & 0x2 [bucket 5] + * & 0x3 [bucket 0], resulting bitmap is 0x2. + * + * - if this is the last field in the set, look up the value from the mapping + * array corresponding to the final result bitmap + * + * Example: 0x2 resulting bitmap from 192.168.1.5:2048, mapping array for + * last field from insertion example: + * + * :: + * + * rule indices in last field: 0 1 + * map to elements: 0x66 0x42 + * + * the matching element is at 0x42. + * + * + * References + * ---------- + * + * [Ligatti 2010] + * A Packet-classification Algorithm for Arbitrary Bitmask Rules, with + * Automatic Time-space Tradeoffs + * Jay Ligatti, Josh Kuhn, and Chris Gage. + * Proceedings of the IEEE International Conference on Computer + * Communication Networks (ICCCN), August 2010. + * https://www.cse.usf.edu/~ligatti/papers/grouper-conf.pdf + * + * [Rottenstreich 2010] + * Worst-Case TCAM Rule Expansion + * Ori Rottenstreich and Isaac Keslassy. + * 2010 Proceedings IEEE INFOCOM, San Diego, CA, 2010. + * http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.212.4592&rep=rep1&type=pdf + * + * [Kogan 2014] + * SAX-PAC (Scalable And eXpressive PAcket Classification) + * Kirill Kogan, Sergey Nikolenko, Ori Rottenstreich, William Culhane, + * and Patrick Eugster. + * Proceedings of the 2014 ACM conference on SIGCOMM, August 2014. + * https://www.sigcomm.org/sites/default/files/ccr/papers/2014/August/2619239-2626294.pdf + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "nft_set_pipapo_avx2.h" +#include "nft_set_pipapo.h" + +/* Current working bitmap index, toggled between field matches */ +static DEFINE_PER_CPU(bool, nft_pipapo_scratch_index); + +/** + * pipapo_refill() - For each set bit, set bits from selected mapping table item + * @map: Bitmap to be scanned for set bits + * @len: Length of bitmap in longs + * @rules: Number of rules in field + * @dst: Destination bitmap + * @mt: Mapping table containing bit set specifiers + * @match_only: Find a single bit and return, don't fill + * + * Iteration over set bits with __builtin_ctzl(): Daniel Lemire, public domain. + * + * For each bit set in map, select the bucket from mapping table with index + * corresponding to the position of the bit set. Use start bit and amount of + * bits specified in bucket to fill region in dst. + * + * Return: -1 on no match, bit position on 'match_only', 0 otherwise. + */ +int pipapo_refill(unsigned long *map, int len, int rules, unsigned long *dst, + union nft_pipapo_map_bucket *mt, bool match_only) +{ + unsigned long bitset; + int k, ret = -1; + + for (k = 0; k < len; k++) { + bitset = map[k]; + while (bitset) { + unsigned long t = bitset & -bitset; + int r = __builtin_ctzl(bitset); + int i = k * BITS_PER_LONG + r; + + if (unlikely(i >= rules)) { + map[k] = 0; + return -1; + } + + if (match_only) { + bitmap_clear(map, i, 1); + return i; + } + + ret = 0; + + bitmap_set(dst, mt[i].to, mt[i].n); + + bitset ^= t; + } + map[k] = 0; + } + + return ret; +} + +/** + * nft_pipapo_lookup() - Lookup function + * @net: Network namespace + * @set: nftables API set representation + * @key: nftables API element representation containing key data + * @ext: nftables API extension pointer, filled with matching reference + * + * For more details, see DOC: Theory of Operation. + * + * Return: true on match, false otherwise. + */ +bool nft_pipapo_lookup(const struct net *net, const struct nft_set *set, + const u32 *key, const struct nft_set_ext **ext) +{ + struct nft_pipapo *priv = nft_set_priv(set); + unsigned long *res_map, *fill_map; + u8 genmask = nft_genmask_cur(net); + const u8 *rp = (const u8 *)key; + struct nft_pipapo_match *m; + struct nft_pipapo_field *f; + bool map_index; + int i; + + local_bh_disable(); + + map_index = raw_cpu_read(nft_pipapo_scratch_index); + + m = rcu_dereference(priv->match); + + if (unlikely(!m || !*raw_cpu_ptr(m->scratch))) + goto out; + + res_map = *raw_cpu_ptr(m->scratch) + (map_index ? m->bsize_max : 0); + fill_map = *raw_cpu_ptr(m->scratch) + (map_index ? 0 : m->bsize_max); + + memset(res_map, 0xff, m->bsize_max * sizeof(*res_map)); + + nft_pipapo_for_each_field(f, i, m) { + bool last = i == m->field_count - 1; + int b; + + /* For each bit group: select lookup table bucket depending on + * packet bytes value, then AND bucket value + */ + if (likely(f->bb == 8)) + pipapo_and_field_buckets_8bit(f, res_map, rp); + else + pipapo_and_field_buckets_4bit(f, res_map, rp); + NFT_PIPAPO_GROUP_BITS_ARE_8_OR_4; + + rp += f->groups / NFT_PIPAPO_GROUPS_PER_BYTE(f); + + /* Now populate the bitmap for the next field, unless this is + * the last field, in which case return the matched 'ext' + * pointer if any. + * + * Now res_map contains the matching bitmap, and fill_map is the + * bitmap for the next field. + */ +next_match: + b = pipapo_refill(res_map, f->bsize, f->rules, fill_map, f->mt, + last); + if (b < 0) { + raw_cpu_write(nft_pipapo_scratch_index, map_index); + local_bh_enable(); + + return false; + } + + if (last) { + *ext = &f->mt[b].e->ext; + if (unlikely(nft_set_elem_expired(*ext) || + !nft_set_elem_active(*ext, genmask))) + goto next_match; + + /* Last field: we're just returning the key without + * filling the initial bitmap for the next field, so the + * current inactive bitmap is clean and can be reused as + * *next* bitmap (not initial) for the next packet. + */ + raw_cpu_write(nft_pipapo_scratch_index, map_index); + local_bh_enable(); + + return true; + } + + /* Swap bitmap indices: res_map is the initial bitmap for the + * next field, and fill_map is guaranteed to be all-zeroes at + * this point. + */ + map_index = !map_index; + swap(res_map, fill_map); + + rp += NFT_PIPAPO_GROUPS_PADDING(f); + } + +out: + local_bh_enable(); + return false; +} + +/** + * pipapo_get() - Get matching element reference given key data + * @net: Network namespace + * @set: nftables API set representation + * @data: Key data to be matched against existing elements + * @genmask: If set, check that element is active in given genmask + * + * This is essentially the same as the lookup function, except that it matches + * key data against the uncommitted copy and doesn't use preallocated maps for + * bitmap results. + * + * Return: pointer to &struct nft_pipapo_elem on match, error pointer otherwise. + */ +static struct nft_pipapo_elem *pipapo_get(const struct net *net, + const struct nft_set *set, + const u8 *data, u8 genmask) +{ + struct nft_pipapo_elem *ret = ERR_PTR(-ENOENT); + struct nft_pipapo *priv = nft_set_priv(set); + struct nft_pipapo_match *m = priv->clone; + unsigned long *res_map, *fill_map = NULL; + struct nft_pipapo_field *f; + int i; + + res_map = kmalloc_array(m->bsize_max, sizeof(*res_map), GFP_ATOMIC); + if (!res_map) { + ret = ERR_PTR(-ENOMEM); + goto out; + } + + fill_map = kcalloc(m->bsize_max, sizeof(*res_map), GFP_ATOMIC); + if (!fill_map) { + ret = ERR_PTR(-ENOMEM); + goto out; + } + + memset(res_map, 0xff, m->bsize_max * sizeof(*res_map)); + + nft_pipapo_for_each_field(f, i, m) { + bool last = i == m->field_count - 1; + int b; + + /* For each bit group: select lookup table bucket depending on + * packet bytes value, then AND bucket value + */ + if (f->bb == 8) + pipapo_and_field_buckets_8bit(f, res_map, data); + else if (f->bb == 4) + pipapo_and_field_buckets_4bit(f, res_map, data); + else + BUG(); + + data += f->groups / NFT_PIPAPO_GROUPS_PER_BYTE(f); + + /* Now populate the bitmap for the next field, unless this is + * the last field, in which case return the matched 'ext' + * pointer if any. + * + * Now res_map contains the matching bitmap, and fill_map is the + * bitmap for the next field. + */ +next_match: + b = pipapo_refill(res_map, f->bsize, f->rules, fill_map, f->mt, + last); + if (b < 0) + goto out; + + if (last) { + if (nft_set_elem_expired(&f->mt[b].e->ext)) + goto next_match; + if ((genmask && + !nft_set_elem_active(&f->mt[b].e->ext, genmask))) + goto next_match; + + ret = f->mt[b].e; + goto out; + } + + data += NFT_PIPAPO_GROUPS_PADDING(f); + + /* Swap bitmap indices: fill_map will be the initial bitmap for + * the next field (i.e. the new res_map), and res_map is + * guaranteed to be all-zeroes at this point, ready to be filled + * according to the next mapping table. + */ + swap(res_map, fill_map); + } + +out: + kfree(fill_map); + kfree(res_map); + return ret; +} + +/** + * nft_pipapo_get() - Get matching element reference given key data + * @net: Network namespace + * @set: nftables API set representation + * @elem: nftables API element representation containing key data + * @flags: Unused + */ +static void *nft_pipapo_get(const struct net *net, const struct nft_set *set, + const struct nft_set_elem *elem, unsigned int flags) +{ + return pipapo_get(net, set, (const u8 *)elem->key.val.data, + nft_genmask_cur(net)); +} + +/** + * pipapo_resize() - Resize lookup or mapping table, or both + * @f: Field containing lookup and mapping tables + * @old_rules: Previous amount of rules in field + * @rules: New amount of rules + * + * Increase, decrease or maintain tables size depending on new amount of rules, + * and copy data over. In case the new size is smaller, throw away data for + * highest-numbered rules. + * + * Return: 0 on success, -ENOMEM on allocation failure. + */ +static int pipapo_resize(struct nft_pipapo_field *f, int old_rules, int rules) +{ + long *new_lt = NULL, *new_p, *old_lt = f->lt, *old_p; + union nft_pipapo_map_bucket *new_mt, *old_mt = f->mt; + size_t new_bucket_size, copy; + int group, bucket; + + new_bucket_size = DIV_ROUND_UP(rules, BITS_PER_LONG); +#ifdef NFT_PIPAPO_ALIGN + new_bucket_size = roundup(new_bucket_size, + NFT_PIPAPO_ALIGN / sizeof(*new_lt)); +#endif + + if (new_bucket_size == f->bsize) + goto mt; + + if (new_bucket_size > f->bsize) + copy = f->bsize; + else + copy = new_bucket_size; + + new_lt = kvzalloc(f->groups * NFT_PIPAPO_BUCKETS(f->bb) * + new_bucket_size * sizeof(*new_lt) + + NFT_PIPAPO_ALIGN_HEADROOM, + GFP_KERNEL); + if (!new_lt) + return -ENOMEM; + + new_p = NFT_PIPAPO_LT_ALIGN(new_lt); + old_p = NFT_PIPAPO_LT_ALIGN(old_lt); + + for (group = 0; group < f->groups; group++) { + for (bucket = 0; bucket < NFT_PIPAPO_BUCKETS(f->bb); bucket++) { + memcpy(new_p, old_p, copy * sizeof(*new_p)); + new_p += copy; + old_p += copy; + + if (new_bucket_size > f->bsize) + new_p += new_bucket_size - f->bsize; + else + old_p += f->bsize - new_bucket_size; + } + } + +mt: + new_mt = kvmalloc(rules * sizeof(*new_mt), GFP_KERNEL); + if (!new_mt) { + kvfree(new_lt); + return -ENOMEM; + } + + memcpy(new_mt, f->mt, min(old_rules, rules) * sizeof(*new_mt)); + if (rules > old_rules) { + memset(new_mt + old_rules, 0, + (rules - old_rules) * sizeof(*new_mt)); + } + + if (new_lt) { + f->bsize = new_bucket_size; + NFT_PIPAPO_LT_ASSIGN(f, new_lt); + kvfree(old_lt); + } + + f->mt = new_mt; + kvfree(old_mt); + + return 0; +} + +/** + * pipapo_bucket_set() - Set rule bit in bucket given group and group value + * @f: Field containing lookup table + * @rule: Rule index + * @group: Group index + * @v: Value of bit group + */ +static void pipapo_bucket_set(struct nft_pipapo_field *f, int rule, int group, + int v) +{ + unsigned long *pos; + + pos = NFT_PIPAPO_LT_ALIGN(f->lt); + pos += f->bsize * NFT_PIPAPO_BUCKETS(f->bb) * group; + pos += f->bsize * v; + + __set_bit(rule, pos); +} + +/** + * pipapo_lt_4b_to_8b() - Switch lookup table group width from 4 bits to 8 bits + * @old_groups: Number of current groups + * @bsize: Size of one bucket, in longs + * @old_lt: Pointer to the current lookup table + * @new_lt: Pointer to the new, pre-allocated lookup table + * + * Each bucket with index b in the new lookup table, belonging to group g, is + * filled with the bit intersection between: + * - bucket with index given by the upper 4 bits of b, from group g, and + * - bucket with index given by the lower 4 bits of b, from group g + 1 + * + * That is, given buckets from the new lookup table N(x, y) and the old lookup + * table O(x, y), with x bucket index, and y group index: + * + * N(b, g) := O(b / 16, g) & O(b % 16, g + 1) + * + * This ensures equivalence of the matching results on lookup. Two examples in + * pictures: + * + * bucket + * group 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 ... 254 255 + * 0 ^ + * 1 | ^ + * ... ( & ) | + * / \ | + * / \ .-( & )-. + * / bucket \ | | + * group 0 / 1 2 3 \ 4 5 6 7 8 9 10 11 12 13 |14 15 | + * 0 / \ | | + * 1 \ | | + * 2 | --' + * 3 '- + * ... + */ +static void pipapo_lt_4b_to_8b(int old_groups, int bsize, + unsigned long *old_lt, unsigned long *new_lt) +{ + int g, b, i; + + for (g = 0; g < old_groups / 2; g++) { + int src_g0 = g * 2, src_g1 = g * 2 + 1; + + for (b = 0; b < NFT_PIPAPO_BUCKETS(8); b++) { + int src_b0 = b / NFT_PIPAPO_BUCKETS(4); + int src_b1 = b % NFT_PIPAPO_BUCKETS(4); + int src_i0 = src_g0 * NFT_PIPAPO_BUCKETS(4) + src_b0; + int src_i1 = src_g1 * NFT_PIPAPO_BUCKETS(4) + src_b1; + + for (i = 0; i < bsize; i++) { + *new_lt = old_lt[src_i0 * bsize + i] & + old_lt[src_i1 * bsize + i]; + new_lt++; + } + } + } +} + +/** + * pipapo_lt_8b_to_4b() - Switch lookup table group width from 8 bits to 4 bits + * @old_groups: Number of current groups + * @bsize: Size of one bucket, in longs + * @old_lt: Pointer to the current lookup table + * @new_lt: Pointer to the new, pre-allocated lookup table + * + * Each bucket with index b in the new lookup table, belonging to group g, is + * filled with the bit union of: + * - all the buckets with index such that the upper four bits of the lower byte + * equal b, from group g, with g odd + * - all the buckets with index such that the lower four bits equal b, from + * group g, with g even + * + * That is, given buckets from the new lookup table N(x, y) and the old lookup + * table O(x, y), with x bucket index, and y group index: + * + * - with g odd: N(b, g) := U(O(x, g) for each x : x = (b & 0xf0) >> 4) + * - with g even: N(b, g) := U(O(x, g) for each x : x = b & 0x0f) + * + * where U() denotes the arbitrary union operation (binary OR of n terms). This + * ensures equivalence of the matching results on lookup. + */ +static void pipapo_lt_8b_to_4b(int old_groups, int bsize, + unsigned long *old_lt, unsigned long *new_lt) +{ + int g, b, bsrc, i; + + memset(new_lt, 0, old_groups * 2 * NFT_PIPAPO_BUCKETS(4) * bsize * + sizeof(unsigned long)); + + for (g = 0; g < old_groups * 2; g += 2) { + int src_g = g / 2; + + for (b = 0; b < NFT_PIPAPO_BUCKETS(4); b++) { + for (bsrc = NFT_PIPAPO_BUCKETS(8) * src_g; + bsrc < NFT_PIPAPO_BUCKETS(8) * (src_g + 1); + bsrc++) { + if (((bsrc & 0xf0) >> 4) != b) + continue; + + for (i = 0; i < bsize; i++) + new_lt[i] |= old_lt[bsrc * bsize + i]; + } + + new_lt += bsize; + } + + for (b = 0; b < NFT_PIPAPO_BUCKETS(4); b++) { + for (bsrc = NFT_PIPAPO_BUCKETS(8) * src_g; + bsrc < NFT_PIPAPO_BUCKETS(8) * (src_g + 1); + bsrc++) { + if ((bsrc & 0x0f) != b) + continue; + + for (i = 0; i < bsize; i++) + new_lt[i] |= old_lt[bsrc * bsize + i]; + } + + new_lt += bsize; + } + } +} + +/** + * pipapo_lt_bits_adjust() - Adjust group size for lookup table if needed + * @f: Field containing lookup table + */ +static void pipapo_lt_bits_adjust(struct nft_pipapo_field *f) +{ + unsigned long *new_lt; + int groups, bb; + size_t lt_size; + + lt_size = f->groups * NFT_PIPAPO_BUCKETS(f->bb) * f->bsize * + sizeof(*f->lt); + + if (f->bb == NFT_PIPAPO_GROUP_BITS_SMALL_SET && + lt_size > NFT_PIPAPO_LT_SIZE_HIGH) { + groups = f->groups * 2; + bb = NFT_PIPAPO_GROUP_BITS_LARGE_SET; + + lt_size = groups * NFT_PIPAPO_BUCKETS(bb) * f->bsize * + sizeof(*f->lt); + } else if (f->bb == NFT_PIPAPO_GROUP_BITS_LARGE_SET && + lt_size < NFT_PIPAPO_LT_SIZE_LOW) { + groups = f->groups / 2; + bb = NFT_PIPAPO_GROUP_BITS_SMALL_SET; + + lt_size = groups * NFT_PIPAPO_BUCKETS(bb) * f->bsize * + sizeof(*f->lt); + + /* Don't increase group width if the resulting lookup table size + * would exceed the upper size threshold for a "small" set. + */ + if (lt_size > NFT_PIPAPO_LT_SIZE_HIGH) + return; + } else { + return; + } + + new_lt = kvzalloc(lt_size + NFT_PIPAPO_ALIGN_HEADROOM, GFP_KERNEL); + if (!new_lt) + return; + + NFT_PIPAPO_GROUP_BITS_ARE_8_OR_4; + if (f->bb == 4 && bb == 8) { + pipapo_lt_4b_to_8b(f->groups, f->bsize, + NFT_PIPAPO_LT_ALIGN(f->lt), + NFT_PIPAPO_LT_ALIGN(new_lt)); + } else if (f->bb == 8 && bb == 4) { + pipapo_lt_8b_to_4b(f->groups, f->bsize, + NFT_PIPAPO_LT_ALIGN(f->lt), + NFT_PIPAPO_LT_ALIGN(new_lt)); + } else { + BUG(); + } + + f->groups = groups; + f->bb = bb; + kvfree(f->lt); + NFT_PIPAPO_LT_ASSIGN(f, new_lt); +} + +/** + * pipapo_insert() - Insert new rule in field given input key and mask length + * @f: Field containing lookup table + * @k: Input key for classification, without nftables padding + * @mask_bits: Length of mask; matches field length for non-ranged entry + * + * Insert a new rule reference in lookup buckets corresponding to k and + * mask_bits. + * + * Return: 1 on success (one rule inserted), negative error code on failure. + */ +static int pipapo_insert(struct nft_pipapo_field *f, const uint8_t *k, + int mask_bits) +{ + int rule = f->rules, group, ret, bit_offset = 0; + + ret = pipapo_resize(f, f->rules, f->rules + 1); + if (ret) + return ret; + + f->rules++; + + for (group = 0; group < f->groups; group++) { + int i, v; + u8 mask; + + v = k[group / (BITS_PER_BYTE / f->bb)]; + v &= GENMASK(BITS_PER_BYTE - bit_offset - 1, 0); + v >>= (BITS_PER_BYTE - bit_offset) - f->bb; + + bit_offset += f->bb; + bit_offset %= BITS_PER_BYTE; + + if (mask_bits >= (group + 1) * f->bb) { + /* Not masked */ + pipapo_bucket_set(f, rule, group, v); + } else if (mask_bits <= group * f->bb) { + /* Completely masked */ + for (i = 0; i < NFT_PIPAPO_BUCKETS(f->bb); i++) + pipapo_bucket_set(f, rule, group, i); + } else { + /* The mask limit falls on this group */ + mask = GENMASK(f->bb - 1, 0); + mask >>= mask_bits - group * f->bb; + for (i = 0; i < NFT_PIPAPO_BUCKETS(f->bb); i++) { + if ((i & ~mask) == (v & ~mask)) + pipapo_bucket_set(f, rule, group, i); + } + } + } + + pipapo_lt_bits_adjust(f); + + return 1; +} + +/** + * pipapo_step_diff() - Check if setting @step bit in netmask would change it + * @base: Mask we are expanding + * @step: Step bit for given expansion step + * @len: Total length of mask space (set and unset bits), bytes + * + * Convenience function for mask expansion. + * + * Return: true if step bit changes mask (i.e. isn't set), false otherwise. + */ +static bool pipapo_step_diff(u8 *base, int step, int len) +{ + /* Network order, byte-addressed */ +#ifdef __BIG_ENDIAN__ + return !(BIT(step % BITS_PER_BYTE) & base[step / BITS_PER_BYTE]); +#else + return !(BIT(step % BITS_PER_BYTE) & + base[len - 1 - step / BITS_PER_BYTE]); +#endif +} + +/** + * pipapo_step_after_end() - Check if mask exceeds range end with given step + * @base: Mask we are expanding + * @end: End of range + * @step: Step bit for given expansion step, highest bit to be set + * @len: Total length of mask space (set and unset bits), bytes + * + * Convenience function for mask expansion. + * + * Return: true if mask exceeds range setting step bits, false otherwise. + */ +static bool pipapo_step_after_end(const u8 *base, const u8 *end, int step, + int len) +{ + u8 tmp[NFT_PIPAPO_MAX_BYTES]; + int i; + + memcpy(tmp, base, len); + + /* Network order, byte-addressed */ + for (i = 0; i <= step; i++) +#ifdef __BIG_ENDIAN__ + tmp[i / BITS_PER_BYTE] |= BIT(i % BITS_PER_BYTE); +#else + tmp[len - 1 - i / BITS_PER_BYTE] |= BIT(i % BITS_PER_BYTE); +#endif + + return memcmp(tmp, end, len) > 0; +} + +/** + * pipapo_base_sum() - Sum step bit to given len-sized netmask base with carry + * @base: Netmask base + * @step: Step bit to sum + * @len: Netmask length, bytes + */ +static void pipapo_base_sum(u8 *base, int step, int len) +{ + bool carry = false; + int i; + + /* Network order, byte-addressed */ +#ifdef __BIG_ENDIAN__ + for (i = step / BITS_PER_BYTE; i < len; i++) { +#else + for (i = len - 1 - step / BITS_PER_BYTE; i >= 0; i--) { +#endif + if (carry) + base[i]++; + else + base[i] += 1 << (step % BITS_PER_BYTE); + + if (base[i]) + break; + + carry = true; + } +} + +/** + * pipapo_expand() - Expand to composing netmasks, insert into lookup table + * @f: Field containing lookup table + * @start: Start of range + * @end: End of range + * @len: Length of value in bits + * + * Expand range to composing netmasks and insert corresponding rule references + * in lookup buckets. + * + * Return: number of inserted rules on success, negative error code on failure. + */ +static int pipapo_expand(struct nft_pipapo_field *f, + const u8 *start, const u8 *end, int len) +{ + int step, masks = 0, bytes = DIV_ROUND_UP(len, BITS_PER_BYTE); + u8 base[NFT_PIPAPO_MAX_BYTES]; + + memcpy(base, start, bytes); + while (memcmp(base, end, bytes) <= 0) { + int err; + + step = 0; + while (pipapo_step_diff(base, step, bytes)) { + if (pipapo_step_after_end(base, end, step, bytes)) + break; + + step++; + if (step >= len) { + if (!masks) { + err = pipapo_insert(f, base, 0); + if (err < 0) + return err; + masks = 1; + } + goto out; + } + } + + err = pipapo_insert(f, base, len - step); + + if (err < 0) + return err; + + masks++; + pipapo_base_sum(base, step, bytes); + } +out: + return masks; +} + +/** + * pipapo_map() - Insert rules in mapping tables, mapping them between fields + * @m: Matching data, including mapping table + * @map: Table of rule maps: array of first rule and amount of rules + * in next field a given rule maps to, for each field + * @e: For last field, nft_set_ext pointer matching rules map to + */ +static void pipapo_map(struct nft_pipapo_match *m, + union nft_pipapo_map_bucket map[NFT_PIPAPO_MAX_FIELDS], + struct nft_pipapo_elem *e) +{ + struct nft_pipapo_field *f; + int i, j; + + for (i = 0, f = m->f; i < m->field_count - 1; i++, f++) { + for (j = 0; j < map[i].n; j++) { + f->mt[map[i].to + j].to = map[i + 1].to; + f->mt[map[i].to + j].n = map[i + 1].n; + } + } + + /* Last field: map to ext instead of mapping to next field */ + for (j = 0; j < map[i].n; j++) + f->mt[map[i].to + j].e = e; +} + +/** + * pipapo_realloc_scratch() - Reallocate scratch maps for partial match results + * @clone: Copy of matching data with pending insertions and deletions + * @bsize_max: Maximum bucket size, scratch maps cover two buckets + * + * Return: 0 on success, -ENOMEM on failure. + */ +static int pipapo_realloc_scratch(struct nft_pipapo_match *clone, + unsigned long bsize_max) +{ + int i; + + for_each_possible_cpu(i) { + unsigned long *scratch; +#ifdef NFT_PIPAPO_ALIGN + unsigned long *scratch_aligned; +#endif + + scratch = kzalloc_node(bsize_max * sizeof(*scratch) * 2 + + NFT_PIPAPO_ALIGN_HEADROOM, + GFP_KERNEL, cpu_to_node(i)); + if (!scratch) { + /* On failure, there's no need to undo previous + * allocations: this means that some scratch maps have + * a bigger allocated size now (this is only called on + * insertion), but the extra space won't be used by any + * CPU as new elements are not inserted and m->bsize_max + * is not updated. + */ + return -ENOMEM; + } + + kfree(*per_cpu_ptr(clone->scratch, i)); + + *per_cpu_ptr(clone->scratch, i) = scratch; + +#ifdef NFT_PIPAPO_ALIGN + scratch_aligned = NFT_PIPAPO_LT_ALIGN(scratch); + *per_cpu_ptr(clone->scratch_aligned, i) = scratch_aligned; +#endif + } + + return 0; +} + +/** + * nft_pipapo_insert() - Validate and insert ranged elements + * @net: Network namespace + * @set: nftables API set representation + * @elem: nftables API element representation containing key data + * @ext2: Filled with pointer to &struct nft_set_ext in inserted element + * + * Return: 0 on success, error pointer on failure. + */ +static int nft_pipapo_insert(const struct net *net, const struct nft_set *set, + const struct nft_set_elem *elem, + struct nft_set_ext **ext2) +{ + const struct nft_set_ext *ext = nft_set_elem_ext(set, elem->priv); + union nft_pipapo_map_bucket rulemap[NFT_PIPAPO_MAX_FIELDS]; + const u8 *start = (const u8 *)elem->key.val.data, *end; + struct nft_pipapo_elem *e = elem->priv, *dup; + struct nft_pipapo *priv = nft_set_priv(set); + struct nft_pipapo_match *m = priv->clone; + u8 genmask = nft_genmask_next(net); + struct nft_pipapo_field *f; + const u8 *start_p, *end_p; + int i, bsize_max, err = 0; + + if (nft_set_ext_exists(ext, NFT_SET_EXT_KEY_END)) + end = (const u8 *)nft_set_ext_key_end(ext)->data; + else + end = start; + + dup = pipapo_get(net, set, start, genmask); + if (!IS_ERR(dup)) { + /* Check if we already have the same exact entry */ + const struct nft_data *dup_key, *dup_end; + + dup_key = nft_set_ext_key(&dup->ext); + if (nft_set_ext_exists(&dup->ext, NFT_SET_EXT_KEY_END)) + dup_end = nft_set_ext_key_end(&dup->ext); + else + dup_end = dup_key; + + if (!memcmp(start, dup_key->data, sizeof(*dup_key->data)) && + !memcmp(end, dup_end->data, sizeof(*dup_end->data))) { + *ext2 = &dup->ext; + return -EEXIST; + } + + return -ENOTEMPTY; + } + + if (PTR_ERR(dup) == -ENOENT) { + /* Look for partially overlapping entries */ + dup = pipapo_get(net, set, end, nft_genmask_next(net)); + } + + if (PTR_ERR(dup) != -ENOENT) { + if (IS_ERR(dup)) + return PTR_ERR(dup); + *ext2 = &dup->ext; + return -ENOTEMPTY; + } + + /* Validate */ + start_p = start; + end_p = end; + nft_pipapo_for_each_field(f, i, m) { + if (f->rules >= (unsigned long)NFT_PIPAPO_RULE0_MAX) + return -ENOSPC; + + if (memcmp(start_p, end_p, + f->groups / NFT_PIPAPO_GROUPS_PER_BYTE(f)) > 0) + return -EINVAL; + + start_p += NFT_PIPAPO_GROUPS_PADDED_SIZE(f); + end_p += NFT_PIPAPO_GROUPS_PADDED_SIZE(f); + } + + /* Insert */ + priv->dirty = true; + + bsize_max = m->bsize_max; + + nft_pipapo_for_each_field(f, i, m) { + int ret; + + rulemap[i].to = f->rules; + + ret = memcmp(start, end, + f->groups / NFT_PIPAPO_GROUPS_PER_BYTE(f)); + if (!ret) + ret = pipapo_insert(f, start, f->groups * f->bb); + else + ret = pipapo_expand(f, start, end, f->groups * f->bb); + + if (ret < 0) + return ret; + + if (f->bsize > bsize_max) + bsize_max = f->bsize; + + rulemap[i].n = ret; + + start += NFT_PIPAPO_GROUPS_PADDED_SIZE(f); + end += NFT_PIPAPO_GROUPS_PADDED_SIZE(f); + } + + if (!*get_cpu_ptr(m->scratch) || bsize_max > m->bsize_max) { + put_cpu_ptr(m->scratch); + + err = pipapo_realloc_scratch(m, bsize_max); + if (err) + return err; + + m->bsize_max = bsize_max; + } else { + put_cpu_ptr(m->scratch); + } + + *ext2 = &e->ext; + + pipapo_map(m, rulemap, e); + + return 0; +} + +/** + * pipapo_clone() - Clone matching data to create new working copy + * @old: Existing matching data + * + * Return: copy of matching data passed as 'old', error pointer on failure + */ +static struct nft_pipapo_match *pipapo_clone(struct nft_pipapo_match *old) +{ + struct nft_pipapo_field *dst, *src; + struct nft_pipapo_match *new; + int i; + + new = kmalloc(sizeof(*new) + sizeof(*dst) * old->field_count, + GFP_KERNEL); + if (!new) + return ERR_PTR(-ENOMEM); + + new->field_count = old->field_count; + new->bsize_max = old->bsize_max; + + new->scratch = alloc_percpu(*new->scratch); + if (!new->scratch) + goto out_scratch; + +#ifdef NFT_PIPAPO_ALIGN + new->scratch_aligned = alloc_percpu(*new->scratch_aligned); + if (!new->scratch_aligned) + goto out_scratch; +#endif + for_each_possible_cpu(i) + *per_cpu_ptr(new->scratch, i) = NULL; + + if (pipapo_realloc_scratch(new, old->bsize_max)) + goto out_scratch_realloc; + + rcu_head_init(&new->rcu); + + src = old->f; + dst = new->f; + + for (i = 0; i < old->field_count; i++) { + unsigned long *new_lt; + + memcpy(dst, src, offsetof(struct nft_pipapo_field, lt)); + + new_lt = kvzalloc(src->groups * NFT_PIPAPO_BUCKETS(src->bb) * + src->bsize * sizeof(*dst->lt) + + NFT_PIPAPO_ALIGN_HEADROOM, + GFP_KERNEL); + if (!new_lt) + goto out_lt; + + NFT_PIPAPO_LT_ASSIGN(dst, new_lt); + + memcpy(NFT_PIPAPO_LT_ALIGN(new_lt), + NFT_PIPAPO_LT_ALIGN(src->lt), + src->bsize * sizeof(*dst->lt) * + src->groups * NFT_PIPAPO_BUCKETS(src->bb)); + + dst->mt = kvmalloc(src->rules * sizeof(*src->mt), GFP_KERNEL); + if (!dst->mt) + goto out_mt; + + memcpy(dst->mt, src->mt, src->rules * sizeof(*src->mt)); + src++; + dst++; + } + + return new; + +out_mt: + kvfree(dst->lt); +out_lt: + for (dst--; i > 0; i--) { + kvfree(dst->mt); + kvfree(dst->lt); + dst--; + } +out_scratch_realloc: + for_each_possible_cpu(i) + kfree(*per_cpu_ptr(new->scratch, i)); +#ifdef NFT_PIPAPO_ALIGN + free_percpu(new->scratch_aligned); +#endif +out_scratch: + free_percpu(new->scratch); + kfree(new); + + return ERR_PTR(-ENOMEM); +} + +/** + * pipapo_rules_same_key() - Get number of rules originated from the same entry + * @f: Field containing mapping table + * @first: Index of first rule in set of rules mapping to same entry + * + * Using the fact that all rules in a field that originated from the same entry + * will map to the same set of rules in the next field, or to the same element + * reference, return the cardinality of the set of rules that originated from + * the same entry as the rule with index @first, @first rule included. + * + * In pictures: + * rules + * field #0 0 1 2 3 4 + * map to: 0 1 2-4 2-4 5-9 + * . . ....... . ... + * | | | | \ \ + * | | | | \ \ + * | | | | \ \ + * ' ' ' ' ' \ + * in field #1 0 1 2 3 4 5 ... + * + * if this is called for rule 2 on field #0, it will return 3, as also rules 2 + * and 3 in field 0 map to the same set of rules (2, 3, 4) in the next field. + * + * For the last field in a set, we can rely on associated entries to map to the + * same element references. + * + * Return: Number of rules that originated from the same entry as @first. + */ +static int pipapo_rules_same_key(struct nft_pipapo_field *f, int first) +{ + struct nft_pipapo_elem *e = NULL; /* Keep gcc happy */ + int r; + + for (r = first; r < f->rules; r++) { + if (r != first && e != f->mt[r].e) + return r - first; + + e = f->mt[r].e; + } + + if (r != first) + return r - first; + + return 0; +} + +/** + * pipapo_unmap() - Remove rules from mapping tables, renumber remaining ones + * @mt: Mapping array + * @rules: Original amount of rules in mapping table + * @start: First rule index to be removed + * @n: Amount of rules to be removed + * @to_offset: First rule index, in next field, this group of rules maps to + * @is_last: If this is the last field, delete reference from mapping array + * + * This is used to unmap rules from the mapping table for a single field, + * maintaining consistency and compactness for the existing ones. + * + * In pictures: let's assume that we want to delete rules 2 and 3 from the + * following mapping array: + * + * rules + * 0 1 2 3 4 + * map to: 4-10 4-10 11-15 11-15 16-18 + * + * the result will be: + * + * rules + * 0 1 2 + * map to: 4-10 4-10 11-13 + * + * for fields before the last one. In case this is the mapping table for the + * last field in a set, and rules map to pointers to &struct nft_pipapo_elem: + * + * rules + * 0 1 2 3 4 + * element pointers: 0x42 0x42 0x33 0x33 0x44 + * + * the result will be: + * + * rules + * 0 1 2 + * element pointers: 0x42 0x42 0x44 + */ +static void pipapo_unmap(union nft_pipapo_map_bucket *mt, int rules, + int start, int n, int to_offset, bool is_last) +{ + int i; + + memmove(mt + start, mt + start + n, (rules - start - n) * sizeof(*mt)); + memset(mt + rules - n, 0, n * sizeof(*mt)); + + if (is_last) + return; + + for (i = start; i < rules - n; i++) + mt[i].to -= to_offset; +} + +/** + * pipapo_drop() - Delete entry from lookup and mapping tables, given rule map + * @m: Matching data + * @rulemap: Table of rule maps, arrays of first rule and amount of rules + * in next field a given entry maps to, for each field + * + * For each rule in lookup table buckets mapping to this set of rules, drop + * all bits set in lookup table mapping. In pictures, assuming we want to drop + * rules 0 and 1 from this lookup table: + * + * bucket + * group 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + * 0 0 1,2 + * 1 1,2 0 + * 2 0 1,2 + * 3 0 1,2 + * 4 0,1,2 + * 5 0 1 2 + * 6 0,1,2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + * 7 1,2 1,2 1 1 1 0,1 1 1 1 1 1 1 1 1 1 1 + * + * rule 2 becomes rule 0, and the result will be: + * + * bucket + * group 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + * 0 0 + * 1 0 + * 2 0 + * 3 0 + * 4 0 + * 5 0 + * 6 0 + * 7 0 0 + * + * once this is done, call unmap() to drop all the corresponding rule references + * from mapping tables. + */ +static void pipapo_drop(struct nft_pipapo_match *m, + union nft_pipapo_map_bucket rulemap[]) +{ + struct nft_pipapo_field *f; + int i; + + nft_pipapo_for_each_field(f, i, m) { + int g; + + for (g = 0; g < f->groups; g++) { + unsigned long *pos; + int b; + + pos = NFT_PIPAPO_LT_ALIGN(f->lt) + g * + NFT_PIPAPO_BUCKETS(f->bb) * f->bsize; + + for (b = 0; b < NFT_PIPAPO_BUCKETS(f->bb); b++) { + bitmap_cut(pos, pos, rulemap[i].to, + rulemap[i].n, + f->bsize * BITS_PER_LONG); + + pos += f->bsize; + } + } + + pipapo_unmap(f->mt, f->rules, rulemap[i].to, rulemap[i].n, + rulemap[i + 1].n, i == m->field_count - 1); + if (pipapo_resize(f, f->rules, f->rules - rulemap[i].n)) { + /* We can ignore this, a failure to shrink tables down + * doesn't make tables invalid. + */ + ; + } + f->rules -= rulemap[i].n; + + pipapo_lt_bits_adjust(f); + } +} + +static void nft_pipapo_gc_deactivate(struct net *net, struct nft_set *set, + struct nft_pipapo_elem *e) + +{ + struct nft_set_elem elem = { + .priv = e, + }; + + nft_setelem_data_deactivate(net, set, &elem); +} + +/** + * pipapo_gc() - Drop expired entries from set, destroy start and end elements + * @_set: nftables API set representation + * @m: Matching data + */ +static void pipapo_gc(const struct nft_set *_set, struct nft_pipapo_match *m) +{ + struct nft_set *set = (struct nft_set *) _set; + struct nft_pipapo *priv = nft_set_priv(set); + struct net *net = read_pnet(&set->net); + int rules_f0, first_rule = 0; + struct nft_pipapo_elem *e; + struct nft_trans_gc *gc; + + gc = nft_trans_gc_alloc(set, 0, GFP_KERNEL); + if (!gc) + return; + + while ((rules_f0 = pipapo_rules_same_key(m->f, first_rule))) { + union nft_pipapo_map_bucket rulemap[NFT_PIPAPO_MAX_FIELDS]; + struct nft_pipapo_field *f; + int i, start, rules_fx; + + start = first_rule; + rules_fx = rules_f0; + + nft_pipapo_for_each_field(f, i, m) { + rulemap[i].to = start; + rulemap[i].n = rules_fx; + + if (i < m->field_count - 1) { + rules_fx = f->mt[start].n; + start = f->mt[start].to; + } + } + + /* Pick the last field, and its last index */ + f--; + i--; + e = f->mt[rulemap[i].to].e; + + /* synchronous gc never fails, there is no need to set on + * NFT_SET_ELEM_DEAD_BIT. + */ + if (nft_set_elem_expired(&e->ext)) { + priv->dirty = true; + + gc = nft_trans_gc_queue_sync(gc, GFP_ATOMIC); + if (!gc) + return; + + nft_pipapo_gc_deactivate(net, set, e); + pipapo_drop(m, rulemap); + nft_trans_gc_elem_add(gc, e); + + /* And check again current first rule, which is now the + * first we haven't checked. + */ + } else { + first_rule += rules_f0; + } + } + + gc = nft_trans_gc_catchall_sync(gc); + if (gc) { + nft_trans_gc_queue_sync_done(gc); + priv->last_gc = jiffies; + } +} + +/** + * pipapo_free_fields() - Free per-field tables contained in matching data + * @m: Matching data + */ +static void pipapo_free_fields(struct nft_pipapo_match *m) +{ + struct nft_pipapo_field *f; + int i; + + nft_pipapo_for_each_field(f, i, m) { + kvfree(f->lt); + kvfree(f->mt); + } +} + +static void pipapo_free_match(struct nft_pipapo_match *m) +{ + int i; + + for_each_possible_cpu(i) + kfree(*per_cpu_ptr(m->scratch, i)); + +#ifdef NFT_PIPAPO_ALIGN + free_percpu(m->scratch_aligned); +#endif + free_percpu(m->scratch); + + pipapo_free_fields(m); + + kfree(m); +} + +/** + * pipapo_reclaim_match - RCU callback to free fields from old matching data + * @rcu: RCU head + */ +static void pipapo_reclaim_match(struct rcu_head *rcu) +{ + struct nft_pipapo_match *m; + + m = container_of(rcu, struct nft_pipapo_match, rcu); + pipapo_free_match(m); +} + +/** + * nft_pipapo_commit() - Replace lookup data with current working copy + * @set: nftables API set representation + * + * While at it, check if we should perform garbage collection on the working + * copy before committing it for lookup, and don't replace the table if the + * working copy doesn't have pending changes. + * + * We also need to create a new working copy for subsequent insertions and + * deletions. + */ +static void nft_pipapo_commit(const struct nft_set *set) +{ + struct nft_pipapo *priv = nft_set_priv(set); + struct nft_pipapo_match *new_clone, *old; + + if (time_after_eq(jiffies, priv->last_gc + nft_set_gc_interval(set))) + pipapo_gc(set, priv->clone); + + if (!priv->dirty) + return; + + new_clone = pipapo_clone(priv->clone); + if (IS_ERR(new_clone)) + return; + + priv->dirty = false; + + old = rcu_access_pointer(priv->match); + rcu_assign_pointer(priv->match, priv->clone); + if (old) + call_rcu(&old->rcu, pipapo_reclaim_match); + + priv->clone = new_clone; +} + +static bool nft_pipapo_transaction_mutex_held(const struct nft_set *set) +{ +#ifdef CONFIG_PROVE_LOCKING + const struct net *net = read_pnet(&set->net); + + return lockdep_is_held(&nft_pernet(net)->commit_mutex); +#else + return true; +#endif +} + +static void nft_pipapo_abort(const struct nft_set *set) +{ + struct nft_pipapo *priv = nft_set_priv(set); + struct nft_pipapo_match *new_clone, *m; + + if (!priv->dirty) + return; + + m = rcu_dereference_protected(priv->match, nft_pipapo_transaction_mutex_held(set)); + + new_clone = pipapo_clone(m); + if (IS_ERR(new_clone)) + return; + + priv->dirty = false; + + pipapo_free_match(priv->clone); + priv->clone = new_clone; +} + +/** + * nft_pipapo_activate() - Mark element reference as active given key, commit + * @net: Network namespace + * @set: nftables API set representation + * @elem: nftables API element representation containing key data + * + * On insertion, elements are added to a copy of the matching data currently + * in use for lookups, and not directly inserted into current lookup data. Both + * nft_pipapo_insert() and nft_pipapo_activate() are called once for each + * element, hence we can't purpose either one as a real commit operation. + */ +static void nft_pipapo_activate(const struct net *net, + const struct nft_set *set, + const struct nft_set_elem *elem) +{ + struct nft_pipapo_elem *e = elem->priv; + + nft_set_elem_change_active(net, set, &e->ext); +} + +/** + * pipapo_deactivate() - Check that element is in set, mark as inactive + * @net: Network namespace + * @set: nftables API set representation + * @data: Input key data + * @ext: nftables API extension pointer, used to check for end element + * + * This is a convenience function that can be called from both + * nft_pipapo_deactivate() and nft_pipapo_flush(), as they are in fact the same + * operation. + * + * Return: deactivated element if found, NULL otherwise. + */ +static void *pipapo_deactivate(const struct net *net, const struct nft_set *set, + const u8 *data, const struct nft_set_ext *ext) +{ + struct nft_pipapo_elem *e; + + e = pipapo_get(net, set, data, nft_genmask_next(net)); + if (IS_ERR(e)) + return NULL; + + nft_set_elem_change_active(net, set, &e->ext); + + return e; +} + +/** + * nft_pipapo_deactivate() - Call pipapo_deactivate() to make element inactive + * @net: Network namespace + * @set: nftables API set representation + * @elem: nftables API element representation containing key data + * + * Return: deactivated element if found, NULL otherwise. + */ +static void *nft_pipapo_deactivate(const struct net *net, + const struct nft_set *set, + const struct nft_set_elem *elem) +{ + const struct nft_set_ext *ext = nft_set_elem_ext(set, elem->priv); + + return pipapo_deactivate(net, set, (const u8 *)elem->key.val.data, ext); +} + +/** + * nft_pipapo_flush() - Call pipapo_deactivate() to make element inactive + * @net: Network namespace + * @set: nftables API set representation + * @elem: nftables API element representation containing key data + * + * This is functionally the same as nft_pipapo_deactivate(), with a slightly + * different interface, and it's also called once for each element in a set + * being flushed, so we can't implement, strictly speaking, a flush operation, + * which would otherwise be as simple as allocating an empty copy of the + * matching data. + * + * Note that we could in theory do that, mark the set as flushed, and ignore + * subsequent calls, but we would leak all the elements after the first one, + * because they wouldn't then be freed as result of API calls. + * + * Return: true if element was found and deactivated. + */ +static bool nft_pipapo_flush(const struct net *net, const struct nft_set *set, + void *elem) +{ + struct nft_pipapo_elem *e = elem; + + return pipapo_deactivate(net, set, (const u8 *)nft_set_ext_key(&e->ext), + &e->ext); +} + +/** + * pipapo_get_boundaries() - Get byte interval for associated rules + * @f: Field including lookup table + * @first_rule: First rule (lowest index) + * @rule_count: Number of associated rules + * @left: Byte expression for left boundary (start of range) + * @right: Byte expression for right boundary (end of range) + * + * Given the first rule and amount of rules that originated from the same entry, + * build the original range associated with the entry, and calculate the length + * of the originating netmask. + * + * In pictures: + * + * bucket + * group 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + * 0 1,2 + * 1 1,2 + * 2 1,2 + * 3 1,2 + * 4 1,2 + * 5 1 2 + * 6 1,2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + * 7 1,2 1,2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + * + * this is the lookup table corresponding to the IPv4 range + * 192.168.1.0-192.168.2.1, which was expanded to the two composing netmasks, + * rule #1: 192.168.1.0/24, and rule #2: 192.168.2.0/31. + * + * This function fills @left and @right with the byte values of the leftmost + * and rightmost bucket indices for the lowest and highest rule indices, + * respectively. If @first_rule is 1 and @rule_count is 2, we obtain, in + * nibbles: + * left: < 12, 0, 10, 8, 0, 1, 0, 0 > + * right: < 12, 0, 10, 8, 0, 2, 2, 1 > + * corresponding to bytes: + * left: < 192, 168, 1, 0 > + * right: < 192, 168, 2, 1 > + * with mask length irrelevant here, unused on return, as the range is already + * defined by its start and end points. The mask length is relevant for a single + * ranged entry instead: if @first_rule is 1 and @rule_count is 1, we ignore + * rule 2 above: @left becomes < 192, 168, 1, 0 >, @right becomes + * < 192, 168, 1, 255 >, and the mask length, calculated from the distances + * between leftmost and rightmost bucket indices for each group, would be 24. + * + * Return: mask length, in bits. + */ +static int pipapo_get_boundaries(struct nft_pipapo_field *f, int first_rule, + int rule_count, u8 *left, u8 *right) +{ + int g, mask_len = 0, bit_offset = 0; + u8 *l = left, *r = right; + + for (g = 0; g < f->groups; g++) { + int b, x0, x1; + + x0 = -1; + x1 = -1; + for (b = 0; b < NFT_PIPAPO_BUCKETS(f->bb); b++) { + unsigned long *pos; + + pos = NFT_PIPAPO_LT_ALIGN(f->lt) + + (g * NFT_PIPAPO_BUCKETS(f->bb) + b) * f->bsize; + if (test_bit(first_rule, pos) && x0 == -1) + x0 = b; + if (test_bit(first_rule + rule_count - 1, pos)) + x1 = b; + } + + *l |= x0 << (BITS_PER_BYTE - f->bb - bit_offset); + *r |= x1 << (BITS_PER_BYTE - f->bb - bit_offset); + + bit_offset += f->bb; + if (bit_offset >= BITS_PER_BYTE) { + bit_offset %= BITS_PER_BYTE; + l++; + r++; + } + + if (x1 - x0 == 0) + mask_len += 4; + else if (x1 - x0 == 1) + mask_len += 3; + else if (x1 - x0 == 3) + mask_len += 2; + else if (x1 - x0 == 7) + mask_len += 1; + } + + return mask_len; +} + +/** + * pipapo_match_field() - Match rules against byte ranges + * @f: Field including the lookup table + * @first_rule: First of associated rules originating from same entry + * @rule_count: Amount of associated rules + * @start: Start of range to be matched + * @end: End of range to be matched + * + * Return: true on match, false otherwise. + */ +static bool pipapo_match_field(struct nft_pipapo_field *f, + int first_rule, int rule_count, + const u8 *start, const u8 *end) +{ + u8 right[NFT_PIPAPO_MAX_BYTES] = { 0 }; + u8 left[NFT_PIPAPO_MAX_BYTES] = { 0 }; + + pipapo_get_boundaries(f, first_rule, rule_count, left, right); + + return !memcmp(start, left, + f->groups / NFT_PIPAPO_GROUPS_PER_BYTE(f)) && + !memcmp(end, right, f->groups / NFT_PIPAPO_GROUPS_PER_BYTE(f)); +} + +/** + * nft_pipapo_remove() - Remove element given key, commit + * @net: Network namespace + * @set: nftables API set representation + * @elem: nftables API element representation containing key data + * + * Similarly to nft_pipapo_activate(), this is used as commit operation by the + * API, but it's called once per element in the pending transaction, so we can't + * implement this as a single commit operation. Closest we can get is to remove + * the matched element here, if any, and commit the updated matching data. + */ +static void nft_pipapo_remove(const struct net *net, const struct nft_set *set, + const struct nft_set_elem *elem) +{ + struct nft_pipapo *priv = nft_set_priv(set); + struct nft_pipapo_match *m = priv->clone; + struct nft_pipapo_elem *e = elem->priv; + int rules_f0, first_rule = 0; + const u8 *data; + + data = (const u8 *)nft_set_ext_key(&e->ext); + + while ((rules_f0 = pipapo_rules_same_key(m->f, first_rule))) { + union nft_pipapo_map_bucket rulemap[NFT_PIPAPO_MAX_FIELDS]; + const u8 *match_start, *match_end; + struct nft_pipapo_field *f; + int i, start, rules_fx; + + match_start = data; + + if (nft_set_ext_exists(&e->ext, NFT_SET_EXT_KEY_END)) + match_end = (const u8 *)nft_set_ext_key_end(&e->ext)->data; + else + match_end = data; + + start = first_rule; + rules_fx = rules_f0; + + nft_pipapo_for_each_field(f, i, m) { + if (!pipapo_match_field(f, start, rules_fx, + match_start, match_end)) + break; + + rulemap[i].to = start; + rulemap[i].n = rules_fx; + + rules_fx = f->mt[start].n; + start = f->mt[start].to; + + match_start += NFT_PIPAPO_GROUPS_PADDED_SIZE(f); + match_end += NFT_PIPAPO_GROUPS_PADDED_SIZE(f); + } + + if (i == m->field_count) { + priv->dirty = true; + pipapo_drop(m, rulemap); + return; + } + + first_rule += rules_f0; + } +} + +/** + * nft_pipapo_walk() - Walk over elements + * @ctx: nftables API context + * @set: nftables API set representation + * @iter: Iterator + * + * As elements are referenced in the mapping array for the last field, directly + * scan that array: there's no need to follow rule mappings from the first + * field. + */ +static void nft_pipapo_walk(const struct nft_ctx *ctx, struct nft_set *set, + struct nft_set_iter *iter) +{ + struct nft_pipapo *priv = nft_set_priv(set); + struct net *net = read_pnet(&set->net); + struct nft_pipapo_match *m; + struct nft_pipapo_field *f; + int i, r; + + rcu_read_lock(); + if (iter->genmask == nft_genmask_cur(net)) + m = rcu_dereference(priv->match); + else + m = priv->clone; + + if (unlikely(!m)) + goto out; + + for (i = 0, f = m->f; i < m->field_count - 1; i++, f++) + ; + + for (r = 0; r < f->rules; r++) { + struct nft_pipapo_elem *e; + struct nft_set_elem elem; + + if (r < f->rules - 1 && f->mt[r + 1].e == f->mt[r].e) + continue; + + if (iter->count < iter->skip) + goto cont; + + e = f->mt[r].e; + + if (!nft_set_elem_active(&e->ext, iter->genmask)) + goto cont; + + elem.priv = e; + + iter->err = iter->fn(ctx, set, iter, &elem); + if (iter->err < 0) + goto out; + +cont: + iter->count++; + } + +out: + rcu_read_unlock(); +} + +/** + * nft_pipapo_privsize() - Return the size of private data for the set + * @nla: netlink attributes, ignored as size doesn't depend on them + * @desc: Set description, ignored as size doesn't depend on it + * + * Return: size of private data for this set implementation, in bytes + */ +static u64 nft_pipapo_privsize(const struct nlattr * const nla[], + const struct nft_set_desc *desc) +{ + return sizeof(struct nft_pipapo); +} + +/** + * nft_pipapo_estimate() - Set size, space and lookup complexity + * @desc: Set description, element count and field description used + * @features: Flags: NFT_SET_INTERVAL needs to be there + * @est: Storage for estimation data + * + * Return: true if set description is compatible, false otherwise + */ +static bool nft_pipapo_estimate(const struct nft_set_desc *desc, u32 features, + struct nft_set_estimate *est) +{ + if (!(features & NFT_SET_INTERVAL) || + desc->field_count < NFT_PIPAPO_MIN_FIELDS) + return false; + + est->size = pipapo_estimate_size(desc); + if (!est->size) + return false; + + est->lookup = NFT_SET_CLASS_O_LOG_N; + + est->space = NFT_SET_CLASS_O_N; + + return true; +} + +/** + * nft_pipapo_init() - Initialise data for a set instance + * @set: nftables API set representation + * @desc: Set description + * @nla: netlink attributes + * + * Validate number and size of fields passed as NFTA_SET_DESC_CONCAT netlink + * attributes, initialise internal set parameters, current instance of matching + * data and a copy for subsequent insertions. + * + * Return: 0 on success, negative error code on failure. + */ +static int nft_pipapo_init(const struct nft_set *set, + const struct nft_set_desc *desc, + const struct nlattr * const nla[]) +{ + struct nft_pipapo *priv = nft_set_priv(set); + struct nft_pipapo_match *m; + struct nft_pipapo_field *f; + int err, i, field_count; + + field_count = desc->field_count ? : 1; + + if (field_count > NFT_PIPAPO_MAX_FIELDS) + return -EINVAL; + + m = kmalloc(sizeof(*priv->match) + sizeof(*f) * field_count, + GFP_KERNEL); + if (!m) + return -ENOMEM; + + m->field_count = field_count; + m->bsize_max = 0; + + m->scratch = alloc_percpu(unsigned long *); + if (!m->scratch) { + err = -ENOMEM; + goto out_scratch; + } + for_each_possible_cpu(i) + *per_cpu_ptr(m->scratch, i) = NULL; + +#ifdef NFT_PIPAPO_ALIGN + m->scratch_aligned = alloc_percpu(unsigned long *); + if (!m->scratch_aligned) { + err = -ENOMEM; + goto out_free; + } + for_each_possible_cpu(i) + *per_cpu_ptr(m->scratch_aligned, i) = NULL; +#endif + + rcu_head_init(&m->rcu); + + nft_pipapo_for_each_field(f, i, m) { + int len = desc->field_len[i] ? : set->klen; + + f->bb = NFT_PIPAPO_GROUP_BITS_INIT; + f->groups = len * NFT_PIPAPO_GROUPS_PER_BYTE(f); + + priv->width += round_up(len, sizeof(u32)); + + f->bsize = 0; + f->rules = 0; + NFT_PIPAPO_LT_ASSIGN(f, NULL); + f->mt = NULL; + } + + /* Create an initial clone of matching data for next insertion */ + priv->clone = pipapo_clone(m); + if (IS_ERR(priv->clone)) { + err = PTR_ERR(priv->clone); + goto out_free; + } + + priv->dirty = false; + + rcu_assign_pointer(priv->match, m); + + return 0; + +out_free: +#ifdef NFT_PIPAPO_ALIGN + free_percpu(m->scratch_aligned); +#endif + free_percpu(m->scratch); +out_scratch: + kfree(m); + + return err; +} + +/** + * nft_set_pipapo_match_destroy() - Destroy elements from key mapping array + * @ctx: context + * @set: nftables API set representation + * @m: matching data pointing to key mapping array + */ +static void nft_set_pipapo_match_destroy(const struct nft_ctx *ctx, + const struct nft_set *set, + struct nft_pipapo_match *m) +{ + struct nft_pipapo_field *f; + int i, r; + + for (i = 0, f = m->f; i < m->field_count - 1; i++, f++) + ; + + for (r = 0; r < f->rules; r++) { + struct nft_pipapo_elem *e; + + if (r < f->rules - 1 && f->mt[r + 1].e == f->mt[r].e) + continue; + + e = f->mt[r].e; + + nf_tables_set_elem_destroy(ctx, set, e); + } +} + +/** + * nft_pipapo_destroy() - Free private data for set and all committed elements + * @ctx: context + * @set: nftables API set representation + */ +static void nft_pipapo_destroy(const struct nft_ctx *ctx, + const struct nft_set *set) +{ + struct nft_pipapo *priv = nft_set_priv(set); + struct nft_pipapo_match *m; + int cpu; + + m = rcu_dereference_protected(priv->match, true); + if (m) { + rcu_barrier(); + + nft_set_pipapo_match_destroy(ctx, set, m); + +#ifdef NFT_PIPAPO_ALIGN + free_percpu(m->scratch_aligned); +#endif + for_each_possible_cpu(cpu) + kfree(*per_cpu_ptr(m->scratch, cpu)); + free_percpu(m->scratch); + pipapo_free_fields(m); + kfree(m); + priv->match = NULL; + } + + if (priv->clone) { + m = priv->clone; + + if (priv->dirty) + nft_set_pipapo_match_destroy(ctx, set, m); + +#ifdef NFT_PIPAPO_ALIGN + free_percpu(priv->clone->scratch_aligned); +#endif + for_each_possible_cpu(cpu) + kfree(*per_cpu_ptr(priv->clone->scratch, cpu)); + free_percpu(priv->clone->scratch); + + pipapo_free_fields(priv->clone); + kfree(priv->clone); + priv->clone = NULL; + } +} + +/** + * nft_pipapo_gc_init() - Initialise garbage collection + * @set: nftables API set representation + * + * Instead of actually setting up a periodic work for garbage collection, as + * this operation requires a swap of matching data with the working copy, we'll + * do that opportunistically with other commit operations if the interval is + * elapsed, so we just need to set the current jiffies timestamp here. + */ +static void nft_pipapo_gc_init(const struct nft_set *set) +{ + struct nft_pipapo *priv = nft_set_priv(set); + + priv->last_gc = jiffies; +} + +const struct nft_set_type nft_set_pipapo_type = { + .features = NFT_SET_INTERVAL | NFT_SET_MAP | NFT_SET_OBJECT | + NFT_SET_TIMEOUT, + .ops = { + .lookup = nft_pipapo_lookup, + .insert = nft_pipapo_insert, + .activate = nft_pipapo_activate, + .deactivate = nft_pipapo_deactivate, + .flush = nft_pipapo_flush, + .remove = nft_pipapo_remove, + .walk = nft_pipapo_walk, + .get = nft_pipapo_get, + .privsize = nft_pipapo_privsize, + .estimate = nft_pipapo_estimate, + .init = nft_pipapo_init, + .destroy = nft_pipapo_destroy, + .gc_init = nft_pipapo_gc_init, + .commit = nft_pipapo_commit, + .abort = nft_pipapo_abort, + .elemsize = offsetof(struct nft_pipapo_elem, ext), + }, +}; + +#if defined(CONFIG_X86_64) && !defined(CONFIG_UML) +const struct nft_set_type nft_set_pipapo_avx2_type = { + .features = NFT_SET_INTERVAL | NFT_SET_MAP | NFT_SET_OBJECT | + NFT_SET_TIMEOUT, + .ops = { + .lookup = nft_pipapo_avx2_lookup, + .insert = nft_pipapo_insert, + .activate = nft_pipapo_activate, + .deactivate = nft_pipapo_deactivate, + .flush = nft_pipapo_flush, + .remove = nft_pipapo_remove, + .walk = nft_pipapo_walk, + .get = nft_pipapo_get, + .privsize = nft_pipapo_privsize, + .estimate = nft_pipapo_avx2_estimate, + .init = nft_pipapo_init, + .destroy = nft_pipapo_destroy, + .gc_init = nft_pipapo_gc_init, + .commit = nft_pipapo_commit, + .abort = nft_pipapo_abort, + .elemsize = offsetof(struct nft_pipapo_elem, ext), + }, +}; +#endif diff --git a/net/netfilter/nft_set_pipapo.h b/net/netfilter/nft_set_pipapo.h new file mode 100644 index 000000000..25a755915 --- /dev/null +++ b/net/netfilter/nft_set_pipapo.h @@ -0,0 +1,280 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#ifndef _NFT_SET_PIPAPO_H + +#include +#include /* For the maximum length of a field */ + +/* Count of concatenated fields depends on count of 32-bit nftables registers */ +#define NFT_PIPAPO_MAX_FIELDS NFT_REG32_COUNT + +/* Restrict usage to multiple fields, make sure rbtree is used otherwise */ +#define NFT_PIPAPO_MIN_FIELDS 2 + +/* Largest supported field size */ +#define NFT_PIPAPO_MAX_BYTES (sizeof(struct in6_addr)) +#define NFT_PIPAPO_MAX_BITS (NFT_PIPAPO_MAX_BYTES * BITS_PER_BYTE) + +/* Bits to be grouped together in table buckets depending on set size */ +#define NFT_PIPAPO_GROUP_BITS_INIT NFT_PIPAPO_GROUP_BITS_SMALL_SET +#define NFT_PIPAPO_GROUP_BITS_SMALL_SET 8 +#define NFT_PIPAPO_GROUP_BITS_LARGE_SET 4 +#define NFT_PIPAPO_GROUP_BITS_ARE_8_OR_4 \ + BUILD_BUG_ON((NFT_PIPAPO_GROUP_BITS_SMALL_SET != 8) || \ + (NFT_PIPAPO_GROUP_BITS_LARGE_SET != 4)) +#define NFT_PIPAPO_GROUPS_PER_BYTE(f) (BITS_PER_BYTE / (f)->bb) + +/* If a lookup table gets bigger than NFT_PIPAPO_LT_SIZE_HIGH, switch to the + * small group width, and switch to the big group width if the table gets + * smaller than NFT_PIPAPO_LT_SIZE_LOW. + * + * Picking 2MiB as threshold (for a single table) avoids as much as possible + * crossing page boundaries on most architectures (x86-64 and MIPS huge pages, + * ARMv7 supersections, POWER "large" pages, SPARC Level 1 regions, etc.), which + * keeps performance nice in case kvmalloc() gives us non-contiguous areas. + */ +#define NFT_PIPAPO_LT_SIZE_THRESHOLD (1 << 21) +#define NFT_PIPAPO_LT_SIZE_HYSTERESIS (1 << 16) +#define NFT_PIPAPO_LT_SIZE_HIGH NFT_PIPAPO_LT_SIZE_THRESHOLD +#define NFT_PIPAPO_LT_SIZE_LOW NFT_PIPAPO_LT_SIZE_THRESHOLD - \ + NFT_PIPAPO_LT_SIZE_HYSTERESIS + +/* Fields are padded to 32 bits in input registers */ +#define NFT_PIPAPO_GROUPS_PADDED_SIZE(f) \ + (round_up((f)->groups / NFT_PIPAPO_GROUPS_PER_BYTE(f), sizeof(u32))) +#define NFT_PIPAPO_GROUPS_PADDING(f) \ + (NFT_PIPAPO_GROUPS_PADDED_SIZE(f) - (f)->groups / \ + NFT_PIPAPO_GROUPS_PER_BYTE(f)) + +/* Number of buckets given by 2 ^ n, with n bucket bits */ +#define NFT_PIPAPO_BUCKETS(bb) (1 << (bb)) + +/* Each n-bit range maps to up to n * 2 rules */ +#define NFT_PIPAPO_MAP_NBITS (const_ilog2(NFT_PIPAPO_MAX_BITS * 2)) + +/* Use the rest of mapping table buckets for rule indices, but it makes no sense + * to exceed 32 bits + */ +#if BITS_PER_LONG == 64 +#define NFT_PIPAPO_MAP_TOBITS 32 +#else +#define NFT_PIPAPO_MAP_TOBITS (BITS_PER_LONG - NFT_PIPAPO_MAP_NBITS) +#endif + +/* ...which gives us the highest allowed index for a rule */ +#define NFT_PIPAPO_RULE0_MAX ((1UL << (NFT_PIPAPO_MAP_TOBITS - 1)) \ + - (1UL << NFT_PIPAPO_MAP_NBITS)) + +/* Definitions for vectorised implementations */ +#ifdef NFT_PIPAPO_ALIGN +#define NFT_PIPAPO_ALIGN_HEADROOM \ + (NFT_PIPAPO_ALIGN - ARCH_KMALLOC_MINALIGN) +#define NFT_PIPAPO_LT_ALIGN(lt) (PTR_ALIGN((lt), NFT_PIPAPO_ALIGN)) +#define NFT_PIPAPO_LT_ASSIGN(field, x) \ + do { \ + (field)->lt_aligned = NFT_PIPAPO_LT_ALIGN(x); \ + (field)->lt = (x); \ + } while (0) +#else +#define NFT_PIPAPO_ALIGN_HEADROOM 0 +#define NFT_PIPAPO_LT_ALIGN(lt) (lt) +#define NFT_PIPAPO_LT_ASSIGN(field, x) ((field)->lt = (x)) +#endif /* NFT_PIPAPO_ALIGN */ + +#define nft_pipapo_for_each_field(field, index, match) \ + for ((field) = (match)->f, (index) = 0; \ + (index) < (match)->field_count; \ + (index)++, (field)++) + +/** + * union nft_pipapo_map_bucket - Bucket of mapping table + * @to: First rule number (in next field) this rule maps to + * @n: Number of rules (in next field) this rule maps to + * @e: If there's no next field, pointer to element this rule maps to + */ +union nft_pipapo_map_bucket { + struct { +#if BITS_PER_LONG == 64 + static_assert(NFT_PIPAPO_MAP_TOBITS <= 32); + u32 to; + + static_assert(NFT_PIPAPO_MAP_NBITS <= 32); + u32 n; +#else + unsigned long to:NFT_PIPAPO_MAP_TOBITS; + unsigned long n:NFT_PIPAPO_MAP_NBITS; +#endif + }; + struct nft_pipapo_elem *e; +}; + +/** + * struct nft_pipapo_field - Lookup, mapping tables and related data for a field + * @groups: Amount of bit groups + * @rules: Number of inserted rules + * @bsize: Size of each bucket in lookup table, in longs + * @bb: Number of bits grouped together in lookup table buckets + * @lt: Lookup table: 'groups' rows of buckets + * @lt_aligned: Version of @lt aligned to NFT_PIPAPO_ALIGN bytes + * @mt: Mapping table: one bucket per rule + */ +struct nft_pipapo_field { + int groups; + unsigned long rules; + size_t bsize; + int bb; +#ifdef NFT_PIPAPO_ALIGN + unsigned long *lt_aligned; +#endif + unsigned long *lt; + union nft_pipapo_map_bucket *mt; +}; + +/** + * struct nft_pipapo_match - Data used for lookup and matching + * @field_count Amount of fields in set + * @scratch: Preallocated per-CPU maps for partial matching results + * @scratch_aligned: Version of @scratch aligned to NFT_PIPAPO_ALIGN bytes + * @bsize_max: Maximum lookup table bucket size of all fields, in longs + * @rcu Matching data is swapped on commits + * @f: Fields, with lookup and mapping tables + */ +struct nft_pipapo_match { + int field_count; +#ifdef NFT_PIPAPO_ALIGN + unsigned long * __percpu *scratch_aligned; +#endif + unsigned long * __percpu *scratch; + size_t bsize_max; + struct rcu_head rcu; + struct nft_pipapo_field f[]; +}; + +/** + * struct nft_pipapo - Representation of a set + * @match: Currently in-use matching data + * @clone: Copy where pending insertions and deletions are kept + * @width: Total bytes to be matched for one packet, including padding + * @dirty: Working copy has pending insertions or deletions + * @last_gc: Timestamp of last garbage collection run, jiffies + */ +struct nft_pipapo { + struct nft_pipapo_match __rcu *match; + struct nft_pipapo_match *clone; + int width; + bool dirty; + unsigned long last_gc; +}; + +struct nft_pipapo_elem; + +/** + * struct nft_pipapo_elem - API-facing representation of single set element + * @ext: nftables API extensions + */ +struct nft_pipapo_elem { + struct nft_set_ext ext; +}; + +int pipapo_refill(unsigned long *map, int len, int rules, unsigned long *dst, + union nft_pipapo_map_bucket *mt, bool match_only); + +/** + * pipapo_and_field_buckets_4bit() - Intersect 4-bit buckets + * @f: Field including lookup table + * @dst: Area to store result + * @data: Input data selecting table buckets + */ +static inline void pipapo_and_field_buckets_4bit(struct nft_pipapo_field *f, + unsigned long *dst, + const u8 *data) +{ + unsigned long *lt = NFT_PIPAPO_LT_ALIGN(f->lt); + int group; + + for (group = 0; group < f->groups; group += BITS_PER_BYTE / 4, data++) { + u8 v; + + v = *data >> 4; + __bitmap_and(dst, dst, lt + v * f->bsize, + f->bsize * BITS_PER_LONG); + lt += f->bsize * NFT_PIPAPO_BUCKETS(4); + + v = *data & 0x0f; + __bitmap_and(dst, dst, lt + v * f->bsize, + f->bsize * BITS_PER_LONG); + lt += f->bsize * NFT_PIPAPO_BUCKETS(4); + } +} + +/** + * pipapo_and_field_buckets_8bit() - Intersect 8-bit buckets + * @f: Field including lookup table + * @dst: Area to store result + * @data: Input data selecting table buckets + */ +static inline void pipapo_and_field_buckets_8bit(struct nft_pipapo_field *f, + unsigned long *dst, + const u8 *data) +{ + unsigned long *lt = NFT_PIPAPO_LT_ALIGN(f->lt); + int group; + + for (group = 0; group < f->groups; group++, data++) { + __bitmap_and(dst, dst, lt + *data * f->bsize, + f->bsize * BITS_PER_LONG); + lt += f->bsize * NFT_PIPAPO_BUCKETS(8); + } +} + +/** + * pipapo_estimate_size() - Estimate worst-case for set size + * @desc: Set description, element count and field description used here + * + * The size for this set type can vary dramatically, as it depends on the number + * of rules (composing netmasks) the entries expand to. We compute the worst + * case here. + * + * In general, for a non-ranged entry or a single composing netmask, we need + * one bit in each of the sixteen NFT_PIPAPO_BUCKETS, for each 4-bit group (that + * is, each input bit needs four bits of matching data), plus a bucket in the + * mapping table for each field. + * + * Return: worst-case set size in bytes, 0 on any overflow + */ +static u64 pipapo_estimate_size(const struct nft_set_desc *desc) +{ + unsigned long entry_size; + u64 size; + int i; + + for (i = 0, entry_size = 0; i < desc->field_count; i++) { + unsigned long rules; + + if (desc->field_len[i] > NFT_PIPAPO_MAX_BYTES) + return 0; + + /* Worst-case ranges for each concatenated field: each n-bit + * field can expand to up to n * 2 rules in each bucket, and + * each rule also needs a mapping bucket. + */ + rules = ilog2(desc->field_len[i] * BITS_PER_BYTE) * 2; + entry_size += rules * + NFT_PIPAPO_BUCKETS(NFT_PIPAPO_GROUP_BITS_INIT) / + BITS_PER_BYTE; + entry_size += rules * sizeof(union nft_pipapo_map_bucket); + } + + /* Rules in lookup and mapping tables are needed for each entry */ + size = desc->size * entry_size; + if (size && div_u64(size, desc->size) != entry_size) + return 0; + + size += sizeof(struct nft_pipapo) + sizeof(struct nft_pipapo_match) * 2; + + size += sizeof(struct nft_pipapo_field) * desc->field_count; + + return size; +} + +#endif /* _NFT_SET_PIPAPO_H */ diff --git a/net/netfilter/nft_set_pipapo_avx2.c b/net/netfilter/nft_set_pipapo_avx2.c new file mode 100644 index 000000000..52e0d026d --- /dev/null +++ b/net/netfilter/nft_set_pipapo_avx2.c @@ -0,0 +1,1228 @@ +// SPDX-License-Identifier: GPL-2.0-only + +/* PIPAPO: PIle PAcket POlicies: AVX2 packet lookup routines + * + * Copyright (c) 2019-2020 Red Hat GmbH + * + * Author: Stefano Brivio + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "nft_set_pipapo_avx2.h" +#include "nft_set_pipapo.h" + +#define NFT_PIPAPO_LONGS_PER_M256 (XSAVE_YMM_SIZE / BITS_PER_LONG) + +/* Load from memory into YMM register with non-temporal hint ("stream load"), + * that is, don't fetch lines from memory into the cache. This avoids pushing + * precious packet data out of the cache hierarchy, and is appropriate when: + * + * - loading buckets from lookup tables, as they are not going to be used + * again before packets are entirely classified + * + * - loading the result bitmap from the previous field, as it's never used + * again + */ +#define NFT_PIPAPO_AVX2_LOAD(reg, loc) \ + asm volatile("vmovntdqa %0, %%ymm" #reg : : "m" (loc)) + +/* Stream a single lookup table bucket into YMM register given lookup table, + * group index, value of packet bits, bucket size. + */ +#define NFT_PIPAPO_AVX2_BUCKET_LOAD4(reg, lt, group, v, bsize) \ + NFT_PIPAPO_AVX2_LOAD(reg, \ + lt[((group) * NFT_PIPAPO_BUCKETS(4) + \ + (v)) * (bsize)]) +#define NFT_PIPAPO_AVX2_BUCKET_LOAD8(reg, lt, group, v, bsize) \ + NFT_PIPAPO_AVX2_LOAD(reg, \ + lt[((group) * NFT_PIPAPO_BUCKETS(8) + \ + (v)) * (bsize)]) + +/* Bitwise AND: the staple operation of this algorithm */ +#define NFT_PIPAPO_AVX2_AND(dst, a, b) \ + asm volatile("vpand %ymm" #a ", %ymm" #b ", %ymm" #dst) + +/* Jump to label if @reg is zero */ +#define NFT_PIPAPO_AVX2_NOMATCH_GOTO(reg, label) \ + asm_volatile_goto("vptest %%ymm" #reg ", %%ymm" #reg ";" \ + "je %l[" #label "]" : : : : label) + +/* Store 256 bits from YMM register into memory. Contrary to bucket load + * operation, we don't bypass the cache here, as stored matching results + * are always used shortly after. + */ +#define NFT_PIPAPO_AVX2_STORE(loc, reg) \ + asm volatile("vmovdqa %%ymm" #reg ", %0" : "=m" (loc)) + +/* Zero out a complete YMM register, @reg */ +#define NFT_PIPAPO_AVX2_ZERO(reg) \ + asm volatile("vpxor %ymm" #reg ", %ymm" #reg ", %ymm" #reg) + +/* Current working bitmap index, toggled between field matches */ +static DEFINE_PER_CPU(bool, nft_pipapo_avx2_scratch_index); + +/** + * nft_pipapo_avx2_prepare() - Prepare before main algorithm body + * + * This zeroes out ymm15, which is later used whenever we need to clear a + * memory location, by storing its content into memory. + */ +static void nft_pipapo_avx2_prepare(void) +{ + NFT_PIPAPO_AVX2_ZERO(15); +} + +/** + * nft_pipapo_avx2_fill() - Fill a bitmap region with ones + * @data: Base memory area + * @start: First bit to set + * @len: Count of bits to fill + * + * This is nothing else than a version of bitmap_set(), as used e.g. by + * pipapo_refill(), tailored for the microarchitectures using it and better + * suited for the specific usage: it's very likely that we'll set a small number + * of bits, not crossing a word boundary, and correct branch prediction is + * critical here. + * + * This function doesn't actually use any AVX2 instruction. + */ +static void nft_pipapo_avx2_fill(unsigned long *data, int start, int len) +{ + int offset = start % BITS_PER_LONG; + unsigned long mask; + + data += start / BITS_PER_LONG; + + if (likely(len == 1)) { + *data |= BIT(offset); + return; + } + + if (likely(len < BITS_PER_LONG || offset)) { + if (likely(len + offset <= BITS_PER_LONG)) { + *data |= GENMASK(len - 1 + offset, offset); + return; + } + + *data |= ~0UL << offset; + len -= BITS_PER_LONG - offset; + data++; + + if (len <= BITS_PER_LONG) { + mask = ~0UL >> (BITS_PER_LONG - len); + *data |= mask; + return; + } + } + + memset(data, 0xff, len / BITS_PER_BYTE); + data += len / BITS_PER_LONG; + + len %= BITS_PER_LONG; + if (len) + *data |= ~0UL >> (BITS_PER_LONG - len); +} + +/** + * nft_pipapo_avx2_refill() - Scan bitmap, select mapping table item, set bits + * @offset: Start from given bitmap (equivalent to bucket) offset, in longs + * @map: Bitmap to be scanned for set bits + * @dst: Destination bitmap + * @mt: Mapping table containing bit set specifiers + * @last: Return index of first set bit, if this is the last field + * + * This is an alternative implementation of pipapo_refill() suitable for usage + * with AVX2 lookup routines: we know there are four words to be scanned, at + * a given offset inside the map, for each matching iteration. + * + * This function doesn't actually use any AVX2 instruction. + * + * Return: first set bit index if @last, index of first filled word otherwise. + */ +static int nft_pipapo_avx2_refill(int offset, unsigned long *map, + unsigned long *dst, + union nft_pipapo_map_bucket *mt, bool last) +{ + int ret = -1; + +#define NFT_PIPAPO_AVX2_REFILL_ONE_WORD(x) \ + do { \ + while (map[(x)]) { \ + int r = __builtin_ctzl(map[(x)]); \ + int i = (offset + (x)) * BITS_PER_LONG + r; \ + \ + if (last) \ + return i; \ + \ + nft_pipapo_avx2_fill(dst, mt[i].to, mt[i].n); \ + \ + if (ret == -1) \ + ret = mt[i].to; \ + \ + map[(x)] &= ~(1UL << r); \ + } \ + } while (0) + + NFT_PIPAPO_AVX2_REFILL_ONE_WORD(0); + NFT_PIPAPO_AVX2_REFILL_ONE_WORD(1); + NFT_PIPAPO_AVX2_REFILL_ONE_WORD(2); + NFT_PIPAPO_AVX2_REFILL_ONE_WORD(3); +#undef NFT_PIPAPO_AVX2_REFILL_ONE_WORD + + return ret; +} + +/** + * nft_pipapo_avx2_lookup_4b_2() - AVX2-based lookup for 2 four-bit groups + * @map: Previous match result, used as initial bitmap + * @fill: Destination bitmap to be filled with current match result + * @f: Field, containing lookup and mapping tables + * @offset: Ignore buckets before the given index, no bits are filled there + * @pkt: Packet data, pointer to input nftables register + * @first: If this is the first field, don't source previous result + * @last: Last field: stop at the first match and return bit index + * + * Load buckets from lookup table corresponding to the values of each 4-bit + * group of packet bytes, and perform a bitwise intersection between them. If + * this is the first field in the set, simply AND the buckets together + * (equivalent to using an all-ones starting bitmap), use the provided starting + * bitmap otherwise. Then call nft_pipapo_avx2_refill() to generate the next + * working bitmap, @fill. + * + * This is used for 8-bit fields (i.e. protocol numbers). + * + * Out-of-order (and superscalar) execution is vital here, so it's critical to + * avoid false data dependencies. CPU and compiler could (mostly) take care of + * this on their own, but the operation ordering is explicitly given here with + * a likely execution order in mind, to highlight possible stalls. That's why + * a number of logically distinct operations (i.e. loading buckets, intersecting + * buckets) are interleaved. + * + * Return: -1 on no match, rule index of match if @last, otherwise first long + * word index to be checked next (i.e. first filled word). + */ +static int nft_pipapo_avx2_lookup_4b_2(unsigned long *map, unsigned long *fill, + struct nft_pipapo_field *f, int offset, + const u8 *pkt, bool first, bool last) +{ + int i, ret = -1, m256_size = f->bsize / NFT_PIPAPO_LONGS_PER_M256, b; + u8 pg[2] = { pkt[0] >> 4, pkt[0] & 0xf }; + unsigned long *lt = f->lt, bsize = f->bsize; + + lt += offset * NFT_PIPAPO_LONGS_PER_M256; + for (i = offset; i < m256_size; i++, lt += NFT_PIPAPO_LONGS_PER_M256) { + int i_ul = i * NFT_PIPAPO_LONGS_PER_M256; + + if (first) { + NFT_PIPAPO_AVX2_BUCKET_LOAD4(0, lt, 0, pg[0], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(1, lt, 1, pg[1], bsize); + NFT_PIPAPO_AVX2_AND(4, 0, 1); + } else { + NFT_PIPAPO_AVX2_BUCKET_LOAD4(0, lt, 0, pg[0], bsize); + NFT_PIPAPO_AVX2_LOAD(2, map[i_ul]); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(1, lt, 1, pg[1], bsize); + NFT_PIPAPO_AVX2_NOMATCH_GOTO(2, nothing); + NFT_PIPAPO_AVX2_AND(3, 0, 1); + NFT_PIPAPO_AVX2_AND(4, 2, 3); + } + + NFT_PIPAPO_AVX2_NOMATCH_GOTO(4, nomatch); + NFT_PIPAPO_AVX2_STORE(map[i_ul], 4); + + b = nft_pipapo_avx2_refill(i_ul, &map[i_ul], fill, f->mt, last); + if (last) + return b; + + if (unlikely(ret == -1)) + ret = b / XSAVE_YMM_SIZE; + + continue; +nomatch: + NFT_PIPAPO_AVX2_STORE(map[i_ul], 15); +nothing: + ; + } + + return ret; +} + +/** + * nft_pipapo_avx2_lookup_4b_4() - AVX2-based lookup for 4 four-bit groups + * @map: Previous match result, used as initial bitmap + * @fill: Destination bitmap to be filled with current match result + * @f: Field, containing lookup and mapping tables + * @offset: Ignore buckets before the given index, no bits are filled there + * @pkt: Packet data, pointer to input nftables register + * @first: If this is the first field, don't source previous result + * @last: Last field: stop at the first match and return bit index + * + * See nft_pipapo_avx2_lookup_4b_2(). + * + * This is used for 16-bit fields (i.e. ports). + * + * Return: -1 on no match, rule index of match if @last, otherwise first long + * word index to be checked next (i.e. first filled word). + */ +static int nft_pipapo_avx2_lookup_4b_4(unsigned long *map, unsigned long *fill, + struct nft_pipapo_field *f, int offset, + const u8 *pkt, bool first, bool last) +{ + int i, ret = -1, m256_size = f->bsize / NFT_PIPAPO_LONGS_PER_M256, b; + u8 pg[4] = { pkt[0] >> 4, pkt[0] & 0xf, pkt[1] >> 4, pkt[1] & 0xf }; + unsigned long *lt = f->lt, bsize = f->bsize; + + lt += offset * NFT_PIPAPO_LONGS_PER_M256; + for (i = offset; i < m256_size; i++, lt += NFT_PIPAPO_LONGS_PER_M256) { + int i_ul = i * NFT_PIPAPO_LONGS_PER_M256; + + if (first) { + NFT_PIPAPO_AVX2_BUCKET_LOAD4(0, lt, 0, pg[0], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(1, lt, 1, pg[1], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(2, lt, 2, pg[2], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(3, lt, 3, pg[3], bsize); + NFT_PIPAPO_AVX2_AND(4, 0, 1); + NFT_PIPAPO_AVX2_AND(5, 2, 3); + NFT_PIPAPO_AVX2_AND(7, 4, 5); + } else { + NFT_PIPAPO_AVX2_BUCKET_LOAD4(0, lt, 0, pg[0], bsize); + + NFT_PIPAPO_AVX2_LOAD(1, map[i_ul]); + + NFT_PIPAPO_AVX2_BUCKET_LOAD4(2, lt, 1, pg[1], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(3, lt, 2, pg[2], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(4, lt, 3, pg[3], bsize); + NFT_PIPAPO_AVX2_AND(5, 0, 1); + + NFT_PIPAPO_AVX2_NOMATCH_GOTO(1, nothing); + + NFT_PIPAPO_AVX2_AND(6, 2, 3); + NFT_PIPAPO_AVX2_AND(7, 4, 5); + /* Stall */ + NFT_PIPAPO_AVX2_AND(7, 6, 7); + } + + /* Stall */ + NFT_PIPAPO_AVX2_NOMATCH_GOTO(7, nomatch); + NFT_PIPAPO_AVX2_STORE(map[i_ul], 7); + + b = nft_pipapo_avx2_refill(i_ul, &map[i_ul], fill, f->mt, last); + if (last) + return b; + + if (unlikely(ret == -1)) + ret = b / XSAVE_YMM_SIZE; + + continue; +nomatch: + NFT_PIPAPO_AVX2_STORE(map[i_ul], 15); +nothing: + ; + } + + return ret; +} + +/** + * nft_pipapo_avx2_lookup_4b_8() - AVX2-based lookup for 8 four-bit groups + * @map: Previous match result, used as initial bitmap + * @fill: Destination bitmap to be filled with current match result + * @f: Field, containing lookup and mapping tables + * @offset: Ignore buckets before the given index, no bits are filled there + * @pkt: Packet data, pointer to input nftables register + * @first: If this is the first field, don't source previous result + * @last: Last field: stop at the first match and return bit index + * + * See nft_pipapo_avx2_lookup_4b_2(). + * + * This is used for 32-bit fields (i.e. IPv4 addresses). + * + * Return: -1 on no match, rule index of match if @last, otherwise first long + * word index to be checked next (i.e. first filled word). + */ +static int nft_pipapo_avx2_lookup_4b_8(unsigned long *map, unsigned long *fill, + struct nft_pipapo_field *f, int offset, + const u8 *pkt, bool first, bool last) +{ + u8 pg[8] = { pkt[0] >> 4, pkt[0] & 0xf, pkt[1] >> 4, pkt[1] & 0xf, + pkt[2] >> 4, pkt[2] & 0xf, pkt[3] >> 4, pkt[3] & 0xf, + }; + int i, ret = -1, m256_size = f->bsize / NFT_PIPAPO_LONGS_PER_M256, b; + unsigned long *lt = f->lt, bsize = f->bsize; + + lt += offset * NFT_PIPAPO_LONGS_PER_M256; + for (i = offset; i < m256_size; i++, lt += NFT_PIPAPO_LONGS_PER_M256) { + int i_ul = i * NFT_PIPAPO_LONGS_PER_M256; + + if (first) { + NFT_PIPAPO_AVX2_BUCKET_LOAD4(0, lt, 0, pg[0], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(1, lt, 1, pg[1], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(2, lt, 2, pg[2], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(3, lt, 3, pg[3], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(4, lt, 4, pg[4], bsize); + NFT_PIPAPO_AVX2_AND(5, 0, 1); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(6, lt, 5, pg[5], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(7, lt, 6, pg[6], bsize); + NFT_PIPAPO_AVX2_AND(8, 2, 3); + NFT_PIPAPO_AVX2_AND(9, 4, 5); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(10, lt, 7, pg[7], bsize); + NFT_PIPAPO_AVX2_AND(11, 6, 7); + NFT_PIPAPO_AVX2_AND(12, 8, 9); + NFT_PIPAPO_AVX2_AND(13, 10, 11); + + /* Stall */ + NFT_PIPAPO_AVX2_AND(1, 12, 13); + } else { + NFT_PIPAPO_AVX2_BUCKET_LOAD4(0, lt, 0, pg[0], bsize); + NFT_PIPAPO_AVX2_LOAD(1, map[i_ul]); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(2, lt, 1, pg[1], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(3, lt, 2, pg[2], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(4, lt, 3, pg[3], bsize); + + NFT_PIPAPO_AVX2_NOMATCH_GOTO(1, nothing); + + NFT_PIPAPO_AVX2_AND(5, 0, 1); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(6, lt, 4, pg[4], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(7, lt, 5, pg[5], bsize); + NFT_PIPAPO_AVX2_AND(8, 2, 3); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(9, lt, 6, pg[6], bsize); + NFT_PIPAPO_AVX2_AND(10, 4, 5); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(11, lt, 7, pg[7], bsize); + NFT_PIPAPO_AVX2_AND(12, 6, 7); + NFT_PIPAPO_AVX2_AND(13, 8, 9); + NFT_PIPAPO_AVX2_AND(14, 10, 11); + + /* Stall */ + NFT_PIPAPO_AVX2_AND(1, 12, 13); + NFT_PIPAPO_AVX2_AND(1, 1, 14); + } + + NFT_PIPAPO_AVX2_NOMATCH_GOTO(1, nomatch); + NFT_PIPAPO_AVX2_STORE(map[i_ul], 1); + + b = nft_pipapo_avx2_refill(i_ul, &map[i_ul], fill, f->mt, last); + if (last) + return b; + + if (unlikely(ret == -1)) + ret = b / XSAVE_YMM_SIZE; + + continue; + +nomatch: + NFT_PIPAPO_AVX2_STORE(map[i_ul], 15); +nothing: + ; + } + + return ret; +} + +/** + * nft_pipapo_avx2_lookup_4b_12() - AVX2-based lookup for 12 four-bit groups + * @map: Previous match result, used as initial bitmap + * @fill: Destination bitmap to be filled with current match result + * @f: Field, containing lookup and mapping tables + * @offset: Ignore buckets before the given index, no bits are filled there + * @pkt: Packet data, pointer to input nftables register + * @first: If this is the first field, don't source previous result + * @last: Last field: stop at the first match and return bit index + * + * See nft_pipapo_avx2_lookup_4b_2(). + * + * This is used for 48-bit fields (i.e. MAC addresses/EUI-48). + * + * Return: -1 on no match, rule index of match if @last, otherwise first long + * word index to be checked next (i.e. first filled word). + */ +static int nft_pipapo_avx2_lookup_4b_12(unsigned long *map, unsigned long *fill, + struct nft_pipapo_field *f, int offset, + const u8 *pkt, bool first, bool last) +{ + u8 pg[12] = { pkt[0] >> 4, pkt[0] & 0xf, pkt[1] >> 4, pkt[1] & 0xf, + pkt[2] >> 4, pkt[2] & 0xf, pkt[3] >> 4, pkt[3] & 0xf, + pkt[4] >> 4, pkt[4] & 0xf, pkt[5] >> 4, pkt[5] & 0xf, + }; + int i, ret = -1, m256_size = f->bsize / NFT_PIPAPO_LONGS_PER_M256, b; + unsigned long *lt = f->lt, bsize = f->bsize; + + lt += offset * NFT_PIPAPO_LONGS_PER_M256; + for (i = offset; i < m256_size; i++, lt += NFT_PIPAPO_LONGS_PER_M256) { + int i_ul = i * NFT_PIPAPO_LONGS_PER_M256; + + if (!first) + NFT_PIPAPO_AVX2_LOAD(0, map[i_ul]); + + NFT_PIPAPO_AVX2_BUCKET_LOAD4(1, lt, 0, pg[0], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(2, lt, 1, pg[1], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(3, lt, 2, pg[2], bsize); + + if (!first) { + NFT_PIPAPO_AVX2_NOMATCH_GOTO(0, nothing); + NFT_PIPAPO_AVX2_AND(1, 1, 0); + } + + NFT_PIPAPO_AVX2_BUCKET_LOAD4(4, lt, 3, pg[3], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(5, lt, 4, pg[4], bsize); + NFT_PIPAPO_AVX2_AND(6, 2, 3); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(7, lt, 5, pg[5], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(8, lt, 6, pg[6], bsize); + NFT_PIPAPO_AVX2_AND(9, 1, 4); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(10, lt, 7, pg[7], bsize); + NFT_PIPAPO_AVX2_AND(11, 5, 6); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(12, lt, 8, pg[8], bsize); + NFT_PIPAPO_AVX2_AND(13, 7, 8); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(14, lt, 9, pg[9], bsize); + + NFT_PIPAPO_AVX2_AND(0, 9, 10); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(1, lt, 10, pg[10], bsize); + NFT_PIPAPO_AVX2_AND(2, 11, 12); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(3, lt, 11, pg[11], bsize); + NFT_PIPAPO_AVX2_AND(4, 13, 14); + NFT_PIPAPO_AVX2_AND(5, 0, 1); + + NFT_PIPAPO_AVX2_AND(6, 2, 3); + + /* Stalls */ + NFT_PIPAPO_AVX2_AND(7, 4, 5); + NFT_PIPAPO_AVX2_AND(8, 6, 7); + + NFT_PIPAPO_AVX2_NOMATCH_GOTO(8, nomatch); + NFT_PIPAPO_AVX2_STORE(map[i_ul], 8); + + b = nft_pipapo_avx2_refill(i_ul, &map[i_ul], fill, f->mt, last); + if (last) + return b; + + if (unlikely(ret == -1)) + ret = b / XSAVE_YMM_SIZE; + + continue; +nomatch: + NFT_PIPAPO_AVX2_STORE(map[i_ul], 15); +nothing: + ; + } + + return ret; +} + +/** + * nft_pipapo_avx2_lookup_4b_32() - AVX2-based lookup for 32 four-bit groups + * @map: Previous match result, used as initial bitmap + * @fill: Destination bitmap to be filled with current match result + * @f: Field, containing lookup and mapping tables + * @offset: Ignore buckets before the given index, no bits are filled there + * @pkt: Packet data, pointer to input nftables register + * @first: If this is the first field, don't source previous result + * @last: Last field: stop at the first match and return bit index + * + * See nft_pipapo_avx2_lookup_4b_2(). + * + * This is used for 128-bit fields (i.e. IPv6 addresses). + * + * Return: -1 on no match, rule index of match if @last, otherwise first long + * word index to be checked next (i.e. first filled word). + */ +static int nft_pipapo_avx2_lookup_4b_32(unsigned long *map, unsigned long *fill, + struct nft_pipapo_field *f, int offset, + const u8 *pkt, bool first, bool last) +{ + u8 pg[32] = { pkt[0] >> 4, pkt[0] & 0xf, pkt[1] >> 4, pkt[1] & 0xf, + pkt[2] >> 4, pkt[2] & 0xf, pkt[3] >> 4, pkt[3] & 0xf, + pkt[4] >> 4, pkt[4] & 0xf, pkt[5] >> 4, pkt[5] & 0xf, + pkt[6] >> 4, pkt[6] & 0xf, pkt[7] >> 4, pkt[7] & 0xf, + pkt[8] >> 4, pkt[8] & 0xf, pkt[9] >> 4, pkt[9] & 0xf, + pkt[10] >> 4, pkt[10] & 0xf, pkt[11] >> 4, pkt[11] & 0xf, + pkt[12] >> 4, pkt[12] & 0xf, pkt[13] >> 4, pkt[13] & 0xf, + pkt[14] >> 4, pkt[14] & 0xf, pkt[15] >> 4, pkt[15] & 0xf, + }; + int i, ret = -1, m256_size = f->bsize / NFT_PIPAPO_LONGS_PER_M256, b; + unsigned long *lt = f->lt, bsize = f->bsize; + + lt += offset * NFT_PIPAPO_LONGS_PER_M256; + for (i = offset; i < m256_size; i++, lt += NFT_PIPAPO_LONGS_PER_M256) { + int i_ul = i * NFT_PIPAPO_LONGS_PER_M256; + + if (!first) + NFT_PIPAPO_AVX2_LOAD(0, map[i_ul]); + + NFT_PIPAPO_AVX2_BUCKET_LOAD4(1, lt, 0, pg[0], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(2, lt, 1, pg[1], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(3, lt, 2, pg[2], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(4, lt, 3, pg[3], bsize); + if (!first) { + NFT_PIPAPO_AVX2_NOMATCH_GOTO(0, nothing); + NFT_PIPAPO_AVX2_AND(1, 1, 0); + } + + NFT_PIPAPO_AVX2_AND(5, 2, 3); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(6, lt, 4, pg[4], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(7, lt, 5, pg[5], bsize); + NFT_PIPAPO_AVX2_AND(8, 1, 4); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(9, lt, 6, pg[6], bsize); + NFT_PIPAPO_AVX2_AND(10, 5, 6); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(11, lt, 7, pg[7], bsize); + NFT_PIPAPO_AVX2_AND(12, 7, 8); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(13, lt, 8, pg[8], bsize); + NFT_PIPAPO_AVX2_AND(14, 9, 10); + + NFT_PIPAPO_AVX2_BUCKET_LOAD4(0, lt, 9, pg[9], bsize); + NFT_PIPAPO_AVX2_AND(1, 11, 12); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(2, lt, 10, pg[10], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(3, lt, 11, pg[11], bsize); + NFT_PIPAPO_AVX2_AND(4, 13, 14); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(5, lt, 12, pg[12], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(6, lt, 13, pg[13], bsize); + NFT_PIPAPO_AVX2_AND(7, 0, 1); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(8, lt, 14, pg[14], bsize); + NFT_PIPAPO_AVX2_AND(9, 2, 3); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(10, lt, 15, pg[15], bsize); + NFT_PIPAPO_AVX2_AND(11, 4, 5); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(12, lt, 16, pg[16], bsize); + NFT_PIPAPO_AVX2_AND(13, 6, 7); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(14, lt, 17, pg[17], bsize); + + NFT_PIPAPO_AVX2_AND(0, 8, 9); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(1, lt, 18, pg[18], bsize); + NFT_PIPAPO_AVX2_AND(2, 10, 11); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(3, lt, 19, pg[19], bsize); + NFT_PIPAPO_AVX2_AND(4, 12, 13); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(5, lt, 20, pg[20], bsize); + NFT_PIPAPO_AVX2_AND(6, 14, 0); + NFT_PIPAPO_AVX2_AND(7, 1, 2); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(8, lt, 21, pg[21], bsize); + NFT_PIPAPO_AVX2_AND(9, 3, 4); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(10, lt, 22, pg[22], bsize); + NFT_PIPAPO_AVX2_AND(11, 5, 6); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(12, lt, 23, pg[23], bsize); + NFT_PIPAPO_AVX2_AND(13, 7, 8); + + NFT_PIPAPO_AVX2_BUCKET_LOAD4(14, lt, 24, pg[24], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(0, lt, 25, pg[25], bsize); + NFT_PIPAPO_AVX2_AND(1, 9, 10); + NFT_PIPAPO_AVX2_AND(2, 11, 12); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(3, lt, 26, pg[26], bsize); + NFT_PIPAPO_AVX2_AND(4, 13, 14); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(5, lt, 27, pg[27], bsize); + NFT_PIPAPO_AVX2_AND(6, 0, 1); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(7, lt, 28, pg[28], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(8, lt, 29, pg[29], bsize); + NFT_PIPAPO_AVX2_AND(9, 2, 3); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(10, lt, 30, pg[30], bsize); + NFT_PIPAPO_AVX2_AND(11, 4, 5); + NFT_PIPAPO_AVX2_BUCKET_LOAD4(12, lt, 31, pg[31], bsize); + + NFT_PIPAPO_AVX2_AND(0, 6, 7); + NFT_PIPAPO_AVX2_AND(1, 8, 9); + NFT_PIPAPO_AVX2_AND(2, 10, 11); + NFT_PIPAPO_AVX2_AND(3, 12, 0); + + /* Stalls */ + NFT_PIPAPO_AVX2_AND(4, 1, 2); + NFT_PIPAPO_AVX2_AND(5, 3, 4); + + NFT_PIPAPO_AVX2_NOMATCH_GOTO(5, nomatch); + NFT_PIPAPO_AVX2_STORE(map[i_ul], 5); + + b = nft_pipapo_avx2_refill(i_ul, &map[i_ul], fill, f->mt, last); + if (last) + return b; + + if (unlikely(ret == -1)) + ret = b / XSAVE_YMM_SIZE; + + continue; +nomatch: + NFT_PIPAPO_AVX2_STORE(map[i_ul], 15); +nothing: + ; + } + + return ret; +} + +/** + * nft_pipapo_avx2_lookup_8b_1() - AVX2-based lookup for one eight-bit group + * @map: Previous match result, used as initial bitmap + * @fill: Destination bitmap to be filled with current match result + * @f: Field, containing lookup and mapping tables + * @offset: Ignore buckets before the given index, no bits are filled there + * @pkt: Packet data, pointer to input nftables register + * @first: If this is the first field, don't source previous result + * @last: Last field: stop at the first match and return bit index + * + * See nft_pipapo_avx2_lookup_4b_2(). + * + * This is used for 8-bit fields (i.e. protocol numbers). + * + * Return: -1 on no match, rule index of match if @last, otherwise first long + * word index to be checked next (i.e. first filled word). + */ +static int nft_pipapo_avx2_lookup_8b_1(unsigned long *map, unsigned long *fill, + struct nft_pipapo_field *f, int offset, + const u8 *pkt, bool first, bool last) +{ + int i, ret = -1, m256_size = f->bsize / NFT_PIPAPO_LONGS_PER_M256, b; + unsigned long *lt = f->lt, bsize = f->bsize; + + lt += offset * NFT_PIPAPO_LONGS_PER_M256; + for (i = offset; i < m256_size; i++, lt += NFT_PIPAPO_LONGS_PER_M256) { + int i_ul = i * NFT_PIPAPO_LONGS_PER_M256; + + if (first) { + NFT_PIPAPO_AVX2_BUCKET_LOAD8(2, lt, 0, pkt[0], bsize); + } else { + NFT_PIPAPO_AVX2_BUCKET_LOAD8(0, lt, 0, pkt[0], bsize); + NFT_PIPAPO_AVX2_LOAD(1, map[i_ul]); + NFT_PIPAPO_AVX2_AND(2, 0, 1); + NFT_PIPAPO_AVX2_NOMATCH_GOTO(1, nothing); + } + + NFT_PIPAPO_AVX2_NOMATCH_GOTO(2, nomatch); + NFT_PIPAPO_AVX2_STORE(map[i_ul], 2); + + b = nft_pipapo_avx2_refill(i_ul, &map[i_ul], fill, f->mt, last); + if (last) + return b; + + if (unlikely(ret == -1)) + ret = b / XSAVE_YMM_SIZE; + + continue; +nomatch: + NFT_PIPAPO_AVX2_STORE(map[i_ul], 15); +nothing: + ; + } + + return ret; +} + +/** + * nft_pipapo_avx2_lookup_8b_2() - AVX2-based lookup for 2 eight-bit groups + * @map: Previous match result, used as initial bitmap + * @fill: Destination bitmap to be filled with current match result + * @f: Field, containing lookup and mapping tables + * @offset: Ignore buckets before the given index, no bits are filled there + * @pkt: Packet data, pointer to input nftables register + * @first: If this is the first field, don't source previous result + * @last: Last field: stop at the first match and return bit index + * + * See nft_pipapo_avx2_lookup_4b_2(). + * + * This is used for 16-bit fields (i.e. ports). + * + * Return: -1 on no match, rule index of match if @last, otherwise first long + * word index to be checked next (i.e. first filled word). + */ +static int nft_pipapo_avx2_lookup_8b_2(unsigned long *map, unsigned long *fill, + struct nft_pipapo_field *f, int offset, + const u8 *pkt, bool first, bool last) +{ + int i, ret = -1, m256_size = f->bsize / NFT_PIPAPO_LONGS_PER_M256, b; + unsigned long *lt = f->lt, bsize = f->bsize; + + lt += offset * NFT_PIPAPO_LONGS_PER_M256; + for (i = offset; i < m256_size; i++, lt += NFT_PIPAPO_LONGS_PER_M256) { + int i_ul = i * NFT_PIPAPO_LONGS_PER_M256; + + if (first) { + NFT_PIPAPO_AVX2_BUCKET_LOAD8(0, lt, 0, pkt[0], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(1, lt, 1, pkt[1], bsize); + NFT_PIPAPO_AVX2_AND(4, 0, 1); + } else { + NFT_PIPAPO_AVX2_LOAD(0, map[i_ul]); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(1, lt, 0, pkt[0], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(2, lt, 1, pkt[1], bsize); + + /* Stall */ + NFT_PIPAPO_AVX2_AND(3, 0, 1); + NFT_PIPAPO_AVX2_NOMATCH_GOTO(0, nothing); + NFT_PIPAPO_AVX2_AND(4, 3, 2); + } + + /* Stall */ + NFT_PIPAPO_AVX2_NOMATCH_GOTO(4, nomatch); + NFT_PIPAPO_AVX2_STORE(map[i_ul], 4); + + b = nft_pipapo_avx2_refill(i_ul, &map[i_ul], fill, f->mt, last); + if (last) + return b; + + if (unlikely(ret == -1)) + ret = b / XSAVE_YMM_SIZE; + + continue; +nomatch: + NFT_PIPAPO_AVX2_STORE(map[i_ul], 15); +nothing: + ; + } + + return ret; +} + +/** + * nft_pipapo_avx2_lookup_8b_4() - AVX2-based lookup for 4 eight-bit groups + * @map: Previous match result, used as initial bitmap + * @fill: Destination bitmap to be filled with current match result + * @f: Field, containing lookup and mapping tables + * @offset: Ignore buckets before the given index, no bits are filled there + * @pkt: Packet data, pointer to input nftables register + * @first: If this is the first field, don't source previous result + * @last: Last field: stop at the first match and return bit index + * + * See nft_pipapo_avx2_lookup_4b_2(). + * + * This is used for 32-bit fields (i.e. IPv4 addresses). + * + * Return: -1 on no match, rule index of match if @last, otherwise first long + * word index to be checked next (i.e. first filled word). + */ +static int nft_pipapo_avx2_lookup_8b_4(unsigned long *map, unsigned long *fill, + struct nft_pipapo_field *f, int offset, + const u8 *pkt, bool first, bool last) +{ + int i, ret = -1, m256_size = f->bsize / NFT_PIPAPO_LONGS_PER_M256, b; + unsigned long *lt = f->lt, bsize = f->bsize; + + lt += offset * NFT_PIPAPO_LONGS_PER_M256; + for (i = offset; i < m256_size; i++, lt += NFT_PIPAPO_LONGS_PER_M256) { + int i_ul = i * NFT_PIPAPO_LONGS_PER_M256; + + if (first) { + NFT_PIPAPO_AVX2_BUCKET_LOAD8(0, lt, 0, pkt[0], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(1, lt, 1, pkt[1], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(2, lt, 2, pkt[2], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(3, lt, 3, pkt[3], bsize); + + /* Stall */ + NFT_PIPAPO_AVX2_AND(4, 0, 1); + NFT_PIPAPO_AVX2_AND(5, 2, 3); + NFT_PIPAPO_AVX2_AND(0, 4, 5); + } else { + NFT_PIPAPO_AVX2_BUCKET_LOAD8(0, lt, 0, pkt[0], bsize); + NFT_PIPAPO_AVX2_LOAD(1, map[i_ul]); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(2, lt, 1, pkt[1], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(3, lt, 2, pkt[2], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(4, lt, 3, pkt[3], bsize); + + NFT_PIPAPO_AVX2_AND(5, 0, 1); + NFT_PIPAPO_AVX2_NOMATCH_GOTO(1, nothing); + NFT_PIPAPO_AVX2_AND(6, 2, 3); + + /* Stall */ + NFT_PIPAPO_AVX2_AND(7, 4, 5); + NFT_PIPAPO_AVX2_AND(0, 6, 7); + } + + NFT_PIPAPO_AVX2_NOMATCH_GOTO(0, nomatch); + NFT_PIPAPO_AVX2_STORE(map[i_ul], 0); + + b = nft_pipapo_avx2_refill(i_ul, &map[i_ul], fill, f->mt, last); + if (last) + return b; + + if (unlikely(ret == -1)) + ret = b / XSAVE_YMM_SIZE; + + continue; + +nomatch: + NFT_PIPAPO_AVX2_STORE(map[i_ul], 15); +nothing: + ; + } + + return ret; +} + +/** + * nft_pipapo_avx2_lookup_8b_6() - AVX2-based lookup for 6 eight-bit groups + * @map: Previous match result, used as initial bitmap + * @fill: Destination bitmap to be filled with current match result + * @f: Field, containing lookup and mapping tables + * @offset: Ignore buckets before the given index, no bits are filled there + * @pkt: Packet data, pointer to input nftables register + * @first: If this is the first field, don't source previous result + * @last: Last field: stop at the first match and return bit index + * + * See nft_pipapo_avx2_lookup_4b_2(). + * + * This is used for 48-bit fields (i.e. MAC addresses/EUI-48). + * + * Return: -1 on no match, rule index of match if @last, otherwise first long + * word index to be checked next (i.e. first filled word). + */ +static int nft_pipapo_avx2_lookup_8b_6(unsigned long *map, unsigned long *fill, + struct nft_pipapo_field *f, int offset, + const u8 *pkt, bool first, bool last) +{ + int i, ret = -1, m256_size = f->bsize / NFT_PIPAPO_LONGS_PER_M256, b; + unsigned long *lt = f->lt, bsize = f->bsize; + + lt += offset * NFT_PIPAPO_LONGS_PER_M256; + for (i = offset; i < m256_size; i++, lt += NFT_PIPAPO_LONGS_PER_M256) { + int i_ul = i * NFT_PIPAPO_LONGS_PER_M256; + + if (first) { + NFT_PIPAPO_AVX2_BUCKET_LOAD8(0, lt, 0, pkt[0], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(1, lt, 1, pkt[1], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(2, lt, 2, pkt[2], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(3, lt, 3, pkt[3], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(4, lt, 4, pkt[4], bsize); + + NFT_PIPAPO_AVX2_AND(5, 0, 1); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(6, lt, 5, pkt[5], bsize); + NFT_PIPAPO_AVX2_AND(7, 2, 3); + + /* Stall */ + NFT_PIPAPO_AVX2_AND(0, 4, 5); + NFT_PIPAPO_AVX2_AND(1, 6, 7); + NFT_PIPAPO_AVX2_AND(4, 0, 1); + } else { + NFT_PIPAPO_AVX2_BUCKET_LOAD8(0, lt, 0, pkt[0], bsize); + NFT_PIPAPO_AVX2_LOAD(1, map[i_ul]); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(2, lt, 1, pkt[1], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(3, lt, 2, pkt[2], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(4, lt, 3, pkt[3], bsize); + + NFT_PIPAPO_AVX2_AND(5, 0, 1); + NFT_PIPAPO_AVX2_NOMATCH_GOTO(1, nothing); + + NFT_PIPAPO_AVX2_AND(6, 2, 3); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(7, lt, 4, pkt[4], bsize); + NFT_PIPAPO_AVX2_AND(0, 4, 5); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(1, lt, 5, pkt[5], bsize); + NFT_PIPAPO_AVX2_AND(2, 6, 7); + + /* Stall */ + NFT_PIPAPO_AVX2_AND(3, 0, 1); + NFT_PIPAPO_AVX2_AND(4, 2, 3); + } + + NFT_PIPAPO_AVX2_NOMATCH_GOTO(4, nomatch); + NFT_PIPAPO_AVX2_STORE(map[i_ul], 4); + + b = nft_pipapo_avx2_refill(i_ul, &map[i_ul], fill, f->mt, last); + if (last) + return b; + + if (unlikely(ret == -1)) + ret = b / XSAVE_YMM_SIZE; + + continue; + +nomatch: + NFT_PIPAPO_AVX2_STORE(map[i_ul], 15); +nothing: + ; + } + + return ret; +} + +/** + * nft_pipapo_avx2_lookup_8b_16() - AVX2-based lookup for 16 eight-bit groups + * @map: Previous match result, used as initial bitmap + * @fill: Destination bitmap to be filled with current match result + * @f: Field, containing lookup and mapping tables + * @offset: Ignore buckets before the given index, no bits are filled there + * @pkt: Packet data, pointer to input nftables register + * @first: If this is the first field, don't source previous result + * @last: Last field: stop at the first match and return bit index + * + * See nft_pipapo_avx2_lookup_4b_2(). + * + * This is used for 128-bit fields (i.e. IPv6 addresses). + * + * Return: -1 on no match, rule index of match if @last, otherwise first long + * word index to be checked next (i.e. first filled word). + */ +static int nft_pipapo_avx2_lookup_8b_16(unsigned long *map, unsigned long *fill, + struct nft_pipapo_field *f, int offset, + const u8 *pkt, bool first, bool last) +{ + int i, ret = -1, m256_size = f->bsize / NFT_PIPAPO_LONGS_PER_M256, b; + unsigned long *lt = f->lt, bsize = f->bsize; + + lt += offset * NFT_PIPAPO_LONGS_PER_M256; + for (i = offset; i < m256_size; i++, lt += NFT_PIPAPO_LONGS_PER_M256) { + int i_ul = i * NFT_PIPAPO_LONGS_PER_M256; + + if (!first) + NFT_PIPAPO_AVX2_LOAD(0, map[i_ul]); + + NFT_PIPAPO_AVX2_BUCKET_LOAD8(1, lt, 0, pkt[0], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(2, lt, 1, pkt[1], bsize); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(3, lt, 2, pkt[2], bsize); + if (!first) { + NFT_PIPAPO_AVX2_NOMATCH_GOTO(0, nothing); + NFT_PIPAPO_AVX2_AND(1, 1, 0); + } + NFT_PIPAPO_AVX2_BUCKET_LOAD8(4, lt, 3, pkt[3], bsize); + + NFT_PIPAPO_AVX2_BUCKET_LOAD8(5, lt, 4, pkt[4], bsize); + NFT_PIPAPO_AVX2_AND(6, 1, 2); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(7, lt, 5, pkt[5], bsize); + NFT_PIPAPO_AVX2_AND(0, 3, 4); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(1, lt, 6, pkt[6], bsize); + + NFT_PIPAPO_AVX2_BUCKET_LOAD8(2, lt, 7, pkt[7], bsize); + NFT_PIPAPO_AVX2_AND(3, 5, 6); + NFT_PIPAPO_AVX2_AND(4, 0, 1); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(5, lt, 8, pkt[8], bsize); + + NFT_PIPAPO_AVX2_AND(6, 2, 3); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(7, lt, 9, pkt[9], bsize); + NFT_PIPAPO_AVX2_AND(0, 4, 5); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(1, lt, 10, pkt[10], bsize); + NFT_PIPAPO_AVX2_AND(2, 6, 7); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(3, lt, 11, pkt[11], bsize); + NFT_PIPAPO_AVX2_AND(4, 0, 1); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(5, lt, 12, pkt[12], bsize); + NFT_PIPAPO_AVX2_AND(6, 2, 3); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(7, lt, 13, pkt[13], bsize); + NFT_PIPAPO_AVX2_AND(0, 4, 5); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(1, lt, 14, pkt[14], bsize); + NFT_PIPAPO_AVX2_AND(2, 6, 7); + NFT_PIPAPO_AVX2_BUCKET_LOAD8(3, lt, 15, pkt[15], bsize); + NFT_PIPAPO_AVX2_AND(4, 0, 1); + + /* Stall */ + NFT_PIPAPO_AVX2_AND(5, 2, 3); + NFT_PIPAPO_AVX2_AND(6, 4, 5); + + NFT_PIPAPO_AVX2_NOMATCH_GOTO(6, nomatch); + NFT_PIPAPO_AVX2_STORE(map[i_ul], 6); + + b = nft_pipapo_avx2_refill(i_ul, &map[i_ul], fill, f->mt, last); + if (last) + return b; + + if (unlikely(ret == -1)) + ret = b / XSAVE_YMM_SIZE; + + continue; + +nomatch: + NFT_PIPAPO_AVX2_STORE(map[i_ul], 15); +nothing: + ; + } + + return ret; +} + +/** + * nft_pipapo_avx2_lookup_slow() - Fallback function for uncommon field sizes + * @map: Previous match result, used as initial bitmap + * @fill: Destination bitmap to be filled with current match result + * @f: Field, containing lookup and mapping tables + * @offset: Ignore buckets before the given index, no bits are filled there + * @pkt: Packet data, pointer to input nftables register + * @first: If this is the first field, don't source previous result + * @last: Last field: stop at the first match and return bit index + * + * This function should never be called, but is provided for the case the field + * size doesn't match any of the known data types. Matching rate is + * substantially lower than AVX2 routines. + * + * Return: -1 on no match, rule index of match if @last, otherwise first long + * word index to be checked next (i.e. first filled word). + */ +static int nft_pipapo_avx2_lookup_slow(unsigned long *map, unsigned long *fill, + struct nft_pipapo_field *f, int offset, + const u8 *pkt, bool first, bool last) +{ + unsigned long bsize = f->bsize; + int i, ret = -1, b; + + if (first) + memset(map, 0xff, bsize * sizeof(*map)); + + for (i = offset; i < bsize; i++) { + if (f->bb == 8) + pipapo_and_field_buckets_8bit(f, map, pkt); + else + pipapo_and_field_buckets_4bit(f, map, pkt); + NFT_PIPAPO_GROUP_BITS_ARE_8_OR_4; + + b = pipapo_refill(map, bsize, f->rules, fill, f->mt, last); + + if (last) + return b; + + if (ret == -1) + ret = b / XSAVE_YMM_SIZE; + } + + return ret; +} + +/** + * nft_pipapo_avx2_estimate() - Set size, space and lookup complexity + * @desc: Set description, element count and field description used + * @features: Flags: NFT_SET_INTERVAL needs to be there + * @est: Storage for estimation data + * + * Return: true if set is compatible and AVX2 available, false otherwise. + */ +bool nft_pipapo_avx2_estimate(const struct nft_set_desc *desc, u32 features, + struct nft_set_estimate *est) +{ + if (!(features & NFT_SET_INTERVAL) || + desc->field_count < NFT_PIPAPO_MIN_FIELDS) + return false; + + if (!boot_cpu_has(X86_FEATURE_AVX2) || !boot_cpu_has(X86_FEATURE_AVX)) + return false; + + est->size = pipapo_estimate_size(desc); + if (!est->size) + return false; + + est->lookup = NFT_SET_CLASS_O_LOG_N; + + est->space = NFT_SET_CLASS_O_N; + + return true; +} + +/** + * nft_pipapo_avx2_lookup() - Lookup function for AVX2 implementation + * @net: Network namespace + * @set: nftables API set representation + * @key: nftables API element representation containing key data + * @ext: nftables API extension pointer, filled with matching reference + * + * For more details, see DOC: Theory of Operation in nft_set_pipapo.c. + * + * This implementation exploits the repetitive characteristic of the algorithm + * to provide a fast, vectorised version using the AVX2 SIMD instruction set. + * + * Return: true on match, false otherwise. + */ +bool nft_pipapo_avx2_lookup(const struct net *net, const struct nft_set *set, + const u32 *key, const struct nft_set_ext **ext) +{ + struct nft_pipapo *priv = nft_set_priv(set); + unsigned long *res, *fill, *scratch; + u8 genmask = nft_genmask_cur(net); + const u8 *rp = (const u8 *)key; + struct nft_pipapo_match *m; + struct nft_pipapo_field *f; + bool map_index; + int i, ret = 0; + + if (unlikely(!irq_fpu_usable())) + return nft_pipapo_lookup(net, set, key, ext); + + m = rcu_dereference(priv->match); + + /* This also protects access to all data related to scratch maps. + * + * Note that we don't need a valid MXCSR state for any of the + * operations we use here, so pass 0 as mask and spare a LDMXCSR + * instruction. + */ + kernel_fpu_begin_mask(0); + + scratch = *raw_cpu_ptr(m->scratch_aligned); + if (unlikely(!scratch)) { + kernel_fpu_end(); + return false; + } + map_index = raw_cpu_read(nft_pipapo_avx2_scratch_index); + + res = scratch + (map_index ? m->bsize_max : 0); + fill = scratch + (map_index ? 0 : m->bsize_max); + + /* Starting map doesn't need to be set for this implementation */ + + nft_pipapo_avx2_prepare(); + +next_match: + nft_pipapo_for_each_field(f, i, m) { + bool last = i == m->field_count - 1, first = !i; + +#define NFT_SET_PIPAPO_AVX2_LOOKUP(b, n) \ + (ret = nft_pipapo_avx2_lookup_##b##b_##n(res, fill, f, \ + ret, rp, \ + first, last)) + + if (likely(f->bb == 8)) { + if (f->groups == 1) { + NFT_SET_PIPAPO_AVX2_LOOKUP(8, 1); + } else if (f->groups == 2) { + NFT_SET_PIPAPO_AVX2_LOOKUP(8, 2); + } else if (f->groups == 4) { + NFT_SET_PIPAPO_AVX2_LOOKUP(8, 4); + } else if (f->groups == 6) { + NFT_SET_PIPAPO_AVX2_LOOKUP(8, 6); + } else if (f->groups == 16) { + NFT_SET_PIPAPO_AVX2_LOOKUP(8, 16); + } else { + ret = nft_pipapo_avx2_lookup_slow(res, fill, f, + ret, rp, + first, last); + } + } else { + if (f->groups == 2) { + NFT_SET_PIPAPO_AVX2_LOOKUP(4, 2); + } else if (f->groups == 4) { + NFT_SET_PIPAPO_AVX2_LOOKUP(4, 4); + } else if (f->groups == 8) { + NFT_SET_PIPAPO_AVX2_LOOKUP(4, 8); + } else if (f->groups == 12) { + NFT_SET_PIPAPO_AVX2_LOOKUP(4, 12); + } else if (f->groups == 32) { + NFT_SET_PIPAPO_AVX2_LOOKUP(4, 32); + } else { + ret = nft_pipapo_avx2_lookup_slow(res, fill, f, + ret, rp, + first, last); + } + } + NFT_PIPAPO_GROUP_BITS_ARE_8_OR_4; + +#undef NFT_SET_PIPAPO_AVX2_LOOKUP + + if (ret < 0) + goto out; + + if (last) { + *ext = &f->mt[ret].e->ext; + if (unlikely(nft_set_elem_expired(*ext) || + !nft_set_elem_active(*ext, genmask))) { + ret = 0; + goto next_match; + } + + goto out; + } + + swap(res, fill); + rp += NFT_PIPAPO_GROUPS_PADDED_SIZE(f); + } + +out: + if (i % 2) + raw_cpu_write(nft_pipapo_avx2_scratch_index, !map_index); + kernel_fpu_end(); + + return ret >= 0; +} diff --git a/net/netfilter/nft_set_pipapo_avx2.h b/net/netfilter/nft_set_pipapo_avx2.h new file mode 100644 index 000000000..dbb6aaca8 --- /dev/null +++ b/net/netfilter/nft_set_pipapo_avx2.h @@ -0,0 +1,12 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +#ifndef _NFT_SET_PIPAPO_AVX2_H + +#if defined(CONFIG_X86_64) && !defined(CONFIG_UML) +#include +#define NFT_PIPAPO_ALIGN (XSAVE_YMM_SIZE / BITS_PER_BYTE) + +bool nft_pipapo_avx2_estimate(const struct nft_set_desc *desc, u32 features, + struct nft_set_estimate *est); +#endif /* defined(CONFIG_X86_64) && !defined(CONFIG_UML) */ + +#endif /* _NFT_SET_PIPAPO_AVX2_H */ diff --git a/net/netfilter/nft_set_rbtree.c b/net/netfilter/nft_set_rbtree.c new file mode 100644 index 000000000..e34662f4a --- /dev/null +++ b/net/netfilter/nft_set_rbtree.c @@ -0,0 +1,773 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008-2009 Patrick McHardy + * + * Development of this code funded by Astaro AG (http://www.astaro.com/) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_rbtree { + struct rb_root root; + rwlock_t lock; + seqcount_rwlock_t count; + struct delayed_work gc_work; +}; + +struct nft_rbtree_elem { + struct rb_node node; + struct nft_set_ext ext; +}; + +static bool nft_rbtree_interval_end(const struct nft_rbtree_elem *rbe) +{ + return nft_set_ext_exists(&rbe->ext, NFT_SET_EXT_FLAGS) && + (*nft_set_ext_flags(&rbe->ext) & NFT_SET_ELEM_INTERVAL_END); +} + +static bool nft_rbtree_interval_start(const struct nft_rbtree_elem *rbe) +{ + return !nft_rbtree_interval_end(rbe); +} + +static int nft_rbtree_cmp(const struct nft_set *set, + const struct nft_rbtree_elem *e1, + const struct nft_rbtree_elem *e2) +{ + return memcmp(nft_set_ext_key(&e1->ext), nft_set_ext_key(&e2->ext), + set->klen); +} + +static bool nft_rbtree_elem_expired(const struct nft_rbtree_elem *rbe) +{ + return nft_set_elem_expired(&rbe->ext) || + nft_set_elem_is_dead(&rbe->ext); +} + +static bool __nft_rbtree_lookup(const struct net *net, const struct nft_set *set, + const u32 *key, const struct nft_set_ext **ext, + unsigned int seq) +{ + struct nft_rbtree *priv = nft_set_priv(set); + const struct nft_rbtree_elem *rbe, *interval = NULL; + u8 genmask = nft_genmask_cur(net); + const struct rb_node *parent; + int d; + + parent = rcu_dereference_raw(priv->root.rb_node); + while (parent != NULL) { + if (read_seqcount_retry(&priv->count, seq)) + return false; + + rbe = rb_entry(parent, struct nft_rbtree_elem, node); + + d = memcmp(nft_set_ext_key(&rbe->ext), key, set->klen); + if (d < 0) { + parent = rcu_dereference_raw(parent->rb_left); + if (interval && + !nft_rbtree_cmp(set, rbe, interval) && + nft_rbtree_interval_end(rbe) && + nft_rbtree_interval_start(interval)) + continue; + interval = rbe; + } else if (d > 0) + parent = rcu_dereference_raw(parent->rb_right); + else { + if (!nft_set_elem_active(&rbe->ext, genmask)) { + parent = rcu_dereference_raw(parent->rb_left); + continue; + } + + if (nft_rbtree_elem_expired(rbe)) + return false; + + if (nft_rbtree_interval_end(rbe)) { + if (nft_set_is_anonymous(set)) + return false; + parent = rcu_dereference_raw(parent->rb_left); + interval = NULL; + continue; + } + + *ext = &rbe->ext; + return true; + } + } + + if (set->flags & NFT_SET_INTERVAL && interval != NULL && + nft_set_elem_active(&interval->ext, genmask) && + !nft_rbtree_elem_expired(interval) && + nft_rbtree_interval_start(interval)) { + *ext = &interval->ext; + return true; + } + + return false; +} + +INDIRECT_CALLABLE_SCOPE +bool nft_rbtree_lookup(const struct net *net, const struct nft_set *set, + const u32 *key, const struct nft_set_ext **ext) +{ + struct nft_rbtree *priv = nft_set_priv(set); + unsigned int seq = read_seqcount_begin(&priv->count); + bool ret; + + ret = __nft_rbtree_lookup(net, set, key, ext, seq); + if (ret || !read_seqcount_retry(&priv->count, seq)) + return ret; + + read_lock_bh(&priv->lock); + seq = read_seqcount_begin(&priv->count); + ret = __nft_rbtree_lookup(net, set, key, ext, seq); + read_unlock_bh(&priv->lock); + + return ret; +} + +static bool __nft_rbtree_get(const struct net *net, const struct nft_set *set, + const u32 *key, struct nft_rbtree_elem **elem, + unsigned int seq, unsigned int flags, u8 genmask) +{ + struct nft_rbtree_elem *rbe, *interval = NULL; + struct nft_rbtree *priv = nft_set_priv(set); + const struct rb_node *parent; + const void *this; + int d; + + parent = rcu_dereference_raw(priv->root.rb_node); + while (parent != NULL) { + if (read_seqcount_retry(&priv->count, seq)) + return false; + + rbe = rb_entry(parent, struct nft_rbtree_elem, node); + + this = nft_set_ext_key(&rbe->ext); + d = memcmp(this, key, set->klen); + if (d < 0) { + parent = rcu_dereference_raw(parent->rb_left); + if (!(flags & NFT_SET_ELEM_INTERVAL_END)) + interval = rbe; + } else if (d > 0) { + parent = rcu_dereference_raw(parent->rb_right); + if (flags & NFT_SET_ELEM_INTERVAL_END) + interval = rbe; + } else { + if (!nft_set_elem_active(&rbe->ext, genmask)) { + parent = rcu_dereference_raw(parent->rb_left); + continue; + } + + if (nft_set_elem_expired(&rbe->ext)) + return false; + + if (!nft_set_ext_exists(&rbe->ext, NFT_SET_EXT_FLAGS) || + (*nft_set_ext_flags(&rbe->ext) & NFT_SET_ELEM_INTERVAL_END) == + (flags & NFT_SET_ELEM_INTERVAL_END)) { + *elem = rbe; + return true; + } + + if (nft_rbtree_interval_end(rbe)) + interval = NULL; + + parent = rcu_dereference_raw(parent->rb_left); + } + } + + if (set->flags & NFT_SET_INTERVAL && interval != NULL && + nft_set_elem_active(&interval->ext, genmask) && + !nft_set_elem_expired(&interval->ext) && + ((!nft_rbtree_interval_end(interval) && + !(flags & NFT_SET_ELEM_INTERVAL_END)) || + (nft_rbtree_interval_end(interval) && + (flags & NFT_SET_ELEM_INTERVAL_END)))) { + *elem = interval; + return true; + } + + return false; +} + +static void *nft_rbtree_get(const struct net *net, const struct nft_set *set, + const struct nft_set_elem *elem, unsigned int flags) +{ + struct nft_rbtree *priv = nft_set_priv(set); + unsigned int seq = read_seqcount_begin(&priv->count); + struct nft_rbtree_elem *rbe = ERR_PTR(-ENOENT); + const u32 *key = (const u32 *)&elem->key.val; + u8 genmask = nft_genmask_cur(net); + bool ret; + + ret = __nft_rbtree_get(net, set, key, &rbe, seq, flags, genmask); + if (ret || !read_seqcount_retry(&priv->count, seq)) + return rbe; + + read_lock_bh(&priv->lock); + seq = read_seqcount_begin(&priv->count); + ret = __nft_rbtree_get(net, set, key, &rbe, seq, flags, genmask); + if (!ret) + rbe = ERR_PTR(-ENOENT); + read_unlock_bh(&priv->lock); + + return rbe; +} + +static void nft_rbtree_gc_remove(struct net *net, struct nft_set *set, + struct nft_rbtree *priv, + struct nft_rbtree_elem *rbe) +{ + struct nft_set_elem elem = { + .priv = rbe, + }; + + nft_setelem_data_deactivate(net, set, &elem); + rb_erase(&rbe->node, &priv->root); +} + +static const struct nft_rbtree_elem * +nft_rbtree_gc_elem(const struct nft_set *__set, struct nft_rbtree *priv, + struct nft_rbtree_elem *rbe, u8 genmask) +{ + struct nft_set *set = (struct nft_set *)__set; + struct rb_node *prev = rb_prev(&rbe->node); + struct net *net = read_pnet(&set->net); + struct nft_rbtree_elem *rbe_prev; + struct nft_trans_gc *gc; + + gc = nft_trans_gc_alloc(set, 0, GFP_ATOMIC); + if (!gc) + return ERR_PTR(-ENOMEM); + + /* search for end interval coming before this element. + * end intervals don't carry a timeout extension, they + * are coupled with the interval start element. + */ + while (prev) { + rbe_prev = rb_entry(prev, struct nft_rbtree_elem, node); + if (nft_rbtree_interval_end(rbe_prev) && + nft_set_elem_active(&rbe_prev->ext, genmask)) + break; + + prev = rb_prev(prev); + } + + rbe_prev = NULL; + if (prev) { + rbe_prev = rb_entry(prev, struct nft_rbtree_elem, node); + nft_rbtree_gc_remove(net, set, priv, rbe_prev); + + /* There is always room in this trans gc for this element, + * memory allocation never actually happens, hence, the warning + * splat in such case. No need to set NFT_SET_ELEM_DEAD_BIT, + * this is synchronous gc which never fails. + */ + gc = nft_trans_gc_queue_sync(gc, GFP_ATOMIC); + if (WARN_ON_ONCE(!gc)) + return ERR_PTR(-ENOMEM); + + nft_trans_gc_elem_add(gc, rbe_prev); + } + + nft_rbtree_gc_remove(net, set, priv, rbe); + gc = nft_trans_gc_queue_sync(gc, GFP_ATOMIC); + if (WARN_ON_ONCE(!gc)) + return ERR_PTR(-ENOMEM); + + nft_trans_gc_elem_add(gc, rbe); + + nft_trans_gc_queue_sync_done(gc); + + return rbe_prev; +} + +static bool nft_rbtree_update_first(const struct nft_set *set, + struct nft_rbtree_elem *rbe, + struct rb_node *first) +{ + struct nft_rbtree_elem *first_elem; + + first_elem = rb_entry(first, struct nft_rbtree_elem, node); + /* this element is closest to where the new element is to be inserted: + * update the first element for the node list path. + */ + if (nft_rbtree_cmp(set, rbe, first_elem) < 0) + return true; + + return false; +} + +static int __nft_rbtree_insert(const struct net *net, const struct nft_set *set, + struct nft_rbtree_elem *new, + struct nft_set_ext **ext) +{ + struct nft_rbtree_elem *rbe, *rbe_le = NULL, *rbe_ge = NULL; + struct rb_node *node, *next, *parent, **p, *first = NULL; + struct nft_rbtree *priv = nft_set_priv(set); + u8 cur_genmask = nft_genmask_cur(net); + u8 genmask = nft_genmask_next(net); + int d; + + /* Descend the tree to search for an existing element greater than the + * key value to insert that is greater than the new element. This is the + * first element to walk the ordered elements to find possible overlap. + */ + parent = NULL; + p = &priv->root.rb_node; + while (*p != NULL) { + parent = *p; + rbe = rb_entry(parent, struct nft_rbtree_elem, node); + d = nft_rbtree_cmp(set, rbe, new); + + if (d < 0) { + p = &parent->rb_left; + } else if (d > 0) { + if (!first || + nft_rbtree_update_first(set, rbe, first)) + first = &rbe->node; + + p = &parent->rb_right; + } else { + if (nft_rbtree_interval_end(rbe)) + p = &parent->rb_left; + else + p = &parent->rb_right; + } + } + + if (!first) + first = rb_first(&priv->root); + + /* Detect overlap by going through the list of valid tree nodes. + * Values stored in the tree are in reversed order, starting from + * highest to lowest value. + */ + for (node = first; node != NULL; node = next) { + next = rb_next(node); + + rbe = rb_entry(node, struct nft_rbtree_elem, node); + + if (!nft_set_elem_active(&rbe->ext, genmask)) + continue; + + /* perform garbage collection to avoid bogus overlap reports + * but skip new elements in this transaction. + */ + if (nft_set_elem_expired(&rbe->ext) && + nft_set_elem_active(&rbe->ext, cur_genmask)) { + const struct nft_rbtree_elem *removed_end; + + removed_end = nft_rbtree_gc_elem(set, priv, rbe, genmask); + if (IS_ERR(removed_end)) + return PTR_ERR(removed_end); + + if (removed_end == rbe_le || removed_end == rbe_ge) + return -EAGAIN; + + continue; + } + + d = nft_rbtree_cmp(set, rbe, new); + if (d == 0) { + /* Matching end element: no need to look for an + * overlapping greater or equal element. + */ + if (nft_rbtree_interval_end(rbe)) { + rbe_le = rbe; + break; + } + + /* first element that is greater or equal to key value. */ + if (!rbe_ge) { + rbe_ge = rbe; + continue; + } + + /* this is a closer more or equal element, update it. */ + if (nft_rbtree_cmp(set, rbe_ge, new) != 0) { + rbe_ge = rbe; + continue; + } + + /* element is equal to key value, make sure flags are + * the same, an existing more or equal start element + * must not be replaced by more or equal end element. + */ + if ((nft_rbtree_interval_start(new) && + nft_rbtree_interval_start(rbe_ge)) || + (nft_rbtree_interval_end(new) && + nft_rbtree_interval_end(rbe_ge))) { + rbe_ge = rbe; + continue; + } + } else if (d > 0) { + /* annotate element greater than the new element. */ + rbe_ge = rbe; + continue; + } else if (d < 0) { + /* annotate element less than the new element. */ + rbe_le = rbe; + break; + } + } + + /* - new start element matching existing start element: full overlap + * reported as -EEXIST, cleared by caller if NLM_F_EXCL is not given. + */ + if (rbe_ge && !nft_rbtree_cmp(set, new, rbe_ge) && + nft_rbtree_interval_start(rbe_ge) == nft_rbtree_interval_start(new)) { + *ext = &rbe_ge->ext; + return -EEXIST; + } + + /* - new end element matching existing end element: full overlap + * reported as -EEXIST, cleared by caller if NLM_F_EXCL is not given. + */ + if (rbe_le && !nft_rbtree_cmp(set, new, rbe_le) && + nft_rbtree_interval_end(rbe_le) == nft_rbtree_interval_end(new)) { + *ext = &rbe_le->ext; + return -EEXIST; + } + + /* - new start element with existing closest, less or equal key value + * being a start element: partial overlap, reported as -ENOTEMPTY. + * Anonymous sets allow for two consecutive start element since they + * are constant, skip them to avoid bogus overlap reports. + */ + if (!nft_set_is_anonymous(set) && rbe_le && + nft_rbtree_interval_start(rbe_le) && nft_rbtree_interval_start(new)) + return -ENOTEMPTY; + + /* - new end element with existing closest, less or equal key value + * being a end element: partial overlap, reported as -ENOTEMPTY. + */ + if (rbe_le && + nft_rbtree_interval_end(rbe_le) && nft_rbtree_interval_end(new)) + return -ENOTEMPTY; + + /* - new end element with existing closest, greater or equal key value + * being an end element: partial overlap, reported as -ENOTEMPTY + */ + if (rbe_ge && + nft_rbtree_interval_end(rbe_ge) && nft_rbtree_interval_end(new)) + return -ENOTEMPTY; + + /* Accepted element: pick insertion point depending on key value */ + parent = NULL; + p = &priv->root.rb_node; + while (*p != NULL) { + parent = *p; + rbe = rb_entry(parent, struct nft_rbtree_elem, node); + d = nft_rbtree_cmp(set, rbe, new); + + if (d < 0) + p = &parent->rb_left; + else if (d > 0) + p = &parent->rb_right; + else if (nft_rbtree_interval_end(rbe)) + p = &parent->rb_left; + else + p = &parent->rb_right; + } + + rb_link_node_rcu(&new->node, parent, p); + rb_insert_color(&new->node, &priv->root); + return 0; +} + +static int nft_rbtree_insert(const struct net *net, const struct nft_set *set, + const struct nft_set_elem *elem, + struct nft_set_ext **ext) +{ + struct nft_rbtree *priv = nft_set_priv(set); + struct nft_rbtree_elem *rbe = elem->priv; + int err; + + do { + if (fatal_signal_pending(current)) + return -EINTR; + + cond_resched(); + + write_lock_bh(&priv->lock); + write_seqcount_begin(&priv->count); + err = __nft_rbtree_insert(net, set, rbe, ext); + write_seqcount_end(&priv->count); + write_unlock_bh(&priv->lock); + } while (err == -EAGAIN); + + return err; +} + +static void nft_rbtree_remove(const struct net *net, + const struct nft_set *set, + const struct nft_set_elem *elem) +{ + struct nft_rbtree *priv = nft_set_priv(set); + struct nft_rbtree_elem *rbe = elem->priv; + + write_lock_bh(&priv->lock); + write_seqcount_begin(&priv->count); + rb_erase(&rbe->node, &priv->root); + write_seqcount_end(&priv->count); + write_unlock_bh(&priv->lock); +} + +static void nft_rbtree_activate(const struct net *net, + const struct nft_set *set, + const struct nft_set_elem *elem) +{ + struct nft_rbtree_elem *rbe = elem->priv; + + nft_set_elem_change_active(net, set, &rbe->ext); +} + +static bool nft_rbtree_flush(const struct net *net, + const struct nft_set *set, void *priv) +{ + struct nft_rbtree_elem *rbe = priv; + + nft_set_elem_change_active(net, set, &rbe->ext); + + return true; +} + +static void *nft_rbtree_deactivate(const struct net *net, + const struct nft_set *set, + const struct nft_set_elem *elem) +{ + const struct nft_rbtree *priv = nft_set_priv(set); + const struct rb_node *parent = priv->root.rb_node; + struct nft_rbtree_elem *rbe, *this = elem->priv; + u8 genmask = nft_genmask_next(net); + int d; + + while (parent != NULL) { + rbe = rb_entry(parent, struct nft_rbtree_elem, node); + + d = memcmp(nft_set_ext_key(&rbe->ext), &elem->key.val, + set->klen); + if (d < 0) + parent = parent->rb_left; + else if (d > 0) + parent = parent->rb_right; + else { + if (nft_rbtree_interval_end(rbe) && + nft_rbtree_interval_start(this)) { + parent = parent->rb_left; + continue; + } else if (nft_rbtree_interval_start(rbe) && + nft_rbtree_interval_end(this)) { + parent = parent->rb_right; + continue; + } else if (nft_set_elem_expired(&rbe->ext)) { + break; + } else if (!nft_set_elem_active(&rbe->ext, genmask)) { + parent = parent->rb_left; + continue; + } + nft_rbtree_flush(net, set, rbe); + return rbe; + } + } + return NULL; +} + +static void nft_rbtree_walk(const struct nft_ctx *ctx, + struct nft_set *set, + struct nft_set_iter *iter) +{ + struct nft_rbtree *priv = nft_set_priv(set); + struct nft_rbtree_elem *rbe; + struct nft_set_elem elem; + struct rb_node *node; + + read_lock_bh(&priv->lock); + for (node = rb_first(&priv->root); node != NULL; node = rb_next(node)) { + rbe = rb_entry(node, struct nft_rbtree_elem, node); + + if (iter->count < iter->skip) + goto cont; + if (!nft_set_elem_active(&rbe->ext, iter->genmask)) + goto cont; + + elem.priv = rbe; + + iter->err = iter->fn(ctx, set, iter, &elem); + if (iter->err < 0) { + read_unlock_bh(&priv->lock); + return; + } +cont: + iter->count++; + } + read_unlock_bh(&priv->lock); +} + +static void nft_rbtree_gc(struct work_struct *work) +{ + struct nft_rbtree_elem *rbe, *rbe_end = NULL; + struct nftables_pernet *nft_net; + struct nft_rbtree *priv; + struct nft_trans_gc *gc; + struct rb_node *node; + struct nft_set *set; + unsigned int gc_seq; + struct net *net; + + priv = container_of(work, struct nft_rbtree, gc_work.work); + set = nft_set_container_of(priv); + net = read_pnet(&set->net); + nft_net = nft_pernet(net); + gc_seq = READ_ONCE(nft_net->gc_seq); + + if (nft_set_gc_is_pending(set)) + goto done; + + gc = nft_trans_gc_alloc(set, gc_seq, GFP_KERNEL); + if (!gc) + goto done; + + read_lock_bh(&priv->lock); + for (node = rb_first(&priv->root); node != NULL; node = rb_next(node)) { + + /* Ruleset has been updated, try later. */ + if (READ_ONCE(nft_net->gc_seq) != gc_seq) { + nft_trans_gc_destroy(gc); + gc = NULL; + goto try_later; + } + + rbe = rb_entry(node, struct nft_rbtree_elem, node); + + if (nft_set_elem_is_dead(&rbe->ext)) + goto dead_elem; + + /* elements are reversed in the rbtree for historical reasons, + * from highest to lowest value, that is why end element is + * always visited before the start element. + */ + if (nft_rbtree_interval_end(rbe)) { + rbe_end = rbe; + continue; + } + if (!nft_set_elem_expired(&rbe->ext)) + continue; + + nft_set_elem_dead(&rbe->ext); + + if (!rbe_end) + continue; + + nft_set_elem_dead(&rbe_end->ext); + + gc = nft_trans_gc_queue_async(gc, gc_seq, GFP_ATOMIC); + if (!gc) + goto try_later; + + nft_trans_gc_elem_add(gc, rbe_end); + rbe_end = NULL; +dead_elem: + gc = nft_trans_gc_queue_async(gc, gc_seq, GFP_ATOMIC); + if (!gc) + goto try_later; + + nft_trans_gc_elem_add(gc, rbe); + } + + gc = nft_trans_gc_catchall_async(gc, gc_seq); + +try_later: + read_unlock_bh(&priv->lock); + + if (gc) + nft_trans_gc_queue_async_done(gc); +done: + queue_delayed_work(system_power_efficient_wq, &priv->gc_work, + nft_set_gc_interval(set)); +} + +static u64 nft_rbtree_privsize(const struct nlattr * const nla[], + const struct nft_set_desc *desc) +{ + return sizeof(struct nft_rbtree); +} + +static int nft_rbtree_init(const struct nft_set *set, + const struct nft_set_desc *desc, + const struct nlattr * const nla[]) +{ + struct nft_rbtree *priv = nft_set_priv(set); + + rwlock_init(&priv->lock); + seqcount_rwlock_init(&priv->count, &priv->lock); + priv->root = RB_ROOT; + + INIT_DEFERRABLE_WORK(&priv->gc_work, nft_rbtree_gc); + if (set->flags & NFT_SET_TIMEOUT) + queue_delayed_work(system_power_efficient_wq, &priv->gc_work, + nft_set_gc_interval(set)); + + return 0; +} + +static void nft_rbtree_destroy(const struct nft_ctx *ctx, + const struct nft_set *set) +{ + struct nft_rbtree *priv = nft_set_priv(set); + struct nft_rbtree_elem *rbe; + struct rb_node *node; + + cancel_delayed_work_sync(&priv->gc_work); + rcu_barrier(); + while ((node = priv->root.rb_node) != NULL) { + rb_erase(node, &priv->root); + rbe = rb_entry(node, struct nft_rbtree_elem, node); + nf_tables_set_elem_destroy(ctx, set, rbe); + } +} + +static bool nft_rbtree_estimate(const struct nft_set_desc *desc, u32 features, + struct nft_set_estimate *est) +{ + if (desc->field_count > 1) + return false; + + if (desc->size) + est->size = sizeof(struct nft_rbtree) + + desc->size * sizeof(struct nft_rbtree_elem); + else + est->size = ~0; + + est->lookup = NFT_SET_CLASS_O_LOG_N; + est->space = NFT_SET_CLASS_O_N; + + return true; +} + +const struct nft_set_type nft_set_rbtree_type = { + .features = NFT_SET_INTERVAL | NFT_SET_MAP | NFT_SET_OBJECT | NFT_SET_TIMEOUT, + .ops = { + .privsize = nft_rbtree_privsize, + .elemsize = offsetof(struct nft_rbtree_elem, ext), + .estimate = nft_rbtree_estimate, + .init = nft_rbtree_init, + .destroy = nft_rbtree_destroy, + .insert = nft_rbtree_insert, + .remove = nft_rbtree_remove, + .deactivate = nft_rbtree_deactivate, + .flush = nft_rbtree_flush, + .activate = nft_rbtree_activate, + .lookup = nft_rbtree_lookup, + .walk = nft_rbtree_walk, + .get = nft_rbtree_get, + }, +}; diff --git a/net/netfilter/nft_socket.c b/net/netfilter/nft_socket.c new file mode 100644 index 000000000..f28324fd8 --- /dev/null +++ b/net/netfilter/nft_socket.c @@ -0,0 +1,291 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#include +#include +#include +#include +#include +#include +#include + +struct nft_socket { + enum nft_socket_keys key:8; + u8 level; + u8 len; + union { + u8 dreg; + }; +}; + +static void nft_socket_wildcard(const struct nft_pktinfo *pkt, + struct nft_regs *regs, struct sock *sk, + u32 *dest) +{ + switch (nft_pf(pkt)) { + case NFPROTO_IPV4: + nft_reg_store8(dest, inet_sk(sk)->inet_rcv_saddr == 0); + break; +#if IS_ENABLED(CONFIG_NF_TABLES_IPV6) + case NFPROTO_IPV6: + nft_reg_store8(dest, ipv6_addr_any(&sk->sk_v6_rcv_saddr)); + break; +#endif + default: + regs->verdict.code = NFT_BREAK; + return; + } +} + +#ifdef CONFIG_SOCK_CGROUP_DATA +static noinline bool +nft_sock_get_eval_cgroupv2(u32 *dest, struct sock *sk, const struct nft_pktinfo *pkt, u32 level) +{ + struct cgroup *cgrp; + u64 cgid; + + if (!sk_fullsock(sk)) + return false; + + cgrp = cgroup_ancestor(sock_cgroup_ptr(&sk->sk_cgrp_data), level); + if (!cgrp) + return false; + + cgid = cgroup_id(cgrp); + memcpy(dest, &cgid, sizeof(u64)); + return true; +} +#endif + +static struct sock *nft_socket_do_lookup(const struct nft_pktinfo *pkt) +{ + const struct net_device *indev = nft_in(pkt); + const struct sk_buff *skb = pkt->skb; + struct sock *sk = NULL; + + if (!indev) + return NULL; + + switch (nft_pf(pkt)) { + case NFPROTO_IPV4: + sk = nf_sk_lookup_slow_v4(nft_net(pkt), skb, indev); + break; +#if IS_ENABLED(CONFIG_NF_TABLES_IPV6) + case NFPROTO_IPV6: + sk = nf_sk_lookup_slow_v6(nft_net(pkt), skb, indev); + break; +#endif + default: + WARN_ON_ONCE(1); + break; + } + + return sk; +} + +static void nft_socket_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_socket *priv = nft_expr_priv(expr); + struct sk_buff *skb = pkt->skb; + struct sock *sk = skb->sk; + u32 *dest = ®s->data[priv->dreg]; + + if (sk && !net_eq(nft_net(pkt), sock_net(sk))) + sk = NULL; + + if (!sk) + sk = nft_socket_do_lookup(pkt); + + if (!sk) { + regs->verdict.code = NFT_BREAK; + return; + } + + switch(priv->key) { + case NFT_SOCKET_TRANSPARENT: + nft_reg_store8(dest, inet_sk_transparent(sk)); + break; + case NFT_SOCKET_MARK: + if (sk_fullsock(sk)) { + *dest = READ_ONCE(sk->sk_mark); + } else { + regs->verdict.code = NFT_BREAK; + return; + } + break; + case NFT_SOCKET_WILDCARD: + if (!sk_fullsock(sk)) { + regs->verdict.code = NFT_BREAK; + return; + } + nft_socket_wildcard(pkt, regs, sk, dest); + break; +#ifdef CONFIG_SOCK_CGROUP_DATA + case NFT_SOCKET_CGROUPV2: + if (!nft_sock_get_eval_cgroupv2(dest, sk, pkt, priv->level)) { + regs->verdict.code = NFT_BREAK; + return; + } + break; +#endif + default: + WARN_ON(1); + regs->verdict.code = NFT_BREAK; + } + + if (sk != skb->sk) + sock_gen_put(sk); +} + +static const struct nla_policy nft_socket_policy[NFTA_SOCKET_MAX + 1] = { + [NFTA_SOCKET_KEY] = { .type = NLA_U32 }, + [NFTA_SOCKET_DREG] = { .type = NLA_U32 }, + [NFTA_SOCKET_LEVEL] = { .type = NLA_U32 }, +}; + +static int nft_socket_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_socket *priv = nft_expr_priv(expr); + unsigned int len; + + if (!tb[NFTA_SOCKET_DREG] || !tb[NFTA_SOCKET_KEY]) + return -EINVAL; + + switch(ctx->family) { + case NFPROTO_IPV4: +#if IS_ENABLED(CONFIG_NF_TABLES_IPV6) + case NFPROTO_IPV6: +#endif + case NFPROTO_INET: + break; + default: + return -EOPNOTSUPP; + } + + priv->key = ntohl(nla_get_be32(tb[NFTA_SOCKET_KEY])); + switch(priv->key) { + case NFT_SOCKET_TRANSPARENT: + case NFT_SOCKET_WILDCARD: + len = sizeof(u8); + break; + case NFT_SOCKET_MARK: + len = sizeof(u32); + break; +#ifdef CONFIG_CGROUPS + case NFT_SOCKET_CGROUPV2: { + unsigned int level; + + if (!tb[NFTA_SOCKET_LEVEL]) + return -EINVAL; + + level = ntohl(nla_get_be32(tb[NFTA_SOCKET_LEVEL])); + if (level > 255) + return -EOPNOTSUPP; + + priv->level = level; + len = sizeof(u64); + break; + } +#endif + default: + return -EOPNOTSUPP; + } + + priv->len = len; + return nft_parse_register_store(ctx, tb[NFTA_SOCKET_DREG], &priv->dreg, + NULL, NFT_DATA_VALUE, len); +} + +static int nft_socket_dump(struct sk_buff *skb, + const struct nft_expr *expr) +{ + const struct nft_socket *priv = nft_expr_priv(expr); + + if (nla_put_be32(skb, NFTA_SOCKET_KEY, htonl(priv->key))) + return -1; + if (nft_dump_register(skb, NFTA_SOCKET_DREG, priv->dreg)) + return -1; + if (priv->key == NFT_SOCKET_CGROUPV2 && + nla_put_be32(skb, NFTA_SOCKET_LEVEL, htonl(priv->level))) + return -1; + return 0; +} + +static bool nft_socket_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + const struct nft_socket *priv = nft_expr_priv(expr); + const struct nft_socket *socket; + + if (!nft_reg_track_cmp(track, expr, priv->dreg)) { + nft_reg_track_update(track, expr, priv->dreg, priv->len); + return false; + } + + socket = nft_expr_priv(track->regs[priv->dreg].selector); + if (priv->key != socket->key || + priv->dreg != socket->dreg || + priv->level != socket->level) { + nft_reg_track_update(track, expr, priv->dreg, priv->len); + return false; + } + + if (!track->regs[priv->dreg].bitwise) + return true; + + return nft_expr_reduce_bitwise(track, expr); +} + +static int nft_socket_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + if (ctx->family != NFPROTO_IPV4 && + ctx->family != NFPROTO_IPV6 && + ctx->family != NFPROTO_INET) + return -EOPNOTSUPP; + + return nft_chain_validate_hooks(ctx->chain, + (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_LOCAL_OUT)); +} + +static struct nft_expr_type nft_socket_type; +static const struct nft_expr_ops nft_socket_ops = { + .type = &nft_socket_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_socket)), + .eval = nft_socket_eval, + .init = nft_socket_init, + .dump = nft_socket_dump, + .validate = nft_socket_validate, + .reduce = nft_socket_reduce, +}; + +static struct nft_expr_type nft_socket_type __read_mostly = { + .name = "socket", + .ops = &nft_socket_ops, + .policy = nft_socket_policy, + .maxattr = NFTA_SOCKET_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_socket_module_init(void) +{ + return nft_register_expr(&nft_socket_type); +} + +static void __exit nft_socket_module_exit(void) +{ + nft_unregister_expr(&nft_socket_type); +} + +module_init(nft_socket_module_init); +module_exit(nft_socket_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Máté Eckl"); +MODULE_DESCRIPTION("nf_tables socket match module"); +MODULE_ALIAS_NFT_EXPR("socket"); diff --git a/net/netfilter/nft_synproxy.c b/net/netfilter/nft_synproxy.c new file mode 100644 index 000000000..a450f28a5 --- /dev/null +++ b/net/netfilter/nft_synproxy.c @@ -0,0 +1,397 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_synproxy { + struct nf_synproxy_info info; +}; + +static const struct nla_policy nft_synproxy_policy[NFTA_SYNPROXY_MAX + 1] = { + [NFTA_SYNPROXY_MSS] = { .type = NLA_U16 }, + [NFTA_SYNPROXY_WSCALE] = { .type = NLA_U8 }, + [NFTA_SYNPROXY_FLAGS] = { .type = NLA_U32 }, +}; + +static void nft_synproxy_tcp_options(struct synproxy_options *opts, + const struct tcphdr *tcp, + struct synproxy_net *snet, + struct nf_synproxy_info *info, + const struct nft_synproxy *priv) +{ + this_cpu_inc(snet->stats->syn_received); + if (tcp->ece && tcp->cwr) + opts->options |= NF_SYNPROXY_OPT_ECN; + + opts->options &= priv->info.options; + opts->mss_encode = opts->mss_option; + opts->mss_option = info->mss; + if (opts->options & NF_SYNPROXY_OPT_TIMESTAMP) + synproxy_init_timestamp_cookie(info, opts); + else + opts->options &= ~(NF_SYNPROXY_OPT_WSCALE | + NF_SYNPROXY_OPT_SACK_PERM | + NF_SYNPROXY_OPT_ECN); +} + +static void nft_synproxy_eval_v4(const struct nft_synproxy *priv, + struct nft_regs *regs, + const struct nft_pktinfo *pkt, + const struct tcphdr *tcp, + struct tcphdr *_tcph, + struct synproxy_options *opts) +{ + struct nf_synproxy_info info = priv->info; + struct net *net = nft_net(pkt); + struct synproxy_net *snet = synproxy_pernet(net); + struct sk_buff *skb = pkt->skb; + + if (tcp->syn) { + /* Initial SYN from client */ + nft_synproxy_tcp_options(opts, tcp, snet, &info, priv); + synproxy_send_client_synack(net, skb, tcp, opts); + consume_skb(skb); + regs->verdict.code = NF_STOLEN; + } else if (tcp->ack) { + /* ACK from client */ + if (synproxy_recv_client_ack(net, skb, tcp, opts, + ntohl(tcp->seq))) { + consume_skb(skb); + regs->verdict.code = NF_STOLEN; + } else { + regs->verdict.code = NF_DROP; + } + } +} + +#if IS_ENABLED(CONFIG_NF_TABLES_IPV6) +static void nft_synproxy_eval_v6(const struct nft_synproxy *priv, + struct nft_regs *regs, + const struct nft_pktinfo *pkt, + const struct tcphdr *tcp, + struct tcphdr *_tcph, + struct synproxy_options *opts) +{ + struct nf_synproxy_info info = priv->info; + struct net *net = nft_net(pkt); + struct synproxy_net *snet = synproxy_pernet(net); + struct sk_buff *skb = pkt->skb; + + if (tcp->syn) { + /* Initial SYN from client */ + nft_synproxy_tcp_options(opts, tcp, snet, &info, priv); + synproxy_send_client_synack_ipv6(net, skb, tcp, opts); + consume_skb(skb); + regs->verdict.code = NF_STOLEN; + } else if (tcp->ack) { + /* ACK from client */ + if (synproxy_recv_client_ack_ipv6(net, skb, tcp, opts, + ntohl(tcp->seq))) { + consume_skb(skb); + regs->verdict.code = NF_STOLEN; + } else { + regs->verdict.code = NF_DROP; + } + } +} +#endif /* CONFIG_NF_TABLES_IPV6*/ + +static void nft_synproxy_do_eval(const struct nft_synproxy *priv, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct synproxy_options opts = {}; + struct sk_buff *skb = pkt->skb; + int thoff = nft_thoff(pkt); + const struct tcphdr *tcp; + struct tcphdr _tcph; + + if (pkt->tprot != IPPROTO_TCP) { + regs->verdict.code = NFT_BREAK; + return; + } + + if (nf_ip_checksum(skb, nft_hook(pkt), thoff, IPPROTO_TCP)) { + regs->verdict.code = NF_DROP; + return; + } + + tcp = skb_header_pointer(skb, thoff, + sizeof(struct tcphdr), + &_tcph); + if (!tcp) { + regs->verdict.code = NF_DROP; + return; + } + + if (!synproxy_parse_options(skb, thoff, tcp, &opts)) { + regs->verdict.code = NF_DROP; + return; + } + + switch (skb->protocol) { + case htons(ETH_P_IP): + nft_synproxy_eval_v4(priv, regs, pkt, tcp, &_tcph, &opts); + return; +#if IS_ENABLED(CONFIG_NF_TABLES_IPV6) + case htons(ETH_P_IPV6): + nft_synproxy_eval_v6(priv, regs, pkt, tcp, &_tcph, &opts); + return; +#endif + } + regs->verdict.code = NFT_BREAK; +} + +static int nft_synproxy_do_init(const struct nft_ctx *ctx, + const struct nlattr * const tb[], + struct nft_synproxy *priv) +{ + struct synproxy_net *snet = synproxy_pernet(ctx->net); + u32 flags; + int err; + + if (tb[NFTA_SYNPROXY_MSS]) + priv->info.mss = ntohs(nla_get_be16(tb[NFTA_SYNPROXY_MSS])); + if (tb[NFTA_SYNPROXY_WSCALE]) + priv->info.wscale = nla_get_u8(tb[NFTA_SYNPROXY_WSCALE]); + if (tb[NFTA_SYNPROXY_FLAGS]) { + flags = ntohl(nla_get_be32(tb[NFTA_SYNPROXY_FLAGS])); + if (flags & ~NF_SYNPROXY_OPT_MASK) + return -EOPNOTSUPP; + priv->info.options = flags; + } + + err = nf_ct_netns_get(ctx->net, ctx->family); + if (err) + return err; + + switch (ctx->family) { + case NFPROTO_IPV4: + err = nf_synproxy_ipv4_init(snet, ctx->net); + if (err) + goto nf_ct_failure; + break; +#if IS_ENABLED(CONFIG_NF_TABLES_IPV6) + case NFPROTO_IPV6: + err = nf_synproxy_ipv6_init(snet, ctx->net); + if (err) + goto nf_ct_failure; + break; +#endif + case NFPROTO_INET: + err = nf_synproxy_ipv4_init(snet, ctx->net); + if (err) + goto nf_ct_failure; + err = nf_synproxy_ipv6_init(snet, ctx->net); + if (err) { + nf_synproxy_ipv4_fini(snet, ctx->net); + goto nf_ct_failure; + } + break; + } + + return 0; + +nf_ct_failure: + nf_ct_netns_put(ctx->net, ctx->family); + return err; +} + +static void nft_synproxy_do_destroy(const struct nft_ctx *ctx) +{ + struct synproxy_net *snet = synproxy_pernet(ctx->net); + + switch (ctx->family) { + case NFPROTO_IPV4: + nf_synproxy_ipv4_fini(snet, ctx->net); + break; +#if IS_ENABLED(CONFIG_NF_TABLES_IPV6) + case NFPROTO_IPV6: + nf_synproxy_ipv6_fini(snet, ctx->net); + break; +#endif + case NFPROTO_INET: + nf_synproxy_ipv4_fini(snet, ctx->net); + nf_synproxy_ipv6_fini(snet, ctx->net); + break; + } + nf_ct_netns_put(ctx->net, ctx->family); +} + +static int nft_synproxy_do_dump(struct sk_buff *skb, struct nft_synproxy *priv) +{ + if (nla_put_be16(skb, NFTA_SYNPROXY_MSS, htons(priv->info.mss)) || + nla_put_u8(skb, NFTA_SYNPROXY_WSCALE, priv->info.wscale) || + nla_put_be32(skb, NFTA_SYNPROXY_FLAGS, htonl(priv->info.options))) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static void nft_synproxy_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_synproxy *priv = nft_expr_priv(expr); + + nft_synproxy_do_eval(priv, regs, pkt); +} + +static int nft_synproxy_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + if (ctx->family != NFPROTO_IPV4 && + ctx->family != NFPROTO_IPV6 && + ctx->family != NFPROTO_INET) + return -EOPNOTSUPP; + + return nft_chain_validate_hooks(ctx->chain, (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_FORWARD)); +} + +static int nft_synproxy_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_synproxy *priv = nft_expr_priv(expr); + + return nft_synproxy_do_init(ctx, tb, priv); +} + +static void nft_synproxy_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + nft_synproxy_do_destroy(ctx); +} + +static int nft_synproxy_dump(struct sk_buff *skb, const struct nft_expr *expr) +{ + struct nft_synproxy *priv = nft_expr_priv(expr); + + return nft_synproxy_do_dump(skb, priv); +} + +static struct nft_expr_type nft_synproxy_type; +static const struct nft_expr_ops nft_synproxy_ops = { + .eval = nft_synproxy_eval, + .size = NFT_EXPR_SIZE(sizeof(struct nft_synproxy)), + .init = nft_synproxy_init, + .destroy = nft_synproxy_destroy, + .dump = nft_synproxy_dump, + .type = &nft_synproxy_type, + .validate = nft_synproxy_validate, + .reduce = NFT_REDUCE_READONLY, +}; + +static struct nft_expr_type nft_synproxy_type __read_mostly = { + .ops = &nft_synproxy_ops, + .name = "synproxy", + .owner = THIS_MODULE, + .policy = nft_synproxy_policy, + .maxattr = NFTA_SYNPROXY_MAX, +}; + +static int nft_synproxy_obj_init(const struct nft_ctx *ctx, + const struct nlattr * const tb[], + struct nft_object *obj) +{ + struct nft_synproxy *priv = nft_obj_data(obj); + + return nft_synproxy_do_init(ctx, tb, priv); +} + +static void nft_synproxy_obj_destroy(const struct nft_ctx *ctx, + struct nft_object *obj) +{ + nft_synproxy_do_destroy(ctx); +} + +static int nft_synproxy_obj_dump(struct sk_buff *skb, + struct nft_object *obj, bool reset) +{ + struct nft_synproxy *priv = nft_obj_data(obj); + + return nft_synproxy_do_dump(skb, priv); +} + +static void nft_synproxy_obj_eval(struct nft_object *obj, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_synproxy *priv = nft_obj_data(obj); + + nft_synproxy_do_eval(priv, regs, pkt); +} + +static void nft_synproxy_obj_update(struct nft_object *obj, + struct nft_object *newobj) +{ + struct nft_synproxy *newpriv = nft_obj_data(newobj); + struct nft_synproxy *priv = nft_obj_data(obj); + + priv->info = newpriv->info; +} + +static struct nft_object_type nft_synproxy_obj_type; +static const struct nft_object_ops nft_synproxy_obj_ops = { + .type = &nft_synproxy_obj_type, + .size = sizeof(struct nft_synproxy), + .init = nft_synproxy_obj_init, + .destroy = nft_synproxy_obj_destroy, + .dump = nft_synproxy_obj_dump, + .eval = nft_synproxy_obj_eval, + .update = nft_synproxy_obj_update, +}; + +static struct nft_object_type nft_synproxy_obj_type __read_mostly = { + .type = NFT_OBJECT_SYNPROXY, + .ops = &nft_synproxy_obj_ops, + .maxattr = NFTA_SYNPROXY_MAX, + .policy = nft_synproxy_policy, + .owner = THIS_MODULE, +}; + +static int __init nft_synproxy_module_init(void) +{ + int err; + + err = nft_register_obj(&nft_synproxy_obj_type); + if (err < 0) + return err; + + err = nft_register_expr(&nft_synproxy_type); + if (err < 0) + goto err; + + return 0; + +err: + nft_unregister_obj(&nft_synproxy_obj_type); + return err; +} + +static void __exit nft_synproxy_module_exit(void) +{ + nft_unregister_expr(&nft_synproxy_type); + nft_unregister_obj(&nft_synproxy_obj_type); +} + +module_init(nft_synproxy_module_init); +module_exit(nft_synproxy_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Fernando Fernandez "); +MODULE_ALIAS_NFT_EXPR("synproxy"); +MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_SYNPROXY); +MODULE_DESCRIPTION("nftables SYNPROXY expression support"); diff --git a/net/netfilter/nft_tproxy.c b/net/netfilter/nft_tproxy.c new file mode 100644 index 000000000..adb50c395 --- /dev/null +++ b/net/netfilter/nft_tproxy.c @@ -0,0 +1,363 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_NF_TABLES_IPV6) +#include +#endif + +struct nft_tproxy { + u8 sreg_addr; + u8 sreg_port; + u8 family; +}; + +static void nft_tproxy_eval_v4(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_tproxy *priv = nft_expr_priv(expr); + struct sk_buff *skb = pkt->skb; + const struct iphdr *iph = ip_hdr(skb); + struct udphdr _hdr, *hp; + __be32 taddr = 0; + __be16 tport = 0; + struct sock *sk; + + if (pkt->tprot != IPPROTO_TCP && + pkt->tprot != IPPROTO_UDP) { + regs->verdict.code = NFT_BREAK; + return; + } + + hp = skb_header_pointer(skb, ip_hdrlen(skb), sizeof(_hdr), &_hdr); + if (!hp) { + regs->verdict.code = NFT_BREAK; + return; + } + + /* check if there's an ongoing connection on the packet addresses, this + * happens if the redirect already happened and the current packet + * belongs to an already established connection + */ + sk = nf_tproxy_get_sock_v4(nft_net(pkt), skb, iph->protocol, + iph->saddr, iph->daddr, + hp->source, hp->dest, + skb->dev, NF_TPROXY_LOOKUP_ESTABLISHED); + + if (priv->sreg_addr) + taddr = nft_reg_load_be32(®s->data[priv->sreg_addr]); + taddr = nf_tproxy_laddr4(skb, taddr, iph->daddr); + + if (priv->sreg_port) + tport = nft_reg_load_be16(®s->data[priv->sreg_port]); + if (!tport) + tport = hp->dest; + + /* UDP has no TCP_TIME_WAIT state, so we never enter here */ + if (sk && sk->sk_state == TCP_TIME_WAIT) { + /* reopening a TIME_WAIT connection needs special handling */ + sk = nf_tproxy_handle_time_wait4(nft_net(pkt), skb, taddr, tport, sk); + } else if (!sk) { + /* no, there's no established connection, check if + * there's a listener on the redirected addr/port + */ + sk = nf_tproxy_get_sock_v4(nft_net(pkt), skb, iph->protocol, + iph->saddr, taddr, + hp->source, tport, + skb->dev, NF_TPROXY_LOOKUP_LISTENER); + } + + if (sk && nf_tproxy_sk_is_transparent(sk)) + nf_tproxy_assign_sock(skb, sk); + else + regs->verdict.code = NFT_BREAK; +} + +#if IS_ENABLED(CONFIG_NF_TABLES_IPV6) +static void nft_tproxy_eval_v6(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_tproxy *priv = nft_expr_priv(expr); + struct sk_buff *skb = pkt->skb; + const struct ipv6hdr *iph = ipv6_hdr(skb); + int thoff = nft_thoff(pkt); + struct udphdr _hdr, *hp; + struct in6_addr taddr; + __be16 tport = 0; + struct sock *sk; + int l4proto; + + memset(&taddr, 0, sizeof(taddr)); + + if (pkt->tprot != IPPROTO_TCP && + pkt->tprot != IPPROTO_UDP) { + regs->verdict.code = NFT_BREAK; + return; + } + l4proto = pkt->tprot; + + hp = skb_header_pointer(skb, thoff, sizeof(_hdr), &_hdr); + if (hp == NULL) { + regs->verdict.code = NFT_BREAK; + return; + } + + /* check if there's an ongoing connection on the packet addresses, this + * happens if the redirect already happened and the current packet + * belongs to an already established connection + */ + sk = nf_tproxy_get_sock_v6(nft_net(pkt), skb, thoff, l4proto, + &iph->saddr, &iph->daddr, + hp->source, hp->dest, + nft_in(pkt), NF_TPROXY_LOOKUP_ESTABLISHED); + + if (priv->sreg_addr) + memcpy(&taddr, ®s->data[priv->sreg_addr], sizeof(taddr)); + taddr = *nf_tproxy_laddr6(skb, &taddr, &iph->daddr); + + if (priv->sreg_port) + tport = nft_reg_load_be16(®s->data[priv->sreg_port]); + if (!tport) + tport = hp->dest; + + /* UDP has no TCP_TIME_WAIT state, so we never enter here */ + if (sk && sk->sk_state == TCP_TIME_WAIT) { + /* reopening a TIME_WAIT connection needs special handling */ + sk = nf_tproxy_handle_time_wait6(skb, l4proto, thoff, + nft_net(pkt), + &taddr, + tport, + sk); + } else if (!sk) { + /* no there's no established connection, check if + * there's a listener on the redirected addr/port + */ + sk = nf_tproxy_get_sock_v6(nft_net(pkt), skb, thoff, + l4proto, &iph->saddr, &taddr, + hp->source, tport, + nft_in(pkt), NF_TPROXY_LOOKUP_LISTENER); + } + + /* NOTE: assign_sock consumes our sk reference */ + if (sk && nf_tproxy_sk_is_transparent(sk)) + nf_tproxy_assign_sock(skb, sk); + else + regs->verdict.code = NFT_BREAK; +} +#endif + +static void nft_tproxy_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_tproxy *priv = nft_expr_priv(expr); + + switch (nft_pf(pkt)) { + case NFPROTO_IPV4: + switch (priv->family) { + case NFPROTO_IPV4: + case NFPROTO_UNSPEC: + nft_tproxy_eval_v4(expr, regs, pkt); + return; + } + break; +#if IS_ENABLED(CONFIG_NF_TABLES_IPV6) + case NFPROTO_IPV6: + switch (priv->family) { + case NFPROTO_IPV6: + case NFPROTO_UNSPEC: + nft_tproxy_eval_v6(expr, regs, pkt); + return; + } +#endif + } + regs->verdict.code = NFT_BREAK; +} + +static const struct nla_policy nft_tproxy_policy[NFTA_TPROXY_MAX + 1] = { + [NFTA_TPROXY_FAMILY] = { .type = NLA_U32 }, + [NFTA_TPROXY_REG_ADDR] = { .type = NLA_U32 }, + [NFTA_TPROXY_REG_PORT] = { .type = NLA_U32 }, +}; + +static int nft_tproxy_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_tproxy *priv = nft_expr_priv(expr); + unsigned int alen = 0; + int err; + + if (!tb[NFTA_TPROXY_FAMILY] || + (!tb[NFTA_TPROXY_REG_ADDR] && !tb[NFTA_TPROXY_REG_PORT])) + return -EINVAL; + + priv->family = ntohl(nla_get_be32(tb[NFTA_TPROXY_FAMILY])); + + switch (ctx->family) { + case NFPROTO_IPV4: + if (priv->family != NFPROTO_IPV4) + return -EINVAL; + break; +#if IS_ENABLED(CONFIG_NF_TABLES_IPV6) + case NFPROTO_IPV6: + if (priv->family != NFPROTO_IPV6) + return -EINVAL; + break; +#endif + case NFPROTO_INET: + break; + default: + return -EOPNOTSUPP; + } + + /* Address is specified but the rule family is not set accordingly */ + if (priv->family == NFPROTO_UNSPEC && tb[NFTA_TPROXY_REG_ADDR]) + return -EINVAL; + + switch (priv->family) { + case NFPROTO_IPV4: + alen = sizeof_field(union nf_inet_addr, in); + err = nf_defrag_ipv4_enable(ctx->net); + if (err) + return err; + break; +#if IS_ENABLED(CONFIG_NF_TABLES_IPV6) + case NFPROTO_IPV6: + alen = sizeof_field(union nf_inet_addr, in6); + err = nf_defrag_ipv6_enable(ctx->net); + if (err) + return err; + break; +#endif + case NFPROTO_UNSPEC: + /* No address is specified here */ + err = nf_defrag_ipv4_enable(ctx->net); + if (err) + return err; +#if IS_ENABLED(CONFIG_NF_TABLES_IPV6) + err = nf_defrag_ipv6_enable(ctx->net); + if (err) + return err; +#endif + break; + default: + return -EOPNOTSUPP; + } + + if (tb[NFTA_TPROXY_REG_ADDR]) { + err = nft_parse_register_load(tb[NFTA_TPROXY_REG_ADDR], + &priv->sreg_addr, alen); + if (err < 0) + return err; + } + + if (tb[NFTA_TPROXY_REG_PORT]) { + err = nft_parse_register_load(tb[NFTA_TPROXY_REG_PORT], + &priv->sreg_port, sizeof(u16)); + if (err < 0) + return err; + } + + return 0; +} + +static void nft_tproxy_destroy(const struct nft_ctx *ctx, + const struct nft_expr *expr) +{ + const struct nft_tproxy *priv = nft_expr_priv(expr); + + switch (priv->family) { + case NFPROTO_IPV4: + nf_defrag_ipv4_disable(ctx->net); + break; +#if IS_ENABLED(CONFIG_NF_TABLES_IPV6) + case NFPROTO_IPV6: + nf_defrag_ipv6_disable(ctx->net); + break; +#endif + case NFPROTO_UNSPEC: + nf_defrag_ipv4_disable(ctx->net); +#if IS_ENABLED(CONFIG_NF_TABLES_IPV6) + nf_defrag_ipv6_disable(ctx->net); +#endif + break; + } +} + +static int nft_tproxy_dump(struct sk_buff *skb, + const struct nft_expr *expr) +{ + const struct nft_tproxy *priv = nft_expr_priv(expr); + + if (nla_put_be32(skb, NFTA_TPROXY_FAMILY, htonl(priv->family))) + return -1; + + if (priv->sreg_addr && + nft_dump_register(skb, NFTA_TPROXY_REG_ADDR, priv->sreg_addr)) + return -1; + + if (priv->sreg_port && + nft_dump_register(skb, NFTA_TPROXY_REG_PORT, priv->sreg_port)) + return -1; + + return 0; +} + +static int nft_tproxy_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + if (ctx->family != NFPROTO_IPV4 && + ctx->family != NFPROTO_IPV6 && + ctx->family != NFPROTO_INET) + return -EOPNOTSUPP; + + return nft_chain_validate_hooks(ctx->chain, 1 << NF_INET_PRE_ROUTING); +} + +static struct nft_expr_type nft_tproxy_type; +static const struct nft_expr_ops nft_tproxy_ops = { + .type = &nft_tproxy_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_tproxy)), + .eval = nft_tproxy_eval, + .init = nft_tproxy_init, + .destroy = nft_tproxy_destroy, + .dump = nft_tproxy_dump, + .reduce = NFT_REDUCE_READONLY, + .validate = nft_tproxy_validate, +}; + +static struct nft_expr_type nft_tproxy_type __read_mostly = { + .name = "tproxy", + .ops = &nft_tproxy_ops, + .policy = nft_tproxy_policy, + .maxattr = NFTA_TPROXY_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_tproxy_module_init(void) +{ + return nft_register_expr(&nft_tproxy_type); +} + +static void __exit nft_tproxy_module_exit(void) +{ + nft_unregister_expr(&nft_tproxy_type); +} + +module_init(nft_tproxy_module_init); +module_exit(nft_tproxy_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Máté Eckl"); +MODULE_DESCRIPTION("nf_tables tproxy support module"); +MODULE_ALIAS_NFT_EXPR("tproxy"); diff --git a/net/netfilter/nft_tunnel.c b/net/netfilter/nft_tunnel.c new file mode 100644 index 000000000..983ade4be --- /dev/null +++ b/net/netfilter/nft_tunnel.c @@ -0,0 +1,750 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct nft_tunnel { + enum nft_tunnel_keys key:8; + u8 dreg; + enum nft_tunnel_mode mode:8; + u8 len; +}; + +static void nft_tunnel_get_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_tunnel *priv = nft_expr_priv(expr); + u32 *dest = ®s->data[priv->dreg]; + struct ip_tunnel_info *tun_info; + + tun_info = skb_tunnel_info(pkt->skb); + + switch (priv->key) { + case NFT_TUNNEL_PATH: + if (!tun_info) { + nft_reg_store8(dest, false); + return; + } + if (priv->mode == NFT_TUNNEL_MODE_NONE || + (priv->mode == NFT_TUNNEL_MODE_RX && + !(tun_info->mode & IP_TUNNEL_INFO_TX)) || + (priv->mode == NFT_TUNNEL_MODE_TX && + (tun_info->mode & IP_TUNNEL_INFO_TX))) + nft_reg_store8(dest, true); + else + nft_reg_store8(dest, false); + break; + case NFT_TUNNEL_ID: + if (!tun_info) { + regs->verdict.code = NFT_BREAK; + return; + } + if (priv->mode == NFT_TUNNEL_MODE_NONE || + (priv->mode == NFT_TUNNEL_MODE_RX && + !(tun_info->mode & IP_TUNNEL_INFO_TX)) || + (priv->mode == NFT_TUNNEL_MODE_TX && + (tun_info->mode & IP_TUNNEL_INFO_TX))) + *dest = ntohl(tunnel_id_to_key32(tun_info->key.tun_id)); + else + regs->verdict.code = NFT_BREAK; + break; + default: + WARN_ON(1); + regs->verdict.code = NFT_BREAK; + } +} + +static const struct nla_policy nft_tunnel_policy[NFTA_TUNNEL_MAX + 1] = { + [NFTA_TUNNEL_KEY] = { .type = NLA_U32 }, + [NFTA_TUNNEL_DREG] = { .type = NLA_U32 }, + [NFTA_TUNNEL_MODE] = { .type = NLA_U32 }, +}; + +static int nft_tunnel_get_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_tunnel *priv = nft_expr_priv(expr); + u32 len; + + if (!tb[NFTA_TUNNEL_KEY] || + !tb[NFTA_TUNNEL_DREG]) + return -EINVAL; + + priv->key = ntohl(nla_get_be32(tb[NFTA_TUNNEL_KEY])); + switch (priv->key) { + case NFT_TUNNEL_PATH: + len = sizeof(u8); + break; + case NFT_TUNNEL_ID: + len = sizeof(u32); + break; + default: + return -EOPNOTSUPP; + } + + if (tb[NFTA_TUNNEL_MODE]) { + priv->mode = ntohl(nla_get_be32(tb[NFTA_TUNNEL_MODE])); + if (priv->mode > NFT_TUNNEL_MODE_MAX) + return -EOPNOTSUPP; + } else { + priv->mode = NFT_TUNNEL_MODE_NONE; + } + + priv->len = len; + return nft_parse_register_store(ctx, tb[NFTA_TUNNEL_DREG], &priv->dreg, + NULL, NFT_DATA_VALUE, len); +} + +static int nft_tunnel_get_dump(struct sk_buff *skb, + const struct nft_expr *expr) +{ + const struct nft_tunnel *priv = nft_expr_priv(expr); + + if (nla_put_be32(skb, NFTA_TUNNEL_KEY, htonl(priv->key))) + goto nla_put_failure; + if (nft_dump_register(skb, NFTA_TUNNEL_DREG, priv->dreg)) + goto nla_put_failure; + if (nla_put_be32(skb, NFTA_TUNNEL_MODE, htonl(priv->mode))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static bool nft_tunnel_get_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + const struct nft_tunnel *priv = nft_expr_priv(expr); + const struct nft_tunnel *tunnel; + + if (!nft_reg_track_cmp(track, expr, priv->dreg)) { + nft_reg_track_update(track, expr, priv->dreg, priv->len); + return false; + } + + tunnel = nft_expr_priv(track->regs[priv->dreg].selector); + if (priv->key != tunnel->key || + priv->dreg != tunnel->dreg || + priv->mode != tunnel->mode) { + nft_reg_track_update(track, expr, priv->dreg, priv->len); + return false; + } + + if (!track->regs[priv->dreg].bitwise) + return true; + + return false; +} + +static struct nft_expr_type nft_tunnel_type; +static const struct nft_expr_ops nft_tunnel_get_ops = { + .type = &nft_tunnel_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_tunnel)), + .eval = nft_tunnel_get_eval, + .init = nft_tunnel_get_init, + .dump = nft_tunnel_get_dump, + .reduce = nft_tunnel_get_reduce, +}; + +static struct nft_expr_type nft_tunnel_type __read_mostly = { + .name = "tunnel", + .family = NFPROTO_NETDEV, + .ops = &nft_tunnel_get_ops, + .policy = nft_tunnel_policy, + .maxattr = NFTA_TUNNEL_MAX, + .owner = THIS_MODULE, +}; + +struct nft_tunnel_opts { + union { + struct vxlan_metadata vxlan; + struct erspan_metadata erspan; + u8 data[IP_TUNNEL_OPTS_MAX]; + } u; + u32 len; + __be16 flags; +}; + +struct nft_tunnel_obj { + struct metadata_dst *md; + struct nft_tunnel_opts opts; +}; + +static const struct nla_policy nft_tunnel_ip_policy[NFTA_TUNNEL_KEY_IP_MAX + 1] = { + [NFTA_TUNNEL_KEY_IP_SRC] = { .type = NLA_U32 }, + [NFTA_TUNNEL_KEY_IP_DST] = { .type = NLA_U32 }, +}; + +static int nft_tunnel_obj_ip_init(const struct nft_ctx *ctx, + const struct nlattr *attr, + struct ip_tunnel_info *info) +{ + struct nlattr *tb[NFTA_TUNNEL_KEY_IP_MAX + 1]; + int err; + + err = nla_parse_nested_deprecated(tb, NFTA_TUNNEL_KEY_IP_MAX, attr, + nft_tunnel_ip_policy, NULL); + if (err < 0) + return err; + + if (!tb[NFTA_TUNNEL_KEY_IP_DST]) + return -EINVAL; + + if (tb[NFTA_TUNNEL_KEY_IP_SRC]) + info->key.u.ipv4.src = nla_get_be32(tb[NFTA_TUNNEL_KEY_IP_SRC]); + if (tb[NFTA_TUNNEL_KEY_IP_DST]) + info->key.u.ipv4.dst = nla_get_be32(tb[NFTA_TUNNEL_KEY_IP_DST]); + + return 0; +} + +static const struct nla_policy nft_tunnel_ip6_policy[NFTA_TUNNEL_KEY_IP6_MAX + 1] = { + [NFTA_TUNNEL_KEY_IP6_SRC] = { .len = sizeof(struct in6_addr), }, + [NFTA_TUNNEL_KEY_IP6_DST] = { .len = sizeof(struct in6_addr), }, + [NFTA_TUNNEL_KEY_IP6_FLOWLABEL] = { .type = NLA_U32, } +}; + +static int nft_tunnel_obj_ip6_init(const struct nft_ctx *ctx, + const struct nlattr *attr, + struct ip_tunnel_info *info) +{ + struct nlattr *tb[NFTA_TUNNEL_KEY_IP6_MAX + 1]; + int err; + + err = nla_parse_nested_deprecated(tb, NFTA_TUNNEL_KEY_IP6_MAX, attr, + nft_tunnel_ip6_policy, NULL); + if (err < 0) + return err; + + if (!tb[NFTA_TUNNEL_KEY_IP6_DST]) + return -EINVAL; + + if (tb[NFTA_TUNNEL_KEY_IP6_SRC]) { + memcpy(&info->key.u.ipv6.src, + nla_data(tb[NFTA_TUNNEL_KEY_IP6_SRC]), + sizeof(struct in6_addr)); + } + if (tb[NFTA_TUNNEL_KEY_IP6_DST]) { + memcpy(&info->key.u.ipv6.dst, + nla_data(tb[NFTA_TUNNEL_KEY_IP6_DST]), + sizeof(struct in6_addr)); + } + if (tb[NFTA_TUNNEL_KEY_IP6_FLOWLABEL]) + info->key.label = nla_get_be32(tb[NFTA_TUNNEL_KEY_IP6_FLOWLABEL]); + + info->mode |= IP_TUNNEL_INFO_IPV6; + + return 0; +} + +static const struct nla_policy nft_tunnel_opts_vxlan_policy[NFTA_TUNNEL_KEY_VXLAN_MAX + 1] = { + [NFTA_TUNNEL_KEY_VXLAN_GBP] = { .type = NLA_U32 }, +}; + +static int nft_tunnel_obj_vxlan_init(const struct nlattr *attr, + struct nft_tunnel_opts *opts) +{ + struct nlattr *tb[NFTA_TUNNEL_KEY_VXLAN_MAX + 1]; + int err; + + err = nla_parse_nested_deprecated(tb, NFTA_TUNNEL_KEY_VXLAN_MAX, attr, + nft_tunnel_opts_vxlan_policy, NULL); + if (err < 0) + return err; + + if (!tb[NFTA_TUNNEL_KEY_VXLAN_GBP]) + return -EINVAL; + + opts->u.vxlan.gbp = ntohl(nla_get_be32(tb[NFTA_TUNNEL_KEY_VXLAN_GBP])); + + opts->len = sizeof(struct vxlan_metadata); + opts->flags = TUNNEL_VXLAN_OPT; + + return 0; +} + +static const struct nla_policy nft_tunnel_opts_erspan_policy[NFTA_TUNNEL_KEY_ERSPAN_MAX + 1] = { + [NFTA_TUNNEL_KEY_ERSPAN_VERSION] = { .type = NLA_U32 }, + [NFTA_TUNNEL_KEY_ERSPAN_V1_INDEX] = { .type = NLA_U32 }, + [NFTA_TUNNEL_KEY_ERSPAN_V2_DIR] = { .type = NLA_U8 }, + [NFTA_TUNNEL_KEY_ERSPAN_V2_HWID] = { .type = NLA_U8 }, +}; + +static int nft_tunnel_obj_erspan_init(const struct nlattr *attr, + struct nft_tunnel_opts *opts) +{ + struct nlattr *tb[NFTA_TUNNEL_KEY_ERSPAN_MAX + 1]; + uint8_t hwid, dir; + int err, version; + + err = nla_parse_nested_deprecated(tb, NFTA_TUNNEL_KEY_ERSPAN_MAX, + attr, nft_tunnel_opts_erspan_policy, + NULL); + if (err < 0) + return err; + + if (!tb[NFTA_TUNNEL_KEY_ERSPAN_VERSION]) + return -EINVAL; + + version = ntohl(nla_get_be32(tb[NFTA_TUNNEL_KEY_ERSPAN_VERSION])); + switch (version) { + case ERSPAN_VERSION: + if (!tb[NFTA_TUNNEL_KEY_ERSPAN_V1_INDEX]) + return -EINVAL; + + opts->u.erspan.u.index = + nla_get_be32(tb[NFTA_TUNNEL_KEY_ERSPAN_V1_INDEX]); + break; + case ERSPAN_VERSION2: + if (!tb[NFTA_TUNNEL_KEY_ERSPAN_V2_DIR] || + !tb[NFTA_TUNNEL_KEY_ERSPAN_V2_HWID]) + return -EINVAL; + + hwid = nla_get_u8(tb[NFTA_TUNNEL_KEY_ERSPAN_V2_HWID]); + dir = nla_get_u8(tb[NFTA_TUNNEL_KEY_ERSPAN_V2_DIR]); + + set_hwid(&opts->u.erspan.u.md2, hwid); + opts->u.erspan.u.md2.dir = dir; + break; + default: + return -EOPNOTSUPP; + } + opts->u.erspan.version = version; + + opts->len = sizeof(struct erspan_metadata); + opts->flags = TUNNEL_ERSPAN_OPT; + + return 0; +} + +static const struct nla_policy nft_tunnel_opts_geneve_policy[NFTA_TUNNEL_KEY_GENEVE_MAX + 1] = { + [NFTA_TUNNEL_KEY_GENEVE_CLASS] = { .type = NLA_U16 }, + [NFTA_TUNNEL_KEY_GENEVE_TYPE] = { .type = NLA_U8 }, + [NFTA_TUNNEL_KEY_GENEVE_DATA] = { .type = NLA_BINARY, .len = 128 }, +}; + +static int nft_tunnel_obj_geneve_init(const struct nlattr *attr, + struct nft_tunnel_opts *opts) +{ + struct geneve_opt *opt = (struct geneve_opt *)opts->u.data + opts->len; + struct nlattr *tb[NFTA_TUNNEL_KEY_GENEVE_MAX + 1]; + int err, data_len; + + err = nla_parse_nested(tb, NFTA_TUNNEL_KEY_GENEVE_MAX, attr, + nft_tunnel_opts_geneve_policy, NULL); + if (err < 0) + return err; + + if (!tb[NFTA_TUNNEL_KEY_GENEVE_CLASS] || + !tb[NFTA_TUNNEL_KEY_GENEVE_TYPE] || + !tb[NFTA_TUNNEL_KEY_GENEVE_DATA]) + return -EINVAL; + + attr = tb[NFTA_TUNNEL_KEY_GENEVE_DATA]; + data_len = nla_len(attr); + if (data_len % 4) + return -EINVAL; + + opts->len += sizeof(*opt) + data_len; + if (opts->len > IP_TUNNEL_OPTS_MAX) + return -EINVAL; + + memcpy(opt->opt_data, nla_data(attr), data_len); + opt->length = data_len / 4; + opt->opt_class = nla_get_be16(tb[NFTA_TUNNEL_KEY_GENEVE_CLASS]); + opt->type = nla_get_u8(tb[NFTA_TUNNEL_KEY_GENEVE_TYPE]); + opts->flags = TUNNEL_GENEVE_OPT; + + return 0; +} + +static const struct nla_policy nft_tunnel_opts_policy[NFTA_TUNNEL_KEY_OPTS_MAX + 1] = { + [NFTA_TUNNEL_KEY_OPTS_UNSPEC] = { + .strict_start_type = NFTA_TUNNEL_KEY_OPTS_GENEVE }, + [NFTA_TUNNEL_KEY_OPTS_VXLAN] = { .type = NLA_NESTED, }, + [NFTA_TUNNEL_KEY_OPTS_ERSPAN] = { .type = NLA_NESTED, }, + [NFTA_TUNNEL_KEY_OPTS_GENEVE] = { .type = NLA_NESTED, }, +}; + +static int nft_tunnel_obj_opts_init(const struct nft_ctx *ctx, + const struct nlattr *attr, + struct ip_tunnel_info *info, + struct nft_tunnel_opts *opts) +{ + struct nlattr *nla; + __be16 type = 0; + int err, rem; + + err = nla_validate_nested_deprecated(attr, NFTA_TUNNEL_KEY_OPTS_MAX, + nft_tunnel_opts_policy, NULL); + if (err < 0) + return err; + + nla_for_each_attr(nla, nla_data(attr), nla_len(attr), rem) { + switch (nla_type(nla)) { + case NFTA_TUNNEL_KEY_OPTS_VXLAN: + if (type) + return -EINVAL; + err = nft_tunnel_obj_vxlan_init(nla, opts); + if (err) + return err; + type = TUNNEL_VXLAN_OPT; + break; + case NFTA_TUNNEL_KEY_OPTS_ERSPAN: + if (type) + return -EINVAL; + err = nft_tunnel_obj_erspan_init(nla, opts); + if (err) + return err; + type = TUNNEL_ERSPAN_OPT; + break; + case NFTA_TUNNEL_KEY_OPTS_GENEVE: + if (type && type != TUNNEL_GENEVE_OPT) + return -EINVAL; + err = nft_tunnel_obj_geneve_init(nla, opts); + if (err) + return err; + type = TUNNEL_GENEVE_OPT; + break; + default: + return -EOPNOTSUPP; + } + } + + return err; +} + +static const struct nla_policy nft_tunnel_key_policy[NFTA_TUNNEL_KEY_MAX + 1] = { + [NFTA_TUNNEL_KEY_IP] = { .type = NLA_NESTED, }, + [NFTA_TUNNEL_KEY_IP6] = { .type = NLA_NESTED, }, + [NFTA_TUNNEL_KEY_ID] = { .type = NLA_U32, }, + [NFTA_TUNNEL_KEY_FLAGS] = { .type = NLA_U32, }, + [NFTA_TUNNEL_KEY_TOS] = { .type = NLA_U8, }, + [NFTA_TUNNEL_KEY_TTL] = { .type = NLA_U8, }, + [NFTA_TUNNEL_KEY_SPORT] = { .type = NLA_U16, }, + [NFTA_TUNNEL_KEY_DPORT] = { .type = NLA_U16, }, + [NFTA_TUNNEL_KEY_OPTS] = { .type = NLA_NESTED, }, +}; + +static int nft_tunnel_obj_init(const struct nft_ctx *ctx, + const struct nlattr * const tb[], + struct nft_object *obj) +{ + struct nft_tunnel_obj *priv = nft_obj_data(obj); + struct ip_tunnel_info info; + struct metadata_dst *md; + int err; + + if (!tb[NFTA_TUNNEL_KEY_ID]) + return -EINVAL; + + memset(&info, 0, sizeof(info)); + info.mode = IP_TUNNEL_INFO_TX; + info.key.tun_id = key32_to_tunnel_id(nla_get_be32(tb[NFTA_TUNNEL_KEY_ID])); + info.key.tun_flags = TUNNEL_KEY | TUNNEL_CSUM | TUNNEL_NOCACHE; + + if (tb[NFTA_TUNNEL_KEY_IP]) { + err = nft_tunnel_obj_ip_init(ctx, tb[NFTA_TUNNEL_KEY_IP], &info); + if (err < 0) + return err; + } else if (tb[NFTA_TUNNEL_KEY_IP6]) { + err = nft_tunnel_obj_ip6_init(ctx, tb[NFTA_TUNNEL_KEY_IP6], &info); + if (err < 0) + return err; + } else { + return -EINVAL; + } + + if (tb[NFTA_TUNNEL_KEY_SPORT]) { + info.key.tp_src = nla_get_be16(tb[NFTA_TUNNEL_KEY_SPORT]); + } + if (tb[NFTA_TUNNEL_KEY_DPORT]) { + info.key.tp_dst = nla_get_be16(tb[NFTA_TUNNEL_KEY_DPORT]); + } + + if (tb[NFTA_TUNNEL_KEY_FLAGS]) { + u32 tun_flags; + + tun_flags = ntohl(nla_get_be32(tb[NFTA_TUNNEL_KEY_FLAGS])); + if (tun_flags & ~NFT_TUNNEL_F_MASK) + return -EOPNOTSUPP; + + if (tun_flags & NFT_TUNNEL_F_ZERO_CSUM_TX) + info.key.tun_flags &= ~TUNNEL_CSUM; + if (tun_flags & NFT_TUNNEL_F_DONT_FRAGMENT) + info.key.tun_flags |= TUNNEL_DONT_FRAGMENT; + if (tun_flags & NFT_TUNNEL_F_SEQ_NUMBER) + info.key.tun_flags |= TUNNEL_SEQ; + } + if (tb[NFTA_TUNNEL_KEY_TOS]) + info.key.tos = nla_get_u8(tb[NFTA_TUNNEL_KEY_TOS]); + if (tb[NFTA_TUNNEL_KEY_TTL]) + info.key.ttl = nla_get_u8(tb[NFTA_TUNNEL_KEY_TTL]); + else + info.key.ttl = U8_MAX; + + if (tb[NFTA_TUNNEL_KEY_OPTS]) { + err = nft_tunnel_obj_opts_init(ctx, tb[NFTA_TUNNEL_KEY_OPTS], + &info, &priv->opts); + if (err < 0) + return err; + } + + md = metadata_dst_alloc(priv->opts.len, METADATA_IP_TUNNEL, GFP_KERNEL); + if (!md) + return -ENOMEM; + + memcpy(&md->u.tun_info, &info, sizeof(info)); +#ifdef CONFIG_DST_CACHE + err = dst_cache_init(&md->u.tun_info.dst_cache, GFP_KERNEL); + if (err < 0) { + metadata_dst_free(md); + return err; + } +#endif + ip_tunnel_info_opts_set(&md->u.tun_info, &priv->opts.u, priv->opts.len, + priv->opts.flags); + priv->md = md; + + return 0; +} + +static inline void nft_tunnel_obj_eval(struct nft_object *obj, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + struct nft_tunnel_obj *priv = nft_obj_data(obj); + struct sk_buff *skb = pkt->skb; + + skb_dst_drop(skb); + dst_hold((struct dst_entry *) priv->md); + skb_dst_set(skb, (struct dst_entry *) priv->md); +} + +static int nft_tunnel_ip_dump(struct sk_buff *skb, struct ip_tunnel_info *info) +{ + struct nlattr *nest; + + if (info->mode & IP_TUNNEL_INFO_IPV6) { + nest = nla_nest_start_noflag(skb, NFTA_TUNNEL_KEY_IP6); + if (!nest) + return -1; + + if (nla_put_in6_addr(skb, NFTA_TUNNEL_KEY_IP6_SRC, + &info->key.u.ipv6.src) < 0 || + nla_put_in6_addr(skb, NFTA_TUNNEL_KEY_IP6_DST, + &info->key.u.ipv6.dst) < 0 || + nla_put_be32(skb, NFTA_TUNNEL_KEY_IP6_FLOWLABEL, + info->key.label)) { + nla_nest_cancel(skb, nest); + return -1; + } + + nla_nest_end(skb, nest); + } else { + nest = nla_nest_start_noflag(skb, NFTA_TUNNEL_KEY_IP); + if (!nest) + return -1; + + if (nla_put_in_addr(skb, NFTA_TUNNEL_KEY_IP_SRC, + info->key.u.ipv4.src) < 0 || + nla_put_in_addr(skb, NFTA_TUNNEL_KEY_IP_DST, + info->key.u.ipv4.dst) < 0) { + nla_nest_cancel(skb, nest); + return -1; + } + + nla_nest_end(skb, nest); + } + + return 0; +} + +static int nft_tunnel_opts_dump(struct sk_buff *skb, + struct nft_tunnel_obj *priv) +{ + struct nft_tunnel_opts *opts = &priv->opts; + struct nlattr *nest, *inner; + + nest = nla_nest_start_noflag(skb, NFTA_TUNNEL_KEY_OPTS); + if (!nest) + return -1; + + if (opts->flags & TUNNEL_VXLAN_OPT) { + inner = nla_nest_start_noflag(skb, NFTA_TUNNEL_KEY_OPTS_VXLAN); + if (!inner) + goto failure; + if (nla_put_be32(skb, NFTA_TUNNEL_KEY_VXLAN_GBP, + htonl(opts->u.vxlan.gbp))) + goto inner_failure; + nla_nest_end(skb, inner); + } else if (opts->flags & TUNNEL_ERSPAN_OPT) { + inner = nla_nest_start_noflag(skb, NFTA_TUNNEL_KEY_OPTS_ERSPAN); + if (!inner) + goto failure; + if (nla_put_be32(skb, NFTA_TUNNEL_KEY_ERSPAN_VERSION, + htonl(opts->u.erspan.version))) + goto inner_failure; + switch (opts->u.erspan.version) { + case ERSPAN_VERSION: + if (nla_put_be32(skb, NFTA_TUNNEL_KEY_ERSPAN_V1_INDEX, + opts->u.erspan.u.index)) + goto inner_failure; + break; + case ERSPAN_VERSION2: + if (nla_put_u8(skb, NFTA_TUNNEL_KEY_ERSPAN_V2_HWID, + get_hwid(&opts->u.erspan.u.md2)) || + nla_put_u8(skb, NFTA_TUNNEL_KEY_ERSPAN_V2_DIR, + opts->u.erspan.u.md2.dir)) + goto inner_failure; + break; + } + nla_nest_end(skb, inner); + } else if (opts->flags & TUNNEL_GENEVE_OPT) { + struct geneve_opt *opt; + int offset = 0; + + inner = nla_nest_start_noflag(skb, NFTA_TUNNEL_KEY_OPTS_GENEVE); + if (!inner) + goto failure; + while (opts->len > offset) { + opt = (struct geneve_opt *)opts->u.data + offset; + if (nla_put_be16(skb, NFTA_TUNNEL_KEY_GENEVE_CLASS, + opt->opt_class) || + nla_put_u8(skb, NFTA_TUNNEL_KEY_GENEVE_TYPE, + opt->type) || + nla_put(skb, NFTA_TUNNEL_KEY_GENEVE_DATA, + opt->length * 4, opt->opt_data)) + goto inner_failure; + offset += sizeof(*opt) + opt->length * 4; + } + nla_nest_end(skb, inner); + } + nla_nest_end(skb, nest); + return 0; + +inner_failure: + nla_nest_cancel(skb, inner); +failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static int nft_tunnel_ports_dump(struct sk_buff *skb, + struct ip_tunnel_info *info) +{ + if (nla_put_be16(skb, NFTA_TUNNEL_KEY_SPORT, info->key.tp_src) < 0 || + nla_put_be16(skb, NFTA_TUNNEL_KEY_DPORT, info->key.tp_dst) < 0) + return -1; + + return 0; +} + +static int nft_tunnel_flags_dump(struct sk_buff *skb, + struct ip_tunnel_info *info) +{ + u32 flags = 0; + + if (info->key.tun_flags & TUNNEL_DONT_FRAGMENT) + flags |= NFT_TUNNEL_F_DONT_FRAGMENT; + if (!(info->key.tun_flags & TUNNEL_CSUM)) + flags |= NFT_TUNNEL_F_ZERO_CSUM_TX; + if (info->key.tun_flags & TUNNEL_SEQ) + flags |= NFT_TUNNEL_F_SEQ_NUMBER; + + if (nla_put_be32(skb, NFTA_TUNNEL_KEY_FLAGS, htonl(flags)) < 0) + return -1; + + return 0; +} + +static int nft_tunnel_obj_dump(struct sk_buff *skb, + struct nft_object *obj, bool reset) +{ + struct nft_tunnel_obj *priv = nft_obj_data(obj); + struct ip_tunnel_info *info = &priv->md->u.tun_info; + + if (nla_put_be32(skb, NFTA_TUNNEL_KEY_ID, + tunnel_id_to_key32(info->key.tun_id)) || + nft_tunnel_ip_dump(skb, info) < 0 || + nft_tunnel_ports_dump(skb, info) < 0 || + nft_tunnel_flags_dump(skb, info) < 0 || + nla_put_u8(skb, NFTA_TUNNEL_KEY_TOS, info->key.tos) || + nla_put_u8(skb, NFTA_TUNNEL_KEY_TTL, info->key.ttl) || + nft_tunnel_opts_dump(skb, priv) < 0) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static void nft_tunnel_obj_destroy(const struct nft_ctx *ctx, + struct nft_object *obj) +{ + struct nft_tunnel_obj *priv = nft_obj_data(obj); + + metadata_dst_free(priv->md); +} + +static struct nft_object_type nft_tunnel_obj_type; +static const struct nft_object_ops nft_tunnel_obj_ops = { + .type = &nft_tunnel_obj_type, + .size = sizeof(struct nft_tunnel_obj), + .eval = nft_tunnel_obj_eval, + .init = nft_tunnel_obj_init, + .destroy = nft_tunnel_obj_destroy, + .dump = nft_tunnel_obj_dump, +}; + +static struct nft_object_type nft_tunnel_obj_type __read_mostly = { + .type = NFT_OBJECT_TUNNEL, + .ops = &nft_tunnel_obj_ops, + .maxattr = NFTA_TUNNEL_KEY_MAX, + .policy = nft_tunnel_key_policy, + .owner = THIS_MODULE, +}; + +static int __init nft_tunnel_module_init(void) +{ + int err; + + err = nft_register_expr(&nft_tunnel_type); + if (err < 0) + return err; + + err = nft_register_obj(&nft_tunnel_obj_type); + if (err < 0) + nft_unregister_expr(&nft_tunnel_type); + + return err; +} + +static void __exit nft_tunnel_module_exit(void) +{ + nft_unregister_obj(&nft_tunnel_obj_type); + nft_unregister_expr(&nft_tunnel_type); +} + +module_init(nft_tunnel_module_init); +module_exit(nft_tunnel_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_ALIAS_NFT_EXPR("tunnel"); +MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_TUNNEL); +MODULE_DESCRIPTION("nftables tunnel expression support"); diff --git a/net/netfilter/nft_xfrm.c b/net/netfilter/nft_xfrm.c new file mode 100644 index 000000000..30259846c --- /dev/null +++ b/net/netfilter/nft_xfrm.c @@ -0,0 +1,324 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * + * Generic part shared by ipv4 and ipv6 backends. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static const struct nla_policy nft_xfrm_policy[NFTA_XFRM_MAX + 1] = { + [NFTA_XFRM_KEY] = { .type = NLA_U32 }, + [NFTA_XFRM_DIR] = { .type = NLA_U8 }, + [NFTA_XFRM_SPNUM] = { .type = NLA_U32 }, + [NFTA_XFRM_DREG] = { .type = NLA_U32 }, +}; + +struct nft_xfrm { + enum nft_xfrm_keys key:8; + u8 dreg; + u8 dir; + u8 spnum; + u8 len; +}; + +static int nft_xfrm_get_init(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nlattr * const tb[]) +{ + struct nft_xfrm *priv = nft_expr_priv(expr); + unsigned int len = 0; + u32 spnum = 0; + u8 dir; + + if (!tb[NFTA_XFRM_KEY] || !tb[NFTA_XFRM_DIR] || !tb[NFTA_XFRM_DREG]) + return -EINVAL; + + switch (ctx->family) { + case NFPROTO_IPV4: + case NFPROTO_IPV6: + case NFPROTO_INET: + break; + default: + return -EOPNOTSUPP; + } + + priv->key = ntohl(nla_get_be32(tb[NFTA_XFRM_KEY])); + switch (priv->key) { + case NFT_XFRM_KEY_REQID: + case NFT_XFRM_KEY_SPI: + len = sizeof(u32); + break; + case NFT_XFRM_KEY_DADDR_IP4: + case NFT_XFRM_KEY_SADDR_IP4: + len = sizeof(struct in_addr); + break; + case NFT_XFRM_KEY_DADDR_IP6: + case NFT_XFRM_KEY_SADDR_IP6: + len = sizeof(struct in6_addr); + break; + default: + return -EINVAL; + } + + dir = nla_get_u8(tb[NFTA_XFRM_DIR]); + switch (dir) { + case XFRM_POLICY_IN: + case XFRM_POLICY_OUT: + priv->dir = dir; + break; + default: + return -EINVAL; + } + + if (tb[NFTA_XFRM_SPNUM]) + spnum = ntohl(nla_get_be32(tb[NFTA_XFRM_SPNUM])); + + if (spnum >= XFRM_MAX_DEPTH) + return -ERANGE; + + priv->spnum = spnum; + + priv->len = len; + return nft_parse_register_store(ctx, tb[NFTA_XFRM_DREG], &priv->dreg, + NULL, NFT_DATA_VALUE, len); +} + +/* Return true if key asks for daddr/saddr and current + * state does have a valid address (BEET, TUNNEL). + */ +static bool xfrm_state_addr_ok(enum nft_xfrm_keys k, u8 family, u8 mode) +{ + switch (k) { + case NFT_XFRM_KEY_DADDR_IP4: + case NFT_XFRM_KEY_SADDR_IP4: + if (family == NFPROTO_IPV4) + break; + return false; + case NFT_XFRM_KEY_DADDR_IP6: + case NFT_XFRM_KEY_SADDR_IP6: + if (family == NFPROTO_IPV6) + break; + return false; + default: + return true; + } + + return mode == XFRM_MODE_BEET || mode == XFRM_MODE_TUNNEL; +} + +static void nft_xfrm_state_get_key(const struct nft_xfrm *priv, + struct nft_regs *regs, + const struct xfrm_state *state) +{ + u32 *dest = ®s->data[priv->dreg]; + + if (!xfrm_state_addr_ok(priv->key, + state->props.family, + state->props.mode)) { + regs->verdict.code = NFT_BREAK; + return; + } + + switch (priv->key) { + case NFT_XFRM_KEY_UNSPEC: + case __NFT_XFRM_KEY_MAX: + WARN_ON_ONCE(1); + break; + case NFT_XFRM_KEY_DADDR_IP4: + *dest = (__force __u32)state->id.daddr.a4; + return; + case NFT_XFRM_KEY_DADDR_IP6: + memcpy(dest, &state->id.daddr.in6, sizeof(struct in6_addr)); + return; + case NFT_XFRM_KEY_SADDR_IP4: + *dest = (__force __u32)state->props.saddr.a4; + return; + case NFT_XFRM_KEY_SADDR_IP6: + memcpy(dest, &state->props.saddr.in6, sizeof(struct in6_addr)); + return; + case NFT_XFRM_KEY_REQID: + *dest = state->props.reqid; + return; + case NFT_XFRM_KEY_SPI: + *dest = (__force __u32)state->id.spi; + return; + } + + regs->verdict.code = NFT_BREAK; +} + +static void nft_xfrm_get_eval_in(const struct nft_xfrm *priv, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct sec_path *sp = skb_sec_path(pkt->skb); + const struct xfrm_state *state; + + if (sp == NULL || sp->len <= priv->spnum) { + regs->verdict.code = NFT_BREAK; + return; + } + + state = sp->xvec[priv->spnum]; + nft_xfrm_state_get_key(priv, regs, state); +} + +static void nft_xfrm_get_eval_out(const struct nft_xfrm *priv, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct dst_entry *dst = skb_dst(pkt->skb); + int i; + + for (i = 0; dst && dst->xfrm; + dst = ((const struct xfrm_dst *)dst)->child, i++) { + if (i < priv->spnum) + continue; + + nft_xfrm_state_get_key(priv, regs, dst->xfrm); + return; + } + + regs->verdict.code = NFT_BREAK; +} + +static void nft_xfrm_get_eval(const struct nft_expr *expr, + struct nft_regs *regs, + const struct nft_pktinfo *pkt) +{ + const struct nft_xfrm *priv = nft_expr_priv(expr); + + switch (priv->dir) { + case XFRM_POLICY_IN: + nft_xfrm_get_eval_in(priv, regs, pkt); + break; + case XFRM_POLICY_OUT: + nft_xfrm_get_eval_out(priv, regs, pkt); + break; + default: + WARN_ON_ONCE(1); + regs->verdict.code = NFT_BREAK; + break; + } +} + +static int nft_xfrm_get_dump(struct sk_buff *skb, + const struct nft_expr *expr) +{ + const struct nft_xfrm *priv = nft_expr_priv(expr); + + if (nft_dump_register(skb, NFTA_XFRM_DREG, priv->dreg)) + return -1; + + if (nla_put_be32(skb, NFTA_XFRM_KEY, htonl(priv->key))) + return -1; + if (nla_put_u8(skb, NFTA_XFRM_DIR, priv->dir)) + return -1; + if (nla_put_be32(skb, NFTA_XFRM_SPNUM, htonl(priv->spnum))) + return -1; + + return 0; +} + +static int nft_xfrm_validate(const struct nft_ctx *ctx, const struct nft_expr *expr, + const struct nft_data **data) +{ + const struct nft_xfrm *priv = nft_expr_priv(expr); + unsigned int hooks; + + if (ctx->family != NFPROTO_IPV4 && + ctx->family != NFPROTO_IPV6 && + ctx->family != NFPROTO_INET) + return -EOPNOTSUPP; + + switch (priv->dir) { + case XFRM_POLICY_IN: + hooks = (1 << NF_INET_FORWARD) | + (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_PRE_ROUTING); + break; + case XFRM_POLICY_OUT: + hooks = (1 << NF_INET_FORWARD) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_POST_ROUTING); + break; + default: + WARN_ON_ONCE(1); + return -EINVAL; + } + + return nft_chain_validate_hooks(ctx->chain, hooks); +} + +static bool nft_xfrm_reduce(struct nft_regs_track *track, + const struct nft_expr *expr) +{ + const struct nft_xfrm *priv = nft_expr_priv(expr); + const struct nft_xfrm *xfrm; + + if (!nft_reg_track_cmp(track, expr, priv->dreg)) { + nft_reg_track_update(track, expr, priv->dreg, priv->len); + return false; + } + + xfrm = nft_expr_priv(track->regs[priv->dreg].selector); + if (priv->key != xfrm->key || + priv->dreg != xfrm->dreg || + priv->dir != xfrm->dir || + priv->spnum != xfrm->spnum) { + nft_reg_track_update(track, expr, priv->dreg, priv->len); + return false; + } + + if (!track->regs[priv->dreg].bitwise) + return true; + + return nft_expr_reduce_bitwise(track, expr); +} + +static struct nft_expr_type nft_xfrm_type; +static const struct nft_expr_ops nft_xfrm_get_ops = { + .type = &nft_xfrm_type, + .size = NFT_EXPR_SIZE(sizeof(struct nft_xfrm)), + .eval = nft_xfrm_get_eval, + .init = nft_xfrm_get_init, + .dump = nft_xfrm_get_dump, + .validate = nft_xfrm_validate, + .reduce = nft_xfrm_reduce, +}; + +static struct nft_expr_type nft_xfrm_type __read_mostly = { + .name = "xfrm", + .ops = &nft_xfrm_get_ops, + .policy = nft_xfrm_policy, + .maxattr = NFTA_XFRM_MAX, + .owner = THIS_MODULE, +}; + +static int __init nft_xfrm_module_init(void) +{ + return nft_register_expr(&nft_xfrm_type); +} + +static void __exit nft_xfrm_module_exit(void) +{ + nft_unregister_expr(&nft_xfrm_type); +} + +module_init(nft_xfrm_module_init); +module_exit(nft_xfrm_module_exit); + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("nf_tables: xfrm/IPSec matching"); +MODULE_AUTHOR("Florian Westphal "); +MODULE_AUTHOR("Máté Eckl "); +MODULE_ALIAS_NFT_EXPR("xfrm"); diff --git a/net/netfilter/utils.c b/net/netfilter/utils.c new file mode 100644 index 000000000..2182d361e --- /dev/null +++ b/net/netfilter/utils.c @@ -0,0 +1,217 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include + +#ifdef CONFIG_INET +__sum16 nf_ip_checksum(struct sk_buff *skb, unsigned int hook, + unsigned int dataoff, u8 protocol) +{ + const struct iphdr *iph = ip_hdr(skb); + __sum16 csum = 0; + + switch (skb->ip_summed) { + case CHECKSUM_COMPLETE: + if (hook != NF_INET_PRE_ROUTING && hook != NF_INET_LOCAL_IN) + break; + if ((protocol != IPPROTO_TCP && protocol != IPPROTO_UDP && + !csum_fold(skb->csum)) || + !csum_tcpudp_magic(iph->saddr, iph->daddr, + skb->len - dataoff, protocol, + skb->csum)) { + skb->ip_summed = CHECKSUM_UNNECESSARY; + break; + } + fallthrough; + case CHECKSUM_NONE: + if (protocol != IPPROTO_TCP && protocol != IPPROTO_UDP) + skb->csum = 0; + else + skb->csum = csum_tcpudp_nofold(iph->saddr, iph->daddr, + skb->len - dataoff, + protocol, 0); + csum = __skb_checksum_complete(skb); + } + return csum; +} +EXPORT_SYMBOL(nf_ip_checksum); +#endif + +static __sum16 nf_ip_checksum_partial(struct sk_buff *skb, unsigned int hook, + unsigned int dataoff, unsigned int len, + u8 protocol) +{ + const struct iphdr *iph = ip_hdr(skb); + __sum16 csum = 0; + + switch (skb->ip_summed) { + case CHECKSUM_COMPLETE: + if (len == skb->len - dataoff) + return nf_ip_checksum(skb, hook, dataoff, protocol); + fallthrough; + case CHECKSUM_NONE: + skb->csum = csum_tcpudp_nofold(iph->saddr, iph->daddr, protocol, + skb->len - dataoff, 0); + skb->ip_summed = CHECKSUM_NONE; + return __skb_checksum_complete_head(skb, dataoff + len); + } + return csum; +} + +__sum16 nf_ip6_checksum(struct sk_buff *skb, unsigned int hook, + unsigned int dataoff, u8 protocol) +{ + const struct ipv6hdr *ip6h = ipv6_hdr(skb); + __sum16 csum = 0; + + switch (skb->ip_summed) { + case CHECKSUM_COMPLETE: + if (hook != NF_INET_PRE_ROUTING && hook != NF_INET_LOCAL_IN) + break; + if (!csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, + skb->len - dataoff, protocol, + csum_sub(skb->csum, + skb_checksum(skb, 0, + dataoff, 0)))) { + skb->ip_summed = CHECKSUM_UNNECESSARY; + break; + } + fallthrough; + case CHECKSUM_NONE: + skb->csum = ~csum_unfold( + csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, + skb->len - dataoff, + protocol, + csum_sub(0, + skb_checksum(skb, 0, + dataoff, 0)))); + csum = __skb_checksum_complete(skb); + } + return csum; +} +EXPORT_SYMBOL(nf_ip6_checksum); + +static __sum16 nf_ip6_checksum_partial(struct sk_buff *skb, unsigned int hook, + unsigned int dataoff, unsigned int len, + u8 protocol) +{ + const struct ipv6hdr *ip6h = ipv6_hdr(skb); + __wsum hsum; + __sum16 csum = 0; + + switch (skb->ip_summed) { + case CHECKSUM_COMPLETE: + if (len == skb->len - dataoff) + return nf_ip6_checksum(skb, hook, dataoff, protocol); + fallthrough; + case CHECKSUM_NONE: + hsum = skb_checksum(skb, 0, dataoff, 0); + skb->csum = ~csum_unfold(csum_ipv6_magic(&ip6h->saddr, + &ip6h->daddr, + skb->len - dataoff, + protocol, + csum_sub(0, hsum))); + skb->ip_summed = CHECKSUM_NONE; + return __skb_checksum_complete_head(skb, dataoff + len); + } + return csum; +}; + +__sum16 nf_checksum(struct sk_buff *skb, unsigned int hook, + unsigned int dataoff, u8 protocol, + unsigned short family) +{ + __sum16 csum = 0; + + switch (family) { + case AF_INET: + csum = nf_ip_checksum(skb, hook, dataoff, protocol); + break; + case AF_INET6: + csum = nf_ip6_checksum(skb, hook, dataoff, protocol); + break; + } + + return csum; +} +EXPORT_SYMBOL_GPL(nf_checksum); + +__sum16 nf_checksum_partial(struct sk_buff *skb, unsigned int hook, + unsigned int dataoff, unsigned int len, + u8 protocol, unsigned short family) +{ + __sum16 csum = 0; + + switch (family) { + case AF_INET: + csum = nf_ip_checksum_partial(skb, hook, dataoff, len, + protocol); + break; + case AF_INET6: + csum = nf_ip6_checksum_partial(skb, hook, dataoff, len, + protocol); + break; + } + + return csum; +} +EXPORT_SYMBOL_GPL(nf_checksum_partial); + +int nf_route(struct net *net, struct dst_entry **dst, struct flowi *fl, + bool strict, unsigned short family) +{ + const struct nf_ipv6_ops *v6ops __maybe_unused; + int ret = 0; + + switch (family) { + case AF_INET: + ret = nf_ip_route(net, dst, fl, strict); + break; + case AF_INET6: + ret = nf_ip6_route(net, dst, fl, strict); + break; + } + + return ret; +} +EXPORT_SYMBOL_GPL(nf_route); + +static int nf_ip_reroute(struct sk_buff *skb, const struct nf_queue_entry *entry) +{ +#ifdef CONFIG_INET + const struct ip_rt_info *rt_info = nf_queue_entry_reroute(entry); + + if (entry->state.hook == NF_INET_LOCAL_OUT) { + const struct iphdr *iph = ip_hdr(skb); + + if (!(iph->tos == rt_info->tos && + skb->mark == rt_info->mark && + iph->daddr == rt_info->daddr && + iph->saddr == rt_info->saddr)) + return ip_route_me_harder(entry->state.net, entry->state.sk, + skb, RTN_UNSPEC); + } +#endif + return 0; +} + +int nf_reroute(struct sk_buff *skb, struct nf_queue_entry *entry) +{ + const struct nf_ipv6_ops *v6ops; + int ret = 0; + + switch (entry->state.pf) { + case AF_INET: + ret = nf_ip_reroute(skb, entry); + break; + case AF_INET6: + v6ops = rcu_dereference(nf_ipv6_ops); + if (v6ops) + ret = v6ops->reroute(skb, entry); + break; + } + return ret; +} diff --git a/net/netfilter/x_tables.c b/net/netfilter/x_tables.c new file mode 100644 index 000000000..470282cf3 --- /dev/null +++ b/net/netfilter/x_tables.c @@ -0,0 +1,2017 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * x_tables core - Backend for {ip,ip6,arp}_tables + * + * Copyright (C) 2006-2006 Harald Welte + * Copyright (C) 2006-2012 Patrick McHardy + * + * Based on existing ip_tables code which is + * Copyright (C) 1999 Paul `Rusty' Russell & Michael J. Neuling + * Copyright (C) 2000-2005 Netfilter Core Team + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Harald Welte "); +MODULE_DESCRIPTION("{ip,ip6,arp,eb}_tables backend module"); + +#define XT_PCPU_BLOCK_SIZE 4096 +#define XT_MAX_TABLE_SIZE (512 * 1024 * 1024) + +struct xt_template { + struct list_head list; + + /* called when table is needed in the given netns */ + int (*table_init)(struct net *net); + + struct module *me; + + /* A unique name... */ + char name[XT_TABLE_MAXNAMELEN]; +}; + +static struct list_head xt_templates[NFPROTO_NUMPROTO]; + +struct xt_pernet { + struct list_head tables[NFPROTO_NUMPROTO]; +}; + +struct compat_delta { + unsigned int offset; /* offset in kernel */ + int delta; /* delta in 32bit user land */ +}; + +struct xt_af { + struct mutex mutex; + struct list_head match; + struct list_head target; +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + struct mutex compat_mutex; + struct compat_delta *compat_tab; + unsigned int number; /* number of slots in compat_tab[] */ + unsigned int cur; /* number of used slots in compat_tab[] */ +#endif +}; + +static unsigned int xt_pernet_id __read_mostly; +static struct xt_af *xt __read_mostly; + +static const char *const xt_prefix[NFPROTO_NUMPROTO] = { + [NFPROTO_UNSPEC] = "x", + [NFPROTO_IPV4] = "ip", + [NFPROTO_ARP] = "arp", + [NFPROTO_BRIDGE] = "eb", + [NFPROTO_IPV6] = "ip6", +}; + +/* Registration hooks for targets. */ +int xt_register_target(struct xt_target *target) +{ + u_int8_t af = target->family; + + mutex_lock(&xt[af].mutex); + list_add(&target->list, &xt[af].target); + mutex_unlock(&xt[af].mutex); + return 0; +} +EXPORT_SYMBOL(xt_register_target); + +void +xt_unregister_target(struct xt_target *target) +{ + u_int8_t af = target->family; + + mutex_lock(&xt[af].mutex); + list_del(&target->list); + mutex_unlock(&xt[af].mutex); +} +EXPORT_SYMBOL(xt_unregister_target); + +int +xt_register_targets(struct xt_target *target, unsigned int n) +{ + unsigned int i; + int err = 0; + + for (i = 0; i < n; i++) { + err = xt_register_target(&target[i]); + if (err) + goto err; + } + return err; + +err: + if (i > 0) + xt_unregister_targets(target, i); + return err; +} +EXPORT_SYMBOL(xt_register_targets); + +void +xt_unregister_targets(struct xt_target *target, unsigned int n) +{ + while (n-- > 0) + xt_unregister_target(&target[n]); +} +EXPORT_SYMBOL(xt_unregister_targets); + +int xt_register_match(struct xt_match *match) +{ + u_int8_t af = match->family; + + mutex_lock(&xt[af].mutex); + list_add(&match->list, &xt[af].match); + mutex_unlock(&xt[af].mutex); + return 0; +} +EXPORT_SYMBOL(xt_register_match); + +void +xt_unregister_match(struct xt_match *match) +{ + u_int8_t af = match->family; + + mutex_lock(&xt[af].mutex); + list_del(&match->list); + mutex_unlock(&xt[af].mutex); +} +EXPORT_SYMBOL(xt_unregister_match); + +int +xt_register_matches(struct xt_match *match, unsigned int n) +{ + unsigned int i; + int err = 0; + + for (i = 0; i < n; i++) { + err = xt_register_match(&match[i]); + if (err) + goto err; + } + return err; + +err: + if (i > 0) + xt_unregister_matches(match, i); + return err; +} +EXPORT_SYMBOL(xt_register_matches); + +void +xt_unregister_matches(struct xt_match *match, unsigned int n) +{ + while (n-- > 0) + xt_unregister_match(&match[n]); +} +EXPORT_SYMBOL(xt_unregister_matches); + + +/* + * These are weird, but module loading must not be done with mutex + * held (since they will register), and we have to have a single + * function to use. + */ + +/* Find match, grabs ref. Returns ERR_PTR() on error. */ +struct xt_match *xt_find_match(u8 af, const char *name, u8 revision) +{ + struct xt_match *m; + int err = -ENOENT; + + if (strnlen(name, XT_EXTENSION_MAXNAMELEN) == XT_EXTENSION_MAXNAMELEN) + return ERR_PTR(-EINVAL); + + mutex_lock(&xt[af].mutex); + list_for_each_entry(m, &xt[af].match, list) { + if (strcmp(m->name, name) == 0) { + if (m->revision == revision) { + if (try_module_get(m->me)) { + mutex_unlock(&xt[af].mutex); + return m; + } + } else + err = -EPROTOTYPE; /* Found something. */ + } + } + mutex_unlock(&xt[af].mutex); + + if (af != NFPROTO_UNSPEC) + /* Try searching again in the family-independent list */ + return xt_find_match(NFPROTO_UNSPEC, name, revision); + + return ERR_PTR(err); +} +EXPORT_SYMBOL(xt_find_match); + +struct xt_match * +xt_request_find_match(uint8_t nfproto, const char *name, uint8_t revision) +{ + struct xt_match *match; + + if (strnlen(name, XT_EXTENSION_MAXNAMELEN) == XT_EXTENSION_MAXNAMELEN) + return ERR_PTR(-EINVAL); + + match = xt_find_match(nfproto, name, revision); + if (IS_ERR(match)) { + request_module("%st_%s", xt_prefix[nfproto], name); + match = xt_find_match(nfproto, name, revision); + } + + return match; +} +EXPORT_SYMBOL_GPL(xt_request_find_match); + +/* Find target, grabs ref. Returns ERR_PTR() on error. */ +static struct xt_target *xt_find_target(u8 af, const char *name, u8 revision) +{ + struct xt_target *t; + int err = -ENOENT; + + if (strnlen(name, XT_EXTENSION_MAXNAMELEN) == XT_EXTENSION_MAXNAMELEN) + return ERR_PTR(-EINVAL); + + mutex_lock(&xt[af].mutex); + list_for_each_entry(t, &xt[af].target, list) { + if (strcmp(t->name, name) == 0) { + if (t->revision == revision) { + if (try_module_get(t->me)) { + mutex_unlock(&xt[af].mutex); + return t; + } + } else + err = -EPROTOTYPE; /* Found something. */ + } + } + mutex_unlock(&xt[af].mutex); + + if (af != NFPROTO_UNSPEC) + /* Try searching again in the family-independent list */ + return xt_find_target(NFPROTO_UNSPEC, name, revision); + + return ERR_PTR(err); +} + +struct xt_target *xt_request_find_target(u8 af, const char *name, u8 revision) +{ + struct xt_target *target; + + if (strnlen(name, XT_EXTENSION_MAXNAMELEN) == XT_EXTENSION_MAXNAMELEN) + return ERR_PTR(-EINVAL); + + target = xt_find_target(af, name, revision); + if (IS_ERR(target)) { + request_module("%st_%s", xt_prefix[af], name); + target = xt_find_target(af, name, revision); + } + + return target; +} +EXPORT_SYMBOL_GPL(xt_request_find_target); + + +static int xt_obj_to_user(u16 __user *psize, u16 size, + void __user *pname, const char *name, + u8 __user *prev, u8 rev) +{ + if (put_user(size, psize)) + return -EFAULT; + if (copy_to_user(pname, name, strlen(name) + 1)) + return -EFAULT; + if (put_user(rev, prev)) + return -EFAULT; + + return 0; +} + +#define XT_OBJ_TO_USER(U, K, TYPE, C_SIZE) \ + xt_obj_to_user(&U->u.TYPE##_size, C_SIZE ? : K->u.TYPE##_size, \ + U->u.user.name, K->u.kernel.TYPE->name, \ + &U->u.user.revision, K->u.kernel.TYPE->revision) + +int xt_data_to_user(void __user *dst, const void *src, + int usersize, int size, int aligned_size) +{ + usersize = usersize ? : size; + if (copy_to_user(dst, src, usersize)) + return -EFAULT; + if (usersize != aligned_size && + clear_user(dst + usersize, aligned_size - usersize)) + return -EFAULT; + + return 0; +} +EXPORT_SYMBOL_GPL(xt_data_to_user); + +#define XT_DATA_TO_USER(U, K, TYPE) \ + xt_data_to_user(U->data, K->data, \ + K->u.kernel.TYPE->usersize, \ + K->u.kernel.TYPE->TYPE##size, \ + XT_ALIGN(K->u.kernel.TYPE->TYPE##size)) + +int xt_match_to_user(const struct xt_entry_match *m, + struct xt_entry_match __user *u) +{ + return XT_OBJ_TO_USER(u, m, match, 0) || + XT_DATA_TO_USER(u, m, match); +} +EXPORT_SYMBOL_GPL(xt_match_to_user); + +int xt_target_to_user(const struct xt_entry_target *t, + struct xt_entry_target __user *u) +{ + return XT_OBJ_TO_USER(u, t, target, 0) || + XT_DATA_TO_USER(u, t, target); +} +EXPORT_SYMBOL_GPL(xt_target_to_user); + +static int match_revfn(u8 af, const char *name, u8 revision, int *bestp) +{ + const struct xt_match *m; + int have_rev = 0; + + mutex_lock(&xt[af].mutex); + list_for_each_entry(m, &xt[af].match, list) { + if (strcmp(m->name, name) == 0) { + if (m->revision > *bestp) + *bestp = m->revision; + if (m->revision == revision) + have_rev = 1; + } + } + mutex_unlock(&xt[af].mutex); + + if (af != NFPROTO_UNSPEC && !have_rev) + return match_revfn(NFPROTO_UNSPEC, name, revision, bestp); + + return have_rev; +} + +static int target_revfn(u8 af, const char *name, u8 revision, int *bestp) +{ + const struct xt_target *t; + int have_rev = 0; + + mutex_lock(&xt[af].mutex); + list_for_each_entry(t, &xt[af].target, list) { + if (strcmp(t->name, name) == 0) { + if (t->revision > *bestp) + *bestp = t->revision; + if (t->revision == revision) + have_rev = 1; + } + } + mutex_unlock(&xt[af].mutex); + + if (af != NFPROTO_UNSPEC && !have_rev) + return target_revfn(NFPROTO_UNSPEC, name, revision, bestp); + + return have_rev; +} + +/* Returns true or false (if no such extension at all) */ +int xt_find_revision(u8 af, const char *name, u8 revision, int target, + int *err) +{ + int have_rev, best = -1; + + if (target == 1) + have_rev = target_revfn(af, name, revision, &best); + else + have_rev = match_revfn(af, name, revision, &best); + + /* Nothing at all? Return 0 to try loading module. */ + if (best == -1) { + *err = -ENOENT; + return 0; + } + + *err = best; + if (!have_rev) + *err = -EPROTONOSUPPORT; + return 1; +} +EXPORT_SYMBOL_GPL(xt_find_revision); + +static char * +textify_hooks(char *buf, size_t size, unsigned int mask, uint8_t nfproto) +{ + static const char *const inetbr_names[] = { + "PREROUTING", "INPUT", "FORWARD", + "OUTPUT", "POSTROUTING", "BROUTING", + }; + static const char *const arp_names[] = { + "INPUT", "FORWARD", "OUTPUT", + }; + const char *const *names; + unsigned int i, max; + char *p = buf; + bool np = false; + int res; + + names = (nfproto == NFPROTO_ARP) ? arp_names : inetbr_names; + max = (nfproto == NFPROTO_ARP) ? ARRAY_SIZE(arp_names) : + ARRAY_SIZE(inetbr_names); + *p = '\0'; + for (i = 0; i < max; ++i) { + if (!(mask & (1 << i))) + continue; + res = snprintf(p, size, "%s%s", np ? "/" : "", names[i]); + if (res > 0) { + size -= res; + p += res; + } + np = true; + } + + return buf; +} + +/** + * xt_check_proc_name - check that name is suitable for /proc file creation + * + * @name: file name candidate + * @size: length of buffer + * + * some x_tables modules wish to create a file in /proc. + * This function makes sure that the name is suitable for this + * purpose, it checks that name is NUL terminated and isn't a 'special' + * name, like "..". + * + * returns negative number on error or 0 if name is useable. + */ +int xt_check_proc_name(const char *name, unsigned int size) +{ + if (name[0] == '\0') + return -EINVAL; + + if (strnlen(name, size) == size) + return -ENAMETOOLONG; + + if (strcmp(name, ".") == 0 || + strcmp(name, "..") == 0 || + strchr(name, '/')) + return -EINVAL; + + return 0; +} +EXPORT_SYMBOL(xt_check_proc_name); + +int xt_check_match(struct xt_mtchk_param *par, + unsigned int size, u16 proto, bool inv_proto) +{ + int ret; + + if (XT_ALIGN(par->match->matchsize) != size && + par->match->matchsize != -1) { + /* + * ebt_among is exempt from centralized matchsize checking + * because it uses a dynamic-size data set. + */ + pr_err_ratelimited("%s_tables: %s.%u match: invalid size %u (kernel) != (user) %u\n", + xt_prefix[par->family], par->match->name, + par->match->revision, + XT_ALIGN(par->match->matchsize), size); + return -EINVAL; + } + if (par->match->table != NULL && + strcmp(par->match->table, par->table) != 0) { + pr_info_ratelimited("%s_tables: %s match: only valid in %s table, not %s\n", + xt_prefix[par->family], par->match->name, + par->match->table, par->table); + return -EINVAL; + } + if (par->match->hooks && (par->hook_mask & ~par->match->hooks) != 0) { + char used[64], allow[64]; + + pr_info_ratelimited("%s_tables: %s match: used from hooks %s, but only valid from %s\n", + xt_prefix[par->family], par->match->name, + textify_hooks(used, sizeof(used), + par->hook_mask, par->family), + textify_hooks(allow, sizeof(allow), + par->match->hooks, + par->family)); + return -EINVAL; + } + if (par->match->proto && (par->match->proto != proto || inv_proto)) { + pr_info_ratelimited("%s_tables: %s match: only valid for protocol %u\n", + xt_prefix[par->family], par->match->name, + par->match->proto); + return -EINVAL; + } + if (par->match->checkentry != NULL) { + ret = par->match->checkentry(par); + if (ret < 0) + return ret; + else if (ret > 0) + /* Flag up potential errors. */ + return -EIO; + } + return 0; +} +EXPORT_SYMBOL_GPL(xt_check_match); + +/** xt_check_entry_match - check that matches end before start of target + * + * @match: beginning of xt_entry_match + * @target: beginning of this rules target (alleged end of matches) + * @alignment: alignment requirement of match structures + * + * Validates that all matches add up to the beginning of the target, + * and that each match covers at least the base structure size. + * + * Return: 0 on success, negative errno on failure. + */ +static int xt_check_entry_match(const char *match, const char *target, + const size_t alignment) +{ + const struct xt_entry_match *pos; + int length = target - match; + + if (length == 0) /* no matches */ + return 0; + + pos = (struct xt_entry_match *)match; + do { + if ((unsigned long)pos % alignment) + return -EINVAL; + + if (length < (int)sizeof(struct xt_entry_match)) + return -EINVAL; + + if (pos->u.match_size < sizeof(struct xt_entry_match)) + return -EINVAL; + + if (pos->u.match_size > length) + return -EINVAL; + + length -= pos->u.match_size; + pos = ((void *)((char *)(pos) + (pos)->u.match_size)); + } while (length > 0); + + return 0; +} + +/** xt_check_table_hooks - check hook entry points are sane + * + * @info xt_table_info to check + * @valid_hooks - hook entry points that we can enter from + * + * Validates that the hook entry and underflows points are set up. + * + * Return: 0 on success, negative errno on failure. + */ +int xt_check_table_hooks(const struct xt_table_info *info, unsigned int valid_hooks) +{ + const char *err = "unsorted underflow"; + unsigned int i, max_uflow, max_entry; + bool check_hooks = false; + + BUILD_BUG_ON(ARRAY_SIZE(info->hook_entry) != ARRAY_SIZE(info->underflow)); + + max_entry = 0; + max_uflow = 0; + + for (i = 0; i < ARRAY_SIZE(info->hook_entry); i++) { + if (!(valid_hooks & (1 << i))) + continue; + + if (info->hook_entry[i] == 0xFFFFFFFF) + return -EINVAL; + if (info->underflow[i] == 0xFFFFFFFF) + return -EINVAL; + + if (check_hooks) { + if (max_uflow > info->underflow[i]) + goto error; + + if (max_uflow == info->underflow[i]) { + err = "duplicate underflow"; + goto error; + } + if (max_entry > info->hook_entry[i]) { + err = "unsorted entry"; + goto error; + } + if (max_entry == info->hook_entry[i]) { + err = "duplicate entry"; + goto error; + } + } + max_entry = info->hook_entry[i]; + max_uflow = info->underflow[i]; + check_hooks = true; + } + + return 0; +error: + pr_err_ratelimited("%s at hook %d\n", err, i); + return -EINVAL; +} +EXPORT_SYMBOL(xt_check_table_hooks); + +static bool verdict_ok(int verdict) +{ + if (verdict > 0) + return true; + + if (verdict < 0) { + int v = -verdict - 1; + + if (verdict == XT_RETURN) + return true; + + switch (v) { + case NF_ACCEPT: return true; + case NF_DROP: return true; + case NF_QUEUE: return true; + default: + break; + } + + return false; + } + + return false; +} + +static bool error_tg_ok(unsigned int usersize, unsigned int kernsize, + const char *msg, unsigned int msglen) +{ + return usersize == kernsize && strnlen(msg, msglen) < msglen; +} + +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT +int xt_compat_add_offset(u_int8_t af, unsigned int offset, int delta) +{ + struct xt_af *xp = &xt[af]; + + WARN_ON(!mutex_is_locked(&xt[af].compat_mutex)); + + if (WARN_ON(!xp->compat_tab)) + return -ENOMEM; + + if (xp->cur >= xp->number) + return -EINVAL; + + if (xp->cur) + delta += xp->compat_tab[xp->cur - 1].delta; + xp->compat_tab[xp->cur].offset = offset; + xp->compat_tab[xp->cur].delta = delta; + xp->cur++; + return 0; +} +EXPORT_SYMBOL_GPL(xt_compat_add_offset); + +void xt_compat_flush_offsets(u_int8_t af) +{ + WARN_ON(!mutex_is_locked(&xt[af].compat_mutex)); + + if (xt[af].compat_tab) { + vfree(xt[af].compat_tab); + xt[af].compat_tab = NULL; + xt[af].number = 0; + xt[af].cur = 0; + } +} +EXPORT_SYMBOL_GPL(xt_compat_flush_offsets); + +int xt_compat_calc_jump(u_int8_t af, unsigned int offset) +{ + struct compat_delta *tmp = xt[af].compat_tab; + int mid, left = 0, right = xt[af].cur - 1; + + while (left <= right) { + mid = (left + right) >> 1; + if (offset > tmp[mid].offset) + left = mid + 1; + else if (offset < tmp[mid].offset) + right = mid - 1; + else + return mid ? tmp[mid - 1].delta : 0; + } + return left ? tmp[left - 1].delta : 0; +} +EXPORT_SYMBOL_GPL(xt_compat_calc_jump); + +int xt_compat_init_offsets(u8 af, unsigned int number) +{ + size_t mem; + + WARN_ON(!mutex_is_locked(&xt[af].compat_mutex)); + + if (!number || number > (INT_MAX / sizeof(struct compat_delta))) + return -EINVAL; + + if (WARN_ON(xt[af].compat_tab)) + return -EINVAL; + + mem = sizeof(struct compat_delta) * number; + if (mem > XT_MAX_TABLE_SIZE) + return -ENOMEM; + + xt[af].compat_tab = vmalloc(mem); + if (!xt[af].compat_tab) + return -ENOMEM; + + xt[af].number = number; + xt[af].cur = 0; + + return 0; +} +EXPORT_SYMBOL(xt_compat_init_offsets); + +int xt_compat_match_offset(const struct xt_match *match) +{ + u_int16_t csize = match->compatsize ? : match->matchsize; + return XT_ALIGN(match->matchsize) - COMPAT_XT_ALIGN(csize); +} +EXPORT_SYMBOL_GPL(xt_compat_match_offset); + +void xt_compat_match_from_user(struct xt_entry_match *m, void **dstptr, + unsigned int *size) +{ + const struct xt_match *match = m->u.kernel.match; + struct compat_xt_entry_match *cm = (struct compat_xt_entry_match *)m; + int off = xt_compat_match_offset(match); + u_int16_t msize = cm->u.user.match_size; + char name[sizeof(m->u.user.name)]; + + m = *dstptr; + memcpy(m, cm, sizeof(*cm)); + if (match->compat_from_user) + match->compat_from_user(m->data, cm->data); + else + memcpy(m->data, cm->data, msize - sizeof(*cm)); + + msize += off; + m->u.user.match_size = msize; + strscpy(name, match->name, sizeof(name)); + module_put(match->me); + strncpy(m->u.user.name, name, sizeof(m->u.user.name)); + + *size += off; + *dstptr += msize; +} +EXPORT_SYMBOL_GPL(xt_compat_match_from_user); + +#define COMPAT_XT_DATA_TO_USER(U, K, TYPE, C_SIZE) \ + xt_data_to_user(U->data, K->data, \ + K->u.kernel.TYPE->usersize, \ + C_SIZE, \ + COMPAT_XT_ALIGN(C_SIZE)) + +int xt_compat_match_to_user(const struct xt_entry_match *m, + void __user **dstptr, unsigned int *size) +{ + const struct xt_match *match = m->u.kernel.match; + struct compat_xt_entry_match __user *cm = *dstptr; + int off = xt_compat_match_offset(match); + u_int16_t msize = m->u.user.match_size - off; + + if (XT_OBJ_TO_USER(cm, m, match, msize)) + return -EFAULT; + + if (match->compat_to_user) { + if (match->compat_to_user((void __user *)cm->data, m->data)) + return -EFAULT; + } else { + if (COMPAT_XT_DATA_TO_USER(cm, m, match, msize - sizeof(*cm))) + return -EFAULT; + } + + *size -= off; + *dstptr += msize; + return 0; +} +EXPORT_SYMBOL_GPL(xt_compat_match_to_user); + +/* non-compat version may have padding after verdict */ +struct compat_xt_standard_target { + struct compat_xt_entry_target t; + compat_uint_t verdict; +}; + +struct compat_xt_error_target { + struct compat_xt_entry_target t; + char errorname[XT_FUNCTION_MAXNAMELEN]; +}; + +int xt_compat_check_entry_offsets(const void *base, const char *elems, + unsigned int target_offset, + unsigned int next_offset) +{ + long size_of_base_struct = elems - (const char *)base; + const struct compat_xt_entry_target *t; + const char *e = base; + + if (target_offset < size_of_base_struct) + return -EINVAL; + + if (target_offset + sizeof(*t) > next_offset) + return -EINVAL; + + t = (void *)(e + target_offset); + if (t->u.target_size < sizeof(*t)) + return -EINVAL; + + if (target_offset + t->u.target_size > next_offset) + return -EINVAL; + + if (strcmp(t->u.user.name, XT_STANDARD_TARGET) == 0) { + const struct compat_xt_standard_target *st = (const void *)t; + + if (COMPAT_XT_ALIGN(target_offset + sizeof(*st)) != next_offset) + return -EINVAL; + + if (!verdict_ok(st->verdict)) + return -EINVAL; + } else if (strcmp(t->u.user.name, XT_ERROR_TARGET) == 0) { + const struct compat_xt_error_target *et = (const void *)t; + + if (!error_tg_ok(t->u.target_size, sizeof(*et), + et->errorname, sizeof(et->errorname))) + return -EINVAL; + } + + /* compat_xt_entry match has less strict alignment requirements, + * otherwise they are identical. In case of padding differences + * we need to add compat version of xt_check_entry_match. + */ + BUILD_BUG_ON(sizeof(struct compat_xt_entry_match) != sizeof(struct xt_entry_match)); + + return xt_check_entry_match(elems, base + target_offset, + __alignof__(struct compat_xt_entry_match)); +} +EXPORT_SYMBOL(xt_compat_check_entry_offsets); +#endif /* CONFIG_NETFILTER_XTABLES_COMPAT */ + +/** + * xt_check_entry_offsets - validate arp/ip/ip6t_entry + * + * @base: pointer to arp/ip/ip6t_entry + * @elems: pointer to first xt_entry_match, i.e. ip(6)t_entry->elems + * @target_offset: the arp/ip/ip6_t->target_offset + * @next_offset: the arp/ip/ip6_t->next_offset + * + * validates that target_offset and next_offset are sane and that all + * match sizes (if any) align with the target offset. + * + * This function does not validate the targets or matches themselves, it + * only tests that all the offsets and sizes are correct, that all + * match structures are aligned, and that the last structure ends where + * the target structure begins. + * + * Also see xt_compat_check_entry_offsets for CONFIG_NETFILTER_XTABLES_COMPAT version. + * + * The arp/ip/ip6t_entry structure @base must have passed following tests: + * - it must point to a valid memory location + * - base to base + next_offset must be accessible, i.e. not exceed allocated + * length. + * + * A well-formed entry looks like this: + * + * ip(6)t_entry match [mtdata] match [mtdata] target [tgdata] ip(6)t_entry + * e->elems[]-----' | | + * matchsize | | + * matchsize | | + * | | + * target_offset---------------------------------' | + * next_offset---------------------------------------------------' + * + * elems[]: flexible array member at end of ip(6)/arpt_entry struct. + * This is where matches (if any) and the target reside. + * target_offset: beginning of target. + * next_offset: start of the next rule; also: size of this rule. + * Since targets have a minimum size, target_offset + minlen <= next_offset. + * + * Every match stores its size, sum of sizes must not exceed target_offset. + * + * Return: 0 on success, negative errno on failure. + */ +int xt_check_entry_offsets(const void *base, + const char *elems, + unsigned int target_offset, + unsigned int next_offset) +{ + long size_of_base_struct = elems - (const char *)base; + const struct xt_entry_target *t; + const char *e = base; + + /* target start is within the ip/ip6/arpt_entry struct */ + if (target_offset < size_of_base_struct) + return -EINVAL; + + if (target_offset + sizeof(*t) > next_offset) + return -EINVAL; + + t = (void *)(e + target_offset); + if (t->u.target_size < sizeof(*t)) + return -EINVAL; + + if (target_offset + t->u.target_size > next_offset) + return -EINVAL; + + if (strcmp(t->u.user.name, XT_STANDARD_TARGET) == 0) { + const struct xt_standard_target *st = (const void *)t; + + if (XT_ALIGN(target_offset + sizeof(*st)) != next_offset) + return -EINVAL; + + if (!verdict_ok(st->verdict)) + return -EINVAL; + } else if (strcmp(t->u.user.name, XT_ERROR_TARGET) == 0) { + const struct xt_error_target *et = (const void *)t; + + if (!error_tg_ok(t->u.target_size, sizeof(*et), + et->errorname, sizeof(et->errorname))) + return -EINVAL; + } + + return xt_check_entry_match(elems, base + target_offset, + __alignof__(struct xt_entry_match)); +} +EXPORT_SYMBOL(xt_check_entry_offsets); + +/** + * xt_alloc_entry_offsets - allocate array to store rule head offsets + * + * @size: number of entries + * + * Return: NULL or zeroed kmalloc'd or vmalloc'd array + */ +unsigned int *xt_alloc_entry_offsets(unsigned int size) +{ + if (size > XT_MAX_TABLE_SIZE / sizeof(unsigned int)) + return NULL; + + return kvcalloc(size, sizeof(unsigned int), GFP_KERNEL); + +} +EXPORT_SYMBOL(xt_alloc_entry_offsets); + +/** + * xt_find_jump_offset - check if target is a valid jump offset + * + * @offsets: array containing all valid rule start offsets of a rule blob + * @target: the jump target to search for + * @size: entries in @offset + */ +bool xt_find_jump_offset(const unsigned int *offsets, + unsigned int target, unsigned int size) +{ + int m, low = 0, hi = size; + + while (hi > low) { + m = (low + hi) / 2u; + + if (offsets[m] > target) + hi = m; + else if (offsets[m] < target) + low = m + 1; + else + return true; + } + + return false; +} +EXPORT_SYMBOL(xt_find_jump_offset); + +int xt_check_target(struct xt_tgchk_param *par, + unsigned int size, u16 proto, bool inv_proto) +{ + int ret; + + if (XT_ALIGN(par->target->targetsize) != size) { + pr_err_ratelimited("%s_tables: %s.%u target: invalid size %u (kernel) != (user) %u\n", + xt_prefix[par->family], par->target->name, + par->target->revision, + XT_ALIGN(par->target->targetsize), size); + return -EINVAL; + } + if (par->target->table != NULL && + strcmp(par->target->table, par->table) != 0) { + pr_info_ratelimited("%s_tables: %s target: only valid in %s table, not %s\n", + xt_prefix[par->family], par->target->name, + par->target->table, par->table); + return -EINVAL; + } + if (par->target->hooks && (par->hook_mask & ~par->target->hooks) != 0) { + char used[64], allow[64]; + + pr_info_ratelimited("%s_tables: %s target: used from hooks %s, but only usable from %s\n", + xt_prefix[par->family], par->target->name, + textify_hooks(used, sizeof(used), + par->hook_mask, par->family), + textify_hooks(allow, sizeof(allow), + par->target->hooks, + par->family)); + return -EINVAL; + } + if (par->target->proto && (par->target->proto != proto || inv_proto)) { + pr_info_ratelimited("%s_tables: %s target: only valid for protocol %u\n", + xt_prefix[par->family], par->target->name, + par->target->proto); + return -EINVAL; + } + if (par->target->checkentry != NULL) { + ret = par->target->checkentry(par); + if (ret < 0) + return ret; + else if (ret > 0) + /* Flag up potential errors. */ + return -EIO; + } + return 0; +} +EXPORT_SYMBOL_GPL(xt_check_target); + +/** + * xt_copy_counters - copy counters and metadata from a sockptr_t + * + * @arg: src sockptr + * @len: alleged size of userspace memory + * @info: where to store the xt_counters_info metadata + * + * Copies counter meta data from @user and stores it in @info. + * + * vmallocs memory to hold the counters, then copies the counter data + * from @user to the new memory and returns a pointer to it. + * + * If called from a compat syscall, @info gets converted automatically to the + * 64bit representation. + * + * The metadata associated with the counters is stored in @info. + * + * Return: returns pointer that caller has to test via IS_ERR(). + * If IS_ERR is false, caller has to vfree the pointer. + */ +void *xt_copy_counters(sockptr_t arg, unsigned int len, + struct xt_counters_info *info) +{ + size_t offset; + void *mem; + u64 size; + +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + if (in_compat_syscall()) { + /* structures only differ in size due to alignment */ + struct compat_xt_counters_info compat_tmp; + + if (len <= sizeof(compat_tmp)) + return ERR_PTR(-EINVAL); + + len -= sizeof(compat_tmp); + if (copy_from_sockptr(&compat_tmp, arg, sizeof(compat_tmp)) != 0) + return ERR_PTR(-EFAULT); + + memcpy(info->name, compat_tmp.name, sizeof(info->name) - 1); + info->num_counters = compat_tmp.num_counters; + offset = sizeof(compat_tmp); + } else +#endif + { + if (len <= sizeof(*info)) + return ERR_PTR(-EINVAL); + + len -= sizeof(*info); + if (copy_from_sockptr(info, arg, sizeof(*info)) != 0) + return ERR_PTR(-EFAULT); + + offset = sizeof(*info); + } + info->name[sizeof(info->name) - 1] = '\0'; + + size = sizeof(struct xt_counters); + size *= info->num_counters; + + if (size != (u64)len) + return ERR_PTR(-EINVAL); + + mem = vmalloc(len); + if (!mem) + return ERR_PTR(-ENOMEM); + + if (copy_from_sockptr_offset(mem, arg, offset, len) == 0) + return mem; + + vfree(mem); + return ERR_PTR(-EFAULT); +} +EXPORT_SYMBOL_GPL(xt_copy_counters); + +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT +int xt_compat_target_offset(const struct xt_target *target) +{ + u_int16_t csize = target->compatsize ? : target->targetsize; + return XT_ALIGN(target->targetsize) - COMPAT_XT_ALIGN(csize); +} +EXPORT_SYMBOL_GPL(xt_compat_target_offset); + +void xt_compat_target_from_user(struct xt_entry_target *t, void **dstptr, + unsigned int *size) +{ + const struct xt_target *target = t->u.kernel.target; + struct compat_xt_entry_target *ct = (struct compat_xt_entry_target *)t; + int off = xt_compat_target_offset(target); + u_int16_t tsize = ct->u.user.target_size; + char name[sizeof(t->u.user.name)]; + + t = *dstptr; + memcpy(t, ct, sizeof(*ct)); + if (target->compat_from_user) + target->compat_from_user(t->data, ct->data); + else + memcpy(t->data, ct->data, tsize - sizeof(*ct)); + + tsize += off; + t->u.user.target_size = tsize; + strscpy(name, target->name, sizeof(name)); + module_put(target->me); + strncpy(t->u.user.name, name, sizeof(t->u.user.name)); + + *size += off; + *dstptr += tsize; +} +EXPORT_SYMBOL_GPL(xt_compat_target_from_user); + +int xt_compat_target_to_user(const struct xt_entry_target *t, + void __user **dstptr, unsigned int *size) +{ + const struct xt_target *target = t->u.kernel.target; + struct compat_xt_entry_target __user *ct = *dstptr; + int off = xt_compat_target_offset(target); + u_int16_t tsize = t->u.user.target_size - off; + + if (XT_OBJ_TO_USER(ct, t, target, tsize)) + return -EFAULT; + + if (target->compat_to_user) { + if (target->compat_to_user((void __user *)ct->data, t->data)) + return -EFAULT; + } else { + if (COMPAT_XT_DATA_TO_USER(ct, t, target, tsize - sizeof(*ct))) + return -EFAULT; + } + + *size -= off; + *dstptr += tsize; + return 0; +} +EXPORT_SYMBOL_GPL(xt_compat_target_to_user); +#endif + +struct xt_table_info *xt_alloc_table_info(unsigned int size) +{ + struct xt_table_info *info = NULL; + size_t sz = sizeof(*info) + size; + + if (sz < sizeof(*info) || sz >= XT_MAX_TABLE_SIZE) + return NULL; + + info = kvmalloc(sz, GFP_KERNEL_ACCOUNT); + if (!info) + return NULL; + + memset(info, 0, sizeof(*info)); + info->size = size; + return info; +} +EXPORT_SYMBOL(xt_alloc_table_info); + +void xt_free_table_info(struct xt_table_info *info) +{ + int cpu; + + if (info->jumpstack != NULL) { + for_each_possible_cpu(cpu) + kvfree(info->jumpstack[cpu]); + kvfree(info->jumpstack); + } + + kvfree(info); +} +EXPORT_SYMBOL(xt_free_table_info); + +struct xt_table *xt_find_table(struct net *net, u8 af, const char *name) +{ + struct xt_pernet *xt_net = net_generic(net, xt_pernet_id); + struct xt_table *t; + + mutex_lock(&xt[af].mutex); + list_for_each_entry(t, &xt_net->tables[af], list) { + if (strcmp(t->name, name) == 0) { + mutex_unlock(&xt[af].mutex); + return t; + } + } + mutex_unlock(&xt[af].mutex); + return NULL; +} +EXPORT_SYMBOL(xt_find_table); + +/* Find table by name, grabs mutex & ref. Returns ERR_PTR on error. */ +struct xt_table *xt_find_table_lock(struct net *net, u_int8_t af, + const char *name) +{ + struct xt_pernet *xt_net = net_generic(net, xt_pernet_id); + struct module *owner = NULL; + struct xt_template *tmpl; + struct xt_table *t; + + mutex_lock(&xt[af].mutex); + list_for_each_entry(t, &xt_net->tables[af], list) + if (strcmp(t->name, name) == 0 && try_module_get(t->me)) + return t; + + /* Table doesn't exist in this netns, check larval list */ + list_for_each_entry(tmpl, &xt_templates[af], list) { + int err; + + if (strcmp(tmpl->name, name)) + continue; + if (!try_module_get(tmpl->me)) + goto out; + + owner = tmpl->me; + + mutex_unlock(&xt[af].mutex); + err = tmpl->table_init(net); + if (err < 0) { + module_put(owner); + return ERR_PTR(err); + } + + mutex_lock(&xt[af].mutex); + break; + } + + /* and once again: */ + list_for_each_entry(t, &xt_net->tables[af], list) + if (strcmp(t->name, name) == 0) + return t; + + module_put(owner); + out: + mutex_unlock(&xt[af].mutex); + return ERR_PTR(-ENOENT); +} +EXPORT_SYMBOL_GPL(xt_find_table_lock); + +struct xt_table *xt_request_find_table_lock(struct net *net, u_int8_t af, + const char *name) +{ + struct xt_table *t = xt_find_table_lock(net, af, name); + +#ifdef CONFIG_MODULES + if (IS_ERR(t)) { + int err = request_module("%stable_%s", xt_prefix[af], name); + if (err < 0) + return ERR_PTR(err); + t = xt_find_table_lock(net, af, name); + } +#endif + + return t; +} +EXPORT_SYMBOL_GPL(xt_request_find_table_lock); + +void xt_table_unlock(struct xt_table *table) +{ + mutex_unlock(&xt[table->af].mutex); +} +EXPORT_SYMBOL_GPL(xt_table_unlock); + +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT +void xt_compat_lock(u_int8_t af) +{ + mutex_lock(&xt[af].compat_mutex); +} +EXPORT_SYMBOL_GPL(xt_compat_lock); + +void xt_compat_unlock(u_int8_t af) +{ + mutex_unlock(&xt[af].compat_mutex); +} +EXPORT_SYMBOL_GPL(xt_compat_unlock); +#endif + +DEFINE_PER_CPU(seqcount_t, xt_recseq); +EXPORT_PER_CPU_SYMBOL_GPL(xt_recseq); + +struct static_key xt_tee_enabled __read_mostly; +EXPORT_SYMBOL_GPL(xt_tee_enabled); + +static int xt_jumpstack_alloc(struct xt_table_info *i) +{ + unsigned int size; + int cpu; + + size = sizeof(void **) * nr_cpu_ids; + if (size > PAGE_SIZE) + i->jumpstack = kvzalloc(size, GFP_KERNEL); + else + i->jumpstack = kzalloc(size, GFP_KERNEL); + if (i->jumpstack == NULL) + return -ENOMEM; + + /* ruleset without jumps -- no stack needed */ + if (i->stacksize == 0) + return 0; + + /* Jumpstack needs to be able to record two full callchains, one + * from the first rule set traversal, plus one table reentrancy + * via -j TEE without clobbering the callchain that brought us to + * TEE target. + * + * This is done by allocating two jumpstacks per cpu, on reentry + * the upper half of the stack is used. + * + * see the jumpstack setup in ipt_do_table() for more details. + */ + size = sizeof(void *) * i->stacksize * 2u; + for_each_possible_cpu(cpu) { + i->jumpstack[cpu] = kvmalloc_node(size, GFP_KERNEL, + cpu_to_node(cpu)); + if (i->jumpstack[cpu] == NULL) + /* + * Freeing will be done later on by the callers. The + * chain is: xt_replace_table -> __do_replace -> + * do_replace -> xt_free_table_info. + */ + return -ENOMEM; + } + + return 0; +} + +struct xt_counters *xt_counters_alloc(unsigned int counters) +{ + struct xt_counters *mem; + + if (counters == 0 || counters > INT_MAX / sizeof(*mem)) + return NULL; + + counters *= sizeof(*mem); + if (counters > XT_MAX_TABLE_SIZE) + return NULL; + + return vzalloc(counters); +} +EXPORT_SYMBOL(xt_counters_alloc); + +struct xt_table_info * +xt_replace_table(struct xt_table *table, + unsigned int num_counters, + struct xt_table_info *newinfo, + int *error) +{ + struct xt_table_info *private; + unsigned int cpu; + int ret; + + ret = xt_jumpstack_alloc(newinfo); + if (ret < 0) { + *error = ret; + return NULL; + } + + /* Do the substitution. */ + local_bh_disable(); + private = table->private; + + /* Check inside lock: is the old number correct? */ + if (num_counters != private->number) { + pr_debug("num_counters != table->private->number (%u/%u)\n", + num_counters, private->number); + local_bh_enable(); + *error = -EAGAIN; + return NULL; + } + + newinfo->initial_entries = private->initial_entries; + /* + * Ensure contents of newinfo are visible before assigning to + * private. + */ + smp_wmb(); + table->private = newinfo; + + /* make sure all cpus see new ->private value */ + smp_mb(); + + /* + * Even though table entries have now been swapped, other CPU's + * may still be using the old entries... + */ + local_bh_enable(); + + /* ... so wait for even xt_recseq on all cpus */ + for_each_possible_cpu(cpu) { + seqcount_t *s = &per_cpu(xt_recseq, cpu); + u32 seq = raw_read_seqcount(s); + + if (seq & 1) { + do { + cond_resched(); + cpu_relax(); + } while (seq == raw_read_seqcount(s)); + } + } + + audit_log_nfcfg(table->name, table->af, private->number, + !private->number ? AUDIT_XT_OP_REGISTER : + AUDIT_XT_OP_REPLACE, + GFP_KERNEL); + return private; +} +EXPORT_SYMBOL_GPL(xt_replace_table); + +struct xt_table *xt_register_table(struct net *net, + const struct xt_table *input_table, + struct xt_table_info *bootstrap, + struct xt_table_info *newinfo) +{ + struct xt_pernet *xt_net = net_generic(net, xt_pernet_id); + struct xt_table_info *private; + struct xt_table *t, *table; + int ret; + + /* Don't add one object to multiple lists. */ + table = kmemdup(input_table, sizeof(struct xt_table), GFP_KERNEL); + if (!table) { + ret = -ENOMEM; + goto out; + } + + mutex_lock(&xt[table->af].mutex); + /* Don't autoload: we'd eat our tail... */ + list_for_each_entry(t, &xt_net->tables[table->af], list) { + if (strcmp(t->name, table->name) == 0) { + ret = -EEXIST; + goto unlock; + } + } + + /* Simplifies replace_table code. */ + table->private = bootstrap; + + if (!xt_replace_table(table, 0, newinfo, &ret)) + goto unlock; + + private = table->private; + pr_debug("table->private->number = %u\n", private->number); + + /* save number of initial entries */ + private->initial_entries = private->number; + + list_add(&table->list, &xt_net->tables[table->af]); + mutex_unlock(&xt[table->af].mutex); + return table; + +unlock: + mutex_unlock(&xt[table->af].mutex); + kfree(table); +out: + return ERR_PTR(ret); +} +EXPORT_SYMBOL_GPL(xt_register_table); + +void *xt_unregister_table(struct xt_table *table) +{ + struct xt_table_info *private; + + mutex_lock(&xt[table->af].mutex); + private = table->private; + list_del(&table->list); + mutex_unlock(&xt[table->af].mutex); + audit_log_nfcfg(table->name, table->af, private->number, + AUDIT_XT_OP_UNREGISTER, GFP_KERNEL); + kfree(table->ops); + kfree(table); + + return private; +} +EXPORT_SYMBOL_GPL(xt_unregister_table); + +#ifdef CONFIG_PROC_FS +static void *xt_table_seq_start(struct seq_file *seq, loff_t *pos) +{ + u8 af = (unsigned long)pde_data(file_inode(seq->file)); + struct net *net = seq_file_net(seq); + struct xt_pernet *xt_net; + + xt_net = net_generic(net, xt_pernet_id); + + mutex_lock(&xt[af].mutex); + return seq_list_start(&xt_net->tables[af], *pos); +} + +static void *xt_table_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + u8 af = (unsigned long)pde_data(file_inode(seq->file)); + struct net *net = seq_file_net(seq); + struct xt_pernet *xt_net; + + xt_net = net_generic(net, xt_pernet_id); + + return seq_list_next(v, &xt_net->tables[af], pos); +} + +static void xt_table_seq_stop(struct seq_file *seq, void *v) +{ + u_int8_t af = (unsigned long)pde_data(file_inode(seq->file)); + + mutex_unlock(&xt[af].mutex); +} + +static int xt_table_seq_show(struct seq_file *seq, void *v) +{ + struct xt_table *table = list_entry(v, struct xt_table, list); + + if (*table->name) + seq_printf(seq, "%s\n", table->name); + return 0; +} + +static const struct seq_operations xt_table_seq_ops = { + .start = xt_table_seq_start, + .next = xt_table_seq_next, + .stop = xt_table_seq_stop, + .show = xt_table_seq_show, +}; + +/* + * Traverse state for ip{,6}_{tables,matches} for helping crossing + * the multi-AF mutexes. + */ +struct nf_mttg_trav { + struct list_head *head, *curr; + uint8_t class; +}; + +enum { + MTTG_TRAV_INIT, + MTTG_TRAV_NFP_UNSPEC, + MTTG_TRAV_NFP_SPEC, + MTTG_TRAV_DONE, +}; + +static void *xt_mttg_seq_next(struct seq_file *seq, void *v, loff_t *ppos, + bool is_target) +{ + static const uint8_t next_class[] = { + [MTTG_TRAV_NFP_UNSPEC] = MTTG_TRAV_NFP_SPEC, + [MTTG_TRAV_NFP_SPEC] = MTTG_TRAV_DONE, + }; + uint8_t nfproto = (unsigned long)pde_data(file_inode(seq->file)); + struct nf_mttg_trav *trav = seq->private; + + if (ppos != NULL) + ++(*ppos); + + switch (trav->class) { + case MTTG_TRAV_INIT: + trav->class = MTTG_TRAV_NFP_UNSPEC; + mutex_lock(&xt[NFPROTO_UNSPEC].mutex); + trav->head = trav->curr = is_target ? + &xt[NFPROTO_UNSPEC].target : &xt[NFPROTO_UNSPEC].match; + break; + case MTTG_TRAV_NFP_UNSPEC: + trav->curr = trav->curr->next; + if (trav->curr != trav->head) + break; + mutex_unlock(&xt[NFPROTO_UNSPEC].mutex); + mutex_lock(&xt[nfproto].mutex); + trav->head = trav->curr = is_target ? + &xt[nfproto].target : &xt[nfproto].match; + trav->class = next_class[trav->class]; + break; + case MTTG_TRAV_NFP_SPEC: + trav->curr = trav->curr->next; + if (trav->curr != trav->head) + break; + fallthrough; + default: + return NULL; + } + return trav; +} + +static void *xt_mttg_seq_start(struct seq_file *seq, loff_t *pos, + bool is_target) +{ + struct nf_mttg_trav *trav = seq->private; + unsigned int j; + + trav->class = MTTG_TRAV_INIT; + for (j = 0; j < *pos; ++j) + if (xt_mttg_seq_next(seq, NULL, NULL, is_target) == NULL) + return NULL; + return trav; +} + +static void xt_mttg_seq_stop(struct seq_file *seq, void *v) +{ + uint8_t nfproto = (unsigned long)pde_data(file_inode(seq->file)); + struct nf_mttg_trav *trav = seq->private; + + switch (trav->class) { + case MTTG_TRAV_NFP_UNSPEC: + mutex_unlock(&xt[NFPROTO_UNSPEC].mutex); + break; + case MTTG_TRAV_NFP_SPEC: + mutex_unlock(&xt[nfproto].mutex); + break; + } +} + +static void *xt_match_seq_start(struct seq_file *seq, loff_t *pos) +{ + return xt_mttg_seq_start(seq, pos, false); +} + +static void *xt_match_seq_next(struct seq_file *seq, void *v, loff_t *ppos) +{ + return xt_mttg_seq_next(seq, v, ppos, false); +} + +static int xt_match_seq_show(struct seq_file *seq, void *v) +{ + const struct nf_mttg_trav *trav = seq->private; + const struct xt_match *match; + + switch (trav->class) { + case MTTG_TRAV_NFP_UNSPEC: + case MTTG_TRAV_NFP_SPEC: + if (trav->curr == trav->head) + return 0; + match = list_entry(trav->curr, struct xt_match, list); + if (*match->name) + seq_printf(seq, "%s\n", match->name); + } + return 0; +} + +static const struct seq_operations xt_match_seq_ops = { + .start = xt_match_seq_start, + .next = xt_match_seq_next, + .stop = xt_mttg_seq_stop, + .show = xt_match_seq_show, +}; + +static void *xt_target_seq_start(struct seq_file *seq, loff_t *pos) +{ + return xt_mttg_seq_start(seq, pos, true); +} + +static void *xt_target_seq_next(struct seq_file *seq, void *v, loff_t *ppos) +{ + return xt_mttg_seq_next(seq, v, ppos, true); +} + +static int xt_target_seq_show(struct seq_file *seq, void *v) +{ + const struct nf_mttg_trav *trav = seq->private; + const struct xt_target *target; + + switch (trav->class) { + case MTTG_TRAV_NFP_UNSPEC: + case MTTG_TRAV_NFP_SPEC: + if (trav->curr == trav->head) + return 0; + target = list_entry(trav->curr, struct xt_target, list); + if (*target->name) + seq_printf(seq, "%s\n", target->name); + } + return 0; +} + +static const struct seq_operations xt_target_seq_ops = { + .start = xt_target_seq_start, + .next = xt_target_seq_next, + .stop = xt_mttg_seq_stop, + .show = xt_target_seq_show, +}; + +#define FORMAT_TABLES "_tables_names" +#define FORMAT_MATCHES "_tables_matches" +#define FORMAT_TARGETS "_tables_targets" + +#endif /* CONFIG_PROC_FS */ + +/** + * xt_hook_ops_alloc - set up hooks for a new table + * @table: table with metadata needed to set up hooks + * @fn: Hook function + * + * This function will create the nf_hook_ops that the x_table needs + * to hand to xt_hook_link_net(). + */ +struct nf_hook_ops * +xt_hook_ops_alloc(const struct xt_table *table, nf_hookfn *fn) +{ + unsigned int hook_mask = table->valid_hooks; + uint8_t i, num_hooks = hweight32(hook_mask); + uint8_t hooknum; + struct nf_hook_ops *ops; + + if (!num_hooks) + return ERR_PTR(-EINVAL); + + ops = kcalloc(num_hooks, sizeof(*ops), GFP_KERNEL); + if (ops == NULL) + return ERR_PTR(-ENOMEM); + + for (i = 0, hooknum = 0; i < num_hooks && hook_mask != 0; + hook_mask >>= 1, ++hooknum) { + if (!(hook_mask & 1)) + continue; + ops[i].hook = fn; + ops[i].pf = table->af; + ops[i].hooknum = hooknum; + ops[i].priority = table->priority; + ++i; + } + + return ops; +} +EXPORT_SYMBOL_GPL(xt_hook_ops_alloc); + +int xt_register_template(const struct xt_table *table, + int (*table_init)(struct net *net)) +{ + int ret = -EEXIST, af = table->af; + struct xt_template *t; + + mutex_lock(&xt[af].mutex); + + list_for_each_entry(t, &xt_templates[af], list) { + if (WARN_ON_ONCE(strcmp(table->name, t->name) == 0)) + goto out_unlock; + } + + ret = -ENOMEM; + t = kzalloc(sizeof(*t), GFP_KERNEL); + if (!t) + goto out_unlock; + + BUILD_BUG_ON(sizeof(t->name) != sizeof(table->name)); + + strscpy(t->name, table->name, sizeof(t->name)); + t->table_init = table_init; + t->me = table->me; + list_add(&t->list, &xt_templates[af]); + ret = 0; +out_unlock: + mutex_unlock(&xt[af].mutex); + return ret; +} +EXPORT_SYMBOL_GPL(xt_register_template); + +void xt_unregister_template(const struct xt_table *table) +{ + struct xt_template *t; + int af = table->af; + + mutex_lock(&xt[af].mutex); + list_for_each_entry(t, &xt_templates[af], list) { + if (strcmp(table->name, t->name)) + continue; + + list_del(&t->list); + mutex_unlock(&xt[af].mutex); + kfree(t); + return; + } + + mutex_unlock(&xt[af].mutex); + WARN_ON_ONCE(1); +} +EXPORT_SYMBOL_GPL(xt_unregister_template); + +int xt_proto_init(struct net *net, u_int8_t af) +{ +#ifdef CONFIG_PROC_FS + char buf[XT_FUNCTION_MAXNAMELEN]; + struct proc_dir_entry *proc; + kuid_t root_uid; + kgid_t root_gid; +#endif + + if (af >= ARRAY_SIZE(xt_prefix)) + return -EINVAL; + + +#ifdef CONFIG_PROC_FS + root_uid = make_kuid(net->user_ns, 0); + root_gid = make_kgid(net->user_ns, 0); + + strscpy(buf, xt_prefix[af], sizeof(buf)); + strlcat(buf, FORMAT_TABLES, sizeof(buf)); + proc = proc_create_net_data(buf, 0440, net->proc_net, &xt_table_seq_ops, + sizeof(struct seq_net_private), + (void *)(unsigned long)af); + if (!proc) + goto out; + if (uid_valid(root_uid) && gid_valid(root_gid)) + proc_set_user(proc, root_uid, root_gid); + + strscpy(buf, xt_prefix[af], sizeof(buf)); + strlcat(buf, FORMAT_MATCHES, sizeof(buf)); + proc = proc_create_seq_private(buf, 0440, net->proc_net, + &xt_match_seq_ops, sizeof(struct nf_mttg_trav), + (void *)(unsigned long)af); + if (!proc) + goto out_remove_tables; + if (uid_valid(root_uid) && gid_valid(root_gid)) + proc_set_user(proc, root_uid, root_gid); + + strscpy(buf, xt_prefix[af], sizeof(buf)); + strlcat(buf, FORMAT_TARGETS, sizeof(buf)); + proc = proc_create_seq_private(buf, 0440, net->proc_net, + &xt_target_seq_ops, sizeof(struct nf_mttg_trav), + (void *)(unsigned long)af); + if (!proc) + goto out_remove_matches; + if (uid_valid(root_uid) && gid_valid(root_gid)) + proc_set_user(proc, root_uid, root_gid); +#endif + + return 0; + +#ifdef CONFIG_PROC_FS +out_remove_matches: + strscpy(buf, xt_prefix[af], sizeof(buf)); + strlcat(buf, FORMAT_MATCHES, sizeof(buf)); + remove_proc_entry(buf, net->proc_net); + +out_remove_tables: + strscpy(buf, xt_prefix[af], sizeof(buf)); + strlcat(buf, FORMAT_TABLES, sizeof(buf)); + remove_proc_entry(buf, net->proc_net); +out: + return -1; +#endif +} +EXPORT_SYMBOL_GPL(xt_proto_init); + +void xt_proto_fini(struct net *net, u_int8_t af) +{ +#ifdef CONFIG_PROC_FS + char buf[XT_FUNCTION_MAXNAMELEN]; + + strscpy(buf, xt_prefix[af], sizeof(buf)); + strlcat(buf, FORMAT_TABLES, sizeof(buf)); + remove_proc_entry(buf, net->proc_net); + + strscpy(buf, xt_prefix[af], sizeof(buf)); + strlcat(buf, FORMAT_TARGETS, sizeof(buf)); + remove_proc_entry(buf, net->proc_net); + + strscpy(buf, xt_prefix[af], sizeof(buf)); + strlcat(buf, FORMAT_MATCHES, sizeof(buf)); + remove_proc_entry(buf, net->proc_net); +#endif /*CONFIG_PROC_FS*/ +} +EXPORT_SYMBOL_GPL(xt_proto_fini); + +/** + * xt_percpu_counter_alloc - allocate x_tables rule counter + * + * @state: pointer to xt_percpu allocation state + * @counter: pointer to counter struct inside the ip(6)/arpt_entry struct + * + * On SMP, the packet counter [ ip(6)t_entry->counters.pcnt ] will then + * contain the address of the real (percpu) counter. + * + * Rule evaluation needs to use xt_get_this_cpu_counter() helper + * to fetch the real percpu counter. + * + * To speed up allocation and improve data locality, a 4kb block is + * allocated. Freeing any counter may free an entire block, so all + * counters allocated using the same state must be freed at the same + * time. + * + * xt_percpu_counter_alloc_state contains the base address of the + * allocated page and the current sub-offset. + * + * returns false on error. + */ +bool xt_percpu_counter_alloc(struct xt_percpu_counter_alloc_state *state, + struct xt_counters *counter) +{ + BUILD_BUG_ON(XT_PCPU_BLOCK_SIZE < (sizeof(*counter) * 2)); + + if (nr_cpu_ids <= 1) + return true; + + if (!state->mem) { + state->mem = __alloc_percpu(XT_PCPU_BLOCK_SIZE, + XT_PCPU_BLOCK_SIZE); + if (!state->mem) + return false; + } + counter->pcnt = (__force unsigned long)(state->mem + state->off); + state->off += sizeof(*counter); + if (state->off > (XT_PCPU_BLOCK_SIZE - sizeof(*counter))) { + state->mem = NULL; + state->off = 0; + } + return true; +} +EXPORT_SYMBOL_GPL(xt_percpu_counter_alloc); + +void xt_percpu_counter_free(struct xt_counters *counters) +{ + unsigned long pcnt = counters->pcnt; + + if (nr_cpu_ids > 1 && (pcnt & (XT_PCPU_BLOCK_SIZE - 1)) == 0) + free_percpu((void __percpu *)pcnt); +} +EXPORT_SYMBOL_GPL(xt_percpu_counter_free); + +static int __net_init xt_net_init(struct net *net) +{ + struct xt_pernet *xt_net = net_generic(net, xt_pernet_id); + int i; + + for (i = 0; i < NFPROTO_NUMPROTO; i++) + INIT_LIST_HEAD(&xt_net->tables[i]); + return 0; +} + +static void __net_exit xt_net_exit(struct net *net) +{ + struct xt_pernet *xt_net = net_generic(net, xt_pernet_id); + int i; + + for (i = 0; i < NFPROTO_NUMPROTO; i++) + WARN_ON_ONCE(!list_empty(&xt_net->tables[i])); +} + +static struct pernet_operations xt_net_ops = { + .init = xt_net_init, + .exit = xt_net_exit, + .id = &xt_pernet_id, + .size = sizeof(struct xt_pernet), +}; + +static int __init xt_init(void) +{ + unsigned int i; + int rv; + + for_each_possible_cpu(i) { + seqcount_init(&per_cpu(xt_recseq, i)); + } + + xt = kcalloc(NFPROTO_NUMPROTO, sizeof(struct xt_af), GFP_KERNEL); + if (!xt) + return -ENOMEM; + + for (i = 0; i < NFPROTO_NUMPROTO; i++) { + mutex_init(&xt[i].mutex); +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + mutex_init(&xt[i].compat_mutex); + xt[i].compat_tab = NULL; +#endif + INIT_LIST_HEAD(&xt[i].target); + INIT_LIST_HEAD(&xt[i].match); + INIT_LIST_HEAD(&xt_templates[i]); + } + rv = register_pernet_subsys(&xt_net_ops); + if (rv < 0) + kfree(xt); + return rv; +} + +static void __exit xt_fini(void) +{ + unregister_pernet_subsys(&xt_net_ops); + kfree(xt); +} + +module_init(xt_init); +module_exit(xt_fini); + diff --git a/net/netfilter/xt_AUDIT.c b/net/netfilter/xt_AUDIT.c new file mode 100644 index 000000000..b6a015aee --- /dev/null +++ b/net/netfilter/xt_AUDIT.c @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Creates audit record for dropped/accepted packets + * + * (C) 2010-2011 Thomas Graf + * (C) 2010-2011 Red Hat, Inc. +*/ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Thomas Graf "); +MODULE_DESCRIPTION("Xtables: creates audit records for dropped/accepted packets"); +MODULE_ALIAS("ipt_AUDIT"); +MODULE_ALIAS("ip6t_AUDIT"); +MODULE_ALIAS("ebt_AUDIT"); +MODULE_ALIAS("arpt_AUDIT"); + +static bool audit_ip4(struct audit_buffer *ab, struct sk_buff *skb) +{ + struct iphdr _iph; + const struct iphdr *ih; + + ih = skb_header_pointer(skb, skb_network_offset(skb), sizeof(_iph), &_iph); + if (!ih) + return false; + + audit_log_format(ab, " saddr=%pI4 daddr=%pI4 proto=%hhu", + &ih->saddr, &ih->daddr, ih->protocol); + + return true; +} + +static bool audit_ip6(struct audit_buffer *ab, struct sk_buff *skb) +{ + struct ipv6hdr _ip6h; + const struct ipv6hdr *ih; + u8 nexthdr; + __be16 frag_off; + + ih = skb_header_pointer(skb, skb_network_offset(skb), sizeof(_ip6h), &_ip6h); + if (!ih) + return false; + + nexthdr = ih->nexthdr; + ipv6_skip_exthdr(skb, skb_network_offset(skb) + sizeof(_ip6h), &nexthdr, &frag_off); + + audit_log_format(ab, " saddr=%pI6c daddr=%pI6c proto=%hhu", + &ih->saddr, &ih->daddr, nexthdr); + + return true; +} + +static unsigned int +audit_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + struct audit_buffer *ab; + int fam = -1; + + if (audit_enabled == AUDIT_OFF) + goto errout; + ab = audit_log_start(NULL, GFP_ATOMIC, AUDIT_NETFILTER_PKT); + if (ab == NULL) + goto errout; + + audit_log_format(ab, "mark=%#x", skb->mark); + + switch (xt_family(par)) { + case NFPROTO_BRIDGE: + switch (eth_hdr(skb)->h_proto) { + case htons(ETH_P_IP): + fam = audit_ip4(ab, skb) ? NFPROTO_IPV4 : -1; + break; + case htons(ETH_P_IPV6): + fam = audit_ip6(ab, skb) ? NFPROTO_IPV6 : -1; + break; + } + break; + case NFPROTO_IPV4: + fam = audit_ip4(ab, skb) ? NFPROTO_IPV4 : -1; + break; + case NFPROTO_IPV6: + fam = audit_ip6(ab, skb) ? NFPROTO_IPV6 : -1; + break; + } + + if (fam == -1) + audit_log_format(ab, " saddr=? daddr=? proto=-1"); + + audit_log_end(ab); + +errout: + return XT_CONTINUE; +} + +static unsigned int +audit_tg_ebt(struct sk_buff *skb, const struct xt_action_param *par) +{ + audit_tg(skb, par); + return EBT_CONTINUE; +} + +static int audit_tg_check(const struct xt_tgchk_param *par) +{ + const struct xt_audit_info *info = par->targinfo; + + if (info->type > XT_AUDIT_TYPE_MAX) { + pr_info_ratelimited("Audit type out of range (valid range: 0..%u)\n", + XT_AUDIT_TYPE_MAX); + return -ERANGE; + } + + return 0; +} + +static struct xt_target audit_tg_reg[] __read_mostly = { + { + .name = "AUDIT", + .family = NFPROTO_UNSPEC, + .target = audit_tg, + .targetsize = sizeof(struct xt_audit_info), + .checkentry = audit_tg_check, + .me = THIS_MODULE, + }, + { + .name = "AUDIT", + .family = NFPROTO_BRIDGE, + .target = audit_tg_ebt, + .targetsize = sizeof(struct xt_audit_info), + .checkentry = audit_tg_check, + .me = THIS_MODULE, + }, +}; + +static int __init audit_tg_init(void) +{ + return xt_register_targets(audit_tg_reg, ARRAY_SIZE(audit_tg_reg)); +} + +static void __exit audit_tg_exit(void) +{ + xt_unregister_targets(audit_tg_reg, ARRAY_SIZE(audit_tg_reg)); +} + +module_init(audit_tg_init); +module_exit(audit_tg_exit); diff --git a/net/netfilter/xt_CHECKSUM.c b/net/netfilter/xt_CHECKSUM.c new file mode 100644 index 000000000..c8a639f56 --- /dev/null +++ b/net/netfilter/xt_CHECKSUM.c @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* iptables module for the packet checksum mangling + * + * (C) 2002 by Harald Welte + * (C) 2010 Red Hat, Inc. + * + * Author: Michael S. Tsirkin +*/ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include + +#include +#include + +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Michael S. Tsirkin "); +MODULE_DESCRIPTION("Xtables: checksum modification"); +MODULE_ALIAS("ipt_CHECKSUM"); +MODULE_ALIAS("ip6t_CHECKSUM"); + +static unsigned int +checksum_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + if (skb->ip_summed == CHECKSUM_PARTIAL && !skb_is_gso(skb)) + skb_checksum_help(skb); + + return XT_CONTINUE; +} + +static int checksum_tg_check(const struct xt_tgchk_param *par) +{ + const struct xt_CHECKSUM_info *einfo = par->targinfo; + const struct ip6t_ip6 *i6 = par->entryinfo; + const struct ipt_ip *i4 = par->entryinfo; + + if (einfo->operation & ~XT_CHECKSUM_OP_FILL) { + pr_info_ratelimited("unsupported CHECKSUM operation %x\n", + einfo->operation); + return -EINVAL; + } + if (!einfo->operation) + return -EINVAL; + + switch (par->family) { + case NFPROTO_IPV4: + if (i4->proto == IPPROTO_UDP && + (i4->invflags & XT_INV_PROTO) == 0) + return 0; + break; + case NFPROTO_IPV6: + if ((i6->flags & IP6T_F_PROTO) && + i6->proto == IPPROTO_UDP && + (i6->invflags & XT_INV_PROTO) == 0) + return 0; + break; + } + + pr_warn_once("CHECKSUM should be avoided. If really needed, restrict with \"-p udp\" and only use in OUTPUT\n"); + return 0; +} + +static struct xt_target checksum_tg_reg __read_mostly = { + .name = "CHECKSUM", + .family = NFPROTO_UNSPEC, + .target = checksum_tg, + .targetsize = sizeof(struct xt_CHECKSUM_info), + .table = "mangle", + .checkentry = checksum_tg_check, + .me = THIS_MODULE, +}; + +static int __init checksum_tg_init(void) +{ + return xt_register_target(&checksum_tg_reg); +} + +static void __exit checksum_tg_exit(void) +{ + xt_unregister_target(&checksum_tg_reg); +} + +module_init(checksum_tg_init); +module_exit(checksum_tg_exit); diff --git a/net/netfilter/xt_CLASSIFY.c b/net/netfilter/xt_CLASSIFY.c new file mode 100644 index 000000000..0accac98d --- /dev/null +++ b/net/netfilter/xt_CLASSIFY.c @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * This is a module which is used for setting the skb->priority field + * of an skb for qdisc classification. + */ + +/* (C) 2001-2002 Patrick McHardy + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +MODULE_AUTHOR("Patrick McHardy "); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Xtables: Qdisc classification"); +MODULE_ALIAS("ipt_CLASSIFY"); +MODULE_ALIAS("ip6t_CLASSIFY"); +MODULE_ALIAS("arpt_CLASSIFY"); + +static unsigned int +classify_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_classify_target_info *clinfo = par->targinfo; + + skb->priority = clinfo->priority; + return XT_CONTINUE; +} + +static struct xt_target classify_tg_reg[] __read_mostly = { + { + .name = "CLASSIFY", + .revision = 0, + .family = NFPROTO_UNSPEC, + .hooks = (1 << NF_INET_LOCAL_OUT) | (1 << NF_INET_FORWARD) | + (1 << NF_INET_POST_ROUTING), + .target = classify_tg, + .targetsize = sizeof(struct xt_classify_target_info), + .me = THIS_MODULE, + }, + { + .name = "CLASSIFY", + .revision = 0, + .family = NFPROTO_ARP, + .hooks = (1 << NF_ARP_OUT) | (1 << NF_ARP_FORWARD), + .target = classify_tg, + .targetsize = sizeof(struct xt_classify_target_info), + .me = THIS_MODULE, + }, +}; + +static int __init classify_tg_init(void) +{ + return xt_register_targets(classify_tg_reg, ARRAY_SIZE(classify_tg_reg)); +} + +static void __exit classify_tg_exit(void) +{ + xt_unregister_targets(classify_tg_reg, ARRAY_SIZE(classify_tg_reg)); +} + +module_init(classify_tg_init); +module_exit(classify_tg_exit); diff --git a/net/netfilter/xt_CONNSECMARK.c b/net/netfilter/xt_CONNSECMARK.c new file mode 100644 index 000000000..76acecf3e --- /dev/null +++ b/net/netfilter/xt_CONNSECMARK.c @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * This module is used to copy security markings from packets + * to connections, and restore security markings from connections + * back to packets. This would normally be performed in conjunction + * with the SECMARK target and state match. + * + * Based somewhat on CONNMARK: + * Copyright (C) 2002,2004 MARA Systems AB + * by Henrik Nordstrom + * + * (C) 2006,2008 Red Hat, Inc., James Morris + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("James Morris "); +MODULE_DESCRIPTION("Xtables: target for copying between connection and security mark"); +MODULE_ALIAS("ipt_CONNSECMARK"); +MODULE_ALIAS("ip6t_CONNSECMARK"); + +/* + * If the packet has a security mark and the connection does not, copy + * the security mark from the packet to the connection. + */ +static void secmark_save(const struct sk_buff *skb) +{ + if (skb->secmark) { + struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + + ct = nf_ct_get(skb, &ctinfo); + if (ct && !ct->secmark) { + ct->secmark = skb->secmark; + nf_conntrack_event_cache(IPCT_SECMARK, ct); + } + } +} + +/* + * If packet has no security mark, and the connection does, restore the + * security mark from the connection to the packet. + */ +static void secmark_restore(struct sk_buff *skb) +{ + if (!skb->secmark) { + const struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + + ct = nf_ct_get(skb, &ctinfo); + if (ct && ct->secmark) + skb->secmark = ct->secmark; + } +} + +static unsigned int +connsecmark_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_connsecmark_target_info *info = par->targinfo; + + switch (info->mode) { + case CONNSECMARK_SAVE: + secmark_save(skb); + break; + + case CONNSECMARK_RESTORE: + secmark_restore(skb); + break; + + default: + BUG(); + } + + return XT_CONTINUE; +} + +static int connsecmark_tg_check(const struct xt_tgchk_param *par) +{ + const struct xt_connsecmark_target_info *info = par->targinfo; + int ret; + + if (strcmp(par->table, "mangle") != 0 && + strcmp(par->table, "security") != 0) { + pr_info_ratelimited("only valid in \'mangle\' or \'security\' table, not \'%s\'\n", + par->table); + return -EINVAL; + } + + switch (info->mode) { + case CONNSECMARK_SAVE: + case CONNSECMARK_RESTORE: + break; + + default: + pr_info_ratelimited("invalid mode: %hu\n", info->mode); + return -EINVAL; + } + + ret = nf_ct_netns_get(par->net, par->family); + if (ret < 0) + pr_info_ratelimited("cannot load conntrack support for proto=%u\n", + par->family); + return ret; +} + +static void connsecmark_tg_destroy(const struct xt_tgdtor_param *par) +{ + nf_ct_netns_put(par->net, par->family); +} + +static struct xt_target connsecmark_tg_reg __read_mostly = { + .name = "CONNSECMARK", + .revision = 0, + .family = NFPROTO_UNSPEC, + .checkentry = connsecmark_tg_check, + .destroy = connsecmark_tg_destroy, + .target = connsecmark_tg, + .targetsize = sizeof(struct xt_connsecmark_target_info), + .me = THIS_MODULE, +}; + +static int __init connsecmark_tg_init(void) +{ + return xt_register_target(&connsecmark_tg_reg); +} + +static void __exit connsecmark_tg_exit(void) +{ + xt_unregister_target(&connsecmark_tg_reg); +} + +module_init(connsecmark_tg_init); +module_exit(connsecmark_tg_exit); diff --git a/net/netfilter/xt_CT.c b/net/netfilter/xt_CT.c new file mode 100644 index 000000000..2be2f7a7b --- /dev/null +++ b/net/netfilter/xt_CT.c @@ -0,0 +1,405 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2010 Patrick McHardy + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static inline int xt_ct_target(struct sk_buff *skb, struct nf_conn *ct) +{ + /* Previously seen (loopback)? Ignore. */ + if (skb->_nfct != 0) + return XT_CONTINUE; + + if (ct) { + refcount_inc(&ct->ct_general.use); + nf_ct_set(skb, ct, IP_CT_NEW); + } else { + nf_ct_set(skb, ct, IP_CT_UNTRACKED); + } + + return XT_CONTINUE; +} + +static unsigned int xt_ct_target_v0(struct sk_buff *skb, + const struct xt_action_param *par) +{ + const struct xt_ct_target_info *info = par->targinfo; + struct nf_conn *ct = info->ct; + + return xt_ct_target(skb, ct); +} + +static unsigned int xt_ct_target_v1(struct sk_buff *skb, + const struct xt_action_param *par) +{ + const struct xt_ct_target_info_v1 *info = par->targinfo; + struct nf_conn *ct = info->ct; + + return xt_ct_target(skb, ct); +} + +static u8 xt_ct_find_proto(const struct xt_tgchk_param *par) +{ + if (par->family == NFPROTO_IPV4) { + const struct ipt_entry *e = par->entryinfo; + + if (e->ip.invflags & IPT_INV_PROTO) + return 0; + return e->ip.proto; + } else if (par->family == NFPROTO_IPV6) { + const struct ip6t_entry *e = par->entryinfo; + + if (e->ipv6.invflags & IP6T_INV_PROTO) + return 0; + return e->ipv6.proto; + } else + return 0; +} + +static int +xt_ct_set_helper(struct nf_conn *ct, const char *helper_name, + const struct xt_tgchk_param *par) +{ + struct nf_conntrack_helper *helper; + struct nf_conn_help *help; + u8 proto; + + proto = xt_ct_find_proto(par); + if (!proto) { + pr_info_ratelimited("You must specify a L4 protocol and not use inversions on it\n"); + return -ENOENT; + } + + helper = nf_conntrack_helper_try_module_get(helper_name, par->family, + proto); + if (helper == NULL) { + pr_info_ratelimited("No such helper \"%s\"\n", helper_name); + return -ENOENT; + } + + help = nf_ct_helper_ext_add(ct, GFP_KERNEL); + if (help == NULL) { + nf_conntrack_helper_put(helper); + return -ENOMEM; + } + + rcu_assign_pointer(help->helper, helper); + return 0; +} + +static int +xt_ct_set_timeout(struct nf_conn *ct, const struct xt_tgchk_param *par, + const char *timeout_name) +{ +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + const struct nf_conntrack_l4proto *l4proto; + u8 proto; + + proto = xt_ct_find_proto(par); + if (!proto) { + pr_info_ratelimited("You must specify a L4 protocol and not " + "use inversions on it"); + return -EINVAL; + } + l4proto = nf_ct_l4proto_find(proto); + return nf_ct_set_timeout(par->net, ct, par->family, l4proto->l4proto, + timeout_name); + +#else + return -EOPNOTSUPP; +#endif +} + +static u16 xt_ct_flags_to_dir(const struct xt_ct_target_info_v1 *info) +{ + switch (info->flags & (XT_CT_ZONE_DIR_ORIG | + XT_CT_ZONE_DIR_REPL)) { + case XT_CT_ZONE_DIR_ORIG: + return NF_CT_ZONE_DIR_ORIG; + case XT_CT_ZONE_DIR_REPL: + return NF_CT_ZONE_DIR_REPL; + default: + return NF_CT_DEFAULT_ZONE_DIR; + } +} + +static void xt_ct_put_helper(struct nf_conn_help *help) +{ + struct nf_conntrack_helper *helper; + + if (!help) + return; + + /* not yet exposed to other cpus, or ruleset + * already detached (post-replacement). + */ + helper = rcu_dereference_raw(help->helper); + if (helper) + nf_conntrack_helper_put(helper); +} + +static int xt_ct_tg_check(const struct xt_tgchk_param *par, + struct xt_ct_target_info_v1 *info) +{ + struct nf_conntrack_zone zone; + struct nf_conn_help *help; + struct nf_conn *ct; + int ret = -EOPNOTSUPP; + + if (info->flags & XT_CT_NOTRACK) { + ct = NULL; + goto out; + } + +#ifndef CONFIG_NF_CONNTRACK_ZONES + if (info->zone || info->flags & (XT_CT_ZONE_DIR_ORIG | + XT_CT_ZONE_DIR_REPL | + XT_CT_ZONE_MARK)) + goto err1; +#endif + + ret = nf_ct_netns_get(par->net, par->family); + if (ret < 0) + goto err1; + + memset(&zone, 0, sizeof(zone)); + zone.id = info->zone; + zone.dir = xt_ct_flags_to_dir(info); + if (info->flags & XT_CT_ZONE_MARK) + zone.flags |= NF_CT_FLAG_MARK; + + ct = nf_ct_tmpl_alloc(par->net, &zone, GFP_KERNEL); + if (!ct) { + ret = -ENOMEM; + goto err2; + } + + if ((info->ct_events || info->exp_events) && + !nf_ct_ecache_ext_add(ct, info->ct_events, info->exp_events, + GFP_KERNEL)) { + ret = -EINVAL; + goto err3; + } + + if (info->helper[0]) { + if (strnlen(info->helper, sizeof(info->helper)) == sizeof(info->helper)) { + ret = -ENAMETOOLONG; + goto err3; + } + + ret = xt_ct_set_helper(ct, info->helper, par); + if (ret < 0) + goto err3; + } + + if (info->timeout[0]) { + if (strnlen(info->timeout, sizeof(info->timeout)) == sizeof(info->timeout)) { + ret = -ENAMETOOLONG; + goto err4; + } + + ret = xt_ct_set_timeout(ct, par, info->timeout); + if (ret < 0) + goto err4; + } + __set_bit(IPS_CONFIRMED_BIT, &ct->status); +out: + info->ct = ct; + return 0; + +err4: + help = nfct_help(ct); + xt_ct_put_helper(help); +err3: + nf_ct_tmpl_free(ct); +err2: + nf_ct_netns_put(par->net, par->family); +err1: + return ret; +} + +static int xt_ct_tg_check_v0(const struct xt_tgchk_param *par) +{ + struct xt_ct_target_info *info = par->targinfo; + struct xt_ct_target_info_v1 info_v1 = { + .flags = info->flags, + .zone = info->zone, + .ct_events = info->ct_events, + .exp_events = info->exp_events, + }; + int ret; + + if (info->flags & ~XT_CT_NOTRACK) + return -EINVAL; + + memcpy(info_v1.helper, info->helper, sizeof(info->helper)); + + ret = xt_ct_tg_check(par, &info_v1); + if (ret < 0) + return ret; + + info->ct = info_v1.ct; + + return ret; +} + +static int xt_ct_tg_check_v1(const struct xt_tgchk_param *par) +{ + struct xt_ct_target_info_v1 *info = par->targinfo; + + if (info->flags & ~XT_CT_NOTRACK) + return -EINVAL; + + return xt_ct_tg_check(par, par->targinfo); +} + +static int xt_ct_tg_check_v2(const struct xt_tgchk_param *par) +{ + struct xt_ct_target_info_v1 *info = par->targinfo; + + if (info->flags & ~XT_CT_MASK) + return -EINVAL; + + return xt_ct_tg_check(par, par->targinfo); +} + +static void xt_ct_tg_destroy(const struct xt_tgdtor_param *par, + struct xt_ct_target_info_v1 *info) +{ + struct nf_conn *ct = info->ct; + struct nf_conn_help *help; + + if (ct) { + help = nfct_help(ct); + xt_ct_put_helper(help); + + nf_ct_netns_put(par->net, par->family); + + nf_ct_destroy_timeout(ct); + nf_ct_put(info->ct); + } +} + +static void xt_ct_tg_destroy_v0(const struct xt_tgdtor_param *par) +{ + struct xt_ct_target_info *info = par->targinfo; + struct xt_ct_target_info_v1 info_v1 = { + .flags = info->flags, + .zone = info->zone, + .ct_events = info->ct_events, + .exp_events = info->exp_events, + .ct = info->ct, + }; + memcpy(info_v1.helper, info->helper, sizeof(info->helper)); + + xt_ct_tg_destroy(par, &info_v1); +} + +static void xt_ct_tg_destroy_v1(const struct xt_tgdtor_param *par) +{ + xt_ct_tg_destroy(par, par->targinfo); +} + +static struct xt_target xt_ct_tg_reg[] __read_mostly = { + { + .name = "CT", + .family = NFPROTO_UNSPEC, + .targetsize = sizeof(struct xt_ct_target_info), + .usersize = offsetof(struct xt_ct_target_info, ct), + .checkentry = xt_ct_tg_check_v0, + .destroy = xt_ct_tg_destroy_v0, + .target = xt_ct_target_v0, + .table = "raw", + .me = THIS_MODULE, + }, + { + .name = "CT", + .family = NFPROTO_UNSPEC, + .revision = 1, + .targetsize = sizeof(struct xt_ct_target_info_v1), + .usersize = offsetof(struct xt_ct_target_info, ct), + .checkentry = xt_ct_tg_check_v1, + .destroy = xt_ct_tg_destroy_v1, + .target = xt_ct_target_v1, + .table = "raw", + .me = THIS_MODULE, + }, + { + .name = "CT", + .family = NFPROTO_UNSPEC, + .revision = 2, + .targetsize = sizeof(struct xt_ct_target_info_v1), + .usersize = offsetof(struct xt_ct_target_info, ct), + .checkentry = xt_ct_tg_check_v2, + .destroy = xt_ct_tg_destroy_v1, + .target = xt_ct_target_v1, + .table = "raw", + .me = THIS_MODULE, + }, +}; + +static unsigned int +notrack_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + /* Previously seen (loopback)? Ignore. */ + if (skb->_nfct != 0) + return XT_CONTINUE; + + nf_ct_set(skb, NULL, IP_CT_UNTRACKED); + + return XT_CONTINUE; +} + +static struct xt_target notrack_tg_reg __read_mostly = { + .name = "NOTRACK", + .revision = 0, + .family = NFPROTO_UNSPEC, + .target = notrack_tg, + .table = "raw", + .me = THIS_MODULE, +}; + +static int __init xt_ct_tg_init(void) +{ + int ret; + + ret = xt_register_target(¬rack_tg_reg); + if (ret < 0) + return ret; + + ret = xt_register_targets(xt_ct_tg_reg, ARRAY_SIZE(xt_ct_tg_reg)); + if (ret < 0) { + xt_unregister_target(¬rack_tg_reg); + return ret; + } + return 0; +} + +static void __exit xt_ct_tg_exit(void) +{ + xt_unregister_targets(xt_ct_tg_reg, ARRAY_SIZE(xt_ct_tg_reg)); + xt_unregister_target(¬rack_tg_reg); +} + +module_init(xt_ct_tg_init); +module_exit(xt_ct_tg_exit); + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Xtables: connection tracking target"); +MODULE_ALIAS("ipt_CT"); +MODULE_ALIAS("ip6t_CT"); +MODULE_ALIAS("ipt_NOTRACK"); +MODULE_ALIAS("ip6t_NOTRACK"); diff --git a/net/netfilter/xt_DSCP.c b/net/netfilter/xt_DSCP.c new file mode 100644 index 000000000..cfa44515a --- /dev/null +++ b/net/netfilter/xt_DSCP.c @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* x_tables module for setting the IPv4/IPv6 DSCP field, Version 1.8 + * + * (C) 2002 by Harald Welte + * based on ipt_FTOS.c (C) 2000 by Matthew G. Marsh + * + * See RFC2474 for a description of the DSCP field within the IP Header. +*/ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include + +#include +#include + +MODULE_AUTHOR("Harald Welte "); +MODULE_DESCRIPTION("Xtables: DSCP/TOS field modification"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_DSCP"); +MODULE_ALIAS("ip6t_DSCP"); +MODULE_ALIAS("ipt_TOS"); +MODULE_ALIAS("ip6t_TOS"); + +#define XT_DSCP_ECN_MASK 3u + +static unsigned int +dscp_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_DSCP_info *dinfo = par->targinfo; + u_int8_t dscp = ipv4_get_dsfield(ip_hdr(skb)) >> XT_DSCP_SHIFT; + + if (dscp != dinfo->dscp) { + if (skb_ensure_writable(skb, sizeof(struct iphdr))) + return NF_DROP; + + ipv4_change_dsfield(ip_hdr(skb), XT_DSCP_ECN_MASK, + dinfo->dscp << XT_DSCP_SHIFT); + + } + return XT_CONTINUE; +} + +static unsigned int +dscp_tg6(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_DSCP_info *dinfo = par->targinfo; + u_int8_t dscp = ipv6_get_dsfield(ipv6_hdr(skb)) >> XT_DSCP_SHIFT; + + if (dscp != dinfo->dscp) { + if (skb_ensure_writable(skb, sizeof(struct ipv6hdr))) + return NF_DROP; + + ipv6_change_dsfield(ipv6_hdr(skb), XT_DSCP_ECN_MASK, + dinfo->dscp << XT_DSCP_SHIFT); + } + return XT_CONTINUE; +} + +static int dscp_tg_check(const struct xt_tgchk_param *par) +{ + const struct xt_DSCP_info *info = par->targinfo; + + if (info->dscp > XT_DSCP_MAX) + return -EDOM; + return 0; +} + +static unsigned int +tos_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_tos_target_info *info = par->targinfo; + struct iphdr *iph = ip_hdr(skb); + u_int8_t orig, nv; + + orig = ipv4_get_dsfield(iph); + nv = (orig & ~info->tos_mask) ^ info->tos_value; + + if (orig != nv) { + if (skb_ensure_writable(skb, sizeof(struct iphdr))) + return NF_DROP; + iph = ip_hdr(skb); + ipv4_change_dsfield(iph, 0, nv); + } + + return XT_CONTINUE; +} + +static unsigned int +tos_tg6(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_tos_target_info *info = par->targinfo; + struct ipv6hdr *iph = ipv6_hdr(skb); + u_int8_t orig, nv; + + orig = ipv6_get_dsfield(iph); + nv = (orig & ~info->tos_mask) ^ info->tos_value; + + if (orig != nv) { + if (skb_ensure_writable(skb, sizeof(struct iphdr))) + return NF_DROP; + iph = ipv6_hdr(skb); + ipv6_change_dsfield(iph, 0, nv); + } + + return XT_CONTINUE; +} + +static struct xt_target dscp_tg_reg[] __read_mostly = { + { + .name = "DSCP", + .family = NFPROTO_IPV4, + .checkentry = dscp_tg_check, + .target = dscp_tg, + .targetsize = sizeof(struct xt_DSCP_info), + .table = "mangle", + .me = THIS_MODULE, + }, + { + .name = "DSCP", + .family = NFPROTO_IPV6, + .checkentry = dscp_tg_check, + .target = dscp_tg6, + .targetsize = sizeof(struct xt_DSCP_info), + .table = "mangle", + .me = THIS_MODULE, + }, + { + .name = "TOS", + .revision = 1, + .family = NFPROTO_IPV4, + .table = "mangle", + .target = tos_tg, + .targetsize = sizeof(struct xt_tos_target_info), + .me = THIS_MODULE, + }, + { + .name = "TOS", + .revision = 1, + .family = NFPROTO_IPV6, + .table = "mangle", + .target = tos_tg6, + .targetsize = sizeof(struct xt_tos_target_info), + .me = THIS_MODULE, + }, +}; + +static int __init dscp_tg_init(void) +{ + return xt_register_targets(dscp_tg_reg, ARRAY_SIZE(dscp_tg_reg)); +} + +static void __exit dscp_tg_exit(void) +{ + xt_unregister_targets(dscp_tg_reg, ARRAY_SIZE(dscp_tg_reg)); +} + +module_init(dscp_tg_init); +module_exit(dscp_tg_exit); diff --git a/net/netfilter/xt_HL.c b/net/netfilter/xt_HL.c new file mode 100644 index 000000000..7873b834c --- /dev/null +++ b/net/netfilter/xt_HL.c @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * TTL modification target for IP tables + * (C) 2000,2005 by Harald Welte + * + * Hop Limit modification target for ip6tables + * Maciej Soltysiak + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include + +#include +#include +#include + +MODULE_AUTHOR("Harald Welte "); +MODULE_AUTHOR("Maciej Soltysiak "); +MODULE_DESCRIPTION("Xtables: Hoplimit/TTL Limit field modification target"); +MODULE_LICENSE("GPL"); + +static unsigned int +ttl_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + struct iphdr *iph; + const struct ipt_TTL_info *info = par->targinfo; + int new_ttl; + + if (skb_ensure_writable(skb, sizeof(*iph))) + return NF_DROP; + + iph = ip_hdr(skb); + + switch (info->mode) { + case IPT_TTL_SET: + new_ttl = info->ttl; + break; + case IPT_TTL_INC: + new_ttl = iph->ttl + info->ttl; + if (new_ttl > 255) + new_ttl = 255; + break; + case IPT_TTL_DEC: + new_ttl = iph->ttl - info->ttl; + if (new_ttl < 0) + new_ttl = 0; + break; + default: + new_ttl = iph->ttl; + break; + } + + if (new_ttl != iph->ttl) { + csum_replace2(&iph->check, htons(iph->ttl << 8), + htons(new_ttl << 8)); + iph->ttl = new_ttl; + } + + return XT_CONTINUE; +} + +static unsigned int +hl_tg6(struct sk_buff *skb, const struct xt_action_param *par) +{ + struct ipv6hdr *ip6h; + const struct ip6t_HL_info *info = par->targinfo; + int new_hl; + + if (skb_ensure_writable(skb, sizeof(*ip6h))) + return NF_DROP; + + ip6h = ipv6_hdr(skb); + + switch (info->mode) { + case IP6T_HL_SET: + new_hl = info->hop_limit; + break; + case IP6T_HL_INC: + new_hl = ip6h->hop_limit + info->hop_limit; + if (new_hl > 255) + new_hl = 255; + break; + case IP6T_HL_DEC: + new_hl = ip6h->hop_limit - info->hop_limit; + if (new_hl < 0) + new_hl = 0; + break; + default: + new_hl = ip6h->hop_limit; + break; + } + + ip6h->hop_limit = new_hl; + + return XT_CONTINUE; +} + +static int ttl_tg_check(const struct xt_tgchk_param *par) +{ + const struct ipt_TTL_info *info = par->targinfo; + + if (info->mode > IPT_TTL_MAXMODE) + return -EINVAL; + if (info->mode != IPT_TTL_SET && info->ttl == 0) + return -EINVAL; + return 0; +} + +static int hl_tg6_check(const struct xt_tgchk_param *par) +{ + const struct ip6t_HL_info *info = par->targinfo; + + if (info->mode > IP6T_HL_MAXMODE) + return -EINVAL; + if (info->mode != IP6T_HL_SET && info->hop_limit == 0) + return -EINVAL; + return 0; +} + +static struct xt_target hl_tg_reg[] __read_mostly = { + { + .name = "TTL", + .revision = 0, + .family = NFPROTO_IPV4, + .target = ttl_tg, + .targetsize = sizeof(struct ipt_TTL_info), + .table = "mangle", + .checkentry = ttl_tg_check, + .me = THIS_MODULE, + }, + { + .name = "HL", + .revision = 0, + .family = NFPROTO_IPV6, + .target = hl_tg6, + .targetsize = sizeof(struct ip6t_HL_info), + .table = "mangle", + .checkentry = hl_tg6_check, + .me = THIS_MODULE, + }, +}; + +static int __init hl_tg_init(void) +{ + return xt_register_targets(hl_tg_reg, ARRAY_SIZE(hl_tg_reg)); +} + +static void __exit hl_tg_exit(void) +{ + xt_unregister_targets(hl_tg_reg, ARRAY_SIZE(hl_tg_reg)); +} + +module_init(hl_tg_init); +module_exit(hl_tg_exit); +MODULE_ALIAS("ipt_TTL"); +MODULE_ALIAS("ip6t_HL"); diff --git a/net/netfilter/xt_HMARK.c b/net/netfilter/xt_HMARK.c new file mode 100644 index 000000000..8928ec56c --- /dev/null +++ b/net/netfilter/xt_HMARK.c @@ -0,0 +1,368 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * xt_HMARK - Netfilter module to set mark by means of hashing + * + * (C) 2012 by Hans Schillstrom + * (C) 2012 by Pablo Neira Ayuso + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include + +#include +#include + +#include +#if IS_ENABLED(CONFIG_NF_CONNTRACK) +#include +#endif +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) +#include +#include +#endif + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Hans Schillstrom "); +MODULE_DESCRIPTION("Xtables: packet marking using hash calculation"); +MODULE_ALIAS("ipt_HMARK"); +MODULE_ALIAS("ip6t_HMARK"); + +struct hmark_tuple { + __be32 src; + __be32 dst; + union hmark_ports uports; + u8 proto; +}; + +static inline __be32 hmark_addr6_mask(const __be32 *addr32, const __be32 *mask) +{ + return (addr32[0] & mask[0]) ^ + (addr32[1] & mask[1]) ^ + (addr32[2] & mask[2]) ^ + (addr32[3] & mask[3]); +} + +static inline __be32 +hmark_addr_mask(int l3num, const __be32 *addr32, const __be32 *mask) +{ + switch (l3num) { + case AF_INET: + return *addr32 & *mask; + case AF_INET6: + return hmark_addr6_mask(addr32, mask); + } + return 0; +} + +static inline void hmark_swap_ports(union hmark_ports *uports, + const struct xt_hmark_info *info) +{ + union hmark_ports hp; + u16 src, dst; + + hp.b32 = (uports->b32 & info->port_mask.b32) | info->port_set.b32; + src = ntohs(hp.b16.src); + dst = ntohs(hp.b16.dst); + + if (dst > src) + uports->v32 = (dst << 16) | src; + else + uports->v32 = (src << 16) | dst; +} + +static int +hmark_ct_set_htuple(const struct sk_buff *skb, struct hmark_tuple *t, + const struct xt_hmark_info *info) +{ +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + struct nf_conntrack_tuple *otuple; + struct nf_conntrack_tuple *rtuple; + + if (ct == NULL) + return -1; + + otuple = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple; + rtuple = &ct->tuplehash[IP_CT_DIR_REPLY].tuple; + + t->src = hmark_addr_mask(otuple->src.l3num, otuple->src.u3.ip6, + info->src_mask.ip6); + t->dst = hmark_addr_mask(otuple->src.l3num, rtuple->src.u3.ip6, + info->dst_mask.ip6); + + if (info->flags & XT_HMARK_FLAG(XT_HMARK_METHOD_L3)) + return 0; + + t->proto = nf_ct_protonum(ct); + if (t->proto != IPPROTO_ICMP) { + t->uports.b16.src = otuple->src.u.all; + t->uports.b16.dst = rtuple->src.u.all; + hmark_swap_ports(&t->uports, info); + } + + return 0; +#else + return -1; +#endif +} + +/* This hash function is endian independent, to ensure consistent hashing if + * the cluster is composed of big and little endian systems. */ +static inline u32 +hmark_hash(struct hmark_tuple *t, const struct xt_hmark_info *info) +{ + u32 hash; + u32 src = ntohl(t->src); + u32 dst = ntohl(t->dst); + + if (dst < src) + swap(src, dst); + + hash = jhash_3words(src, dst, t->uports.v32, info->hashrnd); + hash = hash ^ (t->proto & info->proto_mask); + + return reciprocal_scale(hash, info->hmodulus) + info->hoffset; +} + +static void +hmark_set_tuple_ports(const struct sk_buff *skb, unsigned int nhoff, + struct hmark_tuple *t, const struct xt_hmark_info *info) +{ + int protoff; + + protoff = proto_ports_offset(t->proto); + if (protoff < 0) + return; + + nhoff += protoff; + if (skb_copy_bits(skb, nhoff, &t->uports, sizeof(t->uports)) < 0) + return; + + hmark_swap_ports(&t->uports, info); +} + +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) +static int get_inner6_hdr(const struct sk_buff *skb, int *offset) +{ + struct icmp6hdr *icmp6h, _ih6; + + icmp6h = skb_header_pointer(skb, *offset, sizeof(_ih6), &_ih6); + if (icmp6h == NULL) + return 0; + + if (icmp6h->icmp6_type && icmp6h->icmp6_type < 128) { + *offset += sizeof(struct icmp6hdr); + return 1; + } + return 0; +} + +static int +hmark_pkt_set_htuple_ipv6(const struct sk_buff *skb, struct hmark_tuple *t, + const struct xt_hmark_info *info) +{ + struct ipv6hdr *ip6, _ip6; + int flag = IP6_FH_F_AUTH; + unsigned int nhoff = 0; + u16 fragoff = 0; + int nexthdr; + + ip6 = (struct ipv6hdr *) (skb->data + skb_network_offset(skb)); + nexthdr = ipv6_find_hdr(skb, &nhoff, -1, &fragoff, &flag); + if (nexthdr < 0) + return 0; + /* No need to check for icmp errors on fragments */ + if ((flag & IP6_FH_F_FRAG) || (nexthdr != IPPROTO_ICMPV6)) + goto noicmp; + /* Use inner header in case of ICMP errors */ + if (get_inner6_hdr(skb, &nhoff)) { + ip6 = skb_header_pointer(skb, nhoff, sizeof(_ip6), &_ip6); + if (ip6 == NULL) + return -1; + /* If AH present, use SPI like in ESP. */ + flag = IP6_FH_F_AUTH; + nexthdr = ipv6_find_hdr(skb, &nhoff, -1, &fragoff, &flag); + if (nexthdr < 0) + return -1; + } +noicmp: + t->src = hmark_addr6_mask(ip6->saddr.s6_addr32, info->src_mask.ip6); + t->dst = hmark_addr6_mask(ip6->daddr.s6_addr32, info->dst_mask.ip6); + + if (info->flags & XT_HMARK_FLAG(XT_HMARK_METHOD_L3)) + return 0; + + t->proto = nexthdr; + if (t->proto == IPPROTO_ICMPV6) + return 0; + + if (flag & IP6_FH_F_FRAG) + return 0; + + hmark_set_tuple_ports(skb, nhoff, t, info); + return 0; +} + +static unsigned int +hmark_tg_v6(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_hmark_info *info = par->targinfo; + struct hmark_tuple t; + + memset(&t, 0, sizeof(struct hmark_tuple)); + + if (info->flags & XT_HMARK_FLAG(XT_HMARK_CT)) { + if (hmark_ct_set_htuple(skb, &t, info) < 0) + return XT_CONTINUE; + } else { + if (hmark_pkt_set_htuple_ipv6(skb, &t, info) < 0) + return XT_CONTINUE; + } + + skb->mark = hmark_hash(&t, info); + return XT_CONTINUE; +} +#endif + +static int get_inner_hdr(const struct sk_buff *skb, int iphsz, int *nhoff) +{ + const struct icmphdr *icmph; + struct icmphdr _ih; + + /* Not enough header? */ + icmph = skb_header_pointer(skb, *nhoff + iphsz, sizeof(_ih), &_ih); + if (icmph == NULL || icmph->type > NR_ICMP_TYPES) + return 0; + + /* Error message? */ + if (!icmp_is_err(icmph->type)) + return 0; + + *nhoff += iphsz + sizeof(_ih); + return 1; +} + +static int +hmark_pkt_set_htuple_ipv4(const struct sk_buff *skb, struct hmark_tuple *t, + const struct xt_hmark_info *info) +{ + struct iphdr *ip, _ip; + int nhoff = skb_network_offset(skb); + + ip = (struct iphdr *) (skb->data + nhoff); + if (ip->protocol == IPPROTO_ICMP) { + /* Use inner header in case of ICMP errors */ + if (get_inner_hdr(skb, ip->ihl * 4, &nhoff)) { + ip = skb_header_pointer(skb, nhoff, sizeof(_ip), &_ip); + if (ip == NULL) + return -1; + } + } + + t->src = ip->saddr & info->src_mask.ip; + t->dst = ip->daddr & info->dst_mask.ip; + + if (info->flags & XT_HMARK_FLAG(XT_HMARK_METHOD_L3)) + return 0; + + t->proto = ip->protocol; + + /* ICMP has no ports, skip */ + if (t->proto == IPPROTO_ICMP) + return 0; + + /* follow-up fragments don't contain ports, skip all fragments */ + if (ip_is_fragment(ip)) + return 0; + + hmark_set_tuple_ports(skb, (ip->ihl * 4) + nhoff, t, info); + + return 0; +} + +static unsigned int +hmark_tg_v4(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_hmark_info *info = par->targinfo; + struct hmark_tuple t; + + memset(&t, 0, sizeof(struct hmark_tuple)); + + if (info->flags & XT_HMARK_FLAG(XT_HMARK_CT)) { + if (hmark_ct_set_htuple(skb, &t, info) < 0) + return XT_CONTINUE; + } else { + if (hmark_pkt_set_htuple_ipv4(skb, &t, info) < 0) + return XT_CONTINUE; + } + + skb->mark = hmark_hash(&t, info); + return XT_CONTINUE; +} + +static int hmark_tg_check(const struct xt_tgchk_param *par) +{ + const struct xt_hmark_info *info = par->targinfo; + const char *errmsg = "proto mask must be zero with L3 mode"; + + if (!info->hmodulus) + return -EINVAL; + + if (info->proto_mask && + (info->flags & XT_HMARK_FLAG(XT_HMARK_METHOD_L3))) + goto err; + + if (info->flags & XT_HMARK_FLAG(XT_HMARK_SPI_MASK) && + (info->flags & (XT_HMARK_FLAG(XT_HMARK_SPORT_MASK) | + XT_HMARK_FLAG(XT_HMARK_DPORT_MASK)))) + return -EINVAL; + + if (info->flags & XT_HMARK_FLAG(XT_HMARK_SPI) && + (info->flags & (XT_HMARK_FLAG(XT_HMARK_SPORT) | + XT_HMARK_FLAG(XT_HMARK_DPORT)))) { + errmsg = "spi-set and port-set can't be combined"; + goto err; + } + return 0; +err: + pr_info_ratelimited("%s\n", errmsg); + return -EINVAL; +} + +static struct xt_target hmark_tg_reg[] __read_mostly = { + { + .name = "HMARK", + .family = NFPROTO_IPV4, + .target = hmark_tg_v4, + .targetsize = sizeof(struct xt_hmark_info), + .checkentry = hmark_tg_check, + .me = THIS_MODULE, + }, +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + { + .name = "HMARK", + .family = NFPROTO_IPV6, + .target = hmark_tg_v6, + .targetsize = sizeof(struct xt_hmark_info), + .checkentry = hmark_tg_check, + .me = THIS_MODULE, + }, +#endif +}; + +static int __init hmark_tg_init(void) +{ + return xt_register_targets(hmark_tg_reg, ARRAY_SIZE(hmark_tg_reg)); +} + +static void __exit hmark_tg_exit(void) +{ + xt_unregister_targets(hmark_tg_reg, ARRAY_SIZE(hmark_tg_reg)); +} + +module_init(hmark_tg_init); +module_exit(hmark_tg_exit); diff --git a/net/netfilter/xt_IDLETIMER.c b/net/netfilter/xt_IDLETIMER.c new file mode 100644 index 000000000..0f8bb0bf5 --- /dev/null +++ b/net/netfilter/xt_IDLETIMER.c @@ -0,0 +1,542 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * linux/net/netfilter/xt_IDLETIMER.c + * + * Netfilter module to trigger a timer when packet matches. + * After timer expires a kevent will be sent. + * + * Copyright (C) 2004, 2010 Nokia Corporation + * Written by Timo Teras + * + * Converted to x_tables and reworked for upstream inclusion + * by Luciano Coelho + * + * Contact: Luciano Coelho + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct idletimer_tg { + struct list_head entry; + struct alarm alarm; + struct timer_list timer; + struct work_struct work; + + struct kobject *kobj; + struct device_attribute attr; + + unsigned int refcnt; + u8 timer_type; +}; + +static LIST_HEAD(idletimer_tg_list); +static DEFINE_MUTEX(list_mutex); + +static struct kobject *idletimer_tg_kobj; + +static +struct idletimer_tg *__idletimer_tg_find_by_label(const char *label) +{ + struct idletimer_tg *entry; + + list_for_each_entry(entry, &idletimer_tg_list, entry) { + if (!strcmp(label, entry->attr.attr.name)) + return entry; + } + + return NULL; +} + +static ssize_t idletimer_tg_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + struct idletimer_tg *timer; + unsigned long expires = 0; + struct timespec64 ktimespec = {}; + long time_diff = 0; + + mutex_lock(&list_mutex); + + timer = __idletimer_tg_find_by_label(attr->attr.name); + if (timer) { + if (timer->timer_type & XT_IDLETIMER_ALARM) { + ktime_t expires_alarm = alarm_expires_remaining(&timer->alarm); + ktimespec = ktime_to_timespec64(expires_alarm); + time_diff = ktimespec.tv_sec; + } else { + expires = timer->timer.expires; + time_diff = jiffies_to_msecs(expires - jiffies) / 1000; + } + } + + mutex_unlock(&list_mutex); + + if (time_after(expires, jiffies) || ktimespec.tv_sec > 0) + return sysfs_emit(buf, "%ld\n", time_diff); + + return sysfs_emit(buf, "0\n"); +} + +static void idletimer_tg_work(struct work_struct *work) +{ + struct idletimer_tg *timer = container_of(work, struct idletimer_tg, + work); + + sysfs_notify(idletimer_tg_kobj, NULL, timer->attr.attr.name); +} + +static void idletimer_tg_expired(struct timer_list *t) +{ + struct idletimer_tg *timer = from_timer(timer, t, timer); + + pr_debug("timer %s expired\n", timer->attr.attr.name); + + schedule_work(&timer->work); +} + +static enum alarmtimer_restart idletimer_tg_alarmproc(struct alarm *alarm, + ktime_t now) +{ + struct idletimer_tg *timer = alarm->data; + + pr_debug("alarm %s expired\n", timer->attr.attr.name); + schedule_work(&timer->work); + return ALARMTIMER_NORESTART; +} + +static int idletimer_check_sysfs_name(const char *name, unsigned int size) +{ + int ret; + + ret = xt_check_proc_name(name, size); + if (ret < 0) + return ret; + + if (!strcmp(name, "power") || + !strcmp(name, "subsystem") || + !strcmp(name, "uevent")) + return -EINVAL; + + return 0; +} + +static int idletimer_tg_create(struct idletimer_tg_info *info) +{ + int ret; + + info->timer = kzalloc(sizeof(*info->timer), GFP_KERNEL); + if (!info->timer) { + ret = -ENOMEM; + goto out; + } + + ret = idletimer_check_sysfs_name(info->label, sizeof(info->label)); + if (ret < 0) + goto out_free_timer; + + sysfs_attr_init(&info->timer->attr.attr); + info->timer->attr.attr.name = kstrdup(info->label, GFP_KERNEL); + if (!info->timer->attr.attr.name) { + ret = -ENOMEM; + goto out_free_timer; + } + info->timer->attr.attr.mode = 0444; + info->timer->attr.show = idletimer_tg_show; + + ret = sysfs_create_file(idletimer_tg_kobj, &info->timer->attr.attr); + if (ret < 0) { + pr_debug("couldn't add file to sysfs"); + goto out_free_attr; + } + + list_add(&info->timer->entry, &idletimer_tg_list); + + timer_setup(&info->timer->timer, idletimer_tg_expired, 0); + info->timer->refcnt = 1; + + INIT_WORK(&info->timer->work, idletimer_tg_work); + + mod_timer(&info->timer->timer, + msecs_to_jiffies(info->timeout * 1000) + jiffies); + + return 0; + +out_free_attr: + kfree(info->timer->attr.attr.name); +out_free_timer: + kfree(info->timer); +out: + return ret; +} + +static int idletimer_tg_create_v1(struct idletimer_tg_info_v1 *info) +{ + int ret; + + info->timer = kmalloc(sizeof(*info->timer), GFP_KERNEL); + if (!info->timer) { + ret = -ENOMEM; + goto out; + } + + ret = idletimer_check_sysfs_name(info->label, sizeof(info->label)); + if (ret < 0) + goto out_free_timer; + + sysfs_attr_init(&info->timer->attr.attr); + info->timer->attr.attr.name = kstrdup(info->label, GFP_KERNEL); + if (!info->timer->attr.attr.name) { + ret = -ENOMEM; + goto out_free_timer; + } + info->timer->attr.attr.mode = 0444; + info->timer->attr.show = idletimer_tg_show; + + ret = sysfs_create_file(idletimer_tg_kobj, &info->timer->attr.attr); + if (ret < 0) { + pr_debug("couldn't add file to sysfs"); + goto out_free_attr; + } + + /* notify userspace */ + kobject_uevent(idletimer_tg_kobj,KOBJ_ADD); + + list_add(&info->timer->entry, &idletimer_tg_list); + pr_debug("timer type value is %u", info->timer_type); + info->timer->timer_type = info->timer_type; + info->timer->refcnt = 1; + + INIT_WORK(&info->timer->work, idletimer_tg_work); + + if (info->timer->timer_type & XT_IDLETIMER_ALARM) { + ktime_t tout; + alarm_init(&info->timer->alarm, ALARM_BOOTTIME, + idletimer_tg_alarmproc); + info->timer->alarm.data = info->timer; + tout = ktime_set(info->timeout, 0); + alarm_start_relative(&info->timer->alarm, tout); + } else { + timer_setup(&info->timer->timer, idletimer_tg_expired, 0); + mod_timer(&info->timer->timer, + msecs_to_jiffies(info->timeout * 1000) + jiffies); + } + + return 0; + +out_free_attr: + kfree(info->timer->attr.attr.name); +out_free_timer: + kfree(info->timer); +out: + return ret; +} + +/* + * The actual xt_tables plugin. + */ +static unsigned int idletimer_tg_target(struct sk_buff *skb, + const struct xt_action_param *par) +{ + const struct idletimer_tg_info *info = par->targinfo; + + pr_debug("resetting timer %s, timeout period %u\n", + info->label, info->timeout); + + mod_timer(&info->timer->timer, + msecs_to_jiffies(info->timeout * 1000) + jiffies); + + return XT_CONTINUE; +} + +/* + * The actual xt_tables plugin. + */ +static unsigned int idletimer_tg_target_v1(struct sk_buff *skb, + const struct xt_action_param *par) +{ + const struct idletimer_tg_info_v1 *info = par->targinfo; + + pr_debug("resetting timer %s, timeout period %u\n", + info->label, info->timeout); + + if (info->timer->timer_type & XT_IDLETIMER_ALARM) { + ktime_t tout = ktime_set(info->timeout, 0); + alarm_start_relative(&info->timer->alarm, tout); + } else { + mod_timer(&info->timer->timer, + msecs_to_jiffies(info->timeout * 1000) + jiffies); + } + + return XT_CONTINUE; +} + +static int idletimer_tg_helper(struct idletimer_tg_info *info) +{ + if (info->timeout == 0) { + pr_debug("timeout value is zero\n"); + return -EINVAL; + } + if (info->timeout >= INT_MAX / 1000) { + pr_debug("timeout value is too big\n"); + return -EINVAL; + } + if (info->label[0] == '\0' || + strnlen(info->label, + MAX_IDLETIMER_LABEL_SIZE) == MAX_IDLETIMER_LABEL_SIZE) { + pr_debug("label is empty or not nul-terminated\n"); + return -EINVAL; + } + return 0; +} + + +static int idletimer_tg_checkentry(const struct xt_tgchk_param *par) +{ + struct idletimer_tg_info *info = par->targinfo; + int ret; + + pr_debug("checkentry targinfo%s\n", info->label); + + ret = idletimer_tg_helper(info); + if(ret < 0) + { + pr_debug("checkentry helper return invalid\n"); + return -EINVAL; + } + mutex_lock(&list_mutex); + + info->timer = __idletimer_tg_find_by_label(info->label); + if (info->timer) { + info->timer->refcnt++; + mod_timer(&info->timer->timer, + msecs_to_jiffies(info->timeout * 1000) + jiffies); + + pr_debug("increased refcnt of timer %s to %u\n", + info->label, info->timer->refcnt); + } else { + ret = idletimer_tg_create(info); + if (ret < 0) { + pr_debug("failed to create timer\n"); + mutex_unlock(&list_mutex); + return ret; + } + } + + mutex_unlock(&list_mutex); + return 0; +} + +static int idletimer_tg_checkentry_v1(const struct xt_tgchk_param *par) +{ + struct idletimer_tg_info_v1 *info = par->targinfo; + int ret; + + pr_debug("checkentry targinfo%s\n", info->label); + + if (info->send_nl_msg) + return -EOPNOTSUPP; + + ret = idletimer_tg_helper((struct idletimer_tg_info *)info); + if(ret < 0) + { + pr_debug("checkentry helper return invalid\n"); + return -EINVAL; + } + + if (info->timer_type > XT_IDLETIMER_ALARM) { + pr_debug("invalid value for timer type\n"); + return -EINVAL; + } + + mutex_lock(&list_mutex); + + info->timer = __idletimer_tg_find_by_label(info->label); + if (info->timer) { + if (info->timer->timer_type != info->timer_type) { + pr_debug("Adding/Replacing rule with same label and different timer type is not allowed\n"); + mutex_unlock(&list_mutex); + return -EINVAL; + } + + info->timer->refcnt++; + if (info->timer_type & XT_IDLETIMER_ALARM) { + /* calculate remaining expiry time */ + ktime_t tout = alarm_expires_remaining(&info->timer->alarm); + struct timespec64 ktimespec = ktime_to_timespec64(tout); + + if (ktimespec.tv_sec > 0) { + pr_debug("time_expiry_remaining %lld\n", + ktimespec.tv_sec); + alarm_start_relative(&info->timer->alarm, tout); + } + } else { + mod_timer(&info->timer->timer, + msecs_to_jiffies(info->timeout * 1000) + jiffies); + } + pr_debug("increased refcnt of timer %s to %u\n", + info->label, info->timer->refcnt); + } else { + ret = idletimer_tg_create_v1(info); + if (ret < 0) { + pr_debug("failed to create timer\n"); + mutex_unlock(&list_mutex); + return ret; + } + } + + mutex_unlock(&list_mutex); + return 0; +} + +static void idletimer_tg_destroy(const struct xt_tgdtor_param *par) +{ + const struct idletimer_tg_info *info = par->targinfo; + + pr_debug("destroy targinfo %s\n", info->label); + + mutex_lock(&list_mutex); + + if (--info->timer->refcnt == 0) { + pr_debug("deleting timer %s\n", info->label); + + list_del(&info->timer->entry); + del_timer_sync(&info->timer->timer); + cancel_work_sync(&info->timer->work); + sysfs_remove_file(idletimer_tg_kobj, &info->timer->attr.attr); + kfree(info->timer->attr.attr.name); + kfree(info->timer); + } else { + pr_debug("decreased refcnt of timer %s to %u\n", + info->label, info->timer->refcnt); + } + + mutex_unlock(&list_mutex); +} + +static void idletimer_tg_destroy_v1(const struct xt_tgdtor_param *par) +{ + const struct idletimer_tg_info_v1 *info = par->targinfo; + + pr_debug("destroy targinfo %s\n", info->label); + + mutex_lock(&list_mutex); + + if (--info->timer->refcnt == 0) { + pr_debug("deleting timer %s\n", info->label); + + list_del(&info->timer->entry); + if (info->timer->timer_type & XT_IDLETIMER_ALARM) { + alarm_cancel(&info->timer->alarm); + } else { + del_timer_sync(&info->timer->timer); + } + cancel_work_sync(&info->timer->work); + sysfs_remove_file(idletimer_tg_kobj, &info->timer->attr.attr); + kfree(info->timer->attr.attr.name); + kfree(info->timer); + } else { + pr_debug("decreased refcnt of timer %s to %u\n", + info->label, info->timer->refcnt); + } + + mutex_unlock(&list_mutex); +} + + +static struct xt_target idletimer_tg[] __read_mostly = { + { + .name = "IDLETIMER", + .family = NFPROTO_UNSPEC, + .target = idletimer_tg_target, + .targetsize = sizeof(struct idletimer_tg_info), + .usersize = offsetof(struct idletimer_tg_info, timer), + .checkentry = idletimer_tg_checkentry, + .destroy = idletimer_tg_destroy, + .me = THIS_MODULE, + }, + { + .name = "IDLETIMER", + .family = NFPROTO_UNSPEC, + .revision = 1, + .target = idletimer_tg_target_v1, + .targetsize = sizeof(struct idletimer_tg_info_v1), + .usersize = offsetof(struct idletimer_tg_info_v1, timer), + .checkentry = idletimer_tg_checkentry_v1, + .destroy = idletimer_tg_destroy_v1, + .me = THIS_MODULE, + }, + + +}; + +static struct class *idletimer_tg_class; + +static struct device *idletimer_tg_device; + +static int __init idletimer_tg_init(void) +{ + int err; + + idletimer_tg_class = class_create(THIS_MODULE, "xt_idletimer"); + err = PTR_ERR(idletimer_tg_class); + if (IS_ERR(idletimer_tg_class)) { + pr_debug("couldn't register device class\n"); + goto out; + } + + idletimer_tg_device = device_create(idletimer_tg_class, NULL, + MKDEV(0, 0), NULL, "timers"); + err = PTR_ERR(idletimer_tg_device); + if (IS_ERR(idletimer_tg_device)) { + pr_debug("couldn't register system device\n"); + goto out_class; + } + + idletimer_tg_kobj = &idletimer_tg_device->kobj; + + err = xt_register_targets(idletimer_tg, ARRAY_SIZE(idletimer_tg)); + + if (err < 0) { + pr_debug("couldn't register xt target\n"); + goto out_dev; + } + + return 0; +out_dev: + device_destroy(idletimer_tg_class, MKDEV(0, 0)); +out_class: + class_destroy(idletimer_tg_class); +out: + return err; +} + +static void __exit idletimer_tg_exit(void) +{ + xt_unregister_targets(idletimer_tg, ARRAY_SIZE(idletimer_tg)); + + device_destroy(idletimer_tg_class, MKDEV(0, 0)); + class_destroy(idletimer_tg_class); +} + +module_init(idletimer_tg_init); +module_exit(idletimer_tg_exit); + +MODULE_AUTHOR("Timo Teras "); +MODULE_AUTHOR("Luciano Coelho "); +MODULE_DESCRIPTION("Xtables: idle time monitor"); +MODULE_LICENSE("GPL v2"); +MODULE_ALIAS("ipt_IDLETIMER"); +MODULE_ALIAS("ip6t_IDLETIMER"); diff --git a/net/netfilter/xt_LED.c b/net/netfilter/xt_LED.c new file mode 100644 index 000000000..0371c387b --- /dev/null +++ b/net/netfilter/xt_LED.c @@ -0,0 +1,202 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * xt_LED.c - netfilter target to make LEDs blink upon packet matches + * + * Copyright (C) 2008 Adam Nielsen + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include + +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Adam Nielsen "); +MODULE_DESCRIPTION("Xtables: trigger LED devices on packet match"); +MODULE_ALIAS("ipt_LED"); +MODULE_ALIAS("ip6t_LED"); + +static LIST_HEAD(xt_led_triggers); +static DEFINE_MUTEX(xt_led_mutex); + +/* + * This is declared in here (the kernel module) only, to avoid having these + * dependencies in userspace code. This is what xt_led_info.internal_data + * points to. + */ +struct xt_led_info_internal { + struct list_head list; + int refcnt; + char *trigger_id; + struct led_trigger netfilter_led_trigger; + struct timer_list timer; +}; + +#define XT_LED_BLINK_DELAY 50 /* ms */ + +static unsigned int +led_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_led_info *ledinfo = par->targinfo; + struct xt_led_info_internal *ledinternal = ledinfo->internal_data; + unsigned long led_delay = XT_LED_BLINK_DELAY; + + /* + * If "always blink" is enabled, and there's still some time until the + * LED will switch off, briefly switch it off now. + */ + if ((ledinfo->delay > 0) && ledinfo->always_blink && + timer_pending(&ledinternal->timer)) + led_trigger_blink_oneshot(&ledinternal->netfilter_led_trigger, + &led_delay, &led_delay, 1); + else + led_trigger_event(&ledinternal->netfilter_led_trigger, LED_FULL); + + /* If there's a positive delay, start/update the timer */ + if (ledinfo->delay > 0) { + mod_timer(&ledinternal->timer, + jiffies + msecs_to_jiffies(ledinfo->delay)); + + /* Otherwise if there was no delay given, blink as fast as possible */ + } else if (ledinfo->delay == 0) { + led_trigger_event(&ledinternal->netfilter_led_trigger, LED_OFF); + } + + /* else the delay is negative, which means switch on and stay on */ + + return XT_CONTINUE; +} + +static void led_timeout_callback(struct timer_list *t) +{ + struct xt_led_info_internal *ledinternal = from_timer(ledinternal, t, + timer); + + led_trigger_event(&ledinternal->netfilter_led_trigger, LED_OFF); +} + +static struct xt_led_info_internal *led_trigger_lookup(const char *name) +{ + struct xt_led_info_internal *ledinternal; + + list_for_each_entry(ledinternal, &xt_led_triggers, list) { + if (!strcmp(name, ledinternal->netfilter_led_trigger.name)) { + return ledinternal; + } + } + return NULL; +} + +static int led_tg_check(const struct xt_tgchk_param *par) +{ + struct xt_led_info *ledinfo = par->targinfo; + struct xt_led_info_internal *ledinternal; + int err; + + if (ledinfo->id[0] == '\0') + return -EINVAL; + + mutex_lock(&xt_led_mutex); + + ledinternal = led_trigger_lookup(ledinfo->id); + if (ledinternal) { + ledinternal->refcnt++; + goto out; + } + + err = -ENOMEM; + ledinternal = kzalloc(sizeof(struct xt_led_info_internal), GFP_KERNEL); + if (!ledinternal) + goto exit_mutex_only; + + ledinternal->trigger_id = kstrdup(ledinfo->id, GFP_KERNEL); + if (!ledinternal->trigger_id) + goto exit_internal_alloc; + + ledinternal->refcnt = 1; + ledinternal->netfilter_led_trigger.name = ledinternal->trigger_id; + + err = led_trigger_register(&ledinternal->netfilter_led_trigger); + if (err) { + pr_info_ratelimited("Trigger name is already in use.\n"); + goto exit_alloc; + } + + /* Since the letinternal timer can be shared between multiple targets, + * always set it up, even if the current target does not need it + */ + timer_setup(&ledinternal->timer, led_timeout_callback, 0); + + list_add_tail(&ledinternal->list, &xt_led_triggers); + +out: + mutex_unlock(&xt_led_mutex); + + ledinfo->internal_data = ledinternal; + + return 0; + +exit_alloc: + kfree(ledinternal->trigger_id); + +exit_internal_alloc: + kfree(ledinternal); + +exit_mutex_only: + mutex_unlock(&xt_led_mutex); + + return err; +} + +static void led_tg_destroy(const struct xt_tgdtor_param *par) +{ + const struct xt_led_info *ledinfo = par->targinfo; + struct xt_led_info_internal *ledinternal = ledinfo->internal_data; + + mutex_lock(&xt_led_mutex); + + if (--ledinternal->refcnt) { + mutex_unlock(&xt_led_mutex); + return; + } + + list_del(&ledinternal->list); + + del_timer_sync(&ledinternal->timer); + + led_trigger_unregister(&ledinternal->netfilter_led_trigger); + + mutex_unlock(&xt_led_mutex); + + kfree(ledinternal->trigger_id); + kfree(ledinternal); +} + +static struct xt_target led_tg_reg __read_mostly = { + .name = "LED", + .revision = 0, + .family = NFPROTO_UNSPEC, + .target = led_tg, + .targetsize = sizeof(struct xt_led_info), + .usersize = offsetof(struct xt_led_info, internal_data), + .checkentry = led_tg_check, + .destroy = led_tg_destroy, + .me = THIS_MODULE, +}; + +static int __init led_tg_init(void) +{ + return xt_register_target(&led_tg_reg); +} + +static void __exit led_tg_exit(void) +{ + xt_unregister_target(&led_tg_reg); +} + +module_init(led_tg_init); +module_exit(led_tg_exit); diff --git a/net/netfilter/xt_LOG.c b/net/netfilter/xt_LOG.c new file mode 100644 index 000000000..f39244f9c --- /dev/null +++ b/net/netfilter/xt_LOG.c @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * This is a module which is used for logging packets. + */ + +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2004 Netfilter Core Team + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +static unsigned int +log_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_log_info *loginfo = par->targinfo; + struct net *net = xt_net(par); + struct nf_loginfo li; + + li.type = NF_LOG_TYPE_LOG; + li.u.log.level = loginfo->level; + li.u.log.logflags = loginfo->logflags; + + nf_log_packet(net, xt_family(par), xt_hooknum(par), skb, xt_in(par), + xt_out(par), &li, "%s", loginfo->prefix); + return XT_CONTINUE; +} + +static int log_tg_check(const struct xt_tgchk_param *par) +{ + const struct xt_log_info *loginfo = par->targinfo; + int ret; + + if (par->family != NFPROTO_IPV4 && par->family != NFPROTO_IPV6) + return -EINVAL; + + if (loginfo->level >= 8) { + pr_debug("level %u >= 8\n", loginfo->level); + return -EINVAL; + } + + if (loginfo->prefix[sizeof(loginfo->prefix)-1] != '\0') { + pr_debug("prefix is not null-terminated\n"); + return -EINVAL; + } + + ret = nf_logger_find_get(par->family, NF_LOG_TYPE_LOG); + if (ret != 0 && !par->nft_compat) { + request_module("%s", "nf_log_syslog"); + + ret = nf_logger_find_get(par->family, NF_LOG_TYPE_LOG); + } + + return ret; +} + +static void log_tg_destroy(const struct xt_tgdtor_param *par) +{ + nf_logger_put(par->family, NF_LOG_TYPE_LOG); +} + +static struct xt_target log_tg_regs[] __read_mostly = { + { + .name = "LOG", + .family = NFPROTO_IPV4, + .target = log_tg, + .targetsize = sizeof(struct xt_log_info), + .checkentry = log_tg_check, + .destroy = log_tg_destroy, + .me = THIS_MODULE, + }, +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + { + .name = "LOG", + .family = NFPROTO_IPV6, + .target = log_tg, + .targetsize = sizeof(struct xt_log_info), + .checkentry = log_tg_check, + .destroy = log_tg_destroy, + .me = THIS_MODULE, + }, +#endif +}; + +static int __init log_tg_init(void) +{ + return xt_register_targets(log_tg_regs, ARRAY_SIZE(log_tg_regs)); +} + +static void __exit log_tg_exit(void) +{ + xt_unregister_targets(log_tg_regs, ARRAY_SIZE(log_tg_regs)); +} + +module_init(log_tg_init); +module_exit(log_tg_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Netfilter Core Team "); +MODULE_AUTHOR("Jan Rekorajski "); +MODULE_DESCRIPTION("Xtables: IPv4/IPv6 packet logging"); +MODULE_ALIAS("ipt_LOG"); +MODULE_ALIAS("ip6t_LOG"); +MODULE_SOFTDEP("pre: nf_log_syslog"); diff --git a/net/netfilter/xt_MASQUERADE.c b/net/netfilter/xt_MASQUERADE.c new file mode 100644 index 000000000..eae05c178 --- /dev/null +++ b/net/netfilter/xt_MASQUERADE.c @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Masquerade. Simple mapping which alters range to a local IP address + (depending on route). */ + +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2006 Netfilter Core Team + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Netfilter Core Team "); +MODULE_DESCRIPTION("Xtables: automatic-address SNAT"); + +/* FIXME: Multiple targets. --RR */ +static int masquerade_tg_check(const struct xt_tgchk_param *par) +{ + const struct nf_nat_ipv4_multi_range_compat *mr = par->targinfo; + + if (mr->range[0].flags & NF_NAT_RANGE_MAP_IPS) { + pr_debug("bad MAP_IPS.\n"); + return -EINVAL; + } + if (mr->rangesize != 1) { + pr_debug("bad rangesize %u\n", mr->rangesize); + return -EINVAL; + } + return nf_ct_netns_get(par->net, par->family); +} + +static unsigned int +masquerade_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + struct nf_nat_range2 range; + const struct nf_nat_ipv4_multi_range_compat *mr; + + mr = par->targinfo; + range.flags = mr->range[0].flags; + range.min_proto = mr->range[0].min; + range.max_proto = mr->range[0].max; + + return nf_nat_masquerade_ipv4(skb, xt_hooknum(par), &range, + xt_out(par)); +} + +static void masquerade_tg_destroy(const struct xt_tgdtor_param *par) +{ + nf_ct_netns_put(par->net, par->family); +} + +#if IS_ENABLED(CONFIG_IPV6) +static unsigned int +masquerade_tg6(struct sk_buff *skb, const struct xt_action_param *par) +{ + return nf_nat_masquerade_ipv6(skb, par->targinfo, xt_out(par)); +} + +static int masquerade_tg6_checkentry(const struct xt_tgchk_param *par) +{ + const struct nf_nat_range2 *range = par->targinfo; + + if (range->flags & NF_NAT_RANGE_MAP_IPS) + return -EINVAL; + + return nf_ct_netns_get(par->net, par->family); +} +#endif + +static struct xt_target masquerade_tg_reg[] __read_mostly = { + { +#if IS_ENABLED(CONFIG_IPV6) + .name = "MASQUERADE", + .family = NFPROTO_IPV6, + .target = masquerade_tg6, + .targetsize = sizeof(struct nf_nat_range), + .table = "nat", + .hooks = 1 << NF_INET_POST_ROUTING, + .checkentry = masquerade_tg6_checkentry, + .destroy = masquerade_tg_destroy, + .me = THIS_MODULE, + }, { +#endif + .name = "MASQUERADE", + .family = NFPROTO_IPV4, + .target = masquerade_tg, + .targetsize = sizeof(struct nf_nat_ipv4_multi_range_compat), + .table = "nat", + .hooks = 1 << NF_INET_POST_ROUTING, + .checkentry = masquerade_tg_check, + .destroy = masquerade_tg_destroy, + .me = THIS_MODULE, + } +}; + +static int __init masquerade_tg_init(void) +{ + int ret; + + ret = xt_register_targets(masquerade_tg_reg, + ARRAY_SIZE(masquerade_tg_reg)); + if (ret) + return ret; + + ret = nf_nat_masquerade_inet_register_notifiers(); + if (ret) { + xt_unregister_targets(masquerade_tg_reg, + ARRAY_SIZE(masquerade_tg_reg)); + return ret; + } + + return ret; +} + +static void __exit masquerade_tg_exit(void) +{ + xt_unregister_targets(masquerade_tg_reg, ARRAY_SIZE(masquerade_tg_reg)); + nf_nat_masquerade_inet_unregister_notifiers(); +} + +module_init(masquerade_tg_init); +module_exit(masquerade_tg_exit); +#if IS_ENABLED(CONFIG_IPV6) +MODULE_ALIAS("ip6t_MASQUERADE"); +#endif +MODULE_ALIAS("ipt_MASQUERADE"); diff --git a/net/netfilter/xt_NETMAP.c b/net/netfilter/xt_NETMAP.c new file mode 100644 index 000000000..cb2ee80d8 --- /dev/null +++ b/net/netfilter/xt_NETMAP.c @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * (C) 2000-2001 Svenning Soerensen + * Copyright (c) 2011 Patrick McHardy + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static unsigned int +netmap_tg6(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct nf_nat_range2 *range = par->targinfo; + struct nf_nat_range2 newrange; + struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + union nf_inet_addr new_addr, netmask; + unsigned int i; + + ct = nf_ct_get(skb, &ctinfo); + for (i = 0; i < ARRAY_SIZE(range->min_addr.ip6); i++) + netmask.ip6[i] = ~(range->min_addr.ip6[i] ^ + range->max_addr.ip6[i]); + + if (xt_hooknum(par) == NF_INET_PRE_ROUTING || + xt_hooknum(par) == NF_INET_LOCAL_OUT) + new_addr.in6 = ipv6_hdr(skb)->daddr; + else + new_addr.in6 = ipv6_hdr(skb)->saddr; + + for (i = 0; i < ARRAY_SIZE(new_addr.ip6); i++) { + new_addr.ip6[i] &= ~netmask.ip6[i]; + new_addr.ip6[i] |= range->min_addr.ip6[i] & + netmask.ip6[i]; + } + + newrange.flags = range->flags | NF_NAT_RANGE_MAP_IPS; + newrange.min_addr = new_addr; + newrange.max_addr = new_addr; + newrange.min_proto = range->min_proto; + newrange.max_proto = range->max_proto; + + return nf_nat_setup_info(ct, &newrange, HOOK2MANIP(xt_hooknum(par))); +} + +static int netmap_tg6_checkentry(const struct xt_tgchk_param *par) +{ + const struct nf_nat_range2 *range = par->targinfo; + + if (!(range->flags & NF_NAT_RANGE_MAP_IPS)) + return -EINVAL; + return nf_ct_netns_get(par->net, par->family); +} + +static void netmap_tg_destroy(const struct xt_tgdtor_param *par) +{ + nf_ct_netns_put(par->net, par->family); +} + +static unsigned int +netmap_tg4(struct sk_buff *skb, const struct xt_action_param *par) +{ + struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + __be32 new_ip, netmask; + const struct nf_nat_ipv4_multi_range_compat *mr = par->targinfo; + struct nf_nat_range2 newrange; + + WARN_ON(xt_hooknum(par) != NF_INET_PRE_ROUTING && + xt_hooknum(par) != NF_INET_POST_ROUTING && + xt_hooknum(par) != NF_INET_LOCAL_OUT && + xt_hooknum(par) != NF_INET_LOCAL_IN); + ct = nf_ct_get(skb, &ctinfo); + + netmask = ~(mr->range[0].min_ip ^ mr->range[0].max_ip); + + if (xt_hooknum(par) == NF_INET_PRE_ROUTING || + xt_hooknum(par) == NF_INET_LOCAL_OUT) + new_ip = ip_hdr(skb)->daddr & ~netmask; + else + new_ip = ip_hdr(skb)->saddr & ~netmask; + new_ip |= mr->range[0].min_ip & netmask; + + memset(&newrange.min_addr, 0, sizeof(newrange.min_addr)); + memset(&newrange.max_addr, 0, sizeof(newrange.max_addr)); + newrange.flags = mr->range[0].flags | NF_NAT_RANGE_MAP_IPS; + newrange.min_addr.ip = new_ip; + newrange.max_addr.ip = new_ip; + newrange.min_proto = mr->range[0].min; + newrange.max_proto = mr->range[0].max; + + /* Hand modified range to generic setup. */ + return nf_nat_setup_info(ct, &newrange, HOOK2MANIP(xt_hooknum(par))); +} + +static int netmap_tg4_check(const struct xt_tgchk_param *par) +{ + const struct nf_nat_ipv4_multi_range_compat *mr = par->targinfo; + + if (!(mr->range[0].flags & NF_NAT_RANGE_MAP_IPS)) { + pr_debug("bad MAP_IPS.\n"); + return -EINVAL; + } + if (mr->rangesize != 1) { + pr_debug("bad rangesize %u.\n", mr->rangesize); + return -EINVAL; + } + return nf_ct_netns_get(par->net, par->family); +} + +static struct xt_target netmap_tg_reg[] __read_mostly = { + { + .name = "NETMAP", + .family = NFPROTO_IPV6, + .revision = 0, + .target = netmap_tg6, + .targetsize = sizeof(struct nf_nat_range), + .table = "nat", + .hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_LOCAL_IN), + .checkentry = netmap_tg6_checkentry, + .destroy = netmap_tg_destroy, + .me = THIS_MODULE, + }, + { + .name = "NETMAP", + .family = NFPROTO_IPV4, + .revision = 0, + .target = netmap_tg4, + .targetsize = sizeof(struct nf_nat_ipv4_multi_range_compat), + .table = "nat", + .hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_LOCAL_IN), + .checkentry = netmap_tg4_check, + .destroy = netmap_tg_destroy, + .me = THIS_MODULE, + }, +}; + +static int __init netmap_tg_init(void) +{ + return xt_register_targets(netmap_tg_reg, ARRAY_SIZE(netmap_tg_reg)); +} + +static void netmap_tg_exit(void) +{ + xt_unregister_targets(netmap_tg_reg, ARRAY_SIZE(netmap_tg_reg)); +} + +module_init(netmap_tg_init); +module_exit(netmap_tg_exit); + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Xtables: 1:1 NAT mapping of subnets"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_ALIAS("ip6t_NETMAP"); +MODULE_ALIAS("ipt_NETMAP"); diff --git a/net/netfilter/xt_NFLOG.c b/net/netfilter/xt_NFLOG.c new file mode 100644 index 000000000..e660c3710 --- /dev/null +++ b/net/netfilter/xt_NFLOG.c @@ -0,0 +1,90 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2006 Patrick McHardy + */ + +#include +#include +#include + +#include +#include +#include + +MODULE_AUTHOR("Patrick McHardy "); +MODULE_DESCRIPTION("Xtables: packet logging to netlink using NFLOG"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_NFLOG"); +MODULE_ALIAS("ip6t_NFLOG"); + +static unsigned int +nflog_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_nflog_info *info = par->targinfo; + struct net *net = xt_net(par); + struct nf_loginfo li; + + li.type = NF_LOG_TYPE_ULOG; + li.u.ulog.copy_len = info->len; + li.u.ulog.group = info->group; + li.u.ulog.qthreshold = info->threshold; + li.u.ulog.flags = 0; + + if (info->flags & XT_NFLOG_F_COPY_LEN) + li.u.ulog.flags |= NF_LOG_F_COPY_LEN; + + nf_log_packet(net, xt_family(par), xt_hooknum(par), skb, xt_in(par), + xt_out(par), &li, "%s", info->prefix); + + return XT_CONTINUE; +} + +static int nflog_tg_check(const struct xt_tgchk_param *par) +{ + const struct xt_nflog_info *info = par->targinfo; + int ret; + + if (info->flags & ~XT_NFLOG_MASK) + return -EINVAL; + if (info->prefix[sizeof(info->prefix) - 1] != '\0') + return -EINVAL; + + ret = nf_logger_find_get(par->family, NF_LOG_TYPE_ULOG); + if (ret != 0 && !par->nft_compat) { + request_module("%s", "nfnetlink_log"); + + ret = nf_logger_find_get(par->family, NF_LOG_TYPE_ULOG); + } + + return ret; +} + +static void nflog_tg_destroy(const struct xt_tgdtor_param *par) +{ + nf_logger_put(par->family, NF_LOG_TYPE_ULOG); +} + +static struct xt_target nflog_tg_reg __read_mostly = { + .name = "NFLOG", + .revision = 0, + .family = NFPROTO_UNSPEC, + .checkentry = nflog_tg_check, + .destroy = nflog_tg_destroy, + .target = nflog_tg, + .targetsize = sizeof(struct xt_nflog_info), + .me = THIS_MODULE, +}; + +static int __init nflog_tg_init(void) +{ + return xt_register_target(&nflog_tg_reg); +} + +static void __exit nflog_tg_exit(void) +{ + xt_unregister_target(&nflog_tg_reg); +} + +module_init(nflog_tg_init); +module_exit(nflog_tg_exit); +MODULE_SOFTDEP("pre: nfnetlink_log"); diff --git a/net/netfilter/xt_NFQUEUE.c b/net/netfilter/xt_NFQUEUE.c new file mode 100644 index 000000000..466da23e3 --- /dev/null +++ b/net/netfilter/xt_NFQUEUE.c @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* iptables module for using new netfilter netlink queue + * + * (C) 2005 by Harald Welte + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include + +#include +#include +#include +#include + +#include + +MODULE_AUTHOR("Harald Welte "); +MODULE_DESCRIPTION("Xtables: packet forwarding to netlink"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_NFQUEUE"); +MODULE_ALIAS("ip6t_NFQUEUE"); +MODULE_ALIAS("arpt_NFQUEUE"); + +static u32 jhash_initval __read_mostly; + +static unsigned int +nfqueue_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_NFQ_info *tinfo = par->targinfo; + + return NF_QUEUE_NR(tinfo->queuenum); +} + +static unsigned int +nfqueue_tg_v1(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_NFQ_info_v1 *info = par->targinfo; + u32 queue = info->queuenum; + + if (info->queues_total > 1) { + queue = nfqueue_hash(skb, queue, info->queues_total, + xt_family(par), jhash_initval); + } + return NF_QUEUE_NR(queue); +} + +static unsigned int +nfqueue_tg_v2(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_NFQ_info_v2 *info = par->targinfo; + unsigned int ret = nfqueue_tg_v1(skb, par); + + if (info->bypass) + ret |= NF_VERDICT_FLAG_QUEUE_BYPASS; + return ret; +} + +static int nfqueue_tg_check(const struct xt_tgchk_param *par) +{ + const struct xt_NFQ_info_v3 *info = par->targinfo; + u32 maxid; + + init_hashrandom(&jhash_initval); + + if (info->queues_total == 0) { + pr_info_ratelimited("number of total queues is 0\n"); + return -EINVAL; + } + maxid = info->queues_total - 1 + info->queuenum; + if (maxid > 0xffff) { + pr_info_ratelimited("number of queues (%u) out of range (got %u)\n", + info->queues_total, maxid); + return -ERANGE; + } + if (par->target->revision == 2 && info->flags > 1) + return -EINVAL; + if (par->target->revision == 3 && info->flags & ~NFQ_FLAG_MASK) + return -EINVAL; + + return 0; +} + +static unsigned int +nfqueue_tg_v3(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_NFQ_info_v3 *info = par->targinfo; + u32 queue = info->queuenum; + int ret; + + if (info->queues_total > 1) { + if (info->flags & NFQ_FLAG_CPU_FANOUT) { + int cpu = smp_processor_id(); + + queue = info->queuenum + cpu % info->queues_total; + } else { + queue = nfqueue_hash(skb, queue, info->queues_total, + xt_family(par), jhash_initval); + } + } + + ret = NF_QUEUE_NR(queue); + if (info->flags & NFQ_FLAG_BYPASS) + ret |= NF_VERDICT_FLAG_QUEUE_BYPASS; + + return ret; +} + +static struct xt_target nfqueue_tg_reg[] __read_mostly = { + { + .name = "NFQUEUE", + .family = NFPROTO_UNSPEC, + .target = nfqueue_tg, + .targetsize = sizeof(struct xt_NFQ_info), + .me = THIS_MODULE, + }, + { + .name = "NFQUEUE", + .revision = 1, + .family = NFPROTO_UNSPEC, + .checkentry = nfqueue_tg_check, + .target = nfqueue_tg_v1, + .targetsize = sizeof(struct xt_NFQ_info_v1), + .me = THIS_MODULE, + }, + { + .name = "NFQUEUE", + .revision = 2, + .family = NFPROTO_UNSPEC, + .checkentry = nfqueue_tg_check, + .target = nfqueue_tg_v2, + .targetsize = sizeof(struct xt_NFQ_info_v2), + .me = THIS_MODULE, + }, + { + .name = "NFQUEUE", + .revision = 3, + .family = NFPROTO_UNSPEC, + .checkentry = nfqueue_tg_check, + .target = nfqueue_tg_v3, + .targetsize = sizeof(struct xt_NFQ_info_v3), + .me = THIS_MODULE, + }, +}; + +static int __init nfqueue_tg_init(void) +{ + return xt_register_targets(nfqueue_tg_reg, ARRAY_SIZE(nfqueue_tg_reg)); +} + +static void __exit nfqueue_tg_exit(void) +{ + xt_unregister_targets(nfqueue_tg_reg, ARRAY_SIZE(nfqueue_tg_reg)); +} + +module_init(nfqueue_tg_init); +module_exit(nfqueue_tg_exit); diff --git a/net/netfilter/xt_RATEEST.c b/net/netfilter/xt_RATEEST.c new file mode 100644 index 000000000..80f6624e2 --- /dev/null +++ b/net/netfilter/xt_RATEEST.c @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * (C) 2007 Patrick McHardy + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#define RATEEST_HSIZE 16 + +struct xt_rateest_net { + struct mutex hash_lock; + struct hlist_head hash[RATEEST_HSIZE]; +}; + +static unsigned int xt_rateest_id; + +static unsigned int jhash_rnd __read_mostly; + +static unsigned int xt_rateest_hash(const char *name) +{ + return jhash(name, sizeof_field(struct xt_rateest, name), jhash_rnd) & + (RATEEST_HSIZE - 1); +} + +static void xt_rateest_hash_insert(struct xt_rateest_net *xn, + struct xt_rateest *est) +{ + unsigned int h; + + h = xt_rateest_hash(est->name); + hlist_add_head(&est->list, &xn->hash[h]); +} + +static struct xt_rateest *__xt_rateest_lookup(struct xt_rateest_net *xn, + const char *name) +{ + struct xt_rateest *est; + unsigned int h; + + h = xt_rateest_hash(name); + hlist_for_each_entry(est, &xn->hash[h], list) { + if (strcmp(est->name, name) == 0) { + est->refcnt++; + return est; + } + } + + return NULL; +} + +struct xt_rateest *xt_rateest_lookup(struct net *net, const char *name) +{ + struct xt_rateest_net *xn = net_generic(net, xt_rateest_id); + struct xt_rateest *est; + + mutex_lock(&xn->hash_lock); + est = __xt_rateest_lookup(xn, name); + mutex_unlock(&xn->hash_lock); + return est; +} +EXPORT_SYMBOL_GPL(xt_rateest_lookup); + +void xt_rateest_put(struct net *net, struct xt_rateest *est) +{ + struct xt_rateest_net *xn = net_generic(net, xt_rateest_id); + + mutex_lock(&xn->hash_lock); + if (--est->refcnt == 0) { + hlist_del(&est->list); + gen_kill_estimator(&est->rate_est); + /* + * gen_estimator est_timer() might access est->lock or bstats, + * wait a RCU grace period before freeing 'est' + */ + kfree_rcu(est, rcu); + } + mutex_unlock(&xn->hash_lock); +} +EXPORT_SYMBOL_GPL(xt_rateest_put); + +static unsigned int +xt_rateest_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_rateest_target_info *info = par->targinfo; + struct gnet_stats_basic_sync *stats = &info->est->bstats; + + spin_lock_bh(&info->est->lock); + u64_stats_add(&stats->bytes, skb->len); + u64_stats_inc(&stats->packets); + spin_unlock_bh(&info->est->lock); + + return XT_CONTINUE; +} + +static int xt_rateest_tg_checkentry(const struct xt_tgchk_param *par) +{ + struct xt_rateest_net *xn = net_generic(par->net, xt_rateest_id); + struct xt_rateest_target_info *info = par->targinfo; + struct xt_rateest *est; + struct { + struct nlattr opt; + struct gnet_estimator est; + } cfg; + int ret; + + if (strnlen(info->name, sizeof(est->name)) >= sizeof(est->name)) + return -ENAMETOOLONG; + + net_get_random_once(&jhash_rnd, sizeof(jhash_rnd)); + + mutex_lock(&xn->hash_lock); + est = __xt_rateest_lookup(xn, info->name); + if (est) { + mutex_unlock(&xn->hash_lock); + /* + * If estimator parameters are specified, they must match the + * existing estimator. + */ + if ((!info->interval && !info->ewma_log) || + (info->interval != est->params.interval || + info->ewma_log != est->params.ewma_log)) { + xt_rateest_put(par->net, est); + return -EINVAL; + } + info->est = est; + return 0; + } + + ret = -ENOMEM; + est = kzalloc(sizeof(*est), GFP_KERNEL); + if (!est) + goto err1; + + gnet_stats_basic_sync_init(&est->bstats); + strscpy(est->name, info->name, sizeof(est->name)); + spin_lock_init(&est->lock); + est->refcnt = 1; + est->params.interval = info->interval; + est->params.ewma_log = info->ewma_log; + + cfg.opt.nla_len = nla_attr_size(sizeof(cfg.est)); + cfg.opt.nla_type = TCA_STATS_RATE_EST; + cfg.est.interval = info->interval; + cfg.est.ewma_log = info->ewma_log; + + ret = gen_new_estimator(&est->bstats, NULL, &est->rate_est, + &est->lock, NULL, &cfg.opt); + if (ret < 0) + goto err2; + + info->est = est; + xt_rateest_hash_insert(xn, est); + mutex_unlock(&xn->hash_lock); + return 0; + +err2: + kfree(est); +err1: + mutex_unlock(&xn->hash_lock); + return ret; +} + +static void xt_rateest_tg_destroy(const struct xt_tgdtor_param *par) +{ + struct xt_rateest_target_info *info = par->targinfo; + + xt_rateest_put(par->net, info->est); +} + +static struct xt_target xt_rateest_tg_reg __read_mostly = { + .name = "RATEEST", + .revision = 0, + .family = NFPROTO_UNSPEC, + .target = xt_rateest_tg, + .checkentry = xt_rateest_tg_checkentry, + .destroy = xt_rateest_tg_destroy, + .targetsize = sizeof(struct xt_rateest_target_info), + .usersize = offsetof(struct xt_rateest_target_info, est), + .me = THIS_MODULE, +}; + +static __net_init int xt_rateest_net_init(struct net *net) +{ + struct xt_rateest_net *xn = net_generic(net, xt_rateest_id); + int i; + + mutex_init(&xn->hash_lock); + for (i = 0; i < ARRAY_SIZE(xn->hash); i++) + INIT_HLIST_HEAD(&xn->hash[i]); + return 0; +} + +static struct pernet_operations xt_rateest_net_ops = { + .init = xt_rateest_net_init, + .id = &xt_rateest_id, + .size = sizeof(struct xt_rateest_net), +}; + +static int __init xt_rateest_tg_init(void) +{ + int err = register_pernet_subsys(&xt_rateest_net_ops); + + if (err) + return err; + return xt_register_target(&xt_rateest_tg_reg); +} + +static void __exit xt_rateest_tg_fini(void) +{ + xt_unregister_target(&xt_rateest_tg_reg); + unregister_pernet_subsys(&xt_rateest_net_ops); +} + + +MODULE_AUTHOR("Patrick McHardy "); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Xtables: packet rate estimator"); +MODULE_ALIAS("ipt_RATEEST"); +MODULE_ALIAS("ip6t_RATEEST"); +module_init(xt_rateest_tg_init); +module_exit(xt_rateest_tg_fini); diff --git a/net/netfilter/xt_REDIRECT.c b/net/netfilter/xt_REDIRECT.c new file mode 100644 index 000000000..ff66b56a3 --- /dev/null +++ b/net/netfilter/xt_REDIRECT.c @@ -0,0 +1,124 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2006 Netfilter Core Team + * Copyright (c) 2011 Patrick McHardy + * + * Based on Rusty Russell's IPv4 REDIRECT target. Development of IPv6 + * NAT funded by Astaro. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static unsigned int +redirect_tg6(struct sk_buff *skb, const struct xt_action_param *par) +{ + return nf_nat_redirect_ipv6(skb, par->targinfo, xt_hooknum(par)); +} + +static int redirect_tg6_checkentry(const struct xt_tgchk_param *par) +{ + const struct nf_nat_range2 *range = par->targinfo; + + if (range->flags & NF_NAT_RANGE_MAP_IPS) + return -EINVAL; + + return nf_ct_netns_get(par->net, par->family); +} + +static void redirect_tg_destroy(const struct xt_tgdtor_param *par) +{ + nf_ct_netns_put(par->net, par->family); +} + +static int redirect_tg4_check(const struct xt_tgchk_param *par) +{ + const struct nf_nat_ipv4_multi_range_compat *mr = par->targinfo; + + if (mr->range[0].flags & NF_NAT_RANGE_MAP_IPS) { + pr_debug("bad MAP_IPS.\n"); + return -EINVAL; + } + if (mr->rangesize != 1) { + pr_debug("bad rangesize %u.\n", mr->rangesize); + return -EINVAL; + } + return nf_ct_netns_get(par->net, par->family); +} + +static unsigned int +redirect_tg4(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct nf_nat_ipv4_multi_range_compat *mr = par->targinfo; + struct nf_nat_range2 range = { + .flags = mr->range[0].flags, + .min_proto = mr->range[0].min, + .max_proto = mr->range[0].max, + }; + + return nf_nat_redirect_ipv4(skb, &range, xt_hooknum(par)); +} + +static struct xt_target redirect_tg_reg[] __read_mostly = { + { + .name = "REDIRECT", + .family = NFPROTO_IPV6, + .revision = 0, + .table = "nat", + .checkentry = redirect_tg6_checkentry, + .destroy = redirect_tg_destroy, + .target = redirect_tg6, + .targetsize = sizeof(struct nf_nat_range), + .hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_OUT), + .me = THIS_MODULE, + }, + { + .name = "REDIRECT", + .family = NFPROTO_IPV4, + .revision = 0, + .table = "nat", + .target = redirect_tg4, + .checkentry = redirect_tg4_check, + .destroy = redirect_tg_destroy, + .targetsize = sizeof(struct nf_nat_ipv4_multi_range_compat), + .hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_OUT), + .me = THIS_MODULE, + }, +}; + +static int __init redirect_tg_init(void) +{ + return xt_register_targets(redirect_tg_reg, + ARRAY_SIZE(redirect_tg_reg)); +} + +static void __exit redirect_tg_exit(void) +{ + xt_unregister_targets(redirect_tg_reg, ARRAY_SIZE(redirect_tg_reg)); +} + +module_init(redirect_tg_init); +module_exit(redirect_tg_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_DESCRIPTION("Xtables: Connection redirection to localhost"); +MODULE_ALIAS("ip6t_REDIRECT"); +MODULE_ALIAS("ipt_REDIRECT"); diff --git a/net/netfilter/xt_SECMARK.c b/net/netfilter/xt_SECMARK.c new file mode 100644 index 000000000..498a0bf6f --- /dev/null +++ b/net/netfilter/xt_SECMARK.c @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Module for modifying the secmark field of the skb, for use by + * security subsystems. + * + * Based on the nfmark match by: + * (C) 1999-2001 Marc Boucher + * + * (C) 2006,2008 Red Hat, Inc., James Morris + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("James Morris "); +MODULE_DESCRIPTION("Xtables: packet security mark modification"); +MODULE_ALIAS("ipt_SECMARK"); +MODULE_ALIAS("ip6t_SECMARK"); + +static u8 mode; + +static unsigned int +secmark_tg(struct sk_buff *skb, const struct xt_secmark_target_info_v1 *info) +{ + u32 secmark = 0; + + switch (mode) { + case SECMARK_MODE_SEL: + secmark = info->secid; + break; + default: + BUG(); + } + + skb->secmark = secmark; + return XT_CONTINUE; +} + +static int checkentry_lsm(struct xt_secmark_target_info_v1 *info) +{ + int err; + + info->secctx[SECMARK_SECCTX_MAX - 1] = '\0'; + info->secid = 0; + + err = security_secctx_to_secid(info->secctx, strlen(info->secctx), + &info->secid); + if (err) { + if (err == -EINVAL) + pr_info_ratelimited("invalid security context \'%s\'\n", + info->secctx); + return err; + } + + if (!info->secid) { + pr_info_ratelimited("unable to map security context \'%s\'\n", + info->secctx); + return -ENOENT; + } + + err = security_secmark_relabel_packet(info->secid); + if (err) { + pr_info_ratelimited("unable to obtain relabeling permission\n"); + return err; + } + + security_secmark_refcount_inc(); + return 0; +} + +static int +secmark_tg_check(const char *table, struct xt_secmark_target_info_v1 *info) +{ + int err; + + if (strcmp(table, "mangle") != 0 && + strcmp(table, "security") != 0) { + pr_info_ratelimited("only valid in \'mangle\' or \'security\' table, not \'%s\'\n", + table); + return -EINVAL; + } + + if (mode && mode != info->mode) { + pr_info_ratelimited("mode already set to %hu cannot mix with rules for mode %hu\n", + mode, info->mode); + return -EINVAL; + } + + switch (info->mode) { + case SECMARK_MODE_SEL: + break; + default: + pr_info_ratelimited("invalid mode: %hu\n", info->mode); + return -EINVAL; + } + + err = checkentry_lsm(info); + if (err) + return err; + + if (!mode) + mode = info->mode; + return 0; +} + +static void secmark_tg_destroy(const struct xt_tgdtor_param *par) +{ + switch (mode) { + case SECMARK_MODE_SEL: + security_secmark_refcount_dec(); + } +} + +static int secmark_tg_check_v0(const struct xt_tgchk_param *par) +{ + struct xt_secmark_target_info *info = par->targinfo; + struct xt_secmark_target_info_v1 newinfo = { + .mode = info->mode, + }; + int ret; + + memcpy(newinfo.secctx, info->secctx, SECMARK_SECCTX_MAX); + + ret = secmark_tg_check(par->table, &newinfo); + info->secid = newinfo.secid; + + return ret; +} + +static unsigned int +secmark_tg_v0(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_secmark_target_info *info = par->targinfo; + struct xt_secmark_target_info_v1 newinfo = { + .secid = info->secid, + }; + + return secmark_tg(skb, &newinfo); +} + +static int secmark_tg_check_v1(const struct xt_tgchk_param *par) +{ + return secmark_tg_check(par->table, par->targinfo); +} + +static unsigned int +secmark_tg_v1(struct sk_buff *skb, const struct xt_action_param *par) +{ + return secmark_tg(skb, par->targinfo); +} + +static struct xt_target secmark_tg_reg[] __read_mostly = { + { + .name = "SECMARK", + .revision = 0, + .family = NFPROTO_UNSPEC, + .checkentry = secmark_tg_check_v0, + .destroy = secmark_tg_destroy, + .target = secmark_tg_v0, + .targetsize = sizeof(struct xt_secmark_target_info), + .me = THIS_MODULE, + }, + { + .name = "SECMARK", + .revision = 1, + .family = NFPROTO_UNSPEC, + .checkentry = secmark_tg_check_v1, + .destroy = secmark_tg_destroy, + .target = secmark_tg_v1, + .targetsize = sizeof(struct xt_secmark_target_info_v1), + .usersize = offsetof(struct xt_secmark_target_info_v1, secid), + .me = THIS_MODULE, + }, +}; + +static int __init secmark_tg_init(void) +{ + return xt_register_targets(secmark_tg_reg, ARRAY_SIZE(secmark_tg_reg)); +} + +static void __exit secmark_tg_exit(void) +{ + xt_unregister_targets(secmark_tg_reg, ARRAY_SIZE(secmark_tg_reg)); +} + +module_init(secmark_tg_init); +module_exit(secmark_tg_exit); diff --git a/net/netfilter/xt_TCPMSS.c b/net/netfilter/xt_TCPMSS.c new file mode 100644 index 000000000..116a885ad --- /dev/null +++ b/net/netfilter/xt_TCPMSS.c @@ -0,0 +1,345 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * This is a module which is used for setting the MSS option in TCP packets. + * + * Copyright (C) 2000 Marc Boucher + * Copyright (C) 2007 Patrick McHardy + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Marc Boucher "); +MODULE_DESCRIPTION("Xtables: TCP Maximum Segment Size (MSS) adjustment"); +MODULE_ALIAS("ipt_TCPMSS"); +MODULE_ALIAS("ip6t_TCPMSS"); + +static inline unsigned int +optlen(const u_int8_t *opt, unsigned int offset) +{ + /* Beware zero-length options: make finite progress */ + if (opt[offset] <= TCPOPT_NOP || opt[offset+1] == 0) + return 1; + else + return opt[offset+1]; +} + +static u_int32_t tcpmss_reverse_mtu(struct net *net, + const struct sk_buff *skb, + unsigned int family) +{ + struct flowi fl; + struct rtable *rt = NULL; + u_int32_t mtu = ~0U; + + if (family == PF_INET) { + struct flowi4 *fl4 = &fl.u.ip4; + memset(fl4, 0, sizeof(*fl4)); + fl4->daddr = ip_hdr(skb)->saddr; + } else { + struct flowi6 *fl6 = &fl.u.ip6; + + memset(fl6, 0, sizeof(*fl6)); + fl6->daddr = ipv6_hdr(skb)->saddr; + } + + nf_route(net, (struct dst_entry **)&rt, &fl, false, family); + if (rt != NULL) { + mtu = dst_mtu(&rt->dst); + dst_release(&rt->dst); + } + return mtu; +} + +static int +tcpmss_mangle_packet(struct sk_buff *skb, + const struct xt_action_param *par, + unsigned int family, + unsigned int tcphoff, + unsigned int minlen) +{ + const struct xt_tcpmss_info *info = par->targinfo; + struct tcphdr *tcph; + int len, tcp_hdrlen; + unsigned int i; + __be16 oldval; + u16 newmss; + u8 *opt; + + /* This is a fragment, no TCP header is available */ + if (par->fragoff != 0) + return 0; + + if (skb_ensure_writable(skb, skb->len)) + return -1; + + len = skb->len - tcphoff; + if (len < (int)sizeof(struct tcphdr)) + return -1; + + tcph = (struct tcphdr *)(skb_network_header(skb) + tcphoff); + tcp_hdrlen = tcph->doff * 4; + + if (len < tcp_hdrlen || tcp_hdrlen < sizeof(struct tcphdr)) + return -1; + + if (info->mss == XT_TCPMSS_CLAMP_PMTU) { + struct net *net = xt_net(par); + unsigned int in_mtu = tcpmss_reverse_mtu(net, skb, family); + unsigned int min_mtu = min(dst_mtu(skb_dst(skb)), in_mtu); + + if (min_mtu <= minlen) { + net_err_ratelimited("unknown or invalid path-MTU (%u)\n", + min_mtu); + return -1; + } + newmss = min_mtu - minlen; + } else + newmss = info->mss; + + opt = (u_int8_t *)tcph; + for (i = sizeof(struct tcphdr); i <= tcp_hdrlen - TCPOLEN_MSS; i += optlen(opt, i)) { + if (opt[i] == TCPOPT_MSS && opt[i+1] == TCPOLEN_MSS) { + u_int16_t oldmss; + + oldmss = (opt[i+2] << 8) | opt[i+3]; + + /* Never increase MSS, even when setting it, as + * doing so results in problems for hosts that rely + * on MSS being set correctly. + */ + if (oldmss <= newmss) + return 0; + + opt[i+2] = (newmss & 0xff00) >> 8; + opt[i+3] = newmss & 0x00ff; + + inet_proto_csum_replace2(&tcph->check, skb, + htons(oldmss), htons(newmss), + false); + return 0; + } + } + + /* There is data after the header so the option can't be added + * without moving it, and doing so may make the SYN packet + * itself too large. Accept the packet unmodified instead. + */ + if (len > tcp_hdrlen) + return 0; + + /* tcph->doff has 4 bits, do not wrap it to 0 */ + if (tcp_hdrlen >= 15 * 4) + return 0; + + /* + * MSS Option not found ?! add it.. + */ + if (skb_tailroom(skb) < TCPOLEN_MSS) { + if (pskb_expand_head(skb, 0, + TCPOLEN_MSS - skb_tailroom(skb), + GFP_ATOMIC)) + return -1; + tcph = (struct tcphdr *)(skb_network_header(skb) + tcphoff); + } + + skb_put(skb, TCPOLEN_MSS); + + /* + * IPv4: RFC 1122 states "If an MSS option is not received at + * connection setup, TCP MUST assume a default send MSS of 536". + * IPv6: RFC 2460 states IPv6 has a minimum MTU of 1280 and a minimum + * length IPv6 header of 60, ergo the default MSS value is 1220 + * Since no MSS was provided, we must use the default values + */ + if (xt_family(par) == NFPROTO_IPV4) + newmss = min(newmss, (u16)536); + else + newmss = min(newmss, (u16)1220); + + opt = (u_int8_t *)tcph + sizeof(struct tcphdr); + memmove(opt + TCPOLEN_MSS, opt, len - sizeof(struct tcphdr)); + + inet_proto_csum_replace2(&tcph->check, skb, + htons(len), htons(len + TCPOLEN_MSS), true); + opt[0] = TCPOPT_MSS; + opt[1] = TCPOLEN_MSS; + opt[2] = (newmss & 0xff00) >> 8; + opt[3] = newmss & 0x00ff; + + inet_proto_csum_replace4(&tcph->check, skb, 0, *((__be32 *)opt), false); + + oldval = ((__be16 *)tcph)[6]; + tcph->doff += TCPOLEN_MSS/4; + inet_proto_csum_replace2(&tcph->check, skb, + oldval, ((__be16 *)tcph)[6], false); + return TCPOLEN_MSS; +} + +static unsigned int +tcpmss_tg4(struct sk_buff *skb, const struct xt_action_param *par) +{ + struct iphdr *iph = ip_hdr(skb); + __be16 newlen; + int ret; + + ret = tcpmss_mangle_packet(skb, par, + PF_INET, + iph->ihl * 4, + sizeof(*iph) + sizeof(struct tcphdr)); + if (ret < 0) + return NF_DROP; + if (ret > 0) { + iph = ip_hdr(skb); + newlen = htons(ntohs(iph->tot_len) + ret); + csum_replace2(&iph->check, iph->tot_len, newlen); + iph->tot_len = newlen; + } + return XT_CONTINUE; +} + +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) +static unsigned int +tcpmss_tg6(struct sk_buff *skb, const struct xt_action_param *par) +{ + struct ipv6hdr *ipv6h = ipv6_hdr(skb); + u8 nexthdr; + __be16 frag_off, oldlen, newlen; + int tcphoff; + int ret; + + nexthdr = ipv6h->nexthdr; + tcphoff = ipv6_skip_exthdr(skb, sizeof(*ipv6h), &nexthdr, &frag_off); + if (tcphoff < 0) + return NF_DROP; + ret = tcpmss_mangle_packet(skb, par, + PF_INET6, + tcphoff, + sizeof(*ipv6h) + sizeof(struct tcphdr)); + if (ret < 0) + return NF_DROP; + if (ret > 0) { + ipv6h = ipv6_hdr(skb); + oldlen = ipv6h->payload_len; + newlen = htons(ntohs(oldlen) + ret); + if (skb->ip_summed == CHECKSUM_COMPLETE) + skb->csum = csum_add(csum_sub(skb->csum, (__force __wsum)oldlen), + (__force __wsum)newlen); + ipv6h->payload_len = newlen; + } + return XT_CONTINUE; +} +#endif + +/* Must specify -p tcp --syn */ +static inline bool find_syn_match(const struct xt_entry_match *m) +{ + const struct xt_tcp *tcpinfo = (const struct xt_tcp *)m->data; + + if (strcmp(m->u.kernel.match->name, "tcp") == 0 && + tcpinfo->flg_cmp & TCPHDR_SYN && + !(tcpinfo->invflags & XT_TCP_INV_FLAGS)) + return true; + + return false; +} + +static int tcpmss_tg4_check(const struct xt_tgchk_param *par) +{ + const struct xt_tcpmss_info *info = par->targinfo; + const struct ipt_entry *e = par->entryinfo; + const struct xt_entry_match *ematch; + + if (info->mss == XT_TCPMSS_CLAMP_PMTU && + (par->hook_mask & ~((1 << NF_INET_FORWARD) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_POST_ROUTING))) != 0) { + pr_info_ratelimited("path-MTU clamping only supported in FORWARD, OUTPUT and POSTROUTING hooks\n"); + return -EINVAL; + } + if (par->nft_compat) + return 0; + + xt_ematch_foreach(ematch, e) + if (find_syn_match(ematch)) + return 0; + pr_info_ratelimited("Only works on TCP SYN packets\n"); + return -EINVAL; +} + +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) +static int tcpmss_tg6_check(const struct xt_tgchk_param *par) +{ + const struct xt_tcpmss_info *info = par->targinfo; + const struct ip6t_entry *e = par->entryinfo; + const struct xt_entry_match *ematch; + + if (info->mss == XT_TCPMSS_CLAMP_PMTU && + (par->hook_mask & ~((1 << NF_INET_FORWARD) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_POST_ROUTING))) != 0) { + pr_info_ratelimited("path-MTU clamping only supported in FORWARD, OUTPUT and POSTROUTING hooks\n"); + return -EINVAL; + } + if (par->nft_compat) + return 0; + + xt_ematch_foreach(ematch, e) + if (find_syn_match(ematch)) + return 0; + pr_info_ratelimited("Only works on TCP SYN packets\n"); + return -EINVAL; +} +#endif + +static struct xt_target tcpmss_tg_reg[] __read_mostly = { + { + .family = NFPROTO_IPV4, + .name = "TCPMSS", + .checkentry = tcpmss_tg4_check, + .target = tcpmss_tg4, + .targetsize = sizeof(struct xt_tcpmss_info), + .proto = IPPROTO_TCP, + .me = THIS_MODULE, + }, +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + { + .family = NFPROTO_IPV6, + .name = "TCPMSS", + .checkentry = tcpmss_tg6_check, + .target = tcpmss_tg6, + .targetsize = sizeof(struct xt_tcpmss_info), + .proto = IPPROTO_TCP, + .me = THIS_MODULE, + }, +#endif +}; + +static int __init tcpmss_tg_init(void) +{ + return xt_register_targets(tcpmss_tg_reg, ARRAY_SIZE(tcpmss_tg_reg)); +} + +static void __exit tcpmss_tg_exit(void) +{ + xt_unregister_targets(tcpmss_tg_reg, ARRAY_SIZE(tcpmss_tg_reg)); +} + +module_init(tcpmss_tg_init); +module_exit(tcpmss_tg_exit); diff --git a/net/netfilter/xt_TCPOPTSTRIP.c b/net/netfilter/xt_TCPOPTSTRIP.c new file mode 100644 index 000000000..30e994641 --- /dev/null +++ b/net/netfilter/xt_TCPOPTSTRIP.c @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * A module for stripping a specific TCP option from TCP packets. + * + * Copyright (C) 2007 Sven Schnelle + * Copyright © CC Computer Consultants GmbH, 2007 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static inline unsigned int optlen(const u_int8_t *opt, unsigned int offset) +{ + /* Beware zero-length options: make finite progress */ + if (opt[offset] <= TCPOPT_NOP || opt[offset+1] == 0) + return 1; + else + return opt[offset+1]; +} + +static unsigned int +tcpoptstrip_mangle_packet(struct sk_buff *skb, + const struct xt_action_param *par, + unsigned int tcphoff) +{ + const struct xt_tcpoptstrip_target_info *info = par->targinfo; + struct tcphdr *tcph, _th; + unsigned int optl, i, j; + u_int16_t n, o; + u_int8_t *opt; + int tcp_hdrlen; + + /* This is a fragment, no TCP header is available */ + if (par->fragoff != 0) + return XT_CONTINUE; + + tcph = skb_header_pointer(skb, tcphoff, sizeof(_th), &_th); + if (!tcph) + return NF_DROP; + + tcp_hdrlen = tcph->doff * 4; + if (tcp_hdrlen < sizeof(struct tcphdr)) + return NF_DROP; + + if (skb_ensure_writable(skb, tcphoff + tcp_hdrlen)) + return NF_DROP; + + /* must reload tcph, might have been moved */ + tcph = (struct tcphdr *)(skb_network_header(skb) + tcphoff); + opt = (u8 *)tcph; + + /* + * Walk through all TCP options - if we find some option to remove, + * set all octets to %TCPOPT_NOP and adjust checksum. + */ + for (i = sizeof(struct tcphdr); i < tcp_hdrlen - 1; i += optl) { + optl = optlen(opt, i); + + if (i + optl > tcp_hdrlen) + break; + + if (!tcpoptstrip_test_bit(info->strip_bmap, opt[i])) + continue; + + for (j = 0; j < optl; ++j) { + o = opt[i+j]; + n = TCPOPT_NOP; + if ((i + j) % 2 == 0) { + o <<= 8; + n <<= 8; + } + inet_proto_csum_replace2(&tcph->check, skb, htons(o), + htons(n), false); + } + memset(opt + i, TCPOPT_NOP, optl); + } + + return XT_CONTINUE; +} + +static unsigned int +tcpoptstrip_tg4(struct sk_buff *skb, const struct xt_action_param *par) +{ + return tcpoptstrip_mangle_packet(skb, par, ip_hdrlen(skb)); +} + +#if IS_ENABLED(CONFIG_IP6_NF_MANGLE) +static unsigned int +tcpoptstrip_tg6(struct sk_buff *skb, const struct xt_action_param *par) +{ + struct ipv6hdr *ipv6h = ipv6_hdr(skb); + int tcphoff; + u_int8_t nexthdr; + __be16 frag_off; + + nexthdr = ipv6h->nexthdr; + tcphoff = ipv6_skip_exthdr(skb, sizeof(*ipv6h), &nexthdr, &frag_off); + if (tcphoff < 0) + return NF_DROP; + + return tcpoptstrip_mangle_packet(skb, par, tcphoff); +} +#endif + +static struct xt_target tcpoptstrip_tg_reg[] __read_mostly = { + { + .name = "TCPOPTSTRIP", + .family = NFPROTO_IPV4, + .table = "mangle", + .proto = IPPROTO_TCP, + .target = tcpoptstrip_tg4, + .targetsize = sizeof(struct xt_tcpoptstrip_target_info), + .me = THIS_MODULE, + }, +#if IS_ENABLED(CONFIG_IP6_NF_MANGLE) + { + .name = "TCPOPTSTRIP", + .family = NFPROTO_IPV6, + .table = "mangle", + .proto = IPPROTO_TCP, + .target = tcpoptstrip_tg6, + .targetsize = sizeof(struct xt_tcpoptstrip_target_info), + .me = THIS_MODULE, + }, +#endif +}; + +static int __init tcpoptstrip_tg_init(void) +{ + return xt_register_targets(tcpoptstrip_tg_reg, + ARRAY_SIZE(tcpoptstrip_tg_reg)); +} + +static void __exit tcpoptstrip_tg_exit(void) +{ + xt_unregister_targets(tcpoptstrip_tg_reg, + ARRAY_SIZE(tcpoptstrip_tg_reg)); +} + +module_init(tcpoptstrip_tg_init); +module_exit(tcpoptstrip_tg_exit); +MODULE_AUTHOR("Sven Schnelle , Jan Engelhardt "); +MODULE_DESCRIPTION("Xtables: TCP option stripping"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_TCPOPTSTRIP"); +MODULE_ALIAS("ip6t_TCPOPTSTRIP"); diff --git a/net/netfilter/xt_TEE.c b/net/netfilter/xt_TEE.c new file mode 100644 index 000000000..a5ebd5640 --- /dev/null +++ b/net/netfilter/xt_TEE.c @@ -0,0 +1,231 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * "TEE" target extension for Xtables + * Copyright © Sebastian Claßen, 2007 + * Jan Engelhardt, 2007-2010 + * + * based on ipt_ROUTE.c from Cédric de Launois + * + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct xt_tee_priv { + struct list_head list; + struct xt_tee_tginfo *tginfo; + int oif; +}; + +static unsigned int tee_net_id __read_mostly; +static const union nf_inet_addr tee_zero_address; + +struct tee_net { + struct list_head priv_list; + /* lock protects the priv_list */ + struct mutex lock; +}; + +static unsigned int +tee_tg4(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_tee_tginfo *info = par->targinfo; + int oif = info->priv ? info->priv->oif : 0; + + nf_dup_ipv4(xt_net(par), skb, xt_hooknum(par), &info->gw.in, oif); + + return XT_CONTINUE; +} + +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) +static unsigned int +tee_tg6(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_tee_tginfo *info = par->targinfo; + int oif = info->priv ? info->priv->oif : 0; + + nf_dup_ipv6(xt_net(par), skb, xt_hooknum(par), &info->gw.in6, oif); + + return XT_CONTINUE; +} +#endif + +static int tee_netdev_event(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct net *net = dev_net(dev); + struct tee_net *tn = net_generic(net, tee_net_id); + struct xt_tee_priv *priv; + + mutex_lock(&tn->lock); + list_for_each_entry(priv, &tn->priv_list, list) { + switch (event) { + case NETDEV_REGISTER: + if (!strcmp(dev->name, priv->tginfo->oif)) + priv->oif = dev->ifindex; + break; + case NETDEV_UNREGISTER: + if (dev->ifindex == priv->oif) + priv->oif = -1; + break; + case NETDEV_CHANGENAME: + if (!strcmp(dev->name, priv->tginfo->oif)) + priv->oif = dev->ifindex; + else if (dev->ifindex == priv->oif) + priv->oif = -1; + break; + } + } + mutex_unlock(&tn->lock); + + return NOTIFY_DONE; +} + +static int tee_tg_check(const struct xt_tgchk_param *par) +{ + struct tee_net *tn = net_generic(par->net, tee_net_id); + struct xt_tee_tginfo *info = par->targinfo; + struct xt_tee_priv *priv; + + /* 0.0.0.0 and :: not allowed */ + if (memcmp(&info->gw, &tee_zero_address, + sizeof(tee_zero_address)) == 0) + return -EINVAL; + + if (info->oif[0]) { + struct net_device *dev; + + if (info->oif[sizeof(info->oif)-1] != '\0') + return -EINVAL; + + priv = kzalloc(sizeof(*priv), GFP_KERNEL); + if (priv == NULL) + return -ENOMEM; + + priv->tginfo = info; + priv->oif = -1; + info->priv = priv; + + dev = dev_get_by_name(par->net, info->oif); + if (dev) { + priv->oif = dev->ifindex; + dev_put(dev); + } + mutex_lock(&tn->lock); + list_add(&priv->list, &tn->priv_list); + mutex_unlock(&tn->lock); + } else + info->priv = NULL; + + static_key_slow_inc(&xt_tee_enabled); + return 0; +} + +static void tee_tg_destroy(const struct xt_tgdtor_param *par) +{ + struct tee_net *tn = net_generic(par->net, tee_net_id); + struct xt_tee_tginfo *info = par->targinfo; + + if (info->priv) { + mutex_lock(&tn->lock); + list_del(&info->priv->list); + mutex_unlock(&tn->lock); + kfree(info->priv); + } + static_key_slow_dec(&xt_tee_enabled); +} + +static struct xt_target tee_tg_reg[] __read_mostly = { + { + .name = "TEE", + .revision = 1, + .family = NFPROTO_IPV4, + .target = tee_tg4, + .targetsize = sizeof(struct xt_tee_tginfo), + .usersize = offsetof(struct xt_tee_tginfo, priv), + .checkentry = tee_tg_check, + .destroy = tee_tg_destroy, + .me = THIS_MODULE, + }, +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + { + .name = "TEE", + .revision = 1, + .family = NFPROTO_IPV6, + .target = tee_tg6, + .targetsize = sizeof(struct xt_tee_tginfo), + .usersize = offsetof(struct xt_tee_tginfo, priv), + .checkentry = tee_tg_check, + .destroy = tee_tg_destroy, + .me = THIS_MODULE, + }, +#endif +}; + +static int __net_init tee_net_init(struct net *net) +{ + struct tee_net *tn = net_generic(net, tee_net_id); + + INIT_LIST_HEAD(&tn->priv_list); + mutex_init(&tn->lock); + return 0; +} + +static struct pernet_operations tee_net_ops = { + .init = tee_net_init, + .id = &tee_net_id, + .size = sizeof(struct tee_net), +}; + +static struct notifier_block tee_netdev_notifier = { + .notifier_call = tee_netdev_event, +}; + +static int __init tee_tg_init(void) +{ + int ret; + + ret = register_pernet_subsys(&tee_net_ops); + if (ret < 0) + return ret; + + ret = xt_register_targets(tee_tg_reg, ARRAY_SIZE(tee_tg_reg)); + if (ret < 0) + goto cleanup_subsys; + + ret = register_netdevice_notifier(&tee_netdev_notifier); + if (ret < 0) + goto unregister_targets; + + return 0; + +unregister_targets: + xt_unregister_targets(tee_tg_reg, ARRAY_SIZE(tee_tg_reg)); +cleanup_subsys: + unregister_pernet_subsys(&tee_net_ops); + return ret; +} + +static void __exit tee_tg_exit(void) +{ + unregister_netdevice_notifier(&tee_netdev_notifier); + xt_unregister_targets(tee_tg_reg, ARRAY_SIZE(tee_tg_reg)); + unregister_pernet_subsys(&tee_net_ops); +} + +module_init(tee_tg_init); +module_exit(tee_tg_exit); +MODULE_AUTHOR("Sebastian Claßen "); +MODULE_AUTHOR("Jan Engelhardt "); +MODULE_DESCRIPTION("Xtables: Reroute packet copy"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_TEE"); +MODULE_ALIAS("ip6t_TEE"); diff --git a/net/netfilter/xt_TPROXY.c b/net/netfilter/xt_TPROXY.c new file mode 100644 index 000000000..e4bea1d34 --- /dev/null +++ b/net/netfilter/xt_TPROXY.c @@ -0,0 +1,269 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Transparent proxy support for Linux/iptables + * + * Copyright (c) 2006-2010 BalaBit IT Ltd. + * Author: Balazs Scheidler, Krisztian Kovacs + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) +#define XT_TPROXY_HAVE_IPV6 1 +#include +#include +#include +#include +#include +#endif + +#include +#include + +static unsigned int +tproxy_tg4(struct net *net, struct sk_buff *skb, __be32 laddr, __be16 lport, + u_int32_t mark_mask, u_int32_t mark_value) +{ + const struct iphdr *iph = ip_hdr(skb); + struct udphdr _hdr, *hp; + struct sock *sk; + + hp = skb_header_pointer(skb, ip_hdrlen(skb), sizeof(_hdr), &_hdr); + if (hp == NULL) + return NF_DROP; + + /* check if there's an ongoing connection on the packet + * addresses, this happens if the redirect already happened + * and the current packet belongs to an already established + * connection */ + sk = nf_tproxy_get_sock_v4(net, skb, iph->protocol, + iph->saddr, iph->daddr, + hp->source, hp->dest, + skb->dev, NF_TPROXY_LOOKUP_ESTABLISHED); + + laddr = nf_tproxy_laddr4(skb, laddr, iph->daddr); + if (!lport) + lport = hp->dest; + + /* UDP has no TCP_TIME_WAIT state, so we never enter here */ + if (sk && sk->sk_state == TCP_TIME_WAIT) + /* reopening a TIME_WAIT connection needs special handling */ + sk = nf_tproxy_handle_time_wait4(net, skb, laddr, lport, sk); + else if (!sk) + /* no, there's no established connection, check if + * there's a listener on the redirected addr/port */ + sk = nf_tproxy_get_sock_v4(net, skb, iph->protocol, + iph->saddr, laddr, + hp->source, lport, + skb->dev, NF_TPROXY_LOOKUP_LISTENER); + + /* NOTE: assign_sock consumes our sk reference */ + if (sk && nf_tproxy_sk_is_transparent(sk)) { + /* This should be in a separate target, but we don't do multiple + targets on the same rule yet */ + skb->mark = (skb->mark & ~mark_mask) ^ mark_value; + nf_tproxy_assign_sock(skb, sk); + return NF_ACCEPT; + } + + return NF_DROP; +} + +static unsigned int +tproxy_tg4_v0(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_tproxy_target_info *tgi = par->targinfo; + + return tproxy_tg4(xt_net(par), skb, tgi->laddr, tgi->lport, + tgi->mark_mask, tgi->mark_value); +} + +static unsigned int +tproxy_tg4_v1(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_tproxy_target_info_v1 *tgi = par->targinfo; + + return tproxy_tg4(xt_net(par), skb, tgi->laddr.ip, tgi->lport, + tgi->mark_mask, tgi->mark_value); +} + +#ifdef XT_TPROXY_HAVE_IPV6 + +static unsigned int +tproxy_tg6_v1(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct ipv6hdr *iph = ipv6_hdr(skb); + const struct xt_tproxy_target_info_v1 *tgi = par->targinfo; + struct udphdr _hdr, *hp; + struct sock *sk; + const struct in6_addr *laddr; + __be16 lport; + int thoff = 0; + int tproto; + + tproto = ipv6_find_hdr(skb, &thoff, -1, NULL, NULL); + if (tproto < 0) + return NF_DROP; + + hp = skb_header_pointer(skb, thoff, sizeof(_hdr), &_hdr); + if (!hp) + return NF_DROP; + + /* check if there's an ongoing connection on the packet + * addresses, this happens if the redirect already happened + * and the current packet belongs to an already established + * connection */ + sk = nf_tproxy_get_sock_v6(xt_net(par), skb, thoff, tproto, + &iph->saddr, &iph->daddr, + hp->source, hp->dest, + xt_in(par), NF_TPROXY_LOOKUP_ESTABLISHED); + + laddr = nf_tproxy_laddr6(skb, &tgi->laddr.in6, &iph->daddr); + lport = tgi->lport ? tgi->lport : hp->dest; + + /* UDP has no TCP_TIME_WAIT state, so we never enter here */ + if (sk && sk->sk_state == TCP_TIME_WAIT) { + const struct xt_tproxy_target_info_v1 *tgi = par->targinfo; + /* reopening a TIME_WAIT connection needs special handling */ + sk = nf_tproxy_handle_time_wait6(skb, tproto, thoff, + xt_net(par), + &tgi->laddr.in6, + tgi->lport, + sk); + } + else if (!sk) + /* no there's no established connection, check if + * there's a listener on the redirected addr/port */ + sk = nf_tproxy_get_sock_v6(xt_net(par), skb, thoff, + tproto, &iph->saddr, laddr, + hp->source, lport, + xt_in(par), NF_TPROXY_LOOKUP_LISTENER); + + /* NOTE: assign_sock consumes our sk reference */ + if (sk && nf_tproxy_sk_is_transparent(sk)) { + /* This should be in a separate target, but we don't do multiple + targets on the same rule yet */ + skb->mark = (skb->mark & ~tgi->mark_mask) ^ tgi->mark_value; + nf_tproxy_assign_sock(skb, sk); + return NF_ACCEPT; + } + + return NF_DROP; +} + +static int tproxy_tg6_check(const struct xt_tgchk_param *par) +{ + const struct ip6t_ip6 *i = par->entryinfo; + int err; + + err = nf_defrag_ipv6_enable(par->net); + if (err) + return err; + + if ((i->proto == IPPROTO_TCP || i->proto == IPPROTO_UDP) && + !(i->invflags & IP6T_INV_PROTO)) + return 0; + + pr_info_ratelimited("Can be used only with -p tcp or -p udp\n"); + return -EINVAL; +} + +static void tproxy_tg6_destroy(const struct xt_tgdtor_param *par) +{ + nf_defrag_ipv6_disable(par->net); +} +#endif + +static int tproxy_tg4_check(const struct xt_tgchk_param *par) +{ + const struct ipt_ip *i = par->entryinfo; + int err; + + err = nf_defrag_ipv4_enable(par->net); + if (err) + return err; + + if ((i->proto == IPPROTO_TCP || i->proto == IPPROTO_UDP) + && !(i->invflags & IPT_INV_PROTO)) + return 0; + + pr_info_ratelimited("Can be used only with -p tcp or -p udp\n"); + return -EINVAL; +} + +static void tproxy_tg4_destroy(const struct xt_tgdtor_param *par) +{ + nf_defrag_ipv4_disable(par->net); +} + +static struct xt_target tproxy_tg_reg[] __read_mostly = { + { + .name = "TPROXY", + .family = NFPROTO_IPV4, + .table = "mangle", + .target = tproxy_tg4_v0, + .revision = 0, + .targetsize = sizeof(struct xt_tproxy_target_info), + .checkentry = tproxy_tg4_check, + .destroy = tproxy_tg4_destroy, + .hooks = 1 << NF_INET_PRE_ROUTING, + .me = THIS_MODULE, + }, + { + .name = "TPROXY", + .family = NFPROTO_IPV4, + .table = "mangle", + .target = tproxy_tg4_v1, + .revision = 1, + .targetsize = sizeof(struct xt_tproxy_target_info_v1), + .checkentry = tproxy_tg4_check, + .destroy = tproxy_tg4_destroy, + .hooks = 1 << NF_INET_PRE_ROUTING, + .me = THIS_MODULE, + }, +#ifdef XT_TPROXY_HAVE_IPV6 + { + .name = "TPROXY", + .family = NFPROTO_IPV6, + .table = "mangle", + .target = tproxy_tg6_v1, + .revision = 1, + .targetsize = sizeof(struct xt_tproxy_target_info_v1), + .checkentry = tproxy_tg6_check, + .destroy = tproxy_tg6_destroy, + .hooks = 1 << NF_INET_PRE_ROUTING, + .me = THIS_MODULE, + }, +#endif + +}; + +static int __init tproxy_tg_init(void) +{ + return xt_register_targets(tproxy_tg_reg, ARRAY_SIZE(tproxy_tg_reg)); +} + +static void __exit tproxy_tg_exit(void) +{ + xt_unregister_targets(tproxy_tg_reg, ARRAY_SIZE(tproxy_tg_reg)); +} + +module_init(tproxy_tg_init); +module_exit(tproxy_tg_exit); +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Balazs Scheidler, Krisztian Kovacs"); +MODULE_DESCRIPTION("Netfilter transparent proxy (TPROXY) target module."); +MODULE_ALIAS("ipt_TPROXY"); +MODULE_ALIAS("ip6t_TPROXY"); diff --git a/net/netfilter/xt_TRACE.c b/net/netfilter/xt_TRACE.c new file mode 100644 index 000000000..5582dce98 --- /dev/null +++ b/net/netfilter/xt_TRACE.c @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* This is a module which is used to mark packets for tracing. + */ +#include +#include + +#include +#include + +MODULE_DESCRIPTION("Xtables: packet flow tracing"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_TRACE"); +MODULE_ALIAS("ip6t_TRACE"); + +static int trace_tg_check(const struct xt_tgchk_param *par) +{ + return nf_logger_find_get(par->family, NF_LOG_TYPE_LOG); +} + +static void trace_tg_destroy(const struct xt_tgdtor_param *par) +{ + nf_logger_put(par->family, NF_LOG_TYPE_LOG); +} + +static unsigned int +trace_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + skb->nf_trace = 1; + return XT_CONTINUE; +} + +static struct xt_target trace_tg_reg __read_mostly = { + .name = "TRACE", + .revision = 0, + .family = NFPROTO_UNSPEC, + .table = "raw", + .target = trace_tg, + .checkentry = trace_tg_check, + .destroy = trace_tg_destroy, + .me = THIS_MODULE, +}; + +static int __init trace_tg_init(void) +{ + return xt_register_target(&trace_tg_reg); +} + +static void __exit trace_tg_exit(void) +{ + xt_unregister_target(&trace_tg_reg); +} + +module_init(trace_tg_init); +module_exit(trace_tg_exit); +MODULE_SOFTDEP("pre: nf_log_syslog"); diff --git a/net/netfilter/xt_addrtype.c b/net/netfilter/xt_addrtype.c new file mode 100644 index 000000000..e9b2181e8 --- /dev/null +++ b/net/netfilter/xt_addrtype.c @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * iptables module to match inet_addr_type() of an ip. + * + * Copyright (c) 2004 Patrick McHardy + * (C) 2007 Laszlo Attila Toth + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) +#include +#include +#include +#endif + +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_DESCRIPTION("Xtables: address type match"); +MODULE_ALIAS("ipt_addrtype"); +MODULE_ALIAS("ip6t_addrtype"); + +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) +static u32 match_lookup_rt6(struct net *net, const struct net_device *dev, + const struct in6_addr *addr, u16 mask) +{ + struct flowi6 flow; + struct rt6_info *rt; + u32 ret = 0; + int route_err; + + memset(&flow, 0, sizeof(flow)); + flow.daddr = *addr; + if (dev) + flow.flowi6_oif = dev->ifindex; + + if (dev && (mask & XT_ADDRTYPE_LOCAL)) { + if (nf_ipv6_chk_addr(net, addr, dev, true)) + ret = XT_ADDRTYPE_LOCAL; + } + + route_err = nf_ip6_route(net, (struct dst_entry **)&rt, + flowi6_to_flowi(&flow), false); + if (route_err) + return XT_ADDRTYPE_UNREACHABLE; + + if (rt->rt6i_flags & RTF_REJECT) + ret = XT_ADDRTYPE_UNREACHABLE; + + if (dev == NULL && rt->rt6i_flags & RTF_LOCAL) + ret |= XT_ADDRTYPE_LOCAL; + if (ipv6_anycast_destination((struct dst_entry *)rt, addr)) + ret |= XT_ADDRTYPE_ANYCAST; + + dst_release(&rt->dst); + return ret; +} + +static bool match_type6(struct net *net, const struct net_device *dev, + const struct in6_addr *addr, u16 mask) +{ + int addr_type = ipv6_addr_type(addr); + + if ((mask & XT_ADDRTYPE_MULTICAST) && + !(addr_type & IPV6_ADDR_MULTICAST)) + return false; + if ((mask & XT_ADDRTYPE_UNICAST) && !(addr_type & IPV6_ADDR_UNICAST)) + return false; + if ((mask & XT_ADDRTYPE_UNSPEC) && addr_type != IPV6_ADDR_ANY) + return false; + + if ((XT_ADDRTYPE_LOCAL | XT_ADDRTYPE_ANYCAST | + XT_ADDRTYPE_UNREACHABLE) & mask) + return !!(mask & match_lookup_rt6(net, dev, addr, mask)); + return true; +} + +static bool +addrtype_mt6(struct net *net, const struct net_device *dev, + const struct sk_buff *skb, const struct xt_addrtype_info_v1 *info) +{ + const struct ipv6hdr *iph = ipv6_hdr(skb); + bool ret = true; + + if (info->source) + ret &= match_type6(net, dev, &iph->saddr, info->source) ^ + (info->flags & XT_ADDRTYPE_INVERT_SOURCE); + if (ret && info->dest) + ret &= match_type6(net, dev, &iph->daddr, info->dest) ^ + !!(info->flags & XT_ADDRTYPE_INVERT_DEST); + return ret; +} +#endif + +static inline bool match_type(struct net *net, const struct net_device *dev, + __be32 addr, u_int16_t mask) +{ + return !!(mask & (1 << inet_dev_addr_type(net, dev, addr))); +} + +static bool +addrtype_mt_v0(const struct sk_buff *skb, struct xt_action_param *par) +{ + struct net *net = xt_net(par); + const struct xt_addrtype_info *info = par->matchinfo; + const struct iphdr *iph = ip_hdr(skb); + bool ret = true; + + if (info->source) + ret &= match_type(net, NULL, iph->saddr, info->source) ^ + info->invert_source; + if (info->dest) + ret &= match_type(net, NULL, iph->daddr, info->dest) ^ + info->invert_dest; + + return ret; +} + +static bool +addrtype_mt_v1(const struct sk_buff *skb, struct xt_action_param *par) +{ + struct net *net = xt_net(par); + const struct xt_addrtype_info_v1 *info = par->matchinfo; + const struct iphdr *iph; + const struct net_device *dev = NULL; + bool ret = true; + + if (info->flags & XT_ADDRTYPE_LIMIT_IFACE_IN) + dev = xt_in(par); + else if (info->flags & XT_ADDRTYPE_LIMIT_IFACE_OUT) + dev = xt_out(par); + +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + if (xt_family(par) == NFPROTO_IPV6) + return addrtype_mt6(net, dev, skb, info); +#endif + iph = ip_hdr(skb); + if (info->source) + ret &= match_type(net, dev, iph->saddr, info->source) ^ + (info->flags & XT_ADDRTYPE_INVERT_SOURCE); + if (ret && info->dest) + ret &= match_type(net, dev, iph->daddr, info->dest) ^ + !!(info->flags & XT_ADDRTYPE_INVERT_DEST); + return ret; +} + +static int addrtype_mt_checkentry_v1(const struct xt_mtchk_param *par) +{ + const char *errmsg = "both incoming and outgoing interface limitation cannot be selected"; + struct xt_addrtype_info_v1 *info = par->matchinfo; + + if (info->flags & XT_ADDRTYPE_LIMIT_IFACE_IN && + info->flags & XT_ADDRTYPE_LIMIT_IFACE_OUT) + goto err; + + if (par->hook_mask & ((1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN)) && + info->flags & XT_ADDRTYPE_LIMIT_IFACE_OUT) { + errmsg = "output interface limitation not valid in PREROUTING and INPUT"; + goto err; + } + + if (par->hook_mask & ((1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_OUT)) && + info->flags & XT_ADDRTYPE_LIMIT_IFACE_IN) { + errmsg = "input interface limitation not valid in POSTROUTING and OUTPUT"; + goto err; + } + +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + if (par->family == NFPROTO_IPV6) { + if ((info->source | info->dest) & XT_ADDRTYPE_BLACKHOLE) { + errmsg = "ipv6 BLACKHOLE matching not supported"; + goto err; + } + if ((info->source | info->dest) >= XT_ADDRTYPE_PROHIBIT) { + errmsg = "ipv6 PROHIBIT (THROW, NAT ..) matching not supported"; + goto err; + } + if ((info->source | info->dest) & XT_ADDRTYPE_BROADCAST) { + errmsg = "ipv6 does not support BROADCAST matching"; + goto err; + } + } +#endif + return 0; +err: + pr_info_ratelimited("%s\n", errmsg); + return -EINVAL; +} + +static struct xt_match addrtype_mt_reg[] __read_mostly = { + { + .name = "addrtype", + .family = NFPROTO_IPV4, + .match = addrtype_mt_v0, + .matchsize = sizeof(struct xt_addrtype_info), + .me = THIS_MODULE + }, + { + .name = "addrtype", + .family = NFPROTO_UNSPEC, + .revision = 1, + .match = addrtype_mt_v1, + .checkentry = addrtype_mt_checkentry_v1, + .matchsize = sizeof(struct xt_addrtype_info_v1), + .me = THIS_MODULE + } +}; + +static int __init addrtype_mt_init(void) +{ + return xt_register_matches(addrtype_mt_reg, + ARRAY_SIZE(addrtype_mt_reg)); +} + +static void __exit addrtype_mt_exit(void) +{ + xt_unregister_matches(addrtype_mt_reg, ARRAY_SIZE(addrtype_mt_reg)); +} + +module_init(addrtype_mt_init); +module_exit(addrtype_mt_exit); diff --git a/net/netfilter/xt_bpf.c b/net/netfilter/xt_bpf.c new file mode 100644 index 000000000..849ac552a --- /dev/null +++ b/net/netfilter/xt_bpf.c @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Xtables module to match packets using a BPF filter. + * Copyright 2013 Google Inc. + * Written by Willem de Bruijn + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include + +#include +#include + +MODULE_AUTHOR("Willem de Bruijn "); +MODULE_DESCRIPTION("Xtables: BPF filter match"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_bpf"); +MODULE_ALIAS("ip6t_bpf"); + +static int __bpf_mt_check_bytecode(struct sock_filter *insns, __u16 len, + struct bpf_prog **ret) +{ + struct sock_fprog_kern program; + + if (len > XT_BPF_MAX_NUM_INSTR) + return -EINVAL; + + program.len = len; + program.filter = insns; + + if (bpf_prog_create(ret, &program)) { + pr_info_ratelimited("check failed: parse error\n"); + return -EINVAL; + } + + return 0; +} + +static int __bpf_mt_check_fd(int fd, struct bpf_prog **ret) +{ + struct bpf_prog *prog; + + prog = bpf_prog_get_type(fd, BPF_PROG_TYPE_SOCKET_FILTER); + if (IS_ERR(prog)) + return PTR_ERR(prog); + + *ret = prog; + return 0; +} + +static int __bpf_mt_check_path(const char *path, struct bpf_prog **ret) +{ + if (strnlen(path, XT_BPF_PATH_MAX) == XT_BPF_PATH_MAX) + return -EINVAL; + + *ret = bpf_prog_get_type_path(path, BPF_PROG_TYPE_SOCKET_FILTER); + return PTR_ERR_OR_ZERO(*ret); +} + +static int bpf_mt_check(const struct xt_mtchk_param *par) +{ + struct xt_bpf_info *info = par->matchinfo; + + return __bpf_mt_check_bytecode(info->bpf_program, + info->bpf_program_num_elem, + &info->filter); +} + +static int bpf_mt_check_v1(const struct xt_mtchk_param *par) +{ + struct xt_bpf_info_v1 *info = par->matchinfo; + + if (info->mode == XT_BPF_MODE_BYTECODE) + return __bpf_mt_check_bytecode(info->bpf_program, + info->bpf_program_num_elem, + &info->filter); + else if (info->mode == XT_BPF_MODE_FD_ELF) + return __bpf_mt_check_fd(info->fd, &info->filter); + else if (info->mode == XT_BPF_MODE_PATH_PINNED) + return __bpf_mt_check_path(info->path, &info->filter); + else + return -EINVAL; +} + +static bool bpf_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_bpf_info *info = par->matchinfo; + + return bpf_prog_run(info->filter, skb); +} + +static bool bpf_mt_v1(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_bpf_info_v1 *info = par->matchinfo; + + return !!bpf_prog_run_save_cb(info->filter, (struct sk_buff *) skb); +} + +static void bpf_mt_destroy(const struct xt_mtdtor_param *par) +{ + const struct xt_bpf_info *info = par->matchinfo; + + bpf_prog_destroy(info->filter); +} + +static void bpf_mt_destroy_v1(const struct xt_mtdtor_param *par) +{ + const struct xt_bpf_info_v1 *info = par->matchinfo; + + bpf_prog_destroy(info->filter); +} + +static struct xt_match bpf_mt_reg[] __read_mostly = { + { + .name = "bpf", + .revision = 0, + .family = NFPROTO_UNSPEC, + .checkentry = bpf_mt_check, + .match = bpf_mt, + .destroy = bpf_mt_destroy, + .matchsize = sizeof(struct xt_bpf_info), + .usersize = offsetof(struct xt_bpf_info, filter), + .me = THIS_MODULE, + }, + { + .name = "bpf", + .revision = 1, + .family = NFPROTO_UNSPEC, + .checkentry = bpf_mt_check_v1, + .match = bpf_mt_v1, + .destroy = bpf_mt_destroy_v1, + .matchsize = sizeof(struct xt_bpf_info_v1), + .usersize = offsetof(struct xt_bpf_info_v1, filter), + .me = THIS_MODULE, + }, +}; + +static int __init bpf_mt_init(void) +{ + return xt_register_matches(bpf_mt_reg, ARRAY_SIZE(bpf_mt_reg)); +} + +static void __exit bpf_mt_exit(void) +{ + xt_unregister_matches(bpf_mt_reg, ARRAY_SIZE(bpf_mt_reg)); +} + +module_init(bpf_mt_init); +module_exit(bpf_mt_exit); diff --git a/net/netfilter/xt_cgroup.c b/net/netfilter/xt_cgroup.c new file mode 100644 index 000000000..c0f5e9a4f --- /dev/null +++ b/net/netfilter/xt_cgroup.c @@ -0,0 +1,219 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Xtables module to match the process control group. + * + * Might be used to implement individual "per-application" firewall + * policies in contrast to global policies based on control groups. + * Matching is based upon processes tagged to net_cls' classid marker. + * + * (C) 2013 Daniel Borkmann + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Daniel Borkmann "); +MODULE_DESCRIPTION("Xtables: process control group matching"); +MODULE_ALIAS("ipt_cgroup"); +MODULE_ALIAS("ip6t_cgroup"); + +static int cgroup_mt_check_v0(const struct xt_mtchk_param *par) +{ + struct xt_cgroup_info_v0 *info = par->matchinfo; + + if (info->invert & ~1) + return -EINVAL; + + return 0; +} + +static int cgroup_mt_check_v1(const struct xt_mtchk_param *par) +{ + struct xt_cgroup_info_v1 *info = par->matchinfo; + struct cgroup *cgrp; + + if ((info->invert_path & ~1) || (info->invert_classid & ~1)) + return -EINVAL; + + if (!info->has_path && !info->has_classid) { + pr_info("xt_cgroup: no path or classid specified\n"); + return -EINVAL; + } + + if (info->has_path && info->has_classid) { + pr_info_ratelimited("path and classid specified\n"); + return -EINVAL; + } + + info->priv = NULL; + if (info->has_path) { + cgrp = cgroup_get_from_path(info->path); + if (IS_ERR(cgrp)) { + pr_info_ratelimited("invalid path, errno=%ld\n", + PTR_ERR(cgrp)); + return -EINVAL; + } + info->priv = cgrp; + } + + return 0; +} + +static int cgroup_mt_check_v2(const struct xt_mtchk_param *par) +{ + struct xt_cgroup_info_v2 *info = par->matchinfo; + struct cgroup *cgrp; + + if ((info->invert_path & ~1) || (info->invert_classid & ~1)) + return -EINVAL; + + if (!info->has_path && !info->has_classid) { + pr_info("xt_cgroup: no path or classid specified\n"); + return -EINVAL; + } + + if (info->has_path && info->has_classid) { + pr_info_ratelimited("path and classid specified\n"); + return -EINVAL; + } + + info->priv = NULL; + if (info->has_path) { + cgrp = cgroup_get_from_path(info->path); + if (IS_ERR(cgrp)) { + pr_info_ratelimited("invalid path, errno=%ld\n", + PTR_ERR(cgrp)); + return -EINVAL; + } + info->priv = cgrp; + } + + return 0; +} + +static bool +cgroup_mt_v0(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_cgroup_info_v0 *info = par->matchinfo; + struct sock *sk = skb->sk; + + if (!sk || !sk_fullsock(sk) || !net_eq(xt_net(par), sock_net(sk))) + return false; + + return (info->id == sock_cgroup_classid(&skb->sk->sk_cgrp_data)) ^ + info->invert; +} + +static bool cgroup_mt_v1(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_cgroup_info_v1 *info = par->matchinfo; + struct sock_cgroup_data *skcd = &skb->sk->sk_cgrp_data; + struct cgroup *ancestor = info->priv; + struct sock *sk = skb->sk; + + if (!sk || !sk_fullsock(sk) || !net_eq(xt_net(par), sock_net(sk))) + return false; + + if (ancestor) + return cgroup_is_descendant(sock_cgroup_ptr(skcd), ancestor) ^ + info->invert_path; + else + return (info->classid == sock_cgroup_classid(skcd)) ^ + info->invert_classid; +} + +static bool cgroup_mt_v2(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_cgroup_info_v2 *info = par->matchinfo; + struct sock_cgroup_data *skcd = &skb->sk->sk_cgrp_data; + struct cgroup *ancestor = info->priv; + struct sock *sk = skb->sk; + + if (!sk || !sk_fullsock(sk) || !net_eq(xt_net(par), sock_net(sk))) + return false; + + if (ancestor) + return cgroup_is_descendant(sock_cgroup_ptr(skcd), ancestor) ^ + info->invert_path; + else + return (info->classid == sock_cgroup_classid(skcd)) ^ + info->invert_classid; +} + +static void cgroup_mt_destroy_v1(const struct xt_mtdtor_param *par) +{ + struct xt_cgroup_info_v1 *info = par->matchinfo; + + if (info->priv) + cgroup_put(info->priv); +} + +static void cgroup_mt_destroy_v2(const struct xt_mtdtor_param *par) +{ + struct xt_cgroup_info_v2 *info = par->matchinfo; + + if (info->priv) + cgroup_put(info->priv); +} + +static struct xt_match cgroup_mt_reg[] __read_mostly = { + { + .name = "cgroup", + .revision = 0, + .family = NFPROTO_UNSPEC, + .checkentry = cgroup_mt_check_v0, + .match = cgroup_mt_v0, + .matchsize = sizeof(struct xt_cgroup_info_v0), + .me = THIS_MODULE, + .hooks = (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_IN), + }, + { + .name = "cgroup", + .revision = 1, + .family = NFPROTO_UNSPEC, + .checkentry = cgroup_mt_check_v1, + .match = cgroup_mt_v1, + .matchsize = sizeof(struct xt_cgroup_info_v1), + .usersize = offsetof(struct xt_cgroup_info_v1, priv), + .destroy = cgroup_mt_destroy_v1, + .me = THIS_MODULE, + .hooks = (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_IN), + }, + { + .name = "cgroup", + .revision = 2, + .family = NFPROTO_UNSPEC, + .checkentry = cgroup_mt_check_v2, + .match = cgroup_mt_v2, + .matchsize = sizeof(struct xt_cgroup_info_v2), + .usersize = offsetof(struct xt_cgroup_info_v2, priv), + .destroy = cgroup_mt_destroy_v2, + .me = THIS_MODULE, + .hooks = (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_IN), + }, +}; + +static int __init cgroup_mt_init(void) +{ + return xt_register_matches(cgroup_mt_reg, ARRAY_SIZE(cgroup_mt_reg)); +} + +static void __exit cgroup_mt_exit(void) +{ + xt_unregister_matches(cgroup_mt_reg, ARRAY_SIZE(cgroup_mt_reg)); +} + +module_init(cgroup_mt_init); +module_exit(cgroup_mt_exit); diff --git a/net/netfilter/xt_cluster.c b/net/netfilter/xt_cluster.c new file mode 100644 index 000000000..a047a5453 --- /dev/null +++ b/net/netfilter/xt_cluster.c @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * (C) 2008-2009 Pablo Neira Ayuso + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include + +#include +#include +#include + +static inline u32 nf_ct_orig_ipv4_src(const struct nf_conn *ct) +{ + return (__force u32)ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.u3.ip; +} + +static inline const u32 *nf_ct_orig_ipv6_src(const struct nf_conn *ct) +{ + return (__force u32 *)ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.u3.ip6; +} + +static inline u_int32_t +xt_cluster_hash_ipv4(u_int32_t ip, const struct xt_cluster_match_info *info) +{ + return jhash_1word(ip, info->hash_seed); +} + +static inline u_int32_t +xt_cluster_hash_ipv6(const void *ip, const struct xt_cluster_match_info *info) +{ + return jhash2(ip, NF_CT_TUPLE_L3SIZE / sizeof(__u32), info->hash_seed); +} + +static inline u_int32_t +xt_cluster_hash(const struct nf_conn *ct, + const struct xt_cluster_match_info *info) +{ + u_int32_t hash = 0; + + switch(nf_ct_l3num(ct)) { + case AF_INET: + hash = xt_cluster_hash_ipv4(nf_ct_orig_ipv4_src(ct), info); + break; + case AF_INET6: + hash = xt_cluster_hash_ipv6(nf_ct_orig_ipv6_src(ct), info); + break; + default: + WARN_ON(1); + break; + } + + return reciprocal_scale(hash, info->total_nodes); +} + +static inline bool +xt_cluster_is_multicast_addr(const struct sk_buff *skb, u_int8_t family) +{ + bool is_multicast = false; + + switch(family) { + case NFPROTO_IPV4: + is_multicast = ipv4_is_multicast(ip_hdr(skb)->daddr); + break; + case NFPROTO_IPV6: + is_multicast = ipv6_addr_is_multicast(&ipv6_hdr(skb)->daddr); + break; + default: + WARN_ON(1); + break; + } + return is_multicast; +} + +static bool +xt_cluster_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + struct sk_buff *pskb = (struct sk_buff *)skb; + const struct xt_cluster_match_info *info = par->matchinfo; + const struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + unsigned long hash; + + /* This match assumes that all nodes see the same packets. This can be + * achieved if the switch that connects the cluster nodes support some + * sort of 'port mirroring'. However, if your switch does not support + * this, your cluster nodes can reply ARP request using a multicast MAC + * address. Thus, your switch will flood the same packets to the + * cluster nodes with the same multicast MAC address. Using a multicast + * link address is a RFC 1812 (section 3.3.2) violation, but this works + * fine in practise. + * + * Unfortunately, if you use the multicast MAC address, the link layer + * sets skbuff's pkt_type to PACKET_MULTICAST, which is not accepted + * by TCP and others for packets coming to this node. For that reason, + * this match mangles skbuff's pkt_type if it detects a packet + * addressed to a unicast address but using PACKET_MULTICAST. Yes, I + * know, matches should not alter packets, but we are doing this here + * because we would need to add a PKTTYPE target for this sole purpose. + */ + if (!xt_cluster_is_multicast_addr(skb, xt_family(par)) && + skb->pkt_type == PACKET_MULTICAST) { + pskb->pkt_type = PACKET_HOST; + } + + ct = nf_ct_get(skb, &ctinfo); + if (ct == NULL) + return false; + + if (ct->master) + hash = xt_cluster_hash(ct->master, info); + else + hash = xt_cluster_hash(ct, info); + + return !!((1 << hash) & info->node_mask) ^ + !!(info->flags & XT_CLUSTER_F_INV); +} + +static int xt_cluster_mt_checkentry(const struct xt_mtchk_param *par) +{ + struct xt_cluster_match_info *info = par->matchinfo; + int ret; + + if (info->total_nodes > XT_CLUSTER_NODES_MAX) { + pr_info_ratelimited("you have exceeded the maximum number of cluster nodes (%u > %u)\n", + info->total_nodes, XT_CLUSTER_NODES_MAX); + return -EINVAL; + } + if (info->node_mask >= (1ULL << info->total_nodes)) { + pr_info_ratelimited("node mask cannot exceed total number of nodes\n"); + return -EDOM; + } + + ret = nf_ct_netns_get(par->net, par->family); + if (ret < 0) + pr_info_ratelimited("cannot load conntrack support for proto=%u\n", + par->family); + return ret; +} + +static void xt_cluster_mt_destroy(const struct xt_mtdtor_param *par) +{ + nf_ct_netns_put(par->net, par->family); +} + +static struct xt_match xt_cluster_match __read_mostly = { + .name = "cluster", + .family = NFPROTO_UNSPEC, + .match = xt_cluster_mt, + .checkentry = xt_cluster_mt_checkentry, + .matchsize = sizeof(struct xt_cluster_match_info), + .destroy = xt_cluster_mt_destroy, + .me = THIS_MODULE, +}; + +static int __init xt_cluster_mt_init(void) +{ + return xt_register_match(&xt_cluster_match); +} + +static void __exit xt_cluster_mt_fini(void) +{ + xt_unregister_match(&xt_cluster_match); +} + +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Xtables: hash-based cluster match"); +MODULE_ALIAS("ipt_cluster"); +MODULE_ALIAS("ip6t_cluster"); +module_init(xt_cluster_mt_init); +module_exit(xt_cluster_mt_fini); diff --git a/net/netfilter/xt_comment.c b/net/netfilter/xt_comment.c new file mode 100644 index 000000000..f095557e3 --- /dev/null +++ b/net/netfilter/xt_comment.c @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Implements a dummy match to allow attaching comments to rules + * + * 2003-05-13 Brad Fisher (brad@info-link.net) + */ + +#include +#include +#include +#include + +MODULE_AUTHOR("Brad Fisher "); +MODULE_DESCRIPTION("Xtables: No-op match which can be tagged with a comment"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_comment"); +MODULE_ALIAS("ip6t_comment"); + +static bool +comment_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + /* We always match */ + return true; +} + +static struct xt_match comment_mt_reg __read_mostly = { + .name = "comment", + .revision = 0, + .family = NFPROTO_UNSPEC, + .match = comment_mt, + .matchsize = sizeof(struct xt_comment_info), + .me = THIS_MODULE, +}; + +static int __init comment_mt_init(void) +{ + return xt_register_match(&comment_mt_reg); +} + +static void __exit comment_mt_exit(void) +{ + xt_unregister_match(&comment_mt_reg); +} + +module_init(comment_mt_init); +module_exit(comment_mt_exit); diff --git a/net/netfilter/xt_connbytes.c b/net/netfilter/xt_connbytes.c new file mode 100644 index 000000000..93cb018c3 --- /dev/null +++ b/net/netfilter/xt_connbytes.c @@ -0,0 +1,157 @@ +/* Kernel module to match connection tracking byte counter. + * GPL (C) 2002 Martin Devera (devik@cdi.cz). + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Harald Welte "); +MODULE_DESCRIPTION("Xtables: Number of packets/bytes per connection matching"); +MODULE_ALIAS("ipt_connbytes"); +MODULE_ALIAS("ip6t_connbytes"); + +static bool +connbytes_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_connbytes_info *sinfo = par->matchinfo; + const struct nf_conn *ct; + enum ip_conntrack_info ctinfo; + u_int64_t what = 0; /* initialize to make gcc happy */ + u_int64_t bytes = 0; + u_int64_t pkts = 0; + const struct nf_conn_acct *acct; + const struct nf_conn_counter *counters; + + ct = nf_ct_get(skb, &ctinfo); + if (!ct) + return false; + + acct = nf_conn_acct_find(ct); + if (!acct) + return false; + + counters = acct->counter; + switch (sinfo->what) { + case XT_CONNBYTES_PKTS: + switch (sinfo->direction) { + case XT_CONNBYTES_DIR_ORIGINAL: + what = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].packets); + break; + case XT_CONNBYTES_DIR_REPLY: + what = atomic64_read(&counters[IP_CT_DIR_REPLY].packets); + break; + case XT_CONNBYTES_DIR_BOTH: + what = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].packets); + what += atomic64_read(&counters[IP_CT_DIR_REPLY].packets); + break; + } + break; + case XT_CONNBYTES_BYTES: + switch (sinfo->direction) { + case XT_CONNBYTES_DIR_ORIGINAL: + what = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].bytes); + break; + case XT_CONNBYTES_DIR_REPLY: + what = atomic64_read(&counters[IP_CT_DIR_REPLY].bytes); + break; + case XT_CONNBYTES_DIR_BOTH: + what = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].bytes); + what += atomic64_read(&counters[IP_CT_DIR_REPLY].bytes); + break; + } + break; + case XT_CONNBYTES_AVGPKT: + switch (sinfo->direction) { + case XT_CONNBYTES_DIR_ORIGINAL: + bytes = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].bytes); + pkts = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].packets); + break; + case XT_CONNBYTES_DIR_REPLY: + bytes = atomic64_read(&counters[IP_CT_DIR_REPLY].bytes); + pkts = atomic64_read(&counters[IP_CT_DIR_REPLY].packets); + break; + case XT_CONNBYTES_DIR_BOTH: + bytes = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].bytes) + + atomic64_read(&counters[IP_CT_DIR_REPLY].bytes); + pkts = atomic64_read(&counters[IP_CT_DIR_ORIGINAL].packets) + + atomic64_read(&counters[IP_CT_DIR_REPLY].packets); + break; + } + if (pkts != 0) + what = div64_u64(bytes, pkts); + break; + } + + if (sinfo->count.to >= sinfo->count.from) + return what <= sinfo->count.to && what >= sinfo->count.from; + else /* inverted */ + return what < sinfo->count.to || what > sinfo->count.from; +} + +static int connbytes_mt_check(const struct xt_mtchk_param *par) +{ + const struct xt_connbytes_info *sinfo = par->matchinfo; + int ret; + + if (sinfo->what != XT_CONNBYTES_PKTS && + sinfo->what != XT_CONNBYTES_BYTES && + sinfo->what != XT_CONNBYTES_AVGPKT) + return -EINVAL; + + if (sinfo->direction != XT_CONNBYTES_DIR_ORIGINAL && + sinfo->direction != XT_CONNBYTES_DIR_REPLY && + sinfo->direction != XT_CONNBYTES_DIR_BOTH) + return -EINVAL; + + ret = nf_ct_netns_get(par->net, par->family); + if (ret < 0) + pr_info_ratelimited("cannot load conntrack support for proto=%u\n", + par->family); + + /* + * This filter cannot function correctly unless connection tracking + * accounting is enabled, so complain in the hope that someone notices. + */ + if (!nf_ct_acct_enabled(par->net)) { + pr_warn("Forcing CT accounting to be enabled\n"); + nf_ct_set_acct(par->net, true); + } + + return ret; +} + +static void connbytes_mt_destroy(const struct xt_mtdtor_param *par) +{ + nf_ct_netns_put(par->net, par->family); +} + +static struct xt_match connbytes_mt_reg __read_mostly = { + .name = "connbytes", + .revision = 0, + .family = NFPROTO_UNSPEC, + .checkentry = connbytes_mt_check, + .match = connbytes_mt, + .destroy = connbytes_mt_destroy, + .matchsize = sizeof(struct xt_connbytes_info), + .me = THIS_MODULE, +}; + +static int __init connbytes_mt_init(void) +{ + return xt_register_match(&connbytes_mt_reg); +} + +static void __exit connbytes_mt_exit(void) +{ + xt_unregister_match(&connbytes_mt_reg); +} + +module_init(connbytes_mt_init); +module_exit(connbytes_mt_exit); diff --git a/net/netfilter/xt_connlabel.c b/net/netfilter/xt_connlabel.c new file mode 100644 index 000000000..87505cdad --- /dev/null +++ b/net/netfilter/xt_connlabel.c @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * (C) 2013 Astaro GmbH & Co KG + */ + +#include +#include +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Florian Westphal "); +MODULE_DESCRIPTION("Xtables: add/match connection tracking labels"); +MODULE_ALIAS("ipt_connlabel"); +MODULE_ALIAS("ip6t_connlabel"); + +static bool +connlabel_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_connlabel_mtinfo *info = par->matchinfo; + enum ip_conntrack_info ctinfo; + struct nf_conn_labels *labels; + struct nf_conn *ct; + bool invert = info->options & XT_CONNLABEL_OP_INVERT; + + ct = nf_ct_get(skb, &ctinfo); + if (ct == NULL) + return invert; + + labels = nf_ct_labels_find(ct); + if (!labels) + return invert; + + if (test_bit(info->bit, labels->bits)) + return !invert; + + if (info->options & XT_CONNLABEL_OP_SET) { + if (!test_and_set_bit(info->bit, labels->bits)) + nf_conntrack_event_cache(IPCT_LABEL, ct); + + return !invert; + } + + return invert; +} + +static int connlabel_mt_check(const struct xt_mtchk_param *par) +{ + const int options = XT_CONNLABEL_OP_INVERT | + XT_CONNLABEL_OP_SET; + struct xt_connlabel_mtinfo *info = par->matchinfo; + int ret; + + if (info->options & ~options) { + pr_info_ratelimited("Unknown options in mask %x\n", + info->options); + return -EINVAL; + } + + ret = nf_ct_netns_get(par->net, par->family); + if (ret < 0) { + pr_info_ratelimited("cannot load conntrack support for proto=%u\n", + par->family); + return ret; + } + + ret = nf_connlabels_get(par->net, info->bit); + if (ret < 0) + nf_ct_netns_put(par->net, par->family); + return ret; +} + +static void connlabel_mt_destroy(const struct xt_mtdtor_param *par) +{ + nf_connlabels_put(par->net); + nf_ct_netns_put(par->net, par->family); +} + +static struct xt_match connlabels_mt_reg __read_mostly = { + .name = "connlabel", + .family = NFPROTO_UNSPEC, + .checkentry = connlabel_mt_check, + .match = connlabel_mt, + .matchsize = sizeof(struct xt_connlabel_mtinfo), + .destroy = connlabel_mt_destroy, + .me = THIS_MODULE, +}; + +static int __init connlabel_mt_init(void) +{ + return xt_register_match(&connlabels_mt_reg); +} + +static void __exit connlabel_mt_exit(void) +{ + xt_unregister_match(&connlabels_mt_reg); +} + +module_init(connlabel_mt_init); +module_exit(connlabel_mt_exit); diff --git a/net/netfilter/xt_connlimit.c b/net/netfilter/xt_connlimit.c new file mode 100644 index 000000000..5d04ef80a --- /dev/null +++ b/net/netfilter/xt_connlimit.c @@ -0,0 +1,137 @@ +/* + * netfilter module to limit the number of parallel tcp + * connections per IP address. + * (c) 2000 Gerd Knorr + * Nov 2002: Martin Bene : + * only ignore TIME_WAIT or gone connections + * (C) CC Computer Consultants GmbH, 2007 + * + * based on ... + * + * Kernel module to match connection tracking information. + * GPL (C) 1999 Rusty Russell (rusty@rustcorp.com.au). + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +static bool +connlimit_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + struct net *net = xt_net(par); + const struct xt_connlimit_info *info = par->matchinfo; + struct nf_conntrack_tuple tuple; + const struct nf_conntrack_tuple *tuple_ptr = &tuple; + const struct nf_conntrack_zone *zone = &nf_ct_zone_dflt; + enum ip_conntrack_info ctinfo; + const struct nf_conn *ct; + unsigned int connections; + u32 key[5]; + + ct = nf_ct_get(skb, &ctinfo); + if (ct != NULL) { + tuple_ptr = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple; + zone = nf_ct_zone(ct); + } else if (!nf_ct_get_tuplepr(skb, skb_network_offset(skb), + xt_family(par), net, &tuple)) { + goto hotdrop; + } + + if (xt_family(par) == NFPROTO_IPV6) { + const struct ipv6hdr *iph = ipv6_hdr(skb); + union nf_inet_addr addr; + unsigned int i; + + memcpy(&addr.ip6, (info->flags & XT_CONNLIMIT_DADDR) ? + &iph->daddr : &iph->saddr, sizeof(addr.ip6)); + + for (i = 0; i < ARRAY_SIZE(addr.ip6); ++i) + addr.ip6[i] &= info->mask.ip6[i]; + memcpy(key, &addr, sizeof(addr.ip6)); + key[4] = zone->id; + } else { + const struct iphdr *iph = ip_hdr(skb); + + key[0] = (info->flags & XT_CONNLIMIT_DADDR) ? + (__force __u32)iph->daddr : (__force __u32)iph->saddr; + key[0] &= (__force __u32)info->mask.ip; + key[1] = zone->id; + } + + connections = nf_conncount_count(net, info->data, key, tuple_ptr, + zone); + if (connections == 0) + /* kmalloc failed, drop it entirely */ + goto hotdrop; + + return (connections > info->limit) ^ !!(info->flags & XT_CONNLIMIT_INVERT); + + hotdrop: + par->hotdrop = true; + return false; +} + +static int connlimit_mt_check(const struct xt_mtchk_param *par) +{ + struct xt_connlimit_info *info = par->matchinfo; + unsigned int keylen; + + keylen = sizeof(u32); + if (par->family == NFPROTO_IPV6) + keylen += sizeof(struct in6_addr); + else + keylen += sizeof(struct in_addr); + + /* init private data */ + info->data = nf_conncount_init(par->net, par->family, keylen); + + return PTR_ERR_OR_ZERO(info->data); +} + +static void connlimit_mt_destroy(const struct xt_mtdtor_param *par) +{ + const struct xt_connlimit_info *info = par->matchinfo; + + nf_conncount_destroy(par->net, par->family, info->data); +} + +static struct xt_match connlimit_mt_reg __read_mostly = { + .name = "connlimit", + .revision = 1, + .family = NFPROTO_UNSPEC, + .checkentry = connlimit_mt_check, + .match = connlimit_mt, + .matchsize = sizeof(struct xt_connlimit_info), + .usersize = offsetof(struct xt_connlimit_info, data), + .destroy = connlimit_mt_destroy, + .me = THIS_MODULE, +}; + +static int __init connlimit_mt_init(void) +{ + return xt_register_match(&connlimit_mt_reg); +} + +static void __exit connlimit_mt_exit(void) +{ + xt_unregister_match(&connlimit_mt_reg); +} + +module_init(connlimit_mt_init); +module_exit(connlimit_mt_exit); +MODULE_AUTHOR("Jan Engelhardt "); +MODULE_DESCRIPTION("Xtables: Number of connections matching"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_connlimit"); +MODULE_ALIAS("ip6t_connlimit"); diff --git a/net/netfilter/xt_connmark.c b/net/netfilter/xt_connmark.c new file mode 100644 index 000000000..ad3c033db --- /dev/null +++ b/net/netfilter/xt_connmark.c @@ -0,0 +1,208 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * xt_connmark - Netfilter module to operate on connection marks + * + * Copyright (C) 2002,2004 MARA Systems AB + * by Henrik Nordstrom + * Copyright © CC Computer Consultants GmbH, 2007 - 2008 + * Jan Engelhardt + */ + +#include +#include +#include +#include +#include +#include + +MODULE_AUTHOR("Henrik Nordstrom "); +MODULE_DESCRIPTION("Xtables: connection mark operations"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_CONNMARK"); +MODULE_ALIAS("ip6t_CONNMARK"); +MODULE_ALIAS("ipt_connmark"); +MODULE_ALIAS("ip6t_connmark"); + +static unsigned int +connmark_tg_shift(struct sk_buff *skb, const struct xt_connmark_tginfo2 *info) +{ + enum ip_conntrack_info ctinfo; + u_int32_t new_targetmark; + struct nf_conn *ct; + u_int32_t newmark; + u_int32_t oldmark; + + ct = nf_ct_get(skb, &ctinfo); + if (ct == NULL) + return XT_CONTINUE; + + switch (info->mode) { + case XT_CONNMARK_SET: + oldmark = READ_ONCE(ct->mark); + newmark = (oldmark & ~info->ctmask) ^ info->ctmark; + if (info->shift_dir == D_SHIFT_RIGHT) + newmark >>= info->shift_bits; + else + newmark <<= info->shift_bits; + + if (READ_ONCE(ct->mark) != newmark) { + WRITE_ONCE(ct->mark, newmark); + nf_conntrack_event_cache(IPCT_MARK, ct); + } + break; + case XT_CONNMARK_SAVE: + new_targetmark = (skb->mark & info->nfmask); + if (info->shift_dir == D_SHIFT_RIGHT) + new_targetmark >>= info->shift_bits; + else + new_targetmark <<= info->shift_bits; + + newmark = (READ_ONCE(ct->mark) & ~info->ctmask) ^ + new_targetmark; + if (READ_ONCE(ct->mark) != newmark) { + WRITE_ONCE(ct->mark, newmark); + nf_conntrack_event_cache(IPCT_MARK, ct); + } + break; + case XT_CONNMARK_RESTORE: + new_targetmark = (READ_ONCE(ct->mark) & info->ctmask); + if (info->shift_dir == D_SHIFT_RIGHT) + new_targetmark >>= info->shift_bits; + else + new_targetmark <<= info->shift_bits; + + newmark = (skb->mark & ~info->nfmask) ^ + new_targetmark; + skb->mark = newmark; + break; + } + return XT_CONTINUE; +} + +static unsigned int +connmark_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_connmark_tginfo1 *info = par->targinfo; + const struct xt_connmark_tginfo2 info2 = { + .ctmark = info->ctmark, + .ctmask = info->ctmask, + .nfmask = info->nfmask, + .mode = info->mode, + }; + + return connmark_tg_shift(skb, &info2); +} + +static unsigned int +connmark_tg_v2(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_connmark_tginfo2 *info = par->targinfo; + + return connmark_tg_shift(skb, info); +} + +static int connmark_tg_check(const struct xt_tgchk_param *par) +{ + int ret; + + ret = nf_ct_netns_get(par->net, par->family); + if (ret < 0) + pr_info_ratelimited("cannot load conntrack support for proto=%u\n", + par->family); + return ret; +} + +static void connmark_tg_destroy(const struct xt_tgdtor_param *par) +{ + nf_ct_netns_put(par->net, par->family); +} + +static bool +connmark_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_connmark_mtinfo1 *info = par->matchinfo; + enum ip_conntrack_info ctinfo; + const struct nf_conn *ct; + + ct = nf_ct_get(skb, &ctinfo); + if (ct == NULL) + return false; + + return ((READ_ONCE(ct->mark) & info->mask) == info->mark) ^ info->invert; +} + +static int connmark_mt_check(const struct xt_mtchk_param *par) +{ + int ret; + + ret = nf_ct_netns_get(par->net, par->family); + if (ret < 0) + pr_info_ratelimited("cannot load conntrack support for proto=%u\n", + par->family); + return ret; +} + +static void connmark_mt_destroy(const struct xt_mtdtor_param *par) +{ + nf_ct_netns_put(par->net, par->family); +} + +static struct xt_target connmark_tg_reg[] __read_mostly = { + { + .name = "CONNMARK", + .revision = 1, + .family = NFPROTO_UNSPEC, + .checkentry = connmark_tg_check, + .target = connmark_tg, + .targetsize = sizeof(struct xt_connmark_tginfo1), + .destroy = connmark_tg_destroy, + .me = THIS_MODULE, + }, + { + .name = "CONNMARK", + .revision = 2, + .family = NFPROTO_UNSPEC, + .checkentry = connmark_tg_check, + .target = connmark_tg_v2, + .targetsize = sizeof(struct xt_connmark_tginfo2), + .destroy = connmark_tg_destroy, + .me = THIS_MODULE, + } +}; + +static struct xt_match connmark_mt_reg __read_mostly = { + .name = "connmark", + .revision = 1, + .family = NFPROTO_UNSPEC, + .checkentry = connmark_mt_check, + .match = connmark_mt, + .matchsize = sizeof(struct xt_connmark_mtinfo1), + .destroy = connmark_mt_destroy, + .me = THIS_MODULE, +}; + +static int __init connmark_mt_init(void) +{ + int ret; + + ret = xt_register_targets(connmark_tg_reg, + ARRAY_SIZE(connmark_tg_reg)); + if (ret < 0) + return ret; + ret = xt_register_match(&connmark_mt_reg); + if (ret < 0) { + xt_unregister_targets(connmark_tg_reg, + ARRAY_SIZE(connmark_tg_reg)); + return ret; + } + return 0; +} + +static void __exit connmark_mt_exit(void) +{ + xt_unregister_match(&connmark_mt_reg); + xt_unregister_targets(connmark_tg_reg, ARRAY_SIZE(connmark_tg_reg)); +} + +module_init(connmark_mt_init); +module_exit(connmark_mt_exit); diff --git a/net/netfilter/xt_conntrack.c b/net/netfilter/xt_conntrack.c new file mode 100644 index 000000000..ea299da24 --- /dev/null +++ b/net/netfilter/xt_conntrack.c @@ -0,0 +1,327 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * xt_conntrack - Netfilter module to match connection tracking + * information. (Superset of Rusty's minimalistic state match.) + * + * (C) 2001 Marc Boucher (marc@mbsi.ca). + * (C) 2006-2012 Patrick McHardy + * Copyright © CC Computer Consultants GmbH, 2007 - 2008 + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Marc Boucher "); +MODULE_AUTHOR("Jan Engelhardt "); +MODULE_DESCRIPTION("Xtables: connection tracking state match"); +MODULE_ALIAS("ipt_conntrack"); +MODULE_ALIAS("ip6t_conntrack"); + +static bool +conntrack_addrcmp(const union nf_inet_addr *kaddr, + const union nf_inet_addr *uaddr, + const union nf_inet_addr *umask, unsigned int l3proto) +{ + if (l3proto == NFPROTO_IPV4) + return ((kaddr->ip ^ uaddr->ip) & umask->ip) == 0; + else if (l3proto == NFPROTO_IPV6) + return ipv6_masked_addr_cmp(&kaddr->in6, &umask->in6, + &uaddr->in6) == 0; + else + return false; +} + +static inline bool +conntrack_mt_origsrc(const struct nf_conn *ct, + const struct xt_conntrack_mtinfo2 *info, + u_int8_t family) +{ + return conntrack_addrcmp(&ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.u3, + &info->origsrc_addr, &info->origsrc_mask, family); +} + +static inline bool +conntrack_mt_origdst(const struct nf_conn *ct, + const struct xt_conntrack_mtinfo2 *info, + u_int8_t family) +{ + return conntrack_addrcmp(&ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.dst.u3, + &info->origdst_addr, &info->origdst_mask, family); +} + +static inline bool +conntrack_mt_replsrc(const struct nf_conn *ct, + const struct xt_conntrack_mtinfo2 *info, + u_int8_t family) +{ + return conntrack_addrcmp(&ct->tuplehash[IP_CT_DIR_REPLY].tuple.src.u3, + &info->replsrc_addr, &info->replsrc_mask, family); +} + +static inline bool +conntrack_mt_repldst(const struct nf_conn *ct, + const struct xt_conntrack_mtinfo2 *info, + u_int8_t family) +{ + return conntrack_addrcmp(&ct->tuplehash[IP_CT_DIR_REPLY].tuple.dst.u3, + &info->repldst_addr, &info->repldst_mask, family); +} + +static inline bool +ct_proto_port_check(const struct xt_conntrack_mtinfo2 *info, + const struct nf_conn *ct) +{ + const struct nf_conntrack_tuple *tuple; + + tuple = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple; + if ((info->match_flags & XT_CONNTRACK_PROTO) && + (nf_ct_protonum(ct) == info->l4proto) ^ + !(info->invert_flags & XT_CONNTRACK_PROTO)) + return false; + + /* Shortcut to match all recognized protocols by using ->src.all. */ + if ((info->match_flags & XT_CONNTRACK_ORIGSRC_PORT) && + (tuple->src.u.all == info->origsrc_port) ^ + !(info->invert_flags & XT_CONNTRACK_ORIGSRC_PORT)) + return false; + + if ((info->match_flags & XT_CONNTRACK_ORIGDST_PORT) && + (tuple->dst.u.all == info->origdst_port) ^ + !(info->invert_flags & XT_CONNTRACK_ORIGDST_PORT)) + return false; + + tuple = &ct->tuplehash[IP_CT_DIR_REPLY].tuple; + + if ((info->match_flags & XT_CONNTRACK_REPLSRC_PORT) && + (tuple->src.u.all == info->replsrc_port) ^ + !(info->invert_flags & XT_CONNTRACK_REPLSRC_PORT)) + return false; + + if ((info->match_flags & XT_CONNTRACK_REPLDST_PORT) && + (tuple->dst.u.all == info->repldst_port) ^ + !(info->invert_flags & XT_CONNTRACK_REPLDST_PORT)) + return false; + + return true; +} + +static inline bool +port_match(u16 min, u16 max, u16 port, bool invert) +{ + return (port >= min && port <= max) ^ invert; +} + +static inline bool +ct_proto_port_check_v3(const struct xt_conntrack_mtinfo3 *info, + const struct nf_conn *ct) +{ + const struct nf_conntrack_tuple *tuple; + + tuple = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple; + if ((info->match_flags & XT_CONNTRACK_PROTO) && + (nf_ct_protonum(ct) == info->l4proto) ^ + !(info->invert_flags & XT_CONNTRACK_PROTO)) + return false; + + /* Shortcut to match all recognized protocols by using ->src.all. */ + if ((info->match_flags & XT_CONNTRACK_ORIGSRC_PORT) && + !port_match(info->origsrc_port, info->origsrc_port_high, + ntohs(tuple->src.u.all), + info->invert_flags & XT_CONNTRACK_ORIGSRC_PORT)) + return false; + + if ((info->match_flags & XT_CONNTRACK_ORIGDST_PORT) && + !port_match(info->origdst_port, info->origdst_port_high, + ntohs(tuple->dst.u.all), + info->invert_flags & XT_CONNTRACK_ORIGDST_PORT)) + return false; + + tuple = &ct->tuplehash[IP_CT_DIR_REPLY].tuple; + + if ((info->match_flags & XT_CONNTRACK_REPLSRC_PORT) && + !port_match(info->replsrc_port, info->replsrc_port_high, + ntohs(tuple->src.u.all), + info->invert_flags & XT_CONNTRACK_REPLSRC_PORT)) + return false; + + if ((info->match_flags & XT_CONNTRACK_REPLDST_PORT) && + !port_match(info->repldst_port, info->repldst_port_high, + ntohs(tuple->dst.u.all), + info->invert_flags & XT_CONNTRACK_REPLDST_PORT)) + return false; + + return true; +} + +static bool +conntrack_mt(const struct sk_buff *skb, struct xt_action_param *par, + u16 state_mask, u16 status_mask) +{ + const struct xt_conntrack_mtinfo2 *info = par->matchinfo; + enum ip_conntrack_info ctinfo; + const struct nf_conn *ct; + unsigned int statebit; + + ct = nf_ct_get(skb, &ctinfo); + + if (ct) + statebit = XT_CONNTRACK_STATE_BIT(ctinfo); + else if (ctinfo == IP_CT_UNTRACKED) + statebit = XT_CONNTRACK_STATE_UNTRACKED; + else + statebit = XT_CONNTRACK_STATE_INVALID; + + if (info->match_flags & XT_CONNTRACK_STATE) { + if (ct != NULL) { + if (test_bit(IPS_SRC_NAT_BIT, &ct->status)) + statebit |= XT_CONNTRACK_STATE_SNAT; + if (test_bit(IPS_DST_NAT_BIT, &ct->status)) + statebit |= XT_CONNTRACK_STATE_DNAT; + } + if (!!(state_mask & statebit) ^ + !(info->invert_flags & XT_CONNTRACK_STATE)) + return false; + } + + if (ct == NULL) + return info->match_flags & XT_CONNTRACK_STATE; + if ((info->match_flags & XT_CONNTRACK_DIRECTION) && + (CTINFO2DIR(ctinfo) == IP_CT_DIR_ORIGINAL) ^ + !(info->invert_flags & XT_CONNTRACK_DIRECTION)) + return false; + + if (info->match_flags & XT_CONNTRACK_ORIGSRC) + if (conntrack_mt_origsrc(ct, info, xt_family(par)) ^ + !(info->invert_flags & XT_CONNTRACK_ORIGSRC)) + return false; + + if (info->match_flags & XT_CONNTRACK_ORIGDST) + if (conntrack_mt_origdst(ct, info, xt_family(par)) ^ + !(info->invert_flags & XT_CONNTRACK_ORIGDST)) + return false; + + if (info->match_flags & XT_CONNTRACK_REPLSRC) + if (conntrack_mt_replsrc(ct, info, xt_family(par)) ^ + !(info->invert_flags & XT_CONNTRACK_REPLSRC)) + return false; + + if (info->match_flags & XT_CONNTRACK_REPLDST) + if (conntrack_mt_repldst(ct, info, xt_family(par)) ^ + !(info->invert_flags & XT_CONNTRACK_REPLDST)) + return false; + + if (par->match->revision != 3) { + if (!ct_proto_port_check(info, ct)) + return false; + } else { + if (!ct_proto_port_check_v3(par->matchinfo, ct)) + return false; + } + + if ((info->match_flags & XT_CONNTRACK_STATUS) && + (!!(status_mask & ct->status) ^ + !(info->invert_flags & XT_CONNTRACK_STATUS))) + return false; + + if (info->match_flags & XT_CONNTRACK_EXPIRES) { + unsigned long expires = nf_ct_expires(ct) / HZ; + + if ((expires >= info->expires_min && + expires <= info->expires_max) ^ + !(info->invert_flags & XT_CONNTRACK_EXPIRES)) + return false; + } + return true; +} + +static bool +conntrack_mt_v1(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_conntrack_mtinfo1 *info = par->matchinfo; + + return conntrack_mt(skb, par, info->state_mask, info->status_mask); +} + +static bool +conntrack_mt_v2(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_conntrack_mtinfo2 *info = par->matchinfo; + + return conntrack_mt(skb, par, info->state_mask, info->status_mask); +} + +static bool +conntrack_mt_v3(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_conntrack_mtinfo3 *info = par->matchinfo; + + return conntrack_mt(skb, par, info->state_mask, info->status_mask); +} + +static int conntrack_mt_check(const struct xt_mtchk_param *par) +{ + int ret; + + ret = nf_ct_netns_get(par->net, par->family); + if (ret < 0) + pr_info_ratelimited("cannot load conntrack support for proto=%u\n", + par->family); + return ret; +} + +static void conntrack_mt_destroy(const struct xt_mtdtor_param *par) +{ + nf_ct_netns_put(par->net, par->family); +} + +static struct xt_match conntrack_mt_reg[] __read_mostly = { + { + .name = "conntrack", + .revision = 1, + .family = NFPROTO_UNSPEC, + .matchsize = sizeof(struct xt_conntrack_mtinfo1), + .match = conntrack_mt_v1, + .checkentry = conntrack_mt_check, + .destroy = conntrack_mt_destroy, + .me = THIS_MODULE, + }, + { + .name = "conntrack", + .revision = 2, + .family = NFPROTO_UNSPEC, + .matchsize = sizeof(struct xt_conntrack_mtinfo2), + .match = conntrack_mt_v2, + .checkentry = conntrack_mt_check, + .destroy = conntrack_mt_destroy, + .me = THIS_MODULE, + }, + { + .name = "conntrack", + .revision = 3, + .family = NFPROTO_UNSPEC, + .matchsize = sizeof(struct xt_conntrack_mtinfo3), + .match = conntrack_mt_v3, + .checkentry = conntrack_mt_check, + .destroy = conntrack_mt_destroy, + .me = THIS_MODULE, + }, +}; + +static int __init conntrack_mt_init(void) +{ + return xt_register_matches(conntrack_mt_reg, + ARRAY_SIZE(conntrack_mt_reg)); +} + +static void __exit conntrack_mt_exit(void) +{ + xt_unregister_matches(conntrack_mt_reg, ARRAY_SIZE(conntrack_mt_reg)); +} + +module_init(conntrack_mt_init); +module_exit(conntrack_mt_exit); diff --git a/net/netfilter/xt_cpu.c b/net/netfilter/xt_cpu.c new file mode 100644 index 000000000..3bdc302a0 --- /dev/null +++ b/net/netfilter/xt_cpu.c @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Kernel module to match running CPU */ + +/* + * Might be used to distribute connections on several daemons, if + * RPS (Remote Packet Steering) is enabled or NIC is multiqueue capable, + * each RX queue IRQ affined to one CPU (1:1 mapping) + */ + +/* (C) 2010 Eric Dumazet + */ + +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Eric Dumazet "); +MODULE_DESCRIPTION("Xtables: CPU match"); +MODULE_ALIAS("ipt_cpu"); +MODULE_ALIAS("ip6t_cpu"); + +static int cpu_mt_check(const struct xt_mtchk_param *par) +{ + const struct xt_cpu_info *info = par->matchinfo; + + if (info->invert & ~1) + return -EINVAL; + return 0; +} + +static bool cpu_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_cpu_info *info = par->matchinfo; + + return (info->cpu == smp_processor_id()) ^ info->invert; +} + +static struct xt_match cpu_mt_reg __read_mostly = { + .name = "cpu", + .revision = 0, + .family = NFPROTO_UNSPEC, + .checkentry = cpu_mt_check, + .match = cpu_mt, + .matchsize = sizeof(struct xt_cpu_info), + .me = THIS_MODULE, +}; + +static int __init cpu_mt_init(void) +{ + return xt_register_match(&cpu_mt_reg); +} + +static void __exit cpu_mt_exit(void) +{ + xt_unregister_match(&cpu_mt_reg); +} + +module_init(cpu_mt_init); +module_exit(cpu_mt_exit); diff --git a/net/netfilter/xt_dccp.c b/net/netfilter/xt_dccp.c new file mode 100644 index 000000000..e5a13ecbe --- /dev/null +++ b/net/netfilter/xt_dccp.c @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * iptables module for DCCP protocol header matching + * + * (C) 2005 by Harald Welte + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Harald Welte "); +MODULE_DESCRIPTION("Xtables: DCCP protocol packet match"); +MODULE_ALIAS("ipt_dccp"); +MODULE_ALIAS("ip6t_dccp"); + +#define DCCHECK(cond, option, flag, invflag) (!((flag) & (option)) \ + || (!!((invflag) & (option)) ^ (cond))) + +static unsigned char *dccp_optbuf; +static DEFINE_SPINLOCK(dccp_buflock); + +static inline bool +dccp_find_option(u_int8_t option, + const struct sk_buff *skb, + unsigned int protoff, + const struct dccp_hdr *dh, + bool *hotdrop) +{ + /* tcp.doff is only 4 bits, ie. max 15 * 4 bytes */ + const unsigned char *op; + unsigned int optoff = __dccp_hdr_len(dh); + unsigned int optlen = dh->dccph_doff*4 - __dccp_hdr_len(dh); + unsigned int i; + + if (dh->dccph_doff * 4 < __dccp_hdr_len(dh)) + goto invalid; + + if (!optlen) + return false; + + spin_lock_bh(&dccp_buflock); + op = skb_header_pointer(skb, protoff + optoff, optlen, dccp_optbuf); + if (op == NULL) { + /* If we don't have the whole header, drop packet. */ + goto partial; + } + + for (i = 0; i < optlen; ) { + if (op[i] == option) { + spin_unlock_bh(&dccp_buflock); + return true; + } + + if (op[i] < 2) + i++; + else + i += op[i+1]?:1; + } + + spin_unlock_bh(&dccp_buflock); + return false; + +partial: + spin_unlock_bh(&dccp_buflock); +invalid: + *hotdrop = true; + return false; +} + + +static inline bool +match_types(const struct dccp_hdr *dh, u_int16_t typemask) +{ + return typemask & (1 << dh->dccph_type); +} + +static inline bool +match_option(u_int8_t option, const struct sk_buff *skb, unsigned int protoff, + const struct dccp_hdr *dh, bool *hotdrop) +{ + return dccp_find_option(option, skb, protoff, dh, hotdrop); +} + +static bool +dccp_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_dccp_info *info = par->matchinfo; + const struct dccp_hdr *dh; + struct dccp_hdr _dh; + + if (par->fragoff != 0) + return false; + + dh = skb_header_pointer(skb, par->thoff, sizeof(_dh), &_dh); + if (dh == NULL) { + par->hotdrop = true; + return false; + } + + return DCCHECK(ntohs(dh->dccph_sport) >= info->spts[0] + && ntohs(dh->dccph_sport) <= info->spts[1], + XT_DCCP_SRC_PORTS, info->flags, info->invflags) + && DCCHECK(ntohs(dh->dccph_dport) >= info->dpts[0] + && ntohs(dh->dccph_dport) <= info->dpts[1], + XT_DCCP_DEST_PORTS, info->flags, info->invflags) + && DCCHECK(match_types(dh, info->typemask), + XT_DCCP_TYPE, info->flags, info->invflags) + && DCCHECK(match_option(info->option, skb, par->thoff, dh, + &par->hotdrop), + XT_DCCP_OPTION, info->flags, info->invflags); +} + +static int dccp_mt_check(const struct xt_mtchk_param *par) +{ + const struct xt_dccp_info *info = par->matchinfo; + + if (info->flags & ~XT_DCCP_VALID_FLAGS) + return -EINVAL; + if (info->invflags & ~XT_DCCP_VALID_FLAGS) + return -EINVAL; + if (info->invflags & ~info->flags) + return -EINVAL; + return 0; +} + +static struct xt_match dccp_mt_reg[] __read_mostly = { + { + .name = "dccp", + .family = NFPROTO_IPV4, + .checkentry = dccp_mt_check, + .match = dccp_mt, + .matchsize = sizeof(struct xt_dccp_info), + .proto = IPPROTO_DCCP, + .me = THIS_MODULE, + }, + { + .name = "dccp", + .family = NFPROTO_IPV6, + .checkentry = dccp_mt_check, + .match = dccp_mt, + .matchsize = sizeof(struct xt_dccp_info), + .proto = IPPROTO_DCCP, + .me = THIS_MODULE, + }, +}; + +static int __init dccp_mt_init(void) +{ + int ret; + + /* doff is 8 bits, so the maximum option size is (4*256). Don't put + * this in BSS since DaveM is worried about locked TLB's for kernel + * BSS. */ + dccp_optbuf = kmalloc(256 * 4, GFP_KERNEL); + if (!dccp_optbuf) + return -ENOMEM; + ret = xt_register_matches(dccp_mt_reg, ARRAY_SIZE(dccp_mt_reg)); + if (ret) + goto out_kfree; + return ret; + +out_kfree: + kfree(dccp_optbuf); + return ret; +} + +static void __exit dccp_mt_exit(void) +{ + xt_unregister_matches(dccp_mt_reg, ARRAY_SIZE(dccp_mt_reg)); + kfree(dccp_optbuf); +} + +module_init(dccp_mt_init); +module_exit(dccp_mt_exit); diff --git a/net/netfilter/xt_devgroup.c b/net/netfilter/xt_devgroup.c new file mode 100644 index 000000000..9520dd000 --- /dev/null +++ b/net/netfilter/xt_devgroup.c @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2011 Patrick McHardy + */ + +#include +#include +#include + +#include +#include + +MODULE_AUTHOR("Patrick McHardy "); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Xtables: Device group match"); +MODULE_ALIAS("ipt_devgroup"); +MODULE_ALIAS("ip6t_devgroup"); + +static bool devgroup_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_devgroup_info *info = par->matchinfo; + + if (info->flags & XT_DEVGROUP_MATCH_SRC && + (((info->src_group ^ xt_in(par)->group) & info->src_mask ? 1 : 0) ^ + ((info->flags & XT_DEVGROUP_INVERT_SRC) ? 1 : 0))) + return false; + + if (info->flags & XT_DEVGROUP_MATCH_DST && + (((info->dst_group ^ xt_out(par)->group) & info->dst_mask ? 1 : 0) ^ + ((info->flags & XT_DEVGROUP_INVERT_DST) ? 1 : 0))) + return false; + + return true; +} + +static int devgroup_mt_checkentry(const struct xt_mtchk_param *par) +{ + const struct xt_devgroup_info *info = par->matchinfo; + + if (info->flags & ~(XT_DEVGROUP_MATCH_SRC | XT_DEVGROUP_INVERT_SRC | + XT_DEVGROUP_MATCH_DST | XT_DEVGROUP_INVERT_DST)) + return -EINVAL; + + if (info->flags & XT_DEVGROUP_MATCH_SRC && + par->hook_mask & ~((1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_FORWARD))) + return -EINVAL; + + if (info->flags & XT_DEVGROUP_MATCH_DST && + par->hook_mask & ~((1 << NF_INET_FORWARD) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_POST_ROUTING))) + return -EINVAL; + + return 0; +} + +static struct xt_match devgroup_mt_reg __read_mostly = { + .name = "devgroup", + .match = devgroup_mt, + .checkentry = devgroup_mt_checkentry, + .matchsize = sizeof(struct xt_devgroup_info), + .family = NFPROTO_UNSPEC, + .me = THIS_MODULE +}; + +static int __init devgroup_mt_init(void) +{ + return xt_register_match(&devgroup_mt_reg); +} + +static void __exit devgroup_mt_exit(void) +{ + xt_unregister_match(&devgroup_mt_reg); +} + +module_init(devgroup_mt_init); +module_exit(devgroup_mt_exit); diff --git a/net/netfilter/xt_dscp.c b/net/netfilter/xt_dscp.c new file mode 100644 index 000000000..fb0169a8f --- /dev/null +++ b/net/netfilter/xt_dscp.c @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* IP tables module for matching the value of the IPv4/IPv6 DSCP field + * + * (C) 2002 by Harald Welte + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include + +#include +#include + +MODULE_AUTHOR("Harald Welte "); +MODULE_DESCRIPTION("Xtables: DSCP/TOS field match"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_dscp"); +MODULE_ALIAS("ip6t_dscp"); +MODULE_ALIAS("ipt_tos"); +MODULE_ALIAS("ip6t_tos"); + +static bool +dscp_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_dscp_info *info = par->matchinfo; + u_int8_t dscp = ipv4_get_dsfield(ip_hdr(skb)) >> XT_DSCP_SHIFT; + + return (dscp == info->dscp) ^ !!info->invert; +} + +static bool +dscp_mt6(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_dscp_info *info = par->matchinfo; + u_int8_t dscp = ipv6_get_dsfield(ipv6_hdr(skb)) >> XT_DSCP_SHIFT; + + return (dscp == info->dscp) ^ !!info->invert; +} + +static int dscp_mt_check(const struct xt_mtchk_param *par) +{ + const struct xt_dscp_info *info = par->matchinfo; + + if (info->dscp > XT_DSCP_MAX) + return -EDOM; + + return 0; +} + +static bool tos_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_tos_match_info *info = par->matchinfo; + + if (xt_family(par) == NFPROTO_IPV4) + return ((ip_hdr(skb)->tos & info->tos_mask) == + info->tos_value) ^ !!info->invert; + else + return ((ipv6_get_dsfield(ipv6_hdr(skb)) & info->tos_mask) == + info->tos_value) ^ !!info->invert; +} + +static struct xt_match dscp_mt_reg[] __read_mostly = { + { + .name = "dscp", + .family = NFPROTO_IPV4, + .checkentry = dscp_mt_check, + .match = dscp_mt, + .matchsize = sizeof(struct xt_dscp_info), + .me = THIS_MODULE, + }, + { + .name = "dscp", + .family = NFPROTO_IPV6, + .checkentry = dscp_mt_check, + .match = dscp_mt6, + .matchsize = sizeof(struct xt_dscp_info), + .me = THIS_MODULE, + }, + { + .name = "tos", + .revision = 1, + .family = NFPROTO_IPV4, + .match = tos_mt, + .matchsize = sizeof(struct xt_tos_match_info), + .me = THIS_MODULE, + }, + { + .name = "tos", + .revision = 1, + .family = NFPROTO_IPV6, + .match = tos_mt, + .matchsize = sizeof(struct xt_tos_match_info), + .me = THIS_MODULE, + }, +}; + +static int __init dscp_mt_init(void) +{ + return xt_register_matches(dscp_mt_reg, ARRAY_SIZE(dscp_mt_reg)); +} + +static void __exit dscp_mt_exit(void) +{ + xt_unregister_matches(dscp_mt_reg, ARRAY_SIZE(dscp_mt_reg)); +} + +module_init(dscp_mt_init); +module_exit(dscp_mt_exit); diff --git a/net/netfilter/xt_ecn.c b/net/netfilter/xt_ecn.c new file mode 100644 index 000000000..b96e8203a --- /dev/null +++ b/net/netfilter/xt_ecn.c @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Xtables module for matching the value of the IPv4/IPv6 and TCP ECN bits + * + * (C) 2002 by Harald Welte + * (C) 2011 Patrick McHardy + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +MODULE_AUTHOR("Harald Welte "); +MODULE_DESCRIPTION("Xtables: Explicit Congestion Notification (ECN) flag match"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_ecn"); +MODULE_ALIAS("ip6t_ecn"); + +static bool match_tcp(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_ecn_info *einfo = par->matchinfo; + struct tcphdr _tcph; + const struct tcphdr *th; + + /* In practice, TCP match does this, so can't fail. But let's + * be good citizens. + */ + th = skb_header_pointer(skb, par->thoff, sizeof(_tcph), &_tcph); + if (th == NULL) + return false; + + if (einfo->operation & XT_ECN_OP_MATCH_ECE) { + if (einfo->invert & XT_ECN_OP_MATCH_ECE) { + if (th->ece == 1) + return false; + } else { + if (th->ece == 0) + return false; + } + } + + if (einfo->operation & XT_ECN_OP_MATCH_CWR) { + if (einfo->invert & XT_ECN_OP_MATCH_CWR) { + if (th->cwr == 1) + return false; + } else { + if (th->cwr == 0) + return false; + } + } + + return true; +} + +static inline bool match_ip(const struct sk_buff *skb, + const struct xt_ecn_info *einfo) +{ + return ((ip_hdr(skb)->tos & XT_ECN_IP_MASK) == einfo->ip_ect) ^ + !!(einfo->invert & XT_ECN_OP_MATCH_IP); +} + +static bool ecn_mt4(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_ecn_info *info = par->matchinfo; + + if (info->operation & XT_ECN_OP_MATCH_IP && !match_ip(skb, info)) + return false; + + if (info->operation & (XT_ECN_OP_MATCH_ECE | XT_ECN_OP_MATCH_CWR) && + !match_tcp(skb, par)) + return false; + + return true; +} + +static int ecn_mt_check4(const struct xt_mtchk_param *par) +{ + const struct xt_ecn_info *info = par->matchinfo; + const struct ipt_ip *ip = par->entryinfo; + + if (info->operation & XT_ECN_OP_MATCH_MASK) + return -EINVAL; + + if (info->invert & XT_ECN_OP_MATCH_MASK) + return -EINVAL; + + if (info->operation & (XT_ECN_OP_MATCH_ECE | XT_ECN_OP_MATCH_CWR) && + (ip->proto != IPPROTO_TCP || ip->invflags & IPT_INV_PROTO)) { + pr_info_ratelimited("cannot match TCP bits for non-tcp packets\n"); + return -EINVAL; + } + + return 0; +} + +static inline bool match_ipv6(const struct sk_buff *skb, + const struct xt_ecn_info *einfo) +{ + return (((ipv6_hdr(skb)->flow_lbl[0] >> 4) & XT_ECN_IP_MASK) == + einfo->ip_ect) ^ + !!(einfo->invert & XT_ECN_OP_MATCH_IP); +} + +static bool ecn_mt6(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_ecn_info *info = par->matchinfo; + + if (info->operation & XT_ECN_OP_MATCH_IP && !match_ipv6(skb, info)) + return false; + + if (info->operation & (XT_ECN_OP_MATCH_ECE | XT_ECN_OP_MATCH_CWR) && + !match_tcp(skb, par)) + return false; + + return true; +} + +static int ecn_mt_check6(const struct xt_mtchk_param *par) +{ + const struct xt_ecn_info *info = par->matchinfo; + const struct ip6t_ip6 *ip = par->entryinfo; + + if (info->operation & XT_ECN_OP_MATCH_MASK) + return -EINVAL; + + if (info->invert & XT_ECN_OP_MATCH_MASK) + return -EINVAL; + + if (info->operation & (XT_ECN_OP_MATCH_ECE | XT_ECN_OP_MATCH_CWR) && + (ip->proto != IPPROTO_TCP || ip->invflags & IP6T_INV_PROTO)) { + pr_info_ratelimited("cannot match TCP bits for non-tcp packets\n"); + return -EINVAL; + } + + return 0; +} + +static struct xt_match ecn_mt_reg[] __read_mostly = { + { + .name = "ecn", + .family = NFPROTO_IPV4, + .match = ecn_mt4, + .matchsize = sizeof(struct xt_ecn_info), + .checkentry = ecn_mt_check4, + .me = THIS_MODULE, + }, + { + .name = "ecn", + .family = NFPROTO_IPV6, + .match = ecn_mt6, + .matchsize = sizeof(struct xt_ecn_info), + .checkentry = ecn_mt_check6, + .me = THIS_MODULE, + }, +}; + +static int __init ecn_mt_init(void) +{ + return xt_register_matches(ecn_mt_reg, ARRAY_SIZE(ecn_mt_reg)); +} + +static void __exit ecn_mt_exit(void) +{ + xt_unregister_matches(ecn_mt_reg, ARRAY_SIZE(ecn_mt_reg)); +} + +module_init(ecn_mt_init); +module_exit(ecn_mt_exit); diff --git a/net/netfilter/xt_esp.c b/net/netfilter/xt_esp.c new file mode 100644 index 000000000..2a1c0ad0f --- /dev/null +++ b/net/netfilter/xt_esp.c @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Kernel module to match ESP parameters. */ + +/* (C) 1999-2000 Yon Uriarte + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include + +#include +#include + +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Yon Uriarte "); +MODULE_DESCRIPTION("Xtables: IPsec-ESP packet match"); +MODULE_ALIAS("ipt_esp"); +MODULE_ALIAS("ip6t_esp"); + +/* Returns 1 if the spi is matched by the range, 0 otherwise */ +static inline bool +spi_match(u_int32_t min, u_int32_t max, u_int32_t spi, bool invert) +{ + bool r; + pr_debug("spi_match:%c 0x%x <= 0x%x <= 0x%x\n", + invert ? '!' : ' ', min, spi, max); + r = (spi >= min && spi <= max) ^ invert; + pr_debug(" result %s\n", r ? "PASS" : "FAILED"); + return r; +} + +static bool esp_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct ip_esp_hdr *eh; + struct ip_esp_hdr _esp; + const struct xt_esp *espinfo = par->matchinfo; + + /* Must not be a fragment. */ + if (par->fragoff != 0) + return false; + + eh = skb_header_pointer(skb, par->thoff, sizeof(_esp), &_esp); + if (eh == NULL) { + /* We've been asked to examine this packet, and we + * can't. Hence, no choice but to drop. + */ + pr_debug("Dropping evil ESP tinygram.\n"); + par->hotdrop = true; + return false; + } + + return spi_match(espinfo->spis[0], espinfo->spis[1], ntohl(eh->spi), + !!(espinfo->invflags & XT_ESP_INV_SPI)); +} + +static int esp_mt_check(const struct xt_mtchk_param *par) +{ + const struct xt_esp *espinfo = par->matchinfo; + + if (espinfo->invflags & ~XT_ESP_INV_MASK) { + pr_debug("unknown flags %X\n", espinfo->invflags); + return -EINVAL; + } + + return 0; +} + +static struct xt_match esp_mt_reg[] __read_mostly = { + { + .name = "esp", + .family = NFPROTO_IPV4, + .checkentry = esp_mt_check, + .match = esp_mt, + .matchsize = sizeof(struct xt_esp), + .proto = IPPROTO_ESP, + .me = THIS_MODULE, + }, + { + .name = "esp", + .family = NFPROTO_IPV6, + .checkentry = esp_mt_check, + .match = esp_mt, + .matchsize = sizeof(struct xt_esp), + .proto = IPPROTO_ESP, + .me = THIS_MODULE, + }, +}; + +static int __init esp_mt_init(void) +{ + return xt_register_matches(esp_mt_reg, ARRAY_SIZE(esp_mt_reg)); +} + +static void __exit esp_mt_exit(void) +{ + xt_unregister_matches(esp_mt_reg, ARRAY_SIZE(esp_mt_reg)); +} + +module_init(esp_mt_init); +module_exit(esp_mt_exit); diff --git a/net/netfilter/xt_hashlimit.c b/net/netfilter/xt_hashlimit.c new file mode 100644 index 000000000..0859b8f76 --- /dev/null +++ b/net/netfilter/xt_hashlimit.c @@ -0,0 +1,1332 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * xt_hashlimit - Netfilter module to limit the number of packets per time + * separately for each hashbucket (sourceip/sourceport/dstip/dstport) + * + * (C) 2003-2004 by Harald Welte + * (C) 2006-2012 Patrick McHardy + * Copyright © CC Computer Consultants GmbH, 2007 - 2008 + * + * Development of this code was funded by Astaro AG, http://www.astaro.com/ + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) +#include +#include +#endif + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#define XT_HASHLIMIT_ALL (XT_HASHLIMIT_HASH_DIP | XT_HASHLIMIT_HASH_DPT | \ + XT_HASHLIMIT_HASH_SIP | XT_HASHLIMIT_HASH_SPT | \ + XT_HASHLIMIT_INVERT | XT_HASHLIMIT_BYTES |\ + XT_HASHLIMIT_RATE_MATCH) + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Harald Welte "); +MODULE_AUTHOR("Jan Engelhardt "); +MODULE_DESCRIPTION("Xtables: per hash-bucket rate-limit match"); +MODULE_ALIAS("ipt_hashlimit"); +MODULE_ALIAS("ip6t_hashlimit"); + +struct hashlimit_net { + struct hlist_head htables; + struct proc_dir_entry *ipt_hashlimit; + struct proc_dir_entry *ip6t_hashlimit; +}; + +static unsigned int hashlimit_net_id; +static inline struct hashlimit_net *hashlimit_pernet(struct net *net) +{ + return net_generic(net, hashlimit_net_id); +} + +/* need to declare this at the top */ +static const struct seq_operations dl_seq_ops_v2; +static const struct seq_operations dl_seq_ops_v1; +static const struct seq_operations dl_seq_ops; + +/* hash table crap */ +struct dsthash_dst { + union { + struct { + __be32 src; + __be32 dst; + } ip; +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + struct { + __be32 src[4]; + __be32 dst[4]; + } ip6; +#endif + }; + __be16 src_port; + __be16 dst_port; +}; + +struct dsthash_ent { + /* static / read-only parts in the beginning */ + struct hlist_node node; + struct dsthash_dst dst; + + /* modified structure members in the end */ + spinlock_t lock; + unsigned long expires; /* precalculated expiry time */ + struct { + unsigned long prev; /* last modification */ + union { + struct { + u_int64_t credit; + u_int64_t credit_cap; + u_int64_t cost; + }; + struct { + u_int32_t interval, prev_window; + u_int64_t current_rate; + u_int64_t rate; + int64_t burst; + }; + }; + } rateinfo; + struct rcu_head rcu; +}; + +struct xt_hashlimit_htable { + struct hlist_node node; /* global list of all htables */ + refcount_t use; + u_int8_t family; + bool rnd_initialized; + + struct hashlimit_cfg3 cfg; /* config */ + + /* used internally */ + spinlock_t lock; /* lock for list_head */ + u_int32_t rnd; /* random seed for hash */ + unsigned int count; /* number entries in table */ + struct delayed_work gc_work; + + /* seq_file stuff */ + struct proc_dir_entry *pde; + const char *name; + struct net *net; + + struct hlist_head hash[]; /* hashtable itself */ +}; + +static int +cfg_copy(struct hashlimit_cfg3 *to, const void *from, int revision) +{ + if (revision == 1) { + struct hashlimit_cfg1 *cfg = (struct hashlimit_cfg1 *)from; + + to->mode = cfg->mode; + to->avg = cfg->avg; + to->burst = cfg->burst; + to->size = cfg->size; + to->max = cfg->max; + to->gc_interval = cfg->gc_interval; + to->expire = cfg->expire; + to->srcmask = cfg->srcmask; + to->dstmask = cfg->dstmask; + } else if (revision == 2) { + struct hashlimit_cfg2 *cfg = (struct hashlimit_cfg2 *)from; + + to->mode = cfg->mode; + to->avg = cfg->avg; + to->burst = cfg->burst; + to->size = cfg->size; + to->max = cfg->max; + to->gc_interval = cfg->gc_interval; + to->expire = cfg->expire; + to->srcmask = cfg->srcmask; + to->dstmask = cfg->dstmask; + } else if (revision == 3) { + memcpy(to, from, sizeof(struct hashlimit_cfg3)); + } else { + return -EINVAL; + } + + return 0; +} + +static DEFINE_MUTEX(hashlimit_mutex); /* protects htables list */ +static struct kmem_cache *hashlimit_cachep __read_mostly; + +static inline bool dst_cmp(const struct dsthash_ent *ent, + const struct dsthash_dst *b) +{ + return !memcmp(&ent->dst, b, sizeof(ent->dst)); +} + +static u_int32_t +hash_dst(const struct xt_hashlimit_htable *ht, const struct dsthash_dst *dst) +{ + u_int32_t hash = jhash2((const u32 *)dst, + sizeof(*dst)/sizeof(u32), + ht->rnd); + /* + * Instead of returning hash % ht->cfg.size (implying a divide) + * we return the high 32 bits of the (hash * ht->cfg.size) that will + * give results between [0 and cfg.size-1] and same hash distribution, + * but using a multiply, less expensive than a divide + */ + return reciprocal_scale(hash, ht->cfg.size); +} + +static struct dsthash_ent * +dsthash_find(const struct xt_hashlimit_htable *ht, + const struct dsthash_dst *dst) +{ + struct dsthash_ent *ent; + u_int32_t hash = hash_dst(ht, dst); + + if (!hlist_empty(&ht->hash[hash])) { + hlist_for_each_entry_rcu(ent, &ht->hash[hash], node) + if (dst_cmp(ent, dst)) { + spin_lock(&ent->lock); + return ent; + } + } + return NULL; +} + +/* allocate dsthash_ent, initialize dst, put in htable and lock it */ +static struct dsthash_ent * +dsthash_alloc_init(struct xt_hashlimit_htable *ht, + const struct dsthash_dst *dst, bool *race) +{ + struct dsthash_ent *ent; + + spin_lock(&ht->lock); + + /* Two or more packets may race to create the same entry in the + * hashtable, double check if this packet lost race. + */ + ent = dsthash_find(ht, dst); + if (ent != NULL) { + spin_unlock(&ht->lock); + *race = true; + return ent; + } + + /* initialize hash with random val at the time we allocate + * the first hashtable entry */ + if (unlikely(!ht->rnd_initialized)) { + get_random_bytes(&ht->rnd, sizeof(ht->rnd)); + ht->rnd_initialized = true; + } + + if (ht->cfg.max && ht->count >= ht->cfg.max) { + /* FIXME: do something. question is what.. */ + net_err_ratelimited("max count of %u reached\n", ht->cfg.max); + ent = NULL; + } else + ent = kmem_cache_alloc(hashlimit_cachep, GFP_ATOMIC); + if (ent) { + memcpy(&ent->dst, dst, sizeof(ent->dst)); + spin_lock_init(&ent->lock); + + spin_lock(&ent->lock); + hlist_add_head_rcu(&ent->node, &ht->hash[hash_dst(ht, dst)]); + ht->count++; + } + spin_unlock(&ht->lock); + return ent; +} + +static void dsthash_free_rcu(struct rcu_head *head) +{ + struct dsthash_ent *ent = container_of(head, struct dsthash_ent, rcu); + + kmem_cache_free(hashlimit_cachep, ent); +} + +static inline void +dsthash_free(struct xt_hashlimit_htable *ht, struct dsthash_ent *ent) +{ + hlist_del_rcu(&ent->node); + call_rcu(&ent->rcu, dsthash_free_rcu); + ht->count--; +} +static void htable_gc(struct work_struct *work); + +static int htable_create(struct net *net, struct hashlimit_cfg3 *cfg, + const char *name, u_int8_t family, + struct xt_hashlimit_htable **out_hinfo, + int revision) +{ + struct hashlimit_net *hashlimit_net = hashlimit_pernet(net); + struct xt_hashlimit_htable *hinfo; + const struct seq_operations *ops; + unsigned int size, i; + unsigned long nr_pages = totalram_pages(); + int ret; + + if (cfg->size) { + size = cfg->size; + } else { + size = (nr_pages << PAGE_SHIFT) / 16384 / + sizeof(struct hlist_head); + if (nr_pages > 1024 * 1024 * 1024 / PAGE_SIZE) + size = 8192; + if (size < 16) + size = 16; + } + /* FIXME: don't use vmalloc() here or anywhere else -HW */ + hinfo = vmalloc(struct_size(hinfo, hash, size)); + if (hinfo == NULL) + return -ENOMEM; + *out_hinfo = hinfo; + + /* copy match config into hashtable config */ + ret = cfg_copy(&hinfo->cfg, (void *)cfg, 3); + if (ret) { + vfree(hinfo); + return ret; + } + + hinfo->cfg.size = size; + if (hinfo->cfg.max == 0) + hinfo->cfg.max = 8 * hinfo->cfg.size; + else if (hinfo->cfg.max < hinfo->cfg.size) + hinfo->cfg.max = hinfo->cfg.size; + + for (i = 0; i < hinfo->cfg.size; i++) + INIT_HLIST_HEAD(&hinfo->hash[i]); + + refcount_set(&hinfo->use, 1); + hinfo->count = 0; + hinfo->family = family; + hinfo->rnd_initialized = false; + hinfo->name = kstrdup(name, GFP_KERNEL); + if (!hinfo->name) { + vfree(hinfo); + return -ENOMEM; + } + spin_lock_init(&hinfo->lock); + + switch (revision) { + case 1: + ops = &dl_seq_ops_v1; + break; + case 2: + ops = &dl_seq_ops_v2; + break; + default: + ops = &dl_seq_ops; + } + + hinfo->pde = proc_create_seq_data(name, 0, + (family == NFPROTO_IPV4) ? + hashlimit_net->ipt_hashlimit : hashlimit_net->ip6t_hashlimit, + ops, hinfo); + if (hinfo->pde == NULL) { + kfree(hinfo->name); + vfree(hinfo); + return -ENOMEM; + } + hinfo->net = net; + + INIT_DEFERRABLE_WORK(&hinfo->gc_work, htable_gc); + queue_delayed_work(system_power_efficient_wq, &hinfo->gc_work, + msecs_to_jiffies(hinfo->cfg.gc_interval)); + + hlist_add_head(&hinfo->node, &hashlimit_net->htables); + + return 0; +} + +static void htable_selective_cleanup(struct xt_hashlimit_htable *ht, bool select_all) +{ + unsigned int i; + + for (i = 0; i < ht->cfg.size; i++) { + struct dsthash_ent *dh; + struct hlist_node *n; + + spin_lock_bh(&ht->lock); + hlist_for_each_entry_safe(dh, n, &ht->hash[i], node) { + if (time_after_eq(jiffies, dh->expires) || select_all) + dsthash_free(ht, dh); + } + spin_unlock_bh(&ht->lock); + cond_resched(); + } +} + +static void htable_gc(struct work_struct *work) +{ + struct xt_hashlimit_htable *ht; + + ht = container_of(work, struct xt_hashlimit_htable, gc_work.work); + + htable_selective_cleanup(ht, false); + + queue_delayed_work(system_power_efficient_wq, + &ht->gc_work, msecs_to_jiffies(ht->cfg.gc_interval)); +} + +static void htable_remove_proc_entry(struct xt_hashlimit_htable *hinfo) +{ + struct hashlimit_net *hashlimit_net = hashlimit_pernet(hinfo->net); + struct proc_dir_entry *parent; + + if (hinfo->family == NFPROTO_IPV4) + parent = hashlimit_net->ipt_hashlimit; + else + parent = hashlimit_net->ip6t_hashlimit; + + if (parent != NULL) + remove_proc_entry(hinfo->name, parent); +} + +static struct xt_hashlimit_htable *htable_find_get(struct net *net, + const char *name, + u_int8_t family) +{ + struct hashlimit_net *hashlimit_net = hashlimit_pernet(net); + struct xt_hashlimit_htable *hinfo; + + hlist_for_each_entry(hinfo, &hashlimit_net->htables, node) { + if (!strcmp(name, hinfo->name) && + hinfo->family == family) { + refcount_inc(&hinfo->use); + return hinfo; + } + } + return NULL; +} + +static void htable_put(struct xt_hashlimit_htable *hinfo) +{ + if (refcount_dec_and_mutex_lock(&hinfo->use, &hashlimit_mutex)) { + hlist_del(&hinfo->node); + htable_remove_proc_entry(hinfo); + mutex_unlock(&hashlimit_mutex); + + cancel_delayed_work_sync(&hinfo->gc_work); + htable_selective_cleanup(hinfo, true); + kfree(hinfo->name); + vfree(hinfo); + } +} + +/* The algorithm used is the Simple Token Bucket Filter (TBF) + * see net/sched/sch_tbf.c in the linux source tree + */ + +/* Rusty: This is my (non-mathematically-inclined) understanding of + this algorithm. The `average rate' in jiffies becomes your initial + amount of credit `credit' and the most credit you can ever have + `credit_cap'. The `peak rate' becomes the cost of passing the + test, `cost'. + + `prev' tracks the last packet hit: you gain one credit per jiffy. + If you get credit balance more than this, the extra credit is + discarded. Every time the match passes, you lose `cost' credits; + if you don't have that many, the test fails. + + See Alexey's formal explanation in net/sched/sch_tbf.c. + + To get the maximum range, we multiply by this factor (ie. you get N + credits per jiffy). We want to allow a rate as low as 1 per day + (slowest userspace tool allows), which means + CREDITS_PER_JIFFY*HZ*60*60*24 < 2^32 ie. +*/ +#define MAX_CPJ_v1 (0xFFFFFFFF / (HZ*60*60*24)) +#define MAX_CPJ (0xFFFFFFFFFFFFFFFFULL / (HZ*60*60*24)) + +/* Repeated shift and or gives us all 1s, final shift and add 1 gives + * us the power of 2 below the theoretical max, so GCC simply does a + * shift. */ +#define _POW2_BELOW2(x) ((x)|((x)>>1)) +#define _POW2_BELOW4(x) (_POW2_BELOW2(x)|_POW2_BELOW2((x)>>2)) +#define _POW2_BELOW8(x) (_POW2_BELOW4(x)|_POW2_BELOW4((x)>>4)) +#define _POW2_BELOW16(x) (_POW2_BELOW8(x)|_POW2_BELOW8((x)>>8)) +#define _POW2_BELOW32(x) (_POW2_BELOW16(x)|_POW2_BELOW16((x)>>16)) +#define _POW2_BELOW64(x) (_POW2_BELOW32(x)|_POW2_BELOW32((x)>>32)) +#define POW2_BELOW32(x) ((_POW2_BELOW32(x)>>1) + 1) +#define POW2_BELOW64(x) ((_POW2_BELOW64(x)>>1) + 1) + +#define CREDITS_PER_JIFFY POW2_BELOW64(MAX_CPJ) +#define CREDITS_PER_JIFFY_v1 POW2_BELOW32(MAX_CPJ_v1) + +/* in byte mode, the lowest possible rate is one packet/second. + * credit_cap is used as a counter that tells us how many times we can + * refill the "credits available" counter when it becomes empty. + */ +#define MAX_CPJ_BYTES (0xFFFFFFFF / HZ) +#define CREDITS_PER_JIFFY_BYTES POW2_BELOW32(MAX_CPJ_BYTES) + +static u32 xt_hashlimit_len_to_chunks(u32 len) +{ + return (len >> XT_HASHLIMIT_BYTE_SHIFT) + 1; +} + +/* Precision saver. */ +static u64 user2credits(u64 user, int revision) +{ + u64 scale = (revision == 1) ? + XT_HASHLIMIT_SCALE : XT_HASHLIMIT_SCALE_v2; + u64 cpj = (revision == 1) ? + CREDITS_PER_JIFFY_v1 : CREDITS_PER_JIFFY; + + /* Avoid overflow: divide the constant operands first */ + if (scale >= HZ * cpj) + return div64_u64(user, div64_u64(scale, HZ * cpj)); + + return user * div64_u64(HZ * cpj, scale); +} + +static u32 user2credits_byte(u32 user) +{ + u64 us = user; + us *= HZ * CREDITS_PER_JIFFY_BYTES; + return (u32) (us >> 32); +} + +static u64 user2rate(u64 user) +{ + if (user != 0) { + return div64_u64(XT_HASHLIMIT_SCALE_v2, user); + } else { + pr_info_ratelimited("invalid rate from userspace: %llu\n", + user); + return 0; + } +} + +static u64 user2rate_bytes(u32 user) +{ + u64 r; + + r = user ? U32_MAX / user : U32_MAX; + return (r - 1) << XT_HASHLIMIT_BYTE_SHIFT; +} + +static void rateinfo_recalc(struct dsthash_ent *dh, unsigned long now, + u32 mode, int revision) +{ + unsigned long delta = now - dh->rateinfo.prev; + u64 cap, cpj; + + if (delta == 0) + return; + + if (revision >= 3 && mode & XT_HASHLIMIT_RATE_MATCH) { + u64 interval = dh->rateinfo.interval * HZ; + + if (delta < interval) + return; + + dh->rateinfo.prev = now; + dh->rateinfo.prev_window = + ((dh->rateinfo.current_rate * interval) > + (delta * dh->rateinfo.rate)); + dh->rateinfo.current_rate = 0; + + return; + } + + dh->rateinfo.prev = now; + + if (mode & XT_HASHLIMIT_BYTES) { + u64 tmp = dh->rateinfo.credit; + dh->rateinfo.credit += CREDITS_PER_JIFFY_BYTES * delta; + cap = CREDITS_PER_JIFFY_BYTES * HZ; + if (tmp >= dh->rateinfo.credit) {/* overflow */ + dh->rateinfo.credit = cap; + return; + } + } else { + cpj = (revision == 1) ? + CREDITS_PER_JIFFY_v1 : CREDITS_PER_JIFFY; + dh->rateinfo.credit += delta * cpj; + cap = dh->rateinfo.credit_cap; + } + if (dh->rateinfo.credit > cap) + dh->rateinfo.credit = cap; +} + +static void rateinfo_init(struct dsthash_ent *dh, + struct xt_hashlimit_htable *hinfo, int revision) +{ + dh->rateinfo.prev = jiffies; + if (revision >= 3 && hinfo->cfg.mode & XT_HASHLIMIT_RATE_MATCH) { + dh->rateinfo.prev_window = 0; + dh->rateinfo.current_rate = 0; + if (hinfo->cfg.mode & XT_HASHLIMIT_BYTES) { + dh->rateinfo.rate = + user2rate_bytes((u32)hinfo->cfg.avg); + if (hinfo->cfg.burst) + dh->rateinfo.burst = + hinfo->cfg.burst * dh->rateinfo.rate; + else + dh->rateinfo.burst = dh->rateinfo.rate; + } else { + dh->rateinfo.rate = user2rate(hinfo->cfg.avg); + dh->rateinfo.burst = + hinfo->cfg.burst + dh->rateinfo.rate; + } + dh->rateinfo.interval = hinfo->cfg.interval; + } else if (hinfo->cfg.mode & XT_HASHLIMIT_BYTES) { + dh->rateinfo.credit = CREDITS_PER_JIFFY_BYTES * HZ; + dh->rateinfo.cost = user2credits_byte(hinfo->cfg.avg); + dh->rateinfo.credit_cap = hinfo->cfg.burst; + } else { + dh->rateinfo.credit = user2credits(hinfo->cfg.avg * + hinfo->cfg.burst, revision); + dh->rateinfo.cost = user2credits(hinfo->cfg.avg, revision); + dh->rateinfo.credit_cap = dh->rateinfo.credit; + } +} + +static inline __be32 maskl(__be32 a, unsigned int l) +{ + return l ? htonl(ntohl(a) & ~0 << (32 - l)) : 0; +} + +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) +static void hashlimit_ipv6_mask(__be32 *i, unsigned int p) +{ + switch (p) { + case 0 ... 31: + i[0] = maskl(i[0], p); + i[1] = i[2] = i[3] = 0; + break; + case 32 ... 63: + i[1] = maskl(i[1], p - 32); + i[2] = i[3] = 0; + break; + case 64 ... 95: + i[2] = maskl(i[2], p - 64); + i[3] = 0; + break; + case 96 ... 127: + i[3] = maskl(i[3], p - 96); + break; + case 128: + break; + } +} +#endif + +static int +hashlimit_init_dst(const struct xt_hashlimit_htable *hinfo, + struct dsthash_dst *dst, + const struct sk_buff *skb, unsigned int protoff) +{ + __be16 _ports[2], *ports; + u8 nexthdr; + int poff; + + memset(dst, 0, sizeof(*dst)); + + switch (hinfo->family) { + case NFPROTO_IPV4: + if (hinfo->cfg.mode & XT_HASHLIMIT_HASH_DIP) + dst->ip.dst = maskl(ip_hdr(skb)->daddr, + hinfo->cfg.dstmask); + if (hinfo->cfg.mode & XT_HASHLIMIT_HASH_SIP) + dst->ip.src = maskl(ip_hdr(skb)->saddr, + hinfo->cfg.srcmask); + + if (!(hinfo->cfg.mode & + (XT_HASHLIMIT_HASH_DPT | XT_HASHLIMIT_HASH_SPT))) + return 0; + nexthdr = ip_hdr(skb)->protocol; + break; +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + case NFPROTO_IPV6: + { + __be16 frag_off; + + if (hinfo->cfg.mode & XT_HASHLIMIT_HASH_DIP) { + memcpy(&dst->ip6.dst, &ipv6_hdr(skb)->daddr, + sizeof(dst->ip6.dst)); + hashlimit_ipv6_mask(dst->ip6.dst, hinfo->cfg.dstmask); + } + if (hinfo->cfg.mode & XT_HASHLIMIT_HASH_SIP) { + memcpy(&dst->ip6.src, &ipv6_hdr(skb)->saddr, + sizeof(dst->ip6.src)); + hashlimit_ipv6_mask(dst->ip6.src, hinfo->cfg.srcmask); + } + + if (!(hinfo->cfg.mode & + (XT_HASHLIMIT_HASH_DPT | XT_HASHLIMIT_HASH_SPT))) + return 0; + nexthdr = ipv6_hdr(skb)->nexthdr; + protoff = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &nexthdr, &frag_off); + if ((int)protoff < 0) + return -1; + break; + } +#endif + default: + BUG(); + return 0; + } + + poff = proto_ports_offset(nexthdr); + if (poff >= 0) { + ports = skb_header_pointer(skb, protoff + poff, sizeof(_ports), + &_ports); + } else { + _ports[0] = _ports[1] = 0; + ports = _ports; + } + if (!ports) + return -1; + if (hinfo->cfg.mode & XT_HASHLIMIT_HASH_SPT) + dst->src_port = ports[0]; + if (hinfo->cfg.mode & XT_HASHLIMIT_HASH_DPT) + dst->dst_port = ports[1]; + return 0; +} + +static u32 hashlimit_byte_cost(unsigned int len, struct dsthash_ent *dh) +{ + u64 tmp = xt_hashlimit_len_to_chunks(len); + tmp = tmp * dh->rateinfo.cost; + + if (unlikely(tmp > CREDITS_PER_JIFFY_BYTES * HZ)) + tmp = CREDITS_PER_JIFFY_BYTES * HZ; + + if (dh->rateinfo.credit < tmp && dh->rateinfo.credit_cap) { + dh->rateinfo.credit_cap--; + dh->rateinfo.credit = CREDITS_PER_JIFFY_BYTES * HZ; + } + return (u32) tmp; +} + +static bool +hashlimit_mt_common(const struct sk_buff *skb, struct xt_action_param *par, + struct xt_hashlimit_htable *hinfo, + const struct hashlimit_cfg3 *cfg, int revision) +{ + unsigned long now = jiffies; + struct dsthash_ent *dh; + struct dsthash_dst dst; + bool race = false; + u64 cost; + + if (hashlimit_init_dst(hinfo, &dst, skb, par->thoff) < 0) + goto hotdrop; + + local_bh_disable(); + dh = dsthash_find(hinfo, &dst); + if (dh == NULL) { + dh = dsthash_alloc_init(hinfo, &dst, &race); + if (dh == NULL) { + local_bh_enable(); + goto hotdrop; + } else if (race) { + /* Already got an entry, update expiration timeout */ + dh->expires = now + msecs_to_jiffies(hinfo->cfg.expire); + rateinfo_recalc(dh, now, hinfo->cfg.mode, revision); + } else { + dh->expires = jiffies + msecs_to_jiffies(hinfo->cfg.expire); + rateinfo_init(dh, hinfo, revision); + } + } else { + /* update expiration timeout */ + dh->expires = now + msecs_to_jiffies(hinfo->cfg.expire); + rateinfo_recalc(dh, now, hinfo->cfg.mode, revision); + } + + if (cfg->mode & XT_HASHLIMIT_RATE_MATCH) { + cost = (cfg->mode & XT_HASHLIMIT_BYTES) ? skb->len : 1; + dh->rateinfo.current_rate += cost; + + if (!dh->rateinfo.prev_window && + (dh->rateinfo.current_rate <= dh->rateinfo.burst)) { + spin_unlock(&dh->lock); + local_bh_enable(); + return !(cfg->mode & XT_HASHLIMIT_INVERT); + } else { + goto overlimit; + } + } + + if (cfg->mode & XT_HASHLIMIT_BYTES) + cost = hashlimit_byte_cost(skb->len, dh); + else + cost = dh->rateinfo.cost; + + if (dh->rateinfo.credit >= cost) { + /* below the limit */ + dh->rateinfo.credit -= cost; + spin_unlock(&dh->lock); + local_bh_enable(); + return !(cfg->mode & XT_HASHLIMIT_INVERT); + } + +overlimit: + spin_unlock(&dh->lock); + local_bh_enable(); + /* default match is underlimit - so over the limit, we need to invert */ + return cfg->mode & XT_HASHLIMIT_INVERT; + + hotdrop: + par->hotdrop = true; + return false; +} + +static bool +hashlimit_mt_v1(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_hashlimit_mtinfo1 *info = par->matchinfo; + struct xt_hashlimit_htable *hinfo = info->hinfo; + struct hashlimit_cfg3 cfg = {}; + int ret; + + ret = cfg_copy(&cfg, (void *)&info->cfg, 1); + if (ret) + return ret; + + return hashlimit_mt_common(skb, par, hinfo, &cfg, 1); +} + +static bool +hashlimit_mt_v2(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_hashlimit_mtinfo2 *info = par->matchinfo; + struct xt_hashlimit_htable *hinfo = info->hinfo; + struct hashlimit_cfg3 cfg = {}; + int ret; + + ret = cfg_copy(&cfg, (void *)&info->cfg, 2); + if (ret) + return ret; + + return hashlimit_mt_common(skb, par, hinfo, &cfg, 2); +} + +static bool +hashlimit_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_hashlimit_mtinfo3 *info = par->matchinfo; + struct xt_hashlimit_htable *hinfo = info->hinfo; + + return hashlimit_mt_common(skb, par, hinfo, &info->cfg, 3); +} + +#define HASHLIMIT_MAX_SIZE 1048576 + +static int hashlimit_mt_check_common(const struct xt_mtchk_param *par, + struct xt_hashlimit_htable **hinfo, + struct hashlimit_cfg3 *cfg, + const char *name, int revision) +{ + struct net *net = par->net; + int ret; + + if (cfg->gc_interval == 0 || cfg->expire == 0) + return -EINVAL; + if (cfg->size > HASHLIMIT_MAX_SIZE) { + cfg->size = HASHLIMIT_MAX_SIZE; + pr_info_ratelimited("size too large, truncated to %u\n", cfg->size); + } + if (cfg->max > HASHLIMIT_MAX_SIZE) { + cfg->max = HASHLIMIT_MAX_SIZE; + pr_info_ratelimited("max too large, truncated to %u\n", cfg->max); + } + if (par->family == NFPROTO_IPV4) { + if (cfg->srcmask > 32 || cfg->dstmask > 32) + return -EINVAL; + } else { + if (cfg->srcmask > 128 || cfg->dstmask > 128) + return -EINVAL; + } + + if (cfg->mode & ~XT_HASHLIMIT_ALL) { + pr_info_ratelimited("Unknown mode mask %X, kernel too old?\n", + cfg->mode); + return -EINVAL; + } + + /* Check for overflow. */ + if (revision >= 3 && cfg->mode & XT_HASHLIMIT_RATE_MATCH) { + if (cfg->avg == 0 || cfg->avg > U32_MAX) { + pr_info_ratelimited("invalid rate\n"); + return -ERANGE; + } + + if (cfg->interval == 0) { + pr_info_ratelimited("invalid interval\n"); + return -EINVAL; + } + } else if (cfg->mode & XT_HASHLIMIT_BYTES) { + if (user2credits_byte(cfg->avg) == 0) { + pr_info_ratelimited("overflow, rate too high: %llu\n", + cfg->avg); + return -EINVAL; + } + } else if (cfg->burst == 0 || + user2credits(cfg->avg * cfg->burst, revision) < + user2credits(cfg->avg, revision)) { + pr_info_ratelimited("overflow, try lower: %llu/%llu\n", + cfg->avg, cfg->burst); + return -ERANGE; + } + + mutex_lock(&hashlimit_mutex); + *hinfo = htable_find_get(net, name, par->family); + if (*hinfo == NULL) { + ret = htable_create(net, cfg, name, par->family, + hinfo, revision); + if (ret < 0) { + mutex_unlock(&hashlimit_mutex); + return ret; + } + } + mutex_unlock(&hashlimit_mutex); + + return 0; +} + +static int hashlimit_mt_check_v1(const struct xt_mtchk_param *par) +{ + struct xt_hashlimit_mtinfo1 *info = par->matchinfo; + struct hashlimit_cfg3 cfg = {}; + int ret; + + ret = xt_check_proc_name(info->name, sizeof(info->name)); + if (ret) + return ret; + + ret = cfg_copy(&cfg, (void *)&info->cfg, 1); + if (ret) + return ret; + + return hashlimit_mt_check_common(par, &info->hinfo, + &cfg, info->name, 1); +} + +static int hashlimit_mt_check_v2(const struct xt_mtchk_param *par) +{ + struct xt_hashlimit_mtinfo2 *info = par->matchinfo; + struct hashlimit_cfg3 cfg = {}; + int ret; + + ret = xt_check_proc_name(info->name, sizeof(info->name)); + if (ret) + return ret; + + ret = cfg_copy(&cfg, (void *)&info->cfg, 2); + if (ret) + return ret; + + return hashlimit_mt_check_common(par, &info->hinfo, + &cfg, info->name, 2); +} + +static int hashlimit_mt_check(const struct xt_mtchk_param *par) +{ + struct xt_hashlimit_mtinfo3 *info = par->matchinfo; + int ret; + + ret = xt_check_proc_name(info->name, sizeof(info->name)); + if (ret) + return ret; + + return hashlimit_mt_check_common(par, &info->hinfo, &info->cfg, + info->name, 3); +} + +static void hashlimit_mt_destroy_v2(const struct xt_mtdtor_param *par) +{ + const struct xt_hashlimit_mtinfo2 *info = par->matchinfo; + + htable_put(info->hinfo); +} + +static void hashlimit_mt_destroy_v1(const struct xt_mtdtor_param *par) +{ + const struct xt_hashlimit_mtinfo1 *info = par->matchinfo; + + htable_put(info->hinfo); +} + +static void hashlimit_mt_destroy(const struct xt_mtdtor_param *par) +{ + const struct xt_hashlimit_mtinfo3 *info = par->matchinfo; + + htable_put(info->hinfo); +} + +static struct xt_match hashlimit_mt_reg[] __read_mostly = { + { + .name = "hashlimit", + .revision = 1, + .family = NFPROTO_IPV4, + .match = hashlimit_mt_v1, + .matchsize = sizeof(struct xt_hashlimit_mtinfo1), + .usersize = offsetof(struct xt_hashlimit_mtinfo1, hinfo), + .checkentry = hashlimit_mt_check_v1, + .destroy = hashlimit_mt_destroy_v1, + .me = THIS_MODULE, + }, + { + .name = "hashlimit", + .revision = 2, + .family = NFPROTO_IPV4, + .match = hashlimit_mt_v2, + .matchsize = sizeof(struct xt_hashlimit_mtinfo2), + .usersize = offsetof(struct xt_hashlimit_mtinfo2, hinfo), + .checkentry = hashlimit_mt_check_v2, + .destroy = hashlimit_mt_destroy_v2, + .me = THIS_MODULE, + }, + { + .name = "hashlimit", + .revision = 3, + .family = NFPROTO_IPV4, + .match = hashlimit_mt, + .matchsize = sizeof(struct xt_hashlimit_mtinfo3), + .usersize = offsetof(struct xt_hashlimit_mtinfo3, hinfo), + .checkentry = hashlimit_mt_check, + .destroy = hashlimit_mt_destroy, + .me = THIS_MODULE, + }, +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + { + .name = "hashlimit", + .revision = 1, + .family = NFPROTO_IPV6, + .match = hashlimit_mt_v1, + .matchsize = sizeof(struct xt_hashlimit_mtinfo1), + .usersize = offsetof(struct xt_hashlimit_mtinfo1, hinfo), + .checkentry = hashlimit_mt_check_v1, + .destroy = hashlimit_mt_destroy_v1, + .me = THIS_MODULE, + }, + { + .name = "hashlimit", + .revision = 2, + .family = NFPROTO_IPV6, + .match = hashlimit_mt_v2, + .matchsize = sizeof(struct xt_hashlimit_mtinfo2), + .usersize = offsetof(struct xt_hashlimit_mtinfo2, hinfo), + .checkentry = hashlimit_mt_check_v2, + .destroy = hashlimit_mt_destroy_v2, + .me = THIS_MODULE, + }, + { + .name = "hashlimit", + .revision = 3, + .family = NFPROTO_IPV6, + .match = hashlimit_mt, + .matchsize = sizeof(struct xt_hashlimit_mtinfo3), + .usersize = offsetof(struct xt_hashlimit_mtinfo3, hinfo), + .checkentry = hashlimit_mt_check, + .destroy = hashlimit_mt_destroy, + .me = THIS_MODULE, + }, +#endif +}; + +/* PROC stuff */ +static void *dl_seq_start(struct seq_file *s, loff_t *pos) + __acquires(htable->lock) +{ + struct xt_hashlimit_htable *htable = pde_data(file_inode(s->file)); + unsigned int *bucket; + + spin_lock_bh(&htable->lock); + if (*pos >= htable->cfg.size) + return NULL; + + bucket = kmalloc(sizeof(unsigned int), GFP_ATOMIC); + if (!bucket) + return ERR_PTR(-ENOMEM); + + *bucket = *pos; + return bucket; +} + +static void *dl_seq_next(struct seq_file *s, void *v, loff_t *pos) +{ + struct xt_hashlimit_htable *htable = pde_data(file_inode(s->file)); + unsigned int *bucket = v; + + *pos = ++(*bucket); + if (*pos >= htable->cfg.size) { + kfree(v); + return NULL; + } + return bucket; +} + +static void dl_seq_stop(struct seq_file *s, void *v) + __releases(htable->lock) +{ + struct xt_hashlimit_htable *htable = pde_data(file_inode(s->file)); + unsigned int *bucket = v; + + if (!IS_ERR(bucket)) + kfree(bucket); + spin_unlock_bh(&htable->lock); +} + +static void dl_seq_print(struct dsthash_ent *ent, u_int8_t family, + struct seq_file *s) +{ + switch (family) { + case NFPROTO_IPV4: + seq_printf(s, "%ld %pI4:%u->%pI4:%u %llu %llu %llu\n", + (long)(ent->expires - jiffies)/HZ, + &ent->dst.ip.src, + ntohs(ent->dst.src_port), + &ent->dst.ip.dst, + ntohs(ent->dst.dst_port), + ent->rateinfo.credit, ent->rateinfo.credit_cap, + ent->rateinfo.cost); + break; +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + case NFPROTO_IPV6: + seq_printf(s, "%ld %pI6:%u->%pI6:%u %llu %llu %llu\n", + (long)(ent->expires - jiffies)/HZ, + &ent->dst.ip6.src, + ntohs(ent->dst.src_port), + &ent->dst.ip6.dst, + ntohs(ent->dst.dst_port), + ent->rateinfo.credit, ent->rateinfo.credit_cap, + ent->rateinfo.cost); + break; +#endif + default: + BUG(); + } +} + +static int dl_seq_real_show_v2(struct dsthash_ent *ent, u_int8_t family, + struct seq_file *s) +{ + struct xt_hashlimit_htable *ht = pde_data(file_inode(s->file)); + + spin_lock(&ent->lock); + /* recalculate to show accurate numbers */ + rateinfo_recalc(ent, jiffies, ht->cfg.mode, 2); + + dl_seq_print(ent, family, s); + + spin_unlock(&ent->lock); + return seq_has_overflowed(s); +} + +static int dl_seq_real_show_v1(struct dsthash_ent *ent, u_int8_t family, + struct seq_file *s) +{ + struct xt_hashlimit_htable *ht = pde_data(file_inode(s->file)); + + spin_lock(&ent->lock); + /* recalculate to show accurate numbers */ + rateinfo_recalc(ent, jiffies, ht->cfg.mode, 1); + + dl_seq_print(ent, family, s); + + spin_unlock(&ent->lock); + return seq_has_overflowed(s); +} + +static int dl_seq_real_show(struct dsthash_ent *ent, u_int8_t family, + struct seq_file *s) +{ + struct xt_hashlimit_htable *ht = pde_data(file_inode(s->file)); + + spin_lock(&ent->lock); + /* recalculate to show accurate numbers */ + rateinfo_recalc(ent, jiffies, ht->cfg.mode, 3); + + dl_seq_print(ent, family, s); + + spin_unlock(&ent->lock); + return seq_has_overflowed(s); +} + +static int dl_seq_show_v2(struct seq_file *s, void *v) +{ + struct xt_hashlimit_htable *htable = pde_data(file_inode(s->file)); + unsigned int *bucket = (unsigned int *)v; + struct dsthash_ent *ent; + + if (!hlist_empty(&htable->hash[*bucket])) { + hlist_for_each_entry(ent, &htable->hash[*bucket], node) + if (dl_seq_real_show_v2(ent, htable->family, s)) + return -1; + } + return 0; +} + +static int dl_seq_show_v1(struct seq_file *s, void *v) +{ + struct xt_hashlimit_htable *htable = pde_data(file_inode(s->file)); + unsigned int *bucket = v; + struct dsthash_ent *ent; + + if (!hlist_empty(&htable->hash[*bucket])) { + hlist_for_each_entry(ent, &htable->hash[*bucket], node) + if (dl_seq_real_show_v1(ent, htable->family, s)) + return -1; + } + return 0; +} + +static int dl_seq_show(struct seq_file *s, void *v) +{ + struct xt_hashlimit_htable *htable = pde_data(file_inode(s->file)); + unsigned int *bucket = v; + struct dsthash_ent *ent; + + if (!hlist_empty(&htable->hash[*bucket])) { + hlist_for_each_entry(ent, &htable->hash[*bucket], node) + if (dl_seq_real_show(ent, htable->family, s)) + return -1; + } + return 0; +} + +static const struct seq_operations dl_seq_ops_v1 = { + .start = dl_seq_start, + .next = dl_seq_next, + .stop = dl_seq_stop, + .show = dl_seq_show_v1 +}; + +static const struct seq_operations dl_seq_ops_v2 = { + .start = dl_seq_start, + .next = dl_seq_next, + .stop = dl_seq_stop, + .show = dl_seq_show_v2 +}; + +static const struct seq_operations dl_seq_ops = { + .start = dl_seq_start, + .next = dl_seq_next, + .stop = dl_seq_stop, + .show = dl_seq_show +}; + +static int __net_init hashlimit_proc_net_init(struct net *net) +{ + struct hashlimit_net *hashlimit_net = hashlimit_pernet(net); + + hashlimit_net->ipt_hashlimit = proc_mkdir("ipt_hashlimit", net->proc_net); + if (!hashlimit_net->ipt_hashlimit) + return -ENOMEM; +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + hashlimit_net->ip6t_hashlimit = proc_mkdir("ip6t_hashlimit", net->proc_net); + if (!hashlimit_net->ip6t_hashlimit) { + remove_proc_entry("ipt_hashlimit", net->proc_net); + return -ENOMEM; + } +#endif + return 0; +} + +static void __net_exit hashlimit_proc_net_exit(struct net *net) +{ + struct xt_hashlimit_htable *hinfo; + struct hashlimit_net *hashlimit_net = hashlimit_pernet(net); + + /* hashlimit_net_exit() is called before hashlimit_mt_destroy(). + * Make sure that the parent ipt_hashlimit and ip6t_hashlimit proc + * entries is empty before trying to remove it. + */ + mutex_lock(&hashlimit_mutex); + hlist_for_each_entry(hinfo, &hashlimit_net->htables, node) + htable_remove_proc_entry(hinfo); + hashlimit_net->ipt_hashlimit = NULL; + hashlimit_net->ip6t_hashlimit = NULL; + mutex_unlock(&hashlimit_mutex); + + remove_proc_entry("ipt_hashlimit", net->proc_net); +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + remove_proc_entry("ip6t_hashlimit", net->proc_net); +#endif +} + +static int __net_init hashlimit_net_init(struct net *net) +{ + struct hashlimit_net *hashlimit_net = hashlimit_pernet(net); + + INIT_HLIST_HEAD(&hashlimit_net->htables); + return hashlimit_proc_net_init(net); +} + +static void __net_exit hashlimit_net_exit(struct net *net) +{ + hashlimit_proc_net_exit(net); +} + +static struct pernet_operations hashlimit_net_ops = { + .init = hashlimit_net_init, + .exit = hashlimit_net_exit, + .id = &hashlimit_net_id, + .size = sizeof(struct hashlimit_net), +}; + +static int __init hashlimit_mt_init(void) +{ + int err; + + err = register_pernet_subsys(&hashlimit_net_ops); + if (err < 0) + return err; + err = xt_register_matches(hashlimit_mt_reg, + ARRAY_SIZE(hashlimit_mt_reg)); + if (err < 0) + goto err1; + + err = -ENOMEM; + hashlimit_cachep = kmem_cache_create("xt_hashlimit", + sizeof(struct dsthash_ent), 0, 0, + NULL); + if (!hashlimit_cachep) { + pr_warn("unable to create slab cache\n"); + goto err2; + } + return 0; + +err2: + xt_unregister_matches(hashlimit_mt_reg, ARRAY_SIZE(hashlimit_mt_reg)); +err1: + unregister_pernet_subsys(&hashlimit_net_ops); + return err; + +} + +static void __exit hashlimit_mt_exit(void) +{ + xt_unregister_matches(hashlimit_mt_reg, ARRAY_SIZE(hashlimit_mt_reg)); + unregister_pernet_subsys(&hashlimit_net_ops); + + rcu_barrier(); + kmem_cache_destroy(hashlimit_cachep); +} + +module_init(hashlimit_mt_init); +module_exit(hashlimit_mt_exit); diff --git a/net/netfilter/xt_helper.c b/net/netfilter/xt_helper.c new file mode 100644 index 000000000..a5a167f94 --- /dev/null +++ b/net/netfilter/xt_helper.c @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* iptables module to match on related connections */ +/* + * (C) 2001 Martin Josefsson + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Martin Josefsson "); +MODULE_DESCRIPTION("Xtables: Related connection matching"); +MODULE_ALIAS("ipt_helper"); +MODULE_ALIAS("ip6t_helper"); + + +static bool +helper_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_helper_info *info = par->matchinfo; + const struct nf_conn *ct; + const struct nf_conn_help *master_help; + const struct nf_conntrack_helper *helper; + enum ip_conntrack_info ctinfo; + bool ret = info->invert; + + ct = nf_ct_get(skb, &ctinfo); + if (!ct || !ct->master) + return ret; + + master_help = nfct_help(ct->master); + if (!master_help) + return ret; + + /* rcu_read_lock()ed by nf_hook_thresh */ + helper = rcu_dereference(master_help->helper); + if (!helper) + return ret; + + if (info->name[0] == '\0') + ret = !ret; + else + ret ^= !strncmp(helper->name, info->name, + strlen(helper->name)); + return ret; +} + +static int helper_mt_check(const struct xt_mtchk_param *par) +{ + struct xt_helper_info *info = par->matchinfo; + int ret; + + ret = nf_ct_netns_get(par->net, par->family); + if (ret < 0) { + pr_info_ratelimited("cannot load conntrack support for proto=%u\n", + par->family); + return ret; + } + info->name[sizeof(info->name) - 1] = '\0'; + return 0; +} + +static void helper_mt_destroy(const struct xt_mtdtor_param *par) +{ + nf_ct_netns_put(par->net, par->family); +} + +static struct xt_match helper_mt_reg __read_mostly = { + .name = "helper", + .revision = 0, + .family = NFPROTO_UNSPEC, + .checkentry = helper_mt_check, + .match = helper_mt, + .destroy = helper_mt_destroy, + .matchsize = sizeof(struct xt_helper_info), + .me = THIS_MODULE, +}; + +static int __init helper_mt_init(void) +{ + return xt_register_match(&helper_mt_reg); +} + +static void __exit helper_mt_exit(void) +{ + xt_unregister_match(&helper_mt_reg); +} + +module_init(helper_mt_init); +module_exit(helper_mt_exit); diff --git a/net/netfilter/xt_hl.c b/net/netfilter/xt_hl.c new file mode 100644 index 000000000..c1a70f8f0 --- /dev/null +++ b/net/netfilter/xt_hl.c @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * IP tables module for matching the value of the TTL + * (C) 2000,2001 by Harald Welte + * + * Hop Limit matching module + * (C) 2001-2002 Maciej Soltysiak + */ + +#include +#include +#include +#include + +#include +#include +#include + +MODULE_AUTHOR("Maciej Soltysiak "); +MODULE_DESCRIPTION("Xtables: Hoplimit/TTL field match"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_ttl"); +MODULE_ALIAS("ip6t_hl"); + +static bool ttl_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct ipt_ttl_info *info = par->matchinfo; + const u8 ttl = ip_hdr(skb)->ttl; + + switch (info->mode) { + case IPT_TTL_EQ: + return ttl == info->ttl; + case IPT_TTL_NE: + return ttl != info->ttl; + case IPT_TTL_LT: + return ttl < info->ttl; + case IPT_TTL_GT: + return ttl > info->ttl; + } + + return false; +} + +static bool hl_mt6(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct ip6t_hl_info *info = par->matchinfo; + const struct ipv6hdr *ip6h = ipv6_hdr(skb); + + switch (info->mode) { + case IP6T_HL_EQ: + return ip6h->hop_limit == info->hop_limit; + case IP6T_HL_NE: + return ip6h->hop_limit != info->hop_limit; + case IP6T_HL_LT: + return ip6h->hop_limit < info->hop_limit; + case IP6T_HL_GT: + return ip6h->hop_limit > info->hop_limit; + } + + return false; +} + +static struct xt_match hl_mt_reg[] __read_mostly = { + { + .name = "ttl", + .revision = 0, + .family = NFPROTO_IPV4, + .match = ttl_mt, + .matchsize = sizeof(struct ipt_ttl_info), + .me = THIS_MODULE, + }, + { + .name = "hl", + .revision = 0, + .family = NFPROTO_IPV6, + .match = hl_mt6, + .matchsize = sizeof(struct ip6t_hl_info), + .me = THIS_MODULE, + }, +}; + +static int __init hl_mt_init(void) +{ + return xt_register_matches(hl_mt_reg, ARRAY_SIZE(hl_mt_reg)); +} + +static void __exit hl_mt_exit(void) +{ + xt_unregister_matches(hl_mt_reg, ARRAY_SIZE(hl_mt_reg)); +} + +module_init(hl_mt_init); +module_exit(hl_mt_exit); diff --git a/net/netfilter/xt_ipcomp.c b/net/netfilter/xt_ipcomp.c new file mode 100644 index 000000000..472da639a --- /dev/null +++ b/net/netfilter/xt_ipcomp.c @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* Kernel module to match IPComp parameters for IPv4 and IPv6 + * + * Copyright (C) 2013 WindRiver + * + * Author: + * Fan Du + * + * Based on: + * net/netfilter/xt_esp.c + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include + +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Fan Du "); +MODULE_DESCRIPTION("Xtables: IPv4/6 IPsec-IPComp SPI match"); +MODULE_ALIAS("ipt_ipcomp"); +MODULE_ALIAS("ip6t_ipcomp"); + +/* Returns 1 if the spi is matched by the range, 0 otherwise */ +static inline bool +spi_match(u_int32_t min, u_int32_t max, u_int32_t spi, bool invert) +{ + bool r; + pr_debug("spi_match:%c 0x%x <= 0x%x <= 0x%x\n", + invert ? '!' : ' ', min, spi, max); + r = (spi >= min && spi <= max) ^ invert; + pr_debug(" result %s\n", r ? "PASS" : "FAILED"); + return r; +} + +static bool comp_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + struct ip_comp_hdr _comphdr; + const struct ip_comp_hdr *chdr; + const struct xt_ipcomp *compinfo = par->matchinfo; + + /* Must not be a fragment. */ + if (par->fragoff != 0) + return false; + + chdr = skb_header_pointer(skb, par->thoff, sizeof(_comphdr), &_comphdr); + if (chdr == NULL) { + /* We've been asked to examine this packet, and we + * can't. Hence, no choice but to drop. + */ + pr_debug("Dropping evil IPComp tinygram.\n"); + par->hotdrop = true; + return false; + } + + return spi_match(compinfo->spis[0], compinfo->spis[1], + ntohs(chdr->cpi), + !!(compinfo->invflags & XT_IPCOMP_INV_SPI)); +} + +static int comp_mt_check(const struct xt_mtchk_param *par) +{ + const struct xt_ipcomp *compinfo = par->matchinfo; + + /* Must specify no unknown invflags */ + if (compinfo->invflags & ~XT_IPCOMP_INV_MASK) { + pr_info_ratelimited("unknown flags %X\n", compinfo->invflags); + return -EINVAL; + } + return 0; +} + +static struct xt_match comp_mt_reg[] __read_mostly = { + { + .name = "ipcomp", + .family = NFPROTO_IPV4, + .match = comp_mt, + .matchsize = sizeof(struct xt_ipcomp), + .proto = IPPROTO_COMP, + .checkentry = comp_mt_check, + .me = THIS_MODULE, + }, + { + .name = "ipcomp", + .family = NFPROTO_IPV6, + .match = comp_mt, + .matchsize = sizeof(struct xt_ipcomp), + .proto = IPPROTO_COMP, + .checkentry = comp_mt_check, + .me = THIS_MODULE, + }, +}; + +static int __init comp_mt_init(void) +{ + return xt_register_matches(comp_mt_reg, ARRAY_SIZE(comp_mt_reg)); +} + +static void __exit comp_mt_exit(void) +{ + xt_unregister_matches(comp_mt_reg, ARRAY_SIZE(comp_mt_reg)); +} + +module_init(comp_mt_init); +module_exit(comp_mt_exit); diff --git a/net/netfilter/xt_iprange.c b/net/netfilter/xt_iprange.c new file mode 100644 index 000000000..0c9e014e3 --- /dev/null +++ b/net/netfilter/xt_iprange.c @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * xt_iprange - Netfilter module to match IP address ranges + * + * (C) 2003 Jozsef Kadlecsik + * (C) CC Computer Consultants GmbH, 2008 + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include + +static bool +iprange_mt4(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_iprange_mtinfo *info = par->matchinfo; + const struct iphdr *iph = ip_hdr(skb); + bool m; + + if (info->flags & IPRANGE_SRC) { + m = ntohl(iph->saddr) < ntohl(info->src_min.ip); + m |= ntohl(iph->saddr) > ntohl(info->src_max.ip); + m ^= !!(info->flags & IPRANGE_SRC_INV); + if (m) { + pr_debug("src IP %pI4 NOT in range %s%pI4-%pI4\n", + &iph->saddr, + (info->flags & IPRANGE_SRC_INV) ? "(INV) " : "", + &info->src_min.ip, + &info->src_max.ip); + return false; + } + } + if (info->flags & IPRANGE_DST) { + m = ntohl(iph->daddr) < ntohl(info->dst_min.ip); + m |= ntohl(iph->daddr) > ntohl(info->dst_max.ip); + m ^= !!(info->flags & IPRANGE_DST_INV); + if (m) { + pr_debug("dst IP %pI4 NOT in range %s%pI4-%pI4\n", + &iph->daddr, + (info->flags & IPRANGE_DST_INV) ? "(INV) " : "", + &info->dst_min.ip, + &info->dst_max.ip); + return false; + } + } + return true; +} + +static inline int +iprange_ipv6_lt(const struct in6_addr *a, const struct in6_addr *b) +{ + unsigned int i; + + for (i = 0; i < 4; ++i) { + if (a->s6_addr32[i] != b->s6_addr32[i]) + return ntohl(a->s6_addr32[i]) < ntohl(b->s6_addr32[i]); + } + + return 0; +} + +static bool +iprange_mt6(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_iprange_mtinfo *info = par->matchinfo; + const struct ipv6hdr *iph = ipv6_hdr(skb); + bool m; + + if (info->flags & IPRANGE_SRC) { + m = iprange_ipv6_lt(&iph->saddr, &info->src_min.in6); + m |= iprange_ipv6_lt(&info->src_max.in6, &iph->saddr); + m ^= !!(info->flags & IPRANGE_SRC_INV); + if (m) { + pr_debug("src IP %pI6 NOT in range %s%pI6-%pI6\n", + &iph->saddr, + (info->flags & IPRANGE_SRC_INV) ? "(INV) " : "", + &info->src_min.in6, + &info->src_max.in6); + return false; + } + } + if (info->flags & IPRANGE_DST) { + m = iprange_ipv6_lt(&iph->daddr, &info->dst_min.in6); + m |= iprange_ipv6_lt(&info->dst_max.in6, &iph->daddr); + m ^= !!(info->flags & IPRANGE_DST_INV); + if (m) { + pr_debug("dst IP %pI6 NOT in range %s%pI6-%pI6\n", + &iph->daddr, + (info->flags & IPRANGE_DST_INV) ? "(INV) " : "", + &info->dst_min.in6, + &info->dst_max.in6); + return false; + } + } + return true; +} + +static struct xt_match iprange_mt_reg[] __read_mostly = { + { + .name = "iprange", + .revision = 1, + .family = NFPROTO_IPV4, + .match = iprange_mt4, + .matchsize = sizeof(struct xt_iprange_mtinfo), + .me = THIS_MODULE, + }, + { + .name = "iprange", + .revision = 1, + .family = NFPROTO_IPV6, + .match = iprange_mt6, + .matchsize = sizeof(struct xt_iprange_mtinfo), + .me = THIS_MODULE, + }, +}; + +static int __init iprange_mt_init(void) +{ + return xt_register_matches(iprange_mt_reg, ARRAY_SIZE(iprange_mt_reg)); +} + +static void __exit iprange_mt_exit(void) +{ + xt_unregister_matches(iprange_mt_reg, ARRAY_SIZE(iprange_mt_reg)); +} + +module_init(iprange_mt_init); +module_exit(iprange_mt_exit); +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Jozsef Kadlecsik "); +MODULE_AUTHOR("Jan Engelhardt "); +MODULE_DESCRIPTION("Xtables: arbitrary IPv4 range matching"); +MODULE_ALIAS("ipt_iprange"); +MODULE_ALIAS("ip6t_iprange"); diff --git a/net/netfilter/xt_ipvs.c b/net/netfilter/xt_ipvs.c new file mode 100644 index 000000000..253c71cc9 --- /dev/null +++ b/net/netfilter/xt_ipvs.c @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * xt_ipvs - kernel module to match IPVS connection properties + * + * Author: Hannes Eder + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#ifdef CONFIG_IP_VS_IPV6 +#include +#endif +#include +#include +#include +#include +#include + +#include + +MODULE_AUTHOR("Hannes Eder "); +MODULE_DESCRIPTION("Xtables: match IPVS connection properties"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_ipvs"); +MODULE_ALIAS("ip6t_ipvs"); + +/* borrowed from xt_conntrack */ +static bool ipvs_mt_addrcmp(const union nf_inet_addr *kaddr, + const union nf_inet_addr *uaddr, + const union nf_inet_addr *umask, + unsigned int l3proto) +{ + if (l3proto == NFPROTO_IPV4) + return ((kaddr->ip ^ uaddr->ip) & umask->ip) == 0; +#ifdef CONFIG_IP_VS_IPV6 + else if (l3proto == NFPROTO_IPV6) + return ipv6_masked_addr_cmp(&kaddr->in6, &umask->in6, + &uaddr->in6) == 0; +#endif + else + return false; +} + +static bool +ipvs_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_ipvs_mtinfo *data = par->matchinfo; + struct netns_ipvs *ipvs = net_ipvs(xt_net(par)); + /* ipvs_mt_check ensures that family is only NFPROTO_IPV[46]. */ + const u_int8_t family = xt_family(par); + struct ip_vs_iphdr iph; + struct ip_vs_protocol *pp; + struct ip_vs_conn *cp; + bool match = true; + + if (data->bitmask == XT_IPVS_IPVS_PROPERTY) { + match = skb->ipvs_property ^ + !!(data->invert & XT_IPVS_IPVS_PROPERTY); + goto out; + } + + /* other flags than XT_IPVS_IPVS_PROPERTY are set */ + if (!skb->ipvs_property) { + match = false; + goto out; + } + + ip_vs_fill_iph_skb(family, skb, true, &iph); + + if (data->bitmask & XT_IPVS_PROTO) + if ((iph.protocol == data->l4proto) ^ + !(data->invert & XT_IPVS_PROTO)) { + match = false; + goto out; + } + + pp = ip_vs_proto_get(iph.protocol); + if (unlikely(!pp)) { + match = false; + goto out; + } + + /* + * Check if the packet belongs to an existing entry + */ + cp = pp->conn_out_get(ipvs, family, skb, &iph); + if (unlikely(cp == NULL)) { + match = false; + goto out; + } + + /* + * We found a connection, i.e. ct != 0, make sure to call + * __ip_vs_conn_put before returning. In our case jump to out_put_con. + */ + + if (data->bitmask & XT_IPVS_VPORT) + if ((cp->vport == data->vport) ^ + !(data->invert & XT_IPVS_VPORT)) { + match = false; + goto out_put_cp; + } + + if (data->bitmask & XT_IPVS_VPORTCTL) + if ((cp->control != NULL && + cp->control->vport == data->vportctl) ^ + !(data->invert & XT_IPVS_VPORTCTL)) { + match = false; + goto out_put_cp; + } + + if (data->bitmask & XT_IPVS_DIR) { + enum ip_conntrack_info ctinfo; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + + if (ct == NULL) { + match = false; + goto out_put_cp; + } + + if ((ctinfo >= IP_CT_IS_REPLY) ^ + !!(data->invert & XT_IPVS_DIR)) { + match = false; + goto out_put_cp; + } + } + + if (data->bitmask & XT_IPVS_METHOD) + if (((cp->flags & IP_VS_CONN_F_FWD_MASK) == data->fwd_method) ^ + !(data->invert & XT_IPVS_METHOD)) { + match = false; + goto out_put_cp; + } + + if (data->bitmask & XT_IPVS_VADDR) { + if (ipvs_mt_addrcmp(&cp->vaddr, &data->vaddr, + &data->vmask, family) ^ + !(data->invert & XT_IPVS_VADDR)) { + match = false; + goto out_put_cp; + } + } + +out_put_cp: + __ip_vs_conn_put(cp); +out: + pr_debug("match=%d\n", match); + return match; +} + +static int ipvs_mt_check(const struct xt_mtchk_param *par) +{ + if (par->family != NFPROTO_IPV4 +#ifdef CONFIG_IP_VS_IPV6 + && par->family != NFPROTO_IPV6 +#endif + ) { + pr_info_ratelimited("protocol family %u not supported\n", + par->family); + return -EINVAL; + } + + return 0; +} + +static struct xt_match xt_ipvs_mt_reg __read_mostly = { + .name = "ipvs", + .revision = 0, + .family = NFPROTO_UNSPEC, + .match = ipvs_mt, + .checkentry = ipvs_mt_check, + .matchsize = XT_ALIGN(sizeof(struct xt_ipvs_mtinfo)), + .me = THIS_MODULE, +}; + +static int __init ipvs_mt_init(void) +{ + return xt_register_match(&xt_ipvs_mt_reg); +} + +static void __exit ipvs_mt_exit(void) +{ + xt_unregister_match(&xt_ipvs_mt_reg); +} + +module_init(ipvs_mt_init); +module_exit(ipvs_mt_exit); diff --git a/net/netfilter/xt_l2tp.c b/net/netfilter/xt_l2tp.c new file mode 100644 index 000000000..a61eb81e9 --- /dev/null +++ b/net/netfilter/xt_l2tp.c @@ -0,0 +1,355 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Kernel module to match L2TP header parameters. */ + +/* (C) 2013 James Chapman + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +/* L2TP header masks */ +#define L2TP_HDR_T_BIT 0x8000 +#define L2TP_HDR_L_BIT 0x4000 +#define L2TP_HDR_VER 0x000f + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("James Chapman "); +MODULE_DESCRIPTION("Xtables: L2TP header match"); +MODULE_ALIAS("ipt_l2tp"); +MODULE_ALIAS("ip6t_l2tp"); + +/* The L2TP fields that can be matched */ +struct l2tp_data { + u32 tid; + u32 sid; + u8 type; + u8 version; +}; + +union l2tp_val { + __be16 val16[2]; + __be32 val32; +}; + +static bool l2tp_match(const struct xt_l2tp_info *info, struct l2tp_data *data) +{ + if ((info->flags & XT_L2TP_TYPE) && (info->type != data->type)) + return false; + + if ((info->flags & XT_L2TP_VERSION) && (info->version != data->version)) + return false; + + /* Check tid only for L2TPv3 control or any L2TPv2 packets */ + if ((info->flags & XT_L2TP_TID) && + ((data->type == XT_L2TP_TYPE_CONTROL) || (data->version == 2)) && + (info->tid != data->tid)) + return false; + + /* Check sid only for L2TP data packets */ + if ((info->flags & XT_L2TP_SID) && (data->type == XT_L2TP_TYPE_DATA) && + (info->sid != data->sid)) + return false; + + return true; +} + +/* Parse L2TP header fields when UDP encapsulation is used. Handles + * L2TPv2 and L2TPv3. Note the L2TPv3 control and data packets have a + * different format. See + * RFC2661, Section 3.1, L2TPv2 Header Format + * RFC3931, Section 3.2.1, L2TPv3 Control Message Header + * RFC3931, Section 3.2.2, L2TPv3 Data Message Header + * RFC3931, Section 4.1.2.1, L2TPv3 Session Header over UDP + */ +static bool l2tp_udp_mt(const struct sk_buff *skb, struct xt_action_param *par, u16 thoff) +{ + const struct xt_l2tp_info *info = par->matchinfo; + int uhlen = sizeof(struct udphdr); + int offs = thoff + uhlen; + union l2tp_val *lh; + union l2tp_val lhbuf; + u16 flags; + struct l2tp_data data = { 0, }; + + if (par->fragoff != 0) + return false; + + /* Extract L2TP header fields. The flags in the first 16 bits + * tell us where the other fields are. + */ + lh = skb_header_pointer(skb, offs, 2, &lhbuf); + if (lh == NULL) + return false; + + flags = ntohs(lh->val16[0]); + if (flags & L2TP_HDR_T_BIT) + data.type = XT_L2TP_TYPE_CONTROL; + else + data.type = XT_L2TP_TYPE_DATA; + data.version = (u8) flags & L2TP_HDR_VER; + + /* Now extract the L2TP tid/sid. These are in different places + * for L2TPv2 (rfc2661) and L2TPv3 (rfc3931). For L2TPv2, we + * must also check to see if the length field is present, + * since this affects the offsets into the packet of the + * tid/sid fields. + */ + if (data.version == 3) { + lh = skb_header_pointer(skb, offs + 4, 4, &lhbuf); + if (lh == NULL) + return false; + if (data.type == XT_L2TP_TYPE_CONTROL) + data.tid = ntohl(lh->val32); + else + data.sid = ntohl(lh->val32); + } else if (data.version == 2) { + if (flags & L2TP_HDR_L_BIT) + offs += 2; + lh = skb_header_pointer(skb, offs + 2, 4, &lhbuf); + if (lh == NULL) + return false; + data.tid = (u32) ntohs(lh->val16[0]); + data.sid = (u32) ntohs(lh->val16[1]); + } else + return false; + + return l2tp_match(info, &data); +} + +/* Parse L2TP header fields for IP encapsulation (no UDP header). + * L2TPv3 data packets have a different form with IP encap. See + * RC3931, Section 4.1.1.1, L2TPv3 Session Header over IP. + * RC3931, Section 4.1.1.2, L2TPv3 Control and Data Traffic over IP. + */ +static bool l2tp_ip_mt(const struct sk_buff *skb, struct xt_action_param *par, u16 thoff) +{ + const struct xt_l2tp_info *info = par->matchinfo; + union l2tp_val *lh; + union l2tp_val lhbuf; + struct l2tp_data data = { 0, }; + + /* For IP encap, the L2TP sid is the first 32-bits. */ + lh = skb_header_pointer(skb, thoff, sizeof(lhbuf), &lhbuf); + if (lh == NULL) + return false; + if (lh->val32 == 0) { + /* Must be a control packet. The L2TP tid is further + * into the packet. + */ + data.type = XT_L2TP_TYPE_CONTROL; + lh = skb_header_pointer(skb, thoff + 8, sizeof(lhbuf), + &lhbuf); + if (lh == NULL) + return false; + data.tid = ntohl(lh->val32); + } else { + data.sid = ntohl(lh->val32); + data.type = XT_L2TP_TYPE_DATA; + } + + data.version = 3; + + return l2tp_match(info, &data); +} + +static bool l2tp_mt4(const struct sk_buff *skb, struct xt_action_param *par) +{ + struct iphdr *iph = ip_hdr(skb); + u8 ipproto = iph->protocol; + + /* l2tp_mt_check4 already restricts the transport protocol */ + switch (ipproto) { + case IPPROTO_UDP: + return l2tp_udp_mt(skb, par, par->thoff); + case IPPROTO_L2TP: + return l2tp_ip_mt(skb, par, par->thoff); + } + + return false; +} + +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) +static bool l2tp_mt6(const struct sk_buff *skb, struct xt_action_param *par) +{ + unsigned int thoff = 0; + unsigned short fragoff = 0; + int ipproto; + + ipproto = ipv6_find_hdr(skb, &thoff, -1, &fragoff, NULL); + if (fragoff != 0) + return false; + + /* l2tp_mt_check6 already restricts the transport protocol */ + switch (ipproto) { + case IPPROTO_UDP: + return l2tp_udp_mt(skb, par, thoff); + case IPPROTO_L2TP: + return l2tp_ip_mt(skb, par, thoff); + } + + return false; +} +#endif + +static int l2tp_mt_check(const struct xt_mtchk_param *par) +{ + const struct xt_l2tp_info *info = par->matchinfo; + + /* Check for invalid flags */ + if (info->flags & ~(XT_L2TP_TID | XT_L2TP_SID | XT_L2TP_VERSION | + XT_L2TP_TYPE)) { + pr_info_ratelimited("unknown flags: %x\n", info->flags); + return -EINVAL; + } + + /* At least one of tid, sid or type=control must be specified */ + if ((!(info->flags & XT_L2TP_TID)) && + (!(info->flags & XT_L2TP_SID)) && + ((!(info->flags & XT_L2TP_TYPE)) || + (info->type != XT_L2TP_TYPE_CONTROL))) { + pr_info_ratelimited("invalid flags combination: %x\n", + info->flags); + return -EINVAL; + } + + /* If version 2 is specified, check that incompatible params + * are not supplied + */ + if (info->flags & XT_L2TP_VERSION) { + if ((info->version < 2) || (info->version > 3)) { + pr_info_ratelimited("wrong L2TP version: %u\n", + info->version); + return -EINVAL; + } + + if (info->version == 2) { + if ((info->flags & XT_L2TP_TID) && + (info->tid > 0xffff)) { + pr_info_ratelimited("v2 tid > 0xffff: %u\n", + info->tid); + return -EINVAL; + } + if ((info->flags & XT_L2TP_SID) && + (info->sid > 0xffff)) { + pr_info_ratelimited("v2 sid > 0xffff: %u\n", + info->sid); + return -EINVAL; + } + } + } + + return 0; +} + +static int l2tp_mt_check4(const struct xt_mtchk_param *par) +{ + const struct xt_l2tp_info *info = par->matchinfo; + const struct ipt_entry *e = par->entryinfo; + const struct ipt_ip *ip = &e->ip; + int ret; + + ret = l2tp_mt_check(par); + if (ret != 0) + return ret; + + if ((ip->proto != IPPROTO_UDP) && + (ip->proto != IPPROTO_L2TP)) { + pr_info_ratelimited("missing protocol rule (udp|l2tpip)\n"); + return -EINVAL; + } + + if ((ip->proto == IPPROTO_L2TP) && + (info->version == 2)) { + pr_info_ratelimited("v2 doesn't support IP mode\n"); + return -EINVAL; + } + + return 0; +} + +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) +static int l2tp_mt_check6(const struct xt_mtchk_param *par) +{ + const struct xt_l2tp_info *info = par->matchinfo; + const struct ip6t_entry *e = par->entryinfo; + const struct ip6t_ip6 *ip = &e->ipv6; + int ret; + + ret = l2tp_mt_check(par); + if (ret != 0) + return ret; + + if ((ip->proto != IPPROTO_UDP) && + (ip->proto != IPPROTO_L2TP)) { + pr_info_ratelimited("missing protocol rule (udp|l2tpip)\n"); + return -EINVAL; + } + + if ((ip->proto == IPPROTO_L2TP) && + (info->version == 2)) { + pr_info_ratelimited("v2 doesn't support IP mode\n"); + return -EINVAL; + } + + return 0; +} +#endif + +static struct xt_match l2tp_mt_reg[] __read_mostly = { + { + .name = "l2tp", + .revision = 0, + .family = NFPROTO_IPV4, + .match = l2tp_mt4, + .matchsize = XT_ALIGN(sizeof(struct xt_l2tp_info)), + .checkentry = l2tp_mt_check4, + .hooks = ((1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_FORWARD)), + .me = THIS_MODULE, + }, +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + { + .name = "l2tp", + .revision = 0, + .family = NFPROTO_IPV6, + .match = l2tp_mt6, + .matchsize = XT_ALIGN(sizeof(struct xt_l2tp_info)), + .checkentry = l2tp_mt_check6, + .hooks = ((1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_FORWARD)), + .me = THIS_MODULE, + }, +#endif +}; + +static int __init l2tp_mt_init(void) +{ + return xt_register_matches(&l2tp_mt_reg[0], ARRAY_SIZE(l2tp_mt_reg)); +} + +static void __exit l2tp_mt_exit(void) +{ + xt_unregister_matches(&l2tp_mt_reg[0], ARRAY_SIZE(l2tp_mt_reg)); +} + +module_init(l2tp_mt_init); +module_exit(l2tp_mt_exit); diff --git a/net/netfilter/xt_length.c b/net/netfilter/xt_length.c new file mode 100644 index 000000000..ca730cedb --- /dev/null +++ b/net/netfilter/xt_length.c @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Kernel module to match packet length. */ +/* (C) 1999-2001 James Morris + */ + +#include +#include +#include +#include + +#include +#include + +MODULE_AUTHOR("James Morris "); +MODULE_DESCRIPTION("Xtables: Packet length (Layer3,4,5) match"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_length"); +MODULE_ALIAS("ip6t_length"); + +static bool +length_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_length_info *info = par->matchinfo; + u32 pktlen = skb_ip_totlen(skb); + + return (pktlen >= info->min && pktlen <= info->max) ^ info->invert; +} + +static bool +length_mt6(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_length_info *info = par->matchinfo; + u32 pktlen = skb->len; + + return (pktlen >= info->min && pktlen <= info->max) ^ info->invert; +} + +static struct xt_match length_mt_reg[] __read_mostly = { + { + .name = "length", + .family = NFPROTO_IPV4, + .match = length_mt, + .matchsize = sizeof(struct xt_length_info), + .me = THIS_MODULE, + }, + { + .name = "length", + .family = NFPROTO_IPV6, + .match = length_mt6, + .matchsize = sizeof(struct xt_length_info), + .me = THIS_MODULE, + }, +}; + +static int __init length_mt_init(void) +{ + return xt_register_matches(length_mt_reg, ARRAY_SIZE(length_mt_reg)); +} + +static void __exit length_mt_exit(void) +{ + xt_unregister_matches(length_mt_reg, ARRAY_SIZE(length_mt_reg)); +} + +module_init(length_mt_init); +module_exit(length_mt_exit); diff --git a/net/netfilter/xt_limit.c b/net/netfilter/xt_limit.c new file mode 100644 index 000000000..8b4fd2785 --- /dev/null +++ b/net/netfilter/xt_limit.c @@ -0,0 +1,215 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* (C) 1999 Jérôme de Vivie + * (C) 1999 Hervé Eychenne + * (C) 2006-2012 Patrick McHardy + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include + +#include +#include + +struct xt_limit_priv { + unsigned long prev; + u32 credit; +}; + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Herve Eychenne "); +MODULE_DESCRIPTION("Xtables: rate-limit match"); +MODULE_ALIAS("ipt_limit"); +MODULE_ALIAS("ip6t_limit"); + +/* The algorithm used is the Simple Token Bucket Filter (TBF) + * see net/sched/sch_tbf.c in the linux source tree + */ + +/* Rusty: This is my (non-mathematically-inclined) understanding of + this algorithm. The `average rate' in jiffies becomes your initial + amount of credit `credit' and the most credit you can ever have + `credit_cap'. The `peak rate' becomes the cost of passing the + test, `cost'. + + `prev' tracks the last packet hit: you gain one credit per jiffy. + If you get credit balance more than this, the extra credit is + discarded. Every time the match passes, you lose `cost' credits; + if you don't have that many, the test fails. + + See Alexey's formal explanation in net/sched/sch_tbf.c. + + To get the maximum range, we multiply by this factor (ie. you get N + credits per jiffy). We want to allow a rate as low as 1 per day + (slowest userspace tool allows), which means + CREDITS_PER_JIFFY*HZ*60*60*24 < 2^32. ie. */ +#define MAX_CPJ (0xFFFFFFFF / (HZ*60*60*24)) + +/* Repeated shift and or gives us all 1s, final shift and add 1 gives + * us the power of 2 below the theoretical max, so GCC simply does a + * shift. */ +#define _POW2_BELOW2(x) ((x)|((x)>>1)) +#define _POW2_BELOW4(x) (_POW2_BELOW2(x)|_POW2_BELOW2((x)>>2)) +#define _POW2_BELOW8(x) (_POW2_BELOW4(x)|_POW2_BELOW4((x)>>4)) +#define _POW2_BELOW16(x) (_POW2_BELOW8(x)|_POW2_BELOW8((x)>>8)) +#define _POW2_BELOW32(x) (_POW2_BELOW16(x)|_POW2_BELOW16((x)>>16)) +#define POW2_BELOW32(x) ((_POW2_BELOW32(x)>>1) + 1) + +#define CREDITS_PER_JIFFY POW2_BELOW32(MAX_CPJ) + +static bool +limit_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_rateinfo *r = par->matchinfo; + struct xt_limit_priv *priv = r->master; + unsigned long now; + u32 old_credit, new_credit, credit_increase = 0; + bool ret; + + /* fastpath if there is nothing to update */ + if ((READ_ONCE(priv->credit) < r->cost) && (READ_ONCE(priv->prev) == jiffies)) + return false; + + do { + now = jiffies; + credit_increase += (now - xchg(&priv->prev, now)) * CREDITS_PER_JIFFY; + old_credit = READ_ONCE(priv->credit); + new_credit = old_credit; + new_credit += credit_increase; + if (new_credit > r->credit_cap) + new_credit = r->credit_cap; + if (new_credit >= r->cost) { + ret = true; + new_credit -= r->cost; + } else { + ret = false; + } + } while (cmpxchg(&priv->credit, old_credit, new_credit) != old_credit); + + return ret; +} + +/* Precision saver. */ +static u32 user2credits(u32 user) +{ + /* If multiplying would overflow... */ + if (user > 0xFFFFFFFF / (HZ*CREDITS_PER_JIFFY)) + /* Divide first. */ + return (user / XT_LIMIT_SCALE) * HZ * CREDITS_PER_JIFFY; + + return (user * HZ * CREDITS_PER_JIFFY) / XT_LIMIT_SCALE; +} + +static int limit_mt_check(const struct xt_mtchk_param *par) +{ + struct xt_rateinfo *r = par->matchinfo; + struct xt_limit_priv *priv; + + /* Check for overflow. */ + if (r->burst == 0 + || user2credits(r->avg * r->burst) < user2credits(r->avg)) { + pr_info_ratelimited("Overflow, try lower: %u/%u\n", + r->avg, r->burst); + return -ERANGE; + } + + priv = kmalloc(sizeof(*priv), GFP_KERNEL); + if (priv == NULL) + return -ENOMEM; + + /* For SMP, we only want to use one set of state. */ + r->master = priv; + /* User avg in seconds * XT_LIMIT_SCALE: convert to jiffies * + 128. */ + priv->prev = jiffies; + priv->credit = user2credits(r->avg * r->burst); /* Credits full. */ + if (r->cost == 0) { + r->credit_cap = priv->credit; /* Credits full. */ + r->cost = user2credits(r->avg); + } + + return 0; +} + +static void limit_mt_destroy(const struct xt_mtdtor_param *par) +{ + const struct xt_rateinfo *info = par->matchinfo; + + kfree(info->master); +} + +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT +struct compat_xt_rateinfo { + u_int32_t avg; + u_int32_t burst; + + compat_ulong_t prev; + u_int32_t credit; + u_int32_t credit_cap, cost; + + u_int32_t master; +}; + +/* To keep the full "prev" timestamp, the upper 32 bits are stored in the + * master pointer, which does not need to be preserved. */ +static void limit_mt_compat_from_user(void *dst, const void *src) +{ + const struct compat_xt_rateinfo *cm = src; + struct xt_rateinfo m = { + .avg = cm->avg, + .burst = cm->burst, + .prev = cm->prev | (unsigned long)cm->master << 32, + .credit = cm->credit, + .credit_cap = cm->credit_cap, + .cost = cm->cost, + }; + memcpy(dst, &m, sizeof(m)); +} + +static int limit_mt_compat_to_user(void __user *dst, const void *src) +{ + const struct xt_rateinfo *m = src; + struct compat_xt_rateinfo cm = { + .avg = m->avg, + .burst = m->burst, + .prev = m->prev, + .credit = m->credit, + .credit_cap = m->credit_cap, + .cost = m->cost, + .master = m->prev >> 32, + }; + return copy_to_user(dst, &cm, sizeof(cm)) ? -EFAULT : 0; +} +#endif /* CONFIG_NETFILTER_XTABLES_COMPAT */ + +static struct xt_match limit_mt_reg __read_mostly = { + .name = "limit", + .revision = 0, + .family = NFPROTO_UNSPEC, + .match = limit_mt, + .checkentry = limit_mt_check, + .destroy = limit_mt_destroy, + .matchsize = sizeof(struct xt_rateinfo), +#ifdef CONFIG_NETFILTER_XTABLES_COMPAT + .compatsize = sizeof(struct compat_xt_rateinfo), + .compat_from_user = limit_mt_compat_from_user, + .compat_to_user = limit_mt_compat_to_user, +#endif + .usersize = offsetof(struct xt_rateinfo, prev), + .me = THIS_MODULE, +}; + +static int __init limit_mt_init(void) +{ + return xt_register_match(&limit_mt_reg); +} + +static void __exit limit_mt_exit(void) +{ + xt_unregister_match(&limit_mt_reg); +} + +module_init(limit_mt_init); +module_exit(limit_mt_exit); diff --git a/net/netfilter/xt_mac.c b/net/netfilter/xt_mac.c new file mode 100644 index 000000000..81649da57 --- /dev/null +++ b/net/netfilter/xt_mac.c @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Kernel module to match MAC address parameters. */ + +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2004 Netfilter Core Team + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Netfilter Core Team "); +MODULE_DESCRIPTION("Xtables: MAC address match"); +MODULE_ALIAS("ipt_mac"); +MODULE_ALIAS("ip6t_mac"); + +static bool mac_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_mac_info *info = par->matchinfo; + bool ret; + + if (skb->dev == NULL || skb->dev->type != ARPHRD_ETHER) + return false; + if (skb_mac_header(skb) < skb->head) + return false; + if (skb_mac_header(skb) + ETH_HLEN > skb->data) + return false; + ret = ether_addr_equal(eth_hdr(skb)->h_source, info->srcaddr); + ret ^= info->invert; + return ret; +} + +static struct xt_match mac_mt_reg __read_mostly = { + .name = "mac", + .revision = 0, + .family = NFPROTO_UNSPEC, + .match = mac_mt, + .matchsize = sizeof(struct xt_mac_info), + .hooks = (1 << NF_INET_PRE_ROUTING) | (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_FORWARD), + .me = THIS_MODULE, +}; + +static int __init mac_mt_init(void) +{ + return xt_register_match(&mac_mt_reg); +} + +static void __exit mac_mt_exit(void) +{ + xt_unregister_match(&mac_mt_reg); +} + +module_init(mac_mt_init); +module_exit(mac_mt_exit); diff --git a/net/netfilter/xt_mark.c b/net/netfilter/xt_mark.c new file mode 100644 index 000000000..1ad74b592 --- /dev/null +++ b/net/netfilter/xt_mark.c @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * xt_mark - Netfilter module to match NFMARK value + * + * (C) 1999-2001 Marc Boucher + * Copyright © CC Computer Consultants GmbH, 2007 - 2008 + * Jan Engelhardt + */ + +#include +#include + +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Marc Boucher "); +MODULE_DESCRIPTION("Xtables: packet mark operations"); +MODULE_ALIAS("ipt_mark"); +MODULE_ALIAS("ip6t_mark"); +MODULE_ALIAS("ipt_MARK"); +MODULE_ALIAS("ip6t_MARK"); +MODULE_ALIAS("arpt_MARK"); + +static unsigned int +mark_tg(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_mark_tginfo2 *info = par->targinfo; + + skb->mark = (skb->mark & ~info->mask) ^ info->mark; + return XT_CONTINUE; +} + +static bool +mark_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_mark_mtinfo1 *info = par->matchinfo; + + return ((skb->mark & info->mask) == info->mark) ^ info->invert; +} + +static struct xt_target mark_tg_reg __read_mostly = { + .name = "MARK", + .revision = 2, + .family = NFPROTO_UNSPEC, + .target = mark_tg, + .targetsize = sizeof(struct xt_mark_tginfo2), + .me = THIS_MODULE, +}; + +static struct xt_match mark_mt_reg __read_mostly = { + .name = "mark", + .revision = 1, + .family = NFPROTO_UNSPEC, + .match = mark_mt, + .matchsize = sizeof(struct xt_mark_mtinfo1), + .me = THIS_MODULE, +}; + +static int __init mark_mt_init(void) +{ + int ret; + + ret = xt_register_target(&mark_tg_reg); + if (ret < 0) + return ret; + ret = xt_register_match(&mark_mt_reg); + if (ret < 0) { + xt_unregister_target(&mark_tg_reg); + return ret; + } + return 0; +} + +static void __exit mark_mt_exit(void) +{ + xt_unregister_match(&mark_mt_reg); + xt_unregister_target(&mark_tg_reg); +} + +module_init(mark_mt_init); +module_exit(mark_mt_exit); diff --git a/net/netfilter/xt_multiport.c b/net/netfilter/xt_multiport.c new file mode 100644 index 000000000..44a00f5ac --- /dev/null +++ b/net/netfilter/xt_multiport.c @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Kernel module to match one of a list of TCP/UDP(-Lite)/SCTP/DCCP ports: + ports are in the same place so we can treat them as equal. */ + +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2004 Netfilter Core Team + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Netfilter Core Team "); +MODULE_DESCRIPTION("Xtables: multiple port matching for TCP, UDP, UDP-Lite, SCTP and DCCP"); +MODULE_ALIAS("ipt_multiport"); +MODULE_ALIAS("ip6t_multiport"); + +/* Returns 1 if the port is matched by the test, 0 otherwise. */ +static inline bool +ports_match_v1(const struct xt_multiport_v1 *minfo, + u_int16_t src, u_int16_t dst) +{ + unsigned int i; + u_int16_t s, e; + + for (i = 0; i < minfo->count; i++) { + s = minfo->ports[i]; + + if (minfo->pflags[i]) { + /* range port matching */ + e = minfo->ports[++i]; + pr_debug("src or dst matches with %d-%d?\n", s, e); + + switch (minfo->flags) { + case XT_MULTIPORT_SOURCE: + if (src >= s && src <= e) + return true ^ minfo->invert; + break; + case XT_MULTIPORT_DESTINATION: + if (dst >= s && dst <= e) + return true ^ minfo->invert; + break; + case XT_MULTIPORT_EITHER: + if ((dst >= s && dst <= e) || + (src >= s && src <= e)) + return true ^ minfo->invert; + break; + default: + break; + } + } else { + /* exact port matching */ + pr_debug("src or dst matches with %d?\n", s); + + switch (minfo->flags) { + case XT_MULTIPORT_SOURCE: + if (src == s) + return true ^ minfo->invert; + break; + case XT_MULTIPORT_DESTINATION: + if (dst == s) + return true ^ minfo->invert; + break; + case XT_MULTIPORT_EITHER: + if (src == s || dst == s) + return true ^ minfo->invert; + break; + default: + break; + } + } + } + + return minfo->invert; +} + +static bool +multiport_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const __be16 *pptr; + __be16 _ports[2]; + const struct xt_multiport_v1 *multiinfo = par->matchinfo; + + if (par->fragoff != 0) + return false; + + pptr = skb_header_pointer(skb, par->thoff, sizeof(_ports), _ports); + if (pptr == NULL) { + /* We've been asked to examine this packet, and we + * can't. Hence, no choice but to drop. + */ + pr_debug("Dropping evil offset=0 tinygram.\n"); + par->hotdrop = true; + return false; + } + + return ports_match_v1(multiinfo, ntohs(pptr[0]), ntohs(pptr[1])); +} + +static inline bool +check(u_int16_t proto, + u_int8_t ip_invflags, + u_int8_t match_flags, + u_int8_t count) +{ + /* Must specify supported protocol, no unknown flags or bad count */ + return (proto == IPPROTO_TCP || proto == IPPROTO_UDP + || proto == IPPROTO_UDPLITE + || proto == IPPROTO_SCTP || proto == IPPROTO_DCCP) + && !(ip_invflags & XT_INV_PROTO) + && (match_flags == XT_MULTIPORT_SOURCE + || match_flags == XT_MULTIPORT_DESTINATION + || match_flags == XT_MULTIPORT_EITHER) + && count <= XT_MULTI_PORTS; +} + +static int multiport_mt_check(const struct xt_mtchk_param *par) +{ + const struct ipt_ip *ip = par->entryinfo; + const struct xt_multiport_v1 *multiinfo = par->matchinfo; + + return check(ip->proto, ip->invflags, multiinfo->flags, + multiinfo->count) ? 0 : -EINVAL; +} + +static int multiport_mt6_check(const struct xt_mtchk_param *par) +{ + const struct ip6t_ip6 *ip = par->entryinfo; + const struct xt_multiport_v1 *multiinfo = par->matchinfo; + + return check(ip->proto, ip->invflags, multiinfo->flags, + multiinfo->count) ? 0 : -EINVAL; +} + +static struct xt_match multiport_mt_reg[] __read_mostly = { + { + .name = "multiport", + .family = NFPROTO_IPV4, + .revision = 1, + .checkentry = multiport_mt_check, + .match = multiport_mt, + .matchsize = sizeof(struct xt_multiport_v1), + .me = THIS_MODULE, + }, + { + .name = "multiport", + .family = NFPROTO_IPV6, + .revision = 1, + .checkentry = multiport_mt6_check, + .match = multiport_mt, + .matchsize = sizeof(struct xt_multiport_v1), + .me = THIS_MODULE, + }, +}; + +static int __init multiport_mt_init(void) +{ + return xt_register_matches(multiport_mt_reg, + ARRAY_SIZE(multiport_mt_reg)); +} + +static void __exit multiport_mt_exit(void) +{ + xt_unregister_matches(multiport_mt_reg, ARRAY_SIZE(multiport_mt_reg)); +} + +module_init(multiport_mt_init); +module_exit(multiport_mt_exit); diff --git a/net/netfilter/xt_nat.c b/net/netfilter/xt_nat.c new file mode 100644 index 000000000..b4f7bbc3f --- /dev/null +++ b/net/netfilter/xt_nat.c @@ -0,0 +1,247 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2006 Netfilter Core Team + * (C) 2011 Patrick McHardy + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include + +static int xt_nat_checkentry_v0(const struct xt_tgchk_param *par) +{ + const struct nf_nat_ipv4_multi_range_compat *mr = par->targinfo; + + if (mr->rangesize != 1) { + pr_info_ratelimited("multiple ranges no longer supported\n"); + return -EINVAL; + } + return nf_ct_netns_get(par->net, par->family); +} + +static int xt_nat_checkentry(const struct xt_tgchk_param *par) +{ + return nf_ct_netns_get(par->net, par->family); +} + +static void xt_nat_destroy(const struct xt_tgdtor_param *par) +{ + nf_ct_netns_put(par->net, par->family); +} + +static void xt_nat_convert_range(struct nf_nat_range2 *dst, + const struct nf_nat_ipv4_range *src) +{ + memset(&dst->min_addr, 0, sizeof(dst->min_addr)); + memset(&dst->max_addr, 0, sizeof(dst->max_addr)); + memset(&dst->base_proto, 0, sizeof(dst->base_proto)); + + dst->flags = src->flags; + dst->min_addr.ip = src->min_ip; + dst->max_addr.ip = src->max_ip; + dst->min_proto = src->min; + dst->max_proto = src->max; +} + +static unsigned int +xt_snat_target_v0(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct nf_nat_ipv4_multi_range_compat *mr = par->targinfo; + struct nf_nat_range2 range; + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + + ct = nf_ct_get(skb, &ctinfo); + WARN_ON(!(ct != NULL && + (ctinfo == IP_CT_NEW || ctinfo == IP_CT_RELATED || + ctinfo == IP_CT_RELATED_REPLY))); + + xt_nat_convert_range(&range, &mr->range[0]); + return nf_nat_setup_info(ct, &range, NF_NAT_MANIP_SRC); +} + +static unsigned int +xt_dnat_target_v0(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct nf_nat_ipv4_multi_range_compat *mr = par->targinfo; + struct nf_nat_range2 range; + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + + ct = nf_ct_get(skb, &ctinfo); + WARN_ON(!(ct != NULL && + (ctinfo == IP_CT_NEW || ctinfo == IP_CT_RELATED))); + + xt_nat_convert_range(&range, &mr->range[0]); + return nf_nat_setup_info(ct, &range, NF_NAT_MANIP_DST); +} + +static unsigned int +xt_snat_target_v1(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct nf_nat_range *range_v1 = par->targinfo; + struct nf_nat_range2 range; + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + + ct = nf_ct_get(skb, &ctinfo); + WARN_ON(!(ct != NULL && + (ctinfo == IP_CT_NEW || ctinfo == IP_CT_RELATED || + ctinfo == IP_CT_RELATED_REPLY))); + + memcpy(&range, range_v1, sizeof(*range_v1)); + memset(&range.base_proto, 0, sizeof(range.base_proto)); + + return nf_nat_setup_info(ct, &range, NF_NAT_MANIP_SRC); +} + +static unsigned int +xt_dnat_target_v1(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct nf_nat_range *range_v1 = par->targinfo; + struct nf_nat_range2 range; + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + + ct = nf_ct_get(skb, &ctinfo); + WARN_ON(!(ct != NULL && + (ctinfo == IP_CT_NEW || ctinfo == IP_CT_RELATED))); + + memcpy(&range, range_v1, sizeof(*range_v1)); + memset(&range.base_proto, 0, sizeof(range.base_proto)); + + return nf_nat_setup_info(ct, &range, NF_NAT_MANIP_DST); +} + +static unsigned int +xt_snat_target_v2(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct nf_nat_range2 *range = par->targinfo; + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + + ct = nf_ct_get(skb, &ctinfo); + WARN_ON(!(ct != NULL && + (ctinfo == IP_CT_NEW || ctinfo == IP_CT_RELATED || + ctinfo == IP_CT_RELATED_REPLY))); + + return nf_nat_setup_info(ct, range, NF_NAT_MANIP_SRC); +} + +static unsigned int +xt_dnat_target_v2(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct nf_nat_range2 *range = par->targinfo; + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + + ct = nf_ct_get(skb, &ctinfo); + WARN_ON(!(ct != NULL && + (ctinfo == IP_CT_NEW || ctinfo == IP_CT_RELATED))); + + return nf_nat_setup_info(ct, range, NF_NAT_MANIP_DST); +} + +static struct xt_target xt_nat_target_reg[] __read_mostly = { + { + .name = "SNAT", + .revision = 0, + .checkentry = xt_nat_checkentry_v0, + .destroy = xt_nat_destroy, + .target = xt_snat_target_v0, + .targetsize = sizeof(struct nf_nat_ipv4_multi_range_compat), + .family = NFPROTO_IPV4, + .table = "nat", + .hooks = (1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_IN), + .me = THIS_MODULE, + }, + { + .name = "DNAT", + .revision = 0, + .checkentry = xt_nat_checkentry_v0, + .destroy = xt_nat_destroy, + .target = xt_dnat_target_v0, + .targetsize = sizeof(struct nf_nat_ipv4_multi_range_compat), + .family = NFPROTO_IPV4, + .table = "nat", + .hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_OUT), + .me = THIS_MODULE, + }, + { + .name = "SNAT", + .revision = 1, + .checkentry = xt_nat_checkentry, + .destroy = xt_nat_destroy, + .target = xt_snat_target_v1, + .targetsize = sizeof(struct nf_nat_range), + .table = "nat", + .hooks = (1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_IN), + .me = THIS_MODULE, + }, + { + .name = "DNAT", + .revision = 1, + .checkentry = xt_nat_checkentry, + .destroy = xt_nat_destroy, + .target = xt_dnat_target_v1, + .targetsize = sizeof(struct nf_nat_range), + .table = "nat", + .hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_OUT), + .me = THIS_MODULE, + }, + { + .name = "SNAT", + .revision = 2, + .checkentry = xt_nat_checkentry, + .destroy = xt_nat_destroy, + .target = xt_snat_target_v2, + .targetsize = sizeof(struct nf_nat_range2), + .table = "nat", + .hooks = (1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_IN), + .me = THIS_MODULE, + }, + { + .name = "DNAT", + .revision = 2, + .checkentry = xt_nat_checkentry, + .destroy = xt_nat_destroy, + .target = xt_dnat_target_v2, + .targetsize = sizeof(struct nf_nat_range2), + .table = "nat", + .hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_OUT), + .me = THIS_MODULE, + }, +}; + +static int __init xt_nat_init(void) +{ + return xt_register_targets(xt_nat_target_reg, + ARRAY_SIZE(xt_nat_target_reg)); +} + +static void __exit xt_nat_exit(void) +{ + xt_unregister_targets(xt_nat_target_reg, ARRAY_SIZE(xt_nat_target_reg)); +} + +module_init(xt_nat_init); +module_exit(xt_nat_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_ALIAS("ipt_SNAT"); +MODULE_ALIAS("ipt_DNAT"); +MODULE_ALIAS("ip6t_SNAT"); +MODULE_ALIAS("ip6t_DNAT"); +MODULE_DESCRIPTION("SNAT and DNAT targets support"); diff --git a/net/netfilter/xt_nfacct.c b/net/netfilter/xt_nfacct.c new file mode 100644 index 000000000..7c6bf1c16 --- /dev/null +++ b/net/netfilter/xt_nfacct.c @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * (C) 2011 Pablo Neira Ayuso + * (C) 2011 Intra2net AG + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include + +#include +#include +#include + +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_DESCRIPTION("Xtables: match for the extended accounting infrastructure"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_nfacct"); +MODULE_ALIAS("ip6t_nfacct"); + +static bool nfacct_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + int overquota; + const struct xt_nfacct_match_info *info = par->targinfo; + + nfnl_acct_update(skb, info->nfacct); + + overquota = nfnl_acct_overquota(xt_net(par), info->nfacct); + + return overquota != NFACCT_UNDERQUOTA; +} + +static int +nfacct_mt_checkentry(const struct xt_mtchk_param *par) +{ + struct xt_nfacct_match_info *info = par->matchinfo; + struct nf_acct *nfacct; + + nfacct = nfnl_acct_find_get(par->net, info->name); + if (nfacct == NULL) { + pr_info_ratelimited("accounting object `%s' does not exists\n", + info->name); + return -ENOENT; + } + info->nfacct = nfacct; + return 0; +} + +static void +nfacct_mt_destroy(const struct xt_mtdtor_param *par) +{ + const struct xt_nfacct_match_info *info = par->matchinfo; + + nfnl_acct_put(info->nfacct); +} + +static struct xt_match nfacct_mt_reg[] __read_mostly = { + { + .name = "nfacct", + .revision = 0, + .family = NFPROTO_UNSPEC, + .checkentry = nfacct_mt_checkentry, + .match = nfacct_mt, + .destroy = nfacct_mt_destroy, + .matchsize = sizeof(struct xt_nfacct_match_info), + .usersize = offsetof(struct xt_nfacct_match_info, nfacct), + .me = THIS_MODULE, + }, + { + .name = "nfacct", + .revision = 1, + .family = NFPROTO_UNSPEC, + .checkentry = nfacct_mt_checkentry, + .match = nfacct_mt, + .destroy = nfacct_mt_destroy, + .matchsize = sizeof(struct xt_nfacct_match_info_v1), + .usersize = offsetof(struct xt_nfacct_match_info_v1, nfacct), + .me = THIS_MODULE, + }, +}; + +static int __init nfacct_mt_init(void) +{ + return xt_register_matches(nfacct_mt_reg, ARRAY_SIZE(nfacct_mt_reg)); +} + +static void __exit nfacct_mt_exit(void) +{ + xt_unregister_matches(nfacct_mt_reg, ARRAY_SIZE(nfacct_mt_reg)); +} + +module_init(nfacct_mt_init); +module_exit(nfacct_mt_exit); diff --git a/net/netfilter/xt_osf.c b/net/netfilter/xt_osf.c new file mode 100644 index 000000000..dc9485854 --- /dev/null +++ b/net/netfilter/xt_osf.c @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (c) 2003+ Evgeniy Polyakov + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include + +static bool +xt_osf_match_packet(const struct sk_buff *skb, struct xt_action_param *p) +{ + return nf_osf_match(skb, xt_family(p), xt_hooknum(p), xt_in(p), + xt_out(p), p->matchinfo, xt_net(p), nf_osf_fingers); +} + +static struct xt_match xt_osf_match = { + .name = "osf", + .revision = 0, + .family = NFPROTO_IPV4, + .proto = IPPROTO_TCP, + .hooks = (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_FORWARD), + .match = xt_osf_match_packet, + .matchsize = sizeof(struct xt_osf_info), + .me = THIS_MODULE, +}; + +static int __init xt_osf_init(void) +{ + int err; + + err = xt_register_match(&xt_osf_match); + if (err) { + pr_err("Failed to register OS fingerprint " + "matching module (%d)\n", err); + return err; + } + + return 0; +} + +static void __exit xt_osf_fini(void) +{ + xt_unregister_match(&xt_osf_match); +} + +module_init(xt_osf_init); +module_exit(xt_osf_fini); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Evgeniy Polyakov "); +MODULE_DESCRIPTION("Passive OS fingerprint matching."); +MODULE_ALIAS("ipt_osf"); +MODULE_ALIAS("ip6t_osf"); diff --git a/net/netfilter/xt_owner.c b/net/netfilter/xt_owner.c new file mode 100644 index 000000000..50332888c --- /dev/null +++ b/net/netfilter/xt_owner.c @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Kernel module to match various things tied to sockets associated with + * locally generated outgoing packets. + * + * (C) 2000 Marc Boucher + * + * Copyright © CC Computer Consultants GmbH, 2007 - 2008 + */ +#include +#include +#include +#include + +#include +#include +#include +#include + +static int owner_check(const struct xt_mtchk_param *par) +{ + struct xt_owner_match_info *info = par->matchinfo; + struct net *net = par->net; + + if (info->match & ~XT_OWNER_MASK) + return -EINVAL; + + /* Only allow the common case where the userns of the writer + * matches the userns of the network namespace. + */ + if ((info->match & (XT_OWNER_UID|XT_OWNER_GID)) && + (current_user_ns() != net->user_ns)) + return -EINVAL; + + /* Ensure the uids are valid */ + if (info->match & XT_OWNER_UID) { + kuid_t uid_min = make_kuid(net->user_ns, info->uid_min); + kuid_t uid_max = make_kuid(net->user_ns, info->uid_max); + + if (!uid_valid(uid_min) || !uid_valid(uid_max) || + (info->uid_max < info->uid_min) || + uid_lt(uid_max, uid_min)) { + return -EINVAL; + } + } + + /* Ensure the gids are valid */ + if (info->match & XT_OWNER_GID) { + kgid_t gid_min = make_kgid(net->user_ns, info->gid_min); + kgid_t gid_max = make_kgid(net->user_ns, info->gid_max); + + if (!gid_valid(gid_min) || !gid_valid(gid_max) || + (info->gid_max < info->gid_min) || + gid_lt(gid_max, gid_min)) { + return -EINVAL; + } + } + + return 0; +} + +static bool +owner_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_owner_match_info *info = par->matchinfo; + const struct file *filp; + struct sock *sk = skb_to_full_sk(skb); + struct net *net = xt_net(par); + + if (!sk || !sk->sk_socket || !net_eq(net, sock_net(sk))) + return (info->match ^ info->invert) == 0; + else if (info->match & info->invert & XT_OWNER_SOCKET) + /* + * Socket exists but user wanted ! --socket-exists. + * (Single ampersands intended.) + */ + return false; + + read_lock_bh(&sk->sk_callback_lock); + filp = sk->sk_socket ? sk->sk_socket->file : NULL; + if (filp == NULL) { + read_unlock_bh(&sk->sk_callback_lock); + return ((info->match ^ info->invert) & + (XT_OWNER_UID | XT_OWNER_GID)) == 0; + } + + if (info->match & XT_OWNER_UID) { + kuid_t uid_min = make_kuid(net->user_ns, info->uid_min); + kuid_t uid_max = make_kuid(net->user_ns, info->uid_max); + if ((uid_gte(filp->f_cred->fsuid, uid_min) && + uid_lte(filp->f_cred->fsuid, uid_max)) ^ + !(info->invert & XT_OWNER_UID)) { + read_unlock_bh(&sk->sk_callback_lock); + return false; + } + } + + if (info->match & XT_OWNER_GID) { + unsigned int i, match = false; + kgid_t gid_min = make_kgid(net->user_ns, info->gid_min); + kgid_t gid_max = make_kgid(net->user_ns, info->gid_max); + struct group_info *gi = filp->f_cred->group_info; + + if (gid_gte(filp->f_cred->fsgid, gid_min) && + gid_lte(filp->f_cred->fsgid, gid_max)) + match = true; + + if (!match && (info->match & XT_OWNER_SUPPL_GROUPS) && gi) { + for (i = 0; i < gi->ngroups; ++i) { + kgid_t group = gi->gid[i]; + + if (gid_gte(group, gid_min) && + gid_lte(group, gid_max)) { + match = true; + break; + } + } + } + + if (match ^ !(info->invert & XT_OWNER_GID)) { + read_unlock_bh(&sk->sk_callback_lock); + return false; + } + } + + read_unlock_bh(&sk->sk_callback_lock); + return true; +} + +static struct xt_match owner_mt_reg __read_mostly = { + .name = "owner", + .revision = 1, + .family = NFPROTO_UNSPEC, + .checkentry = owner_check, + .match = owner_mt, + .matchsize = sizeof(struct xt_owner_match_info), + .hooks = (1 << NF_INET_LOCAL_OUT) | + (1 << NF_INET_POST_ROUTING), + .me = THIS_MODULE, +}; + +static int __init owner_mt_init(void) +{ + return xt_register_match(&owner_mt_reg); +} + +static void __exit owner_mt_exit(void) +{ + xt_unregister_match(&owner_mt_reg); +} + +module_init(owner_mt_init); +module_exit(owner_mt_exit); +MODULE_AUTHOR("Jan Engelhardt "); +MODULE_DESCRIPTION("Xtables: socket owner matching"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_owner"); +MODULE_ALIAS("ip6t_owner"); diff --git a/net/netfilter/xt_physdev.c b/net/netfilter/xt_physdev.c new file mode 100644 index 000000000..343e65f37 --- /dev/null +++ b/net/netfilter/xt_physdev.c @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Kernel module to match the bridge port in and + * out device for IP packets coming into contact with a bridge. */ + +/* (C) 2001-2003 Bart De Schuymer + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Bart De Schuymer "); +MODULE_DESCRIPTION("Xtables: Bridge physical device match"); +MODULE_ALIAS("ipt_physdev"); +MODULE_ALIAS("ip6t_physdev"); + + +static bool +physdev_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_physdev_info *info = par->matchinfo; + const struct net_device *physdev; + unsigned long ret; + const char *indev, *outdev; + + /* Not a bridged IP packet or no info available yet: + * LOCAL_OUT/mangle and LOCAL_OUT/nat don't know if + * the destination device will be a bridge. */ + if (!nf_bridge_info_exists(skb)) { + /* Return MATCH if the invert flags of the used options are on */ + if ((info->bitmask & XT_PHYSDEV_OP_BRIDGED) && + !(info->invert & XT_PHYSDEV_OP_BRIDGED)) + return false; + if ((info->bitmask & XT_PHYSDEV_OP_ISIN) && + !(info->invert & XT_PHYSDEV_OP_ISIN)) + return false; + if ((info->bitmask & XT_PHYSDEV_OP_ISOUT) && + !(info->invert & XT_PHYSDEV_OP_ISOUT)) + return false; + if ((info->bitmask & XT_PHYSDEV_OP_IN) && + !(info->invert & XT_PHYSDEV_OP_IN)) + return false; + if ((info->bitmask & XT_PHYSDEV_OP_OUT) && + !(info->invert & XT_PHYSDEV_OP_OUT)) + return false; + return true; + } + + physdev = nf_bridge_get_physoutdev(skb); + outdev = physdev ? physdev->name : NULL; + + /* This only makes sense in the FORWARD and POSTROUTING chains */ + if ((info->bitmask & XT_PHYSDEV_OP_BRIDGED) && + (!!outdev ^ !(info->invert & XT_PHYSDEV_OP_BRIDGED))) + return false; + + physdev = nf_bridge_get_physindev(skb, xt_net(par)); + indev = physdev ? physdev->name : NULL; + + if ((info->bitmask & XT_PHYSDEV_OP_ISIN && + (!indev ^ !!(info->invert & XT_PHYSDEV_OP_ISIN))) || + (info->bitmask & XT_PHYSDEV_OP_ISOUT && + (!outdev ^ !!(info->invert & XT_PHYSDEV_OP_ISOUT)))) + return false; + + if (!(info->bitmask & XT_PHYSDEV_OP_IN)) + goto match_outdev; + + if (indev) { + ret = ifname_compare_aligned(indev, info->physindev, + info->in_mask); + + if (!ret ^ !(info->invert & XT_PHYSDEV_OP_IN)) + return false; + } + +match_outdev: + if (!(info->bitmask & XT_PHYSDEV_OP_OUT)) + return true; + + if (!outdev) + return false; + + ret = ifname_compare_aligned(outdev, info->physoutdev, info->out_mask); + + return (!!ret ^ !(info->invert & XT_PHYSDEV_OP_OUT)); +} + +static int physdev_mt_check(const struct xt_mtchk_param *par) +{ + const struct xt_physdev_info *info = par->matchinfo; + static bool brnf_probed __read_mostly; + + if (!(info->bitmask & XT_PHYSDEV_OP_MASK) || + info->bitmask & ~XT_PHYSDEV_OP_MASK) + return -EINVAL; + if (info->bitmask & (XT_PHYSDEV_OP_OUT | XT_PHYSDEV_OP_ISOUT) && + (!(info->bitmask & XT_PHYSDEV_OP_BRIDGED) || + info->invert & XT_PHYSDEV_OP_BRIDGED) && + par->hook_mask & (1 << NF_INET_LOCAL_OUT)) { + pr_info_ratelimited("--physdev-out and --physdev-is-out only supported in the FORWARD and POSTROUTING chains with bridged traffic\n"); + return -EINVAL; + } + + if (!brnf_probed) { + brnf_probed = true; + request_module("br_netfilter"); + } + + return 0; +} + +static struct xt_match physdev_mt_reg __read_mostly = { + .name = "physdev", + .revision = 0, + .family = NFPROTO_UNSPEC, + .checkentry = physdev_mt_check, + .match = physdev_mt, + .matchsize = sizeof(struct xt_physdev_info), + .me = THIS_MODULE, +}; + +static int __init physdev_mt_init(void) +{ + return xt_register_match(&physdev_mt_reg); +} + +static void __exit physdev_mt_exit(void) +{ + xt_unregister_match(&physdev_mt_reg); +} + +module_init(physdev_mt_init); +module_exit(physdev_mt_exit); diff --git a/net/netfilter/xt_pkttype.c b/net/netfilter/xt_pkttype.c new file mode 100644 index 000000000..f48946aef --- /dev/null +++ b/net/netfilter/xt_pkttype.c @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* (C) 1999-2001 Michal Ludvig + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Michal Ludvig "); +MODULE_DESCRIPTION("Xtables: link layer packet type match"); +MODULE_ALIAS("ipt_pkttype"); +MODULE_ALIAS("ip6t_pkttype"); + +static bool +pkttype_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_pkttype_info *info = par->matchinfo; + u_int8_t type; + + if (skb->pkt_type != PACKET_LOOPBACK) + type = skb->pkt_type; + else if (xt_family(par) == NFPROTO_IPV4 && + ipv4_is_multicast(ip_hdr(skb)->daddr)) + type = PACKET_MULTICAST; + else if (xt_family(par) == NFPROTO_IPV6) + type = PACKET_MULTICAST; + else + type = PACKET_BROADCAST; + + return (type == info->pkttype) ^ info->invert; +} + +static struct xt_match pkttype_mt_reg __read_mostly = { + .name = "pkttype", + .revision = 0, + .family = NFPROTO_UNSPEC, + .match = pkttype_mt, + .matchsize = sizeof(struct xt_pkttype_info), + .me = THIS_MODULE, +}; + +static int __init pkttype_mt_init(void) +{ + return xt_register_match(&pkttype_mt_reg); +} + +static void __exit pkttype_mt_exit(void) +{ + xt_unregister_match(&pkttype_mt_reg); +} + +module_init(pkttype_mt_init); +module_exit(pkttype_mt_exit); diff --git a/net/netfilter/xt_policy.c b/net/netfilter/xt_policy.c new file mode 100644 index 000000000..cb6e82790 --- /dev/null +++ b/net/netfilter/xt_policy.c @@ -0,0 +1,189 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* IP tables module for matching IPsec policy + * + * Copyright (c) 2004,2005 Patrick McHardy, + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include + +#include +#include +#include + +MODULE_AUTHOR("Patrick McHardy "); +MODULE_DESCRIPTION("Xtables: IPsec policy match"); +MODULE_LICENSE("GPL"); + +static inline bool +xt_addr_cmp(const union nf_inet_addr *a1, const union nf_inet_addr *m, + const union nf_inet_addr *a2, unsigned short family) +{ + switch (family) { + case NFPROTO_IPV4: + return ((a1->ip ^ a2->ip) & m->ip) == 0; + case NFPROTO_IPV6: + return ipv6_masked_addr_cmp(&a1->in6, &m->in6, &a2->in6) == 0; + } + return false; +} + +static bool +match_xfrm_state(const struct xfrm_state *x, const struct xt_policy_elem *e, + unsigned short family) +{ +#define MATCH_ADDR(x,y,z) (!e->match.x || \ + (xt_addr_cmp(&e->x, &e->y, (const union nf_inet_addr *)(z), family) \ + ^ e->invert.x)) +#define MATCH(x,y) (!e->match.x || ((e->x == (y)) ^ e->invert.x)) + + return MATCH_ADDR(saddr, smask, &x->props.saddr) && + MATCH_ADDR(daddr, dmask, &x->id.daddr) && + MATCH(proto, x->id.proto) && + MATCH(mode, x->props.mode) && + MATCH(spi, x->id.spi) && + MATCH(reqid, x->props.reqid); +} + +static int +match_policy_in(const struct sk_buff *skb, const struct xt_policy_info *info, + unsigned short family) +{ + const struct xt_policy_elem *e; + const struct sec_path *sp = skb_sec_path(skb); + int strict = info->flags & XT_POLICY_MATCH_STRICT; + int i, pos; + + if (sp == NULL) + return -1; + if (strict && info->len != sp->len) + return 0; + + for (i = sp->len - 1; i >= 0; i--) { + pos = strict ? i - sp->len + 1 : 0; + if (pos >= info->len) + return 0; + e = &info->pol[pos]; + + if (match_xfrm_state(sp->xvec[i], e, family)) { + if (!strict) + return 1; + } else if (strict) + return 0; + } + + return strict ? 1 : 0; +} + +static int +match_policy_out(const struct sk_buff *skb, const struct xt_policy_info *info, + unsigned short family) +{ + const struct xt_policy_elem *e; + const struct dst_entry *dst = skb_dst(skb); + int strict = info->flags & XT_POLICY_MATCH_STRICT; + int i, pos; + + if (dst->xfrm == NULL) + return -1; + + for (i = 0; dst && dst->xfrm; + dst = ((struct xfrm_dst *)dst)->child, i++) { + pos = strict ? i : 0; + if (pos >= info->len) + return 0; + e = &info->pol[pos]; + + if (match_xfrm_state(dst->xfrm, e, family)) { + if (!strict) + return 1; + } else if (strict) + return 0; + } + + return strict ? i == info->len : 0; +} + +static bool +policy_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_policy_info *info = par->matchinfo; + int ret; + + if (info->flags & XT_POLICY_MATCH_IN) + ret = match_policy_in(skb, info, xt_family(par)); + else + ret = match_policy_out(skb, info, xt_family(par)); + + if (ret < 0) + ret = info->flags & XT_POLICY_MATCH_NONE ? true : false; + else if (info->flags & XT_POLICY_MATCH_NONE) + ret = false; + + return ret; +} + +static int policy_mt_check(const struct xt_mtchk_param *par) +{ + const struct xt_policy_info *info = par->matchinfo; + const char *errmsg = "neither incoming nor outgoing policy selected"; + + if (!(info->flags & (XT_POLICY_MATCH_IN|XT_POLICY_MATCH_OUT))) + goto err; + + if (par->hook_mask & ((1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN)) && info->flags & XT_POLICY_MATCH_OUT) { + errmsg = "output policy not valid in PREROUTING and INPUT"; + goto err; + } + if (par->hook_mask & ((1 << NF_INET_POST_ROUTING) | + (1 << NF_INET_LOCAL_OUT)) && info->flags & XT_POLICY_MATCH_IN) { + errmsg = "input policy not valid in POSTROUTING and OUTPUT"; + goto err; + } + if (info->len > XT_POLICY_MAX_ELEM) { + errmsg = "too many policy elements"; + goto err; + } + return 0; +err: + pr_info_ratelimited("%s\n", errmsg); + return -EINVAL; +} + +static struct xt_match policy_mt_reg[] __read_mostly = { + { + .name = "policy", + .family = NFPROTO_IPV4, + .checkentry = policy_mt_check, + .match = policy_mt, + .matchsize = sizeof(struct xt_policy_info), + .me = THIS_MODULE, + }, + { + .name = "policy", + .family = NFPROTO_IPV6, + .checkentry = policy_mt_check, + .match = policy_mt, + .matchsize = sizeof(struct xt_policy_info), + .me = THIS_MODULE, + }, +}; + +static int __init policy_mt_init(void) +{ + return xt_register_matches(policy_mt_reg, ARRAY_SIZE(policy_mt_reg)); +} + +static void __exit policy_mt_exit(void) +{ + xt_unregister_matches(policy_mt_reg, ARRAY_SIZE(policy_mt_reg)); +} + +module_init(policy_mt_init); +module_exit(policy_mt_exit); +MODULE_ALIAS("ipt_policy"); +MODULE_ALIAS("ip6t_policy"); diff --git a/net/netfilter/xt_quota.c b/net/netfilter/xt_quota.c new file mode 100644 index 000000000..4452cc93b --- /dev/null +++ b/net/netfilter/xt_quota.c @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * netfilter module to enforce network quotas + * + * Sam Johnston + */ +#include +#include +#include + +#include +#include +#include + +struct xt_quota_priv { + spinlock_t lock; + uint64_t quota; +}; + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Sam Johnston "); +MODULE_DESCRIPTION("Xtables: countdown quota match"); +MODULE_ALIAS("ipt_quota"); +MODULE_ALIAS("ip6t_quota"); + +static bool +quota_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + struct xt_quota_info *q = (void *)par->matchinfo; + struct xt_quota_priv *priv = q->master; + bool ret = q->flags & XT_QUOTA_INVERT; + + spin_lock_bh(&priv->lock); + if (priv->quota >= skb->len) { + priv->quota -= skb->len; + ret = !ret; + } else { + /* we do not allow even small packets from now on */ + priv->quota = 0; + } + spin_unlock_bh(&priv->lock); + + return ret; +} + +static int quota_mt_check(const struct xt_mtchk_param *par) +{ + struct xt_quota_info *q = par->matchinfo; + + if (q->flags & ~XT_QUOTA_MASK) + return -EINVAL; + + q->master = kmalloc(sizeof(*q->master), GFP_KERNEL); + if (q->master == NULL) + return -ENOMEM; + + spin_lock_init(&q->master->lock); + q->master->quota = q->quota; + return 0; +} + +static void quota_mt_destroy(const struct xt_mtdtor_param *par) +{ + const struct xt_quota_info *q = par->matchinfo; + + kfree(q->master); +} + +static struct xt_match quota_mt_reg __read_mostly = { + .name = "quota", + .revision = 0, + .family = NFPROTO_UNSPEC, + .match = quota_mt, + .checkentry = quota_mt_check, + .destroy = quota_mt_destroy, + .matchsize = sizeof(struct xt_quota_info), + .usersize = offsetof(struct xt_quota_info, master), + .me = THIS_MODULE, +}; + +static int __init quota_mt_init(void) +{ + return xt_register_match("a_mt_reg); +} + +static void __exit quota_mt_exit(void) +{ + xt_unregister_match("a_mt_reg); +} + +module_init(quota_mt_init); +module_exit(quota_mt_exit); diff --git a/net/netfilter/xt_rateest.c b/net/netfilter/xt_rateest.c new file mode 100644 index 000000000..72324bd97 --- /dev/null +++ b/net/netfilter/xt_rateest.c @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * (C) 2007 Patrick McHardy + */ +#include +#include +#include + +#include +#include +#include + + +static bool +xt_rateest_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_rateest_match_info *info = par->matchinfo; + struct gnet_stats_rate_est64 sample = {0}; + u_int32_t bps1, bps2, pps1, pps2; + bool ret = true; + + gen_estimator_read(&info->est1->rate_est, &sample); + + if (info->flags & XT_RATEEST_MATCH_DELTA) { + bps1 = info->bps1 >= sample.bps ? info->bps1 - sample.bps : 0; + pps1 = info->pps1 >= sample.pps ? info->pps1 - sample.pps : 0; + } else { + bps1 = sample.bps; + pps1 = sample.pps; + } + + if (info->flags & XT_RATEEST_MATCH_ABS) { + bps2 = info->bps2; + pps2 = info->pps2; + } else { + gen_estimator_read(&info->est2->rate_est, &sample); + + if (info->flags & XT_RATEEST_MATCH_DELTA) { + bps2 = info->bps2 >= sample.bps ? info->bps2 - sample.bps : 0; + pps2 = info->pps2 >= sample.pps ? info->pps2 - sample.pps : 0; + } else { + bps2 = sample.bps; + pps2 = sample.pps; + } + } + + switch (info->mode) { + case XT_RATEEST_MATCH_LT: + if (info->flags & XT_RATEEST_MATCH_BPS) + ret &= bps1 < bps2; + if (info->flags & XT_RATEEST_MATCH_PPS) + ret &= pps1 < pps2; + break; + case XT_RATEEST_MATCH_GT: + if (info->flags & XT_RATEEST_MATCH_BPS) + ret &= bps1 > bps2; + if (info->flags & XT_RATEEST_MATCH_PPS) + ret &= pps1 > pps2; + break; + case XT_RATEEST_MATCH_EQ: + if (info->flags & XT_RATEEST_MATCH_BPS) + ret &= bps1 == bps2; + if (info->flags & XT_RATEEST_MATCH_PPS) + ret &= pps1 == pps2; + break; + } + + ret ^= info->flags & XT_RATEEST_MATCH_INVERT ? true : false; + return ret; +} + +static int xt_rateest_mt_checkentry(const struct xt_mtchk_param *par) +{ + struct xt_rateest_match_info *info = par->matchinfo; + struct xt_rateest *est1, *est2; + int ret = -EINVAL; + + if (hweight32(info->flags & (XT_RATEEST_MATCH_ABS | + XT_RATEEST_MATCH_REL)) != 1) + goto err1; + + if (!(info->flags & (XT_RATEEST_MATCH_BPS | XT_RATEEST_MATCH_PPS))) + goto err1; + + switch (info->mode) { + case XT_RATEEST_MATCH_EQ: + case XT_RATEEST_MATCH_LT: + case XT_RATEEST_MATCH_GT: + break; + default: + goto err1; + } + + ret = -ENOENT; + est1 = xt_rateest_lookup(par->net, info->name1); + if (!est1) + goto err1; + + est2 = NULL; + if (info->flags & XT_RATEEST_MATCH_REL) { + est2 = xt_rateest_lookup(par->net, info->name2); + if (!est2) + goto err2; + } + + info->est1 = est1; + info->est2 = est2; + return 0; + +err2: + xt_rateest_put(par->net, est1); +err1: + return ret; +} + +static void xt_rateest_mt_destroy(const struct xt_mtdtor_param *par) +{ + struct xt_rateest_match_info *info = par->matchinfo; + + xt_rateest_put(par->net, info->est1); + if (info->est2) + xt_rateest_put(par->net, info->est2); +} + +static struct xt_match xt_rateest_mt_reg __read_mostly = { + .name = "rateest", + .revision = 0, + .family = NFPROTO_UNSPEC, + .match = xt_rateest_mt, + .checkentry = xt_rateest_mt_checkentry, + .destroy = xt_rateest_mt_destroy, + .matchsize = sizeof(struct xt_rateest_match_info), + .usersize = offsetof(struct xt_rateest_match_info, est1), + .me = THIS_MODULE, +}; + +static int __init xt_rateest_mt_init(void) +{ + return xt_register_match(&xt_rateest_mt_reg); +} + +static void __exit xt_rateest_mt_fini(void) +{ + xt_unregister_match(&xt_rateest_mt_reg); +} + +MODULE_AUTHOR("Patrick McHardy "); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("xtables rate estimator match"); +MODULE_ALIAS("ipt_rateest"); +MODULE_ALIAS("ip6t_rateest"); +module_init(xt_rateest_mt_init); +module_exit(xt_rateest_mt_fini); diff --git a/net/netfilter/xt_realm.c b/net/netfilter/xt_realm.c new file mode 100644 index 000000000..6df485f44 --- /dev/null +++ b/net/netfilter/xt_realm.c @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* IP tables module for matching the routing realm + * + * (C) 2003 by Sampsa Ranta + */ + +#include +#include +#include +#include + +#include +#include +#include + +MODULE_AUTHOR("Sampsa Ranta "); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Xtables: Routing realm match"); +MODULE_ALIAS("ipt_realm"); + +static bool +realm_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_realm_info *info = par->matchinfo; + const struct dst_entry *dst = skb_dst(skb); + + return (info->id == (dst->tclassid & info->mask)) ^ info->invert; +} + +static struct xt_match realm_mt_reg __read_mostly = { + .name = "realm", + .match = realm_mt, + .matchsize = sizeof(struct xt_realm_info), + .hooks = (1 << NF_INET_POST_ROUTING) | (1 << NF_INET_FORWARD) | + (1 << NF_INET_LOCAL_OUT) | (1 << NF_INET_LOCAL_IN), + .family = NFPROTO_UNSPEC, + .me = THIS_MODULE +}; + +static int __init realm_mt_init(void) +{ + return xt_register_match(&realm_mt_reg); +} + +static void __exit realm_mt_exit(void) +{ + xt_unregister_match(&realm_mt_reg); +} + +module_init(realm_mt_init); +module_exit(realm_mt_exit); diff --git a/net/netfilter/xt_recent.c b/net/netfilter/xt_recent.c new file mode 100644 index 000000000..ef93e0d3b --- /dev/null +++ b/net/netfilter/xt_recent.c @@ -0,0 +1,763 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2006 Patrick McHardy + * Copyright © CC Computer Consultants GmbH, 2007 - 2008 + * + * This is a replacement of the old ipt_recent module, which carried the + * following copyright notice: + * + * Author: Stephen Frost + * Copyright 2002-2003, Stephen Frost, 2.5.x port by laforge@netfilter.org + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +MODULE_AUTHOR("Patrick McHardy "); +MODULE_AUTHOR("Jan Engelhardt "); +MODULE_DESCRIPTION("Xtables: \"recently-seen\" host matching"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_recent"); +MODULE_ALIAS("ip6t_recent"); + +static unsigned int ip_list_tot __read_mostly = 100; +static unsigned int ip_list_hash_size __read_mostly; +static unsigned int ip_list_perms __read_mostly = 0644; +static unsigned int ip_list_uid __read_mostly; +static unsigned int ip_list_gid __read_mostly; +module_param(ip_list_tot, uint, 0400); +module_param(ip_list_hash_size, uint, 0400); +module_param(ip_list_perms, uint, 0400); +module_param(ip_list_uid, uint, 0644); +module_param(ip_list_gid, uint, 0644); +MODULE_PARM_DESC(ip_list_tot, "number of IPs to remember per list"); +MODULE_PARM_DESC(ip_list_hash_size, "size of hash table used to look up IPs"); +MODULE_PARM_DESC(ip_list_perms, "permissions on /proc/net/xt_recent/* files"); +MODULE_PARM_DESC(ip_list_uid, "default owner of /proc/net/xt_recent/* files"); +MODULE_PARM_DESC(ip_list_gid, "default owning group of /proc/net/xt_recent/* files"); + +/* retained for backwards compatibility */ +static unsigned int ip_pkt_list_tot __read_mostly; +module_param(ip_pkt_list_tot, uint, 0400); +MODULE_PARM_DESC(ip_pkt_list_tot, "number of packets per IP address to remember (max. 255)"); + +#define XT_RECENT_MAX_NSTAMPS 256 + +struct recent_entry { + struct list_head list; + struct list_head lru_list; + union nf_inet_addr addr; + u_int16_t family; + u_int8_t ttl; + u_int8_t index; + u_int16_t nstamps; + unsigned long stamps[]; +}; + +struct recent_table { + struct list_head list; + char name[XT_RECENT_NAME_LEN]; + union nf_inet_addr mask; + unsigned int refcnt; + unsigned int entries; + u8 nstamps_max_mask; + struct list_head lru_list; + struct list_head iphash[]; +}; + +struct recent_net { + struct list_head tables; +#ifdef CONFIG_PROC_FS + struct proc_dir_entry *xt_recent; +#endif +}; + +static unsigned int recent_net_id __read_mostly; + +static inline struct recent_net *recent_pernet(struct net *net) +{ + return net_generic(net, recent_net_id); +} + +static DEFINE_SPINLOCK(recent_lock); +static DEFINE_MUTEX(recent_mutex); + +#ifdef CONFIG_PROC_FS +static const struct proc_ops recent_mt_proc_ops; +#endif + +static u_int32_t hash_rnd __read_mostly; + +static inline unsigned int recent_entry_hash4(const union nf_inet_addr *addr) +{ + return jhash_1word((__force u32)addr->ip, hash_rnd) & + (ip_list_hash_size - 1); +} + +static inline unsigned int recent_entry_hash6(const union nf_inet_addr *addr) +{ + return jhash2((u32 *)addr->ip6, ARRAY_SIZE(addr->ip6), hash_rnd) & + (ip_list_hash_size - 1); +} + +static struct recent_entry * +recent_entry_lookup(const struct recent_table *table, + const union nf_inet_addr *addrp, u_int16_t family, + u_int8_t ttl) +{ + struct recent_entry *e; + unsigned int h; + + if (family == NFPROTO_IPV4) + h = recent_entry_hash4(addrp); + else + h = recent_entry_hash6(addrp); + + list_for_each_entry(e, &table->iphash[h], list) + if (e->family == family && + memcmp(&e->addr, addrp, sizeof(e->addr)) == 0 && + (ttl == e->ttl || ttl == 0 || e->ttl == 0)) + return e; + return NULL; +} + +static void recent_entry_remove(struct recent_table *t, struct recent_entry *e) +{ + list_del(&e->list); + list_del(&e->lru_list); + kfree(e); + t->entries--; +} + +/* + * Drop entries with timestamps older then 'time'. + */ +static void recent_entry_reap(struct recent_table *t, unsigned long time, + struct recent_entry *working, bool update) +{ + struct recent_entry *e; + + /* + * The head of the LRU list is always the oldest entry. + */ + e = list_entry(t->lru_list.next, struct recent_entry, lru_list); + + /* + * Do not reap the entry which are going to be updated. + */ + if (e == working && update) + return; + + /* + * The last time stamp is the most recent. + */ + if (time_after(time, e->stamps[e->index-1])) + recent_entry_remove(t, e); +} + +static struct recent_entry * +recent_entry_init(struct recent_table *t, const union nf_inet_addr *addr, + u_int16_t family, u_int8_t ttl) +{ + struct recent_entry *e; + unsigned int nstamps_max = t->nstamps_max_mask; + + if (t->entries >= ip_list_tot) { + e = list_entry(t->lru_list.next, struct recent_entry, lru_list); + recent_entry_remove(t, e); + } + + nstamps_max += 1; + e = kmalloc(struct_size(e, stamps, nstamps_max), GFP_ATOMIC); + if (e == NULL) + return NULL; + memcpy(&e->addr, addr, sizeof(e->addr)); + e->ttl = ttl; + e->stamps[0] = jiffies; + e->nstamps = 1; + e->index = 1; + e->family = family; + if (family == NFPROTO_IPV4) + list_add_tail(&e->list, &t->iphash[recent_entry_hash4(addr)]); + else + list_add_tail(&e->list, &t->iphash[recent_entry_hash6(addr)]); + list_add_tail(&e->lru_list, &t->lru_list); + t->entries++; + return e; +} + +static void recent_entry_update(struct recent_table *t, struct recent_entry *e) +{ + e->index &= t->nstamps_max_mask; + e->stamps[e->index++] = jiffies; + if (e->index > e->nstamps) + e->nstamps = e->index; + list_move_tail(&e->lru_list, &t->lru_list); +} + +static struct recent_table *recent_table_lookup(struct recent_net *recent_net, + const char *name) +{ + struct recent_table *t; + + list_for_each_entry(t, &recent_net->tables, list) + if (!strcmp(t->name, name)) + return t; + return NULL; +} + +static void recent_table_flush(struct recent_table *t) +{ + struct recent_entry *e, *next; + unsigned int i; + + for (i = 0; i < ip_list_hash_size; i++) + list_for_each_entry_safe(e, next, &t->iphash[i], list) + recent_entry_remove(t, e); +} + +static bool +recent_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + struct net *net = xt_net(par); + struct recent_net *recent_net = recent_pernet(net); + const struct xt_recent_mtinfo_v1 *info = par->matchinfo; + struct recent_table *t; + struct recent_entry *e; + union nf_inet_addr addr = {}, addr_mask; + u_int8_t ttl; + bool ret = info->invert; + + if (xt_family(par) == NFPROTO_IPV4) { + const struct iphdr *iph = ip_hdr(skb); + + if (info->side == XT_RECENT_DEST) + addr.ip = iph->daddr; + else + addr.ip = iph->saddr; + + ttl = iph->ttl; + } else { + const struct ipv6hdr *iph = ipv6_hdr(skb); + + if (info->side == XT_RECENT_DEST) + memcpy(&addr.in6, &iph->daddr, sizeof(addr.in6)); + else + memcpy(&addr.in6, &iph->saddr, sizeof(addr.in6)); + + ttl = iph->hop_limit; + } + + /* use TTL as seen before forwarding */ + if (xt_out(par) != NULL && + (!skb->sk || !net_eq(net, sock_net(skb->sk)))) + ttl++; + + spin_lock_bh(&recent_lock); + t = recent_table_lookup(recent_net, info->name); + + nf_inet_addr_mask(&addr, &addr_mask, &t->mask); + + e = recent_entry_lookup(t, &addr_mask, xt_family(par), + (info->check_set & XT_RECENT_TTL) ? ttl : 0); + if (e == NULL) { + if (!(info->check_set & XT_RECENT_SET)) + goto out; + e = recent_entry_init(t, &addr_mask, xt_family(par), ttl); + if (e == NULL) + par->hotdrop = true; + ret = !ret; + goto out; + } + + if (info->check_set & XT_RECENT_SET) + ret = !ret; + else if (info->check_set & XT_RECENT_REMOVE) { + recent_entry_remove(t, e); + ret = !ret; + } else if (info->check_set & (XT_RECENT_CHECK | XT_RECENT_UPDATE)) { + unsigned long time = jiffies - info->seconds * HZ; + unsigned int i, hits = 0; + + for (i = 0; i < e->nstamps; i++) { + if (info->seconds && time_after(time, e->stamps[i])) + continue; + if (!info->hit_count || ++hits >= info->hit_count) { + ret = !ret; + break; + } + } + + /* info->seconds must be non-zero */ + if (info->check_set & XT_RECENT_REAP) + recent_entry_reap(t, time, e, + info->check_set & XT_RECENT_UPDATE && ret); + } + + if (info->check_set & XT_RECENT_SET || + (info->check_set & XT_RECENT_UPDATE && ret)) { + recent_entry_update(t, e); + e->ttl = ttl; + } +out: + spin_unlock_bh(&recent_lock); + return ret; +} + +static void recent_table_free(void *addr) +{ + kvfree(addr); +} + +static int recent_mt_check(const struct xt_mtchk_param *par, + const struct xt_recent_mtinfo_v1 *info) +{ + struct recent_net *recent_net = recent_pernet(par->net); + struct recent_table *t; +#ifdef CONFIG_PROC_FS + struct proc_dir_entry *pde; + kuid_t uid; + kgid_t gid; +#endif + unsigned int nstamp_mask; + unsigned int i; + int ret = -EINVAL; + + net_get_random_once(&hash_rnd, sizeof(hash_rnd)); + + if (info->check_set & ~XT_RECENT_VALID_FLAGS) { + pr_info_ratelimited("Unsupported userspace flags (%08x)\n", + info->check_set); + return -EINVAL; + } + if (hweight8(info->check_set & + (XT_RECENT_SET | XT_RECENT_REMOVE | + XT_RECENT_CHECK | XT_RECENT_UPDATE)) != 1) + return -EINVAL; + if ((info->check_set & (XT_RECENT_SET | XT_RECENT_REMOVE)) && + (info->seconds || info->hit_count || + (info->check_set & XT_RECENT_MODIFIERS))) + return -EINVAL; + if ((info->check_set & XT_RECENT_REAP) && !info->seconds) + return -EINVAL; + if (info->hit_count >= XT_RECENT_MAX_NSTAMPS) { + pr_info_ratelimited("hitcount (%u) is larger than allowed maximum (%u)\n", + info->hit_count, XT_RECENT_MAX_NSTAMPS - 1); + return -EINVAL; + } + ret = xt_check_proc_name(info->name, sizeof(info->name)); + if (ret) + return ret; + + if (ip_pkt_list_tot && info->hit_count < ip_pkt_list_tot) + nstamp_mask = roundup_pow_of_two(ip_pkt_list_tot) - 1; + else if (info->hit_count) + nstamp_mask = roundup_pow_of_two(info->hit_count) - 1; + else + nstamp_mask = 32 - 1; + + mutex_lock(&recent_mutex); + t = recent_table_lookup(recent_net, info->name); + if (t != NULL) { + if (nstamp_mask > t->nstamps_max_mask) { + spin_lock_bh(&recent_lock); + recent_table_flush(t); + t->nstamps_max_mask = nstamp_mask; + spin_unlock_bh(&recent_lock); + } + + t->refcnt++; + ret = 0; + goto out; + } + + t = kvzalloc(struct_size(t, iphash, ip_list_hash_size), GFP_KERNEL); + if (t == NULL) { + ret = -ENOMEM; + goto out; + } + t->refcnt = 1; + t->nstamps_max_mask = nstamp_mask; + + memcpy(&t->mask, &info->mask, sizeof(t->mask)); + strcpy(t->name, info->name); + INIT_LIST_HEAD(&t->lru_list); + for (i = 0; i < ip_list_hash_size; i++) + INIT_LIST_HEAD(&t->iphash[i]); +#ifdef CONFIG_PROC_FS + uid = make_kuid(&init_user_ns, ip_list_uid); + gid = make_kgid(&init_user_ns, ip_list_gid); + if (!uid_valid(uid) || !gid_valid(gid)) { + recent_table_free(t); + ret = -EINVAL; + goto out; + } + pde = proc_create_data(t->name, ip_list_perms, recent_net->xt_recent, + &recent_mt_proc_ops, t); + if (pde == NULL) { + recent_table_free(t); + ret = -ENOMEM; + goto out; + } + proc_set_user(pde, uid, gid); +#endif + spin_lock_bh(&recent_lock); + list_add_tail(&t->list, &recent_net->tables); + spin_unlock_bh(&recent_lock); + ret = 0; +out: + mutex_unlock(&recent_mutex); + return ret; +} + +static int recent_mt_check_v0(const struct xt_mtchk_param *par) +{ + const struct xt_recent_mtinfo_v0 *info_v0 = par->matchinfo; + struct xt_recent_mtinfo_v1 info_v1; + + /* Copy revision 0 structure to revision 1 */ + memcpy(&info_v1, info_v0, sizeof(struct xt_recent_mtinfo)); + /* Set default mask to ensure backward compatible behaviour */ + memset(info_v1.mask.all, 0xFF, sizeof(info_v1.mask.all)); + + return recent_mt_check(par, &info_v1); +} + +static int recent_mt_check_v1(const struct xt_mtchk_param *par) +{ + return recent_mt_check(par, par->matchinfo); +} + +static void recent_mt_destroy(const struct xt_mtdtor_param *par) +{ + struct recent_net *recent_net = recent_pernet(par->net); + const struct xt_recent_mtinfo_v1 *info = par->matchinfo; + struct recent_table *t; + + mutex_lock(&recent_mutex); + t = recent_table_lookup(recent_net, info->name); + if (--t->refcnt == 0) { + spin_lock_bh(&recent_lock); + list_del(&t->list); + spin_unlock_bh(&recent_lock); +#ifdef CONFIG_PROC_FS + if (recent_net->xt_recent != NULL) + remove_proc_entry(t->name, recent_net->xt_recent); +#endif + recent_table_flush(t); + recent_table_free(t); + } + mutex_unlock(&recent_mutex); +} + +#ifdef CONFIG_PROC_FS +struct recent_iter_state { + const struct recent_table *table; + unsigned int bucket; +}; + +static void *recent_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(recent_lock) +{ + struct recent_iter_state *st = seq->private; + const struct recent_table *t = st->table; + struct recent_entry *e; + loff_t p = *pos; + + spin_lock_bh(&recent_lock); + + for (st->bucket = 0; st->bucket < ip_list_hash_size; st->bucket++) + list_for_each_entry(e, &t->iphash[st->bucket], list) + if (p-- == 0) + return e; + return NULL; +} + +static void *recent_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct recent_iter_state *st = seq->private; + const struct recent_table *t = st->table; + const struct recent_entry *e = v; + const struct list_head *head = e->list.next; + + (*pos)++; + while (head == &t->iphash[st->bucket]) { + if (++st->bucket >= ip_list_hash_size) + return NULL; + head = t->iphash[st->bucket].next; + } + return list_entry(head, struct recent_entry, list); +} + +static void recent_seq_stop(struct seq_file *s, void *v) + __releases(recent_lock) +{ + spin_unlock_bh(&recent_lock); +} + +static int recent_seq_show(struct seq_file *seq, void *v) +{ + const struct recent_entry *e = v; + struct recent_iter_state *st = seq->private; + const struct recent_table *t = st->table; + unsigned int i; + + i = (e->index - 1) & t->nstamps_max_mask; + + if (e->family == NFPROTO_IPV4) + seq_printf(seq, "src=%pI4 ttl: %u last_seen: %lu oldest_pkt: %u", + &e->addr.ip, e->ttl, e->stamps[i], e->index); + else + seq_printf(seq, "src=%pI6 ttl: %u last_seen: %lu oldest_pkt: %u", + &e->addr.in6, e->ttl, e->stamps[i], e->index); + for (i = 0; i < e->nstamps; i++) + seq_printf(seq, "%s %lu", i ? "," : "", e->stamps[i]); + seq_putc(seq, '\n'); + return 0; +} + +static const struct seq_operations recent_seq_ops = { + .start = recent_seq_start, + .next = recent_seq_next, + .stop = recent_seq_stop, + .show = recent_seq_show, +}; + +static int recent_seq_open(struct inode *inode, struct file *file) +{ + struct recent_iter_state *st; + + st = __seq_open_private(file, &recent_seq_ops, sizeof(*st)); + if (st == NULL) + return -ENOMEM; + + st->table = pde_data(inode); + return 0; +} + +static ssize_t +recent_mt_proc_write(struct file *file, const char __user *input, + size_t size, loff_t *loff) +{ + struct recent_table *t = pde_data(file_inode(file)); + struct recent_entry *e; + char buf[sizeof("+b335:1d35:1e55:dead:c0de:1715:255.255.255.255")]; + const char *c = buf; + union nf_inet_addr addr = {}; + u_int16_t family; + bool add, succ; + + if (size == 0) + return 0; + if (size > sizeof(buf)) + size = sizeof(buf); + if (copy_from_user(buf, input, size) != 0) + return -EFAULT; + + /* Strict protocol! */ + if (*loff != 0) + return -ESPIPE; + switch (*c) { + case '/': /* flush table */ + spin_lock_bh(&recent_lock); + recent_table_flush(t); + spin_unlock_bh(&recent_lock); + return size; + case '-': /* remove address */ + add = false; + break; + case '+': /* add address */ + add = true; + break; + default: + pr_info_ratelimited("Need \"+ip\", \"-ip\" or \"/\"\n"); + return -EINVAL; + } + + ++c; + --size; + if (strnchr(c, size, ':') != NULL) { + family = NFPROTO_IPV6; + succ = in6_pton(c, size, (void *)&addr, '\n', NULL); + } else { + family = NFPROTO_IPV4; + succ = in4_pton(c, size, (void *)&addr, '\n', NULL); + } + + if (!succ) + return -EINVAL; + + spin_lock_bh(&recent_lock); + e = recent_entry_lookup(t, &addr, family, 0); + if (e == NULL) { + if (add) + recent_entry_init(t, &addr, family, 0); + } else { + if (add) + recent_entry_update(t, e); + else + recent_entry_remove(t, e); + } + spin_unlock_bh(&recent_lock); + /* Note we removed one above */ + *loff += size + 1; + return size + 1; +} + +static const struct proc_ops recent_mt_proc_ops = { + .proc_open = recent_seq_open, + .proc_read = seq_read, + .proc_write = recent_mt_proc_write, + .proc_release = seq_release_private, + .proc_lseek = seq_lseek, +}; + +static int __net_init recent_proc_net_init(struct net *net) +{ + struct recent_net *recent_net = recent_pernet(net); + + recent_net->xt_recent = proc_mkdir("xt_recent", net->proc_net); + if (!recent_net->xt_recent) + return -ENOMEM; + return 0; +} + +static void __net_exit recent_proc_net_exit(struct net *net) +{ + struct recent_net *recent_net = recent_pernet(net); + struct recent_table *t; + + /* recent_net_exit() is called before recent_mt_destroy(). Make sure + * that the parent xt_recent proc entry is empty before trying to + * remove it. + */ + spin_lock_bh(&recent_lock); + list_for_each_entry(t, &recent_net->tables, list) + remove_proc_entry(t->name, recent_net->xt_recent); + + recent_net->xt_recent = NULL; + spin_unlock_bh(&recent_lock); + + remove_proc_entry("xt_recent", net->proc_net); +} +#else +static inline int recent_proc_net_init(struct net *net) +{ + return 0; +} + +static inline void recent_proc_net_exit(struct net *net) +{ +} +#endif /* CONFIG_PROC_FS */ + +static int __net_init recent_net_init(struct net *net) +{ + struct recent_net *recent_net = recent_pernet(net); + + INIT_LIST_HEAD(&recent_net->tables); + return recent_proc_net_init(net); +} + +static void __net_exit recent_net_exit(struct net *net) +{ + recent_proc_net_exit(net); +} + +static struct pernet_operations recent_net_ops = { + .init = recent_net_init, + .exit = recent_net_exit, + .id = &recent_net_id, + .size = sizeof(struct recent_net), +}; + +static struct xt_match recent_mt_reg[] __read_mostly = { + { + .name = "recent", + .revision = 0, + .family = NFPROTO_IPV4, + .match = recent_mt, + .matchsize = sizeof(struct xt_recent_mtinfo), + .checkentry = recent_mt_check_v0, + .destroy = recent_mt_destroy, + .me = THIS_MODULE, + }, + { + .name = "recent", + .revision = 0, + .family = NFPROTO_IPV6, + .match = recent_mt, + .matchsize = sizeof(struct xt_recent_mtinfo), + .checkentry = recent_mt_check_v0, + .destroy = recent_mt_destroy, + .me = THIS_MODULE, + }, + { + .name = "recent", + .revision = 1, + .family = NFPROTO_IPV4, + .match = recent_mt, + .matchsize = sizeof(struct xt_recent_mtinfo_v1), + .checkentry = recent_mt_check_v1, + .destroy = recent_mt_destroy, + .me = THIS_MODULE, + }, + { + .name = "recent", + .revision = 1, + .family = NFPROTO_IPV6, + .match = recent_mt, + .matchsize = sizeof(struct xt_recent_mtinfo_v1), + .checkentry = recent_mt_check_v1, + .destroy = recent_mt_destroy, + .me = THIS_MODULE, + } +}; + +static int __init recent_mt_init(void) +{ + int err; + + BUILD_BUG_ON_NOT_POWER_OF_2(XT_RECENT_MAX_NSTAMPS); + + if (!ip_list_tot || ip_pkt_list_tot >= XT_RECENT_MAX_NSTAMPS) + return -EINVAL; + ip_list_hash_size = 1 << fls(ip_list_tot); + + err = register_pernet_subsys(&recent_net_ops); + if (err) + return err; + err = xt_register_matches(recent_mt_reg, ARRAY_SIZE(recent_mt_reg)); + if (err) + unregister_pernet_subsys(&recent_net_ops); + return err; +} + +static void __exit recent_mt_exit(void) +{ + xt_unregister_matches(recent_mt_reg, ARRAY_SIZE(recent_mt_reg)); + unregister_pernet_subsys(&recent_net_ops); +} + +module_init(recent_mt_init); +module_exit(recent_mt_exit); diff --git a/net/netfilter/xt_repldata.h b/net/netfilter/xt_repldata.h new file mode 100644 index 000000000..68ccbe50b --- /dev/null +++ b/net/netfilter/xt_repldata.h @@ -0,0 +1,48 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Today's hack: quantum tunneling in structs + * + * 'entries' and 'term' are never anywhere referenced by word in code. In fact, + * they serve as the hanging-off data accessed through repl.data[]. + */ + +/* tbl has the following structure equivalent, but is C99 compliant: + * struct { + * struct type##_replace repl; + * struct type##_standard entries[nhooks]; + * struct type##_error term; + * } *tbl; + */ + +#define xt_alloc_initial_table(type, typ2) ({ \ + unsigned int hook_mask = info->valid_hooks; \ + unsigned int nhooks = hweight32(hook_mask); \ + unsigned int bytes = 0, hooknum = 0, i = 0; \ + struct { \ + struct type##_replace repl; \ + struct type##_standard entries[]; \ + } *tbl; \ + struct type##_error *term; \ + size_t term_offset = (offsetof(typeof(*tbl), entries[nhooks]) + \ + __alignof__(*term) - 1) & ~(__alignof__(*term) - 1); \ + tbl = kzalloc(term_offset + sizeof(*term), GFP_KERNEL); \ + if (tbl == NULL) \ + return NULL; \ + term = (struct type##_error *)&(((char *)tbl)[term_offset]); \ + strncpy(tbl->repl.name, info->name, sizeof(tbl->repl.name)); \ + *term = (struct type##_error)typ2##_ERROR_INIT; \ + tbl->repl.valid_hooks = hook_mask; \ + tbl->repl.num_entries = nhooks + 1; \ + tbl->repl.size = nhooks * sizeof(struct type##_standard) + \ + sizeof(struct type##_error); \ + for (; hook_mask != 0; hook_mask >>= 1, ++hooknum) { \ + if (!(hook_mask & 1)) \ + continue; \ + tbl->repl.hook_entry[hooknum] = bytes; \ + tbl->repl.underflow[hooknum] = bytes; \ + tbl->entries[i++] = (struct type##_standard) \ + typ2##_STANDARD_INIT(NF_ACCEPT); \ + bytes += sizeof(struct type##_standard); \ + } \ + tbl; \ +}) diff --git a/net/netfilter/xt_sctp.c b/net/netfilter/xt_sctp.c new file mode 100644 index 000000000..d4bf089c9 --- /dev/null +++ b/net/netfilter/xt_sctp.c @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: GPL-2.0-only +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Kiran Kumar Immidi"); +MODULE_DESCRIPTION("Xtables: SCTP protocol packet match"); +MODULE_ALIAS("ipt_sctp"); +MODULE_ALIAS("ip6t_sctp"); + +#define SCCHECK(cond, option, flag, invflag) (!((flag) & (option)) \ + || (!!((invflag) & (option)) ^ (cond))) + +static bool +match_flags(const struct xt_sctp_flag_info *flag_info, + const int flag_count, + u_int8_t chunktype, + u_int8_t chunkflags) +{ + int i; + + for (i = 0; i < flag_count; i++) + if (flag_info[i].chunktype == chunktype) + return (chunkflags & flag_info[i].flag_mask) == flag_info[i].flag; + + return true; +} + +static inline bool +match_packet(const struct sk_buff *skb, + unsigned int offset, + const struct xt_sctp_info *info, + bool *hotdrop) +{ + u_int32_t chunkmapcopy[256 / sizeof (u_int32_t)]; + const struct sctp_chunkhdr *sch; + struct sctp_chunkhdr _sch; + int chunk_match_type = info->chunk_match_type; + const struct xt_sctp_flag_info *flag_info = info->flag_info; + int flag_count = info->flag_count; + +#ifdef DEBUG + int i = 0; +#endif + + if (chunk_match_type == SCTP_CHUNK_MATCH_ALL) + SCTP_CHUNKMAP_COPY(chunkmapcopy, info->chunkmap); + + do { + sch = skb_header_pointer(skb, offset, sizeof(_sch), &_sch); + if (sch == NULL || sch->length == 0) { + pr_debug("Dropping invalid SCTP packet.\n"); + *hotdrop = true; + return false; + } +#ifdef DEBUG + pr_debug("Chunk num: %d\toffset: %d\ttype: %d\tlength: %d" + "\tflags: %x\n", + ++i, offset, sch->type, htons(sch->length), + sch->flags); +#endif + offset += SCTP_PAD4(ntohs(sch->length)); + + pr_debug("skb->len: %d\toffset: %d\n", skb->len, offset); + + if (SCTP_CHUNKMAP_IS_SET(info->chunkmap, sch->type)) { + switch (chunk_match_type) { + case SCTP_CHUNK_MATCH_ANY: + if (match_flags(flag_info, flag_count, + sch->type, sch->flags)) { + return true; + } + break; + + case SCTP_CHUNK_MATCH_ALL: + if (match_flags(flag_info, flag_count, + sch->type, sch->flags)) + SCTP_CHUNKMAP_CLEAR(chunkmapcopy, sch->type); + break; + + case SCTP_CHUNK_MATCH_ONLY: + if (!match_flags(flag_info, flag_count, + sch->type, sch->flags)) + return false; + break; + } + } else { + switch (chunk_match_type) { + case SCTP_CHUNK_MATCH_ONLY: + return false; + } + } + } while (offset < skb->len); + + switch (chunk_match_type) { + case SCTP_CHUNK_MATCH_ALL: + return SCTP_CHUNKMAP_IS_CLEAR(chunkmapcopy); + case SCTP_CHUNK_MATCH_ANY: + return false; + case SCTP_CHUNK_MATCH_ONLY: + return true; + } + + /* This will never be reached, but required to stop compiler whine */ + return false; +} + +static bool +sctp_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_sctp_info *info = par->matchinfo; + const struct sctphdr *sh; + struct sctphdr _sh; + + if (par->fragoff != 0) { + pr_debug("Dropping non-first fragment.. FIXME\n"); + return false; + } + + sh = skb_header_pointer(skb, par->thoff, sizeof(_sh), &_sh); + if (sh == NULL) { + pr_debug("Dropping evil TCP offset=0 tinygram.\n"); + par->hotdrop = true; + return false; + } + pr_debug("spt: %d\tdpt: %d\n", ntohs(sh->source), ntohs(sh->dest)); + + return SCCHECK(ntohs(sh->source) >= info->spts[0] + && ntohs(sh->source) <= info->spts[1], + XT_SCTP_SRC_PORTS, info->flags, info->invflags) && + SCCHECK(ntohs(sh->dest) >= info->dpts[0] + && ntohs(sh->dest) <= info->dpts[1], + XT_SCTP_DEST_PORTS, info->flags, info->invflags) && + SCCHECK(match_packet(skb, par->thoff + sizeof(_sh), + info, &par->hotdrop), + XT_SCTP_CHUNK_TYPES, info->flags, info->invflags); +} + +static int sctp_mt_check(const struct xt_mtchk_param *par) +{ + const struct xt_sctp_info *info = par->matchinfo; + + if (info->flag_count > ARRAY_SIZE(info->flag_info)) + return -EINVAL; + if (info->flags & ~XT_SCTP_VALID_FLAGS) + return -EINVAL; + if (info->invflags & ~XT_SCTP_VALID_FLAGS) + return -EINVAL; + if (info->invflags & ~info->flags) + return -EINVAL; + if (!(info->flags & XT_SCTP_CHUNK_TYPES)) + return 0; + if (info->chunk_match_type & (SCTP_CHUNK_MATCH_ALL | + SCTP_CHUNK_MATCH_ANY | SCTP_CHUNK_MATCH_ONLY)) + return 0; + return -EINVAL; +} + +static struct xt_match sctp_mt_reg[] __read_mostly = { + { + .name = "sctp", + .family = NFPROTO_IPV4, + .checkentry = sctp_mt_check, + .match = sctp_mt, + .matchsize = sizeof(struct xt_sctp_info), + .proto = IPPROTO_SCTP, + .me = THIS_MODULE + }, + { + .name = "sctp", + .family = NFPROTO_IPV6, + .checkentry = sctp_mt_check, + .match = sctp_mt, + .matchsize = sizeof(struct xt_sctp_info), + .proto = IPPROTO_SCTP, + .me = THIS_MODULE + }, +}; + +static int __init sctp_mt_init(void) +{ + return xt_register_matches(sctp_mt_reg, ARRAY_SIZE(sctp_mt_reg)); +} + +static void __exit sctp_mt_exit(void) +{ + xt_unregister_matches(sctp_mt_reg, ARRAY_SIZE(sctp_mt_reg)); +} + +module_init(sctp_mt_init); +module_exit(sctp_mt_exit); diff --git a/net/netfilter/xt_set.c b/net/netfilter/xt_set.c new file mode 100644 index 000000000..731bc2caf --- /dev/null +++ b/net/netfilter/xt_set.c @@ -0,0 +1,712 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2000-2002 Joakim Axelsson + * Patrick Schaaf + * Martin Josefsson + * Copyright (C) 2003-2013 Jozsef Kadlecsik + */ + +/* Kernel module which implements the set match and SET target + * for netfilter/iptables. + */ + +#include +#include + +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Jozsef Kadlecsik "); +MODULE_DESCRIPTION("Xtables: IP set match and target module"); +MODULE_ALIAS("xt_SET"); +MODULE_ALIAS("ipt_set"); +MODULE_ALIAS("ip6t_set"); +MODULE_ALIAS("ipt_SET"); +MODULE_ALIAS("ip6t_SET"); + +static inline int +match_set(ip_set_id_t index, const struct sk_buff *skb, + const struct xt_action_param *par, + struct ip_set_adt_opt *opt, int inv) +{ + if (ip_set_test(index, skb, par, opt)) + inv = !inv; + return inv; +} + +#define ADT_OPT(n, f, d, fs, cfs, t, p, b, po, bo) \ +struct ip_set_adt_opt n = { \ + .family = f, \ + .dim = d, \ + .flags = fs, \ + .cmdflags = cfs, \ + .ext.timeout = t, \ + .ext.packets = p, \ + .ext.bytes = b, \ + .ext.packets_op = po, \ + .ext.bytes_op = bo, \ +} + +/* Revision 0 interface: backward compatible with netfilter/iptables */ + +static bool +set_match_v0(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_set_info_match_v0 *info = par->matchinfo; + + ADT_OPT(opt, xt_family(par), info->match_set.u.compat.dim, + info->match_set.u.compat.flags, 0, UINT_MAX, + 0, 0, 0, 0); + + return match_set(info->match_set.index, skb, par, &opt, + info->match_set.u.compat.flags & IPSET_INV_MATCH); +} + +static void +compat_flags(struct xt_set_info_v0 *info) +{ + u_int8_t i; + + /* Fill out compatibility data according to enum ip_set_kopt */ + info->u.compat.dim = IPSET_DIM_ZERO; + if (info->u.flags[0] & IPSET_MATCH_INV) + info->u.compat.flags |= IPSET_INV_MATCH; + for (i = 0; i < IPSET_DIM_MAX - 1 && info->u.flags[i]; i++) { + info->u.compat.dim++; + if (info->u.flags[i] & IPSET_SRC) + info->u.compat.flags |= (1 << info->u.compat.dim); + } +} + +static int +set_match_v0_checkentry(const struct xt_mtchk_param *par) +{ + struct xt_set_info_match_v0 *info = par->matchinfo; + ip_set_id_t index; + + index = ip_set_nfnl_get_byindex(par->net, info->match_set.index); + + if (index == IPSET_INVALID_ID) { + pr_info_ratelimited("Cannot find set identified by id %u to match\n", + info->match_set.index); + return -ENOENT; + } + if (info->match_set.u.flags[IPSET_DIM_MAX - 1] != 0) { + pr_info_ratelimited("set match dimension is over the limit!\n"); + ip_set_nfnl_put(par->net, info->match_set.index); + return -ERANGE; + } + + /* Fill out compatibility data */ + compat_flags(&info->match_set); + + return 0; +} + +static void +set_match_v0_destroy(const struct xt_mtdtor_param *par) +{ + struct xt_set_info_match_v0 *info = par->matchinfo; + + ip_set_nfnl_put(par->net, info->match_set.index); +} + +/* Revision 1 match */ + +static bool +set_match_v1(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_set_info_match_v1 *info = par->matchinfo; + + ADT_OPT(opt, xt_family(par), info->match_set.dim, + info->match_set.flags, 0, UINT_MAX, + 0, 0, 0, 0); + + if (opt.flags & IPSET_RETURN_NOMATCH) + opt.cmdflags |= IPSET_FLAG_RETURN_NOMATCH; + + return match_set(info->match_set.index, skb, par, &opt, + info->match_set.flags & IPSET_INV_MATCH); +} + +static int +set_match_v1_checkentry(const struct xt_mtchk_param *par) +{ + struct xt_set_info_match_v1 *info = par->matchinfo; + ip_set_id_t index; + + index = ip_set_nfnl_get_byindex(par->net, info->match_set.index); + + if (index == IPSET_INVALID_ID) { + pr_info_ratelimited("Cannot find set identified by id %u to match\n", + info->match_set.index); + return -ENOENT; + } + if (info->match_set.dim > IPSET_DIM_MAX) { + pr_info_ratelimited("set match dimension is over the limit!\n"); + ip_set_nfnl_put(par->net, info->match_set.index); + return -ERANGE; + } + + return 0; +} + +static void +set_match_v1_destroy(const struct xt_mtdtor_param *par) +{ + struct xt_set_info_match_v1 *info = par->matchinfo; + + ip_set_nfnl_put(par->net, info->match_set.index); +} + +/* Revision 3 match */ + +static bool +set_match_v3(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_set_info_match_v3 *info = par->matchinfo; + + ADT_OPT(opt, xt_family(par), info->match_set.dim, + info->match_set.flags, info->flags, UINT_MAX, + info->packets.value, info->bytes.value, + info->packets.op, info->bytes.op); + + if (info->packets.op != IPSET_COUNTER_NONE || + info->bytes.op != IPSET_COUNTER_NONE) + opt.cmdflags |= IPSET_FLAG_MATCH_COUNTERS; + + return match_set(info->match_set.index, skb, par, &opt, + info->match_set.flags & IPSET_INV_MATCH); +} + +#define set_match_v3_checkentry set_match_v1_checkentry +#define set_match_v3_destroy set_match_v1_destroy + +/* Revision 4 match */ + +static bool +set_match_v4(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_set_info_match_v4 *info = par->matchinfo; + + ADT_OPT(opt, xt_family(par), info->match_set.dim, + info->match_set.flags, info->flags, UINT_MAX, + info->packets.value, info->bytes.value, + info->packets.op, info->bytes.op); + + if (info->packets.op != IPSET_COUNTER_NONE || + info->bytes.op != IPSET_COUNTER_NONE) + opt.cmdflags |= IPSET_FLAG_MATCH_COUNTERS; + + return match_set(info->match_set.index, skb, par, &opt, + info->match_set.flags & IPSET_INV_MATCH); +} + +#define set_match_v4_checkentry set_match_v1_checkentry +#define set_match_v4_destroy set_match_v1_destroy + +/* Revision 0 interface: backward compatible with netfilter/iptables */ + +static unsigned int +set_target_v0(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_set_info_target_v0 *info = par->targinfo; + + ADT_OPT(add_opt, xt_family(par), info->add_set.u.compat.dim, + info->add_set.u.compat.flags, 0, UINT_MAX, + 0, 0, 0, 0); + ADT_OPT(del_opt, xt_family(par), info->del_set.u.compat.dim, + info->del_set.u.compat.flags, 0, UINT_MAX, + 0, 0, 0, 0); + + if (info->add_set.index != IPSET_INVALID_ID) + ip_set_add(info->add_set.index, skb, par, &add_opt); + if (info->del_set.index != IPSET_INVALID_ID) + ip_set_del(info->del_set.index, skb, par, &del_opt); + + return XT_CONTINUE; +} + +static int +set_target_v0_checkentry(const struct xt_tgchk_param *par) +{ + struct xt_set_info_target_v0 *info = par->targinfo; + ip_set_id_t index; + + if (info->add_set.index != IPSET_INVALID_ID) { + index = ip_set_nfnl_get_byindex(par->net, info->add_set.index); + if (index == IPSET_INVALID_ID) { + pr_info_ratelimited("Cannot find add_set index %u as target\n", + info->add_set.index); + return -ENOENT; + } + } + + if (info->del_set.index != IPSET_INVALID_ID) { + index = ip_set_nfnl_get_byindex(par->net, info->del_set.index); + if (index == IPSET_INVALID_ID) { + pr_info_ratelimited("Cannot find del_set index %u as target\n", + info->del_set.index); + if (info->add_set.index != IPSET_INVALID_ID) + ip_set_nfnl_put(par->net, info->add_set.index); + return -ENOENT; + } + } + if (info->add_set.u.flags[IPSET_DIM_MAX - 1] != 0 || + info->del_set.u.flags[IPSET_DIM_MAX - 1] != 0) { + pr_info_ratelimited("SET target dimension over the limit!\n"); + if (info->add_set.index != IPSET_INVALID_ID) + ip_set_nfnl_put(par->net, info->add_set.index); + if (info->del_set.index != IPSET_INVALID_ID) + ip_set_nfnl_put(par->net, info->del_set.index); + return -ERANGE; + } + + /* Fill out compatibility data */ + compat_flags(&info->add_set); + compat_flags(&info->del_set); + + return 0; +} + +static void +set_target_v0_destroy(const struct xt_tgdtor_param *par) +{ + const struct xt_set_info_target_v0 *info = par->targinfo; + + if (info->add_set.index != IPSET_INVALID_ID) + ip_set_nfnl_put(par->net, info->add_set.index); + if (info->del_set.index != IPSET_INVALID_ID) + ip_set_nfnl_put(par->net, info->del_set.index); +} + +/* Revision 1 target */ + +static unsigned int +set_target_v1(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_set_info_target_v1 *info = par->targinfo; + + ADT_OPT(add_opt, xt_family(par), info->add_set.dim, + info->add_set.flags, 0, UINT_MAX, + 0, 0, 0, 0); + ADT_OPT(del_opt, xt_family(par), info->del_set.dim, + info->del_set.flags, 0, UINT_MAX, + 0, 0, 0, 0); + + if (info->add_set.index != IPSET_INVALID_ID) + ip_set_add(info->add_set.index, skb, par, &add_opt); + if (info->del_set.index != IPSET_INVALID_ID) + ip_set_del(info->del_set.index, skb, par, &del_opt); + + return XT_CONTINUE; +} + +static int +set_target_v1_checkentry(const struct xt_tgchk_param *par) +{ + const struct xt_set_info_target_v1 *info = par->targinfo; + ip_set_id_t index; + + if (info->add_set.index != IPSET_INVALID_ID) { + index = ip_set_nfnl_get_byindex(par->net, info->add_set.index); + if (index == IPSET_INVALID_ID) { + pr_info_ratelimited("Cannot find add_set index %u as target\n", + info->add_set.index); + return -ENOENT; + } + } + + if (info->del_set.index != IPSET_INVALID_ID) { + index = ip_set_nfnl_get_byindex(par->net, info->del_set.index); + if (index == IPSET_INVALID_ID) { + pr_info_ratelimited("Cannot find del_set index %u as target\n", + info->del_set.index); + if (info->add_set.index != IPSET_INVALID_ID) + ip_set_nfnl_put(par->net, info->add_set.index); + return -ENOENT; + } + } + if (info->add_set.dim > IPSET_DIM_MAX || + info->del_set.dim > IPSET_DIM_MAX) { + pr_info_ratelimited("SET target dimension over the limit!\n"); + if (info->add_set.index != IPSET_INVALID_ID) + ip_set_nfnl_put(par->net, info->add_set.index); + if (info->del_set.index != IPSET_INVALID_ID) + ip_set_nfnl_put(par->net, info->del_set.index); + return -ERANGE; + } + + return 0; +} + +static void +set_target_v1_destroy(const struct xt_tgdtor_param *par) +{ + const struct xt_set_info_target_v1 *info = par->targinfo; + + if (info->add_set.index != IPSET_INVALID_ID) + ip_set_nfnl_put(par->net, info->add_set.index); + if (info->del_set.index != IPSET_INVALID_ID) + ip_set_nfnl_put(par->net, info->del_set.index); +} + +/* Revision 2 target */ + +static unsigned int +set_target_v2(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_set_info_target_v2 *info = par->targinfo; + + ADT_OPT(add_opt, xt_family(par), info->add_set.dim, + info->add_set.flags, info->flags, info->timeout, + 0, 0, 0, 0); + ADT_OPT(del_opt, xt_family(par), info->del_set.dim, + info->del_set.flags, 0, UINT_MAX, + 0, 0, 0, 0); + + /* Normalize to fit into jiffies */ + if (add_opt.ext.timeout != IPSET_NO_TIMEOUT && + add_opt.ext.timeout > IPSET_MAX_TIMEOUT) + add_opt.ext.timeout = IPSET_MAX_TIMEOUT; + if (info->add_set.index != IPSET_INVALID_ID) + ip_set_add(info->add_set.index, skb, par, &add_opt); + if (info->del_set.index != IPSET_INVALID_ID) + ip_set_del(info->del_set.index, skb, par, &del_opt); + + return XT_CONTINUE; +} + +#define set_target_v2_checkentry set_target_v1_checkentry +#define set_target_v2_destroy set_target_v1_destroy + +/* Revision 3 target */ + +#define MOPT(opt, member) ((opt).ext.skbinfo.member) + +static unsigned int +set_target_v3(struct sk_buff *skb, const struct xt_action_param *par) +{ + const struct xt_set_info_target_v3 *info = par->targinfo; + int ret; + + ADT_OPT(add_opt, xt_family(par), info->add_set.dim, + info->add_set.flags, info->flags, info->timeout, + 0, 0, 0, 0); + ADT_OPT(del_opt, xt_family(par), info->del_set.dim, + info->del_set.flags, 0, UINT_MAX, + 0, 0, 0, 0); + ADT_OPT(map_opt, xt_family(par), info->map_set.dim, + info->map_set.flags, 0, UINT_MAX, + 0, 0, 0, 0); + + /* Normalize to fit into jiffies */ + if (add_opt.ext.timeout != IPSET_NO_TIMEOUT && + add_opt.ext.timeout > IPSET_MAX_TIMEOUT) + add_opt.ext.timeout = IPSET_MAX_TIMEOUT; + if (info->add_set.index != IPSET_INVALID_ID) + ip_set_add(info->add_set.index, skb, par, &add_opt); + if (info->del_set.index != IPSET_INVALID_ID) + ip_set_del(info->del_set.index, skb, par, &del_opt); + if (info->map_set.index != IPSET_INVALID_ID) { + map_opt.cmdflags |= info->flags & (IPSET_FLAG_MAP_SKBMARK | + IPSET_FLAG_MAP_SKBPRIO | + IPSET_FLAG_MAP_SKBQUEUE); + ret = match_set(info->map_set.index, skb, par, &map_opt, + info->map_set.flags & IPSET_INV_MATCH); + if (!ret) + return XT_CONTINUE; + if (map_opt.cmdflags & IPSET_FLAG_MAP_SKBMARK) + skb->mark = (skb->mark & ~MOPT(map_opt,skbmarkmask)) + ^ MOPT(map_opt, skbmark); + if (map_opt.cmdflags & IPSET_FLAG_MAP_SKBPRIO) + skb->priority = MOPT(map_opt, skbprio); + if ((map_opt.cmdflags & IPSET_FLAG_MAP_SKBQUEUE) && + skb->dev && + skb->dev->real_num_tx_queues > MOPT(map_opt, skbqueue)) + skb_set_queue_mapping(skb, MOPT(map_opt, skbqueue)); + } + return XT_CONTINUE; +} + +static int +set_target_v3_checkentry(const struct xt_tgchk_param *par) +{ + const struct xt_set_info_target_v3 *info = par->targinfo; + ip_set_id_t index; + int ret = 0; + + if (info->add_set.index != IPSET_INVALID_ID) { + index = ip_set_nfnl_get_byindex(par->net, + info->add_set.index); + if (index == IPSET_INVALID_ID) { + pr_info_ratelimited("Cannot find add_set index %u as target\n", + info->add_set.index); + return -ENOENT; + } + } + + if (info->del_set.index != IPSET_INVALID_ID) { + index = ip_set_nfnl_get_byindex(par->net, + info->del_set.index); + if (index == IPSET_INVALID_ID) { + pr_info_ratelimited("Cannot find del_set index %u as target\n", + info->del_set.index); + ret = -ENOENT; + goto cleanup_add; + } + } + + if (info->map_set.index != IPSET_INVALID_ID) { + if (strncmp(par->table, "mangle", 7)) { + pr_info_ratelimited("--map-set only usable from mangle table\n"); + ret = -EINVAL; + goto cleanup_del; + } + if (((info->flags & IPSET_FLAG_MAP_SKBPRIO) | + (info->flags & IPSET_FLAG_MAP_SKBQUEUE)) && + (par->hook_mask & ~(1 << NF_INET_FORWARD | + 1 << NF_INET_LOCAL_OUT | + 1 << NF_INET_POST_ROUTING))) { + pr_info_ratelimited("mapping of prio or/and queue is allowed only from OUTPUT/FORWARD/POSTROUTING chains\n"); + ret = -EINVAL; + goto cleanup_del; + } + index = ip_set_nfnl_get_byindex(par->net, + info->map_set.index); + if (index == IPSET_INVALID_ID) { + pr_info_ratelimited("Cannot find map_set index %u as target\n", + info->map_set.index); + ret = -ENOENT; + goto cleanup_del; + } + } + + if (info->add_set.dim > IPSET_DIM_MAX || + info->del_set.dim > IPSET_DIM_MAX || + info->map_set.dim > IPSET_DIM_MAX) { + pr_info_ratelimited("SET target dimension over the limit!\n"); + ret = -ERANGE; + goto cleanup_mark; + } + + return 0; +cleanup_mark: + if (info->map_set.index != IPSET_INVALID_ID) + ip_set_nfnl_put(par->net, info->map_set.index); +cleanup_del: + if (info->del_set.index != IPSET_INVALID_ID) + ip_set_nfnl_put(par->net, info->del_set.index); +cleanup_add: + if (info->add_set.index != IPSET_INVALID_ID) + ip_set_nfnl_put(par->net, info->add_set.index); + return ret; +} + +static void +set_target_v3_destroy(const struct xt_tgdtor_param *par) +{ + const struct xt_set_info_target_v3 *info = par->targinfo; + + if (info->add_set.index != IPSET_INVALID_ID) + ip_set_nfnl_put(par->net, info->add_set.index); + if (info->del_set.index != IPSET_INVALID_ID) + ip_set_nfnl_put(par->net, info->del_set.index); + if (info->map_set.index != IPSET_INVALID_ID) + ip_set_nfnl_put(par->net, info->map_set.index); +} + +static struct xt_match set_matches[] __read_mostly = { + { + .name = "set", + .family = NFPROTO_IPV4, + .revision = 0, + .match = set_match_v0, + .matchsize = sizeof(struct xt_set_info_match_v0), + .checkentry = set_match_v0_checkentry, + .destroy = set_match_v0_destroy, + .me = THIS_MODULE + }, + { + .name = "set", + .family = NFPROTO_IPV4, + .revision = 1, + .match = set_match_v1, + .matchsize = sizeof(struct xt_set_info_match_v1), + .checkentry = set_match_v1_checkentry, + .destroy = set_match_v1_destroy, + .me = THIS_MODULE + }, + { + .name = "set", + .family = NFPROTO_IPV6, + .revision = 1, + .match = set_match_v1, + .matchsize = sizeof(struct xt_set_info_match_v1), + .checkentry = set_match_v1_checkentry, + .destroy = set_match_v1_destroy, + .me = THIS_MODULE + }, + /* --return-nomatch flag support */ + { + .name = "set", + .family = NFPROTO_IPV4, + .revision = 2, + .match = set_match_v1, + .matchsize = sizeof(struct xt_set_info_match_v1), + .checkentry = set_match_v1_checkentry, + .destroy = set_match_v1_destroy, + .me = THIS_MODULE + }, + { + .name = "set", + .family = NFPROTO_IPV6, + .revision = 2, + .match = set_match_v1, + .matchsize = sizeof(struct xt_set_info_match_v1), + .checkentry = set_match_v1_checkentry, + .destroy = set_match_v1_destroy, + .me = THIS_MODULE + }, + /* counters support: update, match */ + { + .name = "set", + .family = NFPROTO_IPV4, + .revision = 3, + .match = set_match_v3, + .matchsize = sizeof(struct xt_set_info_match_v3), + .checkentry = set_match_v3_checkentry, + .destroy = set_match_v3_destroy, + .me = THIS_MODULE + }, + { + .name = "set", + .family = NFPROTO_IPV6, + .revision = 3, + .match = set_match_v3, + .matchsize = sizeof(struct xt_set_info_match_v3), + .checkentry = set_match_v3_checkentry, + .destroy = set_match_v3_destroy, + .me = THIS_MODULE + }, + /* new revision for counters support: update, match */ + { + .name = "set", + .family = NFPROTO_IPV4, + .revision = 4, + .match = set_match_v4, + .matchsize = sizeof(struct xt_set_info_match_v4), + .checkentry = set_match_v4_checkentry, + .destroy = set_match_v4_destroy, + .me = THIS_MODULE + }, + { + .name = "set", + .family = NFPROTO_IPV6, + .revision = 4, + .match = set_match_v4, + .matchsize = sizeof(struct xt_set_info_match_v4), + .checkentry = set_match_v4_checkentry, + .destroy = set_match_v4_destroy, + .me = THIS_MODULE + }, +}; + +static struct xt_target set_targets[] __read_mostly = { + { + .name = "SET", + .revision = 0, + .family = NFPROTO_IPV4, + .target = set_target_v0, + .targetsize = sizeof(struct xt_set_info_target_v0), + .checkentry = set_target_v0_checkentry, + .destroy = set_target_v0_destroy, + .me = THIS_MODULE + }, + { + .name = "SET", + .revision = 1, + .family = NFPROTO_IPV4, + .target = set_target_v1, + .targetsize = sizeof(struct xt_set_info_target_v1), + .checkentry = set_target_v1_checkentry, + .destroy = set_target_v1_destroy, + .me = THIS_MODULE + }, + { + .name = "SET", + .revision = 1, + .family = NFPROTO_IPV6, + .target = set_target_v1, + .targetsize = sizeof(struct xt_set_info_target_v1), + .checkentry = set_target_v1_checkentry, + .destroy = set_target_v1_destroy, + .me = THIS_MODULE + }, + /* --timeout and --exist flags support */ + { + .name = "SET", + .revision = 2, + .family = NFPROTO_IPV4, + .target = set_target_v2, + .targetsize = sizeof(struct xt_set_info_target_v2), + .checkentry = set_target_v2_checkentry, + .destroy = set_target_v2_destroy, + .me = THIS_MODULE + }, + { + .name = "SET", + .revision = 2, + .family = NFPROTO_IPV6, + .target = set_target_v2, + .targetsize = sizeof(struct xt_set_info_target_v2), + .checkentry = set_target_v2_checkentry, + .destroy = set_target_v2_destroy, + .me = THIS_MODULE + }, + /* --map-set support */ + { + .name = "SET", + .revision = 3, + .family = NFPROTO_IPV4, + .target = set_target_v3, + .targetsize = sizeof(struct xt_set_info_target_v3), + .checkentry = set_target_v3_checkentry, + .destroy = set_target_v3_destroy, + .me = THIS_MODULE + }, + { + .name = "SET", + .revision = 3, + .family = NFPROTO_IPV6, + .target = set_target_v3, + .targetsize = sizeof(struct xt_set_info_target_v3), + .checkentry = set_target_v3_checkentry, + .destroy = set_target_v3_destroy, + .me = THIS_MODULE + }, +}; + +static int __init xt_set_init(void) +{ + int ret = xt_register_matches(set_matches, ARRAY_SIZE(set_matches)); + + if (!ret) { + ret = xt_register_targets(set_targets, + ARRAY_SIZE(set_targets)); + if (ret) + xt_unregister_matches(set_matches, + ARRAY_SIZE(set_matches)); + } + return ret; +} + +static void __exit xt_set_fini(void) +{ + xt_unregister_matches(set_matches, ARRAY_SIZE(set_matches)); + xt_unregister_targets(set_targets, ARRAY_SIZE(set_targets)); +} + +module_init(xt_set_init); +module_exit(xt_set_fini); diff --git a/net/netfilter/xt_socket.c b/net/netfilter/xt_socket.c new file mode 100644 index 000000000..76e01f292 --- /dev/null +++ b/net/netfilter/xt_socket.c @@ -0,0 +1,336 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Transparent proxy support for Linux/iptables + * + * Copyright (C) 2007-2008 BalaBit IT Ltd. + * Author: Krisztian Kovacs + */ +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) +#include +#include +#include +#endif + +#include +#include + +/* "socket" match based redirection (no specific rule) + * =================================================== + * + * There are connections with dynamic endpoints (e.g. FTP data + * connection) that the user is unable to add explicit rules + * for. These are taken care of by a generic "socket" rule. It is + * assumed that the proxy application is trusted to open such + * connections without explicit iptables rule (except of course the + * generic 'socket' rule). In this case the following sockets are + * matched in preference order: + * + * - match: if there's a fully established connection matching the + * _packet_ tuple + * + * - match: if there's a non-zero bound listener (possibly with a + * non-local address) We don't accept zero-bound listeners, since + * then local services could intercept traffic going through the + * box. + */ +static bool +socket_match(const struct sk_buff *skb, struct xt_action_param *par, + const struct xt_socket_mtinfo1 *info) +{ + struct sk_buff *pskb = (struct sk_buff *)skb; + struct sock *sk = skb->sk; + + if (sk && !net_eq(xt_net(par), sock_net(sk))) + sk = NULL; + + if (!sk) + sk = nf_sk_lookup_slow_v4(xt_net(par), skb, xt_in(par)); + + if (sk) { + bool wildcard; + bool transparent = true; + + /* Ignore sockets listening on INADDR_ANY, + * unless XT_SOCKET_NOWILDCARD is set + */ + wildcard = (!(info->flags & XT_SOCKET_NOWILDCARD) && + sk_fullsock(sk) && + inet_sk(sk)->inet_rcv_saddr == 0); + + /* Ignore non-transparent sockets, + * if XT_SOCKET_TRANSPARENT is used + */ + if (info->flags & XT_SOCKET_TRANSPARENT) + transparent = inet_sk_transparent(sk); + + if (info->flags & XT_SOCKET_RESTORESKMARK && !wildcard && + transparent && sk_fullsock(sk)) + pskb->mark = READ_ONCE(sk->sk_mark); + + if (sk != skb->sk) + sock_gen_put(sk); + + if (wildcard || !transparent) + sk = NULL; + } + + return sk != NULL; +} + +static bool +socket_mt4_v0(const struct sk_buff *skb, struct xt_action_param *par) +{ + static struct xt_socket_mtinfo1 xt_info_v0 = { + .flags = 0, + }; + + return socket_match(skb, par, &xt_info_v0); +} + +static bool +socket_mt4_v1_v2_v3(const struct sk_buff *skb, struct xt_action_param *par) +{ + return socket_match(skb, par, par->matchinfo); +} + +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) +static bool +socket_mt6_v1_v2_v3(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_socket_mtinfo1 *info = (struct xt_socket_mtinfo1 *) par->matchinfo; + struct sk_buff *pskb = (struct sk_buff *)skb; + struct sock *sk = skb->sk; + + if (sk && !net_eq(xt_net(par), sock_net(sk))) + sk = NULL; + + if (!sk) + sk = nf_sk_lookup_slow_v6(xt_net(par), skb, xt_in(par)); + + if (sk) { + bool wildcard; + bool transparent = true; + + /* Ignore sockets listening on INADDR_ANY + * unless XT_SOCKET_NOWILDCARD is set + */ + wildcard = (!(info->flags & XT_SOCKET_NOWILDCARD) && + sk_fullsock(sk) && + ipv6_addr_any(&sk->sk_v6_rcv_saddr)); + + /* Ignore non-transparent sockets, + * if XT_SOCKET_TRANSPARENT is used + */ + if (info->flags & XT_SOCKET_TRANSPARENT) + transparent = inet_sk_transparent(sk); + + if (info->flags & XT_SOCKET_RESTORESKMARK && !wildcard && + transparent && sk_fullsock(sk)) + pskb->mark = READ_ONCE(sk->sk_mark); + + if (sk != skb->sk) + sock_gen_put(sk); + + if (wildcard || !transparent) + sk = NULL; + } + + return sk != NULL; +} +#endif + +static int socket_mt_enable_defrag(struct net *net, int family) +{ + switch (family) { + case NFPROTO_IPV4: + return nf_defrag_ipv4_enable(net); +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + case NFPROTO_IPV6: + return nf_defrag_ipv6_enable(net); +#endif + } + WARN_ONCE(1, "Unknown family %d\n", family); + return 0; +} + +static int socket_mt_v1_check(const struct xt_mtchk_param *par) +{ + const struct xt_socket_mtinfo1 *info = (struct xt_socket_mtinfo1 *) par->matchinfo; + int err; + + err = socket_mt_enable_defrag(par->net, par->family); + if (err) + return err; + + if (info->flags & ~XT_SOCKET_FLAGS_V1) { + pr_info_ratelimited("unknown flags 0x%x\n", + info->flags & ~XT_SOCKET_FLAGS_V1); + return -EINVAL; + } + return 0; +} + +static int socket_mt_v2_check(const struct xt_mtchk_param *par) +{ + const struct xt_socket_mtinfo2 *info = (struct xt_socket_mtinfo2 *) par->matchinfo; + int err; + + err = socket_mt_enable_defrag(par->net, par->family); + if (err) + return err; + + if (info->flags & ~XT_SOCKET_FLAGS_V2) { + pr_info_ratelimited("unknown flags 0x%x\n", + info->flags & ~XT_SOCKET_FLAGS_V2); + return -EINVAL; + } + return 0; +} + +static int socket_mt_v3_check(const struct xt_mtchk_param *par) +{ + const struct xt_socket_mtinfo3 *info = + (struct xt_socket_mtinfo3 *)par->matchinfo; + int err; + + err = socket_mt_enable_defrag(par->net, par->family); + if (err) + return err; + if (info->flags & ~XT_SOCKET_FLAGS_V3) { + pr_info_ratelimited("unknown flags 0x%x\n", + info->flags & ~XT_SOCKET_FLAGS_V3); + return -EINVAL; + } + return 0; +} + +static void socket_mt_destroy(const struct xt_mtdtor_param *par) +{ + if (par->family == NFPROTO_IPV4) + nf_defrag_ipv4_disable(par->net); +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + else if (par->family == NFPROTO_IPV6) + nf_defrag_ipv6_disable(par->net); +#endif +} + +static struct xt_match socket_mt_reg[] __read_mostly = { + { + .name = "socket", + .revision = 0, + .family = NFPROTO_IPV4, + .match = socket_mt4_v0, + .hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN), + .me = THIS_MODULE, + }, + { + .name = "socket", + .revision = 1, + .family = NFPROTO_IPV4, + .match = socket_mt4_v1_v2_v3, + .destroy = socket_mt_destroy, + .checkentry = socket_mt_v1_check, + .matchsize = sizeof(struct xt_socket_mtinfo1), + .hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN), + .me = THIS_MODULE, + }, +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + { + .name = "socket", + .revision = 1, + .family = NFPROTO_IPV6, + .match = socket_mt6_v1_v2_v3, + .checkentry = socket_mt_v1_check, + .matchsize = sizeof(struct xt_socket_mtinfo1), + .destroy = socket_mt_destroy, + .hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN), + .me = THIS_MODULE, + }, +#endif + { + .name = "socket", + .revision = 2, + .family = NFPROTO_IPV4, + .match = socket_mt4_v1_v2_v3, + .checkentry = socket_mt_v2_check, + .destroy = socket_mt_destroy, + .matchsize = sizeof(struct xt_socket_mtinfo1), + .hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN), + .me = THIS_MODULE, + }, +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + { + .name = "socket", + .revision = 2, + .family = NFPROTO_IPV6, + .match = socket_mt6_v1_v2_v3, + .checkentry = socket_mt_v2_check, + .destroy = socket_mt_destroy, + .matchsize = sizeof(struct xt_socket_mtinfo1), + .hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN), + .me = THIS_MODULE, + }, +#endif + { + .name = "socket", + .revision = 3, + .family = NFPROTO_IPV4, + .match = socket_mt4_v1_v2_v3, + .checkentry = socket_mt_v3_check, + .destroy = socket_mt_destroy, + .matchsize = sizeof(struct xt_socket_mtinfo1), + .hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN), + .me = THIS_MODULE, + }, +#if IS_ENABLED(CONFIG_IP6_NF_IPTABLES) + { + .name = "socket", + .revision = 3, + .family = NFPROTO_IPV6, + .match = socket_mt6_v1_v2_v3, + .checkentry = socket_mt_v3_check, + .destroy = socket_mt_destroy, + .matchsize = sizeof(struct xt_socket_mtinfo1), + .hooks = (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN), + .me = THIS_MODULE, + }, +#endif +}; + +static int __init socket_mt_init(void) +{ + return xt_register_matches(socket_mt_reg, ARRAY_SIZE(socket_mt_reg)); +} + +static void __exit socket_mt_exit(void) +{ + xt_unregister_matches(socket_mt_reg, ARRAY_SIZE(socket_mt_reg)); +} + +module_init(socket_mt_init); +module_exit(socket_mt_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Krisztian Kovacs, Balazs Scheidler"); +MODULE_DESCRIPTION("x_tables socket match module"); +MODULE_ALIAS("ipt_socket"); +MODULE_ALIAS("ip6t_socket"); diff --git a/net/netfilter/xt_state.c b/net/netfilter/xt_state.c new file mode 100644 index 000000000..bbe07b1be --- /dev/null +++ b/net/netfilter/xt_state.c @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Kernel module to match connection tracking information. */ + +/* (C) 1999-2001 Paul `Rusty' Russell + * (C) 2002-2005 Netfilter Core Team + */ + +#include +#include +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Rusty Russell "); +MODULE_DESCRIPTION("ip[6]_tables connection tracking state match module"); +MODULE_ALIAS("ipt_state"); +MODULE_ALIAS("ip6t_state"); + +static bool +state_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_state_info *sinfo = par->matchinfo; + enum ip_conntrack_info ctinfo; + unsigned int statebit; + struct nf_conn *ct = nf_ct_get(skb, &ctinfo); + + if (ct) + statebit = XT_STATE_BIT(ctinfo); + else if (ctinfo == IP_CT_UNTRACKED) + statebit = XT_STATE_UNTRACKED; + else + statebit = XT_STATE_INVALID; + + return (sinfo->statemask & statebit); +} + +static int state_mt_check(const struct xt_mtchk_param *par) +{ + int ret; + + ret = nf_ct_netns_get(par->net, par->family); + if (ret < 0) + pr_info_ratelimited("cannot load conntrack support for proto=%u\n", + par->family); + return ret; +} + +static void state_mt_destroy(const struct xt_mtdtor_param *par) +{ + nf_ct_netns_put(par->net, par->family); +} + +static struct xt_match state_mt_reg __read_mostly = { + .name = "state", + .family = NFPROTO_UNSPEC, + .checkentry = state_mt_check, + .match = state_mt, + .destroy = state_mt_destroy, + .matchsize = sizeof(struct xt_state_info), + .me = THIS_MODULE, +}; + +static int __init state_mt_init(void) +{ + return xt_register_match(&state_mt_reg); +} + +static void __exit state_mt_exit(void) +{ + xt_unregister_match(&state_mt_reg); +} + +module_init(state_mt_init); +module_exit(state_mt_exit); diff --git a/net/netfilter/xt_statistic.c b/net/netfilter/xt_statistic.c new file mode 100644 index 000000000..b26c1dcfc --- /dev/null +++ b/net/netfilter/xt_statistic.c @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2006 Patrick McHardy + * + * Based on ipt_random and ipt_nth by Fabrice MARIE . + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +struct xt_statistic_priv { + atomic_t count; +} ____cacheline_aligned_in_smp; + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_DESCRIPTION("Xtables: statistics-based matching (\"Nth\", random)"); +MODULE_ALIAS("ipt_statistic"); +MODULE_ALIAS("ip6t_statistic"); + +static bool +statistic_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_statistic_info *info = par->matchinfo; + bool ret = info->flags & XT_STATISTIC_INVERT; + int nval, oval; + + switch (info->mode) { + case XT_STATISTIC_MODE_RANDOM: + if ((get_random_u32() & 0x7FFFFFFF) < info->u.random.probability) + ret = !ret; + break; + case XT_STATISTIC_MODE_NTH: + do { + oval = atomic_read(&info->master->count); + nval = (oval == info->u.nth.every) ? 0 : oval + 1; + } while (atomic_cmpxchg(&info->master->count, oval, nval) != oval); + if (nval == 0) + ret = !ret; + break; + } + + return ret; +} + +static int statistic_mt_check(const struct xt_mtchk_param *par) +{ + struct xt_statistic_info *info = par->matchinfo; + + if (info->mode > XT_STATISTIC_MODE_MAX || + info->flags & ~XT_STATISTIC_MASK) + return -EINVAL; + + info->master = kzalloc(sizeof(*info->master), GFP_KERNEL); + if (info->master == NULL) + return -ENOMEM; + atomic_set(&info->master->count, info->u.nth.count); + + return 0; +} + +static void statistic_mt_destroy(const struct xt_mtdtor_param *par) +{ + const struct xt_statistic_info *info = par->matchinfo; + + kfree(info->master); +} + +static struct xt_match xt_statistic_mt_reg __read_mostly = { + .name = "statistic", + .revision = 0, + .family = NFPROTO_UNSPEC, + .match = statistic_mt, + .checkentry = statistic_mt_check, + .destroy = statistic_mt_destroy, + .matchsize = sizeof(struct xt_statistic_info), + .usersize = offsetof(struct xt_statistic_info, master), + .me = THIS_MODULE, +}; + +static int __init statistic_mt_init(void) +{ + return xt_register_match(&xt_statistic_mt_reg); +} + +static void __exit statistic_mt_exit(void) +{ + xt_unregister_match(&xt_statistic_mt_reg); +} + +module_init(statistic_mt_init); +module_exit(statistic_mt_exit); diff --git a/net/netfilter/xt_string.c b/net/netfilter/xt_string.c new file mode 100644 index 000000000..8ce25bc9b --- /dev/null +++ b/net/netfilter/xt_string.c @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* String matching match for iptables + * + * (C) 2005 Pablo Neira Ayuso + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +MODULE_AUTHOR("Pablo Neira Ayuso "); +MODULE_DESCRIPTION("Xtables: string-based matching"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_string"); +MODULE_ALIAS("ip6t_string"); +MODULE_ALIAS("ebt_string"); + +static bool +string_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_string_info *conf = par->matchinfo; + bool invert; + + invert = conf->u.v1.flags & XT_STRING_FLAG_INVERT; + + return (skb_find_text((struct sk_buff *)skb, conf->from_offset, + conf->to_offset, conf->config) + != UINT_MAX) ^ invert; +} + +#define STRING_TEXT_PRIV(m) ((struct xt_string_info *)(m)) + +static int string_mt_check(const struct xt_mtchk_param *par) +{ + struct xt_string_info *conf = par->matchinfo; + struct ts_config *ts_conf; + int flags = TS_AUTOLOAD; + + /* Damn, can't handle this case properly with iptables... */ + if (conf->from_offset > conf->to_offset) + return -EINVAL; + if (conf->algo[XT_STRING_MAX_ALGO_NAME_SIZE - 1] != '\0') + return -EINVAL; + if (conf->patlen > XT_STRING_MAX_PATTERN_SIZE) + return -EINVAL; + if (conf->u.v1.flags & + ~(XT_STRING_FLAG_IGNORECASE | XT_STRING_FLAG_INVERT)) + return -EINVAL; + if (conf->u.v1.flags & XT_STRING_FLAG_IGNORECASE) + flags |= TS_IGNORECASE; + ts_conf = textsearch_prepare(conf->algo, conf->pattern, conf->patlen, + GFP_KERNEL, flags); + if (IS_ERR(ts_conf)) + return PTR_ERR(ts_conf); + + conf->config = ts_conf; + return 0; +} + +static void string_mt_destroy(const struct xt_mtdtor_param *par) +{ + textsearch_destroy(STRING_TEXT_PRIV(par->matchinfo)->config); +} + +static struct xt_match xt_string_mt_reg __read_mostly = { + .name = "string", + .revision = 1, + .family = NFPROTO_UNSPEC, + .checkentry = string_mt_check, + .match = string_mt, + .destroy = string_mt_destroy, + .matchsize = sizeof(struct xt_string_info), + .usersize = offsetof(struct xt_string_info, config), + .me = THIS_MODULE, +}; + +static int __init string_mt_init(void) +{ + return xt_register_match(&xt_string_mt_reg); +} + +static void __exit string_mt_exit(void) +{ + xt_unregister_match(&xt_string_mt_reg); +} + +module_init(string_mt_init); +module_exit(string_mt_exit); diff --git a/net/netfilter/xt_tcpmss.c b/net/netfilter/xt_tcpmss.c new file mode 100644 index 000000000..37704ab01 --- /dev/null +++ b/net/netfilter/xt_tcpmss.c @@ -0,0 +1,107 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Kernel module to match TCP MSS values. */ + +/* Copyright (C) 2000 Marc Boucher + * Portions (C) 2005 by Harald Welte + */ + +#include +#include +#include + +#include +#include + +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Marc Boucher "); +MODULE_DESCRIPTION("Xtables: TCP MSS match"); +MODULE_ALIAS("ipt_tcpmss"); +MODULE_ALIAS("ip6t_tcpmss"); + +static bool +tcpmss_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_tcpmss_match_info *info = par->matchinfo; + const struct tcphdr *th; + struct tcphdr _tcph; + /* tcp.doff is only 4 bits, ie. max 15 * 4 bytes */ + const u_int8_t *op; + u8 _opt[15 * 4 - sizeof(_tcph)]; + unsigned int i, optlen; + + /* If we don't have the whole header, drop packet. */ + th = skb_header_pointer(skb, par->thoff, sizeof(_tcph), &_tcph); + if (th == NULL) + goto dropit; + + /* Malformed. */ + if (th->doff*4 < sizeof(*th)) + goto dropit; + + optlen = th->doff*4 - sizeof(*th); + if (!optlen) + goto out; + + /* Truncated options. */ + op = skb_header_pointer(skb, par->thoff + sizeof(*th), optlen, _opt); + if (op == NULL) + goto dropit; + + for (i = 0; i < optlen; ) { + if (op[i] == TCPOPT_MSS + && (optlen - i) >= TCPOLEN_MSS + && op[i+1] == TCPOLEN_MSS) { + u_int16_t mssval; + + mssval = (op[i+2] << 8) | op[i+3]; + + return (mssval >= info->mss_min && + mssval <= info->mss_max) ^ info->invert; + } + if (op[i] < 2) + i++; + else + i += op[i+1] ? : 1; + } +out: + return info->invert; + +dropit: + par->hotdrop = true; + return false; +} + +static struct xt_match tcpmss_mt_reg[] __read_mostly = { + { + .name = "tcpmss", + .family = NFPROTO_IPV4, + .match = tcpmss_mt, + .matchsize = sizeof(struct xt_tcpmss_match_info), + .proto = IPPROTO_TCP, + .me = THIS_MODULE, + }, + { + .name = "tcpmss", + .family = NFPROTO_IPV6, + .match = tcpmss_mt, + .matchsize = sizeof(struct xt_tcpmss_match_info), + .proto = IPPROTO_TCP, + .me = THIS_MODULE, + }, +}; + +static int __init tcpmss_mt_init(void) +{ + return xt_register_matches(tcpmss_mt_reg, ARRAY_SIZE(tcpmss_mt_reg)); +} + +static void __exit tcpmss_mt_exit(void) +{ + xt_unregister_matches(tcpmss_mt_reg, ARRAY_SIZE(tcpmss_mt_reg)); +} + +module_init(tcpmss_mt_init); +module_exit(tcpmss_mt_exit); diff --git a/net/netfilter/xt_tcpudp.c b/net/netfilter/xt_tcpudp.c new file mode 100644 index 000000000..11ec2abf0 --- /dev/null +++ b/net/netfilter/xt_tcpudp.c @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: GPL-2.0-only +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +MODULE_DESCRIPTION("Xtables: TCP, UDP and UDP-Lite match"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("xt_tcp"); +MODULE_ALIAS("xt_udp"); +MODULE_ALIAS("ipt_udp"); +MODULE_ALIAS("ipt_tcp"); +MODULE_ALIAS("ip6t_udp"); +MODULE_ALIAS("ip6t_tcp"); + +/* Returns 1 if the port is matched by the range, 0 otherwise */ +static inline bool +port_match(u_int16_t min, u_int16_t max, u_int16_t port, bool invert) +{ + return (port >= min && port <= max) ^ invert; +} + +static bool +tcp_find_option(u_int8_t option, + const struct sk_buff *skb, + unsigned int protoff, + unsigned int optlen, + bool invert, + bool *hotdrop) +{ + /* tcp.doff is only 4 bits, ie. max 15 * 4 bytes */ + const u_int8_t *op; + u_int8_t _opt[60 - sizeof(struct tcphdr)]; + unsigned int i; + + pr_debug("finding option\n"); + + if (!optlen) + return invert; + + /* If we don't have the whole header, drop packet. */ + op = skb_header_pointer(skb, protoff + sizeof(struct tcphdr), + optlen, _opt); + if (op == NULL) { + *hotdrop = true; + return false; + } + + for (i = 0; i < optlen; ) { + if (op[i] == option) return !invert; + if (op[i] < 2) i++; + else i += op[i+1]?:1; + } + + return invert; +} + +static bool tcp_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct tcphdr *th; + struct tcphdr _tcph; + const struct xt_tcp *tcpinfo = par->matchinfo; + + if (par->fragoff != 0) { + /* To quote Alan: + + Don't allow a fragment of TCP 8 bytes in. Nobody normal + causes this. Its a cracker trying to break in by doing a + flag overwrite to pass the direction checks. + */ + if (par->fragoff == 1) { + pr_debug("Dropping evil TCP offset=1 frag.\n"); + par->hotdrop = true; + } + /* Must not be a fragment. */ + return false; + } + + th = skb_header_pointer(skb, par->thoff, sizeof(_tcph), &_tcph); + if (th == NULL) { + /* We've been asked to examine this packet, and we + can't. Hence, no choice but to drop. */ + pr_debug("Dropping evil TCP offset=0 tinygram.\n"); + par->hotdrop = true; + return false; + } + + if (!port_match(tcpinfo->spts[0], tcpinfo->spts[1], + ntohs(th->source), + !!(tcpinfo->invflags & XT_TCP_INV_SRCPT))) + return false; + if (!port_match(tcpinfo->dpts[0], tcpinfo->dpts[1], + ntohs(th->dest), + !!(tcpinfo->invflags & XT_TCP_INV_DSTPT))) + return false; + if (!NF_INVF(tcpinfo, XT_TCP_INV_FLAGS, + (((unsigned char *)th)[13] & tcpinfo->flg_mask) == tcpinfo->flg_cmp)) + return false; + if (tcpinfo->option) { + if (th->doff * 4 < sizeof(_tcph)) { + par->hotdrop = true; + return false; + } + if (!tcp_find_option(tcpinfo->option, skb, par->thoff, + th->doff*4 - sizeof(_tcph), + tcpinfo->invflags & XT_TCP_INV_OPTION, + &par->hotdrop)) + return false; + } + return true; +} + +static int tcp_mt_check(const struct xt_mtchk_param *par) +{ + const struct xt_tcp *tcpinfo = par->matchinfo; + + /* Must specify no unknown invflags */ + return (tcpinfo->invflags & ~XT_TCP_INV_MASK) ? -EINVAL : 0; +} + +static bool udp_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct udphdr *uh; + struct udphdr _udph; + const struct xt_udp *udpinfo = par->matchinfo; + + /* Must not be a fragment. */ + if (par->fragoff != 0) + return false; + + uh = skb_header_pointer(skb, par->thoff, sizeof(_udph), &_udph); + if (uh == NULL) { + /* We've been asked to examine this packet, and we + can't. Hence, no choice but to drop. */ + pr_debug("Dropping evil UDP tinygram.\n"); + par->hotdrop = true; + return false; + } + + return port_match(udpinfo->spts[0], udpinfo->spts[1], + ntohs(uh->source), + !!(udpinfo->invflags & XT_UDP_INV_SRCPT)) + && port_match(udpinfo->dpts[0], udpinfo->dpts[1], + ntohs(uh->dest), + !!(udpinfo->invflags & XT_UDP_INV_DSTPT)); +} + +static int udp_mt_check(const struct xt_mtchk_param *par) +{ + const struct xt_udp *udpinfo = par->matchinfo; + + /* Must specify no unknown invflags */ + return (udpinfo->invflags & ~XT_UDP_INV_MASK) ? -EINVAL : 0; +} + +static struct xt_match tcpudp_mt_reg[] __read_mostly = { + { + .name = "tcp", + .family = NFPROTO_IPV4, + .checkentry = tcp_mt_check, + .match = tcp_mt, + .matchsize = sizeof(struct xt_tcp), + .proto = IPPROTO_TCP, + .me = THIS_MODULE, + }, + { + .name = "tcp", + .family = NFPROTO_IPV6, + .checkentry = tcp_mt_check, + .match = tcp_mt, + .matchsize = sizeof(struct xt_tcp), + .proto = IPPROTO_TCP, + .me = THIS_MODULE, + }, + { + .name = "udp", + .family = NFPROTO_IPV4, + .checkentry = udp_mt_check, + .match = udp_mt, + .matchsize = sizeof(struct xt_udp), + .proto = IPPROTO_UDP, + .me = THIS_MODULE, + }, + { + .name = "udp", + .family = NFPROTO_IPV6, + .checkentry = udp_mt_check, + .match = udp_mt, + .matchsize = sizeof(struct xt_udp), + .proto = IPPROTO_UDP, + .me = THIS_MODULE, + }, + { + .name = "udplite", + .family = NFPROTO_IPV4, + .checkentry = udp_mt_check, + .match = udp_mt, + .matchsize = sizeof(struct xt_udp), + .proto = IPPROTO_UDPLITE, + .me = THIS_MODULE, + }, + { + .name = "udplite", + .family = NFPROTO_IPV6, + .checkentry = udp_mt_check, + .match = udp_mt, + .matchsize = sizeof(struct xt_udp), + .proto = IPPROTO_UDPLITE, + .me = THIS_MODULE, + }, +}; + +static int __init tcpudp_mt_init(void) +{ + return xt_register_matches(tcpudp_mt_reg, ARRAY_SIZE(tcpudp_mt_reg)); +} + +static void __exit tcpudp_mt_exit(void) +{ + xt_unregister_matches(tcpudp_mt_reg, ARRAY_SIZE(tcpudp_mt_reg)); +} + +module_init(tcpudp_mt_init); +module_exit(tcpudp_mt_exit); diff --git a/net/netfilter/xt_time.c b/net/netfilter/xt_time.c new file mode 100644 index 000000000..6aa12d0f5 --- /dev/null +++ b/net/netfilter/xt_time.c @@ -0,0 +1,300 @@ +/* + * xt_time + * Copyright © CC Computer Consultants GmbH, 2007 + * + * based on ipt_time by Fabrice MARIE + * This is a module which is used for time matching + * It is using some modified code from dietlibc (localtime() function) + * that you can find at https://www.fefe.de/dietlibc/ + * This file is distributed under the terms of the GNU General Public + * License (GPL). Copies of the GPL can be obtained from gnu.org/gpl. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include + +struct xtm { + u_int8_t month; /* (1-12) */ + u_int8_t monthday; /* (1-31) */ + u_int8_t weekday; /* (1-7) */ + u_int8_t hour; /* (0-23) */ + u_int8_t minute; /* (0-59) */ + u_int8_t second; /* (0-59) */ + unsigned int dse; +}; + +extern struct timezone sys_tz; /* ouch */ + +static const u_int16_t days_since_year[] = { + 0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, +}; + +static const u_int16_t days_since_leapyear[] = { + 0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335, +}; + +/* + * Since time progresses forward, it is best to organize this array in reverse, + * to minimize lookup time. + */ +enum { + DSE_FIRST = 2039, + SECONDS_PER_DAY = 86400, +}; +static const u_int16_t days_since_epoch[] = { + /* 2039 - 2030 */ + 25202, 24837, 24472, 24106, 23741, 23376, 23011, 22645, 22280, 21915, + /* 2029 - 2020 */ + 21550, 21184, 20819, 20454, 20089, 19723, 19358, 18993, 18628, 18262, + /* 2019 - 2010 */ + 17897, 17532, 17167, 16801, 16436, 16071, 15706, 15340, 14975, 14610, + /* 2009 - 2000 */ + 14245, 13879, 13514, 13149, 12784, 12418, 12053, 11688, 11323, 10957, + /* 1999 - 1990 */ + 10592, 10227, 9862, 9496, 9131, 8766, 8401, 8035, 7670, 7305, + /* 1989 - 1980 */ + 6940, 6574, 6209, 5844, 5479, 5113, 4748, 4383, 4018, 3652, + /* 1979 - 1970 */ + 3287, 2922, 2557, 2191, 1826, 1461, 1096, 730, 365, 0, +}; + +static inline bool is_leap(unsigned int y) +{ + return y % 4 == 0 && (y % 100 != 0 || y % 400 == 0); +} + +/* + * Each network packet has a (nano)seconds-since-the-epoch (SSTE) timestamp. + * Since we match against days and daytime, the SSTE value needs to be + * computed back into human-readable dates. + * + * This is done in three separate functions so that the most expensive + * calculations are done last, in case a "simple match" can be found earlier. + */ +static inline unsigned int localtime_1(struct xtm *r, time64_t time) +{ + unsigned int v, w; + + /* Each day has 86400s, so finding the hour/minute is actually easy. */ + div_u64_rem(time, SECONDS_PER_DAY, &v); + r->second = v % 60; + w = v / 60; + r->minute = w % 60; + r->hour = w / 60; + return v; +} + +static inline void localtime_2(struct xtm *r, time64_t time) +{ + /* + * Here comes the rest (weekday, monthday). First, divide the SSTE + * by seconds-per-day to get the number of _days_ since the epoch. + */ + r->dse = div_u64(time, SECONDS_PER_DAY); + + /* + * 1970-01-01 (w=0) was a Thursday (4). + * -1 and +1 map Sunday properly onto 7. + */ + r->weekday = (4 + r->dse - 1) % 7 + 1; +} + +static void localtime_3(struct xtm *r, time64_t time) +{ + unsigned int year, i, w = r->dse; + + /* + * In each year, a certain number of days-since-the-epoch have passed. + * Find the year that is closest to said days. + * + * Consider, for example, w=21612 (2029-03-04). Loop will abort on + * dse[i] <= w, which happens when dse[i] == 21550. This implies + * year == 2009. w will then be 62. + */ + for (i = 0, year = DSE_FIRST; days_since_epoch[i] > w; + ++i, --year) + /* just loop */; + + w -= days_since_epoch[i]; + + /* + * By now we have the current year, and the day of the year. + * r->yearday = w; + * + * On to finding the month (like above). In each month, a certain + * number of days-since-New Year have passed, and find the closest + * one. + * + * Consider w=62 (in a non-leap year). Loop will abort on + * dsy[i] < w, which happens when dsy[i] == 31+28 (i == 2). + * Concludes i == 2, i.e. 3rd month => March. + * + * (A different approach to use would be to subtract a monthlength + * from w repeatedly while counting.) + */ + if (is_leap(year)) { + /* use days_since_leapyear[] in a leap year */ + for (i = ARRAY_SIZE(days_since_leapyear) - 1; + i > 0 && days_since_leapyear[i] > w; --i) + /* just loop */; + r->monthday = w - days_since_leapyear[i] + 1; + } else { + for (i = ARRAY_SIZE(days_since_year) - 1; + i > 0 && days_since_year[i] > w; --i) + /* just loop */; + r->monthday = w - days_since_year[i] + 1; + } + + r->month = i + 1; +} + +static bool +time_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_time_info *info = par->matchinfo; + unsigned int packet_time; + struct xtm current_time; + time64_t stamp; + + /* + * We need real time here, but we can neither use skb->tstamp + * nor __net_timestamp(). + * + * skb->tstamp and skb->skb_mstamp_ns overlap, however, they + * use different clock types (real vs monotonic). + * + * Suppose you have two rules: + * 1. match before 13:00 + * 2. match after 13:00 + * + * If you match against processing time (ktime_get_real_seconds) it + * may happen that the same packet matches both rules if + * it arrived at the right moment before 13:00, so it would be + * better to check skb->tstamp and set it via __net_timestamp() + * if needed. This however breaks outgoing packets tx timestamp, + * and causes them to get delayed forever by fq packet scheduler. + */ + stamp = ktime_get_real_seconds(); + + if (info->flags & XT_TIME_LOCAL_TZ) + /* Adjust for local timezone */ + stamp -= 60 * sys_tz.tz_minuteswest; + + /* + * xt_time will match when _all_ of the following hold: + * - 'now' is in the global time range date_start..date_end + * - 'now' is in the monthday mask + * - 'now' is in the weekday mask + * - 'now' is in the daytime range time_start..time_end + * (and by default, libxt_time will set these so as to match) + * + * note: info->date_start/stop are unsigned 32-bit values that + * can hold values beyond y2038, but not after y2106. + */ + + if (stamp < info->date_start || stamp > info->date_stop) + return false; + + packet_time = localtime_1(¤t_time, stamp); + + if (info->daytime_start < info->daytime_stop) { + if (packet_time < info->daytime_start || + packet_time > info->daytime_stop) + return false; + } else { + if (packet_time < info->daytime_start && + packet_time > info->daytime_stop) + return false; + + /** if user asked to ignore 'next day', then e.g. + * '1 PM Wed, August 1st' should be treated + * like 'Tue 1 PM July 31st'. + * + * This also causes + * 'Monday, "23:00 to 01:00", to match for 2 hours, starting + * Monday 23:00 to Tuesday 01:00. + */ + if ((info->flags & XT_TIME_CONTIGUOUS) && + packet_time <= info->daytime_stop) + stamp -= SECONDS_PER_DAY; + } + + localtime_2(¤t_time, stamp); + + if (!(info->weekdays_match & (1 << current_time.weekday))) + return false; + + /* Do not spend time computing monthday if all days match anyway */ + if (info->monthdays_match != XT_TIME_ALL_MONTHDAYS) { + localtime_3(¤t_time, stamp); + if (!(info->monthdays_match & (1 << current_time.monthday))) + return false; + } + + return true; +} + +static int time_mt_check(const struct xt_mtchk_param *par) +{ + const struct xt_time_info *info = par->matchinfo; + + if (info->daytime_start > XT_TIME_MAX_DAYTIME || + info->daytime_stop > XT_TIME_MAX_DAYTIME) { + pr_info_ratelimited("invalid argument - start or stop time greater than 23:59:59\n"); + return -EDOM; + } + + if (info->flags & ~XT_TIME_ALL_FLAGS) { + pr_info_ratelimited("unknown flags 0x%x\n", + info->flags & ~XT_TIME_ALL_FLAGS); + return -EINVAL; + } + + if ((info->flags & XT_TIME_CONTIGUOUS) && + info->daytime_start < info->daytime_stop) + return -EINVAL; + + return 0; +} + +static struct xt_match xt_time_mt_reg __read_mostly = { + .name = "time", + .family = NFPROTO_UNSPEC, + .match = time_mt, + .checkentry = time_mt_check, + .matchsize = sizeof(struct xt_time_info), + .me = THIS_MODULE, +}; + +static int __init time_mt_init(void) +{ + int minutes = sys_tz.tz_minuteswest; + + if (minutes < 0) /* east of Greenwich */ + pr_info("kernel timezone is +%02d%02d\n", + -minutes / 60, -minutes % 60); + else /* west of Greenwich */ + pr_info("kernel timezone is -%02d%02d\n", + minutes / 60, minutes % 60); + + return xt_register_match(&xt_time_mt_reg); +} + +static void __exit time_mt_exit(void) +{ + xt_unregister_match(&xt_time_mt_reg); +} + +module_init(time_mt_init); +module_exit(time_mt_exit); +MODULE_AUTHOR("Jan Engelhardt "); +MODULE_DESCRIPTION("Xtables: time-based matching"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_time"); +MODULE_ALIAS("ip6t_time"); diff --git a/net/netfilter/xt_u32.c b/net/netfilter/xt_u32.c new file mode 100644 index 000000000..117d4615d --- /dev/null +++ b/net/netfilter/xt_u32.c @@ -0,0 +1,145 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * xt_u32 - kernel module to match u32 packet content + * + * Original author: Don Cohen + * (C) CC Computer Consultants GmbH, 2007 + */ + +#include +#include +#include +#include +#include +#include +#include + +static bool u32_match_it(const struct xt_u32 *data, + const struct sk_buff *skb) +{ + const struct xt_u32_test *ct; + unsigned int testind; + unsigned int nnums; + unsigned int nvals; + unsigned int i; + __be32 n; + u_int32_t pos; + u_int32_t val; + u_int32_t at; + + /* + * Small example: "0 >> 28 == 4 && 8 & 0xFF0000 >> 16 = 6, 17" + * (=IPv4 and (TCP or UDP)). Outer loop runs over the "&&" operands. + */ + for (testind = 0; testind < data->ntests; ++testind) { + ct = &data->tests[testind]; + at = 0; + pos = ct->location[0].number; + + if (skb->len < 4 || pos > skb->len - 4) + return false; + + if (skb_copy_bits(skb, pos, &n, sizeof(n)) < 0) + BUG(); + val = ntohl(n); + nnums = ct->nnums; + + /* Inner loop runs over "&", "<<", ">>" and "@" operands */ + for (i = 1; i < nnums; ++i) { + u_int32_t number = ct->location[i].number; + switch (ct->location[i].nextop) { + case XT_U32_AND: + val &= number; + break; + case XT_U32_LEFTSH: + val <<= number; + break; + case XT_U32_RIGHTSH: + val >>= number; + break; + case XT_U32_AT: + if (at + val < at) + return false; + at += val; + pos = number; + if (at + 4 < at || skb->len < at + 4 || + pos > skb->len - at - 4) + return false; + + if (skb_copy_bits(skb, at + pos, &n, + sizeof(n)) < 0) + BUG(); + val = ntohl(n); + break; + } + } + + /* Run over the "," and ":" operands */ + nvals = ct->nvalues; + for (i = 0; i < nvals; ++i) + if (ct->value[i].min <= val && val <= ct->value[i].max) + break; + + if (i >= ct->nvalues) + return false; + } + + return true; +} + +static bool u32_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_u32 *data = par->matchinfo; + bool ret; + + ret = u32_match_it(data, skb); + return ret ^ data->invert; +} + +static int u32_mt_checkentry(const struct xt_mtchk_param *par) +{ + const struct xt_u32 *data = par->matchinfo; + const struct xt_u32_test *ct; + unsigned int i; + + if (data->ntests > ARRAY_SIZE(data->tests)) + return -EINVAL; + + for (i = 0; i < data->ntests; ++i) { + ct = &data->tests[i]; + + if (ct->nnums > ARRAY_SIZE(ct->location) || + ct->nvalues > ARRAY_SIZE(ct->value)) + return -EINVAL; + } + + return 0; +} + +static struct xt_match xt_u32_mt_reg __read_mostly = { + .name = "u32", + .revision = 0, + .family = NFPROTO_UNSPEC, + .match = u32_mt, + .checkentry = u32_mt_checkentry, + .matchsize = sizeof(struct xt_u32), + .me = THIS_MODULE, +}; + +static int __init u32_mt_init(void) +{ + return xt_register_match(&xt_u32_mt_reg); +} + +static void __exit u32_mt_exit(void) +{ + xt_unregister_match(&xt_u32_mt_reg); +} + +module_init(u32_mt_init); +module_exit(u32_mt_exit); +MODULE_AUTHOR("Jan Engelhardt "); +MODULE_DESCRIPTION("Xtables: arbitrary byte matching"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_u32"); +MODULE_ALIAS("ip6t_u32"); diff --git a/net/netlabel/Kconfig b/net/netlabel/Kconfig new file mode 100644 index 000000000..4383ac296 --- /dev/null +++ b/net/netlabel/Kconfig @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# NetLabel configuration +# + +config NETLABEL + bool "NetLabel subsystem support" + depends on SECURITY + select CRC_CCITT if IPV6 + default n + help + NetLabel provides support for explicit network packet labeling + protocols such as CIPSO and RIPSO. For more information see + Documentation/netlabel as well as the NetLabel SourceForge project + for configuration tools and additional documentation. + + * https://github.com/netlabel/netlabel_tools + + If you are unsure, say N. diff --git a/net/netlabel/Makefile b/net/netlabel/Makefile new file mode 100644 index 000000000..5a46381a6 --- /dev/null +++ b/net/netlabel/Makefile @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the NetLabel subsystem. +# + +# base objects +obj-y := netlabel_user.o netlabel_kapi.o +obj-y += netlabel_domainhash.o netlabel_addrlist.o + +# management objects +obj-y += netlabel_mgmt.o + +# protocol modules +obj-y += netlabel_unlabeled.o +obj-y += netlabel_cipso_v4.o +obj-$(subst m,y,$(CONFIG_IPV6)) += netlabel_calipso.o diff --git a/net/netlabel/netlabel_addrlist.c b/net/netlabel/netlabel_addrlist.c new file mode 100644 index 000000000..3282acf7f --- /dev/null +++ b/net/netlabel/netlabel_addrlist.c @@ -0,0 +1,369 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NetLabel Network Address Lists + * + * This file contains network address list functions used to manage ordered + * lists of network addresses for use by the NetLabel subsystem. The NetLabel + * system manages static and dynamic label mappings for network protocols such + * as CIPSO and RIPSO. + * + * Author: Paul Moore + */ + +/* + * (c) Copyright Hewlett-Packard Development Company, L.P., 2008 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "netlabel_addrlist.h" + +/* + * Address List Functions + */ + +/** + * netlbl_af4list_search - Search for a matching IPv4 address entry + * @addr: IPv4 address + * @head: the list head + * + * Description: + * Searches the IPv4 address list given by @head. If a matching address entry + * is found it is returned, otherwise NULL is returned. The caller is + * responsible for calling the rcu_read_[un]lock() functions. + * + */ +struct netlbl_af4list *netlbl_af4list_search(__be32 addr, + struct list_head *head) +{ + struct netlbl_af4list *iter; + + list_for_each_entry_rcu(iter, head, list) + if (iter->valid && (addr & iter->mask) == iter->addr) + return iter; + + return NULL; +} + +/** + * netlbl_af4list_search_exact - Search for an exact IPv4 address entry + * @addr: IPv4 address + * @mask: IPv4 address mask + * @head: the list head + * + * Description: + * Searches the IPv4 address list given by @head. If an exact match if found + * it is returned, otherwise NULL is returned. The caller is responsible for + * calling the rcu_read_[un]lock() functions. + * + */ +struct netlbl_af4list *netlbl_af4list_search_exact(__be32 addr, + __be32 mask, + struct list_head *head) +{ + struct netlbl_af4list *iter; + + list_for_each_entry_rcu(iter, head, list) + if (iter->valid && iter->addr == addr && iter->mask == mask) + return iter; + + return NULL; +} + + +#if IS_ENABLED(CONFIG_IPV6) +/** + * netlbl_af6list_search - Search for a matching IPv6 address entry + * @addr: IPv6 address + * @head: the list head + * + * Description: + * Searches the IPv6 address list given by @head. If a matching address entry + * is found it is returned, otherwise NULL is returned. The caller is + * responsible for calling the rcu_read_[un]lock() functions. + * + */ +struct netlbl_af6list *netlbl_af6list_search(const struct in6_addr *addr, + struct list_head *head) +{ + struct netlbl_af6list *iter; + + list_for_each_entry_rcu(iter, head, list) + if (iter->valid && + ipv6_masked_addr_cmp(&iter->addr, &iter->mask, addr) == 0) + return iter; + + return NULL; +} + +/** + * netlbl_af6list_search_exact - Search for an exact IPv6 address entry + * @addr: IPv6 address + * @mask: IPv6 address mask + * @head: the list head + * + * Description: + * Searches the IPv6 address list given by @head. If an exact match if found + * it is returned, otherwise NULL is returned. The caller is responsible for + * calling the rcu_read_[un]lock() functions. + * + */ +struct netlbl_af6list *netlbl_af6list_search_exact(const struct in6_addr *addr, + const struct in6_addr *mask, + struct list_head *head) +{ + struct netlbl_af6list *iter; + + list_for_each_entry_rcu(iter, head, list) + if (iter->valid && + ipv6_addr_equal(&iter->addr, addr) && + ipv6_addr_equal(&iter->mask, mask)) + return iter; + + return NULL; +} +#endif /* IPv6 */ + +/** + * netlbl_af4list_add - Add a new IPv4 address entry to a list + * @entry: address entry + * @head: the list head + * + * Description: + * Add a new address entry to the list pointed to by @head. On success zero is + * returned, otherwise a negative value is returned. The caller is responsible + * for calling the necessary locking functions. + * + */ +int netlbl_af4list_add(struct netlbl_af4list *entry, struct list_head *head) +{ + struct netlbl_af4list *iter; + + iter = netlbl_af4list_search(entry->addr, head); + if (iter != NULL && + iter->addr == entry->addr && iter->mask == entry->mask) + return -EEXIST; + + /* in order to speed up address searches through the list (the common + * case) we need to keep the list in order based on the size of the + * address mask such that the entry with the widest mask (smallest + * numerical value) appears first in the list */ + list_for_each_entry_rcu(iter, head, list) + if (iter->valid && + ntohl(entry->mask) > ntohl(iter->mask)) { + __list_add_rcu(&entry->list, + iter->list.prev, + &iter->list); + return 0; + } + list_add_tail_rcu(&entry->list, head); + return 0; +} + +#if IS_ENABLED(CONFIG_IPV6) +/** + * netlbl_af6list_add - Add a new IPv6 address entry to a list + * @entry: address entry + * @head: the list head + * + * Description: + * Add a new address entry to the list pointed to by @head. On success zero is + * returned, otherwise a negative value is returned. The caller is responsible + * for calling the necessary locking functions. + * + */ +int netlbl_af6list_add(struct netlbl_af6list *entry, struct list_head *head) +{ + struct netlbl_af6list *iter; + + iter = netlbl_af6list_search(&entry->addr, head); + if (iter != NULL && + ipv6_addr_equal(&iter->addr, &entry->addr) && + ipv6_addr_equal(&iter->mask, &entry->mask)) + return -EEXIST; + + /* in order to speed up address searches through the list (the common + * case) we need to keep the list in order based on the size of the + * address mask such that the entry with the widest mask (smallest + * numerical value) appears first in the list */ + list_for_each_entry_rcu(iter, head, list) + if (iter->valid && + ipv6_addr_cmp(&entry->mask, &iter->mask) > 0) { + __list_add_rcu(&entry->list, + iter->list.prev, + &iter->list); + return 0; + } + list_add_tail_rcu(&entry->list, head); + return 0; +} +#endif /* IPv6 */ + +/** + * netlbl_af4list_remove_entry - Remove an IPv4 address entry + * @entry: address entry + * + * Description: + * Remove the specified IP address entry. The caller is responsible for + * calling the necessary locking functions. + * + */ +void netlbl_af4list_remove_entry(struct netlbl_af4list *entry) +{ + entry->valid = 0; + list_del_rcu(&entry->list); +} + +/** + * netlbl_af4list_remove - Remove an IPv4 address entry + * @addr: IP address + * @mask: IP address mask + * @head: the list head + * + * Description: + * Remove an IP address entry from the list pointed to by @head. Returns the + * entry on success, NULL on failure. The caller is responsible for calling + * the necessary locking functions. + * + */ +struct netlbl_af4list *netlbl_af4list_remove(__be32 addr, __be32 mask, + struct list_head *head) +{ + struct netlbl_af4list *entry; + + entry = netlbl_af4list_search_exact(addr, mask, head); + if (entry == NULL) + return NULL; + netlbl_af4list_remove_entry(entry); + return entry; +} + +#if IS_ENABLED(CONFIG_IPV6) +/** + * netlbl_af6list_remove_entry - Remove an IPv6 address entry + * @entry: address entry + * + * Description: + * Remove the specified IP address entry. The caller is responsible for + * calling the necessary locking functions. + * + */ +void netlbl_af6list_remove_entry(struct netlbl_af6list *entry) +{ + entry->valid = 0; + list_del_rcu(&entry->list); +} + +/** + * netlbl_af6list_remove - Remove an IPv6 address entry + * @addr: IP address + * @mask: IP address mask + * @head: the list head + * + * Description: + * Remove an IP address entry from the list pointed to by @head. Returns the + * entry on success, NULL on failure. The caller is responsible for calling + * the necessary locking functions. + * + */ +struct netlbl_af6list *netlbl_af6list_remove(const struct in6_addr *addr, + const struct in6_addr *mask, + struct list_head *head) +{ + struct netlbl_af6list *entry; + + entry = netlbl_af6list_search_exact(addr, mask, head); + if (entry == NULL) + return NULL; + netlbl_af6list_remove_entry(entry); + return entry; +} +#endif /* IPv6 */ + +/* + * Audit Helper Functions + */ + +#ifdef CONFIG_AUDIT +/** + * netlbl_af4list_audit_addr - Audit an IPv4 address + * @audit_buf: audit buffer + * @src: true if source address, false if destination + * @dev: network interface + * @addr: IP address + * @mask: IP address mask + * + * Description: + * Write the IPv4 address and address mask, if necessary, to @audit_buf. + * + */ +void netlbl_af4list_audit_addr(struct audit_buffer *audit_buf, + int src, const char *dev, + __be32 addr, __be32 mask) +{ + u32 mask_val = ntohl(mask); + char *dir = (src ? "src" : "dst"); + + if (dev != NULL) + audit_log_format(audit_buf, " netif=%s", dev); + audit_log_format(audit_buf, " %s=%pI4", dir, &addr); + if (mask_val != 0xffffffff) { + u32 mask_len = 0; + while (mask_val > 0) { + mask_val <<= 1; + mask_len++; + } + audit_log_format(audit_buf, " %s_prefixlen=%d", dir, mask_len); + } +} + +#if IS_ENABLED(CONFIG_IPV6) +/** + * netlbl_af6list_audit_addr - Audit an IPv6 address + * @audit_buf: audit buffer + * @src: true if source address, false if destination + * @dev: network interface + * @addr: IP address + * @mask: IP address mask + * + * Description: + * Write the IPv6 address and address mask, if necessary, to @audit_buf. + * + */ +void netlbl_af6list_audit_addr(struct audit_buffer *audit_buf, + int src, + const char *dev, + const struct in6_addr *addr, + const struct in6_addr *mask) +{ + char *dir = (src ? "src" : "dst"); + + if (dev != NULL) + audit_log_format(audit_buf, " netif=%s", dev); + audit_log_format(audit_buf, " %s=%pI6", dir, addr); + if (ntohl(mask->s6_addr32[3]) != 0xffffffff) { + u32 mask_len = 0; + u32 mask_val; + int iter = -1; + while (ntohl(mask->s6_addr32[++iter]) == 0xffffffff) + mask_len += 32; + mask_val = ntohl(mask->s6_addr32[iter]); + while (mask_val > 0) { + mask_val <<= 1; + mask_len++; + } + audit_log_format(audit_buf, " %s_prefixlen=%d", dir, mask_len); + } +} +#endif /* IPv6 */ +#endif /* CONFIG_AUDIT */ diff --git a/net/netlabel/netlabel_addrlist.h b/net/netlabel/netlabel_addrlist.h new file mode 100644 index 000000000..a01cf4955 --- /dev/null +++ b/net/netlabel/netlabel_addrlist.h @@ -0,0 +1,194 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * NetLabel Network Address Lists + * + * This file contains network address list functions used to manage ordered + * lists of network addresses for use by the NetLabel subsystem. The NetLabel + * system manages static and dynamic label mappings for network protocols such + * as CIPSO and RIPSO. + * + * Author: Paul Moore + */ + +/* + * (c) Copyright Hewlett-Packard Development Company, L.P., 2008 + */ + +#ifndef _NETLABEL_ADDRLIST_H +#define _NETLABEL_ADDRLIST_H + +#include +#include +#include +#include +#include + +/** + * struct netlbl_af4list - NetLabel IPv4 address list + * @addr: IPv4 address + * @mask: IPv4 address mask + * @valid: valid flag + * @list: list structure, used internally + */ +struct netlbl_af4list { + __be32 addr; + __be32 mask; + + u32 valid; + struct list_head list; +}; + +/** + * struct netlbl_af6list - NetLabel IPv6 address list + * @addr: IPv6 address + * @mask: IPv6 address mask + * @valid: valid flag + * @list: list structure, used internally + */ +struct netlbl_af6list { + struct in6_addr addr; + struct in6_addr mask; + + u32 valid; + struct list_head list; +}; + +#define __af4list_entry(ptr) container_of(ptr, struct netlbl_af4list, list) + +static inline struct netlbl_af4list *__af4list_valid(struct list_head *s, + struct list_head *h) +{ + struct list_head *i = s; + struct netlbl_af4list *n = __af4list_entry(s); + while (i != h && !n->valid) { + i = i->next; + n = __af4list_entry(i); + } + return n; +} + +static inline struct netlbl_af4list *__af4list_valid_rcu(struct list_head *s, + struct list_head *h) +{ + struct list_head *i = s; + struct netlbl_af4list *n = __af4list_entry(s); + while (i != h && !n->valid) { + i = rcu_dereference(list_next_rcu(i)); + n = __af4list_entry(i); + } + return n; +} + +#define netlbl_af4list_foreach(iter, head) \ + for (iter = __af4list_valid((head)->next, head); \ + &iter->list != (head); \ + iter = __af4list_valid(iter->list.next, head)) + +#define netlbl_af4list_foreach_rcu(iter, head) \ + for (iter = __af4list_valid_rcu((head)->next, head); \ + &iter->list != (head); \ + iter = __af4list_valid_rcu(iter->list.next, head)) + +#define netlbl_af4list_foreach_safe(iter, tmp, head) \ + for (iter = __af4list_valid((head)->next, head), \ + tmp = __af4list_valid(iter->list.next, head); \ + &iter->list != (head); \ + iter = tmp, tmp = __af4list_valid(iter->list.next, head)) + +int netlbl_af4list_add(struct netlbl_af4list *entry, + struct list_head *head); +struct netlbl_af4list *netlbl_af4list_remove(__be32 addr, __be32 mask, + struct list_head *head); +void netlbl_af4list_remove_entry(struct netlbl_af4list *entry); +struct netlbl_af4list *netlbl_af4list_search(__be32 addr, + struct list_head *head); +struct netlbl_af4list *netlbl_af4list_search_exact(__be32 addr, + __be32 mask, + struct list_head *head); + +#ifdef CONFIG_AUDIT +void netlbl_af4list_audit_addr(struct audit_buffer *audit_buf, + int src, const char *dev, + __be32 addr, __be32 mask); +#else +static inline void netlbl_af4list_audit_addr(struct audit_buffer *audit_buf, + int src, const char *dev, + __be32 addr, __be32 mask) +{ +} +#endif + +#if IS_ENABLED(CONFIG_IPV6) + +#define __af6list_entry(ptr) container_of(ptr, struct netlbl_af6list, list) + +static inline struct netlbl_af6list *__af6list_valid(struct list_head *s, + struct list_head *h) +{ + struct list_head *i = s; + struct netlbl_af6list *n = __af6list_entry(s); + while (i != h && !n->valid) { + i = i->next; + n = __af6list_entry(i); + } + return n; +} + +static inline struct netlbl_af6list *__af6list_valid_rcu(struct list_head *s, + struct list_head *h) +{ + struct list_head *i = s; + struct netlbl_af6list *n = __af6list_entry(s); + while (i != h && !n->valid) { + i = rcu_dereference(list_next_rcu(i)); + n = __af6list_entry(i); + } + return n; +} + +#define netlbl_af6list_foreach(iter, head) \ + for (iter = __af6list_valid((head)->next, head); \ + &iter->list != (head); \ + iter = __af6list_valid(iter->list.next, head)) + +#define netlbl_af6list_foreach_rcu(iter, head) \ + for (iter = __af6list_valid_rcu((head)->next, head); \ + &iter->list != (head); \ + iter = __af6list_valid_rcu(iter->list.next, head)) + +#define netlbl_af6list_foreach_safe(iter, tmp, head) \ + for (iter = __af6list_valid((head)->next, head), \ + tmp = __af6list_valid(iter->list.next, head); \ + &iter->list != (head); \ + iter = tmp, tmp = __af6list_valid(iter->list.next, head)) + +int netlbl_af6list_add(struct netlbl_af6list *entry, + struct list_head *head); +struct netlbl_af6list *netlbl_af6list_remove(const struct in6_addr *addr, + const struct in6_addr *mask, + struct list_head *head); +void netlbl_af6list_remove_entry(struct netlbl_af6list *entry); +struct netlbl_af6list *netlbl_af6list_search(const struct in6_addr *addr, + struct list_head *head); +struct netlbl_af6list *netlbl_af6list_search_exact(const struct in6_addr *addr, + const struct in6_addr *mask, + struct list_head *head); + +#ifdef CONFIG_AUDIT +void netlbl_af6list_audit_addr(struct audit_buffer *audit_buf, + int src, + const char *dev, + const struct in6_addr *addr, + const struct in6_addr *mask); +#else +static inline void netlbl_af6list_audit_addr(struct audit_buffer *audit_buf, + int src, + const char *dev, + const struct in6_addr *addr, + const struct in6_addr *mask) +{ +} +#endif +#endif /* IPV6 */ + +#endif diff --git a/net/netlabel/netlabel_calipso.c b/net/netlabel/netlabel_calipso.c new file mode 100644 index 000000000..a07c2216d --- /dev/null +++ b/net/netlabel/netlabel_calipso.c @@ -0,0 +1,735 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NetLabel CALIPSO/IPv6 Support + * + * This file defines the CALIPSO/IPv6 functions for the NetLabel system. The + * NetLabel system manages static and dynamic label mappings for network + * protocols such as CIPSO and CALIPSO. + * + * Authors: Paul Moore + * Huw Davies + */ + +/* (c) Copyright Hewlett-Packard Development Company, L.P., 2006 + * (c) Copyright Huw Davies , 2015 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "netlabel_user.h" +#include "netlabel_calipso.h" +#include "netlabel_mgmt.h" +#include "netlabel_domainhash.h" + +/* Argument struct for calipso_doi_walk() */ +struct netlbl_calipso_doiwalk_arg { + struct netlink_callback *nl_cb; + struct sk_buff *skb; + u32 seq; +}; + +/* Argument struct for netlbl_domhsh_walk() */ +struct netlbl_domhsh_walk_arg { + struct netlbl_audit *audit_info; + u32 doi; +}; + +/* NetLabel Generic NETLINK CALIPSO family */ +static struct genl_family netlbl_calipso_gnl_family; + +/* NetLabel Netlink attribute policy */ +static const struct nla_policy calipso_genl_policy[NLBL_CALIPSO_A_MAX + 1] = { + [NLBL_CALIPSO_A_DOI] = { .type = NLA_U32 }, + [NLBL_CALIPSO_A_MTYPE] = { .type = NLA_U32 }, +}; + +static const struct netlbl_calipso_ops *calipso_ops; + +/** + * netlbl_calipso_ops_register - Register the CALIPSO operations + * @ops: ops to register + * + * Description: + * Register the CALIPSO packet engine operations. + * + */ +const struct netlbl_calipso_ops * +netlbl_calipso_ops_register(const struct netlbl_calipso_ops *ops) +{ + return xchg(&calipso_ops, ops); +} +EXPORT_SYMBOL(netlbl_calipso_ops_register); + +static const struct netlbl_calipso_ops *netlbl_calipso_ops_get(void) +{ + return READ_ONCE(calipso_ops); +} + +/* NetLabel Command Handlers + */ +/** + * netlbl_calipso_add_pass - Adds a CALIPSO pass DOI definition + * @info: the Generic NETLINK info block + * @audit_info: NetLabel audit information + * + * Description: + * Create a new CALIPSO_MAP_PASS DOI definition based on the given ADD message + * and add it to the CALIPSO engine. Return zero on success and non-zero on + * error. + * + */ +static int netlbl_calipso_add_pass(struct genl_info *info, + struct netlbl_audit *audit_info) +{ + int ret_val; + struct calipso_doi *doi_def = NULL; + + doi_def = kmalloc(sizeof(*doi_def), GFP_KERNEL); + if (!doi_def) + return -ENOMEM; + doi_def->type = CALIPSO_MAP_PASS; + doi_def->doi = nla_get_u32(info->attrs[NLBL_CALIPSO_A_DOI]); + ret_val = calipso_doi_add(doi_def, audit_info); + if (ret_val != 0) + calipso_doi_free(doi_def); + + return ret_val; +} + +/** + * netlbl_calipso_add - Handle an ADD message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Create a new DOI definition based on the given ADD message and add it to the + * CALIPSO engine. Returns zero on success, negative values on failure. + * + */ +static int netlbl_calipso_add(struct sk_buff *skb, struct genl_info *info) +{ + int ret_val = -EINVAL; + struct netlbl_audit audit_info; + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (!info->attrs[NLBL_CALIPSO_A_DOI] || + !info->attrs[NLBL_CALIPSO_A_MTYPE]) + return -EINVAL; + + if (!ops) + return -EOPNOTSUPP; + + netlbl_netlink_auditinfo(&audit_info); + switch (nla_get_u32(info->attrs[NLBL_CALIPSO_A_MTYPE])) { + case CALIPSO_MAP_PASS: + ret_val = netlbl_calipso_add_pass(info, &audit_info); + break; + } + if (ret_val == 0) + atomic_inc(&netlabel_mgmt_protocount); + + return ret_val; +} + +/** + * netlbl_calipso_list - Handle a LIST message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Process a user generated LIST message and respond accordingly. + * Returns zero on success and negative values on error. + * + */ +static int netlbl_calipso_list(struct sk_buff *skb, struct genl_info *info) +{ + int ret_val; + struct sk_buff *ans_skb = NULL; + void *data; + u32 doi; + struct calipso_doi *doi_def; + + if (!info->attrs[NLBL_CALIPSO_A_DOI]) { + ret_val = -EINVAL; + goto list_failure; + } + + doi = nla_get_u32(info->attrs[NLBL_CALIPSO_A_DOI]); + + doi_def = calipso_doi_getdef(doi); + if (!doi_def) { + ret_val = -EINVAL; + goto list_failure; + } + + ans_skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!ans_skb) { + ret_val = -ENOMEM; + goto list_failure_put; + } + data = genlmsg_put_reply(ans_skb, info, &netlbl_calipso_gnl_family, + 0, NLBL_CALIPSO_C_LIST); + if (!data) { + ret_val = -ENOMEM; + goto list_failure_put; + } + + ret_val = nla_put_u32(ans_skb, NLBL_CALIPSO_A_MTYPE, doi_def->type); + if (ret_val != 0) + goto list_failure_put; + + calipso_doi_putdef(doi_def); + + genlmsg_end(ans_skb, data); + return genlmsg_reply(ans_skb, info); + +list_failure_put: + calipso_doi_putdef(doi_def); +list_failure: + kfree_skb(ans_skb); + return ret_val; +} + +/** + * netlbl_calipso_listall_cb - calipso_doi_walk() callback for LISTALL + * @doi_def: the CALIPSO DOI definition + * @arg: the netlbl_calipso_doiwalk_arg structure + * + * Description: + * This function is designed to be used as a callback to the + * calipso_doi_walk() function for use in generating a response for a LISTALL + * message. Returns the size of the message on success, negative values on + * failure. + * + */ +static int netlbl_calipso_listall_cb(struct calipso_doi *doi_def, void *arg) +{ + int ret_val = -ENOMEM; + struct netlbl_calipso_doiwalk_arg *cb_arg = arg; + void *data; + + data = genlmsg_put(cb_arg->skb, NETLINK_CB(cb_arg->nl_cb->skb).portid, + cb_arg->seq, &netlbl_calipso_gnl_family, + NLM_F_MULTI, NLBL_CALIPSO_C_LISTALL); + if (!data) + goto listall_cb_failure; + + ret_val = nla_put_u32(cb_arg->skb, NLBL_CALIPSO_A_DOI, doi_def->doi); + if (ret_val != 0) + goto listall_cb_failure; + ret_val = nla_put_u32(cb_arg->skb, + NLBL_CALIPSO_A_MTYPE, + doi_def->type); + if (ret_val != 0) + goto listall_cb_failure; + + genlmsg_end(cb_arg->skb, data); + return 0; + +listall_cb_failure: + genlmsg_cancel(cb_arg->skb, data); + return ret_val; +} + +/** + * netlbl_calipso_listall - Handle a LISTALL message + * @skb: the NETLINK buffer + * @cb: the NETLINK callback + * + * Description: + * Process a user generated LISTALL message and respond accordingly. Returns + * zero on success and negative values on error. + * + */ +static int netlbl_calipso_listall(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct netlbl_calipso_doiwalk_arg cb_arg; + u32 doi_skip = cb->args[0]; + + cb_arg.nl_cb = cb; + cb_arg.skb = skb; + cb_arg.seq = cb->nlh->nlmsg_seq; + + calipso_doi_walk(&doi_skip, netlbl_calipso_listall_cb, &cb_arg); + + cb->args[0] = doi_skip; + return skb->len; +} + +/** + * netlbl_calipso_remove_cb - netlbl_calipso_remove() callback for REMOVE + * @entry: LSM domain mapping entry + * @arg: the netlbl_domhsh_walk_arg structure + * + * Description: + * This function is intended for use by netlbl_calipso_remove() as the callback + * for the netlbl_domhsh_walk() function; it removes LSM domain map entries + * which are associated with the CALIPSO DOI specified in @arg. Returns zero on + * success, negative values on failure. + * + */ +static int netlbl_calipso_remove_cb(struct netlbl_dom_map *entry, void *arg) +{ + struct netlbl_domhsh_walk_arg *cb_arg = arg; + + if (entry->def.type == NETLBL_NLTYPE_CALIPSO && + entry->def.calipso->doi == cb_arg->doi) + return netlbl_domhsh_remove_entry(entry, cb_arg->audit_info); + + return 0; +} + +/** + * netlbl_calipso_remove - Handle a REMOVE message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Process a user generated REMOVE message and respond accordingly. Returns + * zero on success, negative values on failure. + * + */ +static int netlbl_calipso_remove(struct sk_buff *skb, struct genl_info *info) +{ + int ret_val = -EINVAL; + struct netlbl_domhsh_walk_arg cb_arg; + struct netlbl_audit audit_info; + u32 skip_bkt = 0; + u32 skip_chain = 0; + + if (!info->attrs[NLBL_CALIPSO_A_DOI]) + return -EINVAL; + + netlbl_netlink_auditinfo(&audit_info); + cb_arg.doi = nla_get_u32(info->attrs[NLBL_CALIPSO_A_DOI]); + cb_arg.audit_info = &audit_info; + ret_val = netlbl_domhsh_walk(&skip_bkt, &skip_chain, + netlbl_calipso_remove_cb, &cb_arg); + if (ret_val == 0 || ret_val == -ENOENT) { + ret_val = calipso_doi_remove(cb_arg.doi, &audit_info); + if (ret_val == 0) + atomic_dec(&netlabel_mgmt_protocount); + } + + return ret_val; +} + +/* NetLabel Generic NETLINK Command Definitions + */ + +static const struct genl_small_ops netlbl_calipso_ops[] = { + { + .cmd = NLBL_CALIPSO_C_ADD, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = netlbl_calipso_add, + .dumpit = NULL, + }, + { + .cmd = NLBL_CALIPSO_C_REMOVE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = netlbl_calipso_remove, + .dumpit = NULL, + }, + { + .cmd = NLBL_CALIPSO_C_LIST, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, + .doit = netlbl_calipso_list, + .dumpit = NULL, + }, + { + .cmd = NLBL_CALIPSO_C_LISTALL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, + .doit = NULL, + .dumpit = netlbl_calipso_listall, + }, +}; + +static struct genl_family netlbl_calipso_gnl_family __ro_after_init = { + .hdrsize = 0, + .name = NETLBL_NLTYPE_CALIPSO_NAME, + .version = NETLBL_PROTO_VERSION, + .maxattr = NLBL_CALIPSO_A_MAX, + .policy = calipso_genl_policy, + .module = THIS_MODULE, + .small_ops = netlbl_calipso_ops, + .n_small_ops = ARRAY_SIZE(netlbl_calipso_ops), + .resv_start_op = NLBL_CALIPSO_C_LISTALL + 1, +}; + +/* NetLabel Generic NETLINK Protocol Functions + */ + +/** + * netlbl_calipso_genl_init - Register the CALIPSO NetLabel component + * + * Description: + * Register the CALIPSO packet NetLabel component with the Generic NETLINK + * mechanism. Returns zero on success, negative values on failure. + * + */ +int __init netlbl_calipso_genl_init(void) +{ + return genl_register_family(&netlbl_calipso_gnl_family); +} + +/** + * calipso_doi_add - Add a new DOI to the CALIPSO protocol engine + * @doi_def: the DOI structure + * @audit_info: NetLabel audit information + * + * Description: + * The caller defines a new DOI for use by the CALIPSO engine and calls this + * function to add it to the list of acceptable domains. The caller must + * ensure that the mapping table specified in @doi_def->map meets all of the + * requirements of the mapping type (see calipso.h for details). Returns + * zero on success and non-zero on failure. + * + */ +int calipso_doi_add(struct calipso_doi *doi_def, + struct netlbl_audit *audit_info) +{ + int ret_val = -ENOMSG; + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (ops) + ret_val = ops->doi_add(doi_def, audit_info); + return ret_val; +} + +/** + * calipso_doi_free - Frees a DOI definition + * @doi_def: the DOI definition + * + * Description: + * This function frees all of the memory associated with a DOI definition. + * + */ +void calipso_doi_free(struct calipso_doi *doi_def) +{ + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (ops) + ops->doi_free(doi_def); +} + +/** + * calipso_doi_remove - Remove an existing DOI from the CALIPSO protocol engine + * @doi: the DOI value + * @audit_info: NetLabel audit information + * + * Description: + * Removes a DOI definition from the CALIPSO engine. The NetLabel routines will + * be called to release their own LSM domain mappings as well as our own + * domain list. Returns zero on success and negative values on failure. + * + */ +int calipso_doi_remove(u32 doi, struct netlbl_audit *audit_info) +{ + int ret_val = -ENOMSG; + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (ops) + ret_val = ops->doi_remove(doi, audit_info); + return ret_val; +} + +/** + * calipso_doi_getdef - Returns a reference to a valid DOI definition + * @doi: the DOI value + * + * Description: + * Searches for a valid DOI definition and if one is found it is returned to + * the caller. Otherwise NULL is returned. The caller must ensure that + * calipso_doi_putdef() is called when the caller is done. + * + */ +struct calipso_doi *calipso_doi_getdef(u32 doi) +{ + struct calipso_doi *ret_val = NULL; + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (ops) + ret_val = ops->doi_getdef(doi); + return ret_val; +} + +/** + * calipso_doi_putdef - Releases a reference for the given DOI definition + * @doi_def: the DOI definition + * + * Description: + * Releases a DOI definition reference obtained from calipso_doi_getdef(). + * + */ +void calipso_doi_putdef(struct calipso_doi *doi_def) +{ + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (ops) + ops->doi_putdef(doi_def); +} + +/** + * calipso_doi_walk - Iterate through the DOI definitions + * @skip_cnt: skip past this number of DOI definitions, updated + * @callback: callback for each DOI definition + * @cb_arg: argument for the callback function + * + * Description: + * Iterate over the DOI definition list, skipping the first @skip_cnt entries. + * For each entry call @callback, if @callback returns a negative value stop + * 'walking' through the list and return. Updates the value in @skip_cnt upon + * return. Returns zero on success, negative values on failure. + * + */ +int calipso_doi_walk(u32 *skip_cnt, + int (*callback)(struct calipso_doi *doi_def, void *arg), + void *cb_arg) +{ + int ret_val = -ENOMSG; + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (ops) + ret_val = ops->doi_walk(skip_cnt, callback, cb_arg); + return ret_val; +} + +/** + * calipso_sock_getattr - Get the security attributes from a sock + * @sk: the sock + * @secattr: the security attributes + * + * Description: + * Query @sk to see if there is a CALIPSO option attached to the sock and if + * there is return the CALIPSO security attributes in @secattr. This function + * requires that @sk be locked, or privately held, but it does not do any + * locking itself. Returns zero on success and negative values on failure. + * + */ +int calipso_sock_getattr(struct sock *sk, struct netlbl_lsm_secattr *secattr) +{ + int ret_val = -ENOMSG; + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (ops) + ret_val = ops->sock_getattr(sk, secattr); + return ret_val; +} + +/** + * calipso_sock_setattr - Add a CALIPSO option to a socket + * @sk: the socket + * @doi_def: the CALIPSO DOI to use + * @secattr: the specific security attributes of the socket + * + * Description: + * Set the CALIPSO option on the given socket using the DOI definition and + * security attributes passed to the function. This function requires + * exclusive access to @sk, which means it either needs to be in the + * process of being created or locked. Returns zero on success and negative + * values on failure. + * + */ +int calipso_sock_setattr(struct sock *sk, + const struct calipso_doi *doi_def, + const struct netlbl_lsm_secattr *secattr) +{ + int ret_val = -ENOMSG; + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (ops) + ret_val = ops->sock_setattr(sk, doi_def, secattr); + return ret_val; +} + +/** + * calipso_sock_delattr - Delete the CALIPSO option from a socket + * @sk: the socket + * + * Description: + * Removes the CALIPSO option from a socket, if present. + * + */ +void calipso_sock_delattr(struct sock *sk) +{ + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (ops) + ops->sock_delattr(sk); +} + +/** + * calipso_req_setattr - Add a CALIPSO option to a connection request socket + * @req: the connection request socket + * @doi_def: the CALIPSO DOI to use + * @secattr: the specific security attributes of the socket + * + * Description: + * Set the CALIPSO option on the given socket using the DOI definition and + * security attributes passed to the function. Returns zero on success and + * negative values on failure. + * + */ +int calipso_req_setattr(struct request_sock *req, + const struct calipso_doi *doi_def, + const struct netlbl_lsm_secattr *secattr) +{ + int ret_val = -ENOMSG; + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (ops) + ret_val = ops->req_setattr(req, doi_def, secattr); + return ret_val; +} + +/** + * calipso_req_delattr - Delete the CALIPSO option from a request socket + * @req: the request socket + * + * Description: + * Removes the CALIPSO option from a request socket, if present. + * + */ +void calipso_req_delattr(struct request_sock *req) +{ + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (ops) + ops->req_delattr(req); +} + +/** + * calipso_optptr - Find the CALIPSO option in the packet + * @skb: the packet + * + * Description: + * Parse the packet's IP header looking for a CALIPSO option. Returns a pointer + * to the start of the CALIPSO option on success, NULL if one if not found. + * + */ +unsigned char *calipso_optptr(const struct sk_buff *skb) +{ + unsigned char *ret_val = NULL; + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (ops) + ret_val = ops->skbuff_optptr(skb); + return ret_val; +} + +/** + * calipso_getattr - Get the security attributes from a memory block. + * @calipso: the CALIPSO option + * @secattr: the security attributes + * + * Description: + * Inspect @calipso and return the security attributes in @secattr. + * Returns zero on success and negative values on failure. + * + */ +int calipso_getattr(const unsigned char *calipso, + struct netlbl_lsm_secattr *secattr) +{ + int ret_val = -ENOMSG; + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (ops) + ret_val = ops->opt_getattr(calipso, secattr); + return ret_val; +} + +/** + * calipso_skbuff_setattr - Set the CALIPSO option on a packet + * @skb: the packet + * @doi_def: the CALIPSO DOI to use + * @secattr: the security attributes + * + * Description: + * Set the CALIPSO option on the given packet based on the security attributes. + * Returns a pointer to the IP header on success and NULL on failure. + * + */ +int calipso_skbuff_setattr(struct sk_buff *skb, + const struct calipso_doi *doi_def, + const struct netlbl_lsm_secattr *secattr) +{ + int ret_val = -ENOMSG; + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (ops) + ret_val = ops->skbuff_setattr(skb, doi_def, secattr); + return ret_val; +} + +/** + * calipso_skbuff_delattr - Delete any CALIPSO options from a packet + * @skb: the packet + * + * Description: + * Removes any and all CALIPSO options from the given packet. Returns zero on + * success, negative values on failure. + * + */ +int calipso_skbuff_delattr(struct sk_buff *skb) +{ + int ret_val = -ENOMSG; + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (ops) + ret_val = ops->skbuff_delattr(skb); + return ret_val; +} + +/** + * calipso_cache_invalidate - Invalidates the current CALIPSO cache + * + * Description: + * Invalidates and frees any entries in the CALIPSO cache. Returns zero on + * success and negative values on failure. + * + */ +void calipso_cache_invalidate(void) +{ + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (ops) + ops->cache_invalidate(); +} + +/** + * calipso_cache_add - Add an entry to the CALIPSO cache + * @calipso_ptr: the CALIPSO option + * @secattr: the packet's security attributes + * + * Description: + * Add a new entry into the CALIPSO label mapping cache. + * Returns zero on success, negative values on failure. + * + */ +int calipso_cache_add(const unsigned char *calipso_ptr, + const struct netlbl_lsm_secattr *secattr) + +{ + int ret_val = -ENOMSG; + const struct netlbl_calipso_ops *ops = netlbl_calipso_ops_get(); + + if (ops) + ret_val = ops->cache_add(calipso_ptr, secattr); + return ret_val; +} diff --git a/net/netlabel/netlabel_calipso.h b/net/netlabel/netlabel_calipso.h new file mode 100644 index 000000000..ef3e9a7ab --- /dev/null +++ b/net/netlabel/netlabel_calipso.h @@ -0,0 +1,137 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * NetLabel CALIPSO Support + * + * This file defines the CALIPSO functions for the NetLabel system. The + * NetLabel system manages static and dynamic label mappings for network + * protocols such as CIPSO and RIPSO. + * + * Authors: Paul Moore + * Huw Davies + */ + +/* (c) Copyright Hewlett-Packard Development Company, L.P., 2006 + * (c) Copyright Huw Davies , 2015 + */ + +#ifndef _NETLABEL_CALIPSO +#define _NETLABEL_CALIPSO + +#include +#include + +/* The following NetLabel payloads are supported by the CALIPSO subsystem. + * + * o ADD: + * Sent by an application to add a new DOI mapping table. + * + * Required attributes: + * + * NLBL_CALIPSO_A_DOI + * NLBL_CALIPSO_A_MTYPE + * + * If using CALIPSO_MAP_PASS no additional attributes are required. + * + * o REMOVE: + * Sent by an application to remove a specific DOI mapping table from the + * CALIPSO system. + * + * Required attributes: + * + * NLBL_CALIPSO_A_DOI + * + * o LIST: + * Sent by an application to list the details of a DOI definition. On + * success the kernel should send a response using the following format. + * + * Required attributes: + * + * NLBL_CALIPSO_A_DOI + * + * The valid response message format depends on the type of the DOI mapping, + * the defined formats are shown below. + * + * Required attributes: + * + * NLBL_CALIPSO_A_MTYPE + * + * If using CALIPSO_MAP_PASS no additional attributes are required. + * + * o LISTALL: + * This message is sent by an application to list the valid DOIs on the + * system. When sent by an application there is no payload and the + * NLM_F_DUMP flag should be set. The kernel should respond with a series of + * the following messages. + * + * Required attributes: + * + * NLBL_CALIPSO_A_DOI + * NLBL_CALIPSO_A_MTYPE + * + */ + +/* NetLabel CALIPSO commands */ +enum { + NLBL_CALIPSO_C_UNSPEC, + NLBL_CALIPSO_C_ADD, + NLBL_CALIPSO_C_REMOVE, + NLBL_CALIPSO_C_LIST, + NLBL_CALIPSO_C_LISTALL, + __NLBL_CALIPSO_C_MAX, +}; + +/* NetLabel CALIPSO attributes */ +enum { + NLBL_CALIPSO_A_UNSPEC, + NLBL_CALIPSO_A_DOI, + /* (NLA_U32) + * the DOI value */ + NLBL_CALIPSO_A_MTYPE, + /* (NLA_U32) + * the mapping table type (defined in the calipso.h header as + * CALIPSO_MAP_*) */ + __NLBL_CALIPSO_A_MAX, +}; + +#define NLBL_CALIPSO_A_MAX (__NLBL_CALIPSO_A_MAX - 1) + +/* NetLabel protocol functions */ +#if IS_ENABLED(CONFIG_IPV6) +int netlbl_calipso_genl_init(void); +#else +static inline int netlbl_calipso_genl_init(void) +{ + return 0; +} +#endif + +int calipso_doi_add(struct calipso_doi *doi_def, + struct netlbl_audit *audit_info); +void calipso_doi_free(struct calipso_doi *doi_def); +int calipso_doi_remove(u32 doi, struct netlbl_audit *audit_info); +struct calipso_doi *calipso_doi_getdef(u32 doi); +void calipso_doi_putdef(struct calipso_doi *doi_def); +int calipso_doi_walk(u32 *skip_cnt, + int (*callback)(struct calipso_doi *doi_def, void *arg), + void *cb_arg); +int calipso_sock_getattr(struct sock *sk, struct netlbl_lsm_secattr *secattr); +int calipso_sock_setattr(struct sock *sk, + const struct calipso_doi *doi_def, + const struct netlbl_lsm_secattr *secattr); +void calipso_sock_delattr(struct sock *sk); +int calipso_req_setattr(struct request_sock *req, + const struct calipso_doi *doi_def, + const struct netlbl_lsm_secattr *secattr); +void calipso_req_delattr(struct request_sock *req); +unsigned char *calipso_optptr(const struct sk_buff *skb); +int calipso_getattr(const unsigned char *calipso, + struct netlbl_lsm_secattr *secattr); +int calipso_skbuff_setattr(struct sk_buff *skb, + const struct calipso_doi *doi_def, + const struct netlbl_lsm_secattr *secattr); +int calipso_skbuff_delattr(struct sk_buff *skb); +void calipso_cache_invalidate(void); +int calipso_cache_add(const unsigned char *calipso_ptr, + const struct netlbl_lsm_secattr *secattr); + +#endif diff --git a/net/netlabel/netlabel_cipso_v4.c b/net/netlabel/netlabel_cipso_v4.c new file mode 100644 index 000000000..fa08ee75a --- /dev/null +++ b/net/netlabel/netlabel_cipso_v4.c @@ -0,0 +1,788 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NetLabel CIPSO/IPv4 Support + * + * This file defines the CIPSO/IPv4 functions for the NetLabel system. The + * NetLabel system manages static and dynamic label mappings for network + * protocols such as CIPSO and RIPSO. + * + * Author: Paul Moore + */ + +/* + * (c) Copyright Hewlett-Packard Development Company, L.P., 2006 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "netlabel_user.h" +#include "netlabel_cipso_v4.h" +#include "netlabel_mgmt.h" +#include "netlabel_domainhash.h" + +/* Argument struct for cipso_v4_doi_walk() */ +struct netlbl_cipsov4_doiwalk_arg { + struct netlink_callback *nl_cb; + struct sk_buff *skb; + u32 seq; +}; + +/* Argument struct for netlbl_domhsh_walk() */ +struct netlbl_domhsh_walk_arg { + struct netlbl_audit *audit_info; + u32 doi; +}; + +/* NetLabel Generic NETLINK CIPSOv4 family */ +static struct genl_family netlbl_cipsov4_gnl_family; +/* NetLabel Netlink attribute policy */ +static const struct nla_policy netlbl_cipsov4_genl_policy[NLBL_CIPSOV4_A_MAX + 1] = { + [NLBL_CIPSOV4_A_DOI] = { .type = NLA_U32 }, + [NLBL_CIPSOV4_A_MTYPE] = { .type = NLA_U32 }, + [NLBL_CIPSOV4_A_TAG] = { .type = NLA_U8 }, + [NLBL_CIPSOV4_A_TAGLST] = { .type = NLA_NESTED }, + [NLBL_CIPSOV4_A_MLSLVLLOC] = { .type = NLA_U32 }, + [NLBL_CIPSOV4_A_MLSLVLREM] = { .type = NLA_U32 }, + [NLBL_CIPSOV4_A_MLSLVL] = { .type = NLA_NESTED }, + [NLBL_CIPSOV4_A_MLSLVLLST] = { .type = NLA_NESTED }, + [NLBL_CIPSOV4_A_MLSCATLOC] = { .type = NLA_U32 }, + [NLBL_CIPSOV4_A_MLSCATREM] = { .type = NLA_U32 }, + [NLBL_CIPSOV4_A_MLSCAT] = { .type = NLA_NESTED }, + [NLBL_CIPSOV4_A_MLSCATLST] = { .type = NLA_NESTED }, +}; + +/* + * Helper Functions + */ + +/** + * netlbl_cipsov4_add_common - Parse the common sections of a ADD message + * @info: the Generic NETLINK info block + * @doi_def: the CIPSO V4 DOI definition + * + * Description: + * Parse the common sections of a ADD message and fill in the related values + * in @doi_def. Returns zero on success, negative values on failure. + * + */ +static int netlbl_cipsov4_add_common(struct genl_info *info, + struct cipso_v4_doi *doi_def) +{ + struct nlattr *nla; + int nla_rem; + u32 iter = 0; + + doi_def->doi = nla_get_u32(info->attrs[NLBL_CIPSOV4_A_DOI]); + + if (nla_validate_nested_deprecated(info->attrs[NLBL_CIPSOV4_A_TAGLST], + NLBL_CIPSOV4_A_MAX, + netlbl_cipsov4_genl_policy, + NULL) != 0) + return -EINVAL; + + nla_for_each_nested(nla, info->attrs[NLBL_CIPSOV4_A_TAGLST], nla_rem) + if (nla_type(nla) == NLBL_CIPSOV4_A_TAG) { + if (iter >= CIPSO_V4_TAG_MAXCNT) + return -EINVAL; + doi_def->tags[iter++] = nla_get_u8(nla); + } + while (iter < CIPSO_V4_TAG_MAXCNT) + doi_def->tags[iter++] = CIPSO_V4_TAG_INVALID; + + return 0; +} + +/* + * NetLabel Command Handlers + */ + +/** + * netlbl_cipsov4_add_std - Adds a CIPSO V4 DOI definition + * @info: the Generic NETLINK info block + * @audit_info: NetLabel audit information + * + * Description: + * Create a new CIPSO_V4_MAP_TRANS DOI definition based on the given ADD + * message and add it to the CIPSO V4 engine. Return zero on success and + * non-zero on error. + * + */ +static int netlbl_cipsov4_add_std(struct genl_info *info, + struct netlbl_audit *audit_info) +{ + int ret_val = -EINVAL; + struct cipso_v4_doi *doi_def = NULL; + struct nlattr *nla_a; + struct nlattr *nla_b; + int nla_a_rem; + int nla_b_rem; + u32 iter; + + if (!info->attrs[NLBL_CIPSOV4_A_TAGLST] || + !info->attrs[NLBL_CIPSOV4_A_MLSLVLLST]) + return -EINVAL; + + if (nla_validate_nested_deprecated(info->attrs[NLBL_CIPSOV4_A_MLSLVLLST], + NLBL_CIPSOV4_A_MAX, + netlbl_cipsov4_genl_policy, + NULL) != 0) + return -EINVAL; + + doi_def = kmalloc(sizeof(*doi_def), GFP_KERNEL); + if (doi_def == NULL) + return -ENOMEM; + doi_def->map.std = kzalloc(sizeof(*doi_def->map.std), GFP_KERNEL); + if (doi_def->map.std == NULL) { + kfree(doi_def); + return -ENOMEM; + } + doi_def->type = CIPSO_V4_MAP_TRANS; + + ret_val = netlbl_cipsov4_add_common(info, doi_def); + if (ret_val != 0) + goto add_std_failure; + ret_val = -EINVAL; + + nla_for_each_nested(nla_a, + info->attrs[NLBL_CIPSOV4_A_MLSLVLLST], + nla_a_rem) + if (nla_type(nla_a) == NLBL_CIPSOV4_A_MLSLVL) { + if (nla_validate_nested_deprecated(nla_a, + NLBL_CIPSOV4_A_MAX, + netlbl_cipsov4_genl_policy, + NULL) != 0) + goto add_std_failure; + nla_for_each_nested(nla_b, nla_a, nla_b_rem) + switch (nla_type(nla_b)) { + case NLBL_CIPSOV4_A_MLSLVLLOC: + if (nla_get_u32(nla_b) > + CIPSO_V4_MAX_LOC_LVLS) + goto add_std_failure; + if (nla_get_u32(nla_b) >= + doi_def->map.std->lvl.local_size) + doi_def->map.std->lvl.local_size = + nla_get_u32(nla_b) + 1; + break; + case NLBL_CIPSOV4_A_MLSLVLREM: + if (nla_get_u32(nla_b) > + CIPSO_V4_MAX_REM_LVLS) + goto add_std_failure; + if (nla_get_u32(nla_b) >= + doi_def->map.std->lvl.cipso_size) + doi_def->map.std->lvl.cipso_size = + nla_get_u32(nla_b) + 1; + break; + } + } + doi_def->map.std->lvl.local = kcalloc(doi_def->map.std->lvl.local_size, + sizeof(u32), + GFP_KERNEL | __GFP_NOWARN); + if (doi_def->map.std->lvl.local == NULL) { + ret_val = -ENOMEM; + goto add_std_failure; + } + doi_def->map.std->lvl.cipso = kcalloc(doi_def->map.std->lvl.cipso_size, + sizeof(u32), + GFP_KERNEL | __GFP_NOWARN); + if (doi_def->map.std->lvl.cipso == NULL) { + ret_val = -ENOMEM; + goto add_std_failure; + } + for (iter = 0; iter < doi_def->map.std->lvl.local_size; iter++) + doi_def->map.std->lvl.local[iter] = CIPSO_V4_INV_LVL; + for (iter = 0; iter < doi_def->map.std->lvl.cipso_size; iter++) + doi_def->map.std->lvl.cipso[iter] = CIPSO_V4_INV_LVL; + nla_for_each_nested(nla_a, + info->attrs[NLBL_CIPSOV4_A_MLSLVLLST], + nla_a_rem) + if (nla_type(nla_a) == NLBL_CIPSOV4_A_MLSLVL) { + struct nlattr *lvl_loc; + struct nlattr *lvl_rem; + + lvl_loc = nla_find_nested(nla_a, + NLBL_CIPSOV4_A_MLSLVLLOC); + lvl_rem = nla_find_nested(nla_a, + NLBL_CIPSOV4_A_MLSLVLREM); + if (lvl_loc == NULL || lvl_rem == NULL) + goto add_std_failure; + doi_def->map.std->lvl.local[nla_get_u32(lvl_loc)] = + nla_get_u32(lvl_rem); + doi_def->map.std->lvl.cipso[nla_get_u32(lvl_rem)] = + nla_get_u32(lvl_loc); + } + + if (info->attrs[NLBL_CIPSOV4_A_MLSCATLST]) { + if (nla_validate_nested_deprecated(info->attrs[NLBL_CIPSOV4_A_MLSCATLST], + NLBL_CIPSOV4_A_MAX, + netlbl_cipsov4_genl_policy, + NULL) != 0) + goto add_std_failure; + + nla_for_each_nested(nla_a, + info->attrs[NLBL_CIPSOV4_A_MLSCATLST], + nla_a_rem) + if (nla_type(nla_a) == NLBL_CIPSOV4_A_MLSCAT) { + if (nla_validate_nested_deprecated(nla_a, + NLBL_CIPSOV4_A_MAX, + netlbl_cipsov4_genl_policy, + NULL) != 0) + goto add_std_failure; + nla_for_each_nested(nla_b, nla_a, nla_b_rem) + switch (nla_type(nla_b)) { + case NLBL_CIPSOV4_A_MLSCATLOC: + if (nla_get_u32(nla_b) > + CIPSO_V4_MAX_LOC_CATS) + goto add_std_failure; + if (nla_get_u32(nla_b) >= + doi_def->map.std->cat.local_size) + doi_def->map.std->cat.local_size = + nla_get_u32(nla_b) + 1; + break; + case NLBL_CIPSOV4_A_MLSCATREM: + if (nla_get_u32(nla_b) > + CIPSO_V4_MAX_REM_CATS) + goto add_std_failure; + if (nla_get_u32(nla_b) >= + doi_def->map.std->cat.cipso_size) + doi_def->map.std->cat.cipso_size = + nla_get_u32(nla_b) + 1; + break; + } + } + doi_def->map.std->cat.local = kcalloc( + doi_def->map.std->cat.local_size, + sizeof(u32), + GFP_KERNEL | __GFP_NOWARN); + if (doi_def->map.std->cat.local == NULL) { + ret_val = -ENOMEM; + goto add_std_failure; + } + doi_def->map.std->cat.cipso = kcalloc( + doi_def->map.std->cat.cipso_size, + sizeof(u32), + GFP_KERNEL | __GFP_NOWARN); + if (doi_def->map.std->cat.cipso == NULL) { + ret_val = -ENOMEM; + goto add_std_failure; + } + for (iter = 0; iter < doi_def->map.std->cat.local_size; iter++) + doi_def->map.std->cat.local[iter] = CIPSO_V4_INV_CAT; + for (iter = 0; iter < doi_def->map.std->cat.cipso_size; iter++) + doi_def->map.std->cat.cipso[iter] = CIPSO_V4_INV_CAT; + nla_for_each_nested(nla_a, + info->attrs[NLBL_CIPSOV4_A_MLSCATLST], + nla_a_rem) + if (nla_type(nla_a) == NLBL_CIPSOV4_A_MLSCAT) { + struct nlattr *cat_loc; + struct nlattr *cat_rem; + + cat_loc = nla_find_nested(nla_a, + NLBL_CIPSOV4_A_MLSCATLOC); + cat_rem = nla_find_nested(nla_a, + NLBL_CIPSOV4_A_MLSCATREM); + if (cat_loc == NULL || cat_rem == NULL) + goto add_std_failure; + doi_def->map.std->cat.local[ + nla_get_u32(cat_loc)] = + nla_get_u32(cat_rem); + doi_def->map.std->cat.cipso[ + nla_get_u32(cat_rem)] = + nla_get_u32(cat_loc); + } + } + + ret_val = cipso_v4_doi_add(doi_def, audit_info); + if (ret_val != 0) + goto add_std_failure; + return 0; + +add_std_failure: + cipso_v4_doi_free(doi_def); + return ret_val; +} + +/** + * netlbl_cipsov4_add_pass - Adds a CIPSO V4 DOI definition + * @info: the Generic NETLINK info block + * @audit_info: NetLabel audit information + * + * Description: + * Create a new CIPSO_V4_MAP_PASS DOI definition based on the given ADD message + * and add it to the CIPSO V4 engine. Return zero on success and non-zero on + * error. + * + */ +static int netlbl_cipsov4_add_pass(struct genl_info *info, + struct netlbl_audit *audit_info) +{ + int ret_val; + struct cipso_v4_doi *doi_def = NULL; + + if (!info->attrs[NLBL_CIPSOV4_A_TAGLST]) + return -EINVAL; + + doi_def = kmalloc(sizeof(*doi_def), GFP_KERNEL); + if (doi_def == NULL) + return -ENOMEM; + doi_def->type = CIPSO_V4_MAP_PASS; + + ret_val = netlbl_cipsov4_add_common(info, doi_def); + if (ret_val != 0) + goto add_pass_failure; + + ret_val = cipso_v4_doi_add(doi_def, audit_info); + if (ret_val != 0) + goto add_pass_failure; + return 0; + +add_pass_failure: + cipso_v4_doi_free(doi_def); + return ret_val; +} + +/** + * netlbl_cipsov4_add_local - Adds a CIPSO V4 DOI definition + * @info: the Generic NETLINK info block + * @audit_info: NetLabel audit information + * + * Description: + * Create a new CIPSO_V4_MAP_LOCAL DOI definition based on the given ADD + * message and add it to the CIPSO V4 engine. Return zero on success and + * non-zero on error. + * + */ +static int netlbl_cipsov4_add_local(struct genl_info *info, + struct netlbl_audit *audit_info) +{ + int ret_val; + struct cipso_v4_doi *doi_def = NULL; + + if (!info->attrs[NLBL_CIPSOV4_A_TAGLST]) + return -EINVAL; + + doi_def = kmalloc(sizeof(*doi_def), GFP_KERNEL); + if (doi_def == NULL) + return -ENOMEM; + doi_def->type = CIPSO_V4_MAP_LOCAL; + + ret_val = netlbl_cipsov4_add_common(info, doi_def); + if (ret_val != 0) + goto add_local_failure; + + ret_val = cipso_v4_doi_add(doi_def, audit_info); + if (ret_val != 0) + goto add_local_failure; + return 0; + +add_local_failure: + cipso_v4_doi_free(doi_def); + return ret_val; +} + +/** + * netlbl_cipsov4_add - Handle an ADD message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Create a new DOI definition based on the given ADD message and add it to the + * CIPSO V4 engine. Returns zero on success, negative values on failure. + * + */ +static int netlbl_cipsov4_add(struct sk_buff *skb, struct genl_info *info) + +{ + int ret_val = -EINVAL; + struct netlbl_audit audit_info; + + if (!info->attrs[NLBL_CIPSOV4_A_DOI] || + !info->attrs[NLBL_CIPSOV4_A_MTYPE]) + return -EINVAL; + + netlbl_netlink_auditinfo(&audit_info); + switch (nla_get_u32(info->attrs[NLBL_CIPSOV4_A_MTYPE])) { + case CIPSO_V4_MAP_TRANS: + ret_val = netlbl_cipsov4_add_std(info, &audit_info); + break; + case CIPSO_V4_MAP_PASS: + ret_val = netlbl_cipsov4_add_pass(info, &audit_info); + break; + case CIPSO_V4_MAP_LOCAL: + ret_val = netlbl_cipsov4_add_local(info, &audit_info); + break; + } + if (ret_val == 0) + atomic_inc(&netlabel_mgmt_protocount); + + return ret_val; +} + +/** + * netlbl_cipsov4_list - Handle a LIST message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Process a user generated LIST message and respond accordingly. While the + * response message generated by the kernel is straightforward, determining + * before hand the size of the buffer to allocate is not (we have to generate + * the message to know the size). In order to keep this function sane what we + * do is allocate a buffer of NLMSG_GOODSIZE and try to fit the response in + * that size, if we fail then we restart with a larger buffer and try again. + * We continue in this manner until we hit a limit of failed attempts then we + * give up and just send an error message. Returns zero on success and + * negative values on error. + * + */ +static int netlbl_cipsov4_list(struct sk_buff *skb, struct genl_info *info) +{ + int ret_val; + struct sk_buff *ans_skb = NULL; + u32 nlsze_mult = 1; + void *data; + u32 doi; + struct nlattr *nla_a; + struct nlattr *nla_b; + struct cipso_v4_doi *doi_def; + u32 iter; + + if (!info->attrs[NLBL_CIPSOV4_A_DOI]) { + ret_val = -EINVAL; + goto list_failure; + } + +list_start: + ans_skb = nlmsg_new(NLMSG_DEFAULT_SIZE * nlsze_mult, GFP_KERNEL); + if (ans_skb == NULL) { + ret_val = -ENOMEM; + goto list_failure; + } + data = genlmsg_put_reply(ans_skb, info, &netlbl_cipsov4_gnl_family, + 0, NLBL_CIPSOV4_C_LIST); + if (data == NULL) { + ret_val = -ENOMEM; + goto list_failure; + } + + doi = nla_get_u32(info->attrs[NLBL_CIPSOV4_A_DOI]); + + rcu_read_lock(); + doi_def = cipso_v4_doi_getdef(doi); + if (doi_def == NULL) { + ret_val = -EINVAL; + goto list_failure_lock; + } + + ret_val = nla_put_u32(ans_skb, NLBL_CIPSOV4_A_MTYPE, doi_def->type); + if (ret_val != 0) + goto list_failure_lock; + + nla_a = nla_nest_start_noflag(ans_skb, NLBL_CIPSOV4_A_TAGLST); + if (nla_a == NULL) { + ret_val = -ENOMEM; + goto list_failure_lock; + } + for (iter = 0; + iter < CIPSO_V4_TAG_MAXCNT && + doi_def->tags[iter] != CIPSO_V4_TAG_INVALID; + iter++) { + ret_val = nla_put_u8(ans_skb, + NLBL_CIPSOV4_A_TAG, + doi_def->tags[iter]); + if (ret_val != 0) + goto list_failure_lock; + } + nla_nest_end(ans_skb, nla_a); + + switch (doi_def->type) { + case CIPSO_V4_MAP_TRANS: + nla_a = nla_nest_start_noflag(ans_skb, + NLBL_CIPSOV4_A_MLSLVLLST); + if (nla_a == NULL) { + ret_val = -ENOMEM; + goto list_failure_lock; + } + for (iter = 0; + iter < doi_def->map.std->lvl.local_size; + iter++) { + if (doi_def->map.std->lvl.local[iter] == + CIPSO_V4_INV_LVL) + continue; + + nla_b = nla_nest_start_noflag(ans_skb, + NLBL_CIPSOV4_A_MLSLVL); + if (nla_b == NULL) { + ret_val = -ENOMEM; + goto list_retry; + } + ret_val = nla_put_u32(ans_skb, + NLBL_CIPSOV4_A_MLSLVLLOC, + iter); + if (ret_val != 0) + goto list_retry; + ret_val = nla_put_u32(ans_skb, + NLBL_CIPSOV4_A_MLSLVLREM, + doi_def->map.std->lvl.local[iter]); + if (ret_val != 0) + goto list_retry; + nla_nest_end(ans_skb, nla_b); + } + nla_nest_end(ans_skb, nla_a); + + nla_a = nla_nest_start_noflag(ans_skb, + NLBL_CIPSOV4_A_MLSCATLST); + if (nla_a == NULL) { + ret_val = -ENOMEM; + goto list_retry; + } + for (iter = 0; + iter < doi_def->map.std->cat.local_size; + iter++) { + if (doi_def->map.std->cat.local[iter] == + CIPSO_V4_INV_CAT) + continue; + + nla_b = nla_nest_start_noflag(ans_skb, + NLBL_CIPSOV4_A_MLSCAT); + if (nla_b == NULL) { + ret_val = -ENOMEM; + goto list_retry; + } + ret_val = nla_put_u32(ans_skb, + NLBL_CIPSOV4_A_MLSCATLOC, + iter); + if (ret_val != 0) + goto list_retry; + ret_val = nla_put_u32(ans_skb, + NLBL_CIPSOV4_A_MLSCATREM, + doi_def->map.std->cat.local[iter]); + if (ret_val != 0) + goto list_retry; + nla_nest_end(ans_skb, nla_b); + } + nla_nest_end(ans_skb, nla_a); + + break; + } + cipso_v4_doi_putdef(doi_def); + rcu_read_unlock(); + + genlmsg_end(ans_skb, data); + return genlmsg_reply(ans_skb, info); + +list_retry: + /* XXX - this limit is a guesstimate */ + if (nlsze_mult < 4) { + cipso_v4_doi_putdef(doi_def); + rcu_read_unlock(); + kfree_skb(ans_skb); + nlsze_mult *= 2; + goto list_start; + } +list_failure_lock: + cipso_v4_doi_putdef(doi_def); + rcu_read_unlock(); +list_failure: + kfree_skb(ans_skb); + return ret_val; +} + +/** + * netlbl_cipsov4_listall_cb - cipso_v4_doi_walk() callback for LISTALL + * @doi_def: the CIPSOv4 DOI definition + * @arg: the netlbl_cipsov4_doiwalk_arg structure + * + * Description: + * This function is designed to be used as a callback to the + * cipso_v4_doi_walk() function for use in generating a response for a LISTALL + * message. Returns the size of the message on success, negative values on + * failure. + * + */ +static int netlbl_cipsov4_listall_cb(struct cipso_v4_doi *doi_def, void *arg) +{ + int ret_val = -ENOMEM; + struct netlbl_cipsov4_doiwalk_arg *cb_arg = arg; + void *data; + + data = genlmsg_put(cb_arg->skb, NETLINK_CB(cb_arg->nl_cb->skb).portid, + cb_arg->seq, &netlbl_cipsov4_gnl_family, + NLM_F_MULTI, NLBL_CIPSOV4_C_LISTALL); + if (data == NULL) + goto listall_cb_failure; + + ret_val = nla_put_u32(cb_arg->skb, NLBL_CIPSOV4_A_DOI, doi_def->doi); + if (ret_val != 0) + goto listall_cb_failure; + ret_val = nla_put_u32(cb_arg->skb, + NLBL_CIPSOV4_A_MTYPE, + doi_def->type); + if (ret_val != 0) + goto listall_cb_failure; + + genlmsg_end(cb_arg->skb, data); + return 0; + +listall_cb_failure: + genlmsg_cancel(cb_arg->skb, data); + return ret_val; +} + +/** + * netlbl_cipsov4_listall - Handle a LISTALL message + * @skb: the NETLINK buffer + * @cb: the NETLINK callback + * + * Description: + * Process a user generated LISTALL message and respond accordingly. Returns + * zero on success and negative values on error. + * + */ +static int netlbl_cipsov4_listall(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct netlbl_cipsov4_doiwalk_arg cb_arg; + u32 doi_skip = cb->args[0]; + + cb_arg.nl_cb = cb; + cb_arg.skb = skb; + cb_arg.seq = cb->nlh->nlmsg_seq; + + cipso_v4_doi_walk(&doi_skip, netlbl_cipsov4_listall_cb, &cb_arg); + + cb->args[0] = doi_skip; + return skb->len; +} + +/** + * netlbl_cipsov4_remove_cb - netlbl_cipsov4_remove() callback for REMOVE + * @entry: LSM domain mapping entry + * @arg: the netlbl_domhsh_walk_arg structure + * + * Description: + * This function is intended for use by netlbl_cipsov4_remove() as the callback + * for the netlbl_domhsh_walk() function; it removes LSM domain map entries + * which are associated with the CIPSO DOI specified in @arg. Returns zero on + * success, negative values on failure. + * + */ +static int netlbl_cipsov4_remove_cb(struct netlbl_dom_map *entry, void *arg) +{ + struct netlbl_domhsh_walk_arg *cb_arg = arg; + + if (entry->def.type == NETLBL_NLTYPE_CIPSOV4 && + entry->def.cipso->doi == cb_arg->doi) + return netlbl_domhsh_remove_entry(entry, cb_arg->audit_info); + + return 0; +} + +/** + * netlbl_cipsov4_remove - Handle a REMOVE message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Process a user generated REMOVE message and respond accordingly. Returns + * zero on success, negative values on failure. + * + */ +static int netlbl_cipsov4_remove(struct sk_buff *skb, struct genl_info *info) +{ + int ret_val = -EINVAL; + struct netlbl_domhsh_walk_arg cb_arg; + struct netlbl_audit audit_info; + u32 skip_bkt = 0; + u32 skip_chain = 0; + + if (!info->attrs[NLBL_CIPSOV4_A_DOI]) + return -EINVAL; + + netlbl_netlink_auditinfo(&audit_info); + cb_arg.doi = nla_get_u32(info->attrs[NLBL_CIPSOV4_A_DOI]); + cb_arg.audit_info = &audit_info; + ret_val = netlbl_domhsh_walk(&skip_bkt, &skip_chain, + netlbl_cipsov4_remove_cb, &cb_arg); + if (ret_val == 0 || ret_val == -ENOENT) { + ret_val = cipso_v4_doi_remove(cb_arg.doi, &audit_info); + if (ret_val == 0) + atomic_dec(&netlabel_mgmt_protocount); + } + + return ret_val; +} + +/* + * NetLabel Generic NETLINK Command Definitions + */ + +static const struct genl_small_ops netlbl_cipsov4_ops[] = { + { + .cmd = NLBL_CIPSOV4_C_ADD, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = netlbl_cipsov4_add, + .dumpit = NULL, + }, + { + .cmd = NLBL_CIPSOV4_C_REMOVE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = netlbl_cipsov4_remove, + .dumpit = NULL, + }, + { + .cmd = NLBL_CIPSOV4_C_LIST, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, + .doit = netlbl_cipsov4_list, + .dumpit = NULL, + }, + { + .cmd = NLBL_CIPSOV4_C_LISTALL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, + .doit = NULL, + .dumpit = netlbl_cipsov4_listall, + }, +}; + +static struct genl_family netlbl_cipsov4_gnl_family __ro_after_init = { + .hdrsize = 0, + .name = NETLBL_NLTYPE_CIPSOV4_NAME, + .version = NETLBL_PROTO_VERSION, + .maxattr = NLBL_CIPSOV4_A_MAX, + .policy = netlbl_cipsov4_genl_policy, + .module = THIS_MODULE, + .small_ops = netlbl_cipsov4_ops, + .n_small_ops = ARRAY_SIZE(netlbl_cipsov4_ops), + .resv_start_op = NLBL_CIPSOV4_C_LISTALL + 1, +}; + +/* + * NetLabel Generic NETLINK Protocol Functions + */ + +/** + * netlbl_cipsov4_genl_init - Register the CIPSOv4 NetLabel component + * + * Description: + * Register the CIPSOv4 packet NetLabel component with the Generic NETLINK + * mechanism. Returns zero on success, negative values on failure. + * + */ +int __init netlbl_cipsov4_genl_init(void) +{ + return genl_register_family(&netlbl_cipsov4_gnl_family); +} diff --git a/net/netlabel/netlabel_cipso_v4.h b/net/netlabel/netlabel_cipso_v4.h new file mode 100644 index 000000000..85d7ecb05 --- /dev/null +++ b/net/netlabel/netlabel_cipso_v4.h @@ -0,0 +1,155 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * NetLabel CIPSO/IPv4 Support + * + * This file defines the CIPSO/IPv4 functions for the NetLabel system. The + * NetLabel system manages static and dynamic label mappings for network + * protocols such as CIPSO and RIPSO. + * + * Author: Paul Moore + */ + +/* + * (c) Copyright Hewlett-Packard Development Company, L.P., 2006 + */ + +#ifndef _NETLABEL_CIPSO_V4 +#define _NETLABEL_CIPSO_V4 + +#include + +/* + * The following NetLabel payloads are supported by the CIPSO subsystem. + * + * o ADD: + * Sent by an application to add a new DOI mapping table. + * + * Required attributes: + * + * NLBL_CIPSOV4_A_DOI + * NLBL_CIPSOV4_A_MTYPE + * NLBL_CIPSOV4_A_TAGLST + * + * If using CIPSO_V4_MAP_TRANS the following attributes are required: + * + * NLBL_CIPSOV4_A_MLSLVLLST + * NLBL_CIPSOV4_A_MLSCATLST + * + * If using CIPSO_V4_MAP_PASS or CIPSO_V4_MAP_LOCAL no additional attributes + * are required. + * + * o REMOVE: + * Sent by an application to remove a specific DOI mapping table from the + * CIPSO V4 system. + * + * Required attributes: + * + * NLBL_CIPSOV4_A_DOI + * + * o LIST: + * Sent by an application to list the details of a DOI definition. On + * success the kernel should send a response using the following format. + * + * Required attributes: + * + * NLBL_CIPSOV4_A_DOI + * + * The valid response message format depends on the type of the DOI mapping, + * the defined formats are shown below. + * + * Required attributes: + * + * NLBL_CIPSOV4_A_MTYPE + * NLBL_CIPSOV4_A_TAGLST + * + * If using CIPSO_V4_MAP_TRANS the following attributes are required: + * + * NLBL_CIPSOV4_A_MLSLVLLST + * NLBL_CIPSOV4_A_MLSCATLST + * + * If using CIPSO_V4_MAP_PASS or CIPSO_V4_MAP_LOCAL no additional attributes + * are required. + * + * o LISTALL: + * This message is sent by an application to list the valid DOIs on the + * system. When sent by an application there is no payload and the + * NLM_F_DUMP flag should be set. The kernel should respond with a series of + * the following messages. + * + * Required attributes: + * + * NLBL_CIPSOV4_A_DOI + * NLBL_CIPSOV4_A_MTYPE + * + */ + +/* NetLabel CIPSOv4 commands */ +enum { + NLBL_CIPSOV4_C_UNSPEC, + NLBL_CIPSOV4_C_ADD, + NLBL_CIPSOV4_C_REMOVE, + NLBL_CIPSOV4_C_LIST, + NLBL_CIPSOV4_C_LISTALL, + __NLBL_CIPSOV4_C_MAX, +}; + +/* NetLabel CIPSOv4 attributes */ +enum { + NLBL_CIPSOV4_A_UNSPEC, + NLBL_CIPSOV4_A_DOI, + /* (NLA_U32) + * the DOI value */ + NLBL_CIPSOV4_A_MTYPE, + /* (NLA_U32) + * the mapping table type (defined in the cipso_ipv4.h header as + * CIPSO_V4_MAP_*) */ + NLBL_CIPSOV4_A_TAG, + /* (NLA_U8) + * a CIPSO tag type, meant to be used within a NLBL_CIPSOV4_A_TAGLST + * attribute */ + NLBL_CIPSOV4_A_TAGLST, + /* (NLA_NESTED) + * the CIPSO tag list for the DOI, there must be at least one + * NLBL_CIPSOV4_A_TAG attribute, tags listed first are given higher + * priorirty when sending packets */ + NLBL_CIPSOV4_A_MLSLVLLOC, + /* (NLA_U32) + * the local MLS sensitivity level */ + NLBL_CIPSOV4_A_MLSLVLREM, + /* (NLA_U32) + * the remote MLS sensitivity level */ + NLBL_CIPSOV4_A_MLSLVL, + /* (NLA_NESTED) + * a MLS sensitivity level mapping, must contain only one attribute of + * each of the following types: NLBL_CIPSOV4_A_MLSLVLLOC and + * NLBL_CIPSOV4_A_MLSLVLREM */ + NLBL_CIPSOV4_A_MLSLVLLST, + /* (NLA_NESTED) + * the CIPSO level mappings, there must be at least one + * NLBL_CIPSOV4_A_MLSLVL attribute */ + NLBL_CIPSOV4_A_MLSCATLOC, + /* (NLA_U32) + * the local MLS category */ + NLBL_CIPSOV4_A_MLSCATREM, + /* (NLA_U32) + * the remote MLS category */ + NLBL_CIPSOV4_A_MLSCAT, + /* (NLA_NESTED) + * a MLS category mapping, must contain only one attribute of each of + * the following types: NLBL_CIPSOV4_A_MLSCATLOC and + * NLBL_CIPSOV4_A_MLSCATREM */ + NLBL_CIPSOV4_A_MLSCATLST, + /* (NLA_NESTED) + * the CIPSO category mappings, there must be at least one + * NLBL_CIPSOV4_A_MLSCAT attribute */ + __NLBL_CIPSOV4_A_MAX, +}; +#define NLBL_CIPSOV4_A_MAX (__NLBL_CIPSOV4_A_MAX - 1) + +/* NetLabel protocol functions */ +int netlbl_cipsov4_genl_init(void); + +/* Free the memory associated with a CIPSOv4 DOI definition */ +void netlbl_cipsov4_doi_free(struct rcu_head *entry); + +#endif diff --git a/net/netlabel/netlabel_domainhash.c b/net/netlabel/netlabel_domainhash.c new file mode 100644 index 000000000..8158a2597 --- /dev/null +++ b/net/netlabel/netlabel_domainhash.c @@ -0,0 +1,972 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NetLabel Domain Hash Table + * + * This file manages the domain hash table that NetLabel uses to determine + * which network labeling protocol to use for a given domain. The NetLabel + * system manages static and dynamic label mappings for network protocols such + * as CIPSO and RIPSO. + * + * Author: Paul Moore + */ + +/* + * (c) Copyright Hewlett-Packard Development Company, L.P., 2006, 2008 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "netlabel_mgmt.h" +#include "netlabel_addrlist.h" +#include "netlabel_calipso.h" +#include "netlabel_domainhash.h" +#include "netlabel_user.h" + +struct netlbl_domhsh_tbl { + struct list_head *tbl; + u32 size; +}; + +/* Domain hash table */ +/* updates should be so rare that having one spinlock for the entire hash table + * should be okay */ +static DEFINE_SPINLOCK(netlbl_domhsh_lock); +#define netlbl_domhsh_rcu_deref(p) \ + rcu_dereference_check(p, lockdep_is_held(&netlbl_domhsh_lock)) +static struct netlbl_domhsh_tbl __rcu *netlbl_domhsh; +static struct netlbl_dom_map __rcu *netlbl_domhsh_def_ipv4; +static struct netlbl_dom_map __rcu *netlbl_domhsh_def_ipv6; + +/* + * Domain Hash Table Helper Functions + */ + +/** + * netlbl_domhsh_free_entry - Frees a domain hash table entry + * @entry: the entry's RCU field + * + * Description: + * This function is designed to be used as a callback to the call_rcu() + * function so that the memory allocated to a hash table entry can be released + * safely. + * + */ +static void netlbl_domhsh_free_entry(struct rcu_head *entry) +{ + struct netlbl_dom_map *ptr; + struct netlbl_af4list *iter4; + struct netlbl_af4list *tmp4; +#if IS_ENABLED(CONFIG_IPV6) + struct netlbl_af6list *iter6; + struct netlbl_af6list *tmp6; +#endif /* IPv6 */ + + ptr = container_of(entry, struct netlbl_dom_map, rcu); + if (ptr->def.type == NETLBL_NLTYPE_ADDRSELECT) { + netlbl_af4list_foreach_safe(iter4, tmp4, + &ptr->def.addrsel->list4) { + netlbl_af4list_remove_entry(iter4); + kfree(netlbl_domhsh_addr4_entry(iter4)); + } +#if IS_ENABLED(CONFIG_IPV6) + netlbl_af6list_foreach_safe(iter6, tmp6, + &ptr->def.addrsel->list6) { + netlbl_af6list_remove_entry(iter6); + kfree(netlbl_domhsh_addr6_entry(iter6)); + } +#endif /* IPv6 */ + kfree(ptr->def.addrsel); + } + kfree(ptr->domain); + kfree(ptr); +} + +/** + * netlbl_domhsh_hash - Hashing function for the domain hash table + * @key: the domain name to hash + * + * Description: + * This is the hashing function for the domain hash table, it returns the + * correct bucket number for the domain. The caller is responsible for + * ensuring that the hash table is protected with either a RCU read lock or the + * hash table lock. + * + */ +static u32 netlbl_domhsh_hash(const char *key) +{ + u32 iter; + u32 val; + u32 len; + + /* This is taken (with slight modification) from + * security/selinux/ss/symtab.c:symhash() */ + + for (iter = 0, val = 0, len = strlen(key); iter < len; iter++) + val = (val << 4 | (val >> (8 * sizeof(u32) - 4))) ^ key[iter]; + return val & (netlbl_domhsh_rcu_deref(netlbl_domhsh)->size - 1); +} + +static bool netlbl_family_match(u16 f1, u16 f2) +{ + return (f1 == f2) || (f1 == AF_UNSPEC) || (f2 == AF_UNSPEC); +} + +/** + * netlbl_domhsh_search - Search for a domain entry + * @domain: the domain + * @family: the address family + * + * Description: + * Searches the domain hash table and returns a pointer to the hash table + * entry if found, otherwise NULL is returned. @family may be %AF_UNSPEC + * which matches any address family entries. The caller is responsible for + * ensuring that the hash table is protected with either a RCU read lock or the + * hash table lock. + * + */ +static struct netlbl_dom_map *netlbl_domhsh_search(const char *domain, + u16 family) +{ + u32 bkt; + struct list_head *bkt_list; + struct netlbl_dom_map *iter; + + if (domain != NULL) { + bkt = netlbl_domhsh_hash(domain); + bkt_list = &netlbl_domhsh_rcu_deref(netlbl_domhsh)->tbl[bkt]; + list_for_each_entry_rcu(iter, bkt_list, list, + lockdep_is_held(&netlbl_domhsh_lock)) + if (iter->valid && + netlbl_family_match(iter->family, family) && + strcmp(iter->domain, domain) == 0) + return iter; + } + + return NULL; +} + +/** + * netlbl_domhsh_search_def - Search for a domain entry + * @domain: the domain + * @family: the address family + * + * Description: + * Searches the domain hash table and returns a pointer to the hash table + * entry if an exact match is found, if an exact match is not present in the + * hash table then the default entry is returned if valid otherwise NULL is + * returned. @family may be %AF_UNSPEC which matches any address family + * entries. The caller is responsible ensuring that the hash table is + * protected with either a RCU read lock or the hash table lock. + * + */ +static struct netlbl_dom_map *netlbl_domhsh_search_def(const char *domain, + u16 family) +{ + struct netlbl_dom_map *entry; + + entry = netlbl_domhsh_search(domain, family); + if (entry != NULL) + return entry; + if (family == AF_INET || family == AF_UNSPEC) { + entry = netlbl_domhsh_rcu_deref(netlbl_domhsh_def_ipv4); + if (entry != NULL && entry->valid) + return entry; + } + if (family == AF_INET6 || family == AF_UNSPEC) { + entry = netlbl_domhsh_rcu_deref(netlbl_domhsh_def_ipv6); + if (entry != NULL && entry->valid) + return entry; + } + + return NULL; +} + +/** + * netlbl_domhsh_audit_add - Generate an audit entry for an add event + * @entry: the entry being added + * @addr4: the IPv4 address information + * @addr6: the IPv6 address information + * @result: the result code + * @audit_info: NetLabel audit information + * + * Description: + * Generate an audit record for adding a new NetLabel/LSM mapping entry with + * the given information. Caller is responsible for holding the necessary + * locks. + * + */ +static void netlbl_domhsh_audit_add(struct netlbl_dom_map *entry, + struct netlbl_af4list *addr4, + struct netlbl_af6list *addr6, + int result, + struct netlbl_audit *audit_info) +{ + struct audit_buffer *audit_buf; + struct cipso_v4_doi *cipsov4 = NULL; + struct calipso_doi *calipso = NULL; + u32 type; + + audit_buf = netlbl_audit_start_common(AUDIT_MAC_MAP_ADD, audit_info); + if (audit_buf != NULL) { + audit_log_format(audit_buf, " nlbl_domain=%s", + entry->domain ? entry->domain : "(default)"); + if (addr4 != NULL) { + struct netlbl_domaddr4_map *map4; + map4 = netlbl_domhsh_addr4_entry(addr4); + type = map4->def.type; + cipsov4 = map4->def.cipso; + netlbl_af4list_audit_addr(audit_buf, 0, NULL, + addr4->addr, addr4->mask); +#if IS_ENABLED(CONFIG_IPV6) + } else if (addr6 != NULL) { + struct netlbl_domaddr6_map *map6; + map6 = netlbl_domhsh_addr6_entry(addr6); + type = map6->def.type; + calipso = map6->def.calipso; + netlbl_af6list_audit_addr(audit_buf, 0, NULL, + &addr6->addr, &addr6->mask); +#endif /* IPv6 */ + } else { + type = entry->def.type; + cipsov4 = entry->def.cipso; + calipso = entry->def.calipso; + } + switch (type) { + case NETLBL_NLTYPE_UNLABELED: + audit_log_format(audit_buf, " nlbl_protocol=unlbl"); + break; + case NETLBL_NLTYPE_CIPSOV4: + BUG_ON(cipsov4 == NULL); + audit_log_format(audit_buf, + " nlbl_protocol=cipsov4 cipso_doi=%u", + cipsov4->doi); + break; + case NETLBL_NLTYPE_CALIPSO: + BUG_ON(calipso == NULL); + audit_log_format(audit_buf, + " nlbl_protocol=calipso calipso_doi=%u", + calipso->doi); + break; + } + audit_log_format(audit_buf, " res=%u", result == 0 ? 1 : 0); + audit_log_end(audit_buf); + } +} + +/** + * netlbl_domhsh_validate - Validate a new domain mapping entry + * @entry: the entry to validate + * + * This function validates the new domain mapping entry to ensure that it is + * a valid entry. Returns zero on success, negative values on failure. + * + */ +static int netlbl_domhsh_validate(const struct netlbl_dom_map *entry) +{ + struct netlbl_af4list *iter4; + struct netlbl_domaddr4_map *map4; +#if IS_ENABLED(CONFIG_IPV6) + struct netlbl_af6list *iter6; + struct netlbl_domaddr6_map *map6; +#endif /* IPv6 */ + + if (entry == NULL) + return -EINVAL; + + if (entry->family != AF_INET && entry->family != AF_INET6 && + (entry->family != AF_UNSPEC || + entry->def.type != NETLBL_NLTYPE_UNLABELED)) + return -EINVAL; + + switch (entry->def.type) { + case NETLBL_NLTYPE_UNLABELED: + if (entry->def.cipso != NULL || entry->def.calipso != NULL || + entry->def.addrsel != NULL) + return -EINVAL; + break; + case NETLBL_NLTYPE_CIPSOV4: + if (entry->family != AF_INET || + entry->def.cipso == NULL) + return -EINVAL; + break; + case NETLBL_NLTYPE_CALIPSO: + if (entry->family != AF_INET6 || + entry->def.calipso == NULL) + return -EINVAL; + break; + case NETLBL_NLTYPE_ADDRSELECT: + netlbl_af4list_foreach(iter4, &entry->def.addrsel->list4) { + map4 = netlbl_domhsh_addr4_entry(iter4); + switch (map4->def.type) { + case NETLBL_NLTYPE_UNLABELED: + if (map4->def.cipso != NULL) + return -EINVAL; + break; + case NETLBL_NLTYPE_CIPSOV4: + if (map4->def.cipso == NULL) + return -EINVAL; + break; + default: + return -EINVAL; + } + } +#if IS_ENABLED(CONFIG_IPV6) + netlbl_af6list_foreach(iter6, &entry->def.addrsel->list6) { + map6 = netlbl_domhsh_addr6_entry(iter6); + switch (map6->def.type) { + case NETLBL_NLTYPE_UNLABELED: + if (map6->def.calipso != NULL) + return -EINVAL; + break; + case NETLBL_NLTYPE_CALIPSO: + if (map6->def.calipso == NULL) + return -EINVAL; + break; + default: + return -EINVAL; + } + } +#endif /* IPv6 */ + break; + default: + return -EINVAL; + } + + return 0; +} + +/* + * Domain Hash Table Functions + */ + +/** + * netlbl_domhsh_init - Init for the domain hash + * @size: the number of bits to use for the hash buckets + * + * Description: + * Initializes the domain hash table, should be called only by + * netlbl_user_init() during initialization. Returns zero on success, non-zero + * values on error. + * + */ +int __init netlbl_domhsh_init(u32 size) +{ + u32 iter; + struct netlbl_domhsh_tbl *hsh_tbl; + + if (size == 0) + return -EINVAL; + + hsh_tbl = kmalloc(sizeof(*hsh_tbl), GFP_KERNEL); + if (hsh_tbl == NULL) + return -ENOMEM; + hsh_tbl->size = 1 << size; + hsh_tbl->tbl = kcalloc(hsh_tbl->size, + sizeof(struct list_head), + GFP_KERNEL); + if (hsh_tbl->tbl == NULL) { + kfree(hsh_tbl); + return -ENOMEM; + } + for (iter = 0; iter < hsh_tbl->size; iter++) + INIT_LIST_HEAD(&hsh_tbl->tbl[iter]); + + spin_lock(&netlbl_domhsh_lock); + rcu_assign_pointer(netlbl_domhsh, hsh_tbl); + spin_unlock(&netlbl_domhsh_lock); + + return 0; +} + +/** + * netlbl_domhsh_add - Adds a entry to the domain hash table + * @entry: the entry to add + * @audit_info: NetLabel audit information + * + * Description: + * Adds a new entry to the domain hash table and handles any updates to the + * lower level protocol handler (i.e. CIPSO). @entry->family may be set to + * %AF_UNSPEC which will add an entry that matches all address families. This + * is only useful for the unlabelled type and will only succeed if there is no + * existing entry for any address family with the same domain. Returns zero + * on success, negative on failure. + * + */ +int netlbl_domhsh_add(struct netlbl_dom_map *entry, + struct netlbl_audit *audit_info) +{ + int ret_val = 0; + struct netlbl_dom_map *entry_old, *entry_b; + struct netlbl_af4list *iter4; + struct netlbl_af4list *tmp4; +#if IS_ENABLED(CONFIG_IPV6) + struct netlbl_af6list *iter6; + struct netlbl_af6list *tmp6; +#endif /* IPv6 */ + + ret_val = netlbl_domhsh_validate(entry); + if (ret_val != 0) + return ret_val; + + /* XXX - we can remove this RCU read lock as the spinlock protects the + * entire function, but before we do we need to fixup the + * netlbl_af[4,6]list RCU functions to do "the right thing" with + * respect to rcu_dereference() when only a spinlock is held. */ + rcu_read_lock(); + spin_lock(&netlbl_domhsh_lock); + if (entry->domain != NULL) + entry_old = netlbl_domhsh_search(entry->domain, entry->family); + else + entry_old = netlbl_domhsh_search_def(entry->domain, + entry->family); + if (entry_old == NULL) { + entry->valid = 1; + + if (entry->domain != NULL) { + u32 bkt = netlbl_domhsh_hash(entry->domain); + list_add_tail_rcu(&entry->list, + &rcu_dereference(netlbl_domhsh)->tbl[bkt]); + } else { + INIT_LIST_HEAD(&entry->list); + switch (entry->family) { + case AF_INET: + rcu_assign_pointer(netlbl_domhsh_def_ipv4, + entry); + break; + case AF_INET6: + rcu_assign_pointer(netlbl_domhsh_def_ipv6, + entry); + break; + case AF_UNSPEC: + if (entry->def.type != + NETLBL_NLTYPE_UNLABELED) { + ret_val = -EINVAL; + goto add_return; + } + entry_b = kzalloc(sizeof(*entry_b), GFP_ATOMIC); + if (entry_b == NULL) { + ret_val = -ENOMEM; + goto add_return; + } + entry_b->family = AF_INET6; + entry_b->def.type = NETLBL_NLTYPE_UNLABELED; + entry_b->valid = 1; + entry->family = AF_INET; + rcu_assign_pointer(netlbl_domhsh_def_ipv4, + entry); + rcu_assign_pointer(netlbl_domhsh_def_ipv6, + entry_b); + break; + default: + /* Already checked in + * netlbl_domhsh_validate(). */ + ret_val = -EINVAL; + goto add_return; + } + } + + if (entry->def.type == NETLBL_NLTYPE_ADDRSELECT) { + netlbl_af4list_foreach_rcu(iter4, + &entry->def.addrsel->list4) + netlbl_domhsh_audit_add(entry, iter4, NULL, + ret_val, audit_info); +#if IS_ENABLED(CONFIG_IPV6) + netlbl_af6list_foreach_rcu(iter6, + &entry->def.addrsel->list6) + netlbl_domhsh_audit_add(entry, NULL, iter6, + ret_val, audit_info); +#endif /* IPv6 */ + } else + netlbl_domhsh_audit_add(entry, NULL, NULL, + ret_val, audit_info); + } else if (entry_old->def.type == NETLBL_NLTYPE_ADDRSELECT && + entry->def.type == NETLBL_NLTYPE_ADDRSELECT) { + struct list_head *old_list4; + struct list_head *old_list6; + + old_list4 = &entry_old->def.addrsel->list4; + old_list6 = &entry_old->def.addrsel->list6; + + /* we only allow the addition of address selectors if all of + * the selectors do not exist in the existing domain map */ + netlbl_af4list_foreach_rcu(iter4, &entry->def.addrsel->list4) + if (netlbl_af4list_search_exact(iter4->addr, + iter4->mask, + old_list4)) { + ret_val = -EEXIST; + goto add_return; + } +#if IS_ENABLED(CONFIG_IPV6) + netlbl_af6list_foreach_rcu(iter6, &entry->def.addrsel->list6) + if (netlbl_af6list_search_exact(&iter6->addr, + &iter6->mask, + old_list6)) { + ret_val = -EEXIST; + goto add_return; + } +#endif /* IPv6 */ + + netlbl_af4list_foreach_safe(iter4, tmp4, + &entry->def.addrsel->list4) { + netlbl_af4list_remove_entry(iter4); + iter4->valid = 1; + ret_val = netlbl_af4list_add(iter4, old_list4); + netlbl_domhsh_audit_add(entry_old, iter4, NULL, + ret_val, audit_info); + if (ret_val != 0) + goto add_return; + } +#if IS_ENABLED(CONFIG_IPV6) + netlbl_af6list_foreach_safe(iter6, tmp6, + &entry->def.addrsel->list6) { + netlbl_af6list_remove_entry(iter6); + iter6->valid = 1; + ret_val = netlbl_af6list_add(iter6, old_list6); + netlbl_domhsh_audit_add(entry_old, NULL, iter6, + ret_val, audit_info); + if (ret_val != 0) + goto add_return; + } +#endif /* IPv6 */ + /* cleanup the new entry since we've moved everything over */ + netlbl_domhsh_free_entry(&entry->rcu); + } else + ret_val = -EINVAL; + +add_return: + spin_unlock(&netlbl_domhsh_lock); + rcu_read_unlock(); + return ret_val; +} + +/** + * netlbl_domhsh_add_default - Adds the default entry to the domain hash table + * @entry: the entry to add + * @audit_info: NetLabel audit information + * + * Description: + * Adds a new default entry to the domain hash table and handles any updates + * to the lower level protocol handler (i.e. CIPSO). Returns zero on success, + * negative on failure. + * + */ +int netlbl_domhsh_add_default(struct netlbl_dom_map *entry, + struct netlbl_audit *audit_info) +{ + return netlbl_domhsh_add(entry, audit_info); +} + +/** + * netlbl_domhsh_remove_entry - Removes a given entry from the domain table + * @entry: the entry to remove + * @audit_info: NetLabel audit information + * + * Description: + * Removes an entry from the domain hash table and handles any updates to the + * lower level protocol handler (i.e. CIPSO). Caller is responsible for + * ensuring that the RCU read lock is held. Returns zero on success, negative + * on failure. + * + */ +int netlbl_domhsh_remove_entry(struct netlbl_dom_map *entry, + struct netlbl_audit *audit_info) +{ + int ret_val = 0; + struct audit_buffer *audit_buf; + struct netlbl_af4list *iter4; + struct netlbl_domaddr4_map *map4; +#if IS_ENABLED(CONFIG_IPV6) + struct netlbl_af6list *iter6; + struct netlbl_domaddr6_map *map6; +#endif /* IPv6 */ + + if (entry == NULL) + return -ENOENT; + + spin_lock(&netlbl_domhsh_lock); + if (entry->valid) { + entry->valid = 0; + if (entry == rcu_dereference(netlbl_domhsh_def_ipv4)) + RCU_INIT_POINTER(netlbl_domhsh_def_ipv4, NULL); + else if (entry == rcu_dereference(netlbl_domhsh_def_ipv6)) + RCU_INIT_POINTER(netlbl_domhsh_def_ipv6, NULL); + else + list_del_rcu(&entry->list); + } else + ret_val = -ENOENT; + spin_unlock(&netlbl_domhsh_lock); + + if (ret_val) + return ret_val; + + audit_buf = netlbl_audit_start_common(AUDIT_MAC_MAP_DEL, audit_info); + if (audit_buf != NULL) { + audit_log_format(audit_buf, + " nlbl_domain=%s res=1", + entry->domain ? entry->domain : "(default)"); + audit_log_end(audit_buf); + } + + switch (entry->def.type) { + case NETLBL_NLTYPE_ADDRSELECT: + netlbl_af4list_foreach_rcu(iter4, &entry->def.addrsel->list4) { + map4 = netlbl_domhsh_addr4_entry(iter4); + cipso_v4_doi_putdef(map4->def.cipso); + } +#if IS_ENABLED(CONFIG_IPV6) + netlbl_af6list_foreach_rcu(iter6, &entry->def.addrsel->list6) { + map6 = netlbl_domhsh_addr6_entry(iter6); + calipso_doi_putdef(map6->def.calipso); + } +#endif /* IPv6 */ + break; + case NETLBL_NLTYPE_CIPSOV4: + cipso_v4_doi_putdef(entry->def.cipso); + break; +#if IS_ENABLED(CONFIG_IPV6) + case NETLBL_NLTYPE_CALIPSO: + calipso_doi_putdef(entry->def.calipso); + break; +#endif /* IPv6 */ + } + call_rcu(&entry->rcu, netlbl_domhsh_free_entry); + + return ret_val; +} + +/** + * netlbl_domhsh_remove_af4 - Removes an address selector entry + * @domain: the domain + * @addr: IPv4 address + * @mask: IPv4 address mask + * @audit_info: NetLabel audit information + * + * Description: + * Removes an individual address selector from a domain mapping and potentially + * the entire mapping if it is empty. Returns zero on success, negative values + * on failure. + * + */ +int netlbl_domhsh_remove_af4(const char *domain, + const struct in_addr *addr, + const struct in_addr *mask, + struct netlbl_audit *audit_info) +{ + struct netlbl_dom_map *entry_map; + struct netlbl_af4list *entry_addr; + struct netlbl_af4list *iter4; +#if IS_ENABLED(CONFIG_IPV6) + struct netlbl_af6list *iter6; +#endif /* IPv6 */ + struct netlbl_domaddr4_map *entry; + + rcu_read_lock(); + + if (domain) + entry_map = netlbl_domhsh_search(domain, AF_INET); + else + entry_map = netlbl_domhsh_search_def(domain, AF_INET); + if (entry_map == NULL || + entry_map->def.type != NETLBL_NLTYPE_ADDRSELECT) + goto remove_af4_failure; + + spin_lock(&netlbl_domhsh_lock); + entry_addr = netlbl_af4list_remove(addr->s_addr, mask->s_addr, + &entry_map->def.addrsel->list4); + spin_unlock(&netlbl_domhsh_lock); + + if (entry_addr == NULL) + goto remove_af4_failure; + netlbl_af4list_foreach_rcu(iter4, &entry_map->def.addrsel->list4) + goto remove_af4_single_addr; +#if IS_ENABLED(CONFIG_IPV6) + netlbl_af6list_foreach_rcu(iter6, &entry_map->def.addrsel->list6) + goto remove_af4_single_addr; +#endif /* IPv6 */ + /* the domain mapping is empty so remove it from the mapping table */ + netlbl_domhsh_remove_entry(entry_map, audit_info); + +remove_af4_single_addr: + rcu_read_unlock(); + /* yick, we can't use call_rcu here because we don't have a rcu head + * pointer but hopefully this should be a rare case so the pause + * shouldn't be a problem */ + synchronize_rcu(); + entry = netlbl_domhsh_addr4_entry(entry_addr); + cipso_v4_doi_putdef(entry->def.cipso); + kfree(entry); + return 0; + +remove_af4_failure: + rcu_read_unlock(); + return -ENOENT; +} + +#if IS_ENABLED(CONFIG_IPV6) +/** + * netlbl_domhsh_remove_af6 - Removes an address selector entry + * @domain: the domain + * @addr: IPv6 address + * @mask: IPv6 address mask + * @audit_info: NetLabel audit information + * + * Description: + * Removes an individual address selector from a domain mapping and potentially + * the entire mapping if it is empty. Returns zero on success, negative values + * on failure. + * + */ +int netlbl_domhsh_remove_af6(const char *domain, + const struct in6_addr *addr, + const struct in6_addr *mask, + struct netlbl_audit *audit_info) +{ + struct netlbl_dom_map *entry_map; + struct netlbl_af6list *entry_addr; + struct netlbl_af4list *iter4; + struct netlbl_af6list *iter6; + struct netlbl_domaddr6_map *entry; + + rcu_read_lock(); + + if (domain) + entry_map = netlbl_domhsh_search(domain, AF_INET6); + else + entry_map = netlbl_domhsh_search_def(domain, AF_INET6); + if (entry_map == NULL || + entry_map->def.type != NETLBL_NLTYPE_ADDRSELECT) + goto remove_af6_failure; + + spin_lock(&netlbl_domhsh_lock); + entry_addr = netlbl_af6list_remove(addr, mask, + &entry_map->def.addrsel->list6); + spin_unlock(&netlbl_domhsh_lock); + + if (entry_addr == NULL) + goto remove_af6_failure; + netlbl_af4list_foreach_rcu(iter4, &entry_map->def.addrsel->list4) + goto remove_af6_single_addr; + netlbl_af6list_foreach_rcu(iter6, &entry_map->def.addrsel->list6) + goto remove_af6_single_addr; + /* the domain mapping is empty so remove it from the mapping table */ + netlbl_domhsh_remove_entry(entry_map, audit_info); + +remove_af6_single_addr: + rcu_read_unlock(); + /* yick, we can't use call_rcu here because we don't have a rcu head + * pointer but hopefully this should be a rare case so the pause + * shouldn't be a problem */ + synchronize_rcu(); + entry = netlbl_domhsh_addr6_entry(entry_addr); + calipso_doi_putdef(entry->def.calipso); + kfree(entry); + return 0; + +remove_af6_failure: + rcu_read_unlock(); + return -ENOENT; +} +#endif /* IPv6 */ + +/** + * netlbl_domhsh_remove - Removes an entry from the domain hash table + * @domain: the domain to remove + * @family: address family + * @audit_info: NetLabel audit information + * + * Description: + * Removes an entry from the domain hash table and handles any updates to the + * lower level protocol handler (i.e. CIPSO). @family may be %AF_UNSPEC which + * removes all address family entries. Returns zero on success, negative on + * failure. + * + */ +int netlbl_domhsh_remove(const char *domain, u16 family, + struct netlbl_audit *audit_info) +{ + int ret_val = -EINVAL; + struct netlbl_dom_map *entry; + + rcu_read_lock(); + + if (family == AF_INET || family == AF_UNSPEC) { + if (domain) + entry = netlbl_domhsh_search(domain, AF_INET); + else + entry = netlbl_domhsh_search_def(domain, AF_INET); + ret_val = netlbl_domhsh_remove_entry(entry, audit_info); + if (ret_val && ret_val != -ENOENT) + goto done; + } + if (family == AF_INET6 || family == AF_UNSPEC) { + int ret_val2; + + if (domain) + entry = netlbl_domhsh_search(domain, AF_INET6); + else + entry = netlbl_domhsh_search_def(domain, AF_INET6); + ret_val2 = netlbl_domhsh_remove_entry(entry, audit_info); + if (ret_val2 != -ENOENT) + ret_val = ret_val2; + } +done: + rcu_read_unlock(); + + return ret_val; +} + +/** + * netlbl_domhsh_remove_default - Removes the default entry from the table + * @family: address family + * @audit_info: NetLabel audit information + * + * Description: + * Removes/resets the default entry corresponding to @family from the domain + * hash table and handles any updates to the lower level protocol handler + * (i.e. CIPSO). @family may be %AF_UNSPEC which removes all address family + * entries. Returns zero on success, negative on failure. + * + */ +int netlbl_domhsh_remove_default(u16 family, struct netlbl_audit *audit_info) +{ + return netlbl_domhsh_remove(NULL, family, audit_info); +} + +/** + * netlbl_domhsh_getentry - Get an entry from the domain hash table + * @domain: the domain name to search for + * @family: address family + * + * Description: + * Look through the domain hash table searching for an entry to match @domain, + * with address family @family, return a pointer to a copy of the entry or + * NULL. The caller is responsible for ensuring that rcu_read_[un]lock() is + * called. + * + */ +struct netlbl_dom_map *netlbl_domhsh_getentry(const char *domain, u16 family) +{ + if (family == AF_UNSPEC) + return NULL; + return netlbl_domhsh_search_def(domain, family); +} + +/** + * netlbl_domhsh_getentry_af4 - Get an entry from the domain hash table + * @domain: the domain name to search for + * @addr: the IP address to search for + * + * Description: + * Look through the domain hash table searching for an entry to match @domain + * and @addr, return a pointer to a copy of the entry or NULL. The caller is + * responsible for ensuring that rcu_read_[un]lock() is called. + * + */ +struct netlbl_dommap_def *netlbl_domhsh_getentry_af4(const char *domain, + __be32 addr) +{ + struct netlbl_dom_map *dom_iter; + struct netlbl_af4list *addr_iter; + + dom_iter = netlbl_domhsh_search_def(domain, AF_INET); + if (dom_iter == NULL) + return NULL; + + if (dom_iter->def.type != NETLBL_NLTYPE_ADDRSELECT) + return &dom_iter->def; + addr_iter = netlbl_af4list_search(addr, &dom_iter->def.addrsel->list4); + if (addr_iter == NULL) + return NULL; + return &(netlbl_domhsh_addr4_entry(addr_iter)->def); +} + +#if IS_ENABLED(CONFIG_IPV6) +/** + * netlbl_domhsh_getentry_af6 - Get an entry from the domain hash table + * @domain: the domain name to search for + * @addr: the IP address to search for + * + * Description: + * Look through the domain hash table searching for an entry to match @domain + * and @addr, return a pointer to a copy of the entry or NULL. The caller is + * responsible for ensuring that rcu_read_[un]lock() is called. + * + */ +struct netlbl_dommap_def *netlbl_domhsh_getentry_af6(const char *domain, + const struct in6_addr *addr) +{ + struct netlbl_dom_map *dom_iter; + struct netlbl_af6list *addr_iter; + + dom_iter = netlbl_domhsh_search_def(domain, AF_INET6); + if (dom_iter == NULL) + return NULL; + + if (dom_iter->def.type != NETLBL_NLTYPE_ADDRSELECT) + return &dom_iter->def; + addr_iter = netlbl_af6list_search(addr, &dom_iter->def.addrsel->list6); + if (addr_iter == NULL) + return NULL; + return &(netlbl_domhsh_addr6_entry(addr_iter)->def); +} +#endif /* IPv6 */ + +/** + * netlbl_domhsh_walk - Iterate through the domain mapping hash table + * @skip_bkt: the number of buckets to skip at the start + * @skip_chain: the number of entries to skip in the first iterated bucket + * @callback: callback for each entry + * @cb_arg: argument for the callback function + * + * Description: + * Iterate over the domain mapping hash table, skipping the first @skip_bkt + * buckets and @skip_chain entries. For each entry in the table call + * @callback, if @callback returns a negative value stop 'walking' through the + * table and return. Updates the values in @skip_bkt and @skip_chain on + * return. Returns zero on success, negative values on failure. + * + */ +int netlbl_domhsh_walk(u32 *skip_bkt, + u32 *skip_chain, + int (*callback) (struct netlbl_dom_map *entry, void *arg), + void *cb_arg) +{ + int ret_val = -ENOENT; + u32 iter_bkt; + struct list_head *iter_list; + struct netlbl_dom_map *iter_entry; + u32 chain_cnt = 0; + + rcu_read_lock(); + for (iter_bkt = *skip_bkt; + iter_bkt < rcu_dereference(netlbl_domhsh)->size; + iter_bkt++, chain_cnt = 0) { + iter_list = &rcu_dereference(netlbl_domhsh)->tbl[iter_bkt]; + list_for_each_entry_rcu(iter_entry, iter_list, list) + if (iter_entry->valid) { + if (chain_cnt++ < *skip_chain) + continue; + ret_val = callback(iter_entry, cb_arg); + if (ret_val < 0) { + chain_cnt--; + goto walk_return; + } + } + } + +walk_return: + rcu_read_unlock(); + *skip_bkt = iter_bkt; + *skip_chain = chain_cnt; + return ret_val; +} diff --git a/net/netlabel/netlabel_domainhash.h b/net/netlabel/netlabel_domainhash.h new file mode 100644 index 000000000..9f80972ae --- /dev/null +++ b/net/netlabel/netlabel_domainhash.h @@ -0,0 +1,106 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * NetLabel Domain Hash Table + * + * This file manages the domain hash table that NetLabel uses to determine + * which network labeling protocol to use for a given domain. The NetLabel + * system manages static and dynamic label mappings for network protocols such + * as CIPSO and RIPSO. + * + * Author: Paul Moore + */ + +/* + * (c) Copyright Hewlett-Packard Development Company, L.P., 2006, 2008 + */ + +#ifndef _NETLABEL_DOMAINHASH_H +#define _NETLABEL_DOMAINHASH_H + +#include +#include +#include + +#include "netlabel_addrlist.h" + +/* Domain hash table size */ +/* XXX - currently this number is an uneducated guess */ +#define NETLBL_DOMHSH_BITSIZE 7 + +/* Domain mapping definition structures */ +struct netlbl_domaddr_map { + struct list_head list4; + struct list_head list6; +}; +struct netlbl_dommap_def { + u32 type; + union { + struct netlbl_domaddr_map *addrsel; + struct cipso_v4_doi *cipso; + struct calipso_doi *calipso; + }; +}; +#define netlbl_domhsh_addr4_entry(iter) \ + container_of(iter, struct netlbl_domaddr4_map, list) +struct netlbl_domaddr4_map { + struct netlbl_dommap_def def; + + struct netlbl_af4list list; +}; +#define netlbl_domhsh_addr6_entry(iter) \ + container_of(iter, struct netlbl_domaddr6_map, list) +struct netlbl_domaddr6_map { + struct netlbl_dommap_def def; + + struct netlbl_af6list list; +}; + +struct netlbl_dom_map { + char *domain; + u16 family; + struct netlbl_dommap_def def; + + u32 valid; + struct list_head list; + struct rcu_head rcu; +}; + +/* init function */ +int netlbl_domhsh_init(u32 size); + +/* Manipulate the domain hash table */ +int netlbl_domhsh_add(struct netlbl_dom_map *entry, + struct netlbl_audit *audit_info); +int netlbl_domhsh_add_default(struct netlbl_dom_map *entry, + struct netlbl_audit *audit_info); +int netlbl_domhsh_remove_entry(struct netlbl_dom_map *entry, + struct netlbl_audit *audit_info); +int netlbl_domhsh_remove_af4(const char *domain, + const struct in_addr *addr, + const struct in_addr *mask, + struct netlbl_audit *audit_info); +int netlbl_domhsh_remove_af6(const char *domain, + const struct in6_addr *addr, + const struct in6_addr *mask, + struct netlbl_audit *audit_info); +int netlbl_domhsh_remove(const char *domain, u16 family, + struct netlbl_audit *audit_info); +int netlbl_domhsh_remove_default(u16 family, struct netlbl_audit *audit_info); +struct netlbl_dom_map *netlbl_domhsh_getentry(const char *domain, u16 family); +struct netlbl_dommap_def *netlbl_domhsh_getentry_af4(const char *domain, + __be32 addr); +#if IS_ENABLED(CONFIG_IPV6) +struct netlbl_dommap_def *netlbl_domhsh_getentry_af6(const char *domain, + const struct in6_addr *addr); +int netlbl_domhsh_remove_af6(const char *domain, + const struct in6_addr *addr, + const struct in6_addr *mask, + struct netlbl_audit *audit_info); +#endif /* IPv6 */ + +int netlbl_domhsh_walk(u32 *skip_bkt, + u32 *skip_chain, + int (*callback) (struct netlbl_dom_map *entry, void *arg), + void *cb_arg); + +#endif diff --git a/net/netlabel/netlabel_kapi.c b/net/netlabel/netlabel_kapi.c new file mode 100644 index 000000000..27511c90a --- /dev/null +++ b/net/netlabel/netlabel_kapi.c @@ -0,0 +1,1526 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NetLabel Kernel API + * + * This file defines the kernel API for the NetLabel system. The NetLabel + * system manages static and dynamic label mappings for network protocols such + * as CIPSO and RIPSO. + * + * Author: Paul Moore + */ + +/* + * (c) Copyright Hewlett-Packard Development Company, L.P., 2006, 2008 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "netlabel_domainhash.h" +#include "netlabel_unlabeled.h" +#include "netlabel_cipso_v4.h" +#include "netlabel_calipso.h" +#include "netlabel_user.h" +#include "netlabel_mgmt.h" +#include "netlabel_addrlist.h" + +/* + * Configuration Functions + */ + +/** + * netlbl_cfg_map_del - Remove a NetLabel/LSM domain mapping + * @domain: the domain mapping to remove + * @family: address family + * @addr: IP address + * @mask: IP address mask + * @audit_info: NetLabel audit information + * + * Description: + * Removes a NetLabel/LSM domain mapping. A @domain value of NULL causes the + * default domain mapping to be removed. Returns zero on success, negative + * values on failure. + * + */ +int netlbl_cfg_map_del(const char *domain, + u16 family, + const void *addr, + const void *mask, + struct netlbl_audit *audit_info) +{ + if (addr == NULL && mask == NULL) { + return netlbl_domhsh_remove(domain, family, audit_info); + } else if (addr != NULL && mask != NULL) { + switch (family) { + case AF_INET: + return netlbl_domhsh_remove_af4(domain, addr, mask, + audit_info); +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + return netlbl_domhsh_remove_af6(domain, addr, mask, + audit_info); +#endif /* IPv6 */ + default: + return -EPFNOSUPPORT; + } + } else + return -EINVAL; +} + +/** + * netlbl_cfg_unlbl_map_add - Add a new unlabeled mapping + * @domain: the domain mapping to add + * @family: address family + * @addr: IP address + * @mask: IP address mask + * @audit_info: NetLabel audit information + * + * Description: + * Adds a new unlabeled NetLabel/LSM domain mapping. A @domain value of NULL + * causes a new default domain mapping to be added. Returns zero on success, + * negative values on failure. + * + */ +int netlbl_cfg_unlbl_map_add(const char *domain, + u16 family, + const void *addr, + const void *mask, + struct netlbl_audit *audit_info) +{ + int ret_val = -ENOMEM; + struct netlbl_dom_map *entry; + struct netlbl_domaddr_map *addrmap = NULL; + struct netlbl_domaddr4_map *map4 = NULL; + struct netlbl_domaddr6_map *map6 = NULL; + + entry = kzalloc(sizeof(*entry), GFP_ATOMIC); + if (entry == NULL) + return -ENOMEM; + if (domain != NULL) { + entry->domain = kstrdup(domain, GFP_ATOMIC); + if (entry->domain == NULL) + goto cfg_unlbl_map_add_failure; + } + entry->family = family; + + if (addr == NULL && mask == NULL) + entry->def.type = NETLBL_NLTYPE_UNLABELED; + else if (addr != NULL && mask != NULL) { + addrmap = kzalloc(sizeof(*addrmap), GFP_ATOMIC); + if (addrmap == NULL) + goto cfg_unlbl_map_add_failure; + INIT_LIST_HEAD(&addrmap->list4); + INIT_LIST_HEAD(&addrmap->list6); + + switch (family) { + case AF_INET: { + const struct in_addr *addr4 = addr; + const struct in_addr *mask4 = mask; + map4 = kzalloc(sizeof(*map4), GFP_ATOMIC); + if (map4 == NULL) + goto cfg_unlbl_map_add_failure; + map4->def.type = NETLBL_NLTYPE_UNLABELED; + map4->list.addr = addr4->s_addr & mask4->s_addr; + map4->list.mask = mask4->s_addr; + map4->list.valid = 1; + ret_val = netlbl_af4list_add(&map4->list, + &addrmap->list4); + if (ret_val != 0) + goto cfg_unlbl_map_add_failure; + break; + } +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: { + const struct in6_addr *addr6 = addr; + const struct in6_addr *mask6 = mask; + map6 = kzalloc(sizeof(*map6), GFP_ATOMIC); + if (map6 == NULL) + goto cfg_unlbl_map_add_failure; + map6->def.type = NETLBL_NLTYPE_UNLABELED; + map6->list.addr = *addr6; + map6->list.addr.s6_addr32[0] &= mask6->s6_addr32[0]; + map6->list.addr.s6_addr32[1] &= mask6->s6_addr32[1]; + map6->list.addr.s6_addr32[2] &= mask6->s6_addr32[2]; + map6->list.addr.s6_addr32[3] &= mask6->s6_addr32[3]; + map6->list.mask = *mask6; + map6->list.valid = 1; + ret_val = netlbl_af6list_add(&map6->list, + &addrmap->list6); + if (ret_val != 0) + goto cfg_unlbl_map_add_failure; + break; + } +#endif /* IPv6 */ + default: + goto cfg_unlbl_map_add_failure; + } + + entry->def.addrsel = addrmap; + entry->def.type = NETLBL_NLTYPE_ADDRSELECT; + } else { + ret_val = -EINVAL; + goto cfg_unlbl_map_add_failure; + } + + ret_val = netlbl_domhsh_add(entry, audit_info); + if (ret_val != 0) + goto cfg_unlbl_map_add_failure; + + return 0; + +cfg_unlbl_map_add_failure: + kfree(entry->domain); + kfree(entry); + kfree(addrmap); + kfree(map4); + kfree(map6); + return ret_val; +} + + +/** + * netlbl_cfg_unlbl_static_add - Adds a new static label + * @net: network namespace + * @dev_name: interface name + * @addr: IP address in network byte order (struct in[6]_addr) + * @mask: address mask in network byte order (struct in[6]_addr) + * @family: address family + * @secid: LSM secid value for the entry + * @audit_info: NetLabel audit information + * + * Description: + * Adds a new NetLabel static label to be used when protocol provided labels + * are not present on incoming traffic. If @dev_name is NULL then the default + * interface will be used. Returns zero on success, negative values on failure. + * + */ +int netlbl_cfg_unlbl_static_add(struct net *net, + const char *dev_name, + const void *addr, + const void *mask, + u16 family, + u32 secid, + struct netlbl_audit *audit_info) +{ + u32 addr_len; + + switch (family) { + case AF_INET: + addr_len = sizeof(struct in_addr); + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + addr_len = sizeof(struct in6_addr); + break; +#endif /* IPv6 */ + default: + return -EPFNOSUPPORT; + } + + return netlbl_unlhsh_add(net, + dev_name, addr, mask, addr_len, + secid, audit_info); +} + +/** + * netlbl_cfg_unlbl_static_del - Removes an existing static label + * @net: network namespace + * @dev_name: interface name + * @addr: IP address in network byte order (struct in[6]_addr) + * @mask: address mask in network byte order (struct in[6]_addr) + * @family: address family + * @audit_info: NetLabel audit information + * + * Description: + * Removes an existing NetLabel static label used when protocol provided labels + * are not present on incoming traffic. If @dev_name is NULL then the default + * interface will be used. Returns zero on success, negative values on failure. + * + */ +int netlbl_cfg_unlbl_static_del(struct net *net, + const char *dev_name, + const void *addr, + const void *mask, + u16 family, + struct netlbl_audit *audit_info) +{ + u32 addr_len; + + switch (family) { + case AF_INET: + addr_len = sizeof(struct in_addr); + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + addr_len = sizeof(struct in6_addr); + break; +#endif /* IPv6 */ + default: + return -EPFNOSUPPORT; + } + + return netlbl_unlhsh_remove(net, + dev_name, addr, mask, addr_len, + audit_info); +} + +/** + * netlbl_cfg_cipsov4_add - Add a new CIPSOv4 DOI definition + * @doi_def: CIPSO DOI definition + * @audit_info: NetLabel audit information + * + * Description: + * Add a new CIPSO DOI definition as defined by @doi_def. Returns zero on + * success and negative values on failure. + * + */ +int netlbl_cfg_cipsov4_add(struct cipso_v4_doi *doi_def, + struct netlbl_audit *audit_info) +{ + return cipso_v4_doi_add(doi_def, audit_info); +} + +/** + * netlbl_cfg_cipsov4_del - Remove an existing CIPSOv4 DOI definition + * @doi: CIPSO DOI + * @audit_info: NetLabel audit information + * + * Description: + * Remove an existing CIPSO DOI definition matching @doi. Returns zero on + * success and negative values on failure. + * + */ +void netlbl_cfg_cipsov4_del(u32 doi, struct netlbl_audit *audit_info) +{ + cipso_v4_doi_remove(doi, audit_info); +} + +/** + * netlbl_cfg_cipsov4_map_add - Add a new CIPSOv4 DOI mapping + * @doi: the CIPSO DOI + * @domain: the domain mapping to add + * @addr: IP address + * @mask: IP address mask + * @audit_info: NetLabel audit information + * + * Description: + * Add a new NetLabel/LSM domain mapping for the given CIPSO DOI to the NetLabel + * subsystem. A @domain value of NULL adds a new default domain mapping. + * Returns zero on success, negative values on failure. + * + */ +int netlbl_cfg_cipsov4_map_add(u32 doi, + const char *domain, + const struct in_addr *addr, + const struct in_addr *mask, + struct netlbl_audit *audit_info) +{ + int ret_val = -ENOMEM; + struct cipso_v4_doi *doi_def; + struct netlbl_dom_map *entry; + struct netlbl_domaddr_map *addrmap = NULL; + struct netlbl_domaddr4_map *addrinfo = NULL; + + doi_def = cipso_v4_doi_getdef(doi); + if (doi_def == NULL) + return -ENOENT; + + entry = kzalloc(sizeof(*entry), GFP_ATOMIC); + if (entry == NULL) + goto out_entry; + entry->family = AF_INET; + if (domain != NULL) { + entry->domain = kstrdup(domain, GFP_ATOMIC); + if (entry->domain == NULL) + goto out_domain; + } + + if (addr == NULL && mask == NULL) { + entry->def.cipso = doi_def; + entry->def.type = NETLBL_NLTYPE_CIPSOV4; + } else if (addr != NULL && mask != NULL) { + addrmap = kzalloc(sizeof(*addrmap), GFP_ATOMIC); + if (addrmap == NULL) + goto out_addrmap; + INIT_LIST_HEAD(&addrmap->list4); + INIT_LIST_HEAD(&addrmap->list6); + + addrinfo = kzalloc(sizeof(*addrinfo), GFP_ATOMIC); + if (addrinfo == NULL) + goto out_addrinfo; + addrinfo->def.cipso = doi_def; + addrinfo->def.type = NETLBL_NLTYPE_CIPSOV4; + addrinfo->list.addr = addr->s_addr & mask->s_addr; + addrinfo->list.mask = mask->s_addr; + addrinfo->list.valid = 1; + ret_val = netlbl_af4list_add(&addrinfo->list, &addrmap->list4); + if (ret_val != 0) + goto cfg_cipsov4_map_add_failure; + + entry->def.addrsel = addrmap; + entry->def.type = NETLBL_NLTYPE_ADDRSELECT; + } else { + ret_val = -EINVAL; + goto out_addrmap; + } + + ret_val = netlbl_domhsh_add(entry, audit_info); + if (ret_val != 0) + goto cfg_cipsov4_map_add_failure; + + return 0; + +cfg_cipsov4_map_add_failure: + kfree(addrinfo); +out_addrinfo: + kfree(addrmap); +out_addrmap: + kfree(entry->domain); +out_domain: + kfree(entry); +out_entry: + cipso_v4_doi_putdef(doi_def); + return ret_val; +} + +/** + * netlbl_cfg_calipso_add - Add a new CALIPSO DOI definition + * @doi_def: CALIPSO DOI definition + * @audit_info: NetLabel audit information + * + * Description: + * Add a new CALIPSO DOI definition as defined by @doi_def. Returns zero on + * success and negative values on failure. + * + */ +int netlbl_cfg_calipso_add(struct calipso_doi *doi_def, + struct netlbl_audit *audit_info) +{ +#if IS_ENABLED(CONFIG_IPV6) + return calipso_doi_add(doi_def, audit_info); +#else /* IPv6 */ + return -ENOSYS; +#endif /* IPv6 */ +} + +/** + * netlbl_cfg_calipso_del - Remove an existing CALIPSO DOI definition + * @doi: CALIPSO DOI + * @audit_info: NetLabel audit information + * + * Description: + * Remove an existing CALIPSO DOI definition matching @doi. Returns zero on + * success and negative values on failure. + * + */ +void netlbl_cfg_calipso_del(u32 doi, struct netlbl_audit *audit_info) +{ +#if IS_ENABLED(CONFIG_IPV6) + calipso_doi_remove(doi, audit_info); +#endif /* IPv6 */ +} + +/** + * netlbl_cfg_calipso_map_add - Add a new CALIPSO DOI mapping + * @doi: the CALIPSO DOI + * @domain: the domain mapping to add + * @addr: IP address + * @mask: IP address mask + * @audit_info: NetLabel audit information + * + * Description: + * Add a new NetLabel/LSM domain mapping for the given CALIPSO DOI to the + * NetLabel subsystem. A @domain value of NULL adds a new default domain + * mapping. Returns zero on success, negative values on failure. + * + */ +int netlbl_cfg_calipso_map_add(u32 doi, + const char *domain, + const struct in6_addr *addr, + const struct in6_addr *mask, + struct netlbl_audit *audit_info) +{ +#if IS_ENABLED(CONFIG_IPV6) + int ret_val = -ENOMEM; + struct calipso_doi *doi_def; + struct netlbl_dom_map *entry; + struct netlbl_domaddr_map *addrmap = NULL; + struct netlbl_domaddr6_map *addrinfo = NULL; + + doi_def = calipso_doi_getdef(doi); + if (doi_def == NULL) + return -ENOENT; + + entry = kzalloc(sizeof(*entry), GFP_ATOMIC); + if (entry == NULL) + goto out_entry; + entry->family = AF_INET6; + if (domain != NULL) { + entry->domain = kstrdup(domain, GFP_ATOMIC); + if (entry->domain == NULL) + goto out_domain; + } + + if (addr == NULL && mask == NULL) { + entry->def.calipso = doi_def; + entry->def.type = NETLBL_NLTYPE_CALIPSO; + } else if (addr != NULL && mask != NULL) { + addrmap = kzalloc(sizeof(*addrmap), GFP_ATOMIC); + if (addrmap == NULL) + goto out_addrmap; + INIT_LIST_HEAD(&addrmap->list4); + INIT_LIST_HEAD(&addrmap->list6); + + addrinfo = kzalloc(sizeof(*addrinfo), GFP_ATOMIC); + if (addrinfo == NULL) + goto out_addrinfo; + addrinfo->def.calipso = doi_def; + addrinfo->def.type = NETLBL_NLTYPE_CALIPSO; + addrinfo->list.addr = *addr; + addrinfo->list.addr.s6_addr32[0] &= mask->s6_addr32[0]; + addrinfo->list.addr.s6_addr32[1] &= mask->s6_addr32[1]; + addrinfo->list.addr.s6_addr32[2] &= mask->s6_addr32[2]; + addrinfo->list.addr.s6_addr32[3] &= mask->s6_addr32[3]; + addrinfo->list.mask = *mask; + addrinfo->list.valid = 1; + ret_val = netlbl_af6list_add(&addrinfo->list, &addrmap->list6); + if (ret_val != 0) + goto cfg_calipso_map_add_failure; + + entry->def.addrsel = addrmap; + entry->def.type = NETLBL_NLTYPE_ADDRSELECT; + } else { + ret_val = -EINVAL; + goto out_addrmap; + } + + ret_val = netlbl_domhsh_add(entry, audit_info); + if (ret_val != 0) + goto cfg_calipso_map_add_failure; + + return 0; + +cfg_calipso_map_add_failure: + kfree(addrinfo); +out_addrinfo: + kfree(addrmap); +out_addrmap: + kfree(entry->domain); +out_domain: + kfree(entry); +out_entry: + calipso_doi_putdef(doi_def); + return ret_val; +#else /* IPv6 */ + return -ENOSYS; +#endif /* IPv6 */ +} + +/* + * Security Attribute Functions + */ + +#define _CM_F_NONE 0x00000000 +#define _CM_F_ALLOC 0x00000001 +#define _CM_F_WALK 0x00000002 + +/** + * _netlbl_catmap_getnode - Get a individual node from a catmap + * @catmap: pointer to the category bitmap + * @offset: the requested offset + * @cm_flags: catmap flags, see _CM_F_* + * @gfp_flags: memory allocation flags + * + * Description: + * Iterate through the catmap looking for the node associated with @offset. + * If the _CM_F_ALLOC flag is set in @cm_flags and there is no associated node, + * one will be created and inserted into the catmap. If the _CM_F_WALK flag is + * set in @cm_flags and there is no associated node, the next highest node will + * be returned. Returns a pointer to the node on success, NULL on failure. + * + */ +static struct netlbl_lsm_catmap *_netlbl_catmap_getnode( + struct netlbl_lsm_catmap **catmap, + u32 offset, + unsigned int cm_flags, + gfp_t gfp_flags) +{ + struct netlbl_lsm_catmap *iter = *catmap; + struct netlbl_lsm_catmap *prev = NULL; + + if (iter == NULL) + goto catmap_getnode_alloc; + if (offset < iter->startbit) + goto catmap_getnode_walk; + while (iter && offset >= (iter->startbit + NETLBL_CATMAP_SIZE)) { + prev = iter; + iter = iter->next; + } + if (iter == NULL || offset < iter->startbit) + goto catmap_getnode_walk; + + return iter; + +catmap_getnode_walk: + if (cm_flags & _CM_F_WALK) + return iter; +catmap_getnode_alloc: + if (!(cm_flags & _CM_F_ALLOC)) + return NULL; + + iter = netlbl_catmap_alloc(gfp_flags); + if (iter == NULL) + return NULL; + iter->startbit = offset & ~(NETLBL_CATMAP_SIZE - 1); + + if (prev == NULL) { + iter->next = *catmap; + *catmap = iter; + } else { + iter->next = prev->next; + prev->next = iter; + } + + return iter; +} + +/** + * netlbl_catmap_walk - Walk a LSM secattr catmap looking for a bit + * @catmap: the category bitmap + * @offset: the offset to start searching at, in bits + * + * Description: + * This function walks a LSM secattr category bitmap starting at @offset and + * returns the spot of the first set bit or -ENOENT if no bits are set. + * + */ +int netlbl_catmap_walk(struct netlbl_lsm_catmap *catmap, u32 offset) +{ + struct netlbl_lsm_catmap *iter; + u32 idx; + u32 bit; + NETLBL_CATMAP_MAPTYPE bitmap; + + iter = _netlbl_catmap_getnode(&catmap, offset, _CM_F_WALK, 0); + if (iter == NULL) + return -ENOENT; + if (offset > iter->startbit) { + offset -= iter->startbit; + idx = offset / NETLBL_CATMAP_MAPSIZE; + bit = offset % NETLBL_CATMAP_MAPSIZE; + } else { + idx = 0; + bit = 0; + } + bitmap = iter->bitmap[idx] >> bit; + + for (;;) { + if (bitmap != 0) { + while ((bitmap & NETLBL_CATMAP_BIT) == 0) { + bitmap >>= 1; + bit++; + } + return iter->startbit + + (NETLBL_CATMAP_MAPSIZE * idx) + bit; + } + if (++idx >= NETLBL_CATMAP_MAPCNT) { + if (iter->next != NULL) { + iter = iter->next; + idx = 0; + } else + return -ENOENT; + } + bitmap = iter->bitmap[idx]; + bit = 0; + } + + return -ENOENT; +} +EXPORT_SYMBOL(netlbl_catmap_walk); + +/** + * netlbl_catmap_walkrng - Find the end of a string of set bits + * @catmap: the category bitmap + * @offset: the offset to start searching at, in bits + * + * Description: + * This function walks a LSM secattr category bitmap starting at @offset and + * returns the spot of the first cleared bit or -ENOENT if the offset is past + * the end of the bitmap. + * + */ +int netlbl_catmap_walkrng(struct netlbl_lsm_catmap *catmap, u32 offset) +{ + struct netlbl_lsm_catmap *iter; + struct netlbl_lsm_catmap *prev = NULL; + u32 idx; + u32 bit; + NETLBL_CATMAP_MAPTYPE bitmask; + NETLBL_CATMAP_MAPTYPE bitmap; + + iter = _netlbl_catmap_getnode(&catmap, offset, _CM_F_WALK, 0); + if (iter == NULL) + return -ENOENT; + if (offset > iter->startbit) { + offset -= iter->startbit; + idx = offset / NETLBL_CATMAP_MAPSIZE; + bit = offset % NETLBL_CATMAP_MAPSIZE; + } else { + idx = 0; + bit = 0; + } + bitmask = NETLBL_CATMAP_BIT << bit; + + for (;;) { + bitmap = iter->bitmap[idx]; + while (bitmask != 0 && (bitmap & bitmask) != 0) { + bitmask <<= 1; + bit++; + } + + if (prev && idx == 0 && bit == 0) + return prev->startbit + NETLBL_CATMAP_SIZE - 1; + else if (bitmask != 0) + return iter->startbit + + (NETLBL_CATMAP_MAPSIZE * idx) + bit - 1; + else if (++idx >= NETLBL_CATMAP_MAPCNT) { + if (iter->next == NULL) + return iter->startbit + NETLBL_CATMAP_SIZE - 1; + prev = iter; + iter = iter->next; + idx = 0; + } + bitmask = NETLBL_CATMAP_BIT; + bit = 0; + } + + return -ENOENT; +} + +/** + * netlbl_catmap_getlong - Export an unsigned long bitmap + * @catmap: pointer to the category bitmap + * @offset: pointer to the requested offset + * @bitmap: the exported bitmap + * + * Description: + * Export a bitmap with an offset greater than or equal to @offset and return + * it in @bitmap. The @offset must be aligned to an unsigned long and will be + * updated on return if different from what was requested; if the catmap is + * empty at the requested offset and beyond, the @offset is set to (u32)-1. + * Returns zero on success, negative values on failure. + * + */ +int netlbl_catmap_getlong(struct netlbl_lsm_catmap *catmap, + u32 *offset, + unsigned long *bitmap) +{ + struct netlbl_lsm_catmap *iter; + u32 off = *offset; + u32 idx; + + /* only allow aligned offsets */ + if ((off & (BITS_PER_LONG - 1)) != 0) + return -EINVAL; + + /* a null catmap is equivalent to an empty one */ + if (!catmap) { + *offset = (u32)-1; + return 0; + } + + if (off < catmap->startbit) { + off = catmap->startbit; + *offset = off; + } + iter = _netlbl_catmap_getnode(&catmap, off, _CM_F_WALK, 0); + if (iter == NULL) { + *offset = (u32)-1; + return 0; + } + + if (off < iter->startbit) { + *offset = iter->startbit; + off = 0; + } else + off -= iter->startbit; + idx = off / NETLBL_CATMAP_MAPSIZE; + *bitmap = iter->bitmap[idx] >> (off % NETLBL_CATMAP_MAPSIZE); + + return 0; +} + +/** + * netlbl_catmap_setbit - Set a bit in a LSM secattr catmap + * @catmap: pointer to the category bitmap + * @bit: the bit to set + * @flags: memory allocation flags + * + * Description: + * Set the bit specified by @bit in @catmap. Returns zero on success, + * negative values on failure. + * + */ +int netlbl_catmap_setbit(struct netlbl_lsm_catmap **catmap, + u32 bit, + gfp_t flags) +{ + struct netlbl_lsm_catmap *iter; + u32 idx; + + iter = _netlbl_catmap_getnode(catmap, bit, _CM_F_ALLOC, flags); + if (iter == NULL) + return -ENOMEM; + + bit -= iter->startbit; + idx = bit / NETLBL_CATMAP_MAPSIZE; + iter->bitmap[idx] |= NETLBL_CATMAP_BIT << (bit % NETLBL_CATMAP_MAPSIZE); + + return 0; +} +EXPORT_SYMBOL(netlbl_catmap_setbit); + +/** + * netlbl_catmap_setrng - Set a range of bits in a LSM secattr catmap + * @catmap: pointer to the category bitmap + * @start: the starting bit + * @end: the last bit in the string + * @flags: memory allocation flags + * + * Description: + * Set a range of bits, starting at @start and ending with @end. Returns zero + * on success, negative values on failure. + * + */ +int netlbl_catmap_setrng(struct netlbl_lsm_catmap **catmap, + u32 start, + u32 end, + gfp_t flags) +{ + int rc = 0; + u32 spot = start; + + while (rc == 0 && spot <= end) { + if (((spot & (BITS_PER_LONG - 1)) == 0) && + ((end - spot) > BITS_PER_LONG)) { + rc = netlbl_catmap_setlong(catmap, + spot, + (unsigned long)-1, + flags); + spot += BITS_PER_LONG; + } else + rc = netlbl_catmap_setbit(catmap, spot++, flags); + } + + return rc; +} + +/** + * netlbl_catmap_setlong - Import an unsigned long bitmap + * @catmap: pointer to the category bitmap + * @offset: offset to the start of the imported bitmap + * @bitmap: the bitmap to import + * @flags: memory allocation flags + * + * Description: + * Import the bitmap specified in @bitmap into @catmap, using the offset + * in @offset. The offset must be aligned to an unsigned long. Returns zero + * on success, negative values on failure. + * + */ +int netlbl_catmap_setlong(struct netlbl_lsm_catmap **catmap, + u32 offset, + unsigned long bitmap, + gfp_t flags) +{ + struct netlbl_lsm_catmap *iter; + u32 idx; + + /* only allow aligned offsets */ + if ((offset & (BITS_PER_LONG - 1)) != 0) + return -EINVAL; + + iter = _netlbl_catmap_getnode(catmap, offset, _CM_F_ALLOC, flags); + if (iter == NULL) + return -ENOMEM; + + offset -= iter->startbit; + idx = offset / NETLBL_CATMAP_MAPSIZE; + iter->bitmap[idx] |= (NETLBL_CATMAP_MAPTYPE)bitmap + << (offset % NETLBL_CATMAP_MAPSIZE); + + return 0; +} + +/* Bitmap functions + */ + +/** + * netlbl_bitmap_walk - Walk a bitmap looking for a bit + * @bitmap: the bitmap + * @bitmap_len: length in bits + * @offset: starting offset + * @state: if non-zero, look for a set (1) bit else look for a cleared (0) bit + * + * Description: + * Starting at @offset, walk the bitmap from left to right until either the + * desired bit is found or we reach the end. Return the bit offset, -1 if + * not found, or -2 if error. + */ +int netlbl_bitmap_walk(const unsigned char *bitmap, u32 bitmap_len, + u32 offset, u8 state) +{ + u32 bit_spot; + u32 byte_offset; + unsigned char bitmask; + unsigned char byte; + + if (offset >= bitmap_len) + return -1; + byte_offset = offset / 8; + byte = bitmap[byte_offset]; + bit_spot = offset; + bitmask = 0x80 >> (offset % 8); + + while (bit_spot < bitmap_len) { + if ((state && (byte & bitmask) == bitmask) || + (state == 0 && (byte & bitmask) == 0)) + return bit_spot; + + if (++bit_spot >= bitmap_len) + return -1; + bitmask >>= 1; + if (bitmask == 0) { + byte = bitmap[++byte_offset]; + bitmask = 0x80; + } + } + + return -1; +} +EXPORT_SYMBOL(netlbl_bitmap_walk); + +/** + * netlbl_bitmap_setbit - Sets a single bit in a bitmap + * @bitmap: the bitmap + * @bit: the bit + * @state: if non-zero, set the bit (1) else clear the bit (0) + * + * Description: + * Set a single bit in the bitmask. Returns zero on success, negative values + * on error. + */ +void netlbl_bitmap_setbit(unsigned char *bitmap, u32 bit, u8 state) +{ + u32 byte_spot; + u8 bitmask; + + /* gcc always rounds to zero when doing integer division */ + byte_spot = bit / 8; + bitmask = 0x80 >> (bit % 8); + if (state) + bitmap[byte_spot] |= bitmask; + else + bitmap[byte_spot] &= ~bitmask; +} +EXPORT_SYMBOL(netlbl_bitmap_setbit); + +/* + * LSM Functions + */ + +/** + * netlbl_enabled - Determine if the NetLabel subsystem is enabled + * + * Description: + * The LSM can use this function to determine if it should use NetLabel + * security attributes in it's enforcement mechanism. Currently, NetLabel is + * considered to be enabled when it's configuration contains a valid setup for + * at least one labeled protocol (i.e. NetLabel can understand incoming + * labeled packets of at least one type); otherwise NetLabel is considered to + * be disabled. + * + */ +int netlbl_enabled(void) +{ + /* At some point we probably want to expose this mechanism to the user + * as well so that admins can toggle NetLabel regardless of the + * configuration */ + return (atomic_read(&netlabel_mgmt_protocount) > 0); +} + +/** + * netlbl_sock_setattr - Label a socket using the correct protocol + * @sk: the socket to label + * @family: protocol family + * @secattr: the security attributes + * + * Description: + * Attach the correct label to the given socket using the security attributes + * specified in @secattr. This function requires exclusive access to @sk, + * which means it either needs to be in the process of being created or locked. + * Returns zero on success, -EDESTADDRREQ if the domain is configured to use + * network address selectors (can't blindly label the socket), and negative + * values on all other failures. + * + */ +int netlbl_sock_setattr(struct sock *sk, + u16 family, + const struct netlbl_lsm_secattr *secattr) +{ + int ret_val; + struct netlbl_dom_map *dom_entry; + + rcu_read_lock(); + dom_entry = netlbl_domhsh_getentry(secattr->domain, family); + if (dom_entry == NULL) { + ret_val = -ENOENT; + goto socket_setattr_return; + } + switch (family) { + case AF_INET: + switch (dom_entry->def.type) { + case NETLBL_NLTYPE_ADDRSELECT: + ret_val = -EDESTADDRREQ; + break; + case NETLBL_NLTYPE_CIPSOV4: + ret_val = cipso_v4_sock_setattr(sk, + dom_entry->def.cipso, + secattr); + break; + case NETLBL_NLTYPE_UNLABELED: + ret_val = 0; + break; + default: + ret_val = -ENOENT; + } + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + switch (dom_entry->def.type) { + case NETLBL_NLTYPE_ADDRSELECT: + ret_val = -EDESTADDRREQ; + break; + case NETLBL_NLTYPE_CALIPSO: + ret_val = calipso_sock_setattr(sk, + dom_entry->def.calipso, + secattr); + break; + case NETLBL_NLTYPE_UNLABELED: + ret_val = 0; + break; + default: + ret_val = -ENOENT; + } + break; +#endif /* IPv6 */ + default: + ret_val = -EPROTONOSUPPORT; + } + +socket_setattr_return: + rcu_read_unlock(); + return ret_val; +} + +/** + * netlbl_sock_delattr - Delete all the NetLabel labels on a socket + * @sk: the socket + * + * Description: + * Remove all the NetLabel labeling from @sk. The caller is responsible for + * ensuring that @sk is locked. + * + */ +void netlbl_sock_delattr(struct sock *sk) +{ + switch (sk->sk_family) { + case AF_INET: + cipso_v4_sock_delattr(sk); + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + calipso_sock_delattr(sk); + break; +#endif /* IPv6 */ + } +} + +/** + * netlbl_sock_getattr - Determine the security attributes of a sock + * @sk: the sock + * @secattr: the security attributes + * + * Description: + * Examines the given sock to see if any NetLabel style labeling has been + * applied to the sock, if so it parses the socket label and returns the + * security attributes in @secattr. Returns zero on success, negative values + * on failure. + * + */ +int netlbl_sock_getattr(struct sock *sk, + struct netlbl_lsm_secattr *secattr) +{ + int ret_val; + + switch (sk->sk_family) { + case AF_INET: + ret_val = cipso_v4_sock_getattr(sk, secattr); + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + ret_val = calipso_sock_getattr(sk, secattr); + break; +#endif /* IPv6 */ + default: + ret_val = -EPROTONOSUPPORT; + } + + return ret_val; +} + +/** + * netlbl_conn_setattr - Label a connected socket using the correct protocol + * @sk: the socket to label + * @addr: the destination address + * @secattr: the security attributes + * + * Description: + * Attach the correct label to the given connected socket using the security + * attributes specified in @secattr. The caller is responsible for ensuring + * that @sk is locked. Returns zero on success, negative values on failure. + * + */ +int netlbl_conn_setattr(struct sock *sk, + struct sockaddr *addr, + const struct netlbl_lsm_secattr *secattr) +{ + int ret_val; + struct sockaddr_in *addr4; +#if IS_ENABLED(CONFIG_IPV6) + struct sockaddr_in6 *addr6; +#endif + struct netlbl_dommap_def *entry; + + rcu_read_lock(); + switch (addr->sa_family) { + case AF_INET: + addr4 = (struct sockaddr_in *)addr; + entry = netlbl_domhsh_getentry_af4(secattr->domain, + addr4->sin_addr.s_addr); + if (entry == NULL) { + ret_val = -ENOENT; + goto conn_setattr_return; + } + switch (entry->type) { + case NETLBL_NLTYPE_CIPSOV4: + ret_val = cipso_v4_sock_setattr(sk, + entry->cipso, secattr); + break; + case NETLBL_NLTYPE_UNLABELED: + /* just delete the protocols we support for right now + * but we could remove other protocols if needed */ + netlbl_sock_delattr(sk); + ret_val = 0; + break; + default: + ret_val = -ENOENT; + } + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + addr6 = (struct sockaddr_in6 *)addr; + entry = netlbl_domhsh_getentry_af6(secattr->domain, + &addr6->sin6_addr); + if (entry == NULL) { + ret_val = -ENOENT; + goto conn_setattr_return; + } + switch (entry->type) { + case NETLBL_NLTYPE_CALIPSO: + ret_val = calipso_sock_setattr(sk, + entry->calipso, secattr); + break; + case NETLBL_NLTYPE_UNLABELED: + /* just delete the protocols we support for right now + * but we could remove other protocols if needed */ + netlbl_sock_delattr(sk); + ret_val = 0; + break; + default: + ret_val = -ENOENT; + } + break; +#endif /* IPv6 */ + default: + ret_val = -EPROTONOSUPPORT; + } + +conn_setattr_return: + rcu_read_unlock(); + return ret_val; +} + +/** + * netlbl_req_setattr - Label a request socket using the correct protocol + * @req: the request socket to label + * @secattr: the security attributes + * + * Description: + * Attach the correct label to the given socket using the security attributes + * specified in @secattr. Returns zero on success, negative values on failure. + * + */ +int netlbl_req_setattr(struct request_sock *req, + const struct netlbl_lsm_secattr *secattr) +{ + int ret_val; + struct netlbl_dommap_def *entry; + struct inet_request_sock *ireq = inet_rsk(req); + + rcu_read_lock(); + switch (req->rsk_ops->family) { + case AF_INET: + entry = netlbl_domhsh_getentry_af4(secattr->domain, + ireq->ir_rmt_addr); + if (entry == NULL) { + ret_val = -ENOENT; + goto req_setattr_return; + } + switch (entry->type) { + case NETLBL_NLTYPE_CIPSOV4: + ret_val = cipso_v4_req_setattr(req, + entry->cipso, secattr); + break; + case NETLBL_NLTYPE_UNLABELED: + netlbl_req_delattr(req); + ret_val = 0; + break; + default: + ret_val = -ENOENT; + } + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + entry = netlbl_domhsh_getentry_af6(secattr->domain, + &ireq->ir_v6_rmt_addr); + if (entry == NULL) { + ret_val = -ENOENT; + goto req_setattr_return; + } + switch (entry->type) { + case NETLBL_NLTYPE_CALIPSO: + ret_val = calipso_req_setattr(req, + entry->calipso, secattr); + break; + case NETLBL_NLTYPE_UNLABELED: + netlbl_req_delattr(req); + ret_val = 0; + break; + default: + ret_val = -ENOENT; + } + break; +#endif /* IPv6 */ + default: + ret_val = -EPROTONOSUPPORT; + } + +req_setattr_return: + rcu_read_unlock(); + return ret_val; +} + +/** +* netlbl_req_delattr - Delete all the NetLabel labels on a socket +* @req: the socket +* +* Description: +* Remove all the NetLabel labeling from @req. +* +*/ +void netlbl_req_delattr(struct request_sock *req) +{ + switch (req->rsk_ops->family) { + case AF_INET: + cipso_v4_req_delattr(req); + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + calipso_req_delattr(req); + break; +#endif /* IPv6 */ + } +} + +/** + * netlbl_skbuff_setattr - Label a packet using the correct protocol + * @skb: the packet + * @family: protocol family + * @secattr: the security attributes + * + * Description: + * Attach the correct label to the given packet using the security attributes + * specified in @secattr. Returns zero on success, negative values on failure. + * + */ +int netlbl_skbuff_setattr(struct sk_buff *skb, + u16 family, + const struct netlbl_lsm_secattr *secattr) +{ + int ret_val; + struct iphdr *hdr4; +#if IS_ENABLED(CONFIG_IPV6) + struct ipv6hdr *hdr6; +#endif + struct netlbl_dommap_def *entry; + + rcu_read_lock(); + switch (family) { + case AF_INET: + hdr4 = ip_hdr(skb); + entry = netlbl_domhsh_getentry_af4(secattr->domain, + hdr4->daddr); + if (entry == NULL) { + ret_val = -ENOENT; + goto skbuff_setattr_return; + } + switch (entry->type) { + case NETLBL_NLTYPE_CIPSOV4: + ret_val = cipso_v4_skbuff_setattr(skb, entry->cipso, + secattr); + break; + case NETLBL_NLTYPE_UNLABELED: + /* just delete the protocols we support for right now + * but we could remove other protocols if needed */ + ret_val = cipso_v4_skbuff_delattr(skb); + break; + default: + ret_val = -ENOENT; + } + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + hdr6 = ipv6_hdr(skb); + entry = netlbl_domhsh_getentry_af6(secattr->domain, + &hdr6->daddr); + if (entry == NULL) { + ret_val = -ENOENT; + goto skbuff_setattr_return; + } + switch (entry->type) { + case NETLBL_NLTYPE_CALIPSO: + ret_val = calipso_skbuff_setattr(skb, entry->calipso, + secattr); + break; + case NETLBL_NLTYPE_UNLABELED: + /* just delete the protocols we support for right now + * but we could remove other protocols if needed */ + ret_val = calipso_skbuff_delattr(skb); + break; + default: + ret_val = -ENOENT; + } + break; +#endif /* IPv6 */ + default: + ret_val = -EPROTONOSUPPORT; + } + +skbuff_setattr_return: + rcu_read_unlock(); + return ret_val; +} + +/** + * netlbl_skbuff_getattr - Determine the security attributes of a packet + * @skb: the packet + * @family: protocol family + * @secattr: the security attributes + * + * Description: + * Examines the given packet to see if a recognized form of packet labeling + * is present, if so it parses the packet label and returns the security + * attributes in @secattr. Returns zero on success, negative values on + * failure. + * + */ +int netlbl_skbuff_getattr(const struct sk_buff *skb, + u16 family, + struct netlbl_lsm_secattr *secattr) +{ + unsigned char *ptr; + + switch (family) { + case AF_INET: + ptr = cipso_v4_optptr(skb); + if (ptr && cipso_v4_getattr(ptr, secattr) == 0) + return 0; + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + ptr = calipso_optptr(skb); + if (ptr && calipso_getattr(ptr, secattr) == 0) + return 0; + break; +#endif /* IPv6 */ + } + + return netlbl_unlabel_getattr(skb, family, secattr); +} + +/** + * netlbl_skbuff_err - Handle a LSM error on a sk_buff + * @skb: the packet + * @family: the family + * @error: the error code + * @gateway: true if host is acting as a gateway, false otherwise + * + * Description: + * Deal with a LSM problem when handling the packet in @skb, typically this is + * a permission denied problem (-EACCES). The correct action is determined + * according to the packet's labeling protocol. + * + */ +void netlbl_skbuff_err(struct sk_buff *skb, u16 family, int error, int gateway) +{ + switch (family) { + case AF_INET: + if (cipso_v4_optptr(skb)) + cipso_v4_error(skb, error, gateway); + break; + } +} + +/** + * netlbl_cache_invalidate - Invalidate all of the NetLabel protocol caches + * + * Description: + * For all of the NetLabel protocols that support some form of label mapping + * cache, invalidate the cache. Returns zero on success, negative values on + * error. + * + */ +void netlbl_cache_invalidate(void) +{ + cipso_v4_cache_invalidate(); +#if IS_ENABLED(CONFIG_IPV6) + calipso_cache_invalidate(); +#endif /* IPv6 */ +} + +/** + * netlbl_cache_add - Add an entry to a NetLabel protocol cache + * @skb: the packet + * @family: the family + * @secattr: the packet's security attributes + * + * Description: + * Add the LSM security attributes for the given packet to the underlying + * NetLabel protocol's label mapping cache. Returns zero on success, negative + * values on error. + * + */ +int netlbl_cache_add(const struct sk_buff *skb, u16 family, + const struct netlbl_lsm_secattr *secattr) +{ + unsigned char *ptr; + + if ((secattr->flags & NETLBL_SECATTR_CACHE) == 0) + return -ENOMSG; + + switch (family) { + case AF_INET: + ptr = cipso_v4_optptr(skb); + if (ptr) + return cipso_v4_cache_add(ptr, secattr); + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + ptr = calipso_optptr(skb); + if (ptr) + return calipso_cache_add(ptr, secattr); + break; +#endif /* IPv6 */ + } + return -ENOMSG; +} + +/* + * Protocol Engine Functions + */ + +/** + * netlbl_audit_start - Start an audit message + * @type: audit message type + * @audit_info: NetLabel audit information + * + * Description: + * Start an audit message using the type specified in @type and fill the audit + * message with some fields common to all NetLabel audit messages. This + * function should only be used by protocol engines, not LSMs. Returns a + * pointer to the audit buffer on success, NULL on failure. + * + */ +struct audit_buffer *netlbl_audit_start(int type, + struct netlbl_audit *audit_info) +{ + return netlbl_audit_start_common(type, audit_info); +} +EXPORT_SYMBOL(netlbl_audit_start); + +/* + * Setup Functions + */ + +/** + * netlbl_init - Initialize NetLabel + * + * Description: + * Perform the required NetLabel initialization before first use. + * + */ +static int __init netlbl_init(void) +{ + int ret_val; + + printk(KERN_INFO "NetLabel: Initializing\n"); + printk(KERN_INFO "NetLabel: domain hash size = %u\n", + (1 << NETLBL_DOMHSH_BITSIZE)); + printk(KERN_INFO "NetLabel: protocols = UNLABELED CIPSOv4 CALIPSO\n"); + + ret_val = netlbl_domhsh_init(NETLBL_DOMHSH_BITSIZE); + if (ret_val != 0) + goto init_failure; + + ret_val = netlbl_unlabel_init(NETLBL_UNLHSH_BITSIZE); + if (ret_val != 0) + goto init_failure; + + ret_val = netlbl_netlink_init(); + if (ret_val != 0) + goto init_failure; + + ret_val = netlbl_unlabel_defconf(); + if (ret_val != 0) + goto init_failure; + printk(KERN_INFO "NetLabel: unlabeled traffic allowed by default\n"); + + return 0; + +init_failure: + panic("NetLabel: failed to initialize properly (%d)\n", ret_val); +} + +subsys_initcall(netlbl_init); diff --git a/net/netlabel/netlabel_mgmt.c b/net/netlabel/netlabel_mgmt.c new file mode 100644 index 000000000..689eaa2af --- /dev/null +++ b/net/netlabel/netlabel_mgmt.c @@ -0,0 +1,847 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NetLabel Management Support + * + * This file defines the management functions for the NetLabel system. The + * NetLabel system manages static and dynamic label mappings for network + * protocols such as CIPSO and RIPSO. + * + * Author: Paul Moore + */ + +/* + * (c) Copyright Hewlett-Packard Development Company, L.P., 2006, 2008 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "netlabel_calipso.h" +#include "netlabel_domainhash.h" +#include "netlabel_user.h" +#include "netlabel_mgmt.h" + +/* NetLabel configured protocol counter */ +atomic_t netlabel_mgmt_protocount = ATOMIC_INIT(0); + +/* Argument struct for netlbl_domhsh_walk() */ +struct netlbl_domhsh_walk_arg { + struct netlink_callback *nl_cb; + struct sk_buff *skb; + u32 seq; +}; + +/* NetLabel Generic NETLINK CIPSOv4 family */ +static struct genl_family netlbl_mgmt_gnl_family; + +/* NetLabel Netlink attribute policy */ +static const struct nla_policy netlbl_mgmt_genl_policy[NLBL_MGMT_A_MAX + 1] = { + [NLBL_MGMT_A_DOMAIN] = { .type = NLA_NUL_STRING }, + [NLBL_MGMT_A_PROTOCOL] = { .type = NLA_U32 }, + [NLBL_MGMT_A_VERSION] = { .type = NLA_U32 }, + [NLBL_MGMT_A_CV4DOI] = { .type = NLA_U32 }, + [NLBL_MGMT_A_FAMILY] = { .type = NLA_U16 }, + [NLBL_MGMT_A_CLPDOI] = { .type = NLA_U32 }, +}; + +/* + * Helper Functions + */ + +/** + * netlbl_mgmt_add_common - Handle an ADD message + * @info: the Generic NETLINK info block + * @audit_info: NetLabel audit information + * + * Description: + * Helper function for the ADD and ADDDEF messages to add the domain mappings + * from the message to the hash table. See netlabel.h for a description of the + * message format. Returns zero on success, negative values on failure. + * + */ +static int netlbl_mgmt_add_common(struct genl_info *info, + struct netlbl_audit *audit_info) +{ + void *pmap = NULL; + int ret_val = -EINVAL; + struct netlbl_domaddr_map *addrmap = NULL; + struct cipso_v4_doi *cipsov4 = NULL; +#if IS_ENABLED(CONFIG_IPV6) + struct calipso_doi *calipso = NULL; +#endif + u32 tmp_val; + struct netlbl_dom_map *entry = kzalloc(sizeof(*entry), GFP_KERNEL); + + if (!entry) + return -ENOMEM; + entry->def.type = nla_get_u32(info->attrs[NLBL_MGMT_A_PROTOCOL]); + if (info->attrs[NLBL_MGMT_A_DOMAIN]) { + size_t tmp_size = nla_len(info->attrs[NLBL_MGMT_A_DOMAIN]); + entry->domain = kmalloc(tmp_size, GFP_KERNEL); + if (entry->domain == NULL) { + ret_val = -ENOMEM; + goto add_free_entry; + } + nla_strscpy(entry->domain, + info->attrs[NLBL_MGMT_A_DOMAIN], tmp_size); + } + + /* NOTE: internally we allow/use a entry->def.type value of + * NETLBL_NLTYPE_ADDRSELECT but we don't currently allow users + * to pass that as a protocol value because we need to know the + * "real" protocol */ + + switch (entry->def.type) { + case NETLBL_NLTYPE_UNLABELED: + if (info->attrs[NLBL_MGMT_A_FAMILY]) + entry->family = + nla_get_u16(info->attrs[NLBL_MGMT_A_FAMILY]); + else + entry->family = AF_UNSPEC; + break; + case NETLBL_NLTYPE_CIPSOV4: + if (!info->attrs[NLBL_MGMT_A_CV4DOI]) + goto add_free_domain; + + tmp_val = nla_get_u32(info->attrs[NLBL_MGMT_A_CV4DOI]); + cipsov4 = cipso_v4_doi_getdef(tmp_val); + if (cipsov4 == NULL) + goto add_free_domain; + entry->family = AF_INET; + entry->def.cipso = cipsov4; + break; +#if IS_ENABLED(CONFIG_IPV6) + case NETLBL_NLTYPE_CALIPSO: + if (!info->attrs[NLBL_MGMT_A_CLPDOI]) + goto add_free_domain; + + tmp_val = nla_get_u32(info->attrs[NLBL_MGMT_A_CLPDOI]); + calipso = calipso_doi_getdef(tmp_val); + if (calipso == NULL) + goto add_free_domain; + entry->family = AF_INET6; + entry->def.calipso = calipso; + break; +#endif /* IPv6 */ + default: + goto add_free_domain; + } + + if ((entry->family == AF_INET && info->attrs[NLBL_MGMT_A_IPV6ADDR]) || + (entry->family == AF_INET6 && info->attrs[NLBL_MGMT_A_IPV4ADDR])) + goto add_doi_put_def; + + if (info->attrs[NLBL_MGMT_A_IPV4ADDR]) { + struct in_addr *addr; + struct in_addr *mask; + struct netlbl_domaddr4_map *map; + + addrmap = kzalloc(sizeof(*addrmap), GFP_KERNEL); + if (addrmap == NULL) { + ret_val = -ENOMEM; + goto add_doi_put_def; + } + INIT_LIST_HEAD(&addrmap->list4); + INIT_LIST_HEAD(&addrmap->list6); + + if (nla_len(info->attrs[NLBL_MGMT_A_IPV4ADDR]) != + sizeof(struct in_addr)) { + ret_val = -EINVAL; + goto add_free_addrmap; + } + if (nla_len(info->attrs[NLBL_MGMT_A_IPV4MASK]) != + sizeof(struct in_addr)) { + ret_val = -EINVAL; + goto add_free_addrmap; + } + addr = nla_data(info->attrs[NLBL_MGMT_A_IPV4ADDR]); + mask = nla_data(info->attrs[NLBL_MGMT_A_IPV4MASK]); + + map = kzalloc(sizeof(*map), GFP_KERNEL); + if (map == NULL) { + ret_val = -ENOMEM; + goto add_free_addrmap; + } + pmap = map; + map->list.addr = addr->s_addr & mask->s_addr; + map->list.mask = mask->s_addr; + map->list.valid = 1; + map->def.type = entry->def.type; + if (cipsov4) + map->def.cipso = cipsov4; + + ret_val = netlbl_af4list_add(&map->list, &addrmap->list4); + if (ret_val != 0) + goto add_free_map; + + entry->family = AF_INET; + entry->def.type = NETLBL_NLTYPE_ADDRSELECT; + entry->def.addrsel = addrmap; +#if IS_ENABLED(CONFIG_IPV6) + } else if (info->attrs[NLBL_MGMT_A_IPV6ADDR]) { + struct in6_addr *addr; + struct in6_addr *mask; + struct netlbl_domaddr6_map *map; + + addrmap = kzalloc(sizeof(*addrmap), GFP_KERNEL); + if (addrmap == NULL) { + ret_val = -ENOMEM; + goto add_doi_put_def; + } + INIT_LIST_HEAD(&addrmap->list4); + INIT_LIST_HEAD(&addrmap->list6); + + if (nla_len(info->attrs[NLBL_MGMT_A_IPV6ADDR]) != + sizeof(struct in6_addr)) { + ret_val = -EINVAL; + goto add_free_addrmap; + } + if (nla_len(info->attrs[NLBL_MGMT_A_IPV6MASK]) != + sizeof(struct in6_addr)) { + ret_val = -EINVAL; + goto add_free_addrmap; + } + addr = nla_data(info->attrs[NLBL_MGMT_A_IPV6ADDR]); + mask = nla_data(info->attrs[NLBL_MGMT_A_IPV6MASK]); + + map = kzalloc(sizeof(*map), GFP_KERNEL); + if (map == NULL) { + ret_val = -ENOMEM; + goto add_free_addrmap; + } + pmap = map; + map->list.addr = *addr; + map->list.addr.s6_addr32[0] &= mask->s6_addr32[0]; + map->list.addr.s6_addr32[1] &= mask->s6_addr32[1]; + map->list.addr.s6_addr32[2] &= mask->s6_addr32[2]; + map->list.addr.s6_addr32[3] &= mask->s6_addr32[3]; + map->list.mask = *mask; + map->list.valid = 1; + map->def.type = entry->def.type; + if (calipso) + map->def.calipso = calipso; + + ret_val = netlbl_af6list_add(&map->list, &addrmap->list6); + if (ret_val != 0) + goto add_free_map; + + entry->family = AF_INET6; + entry->def.type = NETLBL_NLTYPE_ADDRSELECT; + entry->def.addrsel = addrmap; +#endif /* IPv6 */ + } + + ret_val = netlbl_domhsh_add(entry, audit_info); + if (ret_val != 0) + goto add_free_map; + + return 0; + +add_free_map: + kfree(pmap); +add_free_addrmap: + kfree(addrmap); +add_doi_put_def: + cipso_v4_doi_putdef(cipsov4); +#if IS_ENABLED(CONFIG_IPV6) + calipso_doi_putdef(calipso); +#endif +add_free_domain: + kfree(entry->domain); +add_free_entry: + kfree(entry); + return ret_val; +} + +/** + * netlbl_mgmt_listentry - List a NetLabel/LSM domain map entry + * @skb: the NETLINK buffer + * @entry: the map entry + * + * Description: + * This function is a helper function used by the LISTALL and LISTDEF command + * handlers. The caller is responsible for ensuring that the RCU read lock + * is held. Returns zero on success, negative values on failure. + * + */ +static int netlbl_mgmt_listentry(struct sk_buff *skb, + struct netlbl_dom_map *entry) +{ + int ret_val = 0; + struct nlattr *nla_a; + struct nlattr *nla_b; + struct netlbl_af4list *iter4; +#if IS_ENABLED(CONFIG_IPV6) + struct netlbl_af6list *iter6; +#endif + + if (entry->domain != NULL) { + ret_val = nla_put_string(skb, + NLBL_MGMT_A_DOMAIN, entry->domain); + if (ret_val != 0) + return ret_val; + } + + ret_val = nla_put_u16(skb, NLBL_MGMT_A_FAMILY, entry->family); + if (ret_val != 0) + return ret_val; + + switch (entry->def.type) { + case NETLBL_NLTYPE_ADDRSELECT: + nla_a = nla_nest_start_noflag(skb, NLBL_MGMT_A_SELECTORLIST); + if (nla_a == NULL) + return -ENOMEM; + + netlbl_af4list_foreach_rcu(iter4, &entry->def.addrsel->list4) { + struct netlbl_domaddr4_map *map4; + struct in_addr addr_struct; + + nla_b = nla_nest_start_noflag(skb, + NLBL_MGMT_A_ADDRSELECTOR); + if (nla_b == NULL) + return -ENOMEM; + + addr_struct.s_addr = iter4->addr; + ret_val = nla_put_in_addr(skb, NLBL_MGMT_A_IPV4ADDR, + addr_struct.s_addr); + if (ret_val != 0) + return ret_val; + addr_struct.s_addr = iter4->mask; + ret_val = nla_put_in_addr(skb, NLBL_MGMT_A_IPV4MASK, + addr_struct.s_addr); + if (ret_val != 0) + return ret_val; + map4 = netlbl_domhsh_addr4_entry(iter4); + ret_val = nla_put_u32(skb, NLBL_MGMT_A_PROTOCOL, + map4->def.type); + if (ret_val != 0) + return ret_val; + switch (map4->def.type) { + case NETLBL_NLTYPE_CIPSOV4: + ret_val = nla_put_u32(skb, NLBL_MGMT_A_CV4DOI, + map4->def.cipso->doi); + if (ret_val != 0) + return ret_val; + break; + } + + nla_nest_end(skb, nla_b); + } +#if IS_ENABLED(CONFIG_IPV6) + netlbl_af6list_foreach_rcu(iter6, &entry->def.addrsel->list6) { + struct netlbl_domaddr6_map *map6; + + nla_b = nla_nest_start_noflag(skb, + NLBL_MGMT_A_ADDRSELECTOR); + if (nla_b == NULL) + return -ENOMEM; + + ret_val = nla_put_in6_addr(skb, NLBL_MGMT_A_IPV6ADDR, + &iter6->addr); + if (ret_val != 0) + return ret_val; + ret_val = nla_put_in6_addr(skb, NLBL_MGMT_A_IPV6MASK, + &iter6->mask); + if (ret_val != 0) + return ret_val; + map6 = netlbl_domhsh_addr6_entry(iter6); + ret_val = nla_put_u32(skb, NLBL_MGMT_A_PROTOCOL, + map6->def.type); + if (ret_val != 0) + return ret_val; + + switch (map6->def.type) { + case NETLBL_NLTYPE_CALIPSO: + ret_val = nla_put_u32(skb, NLBL_MGMT_A_CLPDOI, + map6->def.calipso->doi); + if (ret_val != 0) + return ret_val; + break; + } + + nla_nest_end(skb, nla_b); + } +#endif /* IPv6 */ + + nla_nest_end(skb, nla_a); + break; + case NETLBL_NLTYPE_UNLABELED: + ret_val = nla_put_u32(skb, NLBL_MGMT_A_PROTOCOL, + entry->def.type); + break; + case NETLBL_NLTYPE_CIPSOV4: + ret_val = nla_put_u32(skb, NLBL_MGMT_A_PROTOCOL, + entry->def.type); + if (ret_val != 0) + return ret_val; + ret_val = nla_put_u32(skb, NLBL_MGMT_A_CV4DOI, + entry->def.cipso->doi); + break; + case NETLBL_NLTYPE_CALIPSO: + ret_val = nla_put_u32(skb, NLBL_MGMT_A_PROTOCOL, + entry->def.type); + if (ret_val != 0) + return ret_val; + ret_val = nla_put_u32(skb, NLBL_MGMT_A_CLPDOI, + entry->def.calipso->doi); + break; + } + + return ret_val; +} + +/* + * NetLabel Command Handlers + */ + +/** + * netlbl_mgmt_add - Handle an ADD message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Process a user generated ADD message and add the domains from the message + * to the hash table. See netlabel.h for a description of the message format. + * Returns zero on success, negative values on failure. + * + */ +static int netlbl_mgmt_add(struct sk_buff *skb, struct genl_info *info) +{ + struct netlbl_audit audit_info; + + if ((!info->attrs[NLBL_MGMT_A_DOMAIN]) || + (!info->attrs[NLBL_MGMT_A_PROTOCOL]) || + (info->attrs[NLBL_MGMT_A_IPV4ADDR] && + info->attrs[NLBL_MGMT_A_IPV6ADDR]) || + (info->attrs[NLBL_MGMT_A_IPV4MASK] && + info->attrs[NLBL_MGMT_A_IPV6MASK]) || + ((info->attrs[NLBL_MGMT_A_IPV4ADDR] != NULL) ^ + (info->attrs[NLBL_MGMT_A_IPV4MASK] != NULL)) || + ((info->attrs[NLBL_MGMT_A_IPV6ADDR] != NULL) ^ + (info->attrs[NLBL_MGMT_A_IPV6MASK] != NULL))) + return -EINVAL; + + netlbl_netlink_auditinfo(&audit_info); + + return netlbl_mgmt_add_common(info, &audit_info); +} + +/** + * netlbl_mgmt_remove - Handle a REMOVE message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Process a user generated REMOVE message and remove the specified domain + * mappings. Returns zero on success, negative values on failure. + * + */ +static int netlbl_mgmt_remove(struct sk_buff *skb, struct genl_info *info) +{ + char *domain; + struct netlbl_audit audit_info; + + if (!info->attrs[NLBL_MGMT_A_DOMAIN]) + return -EINVAL; + + netlbl_netlink_auditinfo(&audit_info); + + domain = nla_data(info->attrs[NLBL_MGMT_A_DOMAIN]); + return netlbl_domhsh_remove(domain, AF_UNSPEC, &audit_info); +} + +/** + * netlbl_mgmt_listall_cb - netlbl_domhsh_walk() callback for LISTALL + * @entry: the domain mapping hash table entry + * @arg: the netlbl_domhsh_walk_arg structure + * + * Description: + * This function is designed to be used as a callback to the + * netlbl_domhsh_walk() function for use in generating a response for a LISTALL + * message. Returns the size of the message on success, negative values on + * failure. + * + */ +static int netlbl_mgmt_listall_cb(struct netlbl_dom_map *entry, void *arg) +{ + int ret_val = -ENOMEM; + struct netlbl_domhsh_walk_arg *cb_arg = arg; + void *data; + + data = genlmsg_put(cb_arg->skb, NETLINK_CB(cb_arg->nl_cb->skb).portid, + cb_arg->seq, &netlbl_mgmt_gnl_family, + NLM_F_MULTI, NLBL_MGMT_C_LISTALL); + if (data == NULL) + goto listall_cb_failure; + + ret_val = netlbl_mgmt_listentry(cb_arg->skb, entry); + if (ret_val != 0) + goto listall_cb_failure; + + cb_arg->seq++; + genlmsg_end(cb_arg->skb, data); + return 0; + +listall_cb_failure: + genlmsg_cancel(cb_arg->skb, data); + return ret_val; +} + +/** + * netlbl_mgmt_listall - Handle a LISTALL message + * @skb: the NETLINK buffer + * @cb: the NETLINK callback + * + * Description: + * Process a user generated LISTALL message and dumps the domain hash table in + * a form suitable for use in a kernel generated LISTALL message. Returns zero + * on success, negative values on failure. + * + */ +static int netlbl_mgmt_listall(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct netlbl_domhsh_walk_arg cb_arg; + u32 skip_bkt = cb->args[0]; + u32 skip_chain = cb->args[1]; + + cb_arg.nl_cb = cb; + cb_arg.skb = skb; + cb_arg.seq = cb->nlh->nlmsg_seq; + + netlbl_domhsh_walk(&skip_bkt, + &skip_chain, + netlbl_mgmt_listall_cb, + &cb_arg); + + cb->args[0] = skip_bkt; + cb->args[1] = skip_chain; + return skb->len; +} + +/** + * netlbl_mgmt_adddef - Handle an ADDDEF message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Process a user generated ADDDEF message and respond accordingly. Returns + * zero on success, negative values on failure. + * + */ +static int netlbl_mgmt_adddef(struct sk_buff *skb, struct genl_info *info) +{ + struct netlbl_audit audit_info; + + if ((!info->attrs[NLBL_MGMT_A_PROTOCOL]) || + (info->attrs[NLBL_MGMT_A_IPV4ADDR] && + info->attrs[NLBL_MGMT_A_IPV6ADDR]) || + (info->attrs[NLBL_MGMT_A_IPV4MASK] && + info->attrs[NLBL_MGMT_A_IPV6MASK]) || + ((info->attrs[NLBL_MGMT_A_IPV4ADDR] != NULL) ^ + (info->attrs[NLBL_MGMT_A_IPV4MASK] != NULL)) || + ((info->attrs[NLBL_MGMT_A_IPV6ADDR] != NULL) ^ + (info->attrs[NLBL_MGMT_A_IPV6MASK] != NULL))) + return -EINVAL; + + netlbl_netlink_auditinfo(&audit_info); + + return netlbl_mgmt_add_common(info, &audit_info); +} + +/** + * netlbl_mgmt_removedef - Handle a REMOVEDEF message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Process a user generated REMOVEDEF message and remove the default domain + * mapping. Returns zero on success, negative values on failure. + * + */ +static int netlbl_mgmt_removedef(struct sk_buff *skb, struct genl_info *info) +{ + struct netlbl_audit audit_info; + + netlbl_netlink_auditinfo(&audit_info); + + return netlbl_domhsh_remove_default(AF_UNSPEC, &audit_info); +} + +/** + * netlbl_mgmt_listdef - Handle a LISTDEF message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Process a user generated LISTDEF message and dumps the default domain + * mapping in a form suitable for use in a kernel generated LISTDEF message. + * Returns zero on success, negative values on failure. + * + */ +static int netlbl_mgmt_listdef(struct sk_buff *skb, struct genl_info *info) +{ + int ret_val = -ENOMEM; + struct sk_buff *ans_skb = NULL; + void *data; + struct netlbl_dom_map *entry; + u16 family; + + if (info->attrs[NLBL_MGMT_A_FAMILY]) + family = nla_get_u16(info->attrs[NLBL_MGMT_A_FAMILY]); + else + family = AF_INET; + + ans_skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (ans_skb == NULL) + return -ENOMEM; + data = genlmsg_put_reply(ans_skb, info, &netlbl_mgmt_gnl_family, + 0, NLBL_MGMT_C_LISTDEF); + if (data == NULL) + goto listdef_failure; + + rcu_read_lock(); + entry = netlbl_domhsh_getentry(NULL, family); + if (entry == NULL) { + ret_val = -ENOENT; + goto listdef_failure_lock; + } + ret_val = netlbl_mgmt_listentry(ans_skb, entry); + rcu_read_unlock(); + if (ret_val != 0) + goto listdef_failure; + + genlmsg_end(ans_skb, data); + return genlmsg_reply(ans_skb, info); + +listdef_failure_lock: + rcu_read_unlock(); +listdef_failure: + kfree_skb(ans_skb); + return ret_val; +} + +/** + * netlbl_mgmt_protocols_cb - Write an individual PROTOCOL message response + * @skb: the skb to write to + * @cb: the NETLINK callback + * @protocol: the NetLabel protocol to use in the message + * + * Description: + * This function is to be used in conjunction with netlbl_mgmt_protocols() to + * answer a application's PROTOCOLS message. Returns the size of the message + * on success, negative values on failure. + * + */ +static int netlbl_mgmt_protocols_cb(struct sk_buff *skb, + struct netlink_callback *cb, + u32 protocol) +{ + int ret_val = -ENOMEM; + void *data; + + data = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &netlbl_mgmt_gnl_family, NLM_F_MULTI, + NLBL_MGMT_C_PROTOCOLS); + if (data == NULL) + goto protocols_cb_failure; + + ret_val = nla_put_u32(skb, NLBL_MGMT_A_PROTOCOL, protocol); + if (ret_val != 0) + goto protocols_cb_failure; + + genlmsg_end(skb, data); + return 0; + +protocols_cb_failure: + genlmsg_cancel(skb, data); + return ret_val; +} + +/** + * netlbl_mgmt_protocols - Handle a PROTOCOLS message + * @skb: the NETLINK buffer + * @cb: the NETLINK callback + * + * Description: + * Process a user generated PROTOCOLS message and respond accordingly. + * + */ +static int netlbl_mgmt_protocols(struct sk_buff *skb, + struct netlink_callback *cb) +{ + u32 protos_sent = cb->args[0]; + + if (protos_sent == 0) { + if (netlbl_mgmt_protocols_cb(skb, + cb, + NETLBL_NLTYPE_UNLABELED) < 0) + goto protocols_return; + protos_sent++; + } + if (protos_sent == 1) { + if (netlbl_mgmt_protocols_cb(skb, + cb, + NETLBL_NLTYPE_CIPSOV4) < 0) + goto protocols_return; + protos_sent++; + } +#if IS_ENABLED(CONFIG_IPV6) + if (protos_sent == 2) { + if (netlbl_mgmt_protocols_cb(skb, + cb, + NETLBL_NLTYPE_CALIPSO) < 0) + goto protocols_return; + protos_sent++; + } +#endif + +protocols_return: + cb->args[0] = protos_sent; + return skb->len; +} + +/** + * netlbl_mgmt_version - Handle a VERSION message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Process a user generated VERSION message and respond accordingly. Returns + * zero on success, negative values on failure. + * + */ +static int netlbl_mgmt_version(struct sk_buff *skb, struct genl_info *info) +{ + int ret_val = -ENOMEM; + struct sk_buff *ans_skb = NULL; + void *data; + + ans_skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (ans_skb == NULL) + return -ENOMEM; + data = genlmsg_put_reply(ans_skb, info, &netlbl_mgmt_gnl_family, + 0, NLBL_MGMT_C_VERSION); + if (data == NULL) + goto version_failure; + + ret_val = nla_put_u32(ans_skb, + NLBL_MGMT_A_VERSION, + NETLBL_PROTO_VERSION); + if (ret_val != 0) + goto version_failure; + + genlmsg_end(ans_skb, data); + return genlmsg_reply(ans_skb, info); + +version_failure: + kfree_skb(ans_skb); + return ret_val; +} + + +/* + * NetLabel Generic NETLINK Command Definitions + */ + +static const struct genl_small_ops netlbl_mgmt_genl_ops[] = { + { + .cmd = NLBL_MGMT_C_ADD, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = netlbl_mgmt_add, + .dumpit = NULL, + }, + { + .cmd = NLBL_MGMT_C_REMOVE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = netlbl_mgmt_remove, + .dumpit = NULL, + }, + { + .cmd = NLBL_MGMT_C_LISTALL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, + .doit = NULL, + .dumpit = netlbl_mgmt_listall, + }, + { + .cmd = NLBL_MGMT_C_ADDDEF, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = netlbl_mgmt_adddef, + .dumpit = NULL, + }, + { + .cmd = NLBL_MGMT_C_REMOVEDEF, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = netlbl_mgmt_removedef, + .dumpit = NULL, + }, + { + .cmd = NLBL_MGMT_C_LISTDEF, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, + .doit = netlbl_mgmt_listdef, + .dumpit = NULL, + }, + { + .cmd = NLBL_MGMT_C_PROTOCOLS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, + .doit = NULL, + .dumpit = netlbl_mgmt_protocols, + }, + { + .cmd = NLBL_MGMT_C_VERSION, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, + .doit = netlbl_mgmt_version, + .dumpit = NULL, + }, +}; + +static struct genl_family netlbl_mgmt_gnl_family __ro_after_init = { + .hdrsize = 0, + .name = NETLBL_NLTYPE_MGMT_NAME, + .version = NETLBL_PROTO_VERSION, + .maxattr = NLBL_MGMT_A_MAX, + .policy = netlbl_mgmt_genl_policy, + .module = THIS_MODULE, + .small_ops = netlbl_mgmt_genl_ops, + .n_small_ops = ARRAY_SIZE(netlbl_mgmt_genl_ops), + .resv_start_op = NLBL_MGMT_C_VERSION + 1, +}; + +/* + * NetLabel Generic NETLINK Protocol Functions + */ + +/** + * netlbl_mgmt_genl_init - Register the NetLabel management component + * + * Description: + * Register the NetLabel management component with the Generic NETLINK + * mechanism. Returns zero on success, negative values on failure. + * + */ +int __init netlbl_mgmt_genl_init(void) +{ + return genl_register_family(&netlbl_mgmt_gnl_family); +} diff --git a/net/netlabel/netlabel_mgmt.h b/net/netlabel/netlabel_mgmt.h new file mode 100644 index 000000000..db20dfbbd --- /dev/null +++ b/net/netlabel/netlabel_mgmt.h @@ -0,0 +1,225 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * NetLabel Management Support + * + * This file defines the management functions for the NetLabel system. The + * NetLabel system manages static and dynamic label mappings for network + * protocols such as CIPSO and RIPSO. + * + * Author: Paul Moore + */ + +/* + * (c) Copyright Hewlett-Packard Development Company, L.P., 2006 + */ + +#ifndef _NETLABEL_MGMT_H +#define _NETLABEL_MGMT_H + +#include +#include + +/* + * The following NetLabel payloads are supported by the management interface. + * + * o ADD: + * Sent by an application to add a domain mapping to the NetLabel system. + * + * Required attributes: + * + * NLBL_MGMT_A_DOMAIN + * NLBL_MGMT_A_PROTOCOL + * + * If IPv4 is specified the following attributes are required: + * + * NLBL_MGMT_A_IPV4ADDR + * NLBL_MGMT_A_IPV4MASK + * + * If IPv6 is specified the following attributes are required: + * + * NLBL_MGMT_A_IPV6ADDR + * NLBL_MGMT_A_IPV6MASK + * + * If using NETLBL_NLTYPE_CIPSOV4 the following attributes are required: + * + * NLBL_MGMT_A_CV4DOI + * + * If using NETLBL_NLTYPE_UNLABELED no other attributes are required, + * however the following attribute may optionally be sent: + * + * NLBL_MGMT_A_FAMILY + * + * o REMOVE: + * Sent by an application to remove a domain mapping from the NetLabel + * system. + * + * Required attributes: + * + * NLBL_MGMT_A_DOMAIN + * + * o LISTALL: + * This message can be sent either from an application or by the kernel in + * response to an application generated LISTALL message. When sent by an + * application there is no payload and the NLM_F_DUMP flag should be set. + * The kernel should respond with a series of the following messages. + * + * Required attributes: + * + * NLBL_MGMT_A_DOMAIN + * NLBL_MGMT_A_FAMILY + * + * If the IP address selectors are not used the following attribute is + * required: + * + * NLBL_MGMT_A_PROTOCOL + * + * If the IP address selectors are used then the following attritbute is + * required: + * + * NLBL_MGMT_A_SELECTORLIST + * + * If the mapping is using the NETLBL_NLTYPE_CIPSOV4 type then the following + * attributes are required: + * + * NLBL_MGMT_A_CV4DOI + * + * If the mapping is using the NETLBL_NLTYPE_UNLABELED type no other + * attributes are required. + * + * o ADDDEF: + * Sent by an application to set the default domain mapping for the NetLabel + * system. + * + * Required attributes: + * + * NLBL_MGMT_A_PROTOCOL + * + * If using NETLBL_NLTYPE_CIPSOV4 the following attributes are required: + * + * NLBL_MGMT_A_CV4DOI + * + * If using NETLBL_NLTYPE_UNLABELED no other attributes are required, + * however the following attribute may optionally be sent: + * + * NLBL_MGMT_A_FAMILY + * + * o REMOVEDEF: + * Sent by an application to remove the default domain mapping from the + * NetLabel system, there is no payload. + * + * o LISTDEF: + * This message can be sent either from an application or by the kernel in + * response to an application generated LISTDEF message. When sent by an + * application there may be an optional payload. + * + * NLBL_MGMT_A_FAMILY + * + * On success the kernel should send a response using the following format: + * + * If the IP address selectors are not used the following attributes are + * required: + * + * NLBL_MGMT_A_PROTOCOL + * NLBL_MGMT_A_FAMILY + * + * If the IP address selectors are used then the following attritbute is + * required: + * + * NLBL_MGMT_A_SELECTORLIST + * + * If the mapping is using the NETLBL_NLTYPE_CIPSOV4 type then the following + * attributes are required: + * + * NLBL_MGMT_A_CV4DOI + * + * If the mapping is using the NETLBL_NLTYPE_UNLABELED type no other + * attributes are required. + * + * o PROTOCOLS: + * Sent by an application to request a list of configured NetLabel protocols + * in the kernel. When sent by an application there is no payload and the + * NLM_F_DUMP flag should be set. The kernel should respond with a series of + * the following messages. + * + * Required attributes: + * + * NLBL_MGMT_A_PROTOCOL + * + * o VERSION: + * Sent by an application to request the NetLabel version. When sent by an + * application there is no payload. This message type is also used by the + * kernel to respond to an VERSION request. + * + * Required attributes: + * + * NLBL_MGMT_A_VERSION + * + */ + +/* NetLabel Management commands */ +enum { + NLBL_MGMT_C_UNSPEC, + NLBL_MGMT_C_ADD, + NLBL_MGMT_C_REMOVE, + NLBL_MGMT_C_LISTALL, + NLBL_MGMT_C_ADDDEF, + NLBL_MGMT_C_REMOVEDEF, + NLBL_MGMT_C_LISTDEF, + NLBL_MGMT_C_PROTOCOLS, + NLBL_MGMT_C_VERSION, + __NLBL_MGMT_C_MAX, +}; + +/* NetLabel Management attributes */ +enum { + NLBL_MGMT_A_UNSPEC, + NLBL_MGMT_A_DOMAIN, + /* (NLA_NUL_STRING) + * the NULL terminated LSM domain string */ + NLBL_MGMT_A_PROTOCOL, + /* (NLA_U32) + * the NetLabel protocol type (defined by NETLBL_NLTYPE_*) */ + NLBL_MGMT_A_VERSION, + /* (NLA_U32) + * the NetLabel protocol version number (defined by + * NETLBL_PROTO_VERSION) */ + NLBL_MGMT_A_CV4DOI, + /* (NLA_U32) + * the CIPSOv4 DOI value */ + NLBL_MGMT_A_IPV6ADDR, + /* (NLA_BINARY, struct in6_addr) + * an IPv6 address */ + NLBL_MGMT_A_IPV6MASK, + /* (NLA_BINARY, struct in6_addr) + * an IPv6 address mask */ + NLBL_MGMT_A_IPV4ADDR, + /* (NLA_BINARY, struct in_addr) + * an IPv4 address */ + NLBL_MGMT_A_IPV4MASK, + /* (NLA_BINARY, struct in_addr) + * and IPv4 address mask */ + NLBL_MGMT_A_ADDRSELECTOR, + /* (NLA_NESTED) + * an IP address selector, must contain an address, mask, and protocol + * attribute plus any protocol specific attributes */ + NLBL_MGMT_A_SELECTORLIST, + /* (NLA_NESTED) + * the selector list, there must be at least one + * NLBL_MGMT_A_ADDRSELECTOR attribute */ + NLBL_MGMT_A_FAMILY, + /* (NLA_U16) + * The address family */ + NLBL_MGMT_A_CLPDOI, + /* (NLA_U32) + * the CALIPSO DOI value */ + __NLBL_MGMT_A_MAX, +}; +#define NLBL_MGMT_A_MAX (__NLBL_MGMT_A_MAX - 1) + +/* NetLabel protocol functions */ +int netlbl_mgmt_genl_init(void); + +/* NetLabel configured protocol reference counter */ +extern atomic_t netlabel_mgmt_protocount; + +#endif diff --git a/net/netlabel/netlabel_unlabeled.c b/net/netlabel/netlabel_unlabeled.c new file mode 100644 index 000000000..9996883bf --- /dev/null +++ b/net/netlabel/netlabel_unlabeled.c @@ -0,0 +1,1557 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NetLabel Unlabeled Support + * + * This file defines functions for dealing with unlabeled packets for the + * NetLabel system. The NetLabel system manages static and dynamic label + * mappings for network protocols such as CIPSO and RIPSO. + * + * Author: Paul Moore + */ + +/* + * (c) Copyright Hewlett-Packard Development Company, L.P., 2006 - 2008 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "netlabel_user.h" +#include "netlabel_addrlist.h" +#include "netlabel_domainhash.h" +#include "netlabel_unlabeled.h" +#include "netlabel_mgmt.h" + +/* NOTE: at present we always use init's network namespace since we don't + * presently support different namespaces even though the majority of + * the functions in this file are "namespace safe" */ + +/* The unlabeled connection hash table which we use to map network interfaces + * and addresses of unlabeled packets to a user specified secid value for the + * LSM. The hash table is used to lookup the network interface entry + * (struct netlbl_unlhsh_iface) and then the interface entry is used to + * lookup an IP address match from an ordered list. If a network interface + * match can not be found in the hash table then the default entry + * (netlbl_unlhsh_def) is used. The IP address entry list + * (struct netlbl_unlhsh_addr) is ordered such that the entries with a + * larger netmask come first. + */ +struct netlbl_unlhsh_tbl { + struct list_head *tbl; + u32 size; +}; +#define netlbl_unlhsh_addr4_entry(iter) \ + container_of(iter, struct netlbl_unlhsh_addr4, list) +struct netlbl_unlhsh_addr4 { + u32 secid; + + struct netlbl_af4list list; + struct rcu_head rcu; +}; +#define netlbl_unlhsh_addr6_entry(iter) \ + container_of(iter, struct netlbl_unlhsh_addr6, list) +struct netlbl_unlhsh_addr6 { + u32 secid; + + struct netlbl_af6list list; + struct rcu_head rcu; +}; +struct netlbl_unlhsh_iface { + int ifindex; + struct list_head addr4_list; + struct list_head addr6_list; + + u32 valid; + struct list_head list; + struct rcu_head rcu; +}; + +/* Argument struct for netlbl_unlhsh_walk() */ +struct netlbl_unlhsh_walk_arg { + struct netlink_callback *nl_cb; + struct sk_buff *skb; + u32 seq; +}; + +/* Unlabeled connection hash table */ +/* updates should be so rare that having one spinlock for the entire + * hash table should be okay */ +static DEFINE_SPINLOCK(netlbl_unlhsh_lock); +#define netlbl_unlhsh_rcu_deref(p) \ + rcu_dereference_check(p, lockdep_is_held(&netlbl_unlhsh_lock)) +static struct netlbl_unlhsh_tbl __rcu *netlbl_unlhsh; +static struct netlbl_unlhsh_iface __rcu *netlbl_unlhsh_def; + +/* Accept unlabeled packets flag */ +static u8 netlabel_unlabel_acceptflg; + +/* NetLabel Generic NETLINK unlabeled family */ +static struct genl_family netlbl_unlabel_gnl_family; + +/* NetLabel Netlink attribute policy */ +static const struct nla_policy netlbl_unlabel_genl_policy[NLBL_UNLABEL_A_MAX + 1] = { + [NLBL_UNLABEL_A_ACPTFLG] = { .type = NLA_U8 }, + [NLBL_UNLABEL_A_IPV6ADDR] = { .type = NLA_BINARY, + .len = sizeof(struct in6_addr) }, + [NLBL_UNLABEL_A_IPV6MASK] = { .type = NLA_BINARY, + .len = sizeof(struct in6_addr) }, + [NLBL_UNLABEL_A_IPV4ADDR] = { .type = NLA_BINARY, + .len = sizeof(struct in_addr) }, + [NLBL_UNLABEL_A_IPV4MASK] = { .type = NLA_BINARY, + .len = sizeof(struct in_addr) }, + [NLBL_UNLABEL_A_IFACE] = { .type = NLA_NUL_STRING, + .len = IFNAMSIZ - 1 }, + [NLBL_UNLABEL_A_SECCTX] = { .type = NLA_BINARY } +}; + +/* + * Unlabeled Connection Hash Table Functions + */ + +/** + * netlbl_unlhsh_free_iface - Frees an interface entry from the hash table + * @entry: the entry's RCU field + * + * Description: + * This function is designed to be used as a callback to the call_rcu() + * function so that memory allocated to a hash table interface entry can be + * released safely. It is important to note that this function does not free + * the IPv4 and IPv6 address lists contained as part of an interface entry. It + * is up to the rest of the code to make sure an interface entry is only freed + * once it's address lists are empty. + * + */ +static void netlbl_unlhsh_free_iface(struct rcu_head *entry) +{ + struct netlbl_unlhsh_iface *iface; + struct netlbl_af4list *iter4; + struct netlbl_af4list *tmp4; +#if IS_ENABLED(CONFIG_IPV6) + struct netlbl_af6list *iter6; + struct netlbl_af6list *tmp6; +#endif /* IPv6 */ + + iface = container_of(entry, struct netlbl_unlhsh_iface, rcu); + + /* no need for locks here since we are the only one with access to this + * structure */ + + netlbl_af4list_foreach_safe(iter4, tmp4, &iface->addr4_list) { + netlbl_af4list_remove_entry(iter4); + kfree(netlbl_unlhsh_addr4_entry(iter4)); + } +#if IS_ENABLED(CONFIG_IPV6) + netlbl_af6list_foreach_safe(iter6, tmp6, &iface->addr6_list) { + netlbl_af6list_remove_entry(iter6); + kfree(netlbl_unlhsh_addr6_entry(iter6)); + } +#endif /* IPv6 */ + kfree(iface); +} + +/** + * netlbl_unlhsh_hash - Hashing function for the hash table + * @ifindex: the network interface/device to hash + * + * Description: + * This is the hashing function for the unlabeled hash table, it returns the + * bucket number for the given device/interface. The caller is responsible for + * ensuring that the hash table is protected with either a RCU read lock or + * the hash table lock. + * + */ +static u32 netlbl_unlhsh_hash(int ifindex) +{ + return ifindex & (netlbl_unlhsh_rcu_deref(netlbl_unlhsh)->size - 1); +} + +/** + * netlbl_unlhsh_search_iface - Search for a matching interface entry + * @ifindex: the network interface + * + * Description: + * Searches the unlabeled connection hash table and returns a pointer to the + * interface entry which matches @ifindex, otherwise NULL is returned. The + * caller is responsible for ensuring that the hash table is protected with + * either a RCU read lock or the hash table lock. + * + */ +static struct netlbl_unlhsh_iface *netlbl_unlhsh_search_iface(int ifindex) +{ + u32 bkt; + struct list_head *bkt_list; + struct netlbl_unlhsh_iface *iter; + + bkt = netlbl_unlhsh_hash(ifindex); + bkt_list = &netlbl_unlhsh_rcu_deref(netlbl_unlhsh)->tbl[bkt]; + list_for_each_entry_rcu(iter, bkt_list, list, + lockdep_is_held(&netlbl_unlhsh_lock)) + if (iter->valid && iter->ifindex == ifindex) + return iter; + + return NULL; +} + +/** + * netlbl_unlhsh_add_addr4 - Add a new IPv4 address entry to the hash table + * @iface: the associated interface entry + * @addr: IPv4 address in network byte order + * @mask: IPv4 address mask in network byte order + * @secid: LSM secid value for entry + * + * Description: + * Add a new address entry into the unlabeled connection hash table using the + * interface entry specified by @iface. On success zero is returned, otherwise + * a negative value is returned. + * + */ +static int netlbl_unlhsh_add_addr4(struct netlbl_unlhsh_iface *iface, + const struct in_addr *addr, + const struct in_addr *mask, + u32 secid) +{ + int ret_val; + struct netlbl_unlhsh_addr4 *entry; + + entry = kzalloc(sizeof(*entry), GFP_ATOMIC); + if (entry == NULL) + return -ENOMEM; + + entry->list.addr = addr->s_addr & mask->s_addr; + entry->list.mask = mask->s_addr; + entry->list.valid = 1; + entry->secid = secid; + + spin_lock(&netlbl_unlhsh_lock); + ret_val = netlbl_af4list_add(&entry->list, &iface->addr4_list); + spin_unlock(&netlbl_unlhsh_lock); + + if (ret_val != 0) + kfree(entry); + return ret_val; +} + +#if IS_ENABLED(CONFIG_IPV6) +/** + * netlbl_unlhsh_add_addr6 - Add a new IPv6 address entry to the hash table + * @iface: the associated interface entry + * @addr: IPv6 address in network byte order + * @mask: IPv6 address mask in network byte order + * @secid: LSM secid value for entry + * + * Description: + * Add a new address entry into the unlabeled connection hash table using the + * interface entry specified by @iface. On success zero is returned, otherwise + * a negative value is returned. + * + */ +static int netlbl_unlhsh_add_addr6(struct netlbl_unlhsh_iface *iface, + const struct in6_addr *addr, + const struct in6_addr *mask, + u32 secid) +{ + int ret_val; + struct netlbl_unlhsh_addr6 *entry; + + entry = kzalloc(sizeof(*entry), GFP_ATOMIC); + if (entry == NULL) + return -ENOMEM; + + entry->list.addr = *addr; + entry->list.addr.s6_addr32[0] &= mask->s6_addr32[0]; + entry->list.addr.s6_addr32[1] &= mask->s6_addr32[1]; + entry->list.addr.s6_addr32[2] &= mask->s6_addr32[2]; + entry->list.addr.s6_addr32[3] &= mask->s6_addr32[3]; + entry->list.mask = *mask; + entry->list.valid = 1; + entry->secid = secid; + + spin_lock(&netlbl_unlhsh_lock); + ret_val = netlbl_af6list_add(&entry->list, &iface->addr6_list); + spin_unlock(&netlbl_unlhsh_lock); + + if (ret_val != 0) + kfree(entry); + return 0; +} +#endif /* IPv6 */ + +/** + * netlbl_unlhsh_add_iface - Adds a new interface entry to the hash table + * @ifindex: network interface + * + * Description: + * Add a new, empty, interface entry into the unlabeled connection hash table. + * On success a pointer to the new interface entry is returned, on failure NULL + * is returned. + * + */ +static struct netlbl_unlhsh_iface *netlbl_unlhsh_add_iface(int ifindex) +{ + u32 bkt; + struct netlbl_unlhsh_iface *iface; + + iface = kzalloc(sizeof(*iface), GFP_ATOMIC); + if (iface == NULL) + return NULL; + + iface->ifindex = ifindex; + INIT_LIST_HEAD(&iface->addr4_list); + INIT_LIST_HEAD(&iface->addr6_list); + iface->valid = 1; + + spin_lock(&netlbl_unlhsh_lock); + if (ifindex > 0) { + bkt = netlbl_unlhsh_hash(ifindex); + if (netlbl_unlhsh_search_iface(ifindex) != NULL) + goto add_iface_failure; + list_add_tail_rcu(&iface->list, + &netlbl_unlhsh_rcu_deref(netlbl_unlhsh)->tbl[bkt]); + } else { + INIT_LIST_HEAD(&iface->list); + if (netlbl_unlhsh_rcu_deref(netlbl_unlhsh_def) != NULL) + goto add_iface_failure; + rcu_assign_pointer(netlbl_unlhsh_def, iface); + } + spin_unlock(&netlbl_unlhsh_lock); + + return iface; + +add_iface_failure: + spin_unlock(&netlbl_unlhsh_lock); + kfree(iface); + return NULL; +} + +/** + * netlbl_unlhsh_add - Adds a new entry to the unlabeled connection hash table + * @net: network namespace + * @dev_name: interface name + * @addr: IP address in network byte order + * @mask: address mask in network byte order + * @addr_len: length of address/mask (4 for IPv4, 16 for IPv6) + * @secid: LSM secid value for the entry + * @audit_info: NetLabel audit information + * + * Description: + * Adds a new entry to the unlabeled connection hash table. Returns zero on + * success, negative values on failure. + * + */ +int netlbl_unlhsh_add(struct net *net, + const char *dev_name, + const void *addr, + const void *mask, + u32 addr_len, + u32 secid, + struct netlbl_audit *audit_info) +{ + int ret_val; + int ifindex; + struct net_device *dev; + struct netlbl_unlhsh_iface *iface; + struct audit_buffer *audit_buf = NULL; + char *secctx = NULL; + u32 secctx_len; + + if (addr_len != sizeof(struct in_addr) && + addr_len != sizeof(struct in6_addr)) + return -EINVAL; + + rcu_read_lock(); + if (dev_name != NULL) { + dev = dev_get_by_name_rcu(net, dev_name); + if (dev == NULL) { + ret_val = -ENODEV; + goto unlhsh_add_return; + } + ifindex = dev->ifindex; + iface = netlbl_unlhsh_search_iface(ifindex); + } else { + ifindex = 0; + iface = rcu_dereference(netlbl_unlhsh_def); + } + if (iface == NULL) { + iface = netlbl_unlhsh_add_iface(ifindex); + if (iface == NULL) { + ret_val = -ENOMEM; + goto unlhsh_add_return; + } + } + audit_buf = netlbl_audit_start_common(AUDIT_MAC_UNLBL_STCADD, + audit_info); + switch (addr_len) { + case sizeof(struct in_addr): { + const struct in_addr *addr4 = addr; + const struct in_addr *mask4 = mask; + + ret_val = netlbl_unlhsh_add_addr4(iface, addr4, mask4, secid); + if (audit_buf != NULL) + netlbl_af4list_audit_addr(audit_buf, 1, + dev_name, + addr4->s_addr, + mask4->s_addr); + break; + } +#if IS_ENABLED(CONFIG_IPV6) + case sizeof(struct in6_addr): { + const struct in6_addr *addr6 = addr; + const struct in6_addr *mask6 = mask; + + ret_val = netlbl_unlhsh_add_addr6(iface, addr6, mask6, secid); + if (audit_buf != NULL) + netlbl_af6list_audit_addr(audit_buf, 1, + dev_name, + addr6, mask6); + break; + } +#endif /* IPv6 */ + default: + ret_val = -EINVAL; + } + if (ret_val == 0) + atomic_inc(&netlabel_mgmt_protocount); + +unlhsh_add_return: + rcu_read_unlock(); + if (audit_buf != NULL) { + if (security_secid_to_secctx(secid, + &secctx, + &secctx_len) == 0) { + audit_log_format(audit_buf, " sec_obj=%s", secctx); + security_release_secctx(secctx, secctx_len); + } + audit_log_format(audit_buf, " res=%u", ret_val == 0 ? 1 : 0); + audit_log_end(audit_buf); + } + return ret_val; +} + +/** + * netlbl_unlhsh_remove_addr4 - Remove an IPv4 address entry + * @net: network namespace + * @iface: interface entry + * @addr: IP address + * @mask: IP address mask + * @audit_info: NetLabel audit information + * + * Description: + * Remove an IP address entry from the unlabeled connection hash table. + * Returns zero on success, negative values on failure. + * + */ +static int netlbl_unlhsh_remove_addr4(struct net *net, + struct netlbl_unlhsh_iface *iface, + const struct in_addr *addr, + const struct in_addr *mask, + struct netlbl_audit *audit_info) +{ + struct netlbl_af4list *list_entry; + struct netlbl_unlhsh_addr4 *entry; + struct audit_buffer *audit_buf; + struct net_device *dev; + char *secctx; + u32 secctx_len; + + spin_lock(&netlbl_unlhsh_lock); + list_entry = netlbl_af4list_remove(addr->s_addr, mask->s_addr, + &iface->addr4_list); + spin_unlock(&netlbl_unlhsh_lock); + if (list_entry != NULL) + entry = netlbl_unlhsh_addr4_entry(list_entry); + else + entry = NULL; + + audit_buf = netlbl_audit_start_common(AUDIT_MAC_UNLBL_STCDEL, + audit_info); + if (audit_buf != NULL) { + dev = dev_get_by_index(net, iface->ifindex); + netlbl_af4list_audit_addr(audit_buf, 1, + (dev != NULL ? dev->name : NULL), + addr->s_addr, mask->s_addr); + dev_put(dev); + if (entry != NULL && + security_secid_to_secctx(entry->secid, + &secctx, &secctx_len) == 0) { + audit_log_format(audit_buf, " sec_obj=%s", secctx); + security_release_secctx(secctx, secctx_len); + } + audit_log_format(audit_buf, " res=%u", entry != NULL ? 1 : 0); + audit_log_end(audit_buf); + } + + if (entry == NULL) + return -ENOENT; + + kfree_rcu(entry, rcu); + return 0; +} + +#if IS_ENABLED(CONFIG_IPV6) +/** + * netlbl_unlhsh_remove_addr6 - Remove an IPv6 address entry + * @net: network namespace + * @iface: interface entry + * @addr: IP address + * @mask: IP address mask + * @audit_info: NetLabel audit information + * + * Description: + * Remove an IP address entry from the unlabeled connection hash table. + * Returns zero on success, negative values on failure. + * + */ +static int netlbl_unlhsh_remove_addr6(struct net *net, + struct netlbl_unlhsh_iface *iface, + const struct in6_addr *addr, + const struct in6_addr *mask, + struct netlbl_audit *audit_info) +{ + struct netlbl_af6list *list_entry; + struct netlbl_unlhsh_addr6 *entry; + struct audit_buffer *audit_buf; + struct net_device *dev; + char *secctx; + u32 secctx_len; + + spin_lock(&netlbl_unlhsh_lock); + list_entry = netlbl_af6list_remove(addr, mask, &iface->addr6_list); + spin_unlock(&netlbl_unlhsh_lock); + if (list_entry != NULL) + entry = netlbl_unlhsh_addr6_entry(list_entry); + else + entry = NULL; + + audit_buf = netlbl_audit_start_common(AUDIT_MAC_UNLBL_STCDEL, + audit_info); + if (audit_buf != NULL) { + dev = dev_get_by_index(net, iface->ifindex); + netlbl_af6list_audit_addr(audit_buf, 1, + (dev != NULL ? dev->name : NULL), + addr, mask); + dev_put(dev); + if (entry != NULL && + security_secid_to_secctx(entry->secid, + &secctx, &secctx_len) == 0) { + audit_log_format(audit_buf, " sec_obj=%s", secctx); + security_release_secctx(secctx, secctx_len); + } + audit_log_format(audit_buf, " res=%u", entry != NULL ? 1 : 0); + audit_log_end(audit_buf); + } + + if (entry == NULL) + return -ENOENT; + + kfree_rcu(entry, rcu); + return 0; +} +#endif /* IPv6 */ + +/** + * netlbl_unlhsh_condremove_iface - Remove an interface entry + * @iface: the interface entry + * + * Description: + * Remove an interface entry from the unlabeled connection hash table if it is + * empty. An interface entry is considered to be empty if there are no + * address entries assigned to it. + * + */ +static void netlbl_unlhsh_condremove_iface(struct netlbl_unlhsh_iface *iface) +{ + struct netlbl_af4list *iter4; +#if IS_ENABLED(CONFIG_IPV6) + struct netlbl_af6list *iter6; +#endif /* IPv6 */ + + spin_lock(&netlbl_unlhsh_lock); + netlbl_af4list_foreach_rcu(iter4, &iface->addr4_list) + goto unlhsh_condremove_failure; +#if IS_ENABLED(CONFIG_IPV6) + netlbl_af6list_foreach_rcu(iter6, &iface->addr6_list) + goto unlhsh_condremove_failure; +#endif /* IPv6 */ + iface->valid = 0; + if (iface->ifindex > 0) + list_del_rcu(&iface->list); + else + RCU_INIT_POINTER(netlbl_unlhsh_def, NULL); + spin_unlock(&netlbl_unlhsh_lock); + + call_rcu(&iface->rcu, netlbl_unlhsh_free_iface); + return; + +unlhsh_condremove_failure: + spin_unlock(&netlbl_unlhsh_lock); +} + +/** + * netlbl_unlhsh_remove - Remove an entry from the unlabeled hash table + * @net: network namespace + * @dev_name: interface name + * @addr: IP address in network byte order + * @mask: address mask in network byte order + * @addr_len: length of address/mask (4 for IPv4, 16 for IPv6) + * @audit_info: NetLabel audit information + * + * Description: + * Removes and existing entry from the unlabeled connection hash table. + * Returns zero on success, negative values on failure. + * + */ +int netlbl_unlhsh_remove(struct net *net, + const char *dev_name, + const void *addr, + const void *mask, + u32 addr_len, + struct netlbl_audit *audit_info) +{ + int ret_val; + struct net_device *dev; + struct netlbl_unlhsh_iface *iface; + + if (addr_len != sizeof(struct in_addr) && + addr_len != sizeof(struct in6_addr)) + return -EINVAL; + + rcu_read_lock(); + if (dev_name != NULL) { + dev = dev_get_by_name_rcu(net, dev_name); + if (dev == NULL) { + ret_val = -ENODEV; + goto unlhsh_remove_return; + } + iface = netlbl_unlhsh_search_iface(dev->ifindex); + } else + iface = rcu_dereference(netlbl_unlhsh_def); + if (iface == NULL) { + ret_val = -ENOENT; + goto unlhsh_remove_return; + } + switch (addr_len) { + case sizeof(struct in_addr): + ret_val = netlbl_unlhsh_remove_addr4(net, + iface, addr, mask, + audit_info); + break; +#if IS_ENABLED(CONFIG_IPV6) + case sizeof(struct in6_addr): + ret_val = netlbl_unlhsh_remove_addr6(net, + iface, addr, mask, + audit_info); + break; +#endif /* IPv6 */ + default: + ret_val = -EINVAL; + } + if (ret_val == 0) { + netlbl_unlhsh_condremove_iface(iface); + atomic_dec(&netlabel_mgmt_protocount); + } + +unlhsh_remove_return: + rcu_read_unlock(); + return ret_val; +} + +/* + * General Helper Functions + */ + +/** + * netlbl_unlhsh_netdev_handler - Network device notification handler + * @this: notifier block + * @event: the event + * @ptr: the netdevice notifier info (cast to void) + * + * Description: + * Handle network device events, although at present all we care about is a + * network device going away. In the case of a device going away we clear any + * related entries from the unlabeled connection hash table. + * + */ +static int netlbl_unlhsh_netdev_handler(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct netlbl_unlhsh_iface *iface = NULL; + + if (!net_eq(dev_net(dev), &init_net)) + return NOTIFY_DONE; + + /* XXX - should this be a check for NETDEV_DOWN or _UNREGISTER? */ + if (event == NETDEV_DOWN) { + spin_lock(&netlbl_unlhsh_lock); + iface = netlbl_unlhsh_search_iface(dev->ifindex); + if (iface != NULL && iface->valid) { + iface->valid = 0; + list_del_rcu(&iface->list); + } else + iface = NULL; + spin_unlock(&netlbl_unlhsh_lock); + } + + if (iface != NULL) + call_rcu(&iface->rcu, netlbl_unlhsh_free_iface); + + return NOTIFY_DONE; +} + +/** + * netlbl_unlabel_acceptflg_set - Set the unlabeled accept flag + * @value: desired value + * @audit_info: NetLabel audit information + * + * Description: + * Set the value of the unlabeled accept flag to @value. + * + */ +static void netlbl_unlabel_acceptflg_set(u8 value, + struct netlbl_audit *audit_info) +{ + struct audit_buffer *audit_buf; + u8 old_val; + + old_val = netlabel_unlabel_acceptflg; + netlabel_unlabel_acceptflg = value; + audit_buf = netlbl_audit_start_common(AUDIT_MAC_UNLBL_ALLOW, + audit_info); + if (audit_buf != NULL) { + audit_log_format(audit_buf, + " unlbl_accept=%u old=%u", value, old_val); + audit_log_end(audit_buf); + } +} + +/** + * netlbl_unlabel_addrinfo_get - Get the IPv4/6 address information + * @info: the Generic NETLINK info block + * @addr: the IP address + * @mask: the IP address mask + * @len: the address length + * + * Description: + * Examine the Generic NETLINK message and extract the IP address information. + * Returns zero on success, negative values on failure. + * + */ +static int netlbl_unlabel_addrinfo_get(struct genl_info *info, + void **addr, + void **mask, + u32 *len) +{ + u32 addr_len; + + if (info->attrs[NLBL_UNLABEL_A_IPV4ADDR] && + info->attrs[NLBL_UNLABEL_A_IPV4MASK]) { + addr_len = nla_len(info->attrs[NLBL_UNLABEL_A_IPV4ADDR]); + if (addr_len != sizeof(struct in_addr) && + addr_len != nla_len(info->attrs[NLBL_UNLABEL_A_IPV4MASK])) + return -EINVAL; + *len = addr_len; + *addr = nla_data(info->attrs[NLBL_UNLABEL_A_IPV4ADDR]); + *mask = nla_data(info->attrs[NLBL_UNLABEL_A_IPV4MASK]); + return 0; + } else if (info->attrs[NLBL_UNLABEL_A_IPV6ADDR]) { + addr_len = nla_len(info->attrs[NLBL_UNLABEL_A_IPV6ADDR]); + if (addr_len != sizeof(struct in6_addr) && + addr_len != nla_len(info->attrs[NLBL_UNLABEL_A_IPV6MASK])) + return -EINVAL; + *len = addr_len; + *addr = nla_data(info->attrs[NLBL_UNLABEL_A_IPV6ADDR]); + *mask = nla_data(info->attrs[NLBL_UNLABEL_A_IPV6MASK]); + return 0; + } + + return -EINVAL; +} + +/* + * NetLabel Command Handlers + */ + +/** + * netlbl_unlabel_accept - Handle an ACCEPT message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Process a user generated ACCEPT message and set the accept flag accordingly. + * Returns zero on success, negative values on failure. + * + */ +static int netlbl_unlabel_accept(struct sk_buff *skb, struct genl_info *info) +{ + u8 value; + struct netlbl_audit audit_info; + + if (info->attrs[NLBL_UNLABEL_A_ACPTFLG]) { + value = nla_get_u8(info->attrs[NLBL_UNLABEL_A_ACPTFLG]); + if (value == 1 || value == 0) { + netlbl_netlink_auditinfo(&audit_info); + netlbl_unlabel_acceptflg_set(value, &audit_info); + return 0; + } + } + + return -EINVAL; +} + +/** + * netlbl_unlabel_list - Handle a LIST message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Process a user generated LIST message and respond with the current status. + * Returns zero on success, negative values on failure. + * + */ +static int netlbl_unlabel_list(struct sk_buff *skb, struct genl_info *info) +{ + int ret_val = -EINVAL; + struct sk_buff *ans_skb; + void *data; + + ans_skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (ans_skb == NULL) + goto list_failure; + data = genlmsg_put_reply(ans_skb, info, &netlbl_unlabel_gnl_family, + 0, NLBL_UNLABEL_C_LIST); + if (data == NULL) { + ret_val = -ENOMEM; + goto list_failure; + } + + ret_val = nla_put_u8(ans_skb, + NLBL_UNLABEL_A_ACPTFLG, + netlabel_unlabel_acceptflg); + if (ret_val != 0) + goto list_failure; + + genlmsg_end(ans_skb, data); + return genlmsg_reply(ans_skb, info); + +list_failure: + kfree_skb(ans_skb); + return ret_val; +} + +/** + * netlbl_unlabel_staticadd - Handle a STATICADD message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Process a user generated STATICADD message and add a new unlabeled + * connection entry to the hash table. Returns zero on success, negative + * values on failure. + * + */ +static int netlbl_unlabel_staticadd(struct sk_buff *skb, + struct genl_info *info) +{ + int ret_val; + char *dev_name; + void *addr; + void *mask; + u32 addr_len; + u32 secid; + struct netlbl_audit audit_info; + + /* Don't allow users to add both IPv4 and IPv6 addresses for a + * single entry. However, allow users to create two entries, one each + * for IPv4 and IPv6, with the same LSM security context which should + * achieve the same result. */ + if (!info->attrs[NLBL_UNLABEL_A_SECCTX] || + !info->attrs[NLBL_UNLABEL_A_IFACE] || + !((!info->attrs[NLBL_UNLABEL_A_IPV4ADDR] || + !info->attrs[NLBL_UNLABEL_A_IPV4MASK]) ^ + (!info->attrs[NLBL_UNLABEL_A_IPV6ADDR] || + !info->attrs[NLBL_UNLABEL_A_IPV6MASK]))) + return -EINVAL; + + netlbl_netlink_auditinfo(&audit_info); + + ret_val = netlbl_unlabel_addrinfo_get(info, &addr, &mask, &addr_len); + if (ret_val != 0) + return ret_val; + dev_name = nla_data(info->attrs[NLBL_UNLABEL_A_IFACE]); + ret_val = security_secctx_to_secid( + nla_data(info->attrs[NLBL_UNLABEL_A_SECCTX]), + nla_len(info->attrs[NLBL_UNLABEL_A_SECCTX]), + &secid); + if (ret_val != 0) + return ret_val; + + return netlbl_unlhsh_add(&init_net, + dev_name, addr, mask, addr_len, secid, + &audit_info); +} + +/** + * netlbl_unlabel_staticadddef - Handle a STATICADDDEF message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Process a user generated STATICADDDEF message and add a new default + * unlabeled connection entry. Returns zero on success, negative values on + * failure. + * + */ +static int netlbl_unlabel_staticadddef(struct sk_buff *skb, + struct genl_info *info) +{ + int ret_val; + void *addr; + void *mask; + u32 addr_len; + u32 secid; + struct netlbl_audit audit_info; + + /* Don't allow users to add both IPv4 and IPv6 addresses for a + * single entry. However, allow users to create two entries, one each + * for IPv4 and IPv6, with the same LSM security context which should + * achieve the same result. */ + if (!info->attrs[NLBL_UNLABEL_A_SECCTX] || + !((!info->attrs[NLBL_UNLABEL_A_IPV4ADDR] || + !info->attrs[NLBL_UNLABEL_A_IPV4MASK]) ^ + (!info->attrs[NLBL_UNLABEL_A_IPV6ADDR] || + !info->attrs[NLBL_UNLABEL_A_IPV6MASK]))) + return -EINVAL; + + netlbl_netlink_auditinfo(&audit_info); + + ret_val = netlbl_unlabel_addrinfo_get(info, &addr, &mask, &addr_len); + if (ret_val != 0) + return ret_val; + ret_val = security_secctx_to_secid( + nla_data(info->attrs[NLBL_UNLABEL_A_SECCTX]), + nla_len(info->attrs[NLBL_UNLABEL_A_SECCTX]), + &secid); + if (ret_val != 0) + return ret_val; + + return netlbl_unlhsh_add(&init_net, + NULL, addr, mask, addr_len, secid, + &audit_info); +} + +/** + * netlbl_unlabel_staticremove - Handle a STATICREMOVE message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Process a user generated STATICREMOVE message and remove the specified + * unlabeled connection entry. Returns zero on success, negative values on + * failure. + * + */ +static int netlbl_unlabel_staticremove(struct sk_buff *skb, + struct genl_info *info) +{ + int ret_val; + char *dev_name; + void *addr; + void *mask; + u32 addr_len; + struct netlbl_audit audit_info; + + /* See the note in netlbl_unlabel_staticadd() about not allowing both + * IPv4 and IPv6 in the same entry. */ + if (!info->attrs[NLBL_UNLABEL_A_IFACE] || + !((!info->attrs[NLBL_UNLABEL_A_IPV4ADDR] || + !info->attrs[NLBL_UNLABEL_A_IPV4MASK]) ^ + (!info->attrs[NLBL_UNLABEL_A_IPV6ADDR] || + !info->attrs[NLBL_UNLABEL_A_IPV6MASK]))) + return -EINVAL; + + netlbl_netlink_auditinfo(&audit_info); + + ret_val = netlbl_unlabel_addrinfo_get(info, &addr, &mask, &addr_len); + if (ret_val != 0) + return ret_val; + dev_name = nla_data(info->attrs[NLBL_UNLABEL_A_IFACE]); + + return netlbl_unlhsh_remove(&init_net, + dev_name, addr, mask, addr_len, + &audit_info); +} + +/** + * netlbl_unlabel_staticremovedef - Handle a STATICREMOVEDEF message + * @skb: the NETLINK buffer + * @info: the Generic NETLINK info block + * + * Description: + * Process a user generated STATICREMOVEDEF message and remove the default + * unlabeled connection entry. Returns zero on success, negative values on + * failure. + * + */ +static int netlbl_unlabel_staticremovedef(struct sk_buff *skb, + struct genl_info *info) +{ + int ret_val; + void *addr; + void *mask; + u32 addr_len; + struct netlbl_audit audit_info; + + /* See the note in netlbl_unlabel_staticadd() about not allowing both + * IPv4 and IPv6 in the same entry. */ + if (!((!info->attrs[NLBL_UNLABEL_A_IPV4ADDR] || + !info->attrs[NLBL_UNLABEL_A_IPV4MASK]) ^ + (!info->attrs[NLBL_UNLABEL_A_IPV6ADDR] || + !info->attrs[NLBL_UNLABEL_A_IPV6MASK]))) + return -EINVAL; + + netlbl_netlink_auditinfo(&audit_info); + + ret_val = netlbl_unlabel_addrinfo_get(info, &addr, &mask, &addr_len); + if (ret_val != 0) + return ret_val; + + return netlbl_unlhsh_remove(&init_net, + NULL, addr, mask, addr_len, + &audit_info); +} + + +/** + * netlbl_unlabel_staticlist_gen - Generate messages for STATICLIST[DEF] + * @cmd: command/message + * @iface: the interface entry + * @addr4: the IPv4 address entry + * @addr6: the IPv6 address entry + * @arg: the netlbl_unlhsh_walk_arg structure + * + * Description: + * This function is designed to be used to generate a response for a + * STATICLIST or STATICLISTDEF message. When called either @addr4 or @addr6 + * can be specified, not both, the other unspecified entry should be set to + * NULL by the caller. Returns the size of the message on success, negative + * values on failure. + * + */ +static int netlbl_unlabel_staticlist_gen(u32 cmd, + const struct netlbl_unlhsh_iface *iface, + const struct netlbl_unlhsh_addr4 *addr4, + const struct netlbl_unlhsh_addr6 *addr6, + void *arg) +{ + int ret_val = -ENOMEM; + struct netlbl_unlhsh_walk_arg *cb_arg = arg; + struct net_device *dev; + void *data; + u32 secid; + char *secctx; + u32 secctx_len; + + data = genlmsg_put(cb_arg->skb, NETLINK_CB(cb_arg->nl_cb->skb).portid, + cb_arg->seq, &netlbl_unlabel_gnl_family, + NLM_F_MULTI, cmd); + if (data == NULL) + goto list_cb_failure; + + if (iface->ifindex > 0) { + dev = dev_get_by_index(&init_net, iface->ifindex); + if (!dev) { + ret_val = -ENODEV; + goto list_cb_failure; + } + ret_val = nla_put_string(cb_arg->skb, + NLBL_UNLABEL_A_IFACE, dev->name); + dev_put(dev); + if (ret_val != 0) + goto list_cb_failure; + } + + if (addr4) { + struct in_addr addr_struct; + + addr_struct.s_addr = addr4->list.addr; + ret_val = nla_put_in_addr(cb_arg->skb, + NLBL_UNLABEL_A_IPV4ADDR, + addr_struct.s_addr); + if (ret_val != 0) + goto list_cb_failure; + + addr_struct.s_addr = addr4->list.mask; + ret_val = nla_put_in_addr(cb_arg->skb, + NLBL_UNLABEL_A_IPV4MASK, + addr_struct.s_addr); + if (ret_val != 0) + goto list_cb_failure; + + secid = addr4->secid; + } else { + ret_val = nla_put_in6_addr(cb_arg->skb, + NLBL_UNLABEL_A_IPV6ADDR, + &addr6->list.addr); + if (ret_val != 0) + goto list_cb_failure; + + ret_val = nla_put_in6_addr(cb_arg->skb, + NLBL_UNLABEL_A_IPV6MASK, + &addr6->list.mask); + if (ret_val != 0) + goto list_cb_failure; + + secid = addr6->secid; + } + + ret_val = security_secid_to_secctx(secid, &secctx, &secctx_len); + if (ret_val != 0) + goto list_cb_failure; + ret_val = nla_put(cb_arg->skb, + NLBL_UNLABEL_A_SECCTX, + secctx_len, + secctx); + security_release_secctx(secctx, secctx_len); + if (ret_val != 0) + goto list_cb_failure; + + cb_arg->seq++; + genlmsg_end(cb_arg->skb, data); + return 0; + +list_cb_failure: + genlmsg_cancel(cb_arg->skb, data); + return ret_val; +} + +/** + * netlbl_unlabel_staticlist - Handle a STATICLIST message + * @skb: the NETLINK buffer + * @cb: the NETLINK callback + * + * Description: + * Process a user generated STATICLIST message and dump the unlabeled + * connection hash table in a form suitable for use in a kernel generated + * STATICLIST message. Returns the length of @skb. + * + */ +static int netlbl_unlabel_staticlist(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct netlbl_unlhsh_walk_arg cb_arg; + u32 skip_bkt = cb->args[0]; + u32 skip_chain = cb->args[1]; + u32 skip_addr4 = cb->args[2]; + u32 iter_bkt, iter_chain = 0, iter_addr4 = 0, iter_addr6 = 0; + struct netlbl_unlhsh_iface *iface; + struct list_head *iter_list; + struct netlbl_af4list *addr4; +#if IS_ENABLED(CONFIG_IPV6) + u32 skip_addr6 = cb->args[3]; + struct netlbl_af6list *addr6; +#endif + + cb_arg.nl_cb = cb; + cb_arg.skb = skb; + cb_arg.seq = cb->nlh->nlmsg_seq; + + rcu_read_lock(); + for (iter_bkt = skip_bkt; + iter_bkt < rcu_dereference(netlbl_unlhsh)->size; + iter_bkt++) { + iter_list = &rcu_dereference(netlbl_unlhsh)->tbl[iter_bkt]; + list_for_each_entry_rcu(iface, iter_list, list) { + if (!iface->valid || + iter_chain++ < skip_chain) + continue; + netlbl_af4list_foreach_rcu(addr4, + &iface->addr4_list) { + if (iter_addr4++ < skip_addr4) + continue; + if (netlbl_unlabel_staticlist_gen( + NLBL_UNLABEL_C_STATICLIST, + iface, + netlbl_unlhsh_addr4_entry(addr4), + NULL, + &cb_arg) < 0) { + iter_addr4--; + iter_chain--; + goto unlabel_staticlist_return; + } + } + iter_addr4 = 0; + skip_addr4 = 0; +#if IS_ENABLED(CONFIG_IPV6) + netlbl_af6list_foreach_rcu(addr6, + &iface->addr6_list) { + if (iter_addr6++ < skip_addr6) + continue; + if (netlbl_unlabel_staticlist_gen( + NLBL_UNLABEL_C_STATICLIST, + iface, + NULL, + netlbl_unlhsh_addr6_entry(addr6), + &cb_arg) < 0) { + iter_addr6--; + iter_chain--; + goto unlabel_staticlist_return; + } + } + iter_addr6 = 0; + skip_addr6 = 0; +#endif /* IPv6 */ + } + iter_chain = 0; + skip_chain = 0; + } + +unlabel_staticlist_return: + rcu_read_unlock(); + cb->args[0] = iter_bkt; + cb->args[1] = iter_chain; + cb->args[2] = iter_addr4; + cb->args[3] = iter_addr6; + return skb->len; +} + +/** + * netlbl_unlabel_staticlistdef - Handle a STATICLISTDEF message + * @skb: the NETLINK buffer + * @cb: the NETLINK callback + * + * Description: + * Process a user generated STATICLISTDEF message and dump the default + * unlabeled connection entry in a form suitable for use in a kernel generated + * STATICLISTDEF message. Returns the length of @skb. + * + */ +static int netlbl_unlabel_staticlistdef(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct netlbl_unlhsh_walk_arg cb_arg; + struct netlbl_unlhsh_iface *iface; + u32 iter_addr4 = 0, iter_addr6 = 0; + struct netlbl_af4list *addr4; +#if IS_ENABLED(CONFIG_IPV6) + struct netlbl_af6list *addr6; +#endif + + cb_arg.nl_cb = cb; + cb_arg.skb = skb; + cb_arg.seq = cb->nlh->nlmsg_seq; + + rcu_read_lock(); + iface = rcu_dereference(netlbl_unlhsh_def); + if (iface == NULL || !iface->valid) + goto unlabel_staticlistdef_return; + + netlbl_af4list_foreach_rcu(addr4, &iface->addr4_list) { + if (iter_addr4++ < cb->args[0]) + continue; + if (netlbl_unlabel_staticlist_gen(NLBL_UNLABEL_C_STATICLISTDEF, + iface, + netlbl_unlhsh_addr4_entry(addr4), + NULL, + &cb_arg) < 0) { + iter_addr4--; + goto unlabel_staticlistdef_return; + } + } +#if IS_ENABLED(CONFIG_IPV6) + netlbl_af6list_foreach_rcu(addr6, &iface->addr6_list) { + if (iter_addr6++ < cb->args[1]) + continue; + if (netlbl_unlabel_staticlist_gen(NLBL_UNLABEL_C_STATICLISTDEF, + iface, + NULL, + netlbl_unlhsh_addr6_entry(addr6), + &cb_arg) < 0) { + iter_addr6--; + goto unlabel_staticlistdef_return; + } + } +#endif /* IPv6 */ + +unlabel_staticlistdef_return: + rcu_read_unlock(); + cb->args[0] = iter_addr4; + cb->args[1] = iter_addr6; + return skb->len; +} + +/* + * NetLabel Generic NETLINK Command Definitions + */ + +static const struct genl_small_ops netlbl_unlabel_genl_ops[] = { + { + .cmd = NLBL_UNLABEL_C_STATICADD, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = netlbl_unlabel_staticadd, + .dumpit = NULL, + }, + { + .cmd = NLBL_UNLABEL_C_STATICREMOVE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = netlbl_unlabel_staticremove, + .dumpit = NULL, + }, + { + .cmd = NLBL_UNLABEL_C_STATICLIST, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, + .doit = NULL, + .dumpit = netlbl_unlabel_staticlist, + }, + { + .cmd = NLBL_UNLABEL_C_STATICADDDEF, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = netlbl_unlabel_staticadddef, + .dumpit = NULL, + }, + { + .cmd = NLBL_UNLABEL_C_STATICREMOVEDEF, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = netlbl_unlabel_staticremovedef, + .dumpit = NULL, + }, + { + .cmd = NLBL_UNLABEL_C_STATICLISTDEF, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, + .doit = NULL, + .dumpit = netlbl_unlabel_staticlistdef, + }, + { + .cmd = NLBL_UNLABEL_C_ACCEPT, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = netlbl_unlabel_accept, + .dumpit = NULL, + }, + { + .cmd = NLBL_UNLABEL_C_LIST, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, + .doit = netlbl_unlabel_list, + .dumpit = NULL, + }, +}; + +static struct genl_family netlbl_unlabel_gnl_family __ro_after_init = { + .hdrsize = 0, + .name = NETLBL_NLTYPE_UNLABELED_NAME, + .version = NETLBL_PROTO_VERSION, + .maxattr = NLBL_UNLABEL_A_MAX, + .policy = netlbl_unlabel_genl_policy, + .module = THIS_MODULE, + .small_ops = netlbl_unlabel_genl_ops, + .n_small_ops = ARRAY_SIZE(netlbl_unlabel_genl_ops), + .resv_start_op = NLBL_UNLABEL_C_STATICLISTDEF + 1, +}; + +/* + * NetLabel Generic NETLINK Protocol Functions + */ + +/** + * netlbl_unlabel_genl_init - Register the Unlabeled NetLabel component + * + * Description: + * Register the unlabeled packet NetLabel component with the Generic NETLINK + * mechanism. Returns zero on success, negative values on failure. + * + */ +int __init netlbl_unlabel_genl_init(void) +{ + return genl_register_family(&netlbl_unlabel_gnl_family); +} + +/* + * NetLabel KAPI Hooks + */ + +static struct notifier_block netlbl_unlhsh_netdev_notifier = { + .notifier_call = netlbl_unlhsh_netdev_handler, +}; + +/** + * netlbl_unlabel_init - Initialize the unlabeled connection hash table + * @size: the number of bits to use for the hash buckets + * + * Description: + * Initializes the unlabeled connection hash table and registers a network + * device notification handler. This function should only be called by the + * NetLabel subsystem itself during initialization. Returns zero on success, + * non-zero values on error. + * + */ +int __init netlbl_unlabel_init(u32 size) +{ + u32 iter; + struct netlbl_unlhsh_tbl *hsh_tbl; + + if (size == 0) + return -EINVAL; + + hsh_tbl = kmalloc(sizeof(*hsh_tbl), GFP_KERNEL); + if (hsh_tbl == NULL) + return -ENOMEM; + hsh_tbl->size = 1 << size; + hsh_tbl->tbl = kcalloc(hsh_tbl->size, + sizeof(struct list_head), + GFP_KERNEL); + if (hsh_tbl->tbl == NULL) { + kfree(hsh_tbl); + return -ENOMEM; + } + for (iter = 0; iter < hsh_tbl->size; iter++) + INIT_LIST_HEAD(&hsh_tbl->tbl[iter]); + + spin_lock(&netlbl_unlhsh_lock); + rcu_assign_pointer(netlbl_unlhsh, hsh_tbl); + spin_unlock(&netlbl_unlhsh_lock); + + register_netdevice_notifier(&netlbl_unlhsh_netdev_notifier); + + return 0; +} + +/** + * netlbl_unlabel_getattr - Get the security attributes for an unlabled packet + * @skb: the packet + * @family: protocol family + * @secattr: the security attributes + * + * Description: + * Determine the security attributes, if any, for an unlabled packet and return + * them in @secattr. Returns zero on success and negative values on failure. + * + */ +int netlbl_unlabel_getattr(const struct sk_buff *skb, + u16 family, + struct netlbl_lsm_secattr *secattr) +{ + struct netlbl_unlhsh_iface *iface; + + rcu_read_lock(); + iface = netlbl_unlhsh_search_iface(skb->skb_iif); + if (iface == NULL) + iface = rcu_dereference(netlbl_unlhsh_def); + if (iface == NULL || !iface->valid) + goto unlabel_getattr_nolabel; + +#if IS_ENABLED(CONFIG_IPV6) + /* When resolving a fallback label, check the sk_buff version as + * it is possible (e.g. SCTP) to have family = PF_INET6 while + * receiving ip_hdr(skb)->version = 4. + */ + if (family == PF_INET6 && ip_hdr(skb)->version == 4) + family = PF_INET; +#endif /* IPv6 */ + + switch (family) { + case PF_INET: { + struct iphdr *hdr4; + struct netlbl_af4list *addr4; + + hdr4 = ip_hdr(skb); + addr4 = netlbl_af4list_search(hdr4->saddr, + &iface->addr4_list); + if (addr4 == NULL) + goto unlabel_getattr_nolabel; + secattr->attr.secid = netlbl_unlhsh_addr4_entry(addr4)->secid; + break; + } +#if IS_ENABLED(CONFIG_IPV6) + case PF_INET6: { + struct ipv6hdr *hdr6; + struct netlbl_af6list *addr6; + + hdr6 = ipv6_hdr(skb); + addr6 = netlbl_af6list_search(&hdr6->saddr, + &iface->addr6_list); + if (addr6 == NULL) + goto unlabel_getattr_nolabel; + secattr->attr.secid = netlbl_unlhsh_addr6_entry(addr6)->secid; + break; + } +#endif /* IPv6 */ + default: + goto unlabel_getattr_nolabel; + } + rcu_read_unlock(); + + secattr->flags |= NETLBL_SECATTR_SECID; + secattr->type = NETLBL_NLTYPE_UNLABELED; + return 0; + +unlabel_getattr_nolabel: + rcu_read_unlock(); + if (netlabel_unlabel_acceptflg == 0) + return -ENOMSG; + secattr->type = NETLBL_NLTYPE_UNLABELED; + return 0; +} + +/** + * netlbl_unlabel_defconf - Set the default config to allow unlabeled packets + * + * Description: + * Set the default NetLabel configuration to allow incoming unlabeled packets + * and to send unlabeled network traffic by default. + * + */ +int __init netlbl_unlabel_defconf(void) +{ + int ret_val; + struct netlbl_dom_map *entry; + struct netlbl_audit audit_info; + + /* Only the kernel is allowed to call this function and the only time + * it is called is at bootup before the audit subsystem is reporting + * messages so don't worry to much about these values. */ + security_current_getsecid_subj(&audit_info.secid); + audit_info.loginuid = GLOBAL_ROOT_UID; + audit_info.sessionid = 0; + + entry = kzalloc(sizeof(*entry), GFP_KERNEL); + if (entry == NULL) + return -ENOMEM; + entry->family = AF_UNSPEC; + entry->def.type = NETLBL_NLTYPE_UNLABELED; + ret_val = netlbl_domhsh_add_default(entry, &audit_info); + if (ret_val != 0) + return ret_val; + + netlbl_unlabel_acceptflg_set(1, &audit_info); + + return 0; +} diff --git a/net/netlabel/netlabel_unlabeled.h b/net/netlabel/netlabel_unlabeled.h new file mode 100644 index 000000000..058e3a285 --- /dev/null +++ b/net/netlabel/netlabel_unlabeled.h @@ -0,0 +1,231 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * NetLabel Unlabeled Support + * + * This file defines functions for dealing with unlabeled packets for the + * NetLabel system. The NetLabel system manages static and dynamic label + * mappings for network protocols such as CIPSO and RIPSO. + * + * Author: Paul Moore + */ + +/* + * (c) Copyright Hewlett-Packard Development Company, L.P., 2006 + */ + +#ifndef _NETLABEL_UNLABELED_H +#define _NETLABEL_UNLABELED_H + +#include + +/* + * The following NetLabel payloads are supported by the Unlabeled subsystem. + * + * o STATICADD + * This message is sent from an application to add a new static label for + * incoming unlabeled connections. + * + * Required attributes: + * + * NLBL_UNLABEL_A_IFACE + * NLBL_UNLABEL_A_SECCTX + * + * If IPv4 is specified the following attributes are required: + * + * NLBL_UNLABEL_A_IPV4ADDR + * NLBL_UNLABEL_A_IPV4MASK + * + * If IPv6 is specified the following attributes are required: + * + * NLBL_UNLABEL_A_IPV6ADDR + * NLBL_UNLABEL_A_IPV6MASK + * + * o STATICREMOVE + * This message is sent from an application to remove an existing static + * label for incoming unlabeled connections. + * + * Required attributes: + * + * NLBL_UNLABEL_A_IFACE + * + * If IPv4 is specified the following attributes are required: + * + * NLBL_UNLABEL_A_IPV4ADDR + * NLBL_UNLABEL_A_IPV4MASK + * + * If IPv6 is specified the following attributes are required: + * + * NLBL_UNLABEL_A_IPV6ADDR + * NLBL_UNLABEL_A_IPV6MASK + * + * o STATICLIST + * This message can be sent either from an application or by the kernel in + * response to an application generated STATICLIST message. When sent by an + * application there is no payload and the NLM_F_DUMP flag should be set. + * The kernel should response with a series of the following messages. + * + * Required attributes: + * + * NLBL_UNLABEL_A_IFACE + * NLBL_UNLABEL_A_SECCTX + * + * If IPv4 is specified the following attributes are required: + * + * NLBL_UNLABEL_A_IPV4ADDR + * NLBL_UNLABEL_A_IPV4MASK + * + * If IPv6 is specified the following attributes are required: + * + * NLBL_UNLABEL_A_IPV6ADDR + * NLBL_UNLABEL_A_IPV6MASK + * + * o STATICADDDEF + * This message is sent from an application to set the default static + * label for incoming unlabeled connections. + * + * Required attribute: + * + * NLBL_UNLABEL_A_SECCTX + * + * If IPv4 is specified the following attributes are required: + * + * NLBL_UNLABEL_A_IPV4ADDR + * NLBL_UNLABEL_A_IPV4MASK + * + * If IPv6 is specified the following attributes are required: + * + * NLBL_UNLABEL_A_IPV6ADDR + * NLBL_UNLABEL_A_IPV6MASK + * + * o STATICREMOVEDEF + * This message is sent from an application to remove the existing default + * static label for incoming unlabeled connections. + * + * If IPv4 is specified the following attributes are required: + * + * NLBL_UNLABEL_A_IPV4ADDR + * NLBL_UNLABEL_A_IPV4MASK + * + * If IPv6 is specified the following attributes are required: + * + * NLBL_UNLABEL_A_IPV6ADDR + * NLBL_UNLABEL_A_IPV6MASK + * + * o STATICLISTDEF + * This message can be sent either from an application or by the kernel in + * response to an application generated STATICLISTDEF message. When sent by + * an application there is no payload and the NLM_F_DUMP flag should be set. + * The kernel should response with the following message. + * + * Required attribute: + * + * NLBL_UNLABEL_A_SECCTX + * + * If IPv4 is specified the following attributes are required: + * + * NLBL_UNLABEL_A_IPV4ADDR + * NLBL_UNLABEL_A_IPV4MASK + * + * If IPv6 is specified the following attributes are required: + * + * NLBL_UNLABEL_A_IPV6ADDR + * NLBL_UNLABEL_A_IPV6MASK + * + * o ACCEPT + * This message is sent from an application to specify if the kernel should + * allow unlabled packets to pass if they do not match any of the static + * mappings defined in the unlabeled module. + * + * Required attributes: + * + * NLBL_UNLABEL_A_ACPTFLG + * + * o LIST + * This message can be sent either from an application or by the kernel in + * response to an application generated LIST message. When sent by an + * application there is no payload. The kernel should respond to a LIST + * message with a LIST message on success. + * + * Required attributes: + * + * NLBL_UNLABEL_A_ACPTFLG + * + */ + +/* NetLabel Unlabeled commands */ +enum { + NLBL_UNLABEL_C_UNSPEC, + NLBL_UNLABEL_C_ACCEPT, + NLBL_UNLABEL_C_LIST, + NLBL_UNLABEL_C_STATICADD, + NLBL_UNLABEL_C_STATICREMOVE, + NLBL_UNLABEL_C_STATICLIST, + NLBL_UNLABEL_C_STATICADDDEF, + NLBL_UNLABEL_C_STATICREMOVEDEF, + NLBL_UNLABEL_C_STATICLISTDEF, + __NLBL_UNLABEL_C_MAX, +}; + +/* NetLabel Unlabeled attributes */ +enum { + NLBL_UNLABEL_A_UNSPEC, + NLBL_UNLABEL_A_ACPTFLG, + /* (NLA_U8) + * if true then unlabeled packets are allowed to pass, else unlabeled + * packets are rejected */ + NLBL_UNLABEL_A_IPV6ADDR, + /* (NLA_BINARY, struct in6_addr) + * an IPv6 address */ + NLBL_UNLABEL_A_IPV6MASK, + /* (NLA_BINARY, struct in6_addr) + * an IPv6 address mask */ + NLBL_UNLABEL_A_IPV4ADDR, + /* (NLA_BINARY, struct in_addr) + * an IPv4 address */ + NLBL_UNLABEL_A_IPV4MASK, + /* (NLA_BINARY, struct in_addr) + * and IPv4 address mask */ + NLBL_UNLABEL_A_IFACE, + /* (NLA_NULL_STRING) + * network interface */ + NLBL_UNLABEL_A_SECCTX, + /* (NLA_BINARY) + * a LSM specific security context */ + __NLBL_UNLABEL_A_MAX, +}; +#define NLBL_UNLABEL_A_MAX (__NLBL_UNLABEL_A_MAX - 1) + +/* NetLabel protocol functions */ +int netlbl_unlabel_genl_init(void); + +/* Unlabeled connection hash table size */ +/* XXX - currently this number is an uneducated guess */ +#define NETLBL_UNLHSH_BITSIZE 7 + +/* General Unlabeled init function */ +int netlbl_unlabel_init(u32 size); + +/* Static/Fallback label management functions */ +int netlbl_unlhsh_add(struct net *net, + const char *dev_name, + const void *addr, + const void *mask, + u32 addr_len, + u32 secid, + struct netlbl_audit *audit_info); +int netlbl_unlhsh_remove(struct net *net, + const char *dev_name, + const void *addr, + const void *mask, + u32 addr_len, + struct netlbl_audit *audit_info); + +/* Process Unlabeled incoming network packets */ +int netlbl_unlabel_getattr(const struct sk_buff *skb, + u16 family, + struct netlbl_lsm_secattr *secattr); + +/* Set the default configuration to allow Unlabeled packets */ +int netlbl_unlabel_defconf(void); + +#endif diff --git a/net/netlabel/netlabel_user.c b/net/netlabel/netlabel_user.c new file mode 100644 index 000000000..3ed4fea2a --- /dev/null +++ b/net/netlabel/netlabel_user.c @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NetLabel NETLINK Interface + * + * This file defines the NETLINK interface for the NetLabel system. The + * NetLabel system manages static and dynamic label mappings for network + * protocols such as CIPSO and RIPSO. + * + * Author: Paul Moore + */ + +/* + * (c) Copyright Hewlett-Packard Development Company, L.P., 2006 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "netlabel_mgmt.h" +#include "netlabel_unlabeled.h" +#include "netlabel_cipso_v4.h" +#include "netlabel_calipso.h" +#include "netlabel_user.h" + +/* + * NetLabel NETLINK Setup Functions + */ + +/** + * netlbl_netlink_init - Initialize the NETLINK communication channel + * + * Description: + * Call out to the NetLabel components so they can register their families and + * commands with the Generic NETLINK mechanism. Returns zero on success and + * non-zero on failure. + * + */ +int __init netlbl_netlink_init(void) +{ + int ret_val; + + ret_val = netlbl_mgmt_genl_init(); + if (ret_val != 0) + return ret_val; + + ret_val = netlbl_cipsov4_genl_init(); + if (ret_val != 0) + return ret_val; + + ret_val = netlbl_calipso_genl_init(); + if (ret_val != 0) + return ret_val; + + return netlbl_unlabel_genl_init(); +} + +/* + * NetLabel Audit Functions + */ + +/** + * netlbl_audit_start_common - Start an audit message + * @type: audit message type + * @audit_info: NetLabel audit information + * + * Description: + * Start an audit message using the type specified in @type and fill the audit + * message with some fields common to all NetLabel audit messages. Returns + * a pointer to the audit buffer on success, NULL on failure. + * + */ +struct audit_buffer *netlbl_audit_start_common(int type, + struct netlbl_audit *audit_info) +{ + struct audit_buffer *audit_buf; + char *secctx; + u32 secctx_len; + + if (audit_enabled == AUDIT_OFF) + return NULL; + + audit_buf = audit_log_start(audit_context(), GFP_ATOMIC, type); + if (audit_buf == NULL) + return NULL; + + audit_log_format(audit_buf, "netlabel: auid=%u ses=%u", + from_kuid(&init_user_ns, audit_info->loginuid), + audit_info->sessionid); + + if (audit_info->secid != 0 && + security_secid_to_secctx(audit_info->secid, + &secctx, + &secctx_len) == 0) { + audit_log_format(audit_buf, " subj=%s", secctx); + security_release_secctx(secctx, secctx_len); + } + + return audit_buf; +} diff --git a/net/netlabel/netlabel_user.h b/net/netlabel/netlabel_user.h new file mode 100644 index 000000000..d6c5b31eb --- /dev/null +++ b/net/netlabel/netlabel_user.h @@ -0,0 +1,49 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * NetLabel NETLINK Interface + * + * This file defines the NETLINK interface for the NetLabel system. The + * NetLabel system manages static and dynamic label mappings for network + * protocols such as CIPSO and RIPSO. + * + * Author: Paul Moore + */ + +/* + * (c) Copyright Hewlett-Packard Development Company, L.P., 2006 + */ + +#ifndef _NETLABEL_USER_H +#define _NETLABEL_USER_H + +#include +#include +#include +#include +#include +#include +#include + +/* NetLabel NETLINK helper functions */ + +/** + * netlbl_netlink_auditinfo - Fetch the audit information from a NETLINK msg + * @audit_info: NetLabel audit information + */ +static inline void netlbl_netlink_auditinfo(struct netlbl_audit *audit_info) +{ + security_current_getsecid_subj(&audit_info->secid); + audit_info->loginuid = audit_get_loginuid(current); + audit_info->sessionid = audit_get_sessionid(current); +} + +/* NetLabel NETLINK I/O functions */ + +int netlbl_netlink_init(void); + +/* NetLabel Audit Functions */ + +struct audit_buffer *netlbl_audit_start_common(int type, + struct netlbl_audit *audit_info); + +#endif diff --git a/net/netlink/Kconfig b/net/netlink/Kconfig new file mode 100644 index 000000000..1039d4f2c --- /dev/null +++ b/net/netlink/Kconfig @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Netlink Sockets +# + +config NETLINK_DIAG + tristate "NETLINK: socket monitoring interface" + default n + help + Support for NETLINK socket monitoring interface used by the ss tool. + If unsure, say Y. diff --git a/net/netlink/Makefile b/net/netlink/Makefile new file mode 100644 index 000000000..e05202708 --- /dev/null +++ b/net/netlink/Makefile @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the netlink driver. +# + +obj-y := af_netlink.o genetlink.o policy.o + +obj-$(CONFIG_NETLINK_DIAG) += netlink_diag.o +netlink_diag-y := diag.o diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c new file mode 100644 index 000000000..6857a4965 --- /dev/null +++ b/net/netlink/af_netlink.c @@ -0,0 +1,2917 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NETLINK Kernel-user communication protocol. + * + * Authors: Alan Cox + * Alexey Kuznetsov + * Patrick McHardy + * + * Tue Jun 26 14:36:48 MEST 2001 Herbert "herp" Rosmanith + * added netlink_proto_exit + * Tue Jan 22 18:32:44 BRST 2002 Arnaldo C. de Melo + * use nlk_sk, as sk->protinfo is on a diet 8) + * Fri Jul 22 19:51:12 MEST 2005 Harald Welte + * - inc module use count of module that owns + * the kernel socket in case userspace opens + * socket of same protocol + * - remove all module support, since netlink is + * mandatory if CONFIG_NET=y these days + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#define CREATE_TRACE_POINTS +#include + +#include "af_netlink.h" + +struct listeners { + struct rcu_head rcu; + unsigned long masks[]; +}; + +/* state bits */ +#define NETLINK_S_CONGESTED 0x0 + +static inline int netlink_is_kernel(struct sock *sk) +{ + return nlk_test_bit(KERNEL_SOCKET, sk); +} + +struct netlink_table *nl_table __read_mostly; +EXPORT_SYMBOL_GPL(nl_table); + +static DECLARE_WAIT_QUEUE_HEAD(nl_table_wait); + +static struct lock_class_key nlk_cb_mutex_keys[MAX_LINKS]; + +static const char *const nlk_cb_mutex_key_strings[MAX_LINKS + 1] = { + "nlk_cb_mutex-ROUTE", + "nlk_cb_mutex-1", + "nlk_cb_mutex-USERSOCK", + "nlk_cb_mutex-FIREWALL", + "nlk_cb_mutex-SOCK_DIAG", + "nlk_cb_mutex-NFLOG", + "nlk_cb_mutex-XFRM", + "nlk_cb_mutex-SELINUX", + "nlk_cb_mutex-ISCSI", + "nlk_cb_mutex-AUDIT", + "nlk_cb_mutex-FIB_LOOKUP", + "nlk_cb_mutex-CONNECTOR", + "nlk_cb_mutex-NETFILTER", + "nlk_cb_mutex-IP6_FW", + "nlk_cb_mutex-DNRTMSG", + "nlk_cb_mutex-KOBJECT_UEVENT", + "nlk_cb_mutex-GENERIC", + "nlk_cb_mutex-17", + "nlk_cb_mutex-SCSITRANSPORT", + "nlk_cb_mutex-ECRYPTFS", + "nlk_cb_mutex-RDMA", + "nlk_cb_mutex-CRYPTO", + "nlk_cb_mutex-SMC", + "nlk_cb_mutex-23", + "nlk_cb_mutex-24", + "nlk_cb_mutex-25", + "nlk_cb_mutex-26", + "nlk_cb_mutex-27", + "nlk_cb_mutex-28", + "nlk_cb_mutex-29", + "nlk_cb_mutex-30", + "nlk_cb_mutex-31", + "nlk_cb_mutex-MAX_LINKS" +}; + +static int netlink_dump(struct sock *sk); + +/* nl_table locking explained: + * Lookup and traversal are protected with an RCU read-side lock. Insertion + * and removal are protected with per bucket lock while using RCU list + * modification primitives and may run in parallel to RCU protected lookups. + * Destruction of the Netlink socket may only occur *after* nl_table_lock has + * been acquired * either during or after the socket has been removed from + * the list and after an RCU grace period. + */ +DEFINE_RWLOCK(nl_table_lock); +EXPORT_SYMBOL_GPL(nl_table_lock); +static atomic_t nl_table_users = ATOMIC_INIT(0); + +#define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock)); + +static BLOCKING_NOTIFIER_HEAD(netlink_chain); + + +static const struct rhashtable_params netlink_rhashtable_params; + +void do_trace_netlink_extack(const char *msg) +{ + trace_netlink_extack(msg); +} +EXPORT_SYMBOL(do_trace_netlink_extack); + +static inline u32 netlink_group_mask(u32 group) +{ + if (group > 32) + return 0; + return group ? 1 << (group - 1) : 0; +} + +static struct sk_buff *netlink_to_full_skb(const struct sk_buff *skb, + gfp_t gfp_mask) +{ + unsigned int len = skb_end_offset(skb); + struct sk_buff *new; + + new = alloc_skb(len, gfp_mask); + if (new == NULL) + return NULL; + + NETLINK_CB(new).portid = NETLINK_CB(skb).portid; + NETLINK_CB(new).dst_group = NETLINK_CB(skb).dst_group; + NETLINK_CB(new).creds = NETLINK_CB(skb).creds; + + skb_put_data(new, skb->data, len); + return new; +} + +static unsigned int netlink_tap_net_id; + +struct netlink_tap_net { + struct list_head netlink_tap_all; + struct mutex netlink_tap_lock; +}; + +int netlink_add_tap(struct netlink_tap *nt) +{ + struct net *net = dev_net(nt->dev); + struct netlink_tap_net *nn = net_generic(net, netlink_tap_net_id); + + if (unlikely(nt->dev->type != ARPHRD_NETLINK)) + return -EINVAL; + + mutex_lock(&nn->netlink_tap_lock); + list_add_rcu(&nt->list, &nn->netlink_tap_all); + mutex_unlock(&nn->netlink_tap_lock); + + __module_get(nt->module); + + return 0; +} +EXPORT_SYMBOL_GPL(netlink_add_tap); + +static int __netlink_remove_tap(struct netlink_tap *nt) +{ + struct net *net = dev_net(nt->dev); + struct netlink_tap_net *nn = net_generic(net, netlink_tap_net_id); + bool found = false; + struct netlink_tap *tmp; + + mutex_lock(&nn->netlink_tap_lock); + + list_for_each_entry(tmp, &nn->netlink_tap_all, list) { + if (nt == tmp) { + list_del_rcu(&nt->list); + found = true; + goto out; + } + } + + pr_warn("__netlink_remove_tap: %p not found\n", nt); +out: + mutex_unlock(&nn->netlink_tap_lock); + + if (found) + module_put(nt->module); + + return found ? 0 : -ENODEV; +} + +int netlink_remove_tap(struct netlink_tap *nt) +{ + int ret; + + ret = __netlink_remove_tap(nt); + synchronize_net(); + + return ret; +} +EXPORT_SYMBOL_GPL(netlink_remove_tap); + +static __net_init int netlink_tap_init_net(struct net *net) +{ + struct netlink_tap_net *nn = net_generic(net, netlink_tap_net_id); + + INIT_LIST_HEAD(&nn->netlink_tap_all); + mutex_init(&nn->netlink_tap_lock); + return 0; +} + +static struct pernet_operations netlink_tap_net_ops = { + .init = netlink_tap_init_net, + .id = &netlink_tap_net_id, + .size = sizeof(struct netlink_tap_net), +}; + +static bool netlink_filter_tap(const struct sk_buff *skb) +{ + struct sock *sk = skb->sk; + + /* We take the more conservative approach and + * whitelist socket protocols that may pass. + */ + switch (sk->sk_protocol) { + case NETLINK_ROUTE: + case NETLINK_USERSOCK: + case NETLINK_SOCK_DIAG: + case NETLINK_NFLOG: + case NETLINK_XFRM: + case NETLINK_FIB_LOOKUP: + case NETLINK_NETFILTER: + case NETLINK_GENERIC: + return true; + } + + return false; +} + +static int __netlink_deliver_tap_skb(struct sk_buff *skb, + struct net_device *dev) +{ + struct sk_buff *nskb; + struct sock *sk = skb->sk; + int ret = -ENOMEM; + + if (!net_eq(dev_net(dev), sock_net(sk))) + return 0; + + dev_hold(dev); + + if (is_vmalloc_addr(skb->head)) + nskb = netlink_to_full_skb(skb, GFP_ATOMIC); + else + nskb = skb_clone(skb, GFP_ATOMIC); + if (nskb) { + nskb->dev = dev; + nskb->protocol = htons((u16) sk->sk_protocol); + nskb->pkt_type = netlink_is_kernel(sk) ? + PACKET_KERNEL : PACKET_USER; + skb_reset_network_header(nskb); + ret = dev_queue_xmit(nskb); + if (unlikely(ret > 0)) + ret = net_xmit_errno(ret); + } + + dev_put(dev); + return ret; +} + +static void __netlink_deliver_tap(struct sk_buff *skb, struct netlink_tap_net *nn) +{ + int ret; + struct netlink_tap *tmp; + + if (!netlink_filter_tap(skb)) + return; + + list_for_each_entry_rcu(tmp, &nn->netlink_tap_all, list) { + ret = __netlink_deliver_tap_skb(skb, tmp->dev); + if (unlikely(ret)) + break; + } +} + +static void netlink_deliver_tap(struct net *net, struct sk_buff *skb) +{ + struct netlink_tap_net *nn = net_generic(net, netlink_tap_net_id); + + rcu_read_lock(); + + if (unlikely(!list_empty(&nn->netlink_tap_all))) + __netlink_deliver_tap(skb, nn); + + rcu_read_unlock(); +} + +static void netlink_deliver_tap_kernel(struct sock *dst, struct sock *src, + struct sk_buff *skb) +{ + if (!(netlink_is_kernel(dst) && netlink_is_kernel(src))) + netlink_deliver_tap(sock_net(dst), skb); +} + +static void netlink_overrun(struct sock *sk) +{ + if (!nlk_test_bit(RECV_NO_ENOBUFS, sk)) { + if (!test_and_set_bit(NETLINK_S_CONGESTED, + &nlk_sk(sk)->state)) { + WRITE_ONCE(sk->sk_err, ENOBUFS); + sk_error_report(sk); + } + } + atomic_inc(&sk->sk_drops); +} + +static void netlink_rcv_wake(struct sock *sk) +{ + struct netlink_sock *nlk = nlk_sk(sk); + + if (skb_queue_empty_lockless(&sk->sk_receive_queue)) + clear_bit(NETLINK_S_CONGESTED, &nlk->state); + if (!test_bit(NETLINK_S_CONGESTED, &nlk->state)) + wake_up_interruptible(&nlk->wait); +} + +static void netlink_skb_destructor(struct sk_buff *skb) +{ + if (is_vmalloc_addr(skb->head)) { + if (!skb->cloned || + !atomic_dec_return(&(skb_shinfo(skb)->dataref))) + vfree_atomic(skb->head); + + skb->head = NULL; + } + if (skb->sk != NULL) + sock_rfree(skb); +} + +static void netlink_skb_set_owner_r(struct sk_buff *skb, struct sock *sk) +{ + WARN_ON(skb->sk != NULL); + skb->sk = sk; + skb->destructor = netlink_skb_destructor; + atomic_add(skb->truesize, &sk->sk_rmem_alloc); + sk_mem_charge(sk, skb->truesize); +} + +static void netlink_sock_destruct(struct sock *sk) +{ + struct netlink_sock *nlk = nlk_sk(sk); + + if (nlk->cb_running) { + if (nlk->cb.done) + nlk->cb.done(&nlk->cb); + module_put(nlk->cb.module); + kfree_skb(nlk->cb.skb); + } + + skb_queue_purge(&sk->sk_receive_queue); + + if (!sock_flag(sk, SOCK_DEAD)) { + printk(KERN_ERR "Freeing alive netlink socket %p\n", sk); + return; + } + + WARN_ON(atomic_read(&sk->sk_rmem_alloc)); + WARN_ON(refcount_read(&sk->sk_wmem_alloc)); + WARN_ON(nlk_sk(sk)->groups); +} + +static void netlink_sock_destruct_work(struct work_struct *work) +{ + struct netlink_sock *nlk = container_of(work, struct netlink_sock, + work); + + sk_free(&nlk->sk); +} + +/* This lock without WQ_FLAG_EXCLUSIVE is good on UP and it is _very_ bad on + * SMP. Look, when several writers sleep and reader wakes them up, all but one + * immediately hit write lock and grab all the cpus. Exclusive sleep solves + * this, _but_ remember, it adds useless work on UP machines. + */ + +void netlink_table_grab(void) + __acquires(nl_table_lock) +{ + might_sleep(); + + write_lock_irq(&nl_table_lock); + + if (atomic_read(&nl_table_users)) { + DECLARE_WAITQUEUE(wait, current); + + add_wait_queue_exclusive(&nl_table_wait, &wait); + for (;;) { + set_current_state(TASK_UNINTERRUPTIBLE); + if (atomic_read(&nl_table_users) == 0) + break; + write_unlock_irq(&nl_table_lock); + schedule(); + write_lock_irq(&nl_table_lock); + } + + __set_current_state(TASK_RUNNING); + remove_wait_queue(&nl_table_wait, &wait); + } +} + +void netlink_table_ungrab(void) + __releases(nl_table_lock) +{ + write_unlock_irq(&nl_table_lock); + wake_up(&nl_table_wait); +} + +static inline void +netlink_lock_table(void) +{ + unsigned long flags; + + /* read_lock() synchronizes us to netlink_table_grab */ + + read_lock_irqsave(&nl_table_lock, flags); + atomic_inc(&nl_table_users); + read_unlock_irqrestore(&nl_table_lock, flags); +} + +static inline void +netlink_unlock_table(void) +{ + if (atomic_dec_and_test(&nl_table_users)) + wake_up(&nl_table_wait); +} + +struct netlink_compare_arg +{ + possible_net_t pnet; + u32 portid; +}; + +/* Doing sizeof directly may yield 4 extra bytes on 64-bit. */ +#define netlink_compare_arg_len \ + (offsetof(struct netlink_compare_arg, portid) + sizeof(u32)) + +static inline int netlink_compare(struct rhashtable_compare_arg *arg, + const void *ptr) +{ + const struct netlink_compare_arg *x = arg->key; + const struct netlink_sock *nlk = ptr; + + return nlk->portid != x->portid || + !net_eq(sock_net(&nlk->sk), read_pnet(&x->pnet)); +} + +static void netlink_compare_arg_init(struct netlink_compare_arg *arg, + struct net *net, u32 portid) +{ + memset(arg, 0, sizeof(*arg)); + write_pnet(&arg->pnet, net); + arg->portid = portid; +} + +static struct sock *__netlink_lookup(struct netlink_table *table, u32 portid, + struct net *net) +{ + struct netlink_compare_arg arg; + + netlink_compare_arg_init(&arg, net, portid); + return rhashtable_lookup_fast(&table->hash, &arg, + netlink_rhashtable_params); +} + +static int __netlink_insert(struct netlink_table *table, struct sock *sk) +{ + struct netlink_compare_arg arg; + + netlink_compare_arg_init(&arg, sock_net(sk), nlk_sk(sk)->portid); + return rhashtable_lookup_insert_key(&table->hash, &arg, + &nlk_sk(sk)->node, + netlink_rhashtable_params); +} + +static struct sock *netlink_lookup(struct net *net, int protocol, u32 portid) +{ + struct netlink_table *table = &nl_table[protocol]; + struct sock *sk; + + rcu_read_lock(); + sk = __netlink_lookup(table, portid, net); + if (sk) + sock_hold(sk); + rcu_read_unlock(); + + return sk; +} + +static const struct proto_ops netlink_ops; + +static void +netlink_update_listeners(struct sock *sk) +{ + struct netlink_table *tbl = &nl_table[sk->sk_protocol]; + unsigned long mask; + unsigned int i; + struct listeners *listeners; + + listeners = nl_deref_protected(tbl->listeners); + if (!listeners) + return; + + for (i = 0; i < NLGRPLONGS(tbl->groups); i++) { + mask = 0; + sk_for_each_bound(sk, &tbl->mc_list) { + if (i < NLGRPLONGS(nlk_sk(sk)->ngroups)) + mask |= nlk_sk(sk)->groups[i]; + } + listeners->masks[i] = mask; + } + /* this function is only called with the netlink table "grabbed", which + * makes sure updates are visible before bind or setsockopt return. */ +} + +static int netlink_insert(struct sock *sk, u32 portid) +{ + struct netlink_table *table = &nl_table[sk->sk_protocol]; + int err; + + lock_sock(sk); + + err = nlk_sk(sk)->portid == portid ? 0 : -EBUSY; + if (nlk_sk(sk)->bound) + goto err; + + /* portid can be read locklessly from netlink_getname(). */ + WRITE_ONCE(nlk_sk(sk)->portid, portid); + + sock_hold(sk); + + err = __netlink_insert(table, sk); + if (err) { + /* In case the hashtable backend returns with -EBUSY + * from here, it must not escape to the caller. + */ + if (unlikely(err == -EBUSY)) + err = -EOVERFLOW; + if (err == -EEXIST) + err = -EADDRINUSE; + sock_put(sk); + goto err; + } + + /* We need to ensure that the socket is hashed and visible. */ + smp_wmb(); + /* Paired with lockless reads from netlink_bind(), + * netlink_connect() and netlink_sendmsg(). + */ + WRITE_ONCE(nlk_sk(sk)->bound, portid); + +err: + release_sock(sk); + return err; +} + +static void netlink_remove(struct sock *sk) +{ + struct netlink_table *table; + + table = &nl_table[sk->sk_protocol]; + if (!rhashtable_remove_fast(&table->hash, &nlk_sk(sk)->node, + netlink_rhashtable_params)) { + WARN_ON(refcount_read(&sk->sk_refcnt) == 1); + __sock_put(sk); + } + + netlink_table_grab(); + if (nlk_sk(sk)->subscriptions) { + __sk_del_bind_node(sk); + netlink_update_listeners(sk); + } + if (sk->sk_protocol == NETLINK_GENERIC) + atomic_inc(&genl_sk_destructing_cnt); + netlink_table_ungrab(); +} + +static struct proto netlink_proto = { + .name = "NETLINK", + .owner = THIS_MODULE, + .obj_size = sizeof(struct netlink_sock), +}; + +static int __netlink_create(struct net *net, struct socket *sock, + struct mutex *cb_mutex, int protocol, + int kern) +{ + struct sock *sk; + struct netlink_sock *nlk; + + sock->ops = &netlink_ops; + + sk = sk_alloc(net, PF_NETLINK, GFP_KERNEL, &netlink_proto, kern); + if (!sk) + return -ENOMEM; + + sock_init_data(sock, sk); + + nlk = nlk_sk(sk); + if (cb_mutex) { + nlk->cb_mutex = cb_mutex; + } else { + nlk->cb_mutex = &nlk->cb_def_mutex; + mutex_init(nlk->cb_mutex); + lockdep_set_class_and_name(nlk->cb_mutex, + nlk_cb_mutex_keys + protocol, + nlk_cb_mutex_key_strings[protocol]); + } + init_waitqueue_head(&nlk->wait); + + sk->sk_destruct = netlink_sock_destruct; + sk->sk_protocol = protocol; + return 0; +} + +static int netlink_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct module *module = NULL; + struct mutex *cb_mutex; + struct netlink_sock *nlk; + int (*bind)(struct net *net, int group); + void (*unbind)(struct net *net, int group); + int err = 0; + + sock->state = SS_UNCONNECTED; + + if (sock->type != SOCK_RAW && sock->type != SOCK_DGRAM) + return -ESOCKTNOSUPPORT; + + if (protocol < 0 || protocol >= MAX_LINKS) + return -EPROTONOSUPPORT; + protocol = array_index_nospec(protocol, MAX_LINKS); + + netlink_lock_table(); +#ifdef CONFIG_MODULES + if (!nl_table[protocol].registered) { + netlink_unlock_table(); + request_module("net-pf-%d-proto-%d", PF_NETLINK, protocol); + netlink_lock_table(); + } +#endif + if (nl_table[protocol].registered && + try_module_get(nl_table[protocol].module)) + module = nl_table[protocol].module; + else + err = -EPROTONOSUPPORT; + cb_mutex = nl_table[protocol].cb_mutex; + bind = nl_table[protocol].bind; + unbind = nl_table[protocol].unbind; + netlink_unlock_table(); + + if (err < 0) + goto out; + + err = __netlink_create(net, sock, cb_mutex, protocol, kern); + if (err < 0) + goto out_module; + + sock_prot_inuse_add(net, &netlink_proto, 1); + + nlk = nlk_sk(sock->sk); + nlk->module = module; + nlk->netlink_bind = bind; + nlk->netlink_unbind = unbind; +out: + return err; + +out_module: + module_put(module); + goto out; +} + +static void deferred_put_nlk_sk(struct rcu_head *head) +{ + struct netlink_sock *nlk = container_of(head, struct netlink_sock, rcu); + struct sock *sk = &nlk->sk; + + kfree(nlk->groups); + nlk->groups = NULL; + + if (!refcount_dec_and_test(&sk->sk_refcnt)) + return; + + if (nlk->cb_running && nlk->cb.done) { + INIT_WORK(&nlk->work, netlink_sock_destruct_work); + schedule_work(&nlk->work); + return; + } + + sk_free(sk); +} + +static int netlink_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct netlink_sock *nlk; + + if (!sk) + return 0; + + netlink_remove(sk); + sock_orphan(sk); + nlk = nlk_sk(sk); + + /* + * OK. Socket is unlinked, any packets that arrive now + * will be purged. + */ + + /* must not acquire netlink_table_lock in any way again before unbind + * and notifying genetlink is done as otherwise it might deadlock + */ + if (nlk->netlink_unbind) { + int i; + + for (i = 0; i < nlk->ngroups; i++) + if (test_bit(i, nlk->groups)) + nlk->netlink_unbind(sock_net(sk), i + 1); + } + if (sk->sk_protocol == NETLINK_GENERIC && + atomic_dec_return(&genl_sk_destructing_cnt) == 0) + wake_up(&genl_sk_destructing_waitq); + + sock->sk = NULL; + wake_up_interruptible_all(&nlk->wait); + + skb_queue_purge(&sk->sk_write_queue); + + if (nlk->portid && nlk->bound) { + struct netlink_notify n = { + .net = sock_net(sk), + .protocol = sk->sk_protocol, + .portid = nlk->portid, + }; + blocking_notifier_call_chain(&netlink_chain, + NETLINK_URELEASE, &n); + } + + module_put(nlk->module); + + if (netlink_is_kernel(sk)) { + netlink_table_grab(); + BUG_ON(nl_table[sk->sk_protocol].registered == 0); + if (--nl_table[sk->sk_protocol].registered == 0) { + struct listeners *old; + + old = nl_deref_protected(nl_table[sk->sk_protocol].listeners); + RCU_INIT_POINTER(nl_table[sk->sk_protocol].listeners, NULL); + kfree_rcu(old, rcu); + nl_table[sk->sk_protocol].module = NULL; + nl_table[sk->sk_protocol].bind = NULL; + nl_table[sk->sk_protocol].unbind = NULL; + nl_table[sk->sk_protocol].flags = 0; + nl_table[sk->sk_protocol].registered = 0; + } + netlink_table_ungrab(); + } + + sock_prot_inuse_add(sock_net(sk), &netlink_proto, -1); + call_rcu(&nlk->rcu, deferred_put_nlk_sk); + return 0; +} + +static int netlink_autobind(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct net *net = sock_net(sk); + struct netlink_table *table = &nl_table[sk->sk_protocol]; + s32 portid = task_tgid_vnr(current); + int err; + s32 rover = -4096; + bool ok; + +retry: + cond_resched(); + rcu_read_lock(); + ok = !__netlink_lookup(table, portid, net); + rcu_read_unlock(); + if (!ok) { + /* Bind collision, search negative portid values. */ + if (rover == -4096) + /* rover will be in range [S32_MIN, -4097] */ + rover = S32_MIN + prandom_u32_max(-4096 - S32_MIN); + else if (rover >= -4096) + rover = -4097; + portid = rover--; + goto retry; + } + + err = netlink_insert(sk, portid); + if (err == -EADDRINUSE) + goto retry; + + /* If 2 threads race to autobind, that is fine. */ + if (err == -EBUSY) + err = 0; + + return err; +} + +/** + * __netlink_ns_capable - General netlink message capability test + * @nsp: NETLINK_CB of the socket buffer holding a netlink command from userspace. + * @user_ns: The user namespace of the capability to use + * @cap: The capability to use + * + * Test to see if the opener of the socket we received the message + * from had when the netlink socket was created and the sender of the + * message has the capability @cap in the user namespace @user_ns. + */ +bool __netlink_ns_capable(const struct netlink_skb_parms *nsp, + struct user_namespace *user_ns, int cap) +{ + return ((nsp->flags & NETLINK_SKB_DST) || + file_ns_capable(nsp->sk->sk_socket->file, user_ns, cap)) && + ns_capable(user_ns, cap); +} +EXPORT_SYMBOL(__netlink_ns_capable); + +/** + * netlink_ns_capable - General netlink message capability test + * @skb: socket buffer holding a netlink command from userspace + * @user_ns: The user namespace of the capability to use + * @cap: The capability to use + * + * Test to see if the opener of the socket we received the message + * from had when the netlink socket was created and the sender of the + * message has the capability @cap in the user namespace @user_ns. + */ +bool netlink_ns_capable(const struct sk_buff *skb, + struct user_namespace *user_ns, int cap) +{ + return __netlink_ns_capable(&NETLINK_CB(skb), user_ns, cap); +} +EXPORT_SYMBOL(netlink_ns_capable); + +/** + * netlink_capable - Netlink global message capability test + * @skb: socket buffer holding a netlink command from userspace + * @cap: The capability to use + * + * Test to see if the opener of the socket we received the message + * from had when the netlink socket was created and the sender of the + * message has the capability @cap in all user namespaces. + */ +bool netlink_capable(const struct sk_buff *skb, int cap) +{ + return netlink_ns_capable(skb, &init_user_ns, cap); +} +EXPORT_SYMBOL(netlink_capable); + +/** + * netlink_net_capable - Netlink network namespace message capability test + * @skb: socket buffer holding a netlink command from userspace + * @cap: The capability to use + * + * Test to see if the opener of the socket we received the message + * from had when the netlink socket was created and the sender of the + * message has the capability @cap over the network namespace of + * the socket we received the message from. + */ +bool netlink_net_capable(const struct sk_buff *skb, int cap) +{ + return netlink_ns_capable(skb, sock_net(skb->sk)->user_ns, cap); +} +EXPORT_SYMBOL(netlink_net_capable); + +static inline int netlink_allowed(const struct socket *sock, unsigned int flag) +{ + return (nl_table[sock->sk->sk_protocol].flags & flag) || + ns_capable(sock_net(sock->sk)->user_ns, CAP_NET_ADMIN); +} + +static void +netlink_update_subscriptions(struct sock *sk, unsigned int subscriptions) +{ + struct netlink_sock *nlk = nlk_sk(sk); + + if (nlk->subscriptions && !subscriptions) + __sk_del_bind_node(sk); + else if (!nlk->subscriptions && subscriptions) + sk_add_bind_node(sk, &nl_table[sk->sk_protocol].mc_list); + nlk->subscriptions = subscriptions; +} + +static int netlink_realloc_groups(struct sock *sk) +{ + struct netlink_sock *nlk = nlk_sk(sk); + unsigned int groups; + unsigned long *new_groups; + int err = 0; + + netlink_table_grab(); + + groups = nl_table[sk->sk_protocol].groups; + if (!nl_table[sk->sk_protocol].registered) { + err = -ENOENT; + goto out_unlock; + } + + if (nlk->ngroups >= groups) + goto out_unlock; + + new_groups = krealloc(nlk->groups, NLGRPSZ(groups), GFP_ATOMIC); + if (new_groups == NULL) { + err = -ENOMEM; + goto out_unlock; + } + memset((char *)new_groups + NLGRPSZ(nlk->ngroups), 0, + NLGRPSZ(groups) - NLGRPSZ(nlk->ngroups)); + + nlk->groups = new_groups; + nlk->ngroups = groups; + out_unlock: + netlink_table_ungrab(); + return err; +} + +static void netlink_undo_bind(int group, long unsigned int groups, + struct sock *sk) +{ + struct netlink_sock *nlk = nlk_sk(sk); + int undo; + + if (!nlk->netlink_unbind) + return; + + for (undo = 0; undo < group; undo++) + if (test_bit(undo, &groups)) + nlk->netlink_unbind(sock_net(sk), undo + 1); +} + +static int netlink_bind(struct socket *sock, struct sockaddr *addr, + int addr_len) +{ + struct sock *sk = sock->sk; + struct net *net = sock_net(sk); + struct netlink_sock *nlk = nlk_sk(sk); + struct sockaddr_nl *nladdr = (struct sockaddr_nl *)addr; + int err = 0; + unsigned long groups; + bool bound; + + if (addr_len < sizeof(struct sockaddr_nl)) + return -EINVAL; + + if (nladdr->nl_family != AF_NETLINK) + return -EINVAL; + groups = nladdr->nl_groups; + + /* Only superuser is allowed to listen multicasts */ + if (groups) { + if (!netlink_allowed(sock, NL_CFG_F_NONROOT_RECV)) + return -EPERM; + err = netlink_realloc_groups(sk); + if (err) + return err; + } + + if (nlk->ngroups < BITS_PER_LONG) + groups &= (1UL << nlk->ngroups) - 1; + + /* Paired with WRITE_ONCE() in netlink_insert() */ + bound = READ_ONCE(nlk->bound); + if (bound) { + /* Ensure nlk->portid is up-to-date. */ + smp_rmb(); + + if (nladdr->nl_pid != nlk->portid) + return -EINVAL; + } + + if (nlk->netlink_bind && groups) { + int group; + + /* nl_groups is a u32, so cap the maximum groups we can bind */ + for (group = 0; group < BITS_PER_TYPE(u32); group++) { + if (!test_bit(group, &groups)) + continue; + err = nlk->netlink_bind(net, group + 1); + if (!err) + continue; + netlink_undo_bind(group, groups, sk); + return err; + } + } + + /* No need for barriers here as we return to user-space without + * using any of the bound attributes. + */ + netlink_lock_table(); + if (!bound) { + err = nladdr->nl_pid ? + netlink_insert(sk, nladdr->nl_pid) : + netlink_autobind(sock); + if (err) { + netlink_undo_bind(BITS_PER_TYPE(u32), groups, sk); + goto unlock; + } + } + + if (!groups && (nlk->groups == NULL || !(u32)nlk->groups[0])) + goto unlock; + netlink_unlock_table(); + + netlink_table_grab(); + netlink_update_subscriptions(sk, nlk->subscriptions + + hweight32(groups) - + hweight32(nlk->groups[0])); + nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | groups; + netlink_update_listeners(sk); + netlink_table_ungrab(); + + return 0; + +unlock: + netlink_unlock_table(); + return err; +} + +static int netlink_connect(struct socket *sock, struct sockaddr *addr, + int alen, int flags) +{ + int err = 0; + struct sock *sk = sock->sk; + struct netlink_sock *nlk = nlk_sk(sk); + struct sockaddr_nl *nladdr = (struct sockaddr_nl *)addr; + + if (alen < sizeof(addr->sa_family)) + return -EINVAL; + + if (addr->sa_family == AF_UNSPEC) { + /* paired with READ_ONCE() in netlink_getsockbyportid() */ + WRITE_ONCE(sk->sk_state, NETLINK_UNCONNECTED); + /* dst_portid and dst_group can be read locklessly */ + WRITE_ONCE(nlk->dst_portid, 0); + WRITE_ONCE(nlk->dst_group, 0); + return 0; + } + if (addr->sa_family != AF_NETLINK) + return -EINVAL; + + if (alen < sizeof(struct sockaddr_nl)) + return -EINVAL; + + if ((nladdr->nl_groups || nladdr->nl_pid) && + !netlink_allowed(sock, NL_CFG_F_NONROOT_SEND)) + return -EPERM; + + /* No need for barriers here as we return to user-space without + * using any of the bound attributes. + * Paired with WRITE_ONCE() in netlink_insert(). + */ + if (!READ_ONCE(nlk->bound)) + err = netlink_autobind(sock); + + if (err == 0) { + /* paired with READ_ONCE() in netlink_getsockbyportid() */ + WRITE_ONCE(sk->sk_state, NETLINK_CONNECTED); + /* dst_portid and dst_group can be read locklessly */ + WRITE_ONCE(nlk->dst_portid, nladdr->nl_pid); + WRITE_ONCE(nlk->dst_group, ffs(nladdr->nl_groups)); + } + + return err; +} + +static int netlink_getname(struct socket *sock, struct sockaddr *addr, + int peer) +{ + struct sock *sk = sock->sk; + struct netlink_sock *nlk = nlk_sk(sk); + DECLARE_SOCKADDR(struct sockaddr_nl *, nladdr, addr); + + nladdr->nl_family = AF_NETLINK; + nladdr->nl_pad = 0; + + if (peer) { + /* Paired with WRITE_ONCE() in netlink_connect() */ + nladdr->nl_pid = READ_ONCE(nlk->dst_portid); + nladdr->nl_groups = netlink_group_mask(READ_ONCE(nlk->dst_group)); + } else { + /* Paired with WRITE_ONCE() in netlink_insert() */ + nladdr->nl_pid = READ_ONCE(nlk->portid); + netlink_lock_table(); + nladdr->nl_groups = nlk->groups ? nlk->groups[0] : 0; + netlink_unlock_table(); + } + return sizeof(*nladdr); +} + +static int netlink_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + /* try to hand this ioctl down to the NIC drivers. + */ + return -ENOIOCTLCMD; +} + +static struct sock *netlink_getsockbyportid(struct sock *ssk, u32 portid) +{ + struct sock *sock; + struct netlink_sock *nlk; + + sock = netlink_lookup(sock_net(ssk), ssk->sk_protocol, portid); + if (!sock) + return ERR_PTR(-ECONNREFUSED); + + /* Don't bother queuing skb if kernel socket has no input function */ + nlk = nlk_sk(sock); + /* dst_portid and sk_state can be changed in netlink_connect() */ + if (READ_ONCE(sock->sk_state) == NETLINK_CONNECTED && + READ_ONCE(nlk->dst_portid) != nlk_sk(ssk)->portid) { + sock_put(sock); + return ERR_PTR(-ECONNREFUSED); + } + return sock; +} + +struct sock *netlink_getsockbyfilp(struct file *filp) +{ + struct inode *inode = file_inode(filp); + struct sock *sock; + + if (!S_ISSOCK(inode->i_mode)) + return ERR_PTR(-ENOTSOCK); + + sock = SOCKET_I(inode)->sk; + if (sock->sk_family != AF_NETLINK) + return ERR_PTR(-EINVAL); + + sock_hold(sock); + return sock; +} + +static struct sk_buff *netlink_alloc_large_skb(unsigned int size, + int broadcast) +{ + struct sk_buff *skb; + void *data; + + if (size <= NLMSG_GOODSIZE || broadcast) + return alloc_skb(size, GFP_KERNEL); + + size = SKB_DATA_ALIGN(size) + + SKB_DATA_ALIGN(sizeof(struct skb_shared_info)); + + data = vmalloc(size); + if (data == NULL) + return NULL; + + skb = __build_skb(data, size); + if (skb == NULL) + vfree(data); + else + skb->destructor = netlink_skb_destructor; + + return skb; +} + +/* + * Attach a skb to a netlink socket. + * The caller must hold a reference to the destination socket. On error, the + * reference is dropped. The skb is not send to the destination, just all + * all error checks are performed and memory in the queue is reserved. + * Return values: + * < 0: error. skb freed, reference to sock dropped. + * 0: continue + * 1: repeat lookup - reference dropped while waiting for socket memory. + */ +int netlink_attachskb(struct sock *sk, struct sk_buff *skb, + long *timeo, struct sock *ssk) +{ + struct netlink_sock *nlk; + + nlk = nlk_sk(sk); + + if ((atomic_read(&sk->sk_rmem_alloc) > sk->sk_rcvbuf || + test_bit(NETLINK_S_CONGESTED, &nlk->state))) { + DECLARE_WAITQUEUE(wait, current); + if (!*timeo) { + if (!ssk || netlink_is_kernel(ssk)) + netlink_overrun(sk); + sock_put(sk); + kfree_skb(skb); + return -EAGAIN; + } + + __set_current_state(TASK_INTERRUPTIBLE); + add_wait_queue(&nlk->wait, &wait); + + if ((atomic_read(&sk->sk_rmem_alloc) > sk->sk_rcvbuf || + test_bit(NETLINK_S_CONGESTED, &nlk->state)) && + !sock_flag(sk, SOCK_DEAD)) + *timeo = schedule_timeout(*timeo); + + __set_current_state(TASK_RUNNING); + remove_wait_queue(&nlk->wait, &wait); + sock_put(sk); + + if (signal_pending(current)) { + kfree_skb(skb); + return sock_intr_errno(*timeo); + } + return 1; + } + netlink_skb_set_owner_r(skb, sk); + return 0; +} + +static int __netlink_sendskb(struct sock *sk, struct sk_buff *skb) +{ + int len = skb->len; + + netlink_deliver_tap(sock_net(sk), skb); + + skb_queue_tail(&sk->sk_receive_queue, skb); + sk->sk_data_ready(sk); + return len; +} + +int netlink_sendskb(struct sock *sk, struct sk_buff *skb) +{ + int len = __netlink_sendskb(sk, skb); + + sock_put(sk); + return len; +} + +void netlink_detachskb(struct sock *sk, struct sk_buff *skb) +{ + kfree_skb(skb); + sock_put(sk); +} + +static struct sk_buff *netlink_trim(struct sk_buff *skb, gfp_t allocation) +{ + int delta; + + WARN_ON(skb->sk != NULL); + delta = skb->end - skb->tail; + if (is_vmalloc_addr(skb->head) || delta * 2 < skb->truesize) + return skb; + + if (skb_shared(skb)) { + struct sk_buff *nskb = skb_clone(skb, allocation); + if (!nskb) + return skb; + consume_skb(skb); + skb = nskb; + } + + pskb_expand_head(skb, 0, -delta, + (allocation & ~__GFP_DIRECT_RECLAIM) | + __GFP_NOWARN | __GFP_NORETRY); + return skb; +} + +static int netlink_unicast_kernel(struct sock *sk, struct sk_buff *skb, + struct sock *ssk) +{ + int ret; + struct netlink_sock *nlk = nlk_sk(sk); + + ret = -ECONNREFUSED; + if (nlk->netlink_rcv != NULL) { + ret = skb->len; + netlink_skb_set_owner_r(skb, sk); + NETLINK_CB(skb).sk = ssk; + netlink_deliver_tap_kernel(sk, ssk, skb); + nlk->netlink_rcv(skb); + consume_skb(skb); + } else { + kfree_skb(skb); + } + sock_put(sk); + return ret; +} + +int netlink_unicast(struct sock *ssk, struct sk_buff *skb, + u32 portid, int nonblock) +{ + struct sock *sk; + int err; + long timeo; + + skb = netlink_trim(skb, gfp_any()); + + timeo = sock_sndtimeo(ssk, nonblock); +retry: + sk = netlink_getsockbyportid(ssk, portid); + if (IS_ERR(sk)) { + kfree_skb(skb); + return PTR_ERR(sk); + } + if (netlink_is_kernel(sk)) + return netlink_unicast_kernel(sk, skb, ssk); + + if (sk_filter(sk, skb)) { + err = skb->len; + kfree_skb(skb); + sock_put(sk); + return err; + } + + err = netlink_attachskb(sk, skb, &timeo, ssk); + if (err == 1) + goto retry; + if (err) + return err; + + return netlink_sendskb(sk, skb); +} +EXPORT_SYMBOL(netlink_unicast); + +int netlink_has_listeners(struct sock *sk, unsigned int group) +{ + int res = 0; + struct listeners *listeners; + + BUG_ON(!netlink_is_kernel(sk)); + + rcu_read_lock(); + listeners = rcu_dereference(nl_table[sk->sk_protocol].listeners); + + if (listeners && group - 1 < nl_table[sk->sk_protocol].groups) + res = test_bit(group - 1, listeners->masks); + + rcu_read_unlock(); + + return res; +} +EXPORT_SYMBOL_GPL(netlink_has_listeners); + +bool netlink_strict_get_check(struct sk_buff *skb) +{ + return nlk_test_bit(STRICT_CHK, NETLINK_CB(skb).sk); +} +EXPORT_SYMBOL_GPL(netlink_strict_get_check); + +static int netlink_broadcast_deliver(struct sock *sk, struct sk_buff *skb) +{ + struct netlink_sock *nlk = nlk_sk(sk); + + if (atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf && + !test_bit(NETLINK_S_CONGESTED, &nlk->state)) { + netlink_skb_set_owner_r(skb, sk); + __netlink_sendskb(sk, skb); + return atomic_read(&sk->sk_rmem_alloc) > (sk->sk_rcvbuf >> 1); + } + return -1; +} + +struct netlink_broadcast_data { + struct sock *exclude_sk; + struct net *net; + u32 portid; + u32 group; + int failure; + int delivery_failure; + int congested; + int delivered; + gfp_t allocation; + struct sk_buff *skb, *skb2; +}; + +static void do_one_broadcast(struct sock *sk, + struct netlink_broadcast_data *p) +{ + struct netlink_sock *nlk = nlk_sk(sk); + int val; + + if (p->exclude_sk == sk) + return; + + if (nlk->portid == p->portid || p->group - 1 >= nlk->ngroups || + !test_bit(p->group - 1, nlk->groups)) + return; + + if (!net_eq(sock_net(sk), p->net)) { + if (!nlk_test_bit(LISTEN_ALL_NSID, sk)) + return; + + if (!peernet_has_id(sock_net(sk), p->net)) + return; + + if (!file_ns_capable(sk->sk_socket->file, p->net->user_ns, + CAP_NET_BROADCAST)) + return; + } + + if (p->failure) { + netlink_overrun(sk); + return; + } + + sock_hold(sk); + if (p->skb2 == NULL) { + if (skb_shared(p->skb)) { + p->skb2 = skb_clone(p->skb, p->allocation); + } else { + p->skb2 = skb_get(p->skb); + /* + * skb ownership may have been set when + * delivered to a previous socket. + */ + skb_orphan(p->skb2); + } + } + if (p->skb2 == NULL) { + netlink_overrun(sk); + /* Clone failed. Notify ALL listeners. */ + p->failure = 1; + if (nlk_test_bit(BROADCAST_SEND_ERROR, sk)) + p->delivery_failure = 1; + goto out; + } + if (sk_filter(sk, p->skb2)) { + kfree_skb(p->skb2); + p->skb2 = NULL; + goto out; + } + NETLINK_CB(p->skb2).nsid = peernet2id(sock_net(sk), p->net); + if (NETLINK_CB(p->skb2).nsid != NETNSA_NSID_NOT_ASSIGNED) + NETLINK_CB(p->skb2).nsid_is_set = true; + val = netlink_broadcast_deliver(sk, p->skb2); + if (val < 0) { + netlink_overrun(sk); + if (nlk_test_bit(BROADCAST_SEND_ERROR, sk)) + p->delivery_failure = 1; + } else { + p->congested |= val; + p->delivered = 1; + p->skb2 = NULL; + } +out: + sock_put(sk); +} + +int netlink_broadcast(struct sock *ssk, struct sk_buff *skb, u32 portid, + u32 group, gfp_t allocation) +{ + struct net *net = sock_net(ssk); + struct netlink_broadcast_data info; + struct sock *sk; + + skb = netlink_trim(skb, allocation); + + info.exclude_sk = ssk; + info.net = net; + info.portid = portid; + info.group = group; + info.failure = 0; + info.delivery_failure = 0; + info.congested = 0; + info.delivered = 0; + info.allocation = allocation; + info.skb = skb; + info.skb2 = NULL; + + /* While we sleep in clone, do not allow to change socket list */ + + netlink_lock_table(); + + sk_for_each_bound(sk, &nl_table[ssk->sk_protocol].mc_list) + do_one_broadcast(sk, &info); + + consume_skb(skb); + + netlink_unlock_table(); + + if (info.delivery_failure) { + kfree_skb(info.skb2); + return -ENOBUFS; + } + consume_skb(info.skb2); + + if (info.delivered) { + if (info.congested && gfpflags_allow_blocking(allocation)) + yield(); + return 0; + } + return -ESRCH; +} +EXPORT_SYMBOL(netlink_broadcast); + +struct netlink_set_err_data { + struct sock *exclude_sk; + u32 portid; + u32 group; + int code; +}; + +static int do_one_set_err(struct sock *sk, struct netlink_set_err_data *p) +{ + struct netlink_sock *nlk = nlk_sk(sk); + int ret = 0; + + if (sk == p->exclude_sk) + goto out; + + if (!net_eq(sock_net(sk), sock_net(p->exclude_sk))) + goto out; + + if (nlk->portid == p->portid || p->group - 1 >= nlk->ngroups || + !test_bit(p->group - 1, nlk->groups)) + goto out; + + if (p->code == ENOBUFS && nlk_test_bit(RECV_NO_ENOBUFS, sk)) { + ret = 1; + goto out; + } + + WRITE_ONCE(sk->sk_err, p->code); + sk_error_report(sk); +out: + return ret; +} + +/** + * netlink_set_err - report error to broadcast listeners + * @ssk: the kernel netlink socket, as returned by netlink_kernel_create() + * @portid: the PORTID of a process that we want to skip (if any) + * @group: the broadcast group that will notice the error + * @code: error code, must be negative (as usual in kernelspace) + * + * This function returns the number of broadcast listeners that have set the + * NETLINK_NO_ENOBUFS socket option. + */ +int netlink_set_err(struct sock *ssk, u32 portid, u32 group, int code) +{ + struct netlink_set_err_data info; + unsigned long flags; + struct sock *sk; + int ret = 0; + + info.exclude_sk = ssk; + info.portid = portid; + info.group = group; + /* sk->sk_err wants a positive error value */ + info.code = -code; + + read_lock_irqsave(&nl_table_lock, flags); + + sk_for_each_bound(sk, &nl_table[ssk->sk_protocol].mc_list) + ret += do_one_set_err(sk, &info); + + read_unlock_irqrestore(&nl_table_lock, flags); + return ret; +} +EXPORT_SYMBOL(netlink_set_err); + +/* must be called with netlink table grabbed */ +static void netlink_update_socket_mc(struct netlink_sock *nlk, + unsigned int group, + int is_new) +{ + int old, new = !!is_new, subscriptions; + + old = test_bit(group - 1, nlk->groups); + subscriptions = nlk->subscriptions - old + new; + if (new) + __set_bit(group - 1, nlk->groups); + else + __clear_bit(group - 1, nlk->groups); + netlink_update_subscriptions(&nlk->sk, subscriptions); + netlink_update_listeners(&nlk->sk); +} + +static int netlink_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct netlink_sock *nlk = nlk_sk(sk); + unsigned int val = 0; + int nr = -1; + + if (level != SOL_NETLINK) + return -ENOPROTOOPT; + + if (optlen >= sizeof(int) && + copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + + switch (optname) { + case NETLINK_PKTINFO: + nr = NETLINK_F_RECV_PKTINFO; + break; + case NETLINK_ADD_MEMBERSHIP: + case NETLINK_DROP_MEMBERSHIP: { + int err; + + if (!netlink_allowed(sock, NL_CFG_F_NONROOT_RECV)) + return -EPERM; + err = netlink_realloc_groups(sk); + if (err) + return err; + if (!val || val - 1 >= nlk->ngroups) + return -EINVAL; + if (optname == NETLINK_ADD_MEMBERSHIP && nlk->netlink_bind) { + err = nlk->netlink_bind(sock_net(sk), val); + if (err) + return err; + } + netlink_table_grab(); + netlink_update_socket_mc(nlk, val, + optname == NETLINK_ADD_MEMBERSHIP); + netlink_table_ungrab(); + if (optname == NETLINK_DROP_MEMBERSHIP && nlk->netlink_unbind) + nlk->netlink_unbind(sock_net(sk), val); + + break; + } + case NETLINK_BROADCAST_ERROR: + nr = NETLINK_F_BROADCAST_SEND_ERROR; + break; + case NETLINK_NO_ENOBUFS: + assign_bit(NETLINK_F_RECV_NO_ENOBUFS, &nlk->flags, val); + if (val) { + clear_bit(NETLINK_S_CONGESTED, &nlk->state); + wake_up_interruptible(&nlk->wait); + } + break; + case NETLINK_LISTEN_ALL_NSID: + if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_BROADCAST)) + return -EPERM; + nr = NETLINK_F_LISTEN_ALL_NSID; + break; + case NETLINK_CAP_ACK: + nr = NETLINK_F_CAP_ACK; + break; + case NETLINK_EXT_ACK: + nr = NETLINK_F_EXT_ACK; + break; + case NETLINK_GET_STRICT_CHK: + nr = NETLINK_F_STRICT_CHK; + break; + default: + return -ENOPROTOOPT; + } + if (nr >= 0) + assign_bit(nr, &nlk->flags, val); + return 0; +} + +static int netlink_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + struct netlink_sock *nlk = nlk_sk(sk); + unsigned int flag; + int len, val; + + if (level != SOL_NETLINK) + return -ENOPROTOOPT; + + if (get_user(len, optlen)) + return -EFAULT; + if (len < 0) + return -EINVAL; + + switch (optname) { + case NETLINK_PKTINFO: + flag = NETLINK_F_RECV_PKTINFO; + break; + case NETLINK_BROADCAST_ERROR: + flag = NETLINK_F_BROADCAST_SEND_ERROR; + break; + case NETLINK_NO_ENOBUFS: + flag = NETLINK_F_RECV_NO_ENOBUFS; + break; + case NETLINK_LIST_MEMBERSHIPS: { + int pos, idx, shift, err = 0; + + netlink_lock_table(); + for (pos = 0; pos * 8 < nlk->ngroups; pos += sizeof(u32)) { + if (len - pos < sizeof(u32)) + break; + + idx = pos / sizeof(unsigned long); + shift = (pos % sizeof(unsigned long)) * 8; + if (put_user((u32)(nlk->groups[idx] >> shift), + (u32 __user *)(optval + pos))) { + err = -EFAULT; + break; + } + } + if (put_user(ALIGN(BITS_TO_BYTES(nlk->ngroups), sizeof(u32)), optlen)) + err = -EFAULT; + netlink_unlock_table(); + return err; + } + case NETLINK_CAP_ACK: + flag = NETLINK_F_CAP_ACK; + break; + case NETLINK_EXT_ACK: + flag = NETLINK_F_EXT_ACK; + break; + case NETLINK_GET_STRICT_CHK: + flag = NETLINK_F_STRICT_CHK; + break; + default: + return -ENOPROTOOPT; + } + + if (len < sizeof(int)) + return -EINVAL; + + len = sizeof(int); + val = test_bit(flag, &nlk->flags); + + if (put_user(len, optlen) || + copy_to_user(optval, &val, len)) + return -EFAULT; + + return 0; +} + +static void netlink_cmsg_recv_pktinfo(struct msghdr *msg, struct sk_buff *skb) +{ + struct nl_pktinfo info; + + info.group = NETLINK_CB(skb).dst_group; + put_cmsg(msg, SOL_NETLINK, NETLINK_PKTINFO, sizeof(info), &info); +} + +static void netlink_cmsg_listen_all_nsid(struct sock *sk, struct msghdr *msg, + struct sk_buff *skb) +{ + if (!NETLINK_CB(skb).nsid_is_set) + return; + + put_cmsg(msg, SOL_NETLINK, NETLINK_LISTEN_ALL_NSID, sizeof(int), + &NETLINK_CB(skb).nsid); +} + +static int netlink_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) +{ + struct sock *sk = sock->sk; + struct netlink_sock *nlk = nlk_sk(sk); + DECLARE_SOCKADDR(struct sockaddr_nl *, addr, msg->msg_name); + u32 dst_portid; + u32 dst_group; + struct sk_buff *skb; + int err; + struct scm_cookie scm; + u32 netlink_skb_flags = 0; + + if (msg->msg_flags & MSG_OOB) + return -EOPNOTSUPP; + + if (len == 0) { + pr_warn_once("Zero length message leads to an empty skb\n"); + return -ENODATA; + } + + err = scm_send(sock, msg, &scm, true); + if (err < 0) + return err; + + if (msg->msg_namelen) { + err = -EINVAL; + if (msg->msg_namelen < sizeof(struct sockaddr_nl)) + goto out; + if (addr->nl_family != AF_NETLINK) + goto out; + dst_portid = addr->nl_pid; + dst_group = ffs(addr->nl_groups); + err = -EPERM; + if ((dst_group || dst_portid) && + !netlink_allowed(sock, NL_CFG_F_NONROOT_SEND)) + goto out; + netlink_skb_flags |= NETLINK_SKB_DST; + } else { + /* Paired with WRITE_ONCE() in netlink_connect() */ + dst_portid = READ_ONCE(nlk->dst_portid); + dst_group = READ_ONCE(nlk->dst_group); + } + + /* Paired with WRITE_ONCE() in netlink_insert() */ + if (!READ_ONCE(nlk->bound)) { + err = netlink_autobind(sock); + if (err) + goto out; + } else { + /* Ensure nlk is hashed and visible. */ + smp_rmb(); + } + + err = -EMSGSIZE; + if (len > sk->sk_sndbuf - 32) + goto out; + err = -ENOBUFS; + skb = netlink_alloc_large_skb(len, dst_group); + if (skb == NULL) + goto out; + + NETLINK_CB(skb).portid = nlk->portid; + NETLINK_CB(skb).dst_group = dst_group; + NETLINK_CB(skb).creds = scm.creds; + NETLINK_CB(skb).flags = netlink_skb_flags; + + err = -EFAULT; + if (memcpy_from_msg(skb_put(skb, len), msg, len)) { + kfree_skb(skb); + goto out; + } + + err = security_netlink_send(sk, skb); + if (err) { + kfree_skb(skb); + goto out; + } + + if (dst_group) { + refcount_inc(&skb->users); + netlink_broadcast(sk, skb, dst_portid, dst_group, GFP_KERNEL); + } + err = netlink_unicast(sk, skb, dst_portid, msg->msg_flags & MSG_DONTWAIT); + +out: + scm_destroy(&scm); + return err; +} + +static int netlink_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, + int flags) +{ + struct scm_cookie scm; + struct sock *sk = sock->sk; + struct netlink_sock *nlk = nlk_sk(sk); + size_t copied, max_recvmsg_len; + struct sk_buff *skb, *data_skb; + int err, ret; + + if (flags & MSG_OOB) + return -EOPNOTSUPP; + + copied = 0; + + skb = skb_recv_datagram(sk, flags, &err); + if (skb == NULL) + goto out; + + data_skb = skb; + +#ifdef CONFIG_COMPAT_NETLINK_MESSAGES + if (unlikely(skb_shinfo(skb)->frag_list)) { + /* + * If this skb has a frag_list, then here that means that we + * will have to use the frag_list skb's data for compat tasks + * and the regular skb's data for normal (non-compat) tasks. + * + * If we need to send the compat skb, assign it to the + * 'data_skb' variable so that it will be used below for data + * copying. We keep 'skb' for everything else, including + * freeing both later. + */ + if (flags & MSG_CMSG_COMPAT) + data_skb = skb_shinfo(skb)->frag_list; + } +#endif + + /* Record the max length of recvmsg() calls for future allocations */ + max_recvmsg_len = max(READ_ONCE(nlk->max_recvmsg_len), len); + max_recvmsg_len = min_t(size_t, max_recvmsg_len, + SKB_WITH_OVERHEAD(32768)); + WRITE_ONCE(nlk->max_recvmsg_len, max_recvmsg_len); + + copied = data_skb->len; + if (len < copied) { + msg->msg_flags |= MSG_TRUNC; + copied = len; + } + + err = skb_copy_datagram_msg(data_skb, 0, msg, copied); + + if (msg->msg_name) { + DECLARE_SOCKADDR(struct sockaddr_nl *, addr, msg->msg_name); + addr->nl_family = AF_NETLINK; + addr->nl_pad = 0; + addr->nl_pid = NETLINK_CB(skb).portid; + addr->nl_groups = netlink_group_mask(NETLINK_CB(skb).dst_group); + msg->msg_namelen = sizeof(*addr); + } + + if (nlk_test_bit(RECV_PKTINFO, sk)) + netlink_cmsg_recv_pktinfo(msg, skb); + if (nlk_test_bit(LISTEN_ALL_NSID, sk)) + netlink_cmsg_listen_all_nsid(sk, msg, skb); + + memset(&scm, 0, sizeof(scm)); + scm.creds = *NETLINK_CREDS(skb); + if (flags & MSG_TRUNC) + copied = data_skb->len; + + skb_free_datagram(sk, skb); + + if (READ_ONCE(nlk->cb_running) && + atomic_read(&sk->sk_rmem_alloc) <= sk->sk_rcvbuf / 2) { + ret = netlink_dump(sk); + if (ret) { + WRITE_ONCE(sk->sk_err, -ret); + sk_error_report(sk); + } + } + + scm_recv(sock, msg, &scm, flags); +out: + netlink_rcv_wake(sk); + return err ? : copied; +} + +static void netlink_data_ready(struct sock *sk) +{ + BUG(); +} + +/* + * We export these functions to other modules. They provide a + * complete set of kernel non-blocking support for message + * queueing. + */ + +struct sock * +__netlink_kernel_create(struct net *net, int unit, struct module *module, + struct netlink_kernel_cfg *cfg) +{ + struct socket *sock; + struct sock *sk; + struct netlink_sock *nlk; + struct listeners *listeners = NULL; + struct mutex *cb_mutex = cfg ? cfg->cb_mutex : NULL; + unsigned int groups; + + BUG_ON(!nl_table); + + if (unit < 0 || unit >= MAX_LINKS) + return NULL; + + if (sock_create_lite(PF_NETLINK, SOCK_DGRAM, unit, &sock)) + return NULL; + + if (__netlink_create(net, sock, cb_mutex, unit, 1) < 0) + goto out_sock_release_nosk; + + sk = sock->sk; + + if (!cfg || cfg->groups < 32) + groups = 32; + else + groups = cfg->groups; + + listeners = kzalloc(sizeof(*listeners) + NLGRPSZ(groups), GFP_KERNEL); + if (!listeners) + goto out_sock_release; + + sk->sk_data_ready = netlink_data_ready; + if (cfg && cfg->input) + nlk_sk(sk)->netlink_rcv = cfg->input; + + if (netlink_insert(sk, 0)) + goto out_sock_release; + + nlk = nlk_sk(sk); + set_bit(NETLINK_F_KERNEL_SOCKET, &nlk->flags); + + netlink_table_grab(); + if (!nl_table[unit].registered) { + nl_table[unit].groups = groups; + rcu_assign_pointer(nl_table[unit].listeners, listeners); + nl_table[unit].cb_mutex = cb_mutex; + nl_table[unit].module = module; + if (cfg) { + nl_table[unit].bind = cfg->bind; + nl_table[unit].unbind = cfg->unbind; + nl_table[unit].flags = cfg->flags; + if (cfg->compare) + nl_table[unit].compare = cfg->compare; + } + nl_table[unit].registered = 1; + } else { + kfree(listeners); + nl_table[unit].registered++; + } + netlink_table_ungrab(); + return sk; + +out_sock_release: + kfree(listeners); + netlink_kernel_release(sk); + return NULL; + +out_sock_release_nosk: + sock_release(sock); + return NULL; +} +EXPORT_SYMBOL(__netlink_kernel_create); + +void +netlink_kernel_release(struct sock *sk) +{ + if (sk == NULL || sk->sk_socket == NULL) + return; + + sock_release(sk->sk_socket); +} +EXPORT_SYMBOL(netlink_kernel_release); + +int __netlink_change_ngroups(struct sock *sk, unsigned int groups) +{ + struct listeners *new, *old; + struct netlink_table *tbl = &nl_table[sk->sk_protocol]; + + if (groups < 32) + groups = 32; + + if (NLGRPSZ(tbl->groups) < NLGRPSZ(groups)) { + new = kzalloc(sizeof(*new) + NLGRPSZ(groups), GFP_ATOMIC); + if (!new) + return -ENOMEM; + old = nl_deref_protected(tbl->listeners); + memcpy(new->masks, old->masks, NLGRPSZ(tbl->groups)); + rcu_assign_pointer(tbl->listeners, new); + + kfree_rcu(old, rcu); + } + tbl->groups = groups; + + return 0; +} + +/** + * netlink_change_ngroups - change number of multicast groups + * + * This changes the number of multicast groups that are available + * on a certain netlink family. Note that it is not possible to + * change the number of groups to below 32. Also note that it does + * not implicitly call netlink_clear_multicast_users() when the + * number of groups is reduced. + * + * @sk: The kernel netlink socket, as returned by netlink_kernel_create(). + * @groups: The new number of groups. + */ +int netlink_change_ngroups(struct sock *sk, unsigned int groups) +{ + int err; + + netlink_table_grab(); + err = __netlink_change_ngroups(sk, groups); + netlink_table_ungrab(); + + return err; +} + +void __netlink_clear_multicast_users(struct sock *ksk, unsigned int group) +{ + struct sock *sk; + struct netlink_table *tbl = &nl_table[ksk->sk_protocol]; + + sk_for_each_bound(sk, &tbl->mc_list) + netlink_update_socket_mc(nlk_sk(sk), group, 0); +} + +struct nlmsghdr * +__nlmsg_put(struct sk_buff *skb, u32 portid, u32 seq, int type, int len, int flags) +{ + struct nlmsghdr *nlh; + int size = nlmsg_msg_size(len); + + nlh = skb_put(skb, NLMSG_ALIGN(size)); + nlh->nlmsg_type = type; + nlh->nlmsg_len = size; + nlh->nlmsg_flags = flags; + nlh->nlmsg_pid = portid; + nlh->nlmsg_seq = seq; + if (!__builtin_constant_p(size) || NLMSG_ALIGN(size) - size != 0) + memset(nlmsg_data(nlh) + len, 0, NLMSG_ALIGN(size) - size); + return nlh; +} +EXPORT_SYMBOL(__nlmsg_put); + +/* + * It looks a bit ugly. + * It would be better to create kernel thread. + */ + +static int netlink_dump_done(struct netlink_sock *nlk, struct sk_buff *skb, + struct netlink_callback *cb, + struct netlink_ext_ack *extack) +{ + struct nlmsghdr *nlh; + + nlh = nlmsg_put_answer(skb, cb, NLMSG_DONE, sizeof(nlk->dump_done_errno), + NLM_F_MULTI | cb->answer_flags); + if (WARN_ON(!nlh)) + return -ENOBUFS; + + nl_dump_check_consistent(cb, nlh); + memcpy(nlmsg_data(nlh), &nlk->dump_done_errno, sizeof(nlk->dump_done_errno)); + + if (extack->_msg && test_bit(NETLINK_F_EXT_ACK, &nlk->flags)) { + nlh->nlmsg_flags |= NLM_F_ACK_TLVS; + if (!nla_put_string(skb, NLMSGERR_ATTR_MSG, extack->_msg)) + nlmsg_end(skb, nlh); + } + + return 0; +} + +static int netlink_dump(struct sock *sk) +{ + struct netlink_sock *nlk = nlk_sk(sk); + struct netlink_ext_ack extack = {}; + struct netlink_callback *cb; + struct sk_buff *skb = NULL; + size_t max_recvmsg_len; + struct module *module; + int err = -ENOBUFS; + int alloc_min_size; + int alloc_size; + + mutex_lock(nlk->cb_mutex); + if (!nlk->cb_running) { + err = -EINVAL; + goto errout_skb; + } + + if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf) + goto errout_skb; + + /* NLMSG_GOODSIZE is small to avoid high order allocations being + * required, but it makes sense to _attempt_ a 16K bytes allocation + * to reduce number of system calls on dump operations, if user + * ever provided a big enough buffer. + */ + cb = &nlk->cb; + alloc_min_size = max_t(int, cb->min_dump_alloc, NLMSG_GOODSIZE); + + max_recvmsg_len = READ_ONCE(nlk->max_recvmsg_len); + if (alloc_min_size < max_recvmsg_len) { + alloc_size = max_recvmsg_len; + skb = alloc_skb(alloc_size, + (GFP_KERNEL & ~__GFP_DIRECT_RECLAIM) | + __GFP_NOWARN | __GFP_NORETRY); + } + if (!skb) { + alloc_size = alloc_min_size; + skb = alloc_skb(alloc_size, GFP_KERNEL); + } + if (!skb) + goto errout_skb; + + /* Trim skb to allocated size. User is expected to provide buffer as + * large as max(min_dump_alloc, 16KiB (mac_recvmsg_len capped at + * netlink_recvmsg())). dump will pack as many smaller messages as + * could fit within the allocated skb. skb is typically allocated + * with larger space than required (could be as much as near 2x the + * requested size with align to next power of 2 approach). Allowing + * dump to use the excess space makes it difficult for a user to have a + * reasonable static buffer based on the expected largest dump of a + * single netdev. The outcome is MSG_TRUNC error. + */ + skb_reserve(skb, skb_tailroom(skb) - alloc_size); + + /* Make sure malicious BPF programs can not read unitialized memory + * from skb->head -> skb->data + */ + skb_reset_network_header(skb); + skb_reset_mac_header(skb); + + netlink_skb_set_owner_r(skb, sk); + + if (nlk->dump_done_errno > 0) { + cb->extack = &extack; + nlk->dump_done_errno = cb->dump(skb, cb); + cb->extack = NULL; + } + + if (nlk->dump_done_errno > 0 || + skb_tailroom(skb) < nlmsg_total_size(sizeof(nlk->dump_done_errno))) { + mutex_unlock(nlk->cb_mutex); + + if (sk_filter(sk, skb)) + kfree_skb(skb); + else + __netlink_sendskb(sk, skb); + return 0; + } + + if (netlink_dump_done(nlk, skb, cb, &extack)) + goto errout_skb; + +#ifdef CONFIG_COMPAT_NETLINK_MESSAGES + /* frag_list skb's data is used for compat tasks + * and the regular skb's data for normal (non-compat) tasks. + * See netlink_recvmsg(). + */ + if (unlikely(skb_shinfo(skb)->frag_list)) { + if (netlink_dump_done(nlk, skb_shinfo(skb)->frag_list, cb, &extack)) + goto errout_skb; + } +#endif + + if (sk_filter(sk, skb)) + kfree_skb(skb); + else + __netlink_sendskb(sk, skb); + + if (cb->done) + cb->done(cb); + + WRITE_ONCE(nlk->cb_running, false); + module = cb->module; + skb = cb->skb; + mutex_unlock(nlk->cb_mutex); + module_put(module); + consume_skb(skb); + return 0; + +errout_skb: + mutex_unlock(nlk->cb_mutex); + kfree_skb(skb); + return err; +} + +int __netlink_dump_start(struct sock *ssk, struct sk_buff *skb, + const struct nlmsghdr *nlh, + struct netlink_dump_control *control) +{ + struct netlink_callback *cb; + struct netlink_sock *nlk; + struct sock *sk; + int ret; + + refcount_inc(&skb->users); + + sk = netlink_lookup(sock_net(ssk), ssk->sk_protocol, NETLINK_CB(skb).portid); + if (sk == NULL) { + ret = -ECONNREFUSED; + goto error_free; + } + + nlk = nlk_sk(sk); + mutex_lock(nlk->cb_mutex); + /* A dump is in progress... */ + if (nlk->cb_running) { + ret = -EBUSY; + goto error_unlock; + } + /* add reference of module which cb->dump belongs to */ + if (!try_module_get(control->module)) { + ret = -EPROTONOSUPPORT; + goto error_unlock; + } + + cb = &nlk->cb; + memset(cb, 0, sizeof(*cb)); + cb->dump = control->dump; + cb->done = control->done; + cb->nlh = nlh; + cb->data = control->data; + cb->module = control->module; + cb->min_dump_alloc = control->min_dump_alloc; + cb->skb = skb; + + cb->strict_check = nlk_test_bit(STRICT_CHK, NETLINK_CB(skb).sk); + + if (control->start) { + ret = control->start(cb); + if (ret) + goto error_put; + } + + WRITE_ONCE(nlk->cb_running, true); + nlk->dump_done_errno = INT_MAX; + + mutex_unlock(nlk->cb_mutex); + + ret = netlink_dump(sk); + + sock_put(sk); + + if (ret) + return ret; + + /* We successfully started a dump, by returning -EINTR we + * signal not to send ACK even if it was requested. + */ + return -EINTR; + +error_put: + module_put(control->module); +error_unlock: + sock_put(sk); + mutex_unlock(nlk->cb_mutex); +error_free: + kfree_skb(skb); + return ret; +} +EXPORT_SYMBOL(__netlink_dump_start); + +static size_t +netlink_ack_tlv_len(struct netlink_sock *nlk, int err, + const struct netlink_ext_ack *extack) +{ + size_t tlvlen; + + if (!extack || !test_bit(NETLINK_F_EXT_ACK, &nlk->flags)) + return 0; + + tlvlen = 0; + if (extack->_msg) + tlvlen += nla_total_size(strlen(extack->_msg) + 1); + if (extack->cookie_len) + tlvlen += nla_total_size(extack->cookie_len); + + /* Following attributes are only reported as error (not warning) */ + if (!err) + return tlvlen; + + if (extack->bad_attr) + tlvlen += nla_total_size(sizeof(u32)); + if (extack->policy) + tlvlen += netlink_policy_dump_attr_size_estimate(extack->policy); + if (extack->miss_type) + tlvlen += nla_total_size(sizeof(u32)); + if (extack->miss_nest) + tlvlen += nla_total_size(sizeof(u32)); + + return tlvlen; +} + +static void +netlink_ack_tlv_fill(struct sk_buff *in_skb, struct sk_buff *skb, + struct nlmsghdr *nlh, int err, + const struct netlink_ext_ack *extack) +{ + if (extack->_msg) + WARN_ON(nla_put_string(skb, NLMSGERR_ATTR_MSG, extack->_msg)); + if (extack->cookie_len) + WARN_ON(nla_put(skb, NLMSGERR_ATTR_COOKIE, + extack->cookie_len, extack->cookie)); + + if (!err) + return; + + if (extack->bad_attr && + !WARN_ON((u8 *)extack->bad_attr < in_skb->data || + (u8 *)extack->bad_attr >= in_skb->data + in_skb->len)) + WARN_ON(nla_put_u32(skb, NLMSGERR_ATTR_OFFS, + (u8 *)extack->bad_attr - (u8 *)nlh)); + if (extack->policy) + netlink_policy_dump_write_attr(skb, extack->policy, + NLMSGERR_ATTR_POLICY); + if (extack->miss_type) + WARN_ON(nla_put_u32(skb, NLMSGERR_ATTR_MISS_TYPE, + extack->miss_type)); + if (extack->miss_nest && + !WARN_ON((u8 *)extack->miss_nest < in_skb->data || + (u8 *)extack->miss_nest > in_skb->data + in_skb->len)) + WARN_ON(nla_put_u32(skb, NLMSGERR_ATTR_MISS_NEST, + (u8 *)extack->miss_nest - (u8 *)nlh)); +} + +void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err, + const struct netlink_ext_ack *extack) +{ + struct sk_buff *skb; + struct nlmsghdr *rep; + struct nlmsgerr *errmsg; + size_t payload = sizeof(*errmsg); + struct netlink_sock *nlk = nlk_sk(NETLINK_CB(in_skb).sk); + unsigned int flags = 0; + size_t tlvlen; + + /* Error messages get the original request appened, unless the user + * requests to cap the error message, and get extra error data if + * requested. + */ + if (err && !test_bit(NETLINK_F_CAP_ACK, &nlk->flags)) + payload += nlmsg_len(nlh); + else + flags |= NLM_F_CAPPED; + + tlvlen = netlink_ack_tlv_len(nlk, err, extack); + if (tlvlen) + flags |= NLM_F_ACK_TLVS; + + skb = nlmsg_new(payload + tlvlen, GFP_KERNEL); + if (!skb) + goto err_skb; + + rep = nlmsg_put(skb, NETLINK_CB(in_skb).portid, nlh->nlmsg_seq, + NLMSG_ERROR, sizeof(*errmsg), flags); + if (!rep) + goto err_bad_put; + errmsg = nlmsg_data(rep); + errmsg->error = err; + errmsg->msg = *nlh; + + if (!(flags & NLM_F_CAPPED)) { + if (!nlmsg_append(skb, nlmsg_len(nlh))) + goto err_bad_put; + + memcpy(nlmsg_data(&errmsg->msg), nlmsg_data(nlh), + nlmsg_len(nlh)); + } + + if (tlvlen) + netlink_ack_tlv_fill(in_skb, skb, nlh, err, extack); + + nlmsg_end(skb, rep); + + nlmsg_unicast(in_skb->sk, skb, NETLINK_CB(in_skb).portid); + + return; + +err_bad_put: + nlmsg_free(skb); +err_skb: + WRITE_ONCE(NETLINK_CB(in_skb).sk->sk_err, ENOBUFS); + sk_error_report(NETLINK_CB(in_skb).sk); +} +EXPORT_SYMBOL(netlink_ack); + +int netlink_rcv_skb(struct sk_buff *skb, int (*cb)(struct sk_buff *, + struct nlmsghdr *, + struct netlink_ext_ack *)) +{ + struct netlink_ext_ack extack; + struct nlmsghdr *nlh; + int err; + + while (skb->len >= nlmsg_total_size(0)) { + int msglen; + + memset(&extack, 0, sizeof(extack)); + nlh = nlmsg_hdr(skb); + err = 0; + + if (nlh->nlmsg_len < NLMSG_HDRLEN || skb->len < nlh->nlmsg_len) + return 0; + + /* Only requests are handled by the kernel */ + if (!(nlh->nlmsg_flags & NLM_F_REQUEST)) + goto ack; + + /* Skip control messages */ + if (nlh->nlmsg_type < NLMSG_MIN_TYPE) + goto ack; + + err = cb(skb, nlh, &extack); + if (err == -EINTR) + goto skip; + +ack: + if (nlh->nlmsg_flags & NLM_F_ACK || err) + netlink_ack(skb, nlh, err, &extack); + +skip: + msglen = NLMSG_ALIGN(nlh->nlmsg_len); + if (msglen > skb->len) + msglen = skb->len; + skb_pull(skb, msglen); + } + + return 0; +} +EXPORT_SYMBOL(netlink_rcv_skb); + +/** + * nlmsg_notify - send a notification netlink message + * @sk: netlink socket to use + * @skb: notification message + * @portid: destination netlink portid for reports or 0 + * @group: destination multicast group or 0 + * @report: 1 to report back, 0 to disable + * @flags: allocation flags + */ +int nlmsg_notify(struct sock *sk, struct sk_buff *skb, u32 portid, + unsigned int group, int report, gfp_t flags) +{ + int err = 0; + + if (group) { + int exclude_portid = 0; + + if (report) { + refcount_inc(&skb->users); + exclude_portid = portid; + } + + /* errors reported via destination sk->sk_err, but propagate + * delivery errors if NETLINK_BROADCAST_ERROR flag is set */ + err = nlmsg_multicast(sk, skb, exclude_portid, group, flags); + if (err == -ESRCH) + err = 0; + } + + if (report) { + int err2; + + err2 = nlmsg_unicast(sk, skb, portid); + if (!err) + err = err2; + } + + return err; +} +EXPORT_SYMBOL(nlmsg_notify); + +#ifdef CONFIG_PROC_FS +struct nl_seq_iter { + struct seq_net_private p; + struct rhashtable_iter hti; + int link; +}; + +static void netlink_walk_start(struct nl_seq_iter *iter) +{ + rhashtable_walk_enter(&nl_table[iter->link].hash, &iter->hti); + rhashtable_walk_start(&iter->hti); +} + +static void netlink_walk_stop(struct nl_seq_iter *iter) +{ + rhashtable_walk_stop(&iter->hti); + rhashtable_walk_exit(&iter->hti); +} + +static void *__netlink_seq_next(struct seq_file *seq) +{ + struct nl_seq_iter *iter = seq->private; + struct netlink_sock *nlk; + + do { + for (;;) { + nlk = rhashtable_walk_next(&iter->hti); + + if (IS_ERR(nlk)) { + if (PTR_ERR(nlk) == -EAGAIN) + continue; + + return nlk; + } + + if (nlk) + break; + + netlink_walk_stop(iter); + if (++iter->link >= MAX_LINKS) + return NULL; + + netlink_walk_start(iter); + } + } while (sock_net(&nlk->sk) != seq_file_net(seq)); + + return nlk; +} + +static void *netlink_seq_start(struct seq_file *seq, loff_t *posp) + __acquires(RCU) +{ + struct nl_seq_iter *iter = seq->private; + void *obj = SEQ_START_TOKEN; + loff_t pos; + + iter->link = 0; + + netlink_walk_start(iter); + + for (pos = *posp; pos && obj && !IS_ERR(obj); pos--) + obj = __netlink_seq_next(seq); + + return obj; +} + +static void *netlink_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + ++*pos; + return __netlink_seq_next(seq); +} + +static void netlink_native_seq_stop(struct seq_file *seq, void *v) +{ + struct nl_seq_iter *iter = seq->private; + + if (iter->link >= MAX_LINKS) + return; + + netlink_walk_stop(iter); +} + + +static int netlink_native_seq_show(struct seq_file *seq, void *v) +{ + if (v == SEQ_START_TOKEN) { + seq_puts(seq, + "sk Eth Pid Groups " + "Rmem Wmem Dump Locks Drops Inode\n"); + } else { + struct sock *s = v; + struct netlink_sock *nlk = nlk_sk(s); + + seq_printf(seq, "%pK %-3d %-10u %08x %-8d %-8d %-5d %-8d %-8u %-8lu\n", + s, + s->sk_protocol, + nlk->portid, + nlk->groups ? (u32)nlk->groups[0] : 0, + sk_rmem_alloc_get(s), + sk_wmem_alloc_get(s), + READ_ONCE(nlk->cb_running), + refcount_read(&s->sk_refcnt), + atomic_read(&s->sk_drops), + sock_i_ino(s) + ); + + } + return 0; +} + +#ifdef CONFIG_BPF_SYSCALL +struct bpf_iter__netlink { + __bpf_md_ptr(struct bpf_iter_meta *, meta); + __bpf_md_ptr(struct netlink_sock *, sk); +}; + +DEFINE_BPF_ITER_FUNC(netlink, struct bpf_iter_meta *meta, struct netlink_sock *sk) + +static int netlink_prog_seq_show(struct bpf_prog *prog, + struct bpf_iter_meta *meta, + void *v) +{ + struct bpf_iter__netlink ctx; + + meta->seq_num--; /* skip SEQ_START_TOKEN */ + ctx.meta = meta; + ctx.sk = nlk_sk((struct sock *)v); + return bpf_iter_run_prog(prog, &ctx); +} + +static int netlink_seq_show(struct seq_file *seq, void *v) +{ + struct bpf_iter_meta meta; + struct bpf_prog *prog; + + meta.seq = seq; + prog = bpf_iter_get_info(&meta, false); + if (!prog) + return netlink_native_seq_show(seq, v); + + if (v != SEQ_START_TOKEN) + return netlink_prog_seq_show(prog, &meta, v); + + return 0; +} + +static void netlink_seq_stop(struct seq_file *seq, void *v) +{ + struct bpf_iter_meta meta; + struct bpf_prog *prog; + + if (!v) { + meta.seq = seq; + prog = bpf_iter_get_info(&meta, true); + if (prog) + (void)netlink_prog_seq_show(prog, &meta, v); + } + + netlink_native_seq_stop(seq, v); +} +#else +static int netlink_seq_show(struct seq_file *seq, void *v) +{ + return netlink_native_seq_show(seq, v); +} + +static void netlink_seq_stop(struct seq_file *seq, void *v) +{ + netlink_native_seq_stop(seq, v); +} +#endif + +static const struct seq_operations netlink_seq_ops = { + .start = netlink_seq_start, + .next = netlink_seq_next, + .stop = netlink_seq_stop, + .show = netlink_seq_show, +}; +#endif + +int netlink_register_notifier(struct notifier_block *nb) +{ + return blocking_notifier_chain_register(&netlink_chain, nb); +} +EXPORT_SYMBOL(netlink_register_notifier); + +int netlink_unregister_notifier(struct notifier_block *nb) +{ + return blocking_notifier_chain_unregister(&netlink_chain, nb); +} +EXPORT_SYMBOL(netlink_unregister_notifier); + +static const struct proto_ops netlink_ops = { + .family = PF_NETLINK, + .owner = THIS_MODULE, + .release = netlink_release, + .bind = netlink_bind, + .connect = netlink_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = netlink_getname, + .poll = datagram_poll, + .ioctl = netlink_ioctl, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = netlink_setsockopt, + .getsockopt = netlink_getsockopt, + .sendmsg = netlink_sendmsg, + .recvmsg = netlink_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static const struct net_proto_family netlink_family_ops = { + .family = PF_NETLINK, + .create = netlink_create, + .owner = THIS_MODULE, /* for consistency 8) */ +}; + +static int __net_init netlink_net_init(struct net *net) +{ +#ifdef CONFIG_PROC_FS + if (!proc_create_net("netlink", 0, net->proc_net, &netlink_seq_ops, + sizeof(struct nl_seq_iter))) + return -ENOMEM; +#endif + return 0; +} + +static void __net_exit netlink_net_exit(struct net *net) +{ +#ifdef CONFIG_PROC_FS + remove_proc_entry("netlink", net->proc_net); +#endif +} + +static void __init netlink_add_usersock_entry(void) +{ + struct listeners *listeners; + int groups = 32; + + listeners = kzalloc(sizeof(*listeners) + NLGRPSZ(groups), GFP_KERNEL); + if (!listeners) + panic("netlink_add_usersock_entry: Cannot allocate listeners\n"); + + netlink_table_grab(); + + nl_table[NETLINK_USERSOCK].groups = groups; + rcu_assign_pointer(nl_table[NETLINK_USERSOCK].listeners, listeners); + nl_table[NETLINK_USERSOCK].module = THIS_MODULE; + nl_table[NETLINK_USERSOCK].registered = 1; + nl_table[NETLINK_USERSOCK].flags = NL_CFG_F_NONROOT_SEND; + + netlink_table_ungrab(); +} + +static struct pernet_operations __net_initdata netlink_net_ops = { + .init = netlink_net_init, + .exit = netlink_net_exit, +}; + +static inline u32 netlink_hash(const void *data, u32 len, u32 seed) +{ + const struct netlink_sock *nlk = data; + struct netlink_compare_arg arg; + + netlink_compare_arg_init(&arg, sock_net(&nlk->sk), nlk->portid); + return jhash2((u32 *)&arg, netlink_compare_arg_len / sizeof(u32), seed); +} + +static const struct rhashtable_params netlink_rhashtable_params = { + .head_offset = offsetof(struct netlink_sock, node), + .key_len = netlink_compare_arg_len, + .obj_hashfn = netlink_hash, + .obj_cmpfn = netlink_compare, + .automatic_shrinking = true, +}; + +#if defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS) +BTF_ID_LIST(btf_netlink_sock_id) +BTF_ID(struct, netlink_sock) + +static const struct bpf_iter_seq_info netlink_seq_info = { + .seq_ops = &netlink_seq_ops, + .init_seq_private = bpf_iter_init_seq_net, + .fini_seq_private = bpf_iter_fini_seq_net, + .seq_priv_size = sizeof(struct nl_seq_iter), +}; + +static struct bpf_iter_reg netlink_reg_info = { + .target = "netlink", + .ctx_arg_info_size = 1, + .ctx_arg_info = { + { offsetof(struct bpf_iter__netlink, sk), + PTR_TO_BTF_ID_OR_NULL }, + }, + .seq_info = &netlink_seq_info, +}; + +static int __init bpf_iter_register(void) +{ + netlink_reg_info.ctx_arg_info[0].btf_id = *btf_netlink_sock_id; + return bpf_iter_reg_target(&netlink_reg_info); +} +#endif + +static int __init netlink_proto_init(void) +{ + int i; + int err = proto_register(&netlink_proto, 0); + + if (err != 0) + goto out; + +#if defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS) + err = bpf_iter_register(); + if (err) + goto out; +#endif + + BUILD_BUG_ON(sizeof(struct netlink_skb_parms) > sizeof_field(struct sk_buff, cb)); + + nl_table = kcalloc(MAX_LINKS, sizeof(*nl_table), GFP_KERNEL); + if (!nl_table) + goto panic; + + for (i = 0; i < MAX_LINKS; i++) { + if (rhashtable_init(&nl_table[i].hash, + &netlink_rhashtable_params) < 0) { + while (--i > 0) + rhashtable_destroy(&nl_table[i].hash); + kfree(nl_table); + goto panic; + } + } + + netlink_add_usersock_entry(); + + sock_register(&netlink_family_ops); + register_pernet_subsys(&netlink_net_ops); + register_pernet_subsys(&netlink_tap_net_ops); + /* The netlink device handler may be needed early. */ + rtnetlink_init(); +out: + return err; +panic: + panic("netlink_init: Cannot allocate nl_table\n"); +} + +core_initcall(netlink_proto_init); diff --git a/net/netlink/af_netlink.h b/net/netlink/af_netlink.h new file mode 100644 index 000000000..b30b8fc76 --- /dev/null +++ b/net/netlink/af_netlink.h @@ -0,0 +1,78 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _AF_NETLINK_H +#define _AF_NETLINK_H + +#include +#include +#include +#include + +/* flags */ +enum { + NETLINK_F_KERNEL_SOCKET, + NETLINK_F_RECV_PKTINFO, + NETLINK_F_BROADCAST_SEND_ERROR, + NETLINK_F_RECV_NO_ENOBUFS, + NETLINK_F_LISTEN_ALL_NSID, + NETLINK_F_CAP_ACK, + NETLINK_F_EXT_ACK, + NETLINK_F_STRICT_CHK, +}; + +#define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8) +#define NLGRPLONGS(x) (NLGRPSZ(x)/sizeof(unsigned long)) + +struct netlink_sock { + /* struct sock has to be the first member of netlink_sock */ + struct sock sk; + unsigned long flags; + u32 portid; + u32 dst_portid; + u32 dst_group; + u32 subscriptions; + u32 ngroups; + unsigned long *groups; + unsigned long state; + size_t max_recvmsg_len; + wait_queue_head_t wait; + bool bound; + bool cb_running; + int dump_done_errno; + struct netlink_callback cb; + struct mutex *cb_mutex; + struct mutex cb_def_mutex; + void (*netlink_rcv)(struct sk_buff *skb); + int (*netlink_bind)(struct net *net, int group); + void (*netlink_unbind)(struct net *net, int group); + struct module *module; + + struct rhash_head node; + struct rcu_head rcu; + struct work_struct work; +}; + +static inline struct netlink_sock *nlk_sk(struct sock *sk) +{ + return container_of(sk, struct netlink_sock, sk); +} + +#define nlk_test_bit(nr, sk) test_bit(NETLINK_F_##nr, &nlk_sk(sk)->flags) + +struct netlink_table { + struct rhashtable hash; + struct hlist_head mc_list; + struct listeners __rcu *listeners; + unsigned int flags; + unsigned int groups; + struct mutex *cb_mutex; + struct module *module; + int (*bind)(struct net *net, int group); + void (*unbind)(struct net *net, int group); + bool (*compare)(struct net *net, struct sock *sock); + int registered; +}; + +extern struct netlink_table *nl_table; +extern rwlock_t nl_table_lock; + +#endif diff --git a/net/netlink/diag.c b/net/netlink/diag.c new file mode 100644 index 000000000..9c4f231be --- /dev/null +++ b/net/netlink/diag.c @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include + +#include +#include +#include +#include +#include + +#include "af_netlink.h" + +static int sk_diag_dump_groups(struct sock *sk, struct sk_buff *nlskb) +{ + struct netlink_sock *nlk = nlk_sk(sk); + + if (nlk->groups == NULL) + return 0; + + return nla_put(nlskb, NETLINK_DIAG_GROUPS, NLGRPSZ(nlk->ngroups), + nlk->groups); +} + +static int sk_diag_put_flags(struct sock *sk, struct sk_buff *skb) +{ + struct netlink_sock *nlk = nlk_sk(sk); + u32 flags = 0; + + if (nlk->cb_running) + flags |= NDIAG_FLAG_CB_RUNNING; + if (nlk_test_bit(RECV_PKTINFO, sk)) + flags |= NDIAG_FLAG_PKTINFO; + if (nlk_test_bit(BROADCAST_SEND_ERROR, sk)) + flags |= NDIAG_FLAG_BROADCAST_ERROR; + if (nlk_test_bit(RECV_NO_ENOBUFS, sk)) + flags |= NDIAG_FLAG_NO_ENOBUFS; + if (nlk_test_bit(LISTEN_ALL_NSID, sk)) + flags |= NDIAG_FLAG_LISTEN_ALL_NSID; + if (nlk_test_bit(CAP_ACK, sk)) + flags |= NDIAG_FLAG_CAP_ACK; + + return nla_put_u32(skb, NETLINK_DIAG_FLAGS, flags); +} + +static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, + struct netlink_diag_req *req, + u32 portid, u32 seq, u32 flags, int sk_ino) +{ + struct nlmsghdr *nlh; + struct netlink_diag_msg *rep; + struct netlink_sock *nlk = nlk_sk(sk); + + nlh = nlmsg_put(skb, portid, seq, SOCK_DIAG_BY_FAMILY, sizeof(*rep), + flags); + if (!nlh) + return -EMSGSIZE; + + rep = nlmsg_data(nlh); + rep->ndiag_family = AF_NETLINK; + rep->ndiag_type = sk->sk_type; + rep->ndiag_protocol = sk->sk_protocol; + rep->ndiag_state = sk->sk_state; + + rep->ndiag_ino = sk_ino; + rep->ndiag_portid = nlk->portid; + rep->ndiag_dst_portid = nlk->dst_portid; + rep->ndiag_dst_group = nlk->dst_group; + sock_diag_save_cookie(sk, rep->ndiag_cookie); + + if ((req->ndiag_show & NDIAG_SHOW_GROUPS) && + sk_diag_dump_groups(sk, skb)) + goto out_nlmsg_trim; + + if ((req->ndiag_show & NDIAG_SHOW_MEMINFO) && + sock_diag_put_meminfo(sk, skb, NETLINK_DIAG_MEMINFO)) + goto out_nlmsg_trim; + + if ((req->ndiag_show & NDIAG_SHOW_FLAGS) && + sk_diag_put_flags(sk, skb)) + goto out_nlmsg_trim; + + nlmsg_end(skb, nlh); + return 0; + +out_nlmsg_trim: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, + int protocol, int s_num) +{ + struct rhashtable_iter *hti = (void *)cb->args[2]; + struct netlink_table *tbl = &nl_table[protocol]; + struct net *net = sock_net(skb->sk); + struct netlink_diag_req *req; + struct netlink_sock *nlsk; + unsigned long flags; + struct sock *sk; + int num = 2; + int ret = 0; + + req = nlmsg_data(cb->nlh); + + if (s_num > 1) + goto mc_list; + + num--; + + if (!hti) { + hti = kmalloc(sizeof(*hti), GFP_KERNEL); + if (!hti) + return -ENOMEM; + + cb->args[2] = (long)hti; + } + + if (!s_num) + rhashtable_walk_enter(&tbl->hash, hti); + + rhashtable_walk_start(hti); + + while ((nlsk = rhashtable_walk_next(hti))) { + if (IS_ERR(nlsk)) { + ret = PTR_ERR(nlsk); + if (ret == -EAGAIN) { + ret = 0; + continue; + } + break; + } + + sk = (struct sock *)nlsk; + + if (!net_eq(sock_net(sk), net)) + continue; + + if (sk_diag_fill(sk, skb, req, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI, + sock_i_ino(sk)) < 0) { + ret = 1; + break; + } + } + + rhashtable_walk_stop(hti); + + if (ret) + goto done; + + rhashtable_walk_exit(hti); + num++; + +mc_list: + read_lock_irqsave(&nl_table_lock, flags); + sk_for_each_bound(sk, &tbl->mc_list) { + if (sk_hashed(sk)) + continue; + if (!net_eq(sock_net(sk), net)) + continue; + if (num < s_num) { + num++; + continue; + } + + if (sk_diag_fill(sk, skb, req, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI, + __sock_i_ino(sk)) < 0) { + ret = 1; + break; + } + num++; + } + read_unlock_irqrestore(&nl_table_lock, flags); + +done: + cb->args[0] = num; + + return ret; +} + +static int netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct netlink_diag_req *req; + int s_num = cb->args[0]; + int err = 0; + + req = nlmsg_data(cb->nlh); + + if (req->sdiag_protocol == NDIAG_PROTO_ALL) { + int i; + + for (i = cb->args[1]; i < MAX_LINKS; i++) { + err = __netlink_diag_dump(skb, cb, i, s_num); + if (err) + break; + s_num = 0; + } + cb->args[1] = i; + } else { + if (req->sdiag_protocol >= MAX_LINKS) + return -ENOENT; + + err = __netlink_diag_dump(skb, cb, req->sdiag_protocol, s_num); + } + + return err < 0 ? err : skb->len; +} + +static int netlink_diag_dump_done(struct netlink_callback *cb) +{ + struct rhashtable_iter *hti = (void *)cb->args[2]; + + if (cb->args[0] == 1) + rhashtable_walk_exit(hti); + + kfree(hti); + + return 0; +} + +static int netlink_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h) +{ + int hdrlen = sizeof(struct netlink_diag_req); + struct net *net = sock_net(skb->sk); + + if (nlmsg_len(h) < hdrlen) + return -EINVAL; + + if (h->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .dump = netlink_diag_dump, + .done = netlink_diag_dump_done, + }; + return netlink_dump_start(net->diag_nlsk, skb, h, &c); + } else + return -EOPNOTSUPP; +} + +static const struct sock_diag_handler netlink_diag_handler = { + .family = AF_NETLINK, + .dump = netlink_diag_handler_dump, +}; + +static int __init netlink_diag_init(void) +{ + return sock_diag_register(&netlink_diag_handler); +} + +static void __exit netlink_diag_exit(void) +{ + sock_diag_unregister(&netlink_diag_handler); +} + +module_init(netlink_diag_init); +module_exit(netlink_diag_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 16 /* AF_NETLINK */); diff --git a/net/netlink/genetlink.c b/net/netlink/genetlink.c new file mode 100644 index 000000000..505d3b910 --- /dev/null +++ b/net/netlink/genetlink.c @@ -0,0 +1,1565 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * NETLINK Generic Netlink Family + * + * Authors: Jamal Hadi Salim + * Thomas Graf + * Johannes Berg + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static DEFINE_MUTEX(genl_mutex); /* serialization of message processing */ +static DECLARE_RWSEM(cb_lock); + +atomic_t genl_sk_destructing_cnt = ATOMIC_INIT(0); +DECLARE_WAIT_QUEUE_HEAD(genl_sk_destructing_waitq); + +void genl_lock(void) +{ + mutex_lock(&genl_mutex); +} +EXPORT_SYMBOL(genl_lock); + +void genl_unlock(void) +{ + mutex_unlock(&genl_mutex); +} +EXPORT_SYMBOL(genl_unlock); + +static void genl_lock_all(void) +{ + down_write(&cb_lock); + genl_lock(); +} + +static void genl_unlock_all(void) +{ + genl_unlock(); + up_write(&cb_lock); +} + +static DEFINE_IDR(genl_fam_idr); + +/* + * Bitmap of multicast groups that are currently in use. + * + * To avoid an allocation at boot of just one unsigned long, + * declare it global instead. + * Bit 0 is marked as already used since group 0 is invalid. + * Bit 1 is marked as already used since the drop-monitor code + * abuses the API and thinks it can statically use group 1. + * That group will typically conflict with other groups that + * any proper users use. + * Bit 16 is marked as used since it's used for generic netlink + * and the code no longer marks pre-reserved IDs as used. + * Bit 17 is marked as already used since the VFS quota code + * also abused this API and relied on family == group ID, we + * cater to that by giving it a static family and group ID. + * Bit 18 is marked as already used since the PMCRAID driver + * did the same thing as the VFS quota code (maybe copied?) + */ +static unsigned long mc_group_start = 0x3 | BIT(GENL_ID_CTRL) | + BIT(GENL_ID_VFS_DQUOT) | + BIT(GENL_ID_PMCRAID); +static unsigned long *mc_groups = &mc_group_start; +static unsigned long mc_groups_longs = 1; + +/* We need the last attribute with non-zero ID therefore a 2-entry array */ +static struct nla_policy genl_policy_reject_all[] = { + { .type = NLA_REJECT }, + { .type = NLA_REJECT }, +}; + +static int genl_ctrl_event(int event, const struct genl_family *family, + const struct genl_multicast_group *grp, + int grp_id); + +static void +genl_op_fill_in_reject_policy(const struct genl_family *family, + struct genl_ops *op) +{ + BUILD_BUG_ON(ARRAY_SIZE(genl_policy_reject_all) - 1 != 1); + + if (op->policy || op->cmd < family->resv_start_op) + return; + + op->policy = genl_policy_reject_all; + op->maxattr = 1; +} + +static const struct genl_family *genl_family_find_byid(unsigned int id) +{ + return idr_find(&genl_fam_idr, id); +} + +static const struct genl_family *genl_family_find_byname(char *name) +{ + const struct genl_family *family; + unsigned int id; + + idr_for_each_entry(&genl_fam_idr, family, id) + if (strcmp(family->name, name) == 0) + return family; + + return NULL; +} + +static int genl_get_cmd_cnt(const struct genl_family *family) +{ + return family->n_ops + family->n_small_ops; +} + +static void genl_op_from_full(const struct genl_family *family, + unsigned int i, struct genl_ops *op) +{ + *op = family->ops[i]; + + if (!op->maxattr) + op->maxattr = family->maxattr; + if (!op->policy) + op->policy = family->policy; + + genl_op_fill_in_reject_policy(family, op); +} + +static int genl_get_cmd_full(u32 cmd, const struct genl_family *family, + struct genl_ops *op) +{ + int i; + + for (i = 0; i < family->n_ops; i++) + if (family->ops[i].cmd == cmd) { + genl_op_from_full(family, i, op); + return 0; + } + + return -ENOENT; +} + +static void genl_op_from_small(const struct genl_family *family, + unsigned int i, struct genl_ops *op) +{ + memset(op, 0, sizeof(*op)); + op->doit = family->small_ops[i].doit; + op->dumpit = family->small_ops[i].dumpit; + op->cmd = family->small_ops[i].cmd; + op->internal_flags = family->small_ops[i].internal_flags; + op->flags = family->small_ops[i].flags; + op->validate = family->small_ops[i].validate; + + op->maxattr = family->maxattr; + op->policy = family->policy; + + genl_op_fill_in_reject_policy(family, op); +} + +static int genl_get_cmd_small(u32 cmd, const struct genl_family *family, + struct genl_ops *op) +{ + int i; + + for (i = 0; i < family->n_small_ops; i++) + if (family->small_ops[i].cmd == cmd) { + genl_op_from_small(family, i, op); + return 0; + } + + return -ENOENT; +} + +static int genl_get_cmd(u32 cmd, const struct genl_family *family, + struct genl_ops *op) +{ + if (!genl_get_cmd_full(cmd, family, op)) + return 0; + return genl_get_cmd_small(cmd, family, op); +} + +static void genl_get_cmd_by_index(unsigned int i, + const struct genl_family *family, + struct genl_ops *op) +{ + if (i < family->n_ops) + genl_op_from_full(family, i, op); + else if (i < family->n_ops + family->n_small_ops) + genl_op_from_small(family, i - family->n_ops, op); + else + WARN_ON_ONCE(1); +} + +static int genl_allocate_reserve_groups(int n_groups, int *first_id) +{ + unsigned long *new_groups; + int start = 0; + int i; + int id; + bool fits; + + do { + if (start == 0) + id = find_first_zero_bit(mc_groups, + mc_groups_longs * + BITS_PER_LONG); + else + id = find_next_zero_bit(mc_groups, + mc_groups_longs * BITS_PER_LONG, + start); + + fits = true; + for (i = id; + i < min_t(int, id + n_groups, + mc_groups_longs * BITS_PER_LONG); + i++) { + if (test_bit(i, mc_groups)) { + start = i; + fits = false; + break; + } + } + + if (id + n_groups > mc_groups_longs * BITS_PER_LONG) { + unsigned long new_longs = mc_groups_longs + + BITS_TO_LONGS(n_groups); + size_t nlen = new_longs * sizeof(unsigned long); + + if (mc_groups == &mc_group_start) { + new_groups = kzalloc(nlen, GFP_KERNEL); + if (!new_groups) + return -ENOMEM; + mc_groups = new_groups; + *mc_groups = mc_group_start; + } else { + new_groups = krealloc(mc_groups, nlen, + GFP_KERNEL); + if (!new_groups) + return -ENOMEM; + mc_groups = new_groups; + for (i = 0; i < BITS_TO_LONGS(n_groups); i++) + mc_groups[mc_groups_longs + i] = 0; + } + mc_groups_longs = new_longs; + } + } while (!fits); + + for (i = id; i < id + n_groups; i++) + set_bit(i, mc_groups); + *first_id = id; + return 0; +} + +static struct genl_family genl_ctrl; + +static int genl_validate_assign_mc_groups(struct genl_family *family) +{ + int first_id; + int n_groups = family->n_mcgrps; + int err = 0, i; + bool groups_allocated = false; + + if (!n_groups) + return 0; + + for (i = 0; i < n_groups; i++) { + const struct genl_multicast_group *grp = &family->mcgrps[i]; + + if (WARN_ON(grp->name[0] == '\0')) + return -EINVAL; + if (WARN_ON(memchr(grp->name, '\0', GENL_NAMSIZ) == NULL)) + return -EINVAL; + } + + /* special-case our own group and hacks */ + if (family == &genl_ctrl) { + first_id = GENL_ID_CTRL; + BUG_ON(n_groups != 1); + } else if (strcmp(family->name, "NET_DM") == 0) { + first_id = 1; + BUG_ON(n_groups != 1); + } else if (family->id == GENL_ID_VFS_DQUOT) { + first_id = GENL_ID_VFS_DQUOT; + BUG_ON(n_groups != 1); + } else if (family->id == GENL_ID_PMCRAID) { + first_id = GENL_ID_PMCRAID; + BUG_ON(n_groups != 1); + } else { + groups_allocated = true; + err = genl_allocate_reserve_groups(n_groups, &first_id); + if (err) + return err; + } + + family->mcgrp_offset = first_id; + + /* if still initializing, can't and don't need to realloc bitmaps */ + if (!init_net.genl_sock) + return 0; + + if (family->netnsok) { + struct net *net; + + netlink_table_grab(); + rcu_read_lock(); + for_each_net_rcu(net) { + err = __netlink_change_ngroups(net->genl_sock, + mc_groups_longs * BITS_PER_LONG); + if (err) { + /* + * No need to roll back, can only fail if + * memory allocation fails and then the + * number of _possible_ groups has been + * increased on some sockets which is ok. + */ + break; + } + } + rcu_read_unlock(); + netlink_table_ungrab(); + } else { + err = netlink_change_ngroups(init_net.genl_sock, + mc_groups_longs * BITS_PER_LONG); + } + + if (groups_allocated && err) { + for (i = 0; i < family->n_mcgrps; i++) + clear_bit(family->mcgrp_offset + i, mc_groups); + } + + return err; +} + +static void genl_unregister_mc_groups(const struct genl_family *family) +{ + struct net *net; + int i; + + netlink_table_grab(); + rcu_read_lock(); + for_each_net_rcu(net) { + for (i = 0; i < family->n_mcgrps; i++) + __netlink_clear_multicast_users( + net->genl_sock, family->mcgrp_offset + i); + } + rcu_read_unlock(); + netlink_table_ungrab(); + + for (i = 0; i < family->n_mcgrps; i++) { + int grp_id = family->mcgrp_offset + i; + + if (grp_id != 1) + clear_bit(grp_id, mc_groups); + genl_ctrl_event(CTRL_CMD_DELMCAST_GRP, family, + &family->mcgrps[i], grp_id); + } +} + +static int genl_validate_ops(const struct genl_family *family) +{ + int i, j; + + if (WARN_ON(family->n_ops && !family->ops) || + WARN_ON(family->n_small_ops && !family->small_ops)) + return -EINVAL; + + for (i = 0; i < genl_get_cmd_cnt(family); i++) { + struct genl_ops op; + + genl_get_cmd_by_index(i, family, &op); + if (op.dumpit == NULL && op.doit == NULL) + return -EINVAL; + if (WARN_ON(op.cmd >= family->resv_start_op && op.validate)) + return -EINVAL; + for (j = i + 1; j < genl_get_cmd_cnt(family); j++) { + struct genl_ops op2; + + genl_get_cmd_by_index(j, family, &op2); + if (op.cmd == op2.cmd) + return -EINVAL; + } + } + + return 0; +} + +/** + * genl_register_family - register a generic netlink family + * @family: generic netlink family + * + * Registers the specified family after validating it first. Only one + * family may be registered with the same family name or identifier. + * + * The family's ops, multicast groups and module pointer must already + * be assigned. + * + * Return 0 on success or a negative error code. + */ +int genl_register_family(struct genl_family *family) +{ + int err, i; + int start = GENL_START_ALLOC, end = GENL_MAX_ID; + + err = genl_validate_ops(family); + if (err) + return err; + + genl_lock_all(); + + if (genl_family_find_byname(family->name)) { + err = -EEXIST; + goto errout_locked; + } + + /* + * Sadly, a few cases need to be special-cased + * due to them having previously abused the API + * and having used their family ID also as their + * multicast group ID, so we use reserved IDs + * for both to be sure we can do that mapping. + */ + if (family == &genl_ctrl) { + /* and this needs to be special for initial family lookups */ + start = end = GENL_ID_CTRL; + } else if (strcmp(family->name, "pmcraid") == 0) { + start = end = GENL_ID_PMCRAID; + } else if (strcmp(family->name, "VFS_DQUOT") == 0) { + start = end = GENL_ID_VFS_DQUOT; + } + + family->id = idr_alloc_cyclic(&genl_fam_idr, family, + start, end + 1, GFP_KERNEL); + if (family->id < 0) { + err = family->id; + goto errout_locked; + } + + err = genl_validate_assign_mc_groups(family); + if (err) + goto errout_remove; + + genl_unlock_all(); + + /* send all events */ + genl_ctrl_event(CTRL_CMD_NEWFAMILY, family, NULL, 0); + for (i = 0; i < family->n_mcgrps; i++) + genl_ctrl_event(CTRL_CMD_NEWMCAST_GRP, family, + &family->mcgrps[i], family->mcgrp_offset + i); + + return 0; + +errout_remove: + idr_remove(&genl_fam_idr, family->id); +errout_locked: + genl_unlock_all(); + return err; +} +EXPORT_SYMBOL(genl_register_family); + +/** + * genl_unregister_family - unregister generic netlink family + * @family: generic netlink family + * + * Unregisters the specified family. + * + * Returns 0 on success or a negative error code. + */ +int genl_unregister_family(const struct genl_family *family) +{ + genl_lock_all(); + + if (!genl_family_find_byid(family->id)) { + genl_unlock_all(); + return -ENOENT; + } + + genl_unregister_mc_groups(family); + + idr_remove(&genl_fam_idr, family->id); + + up_write(&cb_lock); + wait_event(genl_sk_destructing_waitq, + atomic_read(&genl_sk_destructing_cnt) == 0); + genl_unlock(); + + genl_ctrl_event(CTRL_CMD_DELFAMILY, family, NULL, 0); + + return 0; +} +EXPORT_SYMBOL(genl_unregister_family); + +/** + * genlmsg_put - Add generic netlink header to netlink message + * @skb: socket buffer holding the message + * @portid: netlink portid the message is addressed to + * @seq: sequence number (usually the one of the sender) + * @family: generic netlink family + * @flags: netlink message flags + * @cmd: generic netlink command + * + * Returns pointer to user specific header + */ +void *genlmsg_put(struct sk_buff *skb, u32 portid, u32 seq, + const struct genl_family *family, int flags, u8 cmd) +{ + struct nlmsghdr *nlh; + struct genlmsghdr *hdr; + + nlh = nlmsg_put(skb, portid, seq, family->id, GENL_HDRLEN + + family->hdrsize, flags); + if (nlh == NULL) + return NULL; + + hdr = nlmsg_data(nlh); + hdr->cmd = cmd; + hdr->version = family->version; + hdr->reserved = 0; + + return (char *) hdr + GENL_HDRLEN; +} +EXPORT_SYMBOL(genlmsg_put); + +static struct genl_dumpit_info *genl_dumpit_info_alloc(void) +{ + return kmalloc(sizeof(struct genl_dumpit_info), GFP_KERNEL); +} + +static void genl_dumpit_info_free(const struct genl_dumpit_info *info) +{ + kfree(info); +} + +static struct nlattr ** +genl_family_rcv_msg_attrs_parse(const struct genl_family *family, + struct nlmsghdr *nlh, + struct netlink_ext_ack *extack, + const struct genl_ops *ops, + int hdrlen, + enum genl_validate_flags no_strict_flag) +{ + enum netlink_validation validate = ops->validate & no_strict_flag ? + NL_VALIDATE_LIBERAL : + NL_VALIDATE_STRICT; + struct nlattr **attrbuf; + int err; + + if (!ops->maxattr) + return NULL; + + attrbuf = kmalloc_array(ops->maxattr + 1, + sizeof(struct nlattr *), GFP_KERNEL); + if (!attrbuf) + return ERR_PTR(-ENOMEM); + + err = __nlmsg_parse(nlh, hdrlen, attrbuf, ops->maxattr, ops->policy, + validate, extack); + if (err) { + kfree(attrbuf); + return ERR_PTR(err); + } + return attrbuf; +} + +static void genl_family_rcv_msg_attrs_free(struct nlattr **attrbuf) +{ + kfree(attrbuf); +} + +struct genl_start_context { + const struct genl_family *family; + struct nlmsghdr *nlh; + struct netlink_ext_ack *extack; + const struct genl_ops *ops; + int hdrlen; +}; + +static int genl_start(struct netlink_callback *cb) +{ + struct genl_start_context *ctx = cb->data; + const struct genl_ops *ops = ctx->ops; + struct genl_dumpit_info *info; + struct nlattr **attrs = NULL; + int rc = 0; + + if (ops->validate & GENL_DONT_VALIDATE_DUMP) + goto no_attrs; + + if (ctx->nlh->nlmsg_len < nlmsg_msg_size(ctx->hdrlen)) + return -EINVAL; + + attrs = genl_family_rcv_msg_attrs_parse(ctx->family, ctx->nlh, ctx->extack, + ops, ctx->hdrlen, + GENL_DONT_VALIDATE_DUMP_STRICT); + if (IS_ERR(attrs)) + return PTR_ERR(attrs); + +no_attrs: + info = genl_dumpit_info_alloc(); + if (!info) { + genl_family_rcv_msg_attrs_free(attrs); + return -ENOMEM; + } + info->family = ctx->family; + info->op = *ops; + info->attrs = attrs; + + cb->data = info; + if (ops->start) { + if (!ctx->family->parallel_ops) + genl_lock(); + rc = ops->start(cb); + if (!ctx->family->parallel_ops) + genl_unlock(); + } + + if (rc) { + genl_family_rcv_msg_attrs_free(info->attrs); + genl_dumpit_info_free(info); + cb->data = NULL; + } + return rc; +} + +static int genl_lock_dumpit(struct sk_buff *skb, struct netlink_callback *cb) +{ + const struct genl_ops *ops = &genl_dumpit_info(cb)->op; + int rc; + + genl_lock(); + rc = ops->dumpit(skb, cb); + genl_unlock(); + return rc; +} + +static int genl_lock_done(struct netlink_callback *cb) +{ + const struct genl_dumpit_info *info = genl_dumpit_info(cb); + const struct genl_ops *ops = &info->op; + int rc = 0; + + if (ops->done) { + genl_lock(); + rc = ops->done(cb); + genl_unlock(); + } + genl_family_rcv_msg_attrs_free(info->attrs); + genl_dumpit_info_free(info); + return rc; +} + +static int genl_parallel_done(struct netlink_callback *cb) +{ + const struct genl_dumpit_info *info = genl_dumpit_info(cb); + const struct genl_ops *ops = &info->op; + int rc = 0; + + if (ops->done) + rc = ops->done(cb); + genl_family_rcv_msg_attrs_free(info->attrs); + genl_dumpit_info_free(info); + return rc; +} + +static int genl_family_rcv_msg_dumpit(const struct genl_family *family, + struct sk_buff *skb, + struct nlmsghdr *nlh, + struct netlink_ext_ack *extack, + const struct genl_ops *ops, + int hdrlen, struct net *net) +{ + struct genl_start_context ctx; + int err; + + if (!ops->dumpit) + return -EOPNOTSUPP; + + ctx.family = family; + ctx.nlh = nlh; + ctx.extack = extack; + ctx.ops = ops; + ctx.hdrlen = hdrlen; + + if (!family->parallel_ops) { + struct netlink_dump_control c = { + .module = family->module, + .data = &ctx, + .start = genl_start, + .dump = genl_lock_dumpit, + .done = genl_lock_done, + }; + + genl_unlock(); + err = __netlink_dump_start(net->genl_sock, skb, nlh, &c); + genl_lock(); + } else { + struct netlink_dump_control c = { + .module = family->module, + .data = &ctx, + .start = genl_start, + .dump = ops->dumpit, + .done = genl_parallel_done, + }; + + err = __netlink_dump_start(net->genl_sock, skb, nlh, &c); + } + + return err; +} + +static int genl_family_rcv_msg_doit(const struct genl_family *family, + struct sk_buff *skb, + struct nlmsghdr *nlh, + struct netlink_ext_ack *extack, + const struct genl_ops *ops, + int hdrlen, struct net *net) +{ + struct nlattr **attrbuf; + struct genl_info info; + int err; + + if (!ops->doit) + return -EOPNOTSUPP; + + attrbuf = genl_family_rcv_msg_attrs_parse(family, nlh, extack, + ops, hdrlen, + GENL_DONT_VALIDATE_STRICT); + if (IS_ERR(attrbuf)) + return PTR_ERR(attrbuf); + + info.snd_seq = nlh->nlmsg_seq; + info.snd_portid = NETLINK_CB(skb).portid; + info.nlhdr = nlh; + info.genlhdr = nlmsg_data(nlh); + info.userhdr = nlmsg_data(nlh) + GENL_HDRLEN; + info.attrs = attrbuf; + info.extack = extack; + genl_info_net_set(&info, net); + memset(&info.user_ptr, 0, sizeof(info.user_ptr)); + + if (family->pre_doit) { + err = family->pre_doit(ops, skb, &info); + if (err) + goto out; + } + + err = ops->doit(skb, &info); + + if (family->post_doit) + family->post_doit(ops, skb, &info); + +out: + genl_family_rcv_msg_attrs_free(attrbuf); + + return err; +} + +static int genl_header_check(const struct genl_family *family, + struct nlmsghdr *nlh, struct genlmsghdr *hdr, + struct netlink_ext_ack *extack) +{ + u16 flags; + + /* Only for commands added after we started validating */ + if (hdr->cmd < family->resv_start_op) + return 0; + + if (hdr->reserved) { + NL_SET_ERR_MSG(extack, "genlmsghdr.reserved field is not 0"); + return -EINVAL; + } + + /* Old netlink flags have pretty loose semantics, allow only the flags + * consumed by the core where we can enforce the meaning. + */ + flags = nlh->nlmsg_flags; + if ((flags & NLM_F_DUMP) == NLM_F_DUMP) /* DUMP is 2 bits */ + flags &= ~NLM_F_DUMP; + if (flags & ~(NLM_F_REQUEST | NLM_F_ACK | NLM_F_ECHO)) { + NL_SET_ERR_MSG(extack, + "ambiguous or reserved bits set in nlmsg_flags"); + return -EINVAL; + } + + return 0; +} + +static int genl_family_rcv_msg(const struct genl_family *family, + struct sk_buff *skb, + struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct genlmsghdr *hdr = nlmsg_data(nlh); + struct genl_ops op; + int hdrlen; + + /* this family doesn't exist in this netns */ + if (!family->netnsok && !net_eq(net, &init_net)) + return -ENOENT; + + hdrlen = GENL_HDRLEN + family->hdrsize; + if (nlh->nlmsg_len < nlmsg_msg_size(hdrlen)) + return -EINVAL; + + if (genl_header_check(family, nlh, hdr, extack)) + return -EINVAL; + + if (genl_get_cmd(hdr->cmd, family, &op)) + return -EOPNOTSUPP; + + if ((op.flags & GENL_ADMIN_PERM) && + !netlink_capable(skb, CAP_NET_ADMIN)) + return -EPERM; + + if ((op.flags & GENL_UNS_ADMIN_PERM) && + !netlink_ns_capable(skb, net->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + if ((nlh->nlmsg_flags & NLM_F_DUMP) == NLM_F_DUMP) + return genl_family_rcv_msg_dumpit(family, skb, nlh, extack, + &op, hdrlen, net); + else + return genl_family_rcv_msg_doit(family, skb, nlh, extack, + &op, hdrlen, net); +} + +static int genl_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + const struct genl_family *family; + int err; + + family = genl_family_find_byid(nlh->nlmsg_type); + if (family == NULL) + return -ENOENT; + + if (!family->parallel_ops) + genl_lock(); + + err = genl_family_rcv_msg(family, skb, nlh, extack); + + if (!family->parallel_ops) + genl_unlock(); + + return err; +} + +static void genl_rcv(struct sk_buff *skb) +{ + down_read(&cb_lock); + netlink_rcv_skb(skb, &genl_rcv_msg); + up_read(&cb_lock); +} + +/************************************************************************** + * Controller + **************************************************************************/ + +static struct genl_family genl_ctrl; + +static int ctrl_fill_info(const struct genl_family *family, u32 portid, u32 seq, + u32 flags, struct sk_buff *skb, u8 cmd) +{ + void *hdr; + + hdr = genlmsg_put(skb, portid, seq, &genl_ctrl, flags, cmd); + if (hdr == NULL) + return -1; + + if (nla_put_string(skb, CTRL_ATTR_FAMILY_NAME, family->name) || + nla_put_u16(skb, CTRL_ATTR_FAMILY_ID, family->id) || + nla_put_u32(skb, CTRL_ATTR_VERSION, family->version) || + nla_put_u32(skb, CTRL_ATTR_HDRSIZE, family->hdrsize) || + nla_put_u32(skb, CTRL_ATTR_MAXATTR, family->maxattr)) + goto nla_put_failure; + + if (genl_get_cmd_cnt(family)) { + struct nlattr *nla_ops; + int i; + + nla_ops = nla_nest_start_noflag(skb, CTRL_ATTR_OPS); + if (nla_ops == NULL) + goto nla_put_failure; + + for (i = 0; i < genl_get_cmd_cnt(family); i++) { + struct nlattr *nest; + struct genl_ops op; + u32 op_flags; + + genl_get_cmd_by_index(i, family, &op); + op_flags = op.flags; + if (op.dumpit) + op_flags |= GENL_CMD_CAP_DUMP; + if (op.doit) + op_flags |= GENL_CMD_CAP_DO; + if (op.policy) + op_flags |= GENL_CMD_CAP_HASPOL; + + nest = nla_nest_start_noflag(skb, i + 1); + if (nest == NULL) + goto nla_put_failure; + + if (nla_put_u32(skb, CTRL_ATTR_OP_ID, op.cmd) || + nla_put_u32(skb, CTRL_ATTR_OP_FLAGS, op_flags)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + } + + nla_nest_end(skb, nla_ops); + } + + if (family->n_mcgrps) { + struct nlattr *nla_grps; + int i; + + nla_grps = nla_nest_start_noflag(skb, CTRL_ATTR_MCAST_GROUPS); + if (nla_grps == NULL) + goto nla_put_failure; + + for (i = 0; i < family->n_mcgrps; i++) { + struct nlattr *nest; + const struct genl_multicast_group *grp; + + grp = &family->mcgrps[i]; + + nest = nla_nest_start_noflag(skb, i + 1); + if (nest == NULL) + goto nla_put_failure; + + if (nla_put_u32(skb, CTRL_ATTR_MCAST_GRP_ID, + family->mcgrp_offset + i) || + nla_put_string(skb, CTRL_ATTR_MCAST_GRP_NAME, + grp->name)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + } + nla_nest_end(skb, nla_grps); + } + + genlmsg_end(skb, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; +} + +static int ctrl_fill_mcgrp_info(const struct genl_family *family, + const struct genl_multicast_group *grp, + int grp_id, u32 portid, u32 seq, u32 flags, + struct sk_buff *skb, u8 cmd) +{ + void *hdr; + struct nlattr *nla_grps; + struct nlattr *nest; + + hdr = genlmsg_put(skb, portid, seq, &genl_ctrl, flags, cmd); + if (hdr == NULL) + return -1; + + if (nla_put_string(skb, CTRL_ATTR_FAMILY_NAME, family->name) || + nla_put_u16(skb, CTRL_ATTR_FAMILY_ID, family->id)) + goto nla_put_failure; + + nla_grps = nla_nest_start_noflag(skb, CTRL_ATTR_MCAST_GROUPS); + if (nla_grps == NULL) + goto nla_put_failure; + + nest = nla_nest_start_noflag(skb, 1); + if (nest == NULL) + goto nla_put_failure; + + if (nla_put_u32(skb, CTRL_ATTR_MCAST_GRP_ID, grp_id) || + nla_put_string(skb, CTRL_ATTR_MCAST_GRP_NAME, + grp->name)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + nla_nest_end(skb, nla_grps); + + genlmsg_end(skb, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; +} + +static int ctrl_dumpfamily(struct sk_buff *skb, struct netlink_callback *cb) +{ + int n = 0; + struct genl_family *rt; + struct net *net = sock_net(skb->sk); + int fams_to_skip = cb->args[0]; + unsigned int id; + + idr_for_each_entry(&genl_fam_idr, rt, id) { + if (!rt->netnsok && !net_eq(net, &init_net)) + continue; + + if (n++ < fams_to_skip) + continue; + + if (ctrl_fill_info(rt, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + skb, CTRL_CMD_NEWFAMILY) < 0) { + n--; + break; + } + } + + cb->args[0] = n; + return skb->len; +} + +static struct sk_buff *ctrl_build_family_msg(const struct genl_family *family, + u32 portid, int seq, u8 cmd) +{ + struct sk_buff *skb; + int err; + + skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (skb == NULL) + return ERR_PTR(-ENOBUFS); + + err = ctrl_fill_info(family, portid, seq, 0, skb, cmd); + if (err < 0) { + nlmsg_free(skb); + return ERR_PTR(err); + } + + return skb; +} + +static struct sk_buff * +ctrl_build_mcgrp_msg(const struct genl_family *family, + const struct genl_multicast_group *grp, + int grp_id, u32 portid, int seq, u8 cmd) +{ + struct sk_buff *skb; + int err; + + skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (skb == NULL) + return ERR_PTR(-ENOBUFS); + + err = ctrl_fill_mcgrp_info(family, grp, grp_id, portid, + seq, 0, skb, cmd); + if (err < 0) { + nlmsg_free(skb); + return ERR_PTR(err); + } + + return skb; +} + +static const struct nla_policy ctrl_policy_family[] = { + [CTRL_ATTR_FAMILY_ID] = { .type = NLA_U16 }, + [CTRL_ATTR_FAMILY_NAME] = { .type = NLA_NUL_STRING, + .len = GENL_NAMSIZ - 1 }, +}; + +static int ctrl_getfamily(struct sk_buff *skb, struct genl_info *info) +{ + struct sk_buff *msg; + const struct genl_family *res = NULL; + int err = -EINVAL; + + if (info->attrs[CTRL_ATTR_FAMILY_ID]) { + u16 id = nla_get_u16(info->attrs[CTRL_ATTR_FAMILY_ID]); + res = genl_family_find_byid(id); + err = -ENOENT; + } + + if (info->attrs[CTRL_ATTR_FAMILY_NAME]) { + char *name; + + name = nla_data(info->attrs[CTRL_ATTR_FAMILY_NAME]); + res = genl_family_find_byname(name); +#ifdef CONFIG_MODULES + if (res == NULL) { + genl_unlock(); + up_read(&cb_lock); + request_module("net-pf-%d-proto-%d-family-%s", + PF_NETLINK, NETLINK_GENERIC, name); + down_read(&cb_lock); + genl_lock(); + res = genl_family_find_byname(name); + } +#endif + err = -ENOENT; + } + + if (res == NULL) + return err; + + if (!res->netnsok && !net_eq(genl_info_net(info), &init_net)) { + /* family doesn't exist here */ + return -ENOENT; + } + + msg = ctrl_build_family_msg(res, info->snd_portid, info->snd_seq, + CTRL_CMD_NEWFAMILY); + if (IS_ERR(msg)) + return PTR_ERR(msg); + + return genlmsg_reply(msg, info); +} + +static int genl_ctrl_event(int event, const struct genl_family *family, + const struct genl_multicast_group *grp, + int grp_id) +{ + struct sk_buff *msg; + + /* genl is still initialising */ + if (!init_net.genl_sock) + return 0; + + switch (event) { + case CTRL_CMD_NEWFAMILY: + case CTRL_CMD_DELFAMILY: + WARN_ON(grp); + msg = ctrl_build_family_msg(family, 0, 0, event); + break; + case CTRL_CMD_NEWMCAST_GRP: + case CTRL_CMD_DELMCAST_GRP: + BUG_ON(!grp); + msg = ctrl_build_mcgrp_msg(family, grp, grp_id, 0, 0, event); + break; + default: + return -EINVAL; + } + + if (IS_ERR(msg)) + return PTR_ERR(msg); + + if (!family->netnsok) { + genlmsg_multicast_netns(&genl_ctrl, &init_net, msg, 0, + 0, GFP_KERNEL); + } else { + rcu_read_lock(); + genlmsg_multicast_allns(&genl_ctrl, msg, 0, + 0, GFP_ATOMIC); + rcu_read_unlock(); + } + + return 0; +} + +struct ctrl_dump_policy_ctx { + struct netlink_policy_dump_state *state; + const struct genl_family *rt; + unsigned int opidx; + u32 op; + u16 fam_id; + u8 policies:1, + single_op:1; +}; + +static const struct nla_policy ctrl_policy_policy[] = { + [CTRL_ATTR_FAMILY_ID] = { .type = NLA_U16 }, + [CTRL_ATTR_FAMILY_NAME] = { .type = NLA_NUL_STRING, + .len = GENL_NAMSIZ - 1 }, + [CTRL_ATTR_OP] = { .type = NLA_U32 }, +}; + +static int ctrl_dumppolicy_start(struct netlink_callback *cb) +{ + const struct genl_dumpit_info *info = genl_dumpit_info(cb); + struct ctrl_dump_policy_ctx *ctx = (void *)cb->ctx; + struct nlattr **tb = info->attrs; + const struct genl_family *rt; + struct genl_ops op; + int err, i; + + BUILD_BUG_ON(sizeof(*ctx) > sizeof(cb->ctx)); + + if (!tb[CTRL_ATTR_FAMILY_ID] && !tb[CTRL_ATTR_FAMILY_NAME]) + return -EINVAL; + + if (tb[CTRL_ATTR_FAMILY_ID]) { + ctx->fam_id = nla_get_u16(tb[CTRL_ATTR_FAMILY_ID]); + } else { + rt = genl_family_find_byname( + nla_data(tb[CTRL_ATTR_FAMILY_NAME])); + if (!rt) + return -ENOENT; + ctx->fam_id = rt->id; + } + + rt = genl_family_find_byid(ctx->fam_id); + if (!rt) + return -ENOENT; + + ctx->rt = rt; + + if (tb[CTRL_ATTR_OP]) { + ctx->single_op = true; + ctx->op = nla_get_u32(tb[CTRL_ATTR_OP]); + + err = genl_get_cmd(ctx->op, rt, &op); + if (err) { + NL_SET_BAD_ATTR(cb->extack, tb[CTRL_ATTR_OP]); + return err; + } + + if (!op.policy) + return -ENODATA; + + return netlink_policy_dump_add_policy(&ctx->state, op.policy, + op.maxattr); + } + + for (i = 0; i < genl_get_cmd_cnt(rt); i++) { + genl_get_cmd_by_index(i, rt, &op); + + if (op.policy) { + err = netlink_policy_dump_add_policy(&ctx->state, + op.policy, + op.maxattr); + if (err) + goto err_free_state; + } + } + + if (!ctx->state) + return -ENODATA; + return 0; + +err_free_state: + netlink_policy_dump_free(ctx->state); + return err; +} + +static void *ctrl_dumppolicy_prep(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct ctrl_dump_policy_ctx *ctx = (void *)cb->ctx; + void *hdr; + + hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, &genl_ctrl, + NLM_F_MULTI, CTRL_CMD_GETPOLICY); + if (!hdr) + return NULL; + + if (nla_put_u16(skb, CTRL_ATTR_FAMILY_ID, ctx->fam_id)) + return NULL; + + return hdr; +} + +static int ctrl_dumppolicy_put_op(struct sk_buff *skb, + struct netlink_callback *cb, + struct genl_ops *op) +{ + struct ctrl_dump_policy_ctx *ctx = (void *)cb->ctx; + struct nlattr *nest_pol, *nest_op; + void *hdr; + int idx; + + /* skip if we have nothing to show */ + if (!op->policy) + return 0; + if (!op->doit && + (!op->dumpit || op->validate & GENL_DONT_VALIDATE_DUMP)) + return 0; + + hdr = ctrl_dumppolicy_prep(skb, cb); + if (!hdr) + return -ENOBUFS; + + nest_pol = nla_nest_start(skb, CTRL_ATTR_OP_POLICY); + if (!nest_pol) + goto err; + + nest_op = nla_nest_start(skb, op->cmd); + if (!nest_op) + goto err; + + /* for now both do/dump are always the same */ + idx = netlink_policy_dump_get_policy_idx(ctx->state, + op->policy, + op->maxattr); + + if (op->doit && nla_put_u32(skb, CTRL_ATTR_POLICY_DO, idx)) + goto err; + + if (op->dumpit && !(op->validate & GENL_DONT_VALIDATE_DUMP) && + nla_put_u32(skb, CTRL_ATTR_POLICY_DUMP, idx)) + goto err; + + nla_nest_end(skb, nest_op); + nla_nest_end(skb, nest_pol); + genlmsg_end(skb, hdr); + + return 0; +err: + genlmsg_cancel(skb, hdr); + return -ENOBUFS; +} + +static int ctrl_dumppolicy(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct ctrl_dump_policy_ctx *ctx = (void *)cb->ctx; + void *hdr; + + if (!ctx->policies) { + while (ctx->opidx < genl_get_cmd_cnt(ctx->rt)) { + struct genl_ops op; + + if (ctx->single_op) { + int err; + + err = genl_get_cmd(ctx->op, ctx->rt, &op); + if (WARN_ON(err)) + return skb->len; + + /* break out of the loop after this one */ + ctx->opidx = genl_get_cmd_cnt(ctx->rt); + } else { + genl_get_cmd_by_index(ctx->opidx, ctx->rt, &op); + } + + if (ctrl_dumppolicy_put_op(skb, cb, &op)) + return skb->len; + + ctx->opidx++; + } + + /* completed with the per-op policy index list */ + ctx->policies = true; + } + + while (netlink_policy_dump_loop(ctx->state)) { + struct nlattr *nest; + + hdr = ctrl_dumppolicy_prep(skb, cb); + if (!hdr) + goto nla_put_failure; + + nest = nla_nest_start(skb, CTRL_ATTR_POLICY); + if (!nest) + goto nla_put_failure; + + if (netlink_policy_dump_write(skb, ctx->state)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + + genlmsg_end(skb, hdr); + } + + return skb->len; + +nla_put_failure: + genlmsg_cancel(skb, hdr); + return skb->len; +} + +static int ctrl_dumppolicy_done(struct netlink_callback *cb) +{ + struct ctrl_dump_policy_ctx *ctx = (void *)cb->ctx; + + netlink_policy_dump_free(ctx->state); + return 0; +} + +static const struct genl_ops genl_ctrl_ops[] = { + { + .cmd = CTRL_CMD_GETFAMILY, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .policy = ctrl_policy_family, + .maxattr = ARRAY_SIZE(ctrl_policy_family) - 1, + .doit = ctrl_getfamily, + .dumpit = ctrl_dumpfamily, + }, + { + .cmd = CTRL_CMD_GETPOLICY, + .policy = ctrl_policy_policy, + .maxattr = ARRAY_SIZE(ctrl_policy_policy) - 1, + .start = ctrl_dumppolicy_start, + .dumpit = ctrl_dumppolicy, + .done = ctrl_dumppolicy_done, + }, +}; + +static const struct genl_multicast_group genl_ctrl_groups[] = { + { .name = "notify", }, +}; + +static struct genl_family genl_ctrl __ro_after_init = { + .module = THIS_MODULE, + .ops = genl_ctrl_ops, + .n_ops = ARRAY_SIZE(genl_ctrl_ops), + .resv_start_op = CTRL_CMD_GETPOLICY + 1, + .mcgrps = genl_ctrl_groups, + .n_mcgrps = ARRAY_SIZE(genl_ctrl_groups), + .id = GENL_ID_CTRL, + .name = "nlctrl", + .version = 0x2, + .netnsok = true, +}; + +static int genl_bind(struct net *net, int group) +{ + const struct genl_family *family; + unsigned int id; + int ret = 0; + + down_read(&cb_lock); + + idr_for_each_entry(&genl_fam_idr, family, id) { + const struct genl_multicast_group *grp; + int i; + + if (family->n_mcgrps == 0) + continue; + + i = group - family->mcgrp_offset; + if (i < 0 || i >= family->n_mcgrps) + continue; + + grp = &family->mcgrps[i]; + if ((grp->flags & GENL_UNS_ADMIN_PERM) && + !ns_capable(net->user_ns, CAP_NET_ADMIN)) + ret = -EPERM; + if (grp->cap_sys_admin && + !ns_capable(net->user_ns, CAP_SYS_ADMIN)) + ret = -EPERM; + + break; + } + + up_read(&cb_lock); + return ret; +} + +static int __net_init genl_pernet_init(struct net *net) +{ + struct netlink_kernel_cfg cfg = { + .input = genl_rcv, + .flags = NL_CFG_F_NONROOT_RECV, + .bind = genl_bind, + }; + + /* we'll bump the group number right afterwards */ + net->genl_sock = netlink_kernel_create(net, NETLINK_GENERIC, &cfg); + + if (!net->genl_sock && net_eq(net, &init_net)) + panic("GENL: Cannot initialize generic netlink\n"); + + if (!net->genl_sock) + return -ENOMEM; + + return 0; +} + +static void __net_exit genl_pernet_exit(struct net *net) +{ + netlink_kernel_release(net->genl_sock); + net->genl_sock = NULL; +} + +static struct pernet_operations genl_pernet_ops = { + .init = genl_pernet_init, + .exit = genl_pernet_exit, +}; + +static int __init genl_init(void) +{ + int err; + + err = genl_register_family(&genl_ctrl); + if (err < 0) + goto problem; + + err = register_pernet_subsys(&genl_pernet_ops); + if (err) + goto problem; + + return 0; + +problem: + panic("GENL: Cannot register controller: %d\n", err); +} + +core_initcall(genl_init); + +static int genlmsg_mcast(struct sk_buff *skb, u32 portid, unsigned long group, + gfp_t flags) +{ + struct sk_buff *tmp; + struct net *net, *prev = NULL; + bool delivered = false; + int err; + + for_each_net_rcu(net) { + if (prev) { + tmp = skb_clone(skb, flags); + if (!tmp) { + err = -ENOMEM; + goto error; + } + err = nlmsg_multicast(prev->genl_sock, tmp, + portid, group, flags); + if (!err) + delivered = true; + else if (err != -ESRCH) + goto error; + } + + prev = net; + } + + err = nlmsg_multicast(prev->genl_sock, skb, portid, group, flags); + if (!err) + delivered = true; + else if (err != -ESRCH) + return err; + return delivered ? 0 : -ESRCH; + error: + kfree_skb(skb); + return err; +} + +int genlmsg_multicast_allns(const struct genl_family *family, + struct sk_buff *skb, u32 portid, + unsigned int group, gfp_t flags) +{ + if (WARN_ON_ONCE(group >= family->n_mcgrps)) + return -EINVAL; + + group = family->mcgrp_offset + group; + return genlmsg_mcast(skb, portid, group, flags); +} +EXPORT_SYMBOL(genlmsg_multicast_allns); + +void genl_notify(const struct genl_family *family, struct sk_buff *skb, + struct genl_info *info, u32 group, gfp_t flags) +{ + struct net *net = genl_info_net(info); + struct sock *sk = net->genl_sock; + + if (WARN_ON_ONCE(group >= family->n_mcgrps)) + return; + + group = family->mcgrp_offset + group; + nlmsg_notify(sk, skb, info->snd_portid, group, + nlmsg_report(info->nlhdr), flags); +} +EXPORT_SYMBOL(genl_notify); diff --git a/net/netlink/policy.c b/net/netlink/policy.c new file mode 100644 index 000000000..87e3de0fd --- /dev/null +++ b/net/netlink/policy.c @@ -0,0 +1,479 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * NETLINK Policy advertisement to userspace + * + * Authors: Johannes Berg + * + * Copyright 2019 Intel Corporation + */ + +#include +#include +#include +#include + +#define INITIAL_POLICIES_ALLOC 10 + +struct netlink_policy_dump_state { + unsigned int policy_idx; + unsigned int attr_idx; + unsigned int n_alloc; + struct { + const struct nla_policy *policy; + unsigned int maxtype; + } policies[]; +}; + +static int add_policy(struct netlink_policy_dump_state **statep, + const struct nla_policy *policy, + unsigned int maxtype) +{ + struct netlink_policy_dump_state *state = *statep; + unsigned int n_alloc, i; + + if (!policy || !maxtype) + return 0; + + for (i = 0; i < state->n_alloc; i++) { + if (state->policies[i].policy == policy && + state->policies[i].maxtype == maxtype) + return 0; + + if (!state->policies[i].policy) { + state->policies[i].policy = policy; + state->policies[i].maxtype = maxtype; + return 0; + } + } + + n_alloc = state->n_alloc + INITIAL_POLICIES_ALLOC; + state = krealloc(state, struct_size(state, policies, n_alloc), + GFP_KERNEL); + if (!state) + return -ENOMEM; + + memset(&state->policies[state->n_alloc], 0, + flex_array_size(state, policies, n_alloc - state->n_alloc)); + + state->policies[state->n_alloc].policy = policy; + state->policies[state->n_alloc].maxtype = maxtype; + state->n_alloc = n_alloc; + *statep = state; + + return 0; +} + +/** + * netlink_policy_dump_get_policy_idx - retrieve policy index + * @state: the policy dump state + * @policy: the policy to find + * @maxtype: the policy's maxattr + * + * Returns: the index of the given policy in the dump state + * + * Call this to find a policy index when you've added multiple and e.g. + * need to tell userspace which command has which policy (by index). + * + * Note: this will WARN and return 0 if the policy isn't found, which + * means it wasn't added in the first place, which would be an + * internal consistency bug. + */ +int netlink_policy_dump_get_policy_idx(struct netlink_policy_dump_state *state, + const struct nla_policy *policy, + unsigned int maxtype) +{ + unsigned int i; + + if (WARN_ON(!policy || !maxtype)) + return 0; + + for (i = 0; i < state->n_alloc; i++) { + if (state->policies[i].policy == policy && + state->policies[i].maxtype == maxtype) + return i; + } + + WARN_ON(1); + return 0; +} + +static struct netlink_policy_dump_state *alloc_state(void) +{ + struct netlink_policy_dump_state *state; + + state = kzalloc(struct_size(state, policies, INITIAL_POLICIES_ALLOC), + GFP_KERNEL); + if (!state) + return ERR_PTR(-ENOMEM); + state->n_alloc = INITIAL_POLICIES_ALLOC; + + return state; +} + +/** + * netlink_policy_dump_add_policy - add a policy to the dump + * @pstate: state to add to, may be reallocated, must be %NULL the first time + * @policy: the new policy to add to the dump + * @maxtype: the new policy's max attr type + * + * Returns: 0 on success, a negative error code otherwise. + * + * Call this to allocate a policy dump state, and to add policies to it. This + * should be called from the dump start() callback. + * + * Note: on failures, any previously allocated state is freed. + */ +int netlink_policy_dump_add_policy(struct netlink_policy_dump_state **pstate, + const struct nla_policy *policy, + unsigned int maxtype) +{ + struct netlink_policy_dump_state *state = *pstate; + unsigned int policy_idx; + int err; + + if (!state) { + state = alloc_state(); + if (IS_ERR(state)) + return PTR_ERR(state); + } + + /* + * walk the policies and nested ones first, and build + * a linear list of them. + */ + + err = add_policy(&state, policy, maxtype); + if (err) + goto err_try_undo; + + for (policy_idx = 0; + policy_idx < state->n_alloc && state->policies[policy_idx].policy; + policy_idx++) { + const struct nla_policy *policy; + unsigned int type; + + policy = state->policies[policy_idx].policy; + + for (type = 0; + type <= state->policies[policy_idx].maxtype; + type++) { + switch (policy[type].type) { + case NLA_NESTED: + case NLA_NESTED_ARRAY: + err = add_policy(&state, + policy[type].nested_policy, + policy[type].len); + if (err) + goto err_try_undo; + break; + default: + break; + } + } + } + + *pstate = state; + return 0; + +err_try_undo: + /* Try to preserve reasonable unwind semantics - if we're starting from + * scratch clean up fully, otherwise record what we got and caller will. + */ + if (!*pstate) + netlink_policy_dump_free(state); + else + *pstate = state; + return err; +} + +static bool +netlink_policy_dump_finished(struct netlink_policy_dump_state *state) +{ + return state->policy_idx >= state->n_alloc || + !state->policies[state->policy_idx].policy; +} + +/** + * netlink_policy_dump_loop - dumping loop indicator + * @state: the policy dump state + * + * Returns: %true if the dump continues, %false otherwise + * + * Note: this frees the dump state when finishing + */ +bool netlink_policy_dump_loop(struct netlink_policy_dump_state *state) +{ + return !netlink_policy_dump_finished(state); +} + +int netlink_policy_dump_attr_size_estimate(const struct nla_policy *pt) +{ + /* nested + type */ + int common = 2 * nla_attr_size(sizeof(u32)); + + switch (pt->type) { + case NLA_UNSPEC: + case NLA_REJECT: + /* these actually don't need any space */ + return 0; + case NLA_NESTED: + case NLA_NESTED_ARRAY: + /* common, policy idx, policy maxattr */ + return common + 2 * nla_attr_size(sizeof(u32)); + case NLA_U8: + case NLA_U16: + case NLA_U32: + case NLA_U64: + case NLA_MSECS: + case NLA_S8: + case NLA_S16: + case NLA_S32: + case NLA_S64: + /* maximum is common, u64 min/max with padding */ + return common + + 2 * (nla_attr_size(0) + nla_attr_size(sizeof(u64))); + case NLA_BITFIELD32: + return common + nla_attr_size(sizeof(u32)); + case NLA_STRING: + case NLA_NUL_STRING: + case NLA_BINARY: + /* maximum is common, u32 min-length/max-length */ + return common + 2 * nla_attr_size(sizeof(u32)); + case NLA_FLAG: + return common; + } + + /* this should then cause a warning later */ + return 0; +} + +static int +__netlink_policy_dump_write_attr(struct netlink_policy_dump_state *state, + struct sk_buff *skb, + const struct nla_policy *pt, + int nestattr) +{ + int estimate = netlink_policy_dump_attr_size_estimate(pt); + enum netlink_attribute_type type; + struct nlattr *attr; + + attr = nla_nest_start(skb, nestattr); + if (!attr) + return -ENOBUFS; + + switch (pt->type) { + default: + case NLA_UNSPEC: + case NLA_REJECT: + /* skip - use NLA_MIN_LEN to advertise such */ + nla_nest_cancel(skb, attr); + return -ENODATA; + case NLA_NESTED: + type = NL_ATTR_TYPE_NESTED; + fallthrough; + case NLA_NESTED_ARRAY: + if (pt->type == NLA_NESTED_ARRAY) + type = NL_ATTR_TYPE_NESTED_ARRAY; + if (state && pt->nested_policy && pt->len && + (nla_put_u32(skb, NL_POLICY_TYPE_ATTR_POLICY_IDX, + netlink_policy_dump_get_policy_idx(state, + pt->nested_policy, + pt->len)) || + nla_put_u32(skb, NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE, + pt->len))) + goto nla_put_failure; + break; + case NLA_U8: + case NLA_U16: + case NLA_U32: + case NLA_U64: + case NLA_MSECS: { + struct netlink_range_validation range; + + if (pt->type == NLA_U8) + type = NL_ATTR_TYPE_U8; + else if (pt->type == NLA_U16) + type = NL_ATTR_TYPE_U16; + else if (pt->type == NLA_U32) + type = NL_ATTR_TYPE_U32; + else + type = NL_ATTR_TYPE_U64; + + if (pt->validation_type == NLA_VALIDATE_MASK) { + if (nla_put_u64_64bit(skb, NL_POLICY_TYPE_ATTR_MASK, + pt->mask, + NL_POLICY_TYPE_ATTR_PAD)) + goto nla_put_failure; + break; + } + + nla_get_range_unsigned(pt, &range); + + if (nla_put_u64_64bit(skb, NL_POLICY_TYPE_ATTR_MIN_VALUE_U, + range.min, NL_POLICY_TYPE_ATTR_PAD) || + nla_put_u64_64bit(skb, NL_POLICY_TYPE_ATTR_MAX_VALUE_U, + range.max, NL_POLICY_TYPE_ATTR_PAD)) + goto nla_put_failure; + break; + } + case NLA_S8: + case NLA_S16: + case NLA_S32: + case NLA_S64: { + struct netlink_range_validation_signed range; + + if (pt->type == NLA_S8) + type = NL_ATTR_TYPE_S8; + else if (pt->type == NLA_S16) + type = NL_ATTR_TYPE_S16; + else if (pt->type == NLA_S32) + type = NL_ATTR_TYPE_S32; + else + type = NL_ATTR_TYPE_S64; + + nla_get_range_signed(pt, &range); + + if (nla_put_s64(skb, NL_POLICY_TYPE_ATTR_MIN_VALUE_S, + range.min, NL_POLICY_TYPE_ATTR_PAD) || + nla_put_s64(skb, NL_POLICY_TYPE_ATTR_MAX_VALUE_S, + range.max, NL_POLICY_TYPE_ATTR_PAD)) + goto nla_put_failure; + break; + } + case NLA_BITFIELD32: + type = NL_ATTR_TYPE_BITFIELD32; + if (nla_put_u32(skb, NL_POLICY_TYPE_ATTR_BITFIELD32_MASK, + pt->bitfield32_valid)) + goto nla_put_failure; + break; + case NLA_STRING: + case NLA_NUL_STRING: + case NLA_BINARY: + if (pt->type == NLA_STRING) + type = NL_ATTR_TYPE_STRING; + else if (pt->type == NLA_NUL_STRING) + type = NL_ATTR_TYPE_NUL_STRING; + else + type = NL_ATTR_TYPE_BINARY; + + if (pt->validation_type == NLA_VALIDATE_RANGE || + pt->validation_type == NLA_VALIDATE_RANGE_WARN_TOO_LONG) { + struct netlink_range_validation range; + + nla_get_range_unsigned(pt, &range); + + if (range.min && + nla_put_u32(skb, NL_POLICY_TYPE_ATTR_MIN_LENGTH, + range.min)) + goto nla_put_failure; + + if (range.max < U16_MAX && + nla_put_u32(skb, NL_POLICY_TYPE_ATTR_MAX_LENGTH, + range.max)) + goto nla_put_failure; + } else if (pt->len && + nla_put_u32(skb, NL_POLICY_TYPE_ATTR_MAX_LENGTH, + pt->len)) { + goto nla_put_failure; + } + break; + case NLA_FLAG: + type = NL_ATTR_TYPE_FLAG; + break; + } + + if (nla_put_u32(skb, NL_POLICY_TYPE_ATTR_TYPE, type)) + goto nla_put_failure; + + nla_nest_end(skb, attr); + WARN_ON(attr->nla_len > estimate); + + return 0; +nla_put_failure: + nla_nest_cancel(skb, attr); + return -ENOBUFS; +} + +/** + * netlink_policy_dump_write_attr - write a given attribute policy + * @skb: the message skb to write to + * @pt: the attribute's policy + * @nestattr: the nested attribute ID to use + * + * Returns: 0 on success, an error code otherwise; -%ENODATA is + * special, indicating that there's no policy data and + * the attribute is generally rejected. + */ +int netlink_policy_dump_write_attr(struct sk_buff *skb, + const struct nla_policy *pt, + int nestattr) +{ + return __netlink_policy_dump_write_attr(NULL, skb, pt, nestattr); +} + +/** + * netlink_policy_dump_write - write current policy dump attributes + * @skb: the message skb to write to + * @state: the policy dump state + * + * Returns: 0 on success, an error code otherwise + */ +int netlink_policy_dump_write(struct sk_buff *skb, + struct netlink_policy_dump_state *state) +{ + const struct nla_policy *pt; + struct nlattr *policy; + bool again; + int err; + +send_attribute: + again = false; + + pt = &state->policies[state->policy_idx].policy[state->attr_idx]; + + policy = nla_nest_start(skb, state->policy_idx); + if (!policy) + return -ENOBUFS; + + err = __netlink_policy_dump_write_attr(state, skb, pt, state->attr_idx); + if (err == -ENODATA) { + nla_nest_cancel(skb, policy); + again = true; + goto next; + } else if (err) { + goto nla_put_failure; + } + + /* finish and move state to next attribute */ + nla_nest_end(skb, policy); + +next: + state->attr_idx += 1; + if (state->attr_idx > state->policies[state->policy_idx].maxtype) { + state->attr_idx = 0; + state->policy_idx++; + } + + if (again) { + if (netlink_policy_dump_finished(state)) + return -ENODATA; + goto send_attribute; + } + + return 0; + +nla_put_failure: + nla_nest_cancel(skb, policy); + return -ENOBUFS; +} + +/** + * netlink_policy_dump_free - free policy dump state + * @state: the policy dump state to free + * + * Call this from the done() method to ensure dump state is freed. + */ +void netlink_policy_dump_free(struct netlink_policy_dump_state *state) +{ + kfree(state); +} diff --git a/net/netrom/Makefile b/net/netrom/Makefile new file mode 100644 index 000000000..603e36c9a --- /dev/null +++ b/net/netrom/Makefile @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the Linux NET/ROM layer. +# + +obj-$(CONFIG_NETROM) += netrom.o + +netrom-y := af_netrom.o nr_dev.o nr_in.o nr_loopback.o \ + nr_out.o nr_route.o nr_subr.o nr_timer.o +netrom-$(CONFIG_SYSCTL) += sysctl_net_netrom.o diff --git a/net/netrom/af_netrom.c b/net/netrom/af_netrom.c new file mode 100644 index 000000000..ec5747969 --- /dev/null +++ b/net/netrom/af_netrom.c @@ -0,0 +1,1537 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright Alan Cox GW4PTS (alan@lxorguk.ukuu.org.uk) + * Copyright Darryl Miles G7LED (dlm@g7led.demon.co.uk) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include /* For TIOCINQ/OUTQ */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int nr_ndevs = 4; + +int sysctl_netrom_default_path_quality = NR_DEFAULT_QUAL; +int sysctl_netrom_obsolescence_count_initialiser = NR_DEFAULT_OBS; +int sysctl_netrom_network_ttl_initialiser = NR_DEFAULT_TTL; +int sysctl_netrom_transport_timeout = NR_DEFAULT_T1; +int sysctl_netrom_transport_maximum_tries = NR_DEFAULT_N2; +int sysctl_netrom_transport_acknowledge_delay = NR_DEFAULT_T2; +int sysctl_netrom_transport_busy_delay = NR_DEFAULT_T4; +int sysctl_netrom_transport_requested_window_size = NR_DEFAULT_WINDOW; +int sysctl_netrom_transport_no_activity_timeout = NR_DEFAULT_IDLE; +int sysctl_netrom_routing_control = NR_DEFAULT_ROUTING; +int sysctl_netrom_link_fails_count = NR_DEFAULT_FAILS; +int sysctl_netrom_reset_circuit = NR_DEFAULT_RESET; + +static unsigned short circuit = 0x101; + +static HLIST_HEAD(nr_list); +static DEFINE_SPINLOCK(nr_list_lock); + +static const struct proto_ops nr_proto_ops; + +/* + * NETROM network devices are virtual network devices encapsulating NETROM + * frames into AX.25 which will be sent through an AX.25 device, so form a + * special "super class" of normal net devices; split their locks off into a + * separate class since they always nest. + */ +static struct lock_class_key nr_netdev_xmit_lock_key; +static struct lock_class_key nr_netdev_addr_lock_key; + +static void nr_set_lockdep_one(struct net_device *dev, + struct netdev_queue *txq, + void *_unused) +{ + lockdep_set_class(&txq->_xmit_lock, &nr_netdev_xmit_lock_key); +} + +static void nr_set_lockdep_key(struct net_device *dev) +{ + lockdep_set_class(&dev->addr_list_lock, &nr_netdev_addr_lock_key); + netdev_for_each_tx_queue(dev, nr_set_lockdep_one, NULL); +} + +/* + * Socket removal during an interrupt is now safe. + */ +static void nr_remove_socket(struct sock *sk) +{ + spin_lock_bh(&nr_list_lock); + sk_del_node_init(sk); + spin_unlock_bh(&nr_list_lock); +} + +/* + * Kill all bound sockets on a dropped device. + */ +static void nr_kill_by_device(struct net_device *dev) +{ + struct sock *s; + + spin_lock_bh(&nr_list_lock); + sk_for_each(s, &nr_list) + if (nr_sk(s)->device == dev) + nr_disconnect(s, ENETUNREACH); + spin_unlock_bh(&nr_list_lock); +} + +/* + * Handle device status changes. + */ +static int nr_device_event(struct notifier_block *this, unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + if (!net_eq(dev_net(dev), &init_net)) + return NOTIFY_DONE; + + if (event != NETDEV_DOWN) + return NOTIFY_DONE; + + nr_kill_by_device(dev); + nr_rt_device_down(dev); + + return NOTIFY_DONE; +} + +/* + * Add a socket to the bound sockets list. + */ +static void nr_insert_socket(struct sock *sk) +{ + spin_lock_bh(&nr_list_lock); + sk_add_node(sk, &nr_list); + spin_unlock_bh(&nr_list_lock); +} + +/* + * Find a socket that wants to accept the Connect Request we just + * received. + */ +static struct sock *nr_find_listener(ax25_address *addr) +{ + struct sock *s; + + spin_lock_bh(&nr_list_lock); + sk_for_each(s, &nr_list) + if (!ax25cmp(&nr_sk(s)->source_addr, addr) && + s->sk_state == TCP_LISTEN) { + sock_hold(s); + goto found; + } + s = NULL; +found: + spin_unlock_bh(&nr_list_lock); + return s; +} + +/* + * Find a connected NET/ROM socket given my circuit IDs. + */ +static struct sock *nr_find_socket(unsigned char index, unsigned char id) +{ + struct sock *s; + + spin_lock_bh(&nr_list_lock); + sk_for_each(s, &nr_list) { + struct nr_sock *nr = nr_sk(s); + + if (nr->my_index == index && nr->my_id == id) { + sock_hold(s); + goto found; + } + } + s = NULL; +found: + spin_unlock_bh(&nr_list_lock); + return s; +} + +/* + * Find a connected NET/ROM socket given their circuit IDs. + */ +static struct sock *nr_find_peer(unsigned char index, unsigned char id, + ax25_address *dest) +{ + struct sock *s; + + spin_lock_bh(&nr_list_lock); + sk_for_each(s, &nr_list) { + struct nr_sock *nr = nr_sk(s); + + if (nr->your_index == index && nr->your_id == id && + !ax25cmp(&nr->dest_addr, dest)) { + sock_hold(s); + goto found; + } + } + s = NULL; +found: + spin_unlock_bh(&nr_list_lock); + return s; +} + +/* + * Find next free circuit ID. + */ +static unsigned short nr_find_next_circuit(void) +{ + unsigned short id = circuit; + unsigned char i, j; + struct sock *sk; + + for (;;) { + i = id / 256; + j = id % 256; + + if (i != 0 && j != 0) { + if ((sk=nr_find_socket(i, j)) == NULL) + break; + sock_put(sk); + } + + id++; + } + + return id; +} + +/* + * Deferred destroy. + */ +void nr_destroy_socket(struct sock *); + +/* + * Handler for deferred kills. + */ +static void nr_destroy_timer(struct timer_list *t) +{ + struct sock *sk = from_timer(sk, t, sk_timer); + bh_lock_sock(sk); + sock_hold(sk); + nr_destroy_socket(sk); + bh_unlock_sock(sk); + sock_put(sk); +} + +/* + * This is called from user mode and the timers. Thus it protects itself + * against interrupt users but doesn't worry about being called during + * work. Once it is removed from the queue no interrupt or bottom half + * will touch it and we are (fairly 8-) ) safe. + */ +void nr_destroy_socket(struct sock *sk) +{ + struct sk_buff *skb; + + nr_remove_socket(sk); + + nr_stop_heartbeat(sk); + nr_stop_t1timer(sk); + nr_stop_t2timer(sk); + nr_stop_t4timer(sk); + nr_stop_idletimer(sk); + + nr_clear_queues(sk); /* Flush the queues */ + + while ((skb = skb_dequeue(&sk->sk_receive_queue)) != NULL) { + if (skb->sk != sk) { /* A pending connection */ + /* Queue the unaccepted socket for death */ + sock_set_flag(skb->sk, SOCK_DEAD); + nr_start_heartbeat(skb->sk); + nr_sk(skb->sk)->state = NR_STATE_0; + } + + kfree_skb(skb); + } + + if (sk_has_allocations(sk)) { + /* Defer: outstanding buffers */ + sk->sk_timer.function = nr_destroy_timer; + sk->sk_timer.expires = jiffies + 2 * HZ; + add_timer(&sk->sk_timer); + } else + sock_put(sk); +} + +/* + * Handling for system calls applied via the various interfaces to a + * NET/ROM socket object. + */ + +static int nr_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct nr_sock *nr = nr_sk(sk); + unsigned int opt; + + if (level != SOL_NETROM) + return -ENOPROTOOPT; + + if (optlen < sizeof(unsigned int)) + return -EINVAL; + + if (copy_from_sockptr(&opt, optval, sizeof(opt))) + return -EFAULT; + + switch (optname) { + case NETROM_T1: + if (opt < 1 || opt > UINT_MAX / HZ) + return -EINVAL; + nr->t1 = opt * HZ; + return 0; + + case NETROM_T2: + if (opt < 1 || opt > UINT_MAX / HZ) + return -EINVAL; + nr->t2 = opt * HZ; + return 0; + + case NETROM_N2: + if (opt < 1 || opt > 31) + return -EINVAL; + nr->n2 = opt; + return 0; + + case NETROM_T4: + if (opt < 1 || opt > UINT_MAX / HZ) + return -EINVAL; + nr->t4 = opt * HZ; + return 0; + + case NETROM_IDLE: + if (opt > UINT_MAX / (60 * HZ)) + return -EINVAL; + nr->idle = opt * 60 * HZ; + return 0; + + default: + return -ENOPROTOOPT; + } +} + +static int nr_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + struct nr_sock *nr = nr_sk(sk); + int val = 0; + int len; + + if (level != SOL_NETROM) + return -ENOPROTOOPT; + + if (get_user(len, optlen)) + return -EFAULT; + + if (len < 0) + return -EINVAL; + + switch (optname) { + case NETROM_T1: + val = nr->t1 / HZ; + break; + + case NETROM_T2: + val = nr->t2 / HZ; + break; + + case NETROM_N2: + val = nr->n2; + break; + + case NETROM_T4: + val = nr->t4 / HZ; + break; + + case NETROM_IDLE: + val = nr->idle / (60 * HZ); + break; + + default: + return -ENOPROTOOPT; + } + + len = min_t(unsigned int, len, sizeof(int)); + + if (put_user(len, optlen)) + return -EFAULT; + + return copy_to_user(optval, &val, len) ? -EFAULT : 0; +} + +static int nr_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + + lock_sock(sk); + if (sock->state != SS_UNCONNECTED) { + release_sock(sk); + return -EINVAL; + } + + if (sk->sk_state != TCP_LISTEN) { + memset(&nr_sk(sk)->user_addr, 0, AX25_ADDR_LEN); + sk->sk_max_ack_backlog = backlog; + sk->sk_state = TCP_LISTEN; + release_sock(sk); + return 0; + } + release_sock(sk); + + return -EOPNOTSUPP; +} + +static struct proto nr_proto = { + .name = "NETROM", + .owner = THIS_MODULE, + .obj_size = sizeof(struct nr_sock), +}; + +static int nr_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + struct nr_sock *nr; + + if (!net_eq(net, &init_net)) + return -EAFNOSUPPORT; + + if (sock->type != SOCK_SEQPACKET || protocol != 0) + return -ESOCKTNOSUPPORT; + + sk = sk_alloc(net, PF_NETROM, GFP_ATOMIC, &nr_proto, kern); + if (sk == NULL) + return -ENOMEM; + + nr = nr_sk(sk); + + sock_init_data(sock, sk); + + sock->ops = &nr_proto_ops; + sk->sk_protocol = protocol; + + skb_queue_head_init(&nr->ack_queue); + skb_queue_head_init(&nr->reseq_queue); + skb_queue_head_init(&nr->frag_queue); + + nr_init_timers(sk); + + nr->t1 = + msecs_to_jiffies(sysctl_netrom_transport_timeout); + nr->t2 = + msecs_to_jiffies(sysctl_netrom_transport_acknowledge_delay); + nr->n2 = + msecs_to_jiffies(sysctl_netrom_transport_maximum_tries); + nr->t4 = + msecs_to_jiffies(sysctl_netrom_transport_busy_delay); + nr->idle = + msecs_to_jiffies(sysctl_netrom_transport_no_activity_timeout); + nr->window = sysctl_netrom_transport_requested_window_size; + + nr->bpqext = 1; + nr->state = NR_STATE_0; + + return 0; +} + +static struct sock *nr_make_new(struct sock *osk) +{ + struct sock *sk; + struct nr_sock *nr, *onr; + + if (osk->sk_type != SOCK_SEQPACKET) + return NULL; + + sk = sk_alloc(sock_net(osk), PF_NETROM, GFP_ATOMIC, osk->sk_prot, 0); + if (sk == NULL) + return NULL; + + nr = nr_sk(sk); + + sock_init_data(NULL, sk); + + sk->sk_type = osk->sk_type; + sk->sk_priority = osk->sk_priority; + sk->sk_protocol = osk->sk_protocol; + sk->sk_rcvbuf = osk->sk_rcvbuf; + sk->sk_sndbuf = osk->sk_sndbuf; + sk->sk_state = TCP_ESTABLISHED; + sock_copy_flags(sk, osk); + + skb_queue_head_init(&nr->ack_queue); + skb_queue_head_init(&nr->reseq_queue); + skb_queue_head_init(&nr->frag_queue); + + nr_init_timers(sk); + + onr = nr_sk(osk); + + nr->t1 = onr->t1; + nr->t2 = onr->t2; + nr->n2 = onr->n2; + nr->t4 = onr->t4; + nr->idle = onr->idle; + nr->window = onr->window; + + nr->device = onr->device; + nr->bpqext = onr->bpqext; + + return sk; +} + +static int nr_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct nr_sock *nr; + + if (sk == NULL) return 0; + + sock_hold(sk); + sock_orphan(sk); + lock_sock(sk); + nr = nr_sk(sk); + + switch (nr->state) { + case NR_STATE_0: + case NR_STATE_1: + case NR_STATE_2: + nr_disconnect(sk, 0); + nr_destroy_socket(sk); + break; + + case NR_STATE_3: + nr_clear_queues(sk); + nr->n2count = 0; + nr_write_internal(sk, NR_DISCREQ); + nr_start_t1timer(sk); + nr_stop_t2timer(sk); + nr_stop_t4timer(sk); + nr_stop_idletimer(sk); + nr->state = NR_STATE_2; + sk->sk_state = TCP_CLOSE; + sk->sk_shutdown |= SEND_SHUTDOWN; + sk->sk_state_change(sk); + sock_set_flag(sk, SOCK_DESTROY); + break; + + default: + break; + } + + sock->sk = NULL; + release_sock(sk); + sock_put(sk); + + return 0; +} + +static int nr_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +{ + struct sock *sk = sock->sk; + struct nr_sock *nr = nr_sk(sk); + struct full_sockaddr_ax25 *addr = (struct full_sockaddr_ax25 *)uaddr; + struct net_device *dev; + ax25_uid_assoc *user; + ax25_address *source; + + lock_sock(sk); + if (!sock_flag(sk, SOCK_ZAPPED)) { + release_sock(sk); + return -EINVAL; + } + if (addr_len < sizeof(struct sockaddr_ax25) || addr_len > sizeof(struct full_sockaddr_ax25)) { + release_sock(sk); + return -EINVAL; + } + if (addr_len < (addr->fsa_ax25.sax25_ndigis * sizeof(ax25_address) + sizeof(struct sockaddr_ax25))) { + release_sock(sk); + return -EINVAL; + } + if (addr->fsa_ax25.sax25_family != AF_NETROM) { + release_sock(sk); + return -EINVAL; + } + if ((dev = nr_dev_get(&addr->fsa_ax25.sax25_call)) == NULL) { + release_sock(sk); + return -EADDRNOTAVAIL; + } + + /* + * Only the super user can set an arbitrary user callsign. + */ + if (addr->fsa_ax25.sax25_ndigis == 1) { + if (!capable(CAP_NET_BIND_SERVICE)) { + dev_put(dev); + release_sock(sk); + return -EPERM; + } + nr->user_addr = addr->fsa_digipeater[0]; + nr->source_addr = addr->fsa_ax25.sax25_call; + } else { + source = &addr->fsa_ax25.sax25_call; + + user = ax25_findbyuid(current_euid()); + if (user) { + nr->user_addr = user->call; + ax25_uid_put(user); + } else { + if (ax25_uid_policy && !capable(CAP_NET_BIND_SERVICE)) { + release_sock(sk); + dev_put(dev); + return -EPERM; + } + nr->user_addr = *source; + } + + nr->source_addr = *source; + } + + nr->device = dev; + nr_insert_socket(sk); + + sock_reset_flag(sk, SOCK_ZAPPED); + dev_put(dev); + release_sock(sk); + + return 0; +} + +static int nr_connect(struct socket *sock, struct sockaddr *uaddr, + int addr_len, int flags) +{ + struct sock *sk = sock->sk; + struct nr_sock *nr = nr_sk(sk); + struct sockaddr_ax25 *addr = (struct sockaddr_ax25 *)uaddr; + const ax25_address *source = NULL; + ax25_uid_assoc *user; + struct net_device *dev; + int err = 0; + + lock_sock(sk); + if (sk->sk_state == TCP_ESTABLISHED && sock->state == SS_CONNECTING) { + sock->state = SS_CONNECTED; + goto out_release; /* Connect completed during a ERESTARTSYS event */ + } + + if (sk->sk_state == TCP_CLOSE && sock->state == SS_CONNECTING) { + sock->state = SS_UNCONNECTED; + err = -ECONNREFUSED; + goto out_release; + } + + if (sk->sk_state == TCP_ESTABLISHED) { + err = -EISCONN; /* No reconnect on a seqpacket socket */ + goto out_release; + } + + if (sock->state == SS_CONNECTING) { + err = -EALREADY; + goto out_release; + } + + sk->sk_state = TCP_CLOSE; + sock->state = SS_UNCONNECTED; + + if (addr_len != sizeof(struct sockaddr_ax25) && addr_len != sizeof(struct full_sockaddr_ax25)) { + err = -EINVAL; + goto out_release; + } + if (addr->sax25_family != AF_NETROM) { + err = -EINVAL; + goto out_release; + } + if (sock_flag(sk, SOCK_ZAPPED)) { /* Must bind first - autobinding in this may or may not work */ + sock_reset_flag(sk, SOCK_ZAPPED); + + if ((dev = nr_dev_first()) == NULL) { + err = -ENETUNREACH; + goto out_release; + } + source = (const ax25_address *)dev->dev_addr; + + user = ax25_findbyuid(current_euid()); + if (user) { + nr->user_addr = user->call; + ax25_uid_put(user); + } else { + if (ax25_uid_policy && !capable(CAP_NET_ADMIN)) { + dev_put(dev); + err = -EPERM; + goto out_release; + } + nr->user_addr = *source; + } + + nr->source_addr = *source; + nr->device = dev; + + dev_put(dev); + nr_insert_socket(sk); /* Finish the bind */ + } + + nr->dest_addr = addr->sax25_call; + + release_sock(sk); + circuit = nr_find_next_circuit(); + lock_sock(sk); + + nr->my_index = circuit / 256; + nr->my_id = circuit % 256; + + circuit++; + + /* Move to connecting socket, start sending Connect Requests */ + sock->state = SS_CONNECTING; + sk->sk_state = TCP_SYN_SENT; + + nr_establish_data_link(sk); + + nr->state = NR_STATE_1; + + nr_start_heartbeat(sk); + + /* Now the loop */ + if (sk->sk_state != TCP_ESTABLISHED && (flags & O_NONBLOCK)) { + err = -EINPROGRESS; + goto out_release; + } + + /* + * A Connect Ack with Choke or timeout or failed routing will go to + * closed. + */ + if (sk->sk_state == TCP_SYN_SENT) { + DEFINE_WAIT(wait); + + for (;;) { + prepare_to_wait(sk_sleep(sk), &wait, + TASK_INTERRUPTIBLE); + if (sk->sk_state != TCP_SYN_SENT) + break; + if (!signal_pending(current)) { + release_sock(sk); + schedule(); + lock_sock(sk); + continue; + } + err = -ERESTARTSYS; + break; + } + finish_wait(sk_sleep(sk), &wait); + if (err) + goto out_release; + } + + if (sk->sk_state != TCP_ESTABLISHED) { + sock->state = SS_UNCONNECTED; + err = sock_error(sk); /* Always set at this point */ + goto out_release; + } + + sock->state = SS_CONNECTED; + +out_release: + release_sock(sk); + + return err; +} + +static int nr_accept(struct socket *sock, struct socket *newsock, int flags, + bool kern) +{ + struct sk_buff *skb; + struct sock *newsk; + DEFINE_WAIT(wait); + struct sock *sk; + int err = 0; + + if ((sk = sock->sk) == NULL) + return -EINVAL; + + lock_sock(sk); + if (sk->sk_type != SOCK_SEQPACKET) { + err = -EOPNOTSUPP; + goto out_release; + } + + if (sk->sk_state != TCP_LISTEN) { + err = -EINVAL; + goto out_release; + } + + /* + * The write queue this time is holding sockets ready to use + * hooked into the SABM we saved + */ + for (;;) { + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + skb = skb_dequeue(&sk->sk_receive_queue); + if (skb) + break; + + if (flags & O_NONBLOCK) { + err = -EWOULDBLOCK; + break; + } + if (!signal_pending(current)) { + release_sock(sk); + schedule(); + lock_sock(sk); + continue; + } + err = -ERESTARTSYS; + break; + } + finish_wait(sk_sleep(sk), &wait); + if (err) + goto out_release; + + newsk = skb->sk; + sock_graft(newsk, newsock); + + /* Now attach up the new socket */ + kfree_skb(skb); + sk_acceptq_removed(sk); + +out_release: + release_sock(sk); + + return err; +} + +static int nr_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + struct full_sockaddr_ax25 *sax = (struct full_sockaddr_ax25 *)uaddr; + struct sock *sk = sock->sk; + struct nr_sock *nr = nr_sk(sk); + int uaddr_len; + + memset(&sax->fsa_ax25, 0, sizeof(struct sockaddr_ax25)); + + lock_sock(sk); + if (peer != 0) { + if (sk->sk_state != TCP_ESTABLISHED) { + release_sock(sk); + return -ENOTCONN; + } + sax->fsa_ax25.sax25_family = AF_NETROM; + sax->fsa_ax25.sax25_ndigis = 1; + sax->fsa_ax25.sax25_call = nr->user_addr; + memset(sax->fsa_digipeater, 0, sizeof(sax->fsa_digipeater)); + sax->fsa_digipeater[0] = nr->dest_addr; + uaddr_len = sizeof(struct full_sockaddr_ax25); + } else { + sax->fsa_ax25.sax25_family = AF_NETROM; + sax->fsa_ax25.sax25_ndigis = 0; + sax->fsa_ax25.sax25_call = nr->source_addr; + uaddr_len = sizeof(struct sockaddr_ax25); + } + release_sock(sk); + + return uaddr_len; +} + +int nr_rx_frame(struct sk_buff *skb, struct net_device *dev) +{ + struct sock *sk; + struct sock *make; + struct nr_sock *nr_make; + ax25_address *src, *dest, *user; + unsigned short circuit_index, circuit_id; + unsigned short peer_circuit_index, peer_circuit_id; + unsigned short frametype, flags, window, timeout; + int ret; + + skb_orphan(skb); + + /* + * skb->data points to the netrom frame start + */ + + src = (ax25_address *)(skb->data + 0); + dest = (ax25_address *)(skb->data + 7); + + circuit_index = skb->data[15]; + circuit_id = skb->data[16]; + peer_circuit_index = skb->data[17]; + peer_circuit_id = skb->data[18]; + frametype = skb->data[19] & 0x0F; + flags = skb->data[19] & 0xF0; + + /* + * Check for an incoming IP over NET/ROM frame. + */ + if (frametype == NR_PROTOEXT && + circuit_index == NR_PROTO_IP && circuit_id == NR_PROTO_IP) { + skb_pull(skb, NR_NETWORK_LEN + NR_TRANSPORT_LEN); + skb_reset_transport_header(skb); + + return nr_rx_ip(skb, dev); + } + + /* + * Find an existing socket connection, based on circuit ID, if it's + * a Connect Request base it on their circuit ID. + * + * Circuit ID 0/0 is not valid but it could still be a "reset" for a + * circuit that no longer exists at the other end ... + */ + + sk = NULL; + + if (circuit_index == 0 && circuit_id == 0) { + if (frametype == NR_CONNACK && flags == NR_CHOKE_FLAG) + sk = nr_find_peer(peer_circuit_index, peer_circuit_id, src); + } else { + if (frametype == NR_CONNREQ) + sk = nr_find_peer(circuit_index, circuit_id, src); + else + sk = nr_find_socket(circuit_index, circuit_id); + } + + if (sk != NULL) { + bh_lock_sock(sk); + skb_reset_transport_header(skb); + + if (frametype == NR_CONNACK && skb->len == 22) + nr_sk(sk)->bpqext = 1; + else + nr_sk(sk)->bpqext = 0; + + ret = nr_process_rx_frame(sk, skb); + bh_unlock_sock(sk); + sock_put(sk); + return ret; + } + + /* + * Now it should be a CONNREQ. + */ + if (frametype != NR_CONNREQ) { + /* + * Here it would be nice to be able to send a reset but + * NET/ROM doesn't have one. We've tried to extend the protocol + * by sending NR_CONNACK | NR_CHOKE_FLAGS replies but that + * apparently kills BPQ boxes... :-( + * So now we try to follow the established behaviour of + * G8PZT's Xrouter which is sending packets with command type 7 + * as an extension of the protocol. + */ + if (sysctl_netrom_reset_circuit && + (frametype != NR_RESET || flags != 0)) + nr_transmit_reset(skb, 1); + + return 0; + } + + sk = nr_find_listener(dest); + + user = (ax25_address *)(skb->data + 21); + + if (sk == NULL || sk_acceptq_is_full(sk) || + (make = nr_make_new(sk)) == NULL) { + nr_transmit_refusal(skb, 0); + if (sk) + sock_put(sk); + return 0; + } + + bh_lock_sock(sk); + + window = skb->data[20]; + + sock_hold(make); + skb->sk = make; + skb->destructor = sock_efree; + make->sk_state = TCP_ESTABLISHED; + + /* Fill in his circuit details */ + nr_make = nr_sk(make); + nr_make->source_addr = *dest; + nr_make->dest_addr = *src; + nr_make->user_addr = *user; + + nr_make->your_index = circuit_index; + nr_make->your_id = circuit_id; + + bh_unlock_sock(sk); + circuit = nr_find_next_circuit(); + bh_lock_sock(sk); + + nr_make->my_index = circuit / 256; + nr_make->my_id = circuit % 256; + + circuit++; + + /* Window negotiation */ + if (window < nr_make->window) + nr_make->window = window; + + /* L4 timeout negotiation */ + if (skb->len == 37) { + timeout = skb->data[36] * 256 + skb->data[35]; + if (timeout * HZ < nr_make->t1) + nr_make->t1 = timeout * HZ; + nr_make->bpqext = 1; + } else { + nr_make->bpqext = 0; + } + + nr_write_internal(make, NR_CONNACK); + + nr_make->condition = 0x00; + nr_make->vs = 0; + nr_make->va = 0; + nr_make->vr = 0; + nr_make->vl = 0; + nr_make->state = NR_STATE_3; + sk_acceptq_added(sk); + skb_queue_head(&sk->sk_receive_queue, skb); + + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_data_ready(sk); + + bh_unlock_sock(sk); + sock_put(sk); + + nr_insert_socket(make); + + nr_start_heartbeat(make); + nr_start_idletimer(make); + + return 1; +} + +static int nr_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) +{ + struct sock *sk = sock->sk; + struct nr_sock *nr = nr_sk(sk); + DECLARE_SOCKADDR(struct sockaddr_ax25 *, usax, msg->msg_name); + int err; + struct sockaddr_ax25 sax; + struct sk_buff *skb; + unsigned char *asmptr; + int size; + + if (msg->msg_flags & ~(MSG_DONTWAIT|MSG_EOR|MSG_CMSG_COMPAT)) + return -EINVAL; + + lock_sock(sk); + if (sock_flag(sk, SOCK_ZAPPED)) { + err = -EADDRNOTAVAIL; + goto out; + } + + if (sk->sk_shutdown & SEND_SHUTDOWN) { + send_sig(SIGPIPE, current, 0); + err = -EPIPE; + goto out; + } + + if (nr->device == NULL) { + err = -ENETUNREACH; + goto out; + } + + if (usax) { + if (msg->msg_namelen < sizeof(sax)) { + err = -EINVAL; + goto out; + } + sax = *usax; + if (ax25cmp(&nr->dest_addr, &sax.sax25_call) != 0) { + err = -EISCONN; + goto out; + } + if (sax.sax25_family != AF_NETROM) { + err = -EINVAL; + goto out; + } + } else { + if (sk->sk_state != TCP_ESTABLISHED) { + err = -ENOTCONN; + goto out; + } + sax.sax25_family = AF_NETROM; + sax.sax25_call = nr->dest_addr; + } + + /* Build a packet - the conventional user limit is 236 bytes. We can + do ludicrously large NetROM frames but must not overflow */ + if (len > 65536) { + err = -EMSGSIZE; + goto out; + } + + size = len + NR_NETWORK_LEN + NR_TRANSPORT_LEN; + + if ((skb = sock_alloc_send_skb(sk, size, msg->msg_flags & MSG_DONTWAIT, &err)) == NULL) + goto out; + + skb_reserve(skb, size - len); + skb_reset_transport_header(skb); + + /* + * Push down the NET/ROM header + */ + + asmptr = skb_push(skb, NR_TRANSPORT_LEN); + + /* Build a NET/ROM Transport header */ + + *asmptr++ = nr->your_index; + *asmptr++ = nr->your_id; + *asmptr++ = 0; /* To be filled in later */ + *asmptr++ = 0; /* Ditto */ + *asmptr++ = NR_INFO; + + /* + * Put the data on the end + */ + skb_put(skb, len); + + /* User data follows immediately after the NET/ROM transport header */ + if (memcpy_from_msg(skb_transport_header(skb), msg, len)) { + kfree_skb(skb); + err = -EFAULT; + goto out; + } + + if (sk->sk_state != TCP_ESTABLISHED) { + kfree_skb(skb); + err = -ENOTCONN; + goto out; + } + + nr_output(sk, skb); /* Shove it onto the queue */ + + err = len; +out: + release_sock(sk); + return err; +} + +static int nr_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, + int flags) +{ + struct sock *sk = sock->sk; + DECLARE_SOCKADDR(struct sockaddr_ax25 *, sax, msg->msg_name); + size_t copied; + struct sk_buff *skb; + int er; + + /* + * This works for seqpacket too. The receiver has ordered the queue for + * us! We do one quick check first though + */ + + lock_sock(sk); + if (sk->sk_state != TCP_ESTABLISHED) { + release_sock(sk); + return -ENOTCONN; + } + + /* Now we can treat all alike */ + skb = skb_recv_datagram(sk, flags, &er); + if (!skb) { + release_sock(sk); + return er; + } + + skb_reset_transport_header(skb); + copied = skb->len; + + if (copied > size) { + copied = size; + msg->msg_flags |= MSG_TRUNC; + } + + er = skb_copy_datagram_msg(skb, 0, msg, copied); + if (er < 0) { + skb_free_datagram(sk, skb); + release_sock(sk); + return er; + } + + if (sax != NULL) { + memset(sax, 0, sizeof(*sax)); + sax->sax25_family = AF_NETROM; + skb_copy_from_linear_data_offset(skb, 7, sax->sax25_call.ax25_call, + AX25_ADDR_LEN); + msg->msg_namelen = sizeof(*sax); + } + + skb_free_datagram(sk, skb); + + release_sock(sk); + return copied; +} + + +static int nr_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + struct sock *sk = sock->sk; + void __user *argp = (void __user *)arg; + + switch (cmd) { + case TIOCOUTQ: { + long amount; + + lock_sock(sk); + amount = sk->sk_sndbuf - sk_wmem_alloc_get(sk); + if (amount < 0) + amount = 0; + release_sock(sk); + return put_user(amount, (int __user *)argp); + } + + case TIOCINQ: { + struct sk_buff *skb; + long amount = 0L; + + lock_sock(sk); + /* These two are safe on a single CPU system as only user tasks fiddle here */ + if ((skb = skb_peek(&sk->sk_receive_queue)) != NULL) + amount = skb->len; + release_sock(sk); + return put_user(amount, (int __user *)argp); + } + + case SIOCGIFADDR: + case SIOCSIFADDR: + case SIOCGIFDSTADDR: + case SIOCSIFDSTADDR: + case SIOCGIFBRDADDR: + case SIOCSIFBRDADDR: + case SIOCGIFNETMASK: + case SIOCSIFNETMASK: + case SIOCGIFMETRIC: + case SIOCSIFMETRIC: + return -EINVAL; + + case SIOCADDRT: + case SIOCDELRT: + case SIOCNRDECOBS: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + return nr_rt_ioctl(cmd, argp); + + default: + return -ENOIOCTLCMD; + } + + return 0; +} + +#ifdef CONFIG_PROC_FS + +static void *nr_info_start(struct seq_file *seq, loff_t *pos) + __acquires(&nr_list_lock) +{ + spin_lock_bh(&nr_list_lock); + return seq_hlist_start_head(&nr_list, *pos); +} + +static void *nr_info_next(struct seq_file *seq, void *v, loff_t *pos) +{ + return seq_hlist_next(v, &nr_list, pos); +} + +static void nr_info_stop(struct seq_file *seq, void *v) + __releases(&nr_list_lock) +{ + spin_unlock_bh(&nr_list_lock); +} + +static int nr_info_show(struct seq_file *seq, void *v) +{ + struct sock *s = sk_entry(v); + struct net_device *dev; + struct nr_sock *nr; + const char *devname; + char buf[11]; + + if (v == SEQ_START_TOKEN) + seq_puts(seq, +"user_addr dest_node src_node dev my your st vs vr va t1 t2 t4 idle n2 wnd Snd-Q Rcv-Q inode\n"); + + else { + + bh_lock_sock(s); + nr = nr_sk(s); + + if ((dev = nr->device) == NULL) + devname = "???"; + else + devname = dev->name; + + seq_printf(seq, "%-9s ", ax2asc(buf, &nr->user_addr)); + seq_printf(seq, "%-9s ", ax2asc(buf, &nr->dest_addr)); + seq_printf(seq, +"%-9s %-3s %02X/%02X %02X/%02X %2d %3d %3d %3d %3lu/%03lu %2lu/%02lu %3lu/%03lu %3lu/%03lu %2d/%02d %3d %5d %5d %ld\n", + ax2asc(buf, &nr->source_addr), + devname, + nr->my_index, + nr->my_id, + nr->your_index, + nr->your_id, + nr->state, + nr->vs, + nr->vr, + nr->va, + ax25_display_timer(&nr->t1timer) / HZ, + nr->t1 / HZ, + ax25_display_timer(&nr->t2timer) / HZ, + nr->t2 / HZ, + ax25_display_timer(&nr->t4timer) / HZ, + nr->t4 / HZ, + ax25_display_timer(&nr->idletimer) / (60 * HZ), + nr->idle / (60 * HZ), + nr->n2count, + nr->n2, + nr->window, + sk_wmem_alloc_get(s), + sk_rmem_alloc_get(s), + s->sk_socket ? SOCK_INODE(s->sk_socket)->i_ino : 0L); + + bh_unlock_sock(s); + } + return 0; +} + +static const struct seq_operations nr_info_seqops = { + .start = nr_info_start, + .next = nr_info_next, + .stop = nr_info_stop, + .show = nr_info_show, +}; +#endif /* CONFIG_PROC_FS */ + +static const struct net_proto_family nr_family_ops = { + .family = PF_NETROM, + .create = nr_create, + .owner = THIS_MODULE, +}; + +static const struct proto_ops nr_proto_ops = { + .family = PF_NETROM, + .owner = THIS_MODULE, + .release = nr_release, + .bind = nr_bind, + .connect = nr_connect, + .socketpair = sock_no_socketpair, + .accept = nr_accept, + .getname = nr_getname, + .poll = datagram_poll, + .ioctl = nr_ioctl, + .gettstamp = sock_gettstamp, + .listen = nr_listen, + .shutdown = sock_no_shutdown, + .setsockopt = nr_setsockopt, + .getsockopt = nr_getsockopt, + .sendmsg = nr_sendmsg, + .recvmsg = nr_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static struct notifier_block nr_dev_notifier = { + .notifier_call = nr_device_event, +}; + +static struct net_device **dev_nr; + +static struct ax25_protocol nr_pid = { + .pid = AX25_P_NETROM, + .func = nr_route_frame +}; + +static struct ax25_linkfail nr_linkfail_notifier = { + .func = nr_link_failed, +}; + +static int __init nr_proto_init(void) +{ + int i; + int rc = proto_register(&nr_proto, 0); + + if (rc) + return rc; + + if (nr_ndevs > 0x7fffffff/sizeof(struct net_device *)) { + pr_err("NET/ROM: %s - nr_ndevs parameter too large\n", + __func__); + rc = -EINVAL; + goto unregister_proto; + } + + dev_nr = kcalloc(nr_ndevs, sizeof(struct net_device *), GFP_KERNEL); + if (!dev_nr) { + pr_err("NET/ROM: %s - unable to allocate device array\n", + __func__); + rc = -ENOMEM; + goto unregister_proto; + } + + for (i = 0; i < nr_ndevs; i++) { + char name[IFNAMSIZ]; + struct net_device *dev; + + sprintf(name, "nr%d", i); + dev = alloc_netdev(0, name, NET_NAME_UNKNOWN, nr_setup); + if (!dev) { + rc = -ENOMEM; + goto fail; + } + + dev->base_addr = i; + rc = register_netdev(dev); + if (rc) { + free_netdev(dev); + goto fail; + } + nr_set_lockdep_key(dev); + dev_nr[i] = dev; + } + + rc = sock_register(&nr_family_ops); + if (rc) + goto fail; + + rc = register_netdevice_notifier(&nr_dev_notifier); + if (rc) + goto out_sock; + + ax25_register_pid(&nr_pid); + ax25_linkfail_register(&nr_linkfail_notifier); + +#ifdef CONFIG_SYSCTL + rc = nr_register_sysctl(); + if (rc) + goto out_sysctl; +#endif + + nr_loopback_init(); + + rc = -ENOMEM; + if (!proc_create_seq("nr", 0444, init_net.proc_net, &nr_info_seqops)) + goto proc_remove1; + if (!proc_create_seq("nr_neigh", 0444, init_net.proc_net, + &nr_neigh_seqops)) + goto proc_remove2; + if (!proc_create_seq("nr_nodes", 0444, init_net.proc_net, + &nr_node_seqops)) + goto proc_remove3; + + return 0; + +proc_remove3: + remove_proc_entry("nr_neigh", init_net.proc_net); +proc_remove2: + remove_proc_entry("nr", init_net.proc_net); +proc_remove1: + + nr_loopback_clear(); + nr_rt_free(); + +#ifdef CONFIG_SYSCTL + nr_unregister_sysctl(); +out_sysctl: +#endif + ax25_linkfail_release(&nr_linkfail_notifier); + ax25_protocol_release(AX25_P_NETROM); + unregister_netdevice_notifier(&nr_dev_notifier); +out_sock: + sock_unregister(PF_NETROM); +fail: + while (--i >= 0) { + unregister_netdev(dev_nr[i]); + free_netdev(dev_nr[i]); + } + kfree(dev_nr); +unregister_proto: + proto_unregister(&nr_proto); + return rc; +} + +module_init(nr_proto_init); + +module_param(nr_ndevs, int, 0); +MODULE_PARM_DESC(nr_ndevs, "number of NET/ROM devices"); + +MODULE_AUTHOR("Jonathan Naylor G4KLX "); +MODULE_DESCRIPTION("The amateur radio NET/ROM network and transport layer protocol"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_NETROM); + +static void __exit nr_exit(void) +{ + int i; + + remove_proc_entry("nr", init_net.proc_net); + remove_proc_entry("nr_neigh", init_net.proc_net); + remove_proc_entry("nr_nodes", init_net.proc_net); + nr_loopback_clear(); + + nr_rt_free(); + +#ifdef CONFIG_SYSCTL + nr_unregister_sysctl(); +#endif + + ax25_linkfail_release(&nr_linkfail_notifier); + ax25_protocol_release(AX25_P_NETROM); + + unregister_netdevice_notifier(&nr_dev_notifier); + + sock_unregister(PF_NETROM); + + for (i = 0; i < nr_ndevs; i++) { + struct net_device *dev = dev_nr[i]; + if (dev) { + unregister_netdev(dev); + free_netdev(dev); + } + } + + kfree(dev_nr); + proto_unregister(&nr_proto); +} +module_exit(nr_exit); diff --git a/net/netrom/nr_dev.c b/net/netrom/nr_dev.c new file mode 100644 index 000000000..3aaac4a22 --- /dev/null +++ b/net/netrom/nr_dev.c @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include /* For the statistics structure. */ +#include +#include + +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +/* + * Only allow IP over NET/ROM frames through if the netrom device is up. + */ + +int nr_rx_ip(struct sk_buff *skb, struct net_device *dev) +{ + struct net_device_stats *stats = &dev->stats; + + if (!netif_running(dev)) { + stats->rx_dropped++; + return 0; + } + + stats->rx_packets++; + stats->rx_bytes += skb->len; + + skb->protocol = htons(ETH_P_IP); + + /* Spoof incoming device */ + skb->dev = dev; + skb->mac_header = skb->network_header; + skb_reset_network_header(skb); + skb->pkt_type = PACKET_HOST; + + netif_rx(skb); + + return 1; +} + +static int nr_header(struct sk_buff *skb, struct net_device *dev, + unsigned short type, + const void *daddr, const void *saddr, unsigned int len) +{ + unsigned char *buff = skb_push(skb, NR_NETWORK_LEN + NR_TRANSPORT_LEN); + + memcpy(buff, (saddr != NULL) ? saddr : dev->dev_addr, dev->addr_len); + buff[6] &= ~AX25_CBIT; + buff[6] &= ~AX25_EBIT; + buff[6] |= AX25_SSSID_SPARE; + buff += AX25_ADDR_LEN; + + if (daddr != NULL) + memcpy(buff, daddr, dev->addr_len); + buff[6] &= ~AX25_CBIT; + buff[6] |= AX25_EBIT; + buff[6] |= AX25_SSSID_SPARE; + buff += AX25_ADDR_LEN; + + *buff++ = sysctl_netrom_network_ttl_initialiser; + + *buff++ = NR_PROTO_IP; + *buff++ = NR_PROTO_IP; + *buff++ = 0; + *buff++ = 0; + *buff++ = NR_PROTOEXT; + + if (daddr != NULL) + return 37; + + return -37; +} + +static int __must_check nr_set_mac_address(struct net_device *dev, void *addr) +{ + struct sockaddr *sa = addr; + int err; + + if (!memcmp(dev->dev_addr, sa->sa_data, dev->addr_len)) + return 0; + + if (dev->flags & IFF_UP) { + err = ax25_listen_register((ax25_address *)sa->sa_data, NULL); + if (err) + return err; + + ax25_listen_release((const ax25_address *)dev->dev_addr, NULL); + } + + dev_addr_set(dev, sa->sa_data); + + return 0; +} + +static int nr_open(struct net_device *dev) +{ + int err; + + err = ax25_listen_register((const ax25_address *)dev->dev_addr, NULL); + if (err) + return err; + + netif_start_queue(dev); + + return 0; +} + +static int nr_close(struct net_device *dev) +{ + ax25_listen_release((const ax25_address *)dev->dev_addr, NULL); + netif_stop_queue(dev); + return 0; +} + +static netdev_tx_t nr_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct net_device_stats *stats = &dev->stats; + unsigned int len = skb->len; + + if (!nr_route_frame(skb, NULL)) { + kfree_skb(skb); + stats->tx_errors++; + return NETDEV_TX_OK; + } + + stats->tx_packets++; + stats->tx_bytes += len; + + return NETDEV_TX_OK; +} + +static const struct header_ops nr_header_ops = { + .create = nr_header, +}; + +static const struct net_device_ops nr_netdev_ops = { + .ndo_open = nr_open, + .ndo_stop = nr_close, + .ndo_start_xmit = nr_xmit, + .ndo_set_mac_address = nr_set_mac_address, +}; + +void nr_setup(struct net_device *dev) +{ + dev->mtu = NR_MAX_PACKET_SIZE; + dev->netdev_ops = &nr_netdev_ops; + dev->header_ops = &nr_header_ops; + dev->hard_header_len = NR_NETWORK_LEN + NR_TRANSPORT_LEN; + dev->addr_len = AX25_ADDR_LEN; + dev->type = ARPHRD_NETROM; + + /* New-style flags. */ + dev->flags = IFF_NOARP; +} diff --git a/net/netrom/nr_in.c b/net/netrom/nr_in.c new file mode 100644 index 000000000..2f084b6f6 --- /dev/null +++ b/net/netrom/nr_in.c @@ -0,0 +1,301 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright Darryl Miles G7LED (dlm@g7led.demon.co.uk) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int nr_queue_rx_frame(struct sock *sk, struct sk_buff *skb, int more) +{ + struct sk_buff *skbo, *skbn = skb; + struct nr_sock *nr = nr_sk(sk); + + skb_pull(skb, NR_NETWORK_LEN + NR_TRANSPORT_LEN); + + nr_start_idletimer(sk); + + if (more) { + nr->fraglen += skb->len; + skb_queue_tail(&nr->frag_queue, skb); + return 0; + } + + if (!more && nr->fraglen > 0) { /* End of fragment */ + nr->fraglen += skb->len; + skb_queue_tail(&nr->frag_queue, skb); + + if ((skbn = alloc_skb(nr->fraglen, GFP_ATOMIC)) == NULL) + return 1; + + skb_reset_transport_header(skbn); + + while ((skbo = skb_dequeue(&nr->frag_queue)) != NULL) { + skb_copy_from_linear_data(skbo, + skb_put(skbn, skbo->len), + skbo->len); + kfree_skb(skbo); + } + + nr->fraglen = 0; + } + + return sock_queue_rcv_skb(sk, skbn); +} + +/* + * State machine for state 1, Awaiting Connection State. + * The handling of the timer(s) is in file nr_timer.c. + * Handling of state 0 and connection release is in netrom.c. + */ +static int nr_state1_machine(struct sock *sk, struct sk_buff *skb, + int frametype) +{ + switch (frametype) { + case NR_CONNACK: { + struct nr_sock *nr = nr_sk(sk); + + nr_stop_t1timer(sk); + nr_start_idletimer(sk); + nr->your_index = skb->data[17]; + nr->your_id = skb->data[18]; + nr->vs = 0; + nr->va = 0; + nr->vr = 0; + nr->vl = 0; + nr->state = NR_STATE_3; + nr->n2count = 0; + nr->window = skb->data[20]; + sk->sk_state = TCP_ESTABLISHED; + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_state_change(sk); + break; + } + + case NR_CONNACK | NR_CHOKE_FLAG: + nr_disconnect(sk, ECONNREFUSED); + break; + + case NR_RESET: + if (sysctl_netrom_reset_circuit) + nr_disconnect(sk, ECONNRESET); + break; + + default: + break; + } + return 0; +} + +/* + * State machine for state 2, Awaiting Release State. + * The handling of the timer(s) is in file nr_timer.c + * Handling of state 0 and connection release is in netrom.c. + */ +static int nr_state2_machine(struct sock *sk, struct sk_buff *skb, + int frametype) +{ + switch (frametype) { + case NR_CONNACK | NR_CHOKE_FLAG: + nr_disconnect(sk, ECONNRESET); + break; + + case NR_DISCREQ: + nr_write_internal(sk, NR_DISCACK); + fallthrough; + case NR_DISCACK: + nr_disconnect(sk, 0); + break; + + case NR_RESET: + if (sysctl_netrom_reset_circuit) + nr_disconnect(sk, ECONNRESET); + break; + + default: + break; + } + return 0; +} + +/* + * State machine for state 3, Connected State. + * The handling of the timer(s) is in file nr_timer.c + * Handling of state 0 and connection release is in netrom.c. + */ +static int nr_state3_machine(struct sock *sk, struct sk_buff *skb, int frametype) +{ + struct nr_sock *nrom = nr_sk(sk); + struct sk_buff_head temp_queue; + struct sk_buff *skbn; + unsigned short save_vr; + unsigned short nr, ns; + int queued = 0; + + nr = skb->data[18]; + + switch (frametype) { + case NR_CONNREQ: + nr_write_internal(sk, NR_CONNACK); + break; + + case NR_DISCREQ: + nr_write_internal(sk, NR_DISCACK); + nr_disconnect(sk, 0); + break; + + case NR_CONNACK | NR_CHOKE_FLAG: + case NR_DISCACK: + nr_disconnect(sk, ECONNRESET); + break; + + case NR_INFOACK: + case NR_INFOACK | NR_CHOKE_FLAG: + case NR_INFOACK | NR_NAK_FLAG: + case NR_INFOACK | NR_NAK_FLAG | NR_CHOKE_FLAG: + if (frametype & NR_CHOKE_FLAG) { + nrom->condition |= NR_COND_PEER_RX_BUSY; + nr_start_t4timer(sk); + } else { + nrom->condition &= ~NR_COND_PEER_RX_BUSY; + nr_stop_t4timer(sk); + } + if (!nr_validate_nr(sk, nr)) { + break; + } + if (frametype & NR_NAK_FLAG) { + nr_frames_acked(sk, nr); + nr_send_nak_frame(sk); + } else { + if (nrom->condition & NR_COND_PEER_RX_BUSY) { + nr_frames_acked(sk, nr); + } else { + nr_check_iframes_acked(sk, nr); + } + } + break; + + case NR_INFO: + case NR_INFO | NR_NAK_FLAG: + case NR_INFO | NR_CHOKE_FLAG: + case NR_INFO | NR_MORE_FLAG: + case NR_INFO | NR_NAK_FLAG | NR_CHOKE_FLAG: + case NR_INFO | NR_CHOKE_FLAG | NR_MORE_FLAG: + case NR_INFO | NR_NAK_FLAG | NR_MORE_FLAG: + case NR_INFO | NR_NAK_FLAG | NR_CHOKE_FLAG | NR_MORE_FLAG: + if (frametype & NR_CHOKE_FLAG) { + nrom->condition |= NR_COND_PEER_RX_BUSY; + nr_start_t4timer(sk); + } else { + nrom->condition &= ~NR_COND_PEER_RX_BUSY; + nr_stop_t4timer(sk); + } + if (nr_validate_nr(sk, nr)) { + if (frametype & NR_NAK_FLAG) { + nr_frames_acked(sk, nr); + nr_send_nak_frame(sk); + } else { + if (nrom->condition & NR_COND_PEER_RX_BUSY) { + nr_frames_acked(sk, nr); + } else { + nr_check_iframes_acked(sk, nr); + } + } + } + queued = 1; + skb_queue_head(&nrom->reseq_queue, skb); + if (nrom->condition & NR_COND_OWN_RX_BUSY) + break; + skb_queue_head_init(&temp_queue); + do { + save_vr = nrom->vr; + while ((skbn = skb_dequeue(&nrom->reseq_queue)) != NULL) { + ns = skbn->data[17]; + if (ns == nrom->vr) { + if (nr_queue_rx_frame(sk, skbn, frametype & NR_MORE_FLAG) == 0) { + nrom->vr = (nrom->vr + 1) % NR_MODULUS; + } else { + nrom->condition |= NR_COND_OWN_RX_BUSY; + skb_queue_tail(&temp_queue, skbn); + } + } else if (nr_in_rx_window(sk, ns)) { + skb_queue_tail(&temp_queue, skbn); + } else { + kfree_skb(skbn); + } + } + while ((skbn = skb_dequeue(&temp_queue)) != NULL) { + skb_queue_tail(&nrom->reseq_queue, skbn); + } + } while (save_vr != nrom->vr); + /* + * Window is full, ack it immediately. + */ + if (((nrom->vl + nrom->window) % NR_MODULUS) == nrom->vr) { + nr_enquiry_response(sk); + } else { + if (!(nrom->condition & NR_COND_ACK_PENDING)) { + nrom->condition |= NR_COND_ACK_PENDING; + nr_start_t2timer(sk); + } + } + break; + + case NR_RESET: + if (sysctl_netrom_reset_circuit) + nr_disconnect(sk, ECONNRESET); + break; + + default: + break; + } + return queued; +} + +/* Higher level upcall for a LAPB frame - called with sk locked */ +int nr_process_rx_frame(struct sock *sk, struct sk_buff *skb) +{ + struct nr_sock *nr = nr_sk(sk); + int queued = 0, frametype; + + if (nr->state == NR_STATE_0) + return 0; + + frametype = skb->data[19]; + + switch (nr->state) { + case NR_STATE_1: + queued = nr_state1_machine(sk, skb, frametype); + break; + case NR_STATE_2: + queued = nr_state2_machine(sk, skb, frametype); + break; + case NR_STATE_3: + queued = nr_state3_machine(sk, skb, frametype); + break; + } + + nr_kick(sk); + + return queued; +} diff --git a/net/netrom/nr_loopback.c b/net/netrom/nr_loopback.c new file mode 100644 index 000000000..511819fbf --- /dev/null +++ b/net/netrom/nr_loopback.c @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright Tomi Manninen OH2BNS (oh2bns@sral.fi) + */ +#include +#include +#include +#include +#include +#include +#include +#include + +static void nr_loopback_timer(struct timer_list *); + +static struct sk_buff_head loopback_queue; +static DEFINE_TIMER(loopback_timer, nr_loopback_timer); + +void __init nr_loopback_init(void) +{ + skb_queue_head_init(&loopback_queue); +} + +static inline int nr_loopback_running(void) +{ + return timer_pending(&loopback_timer); +} + +int nr_loopback_queue(struct sk_buff *skb) +{ + struct sk_buff *skbn; + + if ((skbn = alloc_skb(skb->len, GFP_ATOMIC)) != NULL) { + skb_copy_from_linear_data(skb, skb_put(skbn, skb->len), skb->len); + skb_reset_transport_header(skbn); + + skb_queue_tail(&loopback_queue, skbn); + + if (!nr_loopback_running()) + mod_timer(&loopback_timer, jiffies + 10); + } + + kfree_skb(skb); + return 1; +} + +static void nr_loopback_timer(struct timer_list *unused) +{ + struct sk_buff *skb; + ax25_address *nr_dest; + struct net_device *dev; + + if ((skb = skb_dequeue(&loopback_queue)) != NULL) { + nr_dest = (ax25_address *)(skb->data + 7); + + dev = nr_dev_get(nr_dest); + + if (dev == NULL || nr_rx_frame(skb, dev) == 0) + kfree_skb(skb); + + dev_put(dev); + + if (!skb_queue_empty(&loopback_queue) && !nr_loopback_running()) + mod_timer(&loopback_timer, jiffies + 10); + } +} + +void nr_loopback_clear(void) +{ + del_timer_sync(&loopback_timer); + skb_queue_purge(&loopback_queue); +} diff --git a/net/netrom/nr_out.c b/net/netrom/nr_out.c new file mode 100644 index 000000000..44929657f --- /dev/null +++ b/net/netrom/nr_out.c @@ -0,0 +1,270 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright Darryl Miles G7LED (dlm@g7led.demon.co.uk) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * This is where all NET/ROM frames pass, except for IP-over-NET/ROM which + * cannot be fragmented in this manner. + */ +void nr_output(struct sock *sk, struct sk_buff *skb) +{ + struct sk_buff *skbn; + unsigned char transport[NR_TRANSPORT_LEN]; + int err, frontlen, len; + + if (skb->len - NR_TRANSPORT_LEN > NR_MAX_PACKET_SIZE) { + /* Save a copy of the Transport Header */ + skb_copy_from_linear_data(skb, transport, NR_TRANSPORT_LEN); + skb_pull(skb, NR_TRANSPORT_LEN); + + frontlen = skb_headroom(skb); + + while (skb->len > 0) { + if ((skbn = sock_alloc_send_skb(sk, frontlen + NR_MAX_PACKET_SIZE, 0, &err)) == NULL) + return; + + skb_reserve(skbn, frontlen); + + len = (NR_MAX_PACKET_SIZE > skb->len) ? skb->len : NR_MAX_PACKET_SIZE; + + /* Copy the user data */ + skb_copy_from_linear_data(skb, skb_put(skbn, len), len); + skb_pull(skb, len); + + /* Duplicate the Transport Header */ + skb_push(skbn, NR_TRANSPORT_LEN); + skb_copy_to_linear_data(skbn, transport, + NR_TRANSPORT_LEN); + if (skb->len > 0) + skbn->data[4] |= NR_MORE_FLAG; + + skb_queue_tail(&sk->sk_write_queue, skbn); /* Throw it on the queue */ + } + + kfree_skb(skb); + } else { + skb_queue_tail(&sk->sk_write_queue, skb); /* Throw it on the queue */ + } + + nr_kick(sk); +} + +/* + * This procedure is passed a buffer descriptor for an iframe. It builds + * the rest of the control part of the frame and then writes it out. + */ +static void nr_send_iframe(struct sock *sk, struct sk_buff *skb) +{ + struct nr_sock *nr = nr_sk(sk); + + if (skb == NULL) + return; + + skb->data[2] = nr->vs; + skb->data[3] = nr->vr; + + if (nr->condition & NR_COND_OWN_RX_BUSY) + skb->data[4] |= NR_CHOKE_FLAG; + + nr_start_idletimer(sk); + + nr_transmit_buffer(sk, skb); +} + +void nr_send_nak_frame(struct sock *sk) +{ + struct sk_buff *skb, *skbn; + struct nr_sock *nr = nr_sk(sk); + + if ((skb = skb_peek(&nr->ack_queue)) == NULL) + return; + + if ((skbn = skb_clone(skb, GFP_ATOMIC)) == NULL) + return; + + skbn->data[2] = nr->va; + skbn->data[3] = nr->vr; + + if (nr->condition & NR_COND_OWN_RX_BUSY) + skbn->data[4] |= NR_CHOKE_FLAG; + + nr_transmit_buffer(sk, skbn); + + nr->condition &= ~NR_COND_ACK_PENDING; + nr->vl = nr->vr; + + nr_stop_t1timer(sk); +} + +void nr_kick(struct sock *sk) +{ + struct nr_sock *nr = nr_sk(sk); + struct sk_buff *skb, *skbn; + unsigned short start, end; + + if (nr->state != NR_STATE_3) + return; + + if (nr->condition & NR_COND_PEER_RX_BUSY) + return; + + if (!skb_peek(&sk->sk_write_queue)) + return; + + start = (skb_peek(&nr->ack_queue) == NULL) ? nr->va : nr->vs; + end = (nr->va + nr->window) % NR_MODULUS; + + if (start == end) + return; + + nr->vs = start; + + /* + * Transmit data until either we're out of data to send or + * the window is full. + */ + + /* + * Dequeue the frame and copy it. + */ + skb = skb_dequeue(&sk->sk_write_queue); + + do { + if ((skbn = skb_clone(skb, GFP_ATOMIC)) == NULL) { + skb_queue_head(&sk->sk_write_queue, skb); + break; + } + + skb_set_owner_w(skbn, sk); + + /* + * Transmit the frame copy. + */ + nr_send_iframe(sk, skbn); + + nr->vs = (nr->vs + 1) % NR_MODULUS; + + /* + * Requeue the original data frame. + */ + skb_queue_tail(&nr->ack_queue, skb); + + } while (nr->vs != end && + (skb = skb_dequeue(&sk->sk_write_queue)) != NULL); + + nr->vl = nr->vr; + nr->condition &= ~NR_COND_ACK_PENDING; + + if (!nr_t1timer_running(sk)) + nr_start_t1timer(sk); +} + +void nr_transmit_buffer(struct sock *sk, struct sk_buff *skb) +{ + struct nr_sock *nr = nr_sk(sk); + unsigned char *dptr; + + /* + * Add the protocol byte and network header. + */ + dptr = skb_push(skb, NR_NETWORK_LEN); + + memcpy(dptr, &nr->source_addr, AX25_ADDR_LEN); + dptr[6] &= ~AX25_CBIT; + dptr[6] &= ~AX25_EBIT; + dptr[6] |= AX25_SSSID_SPARE; + dptr += AX25_ADDR_LEN; + + memcpy(dptr, &nr->dest_addr, AX25_ADDR_LEN); + dptr[6] &= ~AX25_CBIT; + dptr[6] |= AX25_EBIT; + dptr[6] |= AX25_SSSID_SPARE; + dptr += AX25_ADDR_LEN; + + *dptr++ = sysctl_netrom_network_ttl_initialiser; + + if (!nr_route_frame(skb, NULL)) { + kfree_skb(skb); + nr_disconnect(sk, ENETUNREACH); + } +} + +/* + * The following routines are taken from page 170 of the 7th ARRL Computer + * Networking Conference paper, as is the whole state machine. + */ + +void nr_establish_data_link(struct sock *sk) +{ + struct nr_sock *nr = nr_sk(sk); + + nr->condition = 0x00; + nr->n2count = 0; + + nr_write_internal(sk, NR_CONNREQ); + + nr_stop_t2timer(sk); + nr_stop_t4timer(sk); + nr_stop_idletimer(sk); + nr_start_t1timer(sk); +} + +/* + * Never send a NAK when we are CHOKEd. + */ +void nr_enquiry_response(struct sock *sk) +{ + struct nr_sock *nr = nr_sk(sk); + int frametype = NR_INFOACK; + + if (nr->condition & NR_COND_OWN_RX_BUSY) { + frametype |= NR_CHOKE_FLAG; + } else { + if (skb_peek(&nr->reseq_queue) != NULL) + frametype |= NR_NAK_FLAG; + } + + nr_write_internal(sk, frametype); + + nr->vl = nr->vr; + nr->condition &= ~NR_COND_ACK_PENDING; +} + +void nr_check_iframes_acked(struct sock *sk, unsigned short nr) +{ + struct nr_sock *nrom = nr_sk(sk); + + if (nrom->vs == nr) { + nr_frames_acked(sk, nr); + nr_stop_t1timer(sk); + nrom->n2count = 0; + } else { + if (nrom->va != nr) { + nr_frames_acked(sk, nr); + nr_start_t1timer(sk); + } + } +} diff --git a/net/netrom/nr_route.c b/net/netrom/nr_route.c new file mode 100644 index 000000000..baea3cbd7 --- /dev/null +++ b/net/netrom/nr_route.c @@ -0,0 +1,983 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright Alan Cox GW4PTS (alan@lxorguk.ukuu.org.uk) + * Copyright Tomi Manninen OH2BNS (oh2bns@sral.fi) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include /* For TIOCINQ/OUTQ */ +#include +#include +#include +#include +#include +#include +#include +#include + +static unsigned int nr_neigh_no = 1; + +static HLIST_HEAD(nr_node_list); +static DEFINE_SPINLOCK(nr_node_list_lock); +static HLIST_HEAD(nr_neigh_list); +static DEFINE_SPINLOCK(nr_neigh_list_lock); + +static struct nr_node *nr_node_get(ax25_address *callsign) +{ + struct nr_node *found = NULL; + struct nr_node *nr_node; + + spin_lock_bh(&nr_node_list_lock); + nr_node_for_each(nr_node, &nr_node_list) + if (ax25cmp(callsign, &nr_node->callsign) == 0) { + nr_node_hold(nr_node); + found = nr_node; + break; + } + spin_unlock_bh(&nr_node_list_lock); + return found; +} + +static struct nr_neigh *nr_neigh_get_dev(ax25_address *callsign, + struct net_device *dev) +{ + struct nr_neigh *found = NULL; + struct nr_neigh *nr_neigh; + + spin_lock_bh(&nr_neigh_list_lock); + nr_neigh_for_each(nr_neigh, &nr_neigh_list) + if (ax25cmp(callsign, &nr_neigh->callsign) == 0 && + nr_neigh->dev == dev) { + nr_neigh_hold(nr_neigh); + found = nr_neigh; + break; + } + spin_unlock_bh(&nr_neigh_list_lock); + return found; +} + +static void nr_remove_neigh(struct nr_neigh *); + +/* re-sort the routes in quality order. */ +static void re_sort_routes(struct nr_node *nr_node, int x, int y) +{ + if (nr_node->routes[y].quality > nr_node->routes[x].quality) { + if (nr_node->which == x) + nr_node->which = y; + else if (nr_node->which == y) + nr_node->which = x; + + swap(nr_node->routes[x], nr_node->routes[y]); + } +} + +/* + * Add a new route to a node, and in the process add the node and the + * neighbour if it is new. + */ +static int __must_check nr_add_node(ax25_address *nr, const char *mnemonic, + ax25_address *ax25, ax25_digi *ax25_digi, struct net_device *dev, + int quality, int obs_count) +{ + struct nr_node *nr_node; + struct nr_neigh *nr_neigh; + int i, found; + struct net_device *odev; + + if ((odev=nr_dev_get(nr)) != NULL) { /* Can't add routes to ourself */ + dev_put(odev); + return -EINVAL; + } + + nr_node = nr_node_get(nr); + + nr_neigh = nr_neigh_get_dev(ax25, dev); + + /* + * The L2 link to a neighbour has failed in the past + * and now a frame comes from this neighbour. We assume + * it was a temporary trouble with the link and reset the + * routes now (and not wait for a node broadcast). + */ + if (nr_neigh != NULL && nr_neigh->failed != 0 && quality == 0) { + struct nr_node *nr_nodet; + + spin_lock_bh(&nr_node_list_lock); + nr_node_for_each(nr_nodet, &nr_node_list) { + nr_node_lock(nr_nodet); + for (i = 0; i < nr_nodet->count; i++) + if (nr_nodet->routes[i].neighbour == nr_neigh) + if (i < nr_nodet->which) + nr_nodet->which = i; + nr_node_unlock(nr_nodet); + } + spin_unlock_bh(&nr_node_list_lock); + } + + if (nr_neigh != NULL) + nr_neigh->failed = 0; + + if (quality == 0 && nr_neigh != NULL && nr_node != NULL) { + nr_neigh_put(nr_neigh); + nr_node_put(nr_node); + return 0; + } + + if (nr_neigh == NULL) { + if ((nr_neigh = kmalloc(sizeof(*nr_neigh), GFP_ATOMIC)) == NULL) { + if (nr_node) + nr_node_put(nr_node); + return -ENOMEM; + } + + nr_neigh->callsign = *ax25; + nr_neigh->digipeat = NULL; + nr_neigh->ax25 = NULL; + nr_neigh->dev = dev; + nr_neigh->quality = sysctl_netrom_default_path_quality; + nr_neigh->locked = 0; + nr_neigh->count = 0; + nr_neigh->number = nr_neigh_no++; + nr_neigh->failed = 0; + refcount_set(&nr_neigh->refcount, 1); + + if (ax25_digi != NULL && ax25_digi->ndigi > 0) { + nr_neigh->digipeat = kmemdup(ax25_digi, + sizeof(*ax25_digi), + GFP_KERNEL); + if (nr_neigh->digipeat == NULL) { + kfree(nr_neigh); + if (nr_node) + nr_node_put(nr_node); + return -ENOMEM; + } + } + + spin_lock_bh(&nr_neigh_list_lock); + hlist_add_head(&nr_neigh->neigh_node, &nr_neigh_list); + nr_neigh_hold(nr_neigh); + spin_unlock_bh(&nr_neigh_list_lock); + } + + if (quality != 0 && ax25cmp(nr, ax25) == 0 && !nr_neigh->locked) + nr_neigh->quality = quality; + + if (nr_node == NULL) { + if ((nr_node = kmalloc(sizeof(*nr_node), GFP_ATOMIC)) == NULL) { + if (nr_neigh) + nr_neigh_put(nr_neigh); + return -ENOMEM; + } + + nr_node->callsign = *nr; + strcpy(nr_node->mnemonic, mnemonic); + + nr_node->which = 0; + nr_node->count = 1; + refcount_set(&nr_node->refcount, 1); + spin_lock_init(&nr_node->node_lock); + + nr_node->routes[0].quality = quality; + nr_node->routes[0].obs_count = obs_count; + nr_node->routes[0].neighbour = nr_neigh; + + nr_neigh_hold(nr_neigh); + nr_neigh->count++; + + spin_lock_bh(&nr_node_list_lock); + hlist_add_head(&nr_node->node_node, &nr_node_list); + /* refcount initialized at 1 */ + spin_unlock_bh(&nr_node_list_lock); + + nr_neigh_put(nr_neigh); + return 0; + } + nr_node_lock(nr_node); + + if (quality != 0) + strcpy(nr_node->mnemonic, mnemonic); + + for (found = 0, i = 0; i < nr_node->count; i++) { + if (nr_node->routes[i].neighbour == nr_neigh) { + nr_node->routes[i].quality = quality; + nr_node->routes[i].obs_count = obs_count; + found = 1; + break; + } + } + + if (!found) { + /* We have space at the bottom, slot it in */ + if (nr_node->count < 3) { + nr_node->routes[2] = nr_node->routes[1]; + nr_node->routes[1] = nr_node->routes[0]; + + nr_node->routes[0].quality = quality; + nr_node->routes[0].obs_count = obs_count; + nr_node->routes[0].neighbour = nr_neigh; + + nr_node->which++; + nr_node->count++; + nr_neigh_hold(nr_neigh); + nr_neigh->count++; + } else { + /* It must be better than the worst */ + if (quality > nr_node->routes[2].quality) { + nr_node->routes[2].neighbour->count--; + nr_neigh_put(nr_node->routes[2].neighbour); + + if (nr_node->routes[2].neighbour->count == 0 && !nr_node->routes[2].neighbour->locked) + nr_remove_neigh(nr_node->routes[2].neighbour); + + nr_node->routes[2].quality = quality; + nr_node->routes[2].obs_count = obs_count; + nr_node->routes[2].neighbour = nr_neigh; + + nr_neigh_hold(nr_neigh); + nr_neigh->count++; + } + } + } + + /* Now re-sort the routes in quality order */ + switch (nr_node->count) { + case 3: + re_sort_routes(nr_node, 0, 1); + re_sort_routes(nr_node, 1, 2); + fallthrough; + case 2: + re_sort_routes(nr_node, 0, 1); + break; + case 1: + break; + } + + for (i = 0; i < nr_node->count; i++) { + if (nr_node->routes[i].neighbour == nr_neigh) { + if (i < nr_node->which) + nr_node->which = i; + break; + } + } + + nr_neigh_put(nr_neigh); + nr_node_unlock(nr_node); + nr_node_put(nr_node); + return 0; +} + +static inline void __nr_remove_node(struct nr_node *nr_node) +{ + hlist_del_init(&nr_node->node_node); + nr_node_put(nr_node); +} + +#define nr_remove_node_locked(__node) \ + __nr_remove_node(__node) + +static void nr_remove_node(struct nr_node *nr_node) +{ + spin_lock_bh(&nr_node_list_lock); + __nr_remove_node(nr_node); + spin_unlock_bh(&nr_node_list_lock); +} + +static inline void __nr_remove_neigh(struct nr_neigh *nr_neigh) +{ + hlist_del_init(&nr_neigh->neigh_node); + nr_neigh_put(nr_neigh); +} + +#define nr_remove_neigh_locked(__neigh) \ + __nr_remove_neigh(__neigh) + +static void nr_remove_neigh(struct nr_neigh *nr_neigh) +{ + spin_lock_bh(&nr_neigh_list_lock); + __nr_remove_neigh(nr_neigh); + spin_unlock_bh(&nr_neigh_list_lock); +} + +/* + * "Delete" a node. Strictly speaking remove a route to a node. The node + * is only deleted if no routes are left to it. + */ +static int nr_del_node(ax25_address *callsign, ax25_address *neighbour, struct net_device *dev) +{ + struct nr_node *nr_node; + struct nr_neigh *nr_neigh; + int i; + + nr_node = nr_node_get(callsign); + + if (nr_node == NULL) + return -EINVAL; + + nr_neigh = nr_neigh_get_dev(neighbour, dev); + + if (nr_neigh == NULL) { + nr_node_put(nr_node); + return -EINVAL; + } + + nr_node_lock(nr_node); + for (i = 0; i < nr_node->count; i++) { + if (nr_node->routes[i].neighbour == nr_neigh) { + nr_neigh->count--; + nr_neigh_put(nr_neigh); + + if (nr_neigh->count == 0 && !nr_neigh->locked) + nr_remove_neigh(nr_neigh); + nr_neigh_put(nr_neigh); + + nr_node->count--; + + if (nr_node->count == 0) { + nr_remove_node(nr_node); + } else { + switch (i) { + case 0: + nr_node->routes[0] = nr_node->routes[1]; + fallthrough; + case 1: + nr_node->routes[1] = nr_node->routes[2]; + fallthrough; + case 2: + break; + } + nr_node_put(nr_node); + } + nr_node_unlock(nr_node); + + return 0; + } + } + nr_neigh_put(nr_neigh); + nr_node_unlock(nr_node); + nr_node_put(nr_node); + + return -EINVAL; +} + +/* + * Lock a neighbour with a quality. + */ +static int __must_check nr_add_neigh(ax25_address *callsign, + ax25_digi *ax25_digi, struct net_device *dev, unsigned int quality) +{ + struct nr_neigh *nr_neigh; + + nr_neigh = nr_neigh_get_dev(callsign, dev); + if (nr_neigh) { + nr_neigh->quality = quality; + nr_neigh->locked = 1; + nr_neigh_put(nr_neigh); + return 0; + } + + if ((nr_neigh = kmalloc(sizeof(*nr_neigh), GFP_ATOMIC)) == NULL) + return -ENOMEM; + + nr_neigh->callsign = *callsign; + nr_neigh->digipeat = NULL; + nr_neigh->ax25 = NULL; + nr_neigh->dev = dev; + nr_neigh->quality = quality; + nr_neigh->locked = 1; + nr_neigh->count = 0; + nr_neigh->number = nr_neigh_no++; + nr_neigh->failed = 0; + refcount_set(&nr_neigh->refcount, 1); + + if (ax25_digi != NULL && ax25_digi->ndigi > 0) { + nr_neigh->digipeat = kmemdup(ax25_digi, sizeof(*ax25_digi), + GFP_KERNEL); + if (nr_neigh->digipeat == NULL) { + kfree(nr_neigh); + return -ENOMEM; + } + } + + spin_lock_bh(&nr_neigh_list_lock); + hlist_add_head(&nr_neigh->neigh_node, &nr_neigh_list); + /* refcount is initialized at 1 */ + spin_unlock_bh(&nr_neigh_list_lock); + + return 0; +} + +/* + * "Delete" a neighbour. The neighbour is only removed if the number + * of nodes that may use it is zero. + */ +static int nr_del_neigh(ax25_address *callsign, struct net_device *dev, unsigned int quality) +{ + struct nr_neigh *nr_neigh; + + nr_neigh = nr_neigh_get_dev(callsign, dev); + + if (nr_neigh == NULL) return -EINVAL; + + nr_neigh->quality = quality; + nr_neigh->locked = 0; + + if (nr_neigh->count == 0) + nr_remove_neigh(nr_neigh); + nr_neigh_put(nr_neigh); + + return 0; +} + +/* + * Decrement the obsolescence count by one. If a route is reduced to a + * count of zero, remove it. Also remove any unlocked neighbours with + * zero nodes routing via it. + */ +static int nr_dec_obs(void) +{ + struct nr_neigh *nr_neigh; + struct nr_node *s; + struct hlist_node *nodet; + int i; + + spin_lock_bh(&nr_node_list_lock); + nr_node_for_each_safe(s, nodet, &nr_node_list) { + nr_node_lock(s); + for (i = 0; i < s->count; i++) { + switch (s->routes[i].obs_count) { + case 0: /* A locked entry */ + break; + + case 1: /* From 1 -> 0 */ + nr_neigh = s->routes[i].neighbour; + + nr_neigh->count--; + nr_neigh_put(nr_neigh); + + if (nr_neigh->count == 0 && !nr_neigh->locked) + nr_remove_neigh(nr_neigh); + + s->count--; + + switch (i) { + case 0: + s->routes[0] = s->routes[1]; + fallthrough; + case 1: + s->routes[1] = s->routes[2]; + break; + case 2: + break; + } + break; + + default: + s->routes[i].obs_count--; + break; + + } + } + + if (s->count <= 0) + nr_remove_node_locked(s); + nr_node_unlock(s); + } + spin_unlock_bh(&nr_node_list_lock); + + return 0; +} + +/* + * A device has been removed. Remove its routes and neighbours. + */ +void nr_rt_device_down(struct net_device *dev) +{ + struct nr_neigh *s; + struct hlist_node *nodet, *node2t; + struct nr_node *t; + int i; + + spin_lock_bh(&nr_neigh_list_lock); + nr_neigh_for_each_safe(s, nodet, &nr_neigh_list) { + if (s->dev == dev) { + spin_lock_bh(&nr_node_list_lock); + nr_node_for_each_safe(t, node2t, &nr_node_list) { + nr_node_lock(t); + for (i = 0; i < t->count; i++) { + if (t->routes[i].neighbour == s) { + t->count--; + + switch (i) { + case 0: + t->routes[0] = t->routes[1]; + fallthrough; + case 1: + t->routes[1] = t->routes[2]; + break; + case 2: + break; + } + } + } + + if (t->count <= 0) + nr_remove_node_locked(t); + nr_node_unlock(t); + } + spin_unlock_bh(&nr_node_list_lock); + + nr_remove_neigh_locked(s); + } + } + spin_unlock_bh(&nr_neigh_list_lock); +} + +/* + * Check that the device given is a valid AX.25 interface that is "up". + * Or a valid ethernet interface with an AX.25 callsign binding. + */ +static struct net_device *nr_ax25_dev_get(char *devname) +{ + struct net_device *dev; + + if ((dev = dev_get_by_name(&init_net, devname)) == NULL) + return NULL; + + if ((dev->flags & IFF_UP) && dev->type == ARPHRD_AX25) + return dev; + + dev_put(dev); + return NULL; +} + +/* + * Find the first active NET/ROM device, usually "nr0". + */ +struct net_device *nr_dev_first(void) +{ + struct net_device *dev, *first = NULL; + + rcu_read_lock(); + for_each_netdev_rcu(&init_net, dev) { + if ((dev->flags & IFF_UP) && dev->type == ARPHRD_NETROM) + if (first == NULL || strncmp(dev->name, first->name, 3) < 0) + first = dev; + } + dev_hold(first); + rcu_read_unlock(); + + return first; +} + +/* + * Find the NET/ROM device for the given callsign. + */ +struct net_device *nr_dev_get(ax25_address *addr) +{ + struct net_device *dev; + + rcu_read_lock(); + for_each_netdev_rcu(&init_net, dev) { + if ((dev->flags & IFF_UP) && dev->type == ARPHRD_NETROM && + ax25cmp(addr, (const ax25_address *)dev->dev_addr) == 0) { + dev_hold(dev); + goto out; + } + } + dev = NULL; +out: + rcu_read_unlock(); + return dev; +} + +static ax25_digi *nr_call_to_digi(ax25_digi *digi, int ndigis, + ax25_address *digipeaters) +{ + int i; + + if (ndigis == 0) + return NULL; + + for (i = 0; i < ndigis; i++) { + digi->calls[i] = digipeaters[i]; + digi->repeated[i] = 0; + } + + digi->ndigi = ndigis; + digi->lastrepeat = -1; + + return digi; +} + +/* + * Handle the ioctls that control the routing functions. + */ +int nr_rt_ioctl(unsigned int cmd, void __user *arg) +{ + struct nr_route_struct nr_route; + struct net_device *dev; + ax25_digi digi; + int ret; + + switch (cmd) { + case SIOCADDRT: + if (copy_from_user(&nr_route, arg, sizeof(struct nr_route_struct))) + return -EFAULT; + if (nr_route.ndigis > AX25_MAX_DIGIS) + return -EINVAL; + if ((dev = nr_ax25_dev_get(nr_route.device)) == NULL) + return -EINVAL; + switch (nr_route.type) { + case NETROM_NODE: + if (strnlen(nr_route.mnemonic, 7) == 7) { + ret = -EINVAL; + break; + } + + ret = nr_add_node(&nr_route.callsign, + nr_route.mnemonic, + &nr_route.neighbour, + nr_call_to_digi(&digi, nr_route.ndigis, + nr_route.digipeaters), + dev, nr_route.quality, + nr_route.obs_count); + break; + case NETROM_NEIGH: + ret = nr_add_neigh(&nr_route.callsign, + nr_call_to_digi(&digi, nr_route.ndigis, + nr_route.digipeaters), + dev, nr_route.quality); + break; + default: + ret = -EINVAL; + } + dev_put(dev); + return ret; + + case SIOCDELRT: + if (copy_from_user(&nr_route, arg, sizeof(struct nr_route_struct))) + return -EFAULT; + if ((dev = nr_ax25_dev_get(nr_route.device)) == NULL) + return -EINVAL; + switch (nr_route.type) { + case NETROM_NODE: + ret = nr_del_node(&nr_route.callsign, + &nr_route.neighbour, dev); + break; + case NETROM_NEIGH: + ret = nr_del_neigh(&nr_route.callsign, + dev, nr_route.quality); + break; + default: + ret = -EINVAL; + } + dev_put(dev); + return ret; + + case SIOCNRDECOBS: + return nr_dec_obs(); + + default: + return -EINVAL; + } + + return 0; +} + +/* + * A level 2 link has timed out, therefore it appears to be a poor link, + * then don't use that neighbour until it is reset. + */ +void nr_link_failed(ax25_cb *ax25, int reason) +{ + struct nr_neigh *s, *nr_neigh = NULL; + struct nr_node *nr_node = NULL; + + spin_lock_bh(&nr_neigh_list_lock); + nr_neigh_for_each(s, &nr_neigh_list) { + if (s->ax25 == ax25) { + nr_neigh_hold(s); + nr_neigh = s; + break; + } + } + spin_unlock_bh(&nr_neigh_list_lock); + + if (nr_neigh == NULL) + return; + + nr_neigh->ax25 = NULL; + ax25_cb_put(ax25); + + if (++nr_neigh->failed < sysctl_netrom_link_fails_count) { + nr_neigh_put(nr_neigh); + return; + } + spin_lock_bh(&nr_node_list_lock); + nr_node_for_each(nr_node, &nr_node_list) { + nr_node_lock(nr_node); + if (nr_node->which < nr_node->count && + nr_node->routes[nr_node->which].neighbour == nr_neigh) + nr_node->which++; + nr_node_unlock(nr_node); + } + spin_unlock_bh(&nr_node_list_lock); + nr_neigh_put(nr_neigh); +} + +/* + * Route a frame to an appropriate AX.25 connection. A NULL ax25_cb + * indicates an internally generated frame. + */ +int nr_route_frame(struct sk_buff *skb, ax25_cb *ax25) +{ + ax25_address *nr_src, *nr_dest; + struct nr_neigh *nr_neigh; + struct nr_node *nr_node; + struct net_device *dev; + unsigned char *dptr; + ax25_cb *ax25s; + int ret; + struct sk_buff *skbn; + + + nr_src = (ax25_address *)(skb->data + 0); + nr_dest = (ax25_address *)(skb->data + 7); + + if (ax25 != NULL) { + ret = nr_add_node(nr_src, "", &ax25->dest_addr, ax25->digipeat, + ax25->ax25_dev->dev, 0, + sysctl_netrom_obsolescence_count_initialiser); + if (ret) + return ret; + } + + if ((dev = nr_dev_get(nr_dest)) != NULL) { /* Its for me */ + if (ax25 == NULL) /* Its from me */ + ret = nr_loopback_queue(skb); + else + ret = nr_rx_frame(skb, dev); + dev_put(dev); + return ret; + } + + if (!sysctl_netrom_routing_control && ax25 != NULL) + return 0; + + /* Its Time-To-Live has expired */ + if (skb->data[14] == 1) { + return 0; + } + + nr_node = nr_node_get(nr_dest); + if (nr_node == NULL) + return 0; + nr_node_lock(nr_node); + + if (nr_node->which >= nr_node->count) { + nr_node_unlock(nr_node); + nr_node_put(nr_node); + return 0; + } + + nr_neigh = nr_node->routes[nr_node->which].neighbour; + + if ((dev = nr_dev_first()) == NULL) { + nr_node_unlock(nr_node); + nr_node_put(nr_node); + return 0; + } + + /* We are going to change the netrom headers so we should get our + own skb, we also did not know until now how much header space + we had to reserve... - RXQ */ + if ((skbn=skb_copy_expand(skb, dev->hard_header_len, 0, GFP_ATOMIC)) == NULL) { + nr_node_unlock(nr_node); + nr_node_put(nr_node); + dev_put(dev); + return 0; + } + kfree_skb(skb); + skb=skbn; + skb->data[14]--; + + dptr = skb_push(skb, 1); + *dptr = AX25_P_NETROM; + + ax25s = nr_neigh->ax25; + nr_neigh->ax25 = ax25_send_frame(skb, 256, + (const ax25_address *)dev->dev_addr, + &nr_neigh->callsign, + nr_neigh->digipeat, nr_neigh->dev); + if (ax25s) + ax25_cb_put(ax25s); + + dev_put(dev); + ret = (nr_neigh->ax25 != NULL); + nr_node_unlock(nr_node); + nr_node_put(nr_node); + + return ret; +} + +#ifdef CONFIG_PROC_FS + +static void *nr_node_start(struct seq_file *seq, loff_t *pos) + __acquires(&nr_node_list_lock) +{ + spin_lock_bh(&nr_node_list_lock); + return seq_hlist_start_head(&nr_node_list, *pos); +} + +static void *nr_node_next(struct seq_file *seq, void *v, loff_t *pos) +{ + return seq_hlist_next(v, &nr_node_list, pos); +} + +static void nr_node_stop(struct seq_file *seq, void *v) + __releases(&nr_node_list_lock) +{ + spin_unlock_bh(&nr_node_list_lock); +} + +static int nr_node_show(struct seq_file *seq, void *v) +{ + char buf[11]; + int i; + + if (v == SEQ_START_TOKEN) + seq_puts(seq, + "callsign mnemonic w n qual obs neigh qual obs neigh qual obs neigh\n"); + else { + struct nr_node *nr_node = hlist_entry(v, struct nr_node, + node_node); + + nr_node_lock(nr_node); + seq_printf(seq, "%-9s %-7s %d %d", + ax2asc(buf, &nr_node->callsign), + (nr_node->mnemonic[0] == '\0') ? "*" : nr_node->mnemonic, + nr_node->which + 1, + nr_node->count); + + for (i = 0; i < nr_node->count; i++) { + seq_printf(seq, " %3d %d %05d", + nr_node->routes[i].quality, + nr_node->routes[i].obs_count, + nr_node->routes[i].neighbour->number); + } + nr_node_unlock(nr_node); + + seq_puts(seq, "\n"); + } + return 0; +} + +const struct seq_operations nr_node_seqops = { + .start = nr_node_start, + .next = nr_node_next, + .stop = nr_node_stop, + .show = nr_node_show, +}; + +static void *nr_neigh_start(struct seq_file *seq, loff_t *pos) + __acquires(&nr_neigh_list_lock) +{ + spin_lock_bh(&nr_neigh_list_lock); + return seq_hlist_start_head(&nr_neigh_list, *pos); +} + +static void *nr_neigh_next(struct seq_file *seq, void *v, loff_t *pos) +{ + return seq_hlist_next(v, &nr_neigh_list, pos); +} + +static void nr_neigh_stop(struct seq_file *seq, void *v) + __releases(&nr_neigh_list_lock) +{ + spin_unlock_bh(&nr_neigh_list_lock); +} + +static int nr_neigh_show(struct seq_file *seq, void *v) +{ + char buf[11]; + int i; + + if (v == SEQ_START_TOKEN) + seq_puts(seq, "addr callsign dev qual lock count failed digipeaters\n"); + else { + struct nr_neigh *nr_neigh; + + nr_neigh = hlist_entry(v, struct nr_neigh, neigh_node); + seq_printf(seq, "%05d %-9s %-4s %3d %d %3d %3d", + nr_neigh->number, + ax2asc(buf, &nr_neigh->callsign), + nr_neigh->dev ? nr_neigh->dev->name : "???", + nr_neigh->quality, + nr_neigh->locked, + nr_neigh->count, + nr_neigh->failed); + + if (nr_neigh->digipeat != NULL) { + for (i = 0; i < nr_neigh->digipeat->ndigi; i++) + seq_printf(seq, " %s", + ax2asc(buf, &nr_neigh->digipeat->calls[i])); + } + + seq_puts(seq, "\n"); + } + return 0; +} + +const struct seq_operations nr_neigh_seqops = { + .start = nr_neigh_start, + .next = nr_neigh_next, + .stop = nr_neigh_stop, + .show = nr_neigh_show, +}; +#endif + +/* + * Free all memory associated with the nodes and routes lists. + */ +void nr_rt_free(void) +{ + struct nr_neigh *s = NULL; + struct nr_node *t = NULL; + struct hlist_node *nodet; + + spin_lock_bh(&nr_neigh_list_lock); + spin_lock_bh(&nr_node_list_lock); + nr_node_for_each_safe(t, nodet, &nr_node_list) { + nr_node_lock(t); + nr_remove_node_locked(t); + nr_node_unlock(t); + } + nr_neigh_for_each_safe(s, nodet, &nr_neigh_list) { + while(s->count) { + s->count--; + nr_neigh_put(s); + } + nr_remove_neigh_locked(s); + } + spin_unlock_bh(&nr_node_list_lock); + spin_unlock_bh(&nr_neigh_list_lock); +} diff --git a/net/netrom/nr_subr.c b/net/netrom/nr_subr.c new file mode 100644 index 000000000..e2d2af924 --- /dev/null +++ b/net/netrom/nr_subr.c @@ -0,0 +1,279 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * This routine purges all of the queues of frames. + */ +void nr_clear_queues(struct sock *sk) +{ + struct nr_sock *nr = nr_sk(sk); + + skb_queue_purge(&sk->sk_write_queue); + skb_queue_purge(&nr->ack_queue); + skb_queue_purge(&nr->reseq_queue); + skb_queue_purge(&nr->frag_queue); +} + +/* + * This routine purges the input queue of those frames that have been + * acknowledged. This replaces the boxes labelled "V(a) <- N(r)" on the + * SDL diagram. + */ +void nr_frames_acked(struct sock *sk, unsigned short nr) +{ + struct nr_sock *nrom = nr_sk(sk); + struct sk_buff *skb; + + /* + * Remove all the ack-ed frames from the ack queue. + */ + if (nrom->va != nr) { + while (skb_peek(&nrom->ack_queue) != NULL && nrom->va != nr) { + skb = skb_dequeue(&nrom->ack_queue); + kfree_skb(skb); + nrom->va = (nrom->va + 1) % NR_MODULUS; + } + } +} + +/* + * Requeue all the un-ack-ed frames on the output queue to be picked + * up by nr_kick called from the timer. This arrangement handles the + * possibility of an empty output queue. + */ +void nr_requeue_frames(struct sock *sk) +{ + struct sk_buff *skb, *skb_prev = NULL; + + while ((skb = skb_dequeue(&nr_sk(sk)->ack_queue)) != NULL) { + if (skb_prev == NULL) + skb_queue_head(&sk->sk_write_queue, skb); + else + skb_append(skb_prev, skb, &sk->sk_write_queue); + skb_prev = skb; + } +} + +/* + * Validate that the value of nr is between va and vs. Return true or + * false for testing. + */ +int nr_validate_nr(struct sock *sk, unsigned short nr) +{ + struct nr_sock *nrom = nr_sk(sk); + unsigned short vc = nrom->va; + + while (vc != nrom->vs) { + if (nr == vc) return 1; + vc = (vc + 1) % NR_MODULUS; + } + + return nr == nrom->vs; +} + +/* + * Check that ns is within the receive window. + */ +int nr_in_rx_window(struct sock *sk, unsigned short ns) +{ + struct nr_sock *nr = nr_sk(sk); + unsigned short vc = nr->vr; + unsigned short vt = (nr->vl + nr->window) % NR_MODULUS; + + while (vc != vt) { + if (ns == vc) return 1; + vc = (vc + 1) % NR_MODULUS; + } + + return 0; +} + +/* + * This routine is called when the HDLC layer internally generates a + * control frame. + */ +void nr_write_internal(struct sock *sk, int frametype) +{ + struct nr_sock *nr = nr_sk(sk); + struct sk_buff *skb; + unsigned char *dptr; + int len, timeout; + + len = NR_TRANSPORT_LEN; + + switch (frametype & 0x0F) { + case NR_CONNREQ: + len += 17; + break; + case NR_CONNACK: + len += (nr->bpqext) ? 2 : 1; + break; + case NR_DISCREQ: + case NR_DISCACK: + case NR_INFOACK: + break; + default: + printk(KERN_ERR "NET/ROM: nr_write_internal - invalid frame type %d\n", frametype); + return; + } + + skb = alloc_skb(NR_NETWORK_LEN + len, GFP_ATOMIC); + if (!skb) + return; + + /* + * Space for AX.25 and NET/ROM network header + */ + skb_reserve(skb, NR_NETWORK_LEN); + + dptr = skb_put(skb, len); + + switch (frametype & 0x0F) { + case NR_CONNREQ: + timeout = nr->t1 / HZ; + *dptr++ = nr->my_index; + *dptr++ = nr->my_id; + *dptr++ = 0; + *dptr++ = 0; + *dptr++ = frametype; + *dptr++ = nr->window; + memcpy(dptr, &nr->user_addr, AX25_ADDR_LEN); + dptr[6] &= ~AX25_CBIT; + dptr[6] &= ~AX25_EBIT; + dptr[6] |= AX25_SSSID_SPARE; + dptr += AX25_ADDR_LEN; + memcpy(dptr, &nr->source_addr, AX25_ADDR_LEN); + dptr[6] &= ~AX25_CBIT; + dptr[6] &= ~AX25_EBIT; + dptr[6] |= AX25_SSSID_SPARE; + dptr += AX25_ADDR_LEN; + *dptr++ = timeout % 256; + *dptr++ = timeout / 256; + break; + + case NR_CONNACK: + *dptr++ = nr->your_index; + *dptr++ = nr->your_id; + *dptr++ = nr->my_index; + *dptr++ = nr->my_id; + *dptr++ = frametype; + *dptr++ = nr->window; + if (nr->bpqext) *dptr++ = sysctl_netrom_network_ttl_initialiser; + break; + + case NR_DISCREQ: + case NR_DISCACK: + *dptr++ = nr->your_index; + *dptr++ = nr->your_id; + *dptr++ = 0; + *dptr++ = 0; + *dptr++ = frametype; + break; + + case NR_INFOACK: + *dptr++ = nr->your_index; + *dptr++ = nr->your_id; + *dptr++ = 0; + *dptr++ = nr->vr; + *dptr++ = frametype; + break; + } + + nr_transmit_buffer(sk, skb); +} + +/* + * This routine is called to send an error reply. + */ +void __nr_transmit_reply(struct sk_buff *skb, int mine, unsigned char cmdflags) +{ + struct sk_buff *skbn; + unsigned char *dptr; + int len; + + len = NR_NETWORK_LEN + NR_TRANSPORT_LEN + 1; + + if ((skbn = alloc_skb(len, GFP_ATOMIC)) == NULL) + return; + + skb_reserve(skbn, 0); + + dptr = skb_put(skbn, NR_NETWORK_LEN + NR_TRANSPORT_LEN); + + skb_copy_from_linear_data_offset(skb, 7, dptr, AX25_ADDR_LEN); + dptr[6] &= ~AX25_CBIT; + dptr[6] &= ~AX25_EBIT; + dptr[6] |= AX25_SSSID_SPARE; + dptr += AX25_ADDR_LEN; + + skb_copy_from_linear_data(skb, dptr, AX25_ADDR_LEN); + dptr[6] &= ~AX25_CBIT; + dptr[6] |= AX25_EBIT; + dptr[6] |= AX25_SSSID_SPARE; + dptr += AX25_ADDR_LEN; + + *dptr++ = sysctl_netrom_network_ttl_initialiser; + + if (mine) { + *dptr++ = 0; + *dptr++ = 0; + *dptr++ = skb->data[15]; + *dptr++ = skb->data[16]; + } else { + *dptr++ = skb->data[15]; + *dptr++ = skb->data[16]; + *dptr++ = 0; + *dptr++ = 0; + } + + *dptr++ = cmdflags; + *dptr++ = 0; + + if (!nr_route_frame(skbn, NULL)) + kfree_skb(skbn); +} + +void nr_disconnect(struct sock *sk, int reason) +{ + nr_stop_t1timer(sk); + nr_stop_t2timer(sk); + nr_stop_t4timer(sk); + nr_stop_idletimer(sk); + + nr_clear_queues(sk); + + nr_sk(sk)->state = NR_STATE_0; + + sk->sk_state = TCP_CLOSE; + sk->sk_err = reason; + sk->sk_shutdown |= SEND_SHUTDOWN; + + if (!sock_flag(sk, SOCK_DEAD)) { + sk->sk_state_change(sk); + sock_set_flag(sk, SOCK_DEAD); + } +} diff --git a/net/netrom/nr_timer.c b/net/netrom/nr_timer.c new file mode 100644 index 000000000..4e7c968cd --- /dev/null +++ b/net/netrom/nr_timer.c @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright (C) 2002 Ralf Baechle DO1GRB (ralf@gnu.org) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void nr_heartbeat_expiry(struct timer_list *); +static void nr_t1timer_expiry(struct timer_list *); +static void nr_t2timer_expiry(struct timer_list *); +static void nr_t4timer_expiry(struct timer_list *); +static void nr_idletimer_expiry(struct timer_list *); + +void nr_init_timers(struct sock *sk) +{ + struct nr_sock *nr = nr_sk(sk); + + timer_setup(&nr->t1timer, nr_t1timer_expiry, 0); + timer_setup(&nr->t2timer, nr_t2timer_expiry, 0); + timer_setup(&nr->t4timer, nr_t4timer_expiry, 0); + timer_setup(&nr->idletimer, nr_idletimer_expiry, 0); + + /* initialized by sock_init_data */ + sk->sk_timer.function = nr_heartbeat_expiry; +} + +void nr_start_t1timer(struct sock *sk) +{ + struct nr_sock *nr = nr_sk(sk); + + sk_reset_timer(sk, &nr->t1timer, jiffies + nr->t1); +} + +void nr_start_t2timer(struct sock *sk) +{ + struct nr_sock *nr = nr_sk(sk); + + sk_reset_timer(sk, &nr->t2timer, jiffies + nr->t2); +} + +void nr_start_t4timer(struct sock *sk) +{ + struct nr_sock *nr = nr_sk(sk); + + sk_reset_timer(sk, &nr->t4timer, jiffies + nr->t4); +} + +void nr_start_idletimer(struct sock *sk) +{ + struct nr_sock *nr = nr_sk(sk); + + if (nr->idle > 0) + sk_reset_timer(sk, &nr->idletimer, jiffies + nr->idle); +} + +void nr_start_heartbeat(struct sock *sk) +{ + sk_reset_timer(sk, &sk->sk_timer, jiffies + 5 * HZ); +} + +void nr_stop_t1timer(struct sock *sk) +{ + sk_stop_timer(sk, &nr_sk(sk)->t1timer); +} + +void nr_stop_t2timer(struct sock *sk) +{ + sk_stop_timer(sk, &nr_sk(sk)->t2timer); +} + +void nr_stop_t4timer(struct sock *sk) +{ + sk_stop_timer(sk, &nr_sk(sk)->t4timer); +} + +void nr_stop_idletimer(struct sock *sk) +{ + sk_stop_timer(sk, &nr_sk(sk)->idletimer); +} + +void nr_stop_heartbeat(struct sock *sk) +{ + sk_stop_timer(sk, &sk->sk_timer); +} + +int nr_t1timer_running(struct sock *sk) +{ + return timer_pending(&nr_sk(sk)->t1timer); +} + +static void nr_heartbeat_expiry(struct timer_list *t) +{ + struct sock *sk = from_timer(sk, t, sk_timer); + struct nr_sock *nr = nr_sk(sk); + + bh_lock_sock(sk); + switch (nr->state) { + case NR_STATE_0: + /* Magic here: If we listen() and a new link dies before it + is accepted() it isn't 'dead' so doesn't get removed. */ + if (sock_flag(sk, SOCK_DESTROY) || + (sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_DEAD))) { + sock_hold(sk); + bh_unlock_sock(sk); + nr_destroy_socket(sk); + goto out; + } + break; + + case NR_STATE_3: + /* + * Check for the state of the receive buffer. + */ + if (atomic_read(&sk->sk_rmem_alloc) < (sk->sk_rcvbuf / 2) && + (nr->condition & NR_COND_OWN_RX_BUSY)) { + nr->condition &= ~NR_COND_OWN_RX_BUSY; + nr->condition &= ~NR_COND_ACK_PENDING; + nr->vl = nr->vr; + nr_write_internal(sk, NR_INFOACK); + break; + } + break; + } + + nr_start_heartbeat(sk); + bh_unlock_sock(sk); +out: + sock_put(sk); +} + +static void nr_t2timer_expiry(struct timer_list *t) +{ + struct nr_sock *nr = from_timer(nr, t, t2timer); + struct sock *sk = &nr->sock; + + bh_lock_sock(sk); + if (nr->condition & NR_COND_ACK_PENDING) { + nr->condition &= ~NR_COND_ACK_PENDING; + nr_enquiry_response(sk); + } + bh_unlock_sock(sk); + sock_put(sk); +} + +static void nr_t4timer_expiry(struct timer_list *t) +{ + struct nr_sock *nr = from_timer(nr, t, t4timer); + struct sock *sk = &nr->sock; + + bh_lock_sock(sk); + nr_sk(sk)->condition &= ~NR_COND_PEER_RX_BUSY; + bh_unlock_sock(sk); + sock_put(sk); +} + +static void nr_idletimer_expiry(struct timer_list *t) +{ + struct nr_sock *nr = from_timer(nr, t, idletimer); + struct sock *sk = &nr->sock; + + bh_lock_sock(sk); + + nr_clear_queues(sk); + + nr->n2count = 0; + nr_write_internal(sk, NR_DISCREQ); + nr->state = NR_STATE_2; + + nr_start_t1timer(sk); + nr_stop_t2timer(sk); + nr_stop_t4timer(sk); + + sk->sk_state = TCP_CLOSE; + sk->sk_err = 0; + sk->sk_shutdown |= SEND_SHUTDOWN; + + if (!sock_flag(sk, SOCK_DEAD)) { + sk->sk_state_change(sk); + sock_set_flag(sk, SOCK_DEAD); + } + bh_unlock_sock(sk); + sock_put(sk); +} + +static void nr_t1timer_expiry(struct timer_list *t) +{ + struct nr_sock *nr = from_timer(nr, t, t1timer); + struct sock *sk = &nr->sock; + + bh_lock_sock(sk); + switch (nr->state) { + case NR_STATE_1: + if (nr->n2count == nr->n2) { + nr_disconnect(sk, ETIMEDOUT); + goto out; + } else { + nr->n2count++; + nr_write_internal(sk, NR_CONNREQ); + } + break; + + case NR_STATE_2: + if (nr->n2count == nr->n2) { + nr_disconnect(sk, ETIMEDOUT); + goto out; + } else { + nr->n2count++; + nr_write_internal(sk, NR_DISCREQ); + } + break; + + case NR_STATE_3: + if (nr->n2count == nr->n2) { + nr_disconnect(sk, ETIMEDOUT); + goto out; + } else { + nr->n2count++; + nr_requeue_frames(sk); + } + break; + } + + nr_start_t1timer(sk); +out: + bh_unlock_sock(sk); + sock_put(sk); +} diff --git a/net/netrom/sysctl_net_netrom.c b/net/netrom/sysctl_net_netrom.c new file mode 100644 index 000000000..79fb2d3f4 --- /dev/null +++ b/net/netrom/sysctl_net_netrom.c @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) 1996 Mike Shaver (shaver@zeroknowledge.com) + */ +#include +#include +#include +#include +#include + +/* + * Values taken from NET/ROM documentation. + */ +static int min_quality[] = {0}, max_quality[] = {255}; +static int min_obs[] = {0}, max_obs[] = {255}; +static int min_ttl[] = {0}, max_ttl[] = {255}; +static int min_t1[] = {5 * HZ}; +static int max_t1[] = {600 * HZ}; +static int min_n2[] = {2}, max_n2[] = {127}; +static int min_t2[] = {1 * HZ}; +static int max_t2[] = {60 * HZ}; +static int min_t4[] = {1 * HZ}; +static int max_t4[] = {1000 * HZ}; +static int min_window[] = {1}, max_window[] = {127}; +static int min_idle[] = {0 * HZ}; +static int max_idle[] = {65535 * HZ}; +static int min_route[] = {0}, max_route[] = {1}; +static int min_fails[] = {1}, max_fails[] = {10}; +static int min_reset[] = {0}, max_reset[] = {1}; + +static struct ctl_table_header *nr_table_header; + +static struct ctl_table nr_table[] = { + { + .procname = "default_path_quality", + .data = &sysctl_netrom_default_path_quality, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_quality, + .extra2 = &max_quality + }, + { + .procname = "obsolescence_count_initialiser", + .data = &sysctl_netrom_obsolescence_count_initialiser, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_obs, + .extra2 = &max_obs + }, + { + .procname = "network_ttl_initialiser", + .data = &sysctl_netrom_network_ttl_initialiser, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_ttl, + .extra2 = &max_ttl + }, + { + .procname = "transport_timeout", + .data = &sysctl_netrom_transport_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_t1, + .extra2 = &max_t1 + }, + { + .procname = "transport_maximum_tries", + .data = &sysctl_netrom_transport_maximum_tries, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_n2, + .extra2 = &max_n2 + }, + { + .procname = "transport_acknowledge_delay", + .data = &sysctl_netrom_transport_acknowledge_delay, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_t2, + .extra2 = &max_t2 + }, + { + .procname = "transport_busy_delay", + .data = &sysctl_netrom_transport_busy_delay, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_t4, + .extra2 = &max_t4 + }, + { + .procname = "transport_requested_window_size", + .data = &sysctl_netrom_transport_requested_window_size, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_window, + .extra2 = &max_window + }, + { + .procname = "transport_no_activity_timeout", + .data = &sysctl_netrom_transport_no_activity_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_idle, + .extra2 = &max_idle + }, + { + .procname = "routing_control", + .data = &sysctl_netrom_routing_control, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_route, + .extra2 = &max_route + }, + { + .procname = "link_fails_count", + .data = &sysctl_netrom_link_fails_count, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_fails, + .extra2 = &max_fails + }, + { + .procname = "reset", + .data = &sysctl_netrom_reset_circuit, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_reset, + .extra2 = &max_reset + }, + { } +}; + +int __init nr_register_sysctl(void) +{ + nr_table_header = register_net_sysctl(&init_net, "net/netrom", nr_table); + if (!nr_table_header) + return -ENOMEM; + return 0; +} + +void nr_unregister_sysctl(void) +{ + unregister_net_sysctl_table(nr_table_header); +} diff --git a/net/nfc/Kconfig b/net/nfc/Kconfig new file mode 100644 index 000000000..466a0279b --- /dev/null +++ b/net/nfc/Kconfig @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# NFC subsystem configuration +# + +menuconfig NFC + depends on RFKILL || !RFKILL + tristate "NFC subsystem support" + default n + help + Say Y here if you want to build support for NFC (Near field + communication) devices. + + To compile this support as a module, choose M here: the module will + be called nfc. + +config NFC_DIGITAL + depends on NFC + select CRC_CCITT + select CRC_ITU_T + tristate "NFC Digital Protocol stack support" + default n + help + Say Y if you want to build NFC digital protocol stack support. + This is needed by NFC chipsets whose firmware only implement + the NFC analog layer. + + To compile this support as a module, choose M here: the module will + be called nfc_digital. + +source "net/nfc/nci/Kconfig" +source "net/nfc/hci/Kconfig" + +source "drivers/nfc/Kconfig" diff --git a/net/nfc/Makefile b/net/nfc/Makefile new file mode 100644 index 000000000..2ffc69b47 --- /dev/null +++ b/net/nfc/Makefile @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the Linux NFC subsystem. +# + +obj-$(CONFIG_NFC) += nfc.o +obj-$(CONFIG_NFC_NCI) += nci/ +obj-$(CONFIG_NFC_HCI) += hci/ +obj-$(CONFIG_NFC_DIGITAL) += nfc_digital.o + +nfc-objs := core.o netlink.o af_nfc.o rawsock.o llcp_core.o llcp_commands.o \ + llcp_sock.o + +nfc_digital-objs := digital_core.o digital_technology.o digital_dep.o diff --git a/net/nfc/af_nfc.c b/net/nfc/af_nfc.c new file mode 100644 index 000000000..dda323e0a --- /dev/null +++ b/net/nfc/af_nfc.c @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (C) 2011 Instituto Nokia de Tecnologia + * + * Authors: + * Aloisio Almeida Jr + * Lauro Ramos Venancio + */ + +#include +#include + +#include "nfc.h" + +static DEFINE_RWLOCK(proto_tab_lock); +static const struct nfc_protocol *proto_tab[NFC_SOCKPROTO_MAX]; + +static int nfc_sock_create(struct net *net, struct socket *sock, int proto, + int kern) +{ + int rc = -EPROTONOSUPPORT; + + if (net != &init_net) + return -EAFNOSUPPORT; + + if (proto < 0 || proto >= NFC_SOCKPROTO_MAX) + return -EINVAL; + + read_lock(&proto_tab_lock); + if (proto_tab[proto] && try_module_get(proto_tab[proto]->owner)) { + rc = proto_tab[proto]->create(net, sock, proto_tab[proto], kern); + module_put(proto_tab[proto]->owner); + } + read_unlock(&proto_tab_lock); + + return rc; +} + +static const struct net_proto_family nfc_sock_family_ops = { + .owner = THIS_MODULE, + .family = PF_NFC, + .create = nfc_sock_create, +}; + +int nfc_proto_register(const struct nfc_protocol *nfc_proto) +{ + int rc; + + if (nfc_proto->id < 0 || nfc_proto->id >= NFC_SOCKPROTO_MAX) + return -EINVAL; + + rc = proto_register(nfc_proto->proto, 0); + if (rc) + return rc; + + write_lock(&proto_tab_lock); + if (proto_tab[nfc_proto->id]) + rc = -EBUSY; + else + proto_tab[nfc_proto->id] = nfc_proto; + write_unlock(&proto_tab_lock); + + if (rc) + proto_unregister(nfc_proto->proto); + + return rc; +} +EXPORT_SYMBOL(nfc_proto_register); + +void nfc_proto_unregister(const struct nfc_protocol *nfc_proto) +{ + write_lock(&proto_tab_lock); + proto_tab[nfc_proto->id] = NULL; + write_unlock(&proto_tab_lock); + + proto_unregister(nfc_proto->proto); +} +EXPORT_SYMBOL(nfc_proto_unregister); + +int __init af_nfc_init(void) +{ + return sock_register(&nfc_sock_family_ops); +} + +void __exit af_nfc_exit(void) +{ + sock_unregister(PF_NFC); +} diff --git a/net/nfc/core.c b/net/nfc/core.c new file mode 100644 index 000000000..eb2c0959e --- /dev/null +++ b/net/nfc/core.c @@ -0,0 +1,1247 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (C) 2011 Instituto Nokia de Tecnologia + * + * Authors: + * Lauro Ramos Venancio + * Aloisio Almeida Jr + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": %s: " fmt, __func__ + +#include +#include +#include +#include +#include +#include + +#include + +#include "nfc.h" + +#define VERSION "0.1" + +#define NFC_CHECK_PRES_FREQ_MS 2000 + +int nfc_devlist_generation; +DEFINE_MUTEX(nfc_devlist_mutex); + +/* NFC device ID bitmap */ +static DEFINE_IDA(nfc_index_ida); + +int nfc_fw_download(struct nfc_dev *dev, const char *firmware_name) +{ + int rc = 0; + + pr_debug("%s do firmware %s\n", dev_name(&dev->dev), firmware_name); + + device_lock(&dev->dev); + + if (dev->shutting_down) { + rc = -ENODEV; + goto error; + } + + if (dev->dev_up) { + rc = -EBUSY; + goto error; + } + + if (!dev->ops->fw_download) { + rc = -EOPNOTSUPP; + goto error; + } + + dev->fw_download_in_progress = true; + rc = dev->ops->fw_download(dev, firmware_name); + if (rc) + dev->fw_download_in_progress = false; + +error: + device_unlock(&dev->dev); + return rc; +} + +/** + * nfc_fw_download_done - inform that a firmware download was completed + * + * @dev: The nfc device to which firmware was downloaded + * @firmware_name: The firmware filename + * @result: The positive value of a standard errno value + */ +int nfc_fw_download_done(struct nfc_dev *dev, const char *firmware_name, + u32 result) +{ + dev->fw_download_in_progress = false; + + return nfc_genl_fw_download_done(dev, firmware_name, result); +} +EXPORT_SYMBOL(nfc_fw_download_done); + +/** + * nfc_dev_up - turn on the NFC device + * + * @dev: The nfc device to be turned on + * + * The device remains up until the nfc_dev_down function is called. + */ +int nfc_dev_up(struct nfc_dev *dev) +{ + int rc = 0; + + pr_debug("dev_name=%s\n", dev_name(&dev->dev)); + + device_lock(&dev->dev); + + if (dev->shutting_down) { + rc = -ENODEV; + goto error; + } + + if (dev->rfkill && rfkill_blocked(dev->rfkill)) { + rc = -ERFKILL; + goto error; + } + + if (dev->fw_download_in_progress) { + rc = -EBUSY; + goto error; + } + + if (dev->dev_up) { + rc = -EALREADY; + goto error; + } + + if (dev->ops->dev_up) + rc = dev->ops->dev_up(dev); + + if (!rc) + dev->dev_up = true; + + /* We have to enable the device before discovering SEs */ + if (dev->ops->discover_se && dev->ops->discover_se(dev)) + pr_err("SE discovery failed\n"); + +error: + device_unlock(&dev->dev); + return rc; +} + +/** + * nfc_dev_down - turn off the NFC device + * + * @dev: The nfc device to be turned off + */ +int nfc_dev_down(struct nfc_dev *dev) +{ + int rc = 0; + + pr_debug("dev_name=%s\n", dev_name(&dev->dev)); + + device_lock(&dev->dev); + + if (dev->shutting_down) { + rc = -ENODEV; + goto error; + } + + if (!dev->dev_up) { + rc = -EALREADY; + goto error; + } + + if (dev->polling || dev->active_target) { + rc = -EBUSY; + goto error; + } + + if (dev->ops->dev_down) + dev->ops->dev_down(dev); + + dev->dev_up = false; + +error: + device_unlock(&dev->dev); + return rc; +} + +static int nfc_rfkill_set_block(void *data, bool blocked) +{ + struct nfc_dev *dev = data; + + pr_debug("%s blocked %d", dev_name(&dev->dev), blocked); + + if (!blocked) + return 0; + + nfc_dev_down(dev); + + return 0; +} + +static const struct rfkill_ops nfc_rfkill_ops = { + .set_block = nfc_rfkill_set_block, +}; + +/** + * nfc_start_poll - start polling for nfc targets + * + * @dev: The nfc device that must start polling + * @im_protocols: bitset of nfc initiator protocols to be used for polling + * @tm_protocols: bitset of nfc transport protocols to be used for polling + * + * The device remains polling for targets until a target is found or + * the nfc_stop_poll function is called. + */ +int nfc_start_poll(struct nfc_dev *dev, u32 im_protocols, u32 tm_protocols) +{ + int rc; + + pr_debug("dev_name %s initiator protocols 0x%x target protocols 0x%x\n", + dev_name(&dev->dev), im_protocols, tm_protocols); + + if (!im_protocols && !tm_protocols) + return -EINVAL; + + device_lock(&dev->dev); + + if (dev->shutting_down) { + rc = -ENODEV; + goto error; + } + + if (!dev->dev_up) { + rc = -ENODEV; + goto error; + } + + if (dev->polling) { + rc = -EBUSY; + goto error; + } + + rc = dev->ops->start_poll(dev, im_protocols, tm_protocols); + if (!rc) { + dev->polling = true; + dev->rf_mode = NFC_RF_NONE; + } + +error: + device_unlock(&dev->dev); + return rc; +} + +/** + * nfc_stop_poll - stop polling for nfc targets + * + * @dev: The nfc device that must stop polling + */ +int nfc_stop_poll(struct nfc_dev *dev) +{ + int rc = 0; + + pr_debug("dev_name=%s\n", dev_name(&dev->dev)); + + device_lock(&dev->dev); + + if (dev->shutting_down) { + rc = -ENODEV; + goto error; + } + + if (!dev->polling) { + rc = -EINVAL; + goto error; + } + + dev->ops->stop_poll(dev); + dev->polling = false; + dev->rf_mode = NFC_RF_NONE; + +error: + device_unlock(&dev->dev); + return rc; +} + +static struct nfc_target *nfc_find_target(struct nfc_dev *dev, u32 target_idx) +{ + int i; + + for (i = 0; i < dev->n_targets; i++) { + if (dev->targets[i].idx == target_idx) + return &dev->targets[i]; + } + + return NULL; +} + +int nfc_dep_link_up(struct nfc_dev *dev, int target_index, u8 comm_mode) +{ + int rc = 0; + u8 *gb; + size_t gb_len; + struct nfc_target *target; + + pr_debug("dev_name=%s comm %d\n", dev_name(&dev->dev), comm_mode); + + if (!dev->ops->dep_link_up) + return -EOPNOTSUPP; + + device_lock(&dev->dev); + + if (dev->shutting_down) { + rc = -ENODEV; + goto error; + } + + if (dev->dep_link_up == true) { + rc = -EALREADY; + goto error; + } + + gb = nfc_llcp_general_bytes(dev, &gb_len); + if (gb_len > NFC_MAX_GT_LEN) { + rc = -EINVAL; + goto error; + } + + target = nfc_find_target(dev, target_index); + if (target == NULL) { + rc = -ENOTCONN; + goto error; + } + + rc = dev->ops->dep_link_up(dev, target, comm_mode, gb, gb_len); + if (!rc) { + dev->active_target = target; + dev->rf_mode = NFC_RF_INITIATOR; + } + +error: + device_unlock(&dev->dev); + return rc; +} + +int nfc_dep_link_down(struct nfc_dev *dev) +{ + int rc = 0; + + pr_debug("dev_name=%s\n", dev_name(&dev->dev)); + + if (!dev->ops->dep_link_down) + return -EOPNOTSUPP; + + device_lock(&dev->dev); + + if (dev->shutting_down) { + rc = -ENODEV; + goto error; + } + + if (dev->dep_link_up == false) { + rc = -EALREADY; + goto error; + } + + rc = dev->ops->dep_link_down(dev); + if (!rc) { + dev->dep_link_up = false; + dev->active_target = NULL; + dev->rf_mode = NFC_RF_NONE; + nfc_llcp_mac_is_down(dev); + nfc_genl_dep_link_down_event(dev); + } + +error: + device_unlock(&dev->dev); + + return rc; +} + +int nfc_dep_link_is_up(struct nfc_dev *dev, u32 target_idx, + u8 comm_mode, u8 rf_mode) +{ + dev->dep_link_up = true; + + if (!dev->active_target && rf_mode == NFC_RF_INITIATOR) { + struct nfc_target *target; + + target = nfc_find_target(dev, target_idx); + if (target == NULL) + return -ENOTCONN; + + dev->active_target = target; + } + + dev->polling = false; + dev->rf_mode = rf_mode; + + nfc_llcp_mac_is_up(dev, target_idx, comm_mode, rf_mode); + + return nfc_genl_dep_link_up_event(dev, target_idx, comm_mode, rf_mode); +} +EXPORT_SYMBOL(nfc_dep_link_is_up); + +/** + * nfc_activate_target - prepare the target for data exchange + * + * @dev: The nfc device that found the target + * @target_idx: index of the target that must be activated + * @protocol: nfc protocol that will be used for data exchange + */ +int nfc_activate_target(struct nfc_dev *dev, u32 target_idx, u32 protocol) +{ + int rc; + struct nfc_target *target; + + pr_debug("dev_name=%s target_idx=%u protocol=%u\n", + dev_name(&dev->dev), target_idx, protocol); + + device_lock(&dev->dev); + + if (dev->shutting_down) { + rc = -ENODEV; + goto error; + } + + if (dev->active_target) { + rc = -EBUSY; + goto error; + } + + target = nfc_find_target(dev, target_idx); + if (target == NULL) { + rc = -ENOTCONN; + goto error; + } + + rc = dev->ops->activate_target(dev, target, protocol); + if (!rc) { + dev->active_target = target; + dev->rf_mode = NFC_RF_INITIATOR; + + if (dev->ops->check_presence && !dev->shutting_down) + mod_timer(&dev->check_pres_timer, jiffies + + msecs_to_jiffies(NFC_CHECK_PRES_FREQ_MS)); + } + +error: + device_unlock(&dev->dev); + return rc; +} + +/** + * nfc_deactivate_target - deactivate a nfc target + * + * @dev: The nfc device that found the target + * @target_idx: index of the target that must be deactivated + * @mode: idle or sleep? + */ +int nfc_deactivate_target(struct nfc_dev *dev, u32 target_idx, u8 mode) +{ + int rc = 0; + + pr_debug("dev_name=%s target_idx=%u\n", + dev_name(&dev->dev), target_idx); + + device_lock(&dev->dev); + + if (dev->shutting_down) { + rc = -ENODEV; + goto error; + } + + if (dev->active_target == NULL) { + rc = -ENOTCONN; + goto error; + } + + if (dev->active_target->idx != target_idx) { + rc = -ENOTCONN; + goto error; + } + + if (dev->ops->check_presence) + del_timer_sync(&dev->check_pres_timer); + + dev->ops->deactivate_target(dev, dev->active_target, mode); + dev->active_target = NULL; + +error: + device_unlock(&dev->dev); + return rc; +} + +/** + * nfc_data_exchange - transceive data + * + * @dev: The nfc device that found the target + * @target_idx: index of the target + * @skb: data to be sent + * @cb: callback called when the response is received + * @cb_context: parameter for the callback function + * + * The user must wait for the callback before calling this function again. + */ +int nfc_data_exchange(struct nfc_dev *dev, u32 target_idx, struct sk_buff *skb, + data_exchange_cb_t cb, void *cb_context) +{ + int rc; + + pr_debug("dev_name=%s target_idx=%u skb->len=%u\n", + dev_name(&dev->dev), target_idx, skb->len); + + device_lock(&dev->dev); + + if (dev->shutting_down) { + rc = -ENODEV; + kfree_skb(skb); + goto error; + } + + if (dev->rf_mode == NFC_RF_INITIATOR && dev->active_target != NULL) { + if (dev->active_target->idx != target_idx) { + rc = -EADDRNOTAVAIL; + kfree_skb(skb); + goto error; + } + + if (dev->ops->check_presence) + del_timer_sync(&dev->check_pres_timer); + + rc = dev->ops->im_transceive(dev, dev->active_target, skb, cb, + cb_context); + + if (!rc && dev->ops->check_presence && !dev->shutting_down) + mod_timer(&dev->check_pres_timer, jiffies + + msecs_to_jiffies(NFC_CHECK_PRES_FREQ_MS)); + } else if (dev->rf_mode == NFC_RF_TARGET && dev->ops->tm_send != NULL) { + rc = dev->ops->tm_send(dev, skb); + } else { + rc = -ENOTCONN; + kfree_skb(skb); + goto error; + } + + +error: + device_unlock(&dev->dev); + return rc; +} + +struct nfc_se *nfc_find_se(struct nfc_dev *dev, u32 se_idx) +{ + struct nfc_se *se; + + list_for_each_entry(se, &dev->secure_elements, list) + if (se->idx == se_idx) + return se; + + return NULL; +} +EXPORT_SYMBOL(nfc_find_se); + +int nfc_enable_se(struct nfc_dev *dev, u32 se_idx) +{ + struct nfc_se *se; + int rc; + + pr_debug("%s se index %d\n", dev_name(&dev->dev), se_idx); + + device_lock(&dev->dev); + + if (dev->shutting_down) { + rc = -ENODEV; + goto error; + } + + if (!dev->dev_up) { + rc = -ENODEV; + goto error; + } + + if (dev->polling) { + rc = -EBUSY; + goto error; + } + + if (!dev->ops->enable_se || !dev->ops->disable_se) { + rc = -EOPNOTSUPP; + goto error; + } + + se = nfc_find_se(dev, se_idx); + if (!se) { + rc = -EINVAL; + goto error; + } + + if (se->state == NFC_SE_ENABLED) { + rc = -EALREADY; + goto error; + } + + rc = dev->ops->enable_se(dev, se_idx); + if (rc >= 0) + se->state = NFC_SE_ENABLED; + +error: + device_unlock(&dev->dev); + return rc; +} + +int nfc_disable_se(struct nfc_dev *dev, u32 se_idx) +{ + struct nfc_se *se; + int rc; + + pr_debug("%s se index %d\n", dev_name(&dev->dev), se_idx); + + device_lock(&dev->dev); + + if (dev->shutting_down) { + rc = -ENODEV; + goto error; + } + + if (!dev->dev_up) { + rc = -ENODEV; + goto error; + } + + if (!dev->ops->enable_se || !dev->ops->disable_se) { + rc = -EOPNOTSUPP; + goto error; + } + + se = nfc_find_se(dev, se_idx); + if (!se) { + rc = -EINVAL; + goto error; + } + + if (se->state == NFC_SE_DISABLED) { + rc = -EALREADY; + goto error; + } + + rc = dev->ops->disable_se(dev, se_idx); + if (rc >= 0) + se->state = NFC_SE_DISABLED; + +error: + device_unlock(&dev->dev); + return rc; +} + +int nfc_set_remote_general_bytes(struct nfc_dev *dev, const u8 *gb, u8 gb_len) +{ + pr_debug("dev_name=%s gb_len=%d\n", dev_name(&dev->dev), gb_len); + + return nfc_llcp_set_remote_gb(dev, gb, gb_len); +} +EXPORT_SYMBOL(nfc_set_remote_general_bytes); + +u8 *nfc_get_local_general_bytes(struct nfc_dev *dev, size_t *gb_len) +{ + pr_debug("dev_name=%s\n", dev_name(&dev->dev)); + + return nfc_llcp_general_bytes(dev, gb_len); +} +EXPORT_SYMBOL(nfc_get_local_general_bytes); + +int nfc_tm_data_received(struct nfc_dev *dev, struct sk_buff *skb) +{ + /* Only LLCP target mode for now */ + if (dev->dep_link_up == false) { + kfree_skb(skb); + return -ENOLINK; + } + + return nfc_llcp_data_received(dev, skb); +} +EXPORT_SYMBOL(nfc_tm_data_received); + +int nfc_tm_activated(struct nfc_dev *dev, u32 protocol, u8 comm_mode, + const u8 *gb, size_t gb_len) +{ + int rc; + + device_lock(&dev->dev); + + dev->polling = false; + + if (gb != NULL) { + rc = nfc_set_remote_general_bytes(dev, gb, gb_len); + if (rc < 0) + goto out; + } + + dev->rf_mode = NFC_RF_TARGET; + + if (protocol == NFC_PROTO_NFC_DEP_MASK) + nfc_dep_link_is_up(dev, 0, comm_mode, NFC_RF_TARGET); + + rc = nfc_genl_tm_activated(dev, protocol); + +out: + device_unlock(&dev->dev); + + return rc; +} +EXPORT_SYMBOL(nfc_tm_activated); + +int nfc_tm_deactivated(struct nfc_dev *dev) +{ + dev->dep_link_up = false; + dev->rf_mode = NFC_RF_NONE; + + return nfc_genl_tm_deactivated(dev); +} +EXPORT_SYMBOL(nfc_tm_deactivated); + +/** + * nfc_alloc_send_skb - allocate a skb for data exchange responses + * + * @dev: device sending the response + * @sk: socket sending the response + * @flags: MSG_DONTWAIT flag + * @size: size to allocate + * @err: pointer to memory to store the error code + */ +struct sk_buff *nfc_alloc_send_skb(struct nfc_dev *dev, struct sock *sk, + unsigned int flags, unsigned int size, + unsigned int *err) +{ + struct sk_buff *skb; + unsigned int total_size; + + total_size = size + + dev->tx_headroom + dev->tx_tailroom + NFC_HEADER_SIZE; + + skb = sock_alloc_send_skb(sk, total_size, flags & MSG_DONTWAIT, err); + if (skb) + skb_reserve(skb, dev->tx_headroom + NFC_HEADER_SIZE); + + return skb; +} + +/** + * nfc_alloc_recv_skb - allocate a skb for data exchange responses + * + * @size: size to allocate + * @gfp: gfp flags + */ +struct sk_buff *nfc_alloc_recv_skb(unsigned int size, gfp_t gfp) +{ + struct sk_buff *skb; + unsigned int total_size; + + total_size = size + 1; + skb = alloc_skb(total_size, gfp); + + if (skb) + skb_reserve(skb, 1); + + return skb; +} +EXPORT_SYMBOL(nfc_alloc_recv_skb); + +/** + * nfc_targets_found - inform that targets were found + * + * @dev: The nfc device that found the targets + * @targets: array of nfc targets found + * @n_targets: targets array size + * + * The device driver must call this function when one or many nfc targets + * are found. After calling this function, the device driver must stop + * polling for targets. + * NOTE: This function can be called with targets=NULL and n_targets=0 to + * notify a driver error, meaning that the polling operation cannot complete. + * IMPORTANT: this function must not be called from an atomic context. + * In addition, it must also not be called from a context that would prevent + * the NFC Core to call other nfc ops entry point concurrently. + */ +int nfc_targets_found(struct nfc_dev *dev, + struct nfc_target *targets, int n_targets) +{ + int i; + + pr_debug("dev_name=%s n_targets=%d\n", dev_name(&dev->dev), n_targets); + + for (i = 0; i < n_targets; i++) + targets[i].idx = dev->target_next_idx++; + + device_lock(&dev->dev); + + if (dev->polling == false) { + device_unlock(&dev->dev); + return 0; + } + + dev->polling = false; + + dev->targets_generation++; + + kfree(dev->targets); + dev->targets = NULL; + + if (targets) { + dev->targets = kmemdup(targets, + n_targets * sizeof(struct nfc_target), + GFP_ATOMIC); + + if (!dev->targets) { + dev->n_targets = 0; + device_unlock(&dev->dev); + return -ENOMEM; + } + } + + dev->n_targets = n_targets; + device_unlock(&dev->dev); + + nfc_genl_targets_found(dev); + + return 0; +} +EXPORT_SYMBOL(nfc_targets_found); + +/** + * nfc_target_lost - inform that an activated target went out of field + * + * @dev: The nfc device that had the activated target in field + * @target_idx: the nfc index of the target + * + * The device driver must call this function when the activated target + * goes out of the field. + * IMPORTANT: this function must not be called from an atomic context. + * In addition, it must also not be called from a context that would prevent + * the NFC Core to call other nfc ops entry point concurrently. + */ +int nfc_target_lost(struct nfc_dev *dev, u32 target_idx) +{ + const struct nfc_target *tg; + int i; + + pr_debug("dev_name %s n_target %d\n", dev_name(&dev->dev), target_idx); + + device_lock(&dev->dev); + + for (i = 0; i < dev->n_targets; i++) { + tg = &dev->targets[i]; + if (tg->idx == target_idx) + break; + } + + if (i == dev->n_targets) { + device_unlock(&dev->dev); + return -EINVAL; + } + + dev->targets_generation++; + dev->n_targets--; + dev->active_target = NULL; + + if (dev->n_targets) { + memcpy(&dev->targets[i], &dev->targets[i + 1], + (dev->n_targets - i) * sizeof(struct nfc_target)); + } else { + kfree(dev->targets); + dev->targets = NULL; + } + + device_unlock(&dev->dev); + + nfc_genl_target_lost(dev, target_idx); + + return 0; +} +EXPORT_SYMBOL(nfc_target_lost); + +inline void nfc_driver_failure(struct nfc_dev *dev, int err) +{ + nfc_targets_found(dev, NULL, 0); +} +EXPORT_SYMBOL(nfc_driver_failure); + +int nfc_add_se(struct nfc_dev *dev, u32 se_idx, u16 type) +{ + struct nfc_se *se; + int rc; + + pr_debug("%s se index %d\n", dev_name(&dev->dev), se_idx); + + se = nfc_find_se(dev, se_idx); + if (se) + return -EALREADY; + + se = kzalloc(sizeof(struct nfc_se), GFP_KERNEL); + if (!se) + return -ENOMEM; + + se->idx = se_idx; + se->type = type; + se->state = NFC_SE_DISABLED; + INIT_LIST_HEAD(&se->list); + + list_add(&se->list, &dev->secure_elements); + + rc = nfc_genl_se_added(dev, se_idx, type); + if (rc < 0) { + list_del(&se->list); + kfree(se); + + return rc; + } + + return 0; +} +EXPORT_SYMBOL(nfc_add_se); + +int nfc_remove_se(struct nfc_dev *dev, u32 se_idx) +{ + struct nfc_se *se, *n; + int rc; + + pr_debug("%s se index %d\n", dev_name(&dev->dev), se_idx); + + list_for_each_entry_safe(se, n, &dev->secure_elements, list) + if (se->idx == se_idx) { + rc = nfc_genl_se_removed(dev, se_idx); + if (rc < 0) + return rc; + + list_del(&se->list); + kfree(se); + + return 0; + } + + return -EINVAL; +} +EXPORT_SYMBOL(nfc_remove_se); + +int nfc_se_transaction(struct nfc_dev *dev, u8 se_idx, + struct nfc_evt_transaction *evt_transaction) +{ + int rc; + + pr_debug("transaction: %x\n", se_idx); + + device_lock(&dev->dev); + + if (!evt_transaction) { + rc = -EPROTO; + goto out; + } + + rc = nfc_genl_se_transaction(dev, se_idx, evt_transaction); +out: + device_unlock(&dev->dev); + return rc; +} +EXPORT_SYMBOL(nfc_se_transaction); + +int nfc_se_connectivity(struct nfc_dev *dev, u8 se_idx) +{ + int rc; + + pr_debug("connectivity: %x\n", se_idx); + + device_lock(&dev->dev); + rc = nfc_genl_se_connectivity(dev, se_idx); + device_unlock(&dev->dev); + return rc; +} +EXPORT_SYMBOL(nfc_se_connectivity); + +static void nfc_release(struct device *d) +{ + struct nfc_dev *dev = to_nfc_dev(d); + struct nfc_se *se, *n; + + pr_debug("dev_name=%s\n", dev_name(&dev->dev)); + + nfc_genl_data_exit(&dev->genl_data); + kfree(dev->targets); + + list_for_each_entry_safe(se, n, &dev->secure_elements, list) { + nfc_genl_se_removed(dev, se->idx); + list_del(&se->list); + kfree(se); + } + + ida_free(&nfc_index_ida, dev->idx); + + kfree(dev); +} + +static void nfc_check_pres_work(struct work_struct *work) +{ + struct nfc_dev *dev = container_of(work, struct nfc_dev, + check_pres_work); + int rc; + + device_lock(&dev->dev); + + if (dev->active_target && timer_pending(&dev->check_pres_timer) == 0) { + rc = dev->ops->check_presence(dev, dev->active_target); + if (rc == -EOPNOTSUPP) + goto exit; + if (rc) { + u32 active_target_idx = dev->active_target->idx; + device_unlock(&dev->dev); + nfc_target_lost(dev, active_target_idx); + return; + } + + if (!dev->shutting_down) + mod_timer(&dev->check_pres_timer, jiffies + + msecs_to_jiffies(NFC_CHECK_PRES_FREQ_MS)); + } + +exit: + device_unlock(&dev->dev); +} + +static void nfc_check_pres_timeout(struct timer_list *t) +{ + struct nfc_dev *dev = from_timer(dev, t, check_pres_timer); + + schedule_work(&dev->check_pres_work); +} + +struct class nfc_class = { + .name = "nfc", + .dev_release = nfc_release, +}; +EXPORT_SYMBOL(nfc_class); + +static int match_idx(struct device *d, const void *data) +{ + struct nfc_dev *dev = to_nfc_dev(d); + const unsigned int *idx = data; + + return dev->idx == *idx; +} + +struct nfc_dev *nfc_get_device(unsigned int idx) +{ + struct device *d; + + d = class_find_device(&nfc_class, NULL, &idx, match_idx); + if (!d) + return NULL; + + return to_nfc_dev(d); +} + +/** + * nfc_allocate_device - allocate a new nfc device + * + * @ops: device operations + * @supported_protocols: NFC protocols supported by the device + * @tx_headroom: reserved space at beginning of skb + * @tx_tailroom: reserved space at end of skb + */ +struct nfc_dev *nfc_allocate_device(const struct nfc_ops *ops, + u32 supported_protocols, + int tx_headroom, int tx_tailroom) +{ + struct nfc_dev *dev; + int rc; + + if (!ops->start_poll || !ops->stop_poll || !ops->activate_target || + !ops->deactivate_target || !ops->im_transceive) + return NULL; + + if (!supported_protocols) + return NULL; + + dev = kzalloc(sizeof(struct nfc_dev), GFP_KERNEL); + if (!dev) + return NULL; + + rc = ida_alloc(&nfc_index_ida, GFP_KERNEL); + if (rc < 0) + goto err_free_dev; + dev->idx = rc; + + dev->dev.class = &nfc_class; + dev_set_name(&dev->dev, "nfc%d", dev->idx); + device_initialize(&dev->dev); + + dev->ops = ops; + dev->supported_protocols = supported_protocols; + dev->tx_headroom = tx_headroom; + dev->tx_tailroom = tx_tailroom; + INIT_LIST_HEAD(&dev->secure_elements); + + nfc_genl_data_init(&dev->genl_data); + + dev->rf_mode = NFC_RF_NONE; + + /* first generation must not be 0 */ + dev->targets_generation = 1; + + if (ops->check_presence) { + timer_setup(&dev->check_pres_timer, nfc_check_pres_timeout, 0); + INIT_WORK(&dev->check_pres_work, nfc_check_pres_work); + } + + return dev; + +err_free_dev: + kfree(dev); + + return NULL; +} +EXPORT_SYMBOL(nfc_allocate_device); + +/** + * nfc_register_device - register a nfc device in the nfc subsystem + * + * @dev: The nfc device to register + */ +int nfc_register_device(struct nfc_dev *dev) +{ + int rc; + + pr_debug("dev_name=%s\n", dev_name(&dev->dev)); + + mutex_lock(&nfc_devlist_mutex); + nfc_devlist_generation++; + rc = device_add(&dev->dev); + mutex_unlock(&nfc_devlist_mutex); + + if (rc < 0) + return rc; + + rc = nfc_llcp_register_device(dev); + if (rc) + pr_err("Could not register llcp device\n"); + + device_lock(&dev->dev); + dev->rfkill = rfkill_alloc(dev_name(&dev->dev), &dev->dev, + RFKILL_TYPE_NFC, &nfc_rfkill_ops, dev); + if (dev->rfkill) { + if (rfkill_register(dev->rfkill) < 0) { + rfkill_destroy(dev->rfkill); + dev->rfkill = NULL; + } + } + dev->shutting_down = false; + device_unlock(&dev->dev); + + rc = nfc_genl_device_added(dev); + if (rc) + pr_debug("The userspace won't be notified that the device %s was added\n", + dev_name(&dev->dev)); + + return 0; +} +EXPORT_SYMBOL(nfc_register_device); + +/** + * nfc_unregister_device - unregister a nfc device in the nfc subsystem + * + * @dev: The nfc device to unregister + */ +void nfc_unregister_device(struct nfc_dev *dev) +{ + int rc; + + pr_debug("dev_name=%s\n", dev_name(&dev->dev)); + + rc = nfc_genl_device_removed(dev); + if (rc) + pr_debug("The userspace won't be notified that the device %s " + "was removed\n", dev_name(&dev->dev)); + + device_lock(&dev->dev); + if (dev->rfkill) { + rfkill_unregister(dev->rfkill); + rfkill_destroy(dev->rfkill); + dev->rfkill = NULL; + } + dev->shutting_down = true; + device_unlock(&dev->dev); + + if (dev->ops->check_presence) { + del_timer_sync(&dev->check_pres_timer); + cancel_work_sync(&dev->check_pres_work); + } + + nfc_llcp_unregister_device(dev); + + mutex_lock(&nfc_devlist_mutex); + nfc_devlist_generation++; + device_del(&dev->dev); + mutex_unlock(&nfc_devlist_mutex); +} +EXPORT_SYMBOL(nfc_unregister_device); + +static int __init nfc_init(void) +{ + int rc; + + pr_info("NFC Core ver %s\n", VERSION); + + rc = class_register(&nfc_class); + if (rc) + return rc; + + rc = nfc_genl_init(); + if (rc) + goto err_genl; + + /* the first generation must not be 0 */ + nfc_devlist_generation = 1; + + rc = rawsock_init(); + if (rc) + goto err_rawsock; + + rc = nfc_llcp_init(); + if (rc) + goto err_llcp_sock; + + rc = af_nfc_init(); + if (rc) + goto err_af_nfc; + + return 0; + +err_af_nfc: + nfc_llcp_exit(); +err_llcp_sock: + rawsock_exit(); +err_rawsock: + nfc_genl_exit(); +err_genl: + class_unregister(&nfc_class); + return rc; +} + +static void __exit nfc_exit(void) +{ + af_nfc_exit(); + nfc_llcp_exit(); + rawsock_exit(); + nfc_genl_exit(); + class_unregister(&nfc_class); +} + +subsys_initcall(nfc_init); +module_exit(nfc_exit); + +MODULE_AUTHOR("Lauro Ramos Venancio "); +MODULE_DESCRIPTION("NFC Core ver " VERSION); +MODULE_VERSION(VERSION); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_NFC); +MODULE_ALIAS_GENL_FAMILY(NFC_GENL_NAME); diff --git a/net/nfc/digital.h b/net/nfc/digital.h new file mode 100644 index 000000000..c33b2f7e1 --- /dev/null +++ b/net/nfc/digital.h @@ -0,0 +1,171 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * NFC Digital Protocol stack + * Copyright (c) 2013, Intel Corporation. + */ + +#ifndef __DIGITAL_H +#define __DIGITAL_H + +#include +#include + +#include +#include + +#define PROTOCOL_ERR(req) pr_err("%d: NFC Digital Protocol error: %s\n", \ + __LINE__, req) + +#define DIGITAL_CMD_IN_SEND 0 +#define DIGITAL_CMD_TG_SEND 1 +#define DIGITAL_CMD_TG_LISTEN 2 +#define DIGITAL_CMD_TG_LISTEN_MDAA 3 +#define DIGITAL_CMD_TG_LISTEN_MD 4 + +#define DIGITAL_MAX_HEADER_LEN 7 +#define DIGITAL_CRC_LEN 2 + +#define DIGITAL_SENSF_NFCID2_NFC_DEP_B1 0x01 +#define DIGITAL_SENSF_NFCID2_NFC_DEP_B2 0xFE + +#define DIGITAL_SENS_RES_NFC_DEP 0x0100 +#define DIGITAL_SEL_RES_NFC_DEP 0x40 +#define DIGITAL_SENSF_FELICA_SC 0xFFFF + +#define DIGITAL_DRV_CAPS_IN_CRC(ddev) \ + ((ddev)->driver_capabilities & NFC_DIGITAL_DRV_CAPS_IN_CRC) +#define DIGITAL_DRV_CAPS_TG_CRC(ddev) \ + ((ddev)->driver_capabilities & NFC_DIGITAL_DRV_CAPS_TG_CRC) + +struct digital_data_exch { + data_exchange_cb_t cb; + void *cb_context; +}; + +struct sk_buff *digital_skb_alloc(struct nfc_digital_dev *ddev, + unsigned int len); + +int digital_send_cmd(struct nfc_digital_dev *ddev, u8 cmd_type, + struct sk_buff *skb, struct digital_tg_mdaa_params *params, + u16 timeout, nfc_digital_cmd_complete_t cmd_cb, + void *cb_context); + +int digital_in_configure_hw(struct nfc_digital_dev *ddev, int type, int param); +static inline int digital_in_send_cmd(struct nfc_digital_dev *ddev, + struct sk_buff *skb, u16 timeout, + nfc_digital_cmd_complete_t cmd_cb, + void *cb_context) +{ + return digital_send_cmd(ddev, DIGITAL_CMD_IN_SEND, skb, NULL, timeout, + cmd_cb, cb_context); +} + +void digital_poll_next_tech(struct nfc_digital_dev *ddev); + +int digital_in_send_sens_req(struct nfc_digital_dev *ddev, u8 rf_tech); +int digital_in_send_sensb_req(struct nfc_digital_dev *ddev, u8 rf_tech); +int digital_in_send_sensf_req(struct nfc_digital_dev *ddev, u8 rf_tech); +int digital_in_send_iso15693_inv_req(struct nfc_digital_dev *ddev, u8 rf_tech); + +int digital_in_iso_dep_pull_sod(struct nfc_digital_dev *ddev, + struct sk_buff *skb); +int digital_in_iso_dep_push_sod(struct nfc_digital_dev *ddev, + struct sk_buff *skb); + +int digital_target_found(struct nfc_digital_dev *ddev, + struct nfc_target *target, u8 protocol); + +int digital_in_recv_mifare_res(struct sk_buff *resp); + +int digital_in_send_atr_req(struct nfc_digital_dev *ddev, + struct nfc_target *target, __u8 comm_mode, __u8 *gb, + size_t gb_len); +int digital_in_send_dep_req(struct nfc_digital_dev *ddev, + struct nfc_target *target, struct sk_buff *skb, + struct digital_data_exch *data_exch); + +int digital_tg_configure_hw(struct nfc_digital_dev *ddev, int type, int param); +static inline int digital_tg_send_cmd(struct nfc_digital_dev *ddev, + struct sk_buff *skb, u16 timeout, + nfc_digital_cmd_complete_t cmd_cb, void *cb_context) +{ + return digital_send_cmd(ddev, DIGITAL_CMD_TG_SEND, skb, NULL, timeout, + cmd_cb, cb_context); +} + +void digital_tg_recv_sens_req(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp); + +void digital_tg_recv_sensf_req(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp); + +static inline int digital_tg_listen(struct nfc_digital_dev *ddev, u16 timeout, + nfc_digital_cmd_complete_t cb, void *arg) +{ + return digital_send_cmd(ddev, DIGITAL_CMD_TG_LISTEN, NULL, NULL, + timeout, cb, arg); +} + +void digital_tg_recv_atr_req(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp); + +int digital_tg_send_dep_res(struct nfc_digital_dev *ddev, struct sk_buff *skb); + +int digital_tg_listen_nfca(struct nfc_digital_dev *ddev, u8 rf_tech); +int digital_tg_listen_nfcf(struct nfc_digital_dev *ddev, u8 rf_tech); +void digital_tg_recv_md_req(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp); + +typedef u16 (*crc_func_t)(u16, const u8 *, size_t); + +#define CRC_A_INIT 0x6363 +#define CRC_B_INIT 0xFFFF +#define CRC_F_INIT 0x0000 + +void digital_skb_add_crc(struct sk_buff *skb, crc_func_t crc_func, u16 init, + u8 bitwise_inv, u8 msb_first); + +static inline void digital_skb_add_crc_a(struct sk_buff *skb) +{ + digital_skb_add_crc(skb, crc_ccitt, CRC_A_INIT, 0, 0); +} + +static inline void digital_skb_add_crc_b(struct sk_buff *skb) +{ + digital_skb_add_crc(skb, crc_ccitt, CRC_B_INIT, 1, 0); +} + +static inline void digital_skb_add_crc_f(struct sk_buff *skb) +{ + digital_skb_add_crc(skb, crc_itu_t, CRC_F_INIT, 0, 1); +} + +static inline void digital_skb_add_crc_none(struct sk_buff *skb) +{ + return; +} + +int digital_skb_check_crc(struct sk_buff *skb, crc_func_t crc_func, + u16 crc_init, u8 bitwise_inv, u8 msb_first); + +static inline int digital_skb_check_crc_a(struct sk_buff *skb) +{ + return digital_skb_check_crc(skb, crc_ccitt, CRC_A_INIT, 0, 0); +} + +static inline int digital_skb_check_crc_b(struct sk_buff *skb) +{ + return digital_skb_check_crc(skb, crc_ccitt, CRC_B_INIT, 1, 0); +} + +static inline int digital_skb_check_crc_f(struct sk_buff *skb) +{ + return digital_skb_check_crc(skb, crc_itu_t, CRC_F_INIT, 0, 1); +} + +static inline int digital_skb_check_crc_none(struct sk_buff *skb) +{ + return 0; +} + +#endif /* __DIGITAL_H */ diff --git a/net/nfc/digital_core.c b/net/nfc/digital_core.c new file mode 100644 index 000000000..d63d2e5dc --- /dev/null +++ b/net/nfc/digital_core.c @@ -0,0 +1,861 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * NFC Digital Protocol stack + * Copyright (c) 2013, Intel Corporation. + */ + +#define pr_fmt(fmt) "digital: %s: " fmt, __func__ + +#include + +#include "digital.h" + +#define DIGITAL_PROTO_NFCA_RF_TECH \ + (NFC_PROTO_JEWEL_MASK | NFC_PROTO_MIFARE_MASK | \ + NFC_PROTO_NFC_DEP_MASK | NFC_PROTO_ISO14443_MASK) + +#define DIGITAL_PROTO_NFCB_RF_TECH NFC_PROTO_ISO14443_B_MASK + +#define DIGITAL_PROTO_NFCF_RF_TECH \ + (NFC_PROTO_FELICA_MASK | NFC_PROTO_NFC_DEP_MASK) + +#define DIGITAL_PROTO_ISO15693_RF_TECH NFC_PROTO_ISO15693_MASK + +/* Delay between each poll frame (ms) */ +#define DIGITAL_POLL_INTERVAL 10 + +struct digital_cmd { + struct list_head queue; + + u8 type; + u8 pending; + + u16 timeout; + struct sk_buff *req; + struct sk_buff *resp; + struct digital_tg_mdaa_params *mdaa_params; + + nfc_digital_cmd_complete_t cmd_cb; + void *cb_context; +}; + +struct sk_buff *digital_skb_alloc(struct nfc_digital_dev *ddev, + unsigned int len) +{ + struct sk_buff *skb; + + skb = alloc_skb(len + ddev->tx_headroom + ddev->tx_tailroom, + GFP_KERNEL); + if (skb) + skb_reserve(skb, ddev->tx_headroom); + + return skb; +} + +void digital_skb_add_crc(struct sk_buff *skb, crc_func_t crc_func, u16 init, + u8 bitwise_inv, u8 msb_first) +{ + u16 crc; + + crc = crc_func(init, skb->data, skb->len); + + if (bitwise_inv) + crc = ~crc; + + if (msb_first) + crc = __fswab16(crc); + + skb_put_u8(skb, crc & 0xFF); + skb_put_u8(skb, (crc >> 8) & 0xFF); +} + +int digital_skb_check_crc(struct sk_buff *skb, crc_func_t crc_func, + u16 crc_init, u8 bitwise_inv, u8 msb_first) +{ + int rc; + u16 crc; + + if (skb->len <= 2) + return -EIO; + + crc = crc_func(crc_init, skb->data, skb->len - 2); + + if (bitwise_inv) + crc = ~crc; + + if (msb_first) + crc = __swab16(crc); + + rc = (skb->data[skb->len - 2] - (crc & 0xFF)) + + (skb->data[skb->len - 1] - ((crc >> 8) & 0xFF)); + + if (rc) + return -EIO; + + skb_trim(skb, skb->len - 2); + + return 0; +} + +static inline void digital_switch_rf(struct nfc_digital_dev *ddev, bool on) +{ + ddev->ops->switch_rf(ddev, on); +} + +static inline void digital_abort_cmd(struct nfc_digital_dev *ddev) +{ + ddev->ops->abort_cmd(ddev); +} + +static void digital_wq_cmd_complete(struct work_struct *work) +{ + struct digital_cmd *cmd; + struct nfc_digital_dev *ddev = container_of(work, + struct nfc_digital_dev, + cmd_complete_work); + + mutex_lock(&ddev->cmd_lock); + + cmd = list_first_entry_or_null(&ddev->cmd_queue, struct digital_cmd, + queue); + if (!cmd) { + mutex_unlock(&ddev->cmd_lock); + return; + } + + list_del(&cmd->queue); + + mutex_unlock(&ddev->cmd_lock); + + if (!IS_ERR(cmd->resp)) + print_hex_dump_debug("DIGITAL RX: ", DUMP_PREFIX_NONE, 16, 1, + cmd->resp->data, cmd->resp->len, false); + + cmd->cmd_cb(ddev, cmd->cb_context, cmd->resp); + + kfree(cmd->mdaa_params); + kfree(cmd); + + schedule_work(&ddev->cmd_work); +} + +static void digital_send_cmd_complete(struct nfc_digital_dev *ddev, + void *arg, struct sk_buff *resp) +{ + struct digital_cmd *cmd = arg; + + cmd->resp = resp; + + schedule_work(&ddev->cmd_complete_work); +} + +static void digital_wq_cmd(struct work_struct *work) +{ + int rc; + struct digital_cmd *cmd; + struct digital_tg_mdaa_params *params; + struct nfc_digital_dev *ddev = container_of(work, + struct nfc_digital_dev, + cmd_work); + + mutex_lock(&ddev->cmd_lock); + + cmd = list_first_entry_or_null(&ddev->cmd_queue, struct digital_cmd, + queue); + if (!cmd || cmd->pending) { + mutex_unlock(&ddev->cmd_lock); + return; + } + + cmd->pending = 1; + + mutex_unlock(&ddev->cmd_lock); + + if (cmd->req) + print_hex_dump_debug("DIGITAL TX: ", DUMP_PREFIX_NONE, 16, 1, + cmd->req->data, cmd->req->len, false); + + switch (cmd->type) { + case DIGITAL_CMD_IN_SEND: + rc = ddev->ops->in_send_cmd(ddev, cmd->req, cmd->timeout, + digital_send_cmd_complete, cmd); + break; + + case DIGITAL_CMD_TG_SEND: + rc = ddev->ops->tg_send_cmd(ddev, cmd->req, cmd->timeout, + digital_send_cmd_complete, cmd); + break; + + case DIGITAL_CMD_TG_LISTEN: + rc = ddev->ops->tg_listen(ddev, cmd->timeout, + digital_send_cmd_complete, cmd); + break; + + case DIGITAL_CMD_TG_LISTEN_MDAA: + params = cmd->mdaa_params; + + rc = ddev->ops->tg_listen_mdaa(ddev, params, cmd->timeout, + digital_send_cmd_complete, cmd); + break; + + case DIGITAL_CMD_TG_LISTEN_MD: + rc = ddev->ops->tg_listen_md(ddev, cmd->timeout, + digital_send_cmd_complete, cmd); + break; + + default: + pr_err("Unknown cmd type %d\n", cmd->type); + return; + } + + if (!rc) + return; + + pr_err("in_send_command returned err %d\n", rc); + + mutex_lock(&ddev->cmd_lock); + list_del(&cmd->queue); + mutex_unlock(&ddev->cmd_lock); + + kfree_skb(cmd->req); + kfree(cmd->mdaa_params); + kfree(cmd); + + schedule_work(&ddev->cmd_work); +} + +int digital_send_cmd(struct nfc_digital_dev *ddev, u8 cmd_type, + struct sk_buff *skb, struct digital_tg_mdaa_params *params, + u16 timeout, nfc_digital_cmd_complete_t cmd_cb, + void *cb_context) +{ + struct digital_cmd *cmd; + + cmd = kzalloc(sizeof(*cmd), GFP_KERNEL); + if (!cmd) + return -ENOMEM; + + cmd->type = cmd_type; + cmd->timeout = timeout; + cmd->req = skb; + cmd->mdaa_params = params; + cmd->cmd_cb = cmd_cb; + cmd->cb_context = cb_context; + INIT_LIST_HEAD(&cmd->queue); + + mutex_lock(&ddev->cmd_lock); + list_add_tail(&cmd->queue, &ddev->cmd_queue); + mutex_unlock(&ddev->cmd_lock); + + schedule_work(&ddev->cmd_work); + + return 0; +} + +int digital_in_configure_hw(struct nfc_digital_dev *ddev, int type, int param) +{ + int rc; + + rc = ddev->ops->in_configure_hw(ddev, type, param); + if (rc) + pr_err("in_configure_hw failed: %d\n", rc); + + return rc; +} + +int digital_tg_configure_hw(struct nfc_digital_dev *ddev, int type, int param) +{ + int rc; + + rc = ddev->ops->tg_configure_hw(ddev, type, param); + if (rc) + pr_err("tg_configure_hw failed: %d\n", rc); + + return rc; +} + +static int digital_tg_listen_mdaa(struct nfc_digital_dev *ddev, u8 rf_tech) +{ + struct digital_tg_mdaa_params *params; + int rc; + + params = kzalloc(sizeof(*params), GFP_KERNEL); + if (!params) + return -ENOMEM; + + params->sens_res = DIGITAL_SENS_RES_NFC_DEP; + get_random_bytes(params->nfcid1, sizeof(params->nfcid1)); + params->sel_res = DIGITAL_SEL_RES_NFC_DEP; + + params->nfcid2[0] = DIGITAL_SENSF_NFCID2_NFC_DEP_B1; + params->nfcid2[1] = DIGITAL_SENSF_NFCID2_NFC_DEP_B2; + get_random_bytes(params->nfcid2 + 2, NFC_NFCID2_MAXSIZE - 2); + params->sc = DIGITAL_SENSF_FELICA_SC; + + rc = digital_send_cmd(ddev, DIGITAL_CMD_TG_LISTEN_MDAA, NULL, params, + 500, digital_tg_recv_atr_req, NULL); + if (rc) + kfree(params); + + return rc; +} + +static int digital_tg_listen_md(struct nfc_digital_dev *ddev, u8 rf_tech) +{ + return digital_send_cmd(ddev, DIGITAL_CMD_TG_LISTEN_MD, NULL, NULL, 500, + digital_tg_recv_md_req, NULL); +} + +int digital_target_found(struct nfc_digital_dev *ddev, + struct nfc_target *target, u8 protocol) +{ + int rc; + u8 framing; + u8 rf_tech; + u8 poll_tech_count; + int (*check_crc)(struct sk_buff *skb); + void (*add_crc)(struct sk_buff *skb); + + rf_tech = ddev->poll_techs[ddev->poll_tech_index].rf_tech; + + switch (protocol) { + case NFC_PROTO_JEWEL: + framing = NFC_DIGITAL_FRAMING_NFCA_T1T; + check_crc = digital_skb_check_crc_b; + add_crc = digital_skb_add_crc_b; + break; + + case NFC_PROTO_MIFARE: + framing = NFC_DIGITAL_FRAMING_NFCA_T2T; + check_crc = digital_skb_check_crc_a; + add_crc = digital_skb_add_crc_a; + break; + + case NFC_PROTO_FELICA: + framing = NFC_DIGITAL_FRAMING_NFCF_T3T; + check_crc = digital_skb_check_crc_f; + add_crc = digital_skb_add_crc_f; + break; + + case NFC_PROTO_NFC_DEP: + if (rf_tech == NFC_DIGITAL_RF_TECH_106A) { + framing = NFC_DIGITAL_FRAMING_NFCA_NFC_DEP; + check_crc = digital_skb_check_crc_a; + add_crc = digital_skb_add_crc_a; + } else { + framing = NFC_DIGITAL_FRAMING_NFCF_NFC_DEP; + check_crc = digital_skb_check_crc_f; + add_crc = digital_skb_add_crc_f; + } + break; + + case NFC_PROTO_ISO15693: + framing = NFC_DIGITAL_FRAMING_ISO15693_T5T; + check_crc = digital_skb_check_crc_b; + add_crc = digital_skb_add_crc_b; + break; + + case NFC_PROTO_ISO14443: + framing = NFC_DIGITAL_FRAMING_NFCA_T4T; + check_crc = digital_skb_check_crc_a; + add_crc = digital_skb_add_crc_a; + break; + + case NFC_PROTO_ISO14443_B: + framing = NFC_DIGITAL_FRAMING_NFCB_T4T; + check_crc = digital_skb_check_crc_b; + add_crc = digital_skb_add_crc_b; + break; + + default: + pr_err("Invalid protocol %d\n", protocol); + return -EINVAL; + } + + pr_debug("rf_tech=%d, protocol=%d\n", rf_tech, protocol); + + ddev->curr_rf_tech = rf_tech; + + if (DIGITAL_DRV_CAPS_IN_CRC(ddev)) { + ddev->skb_add_crc = digital_skb_add_crc_none; + ddev->skb_check_crc = digital_skb_check_crc_none; + } else { + ddev->skb_add_crc = add_crc; + ddev->skb_check_crc = check_crc; + } + + rc = digital_in_configure_hw(ddev, NFC_DIGITAL_CONFIG_FRAMING, framing); + if (rc) + return rc; + + target->supported_protocols = (1 << protocol); + + poll_tech_count = ddev->poll_tech_count; + ddev->poll_tech_count = 0; + + rc = nfc_targets_found(ddev->nfc_dev, target, 1); + if (rc) { + ddev->poll_tech_count = poll_tech_count; + return rc; + } + + return 0; +} + +void digital_poll_next_tech(struct nfc_digital_dev *ddev) +{ + u8 rand_mod; + + digital_switch_rf(ddev, 0); + + mutex_lock(&ddev->poll_lock); + + if (!ddev->poll_tech_count) { + mutex_unlock(&ddev->poll_lock); + return; + } + + get_random_bytes(&rand_mod, sizeof(rand_mod)); + ddev->poll_tech_index = rand_mod % ddev->poll_tech_count; + + mutex_unlock(&ddev->poll_lock); + + schedule_delayed_work(&ddev->poll_work, + msecs_to_jiffies(DIGITAL_POLL_INTERVAL)); +} + +static void digital_wq_poll(struct work_struct *work) +{ + int rc; + struct digital_poll_tech *poll_tech; + struct nfc_digital_dev *ddev = container_of(work, + struct nfc_digital_dev, + poll_work.work); + mutex_lock(&ddev->poll_lock); + + if (!ddev->poll_tech_count) { + mutex_unlock(&ddev->poll_lock); + return; + } + + poll_tech = &ddev->poll_techs[ddev->poll_tech_index]; + + mutex_unlock(&ddev->poll_lock); + + rc = poll_tech->poll_func(ddev, poll_tech->rf_tech); + if (rc) + digital_poll_next_tech(ddev); +} + +static void digital_add_poll_tech(struct nfc_digital_dev *ddev, u8 rf_tech, + digital_poll_t poll_func) +{ + struct digital_poll_tech *poll_tech; + + if (ddev->poll_tech_count >= NFC_DIGITAL_POLL_MODE_COUNT_MAX) + return; + + poll_tech = &ddev->poll_techs[ddev->poll_tech_count++]; + + poll_tech->rf_tech = rf_tech; + poll_tech->poll_func = poll_func; +} + +/** + * digital_start_poll - start_poll operation + * @nfc_dev: device to be polled + * @im_protocols: bitset of nfc initiator protocols to be used for polling + * @tm_protocols: bitset of nfc transport protocols to be used for polling + * + * For every supported protocol, the corresponding polling function is added + * to the table of polling technologies (ddev->poll_techs[]) using + * digital_add_poll_tech(). + * When a polling function fails (by timeout or protocol error) the next one is + * schedule by digital_poll_next_tech() on the poll workqueue (ddev->poll_work). + */ +static int digital_start_poll(struct nfc_dev *nfc_dev, __u32 im_protocols, + __u32 tm_protocols) +{ + struct nfc_digital_dev *ddev = nfc_get_drvdata(nfc_dev); + u32 matching_im_protocols, matching_tm_protocols; + + pr_debug("protocols: im 0x%x, tm 0x%x, supported 0x%x\n", im_protocols, + tm_protocols, ddev->protocols); + + matching_im_protocols = ddev->protocols & im_protocols; + matching_tm_protocols = ddev->protocols & tm_protocols; + + if (!matching_im_protocols && !matching_tm_protocols) { + pr_err("Unknown protocol\n"); + return -EINVAL; + } + + if (ddev->poll_tech_count) { + pr_err("Already polling\n"); + return -EBUSY; + } + + if (ddev->curr_protocol) { + pr_err("A target is already active\n"); + return -EBUSY; + } + + ddev->poll_tech_count = 0; + ddev->poll_tech_index = 0; + + if (matching_im_protocols & DIGITAL_PROTO_NFCA_RF_TECH) + digital_add_poll_tech(ddev, NFC_DIGITAL_RF_TECH_106A, + digital_in_send_sens_req); + + if (matching_im_protocols & DIGITAL_PROTO_NFCB_RF_TECH) + digital_add_poll_tech(ddev, NFC_DIGITAL_RF_TECH_106B, + digital_in_send_sensb_req); + + if (matching_im_protocols & DIGITAL_PROTO_NFCF_RF_TECH) { + digital_add_poll_tech(ddev, NFC_DIGITAL_RF_TECH_212F, + digital_in_send_sensf_req); + + digital_add_poll_tech(ddev, NFC_DIGITAL_RF_TECH_424F, + digital_in_send_sensf_req); + } + + if (matching_im_protocols & DIGITAL_PROTO_ISO15693_RF_TECH) + digital_add_poll_tech(ddev, NFC_DIGITAL_RF_TECH_ISO15693, + digital_in_send_iso15693_inv_req); + + if (matching_tm_protocols & NFC_PROTO_NFC_DEP_MASK) { + if (ddev->ops->tg_listen_mdaa) { + digital_add_poll_tech(ddev, 0, + digital_tg_listen_mdaa); + } else if (ddev->ops->tg_listen_md) { + digital_add_poll_tech(ddev, 0, + digital_tg_listen_md); + } else { + digital_add_poll_tech(ddev, NFC_DIGITAL_RF_TECH_106A, + digital_tg_listen_nfca); + + digital_add_poll_tech(ddev, NFC_DIGITAL_RF_TECH_212F, + digital_tg_listen_nfcf); + + digital_add_poll_tech(ddev, NFC_DIGITAL_RF_TECH_424F, + digital_tg_listen_nfcf); + } + } + + if (!ddev->poll_tech_count) { + pr_err("Unsupported protocols: im=0x%x, tm=0x%x\n", + matching_im_protocols, matching_tm_protocols); + return -EINVAL; + } + + schedule_delayed_work(&ddev->poll_work, 0); + + return 0; +} + +static void digital_stop_poll(struct nfc_dev *nfc_dev) +{ + struct nfc_digital_dev *ddev = nfc_get_drvdata(nfc_dev); + + mutex_lock(&ddev->poll_lock); + + if (!ddev->poll_tech_count) { + pr_err("Polling operation was not running\n"); + mutex_unlock(&ddev->poll_lock); + return; + } + + ddev->poll_tech_count = 0; + + mutex_unlock(&ddev->poll_lock); + + cancel_delayed_work_sync(&ddev->poll_work); + + digital_abort_cmd(ddev); +} + +static int digital_dev_up(struct nfc_dev *nfc_dev) +{ + struct nfc_digital_dev *ddev = nfc_get_drvdata(nfc_dev); + + digital_switch_rf(ddev, 1); + + return 0; +} + +static int digital_dev_down(struct nfc_dev *nfc_dev) +{ + struct nfc_digital_dev *ddev = nfc_get_drvdata(nfc_dev); + + digital_switch_rf(ddev, 0); + + return 0; +} + +static int digital_dep_link_up(struct nfc_dev *nfc_dev, + struct nfc_target *target, + __u8 comm_mode, __u8 *gb, size_t gb_len) +{ + struct nfc_digital_dev *ddev = nfc_get_drvdata(nfc_dev); + int rc; + + rc = digital_in_send_atr_req(ddev, target, comm_mode, gb, gb_len); + + if (!rc) + ddev->curr_protocol = NFC_PROTO_NFC_DEP; + + return rc; +} + +static int digital_dep_link_down(struct nfc_dev *nfc_dev) +{ + struct nfc_digital_dev *ddev = nfc_get_drvdata(nfc_dev); + + digital_abort_cmd(ddev); + + ddev->curr_protocol = 0; + + return 0; +} + +static int digital_activate_target(struct nfc_dev *nfc_dev, + struct nfc_target *target, __u32 protocol) +{ + struct nfc_digital_dev *ddev = nfc_get_drvdata(nfc_dev); + + if (ddev->poll_tech_count) { + pr_err("Can't activate a target while polling\n"); + return -EBUSY; + } + + if (ddev->curr_protocol) { + pr_err("A target is already active\n"); + return -EBUSY; + } + + ddev->curr_protocol = protocol; + + return 0; +} + +static void digital_deactivate_target(struct nfc_dev *nfc_dev, + struct nfc_target *target, + u8 mode) +{ + struct nfc_digital_dev *ddev = nfc_get_drvdata(nfc_dev); + + if (!ddev->curr_protocol) { + pr_err("No active target\n"); + return; + } + + digital_abort_cmd(ddev); + ddev->curr_protocol = 0; +} + +static int digital_tg_send(struct nfc_dev *dev, struct sk_buff *skb) +{ + struct nfc_digital_dev *ddev = nfc_get_drvdata(dev); + + return digital_tg_send_dep_res(ddev, skb); +} + +static void digital_in_send_complete(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + struct digital_data_exch *data_exch = arg; + int rc; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto done; + } + + if (ddev->curr_protocol == NFC_PROTO_MIFARE) { + rc = digital_in_recv_mifare_res(resp); + /* crc check is done in digital_in_recv_mifare_res() */ + goto done; + } + + if ((ddev->curr_protocol == NFC_PROTO_ISO14443) || + (ddev->curr_protocol == NFC_PROTO_ISO14443_B)) { + rc = digital_in_iso_dep_pull_sod(ddev, resp); + if (rc) + goto done; + } + + rc = ddev->skb_check_crc(resp); + +done: + if (rc) { + kfree_skb(resp); + resp = NULL; + } + + data_exch->cb(data_exch->cb_context, resp, rc); + + kfree(data_exch); +} + +static int digital_in_send(struct nfc_dev *nfc_dev, struct nfc_target *target, + struct sk_buff *skb, data_exchange_cb_t cb, + void *cb_context) +{ + struct nfc_digital_dev *ddev = nfc_get_drvdata(nfc_dev); + struct digital_data_exch *data_exch; + int rc; + + data_exch = kzalloc(sizeof(*data_exch), GFP_KERNEL); + if (!data_exch) + return -ENOMEM; + + data_exch->cb = cb; + data_exch->cb_context = cb_context; + + if (ddev->curr_protocol == NFC_PROTO_NFC_DEP) { + rc = digital_in_send_dep_req(ddev, target, skb, data_exch); + goto exit; + } + + if ((ddev->curr_protocol == NFC_PROTO_ISO14443) || + (ddev->curr_protocol == NFC_PROTO_ISO14443_B)) { + rc = digital_in_iso_dep_push_sod(ddev, skb); + if (rc) + goto exit; + } + + ddev->skb_add_crc(skb); + + rc = digital_in_send_cmd(ddev, skb, 500, digital_in_send_complete, + data_exch); + +exit: + if (rc) + kfree(data_exch); + + return rc; +} + +static const struct nfc_ops digital_nfc_ops = { + .dev_up = digital_dev_up, + .dev_down = digital_dev_down, + .start_poll = digital_start_poll, + .stop_poll = digital_stop_poll, + .dep_link_up = digital_dep_link_up, + .dep_link_down = digital_dep_link_down, + .activate_target = digital_activate_target, + .deactivate_target = digital_deactivate_target, + .tm_send = digital_tg_send, + .im_transceive = digital_in_send, +}; + +struct nfc_digital_dev *nfc_digital_allocate_device(const struct nfc_digital_ops *ops, + __u32 supported_protocols, + __u32 driver_capabilities, + int tx_headroom, int tx_tailroom) +{ + struct nfc_digital_dev *ddev; + + if (!ops->in_configure_hw || !ops->in_send_cmd || !ops->tg_listen || + !ops->tg_configure_hw || !ops->tg_send_cmd || !ops->abort_cmd || + !ops->switch_rf || (ops->tg_listen_md && !ops->tg_get_rf_tech)) + return NULL; + + ddev = kzalloc(sizeof(*ddev), GFP_KERNEL); + if (!ddev) + return NULL; + + ddev->driver_capabilities = driver_capabilities; + ddev->ops = ops; + + mutex_init(&ddev->cmd_lock); + INIT_LIST_HEAD(&ddev->cmd_queue); + + INIT_WORK(&ddev->cmd_work, digital_wq_cmd); + INIT_WORK(&ddev->cmd_complete_work, digital_wq_cmd_complete); + + mutex_init(&ddev->poll_lock); + INIT_DELAYED_WORK(&ddev->poll_work, digital_wq_poll); + + if (supported_protocols & NFC_PROTO_JEWEL_MASK) + ddev->protocols |= NFC_PROTO_JEWEL_MASK; + if (supported_protocols & NFC_PROTO_MIFARE_MASK) + ddev->protocols |= NFC_PROTO_MIFARE_MASK; + if (supported_protocols & NFC_PROTO_FELICA_MASK) + ddev->protocols |= NFC_PROTO_FELICA_MASK; + if (supported_protocols & NFC_PROTO_NFC_DEP_MASK) + ddev->protocols |= NFC_PROTO_NFC_DEP_MASK; + if (supported_protocols & NFC_PROTO_ISO15693_MASK) + ddev->protocols |= NFC_PROTO_ISO15693_MASK; + if (supported_protocols & NFC_PROTO_ISO14443_MASK) + ddev->protocols |= NFC_PROTO_ISO14443_MASK; + if (supported_protocols & NFC_PROTO_ISO14443_B_MASK) + ddev->protocols |= NFC_PROTO_ISO14443_B_MASK; + + ddev->tx_headroom = tx_headroom + DIGITAL_MAX_HEADER_LEN; + ddev->tx_tailroom = tx_tailroom + DIGITAL_CRC_LEN; + + ddev->nfc_dev = nfc_allocate_device(&digital_nfc_ops, ddev->protocols, + ddev->tx_headroom, + ddev->tx_tailroom); + if (!ddev->nfc_dev) { + pr_err("nfc_allocate_device failed\n"); + goto free_dev; + } + + nfc_set_drvdata(ddev->nfc_dev, ddev); + + return ddev; + +free_dev: + kfree(ddev); + + return NULL; +} +EXPORT_SYMBOL(nfc_digital_allocate_device); + +void nfc_digital_free_device(struct nfc_digital_dev *ddev) +{ + nfc_free_device(ddev->nfc_dev); + kfree(ddev); +} +EXPORT_SYMBOL(nfc_digital_free_device); + +int nfc_digital_register_device(struct nfc_digital_dev *ddev) +{ + return nfc_register_device(ddev->nfc_dev); +} +EXPORT_SYMBOL(nfc_digital_register_device); + +void nfc_digital_unregister_device(struct nfc_digital_dev *ddev) +{ + struct digital_cmd *cmd, *n; + + nfc_unregister_device(ddev->nfc_dev); + + mutex_lock(&ddev->poll_lock); + ddev->poll_tech_count = 0; + mutex_unlock(&ddev->poll_lock); + + cancel_delayed_work_sync(&ddev->poll_work); + cancel_work_sync(&ddev->cmd_work); + cancel_work_sync(&ddev->cmd_complete_work); + + list_for_each_entry_safe(cmd, n, &ddev->cmd_queue, queue) { + list_del(&cmd->queue); + + /* Call the command callback if any and pass it a ENODEV error. + * This gives a chance to the command issuer to free any + * allocated buffer. + */ + if (cmd->cmd_cb) + cmd->cmd_cb(ddev, cmd->cb_context, ERR_PTR(-ENODEV)); + + kfree(cmd->mdaa_params); + kfree(cmd); + } +} +EXPORT_SYMBOL(nfc_digital_unregister_device); + +MODULE_LICENSE("GPL"); diff --git a/net/nfc/digital_dep.c b/net/nfc/digital_dep.c new file mode 100644 index 000000000..3982fa084 --- /dev/null +++ b/net/nfc/digital_dep.c @@ -0,0 +1,1633 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * NFC Digital Protocol stack + * Copyright (c) 2013, Intel Corporation. + */ + +#define pr_fmt(fmt) "digital: %s: " fmt, __func__ + +#include "digital.h" + +#define DIGITAL_NFC_DEP_N_RETRY_NACK 2 +#define DIGITAL_NFC_DEP_N_RETRY_ATN 2 + +#define DIGITAL_NFC_DEP_FRAME_DIR_OUT 0xD4 +#define DIGITAL_NFC_DEP_FRAME_DIR_IN 0xD5 + +#define DIGITAL_NFC_DEP_NFCA_SOD_SB 0xF0 + +#define DIGITAL_CMD_ATR_REQ 0x00 +#define DIGITAL_CMD_ATR_RES 0x01 +#define DIGITAL_CMD_PSL_REQ 0x04 +#define DIGITAL_CMD_PSL_RES 0x05 +#define DIGITAL_CMD_DEP_REQ 0x06 +#define DIGITAL_CMD_DEP_RES 0x07 + +#define DIGITAL_ATR_REQ_MIN_SIZE 16 +#define DIGITAL_ATR_REQ_MAX_SIZE 64 + +#define DIGITAL_ATR_RES_TO_WT(s) ((s) & 0xF) + +#define DIGITAL_DID_MAX 14 + +#define DIGITAL_PAYLOAD_SIZE_MAX 254 +#define DIGITAL_PAYLOAD_BITS_TO_PP(s) (((s) & 0x3) << 4) +#define DIGITAL_PAYLOAD_PP_TO_BITS(s) (((s) >> 4) & 0x3) +#define DIGITAL_PAYLOAD_BITS_TO_FSL(s) ((s) & 0x3) +#define DIGITAL_PAYLOAD_FSL_TO_BITS(s) ((s) & 0x3) + +#define DIGITAL_GB_BIT 0x02 + +#define DIGITAL_NFC_DEP_PFB_TYPE(pfb) ((pfb) & 0xE0) + +#define DIGITAL_NFC_DEP_PFB_TIMEOUT_BIT 0x10 +#define DIGITAL_NFC_DEP_PFB_MI_BIT 0x10 +#define DIGITAL_NFC_DEP_PFB_NACK_BIT 0x10 +#define DIGITAL_NFC_DEP_PFB_DID_BIT 0x04 + +#define DIGITAL_NFC_DEP_PFB_IS_TIMEOUT(pfb) \ + ((pfb) & DIGITAL_NFC_DEP_PFB_TIMEOUT_BIT) +#define DIGITAL_NFC_DEP_MI_BIT_SET(pfb) ((pfb) & DIGITAL_NFC_DEP_PFB_MI_BIT) +#define DIGITAL_NFC_DEP_NACK_BIT_SET(pfb) ((pfb) & DIGITAL_NFC_DEP_PFB_NACK_BIT) +#define DIGITAL_NFC_DEP_NAD_BIT_SET(pfb) ((pfb) & 0x08) +#define DIGITAL_NFC_DEP_DID_BIT_SET(pfb) ((pfb) & DIGITAL_NFC_DEP_PFB_DID_BIT) +#define DIGITAL_NFC_DEP_PFB_PNI(pfb) ((pfb) & 0x03) + +#define DIGITAL_NFC_DEP_RTOX_VALUE(data) ((data) & 0x3F) +#define DIGITAL_NFC_DEP_RTOX_MAX 59 + +#define DIGITAL_NFC_DEP_PFB_I_PDU 0x00 +#define DIGITAL_NFC_DEP_PFB_ACK_NACK_PDU 0x40 +#define DIGITAL_NFC_DEP_PFB_SUPERVISOR_PDU 0x80 + +struct digital_atr_req { + u8 dir; + u8 cmd; + u8 nfcid3[10]; + u8 did; + u8 bs; + u8 br; + u8 pp; + u8 gb[]; +} __packed; + +struct digital_atr_res { + u8 dir; + u8 cmd; + u8 nfcid3[10]; + u8 did; + u8 bs; + u8 br; + u8 to; + u8 pp; + u8 gb[]; +} __packed; + +struct digital_psl_req { + u8 dir; + u8 cmd; + u8 did; + u8 brs; + u8 fsl; +} __packed; + +struct digital_psl_res { + u8 dir; + u8 cmd; + u8 did; +} __packed; + +struct digital_dep_req_res { + u8 dir; + u8 cmd; + u8 pfb; +} __packed; + +static void digital_in_recv_dep_res(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp); +static void digital_tg_recv_dep_req(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp); + +static const u8 digital_payload_bits_map[4] = { + [0] = 64, + [1] = 128, + [2] = 192, + [3] = 254 +}; + +/* Response Waiting Time for ATR_RES PDU in ms + * + * RWT(ATR_RES) = RWT(nfcdep,activation) + dRWT(nfcdep) + dT(nfcdep,initiator) + * + * with: + * RWT(nfcdep,activation) = 4096 * 2^12 / f(c) s + * dRWT(nfcdep) = 16 / f(c) s + * dT(nfcdep,initiator) = 100 ms + * f(c) = 13560000 Hz + */ +#define DIGITAL_ATR_RES_RWT 1337 + +/* Response Waiting Time for other DEP PDUs in ms + * + * max_rwt = rwt + dRWT(nfcdep) + dT(nfcdep,initiator) + * + * with: + * rwt = (256 * 16 / f(c)) * 2^wt s + * dRWT(nfcdep) = 16 / f(c) s + * dT(nfcdep,initiator) = 100 ms + * f(c) = 13560000 Hz + * 0 <= wt <= 14 (given by the target by the TO field of ATR_RES response) + */ +#define DIGITAL_NFC_DEP_IN_MAX_WT 14 +#define DIGITAL_NFC_DEP_TG_MAX_WT 14 +static const u16 digital_rwt_map[DIGITAL_NFC_DEP_IN_MAX_WT + 1] = { + 100, 101, 101, 102, 105, + 110, 119, 139, 177, 255, + 409, 719, 1337, 2575, 5049, +}; + +static u8 digital_payload_bits_to_size(u8 payload_bits) +{ + if (payload_bits >= ARRAY_SIZE(digital_payload_bits_map)) + return 0; + + return digital_payload_bits_map[payload_bits]; +} + +static u8 digital_payload_size_to_bits(u8 payload_size) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(digital_payload_bits_map); i++) + if (digital_payload_bits_map[i] == payload_size) + return i; + + return 0xff; +} + +static void digital_skb_push_dep_sod(struct nfc_digital_dev *ddev, + struct sk_buff *skb) +{ + skb_push(skb, sizeof(u8)); + + skb->data[0] = skb->len; + + if (ddev->curr_rf_tech == NFC_DIGITAL_RF_TECH_106A) + *(u8 *)skb_push(skb, sizeof(u8)) = DIGITAL_NFC_DEP_NFCA_SOD_SB; +} + +static int digital_skb_pull_dep_sod(struct nfc_digital_dev *ddev, + struct sk_buff *skb) +{ + u8 size; + + if (skb->len < 2) + return -EIO; + + if (ddev->curr_rf_tech == NFC_DIGITAL_RF_TECH_106A) + skb_pull(skb, sizeof(u8)); + + size = skb->data[0]; + if (size != skb->len) + return -EIO; + + skb_pull(skb, sizeof(u8)); + + return 0; +} + +static struct sk_buff * +digital_send_dep_data_prep(struct nfc_digital_dev *ddev, struct sk_buff *skb, + struct digital_dep_req_res *dep_req_res, + struct digital_data_exch *data_exch) +{ + struct sk_buff *new_skb; + + if (skb->len > ddev->remote_payload_max) { + dep_req_res->pfb |= DIGITAL_NFC_DEP_PFB_MI_BIT; + + new_skb = digital_skb_alloc(ddev, ddev->remote_payload_max); + if (!new_skb) { + kfree_skb(ddev->chaining_skb); + ddev->chaining_skb = NULL; + + return ERR_PTR(-ENOMEM); + } + + skb_put_data(new_skb, skb->data, ddev->remote_payload_max); + skb_pull(skb, ddev->remote_payload_max); + + ddev->chaining_skb = skb; + ddev->data_exch = data_exch; + } else { + ddev->chaining_skb = NULL; + new_skb = skb; + } + + return new_skb; +} + +static struct sk_buff * +digital_recv_dep_data_gather(struct nfc_digital_dev *ddev, u8 pfb, + struct sk_buff *resp, + int (*send_ack)(struct nfc_digital_dev *ddev, + struct digital_data_exch + *data_exch), + struct digital_data_exch *data_exch) +{ + struct sk_buff *new_skb; + int rc; + + if (DIGITAL_NFC_DEP_MI_BIT_SET(pfb) && (!ddev->chaining_skb)) { + ddev->chaining_skb = + nfc_alloc_recv_skb(8 * ddev->local_payload_max, + GFP_KERNEL); + if (!ddev->chaining_skb) { + rc = -ENOMEM; + goto error; + } + } + + if (ddev->chaining_skb) { + if (resp->len > skb_tailroom(ddev->chaining_skb)) { + new_skb = skb_copy_expand(ddev->chaining_skb, + skb_headroom( + ddev->chaining_skb), + 8 * ddev->local_payload_max, + GFP_KERNEL); + if (!new_skb) { + rc = -ENOMEM; + goto error; + } + + kfree_skb(ddev->chaining_skb); + ddev->chaining_skb = new_skb; + } + + skb_put_data(ddev->chaining_skb, resp->data, resp->len); + + kfree_skb(resp); + resp = NULL; + + if (DIGITAL_NFC_DEP_MI_BIT_SET(pfb)) { + rc = send_ack(ddev, data_exch); + if (rc) + goto error; + + return NULL; + } + + resp = ddev->chaining_skb; + ddev->chaining_skb = NULL; + } + + return resp; + +error: + kfree_skb(resp); + + kfree_skb(ddev->chaining_skb); + ddev->chaining_skb = NULL; + + return ERR_PTR(rc); +} + +static void digital_in_recv_psl_res(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + struct nfc_target *target = arg; + struct digital_psl_res *psl_res; + int rc; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto exit; + } + + rc = ddev->skb_check_crc(resp); + if (rc) { + PROTOCOL_ERR("14.4.1.6"); + goto exit; + } + + rc = digital_skb_pull_dep_sod(ddev, resp); + if (rc) { + PROTOCOL_ERR("14.4.1.2"); + goto exit; + } + + psl_res = (struct digital_psl_res *)resp->data; + + if ((resp->len != sizeof(*psl_res)) || + (psl_res->dir != DIGITAL_NFC_DEP_FRAME_DIR_IN) || + (psl_res->cmd != DIGITAL_CMD_PSL_RES)) { + rc = -EIO; + goto exit; + } + + rc = digital_in_configure_hw(ddev, NFC_DIGITAL_CONFIG_RF_TECH, + NFC_DIGITAL_RF_TECH_424F); + if (rc) + goto exit; + + rc = digital_in_configure_hw(ddev, NFC_DIGITAL_CONFIG_FRAMING, + NFC_DIGITAL_FRAMING_NFCF_NFC_DEP); + if (rc) + goto exit; + + if (!DIGITAL_DRV_CAPS_IN_CRC(ddev) && + (ddev->curr_rf_tech == NFC_DIGITAL_RF_TECH_106A)) { + ddev->skb_add_crc = digital_skb_add_crc_f; + ddev->skb_check_crc = digital_skb_check_crc_f; + } + + ddev->curr_rf_tech = NFC_DIGITAL_RF_TECH_424F; + + nfc_dep_link_is_up(ddev->nfc_dev, target->idx, NFC_COMM_ACTIVE, + NFC_RF_INITIATOR); + + ddev->curr_nfc_dep_pni = 0; + +exit: + dev_kfree_skb(resp); + + if (rc) + ddev->curr_protocol = 0; +} + +static int digital_in_send_psl_req(struct nfc_digital_dev *ddev, + struct nfc_target *target) +{ + struct sk_buff *skb; + struct digital_psl_req *psl_req; + int rc; + u8 payload_size, payload_bits; + + skb = digital_skb_alloc(ddev, sizeof(*psl_req)); + if (!skb) + return -ENOMEM; + + skb_put(skb, sizeof(*psl_req)); + + psl_req = (struct digital_psl_req *)skb->data; + + psl_req->dir = DIGITAL_NFC_DEP_FRAME_DIR_OUT; + psl_req->cmd = DIGITAL_CMD_PSL_REQ; + psl_req->did = 0; + psl_req->brs = (0x2 << 3) | 0x2; /* 424F both directions */ + + payload_size = min(ddev->local_payload_max, ddev->remote_payload_max); + payload_bits = digital_payload_size_to_bits(payload_size); + psl_req->fsl = DIGITAL_PAYLOAD_BITS_TO_FSL(payload_bits); + + ddev->local_payload_max = payload_size; + ddev->remote_payload_max = payload_size; + + digital_skb_push_dep_sod(ddev, skb); + + ddev->skb_add_crc(skb); + + rc = digital_in_send_cmd(ddev, skb, ddev->dep_rwt, + digital_in_recv_psl_res, target); + if (rc) + kfree_skb(skb); + + return rc; +} + +static void digital_in_recv_atr_res(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + struct nfc_target *target = arg; + struct digital_atr_res *atr_res; + u8 gb_len, payload_bits; + u8 wt; + int rc; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto exit; + } + + rc = ddev->skb_check_crc(resp); + if (rc) { + PROTOCOL_ERR("14.4.1.6"); + goto exit; + } + + rc = digital_skb_pull_dep_sod(ddev, resp); + if (rc) { + PROTOCOL_ERR("14.4.1.2"); + goto exit; + } + + if (resp->len < sizeof(struct digital_atr_res)) { + rc = -EIO; + goto exit; + } + + gb_len = resp->len - sizeof(struct digital_atr_res); + + atr_res = (struct digital_atr_res *)resp->data; + + wt = DIGITAL_ATR_RES_TO_WT(atr_res->to); + if (wt > DIGITAL_NFC_DEP_IN_MAX_WT) + wt = DIGITAL_NFC_DEP_IN_MAX_WT; + ddev->dep_rwt = digital_rwt_map[wt]; + + payload_bits = DIGITAL_PAYLOAD_PP_TO_BITS(atr_res->pp); + ddev->remote_payload_max = digital_payload_bits_to_size(payload_bits); + + if (!ddev->remote_payload_max) { + rc = -EINVAL; + goto exit; + } + + rc = nfc_set_remote_general_bytes(ddev->nfc_dev, atr_res->gb, gb_len); + if (rc) + goto exit; + + if ((ddev->protocols & NFC_PROTO_FELICA_MASK) && + (ddev->curr_rf_tech != NFC_DIGITAL_RF_TECH_424F)) { + rc = digital_in_send_psl_req(ddev, target); + if (!rc) + goto exit; + } + + rc = nfc_dep_link_is_up(ddev->nfc_dev, target->idx, NFC_COMM_ACTIVE, + NFC_RF_INITIATOR); + + ddev->curr_nfc_dep_pni = 0; + +exit: + dev_kfree_skb(resp); + + if (rc) + ddev->curr_protocol = 0; +} + +int digital_in_send_atr_req(struct nfc_digital_dev *ddev, + struct nfc_target *target, __u8 comm_mode, __u8 *gb, + size_t gb_len) +{ + struct sk_buff *skb; + struct digital_atr_req *atr_req; + uint size; + int rc; + u8 payload_bits; + + size = DIGITAL_ATR_REQ_MIN_SIZE + gb_len; + + if (size > DIGITAL_ATR_REQ_MAX_SIZE) { + PROTOCOL_ERR("14.6.1.1"); + return -EINVAL; + } + + skb = digital_skb_alloc(ddev, size); + if (!skb) + return -ENOMEM; + + skb_put(skb, sizeof(struct digital_atr_req)); + + atr_req = (struct digital_atr_req *)skb->data; + memset(atr_req, 0, sizeof(struct digital_atr_req)); + + atr_req->dir = DIGITAL_NFC_DEP_FRAME_DIR_OUT; + atr_req->cmd = DIGITAL_CMD_ATR_REQ; + if (target->nfcid2_len) + memcpy(atr_req->nfcid3, target->nfcid2, NFC_NFCID2_MAXSIZE); + else + get_random_bytes(atr_req->nfcid3, NFC_NFCID3_MAXSIZE); + + atr_req->did = 0; + atr_req->bs = 0; + atr_req->br = 0; + + ddev->local_payload_max = DIGITAL_PAYLOAD_SIZE_MAX; + payload_bits = digital_payload_size_to_bits(ddev->local_payload_max); + atr_req->pp = DIGITAL_PAYLOAD_BITS_TO_PP(payload_bits); + + if (gb_len) { + atr_req->pp |= DIGITAL_GB_BIT; + skb_put_data(skb, gb, gb_len); + } + + digital_skb_push_dep_sod(ddev, skb); + + ddev->skb_add_crc(skb); + + rc = digital_in_send_cmd(ddev, skb, DIGITAL_ATR_RES_RWT, + digital_in_recv_atr_res, target); + if (rc) + kfree_skb(skb); + + return rc; +} + +static int digital_in_send_ack(struct nfc_digital_dev *ddev, + struct digital_data_exch *data_exch) +{ + struct digital_dep_req_res *dep_req; + struct sk_buff *skb; + int rc; + + skb = digital_skb_alloc(ddev, 1); + if (!skb) + return -ENOMEM; + + skb_push(skb, sizeof(struct digital_dep_req_res)); + + dep_req = (struct digital_dep_req_res *)skb->data; + + dep_req->dir = DIGITAL_NFC_DEP_FRAME_DIR_OUT; + dep_req->cmd = DIGITAL_CMD_DEP_REQ; + dep_req->pfb = DIGITAL_NFC_DEP_PFB_ACK_NACK_PDU | + ddev->curr_nfc_dep_pni; + + digital_skb_push_dep_sod(ddev, skb); + + ddev->skb_add_crc(skb); + + ddev->saved_skb = pskb_copy(skb, GFP_KERNEL); + + rc = digital_in_send_cmd(ddev, skb, ddev->dep_rwt, + digital_in_recv_dep_res, data_exch); + if (rc) { + kfree_skb(skb); + kfree_skb(ddev->saved_skb); + ddev->saved_skb = NULL; + } + + return rc; +} + +static int digital_in_send_nack(struct nfc_digital_dev *ddev, + struct digital_data_exch *data_exch) +{ + struct digital_dep_req_res *dep_req; + struct sk_buff *skb; + int rc; + + skb = digital_skb_alloc(ddev, 1); + if (!skb) + return -ENOMEM; + + skb_push(skb, sizeof(struct digital_dep_req_res)); + + dep_req = (struct digital_dep_req_res *)skb->data; + + dep_req->dir = DIGITAL_NFC_DEP_FRAME_DIR_OUT; + dep_req->cmd = DIGITAL_CMD_DEP_REQ; + dep_req->pfb = DIGITAL_NFC_DEP_PFB_ACK_NACK_PDU | + DIGITAL_NFC_DEP_PFB_NACK_BIT | ddev->curr_nfc_dep_pni; + + digital_skb_push_dep_sod(ddev, skb); + + ddev->skb_add_crc(skb); + + rc = digital_in_send_cmd(ddev, skb, ddev->dep_rwt, + digital_in_recv_dep_res, data_exch); + if (rc) + kfree_skb(skb); + + return rc; +} + +static int digital_in_send_atn(struct nfc_digital_dev *ddev, + struct digital_data_exch *data_exch) +{ + struct digital_dep_req_res *dep_req; + struct sk_buff *skb; + int rc; + + skb = digital_skb_alloc(ddev, 1); + if (!skb) + return -ENOMEM; + + skb_push(skb, sizeof(struct digital_dep_req_res)); + + dep_req = (struct digital_dep_req_res *)skb->data; + + dep_req->dir = DIGITAL_NFC_DEP_FRAME_DIR_OUT; + dep_req->cmd = DIGITAL_CMD_DEP_REQ; + dep_req->pfb = DIGITAL_NFC_DEP_PFB_SUPERVISOR_PDU; + + digital_skb_push_dep_sod(ddev, skb); + + ddev->skb_add_crc(skb); + + rc = digital_in_send_cmd(ddev, skb, ddev->dep_rwt, + digital_in_recv_dep_res, data_exch); + if (rc) + kfree_skb(skb); + + return rc; +} + +static int digital_in_send_rtox(struct nfc_digital_dev *ddev, + struct digital_data_exch *data_exch, u8 rtox) +{ + struct digital_dep_req_res *dep_req; + struct sk_buff *skb; + int rc; + u16 rwt_int; + + rwt_int = ddev->dep_rwt * rtox; + if (rwt_int > digital_rwt_map[DIGITAL_NFC_DEP_IN_MAX_WT]) + rwt_int = digital_rwt_map[DIGITAL_NFC_DEP_IN_MAX_WT]; + + skb = digital_skb_alloc(ddev, 1); + if (!skb) + return -ENOMEM; + + skb_put_u8(skb, rtox); + + skb_push(skb, sizeof(struct digital_dep_req_res)); + + dep_req = (struct digital_dep_req_res *)skb->data; + + dep_req->dir = DIGITAL_NFC_DEP_FRAME_DIR_OUT; + dep_req->cmd = DIGITAL_CMD_DEP_REQ; + dep_req->pfb = DIGITAL_NFC_DEP_PFB_SUPERVISOR_PDU | + DIGITAL_NFC_DEP_PFB_TIMEOUT_BIT; + + digital_skb_push_dep_sod(ddev, skb); + + ddev->skb_add_crc(skb); + + rc = digital_in_send_cmd(ddev, skb, rwt_int, + digital_in_recv_dep_res, data_exch); + if (rc) + kfree_skb(skb); + + return rc; +} + +static int digital_in_send_saved_skb(struct nfc_digital_dev *ddev, + struct digital_data_exch *data_exch) +{ + int rc; + + if (!ddev->saved_skb) + return -EINVAL; + + skb_get(ddev->saved_skb); + + rc = digital_in_send_cmd(ddev, ddev->saved_skb, ddev->dep_rwt, + digital_in_recv_dep_res, data_exch); + if (rc) + kfree_skb(ddev->saved_skb); + + return rc; +} + +static void digital_in_recv_dep_res(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + struct digital_data_exch *data_exch = arg; + struct digital_dep_req_res *dep_res; + u8 pfb; + uint size; + int rc; + u8 rtox; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + + if ((rc == -EIO || (rc == -ETIMEDOUT && ddev->nack_count)) && + (ddev->nack_count++ < DIGITAL_NFC_DEP_N_RETRY_NACK)) { + ddev->atn_count = 0; + + rc = digital_in_send_nack(ddev, data_exch); + if (rc) + goto error; + + return; + } else if ((rc == -ETIMEDOUT) && + (ddev->atn_count++ < DIGITAL_NFC_DEP_N_RETRY_ATN)) { + ddev->nack_count = 0; + + rc = digital_in_send_atn(ddev, data_exch); + if (rc) + goto error; + + return; + } + + goto exit; + } + + rc = digital_skb_pull_dep_sod(ddev, resp); + if (rc) { + PROTOCOL_ERR("14.4.1.2"); + goto exit; + } + + rc = ddev->skb_check_crc(resp); + if (rc) { + if ((resp->len >= 4) && + (ddev->nack_count++ < DIGITAL_NFC_DEP_N_RETRY_NACK)) { + ddev->atn_count = 0; + + rc = digital_in_send_nack(ddev, data_exch); + if (rc) + goto error; + + kfree_skb(resp); + + return; + } + + PROTOCOL_ERR("14.4.1.6"); + goto error; + } + + ddev->atn_count = 0; + ddev->nack_count = 0; + + if (resp->len > ddev->local_payload_max) { + rc = -EMSGSIZE; + goto exit; + } + + size = sizeof(struct digital_dep_req_res); + dep_res = (struct digital_dep_req_res *)resp->data; + + if (resp->len < size || dep_res->dir != DIGITAL_NFC_DEP_FRAME_DIR_IN || + dep_res->cmd != DIGITAL_CMD_DEP_RES) { + rc = -EIO; + goto error; + } + + pfb = dep_res->pfb; + + if (DIGITAL_NFC_DEP_DID_BIT_SET(pfb)) { + PROTOCOL_ERR("14.8.2.1"); + rc = -EIO; + goto error; + } + + if (DIGITAL_NFC_DEP_NAD_BIT_SET(pfb)) { + rc = -EIO; + goto exit; + } + + if (size > resp->len) { + rc = -EIO; + goto error; + } + + skb_pull(resp, size); + + switch (DIGITAL_NFC_DEP_PFB_TYPE(pfb)) { + case DIGITAL_NFC_DEP_PFB_I_PDU: + if (DIGITAL_NFC_DEP_PFB_PNI(pfb) != ddev->curr_nfc_dep_pni) { + PROTOCOL_ERR("14.12.3.3"); + rc = -EIO; + goto error; + } + + ddev->curr_nfc_dep_pni = + DIGITAL_NFC_DEP_PFB_PNI(ddev->curr_nfc_dep_pni + 1); + + kfree_skb(ddev->saved_skb); + ddev->saved_skb = NULL; + + resp = digital_recv_dep_data_gather(ddev, pfb, resp, + digital_in_send_ack, + data_exch); + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto error; + } + + /* If resp is NULL then we're still chaining so return and + * wait for the next part of the PDU. Else, the PDU is + * complete so pass it up. + */ + if (!resp) + return; + + rc = 0; + break; + + case DIGITAL_NFC_DEP_PFB_ACK_NACK_PDU: + if (DIGITAL_NFC_DEP_NACK_BIT_SET(pfb)) { + PROTOCOL_ERR("14.12.4.5"); + rc = -EIO; + goto exit; + } + + if (DIGITAL_NFC_DEP_PFB_PNI(pfb) != ddev->curr_nfc_dep_pni) { + PROTOCOL_ERR("14.12.3.3"); + rc = -EIO; + goto exit; + } + + ddev->curr_nfc_dep_pni = + DIGITAL_NFC_DEP_PFB_PNI(ddev->curr_nfc_dep_pni + 1); + + if (!ddev->chaining_skb) { + PROTOCOL_ERR("14.12.4.3"); + rc = -EIO; + goto exit; + } + + /* The initiator has received a valid ACK. Free the last sent + * PDU and keep on sending chained skb. + */ + kfree_skb(ddev->saved_skb); + ddev->saved_skb = NULL; + + rc = digital_in_send_dep_req(ddev, NULL, + ddev->chaining_skb, + ddev->data_exch); + if (rc) + goto error; + + goto free_resp; + + case DIGITAL_NFC_DEP_PFB_SUPERVISOR_PDU: + if (!DIGITAL_NFC_DEP_PFB_IS_TIMEOUT(pfb)) { /* ATN */ + rc = digital_in_send_saved_skb(ddev, data_exch); + if (rc) + goto error; + + goto free_resp; + } + + if (ddev->atn_count || ddev->nack_count) { + PROTOCOL_ERR("14.12.4.4"); + rc = -EIO; + goto error; + } + + rtox = DIGITAL_NFC_DEP_RTOX_VALUE(resp->data[0]); + if (!rtox || rtox > DIGITAL_NFC_DEP_RTOX_MAX) { + PROTOCOL_ERR("14.8.4.1"); + rc = -EIO; + goto error; + } + + rc = digital_in_send_rtox(ddev, data_exch, rtox); + if (rc) + goto error; + + goto free_resp; + } + +exit: + data_exch->cb(data_exch->cb_context, resp, rc); + +error: + kfree(data_exch); + + kfree_skb(ddev->chaining_skb); + ddev->chaining_skb = NULL; + + kfree_skb(ddev->saved_skb); + ddev->saved_skb = NULL; + + if (rc) + kfree_skb(resp); + + return; + +free_resp: + dev_kfree_skb(resp); +} + +int digital_in_send_dep_req(struct nfc_digital_dev *ddev, + struct nfc_target *target, struct sk_buff *skb, + struct digital_data_exch *data_exch) +{ + struct digital_dep_req_res *dep_req; + struct sk_buff *chaining_skb, *tmp_skb; + int rc; + + skb_push(skb, sizeof(struct digital_dep_req_res)); + + dep_req = (struct digital_dep_req_res *)skb->data; + + dep_req->dir = DIGITAL_NFC_DEP_FRAME_DIR_OUT; + dep_req->cmd = DIGITAL_CMD_DEP_REQ; + dep_req->pfb = ddev->curr_nfc_dep_pni; + + ddev->atn_count = 0; + ddev->nack_count = 0; + + chaining_skb = ddev->chaining_skb; + + tmp_skb = digital_send_dep_data_prep(ddev, skb, dep_req, data_exch); + if (IS_ERR(tmp_skb)) + return PTR_ERR(tmp_skb); + + digital_skb_push_dep_sod(ddev, tmp_skb); + + ddev->skb_add_crc(tmp_skb); + + ddev->saved_skb = pskb_copy(tmp_skb, GFP_KERNEL); + + rc = digital_in_send_cmd(ddev, tmp_skb, ddev->dep_rwt, + digital_in_recv_dep_res, data_exch); + if (rc) { + if (tmp_skb != skb) + kfree_skb(tmp_skb); + + kfree_skb(chaining_skb); + ddev->chaining_skb = NULL; + + kfree_skb(ddev->saved_skb); + ddev->saved_skb = NULL; + } + + return rc; +} + +static void digital_tg_set_rf_tech(struct nfc_digital_dev *ddev, u8 rf_tech) +{ + ddev->curr_rf_tech = rf_tech; + + ddev->skb_add_crc = digital_skb_add_crc_none; + ddev->skb_check_crc = digital_skb_check_crc_none; + + if (DIGITAL_DRV_CAPS_TG_CRC(ddev)) + return; + + switch (ddev->curr_rf_tech) { + case NFC_DIGITAL_RF_TECH_106A: + ddev->skb_add_crc = digital_skb_add_crc_a; + ddev->skb_check_crc = digital_skb_check_crc_a; + break; + + case NFC_DIGITAL_RF_TECH_212F: + case NFC_DIGITAL_RF_TECH_424F: + ddev->skb_add_crc = digital_skb_add_crc_f; + ddev->skb_check_crc = digital_skb_check_crc_f; + break; + + default: + break; + } +} + +static int digital_tg_send_ack(struct nfc_digital_dev *ddev, + struct digital_data_exch *data_exch) +{ + struct digital_dep_req_res *dep_res; + struct sk_buff *skb; + int rc; + + skb = digital_skb_alloc(ddev, 1); + if (!skb) + return -ENOMEM; + + skb_push(skb, sizeof(struct digital_dep_req_res)); + + dep_res = (struct digital_dep_req_res *)skb->data; + + dep_res->dir = DIGITAL_NFC_DEP_FRAME_DIR_IN; + dep_res->cmd = DIGITAL_CMD_DEP_RES; + dep_res->pfb = DIGITAL_NFC_DEP_PFB_ACK_NACK_PDU | + ddev->curr_nfc_dep_pni; + + if (ddev->did) { + dep_res->pfb |= DIGITAL_NFC_DEP_PFB_DID_BIT; + + skb_put_data(skb, &ddev->did, sizeof(ddev->did)); + } + + ddev->curr_nfc_dep_pni = + DIGITAL_NFC_DEP_PFB_PNI(ddev->curr_nfc_dep_pni + 1); + + digital_skb_push_dep_sod(ddev, skb); + + ddev->skb_add_crc(skb); + + ddev->saved_skb = pskb_copy(skb, GFP_KERNEL); + + rc = digital_tg_send_cmd(ddev, skb, 1500, digital_tg_recv_dep_req, + data_exch); + if (rc) { + kfree_skb(skb); + kfree_skb(ddev->saved_skb); + ddev->saved_skb = NULL; + } + + return rc; +} + +static int digital_tg_send_atn(struct nfc_digital_dev *ddev) +{ + struct digital_dep_req_res *dep_res; + struct sk_buff *skb; + int rc; + + skb = digital_skb_alloc(ddev, 1); + if (!skb) + return -ENOMEM; + + skb_push(skb, sizeof(struct digital_dep_req_res)); + + dep_res = (struct digital_dep_req_res *)skb->data; + + dep_res->dir = DIGITAL_NFC_DEP_FRAME_DIR_IN; + dep_res->cmd = DIGITAL_CMD_DEP_RES; + dep_res->pfb = DIGITAL_NFC_DEP_PFB_SUPERVISOR_PDU; + + if (ddev->did) { + dep_res->pfb |= DIGITAL_NFC_DEP_PFB_DID_BIT; + + skb_put_data(skb, &ddev->did, sizeof(ddev->did)); + } + + digital_skb_push_dep_sod(ddev, skb); + + ddev->skb_add_crc(skb); + + rc = digital_tg_send_cmd(ddev, skb, 1500, digital_tg_recv_dep_req, + NULL); + if (rc) + kfree_skb(skb); + + return rc; +} + +static int digital_tg_send_saved_skb(struct nfc_digital_dev *ddev) +{ + int rc; + + if (!ddev->saved_skb) + return -EINVAL; + + skb_get(ddev->saved_skb); + + rc = digital_tg_send_cmd(ddev, ddev->saved_skb, 1500, + digital_tg_recv_dep_req, NULL); + if (rc) + kfree_skb(ddev->saved_skb); + + return rc; +} + +static void digital_tg_recv_dep_req(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + int rc; + struct digital_dep_req_res *dep_req; + u8 pfb; + size_t size; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto exit; + } + + rc = ddev->skb_check_crc(resp); + if (rc) { + PROTOCOL_ERR("14.4.1.6"); + goto exit; + } + + rc = digital_skb_pull_dep_sod(ddev, resp); + if (rc) { + PROTOCOL_ERR("14.4.1.2"); + goto exit; + } + + if (resp->len > ddev->local_payload_max) { + rc = -EMSGSIZE; + goto exit; + } + + size = sizeof(struct digital_dep_req_res); + dep_req = (struct digital_dep_req_res *)resp->data; + + if (resp->len < size || dep_req->dir != DIGITAL_NFC_DEP_FRAME_DIR_OUT || + dep_req->cmd != DIGITAL_CMD_DEP_REQ) { + rc = -EIO; + goto exit; + } + + pfb = dep_req->pfb; + + if (DIGITAL_NFC_DEP_DID_BIT_SET(pfb)) { + if (ddev->did && (ddev->did == resp->data[3])) { + size++; + } else { + rc = -EIO; + goto exit; + } + } else if (ddev->did) { + rc = -EIO; + goto exit; + } + + if (DIGITAL_NFC_DEP_NAD_BIT_SET(pfb)) { + rc = -EIO; + goto exit; + } + + if (size > resp->len) { + rc = -EIO; + goto exit; + } + + skb_pull(resp, size); + + switch (DIGITAL_NFC_DEP_PFB_TYPE(pfb)) { + case DIGITAL_NFC_DEP_PFB_I_PDU: + pr_debug("DIGITAL_NFC_DEP_PFB_I_PDU\n"); + + if (ddev->atn_count) { + /* The target has received (and replied to) at least one + * ATN DEP_REQ. + */ + ddev->atn_count = 0; + + /* pni of resp PDU equal to the target current pni - 1 + * means resp is the previous DEP_REQ PDU received from + * the initiator so the target replies with saved_skb + * which is the previous DEP_RES saved in + * digital_tg_send_dep_res(). + */ + if (DIGITAL_NFC_DEP_PFB_PNI(pfb) == + DIGITAL_NFC_DEP_PFB_PNI(ddev->curr_nfc_dep_pni - 1)) { + rc = digital_tg_send_saved_skb(ddev); + if (rc) + goto exit; + + goto free_resp; + } + + /* atn_count > 0 and PDU pni != curr_nfc_dep_pni - 1 + * means the target probably did not received the last + * DEP_REQ PDU sent by the initiator. The target + * fallbacks to normal processing then. + */ + } + + if (DIGITAL_NFC_DEP_PFB_PNI(pfb) != ddev->curr_nfc_dep_pni) { + PROTOCOL_ERR("14.12.3.4"); + rc = -EIO; + goto exit; + } + + kfree_skb(ddev->saved_skb); + ddev->saved_skb = NULL; + + resp = digital_recv_dep_data_gather(ddev, pfb, resp, + digital_tg_send_ack, NULL); + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto exit; + } + + /* If resp is NULL then we're still chaining so return and + * wait for the next part of the PDU. Else, the PDU is + * complete so pass it up. + */ + if (!resp) + return; + + rc = 0; + break; + case DIGITAL_NFC_DEP_PFB_ACK_NACK_PDU: + if (DIGITAL_NFC_DEP_NACK_BIT_SET(pfb)) { /* NACK */ + if (DIGITAL_NFC_DEP_PFB_PNI(pfb + 1) != + ddev->curr_nfc_dep_pni) { + rc = -EIO; + goto exit; + } + + ddev->atn_count = 0; + + rc = digital_tg_send_saved_skb(ddev); + if (rc) + goto exit; + + goto free_resp; + } + + /* ACK */ + if (ddev->atn_count) { + /* The target has previously received one or more ATN + * PDUs. + */ + ddev->atn_count = 0; + + /* If the ACK PNI is equal to the target PNI - 1 means + * that the initiator did not receive the previous PDU + * sent by the target so re-send it. + */ + if (DIGITAL_NFC_DEP_PFB_PNI(pfb + 1) == + ddev->curr_nfc_dep_pni) { + rc = digital_tg_send_saved_skb(ddev); + if (rc) + goto exit; + + goto free_resp; + } + + /* Otherwise, the target did not receive the previous + * ACK PDU from the initiator. Fallback to normal + * processing of chained PDU then. + */ + } + + /* Keep on sending chained PDU */ + if (!ddev->chaining_skb || + DIGITAL_NFC_DEP_PFB_PNI(pfb) != + ddev->curr_nfc_dep_pni) { + rc = -EIO; + goto exit; + } + + kfree_skb(ddev->saved_skb); + ddev->saved_skb = NULL; + + rc = digital_tg_send_dep_res(ddev, ddev->chaining_skb); + if (rc) + goto exit; + + goto free_resp; + case DIGITAL_NFC_DEP_PFB_SUPERVISOR_PDU: + if (DIGITAL_NFC_DEP_PFB_IS_TIMEOUT(pfb)) { + rc = -EINVAL; + goto exit; + } + + rc = digital_tg_send_atn(ddev); + if (rc) + goto exit; + + ddev->atn_count++; + + goto free_resp; + } + + rc = nfc_tm_data_received(ddev->nfc_dev, resp); + if (rc) + resp = NULL; + +exit: + kfree_skb(ddev->chaining_skb); + ddev->chaining_skb = NULL; + + ddev->atn_count = 0; + + kfree_skb(ddev->saved_skb); + ddev->saved_skb = NULL; + + if (rc) + kfree_skb(resp); + + return; + +free_resp: + dev_kfree_skb(resp); +} + +int digital_tg_send_dep_res(struct nfc_digital_dev *ddev, struct sk_buff *skb) +{ + struct digital_dep_req_res *dep_res; + struct sk_buff *chaining_skb, *tmp_skb; + int rc; + + skb_push(skb, sizeof(struct digital_dep_req_res)); + + dep_res = (struct digital_dep_req_res *)skb->data; + + dep_res->dir = DIGITAL_NFC_DEP_FRAME_DIR_IN; + dep_res->cmd = DIGITAL_CMD_DEP_RES; + dep_res->pfb = ddev->curr_nfc_dep_pni; + + if (ddev->did) { + dep_res->pfb |= DIGITAL_NFC_DEP_PFB_DID_BIT; + + skb_put_data(skb, &ddev->did, sizeof(ddev->did)); + } + + ddev->curr_nfc_dep_pni = + DIGITAL_NFC_DEP_PFB_PNI(ddev->curr_nfc_dep_pni + 1); + + chaining_skb = ddev->chaining_skb; + + tmp_skb = digital_send_dep_data_prep(ddev, skb, dep_res, NULL); + if (IS_ERR(tmp_skb)) + return PTR_ERR(tmp_skb); + + digital_skb_push_dep_sod(ddev, tmp_skb); + + ddev->skb_add_crc(tmp_skb); + + ddev->saved_skb = pskb_copy(tmp_skb, GFP_KERNEL); + + rc = digital_tg_send_cmd(ddev, tmp_skb, 1500, digital_tg_recv_dep_req, + NULL); + if (rc) { + if (tmp_skb != skb) + kfree_skb(tmp_skb); + + kfree_skb(chaining_skb); + ddev->chaining_skb = NULL; + + kfree_skb(ddev->saved_skb); + ddev->saved_skb = NULL; + } + + return rc; +} + +static void digital_tg_send_psl_res_complete(struct nfc_digital_dev *ddev, + void *arg, struct sk_buff *resp) +{ + u8 rf_tech = (unsigned long)arg; + + if (IS_ERR(resp)) + return; + + digital_tg_set_rf_tech(ddev, rf_tech); + + digital_tg_configure_hw(ddev, NFC_DIGITAL_CONFIG_RF_TECH, rf_tech); + + digital_tg_listen(ddev, 1500, digital_tg_recv_dep_req, NULL); + + dev_kfree_skb(resp); +} + +static int digital_tg_send_psl_res(struct nfc_digital_dev *ddev, u8 did, + u8 rf_tech) +{ + struct digital_psl_res *psl_res; + struct sk_buff *skb; + int rc; + + skb = digital_skb_alloc(ddev, sizeof(struct digital_psl_res)); + if (!skb) + return -ENOMEM; + + skb_put(skb, sizeof(struct digital_psl_res)); + + psl_res = (struct digital_psl_res *)skb->data; + + psl_res->dir = DIGITAL_NFC_DEP_FRAME_DIR_IN; + psl_res->cmd = DIGITAL_CMD_PSL_RES; + psl_res->did = did; + + digital_skb_push_dep_sod(ddev, skb); + + ddev->skb_add_crc(skb); + + ddev->curr_nfc_dep_pni = 0; + + rc = digital_tg_send_cmd(ddev, skb, 0, digital_tg_send_psl_res_complete, + (void *)(unsigned long)rf_tech); + if (rc) + kfree_skb(skb); + + return rc; +} + +static void digital_tg_recv_psl_req(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + int rc; + struct digital_psl_req *psl_req; + u8 rf_tech; + u8 dsi, payload_size, payload_bits; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto exit; + } + + rc = ddev->skb_check_crc(resp); + if (rc) { + PROTOCOL_ERR("14.4.1.6"); + goto exit; + } + + rc = digital_skb_pull_dep_sod(ddev, resp); + if (rc) { + PROTOCOL_ERR("14.4.1.2"); + goto exit; + } + + psl_req = (struct digital_psl_req *)resp->data; + + if (resp->len != sizeof(struct digital_psl_req) || + psl_req->dir != DIGITAL_NFC_DEP_FRAME_DIR_OUT || + psl_req->cmd != DIGITAL_CMD_PSL_REQ) { + rc = -EIO; + goto exit; + } + + dsi = (psl_req->brs >> 3) & 0x07; + switch (dsi) { + case 0: + rf_tech = NFC_DIGITAL_RF_TECH_106A; + break; + case 1: + rf_tech = NFC_DIGITAL_RF_TECH_212F; + break; + case 2: + rf_tech = NFC_DIGITAL_RF_TECH_424F; + break; + default: + pr_err("Unsupported dsi value %d\n", dsi); + goto exit; + } + + payload_bits = DIGITAL_PAYLOAD_FSL_TO_BITS(psl_req->fsl); + payload_size = digital_payload_bits_to_size(payload_bits); + + if (!payload_size || (payload_size > min(ddev->local_payload_max, + ddev->remote_payload_max))) { + rc = -EINVAL; + goto exit; + } + + ddev->local_payload_max = payload_size; + ddev->remote_payload_max = payload_size; + + rc = digital_tg_send_psl_res(ddev, psl_req->did, rf_tech); + +exit: + kfree_skb(resp); +} + +static void digital_tg_send_atr_res_complete(struct nfc_digital_dev *ddev, + void *arg, struct sk_buff *resp) +{ + int offset; + + if (IS_ERR(resp)) { + digital_poll_next_tech(ddev); + return; + } + + offset = 2; + if (resp->data[0] == DIGITAL_NFC_DEP_NFCA_SOD_SB) + offset++; + + ddev->atn_count = 0; + + if (resp->data[offset] == DIGITAL_CMD_PSL_REQ) + digital_tg_recv_psl_req(ddev, arg, resp); + else + digital_tg_recv_dep_req(ddev, arg, resp); +} + +static int digital_tg_send_atr_res(struct nfc_digital_dev *ddev, + struct digital_atr_req *atr_req) +{ + struct digital_atr_res *atr_res; + struct sk_buff *skb; + u8 *gb, payload_bits; + size_t gb_len; + int rc; + + gb = nfc_get_local_general_bytes(ddev->nfc_dev, &gb_len); + if (!gb) + gb_len = 0; + + skb = digital_skb_alloc(ddev, sizeof(struct digital_atr_res) + gb_len); + if (!skb) + return -ENOMEM; + + skb_put(skb, sizeof(struct digital_atr_res)); + atr_res = (struct digital_atr_res *)skb->data; + + memset(atr_res, 0, sizeof(struct digital_atr_res)); + + atr_res->dir = DIGITAL_NFC_DEP_FRAME_DIR_IN; + atr_res->cmd = DIGITAL_CMD_ATR_RES; + memcpy(atr_res->nfcid3, atr_req->nfcid3, sizeof(atr_req->nfcid3)); + atr_res->to = DIGITAL_NFC_DEP_TG_MAX_WT; + + ddev->local_payload_max = DIGITAL_PAYLOAD_SIZE_MAX; + payload_bits = digital_payload_size_to_bits(ddev->local_payload_max); + atr_res->pp = DIGITAL_PAYLOAD_BITS_TO_PP(payload_bits); + + if (gb_len) { + skb_put(skb, gb_len); + + atr_res->pp |= DIGITAL_GB_BIT; + memcpy(atr_res->gb, gb, gb_len); + } + + digital_skb_push_dep_sod(ddev, skb); + + ddev->skb_add_crc(skb); + + ddev->curr_nfc_dep_pni = 0; + + rc = digital_tg_send_cmd(ddev, skb, 999, + digital_tg_send_atr_res_complete, NULL); + if (rc) + kfree_skb(skb); + + return rc; +} + +void digital_tg_recv_atr_req(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + int rc; + struct digital_atr_req *atr_req; + size_t gb_len, min_size; + u8 poll_tech_count, payload_bits; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto exit; + } + + if (!resp->len) { + rc = -EIO; + goto exit; + } + + if (resp->data[0] == DIGITAL_NFC_DEP_NFCA_SOD_SB) { + min_size = DIGITAL_ATR_REQ_MIN_SIZE + 2; + digital_tg_set_rf_tech(ddev, NFC_DIGITAL_RF_TECH_106A); + } else { + min_size = DIGITAL_ATR_REQ_MIN_SIZE + 1; + digital_tg_set_rf_tech(ddev, NFC_DIGITAL_RF_TECH_212F); + } + + if (resp->len < min_size) { + rc = -EIO; + goto exit; + } + + ddev->curr_protocol = NFC_PROTO_NFC_DEP_MASK; + + rc = ddev->skb_check_crc(resp); + if (rc) { + PROTOCOL_ERR("14.4.1.6"); + goto exit; + } + + rc = digital_skb_pull_dep_sod(ddev, resp); + if (rc) { + PROTOCOL_ERR("14.4.1.2"); + goto exit; + } + + atr_req = (struct digital_atr_req *)resp->data; + + if (atr_req->dir != DIGITAL_NFC_DEP_FRAME_DIR_OUT || + atr_req->cmd != DIGITAL_CMD_ATR_REQ || + atr_req->did > DIGITAL_DID_MAX) { + rc = -EINVAL; + goto exit; + } + + payload_bits = DIGITAL_PAYLOAD_PP_TO_BITS(atr_req->pp); + ddev->remote_payload_max = digital_payload_bits_to_size(payload_bits); + + if (!ddev->remote_payload_max) { + rc = -EINVAL; + goto exit; + } + + ddev->did = atr_req->did; + + rc = digital_tg_configure_hw(ddev, NFC_DIGITAL_CONFIG_FRAMING, + NFC_DIGITAL_FRAMING_NFC_DEP_ACTIVATED); + if (rc) + goto exit; + + rc = digital_tg_send_atr_res(ddev, atr_req); + if (rc) + goto exit; + + gb_len = resp->len - sizeof(struct digital_atr_req); + + poll_tech_count = ddev->poll_tech_count; + ddev->poll_tech_count = 0; + + rc = nfc_tm_activated(ddev->nfc_dev, NFC_PROTO_NFC_DEP_MASK, + NFC_COMM_PASSIVE, atr_req->gb, gb_len); + if (rc) { + ddev->poll_tech_count = poll_tech_count; + goto exit; + } + + rc = 0; +exit: + if (rc) + digital_poll_next_tech(ddev); + + dev_kfree_skb(resp); +} diff --git a/net/nfc/digital_technology.c b/net/nfc/digital_technology.c new file mode 100644 index 000000000..3adf45898 --- /dev/null +++ b/net/nfc/digital_technology.c @@ -0,0 +1,1300 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * NFC Digital Protocol stack + * Copyright (c) 2013, Intel Corporation. + */ + +#define pr_fmt(fmt) "digital: %s: " fmt, __func__ + +#include "digital.h" + +#define DIGITAL_CMD_SENS_REQ 0x26 +#define DIGITAL_CMD_ALL_REQ 0x52 +#define DIGITAL_CMD_SEL_REQ_CL1 0x93 +#define DIGITAL_CMD_SEL_REQ_CL2 0x95 +#define DIGITAL_CMD_SEL_REQ_CL3 0x97 + +#define DIGITAL_SDD_REQ_SEL_PAR 0x20 + +#define DIGITAL_SDD_RES_CT 0x88 +#define DIGITAL_SDD_RES_LEN 5 +#define DIGITAL_SEL_RES_LEN 1 + +#define DIGITAL_SEL_RES_NFCID1_COMPLETE(sel_res) (!((sel_res) & 0x04)) +#define DIGITAL_SEL_RES_IS_T2T(sel_res) (!((sel_res) & 0x60)) +#define DIGITAL_SEL_RES_IS_T4T(sel_res) ((sel_res) & 0x20) +#define DIGITAL_SEL_RES_IS_NFC_DEP(sel_res) ((sel_res) & 0x40) + +#define DIGITAL_SENS_RES_IS_T1T(sens_res) (((sens_res) & 0x0C00) == 0x0C00) +#define DIGITAL_SENS_RES_IS_VALID(sens_res) \ + ((!((sens_res) & 0x001F) && (((sens_res) & 0x0C00) == 0x0C00)) || \ + (((sens_res) & 0x001F) && ((sens_res) & 0x0C00) != 0x0C00)) + +#define DIGITAL_MIFARE_READ_RES_LEN 16 +#define DIGITAL_MIFARE_ACK_RES 0x0A + +#define DIGITAL_CMD_SENSB_REQ 0x05 +#define DIGITAL_SENSB_ADVANCED BIT(5) +#define DIGITAL_SENSB_EXTENDED BIT(4) +#define DIGITAL_SENSB_ALLB_REQ BIT(3) +#define DIGITAL_SENSB_N(n) ((n) & 0x7) + +#define DIGITAL_CMD_SENSB_RES 0x50 + +#define DIGITAL_CMD_ATTRIB_REQ 0x1D +#define DIGITAL_ATTRIB_P1_TR0_DEFAULT (0x0 << 6) +#define DIGITAL_ATTRIB_P1_TR1_DEFAULT (0x0 << 4) +#define DIGITAL_ATTRIB_P1_SUPRESS_EOS BIT(3) +#define DIGITAL_ATTRIB_P1_SUPRESS_SOS BIT(2) +#define DIGITAL_ATTRIB_P2_LISTEN_POLL_1 (0x0 << 6) +#define DIGITAL_ATTRIB_P2_POLL_LISTEN_1 (0x0 << 4) +#define DIGITAL_ATTRIB_P2_MAX_FRAME_256 0x8 +#define DIGITAL_ATTRIB_P4_DID(n) ((n) & 0xf) + +#define DIGITAL_CMD_SENSF_REQ 0x00 +#define DIGITAL_CMD_SENSF_RES 0x01 + +#define DIGITAL_SENSF_RES_MIN_LENGTH 17 +#define DIGITAL_SENSF_RES_RD_AP_B1 0x00 +#define DIGITAL_SENSF_RES_RD_AP_B2 0x8F + +#define DIGITAL_SENSF_REQ_RC_NONE 0 +#define DIGITAL_SENSF_REQ_RC_SC 1 +#define DIGITAL_SENSF_REQ_RC_AP 2 + +#define DIGITAL_CMD_ISO15693_INVENTORY_REQ 0x01 + +#define DIGITAL_ISO15693_REQ_FLAG_DATA_RATE BIT(1) +#define DIGITAL_ISO15693_REQ_FLAG_INVENTORY BIT(2) +#define DIGITAL_ISO15693_REQ_FLAG_NB_SLOTS BIT(5) +#define DIGITAL_ISO15693_RES_FLAG_ERROR BIT(0) +#define DIGITAL_ISO15693_RES_IS_VALID(flags) \ + (!((flags) & DIGITAL_ISO15693_RES_FLAG_ERROR)) + +#define DIGITAL_ISO_DEP_I_PCB 0x02 +#define DIGITAL_ISO_DEP_PNI(pni) ((pni) & 0x01) + +#define DIGITAL_ISO_DEP_PCB_TYPE(pcb) ((pcb) & 0xC0) + +#define DIGITAL_ISO_DEP_I_BLOCK 0x00 + +#define DIGITAL_ISO_DEP_BLOCK_HAS_DID(pcb) ((pcb) & 0x08) + +static const u8 digital_ats_fsc[] = { + 16, 24, 32, 40, 48, 64, 96, 128, +}; + +#define DIGITAL_ATS_FSCI(t0) ((t0) & 0x0F) +#define DIGITAL_SENSB_FSCI(pi2) (((pi2) & 0xF0) >> 4) +#define DIGITAL_ATS_MAX_FSC 256 + +#define DIGITAL_RATS_BYTE1 0xE0 +#define DIGITAL_RATS_PARAM 0x80 + +struct digital_sdd_res { + u8 nfcid1[4]; + u8 bcc; +} __packed; + +struct digital_sel_req { + u8 sel_cmd; + u8 b2; + u8 nfcid1[4]; + u8 bcc; +} __packed; + +struct digital_sensb_req { + u8 cmd; + u8 afi; + u8 param; +} __packed; + +struct digital_sensb_res { + u8 cmd; + u8 nfcid0[4]; + u8 app_data[4]; + u8 proto_info[3]; +} __packed; + +struct digital_attrib_req { + u8 cmd; + u8 nfcid0[4]; + u8 param1; + u8 param2; + u8 param3; + u8 param4; +} __packed; + +struct digital_attrib_res { + u8 mbli_did; +} __packed; + +struct digital_sensf_req { + u8 cmd; + u8 sc1; + u8 sc2; + u8 rc; + u8 tsn; +} __packed; + +struct digital_sensf_res { + u8 cmd; + u8 nfcid2[8]; + u8 pad0[2]; + u8 pad1[3]; + u8 mrti_check; + u8 mrti_update; + u8 pad2; + u8 rd[2]; +} __packed; + +struct digital_iso15693_inv_req { + u8 flags; + u8 cmd; + u8 mask_len; + u64 mask; +} __packed; + +struct digital_iso15693_inv_res { + u8 flags; + u8 dsfid; + u64 uid; +} __packed; + +static int digital_in_send_sdd_req(struct nfc_digital_dev *ddev, + struct nfc_target *target); + +int digital_in_iso_dep_pull_sod(struct nfc_digital_dev *ddev, + struct sk_buff *skb) +{ + u8 pcb; + u8 block_type; + + if (skb->len < 1) + return -EIO; + + pcb = *skb->data; + block_type = DIGITAL_ISO_DEP_PCB_TYPE(pcb); + + /* No support fo R-block nor S-block */ + if (block_type != DIGITAL_ISO_DEP_I_BLOCK) { + pr_err("ISO_DEP R-block and S-block not supported\n"); + return -EIO; + } + + if (DIGITAL_ISO_DEP_BLOCK_HAS_DID(pcb)) { + pr_err("DID field in ISO_DEP PCB not supported\n"); + return -EIO; + } + + skb_pull(skb, 1); + + return 0; +} + +int digital_in_iso_dep_push_sod(struct nfc_digital_dev *ddev, + struct sk_buff *skb) +{ + /* + * Chaining not supported so skb->len + 1 PCB byte + 2 CRC bytes must + * not be greater than remote FSC + */ + if (skb->len + 3 > ddev->target_fsc) + return -EIO; + + skb_push(skb, 1); + + *skb->data = DIGITAL_ISO_DEP_I_PCB | ddev->curr_nfc_dep_pni; + + ddev->curr_nfc_dep_pni = + DIGITAL_ISO_DEP_PNI(ddev->curr_nfc_dep_pni + 1); + + return 0; +} + +static void digital_in_recv_ats(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + struct nfc_target *target = arg; + u8 fsdi; + int rc; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto exit; + } + + if (resp->len < 2) { + rc = -EIO; + goto exit; + } + + fsdi = DIGITAL_ATS_FSCI(resp->data[1]); + if (fsdi >= 8) + ddev->target_fsc = DIGITAL_ATS_MAX_FSC; + else + ddev->target_fsc = digital_ats_fsc[fsdi]; + + ddev->curr_nfc_dep_pni = 0; + + rc = digital_target_found(ddev, target, NFC_PROTO_ISO14443); + +exit: + dev_kfree_skb(resp); + kfree(target); + + if (rc) + digital_poll_next_tech(ddev); +} + +static int digital_in_send_rats(struct nfc_digital_dev *ddev, + struct nfc_target *target) +{ + int rc; + struct sk_buff *skb; + + skb = digital_skb_alloc(ddev, 2); + if (!skb) + return -ENOMEM; + + skb_put_u8(skb, DIGITAL_RATS_BYTE1); + skb_put_u8(skb, DIGITAL_RATS_PARAM); + + rc = digital_in_send_cmd(ddev, skb, 30, digital_in_recv_ats, + target); + if (rc) + kfree_skb(skb); + + return rc; +} + +static void digital_in_recv_sel_res(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + struct nfc_target *target = arg; + int rc; + u8 sel_res; + u8 nfc_proto; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto exit; + } + + if (!DIGITAL_DRV_CAPS_IN_CRC(ddev)) { + rc = digital_skb_check_crc_a(resp); + if (rc) { + PROTOCOL_ERR("4.4.1.3"); + goto exit; + } + } + + if (resp->len != DIGITAL_SEL_RES_LEN) { + rc = -EIO; + goto exit; + } + + sel_res = resp->data[0]; + + if (!DIGITAL_SEL_RES_NFCID1_COMPLETE(sel_res)) { + rc = digital_in_send_sdd_req(ddev, target); + if (rc) + goto exit; + + goto exit_free_skb; + } + + target->sel_res = sel_res; + + if (DIGITAL_SEL_RES_IS_T2T(sel_res)) { + nfc_proto = NFC_PROTO_MIFARE; + } else if (DIGITAL_SEL_RES_IS_NFC_DEP(sel_res)) { + nfc_proto = NFC_PROTO_NFC_DEP; + } else if (DIGITAL_SEL_RES_IS_T4T(sel_res)) { + rc = digital_in_send_rats(ddev, target); + if (rc) + goto exit; + /* + * Skip target_found and don't free it for now. This will be + * done when receiving the ATS + */ + goto exit_free_skb; + } else { + rc = -EOPNOTSUPP; + goto exit; + } + + rc = digital_target_found(ddev, target, nfc_proto); + +exit: + kfree(target); + +exit_free_skb: + dev_kfree_skb(resp); + + if (rc) + digital_poll_next_tech(ddev); +} + +static int digital_in_send_sel_req(struct nfc_digital_dev *ddev, + struct nfc_target *target, + struct digital_sdd_res *sdd_res) +{ + struct sk_buff *skb; + struct digital_sel_req *sel_req; + u8 sel_cmd; + int rc; + + skb = digital_skb_alloc(ddev, sizeof(struct digital_sel_req)); + if (!skb) + return -ENOMEM; + + skb_put(skb, sizeof(struct digital_sel_req)); + sel_req = (struct digital_sel_req *)skb->data; + + if (target->nfcid1_len <= 4) + sel_cmd = DIGITAL_CMD_SEL_REQ_CL1; + else if (target->nfcid1_len < 10) + sel_cmd = DIGITAL_CMD_SEL_REQ_CL2; + else + sel_cmd = DIGITAL_CMD_SEL_REQ_CL3; + + sel_req->sel_cmd = sel_cmd; + sel_req->b2 = 0x70; + memcpy(sel_req->nfcid1, sdd_res->nfcid1, 4); + sel_req->bcc = sdd_res->bcc; + + if (DIGITAL_DRV_CAPS_IN_CRC(ddev)) { + rc = digital_in_configure_hw(ddev, NFC_DIGITAL_CONFIG_FRAMING, + NFC_DIGITAL_FRAMING_NFCA_STANDARD_WITH_CRC_A); + if (rc) + goto exit; + } else { + digital_skb_add_crc_a(skb); + } + + rc = digital_in_send_cmd(ddev, skb, 30, digital_in_recv_sel_res, + target); +exit: + if (rc) + kfree_skb(skb); + + return rc; +} + +static void digital_in_recv_sdd_res(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + struct nfc_target *target = arg; + struct digital_sdd_res *sdd_res; + int rc; + u8 offset, size; + u8 i, bcc; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto exit; + } + + if (resp->len < DIGITAL_SDD_RES_LEN) { + PROTOCOL_ERR("4.7.2.8"); + rc = -EINVAL; + goto exit; + } + + sdd_res = (struct digital_sdd_res *)resp->data; + + for (i = 0, bcc = 0; i < 4; i++) + bcc ^= sdd_res->nfcid1[i]; + + if (bcc != sdd_res->bcc) { + PROTOCOL_ERR("4.7.2.6"); + rc = -EINVAL; + goto exit; + } + + if (sdd_res->nfcid1[0] == DIGITAL_SDD_RES_CT) { + offset = 1; + size = 3; + } else { + offset = 0; + size = 4; + } + + memcpy(target->nfcid1 + target->nfcid1_len, sdd_res->nfcid1 + offset, + size); + target->nfcid1_len += size; + + rc = digital_in_send_sel_req(ddev, target, sdd_res); + +exit: + dev_kfree_skb(resp); + + if (rc) { + kfree(target); + digital_poll_next_tech(ddev); + } +} + +static int digital_in_send_sdd_req(struct nfc_digital_dev *ddev, + struct nfc_target *target) +{ + int rc; + struct sk_buff *skb; + u8 sel_cmd; + + rc = digital_in_configure_hw(ddev, NFC_DIGITAL_CONFIG_FRAMING, + NFC_DIGITAL_FRAMING_NFCA_STANDARD); + if (rc) + return rc; + + skb = digital_skb_alloc(ddev, 2); + if (!skb) + return -ENOMEM; + + if (target->nfcid1_len == 0) + sel_cmd = DIGITAL_CMD_SEL_REQ_CL1; + else if (target->nfcid1_len == 3) + sel_cmd = DIGITAL_CMD_SEL_REQ_CL2; + else + sel_cmd = DIGITAL_CMD_SEL_REQ_CL3; + + skb_put_u8(skb, sel_cmd); + skb_put_u8(skb, DIGITAL_SDD_REQ_SEL_PAR); + + rc = digital_in_send_cmd(ddev, skb, 30, digital_in_recv_sdd_res, + target); + if (rc) + kfree_skb(skb); + + return rc; +} + +static void digital_in_recv_sens_res(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + struct nfc_target *target = NULL; + int rc; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto exit; + } + + if (resp->len < sizeof(u16)) { + rc = -EIO; + goto exit; + } + + target = kzalloc(sizeof(struct nfc_target), GFP_KERNEL); + if (!target) { + rc = -ENOMEM; + goto exit; + } + + target->sens_res = __le16_to_cpu(*(__le16 *)resp->data); + + if (!DIGITAL_SENS_RES_IS_VALID(target->sens_res)) { + PROTOCOL_ERR("4.6.3.3"); + rc = -EINVAL; + goto exit; + } + + if (DIGITAL_SENS_RES_IS_T1T(target->sens_res)) + rc = digital_target_found(ddev, target, NFC_PROTO_JEWEL); + else + rc = digital_in_send_sdd_req(ddev, target); + +exit: + dev_kfree_skb(resp); + + if (rc) { + kfree(target); + digital_poll_next_tech(ddev); + } +} + +int digital_in_send_sens_req(struct nfc_digital_dev *ddev, u8 rf_tech) +{ + struct sk_buff *skb; + int rc; + + rc = digital_in_configure_hw(ddev, NFC_DIGITAL_CONFIG_RF_TECH, + NFC_DIGITAL_RF_TECH_106A); + if (rc) + return rc; + + rc = digital_in_configure_hw(ddev, NFC_DIGITAL_CONFIG_FRAMING, + NFC_DIGITAL_FRAMING_NFCA_SHORT); + if (rc) + return rc; + + skb = digital_skb_alloc(ddev, 1); + if (!skb) + return -ENOMEM; + + skb_put_u8(skb, DIGITAL_CMD_SENS_REQ); + + rc = digital_in_send_cmd(ddev, skb, 30, digital_in_recv_sens_res, NULL); + if (rc) + kfree_skb(skb); + + return rc; +} + +int digital_in_recv_mifare_res(struct sk_buff *resp) +{ + /* Successful READ command response is 16 data bytes + 2 CRC bytes long. + * Since the driver can't differentiate a ACK/NACK response from a valid + * READ response, the CRC calculation must be handled at digital level + * even if the driver supports it for this technology. + */ + if (resp->len == DIGITAL_MIFARE_READ_RES_LEN + DIGITAL_CRC_LEN) { + if (digital_skb_check_crc_a(resp)) { + PROTOCOL_ERR("9.4.1.2"); + return -EIO; + } + + return 0; + } + + /* ACK response (i.e. successful WRITE). */ + if (resp->len == 1 && resp->data[0] == DIGITAL_MIFARE_ACK_RES) { + resp->data[0] = 0; + return 0; + } + + /* NACK and any other responses are treated as error. */ + return -EIO; +} + +static void digital_in_recv_attrib_res(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + struct nfc_target *target = arg; + struct digital_attrib_res *attrib_res; + int rc; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto exit; + } + + if (resp->len < sizeof(*attrib_res)) { + PROTOCOL_ERR("12.6.2"); + rc = -EIO; + goto exit; + } + + attrib_res = (struct digital_attrib_res *)resp->data; + + if (attrib_res->mbli_did & 0x0f) { + PROTOCOL_ERR("12.6.2.1"); + rc = -EIO; + goto exit; + } + + rc = digital_target_found(ddev, target, NFC_PROTO_ISO14443_B); + +exit: + dev_kfree_skb(resp); + kfree(target); + + if (rc) + digital_poll_next_tech(ddev); +} + +static int digital_in_send_attrib_req(struct nfc_digital_dev *ddev, + struct nfc_target *target, + struct digital_sensb_res *sensb_res) +{ + struct digital_attrib_req *attrib_req; + struct sk_buff *skb; + int rc; + + skb = digital_skb_alloc(ddev, sizeof(*attrib_req)); + if (!skb) + return -ENOMEM; + + attrib_req = skb_put(skb, sizeof(*attrib_req)); + + attrib_req->cmd = DIGITAL_CMD_ATTRIB_REQ; + memcpy(attrib_req->nfcid0, sensb_res->nfcid0, + sizeof(attrib_req->nfcid0)); + attrib_req->param1 = DIGITAL_ATTRIB_P1_TR0_DEFAULT | + DIGITAL_ATTRIB_P1_TR1_DEFAULT; + attrib_req->param2 = DIGITAL_ATTRIB_P2_LISTEN_POLL_1 | + DIGITAL_ATTRIB_P2_POLL_LISTEN_1 | + DIGITAL_ATTRIB_P2_MAX_FRAME_256; + attrib_req->param3 = sensb_res->proto_info[1] & 0x07; + attrib_req->param4 = DIGITAL_ATTRIB_P4_DID(0); + + rc = digital_in_send_cmd(ddev, skb, 30, digital_in_recv_attrib_res, + target); + if (rc) + kfree_skb(skb); + + return rc; +} + +static void digital_in_recv_sensb_res(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + struct nfc_target *target = NULL; + struct digital_sensb_res *sensb_res; + u8 fsci; + int rc; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto exit; + } + + if (resp->len != sizeof(*sensb_res)) { + PROTOCOL_ERR("5.6.2.1"); + rc = -EIO; + goto exit; + } + + sensb_res = (struct digital_sensb_res *)resp->data; + + if (sensb_res->cmd != DIGITAL_CMD_SENSB_RES) { + PROTOCOL_ERR("5.6.2"); + rc = -EIO; + goto exit; + } + + if (!(sensb_res->proto_info[1] & BIT(0))) { + PROTOCOL_ERR("5.6.2.12"); + rc = -EIO; + goto exit; + } + + if (sensb_res->proto_info[1] & BIT(3)) { + PROTOCOL_ERR("5.6.2.16"); + rc = -EIO; + goto exit; + } + + fsci = DIGITAL_SENSB_FSCI(sensb_res->proto_info[1]); + if (fsci >= 8) + ddev->target_fsc = DIGITAL_ATS_MAX_FSC; + else + ddev->target_fsc = digital_ats_fsc[fsci]; + + target = kzalloc(sizeof(struct nfc_target), GFP_KERNEL); + if (!target) { + rc = -ENOMEM; + goto exit; + } + + rc = digital_in_send_attrib_req(ddev, target, sensb_res); + +exit: + dev_kfree_skb(resp); + + if (rc) { + kfree(target); + digital_poll_next_tech(ddev); + } +} + +int digital_in_send_sensb_req(struct nfc_digital_dev *ddev, u8 rf_tech) +{ + struct digital_sensb_req *sensb_req; + struct sk_buff *skb; + int rc; + + rc = digital_in_configure_hw(ddev, NFC_DIGITAL_CONFIG_RF_TECH, + NFC_DIGITAL_RF_TECH_106B); + if (rc) + return rc; + + rc = digital_in_configure_hw(ddev, NFC_DIGITAL_CONFIG_FRAMING, + NFC_DIGITAL_FRAMING_NFCB); + if (rc) + return rc; + + skb = digital_skb_alloc(ddev, sizeof(*sensb_req)); + if (!skb) + return -ENOMEM; + + sensb_req = skb_put(skb, sizeof(*sensb_req)); + + sensb_req->cmd = DIGITAL_CMD_SENSB_REQ; + sensb_req->afi = 0x00; /* All families and sub-families */ + sensb_req->param = DIGITAL_SENSB_N(0); + + rc = digital_in_send_cmd(ddev, skb, 30, digital_in_recv_sensb_res, + NULL); + if (rc) + kfree_skb(skb); + + return rc; +} + +static void digital_in_recv_sensf_res(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + int rc; + u8 proto; + struct nfc_target target; + struct digital_sensf_res *sensf_res; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto exit; + } + + if (resp->len < DIGITAL_SENSF_RES_MIN_LENGTH) { + rc = -EIO; + goto exit; + } + + if (!DIGITAL_DRV_CAPS_IN_CRC(ddev)) { + rc = digital_skb_check_crc_f(resp); + if (rc) { + PROTOCOL_ERR("6.4.1.8"); + goto exit; + } + } + + skb_pull(resp, 1); + + memset(&target, 0, sizeof(struct nfc_target)); + + sensf_res = (struct digital_sensf_res *)resp->data; + + memcpy(target.sensf_res, sensf_res, resp->len); + target.sensf_res_len = resp->len; + + memcpy(target.nfcid2, sensf_res->nfcid2, NFC_NFCID2_MAXSIZE); + target.nfcid2_len = NFC_NFCID2_MAXSIZE; + + if (target.nfcid2[0] == DIGITAL_SENSF_NFCID2_NFC_DEP_B1 && + target.nfcid2[1] == DIGITAL_SENSF_NFCID2_NFC_DEP_B2) + proto = NFC_PROTO_NFC_DEP; + else + proto = NFC_PROTO_FELICA; + + rc = digital_target_found(ddev, &target, proto); + +exit: + dev_kfree_skb(resp); + + if (rc) + digital_poll_next_tech(ddev); +} + +int digital_in_send_sensf_req(struct nfc_digital_dev *ddev, u8 rf_tech) +{ + struct digital_sensf_req *sensf_req; + struct sk_buff *skb; + int rc; + u8 size; + + rc = digital_in_configure_hw(ddev, NFC_DIGITAL_CONFIG_RF_TECH, rf_tech); + if (rc) + return rc; + + rc = digital_in_configure_hw(ddev, NFC_DIGITAL_CONFIG_FRAMING, + NFC_DIGITAL_FRAMING_NFCF); + if (rc) + return rc; + + size = sizeof(struct digital_sensf_req); + + skb = digital_skb_alloc(ddev, size); + if (!skb) + return -ENOMEM; + + skb_put(skb, size); + + sensf_req = (struct digital_sensf_req *)skb->data; + sensf_req->cmd = DIGITAL_CMD_SENSF_REQ; + sensf_req->sc1 = 0xFF; + sensf_req->sc2 = 0xFF; + sensf_req->rc = 0; + sensf_req->tsn = 0; + + *(u8 *)skb_push(skb, 1) = size + 1; + + if (!DIGITAL_DRV_CAPS_IN_CRC(ddev)) + digital_skb_add_crc_f(skb); + + rc = digital_in_send_cmd(ddev, skb, 30, digital_in_recv_sensf_res, + NULL); + if (rc) + kfree_skb(skb); + + return rc; +} + +static void digital_in_recv_iso15693_inv_res(struct nfc_digital_dev *ddev, + void *arg, struct sk_buff *resp) +{ + struct digital_iso15693_inv_res *res; + struct nfc_target *target = NULL; + int rc; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto out_free_skb; + } + + if (resp->len != sizeof(*res)) { + rc = -EIO; + goto out_free_skb; + } + + res = (struct digital_iso15693_inv_res *)resp->data; + + if (!DIGITAL_ISO15693_RES_IS_VALID(res->flags)) { + PROTOCOL_ERR("ISO15693 - 10.3.1"); + rc = -EINVAL; + goto out_free_skb; + } + + target = kzalloc(sizeof(*target), GFP_KERNEL); + if (!target) { + rc = -ENOMEM; + goto out_free_skb; + } + + target->is_iso15693 = 1; + target->iso15693_dsfid = res->dsfid; + memcpy(target->iso15693_uid, &res->uid, sizeof(target->iso15693_uid)); + + rc = digital_target_found(ddev, target, NFC_PROTO_ISO15693); + + kfree(target); + +out_free_skb: + dev_kfree_skb(resp); + + if (rc) + digital_poll_next_tech(ddev); +} + +int digital_in_send_iso15693_inv_req(struct nfc_digital_dev *ddev, u8 rf_tech) +{ + struct digital_iso15693_inv_req *req; + struct sk_buff *skb; + int rc; + + rc = digital_in_configure_hw(ddev, NFC_DIGITAL_CONFIG_RF_TECH, + NFC_DIGITAL_RF_TECH_ISO15693); + if (rc) + return rc; + + rc = digital_in_configure_hw(ddev, NFC_DIGITAL_CONFIG_FRAMING, + NFC_DIGITAL_FRAMING_ISO15693_INVENTORY); + if (rc) + return rc; + + skb = digital_skb_alloc(ddev, sizeof(*req)); + if (!skb) + return -ENOMEM; + + skb_put(skb, sizeof(*req) - sizeof(req->mask)); /* No mask */ + req = (struct digital_iso15693_inv_req *)skb->data; + + /* Single sub-carrier, high data rate, no AFI, single slot + * Inventory command + */ + req->flags = DIGITAL_ISO15693_REQ_FLAG_DATA_RATE | + DIGITAL_ISO15693_REQ_FLAG_INVENTORY | + DIGITAL_ISO15693_REQ_FLAG_NB_SLOTS; + req->cmd = DIGITAL_CMD_ISO15693_INVENTORY_REQ; + req->mask_len = 0; + + rc = digital_in_send_cmd(ddev, skb, 30, + digital_in_recv_iso15693_inv_res, NULL); + if (rc) + kfree_skb(skb); + + return rc; +} + +static int digital_tg_send_sel_res(struct nfc_digital_dev *ddev) +{ + struct sk_buff *skb; + int rc; + + skb = digital_skb_alloc(ddev, 1); + if (!skb) + return -ENOMEM; + + skb_put_u8(skb, DIGITAL_SEL_RES_NFC_DEP); + + if (!DIGITAL_DRV_CAPS_TG_CRC(ddev)) + digital_skb_add_crc_a(skb); + + rc = digital_tg_configure_hw(ddev, NFC_DIGITAL_CONFIG_FRAMING, + NFC_DIGITAL_FRAMING_NFCA_ANTICOL_COMPLETE); + if (rc) { + kfree_skb(skb); + return rc; + } + + rc = digital_tg_send_cmd(ddev, skb, 300, digital_tg_recv_atr_req, + NULL); + if (rc) + kfree_skb(skb); + + return rc; +} + +static void digital_tg_recv_sel_req(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + int rc; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto exit; + } + + if (!DIGITAL_DRV_CAPS_TG_CRC(ddev)) { + rc = digital_skb_check_crc_a(resp); + if (rc) { + PROTOCOL_ERR("4.4.1.3"); + goto exit; + } + } + + /* Silently ignore SEL_REQ content and send a SEL_RES for NFC-DEP */ + + rc = digital_tg_send_sel_res(ddev); + +exit: + if (rc) + digital_poll_next_tech(ddev); + + dev_kfree_skb(resp); +} + +static int digital_tg_send_sdd_res(struct nfc_digital_dev *ddev) +{ + struct sk_buff *skb; + struct digital_sdd_res *sdd_res; + int rc, i; + + skb = digital_skb_alloc(ddev, sizeof(struct digital_sdd_res)); + if (!skb) + return -ENOMEM; + + skb_put(skb, sizeof(struct digital_sdd_res)); + sdd_res = (struct digital_sdd_res *)skb->data; + + sdd_res->nfcid1[0] = 0x08; + get_random_bytes(sdd_res->nfcid1 + 1, 3); + + sdd_res->bcc = 0; + for (i = 0; i < 4; i++) + sdd_res->bcc ^= sdd_res->nfcid1[i]; + + rc = digital_tg_configure_hw(ddev, NFC_DIGITAL_CONFIG_FRAMING, + NFC_DIGITAL_FRAMING_NFCA_STANDARD_WITH_CRC_A); + if (rc) { + kfree_skb(skb); + return rc; + } + + rc = digital_tg_send_cmd(ddev, skb, 300, digital_tg_recv_sel_req, + NULL); + if (rc) + kfree_skb(skb); + + return rc; +} + +static void digital_tg_recv_sdd_req(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + u8 *sdd_req; + int rc; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto exit; + } + + sdd_req = resp->data; + + if (resp->len < 2 || sdd_req[0] != DIGITAL_CMD_SEL_REQ_CL1 || + sdd_req[1] != DIGITAL_SDD_REQ_SEL_PAR) { + rc = -EINVAL; + goto exit; + } + + rc = digital_tg_send_sdd_res(ddev); + +exit: + if (rc) + digital_poll_next_tech(ddev); + + dev_kfree_skb(resp); +} + +static int digital_tg_send_sens_res(struct nfc_digital_dev *ddev) +{ + struct sk_buff *skb; + u8 *sens_res; + int rc; + + skb = digital_skb_alloc(ddev, 2); + if (!skb) + return -ENOMEM; + + sens_res = skb_put(skb, 2); + + sens_res[0] = (DIGITAL_SENS_RES_NFC_DEP >> 8) & 0xFF; + sens_res[1] = DIGITAL_SENS_RES_NFC_DEP & 0xFF; + + rc = digital_tg_configure_hw(ddev, NFC_DIGITAL_CONFIG_FRAMING, + NFC_DIGITAL_FRAMING_NFCA_STANDARD); + if (rc) { + kfree_skb(skb); + return rc; + } + + rc = digital_tg_send_cmd(ddev, skb, 300, digital_tg_recv_sdd_req, + NULL); + if (rc) + kfree_skb(skb); + + return rc; +} + +void digital_tg_recv_sens_req(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + u8 sens_req; + int rc; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto exit; + } + + sens_req = resp->data[0]; + + if (!resp->len || (sens_req != DIGITAL_CMD_SENS_REQ && + sens_req != DIGITAL_CMD_ALL_REQ)) { + rc = -EINVAL; + goto exit; + } + + rc = digital_tg_send_sens_res(ddev); + +exit: + if (rc) + digital_poll_next_tech(ddev); + + dev_kfree_skb(resp); +} + +static void digital_tg_recv_atr_or_sensf_req(struct nfc_digital_dev *ddev, + void *arg, struct sk_buff *resp) +{ + if (!IS_ERR(resp) && (resp->len >= 2) && + (resp->data[1] == DIGITAL_CMD_SENSF_REQ)) + digital_tg_recv_sensf_req(ddev, arg, resp); + else + digital_tg_recv_atr_req(ddev, arg, resp); + + return; +} + +static int digital_tg_send_sensf_res(struct nfc_digital_dev *ddev, + struct digital_sensf_req *sensf_req) +{ + struct sk_buff *skb; + u8 size; + int rc; + struct digital_sensf_res *sensf_res; + + size = sizeof(struct digital_sensf_res); + + if (sensf_req->rc == DIGITAL_SENSF_REQ_RC_NONE) + size -= sizeof(sensf_res->rd); + + skb = digital_skb_alloc(ddev, size); + if (!skb) + return -ENOMEM; + + skb_put(skb, size); + + sensf_res = (struct digital_sensf_res *)skb->data; + + memset(sensf_res, 0, size); + + sensf_res->cmd = DIGITAL_CMD_SENSF_RES; + sensf_res->nfcid2[0] = DIGITAL_SENSF_NFCID2_NFC_DEP_B1; + sensf_res->nfcid2[1] = DIGITAL_SENSF_NFCID2_NFC_DEP_B2; + get_random_bytes(&sensf_res->nfcid2[2], 6); + + switch (sensf_req->rc) { + case DIGITAL_SENSF_REQ_RC_SC: + sensf_res->rd[0] = sensf_req->sc1; + sensf_res->rd[1] = sensf_req->sc2; + break; + case DIGITAL_SENSF_REQ_RC_AP: + sensf_res->rd[0] = DIGITAL_SENSF_RES_RD_AP_B1; + sensf_res->rd[1] = DIGITAL_SENSF_RES_RD_AP_B2; + break; + } + + *(u8 *)skb_push(skb, sizeof(u8)) = size + 1; + + if (!DIGITAL_DRV_CAPS_TG_CRC(ddev)) + digital_skb_add_crc_f(skb); + + rc = digital_tg_send_cmd(ddev, skb, 300, + digital_tg_recv_atr_or_sensf_req, NULL); + if (rc) + kfree_skb(skb); + + return rc; +} + +void digital_tg_recv_sensf_req(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + struct digital_sensf_req *sensf_req; + int rc; + + if (IS_ERR(resp)) { + rc = PTR_ERR(resp); + resp = NULL; + goto exit; + } + + if (!DIGITAL_DRV_CAPS_TG_CRC(ddev)) { + rc = digital_skb_check_crc_f(resp); + if (rc) { + PROTOCOL_ERR("6.4.1.8"); + goto exit; + } + } + + if (resp->len != sizeof(struct digital_sensf_req) + 1) { + rc = -EINVAL; + goto exit; + } + + skb_pull(resp, 1); + sensf_req = (struct digital_sensf_req *)resp->data; + + if (sensf_req->cmd != DIGITAL_CMD_SENSF_REQ) { + rc = -EINVAL; + goto exit; + } + + rc = digital_tg_send_sensf_res(ddev, sensf_req); + +exit: + if (rc) + digital_poll_next_tech(ddev); + + dev_kfree_skb(resp); +} + +static int digital_tg_config_nfca(struct nfc_digital_dev *ddev) +{ + int rc; + + rc = digital_tg_configure_hw(ddev, NFC_DIGITAL_CONFIG_RF_TECH, + NFC_DIGITAL_RF_TECH_106A); + if (rc) + return rc; + + return digital_tg_configure_hw(ddev, NFC_DIGITAL_CONFIG_FRAMING, + NFC_DIGITAL_FRAMING_NFCA_NFC_DEP); +} + +int digital_tg_listen_nfca(struct nfc_digital_dev *ddev, u8 rf_tech) +{ + int rc; + + rc = digital_tg_config_nfca(ddev); + if (rc) + return rc; + + return digital_tg_listen(ddev, 300, digital_tg_recv_sens_req, NULL); +} + +static int digital_tg_config_nfcf(struct nfc_digital_dev *ddev, u8 rf_tech) +{ + int rc; + + rc = digital_tg_configure_hw(ddev, NFC_DIGITAL_CONFIG_RF_TECH, rf_tech); + if (rc) + return rc; + + return digital_tg_configure_hw(ddev, NFC_DIGITAL_CONFIG_FRAMING, + NFC_DIGITAL_FRAMING_NFCF_NFC_DEP); +} + +int digital_tg_listen_nfcf(struct nfc_digital_dev *ddev, u8 rf_tech) +{ + int rc; + + rc = digital_tg_config_nfcf(ddev, rf_tech); + if (rc) + return rc; + + return digital_tg_listen(ddev, 300, digital_tg_recv_sensf_req, NULL); +} + +void digital_tg_recv_md_req(struct nfc_digital_dev *ddev, void *arg, + struct sk_buff *resp) +{ + u8 rf_tech; + int rc; + + if (IS_ERR(resp)) { + resp = NULL; + goto exit_free_skb; + } + + rc = ddev->ops->tg_get_rf_tech(ddev, &rf_tech); + if (rc) + goto exit_free_skb; + + switch (rf_tech) { + case NFC_DIGITAL_RF_TECH_106A: + rc = digital_tg_config_nfca(ddev); + if (rc) + goto exit_free_skb; + digital_tg_recv_sens_req(ddev, arg, resp); + break; + case NFC_DIGITAL_RF_TECH_212F: + case NFC_DIGITAL_RF_TECH_424F: + rc = digital_tg_config_nfcf(ddev, rf_tech); + if (rc) + goto exit_free_skb; + digital_tg_recv_sensf_req(ddev, arg, resp); + break; + default: + goto exit_free_skb; + } + + return; + +exit_free_skb: + digital_poll_next_tech(ddev); + dev_kfree_skb(resp); +} diff --git a/net/nfc/hci/Kconfig b/net/nfc/hci/Kconfig new file mode 100644 index 000000000..9500b8a27 --- /dev/null +++ b/net/nfc/hci/Kconfig @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: GPL-2.0-only +config NFC_HCI + depends on NFC + tristate "NFC HCI implementation" + default n + help + Say Y here if you want to build support for a kernel NFC HCI + implementation. This is mostly needed for devices that only process + HCI frames, like for example the NXP pn544. + +config NFC_SHDLC + depends on NFC_HCI + select CRC_CCITT + bool "SHDLC link layer for HCI based NFC drivers" + default n + help + Say yes if you use an NFC HCI driver that requires SHDLC link layer. + If unsure, say N here. diff --git a/net/nfc/hci/Makefile b/net/nfc/hci/Makefile new file mode 100644 index 000000000..5a0aaae6f --- /dev/null +++ b/net/nfc/hci/Makefile @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the Linux NFC HCI layer. +# + +obj-$(CONFIG_NFC_HCI) += hci.o + +hci-y := core.o hcp.o command.o llc.o llc_nop.o +hci-$(CONFIG_NFC_SHDLC) += llc_shdlc.o diff --git a/net/nfc/hci/command.c b/net/nfc/hci/command.c new file mode 100644 index 000000000..af6bacb3b --- /dev/null +++ b/net/nfc/hci/command.c @@ -0,0 +1,344 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (C) 2012 Intel Corporation. All rights reserved. + */ + +#define pr_fmt(fmt) "hci: %s: " fmt, __func__ + +#include +#include +#include +#include + +#include + +#include "hci.h" + +#define MAX_FWI 4949 + +static int nfc_hci_execute_cmd_async(struct nfc_hci_dev *hdev, u8 pipe, u8 cmd, + const u8 *param, size_t param_len, + data_exchange_cb_t cb, void *cb_context) +{ + pr_debug("exec cmd async through pipe=%d, cmd=%d, plen=%zd\n", pipe, + cmd, param_len); + + /* TODO: Define hci cmd execution delay. Should it be the same + * for all commands? + */ + return nfc_hci_hcp_message_tx(hdev, pipe, NFC_HCI_HCP_COMMAND, cmd, + param, param_len, cb, cb_context, MAX_FWI); +} + +/* + * HCI command execution completion callback. + * err will be a standard linux error (may be converted from HCI response) + * skb contains the response data and must be disposed, or may be NULL if + * an error occurred + */ +static void nfc_hci_execute_cb(void *context, struct sk_buff *skb, int err) +{ + struct hcp_exec_waiter *hcp_ew = (struct hcp_exec_waiter *)context; + + pr_debug("HCI Cmd completed with result=%d\n", err); + + hcp_ew->exec_result = err; + if (hcp_ew->exec_result == 0) + hcp_ew->result_skb = skb; + else + kfree_skb(skb); + hcp_ew->exec_complete = true; + + wake_up(hcp_ew->wq); +} + +static int nfc_hci_execute_cmd(struct nfc_hci_dev *hdev, u8 pipe, u8 cmd, + const u8 *param, size_t param_len, + struct sk_buff **skb) +{ + DECLARE_WAIT_QUEUE_HEAD_ONSTACK(ew_wq); + struct hcp_exec_waiter hcp_ew; + hcp_ew.wq = &ew_wq; + hcp_ew.exec_complete = false; + hcp_ew.result_skb = NULL; + + pr_debug("exec cmd sync through pipe=%d, cmd=%d, plen=%zd\n", pipe, + cmd, param_len); + + /* TODO: Define hci cmd execution delay. Should it be the same + * for all commands? + */ + hcp_ew.exec_result = nfc_hci_hcp_message_tx(hdev, pipe, + NFC_HCI_HCP_COMMAND, cmd, + param, param_len, + nfc_hci_execute_cb, &hcp_ew, + MAX_FWI); + if (hcp_ew.exec_result < 0) + return hcp_ew.exec_result; + + wait_event(ew_wq, hcp_ew.exec_complete == true); + + if (hcp_ew.exec_result == 0) { + if (skb) + *skb = hcp_ew.result_skb; + else + kfree_skb(hcp_ew.result_skb); + } + + return hcp_ew.exec_result; +} + +int nfc_hci_send_event(struct nfc_hci_dev *hdev, u8 gate, u8 event, + const u8 *param, size_t param_len) +{ + u8 pipe; + + pr_debug("%d to gate %d\n", event, gate); + + pipe = hdev->gate2pipe[gate]; + if (pipe == NFC_HCI_INVALID_PIPE) + return -EADDRNOTAVAIL; + + return nfc_hci_hcp_message_tx(hdev, pipe, NFC_HCI_HCP_EVENT, event, + param, param_len, NULL, NULL, 0); +} +EXPORT_SYMBOL(nfc_hci_send_event); + +/* + * Execute an hci command sent to gate. + * skb will contain response data if success. skb can be NULL if you are not + * interested by the response. + */ +int nfc_hci_send_cmd(struct nfc_hci_dev *hdev, u8 gate, u8 cmd, + const u8 *param, size_t param_len, struct sk_buff **skb) +{ + u8 pipe; + + pipe = hdev->gate2pipe[gate]; + if (pipe == NFC_HCI_INVALID_PIPE) + return -EADDRNOTAVAIL; + + return nfc_hci_execute_cmd(hdev, pipe, cmd, param, param_len, skb); +} +EXPORT_SYMBOL(nfc_hci_send_cmd); + +int nfc_hci_send_cmd_async(struct nfc_hci_dev *hdev, u8 gate, u8 cmd, + const u8 *param, size_t param_len, + data_exchange_cb_t cb, void *cb_context) +{ + u8 pipe; + + pipe = hdev->gate2pipe[gate]; + if (pipe == NFC_HCI_INVALID_PIPE) + return -EADDRNOTAVAIL; + + return nfc_hci_execute_cmd_async(hdev, pipe, cmd, param, param_len, + cb, cb_context); +} +EXPORT_SYMBOL(nfc_hci_send_cmd_async); + +int nfc_hci_set_param(struct nfc_hci_dev *hdev, u8 gate, u8 idx, + const u8 *param, size_t param_len) +{ + int r; + u8 *tmp; + + /* TODO ELa: reg idx must be inserted before param, but we don't want + * to ask the caller to do it to keep a simpler API. + * For now, just create a new temporary param buffer. This is far from + * optimal though, and the plan is to modify APIs to pass idx down to + * nfc_hci_hcp_message_tx where the frame is actually built, thereby + * eliminating the need for the temp allocation-copy here. + */ + + pr_debug("idx=%d to gate %d\n", idx, gate); + + tmp = kmalloc(1 + param_len, GFP_KERNEL); + if (tmp == NULL) + return -ENOMEM; + + *tmp = idx; + memcpy(tmp + 1, param, param_len); + + r = nfc_hci_send_cmd(hdev, gate, NFC_HCI_ANY_SET_PARAMETER, + tmp, param_len + 1, NULL); + + kfree(tmp); + + return r; +} +EXPORT_SYMBOL(nfc_hci_set_param); + +int nfc_hci_get_param(struct nfc_hci_dev *hdev, u8 gate, u8 idx, + struct sk_buff **skb) +{ + pr_debug("gate=%d regidx=%d\n", gate, idx); + + return nfc_hci_send_cmd(hdev, gate, NFC_HCI_ANY_GET_PARAMETER, + &idx, 1, skb); +} +EXPORT_SYMBOL(nfc_hci_get_param); + +static int nfc_hci_open_pipe(struct nfc_hci_dev *hdev, u8 pipe) +{ + struct sk_buff *skb; + int r; + + pr_debug("pipe=%d\n", pipe); + + r = nfc_hci_execute_cmd(hdev, pipe, NFC_HCI_ANY_OPEN_PIPE, + NULL, 0, &skb); + if (r == 0) { + /* dest host other than host controller will send + * number of pipes already open on this gate before + * execution. The number can be found in skb->data[0] + */ + kfree_skb(skb); + } + + return r; +} + +static int nfc_hci_close_pipe(struct nfc_hci_dev *hdev, u8 pipe) +{ + return nfc_hci_execute_cmd(hdev, pipe, NFC_HCI_ANY_CLOSE_PIPE, + NULL, 0, NULL); +} + +static u8 nfc_hci_create_pipe(struct nfc_hci_dev *hdev, u8 dest_host, + u8 dest_gate, int *result) +{ + struct sk_buff *skb; + struct hci_create_pipe_params params; + struct hci_create_pipe_resp *resp; + u8 pipe; + + pr_debug("gate=%d\n", dest_gate); + + params.src_gate = NFC_HCI_ADMIN_GATE; + params.dest_host = dest_host; + params.dest_gate = dest_gate; + + *result = nfc_hci_execute_cmd(hdev, NFC_HCI_ADMIN_PIPE, + NFC_HCI_ADM_CREATE_PIPE, + (u8 *) ¶ms, sizeof(params), &skb); + if (*result < 0) + return NFC_HCI_INVALID_PIPE; + + resp = (struct hci_create_pipe_resp *)skb->data; + pipe = resp->pipe; + kfree_skb(skb); + + pr_debug("pipe created=%d\n", pipe); + + return pipe; +} + +static int nfc_hci_delete_pipe(struct nfc_hci_dev *hdev, u8 pipe) +{ + return nfc_hci_execute_cmd(hdev, NFC_HCI_ADMIN_PIPE, + NFC_HCI_ADM_DELETE_PIPE, &pipe, 1, NULL); +} + +static int nfc_hci_clear_all_pipes(struct nfc_hci_dev *hdev) +{ + u8 param[2]; + size_t param_len = 2; + + /* TODO: Find out what the identity reference data is + * and fill param with it. HCI spec 6.1.3.5 */ + + if (test_bit(NFC_HCI_QUIRK_SHORT_CLEAR, &hdev->quirks)) + param_len = 0; + + return nfc_hci_execute_cmd(hdev, NFC_HCI_ADMIN_PIPE, + NFC_HCI_ADM_CLEAR_ALL_PIPE, param, param_len, + NULL); +} + +int nfc_hci_disconnect_gate(struct nfc_hci_dev *hdev, u8 gate) +{ + int r; + u8 pipe = hdev->gate2pipe[gate]; + + if (pipe == NFC_HCI_INVALID_PIPE) + return -EADDRNOTAVAIL; + + r = nfc_hci_close_pipe(hdev, pipe); + if (r < 0) + return r; + + if (pipe != NFC_HCI_LINK_MGMT_PIPE && pipe != NFC_HCI_ADMIN_PIPE) { + r = nfc_hci_delete_pipe(hdev, pipe); + if (r < 0) + return r; + } + + hdev->gate2pipe[gate] = NFC_HCI_INVALID_PIPE; + + return 0; +} +EXPORT_SYMBOL(nfc_hci_disconnect_gate); + +int nfc_hci_disconnect_all_gates(struct nfc_hci_dev *hdev) +{ + int r; + + r = nfc_hci_clear_all_pipes(hdev); + if (r < 0) + return r; + + nfc_hci_reset_pipes(hdev); + + return 0; +} +EXPORT_SYMBOL(nfc_hci_disconnect_all_gates); + +int nfc_hci_connect_gate(struct nfc_hci_dev *hdev, u8 dest_host, u8 dest_gate, + u8 pipe) +{ + bool pipe_created = false; + int r; + + if (pipe == NFC_HCI_DO_NOT_CREATE_PIPE) + return 0; + + if (hdev->gate2pipe[dest_gate] != NFC_HCI_INVALID_PIPE) + return -EADDRINUSE; + + if (pipe != NFC_HCI_INVALID_PIPE) + goto open_pipe; + + switch (dest_gate) { + case NFC_HCI_LINK_MGMT_GATE: + pipe = NFC_HCI_LINK_MGMT_PIPE; + break; + case NFC_HCI_ADMIN_GATE: + pipe = NFC_HCI_ADMIN_PIPE; + break; + default: + pipe = nfc_hci_create_pipe(hdev, dest_host, dest_gate, &r); + if (pipe == NFC_HCI_INVALID_PIPE) + return r; + pipe_created = true; + break; + } + +open_pipe: + r = nfc_hci_open_pipe(hdev, pipe); + if (r < 0) { + if (pipe_created) + if (nfc_hci_delete_pipe(hdev, pipe) < 0) { + /* TODO: Cannot clean by deleting pipe... + * -> inconsistent state */ + } + return r; + } + + hdev->pipes[pipe].gate = dest_gate; + hdev->pipes[pipe].dest_host = dest_host; + hdev->gate2pipe[dest_gate] = pipe; + + return 0; +} +EXPORT_SYMBOL(nfc_hci_connect_gate); diff --git a/net/nfc/hci/core.c b/net/nfc/hci/core.c new file mode 100644 index 000000000..ceb87db57 --- /dev/null +++ b/net/nfc/hci/core.c @@ -0,0 +1,1105 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (C) 2012 Intel Corporation. All rights reserved. + */ + +#define pr_fmt(fmt) "hci: %s: " fmt, __func__ + +#include +#include +#include +#include + +#include +#include +#include + +#include "hci.h" + +/* Largest headroom needed for outgoing HCI commands */ +#define HCI_CMDS_HEADROOM 1 + +int nfc_hci_result_to_errno(u8 result) +{ + switch (result) { + case NFC_HCI_ANY_OK: + return 0; + case NFC_HCI_ANY_E_REG_PAR_UNKNOWN: + return -EOPNOTSUPP; + case NFC_HCI_ANY_E_TIMEOUT: + return -ETIME; + default: + return -1; + } +} +EXPORT_SYMBOL(nfc_hci_result_to_errno); + +void nfc_hci_reset_pipes(struct nfc_hci_dev *hdev) +{ + int i = 0; + + for (i = 0; i < NFC_HCI_MAX_PIPES; i++) { + hdev->pipes[i].gate = NFC_HCI_INVALID_GATE; + hdev->pipes[i].dest_host = NFC_HCI_INVALID_HOST; + } + memset(hdev->gate2pipe, NFC_HCI_INVALID_PIPE, sizeof(hdev->gate2pipe)); +} +EXPORT_SYMBOL(nfc_hci_reset_pipes); + +void nfc_hci_reset_pipes_per_host(struct nfc_hci_dev *hdev, u8 host) +{ + int i = 0; + + for (i = 0; i < NFC_HCI_MAX_PIPES; i++) { + if (hdev->pipes[i].dest_host != host) + continue; + + hdev->pipes[i].gate = NFC_HCI_INVALID_GATE; + hdev->pipes[i].dest_host = NFC_HCI_INVALID_HOST; + } +} +EXPORT_SYMBOL(nfc_hci_reset_pipes_per_host); + +static void nfc_hci_msg_tx_work(struct work_struct *work) +{ + struct nfc_hci_dev *hdev = container_of(work, struct nfc_hci_dev, + msg_tx_work); + struct hci_msg *msg; + struct sk_buff *skb; + int r = 0; + + mutex_lock(&hdev->msg_tx_mutex); + if (hdev->shutting_down) + goto exit; + + if (hdev->cmd_pending_msg) { + if (timer_pending(&hdev->cmd_timer) == 0) { + if (hdev->cmd_pending_msg->cb) + hdev->cmd_pending_msg->cb(hdev-> + cmd_pending_msg-> + cb_context, + NULL, + -ETIME); + kfree(hdev->cmd_pending_msg); + hdev->cmd_pending_msg = NULL; + } else { + goto exit; + } + } + +next_msg: + if (list_empty(&hdev->msg_tx_queue)) + goto exit; + + msg = list_first_entry(&hdev->msg_tx_queue, struct hci_msg, msg_l); + list_del(&msg->msg_l); + + pr_debug("msg_tx_queue has a cmd to send\n"); + while ((skb = skb_dequeue(&msg->msg_frags)) != NULL) { + r = nfc_llc_xmit_from_hci(hdev->llc, skb); + if (r < 0) { + kfree_skb(skb); + skb_queue_purge(&msg->msg_frags); + if (msg->cb) + msg->cb(msg->cb_context, NULL, r); + kfree(msg); + break; + } + } + + if (r) + goto next_msg; + + if (msg->wait_response == false) { + kfree(msg); + goto next_msg; + } + + hdev->cmd_pending_msg = msg; + mod_timer(&hdev->cmd_timer, jiffies + + msecs_to_jiffies(hdev->cmd_pending_msg->completion_delay)); + +exit: + mutex_unlock(&hdev->msg_tx_mutex); +} + +static void nfc_hci_msg_rx_work(struct work_struct *work) +{ + struct nfc_hci_dev *hdev = container_of(work, struct nfc_hci_dev, + msg_rx_work); + struct sk_buff *skb; + const struct hcp_message *message; + u8 pipe; + u8 type; + u8 instruction; + + while ((skb = skb_dequeue(&hdev->msg_rx_queue)) != NULL) { + pipe = skb->data[0]; + skb_pull(skb, NFC_HCI_HCP_PACKET_HEADER_LEN); + message = (struct hcp_message *)skb->data; + type = HCP_MSG_GET_TYPE(message->header); + instruction = HCP_MSG_GET_CMD(message->header); + skb_pull(skb, NFC_HCI_HCP_MESSAGE_HEADER_LEN); + + nfc_hci_hcp_message_rx(hdev, pipe, type, instruction, skb); + } +} + +static void __nfc_hci_cmd_completion(struct nfc_hci_dev *hdev, int err, + struct sk_buff *skb) +{ + del_timer_sync(&hdev->cmd_timer); + + if (hdev->cmd_pending_msg->cb) + hdev->cmd_pending_msg->cb(hdev->cmd_pending_msg->cb_context, + skb, err); + else + kfree_skb(skb); + + kfree(hdev->cmd_pending_msg); + hdev->cmd_pending_msg = NULL; + + schedule_work(&hdev->msg_tx_work); +} + +void nfc_hci_resp_received(struct nfc_hci_dev *hdev, u8 result, + struct sk_buff *skb) +{ + mutex_lock(&hdev->msg_tx_mutex); + + if (hdev->cmd_pending_msg == NULL) { + kfree_skb(skb); + goto exit; + } + + __nfc_hci_cmd_completion(hdev, nfc_hci_result_to_errno(result), skb); + +exit: + mutex_unlock(&hdev->msg_tx_mutex); +} + +void nfc_hci_cmd_received(struct nfc_hci_dev *hdev, u8 pipe, u8 cmd, + struct sk_buff *skb) +{ + u8 status = NFC_HCI_ANY_OK; + const struct hci_create_pipe_resp *create_info; + const struct hci_delete_pipe_noti *delete_info; + const struct hci_all_pipe_cleared_noti *cleared_info; + u8 gate; + + pr_debug("from pipe %x cmd %x\n", pipe, cmd); + + if (pipe >= NFC_HCI_MAX_PIPES) { + status = NFC_HCI_ANY_E_NOK; + goto exit; + } + + gate = hdev->pipes[pipe].gate; + + switch (cmd) { + case NFC_HCI_ADM_NOTIFY_PIPE_CREATED: + if (skb->len != 5) { + status = NFC_HCI_ANY_E_NOK; + goto exit; + } + create_info = (struct hci_create_pipe_resp *)skb->data; + + if (create_info->pipe >= NFC_HCI_MAX_PIPES) { + status = NFC_HCI_ANY_E_NOK; + goto exit; + } + + /* Save the new created pipe and bind with local gate, + * the description for skb->data[3] is destination gate id + * but since we received this cmd from host controller, we + * are the destination and it is our local gate + */ + hdev->gate2pipe[create_info->dest_gate] = create_info->pipe; + hdev->pipes[create_info->pipe].gate = create_info->dest_gate; + hdev->pipes[create_info->pipe].dest_host = + create_info->src_host; + break; + case NFC_HCI_ANY_OPEN_PIPE: + if (gate == NFC_HCI_INVALID_GATE) { + status = NFC_HCI_ANY_E_NOK; + goto exit; + } + break; + case NFC_HCI_ADM_NOTIFY_PIPE_DELETED: + if (skb->len != 1) { + status = NFC_HCI_ANY_E_NOK; + goto exit; + } + delete_info = (struct hci_delete_pipe_noti *)skb->data; + + if (delete_info->pipe >= NFC_HCI_MAX_PIPES) { + status = NFC_HCI_ANY_E_NOK; + goto exit; + } + + hdev->pipes[delete_info->pipe].gate = NFC_HCI_INVALID_GATE; + hdev->pipes[delete_info->pipe].dest_host = NFC_HCI_INVALID_HOST; + break; + case NFC_HCI_ADM_NOTIFY_ALL_PIPE_CLEARED: + if (skb->len != 1) { + status = NFC_HCI_ANY_E_NOK; + goto exit; + } + cleared_info = (struct hci_all_pipe_cleared_noti *)skb->data; + + nfc_hci_reset_pipes_per_host(hdev, cleared_info->host); + break; + default: + pr_info("Discarded unknown cmd %x to gate %x\n", cmd, gate); + break; + } + + if (hdev->ops->cmd_received) + hdev->ops->cmd_received(hdev, pipe, cmd, skb); + +exit: + nfc_hci_hcp_message_tx(hdev, pipe, NFC_HCI_HCP_RESPONSE, + status, NULL, 0, NULL, NULL, 0); + + kfree_skb(skb); +} + +u32 nfc_hci_sak_to_protocol(u8 sak) +{ + switch (NFC_HCI_TYPE_A_SEL_PROT(sak)) { + case NFC_HCI_TYPE_A_SEL_PROT_MIFARE: + return NFC_PROTO_MIFARE_MASK; + case NFC_HCI_TYPE_A_SEL_PROT_ISO14443: + return NFC_PROTO_ISO14443_MASK; + case NFC_HCI_TYPE_A_SEL_PROT_DEP: + return NFC_PROTO_NFC_DEP_MASK; + case NFC_HCI_TYPE_A_SEL_PROT_ISO14443_DEP: + return NFC_PROTO_ISO14443_MASK | NFC_PROTO_NFC_DEP_MASK; + default: + return 0xffffffff; + } +} +EXPORT_SYMBOL(nfc_hci_sak_to_protocol); + +int nfc_hci_target_discovered(struct nfc_hci_dev *hdev, u8 gate) +{ + struct nfc_target *targets; + struct sk_buff *atqa_skb = NULL; + struct sk_buff *sak_skb = NULL; + struct sk_buff *uid_skb = NULL; + int r; + + pr_debug("from gate %d\n", gate); + + targets = kzalloc(sizeof(struct nfc_target), GFP_KERNEL); + if (targets == NULL) + return -ENOMEM; + + switch (gate) { + case NFC_HCI_RF_READER_A_GATE: + r = nfc_hci_get_param(hdev, NFC_HCI_RF_READER_A_GATE, + NFC_HCI_RF_READER_A_ATQA, &atqa_skb); + if (r < 0) + goto exit; + + r = nfc_hci_get_param(hdev, NFC_HCI_RF_READER_A_GATE, + NFC_HCI_RF_READER_A_SAK, &sak_skb); + if (r < 0) + goto exit; + + if (atqa_skb->len != 2 || sak_skb->len != 1) { + r = -EPROTO; + goto exit; + } + + targets->supported_protocols = + nfc_hci_sak_to_protocol(sak_skb->data[0]); + if (targets->supported_protocols == 0xffffffff) { + r = -EPROTO; + goto exit; + } + + targets->sens_res = be16_to_cpu(*(__be16 *)atqa_skb->data); + targets->sel_res = sak_skb->data[0]; + + r = nfc_hci_get_param(hdev, NFC_HCI_RF_READER_A_GATE, + NFC_HCI_RF_READER_A_UID, &uid_skb); + if (r < 0) + goto exit; + + if (uid_skb->len == 0 || uid_skb->len > NFC_NFCID1_MAXSIZE) { + r = -EPROTO; + goto exit; + } + + memcpy(targets->nfcid1, uid_skb->data, uid_skb->len); + targets->nfcid1_len = uid_skb->len; + + if (hdev->ops->complete_target_discovered) { + r = hdev->ops->complete_target_discovered(hdev, gate, + targets); + if (r < 0) + goto exit; + } + break; + case NFC_HCI_RF_READER_B_GATE: + targets->supported_protocols = NFC_PROTO_ISO14443_B_MASK; + break; + default: + if (hdev->ops->target_from_gate) + r = hdev->ops->target_from_gate(hdev, gate, targets); + else + r = -EPROTO; + if (r < 0) + goto exit; + + if (hdev->ops->complete_target_discovered) { + r = hdev->ops->complete_target_discovered(hdev, gate, + targets); + if (r < 0) + goto exit; + } + break; + } + + /* if driver set the new gate, we will skip the old one */ + if (targets->hci_reader_gate == 0x00) + targets->hci_reader_gate = gate; + + r = nfc_targets_found(hdev->ndev, targets, 1); + +exit: + kfree(targets); + kfree_skb(atqa_skb); + kfree_skb(sak_skb); + kfree_skb(uid_skb); + + return r; +} +EXPORT_SYMBOL(nfc_hci_target_discovered); + +void nfc_hci_event_received(struct nfc_hci_dev *hdev, u8 pipe, u8 event, + struct sk_buff *skb) +{ + int r = 0; + u8 gate; + + if (pipe >= NFC_HCI_MAX_PIPES) { + pr_err("Discarded event %x to invalid pipe %x\n", event, pipe); + goto exit; + } + + gate = hdev->pipes[pipe].gate; + if (gate == NFC_HCI_INVALID_GATE) { + pr_err("Discarded event %x to unopened pipe %x\n", event, pipe); + goto exit; + } + + if (hdev->ops->event_received) { + r = hdev->ops->event_received(hdev, pipe, event, skb); + if (r <= 0) + goto exit_noskb; + } + + switch (event) { + case NFC_HCI_EVT_TARGET_DISCOVERED: + if (skb->len < 1) { /* no status data? */ + r = -EPROTO; + goto exit; + } + + if (skb->data[0] == 3) { + /* TODO: Multiple targets in field, none activated + * poll is supposedly stopped, but there is no + * single target to activate, so nothing to report + * up. + * if we need to restart poll, we must save the + * protocols from the initial poll and reuse here. + */ + } + + if (skb->data[0] != 0) { + r = -EPROTO; + goto exit; + } + + r = nfc_hci_target_discovered(hdev, gate); + break; + default: + pr_info("Discarded unknown event %x to gate %x\n", event, gate); + r = -EINVAL; + break; + } + +exit: + kfree_skb(skb); + +exit_noskb: + if (r) + nfc_hci_driver_failure(hdev, r); +} + +static void nfc_hci_cmd_timeout(struct timer_list *t) +{ + struct nfc_hci_dev *hdev = from_timer(hdev, t, cmd_timer); + + schedule_work(&hdev->msg_tx_work); +} + +static int hci_dev_connect_gates(struct nfc_hci_dev *hdev, u8 gate_count, + const struct nfc_hci_gate *gates) +{ + int r; + while (gate_count--) { + r = nfc_hci_connect_gate(hdev, NFC_HCI_HOST_CONTROLLER_ID, + gates->gate, gates->pipe); + if (r < 0) + return r; + gates++; + } + + return 0; +} + +static int hci_dev_session_init(struct nfc_hci_dev *hdev) +{ + struct sk_buff *skb = NULL; + int r; + + if (hdev->init_data.gates[0].gate != NFC_HCI_ADMIN_GATE) + return -EPROTO; + + r = nfc_hci_connect_gate(hdev, NFC_HCI_HOST_CONTROLLER_ID, + hdev->init_data.gates[0].gate, + hdev->init_data.gates[0].pipe); + if (r < 0) + goto exit; + + r = nfc_hci_get_param(hdev, NFC_HCI_ADMIN_GATE, + NFC_HCI_ADMIN_SESSION_IDENTITY, &skb); + if (r < 0) + goto disconnect_all; + + if (skb->len && skb->len == strlen(hdev->init_data.session_id) && + (memcmp(hdev->init_data.session_id, skb->data, + skb->len) == 0) && hdev->ops->load_session) { + /* Restore gate<->pipe table from some proprietary location. */ + + r = hdev->ops->load_session(hdev); + + if (r < 0) + goto disconnect_all; + } else { + + r = nfc_hci_disconnect_all_gates(hdev); + if (r < 0) + goto exit; + + r = hci_dev_connect_gates(hdev, hdev->init_data.gate_count, + hdev->init_data.gates); + if (r < 0) + goto disconnect_all; + + r = nfc_hci_set_param(hdev, NFC_HCI_ADMIN_GATE, + NFC_HCI_ADMIN_SESSION_IDENTITY, + hdev->init_data.session_id, + strlen(hdev->init_data.session_id)); + } + if (r == 0) + goto exit; + +disconnect_all: + nfc_hci_disconnect_all_gates(hdev); + +exit: + kfree_skb(skb); + + return r; +} + +static int hci_dev_version(struct nfc_hci_dev *hdev) +{ + int r; + struct sk_buff *skb; + + r = nfc_hci_get_param(hdev, NFC_HCI_ID_MGMT_GATE, + NFC_HCI_ID_MGMT_VERSION_SW, &skb); + if (r == -EOPNOTSUPP) { + pr_info("Software/Hardware info not available\n"); + return 0; + } + if (r < 0) + return r; + + if (skb->len != 3) { + kfree_skb(skb); + return -EINVAL; + } + + hdev->sw_romlib = (skb->data[0] & 0xf0) >> 4; + hdev->sw_patch = skb->data[0] & 0x0f; + hdev->sw_flashlib_major = skb->data[1]; + hdev->sw_flashlib_minor = skb->data[2]; + + kfree_skb(skb); + + r = nfc_hci_get_param(hdev, NFC_HCI_ID_MGMT_GATE, + NFC_HCI_ID_MGMT_VERSION_HW, &skb); + if (r < 0) + return r; + + if (skb->len != 3) { + kfree_skb(skb); + return -EINVAL; + } + + hdev->hw_derivative = (skb->data[0] & 0xe0) >> 5; + hdev->hw_version = skb->data[0] & 0x1f; + hdev->hw_mpw = (skb->data[1] & 0xc0) >> 6; + hdev->hw_software = skb->data[1] & 0x3f; + hdev->hw_bsid = skb->data[2]; + + kfree_skb(skb); + + pr_info("SOFTWARE INFO:\n"); + pr_info("RomLib : %d\n", hdev->sw_romlib); + pr_info("Patch : %d\n", hdev->sw_patch); + pr_info("FlashLib Major : %d\n", hdev->sw_flashlib_major); + pr_info("FlashLib Minor : %d\n", hdev->sw_flashlib_minor); + pr_info("HARDWARE INFO:\n"); + pr_info("Derivative : %d\n", hdev->hw_derivative); + pr_info("HW Version : %d\n", hdev->hw_version); + pr_info("#MPW : %d\n", hdev->hw_mpw); + pr_info("Software : %d\n", hdev->hw_software); + pr_info("BSID Version : %d\n", hdev->hw_bsid); + + return 0; +} + +static int hci_dev_up(struct nfc_dev *nfc_dev) +{ + struct nfc_hci_dev *hdev = nfc_get_drvdata(nfc_dev); + int r = 0; + + if (hdev->ops->open) { + r = hdev->ops->open(hdev); + if (r < 0) + return r; + } + + r = nfc_llc_start(hdev->llc); + if (r < 0) + goto exit_close; + + r = hci_dev_session_init(hdev); + if (r < 0) + goto exit_llc; + + r = nfc_hci_send_event(hdev, NFC_HCI_RF_READER_A_GATE, + NFC_HCI_EVT_END_OPERATION, NULL, 0); + if (r < 0) + goto exit_llc; + + if (hdev->ops->hci_ready) { + r = hdev->ops->hci_ready(hdev); + if (r < 0) + goto exit_llc; + } + + r = hci_dev_version(hdev); + if (r < 0) + goto exit_llc; + + return 0; + +exit_llc: + nfc_llc_stop(hdev->llc); + +exit_close: + if (hdev->ops->close) + hdev->ops->close(hdev); + + return r; +} + +static int hci_dev_down(struct nfc_dev *nfc_dev) +{ + struct nfc_hci_dev *hdev = nfc_get_drvdata(nfc_dev); + + nfc_llc_stop(hdev->llc); + + if (hdev->ops->close) + hdev->ops->close(hdev); + + nfc_hci_reset_pipes(hdev); + + return 0; +} + +static int hci_start_poll(struct nfc_dev *nfc_dev, + u32 im_protocols, u32 tm_protocols) +{ + struct nfc_hci_dev *hdev = nfc_get_drvdata(nfc_dev); + + if (hdev->ops->start_poll) + return hdev->ops->start_poll(hdev, im_protocols, tm_protocols); + else + return nfc_hci_send_event(hdev, NFC_HCI_RF_READER_A_GATE, + NFC_HCI_EVT_READER_REQUESTED, + NULL, 0); +} + +static void hci_stop_poll(struct nfc_dev *nfc_dev) +{ + struct nfc_hci_dev *hdev = nfc_get_drvdata(nfc_dev); + + if (hdev->ops->stop_poll) + hdev->ops->stop_poll(hdev); + else + nfc_hci_send_event(hdev, NFC_HCI_RF_READER_A_GATE, + NFC_HCI_EVT_END_OPERATION, NULL, 0); +} + +static int hci_dep_link_up(struct nfc_dev *nfc_dev, struct nfc_target *target, + __u8 comm_mode, __u8 *gb, size_t gb_len) +{ + struct nfc_hci_dev *hdev = nfc_get_drvdata(nfc_dev); + + if (!hdev->ops->dep_link_up) + return 0; + + return hdev->ops->dep_link_up(hdev, target, comm_mode, + gb, gb_len); +} + +static int hci_dep_link_down(struct nfc_dev *nfc_dev) +{ + struct nfc_hci_dev *hdev = nfc_get_drvdata(nfc_dev); + + if (!hdev->ops->dep_link_down) + return 0; + + return hdev->ops->dep_link_down(hdev); +} + +static int hci_activate_target(struct nfc_dev *nfc_dev, + struct nfc_target *target, u32 protocol) +{ + return 0; +} + +static void hci_deactivate_target(struct nfc_dev *nfc_dev, + struct nfc_target *target, + u8 mode) +{ +} + +#define HCI_CB_TYPE_TRANSCEIVE 1 + +static void hci_transceive_cb(void *context, struct sk_buff *skb, int err) +{ + struct nfc_hci_dev *hdev = context; + + switch (hdev->async_cb_type) { + case HCI_CB_TYPE_TRANSCEIVE: + /* + * TODO: Check RF Error indicator to make sure data is valid. + * It seems that HCI cmd can complete without error, but data + * can be invalid if an RF error occurred? Ignore for now. + */ + if (err == 0) + skb_trim(skb, skb->len - 1); /* RF Err ind */ + + hdev->async_cb(hdev->async_cb_context, skb, err); + break; + default: + if (err == 0) + kfree_skb(skb); + break; + } +} + +static int hci_transceive(struct nfc_dev *nfc_dev, struct nfc_target *target, + struct sk_buff *skb, data_exchange_cb_t cb, + void *cb_context) +{ + struct nfc_hci_dev *hdev = nfc_get_drvdata(nfc_dev); + int r; + + pr_debug("target_idx=%d\n", target->idx); + + switch (target->hci_reader_gate) { + case NFC_HCI_RF_READER_A_GATE: + case NFC_HCI_RF_READER_B_GATE: + if (hdev->ops->im_transceive) { + r = hdev->ops->im_transceive(hdev, target, skb, cb, + cb_context); + if (r <= 0) /* handled */ + break; + } + + *(u8 *)skb_push(skb, 1) = 0; /* CTR, see spec:10.2.2.1 */ + + hdev->async_cb_type = HCI_CB_TYPE_TRANSCEIVE; + hdev->async_cb = cb; + hdev->async_cb_context = cb_context; + + r = nfc_hci_send_cmd_async(hdev, target->hci_reader_gate, + NFC_HCI_WR_XCHG_DATA, skb->data, + skb->len, hci_transceive_cb, hdev); + break; + default: + if (hdev->ops->im_transceive) { + r = hdev->ops->im_transceive(hdev, target, skb, cb, + cb_context); + if (r == 1) + r = -ENOTSUPP; + } else { + r = -ENOTSUPP; + } + break; + } + + kfree_skb(skb); + + return r; +} + +static int hci_tm_send(struct nfc_dev *nfc_dev, struct sk_buff *skb) +{ + struct nfc_hci_dev *hdev = nfc_get_drvdata(nfc_dev); + + if (!hdev->ops->tm_send) { + kfree_skb(skb); + return -ENOTSUPP; + } + + return hdev->ops->tm_send(hdev, skb); +} + +static int hci_check_presence(struct nfc_dev *nfc_dev, + struct nfc_target *target) +{ + struct nfc_hci_dev *hdev = nfc_get_drvdata(nfc_dev); + + if (!hdev->ops->check_presence) + return 0; + + return hdev->ops->check_presence(hdev, target); +} + +static int hci_discover_se(struct nfc_dev *nfc_dev) +{ + struct nfc_hci_dev *hdev = nfc_get_drvdata(nfc_dev); + + if (hdev->ops->discover_se) + return hdev->ops->discover_se(hdev); + + return 0; +} + +static int hci_enable_se(struct nfc_dev *nfc_dev, u32 se_idx) +{ + struct nfc_hci_dev *hdev = nfc_get_drvdata(nfc_dev); + + if (hdev->ops->enable_se) + return hdev->ops->enable_se(hdev, se_idx); + + return 0; +} + +static int hci_disable_se(struct nfc_dev *nfc_dev, u32 se_idx) +{ + struct nfc_hci_dev *hdev = nfc_get_drvdata(nfc_dev); + + if (hdev->ops->disable_se) + return hdev->ops->disable_se(hdev, se_idx); + + return 0; +} + +static int hci_se_io(struct nfc_dev *nfc_dev, u32 se_idx, + u8 *apdu, size_t apdu_length, + se_io_cb_t cb, void *cb_context) +{ + struct nfc_hci_dev *hdev = nfc_get_drvdata(nfc_dev); + + if (hdev->ops->se_io) + return hdev->ops->se_io(hdev, se_idx, apdu, + apdu_length, cb, cb_context); + + return 0; +} + +static void nfc_hci_failure(struct nfc_hci_dev *hdev, int err) +{ + mutex_lock(&hdev->msg_tx_mutex); + + if (hdev->cmd_pending_msg == NULL) { + nfc_driver_failure(hdev->ndev, err); + goto exit; + } + + __nfc_hci_cmd_completion(hdev, err, NULL); + +exit: + mutex_unlock(&hdev->msg_tx_mutex); +} + +static void nfc_hci_llc_failure(struct nfc_hci_dev *hdev, int err) +{ + nfc_hci_failure(hdev, err); +} + +static void nfc_hci_recv_from_llc(struct nfc_hci_dev *hdev, struct sk_buff *skb) +{ + struct hcp_packet *packet; + u8 type; + u8 instruction; + struct sk_buff *hcp_skb; + u8 pipe; + struct sk_buff *frag_skb; + int msg_len; + + packet = (struct hcp_packet *)skb->data; + if ((packet->header & ~NFC_HCI_FRAGMENT) == 0) { + skb_queue_tail(&hdev->rx_hcp_frags, skb); + return; + } + + /* it's the last fragment. Does it need re-aggregation? */ + if (skb_queue_len(&hdev->rx_hcp_frags)) { + pipe = packet->header & NFC_HCI_FRAGMENT; + skb_queue_tail(&hdev->rx_hcp_frags, skb); + + msg_len = 0; + skb_queue_walk(&hdev->rx_hcp_frags, frag_skb) { + msg_len += (frag_skb->len - + NFC_HCI_HCP_PACKET_HEADER_LEN); + } + + hcp_skb = nfc_alloc_recv_skb(NFC_HCI_HCP_PACKET_HEADER_LEN + + msg_len, GFP_KERNEL); + if (hcp_skb == NULL) { + nfc_hci_failure(hdev, -ENOMEM); + return; + } + + skb_put_u8(hcp_skb, pipe); + + skb_queue_walk(&hdev->rx_hcp_frags, frag_skb) { + msg_len = frag_skb->len - NFC_HCI_HCP_PACKET_HEADER_LEN; + skb_put_data(hcp_skb, + frag_skb->data + NFC_HCI_HCP_PACKET_HEADER_LEN, + msg_len); + } + + skb_queue_purge(&hdev->rx_hcp_frags); + } else { + packet->header &= NFC_HCI_FRAGMENT; + hcp_skb = skb; + } + + /* if this is a response, dispatch immediately to + * unblock waiting cmd context. Otherwise, enqueue to dispatch + * in separate context where handler can also execute command. + */ + packet = (struct hcp_packet *)hcp_skb->data; + type = HCP_MSG_GET_TYPE(packet->message.header); + if (type == NFC_HCI_HCP_RESPONSE) { + pipe = packet->header; + instruction = HCP_MSG_GET_CMD(packet->message.header); + skb_pull(hcp_skb, NFC_HCI_HCP_PACKET_HEADER_LEN + + NFC_HCI_HCP_MESSAGE_HEADER_LEN); + nfc_hci_hcp_message_rx(hdev, pipe, type, instruction, hcp_skb); + } else { + skb_queue_tail(&hdev->msg_rx_queue, hcp_skb); + schedule_work(&hdev->msg_rx_work); + } +} + +static int hci_fw_download(struct nfc_dev *nfc_dev, const char *firmware_name) +{ + struct nfc_hci_dev *hdev = nfc_get_drvdata(nfc_dev); + + if (!hdev->ops->fw_download) + return -ENOTSUPP; + + return hdev->ops->fw_download(hdev, firmware_name); +} + +static const struct nfc_ops hci_nfc_ops = { + .dev_up = hci_dev_up, + .dev_down = hci_dev_down, + .start_poll = hci_start_poll, + .stop_poll = hci_stop_poll, + .dep_link_up = hci_dep_link_up, + .dep_link_down = hci_dep_link_down, + .activate_target = hci_activate_target, + .deactivate_target = hci_deactivate_target, + .im_transceive = hci_transceive, + .tm_send = hci_tm_send, + .check_presence = hci_check_presence, + .fw_download = hci_fw_download, + .discover_se = hci_discover_se, + .enable_se = hci_enable_se, + .disable_se = hci_disable_se, + .se_io = hci_se_io, +}; + +struct nfc_hci_dev *nfc_hci_allocate_device(const struct nfc_hci_ops *ops, + struct nfc_hci_init_data *init_data, + unsigned long quirks, + u32 protocols, + const char *llc_name, + int tx_headroom, + int tx_tailroom, + int max_link_payload) +{ + struct nfc_hci_dev *hdev; + + if (ops->xmit == NULL) + return NULL; + + if (protocols == 0) + return NULL; + + hdev = kzalloc(sizeof(struct nfc_hci_dev), GFP_KERNEL); + if (hdev == NULL) + return NULL; + + hdev->llc = nfc_llc_allocate(llc_name, hdev, ops->xmit, + nfc_hci_recv_from_llc, tx_headroom, + tx_tailroom, nfc_hci_llc_failure); + if (hdev->llc == NULL) { + kfree(hdev); + return NULL; + } + + hdev->ndev = nfc_allocate_device(&hci_nfc_ops, protocols, + tx_headroom + HCI_CMDS_HEADROOM, + tx_tailroom); + if (!hdev->ndev) { + nfc_llc_free(hdev->llc); + kfree(hdev); + return NULL; + } + + hdev->ops = ops; + hdev->max_data_link_payload = max_link_payload; + hdev->init_data = *init_data; + + nfc_set_drvdata(hdev->ndev, hdev); + + nfc_hci_reset_pipes(hdev); + + hdev->quirks = quirks; + + return hdev; +} +EXPORT_SYMBOL(nfc_hci_allocate_device); + +void nfc_hci_free_device(struct nfc_hci_dev *hdev) +{ + nfc_free_device(hdev->ndev); + nfc_llc_free(hdev->llc); + kfree(hdev); +} +EXPORT_SYMBOL(nfc_hci_free_device); + +int nfc_hci_register_device(struct nfc_hci_dev *hdev) +{ + mutex_init(&hdev->msg_tx_mutex); + + INIT_LIST_HEAD(&hdev->msg_tx_queue); + + INIT_WORK(&hdev->msg_tx_work, nfc_hci_msg_tx_work); + + timer_setup(&hdev->cmd_timer, nfc_hci_cmd_timeout, 0); + + skb_queue_head_init(&hdev->rx_hcp_frags); + + INIT_WORK(&hdev->msg_rx_work, nfc_hci_msg_rx_work); + + skb_queue_head_init(&hdev->msg_rx_queue); + + return nfc_register_device(hdev->ndev); +} +EXPORT_SYMBOL(nfc_hci_register_device); + +void nfc_hci_unregister_device(struct nfc_hci_dev *hdev) +{ + struct hci_msg *msg, *n; + + mutex_lock(&hdev->msg_tx_mutex); + + if (hdev->cmd_pending_msg) { + if (hdev->cmd_pending_msg->cb) + hdev->cmd_pending_msg->cb( + hdev->cmd_pending_msg->cb_context, + NULL, -ESHUTDOWN); + kfree(hdev->cmd_pending_msg); + hdev->cmd_pending_msg = NULL; + } + + hdev->shutting_down = true; + + mutex_unlock(&hdev->msg_tx_mutex); + + del_timer_sync(&hdev->cmd_timer); + cancel_work_sync(&hdev->msg_tx_work); + + cancel_work_sync(&hdev->msg_rx_work); + + nfc_unregister_device(hdev->ndev); + + skb_queue_purge(&hdev->rx_hcp_frags); + skb_queue_purge(&hdev->msg_rx_queue); + + list_for_each_entry_safe(msg, n, &hdev->msg_tx_queue, msg_l) { + list_del(&msg->msg_l); + skb_queue_purge(&msg->msg_frags); + kfree(msg); + } +} +EXPORT_SYMBOL(nfc_hci_unregister_device); + +void nfc_hci_set_clientdata(struct nfc_hci_dev *hdev, void *clientdata) +{ + hdev->clientdata = clientdata; +} +EXPORT_SYMBOL(nfc_hci_set_clientdata); + +void *nfc_hci_get_clientdata(struct nfc_hci_dev *hdev) +{ + return hdev->clientdata; +} +EXPORT_SYMBOL(nfc_hci_get_clientdata); + +void nfc_hci_driver_failure(struct nfc_hci_dev *hdev, int err) +{ + nfc_hci_failure(hdev, err); +} +EXPORT_SYMBOL(nfc_hci_driver_failure); + +void nfc_hci_recv_frame(struct nfc_hci_dev *hdev, struct sk_buff *skb) +{ + nfc_llc_rcv_from_drv(hdev->llc, skb); +} +EXPORT_SYMBOL(nfc_hci_recv_frame); + +static int __init nfc_hci_init(void) +{ + return nfc_llc_init(); +} + +static void __exit nfc_hci_exit(void) +{ + nfc_llc_exit(); +} + +subsys_initcall(nfc_hci_init); +module_exit(nfc_hci_exit); + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("NFC HCI Core"); diff --git a/net/nfc/hci/hci.h b/net/nfc/hci/hci.h new file mode 100644 index 000000000..a59c96fcf --- /dev/null +++ b/net/nfc/hci/hci.h @@ -0,0 +1,120 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * Copyright (C) 2012 Intel Corporation. All rights reserved. + */ + +#ifndef __LOCAL_HCI_H +#define __LOCAL_HCI_H + +#include + +struct gate_pipe_map { + u8 gate; + u8 pipe; +}; + +struct hcp_message { + u8 header; /* type -cmd,evt,rsp- + instruction */ + u8 data[]; +} __packed; + +struct hcp_packet { + u8 header; /* cbit+pipe */ + struct hcp_message message; +} __packed; + +struct hcp_exec_waiter { + wait_queue_head_t *wq; + bool exec_complete; + int exec_result; + struct sk_buff *result_skb; +}; + +struct hci_msg { + struct list_head msg_l; + struct sk_buff_head msg_frags; + bool wait_response; + data_exchange_cb_t cb; + void *cb_context; + unsigned long completion_delay; +}; + +struct hci_create_pipe_params { + u8 src_gate; + u8 dest_host; + u8 dest_gate; +} __packed; + +struct hci_create_pipe_resp { + u8 src_host; + u8 src_gate; + u8 dest_host; + u8 dest_gate; + u8 pipe; +} __packed; + +struct hci_delete_pipe_noti { + u8 pipe; +} __packed; + +struct hci_all_pipe_cleared_noti { + u8 host; +} __packed; + +#define NFC_HCI_FRAGMENT 0x7f + +#define HCP_HEADER(type, instr) ((((type) & 0x03) << 6) | ((instr) & 0x3f)) +#define HCP_MSG_GET_TYPE(header) ((header & 0xc0) >> 6) +#define HCP_MSG_GET_CMD(header) (header & 0x3f) + +int nfc_hci_hcp_message_tx(struct nfc_hci_dev *hdev, u8 pipe, + u8 type, u8 instruction, + const u8 *payload, size_t payload_len, + data_exchange_cb_t cb, void *cb_context, + unsigned long completion_delay); + +void nfc_hci_hcp_message_rx(struct nfc_hci_dev *hdev, u8 pipe, u8 type, + u8 instruction, struct sk_buff *skb); + +/* HCP headers */ +#define NFC_HCI_HCP_PACKET_HEADER_LEN 1 +#define NFC_HCI_HCP_MESSAGE_HEADER_LEN 1 +#define NFC_HCI_HCP_HEADER_LEN 2 + +/* HCP types */ +#define NFC_HCI_HCP_COMMAND 0x00 +#define NFC_HCI_HCP_EVENT 0x01 +#define NFC_HCI_HCP_RESPONSE 0x02 + +/* Generic commands */ +#define NFC_HCI_ANY_SET_PARAMETER 0x01 +#define NFC_HCI_ANY_GET_PARAMETER 0x02 +#define NFC_HCI_ANY_OPEN_PIPE 0x03 +#define NFC_HCI_ANY_CLOSE_PIPE 0x04 + +/* Reader RF commands */ +#define NFC_HCI_WR_XCHG_DATA 0x10 + +/* Admin commands */ +#define NFC_HCI_ADM_CREATE_PIPE 0x10 +#define NFC_HCI_ADM_DELETE_PIPE 0x11 +#define NFC_HCI_ADM_NOTIFY_PIPE_CREATED 0x12 +#define NFC_HCI_ADM_NOTIFY_PIPE_DELETED 0x13 +#define NFC_HCI_ADM_CLEAR_ALL_PIPE 0x14 +#define NFC_HCI_ADM_NOTIFY_ALL_PIPE_CLEARED 0x15 + +/* Generic responses */ +#define NFC_HCI_ANY_OK 0x00 +#define NFC_HCI_ANY_E_NOT_CONNECTED 0x01 +#define NFC_HCI_ANY_E_CMD_PAR_UNKNOWN 0x02 +#define NFC_HCI_ANY_E_NOK 0x03 +#define NFC_HCI_ANY_E_PIPES_FULL 0x04 +#define NFC_HCI_ANY_E_REG_PAR_UNKNOWN 0x05 +#define NFC_HCI_ANY_E_PIPE_NOT_OPENED 0x06 +#define NFC_HCI_ANY_E_CMD_NOT_SUPPORTED 0x07 +#define NFC_HCI_ANY_E_INHIBITED 0x08 +#define NFC_HCI_ANY_E_TIMEOUT 0x09 +#define NFC_HCI_ANY_E_REG_ACCESS_DENIED 0x0a +#define NFC_HCI_ANY_E_PIPE_ACCESS_DENIED 0x0b + +#endif /* __LOCAL_HCI_H */ diff --git a/net/nfc/hci/hcp.c b/net/nfc/hci/hcp.c new file mode 100644 index 000000000..4902f5064 --- /dev/null +++ b/net/nfc/hci/hcp.c @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (C) 2012 Intel Corporation. All rights reserved. + */ + +#define pr_fmt(fmt) "hci: %s: " fmt, __func__ + +#include +#include +#include + +#include + +#include "hci.h" + +/* + * Payload is the HCP message data only. Instruction will be prepended. + * Guarantees that cb will be called upon completion or timeout delay + * counted from the moment the cmd is sent to the transport. + */ +int nfc_hci_hcp_message_tx(struct nfc_hci_dev *hdev, u8 pipe, + u8 type, u8 instruction, + const u8 *payload, size_t payload_len, + data_exchange_cb_t cb, void *cb_context, + unsigned long completion_delay) +{ + struct nfc_dev *ndev = hdev->ndev; + struct hci_msg *cmd; + const u8 *ptr = payload; + int hci_len, err; + bool firstfrag = true; + + cmd = kzalloc(sizeof(struct hci_msg), GFP_KERNEL); + if (cmd == NULL) + return -ENOMEM; + + INIT_LIST_HEAD(&cmd->msg_l); + skb_queue_head_init(&cmd->msg_frags); + cmd->wait_response = (type == NFC_HCI_HCP_COMMAND) ? true : false; + cmd->cb = cb; + cmd->cb_context = cb_context; + cmd->completion_delay = completion_delay; + + hci_len = payload_len + 1; + while (hci_len > 0) { + struct sk_buff *skb; + int skb_len, data_link_len; + struct hcp_packet *packet; + + if (NFC_HCI_HCP_PACKET_HEADER_LEN + hci_len <= + hdev->max_data_link_payload) + data_link_len = hci_len; + else + data_link_len = hdev->max_data_link_payload - + NFC_HCI_HCP_PACKET_HEADER_LEN; + + skb_len = ndev->tx_headroom + NFC_HCI_HCP_PACKET_HEADER_LEN + + data_link_len + ndev->tx_tailroom; + hci_len -= data_link_len; + + skb = alloc_skb(skb_len, GFP_KERNEL); + if (skb == NULL) { + err = -ENOMEM; + goto out_skb_err; + } + skb_reserve(skb, ndev->tx_headroom); + + skb_put(skb, NFC_HCI_HCP_PACKET_HEADER_LEN + data_link_len); + + /* Only the last fragment will have the cb bit set to 1 */ + packet = (struct hcp_packet *)skb->data; + packet->header = pipe; + if (firstfrag) { + firstfrag = false; + packet->message.header = HCP_HEADER(type, instruction); + } else { + packet->message.header = *ptr++; + } + if (ptr) { + memcpy(packet->message.data, ptr, data_link_len - 1); + ptr += data_link_len - 1; + } + + /* This is the last fragment, set the cb bit */ + if (hci_len == 0) + packet->header |= ~NFC_HCI_FRAGMENT; + + skb_queue_tail(&cmd->msg_frags, skb); + } + + mutex_lock(&hdev->msg_tx_mutex); + + if (hdev->shutting_down) { + err = -ESHUTDOWN; + mutex_unlock(&hdev->msg_tx_mutex); + goto out_skb_err; + } + + list_add_tail(&cmd->msg_l, &hdev->msg_tx_queue); + mutex_unlock(&hdev->msg_tx_mutex); + + schedule_work(&hdev->msg_tx_work); + + return 0; + +out_skb_err: + skb_queue_purge(&cmd->msg_frags); + kfree(cmd); + + return err; +} + +/* + * Receive hcp message for pipe, with type and cmd. + * skb contains optional message data only. + */ +void nfc_hci_hcp_message_rx(struct nfc_hci_dev *hdev, u8 pipe, u8 type, + u8 instruction, struct sk_buff *skb) +{ + switch (type) { + case NFC_HCI_HCP_RESPONSE: + nfc_hci_resp_received(hdev, instruction, skb); + break; + case NFC_HCI_HCP_COMMAND: + nfc_hci_cmd_received(hdev, pipe, instruction, skb); + break; + case NFC_HCI_HCP_EVENT: + nfc_hci_event_received(hdev, pipe, instruction, skb); + break; + default: + pr_err("UNKNOWN MSG Type %d, instruction=%d\n", + type, instruction); + kfree_skb(skb); + break; + } +} diff --git a/net/nfc/hci/llc.c b/net/nfc/hci/llc.c new file mode 100644 index 000000000..2140f6724 --- /dev/null +++ b/net/nfc/hci/llc.c @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Link Layer Control manager + * + * Copyright (C) 2012 Intel Corporation. All rights reserved. + */ + +#include + +#include "llc.h" + +static LIST_HEAD(llc_engines); + +int __init nfc_llc_init(void) +{ + int r; + + r = nfc_llc_nop_register(); + if (r) + goto exit; + + r = nfc_llc_shdlc_register(); + if (r) + goto exit; + + return 0; + +exit: + nfc_llc_exit(); + return r; +} + +void nfc_llc_exit(void) +{ + struct nfc_llc_engine *llc_engine, *n; + + list_for_each_entry_safe(llc_engine, n, &llc_engines, entry) { + list_del(&llc_engine->entry); + kfree(llc_engine->name); + kfree(llc_engine); + } +} + +int nfc_llc_register(const char *name, const struct nfc_llc_ops *ops) +{ + struct nfc_llc_engine *llc_engine; + + llc_engine = kzalloc(sizeof(struct nfc_llc_engine), GFP_KERNEL); + if (llc_engine == NULL) + return -ENOMEM; + + llc_engine->name = kstrdup(name, GFP_KERNEL); + if (llc_engine->name == NULL) { + kfree(llc_engine); + return -ENOMEM; + } + llc_engine->ops = ops; + + INIT_LIST_HEAD(&llc_engine->entry); + list_add_tail(&llc_engine->entry, &llc_engines); + + return 0; +} + +static struct nfc_llc_engine *nfc_llc_name_to_engine(const char *name) +{ + struct nfc_llc_engine *llc_engine; + + list_for_each_entry(llc_engine, &llc_engines, entry) { + if (strcmp(llc_engine->name, name) == 0) + return llc_engine; + } + + return NULL; +} + +void nfc_llc_unregister(const char *name) +{ + struct nfc_llc_engine *llc_engine; + + llc_engine = nfc_llc_name_to_engine(name); + if (llc_engine == NULL) + return; + + list_del(&llc_engine->entry); + kfree(llc_engine->name); + kfree(llc_engine); +} + +struct nfc_llc *nfc_llc_allocate(const char *name, struct nfc_hci_dev *hdev, + xmit_to_drv_t xmit_to_drv, + rcv_to_hci_t rcv_to_hci, int tx_headroom, + int tx_tailroom, llc_failure_t llc_failure) +{ + struct nfc_llc_engine *llc_engine; + struct nfc_llc *llc; + + llc_engine = nfc_llc_name_to_engine(name); + if (llc_engine == NULL) + return NULL; + + llc = kzalloc(sizeof(struct nfc_llc), GFP_KERNEL); + if (llc == NULL) + return NULL; + + llc->data = llc_engine->ops->init(hdev, xmit_to_drv, rcv_to_hci, + tx_headroom, tx_tailroom, + &llc->rx_headroom, &llc->rx_tailroom, + llc_failure); + if (llc->data == NULL) { + kfree(llc); + return NULL; + } + llc->ops = llc_engine->ops; + + return llc; +} + +void nfc_llc_free(struct nfc_llc *llc) +{ + llc->ops->deinit(llc); + kfree(llc); +} + +int nfc_llc_start(struct nfc_llc *llc) +{ + return llc->ops->start(llc); +} +EXPORT_SYMBOL(nfc_llc_start); + +int nfc_llc_stop(struct nfc_llc *llc) +{ + return llc->ops->stop(llc); +} +EXPORT_SYMBOL(nfc_llc_stop); + +void nfc_llc_rcv_from_drv(struct nfc_llc *llc, struct sk_buff *skb) +{ + llc->ops->rcv_from_drv(llc, skb); +} + +int nfc_llc_xmit_from_hci(struct nfc_llc *llc, struct sk_buff *skb) +{ + return llc->ops->xmit_from_hci(llc, skb); +} + +void *nfc_llc_get_data(struct nfc_llc *llc) +{ + return llc->data; +} diff --git a/net/nfc/hci/llc.h b/net/nfc/hci/llc.h new file mode 100644 index 000000000..d66271d21 --- /dev/null +++ b/net/nfc/hci/llc.h @@ -0,0 +1,56 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Link Layer Control manager + * + * Copyright (C) 2012 Intel Corporation. All rights reserved. + */ + +#ifndef __LOCAL_LLC_H_ +#define __LOCAL_LLC_H_ + +#include +#include +#include + +struct nfc_llc_ops { + void *(*init) (struct nfc_hci_dev *hdev, xmit_to_drv_t xmit_to_drv, + rcv_to_hci_t rcv_to_hci, int tx_headroom, + int tx_tailroom, int *rx_headroom, int *rx_tailroom, + llc_failure_t llc_failure); + void (*deinit) (struct nfc_llc *llc); + int (*start) (struct nfc_llc *llc); + int (*stop) (struct nfc_llc *llc); + void (*rcv_from_drv) (struct nfc_llc *llc, struct sk_buff *skb); + int (*xmit_from_hci) (struct nfc_llc *llc, struct sk_buff *skb); +}; + +struct nfc_llc_engine { + const char *name; + const struct nfc_llc_ops *ops; + struct list_head entry; +}; + +struct nfc_llc { + void *data; + const struct nfc_llc_ops *ops; + int rx_headroom; + int rx_tailroom; +}; + +void *nfc_llc_get_data(struct nfc_llc *llc); + +int nfc_llc_register(const char *name, const struct nfc_llc_ops *ops); +void nfc_llc_unregister(const char *name); + +int nfc_llc_nop_register(void); + +#if defined(CONFIG_NFC_SHDLC) +int nfc_llc_shdlc_register(void); +#else +static inline int nfc_llc_shdlc_register(void) +{ + return 0; +} +#endif + +#endif /* __LOCAL_LLC_H_ */ diff --git a/net/nfc/hci/llc_nop.c b/net/nfc/hci/llc_nop.c new file mode 100644 index 000000000..a58716f16 --- /dev/null +++ b/net/nfc/hci/llc_nop.c @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * nop (passthrough) Link Layer Control + * + * Copyright (C) 2012 Intel Corporation. All rights reserved. + */ + +#include + +#include "llc.h" + +struct llc_nop { + struct nfc_hci_dev *hdev; + xmit_to_drv_t xmit_to_drv; + rcv_to_hci_t rcv_to_hci; + int tx_headroom; + int tx_tailroom; + llc_failure_t llc_failure; +}; + +static void *llc_nop_init(struct nfc_hci_dev *hdev, xmit_to_drv_t xmit_to_drv, + rcv_to_hci_t rcv_to_hci, int tx_headroom, + int tx_tailroom, int *rx_headroom, int *rx_tailroom, + llc_failure_t llc_failure) +{ + struct llc_nop *llc_nop; + + *rx_headroom = 0; + *rx_tailroom = 0; + + llc_nop = kzalloc(sizeof(struct llc_nop), GFP_KERNEL); + if (llc_nop == NULL) + return NULL; + + llc_nop->hdev = hdev; + llc_nop->xmit_to_drv = xmit_to_drv; + llc_nop->rcv_to_hci = rcv_to_hci; + llc_nop->tx_headroom = tx_headroom; + llc_nop->tx_tailroom = tx_tailroom; + llc_nop->llc_failure = llc_failure; + + return llc_nop; +} + +static void llc_nop_deinit(struct nfc_llc *llc) +{ + kfree(nfc_llc_get_data(llc)); +} + +static int llc_nop_start(struct nfc_llc *llc) +{ + return 0; +} + +static int llc_nop_stop(struct nfc_llc *llc) +{ + return 0; +} + +static void llc_nop_rcv_from_drv(struct nfc_llc *llc, struct sk_buff *skb) +{ + struct llc_nop *llc_nop = nfc_llc_get_data(llc); + + llc_nop->rcv_to_hci(llc_nop->hdev, skb); +} + +static int llc_nop_xmit_from_hci(struct nfc_llc *llc, struct sk_buff *skb) +{ + struct llc_nop *llc_nop = nfc_llc_get_data(llc); + + return llc_nop->xmit_to_drv(llc_nop->hdev, skb); +} + +static const struct nfc_llc_ops llc_nop_ops = { + .init = llc_nop_init, + .deinit = llc_nop_deinit, + .start = llc_nop_start, + .stop = llc_nop_stop, + .rcv_from_drv = llc_nop_rcv_from_drv, + .xmit_from_hci = llc_nop_xmit_from_hci, +}; + +int nfc_llc_nop_register(void) +{ + return nfc_llc_register(LLC_NOP_NAME, &llc_nop_ops); +} diff --git a/net/nfc/hci/llc_shdlc.c b/net/nfc/hci/llc_shdlc.c new file mode 100644 index 000000000..e90f70385 --- /dev/null +++ b/net/nfc/hci/llc_shdlc.c @@ -0,0 +1,818 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * shdlc Link Layer Control + * + * Copyright (C) 2012 Intel Corporation. All rights reserved. + */ + +#define pr_fmt(fmt) "shdlc: %s: " fmt, __func__ + +#include +#include +#include +#include +#include + +#include "llc.h" + +enum shdlc_state { + SHDLC_DISCONNECTED = 0, + SHDLC_CONNECTING = 1, + SHDLC_NEGOTIATING = 2, + SHDLC_HALF_CONNECTED = 3, + SHDLC_CONNECTED = 4 +}; + +struct llc_shdlc { + struct nfc_hci_dev *hdev; + xmit_to_drv_t xmit_to_drv; + rcv_to_hci_t rcv_to_hci; + + struct mutex state_mutex; + enum shdlc_state state; + int hard_fault; + + wait_queue_head_t *connect_wq; + int connect_tries; + int connect_result; + struct timer_list connect_timer;/* aka T3 in spec 10.6.1 */ + + u8 w; /* window size */ + bool srej_support; + + struct timer_list t1_timer; /* send ack timeout */ + bool t1_active; + + struct timer_list t2_timer; /* guard/retransmit timeout */ + bool t2_active; + + int ns; /* next seq num for send */ + int nr; /* next expected seq num for receive */ + int dnr; /* oldest sent unacked seq num */ + + struct sk_buff_head rcv_q; + + struct sk_buff_head send_q; + bool rnr; /* other side is not ready to receive */ + + struct sk_buff_head ack_pending_q; + + struct work_struct sm_work; + + int tx_headroom; + int tx_tailroom; + + llc_failure_t llc_failure; +}; + +#define SHDLC_LLC_HEAD_ROOM 2 + +#define SHDLC_MAX_WINDOW 4 +#define SHDLC_SREJ_SUPPORT false + +#define SHDLC_CONTROL_HEAD_MASK 0xe0 +#define SHDLC_CONTROL_HEAD_I 0x80 +#define SHDLC_CONTROL_HEAD_I2 0xa0 +#define SHDLC_CONTROL_HEAD_S 0xc0 +#define SHDLC_CONTROL_HEAD_U 0xe0 + +#define SHDLC_CONTROL_NS_MASK 0x38 +#define SHDLC_CONTROL_NR_MASK 0x07 +#define SHDLC_CONTROL_TYPE_MASK 0x18 + +#define SHDLC_CONTROL_M_MASK 0x1f + +enum sframe_type { + S_FRAME_RR = 0x00, + S_FRAME_REJ = 0x01, + S_FRAME_RNR = 0x02, + S_FRAME_SREJ = 0x03 +}; + +enum uframe_modifier { + U_FRAME_UA = 0x06, + U_FRAME_RSET = 0x19 +}; + +#define SHDLC_CONNECT_VALUE_MS 5 +#define SHDLC_T1_VALUE_MS(w) ((5 * w) / 4) +#define SHDLC_T2_VALUE_MS 300 + +#define SHDLC_DUMP_SKB(info, skb) \ +do { \ + pr_debug("%s:\n", info); \ + print_hex_dump(KERN_DEBUG, "shdlc: ", DUMP_PREFIX_OFFSET, \ + 16, 1, skb->data, skb->len, 0); \ +} while (0) + +/* checks x < y <= z modulo 8 */ +static bool llc_shdlc_x_lt_y_lteq_z(int x, int y, int z) +{ + if (x < z) + return ((x < y) && (y <= z)) ? true : false; + else + return ((y > x) || (y <= z)) ? true : false; +} + +/* checks x <= y < z modulo 8 */ +static bool llc_shdlc_x_lteq_y_lt_z(int x, int y, int z) +{ + if (x <= z) + return ((x <= y) && (y < z)) ? true : false; + else /* x > z -> z+8 > x */ + return ((y >= x) || (y < z)) ? true : false; +} + +static struct sk_buff *llc_shdlc_alloc_skb(const struct llc_shdlc *shdlc, + int payload_len) +{ + struct sk_buff *skb; + + skb = alloc_skb(shdlc->tx_headroom + SHDLC_LLC_HEAD_ROOM + + shdlc->tx_tailroom + payload_len, GFP_KERNEL); + if (skb) + skb_reserve(skb, shdlc->tx_headroom + SHDLC_LLC_HEAD_ROOM); + + return skb; +} + +/* immediately sends an S frame. */ +static int llc_shdlc_send_s_frame(const struct llc_shdlc *shdlc, + enum sframe_type sframe_type, int nr) +{ + int r; + struct sk_buff *skb; + + pr_debug("sframe_type=%d nr=%d\n", sframe_type, nr); + + skb = llc_shdlc_alloc_skb(shdlc, 0); + if (skb == NULL) + return -ENOMEM; + + *(u8 *)skb_push(skb, 1) = SHDLC_CONTROL_HEAD_S | (sframe_type << 3) | nr; + + r = shdlc->xmit_to_drv(shdlc->hdev, skb); + + kfree_skb(skb); + + return r; +} + +/* immediately sends an U frame. skb may contain optional payload */ +static int llc_shdlc_send_u_frame(const struct llc_shdlc *shdlc, + struct sk_buff *skb, + enum uframe_modifier uframe_modifier) +{ + int r; + + pr_debug("uframe_modifier=%d\n", uframe_modifier); + + *(u8 *)skb_push(skb, 1) = SHDLC_CONTROL_HEAD_U | uframe_modifier; + + r = shdlc->xmit_to_drv(shdlc->hdev, skb); + + kfree_skb(skb); + + return r; +} + +/* + * Free ack_pending frames until y_nr - 1, and reset t2 according to + * the remaining oldest ack_pending frame sent time + */ +static void llc_shdlc_reset_t2(struct llc_shdlc *shdlc, int y_nr) +{ + struct sk_buff *skb; + int dnr = shdlc->dnr; /* MUST initially be < y_nr */ + + pr_debug("release ack pending up to frame %d excluded\n", y_nr); + + while (dnr != y_nr) { + pr_debug("release ack pending frame %d\n", dnr); + + skb = skb_dequeue(&shdlc->ack_pending_q); + kfree_skb(skb); + + dnr = (dnr + 1) % 8; + } + + if (skb_queue_empty(&shdlc->ack_pending_q)) { + if (shdlc->t2_active) { + del_timer_sync(&shdlc->t2_timer); + shdlc->t2_active = false; + + pr_debug("All sent frames acked. Stopped T2(retransmit)\n"); + } + } else { + skb = skb_peek(&shdlc->ack_pending_q); + + mod_timer(&shdlc->t2_timer, *(unsigned long *)skb->cb + + msecs_to_jiffies(SHDLC_T2_VALUE_MS)); + shdlc->t2_active = true; + + pr_debug("Start T2(retransmit) for remaining unacked sent frames\n"); + } +} + +/* + * Receive validated frames from lower layer. skb contains HCI payload only. + * Handle according to algorithm at spec:10.8.2 + */ +static void llc_shdlc_rcv_i_frame(struct llc_shdlc *shdlc, + struct sk_buff *skb, int ns, int nr) +{ + int x_ns = ns; + int y_nr = nr; + + pr_debug("recvd I-frame %d, remote waiting frame %d\n", ns, nr); + + if (shdlc->state != SHDLC_CONNECTED) + goto exit; + + if (x_ns != shdlc->nr) { + llc_shdlc_send_s_frame(shdlc, S_FRAME_REJ, shdlc->nr); + goto exit; + } + + if (!shdlc->t1_active) { + shdlc->t1_active = true; + mod_timer(&shdlc->t1_timer, jiffies + + msecs_to_jiffies(SHDLC_T1_VALUE_MS(shdlc->w))); + pr_debug("(re)Start T1(send ack)\n"); + } + + if (skb->len) { + shdlc->rcv_to_hci(shdlc->hdev, skb); + skb = NULL; + } + + shdlc->nr = (shdlc->nr + 1) % 8; + + if (llc_shdlc_x_lt_y_lteq_z(shdlc->dnr, y_nr, shdlc->ns)) { + llc_shdlc_reset_t2(shdlc, y_nr); + + shdlc->dnr = y_nr; + } + +exit: + kfree_skb(skb); +} + +static void llc_shdlc_rcv_ack(struct llc_shdlc *shdlc, int y_nr) +{ + pr_debug("remote acked up to frame %d excluded\n", y_nr); + + if (llc_shdlc_x_lt_y_lteq_z(shdlc->dnr, y_nr, shdlc->ns)) { + llc_shdlc_reset_t2(shdlc, y_nr); + shdlc->dnr = y_nr; + } +} + +static void llc_shdlc_requeue_ack_pending(struct llc_shdlc *shdlc) +{ + struct sk_buff *skb; + + pr_debug("ns reset to %d\n", shdlc->dnr); + + while ((skb = skb_dequeue_tail(&shdlc->ack_pending_q))) { + skb_pull(skb, 1); /* remove control field */ + skb_queue_head(&shdlc->send_q, skb); + } + shdlc->ns = shdlc->dnr; +} + +static void llc_shdlc_rcv_rej(struct llc_shdlc *shdlc, int y_nr) +{ + struct sk_buff *skb; + + pr_debug("remote asks retransmission from frame %d\n", y_nr); + + if (llc_shdlc_x_lteq_y_lt_z(shdlc->dnr, y_nr, shdlc->ns)) { + if (shdlc->t2_active) { + del_timer_sync(&shdlc->t2_timer); + shdlc->t2_active = false; + pr_debug("Stopped T2(retransmit)\n"); + } + + if (shdlc->dnr != y_nr) { + while ((shdlc->dnr = ((shdlc->dnr + 1) % 8)) != y_nr) { + skb = skb_dequeue(&shdlc->ack_pending_q); + kfree_skb(skb); + } + } + + llc_shdlc_requeue_ack_pending(shdlc); + } +} + +/* See spec RR:10.8.3 REJ:10.8.4 */ +static void llc_shdlc_rcv_s_frame(struct llc_shdlc *shdlc, + enum sframe_type s_frame_type, int nr) +{ + struct sk_buff *skb; + + if (shdlc->state != SHDLC_CONNECTED) + return; + + switch (s_frame_type) { + case S_FRAME_RR: + llc_shdlc_rcv_ack(shdlc, nr); + if (shdlc->rnr == true) { /* see SHDLC 10.7.7 */ + shdlc->rnr = false; + if (shdlc->send_q.qlen == 0) { + skb = llc_shdlc_alloc_skb(shdlc, 0); + if (skb) + skb_queue_tail(&shdlc->send_q, skb); + } + } + break; + case S_FRAME_REJ: + llc_shdlc_rcv_rej(shdlc, nr); + break; + case S_FRAME_RNR: + llc_shdlc_rcv_ack(shdlc, nr); + shdlc->rnr = true; + break; + default: + break; + } +} + +static void llc_shdlc_connect_complete(struct llc_shdlc *shdlc, int r) +{ + pr_debug("result=%d\n", r); + + del_timer_sync(&shdlc->connect_timer); + + if (r == 0) { + shdlc->ns = 0; + shdlc->nr = 0; + shdlc->dnr = 0; + + shdlc->state = SHDLC_HALF_CONNECTED; + } else { + shdlc->state = SHDLC_DISCONNECTED; + } + + shdlc->connect_result = r; + + wake_up(shdlc->connect_wq); +} + +static int llc_shdlc_connect_initiate(const struct llc_shdlc *shdlc) +{ + struct sk_buff *skb; + + skb = llc_shdlc_alloc_skb(shdlc, 2); + if (skb == NULL) + return -ENOMEM; + + skb_put_u8(skb, SHDLC_MAX_WINDOW); + skb_put_u8(skb, SHDLC_SREJ_SUPPORT ? 1 : 0); + + return llc_shdlc_send_u_frame(shdlc, skb, U_FRAME_RSET); +} + +static int llc_shdlc_connect_send_ua(const struct llc_shdlc *shdlc) +{ + struct sk_buff *skb; + + skb = llc_shdlc_alloc_skb(shdlc, 0); + if (skb == NULL) + return -ENOMEM; + + return llc_shdlc_send_u_frame(shdlc, skb, U_FRAME_UA); +} + +static void llc_shdlc_rcv_u_frame(struct llc_shdlc *shdlc, + struct sk_buff *skb, + enum uframe_modifier u_frame_modifier) +{ + u8 w = SHDLC_MAX_WINDOW; + bool srej_support = SHDLC_SREJ_SUPPORT; + int r; + + pr_debug("u_frame_modifier=%d\n", u_frame_modifier); + + switch (u_frame_modifier) { + case U_FRAME_RSET: + switch (shdlc->state) { + case SHDLC_NEGOTIATING: + case SHDLC_CONNECTING: + /* + * We sent RSET, but chip wants to negotiate or we + * got RSET before we managed to send out our. + */ + if (skb->len > 0) + w = skb->data[0]; + + if (skb->len > 1) + srej_support = skb->data[1] & 0x01 ? true : + false; + + if ((w <= SHDLC_MAX_WINDOW) && + (SHDLC_SREJ_SUPPORT || (srej_support == false))) { + shdlc->w = w; + shdlc->srej_support = srej_support; + r = llc_shdlc_connect_send_ua(shdlc); + llc_shdlc_connect_complete(shdlc, r); + } + break; + case SHDLC_HALF_CONNECTED: + /* + * Chip resent RSET due to its timeout - Ignote it + * as we already sent UA. + */ + break; + case SHDLC_CONNECTED: + /* + * Chip wants to reset link. This is unexpected and + * unsupported. + */ + shdlc->hard_fault = -ECONNRESET; + break; + default: + break; + } + break; + case U_FRAME_UA: + if ((shdlc->state == SHDLC_CONNECTING && + shdlc->connect_tries > 0) || + (shdlc->state == SHDLC_NEGOTIATING)) { + llc_shdlc_connect_complete(shdlc, 0); + shdlc->state = SHDLC_CONNECTED; + } + break; + default: + break; + } + + kfree_skb(skb); +} + +static void llc_shdlc_handle_rcv_queue(struct llc_shdlc *shdlc) +{ + struct sk_buff *skb; + u8 control; + int nr; + int ns; + enum sframe_type s_frame_type; + enum uframe_modifier u_frame_modifier; + + if (shdlc->rcv_q.qlen) + pr_debug("rcvQlen=%d\n", shdlc->rcv_q.qlen); + + while ((skb = skb_dequeue(&shdlc->rcv_q)) != NULL) { + control = skb->data[0]; + skb_pull(skb, 1); + switch (control & SHDLC_CONTROL_HEAD_MASK) { + case SHDLC_CONTROL_HEAD_I: + case SHDLC_CONTROL_HEAD_I2: + if (shdlc->state == SHDLC_HALF_CONNECTED) + shdlc->state = SHDLC_CONNECTED; + + ns = (control & SHDLC_CONTROL_NS_MASK) >> 3; + nr = control & SHDLC_CONTROL_NR_MASK; + llc_shdlc_rcv_i_frame(shdlc, skb, ns, nr); + break; + case SHDLC_CONTROL_HEAD_S: + if (shdlc->state == SHDLC_HALF_CONNECTED) + shdlc->state = SHDLC_CONNECTED; + + s_frame_type = (control & SHDLC_CONTROL_TYPE_MASK) >> 3; + nr = control & SHDLC_CONTROL_NR_MASK; + llc_shdlc_rcv_s_frame(shdlc, s_frame_type, nr); + kfree_skb(skb); + break; + case SHDLC_CONTROL_HEAD_U: + u_frame_modifier = control & SHDLC_CONTROL_M_MASK; + llc_shdlc_rcv_u_frame(shdlc, skb, u_frame_modifier); + break; + default: + pr_err("UNKNOWN Control=%d\n", control); + kfree_skb(skb); + break; + } + } +} + +static int llc_shdlc_w_used(int ns, int dnr) +{ + int unack_count; + + if (dnr <= ns) + unack_count = ns - dnr; + else + unack_count = 8 - dnr + ns; + + return unack_count; +} + +/* Send frames according to algorithm at spec:10.8.1 */ +static void llc_shdlc_handle_send_queue(struct llc_shdlc *shdlc) +{ + struct sk_buff *skb; + int r; + unsigned long time_sent; + + if (shdlc->send_q.qlen) + pr_debug("sendQlen=%d ns=%d dnr=%d rnr=%s w_room=%d unackQlen=%d\n", + shdlc->send_q.qlen, shdlc->ns, shdlc->dnr, + shdlc->rnr == false ? "false" : "true", + shdlc->w - llc_shdlc_w_used(shdlc->ns, shdlc->dnr), + shdlc->ack_pending_q.qlen); + + while (shdlc->send_q.qlen && shdlc->ack_pending_q.qlen < shdlc->w && + (shdlc->rnr == false)) { + + if (shdlc->t1_active) { + del_timer_sync(&shdlc->t1_timer); + shdlc->t1_active = false; + pr_debug("Stopped T1(send ack)\n"); + } + + skb = skb_dequeue(&shdlc->send_q); + + *(u8 *)skb_push(skb, 1) = SHDLC_CONTROL_HEAD_I | (shdlc->ns << 3) | + shdlc->nr; + + pr_debug("Sending I-Frame %d, waiting to rcv %d\n", shdlc->ns, + shdlc->nr); + SHDLC_DUMP_SKB("shdlc frame written", skb); + + r = shdlc->xmit_to_drv(shdlc->hdev, skb); + if (r < 0) { + shdlc->hard_fault = r; + break; + } + + shdlc->ns = (shdlc->ns + 1) % 8; + + time_sent = jiffies; + *(unsigned long *)skb->cb = time_sent; + + skb_queue_tail(&shdlc->ack_pending_q, skb); + + if (shdlc->t2_active == false) { + shdlc->t2_active = true; + mod_timer(&shdlc->t2_timer, time_sent + + msecs_to_jiffies(SHDLC_T2_VALUE_MS)); + pr_debug("Started T2 (retransmit)\n"); + } + } +} + +static void llc_shdlc_connect_timeout(struct timer_list *t) +{ + struct llc_shdlc *shdlc = from_timer(shdlc, t, connect_timer); + + schedule_work(&shdlc->sm_work); +} + +static void llc_shdlc_t1_timeout(struct timer_list *t) +{ + struct llc_shdlc *shdlc = from_timer(shdlc, t, t1_timer); + + pr_debug("SoftIRQ: need to send ack\n"); + + schedule_work(&shdlc->sm_work); +} + +static void llc_shdlc_t2_timeout(struct timer_list *t) +{ + struct llc_shdlc *shdlc = from_timer(shdlc, t, t2_timer); + + pr_debug("SoftIRQ: need to retransmit\n"); + + schedule_work(&shdlc->sm_work); +} + +static void llc_shdlc_sm_work(struct work_struct *work) +{ + struct llc_shdlc *shdlc = container_of(work, struct llc_shdlc, sm_work); + int r; + + mutex_lock(&shdlc->state_mutex); + + switch (shdlc->state) { + case SHDLC_DISCONNECTED: + skb_queue_purge(&shdlc->rcv_q); + skb_queue_purge(&shdlc->send_q); + skb_queue_purge(&shdlc->ack_pending_q); + break; + case SHDLC_CONNECTING: + if (shdlc->hard_fault) { + llc_shdlc_connect_complete(shdlc, shdlc->hard_fault); + break; + } + + if (shdlc->connect_tries++ < 5) + r = llc_shdlc_connect_initiate(shdlc); + else + r = -ETIME; + if (r < 0) { + llc_shdlc_connect_complete(shdlc, r); + } else { + mod_timer(&shdlc->connect_timer, jiffies + + msecs_to_jiffies(SHDLC_CONNECT_VALUE_MS)); + + shdlc->state = SHDLC_NEGOTIATING; + } + break; + case SHDLC_NEGOTIATING: + if (timer_pending(&shdlc->connect_timer) == 0) { + shdlc->state = SHDLC_CONNECTING; + schedule_work(&shdlc->sm_work); + } + + llc_shdlc_handle_rcv_queue(shdlc); + + if (shdlc->hard_fault) { + llc_shdlc_connect_complete(shdlc, shdlc->hard_fault); + break; + } + break; + case SHDLC_HALF_CONNECTED: + case SHDLC_CONNECTED: + llc_shdlc_handle_rcv_queue(shdlc); + llc_shdlc_handle_send_queue(shdlc); + + if (shdlc->t1_active && timer_pending(&shdlc->t1_timer) == 0) { + pr_debug("Handle T1(send ack) elapsed (T1 now inactive)\n"); + + shdlc->t1_active = false; + r = llc_shdlc_send_s_frame(shdlc, S_FRAME_RR, + shdlc->nr); + if (r < 0) + shdlc->hard_fault = r; + } + + if (shdlc->t2_active && timer_pending(&shdlc->t2_timer) == 0) { + pr_debug("Handle T2(retransmit) elapsed (T2 inactive)\n"); + + shdlc->t2_active = false; + + llc_shdlc_requeue_ack_pending(shdlc); + llc_shdlc_handle_send_queue(shdlc); + } + + if (shdlc->hard_fault) + shdlc->llc_failure(shdlc->hdev, shdlc->hard_fault); + break; + default: + break; + } + mutex_unlock(&shdlc->state_mutex); +} + +/* + * Called from syscall context to establish shdlc link. Sleeps until + * link is ready or failure. + */ +static int llc_shdlc_connect(struct llc_shdlc *shdlc) +{ + DECLARE_WAIT_QUEUE_HEAD_ONSTACK(connect_wq); + + mutex_lock(&shdlc->state_mutex); + + shdlc->state = SHDLC_CONNECTING; + shdlc->connect_wq = &connect_wq; + shdlc->connect_tries = 0; + shdlc->connect_result = 1; + + mutex_unlock(&shdlc->state_mutex); + + schedule_work(&shdlc->sm_work); + + wait_event(connect_wq, shdlc->connect_result != 1); + + return shdlc->connect_result; +} + +static void llc_shdlc_disconnect(struct llc_shdlc *shdlc) +{ + mutex_lock(&shdlc->state_mutex); + + shdlc->state = SHDLC_DISCONNECTED; + + mutex_unlock(&shdlc->state_mutex); + + schedule_work(&shdlc->sm_work); +} + +/* + * Receive an incoming shdlc frame. Frame has already been crc-validated. + * skb contains only LLC header and payload. + * If skb == NULL, it is a notification that the link below is dead. + */ +static void llc_shdlc_recv_frame(struct llc_shdlc *shdlc, struct sk_buff *skb) +{ + if (skb == NULL) { + pr_err("NULL Frame -> link is dead\n"); + shdlc->hard_fault = -EREMOTEIO; + } else { + SHDLC_DUMP_SKB("incoming frame", skb); + skb_queue_tail(&shdlc->rcv_q, skb); + } + + schedule_work(&shdlc->sm_work); +} + +static void *llc_shdlc_init(struct nfc_hci_dev *hdev, xmit_to_drv_t xmit_to_drv, + rcv_to_hci_t rcv_to_hci, int tx_headroom, + int tx_tailroom, int *rx_headroom, int *rx_tailroom, + llc_failure_t llc_failure) +{ + struct llc_shdlc *shdlc; + + *rx_headroom = SHDLC_LLC_HEAD_ROOM; + *rx_tailroom = 0; + + shdlc = kzalloc(sizeof(struct llc_shdlc), GFP_KERNEL); + if (shdlc == NULL) + return NULL; + + mutex_init(&shdlc->state_mutex); + shdlc->state = SHDLC_DISCONNECTED; + + timer_setup(&shdlc->connect_timer, llc_shdlc_connect_timeout, 0); + timer_setup(&shdlc->t1_timer, llc_shdlc_t1_timeout, 0); + timer_setup(&shdlc->t2_timer, llc_shdlc_t2_timeout, 0); + + shdlc->w = SHDLC_MAX_WINDOW; + shdlc->srej_support = SHDLC_SREJ_SUPPORT; + + skb_queue_head_init(&shdlc->rcv_q); + skb_queue_head_init(&shdlc->send_q); + skb_queue_head_init(&shdlc->ack_pending_q); + + INIT_WORK(&shdlc->sm_work, llc_shdlc_sm_work); + + shdlc->hdev = hdev; + shdlc->xmit_to_drv = xmit_to_drv; + shdlc->rcv_to_hci = rcv_to_hci; + shdlc->tx_headroom = tx_headroom; + shdlc->tx_tailroom = tx_tailroom; + shdlc->llc_failure = llc_failure; + + return shdlc; +} + +static void llc_shdlc_deinit(struct nfc_llc *llc) +{ + struct llc_shdlc *shdlc = nfc_llc_get_data(llc); + + skb_queue_purge(&shdlc->rcv_q); + skb_queue_purge(&shdlc->send_q); + skb_queue_purge(&shdlc->ack_pending_q); + + kfree(shdlc); +} + +static int llc_shdlc_start(struct nfc_llc *llc) +{ + struct llc_shdlc *shdlc = nfc_llc_get_data(llc); + + return llc_shdlc_connect(shdlc); +} + +static int llc_shdlc_stop(struct nfc_llc *llc) +{ + struct llc_shdlc *shdlc = nfc_llc_get_data(llc); + + llc_shdlc_disconnect(shdlc); + + return 0; +} + +static void llc_shdlc_rcv_from_drv(struct nfc_llc *llc, struct sk_buff *skb) +{ + struct llc_shdlc *shdlc = nfc_llc_get_data(llc); + + llc_shdlc_recv_frame(shdlc, skb); +} + +static int llc_shdlc_xmit_from_hci(struct nfc_llc *llc, struct sk_buff *skb) +{ + struct llc_shdlc *shdlc = nfc_llc_get_data(llc); + + skb_queue_tail(&shdlc->send_q, skb); + + schedule_work(&shdlc->sm_work); + + return 0; +} + +static const struct nfc_llc_ops llc_shdlc_ops = { + .init = llc_shdlc_init, + .deinit = llc_shdlc_deinit, + .start = llc_shdlc_start, + .stop = llc_shdlc_stop, + .rcv_from_drv = llc_shdlc_rcv_from_drv, + .xmit_from_hci = llc_shdlc_xmit_from_hci, +}; + +int nfc_llc_shdlc_register(void) +{ + return nfc_llc_register(LLC_SHDLC_NAME, &llc_shdlc_ops); +} diff --git a/net/nfc/llcp.h b/net/nfc/llcp.h new file mode 100644 index 000000000..d8345ed57 --- /dev/null +++ b/net/nfc/llcp.h @@ -0,0 +1,252 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * Copyright (C) 2011 Intel Corporation. All rights reserved. + */ + +enum llcp_state { + LLCP_CONNECTED = 1, /* wait_for_packet() wants that */ + LLCP_CONNECTING, + LLCP_CLOSED, + LLCP_BOUND, + LLCP_LISTEN, +}; + +#define LLCP_DEFAULT_LTO 100 +#define LLCP_DEFAULT_RW 1 +#define LLCP_DEFAULT_MIU 128 + +#define LLCP_MAX_LTO 0xff +#define LLCP_MAX_RW 15 +#define LLCP_MAX_MIUX 0x7ff +#define LLCP_MAX_MIU (LLCP_MAX_MIUX + 128) + +#define LLCP_WKS_NUM_SAP 16 +#define LLCP_SDP_NUM_SAP 16 +#define LLCP_LOCAL_NUM_SAP 32 +#define LLCP_LOCAL_SAP_OFFSET (LLCP_WKS_NUM_SAP + LLCP_SDP_NUM_SAP) +#define LLCP_MAX_SAP (LLCP_WKS_NUM_SAP + LLCP_SDP_NUM_SAP + LLCP_LOCAL_NUM_SAP) +#define LLCP_SDP_UNBOUND (LLCP_MAX_SAP + 1) + +struct nfc_llcp_sock; + +struct llcp_sock_list { + struct hlist_head head; + rwlock_t lock; +}; + +struct nfc_llcp_sdp_tlv { + u8 *tlv; + u8 tlv_len; + + char *uri; + u8 tid; + u8 sap; + + unsigned long time; + + struct hlist_node node; +}; + +struct nfc_llcp_local { + struct list_head list; + struct nfc_dev *dev; + + struct kref ref; + + struct mutex sdp_lock; + + struct timer_list link_timer; + struct sk_buff_head tx_queue; + struct work_struct tx_work; + struct work_struct rx_work; + struct sk_buff *rx_pending; + struct work_struct timeout_work; + + u32 target_idx; + u8 rf_mode; + u8 comm_mode; + u8 lto; + u8 rw; + __be16 miux; + unsigned long local_wks; /* Well known services */ + unsigned long local_sdp; /* Local services */ + unsigned long local_sap; /* Local SAPs, not available for discovery */ + atomic_t local_sdp_cnt[LLCP_SDP_NUM_SAP]; + + /* local */ + u8 gb[NFC_MAX_GT_LEN]; + u8 gb_len; + + /* remote */ + u8 remote_gb[NFC_MAX_GT_LEN]; + u8 remote_gb_len; + + u8 remote_version; + u16 remote_miu; + u16 remote_lto; + u8 remote_opt; + u16 remote_wks; + + struct mutex sdreq_lock; + struct hlist_head pending_sdreqs; + struct timer_list sdreq_timer; + struct work_struct sdreq_timeout_work; + u8 sdreq_next_tid; + + /* sockets array */ + struct llcp_sock_list sockets; + struct llcp_sock_list connecting_sockets; + struct llcp_sock_list raw_sockets; +}; + +struct nfc_llcp_sock { + struct sock sk; + struct nfc_dev *dev; + struct nfc_llcp_local *local; + u32 target_idx; + u32 nfc_protocol; + + /* Link parameters */ + u8 ssap; + u8 dsap; + char *service_name; + size_t service_name_len; + u8 rw; + __be16 miux; + + + /* Remote link parameters */ + u8 remote_rw; + u16 remote_miu; + + /* Link variables */ + u8 send_n; + u8 send_ack_n; + u8 recv_n; + u8 recv_ack_n; + + /* Is the remote peer ready to receive */ + u8 remote_ready; + + /* Reserved source SAP */ + u8 reserved_ssap; + + struct sk_buff_head tx_queue; + struct sk_buff_head tx_pending_queue; + + struct list_head accept_queue; + struct sock *parent; +}; + +struct nfc_llcp_ui_cb { + __u8 dsap; + __u8 ssap; +}; + +#define nfc_llcp_ui_skb_cb(__skb) ((struct nfc_llcp_ui_cb *)&((__skb)->cb[0])) + +#define nfc_llcp_sock(sk) ((struct nfc_llcp_sock *) (sk)) +#define nfc_llcp_dev(sk) (nfc_llcp_sock((sk))->dev) + +#define LLCP_HEADER_SIZE 2 +#define LLCP_SEQUENCE_SIZE 1 +#define LLCP_AGF_PDU_HEADER_SIZE 2 + +/* LLCP versions: 1.1 is 1.0 plus SDP */ +#define LLCP_VERSION_10 0x10 +#define LLCP_VERSION_11 0x11 + +/* LLCP PDU types */ +#define LLCP_PDU_SYMM 0x0 +#define LLCP_PDU_PAX 0x1 +#define LLCP_PDU_AGF 0x2 +#define LLCP_PDU_UI 0x3 +#define LLCP_PDU_CONNECT 0x4 +#define LLCP_PDU_DISC 0x5 +#define LLCP_PDU_CC 0x6 +#define LLCP_PDU_DM 0x7 +#define LLCP_PDU_FRMR 0x8 +#define LLCP_PDU_SNL 0x9 +#define LLCP_PDU_I 0xc +#define LLCP_PDU_RR 0xd +#define LLCP_PDU_RNR 0xe + +/* Parameters TLV types */ +#define LLCP_TLV_VERSION 0x1 +#define LLCP_TLV_MIUX 0x2 +#define LLCP_TLV_WKS 0x3 +#define LLCP_TLV_LTO 0x4 +#define LLCP_TLV_RW 0x5 +#define LLCP_TLV_SN 0x6 +#define LLCP_TLV_OPT 0x7 +#define LLCP_TLV_SDREQ 0x8 +#define LLCP_TLV_SDRES 0x9 +#define LLCP_TLV_MAX 0xa + +/* Well known LLCP SAP */ +#define LLCP_SAP_SDP 0x1 +#define LLCP_SAP_IP 0x2 +#define LLCP_SAP_OBEX 0x3 +#define LLCP_SAP_SNEP 0x4 +#define LLCP_SAP_MAX 0xff + +/* Disconnection reason code */ +#define LLCP_DM_DISC 0x00 +#define LLCP_DM_NOCONN 0x01 +#define LLCP_DM_NOBOUND 0x02 +#define LLCP_DM_REJ 0x03 + + +void nfc_llcp_sock_link(struct llcp_sock_list *l, struct sock *s); +void nfc_llcp_sock_unlink(struct llcp_sock_list *l, struct sock *s); +void nfc_llcp_socket_remote_param_init(struct nfc_llcp_sock *sock); +struct nfc_llcp_local *nfc_llcp_find_local(struct nfc_dev *dev); +int nfc_llcp_local_put(struct nfc_llcp_local *local); +u8 nfc_llcp_get_sdp_ssap(struct nfc_llcp_local *local, + struct nfc_llcp_sock *sock); +u8 nfc_llcp_get_local_ssap(struct nfc_llcp_local *local); +void nfc_llcp_put_ssap(struct nfc_llcp_local *local, u8 ssap); +int nfc_llcp_queue_i_frames(struct nfc_llcp_sock *sock); +void nfc_llcp_send_to_raw_sock(struct nfc_llcp_local *local, + struct sk_buff *skb, u8 direction); + +/* Sock API */ +struct sock *nfc_llcp_sock_alloc(struct socket *sock, int type, gfp_t gfp, int kern); +void nfc_llcp_sock_free(struct nfc_llcp_sock *sock); +void nfc_llcp_accept_unlink(struct sock *sk); +void nfc_llcp_accept_enqueue(struct sock *parent, struct sock *sk); +struct sock *nfc_llcp_accept_dequeue(struct sock *sk, struct socket *newsock); + +/* TLV API */ +int nfc_llcp_parse_gb_tlv(struct nfc_llcp_local *local, + const u8 *tlv_array, u16 tlv_array_len); +int nfc_llcp_parse_connection_tlv(struct nfc_llcp_sock *sock, + const u8 *tlv_array, u16 tlv_array_len); + +/* Commands API */ +void nfc_llcp_recv(void *data, struct sk_buff *skb, int err); +u8 *nfc_llcp_build_tlv(u8 type, const u8 *value, u8 value_length, u8 *tlv_length); +struct nfc_llcp_sdp_tlv *nfc_llcp_build_sdres_tlv(u8 tid, u8 sap); +struct nfc_llcp_sdp_tlv *nfc_llcp_build_sdreq_tlv(u8 tid, const char *uri, + size_t uri_len); +void nfc_llcp_free_sdp_tlv(struct nfc_llcp_sdp_tlv *sdp); +void nfc_llcp_free_sdp_tlv_list(struct hlist_head *sdp_head); +void nfc_llcp_recv(void *data, struct sk_buff *skb, int err); +int nfc_llcp_send_symm(struct nfc_dev *dev); +int nfc_llcp_send_connect(struct nfc_llcp_sock *sock); +int nfc_llcp_send_cc(struct nfc_llcp_sock *sock); +int nfc_llcp_send_snl_sdres(struct nfc_llcp_local *local, + struct hlist_head *tlv_list, size_t tlvs_len); +int nfc_llcp_send_snl_sdreq(struct nfc_llcp_local *local, + struct hlist_head *tlv_list, size_t tlvs_len); +int nfc_llcp_send_dm(struct nfc_llcp_local *local, u8 ssap, u8 dsap, u8 reason); +int nfc_llcp_send_disconnect(struct nfc_llcp_sock *sock); +int nfc_llcp_send_i_frame(struct nfc_llcp_sock *sock, + struct msghdr *msg, size_t len); +int nfc_llcp_send_ui_frame(struct nfc_llcp_sock *sock, u8 ssap, u8 dsap, + struct msghdr *msg, size_t len); +int nfc_llcp_send_rr(struct nfc_llcp_sock *sock); + +/* Socket API */ +int __init nfc_llcp_sock_init(void); +void nfc_llcp_sock_exit(void); diff --git a/net/nfc/llcp_commands.c b/net/nfc/llcp_commands.c new file mode 100644 index 000000000..e2680a3be --- /dev/null +++ b/net/nfc/llcp_commands.c @@ -0,0 +1,815 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (C) 2011 Intel Corporation. All rights reserved. + */ + +#define pr_fmt(fmt) "llcp: %s: " fmt, __func__ + +#include +#include +#include +#include + +#include + +#include "nfc.h" +#include "llcp.h" + +static const u8 llcp_tlv_length[LLCP_TLV_MAX] = { + 0, + 1, /* VERSION */ + 2, /* MIUX */ + 2, /* WKS */ + 1, /* LTO */ + 1, /* RW */ + 0, /* SN */ + 1, /* OPT */ + 0, /* SDREQ */ + 2, /* SDRES */ + +}; + +static u8 llcp_tlv8(const u8 *tlv, u8 type) +{ + if (tlv[0] != type || tlv[1] != llcp_tlv_length[tlv[0]]) + return 0; + + return tlv[2]; +} + +static u16 llcp_tlv16(const u8 *tlv, u8 type) +{ + if (tlv[0] != type || tlv[1] != llcp_tlv_length[tlv[0]]) + return 0; + + return be16_to_cpu(*((__be16 *)(tlv + 2))); +} + + +static u8 llcp_tlv_version(const u8 *tlv) +{ + return llcp_tlv8(tlv, LLCP_TLV_VERSION); +} + +static u16 llcp_tlv_miux(const u8 *tlv) +{ + return llcp_tlv16(tlv, LLCP_TLV_MIUX) & 0x7ff; +} + +static u16 llcp_tlv_wks(const u8 *tlv) +{ + return llcp_tlv16(tlv, LLCP_TLV_WKS); +} + +static u16 llcp_tlv_lto(const u8 *tlv) +{ + return llcp_tlv8(tlv, LLCP_TLV_LTO); +} + +static u8 llcp_tlv_opt(const u8 *tlv) +{ + return llcp_tlv8(tlv, LLCP_TLV_OPT); +} + +static u8 llcp_tlv_rw(const u8 *tlv) +{ + return llcp_tlv8(tlv, LLCP_TLV_RW) & 0xf; +} + +u8 *nfc_llcp_build_tlv(u8 type, const u8 *value, u8 value_length, u8 *tlv_length) +{ + u8 *tlv, length; + + pr_debug("type %d\n", type); + + if (type >= LLCP_TLV_MAX) + return NULL; + + length = llcp_tlv_length[type]; + if (length == 0 && value_length == 0) + return NULL; + else if (length == 0) + length = value_length; + + *tlv_length = 2 + length; + tlv = kzalloc(2 + length, GFP_KERNEL); + if (tlv == NULL) + return tlv; + + tlv[0] = type; + tlv[1] = length; + memcpy(tlv + 2, value, length); + + return tlv; +} + +struct nfc_llcp_sdp_tlv *nfc_llcp_build_sdres_tlv(u8 tid, u8 sap) +{ + struct nfc_llcp_sdp_tlv *sdres; + u8 value[2]; + + sdres = kzalloc(sizeof(struct nfc_llcp_sdp_tlv), GFP_KERNEL); + if (sdres == NULL) + return NULL; + + value[0] = tid; + value[1] = sap; + + sdres->tlv = nfc_llcp_build_tlv(LLCP_TLV_SDRES, value, 2, + &sdres->tlv_len); + if (sdres->tlv == NULL) { + kfree(sdres); + return NULL; + } + + sdres->tid = tid; + sdres->sap = sap; + + INIT_HLIST_NODE(&sdres->node); + + return sdres; +} + +struct nfc_llcp_sdp_tlv *nfc_llcp_build_sdreq_tlv(u8 tid, const char *uri, + size_t uri_len) +{ + struct nfc_llcp_sdp_tlv *sdreq; + + pr_debug("uri: %s, len: %zu\n", uri, uri_len); + + /* sdreq->tlv_len is u8, takes uri_len, + 3 for header, + 1 for NULL */ + if (WARN_ON_ONCE(uri_len > U8_MAX - 4)) + return NULL; + + sdreq = kzalloc(sizeof(struct nfc_llcp_sdp_tlv), GFP_KERNEL); + if (sdreq == NULL) + return NULL; + + sdreq->tlv_len = uri_len + 3; + + if (uri[uri_len - 1] == 0) + sdreq->tlv_len--; + + sdreq->tlv = kzalloc(sdreq->tlv_len + 1, GFP_KERNEL); + if (sdreq->tlv == NULL) { + kfree(sdreq); + return NULL; + } + + sdreq->tlv[0] = LLCP_TLV_SDREQ; + sdreq->tlv[1] = sdreq->tlv_len - 2; + sdreq->tlv[2] = tid; + + sdreq->tid = tid; + sdreq->uri = sdreq->tlv + 3; + memcpy(sdreq->uri, uri, uri_len); + + sdreq->time = jiffies; + + INIT_HLIST_NODE(&sdreq->node); + + return sdreq; +} + +void nfc_llcp_free_sdp_tlv(struct nfc_llcp_sdp_tlv *sdp) +{ + kfree(sdp->tlv); + kfree(sdp); +} + +void nfc_llcp_free_sdp_tlv_list(struct hlist_head *head) +{ + struct nfc_llcp_sdp_tlv *sdp; + struct hlist_node *n; + + hlist_for_each_entry_safe(sdp, n, head, node) { + hlist_del(&sdp->node); + + nfc_llcp_free_sdp_tlv(sdp); + } +} + +int nfc_llcp_parse_gb_tlv(struct nfc_llcp_local *local, + const u8 *tlv_array, u16 tlv_array_len) +{ + const u8 *tlv = tlv_array; + u8 type, length, offset = 0; + + pr_debug("TLV array length %d\n", tlv_array_len); + + if (local == NULL) + return -ENODEV; + + while (offset < tlv_array_len) { + type = tlv[0]; + length = tlv[1]; + + pr_debug("type 0x%x length %d\n", type, length); + + switch (type) { + case LLCP_TLV_VERSION: + local->remote_version = llcp_tlv_version(tlv); + break; + case LLCP_TLV_MIUX: + local->remote_miu = llcp_tlv_miux(tlv) + 128; + break; + case LLCP_TLV_WKS: + local->remote_wks = llcp_tlv_wks(tlv); + break; + case LLCP_TLV_LTO: + local->remote_lto = llcp_tlv_lto(tlv) * 10; + break; + case LLCP_TLV_OPT: + local->remote_opt = llcp_tlv_opt(tlv); + break; + default: + pr_err("Invalid gt tlv value 0x%x\n", type); + break; + } + + offset += length + 2; + tlv += length + 2; + } + + pr_debug("version 0x%x miu %d lto %d opt 0x%x wks 0x%x\n", + local->remote_version, local->remote_miu, + local->remote_lto, local->remote_opt, + local->remote_wks); + + return 0; +} + +int nfc_llcp_parse_connection_tlv(struct nfc_llcp_sock *sock, + const u8 *tlv_array, u16 tlv_array_len) +{ + const u8 *tlv = tlv_array; + u8 type, length, offset = 0; + + pr_debug("TLV array length %d\n", tlv_array_len); + + if (sock == NULL) + return -ENOTCONN; + + while (offset < tlv_array_len) { + type = tlv[0]; + length = tlv[1]; + + pr_debug("type 0x%x length %d\n", type, length); + + switch (type) { + case LLCP_TLV_MIUX: + sock->remote_miu = llcp_tlv_miux(tlv) + 128; + break; + case LLCP_TLV_RW: + sock->remote_rw = llcp_tlv_rw(tlv); + break; + case LLCP_TLV_SN: + break; + default: + pr_err("Invalid gt tlv value 0x%x\n", type); + break; + } + + offset += length + 2; + tlv += length + 2; + } + + pr_debug("sock %p rw %d miu %d\n", sock, + sock->remote_rw, sock->remote_miu); + + return 0; +} + +static struct sk_buff *llcp_add_header(struct sk_buff *pdu, + u8 dsap, u8 ssap, u8 ptype) +{ + u8 header[2]; + + pr_debug("ptype 0x%x dsap 0x%x ssap 0x%x\n", ptype, dsap, ssap); + + header[0] = (u8)((dsap << 2) | (ptype >> 2)); + header[1] = (u8)((ptype << 6) | ssap); + + pr_debug("header 0x%x 0x%x\n", header[0], header[1]); + + skb_put_data(pdu, header, LLCP_HEADER_SIZE); + + return pdu; +} + +static struct sk_buff *llcp_add_tlv(struct sk_buff *pdu, const u8 *tlv, + u8 tlv_length) +{ + /* XXX Add an skb length check */ + + if (tlv == NULL) + return NULL; + + skb_put_data(pdu, tlv, tlv_length); + + return pdu; +} + +static struct sk_buff *llcp_allocate_pdu(struct nfc_llcp_sock *sock, + u8 cmd, u16 size) +{ + struct sk_buff *skb; + int err; + + if (sock->ssap == 0) + return NULL; + + skb = nfc_alloc_send_skb(sock->dev, &sock->sk, MSG_DONTWAIT, + size + LLCP_HEADER_SIZE, &err); + if (skb == NULL) { + pr_err("Could not allocate PDU\n"); + return NULL; + } + + skb = llcp_add_header(skb, sock->dsap, sock->ssap, cmd); + + return skb; +} + +int nfc_llcp_send_disconnect(struct nfc_llcp_sock *sock) +{ + struct sk_buff *skb; + struct nfc_dev *dev; + struct nfc_llcp_local *local; + + local = sock->local; + if (local == NULL) + return -ENODEV; + + dev = sock->dev; + if (dev == NULL) + return -ENODEV; + + skb = llcp_allocate_pdu(sock, LLCP_PDU_DISC, 0); + if (skb == NULL) + return -ENOMEM; + + skb_queue_tail(&local->tx_queue, skb); + + return 0; +} + +int nfc_llcp_send_symm(struct nfc_dev *dev) +{ + struct sk_buff *skb; + struct nfc_llcp_local *local; + u16 size = 0; + int err; + + local = nfc_llcp_find_local(dev); + if (local == NULL) + return -ENODEV; + + size += LLCP_HEADER_SIZE; + size += dev->tx_headroom + dev->tx_tailroom + NFC_HEADER_SIZE; + + skb = alloc_skb(size, GFP_KERNEL); + if (skb == NULL) { + err = -ENOMEM; + goto out; + } + + skb_reserve(skb, dev->tx_headroom + NFC_HEADER_SIZE); + + skb = llcp_add_header(skb, 0, 0, LLCP_PDU_SYMM); + + __net_timestamp(skb); + + nfc_llcp_send_to_raw_sock(local, skb, NFC_DIRECTION_TX); + + err = nfc_data_exchange(dev, local->target_idx, skb, + nfc_llcp_recv, local); +out: + nfc_llcp_local_put(local); + return err; +} + +int nfc_llcp_send_connect(struct nfc_llcp_sock *sock) +{ + struct nfc_llcp_local *local; + struct sk_buff *skb; + const u8 *service_name_tlv = NULL; + const u8 *miux_tlv = NULL; + const u8 *rw_tlv = NULL; + u8 service_name_tlv_length = 0; + u8 miux_tlv_length, rw_tlv_length, rw; + int err; + u16 size = 0; + __be16 miux; + + local = sock->local; + if (local == NULL) + return -ENODEV; + + if (sock->service_name != NULL) { + service_name_tlv = nfc_llcp_build_tlv(LLCP_TLV_SN, + sock->service_name, + sock->service_name_len, + &service_name_tlv_length); + if (!service_name_tlv) { + err = -ENOMEM; + goto error_tlv; + } + size += service_name_tlv_length; + } + + /* If the socket parameters are not set, use the local ones */ + miux = be16_to_cpu(sock->miux) > LLCP_MAX_MIUX ? + local->miux : sock->miux; + rw = sock->rw > LLCP_MAX_RW ? local->rw : sock->rw; + + miux_tlv = nfc_llcp_build_tlv(LLCP_TLV_MIUX, (u8 *)&miux, 0, + &miux_tlv_length); + if (!miux_tlv) { + err = -ENOMEM; + goto error_tlv; + } + size += miux_tlv_length; + + rw_tlv = nfc_llcp_build_tlv(LLCP_TLV_RW, &rw, 0, &rw_tlv_length); + if (!rw_tlv) { + err = -ENOMEM; + goto error_tlv; + } + size += rw_tlv_length; + + pr_debug("SKB size %d SN length %zu\n", size, sock->service_name_len); + + skb = llcp_allocate_pdu(sock, LLCP_PDU_CONNECT, size); + if (skb == NULL) { + err = -ENOMEM; + goto error_tlv; + } + + llcp_add_tlv(skb, service_name_tlv, service_name_tlv_length); + llcp_add_tlv(skb, miux_tlv, miux_tlv_length); + llcp_add_tlv(skb, rw_tlv, rw_tlv_length); + + skb_queue_tail(&local->tx_queue, skb); + + err = 0; + +error_tlv: + if (err) + pr_err("error %d\n", err); + + kfree(service_name_tlv); + kfree(miux_tlv); + kfree(rw_tlv); + + return err; +} + +int nfc_llcp_send_cc(struct nfc_llcp_sock *sock) +{ + struct nfc_llcp_local *local; + struct sk_buff *skb; + const u8 *miux_tlv = NULL; + const u8 *rw_tlv = NULL; + u8 miux_tlv_length, rw_tlv_length, rw; + int err; + u16 size = 0; + __be16 miux; + + local = sock->local; + if (local == NULL) + return -ENODEV; + + /* If the socket parameters are not set, use the local ones */ + miux = be16_to_cpu(sock->miux) > LLCP_MAX_MIUX ? + local->miux : sock->miux; + rw = sock->rw > LLCP_MAX_RW ? local->rw : sock->rw; + + miux_tlv = nfc_llcp_build_tlv(LLCP_TLV_MIUX, (u8 *)&miux, 0, + &miux_tlv_length); + if (!miux_tlv) { + err = -ENOMEM; + goto error_tlv; + } + size += miux_tlv_length; + + rw_tlv = nfc_llcp_build_tlv(LLCP_TLV_RW, &rw, 0, &rw_tlv_length); + if (!rw_tlv) { + err = -ENOMEM; + goto error_tlv; + } + size += rw_tlv_length; + + skb = llcp_allocate_pdu(sock, LLCP_PDU_CC, size); + if (skb == NULL) { + err = -ENOMEM; + goto error_tlv; + } + + llcp_add_tlv(skb, miux_tlv, miux_tlv_length); + llcp_add_tlv(skb, rw_tlv, rw_tlv_length); + + skb_queue_tail(&local->tx_queue, skb); + + err = 0; + +error_tlv: + if (err) + pr_err("error %d\n", err); + + kfree(miux_tlv); + kfree(rw_tlv); + + return err; +} + +static struct sk_buff *nfc_llcp_allocate_snl(struct nfc_llcp_local *local, + size_t tlv_length) +{ + struct sk_buff *skb; + struct nfc_dev *dev; + u16 size = 0; + + if (local == NULL) + return ERR_PTR(-ENODEV); + + dev = local->dev; + if (dev == NULL) + return ERR_PTR(-ENODEV); + + size += LLCP_HEADER_SIZE; + size += dev->tx_headroom + dev->tx_tailroom + NFC_HEADER_SIZE; + size += tlv_length; + + skb = alloc_skb(size, GFP_KERNEL); + if (skb == NULL) + return ERR_PTR(-ENOMEM); + + skb_reserve(skb, dev->tx_headroom + NFC_HEADER_SIZE); + + skb = llcp_add_header(skb, LLCP_SAP_SDP, LLCP_SAP_SDP, LLCP_PDU_SNL); + + return skb; +} + +int nfc_llcp_send_snl_sdres(struct nfc_llcp_local *local, + struct hlist_head *tlv_list, size_t tlvs_len) +{ + struct nfc_llcp_sdp_tlv *sdp; + struct hlist_node *n; + struct sk_buff *skb; + + skb = nfc_llcp_allocate_snl(local, tlvs_len); + if (IS_ERR(skb)) + return PTR_ERR(skb); + + hlist_for_each_entry_safe(sdp, n, tlv_list, node) { + skb_put_data(skb, sdp->tlv, sdp->tlv_len); + + hlist_del(&sdp->node); + + nfc_llcp_free_sdp_tlv(sdp); + } + + skb_queue_tail(&local->tx_queue, skb); + + return 0; +} + +int nfc_llcp_send_snl_sdreq(struct nfc_llcp_local *local, + struct hlist_head *tlv_list, size_t tlvs_len) +{ + struct nfc_llcp_sdp_tlv *sdreq; + struct hlist_node *n; + struct sk_buff *skb; + + skb = nfc_llcp_allocate_snl(local, tlvs_len); + if (IS_ERR(skb)) + return PTR_ERR(skb); + + mutex_lock(&local->sdreq_lock); + + if (hlist_empty(&local->pending_sdreqs)) + mod_timer(&local->sdreq_timer, + jiffies + msecs_to_jiffies(3 * local->remote_lto)); + + hlist_for_each_entry_safe(sdreq, n, tlv_list, node) { + pr_debug("tid %d for %s\n", sdreq->tid, sdreq->uri); + + skb_put_data(skb, sdreq->tlv, sdreq->tlv_len); + + hlist_del(&sdreq->node); + + hlist_add_head(&sdreq->node, &local->pending_sdreqs); + } + + mutex_unlock(&local->sdreq_lock); + + skb_queue_tail(&local->tx_queue, skb); + + return 0; +} + +int nfc_llcp_send_dm(struct nfc_llcp_local *local, u8 ssap, u8 dsap, u8 reason) +{ + struct sk_buff *skb; + struct nfc_dev *dev; + u16 size = 1; /* Reason code */ + + pr_debug("Sending DM reason 0x%x\n", reason); + + if (local == NULL) + return -ENODEV; + + dev = local->dev; + if (dev == NULL) + return -ENODEV; + + size += LLCP_HEADER_SIZE; + size += dev->tx_headroom + dev->tx_tailroom + NFC_HEADER_SIZE; + + skb = alloc_skb(size, GFP_KERNEL); + if (skb == NULL) + return -ENOMEM; + + skb_reserve(skb, dev->tx_headroom + NFC_HEADER_SIZE); + + skb = llcp_add_header(skb, dsap, ssap, LLCP_PDU_DM); + + skb_put_data(skb, &reason, 1); + + skb_queue_head(&local->tx_queue, skb); + + return 0; +} + +int nfc_llcp_send_i_frame(struct nfc_llcp_sock *sock, + struct msghdr *msg, size_t len) +{ + struct sk_buff *pdu; + struct sock *sk = &sock->sk; + struct nfc_llcp_local *local; + size_t frag_len = 0, remaining_len; + u8 *msg_data, *msg_ptr; + u16 remote_miu; + + pr_debug("Send I frame len %zd\n", len); + + local = sock->local; + if (local == NULL) + return -ENODEV; + + /* Remote is ready but has not acknowledged our frames */ + if((sock->remote_ready && + skb_queue_len(&sock->tx_pending_queue) >= sock->remote_rw && + skb_queue_len(&sock->tx_queue) >= 2 * sock->remote_rw)) { + pr_err("Pending queue is full %d frames\n", + skb_queue_len(&sock->tx_pending_queue)); + return -ENOBUFS; + } + + /* Remote is not ready and we've been queueing enough frames */ + if ((!sock->remote_ready && + skb_queue_len(&sock->tx_queue) >= 2 * sock->remote_rw)) { + pr_err("Tx queue is full %d frames\n", + skb_queue_len(&sock->tx_queue)); + return -ENOBUFS; + } + + msg_data = kmalloc(len, GFP_USER | __GFP_NOWARN); + if (msg_data == NULL) + return -ENOMEM; + + if (memcpy_from_msg(msg_data, msg, len)) { + kfree(msg_data); + return -EFAULT; + } + + remaining_len = len; + msg_ptr = msg_data; + + do { + remote_miu = sock->remote_miu > LLCP_MAX_MIU ? + LLCP_DEFAULT_MIU : sock->remote_miu; + + frag_len = min_t(size_t, remote_miu, remaining_len); + + pr_debug("Fragment %zd bytes remaining %zd", + frag_len, remaining_len); + + pdu = llcp_allocate_pdu(sock, LLCP_PDU_I, + frag_len + LLCP_SEQUENCE_SIZE); + if (pdu == NULL) { + kfree(msg_data); + return -ENOMEM; + } + + skb_put(pdu, LLCP_SEQUENCE_SIZE); + + if (likely(frag_len > 0)) + skb_put_data(pdu, msg_ptr, frag_len); + + skb_queue_tail(&sock->tx_queue, pdu); + + lock_sock(sk); + + nfc_llcp_queue_i_frames(sock); + + release_sock(sk); + + remaining_len -= frag_len; + msg_ptr += frag_len; + } while (remaining_len > 0); + + kfree(msg_data); + + return len; +} + +int nfc_llcp_send_ui_frame(struct nfc_llcp_sock *sock, u8 ssap, u8 dsap, + struct msghdr *msg, size_t len) +{ + struct sk_buff *pdu; + struct nfc_llcp_local *local; + size_t frag_len = 0, remaining_len; + u8 *msg_ptr, *msg_data; + u16 remote_miu; + int err; + + pr_debug("Send UI frame len %zd\n", len); + + local = sock->local; + if (local == NULL) + return -ENODEV; + + msg_data = kmalloc(len, GFP_USER | __GFP_NOWARN); + if (msg_data == NULL) + return -ENOMEM; + + if (memcpy_from_msg(msg_data, msg, len)) { + kfree(msg_data); + return -EFAULT; + } + + remaining_len = len; + msg_ptr = msg_data; + + do { + remote_miu = sock->remote_miu > LLCP_MAX_MIU ? + local->remote_miu : sock->remote_miu; + + frag_len = min_t(size_t, remote_miu, remaining_len); + + pr_debug("Fragment %zd bytes remaining %zd", + frag_len, remaining_len); + + pdu = nfc_alloc_send_skb(sock->dev, &sock->sk, 0, + frag_len + LLCP_HEADER_SIZE, &err); + if (pdu == NULL) { + pr_err("Could not allocate PDU (error=%d)\n", err); + len -= remaining_len; + if (len == 0) + len = err; + break; + } + + pdu = llcp_add_header(pdu, dsap, ssap, LLCP_PDU_UI); + + if (likely(frag_len > 0)) + skb_put_data(pdu, msg_ptr, frag_len); + + /* No need to check for the peer RW for UI frames */ + skb_queue_tail(&local->tx_queue, pdu); + + remaining_len -= frag_len; + msg_ptr += frag_len; + } while (remaining_len > 0); + + kfree(msg_data); + + return len; +} + +int nfc_llcp_send_rr(struct nfc_llcp_sock *sock) +{ + struct sk_buff *skb; + struct nfc_llcp_local *local; + + pr_debug("Send rr nr %d\n", sock->recv_n); + + local = sock->local; + if (local == NULL) + return -ENODEV; + + skb = llcp_allocate_pdu(sock, LLCP_PDU_RR, LLCP_SEQUENCE_SIZE); + if (skb == NULL) + return -ENOMEM; + + skb_put(skb, LLCP_SEQUENCE_SIZE); + + skb->data[2] = sock->recv_n; + + skb_queue_head(&local->tx_queue, skb); + + return 0; +} diff --git a/net/nfc/llcp_core.c b/net/nfc/llcp_core.c new file mode 100644 index 000000000..18be13fb9 --- /dev/null +++ b/net/nfc/llcp_core.c @@ -0,0 +1,1695 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (C) 2011 Intel Corporation. All rights reserved. + * Copyright (C) 2014 Marvell International Ltd. + */ + +#define pr_fmt(fmt) "llcp: %s: " fmt, __func__ + +#include +#include +#include +#include + +#include "nfc.h" +#include "llcp.h" + +static u8 llcp_magic[3] = {0x46, 0x66, 0x6d}; + +static LIST_HEAD(llcp_devices); +/* Protects llcp_devices list */ +static DEFINE_SPINLOCK(llcp_devices_lock); + +static void nfc_llcp_rx_skb(struct nfc_llcp_local *local, struct sk_buff *skb); + +void nfc_llcp_sock_link(struct llcp_sock_list *l, struct sock *sk) +{ + write_lock(&l->lock); + sk_add_node(sk, &l->head); + write_unlock(&l->lock); +} + +void nfc_llcp_sock_unlink(struct llcp_sock_list *l, struct sock *sk) +{ + write_lock(&l->lock); + sk_del_node_init(sk); + write_unlock(&l->lock); +} + +void nfc_llcp_socket_remote_param_init(struct nfc_llcp_sock *sock) +{ + sock->remote_rw = LLCP_DEFAULT_RW; + sock->remote_miu = LLCP_MAX_MIU + 1; +} + +static void nfc_llcp_socket_purge(struct nfc_llcp_sock *sock) +{ + struct nfc_llcp_local *local = sock->local; + struct sk_buff *s, *tmp; + + skb_queue_purge(&sock->tx_queue); + skb_queue_purge(&sock->tx_pending_queue); + + if (local == NULL) + return; + + /* Search for local pending SKBs that are related to this socket */ + skb_queue_walk_safe(&local->tx_queue, s, tmp) { + if (s->sk != &sock->sk) + continue; + + skb_unlink(s, &local->tx_queue); + kfree_skb(s); + } +} + +static void nfc_llcp_socket_release(struct nfc_llcp_local *local, bool device, + int err) +{ + struct sock *sk; + struct hlist_node *tmp; + struct nfc_llcp_sock *llcp_sock; + + skb_queue_purge(&local->tx_queue); + + write_lock(&local->sockets.lock); + + sk_for_each_safe(sk, tmp, &local->sockets.head) { + llcp_sock = nfc_llcp_sock(sk); + + bh_lock_sock(sk); + + nfc_llcp_socket_purge(llcp_sock); + + if (sk->sk_state == LLCP_CONNECTED) + nfc_put_device(llcp_sock->dev); + + if (sk->sk_state == LLCP_LISTEN) { + struct nfc_llcp_sock *lsk, *n; + struct sock *accept_sk; + + list_for_each_entry_safe(lsk, n, + &llcp_sock->accept_queue, + accept_queue) { + accept_sk = &lsk->sk; + bh_lock_sock(accept_sk); + + nfc_llcp_accept_unlink(accept_sk); + + if (err) + accept_sk->sk_err = err; + accept_sk->sk_state = LLCP_CLOSED; + accept_sk->sk_state_change(sk); + + bh_unlock_sock(accept_sk); + } + } + + if (err) + sk->sk_err = err; + sk->sk_state = LLCP_CLOSED; + sk->sk_state_change(sk); + + bh_unlock_sock(sk); + + sk_del_node_init(sk); + } + + write_unlock(&local->sockets.lock); + + /* If we still have a device, we keep the RAW sockets alive */ + if (device == true) + return; + + write_lock(&local->raw_sockets.lock); + + sk_for_each_safe(sk, tmp, &local->raw_sockets.head) { + llcp_sock = nfc_llcp_sock(sk); + + bh_lock_sock(sk); + + nfc_llcp_socket_purge(llcp_sock); + + if (err) + sk->sk_err = err; + sk->sk_state = LLCP_CLOSED; + sk->sk_state_change(sk); + + bh_unlock_sock(sk); + + sk_del_node_init(sk); + } + + write_unlock(&local->raw_sockets.lock); +} + +static struct nfc_llcp_local *nfc_llcp_local_get(struct nfc_llcp_local *local) +{ + /* Since using nfc_llcp_local may result in usage of nfc_dev, whenever + * we hold a reference to local, we also need to hold a reference to + * the device to avoid UAF. + */ + if (!nfc_get_device(local->dev->idx)) + return NULL; + + kref_get(&local->ref); + + return local; +} + +static void local_cleanup(struct nfc_llcp_local *local) +{ + nfc_llcp_socket_release(local, false, ENXIO); + del_timer_sync(&local->link_timer); + skb_queue_purge(&local->tx_queue); + cancel_work_sync(&local->tx_work); + cancel_work_sync(&local->rx_work); + cancel_work_sync(&local->timeout_work); + kfree_skb(local->rx_pending); + local->rx_pending = NULL; + del_timer_sync(&local->sdreq_timer); + cancel_work_sync(&local->sdreq_timeout_work); + nfc_llcp_free_sdp_tlv_list(&local->pending_sdreqs); +} + +static void local_release(struct kref *ref) +{ + struct nfc_llcp_local *local; + + local = container_of(ref, struct nfc_llcp_local, ref); + + local_cleanup(local); + kfree(local); +} + +int nfc_llcp_local_put(struct nfc_llcp_local *local) +{ + struct nfc_dev *dev; + int ret; + + if (local == NULL) + return 0; + + dev = local->dev; + + ret = kref_put(&local->ref, local_release); + nfc_put_device(dev); + + return ret; +} + +static struct nfc_llcp_sock *nfc_llcp_sock_get(struct nfc_llcp_local *local, + u8 ssap, u8 dsap) +{ + struct sock *sk; + struct nfc_llcp_sock *llcp_sock, *tmp_sock; + + pr_debug("ssap dsap %d %d\n", ssap, dsap); + + if (ssap == 0 && dsap == 0) + return NULL; + + read_lock(&local->sockets.lock); + + llcp_sock = NULL; + + sk_for_each(sk, &local->sockets.head) { + tmp_sock = nfc_llcp_sock(sk); + + if (tmp_sock->ssap == ssap && tmp_sock->dsap == dsap) { + llcp_sock = tmp_sock; + sock_hold(&llcp_sock->sk); + break; + } + } + + read_unlock(&local->sockets.lock); + + return llcp_sock; +} + +static void nfc_llcp_sock_put(struct nfc_llcp_sock *sock) +{ + sock_put(&sock->sk); +} + +static void nfc_llcp_timeout_work(struct work_struct *work) +{ + struct nfc_llcp_local *local = container_of(work, struct nfc_llcp_local, + timeout_work); + + nfc_dep_link_down(local->dev); +} + +static void nfc_llcp_symm_timer(struct timer_list *t) +{ + struct nfc_llcp_local *local = from_timer(local, t, link_timer); + + pr_err("SYMM timeout\n"); + + schedule_work(&local->timeout_work); +} + +static void nfc_llcp_sdreq_timeout_work(struct work_struct *work) +{ + unsigned long time; + HLIST_HEAD(nl_sdres_list); + struct hlist_node *n; + struct nfc_llcp_sdp_tlv *sdp; + struct nfc_llcp_local *local = container_of(work, struct nfc_llcp_local, + sdreq_timeout_work); + + mutex_lock(&local->sdreq_lock); + + time = jiffies - msecs_to_jiffies(3 * local->remote_lto); + + hlist_for_each_entry_safe(sdp, n, &local->pending_sdreqs, node) { + if (time_after(sdp->time, time)) + continue; + + sdp->sap = LLCP_SDP_UNBOUND; + + hlist_del(&sdp->node); + + hlist_add_head(&sdp->node, &nl_sdres_list); + } + + if (!hlist_empty(&local->pending_sdreqs)) + mod_timer(&local->sdreq_timer, + jiffies + msecs_to_jiffies(3 * local->remote_lto)); + + mutex_unlock(&local->sdreq_lock); + + if (!hlist_empty(&nl_sdres_list)) + nfc_genl_llc_send_sdres(local->dev, &nl_sdres_list); +} + +static void nfc_llcp_sdreq_timer(struct timer_list *t) +{ + struct nfc_llcp_local *local = from_timer(local, t, sdreq_timer); + + schedule_work(&local->sdreq_timeout_work); +} + +struct nfc_llcp_local *nfc_llcp_find_local(struct nfc_dev *dev) +{ + struct nfc_llcp_local *local; + struct nfc_llcp_local *res = NULL; + + spin_lock(&llcp_devices_lock); + list_for_each_entry(local, &llcp_devices, list) + if (local->dev == dev) { + res = nfc_llcp_local_get(local); + break; + } + spin_unlock(&llcp_devices_lock); + + return res; +} + +static struct nfc_llcp_local *nfc_llcp_remove_local(struct nfc_dev *dev) +{ + struct nfc_llcp_local *local, *tmp; + + spin_lock(&llcp_devices_lock); + list_for_each_entry_safe(local, tmp, &llcp_devices, list) + if (local->dev == dev) { + list_del(&local->list); + spin_unlock(&llcp_devices_lock); + return local; + } + spin_unlock(&llcp_devices_lock); + + pr_warn("Shutting down device not found\n"); + + return NULL; +} + +static char *wks[] = { + NULL, + NULL, /* SDP */ + "urn:nfc:sn:ip", + "urn:nfc:sn:obex", + "urn:nfc:sn:snep", +}; + +static int nfc_llcp_wks_sap(const char *service_name, size_t service_name_len) +{ + int sap, num_wks; + + pr_debug("%s\n", service_name); + + if (service_name == NULL) + return -EINVAL; + + num_wks = ARRAY_SIZE(wks); + + for (sap = 0; sap < num_wks; sap++) { + if (wks[sap] == NULL) + continue; + + if (strncmp(wks[sap], service_name, service_name_len) == 0) + return sap; + } + + return -EINVAL; +} + +static +struct nfc_llcp_sock *nfc_llcp_sock_from_sn(struct nfc_llcp_local *local, + const u8 *sn, size_t sn_len, + bool needref) +{ + struct sock *sk; + struct nfc_llcp_sock *llcp_sock, *tmp_sock; + + pr_debug("sn %zd %p\n", sn_len, sn); + + if (sn == NULL || sn_len == 0) + return NULL; + + read_lock(&local->sockets.lock); + + llcp_sock = NULL; + + sk_for_each(sk, &local->sockets.head) { + tmp_sock = nfc_llcp_sock(sk); + + pr_debug("llcp sock %p\n", tmp_sock); + + if (tmp_sock->sk.sk_type == SOCK_STREAM && + tmp_sock->sk.sk_state != LLCP_LISTEN) + continue; + + if (tmp_sock->sk.sk_type == SOCK_DGRAM && + tmp_sock->sk.sk_state != LLCP_BOUND) + continue; + + if (tmp_sock->service_name == NULL || + tmp_sock->service_name_len == 0) + continue; + + if (tmp_sock->service_name_len != sn_len) + continue; + + if (memcmp(sn, tmp_sock->service_name, sn_len) == 0) { + llcp_sock = tmp_sock; + if (needref) + sock_hold(&llcp_sock->sk); + break; + } + } + + read_unlock(&local->sockets.lock); + + pr_debug("Found llcp sock %p\n", llcp_sock); + + return llcp_sock; +} + +u8 nfc_llcp_get_sdp_ssap(struct nfc_llcp_local *local, + struct nfc_llcp_sock *sock) +{ + mutex_lock(&local->sdp_lock); + + if (sock->service_name != NULL && sock->service_name_len > 0) { + int ssap = nfc_llcp_wks_sap(sock->service_name, + sock->service_name_len); + + if (ssap > 0) { + pr_debug("WKS %d\n", ssap); + + /* This is a WKS, let's check if it's free */ + if (test_bit(ssap, &local->local_wks)) { + mutex_unlock(&local->sdp_lock); + + return LLCP_SAP_MAX; + } + + set_bit(ssap, &local->local_wks); + mutex_unlock(&local->sdp_lock); + + return ssap; + } + + /* + * Check if there already is a non WKS socket bound + * to this service name. + */ + if (nfc_llcp_sock_from_sn(local, sock->service_name, + sock->service_name_len, + false) != NULL) { + mutex_unlock(&local->sdp_lock); + + return LLCP_SAP_MAX; + } + + mutex_unlock(&local->sdp_lock); + + return LLCP_SDP_UNBOUND; + + } else if (sock->ssap != 0 && sock->ssap < LLCP_WKS_NUM_SAP) { + if (!test_bit(sock->ssap, &local->local_wks)) { + set_bit(sock->ssap, &local->local_wks); + mutex_unlock(&local->sdp_lock); + + return sock->ssap; + } + } + + mutex_unlock(&local->sdp_lock); + + return LLCP_SAP_MAX; +} + +u8 nfc_llcp_get_local_ssap(struct nfc_llcp_local *local) +{ + u8 local_ssap; + + mutex_lock(&local->sdp_lock); + + local_ssap = find_first_zero_bit(&local->local_sap, LLCP_LOCAL_NUM_SAP); + if (local_ssap == LLCP_LOCAL_NUM_SAP) { + mutex_unlock(&local->sdp_lock); + return LLCP_SAP_MAX; + } + + set_bit(local_ssap, &local->local_sap); + + mutex_unlock(&local->sdp_lock); + + return local_ssap + LLCP_LOCAL_SAP_OFFSET; +} + +void nfc_llcp_put_ssap(struct nfc_llcp_local *local, u8 ssap) +{ + u8 local_ssap; + unsigned long *sdp; + + if (ssap < LLCP_WKS_NUM_SAP) { + local_ssap = ssap; + sdp = &local->local_wks; + } else if (ssap < LLCP_LOCAL_NUM_SAP) { + atomic_t *client_cnt; + + local_ssap = ssap - LLCP_WKS_NUM_SAP; + sdp = &local->local_sdp; + client_cnt = &local->local_sdp_cnt[local_ssap]; + + pr_debug("%d clients\n", atomic_read(client_cnt)); + + mutex_lock(&local->sdp_lock); + + if (atomic_dec_and_test(client_cnt)) { + struct nfc_llcp_sock *l_sock; + + pr_debug("No more clients for SAP %d\n", ssap); + + clear_bit(local_ssap, sdp); + + /* Find the listening sock and set it back to UNBOUND */ + l_sock = nfc_llcp_sock_get(local, ssap, LLCP_SAP_SDP); + if (l_sock) { + l_sock->ssap = LLCP_SDP_UNBOUND; + nfc_llcp_sock_put(l_sock); + } + } + + mutex_unlock(&local->sdp_lock); + + return; + } else if (ssap < LLCP_MAX_SAP) { + local_ssap = ssap - LLCP_LOCAL_NUM_SAP; + sdp = &local->local_sap; + } else { + return; + } + + mutex_lock(&local->sdp_lock); + + clear_bit(local_ssap, sdp); + + mutex_unlock(&local->sdp_lock); +} + +static u8 nfc_llcp_reserve_sdp_ssap(struct nfc_llcp_local *local) +{ + u8 ssap; + + mutex_lock(&local->sdp_lock); + + ssap = find_first_zero_bit(&local->local_sdp, LLCP_SDP_NUM_SAP); + if (ssap == LLCP_SDP_NUM_SAP) { + mutex_unlock(&local->sdp_lock); + + return LLCP_SAP_MAX; + } + + pr_debug("SDP ssap %d\n", LLCP_WKS_NUM_SAP + ssap); + + set_bit(ssap, &local->local_sdp); + + mutex_unlock(&local->sdp_lock); + + return LLCP_WKS_NUM_SAP + ssap; +} + +static int nfc_llcp_build_gb(struct nfc_llcp_local *local) +{ + u8 *gb_cur, version, version_length; + u8 lto_length, wks_length, miux_length; + const u8 *version_tlv = NULL, *lto_tlv = NULL, + *wks_tlv = NULL, *miux_tlv = NULL; + __be16 wks = cpu_to_be16(local->local_wks); + u8 gb_len = 0; + int ret = 0; + + version = LLCP_VERSION_11; + version_tlv = nfc_llcp_build_tlv(LLCP_TLV_VERSION, &version, + 1, &version_length); + if (!version_tlv) { + ret = -ENOMEM; + goto out; + } + gb_len += version_length; + + lto_tlv = nfc_llcp_build_tlv(LLCP_TLV_LTO, &local->lto, 1, <o_length); + if (!lto_tlv) { + ret = -ENOMEM; + goto out; + } + gb_len += lto_length; + + pr_debug("Local wks 0x%lx\n", local->local_wks); + wks_tlv = nfc_llcp_build_tlv(LLCP_TLV_WKS, (u8 *)&wks, 2, &wks_length); + if (!wks_tlv) { + ret = -ENOMEM; + goto out; + } + gb_len += wks_length; + + miux_tlv = nfc_llcp_build_tlv(LLCP_TLV_MIUX, (u8 *)&local->miux, 0, + &miux_length); + if (!miux_tlv) { + ret = -ENOMEM; + goto out; + } + gb_len += miux_length; + + gb_len += ARRAY_SIZE(llcp_magic); + + if (gb_len > NFC_MAX_GT_LEN) { + ret = -EINVAL; + goto out; + } + + gb_cur = local->gb; + + memcpy(gb_cur, llcp_magic, ARRAY_SIZE(llcp_magic)); + gb_cur += ARRAY_SIZE(llcp_magic); + + memcpy(gb_cur, version_tlv, version_length); + gb_cur += version_length; + + memcpy(gb_cur, lto_tlv, lto_length); + gb_cur += lto_length; + + memcpy(gb_cur, wks_tlv, wks_length); + gb_cur += wks_length; + + memcpy(gb_cur, miux_tlv, miux_length); + gb_cur += miux_length; + + local->gb_len = gb_len; + +out: + kfree(version_tlv); + kfree(lto_tlv); + kfree(wks_tlv); + kfree(miux_tlv); + + return ret; +} + +u8 *nfc_llcp_general_bytes(struct nfc_dev *dev, size_t *general_bytes_len) +{ + struct nfc_llcp_local *local; + + local = nfc_llcp_find_local(dev); + if (local == NULL) { + *general_bytes_len = 0; + return NULL; + } + + nfc_llcp_build_gb(local); + + *general_bytes_len = local->gb_len; + + nfc_llcp_local_put(local); + + return local->gb; +} + +int nfc_llcp_set_remote_gb(struct nfc_dev *dev, const u8 *gb, u8 gb_len) +{ + struct nfc_llcp_local *local; + int err; + + if (gb_len < 3 || gb_len > NFC_MAX_GT_LEN) + return -EINVAL; + + local = nfc_llcp_find_local(dev); + if (local == NULL) { + pr_err("No LLCP device\n"); + return -ENODEV; + } + + memset(local->remote_gb, 0, NFC_MAX_GT_LEN); + memcpy(local->remote_gb, gb, gb_len); + local->remote_gb_len = gb_len; + + if (memcmp(local->remote_gb, llcp_magic, 3)) { + pr_err("MAC does not support LLCP\n"); + err = -EINVAL; + goto out; + } + + err = nfc_llcp_parse_gb_tlv(local, + &local->remote_gb[3], + local->remote_gb_len - 3); +out: + nfc_llcp_local_put(local); + return err; +} + +static u8 nfc_llcp_dsap(const struct sk_buff *pdu) +{ + return (pdu->data[0] & 0xfc) >> 2; +} + +static u8 nfc_llcp_ptype(const struct sk_buff *pdu) +{ + return ((pdu->data[0] & 0x03) << 2) | ((pdu->data[1] & 0xc0) >> 6); +} + +static u8 nfc_llcp_ssap(const struct sk_buff *pdu) +{ + return pdu->data[1] & 0x3f; +} + +static u8 nfc_llcp_ns(const struct sk_buff *pdu) +{ + return pdu->data[2] >> 4; +} + +static u8 nfc_llcp_nr(const struct sk_buff *pdu) +{ + return pdu->data[2] & 0xf; +} + +static void nfc_llcp_set_nrns(struct nfc_llcp_sock *sock, struct sk_buff *pdu) +{ + pdu->data[2] = (sock->send_n << 4) | (sock->recv_n); + sock->send_n = (sock->send_n + 1) % 16; + sock->recv_ack_n = (sock->recv_n - 1) % 16; +} + +void nfc_llcp_send_to_raw_sock(struct nfc_llcp_local *local, + struct sk_buff *skb, u8 direction) +{ + struct sk_buff *skb_copy = NULL, *nskb; + struct sock *sk; + u8 *data; + + read_lock(&local->raw_sockets.lock); + + sk_for_each(sk, &local->raw_sockets.head) { + if (sk->sk_state != LLCP_BOUND) + continue; + + if (skb_copy == NULL) { + skb_copy = __pskb_copy_fclone(skb, NFC_RAW_HEADER_SIZE, + GFP_ATOMIC, true); + + if (skb_copy == NULL) + continue; + + data = skb_push(skb_copy, NFC_RAW_HEADER_SIZE); + + data[0] = local->dev ? local->dev->idx : 0xFF; + data[1] = direction & 0x01; + data[1] |= (RAW_PAYLOAD_LLCP << 1); + } + + nskb = skb_clone(skb_copy, GFP_ATOMIC); + if (!nskb) + continue; + + if (sock_queue_rcv_skb(sk, nskb)) + kfree_skb(nskb); + } + + read_unlock(&local->raw_sockets.lock); + + kfree_skb(skb_copy); +} + +static void nfc_llcp_tx_work(struct work_struct *work) +{ + struct nfc_llcp_local *local = container_of(work, struct nfc_llcp_local, + tx_work); + struct sk_buff *skb; + struct sock *sk; + struct nfc_llcp_sock *llcp_sock; + + skb = skb_dequeue(&local->tx_queue); + if (skb != NULL) { + sk = skb->sk; + llcp_sock = nfc_llcp_sock(sk); + + if (llcp_sock == NULL && nfc_llcp_ptype(skb) == LLCP_PDU_I) { + kfree_skb(skb); + nfc_llcp_send_symm(local->dev); + } else if (llcp_sock && !llcp_sock->remote_ready) { + skb_queue_head(&local->tx_queue, skb); + nfc_llcp_send_symm(local->dev); + } else { + struct sk_buff *copy_skb = NULL; + u8 ptype = nfc_llcp_ptype(skb); + int ret; + + pr_debug("Sending pending skb\n"); + print_hex_dump_debug("LLCP Tx: ", DUMP_PREFIX_OFFSET, + 16, 1, skb->data, skb->len, true); + + if (ptype == LLCP_PDU_I) + copy_skb = skb_copy(skb, GFP_ATOMIC); + + __net_timestamp(skb); + + nfc_llcp_send_to_raw_sock(local, skb, + NFC_DIRECTION_TX); + + ret = nfc_data_exchange(local->dev, local->target_idx, + skb, nfc_llcp_recv, local); + + if (ret) { + kfree_skb(copy_skb); + goto out; + } + + if (ptype == LLCP_PDU_I && copy_skb) + skb_queue_tail(&llcp_sock->tx_pending_queue, + copy_skb); + } + } else { + nfc_llcp_send_symm(local->dev); + } + +out: + mod_timer(&local->link_timer, + jiffies + msecs_to_jiffies(2 * local->remote_lto)); +} + +static struct nfc_llcp_sock *nfc_llcp_connecting_sock_get(struct nfc_llcp_local *local, + u8 ssap) +{ + struct sock *sk; + struct nfc_llcp_sock *llcp_sock; + + read_lock(&local->connecting_sockets.lock); + + sk_for_each(sk, &local->connecting_sockets.head) { + llcp_sock = nfc_llcp_sock(sk); + + if (llcp_sock->ssap == ssap) { + sock_hold(&llcp_sock->sk); + goto out; + } + } + + llcp_sock = NULL; + +out: + read_unlock(&local->connecting_sockets.lock); + + return llcp_sock; +} + +static struct nfc_llcp_sock *nfc_llcp_sock_get_sn(struct nfc_llcp_local *local, + const u8 *sn, size_t sn_len) +{ + return nfc_llcp_sock_from_sn(local, sn, sn_len, true); +} + +static const u8 *nfc_llcp_connect_sn(const struct sk_buff *skb, size_t *sn_len) +{ + u8 type, length; + const u8 *tlv = &skb->data[2]; + size_t tlv_array_len = skb->len - LLCP_HEADER_SIZE, offset = 0; + + while (offset < tlv_array_len) { + type = tlv[0]; + length = tlv[1]; + + pr_debug("type 0x%x length %d\n", type, length); + + if (type == LLCP_TLV_SN) { + *sn_len = length; + return &tlv[2]; + } + + offset += length + 2; + tlv += length + 2; + } + + return NULL; +} + +static void nfc_llcp_recv_ui(struct nfc_llcp_local *local, + struct sk_buff *skb) +{ + struct nfc_llcp_sock *llcp_sock; + struct nfc_llcp_ui_cb *ui_cb; + u8 dsap, ssap; + + dsap = nfc_llcp_dsap(skb); + ssap = nfc_llcp_ssap(skb); + + ui_cb = nfc_llcp_ui_skb_cb(skb); + ui_cb->dsap = dsap; + ui_cb->ssap = ssap; + + pr_debug("%d %d\n", dsap, ssap); + + /* We're looking for a bound socket, not a client one */ + llcp_sock = nfc_llcp_sock_get(local, dsap, LLCP_SAP_SDP); + if (llcp_sock == NULL || llcp_sock->sk.sk_type != SOCK_DGRAM) + return; + + /* There is no sequence with UI frames */ + skb_pull(skb, LLCP_HEADER_SIZE); + if (!sock_queue_rcv_skb(&llcp_sock->sk, skb)) { + /* + * UI frames will be freed from the socket layer, so we + * need to keep them alive until someone receives them. + */ + skb_get(skb); + } else { + pr_err("Receive queue is full\n"); + } + + nfc_llcp_sock_put(llcp_sock); +} + +static void nfc_llcp_recv_connect(struct nfc_llcp_local *local, + const struct sk_buff *skb) +{ + struct sock *new_sk, *parent; + struct nfc_llcp_sock *sock, *new_sock; + u8 dsap, ssap, reason; + + dsap = nfc_llcp_dsap(skb); + ssap = nfc_llcp_ssap(skb); + + pr_debug("%d %d\n", dsap, ssap); + + if (dsap != LLCP_SAP_SDP) { + sock = nfc_llcp_sock_get(local, dsap, LLCP_SAP_SDP); + if (sock == NULL || sock->sk.sk_state != LLCP_LISTEN) { + reason = LLCP_DM_NOBOUND; + goto fail; + } + } else { + const u8 *sn; + size_t sn_len; + + sn = nfc_llcp_connect_sn(skb, &sn_len); + if (sn == NULL) { + reason = LLCP_DM_NOBOUND; + goto fail; + } + + pr_debug("Service name length %zu\n", sn_len); + + sock = nfc_llcp_sock_get_sn(local, sn, sn_len); + if (sock == NULL) { + reason = LLCP_DM_NOBOUND; + goto fail; + } + } + + lock_sock(&sock->sk); + + parent = &sock->sk; + + if (sk_acceptq_is_full(parent)) { + reason = LLCP_DM_REJ; + release_sock(&sock->sk); + sock_put(&sock->sk); + goto fail; + } + + if (sock->ssap == LLCP_SDP_UNBOUND) { + u8 ssap = nfc_llcp_reserve_sdp_ssap(local); + + pr_debug("First client, reserving %d\n", ssap); + + if (ssap == LLCP_SAP_MAX) { + reason = LLCP_DM_REJ; + release_sock(&sock->sk); + sock_put(&sock->sk); + goto fail; + } + + sock->ssap = ssap; + } + + new_sk = nfc_llcp_sock_alloc(NULL, parent->sk_type, GFP_ATOMIC, 0); + if (new_sk == NULL) { + reason = LLCP_DM_REJ; + release_sock(&sock->sk); + sock_put(&sock->sk); + goto fail; + } + + new_sock = nfc_llcp_sock(new_sk); + + new_sock->local = nfc_llcp_local_get(local); + if (!new_sock->local) { + reason = LLCP_DM_REJ; + sock_put(&new_sock->sk); + release_sock(&sock->sk); + sock_put(&sock->sk); + goto fail; + } + + new_sock->dev = local->dev; + new_sock->rw = sock->rw; + new_sock->miux = sock->miux; + new_sock->nfc_protocol = sock->nfc_protocol; + new_sock->dsap = ssap; + new_sock->target_idx = local->target_idx; + new_sock->parent = parent; + new_sock->ssap = sock->ssap; + if (sock->ssap < LLCP_LOCAL_NUM_SAP && sock->ssap >= LLCP_WKS_NUM_SAP) { + atomic_t *client_count; + + pr_debug("reserved_ssap %d for %p\n", sock->ssap, new_sock); + + client_count = + &local->local_sdp_cnt[sock->ssap - LLCP_WKS_NUM_SAP]; + + atomic_inc(client_count); + new_sock->reserved_ssap = sock->ssap; + } + + nfc_llcp_parse_connection_tlv(new_sock, &skb->data[LLCP_HEADER_SIZE], + skb->len - LLCP_HEADER_SIZE); + + pr_debug("new sock %p sk %p\n", new_sock, &new_sock->sk); + + nfc_llcp_sock_link(&local->sockets, new_sk); + + nfc_llcp_accept_enqueue(&sock->sk, new_sk); + + nfc_get_device(local->dev->idx); + + new_sk->sk_state = LLCP_CONNECTED; + + /* Wake the listening processes */ + parent->sk_data_ready(parent); + + /* Send CC */ + nfc_llcp_send_cc(new_sock); + + release_sock(&sock->sk); + sock_put(&sock->sk); + + return; + +fail: + /* Send DM */ + nfc_llcp_send_dm(local, dsap, ssap, reason); +} + +int nfc_llcp_queue_i_frames(struct nfc_llcp_sock *sock) +{ + int nr_frames = 0; + struct nfc_llcp_local *local = sock->local; + + pr_debug("Remote ready %d tx queue len %d remote rw %d", + sock->remote_ready, skb_queue_len(&sock->tx_pending_queue), + sock->remote_rw); + + /* Try to queue some I frames for transmission */ + while (sock->remote_ready && + skb_queue_len(&sock->tx_pending_queue) < sock->remote_rw) { + struct sk_buff *pdu; + + pdu = skb_dequeue(&sock->tx_queue); + if (pdu == NULL) + break; + + /* Update N(S)/N(R) */ + nfc_llcp_set_nrns(sock, pdu); + + skb_queue_tail(&local->tx_queue, pdu); + nr_frames++; + } + + return nr_frames; +} + +static void nfc_llcp_recv_hdlc(struct nfc_llcp_local *local, + struct sk_buff *skb) +{ + struct nfc_llcp_sock *llcp_sock; + struct sock *sk; + u8 dsap, ssap, ptype, ns, nr; + + ptype = nfc_llcp_ptype(skb); + dsap = nfc_llcp_dsap(skb); + ssap = nfc_llcp_ssap(skb); + ns = nfc_llcp_ns(skb); + nr = nfc_llcp_nr(skb); + + pr_debug("%d %d R %d S %d\n", dsap, ssap, nr, ns); + + llcp_sock = nfc_llcp_sock_get(local, dsap, ssap); + if (llcp_sock == NULL) { + nfc_llcp_send_dm(local, dsap, ssap, LLCP_DM_NOCONN); + return; + } + + sk = &llcp_sock->sk; + lock_sock(sk); + if (sk->sk_state == LLCP_CLOSED) { + release_sock(sk); + nfc_llcp_sock_put(llcp_sock); + } + + /* Pass the payload upstream */ + if (ptype == LLCP_PDU_I) { + pr_debug("I frame, queueing on %p\n", &llcp_sock->sk); + + if (ns == llcp_sock->recv_n) + llcp_sock->recv_n = (llcp_sock->recv_n + 1) % 16; + else + pr_err("Received out of sequence I PDU\n"); + + skb_pull(skb, LLCP_HEADER_SIZE + LLCP_SEQUENCE_SIZE); + if (!sock_queue_rcv_skb(&llcp_sock->sk, skb)) { + /* + * I frames will be freed from the socket layer, so we + * need to keep them alive until someone receives them. + */ + skb_get(skb); + } else { + pr_err("Receive queue is full\n"); + } + } + + /* Remove skbs from the pending queue */ + if (llcp_sock->send_ack_n != nr) { + struct sk_buff *s, *tmp; + u8 n; + + llcp_sock->send_ack_n = nr; + + /* Remove and free all skbs until ns == nr */ + skb_queue_walk_safe(&llcp_sock->tx_pending_queue, s, tmp) { + n = nfc_llcp_ns(s); + + skb_unlink(s, &llcp_sock->tx_pending_queue); + kfree_skb(s); + + if (n == nr) + break; + } + + /* Re-queue the remaining skbs for transmission */ + skb_queue_reverse_walk_safe(&llcp_sock->tx_pending_queue, + s, tmp) { + skb_unlink(s, &llcp_sock->tx_pending_queue); + skb_queue_head(&local->tx_queue, s); + } + } + + if (ptype == LLCP_PDU_RR) + llcp_sock->remote_ready = true; + else if (ptype == LLCP_PDU_RNR) + llcp_sock->remote_ready = false; + + if (nfc_llcp_queue_i_frames(llcp_sock) == 0 && ptype == LLCP_PDU_I) + nfc_llcp_send_rr(llcp_sock); + + release_sock(sk); + nfc_llcp_sock_put(llcp_sock); +} + +static void nfc_llcp_recv_disc(struct nfc_llcp_local *local, + const struct sk_buff *skb) +{ + struct nfc_llcp_sock *llcp_sock; + struct sock *sk; + u8 dsap, ssap; + + dsap = nfc_llcp_dsap(skb); + ssap = nfc_llcp_ssap(skb); + + if ((dsap == 0) && (ssap == 0)) { + pr_debug("Connection termination"); + nfc_dep_link_down(local->dev); + return; + } + + llcp_sock = nfc_llcp_sock_get(local, dsap, ssap); + if (llcp_sock == NULL) { + nfc_llcp_send_dm(local, dsap, ssap, LLCP_DM_NOCONN); + return; + } + + sk = &llcp_sock->sk; + lock_sock(sk); + + nfc_llcp_socket_purge(llcp_sock); + + if (sk->sk_state == LLCP_CLOSED) { + release_sock(sk); + nfc_llcp_sock_put(llcp_sock); + } + + if (sk->sk_state == LLCP_CONNECTED) { + nfc_put_device(local->dev); + sk->sk_state = LLCP_CLOSED; + sk->sk_state_change(sk); + } + + nfc_llcp_send_dm(local, dsap, ssap, LLCP_DM_DISC); + + release_sock(sk); + nfc_llcp_sock_put(llcp_sock); +} + +static void nfc_llcp_recv_cc(struct nfc_llcp_local *local, + const struct sk_buff *skb) +{ + struct nfc_llcp_sock *llcp_sock; + struct sock *sk; + u8 dsap, ssap; + + dsap = nfc_llcp_dsap(skb); + ssap = nfc_llcp_ssap(skb); + + llcp_sock = nfc_llcp_connecting_sock_get(local, dsap); + if (llcp_sock == NULL) { + pr_err("Invalid CC\n"); + nfc_llcp_send_dm(local, dsap, ssap, LLCP_DM_NOCONN); + + return; + } + + sk = &llcp_sock->sk; + + /* Unlink from connecting and link to the client array */ + nfc_llcp_sock_unlink(&local->connecting_sockets, sk); + nfc_llcp_sock_link(&local->sockets, sk); + llcp_sock->dsap = ssap; + + nfc_llcp_parse_connection_tlv(llcp_sock, &skb->data[LLCP_HEADER_SIZE], + skb->len - LLCP_HEADER_SIZE); + + sk->sk_state = LLCP_CONNECTED; + sk->sk_state_change(sk); + + nfc_llcp_sock_put(llcp_sock); +} + +static void nfc_llcp_recv_dm(struct nfc_llcp_local *local, + const struct sk_buff *skb) +{ + struct nfc_llcp_sock *llcp_sock; + struct sock *sk; + u8 dsap, ssap, reason; + + dsap = nfc_llcp_dsap(skb); + ssap = nfc_llcp_ssap(skb); + reason = skb->data[2]; + + pr_debug("%d %d reason %d\n", ssap, dsap, reason); + + switch (reason) { + case LLCP_DM_NOBOUND: + case LLCP_DM_REJ: + llcp_sock = nfc_llcp_connecting_sock_get(local, dsap); + break; + + default: + llcp_sock = nfc_llcp_sock_get(local, dsap, ssap); + break; + } + + if (llcp_sock == NULL) { + pr_debug("Already closed\n"); + return; + } + + sk = &llcp_sock->sk; + + sk->sk_err = ENXIO; + sk->sk_state = LLCP_CLOSED; + sk->sk_state_change(sk); + + nfc_llcp_sock_put(llcp_sock); +} + +static void nfc_llcp_recv_snl(struct nfc_llcp_local *local, + const struct sk_buff *skb) +{ + struct nfc_llcp_sock *llcp_sock; + u8 dsap, ssap, type, length, tid, sap; + const u8 *tlv; + u16 tlv_len, offset; + const char *service_name; + size_t service_name_len; + struct nfc_llcp_sdp_tlv *sdp; + HLIST_HEAD(llc_sdres_list); + size_t sdres_tlvs_len; + HLIST_HEAD(nl_sdres_list); + + dsap = nfc_llcp_dsap(skb); + ssap = nfc_llcp_ssap(skb); + + pr_debug("%d %d\n", dsap, ssap); + + if (dsap != LLCP_SAP_SDP || ssap != LLCP_SAP_SDP) { + pr_err("Wrong SNL SAP\n"); + return; + } + + tlv = &skb->data[LLCP_HEADER_SIZE]; + tlv_len = skb->len - LLCP_HEADER_SIZE; + offset = 0; + sdres_tlvs_len = 0; + + while (offset < tlv_len) { + type = tlv[0]; + length = tlv[1]; + + switch (type) { + case LLCP_TLV_SDREQ: + tid = tlv[2]; + service_name = (char *) &tlv[3]; + service_name_len = length - 1; + + pr_debug("Looking for %.16s\n", service_name); + + if (service_name_len == strlen("urn:nfc:sn:sdp") && + !strncmp(service_name, "urn:nfc:sn:sdp", + service_name_len)) { + sap = 1; + goto add_snl; + } + + llcp_sock = nfc_llcp_sock_from_sn(local, service_name, + service_name_len, + true); + if (!llcp_sock) { + sap = 0; + goto add_snl; + } + + /* + * We found a socket but its ssap has not been reserved + * yet. We need to assign it for good and send a reply. + * The ssap will be freed when the socket is closed. + */ + if (llcp_sock->ssap == LLCP_SDP_UNBOUND) { + atomic_t *client_count; + + sap = nfc_llcp_reserve_sdp_ssap(local); + + pr_debug("Reserving %d\n", sap); + + if (sap == LLCP_SAP_MAX) { + sap = 0; + nfc_llcp_sock_put(llcp_sock); + goto add_snl; + } + + client_count = + &local->local_sdp_cnt[sap - + LLCP_WKS_NUM_SAP]; + + atomic_inc(client_count); + + llcp_sock->ssap = sap; + llcp_sock->reserved_ssap = sap; + } else { + sap = llcp_sock->ssap; + } + + pr_debug("%p %d\n", llcp_sock, sap); + + nfc_llcp_sock_put(llcp_sock); +add_snl: + sdp = nfc_llcp_build_sdres_tlv(tid, sap); + if (sdp == NULL) + goto exit; + + sdres_tlvs_len += sdp->tlv_len; + hlist_add_head(&sdp->node, &llc_sdres_list); + break; + + case LLCP_TLV_SDRES: + mutex_lock(&local->sdreq_lock); + + pr_debug("LLCP_TLV_SDRES: searching tid %d\n", tlv[2]); + + hlist_for_each_entry(sdp, &local->pending_sdreqs, node) { + if (sdp->tid != tlv[2]) + continue; + + sdp->sap = tlv[3]; + + pr_debug("Found: uri=%s, sap=%d\n", + sdp->uri, sdp->sap); + + hlist_del(&sdp->node); + + hlist_add_head(&sdp->node, &nl_sdres_list); + + break; + } + + mutex_unlock(&local->sdreq_lock); + break; + + default: + pr_err("Invalid SNL tlv value 0x%x\n", type); + break; + } + + offset += length + 2; + tlv += length + 2; + } + +exit: + if (!hlist_empty(&nl_sdres_list)) + nfc_genl_llc_send_sdres(local->dev, &nl_sdres_list); + + if (!hlist_empty(&llc_sdres_list)) + nfc_llcp_send_snl_sdres(local, &llc_sdres_list, sdres_tlvs_len); +} + +static void nfc_llcp_recv_agf(struct nfc_llcp_local *local, struct sk_buff *skb) +{ + u8 ptype; + u16 pdu_len; + struct sk_buff *new_skb; + + if (skb->len <= LLCP_HEADER_SIZE) { + pr_err("Malformed AGF PDU\n"); + return; + } + + skb_pull(skb, LLCP_HEADER_SIZE); + + while (skb->len > LLCP_AGF_PDU_HEADER_SIZE) { + pdu_len = skb->data[0] << 8 | skb->data[1]; + + skb_pull(skb, LLCP_AGF_PDU_HEADER_SIZE); + + if (pdu_len < LLCP_HEADER_SIZE || pdu_len > skb->len) { + pr_err("Malformed AGF PDU\n"); + return; + } + + ptype = nfc_llcp_ptype(skb); + + if (ptype == LLCP_PDU_SYMM || ptype == LLCP_PDU_AGF) + goto next; + + new_skb = nfc_alloc_recv_skb(pdu_len, GFP_KERNEL); + if (new_skb == NULL) { + pr_err("Could not allocate PDU\n"); + return; + } + + skb_put_data(new_skb, skb->data, pdu_len); + + nfc_llcp_rx_skb(local, new_skb); + + kfree_skb(new_skb); +next: + skb_pull(skb, pdu_len); + } +} + +static void nfc_llcp_rx_skb(struct nfc_llcp_local *local, struct sk_buff *skb) +{ + u8 dsap, ssap, ptype; + + ptype = nfc_llcp_ptype(skb); + dsap = nfc_llcp_dsap(skb); + ssap = nfc_llcp_ssap(skb); + + pr_debug("ptype 0x%x dsap 0x%x ssap 0x%x\n", ptype, dsap, ssap); + + if (ptype != LLCP_PDU_SYMM) + print_hex_dump_debug("LLCP Rx: ", DUMP_PREFIX_OFFSET, 16, 1, + skb->data, skb->len, true); + + switch (ptype) { + case LLCP_PDU_SYMM: + pr_debug("SYMM\n"); + break; + + case LLCP_PDU_UI: + pr_debug("UI\n"); + nfc_llcp_recv_ui(local, skb); + break; + + case LLCP_PDU_CONNECT: + pr_debug("CONNECT\n"); + nfc_llcp_recv_connect(local, skb); + break; + + case LLCP_PDU_DISC: + pr_debug("DISC\n"); + nfc_llcp_recv_disc(local, skb); + break; + + case LLCP_PDU_CC: + pr_debug("CC\n"); + nfc_llcp_recv_cc(local, skb); + break; + + case LLCP_PDU_DM: + pr_debug("DM\n"); + nfc_llcp_recv_dm(local, skb); + break; + + case LLCP_PDU_SNL: + pr_debug("SNL\n"); + nfc_llcp_recv_snl(local, skb); + break; + + case LLCP_PDU_I: + case LLCP_PDU_RR: + case LLCP_PDU_RNR: + pr_debug("I frame\n"); + nfc_llcp_recv_hdlc(local, skb); + break; + + case LLCP_PDU_AGF: + pr_debug("AGF frame\n"); + nfc_llcp_recv_agf(local, skb); + break; + } +} + +static void nfc_llcp_rx_work(struct work_struct *work) +{ + struct nfc_llcp_local *local = container_of(work, struct nfc_llcp_local, + rx_work); + struct sk_buff *skb; + + skb = local->rx_pending; + if (skb == NULL) { + pr_debug("No pending SKB\n"); + return; + } + + __net_timestamp(skb); + + nfc_llcp_send_to_raw_sock(local, skb, NFC_DIRECTION_RX); + + nfc_llcp_rx_skb(local, skb); + + schedule_work(&local->tx_work); + kfree_skb(local->rx_pending); + local->rx_pending = NULL; +} + +static void __nfc_llcp_recv(struct nfc_llcp_local *local, struct sk_buff *skb) +{ + local->rx_pending = skb; + del_timer(&local->link_timer); + schedule_work(&local->rx_work); +} + +void nfc_llcp_recv(void *data, struct sk_buff *skb, int err) +{ + struct nfc_llcp_local *local = (struct nfc_llcp_local *) data; + + if (err < 0) { + pr_err("LLCP PDU receive err %d\n", err); + return; + } + + __nfc_llcp_recv(local, skb); +} + +int nfc_llcp_data_received(struct nfc_dev *dev, struct sk_buff *skb) +{ + struct nfc_llcp_local *local; + + local = nfc_llcp_find_local(dev); + if (local == NULL) { + kfree_skb(skb); + return -ENODEV; + } + + __nfc_llcp_recv(local, skb); + + nfc_llcp_local_put(local); + + return 0; +} + +void nfc_llcp_mac_is_down(struct nfc_dev *dev) +{ + struct nfc_llcp_local *local; + + local = nfc_llcp_find_local(dev); + if (local == NULL) + return; + + local->remote_miu = LLCP_DEFAULT_MIU; + local->remote_lto = LLCP_DEFAULT_LTO; + + /* Close and purge all existing sockets */ + nfc_llcp_socket_release(local, true, 0); + + nfc_llcp_local_put(local); +} + +void nfc_llcp_mac_is_up(struct nfc_dev *dev, u32 target_idx, + u8 comm_mode, u8 rf_mode) +{ + struct nfc_llcp_local *local; + + pr_debug("rf mode %d\n", rf_mode); + + local = nfc_llcp_find_local(dev); + if (local == NULL) + return; + + local->target_idx = target_idx; + local->comm_mode = comm_mode; + local->rf_mode = rf_mode; + + if (rf_mode == NFC_RF_INITIATOR) { + pr_debug("Queueing Tx work\n"); + + schedule_work(&local->tx_work); + } else { + mod_timer(&local->link_timer, + jiffies + msecs_to_jiffies(local->remote_lto)); + } + + nfc_llcp_local_put(local); +} + +int nfc_llcp_register_device(struct nfc_dev *ndev) +{ + struct nfc_llcp_local *local; + + local = kzalloc(sizeof(struct nfc_llcp_local), GFP_KERNEL); + if (local == NULL) + return -ENOMEM; + + /* As we are going to initialize local's refcount, we need to get the + * nfc_dev to avoid UAF, otherwise there is no point in continuing. + * See nfc_llcp_local_get(). + */ + local->dev = nfc_get_device(ndev->idx); + if (!local->dev) { + kfree(local); + return -ENODEV; + } + + INIT_LIST_HEAD(&local->list); + kref_init(&local->ref); + mutex_init(&local->sdp_lock); + timer_setup(&local->link_timer, nfc_llcp_symm_timer, 0); + + skb_queue_head_init(&local->tx_queue); + INIT_WORK(&local->tx_work, nfc_llcp_tx_work); + + local->rx_pending = NULL; + INIT_WORK(&local->rx_work, nfc_llcp_rx_work); + + INIT_WORK(&local->timeout_work, nfc_llcp_timeout_work); + + rwlock_init(&local->sockets.lock); + rwlock_init(&local->connecting_sockets.lock); + rwlock_init(&local->raw_sockets.lock); + + local->lto = 150; /* 1500 ms */ + local->rw = LLCP_MAX_RW; + local->miux = cpu_to_be16(LLCP_MAX_MIUX); + local->local_wks = 0x1; /* LLC Link Management */ + + nfc_llcp_build_gb(local); + + local->remote_miu = LLCP_DEFAULT_MIU; + local->remote_lto = LLCP_DEFAULT_LTO; + + mutex_init(&local->sdreq_lock); + INIT_HLIST_HEAD(&local->pending_sdreqs); + timer_setup(&local->sdreq_timer, nfc_llcp_sdreq_timer, 0); + INIT_WORK(&local->sdreq_timeout_work, nfc_llcp_sdreq_timeout_work); + + spin_lock(&llcp_devices_lock); + list_add(&local->list, &llcp_devices); + spin_unlock(&llcp_devices_lock); + + return 0; +} + +void nfc_llcp_unregister_device(struct nfc_dev *dev) +{ + struct nfc_llcp_local *local = nfc_llcp_remove_local(dev); + + if (local == NULL) { + pr_debug("No such device\n"); + return; + } + + local_cleanup(local); + + nfc_llcp_local_put(local); +} + +int __init nfc_llcp_init(void) +{ + return nfc_llcp_sock_init(); +} + +void nfc_llcp_exit(void) +{ + nfc_llcp_sock_exit(); +} diff --git a/net/nfc/llcp_sock.c b/net/nfc/llcp_sock.c new file mode 100644 index 000000000..645677f84 --- /dev/null +++ b/net/nfc/llcp_sock.c @@ -0,0 +1,1061 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (C) 2011 Intel Corporation. All rights reserved. + */ + +#define pr_fmt(fmt) "llcp: %s: " fmt, __func__ + +#include +#include +#include +#include +#include + +#include "nfc.h" +#include "llcp.h" + +static int sock_wait_state(struct sock *sk, int state, unsigned long timeo) +{ + DECLARE_WAITQUEUE(wait, current); + int err = 0; + + pr_debug("sk %p", sk); + + add_wait_queue(sk_sleep(sk), &wait); + set_current_state(TASK_INTERRUPTIBLE); + + while (sk->sk_state != state) { + if (!timeo) { + err = -EINPROGRESS; + break; + } + + if (signal_pending(current)) { + err = sock_intr_errno(timeo); + break; + } + + release_sock(sk); + timeo = schedule_timeout(timeo); + lock_sock(sk); + set_current_state(TASK_INTERRUPTIBLE); + + err = sock_error(sk); + if (err) + break; + } + + __set_current_state(TASK_RUNNING); + remove_wait_queue(sk_sleep(sk), &wait); + return err; +} + +static struct proto llcp_sock_proto = { + .name = "NFC_LLCP", + .owner = THIS_MODULE, + .obj_size = sizeof(struct nfc_llcp_sock), +}; + +static int llcp_sock_bind(struct socket *sock, struct sockaddr *addr, int alen) +{ + struct sock *sk = sock->sk; + struct nfc_llcp_sock *llcp_sock = nfc_llcp_sock(sk); + struct nfc_llcp_local *local; + struct nfc_dev *dev; + struct sockaddr_nfc_llcp llcp_addr; + int len, ret = 0; + + if (!addr || alen < offsetofend(struct sockaddr, sa_family) || + addr->sa_family != AF_NFC) + return -EINVAL; + + pr_debug("sk %p addr %p family %d\n", sk, addr, addr->sa_family); + + memset(&llcp_addr, 0, sizeof(llcp_addr)); + len = min_t(unsigned int, sizeof(llcp_addr), alen); + memcpy(&llcp_addr, addr, len); + + /* This is going to be a listening socket, dsap must be 0 */ + if (llcp_addr.dsap != 0) + return -EINVAL; + + lock_sock(sk); + + if (sk->sk_state != LLCP_CLOSED) { + ret = -EBADFD; + goto error; + } + + dev = nfc_get_device(llcp_addr.dev_idx); + if (dev == NULL) { + ret = -ENODEV; + goto error; + } + + local = nfc_llcp_find_local(dev); + if (local == NULL) { + ret = -ENODEV; + goto put_dev; + } + + llcp_sock->dev = dev; + llcp_sock->local = local; + llcp_sock->nfc_protocol = llcp_addr.nfc_protocol; + llcp_sock->service_name_len = min_t(unsigned int, + llcp_addr.service_name_len, + NFC_LLCP_MAX_SERVICE_NAME); + llcp_sock->service_name = kmemdup(llcp_addr.service_name, + llcp_sock->service_name_len, + GFP_KERNEL); + if (!llcp_sock->service_name) { + ret = -ENOMEM; + goto sock_llcp_put_local; + } + llcp_sock->ssap = nfc_llcp_get_sdp_ssap(local, llcp_sock); + if (llcp_sock->ssap == LLCP_SAP_MAX) { + ret = -EADDRINUSE; + goto free_service_name; + } + + llcp_sock->reserved_ssap = llcp_sock->ssap; + + nfc_llcp_sock_link(&local->sockets, sk); + + pr_debug("Socket bound to SAP %d\n", llcp_sock->ssap); + + sk->sk_state = LLCP_BOUND; + nfc_put_device(dev); + release_sock(sk); + + return 0; + +free_service_name: + kfree(llcp_sock->service_name); + llcp_sock->service_name = NULL; + +sock_llcp_put_local: + nfc_llcp_local_put(llcp_sock->local); + llcp_sock->local = NULL; + llcp_sock->dev = NULL; + +put_dev: + nfc_put_device(dev); + +error: + release_sock(sk); + return ret; +} + +static int llcp_raw_sock_bind(struct socket *sock, struct sockaddr *addr, + int alen) +{ + struct sock *sk = sock->sk; + struct nfc_llcp_sock *llcp_sock = nfc_llcp_sock(sk); + struct nfc_llcp_local *local; + struct nfc_dev *dev; + struct sockaddr_nfc_llcp llcp_addr; + int len, ret = 0; + + if (!addr || alen < offsetofend(struct sockaddr, sa_family) || + addr->sa_family != AF_NFC) + return -EINVAL; + + pr_debug("sk %p addr %p family %d\n", sk, addr, addr->sa_family); + + memset(&llcp_addr, 0, sizeof(llcp_addr)); + len = min_t(unsigned int, sizeof(llcp_addr), alen); + memcpy(&llcp_addr, addr, len); + + lock_sock(sk); + + if (sk->sk_state != LLCP_CLOSED) { + ret = -EBADFD; + goto error; + } + + dev = nfc_get_device(llcp_addr.dev_idx); + if (dev == NULL) { + ret = -ENODEV; + goto error; + } + + local = nfc_llcp_find_local(dev); + if (local == NULL) { + ret = -ENODEV; + goto put_dev; + } + + llcp_sock->dev = dev; + llcp_sock->local = local; + llcp_sock->nfc_protocol = llcp_addr.nfc_protocol; + + nfc_llcp_sock_link(&local->raw_sockets, sk); + + sk->sk_state = LLCP_BOUND; + +put_dev: + nfc_put_device(dev); + +error: + release_sock(sk); + return ret; +} + +static int llcp_sock_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + int ret = 0; + + pr_debug("sk %p backlog %d\n", sk, backlog); + + lock_sock(sk); + + if ((sock->type != SOCK_SEQPACKET && sock->type != SOCK_STREAM) || + sk->sk_state != LLCP_BOUND) { + ret = -EBADFD; + goto error; + } + + sk->sk_max_ack_backlog = backlog; + sk->sk_ack_backlog = 0; + + pr_debug("Socket listening\n"); + sk->sk_state = LLCP_LISTEN; + +error: + release_sock(sk); + + return ret; +} + +static int nfc_llcp_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct nfc_llcp_sock *llcp_sock = nfc_llcp_sock(sk); + u32 opt; + int err = 0; + + pr_debug("%p optname %d\n", sk, optname); + + if (level != SOL_NFC) + return -ENOPROTOOPT; + + lock_sock(sk); + + switch (optname) { + case NFC_LLCP_RW: + if (sk->sk_state == LLCP_CONNECTED || + sk->sk_state == LLCP_BOUND || + sk->sk_state == LLCP_LISTEN) { + err = -EINVAL; + break; + } + + if (copy_from_sockptr(&opt, optval, sizeof(u32))) { + err = -EFAULT; + break; + } + + if (opt > LLCP_MAX_RW) { + err = -EINVAL; + break; + } + + llcp_sock->rw = (u8) opt; + + break; + + case NFC_LLCP_MIUX: + if (sk->sk_state == LLCP_CONNECTED || + sk->sk_state == LLCP_BOUND || + sk->sk_state == LLCP_LISTEN) { + err = -EINVAL; + break; + } + + if (copy_from_sockptr(&opt, optval, sizeof(u32))) { + err = -EFAULT; + break; + } + + if (opt > LLCP_MAX_MIUX) { + err = -EINVAL; + break; + } + + llcp_sock->miux = cpu_to_be16((u16) opt); + + break; + + default: + err = -ENOPROTOOPT; + break; + } + + release_sock(sk); + + pr_debug("%p rw %d miux %d\n", llcp_sock, + llcp_sock->rw, llcp_sock->miux); + + return err; +} + +static int nfc_llcp_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct nfc_llcp_local *local; + struct sock *sk = sock->sk; + struct nfc_llcp_sock *llcp_sock = nfc_llcp_sock(sk); + int len, err = 0; + u16 miux, remote_miu; + u8 rw; + + pr_debug("%p optname %d\n", sk, optname); + + if (level != SOL_NFC) + return -ENOPROTOOPT; + + if (get_user(len, optlen)) + return -EFAULT; + + local = llcp_sock->local; + if (!local) + return -ENODEV; + + len = min_t(u32, len, sizeof(u32)); + + lock_sock(sk); + + switch (optname) { + case NFC_LLCP_RW: + rw = llcp_sock->rw > LLCP_MAX_RW ? local->rw : llcp_sock->rw; + if (put_user(rw, (u32 __user *) optval)) + err = -EFAULT; + + break; + + case NFC_LLCP_MIUX: + miux = be16_to_cpu(llcp_sock->miux) > LLCP_MAX_MIUX ? + be16_to_cpu(local->miux) : be16_to_cpu(llcp_sock->miux); + + if (put_user(miux, (u32 __user *) optval)) + err = -EFAULT; + + break; + + case NFC_LLCP_REMOTE_MIU: + remote_miu = llcp_sock->remote_miu > LLCP_MAX_MIU ? + local->remote_miu : llcp_sock->remote_miu; + + if (put_user(remote_miu, (u32 __user *) optval)) + err = -EFAULT; + + break; + + case NFC_LLCP_REMOTE_LTO: + if (put_user(local->remote_lto / 10, (u32 __user *) optval)) + err = -EFAULT; + + break; + + case NFC_LLCP_REMOTE_RW: + if (put_user(llcp_sock->remote_rw, (u32 __user *) optval)) + err = -EFAULT; + + break; + + default: + err = -ENOPROTOOPT; + break; + } + + release_sock(sk); + + if (put_user(len, optlen)) + return -EFAULT; + + return err; +} + +void nfc_llcp_accept_unlink(struct sock *sk) +{ + struct nfc_llcp_sock *llcp_sock = nfc_llcp_sock(sk); + + pr_debug("state %d\n", sk->sk_state); + + list_del_init(&llcp_sock->accept_queue); + sk_acceptq_removed(llcp_sock->parent); + llcp_sock->parent = NULL; + + sock_put(sk); +} + +void nfc_llcp_accept_enqueue(struct sock *parent, struct sock *sk) +{ + struct nfc_llcp_sock *llcp_sock = nfc_llcp_sock(sk); + struct nfc_llcp_sock *llcp_sock_parent = nfc_llcp_sock(parent); + + /* Lock will be free from unlink */ + sock_hold(sk); + + list_add_tail(&llcp_sock->accept_queue, + &llcp_sock_parent->accept_queue); + llcp_sock->parent = parent; + sk_acceptq_added(parent); +} + +struct sock *nfc_llcp_accept_dequeue(struct sock *parent, + struct socket *newsock) +{ + struct nfc_llcp_sock *lsk, *n, *llcp_parent; + struct sock *sk; + + llcp_parent = nfc_llcp_sock(parent); + + list_for_each_entry_safe(lsk, n, &llcp_parent->accept_queue, + accept_queue) { + sk = &lsk->sk; + lock_sock(sk); + + if (sk->sk_state == LLCP_CLOSED) { + release_sock(sk); + nfc_llcp_accept_unlink(sk); + continue; + } + + if (sk->sk_state == LLCP_CONNECTED || !newsock) { + list_del_init(&lsk->accept_queue); + sock_put(sk); + + if (newsock) + sock_graft(sk, newsock); + + release_sock(sk); + + pr_debug("Returning sk state %d\n", sk->sk_state); + + sk_acceptq_removed(parent); + + return sk; + } + + release_sock(sk); + } + + return NULL; +} + +static int llcp_sock_accept(struct socket *sock, struct socket *newsock, + int flags, bool kern) +{ + DECLARE_WAITQUEUE(wait, current); + struct sock *sk = sock->sk, *new_sk; + long timeo; + int ret = 0; + + pr_debug("parent %p\n", sk); + + lock_sock_nested(sk, SINGLE_DEPTH_NESTING); + + if (sk->sk_state != LLCP_LISTEN) { + ret = -EBADFD; + goto error; + } + + timeo = sock_rcvtimeo(sk, flags & O_NONBLOCK); + + /* Wait for an incoming connection. */ + add_wait_queue_exclusive(sk_sleep(sk), &wait); + while (!(new_sk = nfc_llcp_accept_dequeue(sk, newsock))) { + set_current_state(TASK_INTERRUPTIBLE); + + if (!timeo) { + ret = -EAGAIN; + break; + } + + if (signal_pending(current)) { + ret = sock_intr_errno(timeo); + break; + } + + release_sock(sk); + timeo = schedule_timeout(timeo); + lock_sock_nested(sk, SINGLE_DEPTH_NESTING); + } + __set_current_state(TASK_RUNNING); + remove_wait_queue(sk_sleep(sk), &wait); + + if (ret) + goto error; + + newsock->state = SS_CONNECTED; + + pr_debug("new socket %p\n", new_sk); + +error: + release_sock(sk); + + return ret; +} + +static int llcp_sock_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + struct sock *sk = sock->sk; + struct nfc_llcp_sock *llcp_sock = nfc_llcp_sock(sk); + DECLARE_SOCKADDR(struct sockaddr_nfc_llcp *, llcp_addr, uaddr); + + if (llcp_sock == NULL || llcp_sock->dev == NULL) + return -EBADFD; + + pr_debug("%p %d %d %d\n", sk, llcp_sock->target_idx, + llcp_sock->dsap, llcp_sock->ssap); + + memset(llcp_addr, 0, sizeof(*llcp_addr)); + + lock_sock(sk); + if (!llcp_sock->dev) { + release_sock(sk); + return -EBADFD; + } + llcp_addr->sa_family = AF_NFC; + llcp_addr->dev_idx = llcp_sock->dev->idx; + llcp_addr->target_idx = llcp_sock->target_idx; + llcp_addr->nfc_protocol = llcp_sock->nfc_protocol; + llcp_addr->dsap = llcp_sock->dsap; + llcp_addr->ssap = llcp_sock->ssap; + llcp_addr->service_name_len = llcp_sock->service_name_len; + memcpy(llcp_addr->service_name, llcp_sock->service_name, + llcp_addr->service_name_len); + release_sock(sk); + + return sizeof(struct sockaddr_nfc_llcp); +} + +static inline __poll_t llcp_accept_poll(struct sock *parent) +{ + struct nfc_llcp_sock *llcp_sock, *parent_sock; + struct sock *sk; + + parent_sock = nfc_llcp_sock(parent); + + list_for_each_entry(llcp_sock, &parent_sock->accept_queue, + accept_queue) { + sk = &llcp_sock->sk; + + if (sk->sk_state == LLCP_CONNECTED) + return EPOLLIN | EPOLLRDNORM; + } + + return 0; +} + +static __poll_t llcp_sock_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + struct sock *sk = sock->sk; + __poll_t mask = 0; + + pr_debug("%p\n", sk); + + sock_poll_wait(file, sock, wait); + + if (sk->sk_state == LLCP_LISTEN) + return llcp_accept_poll(sk); + + if (sk->sk_err || !skb_queue_empty_lockless(&sk->sk_error_queue)) + mask |= EPOLLERR | + (sock_flag(sk, SOCK_SELECT_ERR_QUEUE) ? EPOLLPRI : 0); + + if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) + mask |= EPOLLIN | EPOLLRDNORM; + + if (sk->sk_state == LLCP_CLOSED) + mask |= EPOLLHUP; + + if (sk->sk_shutdown & RCV_SHUTDOWN) + mask |= EPOLLRDHUP | EPOLLIN | EPOLLRDNORM; + + if (sk->sk_shutdown == SHUTDOWN_MASK) + mask |= EPOLLHUP; + + if (sock_writeable(sk) && sk->sk_state == LLCP_CONNECTED) + mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND; + else + sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk); + + pr_debug("mask 0x%x\n", mask); + + return mask; +} + +static int llcp_sock_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct nfc_llcp_local *local; + struct nfc_llcp_sock *llcp_sock = nfc_llcp_sock(sk); + int err = 0; + + if (!sk) + return 0; + + pr_debug("%p\n", sk); + + local = llcp_sock->local; + if (local == NULL) { + err = -ENODEV; + goto out; + } + + lock_sock(sk); + + /* Send a DISC */ + if (sk->sk_state == LLCP_CONNECTED) + nfc_llcp_send_disconnect(llcp_sock); + + if (sk->sk_state == LLCP_LISTEN) { + struct nfc_llcp_sock *lsk, *n; + struct sock *accept_sk; + + list_for_each_entry_safe(lsk, n, &llcp_sock->accept_queue, + accept_queue) { + accept_sk = &lsk->sk; + lock_sock(accept_sk); + + nfc_llcp_send_disconnect(lsk); + nfc_llcp_accept_unlink(accept_sk); + + release_sock(accept_sk); + } + } + + if (sock->type == SOCK_RAW) + nfc_llcp_sock_unlink(&local->raw_sockets, sk); + else + nfc_llcp_sock_unlink(&local->sockets, sk); + + if (llcp_sock->reserved_ssap < LLCP_SAP_MAX) + nfc_llcp_put_ssap(llcp_sock->local, llcp_sock->ssap); + + release_sock(sk); + +out: + sock_orphan(sk); + sock_put(sk); + + return err; +} + +static int llcp_sock_connect(struct socket *sock, struct sockaddr *_addr, + int len, int flags) +{ + struct sock *sk = sock->sk; + struct nfc_llcp_sock *llcp_sock = nfc_llcp_sock(sk); + struct sockaddr_nfc_llcp *addr = (struct sockaddr_nfc_llcp *)_addr; + struct nfc_dev *dev; + struct nfc_llcp_local *local; + int ret = 0; + + pr_debug("sock %p sk %p flags 0x%x\n", sock, sk, flags); + + if (!addr || len < sizeof(*addr) || addr->sa_family != AF_NFC) + return -EINVAL; + + if (addr->service_name_len == 0 && addr->dsap == 0) + return -EINVAL; + + pr_debug("addr dev_idx=%u target_idx=%u protocol=%u\n", addr->dev_idx, + addr->target_idx, addr->nfc_protocol); + + lock_sock(sk); + + if (sk->sk_state == LLCP_CONNECTED) { + ret = -EISCONN; + goto error; + } + if (sk->sk_state == LLCP_CONNECTING) { + ret = -EINPROGRESS; + goto error; + } + + dev = nfc_get_device(addr->dev_idx); + if (dev == NULL) { + ret = -ENODEV; + goto error; + } + + local = nfc_llcp_find_local(dev); + if (local == NULL) { + ret = -ENODEV; + goto put_dev; + } + + device_lock(&dev->dev); + if (dev->dep_link_up == false) { + ret = -ENOLINK; + device_unlock(&dev->dev); + goto sock_llcp_put_local; + } + device_unlock(&dev->dev); + + if (local->rf_mode == NFC_RF_INITIATOR && + addr->target_idx != local->target_idx) { + ret = -ENOLINK; + goto sock_llcp_put_local; + } + + llcp_sock->dev = dev; + llcp_sock->local = local; + llcp_sock->ssap = nfc_llcp_get_local_ssap(local); + if (llcp_sock->ssap == LLCP_SAP_MAX) { + ret = -ENOMEM; + goto sock_llcp_nullify; + } + + llcp_sock->reserved_ssap = llcp_sock->ssap; + + if (addr->service_name_len == 0) + llcp_sock->dsap = addr->dsap; + else + llcp_sock->dsap = LLCP_SAP_SDP; + llcp_sock->nfc_protocol = addr->nfc_protocol; + llcp_sock->service_name_len = min_t(unsigned int, + addr->service_name_len, + NFC_LLCP_MAX_SERVICE_NAME); + llcp_sock->service_name = kmemdup(addr->service_name, + llcp_sock->service_name_len, + GFP_KERNEL); + if (!llcp_sock->service_name) { + ret = -ENOMEM; + goto sock_llcp_release; + } + + nfc_llcp_sock_link(&local->connecting_sockets, sk); + + ret = nfc_llcp_send_connect(llcp_sock); + if (ret) + goto sock_unlink; + + sk->sk_state = LLCP_CONNECTING; + + ret = sock_wait_state(sk, LLCP_CONNECTED, + sock_sndtimeo(sk, flags & O_NONBLOCK)); + if (ret && ret != -EINPROGRESS) + goto sock_unlink; + + release_sock(sk); + + return ret; + +sock_unlink: + nfc_llcp_sock_unlink(&local->connecting_sockets, sk); + kfree(llcp_sock->service_name); + llcp_sock->service_name = NULL; + +sock_llcp_release: + nfc_llcp_put_ssap(local, llcp_sock->ssap); + +sock_llcp_nullify: + llcp_sock->local = NULL; + llcp_sock->dev = NULL; + +sock_llcp_put_local: + nfc_llcp_local_put(local); + +put_dev: + nfc_put_device(dev); + +error: + release_sock(sk); + return ret; +} + +static int llcp_sock_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + struct sock *sk = sock->sk; + struct nfc_llcp_sock *llcp_sock = nfc_llcp_sock(sk); + int ret; + + pr_debug("sock %p sk %p", sock, sk); + + ret = sock_error(sk); + if (ret) + return ret; + + if (msg->msg_flags & MSG_OOB) + return -EOPNOTSUPP; + + lock_sock(sk); + + if (!llcp_sock->local) { + release_sock(sk); + return -ENODEV; + } + + if (sk->sk_type == SOCK_DGRAM) { + DECLARE_SOCKADDR(struct sockaddr_nfc_llcp *, addr, + msg->msg_name); + + if (msg->msg_namelen < sizeof(*addr)) { + release_sock(sk); + return -EINVAL; + } + + release_sock(sk); + + return nfc_llcp_send_ui_frame(llcp_sock, addr->dsap, addr->ssap, + msg, len); + } + + if (sk->sk_state != LLCP_CONNECTED) { + release_sock(sk); + return -ENOTCONN; + } + + release_sock(sk); + + return nfc_llcp_send_i_frame(llcp_sock, msg, len); +} + +static int llcp_sock_recvmsg(struct socket *sock, struct msghdr *msg, + size_t len, int flags) +{ + struct sock *sk = sock->sk; + unsigned int copied, rlen; + struct sk_buff *skb, *cskb; + int err = 0; + + pr_debug("%p %zu\n", sk, len); + + lock_sock(sk); + + if (sk->sk_state == LLCP_CLOSED && + skb_queue_empty(&sk->sk_receive_queue)) { + release_sock(sk); + return 0; + } + + release_sock(sk); + + if (flags & (MSG_OOB)) + return -EOPNOTSUPP; + + skb = skb_recv_datagram(sk, flags, &err); + if (!skb) { + pr_err("Recv datagram failed state %d %d %d", + sk->sk_state, err, sock_error(sk)); + + if (sk->sk_shutdown & RCV_SHUTDOWN) + return 0; + + return err; + } + + rlen = skb->len; /* real length of skb */ + copied = min_t(unsigned int, rlen, len); + + cskb = skb; + if (skb_copy_datagram_msg(cskb, 0, msg, copied)) { + if (!(flags & MSG_PEEK)) + skb_queue_head(&sk->sk_receive_queue, skb); + return -EFAULT; + } + + sock_recv_timestamp(msg, sk, skb); + + if (sk->sk_type == SOCK_DGRAM && msg->msg_name) { + struct nfc_llcp_ui_cb *ui_cb = nfc_llcp_ui_skb_cb(skb); + DECLARE_SOCKADDR(struct sockaddr_nfc_llcp *, sockaddr, + msg->msg_name); + + msg->msg_namelen = sizeof(struct sockaddr_nfc_llcp); + + pr_debug("Datagram socket %d %d\n", ui_cb->dsap, ui_cb->ssap); + + memset(sockaddr, 0, sizeof(*sockaddr)); + sockaddr->sa_family = AF_NFC; + sockaddr->nfc_protocol = NFC_PROTO_NFC_DEP; + sockaddr->dsap = ui_cb->dsap; + sockaddr->ssap = ui_cb->ssap; + } + + /* Mark read part of skb as used */ + if (!(flags & MSG_PEEK)) { + + /* SOCK_STREAM: re-queue skb if it contains unreceived data */ + if (sk->sk_type == SOCK_STREAM || + sk->sk_type == SOCK_DGRAM || + sk->sk_type == SOCK_RAW) { + skb_pull(skb, copied); + if (skb->len) { + skb_queue_head(&sk->sk_receive_queue, skb); + goto done; + } + } + + kfree_skb(skb); + } + + /* XXX Queue backlogged skbs */ + +done: + /* SOCK_SEQPACKET: return real length if MSG_TRUNC is set */ + if (sk->sk_type == SOCK_SEQPACKET && (flags & MSG_TRUNC)) + copied = rlen; + + return copied; +} + +static const struct proto_ops llcp_sock_ops = { + .family = PF_NFC, + .owner = THIS_MODULE, + .bind = llcp_sock_bind, + .connect = llcp_sock_connect, + .release = llcp_sock_release, + .socketpair = sock_no_socketpair, + .accept = llcp_sock_accept, + .getname = llcp_sock_getname, + .poll = llcp_sock_poll, + .ioctl = sock_no_ioctl, + .listen = llcp_sock_listen, + .shutdown = sock_no_shutdown, + .setsockopt = nfc_llcp_setsockopt, + .getsockopt = nfc_llcp_getsockopt, + .sendmsg = llcp_sock_sendmsg, + .recvmsg = llcp_sock_recvmsg, + .mmap = sock_no_mmap, +}; + +static const struct proto_ops llcp_rawsock_ops = { + .family = PF_NFC, + .owner = THIS_MODULE, + .bind = llcp_raw_sock_bind, + .connect = sock_no_connect, + .release = llcp_sock_release, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = llcp_sock_getname, + .poll = llcp_sock_poll, + .ioctl = sock_no_ioctl, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .sendmsg = sock_no_sendmsg, + .recvmsg = llcp_sock_recvmsg, + .mmap = sock_no_mmap, +}; + +static void llcp_sock_destruct(struct sock *sk) +{ + struct nfc_llcp_sock *llcp_sock = nfc_llcp_sock(sk); + + pr_debug("%p\n", sk); + + if (sk->sk_state == LLCP_CONNECTED) + nfc_put_device(llcp_sock->dev); + + skb_queue_purge(&sk->sk_receive_queue); + + nfc_llcp_sock_free(llcp_sock); + + if (!sock_flag(sk, SOCK_DEAD)) { + pr_err("Freeing alive NFC LLCP socket %p\n", sk); + return; + } +} + +struct sock *nfc_llcp_sock_alloc(struct socket *sock, int type, gfp_t gfp, int kern) +{ + struct sock *sk; + struct nfc_llcp_sock *llcp_sock; + + sk = sk_alloc(&init_net, PF_NFC, gfp, &llcp_sock_proto, kern); + if (!sk) + return NULL; + + llcp_sock = nfc_llcp_sock(sk); + + sock_init_data(sock, sk); + sk->sk_state = LLCP_CLOSED; + sk->sk_protocol = NFC_SOCKPROTO_LLCP; + sk->sk_type = type; + sk->sk_destruct = llcp_sock_destruct; + + llcp_sock->ssap = 0; + llcp_sock->dsap = LLCP_SAP_SDP; + llcp_sock->rw = LLCP_MAX_RW + 1; + llcp_sock->miux = cpu_to_be16(LLCP_MAX_MIUX + 1); + llcp_sock->send_n = llcp_sock->send_ack_n = 0; + llcp_sock->recv_n = llcp_sock->recv_ack_n = 0; + llcp_sock->remote_ready = 1; + llcp_sock->reserved_ssap = LLCP_SAP_MAX; + nfc_llcp_socket_remote_param_init(llcp_sock); + skb_queue_head_init(&llcp_sock->tx_queue); + skb_queue_head_init(&llcp_sock->tx_pending_queue); + INIT_LIST_HEAD(&llcp_sock->accept_queue); + + if (sock != NULL) + sock->state = SS_UNCONNECTED; + + return sk; +} + +void nfc_llcp_sock_free(struct nfc_llcp_sock *sock) +{ + kfree(sock->service_name); + + skb_queue_purge(&sock->tx_queue); + skb_queue_purge(&sock->tx_pending_queue); + + list_del_init(&sock->accept_queue); + + sock->parent = NULL; + + nfc_llcp_local_put(sock->local); +} + +static int llcp_sock_create(struct net *net, struct socket *sock, + const struct nfc_protocol *nfc_proto, int kern) +{ + struct sock *sk; + + pr_debug("%p\n", sock); + + if (sock->type != SOCK_STREAM && + sock->type != SOCK_DGRAM && + sock->type != SOCK_RAW) + return -ESOCKTNOSUPPORT; + + if (sock->type == SOCK_RAW) { + if (!capable(CAP_NET_RAW)) + return -EPERM; + sock->ops = &llcp_rawsock_ops; + } else { + sock->ops = &llcp_sock_ops; + } + + sk = nfc_llcp_sock_alloc(sock, sock->type, GFP_ATOMIC, kern); + if (sk == NULL) + return -ENOMEM; + + return 0; +} + +static const struct nfc_protocol llcp_nfc_proto = { + .id = NFC_SOCKPROTO_LLCP, + .proto = &llcp_sock_proto, + .owner = THIS_MODULE, + .create = llcp_sock_create +}; + +int __init nfc_llcp_sock_init(void) +{ + return nfc_proto_register(&llcp_nfc_proto); +} + +void nfc_llcp_sock_exit(void) +{ + nfc_proto_unregister(&llcp_nfc_proto); +} diff --git a/net/nfc/nci/Kconfig b/net/nfc/nci/Kconfig new file mode 100644 index 000000000..ff1f295fe --- /dev/null +++ b/net/nfc/nci/Kconfig @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: GPL-2.0-only +config NFC_NCI + depends on NFC + tristate "NCI protocol support" + default n + help + NCI (NFC Controller Interface) is a communication protocol between + an NFC Controller (NFCC) and a Device Host (DH). + + Say Y here to compile NCI support into the kernel or say M to + compile it as module (nci). + +config NFC_NCI_SPI + depends on NFC_NCI && SPI + select CRC_CCITT + tristate "NCI over SPI protocol support" + default n + help + NCI (NFC Controller Interface) is a communication protocol between + an NFC Controller (NFCC) and a Device Host (DH). + + Say yes if you use an NCI driver that requires SPI link layer. + +config NFC_NCI_UART + depends on NFC_NCI && TTY + tristate "NCI over UART protocol support" + default n + help + Say yes if you use an NCI driver that requires UART link layer. diff --git a/net/nfc/nci/Makefile b/net/nfc/nci/Makefile new file mode 100644 index 000000000..c3362c499 --- /dev/null +++ b/net/nfc/nci/Makefile @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the Linux NFC NCI layer. +# + +obj-$(CONFIG_NFC_NCI) += nci.o + +nci-objs := core.o data.o lib.o ntf.o rsp.o hci.o + +nci_spi-y += spi.o +obj-$(CONFIG_NFC_NCI_SPI) += nci_spi.o + +nci_uart-y += uart.o +obj-$(CONFIG_NFC_NCI_UART) += nci_uart.o diff --git a/net/nfc/nci/core.c b/net/nfc/nci/core.c new file mode 100644 index 000000000..7535afd15 --- /dev/null +++ b/net/nfc/nci/core.c @@ -0,0 +1,1574 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * The NFC Controller Interface is the communication protocol between an + * NFC Controller (NFCC) and a Device Host (DH). + * + * Copyright (C) 2011 Texas Instruments, Inc. + * Copyright (C) 2014 Marvell International Ltd. + * + * Written by Ilan Elias + * + * Acknowledgements: + * This file is based on hci_core.c, which was written + * by Maxim Krasnyansky. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": %s: " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../nfc.h" +#include +#include +#include + +struct core_conn_create_data { + int length; + struct nci_core_conn_create_cmd *cmd; +}; + +static void nci_cmd_work(struct work_struct *work); +static void nci_rx_work(struct work_struct *work); +static void nci_tx_work(struct work_struct *work); + +struct nci_conn_info *nci_get_conn_info_by_conn_id(struct nci_dev *ndev, + int conn_id) +{ + struct nci_conn_info *conn_info; + + list_for_each_entry(conn_info, &ndev->conn_info_list, list) { + if (conn_info->conn_id == conn_id) + return conn_info; + } + + return NULL; +} + +int nci_get_conn_info_by_dest_type_params(struct nci_dev *ndev, u8 dest_type, + const struct dest_spec_params *params) +{ + const struct nci_conn_info *conn_info; + + list_for_each_entry(conn_info, &ndev->conn_info_list, list) { + if (conn_info->dest_type == dest_type) { + if (!params) + return conn_info->conn_id; + + if (params->id == conn_info->dest_params->id && + params->protocol == conn_info->dest_params->protocol) + return conn_info->conn_id; + } + } + + return -EINVAL; +} +EXPORT_SYMBOL(nci_get_conn_info_by_dest_type_params); + +/* ---- NCI requests ---- */ + +void nci_req_complete(struct nci_dev *ndev, int result) +{ + if (ndev->req_status == NCI_REQ_PEND) { + ndev->req_result = result; + ndev->req_status = NCI_REQ_DONE; + complete(&ndev->req_completion); + } +} +EXPORT_SYMBOL(nci_req_complete); + +static void nci_req_cancel(struct nci_dev *ndev, int err) +{ + if (ndev->req_status == NCI_REQ_PEND) { + ndev->req_result = err; + ndev->req_status = NCI_REQ_CANCELED; + complete(&ndev->req_completion); + } +} + +/* Execute request and wait for completion. */ +static int __nci_request(struct nci_dev *ndev, + void (*req)(struct nci_dev *ndev, const void *opt), + const void *opt, __u32 timeout) +{ + int rc = 0; + long completion_rc; + + ndev->req_status = NCI_REQ_PEND; + + reinit_completion(&ndev->req_completion); + req(ndev, opt); + completion_rc = + wait_for_completion_interruptible_timeout(&ndev->req_completion, + timeout); + + pr_debug("wait_for_completion return %ld\n", completion_rc); + + if (completion_rc > 0) { + switch (ndev->req_status) { + case NCI_REQ_DONE: + rc = nci_to_errno(ndev->req_result); + break; + + case NCI_REQ_CANCELED: + rc = -ndev->req_result; + break; + + default: + rc = -ETIMEDOUT; + break; + } + } else { + pr_err("wait_for_completion_interruptible_timeout failed %ld\n", + completion_rc); + + rc = ((completion_rc == 0) ? (-ETIMEDOUT) : (completion_rc)); + } + + ndev->req_status = ndev->req_result = 0; + + return rc; +} + +inline int nci_request(struct nci_dev *ndev, + void (*req)(struct nci_dev *ndev, + const void *opt), + const void *opt, __u32 timeout) +{ + int rc; + + /* Serialize all requests */ + mutex_lock(&ndev->req_lock); + /* check the state after obtaing the lock against any races + * from nci_close_device when the device gets removed. + */ + if (test_bit(NCI_UP, &ndev->flags)) + rc = __nci_request(ndev, req, opt, timeout); + else + rc = -ENETDOWN; + mutex_unlock(&ndev->req_lock); + + return rc; +} + +static void nci_reset_req(struct nci_dev *ndev, const void *opt) +{ + struct nci_core_reset_cmd cmd; + + cmd.reset_type = NCI_RESET_TYPE_RESET_CONFIG; + nci_send_cmd(ndev, NCI_OP_CORE_RESET_CMD, 1, &cmd); +} + +static void nci_init_req(struct nci_dev *ndev, const void *opt) +{ + u8 plen = 0; + + if (opt) + plen = sizeof(struct nci_core_init_v2_cmd); + + nci_send_cmd(ndev, NCI_OP_CORE_INIT_CMD, plen, opt); +} + +static void nci_init_complete_req(struct nci_dev *ndev, const void *opt) +{ + struct nci_rf_disc_map_cmd cmd; + struct disc_map_config *cfg = cmd.mapping_configs; + __u8 *num = &cmd.num_mapping_configs; + int i; + + /* set rf mapping configurations */ + *num = 0; + + /* by default mapping is set to NCI_RF_INTERFACE_FRAME */ + for (i = 0; i < ndev->num_supported_rf_interfaces; i++) { + if (ndev->supported_rf_interfaces[i] == + NCI_RF_INTERFACE_ISO_DEP) { + cfg[*num].rf_protocol = NCI_RF_PROTOCOL_ISO_DEP; + cfg[*num].mode = NCI_DISC_MAP_MODE_POLL | + NCI_DISC_MAP_MODE_LISTEN; + cfg[*num].rf_interface = NCI_RF_INTERFACE_ISO_DEP; + (*num)++; + } else if (ndev->supported_rf_interfaces[i] == + NCI_RF_INTERFACE_NFC_DEP) { + cfg[*num].rf_protocol = NCI_RF_PROTOCOL_NFC_DEP; + cfg[*num].mode = NCI_DISC_MAP_MODE_POLL | + NCI_DISC_MAP_MODE_LISTEN; + cfg[*num].rf_interface = NCI_RF_INTERFACE_NFC_DEP; + (*num)++; + } + + if (*num == NCI_MAX_NUM_MAPPING_CONFIGS) + break; + } + + nci_send_cmd(ndev, NCI_OP_RF_DISCOVER_MAP_CMD, + (1 + ((*num) * sizeof(struct disc_map_config))), &cmd); +} + +struct nci_set_config_param { + __u8 id; + size_t len; + const __u8 *val; +}; + +static void nci_set_config_req(struct nci_dev *ndev, const void *opt) +{ + const struct nci_set_config_param *param = opt; + struct nci_core_set_config_cmd cmd; + + BUG_ON(param->len > NCI_MAX_PARAM_LEN); + + cmd.num_params = 1; + cmd.param.id = param->id; + cmd.param.len = param->len; + memcpy(cmd.param.val, param->val, param->len); + + nci_send_cmd(ndev, NCI_OP_CORE_SET_CONFIG_CMD, (3 + param->len), &cmd); +} + +struct nci_rf_discover_param { + __u32 im_protocols; + __u32 tm_protocols; +}; + +static void nci_rf_discover_req(struct nci_dev *ndev, const void *opt) +{ + const struct nci_rf_discover_param *param = opt; + struct nci_rf_disc_cmd cmd; + + cmd.num_disc_configs = 0; + + if ((cmd.num_disc_configs < NCI_MAX_NUM_RF_CONFIGS) && + (param->im_protocols & NFC_PROTO_JEWEL_MASK || + param->im_protocols & NFC_PROTO_MIFARE_MASK || + param->im_protocols & NFC_PROTO_ISO14443_MASK || + param->im_protocols & NFC_PROTO_NFC_DEP_MASK)) { + cmd.disc_configs[cmd.num_disc_configs].rf_tech_and_mode = + NCI_NFC_A_PASSIVE_POLL_MODE; + cmd.disc_configs[cmd.num_disc_configs].frequency = 1; + cmd.num_disc_configs++; + } + + if ((cmd.num_disc_configs < NCI_MAX_NUM_RF_CONFIGS) && + (param->im_protocols & NFC_PROTO_ISO14443_B_MASK)) { + cmd.disc_configs[cmd.num_disc_configs].rf_tech_and_mode = + NCI_NFC_B_PASSIVE_POLL_MODE; + cmd.disc_configs[cmd.num_disc_configs].frequency = 1; + cmd.num_disc_configs++; + } + + if ((cmd.num_disc_configs < NCI_MAX_NUM_RF_CONFIGS) && + (param->im_protocols & NFC_PROTO_FELICA_MASK || + param->im_protocols & NFC_PROTO_NFC_DEP_MASK)) { + cmd.disc_configs[cmd.num_disc_configs].rf_tech_and_mode = + NCI_NFC_F_PASSIVE_POLL_MODE; + cmd.disc_configs[cmd.num_disc_configs].frequency = 1; + cmd.num_disc_configs++; + } + + if ((cmd.num_disc_configs < NCI_MAX_NUM_RF_CONFIGS) && + (param->im_protocols & NFC_PROTO_ISO15693_MASK)) { + cmd.disc_configs[cmd.num_disc_configs].rf_tech_and_mode = + NCI_NFC_V_PASSIVE_POLL_MODE; + cmd.disc_configs[cmd.num_disc_configs].frequency = 1; + cmd.num_disc_configs++; + } + + if ((cmd.num_disc_configs < NCI_MAX_NUM_RF_CONFIGS - 1) && + (param->tm_protocols & NFC_PROTO_NFC_DEP_MASK)) { + cmd.disc_configs[cmd.num_disc_configs].rf_tech_and_mode = + NCI_NFC_A_PASSIVE_LISTEN_MODE; + cmd.disc_configs[cmd.num_disc_configs].frequency = 1; + cmd.num_disc_configs++; + cmd.disc_configs[cmd.num_disc_configs].rf_tech_and_mode = + NCI_NFC_F_PASSIVE_LISTEN_MODE; + cmd.disc_configs[cmd.num_disc_configs].frequency = 1; + cmd.num_disc_configs++; + } + + nci_send_cmd(ndev, NCI_OP_RF_DISCOVER_CMD, + (1 + (cmd.num_disc_configs * sizeof(struct disc_config))), + &cmd); +} + +struct nci_rf_discover_select_param { + __u8 rf_discovery_id; + __u8 rf_protocol; +}; + +static void nci_rf_discover_select_req(struct nci_dev *ndev, const void *opt) +{ + const struct nci_rf_discover_select_param *param = opt; + struct nci_rf_discover_select_cmd cmd; + + cmd.rf_discovery_id = param->rf_discovery_id; + cmd.rf_protocol = param->rf_protocol; + + switch (cmd.rf_protocol) { + case NCI_RF_PROTOCOL_ISO_DEP: + cmd.rf_interface = NCI_RF_INTERFACE_ISO_DEP; + break; + + case NCI_RF_PROTOCOL_NFC_DEP: + cmd.rf_interface = NCI_RF_INTERFACE_NFC_DEP; + break; + + default: + cmd.rf_interface = NCI_RF_INTERFACE_FRAME; + break; + } + + nci_send_cmd(ndev, NCI_OP_RF_DISCOVER_SELECT_CMD, + sizeof(struct nci_rf_discover_select_cmd), &cmd); +} + +static void nci_rf_deactivate_req(struct nci_dev *ndev, const void *opt) +{ + struct nci_rf_deactivate_cmd cmd; + + cmd.type = (unsigned long)opt; + + nci_send_cmd(ndev, NCI_OP_RF_DEACTIVATE_CMD, + sizeof(struct nci_rf_deactivate_cmd), &cmd); +} + +struct nci_cmd_param { + __u16 opcode; + size_t len; + const __u8 *payload; +}; + +static void nci_generic_req(struct nci_dev *ndev, const void *opt) +{ + const struct nci_cmd_param *param = opt; + + nci_send_cmd(ndev, param->opcode, param->len, param->payload); +} + +int nci_prop_cmd(struct nci_dev *ndev, __u8 oid, size_t len, const __u8 *payload) +{ + struct nci_cmd_param param; + + param.opcode = nci_opcode_pack(NCI_GID_PROPRIETARY, oid); + param.len = len; + param.payload = payload; + + return __nci_request(ndev, nci_generic_req, ¶m, + msecs_to_jiffies(NCI_CMD_TIMEOUT)); +} +EXPORT_SYMBOL(nci_prop_cmd); + +int nci_core_cmd(struct nci_dev *ndev, __u16 opcode, size_t len, + const __u8 *payload) +{ + struct nci_cmd_param param; + + param.opcode = opcode; + param.len = len; + param.payload = payload; + + return __nci_request(ndev, nci_generic_req, ¶m, + msecs_to_jiffies(NCI_CMD_TIMEOUT)); +} +EXPORT_SYMBOL(nci_core_cmd); + +int nci_core_reset(struct nci_dev *ndev) +{ + return __nci_request(ndev, nci_reset_req, (void *)0, + msecs_to_jiffies(NCI_RESET_TIMEOUT)); +} +EXPORT_SYMBOL(nci_core_reset); + +int nci_core_init(struct nci_dev *ndev) +{ + return __nci_request(ndev, nci_init_req, (void *)0, + msecs_to_jiffies(NCI_INIT_TIMEOUT)); +} +EXPORT_SYMBOL(nci_core_init); + +struct nci_loopback_data { + u8 conn_id; + struct sk_buff *data; +}; + +static void nci_send_data_req(struct nci_dev *ndev, const void *opt) +{ + const struct nci_loopback_data *data = opt; + + nci_send_data(ndev, data->conn_id, data->data); +} + +static void nci_nfcc_loopback_cb(void *context, struct sk_buff *skb, int err) +{ + struct nci_dev *ndev = (struct nci_dev *)context; + struct nci_conn_info *conn_info; + + conn_info = nci_get_conn_info_by_conn_id(ndev, ndev->cur_conn_id); + if (!conn_info) { + nci_req_complete(ndev, NCI_STATUS_REJECTED); + return; + } + + conn_info->rx_skb = skb; + + nci_req_complete(ndev, NCI_STATUS_OK); +} + +int nci_nfcc_loopback(struct nci_dev *ndev, const void *data, size_t data_len, + struct sk_buff **resp) +{ + int r; + struct nci_loopback_data loopback_data; + struct nci_conn_info *conn_info; + struct sk_buff *skb; + int conn_id = nci_get_conn_info_by_dest_type_params(ndev, + NCI_DESTINATION_NFCC_LOOPBACK, NULL); + + if (conn_id < 0) { + r = nci_core_conn_create(ndev, NCI_DESTINATION_NFCC_LOOPBACK, + 0, 0, NULL); + if (r != NCI_STATUS_OK) + return r; + + conn_id = nci_get_conn_info_by_dest_type_params(ndev, + NCI_DESTINATION_NFCC_LOOPBACK, + NULL); + } + + conn_info = nci_get_conn_info_by_conn_id(ndev, conn_id); + if (!conn_info) + return -EPROTO; + + /* store cb and context to be used on receiving data */ + conn_info->data_exchange_cb = nci_nfcc_loopback_cb; + conn_info->data_exchange_cb_context = ndev; + + skb = nci_skb_alloc(ndev, NCI_DATA_HDR_SIZE + data_len, GFP_KERNEL); + if (!skb) + return -ENOMEM; + + skb_reserve(skb, NCI_DATA_HDR_SIZE); + skb_put_data(skb, data, data_len); + + loopback_data.conn_id = conn_id; + loopback_data.data = skb; + + ndev->cur_conn_id = conn_id; + r = nci_request(ndev, nci_send_data_req, &loopback_data, + msecs_to_jiffies(NCI_DATA_TIMEOUT)); + if (r == NCI_STATUS_OK && resp) + *resp = conn_info->rx_skb; + + return r; +} +EXPORT_SYMBOL(nci_nfcc_loopback); + +static int nci_open_device(struct nci_dev *ndev) +{ + int rc = 0; + + mutex_lock(&ndev->req_lock); + + if (test_bit(NCI_UNREG, &ndev->flags)) { + rc = -ENODEV; + goto done; + } + + if (test_bit(NCI_UP, &ndev->flags)) { + rc = -EALREADY; + goto done; + } + + if (ndev->ops->open(ndev)) { + rc = -EIO; + goto done; + } + + atomic_set(&ndev->cmd_cnt, 1); + + set_bit(NCI_INIT, &ndev->flags); + + if (ndev->ops->init) + rc = ndev->ops->init(ndev); + + if (!rc) { + rc = __nci_request(ndev, nci_reset_req, (void *)0, + msecs_to_jiffies(NCI_RESET_TIMEOUT)); + } + + if (!rc && ndev->ops->setup) { + rc = ndev->ops->setup(ndev); + } + + if (!rc) { + struct nci_core_init_v2_cmd nci_init_v2_cmd = { + .feature1 = NCI_FEATURE_DISABLE, + .feature2 = NCI_FEATURE_DISABLE + }; + const void *opt = NULL; + + if (ndev->nci_ver & NCI_VER_2_MASK) + opt = &nci_init_v2_cmd; + + rc = __nci_request(ndev, nci_init_req, opt, + msecs_to_jiffies(NCI_INIT_TIMEOUT)); + } + + if (!rc && ndev->ops->post_setup) + rc = ndev->ops->post_setup(ndev); + + if (!rc) { + rc = __nci_request(ndev, nci_init_complete_req, (void *)0, + msecs_to_jiffies(NCI_INIT_TIMEOUT)); + } + + clear_bit(NCI_INIT, &ndev->flags); + + if (!rc) { + set_bit(NCI_UP, &ndev->flags); + nci_clear_target_list(ndev); + atomic_set(&ndev->state, NCI_IDLE); + } else { + /* Init failed, cleanup */ + skb_queue_purge(&ndev->cmd_q); + skb_queue_purge(&ndev->rx_q); + skb_queue_purge(&ndev->tx_q); + + ndev->ops->close(ndev); + ndev->flags &= BIT(NCI_UNREG); + } + +done: + mutex_unlock(&ndev->req_lock); + return rc; +} + +static int nci_close_device(struct nci_dev *ndev) +{ + nci_req_cancel(ndev, ENODEV); + + /* This mutex needs to be held as a barrier for + * caller nci_unregister_device + */ + mutex_lock(&ndev->req_lock); + + if (!test_and_clear_bit(NCI_UP, &ndev->flags)) { + /* Need to flush the cmd wq in case + * there is a queued/running cmd_work + */ + flush_workqueue(ndev->cmd_wq); + del_timer_sync(&ndev->cmd_timer); + del_timer_sync(&ndev->data_timer); + mutex_unlock(&ndev->req_lock); + return 0; + } + + /* Drop RX and TX queues */ + skb_queue_purge(&ndev->rx_q); + skb_queue_purge(&ndev->tx_q); + + /* Flush RX and TX wq */ + flush_workqueue(ndev->rx_wq); + flush_workqueue(ndev->tx_wq); + + /* Reset device */ + skb_queue_purge(&ndev->cmd_q); + atomic_set(&ndev->cmd_cnt, 1); + + set_bit(NCI_INIT, &ndev->flags); + __nci_request(ndev, nci_reset_req, (void *)0, + msecs_to_jiffies(NCI_RESET_TIMEOUT)); + + /* After this point our queues are empty + * and no works are scheduled. + */ + ndev->ops->close(ndev); + + clear_bit(NCI_INIT, &ndev->flags); + + /* Flush cmd wq */ + flush_workqueue(ndev->cmd_wq); + + del_timer_sync(&ndev->cmd_timer); + + /* Clear flags except NCI_UNREG */ + ndev->flags &= BIT(NCI_UNREG); + + mutex_unlock(&ndev->req_lock); + + return 0; +} + +/* NCI command timer function */ +static void nci_cmd_timer(struct timer_list *t) +{ + struct nci_dev *ndev = from_timer(ndev, t, cmd_timer); + + atomic_set(&ndev->cmd_cnt, 1); + queue_work(ndev->cmd_wq, &ndev->cmd_work); +} + +/* NCI data exchange timer function */ +static void nci_data_timer(struct timer_list *t) +{ + struct nci_dev *ndev = from_timer(ndev, t, data_timer); + + set_bit(NCI_DATA_EXCHANGE_TO, &ndev->flags); + queue_work(ndev->rx_wq, &ndev->rx_work); +} + +static int nci_dev_up(struct nfc_dev *nfc_dev) +{ + struct nci_dev *ndev = nfc_get_drvdata(nfc_dev); + + return nci_open_device(ndev); +} + +static int nci_dev_down(struct nfc_dev *nfc_dev) +{ + struct nci_dev *ndev = nfc_get_drvdata(nfc_dev); + + return nci_close_device(ndev); +} + +int nci_set_config(struct nci_dev *ndev, __u8 id, size_t len, const __u8 *val) +{ + struct nci_set_config_param param; + + if (!val || !len) + return 0; + + param.id = id; + param.len = len; + param.val = val; + + return __nci_request(ndev, nci_set_config_req, ¶m, + msecs_to_jiffies(NCI_SET_CONFIG_TIMEOUT)); +} +EXPORT_SYMBOL(nci_set_config); + +static void nci_nfcee_discover_req(struct nci_dev *ndev, const void *opt) +{ + struct nci_nfcee_discover_cmd cmd; + __u8 action = (unsigned long)opt; + + cmd.discovery_action = action; + + nci_send_cmd(ndev, NCI_OP_NFCEE_DISCOVER_CMD, 1, &cmd); +} + +int nci_nfcee_discover(struct nci_dev *ndev, u8 action) +{ + unsigned long opt = action; + + return __nci_request(ndev, nci_nfcee_discover_req, (void *)opt, + msecs_to_jiffies(NCI_CMD_TIMEOUT)); +} +EXPORT_SYMBOL(nci_nfcee_discover); + +static void nci_nfcee_mode_set_req(struct nci_dev *ndev, const void *opt) +{ + const struct nci_nfcee_mode_set_cmd *cmd = opt; + + nci_send_cmd(ndev, NCI_OP_NFCEE_MODE_SET_CMD, + sizeof(struct nci_nfcee_mode_set_cmd), cmd); +} + +int nci_nfcee_mode_set(struct nci_dev *ndev, u8 nfcee_id, u8 nfcee_mode) +{ + struct nci_nfcee_mode_set_cmd cmd; + + cmd.nfcee_id = nfcee_id; + cmd.nfcee_mode = nfcee_mode; + + return __nci_request(ndev, nci_nfcee_mode_set_req, &cmd, + msecs_to_jiffies(NCI_CMD_TIMEOUT)); +} +EXPORT_SYMBOL(nci_nfcee_mode_set); + +static void nci_core_conn_create_req(struct nci_dev *ndev, const void *opt) +{ + const struct core_conn_create_data *data = opt; + + nci_send_cmd(ndev, NCI_OP_CORE_CONN_CREATE_CMD, data->length, data->cmd); +} + +int nci_core_conn_create(struct nci_dev *ndev, u8 destination_type, + u8 number_destination_params, + size_t params_len, + const struct core_conn_create_dest_spec_params *params) +{ + int r; + struct nci_core_conn_create_cmd *cmd; + struct core_conn_create_data data; + + data.length = params_len + sizeof(struct nci_core_conn_create_cmd); + cmd = kzalloc(data.length, GFP_KERNEL); + if (!cmd) + return -ENOMEM; + + cmd->destination_type = destination_type; + cmd->number_destination_params = number_destination_params; + + data.cmd = cmd; + + if (params) { + memcpy(cmd->params, params, params_len); + if (params->length > 0) + memcpy(&ndev->cur_params, + ¶ms->value[DEST_SPEC_PARAMS_ID_INDEX], + sizeof(struct dest_spec_params)); + else + ndev->cur_params.id = 0; + } else { + ndev->cur_params.id = 0; + } + ndev->cur_dest_type = destination_type; + + r = __nci_request(ndev, nci_core_conn_create_req, &data, + msecs_to_jiffies(NCI_CMD_TIMEOUT)); + kfree(cmd); + return r; +} +EXPORT_SYMBOL(nci_core_conn_create); + +static void nci_core_conn_close_req(struct nci_dev *ndev, const void *opt) +{ + __u8 conn_id = (unsigned long)opt; + + nci_send_cmd(ndev, NCI_OP_CORE_CONN_CLOSE_CMD, 1, &conn_id); +} + +int nci_core_conn_close(struct nci_dev *ndev, u8 conn_id) +{ + unsigned long opt = conn_id; + + ndev->cur_conn_id = conn_id; + return __nci_request(ndev, nci_core_conn_close_req, (void *)opt, + msecs_to_jiffies(NCI_CMD_TIMEOUT)); +} +EXPORT_SYMBOL(nci_core_conn_close); + +static int nci_set_local_general_bytes(struct nfc_dev *nfc_dev) +{ + struct nci_dev *ndev = nfc_get_drvdata(nfc_dev); + struct nci_set_config_param param; + int rc; + + param.val = nfc_get_local_general_bytes(nfc_dev, ¶m.len); + if ((param.val == NULL) || (param.len == 0)) + return 0; + + if (param.len > NFC_MAX_GT_LEN) + return -EINVAL; + + param.id = NCI_PN_ATR_REQ_GEN_BYTES; + + rc = nci_request(ndev, nci_set_config_req, ¶m, + msecs_to_jiffies(NCI_SET_CONFIG_TIMEOUT)); + if (rc) + return rc; + + param.id = NCI_LN_ATR_RES_GEN_BYTES; + + return nci_request(ndev, nci_set_config_req, ¶m, + msecs_to_jiffies(NCI_SET_CONFIG_TIMEOUT)); +} + +static int nci_set_listen_parameters(struct nfc_dev *nfc_dev) +{ + struct nci_dev *ndev = nfc_get_drvdata(nfc_dev); + int rc; + __u8 val; + + val = NCI_LA_SEL_INFO_NFC_DEP_MASK; + + rc = nci_set_config(ndev, NCI_LA_SEL_INFO, 1, &val); + if (rc) + return rc; + + val = NCI_LF_PROTOCOL_TYPE_NFC_DEP_MASK; + + rc = nci_set_config(ndev, NCI_LF_PROTOCOL_TYPE, 1, &val); + if (rc) + return rc; + + val = NCI_LF_CON_BITR_F_212 | NCI_LF_CON_BITR_F_424; + + return nci_set_config(ndev, NCI_LF_CON_BITR_F, 1, &val); +} + +static int nci_start_poll(struct nfc_dev *nfc_dev, + __u32 im_protocols, __u32 tm_protocols) +{ + struct nci_dev *ndev = nfc_get_drvdata(nfc_dev); + struct nci_rf_discover_param param; + int rc; + + if ((atomic_read(&ndev->state) == NCI_DISCOVERY) || + (atomic_read(&ndev->state) == NCI_W4_ALL_DISCOVERIES)) { + pr_err("unable to start poll, since poll is already active\n"); + return -EBUSY; + } + + if (ndev->target_active_prot) { + pr_err("there is an active target\n"); + return -EBUSY; + } + + if ((atomic_read(&ndev->state) == NCI_W4_HOST_SELECT) || + (atomic_read(&ndev->state) == NCI_POLL_ACTIVE)) { + pr_debug("target active or w4 select, implicitly deactivate\n"); + + rc = nci_request(ndev, nci_rf_deactivate_req, + (void *)NCI_DEACTIVATE_TYPE_IDLE_MODE, + msecs_to_jiffies(NCI_RF_DEACTIVATE_TIMEOUT)); + if (rc) + return -EBUSY; + } + + if ((im_protocols | tm_protocols) & NFC_PROTO_NFC_DEP_MASK) { + rc = nci_set_local_general_bytes(nfc_dev); + if (rc) { + pr_err("failed to set local general bytes\n"); + return rc; + } + } + + if (tm_protocols & NFC_PROTO_NFC_DEP_MASK) { + rc = nci_set_listen_parameters(nfc_dev); + if (rc) + pr_err("failed to set listen parameters\n"); + } + + param.im_protocols = im_protocols; + param.tm_protocols = tm_protocols; + rc = nci_request(ndev, nci_rf_discover_req, ¶m, + msecs_to_jiffies(NCI_RF_DISC_TIMEOUT)); + + if (!rc) + ndev->poll_prots = im_protocols; + + return rc; +} + +static void nci_stop_poll(struct nfc_dev *nfc_dev) +{ + struct nci_dev *ndev = nfc_get_drvdata(nfc_dev); + + if ((atomic_read(&ndev->state) != NCI_DISCOVERY) && + (atomic_read(&ndev->state) != NCI_W4_ALL_DISCOVERIES)) { + pr_err("unable to stop poll, since poll is not active\n"); + return; + } + + nci_request(ndev, nci_rf_deactivate_req, + (void *)NCI_DEACTIVATE_TYPE_IDLE_MODE, + msecs_to_jiffies(NCI_RF_DEACTIVATE_TIMEOUT)); +} + +static int nci_activate_target(struct nfc_dev *nfc_dev, + struct nfc_target *target, __u32 protocol) +{ + struct nci_dev *ndev = nfc_get_drvdata(nfc_dev); + struct nci_rf_discover_select_param param; + const struct nfc_target *nci_target = NULL; + int i; + int rc = 0; + + pr_debug("target_idx %d, protocol 0x%x\n", target->idx, protocol); + + if ((atomic_read(&ndev->state) != NCI_W4_HOST_SELECT) && + (atomic_read(&ndev->state) != NCI_POLL_ACTIVE)) { + pr_err("there is no available target to activate\n"); + return -EINVAL; + } + + if (ndev->target_active_prot) { + pr_err("there is already an active target\n"); + return -EBUSY; + } + + for (i = 0; i < ndev->n_targets; i++) { + if (ndev->targets[i].idx == target->idx) { + nci_target = &ndev->targets[i]; + break; + } + } + + if (!nci_target) { + pr_err("unable to find the selected target\n"); + return -EINVAL; + } + + if (protocol >= NFC_PROTO_MAX) { + pr_err("the requested nfc protocol is invalid\n"); + return -EINVAL; + } + + if (!(nci_target->supported_protocols & (1 << protocol))) { + pr_err("target does not support the requested protocol 0x%x\n", + protocol); + return -EINVAL; + } + + if (atomic_read(&ndev->state) == NCI_W4_HOST_SELECT) { + param.rf_discovery_id = nci_target->logical_idx; + + if (protocol == NFC_PROTO_JEWEL) + param.rf_protocol = NCI_RF_PROTOCOL_T1T; + else if (protocol == NFC_PROTO_MIFARE) + param.rf_protocol = NCI_RF_PROTOCOL_T2T; + else if (protocol == NFC_PROTO_FELICA) + param.rf_protocol = NCI_RF_PROTOCOL_T3T; + else if (protocol == NFC_PROTO_ISO14443 || + protocol == NFC_PROTO_ISO14443_B) + param.rf_protocol = NCI_RF_PROTOCOL_ISO_DEP; + else + param.rf_protocol = NCI_RF_PROTOCOL_NFC_DEP; + + rc = nci_request(ndev, nci_rf_discover_select_req, ¶m, + msecs_to_jiffies(NCI_RF_DISC_SELECT_TIMEOUT)); + } + + if (!rc) + ndev->target_active_prot = protocol; + + return rc; +} + +static void nci_deactivate_target(struct nfc_dev *nfc_dev, + struct nfc_target *target, + __u8 mode) +{ + struct nci_dev *ndev = nfc_get_drvdata(nfc_dev); + unsigned long nci_mode = NCI_DEACTIVATE_TYPE_IDLE_MODE; + + if (!ndev->target_active_prot) { + pr_err("unable to deactivate target, no active target\n"); + return; + } + + ndev->target_active_prot = 0; + + switch (mode) { + case NFC_TARGET_MODE_SLEEP: + nci_mode = NCI_DEACTIVATE_TYPE_SLEEP_MODE; + break; + } + + if (atomic_read(&ndev->state) == NCI_POLL_ACTIVE) { + nci_request(ndev, nci_rf_deactivate_req, (void *)nci_mode, + msecs_to_jiffies(NCI_RF_DEACTIVATE_TIMEOUT)); + } +} + +static int nci_dep_link_up(struct nfc_dev *nfc_dev, struct nfc_target *target, + __u8 comm_mode, __u8 *gb, size_t gb_len) +{ + struct nci_dev *ndev = nfc_get_drvdata(nfc_dev); + int rc; + + pr_debug("target_idx %d, comm_mode %d\n", target->idx, comm_mode); + + rc = nci_activate_target(nfc_dev, target, NFC_PROTO_NFC_DEP); + if (rc) + return rc; + + rc = nfc_set_remote_general_bytes(nfc_dev, ndev->remote_gb, + ndev->remote_gb_len); + if (!rc) + rc = nfc_dep_link_is_up(nfc_dev, target->idx, NFC_COMM_PASSIVE, + NFC_RF_INITIATOR); + + return rc; +} + +static int nci_dep_link_down(struct nfc_dev *nfc_dev) +{ + struct nci_dev *ndev = nfc_get_drvdata(nfc_dev); + int rc; + + if (nfc_dev->rf_mode == NFC_RF_INITIATOR) { + nci_deactivate_target(nfc_dev, NULL, NCI_DEACTIVATE_TYPE_IDLE_MODE); + } else { + if (atomic_read(&ndev->state) == NCI_LISTEN_ACTIVE || + atomic_read(&ndev->state) == NCI_DISCOVERY) { + nci_request(ndev, nci_rf_deactivate_req, (void *)0, + msecs_to_jiffies(NCI_RF_DEACTIVATE_TIMEOUT)); + } + + rc = nfc_tm_deactivated(nfc_dev); + if (rc) + pr_err("error when signaling tm deactivation\n"); + } + + return 0; +} + + +static int nci_transceive(struct nfc_dev *nfc_dev, struct nfc_target *target, + struct sk_buff *skb, + data_exchange_cb_t cb, void *cb_context) +{ + struct nci_dev *ndev = nfc_get_drvdata(nfc_dev); + int rc; + struct nci_conn_info *conn_info; + + conn_info = ndev->rf_conn_info; + if (!conn_info) + return -EPROTO; + + pr_debug("target_idx %d, len %d\n", target->idx, skb->len); + + if (!ndev->target_active_prot) { + pr_err("unable to exchange data, no active target\n"); + return -EINVAL; + } + + if (test_and_set_bit(NCI_DATA_EXCHANGE, &ndev->flags)) + return -EBUSY; + + /* store cb and context to be used on receiving data */ + conn_info->data_exchange_cb = cb; + conn_info->data_exchange_cb_context = cb_context; + + rc = nci_send_data(ndev, NCI_STATIC_RF_CONN_ID, skb); + if (rc) + clear_bit(NCI_DATA_EXCHANGE, &ndev->flags); + + return rc; +} + +static int nci_tm_send(struct nfc_dev *nfc_dev, struct sk_buff *skb) +{ + struct nci_dev *ndev = nfc_get_drvdata(nfc_dev); + int rc; + + rc = nci_send_data(ndev, NCI_STATIC_RF_CONN_ID, skb); + if (rc) + pr_err("unable to send data\n"); + + return rc; +} + +static int nci_enable_se(struct nfc_dev *nfc_dev, u32 se_idx) +{ + struct nci_dev *ndev = nfc_get_drvdata(nfc_dev); + + if (ndev->ops->enable_se) + return ndev->ops->enable_se(ndev, se_idx); + + return 0; +} + +static int nci_disable_se(struct nfc_dev *nfc_dev, u32 se_idx) +{ + struct nci_dev *ndev = nfc_get_drvdata(nfc_dev); + + if (ndev->ops->disable_se) + return ndev->ops->disable_se(ndev, se_idx); + + return 0; +} + +static int nci_discover_se(struct nfc_dev *nfc_dev) +{ + int r; + struct nci_dev *ndev = nfc_get_drvdata(nfc_dev); + + if (ndev->ops->discover_se) { + r = nci_nfcee_discover(ndev, NCI_NFCEE_DISCOVERY_ACTION_ENABLE); + if (r != NCI_STATUS_OK) + return -EPROTO; + + return ndev->ops->discover_se(ndev); + } + + return 0; +} + +static int nci_se_io(struct nfc_dev *nfc_dev, u32 se_idx, + u8 *apdu, size_t apdu_length, + se_io_cb_t cb, void *cb_context) +{ + struct nci_dev *ndev = nfc_get_drvdata(nfc_dev); + + if (ndev->ops->se_io) + return ndev->ops->se_io(ndev, se_idx, apdu, + apdu_length, cb, cb_context); + + return 0; +} + +static int nci_fw_download(struct nfc_dev *nfc_dev, const char *firmware_name) +{ + struct nci_dev *ndev = nfc_get_drvdata(nfc_dev); + + if (!ndev->ops->fw_download) + return -ENOTSUPP; + + return ndev->ops->fw_download(ndev, firmware_name); +} + +static const struct nfc_ops nci_nfc_ops = { + .dev_up = nci_dev_up, + .dev_down = nci_dev_down, + .start_poll = nci_start_poll, + .stop_poll = nci_stop_poll, + .dep_link_up = nci_dep_link_up, + .dep_link_down = nci_dep_link_down, + .activate_target = nci_activate_target, + .deactivate_target = nci_deactivate_target, + .im_transceive = nci_transceive, + .tm_send = nci_tm_send, + .enable_se = nci_enable_se, + .disable_se = nci_disable_se, + .discover_se = nci_discover_se, + .se_io = nci_se_io, + .fw_download = nci_fw_download, +}; + +/* ---- Interface to NCI drivers ---- */ +/** + * nci_allocate_device - allocate a new nci device + * + * @ops: device operations + * @supported_protocols: NFC protocols supported by the device + * @tx_headroom: Reserved space at beginning of skb + * @tx_tailroom: Reserved space at end of skb + */ +struct nci_dev *nci_allocate_device(const struct nci_ops *ops, + __u32 supported_protocols, + int tx_headroom, int tx_tailroom) +{ + struct nci_dev *ndev; + + pr_debug("supported_protocols 0x%x\n", supported_protocols); + + if (!ops->open || !ops->close || !ops->send) + return NULL; + + if (!supported_protocols) + return NULL; + + ndev = kzalloc(sizeof(struct nci_dev), GFP_KERNEL); + if (!ndev) + return NULL; + + ndev->ops = ops; + + if (ops->n_prop_ops > NCI_MAX_PROPRIETARY_CMD) { + pr_err("Too many proprietary commands: %zd\n", + ops->n_prop_ops); + goto free_nci; + } + + ndev->tx_headroom = tx_headroom; + ndev->tx_tailroom = tx_tailroom; + init_completion(&ndev->req_completion); + + ndev->nfc_dev = nfc_allocate_device(&nci_nfc_ops, + supported_protocols, + tx_headroom + NCI_DATA_HDR_SIZE, + tx_tailroom); + if (!ndev->nfc_dev) + goto free_nci; + + ndev->hci_dev = nci_hci_allocate(ndev); + if (!ndev->hci_dev) + goto free_nfc; + + nfc_set_drvdata(ndev->nfc_dev, ndev); + + return ndev; + +free_nfc: + nfc_free_device(ndev->nfc_dev); +free_nci: + kfree(ndev); + return NULL; +} +EXPORT_SYMBOL(nci_allocate_device); + +/** + * nci_free_device - deallocate nci device + * + * @ndev: The nci device to deallocate + */ +void nci_free_device(struct nci_dev *ndev) +{ + nfc_free_device(ndev->nfc_dev); + nci_hci_deallocate(ndev); + kfree(ndev); +} +EXPORT_SYMBOL(nci_free_device); + +/** + * nci_register_device - register a nci device in the nfc subsystem + * + * @ndev: The nci device to register + */ +int nci_register_device(struct nci_dev *ndev) +{ + int rc; + struct device *dev = &ndev->nfc_dev->dev; + char name[32]; + + ndev->flags = 0; + + INIT_WORK(&ndev->cmd_work, nci_cmd_work); + snprintf(name, sizeof(name), "%s_nci_cmd_wq", dev_name(dev)); + ndev->cmd_wq = create_singlethread_workqueue(name); + if (!ndev->cmd_wq) { + rc = -ENOMEM; + goto exit; + } + + INIT_WORK(&ndev->rx_work, nci_rx_work); + snprintf(name, sizeof(name), "%s_nci_rx_wq", dev_name(dev)); + ndev->rx_wq = create_singlethread_workqueue(name); + if (!ndev->rx_wq) { + rc = -ENOMEM; + goto destroy_cmd_wq_exit; + } + + INIT_WORK(&ndev->tx_work, nci_tx_work); + snprintf(name, sizeof(name), "%s_nci_tx_wq", dev_name(dev)); + ndev->tx_wq = create_singlethread_workqueue(name); + if (!ndev->tx_wq) { + rc = -ENOMEM; + goto destroy_rx_wq_exit; + } + + skb_queue_head_init(&ndev->cmd_q); + skb_queue_head_init(&ndev->rx_q); + skb_queue_head_init(&ndev->tx_q); + + timer_setup(&ndev->cmd_timer, nci_cmd_timer, 0); + timer_setup(&ndev->data_timer, nci_data_timer, 0); + + mutex_init(&ndev->req_lock); + INIT_LIST_HEAD(&ndev->conn_info_list); + + rc = nfc_register_device(ndev->nfc_dev); + if (rc) + goto destroy_tx_wq_exit; + + goto exit; + +destroy_tx_wq_exit: + destroy_workqueue(ndev->tx_wq); + +destroy_rx_wq_exit: + destroy_workqueue(ndev->rx_wq); + +destroy_cmd_wq_exit: + destroy_workqueue(ndev->cmd_wq); + +exit: + return rc; +} +EXPORT_SYMBOL(nci_register_device); + +/** + * nci_unregister_device - unregister a nci device in the nfc subsystem + * + * @ndev: The nci device to unregister + */ +void nci_unregister_device(struct nci_dev *ndev) +{ + struct nci_conn_info *conn_info, *n; + + /* This set_bit is not protected with specialized barrier, + * However, it is fine because the mutex_lock(&ndev->req_lock); + * in nci_close_device() will help to emit one. + */ + set_bit(NCI_UNREG, &ndev->flags); + + nci_close_device(ndev); + + destroy_workqueue(ndev->cmd_wq); + destroy_workqueue(ndev->rx_wq); + destroy_workqueue(ndev->tx_wq); + + list_for_each_entry_safe(conn_info, n, &ndev->conn_info_list, list) { + list_del(&conn_info->list); + /* conn_info is allocated with devm_kzalloc */ + } + + nfc_unregister_device(ndev->nfc_dev); +} +EXPORT_SYMBOL(nci_unregister_device); + +/** + * nci_recv_frame - receive frame from NCI drivers + * + * @ndev: The nci device + * @skb: The sk_buff to receive + */ +int nci_recv_frame(struct nci_dev *ndev, struct sk_buff *skb) +{ + pr_debug("len %d\n", skb->len); + + if (!ndev || (!test_bit(NCI_UP, &ndev->flags) && + !test_bit(NCI_INIT, &ndev->flags))) { + kfree_skb(skb); + return -ENXIO; + } + + /* Queue frame for rx worker thread */ + skb_queue_tail(&ndev->rx_q, skb); + queue_work(ndev->rx_wq, &ndev->rx_work); + + return 0; +} +EXPORT_SYMBOL(nci_recv_frame); + +int nci_send_frame(struct nci_dev *ndev, struct sk_buff *skb) +{ + pr_debug("len %d\n", skb->len); + + if (!ndev) { + kfree_skb(skb); + return -ENODEV; + } + + /* Get rid of skb owner, prior to sending to the driver. */ + skb_orphan(skb); + + /* Send copy to sniffer */ + nfc_send_to_raw_sock(ndev->nfc_dev, skb, + RAW_PAYLOAD_NCI, NFC_DIRECTION_TX); + + return ndev->ops->send(ndev, skb); +} +EXPORT_SYMBOL(nci_send_frame); + +/* Send NCI command */ +int nci_send_cmd(struct nci_dev *ndev, __u16 opcode, __u8 plen, const void *payload) +{ + struct nci_ctrl_hdr *hdr; + struct sk_buff *skb; + + pr_debug("opcode 0x%x, plen %d\n", opcode, plen); + + skb = nci_skb_alloc(ndev, (NCI_CTRL_HDR_SIZE + plen), GFP_KERNEL); + if (!skb) { + pr_err("no memory for command\n"); + return -ENOMEM; + } + + hdr = skb_put(skb, NCI_CTRL_HDR_SIZE); + hdr->gid = nci_opcode_gid(opcode); + hdr->oid = nci_opcode_oid(opcode); + hdr->plen = plen; + + nci_mt_set((__u8 *)hdr, NCI_MT_CMD_PKT); + nci_pbf_set((__u8 *)hdr, NCI_PBF_LAST); + + if (plen) + skb_put_data(skb, payload, plen); + + skb_queue_tail(&ndev->cmd_q, skb); + queue_work(ndev->cmd_wq, &ndev->cmd_work); + + return 0; +} +EXPORT_SYMBOL(nci_send_cmd); + +/* Proprietary commands API */ +static const struct nci_driver_ops *ops_cmd_lookup(const struct nci_driver_ops *ops, + size_t n_ops, + __u16 opcode) +{ + size_t i; + const struct nci_driver_ops *op; + + if (!ops || !n_ops) + return NULL; + + for (i = 0; i < n_ops; i++) { + op = &ops[i]; + if (op->opcode == opcode) + return op; + } + + return NULL; +} + +static int nci_op_rsp_packet(struct nci_dev *ndev, __u16 rsp_opcode, + struct sk_buff *skb, const struct nci_driver_ops *ops, + size_t n_ops) +{ + const struct nci_driver_ops *op; + + op = ops_cmd_lookup(ops, n_ops, rsp_opcode); + if (!op || !op->rsp) + return -ENOTSUPP; + + return op->rsp(ndev, skb); +} + +static int nci_op_ntf_packet(struct nci_dev *ndev, __u16 ntf_opcode, + struct sk_buff *skb, const struct nci_driver_ops *ops, + size_t n_ops) +{ + const struct nci_driver_ops *op; + + op = ops_cmd_lookup(ops, n_ops, ntf_opcode); + if (!op || !op->ntf) + return -ENOTSUPP; + + return op->ntf(ndev, skb); +} + +int nci_prop_rsp_packet(struct nci_dev *ndev, __u16 opcode, + struct sk_buff *skb) +{ + return nci_op_rsp_packet(ndev, opcode, skb, ndev->ops->prop_ops, + ndev->ops->n_prop_ops); +} + +int nci_prop_ntf_packet(struct nci_dev *ndev, __u16 opcode, + struct sk_buff *skb) +{ + return nci_op_ntf_packet(ndev, opcode, skb, ndev->ops->prop_ops, + ndev->ops->n_prop_ops); +} + +int nci_core_rsp_packet(struct nci_dev *ndev, __u16 opcode, + struct sk_buff *skb) +{ + return nci_op_rsp_packet(ndev, opcode, skb, ndev->ops->core_ops, + ndev->ops->n_core_ops); +} + +int nci_core_ntf_packet(struct nci_dev *ndev, __u16 opcode, + struct sk_buff *skb) +{ + return nci_op_ntf_packet(ndev, opcode, skb, ndev->ops->core_ops, + ndev->ops->n_core_ops); +} + +/* ---- NCI TX Data worker thread ---- */ + +static void nci_tx_work(struct work_struct *work) +{ + struct nci_dev *ndev = container_of(work, struct nci_dev, tx_work); + struct nci_conn_info *conn_info; + struct sk_buff *skb; + + conn_info = nci_get_conn_info_by_conn_id(ndev, ndev->cur_conn_id); + if (!conn_info) + return; + + pr_debug("credits_cnt %d\n", atomic_read(&conn_info->credits_cnt)); + + /* Send queued tx data */ + while (atomic_read(&conn_info->credits_cnt)) { + skb = skb_dequeue(&ndev->tx_q); + if (!skb) + return; + + /* Check if data flow control is used */ + if (atomic_read(&conn_info->credits_cnt) != + NCI_DATA_FLOW_CONTROL_NOT_USED) + atomic_dec(&conn_info->credits_cnt); + + pr_debug("NCI TX: MT=data, PBF=%d, conn_id=%d, plen=%d\n", + nci_pbf(skb->data), + nci_conn_id(skb->data), + nci_plen(skb->data)); + + nci_send_frame(ndev, skb); + + mod_timer(&ndev->data_timer, + jiffies + msecs_to_jiffies(NCI_DATA_TIMEOUT)); + } +} + +/* ----- NCI RX worker thread (data & control) ----- */ + +static void nci_rx_work(struct work_struct *work) +{ + struct nci_dev *ndev = container_of(work, struct nci_dev, rx_work); + struct sk_buff *skb; + + while ((skb = skb_dequeue(&ndev->rx_q))) { + + /* Send copy to sniffer */ + nfc_send_to_raw_sock(ndev->nfc_dev, skb, + RAW_PAYLOAD_NCI, NFC_DIRECTION_RX); + + /* Process frame */ + switch (nci_mt(skb->data)) { + case NCI_MT_RSP_PKT: + nci_rsp_packet(ndev, skb); + break; + + case NCI_MT_NTF_PKT: + nci_ntf_packet(ndev, skb); + break; + + case NCI_MT_DATA_PKT: + nci_rx_data_packet(ndev, skb); + break; + + default: + pr_err("unknown MT 0x%x\n", nci_mt(skb->data)); + kfree_skb(skb); + break; + } + } + + /* check if a data exchange timeout has occurred */ + if (test_bit(NCI_DATA_EXCHANGE_TO, &ndev->flags)) { + /* complete the data exchange transaction, if exists */ + if (test_bit(NCI_DATA_EXCHANGE, &ndev->flags)) + nci_data_exchange_complete(ndev, NULL, + ndev->cur_conn_id, + -ETIMEDOUT); + + clear_bit(NCI_DATA_EXCHANGE_TO, &ndev->flags); + } +} + +/* ----- NCI TX CMD worker thread ----- */ + +static void nci_cmd_work(struct work_struct *work) +{ + struct nci_dev *ndev = container_of(work, struct nci_dev, cmd_work); + struct sk_buff *skb; + + pr_debug("cmd_cnt %d\n", atomic_read(&ndev->cmd_cnt)); + + /* Send queued command */ + if (atomic_read(&ndev->cmd_cnt)) { + skb = skb_dequeue(&ndev->cmd_q); + if (!skb) + return; + + atomic_dec(&ndev->cmd_cnt); + + pr_debug("NCI TX: MT=cmd, PBF=%d, GID=0x%x, OID=0x%x, plen=%d\n", + nci_pbf(skb->data), + nci_opcode_gid(nci_opcode(skb->data)), + nci_opcode_oid(nci_opcode(skb->data)), + nci_plen(skb->data)); + + nci_send_frame(ndev, skb); + + mod_timer(&ndev->cmd_timer, + jiffies + msecs_to_jiffies(NCI_CMD_TIMEOUT)); + } +} + +MODULE_LICENSE("GPL"); diff --git a/net/nfc/nci/data.c b/net/nfc/nci/data.c new file mode 100644 index 000000000..3d36ea570 --- /dev/null +++ b/net/nfc/nci/data.c @@ -0,0 +1,301 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * The NFC Controller Interface is the communication protocol between an + * NFC Controller (NFCC) and a Device Host (DH). + * + * Copyright (C) 2011 Texas Instruments, Inc. + * Copyright (C) 2014 Marvell International Ltd. + * + * Written by Ilan Elias + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": %s: " fmt, __func__ + +#include +#include +#include +#include +#include + +#include "../nfc.h" +#include +#include +#include + +/* Complete data exchange transaction and forward skb to nfc core */ +void nci_data_exchange_complete(struct nci_dev *ndev, struct sk_buff *skb, + __u8 conn_id, int err) +{ + const struct nci_conn_info *conn_info; + data_exchange_cb_t cb; + void *cb_context; + + conn_info = nci_get_conn_info_by_conn_id(ndev, conn_id); + if (!conn_info) { + kfree_skb(skb); + goto exit; + } + + cb = conn_info->data_exchange_cb; + cb_context = conn_info->data_exchange_cb_context; + + pr_debug("len %d, err %d\n", skb ? skb->len : 0, err); + + /* data exchange is complete, stop the data timer */ + del_timer_sync(&ndev->data_timer); + clear_bit(NCI_DATA_EXCHANGE_TO, &ndev->flags); + + if (cb) { + /* forward skb to nfc core */ + cb(cb_context, skb, err); + } else if (skb) { + pr_err("no rx callback, dropping rx data...\n"); + + /* no waiting callback, free skb */ + kfree_skb(skb); + } + +exit: + clear_bit(NCI_DATA_EXCHANGE, &ndev->flags); +} + +/* ----------------- NCI TX Data ----------------- */ + +static inline void nci_push_data_hdr(struct nci_dev *ndev, + __u8 conn_id, + struct sk_buff *skb, + __u8 pbf) +{ + struct nci_data_hdr *hdr; + int plen = skb->len; + + hdr = skb_push(skb, NCI_DATA_HDR_SIZE); + hdr->conn_id = conn_id; + hdr->rfu = 0; + hdr->plen = plen; + + nci_mt_set((__u8 *)hdr, NCI_MT_DATA_PKT); + nci_pbf_set((__u8 *)hdr, pbf); +} + +int nci_conn_max_data_pkt_payload_size(struct nci_dev *ndev, __u8 conn_id) +{ + const struct nci_conn_info *conn_info; + + conn_info = nci_get_conn_info_by_conn_id(ndev, conn_id); + if (!conn_info) + return -EPROTO; + + return conn_info->max_pkt_payload_len; +} +EXPORT_SYMBOL(nci_conn_max_data_pkt_payload_size); + +static int nci_queue_tx_data_frags(struct nci_dev *ndev, + __u8 conn_id, + struct sk_buff *skb) { + const struct nci_conn_info *conn_info; + int total_len = skb->len; + const unsigned char *data = skb->data; + unsigned long flags; + struct sk_buff_head frags_q; + struct sk_buff *skb_frag; + int frag_len; + int rc = 0; + + pr_debug("conn_id 0x%x, total_len %d\n", conn_id, total_len); + + conn_info = nci_get_conn_info_by_conn_id(ndev, conn_id); + if (!conn_info) { + rc = -EPROTO; + goto exit; + } + + __skb_queue_head_init(&frags_q); + + while (total_len) { + frag_len = + min_t(int, total_len, conn_info->max_pkt_payload_len); + + skb_frag = nci_skb_alloc(ndev, + (NCI_DATA_HDR_SIZE + frag_len), + GFP_ATOMIC); + if (skb_frag == NULL) { + rc = -ENOMEM; + goto free_exit; + } + skb_reserve(skb_frag, NCI_DATA_HDR_SIZE); + + /* first, copy the data */ + skb_put_data(skb_frag, data, frag_len); + + /* second, set the header */ + nci_push_data_hdr(ndev, conn_id, skb_frag, + ((total_len == frag_len) ? + (NCI_PBF_LAST) : (NCI_PBF_CONT))); + + __skb_queue_tail(&frags_q, skb_frag); + + data += frag_len; + total_len -= frag_len; + + pr_debug("frag_len %d, remaining total_len %d\n", + frag_len, total_len); + } + + /* queue all fragments atomically */ + spin_lock_irqsave(&ndev->tx_q.lock, flags); + + while ((skb_frag = __skb_dequeue(&frags_q)) != NULL) + __skb_queue_tail(&ndev->tx_q, skb_frag); + + spin_unlock_irqrestore(&ndev->tx_q.lock, flags); + + /* free the original skb */ + kfree_skb(skb); + + goto exit; + +free_exit: + while ((skb_frag = __skb_dequeue(&frags_q)) != NULL) + kfree_skb(skb_frag); + +exit: + return rc; +} + +/* Send NCI data */ +int nci_send_data(struct nci_dev *ndev, __u8 conn_id, struct sk_buff *skb) +{ + const struct nci_conn_info *conn_info; + int rc = 0; + + pr_debug("conn_id 0x%x, plen %d\n", conn_id, skb->len); + + conn_info = nci_get_conn_info_by_conn_id(ndev, conn_id); + if (!conn_info) { + rc = -EPROTO; + goto free_exit; + } + + /* check if the packet need to be fragmented */ + if (skb->len <= conn_info->max_pkt_payload_len) { + /* no need to fragment packet */ + nci_push_data_hdr(ndev, conn_id, skb, NCI_PBF_LAST); + + skb_queue_tail(&ndev->tx_q, skb); + } else { + /* fragment packet and queue the fragments */ + rc = nci_queue_tx_data_frags(ndev, conn_id, skb); + if (rc) { + pr_err("failed to fragment tx data packet\n"); + goto free_exit; + } + } + + ndev->cur_conn_id = conn_id; + queue_work(ndev->tx_wq, &ndev->tx_work); + + goto exit; + +free_exit: + kfree_skb(skb); + +exit: + return rc; +} +EXPORT_SYMBOL(nci_send_data); + +/* ----------------- NCI RX Data ----------------- */ + +static void nci_add_rx_data_frag(struct nci_dev *ndev, + struct sk_buff *skb, + __u8 pbf, __u8 conn_id, __u8 status) +{ + int reassembly_len; + int err = 0; + + if (status) { + err = status; + goto exit; + } + + if (ndev->rx_data_reassembly) { + reassembly_len = ndev->rx_data_reassembly->len; + + /* first, make enough room for the already accumulated data */ + if (skb_cow_head(skb, reassembly_len)) { + pr_err("error adding room for accumulated rx data\n"); + + kfree_skb(skb); + skb = NULL; + + kfree_skb(ndev->rx_data_reassembly); + ndev->rx_data_reassembly = NULL; + + err = -ENOMEM; + goto exit; + } + + /* second, combine the two fragments */ + memcpy(skb_push(skb, reassembly_len), + ndev->rx_data_reassembly->data, + reassembly_len); + + /* third, free old reassembly */ + kfree_skb(ndev->rx_data_reassembly); + ndev->rx_data_reassembly = NULL; + } + + if (pbf == NCI_PBF_CONT) { + /* need to wait for next fragment, store skb and exit */ + ndev->rx_data_reassembly = skb; + return; + } + +exit: + if (ndev->nfc_dev->rf_mode == NFC_RF_TARGET) { + /* Data received in Target mode, forward to nfc core */ + err = nfc_tm_data_received(ndev->nfc_dev, skb); + if (err) + pr_err("unable to handle received data\n"); + } else { + nci_data_exchange_complete(ndev, skb, conn_id, err); + } +} + +/* Rx Data packet */ +void nci_rx_data_packet(struct nci_dev *ndev, struct sk_buff *skb) +{ + __u8 pbf = nci_pbf(skb->data); + __u8 status = 0; + __u8 conn_id = nci_conn_id(skb->data); + const struct nci_conn_info *conn_info; + + pr_debug("len %d\n", skb->len); + + pr_debug("NCI RX: MT=data, PBF=%d, conn_id=%d, plen=%d\n", + nci_pbf(skb->data), + nci_conn_id(skb->data), + nci_plen(skb->data)); + + conn_info = nci_get_conn_info_by_conn_id(ndev, nci_conn_id(skb->data)); + if (!conn_info) { + kfree_skb(skb); + return; + } + + /* strip the nci data header */ + skb_pull(skb, NCI_DATA_HDR_SIZE); + + if (ndev->target_active_prot == NFC_PROTO_MIFARE || + ndev->target_active_prot == NFC_PROTO_JEWEL || + ndev->target_active_prot == NFC_PROTO_FELICA || + ndev->target_active_prot == NFC_PROTO_ISO15693) { + /* frame I/F => remove the status byte */ + pr_debug("frame I/F => remove the status byte\n"); + status = skb->data[skb->len - 1]; + skb_trim(skb, (skb->len - 1)); + } + + nci_add_rx_data_frag(ndev, skb, pbf, conn_id, nci_to_errno(status)); +} diff --git a/net/nfc/nci/hci.c b/net/nfc/nci/hci.c new file mode 100644 index 000000000..78c4b6add --- /dev/null +++ b/net/nfc/nci/hci.c @@ -0,0 +1,791 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * The NFC Controller Interface is the communication protocol between an + * NFC Controller (NFCC) and a Device Host (DH). + * This is the HCI over NCI implementation, as specified in the 10.2 + * section of the NCI 1.1 specification. + * + * Copyright (C) 2014 STMicroelectronics SAS. All rights reserved. + */ + +#include + +#include "../nfc.h" +#include +#include +#include + +struct nci_data { + u8 conn_id; + u8 pipe; + u8 cmd; + const u8 *data; + u32 data_len; +} __packed; + +struct nci_hci_create_pipe_params { + u8 src_gate; + u8 dest_host; + u8 dest_gate; +} __packed; + +struct nci_hci_create_pipe_resp { + u8 src_host; + u8 src_gate; + u8 dest_host; + u8 dest_gate; + u8 pipe; +} __packed; + +struct nci_hci_delete_pipe_noti { + u8 pipe; +} __packed; + +struct nci_hci_all_pipe_cleared_noti { + u8 host; +} __packed; + +struct nci_hcp_message { + u8 header; /* type -cmd,evt,rsp- + instruction */ + u8 data[]; +} __packed; + +struct nci_hcp_packet { + u8 header; /* cbit+pipe */ + struct nci_hcp_message message; +} __packed; + +#define NCI_HCI_ANY_SET_PARAMETER 0x01 +#define NCI_HCI_ANY_GET_PARAMETER 0x02 +#define NCI_HCI_ANY_CLOSE_PIPE 0x04 +#define NCI_HCI_ADM_CLEAR_ALL_PIPE 0x14 + +#define NCI_HFP_NO_CHAINING 0x80 + +#define NCI_NFCEE_ID_HCI 0x80 + +#define NCI_EVT_HOT_PLUG 0x03 + +#define NCI_HCI_ADMIN_PARAM_SESSION_IDENTITY 0x01 +#define NCI_HCI_ADM_CREATE_PIPE 0x10 +#define NCI_HCI_ADM_DELETE_PIPE 0x11 + +/* HCP headers */ +#define NCI_HCI_HCP_PACKET_HEADER_LEN 1 +#define NCI_HCI_HCP_MESSAGE_HEADER_LEN 1 +#define NCI_HCI_HCP_HEADER_LEN 2 + +/* HCP types */ +#define NCI_HCI_HCP_COMMAND 0x00 +#define NCI_HCI_HCP_EVENT 0x01 +#define NCI_HCI_HCP_RESPONSE 0x02 + +#define NCI_HCI_ADM_NOTIFY_PIPE_CREATED 0x12 +#define NCI_HCI_ADM_NOTIFY_PIPE_DELETED 0x13 +#define NCI_HCI_ADM_NOTIFY_ALL_PIPE_CLEARED 0x15 + +#define NCI_HCI_FRAGMENT 0x7f +#define NCI_HCP_HEADER(type, instr) ((((type) & 0x03) << 6) |\ + ((instr) & 0x3f)) + +#define NCI_HCP_MSG_GET_TYPE(header) ((header & 0xc0) >> 6) +#define NCI_HCP_MSG_GET_CMD(header) (header & 0x3f) +#define NCI_HCP_MSG_GET_PIPE(header) (header & 0x7f) + +static int nci_hci_result_to_errno(u8 result) +{ + switch (result) { + case NCI_HCI_ANY_OK: + return 0; + case NCI_HCI_ANY_E_REG_PAR_UNKNOWN: + return -EOPNOTSUPP; + case NCI_HCI_ANY_E_TIMEOUT: + return -ETIME; + default: + return -1; + } +} + +/* HCI core */ +static void nci_hci_reset_pipes(struct nci_hci_dev *hdev) +{ + int i; + + for (i = 0; i < NCI_HCI_MAX_PIPES; i++) { + hdev->pipes[i].gate = NCI_HCI_INVALID_GATE; + hdev->pipes[i].host = NCI_HCI_INVALID_HOST; + } + memset(hdev->gate2pipe, NCI_HCI_INVALID_PIPE, sizeof(hdev->gate2pipe)); +} + +static void nci_hci_reset_pipes_per_host(struct nci_dev *ndev, u8 host) +{ + int i; + + for (i = 0; i < NCI_HCI_MAX_PIPES; i++) { + if (ndev->hci_dev->pipes[i].host == host) { + ndev->hci_dev->pipes[i].gate = NCI_HCI_INVALID_GATE; + ndev->hci_dev->pipes[i].host = NCI_HCI_INVALID_HOST; + } + } +} + +/* Fragment HCI data over NCI packet. + * NFC Forum NCI 10.2.2 Data Exchange: + * The payload of the Data Packets sent on the Logical Connection SHALL be + * valid HCP packets, as defined within [ETSI_102622]. Each Data Packet SHALL + * contain a single HCP packet. NCI Segmentation and Reassembly SHALL NOT be + * applied to Data Messages in either direction. The HCI fragmentation mechanism + * is used if required. + */ +static int nci_hci_send_data(struct nci_dev *ndev, u8 pipe, + const u8 data_type, const u8 *data, + size_t data_len) +{ + const struct nci_conn_info *conn_info; + struct sk_buff *skb; + int len, i, r; + u8 cb = pipe; + + conn_info = ndev->hci_dev->conn_info; + if (!conn_info) + return -EPROTO; + + i = 0; + skb = nci_skb_alloc(ndev, conn_info->max_pkt_payload_len + + NCI_DATA_HDR_SIZE, GFP_ATOMIC); + if (!skb) + return -ENOMEM; + + skb_reserve(skb, NCI_DATA_HDR_SIZE + 2); + *(u8 *)skb_push(skb, 1) = data_type; + + do { + /* If last packet add NCI_HFP_NO_CHAINING */ + if (i + conn_info->max_pkt_payload_len - + (skb->len + 1) >= data_len) { + cb |= NCI_HFP_NO_CHAINING; + len = data_len - i; + } else { + len = conn_info->max_pkt_payload_len - skb->len - 1; + } + + *(u8 *)skb_push(skb, 1) = cb; + + if (len > 0) + skb_put_data(skb, data + i, len); + + r = nci_send_data(ndev, conn_info->conn_id, skb); + if (r < 0) + return r; + + i += len; + + if (i < data_len) { + skb = nci_skb_alloc(ndev, + conn_info->max_pkt_payload_len + + NCI_DATA_HDR_SIZE, GFP_ATOMIC); + if (!skb) + return -ENOMEM; + + skb_reserve(skb, NCI_DATA_HDR_SIZE + 1); + } + } while (i < data_len); + + return i; +} + +static void nci_hci_send_data_req(struct nci_dev *ndev, const void *opt) +{ + const struct nci_data *data = opt; + + nci_hci_send_data(ndev, data->pipe, data->cmd, + data->data, data->data_len); +} + +int nci_hci_send_event(struct nci_dev *ndev, u8 gate, u8 event, + const u8 *param, size_t param_len) +{ + u8 pipe = ndev->hci_dev->gate2pipe[gate]; + + if (pipe == NCI_HCI_INVALID_PIPE) + return -EADDRNOTAVAIL; + + return nci_hci_send_data(ndev, pipe, + NCI_HCP_HEADER(NCI_HCI_HCP_EVENT, event), + param, param_len); +} +EXPORT_SYMBOL(nci_hci_send_event); + +int nci_hci_send_cmd(struct nci_dev *ndev, u8 gate, u8 cmd, + const u8 *param, size_t param_len, + struct sk_buff **skb) +{ + const struct nci_hcp_message *message; + const struct nci_conn_info *conn_info; + struct nci_data data; + int r; + u8 pipe = ndev->hci_dev->gate2pipe[gate]; + + if (pipe == NCI_HCI_INVALID_PIPE) + return -EADDRNOTAVAIL; + + conn_info = ndev->hci_dev->conn_info; + if (!conn_info) + return -EPROTO; + + data.conn_id = conn_info->conn_id; + data.pipe = pipe; + data.cmd = NCI_HCP_HEADER(NCI_HCI_HCP_COMMAND, cmd); + data.data = param; + data.data_len = param_len; + + r = nci_request(ndev, nci_hci_send_data_req, &data, + msecs_to_jiffies(NCI_DATA_TIMEOUT)); + if (r == NCI_STATUS_OK) { + message = (struct nci_hcp_message *)conn_info->rx_skb->data; + r = nci_hci_result_to_errno( + NCI_HCP_MSG_GET_CMD(message->header)); + skb_pull(conn_info->rx_skb, NCI_HCI_HCP_MESSAGE_HEADER_LEN); + + if (!r && skb) + *skb = conn_info->rx_skb; + } + + return r; +} +EXPORT_SYMBOL(nci_hci_send_cmd); + +int nci_hci_clear_all_pipes(struct nci_dev *ndev) +{ + int r; + + r = nci_hci_send_cmd(ndev, NCI_HCI_ADMIN_GATE, + NCI_HCI_ADM_CLEAR_ALL_PIPE, NULL, 0, NULL); + if (r < 0) + return r; + + nci_hci_reset_pipes(ndev->hci_dev); + return r; +} +EXPORT_SYMBOL(nci_hci_clear_all_pipes); + +static void nci_hci_event_received(struct nci_dev *ndev, u8 pipe, + u8 event, struct sk_buff *skb) +{ + if (ndev->ops->hci_event_received) + ndev->ops->hci_event_received(ndev, pipe, event, skb); +} + +static void nci_hci_cmd_received(struct nci_dev *ndev, u8 pipe, + u8 cmd, struct sk_buff *skb) +{ + u8 gate = ndev->hci_dev->pipes[pipe].gate; + u8 status = NCI_HCI_ANY_OK | ~NCI_HCI_FRAGMENT; + u8 dest_gate, new_pipe; + struct nci_hci_create_pipe_resp *create_info; + struct nci_hci_delete_pipe_noti *delete_info; + struct nci_hci_all_pipe_cleared_noti *cleared_info; + + pr_debug("from gate %x pipe %x cmd %x\n", gate, pipe, cmd); + + switch (cmd) { + case NCI_HCI_ADM_NOTIFY_PIPE_CREATED: + if (skb->len != 5) { + status = NCI_HCI_ANY_E_NOK; + goto exit; + } + create_info = (struct nci_hci_create_pipe_resp *)skb->data; + dest_gate = create_info->dest_gate; + new_pipe = create_info->pipe; + if (new_pipe >= NCI_HCI_MAX_PIPES) { + status = NCI_HCI_ANY_E_NOK; + goto exit; + } + + /* Save the new created pipe and bind with local gate, + * the description for skb->data[3] is destination gate id + * but since we received this cmd from host controller, we + * are the destination and it is our local gate + */ + ndev->hci_dev->gate2pipe[dest_gate] = new_pipe; + ndev->hci_dev->pipes[new_pipe].gate = dest_gate; + ndev->hci_dev->pipes[new_pipe].host = + create_info->src_host; + break; + case NCI_HCI_ANY_OPEN_PIPE: + /* If the pipe is not created report an error */ + if (gate == NCI_HCI_INVALID_GATE) { + status = NCI_HCI_ANY_E_NOK; + goto exit; + } + break; + case NCI_HCI_ADM_NOTIFY_PIPE_DELETED: + if (skb->len != 1) { + status = NCI_HCI_ANY_E_NOK; + goto exit; + } + delete_info = (struct nci_hci_delete_pipe_noti *)skb->data; + if (delete_info->pipe >= NCI_HCI_MAX_PIPES) { + status = NCI_HCI_ANY_E_NOK; + goto exit; + } + + ndev->hci_dev->pipes[delete_info->pipe].gate = + NCI_HCI_INVALID_GATE; + ndev->hci_dev->pipes[delete_info->pipe].host = + NCI_HCI_INVALID_HOST; + break; + case NCI_HCI_ADM_NOTIFY_ALL_PIPE_CLEARED: + if (skb->len != 1) { + status = NCI_HCI_ANY_E_NOK; + goto exit; + } + + cleared_info = + (struct nci_hci_all_pipe_cleared_noti *)skb->data; + nci_hci_reset_pipes_per_host(ndev, cleared_info->host); + break; + default: + pr_debug("Discarded unknown cmd %x to gate %x\n", cmd, gate); + break; + } + + if (ndev->ops->hci_cmd_received) + ndev->ops->hci_cmd_received(ndev, pipe, cmd, skb); + +exit: + nci_hci_send_data(ndev, pipe, status, NULL, 0); + + kfree_skb(skb); +} + +static void nci_hci_resp_received(struct nci_dev *ndev, u8 pipe, + struct sk_buff *skb) +{ + struct nci_conn_info *conn_info; + + conn_info = ndev->hci_dev->conn_info; + if (!conn_info) + goto exit; + + conn_info->rx_skb = skb; + +exit: + nci_req_complete(ndev, NCI_STATUS_OK); +} + +/* Receive hcp message for pipe, with type and cmd. + * skb contains optional message data only. + */ +static void nci_hci_hcp_message_rx(struct nci_dev *ndev, u8 pipe, + u8 type, u8 instruction, struct sk_buff *skb) +{ + switch (type) { + case NCI_HCI_HCP_RESPONSE: + nci_hci_resp_received(ndev, pipe, skb); + break; + case NCI_HCI_HCP_COMMAND: + nci_hci_cmd_received(ndev, pipe, instruction, skb); + break; + case NCI_HCI_HCP_EVENT: + nci_hci_event_received(ndev, pipe, instruction, skb); + break; + default: + pr_err("UNKNOWN MSG Type %d, instruction=%d\n", + type, instruction); + kfree_skb(skb); + break; + } + + nci_req_complete(ndev, NCI_STATUS_OK); +} + +static void nci_hci_msg_rx_work(struct work_struct *work) +{ + struct nci_hci_dev *hdev = + container_of(work, struct nci_hci_dev, msg_rx_work); + struct sk_buff *skb; + const struct nci_hcp_message *message; + u8 pipe, type, instruction; + + while ((skb = skb_dequeue(&hdev->msg_rx_queue)) != NULL) { + pipe = NCI_HCP_MSG_GET_PIPE(skb->data[0]); + skb_pull(skb, NCI_HCI_HCP_PACKET_HEADER_LEN); + message = (struct nci_hcp_message *)skb->data; + type = NCI_HCP_MSG_GET_TYPE(message->header); + instruction = NCI_HCP_MSG_GET_CMD(message->header); + skb_pull(skb, NCI_HCI_HCP_MESSAGE_HEADER_LEN); + + nci_hci_hcp_message_rx(hdev->ndev, pipe, + type, instruction, skb); + } +} + +void nci_hci_data_received_cb(void *context, + struct sk_buff *skb, int err) +{ + struct nci_dev *ndev = (struct nci_dev *)context; + struct nci_hcp_packet *packet; + u8 pipe, type; + struct sk_buff *hcp_skb; + struct sk_buff *frag_skb; + int msg_len; + + if (err) { + nci_req_complete(ndev, err); + return; + } + + packet = (struct nci_hcp_packet *)skb->data; + if ((packet->header & ~NCI_HCI_FRAGMENT) == 0) { + skb_queue_tail(&ndev->hci_dev->rx_hcp_frags, skb); + return; + } + + /* it's the last fragment. Does it need re-aggregation? */ + if (skb_queue_len(&ndev->hci_dev->rx_hcp_frags)) { + pipe = NCI_HCP_MSG_GET_PIPE(packet->header); + skb_queue_tail(&ndev->hci_dev->rx_hcp_frags, skb); + + msg_len = 0; + skb_queue_walk(&ndev->hci_dev->rx_hcp_frags, frag_skb) { + msg_len += (frag_skb->len - + NCI_HCI_HCP_PACKET_HEADER_LEN); + } + + hcp_skb = nfc_alloc_recv_skb(NCI_HCI_HCP_PACKET_HEADER_LEN + + msg_len, GFP_KERNEL); + if (!hcp_skb) { + nci_req_complete(ndev, -ENOMEM); + return; + } + + skb_put_u8(hcp_skb, pipe); + + skb_queue_walk(&ndev->hci_dev->rx_hcp_frags, frag_skb) { + msg_len = frag_skb->len - NCI_HCI_HCP_PACKET_HEADER_LEN; + skb_put_data(hcp_skb, + frag_skb->data + NCI_HCI_HCP_PACKET_HEADER_LEN, + msg_len); + } + + skb_queue_purge(&ndev->hci_dev->rx_hcp_frags); + } else { + packet->header &= NCI_HCI_FRAGMENT; + hcp_skb = skb; + } + + /* if this is a response, dispatch immediately to + * unblock waiting cmd context. Otherwise, enqueue to dispatch + * in separate context where handler can also execute command. + */ + packet = (struct nci_hcp_packet *)hcp_skb->data; + type = NCI_HCP_MSG_GET_TYPE(packet->message.header); + if (type == NCI_HCI_HCP_RESPONSE) { + pipe = NCI_HCP_MSG_GET_PIPE(packet->header); + skb_pull(hcp_skb, NCI_HCI_HCP_PACKET_HEADER_LEN); + nci_hci_hcp_message_rx(ndev, pipe, type, + NCI_STATUS_OK, hcp_skb); + } else { + skb_queue_tail(&ndev->hci_dev->msg_rx_queue, hcp_skb); + schedule_work(&ndev->hci_dev->msg_rx_work); + } +} + +int nci_hci_open_pipe(struct nci_dev *ndev, u8 pipe) +{ + struct nci_data data; + const struct nci_conn_info *conn_info; + + conn_info = ndev->hci_dev->conn_info; + if (!conn_info) + return -EPROTO; + + data.conn_id = conn_info->conn_id; + data.pipe = pipe; + data.cmd = NCI_HCP_HEADER(NCI_HCI_HCP_COMMAND, + NCI_HCI_ANY_OPEN_PIPE); + data.data = NULL; + data.data_len = 0; + + return nci_request(ndev, nci_hci_send_data_req, &data, + msecs_to_jiffies(NCI_DATA_TIMEOUT)); +} +EXPORT_SYMBOL(nci_hci_open_pipe); + +static u8 nci_hci_create_pipe(struct nci_dev *ndev, u8 dest_host, + u8 dest_gate, int *result) +{ + u8 pipe; + struct sk_buff *skb; + struct nci_hci_create_pipe_params params; + const struct nci_hci_create_pipe_resp *resp; + + pr_debug("gate=%d\n", dest_gate); + + params.src_gate = NCI_HCI_ADMIN_GATE; + params.dest_host = dest_host; + params.dest_gate = dest_gate; + + *result = nci_hci_send_cmd(ndev, NCI_HCI_ADMIN_GATE, + NCI_HCI_ADM_CREATE_PIPE, + (u8 *)¶ms, sizeof(params), &skb); + if (*result < 0) + return NCI_HCI_INVALID_PIPE; + + resp = (struct nci_hci_create_pipe_resp *)skb->data; + pipe = resp->pipe; + kfree_skb(skb); + + pr_debug("pipe created=%d\n", pipe); + + return pipe; +} + +static int nci_hci_delete_pipe(struct nci_dev *ndev, u8 pipe) +{ + return nci_hci_send_cmd(ndev, NCI_HCI_ADMIN_GATE, + NCI_HCI_ADM_DELETE_PIPE, &pipe, 1, NULL); +} + +int nci_hci_set_param(struct nci_dev *ndev, u8 gate, u8 idx, + const u8 *param, size_t param_len) +{ + const struct nci_hcp_message *message; + const struct nci_conn_info *conn_info; + struct nci_data data; + int r; + u8 *tmp; + u8 pipe = ndev->hci_dev->gate2pipe[gate]; + + pr_debug("idx=%d to gate %d\n", idx, gate); + + if (pipe == NCI_HCI_INVALID_PIPE) + return -EADDRNOTAVAIL; + + conn_info = ndev->hci_dev->conn_info; + if (!conn_info) + return -EPROTO; + + tmp = kmalloc(1 + param_len, GFP_KERNEL); + if (!tmp) + return -ENOMEM; + + *tmp = idx; + memcpy(tmp + 1, param, param_len); + + data.conn_id = conn_info->conn_id; + data.pipe = pipe; + data.cmd = NCI_HCP_HEADER(NCI_HCI_HCP_COMMAND, + NCI_HCI_ANY_SET_PARAMETER); + data.data = tmp; + data.data_len = param_len + 1; + + r = nci_request(ndev, nci_hci_send_data_req, &data, + msecs_to_jiffies(NCI_DATA_TIMEOUT)); + if (r == NCI_STATUS_OK) { + message = (struct nci_hcp_message *)conn_info->rx_skb->data; + r = nci_hci_result_to_errno( + NCI_HCP_MSG_GET_CMD(message->header)); + skb_pull(conn_info->rx_skb, NCI_HCI_HCP_MESSAGE_HEADER_LEN); + } + + kfree(tmp); + return r; +} +EXPORT_SYMBOL(nci_hci_set_param); + +int nci_hci_get_param(struct nci_dev *ndev, u8 gate, u8 idx, + struct sk_buff **skb) +{ + const struct nci_hcp_message *message; + const struct nci_conn_info *conn_info; + struct nci_data data; + int r; + u8 pipe = ndev->hci_dev->gate2pipe[gate]; + + pr_debug("idx=%d to gate %d\n", idx, gate); + + if (pipe == NCI_HCI_INVALID_PIPE) + return -EADDRNOTAVAIL; + + conn_info = ndev->hci_dev->conn_info; + if (!conn_info) + return -EPROTO; + + data.conn_id = conn_info->conn_id; + data.pipe = pipe; + data.cmd = NCI_HCP_HEADER(NCI_HCI_HCP_COMMAND, + NCI_HCI_ANY_GET_PARAMETER); + data.data = &idx; + data.data_len = 1; + + r = nci_request(ndev, nci_hci_send_data_req, &data, + msecs_to_jiffies(NCI_DATA_TIMEOUT)); + + if (r == NCI_STATUS_OK) { + message = (struct nci_hcp_message *)conn_info->rx_skb->data; + r = nci_hci_result_to_errno( + NCI_HCP_MSG_GET_CMD(message->header)); + skb_pull(conn_info->rx_skb, NCI_HCI_HCP_MESSAGE_HEADER_LEN); + + if (!r && skb) + *skb = conn_info->rx_skb; + } + + return r; +} +EXPORT_SYMBOL(nci_hci_get_param); + +int nci_hci_connect_gate(struct nci_dev *ndev, + u8 dest_host, u8 dest_gate, u8 pipe) +{ + bool pipe_created = false; + int r; + + if (pipe == NCI_HCI_DO_NOT_OPEN_PIPE) + return 0; + + if (ndev->hci_dev->gate2pipe[dest_gate] != NCI_HCI_INVALID_PIPE) + return -EADDRINUSE; + + if (pipe != NCI_HCI_INVALID_PIPE) + goto open_pipe; + + switch (dest_gate) { + case NCI_HCI_LINK_MGMT_GATE: + pipe = NCI_HCI_LINK_MGMT_PIPE; + break; + case NCI_HCI_ADMIN_GATE: + pipe = NCI_HCI_ADMIN_PIPE; + break; + default: + pipe = nci_hci_create_pipe(ndev, dest_host, dest_gate, &r); + if (pipe == NCI_HCI_INVALID_PIPE) + return r; + pipe_created = true; + break; + } + +open_pipe: + r = nci_hci_open_pipe(ndev, pipe); + if (r < 0) { + if (pipe_created) { + if (nci_hci_delete_pipe(ndev, pipe) < 0) { + /* TODO: Cannot clean by deleting pipe... + * -> inconsistent state + */ + } + } + return r; + } + + ndev->hci_dev->pipes[pipe].gate = dest_gate; + ndev->hci_dev->pipes[pipe].host = dest_host; + ndev->hci_dev->gate2pipe[dest_gate] = pipe; + + return 0; +} +EXPORT_SYMBOL(nci_hci_connect_gate); + +static int nci_hci_dev_connect_gates(struct nci_dev *ndev, + u8 gate_count, + const struct nci_hci_gate *gates) +{ + int r; + + while (gate_count--) { + r = nci_hci_connect_gate(ndev, gates->dest_host, + gates->gate, gates->pipe); + if (r < 0) + return r; + gates++; + } + + return 0; +} + +int nci_hci_dev_session_init(struct nci_dev *ndev) +{ + struct nci_conn_info *conn_info; + struct sk_buff *skb; + int r; + + ndev->hci_dev->count_pipes = 0; + ndev->hci_dev->expected_pipes = 0; + + conn_info = ndev->hci_dev->conn_info; + if (!conn_info) + return -EPROTO; + + conn_info->data_exchange_cb = nci_hci_data_received_cb; + conn_info->data_exchange_cb_context = ndev; + + nci_hci_reset_pipes(ndev->hci_dev); + + if (ndev->hci_dev->init_data.gates[0].gate != NCI_HCI_ADMIN_GATE) + return -EPROTO; + + r = nci_hci_connect_gate(ndev, + ndev->hci_dev->init_data.gates[0].dest_host, + ndev->hci_dev->init_data.gates[0].gate, + ndev->hci_dev->init_data.gates[0].pipe); + if (r < 0) + return r; + + r = nci_hci_get_param(ndev, NCI_HCI_ADMIN_GATE, + NCI_HCI_ADMIN_PARAM_SESSION_IDENTITY, &skb); + if (r < 0) + return r; + + if (skb->len && + skb->len == strlen(ndev->hci_dev->init_data.session_id) && + !memcmp(ndev->hci_dev->init_data.session_id, skb->data, skb->len) && + ndev->ops->hci_load_session) { + /* Restore gate<->pipe table from some proprietary location. */ + r = ndev->ops->hci_load_session(ndev); + } else { + r = nci_hci_clear_all_pipes(ndev); + if (r < 0) + goto exit; + + r = nci_hci_dev_connect_gates(ndev, + ndev->hci_dev->init_data.gate_count, + ndev->hci_dev->init_data.gates); + if (r < 0) + goto exit; + + r = nci_hci_set_param(ndev, NCI_HCI_ADMIN_GATE, + NCI_HCI_ADMIN_PARAM_SESSION_IDENTITY, + ndev->hci_dev->init_data.session_id, + strlen(ndev->hci_dev->init_data.session_id)); + } + +exit: + kfree_skb(skb); + + return r; +} +EXPORT_SYMBOL(nci_hci_dev_session_init); + +struct nci_hci_dev *nci_hci_allocate(struct nci_dev *ndev) +{ + struct nci_hci_dev *hdev; + + hdev = kzalloc(sizeof(*hdev), GFP_KERNEL); + if (!hdev) + return NULL; + + skb_queue_head_init(&hdev->rx_hcp_frags); + INIT_WORK(&hdev->msg_rx_work, nci_hci_msg_rx_work); + skb_queue_head_init(&hdev->msg_rx_queue); + hdev->ndev = ndev; + + return hdev; +} + +void nci_hci_deallocate(struct nci_dev *ndev) +{ + kfree(ndev->hci_dev); +} diff --git a/net/nfc/nci/lib.c b/net/nfc/nci/lib.c new file mode 100644 index 000000000..473323f80 --- /dev/null +++ b/net/nfc/nci/lib.c @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * The NFC Controller Interface is the communication protocol between an + * NFC Controller (NFCC) and a Device Host (DH). + * + * Copyright (C) 2011 Texas Instruments, Inc. + * + * Written by Ilan Elias + * + * Acknowledgements: + * This file is based on lib.c, which was written + * by Maxim Krasnyansky. + */ + +#include +#include +#include +#include + +#include +#include + +/* NCI status codes to Unix errno mapping */ +int nci_to_errno(__u8 code) +{ + switch (code) { + case NCI_STATUS_OK: + return 0; + + case NCI_STATUS_REJECTED: + return -EBUSY; + + case NCI_STATUS_RF_FRAME_CORRUPTED: + return -EBADMSG; + + case NCI_STATUS_NOT_INITIALIZED: + return -EHOSTDOWN; + + case NCI_STATUS_SYNTAX_ERROR: + case NCI_STATUS_SEMANTIC_ERROR: + case NCI_STATUS_INVALID_PARAM: + case NCI_STATUS_RF_PROTOCOL_ERROR: + case NCI_STATUS_NFCEE_PROTOCOL_ERROR: + return -EPROTO; + + case NCI_STATUS_UNKNOWN_GID: + case NCI_STATUS_UNKNOWN_OID: + return -EBADRQC; + + case NCI_STATUS_MESSAGE_SIZE_EXCEEDED: + return -EMSGSIZE; + + case NCI_STATUS_DISCOVERY_ALREADY_STARTED: + return -EALREADY; + + case NCI_STATUS_DISCOVERY_TARGET_ACTIVATION_FAILED: + case NCI_STATUS_NFCEE_INTERFACE_ACTIVATION_FAILED: + return -ECONNREFUSED; + + case NCI_STATUS_RF_TRANSMISSION_ERROR: + case NCI_STATUS_NFCEE_TRANSMISSION_ERROR: + return -ECOMM; + + case NCI_STATUS_RF_TIMEOUT_ERROR: + case NCI_STATUS_NFCEE_TIMEOUT_ERROR: + return -ETIMEDOUT; + + case NCI_STATUS_FAILED: + default: + return -ENOSYS; + } +} +EXPORT_SYMBOL(nci_to_errno); diff --git a/net/nfc/nci/ntf.c b/net/nfc/nci/ntf.c new file mode 100644 index 000000000..994a0a1ef --- /dev/null +++ b/net/nfc/nci/ntf.c @@ -0,0 +1,824 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * The NFC Controller Interface is the communication protocol between an + * NFC Controller (NFCC) and a Device Host (DH). + * + * Copyright (C) 2014 Marvell International Ltd. + * Copyright (C) 2011 Texas Instruments, Inc. + * + * Written by Ilan Elias + * + * Acknowledgements: + * This file is based on hci_event.c, which was written + * by Maxim Krasnyansky. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": %s: " fmt, __func__ + +#include +#include +#include +#include + +#include "../nfc.h" +#include +#include +#include + +/* Handle NCI Notification packets */ + +static void nci_core_reset_ntf_packet(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + /* Handle NCI 2.x core reset notification */ + const struct nci_core_reset_ntf *ntf = (void *)skb->data; + + ndev->nci_ver = ntf->nci_ver; + pr_debug("nci_ver 0x%x, config_status 0x%x\n", + ntf->nci_ver, ntf->config_status); + + ndev->manufact_id = ntf->manufact_id; + ndev->manufact_specific_info = + __le32_to_cpu(ntf->manufact_specific_info); + + nci_req_complete(ndev, NCI_STATUS_OK); +} + +static void nci_core_conn_credits_ntf_packet(struct nci_dev *ndev, + struct sk_buff *skb) +{ + struct nci_core_conn_credit_ntf *ntf = (void *) skb->data; + struct nci_conn_info *conn_info; + int i; + + pr_debug("num_entries %d\n", ntf->num_entries); + + if (ntf->num_entries > NCI_MAX_NUM_CONN) + ntf->num_entries = NCI_MAX_NUM_CONN; + + /* update the credits */ + for (i = 0; i < ntf->num_entries; i++) { + ntf->conn_entries[i].conn_id = + nci_conn_id(&ntf->conn_entries[i].conn_id); + + pr_debug("entry[%d]: conn_id %d, credits %d\n", + i, ntf->conn_entries[i].conn_id, + ntf->conn_entries[i].credits); + + conn_info = nci_get_conn_info_by_conn_id(ndev, + ntf->conn_entries[i].conn_id); + if (!conn_info) + return; + + atomic_add(ntf->conn_entries[i].credits, + &conn_info->credits_cnt); + } + + /* trigger the next tx */ + if (!skb_queue_empty(&ndev->tx_q)) + queue_work(ndev->tx_wq, &ndev->tx_work); +} + +static void nci_core_generic_error_ntf_packet(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + __u8 status = skb->data[0]; + + pr_debug("status 0x%x\n", status); + + if (atomic_read(&ndev->state) == NCI_W4_HOST_SELECT) { + /* Activation failed, so complete the request + (the state remains the same) */ + nci_req_complete(ndev, status); + } +} + +static void nci_core_conn_intf_error_ntf_packet(struct nci_dev *ndev, + struct sk_buff *skb) +{ + struct nci_core_intf_error_ntf *ntf = (void *) skb->data; + + ntf->conn_id = nci_conn_id(&ntf->conn_id); + + pr_debug("status 0x%x, conn_id %d\n", ntf->status, ntf->conn_id); + + /* complete the data exchange transaction, if exists */ + if (test_bit(NCI_DATA_EXCHANGE, &ndev->flags)) + nci_data_exchange_complete(ndev, NULL, ntf->conn_id, -EIO); +} + +static const __u8 * +nci_extract_rf_params_nfca_passive_poll(struct nci_dev *ndev, + struct rf_tech_specific_params_nfca_poll *nfca_poll, + const __u8 *data) +{ + nfca_poll->sens_res = __le16_to_cpu(*((__le16 *)data)); + data += 2; + + nfca_poll->nfcid1_len = min_t(__u8, *data++, NFC_NFCID1_MAXSIZE); + + pr_debug("sens_res 0x%x, nfcid1_len %d\n", + nfca_poll->sens_res, nfca_poll->nfcid1_len); + + memcpy(nfca_poll->nfcid1, data, nfca_poll->nfcid1_len); + data += nfca_poll->nfcid1_len; + + nfca_poll->sel_res_len = *data++; + + if (nfca_poll->sel_res_len != 0) + nfca_poll->sel_res = *data++; + + pr_debug("sel_res_len %d, sel_res 0x%x\n", + nfca_poll->sel_res_len, + nfca_poll->sel_res); + + return data; +} + +static const __u8 * +nci_extract_rf_params_nfcb_passive_poll(struct nci_dev *ndev, + struct rf_tech_specific_params_nfcb_poll *nfcb_poll, + const __u8 *data) +{ + nfcb_poll->sensb_res_len = min_t(__u8, *data++, NFC_SENSB_RES_MAXSIZE); + + pr_debug("sensb_res_len %d\n", nfcb_poll->sensb_res_len); + + memcpy(nfcb_poll->sensb_res, data, nfcb_poll->sensb_res_len); + data += nfcb_poll->sensb_res_len; + + return data; +} + +static const __u8 * +nci_extract_rf_params_nfcf_passive_poll(struct nci_dev *ndev, + struct rf_tech_specific_params_nfcf_poll *nfcf_poll, + const __u8 *data) +{ + nfcf_poll->bit_rate = *data++; + nfcf_poll->sensf_res_len = min_t(__u8, *data++, NFC_SENSF_RES_MAXSIZE); + + pr_debug("bit_rate %d, sensf_res_len %d\n", + nfcf_poll->bit_rate, nfcf_poll->sensf_res_len); + + memcpy(nfcf_poll->sensf_res, data, nfcf_poll->sensf_res_len); + data += nfcf_poll->sensf_res_len; + + return data; +} + +static const __u8 * +nci_extract_rf_params_nfcv_passive_poll(struct nci_dev *ndev, + struct rf_tech_specific_params_nfcv_poll *nfcv_poll, + const __u8 *data) +{ + ++data; + nfcv_poll->dsfid = *data++; + memcpy(nfcv_poll->uid, data, NFC_ISO15693_UID_MAXSIZE); + data += NFC_ISO15693_UID_MAXSIZE; + return data; +} + +static const __u8 * +nci_extract_rf_params_nfcf_passive_listen(struct nci_dev *ndev, + struct rf_tech_specific_params_nfcf_listen *nfcf_listen, + const __u8 *data) +{ + nfcf_listen->local_nfcid2_len = min_t(__u8, *data++, + NFC_NFCID2_MAXSIZE); + memcpy(nfcf_listen->local_nfcid2, data, nfcf_listen->local_nfcid2_len); + data += nfcf_listen->local_nfcid2_len; + + return data; +} + +static __u32 nci_get_prop_rf_protocol(struct nci_dev *ndev, __u8 rf_protocol) +{ + if (ndev->ops->get_rfprotocol) + return ndev->ops->get_rfprotocol(ndev, rf_protocol); + return 0; +} + +static int nci_add_new_protocol(struct nci_dev *ndev, + struct nfc_target *target, + __u8 rf_protocol, + __u8 rf_tech_and_mode, + const void *params) +{ + const struct rf_tech_specific_params_nfca_poll *nfca_poll; + const struct rf_tech_specific_params_nfcb_poll *nfcb_poll; + const struct rf_tech_specific_params_nfcf_poll *nfcf_poll; + const struct rf_tech_specific_params_nfcv_poll *nfcv_poll; + __u32 protocol; + + if (rf_protocol == NCI_RF_PROTOCOL_T1T) + protocol = NFC_PROTO_JEWEL_MASK; + else if (rf_protocol == NCI_RF_PROTOCOL_T2T) + protocol = NFC_PROTO_MIFARE_MASK; + else if (rf_protocol == NCI_RF_PROTOCOL_ISO_DEP) + if (rf_tech_and_mode == NCI_NFC_A_PASSIVE_POLL_MODE) + protocol = NFC_PROTO_ISO14443_MASK; + else + protocol = NFC_PROTO_ISO14443_B_MASK; + else if (rf_protocol == NCI_RF_PROTOCOL_T3T) + protocol = NFC_PROTO_FELICA_MASK; + else if (rf_protocol == NCI_RF_PROTOCOL_NFC_DEP) + protocol = NFC_PROTO_NFC_DEP_MASK; + else if (rf_protocol == NCI_RF_PROTOCOL_T5T) + protocol = NFC_PROTO_ISO15693_MASK; + else + protocol = nci_get_prop_rf_protocol(ndev, rf_protocol); + + if (!(protocol & ndev->poll_prots)) { + pr_err("the target found does not have the desired protocol\n"); + return -EPROTO; + } + + if (rf_tech_and_mode == NCI_NFC_A_PASSIVE_POLL_MODE) { + nfca_poll = (struct rf_tech_specific_params_nfca_poll *)params; + + target->sens_res = nfca_poll->sens_res; + target->sel_res = nfca_poll->sel_res; + target->nfcid1_len = nfca_poll->nfcid1_len; + if (target->nfcid1_len > ARRAY_SIZE(target->nfcid1)) + return -EPROTO; + if (target->nfcid1_len > 0) { + memcpy(target->nfcid1, nfca_poll->nfcid1, + target->nfcid1_len); + } + } else if (rf_tech_and_mode == NCI_NFC_B_PASSIVE_POLL_MODE) { + nfcb_poll = (struct rf_tech_specific_params_nfcb_poll *)params; + + target->sensb_res_len = nfcb_poll->sensb_res_len; + if (target->sensb_res_len > ARRAY_SIZE(target->sensb_res)) + return -EPROTO; + if (target->sensb_res_len > 0) { + memcpy(target->sensb_res, nfcb_poll->sensb_res, + target->sensb_res_len); + } + } else if (rf_tech_and_mode == NCI_NFC_F_PASSIVE_POLL_MODE) { + nfcf_poll = (struct rf_tech_specific_params_nfcf_poll *)params; + + target->sensf_res_len = nfcf_poll->sensf_res_len; + if (target->sensf_res_len > ARRAY_SIZE(target->sensf_res)) + return -EPROTO; + if (target->sensf_res_len > 0) { + memcpy(target->sensf_res, nfcf_poll->sensf_res, + target->sensf_res_len); + } + } else if (rf_tech_and_mode == NCI_NFC_V_PASSIVE_POLL_MODE) { + nfcv_poll = (struct rf_tech_specific_params_nfcv_poll *)params; + + target->is_iso15693 = 1; + target->iso15693_dsfid = nfcv_poll->dsfid; + memcpy(target->iso15693_uid, nfcv_poll->uid, NFC_ISO15693_UID_MAXSIZE); + } else { + pr_err("unsupported rf_tech_and_mode 0x%x\n", rf_tech_and_mode); + return -EPROTO; + } + + target->supported_protocols |= protocol; + + pr_debug("protocol 0x%x\n", protocol); + + return 0; +} + +static void nci_add_new_target(struct nci_dev *ndev, + const struct nci_rf_discover_ntf *ntf) +{ + struct nfc_target *target; + int i, rc; + + for (i = 0; i < ndev->n_targets; i++) { + target = &ndev->targets[i]; + if (target->logical_idx == ntf->rf_discovery_id) { + /* This target already exists, add the new protocol */ + nci_add_new_protocol(ndev, target, ntf->rf_protocol, + ntf->rf_tech_and_mode, + &ntf->rf_tech_specific_params); + return; + } + } + + /* This is a new target, check if we've enough room */ + if (ndev->n_targets == NCI_MAX_DISCOVERED_TARGETS) { + pr_debug("not enough room, ignoring new target...\n"); + return; + } + + target = &ndev->targets[ndev->n_targets]; + + rc = nci_add_new_protocol(ndev, target, ntf->rf_protocol, + ntf->rf_tech_and_mode, + &ntf->rf_tech_specific_params); + if (!rc) { + target->logical_idx = ntf->rf_discovery_id; + ndev->n_targets++; + + pr_debug("logical idx %d, n_targets %d\n", target->logical_idx, + ndev->n_targets); + } +} + +void nci_clear_target_list(struct nci_dev *ndev) +{ + memset(ndev->targets, 0, + (sizeof(struct nfc_target)*NCI_MAX_DISCOVERED_TARGETS)); + + ndev->n_targets = 0; +} + +static void nci_rf_discover_ntf_packet(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + struct nci_rf_discover_ntf ntf; + const __u8 *data = skb->data; + bool add_target = true; + + ntf.rf_discovery_id = *data++; + ntf.rf_protocol = *data++; + ntf.rf_tech_and_mode = *data++; + ntf.rf_tech_specific_params_len = *data++; + + pr_debug("rf_discovery_id %d\n", ntf.rf_discovery_id); + pr_debug("rf_protocol 0x%x\n", ntf.rf_protocol); + pr_debug("rf_tech_and_mode 0x%x\n", ntf.rf_tech_and_mode); + pr_debug("rf_tech_specific_params_len %d\n", + ntf.rf_tech_specific_params_len); + + if (ntf.rf_tech_specific_params_len > 0) { + switch (ntf.rf_tech_and_mode) { + case NCI_NFC_A_PASSIVE_POLL_MODE: + data = nci_extract_rf_params_nfca_passive_poll(ndev, + &(ntf.rf_tech_specific_params.nfca_poll), data); + break; + + case NCI_NFC_B_PASSIVE_POLL_MODE: + data = nci_extract_rf_params_nfcb_passive_poll(ndev, + &(ntf.rf_tech_specific_params.nfcb_poll), data); + break; + + case NCI_NFC_F_PASSIVE_POLL_MODE: + data = nci_extract_rf_params_nfcf_passive_poll(ndev, + &(ntf.rf_tech_specific_params.nfcf_poll), data); + break; + + case NCI_NFC_V_PASSIVE_POLL_MODE: + data = nci_extract_rf_params_nfcv_passive_poll(ndev, + &(ntf.rf_tech_specific_params.nfcv_poll), data); + break; + + default: + pr_err("unsupported rf_tech_and_mode 0x%x\n", + ntf.rf_tech_and_mode); + data += ntf.rf_tech_specific_params_len; + add_target = false; + } + } + + ntf.ntf_type = *data++; + pr_debug("ntf_type %d\n", ntf.ntf_type); + + if (add_target == true) + nci_add_new_target(ndev, &ntf); + + if (ntf.ntf_type == NCI_DISCOVER_NTF_TYPE_MORE) { + atomic_set(&ndev->state, NCI_W4_ALL_DISCOVERIES); + } else { + atomic_set(&ndev->state, NCI_W4_HOST_SELECT); + nfc_targets_found(ndev->nfc_dev, ndev->targets, + ndev->n_targets); + } +} + +static int nci_extract_activation_params_iso_dep(struct nci_dev *ndev, + struct nci_rf_intf_activated_ntf *ntf, + const __u8 *data) +{ + struct activation_params_nfca_poll_iso_dep *nfca_poll; + struct activation_params_nfcb_poll_iso_dep *nfcb_poll; + + switch (ntf->activation_rf_tech_and_mode) { + case NCI_NFC_A_PASSIVE_POLL_MODE: + nfca_poll = &ntf->activation_params.nfca_poll_iso_dep; + nfca_poll->rats_res_len = min_t(__u8, *data++, 20); + pr_debug("rats_res_len %d\n", nfca_poll->rats_res_len); + if (nfca_poll->rats_res_len > 0) { + memcpy(nfca_poll->rats_res, + data, nfca_poll->rats_res_len); + } + break; + + case NCI_NFC_B_PASSIVE_POLL_MODE: + nfcb_poll = &ntf->activation_params.nfcb_poll_iso_dep; + nfcb_poll->attrib_res_len = min_t(__u8, *data++, 50); + pr_debug("attrib_res_len %d\n", nfcb_poll->attrib_res_len); + if (nfcb_poll->attrib_res_len > 0) { + memcpy(nfcb_poll->attrib_res, + data, nfcb_poll->attrib_res_len); + } + break; + + default: + pr_err("unsupported activation_rf_tech_and_mode 0x%x\n", + ntf->activation_rf_tech_and_mode); + return NCI_STATUS_RF_PROTOCOL_ERROR; + } + + return NCI_STATUS_OK; +} + +static int nci_extract_activation_params_nfc_dep(struct nci_dev *ndev, + struct nci_rf_intf_activated_ntf *ntf, + const __u8 *data) +{ + struct activation_params_poll_nfc_dep *poll; + struct activation_params_listen_nfc_dep *listen; + + switch (ntf->activation_rf_tech_and_mode) { + case NCI_NFC_A_PASSIVE_POLL_MODE: + case NCI_NFC_F_PASSIVE_POLL_MODE: + poll = &ntf->activation_params.poll_nfc_dep; + poll->atr_res_len = min_t(__u8, *data++, + NFC_ATR_RES_MAXSIZE - 2); + pr_debug("atr_res_len %d\n", poll->atr_res_len); + if (poll->atr_res_len > 0) + memcpy(poll->atr_res, data, poll->atr_res_len); + break; + + case NCI_NFC_A_PASSIVE_LISTEN_MODE: + case NCI_NFC_F_PASSIVE_LISTEN_MODE: + listen = &ntf->activation_params.listen_nfc_dep; + listen->atr_req_len = min_t(__u8, *data++, + NFC_ATR_REQ_MAXSIZE - 2); + pr_debug("atr_req_len %d\n", listen->atr_req_len); + if (listen->atr_req_len > 0) + memcpy(listen->atr_req, data, listen->atr_req_len); + break; + + default: + pr_err("unsupported activation_rf_tech_and_mode 0x%x\n", + ntf->activation_rf_tech_and_mode); + return NCI_STATUS_RF_PROTOCOL_ERROR; + } + + return NCI_STATUS_OK; +} + +static void nci_target_auto_activated(struct nci_dev *ndev, + const struct nci_rf_intf_activated_ntf *ntf) +{ + struct nfc_target *target; + int rc; + + target = &ndev->targets[ndev->n_targets]; + + rc = nci_add_new_protocol(ndev, target, ntf->rf_protocol, + ntf->activation_rf_tech_and_mode, + &ntf->rf_tech_specific_params); + if (rc) + return; + + target->logical_idx = ntf->rf_discovery_id; + ndev->n_targets++; + + pr_debug("logical idx %d, n_targets %d\n", + target->logical_idx, ndev->n_targets); + + nfc_targets_found(ndev->nfc_dev, ndev->targets, ndev->n_targets); +} + +static int nci_store_general_bytes_nfc_dep(struct nci_dev *ndev, + const struct nci_rf_intf_activated_ntf *ntf) +{ + ndev->remote_gb_len = 0; + + if (ntf->activation_params_len <= 0) + return NCI_STATUS_OK; + + switch (ntf->activation_rf_tech_and_mode) { + case NCI_NFC_A_PASSIVE_POLL_MODE: + case NCI_NFC_F_PASSIVE_POLL_MODE: + ndev->remote_gb_len = min_t(__u8, + (ntf->activation_params.poll_nfc_dep.atr_res_len + - NFC_ATR_RES_GT_OFFSET), + NFC_ATR_RES_GB_MAXSIZE); + memcpy(ndev->remote_gb, + (ntf->activation_params.poll_nfc_dep.atr_res + + NFC_ATR_RES_GT_OFFSET), + ndev->remote_gb_len); + break; + + case NCI_NFC_A_PASSIVE_LISTEN_MODE: + case NCI_NFC_F_PASSIVE_LISTEN_MODE: + ndev->remote_gb_len = min_t(__u8, + (ntf->activation_params.listen_nfc_dep.atr_req_len + - NFC_ATR_REQ_GT_OFFSET), + NFC_ATR_REQ_GB_MAXSIZE); + memcpy(ndev->remote_gb, + (ntf->activation_params.listen_nfc_dep.atr_req + + NFC_ATR_REQ_GT_OFFSET), + ndev->remote_gb_len); + break; + + default: + pr_err("unsupported activation_rf_tech_and_mode 0x%x\n", + ntf->activation_rf_tech_and_mode); + return NCI_STATUS_RF_PROTOCOL_ERROR; + } + + return NCI_STATUS_OK; +} + +static void nci_rf_intf_activated_ntf_packet(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + struct nci_conn_info *conn_info; + struct nci_rf_intf_activated_ntf ntf; + const __u8 *data = skb->data; + int err = NCI_STATUS_OK; + + ntf.rf_discovery_id = *data++; + ntf.rf_interface = *data++; + ntf.rf_protocol = *data++; + ntf.activation_rf_tech_and_mode = *data++; + ntf.max_data_pkt_payload_size = *data++; + ntf.initial_num_credits = *data++; + ntf.rf_tech_specific_params_len = *data++; + + pr_debug("rf_discovery_id %d\n", ntf.rf_discovery_id); + pr_debug("rf_interface 0x%x\n", ntf.rf_interface); + pr_debug("rf_protocol 0x%x\n", ntf.rf_protocol); + pr_debug("activation_rf_tech_and_mode 0x%x\n", + ntf.activation_rf_tech_and_mode); + pr_debug("max_data_pkt_payload_size 0x%x\n", + ntf.max_data_pkt_payload_size); + pr_debug("initial_num_credits 0x%x\n", + ntf.initial_num_credits); + pr_debug("rf_tech_specific_params_len %d\n", + ntf.rf_tech_specific_params_len); + + /* If this contains a value of 0x00 (NFCEE Direct RF + * Interface) then all following parameters SHALL contain a + * value of 0 and SHALL be ignored. + */ + if (ntf.rf_interface == NCI_RF_INTERFACE_NFCEE_DIRECT) + goto listen; + + if (ntf.rf_tech_specific_params_len > 0) { + switch (ntf.activation_rf_tech_and_mode) { + case NCI_NFC_A_PASSIVE_POLL_MODE: + data = nci_extract_rf_params_nfca_passive_poll(ndev, + &(ntf.rf_tech_specific_params.nfca_poll), data); + break; + + case NCI_NFC_B_PASSIVE_POLL_MODE: + data = nci_extract_rf_params_nfcb_passive_poll(ndev, + &(ntf.rf_tech_specific_params.nfcb_poll), data); + break; + + case NCI_NFC_F_PASSIVE_POLL_MODE: + data = nci_extract_rf_params_nfcf_passive_poll(ndev, + &(ntf.rf_tech_specific_params.nfcf_poll), data); + break; + + case NCI_NFC_V_PASSIVE_POLL_MODE: + data = nci_extract_rf_params_nfcv_passive_poll(ndev, + &(ntf.rf_tech_specific_params.nfcv_poll), data); + break; + + case NCI_NFC_A_PASSIVE_LISTEN_MODE: + /* no RF technology specific parameters */ + break; + + case NCI_NFC_F_PASSIVE_LISTEN_MODE: + data = nci_extract_rf_params_nfcf_passive_listen(ndev, + &(ntf.rf_tech_specific_params.nfcf_listen), + data); + break; + + default: + pr_err("unsupported activation_rf_tech_and_mode 0x%x\n", + ntf.activation_rf_tech_and_mode); + err = NCI_STATUS_RF_PROTOCOL_ERROR; + goto exit; + } + } + + ntf.data_exch_rf_tech_and_mode = *data++; + ntf.data_exch_tx_bit_rate = *data++; + ntf.data_exch_rx_bit_rate = *data++; + ntf.activation_params_len = *data++; + + pr_debug("data_exch_rf_tech_and_mode 0x%x\n", + ntf.data_exch_rf_tech_and_mode); + pr_debug("data_exch_tx_bit_rate 0x%x\n", ntf.data_exch_tx_bit_rate); + pr_debug("data_exch_rx_bit_rate 0x%x\n", ntf.data_exch_rx_bit_rate); + pr_debug("activation_params_len %d\n", ntf.activation_params_len); + + if (ntf.activation_params_len > 0) { + switch (ntf.rf_interface) { + case NCI_RF_INTERFACE_ISO_DEP: + err = nci_extract_activation_params_iso_dep(ndev, + &ntf, data); + break; + + case NCI_RF_INTERFACE_NFC_DEP: + err = nci_extract_activation_params_nfc_dep(ndev, + &ntf, data); + break; + + case NCI_RF_INTERFACE_FRAME: + /* no activation params */ + break; + + default: + pr_err("unsupported rf_interface 0x%x\n", + ntf.rf_interface); + err = NCI_STATUS_RF_PROTOCOL_ERROR; + break; + } + } + +exit: + if (err == NCI_STATUS_OK) { + conn_info = ndev->rf_conn_info; + if (!conn_info) + return; + + conn_info->max_pkt_payload_len = ntf.max_data_pkt_payload_size; + conn_info->initial_num_credits = ntf.initial_num_credits; + + /* set the available credits to initial value */ + atomic_set(&conn_info->credits_cnt, + conn_info->initial_num_credits); + + /* store general bytes to be reported later in dep_link_up */ + if (ntf.rf_interface == NCI_RF_INTERFACE_NFC_DEP) { + err = nci_store_general_bytes_nfc_dep(ndev, &ntf); + if (err != NCI_STATUS_OK) + pr_err("unable to store general bytes\n"); + } + } + + if (!(ntf.activation_rf_tech_and_mode & NCI_RF_TECH_MODE_LISTEN_MASK)) { + /* Poll mode */ + if (atomic_read(&ndev->state) == NCI_DISCOVERY) { + /* A single target was found and activated + * automatically */ + atomic_set(&ndev->state, NCI_POLL_ACTIVE); + if (err == NCI_STATUS_OK) + nci_target_auto_activated(ndev, &ntf); + } else { /* ndev->state == NCI_W4_HOST_SELECT */ + /* A selected target was activated, so complete the + * request */ + atomic_set(&ndev->state, NCI_POLL_ACTIVE); + nci_req_complete(ndev, err); + } + } else { +listen: + /* Listen mode */ + atomic_set(&ndev->state, NCI_LISTEN_ACTIVE); + if (err == NCI_STATUS_OK && + ntf.rf_protocol == NCI_RF_PROTOCOL_NFC_DEP) { + err = nfc_tm_activated(ndev->nfc_dev, + NFC_PROTO_NFC_DEP_MASK, + NFC_COMM_PASSIVE, + ndev->remote_gb, + ndev->remote_gb_len); + if (err != NCI_STATUS_OK) + pr_err("error when signaling tm activation\n"); + } + } +} + +static void nci_rf_deactivate_ntf_packet(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + const struct nci_conn_info *conn_info; + const struct nci_rf_deactivate_ntf *ntf = (void *)skb->data; + + pr_debug("entry, type 0x%x, reason 0x%x\n", ntf->type, ntf->reason); + + conn_info = ndev->rf_conn_info; + if (!conn_info) + return; + + /* drop tx data queue */ + skb_queue_purge(&ndev->tx_q); + + /* drop partial rx data packet */ + if (ndev->rx_data_reassembly) { + kfree_skb(ndev->rx_data_reassembly); + ndev->rx_data_reassembly = NULL; + } + + /* complete the data exchange transaction, if exists */ + if (test_bit(NCI_DATA_EXCHANGE, &ndev->flags)) + nci_data_exchange_complete(ndev, NULL, NCI_STATIC_RF_CONN_ID, + -EIO); + + switch (ntf->type) { + case NCI_DEACTIVATE_TYPE_IDLE_MODE: + nci_clear_target_list(ndev); + atomic_set(&ndev->state, NCI_IDLE); + break; + case NCI_DEACTIVATE_TYPE_SLEEP_MODE: + case NCI_DEACTIVATE_TYPE_SLEEP_AF_MODE: + atomic_set(&ndev->state, NCI_W4_HOST_SELECT); + break; + case NCI_DEACTIVATE_TYPE_DISCOVERY: + nci_clear_target_list(ndev); + atomic_set(&ndev->state, NCI_DISCOVERY); + break; + } + + nci_req_complete(ndev, NCI_STATUS_OK); +} + +static void nci_nfcee_discover_ntf_packet(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + u8 status = NCI_STATUS_OK; + const struct nci_nfcee_discover_ntf *nfcee_ntf = + (struct nci_nfcee_discover_ntf *)skb->data; + + /* NFCForum NCI 9.2.1 HCI Network Specific Handling + * If the NFCC supports the HCI Network, it SHALL return one, + * and only one, NFCEE_DISCOVER_NTF with a Protocol type of + * “HCI Access”, even if the HCI Network contains multiple NFCEEs. + */ + ndev->hci_dev->nfcee_id = nfcee_ntf->nfcee_id; + ndev->cur_params.id = nfcee_ntf->nfcee_id; + + nci_req_complete(ndev, status); +} + +void nci_ntf_packet(struct nci_dev *ndev, struct sk_buff *skb) +{ + __u16 ntf_opcode = nci_opcode(skb->data); + + pr_debug("NCI RX: MT=ntf, PBF=%d, GID=0x%x, OID=0x%x, plen=%d\n", + nci_pbf(skb->data), + nci_opcode_gid(ntf_opcode), + nci_opcode_oid(ntf_opcode), + nci_plen(skb->data)); + + /* strip the nci control header */ + skb_pull(skb, NCI_CTRL_HDR_SIZE); + + if (nci_opcode_gid(ntf_opcode) == NCI_GID_PROPRIETARY) { + if (nci_prop_ntf_packet(ndev, ntf_opcode, skb) == -ENOTSUPP) { + pr_err("unsupported ntf opcode 0x%x\n", + ntf_opcode); + } + + goto end; + } + + switch (ntf_opcode) { + case NCI_OP_CORE_RESET_NTF: + nci_core_reset_ntf_packet(ndev, skb); + break; + + case NCI_OP_CORE_CONN_CREDITS_NTF: + nci_core_conn_credits_ntf_packet(ndev, skb); + break; + + case NCI_OP_CORE_GENERIC_ERROR_NTF: + nci_core_generic_error_ntf_packet(ndev, skb); + break; + + case NCI_OP_CORE_INTF_ERROR_NTF: + nci_core_conn_intf_error_ntf_packet(ndev, skb); + break; + + case NCI_OP_RF_DISCOVER_NTF: + nci_rf_discover_ntf_packet(ndev, skb); + break; + + case NCI_OP_RF_INTF_ACTIVATED_NTF: + nci_rf_intf_activated_ntf_packet(ndev, skb); + break; + + case NCI_OP_RF_DEACTIVATE_NTF: + nci_rf_deactivate_ntf_packet(ndev, skb); + break; + + case NCI_OP_NFCEE_DISCOVER_NTF: + nci_nfcee_discover_ntf_packet(ndev, skb); + break; + + case NCI_OP_RF_NFCEE_ACTION_NTF: + break; + + default: + pr_err("unknown ntf opcode 0x%x\n", ntf_opcode); + break; + } + + nci_core_ntf_packet(ndev, ntf_opcode, skb); +end: + kfree_skb(skb); +} diff --git a/net/nfc/nci/rsp.c b/net/nfc/nci/rsp.c new file mode 100644 index 000000000..b911ab78b --- /dev/null +++ b/net/nfc/nci/rsp.c @@ -0,0 +1,428 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * The NFC Controller Interface is the communication protocol between an + * NFC Controller (NFCC) and a Device Host (DH). + * + * Copyright (C) 2011 Texas Instruments, Inc. + * + * Written by Ilan Elias + * + * Acknowledgements: + * This file is based on hci_event.c, which was written + * by Maxim Krasnyansky. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": %s: " fmt, __func__ + +#include +#include +#include +#include + +#include "../nfc.h" +#include +#include + +/* Handle NCI Response packets */ + +static void nci_core_reset_rsp_packet(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + const struct nci_core_reset_rsp *rsp = (void *)skb->data; + + pr_debug("status 0x%x\n", rsp->status); + + /* Handle NCI 1.x ver */ + if (skb->len != 1) { + if (rsp->status == NCI_STATUS_OK) { + ndev->nci_ver = rsp->nci_ver; + pr_debug("nci_ver 0x%x, config_status 0x%x\n", + rsp->nci_ver, rsp->config_status); + } + + nci_req_complete(ndev, rsp->status); + } +} + +static u8 nci_core_init_rsp_packet_v1(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + const struct nci_core_init_rsp_1 *rsp_1 = (void *)skb->data; + const struct nci_core_init_rsp_2 *rsp_2; + + pr_debug("status 0x%x\n", rsp_1->status); + + if (rsp_1->status != NCI_STATUS_OK) + return rsp_1->status; + + ndev->nfcc_features = __le32_to_cpu(rsp_1->nfcc_features); + ndev->num_supported_rf_interfaces = rsp_1->num_supported_rf_interfaces; + + ndev->num_supported_rf_interfaces = + min((int)ndev->num_supported_rf_interfaces, + NCI_MAX_SUPPORTED_RF_INTERFACES); + + memcpy(ndev->supported_rf_interfaces, + rsp_1->supported_rf_interfaces, + ndev->num_supported_rf_interfaces); + + rsp_2 = (void *) (skb->data + 6 + rsp_1->num_supported_rf_interfaces); + + ndev->max_logical_connections = rsp_2->max_logical_connections; + ndev->max_routing_table_size = + __le16_to_cpu(rsp_2->max_routing_table_size); + ndev->max_ctrl_pkt_payload_len = + rsp_2->max_ctrl_pkt_payload_len; + ndev->max_size_for_large_params = + __le16_to_cpu(rsp_2->max_size_for_large_params); + ndev->manufact_id = + rsp_2->manufact_id; + ndev->manufact_specific_info = + __le32_to_cpu(rsp_2->manufact_specific_info); + + return NCI_STATUS_OK; +} + +static u8 nci_core_init_rsp_packet_v2(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + const struct nci_core_init_rsp_nci_ver2 *rsp = (void *)skb->data; + const u8 *supported_rf_interface = rsp->supported_rf_interfaces; + u8 rf_interface_idx = 0; + u8 rf_extension_cnt = 0; + + pr_debug("status %x\n", rsp->status); + + if (rsp->status != NCI_STATUS_OK) + return rsp->status; + + ndev->nfcc_features = __le32_to_cpu(rsp->nfcc_features); + ndev->num_supported_rf_interfaces = rsp->num_supported_rf_interfaces; + + ndev->num_supported_rf_interfaces = + min((int)ndev->num_supported_rf_interfaces, + NCI_MAX_SUPPORTED_RF_INTERFACES); + + while (rf_interface_idx < ndev->num_supported_rf_interfaces) { + ndev->supported_rf_interfaces[rf_interface_idx++] = *supported_rf_interface++; + + /* skip rf extension parameters */ + rf_extension_cnt = *supported_rf_interface++; + supported_rf_interface += rf_extension_cnt; + } + + ndev->max_logical_connections = rsp->max_logical_connections; + ndev->max_routing_table_size = + __le16_to_cpu(rsp->max_routing_table_size); + ndev->max_ctrl_pkt_payload_len = + rsp->max_ctrl_pkt_payload_len; + ndev->max_size_for_large_params = NCI_MAX_LARGE_PARAMS_NCI_v2; + + return NCI_STATUS_OK; +} + +static void nci_core_init_rsp_packet(struct nci_dev *ndev, const struct sk_buff *skb) +{ + u8 status = 0; + + if (!(ndev->nci_ver & NCI_VER_2_MASK)) + status = nci_core_init_rsp_packet_v1(ndev, skb); + else + status = nci_core_init_rsp_packet_v2(ndev, skb); + + if (status != NCI_STATUS_OK) + goto exit; + + pr_debug("nfcc_features 0x%x\n", + ndev->nfcc_features); + pr_debug("num_supported_rf_interfaces %d\n", + ndev->num_supported_rf_interfaces); + pr_debug("supported_rf_interfaces[0] 0x%x\n", + ndev->supported_rf_interfaces[0]); + pr_debug("supported_rf_interfaces[1] 0x%x\n", + ndev->supported_rf_interfaces[1]); + pr_debug("supported_rf_interfaces[2] 0x%x\n", + ndev->supported_rf_interfaces[2]); + pr_debug("supported_rf_interfaces[3] 0x%x\n", + ndev->supported_rf_interfaces[3]); + pr_debug("max_logical_connections %d\n", + ndev->max_logical_connections); + pr_debug("max_routing_table_size %d\n", + ndev->max_routing_table_size); + pr_debug("max_ctrl_pkt_payload_len %d\n", + ndev->max_ctrl_pkt_payload_len); + pr_debug("max_size_for_large_params %d\n", + ndev->max_size_for_large_params); + pr_debug("manufact_id 0x%x\n", + ndev->manufact_id); + pr_debug("manufact_specific_info 0x%x\n", + ndev->manufact_specific_info); + +exit: + nci_req_complete(ndev, status); +} + +static void nci_core_set_config_rsp_packet(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + const struct nci_core_set_config_rsp *rsp = (void *)skb->data; + + pr_debug("status 0x%x\n", rsp->status); + + nci_req_complete(ndev, rsp->status); +} + +static void nci_rf_disc_map_rsp_packet(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + __u8 status = skb->data[0]; + + pr_debug("status 0x%x\n", status); + + nci_req_complete(ndev, status); +} + +static void nci_rf_disc_rsp_packet(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + struct nci_conn_info *conn_info; + __u8 status = skb->data[0]; + + pr_debug("status 0x%x\n", status); + + if (status == NCI_STATUS_OK) { + atomic_set(&ndev->state, NCI_DISCOVERY); + + conn_info = ndev->rf_conn_info; + if (!conn_info) { + conn_info = devm_kzalloc(&ndev->nfc_dev->dev, + sizeof(struct nci_conn_info), + GFP_KERNEL); + if (!conn_info) { + status = NCI_STATUS_REJECTED; + goto exit; + } + conn_info->conn_id = NCI_STATIC_RF_CONN_ID; + INIT_LIST_HEAD(&conn_info->list); + list_add(&conn_info->list, &ndev->conn_info_list); + ndev->rf_conn_info = conn_info; + } + } + +exit: + nci_req_complete(ndev, status); +} + +static void nci_rf_disc_select_rsp_packet(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + __u8 status = skb->data[0]; + + pr_debug("status 0x%x\n", status); + + /* Complete the request on intf_activated_ntf or generic_error_ntf */ + if (status != NCI_STATUS_OK) + nci_req_complete(ndev, status); +} + +static void nci_rf_deactivate_rsp_packet(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + __u8 status = skb->data[0]; + + pr_debug("status 0x%x\n", status); + + /* If target was active, complete the request only in deactivate_ntf */ + if ((status != NCI_STATUS_OK) || + (atomic_read(&ndev->state) != NCI_POLL_ACTIVE)) { + nci_clear_target_list(ndev); + atomic_set(&ndev->state, NCI_IDLE); + nci_req_complete(ndev, status); + } +} + +static void nci_nfcee_discover_rsp_packet(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + const struct nci_nfcee_discover_rsp *discover_rsp; + + if (skb->len != 2) { + nci_req_complete(ndev, NCI_STATUS_NFCEE_PROTOCOL_ERROR); + return; + } + + discover_rsp = (struct nci_nfcee_discover_rsp *)skb->data; + + if (discover_rsp->status != NCI_STATUS_OK || + discover_rsp->num_nfcee == 0) + nci_req_complete(ndev, discover_rsp->status); +} + +static void nci_nfcee_mode_set_rsp_packet(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + __u8 status = skb->data[0]; + + pr_debug("status 0x%x\n", status); + nci_req_complete(ndev, status); +} + +static void nci_core_conn_create_rsp_packet(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + __u8 status = skb->data[0]; + struct nci_conn_info *conn_info = NULL; + const struct nci_core_conn_create_rsp *rsp; + + pr_debug("status 0x%x\n", status); + + if (status == NCI_STATUS_OK) { + rsp = (struct nci_core_conn_create_rsp *)skb->data; + + conn_info = devm_kzalloc(&ndev->nfc_dev->dev, + sizeof(*conn_info), GFP_KERNEL); + if (!conn_info) { + status = NCI_STATUS_REJECTED; + goto exit; + } + + conn_info->dest_params = devm_kzalloc(&ndev->nfc_dev->dev, + sizeof(struct dest_spec_params), + GFP_KERNEL); + if (!conn_info->dest_params) { + status = NCI_STATUS_REJECTED; + goto free_conn_info; + } + + conn_info->dest_type = ndev->cur_dest_type; + conn_info->dest_params->id = ndev->cur_params.id; + conn_info->dest_params->protocol = ndev->cur_params.protocol; + conn_info->conn_id = rsp->conn_id; + + /* Note: data_exchange_cb and data_exchange_cb_context need to + * be specify out of nci_core_conn_create_rsp_packet + */ + + INIT_LIST_HEAD(&conn_info->list); + list_add(&conn_info->list, &ndev->conn_info_list); + + if (ndev->cur_params.id == ndev->hci_dev->nfcee_id) + ndev->hci_dev->conn_info = conn_info; + + conn_info->conn_id = rsp->conn_id; + conn_info->max_pkt_payload_len = rsp->max_ctrl_pkt_payload_len; + atomic_set(&conn_info->credits_cnt, rsp->credits_cnt); + } + +free_conn_info: + if (status == NCI_STATUS_REJECTED) + devm_kfree(&ndev->nfc_dev->dev, conn_info); +exit: + + nci_req_complete(ndev, status); +} + +static void nci_core_conn_close_rsp_packet(struct nci_dev *ndev, + const struct sk_buff *skb) +{ + struct nci_conn_info *conn_info; + __u8 status = skb->data[0]; + + pr_debug("status 0x%x\n", status); + if (status == NCI_STATUS_OK) { + conn_info = nci_get_conn_info_by_conn_id(ndev, + ndev->cur_conn_id); + if (conn_info) { + list_del(&conn_info->list); + if (conn_info == ndev->rf_conn_info) + ndev->rf_conn_info = NULL; + devm_kfree(&ndev->nfc_dev->dev, conn_info); + } + } + nci_req_complete(ndev, status); +} + +void nci_rsp_packet(struct nci_dev *ndev, struct sk_buff *skb) +{ + __u16 rsp_opcode = nci_opcode(skb->data); + + /* we got a rsp, stop the cmd timer */ + del_timer(&ndev->cmd_timer); + + pr_debug("NCI RX: MT=rsp, PBF=%d, GID=0x%x, OID=0x%x, plen=%d\n", + nci_pbf(skb->data), + nci_opcode_gid(rsp_opcode), + nci_opcode_oid(rsp_opcode), + nci_plen(skb->data)); + + /* strip the nci control header */ + skb_pull(skb, NCI_CTRL_HDR_SIZE); + + if (nci_opcode_gid(rsp_opcode) == NCI_GID_PROPRIETARY) { + if (nci_prop_rsp_packet(ndev, rsp_opcode, skb) == -ENOTSUPP) { + pr_err("unsupported rsp opcode 0x%x\n", + rsp_opcode); + } + + goto end; + } + + switch (rsp_opcode) { + case NCI_OP_CORE_RESET_RSP: + nci_core_reset_rsp_packet(ndev, skb); + break; + + case NCI_OP_CORE_INIT_RSP: + nci_core_init_rsp_packet(ndev, skb); + break; + + case NCI_OP_CORE_SET_CONFIG_RSP: + nci_core_set_config_rsp_packet(ndev, skb); + break; + + case NCI_OP_CORE_CONN_CREATE_RSP: + nci_core_conn_create_rsp_packet(ndev, skb); + break; + + case NCI_OP_CORE_CONN_CLOSE_RSP: + nci_core_conn_close_rsp_packet(ndev, skb); + break; + + case NCI_OP_RF_DISCOVER_MAP_RSP: + nci_rf_disc_map_rsp_packet(ndev, skb); + break; + + case NCI_OP_RF_DISCOVER_RSP: + nci_rf_disc_rsp_packet(ndev, skb); + break; + + case NCI_OP_RF_DISCOVER_SELECT_RSP: + nci_rf_disc_select_rsp_packet(ndev, skb); + break; + + case NCI_OP_RF_DEACTIVATE_RSP: + nci_rf_deactivate_rsp_packet(ndev, skb); + break; + + case NCI_OP_NFCEE_DISCOVER_RSP: + nci_nfcee_discover_rsp_packet(ndev, skb); + break; + + case NCI_OP_NFCEE_MODE_SET_RSP: + nci_nfcee_mode_set_rsp_packet(ndev, skb); + break; + + default: + pr_err("unknown rsp opcode 0x%x\n", rsp_opcode); + break; + } + + nci_core_rsp_packet(ndev, rsp_opcode, skb); +end: + kfree_skb(skb); + + /* trigger the next cmd */ + atomic_set(&ndev->cmd_cnt, 1); + if (!skb_queue_empty(&ndev->cmd_q)) + queue_work(ndev->cmd_wq, &ndev->cmd_work); +} diff --git a/net/nfc/nci/spi.c b/net/nfc/nci/spi.c new file mode 100644 index 000000000..b68150c97 --- /dev/null +++ b/net/nfc/nci/spi.c @@ -0,0 +1,322 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (C) 2013 Intel Corporation. All rights reserved. + */ + +#define pr_fmt(fmt) "nci_spi: %s: " fmt, __func__ + +#include + +#include +#include +#include +#include + +#define NCI_SPI_ACK_SHIFT 6 +#define NCI_SPI_MSB_PAYLOAD_MASK 0x3F + +#define NCI_SPI_SEND_TIMEOUT (NCI_CMD_TIMEOUT > NCI_DATA_TIMEOUT ? \ + NCI_CMD_TIMEOUT : NCI_DATA_TIMEOUT) + +#define NCI_SPI_DIRECT_WRITE 0x01 +#define NCI_SPI_DIRECT_READ 0x02 + +#define ACKNOWLEDGE_NONE 0 +#define ACKNOWLEDGE_ACK 1 +#define ACKNOWLEDGE_NACK 2 + +#define CRC_INIT 0xFFFF + +static int __nci_spi_send(struct nci_spi *nspi, const struct sk_buff *skb, + int cs_change) +{ + struct spi_message m; + struct spi_transfer t; + + memset(&t, 0, sizeof(struct spi_transfer)); + /* a NULL skb means we just want the SPI chip select line to raise */ + if (skb) { + t.tx_buf = skb->data; + t.len = skb->len; + } else { + /* still set tx_buf non NULL to make the driver happy */ + t.tx_buf = &t; + t.len = 0; + } + t.cs_change = cs_change; + t.delay.value = nspi->xfer_udelay; + t.delay.unit = SPI_DELAY_UNIT_USECS; + t.speed_hz = nspi->xfer_speed_hz; + + spi_message_init(&m); + spi_message_add_tail(&t, &m); + + return spi_sync(nspi->spi, &m); +} + +int nci_spi_send(struct nci_spi *nspi, + struct completion *write_handshake_completion, + struct sk_buff *skb) +{ + unsigned int payload_len = skb->len; + unsigned char *hdr; + int ret; + long completion_rc; + + /* add the NCI SPI header to the start of the buffer */ + hdr = skb_push(skb, NCI_SPI_HDR_LEN); + hdr[0] = NCI_SPI_DIRECT_WRITE; + hdr[1] = nspi->acknowledge_mode; + hdr[2] = payload_len >> 8; + hdr[3] = payload_len & 0xFF; + + if (nspi->acknowledge_mode == NCI_SPI_CRC_ENABLED) { + u16 crc; + + crc = crc_ccitt(CRC_INIT, skb->data, skb->len); + skb_put_u8(skb, crc >> 8); + skb_put_u8(skb, crc & 0xFF); + } + + if (write_handshake_completion) { + /* Trick SPI driver to raise chip select */ + ret = __nci_spi_send(nspi, NULL, 1); + if (ret) + goto done; + + /* wait for NFC chip hardware handshake to complete */ + if (wait_for_completion_timeout(write_handshake_completion, + msecs_to_jiffies(1000)) == 0) { + ret = -ETIME; + goto done; + } + } + + ret = __nci_spi_send(nspi, skb, 0); + if (ret != 0 || nspi->acknowledge_mode == NCI_SPI_CRC_DISABLED) + goto done; + + reinit_completion(&nspi->req_completion); + completion_rc = wait_for_completion_interruptible_timeout( + &nspi->req_completion, + NCI_SPI_SEND_TIMEOUT); + + if (completion_rc <= 0 || nspi->req_result == ACKNOWLEDGE_NACK) + ret = -EIO; + +done: + kfree_skb(skb); + + return ret; +} +EXPORT_SYMBOL_GPL(nci_spi_send); + +/* ---- Interface to NCI SPI drivers ---- */ + +/** + * nci_spi_allocate_spi - allocate a new nci spi + * + * @spi: SPI device + * @acknowledge_mode: Acknowledge mode used by the NFC device + * @delay: delay between transactions in us + * @ndev: nci dev to send incoming nci frames to + */ +struct nci_spi *nci_spi_allocate_spi(struct spi_device *spi, + u8 acknowledge_mode, unsigned int delay, + struct nci_dev *ndev) +{ + struct nci_spi *nspi; + + nspi = devm_kzalloc(&spi->dev, sizeof(struct nci_spi), GFP_KERNEL); + if (!nspi) + return NULL; + + nspi->acknowledge_mode = acknowledge_mode; + nspi->xfer_udelay = delay; + /* Use controller max SPI speed by default */ + nspi->xfer_speed_hz = 0; + nspi->spi = spi; + nspi->ndev = ndev; + init_completion(&nspi->req_completion); + + return nspi; +} +EXPORT_SYMBOL_GPL(nci_spi_allocate_spi); + +static int send_acknowledge(struct nci_spi *nspi, u8 acknowledge) +{ + struct sk_buff *skb; + unsigned char *hdr; + u16 crc; + int ret; + + skb = nci_skb_alloc(nspi->ndev, 0, GFP_KERNEL); + if (!skb) + return -ENOMEM; + + /* add the NCI SPI header to the start of the buffer */ + hdr = skb_push(skb, NCI_SPI_HDR_LEN); + hdr[0] = NCI_SPI_DIRECT_WRITE; + hdr[1] = NCI_SPI_CRC_ENABLED; + hdr[2] = acknowledge << NCI_SPI_ACK_SHIFT; + hdr[3] = 0; + + crc = crc_ccitt(CRC_INIT, skb->data, skb->len); + skb_put_u8(skb, crc >> 8); + skb_put_u8(skb, crc & 0xFF); + + ret = __nci_spi_send(nspi, skb, 0); + + kfree_skb(skb); + + return ret; +} + +static struct sk_buff *__nci_spi_read(struct nci_spi *nspi) +{ + struct sk_buff *skb; + struct spi_message m; + unsigned char req[2], resp_hdr[2]; + struct spi_transfer tx, rx; + unsigned short rx_len = 0; + int ret; + + spi_message_init(&m); + + memset(&tx, 0, sizeof(struct spi_transfer)); + req[0] = NCI_SPI_DIRECT_READ; + req[1] = nspi->acknowledge_mode; + tx.tx_buf = req; + tx.len = 2; + tx.cs_change = 0; + tx.speed_hz = nspi->xfer_speed_hz; + spi_message_add_tail(&tx, &m); + + memset(&rx, 0, sizeof(struct spi_transfer)); + rx.rx_buf = resp_hdr; + rx.len = 2; + rx.cs_change = 1; + rx.speed_hz = nspi->xfer_speed_hz; + spi_message_add_tail(&rx, &m); + + ret = spi_sync(nspi->spi, &m); + if (ret) + return NULL; + + if (nspi->acknowledge_mode == NCI_SPI_CRC_ENABLED) + rx_len = ((resp_hdr[0] & NCI_SPI_MSB_PAYLOAD_MASK) << 8) + + resp_hdr[1] + NCI_SPI_CRC_LEN; + else + rx_len = (resp_hdr[0] << 8) | resp_hdr[1]; + + skb = nci_skb_alloc(nspi->ndev, rx_len, GFP_KERNEL); + if (!skb) + return NULL; + + spi_message_init(&m); + + memset(&rx, 0, sizeof(struct spi_transfer)); + rx.rx_buf = skb_put(skb, rx_len); + rx.len = rx_len; + rx.cs_change = 0; + rx.delay.value = nspi->xfer_udelay; + rx.delay.unit = SPI_DELAY_UNIT_USECS; + rx.speed_hz = nspi->xfer_speed_hz; + spi_message_add_tail(&rx, &m); + + ret = spi_sync(nspi->spi, &m); + if (ret) + goto receive_error; + + if (nspi->acknowledge_mode == NCI_SPI_CRC_ENABLED) { + *(u8 *)skb_push(skb, 1) = resp_hdr[1]; + *(u8 *)skb_push(skb, 1) = resp_hdr[0]; + } + + return skb; + +receive_error: + kfree_skb(skb); + + return NULL; +} + +static int nci_spi_check_crc(struct sk_buff *skb) +{ + u16 crc_data = (skb->data[skb->len - 2] << 8) | + skb->data[skb->len - 1]; + int ret; + + ret = (crc_ccitt(CRC_INIT, skb->data, skb->len - NCI_SPI_CRC_LEN) + == crc_data); + + skb_trim(skb, skb->len - NCI_SPI_CRC_LEN); + + return ret; +} + +static u8 nci_spi_get_ack(struct sk_buff *skb) +{ + u8 ret; + + ret = skb->data[0] >> NCI_SPI_ACK_SHIFT; + + /* Remove NFCC part of the header: ACK, NACK and MSB payload len */ + skb_pull(skb, 2); + + return ret; +} + +/** + * nci_spi_read - read frame from NCI SPI drivers + * + * @nspi: The nci spi + * Context: can sleep + * + * This call may only be used from a context that may sleep. The sleep + * is non-interruptible, and has no timeout. + * + * It returns an allocated skb containing the frame on success, or NULL. + */ +struct sk_buff *nci_spi_read(struct nci_spi *nspi) +{ + struct sk_buff *skb; + + /* Retrieve frame from SPI */ + skb = __nci_spi_read(nspi); + if (!skb) + goto done; + + if (nspi->acknowledge_mode == NCI_SPI_CRC_ENABLED) { + if (!nci_spi_check_crc(skb)) { + send_acknowledge(nspi, ACKNOWLEDGE_NACK); + goto done; + } + + /* In case of acknowledged mode: if ACK or NACK received, + * unblock completion of latest frame sent. + */ + nspi->req_result = nci_spi_get_ack(skb); + if (nspi->req_result) + complete(&nspi->req_completion); + } + + /* If there is no payload (ACK/NACK only frame), + * free the socket buffer + */ + if (!skb->len) { + kfree_skb(skb); + skb = NULL; + goto done; + } + + if (nspi->acknowledge_mode == NCI_SPI_CRC_ENABLED) + send_acknowledge(nspi, ACKNOWLEDGE_ACK); + +done: + + return skb; +} +EXPORT_SYMBOL_GPL(nci_spi_read); + +MODULE_LICENSE("GPL"); diff --git a/net/nfc/nci/uart.c b/net/nfc/nci/uart.c new file mode 100644 index 000000000..cc8fa9e36 --- /dev/null +++ b/net/nfc/nci/uart.c @@ -0,0 +1,461 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (C) 2015, Marvell International Ltd. + * + * Inspired (hugely) by HCI LDISC implementation in Bluetooth. + * + * Copyright (C) 2000-2001 Qualcomm Incorporated + * Copyright (C) 2002-2003 Maxim Krasnyansky + * Copyright (C) 2004-2005 Marcel Holtmann + */ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +/* TX states */ +#define NCI_UART_SENDING 1 +#define NCI_UART_TX_WAKEUP 2 + +static struct nci_uart *nci_uart_drivers[NCI_UART_DRIVER_MAX]; + +static inline struct sk_buff *nci_uart_dequeue(struct nci_uart *nu) +{ + struct sk_buff *skb = nu->tx_skb; + + if (!skb) + skb = skb_dequeue(&nu->tx_q); + else + nu->tx_skb = NULL; + + return skb; +} + +static inline int nci_uart_queue_empty(struct nci_uart *nu) +{ + if (nu->tx_skb) + return 0; + + return skb_queue_empty(&nu->tx_q); +} + +static int nci_uart_tx_wakeup(struct nci_uart *nu) +{ + if (test_and_set_bit(NCI_UART_SENDING, &nu->tx_state)) { + set_bit(NCI_UART_TX_WAKEUP, &nu->tx_state); + return 0; + } + + schedule_work(&nu->write_work); + + return 0; +} + +static void nci_uart_write_work(struct work_struct *work) +{ + struct nci_uart *nu = container_of(work, struct nci_uart, write_work); + struct tty_struct *tty = nu->tty; + struct sk_buff *skb; + +restart: + clear_bit(NCI_UART_TX_WAKEUP, &nu->tx_state); + + if (nu->ops.tx_start) + nu->ops.tx_start(nu); + + while ((skb = nci_uart_dequeue(nu))) { + int len; + + set_bit(TTY_DO_WRITE_WAKEUP, &tty->flags); + len = tty->ops->write(tty, skb->data, skb->len); + skb_pull(skb, len); + if (skb->len) { + nu->tx_skb = skb; + break; + } + kfree_skb(skb); + } + + if (test_bit(NCI_UART_TX_WAKEUP, &nu->tx_state)) + goto restart; + + if (nu->ops.tx_done && nci_uart_queue_empty(nu)) + nu->ops.tx_done(nu); + + clear_bit(NCI_UART_SENDING, &nu->tx_state); +} + +static int nci_uart_set_driver(struct tty_struct *tty, unsigned int driver) +{ + struct nci_uart *nu = NULL; + int ret; + + if (driver >= NCI_UART_DRIVER_MAX) + return -EINVAL; + + if (!nci_uart_drivers[driver]) + return -ENOENT; + + nu = kzalloc(sizeof(*nu), GFP_KERNEL); + if (!nu) + return -ENOMEM; + + memcpy(nu, nci_uart_drivers[driver], sizeof(struct nci_uart)); + nu->tty = tty; + tty->disc_data = nu; + skb_queue_head_init(&nu->tx_q); + INIT_WORK(&nu->write_work, nci_uart_write_work); + spin_lock_init(&nu->rx_lock); + + ret = nu->ops.open(nu); + if (ret) { + tty->disc_data = NULL; + kfree(nu); + } else if (!try_module_get(nu->owner)) { + nu->ops.close(nu); + tty->disc_data = NULL; + kfree(nu); + return -ENOENT; + } + return ret; +} + +/* ------ LDISC part ------ */ + +/* nci_uart_tty_open + * + * Called when line discipline changed to NCI_UART. + * + * Arguments: + * tty pointer to tty info structure + * Return Value: + * 0 if success, otherwise error code + */ +static int nci_uart_tty_open(struct tty_struct *tty) +{ + /* Error if the tty has no write op instead of leaving an exploitable + * hole + */ + if (!tty->ops->write) + return -EOPNOTSUPP; + + tty->disc_data = NULL; + tty->receive_room = 65536; + + /* Flush any pending characters in the driver */ + tty_driver_flush_buffer(tty); + + return 0; +} + +/* nci_uart_tty_close() + * + * Called when the line discipline is changed to something + * else, the tty is closed, or the tty detects a hangup. + */ +static void nci_uart_tty_close(struct tty_struct *tty) +{ + struct nci_uart *nu = (void *)tty->disc_data; + + /* Detach from the tty */ + tty->disc_data = NULL; + + if (!nu) + return; + + kfree_skb(nu->tx_skb); + kfree_skb(nu->rx_skb); + + skb_queue_purge(&nu->tx_q); + + nu->ops.close(nu); + nu->tty = NULL; + module_put(nu->owner); + + cancel_work_sync(&nu->write_work); + + kfree(nu); +} + +/* nci_uart_tty_wakeup() + * + * Callback for transmit wakeup. Called when low level + * device driver can accept more send data. + * + * Arguments: tty pointer to associated tty instance data + * Return Value: None + */ +static void nci_uart_tty_wakeup(struct tty_struct *tty) +{ + struct nci_uart *nu = (void *)tty->disc_data; + + if (!nu) + return; + + clear_bit(TTY_DO_WRITE_WAKEUP, &tty->flags); + + if (tty != nu->tty) + return; + + nci_uart_tx_wakeup(nu); +} + +/* -- Default recv_buf handler -- + * + * This handler supposes that NCI frames are sent over UART link without any + * framing. It reads NCI header, retrieve the packet size and once all packet + * bytes are received it passes it to nci_uart driver for processing. + */ +static int nci_uart_default_recv_buf(struct nci_uart *nu, const u8 *data, + int count) +{ + int chunk_len; + + if (!nu->ndev) { + nfc_err(nu->tty->dev, + "receive data from tty but no NCI dev is attached yet, drop buffer\n"); + return 0; + } + + /* Decode all incoming data in packets + * and enqueue then for processing. + */ + while (count > 0) { + /* If this is the first data of a packet, allocate a buffer */ + if (!nu->rx_skb) { + nu->rx_packet_len = -1; + nu->rx_skb = nci_skb_alloc(nu->ndev, + NCI_MAX_PACKET_SIZE, + GFP_ATOMIC); + if (!nu->rx_skb) + return -ENOMEM; + } + + /* Eat byte after byte till full packet header is received */ + if (nu->rx_skb->len < NCI_CTRL_HDR_SIZE) { + skb_put_u8(nu->rx_skb, *data++); + --count; + continue; + } + + /* Header was received but packet len was not read */ + if (nu->rx_packet_len < 0) + nu->rx_packet_len = NCI_CTRL_HDR_SIZE + + nci_plen(nu->rx_skb->data); + + /* Compute how many bytes are missing and how many bytes can + * be consumed. + */ + chunk_len = nu->rx_packet_len - nu->rx_skb->len; + if (count < chunk_len) + chunk_len = count; + skb_put_data(nu->rx_skb, data, chunk_len); + data += chunk_len; + count -= chunk_len; + + /* Check if packet is fully received */ + if (nu->rx_packet_len == nu->rx_skb->len) { + /* Pass RX packet to driver */ + if (nu->ops.recv(nu, nu->rx_skb) != 0) + nfc_err(nu->tty->dev, "corrupted RX packet\n"); + /* Next packet will be a new one */ + nu->rx_skb = NULL; + } + } + + return 0; +} + +/* nci_uart_tty_receive() + * + * Called by tty low level driver when receive data is + * available. + * + * Arguments: tty pointer to tty instance data + * data pointer to received data + * flags pointer to flags for data + * count count of received data in bytes + * + * Return Value: None + */ +static void nci_uart_tty_receive(struct tty_struct *tty, const u8 *data, + const char *flags, int count) +{ + struct nci_uart *nu = (void *)tty->disc_data; + + if (!nu || tty != nu->tty) + return; + + spin_lock(&nu->rx_lock); + nci_uart_default_recv_buf(nu, data, count); + spin_unlock(&nu->rx_lock); + + tty_unthrottle(tty); +} + +/* nci_uart_tty_ioctl() + * + * Process IOCTL system call for the tty device. + * + * Arguments: + * + * tty pointer to tty instance data + * cmd IOCTL command code + * arg argument for IOCTL call (cmd dependent) + * + * Return Value: Command dependent + */ +static int nci_uart_tty_ioctl(struct tty_struct *tty, unsigned int cmd, + unsigned long arg) +{ + struct nci_uart *nu = (void *)tty->disc_data; + int err = 0; + + switch (cmd) { + case NCIUARTSETDRIVER: + if (!nu) + return nci_uart_set_driver(tty, (unsigned int)arg); + else + return -EBUSY; + break; + default: + err = n_tty_ioctl_helper(tty, cmd, arg); + break; + } + + return err; +} + +/* We don't provide read/write/poll interface for user space. */ +static ssize_t nci_uart_tty_read(struct tty_struct *tty, struct file *file, + unsigned char *buf, size_t nr, + void **cookie, unsigned long offset) +{ + return 0; +} + +static ssize_t nci_uart_tty_write(struct tty_struct *tty, struct file *file, + const unsigned char *data, size_t count) +{ + return 0; +} + +static __poll_t nci_uart_tty_poll(struct tty_struct *tty, + struct file *filp, poll_table *wait) +{ + return 0; +} + +static int nci_uart_send(struct nci_uart *nu, struct sk_buff *skb) +{ + /* Queue TX packet */ + skb_queue_tail(&nu->tx_q, skb); + + /* Try to start TX (if possible) */ + nci_uart_tx_wakeup(nu); + + return 0; +} + +int nci_uart_register(struct nci_uart *nu) +{ + if (!nu || !nu->ops.open || + !nu->ops.recv || !nu->ops.close) + return -EINVAL; + + /* Set the send callback */ + nu->ops.send = nci_uart_send; + + /* Add this driver in the driver list */ + if (nci_uart_drivers[nu->driver]) { + pr_err("driver %d is already registered\n", nu->driver); + return -EBUSY; + } + nci_uart_drivers[nu->driver] = nu; + + pr_info("NCI uart driver '%s [%d]' registered\n", nu->name, nu->driver); + + return 0; +} +EXPORT_SYMBOL_GPL(nci_uart_register); + +void nci_uart_unregister(struct nci_uart *nu) +{ + pr_info("NCI uart driver '%s [%d]' unregistered\n", nu->name, + nu->driver); + + /* Remove this driver from the driver list */ + nci_uart_drivers[nu->driver] = NULL; +} +EXPORT_SYMBOL_GPL(nci_uart_unregister); + +void nci_uart_set_config(struct nci_uart *nu, int baudrate, int flow_ctrl) +{ + struct ktermios new_termios; + + if (!nu->tty) + return; + + down_read(&nu->tty->termios_rwsem); + new_termios = nu->tty->termios; + up_read(&nu->tty->termios_rwsem); + tty_termios_encode_baud_rate(&new_termios, baudrate, baudrate); + + if (flow_ctrl) + new_termios.c_cflag |= CRTSCTS; + else + new_termios.c_cflag &= ~CRTSCTS; + + tty_set_termios(nu->tty, &new_termios); +} +EXPORT_SYMBOL_GPL(nci_uart_set_config); + +static struct tty_ldisc_ops nci_uart_ldisc = { + .owner = THIS_MODULE, + .num = N_NCI, + .name = "n_nci", + .open = nci_uart_tty_open, + .close = nci_uart_tty_close, + .read = nci_uart_tty_read, + .write = nci_uart_tty_write, + .poll = nci_uart_tty_poll, + .receive_buf = nci_uart_tty_receive, + .write_wakeup = nci_uart_tty_wakeup, + .ioctl = nci_uart_tty_ioctl, + .compat_ioctl = nci_uart_tty_ioctl, +}; + +static int __init nci_uart_init(void) +{ + return tty_register_ldisc(&nci_uart_ldisc); +} + +static void __exit nci_uart_exit(void) +{ + tty_unregister_ldisc(&nci_uart_ldisc); +} + +module_init(nci_uart_init); +module_exit(nci_uart_exit); + +MODULE_AUTHOR("Marvell International Ltd."); +MODULE_DESCRIPTION("NFC NCI UART driver"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_LDISC(N_NCI); diff --git a/net/nfc/netlink.c b/net/nfc/netlink.c new file mode 100644 index 000000000..e9ac6a6f9 --- /dev/null +++ b/net/nfc/netlink.c @@ -0,0 +1,1932 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (C) 2011 Instituto Nokia de Tecnologia + * + * Authors: + * Lauro Ramos Venancio + * Aloisio Almeida Jr + * + * Vendor commands implementation based on net/wireless/nl80211.c + * which is: + * + * Copyright 2006-2010 Johannes Berg + * Copyright 2013-2014 Intel Mobile Communications GmbH + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": %s: " fmt, __func__ + +#include +#include +#include + +#include "nfc.h" +#include "llcp.h" + +static const struct genl_multicast_group nfc_genl_mcgrps[] = { + { .name = NFC_GENL_MCAST_EVENT_NAME, }, +}; + +static struct genl_family nfc_genl_family; +static const struct nla_policy nfc_genl_policy[NFC_ATTR_MAX + 1] = { + [NFC_ATTR_DEVICE_INDEX] = { .type = NLA_U32 }, + [NFC_ATTR_DEVICE_NAME] = { .type = NLA_STRING, + .len = NFC_DEVICE_NAME_MAXSIZE }, + [NFC_ATTR_PROTOCOLS] = { .type = NLA_U32 }, + [NFC_ATTR_TARGET_INDEX] = { .type = NLA_U32 }, + [NFC_ATTR_COMM_MODE] = { .type = NLA_U8 }, + [NFC_ATTR_RF_MODE] = { .type = NLA_U8 }, + [NFC_ATTR_DEVICE_POWERED] = { .type = NLA_U8 }, + [NFC_ATTR_IM_PROTOCOLS] = { .type = NLA_U32 }, + [NFC_ATTR_TM_PROTOCOLS] = { .type = NLA_U32 }, + [NFC_ATTR_LLC_PARAM_LTO] = { .type = NLA_U8 }, + [NFC_ATTR_LLC_PARAM_RW] = { .type = NLA_U8 }, + [NFC_ATTR_LLC_PARAM_MIUX] = { .type = NLA_U16 }, + [NFC_ATTR_LLC_SDP] = { .type = NLA_NESTED }, + [NFC_ATTR_FIRMWARE_NAME] = { .type = NLA_STRING, + .len = NFC_FIRMWARE_NAME_MAXSIZE }, + [NFC_ATTR_SE_INDEX] = { .type = NLA_U32 }, + [NFC_ATTR_SE_APDU] = { .type = NLA_BINARY }, + [NFC_ATTR_VENDOR_ID] = { .type = NLA_U32 }, + [NFC_ATTR_VENDOR_SUBCMD] = { .type = NLA_U32 }, + [NFC_ATTR_VENDOR_DATA] = { .type = NLA_BINARY }, + +}; + +static const struct nla_policy nfc_sdp_genl_policy[NFC_SDP_ATTR_MAX + 1] = { + [NFC_SDP_ATTR_URI] = { .type = NLA_STRING, + .len = U8_MAX - 4 }, + [NFC_SDP_ATTR_SAP] = { .type = NLA_U8 }, +}; + +static int nfc_genl_send_target(struct sk_buff *msg, struct nfc_target *target, + struct netlink_callback *cb, int flags) +{ + void *hdr; + + hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &nfc_genl_family, flags, NFC_CMD_GET_TARGET); + if (!hdr) + return -EMSGSIZE; + + genl_dump_check_consistent(cb, hdr); + + if (nla_put_u32(msg, NFC_ATTR_TARGET_INDEX, target->idx) || + nla_put_u32(msg, NFC_ATTR_PROTOCOLS, target->supported_protocols) || + nla_put_u16(msg, NFC_ATTR_TARGET_SENS_RES, target->sens_res) || + nla_put_u8(msg, NFC_ATTR_TARGET_SEL_RES, target->sel_res)) + goto nla_put_failure; + if (target->nfcid1_len > 0 && + nla_put(msg, NFC_ATTR_TARGET_NFCID1, target->nfcid1_len, + target->nfcid1)) + goto nla_put_failure; + if (target->sensb_res_len > 0 && + nla_put(msg, NFC_ATTR_TARGET_SENSB_RES, target->sensb_res_len, + target->sensb_res)) + goto nla_put_failure; + if (target->sensf_res_len > 0 && + nla_put(msg, NFC_ATTR_TARGET_SENSF_RES, target->sensf_res_len, + target->sensf_res)) + goto nla_put_failure; + + if (target->is_iso15693) { + if (nla_put_u8(msg, NFC_ATTR_TARGET_ISO15693_DSFID, + target->iso15693_dsfid) || + nla_put(msg, NFC_ATTR_TARGET_ISO15693_UID, + sizeof(target->iso15693_uid), target->iso15693_uid)) + goto nla_put_failure; + } + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static struct nfc_dev *__get_device_from_cb(struct netlink_callback *cb) +{ + const struct genl_dumpit_info *info = genl_dumpit_info(cb); + struct nfc_dev *dev; + u32 idx; + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX]) + return ERR_PTR(-EINVAL); + + idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + + dev = nfc_get_device(idx); + if (!dev) + return ERR_PTR(-ENODEV); + + return dev; +} + +static int nfc_genl_dump_targets(struct sk_buff *skb, + struct netlink_callback *cb) +{ + int i = cb->args[0]; + struct nfc_dev *dev = (struct nfc_dev *) cb->args[1]; + int rc; + + if (!dev) { + dev = __get_device_from_cb(cb); + if (IS_ERR(dev)) + return PTR_ERR(dev); + + cb->args[1] = (long) dev; + } + + device_lock(&dev->dev); + + cb->seq = dev->targets_generation; + + while (i < dev->n_targets) { + rc = nfc_genl_send_target(skb, &dev->targets[i], cb, + NLM_F_MULTI); + if (rc < 0) + break; + + i++; + } + + device_unlock(&dev->dev); + + cb->args[0] = i; + + return skb->len; +} + +static int nfc_genl_dump_targets_done(struct netlink_callback *cb) +{ + struct nfc_dev *dev = (struct nfc_dev *) cb->args[1]; + + if (dev) + nfc_put_device(dev); + + return 0; +} + +int nfc_genl_targets_found(struct nfc_dev *dev) +{ + struct sk_buff *msg; + void *hdr; + + dev->genl_data.poll_req_portid = 0; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &nfc_genl_family, 0, + NFC_EVENT_TARGETS_FOUND); + if (!hdr) + goto free_msg; + + if (nla_put_u32(msg, NFC_ATTR_DEVICE_INDEX, dev->idx)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + return genlmsg_multicast(&nfc_genl_family, msg, 0, 0, GFP_ATOMIC); + +nla_put_failure: +free_msg: + nlmsg_free(msg); + return -EMSGSIZE; +} + +int nfc_genl_target_lost(struct nfc_dev *dev, u32 target_idx) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &nfc_genl_family, 0, + NFC_EVENT_TARGET_LOST); + if (!hdr) + goto free_msg; + + if (nla_put_string(msg, NFC_ATTR_DEVICE_NAME, nfc_device_name(dev)) || + nla_put_u32(msg, NFC_ATTR_TARGET_INDEX, target_idx)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast(&nfc_genl_family, msg, 0, 0, GFP_KERNEL); + + return 0; + +nla_put_failure: +free_msg: + nlmsg_free(msg); + return -EMSGSIZE; +} + +int nfc_genl_tm_activated(struct nfc_dev *dev, u32 protocol) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &nfc_genl_family, 0, + NFC_EVENT_TM_ACTIVATED); + if (!hdr) + goto free_msg; + + if (nla_put_u32(msg, NFC_ATTR_DEVICE_INDEX, dev->idx)) + goto nla_put_failure; + if (nla_put_u32(msg, NFC_ATTR_TM_PROTOCOLS, protocol)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast(&nfc_genl_family, msg, 0, 0, GFP_KERNEL); + + return 0; + +nla_put_failure: +free_msg: + nlmsg_free(msg); + return -EMSGSIZE; +} + +int nfc_genl_tm_deactivated(struct nfc_dev *dev) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &nfc_genl_family, 0, + NFC_EVENT_TM_DEACTIVATED); + if (!hdr) + goto free_msg; + + if (nla_put_u32(msg, NFC_ATTR_DEVICE_INDEX, dev->idx)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast(&nfc_genl_family, msg, 0, 0, GFP_KERNEL); + + return 0; + +nla_put_failure: +free_msg: + nlmsg_free(msg); + return -EMSGSIZE; +} + +static int nfc_genl_setup_device_added(struct nfc_dev *dev, struct sk_buff *msg) +{ + if (nla_put_string(msg, NFC_ATTR_DEVICE_NAME, nfc_device_name(dev)) || + nla_put_u32(msg, NFC_ATTR_DEVICE_INDEX, dev->idx) || + nla_put_u32(msg, NFC_ATTR_PROTOCOLS, dev->supported_protocols) || + nla_put_u8(msg, NFC_ATTR_DEVICE_POWERED, dev->dev_up) || + nla_put_u8(msg, NFC_ATTR_RF_MODE, dev->rf_mode)) + return -1; + return 0; +} + +int nfc_genl_device_added(struct nfc_dev *dev) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &nfc_genl_family, 0, + NFC_EVENT_DEVICE_ADDED); + if (!hdr) + goto free_msg; + + if (nfc_genl_setup_device_added(dev, msg)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast(&nfc_genl_family, msg, 0, 0, GFP_KERNEL); + + return 0; + +nla_put_failure: +free_msg: + nlmsg_free(msg); + return -EMSGSIZE; +} + +int nfc_genl_device_removed(struct nfc_dev *dev) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &nfc_genl_family, 0, + NFC_EVENT_DEVICE_REMOVED); + if (!hdr) + goto free_msg; + + if (nla_put_u32(msg, NFC_ATTR_DEVICE_INDEX, dev->idx)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast(&nfc_genl_family, msg, 0, 0, GFP_KERNEL); + + return 0; + +nla_put_failure: +free_msg: + nlmsg_free(msg); + return -EMSGSIZE; +} + +int nfc_genl_llc_send_sdres(struct nfc_dev *dev, struct hlist_head *sdres_list) +{ + struct sk_buff *msg; + struct nlattr *sdp_attr, *uri_attr; + struct nfc_llcp_sdp_tlv *sdres; + struct hlist_node *n; + void *hdr; + int rc = -EMSGSIZE; + int i; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &nfc_genl_family, 0, + NFC_EVENT_LLC_SDRES); + if (!hdr) + goto free_msg; + + if (nla_put_u32(msg, NFC_ATTR_DEVICE_INDEX, dev->idx)) + goto nla_put_failure; + + sdp_attr = nla_nest_start_noflag(msg, NFC_ATTR_LLC_SDP); + if (sdp_attr == NULL) { + rc = -ENOMEM; + goto nla_put_failure; + } + + i = 1; + hlist_for_each_entry_safe(sdres, n, sdres_list, node) { + pr_debug("uri: %s, sap: %d\n", sdres->uri, sdres->sap); + + uri_attr = nla_nest_start_noflag(msg, i++); + if (uri_attr == NULL) { + rc = -ENOMEM; + goto nla_put_failure; + } + + if (nla_put_u8(msg, NFC_SDP_ATTR_SAP, sdres->sap)) + goto nla_put_failure; + + if (nla_put_string(msg, NFC_SDP_ATTR_URI, sdres->uri)) + goto nla_put_failure; + + nla_nest_end(msg, uri_attr); + + hlist_del(&sdres->node); + + nfc_llcp_free_sdp_tlv(sdres); + } + + nla_nest_end(msg, sdp_attr); + + genlmsg_end(msg, hdr); + + return genlmsg_multicast(&nfc_genl_family, msg, 0, 0, GFP_ATOMIC); + +nla_put_failure: +free_msg: + nlmsg_free(msg); + + nfc_llcp_free_sdp_tlv_list(sdres_list); + + return rc; +} + +int nfc_genl_se_added(struct nfc_dev *dev, u32 se_idx, u16 type) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &nfc_genl_family, 0, + NFC_EVENT_SE_ADDED); + if (!hdr) + goto free_msg; + + if (nla_put_u32(msg, NFC_ATTR_DEVICE_INDEX, dev->idx) || + nla_put_u32(msg, NFC_ATTR_SE_INDEX, se_idx) || + nla_put_u8(msg, NFC_ATTR_SE_TYPE, type)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast(&nfc_genl_family, msg, 0, 0, GFP_KERNEL); + + return 0; + +nla_put_failure: +free_msg: + nlmsg_free(msg); + return -EMSGSIZE; +} + +int nfc_genl_se_removed(struct nfc_dev *dev, u32 se_idx) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &nfc_genl_family, 0, + NFC_EVENT_SE_REMOVED); + if (!hdr) + goto free_msg; + + if (nla_put_u32(msg, NFC_ATTR_DEVICE_INDEX, dev->idx) || + nla_put_u32(msg, NFC_ATTR_SE_INDEX, se_idx)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast(&nfc_genl_family, msg, 0, 0, GFP_KERNEL); + + return 0; + +nla_put_failure: +free_msg: + nlmsg_free(msg); + return -EMSGSIZE; +} + +int nfc_genl_se_transaction(struct nfc_dev *dev, u8 se_idx, + struct nfc_evt_transaction *evt_transaction) +{ + struct nfc_se *se; + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &nfc_genl_family, 0, + NFC_EVENT_SE_TRANSACTION); + if (!hdr) + goto free_msg; + + se = nfc_find_se(dev, se_idx); + if (!se) + goto free_msg; + + if (nla_put_u32(msg, NFC_ATTR_DEVICE_INDEX, dev->idx) || + nla_put_u32(msg, NFC_ATTR_SE_INDEX, se_idx) || + nla_put_u8(msg, NFC_ATTR_SE_TYPE, se->type) || + nla_put(msg, NFC_ATTR_SE_AID, evt_transaction->aid_len, + evt_transaction->aid) || + nla_put(msg, NFC_ATTR_SE_PARAMS, evt_transaction->params_len, + evt_transaction->params)) + goto nla_put_failure; + + /* evt_transaction is no more used */ + devm_kfree(&dev->dev, evt_transaction); + + genlmsg_end(msg, hdr); + + genlmsg_multicast(&nfc_genl_family, msg, 0, 0, GFP_KERNEL); + + return 0; + +nla_put_failure: +free_msg: + /* evt_transaction is no more used */ + devm_kfree(&dev->dev, evt_transaction); + nlmsg_free(msg); + return -EMSGSIZE; +} + +int nfc_genl_se_connectivity(struct nfc_dev *dev, u8 se_idx) +{ + const struct nfc_se *se; + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &nfc_genl_family, 0, + NFC_EVENT_SE_CONNECTIVITY); + if (!hdr) + goto free_msg; + + se = nfc_find_se(dev, se_idx); + if (!se) + goto free_msg; + + if (nla_put_u32(msg, NFC_ATTR_DEVICE_INDEX, dev->idx) || + nla_put_u32(msg, NFC_ATTR_SE_INDEX, se_idx) || + nla_put_u8(msg, NFC_ATTR_SE_TYPE, se->type)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast(&nfc_genl_family, msg, 0, 0, GFP_KERNEL); + + return 0; + +nla_put_failure: +free_msg: + nlmsg_free(msg); + return -EMSGSIZE; +} + +static int nfc_genl_send_device(struct sk_buff *msg, struct nfc_dev *dev, + u32 portid, u32 seq, + struct netlink_callback *cb, + int flags) +{ + void *hdr; + + hdr = genlmsg_put(msg, portid, seq, &nfc_genl_family, flags, + NFC_CMD_GET_DEVICE); + if (!hdr) + return -EMSGSIZE; + + if (cb) + genl_dump_check_consistent(cb, hdr); + + if (nfc_genl_setup_device_added(dev, msg)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int nfc_genl_dump_devices(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct class_dev_iter *iter = (struct class_dev_iter *) cb->args[0]; + struct nfc_dev *dev = (struct nfc_dev *) cb->args[1]; + bool first_call = false; + + if (!iter) { + first_call = true; + iter = kmalloc(sizeof(struct class_dev_iter), GFP_KERNEL); + if (!iter) + return -ENOMEM; + cb->args[0] = (long) iter; + } + + mutex_lock(&nfc_devlist_mutex); + + cb->seq = nfc_devlist_generation; + + if (first_call) { + nfc_device_iter_init(iter); + dev = nfc_device_iter_next(iter); + } + + while (dev) { + int rc; + + rc = nfc_genl_send_device(skb, dev, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, cb, NLM_F_MULTI); + if (rc < 0) + break; + + dev = nfc_device_iter_next(iter); + } + + mutex_unlock(&nfc_devlist_mutex); + + cb->args[1] = (long) dev; + + return skb->len; +} + +static int nfc_genl_dump_devices_done(struct netlink_callback *cb) +{ + struct class_dev_iter *iter = (struct class_dev_iter *) cb->args[0]; + + if (iter) { + nfc_device_iter_exit(iter); + kfree(iter); + } + + return 0; +} + +int nfc_genl_dep_link_up_event(struct nfc_dev *dev, u32 target_idx, + u8 comm_mode, u8 rf_mode) +{ + struct sk_buff *msg; + void *hdr; + + pr_debug("DEP link is up\n"); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &nfc_genl_family, 0, NFC_CMD_DEP_LINK_UP); + if (!hdr) + goto free_msg; + + if (nla_put_u32(msg, NFC_ATTR_DEVICE_INDEX, dev->idx)) + goto nla_put_failure; + if (rf_mode == NFC_RF_INITIATOR && + nla_put_u32(msg, NFC_ATTR_TARGET_INDEX, target_idx)) + goto nla_put_failure; + if (nla_put_u8(msg, NFC_ATTR_COMM_MODE, comm_mode) || + nla_put_u8(msg, NFC_ATTR_RF_MODE, rf_mode)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + dev->dep_link_up = true; + + genlmsg_multicast(&nfc_genl_family, msg, 0, 0, GFP_ATOMIC); + + return 0; + +nla_put_failure: +free_msg: + nlmsg_free(msg); + return -EMSGSIZE; +} + +int nfc_genl_dep_link_down_event(struct nfc_dev *dev) +{ + struct sk_buff *msg; + void *hdr; + + pr_debug("DEP link is down\n"); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &nfc_genl_family, 0, + NFC_CMD_DEP_LINK_DOWN); + if (!hdr) + goto free_msg; + + if (nla_put_u32(msg, NFC_ATTR_DEVICE_INDEX, dev->idx)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast(&nfc_genl_family, msg, 0, 0, GFP_ATOMIC); + + return 0; + +nla_put_failure: +free_msg: + nlmsg_free(msg); + return -EMSGSIZE; +} + +static int nfc_genl_get_device(struct sk_buff *skb, struct genl_info *info) +{ + struct sk_buff *msg; + struct nfc_dev *dev; + u32 idx; + int rc = -ENOBUFS; + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX]) + return -EINVAL; + + idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + + dev = nfc_get_device(idx); + if (!dev) + return -ENODEV; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) { + rc = -ENOMEM; + goto out_putdev; + } + + rc = nfc_genl_send_device(msg, dev, info->snd_portid, info->snd_seq, + NULL, 0); + if (rc < 0) + goto out_free; + + nfc_put_device(dev); + + return genlmsg_reply(msg, info); + +out_free: + nlmsg_free(msg); +out_putdev: + nfc_put_device(dev); + return rc; +} + +static int nfc_genl_dev_up(struct sk_buff *skb, struct genl_info *info) +{ + struct nfc_dev *dev; + int rc; + u32 idx; + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX]) + return -EINVAL; + + idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + + dev = nfc_get_device(idx); + if (!dev) + return -ENODEV; + + rc = nfc_dev_up(dev); + + nfc_put_device(dev); + return rc; +} + +static int nfc_genl_dev_down(struct sk_buff *skb, struct genl_info *info) +{ + struct nfc_dev *dev; + int rc; + u32 idx; + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX]) + return -EINVAL; + + idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + + dev = nfc_get_device(idx); + if (!dev) + return -ENODEV; + + rc = nfc_dev_down(dev); + + nfc_put_device(dev); + return rc; +} + +static int nfc_genl_start_poll(struct sk_buff *skb, struct genl_info *info) +{ + struct nfc_dev *dev; + int rc; + u32 idx; + u32 im_protocols = 0, tm_protocols = 0; + + pr_debug("Poll start\n"); + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX] || + ((!info->attrs[NFC_ATTR_IM_PROTOCOLS] && + !info->attrs[NFC_ATTR_PROTOCOLS]) && + !info->attrs[NFC_ATTR_TM_PROTOCOLS])) + return -EINVAL; + + idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + + if (info->attrs[NFC_ATTR_TM_PROTOCOLS]) + tm_protocols = nla_get_u32(info->attrs[NFC_ATTR_TM_PROTOCOLS]); + + if (info->attrs[NFC_ATTR_IM_PROTOCOLS]) + im_protocols = nla_get_u32(info->attrs[NFC_ATTR_IM_PROTOCOLS]); + else if (info->attrs[NFC_ATTR_PROTOCOLS]) + im_protocols = nla_get_u32(info->attrs[NFC_ATTR_PROTOCOLS]); + + dev = nfc_get_device(idx); + if (!dev) + return -ENODEV; + + mutex_lock(&dev->genl_data.genl_data_mutex); + + rc = nfc_start_poll(dev, im_protocols, tm_protocols); + if (!rc) + dev->genl_data.poll_req_portid = info->snd_portid; + + mutex_unlock(&dev->genl_data.genl_data_mutex); + + nfc_put_device(dev); + return rc; +} + +static int nfc_genl_stop_poll(struct sk_buff *skb, struct genl_info *info) +{ + struct nfc_dev *dev; + int rc; + u32 idx; + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX]) + return -EINVAL; + + idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + + dev = nfc_get_device(idx); + if (!dev) + return -ENODEV; + + device_lock(&dev->dev); + + if (!dev->polling) { + device_unlock(&dev->dev); + nfc_put_device(dev); + return -EINVAL; + } + + device_unlock(&dev->dev); + + mutex_lock(&dev->genl_data.genl_data_mutex); + + if (dev->genl_data.poll_req_portid != info->snd_portid) { + rc = -EBUSY; + goto out; + } + + rc = nfc_stop_poll(dev); + dev->genl_data.poll_req_portid = 0; + +out: + mutex_unlock(&dev->genl_data.genl_data_mutex); + nfc_put_device(dev); + return rc; +} + +static int nfc_genl_activate_target(struct sk_buff *skb, struct genl_info *info) +{ + struct nfc_dev *dev; + u32 device_idx, target_idx, protocol; + int rc; + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX] || + !info->attrs[NFC_ATTR_TARGET_INDEX] || + !info->attrs[NFC_ATTR_PROTOCOLS]) + return -EINVAL; + + device_idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + + dev = nfc_get_device(device_idx); + if (!dev) + return -ENODEV; + + target_idx = nla_get_u32(info->attrs[NFC_ATTR_TARGET_INDEX]); + protocol = nla_get_u32(info->attrs[NFC_ATTR_PROTOCOLS]); + + nfc_deactivate_target(dev, target_idx, NFC_TARGET_MODE_SLEEP); + rc = nfc_activate_target(dev, target_idx, protocol); + + nfc_put_device(dev); + return rc; +} + +static int nfc_genl_deactivate_target(struct sk_buff *skb, + struct genl_info *info) +{ + struct nfc_dev *dev; + u32 device_idx, target_idx; + int rc; + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX] || + !info->attrs[NFC_ATTR_TARGET_INDEX]) + return -EINVAL; + + device_idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + + dev = nfc_get_device(device_idx); + if (!dev) + return -ENODEV; + + target_idx = nla_get_u32(info->attrs[NFC_ATTR_TARGET_INDEX]); + + rc = nfc_deactivate_target(dev, target_idx, NFC_TARGET_MODE_SLEEP); + + nfc_put_device(dev); + return rc; +} + +static int nfc_genl_dep_link_up(struct sk_buff *skb, struct genl_info *info) +{ + struct nfc_dev *dev; + int rc, tgt_idx; + u32 idx; + u8 comm; + + pr_debug("DEP link up\n"); + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX] || + !info->attrs[NFC_ATTR_COMM_MODE]) + return -EINVAL; + + idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + if (!info->attrs[NFC_ATTR_TARGET_INDEX]) + tgt_idx = NFC_TARGET_IDX_ANY; + else + tgt_idx = nla_get_u32(info->attrs[NFC_ATTR_TARGET_INDEX]); + + comm = nla_get_u8(info->attrs[NFC_ATTR_COMM_MODE]); + + if (comm != NFC_COMM_ACTIVE && comm != NFC_COMM_PASSIVE) + return -EINVAL; + + dev = nfc_get_device(idx); + if (!dev) + return -ENODEV; + + rc = nfc_dep_link_up(dev, tgt_idx, comm); + + nfc_put_device(dev); + + return rc; +} + +static int nfc_genl_dep_link_down(struct sk_buff *skb, struct genl_info *info) +{ + struct nfc_dev *dev; + int rc; + u32 idx; + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX] || + !info->attrs[NFC_ATTR_TARGET_INDEX]) + return -EINVAL; + + idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + + dev = nfc_get_device(idx); + if (!dev) + return -ENODEV; + + rc = nfc_dep_link_down(dev); + + nfc_put_device(dev); + return rc; +} + +static int nfc_genl_send_params(struct sk_buff *msg, + struct nfc_llcp_local *local, + u32 portid, u32 seq) +{ + void *hdr; + + hdr = genlmsg_put(msg, portid, seq, &nfc_genl_family, 0, + NFC_CMD_LLC_GET_PARAMS); + if (!hdr) + return -EMSGSIZE; + + if (nla_put_u32(msg, NFC_ATTR_DEVICE_INDEX, local->dev->idx) || + nla_put_u8(msg, NFC_ATTR_LLC_PARAM_LTO, local->lto) || + nla_put_u8(msg, NFC_ATTR_LLC_PARAM_RW, local->rw) || + nla_put_u16(msg, NFC_ATTR_LLC_PARAM_MIUX, be16_to_cpu(local->miux))) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int nfc_genl_llc_get_params(struct sk_buff *skb, struct genl_info *info) +{ + struct nfc_dev *dev; + struct nfc_llcp_local *local; + int rc = 0; + struct sk_buff *msg = NULL; + u32 idx; + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX] || + !info->attrs[NFC_ATTR_FIRMWARE_NAME]) + return -EINVAL; + + idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + + dev = nfc_get_device(idx); + if (!dev) + return -ENODEV; + + device_lock(&dev->dev); + + local = nfc_llcp_find_local(dev); + if (!local) { + rc = -ENODEV; + goto exit; + } + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) { + rc = -ENOMEM; + goto put_local; + } + + rc = nfc_genl_send_params(msg, local, info->snd_portid, info->snd_seq); + +put_local: + nfc_llcp_local_put(local); + +exit: + device_unlock(&dev->dev); + + nfc_put_device(dev); + + if (rc < 0) { + if (msg) + nlmsg_free(msg); + + return rc; + } + + return genlmsg_reply(msg, info); +} + +static int nfc_genl_llc_set_params(struct sk_buff *skb, struct genl_info *info) +{ + struct nfc_dev *dev; + struct nfc_llcp_local *local; + u8 rw = 0; + u16 miux = 0; + u32 idx; + int rc = 0; + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX] || + (!info->attrs[NFC_ATTR_LLC_PARAM_LTO] && + !info->attrs[NFC_ATTR_LLC_PARAM_RW] && + !info->attrs[NFC_ATTR_LLC_PARAM_MIUX])) + return -EINVAL; + + if (info->attrs[NFC_ATTR_LLC_PARAM_RW]) { + rw = nla_get_u8(info->attrs[NFC_ATTR_LLC_PARAM_RW]); + + if (rw > LLCP_MAX_RW) + return -EINVAL; + } + + if (info->attrs[NFC_ATTR_LLC_PARAM_MIUX]) { + miux = nla_get_u16(info->attrs[NFC_ATTR_LLC_PARAM_MIUX]); + + if (miux > LLCP_MAX_MIUX) + return -EINVAL; + } + + idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + + dev = nfc_get_device(idx); + if (!dev) + return -ENODEV; + + device_lock(&dev->dev); + + local = nfc_llcp_find_local(dev); + if (!local) { + rc = -ENODEV; + goto exit; + } + + if (info->attrs[NFC_ATTR_LLC_PARAM_LTO]) { + if (dev->dep_link_up) { + rc = -EINPROGRESS; + goto put_local; + } + + local->lto = nla_get_u8(info->attrs[NFC_ATTR_LLC_PARAM_LTO]); + } + + if (info->attrs[NFC_ATTR_LLC_PARAM_RW]) + local->rw = rw; + + if (info->attrs[NFC_ATTR_LLC_PARAM_MIUX]) + local->miux = cpu_to_be16(miux); + +put_local: + nfc_llcp_local_put(local); + +exit: + device_unlock(&dev->dev); + + nfc_put_device(dev); + + return rc; +} + +static int nfc_genl_llc_sdreq(struct sk_buff *skb, struct genl_info *info) +{ + struct nfc_dev *dev; + struct nfc_llcp_local *local; + struct nlattr *attr, *sdp_attrs[NFC_SDP_ATTR_MAX+1]; + u32 idx; + u8 tid; + char *uri; + int rc = 0, rem; + size_t uri_len, tlvs_len; + struct hlist_head sdreq_list; + struct nfc_llcp_sdp_tlv *sdreq; + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX] || + !info->attrs[NFC_ATTR_LLC_SDP]) + return -EINVAL; + + idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + + dev = nfc_get_device(idx); + if (!dev) + return -ENODEV; + + device_lock(&dev->dev); + + if (dev->dep_link_up == false) { + rc = -ENOLINK; + goto exit; + } + + local = nfc_llcp_find_local(dev); + if (!local) { + rc = -ENODEV; + goto exit; + } + + INIT_HLIST_HEAD(&sdreq_list); + + tlvs_len = 0; + + nla_for_each_nested(attr, info->attrs[NFC_ATTR_LLC_SDP], rem) { + rc = nla_parse_nested_deprecated(sdp_attrs, NFC_SDP_ATTR_MAX, + attr, nfc_sdp_genl_policy, + info->extack); + + if (rc != 0) { + rc = -EINVAL; + goto put_local; + } + + if (!sdp_attrs[NFC_SDP_ATTR_URI]) + continue; + + uri_len = nla_len(sdp_attrs[NFC_SDP_ATTR_URI]); + if (uri_len == 0) + continue; + + uri = nla_data(sdp_attrs[NFC_SDP_ATTR_URI]); + if (uri == NULL || *uri == 0) + continue; + + tid = local->sdreq_next_tid++; + + sdreq = nfc_llcp_build_sdreq_tlv(tid, uri, uri_len); + if (sdreq == NULL) { + rc = -ENOMEM; + goto put_local; + } + + tlvs_len += sdreq->tlv_len; + + hlist_add_head(&sdreq->node, &sdreq_list); + } + + if (hlist_empty(&sdreq_list)) { + rc = -EINVAL; + goto put_local; + } + + rc = nfc_llcp_send_snl_sdreq(local, &sdreq_list, tlvs_len); + +put_local: + nfc_llcp_local_put(local); + +exit: + device_unlock(&dev->dev); + + nfc_put_device(dev); + + return rc; +} + +static int nfc_genl_fw_download(struct sk_buff *skb, struct genl_info *info) +{ + struct nfc_dev *dev; + int rc; + u32 idx; + char firmware_name[NFC_FIRMWARE_NAME_MAXSIZE + 1]; + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX] || !info->attrs[NFC_ATTR_FIRMWARE_NAME]) + return -EINVAL; + + idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + + dev = nfc_get_device(idx); + if (!dev) + return -ENODEV; + + nla_strscpy(firmware_name, info->attrs[NFC_ATTR_FIRMWARE_NAME], + sizeof(firmware_name)); + + rc = nfc_fw_download(dev, firmware_name); + + nfc_put_device(dev); + return rc; +} + +int nfc_genl_fw_download_done(struct nfc_dev *dev, const char *firmware_name, + u32 result) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!msg) + return -ENOMEM; + + hdr = genlmsg_put(msg, 0, 0, &nfc_genl_family, 0, + NFC_CMD_FW_DOWNLOAD); + if (!hdr) + goto free_msg; + + if (nla_put_string(msg, NFC_ATTR_FIRMWARE_NAME, firmware_name) || + nla_put_u32(msg, NFC_ATTR_FIRMWARE_DOWNLOAD_STATUS, result) || + nla_put_u32(msg, NFC_ATTR_DEVICE_INDEX, dev->idx)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast(&nfc_genl_family, msg, 0, 0, GFP_ATOMIC); + + return 0; + +nla_put_failure: +free_msg: + nlmsg_free(msg); + return -EMSGSIZE; +} + +static int nfc_genl_enable_se(struct sk_buff *skb, struct genl_info *info) +{ + struct nfc_dev *dev; + int rc; + u32 idx, se_idx; + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX] || + !info->attrs[NFC_ATTR_SE_INDEX]) + return -EINVAL; + + idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + se_idx = nla_get_u32(info->attrs[NFC_ATTR_SE_INDEX]); + + dev = nfc_get_device(idx); + if (!dev) + return -ENODEV; + + rc = nfc_enable_se(dev, se_idx); + + nfc_put_device(dev); + return rc; +} + +static int nfc_genl_disable_se(struct sk_buff *skb, struct genl_info *info) +{ + struct nfc_dev *dev; + int rc; + u32 idx, se_idx; + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX] || + !info->attrs[NFC_ATTR_SE_INDEX]) + return -EINVAL; + + idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + se_idx = nla_get_u32(info->attrs[NFC_ATTR_SE_INDEX]); + + dev = nfc_get_device(idx); + if (!dev) + return -ENODEV; + + rc = nfc_disable_se(dev, se_idx); + + nfc_put_device(dev); + return rc; +} + +static int nfc_genl_send_se(struct sk_buff *msg, struct nfc_dev *dev, + u32 portid, u32 seq, + struct netlink_callback *cb, + int flags) +{ + void *hdr; + struct nfc_se *se, *n; + + list_for_each_entry_safe(se, n, &dev->secure_elements, list) { + hdr = genlmsg_put(msg, portid, seq, &nfc_genl_family, flags, + NFC_CMD_GET_SE); + if (!hdr) + goto nla_put_failure; + + if (cb) + genl_dump_check_consistent(cb, hdr); + + if (nla_put_u32(msg, NFC_ATTR_DEVICE_INDEX, dev->idx) || + nla_put_u32(msg, NFC_ATTR_SE_INDEX, se->idx) || + nla_put_u8(msg, NFC_ATTR_SE_TYPE, se->type)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + } + + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int nfc_genl_dump_ses(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct class_dev_iter *iter = (struct class_dev_iter *) cb->args[0]; + struct nfc_dev *dev = (struct nfc_dev *) cb->args[1]; + bool first_call = false; + + if (!iter) { + first_call = true; + iter = kmalloc(sizeof(struct class_dev_iter), GFP_KERNEL); + if (!iter) + return -ENOMEM; + cb->args[0] = (long) iter; + } + + mutex_lock(&nfc_devlist_mutex); + + cb->seq = nfc_devlist_generation; + + if (first_call) { + nfc_device_iter_init(iter); + dev = nfc_device_iter_next(iter); + } + + while (dev) { + int rc; + + rc = nfc_genl_send_se(skb, dev, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, cb, NLM_F_MULTI); + if (rc < 0) + break; + + dev = nfc_device_iter_next(iter); + } + + mutex_unlock(&nfc_devlist_mutex); + + cb->args[1] = (long) dev; + + return skb->len; +} + +static int nfc_genl_dump_ses_done(struct netlink_callback *cb) +{ + struct class_dev_iter *iter = (struct class_dev_iter *) cb->args[0]; + + if (iter) { + nfc_device_iter_exit(iter); + kfree(iter); + } + + return 0; +} + +static int nfc_se_io(struct nfc_dev *dev, u32 se_idx, + u8 *apdu, size_t apdu_length, + se_io_cb_t cb, void *cb_context) +{ + struct nfc_se *se; + int rc; + + pr_debug("%s se index %d\n", dev_name(&dev->dev), se_idx); + + device_lock(&dev->dev); + + if (!device_is_registered(&dev->dev)) { + rc = -ENODEV; + goto error; + } + + if (!dev->dev_up) { + rc = -ENODEV; + goto error; + } + + if (!dev->ops->se_io) { + rc = -EOPNOTSUPP; + goto error; + } + + se = nfc_find_se(dev, se_idx); + if (!se) { + rc = -EINVAL; + goto error; + } + + if (se->state != NFC_SE_ENABLED) { + rc = -ENODEV; + goto error; + } + + rc = dev->ops->se_io(dev, se_idx, apdu, + apdu_length, cb, cb_context); + + device_unlock(&dev->dev); + return rc; + +error: + device_unlock(&dev->dev); + kfree(cb_context); + return rc; +} + +struct se_io_ctx { + u32 dev_idx; + u32 se_idx; +}; + +static void se_io_cb(void *context, u8 *apdu, size_t apdu_len, int err) +{ + struct se_io_ctx *ctx = context; + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) { + kfree(ctx); + return; + } + + hdr = genlmsg_put(msg, 0, 0, &nfc_genl_family, 0, + NFC_CMD_SE_IO); + if (!hdr) + goto free_msg; + + if (nla_put_u32(msg, NFC_ATTR_DEVICE_INDEX, ctx->dev_idx) || + nla_put_u32(msg, NFC_ATTR_SE_INDEX, ctx->se_idx) || + nla_put(msg, NFC_ATTR_SE_APDU, apdu_len, apdu)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast(&nfc_genl_family, msg, 0, 0, GFP_KERNEL); + + kfree(ctx); + + return; + +nla_put_failure: +free_msg: + nlmsg_free(msg); + kfree(ctx); + + return; +} + +static int nfc_genl_se_io(struct sk_buff *skb, struct genl_info *info) +{ + struct nfc_dev *dev; + struct se_io_ctx *ctx; + u32 dev_idx, se_idx; + u8 *apdu; + size_t apdu_len; + int rc; + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX] || + !info->attrs[NFC_ATTR_SE_INDEX] || + !info->attrs[NFC_ATTR_SE_APDU]) + return -EINVAL; + + dev_idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + se_idx = nla_get_u32(info->attrs[NFC_ATTR_SE_INDEX]); + + dev = nfc_get_device(dev_idx); + if (!dev) + return -ENODEV; + + if (!dev->ops || !dev->ops->se_io) { + rc = -EOPNOTSUPP; + goto put_dev; + } + + apdu_len = nla_len(info->attrs[NFC_ATTR_SE_APDU]); + if (apdu_len == 0) { + rc = -EINVAL; + goto put_dev; + } + + apdu = nla_data(info->attrs[NFC_ATTR_SE_APDU]); + if (!apdu) { + rc = -EINVAL; + goto put_dev; + } + + ctx = kzalloc(sizeof(struct se_io_ctx), GFP_KERNEL); + if (!ctx) { + rc = -ENOMEM; + goto put_dev; + } + + ctx->dev_idx = dev_idx; + ctx->se_idx = se_idx; + + rc = nfc_se_io(dev, se_idx, apdu, apdu_len, se_io_cb, ctx); + +put_dev: + nfc_put_device(dev); + return rc; +} + +static int nfc_genl_vendor_cmd(struct sk_buff *skb, + struct genl_info *info) +{ + struct nfc_dev *dev; + const struct nfc_vendor_cmd *cmd; + u32 dev_idx, vid, subcmd; + u8 *data; + size_t data_len; + int i, err; + + if (!info->attrs[NFC_ATTR_DEVICE_INDEX] || + !info->attrs[NFC_ATTR_VENDOR_ID] || + !info->attrs[NFC_ATTR_VENDOR_SUBCMD]) + return -EINVAL; + + dev_idx = nla_get_u32(info->attrs[NFC_ATTR_DEVICE_INDEX]); + vid = nla_get_u32(info->attrs[NFC_ATTR_VENDOR_ID]); + subcmd = nla_get_u32(info->attrs[NFC_ATTR_VENDOR_SUBCMD]); + + dev = nfc_get_device(dev_idx); + if (!dev) + return -ENODEV; + + if (!dev->vendor_cmds || !dev->n_vendor_cmds) { + err = -ENODEV; + goto put_dev; + } + + if (info->attrs[NFC_ATTR_VENDOR_DATA]) { + data = nla_data(info->attrs[NFC_ATTR_VENDOR_DATA]); + data_len = nla_len(info->attrs[NFC_ATTR_VENDOR_DATA]); + if (data_len == 0) { + err = -EINVAL; + goto put_dev; + } + } else { + data = NULL; + data_len = 0; + } + + for (i = 0; i < dev->n_vendor_cmds; i++) { + cmd = &dev->vendor_cmds[i]; + + if (cmd->vendor_id != vid || cmd->subcmd != subcmd) + continue; + + dev->cur_cmd_info = info; + err = cmd->doit(dev, data, data_len); + dev->cur_cmd_info = NULL; + goto put_dev; + } + + err = -EOPNOTSUPP; + +put_dev: + nfc_put_device(dev); + return err; +} + +/* message building helper */ +static inline void *nfc_hdr_put(struct sk_buff *skb, u32 portid, u32 seq, + int flags, u8 cmd) +{ + /* since there is no private header just add the generic one */ + return genlmsg_put(skb, portid, seq, &nfc_genl_family, flags, cmd); +} + +static struct sk_buff * +__nfc_alloc_vendor_cmd_skb(struct nfc_dev *dev, int approxlen, + u32 portid, u32 seq, + enum nfc_attrs attr, + u32 oui, u32 subcmd, gfp_t gfp) +{ + struct sk_buff *skb; + void *hdr; + + skb = nlmsg_new(approxlen + 100, gfp); + if (!skb) + return NULL; + + hdr = nfc_hdr_put(skb, portid, seq, 0, NFC_CMD_VENDOR); + if (!hdr) { + kfree_skb(skb); + return NULL; + } + + if (nla_put_u32(skb, NFC_ATTR_DEVICE_INDEX, dev->idx)) + goto nla_put_failure; + if (nla_put_u32(skb, NFC_ATTR_VENDOR_ID, oui)) + goto nla_put_failure; + if (nla_put_u32(skb, NFC_ATTR_VENDOR_SUBCMD, subcmd)) + goto nla_put_failure; + + ((void **)skb->cb)[0] = dev; + ((void **)skb->cb)[1] = hdr; + + return skb; + +nla_put_failure: + kfree_skb(skb); + return NULL; +} + +struct sk_buff *__nfc_alloc_vendor_cmd_reply_skb(struct nfc_dev *dev, + enum nfc_attrs attr, + u32 oui, u32 subcmd, + int approxlen) +{ + if (WARN_ON(!dev->cur_cmd_info)) + return NULL; + + return __nfc_alloc_vendor_cmd_skb(dev, approxlen, + dev->cur_cmd_info->snd_portid, + dev->cur_cmd_info->snd_seq, attr, + oui, subcmd, GFP_KERNEL); +} +EXPORT_SYMBOL(__nfc_alloc_vendor_cmd_reply_skb); + +int nfc_vendor_cmd_reply(struct sk_buff *skb) +{ + struct nfc_dev *dev = ((void **)skb->cb)[0]; + void *hdr = ((void **)skb->cb)[1]; + + /* clear CB data for netlink core to own from now on */ + memset(skb->cb, 0, sizeof(skb->cb)); + + if (WARN_ON(!dev->cur_cmd_info)) { + kfree_skb(skb); + return -EINVAL; + } + + genlmsg_end(skb, hdr); + return genlmsg_reply(skb, dev->cur_cmd_info); +} +EXPORT_SYMBOL(nfc_vendor_cmd_reply); + +static const struct genl_ops nfc_genl_ops[] = { + { + .cmd = NFC_CMD_GET_DEVICE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nfc_genl_get_device, + .dumpit = nfc_genl_dump_devices, + .done = nfc_genl_dump_devices_done, + }, + { + .cmd = NFC_CMD_DEV_UP, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nfc_genl_dev_up, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NFC_CMD_DEV_DOWN, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nfc_genl_dev_down, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NFC_CMD_START_POLL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nfc_genl_start_poll, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NFC_CMD_STOP_POLL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nfc_genl_stop_poll, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NFC_CMD_DEP_LINK_UP, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nfc_genl_dep_link_up, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NFC_CMD_DEP_LINK_DOWN, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nfc_genl_dep_link_down, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NFC_CMD_GET_TARGET, + .validate = GENL_DONT_VALIDATE_STRICT | + GENL_DONT_VALIDATE_DUMP_STRICT, + .dumpit = nfc_genl_dump_targets, + .done = nfc_genl_dump_targets_done, + }, + { + .cmd = NFC_CMD_LLC_GET_PARAMS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nfc_genl_llc_get_params, + }, + { + .cmd = NFC_CMD_LLC_SET_PARAMS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nfc_genl_llc_set_params, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NFC_CMD_LLC_SDREQ, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nfc_genl_llc_sdreq, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NFC_CMD_FW_DOWNLOAD, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nfc_genl_fw_download, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NFC_CMD_ENABLE_SE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nfc_genl_enable_se, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NFC_CMD_DISABLE_SE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nfc_genl_disable_se, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NFC_CMD_GET_SE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .dumpit = nfc_genl_dump_ses, + .done = nfc_genl_dump_ses_done, + }, + { + .cmd = NFC_CMD_SE_IO, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nfc_genl_se_io, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NFC_CMD_ACTIVATE_TARGET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nfc_genl_activate_target, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NFC_CMD_VENDOR, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nfc_genl_vendor_cmd, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NFC_CMD_DEACTIVATE_TARGET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nfc_genl_deactivate_target, + .flags = GENL_ADMIN_PERM, + }, +}; + +static struct genl_family nfc_genl_family __ro_after_init = { + .hdrsize = 0, + .name = NFC_GENL_NAME, + .version = NFC_GENL_VERSION, + .maxattr = NFC_ATTR_MAX, + .policy = nfc_genl_policy, + .module = THIS_MODULE, + .ops = nfc_genl_ops, + .n_ops = ARRAY_SIZE(nfc_genl_ops), + .resv_start_op = NFC_CMD_DEACTIVATE_TARGET + 1, + .mcgrps = nfc_genl_mcgrps, + .n_mcgrps = ARRAY_SIZE(nfc_genl_mcgrps), +}; + + +struct urelease_work { + struct work_struct w; + u32 portid; +}; + +static void nfc_urelease_event_work(struct work_struct *work) +{ + struct urelease_work *w = container_of(work, struct urelease_work, w); + struct class_dev_iter iter; + struct nfc_dev *dev; + + pr_debug("portid %d\n", w->portid); + + mutex_lock(&nfc_devlist_mutex); + + nfc_device_iter_init(&iter); + dev = nfc_device_iter_next(&iter); + + while (dev) { + mutex_lock(&dev->genl_data.genl_data_mutex); + + if (dev->genl_data.poll_req_portid == w->portid) { + nfc_stop_poll(dev); + dev->genl_data.poll_req_portid = 0; + } + + mutex_unlock(&dev->genl_data.genl_data_mutex); + + dev = nfc_device_iter_next(&iter); + } + + nfc_device_iter_exit(&iter); + + mutex_unlock(&nfc_devlist_mutex); + + kfree(w); +} + +static int nfc_genl_rcv_nl_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct netlink_notify *n = ptr; + struct urelease_work *w; + + if (event != NETLINK_URELEASE || n->protocol != NETLINK_GENERIC) + goto out; + + pr_debug("NETLINK_URELEASE event from id %d\n", n->portid); + + w = kmalloc(sizeof(*w), GFP_ATOMIC); + if (w) { + INIT_WORK(&w->w, nfc_urelease_event_work); + w->portid = n->portid; + schedule_work(&w->w); + } + +out: + return NOTIFY_DONE; +} + +void nfc_genl_data_init(struct nfc_genl_data *genl_data) +{ + genl_data->poll_req_portid = 0; + mutex_init(&genl_data->genl_data_mutex); +} + +void nfc_genl_data_exit(struct nfc_genl_data *genl_data) +{ + mutex_destroy(&genl_data->genl_data_mutex); +} + +static struct notifier_block nl_notifier = { + .notifier_call = nfc_genl_rcv_nl_event, +}; + +/** + * nfc_genl_init() - Initialize netlink interface + * + * This initialization function registers the nfc netlink family. + */ +int __init nfc_genl_init(void) +{ + int rc; + + rc = genl_register_family(&nfc_genl_family); + if (rc) + return rc; + + netlink_register_notifier(&nl_notifier); + + return 0; +} + +/** + * nfc_genl_exit() - Deinitialize netlink interface + * + * This exit function unregisters the nfc netlink family. + */ +void nfc_genl_exit(void) +{ + netlink_unregister_notifier(&nl_notifier); + genl_unregister_family(&nfc_genl_family); +} diff --git a/net/nfc/nfc.h b/net/nfc/nfc.h new file mode 100644 index 000000000..0b1e6466f --- /dev/null +++ b/net/nfc/nfc.h @@ -0,0 +1,151 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* + * Copyright (C) 2011 Instituto Nokia de Tecnologia + * + * Authors: + * Lauro Ramos Venancio + * Aloisio Almeida Jr + */ + +#ifndef __LOCAL_NFC_H +#define __LOCAL_NFC_H + +#include +#include + +#define NFC_TARGET_MODE_IDLE 0 +#define NFC_TARGET_MODE_SLEEP 1 + +struct nfc_protocol { + int id; + struct proto *proto; + struct module *owner; + int (*create)(struct net *net, struct socket *sock, + const struct nfc_protocol *nfc_proto, int kern); +}; + +struct nfc_rawsock { + struct sock sk; + struct nfc_dev *dev; + u32 target_idx; + struct work_struct tx_work; + bool tx_work_scheduled; +}; + +struct nfc_sock_list { + struct hlist_head head; + rwlock_t lock; +}; + +#define nfc_rawsock(sk) ((struct nfc_rawsock *) sk) +#define to_rawsock_sk(_tx_work) \ + ((struct sock *) container_of(_tx_work, struct nfc_rawsock, tx_work)) + +struct nfc_llcp_sdp_tlv; + +void nfc_llcp_mac_is_down(struct nfc_dev *dev); +void nfc_llcp_mac_is_up(struct nfc_dev *dev, u32 target_idx, + u8 comm_mode, u8 rf_mode); +int nfc_llcp_register_device(struct nfc_dev *dev); +void nfc_llcp_unregister_device(struct nfc_dev *dev); +int nfc_llcp_set_remote_gb(struct nfc_dev *dev, const u8 *gb, u8 gb_len); +u8 *nfc_llcp_general_bytes(struct nfc_dev *dev, size_t *general_bytes_len); +int nfc_llcp_data_received(struct nfc_dev *dev, struct sk_buff *skb); +struct nfc_llcp_local *nfc_llcp_find_local(struct nfc_dev *dev); +int nfc_llcp_local_put(struct nfc_llcp_local *local); +int __init nfc_llcp_init(void); +void nfc_llcp_exit(void); +void nfc_llcp_free_sdp_tlv(struct nfc_llcp_sdp_tlv *sdp); +void nfc_llcp_free_sdp_tlv_list(struct hlist_head *head); + +int __init rawsock_init(void); +void rawsock_exit(void); + +int __init af_nfc_init(void); +void af_nfc_exit(void); +int nfc_proto_register(const struct nfc_protocol *nfc_proto); +void nfc_proto_unregister(const struct nfc_protocol *nfc_proto); + +extern int nfc_devlist_generation; +extern struct mutex nfc_devlist_mutex; + +int __init nfc_genl_init(void); +void nfc_genl_exit(void); + +void nfc_genl_data_init(struct nfc_genl_data *genl_data); +void nfc_genl_data_exit(struct nfc_genl_data *genl_data); + +int nfc_genl_targets_found(struct nfc_dev *dev); +int nfc_genl_target_lost(struct nfc_dev *dev, u32 target_idx); + +int nfc_genl_device_added(struct nfc_dev *dev); +int nfc_genl_device_removed(struct nfc_dev *dev); + +int nfc_genl_dep_link_up_event(struct nfc_dev *dev, u32 target_idx, + u8 comm_mode, u8 rf_mode); +int nfc_genl_dep_link_down_event(struct nfc_dev *dev); + +int nfc_genl_tm_activated(struct nfc_dev *dev, u32 protocol); +int nfc_genl_tm_deactivated(struct nfc_dev *dev); + +int nfc_genl_llc_send_sdres(struct nfc_dev *dev, struct hlist_head *sdres_list); + +int nfc_genl_se_added(struct nfc_dev *dev, u32 se_idx, u16 type); +int nfc_genl_se_removed(struct nfc_dev *dev, u32 se_idx); +int nfc_genl_se_transaction(struct nfc_dev *dev, u8 se_idx, + struct nfc_evt_transaction *evt_transaction); +int nfc_genl_se_connectivity(struct nfc_dev *dev, u8 se_idx); + +struct nfc_dev *nfc_get_device(unsigned int idx); + +static inline void nfc_put_device(struct nfc_dev *dev) +{ + put_device(&dev->dev); +} + +static inline void nfc_device_iter_init(struct class_dev_iter *iter) +{ + class_dev_iter_init(iter, &nfc_class, NULL, NULL); +} + +static inline struct nfc_dev *nfc_device_iter_next(struct class_dev_iter *iter) +{ + struct device *d = class_dev_iter_next(iter); + if (!d) + return NULL; + + return to_nfc_dev(d); +} + +static inline void nfc_device_iter_exit(struct class_dev_iter *iter) +{ + class_dev_iter_exit(iter); +} + +int nfc_fw_download(struct nfc_dev *dev, const char *firmware_name); +int nfc_genl_fw_download_done(struct nfc_dev *dev, const char *firmware_name, + u32 result); + +int nfc_dev_up(struct nfc_dev *dev); + +int nfc_dev_down(struct nfc_dev *dev); + +int nfc_start_poll(struct nfc_dev *dev, u32 im_protocols, u32 tm_protocols); + +int nfc_stop_poll(struct nfc_dev *dev); + +int nfc_dep_link_up(struct nfc_dev *dev, int target_idx, u8 comm_mode); + +int nfc_dep_link_down(struct nfc_dev *dev); + +int nfc_activate_target(struct nfc_dev *dev, u32 target_idx, u32 protocol); + +int nfc_deactivate_target(struct nfc_dev *dev, u32 target_idx, u8 mode); + +int nfc_data_exchange(struct nfc_dev *dev, u32 target_idx, struct sk_buff *skb, + data_exchange_cb_t cb, void *cb_context); + +int nfc_enable_se(struct nfc_dev *dev, u32 se_idx); +int nfc_disable_se(struct nfc_dev *dev, u32 se_idx); + +#endif /* __LOCAL_NFC_H */ diff --git a/net/nfc/rawsock.c b/net/nfc/rawsock.c new file mode 100644 index 000000000..8dd569765 --- /dev/null +++ b/net/nfc/rawsock.c @@ -0,0 +1,418 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (C) 2011 Instituto Nokia de Tecnologia + * + * Authors: + * Aloisio Almeida Jr + * Lauro Ramos Venancio + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": %s: " fmt, __func__ + +#include +#include +#include + +#include "nfc.h" + +static struct nfc_sock_list raw_sk_list = { + .lock = __RW_LOCK_UNLOCKED(raw_sk_list.lock) +}; + +static void nfc_sock_link(struct nfc_sock_list *l, struct sock *sk) +{ + write_lock(&l->lock); + sk_add_node(sk, &l->head); + write_unlock(&l->lock); +} + +static void nfc_sock_unlink(struct nfc_sock_list *l, struct sock *sk) +{ + write_lock(&l->lock); + sk_del_node_init(sk); + write_unlock(&l->lock); +} + +static void rawsock_write_queue_purge(struct sock *sk) +{ + pr_debug("sk=%p\n", sk); + + spin_lock_bh(&sk->sk_write_queue.lock); + __skb_queue_purge(&sk->sk_write_queue); + nfc_rawsock(sk)->tx_work_scheduled = false; + spin_unlock_bh(&sk->sk_write_queue.lock); +} + +static void rawsock_report_error(struct sock *sk, int err) +{ + pr_debug("sk=%p err=%d\n", sk, err); + + sk->sk_shutdown = SHUTDOWN_MASK; + sk->sk_err = -err; + sk_error_report(sk); + + rawsock_write_queue_purge(sk); +} + +static int rawsock_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + + pr_debug("sock=%p sk=%p\n", sock, sk); + + if (!sk) + return 0; + + if (sock->type == SOCK_RAW) + nfc_sock_unlink(&raw_sk_list, sk); + + sock_orphan(sk); + sock_put(sk); + + return 0; +} + +static int rawsock_connect(struct socket *sock, struct sockaddr *_addr, + int len, int flags) +{ + struct sock *sk = sock->sk; + struct sockaddr_nfc *addr = (struct sockaddr_nfc *)_addr; + struct nfc_dev *dev; + int rc = 0; + + pr_debug("sock=%p sk=%p flags=%d\n", sock, sk, flags); + + if (!addr || len < sizeof(struct sockaddr_nfc) || + addr->sa_family != AF_NFC) + return -EINVAL; + + pr_debug("addr dev_idx=%u target_idx=%u protocol=%u\n", + addr->dev_idx, addr->target_idx, addr->nfc_protocol); + + lock_sock(sk); + + if (sock->state == SS_CONNECTED) { + rc = -EISCONN; + goto error; + } + + dev = nfc_get_device(addr->dev_idx); + if (!dev) { + rc = -ENODEV; + goto error; + } + + if (addr->target_idx > dev->target_next_idx - 1 || + addr->target_idx < dev->target_next_idx - dev->n_targets) { + rc = -EINVAL; + goto put_dev; + } + + rc = nfc_activate_target(dev, addr->target_idx, addr->nfc_protocol); + if (rc) + goto put_dev; + + nfc_rawsock(sk)->dev = dev; + nfc_rawsock(sk)->target_idx = addr->target_idx; + sock->state = SS_CONNECTED; + sk->sk_state = TCP_ESTABLISHED; + sk->sk_state_change(sk); + + release_sock(sk); + return 0; + +put_dev: + nfc_put_device(dev); +error: + release_sock(sk); + return rc; +} + +static int rawsock_add_header(struct sk_buff *skb) +{ + *(u8 *)skb_push(skb, NFC_HEADER_SIZE) = 0; + + return 0; +} + +static void rawsock_data_exchange_complete(void *context, struct sk_buff *skb, + int err) +{ + struct sock *sk = (struct sock *) context; + + BUG_ON(in_hardirq()); + + pr_debug("sk=%p err=%d\n", sk, err); + + if (err) + goto error; + + err = rawsock_add_header(skb); + if (err) + goto error_skb; + + err = sock_queue_rcv_skb(sk, skb); + if (err) + goto error_skb; + + spin_lock_bh(&sk->sk_write_queue.lock); + if (!skb_queue_empty(&sk->sk_write_queue)) + schedule_work(&nfc_rawsock(sk)->tx_work); + else + nfc_rawsock(sk)->tx_work_scheduled = false; + spin_unlock_bh(&sk->sk_write_queue.lock); + + sock_put(sk); + return; + +error_skb: + kfree_skb(skb); + +error: + rawsock_report_error(sk, err); + sock_put(sk); +} + +static void rawsock_tx_work(struct work_struct *work) +{ + struct sock *sk = to_rawsock_sk(work); + struct nfc_dev *dev = nfc_rawsock(sk)->dev; + u32 target_idx = nfc_rawsock(sk)->target_idx; + struct sk_buff *skb; + int rc; + + pr_debug("sk=%p target_idx=%u\n", sk, target_idx); + + if (sk->sk_shutdown & SEND_SHUTDOWN) { + rawsock_write_queue_purge(sk); + return; + } + + skb = skb_dequeue(&sk->sk_write_queue); + + sock_hold(sk); + rc = nfc_data_exchange(dev, target_idx, skb, + rawsock_data_exchange_complete, sk); + if (rc) { + rawsock_report_error(sk, rc); + sock_put(sk); + } +} + +static int rawsock_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) +{ + struct sock *sk = sock->sk; + struct nfc_dev *dev = nfc_rawsock(sk)->dev; + struct sk_buff *skb; + int rc; + + pr_debug("sock=%p sk=%p len=%zu\n", sock, sk, len); + + if (msg->msg_namelen) + return -EOPNOTSUPP; + + if (sock->state != SS_CONNECTED) + return -ENOTCONN; + + skb = nfc_alloc_send_skb(dev, sk, msg->msg_flags, len, &rc); + if (skb == NULL) + return rc; + + rc = memcpy_from_msg(skb_put(skb, len), msg, len); + if (rc < 0) { + kfree_skb(skb); + return rc; + } + + spin_lock_bh(&sk->sk_write_queue.lock); + __skb_queue_tail(&sk->sk_write_queue, skb); + if (!nfc_rawsock(sk)->tx_work_scheduled) { + schedule_work(&nfc_rawsock(sk)->tx_work); + nfc_rawsock(sk)->tx_work_scheduled = true; + } + spin_unlock_bh(&sk->sk_write_queue.lock); + + return len; +} + +static int rawsock_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, + int flags) +{ + struct sock *sk = sock->sk; + struct sk_buff *skb; + int copied; + int rc; + + pr_debug("sock=%p sk=%p len=%zu flags=%d\n", sock, sk, len, flags); + + skb = skb_recv_datagram(sk, flags, &rc); + if (!skb) + return rc; + + copied = skb->len; + if (len < copied) { + msg->msg_flags |= MSG_TRUNC; + copied = len; + } + + rc = skb_copy_datagram_msg(skb, 0, msg, copied); + + skb_free_datagram(sk, skb); + + return rc ? : copied; +} + +static const struct proto_ops rawsock_ops = { + .family = PF_NFC, + .owner = THIS_MODULE, + .release = rawsock_release, + .bind = sock_no_bind, + .connect = rawsock_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = sock_no_getname, + .poll = datagram_poll, + .ioctl = sock_no_ioctl, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .sendmsg = rawsock_sendmsg, + .recvmsg = rawsock_recvmsg, + .mmap = sock_no_mmap, +}; + +static const struct proto_ops rawsock_raw_ops = { + .family = PF_NFC, + .owner = THIS_MODULE, + .release = rawsock_release, + .bind = sock_no_bind, + .connect = sock_no_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = sock_no_getname, + .poll = datagram_poll, + .ioctl = sock_no_ioctl, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .sendmsg = sock_no_sendmsg, + .recvmsg = rawsock_recvmsg, + .mmap = sock_no_mmap, +}; + +static void rawsock_destruct(struct sock *sk) +{ + pr_debug("sk=%p\n", sk); + + if (sk->sk_state == TCP_ESTABLISHED) { + nfc_deactivate_target(nfc_rawsock(sk)->dev, + nfc_rawsock(sk)->target_idx, + NFC_TARGET_MODE_IDLE); + nfc_put_device(nfc_rawsock(sk)->dev); + } + + skb_queue_purge(&sk->sk_receive_queue); + + if (!sock_flag(sk, SOCK_DEAD)) { + pr_err("Freeing alive NFC raw socket %p\n", sk); + return; + } +} + +static int rawsock_create(struct net *net, struct socket *sock, + const struct nfc_protocol *nfc_proto, int kern) +{ + struct sock *sk; + + pr_debug("sock=%p\n", sock); + + if ((sock->type != SOCK_SEQPACKET) && (sock->type != SOCK_RAW)) + return -ESOCKTNOSUPPORT; + + if (sock->type == SOCK_RAW) { + if (!ns_capable(net->user_ns, CAP_NET_RAW)) + return -EPERM; + sock->ops = &rawsock_raw_ops; + } else { + sock->ops = &rawsock_ops; + } + + sk = sk_alloc(net, PF_NFC, GFP_ATOMIC, nfc_proto->proto, kern); + if (!sk) + return -ENOMEM; + + sock_init_data(sock, sk); + sk->sk_protocol = nfc_proto->id; + sk->sk_destruct = rawsock_destruct; + sock->state = SS_UNCONNECTED; + if (sock->type == SOCK_RAW) + nfc_sock_link(&raw_sk_list, sk); + else { + INIT_WORK(&nfc_rawsock(sk)->tx_work, rawsock_tx_work); + nfc_rawsock(sk)->tx_work_scheduled = false; + } + + return 0; +} + +void nfc_send_to_raw_sock(struct nfc_dev *dev, struct sk_buff *skb, + u8 payload_type, u8 direction) +{ + struct sk_buff *skb_copy = NULL, *nskb; + struct sock *sk; + u8 *data; + + read_lock(&raw_sk_list.lock); + + sk_for_each(sk, &raw_sk_list.head) { + if (!skb_copy) { + skb_copy = __pskb_copy_fclone(skb, NFC_RAW_HEADER_SIZE, + GFP_ATOMIC, true); + if (!skb_copy) + continue; + + data = skb_push(skb_copy, NFC_RAW_HEADER_SIZE); + + data[0] = dev ? dev->idx : 0xFF; + data[1] = direction & 0x01; + data[1] |= (payload_type << 1); + } + + nskb = skb_clone(skb_copy, GFP_ATOMIC); + if (!nskb) + continue; + + if (sock_queue_rcv_skb(sk, nskb)) + kfree_skb(nskb); + } + + read_unlock(&raw_sk_list.lock); + + kfree_skb(skb_copy); +} +EXPORT_SYMBOL(nfc_send_to_raw_sock); + +static struct proto rawsock_proto = { + .name = "NFC_RAW", + .owner = THIS_MODULE, + .obj_size = sizeof(struct nfc_rawsock), +}; + +static const struct nfc_protocol rawsock_nfc_proto = { + .id = NFC_SOCKPROTO_RAW, + .proto = &rawsock_proto, + .owner = THIS_MODULE, + .create = rawsock_create +}; + +int __init rawsock_init(void) +{ + int rc; + + rc = nfc_proto_register(&rawsock_nfc_proto); + + return rc; +} + +void rawsock_exit(void) +{ + nfc_proto_unregister(&rawsock_nfc_proto); +} diff --git a/net/nsh/Kconfig b/net/nsh/Kconfig new file mode 100644 index 000000000..84f2a2823 --- /dev/null +++ b/net/nsh/Kconfig @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: GPL-2.0-only +menuconfig NET_NSH + tristate "Network Service Header (NSH) protocol" + default n + help + Network Service Header is an implementation of Service Function + Chaining (RFC 7665). The current implementation in Linux supports + only MD type 1 and only with the openvswitch module. + + If unsure, say N. diff --git a/net/nsh/Makefile b/net/nsh/Makefile new file mode 100644 index 000000000..de34a615a --- /dev/null +++ b/net/nsh/Makefile @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: GPL-2.0-only +obj-$(CONFIG_NET_NSH) += nsh.o diff --git a/net/nsh/nsh.c b/net/nsh/nsh.c new file mode 100644 index 000000000..0f23e5e8e --- /dev/null +++ b/net/nsh/nsh.c @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Network Service Header + * + * Copyright (c) 2017 Red Hat, Inc. -- Jiri Benc + */ + +#include +#include +#include +#include +#include + +int nsh_push(struct sk_buff *skb, const struct nshhdr *pushed_nh) +{ + struct nshhdr *nh; + size_t length = nsh_hdr_len(pushed_nh); + u8 next_proto; + + if (skb->mac_len) { + next_proto = TUN_P_ETHERNET; + } else { + next_proto = tun_p_from_eth_p(skb->protocol); + if (!next_proto) + return -EAFNOSUPPORT; + } + + /* Add the NSH header */ + if (skb_cow_head(skb, length) < 0) + return -ENOMEM; + + skb_push(skb, length); + nh = (struct nshhdr *)(skb->data); + memcpy(nh, pushed_nh, length); + nh->np = next_proto; + skb_postpush_rcsum(skb, nh, length); + + skb->protocol = htons(ETH_P_NSH); + skb_reset_mac_header(skb); + skb_reset_network_header(skb); + skb_reset_mac_len(skb); + + return 0; +} +EXPORT_SYMBOL_GPL(nsh_push); + +int nsh_pop(struct sk_buff *skb) +{ + struct nshhdr *nh; + size_t length; + __be16 inner_proto; + + if (!pskb_may_pull(skb, NSH_BASE_HDR_LEN)) + return -ENOMEM; + nh = (struct nshhdr *)(skb->data); + length = nsh_hdr_len(nh); + if (length < NSH_BASE_HDR_LEN) + return -EINVAL; + inner_proto = tun_p_to_eth_p(nh->np); + if (!pskb_may_pull(skb, length)) + return -ENOMEM; + + if (!inner_proto) + return -EAFNOSUPPORT; + + skb_pull_rcsum(skb, length); + skb_reset_mac_header(skb); + skb_reset_network_header(skb); + skb_reset_mac_len(skb); + skb->protocol = inner_proto; + + return 0; +} +EXPORT_SYMBOL_GPL(nsh_pop); + +static struct sk_buff *nsh_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + struct sk_buff *segs = ERR_PTR(-EINVAL); + u16 mac_offset = skb->mac_header; + unsigned int nsh_len, mac_len; + __be16 proto; + + skb_reset_network_header(skb); + + mac_len = skb->mac_len; + + if (unlikely(!pskb_may_pull(skb, NSH_BASE_HDR_LEN))) + goto out; + nsh_len = nsh_hdr_len(nsh_hdr(skb)); + if (nsh_len < NSH_BASE_HDR_LEN) + goto out; + if (unlikely(!pskb_may_pull(skb, nsh_len))) + goto out; + + proto = tun_p_to_eth_p(nsh_hdr(skb)->np); + if (!proto) + goto out; + + __skb_pull(skb, nsh_len); + + skb_reset_mac_header(skb); + skb->mac_len = proto == htons(ETH_P_TEB) ? ETH_HLEN : 0; + skb->protocol = proto; + + features &= NETIF_F_SG; + segs = skb_mac_gso_segment(skb, features); + if (IS_ERR_OR_NULL(segs)) { + skb_gso_error_unwind(skb, htons(ETH_P_NSH), nsh_len, + mac_offset, mac_len); + goto out; + } + + for (skb = segs; skb; skb = skb->next) { + skb->protocol = htons(ETH_P_NSH); + __skb_push(skb, nsh_len); + skb->mac_header = mac_offset; + skb->network_header = skb->mac_header + mac_len; + skb->mac_len = mac_len; + } + +out: + return segs; +} + +static struct packet_offload nsh_packet_offload __read_mostly = { + .type = htons(ETH_P_NSH), + .priority = 15, + .callbacks = { + .gso_segment = nsh_gso_segment, + }, +}; + +static int __init nsh_init_module(void) +{ + dev_add_offload(&nsh_packet_offload); + return 0; +} + +static void __exit nsh_cleanup_module(void) +{ + dev_remove_offload(&nsh_packet_offload); +} + +module_init(nsh_init_module); +module_exit(nsh_cleanup_module); + +MODULE_AUTHOR("Jiri Benc "); +MODULE_DESCRIPTION("NSH protocol"); +MODULE_LICENSE("GPL v2"); diff --git a/net/openvswitch/Kconfig b/net/openvswitch/Kconfig new file mode 100644 index 000000000..15bd287f5 --- /dev/null +++ b/net/openvswitch/Kconfig @@ -0,0 +1,74 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Open vSwitch +# + +config OPENVSWITCH + tristate "Open vSwitch" + depends on INET + depends on !NF_CONNTRACK || \ + (NF_CONNTRACK && ((!NF_DEFRAG_IPV6 || NF_DEFRAG_IPV6) && \ + (!NF_NAT || NF_NAT) && \ + (!NETFILTER_CONNCOUNT || NETFILTER_CONNCOUNT))) + select LIBCRC32C + select MPLS + select NET_MPLS_GSO + select DST_CACHE + select NET_NSH + help + Open vSwitch is a multilayer Ethernet switch targeted at virtualized + environments. In addition to supporting a variety of features + expected in a traditional hardware switch, it enables fine-grained + programmatic extension and flow-based control of the network. This + control is useful in a wide variety of applications but is + particularly important in multi-server virtualization deployments, + which are often characterized by highly dynamic endpoints and the + need to maintain logical abstractions for multiple tenants. + + The Open vSwitch datapath provides an in-kernel fast path for packet + forwarding. It is complemented by a userspace daemon, ovs-vswitchd, + which is able to accept configuration from a variety of sources and + translate it into packet processing rules. + + See http://openvswitch.org for more information and userspace + utilities. + + To compile this code as a module, choose M here: the module will be + called openvswitch. + + If unsure, say N. + +config OPENVSWITCH_GRE + tristate "Open vSwitch GRE tunneling support" + depends on OPENVSWITCH + depends on NET_IPGRE + default OPENVSWITCH + help + If you say Y here, then the Open vSwitch will be able create GRE + vport. + + Say N to exclude this support and reduce the binary size. + + If unsure, say Y. + +config OPENVSWITCH_VXLAN + tristate "Open vSwitch VXLAN tunneling support" + depends on OPENVSWITCH + depends on VXLAN + default OPENVSWITCH + help + If you say Y here, then the Open vSwitch will be able create vxlan vport. + + Say N to exclude this support and reduce the binary size. + + If unsure, say Y. + +config OPENVSWITCH_GENEVE + tristate "Open vSwitch Geneve tunneling support" + depends on OPENVSWITCH + depends on GENEVE + default OPENVSWITCH + help + If you say Y here, then the Open vSwitch will be able create geneve vport. + + Say N to exclude this support and reduce the binary size. diff --git a/net/openvswitch/Makefile b/net/openvswitch/Makefile new file mode 100644 index 000000000..28982630b --- /dev/null +++ b/net/openvswitch/Makefile @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for Open vSwitch. +# + +obj-$(CONFIG_OPENVSWITCH) += openvswitch.o + +openvswitch-y := \ + actions.o \ + datapath.o \ + dp_notify.o \ + flow.o \ + flow_netlink.o \ + flow_table.o \ + meter.o \ + openvswitch_trace.o \ + vport.o \ + vport-internal_dev.o \ + vport-netdev.o + +ifneq ($(CONFIG_NF_CONNTRACK),) +openvswitch-y += conntrack.o +endif + +obj-$(CONFIG_OPENVSWITCH_VXLAN)+= vport-vxlan.o +obj-$(CONFIG_OPENVSWITCH_GENEVE)+= vport-geneve.o +obj-$(CONFIG_OPENVSWITCH_GRE) += vport-gre.o + +CFLAGS_openvswitch_trace.o = -I$(src) diff --git a/net/openvswitch/actions.c b/net/openvswitch/actions.c new file mode 100644 index 000000000..a8cf9a887 --- /dev/null +++ b/net/openvswitch/actions.c @@ -0,0 +1,1629 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2007-2017 Nicira, Inc. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "datapath.h" +#include "flow.h" +#include "conntrack.h" +#include "vport.h" +#include "flow_netlink.h" +#include "openvswitch_trace.h" + +struct deferred_action { + struct sk_buff *skb; + const struct nlattr *actions; + int actions_len; + + /* Store pkt_key clone when creating deferred action. */ + struct sw_flow_key pkt_key; +}; + +#define MAX_L2_LEN (VLAN_ETH_HLEN + 3 * MPLS_HLEN) +struct ovs_frag_data { + unsigned long dst; + struct vport *vport; + struct ovs_skb_cb cb; + __be16 inner_protocol; + u16 network_offset; /* valid only for MPLS */ + u16 vlan_tci; + __be16 vlan_proto; + unsigned int l2_len; + u8 mac_proto; + u8 l2_data[MAX_L2_LEN]; +}; + +static DEFINE_PER_CPU(struct ovs_frag_data, ovs_frag_data_storage); + +#define DEFERRED_ACTION_FIFO_SIZE 10 +#define OVS_RECURSION_LIMIT 5 +#define OVS_DEFERRED_ACTION_THRESHOLD (OVS_RECURSION_LIMIT - 2) +struct action_fifo { + int head; + int tail; + /* Deferred action fifo queue storage. */ + struct deferred_action fifo[DEFERRED_ACTION_FIFO_SIZE]; +}; + +struct action_flow_keys { + struct sw_flow_key key[OVS_DEFERRED_ACTION_THRESHOLD]; +}; + +static struct action_fifo __percpu *action_fifos; +static struct action_flow_keys __percpu *flow_keys; +static DEFINE_PER_CPU(int, exec_actions_level); + +/* Make a clone of the 'key', using the pre-allocated percpu 'flow_keys' + * space. Return NULL if out of key spaces. + */ +static struct sw_flow_key *clone_key(const struct sw_flow_key *key_) +{ + struct action_flow_keys *keys = this_cpu_ptr(flow_keys); + int level = this_cpu_read(exec_actions_level); + struct sw_flow_key *key = NULL; + + if (level <= OVS_DEFERRED_ACTION_THRESHOLD) { + key = &keys->key[level - 1]; + *key = *key_; + } + + return key; +} + +static void action_fifo_init(struct action_fifo *fifo) +{ + fifo->head = 0; + fifo->tail = 0; +} + +static bool action_fifo_is_empty(const struct action_fifo *fifo) +{ + return (fifo->head == fifo->tail); +} + +static struct deferred_action *action_fifo_get(struct action_fifo *fifo) +{ + if (action_fifo_is_empty(fifo)) + return NULL; + + return &fifo->fifo[fifo->tail++]; +} + +static struct deferred_action *action_fifo_put(struct action_fifo *fifo) +{ + if (fifo->head >= DEFERRED_ACTION_FIFO_SIZE - 1) + return NULL; + + return &fifo->fifo[fifo->head++]; +} + +/* Return true if fifo is not full */ +static struct deferred_action *add_deferred_actions(struct sk_buff *skb, + const struct sw_flow_key *key, + const struct nlattr *actions, + const int actions_len) +{ + struct action_fifo *fifo; + struct deferred_action *da; + + fifo = this_cpu_ptr(action_fifos); + da = action_fifo_put(fifo); + if (da) { + da->skb = skb; + da->actions = actions; + da->actions_len = actions_len; + da->pkt_key = *key; + } + + return da; +} + +static void invalidate_flow_key(struct sw_flow_key *key) +{ + key->mac_proto |= SW_FLOW_KEY_INVALID; +} + +static bool is_flow_key_valid(const struct sw_flow_key *key) +{ + return !(key->mac_proto & SW_FLOW_KEY_INVALID); +} + +static int clone_execute(struct datapath *dp, struct sk_buff *skb, + struct sw_flow_key *key, + u32 recirc_id, + const struct nlattr *actions, int len, + bool last, bool clone_flow_key); + +static int do_execute_actions(struct datapath *dp, struct sk_buff *skb, + struct sw_flow_key *key, + const struct nlattr *attr, int len); + +static int push_mpls(struct sk_buff *skb, struct sw_flow_key *key, + __be32 mpls_lse, __be16 mpls_ethertype, __u16 mac_len) +{ + int err; + + err = skb_mpls_push(skb, mpls_lse, mpls_ethertype, mac_len, !!mac_len); + if (err) + return err; + + if (!mac_len) + key->mac_proto = MAC_PROTO_NONE; + + invalidate_flow_key(key); + return 0; +} + +static int pop_mpls(struct sk_buff *skb, struct sw_flow_key *key, + const __be16 ethertype) +{ + int err; + + err = skb_mpls_pop(skb, ethertype, skb->mac_len, + ovs_key_mac_proto(key) == MAC_PROTO_ETHERNET); + if (err) + return err; + + if (ethertype == htons(ETH_P_TEB)) + key->mac_proto = MAC_PROTO_ETHERNET; + + invalidate_flow_key(key); + return 0; +} + +static int set_mpls(struct sk_buff *skb, struct sw_flow_key *flow_key, + const __be32 *mpls_lse, const __be32 *mask) +{ + struct mpls_shim_hdr *stack; + __be32 lse; + int err; + + if (!pskb_may_pull(skb, skb_network_offset(skb) + MPLS_HLEN)) + return -ENOMEM; + + stack = mpls_hdr(skb); + lse = OVS_MASKED(stack->label_stack_entry, *mpls_lse, *mask); + err = skb_mpls_update_lse(skb, lse); + if (err) + return err; + + flow_key->mpls.lse[0] = lse; + return 0; +} + +static int pop_vlan(struct sk_buff *skb, struct sw_flow_key *key) +{ + int err; + + err = skb_vlan_pop(skb); + if (skb_vlan_tag_present(skb)) { + invalidate_flow_key(key); + } else { + key->eth.vlan.tci = 0; + key->eth.vlan.tpid = 0; + } + return err; +} + +static int push_vlan(struct sk_buff *skb, struct sw_flow_key *key, + const struct ovs_action_push_vlan *vlan) +{ + if (skb_vlan_tag_present(skb)) { + invalidate_flow_key(key); + } else { + key->eth.vlan.tci = vlan->vlan_tci; + key->eth.vlan.tpid = vlan->vlan_tpid; + } + return skb_vlan_push(skb, vlan->vlan_tpid, + ntohs(vlan->vlan_tci) & ~VLAN_CFI_MASK); +} + +/* 'src' is already properly masked. */ +static void ether_addr_copy_masked(u8 *dst_, const u8 *src_, const u8 *mask_) +{ + u16 *dst = (u16 *)dst_; + const u16 *src = (const u16 *)src_; + const u16 *mask = (const u16 *)mask_; + + OVS_SET_MASKED(dst[0], src[0], mask[0]); + OVS_SET_MASKED(dst[1], src[1], mask[1]); + OVS_SET_MASKED(dst[2], src[2], mask[2]); +} + +static int set_eth_addr(struct sk_buff *skb, struct sw_flow_key *flow_key, + const struct ovs_key_ethernet *key, + const struct ovs_key_ethernet *mask) +{ + int err; + + err = skb_ensure_writable(skb, ETH_HLEN); + if (unlikely(err)) + return err; + + skb_postpull_rcsum(skb, eth_hdr(skb), ETH_ALEN * 2); + + ether_addr_copy_masked(eth_hdr(skb)->h_source, key->eth_src, + mask->eth_src); + ether_addr_copy_masked(eth_hdr(skb)->h_dest, key->eth_dst, + mask->eth_dst); + + skb_postpush_rcsum(skb, eth_hdr(skb), ETH_ALEN * 2); + + ether_addr_copy(flow_key->eth.src, eth_hdr(skb)->h_source); + ether_addr_copy(flow_key->eth.dst, eth_hdr(skb)->h_dest); + return 0; +} + +/* pop_eth does not support VLAN packets as this action is never called + * for them. + */ +static int pop_eth(struct sk_buff *skb, struct sw_flow_key *key) +{ + int err; + + err = skb_eth_pop(skb); + if (err) + return err; + + /* safe right before invalidate_flow_key */ + key->mac_proto = MAC_PROTO_NONE; + invalidate_flow_key(key); + return 0; +} + +static int push_eth(struct sk_buff *skb, struct sw_flow_key *key, + const struct ovs_action_push_eth *ethh) +{ + int err; + + err = skb_eth_push(skb, ethh->addresses.eth_dst, + ethh->addresses.eth_src); + if (err) + return err; + + /* safe right before invalidate_flow_key */ + key->mac_proto = MAC_PROTO_ETHERNET; + invalidate_flow_key(key); + return 0; +} + +static int push_nsh(struct sk_buff *skb, struct sw_flow_key *key, + const struct nshhdr *nh) +{ + int err; + + err = nsh_push(skb, nh); + if (err) + return err; + + /* safe right before invalidate_flow_key */ + key->mac_proto = MAC_PROTO_NONE; + invalidate_flow_key(key); + return 0; +} + +static int pop_nsh(struct sk_buff *skb, struct sw_flow_key *key) +{ + int err; + + err = nsh_pop(skb); + if (err) + return err; + + /* safe right before invalidate_flow_key */ + if (skb->protocol == htons(ETH_P_TEB)) + key->mac_proto = MAC_PROTO_ETHERNET; + else + key->mac_proto = MAC_PROTO_NONE; + invalidate_flow_key(key); + return 0; +} + +static void update_ip_l4_checksum(struct sk_buff *skb, struct iphdr *nh, + __be32 addr, __be32 new_addr) +{ + int transport_len = skb->len - skb_transport_offset(skb); + + if (nh->frag_off & htons(IP_OFFSET)) + return; + + if (nh->protocol == IPPROTO_TCP) { + if (likely(transport_len >= sizeof(struct tcphdr))) + inet_proto_csum_replace4(&tcp_hdr(skb)->check, skb, + addr, new_addr, true); + } else if (nh->protocol == IPPROTO_UDP) { + if (likely(transport_len >= sizeof(struct udphdr))) { + struct udphdr *uh = udp_hdr(skb); + + if (uh->check || skb->ip_summed == CHECKSUM_PARTIAL) { + inet_proto_csum_replace4(&uh->check, skb, + addr, new_addr, true); + if (!uh->check) + uh->check = CSUM_MANGLED_0; + } + } + } +} + +static void set_ip_addr(struct sk_buff *skb, struct iphdr *nh, + __be32 *addr, __be32 new_addr) +{ + update_ip_l4_checksum(skb, nh, *addr, new_addr); + csum_replace4(&nh->check, *addr, new_addr); + skb_clear_hash(skb); + ovs_ct_clear(skb, NULL); + *addr = new_addr; +} + +static void update_ipv6_checksum(struct sk_buff *skb, u8 l4_proto, + __be32 addr[4], const __be32 new_addr[4]) +{ + int transport_len = skb->len - skb_transport_offset(skb); + + if (l4_proto == NEXTHDR_TCP) { + if (likely(transport_len >= sizeof(struct tcphdr))) + inet_proto_csum_replace16(&tcp_hdr(skb)->check, skb, + addr, new_addr, true); + } else if (l4_proto == NEXTHDR_UDP) { + if (likely(transport_len >= sizeof(struct udphdr))) { + struct udphdr *uh = udp_hdr(skb); + + if (uh->check || skb->ip_summed == CHECKSUM_PARTIAL) { + inet_proto_csum_replace16(&uh->check, skb, + addr, new_addr, true); + if (!uh->check) + uh->check = CSUM_MANGLED_0; + } + } + } else if (l4_proto == NEXTHDR_ICMP) { + if (likely(transport_len >= sizeof(struct icmp6hdr))) + inet_proto_csum_replace16(&icmp6_hdr(skb)->icmp6_cksum, + skb, addr, new_addr, true); + } +} + +static void mask_ipv6_addr(const __be32 old[4], const __be32 addr[4], + const __be32 mask[4], __be32 masked[4]) +{ + masked[0] = OVS_MASKED(old[0], addr[0], mask[0]); + masked[1] = OVS_MASKED(old[1], addr[1], mask[1]); + masked[2] = OVS_MASKED(old[2], addr[2], mask[2]); + masked[3] = OVS_MASKED(old[3], addr[3], mask[3]); +} + +static void set_ipv6_addr(struct sk_buff *skb, u8 l4_proto, + __be32 addr[4], const __be32 new_addr[4], + bool recalculate_csum) +{ + if (recalculate_csum) + update_ipv6_checksum(skb, l4_proto, addr, new_addr); + + skb_clear_hash(skb); + ovs_ct_clear(skb, NULL); + memcpy(addr, new_addr, sizeof(__be32[4])); +} + +static void set_ipv6_dsfield(struct sk_buff *skb, struct ipv6hdr *nh, u8 ipv6_tclass, u8 mask) +{ + u8 old_ipv6_tclass = ipv6_get_dsfield(nh); + + ipv6_tclass = OVS_MASKED(old_ipv6_tclass, ipv6_tclass, mask); + + if (skb->ip_summed == CHECKSUM_COMPLETE) + csum_replace(&skb->csum, (__force __wsum)(old_ipv6_tclass << 12), + (__force __wsum)(ipv6_tclass << 12)); + + ipv6_change_dsfield(nh, ~mask, ipv6_tclass); +} + +static void set_ipv6_fl(struct sk_buff *skb, struct ipv6hdr *nh, u32 fl, u32 mask) +{ + u32 ofl; + + ofl = nh->flow_lbl[0] << 16 | nh->flow_lbl[1] << 8 | nh->flow_lbl[2]; + fl = OVS_MASKED(ofl, fl, mask); + + /* Bits 21-24 are always unmasked, so this retains their values. */ + nh->flow_lbl[0] = (u8)(fl >> 16); + nh->flow_lbl[1] = (u8)(fl >> 8); + nh->flow_lbl[2] = (u8)fl; + + if (skb->ip_summed == CHECKSUM_COMPLETE) + csum_replace(&skb->csum, (__force __wsum)htonl(ofl), (__force __wsum)htonl(fl)); +} + +static void set_ipv6_ttl(struct sk_buff *skb, struct ipv6hdr *nh, u8 new_ttl, u8 mask) +{ + new_ttl = OVS_MASKED(nh->hop_limit, new_ttl, mask); + + if (skb->ip_summed == CHECKSUM_COMPLETE) + csum_replace(&skb->csum, (__force __wsum)(nh->hop_limit << 8), + (__force __wsum)(new_ttl << 8)); + nh->hop_limit = new_ttl; +} + +static void set_ip_ttl(struct sk_buff *skb, struct iphdr *nh, u8 new_ttl, + u8 mask) +{ + new_ttl = OVS_MASKED(nh->ttl, new_ttl, mask); + + csum_replace2(&nh->check, htons(nh->ttl << 8), htons(new_ttl << 8)); + nh->ttl = new_ttl; +} + +static int set_ipv4(struct sk_buff *skb, struct sw_flow_key *flow_key, + const struct ovs_key_ipv4 *key, + const struct ovs_key_ipv4 *mask) +{ + struct iphdr *nh; + __be32 new_addr; + int err; + + err = skb_ensure_writable(skb, skb_network_offset(skb) + + sizeof(struct iphdr)); + if (unlikely(err)) + return err; + + nh = ip_hdr(skb); + + /* Setting an IP addresses is typically only a side effect of + * matching on them in the current userspace implementation, so it + * makes sense to check if the value actually changed. + */ + if (mask->ipv4_src) { + new_addr = OVS_MASKED(nh->saddr, key->ipv4_src, mask->ipv4_src); + + if (unlikely(new_addr != nh->saddr)) { + set_ip_addr(skb, nh, &nh->saddr, new_addr); + flow_key->ipv4.addr.src = new_addr; + } + } + if (mask->ipv4_dst) { + new_addr = OVS_MASKED(nh->daddr, key->ipv4_dst, mask->ipv4_dst); + + if (unlikely(new_addr != nh->daddr)) { + set_ip_addr(skb, nh, &nh->daddr, new_addr); + flow_key->ipv4.addr.dst = new_addr; + } + } + if (mask->ipv4_tos) { + ipv4_change_dsfield(nh, ~mask->ipv4_tos, key->ipv4_tos); + flow_key->ip.tos = nh->tos; + } + if (mask->ipv4_ttl) { + set_ip_ttl(skb, nh, key->ipv4_ttl, mask->ipv4_ttl); + flow_key->ip.ttl = nh->ttl; + } + + return 0; +} + +static bool is_ipv6_mask_nonzero(const __be32 addr[4]) +{ + return !!(addr[0] | addr[1] | addr[2] | addr[3]); +} + +static int set_ipv6(struct sk_buff *skb, struct sw_flow_key *flow_key, + const struct ovs_key_ipv6 *key, + const struct ovs_key_ipv6 *mask) +{ + struct ipv6hdr *nh; + int err; + + err = skb_ensure_writable(skb, skb_network_offset(skb) + + sizeof(struct ipv6hdr)); + if (unlikely(err)) + return err; + + nh = ipv6_hdr(skb); + + /* Setting an IP addresses is typically only a side effect of + * matching on them in the current userspace implementation, so it + * makes sense to check if the value actually changed. + */ + if (is_ipv6_mask_nonzero(mask->ipv6_src)) { + __be32 *saddr = (__be32 *)&nh->saddr; + __be32 masked[4]; + + mask_ipv6_addr(saddr, key->ipv6_src, mask->ipv6_src, masked); + + if (unlikely(memcmp(saddr, masked, sizeof(masked)))) { + set_ipv6_addr(skb, flow_key->ip.proto, saddr, masked, + true); + memcpy(&flow_key->ipv6.addr.src, masked, + sizeof(flow_key->ipv6.addr.src)); + } + } + if (is_ipv6_mask_nonzero(mask->ipv6_dst)) { + unsigned int offset = 0; + int flags = IP6_FH_F_SKIP_RH; + bool recalc_csum = true; + __be32 *daddr = (__be32 *)&nh->daddr; + __be32 masked[4]; + + mask_ipv6_addr(daddr, key->ipv6_dst, mask->ipv6_dst, masked); + + if (unlikely(memcmp(daddr, masked, sizeof(masked)))) { + if (ipv6_ext_hdr(nh->nexthdr)) + recalc_csum = (ipv6_find_hdr(skb, &offset, + NEXTHDR_ROUTING, + NULL, &flags) + != NEXTHDR_ROUTING); + + set_ipv6_addr(skb, flow_key->ip.proto, daddr, masked, + recalc_csum); + memcpy(&flow_key->ipv6.addr.dst, masked, + sizeof(flow_key->ipv6.addr.dst)); + } + } + if (mask->ipv6_tclass) { + set_ipv6_dsfield(skb, nh, key->ipv6_tclass, mask->ipv6_tclass); + flow_key->ip.tos = ipv6_get_dsfield(nh); + } + if (mask->ipv6_label) { + set_ipv6_fl(skb, nh, ntohl(key->ipv6_label), + ntohl(mask->ipv6_label)); + flow_key->ipv6.label = + *(__be32 *)nh & htonl(IPV6_FLOWINFO_FLOWLABEL); + } + if (mask->ipv6_hlimit) { + set_ipv6_ttl(skb, nh, key->ipv6_hlimit, mask->ipv6_hlimit); + flow_key->ip.ttl = nh->hop_limit; + } + return 0; +} + +static int set_nsh(struct sk_buff *skb, struct sw_flow_key *flow_key, + const struct nlattr *a) +{ + struct nshhdr *nh; + size_t length; + int err; + u8 flags; + u8 ttl; + int i; + + struct ovs_key_nsh key; + struct ovs_key_nsh mask; + + err = nsh_key_from_nlattr(a, &key, &mask); + if (err) + return err; + + /* Make sure the NSH base header is there */ + if (!pskb_may_pull(skb, skb_network_offset(skb) + NSH_BASE_HDR_LEN)) + return -ENOMEM; + + nh = nsh_hdr(skb); + length = nsh_hdr_len(nh); + + /* Make sure the whole NSH header is there */ + err = skb_ensure_writable(skb, skb_network_offset(skb) + + length); + if (unlikely(err)) + return err; + + nh = nsh_hdr(skb); + skb_postpull_rcsum(skb, nh, length); + flags = nsh_get_flags(nh); + flags = OVS_MASKED(flags, key.base.flags, mask.base.flags); + flow_key->nsh.base.flags = flags; + ttl = nsh_get_ttl(nh); + ttl = OVS_MASKED(ttl, key.base.ttl, mask.base.ttl); + flow_key->nsh.base.ttl = ttl; + nsh_set_flags_and_ttl(nh, flags, ttl); + nh->path_hdr = OVS_MASKED(nh->path_hdr, key.base.path_hdr, + mask.base.path_hdr); + flow_key->nsh.base.path_hdr = nh->path_hdr; + switch (nh->mdtype) { + case NSH_M_TYPE1: + for (i = 0; i < NSH_MD1_CONTEXT_SIZE; i++) { + nh->md1.context[i] = + OVS_MASKED(nh->md1.context[i], key.context[i], + mask.context[i]); + } + memcpy(flow_key->nsh.context, nh->md1.context, + sizeof(nh->md1.context)); + break; + case NSH_M_TYPE2: + memset(flow_key->nsh.context, 0, + sizeof(flow_key->nsh.context)); + break; + default: + return -EINVAL; + } + skb_postpush_rcsum(skb, nh, length); + return 0; +} + +/* Must follow skb_ensure_writable() since that can move the skb data. */ +static void set_tp_port(struct sk_buff *skb, __be16 *port, + __be16 new_port, __sum16 *check) +{ + ovs_ct_clear(skb, NULL); + inet_proto_csum_replace2(check, skb, *port, new_port, false); + *port = new_port; +} + +static int set_udp(struct sk_buff *skb, struct sw_flow_key *flow_key, + const struct ovs_key_udp *key, + const struct ovs_key_udp *mask) +{ + struct udphdr *uh; + __be16 src, dst; + int err; + + err = skb_ensure_writable(skb, skb_transport_offset(skb) + + sizeof(struct udphdr)); + if (unlikely(err)) + return err; + + uh = udp_hdr(skb); + /* Either of the masks is non-zero, so do not bother checking them. */ + src = OVS_MASKED(uh->source, key->udp_src, mask->udp_src); + dst = OVS_MASKED(uh->dest, key->udp_dst, mask->udp_dst); + + if (uh->check && skb->ip_summed != CHECKSUM_PARTIAL) { + if (likely(src != uh->source)) { + set_tp_port(skb, &uh->source, src, &uh->check); + flow_key->tp.src = src; + } + if (likely(dst != uh->dest)) { + set_tp_port(skb, &uh->dest, dst, &uh->check); + flow_key->tp.dst = dst; + } + + if (unlikely(!uh->check)) + uh->check = CSUM_MANGLED_0; + } else { + uh->source = src; + uh->dest = dst; + flow_key->tp.src = src; + flow_key->tp.dst = dst; + ovs_ct_clear(skb, NULL); + } + + skb_clear_hash(skb); + + return 0; +} + +static int set_tcp(struct sk_buff *skb, struct sw_flow_key *flow_key, + const struct ovs_key_tcp *key, + const struct ovs_key_tcp *mask) +{ + struct tcphdr *th; + __be16 src, dst; + int err; + + err = skb_ensure_writable(skb, skb_transport_offset(skb) + + sizeof(struct tcphdr)); + if (unlikely(err)) + return err; + + th = tcp_hdr(skb); + src = OVS_MASKED(th->source, key->tcp_src, mask->tcp_src); + if (likely(src != th->source)) { + set_tp_port(skb, &th->source, src, &th->check); + flow_key->tp.src = src; + } + dst = OVS_MASKED(th->dest, key->tcp_dst, mask->tcp_dst); + if (likely(dst != th->dest)) { + set_tp_port(skb, &th->dest, dst, &th->check); + flow_key->tp.dst = dst; + } + skb_clear_hash(skb); + + return 0; +} + +static int set_sctp(struct sk_buff *skb, struct sw_flow_key *flow_key, + const struct ovs_key_sctp *key, + const struct ovs_key_sctp *mask) +{ + unsigned int sctphoff = skb_transport_offset(skb); + struct sctphdr *sh; + __le32 old_correct_csum, new_csum, old_csum; + int err; + + err = skb_ensure_writable(skb, sctphoff + sizeof(struct sctphdr)); + if (unlikely(err)) + return err; + + sh = sctp_hdr(skb); + old_csum = sh->checksum; + old_correct_csum = sctp_compute_cksum(skb, sctphoff); + + sh->source = OVS_MASKED(sh->source, key->sctp_src, mask->sctp_src); + sh->dest = OVS_MASKED(sh->dest, key->sctp_dst, mask->sctp_dst); + + new_csum = sctp_compute_cksum(skb, sctphoff); + + /* Carry any checksum errors through. */ + sh->checksum = old_csum ^ old_correct_csum ^ new_csum; + + skb_clear_hash(skb); + ovs_ct_clear(skb, NULL); + + flow_key->tp.src = sh->source; + flow_key->tp.dst = sh->dest; + + return 0; +} + +static int ovs_vport_output(struct net *net, struct sock *sk, + struct sk_buff *skb) +{ + struct ovs_frag_data *data = this_cpu_ptr(&ovs_frag_data_storage); + struct vport *vport = data->vport; + + if (skb_cow_head(skb, data->l2_len) < 0) { + kfree_skb(skb); + return -ENOMEM; + } + + __skb_dst_copy(skb, data->dst); + *OVS_CB(skb) = data->cb; + skb->inner_protocol = data->inner_protocol; + if (data->vlan_tci & VLAN_CFI_MASK) + __vlan_hwaccel_put_tag(skb, data->vlan_proto, data->vlan_tci & ~VLAN_CFI_MASK); + else + __vlan_hwaccel_clear_tag(skb); + + /* Reconstruct the MAC header. */ + skb_push(skb, data->l2_len); + memcpy(skb->data, &data->l2_data, data->l2_len); + skb_postpush_rcsum(skb, skb->data, data->l2_len); + skb_reset_mac_header(skb); + + if (eth_p_mpls(skb->protocol)) { + skb->inner_network_header = skb->network_header; + skb_set_network_header(skb, data->network_offset); + skb_reset_mac_len(skb); + } + + ovs_vport_send(vport, skb, data->mac_proto); + return 0; +} + +static unsigned int +ovs_dst_get_mtu(const struct dst_entry *dst) +{ + return dst->dev->mtu; +} + +static struct dst_ops ovs_dst_ops = { + .family = AF_UNSPEC, + .mtu = ovs_dst_get_mtu, +}; + +/* prepare_frag() is called once per (larger-than-MTU) frame; its inverse is + * ovs_vport_output(), which is called once per fragmented packet. + */ +static void prepare_frag(struct vport *vport, struct sk_buff *skb, + u16 orig_network_offset, u8 mac_proto) +{ + unsigned int hlen = skb_network_offset(skb); + struct ovs_frag_data *data; + + data = this_cpu_ptr(&ovs_frag_data_storage); + data->dst = skb->_skb_refdst; + data->vport = vport; + data->cb = *OVS_CB(skb); + data->inner_protocol = skb->inner_protocol; + data->network_offset = orig_network_offset; + if (skb_vlan_tag_present(skb)) + data->vlan_tci = skb_vlan_tag_get(skb) | VLAN_CFI_MASK; + else + data->vlan_tci = 0; + data->vlan_proto = skb->vlan_proto; + data->mac_proto = mac_proto; + data->l2_len = hlen; + memcpy(&data->l2_data, skb->data, hlen); + + memset(IPCB(skb), 0, sizeof(struct inet_skb_parm)); + skb_pull(skb, hlen); +} + +static void ovs_fragment(struct net *net, struct vport *vport, + struct sk_buff *skb, u16 mru, + struct sw_flow_key *key) +{ + u16 orig_network_offset = 0; + + if (eth_p_mpls(skb->protocol)) { + orig_network_offset = skb_network_offset(skb); + skb->network_header = skb->inner_network_header; + } + + if (skb_network_offset(skb) > MAX_L2_LEN) { + OVS_NLERR(1, "L2 header too long to fragment"); + goto err; + } + + if (key->eth.type == htons(ETH_P_IP)) { + struct rtable ovs_rt = { 0 }; + unsigned long orig_dst; + + prepare_frag(vport, skb, orig_network_offset, + ovs_key_mac_proto(key)); + dst_init(&ovs_rt.dst, &ovs_dst_ops, NULL, 1, + DST_OBSOLETE_NONE, DST_NOCOUNT); + ovs_rt.dst.dev = vport->dev; + + orig_dst = skb->_skb_refdst; + skb_dst_set_noref(skb, &ovs_rt.dst); + IPCB(skb)->frag_max_size = mru; + + ip_do_fragment(net, skb->sk, skb, ovs_vport_output); + refdst_drop(orig_dst); + } else if (key->eth.type == htons(ETH_P_IPV6)) { + unsigned long orig_dst; + struct rt6_info ovs_rt; + + prepare_frag(vport, skb, orig_network_offset, + ovs_key_mac_proto(key)); + memset(&ovs_rt, 0, sizeof(ovs_rt)); + dst_init(&ovs_rt.dst, &ovs_dst_ops, NULL, 1, + DST_OBSOLETE_NONE, DST_NOCOUNT); + ovs_rt.dst.dev = vport->dev; + + orig_dst = skb->_skb_refdst; + skb_dst_set_noref(skb, &ovs_rt.dst); + IP6CB(skb)->frag_max_size = mru; + + ipv6_stub->ipv6_fragment(net, skb->sk, skb, ovs_vport_output); + refdst_drop(orig_dst); + } else { + WARN_ONCE(1, "Failed fragment ->%s: eth=%04x, MRU=%d, MTU=%d.", + ovs_vport_name(vport), ntohs(key->eth.type), mru, + vport->dev->mtu); + goto err; + } + + return; +err: + kfree_skb(skb); +} + +static void do_output(struct datapath *dp, struct sk_buff *skb, int out_port, + struct sw_flow_key *key) +{ + struct vport *vport = ovs_vport_rcu(dp, out_port); + + if (likely(vport && netif_carrier_ok(vport->dev))) { + u16 mru = OVS_CB(skb)->mru; + u32 cutlen = OVS_CB(skb)->cutlen; + + if (unlikely(cutlen > 0)) { + if (skb->len - cutlen > ovs_mac_header_len(key)) + pskb_trim(skb, skb->len - cutlen); + else + pskb_trim(skb, ovs_mac_header_len(key)); + } + + if (likely(!mru || + (skb->len <= mru + vport->dev->hard_header_len))) { + ovs_vport_send(vport, skb, ovs_key_mac_proto(key)); + } else if (mru <= vport->dev->mtu) { + struct net *net = read_pnet(&dp->net); + + ovs_fragment(net, vport, skb, mru, key); + } else { + kfree_skb(skb); + } + } else { + kfree_skb(skb); + } +} + +static int output_userspace(struct datapath *dp, struct sk_buff *skb, + struct sw_flow_key *key, const struct nlattr *attr, + const struct nlattr *actions, int actions_len, + uint32_t cutlen) +{ + struct dp_upcall_info upcall; + const struct nlattr *a; + int rem; + + memset(&upcall, 0, sizeof(upcall)); + upcall.cmd = OVS_PACKET_CMD_ACTION; + upcall.mru = OVS_CB(skb)->mru; + + for (a = nla_data(attr), rem = nla_len(attr); rem > 0; + a = nla_next(a, &rem)) { + switch (nla_type(a)) { + case OVS_USERSPACE_ATTR_USERDATA: + upcall.userdata = a; + break; + + case OVS_USERSPACE_ATTR_PID: + if (dp->user_features & + OVS_DP_F_DISPATCH_UPCALL_PER_CPU) + upcall.portid = + ovs_dp_get_upcall_portid(dp, + smp_processor_id()); + else + upcall.portid = nla_get_u32(a); + break; + + case OVS_USERSPACE_ATTR_EGRESS_TUN_PORT: { + /* Get out tunnel info. */ + struct vport *vport; + + vport = ovs_vport_rcu(dp, nla_get_u32(a)); + if (vport) { + int err; + + err = dev_fill_metadata_dst(vport->dev, skb); + if (!err) + upcall.egress_tun_info = skb_tunnel_info(skb); + } + + break; + } + + case OVS_USERSPACE_ATTR_ACTIONS: { + /* Include actions. */ + upcall.actions = actions; + upcall.actions_len = actions_len; + break; + } + + } /* End of switch. */ + } + + return ovs_dp_upcall(dp, skb, key, &upcall, cutlen); +} + +static int dec_ttl_exception_handler(struct datapath *dp, struct sk_buff *skb, + struct sw_flow_key *key, + const struct nlattr *attr) +{ + /* The first attribute is always 'OVS_DEC_TTL_ATTR_ACTION'. */ + struct nlattr *actions = nla_data(attr); + + if (nla_len(actions)) + return clone_execute(dp, skb, key, 0, nla_data(actions), + nla_len(actions), true, false); + + consume_skb(skb); + return 0; +} + +/* When 'last' is true, sample() should always consume the 'skb'. + * Otherwise, sample() should keep 'skb' intact regardless what + * actions are executed within sample(). + */ +static int sample(struct datapath *dp, struct sk_buff *skb, + struct sw_flow_key *key, const struct nlattr *attr, + bool last) +{ + struct nlattr *actions; + struct nlattr *sample_arg; + int rem = nla_len(attr); + const struct sample_arg *arg; + bool clone_flow_key; + + /* The first action is always 'OVS_SAMPLE_ATTR_ARG'. */ + sample_arg = nla_data(attr); + arg = nla_data(sample_arg); + actions = nla_next(sample_arg, &rem); + + if ((arg->probability != U32_MAX) && + (!arg->probability || get_random_u32() > arg->probability)) { + if (last) + consume_skb(skb); + return 0; + } + + clone_flow_key = !arg->exec; + return clone_execute(dp, skb, key, 0, actions, rem, last, + clone_flow_key); +} + +/* When 'last' is true, clone() should always consume the 'skb'. + * Otherwise, clone() should keep 'skb' intact regardless what + * actions are executed within clone(). + */ +static int clone(struct datapath *dp, struct sk_buff *skb, + struct sw_flow_key *key, const struct nlattr *attr, + bool last) +{ + struct nlattr *actions; + struct nlattr *clone_arg; + int rem = nla_len(attr); + bool dont_clone_flow_key; + + /* The first action is always 'OVS_CLONE_ATTR_EXEC'. */ + clone_arg = nla_data(attr); + dont_clone_flow_key = nla_get_u32(clone_arg); + actions = nla_next(clone_arg, &rem); + + return clone_execute(dp, skb, key, 0, actions, rem, last, + !dont_clone_flow_key); +} + +static void execute_hash(struct sk_buff *skb, struct sw_flow_key *key, + const struct nlattr *attr) +{ + struct ovs_action_hash *hash_act = nla_data(attr); + u32 hash = 0; + + /* OVS_HASH_ALG_L4 is the only possible hash algorithm. */ + hash = skb_get_hash(skb); + hash = jhash_1word(hash, hash_act->hash_basis); + if (!hash) + hash = 0x1; + + key->ovs_flow_hash = hash; +} + +static int execute_set_action(struct sk_buff *skb, + struct sw_flow_key *flow_key, + const struct nlattr *a) +{ + /* Only tunnel set execution is supported without a mask. */ + if (nla_type(a) == OVS_KEY_ATTR_TUNNEL_INFO) { + struct ovs_tunnel_info *tun = nla_data(a); + + skb_dst_drop(skb); + dst_hold((struct dst_entry *)tun->tun_dst); + skb_dst_set(skb, (struct dst_entry *)tun->tun_dst); + return 0; + } + + return -EINVAL; +} + +/* Mask is at the midpoint of the data. */ +#define get_mask(a, type) ((const type)nla_data(a) + 1) + +static int execute_masked_set_action(struct sk_buff *skb, + struct sw_flow_key *flow_key, + const struct nlattr *a) +{ + int err = 0; + + switch (nla_type(a)) { + case OVS_KEY_ATTR_PRIORITY: + OVS_SET_MASKED(skb->priority, nla_get_u32(a), + *get_mask(a, u32 *)); + flow_key->phy.priority = skb->priority; + break; + + case OVS_KEY_ATTR_SKB_MARK: + OVS_SET_MASKED(skb->mark, nla_get_u32(a), *get_mask(a, u32 *)); + flow_key->phy.skb_mark = skb->mark; + break; + + case OVS_KEY_ATTR_TUNNEL_INFO: + /* Masked data not supported for tunnel. */ + err = -EINVAL; + break; + + case OVS_KEY_ATTR_ETHERNET: + err = set_eth_addr(skb, flow_key, nla_data(a), + get_mask(a, struct ovs_key_ethernet *)); + break; + + case OVS_KEY_ATTR_NSH: + err = set_nsh(skb, flow_key, a); + break; + + case OVS_KEY_ATTR_IPV4: + err = set_ipv4(skb, flow_key, nla_data(a), + get_mask(a, struct ovs_key_ipv4 *)); + break; + + case OVS_KEY_ATTR_IPV6: + err = set_ipv6(skb, flow_key, nla_data(a), + get_mask(a, struct ovs_key_ipv6 *)); + break; + + case OVS_KEY_ATTR_TCP: + err = set_tcp(skb, flow_key, nla_data(a), + get_mask(a, struct ovs_key_tcp *)); + break; + + case OVS_KEY_ATTR_UDP: + err = set_udp(skb, flow_key, nla_data(a), + get_mask(a, struct ovs_key_udp *)); + break; + + case OVS_KEY_ATTR_SCTP: + err = set_sctp(skb, flow_key, nla_data(a), + get_mask(a, struct ovs_key_sctp *)); + break; + + case OVS_KEY_ATTR_MPLS: + err = set_mpls(skb, flow_key, nla_data(a), get_mask(a, + __be32 *)); + break; + + case OVS_KEY_ATTR_CT_STATE: + case OVS_KEY_ATTR_CT_ZONE: + case OVS_KEY_ATTR_CT_MARK: + case OVS_KEY_ATTR_CT_LABELS: + case OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV4: + case OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV6: + err = -EINVAL; + break; + } + + return err; +} + +static int execute_recirc(struct datapath *dp, struct sk_buff *skb, + struct sw_flow_key *key, + const struct nlattr *a, bool last) +{ + u32 recirc_id; + + if (!is_flow_key_valid(key)) { + int err; + + err = ovs_flow_key_update(skb, key); + if (err) + return err; + } + BUG_ON(!is_flow_key_valid(key)); + + recirc_id = nla_get_u32(a); + return clone_execute(dp, skb, key, recirc_id, NULL, 0, last, true); +} + +static int execute_check_pkt_len(struct datapath *dp, struct sk_buff *skb, + struct sw_flow_key *key, + const struct nlattr *attr, bool last) +{ + struct ovs_skb_cb *ovs_cb = OVS_CB(skb); + const struct nlattr *actions, *cpl_arg; + int len, max_len, rem = nla_len(attr); + const struct check_pkt_len_arg *arg; + bool clone_flow_key; + + /* The first netlink attribute in 'attr' is always + * 'OVS_CHECK_PKT_LEN_ATTR_ARG'. + */ + cpl_arg = nla_data(attr); + arg = nla_data(cpl_arg); + + len = ovs_cb->mru ? ovs_cb->mru + skb->mac_len : skb->len; + max_len = arg->pkt_len; + + if ((skb_is_gso(skb) && skb_gso_validate_mac_len(skb, max_len)) || + len <= max_len) { + /* Second netlink attribute in 'attr' is always + * 'OVS_CHECK_PKT_LEN_ATTR_ACTIONS_IF_LESS_EQUAL'. + */ + actions = nla_next(cpl_arg, &rem); + clone_flow_key = !arg->exec_for_lesser_equal; + } else { + /* Third netlink attribute in 'attr' is always + * 'OVS_CHECK_PKT_LEN_ATTR_ACTIONS_IF_GREATER'. + */ + actions = nla_next(cpl_arg, &rem); + actions = nla_next(actions, &rem); + clone_flow_key = !arg->exec_for_greater; + } + + return clone_execute(dp, skb, key, 0, nla_data(actions), + nla_len(actions), last, clone_flow_key); +} + +static int execute_dec_ttl(struct sk_buff *skb, struct sw_flow_key *key) +{ + int err; + + if (skb->protocol == htons(ETH_P_IPV6)) { + struct ipv6hdr *nh; + + err = skb_ensure_writable(skb, skb_network_offset(skb) + + sizeof(*nh)); + if (unlikely(err)) + return err; + + nh = ipv6_hdr(skb); + + if (nh->hop_limit <= 1) + return -EHOSTUNREACH; + + key->ip.ttl = --nh->hop_limit; + } else if (skb->protocol == htons(ETH_P_IP)) { + struct iphdr *nh; + u8 old_ttl; + + err = skb_ensure_writable(skb, skb_network_offset(skb) + + sizeof(*nh)); + if (unlikely(err)) + return err; + + nh = ip_hdr(skb); + if (nh->ttl <= 1) + return -EHOSTUNREACH; + + old_ttl = nh->ttl--; + csum_replace2(&nh->check, htons(old_ttl << 8), + htons(nh->ttl << 8)); + key->ip.ttl = nh->ttl; + } + return 0; +} + +/* Execute a list of actions against 'skb'. */ +static int do_execute_actions(struct datapath *dp, struct sk_buff *skb, + struct sw_flow_key *key, + const struct nlattr *attr, int len) +{ + const struct nlattr *a; + int rem; + + for (a = attr, rem = len; rem > 0; + a = nla_next(a, &rem)) { + int err = 0; + + if (trace_ovs_do_execute_action_enabled()) + trace_ovs_do_execute_action(dp, skb, key, a, rem); + + switch (nla_type(a)) { + case OVS_ACTION_ATTR_OUTPUT: { + int port = nla_get_u32(a); + struct sk_buff *clone; + + /* Every output action needs a separate clone + * of 'skb', In case the output action is the + * last action, cloning can be avoided. + */ + if (nla_is_last(a, rem)) { + do_output(dp, skb, port, key); + /* 'skb' has been used for output. + */ + return 0; + } + + clone = skb_clone(skb, GFP_ATOMIC); + if (clone) + do_output(dp, clone, port, key); + OVS_CB(skb)->cutlen = 0; + break; + } + + case OVS_ACTION_ATTR_TRUNC: { + struct ovs_action_trunc *trunc = nla_data(a); + + if (skb->len > trunc->max_len) + OVS_CB(skb)->cutlen = skb->len - trunc->max_len; + break; + } + + case OVS_ACTION_ATTR_USERSPACE: + output_userspace(dp, skb, key, a, attr, + len, OVS_CB(skb)->cutlen); + OVS_CB(skb)->cutlen = 0; + break; + + case OVS_ACTION_ATTR_HASH: + execute_hash(skb, key, a); + break; + + case OVS_ACTION_ATTR_PUSH_MPLS: { + struct ovs_action_push_mpls *mpls = nla_data(a); + + err = push_mpls(skb, key, mpls->mpls_lse, + mpls->mpls_ethertype, skb->mac_len); + break; + } + case OVS_ACTION_ATTR_ADD_MPLS: { + struct ovs_action_add_mpls *mpls = nla_data(a); + __u16 mac_len = 0; + + if (mpls->tun_flags & OVS_MPLS_L3_TUNNEL_FLAG_MASK) + mac_len = skb->mac_len; + + err = push_mpls(skb, key, mpls->mpls_lse, + mpls->mpls_ethertype, mac_len); + break; + } + case OVS_ACTION_ATTR_POP_MPLS: + err = pop_mpls(skb, key, nla_get_be16(a)); + break; + + case OVS_ACTION_ATTR_PUSH_VLAN: + err = push_vlan(skb, key, nla_data(a)); + break; + + case OVS_ACTION_ATTR_POP_VLAN: + err = pop_vlan(skb, key); + break; + + case OVS_ACTION_ATTR_RECIRC: { + bool last = nla_is_last(a, rem); + + err = execute_recirc(dp, skb, key, a, last); + if (last) { + /* If this is the last action, the skb has + * been consumed or freed. + * Return immediately. + */ + return err; + } + break; + } + + case OVS_ACTION_ATTR_SET: + err = execute_set_action(skb, key, nla_data(a)); + break; + + case OVS_ACTION_ATTR_SET_MASKED: + case OVS_ACTION_ATTR_SET_TO_MASKED: + err = execute_masked_set_action(skb, key, nla_data(a)); + break; + + case OVS_ACTION_ATTR_SAMPLE: { + bool last = nla_is_last(a, rem); + + err = sample(dp, skb, key, a, last); + if (last) + return err; + + break; + } + + case OVS_ACTION_ATTR_CT: + if (!is_flow_key_valid(key)) { + err = ovs_flow_key_update(skb, key); + if (err) + return err; + } + + err = ovs_ct_execute(ovs_dp_get_net(dp), skb, key, + nla_data(a)); + + /* Hide stolen IP fragments from user space. */ + if (err) + return err == -EINPROGRESS ? 0 : err; + break; + + case OVS_ACTION_ATTR_CT_CLEAR: + err = ovs_ct_clear(skb, key); + break; + + case OVS_ACTION_ATTR_PUSH_ETH: + err = push_eth(skb, key, nla_data(a)); + break; + + case OVS_ACTION_ATTR_POP_ETH: + err = pop_eth(skb, key); + break; + + case OVS_ACTION_ATTR_PUSH_NSH: { + u8 buffer[NSH_HDR_MAX_LEN]; + struct nshhdr *nh = (struct nshhdr *)buffer; + + err = nsh_hdr_from_nlattr(nla_data(a), nh, + NSH_HDR_MAX_LEN); + if (unlikely(err)) + break; + err = push_nsh(skb, key, nh); + break; + } + + case OVS_ACTION_ATTR_POP_NSH: + err = pop_nsh(skb, key); + break; + + case OVS_ACTION_ATTR_METER: + if (ovs_meter_execute(dp, skb, key, nla_get_u32(a))) { + consume_skb(skb); + return 0; + } + break; + + case OVS_ACTION_ATTR_CLONE: { + bool last = nla_is_last(a, rem); + + err = clone(dp, skb, key, a, last); + if (last) + return err; + + break; + } + + case OVS_ACTION_ATTR_CHECK_PKT_LEN: { + bool last = nla_is_last(a, rem); + + err = execute_check_pkt_len(dp, skb, key, a, last); + if (last) + return err; + + break; + } + + case OVS_ACTION_ATTR_DEC_TTL: + err = execute_dec_ttl(skb, key); + if (err == -EHOSTUNREACH) + return dec_ttl_exception_handler(dp, skb, + key, a); + break; + } + + if (unlikely(err)) { + kfree_skb(skb); + return err; + } + } + + consume_skb(skb); + return 0; +} + +/* Execute the actions on the clone of the packet. The effect of the + * execution does not affect the original 'skb' nor the original 'key'. + * + * The execution may be deferred in case the actions can not be executed + * immediately. + */ +static int clone_execute(struct datapath *dp, struct sk_buff *skb, + struct sw_flow_key *key, u32 recirc_id, + const struct nlattr *actions, int len, + bool last, bool clone_flow_key) +{ + struct deferred_action *da; + struct sw_flow_key *clone; + + skb = last ? skb : skb_clone(skb, GFP_ATOMIC); + if (!skb) { + /* Out of memory, skip this action. + */ + return 0; + } + + /* When clone_flow_key is false, the 'key' will not be change + * by the actions, then the 'key' can be used directly. + * Otherwise, try to clone key from the next recursion level of + * 'flow_keys'. If clone is successful, execute the actions + * without deferring. + */ + clone = clone_flow_key ? clone_key(key) : key; + if (clone) { + int err = 0; + + if (actions) { /* Sample action */ + if (clone_flow_key) + __this_cpu_inc(exec_actions_level); + + err = do_execute_actions(dp, skb, clone, + actions, len); + + if (clone_flow_key) + __this_cpu_dec(exec_actions_level); + } else { /* Recirc action */ + clone->recirc_id = recirc_id; + ovs_dp_process_packet(skb, clone); + } + return err; + } + + /* Out of 'flow_keys' space. Defer actions */ + da = add_deferred_actions(skb, key, actions, len); + if (da) { + if (!actions) { /* Recirc action */ + key = &da->pkt_key; + key->recirc_id = recirc_id; + } + } else { + /* Out of per CPU action FIFO space. Drop the 'skb' and + * log an error. + */ + kfree_skb(skb); + + if (net_ratelimit()) { + if (actions) { /* Sample action */ + pr_warn("%s: deferred action limit reached, drop sample action\n", + ovs_dp_name(dp)); + } else { /* Recirc action */ + pr_warn("%s: deferred action limit reached, drop recirc action (recirc_id=%#x)\n", + ovs_dp_name(dp), recirc_id); + } + } + } + return 0; +} + +static void process_deferred_actions(struct datapath *dp) +{ + struct action_fifo *fifo = this_cpu_ptr(action_fifos); + + /* Do not touch the FIFO in case there is no deferred actions. */ + if (action_fifo_is_empty(fifo)) + return; + + /* Finishing executing all deferred actions. */ + do { + struct deferred_action *da = action_fifo_get(fifo); + struct sk_buff *skb = da->skb; + struct sw_flow_key *key = &da->pkt_key; + const struct nlattr *actions = da->actions; + int actions_len = da->actions_len; + + if (actions) + do_execute_actions(dp, skb, key, actions, actions_len); + else + ovs_dp_process_packet(skb, key); + } while (!action_fifo_is_empty(fifo)); + + /* Reset FIFO for the next packet. */ + action_fifo_init(fifo); +} + +/* Execute a list of actions against 'skb'. */ +int ovs_execute_actions(struct datapath *dp, struct sk_buff *skb, + const struct sw_flow_actions *acts, + struct sw_flow_key *key) +{ + int err, level; + + level = __this_cpu_inc_return(exec_actions_level); + if (unlikely(level > OVS_RECURSION_LIMIT)) { + net_crit_ratelimited("ovs: recursion limit reached on datapath %s, probable configuration error\n", + ovs_dp_name(dp)); + kfree_skb(skb); + err = -ENETDOWN; + goto out; + } + + OVS_CB(skb)->acts_origlen = acts->orig_len; + err = do_execute_actions(dp, skb, key, + acts->actions, acts->actions_len); + + if (level == 1) + process_deferred_actions(dp); + +out: + __this_cpu_dec(exec_actions_level); + return err; +} + +int action_fifos_init(void) +{ + action_fifos = alloc_percpu(struct action_fifo); + if (!action_fifos) + return -ENOMEM; + + flow_keys = alloc_percpu(struct action_flow_keys); + if (!flow_keys) { + free_percpu(action_fifos); + return -ENOMEM; + } + + return 0; +} + +void action_fifos_exit(void) +{ + free_percpu(action_fifos); + free_percpu(flow_keys); +} diff --git a/net/openvswitch/conntrack.c b/net/openvswitch/conntrack.c new file mode 100644 index 000000000..0591cfb28 --- /dev/null +++ b/net/openvswitch/conntrack.c @@ -0,0 +1,2326 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2015 Nicira, Inc. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_NF_NAT) +#include +#endif + +#include + +#include "datapath.h" +#include "conntrack.h" +#include "flow.h" +#include "flow_netlink.h" + +struct ovs_ct_len_tbl { + int maxlen; + int minlen; +}; + +/* Metadata mark for masked write to conntrack mark */ +struct md_mark { + u32 value; + u32 mask; +}; + +/* Metadata label for masked write to conntrack label. */ +struct md_labels { + struct ovs_key_ct_labels value; + struct ovs_key_ct_labels mask; +}; + +enum ovs_ct_nat { + OVS_CT_NAT = 1 << 0, /* NAT for committed connections only. */ + OVS_CT_SRC_NAT = 1 << 1, /* Source NAT for NEW connections. */ + OVS_CT_DST_NAT = 1 << 2, /* Destination NAT for NEW connections. */ +}; + +/* Conntrack action context for execution. */ +struct ovs_conntrack_info { + struct nf_conntrack_helper *helper; + struct nf_conntrack_zone zone; + struct nf_conn *ct; + u8 commit : 1; + u8 nat : 3; /* enum ovs_ct_nat */ + u8 force : 1; + u8 have_eventmask : 1; + u16 family; + u32 eventmask; /* Mask of 1 << IPCT_*. */ + struct md_mark mark; + struct md_labels labels; + char timeout[CTNL_TIMEOUT_NAME_MAX]; + struct nf_ct_timeout *nf_ct_timeout; +#if IS_ENABLED(CONFIG_NF_NAT) + struct nf_nat_range2 range; /* Only present for SRC NAT and DST NAT. */ +#endif +}; + +#if IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT) +#define OVS_CT_LIMIT_UNLIMITED 0 +#define OVS_CT_LIMIT_DEFAULT OVS_CT_LIMIT_UNLIMITED +#define CT_LIMIT_HASH_BUCKETS 512 +static DEFINE_STATIC_KEY_FALSE(ovs_ct_limit_enabled); + +struct ovs_ct_limit { + /* Elements in ovs_ct_limit_info->limits hash table */ + struct hlist_node hlist_node; + struct rcu_head rcu; + u16 zone; + u32 limit; +}; + +struct ovs_ct_limit_info { + u32 default_limit; + struct hlist_head *limits; + struct nf_conncount_data *data; +}; + +static const struct nla_policy ct_limit_policy[OVS_CT_LIMIT_ATTR_MAX + 1] = { + [OVS_CT_LIMIT_ATTR_ZONE_LIMIT] = { .type = NLA_NESTED, }, +}; +#endif + +static bool labels_nonzero(const struct ovs_key_ct_labels *labels); + +static void __ovs_ct_free_action(struct ovs_conntrack_info *ct_info); + +static u16 key_to_nfproto(const struct sw_flow_key *key) +{ + switch (ntohs(key->eth.type)) { + case ETH_P_IP: + return NFPROTO_IPV4; + case ETH_P_IPV6: + return NFPROTO_IPV6; + default: + return NFPROTO_UNSPEC; + } +} + +/* Map SKB connection state into the values used by flow definition. */ +static u8 ovs_ct_get_state(enum ip_conntrack_info ctinfo) +{ + u8 ct_state = OVS_CS_F_TRACKED; + + switch (ctinfo) { + case IP_CT_ESTABLISHED_REPLY: + case IP_CT_RELATED_REPLY: + ct_state |= OVS_CS_F_REPLY_DIR; + break; + default: + break; + } + + switch (ctinfo) { + case IP_CT_ESTABLISHED: + case IP_CT_ESTABLISHED_REPLY: + ct_state |= OVS_CS_F_ESTABLISHED; + break; + case IP_CT_RELATED: + case IP_CT_RELATED_REPLY: + ct_state |= OVS_CS_F_RELATED; + break; + case IP_CT_NEW: + ct_state |= OVS_CS_F_NEW; + break; + default: + break; + } + + return ct_state; +} + +static u32 ovs_ct_get_mark(const struct nf_conn *ct) +{ +#if IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) + return ct ? READ_ONCE(ct->mark) : 0; +#else + return 0; +#endif +} + +/* Guard against conntrack labels max size shrinking below 128 bits. */ +#if NF_CT_LABELS_MAX_SIZE < 16 +#error NF_CT_LABELS_MAX_SIZE must be at least 16 bytes +#endif + +static void ovs_ct_get_labels(const struct nf_conn *ct, + struct ovs_key_ct_labels *labels) +{ + struct nf_conn_labels *cl = ct ? nf_ct_labels_find(ct) : NULL; + + if (cl) + memcpy(labels, cl->bits, OVS_CT_LABELS_LEN); + else + memset(labels, 0, OVS_CT_LABELS_LEN); +} + +static void __ovs_ct_update_key_orig_tp(struct sw_flow_key *key, + const struct nf_conntrack_tuple *orig, + u8 icmp_proto) +{ + key->ct_orig_proto = orig->dst.protonum; + if (orig->dst.protonum == icmp_proto) { + key->ct.orig_tp.src = htons(orig->dst.u.icmp.type); + key->ct.orig_tp.dst = htons(orig->dst.u.icmp.code); + } else { + key->ct.orig_tp.src = orig->src.u.all; + key->ct.orig_tp.dst = orig->dst.u.all; + } +} + +static void __ovs_ct_update_key(struct sw_flow_key *key, u8 state, + const struct nf_conntrack_zone *zone, + const struct nf_conn *ct) +{ + key->ct_state = state; + key->ct_zone = zone->id; + key->ct.mark = ovs_ct_get_mark(ct); + ovs_ct_get_labels(ct, &key->ct.labels); + + if (ct) { + const struct nf_conntrack_tuple *orig; + + /* Use the master if we have one. */ + if (ct->master) + ct = ct->master; + orig = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple; + + /* IP version must match with the master connection. */ + if (key->eth.type == htons(ETH_P_IP) && + nf_ct_l3num(ct) == NFPROTO_IPV4) { + key->ipv4.ct_orig.src = orig->src.u3.ip; + key->ipv4.ct_orig.dst = orig->dst.u3.ip; + __ovs_ct_update_key_orig_tp(key, orig, IPPROTO_ICMP); + return; + } else if (key->eth.type == htons(ETH_P_IPV6) && + !sw_flow_key_is_nd(key) && + nf_ct_l3num(ct) == NFPROTO_IPV6) { + key->ipv6.ct_orig.src = orig->src.u3.in6; + key->ipv6.ct_orig.dst = orig->dst.u3.in6; + __ovs_ct_update_key_orig_tp(key, orig, NEXTHDR_ICMP); + return; + } + } + /* Clear 'ct_orig_proto' to mark the non-existence of conntrack + * original direction key fields. + */ + key->ct_orig_proto = 0; +} + +/* Update 'key' based on skb->_nfct. If 'post_ct' is true, then OVS has + * previously sent the packet to conntrack via the ct action. If + * 'keep_nat_flags' is true, the existing NAT flags retained, else they are + * initialized from the connection status. + */ +static void ovs_ct_update_key(const struct sk_buff *skb, + const struct ovs_conntrack_info *info, + struct sw_flow_key *key, bool post_ct, + bool keep_nat_flags) +{ + const struct nf_conntrack_zone *zone = &nf_ct_zone_dflt; + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + u8 state = 0; + + ct = nf_ct_get(skb, &ctinfo); + if (ct) { + state = ovs_ct_get_state(ctinfo); + /* All unconfirmed entries are NEW connections. */ + if (!nf_ct_is_confirmed(ct)) + state |= OVS_CS_F_NEW; + /* OVS persists the related flag for the duration of the + * connection. + */ + if (ct->master) + state |= OVS_CS_F_RELATED; + if (keep_nat_flags) { + state |= key->ct_state & OVS_CS_F_NAT_MASK; + } else { + if (ct->status & IPS_SRC_NAT) + state |= OVS_CS_F_SRC_NAT; + if (ct->status & IPS_DST_NAT) + state |= OVS_CS_F_DST_NAT; + } + zone = nf_ct_zone(ct); + } else if (post_ct) { + state = OVS_CS_F_TRACKED | OVS_CS_F_INVALID; + if (info) + zone = &info->zone; + } + __ovs_ct_update_key(key, state, zone, ct); +} + +/* This is called to initialize CT key fields possibly coming in from the local + * stack. + */ +void ovs_ct_fill_key(const struct sk_buff *skb, + struct sw_flow_key *key, + bool post_ct) +{ + ovs_ct_update_key(skb, NULL, key, post_ct, false); +} + +int ovs_ct_put_key(const struct sw_flow_key *swkey, + const struct sw_flow_key *output, struct sk_buff *skb) +{ + if (nla_put_u32(skb, OVS_KEY_ATTR_CT_STATE, output->ct_state)) + return -EMSGSIZE; + + if (IS_ENABLED(CONFIG_NF_CONNTRACK_ZONES) && + nla_put_u16(skb, OVS_KEY_ATTR_CT_ZONE, output->ct_zone)) + return -EMSGSIZE; + + if (IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) && + nla_put_u32(skb, OVS_KEY_ATTR_CT_MARK, output->ct.mark)) + return -EMSGSIZE; + + if (IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS) && + nla_put(skb, OVS_KEY_ATTR_CT_LABELS, sizeof(output->ct.labels), + &output->ct.labels)) + return -EMSGSIZE; + + if (swkey->ct_orig_proto) { + if (swkey->eth.type == htons(ETH_P_IP)) { + struct ovs_key_ct_tuple_ipv4 orig; + + memset(&orig, 0, sizeof(orig)); + orig.ipv4_src = output->ipv4.ct_orig.src; + orig.ipv4_dst = output->ipv4.ct_orig.dst; + orig.src_port = output->ct.orig_tp.src; + orig.dst_port = output->ct.orig_tp.dst; + orig.ipv4_proto = output->ct_orig_proto; + + if (nla_put(skb, OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV4, + sizeof(orig), &orig)) + return -EMSGSIZE; + } else if (swkey->eth.type == htons(ETH_P_IPV6)) { + struct ovs_key_ct_tuple_ipv6 orig; + + memset(&orig, 0, sizeof(orig)); + memcpy(orig.ipv6_src, output->ipv6.ct_orig.src.s6_addr32, + sizeof(orig.ipv6_src)); + memcpy(orig.ipv6_dst, output->ipv6.ct_orig.dst.s6_addr32, + sizeof(orig.ipv6_dst)); + orig.src_port = output->ct.orig_tp.src; + orig.dst_port = output->ct.orig_tp.dst; + orig.ipv6_proto = output->ct_orig_proto; + + if (nla_put(skb, OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV6, + sizeof(orig), &orig)) + return -EMSGSIZE; + } + } + + return 0; +} + +static int ovs_ct_set_mark(struct nf_conn *ct, struct sw_flow_key *key, + u32 ct_mark, u32 mask) +{ +#if IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) + u32 new_mark; + + new_mark = ct_mark | (READ_ONCE(ct->mark) & ~(mask)); + if (READ_ONCE(ct->mark) != new_mark) { + WRITE_ONCE(ct->mark, new_mark); + if (nf_ct_is_confirmed(ct)) + nf_conntrack_event_cache(IPCT_MARK, ct); + key->ct.mark = new_mark; + } + + return 0; +#else + return -ENOTSUPP; +#endif +} + +static struct nf_conn_labels *ovs_ct_get_conn_labels(struct nf_conn *ct) +{ + struct nf_conn_labels *cl; + + cl = nf_ct_labels_find(ct); + if (!cl) { + nf_ct_labels_ext_add(ct); + cl = nf_ct_labels_find(ct); + } + + return cl; +} + +/* Initialize labels for a new, yet to be committed conntrack entry. Note that + * since the new connection is not yet confirmed, and thus no-one else has + * access to it's labels, we simply write them over. + */ +static int ovs_ct_init_labels(struct nf_conn *ct, struct sw_flow_key *key, + const struct ovs_key_ct_labels *labels, + const struct ovs_key_ct_labels *mask) +{ + struct nf_conn_labels *cl, *master_cl; + bool have_mask = labels_nonzero(mask); + + /* Inherit master's labels to the related connection? */ + master_cl = ct->master ? nf_ct_labels_find(ct->master) : NULL; + + if (!master_cl && !have_mask) + return 0; /* Nothing to do. */ + + cl = ovs_ct_get_conn_labels(ct); + if (!cl) + return -ENOSPC; + + /* Inherit the master's labels, if any. */ + if (master_cl) + *cl = *master_cl; + + if (have_mask) { + u32 *dst = (u32 *)cl->bits; + int i; + + for (i = 0; i < OVS_CT_LABELS_LEN_32; i++) + dst[i] = (dst[i] & ~mask->ct_labels_32[i]) | + (labels->ct_labels_32[i] + & mask->ct_labels_32[i]); + } + + /* Labels are included in the IPCTNL_MSG_CT_NEW event only if the + * IPCT_LABEL bit is set in the event cache. + */ + nf_conntrack_event_cache(IPCT_LABEL, ct); + + memcpy(&key->ct.labels, cl->bits, OVS_CT_LABELS_LEN); + + return 0; +} + +static int ovs_ct_set_labels(struct nf_conn *ct, struct sw_flow_key *key, + const struct ovs_key_ct_labels *labels, + const struct ovs_key_ct_labels *mask) +{ + struct nf_conn_labels *cl; + int err; + + cl = ovs_ct_get_conn_labels(ct); + if (!cl) + return -ENOSPC; + + err = nf_connlabels_replace(ct, labels->ct_labels_32, + mask->ct_labels_32, + OVS_CT_LABELS_LEN_32); + if (err) + return err; + + memcpy(&key->ct.labels, cl->bits, OVS_CT_LABELS_LEN); + + return 0; +} + +/* 'skb' should already be pulled to nh_ofs. */ +static int ovs_ct_helper(struct sk_buff *skb, u16 proto) +{ + const struct nf_conntrack_helper *helper; + const struct nf_conn_help *help; + enum ip_conntrack_info ctinfo; + unsigned int protoff; + struct nf_conn *ct; + int err; + + ct = nf_ct_get(skb, &ctinfo); + if (!ct || ctinfo == IP_CT_RELATED_REPLY) + return NF_ACCEPT; + + help = nfct_help(ct); + if (!help) + return NF_ACCEPT; + + helper = rcu_dereference(help->helper); + if (!helper) + return NF_ACCEPT; + + switch (proto) { + case NFPROTO_IPV4: + protoff = ip_hdrlen(skb); + break; + case NFPROTO_IPV6: { + u8 nexthdr = ipv6_hdr(skb)->nexthdr; + __be16 frag_off; + int ofs; + + ofs = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &nexthdr, + &frag_off); + if (ofs < 0 || (frag_off & htons(~0x7)) != 0) { + pr_debug("proto header not found\n"); + return NF_ACCEPT; + } + protoff = ofs; + break; + } + default: + WARN_ONCE(1, "helper invoked on non-IP family!"); + return NF_DROP; + } + + err = helper->help(skb, protoff, ct, ctinfo); + if (err != NF_ACCEPT) + return err; + + /* Adjust seqs after helper. This is needed due to some helpers (e.g., + * FTP with NAT) adusting the TCP payload size when mangling IP + * addresses and/or port numbers in the text-based control connection. + */ + if (test_bit(IPS_SEQ_ADJUST_BIT, &ct->status) && + !nf_ct_seq_adjust(skb, ct, ctinfo, protoff)) + return NF_DROP; + return NF_ACCEPT; +} + +/* Returns 0 on success, -EINPROGRESS if 'skb' is stolen, or other nonzero + * value if 'skb' is freed. + */ +static int handle_fragments(struct net *net, struct sw_flow_key *key, + u16 zone, struct sk_buff *skb) +{ + struct ovs_skb_cb ovs_cb = *OVS_CB(skb); + int err; + + if (key->eth.type == htons(ETH_P_IP)) { + enum ip_defrag_users user = IP_DEFRAG_CONNTRACK_IN + zone; + + memset(IPCB(skb), 0, sizeof(struct inet_skb_parm)); + err = ip_defrag(net, skb, user); + if (err) + return err; + + ovs_cb.mru = IPCB(skb)->frag_max_size; +#if IS_ENABLED(CONFIG_NF_DEFRAG_IPV6) + } else if (key->eth.type == htons(ETH_P_IPV6)) { + enum ip6_defrag_users user = IP6_DEFRAG_CONNTRACK_IN + zone; + + memset(IP6CB(skb), 0, sizeof(struct inet6_skb_parm)); + err = nf_ct_frag6_gather(net, skb, user); + if (err) { + if (err != -EINPROGRESS) + kfree_skb(skb); + return err; + } + + key->ip.proto = ipv6_hdr(skb)->nexthdr; + ovs_cb.mru = IP6CB(skb)->frag_max_size; +#endif + } else { + kfree_skb(skb); + return -EPFNOSUPPORT; + } + + /* The key extracted from the fragment that completed this datagram + * likely didn't have an L4 header, so regenerate it. + */ + ovs_flow_key_update_l3l4(skb, key); + + key->ip.frag = OVS_FRAG_TYPE_NONE; + skb_clear_hash(skb); + skb->ignore_df = 1; + *OVS_CB(skb) = ovs_cb; + + return 0; +} + +static struct nf_conntrack_expect * +ovs_ct_expect_find(struct net *net, const struct nf_conntrack_zone *zone, + u16 proto, const struct sk_buff *skb) +{ + struct nf_conntrack_tuple tuple; + struct nf_conntrack_expect *exp; + + if (!nf_ct_get_tuplepr(skb, skb_network_offset(skb), proto, net, &tuple)) + return NULL; + + exp = __nf_ct_expect_find(net, zone, &tuple); + if (exp) { + struct nf_conntrack_tuple_hash *h; + + /* Delete existing conntrack entry, if it clashes with the + * expectation. This can happen since conntrack ALGs do not + * check for clashes between (new) expectations and existing + * conntrack entries. nf_conntrack_in() will check the + * expectations only if a conntrack entry can not be found, + * which can lead to OVS finding the expectation (here) in the + * init direction, but which will not be removed by the + * nf_conntrack_in() call, if a matching conntrack entry is + * found instead. In this case all init direction packets + * would be reported as new related packets, while reply + * direction packets would be reported as un-related + * established packets. + */ + h = nf_conntrack_find_get(net, zone, &tuple); + if (h) { + struct nf_conn *ct = nf_ct_tuplehash_to_ctrack(h); + + nf_ct_delete(ct, 0, 0); + nf_ct_put(ct); + } + } + + return exp; +} + +/* This replicates logic from nf_conntrack_core.c that is not exported. */ +static enum ip_conntrack_info +ovs_ct_get_info(const struct nf_conntrack_tuple_hash *h) +{ + const struct nf_conn *ct = nf_ct_tuplehash_to_ctrack(h); + + if (NF_CT_DIRECTION(h) == IP_CT_DIR_REPLY) + return IP_CT_ESTABLISHED_REPLY; + /* Once we've had two way comms, always ESTABLISHED. */ + if (test_bit(IPS_SEEN_REPLY_BIT, &ct->status)) + return IP_CT_ESTABLISHED; + if (test_bit(IPS_EXPECTED_BIT, &ct->status)) + return IP_CT_RELATED; + return IP_CT_NEW; +} + +/* Find an existing connection which this packet belongs to without + * re-attributing statistics or modifying the connection state. This allows an + * skb->_nfct lost due to an upcall to be recovered during actions execution. + * + * Must be called with rcu_read_lock. + * + * On success, populates skb->_nfct and returns the connection. Returns NULL + * if there is no existing entry. + */ +static struct nf_conn * +ovs_ct_find_existing(struct net *net, const struct nf_conntrack_zone *zone, + u8 l3num, struct sk_buff *skb, bool natted) +{ + struct nf_conntrack_tuple tuple; + struct nf_conntrack_tuple_hash *h; + struct nf_conn *ct; + + if (!nf_ct_get_tuplepr(skb, skb_network_offset(skb), l3num, + net, &tuple)) { + pr_debug("ovs_ct_find_existing: Can't get tuple\n"); + return NULL; + } + + /* Must invert the tuple if skb has been transformed by NAT. */ + if (natted) { + struct nf_conntrack_tuple inverse; + + if (!nf_ct_invert_tuple(&inverse, &tuple)) { + pr_debug("ovs_ct_find_existing: Inversion failed!\n"); + return NULL; + } + tuple = inverse; + } + + /* look for tuple match */ + h = nf_conntrack_find_get(net, zone, &tuple); + if (!h) + return NULL; /* Not found. */ + + ct = nf_ct_tuplehash_to_ctrack(h); + + /* Inverted packet tuple matches the reverse direction conntrack tuple, + * select the other tuplehash to get the right 'ctinfo' bits for this + * packet. + */ + if (natted) + h = &ct->tuplehash[!h->tuple.dst.dir]; + + nf_ct_set(skb, ct, ovs_ct_get_info(h)); + return ct; +} + +static +struct nf_conn *ovs_ct_executed(struct net *net, + const struct sw_flow_key *key, + const struct ovs_conntrack_info *info, + struct sk_buff *skb, + bool *ct_executed) +{ + struct nf_conn *ct = NULL; + + /* If no ct, check if we have evidence that an existing conntrack entry + * might be found for this skb. This happens when we lose a skb->_nfct + * due to an upcall, or if the direction is being forced. If the + * connection was not confirmed, it is not cached and needs to be run + * through conntrack again. + */ + *ct_executed = (key->ct_state & OVS_CS_F_TRACKED) && + !(key->ct_state & OVS_CS_F_INVALID) && + (key->ct_zone == info->zone.id); + + if (*ct_executed || (!key->ct_state && info->force)) { + ct = ovs_ct_find_existing(net, &info->zone, info->family, skb, + !!(key->ct_state & + OVS_CS_F_NAT_MASK)); + } + + return ct; +} + +/* Determine whether skb->_nfct is equal to the result of conntrack lookup. */ +static bool skb_nfct_cached(struct net *net, + const struct sw_flow_key *key, + const struct ovs_conntrack_info *info, + struct sk_buff *skb) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + bool ct_executed = true; + + ct = nf_ct_get(skb, &ctinfo); + if (!ct) + ct = ovs_ct_executed(net, key, info, skb, &ct_executed); + + if (ct) + nf_ct_get(skb, &ctinfo); + else + return false; + + if (!net_eq(net, read_pnet(&ct->ct_net))) + return false; + if (!nf_ct_zone_equal_any(info->ct, nf_ct_zone(ct))) + return false; + if (info->helper) { + struct nf_conn_help *help; + + help = nf_ct_ext_find(ct, NF_CT_EXT_HELPER); + if (help && rcu_access_pointer(help->helper) != info->helper) + return false; + } + if (info->nf_ct_timeout) { + struct nf_conn_timeout *timeout_ext; + + timeout_ext = nf_ct_timeout_find(ct); + if (!timeout_ext || info->nf_ct_timeout != + rcu_dereference(timeout_ext->timeout)) + return false; + } + /* Force conntrack entry direction to the current packet? */ + if (info->force && CTINFO2DIR(ctinfo) != IP_CT_DIR_ORIGINAL) { + /* Delete the conntrack entry if confirmed, else just release + * the reference. + */ + if (nf_ct_is_confirmed(ct)) + nf_ct_delete(ct, 0, 0); + + nf_ct_put(ct); + nf_ct_set(skb, NULL, 0); + return false; + } + + return ct_executed; +} + +#if IS_ENABLED(CONFIG_NF_NAT) +static void ovs_nat_update_key(struct sw_flow_key *key, + const struct sk_buff *skb, + enum nf_nat_manip_type maniptype) +{ + if (maniptype == NF_NAT_MANIP_SRC) { + __be16 src; + + key->ct_state |= OVS_CS_F_SRC_NAT; + if (key->eth.type == htons(ETH_P_IP)) + key->ipv4.addr.src = ip_hdr(skb)->saddr; + else if (key->eth.type == htons(ETH_P_IPV6)) + memcpy(&key->ipv6.addr.src, &ipv6_hdr(skb)->saddr, + sizeof(key->ipv6.addr.src)); + else + return; + + if (key->ip.proto == IPPROTO_UDP) + src = udp_hdr(skb)->source; + else if (key->ip.proto == IPPROTO_TCP) + src = tcp_hdr(skb)->source; + else if (key->ip.proto == IPPROTO_SCTP) + src = sctp_hdr(skb)->source; + else + return; + + key->tp.src = src; + } else { + __be16 dst; + + key->ct_state |= OVS_CS_F_DST_NAT; + if (key->eth.type == htons(ETH_P_IP)) + key->ipv4.addr.dst = ip_hdr(skb)->daddr; + else if (key->eth.type == htons(ETH_P_IPV6)) + memcpy(&key->ipv6.addr.dst, &ipv6_hdr(skb)->daddr, + sizeof(key->ipv6.addr.dst)); + else + return; + + if (key->ip.proto == IPPROTO_UDP) + dst = udp_hdr(skb)->dest; + else if (key->ip.proto == IPPROTO_TCP) + dst = tcp_hdr(skb)->dest; + else if (key->ip.proto == IPPROTO_SCTP) + dst = sctp_hdr(skb)->dest; + else + return; + + key->tp.dst = dst; + } +} + +/* Modelled after nf_nat_ipv[46]_fn(). + * range is only used for new, uninitialized NAT state. + * Returns either NF_ACCEPT or NF_DROP. + */ +static int ovs_ct_nat_execute(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + const struct nf_nat_range2 *range, + enum nf_nat_manip_type maniptype, struct sw_flow_key *key) +{ + int hooknum, nh_off, err = NF_ACCEPT; + + nh_off = skb_network_offset(skb); + skb_pull_rcsum(skb, nh_off); + + /* See HOOK2MANIP(). */ + if (maniptype == NF_NAT_MANIP_SRC) + hooknum = NF_INET_LOCAL_IN; /* Source NAT */ + else + hooknum = NF_INET_LOCAL_OUT; /* Destination NAT */ + + switch (ctinfo) { + case IP_CT_RELATED: + case IP_CT_RELATED_REPLY: + if (IS_ENABLED(CONFIG_NF_NAT) && + skb->protocol == htons(ETH_P_IP) && + ip_hdr(skb)->protocol == IPPROTO_ICMP) { + if (!nf_nat_icmp_reply_translation(skb, ct, ctinfo, + hooknum)) + err = NF_DROP; + goto push; + } else if (IS_ENABLED(CONFIG_IPV6) && + skb->protocol == htons(ETH_P_IPV6)) { + __be16 frag_off; + u8 nexthdr = ipv6_hdr(skb)->nexthdr; + int hdrlen = ipv6_skip_exthdr(skb, + sizeof(struct ipv6hdr), + &nexthdr, &frag_off); + + if (hdrlen >= 0 && nexthdr == IPPROTO_ICMPV6) { + if (!nf_nat_icmpv6_reply_translation(skb, ct, + ctinfo, + hooknum, + hdrlen)) + err = NF_DROP; + goto push; + } + } + /* Non-ICMP, fall thru to initialize if needed. */ + fallthrough; + case IP_CT_NEW: + /* Seen it before? This can happen for loopback, retrans, + * or local packets. + */ + if (!nf_nat_initialized(ct, maniptype)) { + /* Initialize according to the NAT action. */ + err = (range && range->flags & NF_NAT_RANGE_MAP_IPS) + /* Action is set up to establish a new + * mapping. + */ + ? nf_nat_setup_info(ct, range, maniptype) + : nf_nat_alloc_null_binding(ct, hooknum); + if (err != NF_ACCEPT) + goto push; + } + break; + + case IP_CT_ESTABLISHED: + case IP_CT_ESTABLISHED_REPLY: + break; + + default: + err = NF_DROP; + goto push; + } + + err = nf_nat_packet(ct, ctinfo, hooknum, skb); +push: + skb_push_rcsum(skb, nh_off); + + /* Update the flow key if NAT successful. */ + if (err == NF_ACCEPT) + ovs_nat_update_key(key, skb, maniptype); + + return err; +} + +/* Returns NF_DROP if the packet should be dropped, NF_ACCEPT otherwise. */ +static int ovs_ct_nat(struct net *net, struct sw_flow_key *key, + const struct ovs_conntrack_info *info, + struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo) +{ + enum nf_nat_manip_type maniptype; + int err; + + /* Add NAT extension if not confirmed yet. */ + if (!nf_ct_is_confirmed(ct) && !nf_ct_nat_ext_add(ct)) + return NF_ACCEPT; /* Can't NAT. */ + + /* Determine NAT type. + * Check if the NAT type can be deduced from the tracked connection. + * Make sure new expected connections (IP_CT_RELATED) are NATted only + * when committing. + */ + if (info->nat & OVS_CT_NAT && ctinfo != IP_CT_NEW && + ct->status & IPS_NAT_MASK && + (ctinfo != IP_CT_RELATED || info->commit)) { + /* NAT an established or related connection like before. */ + if (CTINFO2DIR(ctinfo) == IP_CT_DIR_REPLY) + /* This is the REPLY direction for a connection + * for which NAT was applied in the forward + * direction. Do the reverse NAT. + */ + maniptype = ct->status & IPS_SRC_NAT + ? NF_NAT_MANIP_DST : NF_NAT_MANIP_SRC; + else + maniptype = ct->status & IPS_SRC_NAT + ? NF_NAT_MANIP_SRC : NF_NAT_MANIP_DST; + } else if (info->nat & OVS_CT_SRC_NAT) { + maniptype = NF_NAT_MANIP_SRC; + } else if (info->nat & OVS_CT_DST_NAT) { + maniptype = NF_NAT_MANIP_DST; + } else { + return NF_ACCEPT; /* Connection is not NATed. */ + } + err = ovs_ct_nat_execute(skb, ct, ctinfo, &info->range, maniptype, key); + + if (err == NF_ACCEPT && ct->status & IPS_DST_NAT) { + if (ct->status & IPS_SRC_NAT) { + if (maniptype == NF_NAT_MANIP_SRC) + maniptype = NF_NAT_MANIP_DST; + else + maniptype = NF_NAT_MANIP_SRC; + + err = ovs_ct_nat_execute(skb, ct, ctinfo, &info->range, + maniptype, key); + } else if (CTINFO2DIR(ctinfo) == IP_CT_DIR_ORIGINAL) { + err = ovs_ct_nat_execute(skb, ct, ctinfo, NULL, + NF_NAT_MANIP_SRC, key); + } + } + + return err; +} +#else /* !CONFIG_NF_NAT */ +static int ovs_ct_nat(struct net *net, struct sw_flow_key *key, + const struct ovs_conntrack_info *info, + struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo) +{ + return NF_ACCEPT; +} +#endif + +/* Pass 'skb' through conntrack in 'net', using zone configured in 'info', if + * not done already. Update key with new CT state after passing the packet + * through conntrack. + * Note that if the packet is deemed invalid by conntrack, skb->_nfct will be + * set to NULL and 0 will be returned. + */ +static int __ovs_ct_lookup(struct net *net, struct sw_flow_key *key, + const struct ovs_conntrack_info *info, + struct sk_buff *skb) +{ + /* If we are recirculating packets to match on conntrack fields and + * committing with a separate conntrack action, then we don't need to + * actually run the packet through conntrack twice unless it's for a + * different zone. + */ + bool cached = skb_nfct_cached(net, key, info, skb); + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + + if (!cached) { + struct nf_hook_state state = { + .hook = NF_INET_PRE_ROUTING, + .pf = info->family, + .net = net, + }; + struct nf_conn *tmpl = info->ct; + int err; + + /* Associate skb with specified zone. */ + if (tmpl) { + ct = nf_ct_get(skb, &ctinfo); + nf_ct_put(ct); + nf_conntrack_get(&tmpl->ct_general); + nf_ct_set(skb, tmpl, IP_CT_NEW); + } + + err = nf_conntrack_in(skb, &state); + if (err != NF_ACCEPT) + return -ENOENT; + + /* Clear CT state NAT flags to mark that we have not yet done + * NAT after the nf_conntrack_in() call. We can actually clear + * the whole state, as it will be re-initialized below. + */ + key->ct_state = 0; + + /* Update the key, but keep the NAT flags. */ + ovs_ct_update_key(skb, info, key, true, true); + } + + ct = nf_ct_get(skb, &ctinfo); + if (ct) { + bool add_helper = false; + + /* Packets starting a new connection must be NATted before the + * helper, so that the helper knows about the NAT. We enforce + * this by delaying both NAT and helper calls for unconfirmed + * connections until the committing CT action. For later + * packets NAT and Helper may be called in either order. + * + * NAT will be done only if the CT action has NAT, and only + * once per packet (per zone), as guarded by the NAT bits in + * the key->ct_state. + */ + if (info->nat && !(key->ct_state & OVS_CS_F_NAT_MASK) && + (nf_ct_is_confirmed(ct) || info->commit) && + ovs_ct_nat(net, key, info, skb, ct, ctinfo) != NF_ACCEPT) { + return -EINVAL; + } + + /* Userspace may decide to perform a ct lookup without a helper + * specified followed by a (recirculate and) commit with one, + * or attach a helper in a later commit. Therefore, for + * connections which we will commit, we may need to attach + * the helper here. + */ + if (!nf_ct_is_confirmed(ct) && info->commit && + info->helper && !nfct_help(ct)) { + int err = __nf_ct_try_assign_helper(ct, info->ct, + GFP_ATOMIC); + if (err) + return err; + add_helper = true; + + /* helper installed, add seqadj if NAT is required */ + if (info->nat && !nfct_seqadj(ct)) { + if (!nfct_seqadj_ext_add(ct)) + return -EINVAL; + } + } + + /* Call the helper only if: + * - nf_conntrack_in() was executed above ("!cached") or a + * helper was just attached ("add_helper") for a confirmed + * connection, or + * - When committing an unconfirmed connection. + */ + if ((nf_ct_is_confirmed(ct) ? !cached || add_helper : + info->commit) && + ovs_ct_helper(skb, info->family) != NF_ACCEPT) { + return -EINVAL; + } + + if (nf_ct_protonum(ct) == IPPROTO_TCP && + nf_ct_is_confirmed(ct) && nf_conntrack_tcp_established(ct)) { + /* Be liberal for tcp packets so that out-of-window + * packets are not marked invalid. + */ + nf_ct_set_tcp_be_liberal(ct); + } + + nf_conn_act_ct_ext_fill(skb, ct, ctinfo); + } + + return 0; +} + +/* Lookup connection and read fields into key. */ +static int ovs_ct_lookup(struct net *net, struct sw_flow_key *key, + const struct ovs_conntrack_info *info, + struct sk_buff *skb) +{ + struct nf_conntrack_expect *exp; + + /* If we pass an expected packet through nf_conntrack_in() the + * expectation is typically removed, but the packet could still be + * lost in upcall processing. To prevent this from happening we + * perform an explicit expectation lookup. Expected connections are + * always new, and will be passed through conntrack only when they are + * committed, as it is OK to remove the expectation at that time. + */ + exp = ovs_ct_expect_find(net, &info->zone, info->family, skb); + if (exp) { + u8 state; + + /* NOTE: New connections are NATted and Helped only when + * committed, so we are not calling into NAT here. + */ + state = OVS_CS_F_TRACKED | OVS_CS_F_NEW | OVS_CS_F_RELATED; + __ovs_ct_update_key(key, state, &info->zone, exp->master); + } else { + struct nf_conn *ct; + int err; + + err = __ovs_ct_lookup(net, key, info, skb); + if (err) + return err; + + ct = (struct nf_conn *)skb_nfct(skb); + if (ct) + nf_ct_deliver_cached_events(ct); + } + + return 0; +} + +static bool labels_nonzero(const struct ovs_key_ct_labels *labels) +{ + size_t i; + + for (i = 0; i < OVS_CT_LABELS_LEN_32; i++) + if (labels->ct_labels_32[i]) + return true; + + return false; +} + +#if IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT) +static struct hlist_head *ct_limit_hash_bucket( + const struct ovs_ct_limit_info *info, u16 zone) +{ + return &info->limits[zone & (CT_LIMIT_HASH_BUCKETS - 1)]; +} + +/* Call with ovs_mutex */ +static void ct_limit_set(const struct ovs_ct_limit_info *info, + struct ovs_ct_limit *new_ct_limit) +{ + struct ovs_ct_limit *ct_limit; + struct hlist_head *head; + + head = ct_limit_hash_bucket(info, new_ct_limit->zone); + hlist_for_each_entry_rcu(ct_limit, head, hlist_node) { + if (ct_limit->zone == new_ct_limit->zone) { + hlist_replace_rcu(&ct_limit->hlist_node, + &new_ct_limit->hlist_node); + kfree_rcu(ct_limit, rcu); + return; + } + } + + hlist_add_head_rcu(&new_ct_limit->hlist_node, head); +} + +/* Call with ovs_mutex */ +static void ct_limit_del(const struct ovs_ct_limit_info *info, u16 zone) +{ + struct ovs_ct_limit *ct_limit; + struct hlist_head *head; + struct hlist_node *n; + + head = ct_limit_hash_bucket(info, zone); + hlist_for_each_entry_safe(ct_limit, n, head, hlist_node) { + if (ct_limit->zone == zone) { + hlist_del_rcu(&ct_limit->hlist_node); + kfree_rcu(ct_limit, rcu); + return; + } + } +} + +/* Call with RCU read lock */ +static u32 ct_limit_get(const struct ovs_ct_limit_info *info, u16 zone) +{ + struct ovs_ct_limit *ct_limit; + struct hlist_head *head; + + head = ct_limit_hash_bucket(info, zone); + hlist_for_each_entry_rcu(ct_limit, head, hlist_node) { + if (ct_limit->zone == zone) + return ct_limit->limit; + } + + return info->default_limit; +} + +static int ovs_ct_check_limit(struct net *net, + const struct ovs_conntrack_info *info, + const struct nf_conntrack_tuple *tuple) +{ + struct ovs_net *ovs_net = net_generic(net, ovs_net_id); + const struct ovs_ct_limit_info *ct_limit_info = ovs_net->ct_limit_info; + u32 per_zone_limit, connections; + u32 conncount_key; + + conncount_key = info->zone.id; + + per_zone_limit = ct_limit_get(ct_limit_info, info->zone.id); + if (per_zone_limit == OVS_CT_LIMIT_UNLIMITED) + return 0; + + connections = nf_conncount_count(net, ct_limit_info->data, + &conncount_key, tuple, &info->zone); + if (connections > per_zone_limit) + return -ENOMEM; + + return 0; +} +#endif + +/* Lookup connection and confirm if unconfirmed. */ +static int ovs_ct_commit(struct net *net, struct sw_flow_key *key, + const struct ovs_conntrack_info *info, + struct sk_buff *skb) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + int err; + + err = __ovs_ct_lookup(net, key, info, skb); + if (err) + return err; + + /* The connection could be invalid, in which case this is a no-op.*/ + ct = nf_ct_get(skb, &ctinfo); + if (!ct) + return 0; + +#if IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT) + if (static_branch_unlikely(&ovs_ct_limit_enabled)) { + if (!nf_ct_is_confirmed(ct)) { + err = ovs_ct_check_limit(net, info, + &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple); + if (err) { + net_warn_ratelimited("openvswitch: zone: %u " + "exceeds conntrack limit\n", + info->zone.id); + return err; + } + } + } +#endif + + /* Set the conntrack event mask if given. NEW and DELETE events have + * their own groups, but the NFNLGRP_CONNTRACK_UPDATE group listener + * typically would receive many kinds of updates. Setting the event + * mask allows those events to be filtered. The set event mask will + * remain in effect for the lifetime of the connection unless changed + * by a further CT action with both the commit flag and the eventmask + * option. */ + if (info->have_eventmask) { + struct nf_conntrack_ecache *cache = nf_ct_ecache_find(ct); + + if (cache) + cache->ctmask = info->eventmask; + } + + /* Apply changes before confirming the connection so that the initial + * conntrack NEW netlink event carries the values given in the CT + * action. + */ + if (info->mark.mask) { + err = ovs_ct_set_mark(ct, key, info->mark.value, + info->mark.mask); + if (err) + return err; + } + if (!nf_ct_is_confirmed(ct)) { + err = ovs_ct_init_labels(ct, key, &info->labels.value, + &info->labels.mask); + if (err) + return err; + + nf_conn_act_ct_ext_add(skb, ct, ctinfo); + } else if (IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS) && + labels_nonzero(&info->labels.mask)) { + err = ovs_ct_set_labels(ct, key, &info->labels.value, + &info->labels.mask); + if (err) + return err; + } + /* This will take care of sending queued events even if the connection + * is already confirmed. + */ + if (nf_conntrack_confirm(skb) != NF_ACCEPT) + return -EINVAL; + + return 0; +} + +/* Trim the skb to the length specified by the IP/IPv6 header, + * removing any trailing lower-layer padding. This prepares the skb + * for higher-layer processing that assumes skb->len excludes padding + * (such as nf_ip_checksum). The caller needs to pull the skb to the + * network header, and ensure ip_hdr/ipv6_hdr points to valid data. + */ +static int ovs_skb_network_trim(struct sk_buff *skb) +{ + unsigned int len; + int err; + + switch (skb->protocol) { + case htons(ETH_P_IP): + len = ntohs(ip_hdr(skb)->tot_len); + break; + case htons(ETH_P_IPV6): + len = sizeof(struct ipv6hdr) + + ntohs(ipv6_hdr(skb)->payload_len); + break; + default: + len = skb->len; + } + + err = pskb_trim_rcsum(skb, len); + if (err) + kfree_skb(skb); + + return err; +} + +/* Returns 0 on success, -EINPROGRESS if 'skb' is stolen, or other nonzero + * value if 'skb' is freed. + */ +int ovs_ct_execute(struct net *net, struct sk_buff *skb, + struct sw_flow_key *key, + const struct ovs_conntrack_info *info) +{ + int nh_ofs; + int err; + + /* The conntrack module expects to be working at L3. */ + nh_ofs = skb_network_offset(skb); + skb_pull_rcsum(skb, nh_ofs); + + err = ovs_skb_network_trim(skb); + if (err) + return err; + + if (key->ip.frag != OVS_FRAG_TYPE_NONE) { + err = handle_fragments(net, key, info->zone.id, skb); + if (err) + return err; + } + + if (info->commit) + err = ovs_ct_commit(net, key, info, skb); + else + err = ovs_ct_lookup(net, key, info, skb); + + skb_push_rcsum(skb, nh_ofs); + if (err) + kfree_skb(skb); + return err; +} + +int ovs_ct_clear(struct sk_buff *skb, struct sw_flow_key *key) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + + ct = nf_ct_get(skb, &ctinfo); + + nf_ct_put(ct); + nf_ct_set(skb, NULL, IP_CT_UNTRACKED); + + if (key) + ovs_ct_fill_key(skb, key, false); + + return 0; +} + +static int ovs_ct_add_helper(struct ovs_conntrack_info *info, const char *name, + const struct sw_flow_key *key, bool log) +{ + struct nf_conntrack_helper *helper; + struct nf_conn_help *help; + int ret = 0; + + helper = nf_conntrack_helper_try_module_get(name, info->family, + key->ip.proto); + if (!helper) { + OVS_NLERR(log, "Unknown helper \"%s\"", name); + return -EINVAL; + } + + help = nf_ct_helper_ext_add(info->ct, GFP_KERNEL); + if (!help) { + nf_conntrack_helper_put(helper); + return -ENOMEM; + } + +#if IS_ENABLED(CONFIG_NF_NAT) + if (info->nat) { + ret = nf_nat_helper_try_module_get(name, info->family, + key->ip.proto); + if (ret) { + nf_conntrack_helper_put(helper); + OVS_NLERR(log, "Failed to load \"%s\" NAT helper, error: %d", + name, ret); + return ret; + } + } +#endif + rcu_assign_pointer(help->helper, helper); + info->helper = helper; + return ret; +} + +#if IS_ENABLED(CONFIG_NF_NAT) +static int parse_nat(const struct nlattr *attr, + struct ovs_conntrack_info *info, bool log) +{ + struct nlattr *a; + int rem; + bool have_ip_max = false; + bool have_proto_max = false; + bool ip_vers = (info->family == NFPROTO_IPV6); + + nla_for_each_nested(a, attr, rem) { + static const int ovs_nat_attr_lens[OVS_NAT_ATTR_MAX + 1][2] = { + [OVS_NAT_ATTR_SRC] = {0, 0}, + [OVS_NAT_ATTR_DST] = {0, 0}, + [OVS_NAT_ATTR_IP_MIN] = {sizeof(struct in_addr), + sizeof(struct in6_addr)}, + [OVS_NAT_ATTR_IP_MAX] = {sizeof(struct in_addr), + sizeof(struct in6_addr)}, + [OVS_NAT_ATTR_PROTO_MIN] = {sizeof(u16), sizeof(u16)}, + [OVS_NAT_ATTR_PROTO_MAX] = {sizeof(u16), sizeof(u16)}, + [OVS_NAT_ATTR_PERSISTENT] = {0, 0}, + [OVS_NAT_ATTR_PROTO_HASH] = {0, 0}, + [OVS_NAT_ATTR_PROTO_RANDOM] = {0, 0}, + }; + int type = nla_type(a); + + if (type > OVS_NAT_ATTR_MAX) { + OVS_NLERR(log, "Unknown NAT attribute (type=%d, max=%d)", + type, OVS_NAT_ATTR_MAX); + return -EINVAL; + } + + if (nla_len(a) != ovs_nat_attr_lens[type][ip_vers]) { + OVS_NLERR(log, "NAT attribute type %d has unexpected length (%d != %d)", + type, nla_len(a), + ovs_nat_attr_lens[type][ip_vers]); + return -EINVAL; + } + + switch (type) { + case OVS_NAT_ATTR_SRC: + case OVS_NAT_ATTR_DST: + if (info->nat) { + OVS_NLERR(log, "Only one type of NAT may be specified"); + return -ERANGE; + } + info->nat |= OVS_CT_NAT; + info->nat |= ((type == OVS_NAT_ATTR_SRC) + ? OVS_CT_SRC_NAT : OVS_CT_DST_NAT); + break; + + case OVS_NAT_ATTR_IP_MIN: + nla_memcpy(&info->range.min_addr, a, + sizeof(info->range.min_addr)); + info->range.flags |= NF_NAT_RANGE_MAP_IPS; + break; + + case OVS_NAT_ATTR_IP_MAX: + have_ip_max = true; + nla_memcpy(&info->range.max_addr, a, + sizeof(info->range.max_addr)); + info->range.flags |= NF_NAT_RANGE_MAP_IPS; + break; + + case OVS_NAT_ATTR_PROTO_MIN: + info->range.min_proto.all = htons(nla_get_u16(a)); + info->range.flags |= NF_NAT_RANGE_PROTO_SPECIFIED; + break; + + case OVS_NAT_ATTR_PROTO_MAX: + have_proto_max = true; + info->range.max_proto.all = htons(nla_get_u16(a)); + info->range.flags |= NF_NAT_RANGE_PROTO_SPECIFIED; + break; + + case OVS_NAT_ATTR_PERSISTENT: + info->range.flags |= NF_NAT_RANGE_PERSISTENT; + break; + + case OVS_NAT_ATTR_PROTO_HASH: + info->range.flags |= NF_NAT_RANGE_PROTO_RANDOM; + break; + + case OVS_NAT_ATTR_PROTO_RANDOM: + info->range.flags |= NF_NAT_RANGE_PROTO_RANDOM_FULLY; + break; + + default: + OVS_NLERR(log, "Unknown nat attribute (%d)", type); + return -EINVAL; + } + } + + if (rem > 0) { + OVS_NLERR(log, "NAT attribute has %d unknown bytes", rem); + return -EINVAL; + } + if (!info->nat) { + /* Do not allow flags if no type is given. */ + if (info->range.flags) { + OVS_NLERR(log, + "NAT flags may be given only when NAT range (SRC or DST) is also specified." + ); + return -EINVAL; + } + info->nat = OVS_CT_NAT; /* NAT existing connections. */ + } else if (!info->commit) { + OVS_NLERR(log, + "NAT attributes may be specified only when CT COMMIT flag is also specified." + ); + return -EINVAL; + } + /* Allow missing IP_MAX. */ + if (info->range.flags & NF_NAT_RANGE_MAP_IPS && !have_ip_max) { + memcpy(&info->range.max_addr, &info->range.min_addr, + sizeof(info->range.max_addr)); + } + /* Allow missing PROTO_MAX. */ + if (info->range.flags & NF_NAT_RANGE_PROTO_SPECIFIED && + !have_proto_max) { + info->range.max_proto.all = info->range.min_proto.all; + } + return 0; +} +#endif + +static const struct ovs_ct_len_tbl ovs_ct_attr_lens[OVS_CT_ATTR_MAX + 1] = { + [OVS_CT_ATTR_COMMIT] = { .minlen = 0, .maxlen = 0 }, + [OVS_CT_ATTR_FORCE_COMMIT] = { .minlen = 0, .maxlen = 0 }, + [OVS_CT_ATTR_ZONE] = { .minlen = sizeof(u16), + .maxlen = sizeof(u16) }, + [OVS_CT_ATTR_MARK] = { .minlen = sizeof(struct md_mark), + .maxlen = sizeof(struct md_mark) }, + [OVS_CT_ATTR_LABELS] = { .minlen = sizeof(struct md_labels), + .maxlen = sizeof(struct md_labels) }, + [OVS_CT_ATTR_HELPER] = { .minlen = 1, + .maxlen = NF_CT_HELPER_NAME_LEN }, +#if IS_ENABLED(CONFIG_NF_NAT) + /* NAT length is checked when parsing the nested attributes. */ + [OVS_CT_ATTR_NAT] = { .minlen = 0, .maxlen = INT_MAX }, +#endif + [OVS_CT_ATTR_EVENTMASK] = { .minlen = sizeof(u32), + .maxlen = sizeof(u32) }, + [OVS_CT_ATTR_TIMEOUT] = { .minlen = 1, + .maxlen = CTNL_TIMEOUT_NAME_MAX }, +}; + +static int parse_ct(const struct nlattr *attr, struct ovs_conntrack_info *info, + const char **helper, bool log) +{ + struct nlattr *a; + int rem; + + nla_for_each_nested(a, attr, rem) { + int type = nla_type(a); + int maxlen; + int minlen; + + if (type > OVS_CT_ATTR_MAX) { + OVS_NLERR(log, + "Unknown conntrack attr (type=%d, max=%d)", + type, OVS_CT_ATTR_MAX); + return -EINVAL; + } + + maxlen = ovs_ct_attr_lens[type].maxlen; + minlen = ovs_ct_attr_lens[type].minlen; + if (nla_len(a) < minlen || nla_len(a) > maxlen) { + OVS_NLERR(log, + "Conntrack attr type has unexpected length (type=%d, length=%d, expected=%d)", + type, nla_len(a), maxlen); + return -EINVAL; + } + + switch (type) { + case OVS_CT_ATTR_FORCE_COMMIT: + info->force = true; + fallthrough; + case OVS_CT_ATTR_COMMIT: + info->commit = true; + break; +#ifdef CONFIG_NF_CONNTRACK_ZONES + case OVS_CT_ATTR_ZONE: + info->zone.id = nla_get_u16(a); + break; +#endif +#ifdef CONFIG_NF_CONNTRACK_MARK + case OVS_CT_ATTR_MARK: { + struct md_mark *mark = nla_data(a); + + if (!mark->mask) { + OVS_NLERR(log, "ct_mark mask cannot be 0"); + return -EINVAL; + } + info->mark = *mark; + break; + } +#endif +#ifdef CONFIG_NF_CONNTRACK_LABELS + case OVS_CT_ATTR_LABELS: { + struct md_labels *labels = nla_data(a); + + if (!labels_nonzero(&labels->mask)) { + OVS_NLERR(log, "ct_labels mask cannot be 0"); + return -EINVAL; + } + info->labels = *labels; + break; + } +#endif + case OVS_CT_ATTR_HELPER: + *helper = nla_data(a); + if (!memchr(*helper, '\0', nla_len(a))) { + OVS_NLERR(log, "Invalid conntrack helper"); + return -EINVAL; + } + break; +#if IS_ENABLED(CONFIG_NF_NAT) + case OVS_CT_ATTR_NAT: { + int err = parse_nat(a, info, log); + + if (err) + return err; + break; + } +#endif + case OVS_CT_ATTR_EVENTMASK: + info->have_eventmask = true; + info->eventmask = nla_get_u32(a); + break; +#ifdef CONFIG_NF_CONNTRACK_TIMEOUT + case OVS_CT_ATTR_TIMEOUT: + memcpy(info->timeout, nla_data(a), nla_len(a)); + if (!memchr(info->timeout, '\0', nla_len(a))) { + OVS_NLERR(log, "Invalid conntrack timeout"); + return -EINVAL; + } + break; +#endif + + default: + OVS_NLERR(log, "Unknown conntrack attr (%d)", + type); + return -EINVAL; + } + } + +#ifdef CONFIG_NF_CONNTRACK_MARK + if (!info->commit && info->mark.mask) { + OVS_NLERR(log, + "Setting conntrack mark requires 'commit' flag."); + return -EINVAL; + } +#endif +#ifdef CONFIG_NF_CONNTRACK_LABELS + if (!info->commit && labels_nonzero(&info->labels.mask)) { + OVS_NLERR(log, + "Setting conntrack labels requires 'commit' flag."); + return -EINVAL; + } +#endif + if (rem > 0) { + OVS_NLERR(log, "Conntrack attr has %d unknown bytes", rem); + return -EINVAL; + } + + return 0; +} + +bool ovs_ct_verify(struct net *net, enum ovs_key_attr attr) +{ + if (attr == OVS_KEY_ATTR_CT_STATE) + return true; + if (IS_ENABLED(CONFIG_NF_CONNTRACK_ZONES) && + attr == OVS_KEY_ATTR_CT_ZONE) + return true; + if (IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) && + attr == OVS_KEY_ATTR_CT_MARK) + return true; + if (IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS) && + attr == OVS_KEY_ATTR_CT_LABELS) { + struct ovs_net *ovs_net = net_generic(net, ovs_net_id); + + return ovs_net->xt_label; + } + + return false; +} + +int ovs_ct_copy_action(struct net *net, const struct nlattr *attr, + const struct sw_flow_key *key, + struct sw_flow_actions **sfa, bool log) +{ + struct ovs_conntrack_info ct_info; + const char *helper = NULL; + u16 family; + int err; + + family = key_to_nfproto(key); + if (family == NFPROTO_UNSPEC) { + OVS_NLERR(log, "ct family unspecified"); + return -EINVAL; + } + + memset(&ct_info, 0, sizeof(ct_info)); + ct_info.family = family; + + nf_ct_zone_init(&ct_info.zone, NF_CT_DEFAULT_ZONE_ID, + NF_CT_DEFAULT_ZONE_DIR, 0); + + err = parse_ct(attr, &ct_info, &helper, log); + if (err) + return err; + + /* Set up template for tracking connections in specific zones. */ + ct_info.ct = nf_ct_tmpl_alloc(net, &ct_info.zone, GFP_KERNEL); + if (!ct_info.ct) { + OVS_NLERR(log, "Failed to allocate conntrack template"); + return -ENOMEM; + } + + if (ct_info.timeout[0]) { + if (nf_ct_set_timeout(net, ct_info.ct, family, key->ip.proto, + ct_info.timeout)) + pr_info_ratelimited("Failed to associated timeout " + "policy `%s'\n", ct_info.timeout); + else + ct_info.nf_ct_timeout = rcu_dereference( + nf_ct_timeout_find(ct_info.ct)->timeout); + + } + + if (helper) { + err = ovs_ct_add_helper(&ct_info, helper, key, log); + if (err) + goto err_free_ct; + } + + err = ovs_nla_add_action(sfa, OVS_ACTION_ATTR_CT, &ct_info, + sizeof(ct_info), log); + if (err) + goto err_free_ct; + + __set_bit(IPS_CONFIRMED_BIT, &ct_info.ct->status); + return 0; +err_free_ct: + __ovs_ct_free_action(&ct_info); + return err; +} + +#if IS_ENABLED(CONFIG_NF_NAT) +static bool ovs_ct_nat_to_attr(const struct ovs_conntrack_info *info, + struct sk_buff *skb) +{ + struct nlattr *start; + + start = nla_nest_start_noflag(skb, OVS_CT_ATTR_NAT); + if (!start) + return false; + + if (info->nat & OVS_CT_SRC_NAT) { + if (nla_put_flag(skb, OVS_NAT_ATTR_SRC)) + return false; + } else if (info->nat & OVS_CT_DST_NAT) { + if (nla_put_flag(skb, OVS_NAT_ATTR_DST)) + return false; + } else { + goto out; + } + + if (info->range.flags & NF_NAT_RANGE_MAP_IPS) { + if (IS_ENABLED(CONFIG_NF_NAT) && + info->family == NFPROTO_IPV4) { + if (nla_put_in_addr(skb, OVS_NAT_ATTR_IP_MIN, + info->range.min_addr.ip) || + (info->range.max_addr.ip + != info->range.min_addr.ip && + (nla_put_in_addr(skb, OVS_NAT_ATTR_IP_MAX, + info->range.max_addr.ip)))) + return false; + } else if (IS_ENABLED(CONFIG_IPV6) && + info->family == NFPROTO_IPV6) { + if (nla_put_in6_addr(skb, OVS_NAT_ATTR_IP_MIN, + &info->range.min_addr.in6) || + (memcmp(&info->range.max_addr.in6, + &info->range.min_addr.in6, + sizeof(info->range.max_addr.in6)) && + (nla_put_in6_addr(skb, OVS_NAT_ATTR_IP_MAX, + &info->range.max_addr.in6)))) + return false; + } else { + return false; + } + } + if (info->range.flags & NF_NAT_RANGE_PROTO_SPECIFIED && + (nla_put_u16(skb, OVS_NAT_ATTR_PROTO_MIN, + ntohs(info->range.min_proto.all)) || + (info->range.max_proto.all != info->range.min_proto.all && + nla_put_u16(skb, OVS_NAT_ATTR_PROTO_MAX, + ntohs(info->range.max_proto.all))))) + return false; + + if (info->range.flags & NF_NAT_RANGE_PERSISTENT && + nla_put_flag(skb, OVS_NAT_ATTR_PERSISTENT)) + return false; + if (info->range.flags & NF_NAT_RANGE_PROTO_RANDOM && + nla_put_flag(skb, OVS_NAT_ATTR_PROTO_HASH)) + return false; + if (info->range.flags & NF_NAT_RANGE_PROTO_RANDOM_FULLY && + nla_put_flag(skb, OVS_NAT_ATTR_PROTO_RANDOM)) + return false; +out: + nla_nest_end(skb, start); + + return true; +} +#endif + +int ovs_ct_action_to_attr(const struct ovs_conntrack_info *ct_info, + struct sk_buff *skb) +{ + struct nlattr *start; + + start = nla_nest_start_noflag(skb, OVS_ACTION_ATTR_CT); + if (!start) + return -EMSGSIZE; + + if (ct_info->commit && nla_put_flag(skb, ct_info->force + ? OVS_CT_ATTR_FORCE_COMMIT + : OVS_CT_ATTR_COMMIT)) + return -EMSGSIZE; + if (IS_ENABLED(CONFIG_NF_CONNTRACK_ZONES) && + nla_put_u16(skb, OVS_CT_ATTR_ZONE, ct_info->zone.id)) + return -EMSGSIZE; + if (IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) && ct_info->mark.mask && + nla_put(skb, OVS_CT_ATTR_MARK, sizeof(ct_info->mark), + &ct_info->mark)) + return -EMSGSIZE; + if (IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS) && + labels_nonzero(&ct_info->labels.mask) && + nla_put(skb, OVS_CT_ATTR_LABELS, sizeof(ct_info->labels), + &ct_info->labels)) + return -EMSGSIZE; + if (ct_info->helper) { + if (nla_put_string(skb, OVS_CT_ATTR_HELPER, + ct_info->helper->name)) + return -EMSGSIZE; + } + if (ct_info->have_eventmask && + nla_put_u32(skb, OVS_CT_ATTR_EVENTMASK, ct_info->eventmask)) + return -EMSGSIZE; + if (ct_info->timeout[0]) { + if (nla_put_string(skb, OVS_CT_ATTR_TIMEOUT, ct_info->timeout)) + return -EMSGSIZE; + } + +#if IS_ENABLED(CONFIG_NF_NAT) + if (ct_info->nat && !ovs_ct_nat_to_attr(ct_info, skb)) + return -EMSGSIZE; +#endif + nla_nest_end(skb, start); + + return 0; +} + +void ovs_ct_free_action(const struct nlattr *a) +{ + struct ovs_conntrack_info *ct_info = nla_data(a); + + __ovs_ct_free_action(ct_info); +} + +static void __ovs_ct_free_action(struct ovs_conntrack_info *ct_info) +{ + if (ct_info->helper) { +#if IS_ENABLED(CONFIG_NF_NAT) + if (ct_info->nat) + nf_nat_helper_put(ct_info->helper); +#endif + nf_conntrack_helper_put(ct_info->helper); + } + if (ct_info->ct) { + if (ct_info->timeout[0]) + nf_ct_destroy_timeout(ct_info->ct); + nf_ct_tmpl_free(ct_info->ct); + } +} + +#if IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT) +static int ovs_ct_limit_init(struct net *net, struct ovs_net *ovs_net) +{ + int i, err; + + ovs_net->ct_limit_info = kmalloc(sizeof(*ovs_net->ct_limit_info), + GFP_KERNEL); + if (!ovs_net->ct_limit_info) + return -ENOMEM; + + ovs_net->ct_limit_info->default_limit = OVS_CT_LIMIT_DEFAULT; + ovs_net->ct_limit_info->limits = + kmalloc_array(CT_LIMIT_HASH_BUCKETS, sizeof(struct hlist_head), + GFP_KERNEL); + if (!ovs_net->ct_limit_info->limits) { + kfree(ovs_net->ct_limit_info); + return -ENOMEM; + } + + for (i = 0; i < CT_LIMIT_HASH_BUCKETS; i++) + INIT_HLIST_HEAD(&ovs_net->ct_limit_info->limits[i]); + + ovs_net->ct_limit_info->data = + nf_conncount_init(net, NFPROTO_INET, sizeof(u32)); + + if (IS_ERR(ovs_net->ct_limit_info->data)) { + err = PTR_ERR(ovs_net->ct_limit_info->data); + kfree(ovs_net->ct_limit_info->limits); + kfree(ovs_net->ct_limit_info); + pr_err("openvswitch: failed to init nf_conncount %d\n", err); + return err; + } + return 0; +} + +static void ovs_ct_limit_exit(struct net *net, struct ovs_net *ovs_net) +{ + const struct ovs_ct_limit_info *info = ovs_net->ct_limit_info; + int i; + + nf_conncount_destroy(net, NFPROTO_INET, info->data); + for (i = 0; i < CT_LIMIT_HASH_BUCKETS; ++i) { + struct hlist_head *head = &info->limits[i]; + struct ovs_ct_limit *ct_limit; + + hlist_for_each_entry_rcu(ct_limit, head, hlist_node, + lockdep_ovsl_is_held()) + kfree_rcu(ct_limit, rcu); + } + kfree(info->limits); + kfree(info); +} + +static struct sk_buff * +ovs_ct_limit_cmd_reply_start(struct genl_info *info, u8 cmd, + struct ovs_header **ovs_reply_header) +{ + struct ovs_header *ovs_header = info->userhdr; + struct sk_buff *skb; + + skb = genlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb) + return ERR_PTR(-ENOMEM); + + *ovs_reply_header = genlmsg_put(skb, info->snd_portid, + info->snd_seq, + &dp_ct_limit_genl_family, 0, cmd); + + if (!*ovs_reply_header) { + nlmsg_free(skb); + return ERR_PTR(-EMSGSIZE); + } + (*ovs_reply_header)->dp_ifindex = ovs_header->dp_ifindex; + + return skb; +} + +static bool check_zone_id(int zone_id, u16 *pzone) +{ + if (zone_id >= 0 && zone_id <= 65535) { + *pzone = (u16)zone_id; + return true; + } + return false; +} + +static int ovs_ct_limit_set_zone_limit(struct nlattr *nla_zone_limit, + struct ovs_ct_limit_info *info) +{ + struct ovs_zone_limit *zone_limit; + int rem; + u16 zone; + + rem = NLA_ALIGN(nla_len(nla_zone_limit)); + zone_limit = (struct ovs_zone_limit *)nla_data(nla_zone_limit); + + while (rem >= sizeof(*zone_limit)) { + if (unlikely(zone_limit->zone_id == + OVS_ZONE_LIMIT_DEFAULT_ZONE)) { + ovs_lock(); + info->default_limit = zone_limit->limit; + ovs_unlock(); + } else if (unlikely(!check_zone_id( + zone_limit->zone_id, &zone))) { + OVS_NLERR(true, "zone id is out of range"); + } else { + struct ovs_ct_limit *ct_limit; + + ct_limit = kmalloc(sizeof(*ct_limit), + GFP_KERNEL_ACCOUNT); + if (!ct_limit) + return -ENOMEM; + + ct_limit->zone = zone; + ct_limit->limit = zone_limit->limit; + + ovs_lock(); + ct_limit_set(info, ct_limit); + ovs_unlock(); + } + rem -= NLA_ALIGN(sizeof(*zone_limit)); + zone_limit = (struct ovs_zone_limit *)((u8 *)zone_limit + + NLA_ALIGN(sizeof(*zone_limit))); + } + + if (rem) + OVS_NLERR(true, "set zone limit has %d unknown bytes", rem); + + return 0; +} + +static int ovs_ct_limit_del_zone_limit(struct nlattr *nla_zone_limit, + struct ovs_ct_limit_info *info) +{ + struct ovs_zone_limit *zone_limit; + int rem; + u16 zone; + + rem = NLA_ALIGN(nla_len(nla_zone_limit)); + zone_limit = (struct ovs_zone_limit *)nla_data(nla_zone_limit); + + while (rem >= sizeof(*zone_limit)) { + if (unlikely(zone_limit->zone_id == + OVS_ZONE_LIMIT_DEFAULT_ZONE)) { + ovs_lock(); + info->default_limit = OVS_CT_LIMIT_DEFAULT; + ovs_unlock(); + } else if (unlikely(!check_zone_id( + zone_limit->zone_id, &zone))) { + OVS_NLERR(true, "zone id is out of range"); + } else { + ovs_lock(); + ct_limit_del(info, zone); + ovs_unlock(); + } + rem -= NLA_ALIGN(sizeof(*zone_limit)); + zone_limit = (struct ovs_zone_limit *)((u8 *)zone_limit + + NLA_ALIGN(sizeof(*zone_limit))); + } + + if (rem) + OVS_NLERR(true, "del zone limit has %d unknown bytes", rem); + + return 0; +} + +static int ovs_ct_limit_get_default_limit(struct ovs_ct_limit_info *info, + struct sk_buff *reply) +{ + struct ovs_zone_limit zone_limit = { + .zone_id = OVS_ZONE_LIMIT_DEFAULT_ZONE, + .limit = info->default_limit, + }; + + return nla_put_nohdr(reply, sizeof(zone_limit), &zone_limit); +} + +static int __ovs_ct_limit_get_zone_limit(struct net *net, + struct nf_conncount_data *data, + u16 zone_id, u32 limit, + struct sk_buff *reply) +{ + struct nf_conntrack_zone ct_zone; + struct ovs_zone_limit zone_limit; + u32 conncount_key = zone_id; + + zone_limit.zone_id = zone_id; + zone_limit.limit = limit; + nf_ct_zone_init(&ct_zone, zone_id, NF_CT_DEFAULT_ZONE_DIR, 0); + + zone_limit.count = nf_conncount_count(net, data, &conncount_key, NULL, + &ct_zone); + return nla_put_nohdr(reply, sizeof(zone_limit), &zone_limit); +} + +static int ovs_ct_limit_get_zone_limit(struct net *net, + struct nlattr *nla_zone_limit, + struct ovs_ct_limit_info *info, + struct sk_buff *reply) +{ + struct ovs_zone_limit *zone_limit; + int rem, err; + u32 limit; + u16 zone; + + rem = NLA_ALIGN(nla_len(nla_zone_limit)); + zone_limit = (struct ovs_zone_limit *)nla_data(nla_zone_limit); + + while (rem >= sizeof(*zone_limit)) { + if (unlikely(zone_limit->zone_id == + OVS_ZONE_LIMIT_DEFAULT_ZONE)) { + err = ovs_ct_limit_get_default_limit(info, reply); + if (err) + return err; + } else if (unlikely(!check_zone_id(zone_limit->zone_id, + &zone))) { + OVS_NLERR(true, "zone id is out of range"); + } else { + rcu_read_lock(); + limit = ct_limit_get(info, zone); + rcu_read_unlock(); + + err = __ovs_ct_limit_get_zone_limit( + net, info->data, zone, limit, reply); + if (err) + return err; + } + rem -= NLA_ALIGN(sizeof(*zone_limit)); + zone_limit = (struct ovs_zone_limit *)((u8 *)zone_limit + + NLA_ALIGN(sizeof(*zone_limit))); + } + + if (rem) + OVS_NLERR(true, "get zone limit has %d unknown bytes", rem); + + return 0; +} + +static int ovs_ct_limit_get_all_zone_limit(struct net *net, + struct ovs_ct_limit_info *info, + struct sk_buff *reply) +{ + struct ovs_ct_limit *ct_limit; + struct hlist_head *head; + int i, err = 0; + + err = ovs_ct_limit_get_default_limit(info, reply); + if (err) + return err; + + rcu_read_lock(); + for (i = 0; i < CT_LIMIT_HASH_BUCKETS; ++i) { + head = &info->limits[i]; + hlist_for_each_entry_rcu(ct_limit, head, hlist_node) { + err = __ovs_ct_limit_get_zone_limit(net, info->data, + ct_limit->zone, ct_limit->limit, reply); + if (err) + goto exit_err; + } + } + +exit_err: + rcu_read_unlock(); + return err; +} + +static int ovs_ct_limit_cmd_set(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr **a = info->attrs; + struct sk_buff *reply; + struct ovs_header *ovs_reply_header; + struct ovs_net *ovs_net = net_generic(sock_net(skb->sk), ovs_net_id); + struct ovs_ct_limit_info *ct_limit_info = ovs_net->ct_limit_info; + int err; + + reply = ovs_ct_limit_cmd_reply_start(info, OVS_CT_LIMIT_CMD_SET, + &ovs_reply_header); + if (IS_ERR(reply)) + return PTR_ERR(reply); + + if (!a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT]) { + err = -EINVAL; + goto exit_err; + } + + err = ovs_ct_limit_set_zone_limit(a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT], + ct_limit_info); + if (err) + goto exit_err; + + static_branch_enable(&ovs_ct_limit_enabled); + + genlmsg_end(reply, ovs_reply_header); + return genlmsg_reply(reply, info); + +exit_err: + nlmsg_free(reply); + return err; +} + +static int ovs_ct_limit_cmd_del(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr **a = info->attrs; + struct sk_buff *reply; + struct ovs_header *ovs_reply_header; + struct ovs_net *ovs_net = net_generic(sock_net(skb->sk), ovs_net_id); + struct ovs_ct_limit_info *ct_limit_info = ovs_net->ct_limit_info; + int err; + + reply = ovs_ct_limit_cmd_reply_start(info, OVS_CT_LIMIT_CMD_DEL, + &ovs_reply_header); + if (IS_ERR(reply)) + return PTR_ERR(reply); + + if (!a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT]) { + err = -EINVAL; + goto exit_err; + } + + err = ovs_ct_limit_del_zone_limit(a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT], + ct_limit_info); + if (err) + goto exit_err; + + genlmsg_end(reply, ovs_reply_header); + return genlmsg_reply(reply, info); + +exit_err: + nlmsg_free(reply); + return err; +} + +static int ovs_ct_limit_cmd_get(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr **a = info->attrs; + struct nlattr *nla_reply; + struct sk_buff *reply; + struct ovs_header *ovs_reply_header; + struct net *net = sock_net(skb->sk); + struct ovs_net *ovs_net = net_generic(net, ovs_net_id); + struct ovs_ct_limit_info *ct_limit_info = ovs_net->ct_limit_info; + int err; + + reply = ovs_ct_limit_cmd_reply_start(info, OVS_CT_LIMIT_CMD_GET, + &ovs_reply_header); + if (IS_ERR(reply)) + return PTR_ERR(reply); + + nla_reply = nla_nest_start_noflag(reply, OVS_CT_LIMIT_ATTR_ZONE_LIMIT); + if (!nla_reply) { + err = -EMSGSIZE; + goto exit_err; + } + + if (a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT]) { + err = ovs_ct_limit_get_zone_limit( + net, a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT], ct_limit_info, + reply); + if (err) + goto exit_err; + } else { + err = ovs_ct_limit_get_all_zone_limit(net, ct_limit_info, + reply); + if (err) + goto exit_err; + } + + nla_nest_end(reply, nla_reply); + genlmsg_end(reply, ovs_reply_header); + return genlmsg_reply(reply, info); + +exit_err: + nlmsg_free(reply); + return err; +} + +static const struct genl_small_ops ct_limit_genl_ops[] = { + { .cmd = OVS_CT_LIMIT_CMD_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, /* Requires CAP_NET_ADMIN + * privilege. + */ + .doit = ovs_ct_limit_cmd_set, + }, + { .cmd = OVS_CT_LIMIT_CMD_DEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, /* Requires CAP_NET_ADMIN + * privilege. + */ + .doit = ovs_ct_limit_cmd_del, + }, + { .cmd = OVS_CT_LIMIT_CMD_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, /* OK for unprivileged users. */ + .doit = ovs_ct_limit_cmd_get, + }, +}; + +static const struct genl_multicast_group ovs_ct_limit_multicast_group = { + .name = OVS_CT_LIMIT_MCGROUP, +}; + +struct genl_family dp_ct_limit_genl_family __ro_after_init = { + .hdrsize = sizeof(struct ovs_header), + .name = OVS_CT_LIMIT_FAMILY, + .version = OVS_CT_LIMIT_VERSION, + .maxattr = OVS_CT_LIMIT_ATTR_MAX, + .policy = ct_limit_policy, + .netnsok = true, + .parallel_ops = true, + .small_ops = ct_limit_genl_ops, + .n_small_ops = ARRAY_SIZE(ct_limit_genl_ops), + .resv_start_op = OVS_CT_LIMIT_CMD_GET + 1, + .mcgrps = &ovs_ct_limit_multicast_group, + .n_mcgrps = 1, + .module = THIS_MODULE, +}; +#endif + +int ovs_ct_init(struct net *net) +{ + unsigned int n_bits = sizeof(struct ovs_key_ct_labels) * BITS_PER_BYTE; + struct ovs_net *ovs_net = net_generic(net, ovs_net_id); + + if (nf_connlabels_get(net, n_bits - 1)) { + ovs_net->xt_label = false; + OVS_NLERR(true, "Failed to set connlabel length"); + } else { + ovs_net->xt_label = true; + } + +#if IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT) + return ovs_ct_limit_init(net, ovs_net); +#else + return 0; +#endif +} + +void ovs_ct_exit(struct net *net) +{ + struct ovs_net *ovs_net = net_generic(net, ovs_net_id); + +#if IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT) + ovs_ct_limit_exit(net, ovs_net); +#endif + + if (ovs_net->xt_label) + nf_connlabels_put(net); +} diff --git a/net/openvswitch/conntrack.h b/net/openvswitch/conntrack.h new file mode 100644 index 000000000..317e525c8 --- /dev/null +++ b/net/openvswitch/conntrack.h @@ -0,0 +1,106 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2015 Nicira, Inc. + */ + +#ifndef OVS_CONNTRACK_H +#define OVS_CONNTRACK_H 1 + +#include "flow.h" + +struct ovs_conntrack_info; +struct ovs_ct_limit_info; +enum ovs_key_attr; + +#if IS_ENABLED(CONFIG_NF_CONNTRACK) +int ovs_ct_init(struct net *); +void ovs_ct_exit(struct net *); +bool ovs_ct_verify(struct net *, enum ovs_key_attr attr); +int ovs_ct_copy_action(struct net *, const struct nlattr *, + const struct sw_flow_key *, struct sw_flow_actions **, + bool log); +int ovs_ct_action_to_attr(const struct ovs_conntrack_info *, struct sk_buff *); + +int ovs_ct_execute(struct net *, struct sk_buff *, struct sw_flow_key *, + const struct ovs_conntrack_info *); +int ovs_ct_clear(struct sk_buff *skb, struct sw_flow_key *key); + +void ovs_ct_fill_key(const struct sk_buff *skb, struct sw_flow_key *key, + bool post_ct); +int ovs_ct_put_key(const struct sw_flow_key *swkey, + const struct sw_flow_key *output, struct sk_buff *skb); +void ovs_ct_free_action(const struct nlattr *a); + +#define CT_SUPPORTED_MASK (OVS_CS_F_NEW | OVS_CS_F_ESTABLISHED | \ + OVS_CS_F_RELATED | OVS_CS_F_REPLY_DIR | \ + OVS_CS_F_INVALID | OVS_CS_F_TRACKED | \ + OVS_CS_F_SRC_NAT | OVS_CS_F_DST_NAT) +#else +#include + +static inline int ovs_ct_init(struct net *net) { return 0; } + +static inline void ovs_ct_exit(struct net *net) { } + +static inline bool ovs_ct_verify(struct net *net, int attr) +{ + return false; +} + +static inline int ovs_ct_copy_action(struct net *net, const struct nlattr *nla, + const struct sw_flow_key *key, + struct sw_flow_actions **acts, bool log) +{ + return -ENOTSUPP; +} + +static inline int ovs_ct_action_to_attr(const struct ovs_conntrack_info *info, + struct sk_buff *skb) +{ + return -ENOTSUPP; +} + +static inline int ovs_ct_execute(struct net *net, struct sk_buff *skb, + struct sw_flow_key *key, + const struct ovs_conntrack_info *info) +{ + kfree_skb(skb); + return -ENOTSUPP; +} + +static inline int ovs_ct_clear(struct sk_buff *skb, + struct sw_flow_key *key) +{ + return -ENOTSUPP; +} + +static inline void ovs_ct_fill_key(const struct sk_buff *skb, + struct sw_flow_key *key, + bool post_ct) +{ + key->ct_state = 0; + key->ct_zone = 0; + key->ct.mark = 0; + memset(&key->ct.labels, 0, sizeof(key->ct.labels)); + /* Clear 'ct_orig_proto' to mark the non-existence of original + * direction key fields. + */ + key->ct_orig_proto = 0; +} + +static inline int ovs_ct_put_key(const struct sw_flow_key *swkey, + const struct sw_flow_key *output, + struct sk_buff *skb) +{ + return 0; +} + +static inline void ovs_ct_free_action(const struct nlattr *a) { } + +#define CT_SUPPORTED_MASK 0 +#endif /* CONFIG_NF_CONNTRACK */ + +#if IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT) +extern struct genl_family dp_ct_limit_genl_family; +#endif +#endif /* ovs_conntrack.h */ diff --git a/net/openvswitch/datapath.c b/net/openvswitch/datapath.c new file mode 100644 index 000000000..3c7b24535 --- /dev/null +++ b/net/openvswitch/datapath.c @@ -0,0 +1,2762 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2007-2014 Nicira, Inc. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "datapath.h" +#include "flow.h" +#include "flow_table.h" +#include "flow_netlink.h" +#include "meter.h" +#include "openvswitch_trace.h" +#include "vport-internal_dev.h" +#include "vport-netdev.h" + +unsigned int ovs_net_id __read_mostly; + +static struct genl_family dp_packet_genl_family; +static struct genl_family dp_flow_genl_family; +static struct genl_family dp_datapath_genl_family; + +static const struct nla_policy flow_policy[]; + +static const struct genl_multicast_group ovs_dp_flow_multicast_group = { + .name = OVS_FLOW_MCGROUP, +}; + +static const struct genl_multicast_group ovs_dp_datapath_multicast_group = { + .name = OVS_DATAPATH_MCGROUP, +}; + +static const struct genl_multicast_group ovs_dp_vport_multicast_group = { + .name = OVS_VPORT_MCGROUP, +}; + +/* Check if need to build a reply message. + * OVS userspace sets the NLM_F_ECHO flag if it needs the reply. */ +static bool ovs_must_notify(struct genl_family *family, struct genl_info *info, + unsigned int group) +{ + return info->nlhdr->nlmsg_flags & NLM_F_ECHO || + genl_has_listeners(family, genl_info_net(info), group); +} + +static void ovs_notify(struct genl_family *family, + struct sk_buff *skb, struct genl_info *info) +{ + genl_notify(family, skb, info, 0, GFP_KERNEL); +} + +/** + * DOC: Locking: + * + * All writes e.g. Writes to device state (add/remove datapath, port, set + * operations on vports, etc.), Writes to other state (flow table + * modifications, set miscellaneous datapath parameters, etc.) are protected + * by ovs_lock. + * + * Reads are protected by RCU. + * + * There are a few special cases (mostly stats) that have their own + * synchronization but they nest under all of above and don't interact with + * each other. + * + * The RTNL lock nests inside ovs_mutex. + */ + +static DEFINE_MUTEX(ovs_mutex); + +void ovs_lock(void) +{ + mutex_lock(&ovs_mutex); +} + +void ovs_unlock(void) +{ + mutex_unlock(&ovs_mutex); +} + +#ifdef CONFIG_LOCKDEP +int lockdep_ovsl_is_held(void) +{ + if (debug_locks) + return lockdep_is_held(&ovs_mutex); + else + return 1; +} +#endif + +static struct vport *new_vport(const struct vport_parms *); +static int queue_gso_packets(struct datapath *dp, struct sk_buff *, + const struct sw_flow_key *, + const struct dp_upcall_info *, + uint32_t cutlen); +static int queue_userspace_packet(struct datapath *dp, struct sk_buff *, + const struct sw_flow_key *, + const struct dp_upcall_info *, + uint32_t cutlen); + +static void ovs_dp_masks_rebalance(struct work_struct *work); + +static int ovs_dp_set_upcall_portids(struct datapath *, const struct nlattr *); + +/* Must be called with rcu_read_lock or ovs_mutex. */ +const char *ovs_dp_name(const struct datapath *dp) +{ + struct vport *vport = ovs_vport_ovsl_rcu(dp, OVSP_LOCAL); + return ovs_vport_name(vport); +} + +static int get_dpifindex(const struct datapath *dp) +{ + struct vport *local; + int ifindex; + + rcu_read_lock(); + + local = ovs_vport_rcu(dp, OVSP_LOCAL); + if (local) + ifindex = local->dev->ifindex; + else + ifindex = 0; + + rcu_read_unlock(); + + return ifindex; +} + +static void destroy_dp_rcu(struct rcu_head *rcu) +{ + struct datapath *dp = container_of(rcu, struct datapath, rcu); + + ovs_flow_tbl_destroy(&dp->table); + free_percpu(dp->stats_percpu); + kfree(dp->ports); + ovs_meters_exit(dp); + kfree(rcu_dereference_raw(dp->upcall_portids)); + kfree(dp); +} + +static struct hlist_head *vport_hash_bucket(const struct datapath *dp, + u16 port_no) +{ + return &dp->ports[port_no & (DP_VPORT_HASH_BUCKETS - 1)]; +} + +/* Called with ovs_mutex or RCU read lock. */ +struct vport *ovs_lookup_vport(const struct datapath *dp, u16 port_no) +{ + struct vport *vport; + struct hlist_head *head; + + head = vport_hash_bucket(dp, port_no); + hlist_for_each_entry_rcu(vport, head, dp_hash_node, + lockdep_ovsl_is_held()) { + if (vport->port_no == port_no) + return vport; + } + return NULL; +} + +/* Called with ovs_mutex. */ +static struct vport *new_vport(const struct vport_parms *parms) +{ + struct vport *vport; + + vport = ovs_vport_add(parms); + if (!IS_ERR(vport)) { + struct datapath *dp = parms->dp; + struct hlist_head *head = vport_hash_bucket(dp, vport->port_no); + + hlist_add_head_rcu(&vport->dp_hash_node, head); + } + return vport; +} + +void ovs_dp_detach_port(struct vport *p) +{ + ASSERT_OVSL(); + + /* First drop references to device. */ + hlist_del_rcu(&p->dp_hash_node); + + /* Then destroy it. */ + ovs_vport_del(p); +} + +/* Must be called with rcu_read_lock. */ +void ovs_dp_process_packet(struct sk_buff *skb, struct sw_flow_key *key) +{ + const struct vport *p = OVS_CB(skb)->input_vport; + struct datapath *dp = p->dp; + struct sw_flow *flow; + struct sw_flow_actions *sf_acts; + struct dp_stats_percpu *stats; + u64 *stats_counter; + u32 n_mask_hit; + u32 n_cache_hit; + int error; + + stats = this_cpu_ptr(dp->stats_percpu); + + /* Look up flow. */ + flow = ovs_flow_tbl_lookup_stats(&dp->table, key, skb_get_hash(skb), + &n_mask_hit, &n_cache_hit); + if (unlikely(!flow)) { + struct dp_upcall_info upcall; + + memset(&upcall, 0, sizeof(upcall)); + upcall.cmd = OVS_PACKET_CMD_MISS; + + if (dp->user_features & OVS_DP_F_DISPATCH_UPCALL_PER_CPU) + upcall.portid = + ovs_dp_get_upcall_portid(dp, smp_processor_id()); + else + upcall.portid = ovs_vport_find_upcall_portid(p, skb); + + upcall.mru = OVS_CB(skb)->mru; + error = ovs_dp_upcall(dp, skb, key, &upcall, 0); + switch (error) { + case 0: + case -EAGAIN: + case -ERESTARTSYS: + case -EINTR: + consume_skb(skb); + break; + default: + kfree_skb(skb); + break; + } + stats_counter = &stats->n_missed; + goto out; + } + + ovs_flow_stats_update(flow, key->tp.flags, skb); + sf_acts = rcu_dereference(flow->sf_acts); + error = ovs_execute_actions(dp, skb, sf_acts, key); + if (unlikely(error)) + net_dbg_ratelimited("ovs: action execution error on datapath %s: %d\n", + ovs_dp_name(dp), error); + + stats_counter = &stats->n_hit; + +out: + /* Update datapath statistics. */ + u64_stats_update_begin(&stats->syncp); + (*stats_counter)++; + stats->n_mask_hit += n_mask_hit; + stats->n_cache_hit += n_cache_hit; + u64_stats_update_end(&stats->syncp); +} + +int ovs_dp_upcall(struct datapath *dp, struct sk_buff *skb, + const struct sw_flow_key *key, + const struct dp_upcall_info *upcall_info, + uint32_t cutlen) +{ + struct dp_stats_percpu *stats; + int err; + + if (trace_ovs_dp_upcall_enabled()) + trace_ovs_dp_upcall(dp, skb, key, upcall_info); + + if (upcall_info->portid == 0) { + err = -ENOTCONN; + goto err; + } + + if (!skb_is_gso(skb)) + err = queue_userspace_packet(dp, skb, key, upcall_info, cutlen); + else + err = queue_gso_packets(dp, skb, key, upcall_info, cutlen); + if (err) + goto err; + + return 0; + +err: + stats = this_cpu_ptr(dp->stats_percpu); + + u64_stats_update_begin(&stats->syncp); + stats->n_lost++; + u64_stats_update_end(&stats->syncp); + + return err; +} + +static int queue_gso_packets(struct datapath *dp, struct sk_buff *skb, + const struct sw_flow_key *key, + const struct dp_upcall_info *upcall_info, + uint32_t cutlen) +{ + unsigned int gso_type = skb_shinfo(skb)->gso_type; + struct sw_flow_key later_key; + struct sk_buff *segs, *nskb; + int err; + + BUILD_BUG_ON(sizeof(*OVS_CB(skb)) > SKB_GSO_CB_OFFSET); + segs = __skb_gso_segment(skb, NETIF_F_SG, false); + if (IS_ERR(segs)) + return PTR_ERR(segs); + if (segs == NULL) + return -EINVAL; + + if (gso_type & SKB_GSO_UDP) { + /* The initial flow key extracted by ovs_flow_key_extract() + * in this case is for a first fragment, so we need to + * properly mark later fragments. + */ + later_key = *key; + later_key.ip.frag = OVS_FRAG_TYPE_LATER; + } + + /* Queue all of the segments. */ + skb_list_walk_safe(segs, skb, nskb) { + if (gso_type & SKB_GSO_UDP && skb != segs) + key = &later_key; + + err = queue_userspace_packet(dp, skb, key, upcall_info, cutlen); + if (err) + break; + + } + + /* Free all of the segments. */ + skb_list_walk_safe(segs, skb, nskb) { + if (err) + kfree_skb(skb); + else + consume_skb(skb); + } + return err; +} + +static size_t upcall_msg_size(const struct dp_upcall_info *upcall_info, + unsigned int hdrlen, int actions_attrlen) +{ + size_t size = NLMSG_ALIGN(sizeof(struct ovs_header)) + + nla_total_size(hdrlen) /* OVS_PACKET_ATTR_PACKET */ + + nla_total_size(ovs_key_attr_size()) /* OVS_PACKET_ATTR_KEY */ + + nla_total_size(sizeof(unsigned int)) /* OVS_PACKET_ATTR_LEN */ + + nla_total_size(sizeof(u64)); /* OVS_PACKET_ATTR_HASH */ + + /* OVS_PACKET_ATTR_USERDATA */ + if (upcall_info->userdata) + size += NLA_ALIGN(upcall_info->userdata->nla_len); + + /* OVS_PACKET_ATTR_EGRESS_TUN_KEY */ + if (upcall_info->egress_tun_info) + size += nla_total_size(ovs_tun_key_attr_size()); + + /* OVS_PACKET_ATTR_ACTIONS */ + if (upcall_info->actions_len) + size += nla_total_size(actions_attrlen); + + /* OVS_PACKET_ATTR_MRU */ + if (upcall_info->mru) + size += nla_total_size(sizeof(upcall_info->mru)); + + return size; +} + +static void pad_packet(struct datapath *dp, struct sk_buff *skb) +{ + if (!(dp->user_features & OVS_DP_F_UNALIGNED)) { + size_t plen = NLA_ALIGN(skb->len) - skb->len; + + if (plen > 0) + skb_put_zero(skb, plen); + } +} + +static int queue_userspace_packet(struct datapath *dp, struct sk_buff *skb, + const struct sw_flow_key *key, + const struct dp_upcall_info *upcall_info, + uint32_t cutlen) +{ + struct ovs_header *upcall; + struct sk_buff *nskb = NULL; + struct sk_buff *user_skb = NULL; /* to be queued to userspace */ + struct nlattr *nla; + size_t len; + unsigned int hlen; + int err, dp_ifindex; + u64 hash; + + dp_ifindex = get_dpifindex(dp); + if (!dp_ifindex) + return -ENODEV; + + if (skb_vlan_tag_present(skb)) { + nskb = skb_clone(skb, GFP_ATOMIC); + if (!nskb) + return -ENOMEM; + + nskb = __vlan_hwaccel_push_inside(nskb); + if (!nskb) + return -ENOMEM; + + skb = nskb; + } + + if (nla_attr_size(skb->len) > USHRT_MAX) { + err = -EFBIG; + goto out; + } + + /* Complete checksum if needed */ + if (skb->ip_summed == CHECKSUM_PARTIAL && + (err = skb_csum_hwoffload_help(skb, 0))) + goto out; + + /* Older versions of OVS user space enforce alignment of the last + * Netlink attribute to NLA_ALIGNTO which would require extensive + * padding logic. Only perform zerocopy if padding is not required. + */ + if (dp->user_features & OVS_DP_F_UNALIGNED) + hlen = skb_zerocopy_headlen(skb); + else + hlen = skb->len; + + len = upcall_msg_size(upcall_info, hlen - cutlen, + OVS_CB(skb)->acts_origlen); + user_skb = genlmsg_new(len, GFP_ATOMIC); + if (!user_skb) { + err = -ENOMEM; + goto out; + } + + upcall = genlmsg_put(user_skb, 0, 0, &dp_packet_genl_family, + 0, upcall_info->cmd); + if (!upcall) { + err = -EINVAL; + goto out; + } + upcall->dp_ifindex = dp_ifindex; + + err = ovs_nla_put_key(key, key, OVS_PACKET_ATTR_KEY, false, user_skb); + if (err) + goto out; + + if (upcall_info->userdata) + __nla_put(user_skb, OVS_PACKET_ATTR_USERDATA, + nla_len(upcall_info->userdata), + nla_data(upcall_info->userdata)); + + if (upcall_info->egress_tun_info) { + nla = nla_nest_start_noflag(user_skb, + OVS_PACKET_ATTR_EGRESS_TUN_KEY); + if (!nla) { + err = -EMSGSIZE; + goto out; + } + err = ovs_nla_put_tunnel_info(user_skb, + upcall_info->egress_tun_info); + if (err) + goto out; + + nla_nest_end(user_skb, nla); + } + + if (upcall_info->actions_len) { + nla = nla_nest_start_noflag(user_skb, OVS_PACKET_ATTR_ACTIONS); + if (!nla) { + err = -EMSGSIZE; + goto out; + } + err = ovs_nla_put_actions(upcall_info->actions, + upcall_info->actions_len, + user_skb); + if (!err) + nla_nest_end(user_skb, nla); + else + nla_nest_cancel(user_skb, nla); + } + + /* Add OVS_PACKET_ATTR_MRU */ + if (upcall_info->mru && + nla_put_u16(user_skb, OVS_PACKET_ATTR_MRU, upcall_info->mru)) { + err = -ENOBUFS; + goto out; + } + + /* Add OVS_PACKET_ATTR_LEN when packet is truncated */ + if (cutlen > 0 && + nla_put_u32(user_skb, OVS_PACKET_ATTR_LEN, skb->len)) { + err = -ENOBUFS; + goto out; + } + + /* Add OVS_PACKET_ATTR_HASH */ + hash = skb_get_hash_raw(skb); + if (skb->sw_hash) + hash |= OVS_PACKET_HASH_SW_BIT; + + if (skb->l4_hash) + hash |= OVS_PACKET_HASH_L4_BIT; + + if (nla_put(user_skb, OVS_PACKET_ATTR_HASH, sizeof (u64), &hash)) { + err = -ENOBUFS; + goto out; + } + + /* Only reserve room for attribute header, packet data is added + * in skb_zerocopy() */ + if (!(nla = nla_reserve(user_skb, OVS_PACKET_ATTR_PACKET, 0))) { + err = -ENOBUFS; + goto out; + } + nla->nla_len = nla_attr_size(skb->len - cutlen); + + err = skb_zerocopy(user_skb, skb, skb->len - cutlen, hlen); + if (err) + goto out; + + /* Pad OVS_PACKET_ATTR_PACKET if linear copy was performed */ + pad_packet(dp, user_skb); + + ((struct nlmsghdr *) user_skb->data)->nlmsg_len = user_skb->len; + + err = genlmsg_unicast(ovs_dp_get_net(dp), user_skb, upcall_info->portid); + user_skb = NULL; +out: + if (err) + skb_tx_error(skb); + consume_skb(user_skb); + consume_skb(nskb); + + return err; +} + +static int ovs_packet_cmd_execute(struct sk_buff *skb, struct genl_info *info) +{ + struct ovs_header *ovs_header = info->userhdr; + struct net *net = sock_net(skb->sk); + struct nlattr **a = info->attrs; + struct sw_flow_actions *acts; + struct sk_buff *packet; + struct sw_flow *flow; + struct sw_flow_actions *sf_acts; + struct datapath *dp; + struct vport *input_vport; + u16 mru = 0; + u64 hash; + int len; + int err; + bool log = !a[OVS_PACKET_ATTR_PROBE]; + + err = -EINVAL; + if (!a[OVS_PACKET_ATTR_PACKET] || !a[OVS_PACKET_ATTR_KEY] || + !a[OVS_PACKET_ATTR_ACTIONS]) + goto err; + + len = nla_len(a[OVS_PACKET_ATTR_PACKET]); + packet = __dev_alloc_skb(NET_IP_ALIGN + len, GFP_KERNEL); + err = -ENOMEM; + if (!packet) + goto err; + skb_reserve(packet, NET_IP_ALIGN); + + nla_memcpy(__skb_put(packet, len), a[OVS_PACKET_ATTR_PACKET], len); + + /* Set packet's mru */ + if (a[OVS_PACKET_ATTR_MRU]) { + mru = nla_get_u16(a[OVS_PACKET_ATTR_MRU]); + packet->ignore_df = 1; + } + OVS_CB(packet)->mru = mru; + + if (a[OVS_PACKET_ATTR_HASH]) { + hash = nla_get_u64(a[OVS_PACKET_ATTR_HASH]); + + __skb_set_hash(packet, hash & 0xFFFFFFFFULL, + !!(hash & OVS_PACKET_HASH_SW_BIT), + !!(hash & OVS_PACKET_HASH_L4_BIT)); + } + + /* Build an sw_flow for sending this packet. */ + flow = ovs_flow_alloc(); + err = PTR_ERR(flow); + if (IS_ERR(flow)) + goto err_kfree_skb; + + err = ovs_flow_key_extract_userspace(net, a[OVS_PACKET_ATTR_KEY], + packet, &flow->key, log); + if (err) + goto err_flow_free; + + err = ovs_nla_copy_actions(net, a[OVS_PACKET_ATTR_ACTIONS], + &flow->key, &acts, log); + if (err) + goto err_flow_free; + + rcu_assign_pointer(flow->sf_acts, acts); + packet->priority = flow->key.phy.priority; + packet->mark = flow->key.phy.skb_mark; + + rcu_read_lock(); + dp = get_dp_rcu(net, ovs_header->dp_ifindex); + err = -ENODEV; + if (!dp) + goto err_unlock; + + input_vport = ovs_vport_rcu(dp, flow->key.phy.in_port); + if (!input_vport) + input_vport = ovs_vport_rcu(dp, OVSP_LOCAL); + + if (!input_vport) + goto err_unlock; + + packet->dev = input_vport->dev; + OVS_CB(packet)->input_vport = input_vport; + sf_acts = rcu_dereference(flow->sf_acts); + + local_bh_disable(); + err = ovs_execute_actions(dp, packet, sf_acts, &flow->key); + local_bh_enable(); + rcu_read_unlock(); + + ovs_flow_free(flow, false); + return err; + +err_unlock: + rcu_read_unlock(); +err_flow_free: + ovs_flow_free(flow, false); +err_kfree_skb: + kfree_skb(packet); +err: + return err; +} + +static const struct nla_policy packet_policy[OVS_PACKET_ATTR_MAX + 1] = { + [OVS_PACKET_ATTR_PACKET] = { .len = ETH_HLEN }, + [OVS_PACKET_ATTR_KEY] = { .type = NLA_NESTED }, + [OVS_PACKET_ATTR_ACTIONS] = { .type = NLA_NESTED }, + [OVS_PACKET_ATTR_PROBE] = { .type = NLA_FLAG }, + [OVS_PACKET_ATTR_MRU] = { .type = NLA_U16 }, + [OVS_PACKET_ATTR_HASH] = { .type = NLA_U64 }, +}; + +static const struct genl_small_ops dp_packet_genl_ops[] = { + { .cmd = OVS_PACKET_CMD_EXECUTE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, /* Requires CAP_NET_ADMIN privilege. */ + .doit = ovs_packet_cmd_execute + } +}; + +static struct genl_family dp_packet_genl_family __ro_after_init = { + .hdrsize = sizeof(struct ovs_header), + .name = OVS_PACKET_FAMILY, + .version = OVS_PACKET_VERSION, + .maxattr = OVS_PACKET_ATTR_MAX, + .policy = packet_policy, + .netnsok = true, + .parallel_ops = true, + .small_ops = dp_packet_genl_ops, + .n_small_ops = ARRAY_SIZE(dp_packet_genl_ops), + .resv_start_op = OVS_PACKET_CMD_EXECUTE + 1, + .module = THIS_MODULE, +}; + +static void get_dp_stats(const struct datapath *dp, struct ovs_dp_stats *stats, + struct ovs_dp_megaflow_stats *mega_stats) +{ + int i; + + memset(mega_stats, 0, sizeof(*mega_stats)); + + stats->n_flows = ovs_flow_tbl_count(&dp->table); + mega_stats->n_masks = ovs_flow_tbl_num_masks(&dp->table); + + stats->n_hit = stats->n_missed = stats->n_lost = 0; + + for_each_possible_cpu(i) { + const struct dp_stats_percpu *percpu_stats; + struct dp_stats_percpu local_stats; + unsigned int start; + + percpu_stats = per_cpu_ptr(dp->stats_percpu, i); + + do { + start = u64_stats_fetch_begin_irq(&percpu_stats->syncp); + local_stats = *percpu_stats; + } while (u64_stats_fetch_retry_irq(&percpu_stats->syncp, start)); + + stats->n_hit += local_stats.n_hit; + stats->n_missed += local_stats.n_missed; + stats->n_lost += local_stats.n_lost; + mega_stats->n_mask_hit += local_stats.n_mask_hit; + mega_stats->n_cache_hit += local_stats.n_cache_hit; + } +} + +static bool should_fill_key(const struct sw_flow_id *sfid, uint32_t ufid_flags) +{ + return ovs_identifier_is_ufid(sfid) && + !(ufid_flags & OVS_UFID_F_OMIT_KEY); +} + +static bool should_fill_mask(uint32_t ufid_flags) +{ + return !(ufid_flags & OVS_UFID_F_OMIT_MASK); +} + +static bool should_fill_actions(uint32_t ufid_flags) +{ + return !(ufid_flags & OVS_UFID_F_OMIT_ACTIONS); +} + +static size_t ovs_flow_cmd_msg_size(const struct sw_flow_actions *acts, + const struct sw_flow_id *sfid, + uint32_t ufid_flags) +{ + size_t len = NLMSG_ALIGN(sizeof(struct ovs_header)); + + /* OVS_FLOW_ATTR_UFID, or unmasked flow key as fallback + * see ovs_nla_put_identifier() + */ + if (sfid && ovs_identifier_is_ufid(sfid)) + len += nla_total_size(sfid->ufid_len); + else + len += nla_total_size(ovs_key_attr_size()); + + /* OVS_FLOW_ATTR_KEY */ + if (!sfid || should_fill_key(sfid, ufid_flags)) + len += nla_total_size(ovs_key_attr_size()); + + /* OVS_FLOW_ATTR_MASK */ + if (should_fill_mask(ufid_flags)) + len += nla_total_size(ovs_key_attr_size()); + + /* OVS_FLOW_ATTR_ACTIONS */ + if (should_fill_actions(ufid_flags)) + len += nla_total_size(acts->orig_len); + + return len + + nla_total_size_64bit(sizeof(struct ovs_flow_stats)) /* OVS_FLOW_ATTR_STATS */ + + nla_total_size(1) /* OVS_FLOW_ATTR_TCP_FLAGS */ + + nla_total_size_64bit(8); /* OVS_FLOW_ATTR_USED */ +} + +/* Called with ovs_mutex or RCU read lock. */ +static int ovs_flow_cmd_fill_stats(const struct sw_flow *flow, + struct sk_buff *skb) +{ + struct ovs_flow_stats stats; + __be16 tcp_flags; + unsigned long used; + + ovs_flow_stats_get(flow, &stats, &used, &tcp_flags); + + if (used && + nla_put_u64_64bit(skb, OVS_FLOW_ATTR_USED, ovs_flow_used_time(used), + OVS_FLOW_ATTR_PAD)) + return -EMSGSIZE; + + if (stats.n_packets && + nla_put_64bit(skb, OVS_FLOW_ATTR_STATS, + sizeof(struct ovs_flow_stats), &stats, + OVS_FLOW_ATTR_PAD)) + return -EMSGSIZE; + + if ((u8)ntohs(tcp_flags) && + nla_put_u8(skb, OVS_FLOW_ATTR_TCP_FLAGS, (u8)ntohs(tcp_flags))) + return -EMSGSIZE; + + return 0; +} + +/* Called with ovs_mutex or RCU read lock. */ +static int ovs_flow_cmd_fill_actions(const struct sw_flow *flow, + struct sk_buff *skb, int skb_orig_len) +{ + struct nlattr *start; + int err; + + /* If OVS_FLOW_ATTR_ACTIONS doesn't fit, skip dumping the actions if + * this is the first flow to be dumped into 'skb'. This is unusual for + * Netlink but individual action lists can be longer than + * NLMSG_GOODSIZE and thus entirely undumpable if we didn't do this. + * The userspace caller can always fetch the actions separately if it + * really wants them. (Most userspace callers in fact don't care.) + * + * This can only fail for dump operations because the skb is always + * properly sized for single flows. + */ + start = nla_nest_start_noflag(skb, OVS_FLOW_ATTR_ACTIONS); + if (start) { + const struct sw_flow_actions *sf_acts; + + sf_acts = rcu_dereference_ovsl(flow->sf_acts); + err = ovs_nla_put_actions(sf_acts->actions, + sf_acts->actions_len, skb); + + if (!err) + nla_nest_end(skb, start); + else { + if (skb_orig_len) + return err; + + nla_nest_cancel(skb, start); + } + } else if (skb_orig_len) { + return -EMSGSIZE; + } + + return 0; +} + +/* Called with ovs_mutex or RCU read lock. */ +static int ovs_flow_cmd_fill_info(const struct sw_flow *flow, int dp_ifindex, + struct sk_buff *skb, u32 portid, + u32 seq, u32 flags, u8 cmd, u32 ufid_flags) +{ + const int skb_orig_len = skb->len; + struct ovs_header *ovs_header; + int err; + + ovs_header = genlmsg_put(skb, portid, seq, &dp_flow_genl_family, + flags, cmd); + if (!ovs_header) + return -EMSGSIZE; + + ovs_header->dp_ifindex = dp_ifindex; + + err = ovs_nla_put_identifier(flow, skb); + if (err) + goto error; + + if (should_fill_key(&flow->id, ufid_flags)) { + err = ovs_nla_put_masked_key(flow, skb); + if (err) + goto error; + } + + if (should_fill_mask(ufid_flags)) { + err = ovs_nla_put_mask(flow, skb); + if (err) + goto error; + } + + err = ovs_flow_cmd_fill_stats(flow, skb); + if (err) + goto error; + + if (should_fill_actions(ufid_flags)) { + err = ovs_flow_cmd_fill_actions(flow, skb, skb_orig_len); + if (err) + goto error; + } + + genlmsg_end(skb, ovs_header); + return 0; + +error: + genlmsg_cancel(skb, ovs_header); + return err; +} + +/* May not be called with RCU read lock. */ +static struct sk_buff *ovs_flow_cmd_alloc_info(const struct sw_flow_actions *acts, + const struct sw_flow_id *sfid, + struct genl_info *info, + bool always, + uint32_t ufid_flags) +{ + struct sk_buff *skb; + size_t len; + + if (!always && !ovs_must_notify(&dp_flow_genl_family, info, 0)) + return NULL; + + len = ovs_flow_cmd_msg_size(acts, sfid, ufid_flags); + skb = genlmsg_new(len, GFP_KERNEL); + if (!skb) + return ERR_PTR(-ENOMEM); + + return skb; +} + +/* Called with ovs_mutex. */ +static struct sk_buff *ovs_flow_cmd_build_info(const struct sw_flow *flow, + int dp_ifindex, + struct genl_info *info, u8 cmd, + bool always, u32 ufid_flags) +{ + struct sk_buff *skb; + int retval; + + skb = ovs_flow_cmd_alloc_info(ovsl_dereference(flow->sf_acts), + &flow->id, info, always, ufid_flags); + if (IS_ERR_OR_NULL(skb)) + return skb; + + retval = ovs_flow_cmd_fill_info(flow, dp_ifindex, skb, + info->snd_portid, info->snd_seq, 0, + cmd, ufid_flags); + if (WARN_ON_ONCE(retval < 0)) { + kfree_skb(skb); + skb = ERR_PTR(retval); + } + return skb; +} + +static int ovs_flow_cmd_new(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = sock_net(skb->sk); + struct nlattr **a = info->attrs; + struct ovs_header *ovs_header = info->userhdr; + struct sw_flow *flow = NULL, *new_flow; + struct sw_flow_mask mask; + struct sk_buff *reply; + struct datapath *dp; + struct sw_flow_key *key; + struct sw_flow_actions *acts; + struct sw_flow_match match; + u32 ufid_flags = ovs_nla_get_ufid_flags(a[OVS_FLOW_ATTR_UFID_FLAGS]); + int error; + bool log = !a[OVS_FLOW_ATTR_PROBE]; + + /* Must have key and actions. */ + error = -EINVAL; + if (!a[OVS_FLOW_ATTR_KEY]) { + OVS_NLERR(log, "Flow key attr not present in new flow."); + goto error; + } + if (!a[OVS_FLOW_ATTR_ACTIONS]) { + OVS_NLERR(log, "Flow actions attr not present in new flow."); + goto error; + } + + /* Most of the time we need to allocate a new flow, do it before + * locking. + */ + new_flow = ovs_flow_alloc(); + if (IS_ERR(new_flow)) { + error = PTR_ERR(new_flow); + goto error; + } + + /* Extract key. */ + key = kzalloc(sizeof(*key), GFP_KERNEL); + if (!key) { + error = -ENOMEM; + goto err_kfree_flow; + } + + ovs_match_init(&match, key, false, &mask); + error = ovs_nla_get_match(net, &match, a[OVS_FLOW_ATTR_KEY], + a[OVS_FLOW_ATTR_MASK], log); + if (error) + goto err_kfree_key; + + ovs_flow_mask_key(&new_flow->key, key, true, &mask); + + /* Extract flow identifier. */ + error = ovs_nla_get_identifier(&new_flow->id, a[OVS_FLOW_ATTR_UFID], + key, log); + if (error) + goto err_kfree_key; + + /* Validate actions. */ + error = ovs_nla_copy_actions(net, a[OVS_FLOW_ATTR_ACTIONS], + &new_flow->key, &acts, log); + if (error) { + OVS_NLERR(log, "Flow actions may not be safe on all matching packets."); + goto err_kfree_key; + } + + reply = ovs_flow_cmd_alloc_info(acts, &new_flow->id, info, false, + ufid_flags); + if (IS_ERR(reply)) { + error = PTR_ERR(reply); + goto err_kfree_acts; + } + + ovs_lock(); + dp = get_dp(net, ovs_header->dp_ifindex); + if (unlikely(!dp)) { + error = -ENODEV; + goto err_unlock_ovs; + } + + /* Check if this is a duplicate flow */ + if (ovs_identifier_is_ufid(&new_flow->id)) + flow = ovs_flow_tbl_lookup_ufid(&dp->table, &new_flow->id); + if (!flow) + flow = ovs_flow_tbl_lookup(&dp->table, key); + if (likely(!flow)) { + rcu_assign_pointer(new_flow->sf_acts, acts); + + /* Put flow in bucket. */ + error = ovs_flow_tbl_insert(&dp->table, new_flow, &mask); + if (unlikely(error)) { + acts = NULL; + goto err_unlock_ovs; + } + + if (unlikely(reply)) { + error = ovs_flow_cmd_fill_info(new_flow, + ovs_header->dp_ifindex, + reply, info->snd_portid, + info->snd_seq, 0, + OVS_FLOW_CMD_NEW, + ufid_flags); + BUG_ON(error < 0); + } + ovs_unlock(); + } else { + struct sw_flow_actions *old_acts; + + /* Bail out if we're not allowed to modify an existing flow. + * We accept NLM_F_CREATE in place of the intended NLM_F_EXCL + * because Generic Netlink treats the latter as a dump + * request. We also accept NLM_F_EXCL in case that bug ever + * gets fixed. + */ + if (unlikely(info->nlhdr->nlmsg_flags & (NLM_F_CREATE + | NLM_F_EXCL))) { + error = -EEXIST; + goto err_unlock_ovs; + } + /* The flow identifier has to be the same for flow updates. + * Look for any overlapping flow. + */ + if (unlikely(!ovs_flow_cmp(flow, &match))) { + if (ovs_identifier_is_key(&flow->id)) + flow = ovs_flow_tbl_lookup_exact(&dp->table, + &match); + else /* UFID matches but key is different */ + flow = NULL; + if (!flow) { + error = -ENOENT; + goto err_unlock_ovs; + } + } + /* Update actions. */ + old_acts = ovsl_dereference(flow->sf_acts); + rcu_assign_pointer(flow->sf_acts, acts); + + if (unlikely(reply)) { + error = ovs_flow_cmd_fill_info(flow, + ovs_header->dp_ifindex, + reply, info->snd_portid, + info->snd_seq, 0, + OVS_FLOW_CMD_NEW, + ufid_flags); + BUG_ON(error < 0); + } + ovs_unlock(); + + ovs_nla_free_flow_actions_rcu(old_acts); + ovs_flow_free(new_flow, false); + } + + if (reply) + ovs_notify(&dp_flow_genl_family, reply, info); + + kfree(key); + return 0; + +err_unlock_ovs: + ovs_unlock(); + kfree_skb(reply); +err_kfree_acts: + ovs_nla_free_flow_actions(acts); +err_kfree_key: + kfree(key); +err_kfree_flow: + ovs_flow_free(new_flow, false); +error: + return error; +} + +/* Factor out action copy to avoid "Wframe-larger-than=1024" warning. */ +static noinline_for_stack +struct sw_flow_actions *get_flow_actions(struct net *net, + const struct nlattr *a, + const struct sw_flow_key *key, + const struct sw_flow_mask *mask, + bool log) +{ + struct sw_flow_actions *acts; + struct sw_flow_key masked_key; + int error; + + ovs_flow_mask_key(&masked_key, key, true, mask); + error = ovs_nla_copy_actions(net, a, &masked_key, &acts, log); + if (error) { + OVS_NLERR(log, + "Actions may not be safe on all matching packets"); + return ERR_PTR(error); + } + + return acts; +} + +/* Factor out match-init and action-copy to avoid + * "Wframe-larger-than=1024" warning. Because mask is only + * used to get actions, we new a function to save some + * stack space. + * + * If there are not key and action attrs, we return 0 + * directly. In the case, the caller will also not use the + * match as before. If there is action attr, we try to get + * actions and save them to *acts. Before returning from + * the function, we reset the match->mask pointer. Because + * we should not to return match object with dangling reference + * to mask. + * */ +static noinline_for_stack int +ovs_nla_init_match_and_action(struct net *net, + struct sw_flow_match *match, + struct sw_flow_key *key, + struct nlattr **a, + struct sw_flow_actions **acts, + bool log) +{ + struct sw_flow_mask mask; + int error = 0; + + if (a[OVS_FLOW_ATTR_KEY]) { + ovs_match_init(match, key, true, &mask); + error = ovs_nla_get_match(net, match, a[OVS_FLOW_ATTR_KEY], + a[OVS_FLOW_ATTR_MASK], log); + if (error) + goto error; + } + + if (a[OVS_FLOW_ATTR_ACTIONS]) { + if (!a[OVS_FLOW_ATTR_KEY]) { + OVS_NLERR(log, + "Flow key attribute not present in set flow."); + error = -EINVAL; + goto error; + } + + *acts = get_flow_actions(net, a[OVS_FLOW_ATTR_ACTIONS], key, + &mask, log); + if (IS_ERR(*acts)) { + error = PTR_ERR(*acts); + goto error; + } + } + + /* On success, error is 0. */ +error: + match->mask = NULL; + return error; +} + +static int ovs_flow_cmd_set(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = sock_net(skb->sk); + struct nlattr **a = info->attrs; + struct ovs_header *ovs_header = info->userhdr; + struct sw_flow_key key; + struct sw_flow *flow; + struct sk_buff *reply = NULL; + struct datapath *dp; + struct sw_flow_actions *old_acts = NULL, *acts = NULL; + struct sw_flow_match match; + struct sw_flow_id sfid; + u32 ufid_flags = ovs_nla_get_ufid_flags(a[OVS_FLOW_ATTR_UFID_FLAGS]); + int error = 0; + bool log = !a[OVS_FLOW_ATTR_PROBE]; + bool ufid_present; + + ufid_present = ovs_nla_get_ufid(&sfid, a[OVS_FLOW_ATTR_UFID], log); + if (!a[OVS_FLOW_ATTR_KEY] && !ufid_present) { + OVS_NLERR(log, + "Flow set message rejected, Key attribute missing."); + return -EINVAL; + } + + error = ovs_nla_init_match_and_action(net, &match, &key, a, + &acts, log); + if (error) + goto error; + + if (acts) { + /* Can allocate before locking if have acts. */ + reply = ovs_flow_cmd_alloc_info(acts, &sfid, info, false, + ufid_flags); + if (IS_ERR(reply)) { + error = PTR_ERR(reply); + goto err_kfree_acts; + } + } + + ovs_lock(); + dp = get_dp(net, ovs_header->dp_ifindex); + if (unlikely(!dp)) { + error = -ENODEV; + goto err_unlock_ovs; + } + /* Check that the flow exists. */ + if (ufid_present) + flow = ovs_flow_tbl_lookup_ufid(&dp->table, &sfid); + else + flow = ovs_flow_tbl_lookup_exact(&dp->table, &match); + if (unlikely(!flow)) { + error = -ENOENT; + goto err_unlock_ovs; + } + + /* Update actions, if present. */ + if (likely(acts)) { + old_acts = ovsl_dereference(flow->sf_acts); + rcu_assign_pointer(flow->sf_acts, acts); + + if (unlikely(reply)) { + error = ovs_flow_cmd_fill_info(flow, + ovs_header->dp_ifindex, + reply, info->snd_portid, + info->snd_seq, 0, + OVS_FLOW_CMD_SET, + ufid_flags); + BUG_ON(error < 0); + } + } else { + /* Could not alloc without acts before locking. */ + reply = ovs_flow_cmd_build_info(flow, ovs_header->dp_ifindex, + info, OVS_FLOW_CMD_SET, false, + ufid_flags); + + if (IS_ERR(reply)) { + error = PTR_ERR(reply); + goto err_unlock_ovs; + } + } + + /* Clear stats. */ + if (a[OVS_FLOW_ATTR_CLEAR]) + ovs_flow_stats_clear(flow); + ovs_unlock(); + + if (reply) + ovs_notify(&dp_flow_genl_family, reply, info); + if (old_acts) + ovs_nla_free_flow_actions_rcu(old_acts); + + return 0; + +err_unlock_ovs: + ovs_unlock(); + kfree_skb(reply); +err_kfree_acts: + ovs_nla_free_flow_actions(acts); +error: + return error; +} + +static int ovs_flow_cmd_get(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr **a = info->attrs; + struct ovs_header *ovs_header = info->userhdr; + struct net *net = sock_net(skb->sk); + struct sw_flow_key key; + struct sk_buff *reply; + struct sw_flow *flow; + struct datapath *dp; + struct sw_flow_match match; + struct sw_flow_id ufid; + u32 ufid_flags = ovs_nla_get_ufid_flags(a[OVS_FLOW_ATTR_UFID_FLAGS]); + int err = 0; + bool log = !a[OVS_FLOW_ATTR_PROBE]; + bool ufid_present; + + ufid_present = ovs_nla_get_ufid(&ufid, a[OVS_FLOW_ATTR_UFID], log); + if (a[OVS_FLOW_ATTR_KEY]) { + ovs_match_init(&match, &key, true, NULL); + err = ovs_nla_get_match(net, &match, a[OVS_FLOW_ATTR_KEY], NULL, + log); + } else if (!ufid_present) { + OVS_NLERR(log, + "Flow get message rejected, Key attribute missing."); + err = -EINVAL; + } + if (err) + return err; + + ovs_lock(); + dp = get_dp(sock_net(skb->sk), ovs_header->dp_ifindex); + if (!dp) { + err = -ENODEV; + goto unlock; + } + + if (ufid_present) + flow = ovs_flow_tbl_lookup_ufid(&dp->table, &ufid); + else + flow = ovs_flow_tbl_lookup_exact(&dp->table, &match); + if (!flow) { + err = -ENOENT; + goto unlock; + } + + reply = ovs_flow_cmd_build_info(flow, ovs_header->dp_ifindex, info, + OVS_FLOW_CMD_GET, true, ufid_flags); + if (IS_ERR(reply)) { + err = PTR_ERR(reply); + goto unlock; + } + + ovs_unlock(); + return genlmsg_reply(reply, info); +unlock: + ovs_unlock(); + return err; +} + +static int ovs_flow_cmd_del(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr **a = info->attrs; + struct ovs_header *ovs_header = info->userhdr; + struct net *net = sock_net(skb->sk); + struct sw_flow_key key; + struct sk_buff *reply; + struct sw_flow *flow = NULL; + struct datapath *dp; + struct sw_flow_match match; + struct sw_flow_id ufid; + u32 ufid_flags = ovs_nla_get_ufid_flags(a[OVS_FLOW_ATTR_UFID_FLAGS]); + int err; + bool log = !a[OVS_FLOW_ATTR_PROBE]; + bool ufid_present; + + ufid_present = ovs_nla_get_ufid(&ufid, a[OVS_FLOW_ATTR_UFID], log); + if (a[OVS_FLOW_ATTR_KEY]) { + ovs_match_init(&match, &key, true, NULL); + err = ovs_nla_get_match(net, &match, a[OVS_FLOW_ATTR_KEY], + NULL, log); + if (unlikely(err)) + return err; + } + + ovs_lock(); + dp = get_dp(sock_net(skb->sk), ovs_header->dp_ifindex); + if (unlikely(!dp)) { + err = -ENODEV; + goto unlock; + } + + if (unlikely(!a[OVS_FLOW_ATTR_KEY] && !ufid_present)) { + err = ovs_flow_tbl_flush(&dp->table); + goto unlock; + } + + if (ufid_present) + flow = ovs_flow_tbl_lookup_ufid(&dp->table, &ufid); + else + flow = ovs_flow_tbl_lookup_exact(&dp->table, &match); + if (unlikely(!flow)) { + err = -ENOENT; + goto unlock; + } + + ovs_flow_tbl_remove(&dp->table, flow); + ovs_unlock(); + + reply = ovs_flow_cmd_alloc_info((const struct sw_flow_actions __force *) flow->sf_acts, + &flow->id, info, false, ufid_flags); + if (likely(reply)) { + if (!IS_ERR(reply)) { + rcu_read_lock(); /*To keep RCU checker happy. */ + err = ovs_flow_cmd_fill_info(flow, ovs_header->dp_ifindex, + reply, info->snd_portid, + info->snd_seq, 0, + OVS_FLOW_CMD_DEL, + ufid_flags); + rcu_read_unlock(); + if (WARN_ON_ONCE(err < 0)) { + kfree_skb(reply); + goto out_free; + } + + ovs_notify(&dp_flow_genl_family, reply, info); + } else { + netlink_set_err(sock_net(skb->sk)->genl_sock, 0, 0, + PTR_ERR(reply)); + } + } + +out_free: + ovs_flow_free(flow, true); + return 0; +unlock: + ovs_unlock(); + return err; +} + +static int ovs_flow_cmd_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct nlattr *a[__OVS_FLOW_ATTR_MAX]; + struct ovs_header *ovs_header = genlmsg_data(nlmsg_data(cb->nlh)); + struct table_instance *ti; + struct datapath *dp; + u32 ufid_flags; + int err; + + err = genlmsg_parse_deprecated(cb->nlh, &dp_flow_genl_family, a, + OVS_FLOW_ATTR_MAX, flow_policy, NULL); + if (err) + return err; + ufid_flags = ovs_nla_get_ufid_flags(a[OVS_FLOW_ATTR_UFID_FLAGS]); + + rcu_read_lock(); + dp = get_dp_rcu(sock_net(skb->sk), ovs_header->dp_ifindex); + if (!dp) { + rcu_read_unlock(); + return -ENODEV; + } + + ti = rcu_dereference(dp->table.ti); + for (;;) { + struct sw_flow *flow; + u32 bucket, obj; + + bucket = cb->args[0]; + obj = cb->args[1]; + flow = ovs_flow_tbl_dump_next(ti, &bucket, &obj); + if (!flow) + break; + + if (ovs_flow_cmd_fill_info(flow, ovs_header->dp_ifindex, skb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + OVS_FLOW_CMD_GET, ufid_flags) < 0) + break; + + cb->args[0] = bucket; + cb->args[1] = obj; + } + rcu_read_unlock(); + return skb->len; +} + +static const struct nla_policy flow_policy[OVS_FLOW_ATTR_MAX + 1] = { + [OVS_FLOW_ATTR_KEY] = { .type = NLA_NESTED }, + [OVS_FLOW_ATTR_MASK] = { .type = NLA_NESTED }, + [OVS_FLOW_ATTR_ACTIONS] = { .type = NLA_NESTED }, + [OVS_FLOW_ATTR_CLEAR] = { .type = NLA_FLAG }, + [OVS_FLOW_ATTR_PROBE] = { .type = NLA_FLAG }, + [OVS_FLOW_ATTR_UFID] = { .type = NLA_UNSPEC, .len = 1 }, + [OVS_FLOW_ATTR_UFID_FLAGS] = { .type = NLA_U32 }, +}; + +static const struct genl_small_ops dp_flow_genl_ops[] = { + { .cmd = OVS_FLOW_CMD_NEW, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, /* Requires CAP_NET_ADMIN privilege. */ + .doit = ovs_flow_cmd_new + }, + { .cmd = OVS_FLOW_CMD_DEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, /* Requires CAP_NET_ADMIN privilege. */ + .doit = ovs_flow_cmd_del + }, + { .cmd = OVS_FLOW_CMD_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, /* OK for unprivileged users. */ + .doit = ovs_flow_cmd_get, + .dumpit = ovs_flow_cmd_dump + }, + { .cmd = OVS_FLOW_CMD_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, /* Requires CAP_NET_ADMIN privilege. */ + .doit = ovs_flow_cmd_set, + }, +}; + +static struct genl_family dp_flow_genl_family __ro_after_init = { + .hdrsize = sizeof(struct ovs_header), + .name = OVS_FLOW_FAMILY, + .version = OVS_FLOW_VERSION, + .maxattr = OVS_FLOW_ATTR_MAX, + .policy = flow_policy, + .netnsok = true, + .parallel_ops = true, + .small_ops = dp_flow_genl_ops, + .n_small_ops = ARRAY_SIZE(dp_flow_genl_ops), + .resv_start_op = OVS_FLOW_CMD_SET + 1, + .mcgrps = &ovs_dp_flow_multicast_group, + .n_mcgrps = 1, + .module = THIS_MODULE, +}; + +static size_t ovs_dp_cmd_msg_size(void) +{ + size_t msgsize = NLMSG_ALIGN(sizeof(struct ovs_header)); + + msgsize += nla_total_size(IFNAMSIZ); + msgsize += nla_total_size_64bit(sizeof(struct ovs_dp_stats)); + msgsize += nla_total_size_64bit(sizeof(struct ovs_dp_megaflow_stats)); + msgsize += nla_total_size(sizeof(u32)); /* OVS_DP_ATTR_USER_FEATURES */ + msgsize += nla_total_size(sizeof(u32)); /* OVS_DP_ATTR_MASKS_CACHE_SIZE */ + msgsize += nla_total_size(sizeof(u32) * nr_cpu_ids); /* OVS_DP_ATTR_PER_CPU_PIDS */ + + return msgsize; +} + +/* Called with ovs_mutex. */ +static int ovs_dp_cmd_fill_info(struct datapath *dp, struct sk_buff *skb, + u32 portid, u32 seq, u32 flags, u8 cmd) +{ + struct ovs_header *ovs_header; + struct ovs_dp_stats dp_stats; + struct ovs_dp_megaflow_stats dp_megaflow_stats; + struct dp_nlsk_pids *pids = ovsl_dereference(dp->upcall_portids); + int err, pids_len; + + ovs_header = genlmsg_put(skb, portid, seq, &dp_datapath_genl_family, + flags, cmd); + if (!ovs_header) + goto error; + + ovs_header->dp_ifindex = get_dpifindex(dp); + + err = nla_put_string(skb, OVS_DP_ATTR_NAME, ovs_dp_name(dp)); + if (err) + goto nla_put_failure; + + get_dp_stats(dp, &dp_stats, &dp_megaflow_stats); + if (nla_put_64bit(skb, OVS_DP_ATTR_STATS, sizeof(struct ovs_dp_stats), + &dp_stats, OVS_DP_ATTR_PAD)) + goto nla_put_failure; + + if (nla_put_64bit(skb, OVS_DP_ATTR_MEGAFLOW_STATS, + sizeof(struct ovs_dp_megaflow_stats), + &dp_megaflow_stats, OVS_DP_ATTR_PAD)) + goto nla_put_failure; + + if (nla_put_u32(skb, OVS_DP_ATTR_USER_FEATURES, dp->user_features)) + goto nla_put_failure; + + if (nla_put_u32(skb, OVS_DP_ATTR_MASKS_CACHE_SIZE, + ovs_flow_tbl_masks_cache_size(&dp->table))) + goto nla_put_failure; + + if (dp->user_features & OVS_DP_F_DISPATCH_UPCALL_PER_CPU && pids) { + pids_len = min(pids->n_pids, nr_cpu_ids) * sizeof(u32); + if (nla_put(skb, OVS_DP_ATTR_PER_CPU_PIDS, pids_len, &pids->pids)) + goto nla_put_failure; + } + + genlmsg_end(skb, ovs_header); + return 0; + +nla_put_failure: + genlmsg_cancel(skb, ovs_header); +error: + return -EMSGSIZE; +} + +static struct sk_buff *ovs_dp_cmd_alloc_info(void) +{ + return genlmsg_new(ovs_dp_cmd_msg_size(), GFP_KERNEL); +} + +/* Called with rcu_read_lock or ovs_mutex. */ +static struct datapath *lookup_datapath(struct net *net, + const struct ovs_header *ovs_header, + struct nlattr *a[OVS_DP_ATTR_MAX + 1]) +{ + struct datapath *dp; + + if (!a[OVS_DP_ATTR_NAME]) + dp = get_dp(net, ovs_header->dp_ifindex); + else { + struct vport *vport; + + vport = ovs_vport_locate(net, nla_data(a[OVS_DP_ATTR_NAME])); + dp = vport && vport->port_no == OVSP_LOCAL ? vport->dp : NULL; + } + return dp ? dp : ERR_PTR(-ENODEV); +} + +static void ovs_dp_reset_user_features(struct sk_buff *skb, + struct genl_info *info) +{ + struct datapath *dp; + + dp = lookup_datapath(sock_net(skb->sk), info->userhdr, + info->attrs); + if (IS_ERR(dp)) + return; + + pr_warn("%s: Dropping previously announced user features\n", + ovs_dp_name(dp)); + dp->user_features = 0; +} + +static int ovs_dp_set_upcall_portids(struct datapath *dp, + const struct nlattr *ids) +{ + struct dp_nlsk_pids *old, *dp_nlsk_pids; + + if (!nla_len(ids) || nla_len(ids) % sizeof(u32)) + return -EINVAL; + + old = ovsl_dereference(dp->upcall_portids); + + dp_nlsk_pids = kmalloc(sizeof(*dp_nlsk_pids) + nla_len(ids), + GFP_KERNEL); + if (!dp_nlsk_pids) + return -ENOMEM; + + dp_nlsk_pids->n_pids = nla_len(ids) / sizeof(u32); + nla_memcpy(dp_nlsk_pids->pids, ids, nla_len(ids)); + + rcu_assign_pointer(dp->upcall_portids, dp_nlsk_pids); + + kfree_rcu(old, rcu); + + return 0; +} + +u32 ovs_dp_get_upcall_portid(const struct datapath *dp, uint32_t cpu_id) +{ + struct dp_nlsk_pids *dp_nlsk_pids; + + dp_nlsk_pids = rcu_dereference(dp->upcall_portids); + + if (dp_nlsk_pids) { + if (cpu_id < dp_nlsk_pids->n_pids) { + return dp_nlsk_pids->pids[cpu_id]; + } else if (dp_nlsk_pids->n_pids > 0 && + cpu_id >= dp_nlsk_pids->n_pids) { + /* If the number of netlink PIDs is mismatched with + * the number of CPUs as seen by the kernel, log this + * and send the upcall to an arbitrary socket (0) in + * order to not drop packets + */ + pr_info_ratelimited("cpu_id mismatch with handler threads"); + return dp_nlsk_pids->pids[cpu_id % + dp_nlsk_pids->n_pids]; + } else { + return 0; + } + } else { + return 0; + } +} + +static int ovs_dp_change(struct datapath *dp, struct nlattr *a[]) +{ + u32 user_features = 0, old_features = dp->user_features; + int err; + + if (a[OVS_DP_ATTR_USER_FEATURES]) { + user_features = nla_get_u32(a[OVS_DP_ATTR_USER_FEATURES]); + + if (user_features & ~(OVS_DP_F_VPORT_PIDS | + OVS_DP_F_UNALIGNED | + OVS_DP_F_TC_RECIRC_SHARING | + OVS_DP_F_DISPATCH_UPCALL_PER_CPU)) + return -EOPNOTSUPP; + +#if !IS_ENABLED(CONFIG_NET_TC_SKB_EXT) + if (user_features & OVS_DP_F_TC_RECIRC_SHARING) + return -EOPNOTSUPP; +#endif + } + + if (a[OVS_DP_ATTR_MASKS_CACHE_SIZE]) { + int err; + u32 cache_size; + + cache_size = nla_get_u32(a[OVS_DP_ATTR_MASKS_CACHE_SIZE]); + err = ovs_flow_tbl_masks_cache_resize(&dp->table, cache_size); + if (err) + return err; + } + + dp->user_features = user_features; + + if (dp->user_features & OVS_DP_F_DISPATCH_UPCALL_PER_CPU && + a[OVS_DP_ATTR_PER_CPU_PIDS]) { + /* Upcall Netlink Port IDs have been updated */ + err = ovs_dp_set_upcall_portids(dp, + a[OVS_DP_ATTR_PER_CPU_PIDS]); + if (err) + return err; + } + + if ((dp->user_features & OVS_DP_F_TC_RECIRC_SHARING) && + !(old_features & OVS_DP_F_TC_RECIRC_SHARING)) + tc_skb_ext_tc_enable(); + else if (!(dp->user_features & OVS_DP_F_TC_RECIRC_SHARING) && + (old_features & OVS_DP_F_TC_RECIRC_SHARING)) + tc_skb_ext_tc_disable(); + + return 0; +} + +static int ovs_dp_stats_init(struct datapath *dp) +{ + dp->stats_percpu = netdev_alloc_pcpu_stats(struct dp_stats_percpu); + if (!dp->stats_percpu) + return -ENOMEM; + + return 0; +} + +static int ovs_dp_vport_init(struct datapath *dp) +{ + int i; + + dp->ports = kmalloc_array(DP_VPORT_HASH_BUCKETS, + sizeof(struct hlist_head), + GFP_KERNEL); + if (!dp->ports) + return -ENOMEM; + + for (i = 0; i < DP_VPORT_HASH_BUCKETS; i++) + INIT_HLIST_HEAD(&dp->ports[i]); + + return 0; +} + +static int ovs_dp_cmd_new(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr **a = info->attrs; + struct vport_parms parms; + struct sk_buff *reply; + struct datapath *dp; + struct vport *vport; + struct ovs_net *ovs_net; + int err; + + err = -EINVAL; + if (!a[OVS_DP_ATTR_NAME] || !a[OVS_DP_ATTR_UPCALL_PID]) + goto err; + + reply = ovs_dp_cmd_alloc_info(); + if (!reply) + return -ENOMEM; + + err = -ENOMEM; + dp = kzalloc(sizeof(*dp), GFP_KERNEL); + if (dp == NULL) + goto err_destroy_reply; + + ovs_dp_set_net(dp, sock_net(skb->sk)); + + /* Allocate table. */ + err = ovs_flow_tbl_init(&dp->table); + if (err) + goto err_destroy_dp; + + err = ovs_dp_stats_init(dp); + if (err) + goto err_destroy_table; + + err = ovs_dp_vport_init(dp); + if (err) + goto err_destroy_stats; + + err = ovs_meters_init(dp); + if (err) + goto err_destroy_ports; + + /* Set up our datapath device. */ + parms.name = nla_data(a[OVS_DP_ATTR_NAME]); + parms.type = OVS_VPORT_TYPE_INTERNAL; + parms.options = NULL; + parms.dp = dp; + parms.port_no = OVSP_LOCAL; + parms.upcall_portids = a[OVS_DP_ATTR_UPCALL_PID]; + parms.desired_ifindex = a[OVS_DP_ATTR_IFINDEX] + ? nla_get_s32(a[OVS_DP_ATTR_IFINDEX]) : 0; + + /* So far only local changes have been made, now need the lock. */ + ovs_lock(); + + err = ovs_dp_change(dp, a); + if (err) + goto err_unlock_and_destroy_meters; + + vport = new_vport(&parms); + if (IS_ERR(vport)) { + err = PTR_ERR(vport); + if (err == -EBUSY) + err = -EEXIST; + + if (err == -EEXIST) { + /* An outdated user space instance that does not understand + * the concept of user_features has attempted to create a new + * datapath and is likely to reuse it. Drop all user features. + */ + if (info->genlhdr->version < OVS_DP_VER_FEATURES) + ovs_dp_reset_user_features(skb, info); + } + + goto err_destroy_portids; + } + + err = ovs_dp_cmd_fill_info(dp, reply, info->snd_portid, + info->snd_seq, 0, OVS_DP_CMD_NEW); + BUG_ON(err < 0); + + ovs_net = net_generic(ovs_dp_get_net(dp), ovs_net_id); + list_add_tail_rcu(&dp->list_node, &ovs_net->dps); + + ovs_unlock(); + + ovs_notify(&dp_datapath_genl_family, reply, info); + return 0; + +err_destroy_portids: + kfree(rcu_dereference_raw(dp->upcall_portids)); +err_unlock_and_destroy_meters: + ovs_unlock(); + ovs_meters_exit(dp); +err_destroy_ports: + kfree(dp->ports); +err_destroy_stats: + free_percpu(dp->stats_percpu); +err_destroy_table: + ovs_flow_tbl_destroy(&dp->table); +err_destroy_dp: + kfree(dp); +err_destroy_reply: + kfree_skb(reply); +err: + return err; +} + +/* Called with ovs_mutex. */ +static void __dp_destroy(struct datapath *dp) +{ + struct flow_table *table = &dp->table; + int i; + + if (dp->user_features & OVS_DP_F_TC_RECIRC_SHARING) + tc_skb_ext_tc_disable(); + + for (i = 0; i < DP_VPORT_HASH_BUCKETS; i++) { + struct vport *vport; + struct hlist_node *n; + + hlist_for_each_entry_safe(vport, n, &dp->ports[i], dp_hash_node) + if (vport->port_no != OVSP_LOCAL) + ovs_dp_detach_port(vport); + } + + list_del_rcu(&dp->list_node); + + /* OVSP_LOCAL is datapath internal port. We need to make sure that + * all ports in datapath are destroyed first before freeing datapath. + */ + ovs_dp_detach_port(ovs_vport_ovsl(dp, OVSP_LOCAL)); + + /* Flush sw_flow in the tables. RCU cb only releases resource + * such as dp, ports and tables. That may avoid some issues + * such as RCU usage warning. + */ + table_instance_flow_flush(table, ovsl_dereference(table->ti), + ovsl_dereference(table->ufid_ti)); + + /* RCU destroy the ports, meters and flow tables. */ + call_rcu(&dp->rcu, destroy_dp_rcu); +} + +static int ovs_dp_cmd_del(struct sk_buff *skb, struct genl_info *info) +{ + struct sk_buff *reply; + struct datapath *dp; + int err; + + reply = ovs_dp_cmd_alloc_info(); + if (!reply) + return -ENOMEM; + + ovs_lock(); + dp = lookup_datapath(sock_net(skb->sk), info->userhdr, info->attrs); + err = PTR_ERR(dp); + if (IS_ERR(dp)) + goto err_unlock_free; + + err = ovs_dp_cmd_fill_info(dp, reply, info->snd_portid, + info->snd_seq, 0, OVS_DP_CMD_DEL); + BUG_ON(err < 0); + + __dp_destroy(dp); + ovs_unlock(); + + ovs_notify(&dp_datapath_genl_family, reply, info); + + return 0; + +err_unlock_free: + ovs_unlock(); + kfree_skb(reply); + return err; +} + +static int ovs_dp_cmd_set(struct sk_buff *skb, struct genl_info *info) +{ + struct sk_buff *reply; + struct datapath *dp; + int err; + + reply = ovs_dp_cmd_alloc_info(); + if (!reply) + return -ENOMEM; + + ovs_lock(); + dp = lookup_datapath(sock_net(skb->sk), info->userhdr, info->attrs); + err = PTR_ERR(dp); + if (IS_ERR(dp)) + goto err_unlock_free; + + err = ovs_dp_change(dp, info->attrs); + if (err) + goto err_unlock_free; + + err = ovs_dp_cmd_fill_info(dp, reply, info->snd_portid, + info->snd_seq, 0, OVS_DP_CMD_SET); + BUG_ON(err < 0); + + ovs_unlock(); + ovs_notify(&dp_datapath_genl_family, reply, info); + + return 0; + +err_unlock_free: + ovs_unlock(); + kfree_skb(reply); + return err; +} + +static int ovs_dp_cmd_get(struct sk_buff *skb, struct genl_info *info) +{ + struct sk_buff *reply; + struct datapath *dp; + int err; + + reply = ovs_dp_cmd_alloc_info(); + if (!reply) + return -ENOMEM; + + ovs_lock(); + dp = lookup_datapath(sock_net(skb->sk), info->userhdr, info->attrs); + if (IS_ERR(dp)) { + err = PTR_ERR(dp); + goto err_unlock_free; + } + err = ovs_dp_cmd_fill_info(dp, reply, info->snd_portid, + info->snd_seq, 0, OVS_DP_CMD_GET); + BUG_ON(err < 0); + ovs_unlock(); + + return genlmsg_reply(reply, info); + +err_unlock_free: + ovs_unlock(); + kfree_skb(reply); + return err; +} + +static int ovs_dp_cmd_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct ovs_net *ovs_net = net_generic(sock_net(skb->sk), ovs_net_id); + struct datapath *dp; + int skip = cb->args[0]; + int i = 0; + + ovs_lock(); + list_for_each_entry(dp, &ovs_net->dps, list_node) { + if (i >= skip && + ovs_dp_cmd_fill_info(dp, skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + OVS_DP_CMD_GET) < 0) + break; + i++; + } + ovs_unlock(); + + cb->args[0] = i; + + return skb->len; +} + +static const struct nla_policy datapath_policy[OVS_DP_ATTR_MAX + 1] = { + [OVS_DP_ATTR_NAME] = { .type = NLA_NUL_STRING, .len = IFNAMSIZ - 1 }, + [OVS_DP_ATTR_UPCALL_PID] = { .type = NLA_U32 }, + [OVS_DP_ATTR_USER_FEATURES] = { .type = NLA_U32 }, + [OVS_DP_ATTR_MASKS_CACHE_SIZE] = NLA_POLICY_RANGE(NLA_U32, 0, + PCPU_MIN_UNIT_SIZE / sizeof(struct mask_cache_entry)), + [OVS_DP_ATTR_IFINDEX] = NLA_POLICY_MIN(NLA_S32, 0), +}; + +static const struct genl_small_ops dp_datapath_genl_ops[] = { + { .cmd = OVS_DP_CMD_NEW, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, /* Requires CAP_NET_ADMIN privilege. */ + .doit = ovs_dp_cmd_new + }, + { .cmd = OVS_DP_CMD_DEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, /* Requires CAP_NET_ADMIN privilege. */ + .doit = ovs_dp_cmd_del + }, + { .cmd = OVS_DP_CMD_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, /* OK for unprivileged users. */ + .doit = ovs_dp_cmd_get, + .dumpit = ovs_dp_cmd_dump + }, + { .cmd = OVS_DP_CMD_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, /* Requires CAP_NET_ADMIN privilege. */ + .doit = ovs_dp_cmd_set, + }, +}; + +static struct genl_family dp_datapath_genl_family __ro_after_init = { + .hdrsize = sizeof(struct ovs_header), + .name = OVS_DATAPATH_FAMILY, + .version = OVS_DATAPATH_VERSION, + .maxattr = OVS_DP_ATTR_MAX, + .policy = datapath_policy, + .netnsok = true, + .parallel_ops = true, + .small_ops = dp_datapath_genl_ops, + .n_small_ops = ARRAY_SIZE(dp_datapath_genl_ops), + .resv_start_op = OVS_DP_CMD_SET + 1, + .mcgrps = &ovs_dp_datapath_multicast_group, + .n_mcgrps = 1, + .module = THIS_MODULE, +}; + +/* Called with ovs_mutex or RCU read lock. */ +static int ovs_vport_cmd_fill_info(struct vport *vport, struct sk_buff *skb, + struct net *net, u32 portid, u32 seq, + u32 flags, u8 cmd, gfp_t gfp) +{ + struct ovs_header *ovs_header; + struct ovs_vport_stats vport_stats; + int err; + + ovs_header = genlmsg_put(skb, portid, seq, &dp_vport_genl_family, + flags, cmd); + if (!ovs_header) + return -EMSGSIZE; + + ovs_header->dp_ifindex = get_dpifindex(vport->dp); + + if (nla_put_u32(skb, OVS_VPORT_ATTR_PORT_NO, vport->port_no) || + nla_put_u32(skb, OVS_VPORT_ATTR_TYPE, vport->ops->type) || + nla_put_string(skb, OVS_VPORT_ATTR_NAME, + ovs_vport_name(vport)) || + nla_put_u32(skb, OVS_VPORT_ATTR_IFINDEX, vport->dev->ifindex)) + goto nla_put_failure; + + if (!net_eq(net, dev_net(vport->dev))) { + int id = peernet2id_alloc(net, dev_net(vport->dev), gfp); + + if (nla_put_s32(skb, OVS_VPORT_ATTR_NETNSID, id)) + goto nla_put_failure; + } + + ovs_vport_get_stats(vport, &vport_stats); + if (nla_put_64bit(skb, OVS_VPORT_ATTR_STATS, + sizeof(struct ovs_vport_stats), &vport_stats, + OVS_VPORT_ATTR_PAD)) + goto nla_put_failure; + + if (ovs_vport_get_upcall_portids(vport, skb)) + goto nla_put_failure; + + err = ovs_vport_get_options(vport, skb); + if (err == -EMSGSIZE) + goto error; + + genlmsg_end(skb, ovs_header); + return 0; + +nla_put_failure: + err = -EMSGSIZE; +error: + genlmsg_cancel(skb, ovs_header); + return err; +} + +static struct sk_buff *ovs_vport_cmd_alloc_info(void) +{ + return nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); +} + +/* Called with ovs_mutex, only via ovs_dp_notify_wq(). */ +struct sk_buff *ovs_vport_cmd_build_info(struct vport *vport, struct net *net, + u32 portid, u32 seq, u8 cmd) +{ + struct sk_buff *skb; + int retval; + + skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb) + return ERR_PTR(-ENOMEM); + + retval = ovs_vport_cmd_fill_info(vport, skb, net, portid, seq, 0, cmd, + GFP_KERNEL); + BUG_ON(retval < 0); + + return skb; +} + +/* Called with ovs_mutex or RCU read lock. */ +static struct vport *lookup_vport(struct net *net, + const struct ovs_header *ovs_header, + struct nlattr *a[OVS_VPORT_ATTR_MAX + 1]) +{ + struct datapath *dp; + struct vport *vport; + + if (a[OVS_VPORT_ATTR_IFINDEX]) + return ERR_PTR(-EOPNOTSUPP); + if (a[OVS_VPORT_ATTR_NAME]) { + vport = ovs_vport_locate(net, nla_data(a[OVS_VPORT_ATTR_NAME])); + if (!vport) + return ERR_PTR(-ENODEV); + if (ovs_header->dp_ifindex && + ovs_header->dp_ifindex != get_dpifindex(vport->dp)) + return ERR_PTR(-ENODEV); + return vport; + } else if (a[OVS_VPORT_ATTR_PORT_NO]) { + u32 port_no = nla_get_u32(a[OVS_VPORT_ATTR_PORT_NO]); + + if (port_no >= DP_MAX_PORTS) + return ERR_PTR(-EFBIG); + + dp = get_dp(net, ovs_header->dp_ifindex); + if (!dp) + return ERR_PTR(-ENODEV); + + vport = ovs_vport_ovsl_rcu(dp, port_no); + if (!vport) + return ERR_PTR(-ENODEV); + return vport; + } else + return ERR_PTR(-EINVAL); + +} + +static unsigned int ovs_get_max_headroom(struct datapath *dp) +{ + unsigned int dev_headroom, max_headroom = 0; + struct net_device *dev; + struct vport *vport; + int i; + + for (i = 0; i < DP_VPORT_HASH_BUCKETS; i++) { + hlist_for_each_entry_rcu(vport, &dp->ports[i], dp_hash_node, + lockdep_ovsl_is_held()) { + dev = vport->dev; + dev_headroom = netdev_get_fwd_headroom(dev); + if (dev_headroom > max_headroom) + max_headroom = dev_headroom; + } + } + + return max_headroom; +} + +/* Called with ovs_mutex */ +static void ovs_update_headroom(struct datapath *dp, unsigned int new_headroom) +{ + struct vport *vport; + int i; + + dp->max_headroom = new_headroom; + for (i = 0; i < DP_VPORT_HASH_BUCKETS; i++) { + hlist_for_each_entry_rcu(vport, &dp->ports[i], dp_hash_node, + lockdep_ovsl_is_held()) + netdev_set_rx_headroom(vport->dev, new_headroom); + } +} + +static int ovs_vport_cmd_new(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr **a = info->attrs; + struct ovs_header *ovs_header = info->userhdr; + struct vport_parms parms; + struct sk_buff *reply; + struct vport *vport; + struct datapath *dp; + unsigned int new_headroom; + u32 port_no; + int err; + + if (!a[OVS_VPORT_ATTR_NAME] || !a[OVS_VPORT_ATTR_TYPE] || + !a[OVS_VPORT_ATTR_UPCALL_PID]) + return -EINVAL; + + parms.type = nla_get_u32(a[OVS_VPORT_ATTR_TYPE]); + + if (a[OVS_VPORT_ATTR_IFINDEX] && parms.type != OVS_VPORT_TYPE_INTERNAL) + return -EOPNOTSUPP; + + port_no = a[OVS_VPORT_ATTR_PORT_NO] + ? nla_get_u32(a[OVS_VPORT_ATTR_PORT_NO]) : 0; + if (port_no >= DP_MAX_PORTS) + return -EFBIG; + + reply = ovs_vport_cmd_alloc_info(); + if (!reply) + return -ENOMEM; + + ovs_lock(); +restart: + dp = get_dp(sock_net(skb->sk), ovs_header->dp_ifindex); + err = -ENODEV; + if (!dp) + goto exit_unlock_free; + + if (port_no) { + vport = ovs_vport_ovsl(dp, port_no); + err = -EBUSY; + if (vport) + goto exit_unlock_free; + } else { + for (port_no = 1; ; port_no++) { + if (port_no >= DP_MAX_PORTS) { + err = -EFBIG; + goto exit_unlock_free; + } + vport = ovs_vport_ovsl(dp, port_no); + if (!vport) + break; + } + } + + parms.name = nla_data(a[OVS_VPORT_ATTR_NAME]); + parms.options = a[OVS_VPORT_ATTR_OPTIONS]; + parms.dp = dp; + parms.port_no = port_no; + parms.upcall_portids = a[OVS_VPORT_ATTR_UPCALL_PID]; + parms.desired_ifindex = a[OVS_VPORT_ATTR_IFINDEX] + ? nla_get_s32(a[OVS_VPORT_ATTR_IFINDEX]) : 0; + + vport = new_vport(&parms); + err = PTR_ERR(vport); + if (IS_ERR(vport)) { + if (err == -EAGAIN) + goto restart; + goto exit_unlock_free; + } + + err = ovs_vport_cmd_fill_info(vport, reply, genl_info_net(info), + info->snd_portid, info->snd_seq, 0, + OVS_VPORT_CMD_NEW, GFP_KERNEL); + + new_headroom = netdev_get_fwd_headroom(vport->dev); + + if (new_headroom > dp->max_headroom) + ovs_update_headroom(dp, new_headroom); + else + netdev_set_rx_headroom(vport->dev, dp->max_headroom); + + BUG_ON(err < 0); + ovs_unlock(); + + ovs_notify(&dp_vport_genl_family, reply, info); + return 0; + +exit_unlock_free: + ovs_unlock(); + kfree_skb(reply); + return err; +} + +static int ovs_vport_cmd_set(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr **a = info->attrs; + struct sk_buff *reply; + struct vport *vport; + int err; + + reply = ovs_vport_cmd_alloc_info(); + if (!reply) + return -ENOMEM; + + ovs_lock(); + vport = lookup_vport(sock_net(skb->sk), info->userhdr, a); + err = PTR_ERR(vport); + if (IS_ERR(vport)) + goto exit_unlock_free; + + if (a[OVS_VPORT_ATTR_TYPE] && + nla_get_u32(a[OVS_VPORT_ATTR_TYPE]) != vport->ops->type) { + err = -EINVAL; + goto exit_unlock_free; + } + + if (a[OVS_VPORT_ATTR_OPTIONS]) { + err = ovs_vport_set_options(vport, a[OVS_VPORT_ATTR_OPTIONS]); + if (err) + goto exit_unlock_free; + } + + + if (a[OVS_VPORT_ATTR_UPCALL_PID]) { + struct nlattr *ids = a[OVS_VPORT_ATTR_UPCALL_PID]; + + err = ovs_vport_set_upcall_portids(vport, ids); + if (err) + goto exit_unlock_free; + } + + err = ovs_vport_cmd_fill_info(vport, reply, genl_info_net(info), + info->snd_portid, info->snd_seq, 0, + OVS_VPORT_CMD_SET, GFP_KERNEL); + BUG_ON(err < 0); + + ovs_unlock(); + ovs_notify(&dp_vport_genl_family, reply, info); + return 0; + +exit_unlock_free: + ovs_unlock(); + kfree_skb(reply); + return err; +} + +static int ovs_vport_cmd_del(struct sk_buff *skb, struct genl_info *info) +{ + bool update_headroom = false; + struct nlattr **a = info->attrs; + struct sk_buff *reply; + struct datapath *dp; + struct vport *vport; + unsigned int new_headroom; + int err; + + reply = ovs_vport_cmd_alloc_info(); + if (!reply) + return -ENOMEM; + + ovs_lock(); + vport = lookup_vport(sock_net(skb->sk), info->userhdr, a); + err = PTR_ERR(vport); + if (IS_ERR(vport)) + goto exit_unlock_free; + + if (vport->port_no == OVSP_LOCAL) { + err = -EINVAL; + goto exit_unlock_free; + } + + err = ovs_vport_cmd_fill_info(vport, reply, genl_info_net(info), + info->snd_portid, info->snd_seq, 0, + OVS_VPORT_CMD_DEL, GFP_KERNEL); + BUG_ON(err < 0); + + /* the vport deletion may trigger dp headroom update */ + dp = vport->dp; + if (netdev_get_fwd_headroom(vport->dev) == dp->max_headroom) + update_headroom = true; + + netdev_reset_rx_headroom(vport->dev); + ovs_dp_detach_port(vport); + + if (update_headroom) { + new_headroom = ovs_get_max_headroom(dp); + + if (new_headroom < dp->max_headroom) + ovs_update_headroom(dp, new_headroom); + } + ovs_unlock(); + + ovs_notify(&dp_vport_genl_family, reply, info); + return 0; + +exit_unlock_free: + ovs_unlock(); + kfree_skb(reply); + return err; +} + +static int ovs_vport_cmd_get(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr **a = info->attrs; + struct ovs_header *ovs_header = info->userhdr; + struct sk_buff *reply; + struct vport *vport; + int err; + + reply = ovs_vport_cmd_alloc_info(); + if (!reply) + return -ENOMEM; + + rcu_read_lock(); + vport = lookup_vport(sock_net(skb->sk), ovs_header, a); + err = PTR_ERR(vport); + if (IS_ERR(vport)) + goto exit_unlock_free; + err = ovs_vport_cmd_fill_info(vport, reply, genl_info_net(info), + info->snd_portid, info->snd_seq, 0, + OVS_VPORT_CMD_GET, GFP_ATOMIC); + BUG_ON(err < 0); + rcu_read_unlock(); + + return genlmsg_reply(reply, info); + +exit_unlock_free: + rcu_read_unlock(); + kfree_skb(reply); + return err; +} + +static int ovs_vport_cmd_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct ovs_header *ovs_header = genlmsg_data(nlmsg_data(cb->nlh)); + struct datapath *dp; + int bucket = cb->args[0], skip = cb->args[1]; + int i, j = 0; + + rcu_read_lock(); + dp = get_dp_rcu(sock_net(skb->sk), ovs_header->dp_ifindex); + if (!dp) { + rcu_read_unlock(); + return -ENODEV; + } + for (i = bucket; i < DP_VPORT_HASH_BUCKETS; i++) { + struct vport *vport; + + j = 0; + hlist_for_each_entry_rcu(vport, &dp->ports[i], dp_hash_node) { + if (j >= skip && + ovs_vport_cmd_fill_info(vport, skb, + sock_net(skb->sk), + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI, + OVS_VPORT_CMD_GET, + GFP_ATOMIC) < 0) + goto out; + + j++; + } + skip = 0; + } +out: + rcu_read_unlock(); + + cb->args[0] = i; + cb->args[1] = j; + + return skb->len; +} + +static void ovs_dp_masks_rebalance(struct work_struct *work) +{ + struct ovs_net *ovs_net = container_of(work, struct ovs_net, + masks_rebalance.work); + struct datapath *dp; + + ovs_lock(); + + list_for_each_entry(dp, &ovs_net->dps, list_node) + ovs_flow_masks_rebalance(&dp->table); + + ovs_unlock(); + + schedule_delayed_work(&ovs_net->masks_rebalance, + msecs_to_jiffies(DP_MASKS_REBALANCE_INTERVAL)); +} + +static const struct nla_policy vport_policy[OVS_VPORT_ATTR_MAX + 1] = { + [OVS_VPORT_ATTR_NAME] = { .type = NLA_NUL_STRING, .len = IFNAMSIZ - 1 }, + [OVS_VPORT_ATTR_STATS] = { .len = sizeof(struct ovs_vport_stats) }, + [OVS_VPORT_ATTR_PORT_NO] = { .type = NLA_U32 }, + [OVS_VPORT_ATTR_TYPE] = { .type = NLA_U32 }, + [OVS_VPORT_ATTR_UPCALL_PID] = { .type = NLA_UNSPEC }, + [OVS_VPORT_ATTR_OPTIONS] = { .type = NLA_NESTED }, + [OVS_VPORT_ATTR_IFINDEX] = NLA_POLICY_MIN(NLA_S32, 0), + [OVS_VPORT_ATTR_NETNSID] = { .type = NLA_S32 }, +}; + +static const struct genl_small_ops dp_vport_genl_ops[] = { + { .cmd = OVS_VPORT_CMD_NEW, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, /* Requires CAP_NET_ADMIN privilege. */ + .doit = ovs_vport_cmd_new + }, + { .cmd = OVS_VPORT_CMD_DEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, /* Requires CAP_NET_ADMIN privilege. */ + .doit = ovs_vport_cmd_del + }, + { .cmd = OVS_VPORT_CMD_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, /* OK for unprivileged users. */ + .doit = ovs_vport_cmd_get, + .dumpit = ovs_vport_cmd_dump + }, + { .cmd = OVS_VPORT_CMD_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, /* Requires CAP_NET_ADMIN privilege. */ + .doit = ovs_vport_cmd_set, + }, +}; + +struct genl_family dp_vport_genl_family __ro_after_init = { + .hdrsize = sizeof(struct ovs_header), + .name = OVS_VPORT_FAMILY, + .version = OVS_VPORT_VERSION, + .maxattr = OVS_VPORT_ATTR_MAX, + .policy = vport_policy, + .netnsok = true, + .parallel_ops = true, + .small_ops = dp_vport_genl_ops, + .n_small_ops = ARRAY_SIZE(dp_vport_genl_ops), + .resv_start_op = OVS_VPORT_CMD_SET + 1, + .mcgrps = &ovs_dp_vport_multicast_group, + .n_mcgrps = 1, + .module = THIS_MODULE, +}; + +static struct genl_family * const dp_genl_families[] = { + &dp_datapath_genl_family, + &dp_vport_genl_family, + &dp_flow_genl_family, + &dp_packet_genl_family, + &dp_meter_genl_family, +#if IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT) + &dp_ct_limit_genl_family, +#endif +}; + +static void dp_unregister_genl(int n_families) +{ + int i; + + for (i = 0; i < n_families; i++) + genl_unregister_family(dp_genl_families[i]); +} + +static int __init dp_register_genl(void) +{ + int err; + int i; + + for (i = 0; i < ARRAY_SIZE(dp_genl_families); i++) { + + err = genl_register_family(dp_genl_families[i]); + if (err) + goto error; + } + + return 0; + +error: + dp_unregister_genl(i); + return err; +} + +static int __net_init ovs_init_net(struct net *net) +{ + struct ovs_net *ovs_net = net_generic(net, ovs_net_id); + int err; + + INIT_LIST_HEAD(&ovs_net->dps); + INIT_WORK(&ovs_net->dp_notify_work, ovs_dp_notify_wq); + INIT_DELAYED_WORK(&ovs_net->masks_rebalance, ovs_dp_masks_rebalance); + + err = ovs_ct_init(net); + if (err) + return err; + + schedule_delayed_work(&ovs_net->masks_rebalance, + msecs_to_jiffies(DP_MASKS_REBALANCE_INTERVAL)); + return 0; +} + +static void __net_exit list_vports_from_net(struct net *net, struct net *dnet, + struct list_head *head) +{ + struct ovs_net *ovs_net = net_generic(net, ovs_net_id); + struct datapath *dp; + + list_for_each_entry(dp, &ovs_net->dps, list_node) { + int i; + + for (i = 0; i < DP_VPORT_HASH_BUCKETS; i++) { + struct vport *vport; + + hlist_for_each_entry(vport, &dp->ports[i], dp_hash_node) { + if (vport->ops->type != OVS_VPORT_TYPE_INTERNAL) + continue; + + if (dev_net(vport->dev) == dnet) + list_add(&vport->detach_list, head); + } + } + } +} + +static void __net_exit ovs_exit_net(struct net *dnet) +{ + struct datapath *dp, *dp_next; + struct ovs_net *ovs_net = net_generic(dnet, ovs_net_id); + struct vport *vport, *vport_next; + struct net *net; + LIST_HEAD(head); + + ovs_lock(); + + ovs_ct_exit(dnet); + + list_for_each_entry_safe(dp, dp_next, &ovs_net->dps, list_node) + __dp_destroy(dp); + + down_read(&net_rwsem); + for_each_net(net) + list_vports_from_net(net, dnet, &head); + up_read(&net_rwsem); + + /* Detach all vports from given namespace. */ + list_for_each_entry_safe(vport, vport_next, &head, detach_list) { + list_del(&vport->detach_list); + ovs_dp_detach_port(vport); + } + + ovs_unlock(); + + cancel_delayed_work_sync(&ovs_net->masks_rebalance); + cancel_work_sync(&ovs_net->dp_notify_work); +} + +static struct pernet_operations ovs_net_ops = { + .init = ovs_init_net, + .exit = ovs_exit_net, + .id = &ovs_net_id, + .size = sizeof(struct ovs_net), +}; + +static int __init dp_init(void) +{ + int err; + + BUILD_BUG_ON(sizeof(struct ovs_skb_cb) > + sizeof_field(struct sk_buff, cb)); + + pr_info("Open vSwitch switching datapath\n"); + + err = action_fifos_init(); + if (err) + goto error; + + err = ovs_internal_dev_rtnl_link_register(); + if (err) + goto error_action_fifos_exit; + + err = ovs_flow_init(); + if (err) + goto error_unreg_rtnl_link; + + err = ovs_vport_init(); + if (err) + goto error_flow_exit; + + err = register_pernet_device(&ovs_net_ops); + if (err) + goto error_vport_exit; + + err = register_netdevice_notifier(&ovs_dp_device_notifier); + if (err) + goto error_netns_exit; + + err = ovs_netdev_init(); + if (err) + goto error_unreg_notifier; + + err = dp_register_genl(); + if (err < 0) + goto error_unreg_netdev; + + return 0; + +error_unreg_netdev: + ovs_netdev_exit(); +error_unreg_notifier: + unregister_netdevice_notifier(&ovs_dp_device_notifier); +error_netns_exit: + unregister_pernet_device(&ovs_net_ops); +error_vport_exit: + ovs_vport_exit(); +error_flow_exit: + ovs_flow_exit(); +error_unreg_rtnl_link: + ovs_internal_dev_rtnl_link_unregister(); +error_action_fifos_exit: + action_fifos_exit(); +error: + return err; +} + +static void dp_cleanup(void) +{ + dp_unregister_genl(ARRAY_SIZE(dp_genl_families)); + ovs_netdev_exit(); + unregister_netdevice_notifier(&ovs_dp_device_notifier); + unregister_pernet_device(&ovs_net_ops); + rcu_barrier(); + ovs_vport_exit(); + ovs_flow_exit(); + ovs_internal_dev_rtnl_link_unregister(); + action_fifos_exit(); +} + +module_init(dp_init); +module_exit(dp_cleanup); + +MODULE_DESCRIPTION("Open vSwitch switching datapath"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_GENL_FAMILY(OVS_DATAPATH_FAMILY); +MODULE_ALIAS_GENL_FAMILY(OVS_VPORT_FAMILY); +MODULE_ALIAS_GENL_FAMILY(OVS_FLOW_FAMILY); +MODULE_ALIAS_GENL_FAMILY(OVS_PACKET_FAMILY); +MODULE_ALIAS_GENL_FAMILY(OVS_METER_FAMILY); +MODULE_ALIAS_GENL_FAMILY(OVS_CT_LIMIT_FAMILY); diff --git a/net/openvswitch/datapath.h b/net/openvswitch/datapath.h new file mode 100644 index 000000000..0cd29971a --- /dev/null +++ b/net/openvswitch/datapath.h @@ -0,0 +1,285 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2007-2014 Nicira, Inc. + */ + +#ifndef DATAPATH_H +#define DATAPATH_H 1 + +#include +#include +#include +#include +#include +#include +#include + +#include "conntrack.h" +#include "flow.h" +#include "flow_table.h" +#include "meter.h" +#include "vport-internal_dev.h" + +#define DP_MAX_PORTS USHRT_MAX +#define DP_VPORT_HASH_BUCKETS 1024 +#define DP_MASKS_REBALANCE_INTERVAL 4000 + +/** + * struct dp_stats_percpu - per-cpu packet processing statistics for a given + * datapath. + * @n_hit: Number of received packets for which a matching flow was found in + * the flow table. + * @n_miss: Number of received packets that had no matching flow in the flow + * table. The sum of @n_hit and @n_miss is the number of packets that have + * been received by the datapath. + * @n_lost: Number of received packets that had no matching flow in the flow + * table that could not be sent to userspace (normally due to an overflow in + * one of the datapath's queues). + * @n_mask_hit: Number of masks looked up for flow match. + * @n_mask_hit / (@n_hit + @n_missed) will be the average masks looked + * up per packet. + * @n_cache_hit: The number of received packets that had their mask found using + * the mask cache. + */ +struct dp_stats_percpu { + u64 n_hit; + u64 n_missed; + u64 n_lost; + u64 n_mask_hit; + u64 n_cache_hit; + struct u64_stats_sync syncp; +}; + +/** + * struct dp_nlsk_pids - array of netlink portids of for a datapath. + * This is used when OVS_DP_F_DISPATCH_UPCALL_PER_CPU + * is enabled and must be protected by rcu. + * @rcu: RCU callback head for deferred destruction. + * @n_pids: Size of @pids array. + * @pids: Array storing the Netlink socket PIDs indexed by CPU ID for packets + * that miss the flow table. + */ +struct dp_nlsk_pids { + struct rcu_head rcu; + u32 n_pids; + u32 pids[]; +}; + +/** + * struct datapath - datapath for flow-based packet switching + * @rcu: RCU callback head for deferred destruction. + * @list_node: Element in global 'dps' list. + * @table: flow table. + * @ports: Hash table for ports. %OVSP_LOCAL port always exists. Protected by + * ovs_mutex and RCU. + * @stats_percpu: Per-CPU datapath statistics. + * @net: Reference to net namespace. + * @max_headroom: the maximum headroom of all vports in this datapath; it will + * be used by all the internal vports in this dp. + * @upcall_portids: RCU protected 'struct dp_nlsk_pids'. + * + * Context: See the comment on locking at the top of datapath.c for additional + * locking information. + */ +struct datapath { + struct rcu_head rcu; + struct list_head list_node; + + /* Flow table. */ + struct flow_table table; + + /* Switch ports. */ + struct hlist_head *ports; + + /* Stats. */ + struct dp_stats_percpu __percpu *stats_percpu; + + /* Network namespace ref. */ + possible_net_t net; + + u32 user_features; + + u32 max_headroom; + + /* Switch meters. */ + struct dp_meter_table meter_tbl; + + struct dp_nlsk_pids __rcu *upcall_portids; +}; + +/** + * struct ovs_skb_cb - OVS data in skb CB + * @input_vport: The original vport packet came in on. This value is cached + * when a packet is received by OVS. + * @mru: The maximum received fragement size; 0 if the packet is not + * fragmented. + * @acts_origlen: The netlink size of the flow actions applied to this skb. + * @cutlen: The number of bytes from the packet end to be removed. + */ +struct ovs_skb_cb { + struct vport *input_vport; + u16 mru; + u16 acts_origlen; + u32 cutlen; +}; +#define OVS_CB(skb) ((struct ovs_skb_cb *)(skb)->cb) + +/** + * struct dp_upcall - metadata to include with a packet to send to userspace + * @cmd: One of %OVS_PACKET_CMD_*. + * @userdata: If nonnull, its variable-length value is passed to userspace as + * %OVS_PACKET_ATTR_USERDATA. + * @portid: Netlink portid to which packet should be sent. If @portid is 0 + * then no packet is sent and the packet is accounted in the datapath's @n_lost + * counter. + * @egress_tun_info: If nonnull, becomes %OVS_PACKET_ATTR_EGRESS_TUN_KEY. + * @mru: If not zero, Maximum received IP fragment size. + */ +struct dp_upcall_info { + struct ip_tunnel_info *egress_tun_info; + const struct nlattr *userdata; + const struct nlattr *actions; + int actions_len; + u32 portid; + u8 cmd; + u16 mru; +}; + +/** + * struct ovs_net - Per net-namespace data for ovs. + * @dps: List of datapaths to enable dumping them all out. + * Protected by genl_mutex. + */ +struct ovs_net { + struct list_head dps; + struct work_struct dp_notify_work; + struct delayed_work masks_rebalance; +#if IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT) + struct ovs_ct_limit_info *ct_limit_info; +#endif + + /* Module reference for configuring conntrack. */ + bool xt_label; +}; + +/** + * enum ovs_pkt_hash_types - hash info to include with a packet + * to send to userspace. + * @OVS_PACKET_HASH_SW_BIT: indicates hash was computed in software stack. + * @OVS_PACKET_HASH_L4_BIT: indicates hash is a canonical 4-tuple hash + * over transport ports. + */ +enum ovs_pkt_hash_types { + OVS_PACKET_HASH_SW_BIT = (1ULL << 32), + OVS_PACKET_HASH_L4_BIT = (1ULL << 33), +}; + +extern unsigned int ovs_net_id; +void ovs_lock(void); +void ovs_unlock(void); + +#ifdef CONFIG_LOCKDEP +int lockdep_ovsl_is_held(void); +#else +#define lockdep_ovsl_is_held() 1 +#endif + +#define ASSERT_OVSL() WARN_ON(!lockdep_ovsl_is_held()) +#define ovsl_dereference(p) \ + rcu_dereference_protected(p, lockdep_ovsl_is_held()) +#define rcu_dereference_ovsl(p) \ + rcu_dereference_check(p, lockdep_ovsl_is_held()) + +static inline struct net *ovs_dp_get_net(const struct datapath *dp) +{ + return read_pnet(&dp->net); +} + +static inline void ovs_dp_set_net(struct datapath *dp, struct net *net) +{ + write_pnet(&dp->net, net); +} + +struct vport *ovs_lookup_vport(const struct datapath *dp, u16 port_no); + +static inline struct vport *ovs_vport_rcu(const struct datapath *dp, int port_no) +{ + WARN_ON_ONCE(!rcu_read_lock_held()); + return ovs_lookup_vport(dp, port_no); +} + +static inline struct vport *ovs_vport_ovsl_rcu(const struct datapath *dp, int port_no) +{ + WARN_ON_ONCE(!rcu_read_lock_held() && !lockdep_ovsl_is_held()); + return ovs_lookup_vport(dp, port_no); +} + +static inline struct vport *ovs_vport_ovsl(const struct datapath *dp, int port_no) +{ + ASSERT_OVSL(); + return ovs_lookup_vport(dp, port_no); +} + +/* Must be called with rcu_read_lock. */ +static inline struct datapath *get_dp_rcu(struct net *net, int dp_ifindex) +{ + struct net_device *dev = dev_get_by_index_rcu(net, dp_ifindex); + + if (dev) { + struct vport *vport = ovs_internal_dev_get_vport(dev); + + if (vport) + return vport->dp; + } + + return NULL; +} + +/* The caller must hold either ovs_mutex or rcu_read_lock to keep the + * returned dp pointer valid. + */ +static inline struct datapath *get_dp(struct net *net, int dp_ifindex) +{ + struct datapath *dp; + + WARN_ON_ONCE(!rcu_read_lock_held() && !lockdep_ovsl_is_held()); + rcu_read_lock(); + dp = get_dp_rcu(net, dp_ifindex); + rcu_read_unlock(); + + return dp; +} + +extern struct notifier_block ovs_dp_device_notifier; +extern struct genl_family dp_vport_genl_family; + +void ovs_dp_process_packet(struct sk_buff *skb, struct sw_flow_key *key); +void ovs_dp_detach_port(struct vport *); +int ovs_dp_upcall(struct datapath *, struct sk_buff *, + const struct sw_flow_key *, const struct dp_upcall_info *, + uint32_t cutlen); + +u32 ovs_dp_get_upcall_portid(const struct datapath *dp, uint32_t cpu_id); + +const char *ovs_dp_name(const struct datapath *dp); +struct sk_buff *ovs_vport_cmd_build_info(struct vport *vport, struct net *net, + u32 portid, u32 seq, u8 cmd); + +int ovs_execute_actions(struct datapath *dp, struct sk_buff *skb, + const struct sw_flow_actions *, struct sw_flow_key *); + +void ovs_dp_notify_wq(struct work_struct *work); + +int action_fifos_init(void); +void action_fifos_exit(void); + +/* 'KEY' must not have any bits set outside of the 'MASK' */ +#define OVS_MASKED(OLD, KEY, MASK) ((KEY) | ((OLD) & ~(MASK))) +#define OVS_SET_MASKED(OLD, KEY, MASK) ((OLD) = OVS_MASKED(OLD, KEY, MASK)) + +#define OVS_NLERR(logging_allowed, fmt, ...) \ +do { \ + if (logging_allowed && net_ratelimit()) \ + pr_info("netlink: " fmt "\n", ##__VA_ARGS__); \ +} while (0) +#endif /* datapath.h */ diff --git a/net/openvswitch/dp_notify.c b/net/openvswitch/dp_notify.c new file mode 100644 index 000000000..7af0cde8b --- /dev/null +++ b/net/openvswitch/dp_notify.c @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2007-2012 Nicira, Inc. + */ + +#include +#include +#include + +#include "datapath.h" +#include "vport-internal_dev.h" +#include "vport-netdev.h" + +static void dp_detach_port_notify(struct vport *vport) +{ + struct sk_buff *notify; + struct datapath *dp; + + dp = vport->dp; + notify = ovs_vport_cmd_build_info(vport, ovs_dp_get_net(dp), + 0, 0, OVS_VPORT_CMD_DEL); + ovs_dp_detach_port(vport); + if (IS_ERR(notify)) { + genl_set_err(&dp_vport_genl_family, ovs_dp_get_net(dp), 0, + 0, PTR_ERR(notify)); + return; + } + + genlmsg_multicast_netns(&dp_vport_genl_family, + ovs_dp_get_net(dp), notify, 0, + 0, GFP_KERNEL); +} + +void ovs_dp_notify_wq(struct work_struct *work) +{ + struct ovs_net *ovs_net = container_of(work, struct ovs_net, dp_notify_work); + struct datapath *dp; + + ovs_lock(); + list_for_each_entry(dp, &ovs_net->dps, list_node) { + int i; + + for (i = 0; i < DP_VPORT_HASH_BUCKETS; i++) { + struct vport *vport; + struct hlist_node *n; + + hlist_for_each_entry_safe(vport, n, &dp->ports[i], dp_hash_node) { + if (vport->ops->type == OVS_VPORT_TYPE_INTERNAL) + continue; + + if (!(netif_is_ovs_port(vport->dev))) + dp_detach_port_notify(vport); + } + } + } + ovs_unlock(); +} + +static int dp_device_event(struct notifier_block *unused, unsigned long event, + void *ptr) +{ + struct ovs_net *ovs_net; + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct vport *vport = NULL; + + if (!ovs_is_internal_dev(dev)) + vport = ovs_netdev_get_vport(dev); + + if (!vport) + return NOTIFY_DONE; + + if (event == NETDEV_UNREGISTER) { + /* upper_dev_unlink and decrement promisc immediately */ + ovs_netdev_detach_dev(vport); + + /* schedule vport destroy, dev_put and genl notification */ + ovs_net = net_generic(dev_net(dev), ovs_net_id); + queue_work(system_wq, &ovs_net->dp_notify_work); + } + + return NOTIFY_DONE; +} + +struct notifier_block ovs_dp_device_notifier = { + .notifier_call = dp_device_event +}; diff --git a/net/openvswitch/flow.c b/net/openvswitch/flow.c new file mode 100644 index 000000000..e20d1a973 --- /dev/null +++ b/net/openvswitch/flow.c @@ -0,0 +1,1115 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2007-2014 Nicira, Inc. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "conntrack.h" +#include "datapath.h" +#include "flow.h" +#include "flow_netlink.h" +#include "vport.h" + +u64 ovs_flow_used_time(unsigned long flow_jiffies) +{ + struct timespec64 cur_ts; + u64 cur_ms, idle_ms; + + ktime_get_ts64(&cur_ts); + idle_ms = jiffies_to_msecs(jiffies - flow_jiffies); + cur_ms = (u64)(u32)cur_ts.tv_sec * MSEC_PER_SEC + + cur_ts.tv_nsec / NSEC_PER_MSEC; + + return cur_ms - idle_ms; +} + +#define TCP_FLAGS_BE16(tp) (*(__be16 *)&tcp_flag_word(tp) & htons(0x0FFF)) + +void ovs_flow_stats_update(struct sw_flow *flow, __be16 tcp_flags, + const struct sk_buff *skb) +{ + struct sw_flow_stats *stats; + unsigned int cpu = smp_processor_id(); + int len = skb->len + (skb_vlan_tag_present(skb) ? VLAN_HLEN : 0); + + stats = rcu_dereference(flow->stats[cpu]); + + /* Check if already have CPU-specific stats. */ + if (likely(stats)) { + spin_lock(&stats->lock); + /* Mark if we write on the pre-allocated stats. */ + if (cpu == 0 && unlikely(flow->stats_last_writer != cpu)) + flow->stats_last_writer = cpu; + } else { + stats = rcu_dereference(flow->stats[0]); /* Pre-allocated. */ + spin_lock(&stats->lock); + + /* If the current CPU is the only writer on the + * pre-allocated stats keep using them. + */ + if (unlikely(flow->stats_last_writer != cpu)) { + /* A previous locker may have already allocated the + * stats, so we need to check again. If CPU-specific + * stats were already allocated, we update the pre- + * allocated stats as we have already locked them. + */ + if (likely(flow->stats_last_writer != -1) && + likely(!rcu_access_pointer(flow->stats[cpu]))) { + /* Try to allocate CPU-specific stats. */ + struct sw_flow_stats *new_stats; + + new_stats = + kmem_cache_alloc_node(flow_stats_cache, + GFP_NOWAIT | + __GFP_THISNODE | + __GFP_NOWARN | + __GFP_NOMEMALLOC, + numa_node_id()); + if (likely(new_stats)) { + new_stats->used = jiffies; + new_stats->packet_count = 1; + new_stats->byte_count = len; + new_stats->tcp_flags = tcp_flags; + spin_lock_init(&new_stats->lock); + + rcu_assign_pointer(flow->stats[cpu], + new_stats); + cpumask_set_cpu(cpu, &flow->cpu_used_mask); + goto unlock; + } + } + flow->stats_last_writer = cpu; + } + } + + stats->used = jiffies; + stats->packet_count++; + stats->byte_count += len; + stats->tcp_flags |= tcp_flags; +unlock: + spin_unlock(&stats->lock); +} + +/* Must be called with rcu_read_lock or ovs_mutex. */ +void ovs_flow_stats_get(const struct sw_flow *flow, + struct ovs_flow_stats *ovs_stats, + unsigned long *used, __be16 *tcp_flags) +{ + int cpu; + + *used = 0; + *tcp_flags = 0; + memset(ovs_stats, 0, sizeof(*ovs_stats)); + + /* We open code this to make sure cpu 0 is always considered */ + for (cpu = 0; cpu < nr_cpu_ids; cpu = cpumask_next(cpu, &flow->cpu_used_mask)) { + struct sw_flow_stats *stats = rcu_dereference_ovsl(flow->stats[cpu]); + + if (stats) { + /* Local CPU may write on non-local stats, so we must + * block bottom-halves here. + */ + spin_lock_bh(&stats->lock); + if (!*used || time_after(stats->used, *used)) + *used = stats->used; + *tcp_flags |= stats->tcp_flags; + ovs_stats->n_packets += stats->packet_count; + ovs_stats->n_bytes += stats->byte_count; + spin_unlock_bh(&stats->lock); + } + } +} + +/* Called with ovs_mutex. */ +void ovs_flow_stats_clear(struct sw_flow *flow) +{ + int cpu; + + /* We open code this to make sure cpu 0 is always considered */ + for (cpu = 0; cpu < nr_cpu_ids; cpu = cpumask_next(cpu, &flow->cpu_used_mask)) { + struct sw_flow_stats *stats = ovsl_dereference(flow->stats[cpu]); + + if (stats) { + spin_lock_bh(&stats->lock); + stats->used = 0; + stats->packet_count = 0; + stats->byte_count = 0; + stats->tcp_flags = 0; + spin_unlock_bh(&stats->lock); + } + } +} + +static int check_header(struct sk_buff *skb, int len) +{ + if (unlikely(skb->len < len)) + return -EINVAL; + if (unlikely(!pskb_may_pull(skb, len))) + return -ENOMEM; + return 0; +} + +static bool arphdr_ok(struct sk_buff *skb) +{ + return pskb_may_pull(skb, skb_network_offset(skb) + + sizeof(struct arp_eth_header)); +} + +static int check_iphdr(struct sk_buff *skb) +{ + unsigned int nh_ofs = skb_network_offset(skb); + unsigned int ip_len; + int err; + + err = check_header(skb, nh_ofs + sizeof(struct iphdr)); + if (unlikely(err)) + return err; + + ip_len = ip_hdrlen(skb); + if (unlikely(ip_len < sizeof(struct iphdr) || + skb->len < nh_ofs + ip_len)) + return -EINVAL; + + skb_set_transport_header(skb, nh_ofs + ip_len); + return 0; +} + +static bool tcphdr_ok(struct sk_buff *skb) +{ + int th_ofs = skb_transport_offset(skb); + int tcp_len; + + if (unlikely(!pskb_may_pull(skb, th_ofs + sizeof(struct tcphdr)))) + return false; + + tcp_len = tcp_hdrlen(skb); + if (unlikely(tcp_len < sizeof(struct tcphdr) || + skb->len < th_ofs + tcp_len)) + return false; + + return true; +} + +static bool udphdr_ok(struct sk_buff *skb) +{ + return pskb_may_pull(skb, skb_transport_offset(skb) + + sizeof(struct udphdr)); +} + +static bool sctphdr_ok(struct sk_buff *skb) +{ + return pskb_may_pull(skb, skb_transport_offset(skb) + + sizeof(struct sctphdr)); +} + +static bool icmphdr_ok(struct sk_buff *skb) +{ + return pskb_may_pull(skb, skb_transport_offset(skb) + + sizeof(struct icmphdr)); +} + +/** + * get_ipv6_ext_hdrs() - Parses packet and sets IPv6 extension header flags. + * + * @skb: buffer where extension header data starts in packet + * @nh: ipv6 header + * @ext_hdrs: flags are stored here + * + * OFPIEH12_UNREP is set if more than one of a given IPv6 extension header + * is unexpectedly encountered. (Two destination options headers may be + * expected and would not cause this bit to be set.) + * + * OFPIEH12_UNSEQ is set if IPv6 extension headers were not in the order + * preferred (but not required) by RFC 2460: + * + * When more than one extension header is used in the same packet, it is + * recommended that those headers appear in the following order: + * IPv6 header + * Hop-by-Hop Options header + * Destination Options header + * Routing header + * Fragment header + * Authentication header + * Encapsulating Security Payload header + * Destination Options header + * upper-layer header + */ +static void get_ipv6_ext_hdrs(struct sk_buff *skb, struct ipv6hdr *nh, + u16 *ext_hdrs) +{ + u8 next_type = nh->nexthdr; + unsigned int start = skb_network_offset(skb) + sizeof(struct ipv6hdr); + int dest_options_header_count = 0; + + *ext_hdrs = 0; + + while (ipv6_ext_hdr(next_type)) { + struct ipv6_opt_hdr _hdr, *hp; + + switch (next_type) { + case IPPROTO_NONE: + *ext_hdrs |= OFPIEH12_NONEXT; + /* stop parsing */ + return; + + case IPPROTO_ESP: + if (*ext_hdrs & OFPIEH12_ESP) + *ext_hdrs |= OFPIEH12_UNREP; + if ((*ext_hdrs & ~(OFPIEH12_HOP | OFPIEH12_DEST | + OFPIEH12_ROUTER | IPPROTO_FRAGMENT | + OFPIEH12_AUTH | OFPIEH12_UNREP)) || + dest_options_header_count >= 2) { + *ext_hdrs |= OFPIEH12_UNSEQ; + } + *ext_hdrs |= OFPIEH12_ESP; + break; + + case IPPROTO_AH: + if (*ext_hdrs & OFPIEH12_AUTH) + *ext_hdrs |= OFPIEH12_UNREP; + if ((*ext_hdrs & + ~(OFPIEH12_HOP | OFPIEH12_DEST | OFPIEH12_ROUTER | + IPPROTO_FRAGMENT | OFPIEH12_UNREP)) || + dest_options_header_count >= 2) { + *ext_hdrs |= OFPIEH12_UNSEQ; + } + *ext_hdrs |= OFPIEH12_AUTH; + break; + + case IPPROTO_DSTOPTS: + if (dest_options_header_count == 0) { + if (*ext_hdrs & + ~(OFPIEH12_HOP | OFPIEH12_UNREP)) + *ext_hdrs |= OFPIEH12_UNSEQ; + *ext_hdrs |= OFPIEH12_DEST; + } else if (dest_options_header_count == 1) { + if (*ext_hdrs & + ~(OFPIEH12_HOP | OFPIEH12_DEST | + OFPIEH12_ROUTER | OFPIEH12_FRAG | + OFPIEH12_AUTH | OFPIEH12_ESP | + OFPIEH12_UNREP)) { + *ext_hdrs |= OFPIEH12_UNSEQ; + } + } else { + *ext_hdrs |= OFPIEH12_UNREP; + } + dest_options_header_count++; + break; + + case IPPROTO_FRAGMENT: + if (*ext_hdrs & OFPIEH12_FRAG) + *ext_hdrs |= OFPIEH12_UNREP; + if ((*ext_hdrs & ~(OFPIEH12_HOP | + OFPIEH12_DEST | + OFPIEH12_ROUTER | + OFPIEH12_UNREP)) || + dest_options_header_count >= 2) { + *ext_hdrs |= OFPIEH12_UNSEQ; + } + *ext_hdrs |= OFPIEH12_FRAG; + break; + + case IPPROTO_ROUTING: + if (*ext_hdrs & OFPIEH12_ROUTER) + *ext_hdrs |= OFPIEH12_UNREP; + if ((*ext_hdrs & ~(OFPIEH12_HOP | + OFPIEH12_DEST | + OFPIEH12_UNREP)) || + dest_options_header_count >= 2) { + *ext_hdrs |= OFPIEH12_UNSEQ; + } + *ext_hdrs |= OFPIEH12_ROUTER; + break; + + case IPPROTO_HOPOPTS: + if (*ext_hdrs & OFPIEH12_HOP) + *ext_hdrs |= OFPIEH12_UNREP; + /* OFPIEH12_HOP is set to 1 if a hop-by-hop IPv6 + * extension header is present as the first + * extension header in the packet. + */ + if (*ext_hdrs == 0) + *ext_hdrs |= OFPIEH12_HOP; + else + *ext_hdrs |= OFPIEH12_UNSEQ; + break; + + default: + return; + } + + hp = skb_header_pointer(skb, start, sizeof(_hdr), &_hdr); + if (!hp) + break; + next_type = hp->nexthdr; + start += ipv6_optlen(hp); + } +} + +static int parse_ipv6hdr(struct sk_buff *skb, struct sw_flow_key *key) +{ + unsigned short frag_off; + unsigned int payload_ofs = 0; + unsigned int nh_ofs = skb_network_offset(skb); + unsigned int nh_len; + struct ipv6hdr *nh; + int err, nexthdr, flags = 0; + + err = check_header(skb, nh_ofs + sizeof(*nh)); + if (unlikely(err)) + return err; + + nh = ipv6_hdr(skb); + + get_ipv6_ext_hdrs(skb, nh, &key->ipv6.exthdrs); + + key->ip.proto = NEXTHDR_NONE; + key->ip.tos = ipv6_get_dsfield(nh); + key->ip.ttl = nh->hop_limit; + key->ipv6.label = *(__be32 *)nh & htonl(IPV6_FLOWINFO_FLOWLABEL); + key->ipv6.addr.src = nh->saddr; + key->ipv6.addr.dst = nh->daddr; + + nexthdr = ipv6_find_hdr(skb, &payload_ofs, -1, &frag_off, &flags); + if (flags & IP6_FH_F_FRAG) { + if (frag_off) { + key->ip.frag = OVS_FRAG_TYPE_LATER; + key->ip.proto = NEXTHDR_FRAGMENT; + return 0; + } + key->ip.frag = OVS_FRAG_TYPE_FIRST; + } else { + key->ip.frag = OVS_FRAG_TYPE_NONE; + } + + /* Delayed handling of error in ipv6_find_hdr() as it + * always sets flags and frag_off to a valid value which may be + * used to set key->ip.frag above. + */ + if (unlikely(nexthdr < 0)) + return -EPROTO; + + nh_len = payload_ofs - nh_ofs; + skb_set_transport_header(skb, nh_ofs + nh_len); + key->ip.proto = nexthdr; + return nh_len; +} + +static bool icmp6hdr_ok(struct sk_buff *skb) +{ + return pskb_may_pull(skb, skb_transport_offset(skb) + + sizeof(struct icmp6hdr)); +} + +/** + * parse_vlan_tag - Parse vlan tag from vlan header. + * @skb: skb containing frame to parse + * @key_vh: pointer to parsed vlan tag + * @untag_vlan: should the vlan header be removed from the frame + * + * Return: ERROR on memory error. + * %0 if it encounters a non-vlan or incomplete packet. + * %1 after successfully parsing vlan tag. + */ +static int parse_vlan_tag(struct sk_buff *skb, struct vlan_head *key_vh, + bool untag_vlan) +{ + struct vlan_head *vh = (struct vlan_head *)skb->data; + + if (likely(!eth_type_vlan(vh->tpid))) + return 0; + + if (unlikely(skb->len < sizeof(struct vlan_head) + sizeof(__be16))) + return 0; + + if (unlikely(!pskb_may_pull(skb, sizeof(struct vlan_head) + + sizeof(__be16)))) + return -ENOMEM; + + vh = (struct vlan_head *)skb->data; + key_vh->tci = vh->tci | htons(VLAN_CFI_MASK); + key_vh->tpid = vh->tpid; + + if (unlikely(untag_vlan)) { + int offset = skb->data - skb_mac_header(skb); + u16 tci; + int err; + + __skb_push(skb, offset); + err = __skb_vlan_pop(skb, &tci); + __skb_pull(skb, offset); + if (err) + return err; + __vlan_hwaccel_put_tag(skb, key_vh->tpid, tci); + } else { + __skb_pull(skb, sizeof(struct vlan_head)); + } + return 1; +} + +static void clear_vlan(struct sw_flow_key *key) +{ + key->eth.vlan.tci = 0; + key->eth.vlan.tpid = 0; + key->eth.cvlan.tci = 0; + key->eth.cvlan.tpid = 0; +} + +static int parse_vlan(struct sk_buff *skb, struct sw_flow_key *key) +{ + int res; + + if (skb_vlan_tag_present(skb)) { + key->eth.vlan.tci = htons(skb->vlan_tci) | htons(VLAN_CFI_MASK); + key->eth.vlan.tpid = skb->vlan_proto; + } else { + /* Parse outer vlan tag in the non-accelerated case. */ + res = parse_vlan_tag(skb, &key->eth.vlan, true); + if (res <= 0) + return res; + } + + /* Parse inner vlan tag. */ + res = parse_vlan_tag(skb, &key->eth.cvlan, false); + if (res <= 0) + return res; + + return 0; +} + +static __be16 parse_ethertype(struct sk_buff *skb) +{ + struct llc_snap_hdr { + u8 dsap; /* Always 0xAA */ + u8 ssap; /* Always 0xAA */ + u8 ctrl; + u8 oui[3]; + __be16 ethertype; + }; + struct llc_snap_hdr *llc; + __be16 proto; + + proto = *(__be16 *) skb->data; + __skb_pull(skb, sizeof(__be16)); + + if (eth_proto_is_802_3(proto)) + return proto; + + if (skb->len < sizeof(struct llc_snap_hdr)) + return htons(ETH_P_802_2); + + if (unlikely(!pskb_may_pull(skb, sizeof(struct llc_snap_hdr)))) + return htons(0); + + llc = (struct llc_snap_hdr *) skb->data; + if (llc->dsap != LLC_SAP_SNAP || + llc->ssap != LLC_SAP_SNAP || + (llc->oui[0] | llc->oui[1] | llc->oui[2]) != 0) + return htons(ETH_P_802_2); + + __skb_pull(skb, sizeof(struct llc_snap_hdr)); + + if (eth_proto_is_802_3(llc->ethertype)) + return llc->ethertype; + + return htons(ETH_P_802_2); +} + +static int parse_icmpv6(struct sk_buff *skb, struct sw_flow_key *key, + int nh_len) +{ + struct icmp6hdr *icmp = icmp6_hdr(skb); + + /* The ICMPv6 type and code fields use the 16-bit transport port + * fields, so we need to store them in 16-bit network byte order. + */ + key->tp.src = htons(icmp->icmp6_type); + key->tp.dst = htons(icmp->icmp6_code); + memset(&key->ipv6.nd, 0, sizeof(key->ipv6.nd)); + + if (icmp->icmp6_code == 0 && + (icmp->icmp6_type == NDISC_NEIGHBOUR_SOLICITATION || + icmp->icmp6_type == NDISC_NEIGHBOUR_ADVERTISEMENT)) { + int icmp_len = skb->len - skb_transport_offset(skb); + struct nd_msg *nd; + int offset; + + /* In order to process neighbor discovery options, we need the + * entire packet. + */ + if (unlikely(icmp_len < sizeof(*nd))) + return 0; + + if (unlikely(skb_linearize(skb))) + return -ENOMEM; + + nd = (struct nd_msg *)skb_transport_header(skb); + key->ipv6.nd.target = nd->target; + + icmp_len -= sizeof(*nd); + offset = 0; + while (icmp_len >= 8) { + struct nd_opt_hdr *nd_opt = + (struct nd_opt_hdr *)(nd->opt + offset); + int opt_len = nd_opt->nd_opt_len * 8; + + if (unlikely(!opt_len || opt_len > icmp_len)) + return 0; + + /* Store the link layer address if the appropriate + * option is provided. It is considered an error if + * the same link layer option is specified twice. + */ + if (nd_opt->nd_opt_type == ND_OPT_SOURCE_LL_ADDR + && opt_len == 8) { + if (unlikely(!is_zero_ether_addr(key->ipv6.nd.sll))) + goto invalid; + ether_addr_copy(key->ipv6.nd.sll, + &nd->opt[offset+sizeof(*nd_opt)]); + } else if (nd_opt->nd_opt_type == ND_OPT_TARGET_LL_ADDR + && opt_len == 8) { + if (unlikely(!is_zero_ether_addr(key->ipv6.nd.tll))) + goto invalid; + ether_addr_copy(key->ipv6.nd.tll, + &nd->opt[offset+sizeof(*nd_opt)]); + } + + icmp_len -= opt_len; + offset += opt_len; + } + } + + return 0; + +invalid: + memset(&key->ipv6.nd.target, 0, sizeof(key->ipv6.nd.target)); + memset(key->ipv6.nd.sll, 0, sizeof(key->ipv6.nd.sll)); + memset(key->ipv6.nd.tll, 0, sizeof(key->ipv6.nd.tll)); + + return 0; +} + +static int parse_nsh(struct sk_buff *skb, struct sw_flow_key *key) +{ + struct nshhdr *nh; + unsigned int nh_ofs = skb_network_offset(skb); + u8 version, length; + int err; + + err = check_header(skb, nh_ofs + NSH_BASE_HDR_LEN); + if (unlikely(err)) + return err; + + nh = nsh_hdr(skb); + version = nsh_get_ver(nh); + length = nsh_hdr_len(nh); + + if (version != 0) + return -EINVAL; + + err = check_header(skb, nh_ofs + length); + if (unlikely(err)) + return err; + + nh = nsh_hdr(skb); + key->nsh.base.flags = nsh_get_flags(nh); + key->nsh.base.ttl = nsh_get_ttl(nh); + key->nsh.base.mdtype = nh->mdtype; + key->nsh.base.np = nh->np; + key->nsh.base.path_hdr = nh->path_hdr; + switch (key->nsh.base.mdtype) { + case NSH_M_TYPE1: + if (length != NSH_M_TYPE1_LEN) + return -EINVAL; + memcpy(key->nsh.context, nh->md1.context, + sizeof(nh->md1)); + break; + case NSH_M_TYPE2: + memset(key->nsh.context, 0, + sizeof(nh->md1)); + break; + default: + return -EINVAL; + } + + return 0; +} + +/** + * key_extract_l3l4 - extracts L3/L4 header information. + * @skb: sk_buff that contains the frame, with skb->data pointing to the + * L3 header + * @key: output flow key + * + * Return: %0 if successful, otherwise a negative errno value. + */ +static int key_extract_l3l4(struct sk_buff *skb, struct sw_flow_key *key) +{ + int error; + + /* Network layer. */ + if (key->eth.type == htons(ETH_P_IP)) { + struct iphdr *nh; + __be16 offset; + + error = check_iphdr(skb); + if (unlikely(error)) { + memset(&key->ip, 0, sizeof(key->ip)); + memset(&key->ipv4, 0, sizeof(key->ipv4)); + if (error == -EINVAL) { + skb->transport_header = skb->network_header; + error = 0; + } + return error; + } + + nh = ip_hdr(skb); + key->ipv4.addr.src = nh->saddr; + key->ipv4.addr.dst = nh->daddr; + + key->ip.proto = nh->protocol; + key->ip.tos = nh->tos; + key->ip.ttl = nh->ttl; + + offset = nh->frag_off & htons(IP_OFFSET); + if (offset) { + key->ip.frag = OVS_FRAG_TYPE_LATER; + memset(&key->tp, 0, sizeof(key->tp)); + return 0; + } + if (nh->frag_off & htons(IP_MF) || + skb_shinfo(skb)->gso_type & SKB_GSO_UDP) + key->ip.frag = OVS_FRAG_TYPE_FIRST; + else + key->ip.frag = OVS_FRAG_TYPE_NONE; + + /* Transport layer. */ + if (key->ip.proto == IPPROTO_TCP) { + if (tcphdr_ok(skb)) { + struct tcphdr *tcp = tcp_hdr(skb); + key->tp.src = tcp->source; + key->tp.dst = tcp->dest; + key->tp.flags = TCP_FLAGS_BE16(tcp); + } else { + memset(&key->tp, 0, sizeof(key->tp)); + } + + } else if (key->ip.proto == IPPROTO_UDP) { + if (udphdr_ok(skb)) { + struct udphdr *udp = udp_hdr(skb); + key->tp.src = udp->source; + key->tp.dst = udp->dest; + } else { + memset(&key->tp, 0, sizeof(key->tp)); + } + } else if (key->ip.proto == IPPROTO_SCTP) { + if (sctphdr_ok(skb)) { + struct sctphdr *sctp = sctp_hdr(skb); + key->tp.src = sctp->source; + key->tp.dst = sctp->dest; + } else { + memset(&key->tp, 0, sizeof(key->tp)); + } + } else if (key->ip.proto == IPPROTO_ICMP) { + if (icmphdr_ok(skb)) { + struct icmphdr *icmp = icmp_hdr(skb); + /* The ICMP type and code fields use the 16-bit + * transport port fields, so we need to store + * them in 16-bit network byte order. */ + key->tp.src = htons(icmp->type); + key->tp.dst = htons(icmp->code); + } else { + memset(&key->tp, 0, sizeof(key->tp)); + } + } + + } else if (key->eth.type == htons(ETH_P_ARP) || + key->eth.type == htons(ETH_P_RARP)) { + struct arp_eth_header *arp; + bool arp_available = arphdr_ok(skb); + + arp = (struct arp_eth_header *)skb_network_header(skb); + + if (arp_available && + arp->ar_hrd == htons(ARPHRD_ETHER) && + arp->ar_pro == htons(ETH_P_IP) && + arp->ar_hln == ETH_ALEN && + arp->ar_pln == 4) { + + /* We only match on the lower 8 bits of the opcode. */ + if (ntohs(arp->ar_op) <= 0xff) + key->ip.proto = ntohs(arp->ar_op); + else + key->ip.proto = 0; + + memcpy(&key->ipv4.addr.src, arp->ar_sip, sizeof(key->ipv4.addr.src)); + memcpy(&key->ipv4.addr.dst, arp->ar_tip, sizeof(key->ipv4.addr.dst)); + ether_addr_copy(key->ipv4.arp.sha, arp->ar_sha); + ether_addr_copy(key->ipv4.arp.tha, arp->ar_tha); + } else { + memset(&key->ip, 0, sizeof(key->ip)); + memset(&key->ipv4, 0, sizeof(key->ipv4)); + } + } else if (eth_p_mpls(key->eth.type)) { + u8 label_count = 1; + + memset(&key->mpls, 0, sizeof(key->mpls)); + skb_set_inner_network_header(skb, skb->mac_len); + while (1) { + __be32 lse; + + error = check_header(skb, skb->mac_len + + label_count * MPLS_HLEN); + if (unlikely(error)) + return 0; + + memcpy(&lse, skb_inner_network_header(skb), MPLS_HLEN); + + if (label_count <= MPLS_LABEL_DEPTH) + memcpy(&key->mpls.lse[label_count - 1], &lse, + MPLS_HLEN); + + skb_set_inner_network_header(skb, skb->mac_len + + label_count * MPLS_HLEN); + if (lse & htonl(MPLS_LS_S_MASK)) + break; + + label_count++; + } + if (label_count > MPLS_LABEL_DEPTH) + label_count = MPLS_LABEL_DEPTH; + + key->mpls.num_labels_mask = GENMASK(label_count - 1, 0); + } else if (key->eth.type == htons(ETH_P_IPV6)) { + int nh_len; /* IPv6 Header + Extensions */ + + nh_len = parse_ipv6hdr(skb, key); + if (unlikely(nh_len < 0)) { + switch (nh_len) { + case -EINVAL: + memset(&key->ip, 0, sizeof(key->ip)); + memset(&key->ipv6.addr, 0, sizeof(key->ipv6.addr)); + fallthrough; + case -EPROTO: + skb->transport_header = skb->network_header; + error = 0; + break; + default: + error = nh_len; + } + return error; + } + + if (key->ip.frag == OVS_FRAG_TYPE_LATER) { + memset(&key->tp, 0, sizeof(key->tp)); + return 0; + } + if (skb_shinfo(skb)->gso_type & SKB_GSO_UDP) + key->ip.frag = OVS_FRAG_TYPE_FIRST; + + /* Transport layer. */ + if (key->ip.proto == NEXTHDR_TCP) { + if (tcphdr_ok(skb)) { + struct tcphdr *tcp = tcp_hdr(skb); + key->tp.src = tcp->source; + key->tp.dst = tcp->dest; + key->tp.flags = TCP_FLAGS_BE16(tcp); + } else { + memset(&key->tp, 0, sizeof(key->tp)); + } + } else if (key->ip.proto == NEXTHDR_UDP) { + if (udphdr_ok(skb)) { + struct udphdr *udp = udp_hdr(skb); + key->tp.src = udp->source; + key->tp.dst = udp->dest; + } else { + memset(&key->tp, 0, sizeof(key->tp)); + } + } else if (key->ip.proto == NEXTHDR_SCTP) { + if (sctphdr_ok(skb)) { + struct sctphdr *sctp = sctp_hdr(skb); + key->tp.src = sctp->source; + key->tp.dst = sctp->dest; + } else { + memset(&key->tp, 0, sizeof(key->tp)); + } + } else if (key->ip.proto == NEXTHDR_ICMP) { + if (icmp6hdr_ok(skb)) { + error = parse_icmpv6(skb, key, nh_len); + if (error) + return error; + } else { + memset(&key->tp, 0, sizeof(key->tp)); + } + } + } else if (key->eth.type == htons(ETH_P_NSH)) { + error = parse_nsh(skb, key); + if (error) + return error; + } + return 0; +} + +/** + * key_extract - extracts a flow key from an Ethernet frame. + * @skb: sk_buff that contains the frame, with skb->data pointing to the + * Ethernet header + * @key: output flow key + * + * The caller must ensure that skb->len >= ETH_HLEN. + * + * Initializes @skb header fields as follows: + * + * - skb->mac_header: the L2 header. + * + * - skb->network_header: just past the L2 header, or just past the + * VLAN header, to the first byte of the L2 payload. + * + * - skb->transport_header: If key->eth.type is ETH_P_IP or ETH_P_IPV6 + * on output, then just past the IP header, if one is present and + * of a correct length, otherwise the same as skb->network_header. + * For other key->eth.type values it is left untouched. + * + * - skb->protocol: the type of the data starting at skb->network_header. + * Equals to key->eth.type. + * + * Return: %0 if successful, otherwise a negative errno value. + */ +static int key_extract(struct sk_buff *skb, struct sw_flow_key *key) +{ + struct ethhdr *eth; + + /* Flags are always used as part of stats */ + key->tp.flags = 0; + + skb_reset_mac_header(skb); + + /* Link layer. */ + clear_vlan(key); + if (ovs_key_mac_proto(key) == MAC_PROTO_NONE) { + if (unlikely(eth_type_vlan(skb->protocol))) + return -EINVAL; + + skb_reset_network_header(skb); + key->eth.type = skb->protocol; + } else { + eth = eth_hdr(skb); + ether_addr_copy(key->eth.src, eth->h_source); + ether_addr_copy(key->eth.dst, eth->h_dest); + + __skb_pull(skb, 2 * ETH_ALEN); + /* We are going to push all headers that we pull, so no need to + * update skb->csum here. + */ + + if (unlikely(parse_vlan(skb, key))) + return -ENOMEM; + + key->eth.type = parse_ethertype(skb); + if (unlikely(key->eth.type == htons(0))) + return -ENOMEM; + + /* Multiple tagged packets need to retain TPID to satisfy + * skb_vlan_pop(), which will later shift the ethertype into + * skb->protocol. + */ + if (key->eth.cvlan.tci & htons(VLAN_CFI_MASK)) + skb->protocol = key->eth.cvlan.tpid; + else + skb->protocol = key->eth.type; + + skb_reset_network_header(skb); + __skb_push(skb, skb->data - skb_mac_header(skb)); + } + + skb_reset_mac_len(skb); + + /* Fill out L3/L4 key info, if any */ + return key_extract_l3l4(skb, key); +} + +/* In the case of conntrack fragment handling it expects L3 headers, + * add a helper. + */ +int ovs_flow_key_update_l3l4(struct sk_buff *skb, struct sw_flow_key *key) +{ + return key_extract_l3l4(skb, key); +} + +int ovs_flow_key_update(struct sk_buff *skb, struct sw_flow_key *key) +{ + int res; + + res = key_extract(skb, key); + if (!res) + key->mac_proto &= ~SW_FLOW_KEY_INVALID; + + return res; +} + +static int key_extract_mac_proto(struct sk_buff *skb) +{ + switch (skb->dev->type) { + case ARPHRD_ETHER: + return MAC_PROTO_ETHERNET; + case ARPHRD_NONE: + if (skb->protocol == htons(ETH_P_TEB)) + return MAC_PROTO_ETHERNET; + return MAC_PROTO_NONE; + } + WARN_ON_ONCE(1); + return -EINVAL; +} + +int ovs_flow_key_extract(const struct ip_tunnel_info *tun_info, + struct sk_buff *skb, struct sw_flow_key *key) +{ +#if IS_ENABLED(CONFIG_NET_TC_SKB_EXT) + struct tc_skb_ext *tc_ext; +#endif + bool post_ct = false, post_ct_snat = false, post_ct_dnat = false; + int res, err; + u16 zone = 0; + + /* Extract metadata from packet. */ + if (tun_info) { + key->tun_proto = ip_tunnel_info_af(tun_info); + memcpy(&key->tun_key, &tun_info->key, sizeof(key->tun_key)); + + if (tun_info->options_len) { + BUILD_BUG_ON((1 << (sizeof(tun_info->options_len) * + 8)) - 1 + > sizeof(key->tun_opts)); + + ip_tunnel_info_opts_get(TUN_METADATA_OPTS(key, tun_info->options_len), + tun_info); + key->tun_opts_len = tun_info->options_len; + } else { + key->tun_opts_len = 0; + } + } else { + key->tun_proto = 0; + key->tun_opts_len = 0; + memset(&key->tun_key, 0, sizeof(key->tun_key)); + } + + key->phy.priority = skb->priority; + key->phy.in_port = OVS_CB(skb)->input_vport->port_no; + key->phy.skb_mark = skb->mark; + key->ovs_flow_hash = 0; + res = key_extract_mac_proto(skb); + if (res < 0) + return res; + key->mac_proto = res; + +#if IS_ENABLED(CONFIG_NET_TC_SKB_EXT) + if (tc_skb_ext_tc_enabled()) { + tc_ext = skb_ext_find(skb, TC_SKB_EXT); + key->recirc_id = tc_ext ? tc_ext->chain : 0; + OVS_CB(skb)->mru = tc_ext ? tc_ext->mru : 0; + post_ct = tc_ext ? tc_ext->post_ct : false; + post_ct_snat = post_ct ? tc_ext->post_ct_snat : false; + post_ct_dnat = post_ct ? tc_ext->post_ct_dnat : false; + zone = post_ct ? tc_ext->zone : 0; + } else { + key->recirc_id = 0; + } +#else + key->recirc_id = 0; +#endif + + err = key_extract(skb, key); + if (!err) { + ovs_ct_fill_key(skb, key, post_ct); /* Must be after key_extract(). */ + if (post_ct) { + if (!skb_get_nfct(skb)) { + key->ct_zone = zone; + } else { + if (!post_ct_dnat) + key->ct_state &= ~OVS_CS_F_DST_NAT; + if (!post_ct_snat) + key->ct_state &= ~OVS_CS_F_SRC_NAT; + } + } + } + return err; +} + +int ovs_flow_key_extract_userspace(struct net *net, const struct nlattr *attr, + struct sk_buff *skb, + struct sw_flow_key *key, bool log) +{ + const struct nlattr *a[OVS_KEY_ATTR_MAX + 1]; + u64 attrs = 0; + int err; + + err = parse_flow_nlattrs(attr, a, &attrs, log); + if (err) + return -EINVAL; + + /* Extract metadata from netlink attributes. */ + err = ovs_nla_get_flow_metadata(net, a, attrs, key, log); + if (err) + return err; + + /* key_extract assumes that skb->protocol is set-up for + * layer 3 packets which is the case for other callers, + * in particular packets received from the network stack. + * Here the correct value can be set from the metadata + * extracted above. + * For L2 packet key eth type would be zero. skb protocol + * would be set to correct value later during key-extact. + */ + + skb->protocol = key->eth.type; + err = key_extract(skb, key); + if (err) + return err; + + /* Check that we have conntrack original direction tuple metadata only + * for packets for which it makes sense. Otherwise the key may be + * corrupted due to overlapping key fields. + */ + if (attrs & (1 << OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV4) && + key->eth.type != htons(ETH_P_IP)) + return -EINVAL; + if (attrs & (1 << OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV6) && + (key->eth.type != htons(ETH_P_IPV6) || + sw_flow_key_is_nd(key))) + return -EINVAL; + + return 0; +} diff --git a/net/openvswitch/flow.h b/net/openvswitch/flow.h new file mode 100644 index 000000000..073ab73ff --- /dev/null +++ b/net/openvswitch/flow.h @@ -0,0 +1,298 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2007-2017 Nicira, Inc. + */ + +#ifndef FLOW_H +#define FLOW_H 1 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct sk_buff; + +enum sw_flow_mac_proto { + MAC_PROTO_NONE = 0, + MAC_PROTO_ETHERNET, +}; +#define SW_FLOW_KEY_INVALID 0x80 +#define MPLS_LABEL_DEPTH 3 + +/* Bit definitions for IPv6 Extension Header pseudo-field. */ +enum ofp12_ipv6exthdr_flags { + OFPIEH12_NONEXT = 1 << 0, /* "No next header" encountered. */ + OFPIEH12_ESP = 1 << 1, /* Encrypted Sec Payload header present. */ + OFPIEH12_AUTH = 1 << 2, /* Authentication header present. */ + OFPIEH12_DEST = 1 << 3, /* 1 or 2 dest headers present. */ + OFPIEH12_FRAG = 1 << 4, /* Fragment header present. */ + OFPIEH12_ROUTER = 1 << 5, /* Router header present. */ + OFPIEH12_HOP = 1 << 6, /* Hop-by-hop header present. */ + OFPIEH12_UNREP = 1 << 7, /* Unexpected repeats encountered. */ + OFPIEH12_UNSEQ = 1 << 8 /* Unexpected sequencing encountered. */ +}; + +/* Store options at the end of the array if they are less than the + * maximum size. This allows us to get the benefits of variable length + * matching for small options. + */ +#define TUN_METADATA_OFFSET(opt_len) \ + (sizeof_field(struct sw_flow_key, tun_opts) - opt_len) +#define TUN_METADATA_OPTS(flow_key, opt_len) \ + ((void *)((flow_key)->tun_opts + TUN_METADATA_OFFSET(opt_len))) + +struct ovs_tunnel_info { + struct metadata_dst *tun_dst; +}; + +struct vlan_head { + __be16 tpid; /* Vlan type. Generally 802.1q or 802.1ad.*/ + __be16 tci; /* 0 if no VLAN, VLAN_CFI_MASK set otherwise. */ +}; + +#define OVS_SW_FLOW_KEY_METADATA_SIZE \ + (offsetof(struct sw_flow_key, recirc_id) + \ + sizeof_field(struct sw_flow_key, recirc_id)) + +struct ovs_key_nsh { + struct ovs_nsh_key_base base; + __be32 context[NSH_MD1_CONTEXT_SIZE]; +}; + +struct sw_flow_key { + u8 tun_opts[IP_TUNNEL_OPTS_MAX]; + u8 tun_opts_len; + struct ip_tunnel_key tun_key; /* Encapsulating tunnel key. */ + struct { + u32 priority; /* Packet QoS priority. */ + u32 skb_mark; /* SKB mark. */ + u16 in_port; /* Input switch port (or DP_MAX_PORTS). */ + } __packed phy; /* Safe when right after 'tun_key'. */ + u8 mac_proto; /* MAC layer protocol (e.g. Ethernet). */ + u8 tun_proto; /* Protocol of encapsulating tunnel. */ + u32 ovs_flow_hash; /* Datapath computed hash value. */ + u32 recirc_id; /* Recirculation ID. */ + struct { + u8 src[ETH_ALEN]; /* Ethernet source address. */ + u8 dst[ETH_ALEN]; /* Ethernet destination address. */ + struct vlan_head vlan; + struct vlan_head cvlan; + __be16 type; /* Ethernet frame type. */ + } eth; + /* Filling a hole of two bytes. */ + u8 ct_state; + u8 ct_orig_proto; /* CT original direction tuple IP + * protocol. + */ + union { + struct { + u8 proto; /* IP protocol or lower 8 bits of ARP opcode. */ + u8 tos; /* IP ToS. */ + u8 ttl; /* IP TTL/hop limit. */ + u8 frag; /* One of OVS_FRAG_TYPE_*. */ + } ip; + }; + u16 ct_zone; /* Conntrack zone. */ + struct { + __be16 src; /* TCP/UDP/SCTP source port. */ + __be16 dst; /* TCP/UDP/SCTP destination port. */ + __be16 flags; /* TCP flags. */ + } tp; + union { + struct { + struct { + __be32 src; /* IP source address. */ + __be32 dst; /* IP destination address. */ + } addr; + union { + struct { + __be32 src; + __be32 dst; + } ct_orig; /* Conntrack original direction fields. */ + struct { + u8 sha[ETH_ALEN]; /* ARP source hardware address. */ + u8 tha[ETH_ALEN]; /* ARP target hardware address. */ + } arp; + }; + } ipv4; + struct { + struct { + struct in6_addr src; /* IPv6 source address. */ + struct in6_addr dst; /* IPv6 destination address. */ + } addr; + __be32 label; /* IPv6 flow label. */ + u16 exthdrs; /* IPv6 extension header flags */ + union { + struct { + struct in6_addr src; + struct in6_addr dst; + } ct_orig; /* Conntrack original direction fields. */ + struct { + struct in6_addr target; /* ND target address. */ + u8 sll[ETH_ALEN]; /* ND source link layer address. */ + u8 tll[ETH_ALEN]; /* ND target link layer address. */ + } nd; + }; + } ipv6; + struct { + u32 num_labels_mask; /* labels present bitmap of effective length MPLS_LABEL_DEPTH */ + __be32 lse[MPLS_LABEL_DEPTH]; /* label stack entry */ + } mpls; + + struct ovs_key_nsh nsh; /* network service header */ + }; + struct { + /* Connection tracking fields not packed above. */ + struct { + __be16 src; /* CT orig tuple tp src port. */ + __be16 dst; /* CT orig tuple tp dst port. */ + } orig_tp; + u32 mark; + struct ovs_key_ct_labels labels; + } ct; + +} __aligned(BITS_PER_LONG/8); /* Ensure that we can do comparisons as longs. */ + +static inline bool sw_flow_key_is_nd(const struct sw_flow_key *key) +{ + return key->eth.type == htons(ETH_P_IPV6) && + key->ip.proto == NEXTHDR_ICMP && + key->tp.dst == 0 && + (key->tp.src == htons(NDISC_NEIGHBOUR_SOLICITATION) || + key->tp.src == htons(NDISC_NEIGHBOUR_ADVERTISEMENT)); +} + +struct sw_flow_key_range { + unsigned short int start; + unsigned short int end; +}; + +struct sw_flow_mask { + int ref_count; + struct rcu_head rcu; + struct sw_flow_key_range range; + struct sw_flow_key key; +}; + +struct sw_flow_match { + struct sw_flow_key *key; + struct sw_flow_key_range range; + struct sw_flow_mask *mask; +}; + +#define MAX_UFID_LENGTH 16 /* 128 bits */ + +struct sw_flow_id { + u32 ufid_len; + union { + u32 ufid[MAX_UFID_LENGTH / 4]; + struct sw_flow_key *unmasked_key; + }; +}; + +struct sw_flow_actions { + struct rcu_head rcu; + size_t orig_len; /* From flow_cmd_new netlink actions size */ + u32 actions_len; + struct nlattr actions[]; +}; + +struct sw_flow_stats { + u64 packet_count; /* Number of packets matched. */ + u64 byte_count; /* Number of bytes matched. */ + unsigned long used; /* Last used time (in jiffies). */ + spinlock_t lock; /* Lock for atomic stats update. */ + __be16 tcp_flags; /* Union of seen TCP flags. */ +}; + +struct sw_flow { + struct rcu_head rcu; + struct { + struct hlist_node node[2]; + u32 hash; + } flow_table, ufid_table; + int stats_last_writer; /* CPU id of the last writer on + * 'stats[0]'. + */ + struct sw_flow_key key; + struct sw_flow_id id; + struct cpumask cpu_used_mask; + struct sw_flow_mask *mask; + struct sw_flow_actions __rcu *sf_acts; + struct sw_flow_stats __rcu *stats[]; /* One for each CPU. First one + * is allocated at flow creation time, + * the rest are allocated on demand + * while holding the 'stats[0].lock'. + */ +}; + +struct arp_eth_header { + __be16 ar_hrd; /* format of hardware address */ + __be16 ar_pro; /* format of protocol address */ + unsigned char ar_hln; /* length of hardware address */ + unsigned char ar_pln; /* length of protocol address */ + __be16 ar_op; /* ARP opcode (command) */ + + /* Ethernet+IPv4 specific members. */ + unsigned char ar_sha[ETH_ALEN]; /* sender hardware address */ + unsigned char ar_sip[4]; /* sender IP address */ + unsigned char ar_tha[ETH_ALEN]; /* target hardware address */ + unsigned char ar_tip[4]; /* target IP address */ +} __packed; + +static inline u8 ovs_key_mac_proto(const struct sw_flow_key *key) +{ + return key->mac_proto & ~SW_FLOW_KEY_INVALID; +} + +static inline u16 __ovs_mac_header_len(u8 mac_proto) +{ + return mac_proto == MAC_PROTO_ETHERNET ? ETH_HLEN : 0; +} + +static inline u16 ovs_mac_header_len(const struct sw_flow_key *key) +{ + return __ovs_mac_header_len(ovs_key_mac_proto(key)); +} + +static inline bool ovs_identifier_is_ufid(const struct sw_flow_id *sfid) +{ + return sfid->ufid_len; +} + +static inline bool ovs_identifier_is_key(const struct sw_flow_id *sfid) +{ + return !ovs_identifier_is_ufid(sfid); +} + +void ovs_flow_stats_update(struct sw_flow *, __be16 tcp_flags, + const struct sk_buff *); +void ovs_flow_stats_get(const struct sw_flow *, struct ovs_flow_stats *, + unsigned long *used, __be16 *tcp_flags); +void ovs_flow_stats_clear(struct sw_flow *); +u64 ovs_flow_used_time(unsigned long flow_jiffies); + +int ovs_flow_key_update(struct sk_buff *skb, struct sw_flow_key *key); +int ovs_flow_key_update_l3l4(struct sk_buff *skb, struct sw_flow_key *key); +int ovs_flow_key_extract(const struct ip_tunnel_info *tun_info, + struct sk_buff *skb, + struct sw_flow_key *key); +/* Extract key from packet coming from userspace. */ +int ovs_flow_key_extract_userspace(struct net *net, const struct nlattr *attr, + struct sk_buff *skb, + struct sw_flow_key *key, bool log); + +#endif /* flow.h */ diff --git a/net/openvswitch/flow_netlink.c b/net/openvswitch/flow_netlink.c new file mode 100644 index 000000000..ead5418c1 --- /dev/null +++ b/net/openvswitch/flow_netlink.c @@ -0,0 +1,3783 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2007-2017 Nicira, Inc. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include "flow.h" +#include "datapath.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "flow_netlink.h" + +struct ovs_len_tbl { + int len; + const struct ovs_len_tbl *next; +}; + +#define OVS_ATTR_NESTED -1 +#define OVS_ATTR_VARIABLE -2 + +static bool actions_may_change_flow(const struct nlattr *actions) +{ + struct nlattr *nla; + int rem; + + nla_for_each_nested(nla, actions, rem) { + u16 action = nla_type(nla); + + switch (action) { + case OVS_ACTION_ATTR_OUTPUT: + case OVS_ACTION_ATTR_RECIRC: + case OVS_ACTION_ATTR_TRUNC: + case OVS_ACTION_ATTR_USERSPACE: + break; + + case OVS_ACTION_ATTR_CT: + case OVS_ACTION_ATTR_CT_CLEAR: + case OVS_ACTION_ATTR_HASH: + case OVS_ACTION_ATTR_POP_ETH: + case OVS_ACTION_ATTR_POP_MPLS: + case OVS_ACTION_ATTR_POP_NSH: + case OVS_ACTION_ATTR_POP_VLAN: + case OVS_ACTION_ATTR_PUSH_ETH: + case OVS_ACTION_ATTR_PUSH_MPLS: + case OVS_ACTION_ATTR_PUSH_NSH: + case OVS_ACTION_ATTR_PUSH_VLAN: + case OVS_ACTION_ATTR_SAMPLE: + case OVS_ACTION_ATTR_SET: + case OVS_ACTION_ATTR_SET_MASKED: + case OVS_ACTION_ATTR_METER: + case OVS_ACTION_ATTR_CHECK_PKT_LEN: + case OVS_ACTION_ATTR_ADD_MPLS: + case OVS_ACTION_ATTR_DEC_TTL: + default: + return true; + } + } + return false; +} + +static void update_range(struct sw_flow_match *match, + size_t offset, size_t size, bool is_mask) +{ + struct sw_flow_key_range *range; + size_t start = rounddown(offset, sizeof(long)); + size_t end = roundup(offset + size, sizeof(long)); + + if (!is_mask) + range = &match->range; + else + range = &match->mask->range; + + if (range->start == range->end) { + range->start = start; + range->end = end; + return; + } + + if (range->start > start) + range->start = start; + + if (range->end < end) + range->end = end; +} + +#define SW_FLOW_KEY_PUT(match, field, value, is_mask) \ + do { \ + update_range(match, offsetof(struct sw_flow_key, field), \ + sizeof((match)->key->field), is_mask); \ + if (is_mask) \ + (match)->mask->key.field = value; \ + else \ + (match)->key->field = value; \ + } while (0) + +#define SW_FLOW_KEY_MEMCPY_OFFSET(match, offset, value_p, len, is_mask) \ + do { \ + update_range(match, offset, len, is_mask); \ + if (is_mask) \ + memcpy((u8 *)&(match)->mask->key + offset, value_p, \ + len); \ + else \ + memcpy((u8 *)(match)->key + offset, value_p, len); \ + } while (0) + +#define SW_FLOW_KEY_MEMCPY(match, field, value_p, len, is_mask) \ + SW_FLOW_KEY_MEMCPY_OFFSET(match, offsetof(struct sw_flow_key, field), \ + value_p, len, is_mask) + +#define SW_FLOW_KEY_MEMSET_FIELD(match, field, value, is_mask) \ + do { \ + update_range(match, offsetof(struct sw_flow_key, field), \ + sizeof((match)->key->field), is_mask); \ + if (is_mask) \ + memset((u8 *)&(match)->mask->key.field, value, \ + sizeof((match)->mask->key.field)); \ + else \ + memset((u8 *)&(match)->key->field, value, \ + sizeof((match)->key->field)); \ + } while (0) + +static bool match_validate(const struct sw_flow_match *match, + u64 key_attrs, u64 mask_attrs, bool log) +{ + u64 key_expected = 0; + u64 mask_allowed = key_attrs; /* At most allow all key attributes */ + + /* The following mask attributes allowed only if they + * pass the validation tests. */ + mask_allowed &= ~((1 << OVS_KEY_ATTR_IPV4) + | (1 << OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV4) + | (1 << OVS_KEY_ATTR_IPV6) + | (1 << OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV6) + | (1 << OVS_KEY_ATTR_TCP) + | (1 << OVS_KEY_ATTR_TCP_FLAGS) + | (1 << OVS_KEY_ATTR_UDP) + | (1 << OVS_KEY_ATTR_SCTP) + | (1 << OVS_KEY_ATTR_ICMP) + | (1 << OVS_KEY_ATTR_ICMPV6) + | (1 << OVS_KEY_ATTR_ARP) + | (1 << OVS_KEY_ATTR_ND) + | (1 << OVS_KEY_ATTR_MPLS) + | (1 << OVS_KEY_ATTR_NSH)); + + /* Always allowed mask fields. */ + mask_allowed |= ((1 << OVS_KEY_ATTR_TUNNEL) + | (1 << OVS_KEY_ATTR_IN_PORT) + | (1 << OVS_KEY_ATTR_ETHERTYPE)); + + /* Check key attributes. */ + if (match->key->eth.type == htons(ETH_P_ARP) + || match->key->eth.type == htons(ETH_P_RARP)) { + key_expected |= 1 << OVS_KEY_ATTR_ARP; + if (match->mask && (match->mask->key.eth.type == htons(0xffff))) + mask_allowed |= 1 << OVS_KEY_ATTR_ARP; + } + + if (eth_p_mpls(match->key->eth.type)) { + key_expected |= 1 << OVS_KEY_ATTR_MPLS; + if (match->mask && (match->mask->key.eth.type == htons(0xffff))) + mask_allowed |= 1 << OVS_KEY_ATTR_MPLS; + } + + if (match->key->eth.type == htons(ETH_P_IP)) { + key_expected |= 1 << OVS_KEY_ATTR_IPV4; + if (match->mask && match->mask->key.eth.type == htons(0xffff)) { + mask_allowed |= 1 << OVS_KEY_ATTR_IPV4; + mask_allowed |= 1 << OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV4; + } + + if (match->key->ip.frag != OVS_FRAG_TYPE_LATER) { + if (match->key->ip.proto == IPPROTO_UDP) { + key_expected |= 1 << OVS_KEY_ATTR_UDP; + if (match->mask && (match->mask->key.ip.proto == 0xff)) + mask_allowed |= 1 << OVS_KEY_ATTR_UDP; + } + + if (match->key->ip.proto == IPPROTO_SCTP) { + key_expected |= 1 << OVS_KEY_ATTR_SCTP; + if (match->mask && (match->mask->key.ip.proto == 0xff)) + mask_allowed |= 1 << OVS_KEY_ATTR_SCTP; + } + + if (match->key->ip.proto == IPPROTO_TCP) { + key_expected |= 1 << OVS_KEY_ATTR_TCP; + key_expected |= 1 << OVS_KEY_ATTR_TCP_FLAGS; + if (match->mask && (match->mask->key.ip.proto == 0xff)) { + mask_allowed |= 1 << OVS_KEY_ATTR_TCP; + mask_allowed |= 1 << OVS_KEY_ATTR_TCP_FLAGS; + } + } + + if (match->key->ip.proto == IPPROTO_ICMP) { + key_expected |= 1 << OVS_KEY_ATTR_ICMP; + if (match->mask && (match->mask->key.ip.proto == 0xff)) + mask_allowed |= 1 << OVS_KEY_ATTR_ICMP; + } + } + } + + if (match->key->eth.type == htons(ETH_P_IPV6)) { + key_expected |= 1 << OVS_KEY_ATTR_IPV6; + if (match->mask && match->mask->key.eth.type == htons(0xffff)) { + mask_allowed |= 1 << OVS_KEY_ATTR_IPV6; + mask_allowed |= 1 << OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV6; + } + + if (match->key->ip.frag != OVS_FRAG_TYPE_LATER) { + if (match->key->ip.proto == IPPROTO_UDP) { + key_expected |= 1 << OVS_KEY_ATTR_UDP; + if (match->mask && (match->mask->key.ip.proto == 0xff)) + mask_allowed |= 1 << OVS_KEY_ATTR_UDP; + } + + if (match->key->ip.proto == IPPROTO_SCTP) { + key_expected |= 1 << OVS_KEY_ATTR_SCTP; + if (match->mask && (match->mask->key.ip.proto == 0xff)) + mask_allowed |= 1 << OVS_KEY_ATTR_SCTP; + } + + if (match->key->ip.proto == IPPROTO_TCP) { + key_expected |= 1 << OVS_KEY_ATTR_TCP; + key_expected |= 1 << OVS_KEY_ATTR_TCP_FLAGS; + if (match->mask && (match->mask->key.ip.proto == 0xff)) { + mask_allowed |= 1 << OVS_KEY_ATTR_TCP; + mask_allowed |= 1 << OVS_KEY_ATTR_TCP_FLAGS; + } + } + + if (match->key->ip.proto == IPPROTO_ICMPV6) { + key_expected |= 1 << OVS_KEY_ATTR_ICMPV6; + if (match->mask && (match->mask->key.ip.proto == 0xff)) + mask_allowed |= 1 << OVS_KEY_ATTR_ICMPV6; + + if (match->key->tp.src == + htons(NDISC_NEIGHBOUR_SOLICITATION) || + match->key->tp.src == htons(NDISC_NEIGHBOUR_ADVERTISEMENT)) { + key_expected |= 1 << OVS_KEY_ATTR_ND; + /* Original direction conntrack tuple + * uses the same space as the ND fields + * in the key, so both are not allowed + * at the same time. + */ + mask_allowed &= ~(1ULL << OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV6); + if (match->mask && (match->mask->key.tp.src == htons(0xff))) + mask_allowed |= 1 << OVS_KEY_ATTR_ND; + } + } + } + } + + if (match->key->eth.type == htons(ETH_P_NSH)) { + key_expected |= 1 << OVS_KEY_ATTR_NSH; + if (match->mask && + match->mask->key.eth.type == htons(0xffff)) { + mask_allowed |= 1 << OVS_KEY_ATTR_NSH; + } + } + + if ((key_attrs & key_expected) != key_expected) { + /* Key attributes check failed. */ + OVS_NLERR(log, "Missing key (keys=%llx, expected=%llx)", + (unsigned long long)key_attrs, + (unsigned long long)key_expected); + return false; + } + + if ((mask_attrs & mask_allowed) != mask_attrs) { + /* Mask attributes check failed. */ + OVS_NLERR(log, "Unexpected mask (mask=%llx, allowed=%llx)", + (unsigned long long)mask_attrs, + (unsigned long long)mask_allowed); + return false; + } + + return true; +} + +size_t ovs_tun_key_attr_size(void) +{ + /* Whenever adding new OVS_TUNNEL_KEY_ FIELDS, we should consider + * updating this function. + */ + return nla_total_size_64bit(8) /* OVS_TUNNEL_KEY_ATTR_ID */ + + nla_total_size(16) /* OVS_TUNNEL_KEY_ATTR_IPV[46]_SRC */ + + nla_total_size(16) /* OVS_TUNNEL_KEY_ATTR_IPV[46]_DST */ + + nla_total_size(1) /* OVS_TUNNEL_KEY_ATTR_TOS */ + + nla_total_size(1) /* OVS_TUNNEL_KEY_ATTR_TTL */ + + nla_total_size(0) /* OVS_TUNNEL_KEY_ATTR_DONT_FRAGMENT */ + + nla_total_size(0) /* OVS_TUNNEL_KEY_ATTR_CSUM */ + + nla_total_size(0) /* OVS_TUNNEL_KEY_ATTR_OAM */ + + nla_total_size(256) /* OVS_TUNNEL_KEY_ATTR_GENEVE_OPTS */ + /* OVS_TUNNEL_KEY_ATTR_VXLAN_OPTS and + * OVS_TUNNEL_KEY_ATTR_ERSPAN_OPTS is mutually exclusive with + * OVS_TUNNEL_KEY_ATTR_GENEVE_OPTS and covered by it. + */ + + nla_total_size(2) /* OVS_TUNNEL_KEY_ATTR_TP_SRC */ + + nla_total_size(2); /* OVS_TUNNEL_KEY_ATTR_TP_DST */ +} + +static size_t ovs_nsh_key_attr_size(void) +{ + /* Whenever adding new OVS_NSH_KEY_ FIELDS, we should consider + * updating this function. + */ + return nla_total_size(NSH_BASE_HDR_LEN) /* OVS_NSH_KEY_ATTR_BASE */ + /* OVS_NSH_KEY_ATTR_MD1 and OVS_NSH_KEY_ATTR_MD2 are + * mutually exclusive, so the bigger one can cover + * the small one. + */ + + nla_total_size(NSH_CTX_HDRS_MAX_LEN); +} + +size_t ovs_key_attr_size(void) +{ + /* Whenever adding new OVS_KEY_ FIELDS, we should consider + * updating this function. + */ + BUILD_BUG_ON(OVS_KEY_ATTR_MAX != 32); + + return nla_total_size(4) /* OVS_KEY_ATTR_PRIORITY */ + + nla_total_size(0) /* OVS_KEY_ATTR_TUNNEL */ + + ovs_tun_key_attr_size() + + nla_total_size(4) /* OVS_KEY_ATTR_IN_PORT */ + + nla_total_size(4) /* OVS_KEY_ATTR_SKB_MARK */ + + nla_total_size(4) /* OVS_KEY_ATTR_DP_HASH */ + + nla_total_size(4) /* OVS_KEY_ATTR_RECIRC_ID */ + + nla_total_size(4) /* OVS_KEY_ATTR_CT_STATE */ + + nla_total_size(2) /* OVS_KEY_ATTR_CT_ZONE */ + + nla_total_size(4) /* OVS_KEY_ATTR_CT_MARK */ + + nla_total_size(16) /* OVS_KEY_ATTR_CT_LABELS */ + + nla_total_size(40) /* OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV6 */ + + nla_total_size(0) /* OVS_KEY_ATTR_NSH */ + + ovs_nsh_key_attr_size() + + nla_total_size(12) /* OVS_KEY_ATTR_ETHERNET */ + + nla_total_size(2) /* OVS_KEY_ATTR_ETHERTYPE */ + + nla_total_size(4) /* OVS_KEY_ATTR_VLAN */ + + nla_total_size(0) /* OVS_KEY_ATTR_ENCAP */ + + nla_total_size(2) /* OVS_KEY_ATTR_ETHERTYPE */ + + nla_total_size(40) /* OVS_KEY_ATTR_IPV6 */ + + nla_total_size(2) /* OVS_KEY_ATTR_ICMPV6 */ + + nla_total_size(28) /* OVS_KEY_ATTR_ND */ + + nla_total_size(2); /* OVS_KEY_ATTR_IPV6_EXTHDRS */ +} + +static const struct ovs_len_tbl ovs_vxlan_ext_key_lens[OVS_VXLAN_EXT_MAX + 1] = { + [OVS_VXLAN_EXT_GBP] = { .len = sizeof(u32) }, +}; + +static const struct ovs_len_tbl ovs_tunnel_key_lens[OVS_TUNNEL_KEY_ATTR_MAX + 1] = { + [OVS_TUNNEL_KEY_ATTR_ID] = { .len = sizeof(u64) }, + [OVS_TUNNEL_KEY_ATTR_IPV4_SRC] = { .len = sizeof(u32) }, + [OVS_TUNNEL_KEY_ATTR_IPV4_DST] = { .len = sizeof(u32) }, + [OVS_TUNNEL_KEY_ATTR_TOS] = { .len = 1 }, + [OVS_TUNNEL_KEY_ATTR_TTL] = { .len = 1 }, + [OVS_TUNNEL_KEY_ATTR_DONT_FRAGMENT] = { .len = 0 }, + [OVS_TUNNEL_KEY_ATTR_CSUM] = { .len = 0 }, + [OVS_TUNNEL_KEY_ATTR_TP_SRC] = { .len = sizeof(u16) }, + [OVS_TUNNEL_KEY_ATTR_TP_DST] = { .len = sizeof(u16) }, + [OVS_TUNNEL_KEY_ATTR_OAM] = { .len = 0 }, + [OVS_TUNNEL_KEY_ATTR_GENEVE_OPTS] = { .len = OVS_ATTR_VARIABLE }, + [OVS_TUNNEL_KEY_ATTR_VXLAN_OPTS] = { .len = OVS_ATTR_NESTED, + .next = ovs_vxlan_ext_key_lens }, + [OVS_TUNNEL_KEY_ATTR_IPV6_SRC] = { .len = sizeof(struct in6_addr) }, + [OVS_TUNNEL_KEY_ATTR_IPV6_DST] = { .len = sizeof(struct in6_addr) }, + [OVS_TUNNEL_KEY_ATTR_ERSPAN_OPTS] = { .len = OVS_ATTR_VARIABLE }, + [OVS_TUNNEL_KEY_ATTR_IPV4_INFO_BRIDGE] = { .len = 0 }, +}; + +static const struct ovs_len_tbl +ovs_nsh_key_attr_lens[OVS_NSH_KEY_ATTR_MAX + 1] = { + [OVS_NSH_KEY_ATTR_BASE] = { .len = sizeof(struct ovs_nsh_key_base) }, + [OVS_NSH_KEY_ATTR_MD1] = { .len = sizeof(struct ovs_nsh_key_md1) }, + [OVS_NSH_KEY_ATTR_MD2] = { .len = OVS_ATTR_VARIABLE }, +}; + +/* The size of the argument for each %OVS_KEY_ATTR_* Netlink attribute. */ +static const struct ovs_len_tbl ovs_key_lens[OVS_KEY_ATTR_MAX + 1] = { + [OVS_KEY_ATTR_ENCAP] = { .len = OVS_ATTR_NESTED }, + [OVS_KEY_ATTR_PRIORITY] = { .len = sizeof(u32) }, + [OVS_KEY_ATTR_IN_PORT] = { .len = sizeof(u32) }, + [OVS_KEY_ATTR_SKB_MARK] = { .len = sizeof(u32) }, + [OVS_KEY_ATTR_ETHERNET] = { .len = sizeof(struct ovs_key_ethernet) }, + [OVS_KEY_ATTR_VLAN] = { .len = sizeof(__be16) }, + [OVS_KEY_ATTR_ETHERTYPE] = { .len = sizeof(__be16) }, + [OVS_KEY_ATTR_IPV4] = { .len = sizeof(struct ovs_key_ipv4) }, + [OVS_KEY_ATTR_IPV6] = { .len = sizeof(struct ovs_key_ipv6) }, + [OVS_KEY_ATTR_TCP] = { .len = sizeof(struct ovs_key_tcp) }, + [OVS_KEY_ATTR_TCP_FLAGS] = { .len = sizeof(__be16) }, + [OVS_KEY_ATTR_UDP] = { .len = sizeof(struct ovs_key_udp) }, + [OVS_KEY_ATTR_SCTP] = { .len = sizeof(struct ovs_key_sctp) }, + [OVS_KEY_ATTR_ICMP] = { .len = sizeof(struct ovs_key_icmp) }, + [OVS_KEY_ATTR_ICMPV6] = { .len = sizeof(struct ovs_key_icmpv6) }, + [OVS_KEY_ATTR_ARP] = { .len = sizeof(struct ovs_key_arp) }, + [OVS_KEY_ATTR_ND] = { .len = sizeof(struct ovs_key_nd) }, + [OVS_KEY_ATTR_RECIRC_ID] = { .len = sizeof(u32) }, + [OVS_KEY_ATTR_DP_HASH] = { .len = sizeof(u32) }, + [OVS_KEY_ATTR_TUNNEL] = { .len = OVS_ATTR_NESTED, + .next = ovs_tunnel_key_lens, }, + [OVS_KEY_ATTR_MPLS] = { .len = OVS_ATTR_VARIABLE }, + [OVS_KEY_ATTR_CT_STATE] = { .len = sizeof(u32) }, + [OVS_KEY_ATTR_CT_ZONE] = { .len = sizeof(u16) }, + [OVS_KEY_ATTR_CT_MARK] = { .len = sizeof(u32) }, + [OVS_KEY_ATTR_CT_LABELS] = { .len = sizeof(struct ovs_key_ct_labels) }, + [OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV4] = { + .len = sizeof(struct ovs_key_ct_tuple_ipv4) }, + [OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV6] = { + .len = sizeof(struct ovs_key_ct_tuple_ipv6) }, + [OVS_KEY_ATTR_NSH] = { .len = OVS_ATTR_NESTED, + .next = ovs_nsh_key_attr_lens, }, + [OVS_KEY_ATTR_IPV6_EXTHDRS] = { + .len = sizeof(struct ovs_key_ipv6_exthdrs) }, +}; + +static bool check_attr_len(unsigned int attr_len, unsigned int expected_len) +{ + return expected_len == attr_len || + expected_len == OVS_ATTR_NESTED || + expected_len == OVS_ATTR_VARIABLE; +} + +static bool is_all_zero(const u8 *fp, size_t size) +{ + int i; + + if (!fp) + return false; + + for (i = 0; i < size; i++) + if (fp[i]) + return false; + + return true; +} + +static int __parse_flow_nlattrs(const struct nlattr *attr, + const struct nlattr *a[], + u64 *attrsp, bool log, bool nz) +{ + const struct nlattr *nla; + u64 attrs; + int rem; + + attrs = *attrsp; + nla_for_each_nested(nla, attr, rem) { + u16 type = nla_type(nla); + int expected_len; + + if (type > OVS_KEY_ATTR_MAX) { + OVS_NLERR(log, "Key type %d is out of range max %d", + type, OVS_KEY_ATTR_MAX); + return -EINVAL; + } + + if (type == OVS_KEY_ATTR_PACKET_TYPE || + type == OVS_KEY_ATTR_ND_EXTENSIONS || + type == OVS_KEY_ATTR_TUNNEL_INFO) { + OVS_NLERR(log, "Key type %d is not supported", type); + return -EINVAL; + } + + if (attrs & (1ULL << type)) { + OVS_NLERR(log, "Duplicate key (type %d).", type); + return -EINVAL; + } + + expected_len = ovs_key_lens[type].len; + if (!check_attr_len(nla_len(nla), expected_len)) { + OVS_NLERR(log, "Key %d has unexpected len %d expected %d", + type, nla_len(nla), expected_len); + return -EINVAL; + } + + if (!nz || !is_all_zero(nla_data(nla), nla_len(nla))) { + attrs |= 1ULL << type; + a[type] = nla; + } + } + if (rem) { + OVS_NLERR(log, "Message has %d unknown bytes.", rem); + return -EINVAL; + } + + *attrsp = attrs; + return 0; +} + +static int parse_flow_mask_nlattrs(const struct nlattr *attr, + const struct nlattr *a[], u64 *attrsp, + bool log) +{ + return __parse_flow_nlattrs(attr, a, attrsp, log, true); +} + +int parse_flow_nlattrs(const struct nlattr *attr, const struct nlattr *a[], + u64 *attrsp, bool log) +{ + return __parse_flow_nlattrs(attr, a, attrsp, log, false); +} + +static int genev_tun_opt_from_nlattr(const struct nlattr *a, + struct sw_flow_match *match, bool is_mask, + bool log) +{ + unsigned long opt_key_offset; + + if (nla_len(a) > sizeof(match->key->tun_opts)) { + OVS_NLERR(log, "Geneve option length err (len %d, max %zu).", + nla_len(a), sizeof(match->key->tun_opts)); + return -EINVAL; + } + + if (nla_len(a) % 4 != 0) { + OVS_NLERR(log, "Geneve opt len %d is not a multiple of 4.", + nla_len(a)); + return -EINVAL; + } + + /* We need to record the length of the options passed + * down, otherwise packets with the same format but + * additional options will be silently matched. + */ + if (!is_mask) { + SW_FLOW_KEY_PUT(match, tun_opts_len, nla_len(a), + false); + } else { + /* This is somewhat unusual because it looks at + * both the key and mask while parsing the + * attributes (and by extension assumes the key + * is parsed first). Normally, we would verify + * that each is the correct length and that the + * attributes line up in the validate function. + * However, that is difficult because this is + * variable length and we won't have the + * information later. + */ + if (match->key->tun_opts_len != nla_len(a)) { + OVS_NLERR(log, "Geneve option len %d != mask len %d", + match->key->tun_opts_len, nla_len(a)); + return -EINVAL; + } + + SW_FLOW_KEY_PUT(match, tun_opts_len, 0xff, true); + } + + opt_key_offset = TUN_METADATA_OFFSET(nla_len(a)); + SW_FLOW_KEY_MEMCPY_OFFSET(match, opt_key_offset, nla_data(a), + nla_len(a), is_mask); + return 0; +} + +static int vxlan_tun_opt_from_nlattr(const struct nlattr *attr, + struct sw_flow_match *match, bool is_mask, + bool log) +{ + struct nlattr *a; + int rem; + unsigned long opt_key_offset; + struct vxlan_metadata opts; + + BUILD_BUG_ON(sizeof(opts) > sizeof(match->key->tun_opts)); + + memset(&opts, 0, sizeof(opts)); + nla_for_each_nested(a, attr, rem) { + int type = nla_type(a); + + if (type > OVS_VXLAN_EXT_MAX) { + OVS_NLERR(log, "VXLAN extension %d out of range max %d", + type, OVS_VXLAN_EXT_MAX); + return -EINVAL; + } + + if (!check_attr_len(nla_len(a), + ovs_vxlan_ext_key_lens[type].len)) { + OVS_NLERR(log, "VXLAN extension %d has unexpected len %d expected %d", + type, nla_len(a), + ovs_vxlan_ext_key_lens[type].len); + return -EINVAL; + } + + switch (type) { + case OVS_VXLAN_EXT_GBP: + opts.gbp = nla_get_u32(a); + break; + default: + OVS_NLERR(log, "Unknown VXLAN extension attribute %d", + type); + return -EINVAL; + } + } + if (rem) { + OVS_NLERR(log, "VXLAN extension message has %d unknown bytes.", + rem); + return -EINVAL; + } + + if (!is_mask) + SW_FLOW_KEY_PUT(match, tun_opts_len, sizeof(opts), false); + else + SW_FLOW_KEY_PUT(match, tun_opts_len, 0xff, true); + + opt_key_offset = TUN_METADATA_OFFSET(sizeof(opts)); + SW_FLOW_KEY_MEMCPY_OFFSET(match, opt_key_offset, &opts, sizeof(opts), + is_mask); + return 0; +} + +static int erspan_tun_opt_from_nlattr(const struct nlattr *a, + struct sw_flow_match *match, bool is_mask, + bool log) +{ + unsigned long opt_key_offset; + + BUILD_BUG_ON(sizeof(struct erspan_metadata) > + sizeof(match->key->tun_opts)); + + if (nla_len(a) > sizeof(match->key->tun_opts)) { + OVS_NLERR(log, "ERSPAN option length err (len %d, max %zu).", + nla_len(a), sizeof(match->key->tun_opts)); + return -EINVAL; + } + + if (!is_mask) + SW_FLOW_KEY_PUT(match, tun_opts_len, + sizeof(struct erspan_metadata), false); + else + SW_FLOW_KEY_PUT(match, tun_opts_len, 0xff, true); + + opt_key_offset = TUN_METADATA_OFFSET(nla_len(a)); + SW_FLOW_KEY_MEMCPY_OFFSET(match, opt_key_offset, nla_data(a), + nla_len(a), is_mask); + return 0; +} + +static int ip_tun_from_nlattr(const struct nlattr *attr, + struct sw_flow_match *match, bool is_mask, + bool log) +{ + bool ttl = false, ipv4 = false, ipv6 = false; + bool info_bridge_mode = false; + __be16 tun_flags = 0; + int opts_type = 0; + struct nlattr *a; + int rem; + + nla_for_each_nested(a, attr, rem) { + int type = nla_type(a); + int err; + + if (type > OVS_TUNNEL_KEY_ATTR_MAX) { + OVS_NLERR(log, "Tunnel attr %d out of range max %d", + type, OVS_TUNNEL_KEY_ATTR_MAX); + return -EINVAL; + } + + if (!check_attr_len(nla_len(a), + ovs_tunnel_key_lens[type].len)) { + OVS_NLERR(log, "Tunnel attr %d has unexpected len %d expected %d", + type, nla_len(a), ovs_tunnel_key_lens[type].len); + return -EINVAL; + } + + switch (type) { + case OVS_TUNNEL_KEY_ATTR_ID: + SW_FLOW_KEY_PUT(match, tun_key.tun_id, + nla_get_be64(a), is_mask); + tun_flags |= TUNNEL_KEY; + break; + case OVS_TUNNEL_KEY_ATTR_IPV4_SRC: + SW_FLOW_KEY_PUT(match, tun_key.u.ipv4.src, + nla_get_in_addr(a), is_mask); + ipv4 = true; + break; + case OVS_TUNNEL_KEY_ATTR_IPV4_DST: + SW_FLOW_KEY_PUT(match, tun_key.u.ipv4.dst, + nla_get_in_addr(a), is_mask); + ipv4 = true; + break; + case OVS_TUNNEL_KEY_ATTR_IPV6_SRC: + SW_FLOW_KEY_PUT(match, tun_key.u.ipv6.src, + nla_get_in6_addr(a), is_mask); + ipv6 = true; + break; + case OVS_TUNNEL_KEY_ATTR_IPV6_DST: + SW_FLOW_KEY_PUT(match, tun_key.u.ipv6.dst, + nla_get_in6_addr(a), is_mask); + ipv6 = true; + break; + case OVS_TUNNEL_KEY_ATTR_TOS: + SW_FLOW_KEY_PUT(match, tun_key.tos, + nla_get_u8(a), is_mask); + break; + case OVS_TUNNEL_KEY_ATTR_TTL: + SW_FLOW_KEY_PUT(match, tun_key.ttl, + nla_get_u8(a), is_mask); + ttl = true; + break; + case OVS_TUNNEL_KEY_ATTR_DONT_FRAGMENT: + tun_flags |= TUNNEL_DONT_FRAGMENT; + break; + case OVS_TUNNEL_KEY_ATTR_CSUM: + tun_flags |= TUNNEL_CSUM; + break; + case OVS_TUNNEL_KEY_ATTR_TP_SRC: + SW_FLOW_KEY_PUT(match, tun_key.tp_src, + nla_get_be16(a), is_mask); + break; + case OVS_TUNNEL_KEY_ATTR_TP_DST: + SW_FLOW_KEY_PUT(match, tun_key.tp_dst, + nla_get_be16(a), is_mask); + break; + case OVS_TUNNEL_KEY_ATTR_OAM: + tun_flags |= TUNNEL_OAM; + break; + case OVS_TUNNEL_KEY_ATTR_GENEVE_OPTS: + if (opts_type) { + OVS_NLERR(log, "Multiple metadata blocks provided"); + return -EINVAL; + } + + err = genev_tun_opt_from_nlattr(a, match, is_mask, log); + if (err) + return err; + + tun_flags |= TUNNEL_GENEVE_OPT; + opts_type = type; + break; + case OVS_TUNNEL_KEY_ATTR_VXLAN_OPTS: + if (opts_type) { + OVS_NLERR(log, "Multiple metadata blocks provided"); + return -EINVAL; + } + + err = vxlan_tun_opt_from_nlattr(a, match, is_mask, log); + if (err) + return err; + + tun_flags |= TUNNEL_VXLAN_OPT; + opts_type = type; + break; + case OVS_TUNNEL_KEY_ATTR_PAD: + break; + case OVS_TUNNEL_KEY_ATTR_ERSPAN_OPTS: + if (opts_type) { + OVS_NLERR(log, "Multiple metadata blocks provided"); + return -EINVAL; + } + + err = erspan_tun_opt_from_nlattr(a, match, is_mask, + log); + if (err) + return err; + + tun_flags |= TUNNEL_ERSPAN_OPT; + opts_type = type; + break; + case OVS_TUNNEL_KEY_ATTR_IPV4_INFO_BRIDGE: + info_bridge_mode = true; + ipv4 = true; + break; + default: + OVS_NLERR(log, "Unknown IP tunnel attribute %d", + type); + return -EINVAL; + } + } + + SW_FLOW_KEY_PUT(match, tun_key.tun_flags, tun_flags, is_mask); + if (is_mask) + SW_FLOW_KEY_MEMSET_FIELD(match, tun_proto, 0xff, true); + else + SW_FLOW_KEY_PUT(match, tun_proto, ipv6 ? AF_INET6 : AF_INET, + false); + + if (rem > 0) { + OVS_NLERR(log, "IP tunnel attribute has %d unknown bytes.", + rem); + return -EINVAL; + } + + if (ipv4 && ipv6) { + OVS_NLERR(log, "Mixed IPv4 and IPv6 tunnel attributes"); + return -EINVAL; + } + + if (!is_mask) { + if (!ipv4 && !ipv6) { + OVS_NLERR(log, "IP tunnel dst address not specified"); + return -EINVAL; + } + if (ipv4) { + if (info_bridge_mode) { + if (match->key->tun_key.u.ipv4.src || + match->key->tun_key.u.ipv4.dst || + match->key->tun_key.tp_src || + match->key->tun_key.tp_dst || + match->key->tun_key.ttl || + match->key->tun_key.tos || + tun_flags & ~TUNNEL_KEY) { + OVS_NLERR(log, "IPv4 tun info is not correct"); + return -EINVAL; + } + } else if (!match->key->tun_key.u.ipv4.dst) { + OVS_NLERR(log, "IPv4 tunnel dst address is zero"); + return -EINVAL; + } + } + if (ipv6 && ipv6_addr_any(&match->key->tun_key.u.ipv6.dst)) { + OVS_NLERR(log, "IPv6 tunnel dst address is zero"); + return -EINVAL; + } + + if (!ttl && !info_bridge_mode) { + OVS_NLERR(log, "IP tunnel TTL not specified."); + return -EINVAL; + } + } + + return opts_type; +} + +static int vxlan_opt_to_nlattr(struct sk_buff *skb, + const void *tun_opts, int swkey_tun_opts_len) +{ + const struct vxlan_metadata *opts = tun_opts; + struct nlattr *nla; + + nla = nla_nest_start_noflag(skb, OVS_TUNNEL_KEY_ATTR_VXLAN_OPTS); + if (!nla) + return -EMSGSIZE; + + if (nla_put_u32(skb, OVS_VXLAN_EXT_GBP, opts->gbp) < 0) + return -EMSGSIZE; + + nla_nest_end(skb, nla); + return 0; +} + +static int __ip_tun_to_nlattr(struct sk_buff *skb, + const struct ip_tunnel_key *output, + const void *tun_opts, int swkey_tun_opts_len, + unsigned short tun_proto, u8 mode) +{ + if (output->tun_flags & TUNNEL_KEY && + nla_put_be64(skb, OVS_TUNNEL_KEY_ATTR_ID, output->tun_id, + OVS_TUNNEL_KEY_ATTR_PAD)) + return -EMSGSIZE; + + if (mode & IP_TUNNEL_INFO_BRIDGE) + return nla_put_flag(skb, OVS_TUNNEL_KEY_ATTR_IPV4_INFO_BRIDGE) + ? -EMSGSIZE : 0; + + switch (tun_proto) { + case AF_INET: + if (output->u.ipv4.src && + nla_put_in_addr(skb, OVS_TUNNEL_KEY_ATTR_IPV4_SRC, + output->u.ipv4.src)) + return -EMSGSIZE; + if (output->u.ipv4.dst && + nla_put_in_addr(skb, OVS_TUNNEL_KEY_ATTR_IPV4_DST, + output->u.ipv4.dst)) + return -EMSGSIZE; + break; + case AF_INET6: + if (!ipv6_addr_any(&output->u.ipv6.src) && + nla_put_in6_addr(skb, OVS_TUNNEL_KEY_ATTR_IPV6_SRC, + &output->u.ipv6.src)) + return -EMSGSIZE; + if (!ipv6_addr_any(&output->u.ipv6.dst) && + nla_put_in6_addr(skb, OVS_TUNNEL_KEY_ATTR_IPV6_DST, + &output->u.ipv6.dst)) + return -EMSGSIZE; + break; + } + if (output->tos && + nla_put_u8(skb, OVS_TUNNEL_KEY_ATTR_TOS, output->tos)) + return -EMSGSIZE; + if (nla_put_u8(skb, OVS_TUNNEL_KEY_ATTR_TTL, output->ttl)) + return -EMSGSIZE; + if ((output->tun_flags & TUNNEL_DONT_FRAGMENT) && + nla_put_flag(skb, OVS_TUNNEL_KEY_ATTR_DONT_FRAGMENT)) + return -EMSGSIZE; + if ((output->tun_flags & TUNNEL_CSUM) && + nla_put_flag(skb, OVS_TUNNEL_KEY_ATTR_CSUM)) + return -EMSGSIZE; + if (output->tp_src && + nla_put_be16(skb, OVS_TUNNEL_KEY_ATTR_TP_SRC, output->tp_src)) + return -EMSGSIZE; + if (output->tp_dst && + nla_put_be16(skb, OVS_TUNNEL_KEY_ATTR_TP_DST, output->tp_dst)) + return -EMSGSIZE; + if ((output->tun_flags & TUNNEL_OAM) && + nla_put_flag(skb, OVS_TUNNEL_KEY_ATTR_OAM)) + return -EMSGSIZE; + if (swkey_tun_opts_len) { + if (output->tun_flags & TUNNEL_GENEVE_OPT && + nla_put(skb, OVS_TUNNEL_KEY_ATTR_GENEVE_OPTS, + swkey_tun_opts_len, tun_opts)) + return -EMSGSIZE; + else if (output->tun_flags & TUNNEL_VXLAN_OPT && + vxlan_opt_to_nlattr(skb, tun_opts, swkey_tun_opts_len)) + return -EMSGSIZE; + else if (output->tun_flags & TUNNEL_ERSPAN_OPT && + nla_put(skb, OVS_TUNNEL_KEY_ATTR_ERSPAN_OPTS, + swkey_tun_opts_len, tun_opts)) + return -EMSGSIZE; + } + + return 0; +} + +static int ip_tun_to_nlattr(struct sk_buff *skb, + const struct ip_tunnel_key *output, + const void *tun_opts, int swkey_tun_opts_len, + unsigned short tun_proto, u8 mode) +{ + struct nlattr *nla; + int err; + + nla = nla_nest_start_noflag(skb, OVS_KEY_ATTR_TUNNEL); + if (!nla) + return -EMSGSIZE; + + err = __ip_tun_to_nlattr(skb, output, tun_opts, swkey_tun_opts_len, + tun_proto, mode); + if (err) + return err; + + nla_nest_end(skb, nla); + return 0; +} + +int ovs_nla_put_tunnel_info(struct sk_buff *skb, + struct ip_tunnel_info *tun_info) +{ + return __ip_tun_to_nlattr(skb, &tun_info->key, + ip_tunnel_info_opts(tun_info), + tun_info->options_len, + ip_tunnel_info_af(tun_info), tun_info->mode); +} + +static int encode_vlan_from_nlattrs(struct sw_flow_match *match, + const struct nlattr *a[], + bool is_mask, bool inner) +{ + __be16 tci = 0; + __be16 tpid = 0; + + if (a[OVS_KEY_ATTR_VLAN]) + tci = nla_get_be16(a[OVS_KEY_ATTR_VLAN]); + + if (a[OVS_KEY_ATTR_ETHERTYPE]) + tpid = nla_get_be16(a[OVS_KEY_ATTR_ETHERTYPE]); + + if (likely(!inner)) { + SW_FLOW_KEY_PUT(match, eth.vlan.tpid, tpid, is_mask); + SW_FLOW_KEY_PUT(match, eth.vlan.tci, tci, is_mask); + } else { + SW_FLOW_KEY_PUT(match, eth.cvlan.tpid, tpid, is_mask); + SW_FLOW_KEY_PUT(match, eth.cvlan.tci, tci, is_mask); + } + return 0; +} + +static int validate_vlan_from_nlattrs(const struct sw_flow_match *match, + u64 key_attrs, bool inner, + const struct nlattr **a, bool log) +{ + __be16 tci = 0; + + if (!((key_attrs & (1 << OVS_KEY_ATTR_ETHERNET)) && + (key_attrs & (1 << OVS_KEY_ATTR_ETHERTYPE)) && + eth_type_vlan(nla_get_be16(a[OVS_KEY_ATTR_ETHERTYPE])))) { + /* Not a VLAN. */ + return 0; + } + + if (!((key_attrs & (1 << OVS_KEY_ATTR_VLAN)) && + (key_attrs & (1 << OVS_KEY_ATTR_ENCAP)))) { + OVS_NLERR(log, "Invalid %s frame", (inner) ? "C-VLAN" : "VLAN"); + return -EINVAL; + } + + if (a[OVS_KEY_ATTR_VLAN]) + tci = nla_get_be16(a[OVS_KEY_ATTR_VLAN]); + + if (!(tci & htons(VLAN_CFI_MASK))) { + if (tci) { + OVS_NLERR(log, "%s TCI does not have VLAN_CFI_MASK bit set.", + (inner) ? "C-VLAN" : "VLAN"); + return -EINVAL; + } else if (nla_len(a[OVS_KEY_ATTR_ENCAP])) { + /* Corner case for truncated VLAN header. */ + OVS_NLERR(log, "Truncated %s header has non-zero encap attribute.", + (inner) ? "C-VLAN" : "VLAN"); + return -EINVAL; + } + } + + return 1; +} + +static int validate_vlan_mask_from_nlattrs(const struct sw_flow_match *match, + u64 key_attrs, bool inner, + const struct nlattr **a, bool log) +{ + __be16 tci = 0; + __be16 tpid = 0; + bool encap_valid = !!(match->key->eth.vlan.tci & + htons(VLAN_CFI_MASK)); + bool i_encap_valid = !!(match->key->eth.cvlan.tci & + htons(VLAN_CFI_MASK)); + + if (!(key_attrs & (1 << OVS_KEY_ATTR_ENCAP))) { + /* Not a VLAN. */ + return 0; + } + + if ((!inner && !encap_valid) || (inner && !i_encap_valid)) { + OVS_NLERR(log, "Encap mask attribute is set for non-%s frame.", + (inner) ? "C-VLAN" : "VLAN"); + return -EINVAL; + } + + if (a[OVS_KEY_ATTR_VLAN]) + tci = nla_get_be16(a[OVS_KEY_ATTR_VLAN]); + + if (a[OVS_KEY_ATTR_ETHERTYPE]) + tpid = nla_get_be16(a[OVS_KEY_ATTR_ETHERTYPE]); + + if (tpid != htons(0xffff)) { + OVS_NLERR(log, "Must have an exact match on %s TPID (mask=%x).", + (inner) ? "C-VLAN" : "VLAN", ntohs(tpid)); + return -EINVAL; + } + if (!(tci & htons(VLAN_CFI_MASK))) { + OVS_NLERR(log, "%s TCI mask does not have exact match for VLAN_CFI_MASK bit.", + (inner) ? "C-VLAN" : "VLAN"); + return -EINVAL; + } + + return 1; +} + +static int __parse_vlan_from_nlattrs(struct sw_flow_match *match, + u64 *key_attrs, bool inner, + const struct nlattr **a, bool is_mask, + bool log) +{ + int err; + const struct nlattr *encap; + + if (!is_mask) + err = validate_vlan_from_nlattrs(match, *key_attrs, inner, + a, log); + else + err = validate_vlan_mask_from_nlattrs(match, *key_attrs, inner, + a, log); + if (err <= 0) + return err; + + err = encode_vlan_from_nlattrs(match, a, is_mask, inner); + if (err) + return err; + + *key_attrs &= ~(1 << OVS_KEY_ATTR_ENCAP); + *key_attrs &= ~(1 << OVS_KEY_ATTR_VLAN); + *key_attrs &= ~(1 << OVS_KEY_ATTR_ETHERTYPE); + + encap = a[OVS_KEY_ATTR_ENCAP]; + + if (!is_mask) + err = parse_flow_nlattrs(encap, a, key_attrs, log); + else + err = parse_flow_mask_nlattrs(encap, a, key_attrs, log); + + return err; +} + +static int parse_vlan_from_nlattrs(struct sw_flow_match *match, + u64 *key_attrs, const struct nlattr **a, + bool is_mask, bool log) +{ + int err; + bool encap_valid = false; + + err = __parse_vlan_from_nlattrs(match, key_attrs, false, a, + is_mask, log); + if (err) + return err; + + encap_valid = !!(match->key->eth.vlan.tci & htons(VLAN_CFI_MASK)); + if (encap_valid) { + err = __parse_vlan_from_nlattrs(match, key_attrs, true, a, + is_mask, log); + if (err) + return err; + } + + return 0; +} + +static int parse_eth_type_from_nlattrs(struct sw_flow_match *match, + u64 *attrs, const struct nlattr **a, + bool is_mask, bool log) +{ + __be16 eth_type; + + eth_type = nla_get_be16(a[OVS_KEY_ATTR_ETHERTYPE]); + if (is_mask) { + /* Always exact match EtherType. */ + eth_type = htons(0xffff); + } else if (!eth_proto_is_802_3(eth_type)) { + OVS_NLERR(log, "EtherType %x is less than min %x", + ntohs(eth_type), ETH_P_802_3_MIN); + return -EINVAL; + } + + SW_FLOW_KEY_PUT(match, eth.type, eth_type, is_mask); + *attrs &= ~(1 << OVS_KEY_ATTR_ETHERTYPE); + return 0; +} + +static int metadata_from_nlattrs(struct net *net, struct sw_flow_match *match, + u64 *attrs, const struct nlattr **a, + bool is_mask, bool log) +{ + u8 mac_proto = MAC_PROTO_ETHERNET; + + if (*attrs & (1 << OVS_KEY_ATTR_DP_HASH)) { + u32 hash_val = nla_get_u32(a[OVS_KEY_ATTR_DP_HASH]); + + SW_FLOW_KEY_PUT(match, ovs_flow_hash, hash_val, is_mask); + *attrs &= ~(1 << OVS_KEY_ATTR_DP_HASH); + } + + if (*attrs & (1 << OVS_KEY_ATTR_RECIRC_ID)) { + u32 recirc_id = nla_get_u32(a[OVS_KEY_ATTR_RECIRC_ID]); + + SW_FLOW_KEY_PUT(match, recirc_id, recirc_id, is_mask); + *attrs &= ~(1 << OVS_KEY_ATTR_RECIRC_ID); + } + + if (*attrs & (1 << OVS_KEY_ATTR_PRIORITY)) { + SW_FLOW_KEY_PUT(match, phy.priority, + nla_get_u32(a[OVS_KEY_ATTR_PRIORITY]), is_mask); + *attrs &= ~(1 << OVS_KEY_ATTR_PRIORITY); + } + + if (*attrs & (1 << OVS_KEY_ATTR_IN_PORT)) { + u32 in_port = nla_get_u32(a[OVS_KEY_ATTR_IN_PORT]); + + if (is_mask) { + in_port = 0xffffffff; /* Always exact match in_port. */ + } else if (in_port >= DP_MAX_PORTS) { + OVS_NLERR(log, "Port %d exceeds max allowable %d", + in_port, DP_MAX_PORTS); + return -EINVAL; + } + + SW_FLOW_KEY_PUT(match, phy.in_port, in_port, is_mask); + *attrs &= ~(1 << OVS_KEY_ATTR_IN_PORT); + } else if (!is_mask) { + SW_FLOW_KEY_PUT(match, phy.in_port, DP_MAX_PORTS, is_mask); + } + + if (*attrs & (1 << OVS_KEY_ATTR_SKB_MARK)) { + uint32_t mark = nla_get_u32(a[OVS_KEY_ATTR_SKB_MARK]); + + SW_FLOW_KEY_PUT(match, phy.skb_mark, mark, is_mask); + *attrs &= ~(1 << OVS_KEY_ATTR_SKB_MARK); + } + if (*attrs & (1 << OVS_KEY_ATTR_TUNNEL)) { + if (ip_tun_from_nlattr(a[OVS_KEY_ATTR_TUNNEL], match, + is_mask, log) < 0) + return -EINVAL; + *attrs &= ~(1 << OVS_KEY_ATTR_TUNNEL); + } + + if (*attrs & (1 << OVS_KEY_ATTR_CT_STATE) && + ovs_ct_verify(net, OVS_KEY_ATTR_CT_STATE)) { + u32 ct_state = nla_get_u32(a[OVS_KEY_ATTR_CT_STATE]); + + if (ct_state & ~CT_SUPPORTED_MASK) { + OVS_NLERR(log, "ct_state flags %08x unsupported", + ct_state); + return -EINVAL; + } + + SW_FLOW_KEY_PUT(match, ct_state, ct_state, is_mask); + *attrs &= ~(1ULL << OVS_KEY_ATTR_CT_STATE); + } + if (*attrs & (1 << OVS_KEY_ATTR_CT_ZONE) && + ovs_ct_verify(net, OVS_KEY_ATTR_CT_ZONE)) { + u16 ct_zone = nla_get_u16(a[OVS_KEY_ATTR_CT_ZONE]); + + SW_FLOW_KEY_PUT(match, ct_zone, ct_zone, is_mask); + *attrs &= ~(1ULL << OVS_KEY_ATTR_CT_ZONE); + } + if (*attrs & (1 << OVS_KEY_ATTR_CT_MARK) && + ovs_ct_verify(net, OVS_KEY_ATTR_CT_MARK)) { + u32 mark = nla_get_u32(a[OVS_KEY_ATTR_CT_MARK]); + + SW_FLOW_KEY_PUT(match, ct.mark, mark, is_mask); + *attrs &= ~(1ULL << OVS_KEY_ATTR_CT_MARK); + } + if (*attrs & (1 << OVS_KEY_ATTR_CT_LABELS) && + ovs_ct_verify(net, OVS_KEY_ATTR_CT_LABELS)) { + const struct ovs_key_ct_labels *cl; + + cl = nla_data(a[OVS_KEY_ATTR_CT_LABELS]); + SW_FLOW_KEY_MEMCPY(match, ct.labels, cl->ct_labels, + sizeof(*cl), is_mask); + *attrs &= ~(1ULL << OVS_KEY_ATTR_CT_LABELS); + } + if (*attrs & (1ULL << OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV4)) { + const struct ovs_key_ct_tuple_ipv4 *ct; + + ct = nla_data(a[OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV4]); + + SW_FLOW_KEY_PUT(match, ipv4.ct_orig.src, ct->ipv4_src, is_mask); + SW_FLOW_KEY_PUT(match, ipv4.ct_orig.dst, ct->ipv4_dst, is_mask); + SW_FLOW_KEY_PUT(match, ct.orig_tp.src, ct->src_port, is_mask); + SW_FLOW_KEY_PUT(match, ct.orig_tp.dst, ct->dst_port, is_mask); + SW_FLOW_KEY_PUT(match, ct_orig_proto, ct->ipv4_proto, is_mask); + *attrs &= ~(1ULL << OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV4); + } + if (*attrs & (1ULL << OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV6)) { + const struct ovs_key_ct_tuple_ipv6 *ct; + + ct = nla_data(a[OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV6]); + + SW_FLOW_KEY_MEMCPY(match, ipv6.ct_orig.src, &ct->ipv6_src, + sizeof(match->key->ipv6.ct_orig.src), + is_mask); + SW_FLOW_KEY_MEMCPY(match, ipv6.ct_orig.dst, &ct->ipv6_dst, + sizeof(match->key->ipv6.ct_orig.dst), + is_mask); + SW_FLOW_KEY_PUT(match, ct.orig_tp.src, ct->src_port, is_mask); + SW_FLOW_KEY_PUT(match, ct.orig_tp.dst, ct->dst_port, is_mask); + SW_FLOW_KEY_PUT(match, ct_orig_proto, ct->ipv6_proto, is_mask); + *attrs &= ~(1ULL << OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV6); + } + + /* For layer 3 packets the Ethernet type is provided + * and treated as metadata but no MAC addresses are provided. + */ + if (!(*attrs & (1ULL << OVS_KEY_ATTR_ETHERNET)) && + (*attrs & (1ULL << OVS_KEY_ATTR_ETHERTYPE))) + mac_proto = MAC_PROTO_NONE; + + /* Always exact match mac_proto */ + SW_FLOW_KEY_PUT(match, mac_proto, is_mask ? 0xff : mac_proto, is_mask); + + if (mac_proto == MAC_PROTO_NONE) + return parse_eth_type_from_nlattrs(match, attrs, a, is_mask, + log); + + return 0; +} + +int nsh_hdr_from_nlattr(const struct nlattr *attr, + struct nshhdr *nh, size_t size) +{ + struct nlattr *a; + int rem; + u8 flags = 0; + u8 ttl = 0; + int mdlen = 0; + + /* validate_nsh has check this, so we needn't do duplicate check here + */ + if (size < NSH_BASE_HDR_LEN) + return -ENOBUFS; + + nla_for_each_nested(a, attr, rem) { + int type = nla_type(a); + + switch (type) { + case OVS_NSH_KEY_ATTR_BASE: { + const struct ovs_nsh_key_base *base = nla_data(a); + + flags = base->flags; + ttl = base->ttl; + nh->np = base->np; + nh->mdtype = base->mdtype; + nh->path_hdr = base->path_hdr; + break; + } + case OVS_NSH_KEY_ATTR_MD1: + mdlen = nla_len(a); + if (mdlen > size - NSH_BASE_HDR_LEN) + return -ENOBUFS; + memcpy(&nh->md1, nla_data(a), mdlen); + break; + + case OVS_NSH_KEY_ATTR_MD2: + mdlen = nla_len(a); + if (mdlen > size - NSH_BASE_HDR_LEN) + return -ENOBUFS; + memcpy(&nh->md2, nla_data(a), mdlen); + break; + + default: + return -EINVAL; + } + } + + /* nsh header length = NSH_BASE_HDR_LEN + mdlen */ + nh->ver_flags_ttl_len = 0; + nsh_set_flags_ttl_len(nh, flags, ttl, NSH_BASE_HDR_LEN + mdlen); + + return 0; +} + +int nsh_key_from_nlattr(const struct nlattr *attr, + struct ovs_key_nsh *nsh, struct ovs_key_nsh *nsh_mask) +{ + struct nlattr *a; + int rem; + + /* validate_nsh has check this, so we needn't do duplicate check here + */ + nla_for_each_nested(a, attr, rem) { + int type = nla_type(a); + + switch (type) { + case OVS_NSH_KEY_ATTR_BASE: { + const struct ovs_nsh_key_base *base = nla_data(a); + const struct ovs_nsh_key_base *base_mask = base + 1; + + nsh->base = *base; + nsh_mask->base = *base_mask; + break; + } + case OVS_NSH_KEY_ATTR_MD1: { + const struct ovs_nsh_key_md1 *md1 = nla_data(a); + const struct ovs_nsh_key_md1 *md1_mask = md1 + 1; + + memcpy(nsh->context, md1->context, sizeof(*md1)); + memcpy(nsh_mask->context, md1_mask->context, + sizeof(*md1_mask)); + break; + } + case OVS_NSH_KEY_ATTR_MD2: + /* Not supported yet */ + return -ENOTSUPP; + default: + return -EINVAL; + } + } + + return 0; +} + +static int nsh_key_put_from_nlattr(const struct nlattr *attr, + struct sw_flow_match *match, bool is_mask, + bool is_push_nsh, bool log) +{ + struct nlattr *a; + int rem; + bool has_base = false; + bool has_md1 = false; + bool has_md2 = false; + u8 mdtype = 0; + int mdlen = 0; + + if (WARN_ON(is_push_nsh && is_mask)) + return -EINVAL; + + nla_for_each_nested(a, attr, rem) { + int type = nla_type(a); + int i; + + if (type > OVS_NSH_KEY_ATTR_MAX) { + OVS_NLERR(log, "nsh attr %d is out of range max %d", + type, OVS_NSH_KEY_ATTR_MAX); + return -EINVAL; + } + + if (!check_attr_len(nla_len(a), + ovs_nsh_key_attr_lens[type].len)) { + OVS_NLERR( + log, + "nsh attr %d has unexpected len %d expected %d", + type, + nla_len(a), + ovs_nsh_key_attr_lens[type].len + ); + return -EINVAL; + } + + switch (type) { + case OVS_NSH_KEY_ATTR_BASE: { + const struct ovs_nsh_key_base *base = nla_data(a); + + has_base = true; + mdtype = base->mdtype; + SW_FLOW_KEY_PUT(match, nsh.base.flags, + base->flags, is_mask); + SW_FLOW_KEY_PUT(match, nsh.base.ttl, + base->ttl, is_mask); + SW_FLOW_KEY_PUT(match, nsh.base.mdtype, + base->mdtype, is_mask); + SW_FLOW_KEY_PUT(match, nsh.base.np, + base->np, is_mask); + SW_FLOW_KEY_PUT(match, nsh.base.path_hdr, + base->path_hdr, is_mask); + break; + } + case OVS_NSH_KEY_ATTR_MD1: { + const struct ovs_nsh_key_md1 *md1 = nla_data(a); + + has_md1 = true; + for (i = 0; i < NSH_MD1_CONTEXT_SIZE; i++) + SW_FLOW_KEY_PUT(match, nsh.context[i], + md1->context[i], is_mask); + break; + } + case OVS_NSH_KEY_ATTR_MD2: + if (!is_push_nsh) /* Not supported MD type 2 yet */ + return -ENOTSUPP; + + has_md2 = true; + mdlen = nla_len(a); + if (mdlen > NSH_CTX_HDRS_MAX_LEN || mdlen <= 0) { + OVS_NLERR( + log, + "Invalid MD length %d for MD type %d", + mdlen, + mdtype + ); + return -EINVAL; + } + break; + default: + OVS_NLERR(log, "Unknown nsh attribute %d", + type); + return -EINVAL; + } + } + + if (rem > 0) { + OVS_NLERR(log, "nsh attribute has %d unknown bytes.", rem); + return -EINVAL; + } + + if (has_md1 && has_md2) { + OVS_NLERR( + 1, + "invalid nsh attribute: md1 and md2 are exclusive." + ); + return -EINVAL; + } + + if (!is_mask) { + if ((has_md1 && mdtype != NSH_M_TYPE1) || + (has_md2 && mdtype != NSH_M_TYPE2)) { + OVS_NLERR(1, "nsh attribute has unmatched MD type %d.", + mdtype); + return -EINVAL; + } + + if (is_push_nsh && + (!has_base || (!has_md1 && !has_md2))) { + OVS_NLERR( + 1, + "push_nsh: missing base or metadata attributes" + ); + return -EINVAL; + } + } + + return 0; +} + +static int ovs_key_from_nlattrs(struct net *net, struct sw_flow_match *match, + u64 attrs, const struct nlattr **a, + bool is_mask, bool log) +{ + int err; + + err = metadata_from_nlattrs(net, match, &attrs, a, is_mask, log); + if (err) + return err; + + if (attrs & (1 << OVS_KEY_ATTR_ETHERNET)) { + const struct ovs_key_ethernet *eth_key; + + eth_key = nla_data(a[OVS_KEY_ATTR_ETHERNET]); + SW_FLOW_KEY_MEMCPY(match, eth.src, + eth_key->eth_src, ETH_ALEN, is_mask); + SW_FLOW_KEY_MEMCPY(match, eth.dst, + eth_key->eth_dst, ETH_ALEN, is_mask); + attrs &= ~(1 << OVS_KEY_ATTR_ETHERNET); + + if (attrs & (1 << OVS_KEY_ATTR_VLAN)) { + /* VLAN attribute is always parsed before getting here since it + * may occur multiple times. + */ + OVS_NLERR(log, "VLAN attribute unexpected."); + return -EINVAL; + } + + if (attrs & (1 << OVS_KEY_ATTR_ETHERTYPE)) { + err = parse_eth_type_from_nlattrs(match, &attrs, a, is_mask, + log); + if (err) + return err; + } else if (!is_mask) { + SW_FLOW_KEY_PUT(match, eth.type, htons(ETH_P_802_2), is_mask); + } + } else if (!match->key->eth.type) { + OVS_NLERR(log, "Either Ethernet header or EtherType is required."); + return -EINVAL; + } + + if (attrs & (1 << OVS_KEY_ATTR_IPV4)) { + const struct ovs_key_ipv4 *ipv4_key; + + ipv4_key = nla_data(a[OVS_KEY_ATTR_IPV4]); + if (!is_mask && ipv4_key->ipv4_frag > OVS_FRAG_TYPE_MAX) { + OVS_NLERR(log, "IPv4 frag type %d is out of range max %d", + ipv4_key->ipv4_frag, OVS_FRAG_TYPE_MAX); + return -EINVAL; + } + SW_FLOW_KEY_PUT(match, ip.proto, + ipv4_key->ipv4_proto, is_mask); + SW_FLOW_KEY_PUT(match, ip.tos, + ipv4_key->ipv4_tos, is_mask); + SW_FLOW_KEY_PUT(match, ip.ttl, + ipv4_key->ipv4_ttl, is_mask); + SW_FLOW_KEY_PUT(match, ip.frag, + ipv4_key->ipv4_frag, is_mask); + SW_FLOW_KEY_PUT(match, ipv4.addr.src, + ipv4_key->ipv4_src, is_mask); + SW_FLOW_KEY_PUT(match, ipv4.addr.dst, + ipv4_key->ipv4_dst, is_mask); + attrs &= ~(1 << OVS_KEY_ATTR_IPV4); + } + + if (attrs & (1 << OVS_KEY_ATTR_IPV6)) { + const struct ovs_key_ipv6 *ipv6_key; + + ipv6_key = nla_data(a[OVS_KEY_ATTR_IPV6]); + if (!is_mask && ipv6_key->ipv6_frag > OVS_FRAG_TYPE_MAX) { + OVS_NLERR(log, "IPv6 frag type %d is out of range max %d", + ipv6_key->ipv6_frag, OVS_FRAG_TYPE_MAX); + return -EINVAL; + } + + if (!is_mask && ipv6_key->ipv6_label & htonl(0xFFF00000)) { + OVS_NLERR(log, "IPv6 flow label %x is out of range (max=%x)", + ntohl(ipv6_key->ipv6_label), (1 << 20) - 1); + return -EINVAL; + } + + SW_FLOW_KEY_PUT(match, ipv6.label, + ipv6_key->ipv6_label, is_mask); + SW_FLOW_KEY_PUT(match, ip.proto, + ipv6_key->ipv6_proto, is_mask); + SW_FLOW_KEY_PUT(match, ip.tos, + ipv6_key->ipv6_tclass, is_mask); + SW_FLOW_KEY_PUT(match, ip.ttl, + ipv6_key->ipv6_hlimit, is_mask); + SW_FLOW_KEY_PUT(match, ip.frag, + ipv6_key->ipv6_frag, is_mask); + SW_FLOW_KEY_MEMCPY(match, ipv6.addr.src, + ipv6_key->ipv6_src, + sizeof(match->key->ipv6.addr.src), + is_mask); + SW_FLOW_KEY_MEMCPY(match, ipv6.addr.dst, + ipv6_key->ipv6_dst, + sizeof(match->key->ipv6.addr.dst), + is_mask); + + attrs &= ~(1 << OVS_KEY_ATTR_IPV6); + } + + if (attrs & (1ULL << OVS_KEY_ATTR_IPV6_EXTHDRS)) { + const struct ovs_key_ipv6_exthdrs *ipv6_exthdrs_key; + + ipv6_exthdrs_key = nla_data(a[OVS_KEY_ATTR_IPV6_EXTHDRS]); + + SW_FLOW_KEY_PUT(match, ipv6.exthdrs, + ipv6_exthdrs_key->hdrs, is_mask); + + attrs &= ~(1ULL << OVS_KEY_ATTR_IPV6_EXTHDRS); + } + + if (attrs & (1 << OVS_KEY_ATTR_ARP)) { + const struct ovs_key_arp *arp_key; + + arp_key = nla_data(a[OVS_KEY_ATTR_ARP]); + if (!is_mask && (arp_key->arp_op & htons(0xff00))) { + OVS_NLERR(log, "Unknown ARP opcode (opcode=%d).", + arp_key->arp_op); + return -EINVAL; + } + + SW_FLOW_KEY_PUT(match, ipv4.addr.src, + arp_key->arp_sip, is_mask); + SW_FLOW_KEY_PUT(match, ipv4.addr.dst, + arp_key->arp_tip, is_mask); + SW_FLOW_KEY_PUT(match, ip.proto, + ntohs(arp_key->arp_op), is_mask); + SW_FLOW_KEY_MEMCPY(match, ipv4.arp.sha, + arp_key->arp_sha, ETH_ALEN, is_mask); + SW_FLOW_KEY_MEMCPY(match, ipv4.arp.tha, + arp_key->arp_tha, ETH_ALEN, is_mask); + + attrs &= ~(1 << OVS_KEY_ATTR_ARP); + } + + if (attrs & (1 << OVS_KEY_ATTR_NSH)) { + if (nsh_key_put_from_nlattr(a[OVS_KEY_ATTR_NSH], match, + is_mask, false, log) < 0) + return -EINVAL; + attrs &= ~(1 << OVS_KEY_ATTR_NSH); + } + + if (attrs & (1 << OVS_KEY_ATTR_MPLS)) { + const struct ovs_key_mpls *mpls_key; + u32 hdr_len; + u32 label_count, label_count_mask, i; + + mpls_key = nla_data(a[OVS_KEY_ATTR_MPLS]); + hdr_len = nla_len(a[OVS_KEY_ATTR_MPLS]); + label_count = hdr_len / sizeof(struct ovs_key_mpls); + + if (label_count == 0 || label_count > MPLS_LABEL_DEPTH || + hdr_len % sizeof(struct ovs_key_mpls)) + return -EINVAL; + + label_count_mask = GENMASK(label_count - 1, 0); + + for (i = 0 ; i < label_count; i++) + SW_FLOW_KEY_PUT(match, mpls.lse[i], + mpls_key[i].mpls_lse, is_mask); + + SW_FLOW_KEY_PUT(match, mpls.num_labels_mask, + label_count_mask, is_mask); + + attrs &= ~(1 << OVS_KEY_ATTR_MPLS); + } + + if (attrs & (1 << OVS_KEY_ATTR_TCP)) { + const struct ovs_key_tcp *tcp_key; + + tcp_key = nla_data(a[OVS_KEY_ATTR_TCP]); + SW_FLOW_KEY_PUT(match, tp.src, tcp_key->tcp_src, is_mask); + SW_FLOW_KEY_PUT(match, tp.dst, tcp_key->tcp_dst, is_mask); + attrs &= ~(1 << OVS_KEY_ATTR_TCP); + } + + if (attrs & (1 << OVS_KEY_ATTR_TCP_FLAGS)) { + SW_FLOW_KEY_PUT(match, tp.flags, + nla_get_be16(a[OVS_KEY_ATTR_TCP_FLAGS]), + is_mask); + attrs &= ~(1 << OVS_KEY_ATTR_TCP_FLAGS); + } + + if (attrs & (1 << OVS_KEY_ATTR_UDP)) { + const struct ovs_key_udp *udp_key; + + udp_key = nla_data(a[OVS_KEY_ATTR_UDP]); + SW_FLOW_KEY_PUT(match, tp.src, udp_key->udp_src, is_mask); + SW_FLOW_KEY_PUT(match, tp.dst, udp_key->udp_dst, is_mask); + attrs &= ~(1 << OVS_KEY_ATTR_UDP); + } + + if (attrs & (1 << OVS_KEY_ATTR_SCTP)) { + const struct ovs_key_sctp *sctp_key; + + sctp_key = nla_data(a[OVS_KEY_ATTR_SCTP]); + SW_FLOW_KEY_PUT(match, tp.src, sctp_key->sctp_src, is_mask); + SW_FLOW_KEY_PUT(match, tp.dst, sctp_key->sctp_dst, is_mask); + attrs &= ~(1 << OVS_KEY_ATTR_SCTP); + } + + if (attrs & (1 << OVS_KEY_ATTR_ICMP)) { + const struct ovs_key_icmp *icmp_key; + + icmp_key = nla_data(a[OVS_KEY_ATTR_ICMP]); + SW_FLOW_KEY_PUT(match, tp.src, + htons(icmp_key->icmp_type), is_mask); + SW_FLOW_KEY_PUT(match, tp.dst, + htons(icmp_key->icmp_code), is_mask); + attrs &= ~(1 << OVS_KEY_ATTR_ICMP); + } + + if (attrs & (1 << OVS_KEY_ATTR_ICMPV6)) { + const struct ovs_key_icmpv6 *icmpv6_key; + + icmpv6_key = nla_data(a[OVS_KEY_ATTR_ICMPV6]); + SW_FLOW_KEY_PUT(match, tp.src, + htons(icmpv6_key->icmpv6_type), is_mask); + SW_FLOW_KEY_PUT(match, tp.dst, + htons(icmpv6_key->icmpv6_code), is_mask); + attrs &= ~(1 << OVS_KEY_ATTR_ICMPV6); + } + + if (attrs & (1 << OVS_KEY_ATTR_ND)) { + const struct ovs_key_nd *nd_key; + + nd_key = nla_data(a[OVS_KEY_ATTR_ND]); + SW_FLOW_KEY_MEMCPY(match, ipv6.nd.target, + nd_key->nd_target, + sizeof(match->key->ipv6.nd.target), + is_mask); + SW_FLOW_KEY_MEMCPY(match, ipv6.nd.sll, + nd_key->nd_sll, ETH_ALEN, is_mask); + SW_FLOW_KEY_MEMCPY(match, ipv6.nd.tll, + nd_key->nd_tll, ETH_ALEN, is_mask); + attrs &= ~(1 << OVS_KEY_ATTR_ND); + } + + if (attrs != 0) { + OVS_NLERR(log, "Unknown key attributes %llx", + (unsigned long long)attrs); + return -EINVAL; + } + + return 0; +} + +static void nlattr_set(struct nlattr *attr, u8 val, + const struct ovs_len_tbl *tbl) +{ + struct nlattr *nla; + int rem; + + /* The nlattr stream should already have been validated */ + nla_for_each_nested(nla, attr, rem) { + if (tbl[nla_type(nla)].len == OVS_ATTR_NESTED) + nlattr_set(nla, val, tbl[nla_type(nla)].next ? : tbl); + else + memset(nla_data(nla), val, nla_len(nla)); + + if (nla_type(nla) == OVS_KEY_ATTR_CT_STATE) + *(u32 *)nla_data(nla) &= CT_SUPPORTED_MASK; + } +} + +static void mask_set_nlattr(struct nlattr *attr, u8 val) +{ + nlattr_set(attr, val, ovs_key_lens); +} + +/** + * ovs_nla_get_match - parses Netlink attributes into a flow key and + * mask. In case the 'mask' is NULL, the flow is treated as exact match + * flow. Otherwise, it is treated as a wildcarded flow, except the mask + * does not include any don't care bit. + * @net: Used to determine per-namespace field support. + * @match: receives the extracted flow match information. + * @nla_key: Netlink attribute holding nested %OVS_KEY_ATTR_* Netlink attribute + * sequence. The fields should of the packet that triggered the creation + * of this flow. + * @nla_mask: Optional. Netlink attribute holding nested %OVS_KEY_ATTR_* + * Netlink attribute specifies the mask field of the wildcarded flow. + * @log: Boolean to allow kernel error logging. Normally true, but when + * probing for feature compatibility this should be passed in as false to + * suppress unnecessary error logging. + */ +int ovs_nla_get_match(struct net *net, struct sw_flow_match *match, + const struct nlattr *nla_key, + const struct nlattr *nla_mask, + bool log) +{ + const struct nlattr *a[OVS_KEY_ATTR_MAX + 1]; + struct nlattr *newmask = NULL; + u64 key_attrs = 0; + u64 mask_attrs = 0; + int err; + + err = parse_flow_nlattrs(nla_key, a, &key_attrs, log); + if (err) + return err; + + err = parse_vlan_from_nlattrs(match, &key_attrs, a, false, log); + if (err) + return err; + + err = ovs_key_from_nlattrs(net, match, key_attrs, a, false, log); + if (err) + return err; + + if (match->mask) { + if (!nla_mask) { + /* Create an exact match mask. We need to set to 0xff + * all the 'match->mask' fields that have been touched + * in 'match->key'. We cannot simply memset + * 'match->mask', because padding bytes and fields not + * specified in 'match->key' should be left to 0. + * Instead, we use a stream of netlink attributes, + * copied from 'key' and set to 0xff. + * ovs_key_from_nlattrs() will take care of filling + * 'match->mask' appropriately. + */ + newmask = kmemdup(nla_key, + nla_total_size(nla_len(nla_key)), + GFP_KERNEL); + if (!newmask) + return -ENOMEM; + + mask_set_nlattr(newmask, 0xff); + + /* The userspace does not send tunnel attributes that + * are 0, but we should not wildcard them nonetheless. + */ + if (match->key->tun_proto) + SW_FLOW_KEY_MEMSET_FIELD(match, tun_key, + 0xff, true); + + nla_mask = newmask; + } + + err = parse_flow_mask_nlattrs(nla_mask, a, &mask_attrs, log); + if (err) + goto free_newmask; + + /* Always match on tci. */ + SW_FLOW_KEY_PUT(match, eth.vlan.tci, htons(0xffff), true); + SW_FLOW_KEY_PUT(match, eth.cvlan.tci, htons(0xffff), true); + + err = parse_vlan_from_nlattrs(match, &mask_attrs, a, true, log); + if (err) + goto free_newmask; + + err = ovs_key_from_nlattrs(net, match, mask_attrs, a, true, + log); + if (err) + goto free_newmask; + } + + if (!match_validate(match, key_attrs, mask_attrs, log)) + err = -EINVAL; + +free_newmask: + kfree(newmask); + return err; +} + +static size_t get_ufid_len(const struct nlattr *attr, bool log) +{ + size_t len; + + if (!attr) + return 0; + + len = nla_len(attr); + if (len < 1 || len > MAX_UFID_LENGTH) { + OVS_NLERR(log, "ufid size %u bytes exceeds the range (1, %d)", + nla_len(attr), MAX_UFID_LENGTH); + return 0; + } + + return len; +} + +/* Initializes 'flow->ufid', returning true if 'attr' contains a valid UFID, + * or false otherwise. + */ +bool ovs_nla_get_ufid(struct sw_flow_id *sfid, const struct nlattr *attr, + bool log) +{ + sfid->ufid_len = get_ufid_len(attr, log); + if (sfid->ufid_len) + memcpy(sfid->ufid, nla_data(attr), sfid->ufid_len); + + return sfid->ufid_len; +} + +int ovs_nla_get_identifier(struct sw_flow_id *sfid, const struct nlattr *ufid, + const struct sw_flow_key *key, bool log) +{ + struct sw_flow_key *new_key; + + if (ovs_nla_get_ufid(sfid, ufid, log)) + return 0; + + /* If UFID was not provided, use unmasked key. */ + new_key = kmalloc(sizeof(*new_key), GFP_KERNEL); + if (!new_key) + return -ENOMEM; + memcpy(new_key, key, sizeof(*key)); + sfid->unmasked_key = new_key; + + return 0; +} + +u32 ovs_nla_get_ufid_flags(const struct nlattr *attr) +{ + return attr ? nla_get_u32(attr) : 0; +} + +/** + * ovs_nla_get_flow_metadata - parses Netlink attributes into a flow key. + * @net: Network namespace. + * @key: Receives extracted in_port, priority, tun_key, skb_mark and conntrack + * metadata. + * @a: Array of netlink attributes holding parsed %OVS_KEY_ATTR_* Netlink + * attributes. + * @attrs: Bit mask for the netlink attributes included in @a. + * @log: Boolean to allow kernel error logging. Normally true, but when + * probing for feature compatibility this should be passed in as false to + * suppress unnecessary error logging. + * + * This parses a series of Netlink attributes that form a flow key, which must + * take the same form accepted by flow_from_nlattrs(), but only enough of it to + * get the metadata, that is, the parts of the flow key that cannot be + * extracted from the packet itself. + * + * This must be called before the packet key fields are filled in 'key'. + */ + +int ovs_nla_get_flow_metadata(struct net *net, + const struct nlattr *a[OVS_KEY_ATTR_MAX + 1], + u64 attrs, struct sw_flow_key *key, bool log) +{ + struct sw_flow_match match; + + memset(&match, 0, sizeof(match)); + match.key = key; + + key->ct_state = 0; + key->ct_zone = 0; + key->ct_orig_proto = 0; + memset(&key->ct, 0, sizeof(key->ct)); + memset(&key->ipv4.ct_orig, 0, sizeof(key->ipv4.ct_orig)); + memset(&key->ipv6.ct_orig, 0, sizeof(key->ipv6.ct_orig)); + + key->phy.in_port = DP_MAX_PORTS; + + return metadata_from_nlattrs(net, &match, &attrs, a, false, log); +} + +static int ovs_nla_put_vlan(struct sk_buff *skb, const struct vlan_head *vh, + bool is_mask) +{ + __be16 eth_type = !is_mask ? vh->tpid : htons(0xffff); + + if (nla_put_be16(skb, OVS_KEY_ATTR_ETHERTYPE, eth_type) || + nla_put_be16(skb, OVS_KEY_ATTR_VLAN, vh->tci)) + return -EMSGSIZE; + return 0; +} + +static int nsh_key_to_nlattr(const struct ovs_key_nsh *nsh, bool is_mask, + struct sk_buff *skb) +{ + struct nlattr *start; + + start = nla_nest_start_noflag(skb, OVS_KEY_ATTR_NSH); + if (!start) + return -EMSGSIZE; + + if (nla_put(skb, OVS_NSH_KEY_ATTR_BASE, sizeof(nsh->base), &nsh->base)) + goto nla_put_failure; + + if (is_mask || nsh->base.mdtype == NSH_M_TYPE1) { + if (nla_put(skb, OVS_NSH_KEY_ATTR_MD1, + sizeof(nsh->context), nsh->context)) + goto nla_put_failure; + } + + /* Don't support MD type 2 yet */ + + nla_nest_end(skb, start); + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static int __ovs_nla_put_key(const struct sw_flow_key *swkey, + const struct sw_flow_key *output, bool is_mask, + struct sk_buff *skb) +{ + struct ovs_key_ethernet *eth_key; + struct nlattr *nla; + struct nlattr *encap = NULL; + struct nlattr *in_encap = NULL; + + if (nla_put_u32(skb, OVS_KEY_ATTR_RECIRC_ID, output->recirc_id)) + goto nla_put_failure; + + if (nla_put_u32(skb, OVS_KEY_ATTR_DP_HASH, output->ovs_flow_hash)) + goto nla_put_failure; + + if (nla_put_u32(skb, OVS_KEY_ATTR_PRIORITY, output->phy.priority)) + goto nla_put_failure; + + if ((swkey->tun_proto || is_mask)) { + const void *opts = NULL; + + if (output->tun_key.tun_flags & TUNNEL_OPTIONS_PRESENT) + opts = TUN_METADATA_OPTS(output, swkey->tun_opts_len); + + if (ip_tun_to_nlattr(skb, &output->tun_key, opts, + swkey->tun_opts_len, swkey->tun_proto, 0)) + goto nla_put_failure; + } + + if (swkey->phy.in_port == DP_MAX_PORTS) { + if (is_mask && (output->phy.in_port == 0xffff)) + if (nla_put_u32(skb, OVS_KEY_ATTR_IN_PORT, 0xffffffff)) + goto nla_put_failure; + } else { + u16 upper_u16; + upper_u16 = !is_mask ? 0 : 0xffff; + + if (nla_put_u32(skb, OVS_KEY_ATTR_IN_PORT, + (upper_u16 << 16) | output->phy.in_port)) + goto nla_put_failure; + } + + if (nla_put_u32(skb, OVS_KEY_ATTR_SKB_MARK, output->phy.skb_mark)) + goto nla_put_failure; + + if (ovs_ct_put_key(swkey, output, skb)) + goto nla_put_failure; + + if (ovs_key_mac_proto(swkey) == MAC_PROTO_ETHERNET) { + nla = nla_reserve(skb, OVS_KEY_ATTR_ETHERNET, sizeof(*eth_key)); + if (!nla) + goto nla_put_failure; + + eth_key = nla_data(nla); + ether_addr_copy(eth_key->eth_src, output->eth.src); + ether_addr_copy(eth_key->eth_dst, output->eth.dst); + + if (swkey->eth.vlan.tci || eth_type_vlan(swkey->eth.type)) { + if (ovs_nla_put_vlan(skb, &output->eth.vlan, is_mask)) + goto nla_put_failure; + encap = nla_nest_start_noflag(skb, OVS_KEY_ATTR_ENCAP); + if (!swkey->eth.vlan.tci) + goto unencap; + + if (swkey->eth.cvlan.tci || eth_type_vlan(swkey->eth.type)) { + if (ovs_nla_put_vlan(skb, &output->eth.cvlan, is_mask)) + goto nla_put_failure; + in_encap = nla_nest_start_noflag(skb, + OVS_KEY_ATTR_ENCAP); + if (!swkey->eth.cvlan.tci) + goto unencap; + } + } + + if (swkey->eth.type == htons(ETH_P_802_2)) { + /* + * Ethertype 802.2 is represented in the netlink with omitted + * OVS_KEY_ATTR_ETHERTYPE in the flow key attribute, and + * 0xffff in the mask attribute. Ethertype can also + * be wildcarded. + */ + if (is_mask && output->eth.type) + if (nla_put_be16(skb, OVS_KEY_ATTR_ETHERTYPE, + output->eth.type)) + goto nla_put_failure; + goto unencap; + } + } + + if (nla_put_be16(skb, OVS_KEY_ATTR_ETHERTYPE, output->eth.type)) + goto nla_put_failure; + + if (eth_type_vlan(swkey->eth.type)) { + /* There are 3 VLAN tags, we don't know anything about the rest + * of the packet, so truncate here. + */ + WARN_ON_ONCE(!(encap && in_encap)); + goto unencap; + } + + if (swkey->eth.type == htons(ETH_P_IP)) { + struct ovs_key_ipv4 *ipv4_key; + + nla = nla_reserve(skb, OVS_KEY_ATTR_IPV4, sizeof(*ipv4_key)); + if (!nla) + goto nla_put_failure; + ipv4_key = nla_data(nla); + ipv4_key->ipv4_src = output->ipv4.addr.src; + ipv4_key->ipv4_dst = output->ipv4.addr.dst; + ipv4_key->ipv4_proto = output->ip.proto; + ipv4_key->ipv4_tos = output->ip.tos; + ipv4_key->ipv4_ttl = output->ip.ttl; + ipv4_key->ipv4_frag = output->ip.frag; + } else if (swkey->eth.type == htons(ETH_P_IPV6)) { + struct ovs_key_ipv6 *ipv6_key; + struct ovs_key_ipv6_exthdrs *ipv6_exthdrs_key; + + nla = nla_reserve(skb, OVS_KEY_ATTR_IPV6, sizeof(*ipv6_key)); + if (!nla) + goto nla_put_failure; + ipv6_key = nla_data(nla); + memcpy(ipv6_key->ipv6_src, &output->ipv6.addr.src, + sizeof(ipv6_key->ipv6_src)); + memcpy(ipv6_key->ipv6_dst, &output->ipv6.addr.dst, + sizeof(ipv6_key->ipv6_dst)); + ipv6_key->ipv6_label = output->ipv6.label; + ipv6_key->ipv6_proto = output->ip.proto; + ipv6_key->ipv6_tclass = output->ip.tos; + ipv6_key->ipv6_hlimit = output->ip.ttl; + ipv6_key->ipv6_frag = output->ip.frag; + + nla = nla_reserve(skb, OVS_KEY_ATTR_IPV6_EXTHDRS, + sizeof(*ipv6_exthdrs_key)); + if (!nla) + goto nla_put_failure; + ipv6_exthdrs_key = nla_data(nla); + ipv6_exthdrs_key->hdrs = output->ipv6.exthdrs; + } else if (swkey->eth.type == htons(ETH_P_NSH)) { + if (nsh_key_to_nlattr(&output->nsh, is_mask, skb)) + goto nla_put_failure; + } else if (swkey->eth.type == htons(ETH_P_ARP) || + swkey->eth.type == htons(ETH_P_RARP)) { + struct ovs_key_arp *arp_key; + + nla = nla_reserve(skb, OVS_KEY_ATTR_ARP, sizeof(*arp_key)); + if (!nla) + goto nla_put_failure; + arp_key = nla_data(nla); + memset(arp_key, 0, sizeof(struct ovs_key_arp)); + arp_key->arp_sip = output->ipv4.addr.src; + arp_key->arp_tip = output->ipv4.addr.dst; + arp_key->arp_op = htons(output->ip.proto); + ether_addr_copy(arp_key->arp_sha, output->ipv4.arp.sha); + ether_addr_copy(arp_key->arp_tha, output->ipv4.arp.tha); + } else if (eth_p_mpls(swkey->eth.type)) { + u8 i, num_labels; + struct ovs_key_mpls *mpls_key; + + num_labels = hweight_long(output->mpls.num_labels_mask); + nla = nla_reserve(skb, OVS_KEY_ATTR_MPLS, + num_labels * sizeof(*mpls_key)); + if (!nla) + goto nla_put_failure; + + mpls_key = nla_data(nla); + for (i = 0; i < num_labels; i++) + mpls_key[i].mpls_lse = output->mpls.lse[i]; + } + + if ((swkey->eth.type == htons(ETH_P_IP) || + swkey->eth.type == htons(ETH_P_IPV6)) && + swkey->ip.frag != OVS_FRAG_TYPE_LATER) { + + if (swkey->ip.proto == IPPROTO_TCP) { + struct ovs_key_tcp *tcp_key; + + nla = nla_reserve(skb, OVS_KEY_ATTR_TCP, sizeof(*tcp_key)); + if (!nla) + goto nla_put_failure; + tcp_key = nla_data(nla); + tcp_key->tcp_src = output->tp.src; + tcp_key->tcp_dst = output->tp.dst; + if (nla_put_be16(skb, OVS_KEY_ATTR_TCP_FLAGS, + output->tp.flags)) + goto nla_put_failure; + } else if (swkey->ip.proto == IPPROTO_UDP) { + struct ovs_key_udp *udp_key; + + nla = nla_reserve(skb, OVS_KEY_ATTR_UDP, sizeof(*udp_key)); + if (!nla) + goto nla_put_failure; + udp_key = nla_data(nla); + udp_key->udp_src = output->tp.src; + udp_key->udp_dst = output->tp.dst; + } else if (swkey->ip.proto == IPPROTO_SCTP) { + struct ovs_key_sctp *sctp_key; + + nla = nla_reserve(skb, OVS_KEY_ATTR_SCTP, sizeof(*sctp_key)); + if (!nla) + goto nla_put_failure; + sctp_key = nla_data(nla); + sctp_key->sctp_src = output->tp.src; + sctp_key->sctp_dst = output->tp.dst; + } else if (swkey->eth.type == htons(ETH_P_IP) && + swkey->ip.proto == IPPROTO_ICMP) { + struct ovs_key_icmp *icmp_key; + + nla = nla_reserve(skb, OVS_KEY_ATTR_ICMP, sizeof(*icmp_key)); + if (!nla) + goto nla_put_failure; + icmp_key = nla_data(nla); + icmp_key->icmp_type = ntohs(output->tp.src); + icmp_key->icmp_code = ntohs(output->tp.dst); + } else if (swkey->eth.type == htons(ETH_P_IPV6) && + swkey->ip.proto == IPPROTO_ICMPV6) { + struct ovs_key_icmpv6 *icmpv6_key; + + nla = nla_reserve(skb, OVS_KEY_ATTR_ICMPV6, + sizeof(*icmpv6_key)); + if (!nla) + goto nla_put_failure; + icmpv6_key = nla_data(nla); + icmpv6_key->icmpv6_type = ntohs(output->tp.src); + icmpv6_key->icmpv6_code = ntohs(output->tp.dst); + + if (swkey->tp.src == htons(NDISC_NEIGHBOUR_SOLICITATION) || + swkey->tp.src == htons(NDISC_NEIGHBOUR_ADVERTISEMENT)) { + struct ovs_key_nd *nd_key; + + nla = nla_reserve(skb, OVS_KEY_ATTR_ND, sizeof(*nd_key)); + if (!nla) + goto nla_put_failure; + nd_key = nla_data(nla); + memcpy(nd_key->nd_target, &output->ipv6.nd.target, + sizeof(nd_key->nd_target)); + ether_addr_copy(nd_key->nd_sll, output->ipv6.nd.sll); + ether_addr_copy(nd_key->nd_tll, output->ipv6.nd.tll); + } + } + } + +unencap: + if (in_encap) + nla_nest_end(skb, in_encap); + if (encap) + nla_nest_end(skb, encap); + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +int ovs_nla_put_key(const struct sw_flow_key *swkey, + const struct sw_flow_key *output, int attr, bool is_mask, + struct sk_buff *skb) +{ + int err; + struct nlattr *nla; + + nla = nla_nest_start_noflag(skb, attr); + if (!nla) + return -EMSGSIZE; + err = __ovs_nla_put_key(swkey, output, is_mask, skb); + if (err) + return err; + nla_nest_end(skb, nla); + + return 0; +} + +/* Called with ovs_mutex or RCU read lock. */ +int ovs_nla_put_identifier(const struct sw_flow *flow, struct sk_buff *skb) +{ + if (ovs_identifier_is_ufid(&flow->id)) + return nla_put(skb, OVS_FLOW_ATTR_UFID, flow->id.ufid_len, + flow->id.ufid); + + return ovs_nla_put_key(flow->id.unmasked_key, flow->id.unmasked_key, + OVS_FLOW_ATTR_KEY, false, skb); +} + +/* Called with ovs_mutex or RCU read lock. */ +int ovs_nla_put_masked_key(const struct sw_flow *flow, struct sk_buff *skb) +{ + return ovs_nla_put_key(&flow->key, &flow->key, + OVS_FLOW_ATTR_KEY, false, skb); +} + +/* Called with ovs_mutex or RCU read lock. */ +int ovs_nla_put_mask(const struct sw_flow *flow, struct sk_buff *skb) +{ + return ovs_nla_put_key(&flow->key, &flow->mask->key, + OVS_FLOW_ATTR_MASK, true, skb); +} + +#define MAX_ACTIONS_BUFSIZE (32 * 1024) + +static struct sw_flow_actions *nla_alloc_flow_actions(int size) +{ + struct sw_flow_actions *sfa; + + WARN_ON_ONCE(size > MAX_ACTIONS_BUFSIZE); + + sfa = kmalloc(kmalloc_size_roundup(sizeof(*sfa) + size), GFP_KERNEL); + if (!sfa) + return ERR_PTR(-ENOMEM); + + sfa->actions_len = 0; + return sfa; +} + +static void ovs_nla_free_nested_actions(const struct nlattr *actions, int len); + +static void ovs_nla_free_check_pkt_len_action(const struct nlattr *action) +{ + const struct nlattr *a; + int rem; + + nla_for_each_nested(a, action, rem) { + switch (nla_type(a)) { + case OVS_CHECK_PKT_LEN_ATTR_ACTIONS_IF_LESS_EQUAL: + case OVS_CHECK_PKT_LEN_ATTR_ACTIONS_IF_GREATER: + ovs_nla_free_nested_actions(nla_data(a), nla_len(a)); + break; + } + } +} + +static void ovs_nla_free_clone_action(const struct nlattr *action) +{ + const struct nlattr *a = nla_data(action); + int rem = nla_len(action); + + switch (nla_type(a)) { + case OVS_CLONE_ATTR_EXEC: + /* The real list of actions follows this attribute. */ + a = nla_next(a, &rem); + ovs_nla_free_nested_actions(a, rem); + break; + } +} + +static void ovs_nla_free_dec_ttl_action(const struct nlattr *action) +{ + const struct nlattr *a = nla_data(action); + + switch (nla_type(a)) { + case OVS_DEC_TTL_ATTR_ACTION: + ovs_nla_free_nested_actions(nla_data(a), nla_len(a)); + break; + } +} + +static void ovs_nla_free_sample_action(const struct nlattr *action) +{ + const struct nlattr *a = nla_data(action); + int rem = nla_len(action); + + switch (nla_type(a)) { + case OVS_SAMPLE_ATTR_ARG: + /* The real list of actions follows this attribute. */ + a = nla_next(a, &rem); + ovs_nla_free_nested_actions(a, rem); + break; + } +} + +static void ovs_nla_free_set_action(const struct nlattr *a) +{ + const struct nlattr *ovs_key = nla_data(a); + struct ovs_tunnel_info *ovs_tun; + + switch (nla_type(ovs_key)) { + case OVS_KEY_ATTR_TUNNEL_INFO: + ovs_tun = nla_data(ovs_key); + dst_release((struct dst_entry *)ovs_tun->tun_dst); + break; + } +} + +static void ovs_nla_free_nested_actions(const struct nlattr *actions, int len) +{ + const struct nlattr *a; + int rem; + + /* Whenever new actions are added, the need to update this + * function should be considered. + */ + BUILD_BUG_ON(OVS_ACTION_ATTR_MAX != 23); + + if (!actions) + return; + + nla_for_each_attr(a, actions, len, rem) { + switch (nla_type(a)) { + case OVS_ACTION_ATTR_CHECK_PKT_LEN: + ovs_nla_free_check_pkt_len_action(a); + break; + + case OVS_ACTION_ATTR_CLONE: + ovs_nla_free_clone_action(a); + break; + + case OVS_ACTION_ATTR_CT: + ovs_ct_free_action(a); + break; + + case OVS_ACTION_ATTR_DEC_TTL: + ovs_nla_free_dec_ttl_action(a); + break; + + case OVS_ACTION_ATTR_SAMPLE: + ovs_nla_free_sample_action(a); + break; + + case OVS_ACTION_ATTR_SET: + ovs_nla_free_set_action(a); + break; + } + } +} + +void ovs_nla_free_flow_actions(struct sw_flow_actions *sf_acts) +{ + if (!sf_acts) + return; + + ovs_nla_free_nested_actions(sf_acts->actions, sf_acts->actions_len); + kfree(sf_acts); +} + +static void __ovs_nla_free_flow_actions(struct rcu_head *head) +{ + ovs_nla_free_flow_actions(container_of(head, struct sw_flow_actions, rcu)); +} + +/* Schedules 'sf_acts' to be freed after the next RCU grace period. + * The caller must hold rcu_read_lock for this to be sensible. */ +void ovs_nla_free_flow_actions_rcu(struct sw_flow_actions *sf_acts) +{ + call_rcu(&sf_acts->rcu, __ovs_nla_free_flow_actions); +} + +static struct nlattr *reserve_sfa_size(struct sw_flow_actions **sfa, + int attr_len, bool log) +{ + + struct sw_flow_actions *acts; + int new_acts_size; + size_t req_size = NLA_ALIGN(attr_len); + int next_offset = offsetof(struct sw_flow_actions, actions) + + (*sfa)->actions_len; + + if (req_size <= (ksize(*sfa) - next_offset)) + goto out; + + new_acts_size = max(next_offset + req_size, ksize(*sfa) * 2); + + if (new_acts_size > MAX_ACTIONS_BUFSIZE) { + if ((next_offset + req_size) > MAX_ACTIONS_BUFSIZE) { + OVS_NLERR(log, "Flow action size exceeds max %u", + MAX_ACTIONS_BUFSIZE); + return ERR_PTR(-EMSGSIZE); + } + new_acts_size = MAX_ACTIONS_BUFSIZE; + } + + acts = nla_alloc_flow_actions(new_acts_size); + if (IS_ERR(acts)) + return (void *)acts; + + memcpy(acts->actions, (*sfa)->actions, (*sfa)->actions_len); + acts->actions_len = (*sfa)->actions_len; + acts->orig_len = (*sfa)->orig_len; + kfree(*sfa); + *sfa = acts; + +out: + (*sfa)->actions_len += req_size; + return (struct nlattr *) ((unsigned char *)(*sfa) + next_offset); +} + +static struct nlattr *__add_action(struct sw_flow_actions **sfa, + int attrtype, void *data, int len, bool log) +{ + struct nlattr *a; + + a = reserve_sfa_size(sfa, nla_attr_size(len), log); + if (IS_ERR(a)) + return a; + + a->nla_type = attrtype; + a->nla_len = nla_attr_size(len); + + if (data) + memcpy(nla_data(a), data, len); + memset((unsigned char *) a + a->nla_len, 0, nla_padlen(len)); + + return a; +} + +int ovs_nla_add_action(struct sw_flow_actions **sfa, int attrtype, void *data, + int len, bool log) +{ + struct nlattr *a; + + a = __add_action(sfa, attrtype, data, len, log); + + return PTR_ERR_OR_ZERO(a); +} + +static inline int add_nested_action_start(struct sw_flow_actions **sfa, + int attrtype, bool log) +{ + int used = (*sfa)->actions_len; + int err; + + err = ovs_nla_add_action(sfa, attrtype, NULL, 0, log); + if (err) + return err; + + return used; +} + +static inline void add_nested_action_end(struct sw_flow_actions *sfa, + int st_offset) +{ + struct nlattr *a = (struct nlattr *) ((unsigned char *)sfa->actions + + st_offset); + + a->nla_len = sfa->actions_len - st_offset; +} + +static int __ovs_nla_copy_actions(struct net *net, const struct nlattr *attr, + const struct sw_flow_key *key, + struct sw_flow_actions **sfa, + __be16 eth_type, __be16 vlan_tci, + u32 mpls_label_count, bool log); + +static int validate_and_copy_sample(struct net *net, const struct nlattr *attr, + const struct sw_flow_key *key, + struct sw_flow_actions **sfa, + __be16 eth_type, __be16 vlan_tci, + u32 mpls_label_count, bool log, bool last) +{ + const struct nlattr *attrs[OVS_SAMPLE_ATTR_MAX + 1]; + const struct nlattr *probability, *actions; + const struct nlattr *a; + int rem, start, err; + struct sample_arg arg; + + memset(attrs, 0, sizeof(attrs)); + nla_for_each_nested(a, attr, rem) { + int type = nla_type(a); + if (!type || type > OVS_SAMPLE_ATTR_MAX || attrs[type]) + return -EINVAL; + attrs[type] = a; + } + if (rem) + return -EINVAL; + + probability = attrs[OVS_SAMPLE_ATTR_PROBABILITY]; + if (!probability || nla_len(probability) != sizeof(u32)) + return -EINVAL; + + actions = attrs[OVS_SAMPLE_ATTR_ACTIONS]; + if (!actions || (nla_len(actions) && nla_len(actions) < NLA_HDRLEN)) + return -EINVAL; + + /* validation done, copy sample action. */ + start = add_nested_action_start(sfa, OVS_ACTION_ATTR_SAMPLE, log); + if (start < 0) + return start; + + /* When both skb and flow may be changed, put the sample + * into a deferred fifo. On the other hand, if only skb + * may be modified, the actions can be executed in place. + * + * Do this analysis at the flow installation time. + * Set 'clone_action->exec' to true if the actions can be + * executed without being deferred. + * + * If the sample is the last action, it can always be excuted + * rather than deferred. + */ + arg.exec = last || !actions_may_change_flow(actions); + arg.probability = nla_get_u32(probability); + + err = ovs_nla_add_action(sfa, OVS_SAMPLE_ATTR_ARG, &arg, sizeof(arg), + log); + if (err) + return err; + + err = __ovs_nla_copy_actions(net, actions, key, sfa, + eth_type, vlan_tci, mpls_label_count, log); + + if (err) + return err; + + add_nested_action_end(*sfa, start); + + return 0; +} + +static int validate_and_copy_dec_ttl(struct net *net, + const struct nlattr *attr, + const struct sw_flow_key *key, + struct sw_flow_actions **sfa, + __be16 eth_type, __be16 vlan_tci, + u32 mpls_label_count, bool log) +{ + const struct nlattr *attrs[OVS_DEC_TTL_ATTR_MAX + 1]; + int start, action_start, err, rem; + const struct nlattr *a, *actions; + + memset(attrs, 0, sizeof(attrs)); + nla_for_each_nested(a, attr, rem) { + int type = nla_type(a); + + /* Ignore unknown attributes to be future proof. */ + if (type > OVS_DEC_TTL_ATTR_MAX) + continue; + + if (!type || attrs[type]) { + OVS_NLERR(log, "Duplicate or invalid key (type %d).", + type); + return -EINVAL; + } + + attrs[type] = a; + } + + if (rem) { + OVS_NLERR(log, "Message has %d unknown bytes.", rem); + return -EINVAL; + } + + actions = attrs[OVS_DEC_TTL_ATTR_ACTION]; + if (!actions || (nla_len(actions) && nla_len(actions) < NLA_HDRLEN)) { + OVS_NLERR(log, "Missing valid actions attribute."); + return -EINVAL; + } + + start = add_nested_action_start(sfa, OVS_ACTION_ATTR_DEC_TTL, log); + if (start < 0) + return start; + + action_start = add_nested_action_start(sfa, OVS_DEC_TTL_ATTR_ACTION, log); + if (action_start < 0) + return action_start; + + err = __ovs_nla_copy_actions(net, actions, key, sfa, eth_type, + vlan_tci, mpls_label_count, log); + if (err) + return err; + + add_nested_action_end(*sfa, action_start); + add_nested_action_end(*sfa, start); + return 0; +} + +static int validate_and_copy_clone(struct net *net, + const struct nlattr *attr, + const struct sw_flow_key *key, + struct sw_flow_actions **sfa, + __be16 eth_type, __be16 vlan_tci, + u32 mpls_label_count, bool log, bool last) +{ + int start, err; + u32 exec; + + if (nla_len(attr) && nla_len(attr) < NLA_HDRLEN) + return -EINVAL; + + start = add_nested_action_start(sfa, OVS_ACTION_ATTR_CLONE, log); + if (start < 0) + return start; + + exec = last || !actions_may_change_flow(attr); + + err = ovs_nla_add_action(sfa, OVS_CLONE_ATTR_EXEC, &exec, + sizeof(exec), log); + if (err) + return err; + + err = __ovs_nla_copy_actions(net, attr, key, sfa, + eth_type, vlan_tci, mpls_label_count, log); + if (err) + return err; + + add_nested_action_end(*sfa, start); + + return 0; +} + +void ovs_match_init(struct sw_flow_match *match, + struct sw_flow_key *key, + bool reset_key, + struct sw_flow_mask *mask) +{ + memset(match, 0, sizeof(*match)); + match->key = key; + match->mask = mask; + + if (reset_key) + memset(key, 0, sizeof(*key)); + + if (mask) { + memset(&mask->key, 0, sizeof(mask->key)); + mask->range.start = mask->range.end = 0; + } +} + +static int validate_geneve_opts(struct sw_flow_key *key) +{ + struct geneve_opt *option; + int opts_len = key->tun_opts_len; + bool crit_opt = false; + + option = (struct geneve_opt *)TUN_METADATA_OPTS(key, key->tun_opts_len); + while (opts_len > 0) { + int len; + + if (opts_len < sizeof(*option)) + return -EINVAL; + + len = sizeof(*option) + option->length * 4; + if (len > opts_len) + return -EINVAL; + + crit_opt |= !!(option->type & GENEVE_CRIT_OPT_TYPE); + + option = (struct geneve_opt *)((u8 *)option + len); + opts_len -= len; + } + + key->tun_key.tun_flags |= crit_opt ? TUNNEL_CRIT_OPT : 0; + + return 0; +} + +static int validate_and_copy_set_tun(const struct nlattr *attr, + struct sw_flow_actions **sfa, bool log) +{ + struct sw_flow_match match; + struct sw_flow_key key; + struct metadata_dst *tun_dst; + struct ip_tunnel_info *tun_info; + struct ovs_tunnel_info *ovs_tun; + struct nlattr *a; + int err = 0, start, opts_type; + __be16 dst_opt_type; + + dst_opt_type = 0; + ovs_match_init(&match, &key, true, NULL); + opts_type = ip_tun_from_nlattr(nla_data(attr), &match, false, log); + if (opts_type < 0) + return opts_type; + + if (key.tun_opts_len) { + switch (opts_type) { + case OVS_TUNNEL_KEY_ATTR_GENEVE_OPTS: + err = validate_geneve_opts(&key); + if (err < 0) + return err; + dst_opt_type = TUNNEL_GENEVE_OPT; + break; + case OVS_TUNNEL_KEY_ATTR_VXLAN_OPTS: + dst_opt_type = TUNNEL_VXLAN_OPT; + break; + case OVS_TUNNEL_KEY_ATTR_ERSPAN_OPTS: + dst_opt_type = TUNNEL_ERSPAN_OPT; + break; + } + } + + start = add_nested_action_start(sfa, OVS_ACTION_ATTR_SET, log); + if (start < 0) + return start; + + tun_dst = metadata_dst_alloc(key.tun_opts_len, METADATA_IP_TUNNEL, + GFP_KERNEL); + + if (!tun_dst) + return -ENOMEM; + + err = dst_cache_init(&tun_dst->u.tun_info.dst_cache, GFP_KERNEL); + if (err) { + dst_release((struct dst_entry *)tun_dst); + return err; + } + + a = __add_action(sfa, OVS_KEY_ATTR_TUNNEL_INFO, NULL, + sizeof(*ovs_tun), log); + if (IS_ERR(a)) { + dst_release((struct dst_entry *)tun_dst); + return PTR_ERR(a); + } + + ovs_tun = nla_data(a); + ovs_tun->tun_dst = tun_dst; + + tun_info = &tun_dst->u.tun_info; + tun_info->mode = IP_TUNNEL_INFO_TX; + if (key.tun_proto == AF_INET6) + tun_info->mode |= IP_TUNNEL_INFO_IPV6; + else if (key.tun_proto == AF_INET && key.tun_key.u.ipv4.dst == 0) + tun_info->mode |= IP_TUNNEL_INFO_BRIDGE; + tun_info->key = key.tun_key; + + /* We need to store the options in the action itself since + * everything else will go away after flow setup. We can append + * it to tun_info and then point there. + */ + ip_tunnel_info_opts_set(tun_info, + TUN_METADATA_OPTS(&key, key.tun_opts_len), + key.tun_opts_len, dst_opt_type); + add_nested_action_end(*sfa, start); + + return err; +} + +static bool validate_nsh(const struct nlattr *attr, bool is_mask, + bool is_push_nsh, bool log) +{ + struct sw_flow_match match; + struct sw_flow_key key; + int ret = 0; + + ovs_match_init(&match, &key, true, NULL); + ret = nsh_key_put_from_nlattr(attr, &match, is_mask, + is_push_nsh, log); + return !ret; +} + +/* Return false if there are any non-masked bits set. + * Mask follows data immediately, before any netlink padding. + */ +static bool validate_masked(u8 *data, int len) +{ + u8 *mask = data + len; + + while (len--) + if (*data++ & ~*mask++) + return false; + + return true; +} + +static int validate_set(const struct nlattr *a, + const struct sw_flow_key *flow_key, + struct sw_flow_actions **sfa, bool *skip_copy, + u8 mac_proto, __be16 eth_type, bool masked, bool log) +{ + const struct nlattr *ovs_key = nla_data(a); + int key_type = nla_type(ovs_key); + size_t key_len; + + /* There can be only one key in a action */ + if (nla_total_size(nla_len(ovs_key)) != nla_len(a)) + return -EINVAL; + + key_len = nla_len(ovs_key); + if (masked) + key_len /= 2; + + if (key_type > OVS_KEY_ATTR_MAX || + !check_attr_len(key_len, ovs_key_lens[key_type].len)) + return -EINVAL; + + if (masked && !validate_masked(nla_data(ovs_key), key_len)) + return -EINVAL; + + switch (key_type) { + case OVS_KEY_ATTR_PRIORITY: + case OVS_KEY_ATTR_SKB_MARK: + case OVS_KEY_ATTR_CT_MARK: + case OVS_KEY_ATTR_CT_LABELS: + break; + + case OVS_KEY_ATTR_ETHERNET: + if (mac_proto != MAC_PROTO_ETHERNET) + return -EINVAL; + break; + + case OVS_KEY_ATTR_TUNNEL: { + int err; + + if (masked) + return -EINVAL; /* Masked tunnel set not supported. */ + + *skip_copy = true; + err = validate_and_copy_set_tun(a, sfa, log); + if (err) + return err; + break; + } + case OVS_KEY_ATTR_IPV4: { + const struct ovs_key_ipv4 *ipv4_key; + + if (eth_type != htons(ETH_P_IP)) + return -EINVAL; + + ipv4_key = nla_data(ovs_key); + + if (masked) { + const struct ovs_key_ipv4 *mask = ipv4_key + 1; + + /* Non-writeable fields. */ + if (mask->ipv4_proto || mask->ipv4_frag) + return -EINVAL; + } else { + if (ipv4_key->ipv4_proto != flow_key->ip.proto) + return -EINVAL; + + if (ipv4_key->ipv4_frag != flow_key->ip.frag) + return -EINVAL; + } + break; + } + case OVS_KEY_ATTR_IPV6: { + const struct ovs_key_ipv6 *ipv6_key; + + if (eth_type != htons(ETH_P_IPV6)) + return -EINVAL; + + ipv6_key = nla_data(ovs_key); + + if (masked) { + const struct ovs_key_ipv6 *mask = ipv6_key + 1; + + /* Non-writeable fields. */ + if (mask->ipv6_proto || mask->ipv6_frag) + return -EINVAL; + + /* Invalid bits in the flow label mask? */ + if (ntohl(mask->ipv6_label) & 0xFFF00000) + return -EINVAL; + } else { + if (ipv6_key->ipv6_proto != flow_key->ip.proto) + return -EINVAL; + + if (ipv6_key->ipv6_frag != flow_key->ip.frag) + return -EINVAL; + } + if (ntohl(ipv6_key->ipv6_label) & 0xFFF00000) + return -EINVAL; + + break; + } + case OVS_KEY_ATTR_TCP: + if ((eth_type != htons(ETH_P_IP) && + eth_type != htons(ETH_P_IPV6)) || + flow_key->ip.proto != IPPROTO_TCP) + return -EINVAL; + + break; + + case OVS_KEY_ATTR_UDP: + if ((eth_type != htons(ETH_P_IP) && + eth_type != htons(ETH_P_IPV6)) || + flow_key->ip.proto != IPPROTO_UDP) + return -EINVAL; + + break; + + case OVS_KEY_ATTR_MPLS: + if (!eth_p_mpls(eth_type)) + return -EINVAL; + break; + + case OVS_KEY_ATTR_SCTP: + if ((eth_type != htons(ETH_P_IP) && + eth_type != htons(ETH_P_IPV6)) || + flow_key->ip.proto != IPPROTO_SCTP) + return -EINVAL; + + break; + + case OVS_KEY_ATTR_NSH: + if (eth_type != htons(ETH_P_NSH)) + return -EINVAL; + if (!validate_nsh(nla_data(a), masked, false, log)) + return -EINVAL; + break; + + default: + return -EINVAL; + } + + /* Convert non-masked non-tunnel set actions to masked set actions. */ + if (!masked && key_type != OVS_KEY_ATTR_TUNNEL) { + int start, len = key_len * 2; + struct nlattr *at; + + *skip_copy = true; + + start = add_nested_action_start(sfa, + OVS_ACTION_ATTR_SET_TO_MASKED, + log); + if (start < 0) + return start; + + at = __add_action(sfa, key_type, NULL, len, log); + if (IS_ERR(at)) + return PTR_ERR(at); + + memcpy(nla_data(at), nla_data(ovs_key), key_len); /* Key. */ + memset(nla_data(at) + key_len, 0xff, key_len); /* Mask. */ + /* Clear non-writeable bits from otherwise writeable fields. */ + if (key_type == OVS_KEY_ATTR_IPV6) { + struct ovs_key_ipv6 *mask = nla_data(at) + key_len; + + mask->ipv6_label &= htonl(0x000FFFFF); + } + add_nested_action_end(*sfa, start); + } + + return 0; +} + +static int validate_userspace(const struct nlattr *attr) +{ + static const struct nla_policy userspace_policy[OVS_USERSPACE_ATTR_MAX + 1] = { + [OVS_USERSPACE_ATTR_PID] = {.type = NLA_U32 }, + [OVS_USERSPACE_ATTR_USERDATA] = {.type = NLA_UNSPEC }, + [OVS_USERSPACE_ATTR_EGRESS_TUN_PORT] = {.type = NLA_U32 }, + }; + struct nlattr *a[OVS_USERSPACE_ATTR_MAX + 1]; + int error; + + error = nla_parse_nested_deprecated(a, OVS_USERSPACE_ATTR_MAX, attr, + userspace_policy, NULL); + if (error) + return error; + + if (!a[OVS_USERSPACE_ATTR_PID] || + !nla_get_u32(a[OVS_USERSPACE_ATTR_PID])) + return -EINVAL; + + return 0; +} + +static const struct nla_policy cpl_policy[OVS_CHECK_PKT_LEN_ATTR_MAX + 1] = { + [OVS_CHECK_PKT_LEN_ATTR_PKT_LEN] = {.type = NLA_U16 }, + [OVS_CHECK_PKT_LEN_ATTR_ACTIONS_IF_GREATER] = {.type = NLA_NESTED }, + [OVS_CHECK_PKT_LEN_ATTR_ACTIONS_IF_LESS_EQUAL] = {.type = NLA_NESTED }, +}; + +static int validate_and_copy_check_pkt_len(struct net *net, + const struct nlattr *attr, + const struct sw_flow_key *key, + struct sw_flow_actions **sfa, + __be16 eth_type, __be16 vlan_tci, + u32 mpls_label_count, + bool log, bool last) +{ + const struct nlattr *acts_if_greater, *acts_if_lesser_eq; + struct nlattr *a[OVS_CHECK_PKT_LEN_ATTR_MAX + 1]; + struct check_pkt_len_arg arg; + int nested_acts_start; + int start, err; + + err = nla_parse_deprecated_strict(a, OVS_CHECK_PKT_LEN_ATTR_MAX, + nla_data(attr), nla_len(attr), + cpl_policy, NULL); + if (err) + return err; + + if (!a[OVS_CHECK_PKT_LEN_ATTR_PKT_LEN] || + !nla_get_u16(a[OVS_CHECK_PKT_LEN_ATTR_PKT_LEN])) + return -EINVAL; + + acts_if_lesser_eq = a[OVS_CHECK_PKT_LEN_ATTR_ACTIONS_IF_LESS_EQUAL]; + acts_if_greater = a[OVS_CHECK_PKT_LEN_ATTR_ACTIONS_IF_GREATER]; + + /* Both the nested action should be present. */ + if (!acts_if_greater || !acts_if_lesser_eq) + return -EINVAL; + + /* validation done, copy the nested actions. */ + start = add_nested_action_start(sfa, OVS_ACTION_ATTR_CHECK_PKT_LEN, + log); + if (start < 0) + return start; + + arg.pkt_len = nla_get_u16(a[OVS_CHECK_PKT_LEN_ATTR_PKT_LEN]); + arg.exec_for_lesser_equal = + last || !actions_may_change_flow(acts_if_lesser_eq); + arg.exec_for_greater = + last || !actions_may_change_flow(acts_if_greater); + + err = ovs_nla_add_action(sfa, OVS_CHECK_PKT_LEN_ATTR_ARG, &arg, + sizeof(arg), log); + if (err) + return err; + + nested_acts_start = add_nested_action_start(sfa, + OVS_CHECK_PKT_LEN_ATTR_ACTIONS_IF_LESS_EQUAL, log); + if (nested_acts_start < 0) + return nested_acts_start; + + err = __ovs_nla_copy_actions(net, acts_if_lesser_eq, key, sfa, + eth_type, vlan_tci, mpls_label_count, log); + + if (err) + return err; + + add_nested_action_end(*sfa, nested_acts_start); + + nested_acts_start = add_nested_action_start(sfa, + OVS_CHECK_PKT_LEN_ATTR_ACTIONS_IF_GREATER, log); + if (nested_acts_start < 0) + return nested_acts_start; + + err = __ovs_nla_copy_actions(net, acts_if_greater, key, sfa, + eth_type, vlan_tci, mpls_label_count, log); + + if (err) + return err; + + add_nested_action_end(*sfa, nested_acts_start); + add_nested_action_end(*sfa, start); + return 0; +} + +static int copy_action(const struct nlattr *from, + struct sw_flow_actions **sfa, bool log) +{ + int totlen = NLA_ALIGN(from->nla_len); + struct nlattr *to; + + to = reserve_sfa_size(sfa, from->nla_len, log); + if (IS_ERR(to)) + return PTR_ERR(to); + + memcpy(to, from, totlen); + return 0; +} + +static int __ovs_nla_copy_actions(struct net *net, const struct nlattr *attr, + const struct sw_flow_key *key, + struct sw_flow_actions **sfa, + __be16 eth_type, __be16 vlan_tci, + u32 mpls_label_count, bool log) +{ + u8 mac_proto = ovs_key_mac_proto(key); + const struct nlattr *a; + int rem, err; + + nla_for_each_nested(a, attr, rem) { + /* Expected argument lengths, (u32)-1 for variable length. */ + static const u32 action_lens[OVS_ACTION_ATTR_MAX + 1] = { + [OVS_ACTION_ATTR_OUTPUT] = sizeof(u32), + [OVS_ACTION_ATTR_RECIRC] = sizeof(u32), + [OVS_ACTION_ATTR_USERSPACE] = (u32)-1, + [OVS_ACTION_ATTR_PUSH_MPLS] = sizeof(struct ovs_action_push_mpls), + [OVS_ACTION_ATTR_POP_MPLS] = sizeof(__be16), + [OVS_ACTION_ATTR_PUSH_VLAN] = sizeof(struct ovs_action_push_vlan), + [OVS_ACTION_ATTR_POP_VLAN] = 0, + [OVS_ACTION_ATTR_SET] = (u32)-1, + [OVS_ACTION_ATTR_SET_MASKED] = (u32)-1, + [OVS_ACTION_ATTR_SAMPLE] = (u32)-1, + [OVS_ACTION_ATTR_HASH] = sizeof(struct ovs_action_hash), + [OVS_ACTION_ATTR_CT] = (u32)-1, + [OVS_ACTION_ATTR_CT_CLEAR] = 0, + [OVS_ACTION_ATTR_TRUNC] = sizeof(struct ovs_action_trunc), + [OVS_ACTION_ATTR_PUSH_ETH] = sizeof(struct ovs_action_push_eth), + [OVS_ACTION_ATTR_POP_ETH] = 0, + [OVS_ACTION_ATTR_PUSH_NSH] = (u32)-1, + [OVS_ACTION_ATTR_POP_NSH] = 0, + [OVS_ACTION_ATTR_METER] = sizeof(u32), + [OVS_ACTION_ATTR_CLONE] = (u32)-1, + [OVS_ACTION_ATTR_CHECK_PKT_LEN] = (u32)-1, + [OVS_ACTION_ATTR_ADD_MPLS] = sizeof(struct ovs_action_add_mpls), + [OVS_ACTION_ATTR_DEC_TTL] = (u32)-1, + }; + const struct ovs_action_push_vlan *vlan; + int type = nla_type(a); + bool skip_copy; + + if (type > OVS_ACTION_ATTR_MAX || + (action_lens[type] != nla_len(a) && + action_lens[type] != (u32)-1)) + return -EINVAL; + + skip_copy = false; + switch (type) { + case OVS_ACTION_ATTR_UNSPEC: + return -EINVAL; + + case OVS_ACTION_ATTR_USERSPACE: + err = validate_userspace(a); + if (err) + return err; + break; + + case OVS_ACTION_ATTR_OUTPUT: + if (nla_get_u32(a) >= DP_MAX_PORTS) + return -EINVAL; + break; + + case OVS_ACTION_ATTR_TRUNC: { + const struct ovs_action_trunc *trunc = nla_data(a); + + if (trunc->max_len < ETH_HLEN) + return -EINVAL; + break; + } + + case OVS_ACTION_ATTR_HASH: { + const struct ovs_action_hash *act_hash = nla_data(a); + + switch (act_hash->hash_alg) { + case OVS_HASH_ALG_L4: + break; + default: + return -EINVAL; + } + + break; + } + + case OVS_ACTION_ATTR_POP_VLAN: + if (mac_proto != MAC_PROTO_ETHERNET) + return -EINVAL; + vlan_tci = htons(0); + break; + + case OVS_ACTION_ATTR_PUSH_VLAN: + if (mac_proto != MAC_PROTO_ETHERNET) + return -EINVAL; + vlan = nla_data(a); + if (!eth_type_vlan(vlan->vlan_tpid)) + return -EINVAL; + if (!(vlan->vlan_tci & htons(VLAN_CFI_MASK))) + return -EINVAL; + vlan_tci = vlan->vlan_tci; + break; + + case OVS_ACTION_ATTR_RECIRC: + break; + + case OVS_ACTION_ATTR_ADD_MPLS: { + const struct ovs_action_add_mpls *mpls = nla_data(a); + + if (!eth_p_mpls(mpls->mpls_ethertype)) + return -EINVAL; + + if (mpls->tun_flags & OVS_MPLS_L3_TUNNEL_FLAG_MASK) { + if (vlan_tci & htons(VLAN_CFI_MASK) || + (eth_type != htons(ETH_P_IP) && + eth_type != htons(ETH_P_IPV6) && + eth_type != htons(ETH_P_ARP) && + eth_type != htons(ETH_P_RARP) && + !eth_p_mpls(eth_type))) + return -EINVAL; + mpls_label_count++; + } else { + if (mac_proto == MAC_PROTO_ETHERNET) { + mpls_label_count = 1; + mac_proto = MAC_PROTO_NONE; + } else { + mpls_label_count++; + } + } + eth_type = mpls->mpls_ethertype; + break; + } + + case OVS_ACTION_ATTR_PUSH_MPLS: { + const struct ovs_action_push_mpls *mpls = nla_data(a); + + if (!eth_p_mpls(mpls->mpls_ethertype)) + return -EINVAL; + /* Prohibit push MPLS other than to a white list + * for packets that have a known tag order. + */ + if (vlan_tci & htons(VLAN_CFI_MASK) || + (eth_type != htons(ETH_P_IP) && + eth_type != htons(ETH_P_IPV6) && + eth_type != htons(ETH_P_ARP) && + eth_type != htons(ETH_P_RARP) && + !eth_p_mpls(eth_type))) + return -EINVAL; + eth_type = mpls->mpls_ethertype; + mpls_label_count++; + break; + } + + case OVS_ACTION_ATTR_POP_MPLS: { + __be16 proto; + if (vlan_tci & htons(VLAN_CFI_MASK) || + !eth_p_mpls(eth_type)) + return -EINVAL; + + /* Disallow subsequent L2.5+ set actions and mpls_pop + * actions once the last MPLS label in the packet is + * popped as there is no check here to ensure that + * the new eth type is valid and thus set actions could + * write off the end of the packet or otherwise corrupt + * it. + * + * Support for these actions is planned using packet + * recirculation. + */ + proto = nla_get_be16(a); + + if (proto == htons(ETH_P_TEB) && + mac_proto != MAC_PROTO_NONE) + return -EINVAL; + + mpls_label_count--; + + if (!eth_p_mpls(proto) || !mpls_label_count) + eth_type = htons(0); + else + eth_type = proto; + + break; + } + + case OVS_ACTION_ATTR_SET: + err = validate_set(a, key, sfa, + &skip_copy, mac_proto, eth_type, + false, log); + if (err) + return err; + break; + + case OVS_ACTION_ATTR_SET_MASKED: + err = validate_set(a, key, sfa, + &skip_copy, mac_proto, eth_type, + true, log); + if (err) + return err; + break; + + case OVS_ACTION_ATTR_SAMPLE: { + bool last = nla_is_last(a, rem); + + err = validate_and_copy_sample(net, a, key, sfa, + eth_type, vlan_tci, + mpls_label_count, + log, last); + if (err) + return err; + skip_copy = true; + break; + } + + case OVS_ACTION_ATTR_CT: + err = ovs_ct_copy_action(net, a, key, sfa, log); + if (err) + return err; + skip_copy = true; + break; + + case OVS_ACTION_ATTR_CT_CLEAR: + break; + + case OVS_ACTION_ATTR_PUSH_ETH: + /* Disallow pushing an Ethernet header if one + * is already present */ + if (mac_proto != MAC_PROTO_NONE) + return -EINVAL; + mac_proto = MAC_PROTO_ETHERNET; + break; + + case OVS_ACTION_ATTR_POP_ETH: + if (mac_proto != MAC_PROTO_ETHERNET) + return -EINVAL; + if (vlan_tci & htons(VLAN_CFI_MASK)) + return -EINVAL; + mac_proto = MAC_PROTO_NONE; + break; + + case OVS_ACTION_ATTR_PUSH_NSH: + if (mac_proto != MAC_PROTO_ETHERNET) { + u8 next_proto; + + next_proto = tun_p_from_eth_p(eth_type); + if (!next_proto) + return -EINVAL; + } + mac_proto = MAC_PROTO_NONE; + if (!validate_nsh(nla_data(a), false, true, true)) + return -EINVAL; + break; + + case OVS_ACTION_ATTR_POP_NSH: { + __be16 inner_proto; + + if (eth_type != htons(ETH_P_NSH)) + return -EINVAL; + inner_proto = tun_p_to_eth_p(key->nsh.base.np); + if (!inner_proto) + return -EINVAL; + if (key->nsh.base.np == TUN_P_ETHERNET) + mac_proto = MAC_PROTO_ETHERNET; + else + mac_proto = MAC_PROTO_NONE; + break; + } + + case OVS_ACTION_ATTR_METER: + /* Non-existent meters are simply ignored. */ + break; + + case OVS_ACTION_ATTR_CLONE: { + bool last = nla_is_last(a, rem); + + err = validate_and_copy_clone(net, a, key, sfa, + eth_type, vlan_tci, + mpls_label_count, + log, last); + if (err) + return err; + skip_copy = true; + break; + } + + case OVS_ACTION_ATTR_CHECK_PKT_LEN: { + bool last = nla_is_last(a, rem); + + err = validate_and_copy_check_pkt_len(net, a, key, sfa, + eth_type, + vlan_tci, + mpls_label_count, + log, last); + if (err) + return err; + skip_copy = true; + break; + } + + case OVS_ACTION_ATTR_DEC_TTL: + err = validate_and_copy_dec_ttl(net, a, key, sfa, + eth_type, vlan_tci, + mpls_label_count, log); + if (err) + return err; + skip_copy = true; + break; + + default: + OVS_NLERR(log, "Unknown Action type %d", type); + return -EINVAL; + } + if (!skip_copy) { + err = copy_action(a, sfa, log); + if (err) + return err; + } + } + + if (rem > 0) + return -EINVAL; + + return 0; +} + +/* 'key' must be the masked key. */ +int ovs_nla_copy_actions(struct net *net, const struct nlattr *attr, + const struct sw_flow_key *key, + struct sw_flow_actions **sfa, bool log) +{ + int err; + u32 mpls_label_count = 0; + + *sfa = nla_alloc_flow_actions(min(nla_len(attr), MAX_ACTIONS_BUFSIZE)); + if (IS_ERR(*sfa)) + return PTR_ERR(*sfa); + + if (eth_p_mpls(key->eth.type)) + mpls_label_count = hweight_long(key->mpls.num_labels_mask); + + (*sfa)->orig_len = nla_len(attr); + err = __ovs_nla_copy_actions(net, attr, key, sfa, key->eth.type, + key->eth.vlan.tci, mpls_label_count, log); + if (err) + ovs_nla_free_flow_actions(*sfa); + + return err; +} + +static int sample_action_to_attr(const struct nlattr *attr, + struct sk_buff *skb) +{ + struct nlattr *start, *ac_start = NULL, *sample_arg; + int err = 0, rem = nla_len(attr); + const struct sample_arg *arg; + struct nlattr *actions; + + start = nla_nest_start_noflag(skb, OVS_ACTION_ATTR_SAMPLE); + if (!start) + return -EMSGSIZE; + + sample_arg = nla_data(attr); + arg = nla_data(sample_arg); + actions = nla_next(sample_arg, &rem); + + if (nla_put_u32(skb, OVS_SAMPLE_ATTR_PROBABILITY, arg->probability)) { + err = -EMSGSIZE; + goto out; + } + + ac_start = nla_nest_start_noflag(skb, OVS_SAMPLE_ATTR_ACTIONS); + if (!ac_start) { + err = -EMSGSIZE; + goto out; + } + + err = ovs_nla_put_actions(actions, rem, skb); + +out: + if (err) { + nla_nest_cancel(skb, ac_start); + nla_nest_cancel(skb, start); + } else { + nla_nest_end(skb, ac_start); + nla_nest_end(skb, start); + } + + return err; +} + +static int clone_action_to_attr(const struct nlattr *attr, + struct sk_buff *skb) +{ + struct nlattr *start; + int err = 0, rem = nla_len(attr); + + start = nla_nest_start_noflag(skb, OVS_ACTION_ATTR_CLONE); + if (!start) + return -EMSGSIZE; + + /* Skipping the OVS_CLONE_ATTR_EXEC that is always the first attribute. */ + attr = nla_next(nla_data(attr), &rem); + err = ovs_nla_put_actions(attr, rem, skb); + + if (err) + nla_nest_cancel(skb, start); + else + nla_nest_end(skb, start); + + return err; +} + +static int check_pkt_len_action_to_attr(const struct nlattr *attr, + struct sk_buff *skb) +{ + struct nlattr *start, *ac_start = NULL; + const struct check_pkt_len_arg *arg; + const struct nlattr *a, *cpl_arg; + int err = 0, rem = nla_len(attr); + + start = nla_nest_start_noflag(skb, OVS_ACTION_ATTR_CHECK_PKT_LEN); + if (!start) + return -EMSGSIZE; + + /* The first nested attribute in 'attr' is always + * 'OVS_CHECK_PKT_LEN_ATTR_ARG'. + */ + cpl_arg = nla_data(attr); + arg = nla_data(cpl_arg); + + if (nla_put_u16(skb, OVS_CHECK_PKT_LEN_ATTR_PKT_LEN, arg->pkt_len)) { + err = -EMSGSIZE; + goto out; + } + + /* Second nested attribute in 'attr' is always + * 'OVS_CHECK_PKT_LEN_ATTR_ACTIONS_IF_LESS_EQUAL'. + */ + a = nla_next(cpl_arg, &rem); + ac_start = nla_nest_start_noflag(skb, + OVS_CHECK_PKT_LEN_ATTR_ACTIONS_IF_LESS_EQUAL); + if (!ac_start) { + err = -EMSGSIZE; + goto out; + } + + err = ovs_nla_put_actions(nla_data(a), nla_len(a), skb); + if (err) { + nla_nest_cancel(skb, ac_start); + goto out; + } else { + nla_nest_end(skb, ac_start); + } + + /* Third nested attribute in 'attr' is always + * OVS_CHECK_PKT_LEN_ATTR_ACTIONS_IF_GREATER. + */ + a = nla_next(a, &rem); + ac_start = nla_nest_start_noflag(skb, + OVS_CHECK_PKT_LEN_ATTR_ACTIONS_IF_GREATER); + if (!ac_start) { + err = -EMSGSIZE; + goto out; + } + + err = ovs_nla_put_actions(nla_data(a), nla_len(a), skb); + if (err) { + nla_nest_cancel(skb, ac_start); + goto out; + } else { + nla_nest_end(skb, ac_start); + } + + nla_nest_end(skb, start); + return 0; + +out: + nla_nest_cancel(skb, start); + return err; +} + +static int dec_ttl_action_to_attr(const struct nlattr *attr, + struct sk_buff *skb) +{ + struct nlattr *start, *action_start; + const struct nlattr *a; + int err = 0, rem; + + start = nla_nest_start_noflag(skb, OVS_ACTION_ATTR_DEC_TTL); + if (!start) + return -EMSGSIZE; + + nla_for_each_attr(a, nla_data(attr), nla_len(attr), rem) { + switch (nla_type(a)) { + case OVS_DEC_TTL_ATTR_ACTION: + + action_start = nla_nest_start_noflag(skb, OVS_DEC_TTL_ATTR_ACTION); + if (!action_start) { + err = -EMSGSIZE; + goto out; + } + + err = ovs_nla_put_actions(nla_data(a), nla_len(a), skb); + if (err) + goto out; + + nla_nest_end(skb, action_start); + break; + + default: + /* Ignore all other option to be future compatible */ + break; + } + } + + nla_nest_end(skb, start); + return 0; + +out: + nla_nest_cancel(skb, start); + return err; +} + +static int set_action_to_attr(const struct nlattr *a, struct sk_buff *skb) +{ + const struct nlattr *ovs_key = nla_data(a); + int key_type = nla_type(ovs_key); + struct nlattr *start; + int err; + + switch (key_type) { + case OVS_KEY_ATTR_TUNNEL_INFO: { + struct ovs_tunnel_info *ovs_tun = nla_data(ovs_key); + struct ip_tunnel_info *tun_info = &ovs_tun->tun_dst->u.tun_info; + + start = nla_nest_start_noflag(skb, OVS_ACTION_ATTR_SET); + if (!start) + return -EMSGSIZE; + + err = ip_tun_to_nlattr(skb, &tun_info->key, + ip_tunnel_info_opts(tun_info), + tun_info->options_len, + ip_tunnel_info_af(tun_info), tun_info->mode); + if (err) + return err; + nla_nest_end(skb, start); + break; + } + default: + if (nla_put(skb, OVS_ACTION_ATTR_SET, nla_len(a), ovs_key)) + return -EMSGSIZE; + break; + } + + return 0; +} + +static int masked_set_action_to_set_action_attr(const struct nlattr *a, + struct sk_buff *skb) +{ + const struct nlattr *ovs_key = nla_data(a); + struct nlattr *nla; + size_t key_len = nla_len(ovs_key) / 2; + + /* Revert the conversion we did from a non-masked set action to + * masked set action. + */ + nla = nla_nest_start_noflag(skb, OVS_ACTION_ATTR_SET); + if (!nla) + return -EMSGSIZE; + + if (nla_put(skb, nla_type(ovs_key), key_len, nla_data(ovs_key))) + return -EMSGSIZE; + + nla_nest_end(skb, nla); + return 0; +} + +int ovs_nla_put_actions(const struct nlattr *attr, int len, struct sk_buff *skb) +{ + const struct nlattr *a; + int rem, err; + + nla_for_each_attr(a, attr, len, rem) { + int type = nla_type(a); + + switch (type) { + case OVS_ACTION_ATTR_SET: + err = set_action_to_attr(a, skb); + if (err) + return err; + break; + + case OVS_ACTION_ATTR_SET_TO_MASKED: + err = masked_set_action_to_set_action_attr(a, skb); + if (err) + return err; + break; + + case OVS_ACTION_ATTR_SAMPLE: + err = sample_action_to_attr(a, skb); + if (err) + return err; + break; + + case OVS_ACTION_ATTR_CT: + err = ovs_ct_action_to_attr(nla_data(a), skb); + if (err) + return err; + break; + + case OVS_ACTION_ATTR_CLONE: + err = clone_action_to_attr(a, skb); + if (err) + return err; + break; + + case OVS_ACTION_ATTR_CHECK_PKT_LEN: + err = check_pkt_len_action_to_attr(a, skb); + if (err) + return err; + break; + + case OVS_ACTION_ATTR_DEC_TTL: + err = dec_ttl_action_to_attr(a, skb); + if (err) + return err; + break; + + default: + if (nla_put(skb, type, nla_len(a), nla_data(a))) + return -EMSGSIZE; + break; + } + } + + return 0; +} diff --git a/net/openvswitch/flow_netlink.h b/net/openvswitch/flow_netlink.h new file mode 100644 index 000000000..fe7f77fc5 --- /dev/null +++ b/net/openvswitch/flow_netlink.h @@ -0,0 +1,73 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2007-2013 Nicira, Inc. + */ + + +#ifndef FLOW_NETLINK_H +#define FLOW_NETLINK_H 1 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "flow.h" + +size_t ovs_tun_key_attr_size(void); +size_t ovs_key_attr_size(void); + +void ovs_match_init(struct sw_flow_match *match, + struct sw_flow_key *key, bool reset_key, + struct sw_flow_mask *mask); + +int ovs_nla_put_key(const struct sw_flow_key *, const struct sw_flow_key *, + int attr, bool is_mask, struct sk_buff *); +int parse_flow_nlattrs(const struct nlattr *attr, const struct nlattr *a[], + u64 *attrsp, bool log); +int ovs_nla_get_flow_metadata(struct net *net, + const struct nlattr *a[OVS_KEY_ATTR_MAX + 1], + u64 attrs, struct sw_flow_key *key, bool log); + +int ovs_nla_put_identifier(const struct sw_flow *flow, struct sk_buff *skb); +int ovs_nla_put_masked_key(const struct sw_flow *flow, struct sk_buff *skb); +int ovs_nla_put_mask(const struct sw_flow *flow, struct sk_buff *skb); + +int ovs_nla_get_match(struct net *, struct sw_flow_match *, + const struct nlattr *key, const struct nlattr *mask, + bool log); + +int ovs_nla_put_tunnel_info(struct sk_buff *skb, + struct ip_tunnel_info *tun_info); + +bool ovs_nla_get_ufid(struct sw_flow_id *, const struct nlattr *, bool log); +int ovs_nla_get_identifier(struct sw_flow_id *sfid, const struct nlattr *ufid, + const struct sw_flow_key *key, bool log); +u32 ovs_nla_get_ufid_flags(const struct nlattr *attr); + +int ovs_nla_copy_actions(struct net *net, const struct nlattr *attr, + const struct sw_flow_key *key, + struct sw_flow_actions **sfa, bool log); +int ovs_nla_add_action(struct sw_flow_actions **sfa, int attrtype, + void *data, int len, bool log); +int ovs_nla_put_actions(const struct nlattr *attr, + int len, struct sk_buff *skb); + +void ovs_nla_free_flow_actions(struct sw_flow_actions *); +void ovs_nla_free_flow_actions_rcu(struct sw_flow_actions *); + +int nsh_key_from_nlattr(const struct nlattr *attr, struct ovs_key_nsh *nsh, + struct ovs_key_nsh *nsh_mask); +int nsh_hdr_from_nlattr(const struct nlattr *attr, struct nshhdr *nh, + size_t size); + +#endif /* flow_netlink.h */ diff --git a/net/openvswitch/flow_table.c b/net/openvswitch/flow_table.c new file mode 100644 index 000000000..d4a2db0b2 --- /dev/null +++ b/net/openvswitch/flow_table.c @@ -0,0 +1,1222 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2007-2014 Nicira, Inc. + */ + +#include "flow.h" +#include "datapath.h" +#include "flow_netlink.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define TBL_MIN_BUCKETS 1024 +#define MASK_ARRAY_SIZE_MIN 16 +#define REHASH_INTERVAL (10 * 60 * HZ) + +#define MC_DEFAULT_HASH_ENTRIES 256 +#define MC_HASH_SHIFT 8 +#define MC_HASH_SEGS ((sizeof(uint32_t) * 8) / MC_HASH_SHIFT) + +static struct kmem_cache *flow_cache; +struct kmem_cache *flow_stats_cache __read_mostly; + +static u16 range_n_bytes(const struct sw_flow_key_range *range) +{ + return range->end - range->start; +} + +void ovs_flow_mask_key(struct sw_flow_key *dst, const struct sw_flow_key *src, + bool full, const struct sw_flow_mask *mask) +{ + int start = full ? 0 : mask->range.start; + int len = full ? sizeof *dst : range_n_bytes(&mask->range); + const long *m = (const long *)((const u8 *)&mask->key + start); + const long *s = (const long *)((const u8 *)src + start); + long *d = (long *)((u8 *)dst + start); + int i; + + /* If 'full' is true then all of 'dst' is fully initialized. Otherwise, + * if 'full' is false the memory outside of the 'mask->range' is left + * uninitialized. This can be used as an optimization when further + * operations on 'dst' only use contents within 'mask->range'. + */ + for (i = 0; i < len; i += sizeof(long)) + *d++ = *s++ & *m++; +} + +struct sw_flow *ovs_flow_alloc(void) +{ + struct sw_flow *flow; + struct sw_flow_stats *stats; + + flow = kmem_cache_zalloc(flow_cache, GFP_KERNEL); + if (!flow) + return ERR_PTR(-ENOMEM); + + flow->stats_last_writer = -1; + + /* Initialize the default stat node. */ + stats = kmem_cache_alloc_node(flow_stats_cache, + GFP_KERNEL | __GFP_ZERO, + node_online(0) ? 0 : NUMA_NO_NODE); + if (!stats) + goto err; + + spin_lock_init(&stats->lock); + + RCU_INIT_POINTER(flow->stats[0], stats); + + cpumask_set_cpu(0, &flow->cpu_used_mask); + + return flow; +err: + kmem_cache_free(flow_cache, flow); + return ERR_PTR(-ENOMEM); +} + +int ovs_flow_tbl_count(const struct flow_table *table) +{ + return table->count; +} + +static void flow_free(struct sw_flow *flow) +{ + int cpu; + + if (ovs_identifier_is_key(&flow->id)) + kfree(flow->id.unmasked_key); + if (flow->sf_acts) + ovs_nla_free_flow_actions((struct sw_flow_actions __force *) + flow->sf_acts); + /* We open code this to make sure cpu 0 is always considered */ + for (cpu = 0; cpu < nr_cpu_ids; + cpu = cpumask_next(cpu, &flow->cpu_used_mask)) { + if (flow->stats[cpu]) + kmem_cache_free(flow_stats_cache, + (struct sw_flow_stats __force *)flow->stats[cpu]); + } + + kmem_cache_free(flow_cache, flow); +} + +static void rcu_free_flow_callback(struct rcu_head *rcu) +{ + struct sw_flow *flow = container_of(rcu, struct sw_flow, rcu); + + flow_free(flow); +} + +void ovs_flow_free(struct sw_flow *flow, bool deferred) +{ + if (!flow) + return; + + if (deferred) + call_rcu(&flow->rcu, rcu_free_flow_callback); + else + flow_free(flow); +} + +static void __table_instance_destroy(struct table_instance *ti) +{ + kvfree(ti->buckets); + kfree(ti); +} + +static struct table_instance *table_instance_alloc(int new_size) +{ + struct table_instance *ti = kmalloc(sizeof(*ti), GFP_KERNEL); + int i; + + if (!ti) + return NULL; + + ti->buckets = kvmalloc_array(new_size, sizeof(struct hlist_head), + GFP_KERNEL); + if (!ti->buckets) { + kfree(ti); + return NULL; + } + + for (i = 0; i < new_size; i++) + INIT_HLIST_HEAD(&ti->buckets[i]); + + ti->n_buckets = new_size; + ti->node_ver = 0; + get_random_bytes(&ti->hash_seed, sizeof(u32)); + + return ti; +} + +static void __mask_array_destroy(struct mask_array *ma) +{ + free_percpu(ma->masks_usage_stats); + kfree(ma); +} + +static void mask_array_rcu_cb(struct rcu_head *rcu) +{ + struct mask_array *ma = container_of(rcu, struct mask_array, rcu); + + __mask_array_destroy(ma); +} + +static void tbl_mask_array_reset_counters(struct mask_array *ma) +{ + int i, cpu; + + /* As the per CPU counters are not atomic we can not go ahead and + * reset them from another CPU. To be able to still have an approximate + * zero based counter we store the value at reset, and subtract it + * later when processing. + */ + for (i = 0; i < ma->max; i++) { + ma->masks_usage_zero_cntr[i] = 0; + + for_each_possible_cpu(cpu) { + struct mask_array_stats *stats; + unsigned int start; + u64 counter; + + stats = per_cpu_ptr(ma->masks_usage_stats, cpu); + do { + start = u64_stats_fetch_begin_irq(&stats->syncp); + counter = stats->usage_cntrs[i]; + } while (u64_stats_fetch_retry_irq(&stats->syncp, start)); + + ma->masks_usage_zero_cntr[i] += counter; + } + } +} + +static struct mask_array *tbl_mask_array_alloc(int size) +{ + struct mask_array *new; + + size = max(MASK_ARRAY_SIZE_MIN, size); + new = kzalloc(sizeof(struct mask_array) + + sizeof(struct sw_flow_mask *) * size + + sizeof(u64) * size, GFP_KERNEL); + if (!new) + return NULL; + + new->masks_usage_zero_cntr = (u64 *)((u8 *)new + + sizeof(struct mask_array) + + sizeof(struct sw_flow_mask *) * + size); + + new->masks_usage_stats = __alloc_percpu(sizeof(struct mask_array_stats) + + sizeof(u64) * size, + __alignof__(u64)); + if (!new->masks_usage_stats) { + kfree(new); + return NULL; + } + + new->count = 0; + new->max = size; + + return new; +} + +static int tbl_mask_array_realloc(struct flow_table *tbl, int size) +{ + struct mask_array *old; + struct mask_array *new; + + new = tbl_mask_array_alloc(size); + if (!new) + return -ENOMEM; + + old = ovsl_dereference(tbl->mask_array); + if (old) { + int i; + + for (i = 0; i < old->max; i++) { + if (ovsl_dereference(old->masks[i])) + new->masks[new->count++] = old->masks[i]; + } + call_rcu(&old->rcu, mask_array_rcu_cb); + } + + rcu_assign_pointer(tbl->mask_array, new); + + return 0; +} + +static int tbl_mask_array_add_mask(struct flow_table *tbl, + struct sw_flow_mask *new) +{ + struct mask_array *ma = ovsl_dereference(tbl->mask_array); + int err, ma_count = READ_ONCE(ma->count); + + if (ma_count >= ma->max) { + err = tbl_mask_array_realloc(tbl, ma->max + + MASK_ARRAY_SIZE_MIN); + if (err) + return err; + + ma = ovsl_dereference(tbl->mask_array); + } else { + /* On every add or delete we need to reset the counters so + * every new mask gets a fair chance of being prioritized. + */ + tbl_mask_array_reset_counters(ma); + } + + BUG_ON(ovsl_dereference(ma->masks[ma_count])); + + rcu_assign_pointer(ma->masks[ma_count], new); + WRITE_ONCE(ma->count, ma_count + 1); + + return 0; +} + +static void tbl_mask_array_del_mask(struct flow_table *tbl, + struct sw_flow_mask *mask) +{ + struct mask_array *ma = ovsl_dereference(tbl->mask_array); + int i, ma_count = READ_ONCE(ma->count); + + /* Remove the deleted mask pointers from the array */ + for (i = 0; i < ma_count; i++) { + if (mask == ovsl_dereference(ma->masks[i])) + goto found; + } + + BUG(); + return; + +found: + WRITE_ONCE(ma->count, ma_count - 1); + + rcu_assign_pointer(ma->masks[i], ma->masks[ma_count - 1]); + RCU_INIT_POINTER(ma->masks[ma_count - 1], NULL); + + kfree_rcu(mask, rcu); + + /* Shrink the mask array if necessary. */ + if (ma->max >= (MASK_ARRAY_SIZE_MIN * 2) && + ma_count <= (ma->max / 3)) + tbl_mask_array_realloc(tbl, ma->max / 2); + else + tbl_mask_array_reset_counters(ma); + +} + +/* Remove 'mask' from the mask list, if it is not needed any more. */ +static void flow_mask_remove(struct flow_table *tbl, struct sw_flow_mask *mask) +{ + if (mask) { + /* ovs-lock is required to protect mask-refcount and + * mask list. + */ + ASSERT_OVSL(); + BUG_ON(!mask->ref_count); + mask->ref_count--; + + if (!mask->ref_count) + tbl_mask_array_del_mask(tbl, mask); + } +} + +static void __mask_cache_destroy(struct mask_cache *mc) +{ + free_percpu(mc->mask_cache); + kfree(mc); +} + +static void mask_cache_rcu_cb(struct rcu_head *rcu) +{ + struct mask_cache *mc = container_of(rcu, struct mask_cache, rcu); + + __mask_cache_destroy(mc); +} + +static struct mask_cache *tbl_mask_cache_alloc(u32 size) +{ + struct mask_cache_entry __percpu *cache = NULL; + struct mask_cache *new; + + /* Only allow size to be 0, or a power of 2, and does not exceed + * percpu allocation size. + */ + if ((!is_power_of_2(size) && size != 0) || + (size * sizeof(struct mask_cache_entry)) > PCPU_MIN_UNIT_SIZE) + return NULL; + + new = kzalloc(sizeof(*new), GFP_KERNEL); + if (!new) + return NULL; + + new->cache_size = size; + if (new->cache_size > 0) { + cache = __alloc_percpu(array_size(sizeof(struct mask_cache_entry), + new->cache_size), + __alignof__(struct mask_cache_entry)); + if (!cache) { + kfree(new); + return NULL; + } + } + + new->mask_cache = cache; + return new; +} +int ovs_flow_tbl_masks_cache_resize(struct flow_table *table, u32 size) +{ + struct mask_cache *mc = rcu_dereference_ovsl(table->mask_cache); + struct mask_cache *new; + + if (size == mc->cache_size) + return 0; + + if ((!is_power_of_2(size) && size != 0) || + (size * sizeof(struct mask_cache_entry)) > PCPU_MIN_UNIT_SIZE) + return -EINVAL; + + new = tbl_mask_cache_alloc(size); + if (!new) + return -ENOMEM; + + rcu_assign_pointer(table->mask_cache, new); + call_rcu(&mc->rcu, mask_cache_rcu_cb); + + return 0; +} + +int ovs_flow_tbl_init(struct flow_table *table) +{ + struct table_instance *ti, *ufid_ti; + struct mask_cache *mc; + struct mask_array *ma; + + mc = tbl_mask_cache_alloc(MC_DEFAULT_HASH_ENTRIES); + if (!mc) + return -ENOMEM; + + ma = tbl_mask_array_alloc(MASK_ARRAY_SIZE_MIN); + if (!ma) + goto free_mask_cache; + + ti = table_instance_alloc(TBL_MIN_BUCKETS); + if (!ti) + goto free_mask_array; + + ufid_ti = table_instance_alloc(TBL_MIN_BUCKETS); + if (!ufid_ti) + goto free_ti; + + rcu_assign_pointer(table->ti, ti); + rcu_assign_pointer(table->ufid_ti, ufid_ti); + rcu_assign_pointer(table->mask_array, ma); + rcu_assign_pointer(table->mask_cache, mc); + table->last_rehash = jiffies; + table->count = 0; + table->ufid_count = 0; + return 0; + +free_ti: + __table_instance_destroy(ti); +free_mask_array: + __mask_array_destroy(ma); +free_mask_cache: + __mask_cache_destroy(mc); + return -ENOMEM; +} + +static void flow_tbl_destroy_rcu_cb(struct rcu_head *rcu) +{ + struct table_instance *ti; + + ti = container_of(rcu, struct table_instance, rcu); + __table_instance_destroy(ti); +} + +static void table_instance_flow_free(struct flow_table *table, + struct table_instance *ti, + struct table_instance *ufid_ti, + struct sw_flow *flow) +{ + hlist_del_rcu(&flow->flow_table.node[ti->node_ver]); + table->count--; + + if (ovs_identifier_is_ufid(&flow->id)) { + hlist_del_rcu(&flow->ufid_table.node[ufid_ti->node_ver]); + table->ufid_count--; + } + + flow_mask_remove(table, flow->mask); +} + +/* Must be called with OVS mutex held. */ +void table_instance_flow_flush(struct flow_table *table, + struct table_instance *ti, + struct table_instance *ufid_ti) +{ + int i; + + for (i = 0; i < ti->n_buckets; i++) { + struct hlist_head *head = &ti->buckets[i]; + struct hlist_node *n; + struct sw_flow *flow; + + hlist_for_each_entry_safe(flow, n, head, + flow_table.node[ti->node_ver]) { + + table_instance_flow_free(table, ti, ufid_ti, + flow); + ovs_flow_free(flow, true); + } + } + + if (WARN_ON(table->count != 0 || + table->ufid_count != 0)) { + table->count = 0; + table->ufid_count = 0; + } +} + +static void table_instance_destroy(struct table_instance *ti, + struct table_instance *ufid_ti) +{ + call_rcu(&ti->rcu, flow_tbl_destroy_rcu_cb); + call_rcu(&ufid_ti->rcu, flow_tbl_destroy_rcu_cb); +} + +/* No need for locking this function is called from RCU callback or + * error path. + */ +void ovs_flow_tbl_destroy(struct flow_table *table) +{ + struct table_instance *ti = rcu_dereference_raw(table->ti); + struct table_instance *ufid_ti = rcu_dereference_raw(table->ufid_ti); + struct mask_cache *mc = rcu_dereference_raw(table->mask_cache); + struct mask_array *ma = rcu_dereference_raw(table->mask_array); + + call_rcu(&mc->rcu, mask_cache_rcu_cb); + call_rcu(&ma->rcu, mask_array_rcu_cb); + table_instance_destroy(ti, ufid_ti); +} + +struct sw_flow *ovs_flow_tbl_dump_next(struct table_instance *ti, + u32 *bucket, u32 *last) +{ + struct sw_flow *flow; + struct hlist_head *head; + int ver; + int i; + + ver = ti->node_ver; + while (*bucket < ti->n_buckets) { + i = 0; + head = &ti->buckets[*bucket]; + hlist_for_each_entry_rcu(flow, head, flow_table.node[ver]) { + if (i < *last) { + i++; + continue; + } + *last = i + 1; + return flow; + } + (*bucket)++; + *last = 0; + } + + return NULL; +} + +static struct hlist_head *find_bucket(struct table_instance *ti, u32 hash) +{ + hash = jhash_1word(hash, ti->hash_seed); + return &ti->buckets[hash & (ti->n_buckets - 1)]; +} + +static void table_instance_insert(struct table_instance *ti, + struct sw_flow *flow) +{ + struct hlist_head *head; + + head = find_bucket(ti, flow->flow_table.hash); + hlist_add_head_rcu(&flow->flow_table.node[ti->node_ver], head); +} + +static void ufid_table_instance_insert(struct table_instance *ti, + struct sw_flow *flow) +{ + struct hlist_head *head; + + head = find_bucket(ti, flow->ufid_table.hash); + hlist_add_head_rcu(&flow->ufid_table.node[ti->node_ver], head); +} + +static void flow_table_copy_flows(struct table_instance *old, + struct table_instance *new, bool ufid) +{ + int old_ver; + int i; + + old_ver = old->node_ver; + new->node_ver = !old_ver; + + /* Insert in new table. */ + for (i = 0; i < old->n_buckets; i++) { + struct sw_flow *flow; + struct hlist_head *head = &old->buckets[i]; + + if (ufid) + hlist_for_each_entry_rcu(flow, head, + ufid_table.node[old_ver], + lockdep_ovsl_is_held()) + ufid_table_instance_insert(new, flow); + else + hlist_for_each_entry_rcu(flow, head, + flow_table.node[old_ver], + lockdep_ovsl_is_held()) + table_instance_insert(new, flow); + } +} + +static struct table_instance *table_instance_rehash(struct table_instance *ti, + int n_buckets, bool ufid) +{ + struct table_instance *new_ti; + + new_ti = table_instance_alloc(n_buckets); + if (!new_ti) + return NULL; + + flow_table_copy_flows(ti, new_ti, ufid); + + return new_ti; +} + +int ovs_flow_tbl_flush(struct flow_table *flow_table) +{ + struct table_instance *old_ti, *new_ti; + struct table_instance *old_ufid_ti, *new_ufid_ti; + + new_ti = table_instance_alloc(TBL_MIN_BUCKETS); + if (!new_ti) + return -ENOMEM; + new_ufid_ti = table_instance_alloc(TBL_MIN_BUCKETS); + if (!new_ufid_ti) + goto err_free_ti; + + old_ti = ovsl_dereference(flow_table->ti); + old_ufid_ti = ovsl_dereference(flow_table->ufid_ti); + + rcu_assign_pointer(flow_table->ti, new_ti); + rcu_assign_pointer(flow_table->ufid_ti, new_ufid_ti); + flow_table->last_rehash = jiffies; + + table_instance_flow_flush(flow_table, old_ti, old_ufid_ti); + table_instance_destroy(old_ti, old_ufid_ti); + return 0; + +err_free_ti: + __table_instance_destroy(new_ti); + return -ENOMEM; +} + +static u32 flow_hash(const struct sw_flow_key *key, + const struct sw_flow_key_range *range) +{ + const u32 *hash_key = (const u32 *)((const u8 *)key + range->start); + + /* Make sure number of hash bytes are multiple of u32. */ + int hash_u32s = range_n_bytes(range) >> 2; + + return jhash2(hash_key, hash_u32s, 0); +} + +static int flow_key_start(const struct sw_flow_key *key) +{ + if (key->tun_proto) + return 0; + else + return rounddown(offsetof(struct sw_flow_key, phy), + sizeof(long)); +} + +static bool cmp_key(const struct sw_flow_key *key1, + const struct sw_flow_key *key2, + int key_start, int key_end) +{ + const long *cp1 = (const long *)((const u8 *)key1 + key_start); + const long *cp2 = (const long *)((const u8 *)key2 + key_start); + int i; + + for (i = key_start; i < key_end; i += sizeof(long)) + if (*cp1++ ^ *cp2++) + return false; + + return true; +} + +static bool flow_cmp_masked_key(const struct sw_flow *flow, + const struct sw_flow_key *key, + const struct sw_flow_key_range *range) +{ + return cmp_key(&flow->key, key, range->start, range->end); +} + +static bool ovs_flow_cmp_unmasked_key(const struct sw_flow *flow, + const struct sw_flow_match *match) +{ + struct sw_flow_key *key = match->key; + int key_start = flow_key_start(key); + int key_end = match->range.end; + + BUG_ON(ovs_identifier_is_ufid(&flow->id)); + return cmp_key(flow->id.unmasked_key, key, key_start, key_end); +} + +static struct sw_flow *masked_flow_lookup(struct table_instance *ti, + const struct sw_flow_key *unmasked, + const struct sw_flow_mask *mask, + u32 *n_mask_hit) +{ + struct sw_flow *flow; + struct hlist_head *head; + u32 hash; + struct sw_flow_key masked_key; + + ovs_flow_mask_key(&masked_key, unmasked, false, mask); + hash = flow_hash(&masked_key, &mask->range); + head = find_bucket(ti, hash); + (*n_mask_hit)++; + + hlist_for_each_entry_rcu(flow, head, flow_table.node[ti->node_ver], + lockdep_ovsl_is_held()) { + if (flow->mask == mask && flow->flow_table.hash == hash && + flow_cmp_masked_key(flow, &masked_key, &mask->range)) + return flow; + } + return NULL; +} + +/* Flow lookup does full lookup on flow table. It starts with + * mask from index passed in *index. + * This function MUST be called with BH disabled due to the use + * of CPU specific variables. + */ +static struct sw_flow *flow_lookup(struct flow_table *tbl, + struct table_instance *ti, + struct mask_array *ma, + const struct sw_flow_key *key, + u32 *n_mask_hit, + u32 *n_cache_hit, + u32 *index) +{ + struct mask_array_stats *stats = this_cpu_ptr(ma->masks_usage_stats); + struct sw_flow *flow; + struct sw_flow_mask *mask; + int i; + + if (likely(*index < ma->max)) { + mask = rcu_dereference_ovsl(ma->masks[*index]); + if (mask) { + flow = masked_flow_lookup(ti, key, mask, n_mask_hit); + if (flow) { + u64_stats_update_begin(&stats->syncp); + stats->usage_cntrs[*index]++; + u64_stats_update_end(&stats->syncp); + (*n_cache_hit)++; + return flow; + } + } + } + + for (i = 0; i < ma->max; i++) { + + if (i == *index) + continue; + + mask = rcu_dereference_ovsl(ma->masks[i]); + if (unlikely(!mask)) + break; + + flow = masked_flow_lookup(ti, key, mask, n_mask_hit); + if (flow) { /* Found */ + *index = i; + u64_stats_update_begin(&stats->syncp); + stats->usage_cntrs[*index]++; + u64_stats_update_end(&stats->syncp); + return flow; + } + } + + return NULL; +} + +/* + * mask_cache maps flow to probable mask. This cache is not tightly + * coupled cache, It means updates to mask list can result in inconsistent + * cache entry in mask cache. + * This is per cpu cache and is divided in MC_HASH_SEGS segments. + * In case of a hash collision the entry is hashed in next segment. + * */ +struct sw_flow *ovs_flow_tbl_lookup_stats(struct flow_table *tbl, + const struct sw_flow_key *key, + u32 skb_hash, + u32 *n_mask_hit, + u32 *n_cache_hit) +{ + struct mask_cache *mc = rcu_dereference(tbl->mask_cache); + struct mask_array *ma = rcu_dereference(tbl->mask_array); + struct table_instance *ti = rcu_dereference(tbl->ti); + struct mask_cache_entry *entries, *ce; + struct sw_flow *flow; + u32 hash; + int seg; + + *n_mask_hit = 0; + *n_cache_hit = 0; + if (unlikely(!skb_hash || mc->cache_size == 0)) { + u32 mask_index = 0; + u32 cache = 0; + + return flow_lookup(tbl, ti, ma, key, n_mask_hit, &cache, + &mask_index); + } + + /* Pre and post recirulation flows usually have the same skb_hash + * value. To avoid hash collisions, rehash the 'skb_hash' with + * 'recirc_id'. */ + if (key->recirc_id) + skb_hash = jhash_1word(skb_hash, key->recirc_id); + + ce = NULL; + hash = skb_hash; + entries = this_cpu_ptr(mc->mask_cache); + + /* Find the cache entry 'ce' to operate on. */ + for (seg = 0; seg < MC_HASH_SEGS; seg++) { + int index = hash & (mc->cache_size - 1); + struct mask_cache_entry *e; + + e = &entries[index]; + if (e->skb_hash == skb_hash) { + flow = flow_lookup(tbl, ti, ma, key, n_mask_hit, + n_cache_hit, &e->mask_index); + if (!flow) + e->skb_hash = 0; + return flow; + } + + if (!ce || e->skb_hash < ce->skb_hash) + ce = e; /* A better replacement cache candidate. */ + + hash >>= MC_HASH_SHIFT; + } + + /* Cache miss, do full lookup. */ + flow = flow_lookup(tbl, ti, ma, key, n_mask_hit, n_cache_hit, + &ce->mask_index); + if (flow) + ce->skb_hash = skb_hash; + + *n_cache_hit = 0; + return flow; +} + +struct sw_flow *ovs_flow_tbl_lookup(struct flow_table *tbl, + const struct sw_flow_key *key) +{ + struct table_instance *ti = rcu_dereference_ovsl(tbl->ti); + struct mask_array *ma = rcu_dereference_ovsl(tbl->mask_array); + u32 __always_unused n_mask_hit; + u32 __always_unused n_cache_hit; + struct sw_flow *flow; + u32 index = 0; + + /* This function gets called trough the netlink interface and therefore + * is preemptible. However, flow_lookup() function needs to be called + * with BH disabled due to CPU specific variables. + */ + local_bh_disable(); + flow = flow_lookup(tbl, ti, ma, key, &n_mask_hit, &n_cache_hit, &index); + local_bh_enable(); + return flow; +} + +struct sw_flow *ovs_flow_tbl_lookup_exact(struct flow_table *tbl, + const struct sw_flow_match *match) +{ + struct mask_array *ma = ovsl_dereference(tbl->mask_array); + int i; + + /* Always called under ovs-mutex. */ + for (i = 0; i < ma->max; i++) { + struct table_instance *ti = rcu_dereference_ovsl(tbl->ti); + u32 __always_unused n_mask_hit; + struct sw_flow_mask *mask; + struct sw_flow *flow; + + mask = ovsl_dereference(ma->masks[i]); + if (!mask) + continue; + + flow = masked_flow_lookup(ti, match->key, mask, &n_mask_hit); + if (flow && ovs_identifier_is_key(&flow->id) && + ovs_flow_cmp_unmasked_key(flow, match)) { + return flow; + } + } + + return NULL; +} + +static u32 ufid_hash(const struct sw_flow_id *sfid) +{ + return jhash(sfid->ufid, sfid->ufid_len, 0); +} + +static bool ovs_flow_cmp_ufid(const struct sw_flow *flow, + const struct sw_flow_id *sfid) +{ + if (flow->id.ufid_len != sfid->ufid_len) + return false; + + return !memcmp(flow->id.ufid, sfid->ufid, sfid->ufid_len); +} + +bool ovs_flow_cmp(const struct sw_flow *flow, + const struct sw_flow_match *match) +{ + if (ovs_identifier_is_ufid(&flow->id)) + return flow_cmp_masked_key(flow, match->key, &match->range); + + return ovs_flow_cmp_unmasked_key(flow, match); +} + +struct sw_flow *ovs_flow_tbl_lookup_ufid(struct flow_table *tbl, + const struct sw_flow_id *ufid) +{ + struct table_instance *ti = rcu_dereference_ovsl(tbl->ufid_ti); + struct sw_flow *flow; + struct hlist_head *head; + u32 hash; + + hash = ufid_hash(ufid); + head = find_bucket(ti, hash); + hlist_for_each_entry_rcu(flow, head, ufid_table.node[ti->node_ver], + lockdep_ovsl_is_held()) { + if (flow->ufid_table.hash == hash && + ovs_flow_cmp_ufid(flow, ufid)) + return flow; + } + return NULL; +} + +int ovs_flow_tbl_num_masks(const struct flow_table *table) +{ + struct mask_array *ma = rcu_dereference_ovsl(table->mask_array); + return READ_ONCE(ma->count); +} + +u32 ovs_flow_tbl_masks_cache_size(const struct flow_table *table) +{ + struct mask_cache *mc = rcu_dereference_ovsl(table->mask_cache); + + return READ_ONCE(mc->cache_size); +} + +static struct table_instance *table_instance_expand(struct table_instance *ti, + bool ufid) +{ + return table_instance_rehash(ti, ti->n_buckets * 2, ufid); +} + +/* Must be called with OVS mutex held. */ +void ovs_flow_tbl_remove(struct flow_table *table, struct sw_flow *flow) +{ + struct table_instance *ti = ovsl_dereference(table->ti); + struct table_instance *ufid_ti = ovsl_dereference(table->ufid_ti); + + BUG_ON(table->count == 0); + table_instance_flow_free(table, ti, ufid_ti, flow); +} + +static struct sw_flow_mask *mask_alloc(void) +{ + struct sw_flow_mask *mask; + + mask = kmalloc(sizeof(*mask), GFP_KERNEL); + if (mask) + mask->ref_count = 1; + + return mask; +} + +static bool mask_equal(const struct sw_flow_mask *a, + const struct sw_flow_mask *b) +{ + const u8 *a_ = (const u8 *)&a->key + a->range.start; + const u8 *b_ = (const u8 *)&b->key + b->range.start; + + return (a->range.end == b->range.end) + && (a->range.start == b->range.start) + && (memcmp(a_, b_, range_n_bytes(&a->range)) == 0); +} + +static struct sw_flow_mask *flow_mask_find(const struct flow_table *tbl, + const struct sw_flow_mask *mask) +{ + struct mask_array *ma; + int i; + + ma = ovsl_dereference(tbl->mask_array); + for (i = 0; i < ma->max; i++) { + struct sw_flow_mask *t; + t = ovsl_dereference(ma->masks[i]); + + if (t && mask_equal(mask, t)) + return t; + } + + return NULL; +} + +/* Add 'mask' into the mask list, if it is not already there. */ +static int flow_mask_insert(struct flow_table *tbl, struct sw_flow *flow, + const struct sw_flow_mask *new) +{ + struct sw_flow_mask *mask; + + mask = flow_mask_find(tbl, new); + if (!mask) { + /* Allocate a new mask if none exsits. */ + mask = mask_alloc(); + if (!mask) + return -ENOMEM; + mask->key = new->key; + mask->range = new->range; + + /* Add mask to mask-list. */ + if (tbl_mask_array_add_mask(tbl, mask)) { + kfree(mask); + return -ENOMEM; + } + } else { + BUG_ON(!mask->ref_count); + mask->ref_count++; + } + + flow->mask = mask; + return 0; +} + +/* Must be called with OVS mutex held. */ +static void flow_key_insert(struct flow_table *table, struct sw_flow *flow) +{ + struct table_instance *new_ti = NULL; + struct table_instance *ti; + + flow->flow_table.hash = flow_hash(&flow->key, &flow->mask->range); + ti = ovsl_dereference(table->ti); + table_instance_insert(ti, flow); + table->count++; + + /* Expand table, if necessary, to make room. */ + if (table->count > ti->n_buckets) + new_ti = table_instance_expand(ti, false); + else if (time_after(jiffies, table->last_rehash + REHASH_INTERVAL)) + new_ti = table_instance_rehash(ti, ti->n_buckets, false); + + if (new_ti) { + rcu_assign_pointer(table->ti, new_ti); + call_rcu(&ti->rcu, flow_tbl_destroy_rcu_cb); + table->last_rehash = jiffies; + } +} + +/* Must be called with OVS mutex held. */ +static void flow_ufid_insert(struct flow_table *table, struct sw_flow *flow) +{ + struct table_instance *ti; + + flow->ufid_table.hash = ufid_hash(&flow->id); + ti = ovsl_dereference(table->ufid_ti); + ufid_table_instance_insert(ti, flow); + table->ufid_count++; + + /* Expand table, if necessary, to make room. */ + if (table->ufid_count > ti->n_buckets) { + struct table_instance *new_ti; + + new_ti = table_instance_expand(ti, true); + if (new_ti) { + rcu_assign_pointer(table->ufid_ti, new_ti); + call_rcu(&ti->rcu, flow_tbl_destroy_rcu_cb); + } + } +} + +/* Must be called with OVS mutex held. */ +int ovs_flow_tbl_insert(struct flow_table *table, struct sw_flow *flow, + const struct sw_flow_mask *mask) +{ + int err; + + err = flow_mask_insert(table, flow, mask); + if (err) + return err; + flow_key_insert(table, flow); + if (ovs_identifier_is_ufid(&flow->id)) + flow_ufid_insert(table, flow); + + return 0; +} + +static int compare_mask_and_count(const void *a, const void *b) +{ + const struct mask_count *mc_a = a; + const struct mask_count *mc_b = b; + + return (s64)mc_b->counter - (s64)mc_a->counter; +} + +/* Must be called with OVS mutex held. */ +void ovs_flow_masks_rebalance(struct flow_table *table) +{ + struct mask_array *ma = rcu_dereference_ovsl(table->mask_array); + struct mask_count *masks_and_count; + struct mask_array *new; + int masks_entries = 0; + int i; + + /* Build array of all current entries with use counters. */ + masks_and_count = kmalloc_array(ma->max, sizeof(*masks_and_count), + GFP_KERNEL); + if (!masks_and_count) + return; + + for (i = 0; i < ma->max; i++) { + struct sw_flow_mask *mask; + int cpu; + + mask = rcu_dereference_ovsl(ma->masks[i]); + if (unlikely(!mask)) + break; + + masks_and_count[i].index = i; + masks_and_count[i].counter = 0; + + for_each_possible_cpu(cpu) { + struct mask_array_stats *stats; + unsigned int start; + u64 counter; + + stats = per_cpu_ptr(ma->masks_usage_stats, cpu); + do { + start = u64_stats_fetch_begin_irq(&stats->syncp); + counter = stats->usage_cntrs[i]; + } while (u64_stats_fetch_retry_irq(&stats->syncp, + start)); + + masks_and_count[i].counter += counter; + } + + /* Subtract the zero count value. */ + masks_and_count[i].counter -= ma->masks_usage_zero_cntr[i]; + + /* Rather than calling tbl_mask_array_reset_counters() + * below when no change is needed, do it inline here. + */ + ma->masks_usage_zero_cntr[i] += masks_and_count[i].counter; + } + + if (i == 0) + goto free_mask_entries; + + /* Sort the entries */ + masks_entries = i; + sort(masks_and_count, masks_entries, sizeof(*masks_and_count), + compare_mask_and_count, NULL); + + /* If the order is the same, nothing to do... */ + for (i = 0; i < masks_entries; i++) { + if (i != masks_and_count[i].index) + break; + } + if (i == masks_entries) + goto free_mask_entries; + + /* Rebuilt the new list in order of usage. */ + new = tbl_mask_array_alloc(ma->max); + if (!new) + goto free_mask_entries; + + for (i = 0; i < masks_entries; i++) { + int index = masks_and_count[i].index; + + if (ovsl_dereference(ma->masks[index])) + new->masks[new->count++] = ma->masks[index]; + } + + rcu_assign_pointer(table->mask_array, new); + call_rcu(&ma->rcu, mask_array_rcu_cb); + +free_mask_entries: + kfree(masks_and_count); +} + +/* Initializes the flow module. + * Returns zero if successful or a negative error code. */ +int ovs_flow_init(void) +{ + BUILD_BUG_ON(__alignof__(struct sw_flow_key) % __alignof__(long)); + BUILD_BUG_ON(sizeof(struct sw_flow_key) % sizeof(long)); + + flow_cache = kmem_cache_create("sw_flow", sizeof(struct sw_flow) + + (nr_cpu_ids + * sizeof(struct sw_flow_stats *)), + 0, 0, NULL); + if (flow_cache == NULL) + return -ENOMEM; + + flow_stats_cache + = kmem_cache_create("sw_flow_stats", sizeof(struct sw_flow_stats), + 0, SLAB_HWCACHE_ALIGN, NULL); + if (flow_stats_cache == NULL) { + kmem_cache_destroy(flow_cache); + flow_cache = NULL; + return -ENOMEM; + } + + return 0; +} + +/* Uninitializes the flow module. */ +void ovs_flow_exit(void) +{ + kmem_cache_destroy(flow_stats_cache); + kmem_cache_destroy(flow_cache); +} diff --git a/net/openvswitch/flow_table.h b/net/openvswitch/flow_table.h new file mode 100644 index 000000000..9e659db78 --- /dev/null +++ b/net/openvswitch/flow_table.h @@ -0,0 +1,115 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2007-2013 Nicira, Inc. + */ + +#ifndef FLOW_TABLE_H +#define FLOW_TABLE_H 1 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "flow.h" + +struct mask_cache_entry { + u32 skb_hash; + u32 mask_index; +}; + +struct mask_cache { + struct rcu_head rcu; + u32 cache_size; /* Must be ^2 value. */ + struct mask_cache_entry __percpu *mask_cache; +}; + +struct mask_count { + int index; + u64 counter; +}; + +struct mask_array_stats { + struct u64_stats_sync syncp; + u64 usage_cntrs[]; +}; + +struct mask_array { + struct rcu_head rcu; + int count, max; + struct mask_array_stats __percpu *masks_usage_stats; + u64 *masks_usage_zero_cntr; + struct sw_flow_mask __rcu *masks[]; +}; + +struct table_instance { + struct hlist_head *buckets; + unsigned int n_buckets; + struct rcu_head rcu; + int node_ver; + u32 hash_seed; +}; + +struct flow_table { + struct table_instance __rcu *ti; + struct table_instance __rcu *ufid_ti; + struct mask_cache __rcu *mask_cache; + struct mask_array __rcu *mask_array; + unsigned long last_rehash; + unsigned int count; + unsigned int ufid_count; +}; + +extern struct kmem_cache *flow_stats_cache; + +int ovs_flow_init(void); +void ovs_flow_exit(void); + +struct sw_flow *ovs_flow_alloc(void); +void ovs_flow_free(struct sw_flow *, bool deferred); + +int ovs_flow_tbl_init(struct flow_table *); +int ovs_flow_tbl_count(const struct flow_table *table); +void ovs_flow_tbl_destroy(struct flow_table *table); +int ovs_flow_tbl_flush(struct flow_table *flow_table); + +int ovs_flow_tbl_insert(struct flow_table *table, struct sw_flow *flow, + const struct sw_flow_mask *mask); +void ovs_flow_tbl_remove(struct flow_table *table, struct sw_flow *flow); +int ovs_flow_tbl_num_masks(const struct flow_table *table); +u32 ovs_flow_tbl_masks_cache_size(const struct flow_table *table); +int ovs_flow_tbl_masks_cache_resize(struct flow_table *table, u32 size); +struct sw_flow *ovs_flow_tbl_dump_next(struct table_instance *table, + u32 *bucket, u32 *idx); +struct sw_flow *ovs_flow_tbl_lookup_stats(struct flow_table *, + const struct sw_flow_key *, + u32 skb_hash, + u32 *n_mask_hit, + u32 *n_cache_hit); +struct sw_flow *ovs_flow_tbl_lookup(struct flow_table *, + const struct sw_flow_key *); +struct sw_flow *ovs_flow_tbl_lookup_exact(struct flow_table *tbl, + const struct sw_flow_match *match); +struct sw_flow *ovs_flow_tbl_lookup_ufid(struct flow_table *, + const struct sw_flow_id *); + +bool ovs_flow_cmp(const struct sw_flow *, const struct sw_flow_match *); + +void ovs_flow_mask_key(struct sw_flow_key *dst, const struct sw_flow_key *src, + bool full, const struct sw_flow_mask *mask); + +void ovs_flow_masks_rebalance(struct flow_table *table); +void table_instance_flow_flush(struct flow_table *table, + struct table_instance *ti, + struct table_instance *ufid_ti); + +#endif /* flow_table.h */ diff --git a/net/openvswitch/meter.c b/net/openvswitch/meter.c new file mode 100644 index 000000000..f2698d231 --- /dev/null +++ b/net/openvswitch/meter.c @@ -0,0 +1,768 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2017 Nicira, Inc. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "datapath.h" +#include "meter.h" + +static const struct nla_policy meter_policy[OVS_METER_ATTR_MAX + 1] = { + [OVS_METER_ATTR_ID] = { .type = NLA_U32, }, + [OVS_METER_ATTR_KBPS] = { .type = NLA_FLAG }, + [OVS_METER_ATTR_STATS] = { .len = sizeof(struct ovs_flow_stats) }, + [OVS_METER_ATTR_BANDS] = { .type = NLA_NESTED }, + [OVS_METER_ATTR_USED] = { .type = NLA_U64 }, + [OVS_METER_ATTR_CLEAR] = { .type = NLA_FLAG }, + [OVS_METER_ATTR_MAX_METERS] = { .type = NLA_U32 }, + [OVS_METER_ATTR_MAX_BANDS] = { .type = NLA_U32 }, +}; + +static const struct nla_policy band_policy[OVS_BAND_ATTR_MAX + 1] = { + [OVS_BAND_ATTR_TYPE] = { .type = NLA_U32, }, + [OVS_BAND_ATTR_RATE] = { .type = NLA_U32, }, + [OVS_BAND_ATTR_BURST] = { .type = NLA_U32, }, + [OVS_BAND_ATTR_STATS] = { .len = sizeof(struct ovs_flow_stats) }, +}; + +static u32 meter_hash(struct dp_meter_instance *ti, u32 id) +{ + return id % ti->n_meters; +} + +static void ovs_meter_free(struct dp_meter *meter) +{ + if (!meter) + return; + + kfree_rcu(meter, rcu); +} + +/* Call with ovs_mutex or RCU read lock. */ +static struct dp_meter *lookup_meter(const struct dp_meter_table *tbl, + u32 meter_id) +{ + struct dp_meter_instance *ti = rcu_dereference_ovsl(tbl->ti); + u32 hash = meter_hash(ti, meter_id); + struct dp_meter *meter; + + meter = rcu_dereference_ovsl(ti->dp_meters[hash]); + if (meter && likely(meter->id == meter_id)) + return meter; + + return NULL; +} + +static struct dp_meter_instance *dp_meter_instance_alloc(const u32 size) +{ + struct dp_meter_instance *ti; + + ti = kvzalloc(sizeof(*ti) + + sizeof(struct dp_meter *) * size, + GFP_KERNEL); + if (!ti) + return NULL; + + ti->n_meters = size; + + return ti; +} + +static void dp_meter_instance_free(struct dp_meter_instance *ti) +{ + kvfree(ti); +} + +static void dp_meter_instance_free_rcu(struct rcu_head *rcu) +{ + struct dp_meter_instance *ti; + + ti = container_of(rcu, struct dp_meter_instance, rcu); + kvfree(ti); +} + +static int +dp_meter_instance_realloc(struct dp_meter_table *tbl, u32 size) +{ + struct dp_meter_instance *ti = rcu_dereference_ovsl(tbl->ti); + int n_meters = min(size, ti->n_meters); + struct dp_meter_instance *new_ti; + int i; + + new_ti = dp_meter_instance_alloc(size); + if (!new_ti) + return -ENOMEM; + + for (i = 0; i < n_meters; i++) + if (rcu_dereference_ovsl(ti->dp_meters[i])) + new_ti->dp_meters[i] = ti->dp_meters[i]; + + rcu_assign_pointer(tbl->ti, new_ti); + call_rcu(&ti->rcu, dp_meter_instance_free_rcu); + + return 0; +} + +static void dp_meter_instance_insert(struct dp_meter_instance *ti, + struct dp_meter *meter) +{ + u32 hash; + + hash = meter_hash(ti, meter->id); + rcu_assign_pointer(ti->dp_meters[hash], meter); +} + +static void dp_meter_instance_remove(struct dp_meter_instance *ti, + struct dp_meter *meter) +{ + u32 hash; + + hash = meter_hash(ti, meter->id); + RCU_INIT_POINTER(ti->dp_meters[hash], NULL); +} + +static int attach_meter(struct dp_meter_table *tbl, struct dp_meter *meter) +{ + struct dp_meter_instance *ti = rcu_dereference_ovsl(tbl->ti); + u32 hash = meter_hash(ti, meter->id); + int err; + + /* In generally, slots selected should be empty, because + * OvS uses id-pool to fetch a available id. + */ + if (unlikely(rcu_dereference_ovsl(ti->dp_meters[hash]))) + return -EBUSY; + + dp_meter_instance_insert(ti, meter); + + /* That function is thread-safe. */ + tbl->count++; + if (tbl->count >= tbl->max_meters_allowed) { + err = -EFBIG; + goto attach_err; + } + + if (tbl->count >= ti->n_meters && + dp_meter_instance_realloc(tbl, ti->n_meters * 2)) { + err = -ENOMEM; + goto attach_err; + } + + return 0; + +attach_err: + dp_meter_instance_remove(ti, meter); + tbl->count--; + return err; +} + +static int detach_meter(struct dp_meter_table *tbl, struct dp_meter *meter) +{ + struct dp_meter_instance *ti; + + ASSERT_OVSL(); + if (!meter) + return 0; + + ti = rcu_dereference_ovsl(tbl->ti); + dp_meter_instance_remove(ti, meter); + + tbl->count--; + + /* Shrink the meter array if necessary. */ + if (ti->n_meters > DP_METER_ARRAY_SIZE_MIN && + tbl->count <= (ti->n_meters / 4)) { + int half_size = ti->n_meters / 2; + int i; + + /* Avoid hash collision, don't move slots to other place. + * Make sure there are no references of meters in array + * which will be released. + */ + for (i = half_size; i < ti->n_meters; i++) + if (rcu_dereference_ovsl(ti->dp_meters[i])) + goto out; + + if (dp_meter_instance_realloc(tbl, half_size)) + goto shrink_err; + } + +out: + return 0; + +shrink_err: + dp_meter_instance_insert(ti, meter); + tbl->count++; + return -ENOMEM; +} + +static struct sk_buff * +ovs_meter_cmd_reply_start(struct genl_info *info, u8 cmd, + struct ovs_header **ovs_reply_header) +{ + struct sk_buff *skb; + struct ovs_header *ovs_header = info->userhdr; + + skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!skb) + return ERR_PTR(-ENOMEM); + + *ovs_reply_header = genlmsg_put(skb, info->snd_portid, + info->snd_seq, + &dp_meter_genl_family, 0, cmd); + if (!*ovs_reply_header) { + nlmsg_free(skb); + return ERR_PTR(-EMSGSIZE); + } + (*ovs_reply_header)->dp_ifindex = ovs_header->dp_ifindex; + + return skb; +} + +static int ovs_meter_cmd_reply_stats(struct sk_buff *reply, u32 meter_id, + struct dp_meter *meter) +{ + struct nlattr *nla; + struct dp_meter_band *band; + u16 i; + + if (nla_put_u32(reply, OVS_METER_ATTR_ID, meter_id)) + goto error; + + if (nla_put(reply, OVS_METER_ATTR_STATS, + sizeof(struct ovs_flow_stats), &meter->stats)) + goto error; + + if (nla_put_u64_64bit(reply, OVS_METER_ATTR_USED, meter->used, + OVS_METER_ATTR_PAD)) + goto error; + + nla = nla_nest_start_noflag(reply, OVS_METER_ATTR_BANDS); + if (!nla) + goto error; + + band = meter->bands; + + for (i = 0; i < meter->n_bands; ++i, ++band) { + struct nlattr *band_nla; + + band_nla = nla_nest_start_noflag(reply, OVS_BAND_ATTR_UNSPEC); + if (!band_nla || nla_put(reply, OVS_BAND_ATTR_STATS, + sizeof(struct ovs_flow_stats), + &band->stats)) + goto error; + nla_nest_end(reply, band_nla); + } + nla_nest_end(reply, nla); + + return 0; +error: + return -EMSGSIZE; +} + +static int ovs_meter_cmd_features(struct sk_buff *skb, struct genl_info *info) +{ + struct ovs_header *ovs_header = info->userhdr; + struct ovs_header *ovs_reply_header; + struct nlattr *nla, *band_nla; + struct sk_buff *reply; + struct datapath *dp; + int err = -EMSGSIZE; + + reply = ovs_meter_cmd_reply_start(info, OVS_METER_CMD_FEATURES, + &ovs_reply_header); + if (IS_ERR(reply)) + return PTR_ERR(reply); + + ovs_lock(); + dp = get_dp(sock_net(skb->sk), ovs_header->dp_ifindex); + if (!dp) { + err = -ENODEV; + goto exit_unlock; + } + + if (nla_put_u32(reply, OVS_METER_ATTR_MAX_METERS, + dp->meter_tbl.max_meters_allowed)) + goto exit_unlock; + + ovs_unlock(); + + if (nla_put_u32(reply, OVS_METER_ATTR_MAX_BANDS, DP_MAX_BANDS)) + goto nla_put_failure; + + nla = nla_nest_start_noflag(reply, OVS_METER_ATTR_BANDS); + if (!nla) + goto nla_put_failure; + + band_nla = nla_nest_start_noflag(reply, OVS_BAND_ATTR_UNSPEC); + if (!band_nla) + goto nla_put_failure; + /* Currently only DROP band type is supported. */ + if (nla_put_u32(reply, OVS_BAND_ATTR_TYPE, OVS_METER_BAND_TYPE_DROP)) + goto nla_put_failure; + nla_nest_end(reply, band_nla); + nla_nest_end(reply, nla); + + genlmsg_end(reply, ovs_reply_header); + return genlmsg_reply(reply, info); + +exit_unlock: + ovs_unlock(); +nla_put_failure: + nlmsg_free(reply); + return err; +} + +static struct dp_meter *dp_meter_create(struct nlattr **a) +{ + struct nlattr *nla; + int rem; + u16 n_bands = 0; + struct dp_meter *meter; + struct dp_meter_band *band; + int err; + + /* Validate attributes, count the bands. */ + if (!a[OVS_METER_ATTR_BANDS]) + return ERR_PTR(-EINVAL); + + nla_for_each_nested(nla, a[OVS_METER_ATTR_BANDS], rem) + if (++n_bands > DP_MAX_BANDS) + return ERR_PTR(-EINVAL); + + /* Allocate and set up the meter before locking anything. */ + meter = kzalloc(struct_size(meter, bands, n_bands), GFP_KERNEL_ACCOUNT); + if (!meter) + return ERR_PTR(-ENOMEM); + + meter->id = nla_get_u32(a[OVS_METER_ATTR_ID]); + meter->used = div_u64(ktime_get_ns(), 1000 * 1000); + meter->kbps = a[OVS_METER_ATTR_KBPS] ? 1 : 0; + meter->keep_stats = !a[OVS_METER_ATTR_CLEAR]; + spin_lock_init(&meter->lock); + if (meter->keep_stats && a[OVS_METER_ATTR_STATS]) { + meter->stats = *(struct ovs_flow_stats *) + nla_data(a[OVS_METER_ATTR_STATS]); + } + meter->n_bands = n_bands; + + /* Set up meter bands. */ + band = meter->bands; + nla_for_each_nested(nla, a[OVS_METER_ATTR_BANDS], rem) { + struct nlattr *attr[OVS_BAND_ATTR_MAX + 1]; + u32 band_max_delta_t; + + err = nla_parse_deprecated((struct nlattr **)&attr, + OVS_BAND_ATTR_MAX, nla_data(nla), + nla_len(nla), band_policy, NULL); + if (err) + goto exit_free_meter; + + if (!attr[OVS_BAND_ATTR_TYPE] || + !attr[OVS_BAND_ATTR_RATE] || + !attr[OVS_BAND_ATTR_BURST]) { + err = -EINVAL; + goto exit_free_meter; + } + + band->type = nla_get_u32(attr[OVS_BAND_ATTR_TYPE]); + band->rate = nla_get_u32(attr[OVS_BAND_ATTR_RATE]); + if (band->rate == 0) { + err = -EINVAL; + goto exit_free_meter; + } + + band->burst_size = nla_get_u32(attr[OVS_BAND_ATTR_BURST]); + /* Figure out max delta_t that is enough to fill any bucket. + * Keep max_delta_t size to the bucket units: + * pkts => 1/1000 packets, kilobits => bits. + * + * Start with a full bucket. + */ + band->bucket = band->burst_size * 1000ULL; + band_max_delta_t = div_u64(band->bucket, band->rate); + if (band_max_delta_t > meter->max_delta_t) + meter->max_delta_t = band_max_delta_t; + band++; + } + + return meter; + +exit_free_meter: + kfree(meter); + return ERR_PTR(err); +} + +static int ovs_meter_cmd_set(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr **a = info->attrs; + struct dp_meter *meter, *old_meter; + struct sk_buff *reply; + struct ovs_header *ovs_reply_header; + struct ovs_header *ovs_header = info->userhdr; + struct dp_meter_table *meter_tbl; + struct datapath *dp; + int err; + u32 meter_id; + bool failed; + + if (!a[OVS_METER_ATTR_ID]) + return -EINVAL; + + meter = dp_meter_create(a); + if (IS_ERR(meter)) + return PTR_ERR(meter); + + reply = ovs_meter_cmd_reply_start(info, OVS_METER_CMD_SET, + &ovs_reply_header); + if (IS_ERR(reply)) { + err = PTR_ERR(reply); + goto exit_free_meter; + } + + ovs_lock(); + dp = get_dp(sock_net(skb->sk), ovs_header->dp_ifindex); + if (!dp) { + err = -ENODEV; + goto exit_unlock; + } + + meter_tbl = &dp->meter_tbl; + meter_id = nla_get_u32(a[OVS_METER_ATTR_ID]); + + old_meter = lookup_meter(meter_tbl, meter_id); + err = detach_meter(meter_tbl, old_meter); + if (err) + goto exit_unlock; + + err = attach_meter(meter_tbl, meter); + if (err) + goto exit_free_old_meter; + + ovs_unlock(); + + /* Build response with the meter_id and stats from + * the old meter, if any. + */ + failed = nla_put_u32(reply, OVS_METER_ATTR_ID, meter_id); + WARN_ON(failed); + if (old_meter) { + spin_lock_bh(&old_meter->lock); + if (old_meter->keep_stats) { + err = ovs_meter_cmd_reply_stats(reply, meter_id, + old_meter); + WARN_ON(err); + } + spin_unlock_bh(&old_meter->lock); + ovs_meter_free(old_meter); + } + + genlmsg_end(reply, ovs_reply_header); + return genlmsg_reply(reply, info); + +exit_free_old_meter: + ovs_meter_free(old_meter); +exit_unlock: + ovs_unlock(); + nlmsg_free(reply); +exit_free_meter: + kfree(meter); + return err; +} + +static int ovs_meter_cmd_get(struct sk_buff *skb, struct genl_info *info) +{ + struct ovs_header *ovs_header = info->userhdr; + struct ovs_header *ovs_reply_header; + struct nlattr **a = info->attrs; + struct dp_meter *meter; + struct sk_buff *reply; + struct datapath *dp; + u32 meter_id; + int err; + + if (!a[OVS_METER_ATTR_ID]) + return -EINVAL; + + meter_id = nla_get_u32(a[OVS_METER_ATTR_ID]); + + reply = ovs_meter_cmd_reply_start(info, OVS_METER_CMD_GET, + &ovs_reply_header); + if (IS_ERR(reply)) + return PTR_ERR(reply); + + ovs_lock(); + + dp = get_dp(sock_net(skb->sk), ovs_header->dp_ifindex); + if (!dp) { + err = -ENODEV; + goto exit_unlock; + } + + /* Locate meter, copy stats. */ + meter = lookup_meter(&dp->meter_tbl, meter_id); + if (!meter) { + err = -ENOENT; + goto exit_unlock; + } + + spin_lock_bh(&meter->lock); + err = ovs_meter_cmd_reply_stats(reply, meter_id, meter); + spin_unlock_bh(&meter->lock); + if (err) + goto exit_unlock; + + ovs_unlock(); + + genlmsg_end(reply, ovs_reply_header); + return genlmsg_reply(reply, info); + +exit_unlock: + ovs_unlock(); + nlmsg_free(reply); + return err; +} + +static int ovs_meter_cmd_del(struct sk_buff *skb, struct genl_info *info) +{ + struct ovs_header *ovs_header = info->userhdr; + struct ovs_header *ovs_reply_header; + struct nlattr **a = info->attrs; + struct dp_meter *old_meter; + struct sk_buff *reply; + struct datapath *dp; + u32 meter_id; + int err; + + if (!a[OVS_METER_ATTR_ID]) + return -EINVAL; + + reply = ovs_meter_cmd_reply_start(info, OVS_METER_CMD_DEL, + &ovs_reply_header); + if (IS_ERR(reply)) + return PTR_ERR(reply); + + ovs_lock(); + + dp = get_dp(sock_net(skb->sk), ovs_header->dp_ifindex); + if (!dp) { + err = -ENODEV; + goto exit_unlock; + } + + meter_id = nla_get_u32(a[OVS_METER_ATTR_ID]); + old_meter = lookup_meter(&dp->meter_tbl, meter_id); + if (old_meter) { + spin_lock_bh(&old_meter->lock); + err = ovs_meter_cmd_reply_stats(reply, meter_id, old_meter); + WARN_ON(err); + spin_unlock_bh(&old_meter->lock); + + err = detach_meter(&dp->meter_tbl, old_meter); + if (err) + goto exit_unlock; + } + + ovs_unlock(); + ovs_meter_free(old_meter); + genlmsg_end(reply, ovs_reply_header); + return genlmsg_reply(reply, info); + +exit_unlock: + ovs_unlock(); + nlmsg_free(reply); + return err; +} + +/* Meter action execution. + * + * Return true 'meter_id' drop band is triggered. The 'skb' should be + * dropped by the caller'. + */ +bool ovs_meter_execute(struct datapath *dp, struct sk_buff *skb, + struct sw_flow_key *key, u32 meter_id) +{ + long long int now_ms = div_u64(ktime_get_ns(), 1000 * 1000); + long long int long_delta_ms; + struct dp_meter_band *band; + struct dp_meter *meter; + int i, band_exceeded_max = -1; + u32 band_exceeded_rate = 0; + u32 delta_ms; + u32 cost; + + meter = lookup_meter(&dp->meter_tbl, meter_id); + /* Do not drop the packet when there is no meter. */ + if (!meter) + return false; + + /* Lock the meter while using it. */ + spin_lock(&meter->lock); + + long_delta_ms = (now_ms - meter->used); /* ms */ + if (long_delta_ms < 0) { + /* This condition means that we have several threads fighting + * for a meter lock, and the one who received the packets a + * bit later wins. Assuming that all racing threads received + * packets at the same time to avoid overflow. + */ + long_delta_ms = 0; + } + + /* Make sure delta_ms will not be too large, so that bucket will not + * wrap around below. + */ + delta_ms = (long_delta_ms > (long long int)meter->max_delta_t) + ? meter->max_delta_t : (u32)long_delta_ms; + + /* Update meter statistics. + */ + meter->used = now_ms; + meter->stats.n_packets += 1; + meter->stats.n_bytes += skb->len; + + /* Bucket rate is either in kilobits per second, or in packets per + * second. We maintain the bucket in the units of either bits or + * 1/1000th of a packet, correspondingly. + * Then, when rate is multiplied with milliseconds, we get the + * bucket units: + * msec * kbps = bits, and + * msec * packets/sec = 1/1000 packets. + * + * 'cost' is the number of bucket units in this packet. + */ + cost = (meter->kbps) ? skb->len * 8 : 1000; + + /* Update all bands and find the one hit with the highest rate. */ + for (i = 0; i < meter->n_bands; ++i) { + long long int max_bucket_size; + + band = &meter->bands[i]; + max_bucket_size = band->burst_size * 1000LL; + + band->bucket += delta_ms * band->rate; + if (band->bucket > max_bucket_size) + band->bucket = max_bucket_size; + + if (band->bucket >= cost) { + band->bucket -= cost; + } else if (band->rate > band_exceeded_rate) { + band_exceeded_rate = band->rate; + band_exceeded_max = i; + } + } + + if (band_exceeded_max >= 0) { + /* Update band statistics. */ + band = &meter->bands[band_exceeded_max]; + band->stats.n_packets += 1; + band->stats.n_bytes += skb->len; + + /* Drop band triggered, let the caller drop the 'skb'. */ + if (band->type == OVS_METER_BAND_TYPE_DROP) { + spin_unlock(&meter->lock); + return true; + } + } + + spin_unlock(&meter->lock); + return false; +} + +static const struct genl_small_ops dp_meter_genl_ops[] = { + { .cmd = OVS_METER_CMD_FEATURES, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, /* OK for unprivileged users. */ + .doit = ovs_meter_cmd_features + }, + { .cmd = OVS_METER_CMD_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, /* Requires CAP_NET_ADMIN + * privilege. + */ + .doit = ovs_meter_cmd_set, + }, + { .cmd = OVS_METER_CMD_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = 0, /* OK for unprivileged users. */ + .doit = ovs_meter_cmd_get, + }, + { .cmd = OVS_METER_CMD_DEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, /* Requires CAP_NET_ADMIN + * privilege. + */ + .doit = ovs_meter_cmd_del + }, +}; + +static const struct genl_multicast_group ovs_meter_multicast_group = { + .name = OVS_METER_MCGROUP, +}; + +struct genl_family dp_meter_genl_family __ro_after_init = { + .hdrsize = sizeof(struct ovs_header), + .name = OVS_METER_FAMILY, + .version = OVS_METER_VERSION, + .maxattr = OVS_METER_ATTR_MAX, + .policy = meter_policy, + .netnsok = true, + .parallel_ops = true, + .small_ops = dp_meter_genl_ops, + .n_small_ops = ARRAY_SIZE(dp_meter_genl_ops), + .resv_start_op = OVS_METER_CMD_GET + 1, + .mcgrps = &ovs_meter_multicast_group, + .n_mcgrps = 1, + .module = THIS_MODULE, +}; + +int ovs_meters_init(struct datapath *dp) +{ + struct dp_meter_table *tbl = &dp->meter_tbl; + struct dp_meter_instance *ti; + unsigned long free_mem_bytes; + + ti = dp_meter_instance_alloc(DP_METER_ARRAY_SIZE_MIN); + if (!ti) + return -ENOMEM; + + /* Allow meters in a datapath to use ~3.12% of physical memory. */ + free_mem_bytes = nr_free_buffer_pages() * (PAGE_SIZE >> 5); + tbl->max_meters_allowed = min(free_mem_bytes / sizeof(struct dp_meter), + DP_METER_NUM_MAX); + if (!tbl->max_meters_allowed) + goto out_err; + + rcu_assign_pointer(tbl->ti, ti); + tbl->count = 0; + + return 0; + +out_err: + dp_meter_instance_free(ti); + return -ENOMEM; +} + +void ovs_meters_exit(struct datapath *dp) +{ + struct dp_meter_table *tbl = &dp->meter_tbl; + struct dp_meter_instance *ti = rcu_dereference_raw(tbl->ti); + int i; + + for (i = 0; i < ti->n_meters; i++) + ovs_meter_free(rcu_dereference_raw(ti->dp_meters[i])); + + dp_meter_instance_free(ti); +} diff --git a/net/openvswitch/meter.h b/net/openvswitch/meter.h new file mode 100644 index 000000000..0c33889a8 --- /dev/null +++ b/net/openvswitch/meter.h @@ -0,0 +1,63 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2017 Nicira, Inc. + */ + +#ifndef METER_H +#define METER_H 1 + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "flow.h" +struct datapath; + +#define DP_MAX_BANDS 1 +#define DP_METER_ARRAY_SIZE_MIN BIT_ULL(10) +#define DP_METER_NUM_MAX (200000UL) + +struct dp_meter_band { + u32 type; + u32 rate; + u32 burst_size; + u64 bucket; /* 1/1000 packets, or in bits */ + struct ovs_flow_stats stats; +}; + +struct dp_meter { + spinlock_t lock; /* Per meter lock */ + struct rcu_head rcu; + u32 id; + u16 kbps:1, keep_stats:1; + u16 n_bands; + u32 max_delta_t; + u64 used; + struct ovs_flow_stats stats; + struct dp_meter_band bands[]; +}; + +struct dp_meter_instance { + struct rcu_head rcu; + u32 n_meters; + struct dp_meter __rcu *dp_meters[]; +}; + +struct dp_meter_table { + struct dp_meter_instance __rcu *ti; + u32 count; + u32 max_meters_allowed; +}; + +extern struct genl_family dp_meter_genl_family; +int ovs_meters_init(struct datapath *dp); +void ovs_meters_exit(struct datapath *dp); +bool ovs_meter_execute(struct datapath *dp, struct sk_buff *skb, + struct sw_flow_key *key, u32 meter_id); + +#endif /* meter.h */ diff --git a/net/openvswitch/openvswitch_trace.c b/net/openvswitch/openvswitch_trace.c new file mode 100644 index 000000000..62c5f7d6f --- /dev/null +++ b/net/openvswitch/openvswitch_trace.c @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: GPL-2.0 +/* bug in tracepoint.h, it should include this */ +#include + +/* sparse isn't too happy with all macros... */ +#ifndef __CHECKER__ +#define CREATE_TRACE_POINTS +#include "openvswitch_trace.h" + +#endif diff --git a/net/openvswitch/openvswitch_trace.h b/net/openvswitch/openvswitch_trace.h new file mode 100644 index 000000000..3eb35d9eb --- /dev/null +++ b/net/openvswitch/openvswitch_trace.h @@ -0,0 +1,158 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#undef TRACE_SYSTEM +#define TRACE_SYSTEM openvswitch + +#if !defined(_TRACE_OPENVSWITCH_H) || defined(TRACE_HEADER_MULTI_READ) +#define _TRACE_OPENVSWITCH_H + +#include + +#include "datapath.h" + +TRACE_EVENT(ovs_do_execute_action, + + TP_PROTO(struct datapath *dp, struct sk_buff *skb, + struct sw_flow_key *key, const struct nlattr *a, int rem), + + TP_ARGS(dp, skb, key, a, rem), + + TP_STRUCT__entry( + __field( void *, dpaddr ) + __string( dp_name, ovs_dp_name(dp) ) + __string( dev_name, skb->dev->name ) + __field( void *, skbaddr ) + __field( unsigned int, len ) + __field( unsigned int, data_len ) + __field( unsigned int, truesize ) + __field( u8, nr_frags ) + __field( u16, gso_size ) + __field( u16, gso_type ) + __field( u32, ovs_flow_hash ) + __field( u32, recirc_id ) + __field( void *, keyaddr ) + __field( u16, key_eth_type ) + __field( u8, key_ct_state ) + __field( u8, key_ct_orig_proto ) + __field( u16, key_ct_zone ) + __field( unsigned int, flow_key_valid ) + __field( u8, action_type ) + __field( unsigned int, action_len ) + __field( void *, action_data ) + __field( u8, is_last ) + ), + + TP_fast_assign( + __entry->dpaddr = dp; + __assign_str(dp_name, ovs_dp_name(dp)); + __assign_str(dev_name, skb->dev->name); + __entry->skbaddr = skb; + __entry->len = skb->len; + __entry->data_len = skb->data_len; + __entry->truesize = skb->truesize; + __entry->nr_frags = skb_shinfo(skb)->nr_frags; + __entry->gso_size = skb_shinfo(skb)->gso_size; + __entry->gso_type = skb_shinfo(skb)->gso_type; + __entry->ovs_flow_hash = key->ovs_flow_hash; + __entry->recirc_id = key->recirc_id; + __entry->keyaddr = key; + __entry->key_eth_type = key->eth.type; + __entry->key_ct_state = key->ct_state; + __entry->key_ct_orig_proto = key->ct_orig_proto; + __entry->key_ct_zone = key->ct_zone; + __entry->flow_key_valid = !(key->mac_proto & SW_FLOW_KEY_INVALID); + __entry->action_type = nla_type(a); + __entry->action_len = nla_len(a); + __entry->action_data = nla_data(a); + __entry->is_last = nla_is_last(a, rem); + ), + + TP_printk("dpaddr=%p dp_name=%s dev=%s skbaddr=%p len=%u data_len=%u truesize=%u nr_frags=%d gso_size=%d gso_type=%#x ovs_flow_hash=0x%08x recirc_id=0x%08x keyaddr=%p eth_type=0x%04x ct_state=%02x ct_orig_proto=%02x ct_Zone=%04x flow_key_valid=%d action_type=%u action_len=%u action_data=%p is_last=%d", + __entry->dpaddr, __get_str(dp_name), __get_str(dev_name), + __entry->skbaddr, __entry->len, __entry->data_len, + __entry->truesize, __entry->nr_frags, __entry->gso_size, + __entry->gso_type, __entry->ovs_flow_hash, + __entry->recirc_id, __entry->keyaddr, __entry->key_eth_type, + __entry->key_ct_state, __entry->key_ct_orig_proto, + __entry->key_ct_zone, + __entry->flow_key_valid, + __entry->action_type, __entry->action_len, + __entry->action_data, __entry->is_last) +); + +TRACE_EVENT(ovs_dp_upcall, + + TP_PROTO(struct datapath *dp, struct sk_buff *skb, + const struct sw_flow_key *key, + const struct dp_upcall_info *upcall_info), + + TP_ARGS(dp, skb, key, upcall_info), + + TP_STRUCT__entry( + __field( void *, dpaddr ) + __string( dp_name, ovs_dp_name(dp) ) + __string( dev_name, skb->dev->name ) + __field( void *, skbaddr ) + __field( unsigned int, len ) + __field( unsigned int, data_len ) + __field( unsigned int, truesize ) + __field( u8, nr_frags ) + __field( u16, gso_size ) + __field( u16, gso_type ) + __field( u32, ovs_flow_hash ) + __field( u32, recirc_id ) + __field( const void *, keyaddr ) + __field( u16, key_eth_type ) + __field( u8, key_ct_state ) + __field( u8, key_ct_orig_proto ) + __field( u16, key_ct_zone ) + __field( unsigned int, flow_key_valid ) + __field( u8, upcall_cmd ) + __field( u32, upcall_port ) + __field( u16, upcall_mru ) + ), + + TP_fast_assign( + __entry->dpaddr = dp; + __assign_str(dp_name, ovs_dp_name(dp)); + __assign_str(dev_name, skb->dev->name); + __entry->skbaddr = skb; + __entry->len = skb->len; + __entry->data_len = skb->data_len; + __entry->truesize = skb->truesize; + __entry->nr_frags = skb_shinfo(skb)->nr_frags; + __entry->gso_size = skb_shinfo(skb)->gso_size; + __entry->gso_type = skb_shinfo(skb)->gso_type; + __entry->ovs_flow_hash = key->ovs_flow_hash; + __entry->recirc_id = key->recirc_id; + __entry->keyaddr = key; + __entry->key_eth_type = key->eth.type; + __entry->key_ct_state = key->ct_state; + __entry->key_ct_orig_proto = key->ct_orig_proto; + __entry->key_ct_zone = key->ct_zone; + __entry->flow_key_valid = !(key->mac_proto & SW_FLOW_KEY_INVALID); + __entry->upcall_cmd = upcall_info->cmd; + __entry->upcall_port = upcall_info->portid; + __entry->upcall_mru = upcall_info->mru; + ), + + TP_printk("dpaddr=%p dp_name=%s dev=%s skbaddr=%p len=%u data_len=%u truesize=%u nr_frags=%d gso_size=%d gso_type=%#x ovs_flow_hash=0x%08x recirc_id=0x%08x keyaddr=%p eth_type=0x%04x ct_state=%02x ct_orig_proto=%02x ct_zone=%04x flow_key_valid=%d upcall_cmd=%u upcall_port=%u upcall_mru=%u", + __entry->dpaddr, __get_str(dp_name), __get_str(dev_name), + __entry->skbaddr, __entry->len, __entry->data_len, + __entry->truesize, __entry->nr_frags, __entry->gso_size, + __entry->gso_type, __entry->ovs_flow_hash, + __entry->recirc_id, __entry->keyaddr, __entry->key_eth_type, + __entry->key_ct_state, __entry->key_ct_orig_proto, + __entry->key_ct_zone, + __entry->flow_key_valid, + __entry->upcall_cmd, __entry->upcall_port, + __entry->upcall_mru) +); + +#endif /* _TRACE_OPENVSWITCH_H */ + +/* This part must be outside protection */ +#undef TRACE_INCLUDE_PATH +#define TRACE_INCLUDE_PATH . +#undef TRACE_INCLUDE_FILE +#define TRACE_INCLUDE_FILE openvswitch_trace +#include diff --git a/net/openvswitch/vport-geneve.c b/net/openvswitch/vport-geneve.c new file mode 100644 index 000000000..89a8e1501 --- /dev/null +++ b/net/openvswitch/vport-geneve.c @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (c) 2014 Nicira, Inc. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "datapath.h" +#include "vport.h" +#include "vport-netdev.h" + +static struct vport_ops ovs_geneve_vport_ops; +/** + * struct geneve_port - Keeps track of open UDP ports + * @dst_port: destination port. + */ +struct geneve_port { + u16 dst_port; +}; + +static inline struct geneve_port *geneve_vport(const struct vport *vport) +{ + return vport_priv(vport); +} + +static int geneve_get_options(const struct vport *vport, + struct sk_buff *skb) +{ + struct geneve_port *geneve_port = geneve_vport(vport); + + if (nla_put_u16(skb, OVS_TUNNEL_ATTR_DST_PORT, geneve_port->dst_port)) + return -EMSGSIZE; + return 0; +} + +static struct vport *geneve_tnl_create(const struct vport_parms *parms) +{ + struct net *net = ovs_dp_get_net(parms->dp); + struct nlattr *options = parms->options; + struct geneve_port *geneve_port; + struct net_device *dev; + struct vport *vport; + struct nlattr *a; + u16 dst_port; + int err; + + if (!options) { + err = -EINVAL; + goto error; + } + + a = nla_find_nested(options, OVS_TUNNEL_ATTR_DST_PORT); + if (a && nla_len(a) == sizeof(u16)) { + dst_port = nla_get_u16(a); + } else { + /* Require destination port from userspace. */ + err = -EINVAL; + goto error; + } + + vport = ovs_vport_alloc(sizeof(struct geneve_port), + &ovs_geneve_vport_ops, parms); + if (IS_ERR(vport)) + return vport; + + geneve_port = geneve_vport(vport); + geneve_port->dst_port = dst_port; + + rtnl_lock(); + dev = geneve_dev_create_fb(net, parms->name, NET_NAME_USER, dst_port); + if (IS_ERR(dev)) { + rtnl_unlock(); + ovs_vport_free(vport); + return ERR_CAST(dev); + } + + err = dev_change_flags(dev, dev->flags | IFF_UP, NULL); + if (err < 0) { + rtnl_delete_link(dev); + rtnl_unlock(); + ovs_vport_free(vport); + goto error; + } + + rtnl_unlock(); + return vport; +error: + return ERR_PTR(err); +} + +static struct vport *geneve_create(const struct vport_parms *parms) +{ + struct vport *vport; + + vport = geneve_tnl_create(parms); + if (IS_ERR(vport)) + return vport; + + return ovs_netdev_link(vport, parms->name); +} + +static struct vport_ops ovs_geneve_vport_ops = { + .type = OVS_VPORT_TYPE_GENEVE, + .create = geneve_create, + .destroy = ovs_netdev_tunnel_destroy, + .get_options = geneve_get_options, + .send = dev_queue_xmit, +}; + +static int __init ovs_geneve_tnl_init(void) +{ + return ovs_vport_ops_register(&ovs_geneve_vport_ops); +} + +static void __exit ovs_geneve_tnl_exit(void) +{ + ovs_vport_ops_unregister(&ovs_geneve_vport_ops); +} + +module_init(ovs_geneve_tnl_init); +module_exit(ovs_geneve_tnl_exit); + +MODULE_DESCRIPTION("OVS: Geneve switching port"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("vport-type-5"); diff --git a/net/openvswitch/vport-gre.c b/net/openvswitch/vport-gre.c new file mode 100644 index 000000000..e6b5e76a9 --- /dev/null +++ b/net/openvswitch/vport-gre.c @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2007-2014 Nicira, Inc. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "datapath.h" +#include "vport.h" +#include "vport-netdev.h" + +static struct vport_ops ovs_gre_vport_ops; + +static struct vport *gre_tnl_create(const struct vport_parms *parms) +{ + struct net *net = ovs_dp_get_net(parms->dp); + struct net_device *dev; + struct vport *vport; + int err; + + vport = ovs_vport_alloc(0, &ovs_gre_vport_ops, parms); + if (IS_ERR(vport)) + return vport; + + rtnl_lock(); + dev = gretap_fb_dev_create(net, parms->name, NET_NAME_USER); + if (IS_ERR(dev)) { + rtnl_unlock(); + ovs_vport_free(vport); + return ERR_CAST(dev); + } + + err = dev_change_flags(dev, dev->flags | IFF_UP, NULL); + if (err < 0) { + rtnl_delete_link(dev); + rtnl_unlock(); + ovs_vport_free(vport); + return ERR_PTR(err); + } + + rtnl_unlock(); + return vport; +} + +static struct vport *gre_create(const struct vport_parms *parms) +{ + struct vport *vport; + + vport = gre_tnl_create(parms); + if (IS_ERR(vport)) + return vport; + + return ovs_netdev_link(vport, parms->name); +} + +static struct vport_ops ovs_gre_vport_ops = { + .type = OVS_VPORT_TYPE_GRE, + .create = gre_create, + .send = dev_queue_xmit, + .destroy = ovs_netdev_tunnel_destroy, +}; + +static int __init ovs_gre_tnl_init(void) +{ + return ovs_vport_ops_register(&ovs_gre_vport_ops); +} + +static void __exit ovs_gre_tnl_exit(void) +{ + ovs_vport_ops_unregister(&ovs_gre_vport_ops); +} + +module_init(ovs_gre_tnl_init); +module_exit(ovs_gre_tnl_exit); + +MODULE_DESCRIPTION("OVS: GRE switching port"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("vport-type-3"); diff --git a/net/openvswitch/vport-internal_dev.c b/net/openvswitch/vport-internal_dev.c new file mode 100644 index 000000000..74c88a6ba --- /dev/null +++ b/net/openvswitch/vport-internal_dev.c @@ -0,0 +1,255 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2007-2012 Nicira, Inc. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "datapath.h" +#include "vport-internal_dev.h" +#include "vport-netdev.h" + +struct internal_dev { + struct vport *vport; +}; + +static struct vport_ops ovs_internal_vport_ops; + +static struct internal_dev *internal_dev_priv(struct net_device *netdev) +{ + return netdev_priv(netdev); +} + +/* Called with rcu_read_lock_bh. */ +static netdev_tx_t +internal_dev_xmit(struct sk_buff *skb, struct net_device *netdev) +{ + int len, err; + + /* store len value because skb can be freed inside ovs_vport_receive() */ + len = skb->len; + + rcu_read_lock(); + err = ovs_vport_receive(internal_dev_priv(netdev)->vport, skb, NULL); + rcu_read_unlock(); + + if (likely(!err)) + dev_sw_netstats_tx_add(netdev, 1, len); + else + netdev->stats.tx_errors++; + + return NETDEV_TX_OK; +} + +static int internal_dev_open(struct net_device *netdev) +{ + netif_start_queue(netdev); + return 0; +} + +static int internal_dev_stop(struct net_device *netdev) +{ + netif_stop_queue(netdev); + return 0; +} + +static void internal_dev_getinfo(struct net_device *netdev, + struct ethtool_drvinfo *info) +{ + strscpy(info->driver, "openvswitch", sizeof(info->driver)); +} + +static const struct ethtool_ops internal_dev_ethtool_ops = { + .get_drvinfo = internal_dev_getinfo, + .get_link = ethtool_op_get_link, +}; + +static void internal_dev_destructor(struct net_device *dev) +{ + struct vport *vport = ovs_internal_dev_get_vport(dev); + + ovs_vport_free(vport); +} + +static const struct net_device_ops internal_dev_netdev_ops = { + .ndo_open = internal_dev_open, + .ndo_stop = internal_dev_stop, + .ndo_start_xmit = internal_dev_xmit, + .ndo_set_mac_address = eth_mac_addr, + .ndo_get_stats64 = dev_get_tstats64, +}; + +static struct rtnl_link_ops internal_dev_link_ops __read_mostly = { + .kind = "openvswitch", +}; + +static void do_setup(struct net_device *netdev) +{ + ether_setup(netdev); + + netdev->max_mtu = ETH_MAX_MTU; + + netdev->netdev_ops = &internal_dev_netdev_ops; + + netdev->priv_flags &= ~IFF_TX_SKB_SHARING; + netdev->priv_flags |= IFF_LIVE_ADDR_CHANGE | IFF_OPENVSWITCH | + IFF_NO_QUEUE; + netdev->needs_free_netdev = true; + netdev->priv_destructor = NULL; + netdev->ethtool_ops = &internal_dev_ethtool_ops; + netdev->rtnl_link_ops = &internal_dev_link_ops; + + netdev->features = NETIF_F_LLTX | NETIF_F_SG | NETIF_F_FRAGLIST | + NETIF_F_HIGHDMA | NETIF_F_HW_CSUM | + NETIF_F_GSO_SOFTWARE | NETIF_F_GSO_ENCAP_ALL; + + netdev->vlan_features = netdev->features; + netdev->hw_enc_features = netdev->features; + netdev->features |= NETIF_F_HW_VLAN_CTAG_TX | NETIF_F_HW_VLAN_STAG_TX; + netdev->hw_features = netdev->features & ~NETIF_F_LLTX; + + eth_hw_addr_random(netdev); +} + +static struct vport *internal_dev_create(const struct vport_parms *parms) +{ + struct vport *vport; + struct internal_dev *internal_dev; + struct net_device *dev; + int err; + + vport = ovs_vport_alloc(0, &ovs_internal_vport_ops, parms); + if (IS_ERR(vport)) { + err = PTR_ERR(vport); + goto error; + } + + dev = alloc_netdev(sizeof(struct internal_dev), + parms->name, NET_NAME_USER, do_setup); + vport->dev = dev; + if (!vport->dev) { + err = -ENOMEM; + goto error_free_vport; + } + vport->dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); + if (!vport->dev->tstats) { + err = -ENOMEM; + goto error_free_netdev; + } + + dev_net_set(vport->dev, ovs_dp_get_net(vport->dp)); + dev->ifindex = parms->desired_ifindex; + internal_dev = internal_dev_priv(vport->dev); + internal_dev->vport = vport; + + /* Restrict bridge port to current netns. */ + if (vport->port_no == OVSP_LOCAL) + vport->dev->features |= NETIF_F_NETNS_LOCAL; + + rtnl_lock(); + err = register_netdevice(vport->dev); + if (err) + goto error_unlock; + vport->dev->priv_destructor = internal_dev_destructor; + + dev_set_promiscuity(vport->dev, 1); + rtnl_unlock(); + netif_start_queue(vport->dev); + + return vport; + +error_unlock: + rtnl_unlock(); + free_percpu(dev->tstats); +error_free_netdev: + free_netdev(dev); +error_free_vport: + ovs_vport_free(vport); +error: + return ERR_PTR(err); +} + +static void internal_dev_destroy(struct vport *vport) +{ + netif_stop_queue(vport->dev); + rtnl_lock(); + dev_set_promiscuity(vport->dev, -1); + + /* unregister_netdevice() waits for an RCU grace period. */ + unregister_netdevice(vport->dev); + free_percpu(vport->dev->tstats); + rtnl_unlock(); +} + +static int internal_dev_recv(struct sk_buff *skb) +{ + struct net_device *netdev = skb->dev; + + if (unlikely(!(netdev->flags & IFF_UP))) { + kfree_skb(skb); + netdev->stats.rx_dropped++; + return NETDEV_TX_OK; + } + + skb_dst_drop(skb); + nf_reset_ct(skb); + secpath_reset(skb); + + skb->pkt_type = PACKET_HOST; + skb->protocol = eth_type_trans(skb, netdev); + skb_postpull_rcsum(skb, eth_hdr(skb), ETH_HLEN); + dev_sw_netstats_rx_add(netdev, skb->len); + + netif_rx(skb); + return NETDEV_TX_OK; +} + +static struct vport_ops ovs_internal_vport_ops = { + .type = OVS_VPORT_TYPE_INTERNAL, + .create = internal_dev_create, + .destroy = internal_dev_destroy, + .send = internal_dev_recv, +}; + +int ovs_is_internal_dev(const struct net_device *netdev) +{ + return netdev->netdev_ops == &internal_dev_netdev_ops; +} + +struct vport *ovs_internal_dev_get_vport(struct net_device *netdev) +{ + if (!ovs_is_internal_dev(netdev)) + return NULL; + + return internal_dev_priv(netdev)->vport; +} + +int ovs_internal_dev_rtnl_link_register(void) +{ + int err; + + err = rtnl_link_register(&internal_dev_link_ops); + if (err < 0) + return err; + + err = ovs_vport_ops_register(&ovs_internal_vport_ops); + if (err < 0) + rtnl_link_unregister(&internal_dev_link_ops); + + return err; +} + +void ovs_internal_dev_rtnl_link_unregister(void) +{ + ovs_vport_ops_unregister(&ovs_internal_vport_ops); + rtnl_link_unregister(&internal_dev_link_ops); +} diff --git a/net/openvswitch/vport-internal_dev.h b/net/openvswitch/vport-internal_dev.h new file mode 100644 index 000000000..0112d1b09 --- /dev/null +++ b/net/openvswitch/vport-internal_dev.h @@ -0,0 +1,17 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2007-2011 Nicira, Inc. + */ + +#ifndef VPORT_INTERNAL_DEV_H +#define VPORT_INTERNAL_DEV_H 1 + +#include "datapath.h" +#include "vport.h" + +int ovs_is_internal_dev(const struct net_device *); +struct vport *ovs_internal_dev_get_vport(struct net_device *); +int ovs_internal_dev_rtnl_link_register(void); +void ovs_internal_dev_rtnl_link_unregister(void); + +#endif /* vport-internal_dev.h */ diff --git a/net/openvswitch/vport-netdev.c b/net/openvswitch/vport-netdev.c new file mode 100644 index 000000000..2f61d5bdc --- /dev/null +++ b/net/openvswitch/vport-netdev.c @@ -0,0 +1,209 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2007-2012 Nicira, Inc. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "datapath.h" +#include "vport.h" +#include "vport-internal_dev.h" +#include "vport-netdev.h" + +static struct vport_ops ovs_netdev_vport_ops; + +/* Must be called with rcu_read_lock. */ +static void netdev_port_receive(struct sk_buff *skb) +{ + struct vport *vport; + + vport = ovs_netdev_get_vport(skb->dev); + if (unlikely(!vport)) + goto error; + + if (unlikely(skb_warn_if_lro(skb))) + goto error; + + /* Make our own copy of the packet. Otherwise we will mangle the + * packet for anyone who came before us (e.g. tcpdump via AF_PACKET). + */ + skb = skb_share_check(skb, GFP_ATOMIC); + if (unlikely(!skb)) + return; + + if (skb->dev->type == ARPHRD_ETHER) + skb_push_rcsum(skb, ETH_HLEN); + + ovs_vport_receive(vport, skb, skb_tunnel_info(skb)); + return; +error: + kfree_skb(skb); +} + +/* Called with rcu_read_lock and bottom-halves disabled. */ +static rx_handler_result_t netdev_frame_hook(struct sk_buff **pskb) +{ + struct sk_buff *skb = *pskb; + + if (unlikely(skb->pkt_type == PACKET_LOOPBACK)) + return RX_HANDLER_PASS; + + netdev_port_receive(skb); + return RX_HANDLER_CONSUMED; +} + +static struct net_device *get_dpdev(const struct datapath *dp) +{ + struct vport *local; + + local = ovs_vport_ovsl(dp, OVSP_LOCAL); + return local->dev; +} + +struct vport *ovs_netdev_link(struct vport *vport, const char *name) +{ + int err; + + vport->dev = dev_get_by_name(ovs_dp_get_net(vport->dp), name); + if (!vport->dev) { + err = -ENODEV; + goto error_free_vport; + } + netdev_tracker_alloc(vport->dev, &vport->dev_tracker, GFP_KERNEL); + if (vport->dev->flags & IFF_LOOPBACK || + (vport->dev->type != ARPHRD_ETHER && + vport->dev->type != ARPHRD_NONE) || + ovs_is_internal_dev(vport->dev)) { + err = -EINVAL; + goto error_put; + } + + rtnl_lock(); + err = netdev_master_upper_dev_link(vport->dev, + get_dpdev(vport->dp), + NULL, NULL, NULL); + if (err) + goto error_unlock; + + err = netdev_rx_handler_register(vport->dev, netdev_frame_hook, + vport); + if (err) + goto error_master_upper_dev_unlink; + + dev_disable_lro(vport->dev); + dev_set_promiscuity(vport->dev, 1); + vport->dev->priv_flags |= IFF_OVS_DATAPATH; + rtnl_unlock(); + + return vport; + +error_master_upper_dev_unlink: + netdev_upper_dev_unlink(vport->dev, get_dpdev(vport->dp)); +error_unlock: + rtnl_unlock(); +error_put: + netdev_put(vport->dev, &vport->dev_tracker); +error_free_vport: + ovs_vport_free(vport); + return ERR_PTR(err); +} +EXPORT_SYMBOL_GPL(ovs_netdev_link); + +static struct vport *netdev_create(const struct vport_parms *parms) +{ + struct vport *vport; + + vport = ovs_vport_alloc(0, &ovs_netdev_vport_ops, parms); + if (IS_ERR(vport)) + return vport; + + return ovs_netdev_link(vport, parms->name); +} + +static void vport_netdev_free(struct rcu_head *rcu) +{ + struct vport *vport = container_of(rcu, struct vport, rcu); + + netdev_put(vport->dev, &vport->dev_tracker); + ovs_vport_free(vport); +} + +void ovs_netdev_detach_dev(struct vport *vport) +{ + ASSERT_RTNL(); + vport->dev->priv_flags &= ~IFF_OVS_DATAPATH; + netdev_rx_handler_unregister(vport->dev); + netdev_upper_dev_unlink(vport->dev, + netdev_master_upper_dev_get(vport->dev)); + dev_set_promiscuity(vport->dev, -1); +} + +static void netdev_destroy(struct vport *vport) +{ + rtnl_lock(); + if (netif_is_ovs_port(vport->dev)) + ovs_netdev_detach_dev(vport); + rtnl_unlock(); + + call_rcu(&vport->rcu, vport_netdev_free); +} + +void ovs_netdev_tunnel_destroy(struct vport *vport) +{ + rtnl_lock(); + if (netif_is_ovs_port(vport->dev)) + ovs_netdev_detach_dev(vport); + + /* We can be invoked by both explicit vport deletion and + * underlying netdev deregistration; delete the link only + * if it's not already shutting down. + */ + if (vport->dev->reg_state == NETREG_REGISTERED) + rtnl_delete_link(vport->dev); + netdev_put(vport->dev, &vport->dev_tracker); + vport->dev = NULL; + rtnl_unlock(); + + call_rcu(&vport->rcu, vport_netdev_free); +} +EXPORT_SYMBOL_GPL(ovs_netdev_tunnel_destroy); + +/* Returns null if this device is not attached to a datapath. */ +struct vport *ovs_netdev_get_vport(struct net_device *dev) +{ + if (likely(netif_is_ovs_port(dev))) + return (struct vport *) + rcu_dereference_rtnl(dev->rx_handler_data); + else + return NULL; +} + +static struct vport_ops ovs_netdev_vport_ops = { + .type = OVS_VPORT_TYPE_NETDEV, + .create = netdev_create, + .destroy = netdev_destroy, + .send = dev_queue_xmit, +}; + +int __init ovs_netdev_init(void) +{ + return ovs_vport_ops_register(&ovs_netdev_vport_ops); +} + +void ovs_netdev_exit(void) +{ + ovs_vport_ops_unregister(&ovs_netdev_vport_ops); +} diff --git a/net/openvswitch/vport-netdev.h b/net/openvswitch/vport-netdev.h new file mode 100644 index 000000000..c5d83a43b --- /dev/null +++ b/net/openvswitch/vport-netdev.h @@ -0,0 +1,23 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2007-2011 Nicira, Inc. + */ + +#ifndef VPORT_NETDEV_H +#define VPORT_NETDEV_H 1 + +#include +#include + +#include "vport.h" + +struct vport *ovs_netdev_get_vport(struct net_device *dev); + +struct vport *ovs_netdev_link(struct vport *vport, const char *name); +void ovs_netdev_detach_dev(struct vport *); + +int __init ovs_netdev_init(void); +void ovs_netdev_exit(void); + +void ovs_netdev_tunnel_destroy(struct vport *vport); +#endif /* vport_netdev.h */ diff --git a/net/openvswitch/vport-vxlan.c b/net/openvswitch/vport-vxlan.c new file mode 100644 index 000000000..188e9c136 --- /dev/null +++ b/net/openvswitch/vport-vxlan.c @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2014 Nicira, Inc. + * Copyright (c) 2013 Cisco Systems, Inc. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "datapath.h" +#include "vport.h" +#include "vport-netdev.h" + +static struct vport_ops ovs_vxlan_netdev_vport_ops; + +static int vxlan_get_options(const struct vport *vport, struct sk_buff *skb) +{ + struct vxlan_dev *vxlan = netdev_priv(vport->dev); + __be16 dst_port = vxlan->cfg.dst_port; + + if (nla_put_u16(skb, OVS_TUNNEL_ATTR_DST_PORT, ntohs(dst_port))) + return -EMSGSIZE; + + if (vxlan->cfg.flags & VXLAN_F_GBP) { + struct nlattr *exts; + + exts = nla_nest_start_noflag(skb, OVS_TUNNEL_ATTR_EXTENSION); + if (!exts) + return -EMSGSIZE; + + if (vxlan->cfg.flags & VXLAN_F_GBP && + nla_put_flag(skb, OVS_VXLAN_EXT_GBP)) + return -EMSGSIZE; + + nla_nest_end(skb, exts); + } + + return 0; +} + +static const struct nla_policy exts_policy[OVS_VXLAN_EXT_MAX + 1] = { + [OVS_VXLAN_EXT_GBP] = { .type = NLA_FLAG, }, +}; + +static int vxlan_configure_exts(struct vport *vport, struct nlattr *attr, + struct vxlan_config *conf) +{ + struct nlattr *exts[OVS_VXLAN_EXT_MAX + 1]; + int err; + + if (nla_len(attr) < sizeof(struct nlattr)) + return -EINVAL; + + err = nla_parse_nested_deprecated(exts, OVS_VXLAN_EXT_MAX, attr, + exts_policy, NULL); + if (err < 0) + return err; + + if (exts[OVS_VXLAN_EXT_GBP]) + conf->flags |= VXLAN_F_GBP; + + return 0; +} + +static struct vport *vxlan_tnl_create(const struct vport_parms *parms) +{ + struct net *net = ovs_dp_get_net(parms->dp); + struct nlattr *options = parms->options; + struct net_device *dev; + struct vport *vport; + struct nlattr *a; + int err; + struct vxlan_config conf = { + .no_share = true, + .flags = VXLAN_F_COLLECT_METADATA | VXLAN_F_UDP_ZERO_CSUM6_RX, + /* Don't restrict the packets that can be sent by MTU */ + .mtu = IP_MAX_MTU, + }; + + if (!options) { + err = -EINVAL; + goto error; + } + + a = nla_find_nested(options, OVS_TUNNEL_ATTR_DST_PORT); + if (a && nla_len(a) == sizeof(u16)) { + conf.dst_port = htons(nla_get_u16(a)); + } else { + /* Require destination port from userspace. */ + err = -EINVAL; + goto error; + } + + vport = ovs_vport_alloc(0, &ovs_vxlan_netdev_vport_ops, parms); + if (IS_ERR(vport)) + return vport; + + a = nla_find_nested(options, OVS_TUNNEL_ATTR_EXTENSION); + if (a) { + err = vxlan_configure_exts(vport, a, &conf); + if (err) { + ovs_vport_free(vport); + goto error; + } + } + + rtnl_lock(); + dev = vxlan_dev_create(net, parms->name, NET_NAME_USER, &conf); + if (IS_ERR(dev)) { + rtnl_unlock(); + ovs_vport_free(vport); + return ERR_CAST(dev); + } + + err = dev_change_flags(dev, dev->flags | IFF_UP, NULL); + if (err < 0) { + rtnl_delete_link(dev); + rtnl_unlock(); + ovs_vport_free(vport); + goto error; + } + + rtnl_unlock(); + return vport; +error: + return ERR_PTR(err); +} + +static struct vport *vxlan_create(const struct vport_parms *parms) +{ + struct vport *vport; + + vport = vxlan_tnl_create(parms); + if (IS_ERR(vport)) + return vport; + + return ovs_netdev_link(vport, parms->name); +} + +static struct vport_ops ovs_vxlan_netdev_vport_ops = { + .type = OVS_VPORT_TYPE_VXLAN, + .create = vxlan_create, + .destroy = ovs_netdev_tunnel_destroy, + .get_options = vxlan_get_options, + .send = dev_queue_xmit, +}; + +static int __init ovs_vxlan_tnl_init(void) +{ + return ovs_vport_ops_register(&ovs_vxlan_netdev_vport_ops); +} + +static void __exit ovs_vxlan_tnl_exit(void) +{ + ovs_vport_ops_unregister(&ovs_vxlan_netdev_vport_ops); +} + +module_init(ovs_vxlan_tnl_init); +module_exit(ovs_vxlan_tnl_exit); + +MODULE_DESCRIPTION("OVS: VXLAN switching port"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("vport-type-4"); diff --git a/net/openvswitch/vport.c b/net/openvswitch/vport.c new file mode 100644 index 000000000..82a74f998 --- /dev/null +++ b/net/openvswitch/vport.c @@ -0,0 +1,516 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2007-2014 Nicira, Inc. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "datapath.h" +#include "vport.h" +#include "vport-internal_dev.h" + +static LIST_HEAD(vport_ops_list); + +/* Protected by RCU read lock for reading, ovs_mutex for writing. */ +static struct hlist_head *dev_table; +#define VPORT_HASH_BUCKETS 1024 + +/** + * ovs_vport_init - initialize vport subsystem + * + * Called at module load time to initialize the vport subsystem. + */ +int ovs_vport_init(void) +{ + dev_table = kcalloc(VPORT_HASH_BUCKETS, sizeof(struct hlist_head), + GFP_KERNEL); + if (!dev_table) + return -ENOMEM; + + return 0; +} + +/** + * ovs_vport_exit - shutdown vport subsystem + * + * Called at module exit time to shutdown the vport subsystem. + */ +void ovs_vport_exit(void) +{ + kfree(dev_table); +} + +static struct hlist_head *hash_bucket(const struct net *net, const char *name) +{ + unsigned int hash = jhash(name, strlen(name), (unsigned long) net); + return &dev_table[hash & (VPORT_HASH_BUCKETS - 1)]; +} + +int __ovs_vport_ops_register(struct vport_ops *ops) +{ + int err = -EEXIST; + struct vport_ops *o; + + ovs_lock(); + list_for_each_entry(o, &vport_ops_list, list) + if (ops->type == o->type) + goto errout; + + list_add_tail(&ops->list, &vport_ops_list); + err = 0; +errout: + ovs_unlock(); + return err; +} +EXPORT_SYMBOL_GPL(__ovs_vport_ops_register); + +void ovs_vport_ops_unregister(struct vport_ops *ops) +{ + ovs_lock(); + list_del(&ops->list); + ovs_unlock(); +} +EXPORT_SYMBOL_GPL(ovs_vport_ops_unregister); + +/** + * ovs_vport_locate - find a port that has already been created + * + * @net: network namespace + * @name: name of port to find + * + * Must be called with ovs or RCU read lock. + */ +struct vport *ovs_vport_locate(const struct net *net, const char *name) +{ + struct hlist_head *bucket = hash_bucket(net, name); + struct vport *vport; + + hlist_for_each_entry_rcu(vport, bucket, hash_node, + lockdep_ovsl_is_held()) + if (!strcmp(name, ovs_vport_name(vport)) && + net_eq(ovs_dp_get_net(vport->dp), net)) + return vport; + + return NULL; +} + +/** + * ovs_vport_alloc - allocate and initialize new vport + * + * @priv_size: Size of private data area to allocate. + * @ops: vport device ops + * @parms: information about new vport. + * + * Allocate and initialize a new vport defined by @ops. The vport will contain + * a private data area of size @priv_size that can be accessed using + * vport_priv(). Some parameters of the vport will be initialized from @parms. + * @vports that are no longer needed should be released with + * vport_free(). + */ +struct vport *ovs_vport_alloc(int priv_size, const struct vport_ops *ops, + const struct vport_parms *parms) +{ + struct vport *vport; + size_t alloc_size; + + alloc_size = sizeof(struct vport); + if (priv_size) { + alloc_size = ALIGN(alloc_size, VPORT_ALIGN); + alloc_size += priv_size; + } + + vport = kzalloc(alloc_size, GFP_KERNEL); + if (!vport) + return ERR_PTR(-ENOMEM); + + vport->dp = parms->dp; + vport->port_no = parms->port_no; + vport->ops = ops; + INIT_HLIST_NODE(&vport->dp_hash_node); + + if (ovs_vport_set_upcall_portids(vport, parms->upcall_portids)) { + kfree(vport); + return ERR_PTR(-EINVAL); + } + + return vport; +} +EXPORT_SYMBOL_GPL(ovs_vport_alloc); + +/** + * ovs_vport_free - uninitialize and free vport + * + * @vport: vport to free + * + * Frees a vport allocated with vport_alloc() when it is no longer needed. + * + * The caller must ensure that an RCU grace period has passed since the last + * time @vport was in a datapath. + */ +void ovs_vport_free(struct vport *vport) +{ + /* vport is freed from RCU callback or error path, Therefore + * it is safe to use raw dereference. + */ + kfree(rcu_dereference_raw(vport->upcall_portids)); + kfree(vport); +} +EXPORT_SYMBOL_GPL(ovs_vport_free); + +static struct vport_ops *ovs_vport_lookup(const struct vport_parms *parms) +{ + struct vport_ops *ops; + + list_for_each_entry(ops, &vport_ops_list, list) + if (ops->type == parms->type) + return ops; + + return NULL; +} + +/** + * ovs_vport_add - add vport device (for kernel callers) + * + * @parms: Information about new vport. + * + * Creates a new vport with the specified configuration (which is dependent on + * device type). ovs_mutex must be held. + */ +struct vport *ovs_vport_add(const struct vport_parms *parms) +{ + struct vport_ops *ops; + struct vport *vport; + + ops = ovs_vport_lookup(parms); + if (ops) { + struct hlist_head *bucket; + + if (!try_module_get(ops->owner)) + return ERR_PTR(-EAFNOSUPPORT); + + vport = ops->create(parms); + if (IS_ERR(vport)) { + module_put(ops->owner); + return vport; + } + + bucket = hash_bucket(ovs_dp_get_net(vport->dp), + ovs_vport_name(vport)); + hlist_add_head_rcu(&vport->hash_node, bucket); + return vport; + } + + /* Unlock to attempt module load and return -EAGAIN if load + * was successful as we need to restart the port addition + * workflow. + */ + ovs_unlock(); + request_module("vport-type-%d", parms->type); + ovs_lock(); + + if (!ovs_vport_lookup(parms)) + return ERR_PTR(-EAFNOSUPPORT); + else + return ERR_PTR(-EAGAIN); +} + +/** + * ovs_vport_set_options - modify existing vport device (for kernel callers) + * + * @vport: vport to modify. + * @options: New configuration. + * + * Modifies an existing device with the specified configuration (which is + * dependent on device type). ovs_mutex must be held. + */ +int ovs_vport_set_options(struct vport *vport, struct nlattr *options) +{ + if (!vport->ops->set_options) + return -EOPNOTSUPP; + return vport->ops->set_options(vport, options); +} + +/** + * ovs_vport_del - delete existing vport device + * + * @vport: vport to delete. + * + * Detaches @vport from its datapath and destroys it. ovs_mutex must + * be held. + */ +void ovs_vport_del(struct vport *vport) +{ + hlist_del_rcu(&vport->hash_node); + module_put(vport->ops->owner); + vport->ops->destroy(vport); +} + +/** + * ovs_vport_get_stats - retrieve device stats + * + * @vport: vport from which to retrieve the stats + * @stats: location to store stats + * + * Retrieves transmit, receive, and error stats for the given device. + * + * Must be called with ovs_mutex or rcu_read_lock. + */ +void ovs_vport_get_stats(struct vport *vport, struct ovs_vport_stats *stats) +{ + const struct rtnl_link_stats64 *dev_stats; + struct rtnl_link_stats64 temp; + + dev_stats = dev_get_stats(vport->dev, &temp); + stats->rx_errors = dev_stats->rx_errors; + stats->tx_errors = dev_stats->tx_errors; + stats->tx_dropped = dev_stats->tx_dropped; + stats->rx_dropped = dev_stats->rx_dropped; + + stats->rx_bytes = dev_stats->rx_bytes; + stats->rx_packets = dev_stats->rx_packets; + stats->tx_bytes = dev_stats->tx_bytes; + stats->tx_packets = dev_stats->tx_packets; +} + +/** + * ovs_vport_get_options - retrieve device options + * + * @vport: vport from which to retrieve the options. + * @skb: sk_buff where options should be appended. + * + * Retrieves the configuration of the given device, appending an + * %OVS_VPORT_ATTR_OPTIONS attribute that in turn contains nested + * vport-specific attributes to @skb. + * + * Returns 0 if successful, -EMSGSIZE if @skb has insufficient room, or another + * negative error code if a real error occurred. If an error occurs, @skb is + * left unmodified. + * + * Must be called with ovs_mutex or rcu_read_lock. + */ +int ovs_vport_get_options(const struct vport *vport, struct sk_buff *skb) +{ + struct nlattr *nla; + int err; + + if (!vport->ops->get_options) + return 0; + + nla = nla_nest_start_noflag(skb, OVS_VPORT_ATTR_OPTIONS); + if (!nla) + return -EMSGSIZE; + + err = vport->ops->get_options(vport, skb); + if (err) { + nla_nest_cancel(skb, nla); + return err; + } + + nla_nest_end(skb, nla); + return 0; +} + +/** + * ovs_vport_set_upcall_portids - set upcall portids of @vport. + * + * @vport: vport to modify. + * @ids: new configuration, an array of port ids. + * + * Sets the vport's upcall_portids to @ids. + * + * Returns 0 if successful, -EINVAL if @ids is zero length or cannot be parsed + * as an array of U32. + * + * Must be called with ovs_mutex. + */ +int ovs_vport_set_upcall_portids(struct vport *vport, const struct nlattr *ids) +{ + struct vport_portids *old, *vport_portids; + + if (!nla_len(ids) || nla_len(ids) % sizeof(u32)) + return -EINVAL; + + old = ovsl_dereference(vport->upcall_portids); + + vport_portids = kmalloc(sizeof(*vport_portids) + nla_len(ids), + GFP_KERNEL); + if (!vport_portids) + return -ENOMEM; + + vport_portids->n_ids = nla_len(ids) / sizeof(u32); + vport_portids->rn_ids = reciprocal_value(vport_portids->n_ids); + nla_memcpy(vport_portids->ids, ids, nla_len(ids)); + + rcu_assign_pointer(vport->upcall_portids, vport_portids); + + if (old) + kfree_rcu(old, rcu); + return 0; +} + +/** + * ovs_vport_get_upcall_portids - get the upcall_portids of @vport. + * + * @vport: vport from which to retrieve the portids. + * @skb: sk_buff where portids should be appended. + * + * Retrieves the configuration of the given vport, appending the + * %OVS_VPORT_ATTR_UPCALL_PID attribute which is the array of upcall + * portids to @skb. + * + * Returns 0 if successful, -EMSGSIZE if @skb has insufficient room. + * If an error occurs, @skb is left unmodified. Must be called with + * ovs_mutex or rcu_read_lock. + */ +int ovs_vport_get_upcall_portids(const struct vport *vport, + struct sk_buff *skb) +{ + struct vport_portids *ids; + + ids = rcu_dereference_ovsl(vport->upcall_portids); + + if (vport->dp->user_features & OVS_DP_F_VPORT_PIDS) + return nla_put(skb, OVS_VPORT_ATTR_UPCALL_PID, + ids->n_ids * sizeof(u32), (void *)ids->ids); + else + return nla_put_u32(skb, OVS_VPORT_ATTR_UPCALL_PID, ids->ids[0]); +} + +/** + * ovs_vport_find_upcall_portid - find the upcall portid to send upcall. + * + * @vport: vport from which the missed packet is received. + * @skb: skb that the missed packet was received. + * + * Uses the skb_get_hash() to select the upcall portid to send the + * upcall. + * + * Returns the portid of the target socket. Must be called with rcu_read_lock. + */ +u32 ovs_vport_find_upcall_portid(const struct vport *vport, + struct sk_buff *skb) +{ + struct vport_portids *ids; + u32 ids_index; + u32 hash; + + ids = rcu_dereference(vport->upcall_portids); + + /* If there is only one portid, select it in the fast-path. */ + if (ids->n_ids == 1) + return ids->ids[0]; + + hash = skb_get_hash(skb); + ids_index = hash - ids->n_ids * reciprocal_divide(hash, ids->rn_ids); + return ids->ids[ids_index]; +} + +/** + * ovs_vport_receive - pass up received packet to the datapath for processing + * + * @vport: vport that received the packet + * @skb: skb that was received + * @tun_info: tunnel (if any) that carried packet + * + * Must be called with rcu_read_lock. The packet cannot be shared and + * skb->data should point to the Ethernet header. + */ +int ovs_vport_receive(struct vport *vport, struct sk_buff *skb, + const struct ip_tunnel_info *tun_info) +{ + struct sw_flow_key key; + int error; + + OVS_CB(skb)->input_vport = vport; + OVS_CB(skb)->mru = 0; + OVS_CB(skb)->cutlen = 0; + if (unlikely(dev_net(skb->dev) != ovs_dp_get_net(vport->dp))) { + u32 mark; + + mark = skb->mark; + skb_scrub_packet(skb, true); + skb->mark = mark; + tun_info = NULL; + } + + /* Extract flow from 'skb' into 'key'. */ + error = ovs_flow_key_extract(tun_info, skb, &key); + if (unlikely(error)) { + kfree_skb(skb); + return error; + } + ovs_dp_process_packet(skb, &key); + return 0; +} + +static int packet_length(const struct sk_buff *skb, + struct net_device *dev) +{ + int length = skb->len - dev->hard_header_len; + + if (!skb_vlan_tag_present(skb) && + eth_type_vlan(skb->protocol)) + length -= VLAN_HLEN; + + /* Don't subtract for multiple VLAN tags. Most (all?) drivers allow + * (ETH_LEN + VLAN_HLEN) in addition to the mtu value, but almost none + * account for 802.1ad. e.g. is_skb_forwardable(). + */ + + return length > 0 ? length : 0; +} + +void ovs_vport_send(struct vport *vport, struct sk_buff *skb, u8 mac_proto) +{ + int mtu = vport->dev->mtu; + + switch (vport->dev->type) { + case ARPHRD_NONE: + if (mac_proto == MAC_PROTO_ETHERNET) { + skb_reset_network_header(skb); + skb_reset_mac_len(skb); + skb->protocol = htons(ETH_P_TEB); + } else if (mac_proto != MAC_PROTO_NONE) { + WARN_ON_ONCE(1); + goto drop; + } + break; + case ARPHRD_ETHER: + if (mac_proto != MAC_PROTO_ETHERNET) + goto drop; + break; + default: + goto drop; + } + + if (unlikely(packet_length(skb, vport->dev) > mtu && + !skb_is_gso(skb))) { + vport->dev->stats.tx_errors++; + if (vport->dev->flags & IFF_UP) + net_warn_ratelimited("%s: dropped over-mtu packet: " + "%d > %d\n", vport->dev->name, + packet_length(skb, vport->dev), + mtu); + goto drop; + } + + skb->dev = vport->dev; + skb_clear_tstamp(skb); + vport->ops->send(skb); + return; + +drop: + kfree_skb(skb); +} diff --git a/net/openvswitch/vport.h b/net/openvswitch/vport.h new file mode 100644 index 000000000..6ff45e8a0 --- /dev/null +++ b/net/openvswitch/vport.h @@ -0,0 +1,193 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (c) 2007-2012 Nicira, Inc. + */ + +#ifndef VPORT_H +#define VPORT_H 1 + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "datapath.h" + +struct vport; +struct vport_parms; + +/* The following definitions are for users of the vport subsystem: */ + +int ovs_vport_init(void); +void ovs_vport_exit(void); + +struct vport *ovs_vport_add(const struct vport_parms *); +void ovs_vport_del(struct vport *); + +struct vport *ovs_vport_locate(const struct net *net, const char *name); + +void ovs_vport_get_stats(struct vport *, struct ovs_vport_stats *); + +int ovs_vport_set_options(struct vport *, struct nlattr *options); +int ovs_vport_get_options(const struct vport *, struct sk_buff *); + +int ovs_vport_set_upcall_portids(struct vport *, const struct nlattr *pids); +int ovs_vport_get_upcall_portids(const struct vport *, struct sk_buff *); +u32 ovs_vport_find_upcall_portid(const struct vport *, struct sk_buff *); + +/** + * struct vport_portids - array of netlink portids of a vport. + * must be protected by rcu. + * @rn_ids: The reciprocal value of @n_ids. + * @rcu: RCU callback head for deferred destruction. + * @n_ids: Size of @ids array. + * @ids: Array storing the Netlink socket pids to be used for packets received + * on this port that miss the flow table. + */ +struct vport_portids { + struct reciprocal_value rn_ids; + struct rcu_head rcu; + u32 n_ids; + u32 ids[]; +}; + +/** + * struct vport - one port within a datapath + * @dev: Pointer to net_device. + * @dev_tracker: refcount tracker for @dev reference + * @dp: Datapath to which this port belongs. + * @upcall_portids: RCU protected 'struct vport_portids'. + * @port_no: Index into @dp's @ports array. + * @hash_node: Element in @dev_table hash table in vport.c. + * @dp_hash_node: Element in @datapath->ports hash table in datapath.c. + * @ops: Class structure. + * @detach_list: list used for detaching vport in net-exit call. + * @rcu: RCU callback head for deferred destruction. + */ +struct vport { + struct net_device *dev; + netdevice_tracker dev_tracker; + struct datapath *dp; + struct vport_portids __rcu *upcall_portids; + u16 port_no; + + struct hlist_node hash_node; + struct hlist_node dp_hash_node; + const struct vport_ops *ops; + + struct list_head detach_list; + struct rcu_head rcu; +}; + +/** + * struct vport_parms - parameters for creating a new vport + * + * @name: New vport's name. + * @type: New vport's type. + * @options: %OVS_VPORT_ATTR_OPTIONS attribute from Netlink message, %NULL if + * none was supplied. + * @desired_ifindex: New vport's ifindex. + * @dp: New vport's datapath. + * @port_no: New vport's port number. + */ +struct vport_parms { + const char *name; + enum ovs_vport_type type; + int desired_ifindex; + struct nlattr *options; + + /* For ovs_vport_alloc(). */ + struct datapath *dp; + u16 port_no; + struct nlattr *upcall_portids; +}; + +/** + * struct vport_ops - definition of a type of virtual port + * + * @type: %OVS_VPORT_TYPE_* value for this type of virtual port. + * @create: Create a new vport configured as specified. On success returns + * a new vport allocated with ovs_vport_alloc(), otherwise an ERR_PTR() value. + * @destroy: Destroys a vport. Must call vport_free() on the vport but not + * before an RCU grace period has elapsed. + * @set_options: Modify the configuration of an existing vport. May be %NULL + * if modification is not supported. + * @get_options: Appends vport-specific attributes for the configuration of an + * existing vport to a &struct sk_buff. May be %NULL for a vport that does not + * have any configuration. + * @send: Send a packet on the device. + * zero for dropped packets or negative for error. + */ +struct vport_ops { + enum ovs_vport_type type; + + /* Called with ovs_mutex. */ + struct vport *(*create)(const struct vport_parms *); + void (*destroy)(struct vport *); + + int (*set_options)(struct vport *, struct nlattr *); + int (*get_options)(const struct vport *, struct sk_buff *); + + int (*send)(struct sk_buff *skb); + struct module *owner; + struct list_head list; +}; + +struct vport *ovs_vport_alloc(int priv_size, const struct vport_ops *, + const struct vport_parms *); +void ovs_vport_free(struct vport *); + +#define VPORT_ALIGN 8 + +/** + * vport_priv - access private data area of vport + * + * @vport: vport to access + * + * If a nonzero size was passed in priv_size of vport_alloc() a private data + * area was allocated on creation. This allows that area to be accessed and + * used for any purpose needed by the vport implementer. + */ +static inline void *vport_priv(const struct vport *vport) +{ + return (u8 *)(uintptr_t)vport + ALIGN(sizeof(struct vport), VPORT_ALIGN); +} + +/** + * vport_from_priv - lookup vport from private data pointer + * + * @priv: Start of private data area. + * + * It is sometimes useful to translate from a pointer to the private data + * area to the vport, such as in the case where the private data pointer is + * the result of a hash table lookup. @priv must point to the start of the + * private data area. + */ +static inline struct vport *vport_from_priv(void *priv) +{ + return (struct vport *)((u8 *)priv - ALIGN(sizeof(struct vport), VPORT_ALIGN)); +} + +int ovs_vport_receive(struct vport *, struct sk_buff *, + const struct ip_tunnel_info *); + +static inline const char *ovs_vport_name(struct vport *vport) +{ + return vport->dev->name; +} + +int __ovs_vport_ops_register(struct vport_ops *ops); +#define ovs_vport_ops_register(ops) \ + ({ \ + (ops)->owner = THIS_MODULE; \ + __ovs_vport_ops_register(ops); \ + }) + +void ovs_vport_ops_unregister(struct vport_ops *ops); +void ovs_vport_send(struct vport *vport, struct sk_buff *skb, u8 mac_proto); + +#endif /* vport.h */ diff --git a/net/packet/Kconfig b/net/packet/Kconfig new file mode 100644 index 000000000..2997382d5 --- /dev/null +++ b/net/packet/Kconfig @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Packet configuration +# + +config PACKET + tristate "Packet socket" + help + The Packet protocol is used by applications which communicate + directly with network devices without an intermediate network + protocol implemented in the kernel, e.g. tcpdump. If you want them + to work, choose Y. + + To compile this driver as a module, choose M here: the module will + be called af_packet. + + If unsure, say Y. + +config PACKET_DIAG + tristate "Packet: sockets monitoring interface" + depends on PACKET + default n + help + Support for PF_PACKET sockets monitoring interface used by the ss tool. + If unsure, say Y. diff --git a/net/packet/Makefile b/net/packet/Makefile new file mode 100644 index 000000000..97d502e21 --- /dev/null +++ b/net/packet/Makefile @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the packet AF. +# + +obj-$(CONFIG_PACKET) += af_packet.o +obj-$(CONFIG_PACKET_DIAG) += af_packet_diag.o +af_packet_diag-y += diag.o diff --git a/net/packet/af_packet.c b/net/packet/af_packet.c new file mode 100644 index 000000000..51882f07e --- /dev/null +++ b/net/packet/af_packet.c @@ -0,0 +1,4768 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * INET An implementation of the TCP/IP protocol suite for the LINUX + * operating system. INET is implemented using the BSD Socket + * interface as the means of communication with the user level. + * + * PACKET - implements raw packet sockets. + * + * Authors: Ross Biro + * Fred N. van Kempen, + * Alan Cox, + * + * Fixes: + * Alan Cox : verify_area() now used correctly + * Alan Cox : new skbuff lists, look ma no backlogs! + * Alan Cox : tidied skbuff lists. + * Alan Cox : Now uses generic datagram routines I + * added. Also fixed the peek/read crash + * from all old Linux datagram code. + * Alan Cox : Uses the improved datagram code. + * Alan Cox : Added NULL's for socket options. + * Alan Cox : Re-commented the code. + * Alan Cox : Use new kernel side addressing + * Rob Janssen : Correct MTU usage. + * Dave Platt : Counter leaks caused by incorrect + * interrupt locking and some slightly + * dubious gcc output. Can you read + * compiler: it said _VOLATILE_ + * Richard Kooijman : Timestamp fixes. + * Alan Cox : New buffers. Use sk->mac.raw. + * Alan Cox : sendmsg/recvmsg support. + * Alan Cox : Protocol setting support + * Alexey Kuznetsov : Untied from IPv4 stack. + * Cyrus Durgin : Fixed kerneld for kmod. + * Michal Ostrowski : Module initialization cleanup. + * Ulises Alonso : Frame number limit removal and + * packet_set_ring memory leak. + * Eric Biederman : Allow for > 8 byte hardware addresses. + * The convention is that longer addresses + * will simply extend the hardware address + * byte arrays at the end of sockaddr_ll + * and packet_mreq. + * Johann Baudy : Added TX RING. + * Chetan Loke : Implemented TPACKET_V3 block abstraction + * layer. + * Copyright (C) 2011, + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_INET +#include +#endif +#include +#include +#include + +#include "internal.h" + +/* + Assumptions: + - If the device has no dev->header_ops->create, there is no LL header + visible above the device. In this case, its hard_header_len should be 0. + The device may prepend its own header internally. In this case, its + needed_headroom should be set to the space needed for it to add its + internal header. + For example, a WiFi driver pretending to be an Ethernet driver should + set its hard_header_len to be the Ethernet header length, and set its + needed_headroom to be (the real WiFi header length - the fake Ethernet + header length). + - packet socket receives packets with pulled ll header, + so that SOCK_RAW should push it back. + +On receive: +----------- + +Incoming, dev_has_header(dev) == true + mac_header -> ll header + data -> data + +Outgoing, dev_has_header(dev) == true + mac_header -> ll header + data -> ll header + +Incoming, dev_has_header(dev) == false + mac_header -> data + However drivers often make it point to the ll header. + This is incorrect because the ll header should be invisible to us. + data -> data + +Outgoing, dev_has_header(dev) == false + mac_header -> data. ll header is invisible to us. + data -> data + +Resume + If dev_has_header(dev) == false we are unable to restore the ll header, + because it is invisible to us. + + +On transmit: +------------ + +dev_has_header(dev) == true + mac_header -> ll header + data -> ll header + +dev_has_header(dev) == false (ll header is invisible to us) + mac_header -> data + data -> data + + We should set network_header on output to the correct position, + packet classifier depends on it. + */ + +/* Private packet socket structures. */ + +/* identical to struct packet_mreq except it has + * a longer address field. + */ +struct packet_mreq_max { + int mr_ifindex; + unsigned short mr_type; + unsigned short mr_alen; + unsigned char mr_address[MAX_ADDR_LEN]; +}; + +union tpacket_uhdr { + struct tpacket_hdr *h1; + struct tpacket2_hdr *h2; + struct tpacket3_hdr *h3; + void *raw; +}; + +static int packet_set_ring(struct sock *sk, union tpacket_req_u *req_u, + int closing, int tx_ring); + +#define V3_ALIGNMENT (8) + +#define BLK_HDR_LEN (ALIGN(sizeof(struct tpacket_block_desc), V3_ALIGNMENT)) + +#define BLK_PLUS_PRIV(sz_of_priv) \ + (BLK_HDR_LEN + ALIGN((sz_of_priv), V3_ALIGNMENT)) + +#define BLOCK_STATUS(x) ((x)->hdr.bh1.block_status) +#define BLOCK_NUM_PKTS(x) ((x)->hdr.bh1.num_pkts) +#define BLOCK_O2FP(x) ((x)->hdr.bh1.offset_to_first_pkt) +#define BLOCK_LEN(x) ((x)->hdr.bh1.blk_len) +#define BLOCK_SNUM(x) ((x)->hdr.bh1.seq_num) +#define BLOCK_O2PRIV(x) ((x)->offset_to_priv) + +struct packet_sock; +static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev); + +static void *packet_previous_frame(struct packet_sock *po, + struct packet_ring_buffer *rb, + int status); +static void packet_increment_head(struct packet_ring_buffer *buff); +static int prb_curr_blk_in_use(struct tpacket_block_desc *); +static void *prb_dispatch_next_block(struct tpacket_kbdq_core *, + struct packet_sock *); +static void prb_retire_current_block(struct tpacket_kbdq_core *, + struct packet_sock *, unsigned int status); +static int prb_queue_frozen(struct tpacket_kbdq_core *); +static void prb_open_block(struct tpacket_kbdq_core *, + struct tpacket_block_desc *); +static void prb_retire_rx_blk_timer_expired(struct timer_list *); +static void _prb_refresh_rx_retire_blk_timer(struct tpacket_kbdq_core *); +static void prb_fill_rxhash(struct tpacket_kbdq_core *, struct tpacket3_hdr *); +static void prb_clear_rxhash(struct tpacket_kbdq_core *, + struct tpacket3_hdr *); +static void prb_fill_vlan_info(struct tpacket_kbdq_core *, + struct tpacket3_hdr *); +static void packet_flush_mclist(struct sock *sk); +static u16 packet_pick_tx_queue(struct sk_buff *skb); + +struct packet_skb_cb { + union { + struct sockaddr_pkt pkt; + union { + /* Trick: alias skb original length with + * ll.sll_family and ll.protocol in order + * to save room. + */ + unsigned int origlen; + struct sockaddr_ll ll; + }; + } sa; +}; + +#define vio_le() virtio_legacy_is_little_endian() + +#define PACKET_SKB_CB(__skb) ((struct packet_skb_cb *)((__skb)->cb)) + +#define GET_PBDQC_FROM_RB(x) ((struct tpacket_kbdq_core *)(&(x)->prb_bdqc)) +#define GET_PBLOCK_DESC(x, bid) \ + ((struct tpacket_block_desc *)((x)->pkbdq[(bid)].buffer)) +#define GET_CURR_PBLOCK_DESC_FROM_CORE(x) \ + ((struct tpacket_block_desc *)((x)->pkbdq[(x)->kactive_blk_num].buffer)) +#define GET_NEXT_PRB_BLK_NUM(x) \ + (((x)->kactive_blk_num < ((x)->knum_blocks-1)) ? \ + ((x)->kactive_blk_num+1) : 0) + +static void __fanout_unlink(struct sock *sk, struct packet_sock *po); +static void __fanout_link(struct sock *sk, struct packet_sock *po); + +#ifdef CONFIG_NETFILTER_EGRESS +static noinline struct sk_buff *nf_hook_direct_egress(struct sk_buff *skb) +{ + struct sk_buff *next, *head = NULL, *tail; + int rc; + + rcu_read_lock(); + for (; skb != NULL; skb = next) { + next = skb->next; + skb_mark_not_on_list(skb); + + if (!nf_hook_egress(skb, &rc, skb->dev)) + continue; + + if (!head) + head = skb; + else + tail->next = skb; + + tail = skb; + } + rcu_read_unlock(); + + return head; +} +#endif + +static int packet_direct_xmit(struct sk_buff *skb) +{ +#ifdef CONFIG_NETFILTER_EGRESS + if (nf_hook_egress_active()) { + skb = nf_hook_direct_egress(skb); + if (!skb) + return NET_XMIT_DROP; + } +#endif + return dev_direct_xmit(skb, packet_pick_tx_queue(skb)); +} + +static struct net_device *packet_cached_dev_get(struct packet_sock *po) +{ + struct net_device *dev; + + rcu_read_lock(); + dev = rcu_dereference(po->cached_dev); + dev_hold(dev); + rcu_read_unlock(); + + return dev; +} + +static void packet_cached_dev_assign(struct packet_sock *po, + struct net_device *dev) +{ + rcu_assign_pointer(po->cached_dev, dev); +} + +static void packet_cached_dev_reset(struct packet_sock *po) +{ + RCU_INIT_POINTER(po->cached_dev, NULL); +} + +static bool packet_use_direct_xmit(const struct packet_sock *po) +{ + /* Paired with WRITE_ONCE() in packet_setsockopt() */ + return READ_ONCE(po->xmit) == packet_direct_xmit; +} + +static u16 packet_pick_tx_queue(struct sk_buff *skb) +{ + struct net_device *dev = skb->dev; + const struct net_device_ops *ops = dev->netdev_ops; + int cpu = raw_smp_processor_id(); + u16 queue_index; + +#ifdef CONFIG_XPS + skb->sender_cpu = cpu + 1; +#endif + skb_record_rx_queue(skb, cpu % dev->real_num_tx_queues); + if (ops->ndo_select_queue) { + queue_index = ops->ndo_select_queue(dev, skb, NULL); + queue_index = netdev_cap_txqueue(dev, queue_index); + } else { + queue_index = netdev_pick_tx(dev, skb, NULL); + } + + return queue_index; +} + +/* __register_prot_hook must be invoked through register_prot_hook + * or from a context in which asynchronous accesses to the packet + * socket is not possible (packet_create()). + */ +static void __register_prot_hook(struct sock *sk) +{ + struct packet_sock *po = pkt_sk(sk); + + if (!po->running) { + if (po->fanout) + __fanout_link(sk, po); + else + dev_add_pack(&po->prot_hook); + + sock_hold(sk); + po->running = 1; + } +} + +static void register_prot_hook(struct sock *sk) +{ + lockdep_assert_held_once(&pkt_sk(sk)->bind_lock); + __register_prot_hook(sk); +} + +/* If the sync parameter is true, we will temporarily drop + * the po->bind_lock and do a synchronize_net to make sure no + * asynchronous packet processing paths still refer to the elements + * of po->prot_hook. If the sync parameter is false, it is the + * callers responsibility to take care of this. + */ +static void __unregister_prot_hook(struct sock *sk, bool sync) +{ + struct packet_sock *po = pkt_sk(sk); + + lockdep_assert_held_once(&po->bind_lock); + + po->running = 0; + + if (po->fanout) + __fanout_unlink(sk, po); + else + __dev_remove_pack(&po->prot_hook); + + __sock_put(sk); + + if (sync) { + spin_unlock(&po->bind_lock); + synchronize_net(); + spin_lock(&po->bind_lock); + } +} + +static void unregister_prot_hook(struct sock *sk, bool sync) +{ + struct packet_sock *po = pkt_sk(sk); + + if (po->running) + __unregister_prot_hook(sk, sync); +} + +static inline struct page * __pure pgv_to_page(void *addr) +{ + if (is_vmalloc_addr(addr)) + return vmalloc_to_page(addr); + return virt_to_page(addr); +} + +static void __packet_set_status(struct packet_sock *po, void *frame, int status) +{ + union tpacket_uhdr h; + + /* WRITE_ONCE() are paired with READ_ONCE() in __packet_get_status */ + + h.raw = frame; + switch (po->tp_version) { + case TPACKET_V1: + WRITE_ONCE(h.h1->tp_status, status); + flush_dcache_page(pgv_to_page(&h.h1->tp_status)); + break; + case TPACKET_V2: + WRITE_ONCE(h.h2->tp_status, status); + flush_dcache_page(pgv_to_page(&h.h2->tp_status)); + break; + case TPACKET_V3: + WRITE_ONCE(h.h3->tp_status, status); + flush_dcache_page(pgv_to_page(&h.h3->tp_status)); + break; + default: + WARN(1, "TPACKET version not supported.\n"); + BUG(); + } + + smp_wmb(); +} + +static int __packet_get_status(const struct packet_sock *po, void *frame) +{ + union tpacket_uhdr h; + + smp_rmb(); + + /* READ_ONCE() are paired with WRITE_ONCE() in __packet_set_status */ + + h.raw = frame; + switch (po->tp_version) { + case TPACKET_V1: + flush_dcache_page(pgv_to_page(&h.h1->tp_status)); + return READ_ONCE(h.h1->tp_status); + case TPACKET_V2: + flush_dcache_page(pgv_to_page(&h.h2->tp_status)); + return READ_ONCE(h.h2->tp_status); + case TPACKET_V3: + flush_dcache_page(pgv_to_page(&h.h3->tp_status)); + return READ_ONCE(h.h3->tp_status); + default: + WARN(1, "TPACKET version not supported.\n"); + BUG(); + return 0; + } +} + +static __u32 tpacket_get_timestamp(struct sk_buff *skb, struct timespec64 *ts, + unsigned int flags) +{ + struct skb_shared_hwtstamps *shhwtstamps = skb_hwtstamps(skb); + + if (shhwtstamps && + (flags & SOF_TIMESTAMPING_RAW_HARDWARE) && + ktime_to_timespec64_cond(shhwtstamps->hwtstamp, ts)) + return TP_STATUS_TS_RAW_HARDWARE; + + if ((flags & SOF_TIMESTAMPING_SOFTWARE) && + ktime_to_timespec64_cond(skb_tstamp(skb), ts)) + return TP_STATUS_TS_SOFTWARE; + + return 0; +} + +static __u32 __packet_set_timestamp(struct packet_sock *po, void *frame, + struct sk_buff *skb) +{ + union tpacket_uhdr h; + struct timespec64 ts; + __u32 ts_status; + + if (!(ts_status = tpacket_get_timestamp(skb, &ts, po->tp_tstamp))) + return 0; + + h.raw = frame; + /* + * versions 1 through 3 overflow the timestamps in y2106, since they + * all store the seconds in a 32-bit unsigned integer. + * If we create a version 4, that should have a 64-bit timestamp, + * either 64-bit seconds + 32-bit nanoseconds, or just 64-bit + * nanoseconds. + */ + switch (po->tp_version) { + case TPACKET_V1: + h.h1->tp_sec = ts.tv_sec; + h.h1->tp_usec = ts.tv_nsec / NSEC_PER_USEC; + break; + case TPACKET_V2: + h.h2->tp_sec = ts.tv_sec; + h.h2->tp_nsec = ts.tv_nsec; + break; + case TPACKET_V3: + h.h3->tp_sec = ts.tv_sec; + h.h3->tp_nsec = ts.tv_nsec; + break; + default: + WARN(1, "TPACKET version not supported.\n"); + BUG(); + } + + /* one flush is safe, as both fields always lie on the same cacheline */ + flush_dcache_page(pgv_to_page(&h.h1->tp_sec)); + smp_wmb(); + + return ts_status; +} + +static void *packet_lookup_frame(const struct packet_sock *po, + const struct packet_ring_buffer *rb, + unsigned int position, + int status) +{ + unsigned int pg_vec_pos, frame_offset; + union tpacket_uhdr h; + + pg_vec_pos = position / rb->frames_per_block; + frame_offset = position % rb->frames_per_block; + + h.raw = rb->pg_vec[pg_vec_pos].buffer + + (frame_offset * rb->frame_size); + + if (status != __packet_get_status(po, h.raw)) + return NULL; + + return h.raw; +} + +static void *packet_current_frame(struct packet_sock *po, + struct packet_ring_buffer *rb, + int status) +{ + return packet_lookup_frame(po, rb, rb->head, status); +} + +static void prb_del_retire_blk_timer(struct tpacket_kbdq_core *pkc) +{ + del_timer_sync(&pkc->retire_blk_timer); +} + +static void prb_shutdown_retire_blk_timer(struct packet_sock *po, + struct sk_buff_head *rb_queue) +{ + struct tpacket_kbdq_core *pkc; + + pkc = GET_PBDQC_FROM_RB(&po->rx_ring); + + spin_lock_bh(&rb_queue->lock); + pkc->delete_blk_timer = 1; + spin_unlock_bh(&rb_queue->lock); + + prb_del_retire_blk_timer(pkc); +} + +static void prb_setup_retire_blk_timer(struct packet_sock *po) +{ + struct tpacket_kbdq_core *pkc; + + pkc = GET_PBDQC_FROM_RB(&po->rx_ring); + timer_setup(&pkc->retire_blk_timer, prb_retire_rx_blk_timer_expired, + 0); + pkc->retire_blk_timer.expires = jiffies; +} + +static int prb_calc_retire_blk_tmo(struct packet_sock *po, + int blk_size_in_bytes) +{ + struct net_device *dev; + unsigned int mbits, div; + struct ethtool_link_ksettings ecmd; + int err; + + rtnl_lock(); + dev = __dev_get_by_index(sock_net(&po->sk), po->ifindex); + if (unlikely(!dev)) { + rtnl_unlock(); + return DEFAULT_PRB_RETIRE_TOV; + } + err = __ethtool_get_link_ksettings(dev, &ecmd); + rtnl_unlock(); + if (err) + return DEFAULT_PRB_RETIRE_TOV; + + /* If the link speed is so slow you don't really + * need to worry about perf anyways + */ + if (ecmd.base.speed < SPEED_1000 || + ecmd.base.speed == SPEED_UNKNOWN) + return DEFAULT_PRB_RETIRE_TOV; + + div = ecmd.base.speed / 1000; + mbits = (blk_size_in_bytes * 8) / (1024 * 1024); + + if (div) + mbits /= div; + + if (div) + return mbits + 1; + return mbits; +} + +static void prb_init_ft_ops(struct tpacket_kbdq_core *p1, + union tpacket_req_u *req_u) +{ + p1->feature_req_word = req_u->req3.tp_feature_req_word; +} + +static void init_prb_bdqc(struct packet_sock *po, + struct packet_ring_buffer *rb, + struct pgv *pg_vec, + union tpacket_req_u *req_u) +{ + struct tpacket_kbdq_core *p1 = GET_PBDQC_FROM_RB(rb); + struct tpacket_block_desc *pbd; + + memset(p1, 0x0, sizeof(*p1)); + + p1->knxt_seq_num = 1; + p1->pkbdq = pg_vec; + pbd = (struct tpacket_block_desc *)pg_vec[0].buffer; + p1->pkblk_start = pg_vec[0].buffer; + p1->kblk_size = req_u->req3.tp_block_size; + p1->knum_blocks = req_u->req3.tp_block_nr; + p1->hdrlen = po->tp_hdrlen; + p1->version = po->tp_version; + p1->last_kactive_blk_num = 0; + po->stats.stats3.tp_freeze_q_cnt = 0; + if (req_u->req3.tp_retire_blk_tov) + p1->retire_blk_tov = req_u->req3.tp_retire_blk_tov; + else + p1->retire_blk_tov = prb_calc_retire_blk_tmo(po, + req_u->req3.tp_block_size); + p1->tov_in_jiffies = msecs_to_jiffies(p1->retire_blk_tov); + p1->blk_sizeof_priv = req_u->req3.tp_sizeof_priv; + rwlock_init(&p1->blk_fill_in_prog_lock); + + p1->max_frame_len = p1->kblk_size - BLK_PLUS_PRIV(p1->blk_sizeof_priv); + prb_init_ft_ops(p1, req_u); + prb_setup_retire_blk_timer(po); + prb_open_block(p1, pbd); +} + +/* Do NOT update the last_blk_num first. + * Assumes sk_buff_head lock is held. + */ +static void _prb_refresh_rx_retire_blk_timer(struct tpacket_kbdq_core *pkc) +{ + mod_timer(&pkc->retire_blk_timer, + jiffies + pkc->tov_in_jiffies); + pkc->last_kactive_blk_num = pkc->kactive_blk_num; +} + +/* + * Timer logic: + * 1) We refresh the timer only when we open a block. + * By doing this we don't waste cycles refreshing the timer + * on packet-by-packet basis. + * + * With a 1MB block-size, on a 1Gbps line, it will take + * i) ~8 ms to fill a block + ii) memcpy etc. + * In this cut we are not accounting for the memcpy time. + * + * So, if the user sets the 'tmo' to 10ms then the timer + * will never fire while the block is still getting filled + * (which is what we want). However, the user could choose + * to close a block early and that's fine. + * + * But when the timer does fire, we check whether or not to refresh it. + * Since the tmo granularity is in msecs, it is not too expensive + * to refresh the timer, lets say every '8' msecs. + * Either the user can set the 'tmo' or we can derive it based on + * a) line-speed and b) block-size. + * prb_calc_retire_blk_tmo() calculates the tmo. + * + */ +static void prb_retire_rx_blk_timer_expired(struct timer_list *t) +{ + struct packet_sock *po = + from_timer(po, t, rx_ring.prb_bdqc.retire_blk_timer); + struct tpacket_kbdq_core *pkc = GET_PBDQC_FROM_RB(&po->rx_ring); + unsigned int frozen; + struct tpacket_block_desc *pbd; + + spin_lock(&po->sk.sk_receive_queue.lock); + + frozen = prb_queue_frozen(pkc); + pbd = GET_CURR_PBLOCK_DESC_FROM_CORE(pkc); + + if (unlikely(pkc->delete_blk_timer)) + goto out; + + /* We only need to plug the race when the block is partially filled. + * tpacket_rcv: + * lock(); increment BLOCK_NUM_PKTS; unlock() + * copy_bits() is in progress ... + * timer fires on other cpu: + * we can't retire the current block because copy_bits + * is in progress. + * + */ + if (BLOCK_NUM_PKTS(pbd)) { + /* Waiting for skb_copy_bits to finish... */ + write_lock(&pkc->blk_fill_in_prog_lock); + write_unlock(&pkc->blk_fill_in_prog_lock); + } + + if (pkc->last_kactive_blk_num == pkc->kactive_blk_num) { + if (!frozen) { + if (!BLOCK_NUM_PKTS(pbd)) { + /* An empty block. Just refresh the timer. */ + goto refresh_timer; + } + prb_retire_current_block(pkc, po, TP_STATUS_BLK_TMO); + if (!prb_dispatch_next_block(pkc, po)) + goto refresh_timer; + else + goto out; + } else { + /* Case 1. Queue was frozen because user-space was + * lagging behind. + */ + if (prb_curr_blk_in_use(pbd)) { + /* + * Ok, user-space is still behind. + * So just refresh the timer. + */ + goto refresh_timer; + } else { + /* Case 2. queue was frozen,user-space caught up, + * now the link went idle && the timer fired. + * We don't have a block to close.So we open this + * block and restart the timer. + * opening a block thaws the queue,restarts timer + * Thawing/timer-refresh is a side effect. + */ + prb_open_block(pkc, pbd); + goto out; + } + } + } + +refresh_timer: + _prb_refresh_rx_retire_blk_timer(pkc); + +out: + spin_unlock(&po->sk.sk_receive_queue.lock); +} + +static void prb_flush_block(struct tpacket_kbdq_core *pkc1, + struct tpacket_block_desc *pbd1, __u32 status) +{ + /* Flush everything minus the block header */ + +#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE == 1 + u8 *start, *end; + + start = (u8 *)pbd1; + + /* Skip the block header(we know header WILL fit in 4K) */ + start += PAGE_SIZE; + + end = (u8 *)PAGE_ALIGN((unsigned long)pkc1->pkblk_end); + for (; start < end; start += PAGE_SIZE) + flush_dcache_page(pgv_to_page(start)); + + smp_wmb(); +#endif + + /* Now update the block status. */ + + BLOCK_STATUS(pbd1) = status; + + /* Flush the block header */ + +#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE == 1 + start = (u8 *)pbd1; + flush_dcache_page(pgv_to_page(start)); + + smp_wmb(); +#endif +} + +/* + * Side effect: + * + * 1) flush the block + * 2) Increment active_blk_num + * + * Note:We DONT refresh the timer on purpose. + * Because almost always the next block will be opened. + */ +static void prb_close_block(struct tpacket_kbdq_core *pkc1, + struct tpacket_block_desc *pbd1, + struct packet_sock *po, unsigned int stat) +{ + __u32 status = TP_STATUS_USER | stat; + + struct tpacket3_hdr *last_pkt; + struct tpacket_hdr_v1 *h1 = &pbd1->hdr.bh1; + struct sock *sk = &po->sk; + + if (atomic_read(&po->tp_drops)) + status |= TP_STATUS_LOSING; + + last_pkt = (struct tpacket3_hdr *)pkc1->prev; + last_pkt->tp_next_offset = 0; + + /* Get the ts of the last pkt */ + if (BLOCK_NUM_PKTS(pbd1)) { + h1->ts_last_pkt.ts_sec = last_pkt->tp_sec; + h1->ts_last_pkt.ts_nsec = last_pkt->tp_nsec; + } else { + /* Ok, we tmo'd - so get the current time. + * + * It shouldn't really happen as we don't close empty + * blocks. See prb_retire_rx_blk_timer_expired(). + */ + struct timespec64 ts; + ktime_get_real_ts64(&ts); + h1->ts_last_pkt.ts_sec = ts.tv_sec; + h1->ts_last_pkt.ts_nsec = ts.tv_nsec; + } + + smp_wmb(); + + /* Flush the block */ + prb_flush_block(pkc1, pbd1, status); + + sk->sk_data_ready(sk); + + pkc1->kactive_blk_num = GET_NEXT_PRB_BLK_NUM(pkc1); +} + +static void prb_thaw_queue(struct tpacket_kbdq_core *pkc) +{ + pkc->reset_pending_on_curr_blk = 0; +} + +/* + * Side effect of opening a block: + * + * 1) prb_queue is thawed. + * 2) retire_blk_timer is refreshed. + * + */ +static void prb_open_block(struct tpacket_kbdq_core *pkc1, + struct tpacket_block_desc *pbd1) +{ + struct timespec64 ts; + struct tpacket_hdr_v1 *h1 = &pbd1->hdr.bh1; + + smp_rmb(); + + /* We could have just memset this but we will lose the + * flexibility of making the priv area sticky + */ + + BLOCK_SNUM(pbd1) = pkc1->knxt_seq_num++; + BLOCK_NUM_PKTS(pbd1) = 0; + BLOCK_LEN(pbd1) = BLK_PLUS_PRIV(pkc1->blk_sizeof_priv); + + ktime_get_real_ts64(&ts); + + h1->ts_first_pkt.ts_sec = ts.tv_sec; + h1->ts_first_pkt.ts_nsec = ts.tv_nsec; + + pkc1->pkblk_start = (char *)pbd1; + pkc1->nxt_offset = pkc1->pkblk_start + BLK_PLUS_PRIV(pkc1->blk_sizeof_priv); + + BLOCK_O2FP(pbd1) = (__u32)BLK_PLUS_PRIV(pkc1->blk_sizeof_priv); + BLOCK_O2PRIV(pbd1) = BLK_HDR_LEN; + + pbd1->version = pkc1->version; + pkc1->prev = pkc1->nxt_offset; + pkc1->pkblk_end = pkc1->pkblk_start + pkc1->kblk_size; + + prb_thaw_queue(pkc1); + _prb_refresh_rx_retire_blk_timer(pkc1); + + smp_wmb(); +} + +/* + * Queue freeze logic: + * 1) Assume tp_block_nr = 8 blocks. + * 2) At time 't0', user opens Rx ring. + * 3) Some time past 't0', kernel starts filling blocks starting from 0 .. 7 + * 4) user-space is either sleeping or processing block '0'. + * 5) tpacket_rcv is currently filling block '7', since there is no space left, + * it will close block-7,loop around and try to fill block '0'. + * call-flow: + * __packet_lookup_frame_in_block + * prb_retire_current_block() + * prb_dispatch_next_block() + * |->(BLOCK_STATUS == USER) evaluates to true + * 5.1) Since block-0 is currently in-use, we just freeze the queue. + * 6) Now there are two cases: + * 6.1) Link goes idle right after the queue is frozen. + * But remember, the last open_block() refreshed the timer. + * When this timer expires,it will refresh itself so that we can + * re-open block-0 in near future. + * 6.2) Link is busy and keeps on receiving packets. This is a simple + * case and __packet_lookup_frame_in_block will check if block-0 + * is free and can now be re-used. + */ +static void prb_freeze_queue(struct tpacket_kbdq_core *pkc, + struct packet_sock *po) +{ + pkc->reset_pending_on_curr_blk = 1; + po->stats.stats3.tp_freeze_q_cnt++; +} + +#define TOTAL_PKT_LEN_INCL_ALIGN(length) (ALIGN((length), V3_ALIGNMENT)) + +/* + * If the next block is free then we will dispatch it + * and return a good offset. + * Else, we will freeze the queue. + * So, caller must check the return value. + */ +static void *prb_dispatch_next_block(struct tpacket_kbdq_core *pkc, + struct packet_sock *po) +{ + struct tpacket_block_desc *pbd; + + smp_rmb(); + + /* 1. Get current block num */ + pbd = GET_CURR_PBLOCK_DESC_FROM_CORE(pkc); + + /* 2. If this block is currently in_use then freeze the queue */ + if (TP_STATUS_USER & BLOCK_STATUS(pbd)) { + prb_freeze_queue(pkc, po); + return NULL; + } + + /* + * 3. + * open this block and return the offset where the first packet + * needs to get stored. + */ + prb_open_block(pkc, pbd); + return (void *)pkc->nxt_offset; +} + +static void prb_retire_current_block(struct tpacket_kbdq_core *pkc, + struct packet_sock *po, unsigned int status) +{ + struct tpacket_block_desc *pbd = GET_CURR_PBLOCK_DESC_FROM_CORE(pkc); + + /* retire/close the current block */ + if (likely(TP_STATUS_KERNEL == BLOCK_STATUS(pbd))) { + /* + * Plug the case where copy_bits() is in progress on + * cpu-0 and tpacket_rcv() got invoked on cpu-1, didn't + * have space to copy the pkt in the current block and + * called prb_retire_current_block() + * + * We don't need to worry about the TMO case because + * the timer-handler already handled this case. + */ + if (!(status & TP_STATUS_BLK_TMO)) { + /* Waiting for skb_copy_bits to finish... */ + write_lock(&pkc->blk_fill_in_prog_lock); + write_unlock(&pkc->blk_fill_in_prog_lock); + } + prb_close_block(pkc, pbd, po, status); + return; + } +} + +static int prb_curr_blk_in_use(struct tpacket_block_desc *pbd) +{ + return TP_STATUS_USER & BLOCK_STATUS(pbd); +} + +static int prb_queue_frozen(struct tpacket_kbdq_core *pkc) +{ + return pkc->reset_pending_on_curr_blk; +} + +static void prb_clear_blk_fill_status(struct packet_ring_buffer *rb) + __releases(&pkc->blk_fill_in_prog_lock) +{ + struct tpacket_kbdq_core *pkc = GET_PBDQC_FROM_RB(rb); + + read_unlock(&pkc->blk_fill_in_prog_lock); +} + +static void prb_fill_rxhash(struct tpacket_kbdq_core *pkc, + struct tpacket3_hdr *ppd) +{ + ppd->hv1.tp_rxhash = skb_get_hash(pkc->skb); +} + +static void prb_clear_rxhash(struct tpacket_kbdq_core *pkc, + struct tpacket3_hdr *ppd) +{ + ppd->hv1.tp_rxhash = 0; +} + +static void prb_fill_vlan_info(struct tpacket_kbdq_core *pkc, + struct tpacket3_hdr *ppd) +{ + if (skb_vlan_tag_present(pkc->skb)) { + ppd->hv1.tp_vlan_tci = skb_vlan_tag_get(pkc->skb); + ppd->hv1.tp_vlan_tpid = ntohs(pkc->skb->vlan_proto); + ppd->tp_status = TP_STATUS_VLAN_VALID | TP_STATUS_VLAN_TPID_VALID; + } else { + ppd->hv1.tp_vlan_tci = 0; + ppd->hv1.tp_vlan_tpid = 0; + ppd->tp_status = TP_STATUS_AVAILABLE; + } +} + +static void prb_run_all_ft_ops(struct tpacket_kbdq_core *pkc, + struct tpacket3_hdr *ppd) +{ + ppd->hv1.tp_padding = 0; + prb_fill_vlan_info(pkc, ppd); + + if (pkc->feature_req_word & TP_FT_REQ_FILL_RXHASH) + prb_fill_rxhash(pkc, ppd); + else + prb_clear_rxhash(pkc, ppd); +} + +static void prb_fill_curr_block(char *curr, + struct tpacket_kbdq_core *pkc, + struct tpacket_block_desc *pbd, + unsigned int len) + __acquires(&pkc->blk_fill_in_prog_lock) +{ + struct tpacket3_hdr *ppd; + + ppd = (struct tpacket3_hdr *)curr; + ppd->tp_next_offset = TOTAL_PKT_LEN_INCL_ALIGN(len); + pkc->prev = curr; + pkc->nxt_offset += TOTAL_PKT_LEN_INCL_ALIGN(len); + BLOCK_LEN(pbd) += TOTAL_PKT_LEN_INCL_ALIGN(len); + BLOCK_NUM_PKTS(pbd) += 1; + read_lock(&pkc->blk_fill_in_prog_lock); + prb_run_all_ft_ops(pkc, ppd); +} + +/* Assumes caller has the sk->rx_queue.lock */ +static void *__packet_lookup_frame_in_block(struct packet_sock *po, + struct sk_buff *skb, + unsigned int len + ) +{ + struct tpacket_kbdq_core *pkc; + struct tpacket_block_desc *pbd; + char *curr, *end; + + pkc = GET_PBDQC_FROM_RB(&po->rx_ring); + pbd = GET_CURR_PBLOCK_DESC_FROM_CORE(pkc); + + /* Queue is frozen when user space is lagging behind */ + if (prb_queue_frozen(pkc)) { + /* + * Check if that last block which caused the queue to freeze, + * is still in_use by user-space. + */ + if (prb_curr_blk_in_use(pbd)) { + /* Can't record this packet */ + return NULL; + } else { + /* + * Ok, the block was released by user-space. + * Now let's open that block. + * opening a block also thaws the queue. + * Thawing is a side effect. + */ + prb_open_block(pkc, pbd); + } + } + + smp_mb(); + curr = pkc->nxt_offset; + pkc->skb = skb; + end = (char *)pbd + pkc->kblk_size; + + /* first try the current block */ + if (curr+TOTAL_PKT_LEN_INCL_ALIGN(len) < end) { + prb_fill_curr_block(curr, pkc, pbd, len); + return (void *)curr; + } + + /* Ok, close the current block */ + prb_retire_current_block(pkc, po, 0); + + /* Now, try to dispatch the next block */ + curr = (char *)prb_dispatch_next_block(pkc, po); + if (curr) { + pbd = GET_CURR_PBLOCK_DESC_FROM_CORE(pkc); + prb_fill_curr_block(curr, pkc, pbd, len); + return (void *)curr; + } + + /* + * No free blocks are available.user_space hasn't caught up yet. + * Queue was just frozen and now this packet will get dropped. + */ + return NULL; +} + +static void *packet_current_rx_frame(struct packet_sock *po, + struct sk_buff *skb, + int status, unsigned int len) +{ + char *curr = NULL; + switch (po->tp_version) { + case TPACKET_V1: + case TPACKET_V2: + curr = packet_lookup_frame(po, &po->rx_ring, + po->rx_ring.head, status); + return curr; + case TPACKET_V3: + return __packet_lookup_frame_in_block(po, skb, len); + default: + WARN(1, "TPACKET version not supported\n"); + BUG(); + return NULL; + } +} + +static void *prb_lookup_block(const struct packet_sock *po, + const struct packet_ring_buffer *rb, + unsigned int idx, + int status) +{ + struct tpacket_kbdq_core *pkc = GET_PBDQC_FROM_RB(rb); + struct tpacket_block_desc *pbd = GET_PBLOCK_DESC(pkc, idx); + + if (status != BLOCK_STATUS(pbd)) + return NULL; + return pbd; +} + +static int prb_previous_blk_num(struct packet_ring_buffer *rb) +{ + unsigned int prev; + if (rb->prb_bdqc.kactive_blk_num) + prev = rb->prb_bdqc.kactive_blk_num-1; + else + prev = rb->prb_bdqc.knum_blocks-1; + return prev; +} + +/* Assumes caller has held the rx_queue.lock */ +static void *__prb_previous_block(struct packet_sock *po, + struct packet_ring_buffer *rb, + int status) +{ + unsigned int previous = prb_previous_blk_num(rb); + return prb_lookup_block(po, rb, previous, status); +} + +static void *packet_previous_rx_frame(struct packet_sock *po, + struct packet_ring_buffer *rb, + int status) +{ + if (po->tp_version <= TPACKET_V2) + return packet_previous_frame(po, rb, status); + + return __prb_previous_block(po, rb, status); +} + +static void packet_increment_rx_head(struct packet_sock *po, + struct packet_ring_buffer *rb) +{ + switch (po->tp_version) { + case TPACKET_V1: + case TPACKET_V2: + return packet_increment_head(rb); + case TPACKET_V3: + default: + WARN(1, "TPACKET version not supported.\n"); + BUG(); + return; + } +} + +static void *packet_previous_frame(struct packet_sock *po, + struct packet_ring_buffer *rb, + int status) +{ + unsigned int previous = rb->head ? rb->head - 1 : rb->frame_max; + return packet_lookup_frame(po, rb, previous, status); +} + +static void packet_increment_head(struct packet_ring_buffer *buff) +{ + buff->head = buff->head != buff->frame_max ? buff->head+1 : 0; +} + +static void packet_inc_pending(struct packet_ring_buffer *rb) +{ + this_cpu_inc(*rb->pending_refcnt); +} + +static void packet_dec_pending(struct packet_ring_buffer *rb) +{ + this_cpu_dec(*rb->pending_refcnt); +} + +static unsigned int packet_read_pending(const struct packet_ring_buffer *rb) +{ + unsigned int refcnt = 0; + int cpu; + + /* We don't use pending refcount in rx_ring. */ + if (rb->pending_refcnt == NULL) + return 0; + + for_each_possible_cpu(cpu) + refcnt += *per_cpu_ptr(rb->pending_refcnt, cpu); + + return refcnt; +} + +static int packet_alloc_pending(struct packet_sock *po) +{ + po->rx_ring.pending_refcnt = NULL; + + po->tx_ring.pending_refcnt = alloc_percpu(unsigned int); + if (unlikely(po->tx_ring.pending_refcnt == NULL)) + return -ENOBUFS; + + return 0; +} + +static void packet_free_pending(struct packet_sock *po) +{ + free_percpu(po->tx_ring.pending_refcnt); +} + +#define ROOM_POW_OFF 2 +#define ROOM_NONE 0x0 +#define ROOM_LOW 0x1 +#define ROOM_NORMAL 0x2 + +static bool __tpacket_has_room(const struct packet_sock *po, int pow_off) +{ + int idx, len; + + len = READ_ONCE(po->rx_ring.frame_max) + 1; + idx = READ_ONCE(po->rx_ring.head); + if (pow_off) + idx += len >> pow_off; + if (idx >= len) + idx -= len; + return packet_lookup_frame(po, &po->rx_ring, idx, TP_STATUS_KERNEL); +} + +static bool __tpacket_v3_has_room(const struct packet_sock *po, int pow_off) +{ + int idx, len; + + len = READ_ONCE(po->rx_ring.prb_bdqc.knum_blocks); + idx = READ_ONCE(po->rx_ring.prb_bdqc.kactive_blk_num); + if (pow_off) + idx += len >> pow_off; + if (idx >= len) + idx -= len; + return prb_lookup_block(po, &po->rx_ring, idx, TP_STATUS_KERNEL); +} + +static int __packet_rcv_has_room(const struct packet_sock *po, + const struct sk_buff *skb) +{ + const struct sock *sk = &po->sk; + int ret = ROOM_NONE; + + if (po->prot_hook.func != tpacket_rcv) { + int rcvbuf = READ_ONCE(sk->sk_rcvbuf); + int avail = rcvbuf - atomic_read(&sk->sk_rmem_alloc) + - (skb ? skb->truesize : 0); + + if (avail > (rcvbuf >> ROOM_POW_OFF)) + return ROOM_NORMAL; + else if (avail > 0) + return ROOM_LOW; + else + return ROOM_NONE; + } + + if (po->tp_version == TPACKET_V3) { + if (__tpacket_v3_has_room(po, ROOM_POW_OFF)) + ret = ROOM_NORMAL; + else if (__tpacket_v3_has_room(po, 0)) + ret = ROOM_LOW; + } else { + if (__tpacket_has_room(po, ROOM_POW_OFF)) + ret = ROOM_NORMAL; + else if (__tpacket_has_room(po, 0)) + ret = ROOM_LOW; + } + + return ret; +} + +static int packet_rcv_has_room(struct packet_sock *po, struct sk_buff *skb) +{ + int pressure, ret; + + ret = __packet_rcv_has_room(po, skb); + pressure = ret != ROOM_NORMAL; + + if (READ_ONCE(po->pressure) != pressure) + WRITE_ONCE(po->pressure, pressure); + + return ret; +} + +static void packet_rcv_try_clear_pressure(struct packet_sock *po) +{ + if (READ_ONCE(po->pressure) && + __packet_rcv_has_room(po, NULL) == ROOM_NORMAL) + WRITE_ONCE(po->pressure, 0); +} + +static void packet_sock_destruct(struct sock *sk) +{ + skb_queue_purge(&sk->sk_error_queue); + + WARN_ON(atomic_read(&sk->sk_rmem_alloc)); + WARN_ON(refcount_read(&sk->sk_wmem_alloc)); + + if (!sock_flag(sk, SOCK_DEAD)) { + pr_err("Attempt to release alive packet socket: %p\n", sk); + return; + } + + sk_refcnt_debug_dec(sk); +} + +static bool fanout_flow_is_huge(struct packet_sock *po, struct sk_buff *skb) +{ + u32 *history = po->rollover->history; + u32 victim, rxhash; + int i, count = 0; + + rxhash = skb_get_hash(skb); + for (i = 0; i < ROLLOVER_HLEN; i++) + if (READ_ONCE(history[i]) == rxhash) + count++; + + victim = prandom_u32_max(ROLLOVER_HLEN); + + /* Avoid dirtying the cache line if possible */ + if (READ_ONCE(history[victim]) != rxhash) + WRITE_ONCE(history[victim], rxhash); + + return count > (ROLLOVER_HLEN >> 1); +} + +static unsigned int fanout_demux_hash(struct packet_fanout *f, + struct sk_buff *skb, + unsigned int num) +{ + return reciprocal_scale(__skb_get_hash_symmetric(skb), num); +} + +static unsigned int fanout_demux_lb(struct packet_fanout *f, + struct sk_buff *skb, + unsigned int num) +{ + unsigned int val = atomic_inc_return(&f->rr_cur); + + return val % num; +} + +static unsigned int fanout_demux_cpu(struct packet_fanout *f, + struct sk_buff *skb, + unsigned int num) +{ + return smp_processor_id() % num; +} + +static unsigned int fanout_demux_rnd(struct packet_fanout *f, + struct sk_buff *skb, + unsigned int num) +{ + return prandom_u32_max(num); +} + +static unsigned int fanout_demux_rollover(struct packet_fanout *f, + struct sk_buff *skb, + unsigned int idx, bool try_self, + unsigned int num) +{ + struct packet_sock *po, *po_next, *po_skip = NULL; + unsigned int i, j, room = ROOM_NONE; + + po = pkt_sk(rcu_dereference(f->arr[idx])); + + if (try_self) { + room = packet_rcv_has_room(po, skb); + if (room == ROOM_NORMAL || + (room == ROOM_LOW && !fanout_flow_is_huge(po, skb))) + return idx; + po_skip = po; + } + + i = j = min_t(int, po->rollover->sock, num - 1); + do { + po_next = pkt_sk(rcu_dereference(f->arr[i])); + if (po_next != po_skip && !READ_ONCE(po_next->pressure) && + packet_rcv_has_room(po_next, skb) == ROOM_NORMAL) { + if (i != j) + po->rollover->sock = i; + atomic_long_inc(&po->rollover->num); + if (room == ROOM_LOW) + atomic_long_inc(&po->rollover->num_huge); + return i; + } + + if (++i == num) + i = 0; + } while (i != j); + + atomic_long_inc(&po->rollover->num_failed); + return idx; +} + +static unsigned int fanout_demux_qm(struct packet_fanout *f, + struct sk_buff *skb, + unsigned int num) +{ + return skb_get_queue_mapping(skb) % num; +} + +static unsigned int fanout_demux_bpf(struct packet_fanout *f, + struct sk_buff *skb, + unsigned int num) +{ + struct bpf_prog *prog; + unsigned int ret = 0; + + rcu_read_lock(); + prog = rcu_dereference(f->bpf_prog); + if (prog) + ret = bpf_prog_run_clear_cb(prog, skb) % num; + rcu_read_unlock(); + + return ret; +} + +static bool fanout_has_flag(struct packet_fanout *f, u16 flag) +{ + return f->flags & (flag >> 8); +} + +static int packet_rcv_fanout(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + struct packet_fanout *f = pt->af_packet_priv; + unsigned int num = READ_ONCE(f->num_members); + struct net *net = read_pnet(&f->net); + struct packet_sock *po; + unsigned int idx; + + if (!net_eq(dev_net(dev), net) || !num) { + kfree_skb(skb); + return 0; + } + + if (fanout_has_flag(f, PACKET_FANOUT_FLAG_DEFRAG)) { + skb = ip_check_defrag(net, skb, IP_DEFRAG_AF_PACKET); + if (!skb) + return 0; + } + switch (f->type) { + case PACKET_FANOUT_HASH: + default: + idx = fanout_demux_hash(f, skb, num); + break; + case PACKET_FANOUT_LB: + idx = fanout_demux_lb(f, skb, num); + break; + case PACKET_FANOUT_CPU: + idx = fanout_demux_cpu(f, skb, num); + break; + case PACKET_FANOUT_RND: + idx = fanout_demux_rnd(f, skb, num); + break; + case PACKET_FANOUT_QM: + idx = fanout_demux_qm(f, skb, num); + break; + case PACKET_FANOUT_ROLLOVER: + idx = fanout_demux_rollover(f, skb, 0, false, num); + break; + case PACKET_FANOUT_CBPF: + case PACKET_FANOUT_EBPF: + idx = fanout_demux_bpf(f, skb, num); + break; + } + + if (fanout_has_flag(f, PACKET_FANOUT_FLAG_ROLLOVER)) + idx = fanout_demux_rollover(f, skb, idx, true, num); + + po = pkt_sk(rcu_dereference(f->arr[idx])); + return po->prot_hook.func(skb, dev, &po->prot_hook, orig_dev); +} + +DEFINE_MUTEX(fanout_mutex); +EXPORT_SYMBOL_GPL(fanout_mutex); +static LIST_HEAD(fanout_list); +static u16 fanout_next_id; + +static void __fanout_link(struct sock *sk, struct packet_sock *po) +{ + struct packet_fanout *f = po->fanout; + + spin_lock(&f->lock); + rcu_assign_pointer(f->arr[f->num_members], sk); + smp_wmb(); + f->num_members++; + if (f->num_members == 1) + dev_add_pack(&f->prot_hook); + spin_unlock(&f->lock); +} + +static void __fanout_unlink(struct sock *sk, struct packet_sock *po) +{ + struct packet_fanout *f = po->fanout; + int i; + + spin_lock(&f->lock); + for (i = 0; i < f->num_members; i++) { + if (rcu_dereference_protected(f->arr[i], + lockdep_is_held(&f->lock)) == sk) + break; + } + BUG_ON(i >= f->num_members); + rcu_assign_pointer(f->arr[i], + rcu_dereference_protected(f->arr[f->num_members - 1], + lockdep_is_held(&f->lock))); + f->num_members--; + if (f->num_members == 0) + __dev_remove_pack(&f->prot_hook); + spin_unlock(&f->lock); +} + +static bool match_fanout_group(struct packet_type *ptype, struct sock *sk) +{ + if (sk->sk_family != PF_PACKET) + return false; + + return ptype->af_packet_priv == pkt_sk(sk)->fanout; +} + +static void fanout_init_data(struct packet_fanout *f) +{ + switch (f->type) { + case PACKET_FANOUT_LB: + atomic_set(&f->rr_cur, 0); + break; + case PACKET_FANOUT_CBPF: + case PACKET_FANOUT_EBPF: + RCU_INIT_POINTER(f->bpf_prog, NULL); + break; + } +} + +static void __fanout_set_data_bpf(struct packet_fanout *f, struct bpf_prog *new) +{ + struct bpf_prog *old; + + spin_lock(&f->lock); + old = rcu_dereference_protected(f->bpf_prog, lockdep_is_held(&f->lock)); + rcu_assign_pointer(f->bpf_prog, new); + spin_unlock(&f->lock); + + if (old) { + synchronize_net(); + bpf_prog_destroy(old); + } +} + +static int fanout_set_data_cbpf(struct packet_sock *po, sockptr_t data, + unsigned int len) +{ + struct bpf_prog *new; + struct sock_fprog fprog; + int ret; + + if (sock_flag(&po->sk, SOCK_FILTER_LOCKED)) + return -EPERM; + + ret = copy_bpf_fprog_from_user(&fprog, data, len); + if (ret) + return ret; + + ret = bpf_prog_create_from_user(&new, &fprog, NULL, false); + if (ret) + return ret; + + __fanout_set_data_bpf(po->fanout, new); + return 0; +} + +static int fanout_set_data_ebpf(struct packet_sock *po, sockptr_t data, + unsigned int len) +{ + struct bpf_prog *new; + u32 fd; + + if (sock_flag(&po->sk, SOCK_FILTER_LOCKED)) + return -EPERM; + if (len != sizeof(fd)) + return -EINVAL; + if (copy_from_sockptr(&fd, data, len)) + return -EFAULT; + + new = bpf_prog_get_type(fd, BPF_PROG_TYPE_SOCKET_FILTER); + if (IS_ERR(new)) + return PTR_ERR(new); + + __fanout_set_data_bpf(po->fanout, new); + return 0; +} + +static int fanout_set_data(struct packet_sock *po, sockptr_t data, + unsigned int len) +{ + switch (po->fanout->type) { + case PACKET_FANOUT_CBPF: + return fanout_set_data_cbpf(po, data, len); + case PACKET_FANOUT_EBPF: + return fanout_set_data_ebpf(po, data, len); + default: + return -EINVAL; + } +} + +static void fanout_release_data(struct packet_fanout *f) +{ + switch (f->type) { + case PACKET_FANOUT_CBPF: + case PACKET_FANOUT_EBPF: + __fanout_set_data_bpf(f, NULL); + } +} + +static bool __fanout_id_is_free(struct sock *sk, u16 candidate_id) +{ + struct packet_fanout *f; + + list_for_each_entry(f, &fanout_list, list) { + if (f->id == candidate_id && + read_pnet(&f->net) == sock_net(sk)) { + return false; + } + } + return true; +} + +static bool fanout_find_new_id(struct sock *sk, u16 *new_id) +{ + u16 id = fanout_next_id; + + do { + if (__fanout_id_is_free(sk, id)) { + *new_id = id; + fanout_next_id = id + 1; + return true; + } + + id++; + } while (id != fanout_next_id); + + return false; +} + +static int fanout_add(struct sock *sk, struct fanout_args *args) +{ + struct packet_rollover *rollover = NULL; + struct packet_sock *po = pkt_sk(sk); + u16 type_flags = args->type_flags; + struct packet_fanout *f, *match; + u8 type = type_flags & 0xff; + u8 flags = type_flags >> 8; + u16 id = args->id; + int err; + + switch (type) { + case PACKET_FANOUT_ROLLOVER: + if (type_flags & PACKET_FANOUT_FLAG_ROLLOVER) + return -EINVAL; + break; + case PACKET_FANOUT_HASH: + case PACKET_FANOUT_LB: + case PACKET_FANOUT_CPU: + case PACKET_FANOUT_RND: + case PACKET_FANOUT_QM: + case PACKET_FANOUT_CBPF: + case PACKET_FANOUT_EBPF: + break; + default: + return -EINVAL; + } + + mutex_lock(&fanout_mutex); + + err = -EALREADY; + if (po->fanout) + goto out; + + if (type == PACKET_FANOUT_ROLLOVER || + (type_flags & PACKET_FANOUT_FLAG_ROLLOVER)) { + err = -ENOMEM; + rollover = kzalloc(sizeof(*rollover), GFP_KERNEL); + if (!rollover) + goto out; + atomic_long_set(&rollover->num, 0); + atomic_long_set(&rollover->num_huge, 0); + atomic_long_set(&rollover->num_failed, 0); + } + + if (type_flags & PACKET_FANOUT_FLAG_UNIQUEID) { + if (id != 0) { + err = -EINVAL; + goto out; + } + if (!fanout_find_new_id(sk, &id)) { + err = -ENOMEM; + goto out; + } + /* ephemeral flag for the first socket in the group: drop it */ + flags &= ~(PACKET_FANOUT_FLAG_UNIQUEID >> 8); + } + + match = NULL; + list_for_each_entry(f, &fanout_list, list) { + if (f->id == id && + read_pnet(&f->net) == sock_net(sk)) { + match = f; + break; + } + } + err = -EINVAL; + if (match) { + if (match->flags != flags) + goto out; + if (args->max_num_members && + args->max_num_members != match->max_num_members) + goto out; + } else { + if (args->max_num_members > PACKET_FANOUT_MAX) + goto out; + if (!args->max_num_members) + /* legacy PACKET_FANOUT_MAX */ + args->max_num_members = 256; + err = -ENOMEM; + match = kvzalloc(struct_size(match, arr, args->max_num_members), + GFP_KERNEL); + if (!match) + goto out; + write_pnet(&match->net, sock_net(sk)); + match->id = id; + match->type = type; + match->flags = flags; + INIT_LIST_HEAD(&match->list); + spin_lock_init(&match->lock); + refcount_set(&match->sk_ref, 0); + fanout_init_data(match); + match->prot_hook.type = po->prot_hook.type; + match->prot_hook.dev = po->prot_hook.dev; + match->prot_hook.func = packet_rcv_fanout; + match->prot_hook.af_packet_priv = match; + match->prot_hook.af_packet_net = read_pnet(&match->net); + match->prot_hook.id_match = match_fanout_group; + match->max_num_members = args->max_num_members; + list_add(&match->list, &fanout_list); + } + err = -EINVAL; + + spin_lock(&po->bind_lock); + if (po->running && + match->type == type && + match->prot_hook.type == po->prot_hook.type && + match->prot_hook.dev == po->prot_hook.dev) { + err = -ENOSPC; + if (refcount_read(&match->sk_ref) < match->max_num_members) { + __dev_remove_pack(&po->prot_hook); + + /* Paired with packet_setsockopt(PACKET_FANOUT_DATA) */ + WRITE_ONCE(po->fanout, match); + + po->rollover = rollover; + rollover = NULL; + refcount_set(&match->sk_ref, refcount_read(&match->sk_ref) + 1); + __fanout_link(sk, po); + err = 0; + } + } + spin_unlock(&po->bind_lock); + + if (err && !refcount_read(&match->sk_ref)) { + list_del(&match->list); + kvfree(match); + } + +out: + kfree(rollover); + mutex_unlock(&fanout_mutex); + return err; +} + +/* If pkt_sk(sk)->fanout->sk_ref is zero, this function removes + * pkt_sk(sk)->fanout from fanout_list and returns pkt_sk(sk)->fanout. + * It is the responsibility of the caller to call fanout_release_data() and + * free the returned packet_fanout (after synchronize_net()) + */ +static struct packet_fanout *fanout_release(struct sock *sk) +{ + struct packet_sock *po = pkt_sk(sk); + struct packet_fanout *f; + + mutex_lock(&fanout_mutex); + f = po->fanout; + if (f) { + po->fanout = NULL; + + if (refcount_dec_and_test(&f->sk_ref)) + list_del(&f->list); + else + f = NULL; + } + mutex_unlock(&fanout_mutex); + + return f; +} + +static bool packet_extra_vlan_len_allowed(const struct net_device *dev, + struct sk_buff *skb) +{ + /* Earlier code assumed this would be a VLAN pkt, double-check + * this now that we have the actual packet in hand. We can only + * do this check on Ethernet devices. + */ + if (unlikely(dev->type != ARPHRD_ETHER)) + return false; + + skb_reset_mac_header(skb); + return likely(eth_hdr(skb)->h_proto == htons(ETH_P_8021Q)); +} + +static const struct proto_ops packet_ops; + +static const struct proto_ops packet_ops_spkt; + +static int packet_rcv_spkt(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + struct sock *sk; + struct sockaddr_pkt *spkt; + + /* + * When we registered the protocol we saved the socket in the data + * field for just this event. + */ + + sk = pt->af_packet_priv; + + /* + * Yank back the headers [hope the device set this + * right or kerboom...] + * + * Incoming packets have ll header pulled, + * push it back. + * + * For outgoing ones skb->data == skb_mac_header(skb) + * so that this procedure is noop. + */ + + if (skb->pkt_type == PACKET_LOOPBACK) + goto out; + + if (!net_eq(dev_net(dev), sock_net(sk))) + goto out; + + skb = skb_share_check(skb, GFP_ATOMIC); + if (skb == NULL) + goto oom; + + /* drop any routing info */ + skb_dst_drop(skb); + + /* drop conntrack reference */ + nf_reset_ct(skb); + + spkt = &PACKET_SKB_CB(skb)->sa.pkt; + + skb_push(skb, skb->data - skb_mac_header(skb)); + + /* + * The SOCK_PACKET socket receives _all_ frames. + */ + + spkt->spkt_family = dev->type; + strscpy(spkt->spkt_device, dev->name, sizeof(spkt->spkt_device)); + spkt->spkt_protocol = skb->protocol; + + /* + * Charge the memory to the socket. This is done specifically + * to prevent sockets using all the memory up. + */ + + if (sock_queue_rcv_skb(sk, skb) == 0) + return 0; + +out: + kfree_skb(skb); +oom: + return 0; +} + +static void packet_parse_headers(struct sk_buff *skb, struct socket *sock) +{ + int depth; + + if ((!skb->protocol || skb->protocol == htons(ETH_P_ALL)) && + sock->type == SOCK_RAW) { + skb_reset_mac_header(skb); + skb->protocol = dev_parse_header_protocol(skb); + } + + /* Move network header to the right position for VLAN tagged packets */ + if (likely(skb->dev->type == ARPHRD_ETHER) && + eth_type_vlan(skb->protocol) && + vlan_get_protocol_and_depth(skb, skb->protocol, &depth) != 0) + skb_set_network_header(skb, depth); + + skb_probe_transport_header(skb); +} + +/* + * Output a raw packet to a device layer. This bypasses all the other + * protocol layers and you must therefore supply it with a complete frame + */ + +static int packet_sendmsg_spkt(struct socket *sock, struct msghdr *msg, + size_t len) +{ + struct sock *sk = sock->sk; + DECLARE_SOCKADDR(struct sockaddr_pkt *, saddr, msg->msg_name); + struct sk_buff *skb = NULL; + struct net_device *dev; + struct sockcm_cookie sockc; + __be16 proto = 0; + int err; + int extra_len = 0; + + /* + * Get and verify the address. + */ + + if (saddr) { + if (msg->msg_namelen < sizeof(struct sockaddr)) + return -EINVAL; + if (msg->msg_namelen == sizeof(struct sockaddr_pkt)) + proto = saddr->spkt_protocol; + } else + return -ENOTCONN; /* SOCK_PACKET must be sent giving an address */ + + /* + * Find the device first to size check it + */ + + saddr->spkt_device[sizeof(saddr->spkt_device) - 1] = 0; +retry: + rcu_read_lock(); + dev = dev_get_by_name_rcu(sock_net(sk), saddr->spkt_device); + err = -ENODEV; + if (dev == NULL) + goto out_unlock; + + err = -ENETDOWN; + if (!(dev->flags & IFF_UP)) + goto out_unlock; + + /* + * You may not queue a frame bigger than the mtu. This is the lowest level + * raw protocol and you must do your own fragmentation at this level. + */ + + if (unlikely(sock_flag(sk, SOCK_NOFCS))) { + if (!netif_supports_nofcs(dev)) { + err = -EPROTONOSUPPORT; + goto out_unlock; + } + extra_len = 4; /* We're doing our own CRC */ + } + + err = -EMSGSIZE; + if (len > dev->mtu + dev->hard_header_len + VLAN_HLEN + extra_len) + goto out_unlock; + + if (!skb) { + size_t reserved = LL_RESERVED_SPACE(dev); + int tlen = dev->needed_tailroom; + unsigned int hhlen = dev->header_ops ? dev->hard_header_len : 0; + + rcu_read_unlock(); + skb = sock_wmalloc(sk, len + reserved + tlen, 0, GFP_KERNEL); + if (skb == NULL) + return -ENOBUFS; + /* FIXME: Save some space for broken drivers that write a hard + * header at transmission time by themselves. PPP is the notable + * one here. This should really be fixed at the driver level. + */ + skb_reserve(skb, reserved); + skb_reset_network_header(skb); + + /* Try to align data part correctly */ + if (hhlen) { + skb->data -= hhlen; + skb->tail -= hhlen; + if (len < hhlen) + skb_reset_network_header(skb); + } + err = memcpy_from_msg(skb_put(skb, len), msg, len); + if (err) + goto out_free; + goto retry; + } + + if (!dev_validate_header(dev, skb->data, len) || !skb->len) { + err = -EINVAL; + goto out_unlock; + } + if (len > (dev->mtu + dev->hard_header_len + extra_len) && + !packet_extra_vlan_len_allowed(dev, skb)) { + err = -EMSGSIZE; + goto out_unlock; + } + + sockcm_init(&sockc, sk); + if (msg->msg_controllen) { + err = sock_cmsg_send(sk, msg, &sockc); + if (unlikely(err)) + goto out_unlock; + } + + skb->protocol = proto; + skb->dev = dev; + skb->priority = READ_ONCE(sk->sk_priority); + skb->mark = READ_ONCE(sk->sk_mark); + skb->tstamp = sockc.transmit_time; + + skb_setup_tx_timestamp(skb, sockc.tsflags); + + if (unlikely(extra_len == 4)) + skb->no_fcs = 1; + + packet_parse_headers(skb, sock); + + dev_queue_xmit(skb); + rcu_read_unlock(); + return len; + +out_unlock: + rcu_read_unlock(); +out_free: + kfree_skb(skb); + return err; +} + +static unsigned int run_filter(struct sk_buff *skb, + const struct sock *sk, + unsigned int res) +{ + struct sk_filter *filter; + + rcu_read_lock(); + filter = rcu_dereference(sk->sk_filter); + if (filter != NULL) + res = bpf_prog_run_clear_cb(filter->prog, skb); + rcu_read_unlock(); + + return res; +} + +static int packet_rcv_vnet(struct msghdr *msg, const struct sk_buff *skb, + size_t *len) +{ + struct virtio_net_hdr vnet_hdr; + + if (*len < sizeof(vnet_hdr)) + return -EINVAL; + *len -= sizeof(vnet_hdr); + + if (virtio_net_hdr_from_skb(skb, &vnet_hdr, vio_le(), true, 0)) + return -EINVAL; + + return memcpy_to_msg(msg, (void *)&vnet_hdr, sizeof(vnet_hdr)); +} + +/* + * This function makes lazy skb cloning in hope that most of packets + * are discarded by BPF. + * + * Note tricky part: we DO mangle shared skb! skb->data, skb->len + * and skb->cb are mangled. It works because (and until) packets + * falling here are owned by current CPU. Output packets are cloned + * by dev_queue_xmit_nit(), input packets are processed by net_bh + * sequentially, so that if we return skb to original state on exit, + * we will not harm anyone. + */ + +static int packet_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + struct sock *sk; + struct sockaddr_ll *sll; + struct packet_sock *po; + u8 *skb_head = skb->data; + int skb_len = skb->len; + unsigned int snaplen, res; + bool is_drop_n_account = false; + + if (skb->pkt_type == PACKET_LOOPBACK) + goto drop; + + sk = pt->af_packet_priv; + po = pkt_sk(sk); + + if (!net_eq(dev_net(dev), sock_net(sk))) + goto drop; + + skb->dev = dev; + + if (dev_has_header(dev)) { + /* The device has an explicit notion of ll header, + * exported to higher levels. + * + * Otherwise, the device hides details of its frame + * structure, so that corresponding packet head is + * never delivered to user. + */ + if (sk->sk_type != SOCK_DGRAM) + skb_push(skb, skb->data - skb_mac_header(skb)); + else if (skb->pkt_type == PACKET_OUTGOING) { + /* Special case: outgoing packets have ll header at head */ + skb_pull(skb, skb_network_offset(skb)); + } + } + + snaplen = skb->len; + + res = run_filter(skb, sk, snaplen); + if (!res) + goto drop_n_restore; + if (snaplen > res) + snaplen = res; + + if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf) + goto drop_n_acct; + + if (skb_shared(skb)) { + struct sk_buff *nskb = skb_clone(skb, GFP_ATOMIC); + if (nskb == NULL) + goto drop_n_acct; + + if (skb_head != skb->data) { + skb->data = skb_head; + skb->len = skb_len; + } + consume_skb(skb); + skb = nskb; + } + + sock_skb_cb_check_size(sizeof(*PACKET_SKB_CB(skb)) + MAX_ADDR_LEN - 8); + + sll = &PACKET_SKB_CB(skb)->sa.ll; + sll->sll_hatype = dev->type; + sll->sll_pkttype = skb->pkt_type; + if (unlikely(packet_sock_flag(po, PACKET_SOCK_ORIGDEV))) + sll->sll_ifindex = orig_dev->ifindex; + else + sll->sll_ifindex = dev->ifindex; + + sll->sll_halen = dev_parse_header(skb, sll->sll_addr); + + /* sll->sll_family and sll->sll_protocol are set in packet_recvmsg(). + * Use their space for storing the original skb length. + */ + PACKET_SKB_CB(skb)->sa.origlen = skb->len; + + if (pskb_trim(skb, snaplen)) + goto drop_n_acct; + + skb_set_owner_r(skb, sk); + skb->dev = NULL; + skb_dst_drop(skb); + + /* drop conntrack reference */ + nf_reset_ct(skb); + + spin_lock(&sk->sk_receive_queue.lock); + po->stats.stats1.tp_packets++; + sock_skb_set_dropcount(sk, skb); + skb_clear_delivery_time(skb); + __skb_queue_tail(&sk->sk_receive_queue, skb); + spin_unlock(&sk->sk_receive_queue.lock); + sk->sk_data_ready(sk); + return 0; + +drop_n_acct: + is_drop_n_account = true; + atomic_inc(&po->tp_drops); + atomic_inc(&sk->sk_drops); + +drop_n_restore: + if (skb_head != skb->data && skb_shared(skb)) { + skb->data = skb_head; + skb->len = skb_len; + } +drop: + if (!is_drop_n_account) + consume_skb(skb); + else + kfree_skb(skb); + return 0; +} + +static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + struct sock *sk; + struct packet_sock *po; + struct sockaddr_ll *sll; + union tpacket_uhdr h; + u8 *skb_head = skb->data; + int skb_len = skb->len; + unsigned int snaplen, res; + unsigned long status = TP_STATUS_USER; + unsigned short macoff, hdrlen; + unsigned int netoff; + struct sk_buff *copy_skb = NULL; + struct timespec64 ts; + __u32 ts_status; + bool is_drop_n_account = false; + unsigned int slot_id = 0; + bool do_vnet = false; + + /* struct tpacket{2,3}_hdr is aligned to a multiple of TPACKET_ALIGNMENT. + * We may add members to them until current aligned size without forcing + * userspace to call getsockopt(..., PACKET_HDRLEN, ...). + */ + BUILD_BUG_ON(TPACKET_ALIGN(sizeof(*h.h2)) != 32); + BUILD_BUG_ON(TPACKET_ALIGN(sizeof(*h.h3)) != 48); + + if (skb->pkt_type == PACKET_LOOPBACK) + goto drop; + + sk = pt->af_packet_priv; + po = pkt_sk(sk); + + if (!net_eq(dev_net(dev), sock_net(sk))) + goto drop; + + if (dev_has_header(dev)) { + if (sk->sk_type != SOCK_DGRAM) + skb_push(skb, skb->data - skb_mac_header(skb)); + else if (skb->pkt_type == PACKET_OUTGOING) { + /* Special case: outgoing packets have ll header at head */ + skb_pull(skb, skb_network_offset(skb)); + } + } + + snaplen = skb->len; + + res = run_filter(skb, sk, snaplen); + if (!res) + goto drop_n_restore; + + /* If we are flooded, just give up */ + if (__packet_rcv_has_room(po, skb) == ROOM_NONE) { + atomic_inc(&po->tp_drops); + goto drop_n_restore; + } + + if (skb->ip_summed == CHECKSUM_PARTIAL) + status |= TP_STATUS_CSUMNOTREADY; + else if (skb->pkt_type != PACKET_OUTGOING && + skb_csum_unnecessary(skb)) + status |= TP_STATUS_CSUM_VALID; + + if (snaplen > res) + snaplen = res; + + if (sk->sk_type == SOCK_DGRAM) { + macoff = netoff = TPACKET_ALIGN(po->tp_hdrlen) + 16 + + po->tp_reserve; + } else { + unsigned int maclen = skb_network_offset(skb); + netoff = TPACKET_ALIGN(po->tp_hdrlen + + (maclen < 16 ? 16 : maclen)) + + po->tp_reserve; + if (po->has_vnet_hdr) { + netoff += sizeof(struct virtio_net_hdr); + do_vnet = true; + } + macoff = netoff - maclen; + } + if (netoff > USHRT_MAX) { + atomic_inc(&po->tp_drops); + goto drop_n_restore; + } + if (po->tp_version <= TPACKET_V2) { + if (macoff + snaplen > po->rx_ring.frame_size) { + if (po->copy_thresh && + atomic_read(&sk->sk_rmem_alloc) < sk->sk_rcvbuf) { + if (skb_shared(skb)) { + copy_skb = skb_clone(skb, GFP_ATOMIC); + } else { + copy_skb = skb_get(skb); + skb_head = skb->data; + } + if (copy_skb) { + memset(&PACKET_SKB_CB(copy_skb)->sa.ll, 0, + sizeof(PACKET_SKB_CB(copy_skb)->sa.ll)); + skb_set_owner_r(copy_skb, sk); + } + } + snaplen = po->rx_ring.frame_size - macoff; + if ((int)snaplen < 0) { + snaplen = 0; + do_vnet = false; + } + } + } else if (unlikely(macoff + snaplen > + GET_PBDQC_FROM_RB(&po->rx_ring)->max_frame_len)) { + u32 nval; + + nval = GET_PBDQC_FROM_RB(&po->rx_ring)->max_frame_len - macoff; + pr_err_once("tpacket_rcv: packet too big, clamped from %u to %u. macoff=%u\n", + snaplen, nval, macoff); + snaplen = nval; + if (unlikely((int)snaplen < 0)) { + snaplen = 0; + macoff = GET_PBDQC_FROM_RB(&po->rx_ring)->max_frame_len; + do_vnet = false; + } + } + spin_lock(&sk->sk_receive_queue.lock); + h.raw = packet_current_rx_frame(po, skb, + TP_STATUS_KERNEL, (macoff+snaplen)); + if (!h.raw) + goto drop_n_account; + + if (po->tp_version <= TPACKET_V2) { + slot_id = po->rx_ring.head; + if (test_bit(slot_id, po->rx_ring.rx_owner_map)) + goto drop_n_account; + __set_bit(slot_id, po->rx_ring.rx_owner_map); + } + + if (do_vnet && + virtio_net_hdr_from_skb(skb, h.raw + macoff - + sizeof(struct virtio_net_hdr), + vio_le(), true, 0)) { + if (po->tp_version == TPACKET_V3) + prb_clear_blk_fill_status(&po->rx_ring); + goto drop_n_account; + } + + if (po->tp_version <= TPACKET_V2) { + packet_increment_rx_head(po, &po->rx_ring); + /* + * LOSING will be reported till you read the stats, + * because it's COR - Clear On Read. + * Anyways, moving it for V1/V2 only as V3 doesn't need this + * at packet level. + */ + if (atomic_read(&po->tp_drops)) + status |= TP_STATUS_LOSING; + } + + po->stats.stats1.tp_packets++; + if (copy_skb) { + status |= TP_STATUS_COPY; + skb_clear_delivery_time(copy_skb); + __skb_queue_tail(&sk->sk_receive_queue, copy_skb); + } + spin_unlock(&sk->sk_receive_queue.lock); + + skb_copy_bits(skb, 0, h.raw + macoff, snaplen); + + /* Always timestamp; prefer an existing software timestamp taken + * closer to the time of capture. + */ + ts_status = tpacket_get_timestamp(skb, &ts, + po->tp_tstamp | SOF_TIMESTAMPING_SOFTWARE); + if (!ts_status) + ktime_get_real_ts64(&ts); + + status |= ts_status; + + switch (po->tp_version) { + case TPACKET_V1: + h.h1->tp_len = skb->len; + h.h1->tp_snaplen = snaplen; + h.h1->tp_mac = macoff; + h.h1->tp_net = netoff; + h.h1->tp_sec = ts.tv_sec; + h.h1->tp_usec = ts.tv_nsec / NSEC_PER_USEC; + hdrlen = sizeof(*h.h1); + break; + case TPACKET_V2: + h.h2->tp_len = skb->len; + h.h2->tp_snaplen = snaplen; + h.h2->tp_mac = macoff; + h.h2->tp_net = netoff; + h.h2->tp_sec = ts.tv_sec; + h.h2->tp_nsec = ts.tv_nsec; + if (skb_vlan_tag_present(skb)) { + h.h2->tp_vlan_tci = skb_vlan_tag_get(skb); + h.h2->tp_vlan_tpid = ntohs(skb->vlan_proto); + status |= TP_STATUS_VLAN_VALID | TP_STATUS_VLAN_TPID_VALID; + } else { + h.h2->tp_vlan_tci = 0; + h.h2->tp_vlan_tpid = 0; + } + memset(h.h2->tp_padding, 0, sizeof(h.h2->tp_padding)); + hdrlen = sizeof(*h.h2); + break; + case TPACKET_V3: + /* tp_nxt_offset,vlan are already populated above. + * So DONT clear those fields here + */ + h.h3->tp_status |= status; + h.h3->tp_len = skb->len; + h.h3->tp_snaplen = snaplen; + h.h3->tp_mac = macoff; + h.h3->tp_net = netoff; + h.h3->tp_sec = ts.tv_sec; + h.h3->tp_nsec = ts.tv_nsec; + memset(h.h3->tp_padding, 0, sizeof(h.h3->tp_padding)); + hdrlen = sizeof(*h.h3); + break; + default: + BUG(); + } + + sll = h.raw + TPACKET_ALIGN(hdrlen); + sll->sll_halen = dev_parse_header(skb, sll->sll_addr); + sll->sll_family = AF_PACKET; + sll->sll_hatype = dev->type; + sll->sll_protocol = skb->protocol; + sll->sll_pkttype = skb->pkt_type; + if (unlikely(packet_sock_flag(po, PACKET_SOCK_ORIGDEV))) + sll->sll_ifindex = orig_dev->ifindex; + else + sll->sll_ifindex = dev->ifindex; + + smp_mb(); + +#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE == 1 + if (po->tp_version <= TPACKET_V2) { + u8 *start, *end; + + end = (u8 *) PAGE_ALIGN((unsigned long) h.raw + + macoff + snaplen); + + for (start = h.raw; start < end; start += PAGE_SIZE) + flush_dcache_page(pgv_to_page(start)); + } + smp_wmb(); +#endif + + if (po->tp_version <= TPACKET_V2) { + spin_lock(&sk->sk_receive_queue.lock); + __packet_set_status(po, h.raw, status); + __clear_bit(slot_id, po->rx_ring.rx_owner_map); + spin_unlock(&sk->sk_receive_queue.lock); + sk->sk_data_ready(sk); + } else if (po->tp_version == TPACKET_V3) { + prb_clear_blk_fill_status(&po->rx_ring); + } + +drop_n_restore: + if (skb_head != skb->data && skb_shared(skb)) { + skb->data = skb_head; + skb->len = skb_len; + } +drop: + if (!is_drop_n_account) + consume_skb(skb); + else + kfree_skb(skb); + return 0; + +drop_n_account: + spin_unlock(&sk->sk_receive_queue.lock); + atomic_inc(&po->tp_drops); + is_drop_n_account = true; + + sk->sk_data_ready(sk); + kfree_skb(copy_skb); + goto drop_n_restore; +} + +static void tpacket_destruct_skb(struct sk_buff *skb) +{ + struct packet_sock *po = pkt_sk(skb->sk); + + if (likely(po->tx_ring.pg_vec)) { + void *ph; + __u32 ts; + + ph = skb_zcopy_get_nouarg(skb); + packet_dec_pending(&po->tx_ring); + + ts = __packet_set_timestamp(po, ph, skb); + __packet_set_status(po, ph, TP_STATUS_AVAILABLE | ts); + + if (!packet_read_pending(&po->tx_ring)) + complete(&po->skb_completion); + } + + sock_wfree(skb); +} + +static int __packet_snd_vnet_parse(struct virtio_net_hdr *vnet_hdr, size_t len) +{ + if ((vnet_hdr->flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) && + (__virtio16_to_cpu(vio_le(), vnet_hdr->csum_start) + + __virtio16_to_cpu(vio_le(), vnet_hdr->csum_offset) + 2 > + __virtio16_to_cpu(vio_le(), vnet_hdr->hdr_len))) + vnet_hdr->hdr_len = __cpu_to_virtio16(vio_le(), + __virtio16_to_cpu(vio_le(), vnet_hdr->csum_start) + + __virtio16_to_cpu(vio_le(), vnet_hdr->csum_offset) + 2); + + if (__virtio16_to_cpu(vio_le(), vnet_hdr->hdr_len) > len) + return -EINVAL; + + return 0; +} + +static int packet_snd_vnet_parse(struct msghdr *msg, size_t *len, + struct virtio_net_hdr *vnet_hdr) +{ + if (*len < sizeof(*vnet_hdr)) + return -EINVAL; + *len -= sizeof(*vnet_hdr); + + if (!copy_from_iter_full(vnet_hdr, sizeof(*vnet_hdr), &msg->msg_iter)) + return -EFAULT; + + return __packet_snd_vnet_parse(vnet_hdr, *len); +} + +static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb, + void *frame, struct net_device *dev, void *data, int tp_len, + __be16 proto, unsigned char *addr, int hlen, int copylen, + const struct sockcm_cookie *sockc) +{ + union tpacket_uhdr ph; + int to_write, offset, len, nr_frags, len_max; + struct socket *sock = po->sk.sk_socket; + struct page *page; + int err; + + ph.raw = frame; + + skb->protocol = proto; + skb->dev = dev; + skb->priority = READ_ONCE(po->sk.sk_priority); + skb->mark = READ_ONCE(po->sk.sk_mark); + skb->tstamp = sockc->transmit_time; + skb_setup_tx_timestamp(skb, sockc->tsflags); + skb_zcopy_set_nouarg(skb, ph.raw); + + skb_reserve(skb, hlen); + skb_reset_network_header(skb); + + to_write = tp_len; + + if (sock->type == SOCK_DGRAM) { + err = dev_hard_header(skb, dev, ntohs(proto), addr, + NULL, tp_len); + if (unlikely(err < 0)) + return -EINVAL; + } else if (copylen) { + int hdrlen = min_t(int, copylen, tp_len); + + skb_push(skb, dev->hard_header_len); + skb_put(skb, copylen - dev->hard_header_len); + err = skb_store_bits(skb, 0, data, hdrlen); + if (unlikely(err)) + return err; + if (!dev_validate_header(dev, skb->data, hdrlen)) + return -EINVAL; + + data += hdrlen; + to_write -= hdrlen; + } + + offset = offset_in_page(data); + len_max = PAGE_SIZE - offset; + len = ((to_write > len_max) ? len_max : to_write); + + skb->data_len = to_write; + skb->len += to_write; + skb->truesize += to_write; + refcount_add(to_write, &po->sk.sk_wmem_alloc); + + while (likely(to_write)) { + nr_frags = skb_shinfo(skb)->nr_frags; + + if (unlikely(nr_frags >= MAX_SKB_FRAGS)) { + pr_err("Packet exceed the number of skb frags(%lu)\n", + MAX_SKB_FRAGS); + return -EFAULT; + } + + page = pgv_to_page(data); + data += len; + flush_dcache_page(page); + get_page(page); + skb_fill_page_desc(skb, nr_frags, page, offset, len); + to_write -= len; + offset = 0; + len_max = PAGE_SIZE; + len = ((to_write > len_max) ? len_max : to_write); + } + + packet_parse_headers(skb, sock); + + return tp_len; +} + +static int tpacket_parse_header(struct packet_sock *po, void *frame, + int size_max, void **data) +{ + union tpacket_uhdr ph; + int tp_len, off; + + ph.raw = frame; + + switch (po->tp_version) { + case TPACKET_V3: + if (ph.h3->tp_next_offset != 0) { + pr_warn_once("variable sized slot not supported"); + return -EINVAL; + } + tp_len = ph.h3->tp_len; + break; + case TPACKET_V2: + tp_len = ph.h2->tp_len; + break; + default: + tp_len = ph.h1->tp_len; + break; + } + if (unlikely(tp_len > size_max)) { + pr_err("packet size is too long (%d > %d)\n", tp_len, size_max); + return -EMSGSIZE; + } + + if (unlikely(po->tp_tx_has_off)) { + int off_min, off_max; + + off_min = po->tp_hdrlen - sizeof(struct sockaddr_ll); + off_max = po->tx_ring.frame_size - tp_len; + if (po->sk.sk_type == SOCK_DGRAM) { + switch (po->tp_version) { + case TPACKET_V3: + off = ph.h3->tp_net; + break; + case TPACKET_V2: + off = ph.h2->tp_net; + break; + default: + off = ph.h1->tp_net; + break; + } + } else { + switch (po->tp_version) { + case TPACKET_V3: + off = ph.h3->tp_mac; + break; + case TPACKET_V2: + off = ph.h2->tp_mac; + break; + default: + off = ph.h1->tp_mac; + break; + } + } + if (unlikely((off < off_min) || (off_max < off))) + return -EINVAL; + } else { + off = po->tp_hdrlen - sizeof(struct sockaddr_ll); + } + + *data = frame + off; + return tp_len; +} + +static int tpacket_snd(struct packet_sock *po, struct msghdr *msg) +{ + struct sk_buff *skb = NULL; + struct net_device *dev; + struct virtio_net_hdr *vnet_hdr = NULL; + struct sockcm_cookie sockc; + __be16 proto; + int err, reserve = 0; + void *ph; + DECLARE_SOCKADDR(struct sockaddr_ll *, saddr, msg->msg_name); + bool need_wait = !(msg->msg_flags & MSG_DONTWAIT); + unsigned char *addr = NULL; + int tp_len, size_max; + void *data; + int len_sum = 0; + int status = TP_STATUS_AVAILABLE; + int hlen, tlen, copylen = 0; + long timeo = 0; + + mutex_lock(&po->pg_vec_lock); + + /* packet_sendmsg() check on tx_ring.pg_vec was lockless, + * we need to confirm it under protection of pg_vec_lock. + */ + if (unlikely(!po->tx_ring.pg_vec)) { + err = -EBUSY; + goto out; + } + if (likely(saddr == NULL)) { + dev = packet_cached_dev_get(po); + proto = READ_ONCE(po->num); + } else { + err = -EINVAL; + if (msg->msg_namelen < sizeof(struct sockaddr_ll)) + goto out; + if (msg->msg_namelen < (saddr->sll_halen + + offsetof(struct sockaddr_ll, + sll_addr))) + goto out; + proto = saddr->sll_protocol; + dev = dev_get_by_index(sock_net(&po->sk), saddr->sll_ifindex); + if (po->sk.sk_socket->type == SOCK_DGRAM) { + if (dev && msg->msg_namelen < dev->addr_len + + offsetof(struct sockaddr_ll, sll_addr)) + goto out_put; + addr = saddr->sll_addr; + } + } + + err = -ENXIO; + if (unlikely(dev == NULL)) + goto out; + err = -ENETDOWN; + if (unlikely(!(dev->flags & IFF_UP))) + goto out_put; + + sockcm_init(&sockc, &po->sk); + if (msg->msg_controllen) { + err = sock_cmsg_send(&po->sk, msg, &sockc); + if (unlikely(err)) + goto out_put; + } + + if (po->sk.sk_socket->type == SOCK_RAW) + reserve = dev->hard_header_len; + size_max = po->tx_ring.frame_size + - (po->tp_hdrlen - sizeof(struct sockaddr_ll)); + + if ((size_max > dev->mtu + reserve + VLAN_HLEN) && !po->has_vnet_hdr) + size_max = dev->mtu + reserve + VLAN_HLEN; + + reinit_completion(&po->skb_completion); + + do { + ph = packet_current_frame(po, &po->tx_ring, + TP_STATUS_SEND_REQUEST); + if (unlikely(ph == NULL)) { + if (need_wait && skb) { + timeo = sock_sndtimeo(&po->sk, msg->msg_flags & MSG_DONTWAIT); + timeo = wait_for_completion_interruptible_timeout(&po->skb_completion, timeo); + if (timeo <= 0) { + err = !timeo ? -ETIMEDOUT : -ERESTARTSYS; + goto out_put; + } + } + /* check for additional frames */ + continue; + } + + skb = NULL; + tp_len = tpacket_parse_header(po, ph, size_max, &data); + if (tp_len < 0) + goto tpacket_error; + + status = TP_STATUS_SEND_REQUEST; + hlen = LL_RESERVED_SPACE(dev); + tlen = dev->needed_tailroom; + if (po->has_vnet_hdr) { + vnet_hdr = data; + data += sizeof(*vnet_hdr); + tp_len -= sizeof(*vnet_hdr); + if (tp_len < 0 || + __packet_snd_vnet_parse(vnet_hdr, tp_len)) { + tp_len = -EINVAL; + goto tpacket_error; + } + copylen = __virtio16_to_cpu(vio_le(), + vnet_hdr->hdr_len); + } + copylen = max_t(int, copylen, dev->hard_header_len); + skb = sock_alloc_send_skb(&po->sk, + hlen + tlen + sizeof(struct sockaddr_ll) + + (copylen - dev->hard_header_len), + !need_wait, &err); + + if (unlikely(skb == NULL)) { + /* we assume the socket was initially writeable ... */ + if (likely(len_sum > 0)) + err = len_sum; + goto out_status; + } + tp_len = tpacket_fill_skb(po, skb, ph, dev, data, tp_len, proto, + addr, hlen, copylen, &sockc); + if (likely(tp_len >= 0) && + tp_len > dev->mtu + reserve && + !po->has_vnet_hdr && + !packet_extra_vlan_len_allowed(dev, skb)) + tp_len = -EMSGSIZE; + + if (unlikely(tp_len < 0)) { +tpacket_error: + if (po->tp_loss) { + __packet_set_status(po, ph, + TP_STATUS_AVAILABLE); + packet_increment_head(&po->tx_ring); + kfree_skb(skb); + continue; + } else { + status = TP_STATUS_WRONG_FORMAT; + err = tp_len; + goto out_status; + } + } + + if (po->has_vnet_hdr) { + if (virtio_net_hdr_to_skb(skb, vnet_hdr, vio_le())) { + tp_len = -EINVAL; + goto tpacket_error; + } + virtio_net_hdr_set_proto(skb, vnet_hdr); + } + + skb->destructor = tpacket_destruct_skb; + __packet_set_status(po, ph, TP_STATUS_SENDING); + packet_inc_pending(&po->tx_ring); + + status = TP_STATUS_SEND_REQUEST; + /* Paired with WRITE_ONCE() in packet_setsockopt() */ + err = READ_ONCE(po->xmit)(skb); + if (unlikely(err != 0)) { + if (err > 0) + err = net_xmit_errno(err); + if (err && __packet_get_status(po, ph) == + TP_STATUS_AVAILABLE) { + /* skb was destructed already */ + skb = NULL; + goto out_status; + } + /* + * skb was dropped but not destructed yet; + * let's treat it like congestion or err < 0 + */ + err = 0; + } + packet_increment_head(&po->tx_ring); + len_sum += tp_len; + } while (likely((ph != NULL) || + /* Note: packet_read_pending() might be slow if we have + * to call it as it's per_cpu variable, but in fast-path + * we already short-circuit the loop with the first + * condition, and luckily don't have to go that path + * anyway. + */ + (need_wait && packet_read_pending(&po->tx_ring)))); + + err = len_sum; + goto out_put; + +out_status: + __packet_set_status(po, ph, status); + kfree_skb(skb); +out_put: + dev_put(dev); +out: + mutex_unlock(&po->pg_vec_lock); + return err; +} + +static struct sk_buff *packet_alloc_skb(struct sock *sk, size_t prepad, + size_t reserve, size_t len, + size_t linear, int noblock, + int *err) +{ + struct sk_buff *skb; + + /* Under a page? Don't bother with paged skb. */ + if (prepad + len < PAGE_SIZE || !linear) + linear = len; + + skb = sock_alloc_send_pskb(sk, prepad + linear, len - linear, noblock, + err, 0); + if (!skb) + return NULL; + + skb_reserve(skb, reserve); + skb_put(skb, linear); + skb->data_len = len - linear; + skb->len += len - linear; + + return skb; +} + +static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len) +{ + struct sock *sk = sock->sk; + DECLARE_SOCKADDR(struct sockaddr_ll *, saddr, msg->msg_name); + struct sk_buff *skb; + struct net_device *dev; + __be16 proto; + unsigned char *addr = NULL; + int err, reserve = 0; + struct sockcm_cookie sockc; + struct virtio_net_hdr vnet_hdr = { 0 }; + int offset = 0; + struct packet_sock *po = pkt_sk(sk); + bool has_vnet_hdr = false; + int hlen, tlen, linear; + int extra_len = 0; + + /* + * Get and verify the address. + */ + + if (likely(saddr == NULL)) { + dev = packet_cached_dev_get(po); + proto = READ_ONCE(po->num); + } else { + err = -EINVAL; + if (msg->msg_namelen < sizeof(struct sockaddr_ll)) + goto out; + if (msg->msg_namelen < (saddr->sll_halen + offsetof(struct sockaddr_ll, sll_addr))) + goto out; + proto = saddr->sll_protocol; + dev = dev_get_by_index(sock_net(sk), saddr->sll_ifindex); + if (sock->type == SOCK_DGRAM) { + if (dev && msg->msg_namelen < dev->addr_len + + offsetof(struct sockaddr_ll, sll_addr)) + goto out_unlock; + addr = saddr->sll_addr; + } + } + + err = -ENXIO; + if (unlikely(dev == NULL)) + goto out_unlock; + err = -ENETDOWN; + if (unlikely(!(dev->flags & IFF_UP))) + goto out_unlock; + + sockcm_init(&sockc, sk); + sockc.mark = READ_ONCE(sk->sk_mark); + if (msg->msg_controllen) { + err = sock_cmsg_send(sk, msg, &sockc); + if (unlikely(err)) + goto out_unlock; + } + + if (sock->type == SOCK_RAW) + reserve = dev->hard_header_len; + if (po->has_vnet_hdr) { + err = packet_snd_vnet_parse(msg, &len, &vnet_hdr); + if (err) + goto out_unlock; + has_vnet_hdr = true; + } + + if (unlikely(sock_flag(sk, SOCK_NOFCS))) { + if (!netif_supports_nofcs(dev)) { + err = -EPROTONOSUPPORT; + goto out_unlock; + } + extra_len = 4; /* We're doing our own CRC */ + } + + err = -EMSGSIZE; + if (!vnet_hdr.gso_type && + (len > dev->mtu + reserve + VLAN_HLEN + extra_len)) + goto out_unlock; + + err = -ENOBUFS; + hlen = LL_RESERVED_SPACE(dev); + tlen = dev->needed_tailroom; + linear = __virtio16_to_cpu(vio_le(), vnet_hdr.hdr_len); + linear = max(linear, min_t(int, len, dev->hard_header_len)); + skb = packet_alloc_skb(sk, hlen + tlen, hlen, len, linear, + msg->msg_flags & MSG_DONTWAIT, &err); + if (skb == NULL) + goto out_unlock; + + skb_reset_network_header(skb); + + err = -EINVAL; + if (sock->type == SOCK_DGRAM) { + offset = dev_hard_header(skb, dev, ntohs(proto), addr, NULL, len); + if (unlikely(offset < 0)) + goto out_free; + } else if (reserve) { + skb_reserve(skb, -reserve); + if (len < reserve + sizeof(struct ipv6hdr) && + dev->min_header_len != dev->hard_header_len) + skb_reset_network_header(skb); + } + + /* Returns -EFAULT on error */ + err = skb_copy_datagram_from_iter(skb, offset, &msg->msg_iter, len); + if (err) + goto out_free; + + if ((sock->type == SOCK_RAW && + !dev_validate_header(dev, skb->data, len)) || !skb->len) { + err = -EINVAL; + goto out_free; + } + + skb_setup_tx_timestamp(skb, sockc.tsflags); + + if (!vnet_hdr.gso_type && (len > dev->mtu + reserve + extra_len) && + !packet_extra_vlan_len_allowed(dev, skb)) { + err = -EMSGSIZE; + goto out_free; + } + + skb->protocol = proto; + skb->dev = dev; + skb->priority = READ_ONCE(sk->sk_priority); + skb->mark = sockc.mark; + skb->tstamp = sockc.transmit_time; + + if (unlikely(extra_len == 4)) + skb->no_fcs = 1; + + packet_parse_headers(skb, sock); + + if (has_vnet_hdr) { + err = virtio_net_hdr_to_skb(skb, &vnet_hdr, vio_le()); + if (err) + goto out_free; + len += sizeof(vnet_hdr); + virtio_net_hdr_set_proto(skb, &vnet_hdr); + } + + /* Paired with WRITE_ONCE() in packet_setsockopt() */ + err = READ_ONCE(po->xmit)(skb); + if (unlikely(err != 0)) { + if (err > 0) + err = net_xmit_errno(err); + if (err) + goto out_unlock; + } + + dev_put(dev); + + return len; + +out_free: + kfree_skb(skb); +out_unlock: + dev_put(dev); +out: + return err; +} + +static int packet_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) +{ + struct sock *sk = sock->sk; + struct packet_sock *po = pkt_sk(sk); + + /* Reading tx_ring.pg_vec without holding pg_vec_lock is racy. + * tpacket_snd() will redo the check safely. + */ + if (data_race(po->tx_ring.pg_vec)) + return tpacket_snd(po, msg); + + return packet_snd(sock, msg, len); +} + +/* + * Close a PACKET socket. This is fairly simple. We immediately go + * to 'closed' state and remove our protocol entry in the device list. + */ + +static int packet_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct packet_sock *po; + struct packet_fanout *f; + struct net *net; + union tpacket_req_u req_u; + + if (!sk) + return 0; + + net = sock_net(sk); + po = pkt_sk(sk); + + mutex_lock(&net->packet.sklist_lock); + sk_del_node_init_rcu(sk); + mutex_unlock(&net->packet.sklist_lock); + + sock_prot_inuse_add(net, sk->sk_prot, -1); + + spin_lock(&po->bind_lock); + unregister_prot_hook(sk, false); + packet_cached_dev_reset(po); + + if (po->prot_hook.dev) { + netdev_put(po->prot_hook.dev, &po->prot_hook.dev_tracker); + po->prot_hook.dev = NULL; + } + spin_unlock(&po->bind_lock); + + packet_flush_mclist(sk); + + lock_sock(sk); + if (po->rx_ring.pg_vec) { + memset(&req_u, 0, sizeof(req_u)); + packet_set_ring(sk, &req_u, 1, 0); + } + + if (po->tx_ring.pg_vec) { + memset(&req_u, 0, sizeof(req_u)); + packet_set_ring(sk, &req_u, 1, 1); + } + release_sock(sk); + + f = fanout_release(sk); + + synchronize_net(); + + kfree(po->rollover); + if (f) { + fanout_release_data(f); + kvfree(f); + } + /* + * Now the socket is dead. No more input will appear. + */ + sock_orphan(sk); + sock->sk = NULL; + + /* Purge queues */ + + skb_queue_purge(&sk->sk_receive_queue); + packet_free_pending(po); + sk_refcnt_debug_release(sk); + + sock_put(sk); + return 0; +} + +/* + * Attach a packet hook. + */ + +static int packet_do_bind(struct sock *sk, const char *name, int ifindex, + __be16 proto) +{ + struct packet_sock *po = pkt_sk(sk); + struct net_device *dev = NULL; + bool unlisted = false; + bool need_rehook; + int ret = 0; + + lock_sock(sk); + spin_lock(&po->bind_lock); + if (!proto) + proto = po->num; + + rcu_read_lock(); + + if (po->fanout) { + ret = -EINVAL; + goto out_unlock; + } + + if (name) { + dev = dev_get_by_name_rcu(sock_net(sk), name); + if (!dev) { + ret = -ENODEV; + goto out_unlock; + } + } else if (ifindex) { + dev = dev_get_by_index_rcu(sock_net(sk), ifindex); + if (!dev) { + ret = -ENODEV; + goto out_unlock; + } + } + + need_rehook = po->prot_hook.type != proto || po->prot_hook.dev != dev; + + if (need_rehook) { + dev_hold(dev); + if (po->running) { + rcu_read_unlock(); + /* prevents packet_notifier() from calling + * register_prot_hook() + */ + WRITE_ONCE(po->num, 0); + __unregister_prot_hook(sk, true); + rcu_read_lock(); + if (dev) + unlisted = !dev_get_by_index_rcu(sock_net(sk), + dev->ifindex); + } + + BUG_ON(po->running); + WRITE_ONCE(po->num, proto); + po->prot_hook.type = proto; + + netdev_put(po->prot_hook.dev, &po->prot_hook.dev_tracker); + + if (unlikely(unlisted)) { + po->prot_hook.dev = NULL; + WRITE_ONCE(po->ifindex, -1); + packet_cached_dev_reset(po); + } else { + netdev_hold(dev, &po->prot_hook.dev_tracker, + GFP_ATOMIC); + po->prot_hook.dev = dev; + WRITE_ONCE(po->ifindex, dev ? dev->ifindex : 0); + packet_cached_dev_assign(po, dev); + } + dev_put(dev); + } + + if (proto == 0 || !need_rehook) + goto out_unlock; + + if (!unlisted && (!dev || (dev->flags & IFF_UP))) { + register_prot_hook(sk); + } else { + sk->sk_err = ENETDOWN; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + } + +out_unlock: + rcu_read_unlock(); + spin_unlock(&po->bind_lock); + release_sock(sk); + return ret; +} + +/* + * Bind a packet socket to a device + */ + +static int packet_bind_spkt(struct socket *sock, struct sockaddr *uaddr, + int addr_len) +{ + struct sock *sk = sock->sk; + char name[sizeof(uaddr->sa_data) + 1]; + + /* + * Check legality + */ + + if (addr_len != sizeof(struct sockaddr)) + return -EINVAL; + /* uaddr->sa_data comes from the userspace, it's not guaranteed to be + * zero-terminated. + */ + memcpy(name, uaddr->sa_data, sizeof(uaddr->sa_data)); + name[sizeof(uaddr->sa_data)] = 0; + + return packet_do_bind(sk, name, 0, 0); +} + +static int packet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +{ + struct sockaddr_ll *sll = (struct sockaddr_ll *)uaddr; + struct sock *sk = sock->sk; + + /* + * Check legality + */ + + if (addr_len < sizeof(struct sockaddr_ll)) + return -EINVAL; + if (sll->sll_family != AF_PACKET) + return -EINVAL; + + return packet_do_bind(sk, NULL, sll->sll_ifindex, sll->sll_protocol); +} + +static struct proto packet_proto = { + .name = "PACKET", + .owner = THIS_MODULE, + .obj_size = sizeof(struct packet_sock), +}; + +/* + * Create a packet of type SOCK_PACKET. + */ + +static int packet_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + struct packet_sock *po; + __be16 proto = (__force __be16)protocol; /* weird, but documented */ + int err; + + if (!ns_capable(net->user_ns, CAP_NET_RAW)) + return -EPERM; + if (sock->type != SOCK_DGRAM && sock->type != SOCK_RAW && + sock->type != SOCK_PACKET) + return -ESOCKTNOSUPPORT; + + sock->state = SS_UNCONNECTED; + + err = -ENOBUFS; + sk = sk_alloc(net, PF_PACKET, GFP_KERNEL, &packet_proto, kern); + if (sk == NULL) + goto out; + + sock->ops = &packet_ops; + if (sock->type == SOCK_PACKET) + sock->ops = &packet_ops_spkt; + + sock_init_data(sock, sk); + + po = pkt_sk(sk); + init_completion(&po->skb_completion); + sk->sk_family = PF_PACKET; + po->num = proto; + po->xmit = dev_queue_xmit; + + err = packet_alloc_pending(po); + if (err) + goto out2; + + packet_cached_dev_reset(po); + + sk->sk_destruct = packet_sock_destruct; + sk_refcnt_debug_inc(sk); + + /* + * Attach a protocol block + */ + + spin_lock_init(&po->bind_lock); + mutex_init(&po->pg_vec_lock); + po->rollover = NULL; + po->prot_hook.func = packet_rcv; + + if (sock->type == SOCK_PACKET) + po->prot_hook.func = packet_rcv_spkt; + + po->prot_hook.af_packet_priv = sk; + po->prot_hook.af_packet_net = sock_net(sk); + + if (proto) { + po->prot_hook.type = proto; + __register_prot_hook(sk); + } + + mutex_lock(&net->packet.sklist_lock); + sk_add_node_tail_rcu(sk, &net->packet.sklist); + mutex_unlock(&net->packet.sklist_lock); + + sock_prot_inuse_add(net, &packet_proto, 1); + + return 0; +out2: + sk_free(sk); +out: + return err; +} + +/* + * Pull a packet from our receive queue and hand it to the user. + * If necessary we block. + */ + +static int packet_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, + int flags) +{ + struct sock *sk = sock->sk; + struct sk_buff *skb; + int copied, err; + int vnet_hdr_len = 0; + unsigned int origlen = 0; + + err = -EINVAL; + if (flags & ~(MSG_PEEK|MSG_DONTWAIT|MSG_TRUNC|MSG_CMSG_COMPAT|MSG_ERRQUEUE)) + goto out; + +#if 0 + /* What error should we return now? EUNATTACH? */ + if (pkt_sk(sk)->ifindex < 0) + return -ENODEV; +#endif + + if (flags & MSG_ERRQUEUE) { + err = sock_recv_errqueue(sk, msg, len, + SOL_PACKET, PACKET_TX_TIMESTAMP); + goto out; + } + + /* + * Call the generic datagram receiver. This handles all sorts + * of horrible races and re-entrancy so we can forget about it + * in the protocol layers. + * + * Now it will return ENETDOWN, if device have just gone down, + * but then it will block. + */ + + skb = skb_recv_datagram(sk, flags, &err); + + /* + * An error occurred so return it. Because skb_recv_datagram() + * handles the blocking we don't see and worry about blocking + * retries. + */ + + if (skb == NULL) + goto out; + + packet_rcv_try_clear_pressure(pkt_sk(sk)); + + if (pkt_sk(sk)->has_vnet_hdr) { + err = packet_rcv_vnet(msg, skb, &len); + if (err) + goto out_free; + vnet_hdr_len = sizeof(struct virtio_net_hdr); + } + + /* You lose any data beyond the buffer you gave. If it worries + * a user program they can ask the device for its MTU + * anyway. + */ + copied = skb->len; + if (copied > len) { + copied = len; + msg->msg_flags |= MSG_TRUNC; + } + + err = skb_copy_datagram_msg(skb, 0, msg, copied); + if (err) + goto out_free; + + if (sock->type != SOCK_PACKET) { + struct sockaddr_ll *sll = &PACKET_SKB_CB(skb)->sa.ll; + + /* Original length was stored in sockaddr_ll fields */ + origlen = PACKET_SKB_CB(skb)->sa.origlen; + sll->sll_family = AF_PACKET; + sll->sll_protocol = skb->protocol; + } + + sock_recv_cmsgs(msg, sk, skb); + + if (msg->msg_name) { + const size_t max_len = min(sizeof(skb->cb), + sizeof(struct sockaddr_storage)); + int copy_len; + + /* If the address length field is there to be filled + * in, we fill it in now. + */ + if (sock->type == SOCK_PACKET) { + __sockaddr_check_size(sizeof(struct sockaddr_pkt)); + msg->msg_namelen = sizeof(struct sockaddr_pkt); + copy_len = msg->msg_namelen; + } else { + struct sockaddr_ll *sll = &PACKET_SKB_CB(skb)->sa.ll; + + msg->msg_namelen = sll->sll_halen + + offsetof(struct sockaddr_ll, sll_addr); + copy_len = msg->msg_namelen; + if (msg->msg_namelen < sizeof(struct sockaddr_ll)) { + memset(msg->msg_name + + offsetof(struct sockaddr_ll, sll_addr), + 0, sizeof(sll->sll_addr)); + msg->msg_namelen = sizeof(struct sockaddr_ll); + } + } + if (WARN_ON_ONCE(copy_len > max_len)) { + copy_len = max_len; + msg->msg_namelen = copy_len; + } + memcpy(msg->msg_name, &PACKET_SKB_CB(skb)->sa, copy_len); + } + + if (packet_sock_flag(pkt_sk(sk), PACKET_SOCK_AUXDATA)) { + struct tpacket_auxdata aux; + + aux.tp_status = TP_STATUS_USER; + if (skb->ip_summed == CHECKSUM_PARTIAL) + aux.tp_status |= TP_STATUS_CSUMNOTREADY; + else if (skb->pkt_type != PACKET_OUTGOING && + skb_csum_unnecessary(skb)) + aux.tp_status |= TP_STATUS_CSUM_VALID; + + aux.tp_len = origlen; + aux.tp_snaplen = skb->len; + aux.tp_mac = 0; + aux.tp_net = skb_network_offset(skb); + if (skb_vlan_tag_present(skb)) { + aux.tp_vlan_tci = skb_vlan_tag_get(skb); + aux.tp_vlan_tpid = ntohs(skb->vlan_proto); + aux.tp_status |= TP_STATUS_VLAN_VALID | TP_STATUS_VLAN_TPID_VALID; + } else { + aux.tp_vlan_tci = 0; + aux.tp_vlan_tpid = 0; + } + put_cmsg(msg, SOL_PACKET, PACKET_AUXDATA, sizeof(aux), &aux); + } + + /* + * Free or return the buffer as appropriate. Again this + * hides all the races and re-entrancy issues from us. + */ + err = vnet_hdr_len + ((flags&MSG_TRUNC) ? skb->len : copied); + +out_free: + skb_free_datagram(sk, skb); +out: + return err; +} + +static int packet_getname_spkt(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + struct net_device *dev; + struct sock *sk = sock->sk; + + if (peer) + return -EOPNOTSUPP; + + uaddr->sa_family = AF_PACKET; + memset(uaddr->sa_data, 0, sizeof(uaddr->sa_data)); + rcu_read_lock(); + dev = dev_get_by_index_rcu(sock_net(sk), READ_ONCE(pkt_sk(sk)->ifindex)); + if (dev) + strscpy(uaddr->sa_data, dev->name, sizeof(uaddr->sa_data)); + rcu_read_unlock(); + + return sizeof(*uaddr); +} + +static int packet_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + struct net_device *dev; + struct sock *sk = sock->sk; + struct packet_sock *po = pkt_sk(sk); + DECLARE_SOCKADDR(struct sockaddr_ll *, sll, uaddr); + int ifindex; + + if (peer) + return -EOPNOTSUPP; + + ifindex = READ_ONCE(po->ifindex); + sll->sll_family = AF_PACKET; + sll->sll_ifindex = ifindex; + sll->sll_protocol = READ_ONCE(po->num); + sll->sll_pkttype = 0; + rcu_read_lock(); + dev = dev_get_by_index_rcu(sock_net(sk), ifindex); + if (dev) { + sll->sll_hatype = dev->type; + sll->sll_halen = dev->addr_len; + memcpy(sll->sll_addr, dev->dev_addr, dev->addr_len); + } else { + sll->sll_hatype = 0; /* Bad: we have no ARPHRD_UNSPEC */ + sll->sll_halen = 0; + } + rcu_read_unlock(); + + return offsetof(struct sockaddr_ll, sll_addr) + sll->sll_halen; +} + +static int packet_dev_mc(struct net_device *dev, struct packet_mclist *i, + int what) +{ + switch (i->type) { + case PACKET_MR_MULTICAST: + if (i->alen != dev->addr_len) + return -EINVAL; + if (what > 0) + return dev_mc_add(dev, i->addr); + else + return dev_mc_del(dev, i->addr); + break; + case PACKET_MR_PROMISC: + return dev_set_promiscuity(dev, what); + case PACKET_MR_ALLMULTI: + return dev_set_allmulti(dev, what); + case PACKET_MR_UNICAST: + if (i->alen != dev->addr_len) + return -EINVAL; + if (what > 0) + return dev_uc_add(dev, i->addr); + else + return dev_uc_del(dev, i->addr); + break; + default: + break; + } + return 0; +} + +static void packet_dev_mclist_delete(struct net_device *dev, + struct packet_mclist **mlp) +{ + struct packet_mclist *ml; + + while ((ml = *mlp) != NULL) { + if (ml->ifindex == dev->ifindex) { + packet_dev_mc(dev, ml, -1); + *mlp = ml->next; + kfree(ml); + } else + mlp = &ml->next; + } +} + +static int packet_mc_add(struct sock *sk, struct packet_mreq_max *mreq) +{ + struct packet_sock *po = pkt_sk(sk); + struct packet_mclist *ml, *i; + struct net_device *dev; + int err; + + rtnl_lock(); + + err = -ENODEV; + dev = __dev_get_by_index(sock_net(sk), mreq->mr_ifindex); + if (!dev) + goto done; + + err = -EINVAL; + if (mreq->mr_alen > dev->addr_len) + goto done; + + err = -ENOBUFS; + i = kmalloc(sizeof(*i), GFP_KERNEL); + if (i == NULL) + goto done; + + err = 0; + for (ml = po->mclist; ml; ml = ml->next) { + if (ml->ifindex == mreq->mr_ifindex && + ml->type == mreq->mr_type && + ml->alen == mreq->mr_alen && + memcmp(ml->addr, mreq->mr_address, ml->alen) == 0) { + ml->count++; + /* Free the new element ... */ + kfree(i); + goto done; + } + } + + i->type = mreq->mr_type; + i->ifindex = mreq->mr_ifindex; + i->alen = mreq->mr_alen; + memcpy(i->addr, mreq->mr_address, i->alen); + memset(i->addr + i->alen, 0, sizeof(i->addr) - i->alen); + i->count = 1; + i->next = po->mclist; + po->mclist = i; + err = packet_dev_mc(dev, i, 1); + if (err) { + po->mclist = i->next; + kfree(i); + } + +done: + rtnl_unlock(); + return err; +} + +static int packet_mc_drop(struct sock *sk, struct packet_mreq_max *mreq) +{ + struct packet_mclist *ml, **mlp; + + rtnl_lock(); + + for (mlp = &pkt_sk(sk)->mclist; (ml = *mlp) != NULL; mlp = &ml->next) { + if (ml->ifindex == mreq->mr_ifindex && + ml->type == mreq->mr_type && + ml->alen == mreq->mr_alen && + memcmp(ml->addr, mreq->mr_address, ml->alen) == 0) { + if (--ml->count == 0) { + struct net_device *dev; + *mlp = ml->next; + dev = __dev_get_by_index(sock_net(sk), ml->ifindex); + if (dev) + packet_dev_mc(dev, ml, -1); + kfree(ml); + } + break; + } + } + rtnl_unlock(); + return 0; +} + +static void packet_flush_mclist(struct sock *sk) +{ + struct packet_sock *po = pkt_sk(sk); + struct packet_mclist *ml; + + if (!po->mclist) + return; + + rtnl_lock(); + while ((ml = po->mclist) != NULL) { + struct net_device *dev; + + po->mclist = ml->next; + dev = __dev_get_by_index(sock_net(sk), ml->ifindex); + if (dev != NULL) + packet_dev_mc(dev, ml, -1); + kfree(ml); + } + rtnl_unlock(); +} + +static int +packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval, + unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct packet_sock *po = pkt_sk(sk); + int ret; + + if (level != SOL_PACKET) + return -ENOPROTOOPT; + + switch (optname) { + case PACKET_ADD_MEMBERSHIP: + case PACKET_DROP_MEMBERSHIP: + { + struct packet_mreq_max mreq; + int len = optlen; + memset(&mreq, 0, sizeof(mreq)); + if (len < sizeof(struct packet_mreq)) + return -EINVAL; + if (len > sizeof(mreq)) + len = sizeof(mreq); + if (copy_from_sockptr(&mreq, optval, len)) + return -EFAULT; + if (len < (mreq.mr_alen + offsetof(struct packet_mreq, mr_address))) + return -EINVAL; + if (optname == PACKET_ADD_MEMBERSHIP) + ret = packet_mc_add(sk, &mreq); + else + ret = packet_mc_drop(sk, &mreq); + return ret; + } + + case PACKET_RX_RING: + case PACKET_TX_RING: + { + union tpacket_req_u req_u; + int len; + + lock_sock(sk); + switch (po->tp_version) { + case TPACKET_V1: + case TPACKET_V2: + len = sizeof(req_u.req); + break; + case TPACKET_V3: + default: + len = sizeof(req_u.req3); + break; + } + if (optlen < len) { + ret = -EINVAL; + } else { + if (copy_from_sockptr(&req_u.req, optval, len)) + ret = -EFAULT; + else + ret = packet_set_ring(sk, &req_u, 0, + optname == PACKET_TX_RING); + } + release_sock(sk); + return ret; + } + case PACKET_COPY_THRESH: + { + int val; + + if (optlen != sizeof(val)) + return -EINVAL; + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + + pkt_sk(sk)->copy_thresh = val; + return 0; + } + case PACKET_VERSION: + { + int val; + + if (optlen != sizeof(val)) + return -EINVAL; + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + switch (val) { + case TPACKET_V1: + case TPACKET_V2: + case TPACKET_V3: + break; + default: + return -EINVAL; + } + lock_sock(sk); + if (po->rx_ring.pg_vec || po->tx_ring.pg_vec) { + ret = -EBUSY; + } else { + po->tp_version = val; + ret = 0; + } + release_sock(sk); + return ret; + } + case PACKET_RESERVE: + { + unsigned int val; + + if (optlen != sizeof(val)) + return -EINVAL; + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + if (val > INT_MAX) + return -EINVAL; + lock_sock(sk); + if (po->rx_ring.pg_vec || po->tx_ring.pg_vec) { + ret = -EBUSY; + } else { + po->tp_reserve = val; + ret = 0; + } + release_sock(sk); + return ret; + } + case PACKET_LOSS: + { + unsigned int val; + + if (optlen != sizeof(val)) + return -EINVAL; + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + + lock_sock(sk); + if (po->rx_ring.pg_vec || po->tx_ring.pg_vec) { + ret = -EBUSY; + } else { + po->tp_loss = !!val; + ret = 0; + } + release_sock(sk); + return ret; + } + case PACKET_AUXDATA: + { + int val; + + if (optlen < sizeof(val)) + return -EINVAL; + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + + packet_sock_flag_set(po, PACKET_SOCK_AUXDATA, val); + return 0; + } + case PACKET_ORIGDEV: + { + int val; + + if (optlen < sizeof(val)) + return -EINVAL; + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + + packet_sock_flag_set(po, PACKET_SOCK_ORIGDEV, val); + return 0; + } + case PACKET_VNET_HDR: + { + int val; + + if (sock->type != SOCK_RAW) + return -EINVAL; + if (optlen < sizeof(val)) + return -EINVAL; + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + + lock_sock(sk); + if (po->rx_ring.pg_vec || po->tx_ring.pg_vec) { + ret = -EBUSY; + } else { + po->has_vnet_hdr = !!val; + ret = 0; + } + release_sock(sk); + return ret; + } + case PACKET_TIMESTAMP: + { + int val; + + if (optlen != sizeof(val)) + return -EINVAL; + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + + po->tp_tstamp = val; + return 0; + } + case PACKET_FANOUT: + { + struct fanout_args args = { 0 }; + + if (optlen != sizeof(int) && optlen != sizeof(args)) + return -EINVAL; + if (copy_from_sockptr(&args, optval, optlen)) + return -EFAULT; + + return fanout_add(sk, &args); + } + case PACKET_FANOUT_DATA: + { + /* Paired with the WRITE_ONCE() in fanout_add() */ + if (!READ_ONCE(po->fanout)) + return -EINVAL; + + return fanout_set_data(po, optval, optlen); + } + case PACKET_IGNORE_OUTGOING: + { + int val; + + if (optlen != sizeof(val)) + return -EINVAL; + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + if (val < 0 || val > 1) + return -EINVAL; + + po->prot_hook.ignore_outgoing = !!val; + return 0; + } + case PACKET_TX_HAS_OFF: + { + unsigned int val; + + if (optlen != sizeof(val)) + return -EINVAL; + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + + lock_sock(sk); + if (!po->rx_ring.pg_vec && !po->tx_ring.pg_vec) + po->tp_tx_has_off = !!val; + + release_sock(sk); + return 0; + } + case PACKET_QDISC_BYPASS: + { + int val; + + if (optlen != sizeof(val)) + return -EINVAL; + if (copy_from_sockptr(&val, optval, sizeof(val))) + return -EFAULT; + + /* Paired with all lockless reads of po->xmit */ + WRITE_ONCE(po->xmit, val ? packet_direct_xmit : dev_queue_xmit); + return 0; + } + default: + return -ENOPROTOOPT; + } +} + +static int packet_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + int len; + int val, lv = sizeof(val); + struct sock *sk = sock->sk; + struct packet_sock *po = pkt_sk(sk); + void *data = &val; + union tpacket_stats_u st; + struct tpacket_rollover_stats rstats; + int drops; + + if (level != SOL_PACKET) + return -ENOPROTOOPT; + + if (get_user(len, optlen)) + return -EFAULT; + + if (len < 0) + return -EINVAL; + + switch (optname) { + case PACKET_STATISTICS: + spin_lock_bh(&sk->sk_receive_queue.lock); + memcpy(&st, &po->stats, sizeof(st)); + memset(&po->stats, 0, sizeof(po->stats)); + spin_unlock_bh(&sk->sk_receive_queue.lock); + drops = atomic_xchg(&po->tp_drops, 0); + + if (po->tp_version == TPACKET_V3) { + lv = sizeof(struct tpacket_stats_v3); + st.stats3.tp_drops = drops; + st.stats3.tp_packets += drops; + data = &st.stats3; + } else { + lv = sizeof(struct tpacket_stats); + st.stats1.tp_drops = drops; + st.stats1.tp_packets += drops; + data = &st.stats1; + } + + break; + case PACKET_AUXDATA: + val = packet_sock_flag(po, PACKET_SOCK_AUXDATA); + break; + case PACKET_ORIGDEV: + val = packet_sock_flag(po, PACKET_SOCK_ORIGDEV); + break; + case PACKET_VNET_HDR: + val = po->has_vnet_hdr; + break; + case PACKET_VERSION: + val = po->tp_version; + break; + case PACKET_HDRLEN: + if (len > sizeof(int)) + len = sizeof(int); + if (len < sizeof(int)) + return -EINVAL; + if (copy_from_user(&val, optval, len)) + return -EFAULT; + switch (val) { + case TPACKET_V1: + val = sizeof(struct tpacket_hdr); + break; + case TPACKET_V2: + val = sizeof(struct tpacket2_hdr); + break; + case TPACKET_V3: + val = sizeof(struct tpacket3_hdr); + break; + default: + return -EINVAL; + } + break; + case PACKET_RESERVE: + val = po->tp_reserve; + break; + case PACKET_LOSS: + val = po->tp_loss; + break; + case PACKET_TIMESTAMP: + val = po->tp_tstamp; + break; + case PACKET_FANOUT: + val = (po->fanout ? + ((u32)po->fanout->id | + ((u32)po->fanout->type << 16) | + ((u32)po->fanout->flags << 24)) : + 0); + break; + case PACKET_IGNORE_OUTGOING: + val = po->prot_hook.ignore_outgoing; + break; + case PACKET_ROLLOVER_STATS: + if (!po->rollover) + return -EINVAL; + rstats.tp_all = atomic_long_read(&po->rollover->num); + rstats.tp_huge = atomic_long_read(&po->rollover->num_huge); + rstats.tp_failed = atomic_long_read(&po->rollover->num_failed); + data = &rstats; + lv = sizeof(rstats); + break; + case PACKET_TX_HAS_OFF: + val = po->tp_tx_has_off; + break; + case PACKET_QDISC_BYPASS: + val = packet_use_direct_xmit(po); + break; + default: + return -ENOPROTOOPT; + } + + if (len > lv) + len = lv; + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, data, len)) + return -EFAULT; + return 0; +} + +static int packet_notifier(struct notifier_block *this, + unsigned long msg, void *ptr) +{ + struct sock *sk; + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct net *net = dev_net(dev); + + rcu_read_lock(); + sk_for_each_rcu(sk, &net->packet.sklist) { + struct packet_sock *po = pkt_sk(sk); + + switch (msg) { + case NETDEV_UNREGISTER: + if (po->mclist) + packet_dev_mclist_delete(dev, &po->mclist); + fallthrough; + + case NETDEV_DOWN: + if (dev->ifindex == po->ifindex) { + spin_lock(&po->bind_lock); + if (po->running) { + __unregister_prot_hook(sk, false); + sk->sk_err = ENETDOWN; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + } + if (msg == NETDEV_UNREGISTER) { + packet_cached_dev_reset(po); + WRITE_ONCE(po->ifindex, -1); + netdev_put(po->prot_hook.dev, + &po->prot_hook.dev_tracker); + po->prot_hook.dev = NULL; + } + spin_unlock(&po->bind_lock); + } + break; + case NETDEV_UP: + if (dev->ifindex == po->ifindex) { + spin_lock(&po->bind_lock); + if (po->num) + register_prot_hook(sk); + spin_unlock(&po->bind_lock); + } + break; + } + } + rcu_read_unlock(); + return NOTIFY_DONE; +} + + +static int packet_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + struct sock *sk = sock->sk; + + switch (cmd) { + case SIOCOUTQ: + { + int amount = sk_wmem_alloc_get(sk); + + return put_user(amount, (int __user *)arg); + } + case SIOCINQ: + { + struct sk_buff *skb; + int amount = 0; + + spin_lock_bh(&sk->sk_receive_queue.lock); + skb = skb_peek(&sk->sk_receive_queue); + if (skb) + amount = skb->len; + spin_unlock_bh(&sk->sk_receive_queue.lock); + return put_user(amount, (int __user *)arg); + } +#ifdef CONFIG_INET + case SIOCADDRT: + case SIOCDELRT: + case SIOCDARP: + case SIOCGARP: + case SIOCSARP: + case SIOCGIFADDR: + case SIOCSIFADDR: + case SIOCGIFBRDADDR: + case SIOCSIFBRDADDR: + case SIOCGIFNETMASK: + case SIOCSIFNETMASK: + case SIOCGIFDSTADDR: + case SIOCSIFDSTADDR: + case SIOCSIFFLAGS: + return inet_dgram_ops.ioctl(sock, cmd, arg); +#endif + + default: + return -ENOIOCTLCMD; + } + return 0; +} + +static __poll_t packet_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + struct sock *sk = sock->sk; + struct packet_sock *po = pkt_sk(sk); + __poll_t mask = datagram_poll(file, sock, wait); + + spin_lock_bh(&sk->sk_receive_queue.lock); + if (po->rx_ring.pg_vec) { + if (!packet_previous_rx_frame(po, &po->rx_ring, + TP_STATUS_KERNEL)) + mask |= EPOLLIN | EPOLLRDNORM; + } + packet_rcv_try_clear_pressure(po); + spin_unlock_bh(&sk->sk_receive_queue.lock); + spin_lock_bh(&sk->sk_write_queue.lock); + if (po->tx_ring.pg_vec) { + if (packet_current_frame(po, &po->tx_ring, TP_STATUS_AVAILABLE)) + mask |= EPOLLOUT | EPOLLWRNORM; + } + spin_unlock_bh(&sk->sk_write_queue.lock); + return mask; +} + + +/* Dirty? Well, I still did not learn better way to account + * for user mmaps. + */ + +static void packet_mm_open(struct vm_area_struct *vma) +{ + struct file *file = vma->vm_file; + struct socket *sock = file->private_data; + struct sock *sk = sock->sk; + + if (sk) + atomic_long_inc(&pkt_sk(sk)->mapped); +} + +static void packet_mm_close(struct vm_area_struct *vma) +{ + struct file *file = vma->vm_file; + struct socket *sock = file->private_data; + struct sock *sk = sock->sk; + + if (sk) + atomic_long_dec(&pkt_sk(sk)->mapped); +} + +static const struct vm_operations_struct packet_mmap_ops = { + .open = packet_mm_open, + .close = packet_mm_close, +}; + +static void free_pg_vec(struct pgv *pg_vec, unsigned int order, + unsigned int len) +{ + int i; + + for (i = 0; i < len; i++) { + if (likely(pg_vec[i].buffer)) { + if (is_vmalloc_addr(pg_vec[i].buffer)) + vfree(pg_vec[i].buffer); + else + free_pages((unsigned long)pg_vec[i].buffer, + order); + pg_vec[i].buffer = NULL; + } + } + kfree(pg_vec); +} + +static char *alloc_one_pg_vec_page(unsigned long order) +{ + char *buffer; + gfp_t gfp_flags = GFP_KERNEL | __GFP_COMP | + __GFP_ZERO | __GFP_NOWARN | __GFP_NORETRY; + + buffer = (char *) __get_free_pages(gfp_flags, order); + if (buffer) + return buffer; + + /* __get_free_pages failed, fall back to vmalloc */ + buffer = vzalloc(array_size((1 << order), PAGE_SIZE)); + if (buffer) + return buffer; + + /* vmalloc failed, lets dig into swap here */ + gfp_flags &= ~__GFP_NORETRY; + buffer = (char *) __get_free_pages(gfp_flags, order); + if (buffer) + return buffer; + + /* complete and utter failure */ + return NULL; +} + +static struct pgv *alloc_pg_vec(struct tpacket_req *req, int order) +{ + unsigned int block_nr = req->tp_block_nr; + struct pgv *pg_vec; + int i; + + pg_vec = kcalloc(block_nr, sizeof(struct pgv), GFP_KERNEL | __GFP_NOWARN); + if (unlikely(!pg_vec)) + goto out; + + for (i = 0; i < block_nr; i++) { + pg_vec[i].buffer = alloc_one_pg_vec_page(order); + if (unlikely(!pg_vec[i].buffer)) + goto out_free_pgvec; + } + +out: + return pg_vec; + +out_free_pgvec: + free_pg_vec(pg_vec, order, block_nr); + pg_vec = NULL; + goto out; +} + +static int packet_set_ring(struct sock *sk, union tpacket_req_u *req_u, + int closing, int tx_ring) +{ + struct pgv *pg_vec = NULL; + struct packet_sock *po = pkt_sk(sk); + unsigned long *rx_owner_map = NULL; + int was_running, order = 0; + struct packet_ring_buffer *rb; + struct sk_buff_head *rb_queue; + __be16 num; + int err; + /* Added to avoid minimal code churn */ + struct tpacket_req *req = &req_u->req; + + rb = tx_ring ? &po->tx_ring : &po->rx_ring; + rb_queue = tx_ring ? &sk->sk_write_queue : &sk->sk_receive_queue; + + err = -EBUSY; + if (!closing) { + if (atomic_long_read(&po->mapped)) + goto out; + if (packet_read_pending(rb)) + goto out; + } + + if (req->tp_block_nr) { + unsigned int min_frame_size; + + /* Sanity tests and some calculations */ + err = -EBUSY; + if (unlikely(rb->pg_vec)) + goto out; + + switch (po->tp_version) { + case TPACKET_V1: + po->tp_hdrlen = TPACKET_HDRLEN; + break; + case TPACKET_V2: + po->tp_hdrlen = TPACKET2_HDRLEN; + break; + case TPACKET_V3: + po->tp_hdrlen = TPACKET3_HDRLEN; + break; + } + + err = -EINVAL; + if (unlikely((int)req->tp_block_size <= 0)) + goto out; + if (unlikely(!PAGE_ALIGNED(req->tp_block_size))) + goto out; + min_frame_size = po->tp_hdrlen + po->tp_reserve; + if (po->tp_version >= TPACKET_V3 && + req->tp_block_size < + BLK_PLUS_PRIV((u64)req_u->req3.tp_sizeof_priv) + min_frame_size) + goto out; + if (unlikely(req->tp_frame_size < min_frame_size)) + goto out; + if (unlikely(req->tp_frame_size & (TPACKET_ALIGNMENT - 1))) + goto out; + + rb->frames_per_block = req->tp_block_size / req->tp_frame_size; + if (unlikely(rb->frames_per_block == 0)) + goto out; + if (unlikely(rb->frames_per_block > UINT_MAX / req->tp_block_nr)) + goto out; + if (unlikely((rb->frames_per_block * req->tp_block_nr) != + req->tp_frame_nr)) + goto out; + + err = -ENOMEM; + order = get_order(req->tp_block_size); + pg_vec = alloc_pg_vec(req, order); + if (unlikely(!pg_vec)) + goto out; + switch (po->tp_version) { + case TPACKET_V3: + /* Block transmit is not supported yet */ + if (!tx_ring) { + init_prb_bdqc(po, rb, pg_vec, req_u); + } else { + struct tpacket_req3 *req3 = &req_u->req3; + + if (req3->tp_retire_blk_tov || + req3->tp_sizeof_priv || + req3->tp_feature_req_word) { + err = -EINVAL; + goto out_free_pg_vec; + } + } + break; + default: + if (!tx_ring) { + rx_owner_map = bitmap_alloc(req->tp_frame_nr, + GFP_KERNEL | __GFP_NOWARN | __GFP_ZERO); + if (!rx_owner_map) + goto out_free_pg_vec; + } + break; + } + } + /* Done */ + else { + err = -EINVAL; + if (unlikely(req->tp_frame_nr)) + goto out; + } + + + /* Detach socket from network */ + spin_lock(&po->bind_lock); + was_running = po->running; + num = po->num; + if (was_running) { + WRITE_ONCE(po->num, 0); + __unregister_prot_hook(sk, false); + } + spin_unlock(&po->bind_lock); + + synchronize_net(); + + err = -EBUSY; + mutex_lock(&po->pg_vec_lock); + if (closing || atomic_long_read(&po->mapped) == 0) { + err = 0; + spin_lock_bh(&rb_queue->lock); + swap(rb->pg_vec, pg_vec); + if (po->tp_version <= TPACKET_V2) + swap(rb->rx_owner_map, rx_owner_map); + rb->frame_max = (req->tp_frame_nr - 1); + rb->head = 0; + rb->frame_size = req->tp_frame_size; + spin_unlock_bh(&rb_queue->lock); + + swap(rb->pg_vec_order, order); + swap(rb->pg_vec_len, req->tp_block_nr); + + rb->pg_vec_pages = req->tp_block_size/PAGE_SIZE; + po->prot_hook.func = (po->rx_ring.pg_vec) ? + tpacket_rcv : packet_rcv; + skb_queue_purge(rb_queue); + if (atomic_long_read(&po->mapped)) + pr_err("packet_mmap: vma is busy: %ld\n", + atomic_long_read(&po->mapped)); + } + mutex_unlock(&po->pg_vec_lock); + + spin_lock(&po->bind_lock); + if (was_running) { + WRITE_ONCE(po->num, num); + register_prot_hook(sk); + } + spin_unlock(&po->bind_lock); + if (pg_vec && (po->tp_version > TPACKET_V2)) { + /* Because we don't support block-based V3 on tx-ring */ + if (!tx_ring) + prb_shutdown_retire_blk_timer(po, rb_queue); + } + +out_free_pg_vec: + if (pg_vec) { + bitmap_free(rx_owner_map); + free_pg_vec(pg_vec, order, req->tp_block_nr); + } +out: + return err; +} + +static int packet_mmap(struct file *file, struct socket *sock, + struct vm_area_struct *vma) +{ + struct sock *sk = sock->sk; + struct packet_sock *po = pkt_sk(sk); + unsigned long size, expected_size; + struct packet_ring_buffer *rb; + unsigned long start; + int err = -EINVAL; + int i; + + if (vma->vm_pgoff) + return -EINVAL; + + mutex_lock(&po->pg_vec_lock); + + expected_size = 0; + for (rb = &po->rx_ring; rb <= &po->tx_ring; rb++) { + if (rb->pg_vec) { + expected_size += rb->pg_vec_len + * rb->pg_vec_pages + * PAGE_SIZE; + } + } + + if (expected_size == 0) + goto out; + + size = vma->vm_end - vma->vm_start; + if (size != expected_size) + goto out; + + start = vma->vm_start; + for (rb = &po->rx_ring; rb <= &po->tx_ring; rb++) { + if (rb->pg_vec == NULL) + continue; + + for (i = 0; i < rb->pg_vec_len; i++) { + struct page *page; + void *kaddr = rb->pg_vec[i].buffer; + int pg_num; + + for (pg_num = 0; pg_num < rb->pg_vec_pages; pg_num++) { + page = pgv_to_page(kaddr); + err = vm_insert_page(vma, start, page); + if (unlikely(err)) + goto out; + start += PAGE_SIZE; + kaddr += PAGE_SIZE; + } + } + } + + atomic_long_inc(&po->mapped); + vma->vm_ops = &packet_mmap_ops; + err = 0; + +out: + mutex_unlock(&po->pg_vec_lock); + return err; +} + +static const struct proto_ops packet_ops_spkt = { + .family = PF_PACKET, + .owner = THIS_MODULE, + .release = packet_release, + .bind = packet_bind_spkt, + .connect = sock_no_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = packet_getname_spkt, + .poll = datagram_poll, + .ioctl = packet_ioctl, + .gettstamp = sock_gettstamp, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .sendmsg = packet_sendmsg_spkt, + .recvmsg = packet_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static const struct proto_ops packet_ops = { + .family = PF_PACKET, + .owner = THIS_MODULE, + .release = packet_release, + .bind = packet_bind, + .connect = sock_no_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = packet_getname, + .poll = packet_poll, + .ioctl = packet_ioctl, + .gettstamp = sock_gettstamp, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = packet_setsockopt, + .getsockopt = packet_getsockopt, + .sendmsg = packet_sendmsg, + .recvmsg = packet_recvmsg, + .mmap = packet_mmap, + .sendpage = sock_no_sendpage, +}; + +static const struct net_proto_family packet_family_ops = { + .family = PF_PACKET, + .create = packet_create, + .owner = THIS_MODULE, +}; + +static struct notifier_block packet_netdev_notifier = { + .notifier_call = packet_notifier, +}; + +#ifdef CONFIG_PROC_FS + +static void *packet_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(RCU) +{ + struct net *net = seq_file_net(seq); + + rcu_read_lock(); + return seq_hlist_start_head_rcu(&net->packet.sklist, *pos); +} + +static void *packet_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct net *net = seq_file_net(seq); + return seq_hlist_next_rcu(v, &net->packet.sklist, pos); +} + +static void packet_seq_stop(struct seq_file *seq, void *v) + __releases(RCU) +{ + rcu_read_unlock(); +} + +static int packet_seq_show(struct seq_file *seq, void *v) +{ + if (v == SEQ_START_TOKEN) + seq_printf(seq, + "%*sRefCnt Type Proto Iface R Rmem User Inode\n", + IS_ENABLED(CONFIG_64BIT) ? -17 : -9, "sk"); + else { + struct sock *s = sk_entry(v); + const struct packet_sock *po = pkt_sk(s); + + seq_printf(seq, + "%pK %-6d %-4d %04x %-5d %1d %-6u %-6u %-6lu\n", + s, + refcount_read(&s->sk_refcnt), + s->sk_type, + ntohs(READ_ONCE(po->num)), + READ_ONCE(po->ifindex), + po->running, + atomic_read(&s->sk_rmem_alloc), + from_kuid_munged(seq_user_ns(seq), sock_i_uid(s)), + sock_i_ino(s)); + } + + return 0; +} + +static const struct seq_operations packet_seq_ops = { + .start = packet_seq_start, + .next = packet_seq_next, + .stop = packet_seq_stop, + .show = packet_seq_show, +}; +#endif + +static int __net_init packet_net_init(struct net *net) +{ + mutex_init(&net->packet.sklist_lock); + INIT_HLIST_HEAD(&net->packet.sklist); + +#ifdef CONFIG_PROC_FS + if (!proc_create_net("packet", 0, net->proc_net, &packet_seq_ops, + sizeof(struct seq_net_private))) + return -ENOMEM; +#endif /* CONFIG_PROC_FS */ + + return 0; +} + +static void __net_exit packet_net_exit(struct net *net) +{ + remove_proc_entry("packet", net->proc_net); + WARN_ON_ONCE(!hlist_empty(&net->packet.sklist)); +} + +static struct pernet_operations packet_net_ops = { + .init = packet_net_init, + .exit = packet_net_exit, +}; + + +static void __exit packet_exit(void) +{ + sock_unregister(PF_PACKET); + proto_unregister(&packet_proto); + unregister_netdevice_notifier(&packet_netdev_notifier); + unregister_pernet_subsys(&packet_net_ops); +} + +static int __init packet_init(void) +{ + int rc; + + rc = register_pernet_subsys(&packet_net_ops); + if (rc) + goto out; + rc = register_netdevice_notifier(&packet_netdev_notifier); + if (rc) + goto out_pernet; + rc = proto_register(&packet_proto, 0); + if (rc) + goto out_notifier; + rc = sock_register(&packet_family_ops); + if (rc) + goto out_proto; + + return 0; + +out_proto: + proto_unregister(&packet_proto); +out_notifier: + unregister_netdevice_notifier(&packet_netdev_notifier); +out_pernet: + unregister_pernet_subsys(&packet_net_ops); +out: + return rc; +} + +module_init(packet_init); +module_exit(packet_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_PACKET); diff --git a/net/packet/diag.c b/net/packet/diag.c new file mode 100644 index 000000000..a68a84574 --- /dev/null +++ b/net/packet/diag.c @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include +#include +#include + +#include "internal.h" + +static int pdiag_put_info(const struct packet_sock *po, struct sk_buff *nlskb) +{ + struct packet_diag_info pinfo; + + pinfo.pdi_index = po->ifindex; + pinfo.pdi_version = po->tp_version; + pinfo.pdi_reserve = po->tp_reserve; + pinfo.pdi_copy_thresh = po->copy_thresh; + pinfo.pdi_tstamp = po->tp_tstamp; + + pinfo.pdi_flags = 0; + if (po->running) + pinfo.pdi_flags |= PDI_RUNNING; + if (packet_sock_flag(po, PACKET_SOCK_AUXDATA)) + pinfo.pdi_flags |= PDI_AUXDATA; + if (packet_sock_flag(po, PACKET_SOCK_ORIGDEV)) + pinfo.pdi_flags |= PDI_ORIGDEV; + if (po->has_vnet_hdr) + pinfo.pdi_flags |= PDI_VNETHDR; + if (po->tp_loss) + pinfo.pdi_flags |= PDI_LOSS; + + return nla_put(nlskb, PACKET_DIAG_INFO, sizeof(pinfo), &pinfo); +} + +static int pdiag_put_mclist(const struct packet_sock *po, struct sk_buff *nlskb) +{ + struct nlattr *mca; + struct packet_mclist *ml; + + mca = nla_nest_start_noflag(nlskb, PACKET_DIAG_MCLIST); + if (!mca) + return -EMSGSIZE; + + rtnl_lock(); + for (ml = po->mclist; ml; ml = ml->next) { + struct packet_diag_mclist *dml; + + dml = nla_reserve_nohdr(nlskb, sizeof(*dml)); + if (!dml) { + rtnl_unlock(); + nla_nest_cancel(nlskb, mca); + return -EMSGSIZE; + } + + dml->pdmc_index = ml->ifindex; + dml->pdmc_type = ml->type; + dml->pdmc_alen = ml->alen; + dml->pdmc_count = ml->count; + BUILD_BUG_ON(sizeof(dml->pdmc_addr) != sizeof(ml->addr)); + memcpy(dml->pdmc_addr, ml->addr, sizeof(ml->addr)); + } + + rtnl_unlock(); + nla_nest_end(nlskb, mca); + + return 0; +} + +static int pdiag_put_ring(struct packet_ring_buffer *ring, int ver, int nl_type, + struct sk_buff *nlskb) +{ + struct packet_diag_ring pdr; + + if (!ring->pg_vec) + return 0; + + pdr.pdr_block_size = ring->pg_vec_pages << PAGE_SHIFT; + pdr.pdr_block_nr = ring->pg_vec_len; + pdr.pdr_frame_size = ring->frame_size; + pdr.pdr_frame_nr = ring->frame_max + 1; + + if (ver > TPACKET_V2) { + pdr.pdr_retire_tmo = ring->prb_bdqc.retire_blk_tov; + pdr.pdr_sizeof_priv = ring->prb_bdqc.blk_sizeof_priv; + pdr.pdr_features = ring->prb_bdqc.feature_req_word; + } else { + pdr.pdr_retire_tmo = 0; + pdr.pdr_sizeof_priv = 0; + pdr.pdr_features = 0; + } + + return nla_put(nlskb, nl_type, sizeof(pdr), &pdr); +} + +static int pdiag_put_rings_cfg(struct packet_sock *po, struct sk_buff *skb) +{ + int ret; + + mutex_lock(&po->pg_vec_lock); + ret = pdiag_put_ring(&po->rx_ring, po->tp_version, + PACKET_DIAG_RX_RING, skb); + if (!ret) + ret = pdiag_put_ring(&po->tx_ring, po->tp_version, + PACKET_DIAG_TX_RING, skb); + mutex_unlock(&po->pg_vec_lock); + + return ret; +} + +static int pdiag_put_fanout(struct packet_sock *po, struct sk_buff *nlskb) +{ + int ret = 0; + + mutex_lock(&fanout_mutex); + if (po->fanout) { + u32 val; + + val = (u32)po->fanout->id | ((u32)po->fanout->type << 16); + ret = nla_put_u32(nlskb, PACKET_DIAG_FANOUT, val); + } + mutex_unlock(&fanout_mutex); + + return ret; +} + +static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, + struct packet_diag_req *req, + bool may_report_filterinfo, + struct user_namespace *user_ns, + u32 portid, u32 seq, u32 flags, int sk_ino) +{ + struct nlmsghdr *nlh; + struct packet_diag_msg *rp; + struct packet_sock *po = pkt_sk(sk); + + nlh = nlmsg_put(skb, portid, seq, SOCK_DIAG_BY_FAMILY, sizeof(*rp), flags); + if (!nlh) + return -EMSGSIZE; + + rp = nlmsg_data(nlh); + rp->pdiag_family = AF_PACKET; + rp->pdiag_type = sk->sk_type; + rp->pdiag_num = ntohs(READ_ONCE(po->num)); + rp->pdiag_ino = sk_ino; + sock_diag_save_cookie(sk, rp->pdiag_cookie); + + if ((req->pdiag_show & PACKET_SHOW_INFO) && + pdiag_put_info(po, skb)) + goto out_nlmsg_trim; + + if ((req->pdiag_show & PACKET_SHOW_INFO) && + nla_put_u32(skb, PACKET_DIAG_UID, + from_kuid_munged(user_ns, sock_i_uid(sk)))) + goto out_nlmsg_trim; + + if ((req->pdiag_show & PACKET_SHOW_MCLIST) && + pdiag_put_mclist(po, skb)) + goto out_nlmsg_trim; + + if ((req->pdiag_show & PACKET_SHOW_RING_CFG) && + pdiag_put_rings_cfg(po, skb)) + goto out_nlmsg_trim; + + if ((req->pdiag_show & PACKET_SHOW_FANOUT) && + pdiag_put_fanout(po, skb)) + goto out_nlmsg_trim; + + if ((req->pdiag_show & PACKET_SHOW_MEMINFO) && + sock_diag_put_meminfo(sk, skb, PACKET_DIAG_MEMINFO)) + goto out_nlmsg_trim; + + if ((req->pdiag_show & PACKET_SHOW_FILTER) && + sock_diag_put_filterinfo(may_report_filterinfo, sk, skb, + PACKET_DIAG_FILTER)) + goto out_nlmsg_trim; + + nlmsg_end(skb, nlh); + return 0; + +out_nlmsg_trim: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int packet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + int num = 0, s_num = cb->args[0]; + struct packet_diag_req *req; + struct net *net; + struct sock *sk; + bool may_report_filterinfo; + + net = sock_net(skb->sk); + req = nlmsg_data(cb->nlh); + may_report_filterinfo = netlink_net_capable(cb->skb, CAP_NET_ADMIN); + + mutex_lock(&net->packet.sklist_lock); + sk_for_each(sk, &net->packet.sklist) { + if (!net_eq(sock_net(sk), net)) + continue; + if (num < s_num) + goto next; + + if (sk_diag_fill(sk, skb, req, + may_report_filterinfo, + sk_user_ns(NETLINK_CB(cb->skb).sk), + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + sock_i_ino(sk)) < 0) + goto done; +next: + num++; + } +done: + mutex_unlock(&net->packet.sklist_lock); + cb->args[0] = num; + + return skb->len; +} + +static int packet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h) +{ + int hdrlen = sizeof(struct packet_diag_req); + struct net *net = sock_net(skb->sk); + struct packet_diag_req *req; + + if (nlmsg_len(h) < hdrlen) + return -EINVAL; + + req = nlmsg_data(h); + /* Make it possible to support protocol filtering later */ + if (req->sdiag_protocol) + return -EINVAL; + + if (h->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .dump = packet_diag_dump, + }; + return netlink_dump_start(net->diag_nlsk, skb, h, &c); + } else + return -EOPNOTSUPP; +} + +static const struct sock_diag_handler packet_diag_handler = { + .family = AF_PACKET, + .dump = packet_diag_handler_dump, +}; + +static int __init packet_diag_init(void) +{ + return sock_diag_register(&packet_diag_handler); +} + +static void __exit packet_diag_exit(void) +{ + sock_diag_unregister(&packet_diag_handler); +} + +module_init(packet_diag_init); +module_exit(packet_diag_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 17 /* AF_PACKET */); diff --git a/net/packet/internal.h b/net/packet/internal.h new file mode 100644 index 000000000..b2edfe6fc --- /dev/null +++ b/net/packet/internal.h @@ -0,0 +1,167 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __PACKET_INTERNAL_H__ +#define __PACKET_INTERNAL_H__ + +#include + +struct packet_mclist { + struct packet_mclist *next; + int ifindex; + int count; + unsigned short type; + unsigned short alen; + unsigned char addr[MAX_ADDR_LEN]; +}; + +/* kbdq - kernel block descriptor queue */ +struct tpacket_kbdq_core { + struct pgv *pkbdq; + unsigned int feature_req_word; + unsigned int hdrlen; + unsigned char reset_pending_on_curr_blk; + unsigned char delete_blk_timer; + unsigned short kactive_blk_num; + unsigned short blk_sizeof_priv; + + /* last_kactive_blk_num: + * trick to see if user-space has caught up + * in order to avoid refreshing timer when every single pkt arrives. + */ + unsigned short last_kactive_blk_num; + + char *pkblk_start; + char *pkblk_end; + int kblk_size; + unsigned int max_frame_len; + unsigned int knum_blocks; + uint64_t knxt_seq_num; + char *prev; + char *nxt_offset; + struct sk_buff *skb; + + rwlock_t blk_fill_in_prog_lock; + + /* Default is set to 8ms */ +#define DEFAULT_PRB_RETIRE_TOV (8) + + unsigned short retire_blk_tov; + unsigned short version; + unsigned long tov_in_jiffies; + + /* timer to retire an outstanding block */ + struct timer_list retire_blk_timer; +}; + +struct pgv { + char *buffer; +}; + +struct packet_ring_buffer { + struct pgv *pg_vec; + + unsigned int head; + unsigned int frames_per_block; + unsigned int frame_size; + unsigned int frame_max; + + unsigned int pg_vec_order; + unsigned int pg_vec_pages; + unsigned int pg_vec_len; + + unsigned int __percpu *pending_refcnt; + + union { + unsigned long *rx_owner_map; + struct tpacket_kbdq_core prb_bdqc; + }; +}; + +extern struct mutex fanout_mutex; +#define PACKET_FANOUT_MAX (1 << 16) + +struct packet_fanout { + possible_net_t net; + unsigned int num_members; + u32 max_num_members; + u16 id; + u8 type; + u8 flags; + union { + atomic_t rr_cur; + struct bpf_prog __rcu *bpf_prog; + }; + struct list_head list; + spinlock_t lock; + refcount_t sk_ref; + struct packet_type prot_hook ____cacheline_aligned_in_smp; + struct sock __rcu *arr[]; +}; + +struct packet_rollover { + int sock; + atomic_long_t num; + atomic_long_t num_huge; + atomic_long_t num_failed; +#define ROLLOVER_HLEN (L1_CACHE_BYTES / sizeof(u32)) + u32 history[ROLLOVER_HLEN] ____cacheline_aligned; +} ____cacheline_aligned_in_smp; + +struct packet_sock { + /* struct sock has to be the first member of packet_sock */ + struct sock sk; + struct packet_fanout *fanout; + union tpacket_stats_u stats; + struct packet_ring_buffer rx_ring; + struct packet_ring_buffer tx_ring; + int copy_thresh; + spinlock_t bind_lock; + struct mutex pg_vec_lock; + unsigned long flags; + unsigned int running; /* bind_lock must be held */ + unsigned int has_vnet_hdr:1, /* writer must hold sock lock */ + tp_loss:1, + tp_tx_has_off:1; + int pressure; + int ifindex; /* bound device */ + __be16 num; + struct packet_rollover *rollover; + struct packet_mclist *mclist; + atomic_long_t mapped; + enum tpacket_versions tp_version; + unsigned int tp_hdrlen; + unsigned int tp_reserve; + unsigned int tp_tstamp; + struct completion skb_completion; + struct net_device __rcu *cached_dev; + int (*xmit)(struct sk_buff *skb); + struct packet_type prot_hook ____cacheline_aligned_in_smp; + atomic_t tp_drops ____cacheline_aligned_in_smp; +}; + +static inline struct packet_sock *pkt_sk(struct sock *sk) +{ + return (struct packet_sock *)sk; +} + +enum packet_sock_flags { + PACKET_SOCK_ORIGDEV, + PACKET_SOCK_AUXDATA, +}; + +static inline void packet_sock_flag_set(struct packet_sock *po, + enum packet_sock_flags flag, + bool val) +{ + if (val) + set_bit(flag, &po->flags); + else + clear_bit(flag, &po->flags); +} + +static inline bool packet_sock_flag(const struct packet_sock *po, + enum packet_sock_flags flag) +{ + return test_bit(flag, &po->flags); +} + +#endif diff --git a/net/phonet/Kconfig b/net/phonet/Kconfig new file mode 100644 index 000000000..07f2c2172 --- /dev/null +++ b/net/phonet/Kconfig @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Phonet protocol +# + +config PHONET + tristate "Phonet protocols family" + help + The Phone Network protocol (PhoNet) is a packet-oriented + communication protocol developed by Nokia for use with its modems. + + This is required for Maemo to use cellular data connectivity (if + supported). It can also be used to control Nokia phones + from a Linux computer, although AT commands may be easier to use. + + To compile this driver as a module, choose M here: the module + will be called phonet. If unsure, say N. diff --git a/net/phonet/Makefile b/net/phonet/Makefile new file mode 100644 index 000000000..444f87593 --- /dev/null +++ b/net/phonet/Makefile @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: GPL-2.0 +obj-$(CONFIG_PHONET) += phonet.o pn_pep.o + +phonet-y := \ + pn_dev.o \ + pn_netlink.o \ + socket.o \ + datagram.o \ + sysctl.o \ + af_phonet.o + +pn_pep-y := pep.o pep-gprs.o diff --git a/net/phonet/af_phonet.c b/net/phonet/af_phonet.c new file mode 100644 index 000000000..2b582da1e --- /dev/null +++ b/net/phonet/af_phonet.c @@ -0,0 +1,540 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * File: af_phonet.c + * + * Phonet protocols family + * + * Copyright (C) 2008 Nokia Corporation. + * + * Authors: Sakari Ailus + * Rémi Denis-Courmont + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +/* Transport protocol registration */ +static const struct phonet_protocol *proto_tab[PHONET_NPROTO] __read_mostly; + +static const struct phonet_protocol *phonet_proto_get(unsigned int protocol) +{ + const struct phonet_protocol *pp; + + if (protocol >= PHONET_NPROTO) + return NULL; + + rcu_read_lock(); + pp = rcu_dereference(proto_tab[protocol]); + if (pp && !try_module_get(pp->prot->owner)) + pp = NULL; + rcu_read_unlock(); + + return pp; +} + +static inline void phonet_proto_put(const struct phonet_protocol *pp) +{ + module_put(pp->prot->owner); +} + +/* protocol family functions */ + +static int pn_socket_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + struct pn_sock *pn; + const struct phonet_protocol *pnp; + int err; + + if (!capable(CAP_SYS_ADMIN)) + return -EPERM; + + if (protocol == 0) { + /* Default protocol selection */ + switch (sock->type) { + case SOCK_DGRAM: + protocol = PN_PROTO_PHONET; + break; + case SOCK_SEQPACKET: + protocol = PN_PROTO_PIPE; + break; + default: + return -EPROTONOSUPPORT; + } + } + + pnp = phonet_proto_get(protocol); + if (pnp == NULL && + request_module("net-pf-%d-proto-%d", PF_PHONET, protocol) == 0) + pnp = phonet_proto_get(protocol); + + if (pnp == NULL) + return -EPROTONOSUPPORT; + if (sock->type != pnp->sock_type) { + err = -EPROTONOSUPPORT; + goto out; + } + + sk = sk_alloc(net, PF_PHONET, GFP_KERNEL, pnp->prot, kern); + if (sk == NULL) { + err = -ENOMEM; + goto out; + } + + sock_init_data(sock, sk); + sock->state = SS_UNCONNECTED; + sock->ops = pnp->ops; + sk->sk_backlog_rcv = sk->sk_prot->backlog_rcv; + sk->sk_protocol = protocol; + pn = pn_sk(sk); + pn->sobject = 0; + pn->dobject = 0; + pn->resource = 0; + sk->sk_prot->init(sk); + err = 0; + +out: + phonet_proto_put(pnp); + return err; +} + +static const struct net_proto_family phonet_proto_family = { + .family = PF_PHONET, + .create = pn_socket_create, + .owner = THIS_MODULE, +}; + +/* Phonet device header operations */ +static int pn_header_create(struct sk_buff *skb, struct net_device *dev, + unsigned short type, const void *daddr, + const void *saddr, unsigned int len) +{ + u8 *media = skb_push(skb, 1); + + if (type != ETH_P_PHONET) + return -1; + + if (!saddr) + saddr = dev->dev_addr; + *media = *(const u8 *)saddr; + return 1; +} + +static int pn_header_parse(const struct sk_buff *skb, unsigned char *haddr) +{ + const u8 *media = skb_mac_header(skb); + *haddr = *media; + return 1; +} + +const struct header_ops phonet_header_ops = { + .create = pn_header_create, + .parse = pn_header_parse, +}; +EXPORT_SYMBOL(phonet_header_ops); + +/* + * Prepends an ISI header and sends a datagram. + */ +static int pn_send(struct sk_buff *skb, struct net_device *dev, + u16 dst, u16 src, u8 res) +{ + struct phonethdr *ph; + int err; + + if (skb->len + 2 > 0xffff /* Phonet length field limit */ || + skb->len + sizeof(struct phonethdr) > dev->mtu) { + err = -EMSGSIZE; + goto drop; + } + + /* Broadcast sending is not implemented */ + if (pn_addr(dst) == PNADDR_BROADCAST) { + err = -EOPNOTSUPP; + goto drop; + } + + skb_reset_transport_header(skb); + WARN_ON(skb_headroom(skb) & 1); /* HW assumes word alignment */ + skb_push(skb, sizeof(struct phonethdr)); + skb_reset_network_header(skb); + ph = pn_hdr(skb); + ph->pn_rdev = pn_dev(dst); + ph->pn_sdev = pn_dev(src); + ph->pn_res = res; + ph->pn_length = __cpu_to_be16(skb->len + 2 - sizeof(*ph)); + ph->pn_robj = pn_obj(dst); + ph->pn_sobj = pn_obj(src); + + skb->protocol = htons(ETH_P_PHONET); + skb->priority = 0; + skb->dev = dev; + + if (skb->pkt_type == PACKET_LOOPBACK) { + skb_reset_mac_header(skb); + skb_orphan(skb); + err = netif_rx(skb) ? -ENOBUFS : 0; + } else { + err = dev_hard_header(skb, dev, ntohs(skb->protocol), + NULL, NULL, skb->len); + if (err < 0) { + err = -EHOSTUNREACH; + goto drop; + } + err = dev_queue_xmit(skb); + if (unlikely(err > 0)) + err = net_xmit_errno(err); + } + + return err; +drop: + kfree_skb(skb); + return err; +} + +static int pn_raw_send(const void *data, int len, struct net_device *dev, + u16 dst, u16 src, u8 res) +{ + struct sk_buff *skb = alloc_skb(MAX_PHONET_HEADER + len, GFP_ATOMIC); + if (skb == NULL) + return -ENOMEM; + + if (phonet_address_lookup(dev_net(dev), pn_addr(dst)) == 0) + skb->pkt_type = PACKET_LOOPBACK; + + skb_reserve(skb, MAX_PHONET_HEADER); + __skb_put(skb, len); + skb_copy_to_linear_data(skb, data, len); + return pn_send(skb, dev, dst, src, res); +} + +/* + * Create a Phonet header for the skb and send it out. Returns + * non-zero error code if failed. The skb is freed then. + */ +int pn_skb_send(struct sock *sk, struct sk_buff *skb, + const struct sockaddr_pn *target) +{ + struct net *net = sock_net(sk); + struct net_device *dev; + struct pn_sock *pn = pn_sk(sk); + int err; + u16 src, dst; + u8 daddr, saddr, res; + + src = pn->sobject; + if (target != NULL) { + dst = pn_sockaddr_get_object(target); + res = pn_sockaddr_get_resource(target); + } else { + dst = pn->dobject; + res = pn->resource; + } + daddr = pn_addr(dst); + + err = -EHOSTUNREACH; + if (sk->sk_bound_dev_if) + dev = dev_get_by_index(net, sk->sk_bound_dev_if); + else if (phonet_address_lookup(net, daddr) == 0) { + dev = phonet_device_get(net); + skb->pkt_type = PACKET_LOOPBACK; + } else if (dst == 0) { + /* Resource routing (small race until phonet_rcv()) */ + struct sock *sk = pn_find_sock_by_res(net, res); + if (sk) { + sock_put(sk); + dev = phonet_device_get(net); + skb->pkt_type = PACKET_LOOPBACK; + } else + dev = phonet_route_output(net, daddr); + } else + dev = phonet_route_output(net, daddr); + + if (!dev || !(dev->flags & IFF_UP)) + goto drop; + + saddr = phonet_address_get(dev, daddr); + if (saddr == PN_NO_ADDR) + goto drop; + + if (!pn_addr(src)) + src = pn_object(saddr, pn_obj(src)); + + err = pn_send(skb, dev, dst, src, res); + dev_put(dev); + return err; + +drop: + kfree_skb(skb); + dev_put(dev); + return err; +} +EXPORT_SYMBOL(pn_skb_send); + +/* Do not send an error message in response to an error message */ +static inline int can_respond(struct sk_buff *skb) +{ + const struct phonethdr *ph; + const struct phonetmsg *pm; + u8 submsg_id; + + if (!pskb_may_pull(skb, 3)) + return 0; + + ph = pn_hdr(skb); + if (ph->pn_res == PN_PREFIX && !pskb_may_pull(skb, 5)) + return 0; + if (ph->pn_res == PN_COMMGR) /* indications */ + return 0; + + ph = pn_hdr(skb); /* re-acquires the pointer */ + pm = pn_msg(skb); + if (pm->pn_msg_id != PN_COMMON_MESSAGE) + return 1; + submsg_id = (ph->pn_res == PN_PREFIX) + ? pm->pn_e_submsg_id : pm->pn_submsg_id; + if (submsg_id != PN_COMM_ISA_ENTITY_NOT_REACHABLE_RESP && + pm->pn_e_submsg_id != PN_COMM_SERVICE_NOT_IDENTIFIED_RESP) + return 1; + return 0; +} + +static int send_obj_unreachable(struct sk_buff *rskb) +{ + const struct phonethdr *oph = pn_hdr(rskb); + const struct phonetmsg *opm = pn_msg(rskb); + struct phonetmsg resp; + + memset(&resp, 0, sizeof(resp)); + resp.pn_trans_id = opm->pn_trans_id; + resp.pn_msg_id = PN_COMMON_MESSAGE; + if (oph->pn_res == PN_PREFIX) { + resp.pn_e_res_id = opm->pn_e_res_id; + resp.pn_e_submsg_id = PN_COMM_ISA_ENTITY_NOT_REACHABLE_RESP; + resp.pn_e_orig_msg_id = opm->pn_msg_id; + resp.pn_e_status = 0; + } else { + resp.pn_submsg_id = PN_COMM_ISA_ENTITY_NOT_REACHABLE_RESP; + resp.pn_orig_msg_id = opm->pn_msg_id; + resp.pn_status = 0; + } + return pn_raw_send(&resp, sizeof(resp), rskb->dev, + pn_object(oph->pn_sdev, oph->pn_sobj), + pn_object(oph->pn_rdev, oph->pn_robj), + oph->pn_res); +} + +static int send_reset_indications(struct sk_buff *rskb) +{ + struct phonethdr *oph = pn_hdr(rskb); + static const u8 data[4] = { + 0x00 /* trans ID */, 0x10 /* subscribe msg */, + 0x00 /* subscription count */, 0x00 /* dummy */ + }; + + return pn_raw_send(data, sizeof(data), rskb->dev, + pn_object(oph->pn_sdev, 0x00), + pn_object(oph->pn_rdev, oph->pn_robj), + PN_COMMGR); +} + + +/* packet type functions */ + +/* + * Stuff received packets to associated sockets. + * On error, returns non-zero and releases the skb. + */ +static int phonet_rcv(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pkttype, + struct net_device *orig_dev) +{ + struct net *net = dev_net(dev); + struct phonethdr *ph; + struct sockaddr_pn sa; + u16 len; + + skb = skb_share_check(skb, GFP_ATOMIC); + if (!skb) + return NET_RX_DROP; + + /* check we have at least a full Phonet header */ + if (!pskb_pull(skb, sizeof(struct phonethdr))) + goto out; + + /* check that the advertised length is correct */ + ph = pn_hdr(skb); + len = get_unaligned_be16(&ph->pn_length); + if (len < 2) + goto out; + len -= 2; + if ((len > skb->len) || pskb_trim(skb, len)) + goto out; + skb_reset_transport_header(skb); + + pn_skb_get_dst_sockaddr(skb, &sa); + + /* check if this is broadcasted */ + if (pn_sockaddr_get_addr(&sa) == PNADDR_BROADCAST) { + pn_deliver_sock_broadcast(net, skb); + goto out; + } + + /* resource routing */ + if (pn_sockaddr_get_object(&sa) == 0) { + struct sock *sk = pn_find_sock_by_res(net, sa.spn_resource); + if (sk) + return sk_receive_skb(sk, skb, 0); + } + + /* check if we are the destination */ + if (phonet_address_lookup(net, pn_sockaddr_get_addr(&sa)) == 0) { + /* Phonet packet input */ + struct sock *sk = pn_find_sock_by_sa(net, &sa); + + if (sk) + return sk_receive_skb(sk, skb, 0); + + if (can_respond(skb)) { + send_obj_unreachable(skb); + send_reset_indications(skb); + } + } else if (unlikely(skb->pkt_type == PACKET_LOOPBACK)) + goto out; /* Race between address deletion and loopback */ + else { + /* Phonet packet routing */ + struct net_device *out_dev; + + out_dev = phonet_route_output(net, pn_sockaddr_get_addr(&sa)); + if (!out_dev) { + net_dbg_ratelimited("No Phonet route to %02X\n", + pn_sockaddr_get_addr(&sa)); + goto out; + } + + __skb_push(skb, sizeof(struct phonethdr)); + skb->dev = out_dev; + if (out_dev == dev) { + net_dbg_ratelimited("Phonet loop to %02X on %s\n", + pn_sockaddr_get_addr(&sa), + dev->name); + goto out_dev; + } + /* Some drivers (e.g. TUN) do not allocate HW header space */ + if (skb_cow_head(skb, out_dev->hard_header_len)) + goto out_dev; + + if (dev_hard_header(skb, out_dev, ETH_P_PHONET, NULL, NULL, + skb->len) < 0) + goto out_dev; + dev_queue_xmit(skb); + dev_put(out_dev); + return NET_RX_SUCCESS; +out_dev: + dev_put(out_dev); + } + +out: + kfree_skb(skb); + return NET_RX_DROP; +} + +static struct packet_type phonet_packet_type __read_mostly = { + .type = cpu_to_be16(ETH_P_PHONET), + .func = phonet_rcv, +}; + +static DEFINE_MUTEX(proto_tab_lock); + +int __init_or_module phonet_proto_register(unsigned int protocol, + const struct phonet_protocol *pp) +{ + int err = 0; + + if (protocol >= PHONET_NPROTO) + return -EINVAL; + + err = proto_register(pp->prot, 1); + if (err) + return err; + + mutex_lock(&proto_tab_lock); + if (proto_tab[protocol]) + err = -EBUSY; + else + rcu_assign_pointer(proto_tab[protocol], pp); + mutex_unlock(&proto_tab_lock); + + return err; +} +EXPORT_SYMBOL(phonet_proto_register); + +void phonet_proto_unregister(unsigned int protocol, + const struct phonet_protocol *pp) +{ + mutex_lock(&proto_tab_lock); + BUG_ON(proto_tab[protocol] != pp); + RCU_INIT_POINTER(proto_tab[protocol], NULL); + mutex_unlock(&proto_tab_lock); + synchronize_rcu(); + proto_unregister(pp->prot); +} +EXPORT_SYMBOL(phonet_proto_unregister); + +/* Module registration */ +static int __init phonet_init(void) +{ + int err; + + err = phonet_device_init(); + if (err) + return err; + + pn_sock_init(); + err = sock_register(&phonet_proto_family); + if (err) { + printk(KERN_ALERT + "phonet protocol family initialization failed\n"); + goto err_sock; + } + + dev_add_pack(&phonet_packet_type); + phonet_sysctl_init(); + + err = isi_register(); + if (err) + goto err; + return 0; + +err: + phonet_sysctl_exit(); + sock_unregister(PF_PHONET); + dev_remove_pack(&phonet_packet_type); +err_sock: + phonet_device_exit(); + return err; +} + +static void __exit phonet_exit(void) +{ + isi_unregister(); + phonet_sysctl_exit(); + sock_unregister(PF_PHONET); + dev_remove_pack(&phonet_packet_type); + phonet_device_exit(); +} + +module_init(phonet_init); +module_exit(phonet_exit); +MODULE_DESCRIPTION("Phonet protocol stack for Linux"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_PHONET); diff --git a/net/phonet/datagram.c b/net/phonet/datagram.c new file mode 100644 index 000000000..ff5f49ab2 --- /dev/null +++ b/net/phonet/datagram.c @@ -0,0 +1,199 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * File: datagram.c + * + * Datagram (ISI) Phonet sockets + * + * Copyright (C) 2008 Nokia Corporation. + * + * Authors: Sakari Ailus + * Rémi Denis-Courmont + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +static int pn_backlog_rcv(struct sock *sk, struct sk_buff *skb); + +/* associated socket ceases to exist */ +static void pn_sock_close(struct sock *sk, long timeout) +{ + sk_common_release(sk); +} + +static int pn_ioctl(struct sock *sk, int cmd, unsigned long arg) +{ + struct sk_buff *skb; + int answ; + + switch (cmd) { + case SIOCINQ: + lock_sock(sk); + skb = skb_peek(&sk->sk_receive_queue); + answ = skb ? skb->len : 0; + release_sock(sk); + return put_user(answ, (int __user *)arg); + + case SIOCPNADDRESOURCE: + case SIOCPNDELRESOURCE: { + u32 res; + if (get_user(res, (u32 __user *)arg)) + return -EFAULT; + if (res >= 256) + return -EINVAL; + if (cmd == SIOCPNADDRESOURCE) + return pn_sock_bind_res(sk, res); + else + return pn_sock_unbind_res(sk, res); + } + } + + return -ENOIOCTLCMD; +} + +/* Destroy socket. All references are gone. */ +static void pn_destruct(struct sock *sk) +{ + skb_queue_purge(&sk->sk_receive_queue); +} + +static int pn_init(struct sock *sk) +{ + sk->sk_destruct = pn_destruct; + return 0; +} + +static int pn_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) +{ + DECLARE_SOCKADDR(struct sockaddr_pn *, target, msg->msg_name); + struct sk_buff *skb; + int err; + + if (msg->msg_flags & ~(MSG_DONTWAIT|MSG_EOR|MSG_NOSIGNAL| + MSG_CMSG_COMPAT)) + return -EOPNOTSUPP; + + if (target == NULL) + return -EDESTADDRREQ; + + if (msg->msg_namelen < sizeof(struct sockaddr_pn)) + return -EINVAL; + + if (target->spn_family != AF_PHONET) + return -EAFNOSUPPORT; + + skb = sock_alloc_send_skb(sk, MAX_PHONET_HEADER + len, + msg->msg_flags & MSG_DONTWAIT, &err); + if (skb == NULL) + return err; + skb_reserve(skb, MAX_PHONET_HEADER); + + err = memcpy_from_msg((void *)skb_put(skb, len), msg, len); + if (err < 0) { + kfree_skb(skb); + return err; + } + + /* + * Fill in the Phonet header and + * finally pass the packet forwards. + */ + err = pn_skb_send(sk, skb, target); + + /* If ok, return len. */ + return (err >= 0) ? len : err; +} + +static int pn_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int flags, int *addr_len) +{ + struct sk_buff *skb = NULL; + struct sockaddr_pn sa; + int rval = -EOPNOTSUPP; + int copylen; + + if (flags & ~(MSG_PEEK|MSG_TRUNC|MSG_DONTWAIT|MSG_NOSIGNAL| + MSG_CMSG_COMPAT)) + goto out_nofree; + + skb = skb_recv_datagram(sk, flags, &rval); + if (skb == NULL) + goto out_nofree; + + pn_skb_get_src_sockaddr(skb, &sa); + + copylen = skb->len; + if (len < copylen) { + msg->msg_flags |= MSG_TRUNC; + copylen = len; + } + + rval = skb_copy_datagram_msg(skb, 0, msg, copylen); + if (rval) { + rval = -EFAULT; + goto out; + } + + rval = (flags & MSG_TRUNC) ? skb->len : copylen; + + if (msg->msg_name != NULL) { + __sockaddr_check_size(sizeof(sa)); + memcpy(msg->msg_name, &sa, sizeof(sa)); + *addr_len = sizeof(sa); + } + +out: + skb_free_datagram(sk, skb); + +out_nofree: + return rval; +} + +/* Queue an skb for a sock. */ +static int pn_backlog_rcv(struct sock *sk, struct sk_buff *skb) +{ + int err = sock_queue_rcv_skb(sk, skb); + + if (err < 0) + kfree_skb(skb); + return err ? NET_RX_DROP : NET_RX_SUCCESS; +} + +/* Module registration */ +static struct proto pn_proto = { + .close = pn_sock_close, + .ioctl = pn_ioctl, + .init = pn_init, + .sendmsg = pn_sendmsg, + .recvmsg = pn_recvmsg, + .backlog_rcv = pn_backlog_rcv, + .hash = pn_sock_hash, + .unhash = pn_sock_unhash, + .get_port = pn_sock_get_port, + .obj_size = sizeof(struct pn_sock), + .owner = THIS_MODULE, + .name = "PHONET", +}; + +static const struct phonet_protocol pn_dgram_proto = { + .ops = &phonet_dgram_ops, + .prot = &pn_proto, + .sock_type = SOCK_DGRAM, +}; + +int __init isi_register(void) +{ + return phonet_proto_register(PN_PROTO_PHONET, &pn_dgram_proto); +} + +void __exit isi_unregister(void) +{ + phonet_proto_unregister(PN_PROTO_PHONET, &pn_dgram_proto); +} diff --git a/net/phonet/pep-gprs.c b/net/phonet/pep-gprs.c new file mode 100644 index 000000000..1f5df0432 --- /dev/null +++ b/net/phonet/pep-gprs.c @@ -0,0 +1,306 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * File: pep-gprs.c + * + * GPRS over Phonet pipe end point socket + * + * Copyright (C) 2008 Nokia Corporation. + * + * Author: Rémi Denis-Courmont + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#define GPRS_DEFAULT_MTU 1400 + +struct gprs_dev { + struct sock *sk; + void (*old_state_change)(struct sock *); + void (*old_data_ready)(struct sock *); + void (*old_write_space)(struct sock *); + + struct net_device *dev; +}; + +static __be16 gprs_type_trans(struct sk_buff *skb) +{ + const u8 *pvfc; + u8 buf; + + pvfc = skb_header_pointer(skb, 0, 1, &buf); + if (!pvfc) + return htons(0); + /* Look at IP version field */ + switch (*pvfc >> 4) { + case 4: + return htons(ETH_P_IP); + case 6: + return htons(ETH_P_IPV6); + } + return htons(0); +} + +static void gprs_writeable(struct gprs_dev *gp) +{ + struct net_device *dev = gp->dev; + + if (pep_writeable(gp->sk)) + netif_wake_queue(dev); +} + +/* + * Socket callbacks + */ + +static void gprs_state_change(struct sock *sk) +{ + struct gprs_dev *gp = sk->sk_user_data; + + if (sk->sk_state == TCP_CLOSE_WAIT) { + struct net_device *dev = gp->dev; + + netif_stop_queue(dev); + netif_carrier_off(dev); + } +} + +static int gprs_recv(struct gprs_dev *gp, struct sk_buff *skb) +{ + struct net_device *dev = gp->dev; + int err = 0; + __be16 protocol = gprs_type_trans(skb); + + if (!protocol) { + err = -EINVAL; + goto drop; + } + + if (skb_headroom(skb) & 3) { + struct sk_buff *rskb, *fs; + int flen = 0; + + /* Phonet Pipe data header may be misaligned (3 bytes), + * so wrap the IP packet as a single fragment of an head-less + * socket buffer. The network stack will pull what it needs, + * but at least, the whole IP payload is not memcpy'd. */ + rskb = netdev_alloc_skb(dev, 0); + if (!rskb) { + err = -ENOBUFS; + goto drop; + } + skb_shinfo(rskb)->frag_list = skb; + rskb->len += skb->len; + rskb->data_len += rskb->len; + rskb->truesize += rskb->len; + + /* Avoid nested fragments */ + skb_walk_frags(skb, fs) + flen += fs->len; + skb->next = skb_shinfo(skb)->frag_list; + skb_frag_list_init(skb); + skb->len -= flen; + skb->data_len -= flen; + skb->truesize -= flen; + + skb = rskb; + } + + skb->protocol = protocol; + skb_reset_mac_header(skb); + skb->dev = dev; + + if (likely(dev->flags & IFF_UP)) { + dev->stats.rx_packets++; + dev->stats.rx_bytes += skb->len; + netif_rx(skb); + skb = NULL; + } else + err = -ENODEV; + +drop: + if (skb) { + dev_kfree_skb(skb); + dev->stats.rx_dropped++; + } + return err; +} + +static void gprs_data_ready(struct sock *sk) +{ + struct gprs_dev *gp = sk->sk_user_data; + struct sk_buff *skb; + + while ((skb = pep_read(sk)) != NULL) { + skb_orphan(skb); + gprs_recv(gp, skb); + } +} + +static void gprs_write_space(struct sock *sk) +{ + struct gprs_dev *gp = sk->sk_user_data; + + if (netif_running(gp->dev)) + gprs_writeable(gp); +} + +/* + * Network device callbacks + */ + +static int gprs_open(struct net_device *dev) +{ + struct gprs_dev *gp = netdev_priv(dev); + + gprs_writeable(gp); + return 0; +} + +static int gprs_close(struct net_device *dev) +{ + netif_stop_queue(dev); + return 0; +} + +static netdev_tx_t gprs_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct gprs_dev *gp = netdev_priv(dev); + struct sock *sk = gp->sk; + int len, err; + + switch (skb->protocol) { + case htons(ETH_P_IP): + case htons(ETH_P_IPV6): + break; + default: + dev_kfree_skb(skb); + return NETDEV_TX_OK; + } + + skb_orphan(skb); + skb_set_owner_w(skb, sk); + len = skb->len; + err = pep_write(sk, skb); + if (err) { + net_dbg_ratelimited("%s: TX error (%d)\n", dev->name, err); + dev->stats.tx_aborted_errors++; + dev->stats.tx_errors++; + } else { + dev->stats.tx_packets++; + dev->stats.tx_bytes += len; + } + + netif_stop_queue(dev); + if (pep_writeable(sk)) + netif_wake_queue(dev); + return NETDEV_TX_OK; +} + +static const struct net_device_ops gprs_netdev_ops = { + .ndo_open = gprs_open, + .ndo_stop = gprs_close, + .ndo_start_xmit = gprs_xmit, +}; + +static void gprs_setup(struct net_device *dev) +{ + dev->features = NETIF_F_FRAGLIST; + dev->type = ARPHRD_PHONET_PIPE; + dev->flags = IFF_POINTOPOINT | IFF_NOARP; + dev->mtu = GPRS_DEFAULT_MTU; + dev->min_mtu = 576; + dev->max_mtu = (PHONET_MAX_MTU - 11); + dev->hard_header_len = 0; + dev->addr_len = 0; + dev->tx_queue_len = 10; + + dev->netdev_ops = &gprs_netdev_ops; + dev->needs_free_netdev = true; +} + +/* + * External interface + */ + +/* + * Attach a GPRS interface to a datagram socket. + * Returns the interface index on success, negative error code on error. + */ +int gprs_attach(struct sock *sk) +{ + static const char ifname[] = "gprs%d"; + struct gprs_dev *gp; + struct net_device *dev; + int err; + + if (unlikely(sk->sk_type == SOCK_STREAM)) + return -EINVAL; /* need packet boundaries */ + + /* Create net device */ + dev = alloc_netdev(sizeof(*gp), ifname, NET_NAME_UNKNOWN, gprs_setup); + if (!dev) + return -ENOMEM; + gp = netdev_priv(dev); + gp->sk = sk; + gp->dev = dev; + + netif_stop_queue(dev); + err = register_netdev(dev); + if (err) { + free_netdev(dev); + return err; + } + + lock_sock(sk); + if (unlikely(sk->sk_user_data)) { + err = -EBUSY; + goto out_rel; + } + if (unlikely((1 << sk->sk_state & (TCPF_CLOSE|TCPF_LISTEN)) || + sock_flag(sk, SOCK_DEAD))) { + err = -EINVAL; + goto out_rel; + } + sk->sk_user_data = gp; + gp->old_state_change = sk->sk_state_change; + gp->old_data_ready = sk->sk_data_ready; + gp->old_write_space = sk->sk_write_space; + sk->sk_state_change = gprs_state_change; + sk->sk_data_ready = gprs_data_ready; + sk->sk_write_space = gprs_write_space; + release_sock(sk); + sock_hold(sk); + + printk(KERN_DEBUG"%s: attached\n", dev->name); + return dev->ifindex; + +out_rel: + release_sock(sk); + unregister_netdev(dev); + return err; +} + +void gprs_detach(struct sock *sk) +{ + struct gprs_dev *gp = sk->sk_user_data; + struct net_device *dev = gp->dev; + + lock_sock(sk); + sk->sk_user_data = NULL; + sk->sk_state_change = gp->old_state_change; + sk->sk_data_ready = gp->old_data_ready; + sk->sk_write_space = gp->old_write_space; + release_sock(sk); + + printk(KERN_DEBUG"%s: detached\n", dev->name); + unregister_netdev(dev); + sock_put(sk); +} diff --git a/net/phonet/pep.c b/net/phonet/pep.c new file mode 100644 index 000000000..83ea13a50 --- /dev/null +++ b/net/phonet/pep.c @@ -0,0 +1,1366 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * File: pep.c + * + * Phonet pipe protocol end point socket + * + * Copyright (C) 2008 Nokia Corporation. + * + * Author: Rémi Denis-Courmont + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +/* sk_state values: + * TCP_CLOSE sock not in use yet + * TCP_CLOSE_WAIT disconnected pipe + * TCP_LISTEN listening pipe endpoint + * TCP_SYN_RECV connected pipe in disabled state + * TCP_ESTABLISHED connected pipe in enabled state + * + * pep_sock locking: + * - sk_state, hlist: sock lock needed + * - listener: read only + * - pipe_handle: read only + */ + +#define CREDITS_MAX 10 +#define CREDITS_THR 7 + +#define pep_sb_size(s) (((s) + 5) & ~3) /* 2-bytes head, 32-bits aligned */ + +/* Get the next TLV sub-block. */ +static unsigned char *pep_get_sb(struct sk_buff *skb, u8 *ptype, u8 *plen, + void *buf) +{ + void *data = NULL; + struct { + u8 sb_type; + u8 sb_len; + } *ph, h; + int buflen = *plen; + + ph = skb_header_pointer(skb, 0, 2, &h); + if (ph == NULL || ph->sb_len < 2 || !pskb_may_pull(skb, ph->sb_len)) + return NULL; + ph->sb_len -= 2; + *ptype = ph->sb_type; + *plen = ph->sb_len; + + if (buflen > ph->sb_len) + buflen = ph->sb_len; + data = skb_header_pointer(skb, 2, buflen, buf); + __skb_pull(skb, 2 + ph->sb_len); + return data; +} + +static struct sk_buff *pep_alloc_skb(struct sock *sk, const void *payload, + int len, gfp_t priority) +{ + struct sk_buff *skb = alloc_skb(MAX_PNPIPE_HEADER + len, priority); + if (!skb) + return NULL; + skb_set_owner_w(skb, sk); + + skb_reserve(skb, MAX_PNPIPE_HEADER); + __skb_put(skb, len); + skb_copy_to_linear_data(skb, payload, len); + __skb_push(skb, sizeof(struct pnpipehdr)); + skb_reset_transport_header(skb); + return skb; +} + +static int pep_reply(struct sock *sk, struct sk_buff *oskb, u8 code, + const void *data, int len, gfp_t priority) +{ + const struct pnpipehdr *oph = pnp_hdr(oskb); + struct pnpipehdr *ph; + struct sk_buff *skb; + struct sockaddr_pn peer; + + skb = pep_alloc_skb(sk, data, len, priority); + if (!skb) + return -ENOMEM; + + ph = pnp_hdr(skb); + ph->utid = oph->utid; + ph->message_id = oph->message_id + 1; /* REQ -> RESP */ + ph->pipe_handle = oph->pipe_handle; + ph->error_code = code; + + pn_skb_get_src_sockaddr(oskb, &peer); + return pn_skb_send(sk, skb, &peer); +} + +static int pep_indicate(struct sock *sk, u8 id, u8 code, + const void *data, int len, gfp_t priority) +{ + struct pep_sock *pn = pep_sk(sk); + struct pnpipehdr *ph; + struct sk_buff *skb; + + skb = pep_alloc_skb(sk, data, len, priority); + if (!skb) + return -ENOMEM; + + ph = pnp_hdr(skb); + ph->utid = 0; + ph->message_id = id; + ph->pipe_handle = pn->pipe_handle; + ph->error_code = code; + return pn_skb_send(sk, skb, NULL); +} + +#define PAD 0x00 + +static int pipe_handler_request(struct sock *sk, u8 id, u8 code, + const void *data, int len) +{ + struct pep_sock *pn = pep_sk(sk); + struct pnpipehdr *ph; + struct sk_buff *skb; + + skb = pep_alloc_skb(sk, data, len, GFP_KERNEL); + if (!skb) + return -ENOMEM; + + ph = pnp_hdr(skb); + ph->utid = id; /* whatever */ + ph->message_id = id; + ph->pipe_handle = pn->pipe_handle; + ph->error_code = code; + return pn_skb_send(sk, skb, NULL); +} + +static int pipe_handler_send_created_ind(struct sock *sk) +{ + struct pep_sock *pn = pep_sk(sk); + u8 data[4] = { + PN_PIPE_SB_NEGOTIATED_FC, pep_sb_size(2), + pn->tx_fc, pn->rx_fc, + }; + + return pep_indicate(sk, PNS_PIPE_CREATED_IND, 1 /* sub-blocks */, + data, 4, GFP_ATOMIC); +} + +static int pep_accept_conn(struct sock *sk, struct sk_buff *skb) +{ + static const u8 data[20] = { + PAD, PAD, PAD, 2 /* sub-blocks */, + PN_PIPE_SB_REQUIRED_FC_TX, pep_sb_size(5), 3, PAD, + PN_MULTI_CREDIT_FLOW_CONTROL, + PN_ONE_CREDIT_FLOW_CONTROL, + PN_LEGACY_FLOW_CONTROL, + PAD, + PN_PIPE_SB_PREFERRED_FC_RX, pep_sb_size(5), 3, PAD, + PN_MULTI_CREDIT_FLOW_CONTROL, + PN_ONE_CREDIT_FLOW_CONTROL, + PN_LEGACY_FLOW_CONTROL, + PAD, + }; + + might_sleep(); + return pep_reply(sk, skb, PN_PIPE_NO_ERROR, data, sizeof(data), + GFP_KERNEL); +} + +static int pep_reject_conn(struct sock *sk, struct sk_buff *skb, u8 code, + gfp_t priority) +{ + static const u8 data[4] = { PAD, PAD, PAD, 0 /* sub-blocks */ }; + WARN_ON(code == PN_PIPE_NO_ERROR); + return pep_reply(sk, skb, code, data, sizeof(data), priority); +} + +/* Control requests are not sent by the pipe service and have a specific + * message format. */ +static int pep_ctrlreq_error(struct sock *sk, struct sk_buff *oskb, u8 code, + gfp_t priority) +{ + const struct pnpipehdr *oph = pnp_hdr(oskb); + struct sk_buff *skb; + struct pnpipehdr *ph; + struct sockaddr_pn dst; + u8 data[4] = { + oph->pep_type, /* PEP type */ + code, /* error code, at an unusual offset */ + PAD, PAD, + }; + + skb = pep_alloc_skb(sk, data, 4, priority); + if (!skb) + return -ENOMEM; + + ph = pnp_hdr(skb); + ph->utid = oph->utid; + ph->message_id = PNS_PEP_CTRL_RESP; + ph->pipe_handle = oph->pipe_handle; + ph->data0 = oph->data[0]; /* CTRL id */ + + pn_skb_get_src_sockaddr(oskb, &dst); + return pn_skb_send(sk, skb, &dst); +} + +static int pipe_snd_status(struct sock *sk, u8 type, u8 status, gfp_t priority) +{ + u8 data[4] = { type, PAD, PAD, status }; + + return pep_indicate(sk, PNS_PEP_STATUS_IND, PN_PEP_TYPE_COMMON, + data, 4, priority); +} + +/* Send our RX flow control information to the sender. + * Socket must be locked. */ +static void pipe_grant_credits(struct sock *sk, gfp_t priority) +{ + struct pep_sock *pn = pep_sk(sk); + + BUG_ON(sk->sk_state != TCP_ESTABLISHED); + + switch (pn->rx_fc) { + case PN_LEGACY_FLOW_CONTROL: /* TODO */ + break; + case PN_ONE_CREDIT_FLOW_CONTROL: + if (pipe_snd_status(sk, PN_PEP_IND_FLOW_CONTROL, + PEP_IND_READY, priority) == 0) + pn->rx_credits = 1; + break; + case PN_MULTI_CREDIT_FLOW_CONTROL: + if ((pn->rx_credits + CREDITS_THR) > CREDITS_MAX) + break; + if (pipe_snd_status(sk, PN_PEP_IND_ID_MCFC_GRANT_CREDITS, + CREDITS_MAX - pn->rx_credits, + priority) == 0) + pn->rx_credits = CREDITS_MAX; + break; + } +} + +static int pipe_rcv_status(struct sock *sk, struct sk_buff *skb) +{ + struct pep_sock *pn = pep_sk(sk); + struct pnpipehdr *hdr; + int wake = 0; + + if (!pskb_may_pull(skb, sizeof(*hdr) + 4)) + return -EINVAL; + + hdr = pnp_hdr(skb); + if (hdr->pep_type != PN_PEP_TYPE_COMMON) { + net_dbg_ratelimited("Phonet unknown PEP type: %u\n", + (unsigned int)hdr->pep_type); + return -EOPNOTSUPP; + } + + switch (hdr->data[0]) { + case PN_PEP_IND_FLOW_CONTROL: + switch (pn->tx_fc) { + case PN_LEGACY_FLOW_CONTROL: + switch (hdr->data[3]) { + case PEP_IND_BUSY: + atomic_set(&pn->tx_credits, 0); + break; + case PEP_IND_READY: + atomic_set(&pn->tx_credits, wake = 1); + break; + } + break; + case PN_ONE_CREDIT_FLOW_CONTROL: + if (hdr->data[3] == PEP_IND_READY) + atomic_set(&pn->tx_credits, wake = 1); + break; + } + break; + + case PN_PEP_IND_ID_MCFC_GRANT_CREDITS: + if (pn->tx_fc != PN_MULTI_CREDIT_FLOW_CONTROL) + break; + atomic_add(wake = hdr->data[3], &pn->tx_credits); + break; + + default: + net_dbg_ratelimited("Phonet unknown PEP indication: %u\n", + (unsigned int)hdr->data[0]); + return -EOPNOTSUPP; + } + if (wake) + sk->sk_write_space(sk); + return 0; +} + +static int pipe_rcv_created(struct sock *sk, struct sk_buff *skb) +{ + struct pep_sock *pn = pep_sk(sk); + struct pnpipehdr *hdr = pnp_hdr(skb); + u8 n_sb = hdr->data0; + + pn->rx_fc = pn->tx_fc = PN_LEGACY_FLOW_CONTROL; + __skb_pull(skb, sizeof(*hdr)); + while (n_sb > 0) { + u8 type, buf[2], len = sizeof(buf); + u8 *data = pep_get_sb(skb, &type, &len, buf); + + if (data == NULL) + return -EINVAL; + switch (type) { + case PN_PIPE_SB_NEGOTIATED_FC: + if (len < 2 || (data[0] | data[1]) > 3) + break; + pn->tx_fc = data[0] & 3; + pn->rx_fc = data[1] & 3; + break; + } + n_sb--; + } + return 0; +} + +/* Queue an skb to a connected sock. + * Socket lock must be held. */ +static int pipe_do_rcv(struct sock *sk, struct sk_buff *skb) +{ + struct pep_sock *pn = pep_sk(sk); + struct pnpipehdr *hdr = pnp_hdr(skb); + struct sk_buff_head *queue; + int err = 0; + + BUG_ON(sk->sk_state == TCP_CLOSE_WAIT); + + switch (hdr->message_id) { + case PNS_PEP_CONNECT_REQ: + pep_reject_conn(sk, skb, PN_PIPE_ERR_PEP_IN_USE, GFP_ATOMIC); + break; + + case PNS_PEP_DISCONNECT_REQ: + pep_reply(sk, skb, PN_PIPE_NO_ERROR, NULL, 0, GFP_ATOMIC); + sk->sk_state = TCP_CLOSE_WAIT; + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_state_change(sk); + break; + + case PNS_PEP_ENABLE_REQ: + /* Wait for PNS_PIPE_(ENABLED|REDIRECTED)_IND */ + pep_reply(sk, skb, PN_PIPE_NO_ERROR, NULL, 0, GFP_ATOMIC); + break; + + case PNS_PEP_RESET_REQ: + switch (hdr->state_after_reset) { + case PN_PIPE_DISABLE: + pn->init_enable = 0; + break; + case PN_PIPE_ENABLE: + pn->init_enable = 1; + break; + default: /* not allowed to send an error here!? */ + err = -EINVAL; + goto out; + } + fallthrough; + case PNS_PEP_DISABLE_REQ: + atomic_set(&pn->tx_credits, 0); + pep_reply(sk, skb, PN_PIPE_NO_ERROR, NULL, 0, GFP_ATOMIC); + break; + + case PNS_PEP_CTRL_REQ: + if (skb_queue_len(&pn->ctrlreq_queue) >= PNPIPE_CTRLREQ_MAX) { + atomic_inc(&sk->sk_drops); + break; + } + __skb_pull(skb, 4); + queue = &pn->ctrlreq_queue; + goto queue; + + case PNS_PIPE_ALIGNED_DATA: + __skb_pull(skb, 1); + fallthrough; + case PNS_PIPE_DATA: + __skb_pull(skb, 3); /* Pipe data header */ + if (!pn_flow_safe(pn->rx_fc)) { + err = sock_queue_rcv_skb(sk, skb); + if (!err) + return NET_RX_SUCCESS; + err = -ENOBUFS; + break; + } + + if (pn->rx_credits == 0) { + atomic_inc(&sk->sk_drops); + err = -ENOBUFS; + break; + } + pn->rx_credits--; + queue = &sk->sk_receive_queue; + goto queue; + + case PNS_PEP_STATUS_IND: + pipe_rcv_status(sk, skb); + break; + + case PNS_PIPE_REDIRECTED_IND: + err = pipe_rcv_created(sk, skb); + break; + + case PNS_PIPE_CREATED_IND: + err = pipe_rcv_created(sk, skb); + if (err) + break; + fallthrough; + case PNS_PIPE_RESET_IND: + if (!pn->init_enable) + break; + fallthrough; + case PNS_PIPE_ENABLED_IND: + if (!pn_flow_safe(pn->tx_fc)) { + atomic_set(&pn->tx_credits, 1); + sk->sk_write_space(sk); + } + if (sk->sk_state == TCP_ESTABLISHED) + break; /* Nothing to do */ + sk->sk_state = TCP_ESTABLISHED; + pipe_grant_credits(sk, GFP_ATOMIC); + break; + + case PNS_PIPE_DISABLED_IND: + sk->sk_state = TCP_SYN_RECV; + pn->rx_credits = 0; + break; + + default: + net_dbg_ratelimited("Phonet unknown PEP message: %u\n", + hdr->message_id); + err = -EINVAL; + } +out: + kfree_skb(skb); + return (err == -ENOBUFS) ? NET_RX_DROP : NET_RX_SUCCESS; + +queue: + skb->dev = NULL; + skb_set_owner_r(skb, sk); + skb_queue_tail(queue, skb); + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_data_ready(sk); + return NET_RX_SUCCESS; +} + +/* Destroy connected sock. */ +static void pipe_destruct(struct sock *sk) +{ + struct pep_sock *pn = pep_sk(sk); + + skb_queue_purge(&sk->sk_receive_queue); + skb_queue_purge(&pn->ctrlreq_queue); +} + +static u8 pipe_negotiate_fc(const u8 *fcs, unsigned int n) +{ + unsigned int i; + u8 final_fc = PN_NO_FLOW_CONTROL; + + for (i = 0; i < n; i++) { + u8 fc = fcs[i]; + + if (fc > final_fc && fc < PN_MAX_FLOW_CONTROL) + final_fc = fc; + } + return final_fc; +} + +static int pep_connresp_rcv(struct sock *sk, struct sk_buff *skb) +{ + struct pep_sock *pn = pep_sk(sk); + struct pnpipehdr *hdr; + u8 n_sb; + + if (!pskb_pull(skb, sizeof(*hdr) + 4)) + return -EINVAL; + + hdr = pnp_hdr(skb); + if (hdr->error_code != PN_PIPE_NO_ERROR) + return -ECONNREFUSED; + + /* Parse sub-blocks */ + n_sb = hdr->data[3]; + while (n_sb > 0) { + u8 type, buf[6], len = sizeof(buf); + const u8 *data = pep_get_sb(skb, &type, &len, buf); + + if (data == NULL) + return -EINVAL; + + switch (type) { + case PN_PIPE_SB_REQUIRED_FC_TX: + if (len < 2 || len < data[0]) + break; + pn->tx_fc = pipe_negotiate_fc(data + 2, len - 2); + break; + + case PN_PIPE_SB_PREFERRED_FC_RX: + if (len < 2 || len < data[0]) + break; + pn->rx_fc = pipe_negotiate_fc(data + 2, len - 2); + break; + + } + n_sb--; + } + + return pipe_handler_send_created_ind(sk); +} + +static int pep_enableresp_rcv(struct sock *sk, struct sk_buff *skb) +{ + struct pnpipehdr *hdr = pnp_hdr(skb); + + if (hdr->error_code != PN_PIPE_NO_ERROR) + return -ECONNREFUSED; + + return pep_indicate(sk, PNS_PIPE_ENABLED_IND, 0 /* sub-blocks */, + NULL, 0, GFP_ATOMIC); + +} + +static void pipe_start_flow_control(struct sock *sk) +{ + struct pep_sock *pn = pep_sk(sk); + + if (!pn_flow_safe(pn->tx_fc)) { + atomic_set(&pn->tx_credits, 1); + sk->sk_write_space(sk); + } + pipe_grant_credits(sk, GFP_ATOMIC); +} + +/* Queue an skb to an actively connected sock. + * Socket lock must be held. */ +static int pipe_handler_do_rcv(struct sock *sk, struct sk_buff *skb) +{ + struct pep_sock *pn = pep_sk(sk); + struct pnpipehdr *hdr = pnp_hdr(skb); + int err = NET_RX_SUCCESS; + + switch (hdr->message_id) { + case PNS_PIPE_ALIGNED_DATA: + __skb_pull(skb, 1); + fallthrough; + case PNS_PIPE_DATA: + __skb_pull(skb, 3); /* Pipe data header */ + if (!pn_flow_safe(pn->rx_fc)) { + err = sock_queue_rcv_skb(sk, skb); + if (!err) + return NET_RX_SUCCESS; + err = NET_RX_DROP; + break; + } + + if (pn->rx_credits == 0) { + atomic_inc(&sk->sk_drops); + err = NET_RX_DROP; + break; + } + pn->rx_credits--; + skb->dev = NULL; + skb_set_owner_r(skb, sk); + skb_queue_tail(&sk->sk_receive_queue, skb); + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_data_ready(sk); + return NET_RX_SUCCESS; + + case PNS_PEP_CONNECT_RESP: + if (sk->sk_state != TCP_SYN_SENT) + break; + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_state_change(sk); + if (pep_connresp_rcv(sk, skb)) { + sk->sk_state = TCP_CLOSE_WAIT; + break; + } + if (pn->init_enable == PN_PIPE_DISABLE) + sk->sk_state = TCP_SYN_RECV; + else { + sk->sk_state = TCP_ESTABLISHED; + pipe_start_flow_control(sk); + } + break; + + case PNS_PEP_ENABLE_RESP: + if (sk->sk_state != TCP_SYN_SENT) + break; + + if (pep_enableresp_rcv(sk, skb)) { + sk->sk_state = TCP_CLOSE_WAIT; + break; + } + + sk->sk_state = TCP_ESTABLISHED; + pipe_start_flow_control(sk); + break; + + case PNS_PEP_DISCONNECT_RESP: + /* sock should already be dead, nothing to do */ + break; + + case PNS_PEP_STATUS_IND: + pipe_rcv_status(sk, skb); + break; + } + kfree_skb(skb); + return err; +} + +/* Listening sock must be locked */ +static struct sock *pep_find_pipe(const struct hlist_head *hlist, + const struct sockaddr_pn *dst, + u8 pipe_handle) +{ + struct sock *sknode; + u16 dobj = pn_sockaddr_get_object(dst); + + sk_for_each(sknode, hlist) { + struct pep_sock *pnnode = pep_sk(sknode); + + /* Ports match, but addresses might not: */ + if (pnnode->pn_sk.sobject != dobj) + continue; + if (pnnode->pipe_handle != pipe_handle) + continue; + if (sknode->sk_state == TCP_CLOSE_WAIT) + continue; + + sock_hold(sknode); + return sknode; + } + return NULL; +} + +/* + * Deliver an skb to a listening sock. + * Socket lock must be held. + * We then queue the skb to the right connected sock (if any). + */ +static int pep_do_rcv(struct sock *sk, struct sk_buff *skb) +{ + struct pep_sock *pn = pep_sk(sk); + struct sock *sknode; + struct pnpipehdr *hdr; + struct sockaddr_pn dst; + u8 pipe_handle; + + if (!pskb_may_pull(skb, sizeof(*hdr))) + goto drop; + + hdr = pnp_hdr(skb); + pipe_handle = hdr->pipe_handle; + if (pipe_handle == PN_PIPE_INVALID_HANDLE) + goto drop; + + pn_skb_get_dst_sockaddr(skb, &dst); + + /* Look for an existing pipe handle */ + sknode = pep_find_pipe(&pn->hlist, &dst, pipe_handle); + if (sknode) + return sk_receive_skb(sknode, skb, 1); + + switch (hdr->message_id) { + case PNS_PEP_CONNECT_REQ: + if (sk->sk_state != TCP_LISTEN || sk_acceptq_is_full(sk)) { + pep_reject_conn(sk, skb, PN_PIPE_ERR_PEP_IN_USE, + GFP_ATOMIC); + break; + } + skb_queue_head(&sk->sk_receive_queue, skb); + sk_acceptq_added(sk); + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_data_ready(sk); + return NET_RX_SUCCESS; + + case PNS_PEP_DISCONNECT_REQ: + pep_reply(sk, skb, PN_PIPE_NO_ERROR, NULL, 0, GFP_ATOMIC); + break; + + case PNS_PEP_CTRL_REQ: + pep_ctrlreq_error(sk, skb, PN_PIPE_INVALID_HANDLE, GFP_ATOMIC); + break; + + case PNS_PEP_RESET_REQ: + case PNS_PEP_ENABLE_REQ: + case PNS_PEP_DISABLE_REQ: + /* invalid handle is not even allowed here! */ + break; + + default: + if ((1 << sk->sk_state) + & ~(TCPF_CLOSE|TCPF_LISTEN|TCPF_CLOSE_WAIT)) + /* actively connected socket */ + return pipe_handler_do_rcv(sk, skb); + } +drop: + kfree_skb(skb); + return NET_RX_SUCCESS; +} + +static int pipe_do_remove(struct sock *sk) +{ + struct pep_sock *pn = pep_sk(sk); + struct pnpipehdr *ph; + struct sk_buff *skb; + + skb = pep_alloc_skb(sk, NULL, 0, GFP_KERNEL); + if (!skb) + return -ENOMEM; + + ph = pnp_hdr(skb); + ph->utid = 0; + ph->message_id = PNS_PIPE_REMOVE_REQ; + ph->pipe_handle = pn->pipe_handle; + ph->data0 = PAD; + return pn_skb_send(sk, skb, NULL); +} + +/* associated socket ceases to exist */ +static void pep_sock_close(struct sock *sk, long timeout) +{ + struct pep_sock *pn = pep_sk(sk); + int ifindex = 0; + + sock_hold(sk); /* keep a reference after sk_common_release() */ + sk_common_release(sk); + + lock_sock(sk); + if ((1 << sk->sk_state) & (TCPF_SYN_RECV|TCPF_ESTABLISHED)) { + if (sk->sk_backlog_rcv == pipe_do_rcv) + /* Forcefully remove dangling Phonet pipe */ + pipe_do_remove(sk); + else + pipe_handler_request(sk, PNS_PEP_DISCONNECT_REQ, PAD, + NULL, 0); + } + sk->sk_state = TCP_CLOSE; + + ifindex = pn->ifindex; + pn->ifindex = 0; + release_sock(sk); + + if (ifindex) + gprs_detach(sk); + sock_put(sk); +} + +static struct sock *pep_sock_accept(struct sock *sk, int flags, int *errp, + bool kern) +{ + struct pep_sock *pn = pep_sk(sk), *newpn; + struct sock *newsk = NULL; + struct sk_buff *skb; + struct pnpipehdr *hdr; + struct sockaddr_pn dst, src; + int err; + u16 peer_type; + u8 pipe_handle, enabled, n_sb; + u8 aligned = 0; + + skb = skb_recv_datagram(sk, (flags & O_NONBLOCK) ? MSG_DONTWAIT : 0, + errp); + if (!skb) + return NULL; + + lock_sock(sk); + if (sk->sk_state != TCP_LISTEN) { + err = -EINVAL; + goto drop; + } + sk_acceptq_removed(sk); + + err = -EPROTO; + if (!pskb_may_pull(skb, sizeof(*hdr) + 4)) + goto drop; + + hdr = pnp_hdr(skb); + pipe_handle = hdr->pipe_handle; + switch (hdr->state_after_connect) { + case PN_PIPE_DISABLE: + enabled = 0; + break; + case PN_PIPE_ENABLE: + enabled = 1; + break; + default: + pep_reject_conn(sk, skb, PN_PIPE_ERR_INVALID_PARAM, + GFP_KERNEL); + goto drop; + } + peer_type = hdr->other_pep_type << 8; + + /* Parse sub-blocks (options) */ + n_sb = hdr->data[3]; + while (n_sb > 0) { + u8 type, buf[1], len = sizeof(buf); + const u8 *data = pep_get_sb(skb, &type, &len, buf); + + if (data == NULL) + goto drop; + switch (type) { + case PN_PIPE_SB_CONNECT_REQ_PEP_SUB_TYPE: + if (len < 1) + goto drop; + peer_type = (peer_type & 0xff00) | data[0]; + break; + case PN_PIPE_SB_ALIGNED_DATA: + aligned = data[0] != 0; + break; + } + n_sb--; + } + + /* Check for duplicate pipe handle */ + newsk = pep_find_pipe(&pn->hlist, &dst, pipe_handle); + if (unlikely(newsk)) { + __sock_put(newsk); + newsk = NULL; + pep_reject_conn(sk, skb, PN_PIPE_ERR_PEP_IN_USE, GFP_KERNEL); + goto drop; + } + + /* Create a new to-be-accepted sock */ + newsk = sk_alloc(sock_net(sk), PF_PHONET, GFP_KERNEL, sk->sk_prot, + kern); + if (!newsk) { + pep_reject_conn(sk, skb, PN_PIPE_ERR_OVERLOAD, GFP_KERNEL); + err = -ENOBUFS; + goto drop; + } + + sock_init_data(NULL, newsk); + newsk->sk_state = TCP_SYN_RECV; + newsk->sk_backlog_rcv = pipe_do_rcv; + newsk->sk_protocol = sk->sk_protocol; + newsk->sk_destruct = pipe_destruct; + + newpn = pep_sk(newsk); + pn_skb_get_dst_sockaddr(skb, &dst); + pn_skb_get_src_sockaddr(skb, &src); + newpn->pn_sk.sobject = pn_sockaddr_get_object(&dst); + newpn->pn_sk.dobject = pn_sockaddr_get_object(&src); + newpn->pn_sk.resource = pn_sockaddr_get_resource(&dst); + sock_hold(sk); + newpn->listener = sk; + skb_queue_head_init(&newpn->ctrlreq_queue); + newpn->pipe_handle = pipe_handle; + atomic_set(&newpn->tx_credits, 0); + newpn->ifindex = 0; + newpn->peer_type = peer_type; + newpn->rx_credits = 0; + newpn->rx_fc = newpn->tx_fc = PN_LEGACY_FLOW_CONTROL; + newpn->init_enable = enabled; + newpn->aligned = aligned; + + err = pep_accept_conn(newsk, skb); + if (err) { + __sock_put(sk); + sock_put(newsk); + newsk = NULL; + goto drop; + } + sk_add_node(newsk, &pn->hlist); +drop: + release_sock(sk); + kfree_skb(skb); + *errp = err; + return newsk; +} + +static int pep_sock_connect(struct sock *sk, struct sockaddr *addr, int len) +{ + struct pep_sock *pn = pep_sk(sk); + int err; + u8 data[4] = { 0 /* sub-blocks */, PAD, PAD, PAD }; + + if (pn->pipe_handle == PN_PIPE_INVALID_HANDLE) + pn->pipe_handle = 1; /* anything but INVALID_HANDLE */ + + err = pipe_handler_request(sk, PNS_PEP_CONNECT_REQ, + pn->init_enable, data, 4); + if (err) { + pn->pipe_handle = PN_PIPE_INVALID_HANDLE; + return err; + } + + sk->sk_state = TCP_SYN_SENT; + + return 0; +} + +static int pep_sock_enable(struct sock *sk, struct sockaddr *addr, int len) +{ + int err; + + err = pipe_handler_request(sk, PNS_PEP_ENABLE_REQ, PAD, + NULL, 0); + if (err) + return err; + + sk->sk_state = TCP_SYN_SENT; + + return 0; +} + +static int pep_ioctl(struct sock *sk, int cmd, unsigned long arg) +{ + struct pep_sock *pn = pep_sk(sk); + int answ; + int ret = -ENOIOCTLCMD; + + switch (cmd) { + case SIOCINQ: + if (sk->sk_state == TCP_LISTEN) { + ret = -EINVAL; + break; + } + + lock_sock(sk); + if (sock_flag(sk, SOCK_URGINLINE) && + !skb_queue_empty(&pn->ctrlreq_queue)) + answ = skb_peek(&pn->ctrlreq_queue)->len; + else if (!skb_queue_empty(&sk->sk_receive_queue)) + answ = skb_peek(&sk->sk_receive_queue)->len; + else + answ = 0; + release_sock(sk); + ret = put_user(answ, (int __user *)arg); + break; + + case SIOCPNENABLEPIPE: + lock_sock(sk); + if (sk->sk_state == TCP_SYN_SENT) + ret = -EBUSY; + else if (sk->sk_state == TCP_ESTABLISHED) + ret = -EISCONN; + else if (!pn->pn_sk.sobject) + ret = -EADDRNOTAVAIL; + else + ret = pep_sock_enable(sk, NULL, 0); + release_sock(sk); + break; + } + + return ret; +} + +static int pep_init(struct sock *sk) +{ + struct pep_sock *pn = pep_sk(sk); + + sk->sk_destruct = pipe_destruct; + INIT_HLIST_HEAD(&pn->hlist); + pn->listener = NULL; + skb_queue_head_init(&pn->ctrlreq_queue); + atomic_set(&pn->tx_credits, 0); + pn->ifindex = 0; + pn->peer_type = 0; + pn->pipe_handle = PN_PIPE_INVALID_HANDLE; + pn->rx_credits = 0; + pn->rx_fc = pn->tx_fc = PN_LEGACY_FLOW_CONTROL; + pn->init_enable = 1; + pn->aligned = 0; + return 0; +} + +static int pep_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct pep_sock *pn = pep_sk(sk); + int val = 0, err = 0; + + if (level != SOL_PNPIPE) + return -ENOPROTOOPT; + if (optlen >= sizeof(int)) { + if (copy_from_sockptr(&val, optval, sizeof(int))) + return -EFAULT; + } + + lock_sock(sk); + switch (optname) { + case PNPIPE_ENCAP: + if (val && val != PNPIPE_ENCAP_IP) { + err = -EINVAL; + break; + } + if (!pn->ifindex == !val) + break; /* Nothing to do! */ + if (!capable(CAP_NET_ADMIN)) { + err = -EPERM; + break; + } + if (val) { + release_sock(sk); + err = gprs_attach(sk); + if (err > 0) { + pn->ifindex = err; + err = 0; + } + } else { + pn->ifindex = 0; + release_sock(sk); + gprs_detach(sk); + err = 0; + } + goto out_norel; + + case PNPIPE_HANDLE: + if ((sk->sk_state == TCP_CLOSE) && + (val >= 0) && (val < PN_PIPE_INVALID_HANDLE)) + pn->pipe_handle = val; + else + err = -EINVAL; + break; + + case PNPIPE_INITSTATE: + pn->init_enable = !!val; + break; + + default: + err = -ENOPROTOOPT; + } + release_sock(sk); + +out_norel: + return err; +} + +static int pep_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct pep_sock *pn = pep_sk(sk); + int len, val; + + if (level != SOL_PNPIPE) + return -ENOPROTOOPT; + if (get_user(len, optlen)) + return -EFAULT; + + switch (optname) { + case PNPIPE_ENCAP: + val = pn->ifindex ? PNPIPE_ENCAP_IP : PNPIPE_ENCAP_NONE; + break; + + case PNPIPE_IFINDEX: + val = pn->ifindex; + break; + + case PNPIPE_HANDLE: + val = pn->pipe_handle; + if (val == PN_PIPE_INVALID_HANDLE) + return -EINVAL; + break; + + case PNPIPE_INITSTATE: + val = pn->init_enable; + break; + + default: + return -ENOPROTOOPT; + } + + len = min_t(unsigned int, sizeof(int), len); + if (put_user(len, optlen)) + return -EFAULT; + if (put_user(val, (int __user *) optval)) + return -EFAULT; + return 0; +} + +static int pipe_skb_send(struct sock *sk, struct sk_buff *skb) +{ + struct pep_sock *pn = pep_sk(sk); + struct pnpipehdr *ph; + int err; + + if (pn_flow_safe(pn->tx_fc) && + !atomic_add_unless(&pn->tx_credits, -1, 0)) { + kfree_skb(skb); + return -ENOBUFS; + } + + skb_push(skb, 3 + pn->aligned); + skb_reset_transport_header(skb); + ph = pnp_hdr(skb); + ph->utid = 0; + if (pn->aligned) { + ph->message_id = PNS_PIPE_ALIGNED_DATA; + ph->data0 = 0; /* padding */ + } else + ph->message_id = PNS_PIPE_DATA; + ph->pipe_handle = pn->pipe_handle; + err = pn_skb_send(sk, skb, NULL); + + if (err && pn_flow_safe(pn->tx_fc)) + atomic_inc(&pn->tx_credits); + return err; + +} + +static int pep_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) +{ + struct pep_sock *pn = pep_sk(sk); + struct sk_buff *skb; + long timeo; + int flags = msg->msg_flags; + int err, done; + + if (len > USHRT_MAX) + return -EMSGSIZE; + + if ((msg->msg_flags & ~(MSG_DONTWAIT|MSG_EOR|MSG_NOSIGNAL| + MSG_CMSG_COMPAT)) || + !(msg->msg_flags & MSG_EOR)) + return -EOPNOTSUPP; + + skb = sock_alloc_send_skb(sk, MAX_PNPIPE_HEADER + len, + flags & MSG_DONTWAIT, &err); + if (!skb) + return err; + + skb_reserve(skb, MAX_PHONET_HEADER + 3 + pn->aligned); + err = memcpy_from_msg(skb_put(skb, len), msg, len); + if (err < 0) + goto outfree; + + lock_sock(sk); + timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); + if ((1 << sk->sk_state) & (TCPF_LISTEN|TCPF_CLOSE)) { + err = -ENOTCONN; + goto out; + } + if (sk->sk_state != TCP_ESTABLISHED) { + /* Wait until the pipe gets to enabled state */ +disabled: + err = sk_stream_wait_connect(sk, &timeo); + if (err) + goto out; + + if (sk->sk_state == TCP_CLOSE_WAIT) { + err = -ECONNRESET; + goto out; + } + } + BUG_ON(sk->sk_state != TCP_ESTABLISHED); + + /* Wait until flow control allows TX */ + done = atomic_read(&pn->tx_credits); + while (!done) { + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + if (!timeo) { + err = -EAGAIN; + goto out; + } + if (signal_pending(current)) { + err = sock_intr_errno(timeo); + goto out; + } + + add_wait_queue(sk_sleep(sk), &wait); + done = sk_wait_event(sk, &timeo, atomic_read(&pn->tx_credits), &wait); + remove_wait_queue(sk_sleep(sk), &wait); + + if (sk->sk_state != TCP_ESTABLISHED) + goto disabled; + } + + err = pipe_skb_send(sk, skb); + if (err >= 0) + err = len; /* success! */ + skb = NULL; +out: + release_sock(sk); +outfree: + kfree_skb(skb); + return err; +} + +int pep_writeable(struct sock *sk) +{ + struct pep_sock *pn = pep_sk(sk); + + return atomic_read(&pn->tx_credits); +} + +int pep_write(struct sock *sk, struct sk_buff *skb) +{ + struct sk_buff *rskb, *fs; + int flen = 0; + + if (pep_sk(sk)->aligned) + return pipe_skb_send(sk, skb); + + rskb = alloc_skb(MAX_PNPIPE_HEADER, GFP_ATOMIC); + if (!rskb) { + kfree_skb(skb); + return -ENOMEM; + } + skb_shinfo(rskb)->frag_list = skb; + rskb->len += skb->len; + rskb->data_len += rskb->len; + rskb->truesize += rskb->len; + + /* Avoid nested fragments */ + skb_walk_frags(skb, fs) + flen += fs->len; + skb->next = skb_shinfo(skb)->frag_list; + skb_frag_list_init(skb); + skb->len -= flen; + skb->data_len -= flen; + skb->truesize -= flen; + + skb_reserve(rskb, MAX_PHONET_HEADER + 3); + return pipe_skb_send(sk, rskb); +} + +struct sk_buff *pep_read(struct sock *sk) +{ + struct sk_buff *skb = skb_dequeue(&sk->sk_receive_queue); + + if (sk->sk_state == TCP_ESTABLISHED) + pipe_grant_credits(sk, GFP_ATOMIC); + return skb; +} + +static int pep_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int flags, int *addr_len) +{ + struct sk_buff *skb; + int err; + + if (flags & ~(MSG_OOB|MSG_PEEK|MSG_TRUNC|MSG_DONTWAIT|MSG_WAITALL| + MSG_NOSIGNAL|MSG_CMSG_COMPAT)) + return -EOPNOTSUPP; + + if (unlikely(1 << sk->sk_state & (TCPF_LISTEN | TCPF_CLOSE))) + return -ENOTCONN; + + if ((flags & MSG_OOB) || sock_flag(sk, SOCK_URGINLINE)) { + /* Dequeue and acknowledge control request */ + struct pep_sock *pn = pep_sk(sk); + + if (flags & MSG_PEEK) + return -EOPNOTSUPP; + skb = skb_dequeue(&pn->ctrlreq_queue); + if (skb) { + pep_ctrlreq_error(sk, skb, PN_PIPE_NO_ERROR, + GFP_KERNEL); + msg->msg_flags |= MSG_OOB; + goto copy; + } + if (flags & MSG_OOB) + return -EINVAL; + } + + skb = skb_recv_datagram(sk, flags, &err); + lock_sock(sk); + if (skb == NULL) { + if (err == -ENOTCONN && sk->sk_state == TCP_CLOSE_WAIT) + err = -ECONNRESET; + release_sock(sk); + return err; + } + + if (sk->sk_state == TCP_ESTABLISHED) + pipe_grant_credits(sk, GFP_KERNEL); + release_sock(sk); +copy: + msg->msg_flags |= MSG_EOR; + if (skb->len > len) + msg->msg_flags |= MSG_TRUNC; + else + len = skb->len; + + err = skb_copy_datagram_msg(skb, 0, msg, len); + if (!err) + err = (flags & MSG_TRUNC) ? skb->len : len; + + skb_free_datagram(sk, skb); + return err; +} + +static void pep_sock_unhash(struct sock *sk) +{ + struct pep_sock *pn = pep_sk(sk); + struct sock *skparent = NULL; + + lock_sock(sk); + + if (pn->listener != NULL) { + skparent = pn->listener; + pn->listener = NULL; + release_sock(sk); + + pn = pep_sk(skparent); + lock_sock(skparent); + sk_del_node_init(sk); + sk = skparent; + } + + /* Unhash a listening sock only when it is closed + * and all of its active connected pipes are closed. */ + if (hlist_empty(&pn->hlist)) + pn_sock_unhash(&pn->pn_sk.sk); + release_sock(sk); + + if (skparent) + sock_put(skparent); +} + +static struct proto pep_proto = { + .close = pep_sock_close, + .accept = pep_sock_accept, + .connect = pep_sock_connect, + .ioctl = pep_ioctl, + .init = pep_init, + .setsockopt = pep_setsockopt, + .getsockopt = pep_getsockopt, + .sendmsg = pep_sendmsg, + .recvmsg = pep_recvmsg, + .backlog_rcv = pep_do_rcv, + .hash = pn_sock_hash, + .unhash = pep_sock_unhash, + .get_port = pn_sock_get_port, + .obj_size = sizeof(struct pep_sock), + .owner = THIS_MODULE, + .name = "PNPIPE", +}; + +static const struct phonet_protocol pep_pn_proto = { + .ops = &phonet_stream_ops, + .prot = &pep_proto, + .sock_type = SOCK_SEQPACKET, +}; + +static int __init pep_register(void) +{ + return phonet_proto_register(PN_PROTO_PIPE, &pep_pn_proto); +} + +static void __exit pep_unregister(void) +{ + phonet_proto_unregister(PN_PROTO_PIPE, &pep_pn_proto); +} + +module_init(pep_register); +module_exit(pep_unregister); +MODULE_AUTHOR("Remi Denis-Courmont, Nokia"); +MODULE_DESCRIPTION("Phonet pipe protocol"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO(PF_PHONET, PN_PROTO_PIPE); diff --git a/net/phonet/pn_dev.c b/net/phonet/pn_dev.c new file mode 100644 index 000000000..cde671d29 --- /dev/null +++ b/net/phonet/pn_dev.c @@ -0,0 +1,419 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * File: pn_dev.c + * + * Phonet network device + * + * Copyright (C) 2008 Nokia Corporation. + * + * Authors: Sakari Ailus + * Rémi Denis-Courmont + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct phonet_routes { + struct mutex lock; + struct net_device __rcu *table[64]; +}; + +struct phonet_net { + struct phonet_device_list pndevs; + struct phonet_routes routes; +}; + +static unsigned int phonet_net_id __read_mostly; + +static struct phonet_net *phonet_pernet(struct net *net) +{ + return net_generic(net, phonet_net_id); +} + +struct phonet_device_list *phonet_device_list(struct net *net) +{ + struct phonet_net *pnn = phonet_pernet(net); + return &pnn->pndevs; +} + +/* Allocate new Phonet device. */ +static struct phonet_device *__phonet_device_alloc(struct net_device *dev) +{ + struct phonet_device_list *pndevs = phonet_device_list(dev_net(dev)); + struct phonet_device *pnd = kmalloc(sizeof(*pnd), GFP_ATOMIC); + if (pnd == NULL) + return NULL; + pnd->netdev = dev; + bitmap_zero(pnd->addrs, 64); + + BUG_ON(!mutex_is_locked(&pndevs->lock)); + list_add_rcu(&pnd->list, &pndevs->list); + return pnd; +} + +static struct phonet_device *__phonet_get(struct net_device *dev) +{ + struct phonet_device_list *pndevs = phonet_device_list(dev_net(dev)); + struct phonet_device *pnd; + + BUG_ON(!mutex_is_locked(&pndevs->lock)); + list_for_each_entry(pnd, &pndevs->list, list) { + if (pnd->netdev == dev) + return pnd; + } + return NULL; +} + +static struct phonet_device *__phonet_get_rcu(struct net_device *dev) +{ + struct phonet_device_list *pndevs = phonet_device_list(dev_net(dev)); + struct phonet_device *pnd; + + list_for_each_entry_rcu(pnd, &pndevs->list, list) { + if (pnd->netdev == dev) + return pnd; + } + return NULL; +} + +static void phonet_device_destroy(struct net_device *dev) +{ + struct phonet_device_list *pndevs = phonet_device_list(dev_net(dev)); + struct phonet_device *pnd; + + ASSERT_RTNL(); + + mutex_lock(&pndevs->lock); + pnd = __phonet_get(dev); + if (pnd) + list_del_rcu(&pnd->list); + mutex_unlock(&pndevs->lock); + + if (pnd) { + u8 addr; + + for_each_set_bit(addr, pnd->addrs, 64) + phonet_address_notify(RTM_DELADDR, dev, addr); + kfree(pnd); + } +} + +struct net_device *phonet_device_get(struct net *net) +{ + struct phonet_device_list *pndevs = phonet_device_list(net); + struct phonet_device *pnd; + struct net_device *dev = NULL; + + rcu_read_lock(); + list_for_each_entry_rcu(pnd, &pndevs->list, list) { + dev = pnd->netdev; + BUG_ON(!dev); + + if ((dev->reg_state == NETREG_REGISTERED) && + ((pnd->netdev->flags & IFF_UP)) == IFF_UP) + break; + dev = NULL; + } + dev_hold(dev); + rcu_read_unlock(); + return dev; +} + +int phonet_address_add(struct net_device *dev, u8 addr) +{ + struct phonet_device_list *pndevs = phonet_device_list(dev_net(dev)); + struct phonet_device *pnd; + int err = 0; + + mutex_lock(&pndevs->lock); + /* Find or create Phonet-specific device data */ + pnd = __phonet_get(dev); + if (pnd == NULL) + pnd = __phonet_device_alloc(dev); + if (unlikely(pnd == NULL)) + err = -ENOMEM; + else if (test_and_set_bit(addr >> 2, pnd->addrs)) + err = -EEXIST; + mutex_unlock(&pndevs->lock); + return err; +} + +int phonet_address_del(struct net_device *dev, u8 addr) +{ + struct phonet_device_list *pndevs = phonet_device_list(dev_net(dev)); + struct phonet_device *pnd; + int err = 0; + + mutex_lock(&pndevs->lock); + pnd = __phonet_get(dev); + if (!pnd || !test_and_clear_bit(addr >> 2, pnd->addrs)) { + err = -EADDRNOTAVAIL; + pnd = NULL; + } else if (bitmap_empty(pnd->addrs, 64)) + list_del_rcu(&pnd->list); + else + pnd = NULL; + mutex_unlock(&pndevs->lock); + + if (pnd) + kfree_rcu(pnd, rcu); + + return err; +} + +/* Gets a source address toward a destination, through a interface. */ +u8 phonet_address_get(struct net_device *dev, u8 daddr) +{ + struct phonet_device *pnd; + u8 saddr; + + rcu_read_lock(); + pnd = __phonet_get_rcu(dev); + if (pnd) { + BUG_ON(bitmap_empty(pnd->addrs, 64)); + + /* Use same source address as destination, if possible */ + if (test_bit(daddr >> 2, pnd->addrs)) + saddr = daddr; + else + saddr = find_first_bit(pnd->addrs, 64) << 2; + } else + saddr = PN_NO_ADDR; + rcu_read_unlock(); + + if (saddr == PN_NO_ADDR) { + /* Fallback to another device */ + struct net_device *def_dev; + + def_dev = phonet_device_get(dev_net(dev)); + if (def_dev) { + if (def_dev != dev) + saddr = phonet_address_get(def_dev, daddr); + dev_put(def_dev); + } + } + return saddr; +} + +int phonet_address_lookup(struct net *net, u8 addr) +{ + struct phonet_device_list *pndevs = phonet_device_list(net); + struct phonet_device *pnd; + int err = -EADDRNOTAVAIL; + + rcu_read_lock(); + list_for_each_entry_rcu(pnd, &pndevs->list, list) { + /* Don't allow unregistering devices! */ + if ((pnd->netdev->reg_state != NETREG_REGISTERED) || + ((pnd->netdev->flags & IFF_UP)) != IFF_UP) + continue; + + if (test_bit(addr >> 2, pnd->addrs)) { + err = 0; + goto found; + } + } +found: + rcu_read_unlock(); + return err; +} + +/* automatically configure a Phonet device, if supported */ +static int phonet_device_autoconf(struct net_device *dev) +{ + struct if_phonet_req req; + int ret; + + if (!dev->netdev_ops->ndo_siocdevprivate) + return -EOPNOTSUPP; + + ret = dev->netdev_ops->ndo_siocdevprivate(dev, (struct ifreq *)&req, + NULL, SIOCPNGAUTOCONF); + if (ret < 0) + return ret; + + ASSERT_RTNL(); + ret = phonet_address_add(dev, req.ifr_phonet_autoconf.device); + if (ret) + return ret; + phonet_address_notify(RTM_NEWADDR, dev, + req.ifr_phonet_autoconf.device); + return 0; +} + +static void phonet_route_autodel(struct net_device *dev) +{ + struct phonet_net *pnn = phonet_pernet(dev_net(dev)); + unsigned int i; + DECLARE_BITMAP(deleted, 64); + + /* Remove left-over Phonet routes */ + bitmap_zero(deleted, 64); + mutex_lock(&pnn->routes.lock); + for (i = 0; i < 64; i++) + if (rcu_access_pointer(pnn->routes.table[i]) == dev) { + RCU_INIT_POINTER(pnn->routes.table[i], NULL); + set_bit(i, deleted); + } + mutex_unlock(&pnn->routes.lock); + + if (bitmap_empty(deleted, 64)) + return; /* short-circuit RCU */ + synchronize_rcu(); + for_each_set_bit(i, deleted, 64) { + rtm_phonet_notify(RTM_DELROUTE, dev, i); + dev_put(dev); + } +} + +/* notify Phonet of device events */ +static int phonet_device_notify(struct notifier_block *me, unsigned long what, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + switch (what) { + case NETDEV_REGISTER: + if (dev->type == ARPHRD_PHONET) + phonet_device_autoconf(dev); + break; + case NETDEV_UNREGISTER: + phonet_device_destroy(dev); + phonet_route_autodel(dev); + break; + } + return 0; + +} + +static struct notifier_block phonet_device_notifier = { + .notifier_call = phonet_device_notify, + .priority = 0, +}; + +/* Per-namespace Phonet devices handling */ +static int __net_init phonet_init_net(struct net *net) +{ + struct phonet_net *pnn = phonet_pernet(net); + + if (!proc_create_net("phonet", 0, net->proc_net, &pn_sock_seq_ops, + sizeof(struct seq_net_private))) + return -ENOMEM; + + INIT_LIST_HEAD(&pnn->pndevs.list); + mutex_init(&pnn->pndevs.lock); + mutex_init(&pnn->routes.lock); + return 0; +} + +static void __net_exit phonet_exit_net(struct net *net) +{ + struct phonet_net *pnn = phonet_pernet(net); + + remove_proc_entry("phonet", net->proc_net); + WARN_ON_ONCE(!list_empty(&pnn->pndevs.list)); +} + +static struct pernet_operations phonet_net_ops = { + .init = phonet_init_net, + .exit = phonet_exit_net, + .id = &phonet_net_id, + .size = sizeof(struct phonet_net), +}; + +/* Initialize Phonet devices list */ +int __init phonet_device_init(void) +{ + int err = register_pernet_subsys(&phonet_net_ops); + if (err) + return err; + + proc_create_net("pnresource", 0, init_net.proc_net, &pn_res_seq_ops, + sizeof(struct seq_net_private)); + register_netdevice_notifier(&phonet_device_notifier); + err = phonet_netlink_register(); + if (err) + phonet_device_exit(); + return err; +} + +void phonet_device_exit(void) +{ + rtnl_unregister_all(PF_PHONET); + unregister_netdevice_notifier(&phonet_device_notifier); + unregister_pernet_subsys(&phonet_net_ops); + remove_proc_entry("pnresource", init_net.proc_net); +} + +int phonet_route_add(struct net_device *dev, u8 daddr) +{ + struct phonet_net *pnn = phonet_pernet(dev_net(dev)); + struct phonet_routes *routes = &pnn->routes; + int err = -EEXIST; + + daddr = daddr >> 2; + mutex_lock(&routes->lock); + if (routes->table[daddr] == NULL) { + rcu_assign_pointer(routes->table[daddr], dev); + dev_hold(dev); + err = 0; + } + mutex_unlock(&routes->lock); + return err; +} + +int phonet_route_del(struct net_device *dev, u8 daddr) +{ + struct phonet_net *pnn = phonet_pernet(dev_net(dev)); + struct phonet_routes *routes = &pnn->routes; + + daddr = daddr >> 2; + mutex_lock(&routes->lock); + if (rcu_access_pointer(routes->table[daddr]) == dev) + RCU_INIT_POINTER(routes->table[daddr], NULL); + else + dev = NULL; + mutex_unlock(&routes->lock); + + if (!dev) + return -ENOENT; + synchronize_rcu(); + dev_put(dev); + return 0; +} + +struct net_device *phonet_route_get_rcu(struct net *net, u8 daddr) +{ + struct phonet_net *pnn = phonet_pernet(net); + struct phonet_routes *routes = &pnn->routes; + struct net_device *dev; + + daddr >>= 2; + dev = rcu_dereference(routes->table[daddr]); + return dev; +} + +struct net_device *phonet_route_output(struct net *net, u8 daddr) +{ + struct phonet_net *pnn = phonet_pernet(net); + struct phonet_routes *routes = &pnn->routes; + struct net_device *dev; + + daddr >>= 2; + rcu_read_lock(); + dev = rcu_dereference(routes->table[daddr]); + dev_hold(dev); + rcu_read_unlock(); + + if (!dev) + dev = phonet_device_get(net); /* Default route */ + return dev; +} diff --git a/net/phonet/pn_netlink.c b/net/phonet/pn_netlink.c new file mode 100644 index 000000000..59aebe296 --- /dev/null +++ b/net/phonet/pn_netlink.c @@ -0,0 +1,306 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * File: pn_netlink.c + * + * Phonet netlink interface + * + * Copyright (C) 2008 Nokia Corporation. + * + * Authors: Sakari Ailus + * Remi Denis-Courmont + */ + +#include +#include +#include +#include +#include +#include + +/* Device address handling */ + +static int fill_addr(struct sk_buff *skb, struct net_device *dev, u8 addr, + u32 portid, u32 seq, int event); + +void phonet_address_notify(int event, struct net_device *dev, u8 addr) +{ + struct sk_buff *skb; + int err = -ENOBUFS; + + skb = nlmsg_new(NLMSG_ALIGN(sizeof(struct ifaddrmsg)) + + nla_total_size(1), GFP_KERNEL); + if (skb == NULL) + goto errout; + err = fill_addr(skb, dev, addr, 0, 0, event); + if (err < 0) { + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + rtnl_notify(skb, dev_net(dev), 0, + RTNLGRP_PHONET_IFADDR, NULL, GFP_KERNEL); + return; +errout: + rtnl_set_sk_err(dev_net(dev), RTNLGRP_PHONET_IFADDR, err); +} + +static const struct nla_policy ifa_phonet_policy[IFA_MAX+1] = { + [IFA_LOCAL] = { .type = NLA_U8 }, +}; + +static int addr_doit(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *tb[IFA_MAX+1]; + struct net_device *dev; + struct ifaddrmsg *ifm; + int err; + u8 pnaddr; + + if (!netlink_capable(skb, CAP_NET_ADMIN)) + return -EPERM; + + if (!netlink_capable(skb, CAP_SYS_ADMIN)) + return -EPERM; + + ASSERT_RTNL(); + + err = nlmsg_parse_deprecated(nlh, sizeof(*ifm), tb, IFA_MAX, + ifa_phonet_policy, extack); + if (err < 0) + return err; + + ifm = nlmsg_data(nlh); + if (tb[IFA_LOCAL] == NULL) + return -EINVAL; + pnaddr = nla_get_u8(tb[IFA_LOCAL]); + if (pnaddr & 3) + /* Phonet addresses only have 6 high-order bits */ + return -EINVAL; + + dev = __dev_get_by_index(net, ifm->ifa_index); + if (dev == NULL) + return -ENODEV; + + if (nlh->nlmsg_type == RTM_NEWADDR) + err = phonet_address_add(dev, pnaddr); + else + err = phonet_address_del(dev, pnaddr); + if (!err) + phonet_address_notify(nlh->nlmsg_type, dev, pnaddr); + return err; +} + +static int fill_addr(struct sk_buff *skb, struct net_device *dev, u8 addr, + u32 portid, u32 seq, int event) +{ + struct ifaddrmsg *ifm; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(*ifm), 0); + if (nlh == NULL) + return -EMSGSIZE; + + ifm = nlmsg_data(nlh); + ifm->ifa_family = AF_PHONET; + ifm->ifa_prefixlen = 0; + ifm->ifa_flags = IFA_F_PERMANENT; + ifm->ifa_scope = RT_SCOPE_LINK; + ifm->ifa_index = dev->ifindex; + if (nla_put_u8(skb, IFA_LOCAL, addr)) + goto nla_put_failure; + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int getaddr_dumpit(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct phonet_device_list *pndevs; + struct phonet_device *pnd; + int dev_idx = 0, dev_start_idx = cb->args[0]; + int addr_idx = 0, addr_start_idx = cb->args[1]; + + pndevs = phonet_device_list(sock_net(skb->sk)); + rcu_read_lock(); + list_for_each_entry_rcu(pnd, &pndevs->list, list) { + u8 addr; + + if (dev_idx > dev_start_idx) + addr_start_idx = 0; + if (dev_idx++ < dev_start_idx) + continue; + + addr_idx = 0; + for_each_set_bit(addr, pnd->addrs, 64) { + if (addr_idx++ < addr_start_idx) + continue; + + if (fill_addr(skb, pnd->netdev, addr << 2, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, RTM_NEWADDR) < 0) + goto out; + } + } + +out: + rcu_read_unlock(); + cb->args[0] = dev_idx; + cb->args[1] = addr_idx; + + return skb->len; +} + +/* Routes handling */ + +static int fill_route(struct sk_buff *skb, struct net_device *dev, u8 dst, + u32 portid, u32 seq, int event) +{ + struct rtmsg *rtm; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(*rtm), 0); + if (nlh == NULL) + return -EMSGSIZE; + + rtm = nlmsg_data(nlh); + rtm->rtm_family = AF_PHONET; + rtm->rtm_dst_len = 6; + rtm->rtm_src_len = 0; + rtm->rtm_tos = 0; + rtm->rtm_table = RT_TABLE_MAIN; + rtm->rtm_protocol = RTPROT_STATIC; + rtm->rtm_scope = RT_SCOPE_UNIVERSE; + rtm->rtm_type = RTN_UNICAST; + rtm->rtm_flags = 0; + if (nla_put_u8(skb, RTA_DST, dst) || + nla_put_u32(skb, RTA_OIF, dev->ifindex)) + goto nla_put_failure; + nlmsg_end(skb, nlh); + return 0; + +nla_put_failure: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +void rtm_phonet_notify(int event, struct net_device *dev, u8 dst) +{ + struct sk_buff *skb; + int err = -ENOBUFS; + + skb = nlmsg_new(NLMSG_ALIGN(sizeof(struct ifaddrmsg)) + + nla_total_size(1) + nla_total_size(4), GFP_KERNEL); + if (skb == NULL) + goto errout; + err = fill_route(skb, dev, dst, 0, 0, event); + if (err < 0) { + WARN_ON(err == -EMSGSIZE); + kfree_skb(skb); + goto errout; + } + rtnl_notify(skb, dev_net(dev), 0, + RTNLGRP_PHONET_ROUTE, NULL, GFP_KERNEL); + return; +errout: + rtnl_set_sk_err(dev_net(dev), RTNLGRP_PHONET_ROUTE, err); +} + +static const struct nla_policy rtm_phonet_policy[RTA_MAX+1] = { + [RTA_DST] = { .type = NLA_U8 }, + [RTA_OIF] = { .type = NLA_U32 }, +}; + +static int route_doit(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *tb[RTA_MAX+1]; + struct net_device *dev; + struct rtmsg *rtm; + int err; + u8 dst; + + if (!netlink_capable(skb, CAP_NET_ADMIN)) + return -EPERM; + + if (!netlink_capable(skb, CAP_SYS_ADMIN)) + return -EPERM; + + ASSERT_RTNL(); + + err = nlmsg_parse_deprecated(nlh, sizeof(*rtm), tb, RTA_MAX, + rtm_phonet_policy, extack); + if (err < 0) + return err; + + rtm = nlmsg_data(nlh); + if (rtm->rtm_table != RT_TABLE_MAIN || rtm->rtm_type != RTN_UNICAST) + return -EINVAL; + if (tb[RTA_DST] == NULL || tb[RTA_OIF] == NULL) + return -EINVAL; + dst = nla_get_u8(tb[RTA_DST]); + if (dst & 3) /* Phonet addresses only have 6 high-order bits */ + return -EINVAL; + + dev = __dev_get_by_index(net, nla_get_u32(tb[RTA_OIF])); + if (dev == NULL) + return -ENODEV; + + if (nlh->nlmsg_type == RTM_NEWROUTE) + err = phonet_route_add(dev, dst); + else + err = phonet_route_del(dev, dst); + if (!err) + rtm_phonet_notify(nlh->nlmsg_type, dev, dst); + return err; +} + +static int route_dumpit(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + u8 addr; + + rcu_read_lock(); + for (addr = cb->args[0]; addr < 64; addr++) { + struct net_device *dev = phonet_route_get_rcu(net, addr << 2); + + if (!dev) + continue; + + if (fill_route(skb, dev, addr << 2, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, RTM_NEWROUTE) < 0) + goto out; + } + +out: + rcu_read_unlock(); + cb->args[0] = addr; + + return skb->len; +} + +int __init phonet_netlink_register(void) +{ + int err = rtnl_register_module(THIS_MODULE, PF_PHONET, RTM_NEWADDR, + addr_doit, NULL, 0); + if (err) + return err; + + /* Further rtnl_register_module() cannot fail */ + rtnl_register_module(THIS_MODULE, PF_PHONET, RTM_DELADDR, + addr_doit, NULL, 0); + rtnl_register_module(THIS_MODULE, PF_PHONET, RTM_GETADDR, + NULL, getaddr_dumpit, 0); + rtnl_register_module(THIS_MODULE, PF_PHONET, RTM_NEWROUTE, + route_doit, NULL, 0); + rtnl_register_module(THIS_MODULE, PF_PHONET, RTM_DELROUTE, + route_doit, NULL, 0); + rtnl_register_module(THIS_MODULE, PF_PHONET, RTM_GETROUTE, + NULL, route_dumpit, 0); + return 0; +} diff --git a/net/phonet/socket.c b/net/phonet/socket.c new file mode 100644 index 000000000..71e2caf6a --- /dev/null +++ b/net/phonet/socket.c @@ -0,0 +1,774 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * File: socket.c + * + * Phonet sockets + * + * Copyright (C) 2008 Nokia Corporation. + * + * Authors: Sakari Ailus + * Rémi Denis-Courmont + */ + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +static int pn_socket_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + + if (sk) { + sock->sk = NULL; + sk->sk_prot->close(sk, 0); + } + return 0; +} + +#define PN_HASHSIZE 16 +#define PN_HASHMASK (PN_HASHSIZE-1) + + +static struct { + struct hlist_head hlist[PN_HASHSIZE]; + struct mutex lock; +} pnsocks; + +void __init pn_sock_init(void) +{ + unsigned int i; + + for (i = 0; i < PN_HASHSIZE; i++) + INIT_HLIST_HEAD(pnsocks.hlist + i); + mutex_init(&pnsocks.lock); +} + +static struct hlist_head *pn_hash_list(u16 obj) +{ + return pnsocks.hlist + (obj & PN_HASHMASK); +} + +/* + * Find address based on socket address, match only certain fields. + * Also grab sock if it was found. Remember to sock_put it later. + */ +struct sock *pn_find_sock_by_sa(struct net *net, const struct sockaddr_pn *spn) +{ + struct sock *sknode; + struct sock *rval = NULL; + u16 obj = pn_sockaddr_get_object(spn); + u8 res = spn->spn_resource; + struct hlist_head *hlist = pn_hash_list(obj); + + rcu_read_lock(); + sk_for_each_rcu(sknode, hlist) { + struct pn_sock *pn = pn_sk(sknode); + BUG_ON(!pn->sobject); /* unbound socket */ + + if (!net_eq(sock_net(sknode), net)) + continue; + if (pn_port(obj)) { + /* Look up socket by port */ + if (pn_port(pn->sobject) != pn_port(obj)) + continue; + } else { + /* If port is zero, look up by resource */ + if (pn->resource != res) + continue; + } + if (pn_addr(pn->sobject) && + pn_addr(pn->sobject) != pn_addr(obj)) + continue; + + rval = sknode; + sock_hold(sknode); + break; + } + rcu_read_unlock(); + + return rval; +} + +/* Deliver a broadcast packet (only in bottom-half) */ +void pn_deliver_sock_broadcast(struct net *net, struct sk_buff *skb) +{ + struct hlist_head *hlist = pnsocks.hlist; + unsigned int h; + + rcu_read_lock(); + for (h = 0; h < PN_HASHSIZE; h++) { + struct sock *sknode; + + sk_for_each(sknode, hlist) { + struct sk_buff *clone; + + if (!net_eq(sock_net(sknode), net)) + continue; + if (!sock_flag(sknode, SOCK_BROADCAST)) + continue; + + clone = skb_clone(skb, GFP_ATOMIC); + if (clone) { + sock_hold(sknode); + sk_receive_skb(sknode, clone, 0); + } + } + hlist++; + } + rcu_read_unlock(); +} + +int pn_sock_hash(struct sock *sk) +{ + struct hlist_head *hlist = pn_hash_list(pn_sk(sk)->sobject); + + mutex_lock(&pnsocks.lock); + sk_add_node_rcu(sk, hlist); + mutex_unlock(&pnsocks.lock); + + return 0; +} +EXPORT_SYMBOL(pn_sock_hash); + +void pn_sock_unhash(struct sock *sk) +{ + mutex_lock(&pnsocks.lock); + sk_del_node_init_rcu(sk); + mutex_unlock(&pnsocks.lock); + pn_sock_unbind_all_res(sk); + synchronize_rcu(); +} +EXPORT_SYMBOL(pn_sock_unhash); + +static DEFINE_MUTEX(port_mutex); + +static int pn_socket_bind(struct socket *sock, struct sockaddr *addr, int len) +{ + struct sock *sk = sock->sk; + struct pn_sock *pn = pn_sk(sk); + struct sockaddr_pn *spn = (struct sockaddr_pn *)addr; + int err; + u16 handle; + u8 saddr; + + if (sk->sk_prot->bind) + return sk->sk_prot->bind(sk, addr, len); + + if (len < sizeof(struct sockaddr_pn)) + return -EINVAL; + if (spn->spn_family != AF_PHONET) + return -EAFNOSUPPORT; + + handle = pn_sockaddr_get_object((struct sockaddr_pn *)addr); + saddr = pn_addr(handle); + if (saddr && phonet_address_lookup(sock_net(sk), saddr)) + return -EADDRNOTAVAIL; + + lock_sock(sk); + if (sk->sk_state != TCP_CLOSE || pn_port(pn->sobject)) { + err = -EINVAL; /* attempt to rebind */ + goto out; + } + WARN_ON(sk_hashed(sk)); + mutex_lock(&port_mutex); + err = sk->sk_prot->get_port(sk, pn_port(handle)); + if (err) + goto out_port; + + /* get_port() sets the port, bind() sets the address if applicable */ + pn->sobject = pn_object(saddr, pn_port(pn->sobject)); + pn->resource = spn->spn_resource; + + /* Enable RX on the socket */ + err = sk->sk_prot->hash(sk); +out_port: + mutex_unlock(&port_mutex); +out: + release_sock(sk); + return err; +} + +static int pn_socket_autobind(struct socket *sock) +{ + struct sockaddr_pn sa; + int err; + + memset(&sa, 0, sizeof(sa)); + sa.spn_family = AF_PHONET; + err = pn_socket_bind(sock, (struct sockaddr *)&sa, + sizeof(struct sockaddr_pn)); + if (err != -EINVAL) + return err; + BUG_ON(!pn_port(pn_sk(sock->sk)->sobject)); + return 0; /* socket was already bound */ +} + +static int pn_socket_connect(struct socket *sock, struct sockaddr *addr, + int len, int flags) +{ + struct sock *sk = sock->sk; + struct pn_sock *pn = pn_sk(sk); + struct sockaddr_pn *spn = (struct sockaddr_pn *)addr; + struct task_struct *tsk = current; + long timeo = sock_rcvtimeo(sk, flags & O_NONBLOCK); + int err; + + if (pn_socket_autobind(sock)) + return -ENOBUFS; + if (len < sizeof(struct sockaddr_pn)) + return -EINVAL; + if (spn->spn_family != AF_PHONET) + return -EAFNOSUPPORT; + + lock_sock(sk); + + switch (sock->state) { + case SS_UNCONNECTED: + if (sk->sk_state != TCP_CLOSE) { + err = -EISCONN; + goto out; + } + break; + case SS_CONNECTING: + err = -EALREADY; + goto out; + default: + err = -EISCONN; + goto out; + } + + pn->dobject = pn_sockaddr_get_object(spn); + pn->resource = pn_sockaddr_get_resource(spn); + sock->state = SS_CONNECTING; + + err = sk->sk_prot->connect(sk, addr, len); + if (err) { + sock->state = SS_UNCONNECTED; + pn->dobject = 0; + goto out; + } + + while (sk->sk_state == TCP_SYN_SENT) { + DEFINE_WAIT(wait); + + if (!timeo) { + err = -EINPROGRESS; + goto out; + } + if (signal_pending(tsk)) { + err = sock_intr_errno(timeo); + goto out; + } + + prepare_to_wait_exclusive(sk_sleep(sk), &wait, + TASK_INTERRUPTIBLE); + release_sock(sk); + timeo = schedule_timeout(timeo); + lock_sock(sk); + finish_wait(sk_sleep(sk), &wait); + } + + if ((1 << sk->sk_state) & (TCPF_SYN_RECV|TCPF_ESTABLISHED)) + err = 0; + else if (sk->sk_state == TCP_CLOSE_WAIT) + err = -ECONNRESET; + else + err = -ECONNREFUSED; + sock->state = err ? SS_UNCONNECTED : SS_CONNECTED; +out: + release_sock(sk); + return err; +} + +static int pn_socket_accept(struct socket *sock, struct socket *newsock, + int flags, bool kern) +{ + struct sock *sk = sock->sk; + struct sock *newsk; + int err; + + if (unlikely(sk->sk_state != TCP_LISTEN)) + return -EINVAL; + + newsk = sk->sk_prot->accept(sk, flags, &err, kern); + if (!newsk) + return err; + + lock_sock(newsk); + sock_graft(newsk, newsock); + newsock->state = SS_CONNECTED; + release_sock(newsk); + return 0; +} + +static int pn_socket_getname(struct socket *sock, struct sockaddr *addr, + int peer) +{ + struct sock *sk = sock->sk; + struct pn_sock *pn = pn_sk(sk); + + memset(addr, 0, sizeof(struct sockaddr_pn)); + addr->sa_family = AF_PHONET; + if (!peer) /* Race with bind() here is userland's problem. */ + pn_sockaddr_set_object((struct sockaddr_pn *)addr, + pn->sobject); + + return sizeof(struct sockaddr_pn); +} + +static __poll_t pn_socket_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + struct sock *sk = sock->sk; + struct pep_sock *pn = pep_sk(sk); + __poll_t mask = 0; + + poll_wait(file, sk_sleep(sk), wait); + + if (sk->sk_state == TCP_CLOSE) + return EPOLLERR; + if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) + mask |= EPOLLIN | EPOLLRDNORM; + if (!skb_queue_empty_lockless(&pn->ctrlreq_queue)) + mask |= EPOLLPRI; + if (!mask && sk->sk_state == TCP_CLOSE_WAIT) + return EPOLLHUP; + + if (sk->sk_state == TCP_ESTABLISHED && + refcount_read(&sk->sk_wmem_alloc) < sk->sk_sndbuf && + atomic_read(&pn->tx_credits)) + mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND; + + return mask; +} + +static int pn_socket_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + struct sock *sk = sock->sk; + struct pn_sock *pn = pn_sk(sk); + + if (cmd == SIOCPNGETOBJECT) { + struct net_device *dev; + u16 handle; + u8 saddr; + + if (get_user(handle, (__u16 __user *)arg)) + return -EFAULT; + + lock_sock(sk); + if (sk->sk_bound_dev_if) + dev = dev_get_by_index(sock_net(sk), + sk->sk_bound_dev_if); + else + dev = phonet_device_get(sock_net(sk)); + if (dev && (dev->flags & IFF_UP)) + saddr = phonet_address_get(dev, pn_addr(handle)); + else + saddr = PN_NO_ADDR; + release_sock(sk); + + dev_put(dev); + if (saddr == PN_NO_ADDR) + return -EHOSTUNREACH; + + handle = pn_object(saddr, pn_port(pn->sobject)); + return put_user(handle, (__u16 __user *)arg); + } + + return sk->sk_prot->ioctl(sk, cmd, arg); +} + +static int pn_socket_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + int err = 0; + + if (pn_socket_autobind(sock)) + return -ENOBUFS; + + lock_sock(sk); + if (sock->state != SS_UNCONNECTED) { + err = -EINVAL; + goto out; + } + + if (sk->sk_state != TCP_LISTEN) { + sk->sk_state = TCP_LISTEN; + sk->sk_ack_backlog = 0; + } + sk->sk_max_ack_backlog = backlog; +out: + release_sock(sk); + return err; +} + +static int pn_socket_sendmsg(struct socket *sock, struct msghdr *m, + size_t total_len) +{ + struct sock *sk = sock->sk; + + if (pn_socket_autobind(sock)) + return -EAGAIN; + + return sk->sk_prot->sendmsg(sk, m, total_len); +} + +const struct proto_ops phonet_dgram_ops = { + .family = AF_PHONET, + .owner = THIS_MODULE, + .release = pn_socket_release, + .bind = pn_socket_bind, + .connect = sock_no_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = pn_socket_getname, + .poll = datagram_poll, + .ioctl = pn_socket_ioctl, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .sendmsg = pn_socket_sendmsg, + .recvmsg = sock_common_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +const struct proto_ops phonet_stream_ops = { + .family = AF_PHONET, + .owner = THIS_MODULE, + .release = pn_socket_release, + .bind = pn_socket_bind, + .connect = pn_socket_connect, + .socketpair = sock_no_socketpair, + .accept = pn_socket_accept, + .getname = pn_socket_getname, + .poll = pn_socket_poll, + .ioctl = pn_socket_ioctl, + .listen = pn_socket_listen, + .shutdown = sock_no_shutdown, + .setsockopt = sock_common_setsockopt, + .getsockopt = sock_common_getsockopt, + .sendmsg = pn_socket_sendmsg, + .recvmsg = sock_common_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; +EXPORT_SYMBOL(phonet_stream_ops); + +/* allocate port for a socket */ +int pn_sock_get_port(struct sock *sk, unsigned short sport) +{ + static int port_cur; + struct net *net = sock_net(sk); + struct pn_sock *pn = pn_sk(sk); + struct sockaddr_pn try_sa; + struct sock *tmpsk; + + memset(&try_sa, 0, sizeof(struct sockaddr_pn)); + try_sa.spn_family = AF_PHONET; + WARN_ON(!mutex_is_locked(&port_mutex)); + if (!sport) { + /* search free port */ + int port, pmin, pmax; + + phonet_get_local_port_range(&pmin, &pmax); + for (port = pmin; port <= pmax; port++) { + port_cur++; + if (port_cur < pmin || port_cur > pmax) + port_cur = pmin; + + pn_sockaddr_set_port(&try_sa, port_cur); + tmpsk = pn_find_sock_by_sa(net, &try_sa); + if (tmpsk == NULL) { + sport = port_cur; + goto found; + } else + sock_put(tmpsk); + } + } else { + /* try to find specific port */ + pn_sockaddr_set_port(&try_sa, sport); + tmpsk = pn_find_sock_by_sa(net, &try_sa); + if (tmpsk == NULL) + /* No sock there! We can use that port... */ + goto found; + else + sock_put(tmpsk); + } + /* the port must be in use already */ + return -EADDRINUSE; + +found: + pn->sobject = pn_object(pn_addr(pn->sobject), sport); + return 0; +} +EXPORT_SYMBOL(pn_sock_get_port); + +#ifdef CONFIG_PROC_FS +static struct sock *pn_sock_get_idx(struct seq_file *seq, loff_t pos) +{ + struct net *net = seq_file_net(seq); + struct hlist_head *hlist = pnsocks.hlist; + struct sock *sknode; + unsigned int h; + + for (h = 0; h < PN_HASHSIZE; h++) { + sk_for_each_rcu(sknode, hlist) { + if (!net_eq(net, sock_net(sknode))) + continue; + if (!pos) + return sknode; + pos--; + } + hlist++; + } + return NULL; +} + +static struct sock *pn_sock_get_next(struct seq_file *seq, struct sock *sk) +{ + struct net *net = seq_file_net(seq); + + do + sk = sk_next(sk); + while (sk && !net_eq(net, sock_net(sk))); + + return sk; +} + +static void *pn_sock_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(rcu) +{ + rcu_read_lock(); + return *pos ? pn_sock_get_idx(seq, *pos - 1) : SEQ_START_TOKEN; +} + +static void *pn_sock_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct sock *sk; + + if (v == SEQ_START_TOKEN) + sk = pn_sock_get_idx(seq, 0); + else + sk = pn_sock_get_next(seq, v); + (*pos)++; + return sk; +} + +static void pn_sock_seq_stop(struct seq_file *seq, void *v) + __releases(rcu) +{ + rcu_read_unlock(); +} + +static int pn_sock_seq_show(struct seq_file *seq, void *v) +{ + seq_setwidth(seq, 127); + if (v == SEQ_START_TOKEN) + seq_puts(seq, "pt loc rem rs st tx_queue rx_queue " + " uid inode ref pointer drops"); + else { + struct sock *sk = v; + struct pn_sock *pn = pn_sk(sk); + + seq_printf(seq, "%2d %04X:%04X:%02X %02X %08X:%08X %5d %lu " + "%d %pK %u", + sk->sk_protocol, pn->sobject, pn->dobject, + pn->resource, sk->sk_state, + sk_wmem_alloc_get(sk), sk_rmem_alloc_get(sk), + from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk)), + sock_i_ino(sk), + refcount_read(&sk->sk_refcnt), sk, + atomic_read(&sk->sk_drops)); + } + seq_pad(seq, '\n'); + return 0; +} + +const struct seq_operations pn_sock_seq_ops = { + .start = pn_sock_seq_start, + .next = pn_sock_seq_next, + .stop = pn_sock_seq_stop, + .show = pn_sock_seq_show, +}; +#endif + +static struct { + struct sock *sk[256]; +} pnres; + +/* + * Find and hold socket based on resource. + */ +struct sock *pn_find_sock_by_res(struct net *net, u8 res) +{ + struct sock *sk; + + if (!net_eq(net, &init_net)) + return NULL; + + rcu_read_lock(); + sk = rcu_dereference(pnres.sk[res]); + if (sk) + sock_hold(sk); + rcu_read_unlock(); + return sk; +} + +static DEFINE_MUTEX(resource_mutex); + +int pn_sock_bind_res(struct sock *sk, u8 res) +{ + int ret = -EADDRINUSE; + + if (!net_eq(sock_net(sk), &init_net)) + return -ENOIOCTLCMD; + if (!capable(CAP_SYS_ADMIN)) + return -EPERM; + if (pn_socket_autobind(sk->sk_socket)) + return -EAGAIN; + + mutex_lock(&resource_mutex); + if (pnres.sk[res] == NULL) { + sock_hold(sk); + rcu_assign_pointer(pnres.sk[res], sk); + ret = 0; + } + mutex_unlock(&resource_mutex); + return ret; +} + +int pn_sock_unbind_res(struct sock *sk, u8 res) +{ + int ret = -ENOENT; + + if (!capable(CAP_SYS_ADMIN)) + return -EPERM; + + mutex_lock(&resource_mutex); + if (pnres.sk[res] == sk) { + RCU_INIT_POINTER(pnres.sk[res], NULL); + ret = 0; + } + mutex_unlock(&resource_mutex); + + if (ret == 0) { + synchronize_rcu(); + sock_put(sk); + } + return ret; +} + +void pn_sock_unbind_all_res(struct sock *sk) +{ + unsigned int res, match = 0; + + mutex_lock(&resource_mutex); + for (res = 0; res < 256; res++) { + if (pnres.sk[res] == sk) { + RCU_INIT_POINTER(pnres.sk[res], NULL); + match++; + } + } + mutex_unlock(&resource_mutex); + + while (match > 0) { + __sock_put(sk); + match--; + } + /* Caller is responsible for RCU sync before final sock_put() */ +} + +#ifdef CONFIG_PROC_FS +static struct sock **pn_res_get_idx(struct seq_file *seq, loff_t pos) +{ + struct net *net = seq_file_net(seq); + unsigned int i; + + if (!net_eq(net, &init_net)) + return NULL; + + for (i = 0; i < 256; i++) { + if (pnres.sk[i] == NULL) + continue; + if (!pos) + return pnres.sk + i; + pos--; + } + return NULL; +} + +static struct sock **pn_res_get_next(struct seq_file *seq, struct sock **sk) +{ + struct net *net = seq_file_net(seq); + unsigned int i; + + BUG_ON(!net_eq(net, &init_net)); + + for (i = (sk - pnres.sk) + 1; i < 256; i++) + if (pnres.sk[i]) + return pnres.sk + i; + return NULL; +} + +static void *pn_res_seq_start(struct seq_file *seq, loff_t *pos) + __acquires(resource_mutex) +{ + mutex_lock(&resource_mutex); + return *pos ? pn_res_get_idx(seq, *pos - 1) : SEQ_START_TOKEN; +} + +static void *pn_res_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct sock **sk; + + if (v == SEQ_START_TOKEN) + sk = pn_res_get_idx(seq, 0); + else + sk = pn_res_get_next(seq, v); + (*pos)++; + return sk; +} + +static void pn_res_seq_stop(struct seq_file *seq, void *v) + __releases(resource_mutex) +{ + mutex_unlock(&resource_mutex); +} + +static int pn_res_seq_show(struct seq_file *seq, void *v) +{ + seq_setwidth(seq, 63); + if (v == SEQ_START_TOKEN) + seq_puts(seq, "rs uid inode"); + else { + struct sock **psk = v; + struct sock *sk = *psk; + + seq_printf(seq, "%02X %5u %lu", + (int) (psk - pnres.sk), + from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk)), + sock_i_ino(sk)); + } + seq_pad(seq, '\n'); + return 0; +} + +const struct seq_operations pn_res_seq_ops = { + .start = pn_res_seq_start, + .next = pn_res_seq_next, + .stop = pn_res_seq_stop, + .show = pn_res_seq_show, +}; +#endif diff --git a/net/phonet/sysctl.c b/net/phonet/sysctl.c new file mode 100644 index 000000000..0d0bf4138 --- /dev/null +++ b/net/phonet/sysctl.c @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * File: sysctl.c + * + * Phonet /proc/sys/net/phonet interface implementation + * + * Copyright (C) 2008 Nokia Corporation. + * + * Author: Rémi Denis-Courmont + */ + +#include +#include +#include +#include + +#include +#include +#include + +#define DYNAMIC_PORT_MIN 0x40 +#define DYNAMIC_PORT_MAX 0x7f + +static DEFINE_SEQLOCK(local_port_range_lock); +static int local_port_range_min[2] = {0, 0}; +static int local_port_range_max[2] = {1023, 1023}; +static int local_port_range[2] = {DYNAMIC_PORT_MIN, DYNAMIC_PORT_MAX}; +static struct ctl_table_header *phonet_table_hrd; + +static void set_local_port_range(int range[2]) +{ + write_seqlock(&local_port_range_lock); + local_port_range[0] = range[0]; + local_port_range[1] = range[1]; + write_sequnlock(&local_port_range_lock); +} + +void phonet_get_local_port_range(int *min, int *max) +{ + unsigned int seq; + + do { + seq = read_seqbegin(&local_port_range_lock); + if (min) + *min = local_port_range[0]; + if (max) + *max = local_port_range[1]; + } while (read_seqretry(&local_port_range_lock, seq)); +} + +static int proc_local_port_range(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + int ret; + int range[2] = {local_port_range[0], local_port_range[1]}; + struct ctl_table tmp = { + .data = &range, + .maxlen = sizeof(range), + .mode = table->mode, + .extra1 = &local_port_range_min, + .extra2 = &local_port_range_max, + }; + + ret = proc_dointvec_minmax(&tmp, write, buffer, lenp, ppos); + + if (write && ret == 0) { + if (range[1] < range[0]) + ret = -EINVAL; + else + set_local_port_range(range); + } + + return ret; +} + +static struct ctl_table phonet_table[] = { + { + .procname = "local_port_range", + .data = &local_port_range, + .maxlen = sizeof(local_port_range), + .mode = 0644, + .proc_handler = proc_local_port_range, + }, + { } +}; + +int __init phonet_sysctl_init(void) +{ + phonet_table_hrd = register_net_sysctl(&init_net, "net/phonet", phonet_table); + return phonet_table_hrd == NULL ? -ENOMEM : 0; +} + +void phonet_sysctl_exit(void) +{ + unregister_net_sysctl_table(phonet_table_hrd); +} diff --git a/net/psample/Kconfig b/net/psample/Kconfig new file mode 100644 index 000000000..be0b83920 --- /dev/null +++ b/net/psample/Kconfig @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# psample packet sampling configuration +# + +menuconfig PSAMPLE + tristate "Packet-sampling netlink channel" + default n + help + Say Y here to add support for packet-sampling netlink channel + This netlink channel allows transferring packets alongside some + metadata to userspace. + + To compile this support as a module, choose M here: the module will + be called psample. diff --git a/net/psample/Makefile b/net/psample/Makefile new file mode 100644 index 000000000..a04367b9e --- /dev/null +++ b/net/psample/Makefile @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the psample netlink channel +# + +obj-$(CONFIG_PSAMPLE) += psample.o diff --git a/net/psample/psample.c b/net/psample/psample.c new file mode 100644 index 000000000..c34e90285 --- /dev/null +++ b/net/psample/psample.c @@ -0,0 +1,513 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/psample/psample.c - Netlink channel for packet sampling + * Copyright (c) 2017 Yotam Gigi + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define PSAMPLE_MAX_PACKET_SIZE 0xffff + +static LIST_HEAD(psample_groups_list); +static DEFINE_SPINLOCK(psample_groups_lock); + +/* multicast groups */ +enum psample_nl_multicast_groups { + PSAMPLE_NL_MCGRP_CONFIG, + PSAMPLE_NL_MCGRP_SAMPLE, +}; + +static const struct genl_multicast_group psample_nl_mcgrps[] = { + [PSAMPLE_NL_MCGRP_CONFIG] = { .name = PSAMPLE_NL_MCGRP_CONFIG_NAME }, + [PSAMPLE_NL_MCGRP_SAMPLE] = { .name = PSAMPLE_NL_MCGRP_SAMPLE_NAME, + .flags = GENL_UNS_ADMIN_PERM }, +}; + +static struct genl_family psample_nl_family __ro_after_init; + +static int psample_group_nl_fill(struct sk_buff *msg, + struct psample_group *group, + enum psample_command cmd, u32 portid, u32 seq, + int flags) +{ + void *hdr; + int ret; + + hdr = genlmsg_put(msg, portid, seq, &psample_nl_family, flags, cmd); + if (!hdr) + return -EMSGSIZE; + + ret = nla_put_u32(msg, PSAMPLE_ATTR_SAMPLE_GROUP, group->group_num); + if (ret < 0) + goto error; + + ret = nla_put_u32(msg, PSAMPLE_ATTR_GROUP_REFCOUNT, group->refcount); + if (ret < 0) + goto error; + + ret = nla_put_u32(msg, PSAMPLE_ATTR_GROUP_SEQ, group->seq); + if (ret < 0) + goto error; + + genlmsg_end(msg, hdr); + return 0; + +error: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int psample_nl_cmd_get_group_dumpit(struct sk_buff *msg, + struct netlink_callback *cb) +{ + struct psample_group *group; + int start = cb->args[0]; + int idx = 0; + int err; + + spin_lock_bh(&psample_groups_lock); + list_for_each_entry(group, &psample_groups_list, list) { + if (!net_eq(group->net, sock_net(msg->sk))) + continue; + if (idx < start) { + idx++; + continue; + } + err = psample_group_nl_fill(msg, group, PSAMPLE_CMD_NEW_GROUP, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI); + if (err) + break; + idx++; + } + + spin_unlock_bh(&psample_groups_lock); + cb->args[0] = idx; + return msg->len; +} + +static const struct genl_small_ops psample_nl_ops[] = { + { + .cmd = PSAMPLE_CMD_GET_GROUP, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .dumpit = psample_nl_cmd_get_group_dumpit, + /* can be retrieved by unprivileged users */ + } +}; + +static struct genl_family psample_nl_family __ro_after_init = { + .name = PSAMPLE_GENL_NAME, + .version = PSAMPLE_GENL_VERSION, + .maxattr = PSAMPLE_ATTR_MAX, + .netnsok = true, + .module = THIS_MODULE, + .mcgrps = psample_nl_mcgrps, + .small_ops = psample_nl_ops, + .n_small_ops = ARRAY_SIZE(psample_nl_ops), + .resv_start_op = PSAMPLE_CMD_GET_GROUP + 1, + .n_mcgrps = ARRAY_SIZE(psample_nl_mcgrps), +}; + +static void psample_group_notify(struct psample_group *group, + enum psample_command cmd) +{ + struct sk_buff *msg; + int err; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!msg) + return; + + err = psample_group_nl_fill(msg, group, cmd, 0, 0, NLM_F_MULTI); + if (!err) + genlmsg_multicast_netns(&psample_nl_family, group->net, msg, 0, + PSAMPLE_NL_MCGRP_CONFIG, GFP_ATOMIC); + else + nlmsg_free(msg); +} + +static struct psample_group *psample_group_create(struct net *net, + u32 group_num) +{ + struct psample_group *group; + + group = kzalloc(sizeof(*group), GFP_ATOMIC); + if (!group) + return NULL; + + group->net = net; + group->group_num = group_num; + list_add_tail(&group->list, &psample_groups_list); + + psample_group_notify(group, PSAMPLE_CMD_NEW_GROUP); + return group; +} + +static void psample_group_destroy(struct psample_group *group) +{ + psample_group_notify(group, PSAMPLE_CMD_DEL_GROUP); + list_del(&group->list); + kfree_rcu(group, rcu); +} + +static struct psample_group * +psample_group_lookup(struct net *net, u32 group_num) +{ + struct psample_group *group; + + list_for_each_entry(group, &psample_groups_list, list) + if ((group->group_num == group_num) && (group->net == net)) + return group; + return NULL; +} + +struct psample_group *psample_group_get(struct net *net, u32 group_num) +{ + struct psample_group *group; + + spin_lock_bh(&psample_groups_lock); + + group = psample_group_lookup(net, group_num); + if (!group) { + group = psample_group_create(net, group_num); + if (!group) + goto out; + } + group->refcount++; + +out: + spin_unlock_bh(&psample_groups_lock); + return group; +} +EXPORT_SYMBOL_GPL(psample_group_get); + +void psample_group_take(struct psample_group *group) +{ + spin_lock_bh(&psample_groups_lock); + group->refcount++; + spin_unlock_bh(&psample_groups_lock); +} +EXPORT_SYMBOL_GPL(psample_group_take); + +void psample_group_put(struct psample_group *group) +{ + spin_lock_bh(&psample_groups_lock); + + if (--group->refcount == 0) + psample_group_destroy(group); + + spin_unlock_bh(&psample_groups_lock); +} +EXPORT_SYMBOL_GPL(psample_group_put); + +#ifdef CONFIG_INET +static int __psample_ip_tun_to_nlattr(struct sk_buff *skb, + struct ip_tunnel_info *tun_info) +{ + unsigned short tun_proto = ip_tunnel_info_af(tun_info); + const void *tun_opts = ip_tunnel_info_opts(tun_info); + const struct ip_tunnel_key *tun_key = &tun_info->key; + int tun_opts_len = tun_info->options_len; + + if (tun_key->tun_flags & TUNNEL_KEY && + nla_put_be64(skb, PSAMPLE_TUNNEL_KEY_ATTR_ID, tun_key->tun_id, + PSAMPLE_TUNNEL_KEY_ATTR_PAD)) + return -EMSGSIZE; + + if (tun_info->mode & IP_TUNNEL_INFO_BRIDGE && + nla_put_flag(skb, PSAMPLE_TUNNEL_KEY_ATTR_IPV4_INFO_BRIDGE)) + return -EMSGSIZE; + + switch (tun_proto) { + case AF_INET: + if (tun_key->u.ipv4.src && + nla_put_in_addr(skb, PSAMPLE_TUNNEL_KEY_ATTR_IPV4_SRC, + tun_key->u.ipv4.src)) + return -EMSGSIZE; + if (tun_key->u.ipv4.dst && + nla_put_in_addr(skb, PSAMPLE_TUNNEL_KEY_ATTR_IPV4_DST, + tun_key->u.ipv4.dst)) + return -EMSGSIZE; + break; + case AF_INET6: + if (!ipv6_addr_any(&tun_key->u.ipv6.src) && + nla_put_in6_addr(skb, PSAMPLE_TUNNEL_KEY_ATTR_IPV6_SRC, + &tun_key->u.ipv6.src)) + return -EMSGSIZE; + if (!ipv6_addr_any(&tun_key->u.ipv6.dst) && + nla_put_in6_addr(skb, PSAMPLE_TUNNEL_KEY_ATTR_IPV6_DST, + &tun_key->u.ipv6.dst)) + return -EMSGSIZE; + break; + } + if (tun_key->tos && + nla_put_u8(skb, PSAMPLE_TUNNEL_KEY_ATTR_TOS, tun_key->tos)) + return -EMSGSIZE; + if (nla_put_u8(skb, PSAMPLE_TUNNEL_KEY_ATTR_TTL, tun_key->ttl)) + return -EMSGSIZE; + if ((tun_key->tun_flags & TUNNEL_DONT_FRAGMENT) && + nla_put_flag(skb, PSAMPLE_TUNNEL_KEY_ATTR_DONT_FRAGMENT)) + return -EMSGSIZE; + if ((tun_key->tun_flags & TUNNEL_CSUM) && + nla_put_flag(skb, PSAMPLE_TUNNEL_KEY_ATTR_CSUM)) + return -EMSGSIZE; + if (tun_key->tp_src && + nla_put_be16(skb, PSAMPLE_TUNNEL_KEY_ATTR_TP_SRC, tun_key->tp_src)) + return -EMSGSIZE; + if (tun_key->tp_dst && + nla_put_be16(skb, PSAMPLE_TUNNEL_KEY_ATTR_TP_DST, tun_key->tp_dst)) + return -EMSGSIZE; + if ((tun_key->tun_flags & TUNNEL_OAM) && + nla_put_flag(skb, PSAMPLE_TUNNEL_KEY_ATTR_OAM)) + return -EMSGSIZE; + if (tun_opts_len) { + if (tun_key->tun_flags & TUNNEL_GENEVE_OPT && + nla_put(skb, PSAMPLE_TUNNEL_KEY_ATTR_GENEVE_OPTS, + tun_opts_len, tun_opts)) + return -EMSGSIZE; + else if (tun_key->tun_flags & TUNNEL_ERSPAN_OPT && + nla_put(skb, PSAMPLE_TUNNEL_KEY_ATTR_ERSPAN_OPTS, + tun_opts_len, tun_opts)) + return -EMSGSIZE; + } + + return 0; +} + +static int psample_ip_tun_to_nlattr(struct sk_buff *skb, + struct ip_tunnel_info *tun_info) +{ + struct nlattr *nla; + int err; + + nla = nla_nest_start_noflag(skb, PSAMPLE_ATTR_TUNNEL); + if (!nla) + return -EMSGSIZE; + + err = __psample_ip_tun_to_nlattr(skb, tun_info); + if (err) { + nla_nest_cancel(skb, nla); + return err; + } + + nla_nest_end(skb, nla); + + return 0; +} + +static int psample_tunnel_meta_len(struct ip_tunnel_info *tun_info) +{ + unsigned short tun_proto = ip_tunnel_info_af(tun_info); + const struct ip_tunnel_key *tun_key = &tun_info->key; + int tun_opts_len = tun_info->options_len; + int sum = nla_total_size(0); /* PSAMPLE_ATTR_TUNNEL */ + + if (tun_key->tun_flags & TUNNEL_KEY) + sum += nla_total_size_64bit(sizeof(u64)); + + if (tun_info->mode & IP_TUNNEL_INFO_BRIDGE) + sum += nla_total_size(0); + + switch (tun_proto) { + case AF_INET: + if (tun_key->u.ipv4.src) + sum += nla_total_size(sizeof(u32)); + if (tun_key->u.ipv4.dst) + sum += nla_total_size(sizeof(u32)); + break; + case AF_INET6: + if (!ipv6_addr_any(&tun_key->u.ipv6.src)) + sum += nla_total_size(sizeof(struct in6_addr)); + if (!ipv6_addr_any(&tun_key->u.ipv6.dst)) + sum += nla_total_size(sizeof(struct in6_addr)); + break; + } + if (tun_key->tos) + sum += nla_total_size(sizeof(u8)); + sum += nla_total_size(sizeof(u8)); /* TTL */ + if (tun_key->tun_flags & TUNNEL_DONT_FRAGMENT) + sum += nla_total_size(0); + if (tun_key->tun_flags & TUNNEL_CSUM) + sum += nla_total_size(0); + if (tun_key->tp_src) + sum += nla_total_size(sizeof(u16)); + if (tun_key->tp_dst) + sum += nla_total_size(sizeof(u16)); + if (tun_key->tun_flags & TUNNEL_OAM) + sum += nla_total_size(0); + if (tun_opts_len) { + if (tun_key->tun_flags & TUNNEL_GENEVE_OPT) + sum += nla_total_size(tun_opts_len); + else if (tun_key->tun_flags & TUNNEL_ERSPAN_OPT) + sum += nla_total_size(tun_opts_len); + } + + return sum; +} +#endif + +void psample_sample_packet(struct psample_group *group, struct sk_buff *skb, + u32 sample_rate, const struct psample_metadata *md) +{ + ktime_t tstamp = ktime_get_real(); + int out_ifindex = md->out_ifindex; + int in_ifindex = md->in_ifindex; + u32 trunc_size = md->trunc_size; +#ifdef CONFIG_INET + struct ip_tunnel_info *tun_info; +#endif + struct sk_buff *nl_skb; + int data_len; + int meta_len; + void *data; + int ret; + + meta_len = (in_ifindex ? nla_total_size(sizeof(u16)) : 0) + + (out_ifindex ? nla_total_size(sizeof(u16)) : 0) + + (md->out_tc_valid ? nla_total_size(sizeof(u16)) : 0) + + (md->out_tc_occ_valid ? nla_total_size_64bit(sizeof(u64)) : 0) + + (md->latency_valid ? nla_total_size_64bit(sizeof(u64)) : 0) + + nla_total_size(sizeof(u32)) + /* sample_rate */ + nla_total_size(sizeof(u32)) + /* orig_size */ + nla_total_size(sizeof(u32)) + /* group_num */ + nla_total_size(sizeof(u32)) + /* seq */ + nla_total_size_64bit(sizeof(u64)) + /* timestamp */ + nla_total_size(sizeof(u16)); /* protocol */ + +#ifdef CONFIG_INET + tun_info = skb_tunnel_info(skb); + if (tun_info) + meta_len += psample_tunnel_meta_len(tun_info); +#endif + + data_len = min(skb->len, trunc_size); + if (meta_len + nla_total_size(data_len) > PSAMPLE_MAX_PACKET_SIZE) + data_len = PSAMPLE_MAX_PACKET_SIZE - meta_len - NLA_HDRLEN + - NLA_ALIGNTO; + + nl_skb = genlmsg_new(meta_len + nla_total_size(data_len), GFP_ATOMIC); + if (unlikely(!nl_skb)) + return; + + data = genlmsg_put(nl_skb, 0, 0, &psample_nl_family, 0, + PSAMPLE_CMD_SAMPLE); + if (unlikely(!data)) + goto error; + + if (in_ifindex) { + ret = nla_put_u16(nl_skb, PSAMPLE_ATTR_IIFINDEX, in_ifindex); + if (unlikely(ret < 0)) + goto error; + } + + if (out_ifindex) { + ret = nla_put_u16(nl_skb, PSAMPLE_ATTR_OIFINDEX, out_ifindex); + if (unlikely(ret < 0)) + goto error; + } + + ret = nla_put_u32(nl_skb, PSAMPLE_ATTR_SAMPLE_RATE, sample_rate); + if (unlikely(ret < 0)) + goto error; + + ret = nla_put_u32(nl_skb, PSAMPLE_ATTR_ORIGSIZE, skb->len); + if (unlikely(ret < 0)) + goto error; + + ret = nla_put_u32(nl_skb, PSAMPLE_ATTR_SAMPLE_GROUP, group->group_num); + if (unlikely(ret < 0)) + goto error; + + ret = nla_put_u32(nl_skb, PSAMPLE_ATTR_GROUP_SEQ, group->seq++); + if (unlikely(ret < 0)) + goto error; + + if (md->out_tc_valid) { + ret = nla_put_u16(nl_skb, PSAMPLE_ATTR_OUT_TC, md->out_tc); + if (unlikely(ret < 0)) + goto error; + } + + if (md->out_tc_occ_valid) { + ret = nla_put_u64_64bit(nl_skb, PSAMPLE_ATTR_OUT_TC_OCC, + md->out_tc_occ, PSAMPLE_ATTR_PAD); + if (unlikely(ret < 0)) + goto error; + } + + if (md->latency_valid) { + ret = nla_put_u64_64bit(nl_skb, PSAMPLE_ATTR_LATENCY, + md->latency, PSAMPLE_ATTR_PAD); + if (unlikely(ret < 0)) + goto error; + } + + ret = nla_put_u64_64bit(nl_skb, PSAMPLE_ATTR_TIMESTAMP, + ktime_to_ns(tstamp), PSAMPLE_ATTR_PAD); + if (unlikely(ret < 0)) + goto error; + + ret = nla_put_u16(nl_skb, PSAMPLE_ATTR_PROTO, + be16_to_cpu(skb->protocol)); + if (unlikely(ret < 0)) + goto error; + + if (data_len) { + int nla_len = nla_total_size(data_len); + struct nlattr *nla; + + nla = skb_put(nl_skb, nla_len); + nla->nla_type = PSAMPLE_ATTR_DATA; + nla->nla_len = nla_attr_size(data_len); + + if (skb_copy_bits(skb, 0, nla_data(nla), data_len)) + goto error; + } + +#ifdef CONFIG_INET + if (tun_info) { + ret = psample_ip_tun_to_nlattr(nl_skb, tun_info); + if (unlikely(ret < 0)) + goto error; + } +#endif + + genlmsg_end(nl_skb, data); + genlmsg_multicast_netns(&psample_nl_family, group->net, nl_skb, 0, + PSAMPLE_NL_MCGRP_SAMPLE, GFP_ATOMIC); + + return; +error: + pr_err_ratelimited("Could not create psample log message\n"); + nlmsg_free(nl_skb); +} +EXPORT_SYMBOL_GPL(psample_sample_packet); + +static int __init psample_module_init(void) +{ + return genl_register_family(&psample_nl_family); +} + +static void __exit psample_module_exit(void) +{ + genl_unregister_family(&psample_nl_family); +} + +module_init(psample_module_init); +module_exit(psample_module_exit); + +MODULE_AUTHOR("Yotam Gigi "); +MODULE_DESCRIPTION("netlink channel for packet sampling"); +MODULE_LICENSE("GPL v2"); diff --git a/net/qrtr/Kconfig b/net/qrtr/Kconfig new file mode 100644 index 000000000..b4020b847 --- /dev/null +++ b/net/qrtr/Kconfig @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: GPL-2.0-only +# Qualcomm IPC Router configuration +# + +config QRTR + tristate "Qualcomm IPC Router support" + help + Say Y if you intend to use Qualcomm IPC router protocol. The + protocol is used to communicate with services provided by other + hardware blocks in the system. + + In order to do service lookups, a userspace daemon is required to + maintain a service listing. + +if QRTR + +config QRTR_SMD + tristate "SMD IPC Router channels" + depends on RPMSG || (COMPILE_TEST && RPMSG=n) + help + Say Y here to support SMD based ipcrouter channels. SMD is the + most common transport for IPC Router. + +config QRTR_TUN + tristate "TUN device for Qualcomm IPC Router" + help + Say Y here to expose a character device that allows user space to + implement endpoints of QRTR, for purpose of tunneling data to other + hosts or testing purposes. + +config QRTR_MHI + tristate "MHI IPC Router channels" + depends on MHI_BUS + help + Say Y here to support MHI based ipcrouter channels. MHI is the + transport used for communicating to external modems. + +endif # QRTR diff --git a/net/qrtr/Makefile b/net/qrtr/Makefile new file mode 100644 index 000000000..8e0605f88 --- /dev/null +++ b/net/qrtr/Makefile @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: GPL-2.0-only +obj-$(CONFIG_QRTR) += qrtr.o +qrtr-y := af_qrtr.o ns.o + +obj-$(CONFIG_QRTR_SMD) += qrtr-smd.o +qrtr-smd-y := smd.o +obj-$(CONFIG_QRTR_TUN) += qrtr-tun.o +qrtr-tun-y := tun.o +obj-$(CONFIG_QRTR_MHI) += qrtr-mhi.o +qrtr-mhi-y := mhi.o diff --git a/net/qrtr/af_qrtr.c b/net/qrtr/af_qrtr.c new file mode 100644 index 000000000..76f0434d3 --- /dev/null +++ b/net/qrtr/af_qrtr.c @@ -0,0 +1,1324 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2015, Sony Mobile Communications Inc. + * Copyright (c) 2013, The Linux Foundation. All rights reserved. + */ +#include +#include +#include +#include /* For TIOCINQ/OUTQ */ +#include +#include + +#include + +#include "qrtr.h" + +#define QRTR_PROTO_VER_1 1 +#define QRTR_PROTO_VER_2 3 + +/* auto-bind range */ +#define QRTR_MIN_EPH_SOCKET 0x4000 +#define QRTR_MAX_EPH_SOCKET 0x7fff +#define QRTR_EPH_PORT_RANGE \ + XA_LIMIT(QRTR_MIN_EPH_SOCKET, QRTR_MAX_EPH_SOCKET) + +/** + * struct qrtr_hdr_v1 - (I|R)PCrouter packet header version 1 + * @version: protocol version + * @type: packet type; one of QRTR_TYPE_* + * @src_node_id: source node + * @src_port_id: source port + * @confirm_rx: boolean; whether a resume-tx packet should be send in reply + * @size: length of packet, excluding this header + * @dst_node_id: destination node + * @dst_port_id: destination port + */ +struct qrtr_hdr_v1 { + __le32 version; + __le32 type; + __le32 src_node_id; + __le32 src_port_id; + __le32 confirm_rx; + __le32 size; + __le32 dst_node_id; + __le32 dst_port_id; +} __packed; + +/** + * struct qrtr_hdr_v2 - (I|R)PCrouter packet header later versions + * @version: protocol version + * @type: packet type; one of QRTR_TYPE_* + * @flags: bitmask of QRTR_FLAGS_* + * @optlen: length of optional header data + * @size: length of packet, excluding this header and optlen + * @src_node_id: source node + * @src_port_id: source port + * @dst_node_id: destination node + * @dst_port_id: destination port + */ +struct qrtr_hdr_v2 { + u8 version; + u8 type; + u8 flags; + u8 optlen; + __le32 size; + __le16 src_node_id; + __le16 src_port_id; + __le16 dst_node_id; + __le16 dst_port_id; +}; + +#define QRTR_FLAGS_CONFIRM_RX BIT(0) + +struct qrtr_cb { + u32 src_node; + u32 src_port; + u32 dst_node; + u32 dst_port; + + u8 type; + u8 confirm_rx; +}; + +#define QRTR_HDR_MAX_SIZE max_t(size_t, sizeof(struct qrtr_hdr_v1), \ + sizeof(struct qrtr_hdr_v2)) + +struct qrtr_sock { + /* WARNING: sk must be the first member */ + struct sock sk; + struct sockaddr_qrtr us; + struct sockaddr_qrtr peer; +}; + +static inline struct qrtr_sock *qrtr_sk(struct sock *sk) +{ + BUILD_BUG_ON(offsetof(struct qrtr_sock, sk) != 0); + return container_of(sk, struct qrtr_sock, sk); +} + +static unsigned int qrtr_local_nid = 1; + +/* for node ids */ +static RADIX_TREE(qrtr_nodes, GFP_ATOMIC); +static DEFINE_SPINLOCK(qrtr_nodes_lock); +/* broadcast list */ +static LIST_HEAD(qrtr_all_nodes); +/* lock for qrtr_all_nodes and node reference */ +static DEFINE_MUTEX(qrtr_node_lock); + +/* local port allocation management */ +static DEFINE_XARRAY_ALLOC(qrtr_ports); + +/** + * struct qrtr_node - endpoint node + * @ep_lock: lock for endpoint management and callbacks + * @ep: endpoint + * @ref: reference count for node + * @nid: node id + * @qrtr_tx_flow: tree of qrtr_tx_flow, keyed by node << 32 | port + * @qrtr_tx_lock: lock for qrtr_tx_flow inserts + * @rx_queue: receive queue + * @item: list item for broadcast list + */ +struct qrtr_node { + struct mutex ep_lock; + struct qrtr_endpoint *ep; + struct kref ref; + unsigned int nid; + + struct radix_tree_root qrtr_tx_flow; + struct mutex qrtr_tx_lock; /* for qrtr_tx_flow */ + + struct sk_buff_head rx_queue; + struct list_head item; +}; + +/** + * struct qrtr_tx_flow - tx flow control + * @resume_tx: waiters for a resume tx from the remote + * @pending: number of waiting senders + * @tx_failed: indicates that a message with confirm_rx flag was lost + */ +struct qrtr_tx_flow { + struct wait_queue_head resume_tx; + int pending; + int tx_failed; +}; + +#define QRTR_TX_FLOW_HIGH 10 +#define QRTR_TX_FLOW_LOW 5 + +static int qrtr_local_enqueue(struct qrtr_node *node, struct sk_buff *skb, + int type, struct sockaddr_qrtr *from, + struct sockaddr_qrtr *to); +static int qrtr_bcast_enqueue(struct qrtr_node *node, struct sk_buff *skb, + int type, struct sockaddr_qrtr *from, + struct sockaddr_qrtr *to); +static struct qrtr_sock *qrtr_port_lookup(int port); +static void qrtr_port_put(struct qrtr_sock *ipc); + +/* Release node resources and free the node. + * + * Do not call directly, use qrtr_node_release. To be used with + * kref_put_mutex. As such, the node mutex is expected to be locked on call. + */ +static void __qrtr_node_release(struct kref *kref) +{ + struct qrtr_node *node = container_of(kref, struct qrtr_node, ref); + struct radix_tree_iter iter; + struct qrtr_tx_flow *flow; + unsigned long flags; + void __rcu **slot; + + spin_lock_irqsave(&qrtr_nodes_lock, flags); + /* If the node is a bridge for other nodes, there are possibly + * multiple entries pointing to our released node, delete them all. + */ + radix_tree_for_each_slot(slot, &qrtr_nodes, &iter, 0) { + if (*slot == node) + radix_tree_iter_delete(&qrtr_nodes, &iter, slot); + } + spin_unlock_irqrestore(&qrtr_nodes_lock, flags); + + list_del(&node->item); + mutex_unlock(&qrtr_node_lock); + + skb_queue_purge(&node->rx_queue); + + /* Free tx flow counters */ + radix_tree_for_each_slot(slot, &node->qrtr_tx_flow, &iter, 0) { + flow = *slot; + radix_tree_iter_delete(&node->qrtr_tx_flow, &iter, slot); + kfree(flow); + } + kfree(node); +} + +/* Increment reference to node. */ +static struct qrtr_node *qrtr_node_acquire(struct qrtr_node *node) +{ + if (node) + kref_get(&node->ref); + return node; +} + +/* Decrement reference to node and release as necessary. */ +static void qrtr_node_release(struct qrtr_node *node) +{ + if (!node) + return; + kref_put_mutex(&node->ref, __qrtr_node_release, &qrtr_node_lock); +} + +/** + * qrtr_tx_resume() - reset flow control counter + * @node: qrtr_node that the QRTR_TYPE_RESUME_TX packet arrived on + * @skb: resume_tx packet + */ +static void qrtr_tx_resume(struct qrtr_node *node, struct sk_buff *skb) +{ + struct qrtr_ctrl_pkt *pkt = (struct qrtr_ctrl_pkt *)skb->data; + u64 remote_node = le32_to_cpu(pkt->client.node); + u32 remote_port = le32_to_cpu(pkt->client.port); + struct qrtr_tx_flow *flow; + unsigned long key; + + key = remote_node << 32 | remote_port; + + rcu_read_lock(); + flow = radix_tree_lookup(&node->qrtr_tx_flow, key); + rcu_read_unlock(); + if (flow) { + spin_lock(&flow->resume_tx.lock); + flow->pending = 0; + spin_unlock(&flow->resume_tx.lock); + wake_up_interruptible_all(&flow->resume_tx); + } + + consume_skb(skb); +} + +/** + * qrtr_tx_wait() - flow control for outgoing packets + * @node: qrtr_node that the packet is to be send to + * @dest_node: node id of the destination + * @dest_port: port number of the destination + * @type: type of message + * + * The flow control scheme is based around the low and high "watermarks". When + * the low watermark is passed the confirm_rx flag is set on the outgoing + * message, which will trigger the remote to send a control message of the type + * QRTR_TYPE_RESUME_TX to reset the counter. If the high watermark is hit + * further transmision should be paused. + * + * Return: 1 if confirm_rx should be set, 0 otherwise or errno failure + */ +static int qrtr_tx_wait(struct qrtr_node *node, int dest_node, int dest_port, + int type) +{ + unsigned long key = (u64)dest_node << 32 | dest_port; + struct qrtr_tx_flow *flow; + int confirm_rx = 0; + int ret; + + /* Never set confirm_rx on non-data packets */ + if (type != QRTR_TYPE_DATA) + return 0; + + mutex_lock(&node->qrtr_tx_lock); + flow = radix_tree_lookup(&node->qrtr_tx_flow, key); + if (!flow) { + flow = kzalloc(sizeof(*flow), GFP_KERNEL); + if (flow) { + init_waitqueue_head(&flow->resume_tx); + if (radix_tree_insert(&node->qrtr_tx_flow, key, flow)) { + kfree(flow); + flow = NULL; + } + } + } + mutex_unlock(&node->qrtr_tx_lock); + + /* Set confirm_rx if we where unable to find and allocate a flow */ + if (!flow) + return 1; + + spin_lock_irq(&flow->resume_tx.lock); + ret = wait_event_interruptible_locked_irq(flow->resume_tx, + flow->pending < QRTR_TX_FLOW_HIGH || + flow->tx_failed || + !node->ep); + if (ret < 0) { + confirm_rx = ret; + } else if (!node->ep) { + confirm_rx = -EPIPE; + } else if (flow->tx_failed) { + flow->tx_failed = 0; + confirm_rx = 1; + } else { + flow->pending++; + confirm_rx = flow->pending == QRTR_TX_FLOW_LOW; + } + spin_unlock_irq(&flow->resume_tx.lock); + + return confirm_rx; +} + +/** + * qrtr_tx_flow_failed() - flag that tx of confirm_rx flagged messages failed + * @node: qrtr_node that the packet is to be send to + * @dest_node: node id of the destination + * @dest_port: port number of the destination + * + * Signal that the transmission of a message with confirm_rx flag failed. The + * flow's "pending" counter will keep incrementing towards QRTR_TX_FLOW_HIGH, + * at which point transmission would stall forever waiting for the resume TX + * message associated with the dropped confirm_rx message. + * Work around this by marking the flow as having a failed transmission and + * cause the next transmission attempt to be sent with the confirm_rx. + */ +static void qrtr_tx_flow_failed(struct qrtr_node *node, int dest_node, + int dest_port) +{ + unsigned long key = (u64)dest_node << 32 | dest_port; + struct qrtr_tx_flow *flow; + + rcu_read_lock(); + flow = radix_tree_lookup(&node->qrtr_tx_flow, key); + rcu_read_unlock(); + if (flow) { + spin_lock_irq(&flow->resume_tx.lock); + flow->tx_failed = 1; + spin_unlock_irq(&flow->resume_tx.lock); + } +} + +/* Pass an outgoing packet socket buffer to the endpoint driver. */ +static int qrtr_node_enqueue(struct qrtr_node *node, struct sk_buff *skb, + int type, struct sockaddr_qrtr *from, + struct sockaddr_qrtr *to) +{ + struct qrtr_hdr_v1 *hdr; + size_t len = skb->len; + int rc, confirm_rx; + + confirm_rx = qrtr_tx_wait(node, to->sq_node, to->sq_port, type); + if (confirm_rx < 0) { + kfree_skb(skb); + return confirm_rx; + } + + hdr = skb_push(skb, sizeof(*hdr)); + hdr->version = cpu_to_le32(QRTR_PROTO_VER_1); + hdr->type = cpu_to_le32(type); + hdr->src_node_id = cpu_to_le32(from->sq_node); + hdr->src_port_id = cpu_to_le32(from->sq_port); + if (to->sq_port == QRTR_PORT_CTRL) { + hdr->dst_node_id = cpu_to_le32(node->nid); + hdr->dst_port_id = cpu_to_le32(QRTR_PORT_CTRL); + } else { + hdr->dst_node_id = cpu_to_le32(to->sq_node); + hdr->dst_port_id = cpu_to_le32(to->sq_port); + } + + hdr->size = cpu_to_le32(len); + hdr->confirm_rx = !!confirm_rx; + + rc = skb_put_padto(skb, ALIGN(len, 4) + sizeof(*hdr)); + + if (!rc) { + mutex_lock(&node->ep_lock); + rc = -ENODEV; + if (node->ep) + rc = node->ep->xmit(node->ep, skb); + else + kfree_skb(skb); + mutex_unlock(&node->ep_lock); + } + /* Need to ensure that a subsequent message carries the otherwise lost + * confirm_rx flag if we dropped this one */ + if (rc && confirm_rx) + qrtr_tx_flow_failed(node, to->sq_node, to->sq_port); + + return rc; +} + +/* Lookup node by id. + * + * callers must release with qrtr_node_release() + */ +static struct qrtr_node *qrtr_node_lookup(unsigned int nid) +{ + struct qrtr_node *node; + unsigned long flags; + + mutex_lock(&qrtr_node_lock); + spin_lock_irqsave(&qrtr_nodes_lock, flags); + node = radix_tree_lookup(&qrtr_nodes, nid); + node = qrtr_node_acquire(node); + spin_unlock_irqrestore(&qrtr_nodes_lock, flags); + mutex_unlock(&qrtr_node_lock); + + return node; +} + +/* Assign node id to node. + * + * This is mostly useful for automatic node id assignment, based on + * the source id in the incoming packet. + */ +static void qrtr_node_assign(struct qrtr_node *node, unsigned int nid) +{ + unsigned long flags; + + if (nid == QRTR_EP_NID_AUTO) + return; + + spin_lock_irqsave(&qrtr_nodes_lock, flags); + radix_tree_insert(&qrtr_nodes, nid, node); + if (node->nid == QRTR_EP_NID_AUTO) + node->nid = nid; + spin_unlock_irqrestore(&qrtr_nodes_lock, flags); +} + +/** + * qrtr_endpoint_post() - post incoming data + * @ep: endpoint handle + * @data: data pointer + * @len: size of data in bytes + * + * Return: 0 on success; negative error code on failure + */ +int qrtr_endpoint_post(struct qrtr_endpoint *ep, const void *data, size_t len) +{ + struct qrtr_node *node = ep->node; + const struct qrtr_hdr_v1 *v1; + const struct qrtr_hdr_v2 *v2; + struct qrtr_sock *ipc; + struct sk_buff *skb; + struct qrtr_cb *cb; + size_t size; + unsigned int ver; + size_t hdrlen; + + if (len == 0 || len & 3) + return -EINVAL; + + skb = __netdev_alloc_skb(NULL, len, GFP_ATOMIC | __GFP_NOWARN); + if (!skb) + return -ENOMEM; + + cb = (struct qrtr_cb *)skb->cb; + + /* Version field in v1 is little endian, so this works for both cases */ + ver = *(u8*)data; + + switch (ver) { + case QRTR_PROTO_VER_1: + if (len < sizeof(*v1)) + goto err; + v1 = data; + hdrlen = sizeof(*v1); + + cb->type = le32_to_cpu(v1->type); + cb->src_node = le32_to_cpu(v1->src_node_id); + cb->src_port = le32_to_cpu(v1->src_port_id); + cb->confirm_rx = !!v1->confirm_rx; + cb->dst_node = le32_to_cpu(v1->dst_node_id); + cb->dst_port = le32_to_cpu(v1->dst_port_id); + + size = le32_to_cpu(v1->size); + break; + case QRTR_PROTO_VER_2: + if (len < sizeof(*v2)) + goto err; + v2 = data; + hdrlen = sizeof(*v2) + v2->optlen; + + cb->type = v2->type; + cb->confirm_rx = !!(v2->flags & QRTR_FLAGS_CONFIRM_RX); + cb->src_node = le16_to_cpu(v2->src_node_id); + cb->src_port = le16_to_cpu(v2->src_port_id); + cb->dst_node = le16_to_cpu(v2->dst_node_id); + cb->dst_port = le16_to_cpu(v2->dst_port_id); + + if (cb->src_port == (u16)QRTR_PORT_CTRL) + cb->src_port = QRTR_PORT_CTRL; + if (cb->dst_port == (u16)QRTR_PORT_CTRL) + cb->dst_port = QRTR_PORT_CTRL; + + size = le32_to_cpu(v2->size); + break; + default: + pr_err("qrtr: Invalid version %d\n", ver); + goto err; + } + + if (!size || len != ALIGN(size, 4) + hdrlen) + goto err; + + if ((cb->type == QRTR_TYPE_NEW_SERVER || + cb->type == QRTR_TYPE_RESUME_TX) && + size < sizeof(struct qrtr_ctrl_pkt)) + goto err; + + if (cb->dst_port != QRTR_PORT_CTRL && cb->type != QRTR_TYPE_DATA && + cb->type != QRTR_TYPE_RESUME_TX) + goto err; + + skb_put_data(skb, data + hdrlen, size); + + qrtr_node_assign(node, cb->src_node); + + if (cb->type == QRTR_TYPE_NEW_SERVER) { + /* Remote node endpoint can bridge other distant nodes */ + const struct qrtr_ctrl_pkt *pkt; + + pkt = data + hdrlen; + qrtr_node_assign(node, le32_to_cpu(pkt->server.node)); + } + + if (cb->type == QRTR_TYPE_RESUME_TX) { + qrtr_tx_resume(node, skb); + } else { + ipc = qrtr_port_lookup(cb->dst_port); + if (!ipc) + goto err; + + if (sock_queue_rcv_skb(&ipc->sk, skb)) { + qrtr_port_put(ipc); + goto err; + } + + qrtr_port_put(ipc); + } + + return 0; + +err: + kfree_skb(skb); + return -EINVAL; + +} +EXPORT_SYMBOL_GPL(qrtr_endpoint_post); + +/** + * qrtr_alloc_ctrl_packet() - allocate control packet skb + * @pkt: reference to qrtr_ctrl_pkt pointer + * @flags: the type of memory to allocate + * + * Returns newly allocated sk_buff, or NULL on failure + * + * This function allocates a sk_buff large enough to carry a qrtr_ctrl_pkt and + * on success returns a reference to the control packet in @pkt. + */ +static struct sk_buff *qrtr_alloc_ctrl_packet(struct qrtr_ctrl_pkt **pkt, + gfp_t flags) +{ + const int pkt_len = sizeof(struct qrtr_ctrl_pkt); + struct sk_buff *skb; + + skb = alloc_skb(QRTR_HDR_MAX_SIZE + pkt_len, flags); + if (!skb) + return NULL; + + skb_reserve(skb, QRTR_HDR_MAX_SIZE); + *pkt = skb_put_zero(skb, pkt_len); + + return skb; +} + +/** + * qrtr_endpoint_register() - register a new endpoint + * @ep: endpoint to register + * @nid: desired node id; may be QRTR_EP_NID_AUTO for auto-assignment + * Return: 0 on success; negative error code on failure + * + * The specified endpoint must have the xmit function pointer set on call. + */ +int qrtr_endpoint_register(struct qrtr_endpoint *ep, unsigned int nid) +{ + struct qrtr_node *node; + + if (!ep || !ep->xmit) + return -EINVAL; + + node = kzalloc(sizeof(*node), GFP_KERNEL); + if (!node) + return -ENOMEM; + + kref_init(&node->ref); + mutex_init(&node->ep_lock); + skb_queue_head_init(&node->rx_queue); + node->nid = QRTR_EP_NID_AUTO; + node->ep = ep; + + INIT_RADIX_TREE(&node->qrtr_tx_flow, GFP_KERNEL); + mutex_init(&node->qrtr_tx_lock); + + qrtr_node_assign(node, nid); + + mutex_lock(&qrtr_node_lock); + list_add(&node->item, &qrtr_all_nodes); + mutex_unlock(&qrtr_node_lock); + ep->node = node; + + return 0; +} +EXPORT_SYMBOL_GPL(qrtr_endpoint_register); + +/** + * qrtr_endpoint_unregister - unregister endpoint + * @ep: endpoint to unregister + */ +void qrtr_endpoint_unregister(struct qrtr_endpoint *ep) +{ + struct qrtr_node *node = ep->node; + struct sockaddr_qrtr src = {AF_QIPCRTR, node->nid, QRTR_PORT_CTRL}; + struct sockaddr_qrtr dst = {AF_QIPCRTR, qrtr_local_nid, QRTR_PORT_CTRL}; + struct radix_tree_iter iter; + struct qrtr_ctrl_pkt *pkt; + struct qrtr_tx_flow *flow; + struct sk_buff *skb; + unsigned long flags; + void __rcu **slot; + + mutex_lock(&node->ep_lock); + node->ep = NULL; + mutex_unlock(&node->ep_lock); + + /* Notify the local controller about the event */ + spin_lock_irqsave(&qrtr_nodes_lock, flags); + radix_tree_for_each_slot(slot, &qrtr_nodes, &iter, 0) { + if (*slot != node) + continue; + src.sq_node = iter.index; + skb = qrtr_alloc_ctrl_packet(&pkt, GFP_ATOMIC); + if (skb) { + pkt->cmd = cpu_to_le32(QRTR_TYPE_BYE); + qrtr_local_enqueue(NULL, skb, QRTR_TYPE_BYE, &src, &dst); + } + } + spin_unlock_irqrestore(&qrtr_nodes_lock, flags); + + /* Wake up any transmitters waiting for resume-tx from the node */ + mutex_lock(&node->qrtr_tx_lock); + radix_tree_for_each_slot(slot, &node->qrtr_tx_flow, &iter, 0) { + flow = *slot; + wake_up_interruptible_all(&flow->resume_tx); + } + mutex_unlock(&node->qrtr_tx_lock); + + qrtr_node_release(node); + ep->node = NULL; +} +EXPORT_SYMBOL_GPL(qrtr_endpoint_unregister); + +/* Lookup socket by port. + * + * Callers must release with qrtr_port_put() + */ +static struct qrtr_sock *qrtr_port_lookup(int port) +{ + struct qrtr_sock *ipc; + + if (port == QRTR_PORT_CTRL) + port = 0; + + rcu_read_lock(); + ipc = xa_load(&qrtr_ports, port); + if (ipc) + sock_hold(&ipc->sk); + rcu_read_unlock(); + + return ipc; +} + +/* Release acquired socket. */ +static void qrtr_port_put(struct qrtr_sock *ipc) +{ + sock_put(&ipc->sk); +} + +/* Remove port assignment. */ +static void qrtr_port_remove(struct qrtr_sock *ipc) +{ + struct qrtr_ctrl_pkt *pkt; + struct sk_buff *skb; + int port = ipc->us.sq_port; + struct sockaddr_qrtr to; + + to.sq_family = AF_QIPCRTR; + to.sq_node = QRTR_NODE_BCAST; + to.sq_port = QRTR_PORT_CTRL; + + skb = qrtr_alloc_ctrl_packet(&pkt, GFP_KERNEL); + if (skb) { + pkt->cmd = cpu_to_le32(QRTR_TYPE_DEL_CLIENT); + pkt->client.node = cpu_to_le32(ipc->us.sq_node); + pkt->client.port = cpu_to_le32(ipc->us.sq_port); + + skb_set_owner_w(skb, &ipc->sk); + qrtr_bcast_enqueue(NULL, skb, QRTR_TYPE_DEL_CLIENT, &ipc->us, + &to); + } + + if (port == QRTR_PORT_CTRL) + port = 0; + + __sock_put(&ipc->sk); + + xa_erase(&qrtr_ports, port); + + /* Ensure that if qrtr_port_lookup() did enter the RCU read section we + * wait for it to up increment the refcount */ + synchronize_rcu(); +} + +/* Assign port number to socket. + * + * Specify port in the integer pointed to by port, and it will be adjusted + * on return as necesssary. + * + * Port may be: + * 0: Assign ephemeral port in [QRTR_MIN_EPH_SOCKET, QRTR_MAX_EPH_SOCKET] + * QRTR_MIN_EPH_SOCKET: Specified; available to all + */ +static int qrtr_port_assign(struct qrtr_sock *ipc, int *port) +{ + int rc; + + if (!*port) { + rc = xa_alloc(&qrtr_ports, port, ipc, QRTR_EPH_PORT_RANGE, + GFP_KERNEL); + } else if (*port < QRTR_MIN_EPH_SOCKET && !capable(CAP_NET_ADMIN)) { + rc = -EACCES; + } else if (*port == QRTR_PORT_CTRL) { + rc = xa_insert(&qrtr_ports, 0, ipc, GFP_KERNEL); + } else { + rc = xa_insert(&qrtr_ports, *port, ipc, GFP_KERNEL); + } + + if (rc == -EBUSY) + return -EADDRINUSE; + else if (rc < 0) + return rc; + + sock_hold(&ipc->sk); + + return 0; +} + +/* Reset all non-control ports */ +static void qrtr_reset_ports(void) +{ + struct qrtr_sock *ipc; + unsigned long index; + + rcu_read_lock(); + xa_for_each_start(&qrtr_ports, index, ipc, 1) { + sock_hold(&ipc->sk); + ipc->sk.sk_err = ENETRESET; + sk_error_report(&ipc->sk); + sock_put(&ipc->sk); + } + rcu_read_unlock(); +} + +/* Bind socket to address. + * + * Socket should be locked upon call. + */ +static int __qrtr_bind(struct socket *sock, + const struct sockaddr_qrtr *addr, int zapped) +{ + struct qrtr_sock *ipc = qrtr_sk(sock->sk); + struct sock *sk = sock->sk; + int port; + int rc; + + /* rebinding ok */ + if (!zapped && addr->sq_port == ipc->us.sq_port) + return 0; + + port = addr->sq_port; + rc = qrtr_port_assign(ipc, &port); + if (rc) + return rc; + + /* unbind previous, if any */ + if (!zapped) + qrtr_port_remove(ipc); + ipc->us.sq_port = port; + + sock_reset_flag(sk, SOCK_ZAPPED); + + /* Notify all open ports about the new controller */ + if (port == QRTR_PORT_CTRL) + qrtr_reset_ports(); + + return 0; +} + +/* Auto bind to an ephemeral port. */ +static int qrtr_autobind(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct sockaddr_qrtr addr; + + if (!sock_flag(sk, SOCK_ZAPPED)) + return 0; + + addr.sq_family = AF_QIPCRTR; + addr.sq_node = qrtr_local_nid; + addr.sq_port = 0; + + return __qrtr_bind(sock, &addr, 1); +} + +/* Bind socket to specified sockaddr. */ +static int qrtr_bind(struct socket *sock, struct sockaddr *saddr, int len) +{ + DECLARE_SOCKADDR(struct sockaddr_qrtr *, addr, saddr); + struct qrtr_sock *ipc = qrtr_sk(sock->sk); + struct sock *sk = sock->sk; + int rc; + + if (len < sizeof(*addr) || addr->sq_family != AF_QIPCRTR) + return -EINVAL; + + if (addr->sq_node != ipc->us.sq_node) + return -EINVAL; + + lock_sock(sk); + rc = __qrtr_bind(sock, addr, sock_flag(sk, SOCK_ZAPPED)); + release_sock(sk); + + return rc; +} + +/* Queue packet to local peer socket. */ +static int qrtr_local_enqueue(struct qrtr_node *node, struct sk_buff *skb, + int type, struct sockaddr_qrtr *from, + struct sockaddr_qrtr *to) +{ + struct qrtr_sock *ipc; + struct qrtr_cb *cb; + + ipc = qrtr_port_lookup(to->sq_port); + if (!ipc || &ipc->sk == skb->sk) { /* do not send to self */ + if (ipc) + qrtr_port_put(ipc); + kfree_skb(skb); + return -ENODEV; + } + + cb = (struct qrtr_cb *)skb->cb; + cb->src_node = from->sq_node; + cb->src_port = from->sq_port; + + if (sock_queue_rcv_skb(&ipc->sk, skb)) { + qrtr_port_put(ipc); + kfree_skb(skb); + return -ENOSPC; + } + + qrtr_port_put(ipc); + + return 0; +} + +/* Queue packet for broadcast. */ +static int qrtr_bcast_enqueue(struct qrtr_node *node, struct sk_buff *skb, + int type, struct sockaddr_qrtr *from, + struct sockaddr_qrtr *to) +{ + struct sk_buff *skbn; + + mutex_lock(&qrtr_node_lock); + list_for_each_entry(node, &qrtr_all_nodes, item) { + skbn = skb_clone(skb, GFP_KERNEL); + if (!skbn) + break; + skb_set_owner_w(skbn, skb->sk); + qrtr_node_enqueue(node, skbn, type, from, to); + } + mutex_unlock(&qrtr_node_lock); + + qrtr_local_enqueue(NULL, skb, type, from, to); + + return 0; +} + +static int qrtr_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) +{ + DECLARE_SOCKADDR(struct sockaddr_qrtr *, addr, msg->msg_name); + int (*enqueue_fn)(struct qrtr_node *, struct sk_buff *, int, + struct sockaddr_qrtr *, struct sockaddr_qrtr *); + __le32 qrtr_type = cpu_to_le32(QRTR_TYPE_DATA); + struct qrtr_sock *ipc = qrtr_sk(sock->sk); + struct sock *sk = sock->sk; + struct qrtr_node *node; + struct sk_buff *skb; + size_t plen; + u32 type; + int rc; + + if (msg->msg_flags & ~(MSG_DONTWAIT)) + return -EINVAL; + + if (len > 65535) + return -EMSGSIZE; + + lock_sock(sk); + + if (addr) { + if (msg->msg_namelen < sizeof(*addr)) { + release_sock(sk); + return -EINVAL; + } + + if (addr->sq_family != AF_QIPCRTR) { + release_sock(sk); + return -EINVAL; + } + + rc = qrtr_autobind(sock); + if (rc) { + release_sock(sk); + return rc; + } + } else if (sk->sk_state == TCP_ESTABLISHED) { + addr = &ipc->peer; + } else { + release_sock(sk); + return -ENOTCONN; + } + + node = NULL; + if (addr->sq_node == QRTR_NODE_BCAST) { + if (addr->sq_port != QRTR_PORT_CTRL && + qrtr_local_nid != QRTR_NODE_BCAST) { + release_sock(sk); + return -ENOTCONN; + } + enqueue_fn = qrtr_bcast_enqueue; + } else if (addr->sq_node == ipc->us.sq_node) { + enqueue_fn = qrtr_local_enqueue; + } else { + node = qrtr_node_lookup(addr->sq_node); + if (!node) { + release_sock(sk); + return -ECONNRESET; + } + enqueue_fn = qrtr_node_enqueue; + } + + plen = (len + 3) & ~3; + skb = sock_alloc_send_skb(sk, plen + QRTR_HDR_MAX_SIZE, + msg->msg_flags & MSG_DONTWAIT, &rc); + if (!skb) { + rc = -ENOMEM; + goto out_node; + } + + skb_reserve(skb, QRTR_HDR_MAX_SIZE); + + rc = memcpy_from_msg(skb_put(skb, len), msg, len); + if (rc) { + kfree_skb(skb); + goto out_node; + } + + if (ipc->us.sq_port == QRTR_PORT_CTRL) { + if (len < 4) { + rc = -EINVAL; + kfree_skb(skb); + goto out_node; + } + + /* control messages already require the type as 'command' */ + skb_copy_bits(skb, 0, &qrtr_type, 4); + } + + type = le32_to_cpu(qrtr_type); + rc = enqueue_fn(node, skb, type, &ipc->us, addr); + if (rc >= 0) + rc = len; + +out_node: + qrtr_node_release(node); + release_sock(sk); + + return rc; +} + +static int qrtr_send_resume_tx(struct qrtr_cb *cb) +{ + struct sockaddr_qrtr remote = { AF_QIPCRTR, cb->src_node, cb->src_port }; + struct sockaddr_qrtr local = { AF_QIPCRTR, cb->dst_node, cb->dst_port }; + struct qrtr_ctrl_pkt *pkt; + struct qrtr_node *node; + struct sk_buff *skb; + int ret; + + node = qrtr_node_lookup(remote.sq_node); + if (!node) + return -EINVAL; + + skb = qrtr_alloc_ctrl_packet(&pkt, GFP_KERNEL); + if (!skb) + return -ENOMEM; + + pkt->cmd = cpu_to_le32(QRTR_TYPE_RESUME_TX); + pkt->client.node = cpu_to_le32(cb->dst_node); + pkt->client.port = cpu_to_le32(cb->dst_port); + + ret = qrtr_node_enqueue(node, skb, QRTR_TYPE_RESUME_TX, &local, &remote); + + qrtr_node_release(node); + + return ret; +} + +static int qrtr_recvmsg(struct socket *sock, struct msghdr *msg, + size_t size, int flags) +{ + DECLARE_SOCKADDR(struct sockaddr_qrtr *, addr, msg->msg_name); + struct sock *sk = sock->sk; + struct sk_buff *skb; + struct qrtr_cb *cb; + int copied, rc; + + lock_sock(sk); + + if (sock_flag(sk, SOCK_ZAPPED)) { + release_sock(sk); + return -EADDRNOTAVAIL; + } + + skb = skb_recv_datagram(sk, flags, &rc); + if (!skb) { + release_sock(sk); + return rc; + } + cb = (struct qrtr_cb *)skb->cb; + + copied = skb->len; + if (copied > size) { + copied = size; + msg->msg_flags |= MSG_TRUNC; + } + + rc = skb_copy_datagram_msg(skb, 0, msg, copied); + if (rc < 0) + goto out; + rc = copied; + + if (addr) { + /* There is an anonymous 2-byte hole after sq_family, + * make sure to clear it. + */ + memset(addr, 0, sizeof(*addr)); + + addr->sq_family = AF_QIPCRTR; + addr->sq_node = cb->src_node; + addr->sq_port = cb->src_port; + msg->msg_namelen = sizeof(*addr); + } + +out: + if (cb->confirm_rx) + qrtr_send_resume_tx(cb); + + skb_free_datagram(sk, skb); + release_sock(sk); + + return rc; +} + +static int qrtr_connect(struct socket *sock, struct sockaddr *saddr, + int len, int flags) +{ + DECLARE_SOCKADDR(struct sockaddr_qrtr *, addr, saddr); + struct qrtr_sock *ipc = qrtr_sk(sock->sk); + struct sock *sk = sock->sk; + int rc; + + if (len < sizeof(*addr) || addr->sq_family != AF_QIPCRTR) + return -EINVAL; + + lock_sock(sk); + + sk->sk_state = TCP_CLOSE; + sock->state = SS_UNCONNECTED; + + rc = qrtr_autobind(sock); + if (rc) { + release_sock(sk); + return rc; + } + + ipc->peer = *addr; + sock->state = SS_CONNECTED; + sk->sk_state = TCP_ESTABLISHED; + + release_sock(sk); + + return 0; +} + +static int qrtr_getname(struct socket *sock, struct sockaddr *saddr, + int peer) +{ + struct qrtr_sock *ipc = qrtr_sk(sock->sk); + struct sockaddr_qrtr qaddr; + struct sock *sk = sock->sk; + + lock_sock(sk); + if (peer) { + if (sk->sk_state != TCP_ESTABLISHED) { + release_sock(sk); + return -ENOTCONN; + } + + qaddr = ipc->peer; + } else { + qaddr = ipc->us; + } + release_sock(sk); + + qaddr.sq_family = AF_QIPCRTR; + + memcpy(saddr, &qaddr, sizeof(qaddr)); + + return sizeof(qaddr); +} + +static int qrtr_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + void __user *argp = (void __user *)arg; + struct qrtr_sock *ipc = qrtr_sk(sock->sk); + struct sock *sk = sock->sk; + struct sockaddr_qrtr *sq; + struct sk_buff *skb; + struct ifreq ifr; + long len = 0; + int rc = 0; + + lock_sock(sk); + + switch (cmd) { + case TIOCOUTQ: + len = sk->sk_sndbuf - sk_wmem_alloc_get(sk); + if (len < 0) + len = 0; + rc = put_user(len, (int __user *)argp); + break; + case TIOCINQ: + skb = skb_peek(&sk->sk_receive_queue); + if (skb) + len = skb->len; + rc = put_user(len, (int __user *)argp); + break; + case SIOCGIFADDR: + if (get_user_ifreq(&ifr, NULL, argp)) { + rc = -EFAULT; + break; + } + + sq = (struct sockaddr_qrtr *)&ifr.ifr_addr; + *sq = ipc->us; + if (put_user_ifreq(&ifr, argp)) { + rc = -EFAULT; + break; + } + break; + case SIOCADDRT: + case SIOCDELRT: + case SIOCSIFADDR: + case SIOCGIFDSTADDR: + case SIOCSIFDSTADDR: + case SIOCGIFBRDADDR: + case SIOCSIFBRDADDR: + case SIOCGIFNETMASK: + case SIOCSIFNETMASK: + rc = -EINVAL; + break; + default: + rc = -ENOIOCTLCMD; + break; + } + + release_sock(sk); + + return rc; +} + +static int qrtr_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct qrtr_sock *ipc; + + if (!sk) + return 0; + + lock_sock(sk); + + ipc = qrtr_sk(sk); + sk->sk_shutdown = SHUTDOWN_MASK; + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_state_change(sk); + + sock_set_flag(sk, SOCK_DEAD); + sock_orphan(sk); + sock->sk = NULL; + + if (!sock_flag(sk, SOCK_ZAPPED)) + qrtr_port_remove(ipc); + + skb_queue_purge(&sk->sk_receive_queue); + + release_sock(sk); + sock_put(sk); + + return 0; +} + +static const struct proto_ops qrtr_proto_ops = { + .owner = THIS_MODULE, + .family = AF_QIPCRTR, + .bind = qrtr_bind, + .connect = qrtr_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .listen = sock_no_listen, + .sendmsg = qrtr_sendmsg, + .recvmsg = qrtr_recvmsg, + .getname = qrtr_getname, + .ioctl = qrtr_ioctl, + .gettstamp = sock_gettstamp, + .poll = datagram_poll, + .shutdown = sock_no_shutdown, + .release = qrtr_release, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static struct proto qrtr_proto = { + .name = "QIPCRTR", + .owner = THIS_MODULE, + .obj_size = sizeof(struct qrtr_sock), +}; + +static int qrtr_create(struct net *net, struct socket *sock, + int protocol, int kern) +{ + struct qrtr_sock *ipc; + struct sock *sk; + + if (sock->type != SOCK_DGRAM) + return -EPROTOTYPE; + + sk = sk_alloc(net, AF_QIPCRTR, GFP_KERNEL, &qrtr_proto, kern); + if (!sk) + return -ENOMEM; + + sock_set_flag(sk, SOCK_ZAPPED); + + sock_init_data(sock, sk); + sock->ops = &qrtr_proto_ops; + + ipc = qrtr_sk(sk); + ipc->us.sq_family = AF_QIPCRTR; + ipc->us.sq_node = qrtr_local_nid; + ipc->us.sq_port = 0; + + return 0; +} + +static const struct net_proto_family qrtr_family = { + .owner = THIS_MODULE, + .family = AF_QIPCRTR, + .create = qrtr_create, +}; + +static int __init qrtr_proto_init(void) +{ + int rc; + + rc = proto_register(&qrtr_proto, 1); + if (rc) + return rc; + + rc = sock_register(&qrtr_family); + if (rc) + goto err_proto; + + rc = qrtr_ns_init(); + if (rc) + goto err_sock; + + return 0; + +err_sock: + sock_unregister(qrtr_family.family); +err_proto: + proto_unregister(&qrtr_proto); + return rc; +} +postcore_initcall(qrtr_proto_init); + +static void __exit qrtr_proto_fini(void) +{ + qrtr_ns_remove(); + sock_unregister(qrtr_family.family); + proto_unregister(&qrtr_proto); +} +module_exit(qrtr_proto_fini); + +MODULE_DESCRIPTION("Qualcomm IPC-router driver"); +MODULE_LICENSE("GPL v2"); +MODULE_ALIAS_NETPROTO(PF_QIPCRTR); diff --git a/net/qrtr/mhi.c b/net/qrtr/mhi.c new file mode 100644 index 000000000..9ced13c06 --- /dev/null +++ b/net/qrtr/mhi.c @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2018-2020, The Linux Foundation. All rights reserved. + */ + +#include +#include +#include +#include +#include + +#include "qrtr.h" + +struct qrtr_mhi_dev { + struct qrtr_endpoint ep; + struct mhi_device *mhi_dev; + struct device *dev; +}; + +/* From MHI to QRTR */ +static void qcom_mhi_qrtr_dl_callback(struct mhi_device *mhi_dev, + struct mhi_result *mhi_res) +{ + struct qrtr_mhi_dev *qdev = dev_get_drvdata(&mhi_dev->dev); + int rc; + + if (!qdev || mhi_res->transaction_status) + return; + + rc = qrtr_endpoint_post(&qdev->ep, mhi_res->buf_addr, + mhi_res->bytes_xferd); + if (rc == -EINVAL) + dev_err(qdev->dev, "invalid ipcrouter packet\n"); +} + +/* From QRTR to MHI */ +static void qcom_mhi_qrtr_ul_callback(struct mhi_device *mhi_dev, + struct mhi_result *mhi_res) +{ + struct sk_buff *skb = mhi_res->buf_addr; + + if (skb->sk) + sock_put(skb->sk); + consume_skb(skb); +} + +/* Send data over MHI */ +static int qcom_mhi_qrtr_send(struct qrtr_endpoint *ep, struct sk_buff *skb) +{ + struct qrtr_mhi_dev *qdev = container_of(ep, struct qrtr_mhi_dev, ep); + int rc; + + if (skb->sk) + sock_hold(skb->sk); + + rc = skb_linearize(skb); + if (rc) + goto free_skb; + + rc = mhi_queue_skb(qdev->mhi_dev, DMA_TO_DEVICE, skb, skb->len, + MHI_EOT); + if (rc) + goto free_skb; + + return rc; + +free_skb: + if (skb->sk) + sock_put(skb->sk); + kfree_skb(skb); + + return rc; +} + +static int qcom_mhi_qrtr_probe(struct mhi_device *mhi_dev, + const struct mhi_device_id *id) +{ + struct qrtr_mhi_dev *qdev; + int rc; + + qdev = devm_kzalloc(&mhi_dev->dev, sizeof(*qdev), GFP_KERNEL); + if (!qdev) + return -ENOMEM; + + qdev->mhi_dev = mhi_dev; + qdev->dev = &mhi_dev->dev; + qdev->ep.xmit = qcom_mhi_qrtr_send; + + dev_set_drvdata(&mhi_dev->dev, qdev); + rc = qrtr_endpoint_register(&qdev->ep, QRTR_EP_NID_AUTO); + if (rc) + return rc; + + /* start channels */ + rc = mhi_prepare_for_transfer_autoqueue(mhi_dev); + if (rc) { + qrtr_endpoint_unregister(&qdev->ep); + return rc; + } + + dev_dbg(qdev->dev, "Qualcomm MHI QRTR driver probed\n"); + + return 0; +} + +static void qcom_mhi_qrtr_remove(struct mhi_device *mhi_dev) +{ + struct qrtr_mhi_dev *qdev = dev_get_drvdata(&mhi_dev->dev); + + qrtr_endpoint_unregister(&qdev->ep); + mhi_unprepare_from_transfer(mhi_dev); + dev_set_drvdata(&mhi_dev->dev, NULL); +} + +static const struct mhi_device_id qcom_mhi_qrtr_id_table[] = { + { .chan = "IPCR" }, + {} +}; +MODULE_DEVICE_TABLE(mhi, qcom_mhi_qrtr_id_table); + +static struct mhi_driver qcom_mhi_qrtr_driver = { + .probe = qcom_mhi_qrtr_probe, + .remove = qcom_mhi_qrtr_remove, + .dl_xfer_cb = qcom_mhi_qrtr_dl_callback, + .ul_xfer_cb = qcom_mhi_qrtr_ul_callback, + .id_table = qcom_mhi_qrtr_id_table, + .driver = { + .name = "qcom_mhi_qrtr", + }, +}; + +module_mhi_driver(qcom_mhi_qrtr_driver); + +MODULE_AUTHOR("Chris Lew "); +MODULE_AUTHOR("Manivannan Sadhasivam "); +MODULE_DESCRIPTION("Qualcomm IPC-Router MHI interface driver"); +MODULE_LICENSE("GPL v2"); diff --git a/net/qrtr/ns.c b/net/qrtr/ns.c new file mode 100644 index 000000000..4a13b9f7a --- /dev/null +++ b/net/qrtr/ns.c @@ -0,0 +1,830 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2015, Sony Mobile Communications Inc. + * Copyright (c) 2013, The Linux Foundation. All rights reserved. + * Copyright (c) 2020, Linaro Ltd. + */ + +#include +#include +#include +#include + +#include "qrtr.h" + +#define CREATE_TRACE_POINTS +#include + +static RADIX_TREE(nodes, GFP_KERNEL); + +static struct { + struct socket *sock; + struct sockaddr_qrtr bcast_sq; + struct list_head lookups; + struct workqueue_struct *workqueue; + struct work_struct work; + int local_node; +} qrtr_ns; + +static const char * const qrtr_ctrl_pkt_strings[] = { + [QRTR_TYPE_HELLO] = "hello", + [QRTR_TYPE_BYE] = "bye", + [QRTR_TYPE_NEW_SERVER] = "new-server", + [QRTR_TYPE_DEL_SERVER] = "del-server", + [QRTR_TYPE_DEL_CLIENT] = "del-client", + [QRTR_TYPE_RESUME_TX] = "resume-tx", + [QRTR_TYPE_EXIT] = "exit", + [QRTR_TYPE_PING] = "ping", + [QRTR_TYPE_NEW_LOOKUP] = "new-lookup", + [QRTR_TYPE_DEL_LOOKUP] = "del-lookup", +}; + +struct qrtr_server_filter { + unsigned int service; + unsigned int instance; + unsigned int ifilter; +}; + +struct qrtr_lookup { + unsigned int service; + unsigned int instance; + + struct sockaddr_qrtr sq; + struct list_head li; +}; + +struct qrtr_server { + unsigned int service; + unsigned int instance; + + unsigned int node; + unsigned int port; + + struct list_head qli; +}; + +struct qrtr_node { + unsigned int id; + struct radix_tree_root servers; +}; + +static struct qrtr_node *node_get(unsigned int node_id) +{ + struct qrtr_node *node; + + node = radix_tree_lookup(&nodes, node_id); + if (node) + return node; + + /* If node didn't exist, allocate and insert it to the tree */ + node = kzalloc(sizeof(*node), GFP_KERNEL); + if (!node) + return NULL; + + node->id = node_id; + + if (radix_tree_insert(&nodes, node_id, node)) { + kfree(node); + return NULL; + } + + return node; +} + +static int server_match(const struct qrtr_server *srv, + const struct qrtr_server_filter *f) +{ + unsigned int ifilter = f->ifilter; + + if (f->service != 0 && srv->service != f->service) + return 0; + if (!ifilter && f->instance) + ifilter = ~0; + + return (srv->instance & ifilter) == f->instance; +} + +static int service_announce_new(struct sockaddr_qrtr *dest, + struct qrtr_server *srv) +{ + struct qrtr_ctrl_pkt pkt; + struct msghdr msg = { }; + struct kvec iv; + + trace_qrtr_ns_service_announce_new(srv->service, srv->instance, + srv->node, srv->port); + + iv.iov_base = &pkt; + iv.iov_len = sizeof(pkt); + + memset(&pkt, 0, sizeof(pkt)); + pkt.cmd = cpu_to_le32(QRTR_TYPE_NEW_SERVER); + pkt.server.service = cpu_to_le32(srv->service); + pkt.server.instance = cpu_to_le32(srv->instance); + pkt.server.node = cpu_to_le32(srv->node); + pkt.server.port = cpu_to_le32(srv->port); + + msg.msg_name = (struct sockaddr *)dest; + msg.msg_namelen = sizeof(*dest); + + return kernel_sendmsg(qrtr_ns.sock, &msg, &iv, 1, sizeof(pkt)); +} + +static int service_announce_del(struct sockaddr_qrtr *dest, + struct qrtr_server *srv) +{ + struct qrtr_ctrl_pkt pkt; + struct msghdr msg = { }; + struct kvec iv; + int ret; + + trace_qrtr_ns_service_announce_del(srv->service, srv->instance, + srv->node, srv->port); + + iv.iov_base = &pkt; + iv.iov_len = sizeof(pkt); + + memset(&pkt, 0, sizeof(pkt)); + pkt.cmd = cpu_to_le32(QRTR_TYPE_DEL_SERVER); + pkt.server.service = cpu_to_le32(srv->service); + pkt.server.instance = cpu_to_le32(srv->instance); + pkt.server.node = cpu_to_le32(srv->node); + pkt.server.port = cpu_to_le32(srv->port); + + msg.msg_name = (struct sockaddr *)dest; + msg.msg_namelen = sizeof(*dest); + + ret = kernel_sendmsg(qrtr_ns.sock, &msg, &iv, 1, sizeof(pkt)); + if (ret < 0) + pr_err("failed to announce del service\n"); + + return ret; +} + +static void lookup_notify(struct sockaddr_qrtr *to, struct qrtr_server *srv, + bool new) +{ + struct qrtr_ctrl_pkt pkt; + struct msghdr msg = { }; + struct kvec iv; + int ret; + + iv.iov_base = &pkt; + iv.iov_len = sizeof(pkt); + + memset(&pkt, 0, sizeof(pkt)); + pkt.cmd = new ? cpu_to_le32(QRTR_TYPE_NEW_SERVER) : + cpu_to_le32(QRTR_TYPE_DEL_SERVER); + if (srv) { + pkt.server.service = cpu_to_le32(srv->service); + pkt.server.instance = cpu_to_le32(srv->instance); + pkt.server.node = cpu_to_le32(srv->node); + pkt.server.port = cpu_to_le32(srv->port); + } + + msg.msg_name = (struct sockaddr *)to; + msg.msg_namelen = sizeof(*to); + + ret = kernel_sendmsg(qrtr_ns.sock, &msg, &iv, 1, sizeof(pkt)); + if (ret < 0) + pr_err("failed to send lookup notification\n"); +} + +static int announce_servers(struct sockaddr_qrtr *sq) +{ + struct radix_tree_iter iter; + struct qrtr_server *srv; + struct qrtr_node *node; + void __rcu **slot; + int ret; + + node = node_get(qrtr_ns.local_node); + if (!node) + return 0; + + rcu_read_lock(); + /* Announce the list of servers registered in this node */ + radix_tree_for_each_slot(slot, &node->servers, &iter, 0) { + srv = radix_tree_deref_slot(slot); + if (!srv) + continue; + if (radix_tree_deref_retry(srv)) { + slot = radix_tree_iter_retry(&iter); + continue; + } + slot = radix_tree_iter_resume(slot, &iter); + rcu_read_unlock(); + + ret = service_announce_new(sq, srv); + if (ret < 0) { + pr_err("failed to announce new service\n"); + return ret; + } + + rcu_read_lock(); + } + + rcu_read_unlock(); + + return 0; +} + +static struct qrtr_server *server_add(unsigned int service, + unsigned int instance, + unsigned int node_id, + unsigned int port) +{ + struct qrtr_server *srv; + struct qrtr_server *old; + struct qrtr_node *node; + + if (!service || !port) + return NULL; + + srv = kzalloc(sizeof(*srv), GFP_KERNEL); + if (!srv) + return NULL; + + srv->service = service; + srv->instance = instance; + srv->node = node_id; + srv->port = port; + + node = node_get(node_id); + if (!node) + goto err; + + /* Delete the old server on the same port */ + old = radix_tree_lookup(&node->servers, port); + if (old) { + radix_tree_delete(&node->servers, port); + kfree(old); + } + + radix_tree_insert(&node->servers, port, srv); + + trace_qrtr_ns_server_add(srv->service, srv->instance, + srv->node, srv->port); + + return srv; + +err: + kfree(srv); + return NULL; +} + +static int server_del(struct qrtr_node *node, unsigned int port, bool bcast) +{ + struct qrtr_lookup *lookup; + struct qrtr_server *srv; + struct list_head *li; + + srv = radix_tree_lookup(&node->servers, port); + if (!srv) + return -ENOENT; + + radix_tree_delete(&node->servers, port); + + /* Broadcast the removal of local servers */ + if (srv->node == qrtr_ns.local_node && bcast) + service_announce_del(&qrtr_ns.bcast_sq, srv); + + /* Announce the service's disappearance to observers */ + list_for_each(li, &qrtr_ns.lookups) { + lookup = container_of(li, struct qrtr_lookup, li); + if (lookup->service && lookup->service != srv->service) + continue; + if (lookup->instance && lookup->instance != srv->instance) + continue; + + lookup_notify(&lookup->sq, srv, false); + } + + kfree(srv); + + return 0; +} + +static int say_hello(struct sockaddr_qrtr *dest) +{ + struct qrtr_ctrl_pkt pkt; + struct msghdr msg = { }; + struct kvec iv; + int ret; + + iv.iov_base = &pkt; + iv.iov_len = sizeof(pkt); + + memset(&pkt, 0, sizeof(pkt)); + pkt.cmd = cpu_to_le32(QRTR_TYPE_HELLO); + + msg.msg_name = (struct sockaddr *)dest; + msg.msg_namelen = sizeof(*dest); + + ret = kernel_sendmsg(qrtr_ns.sock, &msg, &iv, 1, sizeof(pkt)); + if (ret < 0) + pr_err("failed to send hello msg\n"); + + return ret; +} + +/* Announce the list of servers registered on the local node */ +static int ctrl_cmd_hello(struct sockaddr_qrtr *sq) +{ + int ret; + + ret = say_hello(sq); + if (ret < 0) + return ret; + + return announce_servers(sq); +} + +static int ctrl_cmd_bye(struct sockaddr_qrtr *from) +{ + struct qrtr_node *local_node; + struct radix_tree_iter iter; + struct qrtr_ctrl_pkt pkt; + struct qrtr_server *srv; + struct sockaddr_qrtr sq; + struct msghdr msg = { }; + struct qrtr_node *node; + void __rcu **slot; + struct kvec iv; + int ret; + + iv.iov_base = &pkt; + iv.iov_len = sizeof(pkt); + + node = node_get(from->sq_node); + if (!node) + return 0; + + rcu_read_lock(); + /* Advertise removal of this client to all servers of remote node */ + radix_tree_for_each_slot(slot, &node->servers, &iter, 0) { + srv = radix_tree_deref_slot(slot); + if (!srv) + continue; + if (radix_tree_deref_retry(srv)) { + slot = radix_tree_iter_retry(&iter); + continue; + } + slot = radix_tree_iter_resume(slot, &iter); + rcu_read_unlock(); + server_del(node, srv->port, true); + rcu_read_lock(); + } + rcu_read_unlock(); + + /* Advertise the removal of this client to all local servers */ + local_node = node_get(qrtr_ns.local_node); + if (!local_node) + return 0; + + memset(&pkt, 0, sizeof(pkt)); + pkt.cmd = cpu_to_le32(QRTR_TYPE_BYE); + pkt.client.node = cpu_to_le32(from->sq_node); + + rcu_read_lock(); + radix_tree_for_each_slot(slot, &local_node->servers, &iter, 0) { + srv = radix_tree_deref_slot(slot); + if (!srv) + continue; + if (radix_tree_deref_retry(srv)) { + slot = radix_tree_iter_retry(&iter); + continue; + } + slot = radix_tree_iter_resume(slot, &iter); + rcu_read_unlock(); + + sq.sq_family = AF_QIPCRTR; + sq.sq_node = srv->node; + sq.sq_port = srv->port; + + msg.msg_name = (struct sockaddr *)&sq; + msg.msg_namelen = sizeof(sq); + + ret = kernel_sendmsg(qrtr_ns.sock, &msg, &iv, 1, sizeof(pkt)); + if (ret < 0) { + pr_err("failed to send bye cmd\n"); + return ret; + } + rcu_read_lock(); + } + + rcu_read_unlock(); + + return 0; +} + +static int ctrl_cmd_del_client(struct sockaddr_qrtr *from, + unsigned int node_id, unsigned int port) +{ + struct qrtr_node *local_node; + struct radix_tree_iter iter; + struct qrtr_lookup *lookup; + struct qrtr_ctrl_pkt pkt; + struct msghdr msg = { }; + struct qrtr_server *srv; + struct sockaddr_qrtr sq; + struct qrtr_node *node; + struct list_head *tmp; + struct list_head *li; + void __rcu **slot; + struct kvec iv; + int ret; + + iv.iov_base = &pkt; + iv.iov_len = sizeof(pkt); + + /* Don't accept spoofed messages */ + if (from->sq_node != node_id) + return -EINVAL; + + /* Local DEL_CLIENT messages comes from the port being closed */ + if (from->sq_node == qrtr_ns.local_node && from->sq_port != port) + return -EINVAL; + + /* Remove any lookups by this client */ + list_for_each_safe(li, tmp, &qrtr_ns.lookups) { + lookup = container_of(li, struct qrtr_lookup, li); + if (lookup->sq.sq_node != node_id) + continue; + if (lookup->sq.sq_port != port) + continue; + + list_del(&lookup->li); + kfree(lookup); + } + + /* Remove the server belonging to this port but don't broadcast + * DEL_SERVER. Neighbours would've already removed the server belonging + * to this port due to the DEL_CLIENT broadcast from qrtr_port_remove(). + */ + node = node_get(node_id); + if (node) + server_del(node, port, false); + + /* Advertise the removal of this client to all local servers */ + local_node = node_get(qrtr_ns.local_node); + if (!local_node) + return 0; + + memset(&pkt, 0, sizeof(pkt)); + pkt.cmd = cpu_to_le32(QRTR_TYPE_DEL_CLIENT); + pkt.client.node = cpu_to_le32(node_id); + pkt.client.port = cpu_to_le32(port); + + rcu_read_lock(); + radix_tree_for_each_slot(slot, &local_node->servers, &iter, 0) { + srv = radix_tree_deref_slot(slot); + if (!srv) + continue; + if (radix_tree_deref_retry(srv)) { + slot = radix_tree_iter_retry(&iter); + continue; + } + slot = radix_tree_iter_resume(slot, &iter); + rcu_read_unlock(); + + sq.sq_family = AF_QIPCRTR; + sq.sq_node = srv->node; + sq.sq_port = srv->port; + + msg.msg_name = (struct sockaddr *)&sq; + msg.msg_namelen = sizeof(sq); + + ret = kernel_sendmsg(qrtr_ns.sock, &msg, &iv, 1, sizeof(pkt)); + if (ret < 0) { + pr_err("failed to send del client cmd\n"); + return ret; + } + rcu_read_lock(); + } + + rcu_read_unlock(); + + return 0; +} + +static int ctrl_cmd_new_server(struct sockaddr_qrtr *from, + unsigned int service, unsigned int instance, + unsigned int node_id, unsigned int port) +{ + struct qrtr_lookup *lookup; + struct qrtr_server *srv; + struct list_head *li; + int ret = 0; + + /* Ignore specified node and port for local servers */ + if (from->sq_node == qrtr_ns.local_node) { + node_id = from->sq_node; + port = from->sq_port; + } + + srv = server_add(service, instance, node_id, port); + if (!srv) + return -EINVAL; + + if (srv->node == qrtr_ns.local_node) { + ret = service_announce_new(&qrtr_ns.bcast_sq, srv); + if (ret < 0) { + pr_err("failed to announce new service\n"); + return ret; + } + } + + /* Notify any potential lookups about the new server */ + list_for_each(li, &qrtr_ns.lookups) { + lookup = container_of(li, struct qrtr_lookup, li); + if (lookup->service && lookup->service != service) + continue; + if (lookup->instance && lookup->instance != instance) + continue; + + lookup_notify(&lookup->sq, srv, true); + } + + return ret; +} + +static int ctrl_cmd_del_server(struct sockaddr_qrtr *from, + unsigned int service, unsigned int instance, + unsigned int node_id, unsigned int port) +{ + struct qrtr_node *node; + + /* Ignore specified node and port for local servers*/ + if (from->sq_node == qrtr_ns.local_node) { + node_id = from->sq_node; + port = from->sq_port; + } + + /* Local servers may only unregister themselves */ + if (from->sq_node == qrtr_ns.local_node && from->sq_port != port) + return -EINVAL; + + node = node_get(node_id); + if (!node) + return -ENOENT; + + server_del(node, port, true); + + return 0; +} + +static int ctrl_cmd_new_lookup(struct sockaddr_qrtr *from, + unsigned int service, unsigned int instance) +{ + struct radix_tree_iter node_iter; + struct qrtr_server_filter filter; + struct radix_tree_iter srv_iter; + struct qrtr_lookup *lookup; + struct qrtr_node *node; + void __rcu **node_slot; + void __rcu **srv_slot; + + /* Accept only local observers */ + if (from->sq_node != qrtr_ns.local_node) + return -EINVAL; + + lookup = kzalloc(sizeof(*lookup), GFP_KERNEL); + if (!lookup) + return -ENOMEM; + + lookup->sq = *from; + lookup->service = service; + lookup->instance = instance; + list_add_tail(&lookup->li, &qrtr_ns.lookups); + + memset(&filter, 0, sizeof(filter)); + filter.service = service; + filter.instance = instance; + + rcu_read_lock(); + radix_tree_for_each_slot(node_slot, &nodes, &node_iter, 0) { + node = radix_tree_deref_slot(node_slot); + if (!node) + continue; + if (radix_tree_deref_retry(node)) { + node_slot = radix_tree_iter_retry(&node_iter); + continue; + } + node_slot = radix_tree_iter_resume(node_slot, &node_iter); + + radix_tree_for_each_slot(srv_slot, &node->servers, + &srv_iter, 0) { + struct qrtr_server *srv; + + srv = radix_tree_deref_slot(srv_slot); + if (!srv) + continue; + if (radix_tree_deref_retry(srv)) { + srv_slot = radix_tree_iter_retry(&srv_iter); + continue; + } + + if (!server_match(srv, &filter)) + continue; + + srv_slot = radix_tree_iter_resume(srv_slot, &srv_iter); + + rcu_read_unlock(); + lookup_notify(from, srv, true); + rcu_read_lock(); + } + } + rcu_read_unlock(); + + /* Empty notification, to indicate end of listing */ + lookup_notify(from, NULL, true); + + return 0; +} + +static void ctrl_cmd_del_lookup(struct sockaddr_qrtr *from, + unsigned int service, unsigned int instance) +{ + struct qrtr_lookup *lookup; + struct list_head *tmp; + struct list_head *li; + + list_for_each_safe(li, tmp, &qrtr_ns.lookups) { + lookup = container_of(li, struct qrtr_lookup, li); + if (lookup->sq.sq_node != from->sq_node) + continue; + if (lookup->sq.sq_port != from->sq_port) + continue; + if (lookup->service != service) + continue; + if (lookup->instance && lookup->instance != instance) + continue; + + list_del(&lookup->li); + kfree(lookup); + } +} + +static void qrtr_ns_worker(struct work_struct *work) +{ + const struct qrtr_ctrl_pkt *pkt; + size_t recv_buf_size = 4096; + struct sockaddr_qrtr sq; + struct msghdr msg = { }; + unsigned int cmd; + ssize_t msglen; + void *recv_buf; + struct kvec iv; + int ret; + + msg.msg_name = (struct sockaddr *)&sq; + msg.msg_namelen = sizeof(sq); + + recv_buf = kzalloc(recv_buf_size, GFP_KERNEL); + if (!recv_buf) + return; + + for (;;) { + iv.iov_base = recv_buf; + iv.iov_len = recv_buf_size; + + msglen = kernel_recvmsg(qrtr_ns.sock, &msg, &iv, 1, + iv.iov_len, MSG_DONTWAIT); + + if (msglen == -EAGAIN) + break; + + if (msglen < 0) { + pr_err("error receiving packet: %zd\n", msglen); + break; + } + + pkt = recv_buf; + cmd = le32_to_cpu(pkt->cmd); + if (cmd < ARRAY_SIZE(qrtr_ctrl_pkt_strings) && + qrtr_ctrl_pkt_strings[cmd]) + trace_qrtr_ns_message(qrtr_ctrl_pkt_strings[cmd], + sq.sq_node, sq.sq_port); + + ret = 0; + switch (cmd) { + case QRTR_TYPE_HELLO: + ret = ctrl_cmd_hello(&sq); + break; + case QRTR_TYPE_BYE: + ret = ctrl_cmd_bye(&sq); + break; + case QRTR_TYPE_DEL_CLIENT: + ret = ctrl_cmd_del_client(&sq, + le32_to_cpu(pkt->client.node), + le32_to_cpu(pkt->client.port)); + break; + case QRTR_TYPE_NEW_SERVER: + ret = ctrl_cmd_new_server(&sq, + le32_to_cpu(pkt->server.service), + le32_to_cpu(pkt->server.instance), + le32_to_cpu(pkt->server.node), + le32_to_cpu(pkt->server.port)); + break; + case QRTR_TYPE_DEL_SERVER: + ret = ctrl_cmd_del_server(&sq, + le32_to_cpu(pkt->server.service), + le32_to_cpu(pkt->server.instance), + le32_to_cpu(pkt->server.node), + le32_to_cpu(pkt->server.port)); + break; + case QRTR_TYPE_EXIT: + case QRTR_TYPE_PING: + case QRTR_TYPE_RESUME_TX: + break; + case QRTR_TYPE_NEW_LOOKUP: + ret = ctrl_cmd_new_lookup(&sq, + le32_to_cpu(pkt->server.service), + le32_to_cpu(pkt->server.instance)); + break; + case QRTR_TYPE_DEL_LOOKUP: + ctrl_cmd_del_lookup(&sq, + le32_to_cpu(pkt->server.service), + le32_to_cpu(pkt->server.instance)); + break; + } + + if (ret < 0) + pr_err("failed while handling packet from %d:%d", + sq.sq_node, sq.sq_port); + } + + kfree(recv_buf); +} + +static void qrtr_ns_data_ready(struct sock *sk) +{ + queue_work(qrtr_ns.workqueue, &qrtr_ns.work); +} + +int qrtr_ns_init(void) +{ + struct sockaddr_qrtr sq; + int ret; + + INIT_LIST_HEAD(&qrtr_ns.lookups); + INIT_WORK(&qrtr_ns.work, qrtr_ns_worker); + + ret = sock_create_kern(&init_net, AF_QIPCRTR, SOCK_DGRAM, + PF_QIPCRTR, &qrtr_ns.sock); + if (ret < 0) + return ret; + + ret = kernel_getsockname(qrtr_ns.sock, (struct sockaddr *)&sq); + if (ret < 0) { + pr_err("failed to get socket name\n"); + goto err_sock; + } + + qrtr_ns.workqueue = alloc_workqueue("qrtr_ns_handler", WQ_UNBOUND, 1); + if (!qrtr_ns.workqueue) { + ret = -ENOMEM; + goto err_sock; + } + + qrtr_ns.sock->sk->sk_data_ready = qrtr_ns_data_ready; + + sq.sq_port = QRTR_PORT_CTRL; + qrtr_ns.local_node = sq.sq_node; + + ret = kernel_bind(qrtr_ns.sock, (struct sockaddr *)&sq, sizeof(sq)); + if (ret < 0) { + pr_err("failed to bind to socket\n"); + goto err_wq; + } + + qrtr_ns.bcast_sq.sq_family = AF_QIPCRTR; + qrtr_ns.bcast_sq.sq_node = QRTR_NODE_BCAST; + qrtr_ns.bcast_sq.sq_port = QRTR_PORT_CTRL; + + ret = say_hello(&qrtr_ns.bcast_sq); + if (ret < 0) + goto err_wq; + + return 0; + +err_wq: + destroy_workqueue(qrtr_ns.workqueue); +err_sock: + sock_release(qrtr_ns.sock); + return ret; +} +EXPORT_SYMBOL_GPL(qrtr_ns_init); + +void qrtr_ns_remove(void) +{ + cancel_work_sync(&qrtr_ns.work); + destroy_workqueue(qrtr_ns.workqueue); + sock_release(qrtr_ns.sock); +} +EXPORT_SYMBOL_GPL(qrtr_ns_remove); + +MODULE_AUTHOR("Manivannan Sadhasivam "); +MODULE_DESCRIPTION("Qualcomm IPC Router Nameservice"); +MODULE_LICENSE("Dual BSD/GPL"); diff --git a/net/qrtr/qrtr.h b/net/qrtr/qrtr.h new file mode 100644 index 000000000..3f2d28696 --- /dev/null +++ b/net/qrtr/qrtr.h @@ -0,0 +1,36 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __QRTR_H_ +#define __QRTR_H_ + +#include + +struct sk_buff; + +/* endpoint node id auto assignment */ +#define QRTR_EP_NID_AUTO (-1) + +/** + * struct qrtr_endpoint - endpoint handle + * @xmit: Callback for outgoing packets + * + * The socket buffer passed to the xmit function becomes owned by the endpoint + * driver. As such, when the driver is done with the buffer, it should + * call kfree_skb() on failure, or consume_skb() on success. + */ +struct qrtr_endpoint { + int (*xmit)(struct qrtr_endpoint *ep, struct sk_buff *skb); + /* private: not for endpoint use */ + struct qrtr_node *node; +}; + +int qrtr_endpoint_register(struct qrtr_endpoint *ep, unsigned int nid); + +void qrtr_endpoint_unregister(struct qrtr_endpoint *ep); + +int qrtr_endpoint_post(struct qrtr_endpoint *ep, const void *data, size_t len); + +int qrtr_ns_init(void); + +void qrtr_ns_remove(void); + +#endif diff --git a/net/qrtr/smd.c b/net/qrtr/smd.c new file mode 100644 index 000000000..c91bf030f --- /dev/null +++ b/net/qrtr/smd.c @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2015, Sony Mobile Communications Inc. + * Copyright (c) 2013, The Linux Foundation. All rights reserved. + */ + +#include +#include +#include + +#include "qrtr.h" + +struct qrtr_smd_dev { + struct qrtr_endpoint ep; + struct rpmsg_endpoint *channel; + struct device *dev; +}; + +/* from smd to qrtr */ +static int qcom_smd_qrtr_callback(struct rpmsg_device *rpdev, + void *data, int len, void *priv, u32 addr) +{ + struct qrtr_smd_dev *qdev = dev_get_drvdata(&rpdev->dev); + int rc; + + if (!qdev) + return -EAGAIN; + + rc = qrtr_endpoint_post(&qdev->ep, data, len); + if (rc == -EINVAL) { + dev_err(qdev->dev, "invalid ipcrouter packet\n"); + /* return 0 to let smd drop the packet */ + rc = 0; + } + + return rc; +} + +/* from qrtr to smd */ +static int qcom_smd_qrtr_send(struct qrtr_endpoint *ep, struct sk_buff *skb) +{ + struct qrtr_smd_dev *qdev = container_of(ep, struct qrtr_smd_dev, ep); + int rc; + + rc = skb_linearize(skb); + if (rc) + goto out; + + rc = rpmsg_send(qdev->channel, skb->data, skb->len); + +out: + if (rc) + kfree_skb(skb); + else + consume_skb(skb); + return rc; +} + +static int qcom_smd_qrtr_probe(struct rpmsg_device *rpdev) +{ + struct qrtr_smd_dev *qdev; + int rc; + + qdev = devm_kzalloc(&rpdev->dev, sizeof(*qdev), GFP_KERNEL); + if (!qdev) + return -ENOMEM; + + qdev->channel = rpdev->ept; + qdev->dev = &rpdev->dev; + qdev->ep.xmit = qcom_smd_qrtr_send; + + rc = qrtr_endpoint_register(&qdev->ep, QRTR_EP_NID_AUTO); + if (rc) + return rc; + + dev_set_drvdata(&rpdev->dev, qdev); + + dev_dbg(&rpdev->dev, "Qualcomm SMD QRTR driver probed\n"); + + return 0; +} + +static void qcom_smd_qrtr_remove(struct rpmsg_device *rpdev) +{ + struct qrtr_smd_dev *qdev = dev_get_drvdata(&rpdev->dev); + + qrtr_endpoint_unregister(&qdev->ep); + + dev_set_drvdata(&rpdev->dev, NULL); +} + +static const struct rpmsg_device_id qcom_smd_qrtr_smd_match[] = { + { "IPCRTR" }, + {} +}; + +static struct rpmsg_driver qcom_smd_qrtr_driver = { + .probe = qcom_smd_qrtr_probe, + .remove = qcom_smd_qrtr_remove, + .callback = qcom_smd_qrtr_callback, + .id_table = qcom_smd_qrtr_smd_match, + .drv = { + .name = "qcom_smd_qrtr", + }, +}; + +module_rpmsg_driver(qcom_smd_qrtr_driver); + +MODULE_ALIAS("rpmsg:IPCRTR"); +MODULE_DESCRIPTION("Qualcomm IPC-Router SMD interface driver"); +MODULE_LICENSE("GPL v2"); diff --git a/net/qrtr/tun.c b/net/qrtr/tun.c new file mode 100644 index 000000000..304b41fea --- /dev/null +++ b/net/qrtr/tun.c @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2018, Linaro Ltd */ + +#include +#include +#include +#include +#include + +#include "qrtr.h" + +struct qrtr_tun { + struct qrtr_endpoint ep; + + struct sk_buff_head queue; + wait_queue_head_t readq; +}; + +static int qrtr_tun_send(struct qrtr_endpoint *ep, struct sk_buff *skb) +{ + struct qrtr_tun *tun = container_of(ep, struct qrtr_tun, ep); + + skb_queue_tail(&tun->queue, skb); + + /* wake up any blocking processes, waiting for new data */ + wake_up_interruptible(&tun->readq); + + return 0; +} + +static int qrtr_tun_open(struct inode *inode, struct file *filp) +{ + struct qrtr_tun *tun; + int ret; + + tun = kzalloc(sizeof(*tun), GFP_KERNEL); + if (!tun) + return -ENOMEM; + + skb_queue_head_init(&tun->queue); + init_waitqueue_head(&tun->readq); + + tun->ep.xmit = qrtr_tun_send; + + filp->private_data = tun; + + ret = qrtr_endpoint_register(&tun->ep, QRTR_EP_NID_AUTO); + if (ret) + goto out; + + return 0; + +out: + filp->private_data = NULL; + kfree(tun); + return ret; +} + +static ssize_t qrtr_tun_read_iter(struct kiocb *iocb, struct iov_iter *to) +{ + struct file *filp = iocb->ki_filp; + struct qrtr_tun *tun = filp->private_data; + struct sk_buff *skb; + int count; + + while (!(skb = skb_dequeue(&tun->queue))) { + if (filp->f_flags & O_NONBLOCK) + return -EAGAIN; + + /* Wait until we get data or the endpoint goes away */ + if (wait_event_interruptible(tun->readq, + !skb_queue_empty(&tun->queue))) + return -ERESTARTSYS; + } + + count = min_t(size_t, iov_iter_count(to), skb->len); + if (copy_to_iter(skb->data, count, to) != count) + count = -EFAULT; + + kfree_skb(skb); + + return count; +} + +static ssize_t qrtr_tun_write_iter(struct kiocb *iocb, struct iov_iter *from) +{ + struct file *filp = iocb->ki_filp; + struct qrtr_tun *tun = filp->private_data; + size_t len = iov_iter_count(from); + ssize_t ret; + void *kbuf; + + if (!len) + return -EINVAL; + + if (len > KMALLOC_MAX_SIZE) + return -ENOMEM; + + kbuf = kzalloc(len, GFP_KERNEL); + if (!kbuf) + return -ENOMEM; + + if (!copy_from_iter_full(kbuf, len, from)) { + kfree(kbuf); + return -EFAULT; + } + + ret = qrtr_endpoint_post(&tun->ep, kbuf, len); + + kfree(kbuf); + return ret < 0 ? ret : len; +} + +static __poll_t qrtr_tun_poll(struct file *filp, poll_table *wait) +{ + struct qrtr_tun *tun = filp->private_data; + __poll_t mask = 0; + + poll_wait(filp, &tun->readq, wait); + + if (!skb_queue_empty(&tun->queue)) + mask |= EPOLLIN | EPOLLRDNORM; + + return mask; +} + +static int qrtr_tun_release(struct inode *inode, struct file *filp) +{ + struct qrtr_tun *tun = filp->private_data; + + qrtr_endpoint_unregister(&tun->ep); + + /* Discard all SKBs */ + skb_queue_purge(&tun->queue); + + kfree(tun); + + return 0; +} + +static const struct file_operations qrtr_tun_ops = { + .owner = THIS_MODULE, + .open = qrtr_tun_open, + .poll = qrtr_tun_poll, + .read_iter = qrtr_tun_read_iter, + .write_iter = qrtr_tun_write_iter, + .release = qrtr_tun_release, +}; + +static struct miscdevice qrtr_tun_miscdev = { + MISC_DYNAMIC_MINOR, + "qrtr-tun", + &qrtr_tun_ops, +}; + +static int __init qrtr_tun_init(void) +{ + int ret; + + ret = misc_register(&qrtr_tun_miscdev); + if (ret) + pr_err("failed to register Qualcomm IPC Router tun device\n"); + + return ret; +} + +static void __exit qrtr_tun_exit(void) +{ + misc_deregister(&qrtr_tun_miscdev); +} + +module_init(qrtr_tun_init); +module_exit(qrtr_tun_exit); + +MODULE_DESCRIPTION("Qualcomm IPC Router TUN device"); +MODULE_LICENSE("GPL v2"); diff --git a/net/rds/Kconfig b/net/rds/Kconfig new file mode 100644 index 000000000..75cd69696 --- /dev/null +++ b/net/rds/Kconfig @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: GPL-2.0-only + +config RDS + tristate "The Reliable Datagram Sockets Protocol" + depends on INET + help + The RDS (Reliable Datagram Sockets) protocol provides reliable, + sequenced delivery of datagrams over Infiniband or TCP. + +config RDS_RDMA + tristate "RDS over Infiniband" + depends on RDS && INFINIBAND && INFINIBAND_ADDR_TRANS + help + Allow RDS to use Infiniband as a transport. + This transport supports RDMA operations. + +config RDS_TCP + tristate "RDS over TCP" + depends on RDS + depends on IPV6 || !IPV6 + help + Allow RDS to use TCP as a transport. + This transport does not support RDMA operations. + +config RDS_DEBUG + bool "RDS debugging messages" + depends on RDS + default n diff --git a/net/rds/Makefile b/net/rds/Makefile new file mode 100644 index 000000000..8fdc118e2 --- /dev/null +++ b/net/rds/Makefile @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: GPL-2.0 +obj-$(CONFIG_RDS) += rds.o +rds-y := af_rds.o bind.o cong.o connection.o info.o message.o \ + recv.o send.o stats.o sysctl.o threads.o transport.o \ + loop.o page.o rdma.o + +obj-$(CONFIG_RDS_RDMA) += rds_rdma.o +rds_rdma-y := rdma_transport.o \ + ib.o ib_cm.o ib_recv.o ib_ring.o ib_send.o ib_stats.o \ + ib_sysctl.o ib_rdma.o ib_frmr.o + + +obj-$(CONFIG_RDS_TCP) += rds_tcp.o +rds_tcp-y := tcp.o tcp_connect.o tcp_listen.o tcp_recv.o \ + tcp_send.o tcp_stats.o + +ccflags-$(CONFIG_RDS_DEBUG) := -DRDS_DEBUG diff --git a/net/rds/af_rds.c b/net/rds/af_rds.c new file mode 100644 index 000000000..d107f7605 --- /dev/null +++ b/net/rds/af_rds.c @@ -0,0 +1,963 @@ +/* + * Copyright (c) 2006, 2019 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "rds.h" + +/* this is just used for stats gathering :/ */ +static DEFINE_SPINLOCK(rds_sock_lock); +static unsigned long rds_sock_count; +static LIST_HEAD(rds_sock_list); +DECLARE_WAIT_QUEUE_HEAD(rds_poll_waitq); + +/* + * This is called as the final descriptor referencing this socket is closed. + * We have to unbind the socket so that another socket can be bound to the + * address it was using. + * + * We have to be careful about racing with the incoming path. sock_orphan() + * sets SOCK_DEAD and we use that as an indicator to the rx path that new + * messages shouldn't be queued. + */ +static int rds_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct rds_sock *rs; + + if (!sk) + goto out; + + rs = rds_sk_to_rs(sk); + + sock_orphan(sk); + /* Note - rds_clear_recv_queue grabs rs_recv_lock, so + * that ensures the recv path has completed messing + * with the socket. */ + rds_clear_recv_queue(rs); + rds_cong_remove_socket(rs); + + rds_remove_bound(rs); + + rds_send_drop_to(rs, NULL); + rds_rdma_drop_keys(rs); + rds_notify_queue_get(rs, NULL); + rds_notify_msg_zcopy_purge(&rs->rs_zcookie_queue); + + spin_lock_bh(&rds_sock_lock); + list_del_init(&rs->rs_item); + rds_sock_count--; + spin_unlock_bh(&rds_sock_lock); + + rds_trans_put(rs->rs_transport); + + sock->sk = NULL; + sock_put(sk); +out: + return 0; +} + +/* + * Careful not to race with rds_release -> sock_orphan which clears sk_sleep. + * _bh() isn't OK here, we're called from interrupt handlers. It's probably OK + * to wake the waitqueue after sk_sleep is clear as we hold a sock ref, but + * this seems more conservative. + * NB - normally, one would use sk_callback_lock for this, but we can + * get here from interrupts, whereas the network code grabs sk_callback_lock + * with _lock_bh only - so relying on sk_callback_lock introduces livelocks. + */ +void rds_wake_sk_sleep(struct rds_sock *rs) +{ + unsigned long flags; + + read_lock_irqsave(&rs->rs_recv_lock, flags); + __rds_wake_sk_sleep(rds_rs_to_sk(rs)); + read_unlock_irqrestore(&rs->rs_recv_lock, flags); +} + +static int rds_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + struct rds_sock *rs = rds_sk_to_rs(sock->sk); + struct sockaddr_in6 *sin6; + struct sockaddr_in *sin; + int uaddr_len; + + /* racey, don't care */ + if (peer) { + if (ipv6_addr_any(&rs->rs_conn_addr)) + return -ENOTCONN; + + if (ipv6_addr_v4mapped(&rs->rs_conn_addr)) { + sin = (struct sockaddr_in *)uaddr; + memset(sin->sin_zero, 0, sizeof(sin->sin_zero)); + sin->sin_family = AF_INET; + sin->sin_port = rs->rs_conn_port; + sin->sin_addr.s_addr = rs->rs_conn_addr_v4; + uaddr_len = sizeof(*sin); + } else { + sin6 = (struct sockaddr_in6 *)uaddr; + sin6->sin6_family = AF_INET6; + sin6->sin6_port = rs->rs_conn_port; + sin6->sin6_addr = rs->rs_conn_addr; + sin6->sin6_flowinfo = 0; + /* scope_id is the same as in the bound address. */ + sin6->sin6_scope_id = rs->rs_bound_scope_id; + uaddr_len = sizeof(*sin6); + } + } else { + /* If socket is not yet bound and the socket is connected, + * set the return address family to be the same as the + * connected address, but with 0 address value. If it is not + * connected, set the family to be AF_UNSPEC (value 0) and + * the address size to be that of an IPv4 address. + */ + if (ipv6_addr_any(&rs->rs_bound_addr)) { + if (ipv6_addr_any(&rs->rs_conn_addr)) { + sin = (struct sockaddr_in *)uaddr; + memset(sin, 0, sizeof(*sin)); + sin->sin_family = AF_UNSPEC; + return sizeof(*sin); + } + +#if IS_ENABLED(CONFIG_IPV6) + if (!(ipv6_addr_type(&rs->rs_conn_addr) & + IPV6_ADDR_MAPPED)) { + sin6 = (struct sockaddr_in6 *)uaddr; + memset(sin6, 0, sizeof(*sin6)); + sin6->sin6_family = AF_INET6; + return sizeof(*sin6); + } +#endif + + sin = (struct sockaddr_in *)uaddr; + memset(sin, 0, sizeof(*sin)); + sin->sin_family = AF_INET; + return sizeof(*sin); + } + if (ipv6_addr_v4mapped(&rs->rs_bound_addr)) { + sin = (struct sockaddr_in *)uaddr; + memset(sin->sin_zero, 0, sizeof(sin->sin_zero)); + sin->sin_family = AF_INET; + sin->sin_port = rs->rs_bound_port; + sin->sin_addr.s_addr = rs->rs_bound_addr_v4; + uaddr_len = sizeof(*sin); + } else { + sin6 = (struct sockaddr_in6 *)uaddr; + sin6->sin6_family = AF_INET6; + sin6->sin6_port = rs->rs_bound_port; + sin6->sin6_addr = rs->rs_bound_addr; + sin6->sin6_flowinfo = 0; + sin6->sin6_scope_id = rs->rs_bound_scope_id; + uaddr_len = sizeof(*sin6); + } + } + + return uaddr_len; +} + +/* + * RDS' poll is without a doubt the least intuitive part of the interface, + * as EPOLLIN and EPOLLOUT do not behave entirely as you would expect from + * a network protocol. + * + * EPOLLIN is asserted if + * - there is data on the receive queue. + * - to signal that a previously congested destination may have become + * uncongested + * - A notification has been queued to the socket (this can be a congestion + * update, or a RDMA completion, or a MSG_ZEROCOPY completion). + * + * EPOLLOUT is asserted if there is room on the send queue. This does not mean + * however, that the next sendmsg() call will succeed. If the application tries + * to send to a congested destination, the system call may still fail (and + * return ENOBUFS). + */ +static __poll_t rds_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + struct sock *sk = sock->sk; + struct rds_sock *rs = rds_sk_to_rs(sk); + __poll_t mask = 0; + unsigned long flags; + + poll_wait(file, sk_sleep(sk), wait); + + if (rs->rs_seen_congestion) + poll_wait(file, &rds_poll_waitq, wait); + + read_lock_irqsave(&rs->rs_recv_lock, flags); + if (!rs->rs_cong_monitor) { + /* When a congestion map was updated, we signal EPOLLIN for + * "historical" reasons. Applications can also poll for + * WRBAND instead. */ + if (rds_cong_updated_since(&rs->rs_cong_track)) + mask |= (EPOLLIN | EPOLLRDNORM | EPOLLWRBAND); + } else { + spin_lock(&rs->rs_lock); + if (rs->rs_cong_notify) + mask |= (EPOLLIN | EPOLLRDNORM); + spin_unlock(&rs->rs_lock); + } + if (!list_empty(&rs->rs_recv_queue) || + !list_empty(&rs->rs_notify_queue) || + !list_empty(&rs->rs_zcookie_queue.zcookie_head)) + mask |= (EPOLLIN | EPOLLRDNORM); + if (rs->rs_snd_bytes < rds_sk_sndbuf(rs)) + mask |= (EPOLLOUT | EPOLLWRNORM); + if (sk->sk_err || !skb_queue_empty(&sk->sk_error_queue)) + mask |= POLLERR; + read_unlock_irqrestore(&rs->rs_recv_lock, flags); + + /* clear state any time we wake a seen-congested socket */ + if (mask) + rs->rs_seen_congestion = 0; + + return mask; +} + +static int rds_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + struct rds_sock *rs = rds_sk_to_rs(sock->sk); + rds_tos_t utos, tos = 0; + + switch (cmd) { + case SIOCRDSSETTOS: + if (get_user(utos, (rds_tos_t __user *)arg)) + return -EFAULT; + + if (rs->rs_transport && + rs->rs_transport->get_tos_map) + tos = rs->rs_transport->get_tos_map(utos); + else + return -ENOIOCTLCMD; + + spin_lock_bh(&rds_sock_lock); + if (rs->rs_tos || rs->rs_conn) { + spin_unlock_bh(&rds_sock_lock); + return -EINVAL; + } + rs->rs_tos = tos; + spin_unlock_bh(&rds_sock_lock); + break; + case SIOCRDSGETTOS: + spin_lock_bh(&rds_sock_lock); + tos = rs->rs_tos; + spin_unlock_bh(&rds_sock_lock); + if (put_user(tos, (rds_tos_t __user *)arg)) + return -EFAULT; + break; + default: + return -ENOIOCTLCMD; + } + + return 0; +} + +static int rds_cancel_sent_to(struct rds_sock *rs, sockptr_t optval, int len) +{ + struct sockaddr_in6 sin6; + struct sockaddr_in sin; + int ret = 0; + + /* racing with another thread binding seems ok here */ + if (ipv6_addr_any(&rs->rs_bound_addr)) { + ret = -ENOTCONN; /* XXX not a great errno */ + goto out; + } + + if (len < sizeof(struct sockaddr_in)) { + ret = -EINVAL; + goto out; + } else if (len < sizeof(struct sockaddr_in6)) { + /* Assume IPv4 */ + if (copy_from_sockptr(&sin, optval, + sizeof(struct sockaddr_in))) { + ret = -EFAULT; + goto out; + } + ipv6_addr_set_v4mapped(sin.sin_addr.s_addr, &sin6.sin6_addr); + sin6.sin6_port = sin.sin_port; + } else { + if (copy_from_sockptr(&sin6, optval, + sizeof(struct sockaddr_in6))) { + ret = -EFAULT; + goto out; + } + } + + rds_send_drop_to(rs, &sin6); +out: + return ret; +} + +static int rds_set_bool_option(unsigned char *optvar, sockptr_t optval, + int optlen) +{ + int value; + + if (optlen < sizeof(int)) + return -EINVAL; + if (copy_from_sockptr(&value, optval, sizeof(int))) + return -EFAULT; + *optvar = !!value; + return 0; +} + +static int rds_cong_monitor(struct rds_sock *rs, sockptr_t optval, int optlen) +{ + int ret; + + ret = rds_set_bool_option(&rs->rs_cong_monitor, optval, optlen); + if (ret == 0) { + if (rs->rs_cong_monitor) { + rds_cong_add_socket(rs); + } else { + rds_cong_remove_socket(rs); + rs->rs_cong_mask = 0; + rs->rs_cong_notify = 0; + } + } + return ret; +} + +static int rds_set_transport(struct rds_sock *rs, sockptr_t optval, int optlen) +{ + int t_type; + + if (rs->rs_transport) + return -EOPNOTSUPP; /* previously attached to transport */ + + if (optlen != sizeof(int)) + return -EINVAL; + + if (copy_from_sockptr(&t_type, optval, sizeof(t_type))) + return -EFAULT; + + if (t_type < 0 || t_type >= RDS_TRANS_COUNT) + return -EINVAL; + + rs->rs_transport = rds_trans_get(t_type); + + return rs->rs_transport ? 0 : -ENOPROTOOPT; +} + +static int rds_enable_recvtstamp(struct sock *sk, sockptr_t optval, + int optlen, int optname) +{ + int val, valbool; + + if (optlen != sizeof(int)) + return -EFAULT; + + if (copy_from_sockptr(&val, optval, sizeof(int))) + return -EFAULT; + + valbool = val ? 1 : 0; + + if (optname == SO_TIMESTAMP_NEW) + sock_set_flag(sk, SOCK_TSTAMP_NEW); + + if (valbool) + sock_set_flag(sk, SOCK_RCVTSTAMP); + else + sock_reset_flag(sk, SOCK_RCVTSTAMP); + + return 0; +} + +static int rds_recv_track_latency(struct rds_sock *rs, sockptr_t optval, + int optlen) +{ + struct rds_rx_trace_so trace; + int i; + + if (optlen != sizeof(struct rds_rx_trace_so)) + return -EFAULT; + + if (copy_from_sockptr(&trace, optval, sizeof(trace))) + return -EFAULT; + + if (trace.rx_traces > RDS_MSG_RX_DGRAM_TRACE_MAX) + return -EFAULT; + + rs->rs_rx_traces = trace.rx_traces; + for (i = 0; i < rs->rs_rx_traces; i++) { + if (trace.rx_trace_pos[i] >= RDS_MSG_RX_DGRAM_TRACE_MAX) { + rs->rs_rx_traces = 0; + return -EFAULT; + } + rs->rs_rx_trace[i] = trace.rx_trace_pos[i]; + } + + return 0; +} + +static int rds_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct rds_sock *rs = rds_sk_to_rs(sock->sk); + int ret; + + if (level != SOL_RDS) { + ret = -ENOPROTOOPT; + goto out; + } + + switch (optname) { + case RDS_CANCEL_SENT_TO: + ret = rds_cancel_sent_to(rs, optval, optlen); + break; + case RDS_GET_MR: + ret = rds_get_mr(rs, optval, optlen); + break; + case RDS_GET_MR_FOR_DEST: + ret = rds_get_mr_for_dest(rs, optval, optlen); + break; + case RDS_FREE_MR: + ret = rds_free_mr(rs, optval, optlen); + break; + case RDS_RECVERR: + ret = rds_set_bool_option(&rs->rs_recverr, optval, optlen); + break; + case RDS_CONG_MONITOR: + ret = rds_cong_monitor(rs, optval, optlen); + break; + case SO_RDS_TRANSPORT: + lock_sock(sock->sk); + ret = rds_set_transport(rs, optval, optlen); + release_sock(sock->sk); + break; + case SO_TIMESTAMP_OLD: + case SO_TIMESTAMP_NEW: + lock_sock(sock->sk); + ret = rds_enable_recvtstamp(sock->sk, optval, optlen, optname); + release_sock(sock->sk); + break; + case SO_RDS_MSG_RXPATH_LATENCY: + ret = rds_recv_track_latency(rs, optval, optlen); + break; + default: + ret = -ENOPROTOOPT; + } +out: + return ret; +} + +static int rds_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct rds_sock *rs = rds_sk_to_rs(sock->sk); + int ret = -ENOPROTOOPT, len; + int trans; + + if (level != SOL_RDS) + goto out; + + if (get_user(len, optlen)) { + ret = -EFAULT; + goto out; + } + + switch (optname) { + case RDS_INFO_FIRST ... RDS_INFO_LAST: + ret = rds_info_getsockopt(sock, optname, optval, + optlen); + break; + + case RDS_RECVERR: + if (len < sizeof(int)) + ret = -EINVAL; + else + if (put_user(rs->rs_recverr, (int __user *) optval) || + put_user(sizeof(int), optlen)) + ret = -EFAULT; + else + ret = 0; + break; + case SO_RDS_TRANSPORT: + if (len < sizeof(int)) { + ret = -EINVAL; + break; + } + trans = (rs->rs_transport ? rs->rs_transport->t_type : + RDS_TRANS_NONE); /* unbound */ + if (put_user(trans, (int __user *)optval) || + put_user(sizeof(int), optlen)) + ret = -EFAULT; + else + ret = 0; + break; + default: + break; + } + +out: + return ret; + +} + +static int rds_connect(struct socket *sock, struct sockaddr *uaddr, + int addr_len, int flags) +{ + struct sock *sk = sock->sk; + struct sockaddr_in *sin; + struct rds_sock *rs = rds_sk_to_rs(sk); + int ret = 0; + + if (addr_len < offsetofend(struct sockaddr, sa_family)) + return -EINVAL; + + lock_sock(sk); + + switch (uaddr->sa_family) { + case AF_INET: + sin = (struct sockaddr_in *)uaddr; + if (addr_len < sizeof(struct sockaddr_in)) { + ret = -EINVAL; + break; + } + if (sin->sin_addr.s_addr == htonl(INADDR_ANY)) { + ret = -EDESTADDRREQ; + break; + } + if (ipv4_is_multicast(sin->sin_addr.s_addr) || + sin->sin_addr.s_addr == htonl(INADDR_BROADCAST)) { + ret = -EINVAL; + break; + } + ipv6_addr_set_v4mapped(sin->sin_addr.s_addr, &rs->rs_conn_addr); + rs->rs_conn_port = sin->sin_port; + break; + +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: { + struct sockaddr_in6 *sin6; + int addr_type; + + sin6 = (struct sockaddr_in6 *)uaddr; + if (addr_len < sizeof(struct sockaddr_in6)) { + ret = -EINVAL; + break; + } + addr_type = ipv6_addr_type(&sin6->sin6_addr); + if (!(addr_type & IPV6_ADDR_UNICAST)) { + __be32 addr4; + + if (!(addr_type & IPV6_ADDR_MAPPED)) { + ret = -EPROTOTYPE; + break; + } + + /* It is a mapped address. Need to do some sanity + * checks. + */ + addr4 = sin6->sin6_addr.s6_addr32[3]; + if (addr4 == htonl(INADDR_ANY) || + addr4 == htonl(INADDR_BROADCAST) || + ipv4_is_multicast(addr4)) { + ret = -EPROTOTYPE; + break; + } + } + + if (addr_type & IPV6_ADDR_LINKLOCAL) { + /* If socket is arleady bound to a link local address, + * the peer address must be on the same link. + */ + if (sin6->sin6_scope_id == 0 || + (!ipv6_addr_any(&rs->rs_bound_addr) && + rs->rs_bound_scope_id && + sin6->sin6_scope_id != rs->rs_bound_scope_id)) { + ret = -EINVAL; + break; + } + /* Remember the connected address scope ID. It will + * be checked against the binding local address when + * the socket is bound. + */ + rs->rs_bound_scope_id = sin6->sin6_scope_id; + } + rs->rs_conn_addr = sin6->sin6_addr; + rs->rs_conn_port = sin6->sin6_port; + break; + } +#endif + + default: + ret = -EAFNOSUPPORT; + break; + } + + release_sock(sk); + return ret; +} + +static struct proto rds_proto = { + .name = "RDS", + .owner = THIS_MODULE, + .obj_size = sizeof(struct rds_sock), +}; + +static const struct proto_ops rds_proto_ops = { + .family = AF_RDS, + .owner = THIS_MODULE, + .release = rds_release, + .bind = rds_bind, + .connect = rds_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = rds_getname, + .poll = rds_poll, + .ioctl = rds_ioctl, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = rds_setsockopt, + .getsockopt = rds_getsockopt, + .sendmsg = rds_sendmsg, + .recvmsg = rds_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static void rds_sock_destruct(struct sock *sk) +{ + struct rds_sock *rs = rds_sk_to_rs(sk); + + WARN_ON((&rs->rs_item != rs->rs_item.next || + &rs->rs_item != rs->rs_item.prev)); +} + +static int __rds_create(struct socket *sock, struct sock *sk, int protocol) +{ + struct rds_sock *rs; + + sock_init_data(sock, sk); + sock->ops = &rds_proto_ops; + sk->sk_protocol = protocol; + sk->sk_destruct = rds_sock_destruct; + + rs = rds_sk_to_rs(sk); + spin_lock_init(&rs->rs_lock); + rwlock_init(&rs->rs_recv_lock); + INIT_LIST_HEAD(&rs->rs_send_queue); + INIT_LIST_HEAD(&rs->rs_recv_queue); + INIT_LIST_HEAD(&rs->rs_notify_queue); + INIT_LIST_HEAD(&rs->rs_cong_list); + rds_message_zcopy_queue_init(&rs->rs_zcookie_queue); + spin_lock_init(&rs->rs_rdma_lock); + rs->rs_rdma_keys = RB_ROOT; + rs->rs_rx_traces = 0; + rs->rs_tos = 0; + rs->rs_conn = NULL; + + spin_lock_bh(&rds_sock_lock); + list_add_tail(&rs->rs_item, &rds_sock_list); + rds_sock_count++; + spin_unlock_bh(&rds_sock_lock); + + return 0; +} + +static int rds_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + + if (sock->type != SOCK_SEQPACKET || protocol) + return -ESOCKTNOSUPPORT; + + sk = sk_alloc(net, AF_RDS, GFP_KERNEL, &rds_proto, kern); + if (!sk) + return -ENOMEM; + + return __rds_create(sock, sk, protocol); +} + +void rds_sock_addref(struct rds_sock *rs) +{ + sock_hold(rds_rs_to_sk(rs)); +} + +void rds_sock_put(struct rds_sock *rs) +{ + sock_put(rds_rs_to_sk(rs)); +} + +static const struct net_proto_family rds_family_ops = { + .family = AF_RDS, + .create = rds_create, + .owner = THIS_MODULE, +}; + +static void rds_sock_inc_info(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens) +{ + struct rds_sock *rs; + struct rds_incoming *inc; + unsigned int total = 0; + + len /= sizeof(struct rds_info_message); + + spin_lock_bh(&rds_sock_lock); + + list_for_each_entry(rs, &rds_sock_list, rs_item) { + /* This option only supports IPv4 sockets. */ + if (!ipv6_addr_v4mapped(&rs->rs_bound_addr)) + continue; + + read_lock(&rs->rs_recv_lock); + + /* XXX too lazy to maintain counts.. */ + list_for_each_entry(inc, &rs->rs_recv_queue, i_item) { + total++; + if (total <= len) + rds_inc_info_copy(inc, iter, + inc->i_saddr.s6_addr32[3], + rs->rs_bound_addr_v4, + 1); + } + + read_unlock(&rs->rs_recv_lock); + } + + spin_unlock_bh(&rds_sock_lock); + + lens->nr = total; + lens->each = sizeof(struct rds_info_message); +} + +#if IS_ENABLED(CONFIG_IPV6) +static void rds6_sock_inc_info(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens) +{ + struct rds_incoming *inc; + unsigned int total = 0; + struct rds_sock *rs; + + len /= sizeof(struct rds6_info_message); + + spin_lock_bh(&rds_sock_lock); + + list_for_each_entry(rs, &rds_sock_list, rs_item) { + read_lock(&rs->rs_recv_lock); + + list_for_each_entry(inc, &rs->rs_recv_queue, i_item) { + total++; + if (total <= len) + rds6_inc_info_copy(inc, iter, &inc->i_saddr, + &rs->rs_bound_addr, 1); + } + + read_unlock(&rs->rs_recv_lock); + } + + spin_unlock_bh(&rds_sock_lock); + + lens->nr = total; + lens->each = sizeof(struct rds6_info_message); +} +#endif + +static void rds_sock_info(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens) +{ + struct rds_info_socket sinfo; + unsigned int cnt = 0; + struct rds_sock *rs; + + len /= sizeof(struct rds_info_socket); + + spin_lock_bh(&rds_sock_lock); + + if (len < rds_sock_count) { + cnt = rds_sock_count; + goto out; + } + + list_for_each_entry(rs, &rds_sock_list, rs_item) { + /* This option only supports IPv4 sockets. */ + if (!ipv6_addr_v4mapped(&rs->rs_bound_addr)) + continue; + sinfo.sndbuf = rds_sk_sndbuf(rs); + sinfo.rcvbuf = rds_sk_rcvbuf(rs); + sinfo.bound_addr = rs->rs_bound_addr_v4; + sinfo.connected_addr = rs->rs_conn_addr_v4; + sinfo.bound_port = rs->rs_bound_port; + sinfo.connected_port = rs->rs_conn_port; + sinfo.inum = sock_i_ino(rds_rs_to_sk(rs)); + + rds_info_copy(iter, &sinfo, sizeof(sinfo)); + cnt++; + } + +out: + lens->nr = cnt; + lens->each = sizeof(struct rds_info_socket); + + spin_unlock_bh(&rds_sock_lock); +} + +#if IS_ENABLED(CONFIG_IPV6) +static void rds6_sock_info(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens) +{ + struct rds6_info_socket sinfo6; + struct rds_sock *rs; + + len /= sizeof(struct rds6_info_socket); + + spin_lock_bh(&rds_sock_lock); + + if (len < rds_sock_count) + goto out; + + list_for_each_entry(rs, &rds_sock_list, rs_item) { + sinfo6.sndbuf = rds_sk_sndbuf(rs); + sinfo6.rcvbuf = rds_sk_rcvbuf(rs); + sinfo6.bound_addr = rs->rs_bound_addr; + sinfo6.connected_addr = rs->rs_conn_addr; + sinfo6.bound_port = rs->rs_bound_port; + sinfo6.connected_port = rs->rs_conn_port; + sinfo6.inum = sock_i_ino(rds_rs_to_sk(rs)); + + rds_info_copy(iter, &sinfo6, sizeof(sinfo6)); + } + + out: + lens->nr = rds_sock_count; + lens->each = sizeof(struct rds6_info_socket); + + spin_unlock_bh(&rds_sock_lock); +} +#endif + +static void rds_exit(void) +{ + sock_unregister(rds_family_ops.family); + proto_unregister(&rds_proto); + rds_conn_exit(); + rds_cong_exit(); + rds_sysctl_exit(); + rds_threads_exit(); + rds_stats_exit(); + rds_page_exit(); + rds_bind_lock_destroy(); + rds_info_deregister_func(RDS_INFO_SOCKETS, rds_sock_info); + rds_info_deregister_func(RDS_INFO_RECV_MESSAGES, rds_sock_inc_info); +#if IS_ENABLED(CONFIG_IPV6) + rds_info_deregister_func(RDS6_INFO_SOCKETS, rds6_sock_info); + rds_info_deregister_func(RDS6_INFO_RECV_MESSAGES, rds6_sock_inc_info); +#endif +} +module_exit(rds_exit); + +u32 rds_gen_num; + +static int __init rds_init(void) +{ + int ret; + + net_get_random_once(&rds_gen_num, sizeof(rds_gen_num)); + + ret = rds_bind_lock_init(); + if (ret) + goto out; + + ret = rds_conn_init(); + if (ret) + goto out_bind; + + ret = rds_threads_init(); + if (ret) + goto out_conn; + ret = rds_sysctl_init(); + if (ret) + goto out_threads; + ret = rds_stats_init(); + if (ret) + goto out_sysctl; + ret = proto_register(&rds_proto, 1); + if (ret) + goto out_stats; + ret = sock_register(&rds_family_ops); + if (ret) + goto out_proto; + + rds_info_register_func(RDS_INFO_SOCKETS, rds_sock_info); + rds_info_register_func(RDS_INFO_RECV_MESSAGES, rds_sock_inc_info); +#if IS_ENABLED(CONFIG_IPV6) + rds_info_register_func(RDS6_INFO_SOCKETS, rds6_sock_info); + rds_info_register_func(RDS6_INFO_RECV_MESSAGES, rds6_sock_inc_info); +#endif + + goto out; + +out_proto: + proto_unregister(&rds_proto); +out_stats: + rds_stats_exit(); +out_sysctl: + rds_sysctl_exit(); +out_threads: + rds_threads_exit(); +out_conn: + rds_conn_exit(); + rds_cong_exit(); + rds_page_exit(); +out_bind: + rds_bind_lock_destroy(); +out: + return ret; +} +module_init(rds_init); + +#define DRV_VERSION "4.0" +#define DRV_RELDATE "Feb 12, 2009" + +MODULE_AUTHOR("Oracle Corporation "); +MODULE_DESCRIPTION("RDS: Reliable Datagram Sockets" + " v" DRV_VERSION " (" DRV_RELDATE ")"); +MODULE_VERSION(DRV_VERSION); +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_ALIAS_NETPROTO(PF_RDS); diff --git a/net/rds/bind.c b/net/rds/bind.c new file mode 100644 index 000000000..97a29172a --- /dev/null +++ b/net/rds/bind.c @@ -0,0 +1,283 @@ +/* + * Copyright (c) 2006, 2019 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include +#include +#include +#include +#include "rds.h" + +static struct rhashtable bind_hash_table; + +static const struct rhashtable_params ht_parms = { + .nelem_hint = 768, + .key_len = RDS_BOUND_KEY_LEN, + .key_offset = offsetof(struct rds_sock, rs_bound_key), + .head_offset = offsetof(struct rds_sock, rs_bound_node), + .max_size = 16384, + .min_size = 1024, +}; + +/* Create a key for the bind hash table manipulation. Port is in network byte + * order. + */ +static inline void __rds_create_bind_key(u8 *key, const struct in6_addr *addr, + __be16 port, __u32 scope_id) +{ + memcpy(key, addr, sizeof(*addr)); + key += sizeof(*addr); + memcpy(key, &port, sizeof(port)); + key += sizeof(port); + memcpy(key, &scope_id, sizeof(scope_id)); +} + +/* + * Return the rds_sock bound at the given local address. + * + * The rx path can race with rds_release. We notice if rds_release() has + * marked this socket and don't return a rs ref to the rx path. + */ +struct rds_sock *rds_find_bound(const struct in6_addr *addr, __be16 port, + __u32 scope_id) +{ + u8 key[RDS_BOUND_KEY_LEN]; + struct rds_sock *rs; + + __rds_create_bind_key(key, addr, port, scope_id); + rcu_read_lock(); + rs = rhashtable_lookup(&bind_hash_table, key, ht_parms); + if (rs && (sock_flag(rds_rs_to_sk(rs), SOCK_DEAD) || + !refcount_inc_not_zero(&rds_rs_to_sk(rs)->sk_refcnt))) + rs = NULL; + + rcu_read_unlock(); + + rdsdebug("returning rs %p for %pI6c:%u\n", rs, addr, + ntohs(port)); + + return rs; +} + +/* returns -ve errno or +ve port */ +static int rds_add_bound(struct rds_sock *rs, const struct in6_addr *addr, + __be16 *port, __u32 scope_id) +{ + int ret = -EADDRINUSE; + u16 rover, last; + u8 key[RDS_BOUND_KEY_LEN]; + + if (*port != 0) { + rover = be16_to_cpu(*port); + if (rover == RDS_FLAG_PROBE_PORT) + return -EINVAL; + last = rover; + } else { + rover = max_t(u16, get_random_u16(), 2); + last = rover - 1; + } + + do { + if (rover == 0) + rover++; + + if (rover == RDS_FLAG_PROBE_PORT) + continue; + __rds_create_bind_key(key, addr, cpu_to_be16(rover), + scope_id); + if (rhashtable_lookup_fast(&bind_hash_table, key, ht_parms)) + continue; + + memcpy(rs->rs_bound_key, key, sizeof(rs->rs_bound_key)); + rs->rs_bound_addr = *addr; + net_get_random_once(&rs->rs_hash_initval, + sizeof(rs->rs_hash_initval)); + rs->rs_bound_port = cpu_to_be16(rover); + rs->rs_bound_node.next = NULL; + rds_sock_addref(rs); + if (!rhashtable_insert_fast(&bind_hash_table, + &rs->rs_bound_node, ht_parms)) { + *port = rs->rs_bound_port; + rs->rs_bound_scope_id = scope_id; + ret = 0; + rdsdebug("rs %p binding to %pI6c:%d\n", + rs, addr, (int)ntohs(*port)); + break; + } else { + rs->rs_bound_addr = in6addr_any; + rds_sock_put(rs); + ret = -ENOMEM; + break; + } + } while (rover++ != last); + + return ret; +} + +void rds_remove_bound(struct rds_sock *rs) +{ + + if (ipv6_addr_any(&rs->rs_bound_addr)) + return; + + rdsdebug("rs %p unbinding from %pI6c:%d\n", + rs, &rs->rs_bound_addr, + ntohs(rs->rs_bound_port)); + + rhashtable_remove_fast(&bind_hash_table, &rs->rs_bound_node, ht_parms); + rds_sock_put(rs); + rs->rs_bound_addr = in6addr_any; +} + +int rds_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +{ + struct sock *sk = sock->sk; + struct rds_sock *rs = rds_sk_to_rs(sk); + struct in6_addr v6addr, *binding_addr; + struct rds_transport *trans; + __u32 scope_id = 0; + int ret = 0; + __be16 port; + + /* We allow an RDS socket to be bound to either IPv4 or IPv6 + * address. + */ + if (addr_len < offsetofend(struct sockaddr, sa_family)) + return -EINVAL; + if (uaddr->sa_family == AF_INET) { + struct sockaddr_in *sin = (struct sockaddr_in *)uaddr; + + if (addr_len < sizeof(struct sockaddr_in) || + sin->sin_addr.s_addr == htonl(INADDR_ANY) || + sin->sin_addr.s_addr == htonl(INADDR_BROADCAST) || + ipv4_is_multicast(sin->sin_addr.s_addr)) + return -EINVAL; + ipv6_addr_set_v4mapped(sin->sin_addr.s_addr, &v6addr); + binding_addr = &v6addr; + port = sin->sin_port; +#if IS_ENABLED(CONFIG_IPV6) + } else if (uaddr->sa_family == AF_INET6) { + struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)uaddr; + int addr_type; + + if (addr_len < sizeof(struct sockaddr_in6)) + return -EINVAL; + addr_type = ipv6_addr_type(&sin6->sin6_addr); + if (!(addr_type & IPV6_ADDR_UNICAST)) { + __be32 addr4; + + if (!(addr_type & IPV6_ADDR_MAPPED)) + return -EINVAL; + + /* It is a mapped address. Need to do some sanity + * checks. + */ + addr4 = sin6->sin6_addr.s6_addr32[3]; + if (addr4 == htonl(INADDR_ANY) || + addr4 == htonl(INADDR_BROADCAST) || + ipv4_is_multicast(addr4)) + return -EINVAL; + } + /* The scope ID must be specified for link local address. */ + if (addr_type & IPV6_ADDR_LINKLOCAL) { + if (sin6->sin6_scope_id == 0) + return -EINVAL; + scope_id = sin6->sin6_scope_id; + } + binding_addr = &sin6->sin6_addr; + port = sin6->sin6_port; +#endif + } else { + return -EINVAL; + } + lock_sock(sk); + + /* RDS socket does not allow re-binding. */ + if (!ipv6_addr_any(&rs->rs_bound_addr)) { + ret = -EINVAL; + goto out; + } + /* Socket is connected. The binding address should have the same + * scope ID as the connected address, except the case when one is + * non-link local address (scope_id is 0). + */ + if (!ipv6_addr_any(&rs->rs_conn_addr) && scope_id && + rs->rs_bound_scope_id && + scope_id != rs->rs_bound_scope_id) { + ret = -EINVAL; + goto out; + } + + /* The transport can be set using SO_RDS_TRANSPORT option before the + * socket is bound. + */ + if (rs->rs_transport) { + trans = rs->rs_transport; + if (!trans->laddr_check || + trans->laddr_check(sock_net(sock->sk), + binding_addr, scope_id) != 0) { + ret = -ENOPROTOOPT; + goto out; + } + } else { + trans = rds_trans_get_preferred(sock_net(sock->sk), + binding_addr, scope_id); + if (!trans) { + ret = -EADDRNOTAVAIL; + pr_info_ratelimited("RDS: %s could not find a transport for %pI6c, load rds_tcp or rds_rdma?\n", + __func__, binding_addr); + goto out; + } + rs->rs_transport = trans; + } + + sock_set_flag(sk, SOCK_RCU_FREE); + ret = rds_add_bound(rs, binding_addr, &port, scope_id); + if (ret) + rs->rs_transport = NULL; + +out: + release_sock(sk); + return ret; +} + +void rds_bind_lock_destroy(void) +{ + rhashtable_destroy(&bind_hash_table); +} + +int rds_bind_lock_init(void) +{ + return rhashtable_init(&bind_hash_table, &ht_parms); +} diff --git a/net/rds/cong.c b/net/rds/cong.c new file mode 100644 index 000000000..8b689ebbd --- /dev/null +++ b/net/rds/cong.c @@ -0,0 +1,428 @@ +/* + * Copyright (c) 2007, 2017 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include +#include + +#include "rds.h" + +/* + * This file implements the receive side of the unconventional congestion + * management in RDS. + * + * Messages waiting in the receive queue on the receiving socket are accounted + * against the sockets SO_RCVBUF option value. Only the payload bytes in the + * message are accounted for. If the number of bytes queued equals or exceeds + * rcvbuf then the socket is congested. All sends attempted to this socket's + * address should return block or return -EWOULDBLOCK. + * + * Applications are expected to be reasonably tuned such that this situation + * very rarely occurs. An application encountering this "back-pressure" is + * considered a bug. + * + * This is implemented by having each node maintain bitmaps which indicate + * which ports on bound addresses are congested. As the bitmap changes it is + * sent through all the connections which terminate in the local address of the + * bitmap which changed. + * + * The bitmaps are allocated as connections are brought up. This avoids + * allocation in the interrupt handling path which queues messages on sockets. + * The dense bitmaps let transports send the entire bitmap on any bitmap change + * reasonably efficiently. This is much easier to implement than some + * finer-grained communication of per-port congestion. The sender does a very + * inexpensive bit test to test if the port it's about to send to is congested + * or not. + */ + +/* + * Interaction with poll is a tad tricky. We want all processes stuck in + * poll to wake up and check whether a congested destination became uncongested. + * The really sad thing is we have no idea which destinations the application + * wants to send to - we don't even know which rds_connections are involved. + * So until we implement a more flexible rds poll interface, we have to make + * do with this: + * We maintain a global counter that is incremented each time a congestion map + * update is received. Each rds socket tracks this value, and if rds_poll + * finds that the saved generation number is smaller than the global generation + * number, it wakes up the process. + */ +static atomic_t rds_cong_generation = ATOMIC_INIT(0); + +/* + * Congestion monitoring + */ +static LIST_HEAD(rds_cong_monitor); +static DEFINE_RWLOCK(rds_cong_monitor_lock); + +/* + * Yes, a global lock. It's used so infrequently that it's worth keeping it + * global to simplify the locking. It's only used in the following + * circumstances: + * + * - on connection buildup to associate a conn with its maps + * - on map changes to inform conns of a new map to send + * + * It's sadly ordered under the socket callback lock and the connection lock. + * Receive paths can mark ports congested from interrupt context so the + * lock masks interrupts. + */ +static DEFINE_SPINLOCK(rds_cong_lock); +static struct rb_root rds_cong_tree = RB_ROOT; + +static struct rds_cong_map *rds_cong_tree_walk(const struct in6_addr *addr, + struct rds_cong_map *insert) +{ + struct rb_node **p = &rds_cong_tree.rb_node; + struct rb_node *parent = NULL; + struct rds_cong_map *map; + + while (*p) { + int diff; + + parent = *p; + map = rb_entry(parent, struct rds_cong_map, m_rb_node); + + diff = rds_addr_cmp(addr, &map->m_addr); + if (diff < 0) + p = &(*p)->rb_left; + else if (diff > 0) + p = &(*p)->rb_right; + else + return map; + } + + if (insert) { + rb_link_node(&insert->m_rb_node, parent, p); + rb_insert_color(&insert->m_rb_node, &rds_cong_tree); + } + return NULL; +} + +/* + * There is only ever one bitmap for any address. Connections try and allocate + * these bitmaps in the process getting pointers to them. The bitmaps are only + * ever freed as the module is removed after all connections have been freed. + */ +static struct rds_cong_map *rds_cong_from_addr(const struct in6_addr *addr) +{ + struct rds_cong_map *map; + struct rds_cong_map *ret = NULL; + unsigned long zp; + unsigned long i; + unsigned long flags; + + map = kzalloc(sizeof(struct rds_cong_map), GFP_KERNEL); + if (!map) + return NULL; + + map->m_addr = *addr; + init_waitqueue_head(&map->m_waitq); + INIT_LIST_HEAD(&map->m_conn_list); + + for (i = 0; i < RDS_CONG_MAP_PAGES; i++) { + zp = get_zeroed_page(GFP_KERNEL); + if (zp == 0) + goto out; + map->m_page_addrs[i] = zp; + } + + spin_lock_irqsave(&rds_cong_lock, flags); + ret = rds_cong_tree_walk(addr, map); + spin_unlock_irqrestore(&rds_cong_lock, flags); + + if (!ret) { + ret = map; + map = NULL; + } + +out: + if (map) { + for (i = 0; i < RDS_CONG_MAP_PAGES && map->m_page_addrs[i]; i++) + free_page(map->m_page_addrs[i]); + kfree(map); + } + + rdsdebug("map %p for addr %pI6c\n", ret, addr); + + return ret; +} + +/* + * Put the conn on its local map's list. This is called when the conn is + * really added to the hash. It's nested under the rds_conn_lock, sadly. + */ +void rds_cong_add_conn(struct rds_connection *conn) +{ + unsigned long flags; + + rdsdebug("conn %p now on map %p\n", conn, conn->c_lcong); + spin_lock_irqsave(&rds_cong_lock, flags); + list_add_tail(&conn->c_map_item, &conn->c_lcong->m_conn_list); + spin_unlock_irqrestore(&rds_cong_lock, flags); +} + +void rds_cong_remove_conn(struct rds_connection *conn) +{ + unsigned long flags; + + rdsdebug("removing conn %p from map %p\n", conn, conn->c_lcong); + spin_lock_irqsave(&rds_cong_lock, flags); + list_del_init(&conn->c_map_item); + spin_unlock_irqrestore(&rds_cong_lock, flags); +} + +int rds_cong_get_maps(struct rds_connection *conn) +{ + conn->c_lcong = rds_cong_from_addr(&conn->c_laddr); + conn->c_fcong = rds_cong_from_addr(&conn->c_faddr); + + if (!(conn->c_lcong && conn->c_fcong)) + return -ENOMEM; + + return 0; +} + +void rds_cong_queue_updates(struct rds_cong_map *map) +{ + struct rds_connection *conn; + unsigned long flags; + + spin_lock_irqsave(&rds_cong_lock, flags); + + list_for_each_entry(conn, &map->m_conn_list, c_map_item) { + struct rds_conn_path *cp = &conn->c_path[0]; + + rcu_read_lock(); + if (!test_and_set_bit(0, &conn->c_map_queued) && + !rds_destroy_pending(cp->cp_conn)) { + rds_stats_inc(s_cong_update_queued); + /* We cannot inline the call to rds_send_xmit() here + * for two reasons (both pertaining to a TCP transport): + * 1. When we get here from the receive path, we + * are already holding the sock_lock (held by + * tcp_v4_rcv()). So inlining calls to + * tcp_setsockopt and/or tcp_sendmsg will deadlock + * when it tries to get the sock_lock()) + * 2. Interrupts are masked so that we can mark the + * port congested from both send and recv paths. + * (See comment around declaration of rdc_cong_lock). + * An attempt to get the sock_lock() here will + * therefore trigger warnings. + * Defer the xmit to rds_send_worker() instead. + */ + queue_delayed_work(rds_wq, &cp->cp_send_w, 0); + } + rcu_read_unlock(); + } + + spin_unlock_irqrestore(&rds_cong_lock, flags); +} + +void rds_cong_map_updated(struct rds_cong_map *map, uint64_t portmask) +{ + rdsdebug("waking map %p for %pI4\n", + map, &map->m_addr); + rds_stats_inc(s_cong_update_received); + atomic_inc(&rds_cong_generation); + if (waitqueue_active(&map->m_waitq)) + wake_up(&map->m_waitq); + if (waitqueue_active(&rds_poll_waitq)) + wake_up_all(&rds_poll_waitq); + + if (portmask && !list_empty(&rds_cong_monitor)) { + unsigned long flags; + struct rds_sock *rs; + + read_lock_irqsave(&rds_cong_monitor_lock, flags); + list_for_each_entry(rs, &rds_cong_monitor, rs_cong_list) { + spin_lock(&rs->rs_lock); + rs->rs_cong_notify |= (rs->rs_cong_mask & portmask); + rs->rs_cong_mask &= ~portmask; + spin_unlock(&rs->rs_lock); + if (rs->rs_cong_notify) + rds_wake_sk_sleep(rs); + } + read_unlock_irqrestore(&rds_cong_monitor_lock, flags); + } +} +EXPORT_SYMBOL_GPL(rds_cong_map_updated); + +int rds_cong_updated_since(unsigned long *recent) +{ + unsigned long gen = atomic_read(&rds_cong_generation); + + if (likely(*recent == gen)) + return 0; + *recent = gen; + return 1; +} + +/* + * We're called under the locking that protects the sockets receive buffer + * consumption. This makes it a lot easier for the caller to only call us + * when it knows that an existing set bit needs to be cleared, and vice versa. + * We can't block and we need to deal with concurrent sockets working against + * the same per-address map. + */ +void rds_cong_set_bit(struct rds_cong_map *map, __be16 port) +{ + unsigned long i; + unsigned long off; + + rdsdebug("setting congestion for %pI4:%u in map %p\n", + &map->m_addr, ntohs(port), map); + + i = be16_to_cpu(port) / RDS_CONG_MAP_PAGE_BITS; + off = be16_to_cpu(port) % RDS_CONG_MAP_PAGE_BITS; + + set_bit_le(off, (void *)map->m_page_addrs[i]); +} + +void rds_cong_clear_bit(struct rds_cong_map *map, __be16 port) +{ + unsigned long i; + unsigned long off; + + rdsdebug("clearing congestion for %pI4:%u in map %p\n", + &map->m_addr, ntohs(port), map); + + i = be16_to_cpu(port) / RDS_CONG_MAP_PAGE_BITS; + off = be16_to_cpu(port) % RDS_CONG_MAP_PAGE_BITS; + + clear_bit_le(off, (void *)map->m_page_addrs[i]); +} + +static int rds_cong_test_bit(struct rds_cong_map *map, __be16 port) +{ + unsigned long i; + unsigned long off; + + i = be16_to_cpu(port) / RDS_CONG_MAP_PAGE_BITS; + off = be16_to_cpu(port) % RDS_CONG_MAP_PAGE_BITS; + + return test_bit_le(off, (void *)map->m_page_addrs[i]); +} + +void rds_cong_add_socket(struct rds_sock *rs) +{ + unsigned long flags; + + write_lock_irqsave(&rds_cong_monitor_lock, flags); + if (list_empty(&rs->rs_cong_list)) + list_add(&rs->rs_cong_list, &rds_cong_monitor); + write_unlock_irqrestore(&rds_cong_monitor_lock, flags); +} + +void rds_cong_remove_socket(struct rds_sock *rs) +{ + unsigned long flags; + struct rds_cong_map *map; + + write_lock_irqsave(&rds_cong_monitor_lock, flags); + list_del_init(&rs->rs_cong_list); + write_unlock_irqrestore(&rds_cong_monitor_lock, flags); + + /* update congestion map for now-closed port */ + spin_lock_irqsave(&rds_cong_lock, flags); + map = rds_cong_tree_walk(&rs->rs_bound_addr, NULL); + spin_unlock_irqrestore(&rds_cong_lock, flags); + + if (map && rds_cong_test_bit(map, rs->rs_bound_port)) { + rds_cong_clear_bit(map, rs->rs_bound_port); + rds_cong_queue_updates(map); + } +} + +int rds_cong_wait(struct rds_cong_map *map, __be16 port, int nonblock, + struct rds_sock *rs) +{ + if (!rds_cong_test_bit(map, port)) + return 0; + if (nonblock) { + if (rs && rs->rs_cong_monitor) { + unsigned long flags; + + /* It would have been nice to have an atomic set_bit on + * a uint64_t. */ + spin_lock_irqsave(&rs->rs_lock, flags); + rs->rs_cong_mask |= RDS_CONG_MONITOR_MASK(ntohs(port)); + spin_unlock_irqrestore(&rs->rs_lock, flags); + + /* Test again - a congestion update may have arrived in + * the meantime. */ + if (!rds_cong_test_bit(map, port)) + return 0; + } + rds_stats_inc(s_cong_send_error); + return -ENOBUFS; + } + + rds_stats_inc(s_cong_send_blocked); + rdsdebug("waiting on map %p for port %u\n", map, be16_to_cpu(port)); + + return wait_event_interruptible(map->m_waitq, + !rds_cong_test_bit(map, port)); +} + +void rds_cong_exit(void) +{ + struct rb_node *node; + struct rds_cong_map *map; + unsigned long i; + + while ((node = rb_first(&rds_cong_tree))) { + map = rb_entry(node, struct rds_cong_map, m_rb_node); + rdsdebug("freeing map %p\n", map); + rb_erase(&map->m_rb_node, &rds_cong_tree); + for (i = 0; i < RDS_CONG_MAP_PAGES && map->m_page_addrs[i]; i++) + free_page(map->m_page_addrs[i]); + kfree(map); + } +} + +/* + * Allocate a RDS message containing a congestion update. + */ +struct rds_message *rds_cong_update_alloc(struct rds_connection *conn) +{ + struct rds_cong_map *map = conn->c_lcong; + struct rds_message *rm; + + rm = rds_message_map_pages(map->m_page_addrs, RDS_CONG_MAP_BYTES); + if (!IS_ERR(rm)) + rm->m_inc.i_hdr.h_flags = RDS_FLAG_CONG_BITMAP; + + return rm; +} diff --git a/net/rds/connection.c b/net/rds/connection.c new file mode 100644 index 000000000..b4cc699c5 --- /dev/null +++ b/net/rds/connection.c @@ -0,0 +1,948 @@ +/* + * Copyright (c) 2006, 2018 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include +#include +#include +#include + +#include "rds.h" +#include "loop.h" + +#define RDS_CONNECTION_HASH_BITS 12 +#define RDS_CONNECTION_HASH_ENTRIES (1 << RDS_CONNECTION_HASH_BITS) +#define RDS_CONNECTION_HASH_MASK (RDS_CONNECTION_HASH_ENTRIES - 1) + +/* converting this to RCU is a chore for another day.. */ +static DEFINE_SPINLOCK(rds_conn_lock); +static unsigned long rds_conn_count; +static struct hlist_head rds_conn_hash[RDS_CONNECTION_HASH_ENTRIES]; +static struct kmem_cache *rds_conn_slab; + +static struct hlist_head *rds_conn_bucket(const struct in6_addr *laddr, + const struct in6_addr *faddr) +{ + static u32 rds6_hash_secret __read_mostly; + static u32 rds_hash_secret __read_mostly; + + u32 lhash, fhash, hash; + + net_get_random_once(&rds_hash_secret, sizeof(rds_hash_secret)); + net_get_random_once(&rds6_hash_secret, sizeof(rds6_hash_secret)); + + lhash = (__force u32)laddr->s6_addr32[3]; +#if IS_ENABLED(CONFIG_IPV6) + fhash = __ipv6_addr_jhash(faddr, rds6_hash_secret); +#else + fhash = (__force u32)faddr->s6_addr32[3]; +#endif + hash = __inet_ehashfn(lhash, 0, fhash, 0, rds_hash_secret); + + return &rds_conn_hash[hash & RDS_CONNECTION_HASH_MASK]; +} + +#define rds_conn_info_set(var, test, suffix) do { \ + if (test) \ + var |= RDS_INFO_CONNECTION_FLAG_##suffix; \ +} while (0) + +/* rcu read lock must be held or the connection spinlock */ +static struct rds_connection *rds_conn_lookup(struct net *net, + struct hlist_head *head, + const struct in6_addr *laddr, + const struct in6_addr *faddr, + struct rds_transport *trans, + u8 tos, int dev_if) +{ + struct rds_connection *conn, *ret = NULL; + + hlist_for_each_entry_rcu(conn, head, c_hash_node) { + if (ipv6_addr_equal(&conn->c_faddr, faddr) && + ipv6_addr_equal(&conn->c_laddr, laddr) && + conn->c_trans == trans && + conn->c_tos == tos && + net == rds_conn_net(conn) && + conn->c_dev_if == dev_if) { + ret = conn; + break; + } + } + rdsdebug("returning conn %p for %pI6c -> %pI6c\n", ret, + laddr, faddr); + return ret; +} + +/* + * This is called by transports as they're bringing down a connection. + * It clears partial message state so that the transport can start sending + * and receiving over this connection again in the future. It is up to + * the transport to have serialized this call with its send and recv. + */ +static void rds_conn_path_reset(struct rds_conn_path *cp) +{ + struct rds_connection *conn = cp->cp_conn; + + rdsdebug("connection %pI6c to %pI6c reset\n", + &conn->c_laddr, &conn->c_faddr); + + rds_stats_inc(s_conn_reset); + rds_send_path_reset(cp); + cp->cp_flags = 0; + + /* Do not clear next_rx_seq here, else we cannot distinguish + * retransmitted packets from new packets, and will hand all + * of them to the application. That is not consistent with the + * reliability guarantees of RDS. */ +} + +static void __rds_conn_path_init(struct rds_connection *conn, + struct rds_conn_path *cp, bool is_outgoing) +{ + spin_lock_init(&cp->cp_lock); + cp->cp_next_tx_seq = 1; + init_waitqueue_head(&cp->cp_waitq); + INIT_LIST_HEAD(&cp->cp_send_queue); + INIT_LIST_HEAD(&cp->cp_retrans); + + cp->cp_conn = conn; + atomic_set(&cp->cp_state, RDS_CONN_DOWN); + cp->cp_send_gen = 0; + cp->cp_reconnect_jiffies = 0; + cp->cp_conn->c_proposed_version = RDS_PROTOCOL_VERSION; + INIT_DELAYED_WORK(&cp->cp_send_w, rds_send_worker); + INIT_DELAYED_WORK(&cp->cp_recv_w, rds_recv_worker); + INIT_DELAYED_WORK(&cp->cp_conn_w, rds_connect_worker); + INIT_WORK(&cp->cp_down_w, rds_shutdown_worker); + mutex_init(&cp->cp_cm_lock); + cp->cp_flags = 0; +} + +/* + * There is only every one 'conn' for a given pair of addresses in the + * system at a time. They contain messages to be retransmitted and so + * span the lifetime of the actual underlying transport connections. + * + * For now they are not garbage collected once they're created. They + * are torn down as the module is removed, if ever. + */ +static struct rds_connection *__rds_conn_create(struct net *net, + const struct in6_addr *laddr, + const struct in6_addr *faddr, + struct rds_transport *trans, + gfp_t gfp, u8 tos, + int is_outgoing, + int dev_if) +{ + struct rds_connection *conn, *parent = NULL; + struct hlist_head *head = rds_conn_bucket(laddr, faddr); + struct rds_transport *loop_trans; + unsigned long flags; + int ret, i; + int npaths = (trans->t_mp_capable ? RDS_MPATH_WORKERS : 1); + + rcu_read_lock(); + conn = rds_conn_lookup(net, head, laddr, faddr, trans, tos, dev_if); + if (conn && + conn->c_loopback && + conn->c_trans != &rds_loop_transport && + ipv6_addr_equal(laddr, faddr) && + !is_outgoing) { + /* This is a looped back IB connection, and we're + * called by the code handling the incoming connect. + * We need a second connection object into which we + * can stick the other QP. */ + parent = conn; + conn = parent->c_passive; + } + rcu_read_unlock(); + if (conn) + goto out; + + conn = kmem_cache_zalloc(rds_conn_slab, gfp); + if (!conn) { + conn = ERR_PTR(-ENOMEM); + goto out; + } + conn->c_path = kcalloc(npaths, sizeof(struct rds_conn_path), gfp); + if (!conn->c_path) { + kmem_cache_free(rds_conn_slab, conn); + conn = ERR_PTR(-ENOMEM); + goto out; + } + + INIT_HLIST_NODE(&conn->c_hash_node); + conn->c_laddr = *laddr; + conn->c_isv6 = !ipv6_addr_v4mapped(laddr); + conn->c_faddr = *faddr; + conn->c_dev_if = dev_if; + conn->c_tos = tos; + +#if IS_ENABLED(CONFIG_IPV6) + /* If the local address is link local, set c_bound_if to be the + * index used for this connection. Otherwise, set it to 0 as + * the socket is not bound to an interface. c_bound_if is used + * to look up a socket when a packet is received + */ + if (ipv6_addr_type(laddr) & IPV6_ADDR_LINKLOCAL) + conn->c_bound_if = dev_if; + else +#endif + conn->c_bound_if = 0; + + rds_conn_net_set(conn, net); + + ret = rds_cong_get_maps(conn); + if (ret) { + kfree(conn->c_path); + kmem_cache_free(rds_conn_slab, conn); + conn = ERR_PTR(ret); + goto out; + } + + /* + * This is where a connection becomes loopback. If *any* RDS sockets + * can bind to the destination address then we'd rather the messages + * flow through loopback rather than either transport. + */ + loop_trans = rds_trans_get_preferred(net, faddr, conn->c_dev_if); + if (loop_trans) { + rds_trans_put(loop_trans); + conn->c_loopback = 1; + if (trans->t_prefer_loopback) { + if (likely(is_outgoing)) { + /* "outgoing" connection to local address. + * Protocol says it wants the connection + * handled by the loopback transport. + * This is what TCP does. + */ + trans = &rds_loop_transport; + } else { + /* No transport currently in use + * should end up here, but if it + * does, reset/destroy the connection. + */ + kfree(conn->c_path); + kmem_cache_free(rds_conn_slab, conn); + conn = ERR_PTR(-EOPNOTSUPP); + goto out; + } + } + } + + conn->c_trans = trans; + + init_waitqueue_head(&conn->c_hs_waitq); + for (i = 0; i < npaths; i++) { + __rds_conn_path_init(conn, &conn->c_path[i], + is_outgoing); + conn->c_path[i].cp_index = i; + } + rcu_read_lock(); + if (rds_destroy_pending(conn)) + ret = -ENETDOWN; + else + ret = trans->conn_alloc(conn, GFP_ATOMIC); + if (ret) { + rcu_read_unlock(); + kfree(conn->c_path); + kmem_cache_free(rds_conn_slab, conn); + conn = ERR_PTR(ret); + goto out; + } + + rdsdebug("allocated conn %p for %pI6c -> %pI6c over %s %s\n", + conn, laddr, faddr, + strnlen(trans->t_name, sizeof(trans->t_name)) ? + trans->t_name : "[unknown]", is_outgoing ? "(outgoing)" : ""); + + /* + * Since we ran without holding the conn lock, someone could + * have created the same conn (either normal or passive) in the + * interim. We check while holding the lock. If we won, we complete + * init and return our conn. If we lost, we rollback and return the + * other one. + */ + spin_lock_irqsave(&rds_conn_lock, flags); + if (parent) { + /* Creating passive conn */ + if (parent->c_passive) { + trans->conn_free(conn->c_path[0].cp_transport_data); + kfree(conn->c_path); + kmem_cache_free(rds_conn_slab, conn); + conn = parent->c_passive; + } else { + parent->c_passive = conn; + rds_cong_add_conn(conn); + rds_conn_count++; + } + } else { + /* Creating normal conn */ + struct rds_connection *found; + + found = rds_conn_lookup(net, head, laddr, faddr, trans, + tos, dev_if); + if (found) { + struct rds_conn_path *cp; + int i; + + for (i = 0; i < npaths; i++) { + cp = &conn->c_path[i]; + /* The ->conn_alloc invocation may have + * allocated resource for all paths, so all + * of them may have to be freed here. + */ + if (cp->cp_transport_data) + trans->conn_free(cp->cp_transport_data); + } + kfree(conn->c_path); + kmem_cache_free(rds_conn_slab, conn); + conn = found; + } else { + conn->c_my_gen_num = rds_gen_num; + conn->c_peer_gen_num = 0; + hlist_add_head_rcu(&conn->c_hash_node, head); + rds_cong_add_conn(conn); + rds_conn_count++; + } + } + spin_unlock_irqrestore(&rds_conn_lock, flags); + rcu_read_unlock(); + +out: + return conn; +} + +struct rds_connection *rds_conn_create(struct net *net, + const struct in6_addr *laddr, + const struct in6_addr *faddr, + struct rds_transport *trans, u8 tos, + gfp_t gfp, int dev_if) +{ + return __rds_conn_create(net, laddr, faddr, trans, gfp, tos, 0, dev_if); +} +EXPORT_SYMBOL_GPL(rds_conn_create); + +struct rds_connection *rds_conn_create_outgoing(struct net *net, + const struct in6_addr *laddr, + const struct in6_addr *faddr, + struct rds_transport *trans, + u8 tos, gfp_t gfp, int dev_if) +{ + return __rds_conn_create(net, laddr, faddr, trans, gfp, tos, 1, dev_if); +} +EXPORT_SYMBOL_GPL(rds_conn_create_outgoing); + +void rds_conn_shutdown(struct rds_conn_path *cp) +{ + struct rds_connection *conn = cp->cp_conn; + + /* shut it down unless it's down already */ + if (!rds_conn_path_transition(cp, RDS_CONN_DOWN, RDS_CONN_DOWN)) { + /* + * Quiesce the connection mgmt handlers before we start tearing + * things down. We don't hold the mutex for the entire + * duration of the shutdown operation, else we may be + * deadlocking with the CM handler. Instead, the CM event + * handler is supposed to check for state DISCONNECTING + */ + mutex_lock(&cp->cp_cm_lock); + if (!rds_conn_path_transition(cp, RDS_CONN_UP, + RDS_CONN_DISCONNECTING) && + !rds_conn_path_transition(cp, RDS_CONN_ERROR, + RDS_CONN_DISCONNECTING)) { + rds_conn_path_error(cp, + "shutdown called in state %d\n", + atomic_read(&cp->cp_state)); + mutex_unlock(&cp->cp_cm_lock); + return; + } + mutex_unlock(&cp->cp_cm_lock); + + wait_event(cp->cp_waitq, + !test_bit(RDS_IN_XMIT, &cp->cp_flags)); + wait_event(cp->cp_waitq, + !test_bit(RDS_RECV_REFILL, &cp->cp_flags)); + + conn->c_trans->conn_path_shutdown(cp); + rds_conn_path_reset(cp); + + if (!rds_conn_path_transition(cp, RDS_CONN_DISCONNECTING, + RDS_CONN_DOWN) && + !rds_conn_path_transition(cp, RDS_CONN_ERROR, + RDS_CONN_DOWN)) { + /* This can happen - eg when we're in the middle of tearing + * down the connection, and someone unloads the rds module. + * Quite reproducible with loopback connections. + * Mostly harmless. + * + * Note that this also happens with rds-tcp because + * we could have triggered rds_conn_path_drop in irq + * mode from rds_tcp_state change on the receipt of + * a FIN, thus we need to recheck for RDS_CONN_ERROR + * here. + */ + rds_conn_path_error(cp, "%s: failed to transition " + "to state DOWN, current state " + "is %d\n", __func__, + atomic_read(&cp->cp_state)); + return; + } + } + + /* Then reconnect if it's still live. + * The passive side of an IB loopback connection is never added + * to the conn hash, so we never trigger a reconnect on this + * conn - the reconnect is always triggered by the active peer. */ + cancel_delayed_work_sync(&cp->cp_conn_w); + rcu_read_lock(); + if (!hlist_unhashed(&conn->c_hash_node)) { + rcu_read_unlock(); + rds_queue_reconnect(cp); + } else { + rcu_read_unlock(); + } +} + +/* destroy a single rds_conn_path. rds_conn_destroy() iterates over + * all paths using rds_conn_path_destroy() + */ +static void rds_conn_path_destroy(struct rds_conn_path *cp) +{ + struct rds_message *rm, *rtmp; + + if (!cp->cp_transport_data) + return; + + /* make sure lingering queued work won't try to ref the conn */ + cancel_delayed_work_sync(&cp->cp_send_w); + cancel_delayed_work_sync(&cp->cp_recv_w); + + rds_conn_path_drop(cp, true); + flush_work(&cp->cp_down_w); + + /* tear down queued messages */ + list_for_each_entry_safe(rm, rtmp, + &cp->cp_send_queue, + m_conn_item) { + list_del_init(&rm->m_conn_item); + BUG_ON(!list_empty(&rm->m_sock_item)); + rds_message_put(rm); + } + if (cp->cp_xmit_rm) + rds_message_put(cp->cp_xmit_rm); + + WARN_ON(delayed_work_pending(&cp->cp_send_w)); + WARN_ON(delayed_work_pending(&cp->cp_recv_w)); + WARN_ON(delayed_work_pending(&cp->cp_conn_w)); + WARN_ON(work_pending(&cp->cp_down_w)); + + cp->cp_conn->c_trans->conn_free(cp->cp_transport_data); +} + +/* + * Stop and free a connection. + * + * This can only be used in very limited circumstances. It assumes that once + * the conn has been shutdown that no one else is referencing the connection. + * We can only ensure this in the rmmod path in the current code. + */ +void rds_conn_destroy(struct rds_connection *conn) +{ + unsigned long flags; + int i; + struct rds_conn_path *cp; + int npaths = (conn->c_trans->t_mp_capable ? RDS_MPATH_WORKERS : 1); + + rdsdebug("freeing conn %p for %pI4 -> " + "%pI4\n", conn, &conn->c_laddr, + &conn->c_faddr); + + /* Ensure conn will not be scheduled for reconnect */ + spin_lock_irq(&rds_conn_lock); + hlist_del_init_rcu(&conn->c_hash_node); + spin_unlock_irq(&rds_conn_lock); + synchronize_rcu(); + + /* shut the connection down */ + for (i = 0; i < npaths; i++) { + cp = &conn->c_path[i]; + rds_conn_path_destroy(cp); + BUG_ON(!list_empty(&cp->cp_retrans)); + } + + /* + * The congestion maps aren't freed up here. They're + * freed by rds_cong_exit() after all the connections + * have been freed. + */ + rds_cong_remove_conn(conn); + + kfree(conn->c_path); + kmem_cache_free(rds_conn_slab, conn); + + spin_lock_irqsave(&rds_conn_lock, flags); + rds_conn_count--; + spin_unlock_irqrestore(&rds_conn_lock, flags); +} +EXPORT_SYMBOL_GPL(rds_conn_destroy); + +static void __rds_inc_msg_cp(struct rds_incoming *inc, + struct rds_info_iterator *iter, + void *saddr, void *daddr, int flip, bool isv6) +{ +#if IS_ENABLED(CONFIG_IPV6) + if (isv6) + rds6_inc_info_copy(inc, iter, saddr, daddr, flip); + else +#endif + rds_inc_info_copy(inc, iter, *(__be32 *)saddr, + *(__be32 *)daddr, flip); +} + +static void rds_conn_message_info_cmn(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens, + int want_send, bool isv6) +{ + struct hlist_head *head; + struct list_head *list; + struct rds_connection *conn; + struct rds_message *rm; + unsigned int total = 0; + unsigned long flags; + size_t i; + int j; + + if (isv6) + len /= sizeof(struct rds6_info_message); + else + len /= sizeof(struct rds_info_message); + + rcu_read_lock(); + + for (i = 0, head = rds_conn_hash; i < ARRAY_SIZE(rds_conn_hash); + i++, head++) { + hlist_for_each_entry_rcu(conn, head, c_hash_node) { + struct rds_conn_path *cp; + int npaths; + + if (!isv6 && conn->c_isv6) + continue; + + npaths = (conn->c_trans->t_mp_capable ? + RDS_MPATH_WORKERS : 1); + + for (j = 0; j < npaths; j++) { + cp = &conn->c_path[j]; + if (want_send) + list = &cp->cp_send_queue; + else + list = &cp->cp_retrans; + + spin_lock_irqsave(&cp->cp_lock, flags); + + /* XXX too lazy to maintain counts.. */ + list_for_each_entry(rm, list, m_conn_item) { + total++; + if (total <= len) + __rds_inc_msg_cp(&rm->m_inc, + iter, + &conn->c_laddr, + &conn->c_faddr, + 0, isv6); + } + + spin_unlock_irqrestore(&cp->cp_lock, flags); + } + } + } + rcu_read_unlock(); + + lens->nr = total; + if (isv6) + lens->each = sizeof(struct rds6_info_message); + else + lens->each = sizeof(struct rds_info_message); +} + +static void rds_conn_message_info(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens, + int want_send) +{ + rds_conn_message_info_cmn(sock, len, iter, lens, want_send, false); +} + +#if IS_ENABLED(CONFIG_IPV6) +static void rds6_conn_message_info(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens, + int want_send) +{ + rds_conn_message_info_cmn(sock, len, iter, lens, want_send, true); +} +#endif + +static void rds_conn_message_info_send(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens) +{ + rds_conn_message_info(sock, len, iter, lens, 1); +} + +#if IS_ENABLED(CONFIG_IPV6) +static void rds6_conn_message_info_send(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens) +{ + rds6_conn_message_info(sock, len, iter, lens, 1); +} +#endif + +static void rds_conn_message_info_retrans(struct socket *sock, + unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens) +{ + rds_conn_message_info(sock, len, iter, lens, 0); +} + +#if IS_ENABLED(CONFIG_IPV6) +static void rds6_conn_message_info_retrans(struct socket *sock, + unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens) +{ + rds6_conn_message_info(sock, len, iter, lens, 0); +} +#endif + +void rds_for_each_conn_info(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens, + int (*visitor)(struct rds_connection *, void *), + u64 *buffer, + size_t item_len) +{ + struct hlist_head *head; + struct rds_connection *conn; + size_t i; + + rcu_read_lock(); + + lens->nr = 0; + lens->each = item_len; + + for (i = 0, head = rds_conn_hash; i < ARRAY_SIZE(rds_conn_hash); + i++, head++) { + hlist_for_each_entry_rcu(conn, head, c_hash_node) { + + /* XXX no c_lock usage.. */ + if (!visitor(conn, buffer)) + continue; + + /* We copy as much as we can fit in the buffer, + * but we count all items so that the caller + * can resize the buffer. */ + if (len >= item_len) { + rds_info_copy(iter, buffer, item_len); + len -= item_len; + } + lens->nr++; + } + } + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(rds_for_each_conn_info); + +static void rds_walk_conn_path_info(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens, + int (*visitor)(struct rds_conn_path *, void *), + u64 *buffer, + size_t item_len) +{ + struct hlist_head *head; + struct rds_connection *conn; + size_t i; + + rcu_read_lock(); + + lens->nr = 0; + lens->each = item_len; + + for (i = 0, head = rds_conn_hash; i < ARRAY_SIZE(rds_conn_hash); + i++, head++) { + hlist_for_each_entry_rcu(conn, head, c_hash_node) { + struct rds_conn_path *cp; + + /* XXX We only copy the information from the first + * path for now. The problem is that if there are + * more than one underlying paths, we cannot report + * information of all of them using the existing + * API. For example, there is only one next_tx_seq, + * which path's next_tx_seq should we report? It is + * a bug in the design of MPRDS. + */ + cp = conn->c_path; + + /* XXX no cp_lock usage.. */ + if (!visitor(cp, buffer)) + continue; + + /* We copy as much as we can fit in the buffer, + * but we count all items so that the caller + * can resize the buffer. + */ + if (len >= item_len) { + rds_info_copy(iter, buffer, item_len); + len -= item_len; + } + lens->nr++; + } + } + rcu_read_unlock(); +} + +static int rds_conn_info_visitor(struct rds_conn_path *cp, void *buffer) +{ + struct rds_info_connection *cinfo = buffer; + struct rds_connection *conn = cp->cp_conn; + + if (conn->c_isv6) + return 0; + + cinfo->next_tx_seq = cp->cp_next_tx_seq; + cinfo->next_rx_seq = cp->cp_next_rx_seq; + cinfo->laddr = conn->c_laddr.s6_addr32[3]; + cinfo->faddr = conn->c_faddr.s6_addr32[3]; + cinfo->tos = conn->c_tos; + strncpy(cinfo->transport, conn->c_trans->t_name, + sizeof(cinfo->transport)); + cinfo->flags = 0; + + rds_conn_info_set(cinfo->flags, test_bit(RDS_IN_XMIT, &cp->cp_flags), + SENDING); + /* XXX Future: return the state rather than these funky bits */ + rds_conn_info_set(cinfo->flags, + atomic_read(&cp->cp_state) == RDS_CONN_CONNECTING, + CONNECTING); + rds_conn_info_set(cinfo->flags, + atomic_read(&cp->cp_state) == RDS_CONN_UP, + CONNECTED); + return 1; +} + +#if IS_ENABLED(CONFIG_IPV6) +static int rds6_conn_info_visitor(struct rds_conn_path *cp, void *buffer) +{ + struct rds6_info_connection *cinfo6 = buffer; + struct rds_connection *conn = cp->cp_conn; + + cinfo6->next_tx_seq = cp->cp_next_tx_seq; + cinfo6->next_rx_seq = cp->cp_next_rx_seq; + cinfo6->laddr = conn->c_laddr; + cinfo6->faddr = conn->c_faddr; + strncpy(cinfo6->transport, conn->c_trans->t_name, + sizeof(cinfo6->transport)); + cinfo6->flags = 0; + + rds_conn_info_set(cinfo6->flags, test_bit(RDS_IN_XMIT, &cp->cp_flags), + SENDING); + /* XXX Future: return the state rather than these funky bits */ + rds_conn_info_set(cinfo6->flags, + atomic_read(&cp->cp_state) == RDS_CONN_CONNECTING, + CONNECTING); + rds_conn_info_set(cinfo6->flags, + atomic_read(&cp->cp_state) == RDS_CONN_UP, + CONNECTED); + /* Just return 1 as there is no error case. This is a helper function + * for rds_walk_conn_path_info() and it wants a return value. + */ + return 1; +} +#endif + +static void rds_conn_info(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens) +{ + u64 buffer[(sizeof(struct rds_info_connection) + 7) / 8]; + + rds_walk_conn_path_info(sock, len, iter, lens, + rds_conn_info_visitor, + buffer, + sizeof(struct rds_info_connection)); +} + +#if IS_ENABLED(CONFIG_IPV6) +static void rds6_conn_info(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens) +{ + u64 buffer[(sizeof(struct rds6_info_connection) + 7) / 8]; + + rds_walk_conn_path_info(sock, len, iter, lens, + rds6_conn_info_visitor, + buffer, + sizeof(struct rds6_info_connection)); +} +#endif + +int rds_conn_init(void) +{ + int ret; + + ret = rds_loop_net_init(); /* register pernet callback */ + if (ret) + return ret; + + rds_conn_slab = kmem_cache_create("rds_connection", + sizeof(struct rds_connection), + 0, 0, NULL); + if (!rds_conn_slab) { + rds_loop_net_exit(); + return -ENOMEM; + } + + rds_info_register_func(RDS_INFO_CONNECTIONS, rds_conn_info); + rds_info_register_func(RDS_INFO_SEND_MESSAGES, + rds_conn_message_info_send); + rds_info_register_func(RDS_INFO_RETRANS_MESSAGES, + rds_conn_message_info_retrans); +#if IS_ENABLED(CONFIG_IPV6) + rds_info_register_func(RDS6_INFO_CONNECTIONS, rds6_conn_info); + rds_info_register_func(RDS6_INFO_SEND_MESSAGES, + rds6_conn_message_info_send); + rds_info_register_func(RDS6_INFO_RETRANS_MESSAGES, + rds6_conn_message_info_retrans); +#endif + return 0; +} + +void rds_conn_exit(void) +{ + rds_loop_net_exit(); /* unregister pernet callback */ + rds_loop_exit(); + + WARN_ON(!hlist_empty(rds_conn_hash)); + + kmem_cache_destroy(rds_conn_slab); + + rds_info_deregister_func(RDS_INFO_CONNECTIONS, rds_conn_info); + rds_info_deregister_func(RDS_INFO_SEND_MESSAGES, + rds_conn_message_info_send); + rds_info_deregister_func(RDS_INFO_RETRANS_MESSAGES, + rds_conn_message_info_retrans); +#if IS_ENABLED(CONFIG_IPV6) + rds_info_deregister_func(RDS6_INFO_CONNECTIONS, rds6_conn_info); + rds_info_deregister_func(RDS6_INFO_SEND_MESSAGES, + rds6_conn_message_info_send); + rds_info_deregister_func(RDS6_INFO_RETRANS_MESSAGES, + rds6_conn_message_info_retrans); +#endif +} + +/* + * Force a disconnect + */ +void rds_conn_path_drop(struct rds_conn_path *cp, bool destroy) +{ + atomic_set(&cp->cp_state, RDS_CONN_ERROR); + + rcu_read_lock(); + if (!destroy && rds_destroy_pending(cp->cp_conn)) { + rcu_read_unlock(); + return; + } + queue_work(rds_wq, &cp->cp_down_w); + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(rds_conn_path_drop); + +void rds_conn_drop(struct rds_connection *conn) +{ + WARN_ON(conn->c_trans->t_mp_capable); + rds_conn_path_drop(&conn->c_path[0], false); +} +EXPORT_SYMBOL_GPL(rds_conn_drop); + +/* + * If the connection is down, trigger a connect. We may have scheduled a + * delayed reconnect however - in this case we should not interfere. + */ +void rds_conn_path_connect_if_down(struct rds_conn_path *cp) +{ + rcu_read_lock(); + if (rds_destroy_pending(cp->cp_conn)) { + rcu_read_unlock(); + return; + } + if (rds_conn_path_state(cp) == RDS_CONN_DOWN && + !test_and_set_bit(RDS_RECONNECT_PENDING, &cp->cp_flags)) + queue_delayed_work(rds_wq, &cp->cp_conn_w, 0); + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(rds_conn_path_connect_if_down); + +/* Check connectivity of all paths + */ +void rds_check_all_paths(struct rds_connection *conn) +{ + int i = 0; + + do { + rds_conn_path_connect_if_down(&conn->c_path[i]); + } while (++i < conn->c_npaths); +} + +void rds_conn_connect_if_down(struct rds_connection *conn) +{ + WARN_ON(conn->c_trans->t_mp_capable); + rds_conn_path_connect_if_down(&conn->c_path[0]); +} +EXPORT_SYMBOL_GPL(rds_conn_connect_if_down); + +void +__rds_conn_path_error(struct rds_conn_path *cp, const char *fmt, ...) +{ + va_list ap; + + va_start(ap, fmt); + vprintk(fmt, ap); + va_end(ap); + + rds_conn_path_drop(cp, false); +} diff --git a/net/rds/ib.c b/net/rds/ib.c new file mode 100644 index 000000000..9826fe7f9 --- /dev/null +++ b/net/rds/ib.c @@ -0,0 +1,607 @@ +/* + * Copyright (c) 2006, 2019 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "rds_single_path.h" +#include "rds.h" +#include "ib.h" +#include "ib_mr.h" + +static unsigned int rds_ib_mr_1m_pool_size = RDS_MR_1M_POOL_SIZE; +static unsigned int rds_ib_mr_8k_pool_size = RDS_MR_8K_POOL_SIZE; +unsigned int rds_ib_retry_count = RDS_IB_DEFAULT_RETRY_COUNT; +static atomic_t rds_ib_unloading; + +module_param(rds_ib_mr_1m_pool_size, int, 0444); +MODULE_PARM_DESC(rds_ib_mr_1m_pool_size, " Max number of 1M mr per HCA"); +module_param(rds_ib_mr_8k_pool_size, int, 0444); +MODULE_PARM_DESC(rds_ib_mr_8k_pool_size, " Max number of 8K mr per HCA"); +module_param(rds_ib_retry_count, int, 0444); +MODULE_PARM_DESC(rds_ib_retry_count, " Number of hw retries before reporting an error"); + +/* + * we have a clumsy combination of RCU and a rwsem protecting this list + * because it is used both in the get_mr fast path and while blocking in + * the FMR flushing path. + */ +DECLARE_RWSEM(rds_ib_devices_lock); +struct list_head rds_ib_devices; + +/* NOTE: if also grabbing ibdev lock, grab this first */ +DEFINE_SPINLOCK(ib_nodev_conns_lock); +LIST_HEAD(ib_nodev_conns); + +static void rds_ib_nodev_connect(void) +{ + struct rds_ib_connection *ic; + + spin_lock(&ib_nodev_conns_lock); + list_for_each_entry(ic, &ib_nodev_conns, ib_node) + rds_conn_connect_if_down(ic->conn); + spin_unlock(&ib_nodev_conns_lock); +} + +static void rds_ib_dev_shutdown(struct rds_ib_device *rds_ibdev) +{ + struct rds_ib_connection *ic; + unsigned long flags; + + spin_lock_irqsave(&rds_ibdev->spinlock, flags); + list_for_each_entry(ic, &rds_ibdev->conn_list, ib_node) + rds_conn_path_drop(&ic->conn->c_path[0], true); + spin_unlock_irqrestore(&rds_ibdev->spinlock, flags); +} + +/* + * rds_ib_destroy_mr_pool() blocks on a few things and mrs drop references + * from interrupt context so we push freing off into a work struct in krdsd. + */ +static void rds_ib_dev_free(struct work_struct *work) +{ + struct rds_ib_ipaddr *i_ipaddr, *i_next; + struct rds_ib_device *rds_ibdev = container_of(work, + struct rds_ib_device, free_work); + + if (rds_ibdev->mr_8k_pool) + rds_ib_destroy_mr_pool(rds_ibdev->mr_8k_pool); + if (rds_ibdev->mr_1m_pool) + rds_ib_destroy_mr_pool(rds_ibdev->mr_1m_pool); + if (rds_ibdev->pd) + ib_dealloc_pd(rds_ibdev->pd); + + list_for_each_entry_safe(i_ipaddr, i_next, &rds_ibdev->ipaddr_list, list) { + list_del(&i_ipaddr->list); + kfree(i_ipaddr); + } + + kfree(rds_ibdev->vector_load); + + kfree(rds_ibdev); +} + +void rds_ib_dev_put(struct rds_ib_device *rds_ibdev) +{ + BUG_ON(refcount_read(&rds_ibdev->refcount) == 0); + if (refcount_dec_and_test(&rds_ibdev->refcount)) + queue_work(rds_wq, &rds_ibdev->free_work); +} + +static int rds_ib_add_one(struct ib_device *device) +{ + struct rds_ib_device *rds_ibdev; + int ret; + + /* Only handle IB (no iWARP) devices */ + if (device->node_type != RDMA_NODE_IB_CA) + return -EOPNOTSUPP; + + /* Device must support FRWR */ + if (!(device->attrs.device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS)) + return -EOPNOTSUPP; + + rds_ibdev = kzalloc_node(sizeof(struct rds_ib_device), GFP_KERNEL, + ibdev_to_node(device)); + if (!rds_ibdev) + return -ENOMEM; + + spin_lock_init(&rds_ibdev->spinlock); + refcount_set(&rds_ibdev->refcount, 1); + INIT_WORK(&rds_ibdev->free_work, rds_ib_dev_free); + + INIT_LIST_HEAD(&rds_ibdev->ipaddr_list); + INIT_LIST_HEAD(&rds_ibdev->conn_list); + + rds_ibdev->max_wrs = device->attrs.max_qp_wr; + rds_ibdev->max_sge = min(device->attrs.max_send_sge, RDS_IB_MAX_SGE); + + rds_ibdev->odp_capable = + !!(device->attrs.kernel_cap_flags & + IBK_ON_DEMAND_PAGING) && + !!(device->attrs.odp_caps.per_transport_caps.rc_odp_caps & + IB_ODP_SUPPORT_WRITE) && + !!(device->attrs.odp_caps.per_transport_caps.rc_odp_caps & + IB_ODP_SUPPORT_READ); + + rds_ibdev->max_1m_mrs = device->attrs.max_mr ? + min_t(unsigned int, (device->attrs.max_mr / 2), + rds_ib_mr_1m_pool_size) : rds_ib_mr_1m_pool_size; + + rds_ibdev->max_8k_mrs = device->attrs.max_mr ? + min_t(unsigned int, ((device->attrs.max_mr / 2) * RDS_MR_8K_SCALE), + rds_ib_mr_8k_pool_size) : rds_ib_mr_8k_pool_size; + + rds_ibdev->max_initiator_depth = device->attrs.max_qp_init_rd_atom; + rds_ibdev->max_responder_resources = device->attrs.max_qp_rd_atom; + + rds_ibdev->vector_load = kcalloc(device->num_comp_vectors, + sizeof(int), + GFP_KERNEL); + if (!rds_ibdev->vector_load) { + pr_err("RDS/IB: %s failed to allocate vector memory\n", + __func__); + ret = -ENOMEM; + goto put_dev; + } + + rds_ibdev->dev = device; + rds_ibdev->pd = ib_alloc_pd(device, 0); + if (IS_ERR(rds_ibdev->pd)) { + ret = PTR_ERR(rds_ibdev->pd); + rds_ibdev->pd = NULL; + goto put_dev; + } + + rds_ibdev->mr_1m_pool = + rds_ib_create_mr_pool(rds_ibdev, RDS_IB_MR_1M_POOL); + if (IS_ERR(rds_ibdev->mr_1m_pool)) { + ret = PTR_ERR(rds_ibdev->mr_1m_pool); + rds_ibdev->mr_1m_pool = NULL; + goto put_dev; + } + + rds_ibdev->mr_8k_pool = + rds_ib_create_mr_pool(rds_ibdev, RDS_IB_MR_8K_POOL); + if (IS_ERR(rds_ibdev->mr_8k_pool)) { + ret = PTR_ERR(rds_ibdev->mr_8k_pool); + rds_ibdev->mr_8k_pool = NULL; + goto put_dev; + } + + rdsdebug("RDS/IB: max_mr = %d, max_wrs = %d, max_sge = %d, max_1m_mrs = %d, max_8k_mrs = %d\n", + device->attrs.max_mr, rds_ibdev->max_wrs, rds_ibdev->max_sge, + rds_ibdev->max_1m_mrs, rds_ibdev->max_8k_mrs); + + pr_info("RDS/IB: %s: added\n", device->name); + + down_write(&rds_ib_devices_lock); + list_add_tail_rcu(&rds_ibdev->list, &rds_ib_devices); + up_write(&rds_ib_devices_lock); + refcount_inc(&rds_ibdev->refcount); + + ib_set_client_data(device, &rds_ib_client, rds_ibdev); + + rds_ib_nodev_connect(); + return 0; + +put_dev: + rds_ib_dev_put(rds_ibdev); + return ret; +} + +/* + * New connections use this to find the device to associate with the + * connection. It's not in the fast path so we're not concerned about the + * performance of the IB call. (As of this writing, it uses an interrupt + * blocking spinlock to serialize walking a per-device list of all registered + * clients.) + * + * RCU is used to handle incoming connections racing with device teardown. + * Rather than use a lock to serialize removal from the client_data and + * getting a new reference, we use an RCU grace period. The destruction + * path removes the device from client_data and then waits for all RCU + * readers to finish. + * + * A new connection can get NULL from this if its arriving on a + * device that is in the process of being removed. + */ +struct rds_ib_device *rds_ib_get_client_data(struct ib_device *device) +{ + struct rds_ib_device *rds_ibdev; + + rcu_read_lock(); + rds_ibdev = ib_get_client_data(device, &rds_ib_client); + if (rds_ibdev) + refcount_inc(&rds_ibdev->refcount); + rcu_read_unlock(); + return rds_ibdev; +} + +/* + * The IB stack is letting us know that a device is going away. This can + * happen if the underlying HCA driver is removed or if PCI hotplug is removing + * the pci function, for example. + * + * This can be called at any time and can be racing with any other RDS path. + */ +static void rds_ib_remove_one(struct ib_device *device, void *client_data) +{ + struct rds_ib_device *rds_ibdev = client_data; + + rds_ib_dev_shutdown(rds_ibdev); + + /* stop connection attempts from getting a reference to this device. */ + ib_set_client_data(device, &rds_ib_client, NULL); + + down_write(&rds_ib_devices_lock); + list_del_rcu(&rds_ibdev->list); + up_write(&rds_ib_devices_lock); + + /* + * This synchronize rcu is waiting for readers of both the ib + * client data and the devices list to finish before we drop + * both of those references. + */ + synchronize_rcu(); + rds_ib_dev_put(rds_ibdev); + rds_ib_dev_put(rds_ibdev); +} + +struct ib_client rds_ib_client = { + .name = "rds_ib", + .add = rds_ib_add_one, + .remove = rds_ib_remove_one +}; + +static int rds_ib_conn_info_visitor(struct rds_connection *conn, + void *buffer) +{ + struct rds_info_rdma_connection *iinfo = buffer; + struct rds_ib_connection *ic = conn->c_transport_data; + + /* We will only ever look at IB transports */ + if (conn->c_trans != &rds_ib_transport) + return 0; + if (conn->c_isv6) + return 0; + + iinfo->src_addr = conn->c_laddr.s6_addr32[3]; + iinfo->dst_addr = conn->c_faddr.s6_addr32[3]; + if (ic) { + iinfo->tos = conn->c_tos; + iinfo->sl = ic->i_sl; + } + + memset(&iinfo->src_gid, 0, sizeof(iinfo->src_gid)); + memset(&iinfo->dst_gid, 0, sizeof(iinfo->dst_gid)); + if (rds_conn_state(conn) == RDS_CONN_UP) { + struct rds_ib_device *rds_ibdev; + + rdma_read_gids(ic->i_cm_id, (union ib_gid *)&iinfo->src_gid, + (union ib_gid *)&iinfo->dst_gid); + + rds_ibdev = ic->rds_ibdev; + iinfo->max_send_wr = ic->i_send_ring.w_nr; + iinfo->max_recv_wr = ic->i_recv_ring.w_nr; + iinfo->max_send_sge = rds_ibdev->max_sge; + rds_ib_get_mr_info(rds_ibdev, iinfo); + iinfo->cache_allocs = atomic_read(&ic->i_cache_allocs); + } + return 1; +} + +#if IS_ENABLED(CONFIG_IPV6) +/* IPv6 version of rds_ib_conn_info_visitor(). */ +static int rds6_ib_conn_info_visitor(struct rds_connection *conn, + void *buffer) +{ + struct rds6_info_rdma_connection *iinfo6 = buffer; + struct rds_ib_connection *ic = conn->c_transport_data; + + /* We will only ever look at IB transports */ + if (conn->c_trans != &rds_ib_transport) + return 0; + + iinfo6->src_addr = conn->c_laddr; + iinfo6->dst_addr = conn->c_faddr; + if (ic) { + iinfo6->tos = conn->c_tos; + iinfo6->sl = ic->i_sl; + } + + memset(&iinfo6->src_gid, 0, sizeof(iinfo6->src_gid)); + memset(&iinfo6->dst_gid, 0, sizeof(iinfo6->dst_gid)); + + if (rds_conn_state(conn) == RDS_CONN_UP) { + struct rds_ib_device *rds_ibdev; + + rdma_read_gids(ic->i_cm_id, (union ib_gid *)&iinfo6->src_gid, + (union ib_gid *)&iinfo6->dst_gid); + rds_ibdev = ic->rds_ibdev; + iinfo6->max_send_wr = ic->i_send_ring.w_nr; + iinfo6->max_recv_wr = ic->i_recv_ring.w_nr; + iinfo6->max_send_sge = rds_ibdev->max_sge; + rds6_ib_get_mr_info(rds_ibdev, iinfo6); + iinfo6->cache_allocs = atomic_read(&ic->i_cache_allocs); + } + return 1; +} +#endif + +static void rds_ib_ic_info(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens) +{ + u64 buffer[(sizeof(struct rds_info_rdma_connection) + 7) / 8]; + + rds_for_each_conn_info(sock, len, iter, lens, + rds_ib_conn_info_visitor, + buffer, + sizeof(struct rds_info_rdma_connection)); +} + +#if IS_ENABLED(CONFIG_IPV6) +/* IPv6 version of rds_ib_ic_info(). */ +static void rds6_ib_ic_info(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens) +{ + u64 buffer[(sizeof(struct rds6_info_rdma_connection) + 7) / 8]; + + rds_for_each_conn_info(sock, len, iter, lens, + rds6_ib_conn_info_visitor, + buffer, + sizeof(struct rds6_info_rdma_connection)); +} +#endif + +/* + * Early RDS/IB was built to only bind to an address if there is an IPoIB + * device with that address set. + * + * If it were me, I'd advocate for something more flexible. Sending and + * receiving should be device-agnostic. Transports would try and maintain + * connections between peers who have messages queued. Userspace would be + * allowed to influence which paths have priority. We could call userspace + * asserting this policy "routing". + */ +static int rds_ib_laddr_check(struct net *net, const struct in6_addr *addr, + __u32 scope_id) +{ + int ret; + struct rdma_cm_id *cm_id; +#if IS_ENABLED(CONFIG_IPV6) + struct sockaddr_in6 sin6; +#endif + struct sockaddr_in sin; + struct sockaddr *sa; + bool isv4; + + isv4 = ipv6_addr_v4mapped(addr); + /* Create a CMA ID and try to bind it. This catches both + * IB and iWARP capable NICs. + */ + cm_id = rdma_create_id(&init_net, rds_rdma_cm_event_handler, + NULL, RDMA_PS_TCP, IB_QPT_RC); + if (IS_ERR(cm_id)) + return PTR_ERR(cm_id); + + if (isv4) { + memset(&sin, 0, sizeof(sin)); + sin.sin_family = AF_INET; + sin.sin_addr.s_addr = addr->s6_addr32[3]; + sa = (struct sockaddr *)&sin; + } else { +#if IS_ENABLED(CONFIG_IPV6) + memset(&sin6, 0, sizeof(sin6)); + sin6.sin6_family = AF_INET6; + sin6.sin6_addr = *addr; + sin6.sin6_scope_id = scope_id; + sa = (struct sockaddr *)&sin6; + + /* XXX Do a special IPv6 link local address check here. The + * reason is that rdma_bind_addr() always succeeds with IPv6 + * link local address regardless it is indeed configured in a + * system. + */ + if (ipv6_addr_type(addr) & IPV6_ADDR_LINKLOCAL) { + struct net_device *dev; + + if (scope_id == 0) { + ret = -EADDRNOTAVAIL; + goto out; + } + + /* Use init_net for now as RDS is not network + * name space aware. + */ + dev = dev_get_by_index(&init_net, scope_id); + if (!dev) { + ret = -EADDRNOTAVAIL; + goto out; + } + if (!ipv6_chk_addr(&init_net, addr, dev, 1)) { + dev_put(dev); + ret = -EADDRNOTAVAIL; + goto out; + } + dev_put(dev); + } +#else + ret = -EADDRNOTAVAIL; + goto out; +#endif + } + + /* rdma_bind_addr will only succeed for IB & iWARP devices */ + ret = rdma_bind_addr(cm_id, sa); + /* due to this, we will claim to support iWARP devices unless we + check node_type. */ + if (ret || !cm_id->device || + cm_id->device->node_type != RDMA_NODE_IB_CA) + ret = -EADDRNOTAVAIL; + + rdsdebug("addr %pI6c%%%u ret %d node type %d\n", + addr, scope_id, ret, + cm_id->device ? cm_id->device->node_type : -1); + +out: + rdma_destroy_id(cm_id); + + return ret; +} + +static void rds_ib_unregister_client(void) +{ + ib_unregister_client(&rds_ib_client); + /* wait for rds_ib_dev_free() to complete */ + flush_workqueue(rds_wq); +} + +static void rds_ib_set_unloading(void) +{ + atomic_set(&rds_ib_unloading, 1); +} + +static bool rds_ib_is_unloading(struct rds_connection *conn) +{ + struct rds_conn_path *cp = &conn->c_path[0]; + + return (test_bit(RDS_DESTROY_PENDING, &cp->cp_flags) || + atomic_read(&rds_ib_unloading) != 0); +} + +void rds_ib_exit(void) +{ + rds_ib_set_unloading(); + synchronize_rcu(); + rds_info_deregister_func(RDS_INFO_IB_CONNECTIONS, rds_ib_ic_info); +#if IS_ENABLED(CONFIG_IPV6) + rds_info_deregister_func(RDS6_INFO_IB_CONNECTIONS, rds6_ib_ic_info); +#endif + rds_ib_unregister_client(); + rds_ib_destroy_nodev_conns(); + rds_ib_sysctl_exit(); + rds_ib_recv_exit(); + rds_trans_unregister(&rds_ib_transport); + rds_ib_mr_exit(); +} + +static u8 rds_ib_get_tos_map(u8 tos) +{ + /* 1:1 user to transport map for RDMA transport. + * In future, if custom map is desired, hook can export + * user configurable map. + */ + return tos; +} + +struct rds_transport rds_ib_transport = { + .laddr_check = rds_ib_laddr_check, + .xmit_path_complete = rds_ib_xmit_path_complete, + .xmit = rds_ib_xmit, + .xmit_rdma = rds_ib_xmit_rdma, + .xmit_atomic = rds_ib_xmit_atomic, + .recv_path = rds_ib_recv_path, + .conn_alloc = rds_ib_conn_alloc, + .conn_free = rds_ib_conn_free, + .conn_path_connect = rds_ib_conn_path_connect, + .conn_path_shutdown = rds_ib_conn_path_shutdown, + .inc_copy_to_user = rds_ib_inc_copy_to_user, + .inc_free = rds_ib_inc_free, + .cm_initiate_connect = rds_ib_cm_initiate_connect, + .cm_handle_connect = rds_ib_cm_handle_connect, + .cm_connect_complete = rds_ib_cm_connect_complete, + .stats_info_copy = rds_ib_stats_info_copy, + .exit = rds_ib_exit, + .get_mr = rds_ib_get_mr, + .sync_mr = rds_ib_sync_mr, + .free_mr = rds_ib_free_mr, + .flush_mrs = rds_ib_flush_mrs, + .get_tos_map = rds_ib_get_tos_map, + .t_owner = THIS_MODULE, + .t_name = "infiniband", + .t_unloading = rds_ib_is_unloading, + .t_type = RDS_TRANS_IB +}; + +int rds_ib_init(void) +{ + int ret; + + INIT_LIST_HEAD(&rds_ib_devices); + + ret = rds_ib_mr_init(); + if (ret) + goto out; + + ret = ib_register_client(&rds_ib_client); + if (ret) + goto out_mr_exit; + + ret = rds_ib_sysctl_init(); + if (ret) + goto out_ibreg; + + ret = rds_ib_recv_init(); + if (ret) + goto out_sysctl; + + rds_trans_register(&rds_ib_transport); + + rds_info_register_func(RDS_INFO_IB_CONNECTIONS, rds_ib_ic_info); +#if IS_ENABLED(CONFIG_IPV6) + rds_info_register_func(RDS6_INFO_IB_CONNECTIONS, rds6_ib_ic_info); +#endif + + goto out; + +out_sysctl: + rds_ib_sysctl_exit(); +out_ibreg: + rds_ib_unregister_client(); +out_mr_exit: + rds_ib_mr_exit(); +out: + return ret; +} + +MODULE_LICENSE("GPL"); diff --git a/net/rds/ib.h b/net/rds/ib.h new file mode 100644 index 000000000..2ba71102b --- /dev/null +++ b/net/rds/ib.h @@ -0,0 +1,458 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _RDS_IB_H +#define _RDS_IB_H + +#include +#include +#include +#include +#include +#include "rds.h" +#include "rdma_transport.h" + +#define RDS_IB_MAX_SGE 8 +#define RDS_IB_RECV_SGE 2 + +#define RDS_IB_DEFAULT_RECV_WR 1024 +#define RDS_IB_DEFAULT_SEND_WR 256 +#define RDS_IB_DEFAULT_FR_WR 512 + +#define RDS_IB_DEFAULT_RETRY_COUNT 1 + +#define RDS_IB_SUPPORTED_PROTOCOLS 0x00000003 /* minor versions supported */ + +#define RDS_IB_RECYCLE_BATCH_COUNT 32 + +#define RDS_IB_WC_MAX 32 + +extern struct rw_semaphore rds_ib_devices_lock; +extern struct list_head rds_ib_devices; + +/* + * IB posts RDS_FRAG_SIZE fragments of pages to the receive queues to + * try and minimize the amount of memory tied up both the device and + * socket receive queues. + */ +struct rds_page_frag { + struct list_head f_item; + struct list_head f_cache_entry; + struct scatterlist f_sg; +}; + +struct rds_ib_incoming { + struct list_head ii_frags; + struct list_head ii_cache_entry; + struct rds_incoming ii_inc; +}; + +struct rds_ib_cache_head { + struct list_head *first; + unsigned long count; +}; + +struct rds_ib_refill_cache { + struct rds_ib_cache_head __percpu *percpu; + struct list_head *xfer; + struct list_head *ready; +}; + +/* This is the common structure for the IB private data exchange in setting up + * an RDS connection. The exchange is different for IPv4 and IPv6 connections. + * The reason is that the address size is different and the addresses + * exchanged are in the beginning of the structure. Hence it is not possible + * for interoperability if same structure is used. + */ +struct rds_ib_conn_priv_cmn { + u8 ricpc_protocol_major; + u8 ricpc_protocol_minor; + __be16 ricpc_protocol_minor_mask; /* bitmask */ + u8 ricpc_dp_toss; + u8 ripc_reserved1; + __be16 ripc_reserved2; + __be64 ricpc_ack_seq; + __be32 ricpc_credit; /* non-zero enables flow ctl */ +}; + +struct rds_ib_connect_private { + /* Add new fields at the end, and don't permute existing fields. */ + __be32 dp_saddr; + __be32 dp_daddr; + struct rds_ib_conn_priv_cmn dp_cmn; +}; + +struct rds6_ib_connect_private { + /* Add new fields at the end, and don't permute existing fields. */ + struct in6_addr dp_saddr; + struct in6_addr dp_daddr; + struct rds_ib_conn_priv_cmn dp_cmn; +}; + +#define dp_protocol_major dp_cmn.ricpc_protocol_major +#define dp_protocol_minor dp_cmn.ricpc_protocol_minor +#define dp_protocol_minor_mask dp_cmn.ricpc_protocol_minor_mask +#define dp_ack_seq dp_cmn.ricpc_ack_seq +#define dp_credit dp_cmn.ricpc_credit + +union rds_ib_conn_priv { + struct rds_ib_connect_private ricp_v4; + struct rds6_ib_connect_private ricp_v6; +}; + +struct rds_ib_send_work { + void *s_op; + union { + struct ib_send_wr s_wr; + struct ib_rdma_wr s_rdma_wr; + struct ib_atomic_wr s_atomic_wr; + }; + struct ib_sge s_sge[RDS_IB_MAX_SGE]; + unsigned long s_queued; +}; + +struct rds_ib_recv_work { + struct rds_ib_incoming *r_ibinc; + struct rds_page_frag *r_frag; + struct ib_recv_wr r_wr; + struct ib_sge r_sge[2]; +}; + +struct rds_ib_work_ring { + u32 w_nr; + u32 w_alloc_ptr; + u32 w_alloc_ctr; + u32 w_free_ptr; + atomic_t w_free_ctr; +}; + +/* Rings are posted with all the allocations they'll need to queue the + * incoming message to the receiving socket so this can't fail. + * All fragments start with a header, so we can make sure we're not receiving + * garbage, and we can tell a small 8 byte fragment from an ACK frame. + */ +struct rds_ib_ack_state { + u64 ack_next; + u64 ack_recv; + unsigned int ack_required:1; + unsigned int ack_next_valid:1; + unsigned int ack_recv_valid:1; +}; + + +struct rds_ib_device; + +struct rds_ib_connection { + + struct list_head ib_node; + struct rds_ib_device *rds_ibdev; + struct rds_connection *conn; + + /* alphabet soup, IBTA style */ + struct rdma_cm_id *i_cm_id; + struct ib_pd *i_pd; + struct ib_cq *i_send_cq; + struct ib_cq *i_recv_cq; + struct ib_wc i_send_wc[RDS_IB_WC_MAX]; + struct ib_wc i_recv_wc[RDS_IB_WC_MAX]; + + /* To control the number of wrs from fastreg */ + atomic_t i_fastreg_wrs; + atomic_t i_fastreg_inuse_count; + + /* interrupt handling */ + struct tasklet_struct i_send_tasklet; + struct tasklet_struct i_recv_tasklet; + + /* tx */ + struct rds_ib_work_ring i_send_ring; + struct rm_data_op *i_data_op; + struct rds_header **i_send_hdrs; + dma_addr_t *i_send_hdrs_dma; + struct rds_ib_send_work *i_sends; + atomic_t i_signaled_sends; + + /* rx */ + struct mutex i_recv_mutex; + struct rds_ib_work_ring i_recv_ring; + struct rds_ib_incoming *i_ibinc; + u32 i_recv_data_rem; + struct rds_header **i_recv_hdrs; + dma_addr_t *i_recv_hdrs_dma; + struct rds_ib_recv_work *i_recvs; + u64 i_ack_recv; /* last ACK received */ + struct rds_ib_refill_cache i_cache_incs; + struct rds_ib_refill_cache i_cache_frags; + atomic_t i_cache_allocs; + + /* sending acks */ + unsigned long i_ack_flags; +#ifdef KERNEL_HAS_ATOMIC64 + atomic64_t i_ack_next; /* next ACK to send */ +#else + spinlock_t i_ack_lock; /* protect i_ack_next */ + u64 i_ack_next; /* next ACK to send */ +#endif + struct rds_header *i_ack; + struct ib_send_wr i_ack_wr; + struct ib_sge i_ack_sge; + dma_addr_t i_ack_dma; + unsigned long i_ack_queued; + + /* Flow control related information + * + * Our algorithm uses a pair variables that we need to access + * atomically - one for the send credits, and one posted + * recv credits we need to transfer to remote. + * Rather than protect them using a slow spinlock, we put both into + * a single atomic_t and update it using cmpxchg + */ + atomic_t i_credits; + + /* Protocol version specific information */ + unsigned int i_flowctl:1; /* enable/disable flow ctl */ + + /* Batched completions */ + unsigned int i_unsignaled_wrs; + + /* Endpoint role in connection */ + bool i_active_side; + atomic_t i_cq_quiesce; + + /* Send/Recv vectors */ + int i_scq_vector; + int i_rcq_vector; + u8 i_sl; +}; + +/* This assumes that atomic_t is at least 32 bits */ +#define IB_GET_SEND_CREDITS(v) ((v) & 0xffff) +#define IB_GET_POST_CREDITS(v) ((v) >> 16) +#define IB_SET_SEND_CREDITS(v) ((v) & 0xffff) +#define IB_SET_POST_CREDITS(v) ((v) << 16) + +struct rds_ib_ipaddr { + struct list_head list; + __be32 ipaddr; + struct rcu_head rcu; +}; + +enum { + RDS_IB_MR_8K_POOL, + RDS_IB_MR_1M_POOL, +}; + +struct rds_ib_device { + struct list_head list; + struct list_head ipaddr_list; + struct list_head conn_list; + struct ib_device *dev; + struct ib_pd *pd; + u8 odp_capable:1; + + unsigned int max_mrs; + struct rds_ib_mr_pool *mr_1m_pool; + struct rds_ib_mr_pool *mr_8k_pool; + unsigned int max_8k_mrs; + unsigned int max_1m_mrs; + int max_sge; + unsigned int max_wrs; + unsigned int max_initiator_depth; + unsigned int max_responder_resources; + spinlock_t spinlock; /* protect the above */ + refcount_t refcount; + struct work_struct free_work; + int *vector_load; +}; + +#define rdsibdev_to_node(rdsibdev) ibdev_to_node(rdsibdev->dev) + +/* bits for i_ack_flags */ +#define IB_ACK_IN_FLIGHT 0 +#define IB_ACK_REQUESTED 1 + +/* Magic WR_ID for ACKs */ +#define RDS_IB_ACK_WR_ID (~(u64) 0) + +struct rds_ib_statistics { + uint64_t s_ib_connect_raced; + uint64_t s_ib_listen_closed_stale; + uint64_t s_ib_evt_handler_call; + uint64_t s_ib_tasklet_call; + uint64_t s_ib_tx_cq_event; + uint64_t s_ib_tx_ring_full; + uint64_t s_ib_tx_throttle; + uint64_t s_ib_tx_sg_mapping_failure; + uint64_t s_ib_tx_stalled; + uint64_t s_ib_tx_credit_updates; + uint64_t s_ib_rx_cq_event; + uint64_t s_ib_rx_ring_empty; + uint64_t s_ib_rx_refill_from_cq; + uint64_t s_ib_rx_refill_from_thread; + uint64_t s_ib_rx_alloc_limit; + uint64_t s_ib_rx_total_frags; + uint64_t s_ib_rx_total_incs; + uint64_t s_ib_rx_credit_updates; + uint64_t s_ib_ack_sent; + uint64_t s_ib_ack_send_failure; + uint64_t s_ib_ack_send_delayed; + uint64_t s_ib_ack_send_piggybacked; + uint64_t s_ib_ack_received; + uint64_t s_ib_rdma_mr_8k_alloc; + uint64_t s_ib_rdma_mr_8k_free; + uint64_t s_ib_rdma_mr_8k_used; + uint64_t s_ib_rdma_mr_8k_pool_flush; + uint64_t s_ib_rdma_mr_8k_pool_wait; + uint64_t s_ib_rdma_mr_8k_pool_depleted; + uint64_t s_ib_rdma_mr_1m_alloc; + uint64_t s_ib_rdma_mr_1m_free; + uint64_t s_ib_rdma_mr_1m_used; + uint64_t s_ib_rdma_mr_1m_pool_flush; + uint64_t s_ib_rdma_mr_1m_pool_wait; + uint64_t s_ib_rdma_mr_1m_pool_depleted; + uint64_t s_ib_rdma_mr_8k_reused; + uint64_t s_ib_rdma_mr_1m_reused; + uint64_t s_ib_atomic_cswp; + uint64_t s_ib_atomic_fadd; + uint64_t s_ib_recv_added_to_cache; + uint64_t s_ib_recv_removed_from_cache; +}; + +extern struct workqueue_struct *rds_ib_wq; + +/* + * Fake ib_dma_sync_sg_for_{cpu,device} as long as ib_verbs.h + * doesn't define it. + */ +static inline void rds_ib_dma_sync_sg_for_cpu(struct ib_device *dev, + struct scatterlist *sglist, + unsigned int sg_dma_len, + int direction) +{ + struct scatterlist *sg; + unsigned int i; + + for_each_sg(sglist, sg, sg_dma_len, i) { + ib_dma_sync_single_for_cpu(dev, sg_dma_address(sg), + sg_dma_len(sg), direction); + } +} +#define ib_dma_sync_sg_for_cpu rds_ib_dma_sync_sg_for_cpu + +static inline void rds_ib_dma_sync_sg_for_device(struct ib_device *dev, + struct scatterlist *sglist, + unsigned int sg_dma_len, + int direction) +{ + struct scatterlist *sg; + unsigned int i; + + for_each_sg(sglist, sg, sg_dma_len, i) { + ib_dma_sync_single_for_device(dev, sg_dma_address(sg), + sg_dma_len(sg), direction); + } +} +#define ib_dma_sync_sg_for_device rds_ib_dma_sync_sg_for_device + + +/* ib.c */ +extern struct rds_transport rds_ib_transport; +struct rds_ib_device *rds_ib_get_client_data(struct ib_device *device); +void rds_ib_dev_put(struct rds_ib_device *rds_ibdev); +extern struct ib_client rds_ib_client; + +extern unsigned int rds_ib_retry_count; + +extern spinlock_t ib_nodev_conns_lock; +extern struct list_head ib_nodev_conns; + +/* ib_cm.c */ +int rds_ib_conn_alloc(struct rds_connection *conn, gfp_t gfp); +void rds_ib_conn_free(void *arg); +int rds_ib_conn_path_connect(struct rds_conn_path *cp); +void rds_ib_conn_path_shutdown(struct rds_conn_path *cp); +void rds_ib_state_change(struct sock *sk); +int rds_ib_listen_init(void); +void rds_ib_listen_stop(void); +__printf(2, 3) +void __rds_ib_conn_error(struct rds_connection *conn, const char *, ...); +int rds_ib_cm_handle_connect(struct rdma_cm_id *cm_id, + struct rdma_cm_event *event, bool isv6); +int rds_ib_cm_initiate_connect(struct rdma_cm_id *cm_id, bool isv6); +void rds_ib_cm_connect_complete(struct rds_connection *conn, + struct rdma_cm_event *event); + +#define rds_ib_conn_error(conn, fmt...) \ + __rds_ib_conn_error(conn, KERN_WARNING "RDS/IB: " fmt) + +/* ib_rdma.c */ +int rds_ib_update_ipaddr(struct rds_ib_device *rds_ibdev, + struct in6_addr *ipaddr); +void rds_ib_add_conn(struct rds_ib_device *rds_ibdev, struct rds_connection *conn); +void rds_ib_remove_conn(struct rds_ib_device *rds_ibdev, struct rds_connection *conn); +void rds_ib_destroy_nodev_conns(void); +void rds_ib_mr_cqe_handler(struct rds_ib_connection *ic, struct ib_wc *wc); + +/* ib_recv.c */ +int rds_ib_recv_init(void); +void rds_ib_recv_exit(void); +int rds_ib_recv_path(struct rds_conn_path *conn); +int rds_ib_recv_alloc_caches(struct rds_ib_connection *ic, gfp_t gfp); +void rds_ib_recv_free_caches(struct rds_ib_connection *ic); +void rds_ib_recv_refill(struct rds_connection *conn, int prefill, gfp_t gfp); +void rds_ib_inc_free(struct rds_incoming *inc); +int rds_ib_inc_copy_to_user(struct rds_incoming *inc, struct iov_iter *to); +void rds_ib_recv_cqe_handler(struct rds_ib_connection *ic, struct ib_wc *wc, + struct rds_ib_ack_state *state); +void rds_ib_recv_tasklet_fn(unsigned long data); +void rds_ib_recv_init_ring(struct rds_ib_connection *ic); +void rds_ib_recv_clear_ring(struct rds_ib_connection *ic); +void rds_ib_recv_init_ack(struct rds_ib_connection *ic); +void rds_ib_attempt_ack(struct rds_ib_connection *ic); +void rds_ib_ack_send_complete(struct rds_ib_connection *ic); +u64 rds_ib_piggyb_ack(struct rds_ib_connection *ic); +void rds_ib_set_ack(struct rds_ib_connection *ic, u64 seq, int ack_required); + +/* ib_ring.c */ +void rds_ib_ring_init(struct rds_ib_work_ring *ring, u32 nr); +void rds_ib_ring_resize(struct rds_ib_work_ring *ring, u32 nr); +u32 rds_ib_ring_alloc(struct rds_ib_work_ring *ring, u32 val, u32 *pos); +void rds_ib_ring_free(struct rds_ib_work_ring *ring, u32 val); +void rds_ib_ring_unalloc(struct rds_ib_work_ring *ring, u32 val); +int rds_ib_ring_empty(struct rds_ib_work_ring *ring); +int rds_ib_ring_low(struct rds_ib_work_ring *ring); +u32 rds_ib_ring_oldest(struct rds_ib_work_ring *ring); +u32 rds_ib_ring_completed(struct rds_ib_work_ring *ring, u32 wr_id, u32 oldest); +extern wait_queue_head_t rds_ib_ring_empty_wait; + +/* ib_send.c */ +void rds_ib_xmit_path_complete(struct rds_conn_path *cp); +int rds_ib_xmit(struct rds_connection *conn, struct rds_message *rm, + unsigned int hdr_off, unsigned int sg, unsigned int off); +void rds_ib_send_cqe_handler(struct rds_ib_connection *ic, struct ib_wc *wc); +void rds_ib_send_init_ring(struct rds_ib_connection *ic); +void rds_ib_send_clear_ring(struct rds_ib_connection *ic); +int rds_ib_xmit_rdma(struct rds_connection *conn, struct rm_rdma_op *op); +void rds_ib_send_add_credits(struct rds_connection *conn, unsigned int credits); +void rds_ib_advertise_credits(struct rds_connection *conn, unsigned int posted); +int rds_ib_send_grab_credits(struct rds_ib_connection *ic, u32 wanted, + u32 *adv_credits, int need_posted, int max_posted); +int rds_ib_xmit_atomic(struct rds_connection *conn, struct rm_atomic_op *op); + +/* ib_stats.c */ +DECLARE_PER_CPU_SHARED_ALIGNED(struct rds_ib_statistics, rds_ib_stats); +#define rds_ib_stats_inc(member) rds_stats_inc_which(rds_ib_stats, member) +#define rds_ib_stats_add(member, count) \ + rds_stats_add_which(rds_ib_stats, member, count) +unsigned int rds_ib_stats_info_copy(struct rds_info_iterator *iter, + unsigned int avail); + +/* ib_sysctl.c */ +int rds_ib_sysctl_init(void); +void rds_ib_sysctl_exit(void); +extern unsigned long rds_ib_sysctl_max_send_wr; +extern unsigned long rds_ib_sysctl_max_recv_wr; +extern unsigned long rds_ib_sysctl_max_unsig_wrs; +extern unsigned long rds_ib_sysctl_max_unsig_bytes; +extern unsigned long rds_ib_sysctl_max_recv_allocation; +extern unsigned int rds_ib_sysctl_flow_control; + +#endif diff --git a/net/rds/ib_cm.c b/net/rds/ib_cm.c new file mode 100644 index 000000000..26b069e19 --- /dev/null +++ b/net/rds/ib_cm.c @@ -0,0 +1,1287 @@ +/* + * Copyright (c) 2006, 2019 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include +#include +#include +#include + +#include "rds_single_path.h" +#include "rds.h" +#include "ib.h" +#include "ib_mr.h" + +/* + * Set the selected protocol version + */ +static void rds_ib_set_protocol(struct rds_connection *conn, unsigned int version) +{ + conn->c_version = version; +} + +/* + * Set up flow control + */ +static void rds_ib_set_flow_control(struct rds_connection *conn, u32 credits) +{ + struct rds_ib_connection *ic = conn->c_transport_data; + + if (rds_ib_sysctl_flow_control && credits != 0) { + /* We're doing flow control */ + ic->i_flowctl = 1; + rds_ib_send_add_credits(conn, credits); + } else { + ic->i_flowctl = 0; + } +} + +/* + * Connection established. + * We get here for both outgoing and incoming connection. + */ +void rds_ib_cm_connect_complete(struct rds_connection *conn, struct rdma_cm_event *event) +{ + struct rds_ib_connection *ic = conn->c_transport_data; + const union rds_ib_conn_priv *dp = NULL; + __be64 ack_seq = 0; + __be32 credit = 0; + u8 major = 0; + u8 minor = 0; + int err; + + dp = event->param.conn.private_data; + if (conn->c_isv6) { + if (event->param.conn.private_data_len >= + sizeof(struct rds6_ib_connect_private)) { + major = dp->ricp_v6.dp_protocol_major; + minor = dp->ricp_v6.dp_protocol_minor; + credit = dp->ricp_v6.dp_credit; + /* dp structure start is not guaranteed to be 8 bytes + * aligned. Since dp_ack_seq is 64-bit extended load + * operations can be used so go through get_unaligned + * to avoid unaligned errors. + */ + ack_seq = get_unaligned(&dp->ricp_v6.dp_ack_seq); + } + } else if (event->param.conn.private_data_len >= + sizeof(struct rds_ib_connect_private)) { + major = dp->ricp_v4.dp_protocol_major; + minor = dp->ricp_v4.dp_protocol_minor; + credit = dp->ricp_v4.dp_credit; + ack_seq = get_unaligned(&dp->ricp_v4.dp_ack_seq); + } + + /* make sure it isn't empty data */ + if (major) { + rds_ib_set_protocol(conn, RDS_PROTOCOL(major, minor)); + rds_ib_set_flow_control(conn, be32_to_cpu(credit)); + } + + if (conn->c_version < RDS_PROTOCOL_VERSION) { + if (conn->c_version != RDS_PROTOCOL_COMPAT_VERSION) { + pr_notice("RDS/IB: Connection <%pI6c,%pI6c> version %u.%u no longer supported\n", + &conn->c_laddr, &conn->c_faddr, + RDS_PROTOCOL_MAJOR(conn->c_version), + RDS_PROTOCOL_MINOR(conn->c_version)); + rds_conn_destroy(conn); + return; + } + } + + pr_notice("RDS/IB: %s conn connected <%pI6c,%pI6c,%d> version %u.%u%s\n", + ic->i_active_side ? "Active" : "Passive", + &conn->c_laddr, &conn->c_faddr, conn->c_tos, + RDS_PROTOCOL_MAJOR(conn->c_version), + RDS_PROTOCOL_MINOR(conn->c_version), + ic->i_flowctl ? ", flow control" : ""); + + /* receive sl from the peer */ + ic->i_sl = ic->i_cm_id->route.path_rec->sl; + + atomic_set(&ic->i_cq_quiesce, 0); + + /* Init rings and fill recv. this needs to wait until protocol + * negotiation is complete, since ring layout is different + * from 3.1 to 4.1. + */ + rds_ib_send_init_ring(ic); + rds_ib_recv_init_ring(ic); + /* Post receive buffers - as a side effect, this will update + * the posted credit count. */ + rds_ib_recv_refill(conn, 1, GFP_KERNEL); + + /* update ib_device with this local ipaddr */ + err = rds_ib_update_ipaddr(ic->rds_ibdev, &conn->c_laddr); + if (err) + printk(KERN_ERR "rds_ib_update_ipaddr failed (%d)\n", + err); + + /* If the peer gave us the last packet it saw, process this as if + * we had received a regular ACK. */ + if (dp) { + if (ack_seq) + rds_send_drop_acked(conn, be64_to_cpu(ack_seq), + NULL); + } + + conn->c_proposed_version = conn->c_version; + rds_connect_complete(conn); +} + +static void rds_ib_cm_fill_conn_param(struct rds_connection *conn, + struct rdma_conn_param *conn_param, + union rds_ib_conn_priv *dp, + u32 protocol_version, + u32 max_responder_resources, + u32 max_initiator_depth, + bool isv6) +{ + struct rds_ib_connection *ic = conn->c_transport_data; + struct rds_ib_device *rds_ibdev = ic->rds_ibdev; + + memset(conn_param, 0, sizeof(struct rdma_conn_param)); + + conn_param->responder_resources = + min_t(u32, rds_ibdev->max_responder_resources, max_responder_resources); + conn_param->initiator_depth = + min_t(u32, rds_ibdev->max_initiator_depth, max_initiator_depth); + conn_param->retry_count = min_t(unsigned int, rds_ib_retry_count, 7); + conn_param->rnr_retry_count = 7; + + if (dp) { + memset(dp, 0, sizeof(*dp)); + if (isv6) { + dp->ricp_v6.dp_saddr = conn->c_laddr; + dp->ricp_v6.dp_daddr = conn->c_faddr; + dp->ricp_v6.dp_protocol_major = + RDS_PROTOCOL_MAJOR(protocol_version); + dp->ricp_v6.dp_protocol_minor = + RDS_PROTOCOL_MINOR(protocol_version); + dp->ricp_v6.dp_protocol_minor_mask = + cpu_to_be16(RDS_IB_SUPPORTED_PROTOCOLS); + dp->ricp_v6.dp_ack_seq = + cpu_to_be64(rds_ib_piggyb_ack(ic)); + dp->ricp_v6.dp_cmn.ricpc_dp_toss = conn->c_tos; + + conn_param->private_data = &dp->ricp_v6; + conn_param->private_data_len = sizeof(dp->ricp_v6); + } else { + dp->ricp_v4.dp_saddr = conn->c_laddr.s6_addr32[3]; + dp->ricp_v4.dp_daddr = conn->c_faddr.s6_addr32[3]; + dp->ricp_v4.dp_protocol_major = + RDS_PROTOCOL_MAJOR(protocol_version); + dp->ricp_v4.dp_protocol_minor = + RDS_PROTOCOL_MINOR(protocol_version); + dp->ricp_v4.dp_protocol_minor_mask = + cpu_to_be16(RDS_IB_SUPPORTED_PROTOCOLS); + dp->ricp_v4.dp_ack_seq = + cpu_to_be64(rds_ib_piggyb_ack(ic)); + dp->ricp_v4.dp_cmn.ricpc_dp_toss = conn->c_tos; + + conn_param->private_data = &dp->ricp_v4; + conn_param->private_data_len = sizeof(dp->ricp_v4); + } + + /* Advertise flow control */ + if (ic->i_flowctl) { + unsigned int credits; + + credits = IB_GET_POST_CREDITS + (atomic_read(&ic->i_credits)); + if (isv6) + dp->ricp_v6.dp_credit = cpu_to_be32(credits); + else + dp->ricp_v4.dp_credit = cpu_to_be32(credits); + atomic_sub(IB_SET_POST_CREDITS(credits), + &ic->i_credits); + } + } +} + +static void rds_ib_cq_event_handler(struct ib_event *event, void *data) +{ + rdsdebug("event %u (%s) data %p\n", + event->event, ib_event_msg(event->event), data); +} + +/* Plucking the oldest entry from the ring can be done concurrently with + * the thread refilling the ring. Each ring operation is protected by + * spinlocks and the transient state of refilling doesn't change the + * recording of which entry is oldest. + * + * This relies on IB only calling one cq comp_handler for each cq so that + * there will only be one caller of rds_recv_incoming() per RDS connection. + */ +static void rds_ib_cq_comp_handler_recv(struct ib_cq *cq, void *context) +{ + struct rds_connection *conn = context; + struct rds_ib_connection *ic = conn->c_transport_data; + + rdsdebug("conn %p cq %p\n", conn, cq); + + rds_ib_stats_inc(s_ib_evt_handler_call); + + tasklet_schedule(&ic->i_recv_tasklet); +} + +static void poll_scq(struct rds_ib_connection *ic, struct ib_cq *cq, + struct ib_wc *wcs) +{ + int nr, i; + struct ib_wc *wc; + + while ((nr = ib_poll_cq(cq, RDS_IB_WC_MAX, wcs)) > 0) { + for (i = 0; i < nr; i++) { + wc = wcs + i; + rdsdebug("wc wr_id 0x%llx status %u byte_len %u imm_data %u\n", + (unsigned long long)wc->wr_id, wc->status, + wc->byte_len, be32_to_cpu(wc->ex.imm_data)); + + if (wc->wr_id <= ic->i_send_ring.w_nr || + wc->wr_id == RDS_IB_ACK_WR_ID) + rds_ib_send_cqe_handler(ic, wc); + else + rds_ib_mr_cqe_handler(ic, wc); + + } + } +} + +static void rds_ib_tasklet_fn_send(unsigned long data) +{ + struct rds_ib_connection *ic = (struct rds_ib_connection *)data; + struct rds_connection *conn = ic->conn; + + rds_ib_stats_inc(s_ib_tasklet_call); + + /* if cq has been already reaped, ignore incoming cq event */ + if (atomic_read(&ic->i_cq_quiesce)) + return; + + poll_scq(ic, ic->i_send_cq, ic->i_send_wc); + ib_req_notify_cq(ic->i_send_cq, IB_CQ_NEXT_COMP); + poll_scq(ic, ic->i_send_cq, ic->i_send_wc); + + if (rds_conn_up(conn) && + (!test_bit(RDS_LL_SEND_FULL, &conn->c_flags) || + test_bit(0, &conn->c_map_queued))) + rds_send_xmit(&ic->conn->c_path[0]); +} + +static void poll_rcq(struct rds_ib_connection *ic, struct ib_cq *cq, + struct ib_wc *wcs, + struct rds_ib_ack_state *ack_state) +{ + int nr, i; + struct ib_wc *wc; + + while ((nr = ib_poll_cq(cq, RDS_IB_WC_MAX, wcs)) > 0) { + for (i = 0; i < nr; i++) { + wc = wcs + i; + rdsdebug("wc wr_id 0x%llx status %u byte_len %u imm_data %u\n", + (unsigned long long)wc->wr_id, wc->status, + wc->byte_len, be32_to_cpu(wc->ex.imm_data)); + + rds_ib_recv_cqe_handler(ic, wc, ack_state); + } + } +} + +static void rds_ib_tasklet_fn_recv(unsigned long data) +{ + struct rds_ib_connection *ic = (struct rds_ib_connection *)data; + struct rds_connection *conn = ic->conn; + struct rds_ib_device *rds_ibdev = ic->rds_ibdev; + struct rds_ib_ack_state state; + + if (!rds_ibdev) + rds_conn_drop(conn); + + rds_ib_stats_inc(s_ib_tasklet_call); + + /* if cq has been already reaped, ignore incoming cq event */ + if (atomic_read(&ic->i_cq_quiesce)) + return; + + memset(&state, 0, sizeof(state)); + poll_rcq(ic, ic->i_recv_cq, ic->i_recv_wc, &state); + ib_req_notify_cq(ic->i_recv_cq, IB_CQ_SOLICITED); + poll_rcq(ic, ic->i_recv_cq, ic->i_recv_wc, &state); + + if (state.ack_next_valid) + rds_ib_set_ack(ic, state.ack_next, state.ack_required); + if (state.ack_recv_valid && state.ack_recv > ic->i_ack_recv) { + rds_send_drop_acked(conn, state.ack_recv, NULL); + ic->i_ack_recv = state.ack_recv; + } + + if (rds_conn_up(conn)) + rds_ib_attempt_ack(ic); +} + +static void rds_ib_qp_event_handler(struct ib_event *event, void *data) +{ + struct rds_connection *conn = data; + struct rds_ib_connection *ic = conn->c_transport_data; + + rdsdebug("conn %p ic %p event %u (%s)\n", conn, ic, event->event, + ib_event_msg(event->event)); + + switch (event->event) { + case IB_EVENT_COMM_EST: + rdma_notify(ic->i_cm_id, IB_EVENT_COMM_EST); + break; + default: + rdsdebug("Fatal QP Event %u (%s) - connection %pI6c->%pI6c, reconnecting\n", + event->event, ib_event_msg(event->event), + &conn->c_laddr, &conn->c_faddr); + rds_conn_drop(conn); + break; + } +} + +static void rds_ib_cq_comp_handler_send(struct ib_cq *cq, void *context) +{ + struct rds_connection *conn = context; + struct rds_ib_connection *ic = conn->c_transport_data; + + rdsdebug("conn %p cq %p\n", conn, cq); + + rds_ib_stats_inc(s_ib_evt_handler_call); + + tasklet_schedule(&ic->i_send_tasklet); +} + +static inline int ibdev_get_unused_vector(struct rds_ib_device *rds_ibdev) +{ + int min = rds_ibdev->vector_load[rds_ibdev->dev->num_comp_vectors - 1]; + int index = rds_ibdev->dev->num_comp_vectors - 1; + int i; + + for (i = rds_ibdev->dev->num_comp_vectors - 1; i >= 0; i--) { + if (rds_ibdev->vector_load[i] < min) { + index = i; + min = rds_ibdev->vector_load[i]; + } + } + + rds_ibdev->vector_load[index]++; + return index; +} + +static inline void ibdev_put_vector(struct rds_ib_device *rds_ibdev, int index) +{ + rds_ibdev->vector_load[index]--; +} + +static void rds_dma_hdr_free(struct ib_device *dev, struct rds_header *hdr, + dma_addr_t dma_addr, enum dma_data_direction dir) +{ + ib_dma_unmap_single(dev, dma_addr, sizeof(*hdr), dir); + kfree(hdr); +} + +static struct rds_header *rds_dma_hdr_alloc(struct ib_device *dev, + dma_addr_t *dma_addr, enum dma_data_direction dir) +{ + struct rds_header *hdr; + + hdr = kzalloc_node(sizeof(*hdr), GFP_KERNEL, ibdev_to_node(dev)); + if (!hdr) + return NULL; + + *dma_addr = ib_dma_map_single(dev, hdr, sizeof(*hdr), + DMA_BIDIRECTIONAL); + if (ib_dma_mapping_error(dev, *dma_addr)) { + kfree(hdr); + return NULL; + } + + return hdr; +} + +/* Free the DMA memory used to store struct rds_header. + * + * @dev: the RDS IB device + * @hdrs: pointer to the array storing DMA memory pointers + * @dma_addrs: pointer to the array storing DMA addresses + * @num_hdars: number of headers to free. + */ +static void rds_dma_hdrs_free(struct rds_ib_device *dev, + struct rds_header **hdrs, dma_addr_t *dma_addrs, u32 num_hdrs, + enum dma_data_direction dir) +{ + u32 i; + + for (i = 0; i < num_hdrs; i++) + rds_dma_hdr_free(dev->dev, hdrs[i], dma_addrs[i], dir); + kvfree(hdrs); + kvfree(dma_addrs); +} + + +/* Allocate DMA coherent memory to be used to store struct rds_header for + * sending/receiving packets. The pointers to the DMA memory and the + * associated DMA addresses are stored in two arrays. + * + * @dev: the RDS IB device + * @dma_addrs: pointer to the array for storing DMA addresses + * @num_hdrs: number of headers to allocate + * + * It returns the pointer to the array storing the DMA memory pointers. On + * error, NULL pointer is returned. + */ +static struct rds_header **rds_dma_hdrs_alloc(struct rds_ib_device *dev, + dma_addr_t **dma_addrs, u32 num_hdrs, + enum dma_data_direction dir) +{ + struct rds_header **hdrs; + dma_addr_t *hdr_daddrs; + u32 i; + + hdrs = kvmalloc_node(sizeof(*hdrs) * num_hdrs, GFP_KERNEL, + ibdev_to_node(dev->dev)); + if (!hdrs) + return NULL; + + hdr_daddrs = kvmalloc_node(sizeof(*hdr_daddrs) * num_hdrs, GFP_KERNEL, + ibdev_to_node(dev->dev)); + if (!hdr_daddrs) { + kvfree(hdrs); + return NULL; + } + + for (i = 0; i < num_hdrs; i++) { + hdrs[i] = rds_dma_hdr_alloc(dev->dev, &hdr_daddrs[i], dir); + if (!hdrs[i]) { + rds_dma_hdrs_free(dev, hdrs, hdr_daddrs, i, dir); + return NULL; + } + } + + *dma_addrs = hdr_daddrs; + return hdrs; +} + +/* + * This needs to be very careful to not leave IS_ERR pointers around for + * cleanup to trip over. + */ +static int rds_ib_setup_qp(struct rds_connection *conn) +{ + struct rds_ib_connection *ic = conn->c_transport_data; + struct ib_device *dev = ic->i_cm_id->device; + struct ib_qp_init_attr attr; + struct ib_cq_init_attr cq_attr = {}; + struct rds_ib_device *rds_ibdev; + unsigned long max_wrs; + int ret, fr_queue_space; + + /* + * It's normal to see a null device if an incoming connection races + * with device removal, so we don't print a warning. + */ + rds_ibdev = rds_ib_get_client_data(dev); + if (!rds_ibdev) + return -EOPNOTSUPP; + + /* The fr_queue_space is currently set to 512, to add extra space on + * completion queue and send queue. This extra space is used for FRWR + * registration and invalidation work requests + */ + fr_queue_space = RDS_IB_DEFAULT_FR_WR; + + /* add the conn now so that connection establishment has the dev */ + rds_ib_add_conn(rds_ibdev, conn); + + max_wrs = rds_ibdev->max_wrs < rds_ib_sysctl_max_send_wr + 1 ? + rds_ibdev->max_wrs - 1 : rds_ib_sysctl_max_send_wr; + if (ic->i_send_ring.w_nr != max_wrs) + rds_ib_ring_resize(&ic->i_send_ring, max_wrs); + + max_wrs = rds_ibdev->max_wrs < rds_ib_sysctl_max_recv_wr + 1 ? + rds_ibdev->max_wrs - 1 : rds_ib_sysctl_max_recv_wr; + if (ic->i_recv_ring.w_nr != max_wrs) + rds_ib_ring_resize(&ic->i_recv_ring, max_wrs); + + /* Protection domain and memory range */ + ic->i_pd = rds_ibdev->pd; + + ic->i_scq_vector = ibdev_get_unused_vector(rds_ibdev); + cq_attr.cqe = ic->i_send_ring.w_nr + fr_queue_space + 1; + cq_attr.comp_vector = ic->i_scq_vector; + ic->i_send_cq = ib_create_cq(dev, rds_ib_cq_comp_handler_send, + rds_ib_cq_event_handler, conn, + &cq_attr); + if (IS_ERR(ic->i_send_cq)) { + ret = PTR_ERR(ic->i_send_cq); + ic->i_send_cq = NULL; + ibdev_put_vector(rds_ibdev, ic->i_scq_vector); + rdsdebug("ib_create_cq send failed: %d\n", ret); + goto rds_ibdev_out; + } + + ic->i_rcq_vector = ibdev_get_unused_vector(rds_ibdev); + cq_attr.cqe = ic->i_recv_ring.w_nr; + cq_attr.comp_vector = ic->i_rcq_vector; + ic->i_recv_cq = ib_create_cq(dev, rds_ib_cq_comp_handler_recv, + rds_ib_cq_event_handler, conn, + &cq_attr); + if (IS_ERR(ic->i_recv_cq)) { + ret = PTR_ERR(ic->i_recv_cq); + ic->i_recv_cq = NULL; + ibdev_put_vector(rds_ibdev, ic->i_rcq_vector); + rdsdebug("ib_create_cq recv failed: %d\n", ret); + goto send_cq_out; + } + + ret = ib_req_notify_cq(ic->i_send_cq, IB_CQ_NEXT_COMP); + if (ret) { + rdsdebug("ib_req_notify_cq send failed: %d\n", ret); + goto recv_cq_out; + } + + ret = ib_req_notify_cq(ic->i_recv_cq, IB_CQ_SOLICITED); + if (ret) { + rdsdebug("ib_req_notify_cq recv failed: %d\n", ret); + goto recv_cq_out; + } + + /* XXX negotiate max send/recv with remote? */ + memset(&attr, 0, sizeof(attr)); + attr.event_handler = rds_ib_qp_event_handler; + attr.qp_context = conn; + /* + 1 to allow for the single ack message */ + attr.cap.max_send_wr = ic->i_send_ring.w_nr + fr_queue_space + 1; + attr.cap.max_recv_wr = ic->i_recv_ring.w_nr + 1; + attr.cap.max_send_sge = rds_ibdev->max_sge; + attr.cap.max_recv_sge = RDS_IB_RECV_SGE; + attr.sq_sig_type = IB_SIGNAL_REQ_WR; + attr.qp_type = IB_QPT_RC; + attr.send_cq = ic->i_send_cq; + attr.recv_cq = ic->i_recv_cq; + + /* + * XXX this can fail if max_*_wr is too large? Are we supposed + * to back off until we get a value that the hardware can support? + */ + ret = rdma_create_qp(ic->i_cm_id, ic->i_pd, &attr); + if (ret) { + rdsdebug("rdma_create_qp failed: %d\n", ret); + goto recv_cq_out; + } + + ic->i_send_hdrs = rds_dma_hdrs_alloc(rds_ibdev, &ic->i_send_hdrs_dma, + ic->i_send_ring.w_nr, + DMA_TO_DEVICE); + if (!ic->i_send_hdrs) { + ret = -ENOMEM; + rdsdebug("DMA send hdrs alloc failed\n"); + goto qp_out; + } + + ic->i_recv_hdrs = rds_dma_hdrs_alloc(rds_ibdev, &ic->i_recv_hdrs_dma, + ic->i_recv_ring.w_nr, + DMA_FROM_DEVICE); + if (!ic->i_recv_hdrs) { + ret = -ENOMEM; + rdsdebug("DMA recv hdrs alloc failed\n"); + goto send_hdrs_dma_out; + } + + ic->i_ack = rds_dma_hdr_alloc(rds_ibdev->dev, &ic->i_ack_dma, + DMA_TO_DEVICE); + if (!ic->i_ack) { + ret = -ENOMEM; + rdsdebug("DMA ack header alloc failed\n"); + goto recv_hdrs_dma_out; + } + + ic->i_sends = vzalloc_node(array_size(sizeof(struct rds_ib_send_work), + ic->i_send_ring.w_nr), + ibdev_to_node(dev)); + if (!ic->i_sends) { + ret = -ENOMEM; + rdsdebug("send allocation failed\n"); + goto ack_dma_out; + } + + ic->i_recvs = vzalloc_node(array_size(sizeof(struct rds_ib_recv_work), + ic->i_recv_ring.w_nr), + ibdev_to_node(dev)); + if (!ic->i_recvs) { + ret = -ENOMEM; + rdsdebug("recv allocation failed\n"); + goto sends_out; + } + + rds_ib_recv_init_ack(ic); + + rdsdebug("conn %p pd %p cq %p %p\n", conn, ic->i_pd, + ic->i_send_cq, ic->i_recv_cq); + + goto out; + +sends_out: + vfree(ic->i_sends); + +ack_dma_out: + rds_dma_hdr_free(rds_ibdev->dev, ic->i_ack, ic->i_ack_dma, + DMA_TO_DEVICE); + ic->i_ack = NULL; + +recv_hdrs_dma_out: + rds_dma_hdrs_free(rds_ibdev, ic->i_recv_hdrs, ic->i_recv_hdrs_dma, + ic->i_recv_ring.w_nr, DMA_FROM_DEVICE); + ic->i_recv_hdrs = NULL; + ic->i_recv_hdrs_dma = NULL; + +send_hdrs_dma_out: + rds_dma_hdrs_free(rds_ibdev, ic->i_send_hdrs, ic->i_send_hdrs_dma, + ic->i_send_ring.w_nr, DMA_TO_DEVICE); + ic->i_send_hdrs = NULL; + ic->i_send_hdrs_dma = NULL; + +qp_out: + rdma_destroy_qp(ic->i_cm_id); +recv_cq_out: + ib_destroy_cq(ic->i_recv_cq); + ic->i_recv_cq = NULL; +send_cq_out: + ib_destroy_cq(ic->i_send_cq); + ic->i_send_cq = NULL; +rds_ibdev_out: + rds_ib_remove_conn(rds_ibdev, conn); +out: + rds_ib_dev_put(rds_ibdev); + + return ret; +} + +static u32 rds_ib_protocol_compatible(struct rdma_cm_event *event, bool isv6) +{ + const union rds_ib_conn_priv *dp = event->param.conn.private_data; + u8 data_len, major, minor; + u32 version = 0; + __be16 mask; + u16 common; + + /* + * rdma_cm private data is odd - when there is any private data in the + * request, we will be given a pretty large buffer without telling us the + * original size. The only way to tell the difference is by looking at + * the contents, which are initialized to zero. + * If the protocol version fields aren't set, this is a connection attempt + * from an older version. This could be 3.0 or 2.0 - we can't tell. + * We really should have changed this for OFED 1.3 :-( + */ + + /* Be paranoid. RDS always has privdata */ + if (!event->param.conn.private_data_len) { + printk(KERN_NOTICE "RDS incoming connection has no private data, " + "rejecting\n"); + return 0; + } + + if (isv6) { + data_len = sizeof(struct rds6_ib_connect_private); + major = dp->ricp_v6.dp_protocol_major; + minor = dp->ricp_v6.dp_protocol_minor; + mask = dp->ricp_v6.dp_protocol_minor_mask; + } else { + data_len = sizeof(struct rds_ib_connect_private); + major = dp->ricp_v4.dp_protocol_major; + minor = dp->ricp_v4.dp_protocol_minor; + mask = dp->ricp_v4.dp_protocol_minor_mask; + } + + /* Even if len is crap *now* I still want to check it. -ASG */ + if (event->param.conn.private_data_len < data_len || major == 0) + return RDS_PROTOCOL_4_0; + + common = be16_to_cpu(mask) & RDS_IB_SUPPORTED_PROTOCOLS; + if (major == 4 && common) { + version = RDS_PROTOCOL_4_0; + while ((common >>= 1) != 0) + version++; + } else if (RDS_PROTOCOL_COMPAT_VERSION == + RDS_PROTOCOL(major, minor)) { + version = RDS_PROTOCOL_COMPAT_VERSION; + } else { + if (isv6) + printk_ratelimited(KERN_NOTICE "RDS: Connection from %pI6c using incompatible protocol version %u.%u\n", + &dp->ricp_v6.dp_saddr, major, minor); + else + printk_ratelimited(KERN_NOTICE "RDS: Connection from %pI4 using incompatible protocol version %u.%u\n", + &dp->ricp_v4.dp_saddr, major, minor); + } + return version; +} + +#if IS_ENABLED(CONFIG_IPV6) +/* Given an IPv6 address, find the net_device which hosts that address and + * return its index. This is used by the rds_ib_cm_handle_connect() code to + * find the interface index of where an incoming request comes from when + * the request is using a link local address. + * + * Note one problem in this search. It is possible that two interfaces have + * the same link local address. Unfortunately, this cannot be solved unless + * the underlying layer gives us the interface which an incoming RDMA connect + * request comes from. + */ +static u32 __rds_find_ifindex(struct net *net, const struct in6_addr *addr) +{ + struct net_device *dev; + int idx = 0; + + rcu_read_lock(); + for_each_netdev_rcu(net, dev) { + if (ipv6_chk_addr(net, addr, dev, 1)) { + idx = dev->ifindex; + break; + } + } + rcu_read_unlock(); + + return idx; +} +#endif + +int rds_ib_cm_handle_connect(struct rdma_cm_id *cm_id, + struct rdma_cm_event *event, bool isv6) +{ + __be64 lguid = cm_id->route.path_rec->sgid.global.interface_id; + __be64 fguid = cm_id->route.path_rec->dgid.global.interface_id; + const struct rds_ib_conn_priv_cmn *dp_cmn; + struct rds_connection *conn = NULL; + struct rds_ib_connection *ic = NULL; + struct rdma_conn_param conn_param; + const union rds_ib_conn_priv *dp; + union rds_ib_conn_priv dp_rep; + struct in6_addr s_mapped_addr; + struct in6_addr d_mapped_addr; + const struct in6_addr *saddr6; + const struct in6_addr *daddr6; + int destroy = 1; + u32 ifindex = 0; + u32 version; + int err = 1; + + /* Check whether the remote protocol version matches ours. */ + version = rds_ib_protocol_compatible(event, isv6); + if (!version) { + err = RDS_RDMA_REJ_INCOMPAT; + goto out; + } + + dp = event->param.conn.private_data; + if (isv6) { +#if IS_ENABLED(CONFIG_IPV6) + dp_cmn = &dp->ricp_v6.dp_cmn; + saddr6 = &dp->ricp_v6.dp_saddr; + daddr6 = &dp->ricp_v6.dp_daddr; + /* If either address is link local, need to find the + * interface index in order to create a proper RDS + * connection. + */ + if (ipv6_addr_type(daddr6) & IPV6_ADDR_LINKLOCAL) { + /* Using init_net for now .. */ + ifindex = __rds_find_ifindex(&init_net, daddr6); + /* No index found... Need to bail out. */ + if (ifindex == 0) { + err = -EOPNOTSUPP; + goto out; + } + } else if (ipv6_addr_type(saddr6) & IPV6_ADDR_LINKLOCAL) { + /* Use our address to find the correct index. */ + ifindex = __rds_find_ifindex(&init_net, daddr6); + /* No index found... Need to bail out. */ + if (ifindex == 0) { + err = -EOPNOTSUPP; + goto out; + } + } +#else + err = -EOPNOTSUPP; + goto out; +#endif + } else { + dp_cmn = &dp->ricp_v4.dp_cmn; + ipv6_addr_set_v4mapped(dp->ricp_v4.dp_saddr, &s_mapped_addr); + ipv6_addr_set_v4mapped(dp->ricp_v4.dp_daddr, &d_mapped_addr); + saddr6 = &s_mapped_addr; + daddr6 = &d_mapped_addr; + } + + rdsdebug("saddr %pI6c daddr %pI6c RDSv%u.%u lguid 0x%llx fguid 0x%llx, tos:%d\n", + saddr6, daddr6, RDS_PROTOCOL_MAJOR(version), + RDS_PROTOCOL_MINOR(version), + (unsigned long long)be64_to_cpu(lguid), + (unsigned long long)be64_to_cpu(fguid), dp_cmn->ricpc_dp_toss); + + /* RDS/IB is not currently netns aware, thus init_net */ + conn = rds_conn_create(&init_net, daddr6, saddr6, + &rds_ib_transport, dp_cmn->ricpc_dp_toss, + GFP_KERNEL, ifindex); + if (IS_ERR(conn)) { + rdsdebug("rds_conn_create failed (%ld)\n", PTR_ERR(conn)); + conn = NULL; + goto out; + } + + /* + * The connection request may occur while the + * previous connection exist, e.g. in case of failover. + * But as connections may be initiated simultaneously + * by both hosts, we have a random backoff mechanism - + * see the comment above rds_queue_reconnect() + */ + mutex_lock(&conn->c_cm_lock); + if (!rds_conn_transition(conn, RDS_CONN_DOWN, RDS_CONN_CONNECTING)) { + if (rds_conn_state(conn) == RDS_CONN_UP) { + rdsdebug("incoming connect while connecting\n"); + rds_conn_drop(conn); + rds_ib_stats_inc(s_ib_listen_closed_stale); + } else + if (rds_conn_state(conn) == RDS_CONN_CONNECTING) { + /* Wait and see - our connect may still be succeeding */ + rds_ib_stats_inc(s_ib_connect_raced); + } + goto out; + } + + ic = conn->c_transport_data; + + rds_ib_set_protocol(conn, version); + rds_ib_set_flow_control(conn, be32_to_cpu(dp_cmn->ricpc_credit)); + + /* If the peer gave us the last packet it saw, process this as if + * we had received a regular ACK. */ + if (dp_cmn->ricpc_ack_seq) + rds_send_drop_acked(conn, be64_to_cpu(dp_cmn->ricpc_ack_seq), + NULL); + + BUG_ON(cm_id->context); + BUG_ON(ic->i_cm_id); + + ic->i_cm_id = cm_id; + cm_id->context = conn; + + /* We got halfway through setting up the ib_connection, if we + * fail now, we have to take the long route out of this mess. */ + destroy = 0; + + err = rds_ib_setup_qp(conn); + if (err) { + rds_ib_conn_error(conn, "rds_ib_setup_qp failed (%d)\n", err); + goto out; + } + + rds_ib_cm_fill_conn_param(conn, &conn_param, &dp_rep, version, + event->param.conn.responder_resources, + event->param.conn.initiator_depth, isv6); + + rdma_set_min_rnr_timer(cm_id, IB_RNR_TIMER_000_32); + /* rdma_accept() calls rdma_reject() internally if it fails */ + if (rdma_accept(cm_id, &conn_param)) + rds_ib_conn_error(conn, "rdma_accept failed\n"); + +out: + if (conn) + mutex_unlock(&conn->c_cm_lock); + if (err) + rdma_reject(cm_id, &err, sizeof(int), + IB_CM_REJ_CONSUMER_DEFINED); + return destroy; +} + + +int rds_ib_cm_initiate_connect(struct rdma_cm_id *cm_id, bool isv6) +{ + struct rds_connection *conn = cm_id->context; + struct rds_ib_connection *ic = conn->c_transport_data; + struct rdma_conn_param conn_param; + union rds_ib_conn_priv dp; + int ret; + + /* If the peer doesn't do protocol negotiation, we must + * default to RDSv3.0 */ + rds_ib_set_protocol(conn, RDS_PROTOCOL_4_1); + ic->i_flowctl = rds_ib_sysctl_flow_control; /* advertise flow control */ + + ret = rds_ib_setup_qp(conn); + if (ret) { + rds_ib_conn_error(conn, "rds_ib_setup_qp failed (%d)\n", ret); + goto out; + } + + rds_ib_cm_fill_conn_param(conn, &conn_param, &dp, + conn->c_proposed_version, + UINT_MAX, UINT_MAX, isv6); + ret = rdma_connect_locked(cm_id, &conn_param); + if (ret) + rds_ib_conn_error(conn, "rdma_connect_locked failed (%d)\n", + ret); + +out: + /* Beware - returning non-zero tells the rdma_cm to destroy + * the cm_id. We should certainly not do it as long as we still + * "own" the cm_id. */ + if (ret) { + if (ic->i_cm_id == cm_id) + ret = 0; + } + ic->i_active_side = true; + return ret; +} + +int rds_ib_conn_path_connect(struct rds_conn_path *cp) +{ + struct rds_connection *conn = cp->cp_conn; + struct sockaddr_storage src, dest; + rdma_cm_event_handler handler; + struct rds_ib_connection *ic; + int ret; + + ic = conn->c_transport_data; + + /* XXX I wonder what affect the port space has */ + /* delegate cm event handler to rdma_transport */ +#if IS_ENABLED(CONFIG_IPV6) + if (conn->c_isv6) + handler = rds6_rdma_cm_event_handler; + else +#endif + handler = rds_rdma_cm_event_handler; + ic->i_cm_id = rdma_create_id(&init_net, handler, conn, + RDMA_PS_TCP, IB_QPT_RC); + if (IS_ERR(ic->i_cm_id)) { + ret = PTR_ERR(ic->i_cm_id); + ic->i_cm_id = NULL; + rdsdebug("rdma_create_id() failed: %d\n", ret); + goto out; + } + + rdsdebug("created cm id %p for conn %p\n", ic->i_cm_id, conn); + + if (ipv6_addr_v4mapped(&conn->c_faddr)) { + struct sockaddr_in *sin; + + sin = (struct sockaddr_in *)&src; + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = conn->c_laddr.s6_addr32[3]; + sin->sin_port = 0; + + sin = (struct sockaddr_in *)&dest; + sin->sin_family = AF_INET; + sin->sin_addr.s_addr = conn->c_faddr.s6_addr32[3]; + sin->sin_port = htons(RDS_PORT); + } else { + struct sockaddr_in6 *sin6; + + sin6 = (struct sockaddr_in6 *)&src; + sin6->sin6_family = AF_INET6; + sin6->sin6_addr = conn->c_laddr; + sin6->sin6_port = 0; + sin6->sin6_scope_id = conn->c_dev_if; + + sin6 = (struct sockaddr_in6 *)&dest; + sin6->sin6_family = AF_INET6; + sin6->sin6_addr = conn->c_faddr; + sin6->sin6_port = htons(RDS_CM_PORT); + sin6->sin6_scope_id = conn->c_dev_if; + } + + ret = rdma_resolve_addr(ic->i_cm_id, (struct sockaddr *)&src, + (struct sockaddr *)&dest, + RDS_RDMA_RESOLVE_TIMEOUT_MS); + if (ret) { + rdsdebug("addr resolve failed for cm id %p: %d\n", ic->i_cm_id, + ret); + rdma_destroy_id(ic->i_cm_id); + ic->i_cm_id = NULL; + } + +out: + return ret; +} + +/* + * This is so careful about only cleaning up resources that were built up + * so that it can be called at any point during startup. In fact it + * can be called multiple times for a given connection. + */ +void rds_ib_conn_path_shutdown(struct rds_conn_path *cp) +{ + struct rds_connection *conn = cp->cp_conn; + struct rds_ib_connection *ic = conn->c_transport_data; + int err = 0; + + rdsdebug("cm %p pd %p cq %p %p qp %p\n", ic->i_cm_id, + ic->i_pd, ic->i_send_cq, ic->i_recv_cq, + ic->i_cm_id ? ic->i_cm_id->qp : NULL); + + if (ic->i_cm_id) { + rdsdebug("disconnecting cm %p\n", ic->i_cm_id); + err = rdma_disconnect(ic->i_cm_id); + if (err) { + /* Actually this may happen quite frequently, when + * an outgoing connect raced with an incoming connect. + */ + rdsdebug("failed to disconnect, cm: %p err %d\n", + ic->i_cm_id, err); + } + + /* kick off "flush_worker" for all pools in order to reap + * all FRMR registrations that are still marked "FRMR_IS_INUSE" + */ + rds_ib_flush_mrs(); + + /* + * We want to wait for tx and rx completion to finish + * before we tear down the connection, but we have to be + * careful not to get stuck waiting on a send ring that + * only has unsignaled sends in it. We've shutdown new + * sends before getting here so by waiting for signaled + * sends to complete we're ensured that there will be no + * more tx processing. + */ + wait_event(rds_ib_ring_empty_wait, + rds_ib_ring_empty(&ic->i_recv_ring) && + (atomic_read(&ic->i_signaled_sends) == 0) && + (atomic_read(&ic->i_fastreg_inuse_count) == 0) && + (atomic_read(&ic->i_fastreg_wrs) == RDS_IB_DEFAULT_FR_WR)); + tasklet_kill(&ic->i_send_tasklet); + tasklet_kill(&ic->i_recv_tasklet); + + atomic_set(&ic->i_cq_quiesce, 1); + + /* first destroy the ib state that generates callbacks */ + if (ic->i_cm_id->qp) + rdma_destroy_qp(ic->i_cm_id); + if (ic->i_send_cq) { + if (ic->rds_ibdev) + ibdev_put_vector(ic->rds_ibdev, ic->i_scq_vector); + ib_destroy_cq(ic->i_send_cq); + } + + if (ic->i_recv_cq) { + if (ic->rds_ibdev) + ibdev_put_vector(ic->rds_ibdev, ic->i_rcq_vector); + ib_destroy_cq(ic->i_recv_cq); + } + + if (ic->rds_ibdev) { + /* then free the resources that ib callbacks use */ + if (ic->i_send_hdrs) { + rds_dma_hdrs_free(ic->rds_ibdev, + ic->i_send_hdrs, + ic->i_send_hdrs_dma, + ic->i_send_ring.w_nr, + DMA_TO_DEVICE); + ic->i_send_hdrs = NULL; + ic->i_send_hdrs_dma = NULL; + } + + if (ic->i_recv_hdrs) { + rds_dma_hdrs_free(ic->rds_ibdev, + ic->i_recv_hdrs, + ic->i_recv_hdrs_dma, + ic->i_recv_ring.w_nr, + DMA_FROM_DEVICE); + ic->i_recv_hdrs = NULL; + ic->i_recv_hdrs_dma = NULL; + } + + if (ic->i_ack) { + rds_dma_hdr_free(ic->rds_ibdev->dev, ic->i_ack, + ic->i_ack_dma, DMA_TO_DEVICE); + ic->i_ack = NULL; + } + } else { + WARN_ON(ic->i_send_hdrs); + WARN_ON(ic->i_send_hdrs_dma); + WARN_ON(ic->i_recv_hdrs); + WARN_ON(ic->i_recv_hdrs_dma); + WARN_ON(ic->i_ack); + } + + if (ic->i_sends) + rds_ib_send_clear_ring(ic); + if (ic->i_recvs) + rds_ib_recv_clear_ring(ic); + + rdma_destroy_id(ic->i_cm_id); + + /* + * Move connection back to the nodev list. + */ + if (ic->rds_ibdev) + rds_ib_remove_conn(ic->rds_ibdev, conn); + + ic->i_cm_id = NULL; + ic->i_pd = NULL; + ic->i_send_cq = NULL; + ic->i_recv_cq = NULL; + } + BUG_ON(ic->rds_ibdev); + + /* Clear pending transmit */ + if (ic->i_data_op) { + struct rds_message *rm; + + rm = container_of(ic->i_data_op, struct rds_message, data); + rds_message_put(rm); + ic->i_data_op = NULL; + } + + /* Clear the ACK state */ + clear_bit(IB_ACK_IN_FLIGHT, &ic->i_ack_flags); +#ifdef KERNEL_HAS_ATOMIC64 + atomic64_set(&ic->i_ack_next, 0); +#else + ic->i_ack_next = 0; +#endif + ic->i_ack_recv = 0; + + /* Clear flow control state */ + ic->i_flowctl = 0; + atomic_set(&ic->i_credits, 0); + + /* Re-init rings, but retain sizes. */ + rds_ib_ring_init(&ic->i_send_ring, ic->i_send_ring.w_nr); + rds_ib_ring_init(&ic->i_recv_ring, ic->i_recv_ring.w_nr); + + if (ic->i_ibinc) { + rds_inc_put(&ic->i_ibinc->ii_inc); + ic->i_ibinc = NULL; + } + + vfree(ic->i_sends); + ic->i_sends = NULL; + vfree(ic->i_recvs); + ic->i_recvs = NULL; + ic->i_active_side = false; +} + +int rds_ib_conn_alloc(struct rds_connection *conn, gfp_t gfp) +{ + struct rds_ib_connection *ic; + unsigned long flags; + int ret; + + /* XXX too lazy? */ + ic = kzalloc(sizeof(struct rds_ib_connection), gfp); + if (!ic) + return -ENOMEM; + + ret = rds_ib_recv_alloc_caches(ic, gfp); + if (ret) { + kfree(ic); + return ret; + } + + INIT_LIST_HEAD(&ic->ib_node); + tasklet_init(&ic->i_send_tasklet, rds_ib_tasklet_fn_send, + (unsigned long)ic); + tasklet_init(&ic->i_recv_tasklet, rds_ib_tasklet_fn_recv, + (unsigned long)ic); + mutex_init(&ic->i_recv_mutex); +#ifndef KERNEL_HAS_ATOMIC64 + spin_lock_init(&ic->i_ack_lock); +#endif + atomic_set(&ic->i_signaled_sends, 0); + atomic_set(&ic->i_fastreg_wrs, RDS_IB_DEFAULT_FR_WR); + + /* + * rds_ib_conn_shutdown() waits for these to be emptied so they + * must be initialized before it can be called. + */ + rds_ib_ring_init(&ic->i_send_ring, 0); + rds_ib_ring_init(&ic->i_recv_ring, 0); + + ic->conn = conn; + conn->c_transport_data = ic; + + spin_lock_irqsave(&ib_nodev_conns_lock, flags); + list_add_tail(&ic->ib_node, &ib_nodev_conns); + spin_unlock_irqrestore(&ib_nodev_conns_lock, flags); + + + rdsdebug("conn %p conn ic %p\n", conn, conn->c_transport_data); + return 0; +} + +/* + * Free a connection. Connection must be shut down and not set for reconnect. + */ +void rds_ib_conn_free(void *arg) +{ + struct rds_ib_connection *ic = arg; + spinlock_t *lock_ptr; + + rdsdebug("ic %p\n", ic); + + /* + * Conn is either on a dev's list or on the nodev list. + * A race with shutdown() or connect() would cause problems + * (since rds_ibdev would change) but that should never happen. + */ + lock_ptr = ic->rds_ibdev ? &ic->rds_ibdev->spinlock : &ib_nodev_conns_lock; + + spin_lock_irq(lock_ptr); + list_del(&ic->ib_node); + spin_unlock_irq(lock_ptr); + + rds_ib_recv_free_caches(ic); + + kfree(ic); +} + + +/* + * An error occurred on the connection + */ +void +__rds_ib_conn_error(struct rds_connection *conn, const char *fmt, ...) +{ + va_list ap; + + rds_conn_drop(conn); + + va_start(ap, fmt); + vprintk(fmt, ap); + va_end(ap); +} diff --git a/net/rds/ib_frmr.c b/net/rds/ib_frmr.c new file mode 100644 index 000000000..28c1b0022 --- /dev/null +++ b/net/rds/ib_frmr.c @@ -0,0 +1,446 @@ +/* + * Copyright (c) 2016 Oracle. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "ib_mr.h" + +static inline void +rds_transition_frwr_state(struct rds_ib_mr *ibmr, + enum rds_ib_fr_state old_state, + enum rds_ib_fr_state new_state) +{ + if (cmpxchg(&ibmr->u.frmr.fr_state, + old_state, new_state) == old_state && + old_state == FRMR_IS_INUSE) { + /* enforce order of ibmr->u.frmr.fr_state update + * before decrementing i_fastreg_inuse_count + */ + smp_mb__before_atomic(); + atomic_dec(&ibmr->ic->i_fastreg_inuse_count); + if (waitqueue_active(&rds_ib_ring_empty_wait)) + wake_up(&rds_ib_ring_empty_wait); + } +} + +static struct rds_ib_mr *rds_ib_alloc_frmr(struct rds_ib_device *rds_ibdev, + int npages) +{ + struct rds_ib_mr_pool *pool; + struct rds_ib_mr *ibmr = NULL; + struct rds_ib_frmr *frmr; + int err = 0; + + if (npages <= RDS_MR_8K_MSG_SIZE) + pool = rds_ibdev->mr_8k_pool; + else + pool = rds_ibdev->mr_1m_pool; + + ibmr = rds_ib_try_reuse_ibmr(pool); + if (ibmr) + return ibmr; + + ibmr = kzalloc_node(sizeof(*ibmr), GFP_KERNEL, + rdsibdev_to_node(rds_ibdev)); + if (!ibmr) { + err = -ENOMEM; + goto out_no_cigar; + } + + frmr = &ibmr->u.frmr; + frmr->mr = ib_alloc_mr(rds_ibdev->pd, IB_MR_TYPE_MEM_REG, + pool->max_pages); + if (IS_ERR(frmr->mr)) { + pr_warn("RDS/IB: %s failed to allocate MR", __func__); + err = PTR_ERR(frmr->mr); + goto out_no_cigar; + } + + ibmr->pool = pool; + if (pool->pool_type == RDS_IB_MR_8K_POOL) + rds_ib_stats_inc(s_ib_rdma_mr_8k_alloc); + else + rds_ib_stats_inc(s_ib_rdma_mr_1m_alloc); + + if (atomic_read(&pool->item_count) > pool->max_items_soft) + pool->max_items_soft = pool->max_items; + + frmr->fr_state = FRMR_IS_FREE; + init_waitqueue_head(&frmr->fr_inv_done); + init_waitqueue_head(&frmr->fr_reg_done); + return ibmr; + +out_no_cigar: + kfree(ibmr); + atomic_dec(&pool->item_count); + return ERR_PTR(err); +} + +static void rds_ib_free_frmr(struct rds_ib_mr *ibmr, bool drop) +{ + struct rds_ib_mr_pool *pool = ibmr->pool; + + if (drop) + llist_add(&ibmr->llnode, &pool->drop_list); + else + llist_add(&ibmr->llnode, &pool->free_list); + atomic_add(ibmr->sg_len, &pool->free_pinned); + atomic_inc(&pool->dirty_count); + + /* If we've pinned too many pages, request a flush */ + if (atomic_read(&pool->free_pinned) >= pool->max_free_pinned || + atomic_read(&pool->dirty_count) >= pool->max_items / 5) + queue_delayed_work(rds_ib_mr_wq, &pool->flush_worker, 10); +} + +static int rds_ib_post_reg_frmr(struct rds_ib_mr *ibmr) +{ + struct rds_ib_frmr *frmr = &ibmr->u.frmr; + struct ib_reg_wr reg_wr; + int ret, off = 0; + + while (atomic_dec_return(&ibmr->ic->i_fastreg_wrs) <= 0) { + atomic_inc(&ibmr->ic->i_fastreg_wrs); + cpu_relax(); + } + + ret = ib_map_mr_sg_zbva(frmr->mr, ibmr->sg, ibmr->sg_dma_len, + &off, PAGE_SIZE); + if (unlikely(ret != ibmr->sg_dma_len)) + return ret < 0 ? ret : -EINVAL; + + if (cmpxchg(&frmr->fr_state, + FRMR_IS_FREE, FRMR_IS_INUSE) != FRMR_IS_FREE) + return -EBUSY; + + atomic_inc(&ibmr->ic->i_fastreg_inuse_count); + + /* Perform a WR for the fast_reg_mr. Each individual page + * in the sg list is added to the fast reg page list and placed + * inside the fast_reg_mr WR. The key used is a rolling 8bit + * counter, which should guarantee uniqueness. + */ + ib_update_fast_reg_key(frmr->mr, ibmr->remap_count++); + frmr->fr_reg = true; + + memset(®_wr, 0, sizeof(reg_wr)); + reg_wr.wr.wr_id = (unsigned long)(void *)ibmr; + reg_wr.wr.opcode = IB_WR_REG_MR; + reg_wr.wr.num_sge = 0; + reg_wr.mr = frmr->mr; + reg_wr.key = frmr->mr->rkey; + reg_wr.access = IB_ACCESS_LOCAL_WRITE | + IB_ACCESS_REMOTE_READ | + IB_ACCESS_REMOTE_WRITE; + reg_wr.wr.send_flags = IB_SEND_SIGNALED; + + ret = ib_post_send(ibmr->ic->i_cm_id->qp, ®_wr.wr, NULL); + if (unlikely(ret)) { + /* Failure here can be because of -ENOMEM as well */ + rds_transition_frwr_state(ibmr, FRMR_IS_INUSE, FRMR_IS_STALE); + + atomic_inc(&ibmr->ic->i_fastreg_wrs); + if (printk_ratelimit()) + pr_warn("RDS/IB: %s returned error(%d)\n", + __func__, ret); + goto out; + } + + /* Wait for the registration to complete in order to prevent an invalid + * access error resulting from a race between the memory region already + * being accessed while registration is still pending. + */ + wait_event(frmr->fr_reg_done, !frmr->fr_reg); + +out: + + return ret; +} + +static int rds_ib_map_frmr(struct rds_ib_device *rds_ibdev, + struct rds_ib_mr_pool *pool, + struct rds_ib_mr *ibmr, + struct scatterlist *sg, unsigned int sg_len) +{ + struct ib_device *dev = rds_ibdev->dev; + struct rds_ib_frmr *frmr = &ibmr->u.frmr; + int i; + u32 len; + int ret = 0; + + /* We want to teardown old ibmr values here and fill it up with + * new sg values + */ + rds_ib_teardown_mr(ibmr); + + ibmr->sg = sg; + ibmr->sg_len = sg_len; + ibmr->sg_dma_len = 0; + frmr->sg_byte_len = 0; + WARN_ON(ibmr->sg_dma_len); + ibmr->sg_dma_len = ib_dma_map_sg(dev, ibmr->sg, ibmr->sg_len, + DMA_BIDIRECTIONAL); + if (unlikely(!ibmr->sg_dma_len)) { + pr_warn("RDS/IB: %s failed!\n", __func__); + return -EBUSY; + } + + frmr->sg_byte_len = 0; + frmr->dma_npages = 0; + len = 0; + + ret = -EINVAL; + for (i = 0; i < ibmr->sg_dma_len; ++i) { + unsigned int dma_len = sg_dma_len(&ibmr->sg[i]); + u64 dma_addr = sg_dma_address(&ibmr->sg[i]); + + frmr->sg_byte_len += dma_len; + if (dma_addr & ~PAGE_MASK) { + if (i > 0) + goto out_unmap; + else + ++frmr->dma_npages; + } + + if ((dma_addr + dma_len) & ~PAGE_MASK) { + if (i < ibmr->sg_dma_len - 1) + goto out_unmap; + else + ++frmr->dma_npages; + } + + len += dma_len; + } + frmr->dma_npages += len >> PAGE_SHIFT; + + if (frmr->dma_npages > ibmr->pool->max_pages) { + ret = -EMSGSIZE; + goto out_unmap; + } + + ret = rds_ib_post_reg_frmr(ibmr); + if (ret) + goto out_unmap; + + if (ibmr->pool->pool_type == RDS_IB_MR_8K_POOL) + rds_ib_stats_inc(s_ib_rdma_mr_8k_used); + else + rds_ib_stats_inc(s_ib_rdma_mr_1m_used); + + return ret; + +out_unmap: + ib_dma_unmap_sg(rds_ibdev->dev, ibmr->sg, ibmr->sg_len, + DMA_BIDIRECTIONAL); + ibmr->sg_dma_len = 0; + return ret; +} + +static int rds_ib_post_inv(struct rds_ib_mr *ibmr) +{ + struct ib_send_wr *s_wr; + struct rds_ib_frmr *frmr = &ibmr->u.frmr; + struct rdma_cm_id *i_cm_id = ibmr->ic->i_cm_id; + int ret = -EINVAL; + + if (!i_cm_id || !i_cm_id->qp || !frmr->mr) + goto out; + + if (frmr->fr_state != FRMR_IS_INUSE) + goto out; + + while (atomic_dec_return(&ibmr->ic->i_fastreg_wrs) <= 0) { + atomic_inc(&ibmr->ic->i_fastreg_wrs); + cpu_relax(); + } + + frmr->fr_inv = true; + s_wr = &frmr->fr_wr; + + memset(s_wr, 0, sizeof(*s_wr)); + s_wr->wr_id = (unsigned long)(void *)ibmr; + s_wr->opcode = IB_WR_LOCAL_INV; + s_wr->ex.invalidate_rkey = frmr->mr->rkey; + s_wr->send_flags = IB_SEND_SIGNALED; + + ret = ib_post_send(i_cm_id->qp, s_wr, NULL); + if (unlikely(ret)) { + rds_transition_frwr_state(ibmr, FRMR_IS_INUSE, FRMR_IS_STALE); + frmr->fr_inv = false; + /* enforce order of frmr->fr_inv update + * before incrementing i_fastreg_wrs + */ + smp_mb__before_atomic(); + atomic_inc(&ibmr->ic->i_fastreg_wrs); + pr_err("RDS/IB: %s returned error(%d)\n", __func__, ret); + goto out; + } + + /* Wait for the FRMR_IS_FREE (or FRMR_IS_STALE) transition in order to + * 1) avoid a silly bouncing between "clean_list" and "drop_list" + * triggered by function "rds_ib_reg_frmr" as it is releases frmr + * regions whose state is not "FRMR_IS_FREE" right away. + * 2) prevents an invalid access error in a race + * from a pending "IB_WR_LOCAL_INV" operation + * with a teardown ("dma_unmap_sg", "put_page") + * and de-registration ("ib_dereg_mr") of the corresponding + * memory region. + */ + wait_event(frmr->fr_inv_done, frmr->fr_state != FRMR_IS_INUSE); + +out: + return ret; +} + +void rds_ib_mr_cqe_handler(struct rds_ib_connection *ic, struct ib_wc *wc) +{ + struct rds_ib_mr *ibmr = (void *)(unsigned long)wc->wr_id; + struct rds_ib_frmr *frmr = &ibmr->u.frmr; + + if (wc->status != IB_WC_SUCCESS) { + rds_transition_frwr_state(ibmr, FRMR_IS_INUSE, FRMR_IS_STALE); + if (rds_conn_up(ic->conn)) + rds_ib_conn_error(ic->conn, + "frmr completion <%pI4,%pI4> status %u(%s), vendor_err 0x%x, disconnecting and reconnecting\n", + &ic->conn->c_laddr, + &ic->conn->c_faddr, + wc->status, + ib_wc_status_msg(wc->status), + wc->vendor_err); + } + + if (frmr->fr_inv) { + rds_transition_frwr_state(ibmr, FRMR_IS_INUSE, FRMR_IS_FREE); + frmr->fr_inv = false; + wake_up(&frmr->fr_inv_done); + } + + if (frmr->fr_reg) { + frmr->fr_reg = false; + wake_up(&frmr->fr_reg_done); + } + + /* enforce order of frmr->{fr_reg,fr_inv} update + * before incrementing i_fastreg_wrs + */ + smp_mb__before_atomic(); + atomic_inc(&ic->i_fastreg_wrs); +} + +void rds_ib_unreg_frmr(struct list_head *list, unsigned int *nfreed, + unsigned long *unpinned, unsigned int goal) +{ + struct rds_ib_mr *ibmr, *next; + struct rds_ib_frmr *frmr; + int ret = 0, ret2; + unsigned int freed = *nfreed; + + /* String all ib_mr's onto one list and hand them to ib_unmap_fmr */ + list_for_each_entry(ibmr, list, unmap_list) { + if (ibmr->sg_dma_len) { + ret2 = rds_ib_post_inv(ibmr); + if (ret2 && !ret) + ret = ret2; + } + } + + if (ret) + pr_warn("RDS/IB: %s failed (err=%d)\n", __func__, ret); + + /* Now we can destroy the DMA mapping and unpin any pages */ + list_for_each_entry_safe(ibmr, next, list, unmap_list) { + *unpinned += ibmr->sg_len; + frmr = &ibmr->u.frmr; + __rds_ib_teardown_mr(ibmr); + if (freed < goal || frmr->fr_state == FRMR_IS_STALE) { + /* Don't de-allocate if the MR is not free yet */ + if (frmr->fr_state == FRMR_IS_INUSE) + continue; + + if (ibmr->pool->pool_type == RDS_IB_MR_8K_POOL) + rds_ib_stats_inc(s_ib_rdma_mr_8k_free); + else + rds_ib_stats_inc(s_ib_rdma_mr_1m_free); + list_del(&ibmr->unmap_list); + if (frmr->mr) + ib_dereg_mr(frmr->mr); + kfree(ibmr); + freed++; + } + } + *nfreed = freed; +} + +struct rds_ib_mr *rds_ib_reg_frmr(struct rds_ib_device *rds_ibdev, + struct rds_ib_connection *ic, + struct scatterlist *sg, + unsigned long nents, u32 *key) +{ + struct rds_ib_mr *ibmr = NULL; + struct rds_ib_frmr *frmr; + int ret; + + if (!ic) { + /* TODO: Add FRWR support for RDS_GET_MR using proxy qp*/ + return ERR_PTR(-EOPNOTSUPP); + } + + do { + if (ibmr) + rds_ib_free_frmr(ibmr, true); + ibmr = rds_ib_alloc_frmr(rds_ibdev, nents); + if (IS_ERR(ibmr)) + return ibmr; + frmr = &ibmr->u.frmr; + } while (frmr->fr_state != FRMR_IS_FREE); + + ibmr->ic = ic; + ibmr->device = rds_ibdev; + ret = rds_ib_map_frmr(rds_ibdev, ibmr->pool, ibmr, sg, nents); + if (ret == 0) { + *key = frmr->mr->rkey; + } else { + rds_ib_free_frmr(ibmr, false); + ibmr = ERR_PTR(ret); + } + + return ibmr; +} + +void rds_ib_free_frmr_list(struct rds_ib_mr *ibmr) +{ + struct rds_ib_mr_pool *pool = ibmr->pool; + struct rds_ib_frmr *frmr = &ibmr->u.frmr; + + if (frmr->fr_state == FRMR_IS_STALE) + llist_add(&ibmr->llnode, &pool->drop_list); + else + llist_add(&ibmr->llnode, &pool->free_list); +} diff --git a/net/rds/ib_mr.h b/net/rds/ib_mr.h new file mode 100644 index 000000000..ea5e9aee4 --- /dev/null +++ b/net/rds/ib_mr.h @@ -0,0 +1,143 @@ +/* + * Copyright (c) 2016 Oracle. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef _RDS_IB_MR_H +#define _RDS_IB_MR_H + +#include + +#include "rds.h" +#include "ib.h" + +#define RDS_MR_1M_POOL_SIZE (8192 / 2) +#define RDS_MR_1M_MSG_SIZE 256 +#define RDS_MR_8K_MSG_SIZE 2 +#define RDS_MR_8K_SCALE (256 / (RDS_MR_8K_MSG_SIZE + 1)) +#define RDS_MR_8K_POOL_SIZE (RDS_MR_8K_SCALE * (8192 / 2)) + +enum rds_ib_fr_state { + FRMR_IS_FREE, /* mr invalidated & ready for use */ + FRMR_IS_INUSE, /* mr is in use or used & can be invalidated */ + FRMR_IS_STALE, /* Stale MR and needs to be dropped */ +}; + +struct rds_ib_frmr { + struct ib_mr *mr; + enum rds_ib_fr_state fr_state; + bool fr_inv; + wait_queue_head_t fr_inv_done; + bool fr_reg; + wait_queue_head_t fr_reg_done; + struct ib_send_wr fr_wr; + unsigned int dma_npages; + unsigned int sg_byte_len; +}; + +/* This is stored as mr->r_trans_private. */ +struct rds_ib_mr { + struct delayed_work work; + struct rds_ib_device *device; + struct rds_ib_mr_pool *pool; + struct rds_ib_connection *ic; + + struct llist_node llnode; + + /* unmap_list is for freeing */ + struct list_head unmap_list; + unsigned int remap_count; + + struct scatterlist *sg; + unsigned int sg_len; + int sg_dma_len; + + u8 odp:1; + union { + struct rds_ib_frmr frmr; + struct ib_mr *mr; + } u; +}; + +/* Our own little MR pool */ +struct rds_ib_mr_pool { + unsigned int pool_type; + struct mutex flush_lock; /* serialize fmr invalidate */ + struct delayed_work flush_worker; /* flush worker */ + + atomic_t item_count; /* total # of MRs */ + atomic_t dirty_count; /* # dirty of MRs */ + + struct llist_head drop_list; /* MRs not reached max_maps */ + struct llist_head free_list; /* unused MRs */ + struct llist_head clean_list; /* unused & unmapped MRs */ + wait_queue_head_t flush_wait; + spinlock_t clean_lock; /* "clean_list" concurrency */ + + atomic_t free_pinned; /* memory pinned by free MRs */ + unsigned long max_items; + unsigned long max_items_soft; + unsigned long max_free_pinned; + unsigned int max_pages; +}; + +extern struct workqueue_struct *rds_ib_mr_wq; +extern bool prefer_frmr; + +struct rds_ib_mr_pool *rds_ib_create_mr_pool(struct rds_ib_device *rds_dev, + int npages); +void rds_ib_get_mr_info(struct rds_ib_device *rds_ibdev, + struct rds_info_rdma_connection *iinfo); +void rds6_ib_get_mr_info(struct rds_ib_device *rds_ibdev, + struct rds6_info_rdma_connection *iinfo6); +void rds_ib_destroy_mr_pool(struct rds_ib_mr_pool *); +void *rds_ib_get_mr(struct scatterlist *sg, unsigned long nents, + struct rds_sock *rs, u32 *key_ret, + struct rds_connection *conn, u64 start, u64 length, + int need_odp); +void rds_ib_sync_mr(void *trans_private, int dir); +void rds_ib_free_mr(void *trans_private, int invalidate); +void rds_ib_flush_mrs(void); +int rds_ib_mr_init(void); +void rds_ib_mr_exit(void); +u32 rds_ib_get_lkey(void *trans_private); + +void __rds_ib_teardown_mr(struct rds_ib_mr *); +void rds_ib_teardown_mr(struct rds_ib_mr *); +struct rds_ib_mr *rds_ib_reuse_mr(struct rds_ib_mr_pool *); +int rds_ib_flush_mr_pool(struct rds_ib_mr_pool *, int, struct rds_ib_mr **); +struct rds_ib_mr *rds_ib_try_reuse_ibmr(struct rds_ib_mr_pool *); +struct rds_ib_mr *rds_ib_reg_frmr(struct rds_ib_device *rds_ibdev, + struct rds_ib_connection *ic, + struct scatterlist *sg, + unsigned long nents, u32 *key); +void rds_ib_unreg_frmr(struct list_head *list, unsigned int *nfreed, + unsigned long *unpinned, unsigned int goal); +void rds_ib_free_frmr_list(struct rds_ib_mr *); +#endif diff --git a/net/rds/ib_rdma.c b/net/rds/ib_rdma.c new file mode 100644 index 000000000..8f070ee7e --- /dev/null +++ b/net/rds/ib_rdma.c @@ -0,0 +1,701 @@ +/* + * Copyright (c) 2006, 2018 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include + +#include "rds_single_path.h" +#include "ib_mr.h" +#include "rds.h" + +struct workqueue_struct *rds_ib_mr_wq; +struct rds_ib_dereg_odp_mr { + struct work_struct work; + struct ib_mr *mr; +}; + +static void rds_ib_odp_mr_worker(struct work_struct *work); + +static struct rds_ib_device *rds_ib_get_device(__be32 ipaddr) +{ + struct rds_ib_device *rds_ibdev; + struct rds_ib_ipaddr *i_ipaddr; + + rcu_read_lock(); + list_for_each_entry_rcu(rds_ibdev, &rds_ib_devices, list) { + list_for_each_entry_rcu(i_ipaddr, &rds_ibdev->ipaddr_list, list) { + if (i_ipaddr->ipaddr == ipaddr) { + refcount_inc(&rds_ibdev->refcount); + rcu_read_unlock(); + return rds_ibdev; + } + } + } + rcu_read_unlock(); + + return NULL; +} + +static int rds_ib_add_ipaddr(struct rds_ib_device *rds_ibdev, __be32 ipaddr) +{ + struct rds_ib_ipaddr *i_ipaddr; + + i_ipaddr = kmalloc(sizeof *i_ipaddr, GFP_KERNEL); + if (!i_ipaddr) + return -ENOMEM; + + i_ipaddr->ipaddr = ipaddr; + + spin_lock_irq(&rds_ibdev->spinlock); + list_add_tail_rcu(&i_ipaddr->list, &rds_ibdev->ipaddr_list); + spin_unlock_irq(&rds_ibdev->spinlock); + + return 0; +} + +static void rds_ib_remove_ipaddr(struct rds_ib_device *rds_ibdev, __be32 ipaddr) +{ + struct rds_ib_ipaddr *i_ipaddr; + struct rds_ib_ipaddr *to_free = NULL; + + + spin_lock_irq(&rds_ibdev->spinlock); + list_for_each_entry_rcu(i_ipaddr, &rds_ibdev->ipaddr_list, list) { + if (i_ipaddr->ipaddr == ipaddr) { + list_del_rcu(&i_ipaddr->list); + to_free = i_ipaddr; + break; + } + } + spin_unlock_irq(&rds_ibdev->spinlock); + + if (to_free) + kfree_rcu(to_free, rcu); +} + +int rds_ib_update_ipaddr(struct rds_ib_device *rds_ibdev, + struct in6_addr *ipaddr) +{ + struct rds_ib_device *rds_ibdev_old; + + rds_ibdev_old = rds_ib_get_device(ipaddr->s6_addr32[3]); + if (!rds_ibdev_old) + return rds_ib_add_ipaddr(rds_ibdev, ipaddr->s6_addr32[3]); + + if (rds_ibdev_old != rds_ibdev) { + rds_ib_remove_ipaddr(rds_ibdev_old, ipaddr->s6_addr32[3]); + rds_ib_dev_put(rds_ibdev_old); + return rds_ib_add_ipaddr(rds_ibdev, ipaddr->s6_addr32[3]); + } + rds_ib_dev_put(rds_ibdev_old); + + return 0; +} + +void rds_ib_add_conn(struct rds_ib_device *rds_ibdev, struct rds_connection *conn) +{ + struct rds_ib_connection *ic = conn->c_transport_data; + + /* conn was previously on the nodev_conns_list */ + spin_lock_irq(&ib_nodev_conns_lock); + BUG_ON(list_empty(&ib_nodev_conns)); + BUG_ON(list_empty(&ic->ib_node)); + list_del(&ic->ib_node); + + spin_lock(&rds_ibdev->spinlock); + list_add_tail(&ic->ib_node, &rds_ibdev->conn_list); + spin_unlock(&rds_ibdev->spinlock); + spin_unlock_irq(&ib_nodev_conns_lock); + + ic->rds_ibdev = rds_ibdev; + refcount_inc(&rds_ibdev->refcount); +} + +void rds_ib_remove_conn(struct rds_ib_device *rds_ibdev, struct rds_connection *conn) +{ + struct rds_ib_connection *ic = conn->c_transport_data; + + /* place conn on nodev_conns_list */ + spin_lock(&ib_nodev_conns_lock); + + spin_lock_irq(&rds_ibdev->spinlock); + BUG_ON(list_empty(&ic->ib_node)); + list_del(&ic->ib_node); + spin_unlock_irq(&rds_ibdev->spinlock); + + list_add_tail(&ic->ib_node, &ib_nodev_conns); + + spin_unlock(&ib_nodev_conns_lock); + + ic->rds_ibdev = NULL; + rds_ib_dev_put(rds_ibdev); +} + +void rds_ib_destroy_nodev_conns(void) +{ + struct rds_ib_connection *ic, *_ic; + LIST_HEAD(tmp_list); + + /* avoid calling conn_destroy with irqs off */ + spin_lock_irq(&ib_nodev_conns_lock); + list_splice(&ib_nodev_conns, &tmp_list); + spin_unlock_irq(&ib_nodev_conns_lock); + + list_for_each_entry_safe(ic, _ic, &tmp_list, ib_node) + rds_conn_destroy(ic->conn); +} + +void rds_ib_get_mr_info(struct rds_ib_device *rds_ibdev, struct rds_info_rdma_connection *iinfo) +{ + struct rds_ib_mr_pool *pool_1m = rds_ibdev->mr_1m_pool; + + iinfo->rdma_mr_max = pool_1m->max_items; + iinfo->rdma_mr_size = pool_1m->max_pages; +} + +#if IS_ENABLED(CONFIG_IPV6) +void rds6_ib_get_mr_info(struct rds_ib_device *rds_ibdev, + struct rds6_info_rdma_connection *iinfo6) +{ + struct rds_ib_mr_pool *pool_1m = rds_ibdev->mr_1m_pool; + + iinfo6->rdma_mr_max = pool_1m->max_items; + iinfo6->rdma_mr_size = pool_1m->max_pages; +} +#endif + +struct rds_ib_mr *rds_ib_reuse_mr(struct rds_ib_mr_pool *pool) +{ + struct rds_ib_mr *ibmr = NULL; + struct llist_node *ret; + unsigned long flags; + + spin_lock_irqsave(&pool->clean_lock, flags); + ret = llist_del_first(&pool->clean_list); + spin_unlock_irqrestore(&pool->clean_lock, flags); + if (ret) { + ibmr = llist_entry(ret, struct rds_ib_mr, llnode); + if (pool->pool_type == RDS_IB_MR_8K_POOL) + rds_ib_stats_inc(s_ib_rdma_mr_8k_reused); + else + rds_ib_stats_inc(s_ib_rdma_mr_1m_reused); + } + + return ibmr; +} + +void rds_ib_sync_mr(void *trans_private, int direction) +{ + struct rds_ib_mr *ibmr = trans_private; + struct rds_ib_device *rds_ibdev = ibmr->device; + + if (ibmr->odp) + return; + + switch (direction) { + case DMA_FROM_DEVICE: + ib_dma_sync_sg_for_cpu(rds_ibdev->dev, ibmr->sg, + ibmr->sg_dma_len, DMA_BIDIRECTIONAL); + break; + case DMA_TO_DEVICE: + ib_dma_sync_sg_for_device(rds_ibdev->dev, ibmr->sg, + ibmr->sg_dma_len, DMA_BIDIRECTIONAL); + break; + } +} + +void __rds_ib_teardown_mr(struct rds_ib_mr *ibmr) +{ + struct rds_ib_device *rds_ibdev = ibmr->device; + + if (ibmr->sg_dma_len) { + ib_dma_unmap_sg(rds_ibdev->dev, + ibmr->sg, ibmr->sg_len, + DMA_BIDIRECTIONAL); + ibmr->sg_dma_len = 0; + } + + /* Release the s/g list */ + if (ibmr->sg_len) { + unsigned int i; + + for (i = 0; i < ibmr->sg_len; ++i) { + struct page *page = sg_page(&ibmr->sg[i]); + + /* FIXME we need a way to tell a r/w MR + * from a r/o MR */ + WARN_ON(!page->mapping && irqs_disabled()); + set_page_dirty(page); + put_page(page); + } + kfree(ibmr->sg); + + ibmr->sg = NULL; + ibmr->sg_len = 0; + } +} + +void rds_ib_teardown_mr(struct rds_ib_mr *ibmr) +{ + unsigned int pinned = ibmr->sg_len; + + __rds_ib_teardown_mr(ibmr); + if (pinned) { + struct rds_ib_mr_pool *pool = ibmr->pool; + + atomic_sub(pinned, &pool->free_pinned); + } +} + +static inline unsigned int rds_ib_flush_goal(struct rds_ib_mr_pool *pool, int free_all) +{ + unsigned int item_count; + + item_count = atomic_read(&pool->item_count); + if (free_all) + return item_count; + + return 0; +} + +/* + * given an llist of mrs, put them all into the list_head for more processing + */ +static unsigned int llist_append_to_list(struct llist_head *llist, + struct list_head *list) +{ + struct rds_ib_mr *ibmr; + struct llist_node *node; + struct llist_node *next; + unsigned int count = 0; + + node = llist_del_all(llist); + while (node) { + next = node->next; + ibmr = llist_entry(node, struct rds_ib_mr, llnode); + list_add_tail(&ibmr->unmap_list, list); + node = next; + count++; + } + return count; +} + +/* + * this takes a list head of mrs and turns it into linked llist nodes + * of clusters. Each cluster has linked llist nodes of + * MR_CLUSTER_SIZE mrs that are ready for reuse. + */ +static void list_to_llist_nodes(struct list_head *list, + struct llist_node **nodes_head, + struct llist_node **nodes_tail) +{ + struct rds_ib_mr *ibmr; + struct llist_node *cur = NULL; + struct llist_node **next = nodes_head; + + list_for_each_entry(ibmr, list, unmap_list) { + cur = &ibmr->llnode; + *next = cur; + next = &cur->next; + } + *next = NULL; + *nodes_tail = cur; +} + +/* + * Flush our pool of MRs. + * At a minimum, all currently unused MRs are unmapped. + * If the number of MRs allocated exceeds the limit, we also try + * to free as many MRs as needed to get back to this limit. + */ +int rds_ib_flush_mr_pool(struct rds_ib_mr_pool *pool, + int free_all, struct rds_ib_mr **ibmr_ret) +{ + struct rds_ib_mr *ibmr; + struct llist_node *clean_nodes; + struct llist_node *clean_tail; + LIST_HEAD(unmap_list); + unsigned long unpinned = 0; + unsigned int nfreed = 0, dirty_to_clean = 0, free_goal; + + if (pool->pool_type == RDS_IB_MR_8K_POOL) + rds_ib_stats_inc(s_ib_rdma_mr_8k_pool_flush); + else + rds_ib_stats_inc(s_ib_rdma_mr_1m_pool_flush); + + if (ibmr_ret) { + DEFINE_WAIT(wait); + while (!mutex_trylock(&pool->flush_lock)) { + ibmr = rds_ib_reuse_mr(pool); + if (ibmr) { + *ibmr_ret = ibmr; + finish_wait(&pool->flush_wait, &wait); + goto out_nolock; + } + + prepare_to_wait(&pool->flush_wait, &wait, + TASK_UNINTERRUPTIBLE); + if (llist_empty(&pool->clean_list)) + schedule(); + + ibmr = rds_ib_reuse_mr(pool); + if (ibmr) { + *ibmr_ret = ibmr; + finish_wait(&pool->flush_wait, &wait); + goto out_nolock; + } + } + finish_wait(&pool->flush_wait, &wait); + } else + mutex_lock(&pool->flush_lock); + + if (ibmr_ret) { + ibmr = rds_ib_reuse_mr(pool); + if (ibmr) { + *ibmr_ret = ibmr; + goto out; + } + } + + /* Get the list of all MRs to be dropped. Ordering matters - + * we want to put drop_list ahead of free_list. + */ + dirty_to_clean = llist_append_to_list(&pool->drop_list, &unmap_list); + dirty_to_clean += llist_append_to_list(&pool->free_list, &unmap_list); + if (free_all) { + unsigned long flags; + + spin_lock_irqsave(&pool->clean_lock, flags); + llist_append_to_list(&pool->clean_list, &unmap_list); + spin_unlock_irqrestore(&pool->clean_lock, flags); + } + + free_goal = rds_ib_flush_goal(pool, free_all); + + if (list_empty(&unmap_list)) + goto out; + + rds_ib_unreg_frmr(&unmap_list, &nfreed, &unpinned, free_goal); + + if (!list_empty(&unmap_list)) { + unsigned long flags; + + list_to_llist_nodes(&unmap_list, &clean_nodes, &clean_tail); + if (ibmr_ret) { + *ibmr_ret = llist_entry(clean_nodes, struct rds_ib_mr, llnode); + clean_nodes = clean_nodes->next; + } + /* more than one entry in llist nodes */ + if (clean_nodes) { + spin_lock_irqsave(&pool->clean_lock, flags); + llist_add_batch(clean_nodes, clean_tail, + &pool->clean_list); + spin_unlock_irqrestore(&pool->clean_lock, flags); + } + } + + atomic_sub(unpinned, &pool->free_pinned); + atomic_sub(dirty_to_clean, &pool->dirty_count); + atomic_sub(nfreed, &pool->item_count); + +out: + mutex_unlock(&pool->flush_lock); + if (waitqueue_active(&pool->flush_wait)) + wake_up(&pool->flush_wait); +out_nolock: + return 0; +} + +struct rds_ib_mr *rds_ib_try_reuse_ibmr(struct rds_ib_mr_pool *pool) +{ + struct rds_ib_mr *ibmr = NULL; + int iter = 0; + + while (1) { + ibmr = rds_ib_reuse_mr(pool); + if (ibmr) + return ibmr; + + if (atomic_inc_return(&pool->item_count) <= pool->max_items) + break; + + atomic_dec(&pool->item_count); + + if (++iter > 2) { + if (pool->pool_type == RDS_IB_MR_8K_POOL) + rds_ib_stats_inc(s_ib_rdma_mr_8k_pool_depleted); + else + rds_ib_stats_inc(s_ib_rdma_mr_1m_pool_depleted); + break; + } + + /* We do have some empty MRs. Flush them out. */ + if (pool->pool_type == RDS_IB_MR_8K_POOL) + rds_ib_stats_inc(s_ib_rdma_mr_8k_pool_wait); + else + rds_ib_stats_inc(s_ib_rdma_mr_1m_pool_wait); + + rds_ib_flush_mr_pool(pool, 0, &ibmr); + if (ibmr) + return ibmr; + } + + return NULL; +} + +static void rds_ib_mr_pool_flush_worker(struct work_struct *work) +{ + struct rds_ib_mr_pool *pool = container_of(work, struct rds_ib_mr_pool, flush_worker.work); + + rds_ib_flush_mr_pool(pool, 0, NULL); +} + +void rds_ib_free_mr(void *trans_private, int invalidate) +{ + struct rds_ib_mr *ibmr = trans_private; + struct rds_ib_mr_pool *pool = ibmr->pool; + struct rds_ib_device *rds_ibdev = ibmr->device; + + rdsdebug("RDS/IB: free_mr nents %u\n", ibmr->sg_len); + + if (ibmr->odp) { + /* A MR created and marked as use_once. We use delayed work, + * because there is a change that we are in interrupt and can't + * call to ib_dereg_mr() directly. + */ + INIT_DELAYED_WORK(&ibmr->work, rds_ib_odp_mr_worker); + queue_delayed_work(rds_ib_mr_wq, &ibmr->work, 0); + return; + } + + /* Return it to the pool's free list */ + rds_ib_free_frmr_list(ibmr); + + atomic_add(ibmr->sg_len, &pool->free_pinned); + atomic_inc(&pool->dirty_count); + + /* If we've pinned too many pages, request a flush */ + if (atomic_read(&pool->free_pinned) >= pool->max_free_pinned || + atomic_read(&pool->dirty_count) >= pool->max_items / 5) + queue_delayed_work(rds_ib_mr_wq, &pool->flush_worker, 10); + + if (invalidate) { + if (likely(!in_interrupt())) { + rds_ib_flush_mr_pool(pool, 0, NULL); + } else { + /* We get here if the user created a MR marked + * as use_once and invalidate at the same time. + */ + queue_delayed_work(rds_ib_mr_wq, + &pool->flush_worker, 10); + } + } + + rds_ib_dev_put(rds_ibdev); +} + +void rds_ib_flush_mrs(void) +{ + struct rds_ib_device *rds_ibdev; + + down_read(&rds_ib_devices_lock); + list_for_each_entry(rds_ibdev, &rds_ib_devices, list) { + if (rds_ibdev->mr_8k_pool) + rds_ib_flush_mr_pool(rds_ibdev->mr_8k_pool, 0, NULL); + + if (rds_ibdev->mr_1m_pool) + rds_ib_flush_mr_pool(rds_ibdev->mr_1m_pool, 0, NULL); + } + up_read(&rds_ib_devices_lock); +} + +u32 rds_ib_get_lkey(void *trans_private) +{ + struct rds_ib_mr *ibmr = trans_private; + + return ibmr->u.mr->lkey; +} + +void *rds_ib_get_mr(struct scatterlist *sg, unsigned long nents, + struct rds_sock *rs, u32 *key_ret, + struct rds_connection *conn, + u64 start, u64 length, int need_odp) +{ + struct rds_ib_device *rds_ibdev; + struct rds_ib_mr *ibmr = NULL; + struct rds_ib_connection *ic = NULL; + int ret; + + rds_ibdev = rds_ib_get_device(rs->rs_bound_addr.s6_addr32[3]); + if (!rds_ibdev) { + ret = -ENODEV; + goto out; + } + + if (need_odp == ODP_ZEROBASED || need_odp == ODP_VIRTUAL) { + u64 virt_addr = need_odp == ODP_ZEROBASED ? 0 : start; + int access_flags = + (IB_ACCESS_LOCAL_WRITE | IB_ACCESS_REMOTE_READ | + IB_ACCESS_REMOTE_WRITE | IB_ACCESS_REMOTE_ATOMIC | + IB_ACCESS_ON_DEMAND); + struct ib_sge sge = {}; + struct ib_mr *ib_mr; + + if (!rds_ibdev->odp_capable) { + ret = -EOPNOTSUPP; + goto out; + } + + ib_mr = ib_reg_user_mr(rds_ibdev->pd, start, length, virt_addr, + access_flags); + + if (IS_ERR(ib_mr)) { + rdsdebug("rds_ib_get_user_mr returned %d\n", + IS_ERR(ib_mr)); + ret = PTR_ERR(ib_mr); + goto out; + } + if (key_ret) + *key_ret = ib_mr->rkey; + + ibmr = kzalloc(sizeof(*ibmr), GFP_KERNEL); + if (!ibmr) { + ib_dereg_mr(ib_mr); + ret = -ENOMEM; + goto out; + } + ibmr->u.mr = ib_mr; + ibmr->odp = 1; + + sge.addr = virt_addr; + sge.length = length; + sge.lkey = ib_mr->lkey; + + ib_advise_mr(rds_ibdev->pd, + IB_UVERBS_ADVISE_MR_ADVICE_PREFETCH_WRITE, + IB_UVERBS_ADVISE_MR_FLAG_FLUSH, &sge, 1); + return ibmr; + } + + if (conn) + ic = conn->c_transport_data; + + if (!rds_ibdev->mr_8k_pool || !rds_ibdev->mr_1m_pool) { + ret = -ENODEV; + goto out; + } + + ibmr = rds_ib_reg_frmr(rds_ibdev, ic, sg, nents, key_ret); + if (IS_ERR(ibmr)) { + ret = PTR_ERR(ibmr); + pr_warn("RDS/IB: rds_ib_get_mr failed (errno=%d)\n", ret); + } else { + return ibmr; + } + + out: + if (rds_ibdev) + rds_ib_dev_put(rds_ibdev); + + return ERR_PTR(ret); +} + +void rds_ib_destroy_mr_pool(struct rds_ib_mr_pool *pool) +{ + cancel_delayed_work_sync(&pool->flush_worker); + rds_ib_flush_mr_pool(pool, 1, NULL); + WARN_ON(atomic_read(&pool->item_count)); + WARN_ON(atomic_read(&pool->free_pinned)); + kfree(pool); +} + +struct rds_ib_mr_pool *rds_ib_create_mr_pool(struct rds_ib_device *rds_ibdev, + int pool_type) +{ + struct rds_ib_mr_pool *pool; + + pool = kzalloc(sizeof(*pool), GFP_KERNEL); + if (!pool) + return ERR_PTR(-ENOMEM); + + pool->pool_type = pool_type; + init_llist_head(&pool->free_list); + init_llist_head(&pool->drop_list); + init_llist_head(&pool->clean_list); + spin_lock_init(&pool->clean_lock); + mutex_init(&pool->flush_lock); + init_waitqueue_head(&pool->flush_wait); + INIT_DELAYED_WORK(&pool->flush_worker, rds_ib_mr_pool_flush_worker); + + if (pool_type == RDS_IB_MR_1M_POOL) { + /* +1 allows for unaligned MRs */ + pool->max_pages = RDS_MR_1M_MSG_SIZE + 1; + pool->max_items = rds_ibdev->max_1m_mrs; + } else { + /* pool_type == RDS_IB_MR_8K_POOL */ + pool->max_pages = RDS_MR_8K_MSG_SIZE + 1; + pool->max_items = rds_ibdev->max_8k_mrs; + } + + pool->max_free_pinned = pool->max_items * pool->max_pages / 4; + pool->max_items_soft = rds_ibdev->max_mrs * 3 / 4; + + return pool; +} + +int rds_ib_mr_init(void) +{ + rds_ib_mr_wq = alloc_workqueue("rds_mr_flushd", WQ_MEM_RECLAIM, 0); + if (!rds_ib_mr_wq) + return -ENOMEM; + return 0; +} + +/* By the time this is called all the IB devices should have been torn down and + * had their pools freed. As each pool is freed its work struct is waited on, + * so the pool flushing work queue should be idle by the time we get here. + */ +void rds_ib_mr_exit(void) +{ + destroy_workqueue(rds_ib_mr_wq); +} + +static void rds_ib_odp_mr_worker(struct work_struct *work) +{ + struct rds_ib_mr *ibmr; + + ibmr = container_of(work, struct rds_ib_mr, work.work); + ib_dereg_mr(ibmr->u.mr); + kfree(ibmr); +} diff --git a/net/rds/ib_recv.c b/net/rds/ib_recv.c new file mode 100644 index 000000000..cfbf0e129 --- /dev/null +++ b/net/rds/ib_recv.c @@ -0,0 +1,1094 @@ +/* + * Copyright (c) 2006, 2019 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include +#include + +#include "rds_single_path.h" +#include "rds.h" +#include "ib.h" + +static struct kmem_cache *rds_ib_incoming_slab; +static struct kmem_cache *rds_ib_frag_slab; +static atomic_t rds_ib_allocation = ATOMIC_INIT(0); + +void rds_ib_recv_init_ring(struct rds_ib_connection *ic) +{ + struct rds_ib_recv_work *recv; + u32 i; + + for (i = 0, recv = ic->i_recvs; i < ic->i_recv_ring.w_nr; i++, recv++) { + struct ib_sge *sge; + + recv->r_ibinc = NULL; + recv->r_frag = NULL; + + recv->r_wr.next = NULL; + recv->r_wr.wr_id = i; + recv->r_wr.sg_list = recv->r_sge; + recv->r_wr.num_sge = RDS_IB_RECV_SGE; + + sge = &recv->r_sge[0]; + sge->addr = ic->i_recv_hdrs_dma[i]; + sge->length = sizeof(struct rds_header); + sge->lkey = ic->i_pd->local_dma_lkey; + + sge = &recv->r_sge[1]; + sge->addr = 0; + sge->length = RDS_FRAG_SIZE; + sge->lkey = ic->i_pd->local_dma_lkey; + } +} + +/* + * The entire 'from' list, including the from element itself, is put on + * to the tail of the 'to' list. + */ +static void list_splice_entire_tail(struct list_head *from, + struct list_head *to) +{ + struct list_head *from_last = from->prev; + + list_splice_tail(from_last, to); + list_add_tail(from_last, to); +} + +static void rds_ib_cache_xfer_to_ready(struct rds_ib_refill_cache *cache) +{ + struct list_head *tmp; + + tmp = xchg(&cache->xfer, NULL); + if (tmp) { + if (cache->ready) + list_splice_entire_tail(tmp, cache->ready); + else + cache->ready = tmp; + } +} + +static int rds_ib_recv_alloc_cache(struct rds_ib_refill_cache *cache, gfp_t gfp) +{ + struct rds_ib_cache_head *head; + int cpu; + + cache->percpu = alloc_percpu_gfp(struct rds_ib_cache_head, gfp); + if (!cache->percpu) + return -ENOMEM; + + for_each_possible_cpu(cpu) { + head = per_cpu_ptr(cache->percpu, cpu); + head->first = NULL; + head->count = 0; + } + cache->xfer = NULL; + cache->ready = NULL; + + return 0; +} + +int rds_ib_recv_alloc_caches(struct rds_ib_connection *ic, gfp_t gfp) +{ + int ret; + + ret = rds_ib_recv_alloc_cache(&ic->i_cache_incs, gfp); + if (!ret) { + ret = rds_ib_recv_alloc_cache(&ic->i_cache_frags, gfp); + if (ret) + free_percpu(ic->i_cache_incs.percpu); + } + + return ret; +} + +static void rds_ib_cache_splice_all_lists(struct rds_ib_refill_cache *cache, + struct list_head *caller_list) +{ + struct rds_ib_cache_head *head; + int cpu; + + for_each_possible_cpu(cpu) { + head = per_cpu_ptr(cache->percpu, cpu); + if (head->first) { + list_splice_entire_tail(head->first, caller_list); + head->first = NULL; + } + } + + if (cache->ready) { + list_splice_entire_tail(cache->ready, caller_list); + cache->ready = NULL; + } +} + +void rds_ib_recv_free_caches(struct rds_ib_connection *ic) +{ + struct rds_ib_incoming *inc; + struct rds_ib_incoming *inc_tmp; + struct rds_page_frag *frag; + struct rds_page_frag *frag_tmp; + LIST_HEAD(list); + + rds_ib_cache_xfer_to_ready(&ic->i_cache_incs); + rds_ib_cache_splice_all_lists(&ic->i_cache_incs, &list); + free_percpu(ic->i_cache_incs.percpu); + + list_for_each_entry_safe(inc, inc_tmp, &list, ii_cache_entry) { + list_del(&inc->ii_cache_entry); + WARN_ON(!list_empty(&inc->ii_frags)); + kmem_cache_free(rds_ib_incoming_slab, inc); + atomic_dec(&rds_ib_allocation); + } + + rds_ib_cache_xfer_to_ready(&ic->i_cache_frags); + rds_ib_cache_splice_all_lists(&ic->i_cache_frags, &list); + free_percpu(ic->i_cache_frags.percpu); + + list_for_each_entry_safe(frag, frag_tmp, &list, f_cache_entry) { + list_del(&frag->f_cache_entry); + WARN_ON(!list_empty(&frag->f_item)); + kmem_cache_free(rds_ib_frag_slab, frag); + } +} + +/* fwd decl */ +static void rds_ib_recv_cache_put(struct list_head *new_item, + struct rds_ib_refill_cache *cache); +static struct list_head *rds_ib_recv_cache_get(struct rds_ib_refill_cache *cache); + + +/* Recycle frag and attached recv buffer f_sg */ +static void rds_ib_frag_free(struct rds_ib_connection *ic, + struct rds_page_frag *frag) +{ + rdsdebug("frag %p page %p\n", frag, sg_page(&frag->f_sg)); + + rds_ib_recv_cache_put(&frag->f_cache_entry, &ic->i_cache_frags); + atomic_add(RDS_FRAG_SIZE / SZ_1K, &ic->i_cache_allocs); + rds_ib_stats_add(s_ib_recv_added_to_cache, RDS_FRAG_SIZE); +} + +/* Recycle inc after freeing attached frags */ +void rds_ib_inc_free(struct rds_incoming *inc) +{ + struct rds_ib_incoming *ibinc; + struct rds_page_frag *frag; + struct rds_page_frag *pos; + struct rds_ib_connection *ic = inc->i_conn->c_transport_data; + + ibinc = container_of(inc, struct rds_ib_incoming, ii_inc); + + /* Free attached frags */ + list_for_each_entry_safe(frag, pos, &ibinc->ii_frags, f_item) { + list_del_init(&frag->f_item); + rds_ib_frag_free(ic, frag); + } + BUG_ON(!list_empty(&ibinc->ii_frags)); + + rdsdebug("freeing ibinc %p inc %p\n", ibinc, inc); + rds_ib_recv_cache_put(&ibinc->ii_cache_entry, &ic->i_cache_incs); +} + +static void rds_ib_recv_clear_one(struct rds_ib_connection *ic, + struct rds_ib_recv_work *recv) +{ + if (recv->r_ibinc) { + rds_inc_put(&recv->r_ibinc->ii_inc); + recv->r_ibinc = NULL; + } + if (recv->r_frag) { + ib_dma_unmap_sg(ic->i_cm_id->device, &recv->r_frag->f_sg, 1, DMA_FROM_DEVICE); + rds_ib_frag_free(ic, recv->r_frag); + recv->r_frag = NULL; + } +} + +void rds_ib_recv_clear_ring(struct rds_ib_connection *ic) +{ + u32 i; + + for (i = 0; i < ic->i_recv_ring.w_nr; i++) + rds_ib_recv_clear_one(ic, &ic->i_recvs[i]); +} + +static struct rds_ib_incoming *rds_ib_refill_one_inc(struct rds_ib_connection *ic, + gfp_t slab_mask) +{ + struct rds_ib_incoming *ibinc; + struct list_head *cache_item; + int avail_allocs; + + cache_item = rds_ib_recv_cache_get(&ic->i_cache_incs); + if (cache_item) { + ibinc = container_of(cache_item, struct rds_ib_incoming, ii_cache_entry); + } else { + avail_allocs = atomic_add_unless(&rds_ib_allocation, + 1, rds_ib_sysctl_max_recv_allocation); + if (!avail_allocs) { + rds_ib_stats_inc(s_ib_rx_alloc_limit); + return NULL; + } + ibinc = kmem_cache_alloc(rds_ib_incoming_slab, slab_mask); + if (!ibinc) { + atomic_dec(&rds_ib_allocation); + return NULL; + } + rds_ib_stats_inc(s_ib_rx_total_incs); + } + INIT_LIST_HEAD(&ibinc->ii_frags); + rds_inc_init(&ibinc->ii_inc, ic->conn, &ic->conn->c_faddr); + + return ibinc; +} + +static struct rds_page_frag *rds_ib_refill_one_frag(struct rds_ib_connection *ic, + gfp_t slab_mask, gfp_t page_mask) +{ + struct rds_page_frag *frag; + struct list_head *cache_item; + int ret; + + cache_item = rds_ib_recv_cache_get(&ic->i_cache_frags); + if (cache_item) { + frag = container_of(cache_item, struct rds_page_frag, f_cache_entry); + atomic_sub(RDS_FRAG_SIZE / SZ_1K, &ic->i_cache_allocs); + rds_ib_stats_add(s_ib_recv_added_to_cache, RDS_FRAG_SIZE); + } else { + frag = kmem_cache_alloc(rds_ib_frag_slab, slab_mask); + if (!frag) + return NULL; + + sg_init_table(&frag->f_sg, 1); + ret = rds_page_remainder_alloc(&frag->f_sg, + RDS_FRAG_SIZE, page_mask); + if (ret) { + kmem_cache_free(rds_ib_frag_slab, frag); + return NULL; + } + rds_ib_stats_inc(s_ib_rx_total_frags); + } + + INIT_LIST_HEAD(&frag->f_item); + + return frag; +} + +static int rds_ib_recv_refill_one(struct rds_connection *conn, + struct rds_ib_recv_work *recv, gfp_t gfp) +{ + struct rds_ib_connection *ic = conn->c_transport_data; + struct ib_sge *sge; + int ret = -ENOMEM; + gfp_t slab_mask = gfp; + gfp_t page_mask = gfp; + + if (gfp & __GFP_DIRECT_RECLAIM) { + slab_mask = GFP_KERNEL; + page_mask = GFP_HIGHUSER; + } + + if (!ic->i_cache_incs.ready) + rds_ib_cache_xfer_to_ready(&ic->i_cache_incs); + if (!ic->i_cache_frags.ready) + rds_ib_cache_xfer_to_ready(&ic->i_cache_frags); + + /* + * ibinc was taken from recv if recv contained the start of a message. + * recvs that were continuations will still have this allocated. + */ + if (!recv->r_ibinc) { + recv->r_ibinc = rds_ib_refill_one_inc(ic, slab_mask); + if (!recv->r_ibinc) + goto out; + } + + WARN_ON(recv->r_frag); /* leak! */ + recv->r_frag = rds_ib_refill_one_frag(ic, slab_mask, page_mask); + if (!recv->r_frag) + goto out; + + ret = ib_dma_map_sg(ic->i_cm_id->device, &recv->r_frag->f_sg, + 1, DMA_FROM_DEVICE); + WARN_ON(ret != 1); + + sge = &recv->r_sge[0]; + sge->addr = ic->i_recv_hdrs_dma[recv - ic->i_recvs]; + sge->length = sizeof(struct rds_header); + + sge = &recv->r_sge[1]; + sge->addr = sg_dma_address(&recv->r_frag->f_sg); + sge->length = sg_dma_len(&recv->r_frag->f_sg); + + ret = 0; +out: + return ret; +} + +static int acquire_refill(struct rds_connection *conn) +{ + return test_and_set_bit(RDS_RECV_REFILL, &conn->c_flags) == 0; +} + +static void release_refill(struct rds_connection *conn) +{ + clear_bit(RDS_RECV_REFILL, &conn->c_flags); + smp_mb__after_atomic(); + + /* We don't use wait_on_bit()/wake_up_bit() because our waking is in a + * hot path and finding waiters is very rare. We don't want to walk + * the system-wide hashed waitqueue buckets in the fast path only to + * almost never find waiters. + */ + if (waitqueue_active(&conn->c_waitq)) + wake_up_all(&conn->c_waitq); +} + +/* + * This tries to allocate and post unused work requests after making sure that + * they have all the allocations they need to queue received fragments into + * sockets. + */ +void rds_ib_recv_refill(struct rds_connection *conn, int prefill, gfp_t gfp) +{ + struct rds_ib_connection *ic = conn->c_transport_data; + struct rds_ib_recv_work *recv; + unsigned int posted = 0; + int ret = 0; + bool can_wait = !!(gfp & __GFP_DIRECT_RECLAIM); + bool must_wake = false; + u32 pos; + + /* the goal here is to just make sure that someone, somewhere + * is posting buffers. If we can't get the refill lock, + * let them do their thing + */ + if (!acquire_refill(conn)) + return; + + while ((prefill || rds_conn_up(conn)) && + rds_ib_ring_alloc(&ic->i_recv_ring, 1, &pos)) { + if (pos >= ic->i_recv_ring.w_nr) { + printk(KERN_NOTICE "Argh - ring alloc returned pos=%u\n", + pos); + break; + } + + recv = &ic->i_recvs[pos]; + ret = rds_ib_recv_refill_one(conn, recv, gfp); + if (ret) { + must_wake = true; + break; + } + + rdsdebug("recv %p ibinc %p page %p addr %lu\n", recv, + recv->r_ibinc, sg_page(&recv->r_frag->f_sg), + (long)sg_dma_address(&recv->r_frag->f_sg)); + + /* XXX when can this fail? */ + ret = ib_post_recv(ic->i_cm_id->qp, &recv->r_wr, NULL); + if (ret) { + rds_ib_conn_error(conn, "recv post on " + "%pI6c returned %d, disconnecting and " + "reconnecting\n", &conn->c_faddr, + ret); + break; + } + + posted++; + + if ((posted > 128 && need_resched()) || posted > 8192) { + must_wake = true; + break; + } + } + + /* We're doing flow control - update the window. */ + if (ic->i_flowctl && posted) + rds_ib_advertise_credits(conn, posted); + + if (ret) + rds_ib_ring_unalloc(&ic->i_recv_ring, 1); + + release_refill(conn); + + /* if we're called from the softirq handler, we'll be GFP_NOWAIT. + * in this case the ring being low is going to lead to more interrupts + * and we can safely let the softirq code take care of it unless the + * ring is completely empty. + * + * if we're called from krdsd, we'll be GFP_KERNEL. In this case + * we might have raced with the softirq code while we had the refill + * lock held. Use rds_ib_ring_low() instead of ring_empty to decide + * if we should requeue. + */ + if (rds_conn_up(conn) && + (must_wake || + (can_wait && rds_ib_ring_low(&ic->i_recv_ring)) || + rds_ib_ring_empty(&ic->i_recv_ring))) { + queue_delayed_work(rds_wq, &conn->c_recv_w, 1); + } + if (can_wait) + cond_resched(); +} + +/* + * We want to recycle several types of recv allocations, like incs and frags. + * To use this, the *_free() function passes in the ptr to a list_head within + * the recyclee, as well as the cache to put it on. + * + * First, we put the memory on a percpu list. When this reaches a certain size, + * We move it to an intermediate non-percpu list in a lockless manner, with some + * xchg/compxchg wizardry. + * + * N.B. Instead of a list_head as the anchor, we use a single pointer, which can + * be NULL and xchg'd. The list is actually empty when the pointer is NULL, and + * list_empty() will return true with one element is actually present. + */ +static void rds_ib_recv_cache_put(struct list_head *new_item, + struct rds_ib_refill_cache *cache) +{ + unsigned long flags; + struct list_head *old, *chpfirst; + + local_irq_save(flags); + + chpfirst = __this_cpu_read(cache->percpu->first); + if (!chpfirst) + INIT_LIST_HEAD(new_item); + else /* put on front */ + list_add_tail(new_item, chpfirst); + + __this_cpu_write(cache->percpu->first, new_item); + __this_cpu_inc(cache->percpu->count); + + if (__this_cpu_read(cache->percpu->count) < RDS_IB_RECYCLE_BATCH_COUNT) + goto end; + + /* + * Return our per-cpu first list to the cache's xfer by atomically + * grabbing the current xfer list, appending it to our per-cpu list, + * and then atomically returning that entire list back to the + * cache's xfer list as long as it's still empty. + */ + do { + old = xchg(&cache->xfer, NULL); + if (old) + list_splice_entire_tail(old, chpfirst); + old = cmpxchg(&cache->xfer, NULL, chpfirst); + } while (old); + + + __this_cpu_write(cache->percpu->first, NULL); + __this_cpu_write(cache->percpu->count, 0); +end: + local_irq_restore(flags); +} + +static struct list_head *rds_ib_recv_cache_get(struct rds_ib_refill_cache *cache) +{ + struct list_head *head = cache->ready; + + if (head) { + if (!list_empty(head)) { + cache->ready = head->next; + list_del_init(head); + } else + cache->ready = NULL; + } + + return head; +} + +int rds_ib_inc_copy_to_user(struct rds_incoming *inc, struct iov_iter *to) +{ + struct rds_ib_incoming *ibinc; + struct rds_page_frag *frag; + unsigned long to_copy; + unsigned long frag_off = 0; + int copied = 0; + int ret; + u32 len; + + ibinc = container_of(inc, struct rds_ib_incoming, ii_inc); + frag = list_entry(ibinc->ii_frags.next, struct rds_page_frag, f_item); + len = be32_to_cpu(inc->i_hdr.h_len); + + while (iov_iter_count(to) && copied < len) { + if (frag_off == RDS_FRAG_SIZE) { + frag = list_entry(frag->f_item.next, + struct rds_page_frag, f_item); + frag_off = 0; + } + to_copy = min_t(unsigned long, iov_iter_count(to), + RDS_FRAG_SIZE - frag_off); + to_copy = min_t(unsigned long, to_copy, len - copied); + + /* XXX needs + offset for multiple recvs per page */ + rds_stats_add(s_copy_to_user, to_copy); + ret = copy_page_to_iter(sg_page(&frag->f_sg), + frag->f_sg.offset + frag_off, + to_copy, + to); + if (ret != to_copy) + return -EFAULT; + + frag_off += to_copy; + copied += to_copy; + } + + return copied; +} + +/* ic starts out kzalloc()ed */ +void rds_ib_recv_init_ack(struct rds_ib_connection *ic) +{ + struct ib_send_wr *wr = &ic->i_ack_wr; + struct ib_sge *sge = &ic->i_ack_sge; + + sge->addr = ic->i_ack_dma; + sge->length = sizeof(struct rds_header); + sge->lkey = ic->i_pd->local_dma_lkey; + + wr->sg_list = sge; + wr->num_sge = 1; + wr->opcode = IB_WR_SEND; + wr->wr_id = RDS_IB_ACK_WR_ID; + wr->send_flags = IB_SEND_SIGNALED | IB_SEND_SOLICITED; +} + +/* + * You'd think that with reliable IB connections you wouldn't need to ack + * messages that have been received. The problem is that IB hardware generates + * an ack message before it has DMAed the message into memory. This creates a + * potential message loss if the HCA is disabled for any reason between when it + * sends the ack and before the message is DMAed and processed. This is only a + * potential issue if another HCA is available for fail-over. + * + * When the remote host receives our ack they'll free the sent message from + * their send queue. To decrease the latency of this we always send an ack + * immediately after we've received messages. + * + * For simplicity, we only have one ack in flight at a time. This puts + * pressure on senders to have deep enough send queues to absorb the latency of + * a single ack frame being in flight. This might not be good enough. + * + * This is implemented by have a long-lived send_wr and sge which point to a + * statically allocated ack frame. This ack wr does not fall under the ring + * accounting that the tx and rx wrs do. The QP attribute specifically makes + * room for it beyond the ring size. Send completion notices its special + * wr_id and avoids working with the ring in that case. + */ +#ifndef KERNEL_HAS_ATOMIC64 +void rds_ib_set_ack(struct rds_ib_connection *ic, u64 seq, int ack_required) +{ + unsigned long flags; + + spin_lock_irqsave(&ic->i_ack_lock, flags); + ic->i_ack_next = seq; + if (ack_required) + set_bit(IB_ACK_REQUESTED, &ic->i_ack_flags); + spin_unlock_irqrestore(&ic->i_ack_lock, flags); +} + +static u64 rds_ib_get_ack(struct rds_ib_connection *ic) +{ + unsigned long flags; + u64 seq; + + clear_bit(IB_ACK_REQUESTED, &ic->i_ack_flags); + + spin_lock_irqsave(&ic->i_ack_lock, flags); + seq = ic->i_ack_next; + spin_unlock_irqrestore(&ic->i_ack_lock, flags); + + return seq; +} +#else +void rds_ib_set_ack(struct rds_ib_connection *ic, u64 seq, int ack_required) +{ + atomic64_set(&ic->i_ack_next, seq); + if (ack_required) { + smp_mb__before_atomic(); + set_bit(IB_ACK_REQUESTED, &ic->i_ack_flags); + } +} + +static u64 rds_ib_get_ack(struct rds_ib_connection *ic) +{ + clear_bit(IB_ACK_REQUESTED, &ic->i_ack_flags); + smp_mb__after_atomic(); + + return atomic64_read(&ic->i_ack_next); +} +#endif + + +static void rds_ib_send_ack(struct rds_ib_connection *ic, unsigned int adv_credits) +{ + struct rds_header *hdr = ic->i_ack; + u64 seq; + int ret; + + seq = rds_ib_get_ack(ic); + + rdsdebug("send_ack: ic %p ack %llu\n", ic, (unsigned long long) seq); + + ib_dma_sync_single_for_cpu(ic->rds_ibdev->dev, ic->i_ack_dma, + sizeof(*hdr), DMA_TO_DEVICE); + rds_message_populate_header(hdr, 0, 0, 0); + hdr->h_ack = cpu_to_be64(seq); + hdr->h_credit = adv_credits; + rds_message_make_checksum(hdr); + ib_dma_sync_single_for_device(ic->rds_ibdev->dev, ic->i_ack_dma, + sizeof(*hdr), DMA_TO_DEVICE); + + ic->i_ack_queued = jiffies; + + ret = ib_post_send(ic->i_cm_id->qp, &ic->i_ack_wr, NULL); + if (unlikely(ret)) { + /* Failed to send. Release the WR, and + * force another ACK. + */ + clear_bit(IB_ACK_IN_FLIGHT, &ic->i_ack_flags); + set_bit(IB_ACK_REQUESTED, &ic->i_ack_flags); + + rds_ib_stats_inc(s_ib_ack_send_failure); + + rds_ib_conn_error(ic->conn, "sending ack failed\n"); + } else + rds_ib_stats_inc(s_ib_ack_sent); +} + +/* + * There are 3 ways of getting acknowledgements to the peer: + * 1. We call rds_ib_attempt_ack from the recv completion handler + * to send an ACK-only frame. + * However, there can be only one such frame in the send queue + * at any time, so we may have to postpone it. + * 2. When another (data) packet is transmitted while there's + * an ACK in the queue, we piggyback the ACK sequence number + * on the data packet. + * 3. If the ACK WR is done sending, we get called from the + * send queue completion handler, and check whether there's + * another ACK pending (postponed because the WR was on the + * queue). If so, we transmit it. + * + * We maintain 2 variables: + * - i_ack_flags, which keeps track of whether the ACK WR + * is currently in the send queue or not (IB_ACK_IN_FLIGHT) + * - i_ack_next, which is the last sequence number we received + * + * Potentially, send queue and receive queue handlers can run concurrently. + * It would be nice to not have to use a spinlock to synchronize things, + * but the one problem that rules this out is that 64bit updates are + * not atomic on all platforms. Things would be a lot simpler if + * we had atomic64 or maybe cmpxchg64 everywhere. + * + * Reconnecting complicates this picture just slightly. When we + * reconnect, we may be seeing duplicate packets. The peer + * is retransmitting them, because it hasn't seen an ACK for + * them. It is important that we ACK these. + * + * ACK mitigation adds a header flag "ACK_REQUIRED"; any packet with + * this flag set *MUST* be acknowledged immediately. + */ + +/* + * When we get here, we're called from the recv queue handler. + * Check whether we ought to transmit an ACK. + */ +void rds_ib_attempt_ack(struct rds_ib_connection *ic) +{ + unsigned int adv_credits; + + if (!test_bit(IB_ACK_REQUESTED, &ic->i_ack_flags)) + return; + + if (test_and_set_bit(IB_ACK_IN_FLIGHT, &ic->i_ack_flags)) { + rds_ib_stats_inc(s_ib_ack_send_delayed); + return; + } + + /* Can we get a send credit? */ + if (!rds_ib_send_grab_credits(ic, 1, &adv_credits, 0, RDS_MAX_ADV_CREDIT)) { + rds_ib_stats_inc(s_ib_tx_throttle); + clear_bit(IB_ACK_IN_FLIGHT, &ic->i_ack_flags); + return; + } + + clear_bit(IB_ACK_REQUESTED, &ic->i_ack_flags); + rds_ib_send_ack(ic, adv_credits); +} + +/* + * We get here from the send completion handler, when the + * adapter tells us the ACK frame was sent. + */ +void rds_ib_ack_send_complete(struct rds_ib_connection *ic) +{ + clear_bit(IB_ACK_IN_FLIGHT, &ic->i_ack_flags); + rds_ib_attempt_ack(ic); +} + +/* + * This is called by the regular xmit code when it wants to piggyback + * an ACK on an outgoing frame. + */ +u64 rds_ib_piggyb_ack(struct rds_ib_connection *ic) +{ + if (test_and_clear_bit(IB_ACK_REQUESTED, &ic->i_ack_flags)) + rds_ib_stats_inc(s_ib_ack_send_piggybacked); + return rds_ib_get_ack(ic); +} + +/* + * It's kind of lame that we're copying from the posted receive pages into + * long-lived bitmaps. We could have posted the bitmaps and rdma written into + * them. But receiving new congestion bitmaps should be a *rare* event, so + * hopefully we won't need to invest that complexity in making it more + * efficient. By copying we can share a simpler core with TCP which has to + * copy. + */ +static void rds_ib_cong_recv(struct rds_connection *conn, + struct rds_ib_incoming *ibinc) +{ + struct rds_cong_map *map; + unsigned int map_off; + unsigned int map_page; + struct rds_page_frag *frag; + unsigned long frag_off; + unsigned long to_copy; + unsigned long copied; + __le64 uncongested = 0; + void *addr; + + /* catch completely corrupt packets */ + if (be32_to_cpu(ibinc->ii_inc.i_hdr.h_len) != RDS_CONG_MAP_BYTES) + return; + + map = conn->c_fcong; + map_page = 0; + map_off = 0; + + frag = list_entry(ibinc->ii_frags.next, struct rds_page_frag, f_item); + frag_off = 0; + + copied = 0; + + while (copied < RDS_CONG_MAP_BYTES) { + __le64 *src, *dst; + unsigned int k; + + to_copy = min(RDS_FRAG_SIZE - frag_off, PAGE_SIZE - map_off); + BUG_ON(to_copy & 7); /* Must be 64bit aligned. */ + + addr = kmap_atomic(sg_page(&frag->f_sg)); + + src = addr + frag->f_sg.offset + frag_off; + dst = (void *)map->m_page_addrs[map_page] + map_off; + for (k = 0; k < to_copy; k += 8) { + /* Record ports that became uncongested, ie + * bits that changed from 0 to 1. */ + uncongested |= ~(*src) & *dst; + *dst++ = *src++; + } + kunmap_atomic(addr); + + copied += to_copy; + + map_off += to_copy; + if (map_off == PAGE_SIZE) { + map_off = 0; + map_page++; + } + + frag_off += to_copy; + if (frag_off == RDS_FRAG_SIZE) { + frag = list_entry(frag->f_item.next, + struct rds_page_frag, f_item); + frag_off = 0; + } + } + + /* the congestion map is in little endian order */ + rds_cong_map_updated(map, le64_to_cpu(uncongested)); +} + +static void rds_ib_process_recv(struct rds_connection *conn, + struct rds_ib_recv_work *recv, u32 data_len, + struct rds_ib_ack_state *state) +{ + struct rds_ib_connection *ic = conn->c_transport_data; + struct rds_ib_incoming *ibinc = ic->i_ibinc; + struct rds_header *ihdr, *hdr; + dma_addr_t dma_addr = ic->i_recv_hdrs_dma[recv - ic->i_recvs]; + + /* XXX shut down the connection if port 0,0 are seen? */ + + rdsdebug("ic %p ibinc %p recv %p byte len %u\n", ic, ibinc, recv, + data_len); + + if (data_len < sizeof(struct rds_header)) { + rds_ib_conn_error(conn, "incoming message " + "from %pI6c didn't include a " + "header, disconnecting and " + "reconnecting\n", + &conn->c_faddr); + return; + } + data_len -= sizeof(struct rds_header); + + ihdr = ic->i_recv_hdrs[recv - ic->i_recvs]; + + ib_dma_sync_single_for_cpu(ic->rds_ibdev->dev, dma_addr, + sizeof(*ihdr), DMA_FROM_DEVICE); + /* Validate the checksum. */ + if (!rds_message_verify_checksum(ihdr)) { + rds_ib_conn_error(conn, "incoming message " + "from %pI6c has corrupted header - " + "forcing a reconnect\n", + &conn->c_faddr); + rds_stats_inc(s_recv_drop_bad_checksum); + goto done; + } + + /* Process the ACK sequence which comes with every packet */ + state->ack_recv = be64_to_cpu(ihdr->h_ack); + state->ack_recv_valid = 1; + + /* Process the credits update if there was one */ + if (ihdr->h_credit) + rds_ib_send_add_credits(conn, ihdr->h_credit); + + if (ihdr->h_sport == 0 && ihdr->h_dport == 0 && data_len == 0) { + /* This is an ACK-only packet. The fact that it gets + * special treatment here is that historically, ACKs + * were rather special beasts. + */ + rds_ib_stats_inc(s_ib_ack_received); + + /* + * Usually the frags make their way on to incs and are then freed as + * the inc is freed. We don't go that route, so we have to drop the + * page ref ourselves. We can't just leave the page on the recv + * because that confuses the dma mapping of pages and each recv's use + * of a partial page. + * + * FIXME: Fold this into the code path below. + */ + rds_ib_frag_free(ic, recv->r_frag); + recv->r_frag = NULL; + goto done; + } + + /* + * If we don't already have an inc on the connection then this + * fragment has a header and starts a message.. copy its header + * into the inc and save the inc so we can hang upcoming fragments + * off its list. + */ + if (!ibinc) { + ibinc = recv->r_ibinc; + recv->r_ibinc = NULL; + ic->i_ibinc = ibinc; + + hdr = &ibinc->ii_inc.i_hdr; + ibinc->ii_inc.i_rx_lat_trace[RDS_MSG_RX_HDR] = + local_clock(); + memcpy(hdr, ihdr, sizeof(*hdr)); + ic->i_recv_data_rem = be32_to_cpu(hdr->h_len); + ibinc->ii_inc.i_rx_lat_trace[RDS_MSG_RX_START] = + local_clock(); + + rdsdebug("ic %p ibinc %p rem %u flag 0x%x\n", ic, ibinc, + ic->i_recv_data_rem, hdr->h_flags); + } else { + hdr = &ibinc->ii_inc.i_hdr; + /* We can't just use memcmp here; fragments of a + * single message may carry different ACKs */ + if (hdr->h_sequence != ihdr->h_sequence || + hdr->h_len != ihdr->h_len || + hdr->h_sport != ihdr->h_sport || + hdr->h_dport != ihdr->h_dport) { + rds_ib_conn_error(conn, + "fragment header mismatch; forcing reconnect\n"); + goto done; + } + } + + list_add_tail(&recv->r_frag->f_item, &ibinc->ii_frags); + recv->r_frag = NULL; + + if (ic->i_recv_data_rem > RDS_FRAG_SIZE) + ic->i_recv_data_rem -= RDS_FRAG_SIZE; + else { + ic->i_recv_data_rem = 0; + ic->i_ibinc = NULL; + + if (ibinc->ii_inc.i_hdr.h_flags == RDS_FLAG_CONG_BITMAP) { + rds_ib_cong_recv(conn, ibinc); + } else { + rds_recv_incoming(conn, &conn->c_faddr, &conn->c_laddr, + &ibinc->ii_inc, GFP_ATOMIC); + state->ack_next = be64_to_cpu(hdr->h_sequence); + state->ack_next_valid = 1; + } + + /* Evaluate the ACK_REQUIRED flag *after* we received + * the complete frame, and after bumping the next_rx + * sequence. */ + if (hdr->h_flags & RDS_FLAG_ACK_REQUIRED) { + rds_stats_inc(s_recv_ack_required); + state->ack_required = 1; + } + + rds_inc_put(&ibinc->ii_inc); + } +done: + ib_dma_sync_single_for_device(ic->rds_ibdev->dev, dma_addr, + sizeof(*ihdr), DMA_FROM_DEVICE); +} + +void rds_ib_recv_cqe_handler(struct rds_ib_connection *ic, + struct ib_wc *wc, + struct rds_ib_ack_state *state) +{ + struct rds_connection *conn = ic->conn; + struct rds_ib_recv_work *recv; + + rdsdebug("wc wr_id 0x%llx status %u (%s) byte_len %u imm_data %u\n", + (unsigned long long)wc->wr_id, wc->status, + ib_wc_status_msg(wc->status), wc->byte_len, + be32_to_cpu(wc->ex.imm_data)); + + rds_ib_stats_inc(s_ib_rx_cq_event); + recv = &ic->i_recvs[rds_ib_ring_oldest(&ic->i_recv_ring)]; + ib_dma_unmap_sg(ic->i_cm_id->device, &recv->r_frag->f_sg, 1, + DMA_FROM_DEVICE); + + /* Also process recvs in connecting state because it is possible + * to get a recv completion _before_ the rdmacm ESTABLISHED + * event is processed. + */ + if (wc->status == IB_WC_SUCCESS) { + rds_ib_process_recv(conn, recv, wc->byte_len, state); + } else { + /* We expect errors as the qp is drained during shutdown */ + if (rds_conn_up(conn) || rds_conn_connecting(conn)) + rds_ib_conn_error(conn, "recv completion on <%pI6c,%pI6c, %d> had status %u (%s), vendor err 0x%x, disconnecting and reconnecting\n", + &conn->c_laddr, &conn->c_faddr, + conn->c_tos, wc->status, + ib_wc_status_msg(wc->status), + wc->vendor_err); + } + + /* rds_ib_process_recv() doesn't always consume the frag, and + * we might not have called it at all if the wc didn't indicate + * success. We already unmapped the frag's pages, though, and + * the following rds_ib_ring_free() call tells the refill path + * that it will not find an allocated frag here. Make sure we + * keep that promise by freeing a frag that's still on the ring. + */ + if (recv->r_frag) { + rds_ib_frag_free(ic, recv->r_frag); + recv->r_frag = NULL; + } + rds_ib_ring_free(&ic->i_recv_ring, 1); + + /* If we ever end up with a really empty receive ring, we're + * in deep trouble, as the sender will definitely see RNR + * timeouts. */ + if (rds_ib_ring_empty(&ic->i_recv_ring)) + rds_ib_stats_inc(s_ib_rx_ring_empty); + + if (rds_ib_ring_low(&ic->i_recv_ring)) { + rds_ib_recv_refill(conn, 0, GFP_NOWAIT | __GFP_NOWARN); + rds_ib_stats_inc(s_ib_rx_refill_from_cq); + } +} + +int rds_ib_recv_path(struct rds_conn_path *cp) +{ + struct rds_connection *conn = cp->cp_conn; + struct rds_ib_connection *ic = conn->c_transport_data; + + rdsdebug("conn %p\n", conn); + if (rds_conn_up(conn)) { + rds_ib_attempt_ack(ic); + rds_ib_recv_refill(conn, 0, GFP_KERNEL); + rds_ib_stats_inc(s_ib_rx_refill_from_thread); + } + + return 0; +} + +int rds_ib_recv_init(void) +{ + struct sysinfo si; + int ret = -ENOMEM; + + /* Default to 30% of all available RAM for recv memory */ + si_meminfo(&si); + rds_ib_sysctl_max_recv_allocation = si.totalram / 3 * PAGE_SIZE / RDS_FRAG_SIZE; + + rds_ib_incoming_slab = + kmem_cache_create_usercopy("rds_ib_incoming", + sizeof(struct rds_ib_incoming), + 0, SLAB_HWCACHE_ALIGN, + offsetof(struct rds_ib_incoming, + ii_inc.i_usercopy), + sizeof(struct rds_inc_usercopy), + NULL); + if (!rds_ib_incoming_slab) + goto out; + + rds_ib_frag_slab = kmem_cache_create("rds_ib_frag", + sizeof(struct rds_page_frag), + 0, SLAB_HWCACHE_ALIGN, NULL); + if (!rds_ib_frag_slab) { + kmem_cache_destroy(rds_ib_incoming_slab); + rds_ib_incoming_slab = NULL; + } else + ret = 0; +out: + return ret; +} + +void rds_ib_recv_exit(void) +{ + WARN_ON(atomic_read(&rds_ib_allocation)); + + kmem_cache_destroy(rds_ib_incoming_slab); + kmem_cache_destroy(rds_ib_frag_slab); +} diff --git a/net/rds/ib_ring.c b/net/rds/ib_ring.c new file mode 100644 index 000000000..006b2e441 --- /dev/null +++ b/net/rds/ib_ring.c @@ -0,0 +1,168 @@ +/* + * Copyright (c) 2006 Oracle. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include + +#include "rds.h" +#include "ib.h" + +/* + * Locking for IB rings. + * We assume that allocation is always protected by a mutex + * in the caller (this is a valid assumption for the current + * implementation). + * + * Freeing always happens in an interrupt, and hence only + * races with allocations, but not with other free()s. + * + * The interaction between allocation and freeing is that + * the alloc code has to determine the number of free entries. + * To this end, we maintain two counters; an allocation counter + * and a free counter. Both are allowed to run freely, and wrap + * around. + * The number of used entries is always (alloc_ctr - free_ctr) % NR. + * + * The current implementation makes free_ctr atomic. When the + * caller finds an allocation fails, it should set an "alloc fail" + * bit and retry the allocation. The "alloc fail" bit essentially tells + * the CQ completion handlers to wake it up after freeing some + * more entries. + */ + +/* + * This only happens on shutdown. + */ +DECLARE_WAIT_QUEUE_HEAD(rds_ib_ring_empty_wait); + +void rds_ib_ring_init(struct rds_ib_work_ring *ring, u32 nr) +{ + memset(ring, 0, sizeof(*ring)); + ring->w_nr = nr; + rdsdebug("ring %p nr %u\n", ring, ring->w_nr); +} + +static inline u32 __rds_ib_ring_used(struct rds_ib_work_ring *ring) +{ + u32 diff; + + /* This assumes that atomic_t has at least as many bits as u32 */ + diff = ring->w_alloc_ctr - (u32) atomic_read(&ring->w_free_ctr); + BUG_ON(diff > ring->w_nr); + + return diff; +} + +void rds_ib_ring_resize(struct rds_ib_work_ring *ring, u32 nr) +{ + /* We only ever get called from the connection setup code, + * prior to creating the QP. */ + BUG_ON(__rds_ib_ring_used(ring)); + ring->w_nr = nr; +} + +static int __rds_ib_ring_empty(struct rds_ib_work_ring *ring) +{ + return __rds_ib_ring_used(ring) == 0; +} + +u32 rds_ib_ring_alloc(struct rds_ib_work_ring *ring, u32 val, u32 *pos) +{ + u32 ret = 0, avail; + + avail = ring->w_nr - __rds_ib_ring_used(ring); + + rdsdebug("ring %p val %u next %u free %u\n", ring, val, + ring->w_alloc_ptr, avail); + + if (val && avail) { + ret = min(val, avail); + *pos = ring->w_alloc_ptr; + + ring->w_alloc_ptr = (ring->w_alloc_ptr + ret) % ring->w_nr; + ring->w_alloc_ctr += ret; + } + + return ret; +} + +void rds_ib_ring_free(struct rds_ib_work_ring *ring, u32 val) +{ + ring->w_free_ptr = (ring->w_free_ptr + val) % ring->w_nr; + atomic_add(val, &ring->w_free_ctr); + + if (__rds_ib_ring_empty(ring) && + waitqueue_active(&rds_ib_ring_empty_wait)) + wake_up(&rds_ib_ring_empty_wait); +} + +void rds_ib_ring_unalloc(struct rds_ib_work_ring *ring, u32 val) +{ + ring->w_alloc_ptr = (ring->w_alloc_ptr - val) % ring->w_nr; + ring->w_alloc_ctr -= val; +} + +int rds_ib_ring_empty(struct rds_ib_work_ring *ring) +{ + return __rds_ib_ring_empty(ring); +} + +int rds_ib_ring_low(struct rds_ib_work_ring *ring) +{ + return __rds_ib_ring_used(ring) <= (ring->w_nr >> 1); +} + +/* + * returns the oldest allocated ring entry. This will be the next one + * freed. This can't be called if there are none allocated. + */ +u32 rds_ib_ring_oldest(struct rds_ib_work_ring *ring) +{ + return ring->w_free_ptr; +} + +/* + * returns the number of completed work requests. + */ + +u32 rds_ib_ring_completed(struct rds_ib_work_ring *ring, u32 wr_id, u32 oldest) +{ + u32 ret; + + if (oldest <= (unsigned long long)wr_id) + ret = (unsigned long long)wr_id - oldest + 1; + else + ret = ring->w_nr - oldest + (unsigned long long)wr_id + 1; + + rdsdebug("ring %p ret %u wr_id %u oldest %u\n", ring, ret, + wr_id, oldest); + return ret; +} diff --git a/net/rds/ib_send.c b/net/rds/ib_send.c new file mode 100644 index 000000000..4190b90ff --- /dev/null +++ b/net/rds/ib_send.c @@ -0,0 +1,1017 @@ +/* + * Copyright (c) 2006, 2019 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include +#include + +#include "rds_single_path.h" +#include "rds.h" +#include "ib.h" +#include "ib_mr.h" + +/* + * Convert IB-specific error message to RDS error message and call core + * completion handler. + */ +static void rds_ib_send_complete(struct rds_message *rm, + int wc_status, + void (*complete)(struct rds_message *rm, int status)) +{ + int notify_status; + + switch (wc_status) { + case IB_WC_WR_FLUSH_ERR: + return; + + case IB_WC_SUCCESS: + notify_status = RDS_RDMA_SUCCESS; + break; + + case IB_WC_REM_ACCESS_ERR: + notify_status = RDS_RDMA_REMOTE_ERROR; + break; + + default: + notify_status = RDS_RDMA_OTHER_ERROR; + break; + } + complete(rm, notify_status); +} + +static void rds_ib_send_unmap_data(struct rds_ib_connection *ic, + struct rm_data_op *op, + int wc_status) +{ + if (op->op_nents) + ib_dma_unmap_sg(ic->i_cm_id->device, + op->op_sg, op->op_nents, + DMA_TO_DEVICE); +} + +static void rds_ib_send_unmap_rdma(struct rds_ib_connection *ic, + struct rm_rdma_op *op, + int wc_status) +{ + if (op->op_mapped) { + ib_dma_unmap_sg(ic->i_cm_id->device, + op->op_sg, op->op_nents, + op->op_write ? DMA_TO_DEVICE : DMA_FROM_DEVICE); + op->op_mapped = 0; + } + + /* If the user asked for a completion notification on this + * message, we can implement three different semantics: + * 1. Notify when we received the ACK on the RDS message + * that was queued with the RDMA. This provides reliable + * notification of RDMA status at the expense of a one-way + * packet delay. + * 2. Notify when the IB stack gives us the completion event for + * the RDMA operation. + * 3. Notify when the IB stack gives us the completion event for + * the accompanying RDS messages. + * Here, we implement approach #3. To implement approach #2, + * we would need to take an event for the rdma WR. To implement #1, + * don't call rds_rdma_send_complete at all, and fall back to the notify + * handling in the ACK processing code. + * + * Note: There's no need to explicitly sync any RDMA buffers using + * ib_dma_sync_sg_for_cpu - the completion for the RDMA + * operation itself unmapped the RDMA buffers, which takes care + * of synching. + */ + rds_ib_send_complete(container_of(op, struct rds_message, rdma), + wc_status, rds_rdma_send_complete); + + if (op->op_write) + rds_stats_add(s_send_rdma_bytes, op->op_bytes); + else + rds_stats_add(s_recv_rdma_bytes, op->op_bytes); +} + +static void rds_ib_send_unmap_atomic(struct rds_ib_connection *ic, + struct rm_atomic_op *op, + int wc_status) +{ + /* unmap atomic recvbuf */ + if (op->op_mapped) { + ib_dma_unmap_sg(ic->i_cm_id->device, op->op_sg, 1, + DMA_FROM_DEVICE); + op->op_mapped = 0; + } + + rds_ib_send_complete(container_of(op, struct rds_message, atomic), + wc_status, rds_atomic_send_complete); + + if (op->op_type == RDS_ATOMIC_TYPE_CSWP) + rds_ib_stats_inc(s_ib_atomic_cswp); + else + rds_ib_stats_inc(s_ib_atomic_fadd); +} + +/* + * Unmap the resources associated with a struct send_work. + * + * Returns the rm for no good reason other than it is unobtainable + * other than by switching on wr.opcode, currently, and the caller, + * the event handler, needs it. + */ +static struct rds_message *rds_ib_send_unmap_op(struct rds_ib_connection *ic, + struct rds_ib_send_work *send, + int wc_status) +{ + struct rds_message *rm = NULL; + + /* In the error case, wc.opcode sometimes contains garbage */ + switch (send->s_wr.opcode) { + case IB_WR_SEND: + if (send->s_op) { + rm = container_of(send->s_op, struct rds_message, data); + rds_ib_send_unmap_data(ic, send->s_op, wc_status); + } + break; + case IB_WR_RDMA_WRITE: + case IB_WR_RDMA_READ: + if (send->s_op) { + rm = container_of(send->s_op, struct rds_message, rdma); + rds_ib_send_unmap_rdma(ic, send->s_op, wc_status); + } + break; + case IB_WR_ATOMIC_FETCH_AND_ADD: + case IB_WR_ATOMIC_CMP_AND_SWP: + if (send->s_op) { + rm = container_of(send->s_op, struct rds_message, atomic); + rds_ib_send_unmap_atomic(ic, send->s_op, wc_status); + } + break; + default: + printk_ratelimited(KERN_NOTICE + "RDS/IB: %s: unexpected opcode 0x%x in WR!\n", + __func__, send->s_wr.opcode); + break; + } + + send->s_wr.opcode = 0xdead; + + return rm; +} + +void rds_ib_send_init_ring(struct rds_ib_connection *ic) +{ + struct rds_ib_send_work *send; + u32 i; + + for (i = 0, send = ic->i_sends; i < ic->i_send_ring.w_nr; i++, send++) { + struct ib_sge *sge; + + send->s_op = NULL; + + send->s_wr.wr_id = i; + send->s_wr.sg_list = send->s_sge; + send->s_wr.ex.imm_data = 0; + + sge = &send->s_sge[0]; + sge->addr = ic->i_send_hdrs_dma[i]; + + sge->length = sizeof(struct rds_header); + sge->lkey = ic->i_pd->local_dma_lkey; + + send->s_sge[1].lkey = ic->i_pd->local_dma_lkey; + } +} + +void rds_ib_send_clear_ring(struct rds_ib_connection *ic) +{ + struct rds_ib_send_work *send; + u32 i; + + for (i = 0, send = ic->i_sends; i < ic->i_send_ring.w_nr; i++, send++) { + if (send->s_op && send->s_wr.opcode != 0xdead) + rds_ib_send_unmap_op(ic, send, IB_WC_WR_FLUSH_ERR); + } +} + +/* + * The only fast path caller always has a non-zero nr, so we don't + * bother testing nr before performing the atomic sub. + */ +static void rds_ib_sub_signaled(struct rds_ib_connection *ic, int nr) +{ + if ((atomic_sub_return(nr, &ic->i_signaled_sends) == 0) && + waitqueue_active(&rds_ib_ring_empty_wait)) + wake_up(&rds_ib_ring_empty_wait); + BUG_ON(atomic_read(&ic->i_signaled_sends) < 0); +} + +/* + * The _oldest/_free ring operations here race cleanly with the alloc/unalloc + * operations performed in the send path. As the sender allocs and potentially + * unallocs the next free entry in the ring it doesn't alter which is + * the next to be freed, which is what this is concerned with. + */ +void rds_ib_send_cqe_handler(struct rds_ib_connection *ic, struct ib_wc *wc) +{ + struct rds_message *rm = NULL; + struct rds_connection *conn = ic->conn; + struct rds_ib_send_work *send; + u32 completed; + u32 oldest; + u32 i = 0; + int nr_sig = 0; + + + rdsdebug("wc wr_id 0x%llx status %u (%s) byte_len %u imm_data %u\n", + (unsigned long long)wc->wr_id, wc->status, + ib_wc_status_msg(wc->status), wc->byte_len, + be32_to_cpu(wc->ex.imm_data)); + rds_ib_stats_inc(s_ib_tx_cq_event); + + if (wc->wr_id == RDS_IB_ACK_WR_ID) { + if (time_after(jiffies, ic->i_ack_queued + HZ / 2)) + rds_ib_stats_inc(s_ib_tx_stalled); + rds_ib_ack_send_complete(ic); + return; + } + + oldest = rds_ib_ring_oldest(&ic->i_send_ring); + + completed = rds_ib_ring_completed(&ic->i_send_ring, wc->wr_id, oldest); + + for (i = 0; i < completed; i++) { + send = &ic->i_sends[oldest]; + if (send->s_wr.send_flags & IB_SEND_SIGNALED) + nr_sig++; + + rm = rds_ib_send_unmap_op(ic, send, wc->status); + + if (time_after(jiffies, send->s_queued + HZ / 2)) + rds_ib_stats_inc(s_ib_tx_stalled); + + if (send->s_op) { + if (send->s_op == rm->m_final_op) { + /* If anyone waited for this message to get + * flushed out, wake them up now + */ + rds_message_unmapped(rm); + } + rds_message_put(rm); + send->s_op = NULL; + } + + oldest = (oldest + 1) % ic->i_send_ring.w_nr; + } + + rds_ib_ring_free(&ic->i_send_ring, completed); + rds_ib_sub_signaled(ic, nr_sig); + + if (test_and_clear_bit(RDS_LL_SEND_FULL, &conn->c_flags) || + test_bit(0, &conn->c_map_queued)) + queue_delayed_work(rds_wq, &conn->c_send_w, 0); + + /* We expect errors as the qp is drained during shutdown */ + if (wc->status != IB_WC_SUCCESS && rds_conn_up(conn)) { + rds_ib_conn_error(conn, "send completion on <%pI6c,%pI6c,%d> had status %u (%s), vendor err 0x%x, disconnecting and reconnecting\n", + &conn->c_laddr, &conn->c_faddr, + conn->c_tos, wc->status, + ib_wc_status_msg(wc->status), wc->vendor_err); + } +} + +/* + * This is the main function for allocating credits when sending + * messages. + * + * Conceptually, we have two counters: + * - send credits: this tells us how many WRs we're allowed + * to submit without overruning the receiver's queue. For + * each SEND WR we post, we decrement this by one. + * + * - posted credits: this tells us how many WRs we recently + * posted to the receive queue. This value is transferred + * to the peer as a "credit update" in a RDS header field. + * Every time we transmit credits to the peer, we subtract + * the amount of transferred credits from this counter. + * + * It is essential that we avoid situations where both sides have + * exhausted their send credits, and are unable to send new credits + * to the peer. We achieve this by requiring that we send at least + * one credit update to the peer before exhausting our credits. + * When new credits arrive, we subtract one credit that is withheld + * until we've posted new buffers and are ready to transmit these + * credits (see rds_ib_send_add_credits below). + * + * The RDS send code is essentially single-threaded; rds_send_xmit + * sets RDS_IN_XMIT to ensure exclusive access to the send ring. + * However, the ACK sending code is independent and can race with + * message SENDs. + * + * In the send path, we need to update the counters for send credits + * and the counter of posted buffers atomically - when we use the + * last available credit, we cannot allow another thread to race us + * and grab the posted credits counter. Hence, we have to use a + * spinlock to protect the credit counter, or use atomics. + * + * Spinlocks shared between the send and the receive path are bad, + * because they create unnecessary delays. An early implementation + * using a spinlock showed a 5% degradation in throughput at some + * loads. + * + * This implementation avoids spinlocks completely, putting both + * counters into a single atomic, and updating that atomic using + * atomic_add (in the receive path, when receiving fresh credits), + * and using atomic_cmpxchg when updating the two counters. + */ +int rds_ib_send_grab_credits(struct rds_ib_connection *ic, + u32 wanted, u32 *adv_credits, int need_posted, int max_posted) +{ + unsigned int avail, posted, got = 0, advertise; + long oldval, newval; + + *adv_credits = 0; + if (!ic->i_flowctl) + return wanted; + +try_again: + advertise = 0; + oldval = newval = atomic_read(&ic->i_credits); + posted = IB_GET_POST_CREDITS(oldval); + avail = IB_GET_SEND_CREDITS(oldval); + + rdsdebug("wanted=%u credits=%u posted=%u\n", + wanted, avail, posted); + + /* The last credit must be used to send a credit update. */ + if (avail && !posted) + avail--; + + if (avail < wanted) { + struct rds_connection *conn = ic->i_cm_id->context; + + /* Oops, there aren't that many credits left! */ + set_bit(RDS_LL_SEND_FULL, &conn->c_flags); + got = avail; + } else { + /* Sometimes you get what you want, lalala. */ + got = wanted; + } + newval -= IB_SET_SEND_CREDITS(got); + + /* + * If need_posted is non-zero, then the caller wants + * the posted regardless of whether any send credits are + * available. + */ + if (posted && (got || need_posted)) { + advertise = min_t(unsigned int, posted, max_posted); + newval -= IB_SET_POST_CREDITS(advertise); + } + + /* Finally bill everything */ + if (atomic_cmpxchg(&ic->i_credits, oldval, newval) != oldval) + goto try_again; + + *adv_credits = advertise; + return got; +} + +void rds_ib_send_add_credits(struct rds_connection *conn, unsigned int credits) +{ + struct rds_ib_connection *ic = conn->c_transport_data; + + if (credits == 0) + return; + + rdsdebug("credits=%u current=%u%s\n", + credits, + IB_GET_SEND_CREDITS(atomic_read(&ic->i_credits)), + test_bit(RDS_LL_SEND_FULL, &conn->c_flags) ? ", ll_send_full" : ""); + + atomic_add(IB_SET_SEND_CREDITS(credits), &ic->i_credits); + if (test_and_clear_bit(RDS_LL_SEND_FULL, &conn->c_flags)) + queue_delayed_work(rds_wq, &conn->c_send_w, 0); + + WARN_ON(IB_GET_SEND_CREDITS(credits) >= 16384); + + rds_ib_stats_inc(s_ib_rx_credit_updates); +} + +void rds_ib_advertise_credits(struct rds_connection *conn, unsigned int posted) +{ + struct rds_ib_connection *ic = conn->c_transport_data; + + if (posted == 0) + return; + + atomic_add(IB_SET_POST_CREDITS(posted), &ic->i_credits); + + /* Decide whether to send an update to the peer now. + * If we would send a credit update for every single buffer we + * post, we would end up with an ACK storm (ACK arrives, + * consumes buffer, we refill the ring, send ACK to remote + * advertising the newly posted buffer... ad inf) + * + * Performance pretty much depends on how often we send + * credit updates - too frequent updates mean lots of ACKs. + * Too infrequent updates, and the peer will run out of + * credits and has to throttle. + * For the time being, 16 seems to be a good compromise. + */ + if (IB_GET_POST_CREDITS(atomic_read(&ic->i_credits)) >= 16) + set_bit(IB_ACK_REQUESTED, &ic->i_ack_flags); +} + +static inline int rds_ib_set_wr_signal_state(struct rds_ib_connection *ic, + struct rds_ib_send_work *send, + bool notify) +{ + /* + * We want to delay signaling completions just enough to get + * the batching benefits but not so much that we create dead time + * on the wire. + */ + if (ic->i_unsignaled_wrs-- == 0 || notify) { + ic->i_unsignaled_wrs = rds_ib_sysctl_max_unsig_wrs; + send->s_wr.send_flags |= IB_SEND_SIGNALED; + return 1; + } + return 0; +} + +/* + * This can be called multiple times for a given message. The first time + * we see a message we map its scatterlist into the IB device so that + * we can provide that mapped address to the IB scatter gather entries + * in the IB work requests. We translate the scatterlist into a series + * of work requests that fragment the message. These work requests complete + * in order so we pass ownership of the message to the completion handler + * once we send the final fragment. + * + * The RDS core uses the c_send_lock to only enter this function once + * per connection. This makes sure that the tx ring alloc/unalloc pairs + * don't get out of sync and confuse the ring. + */ +int rds_ib_xmit(struct rds_connection *conn, struct rds_message *rm, + unsigned int hdr_off, unsigned int sg, unsigned int off) +{ + struct rds_ib_connection *ic = conn->c_transport_data; + struct ib_device *dev = ic->i_cm_id->device; + struct rds_ib_send_work *send = NULL; + struct rds_ib_send_work *first; + struct rds_ib_send_work *prev; + const struct ib_send_wr *failed_wr; + struct scatterlist *scat; + u32 pos; + u32 i; + u32 work_alloc; + u32 credit_alloc = 0; + u32 posted; + u32 adv_credits = 0; + int send_flags = 0; + int bytes_sent = 0; + int ret; + int flow_controlled = 0; + int nr_sig = 0; + + BUG_ON(off % RDS_FRAG_SIZE); + BUG_ON(hdr_off != 0 && hdr_off != sizeof(struct rds_header)); + + /* Do not send cong updates to IB loopback */ + if (conn->c_loopback + && rm->m_inc.i_hdr.h_flags & RDS_FLAG_CONG_BITMAP) { + rds_cong_map_updated(conn->c_fcong, ~(u64) 0); + scat = &rm->data.op_sg[sg]; + ret = max_t(int, RDS_CONG_MAP_BYTES, scat->length); + return sizeof(struct rds_header) + ret; + } + + /* FIXME we may overallocate here */ + if (be32_to_cpu(rm->m_inc.i_hdr.h_len) == 0) + i = 1; + else + i = DIV_ROUND_UP(be32_to_cpu(rm->m_inc.i_hdr.h_len), RDS_FRAG_SIZE); + + work_alloc = rds_ib_ring_alloc(&ic->i_send_ring, i, &pos); + if (work_alloc == 0) { + set_bit(RDS_LL_SEND_FULL, &conn->c_flags); + rds_ib_stats_inc(s_ib_tx_ring_full); + ret = -ENOMEM; + goto out; + } + + if (ic->i_flowctl) { + credit_alloc = rds_ib_send_grab_credits(ic, work_alloc, &posted, 0, RDS_MAX_ADV_CREDIT); + adv_credits += posted; + if (credit_alloc < work_alloc) { + rds_ib_ring_unalloc(&ic->i_send_ring, work_alloc - credit_alloc); + work_alloc = credit_alloc; + flow_controlled = 1; + } + if (work_alloc == 0) { + set_bit(RDS_LL_SEND_FULL, &conn->c_flags); + rds_ib_stats_inc(s_ib_tx_throttle); + ret = -ENOMEM; + goto out; + } + } + + /* map the message the first time we see it */ + if (!ic->i_data_op) { + if (rm->data.op_nents) { + rm->data.op_count = ib_dma_map_sg(dev, + rm->data.op_sg, + rm->data.op_nents, + DMA_TO_DEVICE); + rdsdebug("ic %p mapping rm %p: %d\n", ic, rm, rm->data.op_count); + if (rm->data.op_count == 0) { + rds_ib_stats_inc(s_ib_tx_sg_mapping_failure); + rds_ib_ring_unalloc(&ic->i_send_ring, work_alloc); + ret = -ENOMEM; /* XXX ? */ + goto out; + } + } else { + rm->data.op_count = 0; + } + + rds_message_addref(rm); + rm->data.op_dmasg = 0; + rm->data.op_dmaoff = 0; + ic->i_data_op = &rm->data; + + /* Finalize the header */ + if (test_bit(RDS_MSG_ACK_REQUIRED, &rm->m_flags)) + rm->m_inc.i_hdr.h_flags |= RDS_FLAG_ACK_REQUIRED; + if (test_bit(RDS_MSG_RETRANSMITTED, &rm->m_flags)) + rm->m_inc.i_hdr.h_flags |= RDS_FLAG_RETRANSMITTED; + + /* If it has a RDMA op, tell the peer we did it. This is + * used by the peer to release use-once RDMA MRs. */ + if (rm->rdma.op_active) { + struct rds_ext_header_rdma ext_hdr; + + ext_hdr.h_rdma_rkey = cpu_to_be32(rm->rdma.op_rkey); + rds_message_add_extension(&rm->m_inc.i_hdr, + RDS_EXTHDR_RDMA, &ext_hdr, sizeof(ext_hdr)); + } + if (rm->m_rdma_cookie) { + rds_message_add_rdma_dest_extension(&rm->m_inc.i_hdr, + rds_rdma_cookie_key(rm->m_rdma_cookie), + rds_rdma_cookie_offset(rm->m_rdma_cookie)); + } + + /* Note - rds_ib_piggyb_ack clears the ACK_REQUIRED bit, so + * we should not do this unless we have a chance of at least + * sticking the header into the send ring. Which is why we + * should call rds_ib_ring_alloc first. */ + rm->m_inc.i_hdr.h_ack = cpu_to_be64(rds_ib_piggyb_ack(ic)); + rds_message_make_checksum(&rm->m_inc.i_hdr); + + /* + * Update adv_credits since we reset the ACK_REQUIRED bit. + */ + if (ic->i_flowctl) { + rds_ib_send_grab_credits(ic, 0, &posted, 1, RDS_MAX_ADV_CREDIT - adv_credits); + adv_credits += posted; + BUG_ON(adv_credits > 255); + } + } + + /* Sometimes you want to put a fence between an RDMA + * READ and the following SEND. + * We could either do this all the time + * or when requested by the user. Right now, we let + * the application choose. + */ + if (rm->rdma.op_active && rm->rdma.op_fence) + send_flags = IB_SEND_FENCE; + + /* Each frag gets a header. Msgs may be 0 bytes */ + send = &ic->i_sends[pos]; + first = send; + prev = NULL; + scat = &ic->i_data_op->op_sg[rm->data.op_dmasg]; + i = 0; + do { + unsigned int len = 0; + + /* Set up the header */ + send->s_wr.send_flags = send_flags; + send->s_wr.opcode = IB_WR_SEND; + send->s_wr.num_sge = 1; + send->s_wr.next = NULL; + send->s_queued = jiffies; + send->s_op = NULL; + + send->s_sge[0].addr = ic->i_send_hdrs_dma[pos]; + + send->s_sge[0].length = sizeof(struct rds_header); + send->s_sge[0].lkey = ic->i_pd->local_dma_lkey; + + ib_dma_sync_single_for_cpu(ic->rds_ibdev->dev, + ic->i_send_hdrs_dma[pos], + sizeof(struct rds_header), + DMA_TO_DEVICE); + memcpy(ic->i_send_hdrs[pos], &rm->m_inc.i_hdr, + sizeof(struct rds_header)); + + + /* Set up the data, if present */ + if (i < work_alloc + && scat != &rm->data.op_sg[rm->data.op_count]) { + len = min(RDS_FRAG_SIZE, + sg_dma_len(scat) - rm->data.op_dmaoff); + send->s_wr.num_sge = 2; + + send->s_sge[1].addr = sg_dma_address(scat); + send->s_sge[1].addr += rm->data.op_dmaoff; + send->s_sge[1].length = len; + send->s_sge[1].lkey = ic->i_pd->local_dma_lkey; + + bytes_sent += len; + rm->data.op_dmaoff += len; + if (rm->data.op_dmaoff == sg_dma_len(scat)) { + scat++; + rm->data.op_dmasg++; + rm->data.op_dmaoff = 0; + } + } + + rds_ib_set_wr_signal_state(ic, send, false); + + /* + * Always signal the last one if we're stopping due to flow control. + */ + if (ic->i_flowctl && flow_controlled && i == (work_alloc - 1)) { + rds_ib_set_wr_signal_state(ic, send, true); + send->s_wr.send_flags |= IB_SEND_SOLICITED; + } + + if (send->s_wr.send_flags & IB_SEND_SIGNALED) + nr_sig++; + + rdsdebug("send %p wr %p num_sge %u next %p\n", send, + &send->s_wr, send->s_wr.num_sge, send->s_wr.next); + + if (ic->i_flowctl && adv_credits) { + struct rds_header *hdr = ic->i_send_hdrs[pos]; + + /* add credit and redo the header checksum */ + hdr->h_credit = adv_credits; + rds_message_make_checksum(hdr); + adv_credits = 0; + rds_ib_stats_inc(s_ib_tx_credit_updates); + } + ib_dma_sync_single_for_device(ic->rds_ibdev->dev, + ic->i_send_hdrs_dma[pos], + sizeof(struct rds_header), + DMA_TO_DEVICE); + + if (prev) + prev->s_wr.next = &send->s_wr; + prev = send; + + pos = (pos + 1) % ic->i_send_ring.w_nr; + send = &ic->i_sends[pos]; + i++; + + } while (i < work_alloc + && scat != &rm->data.op_sg[rm->data.op_count]); + + /* Account the RDS header in the number of bytes we sent, but just once. + * The caller has no concept of fragmentation. */ + if (hdr_off == 0) + bytes_sent += sizeof(struct rds_header); + + /* if we finished the message then send completion owns it */ + if (scat == &rm->data.op_sg[rm->data.op_count]) { + prev->s_op = ic->i_data_op; + prev->s_wr.send_flags |= IB_SEND_SOLICITED; + if (!(prev->s_wr.send_flags & IB_SEND_SIGNALED)) + nr_sig += rds_ib_set_wr_signal_state(ic, prev, true); + ic->i_data_op = NULL; + } + + /* Put back wrs & credits we didn't use */ + if (i < work_alloc) { + rds_ib_ring_unalloc(&ic->i_send_ring, work_alloc - i); + work_alloc = i; + } + if (ic->i_flowctl && i < credit_alloc) + rds_ib_send_add_credits(conn, credit_alloc - i); + + if (nr_sig) + atomic_add(nr_sig, &ic->i_signaled_sends); + + /* XXX need to worry about failed_wr and partial sends. */ + failed_wr = &first->s_wr; + ret = ib_post_send(ic->i_cm_id->qp, &first->s_wr, &failed_wr); + rdsdebug("ic %p first %p (wr %p) ret %d wr %p\n", ic, + first, &first->s_wr, ret, failed_wr); + BUG_ON(failed_wr != &first->s_wr); + if (ret) { + printk(KERN_WARNING "RDS/IB: ib_post_send to %pI6c " + "returned %d\n", &conn->c_faddr, ret); + rds_ib_ring_unalloc(&ic->i_send_ring, work_alloc); + rds_ib_sub_signaled(ic, nr_sig); + if (prev->s_op) { + ic->i_data_op = prev->s_op; + prev->s_op = NULL; + } + + rds_ib_conn_error(ic->conn, "ib_post_send failed\n"); + goto out; + } + + ret = bytes_sent; +out: + BUG_ON(adv_credits); + return ret; +} + +/* + * Issue atomic operation. + * A simplified version of the rdma case, we always map 1 SG, and + * only 8 bytes, for the return value from the atomic operation. + */ +int rds_ib_xmit_atomic(struct rds_connection *conn, struct rm_atomic_op *op) +{ + struct rds_ib_connection *ic = conn->c_transport_data; + struct rds_ib_send_work *send = NULL; + const struct ib_send_wr *failed_wr; + u32 pos; + u32 work_alloc; + int ret; + int nr_sig = 0; + + work_alloc = rds_ib_ring_alloc(&ic->i_send_ring, 1, &pos); + if (work_alloc != 1) { + rds_ib_stats_inc(s_ib_tx_ring_full); + ret = -ENOMEM; + goto out; + } + + /* address of send request in ring */ + send = &ic->i_sends[pos]; + send->s_queued = jiffies; + + if (op->op_type == RDS_ATOMIC_TYPE_CSWP) { + send->s_atomic_wr.wr.opcode = IB_WR_MASKED_ATOMIC_CMP_AND_SWP; + send->s_atomic_wr.compare_add = op->op_m_cswp.compare; + send->s_atomic_wr.swap = op->op_m_cswp.swap; + send->s_atomic_wr.compare_add_mask = op->op_m_cswp.compare_mask; + send->s_atomic_wr.swap_mask = op->op_m_cswp.swap_mask; + } else { /* FADD */ + send->s_atomic_wr.wr.opcode = IB_WR_MASKED_ATOMIC_FETCH_AND_ADD; + send->s_atomic_wr.compare_add = op->op_m_fadd.add; + send->s_atomic_wr.swap = 0; + send->s_atomic_wr.compare_add_mask = op->op_m_fadd.nocarry_mask; + send->s_atomic_wr.swap_mask = 0; + } + send->s_wr.send_flags = 0; + nr_sig = rds_ib_set_wr_signal_state(ic, send, op->op_notify); + send->s_atomic_wr.wr.num_sge = 1; + send->s_atomic_wr.wr.next = NULL; + send->s_atomic_wr.remote_addr = op->op_remote_addr; + send->s_atomic_wr.rkey = op->op_rkey; + send->s_op = op; + rds_message_addref(container_of(send->s_op, struct rds_message, atomic)); + + /* map 8 byte retval buffer to the device */ + ret = ib_dma_map_sg(ic->i_cm_id->device, op->op_sg, 1, DMA_FROM_DEVICE); + rdsdebug("ic %p mapping atomic op %p. mapped %d pg\n", ic, op, ret); + if (ret != 1) { + rds_ib_ring_unalloc(&ic->i_send_ring, work_alloc); + rds_ib_stats_inc(s_ib_tx_sg_mapping_failure); + ret = -ENOMEM; /* XXX ? */ + goto out; + } + + /* Convert our struct scatterlist to struct ib_sge */ + send->s_sge[0].addr = sg_dma_address(op->op_sg); + send->s_sge[0].length = sg_dma_len(op->op_sg); + send->s_sge[0].lkey = ic->i_pd->local_dma_lkey; + + rdsdebug("rva %Lx rpa %Lx len %u\n", op->op_remote_addr, + send->s_sge[0].addr, send->s_sge[0].length); + + if (nr_sig) + atomic_add(nr_sig, &ic->i_signaled_sends); + + failed_wr = &send->s_atomic_wr.wr; + ret = ib_post_send(ic->i_cm_id->qp, &send->s_atomic_wr.wr, &failed_wr); + rdsdebug("ic %p send %p (wr %p) ret %d wr %p\n", ic, + send, &send->s_atomic_wr, ret, failed_wr); + BUG_ON(failed_wr != &send->s_atomic_wr.wr); + if (ret) { + printk(KERN_WARNING "RDS/IB: atomic ib_post_send to %pI6c " + "returned %d\n", &conn->c_faddr, ret); + rds_ib_ring_unalloc(&ic->i_send_ring, work_alloc); + rds_ib_sub_signaled(ic, nr_sig); + goto out; + } + + if (unlikely(failed_wr != &send->s_atomic_wr.wr)) { + printk(KERN_WARNING "RDS/IB: atomic ib_post_send() rc=%d, but failed_wqe updated!\n", ret); + BUG_ON(failed_wr != &send->s_atomic_wr.wr); + } + +out: + return ret; +} + +int rds_ib_xmit_rdma(struct rds_connection *conn, struct rm_rdma_op *op) +{ + struct rds_ib_connection *ic = conn->c_transport_data; + struct rds_ib_send_work *send = NULL; + struct rds_ib_send_work *first; + struct rds_ib_send_work *prev; + const struct ib_send_wr *failed_wr; + struct scatterlist *scat; + unsigned long len; + u64 remote_addr = op->op_remote_addr; + u32 max_sge = ic->rds_ibdev->max_sge; + u32 pos; + u32 work_alloc; + u32 i; + u32 j; + int sent; + int ret; + int num_sge; + int nr_sig = 0; + u64 odp_addr = op->op_odp_addr; + u32 odp_lkey = 0; + + /* map the op the first time we see it */ + if (!op->op_odp_mr) { + if (!op->op_mapped) { + op->op_count = + ib_dma_map_sg(ic->i_cm_id->device, op->op_sg, + op->op_nents, + (op->op_write) ? DMA_TO_DEVICE : + DMA_FROM_DEVICE); + rdsdebug("ic %p mapping op %p: %d\n", ic, op, + op->op_count); + if (op->op_count == 0) { + rds_ib_stats_inc(s_ib_tx_sg_mapping_failure); + ret = -ENOMEM; /* XXX ? */ + goto out; + } + op->op_mapped = 1; + } + } else { + op->op_count = op->op_nents; + odp_lkey = rds_ib_get_lkey(op->op_odp_mr->r_trans_private); + } + + /* + * Instead of knowing how to return a partial rdma read/write we insist that there + * be enough work requests to send the entire message. + */ + i = DIV_ROUND_UP(op->op_count, max_sge); + + work_alloc = rds_ib_ring_alloc(&ic->i_send_ring, i, &pos); + if (work_alloc != i) { + rds_ib_ring_unalloc(&ic->i_send_ring, work_alloc); + rds_ib_stats_inc(s_ib_tx_ring_full); + ret = -ENOMEM; + goto out; + } + + send = &ic->i_sends[pos]; + first = send; + prev = NULL; + scat = &op->op_sg[0]; + sent = 0; + num_sge = op->op_count; + + for (i = 0; i < work_alloc && scat != &op->op_sg[op->op_count]; i++) { + send->s_wr.send_flags = 0; + send->s_queued = jiffies; + send->s_op = NULL; + + if (!op->op_notify) + nr_sig += rds_ib_set_wr_signal_state(ic, send, + op->op_notify); + + send->s_wr.opcode = op->op_write ? IB_WR_RDMA_WRITE : IB_WR_RDMA_READ; + send->s_rdma_wr.remote_addr = remote_addr; + send->s_rdma_wr.rkey = op->op_rkey; + + if (num_sge > max_sge) { + send->s_rdma_wr.wr.num_sge = max_sge; + num_sge -= max_sge; + } else { + send->s_rdma_wr.wr.num_sge = num_sge; + } + + send->s_rdma_wr.wr.next = NULL; + + if (prev) + prev->s_rdma_wr.wr.next = &send->s_rdma_wr.wr; + + for (j = 0; j < send->s_rdma_wr.wr.num_sge && + scat != &op->op_sg[op->op_count]; j++) { + len = sg_dma_len(scat); + if (!op->op_odp_mr) { + send->s_sge[j].addr = sg_dma_address(scat); + send->s_sge[j].lkey = ic->i_pd->local_dma_lkey; + } else { + send->s_sge[j].addr = odp_addr; + send->s_sge[j].lkey = odp_lkey; + } + send->s_sge[j].length = len; + + sent += len; + rdsdebug("ic %p sent %d remote_addr %llu\n", ic, sent, remote_addr); + + remote_addr += len; + odp_addr += len; + scat++; + } + + rdsdebug("send %p wr %p num_sge %u next %p\n", send, + &send->s_rdma_wr.wr, + send->s_rdma_wr.wr.num_sge, + send->s_rdma_wr.wr.next); + + prev = send; + if (++send == &ic->i_sends[ic->i_send_ring.w_nr]) + send = ic->i_sends; + } + + /* give a reference to the last op */ + if (scat == &op->op_sg[op->op_count]) { + prev->s_op = op; + rds_message_addref(container_of(op, struct rds_message, rdma)); + } + + if (i < work_alloc) { + rds_ib_ring_unalloc(&ic->i_send_ring, work_alloc - i); + work_alloc = i; + } + + if (nr_sig) + atomic_add(nr_sig, &ic->i_signaled_sends); + + failed_wr = &first->s_rdma_wr.wr; + ret = ib_post_send(ic->i_cm_id->qp, &first->s_rdma_wr.wr, &failed_wr); + rdsdebug("ic %p first %p (wr %p) ret %d wr %p\n", ic, + first, &first->s_rdma_wr.wr, ret, failed_wr); + BUG_ON(failed_wr != &first->s_rdma_wr.wr); + if (ret) { + printk(KERN_WARNING "RDS/IB: rdma ib_post_send to %pI6c " + "returned %d\n", &conn->c_faddr, ret); + rds_ib_ring_unalloc(&ic->i_send_ring, work_alloc); + rds_ib_sub_signaled(ic, nr_sig); + goto out; + } + + if (unlikely(failed_wr != &first->s_rdma_wr.wr)) { + printk(KERN_WARNING "RDS/IB: ib_post_send() rc=%d, but failed_wqe updated!\n", ret); + BUG_ON(failed_wr != &first->s_rdma_wr.wr); + } + + +out: + return ret; +} + +void rds_ib_xmit_path_complete(struct rds_conn_path *cp) +{ + struct rds_connection *conn = cp->cp_conn; + struct rds_ib_connection *ic = conn->c_transport_data; + + /* We may have a pending ACK or window update we were unable + * to send previously (due to flow control). Try again. */ + rds_ib_attempt_ack(ic); +} diff --git a/net/rds/ib_stats.c b/net/rds/ib_stats.c new file mode 100644 index 000000000..ac46d8961 --- /dev/null +++ b/net/rds/ib_stats.c @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2006 Oracle. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include + +#include "rds.h" +#include "ib.h" + +DEFINE_PER_CPU_SHARED_ALIGNED(struct rds_ib_statistics, rds_ib_stats); + +static const char *const rds_ib_stat_names[] = { + "ib_connect_raced", + "ib_listen_closed_stale", + "ib_evt_handler_call", + "ib_tasklet_call", + "ib_tx_cq_event", + "ib_tx_ring_full", + "ib_tx_throttle", + "ib_tx_sg_mapping_failure", + "ib_tx_stalled", + "ib_tx_credit_updates", + "ib_rx_cq_event", + "ib_rx_ring_empty", + "ib_rx_refill_from_cq", + "ib_rx_refill_from_thread", + "ib_rx_alloc_limit", + "ib_rx_total_frags", + "ib_rx_total_incs", + "ib_rx_credit_updates", + "ib_ack_sent", + "ib_ack_send_failure", + "ib_ack_send_delayed", + "ib_ack_send_piggybacked", + "ib_ack_received", + "ib_rdma_mr_8k_alloc", + "ib_rdma_mr_8k_free", + "ib_rdma_mr_8k_used", + "ib_rdma_mr_8k_pool_flush", + "ib_rdma_mr_8k_pool_wait", + "ib_rdma_mr_8k_pool_depleted", + "ib_rdma_mr_1m_alloc", + "ib_rdma_mr_1m_free", + "ib_rdma_mr_1m_used", + "ib_rdma_mr_1m_pool_flush", + "ib_rdma_mr_1m_pool_wait", + "ib_rdma_mr_1m_pool_depleted", + "ib_rdma_mr_8k_reused", + "ib_rdma_mr_1m_reused", + "ib_atomic_cswp", + "ib_atomic_fadd", +}; + +unsigned int rds_ib_stats_info_copy(struct rds_info_iterator *iter, + unsigned int avail) +{ + struct rds_ib_statistics stats = {0, }; + uint64_t *src; + uint64_t *sum; + size_t i; + int cpu; + + if (avail < ARRAY_SIZE(rds_ib_stat_names)) + goto out; + + for_each_online_cpu(cpu) { + src = (uint64_t *)&(per_cpu(rds_ib_stats, cpu)); + sum = (uint64_t *)&stats; + for (i = 0; i < sizeof(stats) / sizeof(uint64_t); i++) + *(sum++) += *(src++); + } + + rds_stats_info_copy(iter, (uint64_t *)&stats, rds_ib_stat_names, + ARRAY_SIZE(rds_ib_stat_names)); +out: + return ARRAY_SIZE(rds_ib_stat_names); +} diff --git a/net/rds/ib_sysctl.c b/net/rds/ib_sysctl.c new file mode 100644 index 000000000..e4e41b3af --- /dev/null +++ b/net/rds/ib_sysctl.c @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2006 Oracle. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include + +#include "ib.h" + +static struct ctl_table_header *rds_ib_sysctl_hdr; + +unsigned long rds_ib_sysctl_max_send_wr = RDS_IB_DEFAULT_SEND_WR; +unsigned long rds_ib_sysctl_max_recv_wr = RDS_IB_DEFAULT_RECV_WR; +unsigned long rds_ib_sysctl_max_recv_allocation = (128 * 1024 * 1024) / RDS_FRAG_SIZE; +static unsigned long rds_ib_sysctl_max_wr_min = 1; +/* hardware will fail CQ creation long before this */ +static unsigned long rds_ib_sysctl_max_wr_max = (u32)~0; + +unsigned long rds_ib_sysctl_max_unsig_wrs = 16; +static unsigned long rds_ib_sysctl_max_unsig_wr_min = 1; +static unsigned long rds_ib_sysctl_max_unsig_wr_max = 64; + +/* + * This sysctl does nothing. + * + * Backwards compatibility with RDS 3.0 wire protocol + * disables initial FC credit exchange. + * If it's ever possible to drop 3.0 support, + * setting this to 1 and moving init/refill of send/recv + * rings from ib_cm_connect_complete() back into ib_setup_qp() + * will cause credits to be added before protocol negotiation. + */ +unsigned int rds_ib_sysctl_flow_control = 0; + +static struct ctl_table rds_ib_sysctl_table[] = { + { + .procname = "max_send_wr", + .data = &rds_ib_sysctl_max_send_wr, + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + .extra1 = &rds_ib_sysctl_max_wr_min, + .extra2 = &rds_ib_sysctl_max_wr_max, + }, + { + .procname = "max_recv_wr", + .data = &rds_ib_sysctl_max_recv_wr, + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + .extra1 = &rds_ib_sysctl_max_wr_min, + .extra2 = &rds_ib_sysctl_max_wr_max, + }, + { + .procname = "max_unsignaled_wr", + .data = &rds_ib_sysctl_max_unsig_wrs, + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + .extra1 = &rds_ib_sysctl_max_unsig_wr_min, + .extra2 = &rds_ib_sysctl_max_unsig_wr_max, + }, + { + .procname = "max_recv_allocation", + .data = &rds_ib_sysctl_max_recv_allocation, + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + }, + { + .procname = "flow_control", + .data = &rds_ib_sysctl_flow_control, + .maxlen = sizeof(rds_ib_sysctl_flow_control), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { } +}; + +void rds_ib_sysctl_exit(void) +{ + if (rds_ib_sysctl_hdr) + unregister_net_sysctl_table(rds_ib_sysctl_hdr); +} + +int rds_ib_sysctl_init(void) +{ + rds_ib_sysctl_hdr = register_net_sysctl(&init_net, "net/rds/ib", rds_ib_sysctl_table); + if (!rds_ib_sysctl_hdr) + return -ENOMEM; + return 0; +} diff --git a/net/rds/info.c b/net/rds/info.c new file mode 100644 index 000000000..b6b46a821 --- /dev/null +++ b/net/rds/info.c @@ -0,0 +1,242 @@ +/* + * Copyright (c) 2006 Oracle. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include +#include + +#include "rds.h" + +/* + * This file implements a getsockopt() call which copies a set of fixed + * sized structs into a user-specified buffer as a means of providing + * read-only information about RDS. + * + * For a given information source there are a given number of fixed sized + * structs at a given time. The structs are only copied if the user-specified + * buffer is big enough. The destination pages that make up the buffer + * are pinned for the duration of the copy. + * + * This gives us the following benefits: + * + * - simple implementation, no copy "position" across multiple calls + * - consistent snapshot of an info source + * - atomic copy works well with whatever locking info source has + * - one portable tool to get rds info across implementations + * - long-lived tool can get info without allocating + * + * at the following costs: + * + * - info source copy must be pinned, may be "large" + */ + +struct rds_info_iterator { + struct page **pages; + void *addr; + unsigned long offset; +}; + +static DEFINE_SPINLOCK(rds_info_lock); +static rds_info_func rds_info_funcs[RDS_INFO_LAST - RDS_INFO_FIRST + 1]; + +void rds_info_register_func(int optname, rds_info_func func) +{ + int offset = optname - RDS_INFO_FIRST; + + BUG_ON(optname < RDS_INFO_FIRST || optname > RDS_INFO_LAST); + + spin_lock(&rds_info_lock); + BUG_ON(rds_info_funcs[offset]); + rds_info_funcs[offset] = func; + spin_unlock(&rds_info_lock); +} +EXPORT_SYMBOL_GPL(rds_info_register_func); + +void rds_info_deregister_func(int optname, rds_info_func func) +{ + int offset = optname - RDS_INFO_FIRST; + + BUG_ON(optname < RDS_INFO_FIRST || optname > RDS_INFO_LAST); + + spin_lock(&rds_info_lock); + BUG_ON(rds_info_funcs[offset] != func); + rds_info_funcs[offset] = NULL; + spin_unlock(&rds_info_lock); +} +EXPORT_SYMBOL_GPL(rds_info_deregister_func); + +/* + * Typically we hold an atomic kmap across multiple rds_info_copy() calls + * because the kmap is so expensive. This must be called before using blocking + * operations while holding the mapping and as the iterator is torn down. + */ +void rds_info_iter_unmap(struct rds_info_iterator *iter) +{ + if (iter->addr) { + kunmap_atomic(iter->addr); + iter->addr = NULL; + } +} + +/* + * get_user_pages() called flush_dcache_page() on the pages for us. + */ +void rds_info_copy(struct rds_info_iterator *iter, void *data, + unsigned long bytes) +{ + unsigned long this; + + while (bytes) { + if (!iter->addr) + iter->addr = kmap_atomic(*iter->pages); + + this = min(bytes, PAGE_SIZE - iter->offset); + + rdsdebug("page %p addr %p offset %lu this %lu data %p " + "bytes %lu\n", *iter->pages, iter->addr, + iter->offset, this, data, bytes); + + memcpy(iter->addr + iter->offset, data, this); + + data += this; + bytes -= this; + iter->offset += this; + + if (iter->offset == PAGE_SIZE) { + kunmap_atomic(iter->addr); + iter->addr = NULL; + iter->offset = 0; + iter->pages++; + } + } +} +EXPORT_SYMBOL_GPL(rds_info_copy); + +/* + * @optval points to the userspace buffer that the information snapshot + * will be copied into. + * + * @optlen on input is the size of the buffer in userspace. @optlen + * on output is the size of the requested snapshot in bytes. + * + * This function returns -errno if there is a failure, particularly -ENOSPC + * if the given userspace buffer was not large enough to fit the snapshot. + * On success it returns the positive number of bytes of each array element + * in the snapshot. + */ +int rds_info_getsockopt(struct socket *sock, int optname, char __user *optval, + int __user *optlen) +{ + struct rds_info_iterator iter; + struct rds_info_lengths lens; + unsigned long nr_pages = 0; + unsigned long start; + rds_info_func func; + struct page **pages = NULL; + int ret; + int len; + int total; + + if (get_user(len, optlen)) { + ret = -EFAULT; + goto out; + } + + /* check for all kinds of wrapping and the like */ + start = (unsigned long)optval; + if (len < 0 || len > INT_MAX - PAGE_SIZE + 1 || start + len < start) { + ret = -EINVAL; + goto out; + } + + /* a 0 len call is just trying to probe its length */ + if (len == 0) + goto call_func; + + nr_pages = (PAGE_ALIGN(start + len) - (start & PAGE_MASK)) + >> PAGE_SHIFT; + + pages = kmalloc_array(nr_pages, sizeof(struct page *), GFP_KERNEL); + if (!pages) { + ret = -ENOMEM; + goto out; + } + ret = pin_user_pages_fast(start, nr_pages, FOLL_WRITE, pages); + if (ret != nr_pages) { + if (ret > 0) + nr_pages = ret; + else + nr_pages = 0; + ret = -EAGAIN; /* XXX ? */ + goto out; + } + + rdsdebug("len %d nr_pages %lu\n", len, nr_pages); + +call_func: + func = rds_info_funcs[optname - RDS_INFO_FIRST]; + if (!func) { + ret = -ENOPROTOOPT; + goto out; + } + + iter.pages = pages; + iter.addr = NULL; + iter.offset = start & (PAGE_SIZE - 1); + + func(sock, len, &iter, &lens); + BUG_ON(lens.each == 0); + + total = lens.nr * lens.each; + + rds_info_iter_unmap(&iter); + + if (total > len) { + len = total; + ret = -ENOSPC; + } else { + len = total; + ret = lens.each; + } + + if (put_user(len, optlen)) + ret = -EFAULT; + +out: + if (pages) + unpin_user_pages(pages, nr_pages); + kfree(pages); + + return ret; +} diff --git a/net/rds/info.h b/net/rds/info.h new file mode 100644 index 000000000..a069b51c4 --- /dev/null +++ b/net/rds/info.h @@ -0,0 +1,31 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _RDS_INFO_H +#define _RDS_INFO_H + +struct rds_info_lengths { + unsigned int nr; + unsigned int each; +}; + +struct rds_info_iterator; + +/* + * These functions must fill in the fields of @lens to reflect the size + * of the available info source. If the snapshot fits in @len then it + * should be copied using @iter. The caller will deduce if it was copied + * or not by comparing the lengths. + */ +typedef void (*rds_info_func)(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens); + +void rds_info_register_func(int optname, rds_info_func func); +void rds_info_deregister_func(int optname, rds_info_func func); +int rds_info_getsockopt(struct socket *sock, int optname, char __user *optval, + int __user *optlen); +void rds_info_copy(struct rds_info_iterator *iter, void *data, + unsigned long bytes); +void rds_info_iter_unmap(struct rds_info_iterator *iter); + + +#endif diff --git a/net/rds/loop.c b/net/rds/loop.c new file mode 100644 index 000000000..1d73ad79c --- /dev/null +++ b/net/rds/loop.c @@ -0,0 +1,254 @@ +/* + * Copyright (c) 2006, 2017 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include +#include +#include + +#include "rds_single_path.h" +#include "rds.h" +#include "loop.h" + +static DEFINE_SPINLOCK(loop_conns_lock); +static LIST_HEAD(loop_conns); +static atomic_t rds_loop_unloading = ATOMIC_INIT(0); + +static void rds_loop_set_unloading(void) +{ + atomic_set(&rds_loop_unloading, 1); +} + +static bool rds_loop_is_unloading(struct rds_connection *conn) +{ + return atomic_read(&rds_loop_unloading) != 0; +} + +/* + * This 'loopback' transport is a special case for flows that originate + * and terminate on the same machine. + * + * Connection build-up notices if the destination address is thought of + * as a local address by a transport. At that time it decides to use the + * loopback transport instead of the bound transport of the sending socket. + * + * The loopback transport's sending path just hands the sent rds_message + * straight to the receiving path via an embedded rds_incoming. + */ + +/* + * Usually a message transits both the sender and receiver's conns as it + * flows to the receiver. In the loopback case, though, the receive path + * is handed the sending conn so the sense of the addresses is reversed. + */ +static int rds_loop_xmit(struct rds_connection *conn, struct rds_message *rm, + unsigned int hdr_off, unsigned int sg, + unsigned int off) +{ + struct scatterlist *sgp = &rm->data.op_sg[sg]; + int ret = sizeof(struct rds_header) + + be32_to_cpu(rm->m_inc.i_hdr.h_len); + + /* Do not send cong updates to loopback */ + if (rm->m_inc.i_hdr.h_flags & RDS_FLAG_CONG_BITMAP) { + rds_cong_map_updated(conn->c_fcong, ~(u64) 0); + ret = min_t(int, ret, sgp->length - conn->c_xmit_data_off); + goto out; + } + + BUG_ON(hdr_off || sg || off); + + rds_inc_init(&rm->m_inc, conn, &conn->c_laddr); + /* For the embedded inc. Matching put is in loop_inc_free() */ + rds_message_addref(rm); + + rds_recv_incoming(conn, &conn->c_laddr, &conn->c_faddr, &rm->m_inc, + GFP_KERNEL); + + rds_send_drop_acked(conn, be64_to_cpu(rm->m_inc.i_hdr.h_sequence), + NULL); + + rds_inc_put(&rm->m_inc); +out: + return ret; +} + +/* + * See rds_loop_xmit(). Since our inc is embedded in the rm, we + * make sure the rm lives at least until the inc is done. + */ +static void rds_loop_inc_free(struct rds_incoming *inc) +{ + struct rds_message *rm = container_of(inc, struct rds_message, m_inc); + + rds_message_put(rm); +} + +/* we need to at least give the thread something to succeed */ +static int rds_loop_recv_path(struct rds_conn_path *cp) +{ + return 0; +} + +struct rds_loop_connection { + struct list_head loop_node; + struct rds_connection *conn; +}; + +/* + * Even the loopback transport needs to keep track of its connections, + * so it can call rds_conn_destroy() on them on exit. N.B. there are + * 1+ loopback addresses (127.*.*.*) so it's not a bug to have + * multiple loopback conns allocated, although rather useless. + */ +static int rds_loop_conn_alloc(struct rds_connection *conn, gfp_t gfp) +{ + struct rds_loop_connection *lc; + unsigned long flags; + + lc = kzalloc(sizeof(struct rds_loop_connection), gfp); + if (!lc) + return -ENOMEM; + + INIT_LIST_HEAD(&lc->loop_node); + lc->conn = conn; + conn->c_transport_data = lc; + + spin_lock_irqsave(&loop_conns_lock, flags); + list_add_tail(&lc->loop_node, &loop_conns); + spin_unlock_irqrestore(&loop_conns_lock, flags); + + return 0; +} + +static void rds_loop_conn_free(void *arg) +{ + struct rds_loop_connection *lc = arg; + unsigned long flags; + + rdsdebug("lc %p\n", lc); + spin_lock_irqsave(&loop_conns_lock, flags); + list_del(&lc->loop_node); + spin_unlock_irqrestore(&loop_conns_lock, flags); + kfree(lc); +} + +static int rds_loop_conn_path_connect(struct rds_conn_path *cp) +{ + rds_connect_complete(cp->cp_conn); + return 0; +} + +static void rds_loop_conn_path_shutdown(struct rds_conn_path *cp) +{ +} + +void rds_loop_exit(void) +{ + struct rds_loop_connection *lc, *_lc; + LIST_HEAD(tmp_list); + + rds_loop_set_unloading(); + synchronize_rcu(); + /* avoid calling conn_destroy with irqs off */ + spin_lock_irq(&loop_conns_lock); + list_splice(&loop_conns, &tmp_list); + INIT_LIST_HEAD(&loop_conns); + spin_unlock_irq(&loop_conns_lock); + + list_for_each_entry_safe(lc, _lc, &tmp_list, loop_node) { + WARN_ON(lc->conn->c_passive); + rds_conn_destroy(lc->conn); + } +} + +static void rds_loop_kill_conns(struct net *net) +{ + struct rds_loop_connection *lc, *_lc; + LIST_HEAD(tmp_list); + + spin_lock_irq(&loop_conns_lock); + list_for_each_entry_safe(lc, _lc, &loop_conns, loop_node) { + struct net *c_net = read_pnet(&lc->conn->c_net); + + if (net != c_net) + continue; + list_move_tail(&lc->loop_node, &tmp_list); + } + spin_unlock_irq(&loop_conns_lock); + + list_for_each_entry_safe(lc, _lc, &tmp_list, loop_node) { + WARN_ON(lc->conn->c_passive); + rds_conn_destroy(lc->conn); + } +} + +static void __net_exit rds_loop_exit_net(struct net *net) +{ + rds_loop_kill_conns(net); +} + +static struct pernet_operations rds_loop_net_ops = { + .exit = rds_loop_exit_net, +}; + +int rds_loop_net_init(void) +{ + return register_pernet_device(&rds_loop_net_ops); +} + +void rds_loop_net_exit(void) +{ + unregister_pernet_device(&rds_loop_net_ops); +} + +/* + * This is missing .xmit_* because loop doesn't go through generic + * rds_send_xmit() and doesn't call rds_recv_incoming(). .listen_stop and + * .laddr_check are missing because transport.c doesn't iterate over + * rds_loop_transport. + */ +struct rds_transport rds_loop_transport = { + .xmit = rds_loop_xmit, + .recv_path = rds_loop_recv_path, + .conn_alloc = rds_loop_conn_alloc, + .conn_free = rds_loop_conn_free, + .conn_path_connect = rds_loop_conn_path_connect, + .conn_path_shutdown = rds_loop_conn_path_shutdown, + .inc_copy_to_user = rds_message_inc_copy_to_user, + .inc_free = rds_loop_inc_free, + .t_name = "loopback", + .t_type = RDS_TRANS_LOOP, + .t_unloading = rds_loop_is_unloading, +}; diff --git a/net/rds/loop.h b/net/rds/loop.h new file mode 100644 index 000000000..bbc8cdd03 --- /dev/null +++ b/net/rds/loop.h @@ -0,0 +1,12 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _RDS_LOOP_H +#define _RDS_LOOP_H + +/* loop.c */ +extern struct rds_transport rds_loop_transport; + +int rds_loop_net_init(void); +void rds_loop_net_exit(void); +void rds_loop_exit(void); + +#endif diff --git a/net/rds/message.c b/net/rds/message.c new file mode 100644 index 000000000..f71e1237e --- /dev/null +++ b/net/rds/message.c @@ -0,0 +1,521 @@ +/* + * Copyright (c) 2006, 2020 Oracle and/or its affiliates. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include +#include +#include + +#include "rds.h" + +static unsigned int rds_exthdr_size[__RDS_EXTHDR_MAX] = { +[RDS_EXTHDR_NONE] = 0, +[RDS_EXTHDR_VERSION] = sizeof(struct rds_ext_header_version), +[RDS_EXTHDR_RDMA] = sizeof(struct rds_ext_header_rdma), +[RDS_EXTHDR_RDMA_DEST] = sizeof(struct rds_ext_header_rdma_dest), +[RDS_EXTHDR_NPATHS] = sizeof(u16), +[RDS_EXTHDR_GEN_NUM] = sizeof(u32), +}; + +void rds_message_addref(struct rds_message *rm) +{ + rdsdebug("addref rm %p ref %d\n", rm, refcount_read(&rm->m_refcount)); + refcount_inc(&rm->m_refcount); +} +EXPORT_SYMBOL_GPL(rds_message_addref); + +static inline bool rds_zcookie_add(struct rds_msg_zcopy_info *info, u32 cookie) +{ + struct rds_zcopy_cookies *ck = &info->zcookies; + int ncookies = ck->num; + + if (ncookies == RDS_MAX_ZCOOKIES) + return false; + ck->cookies[ncookies] = cookie; + ck->num = ++ncookies; + return true; +} + +static struct rds_msg_zcopy_info *rds_info_from_znotifier(struct rds_znotifier *znotif) +{ + return container_of(znotif, struct rds_msg_zcopy_info, znotif); +} + +void rds_notify_msg_zcopy_purge(struct rds_msg_zcopy_queue *q) +{ + unsigned long flags; + LIST_HEAD(copy); + struct rds_msg_zcopy_info *info, *tmp; + + spin_lock_irqsave(&q->lock, flags); + list_splice(&q->zcookie_head, ©); + INIT_LIST_HEAD(&q->zcookie_head); + spin_unlock_irqrestore(&q->lock, flags); + + list_for_each_entry_safe(info, tmp, ©, rs_zcookie_next) { + list_del(&info->rs_zcookie_next); + kfree(info); + } +} + +static void rds_rm_zerocopy_callback(struct rds_sock *rs, + struct rds_znotifier *znotif) +{ + struct rds_msg_zcopy_info *info; + struct rds_msg_zcopy_queue *q; + u32 cookie = znotif->z_cookie; + struct rds_zcopy_cookies *ck; + struct list_head *head; + unsigned long flags; + + mm_unaccount_pinned_pages(&znotif->z_mmp); + q = &rs->rs_zcookie_queue; + spin_lock_irqsave(&q->lock, flags); + head = &q->zcookie_head; + if (!list_empty(head)) { + info = list_first_entry(head, struct rds_msg_zcopy_info, + rs_zcookie_next); + if (rds_zcookie_add(info, cookie)) { + spin_unlock_irqrestore(&q->lock, flags); + kfree(rds_info_from_znotifier(znotif)); + /* caller invokes rds_wake_sk_sleep() */ + return; + } + } + + info = rds_info_from_znotifier(znotif); + ck = &info->zcookies; + memset(ck, 0, sizeof(*ck)); + WARN_ON(!rds_zcookie_add(info, cookie)); + list_add_tail(&info->rs_zcookie_next, &q->zcookie_head); + + spin_unlock_irqrestore(&q->lock, flags); + /* caller invokes rds_wake_sk_sleep() */ +} + +/* + * This relies on dma_map_sg() not touching sg[].page during merging. + */ +static void rds_message_purge(struct rds_message *rm) +{ + unsigned long i, flags; + bool zcopy = false; + + if (unlikely(test_bit(RDS_MSG_PAGEVEC, &rm->m_flags))) + return; + + spin_lock_irqsave(&rm->m_rs_lock, flags); + if (rm->m_rs) { + struct rds_sock *rs = rm->m_rs; + + if (rm->data.op_mmp_znotifier) { + zcopy = true; + rds_rm_zerocopy_callback(rs, rm->data.op_mmp_znotifier); + rds_wake_sk_sleep(rs); + rm->data.op_mmp_znotifier = NULL; + } + sock_put(rds_rs_to_sk(rs)); + rm->m_rs = NULL; + } + spin_unlock_irqrestore(&rm->m_rs_lock, flags); + + for (i = 0; i < rm->data.op_nents; i++) { + /* XXX will have to put_page for page refs */ + if (!zcopy) + __free_page(sg_page(&rm->data.op_sg[i])); + else + put_page(sg_page(&rm->data.op_sg[i])); + } + rm->data.op_nents = 0; + + if (rm->rdma.op_active) + rds_rdma_free_op(&rm->rdma); + if (rm->rdma.op_rdma_mr) + kref_put(&rm->rdma.op_rdma_mr->r_kref, __rds_put_mr_final); + + if (rm->atomic.op_active) + rds_atomic_free_op(&rm->atomic); + if (rm->atomic.op_rdma_mr) + kref_put(&rm->atomic.op_rdma_mr->r_kref, __rds_put_mr_final); +} + +void rds_message_put(struct rds_message *rm) +{ + rdsdebug("put rm %p ref %d\n", rm, refcount_read(&rm->m_refcount)); + WARN(!refcount_read(&rm->m_refcount), "danger refcount zero on %p\n", rm); + if (refcount_dec_and_test(&rm->m_refcount)) { + BUG_ON(!list_empty(&rm->m_sock_item)); + BUG_ON(!list_empty(&rm->m_conn_item)); + rds_message_purge(rm); + + kfree(rm); + } +} +EXPORT_SYMBOL_GPL(rds_message_put); + +void rds_message_populate_header(struct rds_header *hdr, __be16 sport, + __be16 dport, u64 seq) +{ + hdr->h_flags = 0; + hdr->h_sport = sport; + hdr->h_dport = dport; + hdr->h_sequence = cpu_to_be64(seq); + hdr->h_exthdr[0] = RDS_EXTHDR_NONE; +} +EXPORT_SYMBOL_GPL(rds_message_populate_header); + +int rds_message_add_extension(struct rds_header *hdr, unsigned int type, + const void *data, unsigned int len) +{ + unsigned int ext_len = sizeof(u8) + len; + unsigned char *dst; + + /* For now, refuse to add more than one extension header */ + if (hdr->h_exthdr[0] != RDS_EXTHDR_NONE) + return 0; + + if (type >= __RDS_EXTHDR_MAX || len != rds_exthdr_size[type]) + return 0; + + if (ext_len >= RDS_HEADER_EXT_SPACE) + return 0; + dst = hdr->h_exthdr; + + *dst++ = type; + memcpy(dst, data, len); + + dst[len] = RDS_EXTHDR_NONE; + return 1; +} +EXPORT_SYMBOL_GPL(rds_message_add_extension); + +/* + * If a message has extension headers, retrieve them here. + * Call like this: + * + * unsigned int pos = 0; + * + * while (1) { + * buflen = sizeof(buffer); + * type = rds_message_next_extension(hdr, &pos, buffer, &buflen); + * if (type == RDS_EXTHDR_NONE) + * break; + * ... + * } + */ +int rds_message_next_extension(struct rds_header *hdr, + unsigned int *pos, void *buf, unsigned int *buflen) +{ + unsigned int offset, ext_type, ext_len; + u8 *src = hdr->h_exthdr; + + offset = *pos; + if (offset >= RDS_HEADER_EXT_SPACE) + goto none; + + /* Get the extension type and length. For now, the + * length is implied by the extension type. */ + ext_type = src[offset++]; + + if (ext_type == RDS_EXTHDR_NONE || ext_type >= __RDS_EXTHDR_MAX) + goto none; + ext_len = rds_exthdr_size[ext_type]; + if (offset + ext_len > RDS_HEADER_EXT_SPACE) + goto none; + + *pos = offset + ext_len; + if (ext_len < *buflen) + *buflen = ext_len; + memcpy(buf, src + offset, *buflen); + return ext_type; + +none: + *pos = RDS_HEADER_EXT_SPACE; + *buflen = 0; + return RDS_EXTHDR_NONE; +} + +int rds_message_add_rdma_dest_extension(struct rds_header *hdr, u32 r_key, u32 offset) +{ + struct rds_ext_header_rdma_dest ext_hdr; + + ext_hdr.h_rdma_rkey = cpu_to_be32(r_key); + ext_hdr.h_rdma_offset = cpu_to_be32(offset); + return rds_message_add_extension(hdr, RDS_EXTHDR_RDMA_DEST, &ext_hdr, sizeof(ext_hdr)); +} +EXPORT_SYMBOL_GPL(rds_message_add_rdma_dest_extension); + +/* + * Each rds_message is allocated with extra space for the scatterlist entries + * rds ops will need. This is to minimize memory allocation count. Then, each rds op + * can grab SGs when initializing its part of the rds_message. + */ +struct rds_message *rds_message_alloc(unsigned int extra_len, gfp_t gfp) +{ + struct rds_message *rm; + + if (extra_len > KMALLOC_MAX_SIZE - sizeof(struct rds_message)) + return NULL; + + rm = kzalloc(sizeof(struct rds_message) + extra_len, gfp); + if (!rm) + goto out; + + rm->m_used_sgs = 0; + rm->m_total_sgs = extra_len / sizeof(struct scatterlist); + + refcount_set(&rm->m_refcount, 1); + INIT_LIST_HEAD(&rm->m_sock_item); + INIT_LIST_HEAD(&rm->m_conn_item); + spin_lock_init(&rm->m_rs_lock); + init_waitqueue_head(&rm->m_flush_wait); + +out: + return rm; +} + +/* + * RDS ops use this to grab SG entries from the rm's sg pool. + */ +struct scatterlist *rds_message_alloc_sgs(struct rds_message *rm, int nents) +{ + struct scatterlist *sg_first = (struct scatterlist *) &rm[1]; + struct scatterlist *sg_ret; + + if (nents <= 0) { + pr_warn("rds: alloc sgs failed! nents <= 0\n"); + return ERR_PTR(-EINVAL); + } + + if (rm->m_used_sgs + nents > rm->m_total_sgs) { + pr_warn("rds: alloc sgs failed! total %d used %d nents %d\n", + rm->m_total_sgs, rm->m_used_sgs, nents); + return ERR_PTR(-ENOMEM); + } + + sg_ret = &sg_first[rm->m_used_sgs]; + sg_init_table(sg_ret, nents); + rm->m_used_sgs += nents; + + return sg_ret; +} + +struct rds_message *rds_message_map_pages(unsigned long *page_addrs, unsigned int total_len) +{ + struct rds_message *rm; + unsigned int i; + int num_sgs = DIV_ROUND_UP(total_len, PAGE_SIZE); + int extra_bytes = num_sgs * sizeof(struct scatterlist); + + rm = rds_message_alloc(extra_bytes, GFP_NOWAIT); + if (!rm) + return ERR_PTR(-ENOMEM); + + set_bit(RDS_MSG_PAGEVEC, &rm->m_flags); + rm->m_inc.i_hdr.h_len = cpu_to_be32(total_len); + rm->data.op_nents = DIV_ROUND_UP(total_len, PAGE_SIZE); + rm->data.op_sg = rds_message_alloc_sgs(rm, num_sgs); + if (IS_ERR(rm->data.op_sg)) { + void *err = ERR_CAST(rm->data.op_sg); + rds_message_put(rm); + return err; + } + + for (i = 0; i < rm->data.op_nents; ++i) { + sg_set_page(&rm->data.op_sg[i], + virt_to_page((void *)page_addrs[i]), + PAGE_SIZE, 0); + } + + return rm; +} + +static int rds_message_zcopy_from_user(struct rds_message *rm, struct iov_iter *from) +{ + struct scatterlist *sg; + int ret = 0; + int length = iov_iter_count(from); + int total_copied = 0; + struct rds_msg_zcopy_info *info; + + rm->m_inc.i_hdr.h_len = cpu_to_be32(iov_iter_count(from)); + + /* + * now allocate and copy in the data payload. + */ + sg = rm->data.op_sg; + + info = kzalloc(sizeof(*info), GFP_KERNEL); + if (!info) + return -ENOMEM; + INIT_LIST_HEAD(&info->rs_zcookie_next); + rm->data.op_mmp_znotifier = &info->znotif; + if (mm_account_pinned_pages(&rm->data.op_mmp_znotifier->z_mmp, + length)) { + ret = -ENOMEM; + goto err; + } + while (iov_iter_count(from)) { + struct page *pages; + size_t start; + ssize_t copied; + + copied = iov_iter_get_pages2(from, &pages, PAGE_SIZE, + 1, &start); + if (copied < 0) { + struct mmpin *mmp; + int i; + + for (i = 0; i < rm->data.op_nents; i++) + put_page(sg_page(&rm->data.op_sg[i])); + mmp = &rm->data.op_mmp_znotifier->z_mmp; + mm_unaccount_pinned_pages(mmp); + ret = -EFAULT; + goto err; + } + total_copied += copied; + length -= copied; + sg_set_page(sg, pages, copied, start); + rm->data.op_nents++; + sg++; + } + WARN_ON_ONCE(length != 0); + return ret; +err: + kfree(info); + rm->data.op_mmp_znotifier = NULL; + return ret; +} + +int rds_message_copy_from_user(struct rds_message *rm, struct iov_iter *from, + bool zcopy) +{ + unsigned long to_copy, nbytes; + unsigned long sg_off; + struct scatterlist *sg; + int ret = 0; + + rm->m_inc.i_hdr.h_len = cpu_to_be32(iov_iter_count(from)); + + /* now allocate and copy in the data payload. */ + sg = rm->data.op_sg; + sg_off = 0; /* Dear gcc, sg->page will be null from kzalloc. */ + + if (zcopy) + return rds_message_zcopy_from_user(rm, from); + + while (iov_iter_count(from)) { + if (!sg_page(sg)) { + ret = rds_page_remainder_alloc(sg, iov_iter_count(from), + GFP_HIGHUSER); + if (ret) + return ret; + rm->data.op_nents++; + sg_off = 0; + } + + to_copy = min_t(unsigned long, iov_iter_count(from), + sg->length - sg_off); + + rds_stats_add(s_copy_from_user, to_copy); + nbytes = copy_page_from_iter(sg_page(sg), sg->offset + sg_off, + to_copy, from); + if (nbytes != to_copy) + return -EFAULT; + + sg_off += to_copy; + + if (sg_off == sg->length) + sg++; + } + + return ret; +} + +int rds_message_inc_copy_to_user(struct rds_incoming *inc, struct iov_iter *to) +{ + struct rds_message *rm; + struct scatterlist *sg; + unsigned long to_copy; + unsigned long vec_off; + int copied; + int ret; + u32 len; + + rm = container_of(inc, struct rds_message, m_inc); + len = be32_to_cpu(rm->m_inc.i_hdr.h_len); + + sg = rm->data.op_sg; + vec_off = 0; + copied = 0; + + while (iov_iter_count(to) && copied < len) { + to_copy = min_t(unsigned long, iov_iter_count(to), + sg->length - vec_off); + to_copy = min_t(unsigned long, to_copy, len - copied); + + rds_stats_add(s_copy_to_user, to_copy); + ret = copy_page_to_iter(sg_page(sg), sg->offset + vec_off, + to_copy, to); + if (ret != to_copy) + return -EFAULT; + + vec_off += to_copy; + copied += to_copy; + + if (vec_off == sg->length) { + vec_off = 0; + sg++; + } + } + + return copied; +} + +/* + * If the message is still on the send queue, wait until the transport + * is done with it. This is particularly important for RDMA operations. + */ +void rds_message_wait(struct rds_message *rm) +{ + wait_event_interruptible(rm->m_flush_wait, + !test_bit(RDS_MSG_MAPPED, &rm->m_flags)); +} + +void rds_message_unmapped(struct rds_message *rm) +{ + clear_bit(RDS_MSG_MAPPED, &rm->m_flags); + wake_up_interruptible(&rm->m_flush_wait); +} +EXPORT_SYMBOL_GPL(rds_message_unmapped); diff --git a/net/rds/page.c b/net/rds/page.c new file mode 100644 index 000000000..7cc57e098 --- /dev/null +++ b/net/rds/page.c @@ -0,0 +1,167 @@ +/* + * Copyright (c) 2006 Oracle. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include + +#include "rds.h" + +struct rds_page_remainder { + struct page *r_page; + unsigned long r_offset; +}; + +static +DEFINE_PER_CPU_SHARED_ALIGNED(struct rds_page_remainder, rds_page_remainders); + +/** + * rds_page_remainder_alloc - build up regions of a message. + * + * @scat: Scatter list for message + * @bytes: the number of bytes needed. + * @gfp: the waiting behaviour of the allocation + * + * @gfp is always ored with __GFP_HIGHMEM. Callers must be prepared to + * kmap the pages, etc. + * + * If @bytes is at least a full page then this just returns a page from + * alloc_page(). + * + * If @bytes is a partial page then this stores the unused region of the + * page in a per-cpu structure. Future partial-page allocations may be + * satisfied from that cached region. This lets us waste less memory on + * small allocations with minimal complexity. It works because the transmit + * path passes read-only page regions down to devices. They hold a page + * reference until they are done with the region. + */ +int rds_page_remainder_alloc(struct scatterlist *scat, unsigned long bytes, + gfp_t gfp) +{ + struct rds_page_remainder *rem; + unsigned long flags; + struct page *page; + int ret; + + gfp |= __GFP_HIGHMEM; + + /* jump straight to allocation if we're trying for a huge page */ + if (bytes >= PAGE_SIZE) { + page = alloc_page(gfp); + if (!page) { + ret = -ENOMEM; + } else { + sg_set_page(scat, page, PAGE_SIZE, 0); + ret = 0; + } + goto out; + } + + rem = &per_cpu(rds_page_remainders, get_cpu()); + local_irq_save(flags); + + while (1) { + /* avoid a tiny region getting stuck by tossing it */ + if (rem->r_page && bytes > (PAGE_SIZE - rem->r_offset)) { + rds_stats_inc(s_page_remainder_miss); + __free_page(rem->r_page); + rem->r_page = NULL; + } + + /* hand out a fragment from the cached page */ + if (rem->r_page && bytes <= (PAGE_SIZE - rem->r_offset)) { + sg_set_page(scat, rem->r_page, bytes, rem->r_offset); + get_page(sg_page(scat)); + + if (rem->r_offset != 0) + rds_stats_inc(s_page_remainder_hit); + + rem->r_offset += ALIGN(bytes, 8); + if (rem->r_offset >= PAGE_SIZE) { + __free_page(rem->r_page); + rem->r_page = NULL; + } + ret = 0; + break; + } + + /* alloc if there is nothing for us to use */ + local_irq_restore(flags); + put_cpu(); + + page = alloc_page(gfp); + + rem = &per_cpu(rds_page_remainders, get_cpu()); + local_irq_save(flags); + + if (!page) { + ret = -ENOMEM; + break; + } + + /* did someone race to fill the remainder before us? */ + if (rem->r_page) { + __free_page(page); + continue; + } + + /* otherwise install our page and loop around to alloc */ + rem->r_page = page; + rem->r_offset = 0; + } + + local_irq_restore(flags); + put_cpu(); +out: + rdsdebug("bytes %lu ret %d %p %u %u\n", bytes, ret, + ret ? NULL : sg_page(scat), ret ? 0 : scat->offset, + ret ? 0 : scat->length); + return ret; +} +EXPORT_SYMBOL_GPL(rds_page_remainder_alloc); + +void rds_page_exit(void) +{ + unsigned int cpu; + + for_each_possible_cpu(cpu) { + struct rds_page_remainder *rem; + + rem = &per_cpu(rds_page_remainders, cpu); + rdsdebug("cpu %u\n", cpu); + + if (rem->r_page) + __free_page(rem->r_page); + rem->r_page = NULL; + } +} diff --git a/net/rds/rdma.c b/net/rds/rdma.c new file mode 100644 index 000000000..fba82d365 --- /dev/null +++ b/net/rds/rdma.c @@ -0,0 +1,958 @@ +/* + * Copyright (c) 2007, 2020 Oracle and/or its affiliates. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include /* for DMA_*_DEVICE */ + +#include "rds.h" + +/* + * XXX + * - build with sparse + * - should we detect duplicate keys on a socket? hmm. + * - an rdma is an mlock, apply rlimit? + */ + +/* + * get the number of pages by looking at the page indices that the start and + * end addresses fall in. + * + * Returns 0 if the vec is invalid. It is invalid if the number of bytes + * causes the address to wrap or overflows an unsigned int. This comes + * from being stored in the 'length' member of 'struct scatterlist'. + */ +static unsigned int rds_pages_in_vec(struct rds_iovec *vec) +{ + if ((vec->addr + vec->bytes <= vec->addr) || + (vec->bytes > (u64)UINT_MAX)) + return 0; + + return ((vec->addr + vec->bytes + PAGE_SIZE - 1) >> PAGE_SHIFT) - + (vec->addr >> PAGE_SHIFT); +} + +static struct rds_mr *rds_mr_tree_walk(struct rb_root *root, u64 key, + struct rds_mr *insert) +{ + struct rb_node **p = &root->rb_node; + struct rb_node *parent = NULL; + struct rds_mr *mr; + + while (*p) { + parent = *p; + mr = rb_entry(parent, struct rds_mr, r_rb_node); + + if (key < mr->r_key) + p = &(*p)->rb_left; + else if (key > mr->r_key) + p = &(*p)->rb_right; + else + return mr; + } + + if (insert) { + rb_link_node(&insert->r_rb_node, parent, p); + rb_insert_color(&insert->r_rb_node, root); + kref_get(&insert->r_kref); + } + return NULL; +} + +/* + * Destroy the transport-specific part of a MR. + */ +static void rds_destroy_mr(struct rds_mr *mr) +{ + struct rds_sock *rs = mr->r_sock; + void *trans_private = NULL; + unsigned long flags; + + rdsdebug("RDS: destroy mr key is %x refcnt %u\n", + mr->r_key, kref_read(&mr->r_kref)); + + spin_lock_irqsave(&rs->rs_rdma_lock, flags); + if (!RB_EMPTY_NODE(&mr->r_rb_node)) + rb_erase(&mr->r_rb_node, &rs->rs_rdma_keys); + trans_private = mr->r_trans_private; + mr->r_trans_private = NULL; + spin_unlock_irqrestore(&rs->rs_rdma_lock, flags); + + if (trans_private) + mr->r_trans->free_mr(trans_private, mr->r_invalidate); +} + +void __rds_put_mr_final(struct kref *kref) +{ + struct rds_mr *mr = container_of(kref, struct rds_mr, r_kref); + + rds_destroy_mr(mr); + kfree(mr); +} + +/* + * By the time this is called we can't have any more ioctls called on + * the socket so we don't need to worry about racing with others. + */ +void rds_rdma_drop_keys(struct rds_sock *rs) +{ + struct rds_mr *mr; + struct rb_node *node; + unsigned long flags; + + /* Release any MRs associated with this socket */ + spin_lock_irqsave(&rs->rs_rdma_lock, flags); + while ((node = rb_first(&rs->rs_rdma_keys))) { + mr = rb_entry(node, struct rds_mr, r_rb_node); + if (mr->r_trans == rs->rs_transport) + mr->r_invalidate = 0; + rb_erase(&mr->r_rb_node, &rs->rs_rdma_keys); + RB_CLEAR_NODE(&mr->r_rb_node); + spin_unlock_irqrestore(&rs->rs_rdma_lock, flags); + kref_put(&mr->r_kref, __rds_put_mr_final); + spin_lock_irqsave(&rs->rs_rdma_lock, flags); + } + spin_unlock_irqrestore(&rs->rs_rdma_lock, flags); + + if (rs->rs_transport && rs->rs_transport->flush_mrs) + rs->rs_transport->flush_mrs(); +} + +/* + * Helper function to pin user pages. + */ +static int rds_pin_pages(unsigned long user_addr, unsigned int nr_pages, + struct page **pages, int write) +{ + unsigned int gup_flags = FOLL_LONGTERM; + int ret; + + if (write) + gup_flags |= FOLL_WRITE; + + ret = pin_user_pages_fast(user_addr, nr_pages, gup_flags, pages); + if (ret >= 0 && ret < nr_pages) { + unpin_user_pages(pages, ret); + ret = -EFAULT; + } + + return ret; +} + +static int __rds_rdma_map(struct rds_sock *rs, struct rds_get_mr_args *args, + u64 *cookie_ret, struct rds_mr **mr_ret, + struct rds_conn_path *cp) +{ + struct rds_mr *mr = NULL, *found; + struct scatterlist *sg = NULL; + unsigned int nr_pages; + struct page **pages = NULL; + void *trans_private; + unsigned long flags; + rds_rdma_cookie_t cookie; + unsigned int nents = 0; + int need_odp = 0; + long i; + int ret; + + if (ipv6_addr_any(&rs->rs_bound_addr) || !rs->rs_transport) { + ret = -ENOTCONN; /* XXX not a great errno */ + goto out; + } + + if (!rs->rs_transport->get_mr) { + ret = -EOPNOTSUPP; + goto out; + } + + /* If the combination of the addr and size requested for this memory + * region causes an integer overflow, return error. + */ + if (((args->vec.addr + args->vec.bytes) < args->vec.addr) || + PAGE_ALIGN(args->vec.addr + args->vec.bytes) < + (args->vec.addr + args->vec.bytes)) { + ret = -EINVAL; + goto out; + } + + if (!can_do_mlock()) { + ret = -EPERM; + goto out; + } + + nr_pages = rds_pages_in_vec(&args->vec); + if (nr_pages == 0) { + ret = -EINVAL; + goto out; + } + + /* Restrict the size of mr irrespective of underlying transport + * To account for unaligned mr regions, subtract one from nr_pages + */ + if ((nr_pages - 1) > (RDS_MAX_MSG_SIZE >> PAGE_SHIFT)) { + ret = -EMSGSIZE; + goto out; + } + + rdsdebug("RDS: get_mr addr %llx len %llu nr_pages %u\n", + args->vec.addr, args->vec.bytes, nr_pages); + + /* XXX clamp nr_pages to limit the size of this alloc? */ + pages = kcalloc(nr_pages, sizeof(struct page *), GFP_KERNEL); + if (!pages) { + ret = -ENOMEM; + goto out; + } + + mr = kzalloc(sizeof(struct rds_mr), GFP_KERNEL); + if (!mr) { + ret = -ENOMEM; + goto out; + } + + kref_init(&mr->r_kref); + RB_CLEAR_NODE(&mr->r_rb_node); + mr->r_trans = rs->rs_transport; + mr->r_sock = rs; + + if (args->flags & RDS_RDMA_USE_ONCE) + mr->r_use_once = 1; + if (args->flags & RDS_RDMA_INVALIDATE) + mr->r_invalidate = 1; + if (args->flags & RDS_RDMA_READWRITE) + mr->r_write = 1; + + /* + * Pin the pages that make up the user buffer and transfer the page + * pointers to the mr's sg array. We check to see if we've mapped + * the whole region after transferring the partial page references + * to the sg array so that we can have one page ref cleanup path. + * + * For now we have no flag that tells us whether the mapping is + * r/o or r/w. We need to assume r/w, or we'll do a lot of RDMA to + * the zero page. + */ + ret = rds_pin_pages(args->vec.addr, nr_pages, pages, 1); + if (ret == -EOPNOTSUPP) { + need_odp = 1; + } else if (ret <= 0) { + goto out; + } else { + nents = ret; + sg = kmalloc_array(nents, sizeof(*sg), GFP_KERNEL); + if (!sg) { + ret = -ENOMEM; + goto out; + } + WARN_ON(!nents); + sg_init_table(sg, nents); + + /* Stick all pages into the scatterlist */ + for (i = 0 ; i < nents; i++) + sg_set_page(&sg[i], pages[i], PAGE_SIZE, 0); + + rdsdebug("RDS: trans_private nents is %u\n", nents); + } + /* Obtain a transport specific MR. If this succeeds, the + * s/g list is now owned by the MR. + * Note that dma_map() implies that pending writes are + * flushed to RAM, so no dma_sync is needed here. */ + trans_private = rs->rs_transport->get_mr( + sg, nents, rs, &mr->r_key, cp ? cp->cp_conn : NULL, + args->vec.addr, args->vec.bytes, + need_odp ? ODP_ZEROBASED : ODP_NOT_NEEDED); + + if (IS_ERR(trans_private)) { + /* In ODP case, we don't GUP pages, so don't need + * to release anything. + */ + if (!need_odp) { + unpin_user_pages(pages, nr_pages); + kfree(sg); + } + ret = PTR_ERR(trans_private); + goto out; + } + + mr->r_trans_private = trans_private; + + rdsdebug("RDS: get_mr put_user key is %x cookie_addr %p\n", + mr->r_key, (void *)(unsigned long) args->cookie_addr); + + /* The user may pass us an unaligned address, but we can only + * map page aligned regions. So we keep the offset, and build + * a 64bit cookie containing and pass that + * around. */ + if (need_odp) + cookie = rds_rdma_make_cookie(mr->r_key, 0); + else + cookie = rds_rdma_make_cookie(mr->r_key, + args->vec.addr & ~PAGE_MASK); + if (cookie_ret) + *cookie_ret = cookie; + + if (args->cookie_addr && + put_user(cookie, (u64 __user *)(unsigned long)args->cookie_addr)) { + if (!need_odp) { + unpin_user_pages(pages, nr_pages); + kfree(sg); + } + ret = -EFAULT; + goto out; + } + + /* Inserting the new MR into the rbtree bumps its + * reference count. */ + spin_lock_irqsave(&rs->rs_rdma_lock, flags); + found = rds_mr_tree_walk(&rs->rs_rdma_keys, mr->r_key, mr); + spin_unlock_irqrestore(&rs->rs_rdma_lock, flags); + + BUG_ON(found && found != mr); + + rdsdebug("RDS: get_mr key is %x\n", mr->r_key); + if (mr_ret) { + kref_get(&mr->r_kref); + *mr_ret = mr; + } + + ret = 0; +out: + kfree(pages); + if (mr) + kref_put(&mr->r_kref, __rds_put_mr_final); + return ret; +} + +int rds_get_mr(struct rds_sock *rs, sockptr_t optval, int optlen) +{ + struct rds_get_mr_args args; + + if (optlen != sizeof(struct rds_get_mr_args)) + return -EINVAL; + + if (copy_from_sockptr(&args, optval, sizeof(struct rds_get_mr_args))) + return -EFAULT; + + return __rds_rdma_map(rs, &args, NULL, NULL, NULL); +} + +int rds_get_mr_for_dest(struct rds_sock *rs, sockptr_t optval, int optlen) +{ + struct rds_get_mr_for_dest_args args; + struct rds_get_mr_args new_args; + + if (optlen != sizeof(struct rds_get_mr_for_dest_args)) + return -EINVAL; + + if (copy_from_sockptr(&args, optval, + sizeof(struct rds_get_mr_for_dest_args))) + return -EFAULT; + + /* + * Initially, just behave like get_mr(). + * TODO: Implement get_mr as wrapper around this + * and deprecate it. + */ + new_args.vec = args.vec; + new_args.cookie_addr = args.cookie_addr; + new_args.flags = args.flags; + + return __rds_rdma_map(rs, &new_args, NULL, NULL, NULL); +} + +/* + * Free the MR indicated by the given R_Key + */ +int rds_free_mr(struct rds_sock *rs, sockptr_t optval, int optlen) +{ + struct rds_free_mr_args args; + struct rds_mr *mr; + unsigned long flags; + + if (optlen != sizeof(struct rds_free_mr_args)) + return -EINVAL; + + if (copy_from_sockptr(&args, optval, sizeof(struct rds_free_mr_args))) + return -EFAULT; + + /* Special case - a null cookie means flush all unused MRs */ + if (args.cookie == 0) { + if (!rs->rs_transport || !rs->rs_transport->flush_mrs) + return -EINVAL; + rs->rs_transport->flush_mrs(); + return 0; + } + + /* Look up the MR given its R_key and remove it from the rbtree + * so nobody else finds it. + * This should also prevent races with rds_rdma_unuse. + */ + spin_lock_irqsave(&rs->rs_rdma_lock, flags); + mr = rds_mr_tree_walk(&rs->rs_rdma_keys, rds_rdma_cookie_key(args.cookie), NULL); + if (mr) { + rb_erase(&mr->r_rb_node, &rs->rs_rdma_keys); + RB_CLEAR_NODE(&mr->r_rb_node); + if (args.flags & RDS_RDMA_INVALIDATE) + mr->r_invalidate = 1; + } + spin_unlock_irqrestore(&rs->rs_rdma_lock, flags); + + if (!mr) + return -EINVAL; + + kref_put(&mr->r_kref, __rds_put_mr_final); + return 0; +} + +/* + * This is called when we receive an extension header that + * tells us this MR was used. It allows us to implement + * use_once semantics + */ +void rds_rdma_unuse(struct rds_sock *rs, u32 r_key, int force) +{ + struct rds_mr *mr; + unsigned long flags; + int zot_me = 0; + + spin_lock_irqsave(&rs->rs_rdma_lock, flags); + mr = rds_mr_tree_walk(&rs->rs_rdma_keys, r_key, NULL); + if (!mr) { + pr_debug("rds: trying to unuse MR with unknown r_key %u!\n", + r_key); + spin_unlock_irqrestore(&rs->rs_rdma_lock, flags); + return; + } + + /* Get a reference so that the MR won't go away before calling + * sync_mr() below. + */ + kref_get(&mr->r_kref); + + /* If it is going to be freed, remove it from the tree now so + * that no other thread can find it and free it. + */ + if (mr->r_use_once || force) { + rb_erase(&mr->r_rb_node, &rs->rs_rdma_keys); + RB_CLEAR_NODE(&mr->r_rb_node); + zot_me = 1; + } + spin_unlock_irqrestore(&rs->rs_rdma_lock, flags); + + /* May have to issue a dma_sync on this memory region. + * Note we could avoid this if the operation was a RDMA READ, + * but at this point we can't tell. */ + if (mr->r_trans->sync_mr) + mr->r_trans->sync_mr(mr->r_trans_private, DMA_FROM_DEVICE); + + /* Release the reference held above. */ + kref_put(&mr->r_kref, __rds_put_mr_final); + + /* If the MR was marked as invalidate, this will + * trigger an async flush. */ + if (zot_me) + kref_put(&mr->r_kref, __rds_put_mr_final); +} + +void rds_rdma_free_op(struct rm_rdma_op *ro) +{ + unsigned int i; + + if (ro->op_odp_mr) { + kref_put(&ro->op_odp_mr->r_kref, __rds_put_mr_final); + } else { + for (i = 0; i < ro->op_nents; i++) { + struct page *page = sg_page(&ro->op_sg[i]); + + /* Mark page dirty if it was possibly modified, which + * is the case for a RDMA_READ which copies from remote + * to local memory + */ + unpin_user_pages_dirty_lock(&page, 1, !ro->op_write); + } + } + + kfree(ro->op_notifier); + ro->op_notifier = NULL; + ro->op_active = 0; + ro->op_odp_mr = NULL; +} + +void rds_atomic_free_op(struct rm_atomic_op *ao) +{ + struct page *page = sg_page(ao->op_sg); + + /* Mark page dirty if it was possibly modified, which + * is the case for a RDMA_READ which copies from remote + * to local memory */ + unpin_user_pages_dirty_lock(&page, 1, true); + + kfree(ao->op_notifier); + ao->op_notifier = NULL; + ao->op_active = 0; +} + + +/* + * Count the number of pages needed to describe an incoming iovec array. + */ +static int rds_rdma_pages(struct rds_iovec iov[], int nr_iovecs) +{ + int tot_pages = 0; + unsigned int nr_pages; + unsigned int i; + + /* figure out the number of pages in the vector */ + for (i = 0; i < nr_iovecs; i++) { + nr_pages = rds_pages_in_vec(&iov[i]); + if (nr_pages == 0) + return -EINVAL; + + tot_pages += nr_pages; + + /* + * nr_pages for one entry is limited to (UINT_MAX>>PAGE_SHIFT)+1, + * so tot_pages cannot overflow without first going negative. + */ + if (tot_pages < 0) + return -EINVAL; + } + + return tot_pages; +} + +int rds_rdma_extra_size(struct rds_rdma_args *args, + struct rds_iov_vector *iov) +{ + struct rds_iovec *vec; + struct rds_iovec __user *local_vec; + int tot_pages = 0; + unsigned int nr_pages; + unsigned int i; + + local_vec = (struct rds_iovec __user *)(unsigned long) args->local_vec_addr; + + if (args->nr_local == 0) + return -EINVAL; + + if (args->nr_local > UIO_MAXIOV) + return -EMSGSIZE; + + iov->iov = kcalloc(args->nr_local, + sizeof(struct rds_iovec), + GFP_KERNEL); + if (!iov->iov) + return -ENOMEM; + + vec = &iov->iov[0]; + + if (copy_from_user(vec, local_vec, args->nr_local * + sizeof(struct rds_iovec))) + return -EFAULT; + iov->len = args->nr_local; + + /* figure out the number of pages in the vector */ + for (i = 0; i < args->nr_local; i++, vec++) { + + nr_pages = rds_pages_in_vec(vec); + if (nr_pages == 0) + return -EINVAL; + + tot_pages += nr_pages; + + /* + * nr_pages for one entry is limited to (UINT_MAX>>PAGE_SHIFT)+1, + * so tot_pages cannot overflow without first going negative. + */ + if (tot_pages < 0) + return -EINVAL; + } + + return tot_pages * sizeof(struct scatterlist); +} + +/* + * The application asks for a RDMA transfer. + * Extract all arguments and set up the rdma_op + */ +int rds_cmsg_rdma_args(struct rds_sock *rs, struct rds_message *rm, + struct cmsghdr *cmsg, + struct rds_iov_vector *vec) +{ + struct rds_rdma_args *args; + struct rm_rdma_op *op = &rm->rdma; + int nr_pages; + unsigned int nr_bytes; + struct page **pages = NULL; + struct rds_iovec *iovs; + unsigned int i, j; + int ret = 0; + bool odp_supported = true; + + if (cmsg->cmsg_len < CMSG_LEN(sizeof(struct rds_rdma_args)) + || rm->rdma.op_active) + return -EINVAL; + + args = CMSG_DATA(cmsg); + + if (ipv6_addr_any(&rs->rs_bound_addr)) { + ret = -ENOTCONN; /* XXX not a great errno */ + goto out_ret; + } + + if (args->nr_local > UIO_MAXIOV) { + ret = -EMSGSIZE; + goto out_ret; + } + + if (vec->len != args->nr_local) { + ret = -EINVAL; + goto out_ret; + } + /* odp-mr is not supported for multiple requests within one message */ + if (args->nr_local != 1) + odp_supported = false; + + iovs = vec->iov; + + nr_pages = rds_rdma_pages(iovs, args->nr_local); + if (nr_pages < 0) { + ret = -EINVAL; + goto out_ret; + } + + pages = kcalloc(nr_pages, sizeof(struct page *), GFP_KERNEL); + if (!pages) { + ret = -ENOMEM; + goto out_ret; + } + + op->op_write = !!(args->flags & RDS_RDMA_READWRITE); + op->op_fence = !!(args->flags & RDS_RDMA_FENCE); + op->op_notify = !!(args->flags & RDS_RDMA_NOTIFY_ME); + op->op_silent = !!(args->flags & RDS_RDMA_SILENT); + op->op_active = 1; + op->op_recverr = rs->rs_recverr; + op->op_odp_mr = NULL; + + WARN_ON(!nr_pages); + op->op_sg = rds_message_alloc_sgs(rm, nr_pages); + if (IS_ERR(op->op_sg)) { + ret = PTR_ERR(op->op_sg); + goto out_pages; + } + + if (op->op_notify || op->op_recverr) { + /* We allocate an uninitialized notifier here, because + * we don't want to do that in the completion handler. We + * would have to use GFP_ATOMIC there, and don't want to deal + * with failed allocations. + */ + op->op_notifier = kmalloc(sizeof(struct rds_notifier), GFP_KERNEL); + if (!op->op_notifier) { + ret = -ENOMEM; + goto out_pages; + } + op->op_notifier->n_user_token = args->user_token; + op->op_notifier->n_status = RDS_RDMA_SUCCESS; + } + + /* The cookie contains the R_Key of the remote memory region, and + * optionally an offset into it. This is how we implement RDMA into + * unaligned memory. + * When setting up the RDMA, we need to add that offset to the + * destination address (which is really an offset into the MR) + * FIXME: We may want to move this into ib_rdma.c + */ + op->op_rkey = rds_rdma_cookie_key(args->cookie); + op->op_remote_addr = args->remote_vec.addr + rds_rdma_cookie_offset(args->cookie); + + nr_bytes = 0; + + rdsdebug("RDS: rdma prepare nr_local %llu rva %llx rkey %x\n", + (unsigned long long)args->nr_local, + (unsigned long long)args->remote_vec.addr, + op->op_rkey); + + for (i = 0; i < args->nr_local; i++) { + struct rds_iovec *iov = &iovs[i]; + /* don't need to check, rds_rdma_pages() verified nr will be +nonzero */ + unsigned int nr = rds_pages_in_vec(iov); + + rs->rs_user_addr = iov->addr; + rs->rs_user_bytes = iov->bytes; + + /* If it's a WRITE operation, we want to pin the pages for reading. + * If it's a READ operation, we need to pin the pages for writing. + */ + ret = rds_pin_pages(iov->addr, nr, pages, !op->op_write); + if ((!odp_supported && ret <= 0) || + (odp_supported && ret <= 0 && ret != -EOPNOTSUPP)) + goto out_pages; + + if (ret == -EOPNOTSUPP) { + struct rds_mr *local_odp_mr; + + if (!rs->rs_transport->get_mr) { + ret = -EOPNOTSUPP; + goto out_pages; + } + local_odp_mr = + kzalloc(sizeof(*local_odp_mr), GFP_KERNEL); + if (!local_odp_mr) { + ret = -ENOMEM; + goto out_pages; + } + RB_CLEAR_NODE(&local_odp_mr->r_rb_node); + kref_init(&local_odp_mr->r_kref); + local_odp_mr->r_trans = rs->rs_transport; + local_odp_mr->r_sock = rs; + local_odp_mr->r_trans_private = + rs->rs_transport->get_mr( + NULL, 0, rs, &local_odp_mr->r_key, NULL, + iov->addr, iov->bytes, ODP_VIRTUAL); + if (IS_ERR(local_odp_mr->r_trans_private)) { + ret = PTR_ERR(local_odp_mr->r_trans_private); + rdsdebug("get_mr ret %d %p\"", ret, + local_odp_mr->r_trans_private); + kfree(local_odp_mr); + ret = -EOPNOTSUPP; + goto out_pages; + } + rdsdebug("Need odp; local_odp_mr %p trans_private %p\n", + local_odp_mr, local_odp_mr->r_trans_private); + op->op_odp_mr = local_odp_mr; + op->op_odp_addr = iov->addr; + } + + rdsdebug("RDS: nr_bytes %u nr %u iov->bytes %llu iov->addr %llx\n", + nr_bytes, nr, iov->bytes, iov->addr); + + nr_bytes += iov->bytes; + + for (j = 0; j < nr; j++) { + unsigned int offset = iov->addr & ~PAGE_MASK; + struct scatterlist *sg; + + sg = &op->op_sg[op->op_nents + j]; + sg_set_page(sg, pages[j], + min_t(unsigned int, iov->bytes, PAGE_SIZE - offset), + offset); + + sg_dma_len(sg) = sg->length; + rdsdebug("RDS: sg->offset %x sg->len %x iov->addr %llx iov->bytes %llu\n", + sg->offset, sg->length, iov->addr, iov->bytes); + + iov->addr += sg->length; + iov->bytes -= sg->length; + } + + op->op_nents += nr; + } + + if (nr_bytes > args->remote_vec.bytes) { + rdsdebug("RDS nr_bytes %u remote_bytes %u do not match\n", + nr_bytes, + (unsigned int) args->remote_vec.bytes); + ret = -EINVAL; + goto out_pages; + } + op->op_bytes = nr_bytes; + ret = 0; + +out_pages: + kfree(pages); +out_ret: + if (ret) + rds_rdma_free_op(op); + else + rds_stats_inc(s_send_rdma); + + return ret; +} + +/* + * The application wants us to pass an RDMA destination (aka MR) + * to the remote + */ +int rds_cmsg_rdma_dest(struct rds_sock *rs, struct rds_message *rm, + struct cmsghdr *cmsg) +{ + unsigned long flags; + struct rds_mr *mr; + u32 r_key; + int err = 0; + + if (cmsg->cmsg_len < CMSG_LEN(sizeof(rds_rdma_cookie_t)) || + rm->m_rdma_cookie != 0) + return -EINVAL; + + memcpy(&rm->m_rdma_cookie, CMSG_DATA(cmsg), sizeof(rm->m_rdma_cookie)); + + /* We are reusing a previously mapped MR here. Most likely, the + * application has written to the buffer, so we need to explicitly + * flush those writes to RAM. Otherwise the HCA may not see them + * when doing a DMA from that buffer. + */ + r_key = rds_rdma_cookie_key(rm->m_rdma_cookie); + + spin_lock_irqsave(&rs->rs_rdma_lock, flags); + mr = rds_mr_tree_walk(&rs->rs_rdma_keys, r_key, NULL); + if (!mr) + err = -EINVAL; /* invalid r_key */ + else + kref_get(&mr->r_kref); + spin_unlock_irqrestore(&rs->rs_rdma_lock, flags); + + if (mr) { + mr->r_trans->sync_mr(mr->r_trans_private, + DMA_TO_DEVICE); + rm->rdma.op_rdma_mr = mr; + } + return err; +} + +/* + * The application passes us an address range it wants to enable RDMA + * to/from. We map the area, and save the pair + * in rm->m_rdma_cookie. This causes it to be sent along to the peer + * in an extension header. + */ +int rds_cmsg_rdma_map(struct rds_sock *rs, struct rds_message *rm, + struct cmsghdr *cmsg) +{ + if (cmsg->cmsg_len < CMSG_LEN(sizeof(struct rds_get_mr_args)) || + rm->m_rdma_cookie != 0) + return -EINVAL; + + return __rds_rdma_map(rs, CMSG_DATA(cmsg), &rm->m_rdma_cookie, + &rm->rdma.op_rdma_mr, rm->m_conn_path); +} + +/* + * Fill in rds_message for an atomic request. + */ +int rds_cmsg_atomic(struct rds_sock *rs, struct rds_message *rm, + struct cmsghdr *cmsg) +{ + struct page *page = NULL; + struct rds_atomic_args *args; + int ret = 0; + + if (cmsg->cmsg_len < CMSG_LEN(sizeof(struct rds_atomic_args)) + || rm->atomic.op_active) + return -EINVAL; + + args = CMSG_DATA(cmsg); + + /* Nonmasked & masked cmsg ops converted to masked hw ops */ + switch (cmsg->cmsg_type) { + case RDS_CMSG_ATOMIC_FADD: + rm->atomic.op_type = RDS_ATOMIC_TYPE_FADD; + rm->atomic.op_m_fadd.add = args->fadd.add; + rm->atomic.op_m_fadd.nocarry_mask = 0; + break; + case RDS_CMSG_MASKED_ATOMIC_FADD: + rm->atomic.op_type = RDS_ATOMIC_TYPE_FADD; + rm->atomic.op_m_fadd.add = args->m_fadd.add; + rm->atomic.op_m_fadd.nocarry_mask = args->m_fadd.nocarry_mask; + break; + case RDS_CMSG_ATOMIC_CSWP: + rm->atomic.op_type = RDS_ATOMIC_TYPE_CSWP; + rm->atomic.op_m_cswp.compare = args->cswp.compare; + rm->atomic.op_m_cswp.swap = args->cswp.swap; + rm->atomic.op_m_cswp.compare_mask = ~0; + rm->atomic.op_m_cswp.swap_mask = ~0; + break; + case RDS_CMSG_MASKED_ATOMIC_CSWP: + rm->atomic.op_type = RDS_ATOMIC_TYPE_CSWP; + rm->atomic.op_m_cswp.compare = args->m_cswp.compare; + rm->atomic.op_m_cswp.swap = args->m_cswp.swap; + rm->atomic.op_m_cswp.compare_mask = args->m_cswp.compare_mask; + rm->atomic.op_m_cswp.swap_mask = args->m_cswp.swap_mask; + break; + default: + BUG(); /* should never happen */ + } + + rm->atomic.op_notify = !!(args->flags & RDS_RDMA_NOTIFY_ME); + rm->atomic.op_silent = !!(args->flags & RDS_RDMA_SILENT); + rm->atomic.op_active = 1; + rm->atomic.op_recverr = rs->rs_recverr; + rm->atomic.op_sg = rds_message_alloc_sgs(rm, 1); + if (IS_ERR(rm->atomic.op_sg)) { + ret = PTR_ERR(rm->atomic.op_sg); + goto err; + } + + /* verify 8 byte-aligned */ + if (args->local_addr & 0x7) { + ret = -EFAULT; + goto err; + } + + ret = rds_pin_pages(args->local_addr, 1, &page, 1); + if (ret != 1) + goto err; + ret = 0; + + sg_set_page(rm->atomic.op_sg, page, 8, offset_in_page(args->local_addr)); + + if (rm->atomic.op_notify || rm->atomic.op_recverr) { + /* We allocate an uninitialized notifier here, because + * we don't want to do that in the completion handler. We + * would have to use GFP_ATOMIC there, and don't want to deal + * with failed allocations. + */ + rm->atomic.op_notifier = kmalloc(sizeof(*rm->atomic.op_notifier), GFP_KERNEL); + if (!rm->atomic.op_notifier) { + ret = -ENOMEM; + goto err; + } + + rm->atomic.op_notifier->n_user_token = args->user_token; + rm->atomic.op_notifier->n_status = RDS_RDMA_SUCCESS; + } + + rm->atomic.op_rkey = rds_rdma_cookie_key(args->cookie); + rm->atomic.op_remote_addr = args->remote_addr + rds_rdma_cookie_offset(args->cookie); + + return ret; +err: + if (page) + unpin_user_page(page); + rm->atomic.op_active = 0; + kfree(rm->atomic.op_notifier); + + return ret; +} diff --git a/net/rds/rdma_transport.c b/net/rds/rdma_transport.c new file mode 100644 index 000000000..b15cf316b --- /dev/null +++ b/net/rds/rdma_transport.c @@ -0,0 +1,322 @@ +/* + * Copyright (c) 2009, 2018 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include + +#include "rds_single_path.h" +#include "rdma_transport.h" +#include "ib.h" + +/* Global IPv4 and IPv6 RDS RDMA listener cm_id */ +static struct rdma_cm_id *rds_rdma_listen_id; +#if IS_ENABLED(CONFIG_IPV6) +static struct rdma_cm_id *rds6_rdma_listen_id; +#endif + +/* Per IB specification 7.7.3, service level is a 4-bit field. */ +#define TOS_TO_SL(tos) ((tos) & 0xF) + +static int rds_rdma_cm_event_handler_cmn(struct rdma_cm_id *cm_id, + struct rdma_cm_event *event, + bool isv6) +{ + /* this can be null in the listening path */ + struct rds_connection *conn = cm_id->context; + struct rds_transport *trans; + int ret = 0; + int *err; + u8 len; + + rdsdebug("conn %p id %p handling event %u (%s)\n", conn, cm_id, + event->event, rdma_event_msg(event->event)); + + if (cm_id->device->node_type == RDMA_NODE_IB_CA) + trans = &rds_ib_transport; + + /* Prevent shutdown from tearing down the connection + * while we're executing. */ + if (conn) { + mutex_lock(&conn->c_cm_lock); + + /* If the connection is being shut down, bail out + * right away. We return 0 so cm_id doesn't get + * destroyed prematurely */ + if (rds_conn_state(conn) == RDS_CONN_DISCONNECTING) { + /* Reject incoming connections while we're tearing + * down an existing one. */ + if (event->event == RDMA_CM_EVENT_CONNECT_REQUEST) + ret = 1; + goto out; + } + } + + switch (event->event) { + case RDMA_CM_EVENT_CONNECT_REQUEST: + ret = trans->cm_handle_connect(cm_id, event, isv6); + break; + + case RDMA_CM_EVENT_ADDR_RESOLVED: + if (conn) { + rdma_set_service_type(cm_id, conn->c_tos); + rdma_set_min_rnr_timer(cm_id, IB_RNR_TIMER_000_32); + /* XXX do we need to clean up if this fails? */ + ret = rdma_resolve_route(cm_id, + RDS_RDMA_RESOLVE_TIMEOUT_MS); + } + break; + + case RDMA_CM_EVENT_ROUTE_RESOLVED: + /* Connection could have been dropped so make sure the + * cm_id is valid before proceeding + */ + if (conn) { + struct rds_ib_connection *ibic; + + ibic = conn->c_transport_data; + if (ibic && ibic->i_cm_id == cm_id) { + cm_id->route.path_rec[0].sl = + TOS_TO_SL(conn->c_tos); + ret = trans->cm_initiate_connect(cm_id, isv6); + } else { + rds_conn_drop(conn); + } + } + break; + + case RDMA_CM_EVENT_ESTABLISHED: + if (conn) + trans->cm_connect_complete(conn, event); + break; + + case RDMA_CM_EVENT_REJECTED: + if (!conn) + break; + err = (int *)rdma_consumer_reject_data(cm_id, event, &len); + if (!err || + (err && len >= sizeof(*err) && + ((*err) <= RDS_RDMA_REJ_INCOMPAT))) { + pr_warn("RDS/RDMA: conn <%pI6c, %pI6c> rejected, dropping connection\n", + &conn->c_laddr, &conn->c_faddr); + + if (!conn->c_tos) + conn->c_proposed_version = RDS_PROTOCOL_COMPAT_VERSION; + + rds_conn_drop(conn); + } + rdsdebug("Connection rejected: %s\n", + rdma_reject_msg(cm_id, event->status)); + break; + case RDMA_CM_EVENT_ADDR_ERROR: + case RDMA_CM_EVENT_ROUTE_ERROR: + case RDMA_CM_EVENT_CONNECT_ERROR: + case RDMA_CM_EVENT_UNREACHABLE: + case RDMA_CM_EVENT_DEVICE_REMOVAL: + case RDMA_CM_EVENT_ADDR_CHANGE: + if (conn) + rds_conn_drop(conn); + break; + + case RDMA_CM_EVENT_DISCONNECTED: + if (!conn) + break; + rdsdebug("DISCONNECT event - dropping connection " + "%pI6c->%pI6c\n", &conn->c_laddr, + &conn->c_faddr); + rds_conn_drop(conn); + break; + + case RDMA_CM_EVENT_TIMEWAIT_EXIT: + if (conn) { + pr_info("RDS: RDMA_CM_EVENT_TIMEWAIT_EXIT event: dropping connection %pI6c->%pI6c\n", + &conn->c_laddr, &conn->c_faddr); + rds_conn_drop(conn); + } + break; + + default: + /* things like device disconnect? */ + printk(KERN_ERR "RDS: unknown event %u (%s)!\n", + event->event, rdma_event_msg(event->event)); + break; + } + +out: + if (conn) + mutex_unlock(&conn->c_cm_lock); + + rdsdebug("id %p event %u (%s) handling ret %d\n", cm_id, event->event, + rdma_event_msg(event->event), ret); + + return ret; +} + +int rds_rdma_cm_event_handler(struct rdma_cm_id *cm_id, + struct rdma_cm_event *event) +{ + return rds_rdma_cm_event_handler_cmn(cm_id, event, false); +} + +#if IS_ENABLED(CONFIG_IPV6) +int rds6_rdma_cm_event_handler(struct rdma_cm_id *cm_id, + struct rdma_cm_event *event) +{ + return rds_rdma_cm_event_handler_cmn(cm_id, event, true); +} +#endif + +static int rds_rdma_listen_init_common(rdma_cm_event_handler handler, + struct sockaddr *sa, + struct rdma_cm_id **ret_cm_id) +{ + struct rdma_cm_id *cm_id; + int ret; + + cm_id = rdma_create_id(&init_net, handler, NULL, + RDMA_PS_TCP, IB_QPT_RC); + if (IS_ERR(cm_id)) { + ret = PTR_ERR(cm_id); + printk(KERN_ERR "RDS/RDMA: failed to setup listener, " + "rdma_create_id() returned %d\n", ret); + return ret; + } + + /* + * XXX I bet this binds the cm_id to a device. If we want to support + * fail-over we'll have to take this into consideration. + */ + ret = rdma_bind_addr(cm_id, sa); + if (ret) { + printk(KERN_ERR "RDS/RDMA: failed to setup listener, " + "rdma_bind_addr() returned %d\n", ret); + goto out; + } + + ret = rdma_listen(cm_id, 128); + if (ret) { + printk(KERN_ERR "RDS/RDMA: failed to setup listener, " + "rdma_listen() returned %d\n", ret); + goto out; + } + + rdsdebug("cm %p listening on port %u\n", cm_id, RDS_PORT); + + *ret_cm_id = cm_id; + cm_id = NULL; +out: + if (cm_id) + rdma_destroy_id(cm_id); + return ret; +} + +/* Initialize the RDS RDMA listeners. We create two listeners for + * compatibility reason. The one on RDS_PORT is used for IPv4 + * requests only. The one on RDS_CM_PORT is used for IPv6 requests + * only. So only IPv6 enabled RDS module will communicate using this + * port. + */ +static int rds_rdma_listen_init(void) +{ + int ret; +#if IS_ENABLED(CONFIG_IPV6) + struct sockaddr_in6 sin6; +#endif + struct sockaddr_in sin; + + sin.sin_family = PF_INET; + sin.sin_addr.s_addr = htonl(INADDR_ANY); + sin.sin_port = htons(RDS_PORT); + ret = rds_rdma_listen_init_common(rds_rdma_cm_event_handler, + (struct sockaddr *)&sin, + &rds_rdma_listen_id); + if (ret != 0) + return ret; + +#if IS_ENABLED(CONFIG_IPV6) + sin6.sin6_family = PF_INET6; + sin6.sin6_addr = in6addr_any; + sin6.sin6_port = htons(RDS_CM_PORT); + sin6.sin6_scope_id = 0; + sin6.sin6_flowinfo = 0; + ret = rds_rdma_listen_init_common(rds6_rdma_cm_event_handler, + (struct sockaddr *)&sin6, + &rds6_rdma_listen_id); + /* Keep going even when IPv6 is not enabled in the system. */ + if (ret != 0) + rdsdebug("Cannot set up IPv6 RDMA listener\n"); +#endif + return 0; +} + +static void rds_rdma_listen_stop(void) +{ + if (rds_rdma_listen_id) { + rdsdebug("cm %p\n", rds_rdma_listen_id); + rdma_destroy_id(rds_rdma_listen_id); + rds_rdma_listen_id = NULL; + } +#if IS_ENABLED(CONFIG_IPV6) + if (rds6_rdma_listen_id) { + rdsdebug("cm %p\n", rds6_rdma_listen_id); + rdma_destroy_id(rds6_rdma_listen_id); + rds6_rdma_listen_id = NULL; + } +#endif +} + +static int __init rds_rdma_init(void) +{ + int ret; + + ret = rds_ib_init(); + if (ret) + goto out; + + ret = rds_rdma_listen_init(); + if (ret) + rds_ib_exit(); +out: + return ret; +} +module_init(rds_rdma_init); + +static void __exit rds_rdma_exit(void) +{ + /* stop listening first to ensure no new connections are attempted */ + rds_rdma_listen_stop(); + rds_ib_exit(); +} +module_exit(rds_rdma_exit); + +MODULE_AUTHOR("Oracle Corporation "); +MODULE_DESCRIPTION("RDS: IB transport"); +MODULE_LICENSE("Dual BSD/GPL"); diff --git a/net/rds/rdma_transport.h b/net/rds/rdma_transport.h new file mode 100644 index 000000000..ca4c3a667 --- /dev/null +++ b/net/rds/rdma_transport.h @@ -0,0 +1,31 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _RDMA_TRANSPORT_H +#define _RDMA_TRANSPORT_H + +#include +#include +#include "rds.h" + +/* RDMA_CM also uses 16385 as the listener port. */ +#define RDS_CM_PORT 16385 + +#define RDS_RDMA_RESOLVE_TIMEOUT_MS 5000 + +/* Below reject reason is for legacy interoperability issue with non-linux + * RDS endpoints where older version incompatibility is conveyed via value 1. + * For future version(s), proper encoded reject reason should be used. + */ +#define RDS_RDMA_REJ_INCOMPAT 1 + +int rds_rdma_conn_connect(struct rds_connection *conn); +int rds_rdma_cm_event_handler(struct rdma_cm_id *cm_id, + struct rdma_cm_event *event); +int rds6_rdma_cm_event_handler(struct rdma_cm_id *cm_id, + struct rdma_cm_event *event); + +/* from ib.c */ +extern struct rds_transport rds_ib_transport; +int rds_ib_init(void); +void rds_ib_exit(void); + +#endif diff --git a/net/rds/rds.h b/net/rds/rds.h new file mode 100644 index 000000000..d35d1fc39 --- /dev/null +++ b/net/rds/rds.h @@ -0,0 +1,1019 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _RDS_RDS_H +#define _RDS_RDS_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "info.h" + +/* + * RDS Network protocol version + */ +#define RDS_PROTOCOL_3_0 0x0300 +#define RDS_PROTOCOL_3_1 0x0301 +#define RDS_PROTOCOL_4_0 0x0400 +#define RDS_PROTOCOL_4_1 0x0401 +#define RDS_PROTOCOL_VERSION RDS_PROTOCOL_3_1 +#define RDS_PROTOCOL_MAJOR(v) ((v) >> 8) +#define RDS_PROTOCOL_MINOR(v) ((v) & 255) +#define RDS_PROTOCOL(maj, min) (((maj) << 8) | min) +#define RDS_PROTOCOL_COMPAT_VERSION RDS_PROTOCOL_3_1 + +/* The following ports, 16385, 18634, 18635, are registered with IANA as + * the ports to be used for RDS over TCP and UDP. Currently, only RDS over + * TCP and RDS over IB/RDMA are implemented. 18634 is the historical value + * used for the RDMA_CM listener port. RDS/TCP uses port 16385. After + * IPv6 work, RDMA_CM also uses 16385 as the listener port. 18634 is kept + * to ensure compatibility with older RDS modules. Those ports are defined + * in each transport's header file. + */ +#define RDS_PORT 18634 + +#ifdef ATOMIC64_INIT +#define KERNEL_HAS_ATOMIC64 +#endif +#ifdef RDS_DEBUG +#define rdsdebug(fmt, args...) pr_debug("%s(): " fmt, __func__ , ##args) +#else +/* sigh, pr_debug() causes unused variable warnings */ +static inline __printf(1, 2) +void rdsdebug(char *fmt, ...) +{ +} +#endif + +#define RDS_FRAG_SHIFT 12 +#define RDS_FRAG_SIZE ((unsigned int)(1 << RDS_FRAG_SHIFT)) + +/* Used to limit both RDMA and non-RDMA RDS message to 1MB */ +#define RDS_MAX_MSG_SIZE ((unsigned int)(1 << 20)) + +#define RDS_CONG_MAP_BYTES (65536 / 8) +#define RDS_CONG_MAP_PAGES (PAGE_ALIGN(RDS_CONG_MAP_BYTES) / PAGE_SIZE) +#define RDS_CONG_MAP_PAGE_BITS (PAGE_SIZE * 8) + +struct rds_cong_map { + struct rb_node m_rb_node; + struct in6_addr m_addr; + wait_queue_head_t m_waitq; + struct list_head m_conn_list; + unsigned long m_page_addrs[RDS_CONG_MAP_PAGES]; +}; + + +/* + * This is how we will track the connection state: + * A connection is always in one of the following + * states. Updates to the state are atomic and imply + * a memory barrier. + */ +enum { + RDS_CONN_DOWN = 0, + RDS_CONN_CONNECTING, + RDS_CONN_DISCONNECTING, + RDS_CONN_UP, + RDS_CONN_RESETTING, + RDS_CONN_ERROR, +}; + +/* Bits for c_flags */ +#define RDS_LL_SEND_FULL 0 +#define RDS_RECONNECT_PENDING 1 +#define RDS_IN_XMIT 2 +#define RDS_RECV_REFILL 3 +#define RDS_DESTROY_PENDING 4 + +/* Max number of multipaths per RDS connection. Must be a power of 2 */ +#define RDS_MPATH_WORKERS 8 +#define RDS_MPATH_HASH(rs, n) (jhash_1word((rs)->rs_bound_port, \ + (rs)->rs_hash_initval) & ((n) - 1)) + +#define IS_CANONICAL(laddr, faddr) (htonl(laddr) < htonl(faddr)) + +/* Per mpath connection state */ +struct rds_conn_path { + struct rds_connection *cp_conn; + struct rds_message *cp_xmit_rm; + unsigned long cp_xmit_sg; + unsigned int cp_xmit_hdr_off; + unsigned int cp_xmit_data_off; + unsigned int cp_xmit_atomic_sent; + unsigned int cp_xmit_rdma_sent; + unsigned int cp_xmit_data_sent; + + spinlock_t cp_lock; /* protect msg queues */ + u64 cp_next_tx_seq; + struct list_head cp_send_queue; + struct list_head cp_retrans; + + u64 cp_next_rx_seq; + + void *cp_transport_data; + + atomic_t cp_state; + unsigned long cp_send_gen; + unsigned long cp_flags; + unsigned long cp_reconnect_jiffies; + struct delayed_work cp_send_w; + struct delayed_work cp_recv_w; + struct delayed_work cp_conn_w; + struct work_struct cp_down_w; + struct mutex cp_cm_lock; /* protect cp_state & cm */ + wait_queue_head_t cp_waitq; + + unsigned int cp_unacked_packets; + unsigned int cp_unacked_bytes; + unsigned int cp_index; +}; + +/* One rds_connection per RDS address pair */ +struct rds_connection { + struct hlist_node c_hash_node; + struct in6_addr c_laddr; + struct in6_addr c_faddr; + int c_dev_if; /* ifindex used for this conn */ + int c_bound_if; /* ifindex of c_laddr */ + unsigned int c_loopback:1, + c_isv6:1, + c_ping_triggered:1, + c_pad_to_32:29; + int c_npaths; + struct rds_connection *c_passive; + struct rds_transport *c_trans; + + struct rds_cong_map *c_lcong; + struct rds_cong_map *c_fcong; + + /* Protocol version */ + unsigned int c_proposed_version; + unsigned int c_version; + possible_net_t c_net; + + /* TOS */ + u8 c_tos; + + struct list_head c_map_item; + unsigned long c_map_queued; + + struct rds_conn_path *c_path; + wait_queue_head_t c_hs_waitq; /* handshake waitq */ + + u32 c_my_gen_num; + u32 c_peer_gen_num; +}; + +static inline +struct net *rds_conn_net(struct rds_connection *conn) +{ + return read_pnet(&conn->c_net); +} + +static inline +void rds_conn_net_set(struct rds_connection *conn, struct net *net) +{ + write_pnet(&conn->c_net, net); +} + +#define RDS_FLAG_CONG_BITMAP 0x01 +#define RDS_FLAG_ACK_REQUIRED 0x02 +#define RDS_FLAG_RETRANSMITTED 0x04 +#define RDS_MAX_ADV_CREDIT 255 + +/* RDS_FLAG_PROBE_PORT is the reserved sport used for sending a ping + * probe to exchange control information before establishing a connection. + * Currently the control information that is exchanged is the number of + * supported paths. If the peer is a legacy (older kernel revision) peer, + * it would return a pong message without additional control information + * that would then alert the sender that the peer was an older rev. + */ +#define RDS_FLAG_PROBE_PORT 1 +#define RDS_HS_PROBE(sport, dport) \ + ((sport == RDS_FLAG_PROBE_PORT && dport == 0) || \ + (sport == 0 && dport == RDS_FLAG_PROBE_PORT)) +/* + * Maximum space available for extension headers. + */ +#define RDS_HEADER_EXT_SPACE 16 + +struct rds_header { + __be64 h_sequence; + __be64 h_ack; + __be32 h_len; + __be16 h_sport; + __be16 h_dport; + u8 h_flags; + u8 h_credit; + u8 h_padding[4]; + __sum16 h_csum; + + u8 h_exthdr[RDS_HEADER_EXT_SPACE]; +}; + +/* + * Reserved - indicates end of extensions + */ +#define RDS_EXTHDR_NONE 0 + +/* + * This extension header is included in the very + * first message that is sent on a new connection, + * and identifies the protocol level. This will help + * rolling updates if a future change requires breaking + * the protocol. + * NB: This is no longer true for IB, where we do a version + * negotiation during the connection setup phase (protocol + * version information is included in the RDMA CM private data). + */ +#define RDS_EXTHDR_VERSION 1 +struct rds_ext_header_version { + __be32 h_version; +}; + +/* + * This extension header is included in the RDS message + * chasing an RDMA operation. + */ +#define RDS_EXTHDR_RDMA 2 +struct rds_ext_header_rdma { + __be32 h_rdma_rkey; +}; + +/* + * This extension header tells the peer about the + * destination of the requested RDMA + * operation. + */ +#define RDS_EXTHDR_RDMA_DEST 3 +struct rds_ext_header_rdma_dest { + __be32 h_rdma_rkey; + __be32 h_rdma_offset; +}; + +/* Extension header announcing number of paths. + * Implicit length = 2 bytes. + */ +#define RDS_EXTHDR_NPATHS 5 +#define RDS_EXTHDR_GEN_NUM 6 + +#define __RDS_EXTHDR_MAX 16 /* for now */ +#define RDS_RX_MAX_TRACES (RDS_MSG_RX_DGRAM_TRACE_MAX + 1) +#define RDS_MSG_RX_HDR 0 +#define RDS_MSG_RX_START 1 +#define RDS_MSG_RX_END 2 +#define RDS_MSG_RX_CMSG 3 + +/* The following values are whitelisted for usercopy */ +struct rds_inc_usercopy { + rds_rdma_cookie_t rdma_cookie; + ktime_t rx_tstamp; +}; + +struct rds_incoming { + refcount_t i_refcount; + struct list_head i_item; + struct rds_connection *i_conn; + struct rds_conn_path *i_conn_path; + struct rds_header i_hdr; + unsigned long i_rx_jiffies; + struct in6_addr i_saddr; + + struct rds_inc_usercopy i_usercopy; + u64 i_rx_lat_trace[RDS_RX_MAX_TRACES]; +}; + +struct rds_mr { + struct rb_node r_rb_node; + struct kref r_kref; + u32 r_key; + + /* A copy of the creation flags */ + unsigned int r_use_once:1; + unsigned int r_invalidate:1; + unsigned int r_write:1; + + struct rds_sock *r_sock; /* back pointer to the socket that owns us */ + struct rds_transport *r_trans; + void *r_trans_private; +}; + +static inline rds_rdma_cookie_t rds_rdma_make_cookie(u32 r_key, u32 offset) +{ + return r_key | (((u64) offset) << 32); +} + +static inline u32 rds_rdma_cookie_key(rds_rdma_cookie_t cookie) +{ + return cookie; +} + +static inline u32 rds_rdma_cookie_offset(rds_rdma_cookie_t cookie) +{ + return cookie >> 32; +} + +/* atomic operation types */ +#define RDS_ATOMIC_TYPE_CSWP 0 +#define RDS_ATOMIC_TYPE_FADD 1 + +/* + * m_sock_item and m_conn_item are on lists that are serialized under + * conn->c_lock. m_sock_item has additional meaning in that once it is empty + * the message will not be put back on the retransmit list after being sent. + * messages that are canceled while being sent rely on this. + * + * m_inc is used by loopback so that it can pass an incoming message straight + * back up into the rx path. It embeds a wire header which is also used by + * the send path, which is kind of awkward. + * + * m_sock_item indicates the message's presence on a socket's send or receive + * queue. m_rs will point to that socket. + * + * m_daddr is used by cancellation to prune messages to a given destination. + * + * The RDS_MSG_ON_SOCK and RDS_MSG_ON_CONN flags are used to avoid lock + * nesting. As paths iterate over messages on a sock, or conn, they must + * also lock the conn, or sock, to remove the message from those lists too. + * Testing the flag to determine if the message is still on the lists lets + * us avoid testing the list_head directly. That means each path can use + * the message's list_head to keep it on a local list while juggling locks + * without confusing the other path. + * + * m_ack_seq is an optional field set by transports who need a different + * sequence number range to invalidate. They can use this in a callback + * that they pass to rds_send_drop_acked() to see if each message has been + * acked. The HAS_ACK_SEQ flag can be used to detect messages which haven't + * had ack_seq set yet. + */ +#define RDS_MSG_ON_SOCK 1 +#define RDS_MSG_ON_CONN 2 +#define RDS_MSG_HAS_ACK_SEQ 3 +#define RDS_MSG_ACK_REQUIRED 4 +#define RDS_MSG_RETRANSMITTED 5 +#define RDS_MSG_MAPPED 6 +#define RDS_MSG_PAGEVEC 7 +#define RDS_MSG_FLUSH 8 + +struct rds_znotifier { + struct mmpin z_mmp; + u32 z_cookie; +}; + +struct rds_msg_zcopy_info { + struct list_head rs_zcookie_next; + union { + struct rds_znotifier znotif; + struct rds_zcopy_cookies zcookies; + }; +}; + +struct rds_msg_zcopy_queue { + struct list_head zcookie_head; + spinlock_t lock; /* protects zcookie_head queue */ +}; + +static inline void rds_message_zcopy_queue_init(struct rds_msg_zcopy_queue *q) +{ + spin_lock_init(&q->lock); + INIT_LIST_HEAD(&q->zcookie_head); +} + +struct rds_iov_vector { + struct rds_iovec *iov; + int len; +}; + +struct rds_iov_vector_arr { + struct rds_iov_vector *vec; + int len; + int indx; + int incr; +}; + +struct rds_message { + refcount_t m_refcount; + struct list_head m_sock_item; + struct list_head m_conn_item; + struct rds_incoming m_inc; + u64 m_ack_seq; + struct in6_addr m_daddr; + unsigned long m_flags; + + /* Never access m_rs without holding m_rs_lock. + * Lock nesting is + * rm->m_rs_lock + * -> rs->rs_lock + */ + spinlock_t m_rs_lock; + wait_queue_head_t m_flush_wait; + + struct rds_sock *m_rs; + + /* cookie to send to remote, in rds header */ + rds_rdma_cookie_t m_rdma_cookie; + + unsigned int m_used_sgs; + unsigned int m_total_sgs; + + void *m_final_op; + + struct { + struct rm_atomic_op { + int op_type; + union { + struct { + uint64_t compare; + uint64_t swap; + uint64_t compare_mask; + uint64_t swap_mask; + } op_m_cswp; + struct { + uint64_t add; + uint64_t nocarry_mask; + } op_m_fadd; + }; + + u32 op_rkey; + u64 op_remote_addr; + unsigned int op_notify:1; + unsigned int op_recverr:1; + unsigned int op_mapped:1; + unsigned int op_silent:1; + unsigned int op_active:1; + struct scatterlist *op_sg; + struct rds_notifier *op_notifier; + + struct rds_mr *op_rdma_mr; + } atomic; + struct rm_rdma_op { + u32 op_rkey; + u64 op_remote_addr; + unsigned int op_write:1; + unsigned int op_fence:1; + unsigned int op_notify:1; + unsigned int op_recverr:1; + unsigned int op_mapped:1; + unsigned int op_silent:1; + unsigned int op_active:1; + unsigned int op_bytes; + unsigned int op_nents; + unsigned int op_count; + struct scatterlist *op_sg; + struct rds_notifier *op_notifier; + + struct rds_mr *op_rdma_mr; + + u64 op_odp_addr; + struct rds_mr *op_odp_mr; + } rdma; + struct rm_data_op { + unsigned int op_active:1; + unsigned int op_nents; + unsigned int op_count; + unsigned int op_dmasg; + unsigned int op_dmaoff; + struct rds_znotifier *op_mmp_znotifier; + struct scatterlist *op_sg; + } data; + }; + + struct rds_conn_path *m_conn_path; +}; + +/* + * The RDS notifier is used (optionally) to tell the application about + * completed RDMA operations. Rather than keeping the whole rds message + * around on the queue, we allocate a small notifier that is put on the + * socket's notifier_list. Notifications are delivered to the application + * through control messages. + */ +struct rds_notifier { + struct list_head n_list; + uint64_t n_user_token; + int n_status; +}; + +/* Available as part of RDS core, so doesn't need to participate + * in get_preferred transport etc + */ +#define RDS_TRANS_LOOP 3 + +/** + * struct rds_transport - transport specific behavioural hooks + * + * @xmit: .xmit is called by rds_send_xmit() to tell the transport to send + * part of a message. The caller serializes on the send_sem so this + * doesn't need to be reentrant for a given conn. The header must be + * sent before the data payload. .xmit must be prepared to send a + * message with no data payload. .xmit should return the number of + * bytes that were sent down the connection, including header bytes. + * Returning 0 tells the caller that it doesn't need to perform any + * additional work now. This is usually the case when the transport has + * filled the sending queue for its connection and will handle + * triggering the rds thread to continue the send when space becomes + * available. Returning -EAGAIN tells the caller to retry the send + * immediately. Returning -ENOMEM tells the caller to retry the send at + * some point in the future. + * + * @conn_shutdown: conn_shutdown stops traffic on the given connection. Once + * it returns the connection can not call rds_recv_incoming(). + * This will only be called once after conn_connect returns + * non-zero success and will The caller serializes this with + * the send and connecting paths (xmit_* and conn_*). The + * transport is responsible for other serialization, including + * rds_recv_incoming(). This is called in process context but + * should try hard not to block. + */ + +struct rds_transport { + char t_name[TRANSNAMSIZ]; + struct list_head t_item; + struct module *t_owner; + unsigned int t_prefer_loopback:1, + t_mp_capable:1; + unsigned int t_type; + + int (*laddr_check)(struct net *net, const struct in6_addr *addr, + __u32 scope_id); + int (*conn_alloc)(struct rds_connection *conn, gfp_t gfp); + void (*conn_free)(void *data); + int (*conn_path_connect)(struct rds_conn_path *cp); + void (*conn_path_shutdown)(struct rds_conn_path *conn); + void (*xmit_path_prepare)(struct rds_conn_path *cp); + void (*xmit_path_complete)(struct rds_conn_path *cp); + int (*xmit)(struct rds_connection *conn, struct rds_message *rm, + unsigned int hdr_off, unsigned int sg, unsigned int off); + int (*xmit_rdma)(struct rds_connection *conn, struct rm_rdma_op *op); + int (*xmit_atomic)(struct rds_connection *conn, struct rm_atomic_op *op); + int (*recv_path)(struct rds_conn_path *cp); + int (*inc_copy_to_user)(struct rds_incoming *inc, struct iov_iter *to); + void (*inc_free)(struct rds_incoming *inc); + + int (*cm_handle_connect)(struct rdma_cm_id *cm_id, + struct rdma_cm_event *event, bool isv6); + int (*cm_initiate_connect)(struct rdma_cm_id *cm_id, bool isv6); + void (*cm_connect_complete)(struct rds_connection *conn, + struct rdma_cm_event *event); + + unsigned int (*stats_info_copy)(struct rds_info_iterator *iter, + unsigned int avail); + void (*exit)(void); + void *(*get_mr)(struct scatterlist *sg, unsigned long nr_sg, + struct rds_sock *rs, u32 *key_ret, + struct rds_connection *conn, + u64 start, u64 length, int need_odp); + void (*sync_mr)(void *trans_private, int direction); + void (*free_mr)(void *trans_private, int invalidate); + void (*flush_mrs)(void); + bool (*t_unloading)(struct rds_connection *conn); + u8 (*get_tos_map)(u8 tos); +}; + +/* Bind hash table key length. It is the sum of the size of a struct + * in6_addr, a scope_id and a port. + */ +#define RDS_BOUND_KEY_LEN \ + (sizeof(struct in6_addr) + sizeof(__u32) + sizeof(__be16)) + +struct rds_sock { + struct sock rs_sk; + + u64 rs_user_addr; + u64 rs_user_bytes; + + /* + * bound_addr used for both incoming and outgoing, no INADDR_ANY + * support. + */ + struct rhash_head rs_bound_node; + u8 rs_bound_key[RDS_BOUND_KEY_LEN]; + struct sockaddr_in6 rs_bound_sin6; +#define rs_bound_addr rs_bound_sin6.sin6_addr +#define rs_bound_addr_v4 rs_bound_sin6.sin6_addr.s6_addr32[3] +#define rs_bound_port rs_bound_sin6.sin6_port +#define rs_bound_scope_id rs_bound_sin6.sin6_scope_id + struct in6_addr rs_conn_addr; +#define rs_conn_addr_v4 rs_conn_addr.s6_addr32[3] + __be16 rs_conn_port; + struct rds_transport *rs_transport; + + /* + * rds_sendmsg caches the conn it used the last time around. + * This helps avoid costly lookups. + */ + struct rds_connection *rs_conn; + + /* flag indicating we were congested or not */ + int rs_congested; + /* seen congestion (ENOBUFS) when sending? */ + int rs_seen_congestion; + + /* rs_lock protects all these adjacent members before the newline */ + spinlock_t rs_lock; + struct list_head rs_send_queue; + u32 rs_snd_bytes; + int rs_rcv_bytes; + struct list_head rs_notify_queue; /* currently used for failed RDMAs */ + + /* Congestion wake_up. If rs_cong_monitor is set, we use cong_mask + * to decide whether the application should be woken up. + * If not set, we use rs_cong_track to find out whether a cong map + * update arrived. + */ + uint64_t rs_cong_mask; + uint64_t rs_cong_notify; + struct list_head rs_cong_list; + unsigned long rs_cong_track; + + /* + * rs_recv_lock protects the receive queue, and is + * used to serialize with rds_release. + */ + rwlock_t rs_recv_lock; + struct list_head rs_recv_queue; + + /* just for stats reporting */ + struct list_head rs_item; + + /* these have their own lock */ + spinlock_t rs_rdma_lock; + struct rb_root rs_rdma_keys; + + /* Socket options - in case there will be more */ + unsigned char rs_recverr, + rs_cong_monitor; + u32 rs_hash_initval; + + /* Socket receive path trace points*/ + u8 rs_rx_traces; + u8 rs_rx_trace[RDS_MSG_RX_DGRAM_TRACE_MAX]; + struct rds_msg_zcopy_queue rs_zcookie_queue; + u8 rs_tos; +}; + +static inline struct rds_sock *rds_sk_to_rs(const struct sock *sk) +{ + return container_of(sk, struct rds_sock, rs_sk); +} +static inline struct sock *rds_rs_to_sk(struct rds_sock *rs) +{ + return &rs->rs_sk; +} + +/* + * The stack assigns sk_sndbuf and sk_rcvbuf to twice the specified value + * to account for overhead. We don't account for overhead, we just apply + * the number of payload bytes to the specified value. + */ +static inline int rds_sk_sndbuf(struct rds_sock *rs) +{ + return rds_rs_to_sk(rs)->sk_sndbuf / 2; +} +static inline int rds_sk_rcvbuf(struct rds_sock *rs) +{ + return rds_rs_to_sk(rs)->sk_rcvbuf / 2; +} + +struct rds_statistics { + uint64_t s_conn_reset; + uint64_t s_recv_drop_bad_checksum; + uint64_t s_recv_drop_old_seq; + uint64_t s_recv_drop_no_sock; + uint64_t s_recv_drop_dead_sock; + uint64_t s_recv_deliver_raced; + uint64_t s_recv_delivered; + uint64_t s_recv_queued; + uint64_t s_recv_immediate_retry; + uint64_t s_recv_delayed_retry; + uint64_t s_recv_ack_required; + uint64_t s_recv_rdma_bytes; + uint64_t s_recv_ping; + uint64_t s_send_queue_empty; + uint64_t s_send_queue_full; + uint64_t s_send_lock_contention; + uint64_t s_send_lock_queue_raced; + uint64_t s_send_immediate_retry; + uint64_t s_send_delayed_retry; + uint64_t s_send_drop_acked; + uint64_t s_send_ack_required; + uint64_t s_send_queued; + uint64_t s_send_rdma; + uint64_t s_send_rdma_bytes; + uint64_t s_send_pong; + uint64_t s_page_remainder_hit; + uint64_t s_page_remainder_miss; + uint64_t s_copy_to_user; + uint64_t s_copy_from_user; + uint64_t s_cong_update_queued; + uint64_t s_cong_update_received; + uint64_t s_cong_send_error; + uint64_t s_cong_send_blocked; + uint64_t s_recv_bytes_added_to_socket; + uint64_t s_recv_bytes_removed_from_socket; + uint64_t s_send_stuck_rm; +}; + +/* af_rds.c */ +void rds_sock_addref(struct rds_sock *rs); +void rds_sock_put(struct rds_sock *rs); +void rds_wake_sk_sleep(struct rds_sock *rs); +static inline void __rds_wake_sk_sleep(struct sock *sk) +{ + wait_queue_head_t *waitq = sk_sleep(sk); + + if (!sock_flag(sk, SOCK_DEAD) && waitq) + wake_up(waitq); +} +extern wait_queue_head_t rds_poll_waitq; + + +/* bind.c */ +int rds_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len); +void rds_remove_bound(struct rds_sock *rs); +struct rds_sock *rds_find_bound(const struct in6_addr *addr, __be16 port, + __u32 scope_id); +int rds_bind_lock_init(void); +void rds_bind_lock_destroy(void); + +/* cong.c */ +int rds_cong_get_maps(struct rds_connection *conn); +void rds_cong_add_conn(struct rds_connection *conn); +void rds_cong_remove_conn(struct rds_connection *conn); +void rds_cong_set_bit(struct rds_cong_map *map, __be16 port); +void rds_cong_clear_bit(struct rds_cong_map *map, __be16 port); +int rds_cong_wait(struct rds_cong_map *map, __be16 port, int nonblock, struct rds_sock *rs); +void rds_cong_queue_updates(struct rds_cong_map *map); +void rds_cong_map_updated(struct rds_cong_map *map, uint64_t); +int rds_cong_updated_since(unsigned long *recent); +void rds_cong_add_socket(struct rds_sock *); +void rds_cong_remove_socket(struct rds_sock *); +void rds_cong_exit(void); +struct rds_message *rds_cong_update_alloc(struct rds_connection *conn); + +/* connection.c */ +extern u32 rds_gen_num; +int rds_conn_init(void); +void rds_conn_exit(void); +struct rds_connection *rds_conn_create(struct net *net, + const struct in6_addr *laddr, + const struct in6_addr *faddr, + struct rds_transport *trans, + u8 tos, gfp_t gfp, + int dev_if); +struct rds_connection *rds_conn_create_outgoing(struct net *net, + const struct in6_addr *laddr, + const struct in6_addr *faddr, + struct rds_transport *trans, + u8 tos, gfp_t gfp, int dev_if); +void rds_conn_shutdown(struct rds_conn_path *cpath); +void rds_conn_destroy(struct rds_connection *conn); +void rds_conn_drop(struct rds_connection *conn); +void rds_conn_path_drop(struct rds_conn_path *cpath, bool destroy); +void rds_conn_connect_if_down(struct rds_connection *conn); +void rds_conn_path_connect_if_down(struct rds_conn_path *cp); +void rds_check_all_paths(struct rds_connection *conn); +void rds_for_each_conn_info(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens, + int (*visitor)(struct rds_connection *, void *), + u64 *buffer, + size_t item_len); + +__printf(2, 3) +void __rds_conn_path_error(struct rds_conn_path *cp, const char *, ...); +#define rds_conn_path_error(cp, fmt...) \ + __rds_conn_path_error(cp, KERN_WARNING "RDS: " fmt) + +static inline int +rds_conn_path_transition(struct rds_conn_path *cp, int old, int new) +{ + return atomic_cmpxchg(&cp->cp_state, old, new) == old; +} + +static inline int +rds_conn_transition(struct rds_connection *conn, int old, int new) +{ + WARN_ON(conn->c_trans->t_mp_capable); + return rds_conn_path_transition(&conn->c_path[0], old, new); +} + +static inline int +rds_conn_path_state(struct rds_conn_path *cp) +{ + return atomic_read(&cp->cp_state); +} + +static inline int +rds_conn_state(struct rds_connection *conn) +{ + WARN_ON(conn->c_trans->t_mp_capable); + return rds_conn_path_state(&conn->c_path[0]); +} + +static inline int +rds_conn_path_up(struct rds_conn_path *cp) +{ + return atomic_read(&cp->cp_state) == RDS_CONN_UP; +} + +static inline int +rds_conn_path_down(struct rds_conn_path *cp) +{ + return atomic_read(&cp->cp_state) == RDS_CONN_DOWN; +} + +static inline int +rds_conn_up(struct rds_connection *conn) +{ + WARN_ON(conn->c_trans->t_mp_capable); + return rds_conn_path_up(&conn->c_path[0]); +} + +static inline int +rds_conn_path_connecting(struct rds_conn_path *cp) +{ + return atomic_read(&cp->cp_state) == RDS_CONN_CONNECTING; +} + +static inline int +rds_conn_connecting(struct rds_connection *conn) +{ + WARN_ON(conn->c_trans->t_mp_capable); + return rds_conn_path_connecting(&conn->c_path[0]); +} + +/* message.c */ +struct rds_message *rds_message_alloc(unsigned int nents, gfp_t gfp); +struct scatterlist *rds_message_alloc_sgs(struct rds_message *rm, int nents); +int rds_message_copy_from_user(struct rds_message *rm, struct iov_iter *from, + bool zcopy); +struct rds_message *rds_message_map_pages(unsigned long *page_addrs, unsigned int total_len); +void rds_message_populate_header(struct rds_header *hdr, __be16 sport, + __be16 dport, u64 seq); +int rds_message_add_extension(struct rds_header *hdr, + unsigned int type, const void *data, unsigned int len); +int rds_message_next_extension(struct rds_header *hdr, + unsigned int *pos, void *buf, unsigned int *buflen); +int rds_message_add_rdma_dest_extension(struct rds_header *hdr, u32 r_key, u32 offset); +int rds_message_inc_copy_to_user(struct rds_incoming *inc, struct iov_iter *to); +void rds_message_inc_free(struct rds_incoming *inc); +void rds_message_addref(struct rds_message *rm); +void rds_message_put(struct rds_message *rm); +void rds_message_wait(struct rds_message *rm); +void rds_message_unmapped(struct rds_message *rm); +void rds_notify_msg_zcopy_purge(struct rds_msg_zcopy_queue *info); + +static inline void rds_message_make_checksum(struct rds_header *hdr) +{ + hdr->h_csum = 0; + hdr->h_csum = ip_fast_csum((void *) hdr, sizeof(*hdr) >> 2); +} + +static inline int rds_message_verify_checksum(const struct rds_header *hdr) +{ + return !hdr->h_csum || ip_fast_csum((void *) hdr, sizeof(*hdr) >> 2) == 0; +} + + +/* page.c */ +int rds_page_remainder_alloc(struct scatterlist *scat, unsigned long bytes, + gfp_t gfp); +void rds_page_exit(void); + +/* recv.c */ +void rds_inc_init(struct rds_incoming *inc, struct rds_connection *conn, + struct in6_addr *saddr); +void rds_inc_path_init(struct rds_incoming *inc, struct rds_conn_path *conn, + struct in6_addr *saddr); +void rds_inc_put(struct rds_incoming *inc); +void rds_recv_incoming(struct rds_connection *conn, struct in6_addr *saddr, + struct in6_addr *daddr, + struct rds_incoming *inc, gfp_t gfp); +int rds_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, + int msg_flags); +void rds_clear_recv_queue(struct rds_sock *rs); +int rds_notify_queue_get(struct rds_sock *rs, struct msghdr *msg); +void rds_inc_info_copy(struct rds_incoming *inc, + struct rds_info_iterator *iter, + __be32 saddr, __be32 daddr, int flip); +void rds6_inc_info_copy(struct rds_incoming *inc, + struct rds_info_iterator *iter, + struct in6_addr *saddr, struct in6_addr *daddr, + int flip); + +/* send.c */ +int rds_sendmsg(struct socket *sock, struct msghdr *msg, size_t payload_len); +void rds_send_path_reset(struct rds_conn_path *conn); +int rds_send_xmit(struct rds_conn_path *cp); +struct sockaddr_in; +void rds_send_drop_to(struct rds_sock *rs, struct sockaddr_in6 *dest); +typedef int (*is_acked_func)(struct rds_message *rm, uint64_t ack); +void rds_send_drop_acked(struct rds_connection *conn, u64 ack, + is_acked_func is_acked); +void rds_send_path_drop_acked(struct rds_conn_path *cp, u64 ack, + is_acked_func is_acked); +void rds_send_ping(struct rds_connection *conn, int cp_index); +int rds_send_pong(struct rds_conn_path *cp, __be16 dport); + +/* rdma.c */ +void rds_rdma_unuse(struct rds_sock *rs, u32 r_key, int force); +int rds_get_mr(struct rds_sock *rs, sockptr_t optval, int optlen); +int rds_get_mr_for_dest(struct rds_sock *rs, sockptr_t optval, int optlen); +int rds_free_mr(struct rds_sock *rs, sockptr_t optval, int optlen); +void rds_rdma_drop_keys(struct rds_sock *rs); +int rds_rdma_extra_size(struct rds_rdma_args *args, + struct rds_iov_vector *iov); +int rds_cmsg_rdma_dest(struct rds_sock *rs, struct rds_message *rm, + struct cmsghdr *cmsg); +int rds_cmsg_rdma_args(struct rds_sock *rs, struct rds_message *rm, + struct cmsghdr *cmsg, + struct rds_iov_vector *vec); +int rds_cmsg_rdma_map(struct rds_sock *rs, struct rds_message *rm, + struct cmsghdr *cmsg); +void rds_rdma_free_op(struct rm_rdma_op *ro); +void rds_atomic_free_op(struct rm_atomic_op *ao); +void rds_rdma_send_complete(struct rds_message *rm, int wc_status); +void rds_atomic_send_complete(struct rds_message *rm, int wc_status); +int rds_cmsg_atomic(struct rds_sock *rs, struct rds_message *rm, + struct cmsghdr *cmsg); + +void __rds_put_mr_final(struct kref *kref); + +static inline bool rds_destroy_pending(struct rds_connection *conn) +{ + return !check_net(rds_conn_net(conn)) || + (conn->c_trans->t_unloading && conn->c_trans->t_unloading(conn)); +} + +enum { + ODP_NOT_NEEDED, + ODP_ZEROBASED, + ODP_VIRTUAL +}; + +/* stats.c */ +DECLARE_PER_CPU_SHARED_ALIGNED(struct rds_statistics, rds_stats); +#define rds_stats_inc_which(which, member) do { \ + per_cpu(which, get_cpu()).member++; \ + put_cpu(); \ +} while (0) +#define rds_stats_inc(member) rds_stats_inc_which(rds_stats, member) +#define rds_stats_add_which(which, member, count) do { \ + per_cpu(which, get_cpu()).member += count; \ + put_cpu(); \ +} while (0) +#define rds_stats_add(member, count) rds_stats_add_which(rds_stats, member, count) +int rds_stats_init(void); +void rds_stats_exit(void); +void rds_stats_info_copy(struct rds_info_iterator *iter, + uint64_t *values, const char *const *names, + size_t nr); + +/* sysctl.c */ +int rds_sysctl_init(void); +void rds_sysctl_exit(void); +extern unsigned long rds_sysctl_sndbuf_min; +extern unsigned long rds_sysctl_sndbuf_default; +extern unsigned long rds_sysctl_sndbuf_max; +extern unsigned long rds_sysctl_reconnect_min_jiffies; +extern unsigned long rds_sysctl_reconnect_max_jiffies; +extern unsigned int rds_sysctl_max_unacked_packets; +extern unsigned int rds_sysctl_max_unacked_bytes; +extern unsigned int rds_sysctl_ping_enable; +extern unsigned long rds_sysctl_trace_flags; +extern unsigned int rds_sysctl_trace_level; + +/* threads.c */ +int rds_threads_init(void); +void rds_threads_exit(void); +extern struct workqueue_struct *rds_wq; +void rds_queue_reconnect(struct rds_conn_path *cp); +void rds_connect_worker(struct work_struct *); +void rds_shutdown_worker(struct work_struct *); +void rds_send_worker(struct work_struct *); +void rds_recv_worker(struct work_struct *); +void rds_connect_path_complete(struct rds_conn_path *conn, int curr); +void rds_connect_complete(struct rds_connection *conn); +int rds_addr_cmp(const struct in6_addr *a1, const struct in6_addr *a2); + +/* transport.c */ +void rds_trans_register(struct rds_transport *trans); +void rds_trans_unregister(struct rds_transport *trans); +struct rds_transport *rds_trans_get_preferred(struct net *net, + const struct in6_addr *addr, + __u32 scope_id); +void rds_trans_put(struct rds_transport *trans); +unsigned int rds_trans_stats_info_copy(struct rds_info_iterator *iter, + unsigned int avail); +struct rds_transport *rds_trans_get(int t_type); +int rds_trans_init(void); +void rds_trans_exit(void); + +#endif diff --git a/net/rds/rds_single_path.h b/net/rds/rds_single_path.h new file mode 100644 index 000000000..9521f6e99 --- /dev/null +++ b/net/rds/rds_single_path.h @@ -0,0 +1,31 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _RDS_RDS_SINGLE_H +#define _RDS_RDS_SINGLE_H + +#define c_xmit_rm c_path[0].cp_xmit_rm +#define c_xmit_sg c_path[0].cp_xmit_sg +#define c_xmit_hdr_off c_path[0].cp_xmit_hdr_off +#define c_xmit_data_off c_path[0].cp_xmit_data_off +#define c_xmit_atomic_sent c_path[0].cp_xmit_atomic_sent +#define c_xmit_rdma_sent c_path[0].cp_xmit_rdma_sent +#define c_xmit_data_sent c_path[0].cp_xmit_data_sent +#define c_lock c_path[0].cp_lock +#define c_next_tx_seq c_path[0].cp_next_tx_seq +#define c_send_queue c_path[0].cp_send_queue +#define c_retrans c_path[0].cp_retrans +#define c_next_rx_seq c_path[0].cp_next_rx_seq +#define c_transport_data c_path[0].cp_transport_data +#define c_state c_path[0].cp_state +#define c_send_gen c_path[0].cp_send_gen +#define c_flags c_path[0].cp_flags +#define c_reconnect_jiffies c_path[0].cp_reconnect_jiffies +#define c_send_w c_path[0].cp_send_w +#define c_recv_w c_path[0].cp_recv_w +#define c_conn_w c_path[0].cp_conn_w +#define c_down_w c_path[0].cp_down_w +#define c_cm_lock c_path[0].cp_cm_lock +#define c_waitq c_path[0].cp_waitq +#define c_unacked_packets c_path[0].cp_unacked_packets +#define c_unacked_bytes c_path[0].cp_unacked_bytes + +#endif /* _RDS_RDS_SINGLE_H */ diff --git a/net/rds/recv.c b/net/rds/recv.c new file mode 100644 index 000000000..5b426dc36 --- /dev/null +++ b/net/rds/recv.c @@ -0,0 +1,831 @@ +/* + * Copyright (c) 2006, 2019 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include +#include +#include +#include + +#include "rds.h" + +void rds_inc_init(struct rds_incoming *inc, struct rds_connection *conn, + struct in6_addr *saddr) +{ + refcount_set(&inc->i_refcount, 1); + INIT_LIST_HEAD(&inc->i_item); + inc->i_conn = conn; + inc->i_saddr = *saddr; + inc->i_usercopy.rdma_cookie = 0; + inc->i_usercopy.rx_tstamp = ktime_set(0, 0); + + memset(inc->i_rx_lat_trace, 0, sizeof(inc->i_rx_lat_trace)); +} +EXPORT_SYMBOL_GPL(rds_inc_init); + +void rds_inc_path_init(struct rds_incoming *inc, struct rds_conn_path *cp, + struct in6_addr *saddr) +{ + refcount_set(&inc->i_refcount, 1); + INIT_LIST_HEAD(&inc->i_item); + inc->i_conn = cp->cp_conn; + inc->i_conn_path = cp; + inc->i_saddr = *saddr; + inc->i_usercopy.rdma_cookie = 0; + inc->i_usercopy.rx_tstamp = ktime_set(0, 0); +} +EXPORT_SYMBOL_GPL(rds_inc_path_init); + +static void rds_inc_addref(struct rds_incoming *inc) +{ + rdsdebug("addref inc %p ref %d\n", inc, refcount_read(&inc->i_refcount)); + refcount_inc(&inc->i_refcount); +} + +void rds_inc_put(struct rds_incoming *inc) +{ + rdsdebug("put inc %p ref %d\n", inc, refcount_read(&inc->i_refcount)); + if (refcount_dec_and_test(&inc->i_refcount)) { + BUG_ON(!list_empty(&inc->i_item)); + + inc->i_conn->c_trans->inc_free(inc); + } +} +EXPORT_SYMBOL_GPL(rds_inc_put); + +static void rds_recv_rcvbuf_delta(struct rds_sock *rs, struct sock *sk, + struct rds_cong_map *map, + int delta, __be16 port) +{ + int now_congested; + + if (delta == 0) + return; + + rs->rs_rcv_bytes += delta; + if (delta > 0) + rds_stats_add(s_recv_bytes_added_to_socket, delta); + else + rds_stats_add(s_recv_bytes_removed_from_socket, -delta); + + /* loop transport doesn't send/recv congestion updates */ + if (rs->rs_transport->t_type == RDS_TRANS_LOOP) + return; + + now_congested = rs->rs_rcv_bytes > rds_sk_rcvbuf(rs); + + rdsdebug("rs %p (%pI6c:%u) recv bytes %d buf %d " + "now_cong %d delta %d\n", + rs, &rs->rs_bound_addr, + ntohs(rs->rs_bound_port), rs->rs_rcv_bytes, + rds_sk_rcvbuf(rs), now_congested, delta); + + /* wasn't -> am congested */ + if (!rs->rs_congested && now_congested) { + rs->rs_congested = 1; + rds_cong_set_bit(map, port); + rds_cong_queue_updates(map); + } + /* was -> aren't congested */ + /* Require more free space before reporting uncongested to prevent + bouncing cong/uncong state too often */ + else if (rs->rs_congested && (rs->rs_rcv_bytes < (rds_sk_rcvbuf(rs)/2))) { + rs->rs_congested = 0; + rds_cong_clear_bit(map, port); + rds_cong_queue_updates(map); + } + + /* do nothing if no change in cong state */ +} + +static void rds_conn_peer_gen_update(struct rds_connection *conn, + u32 peer_gen_num) +{ + int i; + struct rds_message *rm, *tmp; + unsigned long flags; + + WARN_ON(conn->c_trans->t_type != RDS_TRANS_TCP); + if (peer_gen_num != 0) { + if (conn->c_peer_gen_num != 0 && + peer_gen_num != conn->c_peer_gen_num) { + for (i = 0; i < RDS_MPATH_WORKERS; i++) { + struct rds_conn_path *cp; + + cp = &conn->c_path[i]; + spin_lock_irqsave(&cp->cp_lock, flags); + cp->cp_next_tx_seq = 1; + cp->cp_next_rx_seq = 0; + list_for_each_entry_safe(rm, tmp, + &cp->cp_retrans, + m_conn_item) { + set_bit(RDS_MSG_FLUSH, &rm->m_flags); + } + spin_unlock_irqrestore(&cp->cp_lock, flags); + } + } + conn->c_peer_gen_num = peer_gen_num; + } +} + +/* + * Process all extension headers that come with this message. + */ +static void rds_recv_incoming_exthdrs(struct rds_incoming *inc, struct rds_sock *rs) +{ + struct rds_header *hdr = &inc->i_hdr; + unsigned int pos = 0, type, len; + union { + struct rds_ext_header_version version; + struct rds_ext_header_rdma rdma; + struct rds_ext_header_rdma_dest rdma_dest; + } buffer; + + while (1) { + len = sizeof(buffer); + type = rds_message_next_extension(hdr, &pos, &buffer, &len); + if (type == RDS_EXTHDR_NONE) + break; + /* Process extension header here */ + switch (type) { + case RDS_EXTHDR_RDMA: + rds_rdma_unuse(rs, be32_to_cpu(buffer.rdma.h_rdma_rkey), 0); + break; + + case RDS_EXTHDR_RDMA_DEST: + /* We ignore the size for now. We could stash it + * somewhere and use it for error checking. */ + inc->i_usercopy.rdma_cookie = rds_rdma_make_cookie( + be32_to_cpu(buffer.rdma_dest.h_rdma_rkey), + be32_to_cpu(buffer.rdma_dest.h_rdma_offset)); + + break; + } + } +} + +static void rds_recv_hs_exthdrs(struct rds_header *hdr, + struct rds_connection *conn) +{ + unsigned int pos = 0, type, len; + union { + struct rds_ext_header_version version; + u16 rds_npaths; + u32 rds_gen_num; + } buffer; + u32 new_peer_gen_num = 0; + + while (1) { + len = sizeof(buffer); + type = rds_message_next_extension(hdr, &pos, &buffer, &len); + if (type == RDS_EXTHDR_NONE) + break; + /* Process extension header here */ + switch (type) { + case RDS_EXTHDR_NPATHS: + conn->c_npaths = min_t(int, RDS_MPATH_WORKERS, + be16_to_cpu(buffer.rds_npaths)); + break; + case RDS_EXTHDR_GEN_NUM: + new_peer_gen_num = be32_to_cpu(buffer.rds_gen_num); + break; + default: + pr_warn_ratelimited("ignoring unknown exthdr type " + "0x%x\n", type); + } + } + /* if RDS_EXTHDR_NPATHS was not found, default to a single-path */ + conn->c_npaths = max_t(int, conn->c_npaths, 1); + conn->c_ping_triggered = 0; + rds_conn_peer_gen_update(conn, new_peer_gen_num); +} + +/* rds_start_mprds() will synchronously start multiple paths when appropriate. + * The scheme is based on the following rules: + * + * 1. rds_sendmsg on first connect attempt sends the probe ping, with the + * sender's npaths (s_npaths) + * 2. rcvr of probe-ping knows the mprds_paths = min(s_npaths, r_npaths). It + * sends back a probe-pong with r_npaths. After that, if rcvr is the + * smaller ip addr, it starts rds_conn_path_connect_if_down on all + * mprds_paths. + * 3. sender gets woken up, and can move to rds_conn_path_connect_if_down. + * If it is the smaller ipaddr, rds_conn_path_connect_if_down can be + * called after reception of the probe-pong on all mprds_paths. + * Otherwise (sender of probe-ping is not the smaller ip addr): just call + * rds_conn_path_connect_if_down on the hashed path. (see rule 4) + * 4. rds_connect_worker must only trigger a connection if laddr < faddr. + * 5. sender may end up queuing the packet on the cp. will get sent out later. + * when connection is completed. + */ +static void rds_start_mprds(struct rds_connection *conn) +{ + int i; + struct rds_conn_path *cp; + + if (conn->c_npaths > 1 && + rds_addr_cmp(&conn->c_laddr, &conn->c_faddr) < 0) { + for (i = 0; i < conn->c_npaths; i++) { + cp = &conn->c_path[i]; + rds_conn_path_connect_if_down(cp); + } + } +} + +/* + * The transport must make sure that this is serialized against other + * rx and conn reset on this specific conn. + * + * We currently assert that only one fragmented message will be sent + * down a connection at a time. This lets us reassemble in the conn + * instead of per-flow which means that we don't have to go digging through + * flows to tear down partial reassembly progress on conn failure and + * we save flow lookup and locking for each frag arrival. It does mean + * that small messages will wait behind large ones. Fragmenting at all + * is only to reduce the memory consumption of pre-posted buffers. + * + * The caller passes in saddr and daddr instead of us getting it from the + * conn. This lets loopback, who only has one conn for both directions, + * tell us which roles the addrs in the conn are playing for this message. + */ +void rds_recv_incoming(struct rds_connection *conn, struct in6_addr *saddr, + struct in6_addr *daddr, + struct rds_incoming *inc, gfp_t gfp) +{ + struct rds_sock *rs = NULL; + struct sock *sk; + unsigned long flags; + struct rds_conn_path *cp; + + inc->i_conn = conn; + inc->i_rx_jiffies = jiffies; + if (conn->c_trans->t_mp_capable) + cp = inc->i_conn_path; + else + cp = &conn->c_path[0]; + + rdsdebug("conn %p next %llu inc %p seq %llu len %u sport %u dport %u " + "flags 0x%x rx_jiffies %lu\n", conn, + (unsigned long long)cp->cp_next_rx_seq, + inc, + (unsigned long long)be64_to_cpu(inc->i_hdr.h_sequence), + be32_to_cpu(inc->i_hdr.h_len), + be16_to_cpu(inc->i_hdr.h_sport), + be16_to_cpu(inc->i_hdr.h_dport), + inc->i_hdr.h_flags, + inc->i_rx_jiffies); + + /* + * Sequence numbers should only increase. Messages get their + * sequence number as they're queued in a sending conn. They + * can be dropped, though, if the sending socket is closed before + * they hit the wire. So sequence numbers can skip forward + * under normal operation. They can also drop back in the conn + * failover case as previously sent messages are resent down the + * new instance of a conn. We drop those, otherwise we have + * to assume that the next valid seq does not come after a + * hole in the fragment stream. + * + * The headers don't give us a way to realize if fragments of + * a message have been dropped. We assume that frags that arrive + * to a flow are part of the current message on the flow that is + * being reassembled. This means that senders can't drop messages + * from the sending conn until all their frags are sent. + * + * XXX we could spend more on the wire to get more robust failure + * detection, arguably worth it to avoid data corruption. + */ + if (be64_to_cpu(inc->i_hdr.h_sequence) < cp->cp_next_rx_seq && + (inc->i_hdr.h_flags & RDS_FLAG_RETRANSMITTED)) { + rds_stats_inc(s_recv_drop_old_seq); + goto out; + } + cp->cp_next_rx_seq = be64_to_cpu(inc->i_hdr.h_sequence) + 1; + + if (rds_sysctl_ping_enable && inc->i_hdr.h_dport == 0) { + if (inc->i_hdr.h_sport == 0) { + rdsdebug("ignore ping with 0 sport from %pI6c\n", + saddr); + goto out; + } + rds_stats_inc(s_recv_ping); + rds_send_pong(cp, inc->i_hdr.h_sport); + /* if this is a handshake ping, start multipath if necessary */ + if (RDS_HS_PROBE(be16_to_cpu(inc->i_hdr.h_sport), + be16_to_cpu(inc->i_hdr.h_dport))) { + rds_recv_hs_exthdrs(&inc->i_hdr, cp->cp_conn); + rds_start_mprds(cp->cp_conn); + } + goto out; + } + + if (be16_to_cpu(inc->i_hdr.h_dport) == RDS_FLAG_PROBE_PORT && + inc->i_hdr.h_sport == 0) { + rds_recv_hs_exthdrs(&inc->i_hdr, cp->cp_conn); + /* if this is a handshake pong, start multipath if necessary */ + rds_start_mprds(cp->cp_conn); + wake_up(&cp->cp_conn->c_hs_waitq); + goto out; + } + + rs = rds_find_bound(daddr, inc->i_hdr.h_dport, conn->c_bound_if); + if (!rs) { + rds_stats_inc(s_recv_drop_no_sock); + goto out; + } + + /* Process extension headers */ + rds_recv_incoming_exthdrs(inc, rs); + + /* We can be racing with rds_release() which marks the socket dead. */ + sk = rds_rs_to_sk(rs); + + /* serialize with rds_release -> sock_orphan */ + write_lock_irqsave(&rs->rs_recv_lock, flags); + if (!sock_flag(sk, SOCK_DEAD)) { + rdsdebug("adding inc %p to rs %p's recv queue\n", inc, rs); + rds_stats_inc(s_recv_queued); + rds_recv_rcvbuf_delta(rs, sk, inc->i_conn->c_lcong, + be32_to_cpu(inc->i_hdr.h_len), + inc->i_hdr.h_dport); + if (sock_flag(sk, SOCK_RCVTSTAMP)) + inc->i_usercopy.rx_tstamp = ktime_get_real(); + rds_inc_addref(inc); + inc->i_rx_lat_trace[RDS_MSG_RX_END] = local_clock(); + list_add_tail(&inc->i_item, &rs->rs_recv_queue); + __rds_wake_sk_sleep(sk); + } else { + rds_stats_inc(s_recv_drop_dead_sock); + } + write_unlock_irqrestore(&rs->rs_recv_lock, flags); + +out: + if (rs) + rds_sock_put(rs); +} +EXPORT_SYMBOL_GPL(rds_recv_incoming); + +/* + * be very careful here. This is being called as the condition in + * wait_event_*() needs to cope with being called many times. + */ +static int rds_next_incoming(struct rds_sock *rs, struct rds_incoming **inc) +{ + unsigned long flags; + + if (!*inc) { + read_lock_irqsave(&rs->rs_recv_lock, flags); + if (!list_empty(&rs->rs_recv_queue)) { + *inc = list_entry(rs->rs_recv_queue.next, + struct rds_incoming, + i_item); + rds_inc_addref(*inc); + } + read_unlock_irqrestore(&rs->rs_recv_lock, flags); + } + + return *inc != NULL; +} + +static int rds_still_queued(struct rds_sock *rs, struct rds_incoming *inc, + int drop) +{ + struct sock *sk = rds_rs_to_sk(rs); + int ret = 0; + unsigned long flags; + + write_lock_irqsave(&rs->rs_recv_lock, flags); + if (!list_empty(&inc->i_item)) { + ret = 1; + if (drop) { + /* XXX make sure this i_conn is reliable */ + rds_recv_rcvbuf_delta(rs, sk, inc->i_conn->c_lcong, + -be32_to_cpu(inc->i_hdr.h_len), + inc->i_hdr.h_dport); + list_del_init(&inc->i_item); + rds_inc_put(inc); + } + } + write_unlock_irqrestore(&rs->rs_recv_lock, flags); + + rdsdebug("inc %p rs %p still %d dropped %d\n", inc, rs, ret, drop); + return ret; +} + +/* + * Pull errors off the error queue. + * If msghdr is NULL, we will just purge the error queue. + */ +int rds_notify_queue_get(struct rds_sock *rs, struct msghdr *msghdr) +{ + struct rds_notifier *notifier; + struct rds_rdma_notify cmsg; + unsigned int count = 0, max_messages = ~0U; + unsigned long flags; + LIST_HEAD(copy); + int err = 0; + + memset(&cmsg, 0, sizeof(cmsg)); /* fill holes with zero */ + + /* put_cmsg copies to user space and thus may sleep. We can't do this + * with rs_lock held, so first grab as many notifications as we can stuff + * in the user provided cmsg buffer. We don't try to copy more, to avoid + * losing notifications - except when the buffer is so small that it wouldn't + * even hold a single notification. Then we give him as much of this single + * msg as we can squeeze in, and set MSG_CTRUNC. + */ + if (msghdr) { + max_messages = msghdr->msg_controllen / CMSG_SPACE(sizeof(cmsg)); + if (!max_messages) + max_messages = 1; + } + + spin_lock_irqsave(&rs->rs_lock, flags); + while (!list_empty(&rs->rs_notify_queue) && count < max_messages) { + notifier = list_entry(rs->rs_notify_queue.next, + struct rds_notifier, n_list); + list_move(¬ifier->n_list, ©); + count++; + } + spin_unlock_irqrestore(&rs->rs_lock, flags); + + if (!count) + return 0; + + while (!list_empty(©)) { + notifier = list_entry(copy.next, struct rds_notifier, n_list); + + if (msghdr) { + cmsg.user_token = notifier->n_user_token; + cmsg.status = notifier->n_status; + + err = put_cmsg(msghdr, SOL_RDS, RDS_CMSG_RDMA_STATUS, + sizeof(cmsg), &cmsg); + if (err) + break; + } + + list_del_init(¬ifier->n_list); + kfree(notifier); + } + + /* If we bailed out because of an error in put_cmsg, + * we may be left with one or more notifications that we + * didn't process. Return them to the head of the list. */ + if (!list_empty(©)) { + spin_lock_irqsave(&rs->rs_lock, flags); + list_splice(©, &rs->rs_notify_queue); + spin_unlock_irqrestore(&rs->rs_lock, flags); + } + + return err; +} + +/* + * Queue a congestion notification + */ +static int rds_notify_cong(struct rds_sock *rs, struct msghdr *msghdr) +{ + uint64_t notify = rs->rs_cong_notify; + unsigned long flags; + int err; + + err = put_cmsg(msghdr, SOL_RDS, RDS_CMSG_CONG_UPDATE, + sizeof(notify), ¬ify); + if (err) + return err; + + spin_lock_irqsave(&rs->rs_lock, flags); + rs->rs_cong_notify &= ~notify; + spin_unlock_irqrestore(&rs->rs_lock, flags); + + return 0; +} + +/* + * Receive any control messages. + */ +static int rds_cmsg_recv(struct rds_incoming *inc, struct msghdr *msg, + struct rds_sock *rs) +{ + int ret = 0; + + if (inc->i_usercopy.rdma_cookie) { + ret = put_cmsg(msg, SOL_RDS, RDS_CMSG_RDMA_DEST, + sizeof(inc->i_usercopy.rdma_cookie), + &inc->i_usercopy.rdma_cookie); + if (ret) + goto out; + } + + if ((inc->i_usercopy.rx_tstamp != 0) && + sock_flag(rds_rs_to_sk(rs), SOCK_RCVTSTAMP)) { + struct __kernel_old_timeval tv = + ns_to_kernel_old_timeval(inc->i_usercopy.rx_tstamp); + + if (!sock_flag(rds_rs_to_sk(rs), SOCK_TSTAMP_NEW)) { + ret = put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMP_OLD, + sizeof(tv), &tv); + } else { + struct __kernel_sock_timeval sk_tv; + + sk_tv.tv_sec = tv.tv_sec; + sk_tv.tv_usec = tv.tv_usec; + + ret = put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMP_NEW, + sizeof(sk_tv), &sk_tv); + } + + if (ret) + goto out; + } + + if (rs->rs_rx_traces) { + struct rds_cmsg_rx_trace t; + int i, j; + + memset(&t, 0, sizeof(t)); + inc->i_rx_lat_trace[RDS_MSG_RX_CMSG] = local_clock(); + t.rx_traces = rs->rs_rx_traces; + for (i = 0; i < rs->rs_rx_traces; i++) { + j = rs->rs_rx_trace[i]; + t.rx_trace_pos[i] = j; + t.rx_trace[i] = inc->i_rx_lat_trace[j + 1] - + inc->i_rx_lat_trace[j]; + } + + ret = put_cmsg(msg, SOL_RDS, RDS_CMSG_RXPATH_LATENCY, + sizeof(t), &t); + if (ret) + goto out; + } + +out: + return ret; +} + +static bool rds_recvmsg_zcookie(struct rds_sock *rs, struct msghdr *msg) +{ + struct rds_msg_zcopy_queue *q = &rs->rs_zcookie_queue; + struct rds_msg_zcopy_info *info = NULL; + struct rds_zcopy_cookies *done; + unsigned long flags; + + if (!msg->msg_control) + return false; + + if (!sock_flag(rds_rs_to_sk(rs), SOCK_ZEROCOPY) || + msg->msg_controllen < CMSG_SPACE(sizeof(*done))) + return false; + + spin_lock_irqsave(&q->lock, flags); + if (!list_empty(&q->zcookie_head)) { + info = list_entry(q->zcookie_head.next, + struct rds_msg_zcopy_info, rs_zcookie_next); + list_del(&info->rs_zcookie_next); + } + spin_unlock_irqrestore(&q->lock, flags); + if (!info) + return false; + done = &info->zcookies; + if (put_cmsg(msg, SOL_RDS, RDS_CMSG_ZCOPY_COMPLETION, sizeof(*done), + done)) { + spin_lock_irqsave(&q->lock, flags); + list_add(&info->rs_zcookie_next, &q->zcookie_head); + spin_unlock_irqrestore(&q->lock, flags); + return false; + } + kfree(info); + return true; +} + +int rds_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, + int msg_flags) +{ + struct sock *sk = sock->sk; + struct rds_sock *rs = rds_sk_to_rs(sk); + long timeo; + int ret = 0, nonblock = msg_flags & MSG_DONTWAIT; + DECLARE_SOCKADDR(struct sockaddr_in6 *, sin6, msg->msg_name); + DECLARE_SOCKADDR(struct sockaddr_in *, sin, msg->msg_name); + struct rds_incoming *inc = NULL; + + /* udp_recvmsg()->sock_recvtimeo() gets away without locking too.. */ + timeo = sock_rcvtimeo(sk, nonblock); + + rdsdebug("size %zu flags 0x%x timeo %ld\n", size, msg_flags, timeo); + + if (msg_flags & MSG_OOB) + goto out; + if (msg_flags & MSG_ERRQUEUE) + return sock_recv_errqueue(sk, msg, size, SOL_IP, IP_RECVERR); + + while (1) { + /* If there are pending notifications, do those - and nothing else */ + if (!list_empty(&rs->rs_notify_queue)) { + ret = rds_notify_queue_get(rs, msg); + break; + } + + if (rs->rs_cong_notify) { + ret = rds_notify_cong(rs, msg); + break; + } + + if (!rds_next_incoming(rs, &inc)) { + if (nonblock) { + bool reaped = rds_recvmsg_zcookie(rs, msg); + + ret = reaped ? 0 : -EAGAIN; + break; + } + + timeo = wait_event_interruptible_timeout(*sk_sleep(sk), + (!list_empty(&rs->rs_notify_queue) || + rs->rs_cong_notify || + rds_next_incoming(rs, &inc)), timeo); + rdsdebug("recvmsg woke inc %p timeo %ld\n", inc, + timeo); + if (timeo > 0 || timeo == MAX_SCHEDULE_TIMEOUT) + continue; + + ret = timeo; + if (ret == 0) + ret = -ETIMEDOUT; + break; + } + + rdsdebug("copying inc %p from %pI6c:%u to user\n", inc, + &inc->i_conn->c_faddr, + ntohs(inc->i_hdr.h_sport)); + ret = inc->i_conn->c_trans->inc_copy_to_user(inc, &msg->msg_iter); + if (ret < 0) + break; + + /* + * if the message we just copied isn't at the head of the + * recv queue then someone else raced us to return it, try + * to get the next message. + */ + if (!rds_still_queued(rs, inc, !(msg_flags & MSG_PEEK))) { + rds_inc_put(inc); + inc = NULL; + rds_stats_inc(s_recv_deliver_raced); + iov_iter_revert(&msg->msg_iter, ret); + continue; + } + + if (ret < be32_to_cpu(inc->i_hdr.h_len)) { + if (msg_flags & MSG_TRUNC) + ret = be32_to_cpu(inc->i_hdr.h_len); + msg->msg_flags |= MSG_TRUNC; + } + + if (rds_cmsg_recv(inc, msg, rs)) { + ret = -EFAULT; + break; + } + rds_recvmsg_zcookie(rs, msg); + + rds_stats_inc(s_recv_delivered); + + if (msg->msg_name) { + if (ipv6_addr_v4mapped(&inc->i_saddr)) { + sin->sin_family = AF_INET; + sin->sin_port = inc->i_hdr.h_sport; + sin->sin_addr.s_addr = + inc->i_saddr.s6_addr32[3]; + memset(sin->sin_zero, 0, sizeof(sin->sin_zero)); + msg->msg_namelen = sizeof(*sin); + } else { + sin6->sin6_family = AF_INET6; + sin6->sin6_port = inc->i_hdr.h_sport; + sin6->sin6_addr = inc->i_saddr; + sin6->sin6_flowinfo = 0; + sin6->sin6_scope_id = rs->rs_bound_scope_id; + msg->msg_namelen = sizeof(*sin6); + } + } + break; + } + + if (inc) + rds_inc_put(inc); + +out: + return ret; +} + +/* + * The socket is being shut down and we're asked to drop messages that were + * queued for recvmsg. The caller has unbound the socket so the receive path + * won't queue any more incoming fragments or messages on the socket. + */ +void rds_clear_recv_queue(struct rds_sock *rs) +{ + struct sock *sk = rds_rs_to_sk(rs); + struct rds_incoming *inc, *tmp; + unsigned long flags; + + write_lock_irqsave(&rs->rs_recv_lock, flags); + list_for_each_entry_safe(inc, tmp, &rs->rs_recv_queue, i_item) { + rds_recv_rcvbuf_delta(rs, sk, inc->i_conn->c_lcong, + -be32_to_cpu(inc->i_hdr.h_len), + inc->i_hdr.h_dport); + list_del_init(&inc->i_item); + rds_inc_put(inc); + } + write_unlock_irqrestore(&rs->rs_recv_lock, flags); +} + +/* + * inc->i_saddr isn't used here because it is only set in the receive + * path. + */ +void rds_inc_info_copy(struct rds_incoming *inc, + struct rds_info_iterator *iter, + __be32 saddr, __be32 daddr, int flip) +{ + struct rds_info_message minfo; + + minfo.seq = be64_to_cpu(inc->i_hdr.h_sequence); + minfo.len = be32_to_cpu(inc->i_hdr.h_len); + minfo.tos = inc->i_conn->c_tos; + + if (flip) { + minfo.laddr = daddr; + minfo.faddr = saddr; + minfo.lport = inc->i_hdr.h_dport; + minfo.fport = inc->i_hdr.h_sport; + } else { + minfo.laddr = saddr; + minfo.faddr = daddr; + minfo.lport = inc->i_hdr.h_sport; + minfo.fport = inc->i_hdr.h_dport; + } + + minfo.flags = 0; + + rds_info_copy(iter, &minfo, sizeof(minfo)); +} + +#if IS_ENABLED(CONFIG_IPV6) +void rds6_inc_info_copy(struct rds_incoming *inc, + struct rds_info_iterator *iter, + struct in6_addr *saddr, struct in6_addr *daddr, + int flip) +{ + struct rds6_info_message minfo6; + + minfo6.seq = be64_to_cpu(inc->i_hdr.h_sequence); + minfo6.len = be32_to_cpu(inc->i_hdr.h_len); + minfo6.tos = inc->i_conn->c_tos; + + if (flip) { + minfo6.laddr = *daddr; + minfo6.faddr = *saddr; + minfo6.lport = inc->i_hdr.h_dport; + minfo6.fport = inc->i_hdr.h_sport; + } else { + minfo6.laddr = *saddr; + minfo6.faddr = *daddr; + minfo6.lport = inc->i_hdr.h_sport; + minfo6.fport = inc->i_hdr.h_dport; + } + + minfo6.flags = 0; + + rds_info_copy(iter, &minfo6, sizeof(minfo6)); +} +#endif diff --git a/net/rds/send.c b/net/rds/send.c new file mode 100644 index 000000000..0c5504068 --- /dev/null +++ b/net/rds/send.c @@ -0,0 +1,1515 @@ +/* + * Copyright (c) 2006, 2018 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "rds.h" + +/* When transmitting messages in rds_send_xmit, we need to emerge from + * time to time and briefly release the CPU. Otherwise the softlock watchdog + * will kick our shin. + * Also, it seems fairer to not let one busy connection stall all the + * others. + * + * send_batch_count is the number of times we'll loop in send_xmit. Setting + * it to 0 will restore the old behavior (where we looped until we had + * drained the queue). + */ +static int send_batch_count = SZ_1K; +module_param(send_batch_count, int, 0444); +MODULE_PARM_DESC(send_batch_count, " batch factor when working the send queue"); + +static void rds_send_remove_from_sock(struct list_head *messages, int status); + +/* + * Reset the send state. Callers must ensure that this doesn't race with + * rds_send_xmit(). + */ +void rds_send_path_reset(struct rds_conn_path *cp) +{ + struct rds_message *rm, *tmp; + unsigned long flags; + + if (cp->cp_xmit_rm) { + rm = cp->cp_xmit_rm; + cp->cp_xmit_rm = NULL; + /* Tell the user the RDMA op is no longer mapped by the + * transport. This isn't entirely true (it's flushed out + * independently) but as the connection is down, there's + * no ongoing RDMA to/from that memory */ + rds_message_unmapped(rm); + rds_message_put(rm); + } + + cp->cp_xmit_sg = 0; + cp->cp_xmit_hdr_off = 0; + cp->cp_xmit_data_off = 0; + cp->cp_xmit_atomic_sent = 0; + cp->cp_xmit_rdma_sent = 0; + cp->cp_xmit_data_sent = 0; + + cp->cp_conn->c_map_queued = 0; + + cp->cp_unacked_packets = rds_sysctl_max_unacked_packets; + cp->cp_unacked_bytes = rds_sysctl_max_unacked_bytes; + + /* Mark messages as retransmissions, and move them to the send q */ + spin_lock_irqsave(&cp->cp_lock, flags); + list_for_each_entry_safe(rm, tmp, &cp->cp_retrans, m_conn_item) { + set_bit(RDS_MSG_ACK_REQUIRED, &rm->m_flags); + set_bit(RDS_MSG_RETRANSMITTED, &rm->m_flags); + } + list_splice_init(&cp->cp_retrans, &cp->cp_send_queue); + spin_unlock_irqrestore(&cp->cp_lock, flags); +} +EXPORT_SYMBOL_GPL(rds_send_path_reset); + +static int acquire_in_xmit(struct rds_conn_path *cp) +{ + return test_and_set_bit(RDS_IN_XMIT, &cp->cp_flags) == 0; +} + +static void release_in_xmit(struct rds_conn_path *cp) +{ + clear_bit(RDS_IN_XMIT, &cp->cp_flags); + smp_mb__after_atomic(); + /* + * We don't use wait_on_bit()/wake_up_bit() because our waking is in a + * hot path and finding waiters is very rare. We don't want to walk + * the system-wide hashed waitqueue buckets in the fast path only to + * almost never find waiters. + */ + if (waitqueue_active(&cp->cp_waitq)) + wake_up_all(&cp->cp_waitq); +} + +/* + * We're making the conscious trade-off here to only send one message + * down the connection at a time. + * Pro: + * - tx queueing is a simple fifo list + * - reassembly is optional and easily done by transports per conn + * - no per flow rx lookup at all, straight to the socket + * - less per-frag memory and wire overhead + * Con: + * - queued acks can be delayed behind large messages + * Depends: + * - small message latency is higher behind queued large messages + * - large message latency isn't starved by intervening small sends + */ +int rds_send_xmit(struct rds_conn_path *cp) +{ + struct rds_connection *conn = cp->cp_conn; + struct rds_message *rm; + unsigned long flags; + unsigned int tmp; + struct scatterlist *sg; + int ret = 0; + LIST_HEAD(to_be_dropped); + int batch_count; + unsigned long send_gen = 0; + int same_rm = 0; + +restart: + batch_count = 0; + + /* + * sendmsg calls here after having queued its message on the send + * queue. We only have one task feeding the connection at a time. If + * another thread is already feeding the queue then we back off. This + * avoids blocking the caller and trading per-connection data between + * caches per message. + */ + if (!acquire_in_xmit(cp)) { + rds_stats_inc(s_send_lock_contention); + ret = -ENOMEM; + goto out; + } + + if (rds_destroy_pending(cp->cp_conn)) { + release_in_xmit(cp); + ret = -ENETUNREACH; /* dont requeue send work */ + goto out; + } + + /* + * we record the send generation after doing the xmit acquire. + * if someone else manages to jump in and do some work, we'll use + * this to avoid a goto restart farther down. + * + * The acquire_in_xmit() check above ensures that only one + * caller can increment c_send_gen at any time. + */ + send_gen = READ_ONCE(cp->cp_send_gen) + 1; + WRITE_ONCE(cp->cp_send_gen, send_gen); + + /* + * rds_conn_shutdown() sets the conn state and then tests RDS_IN_XMIT, + * we do the opposite to avoid races. + */ + if (!rds_conn_path_up(cp)) { + release_in_xmit(cp); + ret = 0; + goto out; + } + + if (conn->c_trans->xmit_path_prepare) + conn->c_trans->xmit_path_prepare(cp); + + /* + * spin trying to push headers and data down the connection until + * the connection doesn't make forward progress. + */ + while (1) { + + rm = cp->cp_xmit_rm; + + if (!rm) { + same_rm = 0; + } else { + same_rm++; + if (same_rm >= 4096) { + rds_stats_inc(s_send_stuck_rm); + ret = -EAGAIN; + break; + } + } + + /* + * If between sending messages, we can send a pending congestion + * map update. + */ + if (!rm && test_and_clear_bit(0, &conn->c_map_queued)) { + rm = rds_cong_update_alloc(conn); + if (IS_ERR(rm)) { + ret = PTR_ERR(rm); + break; + } + rm->data.op_active = 1; + rm->m_inc.i_conn_path = cp; + rm->m_inc.i_conn = cp->cp_conn; + + cp->cp_xmit_rm = rm; + } + + /* + * If not already working on one, grab the next message. + * + * cp_xmit_rm holds a ref while we're sending this message down + * the connction. We can use this ref while holding the + * send_sem.. rds_send_reset() is serialized with it. + */ + if (!rm) { + unsigned int len; + + batch_count++; + + /* we want to process as big a batch as we can, but + * we also want to avoid softlockups. If we've been + * through a lot of messages, lets back off and see + * if anyone else jumps in + */ + if (batch_count >= send_batch_count) + goto over_batch; + + spin_lock_irqsave(&cp->cp_lock, flags); + + if (!list_empty(&cp->cp_send_queue)) { + rm = list_entry(cp->cp_send_queue.next, + struct rds_message, + m_conn_item); + rds_message_addref(rm); + + /* + * Move the message from the send queue to the retransmit + * list right away. + */ + list_move_tail(&rm->m_conn_item, + &cp->cp_retrans); + } + + spin_unlock_irqrestore(&cp->cp_lock, flags); + + if (!rm) + break; + + /* Unfortunately, the way Infiniband deals with + * RDMA to a bad MR key is by moving the entire + * queue pair to error state. We could possibly + * recover from that, but right now we drop the + * connection. + * Therefore, we never retransmit messages with RDMA ops. + */ + if (test_bit(RDS_MSG_FLUSH, &rm->m_flags) || + (rm->rdma.op_active && + test_bit(RDS_MSG_RETRANSMITTED, &rm->m_flags))) { + spin_lock_irqsave(&cp->cp_lock, flags); + if (test_and_clear_bit(RDS_MSG_ON_CONN, &rm->m_flags)) + list_move(&rm->m_conn_item, &to_be_dropped); + spin_unlock_irqrestore(&cp->cp_lock, flags); + continue; + } + + /* Require an ACK every once in a while */ + len = ntohl(rm->m_inc.i_hdr.h_len); + if (cp->cp_unacked_packets == 0 || + cp->cp_unacked_bytes < len) { + set_bit(RDS_MSG_ACK_REQUIRED, &rm->m_flags); + + cp->cp_unacked_packets = + rds_sysctl_max_unacked_packets; + cp->cp_unacked_bytes = + rds_sysctl_max_unacked_bytes; + rds_stats_inc(s_send_ack_required); + } else { + cp->cp_unacked_bytes -= len; + cp->cp_unacked_packets--; + } + + cp->cp_xmit_rm = rm; + } + + /* The transport either sends the whole rdma or none of it */ + if (rm->rdma.op_active && !cp->cp_xmit_rdma_sent) { + rm->m_final_op = &rm->rdma; + /* The transport owns the mapped memory for now. + * You can't unmap it while it's on the send queue + */ + set_bit(RDS_MSG_MAPPED, &rm->m_flags); + ret = conn->c_trans->xmit_rdma(conn, &rm->rdma); + if (ret) { + clear_bit(RDS_MSG_MAPPED, &rm->m_flags); + wake_up_interruptible(&rm->m_flush_wait); + break; + } + cp->cp_xmit_rdma_sent = 1; + + } + + if (rm->atomic.op_active && !cp->cp_xmit_atomic_sent) { + rm->m_final_op = &rm->atomic; + /* The transport owns the mapped memory for now. + * You can't unmap it while it's on the send queue + */ + set_bit(RDS_MSG_MAPPED, &rm->m_flags); + ret = conn->c_trans->xmit_atomic(conn, &rm->atomic); + if (ret) { + clear_bit(RDS_MSG_MAPPED, &rm->m_flags); + wake_up_interruptible(&rm->m_flush_wait); + break; + } + cp->cp_xmit_atomic_sent = 1; + + } + + /* + * A number of cases require an RDS header to be sent + * even if there is no data. + * We permit 0-byte sends; rds-ping depends on this. + * However, if there are exclusively attached silent ops, + * we skip the hdr/data send, to enable silent operation. + */ + if (rm->data.op_nents == 0) { + int ops_present; + int all_ops_are_silent = 1; + + ops_present = (rm->atomic.op_active || rm->rdma.op_active); + if (rm->atomic.op_active && !rm->atomic.op_silent) + all_ops_are_silent = 0; + if (rm->rdma.op_active && !rm->rdma.op_silent) + all_ops_are_silent = 0; + + if (ops_present && all_ops_are_silent + && !rm->m_rdma_cookie) + rm->data.op_active = 0; + } + + if (rm->data.op_active && !cp->cp_xmit_data_sent) { + rm->m_final_op = &rm->data; + + ret = conn->c_trans->xmit(conn, rm, + cp->cp_xmit_hdr_off, + cp->cp_xmit_sg, + cp->cp_xmit_data_off); + if (ret <= 0) + break; + + if (cp->cp_xmit_hdr_off < sizeof(struct rds_header)) { + tmp = min_t(int, ret, + sizeof(struct rds_header) - + cp->cp_xmit_hdr_off); + cp->cp_xmit_hdr_off += tmp; + ret -= tmp; + } + + sg = &rm->data.op_sg[cp->cp_xmit_sg]; + while (ret) { + tmp = min_t(int, ret, sg->length - + cp->cp_xmit_data_off); + cp->cp_xmit_data_off += tmp; + ret -= tmp; + if (cp->cp_xmit_data_off == sg->length) { + cp->cp_xmit_data_off = 0; + sg++; + cp->cp_xmit_sg++; + BUG_ON(ret != 0 && cp->cp_xmit_sg == + rm->data.op_nents); + } + } + + if (cp->cp_xmit_hdr_off == sizeof(struct rds_header) && + (cp->cp_xmit_sg == rm->data.op_nents)) + cp->cp_xmit_data_sent = 1; + } + + /* + * A rm will only take multiple times through this loop + * if there is a data op. Thus, if the data is sent (or there was + * none), then we're done with the rm. + */ + if (!rm->data.op_active || cp->cp_xmit_data_sent) { + cp->cp_xmit_rm = NULL; + cp->cp_xmit_sg = 0; + cp->cp_xmit_hdr_off = 0; + cp->cp_xmit_data_off = 0; + cp->cp_xmit_rdma_sent = 0; + cp->cp_xmit_atomic_sent = 0; + cp->cp_xmit_data_sent = 0; + + rds_message_put(rm); + } + } + +over_batch: + if (conn->c_trans->xmit_path_complete) + conn->c_trans->xmit_path_complete(cp); + release_in_xmit(cp); + + /* Nuke any messages we decided not to retransmit. */ + if (!list_empty(&to_be_dropped)) { + /* irqs on here, so we can put(), unlike above */ + list_for_each_entry(rm, &to_be_dropped, m_conn_item) + rds_message_put(rm); + rds_send_remove_from_sock(&to_be_dropped, RDS_RDMA_DROPPED); + } + + /* + * Other senders can queue a message after we last test the send queue + * but before we clear RDS_IN_XMIT. In that case they'd back off and + * not try and send their newly queued message. We need to check the + * send queue after having cleared RDS_IN_XMIT so that their message + * doesn't get stuck on the send queue. + * + * If the transport cannot continue (i.e ret != 0), then it must + * call us when more room is available, such as from the tx + * completion handler. + * + * We have an extra generation check here so that if someone manages + * to jump in after our release_in_xmit, we'll see that they have done + * some work and we will skip our goto + */ + if (ret == 0) { + bool raced; + + smp_mb(); + raced = send_gen != READ_ONCE(cp->cp_send_gen); + + if ((test_bit(0, &conn->c_map_queued) || + !list_empty(&cp->cp_send_queue)) && !raced) { + if (batch_count < send_batch_count) + goto restart; + rcu_read_lock(); + if (rds_destroy_pending(cp->cp_conn)) + ret = -ENETUNREACH; + else + queue_delayed_work(rds_wq, &cp->cp_send_w, 1); + rcu_read_unlock(); + } else if (raced) { + rds_stats_inc(s_send_lock_queue_raced); + } + } +out: + return ret; +} +EXPORT_SYMBOL_GPL(rds_send_xmit); + +static void rds_send_sndbuf_remove(struct rds_sock *rs, struct rds_message *rm) +{ + u32 len = be32_to_cpu(rm->m_inc.i_hdr.h_len); + + assert_spin_locked(&rs->rs_lock); + + BUG_ON(rs->rs_snd_bytes < len); + rs->rs_snd_bytes -= len; + + if (rs->rs_snd_bytes == 0) + rds_stats_inc(s_send_queue_empty); +} + +static inline int rds_send_is_acked(struct rds_message *rm, u64 ack, + is_acked_func is_acked) +{ + if (is_acked) + return is_acked(rm, ack); + return be64_to_cpu(rm->m_inc.i_hdr.h_sequence) <= ack; +} + +/* + * This is pretty similar to what happens below in the ACK + * handling code - except that we call here as soon as we get + * the IB send completion on the RDMA op and the accompanying + * message. + */ +void rds_rdma_send_complete(struct rds_message *rm, int status) +{ + struct rds_sock *rs = NULL; + struct rm_rdma_op *ro; + struct rds_notifier *notifier; + unsigned long flags; + + spin_lock_irqsave(&rm->m_rs_lock, flags); + + ro = &rm->rdma; + if (test_bit(RDS_MSG_ON_SOCK, &rm->m_flags) && + ro->op_active && ro->op_notify && ro->op_notifier) { + notifier = ro->op_notifier; + rs = rm->m_rs; + sock_hold(rds_rs_to_sk(rs)); + + notifier->n_status = status; + spin_lock(&rs->rs_lock); + list_add_tail(¬ifier->n_list, &rs->rs_notify_queue); + spin_unlock(&rs->rs_lock); + + ro->op_notifier = NULL; + } + + spin_unlock_irqrestore(&rm->m_rs_lock, flags); + + if (rs) { + rds_wake_sk_sleep(rs); + sock_put(rds_rs_to_sk(rs)); + } +} +EXPORT_SYMBOL_GPL(rds_rdma_send_complete); + +/* + * Just like above, except looks at atomic op + */ +void rds_atomic_send_complete(struct rds_message *rm, int status) +{ + struct rds_sock *rs = NULL; + struct rm_atomic_op *ao; + struct rds_notifier *notifier; + unsigned long flags; + + spin_lock_irqsave(&rm->m_rs_lock, flags); + + ao = &rm->atomic; + if (test_bit(RDS_MSG_ON_SOCK, &rm->m_flags) + && ao->op_active && ao->op_notify && ao->op_notifier) { + notifier = ao->op_notifier; + rs = rm->m_rs; + sock_hold(rds_rs_to_sk(rs)); + + notifier->n_status = status; + spin_lock(&rs->rs_lock); + list_add_tail(¬ifier->n_list, &rs->rs_notify_queue); + spin_unlock(&rs->rs_lock); + + ao->op_notifier = NULL; + } + + spin_unlock_irqrestore(&rm->m_rs_lock, flags); + + if (rs) { + rds_wake_sk_sleep(rs); + sock_put(rds_rs_to_sk(rs)); + } +} +EXPORT_SYMBOL_GPL(rds_atomic_send_complete); + +/* + * This is the same as rds_rdma_send_complete except we + * don't do any locking - we have all the ingredients (message, + * socket, socket lock) and can just move the notifier. + */ +static inline void +__rds_send_complete(struct rds_sock *rs, struct rds_message *rm, int status) +{ + struct rm_rdma_op *ro; + struct rm_atomic_op *ao; + + ro = &rm->rdma; + if (ro->op_active && ro->op_notify && ro->op_notifier) { + ro->op_notifier->n_status = status; + list_add_tail(&ro->op_notifier->n_list, &rs->rs_notify_queue); + ro->op_notifier = NULL; + } + + ao = &rm->atomic; + if (ao->op_active && ao->op_notify && ao->op_notifier) { + ao->op_notifier->n_status = status; + list_add_tail(&ao->op_notifier->n_list, &rs->rs_notify_queue); + ao->op_notifier = NULL; + } + + /* No need to wake the app - caller does this */ +} + +/* + * This removes messages from the socket's list if they're on it. The list + * argument must be private to the caller, we must be able to modify it + * without locks. The messages must have a reference held for their + * position on the list. This function will drop that reference after + * removing the messages from the 'messages' list regardless of if it found + * the messages on the socket list or not. + */ +static void rds_send_remove_from_sock(struct list_head *messages, int status) +{ + unsigned long flags; + struct rds_sock *rs = NULL; + struct rds_message *rm; + + while (!list_empty(messages)) { + int was_on_sock = 0; + + rm = list_entry(messages->next, struct rds_message, + m_conn_item); + list_del_init(&rm->m_conn_item); + + /* + * If we see this flag cleared then we're *sure* that someone + * else beat us to removing it from the sock. If we race + * with their flag update we'll get the lock and then really + * see that the flag has been cleared. + * + * The message spinlock makes sure nobody clears rm->m_rs + * while we're messing with it. It does not prevent the + * message from being removed from the socket, though. + */ + spin_lock_irqsave(&rm->m_rs_lock, flags); + if (!test_bit(RDS_MSG_ON_SOCK, &rm->m_flags)) + goto unlock_and_drop; + + if (rs != rm->m_rs) { + if (rs) { + rds_wake_sk_sleep(rs); + sock_put(rds_rs_to_sk(rs)); + } + rs = rm->m_rs; + if (rs) + sock_hold(rds_rs_to_sk(rs)); + } + if (!rs) + goto unlock_and_drop; + spin_lock(&rs->rs_lock); + + if (test_and_clear_bit(RDS_MSG_ON_SOCK, &rm->m_flags)) { + struct rm_rdma_op *ro = &rm->rdma; + struct rds_notifier *notifier; + + list_del_init(&rm->m_sock_item); + rds_send_sndbuf_remove(rs, rm); + + if (ro->op_active && ro->op_notifier && + (ro->op_notify || (ro->op_recverr && status))) { + notifier = ro->op_notifier; + list_add_tail(¬ifier->n_list, + &rs->rs_notify_queue); + if (!notifier->n_status) + notifier->n_status = status; + rm->rdma.op_notifier = NULL; + } + was_on_sock = 1; + } + spin_unlock(&rs->rs_lock); + +unlock_and_drop: + spin_unlock_irqrestore(&rm->m_rs_lock, flags); + rds_message_put(rm); + if (was_on_sock) + rds_message_put(rm); + } + + if (rs) { + rds_wake_sk_sleep(rs); + sock_put(rds_rs_to_sk(rs)); + } +} + +/* + * Transports call here when they've determined that the receiver queued + * messages up to, and including, the given sequence number. Messages are + * moved to the retrans queue when rds_send_xmit picks them off the send + * queue. This means that in the TCP case, the message may not have been + * assigned the m_ack_seq yet - but that's fine as long as tcp_is_acked + * checks the RDS_MSG_HAS_ACK_SEQ bit. + */ +void rds_send_path_drop_acked(struct rds_conn_path *cp, u64 ack, + is_acked_func is_acked) +{ + struct rds_message *rm, *tmp; + unsigned long flags; + LIST_HEAD(list); + + spin_lock_irqsave(&cp->cp_lock, flags); + + list_for_each_entry_safe(rm, tmp, &cp->cp_retrans, m_conn_item) { + if (!rds_send_is_acked(rm, ack, is_acked)) + break; + + list_move(&rm->m_conn_item, &list); + clear_bit(RDS_MSG_ON_CONN, &rm->m_flags); + } + + /* order flag updates with spin locks */ + if (!list_empty(&list)) + smp_mb__after_atomic(); + + spin_unlock_irqrestore(&cp->cp_lock, flags); + + /* now remove the messages from the sock list as needed */ + rds_send_remove_from_sock(&list, RDS_RDMA_SUCCESS); +} +EXPORT_SYMBOL_GPL(rds_send_path_drop_acked); + +void rds_send_drop_acked(struct rds_connection *conn, u64 ack, + is_acked_func is_acked) +{ + WARN_ON(conn->c_trans->t_mp_capable); + rds_send_path_drop_acked(&conn->c_path[0], ack, is_acked); +} +EXPORT_SYMBOL_GPL(rds_send_drop_acked); + +void rds_send_drop_to(struct rds_sock *rs, struct sockaddr_in6 *dest) +{ + struct rds_message *rm, *tmp; + struct rds_connection *conn; + struct rds_conn_path *cp; + unsigned long flags; + LIST_HEAD(list); + + /* get all the messages we're dropping under the rs lock */ + spin_lock_irqsave(&rs->rs_lock, flags); + + list_for_each_entry_safe(rm, tmp, &rs->rs_send_queue, m_sock_item) { + if (dest && + (!ipv6_addr_equal(&dest->sin6_addr, &rm->m_daddr) || + dest->sin6_port != rm->m_inc.i_hdr.h_dport)) + continue; + + list_move(&rm->m_sock_item, &list); + rds_send_sndbuf_remove(rs, rm); + clear_bit(RDS_MSG_ON_SOCK, &rm->m_flags); + } + + /* order flag updates with the rs lock */ + smp_mb__after_atomic(); + + spin_unlock_irqrestore(&rs->rs_lock, flags); + + if (list_empty(&list)) + return; + + /* Remove the messages from the conn */ + list_for_each_entry(rm, &list, m_sock_item) { + + conn = rm->m_inc.i_conn; + if (conn->c_trans->t_mp_capable) + cp = rm->m_inc.i_conn_path; + else + cp = &conn->c_path[0]; + + spin_lock_irqsave(&cp->cp_lock, flags); + /* + * Maybe someone else beat us to removing rm from the conn. + * If we race with their flag update we'll get the lock and + * then really see that the flag has been cleared. + */ + if (!test_and_clear_bit(RDS_MSG_ON_CONN, &rm->m_flags)) { + spin_unlock_irqrestore(&cp->cp_lock, flags); + continue; + } + list_del_init(&rm->m_conn_item); + spin_unlock_irqrestore(&cp->cp_lock, flags); + + /* + * Couldn't grab m_rs_lock in top loop (lock ordering), + * but we can now. + */ + spin_lock_irqsave(&rm->m_rs_lock, flags); + + spin_lock(&rs->rs_lock); + __rds_send_complete(rs, rm, RDS_RDMA_CANCELED); + spin_unlock(&rs->rs_lock); + + spin_unlock_irqrestore(&rm->m_rs_lock, flags); + + rds_message_put(rm); + } + + rds_wake_sk_sleep(rs); + + while (!list_empty(&list)) { + rm = list_entry(list.next, struct rds_message, m_sock_item); + list_del_init(&rm->m_sock_item); + rds_message_wait(rm); + + /* just in case the code above skipped this message + * because RDS_MSG_ON_CONN wasn't set, run it again here + * taking m_rs_lock is the only thing that keeps us + * from racing with ack processing. + */ + spin_lock_irqsave(&rm->m_rs_lock, flags); + + spin_lock(&rs->rs_lock); + __rds_send_complete(rs, rm, RDS_RDMA_CANCELED); + spin_unlock(&rs->rs_lock); + + spin_unlock_irqrestore(&rm->m_rs_lock, flags); + + rds_message_put(rm); + } +} + +/* + * we only want this to fire once so we use the callers 'queued'. It's + * possible that another thread can race with us and remove the + * message from the flow with RDS_CANCEL_SENT_TO. + */ +static int rds_send_queue_rm(struct rds_sock *rs, struct rds_connection *conn, + struct rds_conn_path *cp, + struct rds_message *rm, __be16 sport, + __be16 dport, int *queued) +{ + unsigned long flags; + u32 len; + + if (*queued) + goto out; + + len = be32_to_cpu(rm->m_inc.i_hdr.h_len); + + /* this is the only place which holds both the socket's rs_lock + * and the connection's c_lock */ + spin_lock_irqsave(&rs->rs_lock, flags); + + /* + * If there is a little space in sndbuf, we don't queue anything, + * and userspace gets -EAGAIN. But poll() indicates there's send + * room. This can lead to bad behavior (spinning) if snd_bytes isn't + * freed up by incoming acks. So we check the *old* value of + * rs_snd_bytes here to allow the last msg to exceed the buffer, + * and poll() now knows no more data can be sent. + */ + if (rs->rs_snd_bytes < rds_sk_sndbuf(rs)) { + rs->rs_snd_bytes += len; + + /* let recv side know we are close to send space exhaustion. + * This is probably not the optimal way to do it, as this + * means we set the flag on *all* messages as soon as our + * throughput hits a certain threshold. + */ + if (rs->rs_snd_bytes >= rds_sk_sndbuf(rs) / 2) + set_bit(RDS_MSG_ACK_REQUIRED, &rm->m_flags); + + list_add_tail(&rm->m_sock_item, &rs->rs_send_queue); + set_bit(RDS_MSG_ON_SOCK, &rm->m_flags); + rds_message_addref(rm); + sock_hold(rds_rs_to_sk(rs)); + rm->m_rs = rs; + + /* The code ordering is a little weird, but we're + trying to minimize the time we hold c_lock */ + rds_message_populate_header(&rm->m_inc.i_hdr, sport, dport, 0); + rm->m_inc.i_conn = conn; + rm->m_inc.i_conn_path = cp; + rds_message_addref(rm); + + spin_lock(&cp->cp_lock); + rm->m_inc.i_hdr.h_sequence = cpu_to_be64(cp->cp_next_tx_seq++); + list_add_tail(&rm->m_conn_item, &cp->cp_send_queue); + set_bit(RDS_MSG_ON_CONN, &rm->m_flags); + spin_unlock(&cp->cp_lock); + + rdsdebug("queued msg %p len %d, rs %p bytes %d seq %llu\n", + rm, len, rs, rs->rs_snd_bytes, + (unsigned long long)be64_to_cpu(rm->m_inc.i_hdr.h_sequence)); + + *queued = 1; + } + + spin_unlock_irqrestore(&rs->rs_lock, flags); +out: + return *queued; +} + +/* + * rds_message is getting to be quite complicated, and we'd like to allocate + * it all in one go. This figures out how big it needs to be up front. + */ +static int rds_rm_size(struct msghdr *msg, int num_sgs, + struct rds_iov_vector_arr *vct) +{ + struct cmsghdr *cmsg; + int size = 0; + int cmsg_groups = 0; + int retval; + bool zcopy_cookie = false; + struct rds_iov_vector *iov, *tmp_iov; + + if (num_sgs < 0) + return -EINVAL; + + for_each_cmsghdr(cmsg, msg) { + if (!CMSG_OK(msg, cmsg)) + return -EINVAL; + + if (cmsg->cmsg_level != SOL_RDS) + continue; + + switch (cmsg->cmsg_type) { + case RDS_CMSG_RDMA_ARGS: + if (vct->indx >= vct->len) { + vct->len += vct->incr; + tmp_iov = + krealloc(vct->vec, + vct->len * + sizeof(struct rds_iov_vector), + GFP_KERNEL); + if (!tmp_iov) { + vct->len -= vct->incr; + return -ENOMEM; + } + vct->vec = tmp_iov; + } + iov = &vct->vec[vct->indx]; + memset(iov, 0, sizeof(struct rds_iov_vector)); + vct->indx++; + cmsg_groups |= 1; + retval = rds_rdma_extra_size(CMSG_DATA(cmsg), iov); + if (retval < 0) + return retval; + size += retval; + + break; + + case RDS_CMSG_ZCOPY_COOKIE: + zcopy_cookie = true; + fallthrough; + + case RDS_CMSG_RDMA_DEST: + case RDS_CMSG_RDMA_MAP: + cmsg_groups |= 2; + /* these are valid but do no add any size */ + break; + + case RDS_CMSG_ATOMIC_CSWP: + case RDS_CMSG_ATOMIC_FADD: + case RDS_CMSG_MASKED_ATOMIC_CSWP: + case RDS_CMSG_MASKED_ATOMIC_FADD: + cmsg_groups |= 1; + size += sizeof(struct scatterlist); + break; + + default: + return -EINVAL; + } + + } + + if ((msg->msg_flags & MSG_ZEROCOPY) && !zcopy_cookie) + return -EINVAL; + + size += num_sgs * sizeof(struct scatterlist); + + /* Ensure (DEST, MAP) are never used with (ARGS, ATOMIC) */ + if (cmsg_groups == 3) + return -EINVAL; + + return size; +} + +static int rds_cmsg_zcopy(struct rds_sock *rs, struct rds_message *rm, + struct cmsghdr *cmsg) +{ + u32 *cookie; + + if (cmsg->cmsg_len < CMSG_LEN(sizeof(*cookie)) || + !rm->data.op_mmp_znotifier) + return -EINVAL; + cookie = CMSG_DATA(cmsg); + rm->data.op_mmp_znotifier->z_cookie = *cookie; + return 0; +} + +static int rds_cmsg_send(struct rds_sock *rs, struct rds_message *rm, + struct msghdr *msg, int *allocated_mr, + struct rds_iov_vector_arr *vct) +{ + struct cmsghdr *cmsg; + int ret = 0, ind = 0; + + for_each_cmsghdr(cmsg, msg) { + if (!CMSG_OK(msg, cmsg)) + return -EINVAL; + + if (cmsg->cmsg_level != SOL_RDS) + continue; + + /* As a side effect, RDMA_DEST and RDMA_MAP will set + * rm->rdma.m_rdma_cookie and rm->rdma.m_rdma_mr. + */ + switch (cmsg->cmsg_type) { + case RDS_CMSG_RDMA_ARGS: + if (ind >= vct->indx) + return -ENOMEM; + ret = rds_cmsg_rdma_args(rs, rm, cmsg, &vct->vec[ind]); + ind++; + break; + + case RDS_CMSG_RDMA_DEST: + ret = rds_cmsg_rdma_dest(rs, rm, cmsg); + break; + + case RDS_CMSG_RDMA_MAP: + ret = rds_cmsg_rdma_map(rs, rm, cmsg); + if (!ret) + *allocated_mr = 1; + else if (ret == -ENODEV) + /* Accommodate the get_mr() case which can fail + * if connection isn't established yet. + */ + ret = -EAGAIN; + break; + case RDS_CMSG_ATOMIC_CSWP: + case RDS_CMSG_ATOMIC_FADD: + case RDS_CMSG_MASKED_ATOMIC_CSWP: + case RDS_CMSG_MASKED_ATOMIC_FADD: + ret = rds_cmsg_atomic(rs, rm, cmsg); + break; + + case RDS_CMSG_ZCOPY_COOKIE: + ret = rds_cmsg_zcopy(rs, rm, cmsg); + break; + + default: + return -EINVAL; + } + + if (ret) + break; + } + + return ret; +} + +static int rds_send_mprds_hash(struct rds_sock *rs, + struct rds_connection *conn, int nonblock) +{ + int hash; + + if (conn->c_npaths == 0) + hash = RDS_MPATH_HASH(rs, RDS_MPATH_WORKERS); + else + hash = RDS_MPATH_HASH(rs, conn->c_npaths); + if (conn->c_npaths == 0 && hash != 0) { + rds_send_ping(conn, 0); + + /* The underlying connection is not up yet. Need to wait + * until it is up to be sure that the non-zero c_path can be + * used. But if we are interrupted, we have to use the zero + * c_path in case the connection ends up being non-MP capable. + */ + if (conn->c_npaths == 0) { + /* Cannot wait for the connection be made, so just use + * the base c_path. + */ + if (nonblock) + return 0; + if (wait_event_interruptible(conn->c_hs_waitq, + conn->c_npaths != 0)) + hash = 0; + } + if (conn->c_npaths == 1) + hash = 0; + } + return hash; +} + +static int rds_rdma_bytes(struct msghdr *msg, size_t *rdma_bytes) +{ + struct rds_rdma_args *args; + struct cmsghdr *cmsg; + + for_each_cmsghdr(cmsg, msg) { + if (!CMSG_OK(msg, cmsg)) + return -EINVAL; + + if (cmsg->cmsg_level != SOL_RDS) + continue; + + if (cmsg->cmsg_type == RDS_CMSG_RDMA_ARGS) { + if (cmsg->cmsg_len < + CMSG_LEN(sizeof(struct rds_rdma_args))) + return -EINVAL; + args = CMSG_DATA(cmsg); + *rdma_bytes += args->remote_vec.bytes; + } + } + return 0; +} + +int rds_sendmsg(struct socket *sock, struct msghdr *msg, size_t payload_len) +{ + struct sock *sk = sock->sk; + struct rds_sock *rs = rds_sk_to_rs(sk); + DECLARE_SOCKADDR(struct sockaddr_in6 *, sin6, msg->msg_name); + DECLARE_SOCKADDR(struct sockaddr_in *, usin, msg->msg_name); + __be16 dport; + struct rds_message *rm = NULL; + struct rds_connection *conn; + int ret = 0; + int queued = 0, allocated_mr = 0; + int nonblock = msg->msg_flags & MSG_DONTWAIT; + long timeo = sock_sndtimeo(sk, nonblock); + struct rds_conn_path *cpath; + struct in6_addr daddr; + __u32 scope_id = 0; + size_t total_payload_len = payload_len, rdma_payload_len = 0; + bool zcopy = ((msg->msg_flags & MSG_ZEROCOPY) && + sock_flag(rds_rs_to_sk(rs), SOCK_ZEROCOPY)); + int num_sgs = DIV_ROUND_UP(payload_len, PAGE_SIZE); + int namelen; + struct rds_iov_vector_arr vct; + int ind; + + memset(&vct, 0, sizeof(vct)); + + /* expect 1 RDMA CMSG per rds_sendmsg. can still grow if more needed. */ + vct.incr = 1; + + /* Mirror Linux UDP mirror of BSD error message compatibility */ + /* XXX: Perhaps MSG_MORE someday */ + if (msg->msg_flags & ~(MSG_DONTWAIT | MSG_CMSG_COMPAT | MSG_ZEROCOPY)) { + ret = -EOPNOTSUPP; + goto out; + } + + namelen = msg->msg_namelen; + if (namelen != 0) { + if (namelen < sizeof(*usin)) { + ret = -EINVAL; + goto out; + } + switch (usin->sin_family) { + case AF_INET: + if (usin->sin_addr.s_addr == htonl(INADDR_ANY) || + usin->sin_addr.s_addr == htonl(INADDR_BROADCAST) || + ipv4_is_multicast(usin->sin_addr.s_addr)) { + ret = -EINVAL; + goto out; + } + ipv6_addr_set_v4mapped(usin->sin_addr.s_addr, &daddr); + dport = usin->sin_port; + break; + +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: { + int addr_type; + + if (namelen < sizeof(*sin6)) { + ret = -EINVAL; + goto out; + } + addr_type = ipv6_addr_type(&sin6->sin6_addr); + if (!(addr_type & IPV6_ADDR_UNICAST)) { + __be32 addr4; + + if (!(addr_type & IPV6_ADDR_MAPPED)) { + ret = -EINVAL; + goto out; + } + + /* It is a mapped address. Need to do some + * sanity checks. + */ + addr4 = sin6->sin6_addr.s6_addr32[3]; + if (addr4 == htonl(INADDR_ANY) || + addr4 == htonl(INADDR_BROADCAST) || + ipv4_is_multicast(addr4)) { + ret = -EINVAL; + goto out; + } + } + if (addr_type & IPV6_ADDR_LINKLOCAL) { + if (sin6->sin6_scope_id == 0) { + ret = -EINVAL; + goto out; + } + scope_id = sin6->sin6_scope_id; + } + + daddr = sin6->sin6_addr; + dport = sin6->sin6_port; + break; + } +#endif + + default: + ret = -EINVAL; + goto out; + } + } else { + /* We only care about consistency with ->connect() */ + lock_sock(sk); + daddr = rs->rs_conn_addr; + dport = rs->rs_conn_port; + scope_id = rs->rs_bound_scope_id; + release_sock(sk); + } + + lock_sock(sk); + if (ipv6_addr_any(&rs->rs_bound_addr) || ipv6_addr_any(&daddr)) { + release_sock(sk); + ret = -ENOTCONN; + goto out; + } else if (namelen != 0) { + /* Cannot send to an IPv4 address using an IPv6 source + * address and cannot send to an IPv6 address using an + * IPv4 source address. + */ + if (ipv6_addr_v4mapped(&daddr) ^ + ipv6_addr_v4mapped(&rs->rs_bound_addr)) { + release_sock(sk); + ret = -EOPNOTSUPP; + goto out; + } + /* If the socket is already bound to a link local address, + * it can only send to peers on the same link. But allow + * communicating between link local and non-link local address. + */ + if (scope_id != rs->rs_bound_scope_id) { + if (!scope_id) { + scope_id = rs->rs_bound_scope_id; + } else if (rs->rs_bound_scope_id) { + release_sock(sk); + ret = -EINVAL; + goto out; + } + } + } + release_sock(sk); + + ret = rds_rdma_bytes(msg, &rdma_payload_len); + if (ret) + goto out; + + total_payload_len += rdma_payload_len; + if (max_t(size_t, payload_len, rdma_payload_len) > RDS_MAX_MSG_SIZE) { + ret = -EMSGSIZE; + goto out; + } + + if (payload_len > rds_sk_sndbuf(rs)) { + ret = -EMSGSIZE; + goto out; + } + + if (zcopy) { + if (rs->rs_transport->t_type != RDS_TRANS_TCP) { + ret = -EOPNOTSUPP; + goto out; + } + num_sgs = iov_iter_npages(&msg->msg_iter, INT_MAX); + } + /* size of rm including all sgs */ + ret = rds_rm_size(msg, num_sgs, &vct); + if (ret < 0) + goto out; + + rm = rds_message_alloc(ret, GFP_KERNEL); + if (!rm) { + ret = -ENOMEM; + goto out; + } + + /* Attach data to the rm */ + if (payload_len) { + rm->data.op_sg = rds_message_alloc_sgs(rm, num_sgs); + if (IS_ERR(rm->data.op_sg)) { + ret = PTR_ERR(rm->data.op_sg); + goto out; + } + ret = rds_message_copy_from_user(rm, &msg->msg_iter, zcopy); + if (ret) + goto out; + } + rm->data.op_active = 1; + + rm->m_daddr = daddr; + + /* rds_conn_create has a spinlock that runs with IRQ off. + * Caching the conn in the socket helps a lot. */ + if (rs->rs_conn && ipv6_addr_equal(&rs->rs_conn->c_faddr, &daddr) && + rs->rs_tos == rs->rs_conn->c_tos) { + conn = rs->rs_conn; + } else { + conn = rds_conn_create_outgoing(sock_net(sock->sk), + &rs->rs_bound_addr, &daddr, + rs->rs_transport, rs->rs_tos, + sock->sk->sk_allocation, + scope_id); + if (IS_ERR(conn)) { + ret = PTR_ERR(conn); + goto out; + } + rs->rs_conn = conn; + } + + if (conn->c_trans->t_mp_capable) + cpath = &conn->c_path[rds_send_mprds_hash(rs, conn, nonblock)]; + else + cpath = &conn->c_path[0]; + + rm->m_conn_path = cpath; + + /* Parse any control messages the user may have included. */ + ret = rds_cmsg_send(rs, rm, msg, &allocated_mr, &vct); + if (ret) { + /* Trigger connection so that its ready for the next retry */ + if (ret == -EAGAIN) + rds_conn_connect_if_down(conn); + goto out; + } + + if (rm->rdma.op_active && !conn->c_trans->xmit_rdma) { + printk_ratelimited(KERN_NOTICE "rdma_op %p conn xmit_rdma %p\n", + &rm->rdma, conn->c_trans->xmit_rdma); + ret = -EOPNOTSUPP; + goto out; + } + + if (rm->atomic.op_active && !conn->c_trans->xmit_atomic) { + printk_ratelimited(KERN_NOTICE "atomic_op %p conn xmit_atomic %p\n", + &rm->atomic, conn->c_trans->xmit_atomic); + ret = -EOPNOTSUPP; + goto out; + } + + if (rds_destroy_pending(conn)) { + ret = -EAGAIN; + goto out; + } + + if (rds_conn_path_down(cpath)) + rds_check_all_paths(conn); + + ret = rds_cong_wait(conn->c_fcong, dport, nonblock, rs); + if (ret) { + rs->rs_seen_congestion = 1; + goto out; + } + while (!rds_send_queue_rm(rs, conn, cpath, rm, rs->rs_bound_port, + dport, &queued)) { + rds_stats_inc(s_send_queue_full); + + if (nonblock) { + ret = -EAGAIN; + goto out; + } + + timeo = wait_event_interruptible_timeout(*sk_sleep(sk), + rds_send_queue_rm(rs, conn, cpath, rm, + rs->rs_bound_port, + dport, + &queued), + timeo); + rdsdebug("sendmsg woke queued %d timeo %ld\n", queued, timeo); + if (timeo > 0 || timeo == MAX_SCHEDULE_TIMEOUT) + continue; + + ret = timeo; + if (ret == 0) + ret = -ETIMEDOUT; + goto out; + } + + /* + * By now we've committed to the send. We reuse rds_send_worker() + * to retry sends in the rds thread if the transport asks us to. + */ + rds_stats_inc(s_send_queued); + + ret = rds_send_xmit(cpath); + if (ret == -ENOMEM || ret == -EAGAIN) { + ret = 0; + rcu_read_lock(); + if (rds_destroy_pending(cpath->cp_conn)) + ret = -ENETUNREACH; + else + queue_delayed_work(rds_wq, &cpath->cp_send_w, 1); + rcu_read_unlock(); + } + if (ret) + goto out; + rds_message_put(rm); + + for (ind = 0; ind < vct.indx; ind++) + kfree(vct.vec[ind].iov); + kfree(vct.vec); + + return payload_len; + +out: + for (ind = 0; ind < vct.indx; ind++) + kfree(vct.vec[ind].iov); + kfree(vct.vec); + + /* If the user included a RDMA_MAP cmsg, we allocated a MR on the fly. + * If the sendmsg goes through, we keep the MR. If it fails with EAGAIN + * or in any other way, we need to destroy the MR again */ + if (allocated_mr) + rds_rdma_unuse(rs, rds_rdma_cookie_key(rm->m_rdma_cookie), 1); + + if (rm) + rds_message_put(rm); + return ret; +} + +/* + * send out a probe. Can be shared by rds_send_ping, + * rds_send_pong, rds_send_hb. + * rds_send_hb should use h_flags + * RDS_FLAG_HB_PING|RDS_FLAG_ACK_REQUIRED + * or + * RDS_FLAG_HB_PONG|RDS_FLAG_ACK_REQUIRED + */ +static int +rds_send_probe(struct rds_conn_path *cp, __be16 sport, + __be16 dport, u8 h_flags) +{ + struct rds_message *rm; + unsigned long flags; + int ret = 0; + + rm = rds_message_alloc(0, GFP_ATOMIC); + if (!rm) { + ret = -ENOMEM; + goto out; + } + + rm->m_daddr = cp->cp_conn->c_faddr; + rm->data.op_active = 1; + + rds_conn_path_connect_if_down(cp); + + ret = rds_cong_wait(cp->cp_conn->c_fcong, dport, 1, NULL); + if (ret) + goto out; + + spin_lock_irqsave(&cp->cp_lock, flags); + list_add_tail(&rm->m_conn_item, &cp->cp_send_queue); + set_bit(RDS_MSG_ON_CONN, &rm->m_flags); + rds_message_addref(rm); + rm->m_inc.i_conn = cp->cp_conn; + rm->m_inc.i_conn_path = cp; + + rds_message_populate_header(&rm->m_inc.i_hdr, sport, dport, + cp->cp_next_tx_seq); + rm->m_inc.i_hdr.h_flags |= h_flags; + cp->cp_next_tx_seq++; + + if (RDS_HS_PROBE(be16_to_cpu(sport), be16_to_cpu(dport)) && + cp->cp_conn->c_trans->t_mp_capable) { + u16 npaths = cpu_to_be16(RDS_MPATH_WORKERS); + u32 my_gen_num = cpu_to_be32(cp->cp_conn->c_my_gen_num); + + rds_message_add_extension(&rm->m_inc.i_hdr, + RDS_EXTHDR_NPATHS, &npaths, + sizeof(npaths)); + rds_message_add_extension(&rm->m_inc.i_hdr, + RDS_EXTHDR_GEN_NUM, + &my_gen_num, + sizeof(u32)); + } + spin_unlock_irqrestore(&cp->cp_lock, flags); + + rds_stats_inc(s_send_queued); + rds_stats_inc(s_send_pong); + + /* schedule the send work on rds_wq */ + rcu_read_lock(); + if (!rds_destroy_pending(cp->cp_conn)) + queue_delayed_work(rds_wq, &cp->cp_send_w, 1); + rcu_read_unlock(); + + rds_message_put(rm); + return 0; + +out: + if (rm) + rds_message_put(rm); + return ret; +} + +int +rds_send_pong(struct rds_conn_path *cp, __be16 dport) +{ + return rds_send_probe(cp, 0, dport, 0); +} + +void +rds_send_ping(struct rds_connection *conn, int cp_index) +{ + unsigned long flags; + struct rds_conn_path *cp = &conn->c_path[cp_index]; + + spin_lock_irqsave(&cp->cp_lock, flags); + if (conn->c_ping_triggered) { + spin_unlock_irqrestore(&cp->cp_lock, flags); + return; + } + conn->c_ping_triggered = 1; + spin_unlock_irqrestore(&cp->cp_lock, flags); + rds_send_probe(cp, cpu_to_be16(RDS_FLAG_PROBE_PORT), 0, 0); +} +EXPORT_SYMBOL_GPL(rds_send_ping); diff --git a/net/rds/stats.c b/net/rds/stats.c new file mode 100644 index 000000000..9e87da43c --- /dev/null +++ b/net/rds/stats.c @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2006 Oracle. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include + +#include "rds.h" + +DEFINE_PER_CPU_SHARED_ALIGNED(struct rds_statistics, rds_stats); +EXPORT_PER_CPU_SYMBOL_GPL(rds_stats); + +/* :.,$s/unsigned long\>.*\= sizeof(ctr.name)); + strncpy(ctr.name, names[i], sizeof(ctr.name) - 1); + ctr.name[sizeof(ctr.name) - 1] = '\0'; + ctr.value = values[i]; + + rds_info_copy(iter, &ctr, sizeof(ctr)); + } +} +EXPORT_SYMBOL_GPL(rds_stats_info_copy); + +/* + * This gives global counters across all the transports. The strings + * are copied in so that the tool doesn't need knowledge of the specific + * stats that we're exporting. Some are pretty implementation dependent + * and may change over time. That doesn't stop them from being useful. + * + * This is the only function in the chain that knows about the byte granular + * length in userspace. It converts it to number of stat entries that the + * rest of the functions operate in. + */ +static void rds_stats_info(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens) +{ + struct rds_statistics stats = {0, }; + uint64_t *src; + uint64_t *sum; + size_t i; + int cpu; + unsigned int avail; + + avail = len / sizeof(struct rds_info_counter); + + if (avail < ARRAY_SIZE(rds_stat_names)) { + avail = 0; + goto trans; + } + + for_each_online_cpu(cpu) { + src = (uint64_t *)&(per_cpu(rds_stats, cpu)); + sum = (uint64_t *)&stats; + for (i = 0; i < sizeof(stats) / sizeof(uint64_t); i++) + *(sum++) += *(src++); + } + + rds_stats_info_copy(iter, (uint64_t *)&stats, rds_stat_names, + ARRAY_SIZE(rds_stat_names)); + avail -= ARRAY_SIZE(rds_stat_names); + +trans: + lens->each = sizeof(struct rds_info_counter); + lens->nr = rds_trans_stats_info_copy(iter, avail) + + ARRAY_SIZE(rds_stat_names); +} + +void rds_stats_exit(void) +{ + rds_info_deregister_func(RDS_INFO_COUNTERS, rds_stats_info); +} + +int rds_stats_init(void) +{ + rds_info_register_func(RDS_INFO_COUNTERS, rds_stats_info); + return 0; +} diff --git a/net/rds/sysctl.c b/net/rds/sysctl.c new file mode 100644 index 000000000..e381bbcd9 --- /dev/null +++ b/net/rds/sysctl.c @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2006 Oracle. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include + +#include "rds.h" + +static struct ctl_table_header *rds_sysctl_reg_table; + +static unsigned long rds_sysctl_reconnect_min = 1; +static unsigned long rds_sysctl_reconnect_max = ~0UL; + +unsigned long rds_sysctl_reconnect_min_jiffies; +unsigned long rds_sysctl_reconnect_max_jiffies = HZ; + +unsigned int rds_sysctl_max_unacked_packets = 8; +unsigned int rds_sysctl_max_unacked_bytes = (16 << 20); + +unsigned int rds_sysctl_ping_enable = 1; + +static struct ctl_table rds_sysctl_rds_table[] = { + { + .procname = "reconnect_min_delay_ms", + .data = &rds_sysctl_reconnect_min_jiffies, + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_ms_jiffies_minmax, + .extra1 = &rds_sysctl_reconnect_min, + .extra2 = &rds_sysctl_reconnect_max_jiffies, + }, + { + .procname = "reconnect_max_delay_ms", + .data = &rds_sysctl_reconnect_max_jiffies, + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_ms_jiffies_minmax, + .extra1 = &rds_sysctl_reconnect_min_jiffies, + .extra2 = &rds_sysctl_reconnect_max, + }, + { + .procname = "max_unacked_packets", + .data = &rds_sysctl_max_unacked_packets, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "max_unacked_bytes", + .data = &rds_sysctl_max_unacked_bytes, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "ping_enable", + .data = &rds_sysctl_ping_enable, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { } +}; + +void rds_sysctl_exit(void) +{ + unregister_net_sysctl_table(rds_sysctl_reg_table); +} + +int rds_sysctl_init(void) +{ + rds_sysctl_reconnect_min = msecs_to_jiffies(1); + rds_sysctl_reconnect_min_jiffies = rds_sysctl_reconnect_min; + + rds_sysctl_reg_table = + register_net_sysctl(&init_net, "net/rds", rds_sysctl_rds_table); + if (!rds_sysctl_reg_table) + return -ENOMEM; + return 0; +} diff --git a/net/rds/tcp.c b/net/rds/tcp.c new file mode 100644 index 000000000..4444fd82b --- /dev/null +++ b/net/rds/tcp.c @@ -0,0 +1,754 @@ +/* + * Copyright (c) 2006, 2018 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "rds.h" +#include "tcp.h" + +/* only for info exporting */ +static DEFINE_SPINLOCK(rds_tcp_tc_list_lock); +static LIST_HEAD(rds_tcp_tc_list); + +/* rds_tcp_tc_count counts only IPv4 connections. + * rds6_tcp_tc_count counts both IPv4 and IPv6 connections. + */ +static unsigned int rds_tcp_tc_count; +#if IS_ENABLED(CONFIG_IPV6) +static unsigned int rds6_tcp_tc_count; +#endif + +/* Track rds_tcp_connection structs so they can be cleaned up */ +static DEFINE_SPINLOCK(rds_tcp_conn_lock); +static LIST_HEAD(rds_tcp_conn_list); +static atomic_t rds_tcp_unloading = ATOMIC_INIT(0); + +static struct kmem_cache *rds_tcp_conn_slab; + +static int rds_tcp_skbuf_handler(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *fpos); + +static int rds_tcp_min_sndbuf = SOCK_MIN_SNDBUF; +static int rds_tcp_min_rcvbuf = SOCK_MIN_RCVBUF; + +static struct ctl_table rds_tcp_sysctl_table[] = { +#define RDS_TCP_SNDBUF 0 + { + .procname = "rds_tcp_sndbuf", + /* data is per-net pointer */ + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = rds_tcp_skbuf_handler, + .extra1 = &rds_tcp_min_sndbuf, + }, +#define RDS_TCP_RCVBUF 1 + { + .procname = "rds_tcp_rcvbuf", + /* data is per-net pointer */ + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = rds_tcp_skbuf_handler, + .extra1 = &rds_tcp_min_rcvbuf, + }, + { } +}; + +u32 rds_tcp_write_seq(struct rds_tcp_connection *tc) +{ + /* seq# of the last byte of data in tcp send buffer */ + return tcp_sk(tc->t_sock->sk)->write_seq; +} + +u32 rds_tcp_snd_una(struct rds_tcp_connection *tc) +{ + return tcp_sk(tc->t_sock->sk)->snd_una; +} + +void rds_tcp_restore_callbacks(struct socket *sock, + struct rds_tcp_connection *tc) +{ + rdsdebug("restoring sock %p callbacks from tc %p\n", sock, tc); + write_lock_bh(&sock->sk->sk_callback_lock); + + /* done under the callback_lock to serialize with write_space */ + spin_lock(&rds_tcp_tc_list_lock); + list_del_init(&tc->t_list_item); +#if IS_ENABLED(CONFIG_IPV6) + rds6_tcp_tc_count--; +#endif + if (!tc->t_cpath->cp_conn->c_isv6) + rds_tcp_tc_count--; + spin_unlock(&rds_tcp_tc_list_lock); + + tc->t_sock = NULL; + + sock->sk->sk_write_space = tc->t_orig_write_space; + sock->sk->sk_data_ready = tc->t_orig_data_ready; + sock->sk->sk_state_change = tc->t_orig_state_change; + sock->sk->sk_user_data = NULL; + + write_unlock_bh(&sock->sk->sk_callback_lock); +} + +/* + * rds_tcp_reset_callbacks() switches the to the new sock and + * returns the existing tc->t_sock. + * + * The only functions that set tc->t_sock are rds_tcp_set_callbacks + * and rds_tcp_reset_callbacks. Send and receive trust that + * it is set. The absence of RDS_CONN_UP bit protects those paths + * from being called while it isn't set. + */ +void rds_tcp_reset_callbacks(struct socket *sock, + struct rds_conn_path *cp) +{ + struct rds_tcp_connection *tc = cp->cp_transport_data; + struct socket *osock = tc->t_sock; + + if (!osock) + goto newsock; + + /* Need to resolve a duelling SYN between peers. + * We have an outstanding SYN to this peer, which may + * potentially have transitioned to the RDS_CONN_UP state, + * so we must quiesce any send threads before resetting + * cp_transport_data. We quiesce these threads by setting + * cp_state to something other than RDS_CONN_UP, and then + * waiting for any existing threads in rds_send_xmit to + * complete release_in_xmit(). (Subsequent threads entering + * rds_send_xmit() will bail on !rds_conn_up(). + * + * However an incoming syn-ack at this point would end up + * marking the conn as RDS_CONN_UP, and would again permit + * rds_send_xmi() threads through, so ideally we would + * synchronize on RDS_CONN_UP after lock_sock(), but cannot + * do that: waiting on !RDS_IN_XMIT after lock_sock() may + * end up deadlocking with tcp_sendmsg(), and the RDS_IN_XMIT + * would not get set. As a result, we set c_state to + * RDS_CONN_RESETTTING, to ensure that rds_tcp_state_change + * cannot mark rds_conn_path_up() in the window before lock_sock() + */ + atomic_set(&cp->cp_state, RDS_CONN_RESETTING); + wait_event(cp->cp_waitq, !test_bit(RDS_IN_XMIT, &cp->cp_flags)); + /* reset receive side state for rds_tcp_data_recv() for osock */ + cancel_delayed_work_sync(&cp->cp_send_w); + cancel_delayed_work_sync(&cp->cp_recv_w); + lock_sock(osock->sk); + if (tc->t_tinc) { + rds_inc_put(&tc->t_tinc->ti_inc); + tc->t_tinc = NULL; + } + tc->t_tinc_hdr_rem = sizeof(struct rds_header); + tc->t_tinc_data_rem = 0; + rds_tcp_restore_callbacks(osock, tc); + release_sock(osock->sk); + sock_release(osock); +newsock: + rds_send_path_reset(cp); + lock_sock(sock->sk); + rds_tcp_set_callbacks(sock, cp); + release_sock(sock->sk); +} + +/* Add tc to rds_tcp_tc_list and set tc->t_sock. See comments + * above rds_tcp_reset_callbacks for notes about synchronization + * with data path + */ +void rds_tcp_set_callbacks(struct socket *sock, struct rds_conn_path *cp) +{ + struct rds_tcp_connection *tc = cp->cp_transport_data; + + rdsdebug("setting sock %p callbacks to tc %p\n", sock, tc); + write_lock_bh(&sock->sk->sk_callback_lock); + + /* done under the callback_lock to serialize with write_space */ + spin_lock(&rds_tcp_tc_list_lock); + list_add_tail(&tc->t_list_item, &rds_tcp_tc_list); +#if IS_ENABLED(CONFIG_IPV6) + rds6_tcp_tc_count++; +#endif + if (!tc->t_cpath->cp_conn->c_isv6) + rds_tcp_tc_count++; + spin_unlock(&rds_tcp_tc_list_lock); + + /* accepted sockets need our listen data ready undone */ + if (sock->sk->sk_data_ready == rds_tcp_listen_data_ready) + sock->sk->sk_data_ready = sock->sk->sk_user_data; + + tc->t_sock = sock; + tc->t_cpath = cp; + tc->t_orig_data_ready = sock->sk->sk_data_ready; + tc->t_orig_write_space = sock->sk->sk_write_space; + tc->t_orig_state_change = sock->sk->sk_state_change; + + sock->sk->sk_user_data = cp; + sock->sk->sk_data_ready = rds_tcp_data_ready; + sock->sk->sk_write_space = rds_tcp_write_space; + sock->sk->sk_state_change = rds_tcp_state_change; + + write_unlock_bh(&sock->sk->sk_callback_lock); +} + +/* Handle RDS_INFO_TCP_SOCKETS socket option. It only returns IPv4 + * connections for backward compatibility. + */ +static void rds_tcp_tc_info(struct socket *rds_sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens) +{ + struct rds_info_tcp_socket tsinfo; + struct rds_tcp_connection *tc; + unsigned long flags; + + spin_lock_irqsave(&rds_tcp_tc_list_lock, flags); + + if (len / sizeof(tsinfo) < rds_tcp_tc_count) + goto out; + + list_for_each_entry(tc, &rds_tcp_tc_list, t_list_item) { + struct inet_sock *inet = inet_sk(tc->t_sock->sk); + + if (tc->t_cpath->cp_conn->c_isv6) + continue; + + tsinfo.local_addr = inet->inet_saddr; + tsinfo.local_port = inet->inet_sport; + tsinfo.peer_addr = inet->inet_daddr; + tsinfo.peer_port = inet->inet_dport; + + tsinfo.hdr_rem = tc->t_tinc_hdr_rem; + tsinfo.data_rem = tc->t_tinc_data_rem; + tsinfo.last_sent_nxt = tc->t_last_sent_nxt; + tsinfo.last_expected_una = tc->t_last_expected_una; + tsinfo.last_seen_una = tc->t_last_seen_una; + tsinfo.tos = tc->t_cpath->cp_conn->c_tos; + + rds_info_copy(iter, &tsinfo, sizeof(tsinfo)); + } + +out: + lens->nr = rds_tcp_tc_count; + lens->each = sizeof(tsinfo); + + spin_unlock_irqrestore(&rds_tcp_tc_list_lock, flags); +} + +#if IS_ENABLED(CONFIG_IPV6) +/* Handle RDS6_INFO_TCP_SOCKETS socket option. It returns both IPv4 and + * IPv6 connections. IPv4 connection address is returned in an IPv4 mapped + * address. + */ +static void rds6_tcp_tc_info(struct socket *sock, unsigned int len, + struct rds_info_iterator *iter, + struct rds_info_lengths *lens) +{ + struct rds6_info_tcp_socket tsinfo6; + struct rds_tcp_connection *tc; + unsigned long flags; + + spin_lock_irqsave(&rds_tcp_tc_list_lock, flags); + + if (len / sizeof(tsinfo6) < rds6_tcp_tc_count) + goto out; + + list_for_each_entry(tc, &rds_tcp_tc_list, t_list_item) { + struct sock *sk = tc->t_sock->sk; + struct inet_sock *inet = inet_sk(sk); + + tsinfo6.local_addr = sk->sk_v6_rcv_saddr; + tsinfo6.local_port = inet->inet_sport; + tsinfo6.peer_addr = sk->sk_v6_daddr; + tsinfo6.peer_port = inet->inet_dport; + + tsinfo6.hdr_rem = tc->t_tinc_hdr_rem; + tsinfo6.data_rem = tc->t_tinc_data_rem; + tsinfo6.last_sent_nxt = tc->t_last_sent_nxt; + tsinfo6.last_expected_una = tc->t_last_expected_una; + tsinfo6.last_seen_una = tc->t_last_seen_una; + + rds_info_copy(iter, &tsinfo6, sizeof(tsinfo6)); + } + +out: + lens->nr = rds6_tcp_tc_count; + lens->each = sizeof(tsinfo6); + + spin_unlock_irqrestore(&rds_tcp_tc_list_lock, flags); +} +#endif + +int rds_tcp_laddr_check(struct net *net, const struct in6_addr *addr, + __u32 scope_id) +{ + struct net_device *dev = NULL; +#if IS_ENABLED(CONFIG_IPV6) + int ret; +#endif + + if (ipv6_addr_v4mapped(addr)) { + if (inet_addr_type(net, addr->s6_addr32[3]) == RTN_LOCAL) + return 0; + return -EADDRNOTAVAIL; + } + + /* If the scope_id is specified, check only those addresses + * hosted on the specified interface. + */ + if (scope_id != 0) { + rcu_read_lock(); + dev = dev_get_by_index_rcu(net, scope_id); + /* scope_id is not valid... */ + if (!dev) { + rcu_read_unlock(); + return -EADDRNOTAVAIL; + } + rcu_read_unlock(); + } +#if IS_ENABLED(CONFIG_IPV6) + ret = ipv6_chk_addr(net, addr, dev, 0); + if (ret) + return 0; +#endif + return -EADDRNOTAVAIL; +} + +static void rds_tcp_conn_free(void *arg) +{ + struct rds_tcp_connection *tc = arg; + unsigned long flags; + + rdsdebug("freeing tc %p\n", tc); + + spin_lock_irqsave(&rds_tcp_conn_lock, flags); + if (!tc->t_tcp_node_detached) + list_del(&tc->t_tcp_node); + spin_unlock_irqrestore(&rds_tcp_conn_lock, flags); + + kmem_cache_free(rds_tcp_conn_slab, tc); +} + +static int rds_tcp_conn_alloc(struct rds_connection *conn, gfp_t gfp) +{ + struct rds_tcp_connection *tc; + int i, j; + int ret = 0; + + for (i = 0; i < RDS_MPATH_WORKERS; i++) { + tc = kmem_cache_alloc(rds_tcp_conn_slab, gfp); + if (!tc) { + ret = -ENOMEM; + goto fail; + } + mutex_init(&tc->t_conn_path_lock); + tc->t_sock = NULL; + tc->t_tinc = NULL; + tc->t_tinc_hdr_rem = sizeof(struct rds_header); + tc->t_tinc_data_rem = 0; + + conn->c_path[i].cp_transport_data = tc; + tc->t_cpath = &conn->c_path[i]; + tc->t_tcp_node_detached = true; + + rdsdebug("rds_conn_path [%d] tc %p\n", i, + conn->c_path[i].cp_transport_data); + } + spin_lock_irq(&rds_tcp_conn_lock); + for (i = 0; i < RDS_MPATH_WORKERS; i++) { + tc = conn->c_path[i].cp_transport_data; + tc->t_tcp_node_detached = false; + list_add_tail(&tc->t_tcp_node, &rds_tcp_conn_list); + } + spin_unlock_irq(&rds_tcp_conn_lock); +fail: + if (ret) { + for (j = 0; j < i; j++) + rds_tcp_conn_free(conn->c_path[j].cp_transport_data); + } + return ret; +} + +static bool list_has_conn(struct list_head *list, struct rds_connection *conn) +{ + struct rds_tcp_connection *tc, *_tc; + + list_for_each_entry_safe(tc, _tc, list, t_tcp_node) { + if (tc->t_cpath->cp_conn == conn) + return true; + } + return false; +} + +static void rds_tcp_set_unloading(void) +{ + atomic_set(&rds_tcp_unloading, 1); +} + +static bool rds_tcp_is_unloading(struct rds_connection *conn) +{ + return atomic_read(&rds_tcp_unloading) != 0; +} + +static void rds_tcp_destroy_conns(void) +{ + struct rds_tcp_connection *tc, *_tc; + LIST_HEAD(tmp_list); + + /* avoid calling conn_destroy with irqs off */ + spin_lock_irq(&rds_tcp_conn_lock); + list_for_each_entry_safe(tc, _tc, &rds_tcp_conn_list, t_tcp_node) { + if (!list_has_conn(&tmp_list, tc->t_cpath->cp_conn)) + list_move_tail(&tc->t_tcp_node, &tmp_list); + } + spin_unlock_irq(&rds_tcp_conn_lock); + + list_for_each_entry_safe(tc, _tc, &tmp_list, t_tcp_node) + rds_conn_destroy(tc->t_cpath->cp_conn); +} + +static void rds_tcp_exit(void); + +static u8 rds_tcp_get_tos_map(u8 tos) +{ + /* all user tos mapped to default 0 for TCP transport */ + return 0; +} + +struct rds_transport rds_tcp_transport = { + .laddr_check = rds_tcp_laddr_check, + .xmit_path_prepare = rds_tcp_xmit_path_prepare, + .xmit_path_complete = rds_tcp_xmit_path_complete, + .xmit = rds_tcp_xmit, + .recv_path = rds_tcp_recv_path, + .conn_alloc = rds_tcp_conn_alloc, + .conn_free = rds_tcp_conn_free, + .conn_path_connect = rds_tcp_conn_path_connect, + .conn_path_shutdown = rds_tcp_conn_path_shutdown, + .inc_copy_to_user = rds_tcp_inc_copy_to_user, + .inc_free = rds_tcp_inc_free, + .stats_info_copy = rds_tcp_stats_info_copy, + .exit = rds_tcp_exit, + .get_tos_map = rds_tcp_get_tos_map, + .t_owner = THIS_MODULE, + .t_name = "tcp", + .t_type = RDS_TRANS_TCP, + .t_prefer_loopback = 1, + .t_mp_capable = 1, + .t_unloading = rds_tcp_is_unloading, +}; + +static unsigned int rds_tcp_netid; + +/* per-network namespace private data for this module */ +struct rds_tcp_net { + struct socket *rds_tcp_listen_sock; + struct work_struct rds_tcp_accept_w; + struct ctl_table_header *rds_tcp_sysctl; + struct ctl_table *ctl_table; + int sndbuf_size; + int rcvbuf_size; +}; + +/* All module specific customizations to the RDS-TCP socket should be done in + * rds_tcp_tune() and applied after socket creation. + */ +bool rds_tcp_tune(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct net *net = sock_net(sk); + struct rds_tcp_net *rtn; + + tcp_sock_set_nodelay(sock->sk); + lock_sock(sk); + /* TCP timer functions might access net namespace even after + * a process which created this net namespace terminated. + */ + if (!sk->sk_net_refcnt) { + if (!maybe_get_net(net)) { + release_sock(sk); + return false; + } + sk->sk_net_refcnt = 1; + netns_tracker_alloc(net, &sk->ns_tracker, GFP_KERNEL); + sock_inuse_add(net, 1); + } + rtn = net_generic(net, rds_tcp_netid); + if (rtn->sndbuf_size > 0) { + sk->sk_sndbuf = rtn->sndbuf_size; + sk->sk_userlocks |= SOCK_SNDBUF_LOCK; + } + if (rtn->rcvbuf_size > 0) { + sk->sk_rcvbuf = rtn->rcvbuf_size; + sk->sk_userlocks |= SOCK_RCVBUF_LOCK; + } + release_sock(sk); + return true; +} + +static void rds_tcp_accept_worker(struct work_struct *work) +{ + struct rds_tcp_net *rtn = container_of(work, + struct rds_tcp_net, + rds_tcp_accept_w); + + while (rds_tcp_accept_one(rtn->rds_tcp_listen_sock) == 0) + cond_resched(); +} + +void rds_tcp_accept_work(struct sock *sk) +{ + struct net *net = sock_net(sk); + struct rds_tcp_net *rtn = net_generic(net, rds_tcp_netid); + + queue_work(rds_wq, &rtn->rds_tcp_accept_w); +} + +static __net_init int rds_tcp_init_net(struct net *net) +{ + struct rds_tcp_net *rtn = net_generic(net, rds_tcp_netid); + struct ctl_table *tbl; + int err = 0; + + memset(rtn, 0, sizeof(*rtn)); + + /* {snd, rcv}buf_size default to 0, which implies we let the + * stack pick the value, and permit auto-tuning of buffer size. + */ + if (net == &init_net) { + tbl = rds_tcp_sysctl_table; + } else { + tbl = kmemdup(rds_tcp_sysctl_table, + sizeof(rds_tcp_sysctl_table), GFP_KERNEL); + if (!tbl) { + pr_warn("could not set allocate sysctl table\n"); + return -ENOMEM; + } + rtn->ctl_table = tbl; + } + tbl[RDS_TCP_SNDBUF].data = &rtn->sndbuf_size; + tbl[RDS_TCP_RCVBUF].data = &rtn->rcvbuf_size; + rtn->rds_tcp_sysctl = register_net_sysctl(net, "net/rds/tcp", tbl); + if (!rtn->rds_tcp_sysctl) { + pr_warn("could not register sysctl\n"); + err = -ENOMEM; + goto fail; + } + +#if IS_ENABLED(CONFIG_IPV6) + rtn->rds_tcp_listen_sock = rds_tcp_listen_init(net, true); +#else + rtn->rds_tcp_listen_sock = rds_tcp_listen_init(net, false); +#endif + if (!rtn->rds_tcp_listen_sock) { + pr_warn("could not set up IPv6 listen sock\n"); + +#if IS_ENABLED(CONFIG_IPV6) + /* Try IPv4 as some systems disable IPv6 */ + rtn->rds_tcp_listen_sock = rds_tcp_listen_init(net, false); + if (!rtn->rds_tcp_listen_sock) { +#endif + unregister_net_sysctl_table(rtn->rds_tcp_sysctl); + rtn->rds_tcp_sysctl = NULL; + err = -EAFNOSUPPORT; + goto fail; +#if IS_ENABLED(CONFIG_IPV6) + } +#endif + } + INIT_WORK(&rtn->rds_tcp_accept_w, rds_tcp_accept_worker); + return 0; + +fail: + if (net != &init_net) + kfree(tbl); + return err; +} + +static void rds_tcp_kill_sock(struct net *net) +{ + struct rds_tcp_connection *tc, *_tc; + LIST_HEAD(tmp_list); + struct rds_tcp_net *rtn = net_generic(net, rds_tcp_netid); + struct socket *lsock = rtn->rds_tcp_listen_sock; + + rtn->rds_tcp_listen_sock = NULL; + rds_tcp_listen_stop(lsock, &rtn->rds_tcp_accept_w); + spin_lock_irq(&rds_tcp_conn_lock); + list_for_each_entry_safe(tc, _tc, &rds_tcp_conn_list, t_tcp_node) { + struct net *c_net = read_pnet(&tc->t_cpath->cp_conn->c_net); + + if (net != c_net) + continue; + if (!list_has_conn(&tmp_list, tc->t_cpath->cp_conn)) { + list_move_tail(&tc->t_tcp_node, &tmp_list); + } else { + list_del(&tc->t_tcp_node); + tc->t_tcp_node_detached = true; + } + } + spin_unlock_irq(&rds_tcp_conn_lock); + list_for_each_entry_safe(tc, _tc, &tmp_list, t_tcp_node) + rds_conn_destroy(tc->t_cpath->cp_conn); +} + +static void __net_exit rds_tcp_exit_net(struct net *net) +{ + struct rds_tcp_net *rtn = net_generic(net, rds_tcp_netid); + + rds_tcp_kill_sock(net); + + if (rtn->rds_tcp_sysctl) + unregister_net_sysctl_table(rtn->rds_tcp_sysctl); + + if (net != &init_net) + kfree(rtn->ctl_table); +} + +static struct pernet_operations rds_tcp_net_ops = { + .init = rds_tcp_init_net, + .exit = rds_tcp_exit_net, + .id = &rds_tcp_netid, + .size = sizeof(struct rds_tcp_net), +}; + +void *rds_tcp_listen_sock_def_readable(struct net *net) +{ + struct rds_tcp_net *rtn = net_generic(net, rds_tcp_netid); + struct socket *lsock = rtn->rds_tcp_listen_sock; + + if (!lsock) + return NULL; + + return lsock->sk->sk_user_data; +} + +/* when sysctl is used to modify some kernel socket parameters,this + * function resets the RDS connections in that netns so that we can + * restart with new parameters. The assumption is that such reset + * events are few and far-between. + */ +static void rds_tcp_sysctl_reset(struct net *net) +{ + struct rds_tcp_connection *tc, *_tc; + + spin_lock_irq(&rds_tcp_conn_lock); + list_for_each_entry_safe(tc, _tc, &rds_tcp_conn_list, t_tcp_node) { + struct net *c_net = read_pnet(&tc->t_cpath->cp_conn->c_net); + + if (net != c_net || !tc->t_sock) + continue; + + /* reconnect with new parameters */ + rds_conn_path_drop(tc->t_cpath, false); + } + spin_unlock_irq(&rds_tcp_conn_lock); +} + +static int rds_tcp_skbuf_handler(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *fpos) +{ + struct net *net = current->nsproxy->net_ns; + int err; + + err = proc_dointvec_minmax(ctl, write, buffer, lenp, fpos); + if (err < 0) { + pr_warn("Invalid input. Must be >= %d\n", + *(int *)(ctl->extra1)); + return err; + } + if (write) + rds_tcp_sysctl_reset(net); + return 0; +} + +static void rds_tcp_exit(void) +{ + rds_tcp_set_unloading(); + synchronize_rcu(); + rds_info_deregister_func(RDS_INFO_TCP_SOCKETS, rds_tcp_tc_info); +#if IS_ENABLED(CONFIG_IPV6) + rds_info_deregister_func(RDS6_INFO_TCP_SOCKETS, rds6_tcp_tc_info); +#endif + unregister_pernet_device(&rds_tcp_net_ops); + rds_tcp_destroy_conns(); + rds_trans_unregister(&rds_tcp_transport); + rds_tcp_recv_exit(); + kmem_cache_destroy(rds_tcp_conn_slab); +} +module_exit(rds_tcp_exit); + +static int __init rds_tcp_init(void) +{ + int ret; + + rds_tcp_conn_slab = kmem_cache_create("rds_tcp_connection", + sizeof(struct rds_tcp_connection), + 0, 0, NULL); + if (!rds_tcp_conn_slab) { + ret = -ENOMEM; + goto out; + } + + ret = rds_tcp_recv_init(); + if (ret) + goto out_slab; + + ret = register_pernet_device(&rds_tcp_net_ops); + if (ret) + goto out_recv; + + rds_trans_register(&rds_tcp_transport); + + rds_info_register_func(RDS_INFO_TCP_SOCKETS, rds_tcp_tc_info); +#if IS_ENABLED(CONFIG_IPV6) + rds_info_register_func(RDS6_INFO_TCP_SOCKETS, rds6_tcp_tc_info); +#endif + + goto out; +out_recv: + rds_tcp_recv_exit(); +out_slab: + kmem_cache_destroy(rds_tcp_conn_slab); +out: + return ret; +} +module_init(rds_tcp_init); + +MODULE_AUTHOR("Oracle Corporation "); +MODULE_DESCRIPTION("RDS: TCP transport"); +MODULE_LICENSE("Dual BSD/GPL"); diff --git a/net/rds/tcp.h b/net/rds/tcp.h new file mode 100644 index 000000000..f8b5930d7 --- /dev/null +++ b/net/rds/tcp.h @@ -0,0 +1,98 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _RDS_TCP_H +#define _RDS_TCP_H + +#define RDS_TCP_PORT 16385 + +struct rds_tcp_incoming { + struct rds_incoming ti_inc; + struct sk_buff_head ti_skb_list; +}; + +struct rds_tcp_connection { + + struct list_head t_tcp_node; + bool t_tcp_node_detached; + struct rds_conn_path *t_cpath; + /* t_conn_path_lock synchronizes the connection establishment between + * rds_tcp_accept_one and rds_tcp_conn_path_connect + */ + struct mutex t_conn_path_lock; + struct socket *t_sock; + void *t_orig_write_space; + void *t_orig_data_ready; + void *t_orig_state_change; + + struct rds_tcp_incoming *t_tinc; + size_t t_tinc_hdr_rem; + size_t t_tinc_data_rem; + + /* XXX error report? */ + struct work_struct t_conn_w; + struct work_struct t_send_w; + struct work_struct t_down_w; + struct work_struct t_recv_w; + + /* for info exporting only */ + struct list_head t_list_item; + u32 t_last_sent_nxt; + u32 t_last_expected_una; + u32 t_last_seen_una; +}; + +struct rds_tcp_statistics { + uint64_t s_tcp_data_ready_calls; + uint64_t s_tcp_write_space_calls; + uint64_t s_tcp_sndbuf_full; + uint64_t s_tcp_connect_raced; + uint64_t s_tcp_listen_closed_stale; +}; + +/* tcp.c */ +bool rds_tcp_tune(struct socket *sock); +void rds_tcp_set_callbacks(struct socket *sock, struct rds_conn_path *cp); +void rds_tcp_reset_callbacks(struct socket *sock, struct rds_conn_path *cp); +void rds_tcp_restore_callbacks(struct socket *sock, + struct rds_tcp_connection *tc); +u32 rds_tcp_write_seq(struct rds_tcp_connection *tc); +u32 rds_tcp_snd_una(struct rds_tcp_connection *tc); +u64 rds_tcp_map_seq(struct rds_tcp_connection *tc, u32 seq); +extern struct rds_transport rds_tcp_transport; +void rds_tcp_accept_work(struct sock *sk); +int rds_tcp_laddr_check(struct net *net, const struct in6_addr *addr, + __u32 scope_id); +/* tcp_connect.c */ +int rds_tcp_conn_path_connect(struct rds_conn_path *cp); +void rds_tcp_conn_path_shutdown(struct rds_conn_path *conn); +void rds_tcp_state_change(struct sock *sk); + +/* tcp_listen.c */ +struct socket *rds_tcp_listen_init(struct net *net, bool isv6); +void rds_tcp_listen_stop(struct socket *sock, struct work_struct *acceptor); +void rds_tcp_listen_data_ready(struct sock *sk); +int rds_tcp_accept_one(struct socket *sock); +void rds_tcp_keepalive(struct socket *sock); +void *rds_tcp_listen_sock_def_readable(struct net *net); + +/* tcp_recv.c */ +int rds_tcp_recv_init(void); +void rds_tcp_recv_exit(void); +void rds_tcp_data_ready(struct sock *sk); +int rds_tcp_recv_path(struct rds_conn_path *cp); +void rds_tcp_inc_free(struct rds_incoming *inc); +int rds_tcp_inc_copy_to_user(struct rds_incoming *inc, struct iov_iter *to); + +/* tcp_send.c */ +void rds_tcp_xmit_path_prepare(struct rds_conn_path *cp); +void rds_tcp_xmit_path_complete(struct rds_conn_path *cp); +int rds_tcp_xmit(struct rds_connection *conn, struct rds_message *rm, + unsigned int hdr_off, unsigned int sg, unsigned int off); +void rds_tcp_write_space(struct sock *sk); + +/* tcp_stats.c */ +DECLARE_PER_CPU(struct rds_tcp_statistics, rds_tcp_stats); +#define rds_tcp_stats_inc(member) rds_stats_inc_which(rds_tcp_stats, member) +unsigned int rds_tcp_stats_info_copy(struct rds_info_iterator *iter, + unsigned int avail); + +#endif diff --git a/net/rds/tcp_connect.c b/net/rds/tcp_connect.c new file mode 100644 index 000000000..a0046e99d --- /dev/null +++ b/net/rds/tcp_connect.c @@ -0,0 +1,229 @@ +/* + * Copyright (c) 2006, 2017 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include + +#include "rds.h" +#include "tcp.h" + +void rds_tcp_state_change(struct sock *sk) +{ + void (*state_change)(struct sock *sk); + struct rds_conn_path *cp; + struct rds_tcp_connection *tc; + + read_lock_bh(&sk->sk_callback_lock); + cp = sk->sk_user_data; + if (!cp) { + state_change = sk->sk_state_change; + goto out; + } + tc = cp->cp_transport_data; + state_change = tc->t_orig_state_change; + + rdsdebug("sock %p state_change to %d\n", tc->t_sock, sk->sk_state); + + switch (sk->sk_state) { + /* ignore connecting sockets as they make progress */ + case TCP_SYN_SENT: + case TCP_SYN_RECV: + break; + case TCP_ESTABLISHED: + /* Force the peer to reconnect so that we have the + * TCP ports going from . to + * .. We avoid marking the + * RDS connection as RDS_CONN_UP until the reconnect, + * to avoid RDS datagram loss. + */ + if (rds_addr_cmp(&cp->cp_conn->c_laddr, + &cp->cp_conn->c_faddr) >= 0 && + rds_conn_path_transition(cp, RDS_CONN_CONNECTING, + RDS_CONN_ERROR)) { + rds_conn_path_drop(cp, false); + } else { + rds_connect_path_complete(cp, RDS_CONN_CONNECTING); + } + break; + case TCP_CLOSE_WAIT: + case TCP_CLOSE: + rds_conn_path_drop(cp, false); + break; + default: + break; + } +out: + read_unlock_bh(&sk->sk_callback_lock); + state_change(sk); +} + +int rds_tcp_conn_path_connect(struct rds_conn_path *cp) +{ + struct socket *sock = NULL; + struct sockaddr_in6 sin6; + struct sockaddr_in sin; + struct sockaddr *addr; + int addrlen; + bool isv6; + int ret; + struct rds_connection *conn = cp->cp_conn; + struct rds_tcp_connection *tc = cp->cp_transport_data; + + /* for multipath rds,we only trigger the connection after + * the handshake probe has determined the number of paths. + */ + if (cp->cp_index > 0 && cp->cp_conn->c_npaths < 2) + return -EAGAIN; + + mutex_lock(&tc->t_conn_path_lock); + + if (rds_conn_path_up(cp)) { + mutex_unlock(&tc->t_conn_path_lock); + return 0; + } + if (ipv6_addr_v4mapped(&conn->c_laddr)) { + ret = sock_create_kern(rds_conn_net(conn), PF_INET, + SOCK_STREAM, IPPROTO_TCP, &sock); + isv6 = false; + } else { + ret = sock_create_kern(rds_conn_net(conn), PF_INET6, + SOCK_STREAM, IPPROTO_TCP, &sock); + isv6 = true; + } + + if (ret < 0) + goto out; + + if (!rds_tcp_tune(sock)) { + ret = -EINVAL; + goto out; + } + + if (isv6) { + sin6.sin6_family = AF_INET6; + sin6.sin6_addr = conn->c_laddr; + sin6.sin6_port = 0; + sin6.sin6_flowinfo = 0; + sin6.sin6_scope_id = conn->c_dev_if; + addr = (struct sockaddr *)&sin6; + addrlen = sizeof(sin6); + } else { + sin.sin_family = AF_INET; + sin.sin_addr.s_addr = conn->c_laddr.s6_addr32[3]; + sin.sin_port = 0; + addr = (struct sockaddr *)&sin; + addrlen = sizeof(sin); + } + + ret = kernel_bind(sock, addr, addrlen); + if (ret) { + rdsdebug("bind failed with %d at address %pI6c\n", + ret, &conn->c_laddr); + goto out; + } + + if (isv6) { + sin6.sin6_family = AF_INET6; + sin6.sin6_addr = conn->c_faddr; + sin6.sin6_port = htons(RDS_TCP_PORT); + sin6.sin6_flowinfo = 0; + sin6.sin6_scope_id = conn->c_dev_if; + addr = (struct sockaddr *)&sin6; + addrlen = sizeof(sin6); + } else { + sin.sin_family = AF_INET; + sin.sin_addr.s_addr = conn->c_faddr.s6_addr32[3]; + sin.sin_port = htons(RDS_TCP_PORT); + addr = (struct sockaddr *)&sin; + addrlen = sizeof(sin); + } + + /* + * once we call connect() we can start getting callbacks and they + * own the socket + */ + rds_tcp_set_callbacks(sock, cp); + ret = kernel_connect(sock, addr, addrlen, O_NONBLOCK); + + rdsdebug("connect to address %pI6c returned %d\n", &conn->c_faddr, ret); + if (ret == -EINPROGRESS) + ret = 0; + if (ret == 0) { + rds_tcp_keepalive(sock); + sock = NULL; + } else { + rds_tcp_restore_callbacks(sock, cp->cp_transport_data); + } + +out: + mutex_unlock(&tc->t_conn_path_lock); + if (sock) + sock_release(sock); + return ret; +} + +/* + * Before killing the tcp socket this needs to serialize with callbacks. The + * caller has already grabbed the sending sem so we're serialized with other + * senders. + * + * TCP calls the callbacks with the sock lock so we hold it while we reset the + * callbacks to those set by TCP. Our callbacks won't execute again once we + * hold the sock lock. + */ +void rds_tcp_conn_path_shutdown(struct rds_conn_path *cp) +{ + struct rds_tcp_connection *tc = cp->cp_transport_data; + struct socket *sock = tc->t_sock; + + rdsdebug("shutting down conn %p tc %p sock %p\n", + cp->cp_conn, tc, sock); + + if (sock) { + if (rds_destroy_pending(cp->cp_conn)) + sock_no_linger(sock->sk); + sock->ops->shutdown(sock, RCV_SHUTDOWN | SEND_SHUTDOWN); + lock_sock(sock->sk); + rds_tcp_restore_callbacks(sock, tc); /* tc->tc_sock = NULL */ + + release_sock(sock->sk); + sock_release(sock); + } + + if (tc->t_tinc) { + rds_inc_put(&tc->t_tinc->ti_inc); + tc->t_tinc = NULL; + } + tc->t_tinc_hdr_rem = sizeof(struct rds_header); + tc->t_tinc_data_rem = 0; +} diff --git a/net/rds/tcp_listen.c b/net/rds/tcp_listen.c new file mode 100644 index 000000000..b576bd252 --- /dev/null +++ b/net/rds/tcp_listen.c @@ -0,0 +1,348 @@ +/* + * Copyright (c) 2006, 2018 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include + +#include "rds.h" +#include "tcp.h" + +void rds_tcp_keepalive(struct socket *sock) +{ + /* values below based on xs_udp_default_timeout */ + int keepidle = 5; /* send a probe 'keepidle' secs after last data */ + int keepcnt = 5; /* number of unack'ed probes before declaring dead */ + + sock_set_keepalive(sock->sk); + tcp_sock_set_keepcnt(sock->sk, keepcnt); + tcp_sock_set_keepidle(sock->sk, keepidle); + /* KEEPINTVL is the interval between successive probes. We follow + * the model in xs_tcp_finish_connecting() and re-use keepidle. + */ + tcp_sock_set_keepintvl(sock->sk, keepidle); +} + +/* rds_tcp_accept_one_path(): if accepting on cp_index > 0, make sure the + * client's ipaddr < server's ipaddr. Otherwise, close the accepted + * socket and force a reconneect from smaller -> larger ip addr. The reason + * we special case cp_index 0 is to allow the rds probe ping itself to itself + * get through efficiently. + * Since reconnects are only initiated from the node with the numerically + * smaller ip address, we recycle conns in RDS_CONN_ERROR on the passive side + * by moving them to CONNECTING in this function. + */ +static +struct rds_tcp_connection *rds_tcp_accept_one_path(struct rds_connection *conn) +{ + int i; + int npaths = max_t(int, 1, conn->c_npaths); + + /* for mprds, all paths MUST be initiated by the peer + * with the smaller address. + */ + if (rds_addr_cmp(&conn->c_faddr, &conn->c_laddr) >= 0) { + /* Make sure we initiate at least one path if this + * has not already been done; rds_start_mprds() will + * take care of additional paths, if necessary. + */ + if (npaths == 1) + rds_conn_path_connect_if_down(&conn->c_path[0]); + return NULL; + } + + for (i = 0; i < npaths; i++) { + struct rds_conn_path *cp = &conn->c_path[i]; + + if (rds_conn_path_transition(cp, RDS_CONN_DOWN, + RDS_CONN_CONNECTING) || + rds_conn_path_transition(cp, RDS_CONN_ERROR, + RDS_CONN_CONNECTING)) { + return cp->cp_transport_data; + } + } + return NULL; +} + +int rds_tcp_accept_one(struct socket *sock) +{ + struct socket *new_sock = NULL; + struct rds_connection *conn; + int ret; + struct inet_sock *inet; + struct rds_tcp_connection *rs_tcp = NULL; + int conn_state; + struct rds_conn_path *cp; + struct in6_addr *my_addr, *peer_addr; +#if !IS_ENABLED(CONFIG_IPV6) + struct in6_addr saddr, daddr; +#endif + int dev_if = 0; + + if (!sock) /* module unload or netns delete in progress */ + return -ENETUNREACH; + + ret = sock_create_lite(sock->sk->sk_family, + sock->sk->sk_type, sock->sk->sk_protocol, + &new_sock); + if (ret) + goto out; + + ret = sock->ops->accept(sock, new_sock, O_NONBLOCK, true); + if (ret < 0) + goto out; + + /* sock_create_lite() does not get a hold on the owner module so we + * need to do it here. Note that sock_release() uses sock->ops to + * determine if it needs to decrement the reference count. So set + * sock->ops after calling accept() in case that fails. And there's + * no need to do try_module_get() as the listener should have a hold + * already. + */ + new_sock->ops = sock->ops; + __module_get(new_sock->ops->owner); + + rds_tcp_keepalive(new_sock); + if (!rds_tcp_tune(new_sock)) { + ret = -EINVAL; + goto out; + } + + inet = inet_sk(new_sock->sk); + +#if IS_ENABLED(CONFIG_IPV6) + my_addr = &new_sock->sk->sk_v6_rcv_saddr; + peer_addr = &new_sock->sk->sk_v6_daddr; +#else + ipv6_addr_set_v4mapped(inet->inet_saddr, &saddr); + ipv6_addr_set_v4mapped(inet->inet_daddr, &daddr); + my_addr = &saddr; + peer_addr = &daddr; +#endif + rdsdebug("accepted family %d tcp %pI6c:%u -> %pI6c:%u\n", + sock->sk->sk_family, + my_addr, ntohs(inet->inet_sport), + peer_addr, ntohs(inet->inet_dport)); + +#if IS_ENABLED(CONFIG_IPV6) + /* sk_bound_dev_if is not set if the peer address is not link local + * address. In this case, it happens that mcast_oif is set. So + * just use it. + */ + if ((ipv6_addr_type(my_addr) & IPV6_ADDR_LINKLOCAL) && + !(ipv6_addr_type(peer_addr) & IPV6_ADDR_LINKLOCAL)) { + struct ipv6_pinfo *inet6; + + inet6 = inet6_sk(new_sock->sk); + dev_if = inet6->mcast_oif; + } else { + dev_if = new_sock->sk->sk_bound_dev_if; + } +#endif + + if (!rds_tcp_laddr_check(sock_net(sock->sk), peer_addr, dev_if)) { + /* local address connection is only allowed via loopback */ + ret = -EOPNOTSUPP; + goto out; + } + + conn = rds_conn_create(sock_net(sock->sk), + my_addr, peer_addr, + &rds_tcp_transport, 0, GFP_KERNEL, dev_if); + + if (IS_ERR(conn)) { + ret = PTR_ERR(conn); + goto out; + } + /* An incoming SYN request came in, and TCP just accepted it. + * + * If the client reboots, this conn will need to be cleaned up. + * rds_tcp_state_change() will do that cleanup + */ + rs_tcp = rds_tcp_accept_one_path(conn); + if (!rs_tcp) + goto rst_nsk; + mutex_lock(&rs_tcp->t_conn_path_lock); + cp = rs_tcp->t_cpath; + conn_state = rds_conn_path_state(cp); + WARN_ON(conn_state == RDS_CONN_UP); + if (conn_state != RDS_CONN_CONNECTING && conn_state != RDS_CONN_ERROR) + goto rst_nsk; + if (rs_tcp->t_sock) { + /* Duelling SYN has been handled in rds_tcp_accept_one() */ + rds_tcp_reset_callbacks(new_sock, cp); + /* rds_connect_path_complete() marks RDS_CONN_UP */ + rds_connect_path_complete(cp, RDS_CONN_RESETTING); + } else { + rds_tcp_set_callbacks(new_sock, cp); + rds_connect_path_complete(cp, RDS_CONN_CONNECTING); + } + new_sock = NULL; + ret = 0; + if (conn->c_npaths == 0) + rds_send_ping(cp->cp_conn, cp->cp_index); + goto out; +rst_nsk: + /* reset the newly returned accept sock and bail. + * It is safe to set linger on new_sock because the RDS connection + * has not been brought up on new_sock, so no RDS-level data could + * be pending on it. By setting linger, we achieve the side-effect + * of avoiding TIME_WAIT state on new_sock. + */ + sock_no_linger(new_sock->sk); + kernel_sock_shutdown(new_sock, SHUT_RDWR); + ret = 0; +out: + if (rs_tcp) + mutex_unlock(&rs_tcp->t_conn_path_lock); + if (new_sock) + sock_release(new_sock); + return ret; +} + +void rds_tcp_listen_data_ready(struct sock *sk) +{ + void (*ready)(struct sock *sk); + + rdsdebug("listen data ready sk %p\n", sk); + + read_lock_bh(&sk->sk_callback_lock); + ready = sk->sk_user_data; + if (!ready) { /* check for teardown race */ + ready = sk->sk_data_ready; + goto out; + } + + /* + * ->sk_data_ready is also called for a newly established child socket + * before it has been accepted and the accepter has set up their + * data_ready.. we only want to queue listen work for our listening + * socket + * + * (*ready)() may be null if we are racing with netns delete, and + * the listen socket is being torn down. + */ + if (sk->sk_state == TCP_LISTEN) + rds_tcp_accept_work(sk); + else + ready = rds_tcp_listen_sock_def_readable(sock_net(sk)); + +out: + read_unlock_bh(&sk->sk_callback_lock); + if (ready) + ready(sk); +} + +struct socket *rds_tcp_listen_init(struct net *net, bool isv6) +{ + struct socket *sock = NULL; + struct sockaddr_storage ss; + struct sockaddr_in6 *sin6; + struct sockaddr_in *sin; + int addr_len; + int ret; + + ret = sock_create_kern(net, isv6 ? PF_INET6 : PF_INET, SOCK_STREAM, + IPPROTO_TCP, &sock); + if (ret < 0) { + rdsdebug("could not create %s listener socket: %d\n", + isv6 ? "IPv6" : "IPv4", ret); + goto out; + } + + sock->sk->sk_reuse = SK_CAN_REUSE; + tcp_sock_set_nodelay(sock->sk); + + write_lock_bh(&sock->sk->sk_callback_lock); + sock->sk->sk_user_data = sock->sk->sk_data_ready; + sock->sk->sk_data_ready = rds_tcp_listen_data_ready; + write_unlock_bh(&sock->sk->sk_callback_lock); + + if (isv6) { + sin6 = (struct sockaddr_in6 *)&ss; + sin6->sin6_family = PF_INET6; + sin6->sin6_addr = in6addr_any; + sin6->sin6_port = (__force u16)htons(RDS_TCP_PORT); + sin6->sin6_scope_id = 0; + sin6->sin6_flowinfo = 0; + addr_len = sizeof(*sin6); + } else { + sin = (struct sockaddr_in *)&ss; + sin->sin_family = PF_INET; + sin->sin_addr.s_addr = INADDR_ANY; + sin->sin_port = (__force u16)htons(RDS_TCP_PORT); + addr_len = sizeof(*sin); + } + + ret = kernel_bind(sock, (struct sockaddr *)&ss, addr_len); + if (ret < 0) { + rdsdebug("could not bind %s listener socket: %d\n", + isv6 ? "IPv6" : "IPv4", ret); + goto out; + } + + ret = sock->ops->listen(sock, 64); + if (ret < 0) + goto out; + + return sock; +out: + if (sock) + sock_release(sock); + return NULL; +} + +void rds_tcp_listen_stop(struct socket *sock, struct work_struct *acceptor) +{ + struct sock *sk; + + if (!sock) + return; + + sk = sock->sk; + + /* serialize with and prevent further callbacks */ + lock_sock(sk); + write_lock_bh(&sk->sk_callback_lock); + if (sk->sk_user_data) { + sk->sk_data_ready = sk->sk_user_data; + sk->sk_user_data = NULL; + } + write_unlock_bh(&sk->sk_callback_lock); + release_sock(sk); + + /* wait for accepts to stop and close the socket */ + flush_workqueue(rds_wq); + flush_work(acceptor); + sock_release(sock); +} diff --git a/net/rds/tcp_recv.c b/net/rds/tcp_recv.c new file mode 100644 index 000000000..f4ee13da9 --- /dev/null +++ b/net/rds/tcp_recv.c @@ -0,0 +1,349 @@ +/* + * Copyright (c) 2006, 2017 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include + +#include "rds.h" +#include "tcp.h" + +static struct kmem_cache *rds_tcp_incoming_slab; + +static void rds_tcp_inc_purge(struct rds_incoming *inc) +{ + struct rds_tcp_incoming *tinc; + tinc = container_of(inc, struct rds_tcp_incoming, ti_inc); + rdsdebug("purging tinc %p inc %p\n", tinc, inc); + skb_queue_purge(&tinc->ti_skb_list); +} + +void rds_tcp_inc_free(struct rds_incoming *inc) +{ + struct rds_tcp_incoming *tinc; + tinc = container_of(inc, struct rds_tcp_incoming, ti_inc); + rds_tcp_inc_purge(inc); + rdsdebug("freeing tinc %p inc %p\n", tinc, inc); + kmem_cache_free(rds_tcp_incoming_slab, tinc); +} + +/* + * this is pretty lame, but, whatever. + */ +int rds_tcp_inc_copy_to_user(struct rds_incoming *inc, struct iov_iter *to) +{ + struct rds_tcp_incoming *tinc; + struct sk_buff *skb; + int ret = 0; + + if (!iov_iter_count(to)) + goto out; + + tinc = container_of(inc, struct rds_tcp_incoming, ti_inc); + + skb_queue_walk(&tinc->ti_skb_list, skb) { + unsigned long to_copy, skb_off; + for (skb_off = 0; skb_off < skb->len; skb_off += to_copy) { + to_copy = iov_iter_count(to); + to_copy = min(to_copy, skb->len - skb_off); + + if (skb_copy_datagram_iter(skb, skb_off, to, to_copy)) + return -EFAULT; + + rds_stats_add(s_copy_to_user, to_copy); + ret += to_copy; + + if (!iov_iter_count(to)) + goto out; + } + } +out: + return ret; +} + +/* + * We have a series of skbs that have fragmented pieces of the congestion + * bitmap. They must add up to the exact size of the congestion bitmap. We + * use the skb helpers to copy those into the pages that make up the in-memory + * congestion bitmap for the remote address of this connection. We then tell + * the congestion core that the bitmap has been changed so that it can wake up + * sleepers. + * + * This is racing with sending paths which are using test_bit to see if the + * bitmap indicates that their recipient is congested. + */ + +static void rds_tcp_cong_recv(struct rds_connection *conn, + struct rds_tcp_incoming *tinc) +{ + struct sk_buff *skb; + unsigned int to_copy, skb_off; + unsigned int map_off; + unsigned int map_page; + struct rds_cong_map *map; + int ret; + + /* catch completely corrupt packets */ + if (be32_to_cpu(tinc->ti_inc.i_hdr.h_len) != RDS_CONG_MAP_BYTES) + return; + + map_page = 0; + map_off = 0; + map = conn->c_fcong; + + skb_queue_walk(&tinc->ti_skb_list, skb) { + skb_off = 0; + while (skb_off < skb->len) { + to_copy = min_t(unsigned int, PAGE_SIZE - map_off, + skb->len - skb_off); + + BUG_ON(map_page >= RDS_CONG_MAP_PAGES); + + /* only returns 0 or -error */ + ret = skb_copy_bits(skb, skb_off, + (void *)map->m_page_addrs[map_page] + map_off, + to_copy); + BUG_ON(ret != 0); + + skb_off += to_copy; + map_off += to_copy; + if (map_off == PAGE_SIZE) { + map_off = 0; + map_page++; + } + } + } + + rds_cong_map_updated(map, ~(u64) 0); +} + +struct rds_tcp_desc_arg { + struct rds_conn_path *conn_path; + gfp_t gfp; +}; + +static int rds_tcp_data_recv(read_descriptor_t *desc, struct sk_buff *skb, + unsigned int offset, size_t len) +{ + struct rds_tcp_desc_arg *arg = desc->arg.data; + struct rds_conn_path *cp = arg->conn_path; + struct rds_tcp_connection *tc = cp->cp_transport_data; + struct rds_tcp_incoming *tinc = tc->t_tinc; + struct sk_buff *clone; + size_t left = len, to_copy; + + rdsdebug("tcp data tc %p skb %p offset %u len %zu\n", tc, skb, offset, + len); + + /* + * tcp_read_sock() interprets partial progress as an indication to stop + * processing. + */ + while (left) { + if (!tinc) { + tinc = kmem_cache_alloc(rds_tcp_incoming_slab, + arg->gfp); + if (!tinc) { + desc->error = -ENOMEM; + goto out; + } + tc->t_tinc = tinc; + rdsdebug("allocated tinc %p\n", tinc); + rds_inc_path_init(&tinc->ti_inc, cp, + &cp->cp_conn->c_faddr); + tinc->ti_inc.i_rx_lat_trace[RDS_MSG_RX_HDR] = + local_clock(); + + /* + * XXX * we might be able to use the __ variants when + * we've already serialized at a higher level. + */ + skb_queue_head_init(&tinc->ti_skb_list); + } + + if (left && tc->t_tinc_hdr_rem) { + to_copy = min(tc->t_tinc_hdr_rem, left); + rdsdebug("copying %zu header from skb %p\n", to_copy, + skb); + skb_copy_bits(skb, offset, + (char *)&tinc->ti_inc.i_hdr + + sizeof(struct rds_header) - + tc->t_tinc_hdr_rem, + to_copy); + tc->t_tinc_hdr_rem -= to_copy; + left -= to_copy; + offset += to_copy; + + if (tc->t_tinc_hdr_rem == 0) { + /* could be 0 for a 0 len message */ + tc->t_tinc_data_rem = + be32_to_cpu(tinc->ti_inc.i_hdr.h_len); + tinc->ti_inc.i_rx_lat_trace[RDS_MSG_RX_START] = + local_clock(); + } + } + + if (left && tc->t_tinc_data_rem) { + to_copy = min(tc->t_tinc_data_rem, left); + + clone = pskb_extract(skb, offset, to_copy, arg->gfp); + if (!clone) { + desc->error = -ENOMEM; + goto out; + } + + skb_queue_tail(&tinc->ti_skb_list, clone); + + rdsdebug("skb %p data %p len %d off %u to_copy %zu -> " + "clone %p data %p len %d\n", + skb, skb->data, skb->len, offset, to_copy, + clone, clone->data, clone->len); + + tc->t_tinc_data_rem -= to_copy; + left -= to_copy; + offset += to_copy; + } + + if (tc->t_tinc_hdr_rem == 0 && tc->t_tinc_data_rem == 0) { + struct rds_connection *conn = cp->cp_conn; + + if (tinc->ti_inc.i_hdr.h_flags == RDS_FLAG_CONG_BITMAP) + rds_tcp_cong_recv(conn, tinc); + else + rds_recv_incoming(conn, &conn->c_faddr, + &conn->c_laddr, + &tinc->ti_inc, + arg->gfp); + + tc->t_tinc_hdr_rem = sizeof(struct rds_header); + tc->t_tinc_data_rem = 0; + tc->t_tinc = NULL; + rds_inc_put(&tinc->ti_inc); + tinc = NULL; + } + } +out: + rdsdebug("returning len %zu left %zu skb len %d rx queue depth %d\n", + len, left, skb->len, + skb_queue_len(&tc->t_sock->sk->sk_receive_queue)); + return len - left; +} + +/* the caller has to hold the sock lock */ +static int rds_tcp_read_sock(struct rds_conn_path *cp, gfp_t gfp) +{ + struct rds_tcp_connection *tc = cp->cp_transport_data; + struct socket *sock = tc->t_sock; + read_descriptor_t desc; + struct rds_tcp_desc_arg arg; + + /* It's like glib in the kernel! */ + arg.conn_path = cp; + arg.gfp = gfp; + desc.arg.data = &arg; + desc.error = 0; + desc.count = 1; /* give more than one skb per call */ + + tcp_read_sock(sock->sk, &desc, rds_tcp_data_recv); + rdsdebug("tcp_read_sock for tc %p gfp 0x%x returned %d\n", tc, gfp, + desc.error); + + return desc.error; +} + +/* + * We hold the sock lock to serialize our rds_tcp_recv->tcp_read_sock from + * data_ready. + * + * if we fail to allocate we're in trouble.. blindly wait some time before + * trying again to see if the VM can free up something for us. + */ +int rds_tcp_recv_path(struct rds_conn_path *cp) +{ + struct rds_tcp_connection *tc = cp->cp_transport_data; + struct socket *sock = tc->t_sock; + int ret = 0; + + rdsdebug("recv worker path [%d] tc %p sock %p\n", + cp->cp_index, tc, sock); + + lock_sock(sock->sk); + ret = rds_tcp_read_sock(cp, GFP_KERNEL); + release_sock(sock->sk); + + return ret; +} + +void rds_tcp_data_ready(struct sock *sk) +{ + void (*ready)(struct sock *sk); + struct rds_conn_path *cp; + struct rds_tcp_connection *tc; + + rdsdebug("data ready sk %p\n", sk); + + read_lock_bh(&sk->sk_callback_lock); + cp = sk->sk_user_data; + if (!cp) { /* check for teardown race */ + ready = sk->sk_data_ready; + goto out; + } + + tc = cp->cp_transport_data; + ready = tc->t_orig_data_ready; + rds_tcp_stats_inc(s_tcp_data_ready_calls); + + if (rds_tcp_read_sock(cp, GFP_ATOMIC) == -ENOMEM) { + rcu_read_lock(); + if (!rds_destroy_pending(cp->cp_conn)) + queue_delayed_work(rds_wq, &cp->cp_recv_w, 0); + rcu_read_unlock(); + } +out: + read_unlock_bh(&sk->sk_callback_lock); + ready(sk); +} + +int rds_tcp_recv_init(void) +{ + rds_tcp_incoming_slab = kmem_cache_create("rds_tcp_incoming", + sizeof(struct rds_tcp_incoming), + 0, 0, NULL); + if (!rds_tcp_incoming_slab) + return -ENOMEM; + return 0; +} + +void rds_tcp_recv_exit(void) +{ + kmem_cache_destroy(rds_tcp_incoming_slab); +} diff --git a/net/rds/tcp_send.c b/net/rds/tcp_send.c new file mode 100644 index 000000000..8c4d1d6e9 --- /dev/null +++ b/net/rds/tcp_send.c @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2006, 2017 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include + +#include "rds_single_path.h" +#include "rds.h" +#include "tcp.h" + +void rds_tcp_xmit_path_prepare(struct rds_conn_path *cp) +{ + struct rds_tcp_connection *tc = cp->cp_transport_data; + + tcp_sock_set_cork(tc->t_sock->sk, true); +} + +void rds_tcp_xmit_path_complete(struct rds_conn_path *cp) +{ + struct rds_tcp_connection *tc = cp->cp_transport_data; + + tcp_sock_set_cork(tc->t_sock->sk, false); +} + +/* the core send_sem serializes this with other xmit and shutdown */ +static int rds_tcp_sendmsg(struct socket *sock, void *data, unsigned int len) +{ + struct kvec vec = { + .iov_base = data, + .iov_len = len, + }; + struct msghdr msg = { + .msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL, + }; + + return kernel_sendmsg(sock, &msg, &vec, 1, vec.iov_len); +} + +/* the core send_sem serializes this with other xmit and shutdown */ +int rds_tcp_xmit(struct rds_connection *conn, struct rds_message *rm, + unsigned int hdr_off, unsigned int sg, unsigned int off) +{ + struct rds_conn_path *cp = rm->m_inc.i_conn_path; + struct rds_tcp_connection *tc = cp->cp_transport_data; + int done = 0; + int ret = 0; + int more; + + if (hdr_off == 0) { + /* + * m_ack_seq is set to the sequence number of the last byte of + * header and data. see rds_tcp_is_acked(). + */ + tc->t_last_sent_nxt = rds_tcp_write_seq(tc); + rm->m_ack_seq = tc->t_last_sent_nxt + + sizeof(struct rds_header) + + be32_to_cpu(rm->m_inc.i_hdr.h_len) - 1; + smp_mb__before_atomic(); + set_bit(RDS_MSG_HAS_ACK_SEQ, &rm->m_flags); + tc->t_last_expected_una = rm->m_ack_seq + 1; + + if (test_bit(RDS_MSG_RETRANSMITTED, &rm->m_flags)) + rm->m_inc.i_hdr.h_flags |= RDS_FLAG_RETRANSMITTED; + + rdsdebug("rm %p tcp nxt %u ack_seq %llu\n", + rm, rds_tcp_write_seq(tc), + (unsigned long long)rm->m_ack_seq); + } + + if (hdr_off < sizeof(struct rds_header)) { + /* see rds_tcp_write_space() */ + set_bit(SOCK_NOSPACE, &tc->t_sock->sk->sk_socket->flags); + + ret = rds_tcp_sendmsg(tc->t_sock, + (void *)&rm->m_inc.i_hdr + hdr_off, + sizeof(rm->m_inc.i_hdr) - hdr_off); + if (ret < 0) + goto out; + done += ret; + if (hdr_off + done != sizeof(struct rds_header)) + goto out; + } + + more = rm->data.op_nents > 1 ? (MSG_MORE | MSG_SENDPAGE_NOTLAST) : 0; + while (sg < rm->data.op_nents) { + int flags = MSG_DONTWAIT | MSG_NOSIGNAL | more; + + ret = tc->t_sock->ops->sendpage(tc->t_sock, + sg_page(&rm->data.op_sg[sg]), + rm->data.op_sg[sg].offset + off, + rm->data.op_sg[sg].length - off, + flags); + rdsdebug("tcp sendpage %p:%u:%u ret %d\n", (void *)sg_page(&rm->data.op_sg[sg]), + rm->data.op_sg[sg].offset + off, rm->data.op_sg[sg].length - off, + ret); + if (ret <= 0) + break; + + off += ret; + done += ret; + if (off == rm->data.op_sg[sg].length) { + off = 0; + sg++; + } + if (sg == rm->data.op_nents - 1) + more = 0; + } + +out: + if (ret <= 0) { + /* write_space will hit after EAGAIN, all else fatal */ + if (ret == -EAGAIN) { + rds_tcp_stats_inc(s_tcp_sndbuf_full); + ret = 0; + } else { + /* No need to disconnect/reconnect if path_drop + * has already been triggered, because, e.g., of + * an incoming RST. + */ + if (rds_conn_path_up(cp)) { + pr_warn("RDS/tcp: send to %pI6c on cp [%d]" + "returned %d, " + "disconnecting and reconnecting\n", + &conn->c_faddr, cp->cp_index, ret); + rds_conn_path_drop(cp, false); + } + } + } + if (done == 0) + done = ret; + return done; +} + +/* + * rm->m_ack_seq is set to the tcp sequence number that corresponds to the + * last byte of the message, including the header. This means that the + * entire message has been received if rm->m_ack_seq is "before" the next + * unacked byte of the TCP sequence space. We have to do very careful + * wrapping 32bit comparisons here. + */ +static int rds_tcp_is_acked(struct rds_message *rm, uint64_t ack) +{ + if (!test_bit(RDS_MSG_HAS_ACK_SEQ, &rm->m_flags)) + return 0; + return (__s32)((u32)rm->m_ack_seq - (u32)ack) < 0; +} + +void rds_tcp_write_space(struct sock *sk) +{ + void (*write_space)(struct sock *sk); + struct rds_conn_path *cp; + struct rds_tcp_connection *tc; + + read_lock_bh(&sk->sk_callback_lock); + cp = sk->sk_user_data; + if (!cp) { + write_space = sk->sk_write_space; + goto out; + } + + tc = cp->cp_transport_data; + rdsdebug("write_space for tc %p\n", tc); + write_space = tc->t_orig_write_space; + rds_tcp_stats_inc(s_tcp_write_space_calls); + + rdsdebug("tcp una %u\n", rds_tcp_snd_una(tc)); + tc->t_last_seen_una = rds_tcp_snd_una(tc); + rds_send_path_drop_acked(cp, rds_tcp_snd_una(tc), rds_tcp_is_acked); + + rcu_read_lock(); + if ((refcount_read(&sk->sk_wmem_alloc) << 1) <= sk->sk_sndbuf && + !rds_destroy_pending(cp->cp_conn)) + queue_delayed_work(rds_wq, &cp->cp_send_w, 0); + rcu_read_unlock(); + +out: + read_unlock_bh(&sk->sk_callback_lock); + + /* + * write_space is only called when data leaves tcp's send queue if + * SOCK_NOSPACE is set. We set SOCK_NOSPACE every time we put + * data in tcp's send queue because we use write_space to parse the + * sequence numbers and notice that rds messages have been fully + * received. + * + * tcp's write_space clears SOCK_NOSPACE if the send queue has more + * than a certain amount of space. So we need to set it again *after* + * we call tcp's write_space or else we might only get called on the + * first of a series of incoming tcp acks. + */ + write_space(sk); + + if (sk->sk_socket) + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); +} diff --git a/net/rds/tcp_stats.c b/net/rds/tcp_stats.c new file mode 100644 index 000000000..f8a7954f1 --- /dev/null +++ b/net/rds/tcp_stats.c @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2006 Oracle. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include + +#include "rds.h" +#include "tcp.h" + +DEFINE_PER_CPU(struct rds_tcp_statistics, rds_tcp_stats) + ____cacheline_aligned; + +static const char * const rds_tcp_stat_names[] = { + "tcp_data_ready_calls", + "tcp_write_space_calls", + "tcp_sndbuf_full", + "tcp_connect_raced", + "tcp_listen_closed_stale", +}; + +unsigned int rds_tcp_stats_info_copy(struct rds_info_iterator *iter, + unsigned int avail) +{ + struct rds_tcp_statistics stats = {0, }; + uint64_t *src; + uint64_t *sum; + size_t i; + int cpu; + + if (avail < ARRAY_SIZE(rds_tcp_stat_names)) + goto out; + + for_each_online_cpu(cpu) { + src = (uint64_t *)&(per_cpu(rds_tcp_stats, cpu)); + sum = (uint64_t *)&stats; + for (i = 0; i < sizeof(stats) / sizeof(uint64_t); i++) + *(sum++) += *(src++); + } + + rds_stats_info_copy(iter, (uint64_t *)&stats, rds_tcp_stat_names, + ARRAY_SIZE(rds_tcp_stat_names)); +out: + return ARRAY_SIZE(rds_tcp_stat_names); +} diff --git a/net/rds/threads.c b/net/rds/threads.c new file mode 100644 index 000000000..1f424cbfc --- /dev/null +++ b/net/rds/threads.c @@ -0,0 +1,311 @@ +/* + * Copyright (c) 2006, 2018 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include + +#include "rds.h" + +/* + * All of connection management is simplified by serializing it through + * work queues that execute in a connection managing thread. + * + * TCP wants to send acks through sendpage() in response to data_ready(), + * but it needs a process context to do so. + * + * The receive paths need to allocate but can't drop packets (!) so we have + * a thread around to block allocating if the receive fast path sees an + * allocation failure. + */ + +/* Grand Unified Theory of connection life cycle: + * At any point in time, the connection can be in one of these states: + * DOWN, CONNECTING, UP, DISCONNECTING, ERROR + * + * The following transitions are possible: + * ANY -> ERROR + * UP -> DISCONNECTING + * ERROR -> DISCONNECTING + * DISCONNECTING -> DOWN + * DOWN -> CONNECTING + * CONNECTING -> UP + * + * Transition to state DISCONNECTING/DOWN: + * - Inside the shutdown worker; synchronizes with xmit path + * through RDS_IN_XMIT, and with connection management callbacks + * via c_cm_lock. + * + * For receive callbacks, we rely on the underlying transport + * (TCP, IB/RDMA) to provide the necessary synchronisation. + */ +struct workqueue_struct *rds_wq; +EXPORT_SYMBOL_GPL(rds_wq); + +void rds_connect_path_complete(struct rds_conn_path *cp, int curr) +{ + if (!rds_conn_path_transition(cp, curr, RDS_CONN_UP)) { + printk(KERN_WARNING "%s: Cannot transition to state UP, " + "current state is %d\n", + __func__, + atomic_read(&cp->cp_state)); + rds_conn_path_drop(cp, false); + return; + } + + rdsdebug("conn %p for %pI6c to %pI6c complete\n", + cp->cp_conn, &cp->cp_conn->c_laddr, &cp->cp_conn->c_faddr); + + cp->cp_reconnect_jiffies = 0; + set_bit(0, &cp->cp_conn->c_map_queued); + rcu_read_lock(); + if (!rds_destroy_pending(cp->cp_conn)) { + queue_delayed_work(rds_wq, &cp->cp_send_w, 0); + queue_delayed_work(rds_wq, &cp->cp_recv_w, 0); + } + rcu_read_unlock(); + cp->cp_conn->c_proposed_version = RDS_PROTOCOL_VERSION; +} +EXPORT_SYMBOL_GPL(rds_connect_path_complete); + +void rds_connect_complete(struct rds_connection *conn) +{ + rds_connect_path_complete(&conn->c_path[0], RDS_CONN_CONNECTING); +} +EXPORT_SYMBOL_GPL(rds_connect_complete); + +/* + * This random exponential backoff is relied on to eventually resolve racing + * connects. + * + * If connect attempts race then both parties drop both connections and come + * here to wait for a random amount of time before trying again. Eventually + * the backoff range will be so much greater than the time it takes to + * establish a connection that one of the pair will establish the connection + * before the other's random delay fires. + * + * Connection attempts that arrive while a connection is already established + * are also considered to be racing connects. This lets a connection from + * a rebooted machine replace an existing stale connection before the transport + * notices that the connection has failed. + * + * We should *always* start with a random backoff; otherwise a broken connection + * will always take several iterations to be re-established. + */ +void rds_queue_reconnect(struct rds_conn_path *cp) +{ + unsigned long rand; + struct rds_connection *conn = cp->cp_conn; + + rdsdebug("conn %p for %pI6c to %pI6c reconnect jiffies %lu\n", + conn, &conn->c_laddr, &conn->c_faddr, + cp->cp_reconnect_jiffies); + + /* let peer with smaller addr initiate reconnect, to avoid duels */ + if (conn->c_trans->t_type == RDS_TRANS_TCP && + rds_addr_cmp(&conn->c_laddr, &conn->c_faddr) >= 0) + return; + + set_bit(RDS_RECONNECT_PENDING, &cp->cp_flags); + if (cp->cp_reconnect_jiffies == 0) { + cp->cp_reconnect_jiffies = rds_sysctl_reconnect_min_jiffies; + rcu_read_lock(); + if (!rds_destroy_pending(cp->cp_conn)) + queue_delayed_work(rds_wq, &cp->cp_conn_w, 0); + rcu_read_unlock(); + return; + } + + get_random_bytes(&rand, sizeof(rand)); + rdsdebug("%lu delay %lu ceil conn %p for %pI6c -> %pI6c\n", + rand % cp->cp_reconnect_jiffies, cp->cp_reconnect_jiffies, + conn, &conn->c_laddr, &conn->c_faddr); + rcu_read_lock(); + if (!rds_destroy_pending(cp->cp_conn)) + queue_delayed_work(rds_wq, &cp->cp_conn_w, + rand % cp->cp_reconnect_jiffies); + rcu_read_unlock(); + + cp->cp_reconnect_jiffies = min(cp->cp_reconnect_jiffies * 2, + rds_sysctl_reconnect_max_jiffies); +} + +void rds_connect_worker(struct work_struct *work) +{ + struct rds_conn_path *cp = container_of(work, + struct rds_conn_path, + cp_conn_w.work); + struct rds_connection *conn = cp->cp_conn; + int ret; + + if (cp->cp_index > 0 && + rds_addr_cmp(&cp->cp_conn->c_laddr, &cp->cp_conn->c_faddr) >= 0) + return; + clear_bit(RDS_RECONNECT_PENDING, &cp->cp_flags); + ret = rds_conn_path_transition(cp, RDS_CONN_DOWN, RDS_CONN_CONNECTING); + if (ret) { + ret = conn->c_trans->conn_path_connect(cp); + rdsdebug("conn %p for %pI6c to %pI6c dispatched, ret %d\n", + conn, &conn->c_laddr, &conn->c_faddr, ret); + + if (ret) { + if (rds_conn_path_transition(cp, + RDS_CONN_CONNECTING, + RDS_CONN_DOWN)) + rds_queue_reconnect(cp); + else + rds_conn_path_error(cp, "connect failed\n"); + } + } +} + +void rds_send_worker(struct work_struct *work) +{ + struct rds_conn_path *cp = container_of(work, + struct rds_conn_path, + cp_send_w.work); + int ret; + + if (rds_conn_path_state(cp) == RDS_CONN_UP) { + clear_bit(RDS_LL_SEND_FULL, &cp->cp_flags); + ret = rds_send_xmit(cp); + cond_resched(); + rdsdebug("conn %p ret %d\n", cp->cp_conn, ret); + switch (ret) { + case -EAGAIN: + rds_stats_inc(s_send_immediate_retry); + queue_delayed_work(rds_wq, &cp->cp_send_w, 0); + break; + case -ENOMEM: + rds_stats_inc(s_send_delayed_retry); + queue_delayed_work(rds_wq, &cp->cp_send_w, 2); + break; + default: + break; + } + } +} + +void rds_recv_worker(struct work_struct *work) +{ + struct rds_conn_path *cp = container_of(work, + struct rds_conn_path, + cp_recv_w.work); + int ret; + + if (rds_conn_path_state(cp) == RDS_CONN_UP) { + ret = cp->cp_conn->c_trans->recv_path(cp); + rdsdebug("conn %p ret %d\n", cp->cp_conn, ret); + switch (ret) { + case -EAGAIN: + rds_stats_inc(s_recv_immediate_retry); + queue_delayed_work(rds_wq, &cp->cp_recv_w, 0); + break; + case -ENOMEM: + rds_stats_inc(s_recv_delayed_retry); + queue_delayed_work(rds_wq, &cp->cp_recv_w, 2); + break; + default: + break; + } + } +} + +void rds_shutdown_worker(struct work_struct *work) +{ + struct rds_conn_path *cp = container_of(work, + struct rds_conn_path, + cp_down_w); + + rds_conn_shutdown(cp); +} + +void rds_threads_exit(void) +{ + destroy_workqueue(rds_wq); +} + +int rds_threads_init(void) +{ + rds_wq = create_singlethread_workqueue("krdsd"); + if (!rds_wq) + return -ENOMEM; + + return 0; +} + +/* Compare two IPv6 addresses. Return 0 if the two addresses are equal. + * Return 1 if the first is greater. Return -1 if the second is greater. + */ +int rds_addr_cmp(const struct in6_addr *addr1, + const struct in6_addr *addr2) +{ +#if defined(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS) && BITS_PER_LONG == 64 + const __be64 *a1, *a2; + u64 x, y; + + a1 = (__be64 *)addr1; + a2 = (__be64 *)addr2; + + if (*a1 != *a2) { + if (be64_to_cpu(*a1) < be64_to_cpu(*a2)) + return -1; + else + return 1; + } else { + x = be64_to_cpu(*++a1); + y = be64_to_cpu(*++a2); + if (x < y) + return -1; + else if (x > y) + return 1; + else + return 0; + } +#else + u32 a, b; + int i; + + for (i = 0; i < 4; i++) { + if (addr1->s6_addr32[i] != addr2->s6_addr32[i]) { + a = ntohl(addr1->s6_addr32[i]); + b = ntohl(addr2->s6_addr32[i]); + if (a < b) + return -1; + else if (a > b) + return 1; + } + } + return 0; +#endif +} +EXPORT_SYMBOL_GPL(rds_addr_cmp); diff --git a/net/rds/transport.c b/net/rds/transport.c new file mode 100644 index 000000000..f8001ec80 --- /dev/null +++ b/net/rds/transport.c @@ -0,0 +1,169 @@ +/* + * Copyright (c) 2006, 2017 Oracle and/or its affiliates. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + */ +#include +#include +#include +#include + +#include "rds.h" +#include "loop.h" + +static char * const rds_trans_modules[] = { + [RDS_TRANS_IB] = "rds_rdma", + [RDS_TRANS_GAP] = NULL, + [RDS_TRANS_TCP] = "rds_tcp", +}; + +static struct rds_transport *transports[RDS_TRANS_COUNT]; +static DECLARE_RWSEM(rds_trans_sem); + +void rds_trans_register(struct rds_transport *trans) +{ + BUG_ON(strlen(trans->t_name) + 1 > TRANSNAMSIZ); + + down_write(&rds_trans_sem); + + if (transports[trans->t_type]) + printk(KERN_ERR "RDS Transport type %d already registered\n", + trans->t_type); + else { + transports[trans->t_type] = trans; + printk(KERN_INFO "Registered RDS/%s transport\n", trans->t_name); + } + + up_write(&rds_trans_sem); +} +EXPORT_SYMBOL_GPL(rds_trans_register); + +void rds_trans_unregister(struct rds_transport *trans) +{ + down_write(&rds_trans_sem); + + transports[trans->t_type] = NULL; + printk(KERN_INFO "Unregistered RDS/%s transport\n", trans->t_name); + + up_write(&rds_trans_sem); +} +EXPORT_SYMBOL_GPL(rds_trans_unregister); + +void rds_trans_put(struct rds_transport *trans) +{ + if (trans) + module_put(trans->t_owner); +} + +struct rds_transport *rds_trans_get_preferred(struct net *net, + const struct in6_addr *addr, + __u32 scope_id) +{ + struct rds_transport *ret = NULL; + struct rds_transport *trans; + unsigned int i; + + if (ipv6_addr_v4mapped(addr)) { + if (*(u_int8_t *)&addr->s6_addr32[3] == IN_LOOPBACKNET) + return &rds_loop_transport; + } else if (ipv6_addr_loopback(addr)) { + return &rds_loop_transport; + } + + down_read(&rds_trans_sem); + for (i = 0; i < RDS_TRANS_COUNT; i++) { + trans = transports[i]; + + if (trans && (trans->laddr_check(net, addr, scope_id) == 0) && + (!trans->t_owner || try_module_get(trans->t_owner))) { + ret = trans; + break; + } + } + up_read(&rds_trans_sem); + + return ret; +} + +struct rds_transport *rds_trans_get(int t_type) +{ + struct rds_transport *ret = NULL; + struct rds_transport *trans; + + down_read(&rds_trans_sem); + trans = transports[t_type]; + if (!trans) { + up_read(&rds_trans_sem); + if (rds_trans_modules[t_type]) + request_module(rds_trans_modules[t_type]); + down_read(&rds_trans_sem); + trans = transports[t_type]; + } + if (trans && trans->t_type == t_type && + (!trans->t_owner || try_module_get(trans->t_owner))) + ret = trans; + + up_read(&rds_trans_sem); + + return ret; +} + +/* + * This returns the number of stats entries in the snapshot and only + * copies them using the iter if there is enough space for them. The + * caller passes in the global stats so that we can size and copy while + * holding the lock. + */ +unsigned int rds_trans_stats_info_copy(struct rds_info_iterator *iter, + unsigned int avail) + +{ + struct rds_transport *trans; + unsigned int total = 0; + unsigned int part; + int i; + + rds_info_iter_unmap(iter); + down_read(&rds_trans_sem); + + for (i = 0; i < RDS_TRANS_COUNT; i++) { + trans = transports[i]; + if (!trans || !trans->stats_info_copy) + continue; + + part = trans->stats_info_copy(iter, avail); + avail -= min(avail, part); + total += part; + } + + up_read(&rds_trans_sem); + + return total; +} diff --git a/net/rfkill/Kconfig b/net/rfkill/Kconfig new file mode 100644 index 000000000..83a7af898 --- /dev/null +++ b/net/rfkill/Kconfig @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# RF switch subsystem configuration +# +menuconfig RFKILL + tristate "RF switch subsystem support" + help + Say Y here if you want to have control over RF switches + found on many WiFi and Bluetooth cards. + + To compile this driver as a module, choose M here: the + module will be called rfkill. + +# LED trigger support +config RFKILL_LEDS + bool + depends on RFKILL + depends on LEDS_TRIGGERS = y || RFKILL = LEDS_TRIGGERS + default y + +config RFKILL_INPUT + bool "RF switch input support" if EXPERT + depends on RFKILL + depends on INPUT = y || RFKILL = INPUT + default y if !EXPERT + +config RFKILL_GPIO + tristate "GPIO RFKILL driver" + depends on RFKILL + depends on GPIOLIB || COMPILE_TEST + default n + help + If you say yes here you get support of a generic gpio RFKILL + driver. diff --git a/net/rfkill/Makefile b/net/rfkill/Makefile new file mode 100644 index 000000000..dc47b6174 --- /dev/null +++ b/net/rfkill/Makefile @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the RF switch subsystem. +# + +rfkill-y += core.o +rfkill-$(CONFIG_RFKILL_INPUT) += input.o +obj-$(CONFIG_RFKILL) += rfkill.o +obj-$(CONFIG_RFKILL_GPIO) += rfkill-gpio.o diff --git a/net/rfkill/core.c b/net/rfkill/core.c new file mode 100644 index 000000000..dac4fdc74 --- /dev/null +++ b/net/rfkill/core.c @@ -0,0 +1,1444 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (C) 2006 - 2007 Ivo van Doorn + * Copyright (C) 2007 Dmitry Torokhov + * Copyright 2009 Johannes Berg + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "rfkill.h" + +#define POLL_INTERVAL (5 * HZ) + +#define RFKILL_BLOCK_HW BIT(0) +#define RFKILL_BLOCK_SW BIT(1) +#define RFKILL_BLOCK_SW_PREV BIT(2) +#define RFKILL_BLOCK_ANY (RFKILL_BLOCK_HW |\ + RFKILL_BLOCK_SW |\ + RFKILL_BLOCK_SW_PREV) +#define RFKILL_BLOCK_SW_SETCALL BIT(31) + +struct rfkill { + spinlock_t lock; + + enum rfkill_type type; + + unsigned long state; + unsigned long hard_block_reasons; + + u32 idx; + + bool registered; + bool persistent; + bool polling_paused; + bool suspended; + + const struct rfkill_ops *ops; + void *data; + +#ifdef CONFIG_RFKILL_LEDS + struct led_trigger led_trigger; + const char *ledtrigname; +#endif + + struct device dev; + struct list_head node; + + struct delayed_work poll_work; + struct work_struct uevent_work; + struct work_struct sync_work; + char name[]; +}; +#define to_rfkill(d) container_of(d, struct rfkill, dev) + +struct rfkill_int_event { + struct list_head list; + struct rfkill_event_ext ev; +}; + +struct rfkill_data { + struct list_head list; + struct list_head events; + struct mutex mtx; + wait_queue_head_t read_wait; + bool input_handler; + u8 max_size; +}; + + +MODULE_AUTHOR("Ivo van Doorn "); +MODULE_AUTHOR("Johannes Berg "); +MODULE_DESCRIPTION("RF switch support"); +MODULE_LICENSE("GPL"); + + +/* + * The locking here should be made much smarter, we currently have + * a bit of a stupid situation because drivers might want to register + * the rfkill struct under their own lock, and take this lock during + * rfkill method calls -- which will cause an AB-BA deadlock situation. + * + * To fix that, we need to rework this code here to be mostly lock-free + * and only use the mutex for list manipulations, not to protect the + * various other global variables. Then we can avoid holding the mutex + * around driver operations, and all is happy. + */ +static LIST_HEAD(rfkill_list); /* list of registered rf switches */ +static DEFINE_MUTEX(rfkill_global_mutex); +static LIST_HEAD(rfkill_fds); /* list of open fds of /dev/rfkill */ + +static unsigned int rfkill_default_state = 1; +module_param_named(default_state, rfkill_default_state, uint, 0444); +MODULE_PARM_DESC(default_state, + "Default initial state for all radio types, 0 = radio off"); + +static struct { + bool cur, sav; +} rfkill_global_states[NUM_RFKILL_TYPES]; + +static bool rfkill_epo_lock_active; + + +#ifdef CONFIG_RFKILL_LEDS +static void rfkill_led_trigger_event(struct rfkill *rfkill) +{ + struct led_trigger *trigger; + + if (!rfkill->registered) + return; + + trigger = &rfkill->led_trigger; + + if (rfkill->state & RFKILL_BLOCK_ANY) + led_trigger_event(trigger, LED_OFF); + else + led_trigger_event(trigger, LED_FULL); +} + +static int rfkill_led_trigger_activate(struct led_classdev *led) +{ + struct rfkill *rfkill; + + rfkill = container_of(led->trigger, struct rfkill, led_trigger); + + rfkill_led_trigger_event(rfkill); + + return 0; +} + +const char *rfkill_get_led_trigger_name(struct rfkill *rfkill) +{ + return rfkill->led_trigger.name; +} +EXPORT_SYMBOL(rfkill_get_led_trigger_name); + +void rfkill_set_led_trigger_name(struct rfkill *rfkill, const char *name) +{ + BUG_ON(!rfkill); + + rfkill->ledtrigname = name; +} +EXPORT_SYMBOL(rfkill_set_led_trigger_name); + +static int rfkill_led_trigger_register(struct rfkill *rfkill) +{ + rfkill->led_trigger.name = rfkill->ledtrigname + ? : dev_name(&rfkill->dev); + rfkill->led_trigger.activate = rfkill_led_trigger_activate; + return led_trigger_register(&rfkill->led_trigger); +} + +static void rfkill_led_trigger_unregister(struct rfkill *rfkill) +{ + led_trigger_unregister(&rfkill->led_trigger); +} + +static struct led_trigger rfkill_any_led_trigger; +static struct led_trigger rfkill_none_led_trigger; +static struct work_struct rfkill_global_led_trigger_work; + +static void rfkill_global_led_trigger_worker(struct work_struct *work) +{ + enum led_brightness brightness = LED_OFF; + struct rfkill *rfkill; + + mutex_lock(&rfkill_global_mutex); + list_for_each_entry(rfkill, &rfkill_list, node) { + if (!(rfkill->state & RFKILL_BLOCK_ANY)) { + brightness = LED_FULL; + break; + } + } + mutex_unlock(&rfkill_global_mutex); + + led_trigger_event(&rfkill_any_led_trigger, brightness); + led_trigger_event(&rfkill_none_led_trigger, + brightness == LED_OFF ? LED_FULL : LED_OFF); +} + +static void rfkill_global_led_trigger_event(void) +{ + schedule_work(&rfkill_global_led_trigger_work); +} + +static int rfkill_global_led_trigger_register(void) +{ + int ret; + + INIT_WORK(&rfkill_global_led_trigger_work, + rfkill_global_led_trigger_worker); + + rfkill_any_led_trigger.name = "rfkill-any"; + ret = led_trigger_register(&rfkill_any_led_trigger); + if (ret) + return ret; + + rfkill_none_led_trigger.name = "rfkill-none"; + ret = led_trigger_register(&rfkill_none_led_trigger); + if (ret) + led_trigger_unregister(&rfkill_any_led_trigger); + else + /* Delay activation until all global triggers are registered */ + rfkill_global_led_trigger_event(); + + return ret; +} + +static void rfkill_global_led_trigger_unregister(void) +{ + led_trigger_unregister(&rfkill_none_led_trigger); + led_trigger_unregister(&rfkill_any_led_trigger); + cancel_work_sync(&rfkill_global_led_trigger_work); +} +#else +static void rfkill_led_trigger_event(struct rfkill *rfkill) +{ +} + +static inline int rfkill_led_trigger_register(struct rfkill *rfkill) +{ + return 0; +} + +static inline void rfkill_led_trigger_unregister(struct rfkill *rfkill) +{ +} + +static void rfkill_global_led_trigger_event(void) +{ +} + +static int rfkill_global_led_trigger_register(void) +{ + return 0; +} + +static void rfkill_global_led_trigger_unregister(void) +{ +} +#endif /* CONFIG_RFKILL_LEDS */ + +static void rfkill_fill_event(struct rfkill_event_ext *ev, + struct rfkill *rfkill, + enum rfkill_operation op) +{ + unsigned long flags; + + ev->idx = rfkill->idx; + ev->type = rfkill->type; + ev->op = op; + + spin_lock_irqsave(&rfkill->lock, flags); + ev->hard = !!(rfkill->state & RFKILL_BLOCK_HW); + ev->soft = !!(rfkill->state & (RFKILL_BLOCK_SW | + RFKILL_BLOCK_SW_PREV)); + ev->hard_block_reasons = rfkill->hard_block_reasons; + spin_unlock_irqrestore(&rfkill->lock, flags); +} + +static void rfkill_send_events(struct rfkill *rfkill, enum rfkill_operation op) +{ + struct rfkill_data *data; + struct rfkill_int_event *ev; + + list_for_each_entry(data, &rfkill_fds, list) { + ev = kzalloc(sizeof(*ev), GFP_KERNEL); + if (!ev) + continue; + rfkill_fill_event(&ev->ev, rfkill, op); + mutex_lock(&data->mtx); + list_add_tail(&ev->list, &data->events); + mutex_unlock(&data->mtx); + wake_up_interruptible(&data->read_wait); + } +} + +static void rfkill_event(struct rfkill *rfkill) +{ + if (!rfkill->registered) + return; + + kobject_uevent(&rfkill->dev.kobj, KOBJ_CHANGE); + + /* also send event to /dev/rfkill */ + rfkill_send_events(rfkill, RFKILL_OP_CHANGE); +} + +/** + * rfkill_set_block - wrapper for set_block method + * + * @rfkill: the rfkill struct to use + * @blocked: the new software state + * + * Calls the set_block method (when applicable) and handles notifications + * etc. as well. + */ +static void rfkill_set_block(struct rfkill *rfkill, bool blocked) +{ + unsigned long flags; + bool prev, curr; + int err; + + if (unlikely(rfkill->dev.power.power_state.event & PM_EVENT_SLEEP)) + return; + + /* + * Some platforms (...!) generate input events which affect the + * _hard_ kill state -- whenever something tries to change the + * current software state query the hardware state too. + */ + if (rfkill->ops->query) + rfkill->ops->query(rfkill, rfkill->data); + + spin_lock_irqsave(&rfkill->lock, flags); + prev = rfkill->state & RFKILL_BLOCK_SW; + + if (prev) + rfkill->state |= RFKILL_BLOCK_SW_PREV; + else + rfkill->state &= ~RFKILL_BLOCK_SW_PREV; + + if (blocked) + rfkill->state |= RFKILL_BLOCK_SW; + else + rfkill->state &= ~RFKILL_BLOCK_SW; + + rfkill->state |= RFKILL_BLOCK_SW_SETCALL; + spin_unlock_irqrestore(&rfkill->lock, flags); + + err = rfkill->ops->set_block(rfkill->data, blocked); + + spin_lock_irqsave(&rfkill->lock, flags); + if (err) { + /* + * Failed -- reset status to _PREV, which may be different + * from what we have set _PREV to earlier in this function + * if rfkill_set_sw_state was invoked. + */ + if (rfkill->state & RFKILL_BLOCK_SW_PREV) + rfkill->state |= RFKILL_BLOCK_SW; + else + rfkill->state &= ~RFKILL_BLOCK_SW; + } + rfkill->state &= ~RFKILL_BLOCK_SW_SETCALL; + rfkill->state &= ~RFKILL_BLOCK_SW_PREV; + curr = rfkill->state & RFKILL_BLOCK_SW; + spin_unlock_irqrestore(&rfkill->lock, flags); + + rfkill_led_trigger_event(rfkill); + rfkill_global_led_trigger_event(); + + if (prev != curr) + rfkill_event(rfkill); +} + +static void rfkill_update_global_state(enum rfkill_type type, bool blocked) +{ + int i; + + if (type != RFKILL_TYPE_ALL) { + rfkill_global_states[type].cur = blocked; + return; + } + + for (i = 0; i < NUM_RFKILL_TYPES; i++) + rfkill_global_states[i].cur = blocked; +} + +#ifdef CONFIG_RFKILL_INPUT +static atomic_t rfkill_input_disabled = ATOMIC_INIT(0); + +/** + * __rfkill_switch_all - Toggle state of all switches of given type + * @type: type of interfaces to be affected + * @blocked: the new state + * + * This function sets the state of all switches of given type, + * unless a specific switch is suspended. + * + * Caller must have acquired rfkill_global_mutex. + */ +static void __rfkill_switch_all(const enum rfkill_type type, bool blocked) +{ + struct rfkill *rfkill; + + rfkill_update_global_state(type, blocked); + list_for_each_entry(rfkill, &rfkill_list, node) { + if (rfkill->type != type && type != RFKILL_TYPE_ALL) + continue; + + rfkill_set_block(rfkill, blocked); + } +} + +/** + * rfkill_switch_all - Toggle state of all switches of given type + * @type: type of interfaces to be affected + * @blocked: the new state + * + * Acquires rfkill_global_mutex and calls __rfkill_switch_all(@type, @state). + * Please refer to __rfkill_switch_all() for details. + * + * Does nothing if the EPO lock is active. + */ +void rfkill_switch_all(enum rfkill_type type, bool blocked) +{ + if (atomic_read(&rfkill_input_disabled)) + return; + + mutex_lock(&rfkill_global_mutex); + + if (!rfkill_epo_lock_active) + __rfkill_switch_all(type, blocked); + + mutex_unlock(&rfkill_global_mutex); +} + +/** + * rfkill_epo - emergency power off all transmitters + * + * This kicks all non-suspended rfkill devices to RFKILL_STATE_SOFT_BLOCKED, + * ignoring everything in its path but rfkill_global_mutex and rfkill->mutex. + * + * The global state before the EPO is saved and can be restored later + * using rfkill_restore_states(). + */ +void rfkill_epo(void) +{ + struct rfkill *rfkill; + int i; + + if (atomic_read(&rfkill_input_disabled)) + return; + + mutex_lock(&rfkill_global_mutex); + + rfkill_epo_lock_active = true; + list_for_each_entry(rfkill, &rfkill_list, node) + rfkill_set_block(rfkill, true); + + for (i = 0; i < NUM_RFKILL_TYPES; i++) { + rfkill_global_states[i].sav = rfkill_global_states[i].cur; + rfkill_global_states[i].cur = true; + } + + mutex_unlock(&rfkill_global_mutex); +} + +/** + * rfkill_restore_states - restore global states + * + * Restore (and sync switches to) the global state from the + * states in rfkill_default_states. This can undo the effects of + * a call to rfkill_epo(). + */ +void rfkill_restore_states(void) +{ + int i; + + if (atomic_read(&rfkill_input_disabled)) + return; + + mutex_lock(&rfkill_global_mutex); + + rfkill_epo_lock_active = false; + for (i = 0; i < NUM_RFKILL_TYPES; i++) + __rfkill_switch_all(i, rfkill_global_states[i].sav); + mutex_unlock(&rfkill_global_mutex); +} + +/** + * rfkill_remove_epo_lock - unlock state changes + * + * Used by rfkill-input manually unlock state changes, when + * the EPO switch is deactivated. + */ +void rfkill_remove_epo_lock(void) +{ + if (atomic_read(&rfkill_input_disabled)) + return; + + mutex_lock(&rfkill_global_mutex); + rfkill_epo_lock_active = false; + mutex_unlock(&rfkill_global_mutex); +} + +/** + * rfkill_is_epo_lock_active - returns true EPO is active + * + * Returns 0 (false) if there is NOT an active EPO condition, + * and 1 (true) if there is an active EPO condition, which + * locks all radios in one of the BLOCKED states. + * + * Can be called in atomic context. + */ +bool rfkill_is_epo_lock_active(void) +{ + return rfkill_epo_lock_active; +} + +/** + * rfkill_get_global_sw_state - returns global state for a type + * @type: the type to get the global state of + * + * Returns the current global state for a given wireless + * device type. + */ +bool rfkill_get_global_sw_state(const enum rfkill_type type) +{ + return rfkill_global_states[type].cur; +} +#endif + +bool rfkill_set_hw_state_reason(struct rfkill *rfkill, + bool blocked, unsigned long reason) +{ + unsigned long flags; + bool ret, prev; + + BUG_ON(!rfkill); + + if (WARN(reason & + ~(RFKILL_HARD_BLOCK_SIGNAL | RFKILL_HARD_BLOCK_NOT_OWNER), + "hw_state reason not supported: 0x%lx", reason)) + return blocked; + + spin_lock_irqsave(&rfkill->lock, flags); + prev = !!(rfkill->hard_block_reasons & reason); + if (blocked) { + rfkill->state |= RFKILL_BLOCK_HW; + rfkill->hard_block_reasons |= reason; + } else { + rfkill->hard_block_reasons &= ~reason; + if (!rfkill->hard_block_reasons) + rfkill->state &= ~RFKILL_BLOCK_HW; + } + ret = !!(rfkill->state & RFKILL_BLOCK_ANY); + spin_unlock_irqrestore(&rfkill->lock, flags); + + rfkill_led_trigger_event(rfkill); + rfkill_global_led_trigger_event(); + + if (rfkill->registered && prev != blocked) + schedule_work(&rfkill->uevent_work); + + return ret; +} +EXPORT_SYMBOL(rfkill_set_hw_state_reason); + +static void __rfkill_set_sw_state(struct rfkill *rfkill, bool blocked) +{ + u32 bit = RFKILL_BLOCK_SW; + + /* if in a ops->set_block right now, use other bit */ + if (rfkill->state & RFKILL_BLOCK_SW_SETCALL) + bit = RFKILL_BLOCK_SW_PREV; + + if (blocked) + rfkill->state |= bit; + else + rfkill->state &= ~bit; +} + +bool rfkill_set_sw_state(struct rfkill *rfkill, bool blocked) +{ + unsigned long flags; + bool prev, hwblock; + + BUG_ON(!rfkill); + + spin_lock_irqsave(&rfkill->lock, flags); + prev = !!(rfkill->state & RFKILL_BLOCK_SW); + __rfkill_set_sw_state(rfkill, blocked); + hwblock = !!(rfkill->state & RFKILL_BLOCK_HW); + blocked = blocked || hwblock; + spin_unlock_irqrestore(&rfkill->lock, flags); + + if (!rfkill->registered) + return blocked; + + if (prev != blocked && !hwblock) + schedule_work(&rfkill->uevent_work); + + rfkill_led_trigger_event(rfkill); + rfkill_global_led_trigger_event(); + + return blocked; +} +EXPORT_SYMBOL(rfkill_set_sw_state); + +void rfkill_init_sw_state(struct rfkill *rfkill, bool blocked) +{ + unsigned long flags; + + BUG_ON(!rfkill); + BUG_ON(rfkill->registered); + + spin_lock_irqsave(&rfkill->lock, flags); + __rfkill_set_sw_state(rfkill, blocked); + rfkill->persistent = true; + spin_unlock_irqrestore(&rfkill->lock, flags); +} +EXPORT_SYMBOL(rfkill_init_sw_state); + +void rfkill_set_states(struct rfkill *rfkill, bool sw, bool hw) +{ + unsigned long flags; + bool swprev, hwprev; + + BUG_ON(!rfkill); + + spin_lock_irqsave(&rfkill->lock, flags); + + /* + * No need to care about prev/setblock ... this is for uevent only + * and that will get triggered by rfkill_set_block anyway. + */ + swprev = !!(rfkill->state & RFKILL_BLOCK_SW); + hwprev = !!(rfkill->state & RFKILL_BLOCK_HW); + __rfkill_set_sw_state(rfkill, sw); + if (hw) + rfkill->state |= RFKILL_BLOCK_HW; + else + rfkill->state &= ~RFKILL_BLOCK_HW; + + spin_unlock_irqrestore(&rfkill->lock, flags); + + if (!rfkill->registered) { + rfkill->persistent = true; + } else { + if (swprev != sw || hwprev != hw) + schedule_work(&rfkill->uevent_work); + + rfkill_led_trigger_event(rfkill); + rfkill_global_led_trigger_event(); + } +} +EXPORT_SYMBOL(rfkill_set_states); + +static const char * const rfkill_types[] = { + NULL, /* RFKILL_TYPE_ALL */ + "wlan", + "bluetooth", + "ultrawideband", + "wimax", + "wwan", + "gps", + "fm", + "nfc", +}; + +enum rfkill_type rfkill_find_type(const char *name) +{ + int i; + + BUILD_BUG_ON(ARRAY_SIZE(rfkill_types) != NUM_RFKILL_TYPES); + + if (!name) + return RFKILL_TYPE_ALL; + + for (i = 1; i < NUM_RFKILL_TYPES; i++) + if (!strcmp(name, rfkill_types[i])) + return i; + return RFKILL_TYPE_ALL; +} +EXPORT_SYMBOL(rfkill_find_type); + +static ssize_t name_show(struct device *dev, struct device_attribute *attr, + char *buf) +{ + struct rfkill *rfkill = to_rfkill(dev); + + return sprintf(buf, "%s\n", rfkill->name); +} +static DEVICE_ATTR_RO(name); + +static ssize_t type_show(struct device *dev, struct device_attribute *attr, + char *buf) +{ + struct rfkill *rfkill = to_rfkill(dev); + + return sprintf(buf, "%s\n", rfkill_types[rfkill->type]); +} +static DEVICE_ATTR_RO(type); + +static ssize_t index_show(struct device *dev, struct device_attribute *attr, + char *buf) +{ + struct rfkill *rfkill = to_rfkill(dev); + + return sprintf(buf, "%d\n", rfkill->idx); +} +static DEVICE_ATTR_RO(index); + +static ssize_t persistent_show(struct device *dev, + struct device_attribute *attr, char *buf) +{ + struct rfkill *rfkill = to_rfkill(dev); + + return sprintf(buf, "%d\n", rfkill->persistent); +} +static DEVICE_ATTR_RO(persistent); + +static ssize_t hard_show(struct device *dev, struct device_attribute *attr, + char *buf) +{ + struct rfkill *rfkill = to_rfkill(dev); + + return sprintf(buf, "%d\n", (rfkill->state & RFKILL_BLOCK_HW) ? 1 : 0 ); +} +static DEVICE_ATTR_RO(hard); + +static ssize_t soft_show(struct device *dev, struct device_attribute *attr, + char *buf) +{ + struct rfkill *rfkill = to_rfkill(dev); + + return sprintf(buf, "%d\n", (rfkill->state & RFKILL_BLOCK_SW) ? 1 : 0 ); +} + +static ssize_t soft_store(struct device *dev, struct device_attribute *attr, + const char *buf, size_t count) +{ + struct rfkill *rfkill = to_rfkill(dev); + unsigned long state; + int err; + + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + err = kstrtoul(buf, 0, &state); + if (err) + return err; + + if (state > 1 ) + return -EINVAL; + + mutex_lock(&rfkill_global_mutex); + rfkill_set_block(rfkill, state); + mutex_unlock(&rfkill_global_mutex); + + return count; +} +static DEVICE_ATTR_RW(soft); + +static ssize_t hard_block_reasons_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + struct rfkill *rfkill = to_rfkill(dev); + + return sprintf(buf, "0x%lx\n", rfkill->hard_block_reasons); +} +static DEVICE_ATTR_RO(hard_block_reasons); + +static u8 user_state_from_blocked(unsigned long state) +{ + if (state & RFKILL_BLOCK_HW) + return RFKILL_USER_STATE_HARD_BLOCKED; + if (state & RFKILL_BLOCK_SW) + return RFKILL_USER_STATE_SOFT_BLOCKED; + + return RFKILL_USER_STATE_UNBLOCKED; +} + +static ssize_t state_show(struct device *dev, struct device_attribute *attr, + char *buf) +{ + struct rfkill *rfkill = to_rfkill(dev); + + return sprintf(buf, "%d\n", user_state_from_blocked(rfkill->state)); +} + +static ssize_t state_store(struct device *dev, struct device_attribute *attr, + const char *buf, size_t count) +{ + struct rfkill *rfkill = to_rfkill(dev); + unsigned long state; + int err; + + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + err = kstrtoul(buf, 0, &state); + if (err) + return err; + + if (state != RFKILL_USER_STATE_SOFT_BLOCKED && + state != RFKILL_USER_STATE_UNBLOCKED) + return -EINVAL; + + mutex_lock(&rfkill_global_mutex); + rfkill_set_block(rfkill, state == RFKILL_USER_STATE_SOFT_BLOCKED); + mutex_unlock(&rfkill_global_mutex); + + return count; +} +static DEVICE_ATTR_RW(state); + +static struct attribute *rfkill_dev_attrs[] = { + &dev_attr_name.attr, + &dev_attr_type.attr, + &dev_attr_index.attr, + &dev_attr_persistent.attr, + &dev_attr_state.attr, + &dev_attr_soft.attr, + &dev_attr_hard.attr, + &dev_attr_hard_block_reasons.attr, + NULL, +}; +ATTRIBUTE_GROUPS(rfkill_dev); + +static void rfkill_release(struct device *dev) +{ + struct rfkill *rfkill = to_rfkill(dev); + + kfree(rfkill); +} + +static int rfkill_dev_uevent(struct device *dev, struct kobj_uevent_env *env) +{ + struct rfkill *rfkill = to_rfkill(dev); + unsigned long flags; + unsigned long reasons; + u32 state; + int error; + + error = add_uevent_var(env, "RFKILL_NAME=%s", rfkill->name); + if (error) + return error; + error = add_uevent_var(env, "RFKILL_TYPE=%s", + rfkill_types[rfkill->type]); + if (error) + return error; + spin_lock_irqsave(&rfkill->lock, flags); + state = rfkill->state; + reasons = rfkill->hard_block_reasons; + spin_unlock_irqrestore(&rfkill->lock, flags); + error = add_uevent_var(env, "RFKILL_STATE=%d", + user_state_from_blocked(state)); + if (error) + return error; + return add_uevent_var(env, "RFKILL_HW_BLOCK_REASON=0x%lx", reasons); +} + +void rfkill_pause_polling(struct rfkill *rfkill) +{ + BUG_ON(!rfkill); + + if (!rfkill->ops->poll) + return; + + rfkill->polling_paused = true; + cancel_delayed_work_sync(&rfkill->poll_work); +} +EXPORT_SYMBOL(rfkill_pause_polling); + +void rfkill_resume_polling(struct rfkill *rfkill) +{ + BUG_ON(!rfkill); + + if (!rfkill->ops->poll) + return; + + rfkill->polling_paused = false; + + if (rfkill->suspended) + return; + + queue_delayed_work(system_power_efficient_wq, + &rfkill->poll_work, 0); +} +EXPORT_SYMBOL(rfkill_resume_polling); + +#ifdef CONFIG_PM_SLEEP +static int rfkill_suspend(struct device *dev) +{ + struct rfkill *rfkill = to_rfkill(dev); + + rfkill->suspended = true; + cancel_delayed_work_sync(&rfkill->poll_work); + + return 0; +} + +static int rfkill_resume(struct device *dev) +{ + struct rfkill *rfkill = to_rfkill(dev); + bool cur; + + rfkill->suspended = false; + + if (!rfkill->registered) + return 0; + + if (!rfkill->persistent) { + cur = !!(rfkill->state & RFKILL_BLOCK_SW); + rfkill_set_block(rfkill, cur); + } + + if (rfkill->ops->poll && !rfkill->polling_paused) + queue_delayed_work(system_power_efficient_wq, + &rfkill->poll_work, 0); + + return 0; +} + +static SIMPLE_DEV_PM_OPS(rfkill_pm_ops, rfkill_suspend, rfkill_resume); +#define RFKILL_PM_OPS (&rfkill_pm_ops) +#else +#define RFKILL_PM_OPS NULL +#endif + +static struct class rfkill_class = { + .name = "rfkill", + .dev_release = rfkill_release, + .dev_groups = rfkill_dev_groups, + .dev_uevent = rfkill_dev_uevent, + .pm = RFKILL_PM_OPS, +}; + +bool rfkill_blocked(struct rfkill *rfkill) +{ + unsigned long flags; + u32 state; + + spin_lock_irqsave(&rfkill->lock, flags); + state = rfkill->state; + spin_unlock_irqrestore(&rfkill->lock, flags); + + return !!(state & RFKILL_BLOCK_ANY); +} +EXPORT_SYMBOL(rfkill_blocked); + +bool rfkill_soft_blocked(struct rfkill *rfkill) +{ + unsigned long flags; + u32 state; + + spin_lock_irqsave(&rfkill->lock, flags); + state = rfkill->state; + spin_unlock_irqrestore(&rfkill->lock, flags); + + return !!(state & RFKILL_BLOCK_SW); +} +EXPORT_SYMBOL(rfkill_soft_blocked); + +struct rfkill * __must_check rfkill_alloc(const char *name, + struct device *parent, + const enum rfkill_type type, + const struct rfkill_ops *ops, + void *ops_data) +{ + struct rfkill *rfkill; + struct device *dev; + + if (WARN_ON(!ops)) + return NULL; + + if (WARN_ON(!ops->set_block)) + return NULL; + + if (WARN_ON(!name)) + return NULL; + + if (WARN_ON(type == RFKILL_TYPE_ALL || type >= NUM_RFKILL_TYPES)) + return NULL; + + rfkill = kzalloc(sizeof(*rfkill) + strlen(name) + 1, GFP_KERNEL); + if (!rfkill) + return NULL; + + spin_lock_init(&rfkill->lock); + INIT_LIST_HEAD(&rfkill->node); + rfkill->type = type; + strcpy(rfkill->name, name); + rfkill->ops = ops; + rfkill->data = ops_data; + + dev = &rfkill->dev; + dev->class = &rfkill_class; + dev->parent = parent; + device_initialize(dev); + + return rfkill; +} +EXPORT_SYMBOL(rfkill_alloc); + +static void rfkill_poll(struct work_struct *work) +{ + struct rfkill *rfkill; + + rfkill = container_of(work, struct rfkill, poll_work.work); + + /* + * Poll hardware state -- driver will use one of the + * rfkill_set{,_hw,_sw}_state functions and use its + * return value to update the current status. + */ + rfkill->ops->poll(rfkill, rfkill->data); + + queue_delayed_work(system_power_efficient_wq, + &rfkill->poll_work, + round_jiffies_relative(POLL_INTERVAL)); +} + +static void rfkill_uevent_work(struct work_struct *work) +{ + struct rfkill *rfkill; + + rfkill = container_of(work, struct rfkill, uevent_work); + + mutex_lock(&rfkill_global_mutex); + rfkill_event(rfkill); + mutex_unlock(&rfkill_global_mutex); +} + +static void rfkill_sync_work(struct work_struct *work) +{ + struct rfkill *rfkill; + bool cur; + + rfkill = container_of(work, struct rfkill, sync_work); + + mutex_lock(&rfkill_global_mutex); + cur = rfkill_global_states[rfkill->type].cur; + rfkill_set_block(rfkill, cur); + mutex_unlock(&rfkill_global_mutex); +} + +int __must_check rfkill_register(struct rfkill *rfkill) +{ + static unsigned long rfkill_no; + struct device *dev; + int error; + + if (!rfkill) + return -EINVAL; + + dev = &rfkill->dev; + + mutex_lock(&rfkill_global_mutex); + + if (rfkill->registered) { + error = -EALREADY; + goto unlock; + } + + rfkill->idx = rfkill_no; + dev_set_name(dev, "rfkill%lu", rfkill_no); + rfkill_no++; + + list_add_tail(&rfkill->node, &rfkill_list); + + error = device_add(dev); + if (error) + goto remove; + + error = rfkill_led_trigger_register(rfkill); + if (error) + goto devdel; + + rfkill->registered = true; + + INIT_DELAYED_WORK(&rfkill->poll_work, rfkill_poll); + INIT_WORK(&rfkill->uevent_work, rfkill_uevent_work); + INIT_WORK(&rfkill->sync_work, rfkill_sync_work); + + if (rfkill->ops->poll) + queue_delayed_work(system_power_efficient_wq, + &rfkill->poll_work, + round_jiffies_relative(POLL_INTERVAL)); + + if (!rfkill->persistent || rfkill_epo_lock_active) { + schedule_work(&rfkill->sync_work); + } else { +#ifdef CONFIG_RFKILL_INPUT + bool soft_blocked = !!(rfkill->state & RFKILL_BLOCK_SW); + + if (!atomic_read(&rfkill_input_disabled)) + __rfkill_switch_all(rfkill->type, soft_blocked); +#endif + } + + rfkill_global_led_trigger_event(); + rfkill_send_events(rfkill, RFKILL_OP_ADD); + + mutex_unlock(&rfkill_global_mutex); + return 0; + + devdel: + device_del(&rfkill->dev); + remove: + list_del_init(&rfkill->node); + unlock: + mutex_unlock(&rfkill_global_mutex); + return error; +} +EXPORT_SYMBOL(rfkill_register); + +void rfkill_unregister(struct rfkill *rfkill) +{ + BUG_ON(!rfkill); + + if (rfkill->ops->poll) + cancel_delayed_work_sync(&rfkill->poll_work); + + cancel_work_sync(&rfkill->uevent_work); + cancel_work_sync(&rfkill->sync_work); + + rfkill->registered = false; + + device_del(&rfkill->dev); + + mutex_lock(&rfkill_global_mutex); + rfkill_send_events(rfkill, RFKILL_OP_DEL); + list_del_init(&rfkill->node); + rfkill_global_led_trigger_event(); + mutex_unlock(&rfkill_global_mutex); + + rfkill_led_trigger_unregister(rfkill); +} +EXPORT_SYMBOL(rfkill_unregister); + +void rfkill_destroy(struct rfkill *rfkill) +{ + if (rfkill) + put_device(&rfkill->dev); +} +EXPORT_SYMBOL(rfkill_destroy); + +static int rfkill_fop_open(struct inode *inode, struct file *file) +{ + struct rfkill_data *data; + struct rfkill *rfkill; + struct rfkill_int_event *ev, *tmp; + + data = kzalloc(sizeof(*data), GFP_KERNEL); + if (!data) + return -ENOMEM; + + data->max_size = RFKILL_EVENT_SIZE_V1; + + INIT_LIST_HEAD(&data->events); + mutex_init(&data->mtx); + init_waitqueue_head(&data->read_wait); + + mutex_lock(&rfkill_global_mutex); + mutex_lock(&data->mtx); + /* + * start getting events from elsewhere but hold mtx to get + * startup events added first + */ + + list_for_each_entry(rfkill, &rfkill_list, node) { + ev = kzalloc(sizeof(*ev), GFP_KERNEL); + if (!ev) + goto free; + rfkill_fill_event(&ev->ev, rfkill, RFKILL_OP_ADD); + list_add_tail(&ev->list, &data->events); + } + list_add(&data->list, &rfkill_fds); + mutex_unlock(&data->mtx); + mutex_unlock(&rfkill_global_mutex); + + file->private_data = data; + + return stream_open(inode, file); + + free: + mutex_unlock(&data->mtx); + mutex_unlock(&rfkill_global_mutex); + mutex_destroy(&data->mtx); + list_for_each_entry_safe(ev, tmp, &data->events, list) + kfree(ev); + kfree(data); + return -ENOMEM; +} + +static __poll_t rfkill_fop_poll(struct file *file, poll_table *wait) +{ + struct rfkill_data *data = file->private_data; + __poll_t res = EPOLLOUT | EPOLLWRNORM; + + poll_wait(file, &data->read_wait, wait); + + mutex_lock(&data->mtx); + if (!list_empty(&data->events)) + res = EPOLLIN | EPOLLRDNORM; + mutex_unlock(&data->mtx); + + return res; +} + +static ssize_t rfkill_fop_read(struct file *file, char __user *buf, + size_t count, loff_t *pos) +{ + struct rfkill_data *data = file->private_data; + struct rfkill_int_event *ev; + unsigned long sz; + int ret; + + mutex_lock(&data->mtx); + + while (list_empty(&data->events)) { + if (file->f_flags & O_NONBLOCK) { + ret = -EAGAIN; + goto out; + } + mutex_unlock(&data->mtx); + /* since we re-check and it just compares pointers, + * using !list_empty() without locking isn't a problem + */ + ret = wait_event_interruptible(data->read_wait, + !list_empty(&data->events)); + mutex_lock(&data->mtx); + + if (ret) + goto out; + } + + ev = list_first_entry(&data->events, struct rfkill_int_event, + list); + + sz = min_t(unsigned long, sizeof(ev->ev), count); + sz = min_t(unsigned long, sz, data->max_size); + ret = sz; + if (copy_to_user(buf, &ev->ev, sz)) + ret = -EFAULT; + + list_del(&ev->list); + kfree(ev); + out: + mutex_unlock(&data->mtx); + return ret; +} + +static ssize_t rfkill_fop_write(struct file *file, const char __user *buf, + size_t count, loff_t *pos) +{ + struct rfkill_data *data = file->private_data; + struct rfkill *rfkill; + struct rfkill_event_ext ev; + int ret; + + /* we don't need the 'hard' variable but accept it */ + if (count < RFKILL_EVENT_SIZE_V1 - 1) + return -EINVAL; + + /* + * Copy as much data as we can accept into our 'ev' buffer, + * but tell userspace how much we've copied so it can determine + * our API version even in a write() call, if it cares. + */ + count = min(count, sizeof(ev)); + count = min_t(size_t, count, data->max_size); + if (copy_from_user(&ev, buf, count)) + return -EFAULT; + + if (ev.type >= NUM_RFKILL_TYPES) + return -EINVAL; + + mutex_lock(&rfkill_global_mutex); + + switch (ev.op) { + case RFKILL_OP_CHANGE_ALL: + rfkill_update_global_state(ev.type, ev.soft); + list_for_each_entry(rfkill, &rfkill_list, node) + if (rfkill->type == ev.type || + ev.type == RFKILL_TYPE_ALL) + rfkill_set_block(rfkill, ev.soft); + ret = 0; + break; + case RFKILL_OP_CHANGE: + list_for_each_entry(rfkill, &rfkill_list, node) + if (rfkill->idx == ev.idx && + (rfkill->type == ev.type || + ev.type == RFKILL_TYPE_ALL)) + rfkill_set_block(rfkill, ev.soft); + ret = 0; + break; + default: + ret = -EINVAL; + break; + } + + mutex_unlock(&rfkill_global_mutex); + + return ret ?: count; +} + +static int rfkill_fop_release(struct inode *inode, struct file *file) +{ + struct rfkill_data *data = file->private_data; + struct rfkill_int_event *ev, *tmp; + + mutex_lock(&rfkill_global_mutex); + list_del(&data->list); + mutex_unlock(&rfkill_global_mutex); + + mutex_destroy(&data->mtx); + list_for_each_entry_safe(ev, tmp, &data->events, list) + kfree(ev); + +#ifdef CONFIG_RFKILL_INPUT + if (data->input_handler) + if (atomic_dec_return(&rfkill_input_disabled) == 0) + printk(KERN_DEBUG "rfkill: input handler enabled\n"); +#endif + + kfree(data); + + return 0; +} + +static long rfkill_fop_ioctl(struct file *file, unsigned int cmd, + unsigned long arg) +{ + struct rfkill_data *data = file->private_data; + int ret = -ENOSYS; + u32 size; + + if (_IOC_TYPE(cmd) != RFKILL_IOC_MAGIC) + return -ENOSYS; + + mutex_lock(&data->mtx); + switch (_IOC_NR(cmd)) { +#ifdef CONFIG_RFKILL_INPUT + case RFKILL_IOC_NOINPUT: + if (!data->input_handler) { + if (atomic_inc_return(&rfkill_input_disabled) == 1) + printk(KERN_DEBUG "rfkill: input handler disabled\n"); + data->input_handler = true; + } + ret = 0; + break; +#endif + case RFKILL_IOC_MAX_SIZE: + if (get_user(size, (__u32 __user *)arg)) { + ret = -EFAULT; + break; + } + if (size < RFKILL_EVENT_SIZE_V1 || size > U8_MAX) { + ret = -EINVAL; + break; + } + data->max_size = size; + ret = 0; + break; + default: + break; + } + mutex_unlock(&data->mtx); + + return ret; +} + +static const struct file_operations rfkill_fops = { + .owner = THIS_MODULE, + .open = rfkill_fop_open, + .read = rfkill_fop_read, + .write = rfkill_fop_write, + .poll = rfkill_fop_poll, + .release = rfkill_fop_release, + .unlocked_ioctl = rfkill_fop_ioctl, + .compat_ioctl = compat_ptr_ioctl, + .llseek = no_llseek, +}; + +#define RFKILL_NAME "rfkill" + +static struct miscdevice rfkill_miscdev = { + .fops = &rfkill_fops, + .name = RFKILL_NAME, + .minor = RFKILL_MINOR, +}; + +static int __init rfkill_init(void) +{ + int error; + + rfkill_update_global_state(RFKILL_TYPE_ALL, !rfkill_default_state); + + error = class_register(&rfkill_class); + if (error) + goto error_class; + + error = misc_register(&rfkill_miscdev); + if (error) + goto error_misc; + + error = rfkill_global_led_trigger_register(); + if (error) + goto error_led_trigger; + +#ifdef CONFIG_RFKILL_INPUT + error = rfkill_handler_init(); + if (error) + goto error_input; +#endif + + return 0; + +#ifdef CONFIG_RFKILL_INPUT +error_input: + rfkill_global_led_trigger_unregister(); +#endif +error_led_trigger: + misc_deregister(&rfkill_miscdev); +error_misc: + class_unregister(&rfkill_class); +error_class: + return error; +} +subsys_initcall(rfkill_init); + +static void __exit rfkill_exit(void) +{ +#ifdef CONFIG_RFKILL_INPUT + rfkill_handler_exit(); +#endif + rfkill_global_led_trigger_unregister(); + misc_deregister(&rfkill_miscdev); + class_unregister(&rfkill_class); +} +module_exit(rfkill_exit); + +MODULE_ALIAS_MISCDEV(RFKILL_MINOR); +MODULE_ALIAS("devname:" RFKILL_NAME); diff --git a/net/rfkill/input.c b/net/rfkill/input.c new file mode 100644 index 000000000..598d0a61b --- /dev/null +++ b/net/rfkill/input.c @@ -0,0 +1,343 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Input layer to RF Kill interface connector + * + * Copyright (c) 2007 Dmitry Torokhov + * Copyright 2009 Johannes Berg + * + * If you ever run into a situation in which you have a SW_ type rfkill + * input device, then you can revive code that was removed in the patch + * "rfkill-input: remove unused code". + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "rfkill.h" + +enum rfkill_input_master_mode { + RFKILL_INPUT_MASTER_UNLOCK = 0, + RFKILL_INPUT_MASTER_RESTORE = 1, + RFKILL_INPUT_MASTER_UNBLOCKALL = 2, + NUM_RFKILL_INPUT_MASTER_MODES +}; + +/* Delay (in ms) between consecutive switch ops */ +#define RFKILL_OPS_DELAY 200 + +static enum rfkill_input_master_mode rfkill_master_switch_mode = + RFKILL_INPUT_MASTER_UNBLOCKALL; +module_param_named(master_switch_mode, rfkill_master_switch_mode, uint, 0); +MODULE_PARM_DESC(master_switch_mode, + "SW_RFKILL_ALL ON should: 0=do nothing (only unlock); 1=restore; 2=unblock all"); + +static DEFINE_SPINLOCK(rfkill_op_lock); +static bool rfkill_op_pending; +static unsigned long rfkill_sw_pending[BITS_TO_LONGS(NUM_RFKILL_TYPES)]; +static unsigned long rfkill_sw_state[BITS_TO_LONGS(NUM_RFKILL_TYPES)]; + +enum rfkill_sched_op { + RFKILL_GLOBAL_OP_EPO = 0, + RFKILL_GLOBAL_OP_RESTORE, + RFKILL_GLOBAL_OP_UNLOCK, + RFKILL_GLOBAL_OP_UNBLOCK, +}; + +static enum rfkill_sched_op rfkill_master_switch_op; +static enum rfkill_sched_op rfkill_op; + +static void __rfkill_handle_global_op(enum rfkill_sched_op op) +{ + unsigned int i; + + switch (op) { + case RFKILL_GLOBAL_OP_EPO: + rfkill_epo(); + break; + case RFKILL_GLOBAL_OP_RESTORE: + rfkill_restore_states(); + break; + case RFKILL_GLOBAL_OP_UNLOCK: + rfkill_remove_epo_lock(); + break; + case RFKILL_GLOBAL_OP_UNBLOCK: + rfkill_remove_epo_lock(); + for (i = 0; i < NUM_RFKILL_TYPES; i++) + rfkill_switch_all(i, false); + break; + default: + /* memory corruption or bug, fail safely */ + rfkill_epo(); + WARN(1, "Unknown requested operation %d! " + "rfkill Emergency Power Off activated\n", + op); + } +} + +static void __rfkill_handle_normal_op(const enum rfkill_type type, + const bool complement) +{ + bool blocked; + + blocked = rfkill_get_global_sw_state(type); + if (complement) + blocked = !blocked; + + rfkill_switch_all(type, blocked); +} + +static void rfkill_op_handler(struct work_struct *work) +{ + unsigned int i; + bool c; + + spin_lock_irq(&rfkill_op_lock); + do { + if (rfkill_op_pending) { + enum rfkill_sched_op op = rfkill_op; + rfkill_op_pending = false; + memset(rfkill_sw_pending, 0, + sizeof(rfkill_sw_pending)); + spin_unlock_irq(&rfkill_op_lock); + + __rfkill_handle_global_op(op); + + spin_lock_irq(&rfkill_op_lock); + + /* + * handle global ops first -- during unlocked period + * we might have gotten a new global op. + */ + if (rfkill_op_pending) + continue; + } + + if (rfkill_is_epo_lock_active()) + continue; + + for (i = 0; i < NUM_RFKILL_TYPES; i++) { + if (__test_and_clear_bit(i, rfkill_sw_pending)) { + c = __test_and_clear_bit(i, rfkill_sw_state); + spin_unlock_irq(&rfkill_op_lock); + + __rfkill_handle_normal_op(i, c); + + spin_lock_irq(&rfkill_op_lock); + } + } + } while (rfkill_op_pending); + spin_unlock_irq(&rfkill_op_lock); +} + +static DECLARE_DELAYED_WORK(rfkill_op_work, rfkill_op_handler); +static unsigned long rfkill_last_scheduled; + +static unsigned long rfkill_ratelimit(const unsigned long last) +{ + const unsigned long delay = msecs_to_jiffies(RFKILL_OPS_DELAY); + return time_after(jiffies, last + delay) ? 0 : delay; +} + +static void rfkill_schedule_ratelimited(void) +{ + if (schedule_delayed_work(&rfkill_op_work, + rfkill_ratelimit(rfkill_last_scheduled))) + rfkill_last_scheduled = jiffies; +} + +static void rfkill_schedule_global_op(enum rfkill_sched_op op) +{ + unsigned long flags; + + spin_lock_irqsave(&rfkill_op_lock, flags); + rfkill_op = op; + rfkill_op_pending = true; + if (op == RFKILL_GLOBAL_OP_EPO && !rfkill_is_epo_lock_active()) { + /* bypass the limiter for EPO */ + mod_delayed_work(system_wq, &rfkill_op_work, 0); + rfkill_last_scheduled = jiffies; + } else + rfkill_schedule_ratelimited(); + spin_unlock_irqrestore(&rfkill_op_lock, flags); +} + +static void rfkill_schedule_toggle(enum rfkill_type type) +{ + unsigned long flags; + + if (rfkill_is_epo_lock_active()) + return; + + spin_lock_irqsave(&rfkill_op_lock, flags); + if (!rfkill_op_pending) { + __set_bit(type, rfkill_sw_pending); + __change_bit(type, rfkill_sw_state); + rfkill_schedule_ratelimited(); + } + spin_unlock_irqrestore(&rfkill_op_lock, flags); +} + +static void rfkill_schedule_evsw_rfkillall(int state) +{ + if (state) + rfkill_schedule_global_op(rfkill_master_switch_op); + else + rfkill_schedule_global_op(RFKILL_GLOBAL_OP_EPO); +} + +static void rfkill_event(struct input_handle *handle, unsigned int type, + unsigned int code, int data) +{ + if (type == EV_KEY && data == 1) { + switch (code) { + case KEY_WLAN: + rfkill_schedule_toggle(RFKILL_TYPE_WLAN); + break; + case KEY_BLUETOOTH: + rfkill_schedule_toggle(RFKILL_TYPE_BLUETOOTH); + break; + case KEY_UWB: + rfkill_schedule_toggle(RFKILL_TYPE_UWB); + break; + case KEY_WIMAX: + rfkill_schedule_toggle(RFKILL_TYPE_WIMAX); + break; + case KEY_RFKILL: + rfkill_schedule_toggle(RFKILL_TYPE_ALL); + break; + } + } else if (type == EV_SW && code == SW_RFKILL_ALL) + rfkill_schedule_evsw_rfkillall(data); +} + +static int rfkill_connect(struct input_handler *handler, struct input_dev *dev, + const struct input_device_id *id) +{ + struct input_handle *handle; + int error; + + handle = kzalloc(sizeof(struct input_handle), GFP_KERNEL); + if (!handle) + return -ENOMEM; + + handle->dev = dev; + handle->handler = handler; + handle->name = "rfkill"; + + /* causes rfkill_start() to be called */ + error = input_register_handle(handle); + if (error) + goto err_free_handle; + + error = input_open_device(handle); + if (error) + goto err_unregister_handle; + + return 0; + + err_unregister_handle: + input_unregister_handle(handle); + err_free_handle: + kfree(handle); + return error; +} + +static void rfkill_start(struct input_handle *handle) +{ + /* + * Take event_lock to guard against configuration changes, we + * should be able to deal with concurrency with rfkill_event() + * just fine (which event_lock will also avoid). + */ + spin_lock_irq(&handle->dev->event_lock); + + if (test_bit(EV_SW, handle->dev->evbit) && + test_bit(SW_RFKILL_ALL, handle->dev->swbit)) + rfkill_schedule_evsw_rfkillall(test_bit(SW_RFKILL_ALL, + handle->dev->sw)); + + spin_unlock_irq(&handle->dev->event_lock); +} + +static void rfkill_disconnect(struct input_handle *handle) +{ + input_close_device(handle); + input_unregister_handle(handle); + kfree(handle); +} + +static const struct input_device_id rfkill_ids[] = { + { + .flags = INPUT_DEVICE_ID_MATCH_EVBIT | INPUT_DEVICE_ID_MATCH_KEYBIT, + .evbit = { BIT_MASK(EV_KEY) }, + .keybit = { [BIT_WORD(KEY_WLAN)] = BIT_MASK(KEY_WLAN) }, + }, + { + .flags = INPUT_DEVICE_ID_MATCH_EVBIT | INPUT_DEVICE_ID_MATCH_KEYBIT, + .evbit = { BIT_MASK(EV_KEY) }, + .keybit = { [BIT_WORD(KEY_BLUETOOTH)] = BIT_MASK(KEY_BLUETOOTH) }, + }, + { + .flags = INPUT_DEVICE_ID_MATCH_EVBIT | INPUT_DEVICE_ID_MATCH_KEYBIT, + .evbit = { BIT_MASK(EV_KEY) }, + .keybit = { [BIT_WORD(KEY_UWB)] = BIT_MASK(KEY_UWB) }, + }, + { + .flags = INPUT_DEVICE_ID_MATCH_EVBIT | INPUT_DEVICE_ID_MATCH_KEYBIT, + .evbit = { BIT_MASK(EV_KEY) }, + .keybit = { [BIT_WORD(KEY_WIMAX)] = BIT_MASK(KEY_WIMAX) }, + }, + { + .flags = INPUT_DEVICE_ID_MATCH_EVBIT | INPUT_DEVICE_ID_MATCH_KEYBIT, + .evbit = { BIT_MASK(EV_KEY) }, + .keybit = { [BIT_WORD(KEY_RFKILL)] = BIT_MASK(KEY_RFKILL) }, + }, + { + .flags = INPUT_DEVICE_ID_MATCH_EVBIT | INPUT_DEVICE_ID_MATCH_SWBIT, + .evbit = { BIT(EV_SW) }, + .swbit = { [BIT_WORD(SW_RFKILL_ALL)] = BIT_MASK(SW_RFKILL_ALL) }, + }, + { } +}; + +static struct input_handler rfkill_handler = { + .name = "rfkill", + .event = rfkill_event, + .connect = rfkill_connect, + .start = rfkill_start, + .disconnect = rfkill_disconnect, + .id_table = rfkill_ids, +}; + +int __init rfkill_handler_init(void) +{ + switch (rfkill_master_switch_mode) { + case RFKILL_INPUT_MASTER_UNBLOCKALL: + rfkill_master_switch_op = RFKILL_GLOBAL_OP_UNBLOCK; + break; + case RFKILL_INPUT_MASTER_RESTORE: + rfkill_master_switch_op = RFKILL_GLOBAL_OP_RESTORE; + break; + case RFKILL_INPUT_MASTER_UNLOCK: + rfkill_master_switch_op = RFKILL_GLOBAL_OP_UNLOCK; + break; + default: + return -EINVAL; + } + + /* Avoid delay at first schedule */ + rfkill_last_scheduled = + jiffies - msecs_to_jiffies(RFKILL_OPS_DELAY) - 1; + return input_register_handler(&rfkill_handler); +} + +void __exit rfkill_handler_exit(void) +{ + input_unregister_handler(&rfkill_handler); + cancel_delayed_work_sync(&rfkill_op_work); +} diff --git a/net/rfkill/rfkill-gpio.c b/net/rfkill/rfkill-gpio.c new file mode 100644 index 000000000..f74baefd8 --- /dev/null +++ b/net/rfkill/rfkill-gpio.c @@ -0,0 +1,181 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (c) 2011, NVIDIA Corporation. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct rfkill_gpio_data { + const char *name; + enum rfkill_type type; + struct gpio_desc *reset_gpio; + struct gpio_desc *shutdown_gpio; + + struct rfkill *rfkill_dev; + struct clk *clk; + + bool clk_enabled; +}; + +static int rfkill_gpio_set_power(void *data, bool blocked) +{ + struct rfkill_gpio_data *rfkill = data; + + if (!blocked && !IS_ERR(rfkill->clk) && !rfkill->clk_enabled) + clk_enable(rfkill->clk); + + gpiod_set_value_cansleep(rfkill->shutdown_gpio, !blocked); + gpiod_set_value_cansleep(rfkill->reset_gpio, !blocked); + + if (blocked && !IS_ERR(rfkill->clk) && rfkill->clk_enabled) + clk_disable(rfkill->clk); + + rfkill->clk_enabled = !blocked; + + return 0; +} + +static const struct rfkill_ops rfkill_gpio_ops = { + .set_block = rfkill_gpio_set_power, +}; + +static const struct acpi_gpio_params reset_gpios = { 0, 0, false }; +static const struct acpi_gpio_params shutdown_gpios = { 1, 0, false }; + +static const struct acpi_gpio_mapping acpi_rfkill_default_gpios[] = { + { "reset-gpios", &reset_gpios, 1 }, + { "shutdown-gpios", &shutdown_gpios, 1 }, + { }, +}; + +static int rfkill_gpio_acpi_probe(struct device *dev, + struct rfkill_gpio_data *rfkill) +{ + const struct acpi_device_id *id; + + id = acpi_match_device(dev->driver->acpi_match_table, dev); + if (!id) + return -ENODEV; + + rfkill->type = (unsigned)id->driver_data; + + return devm_acpi_dev_add_driver_gpios(dev, acpi_rfkill_default_gpios); +} + +static int rfkill_gpio_probe(struct platform_device *pdev) +{ + struct rfkill_gpio_data *rfkill; + struct gpio_desc *gpio; + const char *type_name; + int ret; + + rfkill = devm_kzalloc(&pdev->dev, sizeof(*rfkill), GFP_KERNEL); + if (!rfkill) + return -ENOMEM; + + device_property_read_string(&pdev->dev, "name", &rfkill->name); + device_property_read_string(&pdev->dev, "type", &type_name); + + if (!rfkill->name) + rfkill->name = dev_name(&pdev->dev); + + rfkill->type = rfkill_find_type(type_name); + + if (ACPI_HANDLE(&pdev->dev)) { + ret = rfkill_gpio_acpi_probe(&pdev->dev, rfkill); + if (ret) + return ret; + } + + rfkill->clk = devm_clk_get(&pdev->dev, NULL); + + gpio = devm_gpiod_get_optional(&pdev->dev, "reset", GPIOD_ASIS); + if (IS_ERR(gpio)) + return PTR_ERR(gpio); + + rfkill->reset_gpio = gpio; + + gpio = devm_gpiod_get_optional(&pdev->dev, "shutdown", GPIOD_ASIS); + if (IS_ERR(gpio)) + return PTR_ERR(gpio); + + rfkill->shutdown_gpio = gpio; + + /* Make sure at-least one GPIO is defined for this instance */ + if (!rfkill->reset_gpio && !rfkill->shutdown_gpio) { + dev_err(&pdev->dev, "invalid platform data\n"); + return -EINVAL; + } + + ret = gpiod_direction_output(rfkill->reset_gpio, true); + if (ret) + return ret; + + ret = gpiod_direction_output(rfkill->shutdown_gpio, true); + if (ret) + return ret; + + rfkill->rfkill_dev = rfkill_alloc(rfkill->name, &pdev->dev, + rfkill->type, &rfkill_gpio_ops, + rfkill); + if (!rfkill->rfkill_dev) + return -ENOMEM; + + ret = rfkill_register(rfkill->rfkill_dev); + if (ret < 0) + goto err_destroy; + + platform_set_drvdata(pdev, rfkill); + + dev_info(&pdev->dev, "%s device registered.\n", rfkill->name); + + return 0; + +err_destroy: + rfkill_destroy(rfkill->rfkill_dev); + + return ret; +} + +static int rfkill_gpio_remove(struct platform_device *pdev) +{ + struct rfkill_gpio_data *rfkill = platform_get_drvdata(pdev); + + rfkill_unregister(rfkill->rfkill_dev); + rfkill_destroy(rfkill->rfkill_dev); + + return 0; +} + +#ifdef CONFIG_ACPI +static const struct acpi_device_id rfkill_acpi_match[] = { + { "BCM4752", RFKILL_TYPE_GPS }, + { "LNV4752", RFKILL_TYPE_GPS }, + { }, +}; +MODULE_DEVICE_TABLE(acpi, rfkill_acpi_match); +#endif + +static struct platform_driver rfkill_gpio_driver = { + .probe = rfkill_gpio_probe, + .remove = rfkill_gpio_remove, + .driver = { + .name = "rfkill_gpio", + .acpi_match_table = ACPI_PTR(rfkill_acpi_match), + }, +}; + +module_platform_driver(rfkill_gpio_driver); + +MODULE_DESCRIPTION("gpio rfkill"); +MODULE_AUTHOR("NVIDIA"); +MODULE_LICENSE("GPL"); diff --git a/net/rfkill/rfkill.h b/net/rfkill/rfkill.h new file mode 100644 index 000000000..001c40caa --- /dev/null +++ b/net/rfkill/rfkill.h @@ -0,0 +1,23 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * Copyright (C) 2007 Ivo van Doorn + * Copyright 2009 Johannes Berg + */ + + +#ifndef __RFKILL_INPUT_H +#define __RFKILL_INPUT_H + +/* core code */ +void rfkill_switch_all(const enum rfkill_type type, bool blocked); +void rfkill_epo(void); +void rfkill_restore_states(void); +void rfkill_remove_epo_lock(void); +bool rfkill_is_epo_lock_active(void); +bool rfkill_get_global_sw_state(const enum rfkill_type type); + +/* input handler */ +int rfkill_handler_init(void); +void rfkill_handler_exit(void); + +#endif /* __RFKILL_INPUT_H */ diff --git a/net/rose/Makefile b/net/rose/Makefile new file mode 100644 index 000000000..3e6638f5b --- /dev/null +++ b/net/rose/Makefile @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the Linux Rose (X.25 PLP) layer. +# + +obj-$(CONFIG_ROSE) += rose.o + +rose-y := af_rose.o rose_dev.o rose_in.o rose_link.o rose_loopback.o \ + rose_out.o rose_route.o rose_subr.o rose_timer.o +rose-$(CONFIG_SYSCTL) += sysctl_net_rose.o diff --git a/net/rose/af_rose.c b/net/rose/af_rose.c new file mode 100644 index 000000000..29b74a569 --- /dev/null +++ b/net/rose/af_rose.c @@ -0,0 +1,1674 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright (C) Alan Cox GW4PTS (alan@lxorguk.ukuu.org.uk) + * Copyright (C) Terry Dawson VK2KTJ (terry@animats.net) + * Copyright (C) Tomi Manninen OH2BNS (oh2bns@sral.fi) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int rose_ndevs = 10; + +int sysctl_rose_restart_request_timeout = ROSE_DEFAULT_T0; +int sysctl_rose_call_request_timeout = ROSE_DEFAULT_T1; +int sysctl_rose_reset_request_timeout = ROSE_DEFAULT_T2; +int sysctl_rose_clear_request_timeout = ROSE_DEFAULT_T3; +int sysctl_rose_no_activity_timeout = ROSE_DEFAULT_IDLE; +int sysctl_rose_ack_hold_back_timeout = ROSE_DEFAULT_HB; +int sysctl_rose_routing_control = ROSE_DEFAULT_ROUTING; +int sysctl_rose_link_fail_timeout = ROSE_DEFAULT_FAIL_TIMEOUT; +int sysctl_rose_maximum_vcs = ROSE_DEFAULT_MAXVC; +int sysctl_rose_window_size = ROSE_DEFAULT_WINDOW_SIZE; + +static HLIST_HEAD(rose_list); +static DEFINE_SPINLOCK(rose_list_lock); + +static const struct proto_ops rose_proto_ops; + +ax25_address rose_callsign; + +/* + * ROSE network devices are virtual network devices encapsulating ROSE + * frames into AX.25 which will be sent through an AX.25 device, so form a + * special "super class" of normal net devices; split their locks off into a + * separate class since they always nest. + */ +static struct lock_class_key rose_netdev_xmit_lock_key; +static struct lock_class_key rose_netdev_addr_lock_key; + +static void rose_set_lockdep_one(struct net_device *dev, + struct netdev_queue *txq, + void *_unused) +{ + lockdep_set_class(&txq->_xmit_lock, &rose_netdev_xmit_lock_key); +} + +static void rose_set_lockdep_key(struct net_device *dev) +{ + lockdep_set_class(&dev->addr_list_lock, &rose_netdev_addr_lock_key); + netdev_for_each_tx_queue(dev, rose_set_lockdep_one, NULL); +} + +/* + * Convert a ROSE address into text. + */ +char *rose2asc(char *buf, const rose_address *addr) +{ + if (addr->rose_addr[0] == 0x00 && addr->rose_addr[1] == 0x00 && + addr->rose_addr[2] == 0x00 && addr->rose_addr[3] == 0x00 && + addr->rose_addr[4] == 0x00) { + strcpy(buf, "*"); + } else { + sprintf(buf, "%02X%02X%02X%02X%02X", addr->rose_addr[0] & 0xFF, + addr->rose_addr[1] & 0xFF, + addr->rose_addr[2] & 0xFF, + addr->rose_addr[3] & 0xFF, + addr->rose_addr[4] & 0xFF); + } + + return buf; +} + +/* + * Compare two ROSE addresses, 0 == equal. + */ +int rosecmp(const rose_address *addr1, const rose_address *addr2) +{ + int i; + + for (i = 0; i < 5; i++) + if (addr1->rose_addr[i] != addr2->rose_addr[i]) + return 1; + + return 0; +} + +/* + * Compare two ROSE addresses for only mask digits, 0 == equal. + */ +int rosecmpm(const rose_address *addr1, const rose_address *addr2, + unsigned short mask) +{ + unsigned int i, j; + + if (mask > 10) + return 1; + + for (i = 0; i < mask; i++) { + j = i / 2; + + if ((i % 2) != 0) { + if ((addr1->rose_addr[j] & 0x0F) != (addr2->rose_addr[j] & 0x0F)) + return 1; + } else { + if ((addr1->rose_addr[j] & 0xF0) != (addr2->rose_addr[j] & 0xF0)) + return 1; + } + } + + return 0; +} + +/* + * Socket removal during an interrupt is now safe. + */ +static void rose_remove_socket(struct sock *sk) +{ + spin_lock_bh(&rose_list_lock); + sk_del_node_init(sk); + spin_unlock_bh(&rose_list_lock); +} + +/* + * Kill all bound sockets on a broken link layer connection to a + * particular neighbour. + */ +void rose_kill_by_neigh(struct rose_neigh *neigh) +{ + struct sock *s; + + spin_lock_bh(&rose_list_lock); + sk_for_each(s, &rose_list) { + struct rose_sock *rose = rose_sk(s); + + if (rose->neighbour == neigh) { + rose_disconnect(s, ENETUNREACH, ROSE_OUT_OF_ORDER, 0); + rose->neighbour->use--; + rose->neighbour = NULL; + } + } + spin_unlock_bh(&rose_list_lock); +} + +/* + * Kill all bound sockets on a dropped device. + */ +static void rose_kill_by_device(struct net_device *dev) +{ + struct sock *sk, *array[16]; + struct rose_sock *rose; + bool rescan; + int i, cnt; + +start: + rescan = false; + cnt = 0; + spin_lock_bh(&rose_list_lock); + sk_for_each(sk, &rose_list) { + rose = rose_sk(sk); + if (rose->device == dev) { + if (cnt == ARRAY_SIZE(array)) { + rescan = true; + break; + } + sock_hold(sk); + array[cnt++] = sk; + } + } + spin_unlock_bh(&rose_list_lock); + + for (i = 0; i < cnt; i++) { + sk = array[cnt]; + rose = rose_sk(sk); + lock_sock(sk); + spin_lock_bh(&rose_list_lock); + if (rose->device == dev) { + rose_disconnect(sk, ENETUNREACH, ROSE_OUT_OF_ORDER, 0); + if (rose->neighbour) + rose->neighbour->use--; + netdev_put(rose->device, &rose->dev_tracker); + rose->device = NULL; + } + spin_unlock_bh(&rose_list_lock); + release_sock(sk); + sock_put(sk); + cond_resched(); + } + if (rescan) + goto start; +} + +/* + * Handle device status changes. + */ +static int rose_device_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + if (!net_eq(dev_net(dev), &init_net)) + return NOTIFY_DONE; + + if (event != NETDEV_DOWN) + return NOTIFY_DONE; + + switch (dev->type) { + case ARPHRD_ROSE: + rose_kill_by_device(dev); + break; + case ARPHRD_AX25: + rose_link_device_down(dev); + rose_rt_device_down(dev); + break; + } + + return NOTIFY_DONE; +} + +/* + * Add a socket to the bound sockets list. + */ +static void rose_insert_socket(struct sock *sk) +{ + + spin_lock_bh(&rose_list_lock); + sk_add_node(sk, &rose_list); + spin_unlock_bh(&rose_list_lock); +} + +/* + * Find a socket that wants to accept the Call Request we just + * received. + */ +static struct sock *rose_find_listener(rose_address *addr, ax25_address *call) +{ + struct sock *s; + + spin_lock_bh(&rose_list_lock); + sk_for_each(s, &rose_list) { + struct rose_sock *rose = rose_sk(s); + + if (!rosecmp(&rose->source_addr, addr) && + !ax25cmp(&rose->source_call, call) && + !rose->source_ndigis && s->sk_state == TCP_LISTEN) + goto found; + } + + sk_for_each(s, &rose_list) { + struct rose_sock *rose = rose_sk(s); + + if (!rosecmp(&rose->source_addr, addr) && + !ax25cmp(&rose->source_call, &null_ax25_address) && + s->sk_state == TCP_LISTEN) + goto found; + } + s = NULL; +found: + spin_unlock_bh(&rose_list_lock); + return s; +} + +/* + * Find a connected ROSE socket given my LCI and device. + */ +struct sock *rose_find_socket(unsigned int lci, struct rose_neigh *neigh) +{ + struct sock *s; + + spin_lock_bh(&rose_list_lock); + sk_for_each(s, &rose_list) { + struct rose_sock *rose = rose_sk(s); + + if (rose->lci == lci && rose->neighbour == neigh) + goto found; + } + s = NULL; +found: + spin_unlock_bh(&rose_list_lock); + return s; +} + +/* + * Find a unique LCI for a given device. + */ +unsigned int rose_new_lci(struct rose_neigh *neigh) +{ + int lci; + + if (neigh->dce_mode) { + for (lci = 1; lci <= sysctl_rose_maximum_vcs; lci++) + if (rose_find_socket(lci, neigh) == NULL && rose_route_free_lci(lci, neigh) == NULL) + return lci; + } else { + for (lci = sysctl_rose_maximum_vcs; lci > 0; lci--) + if (rose_find_socket(lci, neigh) == NULL && rose_route_free_lci(lci, neigh) == NULL) + return lci; + } + + return 0; +} + +/* + * Deferred destroy. + */ +void rose_destroy_socket(struct sock *); + +/* + * Handler for deferred kills. + */ +static void rose_destroy_timer(struct timer_list *t) +{ + struct sock *sk = from_timer(sk, t, sk_timer); + + rose_destroy_socket(sk); +} + +/* + * This is called from user mode and the timers. Thus it protects itself + * against interrupt users but doesn't worry about being called during + * work. Once it is removed from the queue no interrupt or bottom half + * will touch it and we are (fairly 8-) ) safe. + */ +void rose_destroy_socket(struct sock *sk) +{ + struct sk_buff *skb; + + rose_remove_socket(sk); + rose_stop_heartbeat(sk); + rose_stop_idletimer(sk); + rose_stop_timer(sk); + + rose_clear_queues(sk); /* Flush the queues */ + + while ((skb = skb_dequeue(&sk->sk_receive_queue)) != NULL) { + if (skb->sk != sk) { /* A pending connection */ + /* Queue the unaccepted socket for death */ + sock_set_flag(skb->sk, SOCK_DEAD); + rose_start_heartbeat(skb->sk); + rose_sk(skb->sk)->state = ROSE_STATE_0; + } + + kfree_skb(skb); + } + + if (sk_has_allocations(sk)) { + /* Defer: outstanding buffers */ + timer_setup(&sk->sk_timer, rose_destroy_timer, 0); + sk->sk_timer.expires = jiffies + 10 * HZ; + add_timer(&sk->sk_timer); + } else + sock_put(sk); +} + +/* + * Handling for system calls applied via the various interfaces to a + * ROSE socket object. + */ + +static int rose_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct rose_sock *rose = rose_sk(sk); + int opt; + + if (level != SOL_ROSE) + return -ENOPROTOOPT; + + if (optlen < sizeof(int)) + return -EINVAL; + + if (copy_from_sockptr(&opt, optval, sizeof(int))) + return -EFAULT; + + switch (optname) { + case ROSE_DEFER: + rose->defer = opt ? 1 : 0; + return 0; + + case ROSE_T1: + if (opt < 1) + return -EINVAL; + rose->t1 = opt * HZ; + return 0; + + case ROSE_T2: + if (opt < 1) + return -EINVAL; + rose->t2 = opt * HZ; + return 0; + + case ROSE_T3: + if (opt < 1) + return -EINVAL; + rose->t3 = opt * HZ; + return 0; + + case ROSE_HOLDBACK: + if (opt < 1) + return -EINVAL; + rose->hb = opt * HZ; + return 0; + + case ROSE_IDLE: + if (opt < 0) + return -EINVAL; + rose->idle = opt * 60 * HZ; + return 0; + + case ROSE_QBITINCL: + rose->qbitincl = opt ? 1 : 0; + return 0; + + default: + return -ENOPROTOOPT; + } +} + +static int rose_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + struct rose_sock *rose = rose_sk(sk); + int val = 0; + int len; + + if (level != SOL_ROSE) + return -ENOPROTOOPT; + + if (get_user(len, optlen)) + return -EFAULT; + + if (len < 0) + return -EINVAL; + + switch (optname) { + case ROSE_DEFER: + val = rose->defer; + break; + + case ROSE_T1: + val = rose->t1 / HZ; + break; + + case ROSE_T2: + val = rose->t2 / HZ; + break; + + case ROSE_T3: + val = rose->t3 / HZ; + break; + + case ROSE_HOLDBACK: + val = rose->hb / HZ; + break; + + case ROSE_IDLE: + val = rose->idle / (60 * HZ); + break; + + case ROSE_QBITINCL: + val = rose->qbitincl; + break; + + default: + return -ENOPROTOOPT; + } + + len = min_t(unsigned int, len, sizeof(int)); + + if (put_user(len, optlen)) + return -EFAULT; + + return copy_to_user(optval, &val, len) ? -EFAULT : 0; +} + +static int rose_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + + lock_sock(sk); + if (sock->state != SS_UNCONNECTED) { + release_sock(sk); + return -EINVAL; + } + + if (sk->sk_state != TCP_LISTEN) { + struct rose_sock *rose = rose_sk(sk); + + rose->dest_ndigis = 0; + memset(&rose->dest_addr, 0, ROSE_ADDR_LEN); + memset(&rose->dest_call, 0, AX25_ADDR_LEN); + memset(rose->dest_digis, 0, AX25_ADDR_LEN * ROSE_MAX_DIGIS); + sk->sk_max_ack_backlog = backlog; + sk->sk_state = TCP_LISTEN; + release_sock(sk); + return 0; + } + release_sock(sk); + + return -EOPNOTSUPP; +} + +static struct proto rose_proto = { + .name = "ROSE", + .owner = THIS_MODULE, + .obj_size = sizeof(struct rose_sock), +}; + +static int rose_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + struct rose_sock *rose; + + if (!net_eq(net, &init_net)) + return -EAFNOSUPPORT; + + if (sock->type != SOCK_SEQPACKET || protocol != 0) + return -ESOCKTNOSUPPORT; + + sk = sk_alloc(net, PF_ROSE, GFP_ATOMIC, &rose_proto, kern); + if (sk == NULL) + return -ENOMEM; + + rose = rose_sk(sk); + + sock_init_data(sock, sk); + + skb_queue_head_init(&rose->ack_queue); +#ifdef M_BIT + skb_queue_head_init(&rose->frag_queue); + rose->fraglen = 0; +#endif + + sock->ops = &rose_proto_ops; + sk->sk_protocol = protocol; + + timer_setup(&rose->timer, NULL, 0); + timer_setup(&rose->idletimer, NULL, 0); + + rose->t1 = msecs_to_jiffies(sysctl_rose_call_request_timeout); + rose->t2 = msecs_to_jiffies(sysctl_rose_reset_request_timeout); + rose->t3 = msecs_to_jiffies(sysctl_rose_clear_request_timeout); + rose->hb = msecs_to_jiffies(sysctl_rose_ack_hold_back_timeout); + rose->idle = msecs_to_jiffies(sysctl_rose_no_activity_timeout); + + rose->state = ROSE_STATE_0; + + return 0; +} + +static struct sock *rose_make_new(struct sock *osk) +{ + struct sock *sk; + struct rose_sock *rose, *orose; + + if (osk->sk_type != SOCK_SEQPACKET) + return NULL; + + sk = sk_alloc(sock_net(osk), PF_ROSE, GFP_ATOMIC, &rose_proto, 0); + if (sk == NULL) + return NULL; + + rose = rose_sk(sk); + + sock_init_data(NULL, sk); + + skb_queue_head_init(&rose->ack_queue); +#ifdef M_BIT + skb_queue_head_init(&rose->frag_queue); + rose->fraglen = 0; +#endif + + sk->sk_type = osk->sk_type; + sk->sk_priority = osk->sk_priority; + sk->sk_protocol = osk->sk_protocol; + sk->sk_rcvbuf = osk->sk_rcvbuf; + sk->sk_sndbuf = osk->sk_sndbuf; + sk->sk_state = TCP_ESTABLISHED; + sock_copy_flags(sk, osk); + + timer_setup(&rose->timer, NULL, 0); + timer_setup(&rose->idletimer, NULL, 0); + + orose = rose_sk(osk); + rose->t1 = orose->t1; + rose->t2 = orose->t2; + rose->t3 = orose->t3; + rose->hb = orose->hb; + rose->idle = orose->idle; + rose->defer = orose->defer; + rose->device = orose->device; + if (rose->device) + netdev_hold(rose->device, &rose->dev_tracker, GFP_ATOMIC); + rose->qbitincl = orose->qbitincl; + + return sk; +} + +static int rose_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct rose_sock *rose; + + if (sk == NULL) return 0; + + sock_hold(sk); + sock_orphan(sk); + lock_sock(sk); + rose = rose_sk(sk); + + switch (rose->state) { + case ROSE_STATE_0: + release_sock(sk); + rose_disconnect(sk, 0, -1, -1); + lock_sock(sk); + rose_destroy_socket(sk); + break; + + case ROSE_STATE_2: + rose->neighbour->use--; + release_sock(sk); + rose_disconnect(sk, 0, -1, -1); + lock_sock(sk); + rose_destroy_socket(sk); + break; + + case ROSE_STATE_1: + case ROSE_STATE_3: + case ROSE_STATE_4: + case ROSE_STATE_5: + rose_clear_queues(sk); + rose_stop_idletimer(sk); + rose_write_internal(sk, ROSE_CLEAR_REQUEST); + rose_start_t3timer(sk); + rose->state = ROSE_STATE_2; + sk->sk_state = TCP_CLOSE; + sk->sk_shutdown |= SEND_SHUTDOWN; + sk->sk_state_change(sk); + sock_set_flag(sk, SOCK_DEAD); + sock_set_flag(sk, SOCK_DESTROY); + break; + + default: + break; + } + + spin_lock_bh(&rose_list_lock); + netdev_put(rose->device, &rose->dev_tracker); + rose->device = NULL; + spin_unlock_bh(&rose_list_lock); + sock->sk = NULL; + release_sock(sk); + sock_put(sk); + + return 0; +} + +static int rose_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +{ + struct sock *sk = sock->sk; + struct rose_sock *rose = rose_sk(sk); + struct sockaddr_rose *addr = (struct sockaddr_rose *)uaddr; + struct net_device *dev; + ax25_address *source; + ax25_uid_assoc *user; + int n; + + if (!sock_flag(sk, SOCK_ZAPPED)) + return -EINVAL; + + if (addr_len != sizeof(struct sockaddr_rose) && addr_len != sizeof(struct full_sockaddr_rose)) + return -EINVAL; + + if (addr->srose_family != AF_ROSE) + return -EINVAL; + + if (addr_len == sizeof(struct sockaddr_rose) && addr->srose_ndigis > 1) + return -EINVAL; + + if ((unsigned int) addr->srose_ndigis > ROSE_MAX_DIGIS) + return -EINVAL; + + if ((dev = rose_dev_get(&addr->srose_addr)) == NULL) + return -EADDRNOTAVAIL; + + source = &addr->srose_call; + + user = ax25_findbyuid(current_euid()); + if (user) { + rose->source_call = user->call; + ax25_uid_put(user); + } else { + if (ax25_uid_policy && !capable(CAP_NET_BIND_SERVICE)) { + dev_put(dev); + return -EACCES; + } + rose->source_call = *source; + } + + rose->source_addr = addr->srose_addr; + rose->device = dev; + netdev_tracker_alloc(rose->device, &rose->dev_tracker, GFP_KERNEL); + rose->source_ndigis = addr->srose_ndigis; + + if (addr_len == sizeof(struct full_sockaddr_rose)) { + struct full_sockaddr_rose *full_addr = (struct full_sockaddr_rose *)uaddr; + for (n = 0 ; n < addr->srose_ndigis ; n++) + rose->source_digis[n] = full_addr->srose_digis[n]; + } else { + if (rose->source_ndigis == 1) { + rose->source_digis[0] = addr->srose_digi; + } + } + + rose_insert_socket(sk); + + sock_reset_flag(sk, SOCK_ZAPPED); + + return 0; +} + +static int rose_connect(struct socket *sock, struct sockaddr *uaddr, int addr_len, int flags) +{ + struct sock *sk = sock->sk; + struct rose_sock *rose = rose_sk(sk); + struct sockaddr_rose *addr = (struct sockaddr_rose *)uaddr; + unsigned char cause, diagnostic; + ax25_uid_assoc *user; + int n, err = 0; + + if (addr_len != sizeof(struct sockaddr_rose) && addr_len != sizeof(struct full_sockaddr_rose)) + return -EINVAL; + + if (addr->srose_family != AF_ROSE) + return -EINVAL; + + if (addr_len == sizeof(struct sockaddr_rose) && addr->srose_ndigis > 1) + return -EINVAL; + + if ((unsigned int) addr->srose_ndigis > ROSE_MAX_DIGIS) + return -EINVAL; + + /* Source + Destination digis should not exceed ROSE_MAX_DIGIS */ + if ((rose->source_ndigis + addr->srose_ndigis) > ROSE_MAX_DIGIS) + return -EINVAL; + + lock_sock(sk); + + if (sk->sk_state == TCP_ESTABLISHED && sock->state == SS_CONNECTING) { + /* Connect completed during a ERESTARTSYS event */ + sock->state = SS_CONNECTED; + goto out_release; + } + + if (sk->sk_state == TCP_CLOSE && sock->state == SS_CONNECTING) { + sock->state = SS_UNCONNECTED; + err = -ECONNREFUSED; + goto out_release; + } + + if (sk->sk_state == TCP_ESTABLISHED) { + /* No reconnect on a seqpacket socket */ + err = -EISCONN; + goto out_release; + } + + sk->sk_state = TCP_CLOSE; + sock->state = SS_UNCONNECTED; + + rose->neighbour = rose_get_neigh(&addr->srose_addr, &cause, + &diagnostic, 0); + if (!rose->neighbour) { + err = -ENETUNREACH; + goto out_release; + } + + rose->lci = rose_new_lci(rose->neighbour); + if (!rose->lci) { + err = -ENETUNREACH; + goto out_release; + } + + if (sock_flag(sk, SOCK_ZAPPED)) { /* Must bind first - autobinding in this may or may not work */ + struct net_device *dev; + + sock_reset_flag(sk, SOCK_ZAPPED); + + dev = rose_dev_first(); + if (!dev) { + err = -ENETUNREACH; + goto out_release; + } + + user = ax25_findbyuid(current_euid()); + if (!user) { + err = -EINVAL; + dev_put(dev); + goto out_release; + } + + memcpy(&rose->source_addr, dev->dev_addr, ROSE_ADDR_LEN); + rose->source_call = user->call; + rose->device = dev; + netdev_tracker_alloc(rose->device, &rose->dev_tracker, + GFP_KERNEL); + ax25_uid_put(user); + + rose_insert_socket(sk); /* Finish the bind */ + } + rose->dest_addr = addr->srose_addr; + rose->dest_call = addr->srose_call; + rose->rand = ((long)rose & 0xFFFF) + rose->lci; + rose->dest_ndigis = addr->srose_ndigis; + + if (addr_len == sizeof(struct full_sockaddr_rose)) { + struct full_sockaddr_rose *full_addr = (struct full_sockaddr_rose *)uaddr; + for (n = 0 ; n < addr->srose_ndigis ; n++) + rose->dest_digis[n] = full_addr->srose_digis[n]; + } else { + if (rose->dest_ndigis == 1) { + rose->dest_digis[0] = addr->srose_digi; + } + } + + /* Move to connecting socket, start sending Connect Requests */ + sock->state = SS_CONNECTING; + sk->sk_state = TCP_SYN_SENT; + + rose->state = ROSE_STATE_1; + + rose->neighbour->use++; + + rose_write_internal(sk, ROSE_CALL_REQUEST); + rose_start_heartbeat(sk); + rose_start_t1timer(sk); + + /* Now the loop */ + if (sk->sk_state != TCP_ESTABLISHED && (flags & O_NONBLOCK)) { + err = -EINPROGRESS; + goto out_release; + } + + /* + * A Connect Ack with Choke or timeout or failed routing will go to + * closed. + */ + if (sk->sk_state == TCP_SYN_SENT) { + DEFINE_WAIT(wait); + + for (;;) { + prepare_to_wait(sk_sleep(sk), &wait, + TASK_INTERRUPTIBLE); + if (sk->sk_state != TCP_SYN_SENT) + break; + if (!signal_pending(current)) { + release_sock(sk); + schedule(); + lock_sock(sk); + continue; + } + err = -ERESTARTSYS; + break; + } + finish_wait(sk_sleep(sk), &wait); + + if (err) + goto out_release; + } + + if (sk->sk_state != TCP_ESTABLISHED) { + sock->state = SS_UNCONNECTED; + err = sock_error(sk); /* Always set at this point */ + goto out_release; + } + + sock->state = SS_CONNECTED; + +out_release: + release_sock(sk); + + return err; +} + +static int rose_accept(struct socket *sock, struct socket *newsock, int flags, + bool kern) +{ + struct sk_buff *skb; + struct sock *newsk; + DEFINE_WAIT(wait); + struct sock *sk; + int err = 0; + + if ((sk = sock->sk) == NULL) + return -EINVAL; + + lock_sock(sk); + if (sk->sk_type != SOCK_SEQPACKET) { + err = -EOPNOTSUPP; + goto out_release; + } + + if (sk->sk_state != TCP_LISTEN) { + err = -EINVAL; + goto out_release; + } + + /* + * The write queue this time is holding sockets ready to use + * hooked into the SABM we saved + */ + for (;;) { + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + + skb = skb_dequeue(&sk->sk_receive_queue); + if (skb) + break; + + if (flags & O_NONBLOCK) { + err = -EWOULDBLOCK; + break; + } + if (!signal_pending(current)) { + release_sock(sk); + schedule(); + lock_sock(sk); + continue; + } + err = -ERESTARTSYS; + break; + } + finish_wait(sk_sleep(sk), &wait); + if (err) + goto out_release; + + newsk = skb->sk; + sock_graft(newsk, newsock); + + /* Now attach up the new socket */ + skb->sk = NULL; + kfree_skb(skb); + sk_acceptq_removed(sk); + +out_release: + release_sock(sk); + + return err; +} + +static int rose_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + struct full_sockaddr_rose *srose = (struct full_sockaddr_rose *)uaddr; + struct sock *sk = sock->sk; + struct rose_sock *rose = rose_sk(sk); + int n; + + memset(srose, 0, sizeof(*srose)); + if (peer != 0) { + if (sk->sk_state != TCP_ESTABLISHED) + return -ENOTCONN; + srose->srose_family = AF_ROSE; + srose->srose_addr = rose->dest_addr; + srose->srose_call = rose->dest_call; + srose->srose_ndigis = rose->dest_ndigis; + for (n = 0; n < rose->dest_ndigis; n++) + srose->srose_digis[n] = rose->dest_digis[n]; + } else { + srose->srose_family = AF_ROSE; + srose->srose_addr = rose->source_addr; + srose->srose_call = rose->source_call; + srose->srose_ndigis = rose->source_ndigis; + for (n = 0; n < rose->source_ndigis; n++) + srose->srose_digis[n] = rose->source_digis[n]; + } + + return sizeof(struct full_sockaddr_rose); +} + +int rose_rx_call_request(struct sk_buff *skb, struct net_device *dev, struct rose_neigh *neigh, unsigned int lci) +{ + struct sock *sk; + struct sock *make; + struct rose_sock *make_rose; + struct rose_facilities_struct facilities; + int n; + + skb->sk = NULL; /* Initially we don't know who it's for */ + + /* + * skb->data points to the rose frame start + */ + memset(&facilities, 0x00, sizeof(struct rose_facilities_struct)); + + if (!rose_parse_facilities(skb->data + ROSE_CALL_REQ_FACILITIES_OFF, + skb->len - ROSE_CALL_REQ_FACILITIES_OFF, + &facilities)) { + rose_transmit_clear_request(neigh, lci, ROSE_INVALID_FACILITY, 76); + return 0; + } + + sk = rose_find_listener(&facilities.source_addr, &facilities.source_call); + + /* + * We can't accept the Call Request. + */ + if (sk == NULL || sk_acceptq_is_full(sk) || + (make = rose_make_new(sk)) == NULL) { + rose_transmit_clear_request(neigh, lci, ROSE_NETWORK_CONGESTION, 120); + return 0; + } + + skb->sk = make; + make->sk_state = TCP_ESTABLISHED; + make_rose = rose_sk(make); + + make_rose->lci = lci; + make_rose->dest_addr = facilities.dest_addr; + make_rose->dest_call = facilities.dest_call; + make_rose->dest_ndigis = facilities.dest_ndigis; + for (n = 0 ; n < facilities.dest_ndigis ; n++) + make_rose->dest_digis[n] = facilities.dest_digis[n]; + make_rose->source_addr = facilities.source_addr; + make_rose->source_call = facilities.source_call; + make_rose->source_ndigis = facilities.source_ndigis; + for (n = 0 ; n < facilities.source_ndigis ; n++) + make_rose->source_digis[n] = facilities.source_digis[n]; + make_rose->neighbour = neigh; + make_rose->device = dev; + /* Caller got a reference for us. */ + netdev_tracker_alloc(make_rose->device, &make_rose->dev_tracker, + GFP_ATOMIC); + make_rose->facilities = facilities; + + make_rose->neighbour->use++; + + if (rose_sk(sk)->defer) { + make_rose->state = ROSE_STATE_5; + } else { + rose_write_internal(make, ROSE_CALL_ACCEPTED); + make_rose->state = ROSE_STATE_3; + rose_start_idletimer(make); + } + + make_rose->condition = 0x00; + make_rose->vs = 0; + make_rose->va = 0; + make_rose->vr = 0; + make_rose->vl = 0; + sk_acceptq_added(sk); + + rose_insert_socket(make); + + skb_queue_head(&sk->sk_receive_queue, skb); + + rose_start_heartbeat(make); + + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_data_ready(sk); + + return 1; +} + +static int rose_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) +{ + struct sock *sk = sock->sk; + struct rose_sock *rose = rose_sk(sk); + DECLARE_SOCKADDR(struct sockaddr_rose *, usrose, msg->msg_name); + int err; + struct full_sockaddr_rose srose; + struct sk_buff *skb; + unsigned char *asmptr; + int n, size, qbit = 0; + + if (msg->msg_flags & ~(MSG_DONTWAIT|MSG_EOR|MSG_CMSG_COMPAT)) + return -EINVAL; + + if (sock_flag(sk, SOCK_ZAPPED)) + return -EADDRNOTAVAIL; + + if (sk->sk_shutdown & SEND_SHUTDOWN) { + send_sig(SIGPIPE, current, 0); + return -EPIPE; + } + + if (rose->neighbour == NULL || rose->device == NULL) + return -ENETUNREACH; + + if (usrose != NULL) { + if (msg->msg_namelen != sizeof(struct sockaddr_rose) && msg->msg_namelen != sizeof(struct full_sockaddr_rose)) + return -EINVAL; + memset(&srose, 0, sizeof(struct full_sockaddr_rose)); + memcpy(&srose, usrose, msg->msg_namelen); + if (rosecmp(&rose->dest_addr, &srose.srose_addr) != 0 || + ax25cmp(&rose->dest_call, &srose.srose_call) != 0) + return -EISCONN; + if (srose.srose_ndigis != rose->dest_ndigis) + return -EISCONN; + if (srose.srose_ndigis == rose->dest_ndigis) { + for (n = 0 ; n < srose.srose_ndigis ; n++) + if (ax25cmp(&rose->dest_digis[n], + &srose.srose_digis[n])) + return -EISCONN; + } + if (srose.srose_family != AF_ROSE) + return -EINVAL; + } else { + if (sk->sk_state != TCP_ESTABLISHED) + return -ENOTCONN; + + srose.srose_family = AF_ROSE; + srose.srose_addr = rose->dest_addr; + srose.srose_call = rose->dest_call; + srose.srose_ndigis = rose->dest_ndigis; + for (n = 0 ; n < rose->dest_ndigis ; n++) + srose.srose_digis[n] = rose->dest_digis[n]; + } + + /* Build a packet */ + /* Sanity check the packet size */ + if (len > 65535) + return -EMSGSIZE; + + size = len + AX25_BPQ_HEADER_LEN + AX25_MAX_HEADER_LEN + ROSE_MIN_LEN; + + if ((skb = sock_alloc_send_skb(sk, size, msg->msg_flags & MSG_DONTWAIT, &err)) == NULL) + return err; + + skb_reserve(skb, AX25_BPQ_HEADER_LEN + AX25_MAX_HEADER_LEN + ROSE_MIN_LEN); + + /* + * Put the data on the end + */ + + skb_reset_transport_header(skb); + skb_put(skb, len); + + err = memcpy_from_msg(skb_transport_header(skb), msg, len); + if (err) { + kfree_skb(skb); + return err; + } + + /* + * If the Q BIT Include socket option is in force, the first + * byte of the user data is the logical value of the Q Bit. + */ + if (rose->qbitincl) { + qbit = skb->data[0]; + skb_pull(skb, 1); + } + + /* + * Push down the ROSE header + */ + asmptr = skb_push(skb, ROSE_MIN_LEN); + + /* Build a ROSE Network header */ + asmptr[0] = ((rose->lci >> 8) & 0x0F) | ROSE_GFI; + asmptr[1] = (rose->lci >> 0) & 0xFF; + asmptr[2] = ROSE_DATA; + + if (qbit) + asmptr[0] |= ROSE_Q_BIT; + + if (sk->sk_state != TCP_ESTABLISHED) { + kfree_skb(skb); + return -ENOTCONN; + } + +#ifdef M_BIT +#define ROSE_PACLEN (256-ROSE_MIN_LEN) + if (skb->len - ROSE_MIN_LEN > ROSE_PACLEN) { + unsigned char header[ROSE_MIN_LEN]; + struct sk_buff *skbn; + int frontlen; + int lg; + + /* Save a copy of the Header */ + skb_copy_from_linear_data(skb, header, ROSE_MIN_LEN); + skb_pull(skb, ROSE_MIN_LEN); + + frontlen = skb_headroom(skb); + + while (skb->len > 0) { + if ((skbn = sock_alloc_send_skb(sk, frontlen + ROSE_PACLEN, 0, &err)) == NULL) { + kfree_skb(skb); + return err; + } + + skbn->sk = sk; + skbn->free = 1; + skbn->arp = 1; + + skb_reserve(skbn, frontlen); + + lg = (ROSE_PACLEN > skb->len) ? skb->len : ROSE_PACLEN; + + /* Copy the user data */ + skb_copy_from_linear_data(skb, skb_put(skbn, lg), lg); + skb_pull(skb, lg); + + /* Duplicate the Header */ + skb_push(skbn, ROSE_MIN_LEN); + skb_copy_to_linear_data(skbn, header, ROSE_MIN_LEN); + + if (skb->len > 0) + skbn->data[2] |= M_BIT; + + skb_queue_tail(&sk->sk_write_queue, skbn); /* Throw it on the queue */ + } + + skb->free = 1; + kfree_skb(skb); + } else { + skb_queue_tail(&sk->sk_write_queue, skb); /* Throw it on the queue */ + } +#else + skb_queue_tail(&sk->sk_write_queue, skb); /* Shove it onto the queue */ +#endif + + rose_kick(sk); + + return len; +} + + +static int rose_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, + int flags) +{ + struct sock *sk = sock->sk; + struct rose_sock *rose = rose_sk(sk); + size_t copied; + unsigned char *asmptr; + struct sk_buff *skb; + int n, er, qbit; + + /* + * This works for seqpacket too. The receiver has ordered the queue for + * us! We do one quick check first though + */ + if (sk->sk_state != TCP_ESTABLISHED) + return -ENOTCONN; + + /* Now we can treat all alike */ + skb = skb_recv_datagram(sk, flags, &er); + if (!skb) + return er; + + qbit = (skb->data[0] & ROSE_Q_BIT) == ROSE_Q_BIT; + + skb_pull(skb, ROSE_MIN_LEN); + + if (rose->qbitincl) { + asmptr = skb_push(skb, 1); + *asmptr = qbit; + } + + skb_reset_transport_header(skb); + copied = skb->len; + + if (copied > size) { + copied = size; + msg->msg_flags |= MSG_TRUNC; + } + + skb_copy_datagram_msg(skb, 0, msg, copied); + + if (msg->msg_name) { + struct sockaddr_rose *srose; + DECLARE_SOCKADDR(struct full_sockaddr_rose *, full_srose, + msg->msg_name); + + memset(msg->msg_name, 0, sizeof(struct full_sockaddr_rose)); + srose = msg->msg_name; + srose->srose_family = AF_ROSE; + srose->srose_addr = rose->dest_addr; + srose->srose_call = rose->dest_call; + srose->srose_ndigis = rose->dest_ndigis; + for (n = 0 ; n < rose->dest_ndigis ; n++) + full_srose->srose_digis[n] = rose->dest_digis[n]; + msg->msg_namelen = sizeof(struct full_sockaddr_rose); + } + + skb_free_datagram(sk, skb); + + return copied; +} + + +static int rose_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + struct sock *sk = sock->sk; + struct rose_sock *rose = rose_sk(sk); + void __user *argp = (void __user *)arg; + + switch (cmd) { + case TIOCOUTQ: { + long amount; + + amount = sk->sk_sndbuf - sk_wmem_alloc_get(sk); + if (amount < 0) + amount = 0; + return put_user(amount, (unsigned int __user *) argp); + } + + case TIOCINQ: { + struct sk_buff *skb; + long amount = 0L; + + spin_lock_irq(&sk->sk_receive_queue.lock); + if ((skb = skb_peek(&sk->sk_receive_queue)) != NULL) + amount = skb->len; + spin_unlock_irq(&sk->sk_receive_queue.lock); + return put_user(amount, (unsigned int __user *) argp); + } + + case SIOCGIFADDR: + case SIOCSIFADDR: + case SIOCGIFDSTADDR: + case SIOCSIFDSTADDR: + case SIOCGIFBRDADDR: + case SIOCSIFBRDADDR: + case SIOCGIFNETMASK: + case SIOCSIFNETMASK: + case SIOCGIFMETRIC: + case SIOCSIFMETRIC: + return -EINVAL; + + case SIOCADDRT: + case SIOCDELRT: + case SIOCRSCLRRT: + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + return rose_rt_ioctl(cmd, argp); + + case SIOCRSGCAUSE: { + struct rose_cause_struct rose_cause; + rose_cause.cause = rose->cause; + rose_cause.diagnostic = rose->diagnostic; + return copy_to_user(argp, &rose_cause, sizeof(struct rose_cause_struct)) ? -EFAULT : 0; + } + + case SIOCRSSCAUSE: { + struct rose_cause_struct rose_cause; + if (copy_from_user(&rose_cause, argp, sizeof(struct rose_cause_struct))) + return -EFAULT; + rose->cause = rose_cause.cause; + rose->diagnostic = rose_cause.diagnostic; + return 0; + } + + case SIOCRSSL2CALL: + if (!capable(CAP_NET_ADMIN)) return -EPERM; + if (ax25cmp(&rose_callsign, &null_ax25_address) != 0) + ax25_listen_release(&rose_callsign, NULL); + if (copy_from_user(&rose_callsign, argp, sizeof(ax25_address))) + return -EFAULT; + if (ax25cmp(&rose_callsign, &null_ax25_address) != 0) + return ax25_listen_register(&rose_callsign, NULL); + + return 0; + + case SIOCRSGL2CALL: + return copy_to_user(argp, &rose_callsign, sizeof(ax25_address)) ? -EFAULT : 0; + + case SIOCRSACCEPT: + if (rose->state == ROSE_STATE_5) { + rose_write_internal(sk, ROSE_CALL_ACCEPTED); + rose_start_idletimer(sk); + rose->condition = 0x00; + rose->vs = 0; + rose->va = 0; + rose->vr = 0; + rose->vl = 0; + rose->state = ROSE_STATE_3; + } + return 0; + + default: + return -ENOIOCTLCMD; + } + + return 0; +} + +#ifdef CONFIG_PROC_FS +static void *rose_info_start(struct seq_file *seq, loff_t *pos) + __acquires(rose_list_lock) +{ + spin_lock_bh(&rose_list_lock); + return seq_hlist_start_head(&rose_list, *pos); +} + +static void *rose_info_next(struct seq_file *seq, void *v, loff_t *pos) +{ + return seq_hlist_next(v, &rose_list, pos); +} + +static void rose_info_stop(struct seq_file *seq, void *v) + __releases(rose_list_lock) +{ + spin_unlock_bh(&rose_list_lock); +} + +static int rose_info_show(struct seq_file *seq, void *v) +{ + char buf[11], rsbuf[11]; + + if (v == SEQ_START_TOKEN) + seq_puts(seq, + "dest_addr dest_call src_addr src_call dev lci neigh st vs vr va t t1 t2 t3 hb idle Snd-Q Rcv-Q inode\n"); + + else { + struct sock *s = sk_entry(v); + struct rose_sock *rose = rose_sk(s); + const char *devname, *callsign; + const struct net_device *dev = rose->device; + + if (!dev) + devname = "???"; + else + devname = dev->name; + + seq_printf(seq, "%-10s %-9s ", + rose2asc(rsbuf, &rose->dest_addr), + ax2asc(buf, &rose->dest_call)); + + if (ax25cmp(&rose->source_call, &null_ax25_address) == 0) + callsign = "??????-?"; + else + callsign = ax2asc(buf, &rose->source_call); + + seq_printf(seq, + "%-10s %-9s %-5s %3.3X %05d %d %d %d %d %3lu %3lu %3lu %3lu %3lu %3lu/%03lu %5d %5d %ld\n", + rose2asc(rsbuf, &rose->source_addr), + callsign, + devname, + rose->lci & 0x0FFF, + (rose->neighbour) ? rose->neighbour->number : 0, + rose->state, + rose->vs, + rose->vr, + rose->va, + ax25_display_timer(&rose->timer) / HZ, + rose->t1 / HZ, + rose->t2 / HZ, + rose->t3 / HZ, + rose->hb / HZ, + ax25_display_timer(&rose->idletimer) / (60 * HZ), + rose->idle / (60 * HZ), + sk_wmem_alloc_get(s), + sk_rmem_alloc_get(s), + s->sk_socket ? SOCK_INODE(s->sk_socket)->i_ino : 0L); + } + + return 0; +} + +static const struct seq_operations rose_info_seqops = { + .start = rose_info_start, + .next = rose_info_next, + .stop = rose_info_stop, + .show = rose_info_show, +}; +#endif /* CONFIG_PROC_FS */ + +static const struct net_proto_family rose_family_ops = { + .family = PF_ROSE, + .create = rose_create, + .owner = THIS_MODULE, +}; + +static const struct proto_ops rose_proto_ops = { + .family = PF_ROSE, + .owner = THIS_MODULE, + .release = rose_release, + .bind = rose_bind, + .connect = rose_connect, + .socketpair = sock_no_socketpair, + .accept = rose_accept, + .getname = rose_getname, + .poll = datagram_poll, + .ioctl = rose_ioctl, + .gettstamp = sock_gettstamp, + .listen = rose_listen, + .shutdown = sock_no_shutdown, + .setsockopt = rose_setsockopt, + .getsockopt = rose_getsockopt, + .sendmsg = rose_sendmsg, + .recvmsg = rose_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static struct notifier_block rose_dev_notifier = { + .notifier_call = rose_device_event, +}; + +static struct net_device **dev_rose; + +static struct ax25_protocol rose_pid = { + .pid = AX25_P_ROSE, + .func = rose_route_frame +}; + +static struct ax25_linkfail rose_linkfail_notifier = { + .func = rose_link_failed +}; + +static int __init rose_proto_init(void) +{ + int i; + int rc; + + if (rose_ndevs > 0x7FFFFFFF/sizeof(struct net_device *)) { + printk(KERN_ERR "ROSE: rose_proto_init - rose_ndevs parameter too large\n"); + rc = -EINVAL; + goto out; + } + + rc = proto_register(&rose_proto, 0); + if (rc != 0) + goto out; + + rose_callsign = null_ax25_address; + + dev_rose = kcalloc(rose_ndevs, sizeof(struct net_device *), + GFP_KERNEL); + if (dev_rose == NULL) { + printk(KERN_ERR "ROSE: rose_proto_init - unable to allocate device structure\n"); + rc = -ENOMEM; + goto out_proto_unregister; + } + + for (i = 0; i < rose_ndevs; i++) { + struct net_device *dev; + char name[IFNAMSIZ]; + + sprintf(name, "rose%d", i); + dev = alloc_netdev(0, name, NET_NAME_UNKNOWN, rose_setup); + if (!dev) { + printk(KERN_ERR "ROSE: rose_proto_init - unable to allocate memory\n"); + rc = -ENOMEM; + goto fail; + } + rc = register_netdev(dev); + if (rc) { + printk(KERN_ERR "ROSE: netdevice registration failed\n"); + free_netdev(dev); + goto fail; + } + rose_set_lockdep_key(dev); + dev_rose[i] = dev; + } + + sock_register(&rose_family_ops); + register_netdevice_notifier(&rose_dev_notifier); + + ax25_register_pid(&rose_pid); + ax25_linkfail_register(&rose_linkfail_notifier); + +#ifdef CONFIG_SYSCTL + rose_register_sysctl(); +#endif + rose_loopback_init(); + + rose_add_loopback_neigh(); + + proc_create_seq("rose", 0444, init_net.proc_net, &rose_info_seqops); + proc_create_seq("rose_neigh", 0444, init_net.proc_net, + &rose_neigh_seqops); + proc_create_seq("rose_nodes", 0444, init_net.proc_net, + &rose_node_seqops); + proc_create_seq("rose_routes", 0444, init_net.proc_net, + &rose_route_seqops); +out: + return rc; +fail: + while (--i >= 0) { + unregister_netdev(dev_rose[i]); + free_netdev(dev_rose[i]); + } + kfree(dev_rose); +out_proto_unregister: + proto_unregister(&rose_proto); + goto out; +} +module_init(rose_proto_init); + +module_param(rose_ndevs, int, 0); +MODULE_PARM_DESC(rose_ndevs, "number of ROSE devices"); + +MODULE_AUTHOR("Jonathan Naylor G4KLX "); +MODULE_DESCRIPTION("The amateur radio ROSE network layer protocol"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_ROSE); + +static void __exit rose_exit(void) +{ + int i; + + remove_proc_entry("rose", init_net.proc_net); + remove_proc_entry("rose_neigh", init_net.proc_net); + remove_proc_entry("rose_nodes", init_net.proc_net); + remove_proc_entry("rose_routes", init_net.proc_net); + rose_loopback_clear(); + + rose_rt_free(); + + ax25_protocol_release(AX25_P_ROSE); + ax25_linkfail_release(&rose_linkfail_notifier); + + if (ax25cmp(&rose_callsign, &null_ax25_address) != 0) + ax25_listen_release(&rose_callsign, NULL); + +#ifdef CONFIG_SYSCTL + rose_unregister_sysctl(); +#endif + unregister_netdevice_notifier(&rose_dev_notifier); + + sock_unregister(PF_ROSE); + + for (i = 0; i < rose_ndevs; i++) { + struct net_device *dev = dev_rose[i]; + + if (dev) { + unregister_netdev(dev); + free_netdev(dev); + } + } + + kfree(dev_rose); + proto_unregister(&rose_proto); +} + +module_exit(rose_exit); diff --git a/net/rose/rose_dev.c b/net/rose/rose_dev.c new file mode 100644 index 000000000..f1a76a582 --- /dev/null +++ b/net/rose/rose_dev.c @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +static int rose_header(struct sk_buff *skb, struct net_device *dev, + unsigned short type, + const void *daddr, const void *saddr, unsigned int len) +{ + unsigned char *buff = skb_push(skb, ROSE_MIN_LEN + 2); + + if (daddr) + memcpy(buff + 7, daddr, dev->addr_len); + + *buff++ = ROSE_GFI | ROSE_Q_BIT; + *buff++ = 0x00; + *buff++ = ROSE_DATA; + *buff++ = 0x7F; + *buff++ = AX25_P_IP; + + if (daddr != NULL) + return 37; + + return -37; +} + +static int rose_set_mac_address(struct net_device *dev, void *addr) +{ + struct sockaddr *sa = addr; + int err; + + if (!memcmp(dev->dev_addr, sa->sa_data, dev->addr_len)) + return 0; + + if (dev->flags & IFF_UP) { + err = rose_add_loopback_node((rose_address *)sa->sa_data); + if (err) + return err; + + rose_del_loopback_node((const rose_address *)dev->dev_addr); + } + + dev_addr_set(dev, sa->sa_data); + + return 0; +} + +static int rose_open(struct net_device *dev) +{ + int err; + + err = rose_add_loopback_node((const rose_address *)dev->dev_addr); + if (err) + return err; + + netif_start_queue(dev); + + return 0; +} + +static int rose_close(struct net_device *dev) +{ + netif_stop_queue(dev); + rose_del_loopback_node((const rose_address *)dev->dev_addr); + return 0; +} + +static netdev_tx_t rose_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct net_device_stats *stats = &dev->stats; + unsigned int len = skb->len; + + if (!netif_running(dev)) { + printk(KERN_ERR "ROSE: rose_xmit - called when iface is down\n"); + return NETDEV_TX_BUSY; + } + + if (!rose_route_frame(skb, NULL)) { + dev_kfree_skb(skb); + stats->tx_errors++; + return NETDEV_TX_OK; + } + + stats->tx_packets++; + stats->tx_bytes += len; + return NETDEV_TX_OK; +} + +static const struct header_ops rose_header_ops = { + .create = rose_header, +}; + +static const struct net_device_ops rose_netdev_ops = { + .ndo_open = rose_open, + .ndo_stop = rose_close, + .ndo_start_xmit = rose_xmit, + .ndo_set_mac_address = rose_set_mac_address, +}; + +void rose_setup(struct net_device *dev) +{ + dev->mtu = ROSE_MAX_PACKET_SIZE - 2; + dev->netdev_ops = &rose_netdev_ops; + + dev->header_ops = &rose_header_ops; + dev->hard_header_len = AX25_BPQ_HEADER_LEN + AX25_MAX_HEADER_LEN + ROSE_MIN_LEN; + dev->addr_len = ROSE_ADDR_LEN; + dev->type = ARPHRD_ROSE; + + /* New-style flags. */ + dev->flags = IFF_NOARP; +} diff --git a/net/rose/rose_in.c b/net/rose/rose_in.c new file mode 100644 index 000000000..4d67f36dc --- /dev/null +++ b/net/rose/rose_in.c @@ -0,0 +1,294 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * + * Most of this code is based on the SDL diagrams published in the 7th ARRL + * Computer Networking Conference papers. The diagrams have mistakes in them, + * but are mostly correct. Before you modify the code could you read the SDL + * diagrams as the code is not obvious and probably very easy to break. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * State machine for state 1, Awaiting Call Accepted State. + * The handling of the timer(s) is in file rose_timer.c. + * Handling of state 0 and connection release is in af_rose.c. + */ +static int rose_state1_machine(struct sock *sk, struct sk_buff *skb, int frametype) +{ + struct rose_sock *rose = rose_sk(sk); + + switch (frametype) { + case ROSE_CALL_ACCEPTED: + rose_stop_timer(sk); + rose_start_idletimer(sk); + rose->condition = 0x00; + rose->vs = 0; + rose->va = 0; + rose->vr = 0; + rose->vl = 0; + rose->state = ROSE_STATE_3; + sk->sk_state = TCP_ESTABLISHED; + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_state_change(sk); + break; + + case ROSE_CLEAR_REQUEST: + rose_write_internal(sk, ROSE_CLEAR_CONFIRMATION); + rose_disconnect(sk, ECONNREFUSED, skb->data[3], skb->data[4]); + rose->neighbour->use--; + break; + + default: + break; + } + + return 0; +} + +/* + * State machine for state 2, Awaiting Clear Confirmation State. + * The handling of the timer(s) is in file rose_timer.c + * Handling of state 0 and connection release is in af_rose.c. + */ +static int rose_state2_machine(struct sock *sk, struct sk_buff *skb, int frametype) +{ + struct rose_sock *rose = rose_sk(sk); + + switch (frametype) { + case ROSE_CLEAR_REQUEST: + rose_write_internal(sk, ROSE_CLEAR_CONFIRMATION); + rose_disconnect(sk, 0, skb->data[3], skb->data[4]); + rose->neighbour->use--; + break; + + case ROSE_CLEAR_CONFIRMATION: + rose_disconnect(sk, 0, -1, -1); + rose->neighbour->use--; + break; + + default: + break; + } + + return 0; +} + +/* + * State machine for state 3, Connected State. + * The handling of the timer(s) is in file rose_timer.c + * Handling of state 0 and connection release is in af_rose.c. + */ +static int rose_state3_machine(struct sock *sk, struct sk_buff *skb, int frametype, int ns, int nr, int q, int d, int m) +{ + struct rose_sock *rose = rose_sk(sk); + int queued = 0; + + switch (frametype) { + case ROSE_RESET_REQUEST: + rose_stop_timer(sk); + rose_start_idletimer(sk); + rose_write_internal(sk, ROSE_RESET_CONFIRMATION); + rose->condition = 0x00; + rose->vs = 0; + rose->vr = 0; + rose->va = 0; + rose->vl = 0; + rose_requeue_frames(sk); + break; + + case ROSE_CLEAR_REQUEST: + rose_write_internal(sk, ROSE_CLEAR_CONFIRMATION); + rose_disconnect(sk, 0, skb->data[3], skb->data[4]); + rose->neighbour->use--; + break; + + case ROSE_RR: + case ROSE_RNR: + if (!rose_validate_nr(sk, nr)) { + rose_write_internal(sk, ROSE_RESET_REQUEST); + rose->condition = 0x00; + rose->vs = 0; + rose->vr = 0; + rose->va = 0; + rose->vl = 0; + rose->state = ROSE_STATE_4; + rose_start_t2timer(sk); + rose_stop_idletimer(sk); + } else { + rose_frames_acked(sk, nr); + if (frametype == ROSE_RNR) { + rose->condition |= ROSE_COND_PEER_RX_BUSY; + } else { + rose->condition &= ~ROSE_COND_PEER_RX_BUSY; + } + } + break; + + case ROSE_DATA: /* XXX */ + rose->condition &= ~ROSE_COND_PEER_RX_BUSY; + if (!rose_validate_nr(sk, nr)) { + rose_write_internal(sk, ROSE_RESET_REQUEST); + rose->condition = 0x00; + rose->vs = 0; + rose->vr = 0; + rose->va = 0; + rose->vl = 0; + rose->state = ROSE_STATE_4; + rose_start_t2timer(sk); + rose_stop_idletimer(sk); + break; + } + rose_frames_acked(sk, nr); + if (ns == rose->vr) { + rose_start_idletimer(sk); + if (sk_filter_trim_cap(sk, skb, ROSE_MIN_LEN) == 0 && + __sock_queue_rcv_skb(sk, skb) == 0) { + rose->vr = (rose->vr + 1) % ROSE_MODULUS; + queued = 1; + } else { + /* Should never happen ! */ + rose_write_internal(sk, ROSE_RESET_REQUEST); + rose->condition = 0x00; + rose->vs = 0; + rose->vr = 0; + rose->va = 0; + rose->vl = 0; + rose->state = ROSE_STATE_4; + rose_start_t2timer(sk); + rose_stop_idletimer(sk); + break; + } + if (atomic_read(&sk->sk_rmem_alloc) > + (sk->sk_rcvbuf >> 1)) + rose->condition |= ROSE_COND_OWN_RX_BUSY; + } + /* + * If the window is full, ack the frame, else start the + * acknowledge hold back timer. + */ + if (((rose->vl + sysctl_rose_window_size) % ROSE_MODULUS) == rose->vr) { + rose->condition &= ~ROSE_COND_ACK_PENDING; + rose_stop_timer(sk); + rose_enquiry_response(sk); + } else { + rose->condition |= ROSE_COND_ACK_PENDING; + rose_start_hbtimer(sk); + } + break; + + default: + printk(KERN_WARNING "ROSE: unknown %02X in state 3\n", frametype); + break; + } + + return queued; +} + +/* + * State machine for state 4, Awaiting Reset Confirmation State. + * The handling of the timer(s) is in file rose_timer.c + * Handling of state 0 and connection release is in af_rose.c. + */ +static int rose_state4_machine(struct sock *sk, struct sk_buff *skb, int frametype) +{ + struct rose_sock *rose = rose_sk(sk); + + switch (frametype) { + case ROSE_RESET_REQUEST: + rose_write_internal(sk, ROSE_RESET_CONFIRMATION); + fallthrough; + case ROSE_RESET_CONFIRMATION: + rose_stop_timer(sk); + rose_start_idletimer(sk); + rose->condition = 0x00; + rose->va = 0; + rose->vr = 0; + rose->vs = 0; + rose->vl = 0; + rose->state = ROSE_STATE_3; + rose_requeue_frames(sk); + break; + + case ROSE_CLEAR_REQUEST: + rose_write_internal(sk, ROSE_CLEAR_CONFIRMATION); + rose_disconnect(sk, 0, skb->data[3], skb->data[4]); + rose->neighbour->use--; + break; + + default: + break; + } + + return 0; +} + +/* + * State machine for state 5, Awaiting Call Acceptance State. + * The handling of the timer(s) is in file rose_timer.c + * Handling of state 0 and connection release is in af_rose.c. + */ +static int rose_state5_machine(struct sock *sk, struct sk_buff *skb, int frametype) +{ + if (frametype == ROSE_CLEAR_REQUEST) { + rose_write_internal(sk, ROSE_CLEAR_CONFIRMATION); + rose_disconnect(sk, 0, skb->data[3], skb->data[4]); + rose_sk(sk)->neighbour->use--; + } + + return 0; +} + +/* Higher level upcall for a LAPB frame */ +int rose_process_rx_frame(struct sock *sk, struct sk_buff *skb) +{ + struct rose_sock *rose = rose_sk(sk); + int queued = 0, frametype, ns, nr, q, d, m; + + if (rose->state == ROSE_STATE_0) + return 0; + + frametype = rose_decode(skb, &ns, &nr, &q, &d, &m); + + switch (rose->state) { + case ROSE_STATE_1: + queued = rose_state1_machine(sk, skb, frametype); + break; + case ROSE_STATE_2: + queued = rose_state2_machine(sk, skb, frametype); + break; + case ROSE_STATE_3: + queued = rose_state3_machine(sk, skb, frametype, ns, nr, q, d, m); + break; + case ROSE_STATE_4: + queued = rose_state4_machine(sk, skb, frametype); + break; + case ROSE_STATE_5: + queued = rose_state5_machine(sk, skb, frametype); + break; + } + + rose_kick(sk); + + return queued; +} diff --git a/net/rose/rose_link.c b/net/rose/rose_link.c new file mode 100644 index 000000000..0f77ae8ef --- /dev/null +++ b/net/rose/rose_link.c @@ -0,0 +1,289 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void rose_ftimer_expiry(struct timer_list *); +static void rose_t0timer_expiry(struct timer_list *); + +static void rose_transmit_restart_confirmation(struct rose_neigh *neigh); +static void rose_transmit_restart_request(struct rose_neigh *neigh); + +void rose_start_ftimer(struct rose_neigh *neigh) +{ + del_timer(&neigh->ftimer); + + neigh->ftimer.function = rose_ftimer_expiry; + neigh->ftimer.expires = + jiffies + msecs_to_jiffies(sysctl_rose_link_fail_timeout); + + add_timer(&neigh->ftimer); +} + +static void rose_start_t0timer(struct rose_neigh *neigh) +{ + del_timer(&neigh->t0timer); + + neigh->t0timer.function = rose_t0timer_expiry; + neigh->t0timer.expires = + jiffies + msecs_to_jiffies(sysctl_rose_restart_request_timeout); + + add_timer(&neigh->t0timer); +} + +void rose_stop_ftimer(struct rose_neigh *neigh) +{ + del_timer(&neigh->ftimer); +} + +void rose_stop_t0timer(struct rose_neigh *neigh) +{ + del_timer(&neigh->t0timer); +} + +int rose_ftimer_running(struct rose_neigh *neigh) +{ + return timer_pending(&neigh->ftimer); +} + +static int rose_t0timer_running(struct rose_neigh *neigh) +{ + return timer_pending(&neigh->t0timer); +} + +static void rose_ftimer_expiry(struct timer_list *t) +{ +} + +static void rose_t0timer_expiry(struct timer_list *t) +{ + struct rose_neigh *neigh = from_timer(neigh, t, t0timer); + + rose_transmit_restart_request(neigh); + + neigh->dce_mode = 0; + + rose_start_t0timer(neigh); +} + +/* + * Interface to ax25_send_frame. Changes my level 2 callsign depending + * on whether we have a global ROSE callsign or use the default port + * callsign. + */ +static int rose_send_frame(struct sk_buff *skb, struct rose_neigh *neigh) +{ + const ax25_address *rose_call; + ax25_cb *ax25s; + + if (ax25cmp(&rose_callsign, &null_ax25_address) == 0) + rose_call = (const ax25_address *)neigh->dev->dev_addr; + else + rose_call = &rose_callsign; + + ax25s = neigh->ax25; + neigh->ax25 = ax25_send_frame(skb, 260, rose_call, &neigh->callsign, neigh->digipeat, neigh->dev); + if (ax25s) + ax25_cb_put(ax25s); + + return neigh->ax25 != NULL; +} + +/* + * Interface to ax25_link_up. Changes my level 2 callsign depending + * on whether we have a global ROSE callsign or use the default port + * callsign. + */ +static int rose_link_up(struct rose_neigh *neigh) +{ + const ax25_address *rose_call; + ax25_cb *ax25s; + + if (ax25cmp(&rose_callsign, &null_ax25_address) == 0) + rose_call = (const ax25_address *)neigh->dev->dev_addr; + else + rose_call = &rose_callsign; + + ax25s = neigh->ax25; + neigh->ax25 = ax25_find_cb(rose_call, &neigh->callsign, neigh->digipeat, neigh->dev); + if (ax25s) + ax25_cb_put(ax25s); + + return neigh->ax25 != NULL; +} + +/* + * This handles all restart and diagnostic frames. + */ +void rose_link_rx_restart(struct sk_buff *skb, struct rose_neigh *neigh, unsigned short frametype) +{ + struct sk_buff *skbn; + + switch (frametype) { + case ROSE_RESTART_REQUEST: + rose_stop_t0timer(neigh); + neigh->restarted = 1; + neigh->dce_mode = (skb->data[3] == ROSE_DTE_ORIGINATED); + rose_transmit_restart_confirmation(neigh); + break; + + case ROSE_RESTART_CONFIRMATION: + rose_stop_t0timer(neigh); + neigh->restarted = 1; + break; + + case ROSE_DIAGNOSTIC: + pr_warn("ROSE: received diagnostic #%d - %3ph\n", skb->data[3], + skb->data + 4); + break; + + default: + printk(KERN_WARNING "ROSE: received unknown %02X with LCI 000\n", frametype); + break; + } + + if (neigh->restarted) { + while ((skbn = skb_dequeue(&neigh->queue)) != NULL) + if (!rose_send_frame(skbn, neigh)) + kfree_skb(skbn); + } +} + +/* + * This routine is called when a Restart Request is needed + */ +static void rose_transmit_restart_request(struct rose_neigh *neigh) +{ + struct sk_buff *skb; + unsigned char *dptr; + int len; + + len = AX25_BPQ_HEADER_LEN + AX25_MAX_HEADER_LEN + ROSE_MIN_LEN + 3; + + if ((skb = alloc_skb(len, GFP_ATOMIC)) == NULL) + return; + + skb_reserve(skb, AX25_BPQ_HEADER_LEN + AX25_MAX_HEADER_LEN); + + dptr = skb_put(skb, ROSE_MIN_LEN + 3); + + *dptr++ = AX25_P_ROSE; + *dptr++ = ROSE_GFI; + *dptr++ = 0x00; + *dptr++ = ROSE_RESTART_REQUEST; + *dptr++ = ROSE_DTE_ORIGINATED; + *dptr++ = 0; + + if (!rose_send_frame(skb, neigh)) + kfree_skb(skb); +} + +/* + * This routine is called when a Restart Confirmation is needed + */ +static void rose_transmit_restart_confirmation(struct rose_neigh *neigh) +{ + struct sk_buff *skb; + unsigned char *dptr; + int len; + + len = AX25_BPQ_HEADER_LEN + AX25_MAX_HEADER_LEN + ROSE_MIN_LEN + 1; + + if ((skb = alloc_skb(len, GFP_ATOMIC)) == NULL) + return; + + skb_reserve(skb, AX25_BPQ_HEADER_LEN + AX25_MAX_HEADER_LEN); + + dptr = skb_put(skb, ROSE_MIN_LEN + 1); + + *dptr++ = AX25_P_ROSE; + *dptr++ = ROSE_GFI; + *dptr++ = 0x00; + *dptr++ = ROSE_RESTART_CONFIRMATION; + + if (!rose_send_frame(skb, neigh)) + kfree_skb(skb); +} + +/* + * This routine is called when a Clear Request is needed outside of the context + * of a connected socket. + */ +void rose_transmit_clear_request(struct rose_neigh *neigh, unsigned int lci, unsigned char cause, unsigned char diagnostic) +{ + struct sk_buff *skb; + unsigned char *dptr; + int len; + + if (!neigh->dev) + return; + + len = AX25_BPQ_HEADER_LEN + AX25_MAX_HEADER_LEN + ROSE_MIN_LEN + 3; + + if ((skb = alloc_skb(len, GFP_ATOMIC)) == NULL) + return; + + skb_reserve(skb, AX25_BPQ_HEADER_LEN + AX25_MAX_HEADER_LEN); + + dptr = skb_put(skb, ROSE_MIN_LEN + 3); + + *dptr++ = AX25_P_ROSE; + *dptr++ = ((lci >> 8) & 0x0F) | ROSE_GFI; + *dptr++ = ((lci >> 0) & 0xFF); + *dptr++ = ROSE_CLEAR_REQUEST; + *dptr++ = cause; + *dptr++ = diagnostic; + + if (!rose_send_frame(skb, neigh)) + kfree_skb(skb); +} + +void rose_transmit_link(struct sk_buff *skb, struct rose_neigh *neigh) +{ + unsigned char *dptr; + + if (neigh->loopback) { + rose_loopback_queue(skb, neigh); + return; + } + + if (!rose_link_up(neigh)) + neigh->restarted = 0; + + dptr = skb_push(skb, 1); + *dptr++ = AX25_P_ROSE; + + if (neigh->restarted) { + if (!rose_send_frame(skb, neigh)) + kfree_skb(skb); + } else { + skb_queue_tail(&neigh->queue, skb); + + if (!rose_t0timer_running(neigh)) { + rose_transmit_restart_request(neigh); + neigh->dce_mode = 0; + rose_start_t0timer(neigh); + } + } +} diff --git a/net/rose/rose_loopback.c b/net/rose/rose_loopback.c new file mode 100644 index 000000000..036d92c0a --- /dev/null +++ b/net/rose/rose_loopback.c @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + */ +#include +#include +#include +#include +#include +#include +#include +#include + +static struct sk_buff_head loopback_queue; +#define ROSE_LOOPBACK_LIMIT 1000 +static struct timer_list loopback_timer; + +static void rose_set_loopback_timer(void); +static void rose_loopback_timer(struct timer_list *unused); + +void rose_loopback_init(void) +{ + skb_queue_head_init(&loopback_queue); + + timer_setup(&loopback_timer, rose_loopback_timer, 0); +} + +static int rose_loopback_running(void) +{ + return timer_pending(&loopback_timer); +} + +int rose_loopback_queue(struct sk_buff *skb, struct rose_neigh *neigh) +{ + struct sk_buff *skbn = NULL; + + if (skb_queue_len(&loopback_queue) < ROSE_LOOPBACK_LIMIT) + skbn = skb_clone(skb, GFP_ATOMIC); + + if (skbn) { + consume_skb(skb); + skb_queue_tail(&loopback_queue, skbn); + + if (!rose_loopback_running()) + rose_set_loopback_timer(); + } else { + kfree_skb(skb); + } + + return 1; +} + +static void rose_set_loopback_timer(void) +{ + mod_timer(&loopback_timer, jiffies + 10); +} + +static void rose_loopback_timer(struct timer_list *unused) +{ + struct sk_buff *skb; + struct net_device *dev; + rose_address *dest; + struct sock *sk; + unsigned short frametype; + unsigned int lci_i, lci_o; + int count; + + for (count = 0; count < ROSE_LOOPBACK_LIMIT; count++) { + skb = skb_dequeue(&loopback_queue); + if (!skb) + return; + if (skb->len < ROSE_MIN_LEN) { + kfree_skb(skb); + continue; + } + lci_i = ((skb->data[0] << 8) & 0xF00) + ((skb->data[1] << 0) & 0x0FF); + frametype = skb->data[2]; + if (frametype == ROSE_CALL_REQUEST && + (skb->len <= ROSE_CALL_REQ_FACILITIES_OFF || + skb->data[ROSE_CALL_REQ_ADDR_LEN_OFF] != + ROSE_CALL_REQ_ADDR_LEN_VAL)) { + kfree_skb(skb); + continue; + } + dest = (rose_address *)(skb->data + ROSE_CALL_REQ_DEST_ADDR_OFF); + lci_o = ROSE_DEFAULT_MAXVC + 1 - lci_i; + + skb_reset_transport_header(skb); + + sk = rose_find_socket(lci_o, rose_loopback_neigh); + if (sk) { + if (rose_process_rx_frame(sk, skb) == 0) + kfree_skb(skb); + continue; + } + + if (frametype == ROSE_CALL_REQUEST) { + if (!rose_loopback_neigh->dev && + !rose_loopback_neigh->loopback) { + kfree_skb(skb); + continue; + } + + dev = rose_dev_get(dest); + if (!dev) { + kfree_skb(skb); + continue; + } + + if (rose_rx_call_request(skb, dev, rose_loopback_neigh, lci_o) == 0) { + dev_put(dev); + kfree_skb(skb); + } + } else { + kfree_skb(skb); + } + } + if (!skb_queue_empty(&loopback_queue)) + mod_timer(&loopback_timer, jiffies + 1); +} + +void __exit rose_loopback_clear(void) +{ + struct sk_buff *skb; + + del_timer(&loopback_timer); + + while ((skb = skb_dequeue(&loopback_queue)) != NULL) { + skb->sk = NULL; + kfree_skb(skb); + } +} diff --git a/net/rose/rose_out.c b/net/rose/rose_out.c new file mode 100644 index 000000000..9050e33c9 --- /dev/null +++ b/net/rose/rose_out.c @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * This procedure is passed a buffer descriptor for an iframe. It builds + * the rest of the control part of the frame and then writes it out. + */ +static void rose_send_iframe(struct sock *sk, struct sk_buff *skb) +{ + struct rose_sock *rose = rose_sk(sk); + + if (skb == NULL) + return; + + skb->data[2] |= (rose->vr << 5) & 0xE0; + skb->data[2] |= (rose->vs << 1) & 0x0E; + + rose_start_idletimer(sk); + + rose_transmit_link(skb, rose->neighbour); +} + +void rose_kick(struct sock *sk) +{ + struct rose_sock *rose = rose_sk(sk); + struct sk_buff *skb, *skbn; + unsigned short start, end; + + if (rose->state != ROSE_STATE_3) + return; + + if (rose->condition & ROSE_COND_PEER_RX_BUSY) + return; + + if (!skb_peek(&sk->sk_write_queue)) + return; + + start = (skb_peek(&rose->ack_queue) == NULL) ? rose->va : rose->vs; + end = (rose->va + sysctl_rose_window_size) % ROSE_MODULUS; + + if (start == end) + return; + + rose->vs = start; + + /* + * Transmit data until either we're out of data to send or + * the window is full. + */ + + skb = skb_dequeue(&sk->sk_write_queue); + + do { + if ((skbn = skb_clone(skb, GFP_ATOMIC)) == NULL) { + skb_queue_head(&sk->sk_write_queue, skb); + break; + } + + skb_set_owner_w(skbn, sk); + + /* + * Transmit the frame copy. + */ + rose_send_iframe(sk, skbn); + + rose->vs = (rose->vs + 1) % ROSE_MODULUS; + + /* + * Requeue the original data frame. + */ + skb_queue_tail(&rose->ack_queue, skb); + + } while (rose->vs != end && + (skb = skb_dequeue(&sk->sk_write_queue)) != NULL); + + rose->vl = rose->vr; + rose->condition &= ~ROSE_COND_ACK_PENDING; + + rose_stop_timer(sk); +} + +/* + * The following routines are taken from page 170 of the 7th ARRL Computer + * Networking Conference paper, as is the whole state machine. + */ + +void rose_enquiry_response(struct sock *sk) +{ + struct rose_sock *rose = rose_sk(sk); + + if (rose->condition & ROSE_COND_OWN_RX_BUSY) + rose_write_internal(sk, ROSE_RNR); + else + rose_write_internal(sk, ROSE_RR); + + rose->vl = rose->vr; + rose->condition &= ~ROSE_COND_ACK_PENDING; + + rose_stop_timer(sk); +} diff --git a/net/rose/rose_route.c b/net/rose/rose_route.c new file mode 100644 index 000000000..fee772b46 --- /dev/null +++ b/net/rose/rose_route.c @@ -0,0 +1,1326 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright (C) Terry Dawson VK2KTJ (terry@animats.net) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include /* For TIOCINQ/OUTQ */ +#include +#include +#include +#include +#include +#include +#include + +static unsigned int rose_neigh_no = 1; + +static struct rose_node *rose_node_list; +static DEFINE_SPINLOCK(rose_node_list_lock); +static struct rose_neigh *rose_neigh_list; +static DEFINE_SPINLOCK(rose_neigh_list_lock); +static struct rose_route *rose_route_list; +static DEFINE_SPINLOCK(rose_route_list_lock); + +struct rose_neigh *rose_loopback_neigh; + +/* + * Add a new route to a node, and in the process add the node and the + * neighbour if it is new. + */ +static int __must_check rose_add_node(struct rose_route_struct *rose_route, + struct net_device *dev) +{ + struct rose_node *rose_node, *rose_tmpn, *rose_tmpp; + struct rose_neigh *rose_neigh; + int i, res = 0; + + spin_lock_bh(&rose_node_list_lock); + spin_lock_bh(&rose_neigh_list_lock); + + rose_node = rose_node_list; + while (rose_node != NULL) { + if ((rose_node->mask == rose_route->mask) && + (rosecmpm(&rose_route->address, &rose_node->address, + rose_route->mask) == 0)) + break; + rose_node = rose_node->next; + } + + if (rose_node != NULL && rose_node->loopback) { + res = -EINVAL; + goto out; + } + + rose_neigh = rose_neigh_list; + while (rose_neigh != NULL) { + if (ax25cmp(&rose_route->neighbour, + &rose_neigh->callsign) == 0 && + rose_neigh->dev == dev) + break; + rose_neigh = rose_neigh->next; + } + + if (rose_neigh == NULL) { + rose_neigh = kmalloc(sizeof(*rose_neigh), GFP_ATOMIC); + if (rose_neigh == NULL) { + res = -ENOMEM; + goto out; + } + + rose_neigh->callsign = rose_route->neighbour; + rose_neigh->digipeat = NULL; + rose_neigh->ax25 = NULL; + rose_neigh->dev = dev; + rose_neigh->count = 0; + rose_neigh->use = 0; + rose_neigh->dce_mode = 0; + rose_neigh->loopback = 0; + rose_neigh->number = rose_neigh_no++; + rose_neigh->restarted = 0; + + skb_queue_head_init(&rose_neigh->queue); + + timer_setup(&rose_neigh->ftimer, NULL, 0); + timer_setup(&rose_neigh->t0timer, NULL, 0); + + if (rose_route->ndigis != 0) { + rose_neigh->digipeat = + kmalloc(sizeof(ax25_digi), GFP_ATOMIC); + if (rose_neigh->digipeat == NULL) { + kfree(rose_neigh); + res = -ENOMEM; + goto out; + } + + rose_neigh->digipeat->ndigi = rose_route->ndigis; + rose_neigh->digipeat->lastrepeat = -1; + + for (i = 0; i < rose_route->ndigis; i++) { + rose_neigh->digipeat->calls[i] = + rose_route->digipeaters[i]; + rose_neigh->digipeat->repeated[i] = 0; + } + } + + rose_neigh->next = rose_neigh_list; + rose_neigh_list = rose_neigh; + } + + /* + * This is a new node to be inserted into the list. Find where it needs + * to be inserted into the list, and insert it. We want to be sure + * to order the list in descending order of mask size to ensure that + * later when we are searching this list the first match will be the + * best match. + */ + if (rose_node == NULL) { + rose_tmpn = rose_node_list; + rose_tmpp = NULL; + + while (rose_tmpn != NULL) { + if (rose_tmpn->mask > rose_route->mask) { + rose_tmpp = rose_tmpn; + rose_tmpn = rose_tmpn->next; + } else { + break; + } + } + + /* create new node */ + rose_node = kmalloc(sizeof(*rose_node), GFP_ATOMIC); + if (rose_node == NULL) { + res = -ENOMEM; + goto out; + } + + rose_node->address = rose_route->address; + rose_node->mask = rose_route->mask; + rose_node->count = 1; + rose_node->loopback = 0; + rose_node->neighbour[0] = rose_neigh; + + if (rose_tmpn == NULL) { + if (rose_tmpp == NULL) { /* Empty list */ + rose_node_list = rose_node; + rose_node->next = NULL; + } else { + rose_tmpp->next = rose_node; + rose_node->next = NULL; + } + } else { + if (rose_tmpp == NULL) { /* 1st node */ + rose_node->next = rose_node_list; + rose_node_list = rose_node; + } else { + rose_tmpp->next = rose_node; + rose_node->next = rose_tmpn; + } + } + rose_neigh->count++; + + goto out; + } + + /* We have space, slot it in */ + if (rose_node->count < 3) { + rose_node->neighbour[rose_node->count] = rose_neigh; + rose_node->count++; + rose_neigh->count++; + } + +out: + spin_unlock_bh(&rose_neigh_list_lock); + spin_unlock_bh(&rose_node_list_lock); + + return res; +} + +/* + * Caller is holding rose_node_list_lock. + */ +static void rose_remove_node(struct rose_node *rose_node) +{ + struct rose_node *s; + + if ((s = rose_node_list) == rose_node) { + rose_node_list = rose_node->next; + kfree(rose_node); + return; + } + + while (s != NULL && s->next != NULL) { + if (s->next == rose_node) { + s->next = rose_node->next; + kfree(rose_node); + return; + } + + s = s->next; + } +} + +/* + * Caller is holding rose_neigh_list_lock. + */ +static void rose_remove_neigh(struct rose_neigh *rose_neigh) +{ + struct rose_neigh *s; + + del_timer_sync(&rose_neigh->ftimer); + del_timer_sync(&rose_neigh->t0timer); + + skb_queue_purge(&rose_neigh->queue); + + if ((s = rose_neigh_list) == rose_neigh) { + rose_neigh_list = rose_neigh->next; + if (rose_neigh->ax25) + ax25_cb_put(rose_neigh->ax25); + kfree(rose_neigh->digipeat); + kfree(rose_neigh); + return; + } + + while (s != NULL && s->next != NULL) { + if (s->next == rose_neigh) { + s->next = rose_neigh->next; + if (rose_neigh->ax25) + ax25_cb_put(rose_neigh->ax25); + kfree(rose_neigh->digipeat); + kfree(rose_neigh); + return; + } + + s = s->next; + } +} + +/* + * Caller is holding rose_route_list_lock. + */ +static void rose_remove_route(struct rose_route *rose_route) +{ + struct rose_route *s; + + if (rose_route->neigh1 != NULL) + rose_route->neigh1->use--; + + if (rose_route->neigh2 != NULL) + rose_route->neigh2->use--; + + if ((s = rose_route_list) == rose_route) { + rose_route_list = rose_route->next; + kfree(rose_route); + return; + } + + while (s != NULL && s->next != NULL) { + if (s->next == rose_route) { + s->next = rose_route->next; + kfree(rose_route); + return; + } + + s = s->next; + } +} + +/* + * "Delete" a node. Strictly speaking remove a route to a node. The node + * is only deleted if no routes are left to it. + */ +static int rose_del_node(struct rose_route_struct *rose_route, + struct net_device *dev) +{ + struct rose_node *rose_node; + struct rose_neigh *rose_neigh; + int i, err = 0; + + spin_lock_bh(&rose_node_list_lock); + spin_lock_bh(&rose_neigh_list_lock); + + rose_node = rose_node_list; + while (rose_node != NULL) { + if ((rose_node->mask == rose_route->mask) && + (rosecmpm(&rose_route->address, &rose_node->address, + rose_route->mask) == 0)) + break; + rose_node = rose_node->next; + } + + if (rose_node == NULL || rose_node->loopback) { + err = -EINVAL; + goto out; + } + + rose_neigh = rose_neigh_list; + while (rose_neigh != NULL) { + if (ax25cmp(&rose_route->neighbour, + &rose_neigh->callsign) == 0 && + rose_neigh->dev == dev) + break; + rose_neigh = rose_neigh->next; + } + + if (rose_neigh == NULL) { + err = -EINVAL; + goto out; + } + + for (i = 0; i < rose_node->count; i++) { + if (rose_node->neighbour[i] == rose_neigh) { + rose_neigh->count--; + + if (rose_neigh->count == 0 && rose_neigh->use == 0) + rose_remove_neigh(rose_neigh); + + rose_node->count--; + + if (rose_node->count == 0) { + rose_remove_node(rose_node); + } else { + switch (i) { + case 0: + rose_node->neighbour[0] = + rose_node->neighbour[1]; + fallthrough; + case 1: + rose_node->neighbour[1] = + rose_node->neighbour[2]; + break; + case 2: + break; + } + } + goto out; + } + } + err = -EINVAL; + +out: + spin_unlock_bh(&rose_neigh_list_lock); + spin_unlock_bh(&rose_node_list_lock); + + return err; +} + +/* + * Add the loopback neighbour. + */ +void rose_add_loopback_neigh(void) +{ + struct rose_neigh *sn; + + rose_loopback_neigh = kmalloc(sizeof(struct rose_neigh), GFP_KERNEL); + if (!rose_loopback_neigh) + return; + sn = rose_loopback_neigh; + + sn->callsign = null_ax25_address; + sn->digipeat = NULL; + sn->ax25 = NULL; + sn->dev = NULL; + sn->count = 0; + sn->use = 0; + sn->dce_mode = 1; + sn->loopback = 1; + sn->number = rose_neigh_no++; + sn->restarted = 1; + + skb_queue_head_init(&sn->queue); + + timer_setup(&sn->ftimer, NULL, 0); + timer_setup(&sn->t0timer, NULL, 0); + + spin_lock_bh(&rose_neigh_list_lock); + sn->next = rose_neigh_list; + rose_neigh_list = sn; + spin_unlock_bh(&rose_neigh_list_lock); +} + +/* + * Add a loopback node. + */ +int rose_add_loopback_node(const rose_address *address) +{ + struct rose_node *rose_node; + int err = 0; + + spin_lock_bh(&rose_node_list_lock); + + rose_node = rose_node_list; + while (rose_node != NULL) { + if ((rose_node->mask == 10) && + (rosecmpm(address, &rose_node->address, 10) == 0) && + rose_node->loopback) + break; + rose_node = rose_node->next; + } + + if (rose_node != NULL) + goto out; + + if ((rose_node = kmalloc(sizeof(*rose_node), GFP_ATOMIC)) == NULL) { + err = -ENOMEM; + goto out; + } + + rose_node->address = *address; + rose_node->mask = 10; + rose_node->count = 1; + rose_node->loopback = 1; + rose_node->neighbour[0] = rose_loopback_neigh; + + /* Insert at the head of list. Address is always mask=10 */ + rose_node->next = rose_node_list; + rose_node_list = rose_node; + + rose_loopback_neigh->count++; + +out: + spin_unlock_bh(&rose_node_list_lock); + + return err; +} + +/* + * Delete a loopback node. + */ +void rose_del_loopback_node(const rose_address *address) +{ + struct rose_node *rose_node; + + spin_lock_bh(&rose_node_list_lock); + + rose_node = rose_node_list; + while (rose_node != NULL) { + if ((rose_node->mask == 10) && + (rosecmpm(address, &rose_node->address, 10) == 0) && + rose_node->loopback) + break; + rose_node = rose_node->next; + } + + if (rose_node == NULL) + goto out; + + rose_remove_node(rose_node); + + rose_loopback_neigh->count--; + +out: + spin_unlock_bh(&rose_node_list_lock); +} + +/* + * A device has been removed. Remove its routes and neighbours. + */ +void rose_rt_device_down(struct net_device *dev) +{ + struct rose_neigh *s, *rose_neigh; + struct rose_node *t, *rose_node; + int i; + + spin_lock_bh(&rose_node_list_lock); + spin_lock_bh(&rose_neigh_list_lock); + rose_neigh = rose_neigh_list; + while (rose_neigh != NULL) { + s = rose_neigh; + rose_neigh = rose_neigh->next; + + if (s->dev != dev) + continue; + + rose_node = rose_node_list; + + while (rose_node != NULL) { + t = rose_node; + rose_node = rose_node->next; + + for (i = 0; i < t->count; i++) { + if (t->neighbour[i] != s) + continue; + + t->count--; + + switch (i) { + case 0: + t->neighbour[0] = t->neighbour[1]; + fallthrough; + case 1: + t->neighbour[1] = t->neighbour[2]; + break; + case 2: + break; + } + } + + if (t->count <= 0) + rose_remove_node(t); + } + + rose_remove_neigh(s); + } + spin_unlock_bh(&rose_neigh_list_lock); + spin_unlock_bh(&rose_node_list_lock); +} + +#if 0 /* Currently unused */ +/* + * A device has been removed. Remove its links. + */ +void rose_route_device_down(struct net_device *dev) +{ + struct rose_route *s, *rose_route; + + spin_lock_bh(&rose_route_list_lock); + rose_route = rose_route_list; + while (rose_route != NULL) { + s = rose_route; + rose_route = rose_route->next; + + if (s->neigh1->dev == dev || s->neigh2->dev == dev) + rose_remove_route(s); + } + spin_unlock_bh(&rose_route_list_lock); +} +#endif + +/* + * Clear all nodes and neighbours out, except for neighbours with + * active connections going through them. + * Do not clear loopback neighbour and nodes. + */ +static int rose_clear_routes(void) +{ + struct rose_neigh *s, *rose_neigh; + struct rose_node *t, *rose_node; + + spin_lock_bh(&rose_node_list_lock); + spin_lock_bh(&rose_neigh_list_lock); + + rose_neigh = rose_neigh_list; + rose_node = rose_node_list; + + while (rose_node != NULL) { + t = rose_node; + rose_node = rose_node->next; + if (!t->loopback) + rose_remove_node(t); + } + + while (rose_neigh != NULL) { + s = rose_neigh; + rose_neigh = rose_neigh->next; + + if (s->use == 0 && !s->loopback) { + s->count = 0; + rose_remove_neigh(s); + } + } + + spin_unlock_bh(&rose_neigh_list_lock); + spin_unlock_bh(&rose_node_list_lock); + + return 0; +} + +/* + * Check that the device given is a valid AX.25 interface that is "up". + * called with RTNL + */ +static struct net_device *rose_ax25_dev_find(char *devname) +{ + struct net_device *dev; + + if ((dev = __dev_get_by_name(&init_net, devname)) == NULL) + return NULL; + + if ((dev->flags & IFF_UP) && dev->type == ARPHRD_AX25) + return dev; + + return NULL; +} + +/* + * Find the first active ROSE device, usually "rose0". + */ +struct net_device *rose_dev_first(void) +{ + struct net_device *dev, *first = NULL; + + rcu_read_lock(); + for_each_netdev_rcu(&init_net, dev) { + if ((dev->flags & IFF_UP) && dev->type == ARPHRD_ROSE) + if (first == NULL || strncmp(dev->name, first->name, 3) < 0) + first = dev; + } + if (first) + dev_hold(first); + rcu_read_unlock(); + + return first; +} + +/* + * Find the ROSE device for the given address. + */ +struct net_device *rose_dev_get(rose_address *addr) +{ + struct net_device *dev; + + rcu_read_lock(); + for_each_netdev_rcu(&init_net, dev) { + if ((dev->flags & IFF_UP) && dev->type == ARPHRD_ROSE && + rosecmp(addr, (const rose_address *)dev->dev_addr) == 0) { + dev_hold(dev); + goto out; + } + } + dev = NULL; +out: + rcu_read_unlock(); + return dev; +} + +static int rose_dev_exists(rose_address *addr) +{ + struct net_device *dev; + + rcu_read_lock(); + for_each_netdev_rcu(&init_net, dev) { + if ((dev->flags & IFF_UP) && dev->type == ARPHRD_ROSE && + rosecmp(addr, (const rose_address *)dev->dev_addr) == 0) + goto out; + } + dev = NULL; +out: + rcu_read_unlock(); + return dev != NULL; +} + + + + +struct rose_route *rose_route_free_lci(unsigned int lci, struct rose_neigh *neigh) +{ + struct rose_route *rose_route; + + for (rose_route = rose_route_list; rose_route != NULL; rose_route = rose_route->next) + if ((rose_route->neigh1 == neigh && rose_route->lci1 == lci) || + (rose_route->neigh2 == neigh && rose_route->lci2 == lci)) + return rose_route; + + return NULL; +} + +/* + * Find a neighbour or a route given a ROSE address. + */ +struct rose_neigh *rose_get_neigh(rose_address *addr, unsigned char *cause, + unsigned char *diagnostic, int route_frame) +{ + struct rose_neigh *res = NULL; + struct rose_node *node; + int failed = 0; + int i; + + if (!route_frame) spin_lock_bh(&rose_node_list_lock); + for (node = rose_node_list; node != NULL; node = node->next) { + if (rosecmpm(addr, &node->address, node->mask) == 0) { + for (i = 0; i < node->count; i++) { + if (node->neighbour[i]->restarted) { + res = node->neighbour[i]; + goto out; + } + } + } + } + if (!route_frame) { /* connect request */ + for (node = rose_node_list; node != NULL; node = node->next) { + if (rosecmpm(addr, &node->address, node->mask) == 0) { + for (i = 0; i < node->count; i++) { + if (!rose_ftimer_running(node->neighbour[i])) { + res = node->neighbour[i]; + goto out; + } + failed = 1; + } + } + } + } + + if (failed) { + *cause = ROSE_OUT_OF_ORDER; + *diagnostic = 0; + } else { + *cause = ROSE_NOT_OBTAINABLE; + *diagnostic = 0; + } + +out: + if (!route_frame) spin_unlock_bh(&rose_node_list_lock); + return res; +} + +/* + * Handle the ioctls that control the routing functions. + */ +int rose_rt_ioctl(unsigned int cmd, void __user *arg) +{ + struct rose_route_struct rose_route; + struct net_device *dev; + int err; + + switch (cmd) { + case SIOCADDRT: + if (copy_from_user(&rose_route, arg, sizeof(struct rose_route_struct))) + return -EFAULT; + if ((dev = rose_ax25_dev_find(rose_route.device)) == NULL) + return -EINVAL; + if (rose_dev_exists(&rose_route.address)) /* Can't add routes to ourself */ + return -EINVAL; + if (rose_route.mask > 10) /* Mask can't be more than 10 digits */ + return -EINVAL; + if (rose_route.ndigis > AX25_MAX_DIGIS) + return -EINVAL; + err = rose_add_node(&rose_route, dev); + return err; + + case SIOCDELRT: + if (copy_from_user(&rose_route, arg, sizeof(struct rose_route_struct))) + return -EFAULT; + if ((dev = rose_ax25_dev_find(rose_route.device)) == NULL) + return -EINVAL; + err = rose_del_node(&rose_route, dev); + return err; + + case SIOCRSCLRRT: + return rose_clear_routes(); + + default: + return -EINVAL; + } + + return 0; +} + +static void rose_del_route_by_neigh(struct rose_neigh *rose_neigh) +{ + struct rose_route *rose_route, *s; + + rose_neigh->restarted = 0; + + rose_stop_t0timer(rose_neigh); + rose_start_ftimer(rose_neigh); + + skb_queue_purge(&rose_neigh->queue); + + spin_lock_bh(&rose_route_list_lock); + + rose_route = rose_route_list; + + while (rose_route != NULL) { + if ((rose_route->neigh1 == rose_neigh && rose_route->neigh2 == rose_neigh) || + (rose_route->neigh1 == rose_neigh && rose_route->neigh2 == NULL) || + (rose_route->neigh2 == rose_neigh && rose_route->neigh1 == NULL)) { + s = rose_route->next; + rose_remove_route(rose_route); + rose_route = s; + continue; + } + + if (rose_route->neigh1 == rose_neigh) { + rose_route->neigh1->use--; + rose_route->neigh1 = NULL; + rose_transmit_clear_request(rose_route->neigh2, rose_route->lci2, ROSE_OUT_OF_ORDER, 0); + } + + if (rose_route->neigh2 == rose_neigh) { + rose_route->neigh2->use--; + rose_route->neigh2 = NULL; + rose_transmit_clear_request(rose_route->neigh1, rose_route->lci1, ROSE_OUT_OF_ORDER, 0); + } + + rose_route = rose_route->next; + } + spin_unlock_bh(&rose_route_list_lock); +} + +/* + * A level 2 link has timed out, therefore it appears to be a poor link, + * then don't use that neighbour until it is reset. Blow away all through + * routes and connections using this route. + */ +void rose_link_failed(ax25_cb *ax25, int reason) +{ + struct rose_neigh *rose_neigh; + + spin_lock_bh(&rose_neigh_list_lock); + rose_neigh = rose_neigh_list; + while (rose_neigh != NULL) { + if (rose_neigh->ax25 == ax25) + break; + rose_neigh = rose_neigh->next; + } + + if (rose_neigh != NULL) { + rose_neigh->ax25 = NULL; + ax25_cb_put(ax25); + + rose_del_route_by_neigh(rose_neigh); + rose_kill_by_neigh(rose_neigh); + } + spin_unlock_bh(&rose_neigh_list_lock); +} + +/* + * A device has been "downed" remove its link status. Blow away all + * through routes and connections that use this device. + */ +void rose_link_device_down(struct net_device *dev) +{ + struct rose_neigh *rose_neigh; + + for (rose_neigh = rose_neigh_list; rose_neigh != NULL; rose_neigh = rose_neigh->next) { + if (rose_neigh->dev == dev) { + rose_del_route_by_neigh(rose_neigh); + rose_kill_by_neigh(rose_neigh); + } + } +} + +/* + * Route a frame to an appropriate AX.25 connection. + * A NULL ax25_cb indicates an internally generated frame. + */ +int rose_route_frame(struct sk_buff *skb, ax25_cb *ax25) +{ + struct rose_neigh *rose_neigh, *new_neigh; + struct rose_route *rose_route; + struct rose_facilities_struct facilities; + rose_address *src_addr, *dest_addr; + struct sock *sk; + unsigned short frametype; + unsigned int lci, new_lci; + unsigned char cause, diagnostic; + struct net_device *dev; + int res = 0; + char buf[11]; + + if (skb->len < ROSE_MIN_LEN) + return res; + + if (!ax25) + return rose_loopback_queue(skb, NULL); + + frametype = skb->data[2]; + lci = ((skb->data[0] << 8) & 0xF00) + ((skb->data[1] << 0) & 0x0FF); + if (frametype == ROSE_CALL_REQUEST && + (skb->len <= ROSE_CALL_REQ_FACILITIES_OFF || + skb->data[ROSE_CALL_REQ_ADDR_LEN_OFF] != + ROSE_CALL_REQ_ADDR_LEN_VAL)) + return res; + src_addr = (rose_address *)(skb->data + ROSE_CALL_REQ_SRC_ADDR_OFF); + dest_addr = (rose_address *)(skb->data + ROSE_CALL_REQ_DEST_ADDR_OFF); + + spin_lock_bh(&rose_neigh_list_lock); + spin_lock_bh(&rose_route_list_lock); + + rose_neigh = rose_neigh_list; + while (rose_neigh != NULL) { + if (ax25cmp(&ax25->dest_addr, &rose_neigh->callsign) == 0 && + ax25->ax25_dev->dev == rose_neigh->dev) + break; + rose_neigh = rose_neigh->next; + } + + if (rose_neigh == NULL) { + printk("rose_route : unknown neighbour or device %s\n", + ax2asc(buf, &ax25->dest_addr)); + goto out; + } + + /* + * Obviously the link is working, halt the ftimer. + */ + rose_stop_ftimer(rose_neigh); + + /* + * LCI of zero is always for us, and its always a restart + * frame. + */ + if (lci == 0) { + rose_link_rx_restart(skb, rose_neigh, frametype); + goto out; + } + + /* + * Find an existing socket. + */ + if ((sk = rose_find_socket(lci, rose_neigh)) != NULL) { + if (frametype == ROSE_CALL_REQUEST) { + struct rose_sock *rose = rose_sk(sk); + + /* Remove an existing unused socket */ + rose_clear_queues(sk); + rose->cause = ROSE_NETWORK_CONGESTION; + rose->diagnostic = 0; + rose->neighbour->use--; + rose->neighbour = NULL; + rose->lci = 0; + rose->state = ROSE_STATE_0; + sk->sk_state = TCP_CLOSE; + sk->sk_err = 0; + sk->sk_shutdown |= SEND_SHUTDOWN; + if (!sock_flag(sk, SOCK_DEAD)) { + sk->sk_state_change(sk); + sock_set_flag(sk, SOCK_DEAD); + } + } + else { + skb_reset_transport_header(skb); + res = rose_process_rx_frame(sk, skb); + goto out; + } + } + + /* + * Is is a Call Request and is it for us ? + */ + if (frametype == ROSE_CALL_REQUEST) + if ((dev = rose_dev_get(dest_addr)) != NULL) { + res = rose_rx_call_request(skb, dev, rose_neigh, lci); + dev_put(dev); + goto out; + } + + if (!sysctl_rose_routing_control) { + rose_transmit_clear_request(rose_neigh, lci, ROSE_NOT_OBTAINABLE, 0); + goto out; + } + + /* + * Route it to the next in line if we have an entry for it. + */ + rose_route = rose_route_list; + while (rose_route != NULL) { + if (rose_route->lci1 == lci && + rose_route->neigh1 == rose_neigh) { + if (frametype == ROSE_CALL_REQUEST) { + /* F6FBB - Remove an existing unused route */ + rose_remove_route(rose_route); + break; + } else if (rose_route->neigh2 != NULL) { + skb->data[0] &= 0xF0; + skb->data[0] |= (rose_route->lci2 >> 8) & 0x0F; + skb->data[1] = (rose_route->lci2 >> 0) & 0xFF; + rose_transmit_link(skb, rose_route->neigh2); + if (frametype == ROSE_CLEAR_CONFIRMATION) + rose_remove_route(rose_route); + res = 1; + goto out; + } else { + if (frametype == ROSE_CLEAR_CONFIRMATION) + rose_remove_route(rose_route); + goto out; + } + } + if (rose_route->lci2 == lci && + rose_route->neigh2 == rose_neigh) { + if (frametype == ROSE_CALL_REQUEST) { + /* F6FBB - Remove an existing unused route */ + rose_remove_route(rose_route); + break; + } else if (rose_route->neigh1 != NULL) { + skb->data[0] &= 0xF0; + skb->data[0] |= (rose_route->lci1 >> 8) & 0x0F; + skb->data[1] = (rose_route->lci1 >> 0) & 0xFF; + rose_transmit_link(skb, rose_route->neigh1); + if (frametype == ROSE_CLEAR_CONFIRMATION) + rose_remove_route(rose_route); + res = 1; + goto out; + } else { + if (frametype == ROSE_CLEAR_CONFIRMATION) + rose_remove_route(rose_route); + goto out; + } + } + rose_route = rose_route->next; + } + + /* + * We know that: + * 1. The frame isn't for us, + * 2. It isn't "owned" by any existing route. + */ + if (frametype != ROSE_CALL_REQUEST) { /* XXX */ + res = 0; + goto out; + } + + memset(&facilities, 0x00, sizeof(struct rose_facilities_struct)); + + if (!rose_parse_facilities(skb->data + ROSE_CALL_REQ_FACILITIES_OFF, + skb->len - ROSE_CALL_REQ_FACILITIES_OFF, + &facilities)) { + rose_transmit_clear_request(rose_neigh, lci, ROSE_INVALID_FACILITY, 76); + goto out; + } + + /* + * Check for routing loops. + */ + rose_route = rose_route_list; + while (rose_route != NULL) { + if (rose_route->rand == facilities.rand && + rosecmp(src_addr, &rose_route->src_addr) == 0 && + ax25cmp(&facilities.dest_call, &rose_route->src_call) == 0 && + ax25cmp(&facilities.source_call, &rose_route->dest_call) == 0) { + rose_transmit_clear_request(rose_neigh, lci, ROSE_NOT_OBTAINABLE, 120); + goto out; + } + rose_route = rose_route->next; + } + + if ((new_neigh = rose_get_neigh(dest_addr, &cause, &diagnostic, 1)) == NULL) { + rose_transmit_clear_request(rose_neigh, lci, cause, diagnostic); + goto out; + } + + if ((new_lci = rose_new_lci(new_neigh)) == 0) { + rose_transmit_clear_request(rose_neigh, lci, ROSE_NETWORK_CONGESTION, 71); + goto out; + } + + if ((rose_route = kmalloc(sizeof(*rose_route), GFP_ATOMIC)) == NULL) { + rose_transmit_clear_request(rose_neigh, lci, ROSE_NETWORK_CONGESTION, 120); + goto out; + } + + rose_route->lci1 = lci; + rose_route->src_addr = *src_addr; + rose_route->dest_addr = *dest_addr; + rose_route->src_call = facilities.dest_call; + rose_route->dest_call = facilities.source_call; + rose_route->rand = facilities.rand; + rose_route->neigh1 = rose_neigh; + rose_route->lci2 = new_lci; + rose_route->neigh2 = new_neigh; + + rose_route->neigh1->use++; + rose_route->neigh2->use++; + + rose_route->next = rose_route_list; + rose_route_list = rose_route; + + skb->data[0] &= 0xF0; + skb->data[0] |= (rose_route->lci2 >> 8) & 0x0F; + skb->data[1] = (rose_route->lci2 >> 0) & 0xFF; + + rose_transmit_link(skb, rose_route->neigh2); + res = 1; + +out: + spin_unlock_bh(&rose_route_list_lock); + spin_unlock_bh(&rose_neigh_list_lock); + + return res; +} + +#ifdef CONFIG_PROC_FS + +static void *rose_node_start(struct seq_file *seq, loff_t *pos) + __acquires(rose_node_list_lock) +{ + struct rose_node *rose_node; + int i = 1; + + spin_lock_bh(&rose_node_list_lock); + if (*pos == 0) + return SEQ_START_TOKEN; + + for (rose_node = rose_node_list; rose_node && i < *pos; + rose_node = rose_node->next, ++i); + + return (i == *pos) ? rose_node : NULL; +} + +static void *rose_node_next(struct seq_file *seq, void *v, loff_t *pos) +{ + ++*pos; + + return (v == SEQ_START_TOKEN) ? rose_node_list + : ((struct rose_node *)v)->next; +} + +static void rose_node_stop(struct seq_file *seq, void *v) + __releases(rose_node_list_lock) +{ + spin_unlock_bh(&rose_node_list_lock); +} + +static int rose_node_show(struct seq_file *seq, void *v) +{ + char rsbuf[11]; + int i; + + if (v == SEQ_START_TOKEN) + seq_puts(seq, "address mask n neigh neigh neigh\n"); + else { + const struct rose_node *rose_node = v; + seq_printf(seq, "%-10s %04d %d", + rose2asc(rsbuf, &rose_node->address), + rose_node->mask, + rose_node->count); + + for (i = 0; i < rose_node->count; i++) + seq_printf(seq, " %05d", rose_node->neighbour[i]->number); + + seq_puts(seq, "\n"); + } + return 0; +} + +const struct seq_operations rose_node_seqops = { + .start = rose_node_start, + .next = rose_node_next, + .stop = rose_node_stop, + .show = rose_node_show, +}; + +static void *rose_neigh_start(struct seq_file *seq, loff_t *pos) + __acquires(rose_neigh_list_lock) +{ + struct rose_neigh *rose_neigh; + int i = 1; + + spin_lock_bh(&rose_neigh_list_lock); + if (*pos == 0) + return SEQ_START_TOKEN; + + for (rose_neigh = rose_neigh_list; rose_neigh && i < *pos; + rose_neigh = rose_neigh->next, ++i); + + return (i == *pos) ? rose_neigh : NULL; +} + +static void *rose_neigh_next(struct seq_file *seq, void *v, loff_t *pos) +{ + ++*pos; + + return (v == SEQ_START_TOKEN) ? rose_neigh_list + : ((struct rose_neigh *)v)->next; +} + +static void rose_neigh_stop(struct seq_file *seq, void *v) + __releases(rose_neigh_list_lock) +{ + spin_unlock_bh(&rose_neigh_list_lock); +} + +static int rose_neigh_show(struct seq_file *seq, void *v) +{ + char buf[11]; + int i; + + if (v == SEQ_START_TOKEN) + seq_puts(seq, + "addr callsign dev count use mode restart t0 tf digipeaters\n"); + else { + struct rose_neigh *rose_neigh = v; + + /* if (!rose_neigh->loopback) { */ + seq_printf(seq, "%05d %-9s %-4s %3d %3d %3s %3s %3lu %3lu", + rose_neigh->number, + (rose_neigh->loopback) ? "RSLOOP-0" : ax2asc(buf, &rose_neigh->callsign), + rose_neigh->dev ? rose_neigh->dev->name : "???", + rose_neigh->count, + rose_neigh->use, + (rose_neigh->dce_mode) ? "DCE" : "DTE", + (rose_neigh->restarted) ? "yes" : "no", + ax25_display_timer(&rose_neigh->t0timer) / HZ, + ax25_display_timer(&rose_neigh->ftimer) / HZ); + + if (rose_neigh->digipeat != NULL) { + for (i = 0; i < rose_neigh->digipeat->ndigi; i++) + seq_printf(seq, " %s", ax2asc(buf, &rose_neigh->digipeat->calls[i])); + } + + seq_puts(seq, "\n"); + } + return 0; +} + + +const struct seq_operations rose_neigh_seqops = { + .start = rose_neigh_start, + .next = rose_neigh_next, + .stop = rose_neigh_stop, + .show = rose_neigh_show, +}; + +static void *rose_route_start(struct seq_file *seq, loff_t *pos) + __acquires(rose_route_list_lock) +{ + struct rose_route *rose_route; + int i = 1; + + spin_lock_bh(&rose_route_list_lock); + if (*pos == 0) + return SEQ_START_TOKEN; + + for (rose_route = rose_route_list; rose_route && i < *pos; + rose_route = rose_route->next, ++i); + + return (i == *pos) ? rose_route : NULL; +} + +static void *rose_route_next(struct seq_file *seq, void *v, loff_t *pos) +{ + ++*pos; + + return (v == SEQ_START_TOKEN) ? rose_route_list + : ((struct rose_route *)v)->next; +} + +static void rose_route_stop(struct seq_file *seq, void *v) + __releases(rose_route_list_lock) +{ + spin_unlock_bh(&rose_route_list_lock); +} + +static int rose_route_show(struct seq_file *seq, void *v) +{ + char buf[11], rsbuf[11]; + + if (v == SEQ_START_TOKEN) + seq_puts(seq, + "lci address callsign neigh <-> lci address callsign neigh\n"); + else { + struct rose_route *rose_route = v; + + if (rose_route->neigh1) + seq_printf(seq, + "%3.3X %-10s %-9s %05d ", + rose_route->lci1, + rose2asc(rsbuf, &rose_route->src_addr), + ax2asc(buf, &rose_route->src_call), + rose_route->neigh1->number); + else + seq_puts(seq, + "000 * * 00000 "); + + if (rose_route->neigh2) + seq_printf(seq, + "%3.3X %-10s %-9s %05d\n", + rose_route->lci2, + rose2asc(rsbuf, &rose_route->dest_addr), + ax2asc(buf, &rose_route->dest_call), + rose_route->neigh2->number); + else + seq_puts(seq, + "000 * * 00000\n"); + } + return 0; +} + +struct seq_operations rose_route_seqops = { + .start = rose_route_start, + .next = rose_route_next, + .stop = rose_route_stop, + .show = rose_route_show, +}; +#endif /* CONFIG_PROC_FS */ + +/* + * Release all memory associated with ROSE routing structures. + */ +void __exit rose_rt_free(void) +{ + struct rose_neigh *s, *rose_neigh = rose_neigh_list; + struct rose_node *t, *rose_node = rose_node_list; + struct rose_route *u, *rose_route = rose_route_list; + + while (rose_neigh != NULL) { + s = rose_neigh; + rose_neigh = rose_neigh->next; + + rose_remove_neigh(s); + } + + while (rose_node != NULL) { + t = rose_node; + rose_node = rose_node->next; + + rose_remove_node(t); + } + + while (rose_route != NULL) { + u = rose_route; + rose_route = rose_route->next; + + rose_remove_route(u); + } +} diff --git a/net/rose/rose_subr.c b/net/rose/rose_subr.c new file mode 100644 index 000000000..4dbc437a9 --- /dev/null +++ b/net/rose/rose_subr.c @@ -0,0 +1,556 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int rose_create_facilities(unsigned char *buffer, struct rose_sock *rose); + +/* + * This routine purges all of the queues of frames. + */ +void rose_clear_queues(struct sock *sk) +{ + skb_queue_purge(&sk->sk_write_queue); + skb_queue_purge(&rose_sk(sk)->ack_queue); +} + +/* + * This routine purges the input queue of those frames that have been + * acknowledged. This replaces the boxes labelled "V(a) <- N(r)" on the + * SDL diagram. + */ +void rose_frames_acked(struct sock *sk, unsigned short nr) +{ + struct sk_buff *skb; + struct rose_sock *rose = rose_sk(sk); + + /* + * Remove all the ack-ed frames from the ack queue. + */ + if (rose->va != nr) { + while (skb_peek(&rose->ack_queue) != NULL && rose->va != nr) { + skb = skb_dequeue(&rose->ack_queue); + kfree_skb(skb); + rose->va = (rose->va + 1) % ROSE_MODULUS; + } + } +} + +void rose_requeue_frames(struct sock *sk) +{ + struct sk_buff *skb, *skb_prev = NULL; + + /* + * Requeue all the un-ack-ed frames on the output queue to be picked + * up by rose_kick. This arrangement handles the possibility of an + * empty output queue. + */ + while ((skb = skb_dequeue(&rose_sk(sk)->ack_queue)) != NULL) { + if (skb_prev == NULL) + skb_queue_head(&sk->sk_write_queue, skb); + else + skb_append(skb_prev, skb, &sk->sk_write_queue); + skb_prev = skb; + } +} + +/* + * Validate that the value of nr is between va and vs. Return true or + * false for testing. + */ +int rose_validate_nr(struct sock *sk, unsigned short nr) +{ + struct rose_sock *rose = rose_sk(sk); + unsigned short vc = rose->va; + + while (vc != rose->vs) { + if (nr == vc) return 1; + vc = (vc + 1) % ROSE_MODULUS; + } + + return nr == rose->vs; +} + +/* + * This routine is called when the packet layer internally generates a + * control frame. + */ +void rose_write_internal(struct sock *sk, int frametype) +{ + struct rose_sock *rose = rose_sk(sk); + struct sk_buff *skb; + unsigned char *dptr; + unsigned char lci1, lci2; + int maxfaclen = 0; + int len, faclen; + int reserve; + + reserve = AX25_BPQ_HEADER_LEN + AX25_MAX_HEADER_LEN + 1; + len = ROSE_MIN_LEN; + + switch (frametype) { + case ROSE_CALL_REQUEST: + len += 1 + ROSE_ADDR_LEN + ROSE_ADDR_LEN; + maxfaclen = 256; + break; + case ROSE_CALL_ACCEPTED: + case ROSE_CLEAR_REQUEST: + case ROSE_RESET_REQUEST: + len += 2; + break; + } + + skb = alloc_skb(reserve + len + maxfaclen, GFP_ATOMIC); + if (!skb) + return; + + /* + * Space for AX.25 header and PID. + */ + skb_reserve(skb, reserve); + + dptr = skb_put(skb, len); + + lci1 = (rose->lci >> 8) & 0x0F; + lci2 = (rose->lci >> 0) & 0xFF; + + switch (frametype) { + case ROSE_CALL_REQUEST: + *dptr++ = ROSE_GFI | lci1; + *dptr++ = lci2; + *dptr++ = frametype; + *dptr++ = ROSE_CALL_REQ_ADDR_LEN_VAL; + memcpy(dptr, &rose->dest_addr, ROSE_ADDR_LEN); + dptr += ROSE_ADDR_LEN; + memcpy(dptr, &rose->source_addr, ROSE_ADDR_LEN); + dptr += ROSE_ADDR_LEN; + faclen = rose_create_facilities(dptr, rose); + skb_put(skb, faclen); + dptr += faclen; + break; + + case ROSE_CALL_ACCEPTED: + *dptr++ = ROSE_GFI | lci1; + *dptr++ = lci2; + *dptr++ = frametype; + *dptr++ = 0x00; /* Address length */ + *dptr++ = 0; /* Facilities length */ + break; + + case ROSE_CLEAR_REQUEST: + *dptr++ = ROSE_GFI | lci1; + *dptr++ = lci2; + *dptr++ = frametype; + *dptr++ = rose->cause; + *dptr++ = rose->diagnostic; + break; + + case ROSE_RESET_REQUEST: + *dptr++ = ROSE_GFI | lci1; + *dptr++ = lci2; + *dptr++ = frametype; + *dptr++ = ROSE_DTE_ORIGINATED; + *dptr++ = 0; + break; + + case ROSE_RR: + case ROSE_RNR: + *dptr++ = ROSE_GFI | lci1; + *dptr++ = lci2; + *dptr = frametype; + *dptr++ |= (rose->vr << 5) & 0xE0; + break; + + case ROSE_CLEAR_CONFIRMATION: + case ROSE_RESET_CONFIRMATION: + *dptr++ = ROSE_GFI | lci1; + *dptr++ = lci2; + *dptr++ = frametype; + break; + + default: + printk(KERN_ERR "ROSE: rose_write_internal - invalid frametype %02X\n", frametype); + kfree_skb(skb); + return; + } + + rose_transmit_link(skb, rose->neighbour); +} + +int rose_decode(struct sk_buff *skb, int *ns, int *nr, int *q, int *d, int *m) +{ + unsigned char *frame; + + frame = skb->data; + + *ns = *nr = *q = *d = *m = 0; + + switch (frame[2]) { + case ROSE_CALL_REQUEST: + case ROSE_CALL_ACCEPTED: + case ROSE_CLEAR_REQUEST: + case ROSE_CLEAR_CONFIRMATION: + case ROSE_RESET_REQUEST: + case ROSE_RESET_CONFIRMATION: + return frame[2]; + default: + break; + } + + if ((frame[2] & 0x1F) == ROSE_RR || + (frame[2] & 0x1F) == ROSE_RNR) { + *nr = (frame[2] >> 5) & 0x07; + return frame[2] & 0x1F; + } + + if ((frame[2] & 0x01) == ROSE_DATA) { + *q = (frame[0] & ROSE_Q_BIT) == ROSE_Q_BIT; + *d = (frame[0] & ROSE_D_BIT) == ROSE_D_BIT; + *m = (frame[2] & ROSE_M_BIT) == ROSE_M_BIT; + *nr = (frame[2] >> 5) & 0x07; + *ns = (frame[2] >> 1) & 0x07; + return ROSE_DATA; + } + + return ROSE_ILLEGAL; +} + +static int rose_parse_national(unsigned char *p, struct rose_facilities_struct *facilities, int len) +{ + unsigned char *pt; + unsigned char l, lg, n = 0; + int fac_national_digis_received = 0; + + do { + switch (*p & 0xC0) { + case 0x00: + if (len < 2) + return -1; + p += 2; + n += 2; + len -= 2; + break; + + case 0x40: + if (len < 3) + return -1; + if (*p == FAC_NATIONAL_RAND) + facilities->rand = ((p[1] << 8) & 0xFF00) + ((p[2] << 0) & 0x00FF); + p += 3; + n += 3; + len -= 3; + break; + + case 0x80: + if (len < 4) + return -1; + p += 4; + n += 4; + len -= 4; + break; + + case 0xC0: + if (len < 2) + return -1; + l = p[1]; + if (len < 2 + l) + return -1; + if (*p == FAC_NATIONAL_DEST_DIGI) { + if (!fac_national_digis_received) { + if (l < AX25_ADDR_LEN) + return -1; + memcpy(&facilities->source_digis[0], p + 2, AX25_ADDR_LEN); + facilities->source_ndigis = 1; + } + } + else if (*p == FAC_NATIONAL_SRC_DIGI) { + if (!fac_national_digis_received) { + if (l < AX25_ADDR_LEN) + return -1; + memcpy(&facilities->dest_digis[0], p + 2, AX25_ADDR_LEN); + facilities->dest_ndigis = 1; + } + } + else if (*p == FAC_NATIONAL_FAIL_CALL) { + if (l < AX25_ADDR_LEN) + return -1; + memcpy(&facilities->fail_call, p + 2, AX25_ADDR_LEN); + } + else if (*p == FAC_NATIONAL_FAIL_ADD) { + if (l < 1 + ROSE_ADDR_LEN) + return -1; + memcpy(&facilities->fail_addr, p + 3, ROSE_ADDR_LEN); + } + else if (*p == FAC_NATIONAL_DIGIS) { + if (l % AX25_ADDR_LEN) + return -1; + fac_national_digis_received = 1; + facilities->source_ndigis = 0; + facilities->dest_ndigis = 0; + for (pt = p + 2, lg = 0 ; lg < l ; pt += AX25_ADDR_LEN, lg += AX25_ADDR_LEN) { + if (pt[6] & AX25_HBIT) { + if (facilities->dest_ndigis >= ROSE_MAX_DIGIS) + return -1; + memcpy(&facilities->dest_digis[facilities->dest_ndigis++], pt, AX25_ADDR_LEN); + } else { + if (facilities->source_ndigis >= ROSE_MAX_DIGIS) + return -1; + memcpy(&facilities->source_digis[facilities->source_ndigis++], pt, AX25_ADDR_LEN); + } + } + } + p += l + 2; + n += l + 2; + len -= l + 2; + break; + } + } while (*p != 0x00 && len > 0); + + return n; +} + +static int rose_parse_ccitt(unsigned char *p, struct rose_facilities_struct *facilities, int len) +{ + unsigned char l, n = 0; + char callsign[11]; + + do { + switch (*p & 0xC0) { + case 0x00: + if (len < 2) + return -1; + p += 2; + n += 2; + len -= 2; + break; + + case 0x40: + if (len < 3) + return -1; + p += 3; + n += 3; + len -= 3; + break; + + case 0x80: + if (len < 4) + return -1; + p += 4; + n += 4; + len -= 4; + break; + + case 0xC0: + if (len < 2) + return -1; + l = p[1]; + + /* Prevent overflows*/ + if (l < 10 || l > 20) + return -1; + + if (*p == FAC_CCITT_DEST_NSAP) { + memcpy(&facilities->source_addr, p + 7, ROSE_ADDR_LEN); + memcpy(callsign, p + 12, l - 10); + callsign[l - 10] = '\0'; + asc2ax(&facilities->source_call, callsign); + } + if (*p == FAC_CCITT_SRC_NSAP) { + memcpy(&facilities->dest_addr, p + 7, ROSE_ADDR_LEN); + memcpy(callsign, p + 12, l - 10); + callsign[l - 10] = '\0'; + asc2ax(&facilities->dest_call, callsign); + } + p += l + 2; + n += l + 2; + len -= l + 2; + break; + } + } while (*p != 0x00 && len > 0); + + return n; +} + +int rose_parse_facilities(unsigned char *p, unsigned packet_len, + struct rose_facilities_struct *facilities) +{ + int facilities_len, len; + + facilities_len = *p++; + + if (facilities_len == 0 || (unsigned int)facilities_len > packet_len) + return 0; + + while (facilities_len >= 3 && *p == 0x00) { + facilities_len--; + p++; + + switch (*p) { + case FAC_NATIONAL: /* National */ + len = rose_parse_national(p + 1, facilities, facilities_len - 1); + break; + + case FAC_CCITT: /* CCITT */ + len = rose_parse_ccitt(p + 1, facilities, facilities_len - 1); + break; + + default: + printk(KERN_DEBUG "ROSE: rose_parse_facilities - unknown facilities family %02X\n", *p); + len = 1; + break; + } + + if (len < 0) + return 0; + if (WARN_ON(len >= facilities_len)) + return 0; + facilities_len -= len + 1; + p += len + 1; + } + + return facilities_len == 0; +} + +static int rose_create_facilities(unsigned char *buffer, struct rose_sock *rose) +{ + unsigned char *p = buffer + 1; + char *callsign; + char buf[11]; + int len, nb; + + /* National Facilities */ + if (rose->rand != 0 || rose->source_ndigis == 1 || rose->dest_ndigis == 1) { + *p++ = 0x00; + *p++ = FAC_NATIONAL; + + if (rose->rand != 0) { + *p++ = FAC_NATIONAL_RAND; + *p++ = (rose->rand >> 8) & 0xFF; + *p++ = (rose->rand >> 0) & 0xFF; + } + + /* Sent before older facilities */ + if ((rose->source_ndigis > 0) || (rose->dest_ndigis > 0)) { + int maxdigi = 0; + *p++ = FAC_NATIONAL_DIGIS; + *p++ = AX25_ADDR_LEN * (rose->source_ndigis + rose->dest_ndigis); + for (nb = 0 ; nb < rose->source_ndigis ; nb++) { + if (++maxdigi >= ROSE_MAX_DIGIS) + break; + memcpy(p, &rose->source_digis[nb], AX25_ADDR_LEN); + p[6] |= AX25_HBIT; + p += AX25_ADDR_LEN; + } + for (nb = 0 ; nb < rose->dest_ndigis ; nb++) { + if (++maxdigi >= ROSE_MAX_DIGIS) + break; + memcpy(p, &rose->dest_digis[nb], AX25_ADDR_LEN); + p[6] &= ~AX25_HBIT; + p += AX25_ADDR_LEN; + } + } + + /* For compatibility */ + if (rose->source_ndigis > 0) { + *p++ = FAC_NATIONAL_SRC_DIGI; + *p++ = AX25_ADDR_LEN; + memcpy(p, &rose->source_digis[0], AX25_ADDR_LEN); + p += AX25_ADDR_LEN; + } + + /* For compatibility */ + if (rose->dest_ndigis > 0) { + *p++ = FAC_NATIONAL_DEST_DIGI; + *p++ = AX25_ADDR_LEN; + memcpy(p, &rose->dest_digis[0], AX25_ADDR_LEN); + p += AX25_ADDR_LEN; + } + } + + *p++ = 0x00; + *p++ = FAC_CCITT; + + *p++ = FAC_CCITT_DEST_NSAP; + + callsign = ax2asc(buf, &rose->dest_call); + + *p++ = strlen(callsign) + 10; + *p++ = (strlen(callsign) + 9) * 2; /* ??? */ + + *p++ = 0x47; *p++ = 0x00; *p++ = 0x11; + *p++ = ROSE_ADDR_LEN * 2; + memcpy(p, &rose->dest_addr, ROSE_ADDR_LEN); + p += ROSE_ADDR_LEN; + + memcpy(p, callsign, strlen(callsign)); + p += strlen(callsign); + + *p++ = FAC_CCITT_SRC_NSAP; + + callsign = ax2asc(buf, &rose->source_call); + + *p++ = strlen(callsign) + 10; + *p++ = (strlen(callsign) + 9) * 2; /* ??? */ + + *p++ = 0x47; *p++ = 0x00; *p++ = 0x11; + *p++ = ROSE_ADDR_LEN * 2; + memcpy(p, &rose->source_addr, ROSE_ADDR_LEN); + p += ROSE_ADDR_LEN; + + memcpy(p, callsign, strlen(callsign)); + p += strlen(callsign); + + len = p - buffer; + buffer[0] = len - 1; + + return len; +} + +void rose_disconnect(struct sock *sk, int reason, int cause, int diagnostic) +{ + struct rose_sock *rose = rose_sk(sk); + + rose_stop_timer(sk); + rose_stop_idletimer(sk); + + rose_clear_queues(sk); + + rose->lci = 0; + rose->state = ROSE_STATE_0; + + if (cause != -1) + rose->cause = cause; + + if (diagnostic != -1) + rose->diagnostic = diagnostic; + + sk->sk_state = TCP_CLOSE; + sk->sk_err = reason; + sk->sk_shutdown |= SEND_SHUTDOWN; + + if (!sock_flag(sk, SOCK_DEAD)) { + sk->sk_state_change(sk); + sock_set_flag(sk, SOCK_DEAD); + } +} diff --git a/net/rose/rose_timer.c b/net/rose/rose_timer.c new file mode 100644 index 000000000..f06ddbed3 --- /dev/null +++ b/net/rose/rose_timer.c @@ -0,0 +1,212 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) Jonathan Naylor G4KLX (g4klx@g4klx.demon.co.uk) + * Copyright (C) 2002 Ralf Baechle DO1GRB (ralf@gnu.org) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void rose_heartbeat_expiry(struct timer_list *t); +static void rose_timer_expiry(struct timer_list *); +static void rose_idletimer_expiry(struct timer_list *); + +void rose_start_heartbeat(struct sock *sk) +{ + sk_stop_timer(sk, &sk->sk_timer); + + sk->sk_timer.function = rose_heartbeat_expiry; + sk->sk_timer.expires = jiffies + 5 * HZ; + + sk_reset_timer(sk, &sk->sk_timer, sk->sk_timer.expires); +} + +void rose_start_t1timer(struct sock *sk) +{ + struct rose_sock *rose = rose_sk(sk); + + sk_stop_timer(sk, &rose->timer); + + rose->timer.function = rose_timer_expiry; + rose->timer.expires = jiffies + rose->t1; + + sk_reset_timer(sk, &rose->timer, rose->timer.expires); +} + +void rose_start_t2timer(struct sock *sk) +{ + struct rose_sock *rose = rose_sk(sk); + + sk_stop_timer(sk, &rose->timer); + + rose->timer.function = rose_timer_expiry; + rose->timer.expires = jiffies + rose->t2; + + sk_reset_timer(sk, &rose->timer, rose->timer.expires); +} + +void rose_start_t3timer(struct sock *sk) +{ + struct rose_sock *rose = rose_sk(sk); + + sk_stop_timer(sk, &rose->timer); + + rose->timer.function = rose_timer_expiry; + rose->timer.expires = jiffies + rose->t3; + + sk_reset_timer(sk, &rose->timer, rose->timer.expires); +} + +void rose_start_hbtimer(struct sock *sk) +{ + struct rose_sock *rose = rose_sk(sk); + + sk_stop_timer(sk, &rose->timer); + + rose->timer.function = rose_timer_expiry; + rose->timer.expires = jiffies + rose->hb; + + sk_reset_timer(sk, &rose->timer, rose->timer.expires); +} + +void rose_start_idletimer(struct sock *sk) +{ + struct rose_sock *rose = rose_sk(sk); + + sk_stop_timer(sk, &rose->idletimer); + + if (rose->idle > 0) { + rose->idletimer.function = rose_idletimer_expiry; + rose->idletimer.expires = jiffies + rose->idle; + + sk_reset_timer(sk, &rose->idletimer, rose->idletimer.expires); + } +} + +void rose_stop_heartbeat(struct sock *sk) +{ + sk_stop_timer(sk, &sk->sk_timer); +} + +void rose_stop_timer(struct sock *sk) +{ + sk_stop_timer(sk, &rose_sk(sk)->timer); +} + +void rose_stop_idletimer(struct sock *sk) +{ + sk_stop_timer(sk, &rose_sk(sk)->idletimer); +} + +static void rose_heartbeat_expiry(struct timer_list *t) +{ + struct sock *sk = from_timer(sk, t, sk_timer); + struct rose_sock *rose = rose_sk(sk); + + bh_lock_sock(sk); + switch (rose->state) { + case ROSE_STATE_0: + /* Magic here: If we listen() and a new link dies before it + is accepted() it isn't 'dead' so doesn't get removed. */ + if (sock_flag(sk, SOCK_DESTROY) || + (sk->sk_state == TCP_LISTEN && sock_flag(sk, SOCK_DEAD))) { + bh_unlock_sock(sk); + rose_destroy_socket(sk); + sock_put(sk); + return; + } + break; + + case ROSE_STATE_3: + /* + * Check for the state of the receive buffer. + */ + if (atomic_read(&sk->sk_rmem_alloc) < (sk->sk_rcvbuf / 2) && + (rose->condition & ROSE_COND_OWN_RX_BUSY)) { + rose->condition &= ~ROSE_COND_OWN_RX_BUSY; + rose->condition &= ~ROSE_COND_ACK_PENDING; + rose->vl = rose->vr; + rose_write_internal(sk, ROSE_RR); + rose_stop_timer(sk); /* HB */ + break; + } + break; + } + + rose_start_heartbeat(sk); + bh_unlock_sock(sk); + sock_put(sk); +} + +static void rose_timer_expiry(struct timer_list *t) +{ + struct rose_sock *rose = from_timer(rose, t, timer); + struct sock *sk = &rose->sock; + + bh_lock_sock(sk); + switch (rose->state) { + case ROSE_STATE_1: /* T1 */ + case ROSE_STATE_4: /* T2 */ + rose_write_internal(sk, ROSE_CLEAR_REQUEST); + rose->state = ROSE_STATE_2; + rose_start_t3timer(sk); + break; + + case ROSE_STATE_2: /* T3 */ + rose->neighbour->use--; + rose_disconnect(sk, ETIMEDOUT, -1, -1); + break; + + case ROSE_STATE_3: /* HB */ + if (rose->condition & ROSE_COND_ACK_PENDING) { + rose->condition &= ~ROSE_COND_ACK_PENDING; + rose_enquiry_response(sk); + } + break; + } + bh_unlock_sock(sk); + sock_put(sk); +} + +static void rose_idletimer_expiry(struct timer_list *t) +{ + struct rose_sock *rose = from_timer(rose, t, idletimer); + struct sock *sk = &rose->sock; + + bh_lock_sock(sk); + rose_clear_queues(sk); + + rose_write_internal(sk, ROSE_CLEAR_REQUEST); + rose_sk(sk)->state = ROSE_STATE_2; + + rose_start_t3timer(sk); + + sk->sk_state = TCP_CLOSE; + sk->sk_err = 0; + sk->sk_shutdown |= SEND_SHUTDOWN; + + if (!sock_flag(sk, SOCK_DEAD)) { + sk->sk_state_change(sk); + sock_set_flag(sk, SOCK_DEAD); + } + bh_unlock_sock(sk); + sock_put(sk); +} diff --git a/net/rose/sysctl_net_rose.c b/net/rose/sysctl_net_rose.c new file mode 100644 index 000000000..d391d7758 --- /dev/null +++ b/net/rose/sysctl_net_rose.c @@ -0,0 +1,126 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * + * Copyright (C) 1996 Mike Shaver (shaver@zeroknowledge.com) + */ +#include +#include +#include +#include +#include + +static int min_timer[] = {1 * HZ}; +static int max_timer[] = {300 * HZ}; +static int min_idle[] = {0 * HZ}; +static int max_idle[] = {65535 * HZ}; +static int min_route[1], max_route[] = {1}; +static int min_ftimer[] = {60 * HZ}; +static int max_ftimer[] = {600 * HZ}; +static int min_maxvcs[] = {1}, max_maxvcs[] = {254}; +static int min_window[] = {1}, max_window[] = {7}; + +static struct ctl_table_header *rose_table_header; + +static struct ctl_table rose_table[] = { + { + .procname = "restart_request_timeout", + .data = &sysctl_rose_restart_request_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_timer, + .extra2 = &max_timer + }, + { + .procname = "call_request_timeout", + .data = &sysctl_rose_call_request_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_timer, + .extra2 = &max_timer + }, + { + .procname = "reset_request_timeout", + .data = &sysctl_rose_reset_request_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_timer, + .extra2 = &max_timer + }, + { + .procname = "clear_request_timeout", + .data = &sysctl_rose_clear_request_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_timer, + .extra2 = &max_timer + }, + { + .procname = "no_activity_timeout", + .data = &sysctl_rose_no_activity_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_idle, + .extra2 = &max_idle + }, + { + .procname = "acknowledge_hold_back_timeout", + .data = &sysctl_rose_ack_hold_back_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_timer, + .extra2 = &max_timer + }, + { + .procname = "routing_control", + .data = &sysctl_rose_routing_control, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_route, + .extra2 = &max_route + }, + { + .procname = "link_fail_timeout", + .data = &sysctl_rose_link_fail_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_ftimer, + .extra2 = &max_ftimer + }, + { + .procname = "maximum_virtual_circuits", + .data = &sysctl_rose_maximum_vcs, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_maxvcs, + .extra2 = &max_maxvcs + }, + { + .procname = "window_size", + .data = &sysctl_rose_window_size, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_window, + .extra2 = &max_window + }, + { } +}; + +void __init rose_register_sysctl(void) +{ + rose_table_header = register_net_sysctl(&init_net, "net/rose", rose_table); +} + +void rose_unregister_sysctl(void) +{ + unregister_net_sysctl_table(rose_table_header); +} diff --git a/net/rxrpc/Kconfig b/net/rxrpc/Kconfig new file mode 100644 index 000000000..accd35c05 --- /dev/null +++ b/net/rxrpc/Kconfig @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# RxRPC session sockets +# + +config AF_RXRPC + tristate "RxRPC session sockets" + depends on INET + select CRYPTO + select KEYS + select NET_UDP_TUNNEL + help + Say Y or M here to include support for RxRPC session sockets (just + the transport part, not the presentation part: (un)marshalling is + left to the application). + + These are used for AFS kernel filesystem and userspace utilities. + + This module at the moment only supports client operations and is + currently incomplete. + + See Documentation/networking/rxrpc.rst. + +if AF_RXRPC + +config AF_RXRPC_IPV6 + bool "IPv6 support for RxRPC" + depends on (IPV6 = m && AF_RXRPC = m) || (IPV6 = y && AF_RXRPC) + help + Say Y here to allow AF_RXRPC to use IPV6 UDP as well as IPV4 UDP as + its network transport. + +config AF_RXRPC_INJECT_LOSS + bool "Inject packet loss into RxRPC packet stream" + help + Say Y here to inject packet loss by discarding some received and some + transmitted packets. + + +config AF_RXRPC_DEBUG + bool "RxRPC dynamic debugging" + help + Say Y here to make runtime controllable debugging messages appear. + + See Documentation/networking/rxrpc.rst. + + +config RXKAD + bool "RxRPC Kerberos security" + select CRYPTO + select CRYPTO_MANAGER + select CRYPTO_SKCIPHER + select CRYPTO_PCBC + select CRYPTO_FCRYPT + help + Provide kerberos 4 and AFS kaserver security handling for AF_RXRPC + through the use of the key retention service. + + See Documentation/networking/rxrpc.rst. + +endif diff --git a/net/rxrpc/Makefile b/net/rxrpc/Makefile new file mode 100644 index 000000000..b11281bed --- /dev/null +++ b/net/rxrpc/Makefile @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for Linux kernel RxRPC +# + +obj-$(CONFIG_AF_RXRPC) += rxrpc.o + +rxrpc-y := \ + af_rxrpc.o \ + call_accept.o \ + call_event.o \ + call_object.o \ + conn_client.o \ + conn_event.o \ + conn_object.o \ + conn_service.o \ + input.o \ + insecure.o \ + key.o \ + local_event.o \ + local_object.o \ + misc.o \ + net_ns.o \ + output.o \ + peer_event.o \ + peer_object.o \ + recvmsg.o \ + rtt.o \ + security.o \ + sendmsg.o \ + server_key.o \ + skbuff.o \ + utils.o + +rxrpc-$(CONFIG_PROC_FS) += proc.o +rxrpc-$(CONFIG_RXKAD) += rxkad.o +rxrpc-$(CONFIG_SYSCTL) += sysctl.o diff --git a/net/rxrpc/af_rxrpc.c b/net/rxrpc/af_rxrpc.c new file mode 100644 index 000000000..ceba28e9d --- /dev/null +++ b/net/rxrpc/af_rxrpc.c @@ -0,0 +1,1078 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* AF_RXRPC implementation + * + * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#define CREATE_TRACE_POINTS +#include "ar-internal.h" + +MODULE_DESCRIPTION("RxRPC network protocol"); +MODULE_AUTHOR("Red Hat, Inc."); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_RXRPC); + +unsigned int rxrpc_debug; // = RXRPC_DEBUG_KPROTO; +module_param_named(debug, rxrpc_debug, uint, 0644); +MODULE_PARM_DESC(debug, "RxRPC debugging mask"); + +static struct proto rxrpc_proto; +static const struct proto_ops rxrpc_rpc_ops; + +/* current debugging ID */ +atomic_t rxrpc_debug_id; +EXPORT_SYMBOL(rxrpc_debug_id); + +/* count of skbs currently in use */ +atomic_t rxrpc_n_tx_skbs, rxrpc_n_rx_skbs; + +struct workqueue_struct *rxrpc_workqueue; + +static void rxrpc_sock_destructor(struct sock *); + +/* + * see if an RxRPC socket is currently writable + */ +static inline int rxrpc_writable(struct sock *sk) +{ + return refcount_read(&sk->sk_wmem_alloc) < (size_t) sk->sk_sndbuf; +} + +/* + * wait for write bufferage to become available + */ +static void rxrpc_write_space(struct sock *sk) +{ + _enter("%p", sk); + rcu_read_lock(); + if (rxrpc_writable(sk)) { + struct socket_wq *wq = rcu_dereference(sk->sk_wq); + + if (skwq_has_sleeper(wq)) + wake_up_interruptible(&wq->wait); + sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT); + } + rcu_read_unlock(); +} + +/* + * validate an RxRPC address + */ +static int rxrpc_validate_address(struct rxrpc_sock *rx, + struct sockaddr_rxrpc *srx, + int len) +{ + unsigned int tail; + + if (len < sizeof(struct sockaddr_rxrpc)) + return -EINVAL; + + if (srx->srx_family != AF_RXRPC) + return -EAFNOSUPPORT; + + if (srx->transport_type != SOCK_DGRAM) + return -ESOCKTNOSUPPORT; + + len -= offsetof(struct sockaddr_rxrpc, transport); + if (srx->transport_len < sizeof(sa_family_t) || + srx->transport_len > len) + return -EINVAL; + + if (srx->transport.family != rx->family && + srx->transport.family == AF_INET && rx->family != AF_INET6) + return -EAFNOSUPPORT; + + switch (srx->transport.family) { + case AF_INET: + if (srx->transport_len < sizeof(struct sockaddr_in)) + return -EINVAL; + tail = offsetof(struct sockaddr_rxrpc, transport.sin.__pad); + break; + +#ifdef CONFIG_AF_RXRPC_IPV6 + case AF_INET6: + if (srx->transport_len < sizeof(struct sockaddr_in6)) + return -EINVAL; + tail = offsetof(struct sockaddr_rxrpc, transport) + + sizeof(struct sockaddr_in6); + break; +#endif + + default: + return -EAFNOSUPPORT; + } + + if (tail < len) + memset((void *)srx + tail, 0, len - tail); + _debug("INET: %pISp", &srx->transport); + return 0; +} + +/* + * bind a local address to an RxRPC socket + */ +static int rxrpc_bind(struct socket *sock, struct sockaddr *saddr, int len) +{ + struct sockaddr_rxrpc *srx = (struct sockaddr_rxrpc *)saddr; + struct rxrpc_local *local; + struct rxrpc_sock *rx = rxrpc_sk(sock->sk); + u16 service_id; + int ret; + + _enter("%p,%p,%d", rx, saddr, len); + + ret = rxrpc_validate_address(rx, srx, len); + if (ret < 0) + goto error; + service_id = srx->srx_service; + + lock_sock(&rx->sk); + + switch (rx->sk.sk_state) { + case RXRPC_UNBOUND: + rx->srx = *srx; + local = rxrpc_lookup_local(sock_net(&rx->sk), &rx->srx); + if (IS_ERR(local)) { + ret = PTR_ERR(local); + goto error_unlock; + } + + if (service_id) { + write_lock(&local->services_lock); + if (rcu_access_pointer(local->service)) + goto service_in_use; + rx->local = local; + rcu_assign_pointer(local->service, rx); + write_unlock(&local->services_lock); + + rx->sk.sk_state = RXRPC_SERVER_BOUND; + } else { + rx->local = local; + rx->sk.sk_state = RXRPC_CLIENT_BOUND; + } + break; + + case RXRPC_SERVER_BOUND: + ret = -EINVAL; + if (service_id == 0) + goto error_unlock; + ret = -EADDRINUSE; + if (service_id == rx->srx.srx_service) + goto error_unlock; + ret = -EINVAL; + srx->srx_service = rx->srx.srx_service; + if (memcmp(srx, &rx->srx, sizeof(*srx)) != 0) + goto error_unlock; + rx->second_service = service_id; + rx->sk.sk_state = RXRPC_SERVER_BOUND2; + break; + + default: + ret = -EINVAL; + goto error_unlock; + } + + release_sock(&rx->sk); + _leave(" = 0"); + return 0; + +service_in_use: + write_unlock(&local->services_lock); + rxrpc_unuse_local(local); + rxrpc_put_local(local); + ret = -EADDRINUSE; +error_unlock: + release_sock(&rx->sk); +error: + _leave(" = %d", ret); + return ret; +} + +/* + * set the number of pending calls permitted on a listening socket + */ +static int rxrpc_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + struct rxrpc_sock *rx = rxrpc_sk(sk); + unsigned int max, old; + int ret; + + _enter("%p,%d", rx, backlog); + + lock_sock(&rx->sk); + + switch (rx->sk.sk_state) { + case RXRPC_UNBOUND: + ret = -EADDRNOTAVAIL; + break; + case RXRPC_SERVER_BOUND: + case RXRPC_SERVER_BOUND2: + ASSERT(rx->local != NULL); + max = READ_ONCE(rxrpc_max_backlog); + ret = -EINVAL; + if (backlog == INT_MAX) + backlog = max; + else if (backlog < 0 || backlog > max) + break; + old = sk->sk_max_ack_backlog; + sk->sk_max_ack_backlog = backlog; + ret = rxrpc_service_prealloc(rx, GFP_KERNEL); + if (ret == 0) + rx->sk.sk_state = RXRPC_SERVER_LISTENING; + else + sk->sk_max_ack_backlog = old; + break; + case RXRPC_SERVER_LISTENING: + if (backlog == 0) { + rx->sk.sk_state = RXRPC_SERVER_LISTEN_DISABLED; + sk->sk_max_ack_backlog = 0; + rxrpc_discard_prealloc(rx); + ret = 0; + break; + } + fallthrough; + default: + ret = -EBUSY; + break; + } + + release_sock(&rx->sk); + _leave(" = %d", ret); + return ret; +} + +/** + * rxrpc_kernel_begin_call - Allow a kernel service to begin a call + * @sock: The socket on which to make the call + * @srx: The address of the peer to contact + * @key: The security context to use (defaults to socket setting) + * @user_call_ID: The ID to use + * @tx_total_len: Total length of data to transmit during the call (or -1) + * @gfp: The allocation constraints + * @notify_rx: Where to send notifications instead of socket queue + * @upgrade: Request service upgrade for call + * @interruptibility: The call is interruptible, or can be canceled. + * @debug_id: The debug ID for tracing to be assigned to the call + * + * Allow a kernel service to begin a call on the nominated socket. This just + * sets up all the internal tracking structures and allocates connection and + * call IDs as appropriate. The call to be used is returned. + * + * The default socket destination address and security may be overridden by + * supplying @srx and @key. + */ +struct rxrpc_call *rxrpc_kernel_begin_call(struct socket *sock, + struct sockaddr_rxrpc *srx, + struct key *key, + unsigned long user_call_ID, + s64 tx_total_len, + gfp_t gfp, + rxrpc_notify_rx_t notify_rx, + bool upgrade, + enum rxrpc_interruptibility interruptibility, + unsigned int debug_id) +{ + struct rxrpc_conn_parameters cp; + struct rxrpc_call_params p; + struct rxrpc_call *call; + struct rxrpc_sock *rx = rxrpc_sk(sock->sk); + int ret; + + _enter(",,%x,%lx", key_serial(key), user_call_ID); + + ret = rxrpc_validate_address(rx, srx, sizeof(*srx)); + if (ret < 0) + return ERR_PTR(ret); + + lock_sock(&rx->sk); + + if (!key) + key = rx->key; + if (key && !key->payload.data[0]) + key = NULL; /* a no-security key */ + + memset(&p, 0, sizeof(p)); + p.user_call_ID = user_call_ID; + p.tx_total_len = tx_total_len; + p.interruptibility = interruptibility; + p.kernel = true; + + memset(&cp, 0, sizeof(cp)); + cp.local = rx->local; + cp.key = key; + cp.security_level = rx->min_sec_level; + cp.exclusive = false; + cp.upgrade = upgrade; + cp.service_id = srx->srx_service; + call = rxrpc_new_client_call(rx, &cp, srx, &p, gfp, debug_id); + /* The socket has been unlocked. */ + if (!IS_ERR(call)) { + call->notify_rx = notify_rx; + mutex_unlock(&call->user_mutex); + } + + rxrpc_put_peer(cp.peer); + _leave(" = %p", call); + return call; +} +EXPORT_SYMBOL(rxrpc_kernel_begin_call); + +/* + * Dummy function used to stop the notifier talking to recvmsg(). + */ +static void rxrpc_dummy_notify_rx(struct sock *sk, struct rxrpc_call *rxcall, + unsigned long call_user_ID) +{ +} + +/** + * rxrpc_kernel_end_call - Allow a kernel service to end a call it was using + * @sock: The socket the call is on + * @call: The call to end + * + * Allow a kernel service to end a call it was using. The call must be + * complete before this is called (the call should be aborted if necessary). + */ +void rxrpc_kernel_end_call(struct socket *sock, struct rxrpc_call *call) +{ + _enter("%d{%d}", call->debug_id, refcount_read(&call->ref)); + + mutex_lock(&call->user_mutex); + rxrpc_release_call(rxrpc_sk(sock->sk), call); + + /* Make sure we're not going to call back into a kernel service */ + if (call->notify_rx) { + spin_lock_bh(&call->notify_lock); + call->notify_rx = rxrpc_dummy_notify_rx; + spin_unlock_bh(&call->notify_lock); + } + + mutex_unlock(&call->user_mutex); + rxrpc_put_call(call, rxrpc_call_put_kernel); +} +EXPORT_SYMBOL(rxrpc_kernel_end_call); + +/** + * rxrpc_kernel_check_life - Check to see whether a call is still alive + * @sock: The socket the call is on + * @call: The call to check + * + * Allow a kernel service to find out whether a call is still alive - + * ie. whether it has completed. + */ +bool rxrpc_kernel_check_life(const struct socket *sock, + const struct rxrpc_call *call) +{ + return call->state != RXRPC_CALL_COMPLETE; +} +EXPORT_SYMBOL(rxrpc_kernel_check_life); + +/** + * rxrpc_kernel_get_epoch - Retrieve the epoch value from a call. + * @sock: The socket the call is on + * @call: The call to query + * + * Allow a kernel service to retrieve the epoch value from a service call to + * see if the client at the other end rebooted. + */ +u32 rxrpc_kernel_get_epoch(struct socket *sock, struct rxrpc_call *call) +{ + return call->conn->proto.epoch; +} +EXPORT_SYMBOL(rxrpc_kernel_get_epoch); + +/** + * rxrpc_kernel_new_call_notification - Get notifications of new calls + * @sock: The socket to intercept received messages on + * @notify_new_call: Function to be called when new calls appear + * @discard_new_call: Function to discard preallocated calls + * + * Allow a kernel service to be given notifications about new calls. + */ +void rxrpc_kernel_new_call_notification( + struct socket *sock, + rxrpc_notify_new_call_t notify_new_call, + rxrpc_discard_new_call_t discard_new_call) +{ + struct rxrpc_sock *rx = rxrpc_sk(sock->sk); + + rx->notify_new_call = notify_new_call; + rx->discard_new_call = discard_new_call; +} +EXPORT_SYMBOL(rxrpc_kernel_new_call_notification); + +/** + * rxrpc_kernel_set_max_life - Set maximum lifespan on a call + * @sock: The socket the call is on + * @call: The call to configure + * @hard_timeout: The maximum lifespan of the call in jiffies + * + * Set the maximum lifespan of a call. The call will end with ETIME or + * ETIMEDOUT if it takes longer than this. + */ +void rxrpc_kernel_set_max_life(struct socket *sock, struct rxrpc_call *call, + unsigned long hard_timeout) +{ + unsigned long now; + + mutex_lock(&call->user_mutex); + + now = jiffies; + hard_timeout += now; + WRITE_ONCE(call->expect_term_by, hard_timeout); + rxrpc_reduce_call_timer(call, hard_timeout, now, rxrpc_timer_set_for_hard); + + mutex_unlock(&call->user_mutex); +} +EXPORT_SYMBOL(rxrpc_kernel_set_max_life); + +/* + * connect an RxRPC socket + * - this just targets it at a specific destination; no actual connection + * negotiation takes place + */ +static int rxrpc_connect(struct socket *sock, struct sockaddr *addr, + int addr_len, int flags) +{ + struct sockaddr_rxrpc *srx = (struct sockaddr_rxrpc *)addr; + struct rxrpc_sock *rx = rxrpc_sk(sock->sk); + int ret; + + _enter("%p,%p,%d,%d", rx, addr, addr_len, flags); + + ret = rxrpc_validate_address(rx, srx, addr_len); + if (ret < 0) { + _leave(" = %d [bad addr]", ret); + return ret; + } + + lock_sock(&rx->sk); + + ret = -EISCONN; + if (test_bit(RXRPC_SOCK_CONNECTED, &rx->flags)) + goto error; + + switch (rx->sk.sk_state) { + case RXRPC_UNBOUND: + rx->sk.sk_state = RXRPC_CLIENT_UNBOUND; + break; + case RXRPC_CLIENT_UNBOUND: + case RXRPC_CLIENT_BOUND: + break; + default: + ret = -EBUSY; + goto error; + } + + rx->connect_srx = *srx; + set_bit(RXRPC_SOCK_CONNECTED, &rx->flags); + ret = 0; + +error: + release_sock(&rx->sk); + return ret; +} + +/* + * send a message through an RxRPC socket + * - in a client this does a number of things: + * - finds/sets up a connection for the security specified (if any) + * - initiates a call (ID in control data) + * - ends the request phase of a call (if MSG_MORE is not set) + * - sends a call data packet + * - may send an abort (abort code in control data) + */ +static int rxrpc_sendmsg(struct socket *sock, struct msghdr *m, size_t len) +{ + struct rxrpc_local *local; + struct rxrpc_sock *rx = rxrpc_sk(sock->sk); + int ret; + + _enter(",{%d},,%zu", rx->sk.sk_state, len); + + if (m->msg_flags & MSG_OOB) + return -EOPNOTSUPP; + + if (m->msg_name) { + ret = rxrpc_validate_address(rx, m->msg_name, m->msg_namelen); + if (ret < 0) { + _leave(" = %d [bad addr]", ret); + return ret; + } + } + + lock_sock(&rx->sk); + + switch (rx->sk.sk_state) { + case RXRPC_UNBOUND: + case RXRPC_CLIENT_UNBOUND: + rx->srx.srx_family = AF_RXRPC; + rx->srx.srx_service = 0; + rx->srx.transport_type = SOCK_DGRAM; + rx->srx.transport.family = rx->family; + switch (rx->family) { + case AF_INET: + rx->srx.transport_len = sizeof(struct sockaddr_in); + break; +#ifdef CONFIG_AF_RXRPC_IPV6 + case AF_INET6: + rx->srx.transport_len = sizeof(struct sockaddr_in6); + break; +#endif + default: + ret = -EAFNOSUPPORT; + goto error_unlock; + } + local = rxrpc_lookup_local(sock_net(sock->sk), &rx->srx); + if (IS_ERR(local)) { + ret = PTR_ERR(local); + goto error_unlock; + } + + rx->local = local; + rx->sk.sk_state = RXRPC_CLIENT_BOUND; + fallthrough; + + case RXRPC_CLIENT_BOUND: + if (!m->msg_name && + test_bit(RXRPC_SOCK_CONNECTED, &rx->flags)) { + m->msg_name = &rx->connect_srx; + m->msg_namelen = sizeof(rx->connect_srx); + } + fallthrough; + case RXRPC_SERVER_BOUND: + case RXRPC_SERVER_LISTENING: + ret = rxrpc_do_sendmsg(rx, m, len); + /* The socket has been unlocked */ + goto out; + default: + ret = -EINVAL; + goto error_unlock; + } + +error_unlock: + release_sock(&rx->sk); +out: + _leave(" = %d", ret); + return ret; +} + +int rxrpc_sock_set_min_security_level(struct sock *sk, unsigned int val) +{ + if (sk->sk_state != RXRPC_UNBOUND) + return -EISCONN; + if (val > RXRPC_SECURITY_MAX) + return -EINVAL; + lock_sock(sk); + rxrpc_sk(sk)->min_sec_level = val; + release_sock(sk); + return 0; +} +EXPORT_SYMBOL(rxrpc_sock_set_min_security_level); + +/* + * set RxRPC socket options + */ +static int rxrpc_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct rxrpc_sock *rx = rxrpc_sk(sock->sk); + unsigned int min_sec_level; + u16 service_upgrade[2]; + int ret; + + _enter(",%d,%d,,%d", level, optname, optlen); + + lock_sock(&rx->sk); + ret = -EOPNOTSUPP; + + if (level == SOL_RXRPC) { + switch (optname) { + case RXRPC_EXCLUSIVE_CONNECTION: + ret = -EINVAL; + if (optlen != 0) + goto error; + ret = -EISCONN; + if (rx->sk.sk_state != RXRPC_UNBOUND) + goto error; + rx->exclusive = true; + goto success; + + case RXRPC_SECURITY_KEY: + ret = -EINVAL; + if (rx->key) + goto error; + ret = -EISCONN; + if (rx->sk.sk_state != RXRPC_UNBOUND) + goto error; + ret = rxrpc_request_key(rx, optval, optlen); + goto error; + + case RXRPC_SECURITY_KEYRING: + ret = -EINVAL; + if (rx->key) + goto error; + ret = -EISCONN; + if (rx->sk.sk_state != RXRPC_UNBOUND) + goto error; + ret = rxrpc_server_keyring(rx, optval, optlen); + goto error; + + case RXRPC_MIN_SECURITY_LEVEL: + ret = -EINVAL; + if (optlen != sizeof(unsigned int)) + goto error; + ret = -EISCONN; + if (rx->sk.sk_state != RXRPC_UNBOUND) + goto error; + ret = copy_from_sockptr(&min_sec_level, optval, + sizeof(unsigned int)); + if (ret < 0) + goto error; + ret = -EINVAL; + if (min_sec_level > RXRPC_SECURITY_MAX) + goto error; + rx->min_sec_level = min_sec_level; + goto success; + + case RXRPC_UPGRADEABLE_SERVICE: + ret = -EINVAL; + if (optlen != sizeof(service_upgrade) || + rx->service_upgrade.from != 0) + goto error; + ret = -EISCONN; + if (rx->sk.sk_state != RXRPC_SERVER_BOUND2) + goto error; + ret = -EFAULT; + if (copy_from_sockptr(service_upgrade, optval, + sizeof(service_upgrade)) != 0) + goto error; + ret = -EINVAL; + if ((service_upgrade[0] != rx->srx.srx_service || + service_upgrade[1] != rx->second_service) && + (service_upgrade[0] != rx->second_service || + service_upgrade[1] != rx->srx.srx_service)) + goto error; + rx->service_upgrade.from = service_upgrade[0]; + rx->service_upgrade.to = service_upgrade[1]; + goto success; + + default: + break; + } + } + +success: + ret = 0; +error: + release_sock(&rx->sk); + return ret; +} + +/* + * Get socket options. + */ +static int rxrpc_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *_optlen) +{ + int optlen; + + if (level != SOL_RXRPC) + return -EOPNOTSUPP; + + if (get_user(optlen, _optlen)) + return -EFAULT; + + switch (optname) { + case RXRPC_SUPPORTED_CMSG: + if (optlen < sizeof(int)) + return -ETOOSMALL; + if (put_user(RXRPC__SUPPORTED - 1, (int __user *)optval) || + put_user(sizeof(int), _optlen)) + return -EFAULT; + return 0; + + default: + return -EOPNOTSUPP; + } +} + +/* + * permit an RxRPC socket to be polled + */ +static __poll_t rxrpc_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + struct sock *sk = sock->sk; + struct rxrpc_sock *rx = rxrpc_sk(sk); + __poll_t mask; + + sock_poll_wait(file, sock, wait); + mask = 0; + + /* the socket is readable if there are any messages waiting on the Rx + * queue */ + if (!list_empty(&rx->recvmsg_q)) + mask |= EPOLLIN | EPOLLRDNORM; + + /* the socket is writable if there is space to add new data to the + * socket; there is no guarantee that any particular call in progress + * on the socket may have space in the Tx ACK window */ + if (rxrpc_writable(sk)) + mask |= EPOLLOUT | EPOLLWRNORM; + + return mask; +} + +/* + * create an RxRPC socket + */ +static int rxrpc_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct rxrpc_net *rxnet; + struct rxrpc_sock *rx; + struct sock *sk; + + _enter("%p,%d", sock, protocol); + + /* we support transport protocol UDP/UDP6 only */ + if (protocol != PF_INET && + IS_ENABLED(CONFIG_AF_RXRPC_IPV6) && protocol != PF_INET6) + return -EPROTONOSUPPORT; + + if (sock->type != SOCK_DGRAM) + return -ESOCKTNOSUPPORT; + + sock->ops = &rxrpc_rpc_ops; + sock->state = SS_UNCONNECTED; + + sk = sk_alloc(net, PF_RXRPC, GFP_KERNEL, &rxrpc_proto, kern); + if (!sk) + return -ENOMEM; + + sock_init_data(sock, sk); + sock_set_flag(sk, SOCK_RCU_FREE); + sk->sk_state = RXRPC_UNBOUND; + sk->sk_write_space = rxrpc_write_space; + sk->sk_max_ack_backlog = 0; + sk->sk_destruct = rxrpc_sock_destructor; + + rx = rxrpc_sk(sk); + rx->family = protocol; + rx->calls = RB_ROOT; + + spin_lock_init(&rx->incoming_lock); + INIT_LIST_HEAD(&rx->sock_calls); + INIT_LIST_HEAD(&rx->to_be_accepted); + INIT_LIST_HEAD(&rx->recvmsg_q); + rwlock_init(&rx->recvmsg_lock); + rwlock_init(&rx->call_lock); + memset(&rx->srx, 0, sizeof(rx->srx)); + + rxnet = rxrpc_net(sock_net(&rx->sk)); + timer_reduce(&rxnet->peer_keepalive_timer, jiffies + 1); + + _leave(" = 0 [%p]", rx); + return 0; +} + +/* + * Kill all the calls on a socket and shut it down. + */ +static int rxrpc_shutdown(struct socket *sock, int flags) +{ + struct sock *sk = sock->sk; + struct rxrpc_sock *rx = rxrpc_sk(sk); + int ret = 0; + + _enter("%p,%d", sk, flags); + + if (flags != SHUT_RDWR) + return -EOPNOTSUPP; + if (sk->sk_state == RXRPC_CLOSE) + return -ESHUTDOWN; + + lock_sock(sk); + + spin_lock_bh(&sk->sk_receive_queue.lock); + if (sk->sk_state < RXRPC_CLOSE) { + sk->sk_state = RXRPC_CLOSE; + sk->sk_shutdown = SHUTDOWN_MASK; + } else { + ret = -ESHUTDOWN; + } + spin_unlock_bh(&sk->sk_receive_queue.lock); + + rxrpc_discard_prealloc(rx); + + release_sock(sk); + return ret; +} + +/* + * RxRPC socket destructor + */ +static void rxrpc_sock_destructor(struct sock *sk) +{ + _enter("%p", sk); + + rxrpc_purge_queue(&sk->sk_receive_queue); + + WARN_ON(refcount_read(&sk->sk_wmem_alloc)); + WARN_ON(!sk_unhashed(sk)); + WARN_ON(sk->sk_socket); + + if (!sock_flag(sk, SOCK_DEAD)) { + printk("Attempt to release alive rxrpc socket: %p\n", sk); + return; + } +} + +/* + * release an RxRPC socket + */ +static int rxrpc_release_sock(struct sock *sk) +{ + struct rxrpc_sock *rx = rxrpc_sk(sk); + + _enter("%p{%d,%d}", sk, sk->sk_state, refcount_read(&sk->sk_refcnt)); + + /* declare the socket closed for business */ + sock_orphan(sk); + sk->sk_shutdown = SHUTDOWN_MASK; + + /* We want to kill off all connections from a service socket + * as fast as possible because we can't share these; client + * sockets, on the other hand, can share an endpoint. + */ + switch (sk->sk_state) { + case RXRPC_SERVER_BOUND: + case RXRPC_SERVER_BOUND2: + case RXRPC_SERVER_LISTENING: + case RXRPC_SERVER_LISTEN_DISABLED: + rx->local->service_closed = true; + break; + } + + spin_lock_bh(&sk->sk_receive_queue.lock); + sk->sk_state = RXRPC_CLOSE; + spin_unlock_bh(&sk->sk_receive_queue.lock); + + if (rx->local && rcu_access_pointer(rx->local->service) == rx) { + write_lock(&rx->local->services_lock); + rcu_assign_pointer(rx->local->service, NULL); + write_unlock(&rx->local->services_lock); + } + + /* try to flush out this socket */ + rxrpc_discard_prealloc(rx); + rxrpc_release_calls_on_socket(rx); + flush_workqueue(rxrpc_workqueue); + rxrpc_purge_queue(&sk->sk_receive_queue); + + rxrpc_unuse_local(rx->local); + rxrpc_put_local(rx->local); + rx->local = NULL; + key_put(rx->key); + rx->key = NULL; + key_put(rx->securities); + rx->securities = NULL; + sock_put(sk); + + _leave(" = 0"); + return 0; +} + +/* + * release an RxRPC BSD socket on close() or equivalent + */ +static int rxrpc_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + + _enter("%p{%p}", sock, sk); + + if (!sk) + return 0; + + sock->sk = NULL; + + return rxrpc_release_sock(sk); +} + +/* + * RxRPC network protocol + */ +static const struct proto_ops rxrpc_rpc_ops = { + .family = PF_RXRPC, + .owner = THIS_MODULE, + .release = rxrpc_release, + .bind = rxrpc_bind, + .connect = rxrpc_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = sock_no_getname, + .poll = rxrpc_poll, + .ioctl = sock_no_ioctl, + .listen = rxrpc_listen, + .shutdown = rxrpc_shutdown, + .setsockopt = rxrpc_setsockopt, + .getsockopt = rxrpc_getsockopt, + .sendmsg = rxrpc_sendmsg, + .recvmsg = rxrpc_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static struct proto rxrpc_proto = { + .name = "RXRPC", + .owner = THIS_MODULE, + .obj_size = sizeof(struct rxrpc_sock), + .max_header = sizeof(struct rxrpc_wire_header), +}; + +static const struct net_proto_family rxrpc_family_ops = { + .family = PF_RXRPC, + .create = rxrpc_create, + .owner = THIS_MODULE, +}; + +/* + * initialise and register the RxRPC protocol + */ +static int __init af_rxrpc_init(void) +{ + int ret = -1; + unsigned int tmp; + + BUILD_BUG_ON(sizeof(struct rxrpc_skb_priv) > sizeof_field(struct sk_buff, cb)); + + get_random_bytes(&tmp, sizeof(tmp)); + tmp &= 0x3fffffff; + if (tmp == 0) + tmp = 1; + idr_set_cursor(&rxrpc_client_conn_ids, tmp); + + ret = -ENOMEM; + rxrpc_call_jar = kmem_cache_create( + "rxrpc_call_jar", sizeof(struct rxrpc_call), 0, + SLAB_HWCACHE_ALIGN, NULL); + if (!rxrpc_call_jar) { + pr_notice("Failed to allocate call jar\n"); + goto error_call_jar; + } + + rxrpc_workqueue = alloc_workqueue("krxrpcd", 0, 1); + if (!rxrpc_workqueue) { + pr_notice("Failed to allocate work queue\n"); + goto error_work_queue; + } + + ret = rxrpc_init_security(); + if (ret < 0) { + pr_crit("Cannot initialise security\n"); + goto error_security; + } + + ret = register_pernet_device(&rxrpc_net_ops); + if (ret) + goto error_pernet; + + ret = proto_register(&rxrpc_proto, 1); + if (ret < 0) { + pr_crit("Cannot register protocol\n"); + goto error_proto; + } + + ret = sock_register(&rxrpc_family_ops); + if (ret < 0) { + pr_crit("Cannot register socket family\n"); + goto error_sock; + } + + ret = register_key_type(&key_type_rxrpc); + if (ret < 0) { + pr_crit("Cannot register client key type\n"); + goto error_key_type; + } + + ret = register_key_type(&key_type_rxrpc_s); + if (ret < 0) { + pr_crit("Cannot register server key type\n"); + goto error_key_type_s; + } + + ret = rxrpc_sysctl_init(); + if (ret < 0) { + pr_crit("Cannot register sysctls\n"); + goto error_sysctls; + } + + return 0; + +error_sysctls: + unregister_key_type(&key_type_rxrpc_s); +error_key_type_s: + unregister_key_type(&key_type_rxrpc); +error_key_type: + sock_unregister(PF_RXRPC); +error_sock: + proto_unregister(&rxrpc_proto); +error_proto: + unregister_pernet_device(&rxrpc_net_ops); +error_pernet: + rxrpc_exit_security(); +error_security: + destroy_workqueue(rxrpc_workqueue); +error_work_queue: + kmem_cache_destroy(rxrpc_call_jar); +error_call_jar: + return ret; +} + +/* + * unregister the RxRPC protocol + */ +static void __exit af_rxrpc_exit(void) +{ + _enter(""); + rxrpc_sysctl_exit(); + unregister_key_type(&key_type_rxrpc_s); + unregister_key_type(&key_type_rxrpc); + sock_unregister(PF_RXRPC); + proto_unregister(&rxrpc_proto); + unregister_pernet_device(&rxrpc_net_ops); + ASSERTCMP(atomic_read(&rxrpc_n_tx_skbs), ==, 0); + ASSERTCMP(atomic_read(&rxrpc_n_rx_skbs), ==, 0); + + /* Make sure the local and peer records pinned by any dying connections + * are released. + */ + rcu_barrier(); + rxrpc_destroy_client_conn_ids(); + + destroy_workqueue(rxrpc_workqueue); + rxrpc_exit_security(); + kmem_cache_destroy(rxrpc_call_jar); + _leave(""); +} + +module_init(af_rxrpc_init); +module_exit(af_rxrpc_exit); diff --git a/net/rxrpc/ar-internal.h b/net/rxrpc/ar-internal.h new file mode 100644 index 000000000..8499ceb77 --- /dev/null +++ b/net/rxrpc/ar-internal.h @@ -0,0 +1,1258 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* AF_RXRPC internal definitions + * + * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "protocol.h" + +#define FCRYPT_BSIZE 8 +struct rxrpc_crypt { + union { + u8 x[FCRYPT_BSIZE]; + __be32 n[2]; + }; +} __attribute__((aligned(8))); + +#define rxrpc_queue_work(WS) queue_work(rxrpc_workqueue, (WS)) +#define rxrpc_queue_delayed_work(WS,D) \ + queue_delayed_work(rxrpc_workqueue, (WS), (D)) + +struct key_preparsed_payload; +struct rxrpc_connection; + +/* + * Mark applied to socket buffers in skb->mark. skb->priority is used + * to pass supplementary information. + */ +enum rxrpc_skb_mark { + RXRPC_SKB_MARK_REJECT_BUSY, /* Reject with BUSY */ + RXRPC_SKB_MARK_REJECT_ABORT, /* Reject with ABORT (code in skb->priority) */ +}; + +/* + * sk_state for RxRPC sockets + */ +enum { + RXRPC_UNBOUND = 0, + RXRPC_CLIENT_UNBOUND, /* Unbound socket used as client */ + RXRPC_CLIENT_BOUND, /* client local address bound */ + RXRPC_SERVER_BOUND, /* server local address bound */ + RXRPC_SERVER_BOUND2, /* second server local address bound */ + RXRPC_SERVER_LISTENING, /* server listening for connections */ + RXRPC_SERVER_LISTEN_DISABLED, /* server listening disabled */ + RXRPC_CLOSE, /* socket is being closed */ +}; + +/* + * Per-network namespace data. + */ +struct rxrpc_net { + struct proc_dir_entry *proc_net; /* Subdir in /proc/net */ + u32 epoch; /* Local epoch for detecting local-end reset */ + struct list_head calls; /* List of calls active in this namespace */ + spinlock_t call_lock; /* Lock for ->calls */ + atomic_t nr_calls; /* Count of allocated calls */ + + atomic_t nr_conns; + struct list_head conn_proc_list; /* List of conns in this namespace for proc */ + struct list_head service_conns; /* Service conns in this namespace */ + rwlock_t conn_lock; /* Lock for ->conn_proc_list, ->service_conns */ + struct work_struct service_conn_reaper; + struct timer_list service_conn_reap_timer; + + bool live; + + bool kill_all_client_conns; + atomic_t nr_client_conns; + spinlock_t client_conn_cache_lock; /* Lock for ->*_client_conns */ + spinlock_t client_conn_discard_lock; /* Prevent multiple discarders */ + struct list_head idle_client_conns; + struct work_struct client_conn_reaper; + struct timer_list client_conn_reap_timer; + + struct hlist_head local_endpoints; + struct mutex local_mutex; /* Lock for ->local_endpoints */ + + DECLARE_HASHTABLE (peer_hash, 10); + spinlock_t peer_hash_lock; /* Lock for ->peer_hash */ + +#define RXRPC_KEEPALIVE_TIME 20 /* NAT keepalive time in seconds */ + u8 peer_keepalive_cursor; + time64_t peer_keepalive_base; + struct list_head peer_keepalive[32]; + struct list_head peer_keepalive_new; + struct timer_list peer_keepalive_timer; + struct work_struct peer_keepalive_work; +}; + +/* + * Service backlog preallocation. + * + * This contains circular buffers of preallocated peers, connections and calls + * for incoming service calls and their head and tail pointers. This allows + * calls to be set up in the data_ready handler, thereby avoiding the need to + * shuffle packets around so much. + */ +struct rxrpc_backlog { + unsigned short peer_backlog_head; + unsigned short peer_backlog_tail; + unsigned short conn_backlog_head; + unsigned short conn_backlog_tail; + unsigned short call_backlog_head; + unsigned short call_backlog_tail; +#define RXRPC_BACKLOG_MAX 32 + struct rxrpc_peer *peer_backlog[RXRPC_BACKLOG_MAX]; + struct rxrpc_connection *conn_backlog[RXRPC_BACKLOG_MAX]; + struct rxrpc_call *call_backlog[RXRPC_BACKLOG_MAX]; +}; + +/* + * RxRPC socket definition + */ +struct rxrpc_sock { + /* WARNING: sk has to be the first member */ + struct sock sk; + rxrpc_notify_new_call_t notify_new_call; /* Func to notify of new call */ + rxrpc_discard_new_call_t discard_new_call; /* Func to discard a new call */ + struct rxrpc_local *local; /* local endpoint */ + struct rxrpc_backlog *backlog; /* Preallocation for services */ + spinlock_t incoming_lock; /* Incoming call vs service shutdown lock */ + struct list_head sock_calls; /* List of calls owned by this socket */ + struct list_head to_be_accepted; /* calls awaiting acceptance */ + struct list_head recvmsg_q; /* Calls awaiting recvmsg's attention */ + rwlock_t recvmsg_lock; /* Lock for recvmsg_q */ + struct key *key; /* security for this socket */ + struct key *securities; /* list of server security descriptors */ + struct rb_root calls; /* User ID -> call mapping */ + unsigned long flags; +#define RXRPC_SOCK_CONNECTED 0 /* connect_srx is set */ + rwlock_t call_lock; /* lock for calls */ + u32 min_sec_level; /* minimum security level */ +#define RXRPC_SECURITY_MAX RXRPC_SECURITY_ENCRYPT + bool exclusive; /* Exclusive connection for a client socket */ + u16 second_service; /* Additional service bound to the endpoint */ + struct { + /* Service upgrade information */ + u16 from; /* Service ID to upgrade (if not 0) */ + u16 to; /* service ID to upgrade to */ + } service_upgrade; + sa_family_t family; /* Protocol family created with */ + struct sockaddr_rxrpc srx; /* Primary Service/local addresses */ + struct sockaddr_rxrpc connect_srx; /* Default client address from connect() */ +}; + +#define rxrpc_sk(__sk) container_of((__sk), struct rxrpc_sock, sk) + +/* + * CPU-byteorder normalised Rx packet header. + */ +struct rxrpc_host_header { + u32 epoch; /* client boot timestamp */ + u32 cid; /* connection and channel ID */ + u32 callNumber; /* call ID (0 for connection-level packets) */ + u32 seq; /* sequence number of pkt in call stream */ + u32 serial; /* serial number of pkt sent to network */ + u8 type; /* packet type */ + u8 flags; /* packet flags */ + u8 userStatus; /* app-layer defined status */ + u8 securityIndex; /* security protocol ID */ + union { + u16 _rsvd; /* reserved */ + u16 cksum; /* kerberos security checksum */ + }; + u16 serviceId; /* service ID */ +} __packed; + +/* + * RxRPC socket buffer private variables + * - max 48 bytes (struct sk_buff::cb) + */ +struct rxrpc_skb_priv { + atomic_t nr_ring_pins; /* Number of rxtx ring pins */ + u8 nr_subpackets; /* Number of subpackets */ + u8 rx_flags; /* Received packet flags */ +#define RXRPC_SKB_INCL_LAST 0x01 /* - Includes last packet */ +#define RXRPC_SKB_TX_BUFFER 0x02 /* - Is transmit buffer */ + union { + int remain; /* amount of space remaining for next write */ + + /* List of requested ACKs on subpackets */ + unsigned long rx_req_ack[(RXRPC_MAX_NR_JUMBO + BITS_PER_LONG - 1) / + BITS_PER_LONG]; + }; + + struct rxrpc_host_header hdr; /* RxRPC packet header from this packet */ +}; + +#define rxrpc_skb(__skb) ((struct rxrpc_skb_priv *) &(__skb)->cb) + +/* + * RxRPC security module interface + */ +struct rxrpc_security { + const char *name; /* name of this service */ + u8 security_index; /* security type provided */ + u32 no_key_abort; /* Abort code indicating no key */ + + /* Initialise a security service */ + int (*init)(void); + + /* Clean up a security service */ + void (*exit)(void); + + /* Parse the information from a server key */ + int (*preparse_server_key)(struct key_preparsed_payload *); + + /* Clean up the preparse buffer after parsing a server key */ + void (*free_preparse_server_key)(struct key_preparsed_payload *); + + /* Destroy the payload of a server key */ + void (*destroy_server_key)(struct key *); + + /* Describe a server key */ + void (*describe_server_key)(const struct key *, struct seq_file *); + + /* initialise a connection's security */ + int (*init_connection_security)(struct rxrpc_connection *, + struct rxrpc_key_token *); + + /* Work out how much data we can store in a packet, given an estimate + * of the amount of data remaining. + */ + int (*how_much_data)(struct rxrpc_call *, size_t, + size_t *, size_t *, size_t *); + + /* impose security on a packet */ + int (*secure_packet)(struct rxrpc_call *, struct sk_buff *, size_t); + + /* verify the security on a received packet */ + int (*verify_packet)(struct rxrpc_call *, struct sk_buff *, + unsigned int, unsigned int, rxrpc_seq_t, u16); + + /* Free crypto request on a call */ + void (*free_call_crypto)(struct rxrpc_call *); + + /* Locate the data in a received packet that has been verified. */ + void (*locate_data)(struct rxrpc_call *, struct sk_buff *, + unsigned int *, unsigned int *); + + /* issue a challenge */ + int (*issue_challenge)(struct rxrpc_connection *); + + /* respond to a challenge */ + int (*respond_to_challenge)(struct rxrpc_connection *, + struct sk_buff *, + u32 *); + + /* verify a response */ + int (*verify_response)(struct rxrpc_connection *, + struct sk_buff *, + u32 *); + + /* clear connection security */ + void (*clear)(struct rxrpc_connection *); +}; + +/* + * RxRPC local transport endpoint description + * - owned by a single AF_RXRPC socket + * - pointed to by transport socket struct sk_user_data + */ +struct rxrpc_local { + struct rcu_head rcu; + atomic_t active_users; /* Number of users of the local endpoint */ + refcount_t ref; /* Number of references to the structure */ + struct rxrpc_net *rxnet; /* The network ns in which this resides */ + struct hlist_node link; + struct socket *socket; /* my UDP socket */ + struct work_struct processor; + struct rxrpc_sock __rcu *service; /* Service(s) listening on this endpoint */ + struct rw_semaphore defrag_sem; /* control re-enablement of IP DF bit */ + struct sk_buff_head reject_queue; /* packets awaiting rejection */ + struct sk_buff_head event_queue; /* endpoint event packets awaiting processing */ + struct rb_root client_bundles; /* Client connection bundles by socket params */ + spinlock_t client_bundles_lock; /* Lock for client_bundles */ + spinlock_t lock; /* access lock */ + rwlock_t services_lock; /* lock for services list */ + int debug_id; /* debug ID for printks */ + bool dead; + bool service_closed; /* Service socket closed */ + struct sockaddr_rxrpc srx; /* local address */ +}; + +/* + * RxRPC remote transport endpoint definition + * - matched by local endpoint, remote port, address and protocol type + */ +struct rxrpc_peer { + struct rcu_head rcu; /* This must be first */ + refcount_t ref; + unsigned long hash_key; + struct hlist_node hash_link; + struct rxrpc_local *local; + struct hlist_head error_targets; /* targets for net error distribution */ + struct rb_root service_conns; /* Service connections */ + struct list_head keepalive_link; /* Link in net->peer_keepalive[] */ + time64_t last_tx_at; /* Last time packet sent here */ + seqlock_t service_conn_lock; + spinlock_t lock; /* access lock */ + unsigned int if_mtu; /* interface MTU for this peer */ + unsigned int mtu; /* network MTU for this peer */ + unsigned int maxdata; /* data size (MTU - hdrsize) */ + unsigned short hdrsize; /* header size (IP + UDP + RxRPC) */ + int debug_id; /* debug ID for printks */ + struct sockaddr_rxrpc srx; /* remote address */ + + /* calculated RTT cache */ +#define RXRPC_RTT_CACHE_SIZE 32 + spinlock_t rtt_input_lock; /* RTT lock for input routine */ + ktime_t rtt_last_req; /* Time of last RTT request */ + unsigned int rtt_count; /* Number of samples we've got */ + + u32 srtt_us; /* smoothed round trip time << 3 in usecs */ + u32 mdev_us; /* medium deviation */ + u32 mdev_max_us; /* maximal mdev for the last rtt period */ + u32 rttvar_us; /* smoothed mdev_max */ + u32 rto_j; /* Retransmission timeout in jiffies */ + u8 backoff; /* Backoff timeout */ + + u8 cong_cwnd; /* Congestion window size */ +}; + +/* + * Keys for matching a connection. + */ +struct rxrpc_conn_proto { + union { + struct { + u32 epoch; /* epoch of this connection */ + u32 cid; /* connection ID */ + }; + u64 index_key; + }; +}; + +struct rxrpc_conn_parameters { + struct rxrpc_local *local; /* Representation of local endpoint */ + struct rxrpc_peer *peer; /* Remote endpoint */ + struct key *key; /* Security details */ + bool exclusive; /* T if conn is exclusive */ + bool upgrade; /* T if service ID can be upgraded */ + u16 service_id; /* Service ID for this connection */ + u32 security_level; /* Security level selected */ +}; + +/* + * Bits in the connection flags. + */ +enum rxrpc_conn_flag { + RXRPC_CONN_HAS_IDR, /* Has a client conn ID assigned */ + RXRPC_CONN_IN_SERVICE_CONNS, /* Conn is in peer->service_conns */ + RXRPC_CONN_DONT_REUSE, /* Don't reuse this connection */ + RXRPC_CONN_PROBING_FOR_UPGRADE, /* Probing for service upgrade */ + RXRPC_CONN_FINAL_ACK_0, /* Need final ACK for channel 0 */ + RXRPC_CONN_FINAL_ACK_1, /* Need final ACK for channel 1 */ + RXRPC_CONN_FINAL_ACK_2, /* Need final ACK for channel 2 */ + RXRPC_CONN_FINAL_ACK_3, /* Need final ACK for channel 3 */ +}; + +#define RXRPC_CONN_FINAL_ACK_MASK ((1UL << RXRPC_CONN_FINAL_ACK_0) | \ + (1UL << RXRPC_CONN_FINAL_ACK_1) | \ + (1UL << RXRPC_CONN_FINAL_ACK_2) | \ + (1UL << RXRPC_CONN_FINAL_ACK_3)) + +/* + * Events that can be raised upon a connection. + */ +enum rxrpc_conn_event { + RXRPC_CONN_EV_CHALLENGE, /* Send challenge packet */ +}; + +/* + * The connection protocol state. + */ +enum rxrpc_conn_proto_state { + RXRPC_CONN_UNUSED, /* Connection not yet attempted */ + RXRPC_CONN_CLIENT, /* Client connection */ + RXRPC_CONN_SERVICE_PREALLOC, /* Service connection preallocation */ + RXRPC_CONN_SERVICE_UNSECURED, /* Service unsecured connection */ + RXRPC_CONN_SERVICE_CHALLENGING, /* Service challenging for security */ + RXRPC_CONN_SERVICE, /* Service secured connection */ + RXRPC_CONN_REMOTELY_ABORTED, /* Conn aborted by peer */ + RXRPC_CONN_LOCALLY_ABORTED, /* Conn aborted locally */ + RXRPC_CONN__NR_STATES +}; + +/* + * RxRPC client connection bundle. + */ +struct rxrpc_bundle { + struct rxrpc_conn_parameters params; + refcount_t ref; + atomic_t active; /* Number of active users */ + unsigned int debug_id; + bool try_upgrade; /* True if the bundle is attempting upgrade */ + bool alloc_conn; /* True if someone's getting a conn */ + short alloc_error; /* Error from last conn allocation */ + spinlock_t channel_lock; + struct rb_node local_node; /* Node in local->client_conns */ + struct list_head waiting_calls; /* Calls waiting for channels */ + unsigned long avail_chans; /* Mask of available channels */ + struct rxrpc_connection *conns[4]; /* The connections in the bundle (max 4) */ +}; + +/* + * RxRPC connection definition + * - matched by { local, peer, epoch, conn_id, direction } + * - each connection can only handle four simultaneous calls + */ +struct rxrpc_connection { + struct rxrpc_conn_proto proto; + struct rxrpc_conn_parameters params; + + refcount_t ref; + struct rcu_head rcu; + struct list_head cache_link; + + unsigned char act_chans; /* Mask of active channels */ + struct rxrpc_channel { + unsigned long final_ack_at; /* Time at which to issue final ACK */ + struct rxrpc_call __rcu *call; /* Active call */ + unsigned int call_debug_id; /* call->debug_id */ + u32 call_id; /* ID of current call */ + u32 call_counter; /* Call ID counter */ + u32 last_call; /* ID of last call */ + u8 last_type; /* Type of last packet */ + union { + u32 last_seq; + u32 last_abort; + }; + } channels[RXRPC_MAXCALLS]; + + struct timer_list timer; /* Conn event timer */ + struct work_struct processor; /* connection event processor */ + struct rxrpc_bundle *bundle; /* Client connection bundle */ + struct rb_node service_node; /* Node in peer->service_conns */ + struct list_head proc_link; /* link in procfs list */ + struct list_head link; /* link in master connection list */ + struct sk_buff_head rx_queue; /* received conn-level packets */ + + const struct rxrpc_security *security; /* applied security module */ + union { + struct { + struct crypto_sync_skcipher *cipher; /* encryption handle */ + struct rxrpc_crypt csum_iv; /* packet checksum base */ + u32 nonce; /* response re-use preventer */ + } rxkad; + }; + unsigned long flags; + unsigned long events; + unsigned long idle_timestamp; /* Time at which last became idle */ + spinlock_t state_lock; /* state-change lock */ + enum rxrpc_conn_proto_state state; /* current state of connection */ + u32 abort_code; /* Abort code of connection abort */ + int debug_id; /* debug ID for printks */ + atomic_t serial; /* packet serial number counter */ + unsigned int hi_serial; /* highest serial number received */ + u32 service_id; /* Service ID, possibly upgraded */ + u8 security_ix; /* security type */ + u8 out_clientflag; /* RXRPC_CLIENT_INITIATED if we are client */ + u8 bundle_shift; /* Index into bundle->avail_chans */ + short error; /* Local error code */ +}; + +static inline bool rxrpc_to_server(const struct rxrpc_skb_priv *sp) +{ + return sp->hdr.flags & RXRPC_CLIENT_INITIATED; +} + +static inline bool rxrpc_to_client(const struct rxrpc_skb_priv *sp) +{ + return !rxrpc_to_server(sp); +} + +/* + * Flags in call->flags. + */ +enum rxrpc_call_flag { + RXRPC_CALL_RELEASED, /* call has been released - no more message to userspace */ + RXRPC_CALL_HAS_USERID, /* has a user ID attached */ + RXRPC_CALL_IS_SERVICE, /* Call is service call */ + RXRPC_CALL_EXPOSED, /* The call was exposed to the world */ + RXRPC_CALL_RX_LAST, /* Received the last packet (at rxtx_top) */ + RXRPC_CALL_TX_LAST, /* Last packet in Tx buffer (at rxtx_top) */ + RXRPC_CALL_SEND_PING, /* A ping will need to be sent */ + RXRPC_CALL_RETRANS_TIMEOUT, /* Retransmission due to timeout occurred */ + RXRPC_CALL_BEGAN_RX_TIMER, /* We began the expect_rx_by timer */ + RXRPC_CALL_RX_HEARD, /* The peer responded at least once to this call */ + RXRPC_CALL_RX_UNDERRUN, /* Got data underrun */ + RXRPC_CALL_DISCONNECTED, /* The call has been disconnected */ + RXRPC_CALL_KERNEL, /* The call was made by the kernel */ + RXRPC_CALL_UPGRADE, /* Service upgrade was requested for the call */ +}; + +/* + * Events that can be raised on a call. + */ +enum rxrpc_call_event { + RXRPC_CALL_EV_ACK, /* need to generate ACK */ + RXRPC_CALL_EV_ABORT, /* need to generate abort */ + RXRPC_CALL_EV_RESEND, /* Tx resend required */ + RXRPC_CALL_EV_PING, /* Ping send required */ + RXRPC_CALL_EV_EXPIRED, /* Expiry occurred */ + RXRPC_CALL_EV_ACK_LOST, /* ACK may be lost, send ping */ +}; + +/* + * The states that a call can be in. + */ +enum rxrpc_call_state { + RXRPC_CALL_UNINITIALISED, + RXRPC_CALL_CLIENT_AWAIT_CONN, /* - client waiting for connection to become available */ + RXRPC_CALL_CLIENT_SEND_REQUEST, /* - client sending request phase */ + RXRPC_CALL_CLIENT_AWAIT_REPLY, /* - client awaiting reply */ + RXRPC_CALL_CLIENT_RECV_REPLY, /* - client receiving reply phase */ + RXRPC_CALL_SERVER_PREALLOC, /* - service preallocation */ + RXRPC_CALL_SERVER_SECURING, /* - server securing request connection */ + RXRPC_CALL_SERVER_RECV_REQUEST, /* - server receiving request */ + RXRPC_CALL_SERVER_ACK_REQUEST, /* - server pending ACK of request */ + RXRPC_CALL_SERVER_SEND_REPLY, /* - server sending reply */ + RXRPC_CALL_SERVER_AWAIT_ACK, /* - server awaiting final ACK */ + RXRPC_CALL_COMPLETE, /* - call complete */ + NR__RXRPC_CALL_STATES +}; + +/* + * Call completion condition (state == RXRPC_CALL_COMPLETE). + */ +enum rxrpc_call_completion { + RXRPC_CALL_SUCCEEDED, /* - Normal termination */ + RXRPC_CALL_REMOTELY_ABORTED, /* - call aborted by peer */ + RXRPC_CALL_LOCALLY_ABORTED, /* - call aborted locally on error or close */ + RXRPC_CALL_LOCAL_ERROR, /* - call failed due to local error */ + RXRPC_CALL_NETWORK_ERROR, /* - call terminated by network error */ + NR__RXRPC_CALL_COMPLETIONS +}; + +/* + * Call Tx congestion management modes. + */ +enum rxrpc_congest_mode { + RXRPC_CALL_SLOW_START, + RXRPC_CALL_CONGEST_AVOIDANCE, + RXRPC_CALL_PACKET_LOSS, + RXRPC_CALL_FAST_RETRANSMIT, + NR__RXRPC_CONGEST_MODES +}; + +/* + * RxRPC call definition + * - matched by { connection, call_id } + */ +struct rxrpc_call { + struct rcu_head rcu; + struct rxrpc_connection *conn; /* connection carrying call */ + struct rxrpc_peer *peer; /* Peer record for remote address */ + struct rxrpc_sock __rcu *socket; /* socket responsible */ + struct rxrpc_net *rxnet; /* Network namespace to which call belongs */ + const struct rxrpc_security *security; /* applied security module */ + struct mutex user_mutex; /* User access mutex */ + unsigned long ack_at; /* When deferred ACK needs to happen */ + unsigned long ack_lost_at; /* When ACK is figured as lost */ + unsigned long resend_at; /* When next resend needs to happen */ + unsigned long ping_at; /* When next to send a ping */ + unsigned long keepalive_at; /* When next to send a keepalive ping */ + unsigned long expect_rx_by; /* When we expect to get a packet by */ + unsigned long expect_req_by; /* When we expect to get a request DATA packet by */ + unsigned long expect_term_by; /* When we expect call termination by */ + u32 next_rx_timo; /* Timeout for next Rx packet (jif) */ + u32 next_req_timo; /* Timeout for next Rx request packet (jif) */ + struct skcipher_request *cipher_req; /* Packet cipher request buffer */ + struct timer_list timer; /* Combined event timer */ + struct work_struct processor; /* Event processor */ + rxrpc_notify_rx_t notify_rx; /* kernel service Rx notification function */ + struct list_head link; /* link in master call list */ + struct list_head chan_wait_link; /* Link in conn->bundle->waiting_calls */ + struct hlist_node error_link; /* link in error distribution list */ + struct list_head accept_link; /* Link in rx->acceptq */ + struct list_head recvmsg_link; /* Link in rx->recvmsg_q */ + struct list_head sock_link; /* Link in rx->sock_calls */ + struct rb_node sock_node; /* Node in rx->calls */ + struct sk_buff *tx_pending; /* Tx socket buffer being filled */ + wait_queue_head_t waitq; /* Wait queue for channel or Tx */ + s64 tx_total_len; /* Total length left to be transmitted (or -1) */ + __be32 crypto_buf[2]; /* Temporary packet crypto buffer */ + unsigned long user_call_ID; /* user-defined call ID */ + unsigned long flags; + unsigned long events; + spinlock_t lock; + spinlock_t notify_lock; /* Kernel notification lock */ + rwlock_t state_lock; /* lock for state transition */ + u32 abort_code; /* Local/remote abort code */ + int error; /* Local error incurred */ + enum rxrpc_call_state state; /* current state of call */ + enum rxrpc_call_completion completion; /* Call completion condition */ + refcount_t ref; + u16 service_id; /* service ID */ + u8 security_ix; /* Security type */ + enum rxrpc_interruptibility interruptibility; /* At what point call may be interrupted */ + u32 call_id; /* call ID on connection */ + u32 cid; /* connection ID plus channel index */ + int debug_id; /* debug ID for printks */ + unsigned short rx_pkt_offset; /* Current recvmsg packet offset */ + unsigned short rx_pkt_len; /* Current recvmsg packet len */ + bool rx_pkt_last; /* Current recvmsg packet is last */ + + /* Rx/Tx circular buffer, depending on phase. + * + * In the Rx phase, packets are annotated with 0 or the number of the + * segment of a jumbo packet each buffer refers to. There can be up to + * 47 segments in a maximum-size UDP packet. + * + * In the Tx phase, packets are annotated with which buffers have been + * acked. + */ +#define RXRPC_RXTX_BUFF_SIZE 64 +#define RXRPC_RXTX_BUFF_MASK (RXRPC_RXTX_BUFF_SIZE - 1) +#define RXRPC_INIT_RX_WINDOW_SIZE 63 + struct sk_buff **rxtx_buffer; + u8 *rxtx_annotations; +#define RXRPC_TX_ANNO_ACK 0 +#define RXRPC_TX_ANNO_UNACK 1 +#define RXRPC_TX_ANNO_NAK 2 +#define RXRPC_TX_ANNO_RETRANS 3 +#define RXRPC_TX_ANNO_MASK 0x03 +#define RXRPC_TX_ANNO_LAST 0x04 +#define RXRPC_TX_ANNO_RESENT 0x08 + +#define RXRPC_RX_ANNO_SUBPACKET 0x3f /* Subpacket number in jumbogram */ +#define RXRPC_RX_ANNO_VERIFIED 0x80 /* Set if verified and decrypted */ + rxrpc_seq_t tx_hard_ack; /* Dead slot in buffer; the first transmitted but + * not hard-ACK'd packet follows this. + */ + rxrpc_seq_t tx_top; /* Highest Tx slot allocated. */ + u16 tx_backoff; /* Delay to insert due to Tx failure */ + + /* TCP-style slow-start congestion control [RFC5681]. Since the SMSS + * is fixed, we keep these numbers in terms of segments (ie. DATA + * packets) rather than bytes. + */ +#define RXRPC_TX_SMSS RXRPC_JUMBO_DATALEN + u8 cong_cwnd; /* Congestion window size */ + u8 cong_extra; /* Extra to send for congestion management */ + u8 cong_ssthresh; /* Slow-start threshold */ + enum rxrpc_congest_mode cong_mode:8; /* Congestion management mode */ + u8 cong_dup_acks; /* Count of ACKs showing missing packets */ + u8 cong_cumul_acks; /* Cumulative ACK count */ + ktime_t cong_tstamp; /* Last time cwnd was changed */ + + rxrpc_seq_t rx_hard_ack; /* Dead slot in buffer; the first received but not + * consumed packet follows this. + */ + rxrpc_seq_t rx_top; /* Highest Rx slot allocated. */ + rxrpc_seq_t rx_expect_next; /* Expected next packet sequence number */ + rxrpc_serial_t rx_serial; /* Highest serial received for this call */ + u8 rx_winsize; /* Size of Rx window */ + u8 tx_winsize; /* Maximum size of Tx window */ + bool tx_phase; /* T if transmission phase, F if receive phase */ + u8 nr_jumbo_bad; /* Number of jumbo dups/exceeds-windows */ + + spinlock_t input_lock; /* Lock for packet input to this call */ + + /* Receive-phase ACK management (ACKs we send). */ + u8 ackr_reason; /* reason to ACK */ + rxrpc_serial_t ackr_serial; /* serial of packet being ACK'd */ + rxrpc_seq_t ackr_highest_seq; /* Higest sequence number received */ + atomic_t ackr_nr_unacked; /* Number of unacked packets */ + atomic_t ackr_nr_consumed; /* Number of packets needing hard ACK */ + + /* RTT management */ + rxrpc_serial_t rtt_serial[4]; /* Serial number of DATA or PING sent */ + ktime_t rtt_sent_at[4]; /* Time packet sent */ + unsigned long rtt_avail; /* Mask of available slots in bits 0-3, + * Mask of pending samples in 8-11 */ +#define RXRPC_CALL_RTT_AVAIL_MASK 0xf +#define RXRPC_CALL_RTT_PEND_SHIFT 8 + + /* Transmission-phase ACK management (ACKs we've received). */ + ktime_t acks_latest_ts; /* Timestamp of latest ACK received */ + rxrpc_seq_t acks_first_seq; /* first sequence number received */ + rxrpc_seq_t acks_prev_seq; /* Highest previousPacket received */ + rxrpc_seq_t acks_lowest_nak; /* Lowest NACK in the buffer (or ==tx_hard_ack) */ + rxrpc_seq_t acks_lost_top; /* tx_top at the time lost-ack ping sent */ + rxrpc_serial_t acks_lost_ping; /* Serial number of probe ACK */ +}; + +/* + * Summary of a new ACK and the changes it made to the Tx buffer packet states. + */ +struct rxrpc_ack_summary { + u8 ack_reason; + u8 nr_acks; /* Number of ACKs in packet */ + u8 nr_nacks; /* Number of NACKs in packet */ + u8 nr_new_acks; /* Number of new ACKs in packet */ + u8 nr_new_nacks; /* Number of new NACKs in packet */ + u8 nr_rot_new_acks; /* Number of rotated new ACKs */ + bool new_low_nack; /* T if new low NACK found */ + bool retrans_timeo; /* T if reTx due to timeout happened */ + u8 flight_size; /* Number of unreceived transmissions */ + /* Place to stash values for tracing */ + enum rxrpc_congest_mode mode:8; + u8 cwnd; + u8 ssthresh; + u8 dup_acks; + u8 cumulative_acks; +}; + +/* + * sendmsg() cmsg-specified parameters. + */ +enum rxrpc_command { + RXRPC_CMD_SEND_DATA, /* send data message */ + RXRPC_CMD_SEND_ABORT, /* request abort generation */ + RXRPC_CMD_REJECT_BUSY, /* [server] reject a call as busy */ + RXRPC_CMD_CHARGE_ACCEPT, /* [server] charge accept preallocation */ +}; + +struct rxrpc_call_params { + s64 tx_total_len; /* Total Tx data length (if send data) */ + unsigned long user_call_ID; /* User's call ID */ + struct { + u32 hard; /* Maximum lifetime (sec) */ + u32 idle; /* Max time since last data packet (msec) */ + u32 normal; /* Max time since last call packet (msec) */ + } timeouts; + u8 nr_timeouts; /* Number of timeouts specified */ + bool kernel; /* T if kernel is making the call */ + enum rxrpc_interruptibility interruptibility; /* How is interruptible is the call? */ +}; + +struct rxrpc_send_params { + struct rxrpc_call_params call; + u32 abort_code; /* Abort code to Tx (if abort) */ + enum rxrpc_command command : 8; /* The command to implement */ + bool exclusive; /* Shared or exclusive call */ + bool upgrade; /* If the connection is upgradeable */ +}; + +#include + +/* + * af_rxrpc.c + */ +extern atomic_t rxrpc_n_tx_skbs, rxrpc_n_rx_skbs; +extern struct workqueue_struct *rxrpc_workqueue; + +/* + * call_accept.c + */ +int rxrpc_service_prealloc(struct rxrpc_sock *, gfp_t); +void rxrpc_discard_prealloc(struct rxrpc_sock *); +struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *, + struct rxrpc_sock *, + struct sk_buff *); +void rxrpc_accept_incoming_calls(struct rxrpc_local *); +int rxrpc_user_charge_accept(struct rxrpc_sock *, unsigned long); + +/* + * call_event.c + */ +void rxrpc_propose_ACK(struct rxrpc_call *, u8, u32, bool, bool, + enum rxrpc_propose_ack_trace); +void rxrpc_process_call(struct work_struct *); + +void rxrpc_reduce_call_timer(struct rxrpc_call *call, + unsigned long expire_at, + unsigned long now, + enum rxrpc_timer_trace why); + +void rxrpc_delete_call_timer(struct rxrpc_call *call); + +/* + * call_object.c + */ +extern const char *const rxrpc_call_states[]; +extern const char *const rxrpc_call_completions[]; +extern struct kmem_cache *rxrpc_call_jar; + +struct rxrpc_call *rxrpc_find_call_by_user_ID(struct rxrpc_sock *, unsigned long); +struct rxrpc_call *rxrpc_alloc_call(struct rxrpc_sock *, gfp_t, unsigned int); +struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *, + struct rxrpc_conn_parameters *, + struct sockaddr_rxrpc *, + struct rxrpc_call_params *, gfp_t, + unsigned int); +void rxrpc_incoming_call(struct rxrpc_sock *, struct rxrpc_call *, + struct sk_buff *); +void rxrpc_release_call(struct rxrpc_sock *, struct rxrpc_call *); +void rxrpc_release_calls_on_socket(struct rxrpc_sock *); +bool __rxrpc_queue_call(struct rxrpc_call *); +bool rxrpc_queue_call(struct rxrpc_call *); +void rxrpc_see_call(struct rxrpc_call *); +bool rxrpc_try_get_call(struct rxrpc_call *call, enum rxrpc_call_trace op); +void rxrpc_get_call(struct rxrpc_call *, enum rxrpc_call_trace); +void rxrpc_put_call(struct rxrpc_call *, enum rxrpc_call_trace); +void rxrpc_cleanup_call(struct rxrpc_call *); +void rxrpc_destroy_all_calls(struct rxrpc_net *); + +static inline bool rxrpc_is_service_call(const struct rxrpc_call *call) +{ + return test_bit(RXRPC_CALL_IS_SERVICE, &call->flags); +} + +static inline bool rxrpc_is_client_call(const struct rxrpc_call *call) +{ + return !rxrpc_is_service_call(call); +} + +/* + * conn_client.c + */ +extern unsigned int rxrpc_reap_client_connections; +extern unsigned long rxrpc_conn_idle_client_expiry; +extern unsigned long rxrpc_conn_idle_client_fast_expiry; +extern struct idr rxrpc_client_conn_ids; + +void rxrpc_destroy_client_conn_ids(void); +struct rxrpc_bundle *rxrpc_get_bundle(struct rxrpc_bundle *); +void rxrpc_put_bundle(struct rxrpc_bundle *); +int rxrpc_connect_call(struct rxrpc_sock *, struct rxrpc_call *, + struct rxrpc_conn_parameters *, struct sockaddr_rxrpc *, + gfp_t); +void rxrpc_expose_client_call(struct rxrpc_call *); +void rxrpc_disconnect_client_call(struct rxrpc_bundle *, struct rxrpc_call *); +void rxrpc_put_client_conn(struct rxrpc_connection *); +void rxrpc_discard_expired_client_conns(struct work_struct *); +void rxrpc_destroy_all_client_connections(struct rxrpc_net *); +void rxrpc_clean_up_local_conns(struct rxrpc_local *); + +/* + * conn_event.c + */ +void rxrpc_process_connection(struct work_struct *); +void rxrpc_process_delayed_final_acks(struct rxrpc_connection *, bool); + +/* + * conn_object.c + */ +extern unsigned int rxrpc_connection_expiry; +extern unsigned int rxrpc_closed_conn_expiry; + +struct rxrpc_connection *rxrpc_alloc_connection(gfp_t); +struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *, + struct sk_buff *, + struct rxrpc_peer **); +void __rxrpc_disconnect_call(struct rxrpc_connection *, struct rxrpc_call *); +void rxrpc_disconnect_call(struct rxrpc_call *); +void rxrpc_kill_connection(struct rxrpc_connection *); +bool rxrpc_queue_conn(struct rxrpc_connection *); +void rxrpc_see_connection(struct rxrpc_connection *); +struct rxrpc_connection *rxrpc_get_connection(struct rxrpc_connection *); +struct rxrpc_connection *rxrpc_get_connection_maybe(struct rxrpc_connection *); +void rxrpc_put_service_conn(struct rxrpc_connection *); +void rxrpc_service_connection_reaper(struct work_struct *); +void rxrpc_destroy_all_connections(struct rxrpc_net *); + +static inline bool rxrpc_conn_is_client(const struct rxrpc_connection *conn) +{ + return conn->out_clientflag; +} + +static inline bool rxrpc_conn_is_service(const struct rxrpc_connection *conn) +{ + return !rxrpc_conn_is_client(conn); +} + +static inline void rxrpc_put_connection(struct rxrpc_connection *conn) +{ + if (!conn) + return; + + if (rxrpc_conn_is_client(conn)) + rxrpc_put_client_conn(conn); + else + rxrpc_put_service_conn(conn); +} + +static inline void rxrpc_reduce_conn_timer(struct rxrpc_connection *conn, + unsigned long expire_at) +{ + timer_reduce(&conn->timer, expire_at); +} + +/* + * conn_service.c + */ +struct rxrpc_connection *rxrpc_find_service_conn_rcu(struct rxrpc_peer *, + struct sk_buff *); +struct rxrpc_connection *rxrpc_prealloc_service_connection(struct rxrpc_net *, gfp_t); +void rxrpc_new_incoming_connection(struct rxrpc_sock *, struct rxrpc_connection *, + const struct rxrpc_security *, struct sk_buff *); +void rxrpc_unpublish_service_conn(struct rxrpc_connection *); + +/* + * input.c + */ +int rxrpc_input_packet(struct sock *, struct sk_buff *); + +/* + * insecure.c + */ +extern const struct rxrpc_security rxrpc_no_security; + +/* + * key.c + */ +extern struct key_type key_type_rxrpc; + +int rxrpc_request_key(struct rxrpc_sock *, sockptr_t , int); +int rxrpc_get_server_data_key(struct rxrpc_connection *, const void *, time64_t, + u32); + +/* + * local_event.c + */ +extern void rxrpc_process_local_events(struct rxrpc_local *); + +/* + * local_object.c + */ +struct rxrpc_local *rxrpc_lookup_local(struct net *, const struct sockaddr_rxrpc *); +struct rxrpc_local *rxrpc_get_local(struct rxrpc_local *); +struct rxrpc_local *rxrpc_get_local_maybe(struct rxrpc_local *); +void rxrpc_put_local(struct rxrpc_local *); +struct rxrpc_local *rxrpc_use_local(struct rxrpc_local *); +void rxrpc_unuse_local(struct rxrpc_local *); +void rxrpc_queue_local(struct rxrpc_local *); +void rxrpc_destroy_all_locals(struct rxrpc_net *); + +static inline bool __rxrpc_unuse_local(struct rxrpc_local *local) +{ + return atomic_dec_return(&local->active_users) == 0; +} + +static inline bool __rxrpc_use_local(struct rxrpc_local *local) +{ + return atomic_fetch_add_unless(&local->active_users, 1, 0) != 0; +} + +/* + * misc.c + */ +extern unsigned int rxrpc_max_backlog __read_mostly; +extern unsigned long rxrpc_requested_ack_delay; +extern unsigned long rxrpc_soft_ack_delay; +extern unsigned long rxrpc_idle_ack_delay; +extern unsigned int rxrpc_rx_window_size; +extern unsigned int rxrpc_rx_mtu; +extern unsigned int rxrpc_rx_jumbo_max; + +extern const s8 rxrpc_ack_priority[]; + +/* + * net_ns.c + */ +extern unsigned int rxrpc_net_id; +extern struct pernet_operations rxrpc_net_ops; + +static inline struct rxrpc_net *rxrpc_net(struct net *net) +{ + return net_generic(net, rxrpc_net_id); +} + +/* + * output.c + */ +int rxrpc_send_ack_packet(struct rxrpc_call *, bool, rxrpc_serial_t *); +int rxrpc_send_abort_packet(struct rxrpc_call *); +int rxrpc_send_data_packet(struct rxrpc_call *, struct sk_buff *, bool); +void rxrpc_reject_packets(struct rxrpc_local *); +void rxrpc_send_keepalive(struct rxrpc_peer *); + +/* + * peer_event.c + */ +void rxrpc_encap_err_rcv(struct sock *sk, struct sk_buff *skb, unsigned int udp_offset); +void rxrpc_error_report(struct sock *); +void rxrpc_peer_keepalive_worker(struct work_struct *); + +/* + * peer_object.c + */ +struct rxrpc_peer *rxrpc_lookup_peer_rcu(struct rxrpc_local *, + const struct sockaddr_rxrpc *); +struct rxrpc_peer *rxrpc_lookup_peer(struct rxrpc_sock *, struct rxrpc_local *, + struct sockaddr_rxrpc *, gfp_t); +struct rxrpc_peer *rxrpc_alloc_peer(struct rxrpc_local *, gfp_t); +void rxrpc_new_incoming_peer(struct rxrpc_sock *, struct rxrpc_local *, + struct rxrpc_peer *); +void rxrpc_destroy_all_peers(struct rxrpc_net *); +struct rxrpc_peer *rxrpc_get_peer(struct rxrpc_peer *); +struct rxrpc_peer *rxrpc_get_peer_maybe(struct rxrpc_peer *); +void rxrpc_put_peer(struct rxrpc_peer *); +void rxrpc_put_peer_locked(struct rxrpc_peer *); + +/* + * proc.c + */ +extern const struct seq_operations rxrpc_call_seq_ops; +extern const struct seq_operations rxrpc_connection_seq_ops; +extern const struct seq_operations rxrpc_peer_seq_ops; +extern const struct seq_operations rxrpc_local_seq_ops; + +/* + * recvmsg.c + */ +void rxrpc_notify_socket(struct rxrpc_call *); +bool __rxrpc_set_call_completion(struct rxrpc_call *, enum rxrpc_call_completion, u32, int); +bool rxrpc_set_call_completion(struct rxrpc_call *, enum rxrpc_call_completion, u32, int); +bool __rxrpc_call_completed(struct rxrpc_call *); +bool rxrpc_call_completed(struct rxrpc_call *); +bool __rxrpc_abort_call(const char *, struct rxrpc_call *, rxrpc_seq_t, u32, int); +bool rxrpc_abort_call(const char *, struct rxrpc_call *, rxrpc_seq_t, u32, int); +int rxrpc_recvmsg(struct socket *, struct msghdr *, size_t, int); + +/* + * Abort a call due to a protocol error. + */ +static inline bool __rxrpc_abort_eproto(struct rxrpc_call *call, + struct sk_buff *skb, + const char *eproto_why, + const char *why, + u32 abort_code) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + + trace_rxrpc_rx_eproto(call, sp->hdr.serial, eproto_why); + return rxrpc_abort_call(why, call, sp->hdr.seq, abort_code, -EPROTO); +} + +#define rxrpc_abort_eproto(call, skb, eproto_why, abort_why, abort_code) \ + __rxrpc_abort_eproto((call), (skb), tracepoint_string(eproto_why), \ + (abort_why), (abort_code)) + +/* + * rtt.c + */ +void rxrpc_peer_add_rtt(struct rxrpc_call *, enum rxrpc_rtt_rx_trace, int, + rxrpc_serial_t, rxrpc_serial_t, ktime_t, ktime_t); +unsigned long rxrpc_get_rto_backoff(struct rxrpc_peer *, bool); +void rxrpc_peer_init_rtt(struct rxrpc_peer *); + +/* + * rxkad.c + */ +#ifdef CONFIG_RXKAD +extern const struct rxrpc_security rxkad; +#endif + +/* + * security.c + */ +int __init rxrpc_init_security(void); +const struct rxrpc_security *rxrpc_security_lookup(u8); +void rxrpc_exit_security(void); +int rxrpc_init_client_conn_security(struct rxrpc_connection *); +const struct rxrpc_security *rxrpc_get_incoming_security(struct rxrpc_sock *, + struct sk_buff *); +struct key *rxrpc_look_up_server_security(struct rxrpc_connection *, + struct sk_buff *, u32, u32); + +/* + * sendmsg.c + */ +int rxrpc_do_sendmsg(struct rxrpc_sock *, struct msghdr *, size_t); + +/* + * server_key.c + */ +extern struct key_type key_type_rxrpc_s; + +int rxrpc_server_keyring(struct rxrpc_sock *, sockptr_t, int); + +/* + * skbuff.c + */ +void rxrpc_kernel_data_consumed(struct rxrpc_call *, struct sk_buff *); +void rxrpc_packet_destructor(struct sk_buff *); +void rxrpc_new_skb(struct sk_buff *, enum rxrpc_skb_trace); +void rxrpc_see_skb(struct sk_buff *, enum rxrpc_skb_trace); +void rxrpc_eaten_skb(struct sk_buff *, enum rxrpc_skb_trace); +void rxrpc_get_skb(struct sk_buff *, enum rxrpc_skb_trace); +void rxrpc_free_skb(struct sk_buff *, enum rxrpc_skb_trace); +void rxrpc_purge_queue(struct sk_buff_head *); + +/* + * sysctl.c + */ +#ifdef CONFIG_SYSCTL +extern int __init rxrpc_sysctl_init(void); +extern void rxrpc_sysctl_exit(void); +#else +static inline int __init rxrpc_sysctl_init(void) { return 0; } +static inline void rxrpc_sysctl_exit(void) {} +#endif + +/* + * utils.c + */ +int rxrpc_extract_addr_from_skb(struct sockaddr_rxrpc *, struct sk_buff *); + +static inline bool before(u32 seq1, u32 seq2) +{ + return (s32)(seq1 - seq2) < 0; +} +static inline bool before_eq(u32 seq1, u32 seq2) +{ + return (s32)(seq1 - seq2) <= 0; +} +static inline bool after(u32 seq1, u32 seq2) +{ + return (s32)(seq1 - seq2) > 0; +} +static inline bool after_eq(u32 seq1, u32 seq2) +{ + return (s32)(seq1 - seq2) >= 0; +} + +/* + * debug tracing + */ +extern unsigned int rxrpc_debug; + +#define dbgprintk(FMT,...) \ + printk("[%-6.6s] "FMT"\n", current->comm ,##__VA_ARGS__) + +#define kenter(FMT,...) dbgprintk("==> %s("FMT")",__func__ ,##__VA_ARGS__) +#define kleave(FMT,...) dbgprintk("<== %s()"FMT"",__func__ ,##__VA_ARGS__) +#define kdebug(FMT,...) dbgprintk(" "FMT ,##__VA_ARGS__) +#define kproto(FMT,...) dbgprintk("### "FMT ,##__VA_ARGS__) +#define knet(FMT,...) dbgprintk("@@@ "FMT ,##__VA_ARGS__) + + +#if defined(__KDEBUG) +#define _enter(FMT,...) kenter(FMT,##__VA_ARGS__) +#define _leave(FMT,...) kleave(FMT,##__VA_ARGS__) +#define _debug(FMT,...) kdebug(FMT,##__VA_ARGS__) +#define _proto(FMT,...) kproto(FMT,##__VA_ARGS__) +#define _net(FMT,...) knet(FMT,##__VA_ARGS__) + +#elif defined(CONFIG_AF_RXRPC_DEBUG) +#define RXRPC_DEBUG_KENTER 0x01 +#define RXRPC_DEBUG_KLEAVE 0x02 +#define RXRPC_DEBUG_KDEBUG 0x04 +#define RXRPC_DEBUG_KPROTO 0x08 +#define RXRPC_DEBUG_KNET 0x10 + +#define _enter(FMT,...) \ +do { \ + if (unlikely(rxrpc_debug & RXRPC_DEBUG_KENTER)) \ + kenter(FMT,##__VA_ARGS__); \ +} while (0) + +#define _leave(FMT,...) \ +do { \ + if (unlikely(rxrpc_debug & RXRPC_DEBUG_KLEAVE)) \ + kleave(FMT,##__VA_ARGS__); \ +} while (0) + +#define _debug(FMT,...) \ +do { \ + if (unlikely(rxrpc_debug & RXRPC_DEBUG_KDEBUG)) \ + kdebug(FMT,##__VA_ARGS__); \ +} while (0) + +#define _proto(FMT,...) \ +do { \ + if (unlikely(rxrpc_debug & RXRPC_DEBUG_KPROTO)) \ + kproto(FMT,##__VA_ARGS__); \ +} while (0) + +#define _net(FMT,...) \ +do { \ + if (unlikely(rxrpc_debug & RXRPC_DEBUG_KNET)) \ + knet(FMT,##__VA_ARGS__); \ +} while (0) + +#else +#define _enter(FMT,...) no_printk("==> %s("FMT")",__func__ ,##__VA_ARGS__) +#define _leave(FMT,...) no_printk("<== %s()"FMT"",__func__ ,##__VA_ARGS__) +#define _debug(FMT,...) no_printk(" "FMT ,##__VA_ARGS__) +#define _proto(FMT,...) no_printk("### "FMT ,##__VA_ARGS__) +#define _net(FMT,...) no_printk("@@@ "FMT ,##__VA_ARGS__) +#endif + +/* + * debug assertion checking + */ +#if 1 // defined(__KDEBUGALL) + +#define ASSERT(X) \ +do { \ + if (unlikely(!(X))) { \ + pr_err("Assertion failed\n"); \ + BUG(); \ + } \ +} while (0) + +#define ASSERTCMP(X, OP, Y) \ +do { \ + __typeof__(X) _x = (X); \ + __typeof__(Y) _y = (__typeof__(X))(Y); \ + if (unlikely(!(_x OP _y))) { \ + pr_err("Assertion failed - %lu(0x%lx) %s %lu(0x%lx) is false\n", \ + (unsigned long)_x, (unsigned long)_x, #OP, \ + (unsigned long)_y, (unsigned long)_y); \ + BUG(); \ + } \ +} while (0) + +#define ASSERTIF(C, X) \ +do { \ + if (unlikely((C) && !(X))) { \ + pr_err("Assertion failed\n"); \ + BUG(); \ + } \ +} while (0) + +#define ASSERTIFCMP(C, X, OP, Y) \ +do { \ + __typeof__(X) _x = (X); \ + __typeof__(Y) _y = (__typeof__(X))(Y); \ + if (unlikely((C) && !(_x OP _y))) { \ + pr_err("Assertion failed - %lu(0x%lx) %s %lu(0x%lx) is false\n", \ + (unsigned long)_x, (unsigned long)_x, #OP, \ + (unsigned long)_y, (unsigned long)_y); \ + BUG(); \ + } \ +} while (0) + +#else + +#define ASSERT(X) \ +do { \ +} while (0) + +#define ASSERTCMP(X, OP, Y) \ +do { \ +} while (0) + +#define ASSERTIF(C, X) \ +do { \ +} while (0) + +#define ASSERTIFCMP(C, X, OP, Y) \ +do { \ +} while (0) + +#endif /* __KDEBUGALL */ diff --git a/net/rxrpc/call_accept.c b/net/rxrpc/call_accept.c new file mode 100644 index 000000000..99e10eea3 --- /dev/null +++ b/net/rxrpc/call_accept.c @@ -0,0 +1,491 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* incoming call handling + * + * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ar-internal.h" + +static void rxrpc_dummy_notify(struct sock *sk, struct rxrpc_call *call, + unsigned long user_call_ID) +{ +} + +/* + * Preallocate a single service call, connection and peer and, if possible, + * give them a user ID and attach the user's side of the ID to them. + */ +static int rxrpc_service_prealloc_one(struct rxrpc_sock *rx, + struct rxrpc_backlog *b, + rxrpc_notify_rx_t notify_rx, + rxrpc_user_attach_call_t user_attach_call, + unsigned long user_call_ID, gfp_t gfp, + unsigned int debug_id) +{ + const void *here = __builtin_return_address(0); + struct rxrpc_call *call, *xcall; + struct rxrpc_net *rxnet = rxrpc_net(sock_net(&rx->sk)); + struct rb_node *parent, **pp; + int max, tmp; + unsigned int size = RXRPC_BACKLOG_MAX; + unsigned int head, tail, call_head, call_tail; + + max = rx->sk.sk_max_ack_backlog; + tmp = rx->sk.sk_ack_backlog; + if (tmp >= max) { + _leave(" = -ENOBUFS [full %u]", max); + return -ENOBUFS; + } + max -= tmp; + + /* We don't need more conns and peers than we have calls, but on the + * other hand, we shouldn't ever use more peers than conns or conns + * than calls. + */ + call_head = b->call_backlog_head; + call_tail = READ_ONCE(b->call_backlog_tail); + tmp = CIRC_CNT(call_head, call_tail, size); + if (tmp >= max) { + _leave(" = -ENOBUFS [enough %u]", tmp); + return -ENOBUFS; + } + max = tmp + 1; + + head = b->peer_backlog_head; + tail = READ_ONCE(b->peer_backlog_tail); + if (CIRC_CNT(head, tail, size) < max) { + struct rxrpc_peer *peer = rxrpc_alloc_peer(rx->local, gfp); + if (!peer) + return -ENOMEM; + b->peer_backlog[head] = peer; + smp_store_release(&b->peer_backlog_head, + (head + 1) & (size - 1)); + } + + head = b->conn_backlog_head; + tail = READ_ONCE(b->conn_backlog_tail); + if (CIRC_CNT(head, tail, size) < max) { + struct rxrpc_connection *conn; + + conn = rxrpc_prealloc_service_connection(rxnet, gfp); + if (!conn) + return -ENOMEM; + b->conn_backlog[head] = conn; + smp_store_release(&b->conn_backlog_head, + (head + 1) & (size - 1)); + + trace_rxrpc_conn(conn->debug_id, rxrpc_conn_new_service, + refcount_read(&conn->ref), here); + } + + /* Now it gets complicated, because calls get registered with the + * socket here, with a user ID preassigned by the user. + */ + call = rxrpc_alloc_call(rx, gfp, debug_id); + if (!call) + return -ENOMEM; + call->flags |= (1 << RXRPC_CALL_IS_SERVICE); + call->state = RXRPC_CALL_SERVER_PREALLOC; + + trace_rxrpc_call(call->debug_id, rxrpc_call_new_service, + refcount_read(&call->ref), + here, (const void *)user_call_ID); + + write_lock(&rx->call_lock); + + /* Check the user ID isn't already in use */ + pp = &rx->calls.rb_node; + parent = NULL; + while (*pp) { + parent = *pp; + xcall = rb_entry(parent, struct rxrpc_call, sock_node); + if (user_call_ID < xcall->user_call_ID) + pp = &(*pp)->rb_left; + else if (user_call_ID > xcall->user_call_ID) + pp = &(*pp)->rb_right; + else + goto id_in_use; + } + + call->user_call_ID = user_call_ID; + call->notify_rx = notify_rx; + if (user_attach_call) { + rxrpc_get_call(call, rxrpc_call_got_kernel); + user_attach_call(call, user_call_ID); + } + + rxrpc_get_call(call, rxrpc_call_got_userid); + rb_link_node(&call->sock_node, parent, pp); + rb_insert_color(&call->sock_node, &rx->calls); + set_bit(RXRPC_CALL_HAS_USERID, &call->flags); + + list_add(&call->sock_link, &rx->sock_calls); + + write_unlock(&rx->call_lock); + + rxnet = call->rxnet; + spin_lock_bh(&rxnet->call_lock); + list_add_tail_rcu(&call->link, &rxnet->calls); + spin_unlock_bh(&rxnet->call_lock); + + b->call_backlog[call_head] = call; + smp_store_release(&b->call_backlog_head, (call_head + 1) & (size - 1)); + _leave(" = 0 [%d -> %lx]", call->debug_id, user_call_ID); + return 0; + +id_in_use: + write_unlock(&rx->call_lock); + rxrpc_cleanup_call(call); + _leave(" = -EBADSLT"); + return -EBADSLT; +} + +/* + * Allocate the preallocation buffers for incoming service calls. These must + * be charged manually. + */ +int rxrpc_service_prealloc(struct rxrpc_sock *rx, gfp_t gfp) +{ + struct rxrpc_backlog *b = rx->backlog; + + if (!b) { + b = kzalloc(sizeof(struct rxrpc_backlog), gfp); + if (!b) + return -ENOMEM; + rx->backlog = b; + } + + return 0; +} + +/* + * Discard the preallocation on a service. + */ +void rxrpc_discard_prealloc(struct rxrpc_sock *rx) +{ + struct rxrpc_backlog *b = rx->backlog; + struct rxrpc_net *rxnet = rxrpc_net(sock_net(&rx->sk)); + unsigned int size = RXRPC_BACKLOG_MAX, head, tail; + + if (!b) + return; + rx->backlog = NULL; + + /* Make sure that there aren't any incoming calls in progress before we + * clear the preallocation buffers. + */ + spin_lock_bh(&rx->incoming_lock); + spin_unlock_bh(&rx->incoming_lock); + + head = b->peer_backlog_head; + tail = b->peer_backlog_tail; + while (CIRC_CNT(head, tail, size) > 0) { + struct rxrpc_peer *peer = b->peer_backlog[tail]; + rxrpc_put_local(peer->local); + kfree(peer); + tail = (tail + 1) & (size - 1); + } + + head = b->conn_backlog_head; + tail = b->conn_backlog_tail; + while (CIRC_CNT(head, tail, size) > 0) { + struct rxrpc_connection *conn = b->conn_backlog[tail]; + write_lock(&rxnet->conn_lock); + list_del(&conn->link); + list_del(&conn->proc_link); + write_unlock(&rxnet->conn_lock); + kfree(conn); + if (atomic_dec_and_test(&rxnet->nr_conns)) + wake_up_var(&rxnet->nr_conns); + tail = (tail + 1) & (size - 1); + } + + head = b->call_backlog_head; + tail = b->call_backlog_tail; + while (CIRC_CNT(head, tail, size) > 0) { + struct rxrpc_call *call = b->call_backlog[tail]; + rcu_assign_pointer(call->socket, rx); + if (rx->discard_new_call) { + _debug("discard %lx", call->user_call_ID); + rx->discard_new_call(call, call->user_call_ID); + if (call->notify_rx) + call->notify_rx = rxrpc_dummy_notify; + rxrpc_put_call(call, rxrpc_call_put_kernel); + } + rxrpc_call_completed(call); + rxrpc_release_call(rx, call); + rxrpc_put_call(call, rxrpc_call_put); + tail = (tail + 1) & (size - 1); + } + + kfree(b); +} + +/* + * Ping the other end to fill our RTT cache and to retrieve the rwind + * and MTU parameters. + */ +static void rxrpc_send_ping(struct rxrpc_call *call, struct sk_buff *skb) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + ktime_t now = skb->tstamp; + + if (call->peer->rtt_count < 3 || + ktime_before(ktime_add_ms(call->peer->rtt_last_req, 1000), now)) + rxrpc_propose_ACK(call, RXRPC_ACK_PING, sp->hdr.serial, + true, true, + rxrpc_propose_ack_ping_for_params); +} + +/* + * Allocate a new incoming call from the prealloc pool, along with a connection + * and a peer as necessary. + */ +static struct rxrpc_call *rxrpc_alloc_incoming_call(struct rxrpc_sock *rx, + struct rxrpc_local *local, + struct rxrpc_peer *peer, + struct rxrpc_connection *conn, + const struct rxrpc_security *sec, + struct sk_buff *skb) +{ + struct rxrpc_backlog *b = rx->backlog; + struct rxrpc_call *call; + unsigned short call_head, conn_head, peer_head; + unsigned short call_tail, conn_tail, peer_tail; + unsigned short call_count, conn_count; + + /* #calls >= #conns >= #peers must hold true. */ + call_head = smp_load_acquire(&b->call_backlog_head); + call_tail = b->call_backlog_tail; + call_count = CIRC_CNT(call_head, call_tail, RXRPC_BACKLOG_MAX); + conn_head = smp_load_acquire(&b->conn_backlog_head); + conn_tail = b->conn_backlog_tail; + conn_count = CIRC_CNT(conn_head, conn_tail, RXRPC_BACKLOG_MAX); + ASSERTCMP(conn_count, >=, call_count); + peer_head = smp_load_acquire(&b->peer_backlog_head); + peer_tail = b->peer_backlog_tail; + ASSERTCMP(CIRC_CNT(peer_head, peer_tail, RXRPC_BACKLOG_MAX), >=, + conn_count); + + if (call_count == 0) + return NULL; + + if (!conn) { + if (peer && !rxrpc_get_peer_maybe(peer)) + peer = NULL; + if (!peer) { + peer = b->peer_backlog[peer_tail]; + if (rxrpc_extract_addr_from_skb(&peer->srx, skb) < 0) + return NULL; + b->peer_backlog[peer_tail] = NULL; + smp_store_release(&b->peer_backlog_tail, + (peer_tail + 1) & + (RXRPC_BACKLOG_MAX - 1)); + + rxrpc_new_incoming_peer(rx, local, peer); + } + + /* Now allocate and set up the connection */ + conn = b->conn_backlog[conn_tail]; + b->conn_backlog[conn_tail] = NULL; + smp_store_release(&b->conn_backlog_tail, + (conn_tail + 1) & (RXRPC_BACKLOG_MAX - 1)); + conn->params.local = rxrpc_get_local(local); + conn->params.peer = peer; + rxrpc_see_connection(conn); + rxrpc_new_incoming_connection(rx, conn, sec, skb); + } else { + rxrpc_get_connection(conn); + } + + /* And now we can allocate and set up a new call */ + call = b->call_backlog[call_tail]; + b->call_backlog[call_tail] = NULL; + smp_store_release(&b->call_backlog_tail, + (call_tail + 1) & (RXRPC_BACKLOG_MAX - 1)); + + rxrpc_see_call(call); + call->conn = conn; + call->security = conn->security; + call->security_ix = conn->security_ix; + call->peer = rxrpc_get_peer(conn->params.peer); + call->cong_cwnd = call->peer->cong_cwnd; + return call; +} + +/* + * Set up a new incoming call. Called in BH context with the RCU read lock + * held. + * + * If this is for a kernel service, when we allocate the call, it will have + * three refs on it: (1) the kernel service, (2) the user_call_ID tree, (3) the + * retainer ref obtained from the backlog buffer. Prealloc calls for userspace + * services only have the ref from the backlog buffer. We want to pass this + * ref to non-BH context to dispose of. + * + * If we want to report an error, we mark the skb with the packet type and + * abort code and return NULL. + * + * The call is returned with the user access mutex held. + */ +struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *local, + struct rxrpc_sock *rx, + struct sk_buff *skb) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + const struct rxrpc_security *sec = NULL; + struct rxrpc_connection *conn; + struct rxrpc_peer *peer = NULL; + struct rxrpc_call *call = NULL; + + _enter(""); + + spin_lock(&rx->incoming_lock); + if (rx->sk.sk_state == RXRPC_SERVER_LISTEN_DISABLED || + rx->sk.sk_state == RXRPC_CLOSE) { + trace_rxrpc_abort(0, "CLS", sp->hdr.cid, sp->hdr.callNumber, + sp->hdr.seq, RX_INVALID_OPERATION, ESHUTDOWN); + skb->mark = RXRPC_SKB_MARK_REJECT_ABORT; + skb->priority = RX_INVALID_OPERATION; + goto no_call; + } + + /* The peer, connection and call may all have sprung into existence due + * to a duplicate packet being handled on another CPU in parallel, so + * we have to recheck the routing. However, we're now holding + * rx->incoming_lock, so the values should remain stable. + */ + conn = rxrpc_find_connection_rcu(local, skb, &peer); + + if (!conn) { + sec = rxrpc_get_incoming_security(rx, skb); + if (!sec) + goto no_call; + } + + call = rxrpc_alloc_incoming_call(rx, local, peer, conn, sec, skb); + if (!call) { + skb->mark = RXRPC_SKB_MARK_REJECT_BUSY; + goto no_call; + } + + trace_rxrpc_receive(call, rxrpc_receive_incoming, + sp->hdr.serial, sp->hdr.seq); + + /* Make the call live. */ + rxrpc_incoming_call(rx, call, skb); + conn = call->conn; + + if (rx->notify_new_call) + rx->notify_new_call(&rx->sk, call, call->user_call_ID); + + spin_lock(&conn->state_lock); + switch (conn->state) { + case RXRPC_CONN_SERVICE_UNSECURED: + conn->state = RXRPC_CONN_SERVICE_CHALLENGING; + set_bit(RXRPC_CONN_EV_CHALLENGE, &call->conn->events); + rxrpc_queue_conn(call->conn); + break; + + case RXRPC_CONN_SERVICE: + write_lock(&call->state_lock); + if (call->state < RXRPC_CALL_COMPLETE) + call->state = RXRPC_CALL_SERVER_RECV_REQUEST; + write_unlock(&call->state_lock); + break; + + case RXRPC_CONN_REMOTELY_ABORTED: + rxrpc_set_call_completion(call, RXRPC_CALL_REMOTELY_ABORTED, + conn->abort_code, conn->error); + break; + case RXRPC_CONN_LOCALLY_ABORTED: + rxrpc_abort_call("CON", call, sp->hdr.seq, + conn->abort_code, conn->error); + break; + default: + BUG(); + } + spin_unlock(&conn->state_lock); + spin_unlock(&rx->incoming_lock); + + rxrpc_send_ping(call, skb); + + /* We have to discard the prealloc queue's ref here and rely on a + * combination of the RCU read lock and refs held either by the socket + * (recvmsg queue, to-be-accepted queue or user ID tree) or the kernel + * service to prevent the call from being deallocated too early. + */ + rxrpc_put_call(call, rxrpc_call_put); + + _leave(" = %p{%d}", call, call->debug_id); + return call; + +no_call: + spin_unlock(&rx->incoming_lock); + _leave(" = NULL [%u]", skb->mark); + return NULL; +} + +/* + * Charge up socket with preallocated calls, attaching user call IDs. + */ +int rxrpc_user_charge_accept(struct rxrpc_sock *rx, unsigned long user_call_ID) +{ + struct rxrpc_backlog *b = rx->backlog; + + if (rx->sk.sk_state == RXRPC_CLOSE) + return -ESHUTDOWN; + + return rxrpc_service_prealloc_one(rx, b, NULL, NULL, user_call_ID, + GFP_KERNEL, + atomic_inc_return(&rxrpc_debug_id)); +} + +/* + * rxrpc_kernel_charge_accept - Charge up socket with preallocated calls + * @sock: The socket on which to preallocate + * @notify_rx: Event notification function for the call + * @user_attach_call: Func to attach call to user_call_ID + * @user_call_ID: The tag to attach to the preallocated call + * @gfp: The allocation conditions. + * @debug_id: The tracing debug ID. + * + * Charge up the socket with preallocated calls, each with a user ID. A + * function should be provided to effect the attachment from the user's side. + * The user is given a ref to hold on the call. + * + * Note that the call may be come connected before this function returns. + */ +int rxrpc_kernel_charge_accept(struct socket *sock, + rxrpc_notify_rx_t notify_rx, + rxrpc_user_attach_call_t user_attach_call, + unsigned long user_call_ID, gfp_t gfp, + unsigned int debug_id) +{ + struct rxrpc_sock *rx = rxrpc_sk(sock->sk); + struct rxrpc_backlog *b = rx->backlog; + + if (sock->sk->sk_state == RXRPC_CLOSE) + return -ESHUTDOWN; + + return rxrpc_service_prealloc_one(rx, b, notify_rx, + user_attach_call, user_call_ID, + gfp, debug_id); +} +EXPORT_SYMBOL(rxrpc_kernel_charge_accept); diff --git a/net/rxrpc/call_event.c b/net/rxrpc/call_event.c new file mode 100644 index 000000000..2a93e7b5f --- /dev/null +++ b/net/rxrpc/call_event.c @@ -0,0 +1,447 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* Management of Tx window, Tx resend, ACKs and out-of-sequence reception + * + * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include "ar-internal.h" + +/* + * Propose a PING ACK be sent. + */ +static void rxrpc_propose_ping(struct rxrpc_call *call, + bool immediate, bool background) +{ + if (immediate) { + if (background && + !test_and_set_bit(RXRPC_CALL_EV_PING, &call->events)) + rxrpc_queue_call(call); + } else { + unsigned long now = jiffies; + unsigned long ping_at = now + rxrpc_idle_ack_delay; + + if (time_before(ping_at, call->ping_at)) { + WRITE_ONCE(call->ping_at, ping_at); + rxrpc_reduce_call_timer(call, ping_at, now, + rxrpc_timer_set_for_ping); + } + } +} + +/* + * propose an ACK be sent + */ +static void __rxrpc_propose_ACK(struct rxrpc_call *call, u8 ack_reason, + u32 serial, bool immediate, bool background, + enum rxrpc_propose_ack_trace why) +{ + enum rxrpc_propose_ack_outcome outcome = rxrpc_propose_ack_use; + unsigned long expiry = rxrpc_soft_ack_delay; + s8 prior = rxrpc_ack_priority[ack_reason]; + + /* Pings are handled specially because we don't want to accidentally + * lose a ping response by subsuming it into a ping. + */ + if (ack_reason == RXRPC_ACK_PING) { + rxrpc_propose_ping(call, immediate, background); + goto trace; + } + + /* Update DELAY, IDLE, REQUESTED and PING_RESPONSE ACK serial + * numbers, but we don't alter the timeout. + */ + _debug("prior %u %u vs %u %u", + ack_reason, prior, + call->ackr_reason, rxrpc_ack_priority[call->ackr_reason]); + if (ack_reason == call->ackr_reason) { + if (RXRPC_ACK_UPDATEABLE & (1 << ack_reason)) { + outcome = rxrpc_propose_ack_update; + call->ackr_serial = serial; + } + if (!immediate) + goto trace; + } else if (prior > rxrpc_ack_priority[call->ackr_reason]) { + call->ackr_reason = ack_reason; + call->ackr_serial = serial; + } else { + outcome = rxrpc_propose_ack_subsume; + } + + switch (ack_reason) { + case RXRPC_ACK_REQUESTED: + if (rxrpc_requested_ack_delay < expiry) + expiry = rxrpc_requested_ack_delay; + if (serial == 1) + immediate = false; + break; + + case RXRPC_ACK_DELAY: + if (rxrpc_soft_ack_delay < expiry) + expiry = rxrpc_soft_ack_delay; + break; + + case RXRPC_ACK_IDLE: + if (rxrpc_idle_ack_delay < expiry) + expiry = rxrpc_idle_ack_delay; + break; + + default: + immediate = true; + break; + } + + if (test_bit(RXRPC_CALL_EV_ACK, &call->events)) { + _debug("already scheduled"); + } else if (immediate || expiry == 0) { + _debug("immediate ACK %lx", call->events); + if (!test_and_set_bit(RXRPC_CALL_EV_ACK, &call->events) && + background) + rxrpc_queue_call(call); + } else { + unsigned long now = jiffies, ack_at; + + if (call->peer->srtt_us != 0) + ack_at = usecs_to_jiffies(call->peer->srtt_us >> 3); + else + ack_at = expiry; + + ack_at += READ_ONCE(call->tx_backoff); + ack_at += now; + if (time_before(ack_at, call->ack_at)) { + WRITE_ONCE(call->ack_at, ack_at); + rxrpc_reduce_call_timer(call, ack_at, now, + rxrpc_timer_set_for_ack); + } + } + +trace: + trace_rxrpc_propose_ack(call, why, ack_reason, serial, immediate, + background, outcome); +} + +/* + * propose an ACK be sent, locking the call structure + */ +void rxrpc_propose_ACK(struct rxrpc_call *call, u8 ack_reason, + u32 serial, bool immediate, bool background, + enum rxrpc_propose_ack_trace why) +{ + spin_lock_bh(&call->lock); + __rxrpc_propose_ACK(call, ack_reason, serial, + immediate, background, why); + spin_unlock_bh(&call->lock); +} + +/* + * Handle congestion being detected by the retransmit timeout. + */ +static void rxrpc_congestion_timeout(struct rxrpc_call *call) +{ + set_bit(RXRPC_CALL_RETRANS_TIMEOUT, &call->flags); +} + +/* + * Perform retransmission of NAK'd and unack'd packets. + */ +static void rxrpc_resend(struct rxrpc_call *call, unsigned long now_j) +{ + struct sk_buff *skb; + unsigned long resend_at; + rxrpc_seq_t cursor, seq, top; + ktime_t now, max_age, oldest, ack_ts; + int ix; + u8 annotation, anno_type, retrans = 0, unacked = 0; + + _enter("{%d,%d}", call->tx_hard_ack, call->tx_top); + + now = ktime_get_real(); + max_age = ktime_sub_us(now, jiffies_to_usecs(call->peer->rto_j)); + + spin_lock_bh(&call->lock); + + cursor = call->tx_hard_ack; + top = call->tx_top; + ASSERT(before_eq(cursor, top)); + if (cursor == top) + goto out_unlock; + + /* Scan the packet list without dropping the lock and decide which of + * the packets in the Tx buffer we're going to resend and what the new + * resend timeout will be. + */ + trace_rxrpc_resend(call, (cursor + 1) & RXRPC_RXTX_BUFF_MASK); + oldest = now; + for (seq = cursor + 1; before_eq(seq, top); seq++) { + ix = seq & RXRPC_RXTX_BUFF_MASK; + annotation = call->rxtx_annotations[ix]; + anno_type = annotation & RXRPC_TX_ANNO_MASK; + annotation &= ~RXRPC_TX_ANNO_MASK; + if (anno_type == RXRPC_TX_ANNO_ACK) + continue; + + skb = call->rxtx_buffer[ix]; + rxrpc_see_skb(skb, rxrpc_skb_seen); + + if (anno_type == RXRPC_TX_ANNO_UNACK) { + if (ktime_after(skb->tstamp, max_age)) { + if (ktime_before(skb->tstamp, oldest)) + oldest = skb->tstamp; + continue; + } + if (!(annotation & RXRPC_TX_ANNO_RESENT)) + unacked++; + } + + /* Okay, we need to retransmit a packet. */ + call->rxtx_annotations[ix] = RXRPC_TX_ANNO_RETRANS | annotation; + retrans++; + trace_rxrpc_retransmit(call, seq, annotation | anno_type, + ktime_to_ns(ktime_sub(skb->tstamp, max_age))); + } + + resend_at = nsecs_to_jiffies(ktime_to_ns(ktime_sub(now, oldest))); + resend_at += jiffies + rxrpc_get_rto_backoff(call->peer, retrans); + WRITE_ONCE(call->resend_at, resend_at); + + if (unacked) + rxrpc_congestion_timeout(call); + + /* If there was nothing that needed retransmission then it's likely + * that an ACK got lost somewhere. Send a ping to find out instead of + * retransmitting data. + */ + if (!retrans) { + rxrpc_reduce_call_timer(call, resend_at, now_j, + rxrpc_timer_set_for_resend); + spin_unlock_bh(&call->lock); + ack_ts = ktime_sub(now, call->acks_latest_ts); + if (ktime_to_us(ack_ts) < (call->peer->srtt_us >> 3)) + goto out; + rxrpc_propose_ACK(call, RXRPC_ACK_PING, 0, true, false, + rxrpc_propose_ack_ping_for_lost_ack); + rxrpc_send_ack_packet(call, true, NULL); + goto out; + } + + /* Now go through the Tx window and perform the retransmissions. We + * have to drop the lock for each send. If an ACK comes in whilst the + * lock is dropped, it may clear some of the retransmission markers for + * packets that it soft-ACKs. + */ + for (seq = cursor + 1; before_eq(seq, top); seq++) { + ix = seq & RXRPC_RXTX_BUFF_MASK; + annotation = call->rxtx_annotations[ix]; + anno_type = annotation & RXRPC_TX_ANNO_MASK; + if (anno_type != RXRPC_TX_ANNO_RETRANS) + continue; + + /* We need to reset the retransmission state, but we need to do + * so before we drop the lock as a new ACK/NAK may come in and + * confuse things + */ + annotation &= ~RXRPC_TX_ANNO_MASK; + annotation |= RXRPC_TX_ANNO_UNACK | RXRPC_TX_ANNO_RESENT; + call->rxtx_annotations[ix] = annotation; + + skb = call->rxtx_buffer[ix]; + if (!skb) + continue; + + rxrpc_get_skb(skb, rxrpc_skb_got); + spin_unlock_bh(&call->lock); + + if (rxrpc_send_data_packet(call, skb, true) < 0) { + rxrpc_free_skb(skb, rxrpc_skb_freed); + return; + } + + if (rxrpc_is_client_call(call)) + rxrpc_expose_client_call(call); + + rxrpc_free_skb(skb, rxrpc_skb_freed); + spin_lock_bh(&call->lock); + if (after(call->tx_hard_ack, seq)) + seq = call->tx_hard_ack; + } + +out_unlock: + spin_unlock_bh(&call->lock); +out: + _leave(""); +} + +/* + * Handle retransmission and deferred ACK/abort generation. + */ +void rxrpc_process_call(struct work_struct *work) +{ + struct rxrpc_call *call = + container_of(work, struct rxrpc_call, processor); + rxrpc_serial_t *send_ack; + unsigned long now, next, t; + unsigned int iterations = 0; + + rxrpc_see_call(call); + + //printk("\n--------------------\n"); + _enter("{%d,%s,%lx}", + call->debug_id, rxrpc_call_states[call->state], call->events); + +recheck_state: + /* Limit the number of times we do this before returning to the manager */ + iterations++; + if (iterations > 5) + goto requeue; + + if (test_and_clear_bit(RXRPC_CALL_EV_ABORT, &call->events)) { + rxrpc_send_abort_packet(call); + goto recheck_state; + } + + if (call->state == RXRPC_CALL_COMPLETE) { + rxrpc_delete_call_timer(call); + goto out_put; + } + + /* Work out if any timeouts tripped */ + now = jiffies; + t = READ_ONCE(call->expect_rx_by); + if (time_after_eq(now, t)) { + trace_rxrpc_timer(call, rxrpc_timer_exp_normal, now); + set_bit(RXRPC_CALL_EV_EXPIRED, &call->events); + } + + t = READ_ONCE(call->expect_req_by); + if (call->state == RXRPC_CALL_SERVER_RECV_REQUEST && + time_after_eq(now, t)) { + trace_rxrpc_timer(call, rxrpc_timer_exp_idle, now); + set_bit(RXRPC_CALL_EV_EXPIRED, &call->events); + } + + t = READ_ONCE(call->expect_term_by); + if (time_after_eq(now, t)) { + trace_rxrpc_timer(call, rxrpc_timer_exp_hard, now); + set_bit(RXRPC_CALL_EV_EXPIRED, &call->events); + } + + t = READ_ONCE(call->ack_at); + if (time_after_eq(now, t)) { + trace_rxrpc_timer(call, rxrpc_timer_exp_ack, now); + cmpxchg(&call->ack_at, t, now + MAX_JIFFY_OFFSET); + set_bit(RXRPC_CALL_EV_ACK, &call->events); + } + + t = READ_ONCE(call->ack_lost_at); + if (time_after_eq(now, t)) { + trace_rxrpc_timer(call, rxrpc_timer_exp_lost_ack, now); + cmpxchg(&call->ack_lost_at, t, now + MAX_JIFFY_OFFSET); + set_bit(RXRPC_CALL_EV_ACK_LOST, &call->events); + } + + t = READ_ONCE(call->keepalive_at); + if (time_after_eq(now, t)) { + trace_rxrpc_timer(call, rxrpc_timer_exp_keepalive, now); + cmpxchg(&call->keepalive_at, t, now + MAX_JIFFY_OFFSET); + rxrpc_propose_ACK(call, RXRPC_ACK_PING, 0, true, true, + rxrpc_propose_ack_ping_for_keepalive); + set_bit(RXRPC_CALL_EV_PING, &call->events); + } + + t = READ_ONCE(call->ping_at); + if (time_after_eq(now, t)) { + trace_rxrpc_timer(call, rxrpc_timer_exp_ping, now); + cmpxchg(&call->ping_at, t, now + MAX_JIFFY_OFFSET); + set_bit(RXRPC_CALL_EV_PING, &call->events); + } + + t = READ_ONCE(call->resend_at); + if (time_after_eq(now, t)) { + trace_rxrpc_timer(call, rxrpc_timer_exp_resend, now); + cmpxchg(&call->resend_at, t, now + MAX_JIFFY_OFFSET); + set_bit(RXRPC_CALL_EV_RESEND, &call->events); + } + + /* Process events */ + if (test_and_clear_bit(RXRPC_CALL_EV_EXPIRED, &call->events)) { + if (test_bit(RXRPC_CALL_RX_HEARD, &call->flags) && + (int)call->conn->hi_serial - (int)call->rx_serial > 0) { + trace_rxrpc_call_reset(call); + rxrpc_abort_call("EXP", call, 0, RX_CALL_DEAD, -ECONNRESET); + } else { + rxrpc_abort_call("EXP", call, 0, RX_CALL_TIMEOUT, -ETIME); + } + set_bit(RXRPC_CALL_EV_ABORT, &call->events); + goto recheck_state; + } + + send_ack = NULL; + if (test_and_clear_bit(RXRPC_CALL_EV_ACK_LOST, &call->events)) { + call->acks_lost_top = call->tx_top; + rxrpc_propose_ACK(call, RXRPC_ACK_PING, 0, true, false, + rxrpc_propose_ack_ping_for_lost_ack); + send_ack = &call->acks_lost_ping; + } + + if (test_and_clear_bit(RXRPC_CALL_EV_ACK, &call->events) || + send_ack) { + if (call->ackr_reason) { + rxrpc_send_ack_packet(call, false, send_ack); + goto recheck_state; + } + } + + if (test_and_clear_bit(RXRPC_CALL_EV_PING, &call->events)) { + rxrpc_send_ack_packet(call, true, NULL); + goto recheck_state; + } + + if (test_and_clear_bit(RXRPC_CALL_EV_RESEND, &call->events) && + call->state != RXRPC_CALL_CLIENT_RECV_REPLY) { + rxrpc_resend(call, now); + goto recheck_state; + } + + /* Make sure the timer is restarted */ + next = call->expect_rx_by; + +#define set(T) { t = READ_ONCE(T); if (time_before(t, next)) next = t; } + + set(call->expect_req_by); + set(call->expect_term_by); + set(call->ack_at); + set(call->ack_lost_at); + set(call->resend_at); + set(call->keepalive_at); + set(call->ping_at); + + now = jiffies; + if (time_after_eq(now, next)) + goto recheck_state; + + rxrpc_reduce_call_timer(call, next, now, rxrpc_timer_restart); + + /* other events may have been raised since we started checking */ + if (call->events && call->state < RXRPC_CALL_COMPLETE) + goto requeue; + +out_put: + rxrpc_put_call(call, rxrpc_call_put); +out: + _leave(""); + return; + +requeue: + __rxrpc_queue_call(call); + goto out; +} diff --git a/net/rxrpc/call_object.c b/net/rxrpc/call_object.c new file mode 100644 index 000000000..6401cdf7a --- /dev/null +++ b/net/rxrpc/call_object.c @@ -0,0 +1,737 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* RxRPC individual remote procedure call handling + * + * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include "ar-internal.h" + +const char *const rxrpc_call_states[NR__RXRPC_CALL_STATES] = { + [RXRPC_CALL_UNINITIALISED] = "Uninit ", + [RXRPC_CALL_CLIENT_AWAIT_CONN] = "ClWtConn", + [RXRPC_CALL_CLIENT_SEND_REQUEST] = "ClSndReq", + [RXRPC_CALL_CLIENT_AWAIT_REPLY] = "ClAwtRpl", + [RXRPC_CALL_CLIENT_RECV_REPLY] = "ClRcvRpl", + [RXRPC_CALL_SERVER_PREALLOC] = "SvPrealc", + [RXRPC_CALL_SERVER_SECURING] = "SvSecure", + [RXRPC_CALL_SERVER_RECV_REQUEST] = "SvRcvReq", + [RXRPC_CALL_SERVER_ACK_REQUEST] = "SvAckReq", + [RXRPC_CALL_SERVER_SEND_REPLY] = "SvSndRpl", + [RXRPC_CALL_SERVER_AWAIT_ACK] = "SvAwtACK", + [RXRPC_CALL_COMPLETE] = "Complete", +}; + +const char *const rxrpc_call_completions[NR__RXRPC_CALL_COMPLETIONS] = { + [RXRPC_CALL_SUCCEEDED] = "Complete", + [RXRPC_CALL_REMOTELY_ABORTED] = "RmtAbort", + [RXRPC_CALL_LOCALLY_ABORTED] = "LocAbort", + [RXRPC_CALL_LOCAL_ERROR] = "LocError", + [RXRPC_CALL_NETWORK_ERROR] = "NetError", +}; + +struct kmem_cache *rxrpc_call_jar; + +static struct semaphore rxrpc_call_limiter = + __SEMAPHORE_INITIALIZER(rxrpc_call_limiter, 1000); +static struct semaphore rxrpc_kernel_call_limiter = + __SEMAPHORE_INITIALIZER(rxrpc_kernel_call_limiter, 1000); + +static void rxrpc_call_timer_expired(struct timer_list *t) +{ + struct rxrpc_call *call = from_timer(call, t, timer); + + _enter("%d", call->debug_id); + + if (call->state < RXRPC_CALL_COMPLETE) { + trace_rxrpc_timer(call, rxrpc_timer_expired, jiffies); + __rxrpc_queue_call(call); + } else { + rxrpc_put_call(call, rxrpc_call_put); + } +} + +void rxrpc_reduce_call_timer(struct rxrpc_call *call, + unsigned long expire_at, + unsigned long now, + enum rxrpc_timer_trace why) +{ + if (rxrpc_try_get_call(call, rxrpc_call_got_timer)) { + trace_rxrpc_timer(call, why, now); + if (timer_reduce(&call->timer, expire_at)) + rxrpc_put_call(call, rxrpc_call_put_notimer); + } +} + +void rxrpc_delete_call_timer(struct rxrpc_call *call) +{ + if (del_timer_sync(&call->timer)) + rxrpc_put_call(call, rxrpc_call_put_timer); +} + +static struct lock_class_key rxrpc_call_user_mutex_lock_class_key; + +/* + * find an extant server call + * - called in process context with IRQs enabled + */ +struct rxrpc_call *rxrpc_find_call_by_user_ID(struct rxrpc_sock *rx, + unsigned long user_call_ID) +{ + struct rxrpc_call *call; + struct rb_node *p; + + _enter("%p,%lx", rx, user_call_ID); + + read_lock(&rx->call_lock); + + p = rx->calls.rb_node; + while (p) { + call = rb_entry(p, struct rxrpc_call, sock_node); + + if (user_call_ID < call->user_call_ID) + p = p->rb_left; + else if (user_call_ID > call->user_call_ID) + p = p->rb_right; + else + goto found_extant_call; + } + + read_unlock(&rx->call_lock); + _leave(" = NULL"); + return NULL; + +found_extant_call: + rxrpc_get_call(call, rxrpc_call_got); + read_unlock(&rx->call_lock); + _leave(" = %p [%d]", call, refcount_read(&call->ref)); + return call; +} + +/* + * allocate a new call + */ +struct rxrpc_call *rxrpc_alloc_call(struct rxrpc_sock *rx, gfp_t gfp, + unsigned int debug_id) +{ + struct rxrpc_call *call; + struct rxrpc_net *rxnet = rxrpc_net(sock_net(&rx->sk)); + + call = kmem_cache_zalloc(rxrpc_call_jar, gfp); + if (!call) + return NULL; + + call->rxtx_buffer = kcalloc(RXRPC_RXTX_BUFF_SIZE, + sizeof(struct sk_buff *), + gfp); + if (!call->rxtx_buffer) + goto nomem; + + call->rxtx_annotations = kcalloc(RXRPC_RXTX_BUFF_SIZE, sizeof(u8), gfp); + if (!call->rxtx_annotations) + goto nomem_2; + + mutex_init(&call->user_mutex); + + /* Prevent lockdep reporting a deadlock false positive between the afs + * filesystem and sys_sendmsg() via the mmap sem. + */ + if (rx->sk.sk_kern_sock) + lockdep_set_class(&call->user_mutex, + &rxrpc_call_user_mutex_lock_class_key); + + timer_setup(&call->timer, rxrpc_call_timer_expired, 0); + INIT_WORK(&call->processor, &rxrpc_process_call); + INIT_LIST_HEAD(&call->link); + INIT_LIST_HEAD(&call->chan_wait_link); + INIT_LIST_HEAD(&call->accept_link); + INIT_LIST_HEAD(&call->recvmsg_link); + INIT_LIST_HEAD(&call->sock_link); + init_waitqueue_head(&call->waitq); + spin_lock_init(&call->lock); + spin_lock_init(&call->notify_lock); + spin_lock_init(&call->input_lock); + rwlock_init(&call->state_lock); + refcount_set(&call->ref, 1); + call->debug_id = debug_id; + call->tx_total_len = -1; + call->next_rx_timo = 20 * HZ; + call->next_req_timo = 1 * HZ; + + memset(&call->sock_node, 0xed, sizeof(call->sock_node)); + + /* Leave space in the ring to handle a maxed-out jumbo packet */ + call->rx_winsize = rxrpc_rx_window_size; + call->tx_winsize = 16; + call->rx_expect_next = 1; + + call->cong_cwnd = 2; + call->cong_ssthresh = RXRPC_RXTX_BUFF_SIZE - 1; + + call->rxnet = rxnet; + call->rtt_avail = RXRPC_CALL_RTT_AVAIL_MASK; + atomic_inc(&rxnet->nr_calls); + return call; + +nomem_2: + kfree(call->rxtx_buffer); +nomem: + kmem_cache_free(rxrpc_call_jar, call); + return NULL; +} + +/* + * Allocate a new client call. + */ +static struct rxrpc_call *rxrpc_alloc_client_call(struct rxrpc_sock *rx, + struct sockaddr_rxrpc *srx, + gfp_t gfp, + unsigned int debug_id) +{ + struct rxrpc_call *call; + ktime_t now; + + _enter(""); + + call = rxrpc_alloc_call(rx, gfp, debug_id); + if (!call) + return ERR_PTR(-ENOMEM); + call->state = RXRPC_CALL_CLIENT_AWAIT_CONN; + call->service_id = srx->srx_service; + call->tx_phase = true; + now = ktime_get_real(); + call->acks_latest_ts = now; + call->cong_tstamp = now; + + _leave(" = %p", call); + return call; +} + +/* + * Initiate the call ack/resend/expiry timer. + */ +static void rxrpc_start_call_timer(struct rxrpc_call *call) +{ + unsigned long now = jiffies; + unsigned long j = now + MAX_JIFFY_OFFSET; + + call->ack_at = j; + call->ack_lost_at = j; + call->resend_at = j; + call->ping_at = j; + call->expect_rx_by = j; + call->expect_req_by = j; + call->expect_term_by = j; + call->timer.expires = now; +} + +/* + * Wait for a call slot to become available. + */ +static struct semaphore *rxrpc_get_call_slot(struct rxrpc_call_params *p, gfp_t gfp) +{ + struct semaphore *limiter = &rxrpc_call_limiter; + + if (p->kernel) + limiter = &rxrpc_kernel_call_limiter; + if (p->interruptibility == RXRPC_UNINTERRUPTIBLE) { + down(limiter); + return limiter; + } + return down_interruptible(limiter) < 0 ? NULL : limiter; +} + +/* + * Release a call slot. + */ +static void rxrpc_put_call_slot(struct rxrpc_call *call) +{ + struct semaphore *limiter = &rxrpc_call_limiter; + + if (test_bit(RXRPC_CALL_KERNEL, &call->flags)) + limiter = &rxrpc_kernel_call_limiter; + up(limiter); +} + +/* + * Set up a call for the given parameters. + * - Called with the socket lock held, which it must release. + * - If it returns a call, the call's lock will need releasing by the caller. + */ +struct rxrpc_call *rxrpc_new_client_call(struct rxrpc_sock *rx, + struct rxrpc_conn_parameters *cp, + struct sockaddr_rxrpc *srx, + struct rxrpc_call_params *p, + gfp_t gfp, + unsigned int debug_id) + __releases(&rx->sk.sk_lock.slock) + __acquires(&call->user_mutex) +{ + struct rxrpc_call *call, *xcall; + struct rxrpc_net *rxnet; + struct semaphore *limiter; + struct rb_node *parent, **pp; + const void *here = __builtin_return_address(0); + int ret; + + _enter("%p,%lx", rx, p->user_call_ID); + + limiter = rxrpc_get_call_slot(p, gfp); + if (!limiter) { + release_sock(&rx->sk); + return ERR_PTR(-ERESTARTSYS); + } + + call = rxrpc_alloc_client_call(rx, srx, gfp, debug_id); + if (IS_ERR(call)) { + release_sock(&rx->sk); + up(limiter); + _leave(" = %ld", PTR_ERR(call)); + return call; + } + + call->interruptibility = p->interruptibility; + call->tx_total_len = p->tx_total_len; + trace_rxrpc_call(call->debug_id, rxrpc_call_new_client, + refcount_read(&call->ref), + here, (const void *)p->user_call_ID); + if (p->kernel) + __set_bit(RXRPC_CALL_KERNEL, &call->flags); + + /* We need to protect a partially set up call against the user as we + * will be acting outside the socket lock. + */ + mutex_lock(&call->user_mutex); + + /* Publish the call, even though it is incompletely set up as yet */ + write_lock(&rx->call_lock); + + pp = &rx->calls.rb_node; + parent = NULL; + while (*pp) { + parent = *pp; + xcall = rb_entry(parent, struct rxrpc_call, sock_node); + + if (p->user_call_ID < xcall->user_call_ID) + pp = &(*pp)->rb_left; + else if (p->user_call_ID > xcall->user_call_ID) + pp = &(*pp)->rb_right; + else + goto error_dup_user_ID; + } + + rcu_assign_pointer(call->socket, rx); + call->user_call_ID = p->user_call_ID; + __set_bit(RXRPC_CALL_HAS_USERID, &call->flags); + rxrpc_get_call(call, rxrpc_call_got_userid); + rb_link_node(&call->sock_node, parent, pp); + rb_insert_color(&call->sock_node, &rx->calls); + list_add(&call->sock_link, &rx->sock_calls); + + write_unlock(&rx->call_lock); + + rxnet = call->rxnet; + spin_lock_bh(&rxnet->call_lock); + list_add_tail_rcu(&call->link, &rxnet->calls); + spin_unlock_bh(&rxnet->call_lock); + + /* From this point on, the call is protected by its own lock. */ + release_sock(&rx->sk); + + /* Set up or get a connection record and set the protocol parameters, + * including channel number and call ID. + */ + ret = rxrpc_connect_call(rx, call, cp, srx, gfp); + if (ret < 0) + goto error_attached_to_socket; + + trace_rxrpc_call(call->debug_id, rxrpc_call_connected, + refcount_read(&call->ref), here, NULL); + + rxrpc_start_call_timer(call); + + _net("CALL new %d on CONN %d", call->debug_id, call->conn->debug_id); + + _leave(" = %p [new]", call); + return call; + + /* We unexpectedly found the user ID in the list after taking + * the call_lock. This shouldn't happen unless the user races + * with itself and tries to add the same user ID twice at the + * same time in different threads. + */ +error_dup_user_ID: + write_unlock(&rx->call_lock); + release_sock(&rx->sk); + __rxrpc_set_call_completion(call, RXRPC_CALL_LOCAL_ERROR, + RX_CALL_DEAD, -EEXIST); + trace_rxrpc_call(call->debug_id, rxrpc_call_error, + refcount_read(&call->ref), here, ERR_PTR(-EEXIST)); + rxrpc_release_call(rx, call); + mutex_unlock(&call->user_mutex); + rxrpc_put_call(call, rxrpc_call_put); + _leave(" = -EEXIST"); + return ERR_PTR(-EEXIST); + + /* We got an error, but the call is attached to the socket and is in + * need of release. However, we might now race with recvmsg() when + * completing the call queues it. Return 0 from sys_sendmsg() and + * leave the error to recvmsg() to deal with. + */ +error_attached_to_socket: + trace_rxrpc_call(call->debug_id, rxrpc_call_error, + refcount_read(&call->ref), here, ERR_PTR(ret)); + set_bit(RXRPC_CALL_DISCONNECTED, &call->flags); + __rxrpc_set_call_completion(call, RXRPC_CALL_LOCAL_ERROR, + RX_CALL_DEAD, ret); + _leave(" = c=%08x [err]", call->debug_id); + return call; +} + +/* + * Set up an incoming call. call->conn points to the connection. + * This is called in BH context and isn't allowed to fail. + */ +void rxrpc_incoming_call(struct rxrpc_sock *rx, + struct rxrpc_call *call, + struct sk_buff *skb) +{ + struct rxrpc_connection *conn = call->conn; + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + u32 chan; + + _enter(",%d", call->conn->debug_id); + + rcu_assign_pointer(call->socket, rx); + call->call_id = sp->hdr.callNumber; + call->service_id = sp->hdr.serviceId; + call->cid = sp->hdr.cid; + call->state = RXRPC_CALL_SERVER_SECURING; + call->cong_tstamp = skb->tstamp; + + /* Set the channel for this call. We don't get channel_lock as we're + * only defending against the data_ready handler (which we're called + * from) and the RESPONSE packet parser (which is only really + * interested in call_counter and can cope with a disagreement with the + * call pointer). + */ + chan = sp->hdr.cid & RXRPC_CHANNELMASK; + conn->channels[chan].call_counter = call->call_id; + conn->channels[chan].call_id = call->call_id; + rcu_assign_pointer(conn->channels[chan].call, call); + + spin_lock(&conn->params.peer->lock); + hlist_add_head_rcu(&call->error_link, &conn->params.peer->error_targets); + spin_unlock(&conn->params.peer->lock); + + _net("CALL incoming %d on CONN %d", call->debug_id, call->conn->debug_id); + + rxrpc_start_call_timer(call); + _leave(""); +} + +/* + * Queue a call's work processor, getting a ref to pass to the work queue. + */ +bool rxrpc_queue_call(struct rxrpc_call *call) +{ + const void *here = __builtin_return_address(0); + int n; + + if (!__refcount_inc_not_zero(&call->ref, &n)) + return false; + if (rxrpc_queue_work(&call->processor)) + trace_rxrpc_call(call->debug_id, rxrpc_call_queued, n + 1, + here, NULL); + else + rxrpc_put_call(call, rxrpc_call_put_noqueue); + return true; +} + +/* + * Queue a call's work processor, passing the callers ref to the work queue. + */ +bool __rxrpc_queue_call(struct rxrpc_call *call) +{ + const void *here = __builtin_return_address(0); + int n = refcount_read(&call->ref); + ASSERTCMP(n, >=, 1); + if (rxrpc_queue_work(&call->processor)) + trace_rxrpc_call(call->debug_id, rxrpc_call_queued_ref, n, + here, NULL); + else + rxrpc_put_call(call, rxrpc_call_put_noqueue); + return true; +} + +/* + * Note the re-emergence of a call. + */ +void rxrpc_see_call(struct rxrpc_call *call) +{ + const void *here = __builtin_return_address(0); + if (call) { + int n = refcount_read(&call->ref); + + trace_rxrpc_call(call->debug_id, rxrpc_call_seen, n, + here, NULL); + } +} + +bool rxrpc_try_get_call(struct rxrpc_call *call, enum rxrpc_call_trace op) +{ + const void *here = __builtin_return_address(0); + int n; + + if (!__refcount_inc_not_zero(&call->ref, &n)) + return false; + trace_rxrpc_call(call->debug_id, op, n + 1, here, NULL); + return true; +} + +/* + * Note the addition of a ref on a call. + */ +void rxrpc_get_call(struct rxrpc_call *call, enum rxrpc_call_trace op) +{ + const void *here = __builtin_return_address(0); + int n; + + __refcount_inc(&call->ref, &n); + trace_rxrpc_call(call->debug_id, op, n + 1, here, NULL); +} + +/* + * Clean up the RxTx skb ring. + */ +static void rxrpc_cleanup_ring(struct rxrpc_call *call) +{ + int i; + + for (i = 0; i < RXRPC_RXTX_BUFF_SIZE; i++) { + rxrpc_free_skb(call->rxtx_buffer[i], rxrpc_skb_cleaned); + call->rxtx_buffer[i] = NULL; + } +} + +/* + * Detach a call from its owning socket. + */ +void rxrpc_release_call(struct rxrpc_sock *rx, struct rxrpc_call *call) +{ + const void *here = __builtin_return_address(0); + struct rxrpc_connection *conn = call->conn; + bool put = false; + + _enter("{%d,%d}", call->debug_id, refcount_read(&call->ref)); + + trace_rxrpc_call(call->debug_id, rxrpc_call_release, + refcount_read(&call->ref), + here, (const void *)call->flags); + + ASSERTCMP(call->state, ==, RXRPC_CALL_COMPLETE); + + spin_lock_bh(&call->lock); + if (test_and_set_bit(RXRPC_CALL_RELEASED, &call->flags)) + BUG(); + spin_unlock_bh(&call->lock); + + rxrpc_put_call_slot(call); + rxrpc_delete_call_timer(call); + + /* Make sure we don't get any more notifications */ + write_lock_bh(&rx->recvmsg_lock); + + if (!list_empty(&call->recvmsg_link)) { + _debug("unlinking once-pending call %p { e=%lx f=%lx }", + call, call->events, call->flags); + list_del(&call->recvmsg_link); + put = true; + } + + /* list_empty() must return false in rxrpc_notify_socket() */ + call->recvmsg_link.next = NULL; + call->recvmsg_link.prev = NULL; + + write_unlock_bh(&rx->recvmsg_lock); + if (put) + rxrpc_put_call(call, rxrpc_call_put); + + write_lock(&rx->call_lock); + + if (test_and_clear_bit(RXRPC_CALL_HAS_USERID, &call->flags)) { + rb_erase(&call->sock_node, &rx->calls); + memset(&call->sock_node, 0xdd, sizeof(call->sock_node)); + rxrpc_put_call(call, rxrpc_call_put_userid); + } + + list_del(&call->sock_link); + write_unlock(&rx->call_lock); + + _debug("RELEASE CALL %p (%d CONN %p)", call, call->debug_id, conn); + + if (conn && !test_bit(RXRPC_CALL_DISCONNECTED, &call->flags)) + rxrpc_disconnect_call(call); + if (call->security) + call->security->free_call_crypto(call); + _leave(""); +} + +/* + * release all the calls associated with a socket + */ +void rxrpc_release_calls_on_socket(struct rxrpc_sock *rx) +{ + struct rxrpc_call *call; + + _enter("%p", rx); + + while (!list_empty(&rx->to_be_accepted)) { + call = list_entry(rx->to_be_accepted.next, + struct rxrpc_call, accept_link); + list_del(&call->accept_link); + rxrpc_abort_call("SKR", call, 0, RX_CALL_DEAD, -ECONNRESET); + rxrpc_put_call(call, rxrpc_call_put); + } + + while (!list_empty(&rx->sock_calls)) { + call = list_entry(rx->sock_calls.next, + struct rxrpc_call, sock_link); + rxrpc_get_call(call, rxrpc_call_got); + rxrpc_abort_call("SKT", call, 0, RX_CALL_DEAD, -ECONNRESET); + rxrpc_send_abort_packet(call); + rxrpc_release_call(rx, call); + rxrpc_put_call(call, rxrpc_call_put); + } + + _leave(""); +} + +/* + * release a call + */ +void rxrpc_put_call(struct rxrpc_call *call, enum rxrpc_call_trace op) +{ + struct rxrpc_net *rxnet = call->rxnet; + const void *here = __builtin_return_address(0); + unsigned int debug_id = call->debug_id; + bool dead; + int n; + + ASSERT(call != NULL); + + dead = __refcount_dec_and_test(&call->ref, &n); + trace_rxrpc_call(debug_id, op, n, here, NULL); + if (dead) { + _debug("call %d dead", call->debug_id); + ASSERTCMP(call->state, ==, RXRPC_CALL_COMPLETE); + + if (!list_empty(&call->link)) { + spin_lock_bh(&rxnet->call_lock); + list_del_init(&call->link); + spin_unlock_bh(&rxnet->call_lock); + } + + rxrpc_cleanup_call(call); + } +} + +/* + * Final call destruction - but must be done in process context. + */ +static void rxrpc_destroy_call(struct work_struct *work) +{ + struct rxrpc_call *call = container_of(work, struct rxrpc_call, processor); + struct rxrpc_net *rxnet = call->rxnet; + + rxrpc_delete_call_timer(call); + + rxrpc_put_connection(call->conn); + rxrpc_put_peer(call->peer); + kfree(call->rxtx_buffer); + kfree(call->rxtx_annotations); + kmem_cache_free(rxrpc_call_jar, call); + if (atomic_dec_and_test(&rxnet->nr_calls)) + wake_up_var(&rxnet->nr_calls); +} + +/* + * Final call destruction under RCU. + */ +static void rxrpc_rcu_destroy_call(struct rcu_head *rcu) +{ + struct rxrpc_call *call = container_of(rcu, struct rxrpc_call, rcu); + + if (in_softirq()) { + INIT_WORK(&call->processor, rxrpc_destroy_call); + if (!rxrpc_queue_work(&call->processor)) + BUG(); + } else { + rxrpc_destroy_call(&call->processor); + } +} + +/* + * clean up a call + */ +void rxrpc_cleanup_call(struct rxrpc_call *call) +{ + _net("DESTROY CALL %d", call->debug_id); + + memset(&call->sock_node, 0xcd, sizeof(call->sock_node)); + + ASSERTCMP(call->state, ==, RXRPC_CALL_COMPLETE); + ASSERT(test_bit(RXRPC_CALL_RELEASED, &call->flags)); + + rxrpc_cleanup_ring(call); + rxrpc_free_skb(call->tx_pending, rxrpc_skb_cleaned); + + call_rcu(&call->rcu, rxrpc_rcu_destroy_call); +} + +/* + * Make sure that all calls are gone from a network namespace. To reach this + * point, any open UDP sockets in that namespace must have been closed, so any + * outstanding calls cannot be doing I/O. + */ +void rxrpc_destroy_all_calls(struct rxrpc_net *rxnet) +{ + struct rxrpc_call *call; + + _enter(""); + + if (!list_empty(&rxnet->calls)) { + spin_lock_bh(&rxnet->call_lock); + + while (!list_empty(&rxnet->calls)) { + call = list_entry(rxnet->calls.next, + struct rxrpc_call, link); + _debug("Zapping call %p", call); + + rxrpc_see_call(call); + list_del_init(&call->link); + + pr_err("Call %p still in use (%d,%s,%lx,%lx)!\n", + call, refcount_read(&call->ref), + rxrpc_call_states[call->state], + call->flags, call->events); + + spin_unlock_bh(&rxnet->call_lock); + cond_resched(); + spin_lock_bh(&rxnet->call_lock); + } + + spin_unlock_bh(&rxnet->call_lock); + } + + atomic_dec(&rxnet->nr_calls); + wait_var_event(&rxnet->nr_calls, !atomic_read(&rxnet->nr_calls)); +} diff --git a/net/rxrpc/conn_client.c b/net/rxrpc/conn_client.c new file mode 100644 index 000000000..bdb335cb2 --- /dev/null +++ b/net/rxrpc/conn_client.c @@ -0,0 +1,1131 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* Client connection-specific management code. + * + * Copyright (C) 2016, 2020 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + * + * Client connections need to be cached for a little while after they've made a + * call so as to handle retransmitted DATA packets in case the server didn't + * receive the final ACK or terminating ABORT we sent it. + * + * There are flags of relevance to the cache: + * + * (2) DONT_REUSE - The connection should be discarded as soon as possible and + * should not be reused. This is set when an exclusive connection is used + * or a call ID counter overflows. + * + * The caching state may only be changed if the cache lock is held. + * + * There are two idle client connection expiry durations. If the total number + * of connections is below the reap threshold, we use the normal duration; if + * it's above, we use the fast duration. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include + +#include "ar-internal.h" + +__read_mostly unsigned int rxrpc_reap_client_connections = 900; +__read_mostly unsigned long rxrpc_conn_idle_client_expiry = 2 * 60 * HZ; +__read_mostly unsigned long rxrpc_conn_idle_client_fast_expiry = 2 * HZ; + +/* + * We use machine-unique IDs for our client connections. + */ +DEFINE_IDR(rxrpc_client_conn_ids); +static DEFINE_SPINLOCK(rxrpc_conn_id_lock); + +static void rxrpc_deactivate_bundle(struct rxrpc_bundle *bundle); + +/* + * Get a connection ID and epoch for a client connection from the global pool. + * The connection struct pointer is then recorded in the idr radix tree. The + * epoch doesn't change until the client is rebooted (or, at least, unless the + * module is unloaded). + */ +static int rxrpc_get_client_connection_id(struct rxrpc_connection *conn, + gfp_t gfp) +{ + struct rxrpc_net *rxnet = conn->params.local->rxnet; + int id; + + _enter(""); + + idr_preload(gfp); + spin_lock(&rxrpc_conn_id_lock); + + id = idr_alloc_cyclic(&rxrpc_client_conn_ids, conn, + 1, 0x40000000, GFP_NOWAIT); + if (id < 0) + goto error; + + spin_unlock(&rxrpc_conn_id_lock); + idr_preload_end(); + + conn->proto.epoch = rxnet->epoch; + conn->proto.cid = id << RXRPC_CIDSHIFT; + set_bit(RXRPC_CONN_HAS_IDR, &conn->flags); + _leave(" [CID %x]", conn->proto.cid); + return 0; + +error: + spin_unlock(&rxrpc_conn_id_lock); + idr_preload_end(); + _leave(" = %d", id); + return id; +} + +/* + * Release a connection ID for a client connection from the global pool. + */ +static void rxrpc_put_client_connection_id(struct rxrpc_connection *conn) +{ + if (test_bit(RXRPC_CONN_HAS_IDR, &conn->flags)) { + spin_lock(&rxrpc_conn_id_lock); + idr_remove(&rxrpc_client_conn_ids, + conn->proto.cid >> RXRPC_CIDSHIFT); + spin_unlock(&rxrpc_conn_id_lock); + } +} + +/* + * Destroy the client connection ID tree. + */ +void rxrpc_destroy_client_conn_ids(void) +{ + struct rxrpc_connection *conn; + int id; + + if (!idr_is_empty(&rxrpc_client_conn_ids)) { + idr_for_each_entry(&rxrpc_client_conn_ids, conn, id) { + pr_err("AF_RXRPC: Leaked client conn %p {%d}\n", + conn, refcount_read(&conn->ref)); + } + BUG(); + } + + idr_destroy(&rxrpc_client_conn_ids); +} + +/* + * Allocate a connection bundle. + */ +static struct rxrpc_bundle *rxrpc_alloc_bundle(struct rxrpc_conn_parameters *cp, + gfp_t gfp) +{ + struct rxrpc_bundle *bundle; + + bundle = kzalloc(sizeof(*bundle), gfp); + if (bundle) { + bundle->params = *cp; + rxrpc_get_peer(bundle->params.peer); + refcount_set(&bundle->ref, 1); + atomic_set(&bundle->active, 1); + spin_lock_init(&bundle->channel_lock); + INIT_LIST_HEAD(&bundle->waiting_calls); + } + return bundle; +} + +struct rxrpc_bundle *rxrpc_get_bundle(struct rxrpc_bundle *bundle) +{ + refcount_inc(&bundle->ref); + return bundle; +} + +static void rxrpc_free_bundle(struct rxrpc_bundle *bundle) +{ + rxrpc_put_peer(bundle->params.peer); + kfree(bundle); +} + +void rxrpc_put_bundle(struct rxrpc_bundle *bundle) +{ + unsigned int d = bundle->debug_id; + bool dead; + int r; + + dead = __refcount_dec_and_test(&bundle->ref, &r); + + _debug("PUT B=%x %d", d, r - 1); + if (dead) + rxrpc_free_bundle(bundle); +} + +/* + * Allocate a client connection. + */ +static struct rxrpc_connection * +rxrpc_alloc_client_connection(struct rxrpc_bundle *bundle, gfp_t gfp) +{ + struct rxrpc_connection *conn; + struct rxrpc_net *rxnet = bundle->params.local->rxnet; + int ret; + + _enter(""); + + conn = rxrpc_alloc_connection(gfp); + if (!conn) { + _leave(" = -ENOMEM"); + return ERR_PTR(-ENOMEM); + } + + refcount_set(&conn->ref, 1); + conn->bundle = bundle; + conn->params = bundle->params; + conn->out_clientflag = RXRPC_CLIENT_INITIATED; + conn->state = RXRPC_CONN_CLIENT; + conn->service_id = conn->params.service_id; + + ret = rxrpc_get_client_connection_id(conn, gfp); + if (ret < 0) + goto error_0; + + ret = rxrpc_init_client_conn_security(conn); + if (ret < 0) + goto error_1; + + atomic_inc(&rxnet->nr_conns); + write_lock(&rxnet->conn_lock); + list_add_tail(&conn->proc_link, &rxnet->conn_proc_list); + write_unlock(&rxnet->conn_lock); + + rxrpc_get_bundle(bundle); + rxrpc_get_peer(conn->params.peer); + rxrpc_get_local(conn->params.local); + key_get(conn->params.key); + + trace_rxrpc_conn(conn->debug_id, rxrpc_conn_new_client, + refcount_read(&conn->ref), + __builtin_return_address(0)); + + atomic_inc(&rxnet->nr_client_conns); + trace_rxrpc_client(conn, -1, rxrpc_client_alloc); + _leave(" = %p", conn); + return conn; + +error_1: + rxrpc_put_client_connection_id(conn); +error_0: + kfree(conn); + _leave(" = %d", ret); + return ERR_PTR(ret); +} + +/* + * Determine if a connection may be reused. + */ +static bool rxrpc_may_reuse_conn(struct rxrpc_connection *conn) +{ + struct rxrpc_net *rxnet; + int id_cursor, id, distance, limit; + + if (!conn) + goto dont_reuse; + + rxnet = conn->params.local->rxnet; + if (test_bit(RXRPC_CONN_DONT_REUSE, &conn->flags)) + goto dont_reuse; + + if (conn->state != RXRPC_CONN_CLIENT || + conn->proto.epoch != rxnet->epoch) + goto mark_dont_reuse; + + /* The IDR tree gets very expensive on memory if the connection IDs are + * widely scattered throughout the number space, so we shall want to + * kill off connections that, say, have an ID more than about four + * times the maximum number of client conns away from the current + * allocation point to try and keep the IDs concentrated. + */ + id_cursor = idr_get_cursor(&rxrpc_client_conn_ids); + id = conn->proto.cid >> RXRPC_CIDSHIFT; + distance = id - id_cursor; + if (distance < 0) + distance = -distance; + limit = max_t(unsigned long, atomic_read(&rxnet->nr_conns) * 4, 1024); + if (distance > limit) + goto mark_dont_reuse; + + return true; + +mark_dont_reuse: + set_bit(RXRPC_CONN_DONT_REUSE, &conn->flags); +dont_reuse: + return false; +} + +/* + * Look up the conn bundle that matches the connection parameters, adding it if + * it doesn't yet exist. + */ +static struct rxrpc_bundle *rxrpc_look_up_bundle(struct rxrpc_conn_parameters *cp, + gfp_t gfp) +{ + static atomic_t rxrpc_bundle_id; + struct rxrpc_bundle *bundle, *candidate; + struct rxrpc_local *local = cp->local; + struct rb_node *p, **pp, *parent; + long diff; + + _enter("{%px,%x,%u,%u}", + cp->peer, key_serial(cp->key), cp->security_level, cp->upgrade); + + if (cp->exclusive) + return rxrpc_alloc_bundle(cp, gfp); + + /* First, see if the bundle is already there. */ + _debug("search 1"); + spin_lock(&local->client_bundles_lock); + p = local->client_bundles.rb_node; + while (p) { + bundle = rb_entry(p, struct rxrpc_bundle, local_node); + +#define cmp(X) ((long)bundle->params.X - (long)cp->X) + diff = (cmp(peer) ?: + cmp(key) ?: + cmp(security_level) ?: + cmp(upgrade)); +#undef cmp + if (diff < 0) + p = p->rb_left; + else if (diff > 0) + p = p->rb_right; + else + goto found_bundle; + } + spin_unlock(&local->client_bundles_lock); + _debug("not found"); + + /* It wasn't. We need to add one. */ + candidate = rxrpc_alloc_bundle(cp, gfp); + if (!candidate) + return NULL; + + _debug("search 2"); + spin_lock(&local->client_bundles_lock); + pp = &local->client_bundles.rb_node; + parent = NULL; + while (*pp) { + parent = *pp; + bundle = rb_entry(parent, struct rxrpc_bundle, local_node); + +#define cmp(X) ((long)bundle->params.X - (long)cp->X) + diff = (cmp(peer) ?: + cmp(key) ?: + cmp(security_level) ?: + cmp(upgrade)); +#undef cmp + if (diff < 0) + pp = &(*pp)->rb_left; + else if (diff > 0) + pp = &(*pp)->rb_right; + else + goto found_bundle_free; + } + + _debug("new bundle"); + candidate->debug_id = atomic_inc_return(&rxrpc_bundle_id); + rb_link_node(&candidate->local_node, parent, pp); + rb_insert_color(&candidate->local_node, &local->client_bundles); + rxrpc_get_bundle(candidate); + spin_unlock(&local->client_bundles_lock); + _leave(" = %u [new]", candidate->debug_id); + return candidate; + +found_bundle_free: + rxrpc_free_bundle(candidate); +found_bundle: + rxrpc_get_bundle(bundle); + atomic_inc(&bundle->active); + spin_unlock(&local->client_bundles_lock); + _leave(" = %u [found]", bundle->debug_id); + return bundle; +} + +/* + * Create or find a client bundle to use for a call. + * + * If we return with a connection, the call will be on its waiting list. It's + * left to the caller to assign a channel and wake up the call. + */ +static struct rxrpc_bundle *rxrpc_prep_call(struct rxrpc_sock *rx, + struct rxrpc_call *call, + struct rxrpc_conn_parameters *cp, + struct sockaddr_rxrpc *srx, + gfp_t gfp) +{ + struct rxrpc_bundle *bundle; + + _enter("{%d,%lx},", call->debug_id, call->user_call_ID); + + cp->peer = rxrpc_lookup_peer(rx, cp->local, srx, gfp); + if (!cp->peer) + goto error; + + call->cong_cwnd = cp->peer->cong_cwnd; + if (call->cong_cwnd >= call->cong_ssthresh) + call->cong_mode = RXRPC_CALL_CONGEST_AVOIDANCE; + else + call->cong_mode = RXRPC_CALL_SLOW_START; + if (cp->upgrade) + __set_bit(RXRPC_CALL_UPGRADE, &call->flags); + + /* Find the client connection bundle. */ + bundle = rxrpc_look_up_bundle(cp, gfp); + if (!bundle) + goto error; + + /* Get this call queued. Someone else may activate it whilst we're + * lining up a new connection, but that's fine. + */ + spin_lock(&bundle->channel_lock); + list_add_tail(&call->chan_wait_link, &bundle->waiting_calls); + spin_unlock(&bundle->channel_lock); + + _leave(" = [B=%x]", bundle->debug_id); + return bundle; + +error: + _leave(" = -ENOMEM"); + return ERR_PTR(-ENOMEM); +} + +/* + * Allocate a new connection and add it into a bundle. + */ +static void rxrpc_add_conn_to_bundle(struct rxrpc_bundle *bundle, gfp_t gfp) + __releases(bundle->channel_lock) +{ + struct rxrpc_connection *candidate = NULL, *old = NULL; + bool conflict; + int i; + + _enter(""); + + conflict = bundle->alloc_conn; + if (!conflict) + bundle->alloc_conn = true; + spin_unlock(&bundle->channel_lock); + if (conflict) { + _leave(" [conf]"); + return; + } + + candidate = rxrpc_alloc_client_connection(bundle, gfp); + + spin_lock(&bundle->channel_lock); + bundle->alloc_conn = false; + + if (IS_ERR(candidate)) { + bundle->alloc_error = PTR_ERR(candidate); + spin_unlock(&bundle->channel_lock); + _leave(" [err %ld]", PTR_ERR(candidate)); + return; + } + + bundle->alloc_error = 0; + + for (i = 0; i < ARRAY_SIZE(bundle->conns); i++) { + unsigned int shift = i * RXRPC_MAXCALLS; + int j; + + old = bundle->conns[i]; + if (!rxrpc_may_reuse_conn(old)) { + if (old) + trace_rxrpc_client(old, -1, rxrpc_client_replace); + candidate->bundle_shift = shift; + atomic_inc(&bundle->active); + bundle->conns[i] = candidate; + for (j = 0; j < RXRPC_MAXCALLS; j++) + set_bit(shift + j, &bundle->avail_chans); + candidate = NULL; + break; + } + + old = NULL; + } + + spin_unlock(&bundle->channel_lock); + + if (candidate) { + _debug("discard C=%x", candidate->debug_id); + trace_rxrpc_client(candidate, -1, rxrpc_client_duplicate); + rxrpc_put_connection(candidate); + } + + rxrpc_put_connection(old); + _leave(""); +} + +/* + * Add a connection to a bundle if there are no usable connections or we have + * connections waiting for extra capacity. + */ +static void rxrpc_maybe_add_conn(struct rxrpc_bundle *bundle, gfp_t gfp) +{ + struct rxrpc_call *call; + int i, usable; + + _enter(""); + + spin_lock(&bundle->channel_lock); + + /* See if there are any usable connections. */ + usable = 0; + for (i = 0; i < ARRAY_SIZE(bundle->conns); i++) + if (rxrpc_may_reuse_conn(bundle->conns[i])) + usable++; + + if (!usable && !list_empty(&bundle->waiting_calls)) { + call = list_first_entry(&bundle->waiting_calls, + struct rxrpc_call, chan_wait_link); + if (test_bit(RXRPC_CALL_UPGRADE, &call->flags)) + bundle->try_upgrade = true; + } + + if (!usable) + goto alloc_conn; + + if (!bundle->avail_chans && + !bundle->try_upgrade && + !list_empty(&bundle->waiting_calls) && + usable < ARRAY_SIZE(bundle->conns)) + goto alloc_conn; + + spin_unlock(&bundle->channel_lock); + _leave(""); + return; + +alloc_conn: + return rxrpc_add_conn_to_bundle(bundle, gfp); +} + +/* + * Assign a channel to the call at the front of the queue and wake the call up. + * We don't increment the callNumber counter until this number has been exposed + * to the world. + */ +static void rxrpc_activate_one_channel(struct rxrpc_connection *conn, + unsigned int channel) +{ + struct rxrpc_channel *chan = &conn->channels[channel]; + struct rxrpc_bundle *bundle = conn->bundle; + struct rxrpc_call *call = list_entry(bundle->waiting_calls.next, + struct rxrpc_call, chan_wait_link); + u32 call_id = chan->call_counter + 1; + + _enter("C=%x,%u", conn->debug_id, channel); + + trace_rxrpc_client(conn, channel, rxrpc_client_chan_activate); + + /* Cancel the final ACK on the previous call if it hasn't been sent yet + * as the DATA packet will implicitly ACK it. + */ + clear_bit(RXRPC_CONN_FINAL_ACK_0 + channel, &conn->flags); + clear_bit(conn->bundle_shift + channel, &bundle->avail_chans); + + rxrpc_see_call(call); + list_del_init(&call->chan_wait_link); + call->peer = rxrpc_get_peer(conn->params.peer); + call->conn = rxrpc_get_connection(conn); + call->cid = conn->proto.cid | channel; + call->call_id = call_id; + call->security = conn->security; + call->security_ix = conn->security_ix; + call->service_id = conn->service_id; + + trace_rxrpc_connect_call(call); + _net("CONNECT call %08x:%08x as call %d on conn %d", + call->cid, call->call_id, call->debug_id, conn->debug_id); + + write_lock_bh(&call->state_lock); + call->state = RXRPC_CALL_CLIENT_SEND_REQUEST; + write_unlock_bh(&call->state_lock); + + /* Paired with the read barrier in rxrpc_connect_call(). This orders + * cid and epoch in the connection wrt to call_id without the need to + * take the channel_lock. + * + * We provisionally assign a callNumber at this point, but we don't + * confirm it until the call is about to be exposed. + * + * TODO: Pair with a barrier in the data_ready handler when that looks + * at the call ID through a connection channel. + */ + smp_wmb(); + + chan->call_id = call_id; + chan->call_debug_id = call->debug_id; + rcu_assign_pointer(chan->call, call); + wake_up(&call->waitq); +} + +/* + * Remove a connection from the idle list if it's on it. + */ +static void rxrpc_unidle_conn(struct rxrpc_bundle *bundle, struct rxrpc_connection *conn) +{ + struct rxrpc_net *rxnet = bundle->params.local->rxnet; + bool drop_ref; + + if (!list_empty(&conn->cache_link)) { + drop_ref = false; + spin_lock(&rxnet->client_conn_cache_lock); + if (!list_empty(&conn->cache_link)) { + list_del_init(&conn->cache_link); + drop_ref = true; + } + spin_unlock(&rxnet->client_conn_cache_lock); + if (drop_ref) + rxrpc_put_connection(conn); + } +} + +/* + * Assign channels and callNumbers to waiting calls with channel_lock + * held by caller. + */ +static void rxrpc_activate_channels_locked(struct rxrpc_bundle *bundle) +{ + struct rxrpc_connection *conn; + unsigned long avail, mask; + unsigned int channel, slot; + + if (bundle->try_upgrade) + mask = 1; + else + mask = ULONG_MAX; + + while (!list_empty(&bundle->waiting_calls)) { + avail = bundle->avail_chans & mask; + if (!avail) + break; + channel = __ffs(avail); + clear_bit(channel, &bundle->avail_chans); + + slot = channel / RXRPC_MAXCALLS; + conn = bundle->conns[slot]; + if (!conn) + break; + + if (bundle->try_upgrade) + set_bit(RXRPC_CONN_PROBING_FOR_UPGRADE, &conn->flags); + rxrpc_unidle_conn(bundle, conn); + + channel &= (RXRPC_MAXCALLS - 1); + conn->act_chans |= 1 << channel; + rxrpc_activate_one_channel(conn, channel); + } +} + +/* + * Assign channels and callNumbers to waiting calls. + */ +static void rxrpc_activate_channels(struct rxrpc_bundle *bundle) +{ + _enter("B=%x", bundle->debug_id); + + trace_rxrpc_client(NULL, -1, rxrpc_client_activate_chans); + + if (!bundle->avail_chans) + return; + + spin_lock(&bundle->channel_lock); + rxrpc_activate_channels_locked(bundle); + spin_unlock(&bundle->channel_lock); + _leave(""); +} + +/* + * Wait for a callNumber and a channel to be granted to a call. + */ +static int rxrpc_wait_for_channel(struct rxrpc_bundle *bundle, + struct rxrpc_call *call, gfp_t gfp) +{ + DECLARE_WAITQUEUE(myself, current); + int ret = 0; + + _enter("%d", call->debug_id); + + if (!gfpflags_allow_blocking(gfp)) { + rxrpc_maybe_add_conn(bundle, gfp); + rxrpc_activate_channels(bundle); + ret = bundle->alloc_error ?: -EAGAIN; + goto out; + } + + add_wait_queue_exclusive(&call->waitq, &myself); + for (;;) { + rxrpc_maybe_add_conn(bundle, gfp); + rxrpc_activate_channels(bundle); + ret = bundle->alloc_error; + if (ret < 0) + break; + + switch (call->interruptibility) { + case RXRPC_INTERRUPTIBLE: + case RXRPC_PREINTERRUPTIBLE: + set_current_state(TASK_INTERRUPTIBLE); + break; + case RXRPC_UNINTERRUPTIBLE: + default: + set_current_state(TASK_UNINTERRUPTIBLE); + break; + } + if (READ_ONCE(call->state) != RXRPC_CALL_CLIENT_AWAIT_CONN) + break; + if ((call->interruptibility == RXRPC_INTERRUPTIBLE || + call->interruptibility == RXRPC_PREINTERRUPTIBLE) && + signal_pending(current)) { + ret = -ERESTARTSYS; + break; + } + schedule(); + } + remove_wait_queue(&call->waitq, &myself); + __set_current_state(TASK_RUNNING); + +out: + _leave(" = %d", ret); + return ret; +} + +/* + * find a connection for a call + * - called in process context with IRQs enabled + */ +int rxrpc_connect_call(struct rxrpc_sock *rx, + struct rxrpc_call *call, + struct rxrpc_conn_parameters *cp, + struct sockaddr_rxrpc *srx, + gfp_t gfp) +{ + struct rxrpc_bundle *bundle; + struct rxrpc_net *rxnet = cp->local->rxnet; + int ret = 0; + + _enter("{%d,%lx},", call->debug_id, call->user_call_ID); + + rxrpc_discard_expired_client_conns(&rxnet->client_conn_reaper); + + bundle = rxrpc_prep_call(rx, call, cp, srx, gfp); + if (IS_ERR(bundle)) { + ret = PTR_ERR(bundle); + goto out; + } + + if (call->state == RXRPC_CALL_CLIENT_AWAIT_CONN) { + ret = rxrpc_wait_for_channel(bundle, call, gfp); + if (ret < 0) + goto wait_failed; + } + +granted_channel: + /* Paired with the write barrier in rxrpc_activate_one_channel(). */ + smp_rmb(); + +out_put_bundle: + rxrpc_deactivate_bundle(bundle); + rxrpc_put_bundle(bundle); +out: + _leave(" = %d", ret); + return ret; + +wait_failed: + spin_lock(&bundle->channel_lock); + list_del_init(&call->chan_wait_link); + spin_unlock(&bundle->channel_lock); + + if (call->state != RXRPC_CALL_CLIENT_AWAIT_CONN) { + ret = 0; + goto granted_channel; + } + + trace_rxrpc_client(call->conn, ret, rxrpc_client_chan_wait_failed); + rxrpc_set_call_completion(call, RXRPC_CALL_LOCAL_ERROR, 0, ret); + rxrpc_disconnect_client_call(bundle, call); + goto out_put_bundle; +} + +/* + * Note that a call, and thus a connection, is about to be exposed to the + * world. + */ +void rxrpc_expose_client_call(struct rxrpc_call *call) +{ + unsigned int channel = call->cid & RXRPC_CHANNELMASK; + struct rxrpc_connection *conn = call->conn; + struct rxrpc_channel *chan = &conn->channels[channel]; + + if (!test_and_set_bit(RXRPC_CALL_EXPOSED, &call->flags)) { + /* Mark the call ID as being used. If the callNumber counter + * exceeds ~2 billion, we kill the connection after its + * outstanding calls have finished so that the counter doesn't + * wrap. + */ + chan->call_counter++; + if (chan->call_counter >= INT_MAX) + set_bit(RXRPC_CONN_DONT_REUSE, &conn->flags); + trace_rxrpc_client(conn, channel, rxrpc_client_exposed); + } +} + +/* + * Set the reap timer. + */ +static void rxrpc_set_client_reap_timer(struct rxrpc_net *rxnet) +{ + if (!rxnet->kill_all_client_conns) { + unsigned long now = jiffies; + unsigned long reap_at = now + rxrpc_conn_idle_client_expiry; + + if (rxnet->live) + timer_reduce(&rxnet->client_conn_reap_timer, reap_at); + } +} + +/* + * Disconnect a client call. + */ +void rxrpc_disconnect_client_call(struct rxrpc_bundle *bundle, struct rxrpc_call *call) +{ + struct rxrpc_connection *conn; + struct rxrpc_channel *chan = NULL; + struct rxrpc_net *rxnet = bundle->params.local->rxnet; + unsigned int channel; + bool may_reuse; + u32 cid; + + _enter("c=%x", call->debug_id); + + spin_lock(&bundle->channel_lock); + set_bit(RXRPC_CALL_DISCONNECTED, &call->flags); + + /* Calls that have never actually been assigned a channel can simply be + * discarded. + */ + conn = call->conn; + if (!conn) { + _debug("call is waiting"); + ASSERTCMP(call->call_id, ==, 0); + ASSERT(!test_bit(RXRPC_CALL_EXPOSED, &call->flags)); + list_del_init(&call->chan_wait_link); + goto out; + } + + cid = call->cid; + channel = cid & RXRPC_CHANNELMASK; + chan = &conn->channels[channel]; + trace_rxrpc_client(conn, channel, rxrpc_client_chan_disconnect); + + if (rcu_access_pointer(chan->call) != call) { + spin_unlock(&bundle->channel_lock); + BUG(); + } + + may_reuse = rxrpc_may_reuse_conn(conn); + + /* If a client call was exposed to the world, we save the result for + * retransmission. + * + * We use a barrier here so that the call number and abort code can be + * read without needing to take a lock. + * + * TODO: Make the incoming packet handler check this and handle + * terminal retransmission without requiring access to the call. + */ + if (test_bit(RXRPC_CALL_EXPOSED, &call->flags)) { + _debug("exposed %u,%u", call->call_id, call->abort_code); + __rxrpc_disconnect_call(conn, call); + + if (test_and_clear_bit(RXRPC_CONN_PROBING_FOR_UPGRADE, &conn->flags)) { + trace_rxrpc_client(conn, channel, rxrpc_client_to_active); + bundle->try_upgrade = false; + if (may_reuse) + rxrpc_activate_channels_locked(bundle); + } + + } + + /* See if we can pass the channel directly to another call. */ + if (may_reuse && !list_empty(&bundle->waiting_calls)) { + trace_rxrpc_client(conn, channel, rxrpc_client_chan_pass); + rxrpc_activate_one_channel(conn, channel); + goto out; + } + + /* Schedule the final ACK to be transmitted in a short while so that it + * can be skipped if we find a follow-on call. The first DATA packet + * of the follow on call will implicitly ACK this call. + */ + if (call->completion == RXRPC_CALL_SUCCEEDED && + test_bit(RXRPC_CALL_EXPOSED, &call->flags)) { + unsigned long final_ack_at = jiffies + 2; + + WRITE_ONCE(chan->final_ack_at, final_ack_at); + smp_wmb(); /* vs rxrpc_process_delayed_final_acks() */ + set_bit(RXRPC_CONN_FINAL_ACK_0 + channel, &conn->flags); + rxrpc_reduce_conn_timer(conn, final_ack_at); + } + + /* Deactivate the channel. */ + rcu_assign_pointer(chan->call, NULL); + set_bit(conn->bundle_shift + channel, &conn->bundle->avail_chans); + conn->act_chans &= ~(1 << channel); + + /* If no channels remain active, then put the connection on the idle + * list for a short while. Give it a ref to stop it going away if it + * becomes unbundled. + */ + if (!conn->act_chans) { + trace_rxrpc_client(conn, channel, rxrpc_client_to_idle); + conn->idle_timestamp = jiffies; + + rxrpc_get_connection(conn); + spin_lock(&rxnet->client_conn_cache_lock); + list_move_tail(&conn->cache_link, &rxnet->idle_client_conns); + spin_unlock(&rxnet->client_conn_cache_lock); + + rxrpc_set_client_reap_timer(rxnet); + } + +out: + spin_unlock(&bundle->channel_lock); + _leave(""); + return; +} + +/* + * Remove a connection from a bundle. + */ +static void rxrpc_unbundle_conn(struct rxrpc_connection *conn) +{ + struct rxrpc_bundle *bundle = conn->bundle; + unsigned int bindex; + bool need_drop = false; + int i; + + _enter("C=%x", conn->debug_id); + + if (conn->flags & RXRPC_CONN_FINAL_ACK_MASK) + rxrpc_process_delayed_final_acks(conn, true); + + spin_lock(&bundle->channel_lock); + bindex = conn->bundle_shift / RXRPC_MAXCALLS; + if (bundle->conns[bindex] == conn) { + _debug("clear slot %u", bindex); + bundle->conns[bindex] = NULL; + for (i = 0; i < RXRPC_MAXCALLS; i++) + clear_bit(conn->bundle_shift + i, &bundle->avail_chans); + need_drop = true; + } + spin_unlock(&bundle->channel_lock); + + if (need_drop) { + rxrpc_deactivate_bundle(bundle); + rxrpc_put_connection(conn); + } +} + +/* + * Drop the active count on a bundle. + */ +static void rxrpc_deactivate_bundle(struct rxrpc_bundle *bundle) +{ + struct rxrpc_local *local = bundle->params.local; + bool need_put = false; + + if (atomic_dec_and_lock(&bundle->active, &local->client_bundles_lock)) { + if (!bundle->params.exclusive) { + _debug("erase bundle"); + rb_erase(&bundle->local_node, &local->client_bundles); + need_put = true; + } + + spin_unlock(&local->client_bundles_lock); + if (need_put) + rxrpc_put_bundle(bundle); + } +} + +/* + * Clean up a dead client connection. + */ +static void rxrpc_kill_client_conn(struct rxrpc_connection *conn) +{ + struct rxrpc_local *local = conn->params.local; + struct rxrpc_net *rxnet = local->rxnet; + + _enter("C=%x", conn->debug_id); + + trace_rxrpc_client(conn, -1, rxrpc_client_cleanup); + atomic_dec(&rxnet->nr_client_conns); + + rxrpc_put_client_connection_id(conn); + rxrpc_kill_connection(conn); +} + +/* + * Clean up a dead client connections. + */ +void rxrpc_put_client_conn(struct rxrpc_connection *conn) +{ + const void *here = __builtin_return_address(0); + unsigned int debug_id = conn->debug_id; + bool dead; + int r; + + dead = __refcount_dec_and_test(&conn->ref, &r); + trace_rxrpc_conn(debug_id, rxrpc_conn_put_client, r - 1, here); + if (dead) + rxrpc_kill_client_conn(conn); +} + +/* + * Discard expired client connections from the idle list. Each conn in the + * idle list has been exposed and holds an extra ref because of that. + * + * This may be called from conn setup or from a work item so cannot be + * considered non-reentrant. + */ +void rxrpc_discard_expired_client_conns(struct work_struct *work) +{ + struct rxrpc_connection *conn; + struct rxrpc_net *rxnet = + container_of(work, struct rxrpc_net, client_conn_reaper); + unsigned long expiry, conn_expires_at, now; + unsigned int nr_conns; + + _enter(""); + + if (list_empty(&rxnet->idle_client_conns)) { + _leave(" [empty]"); + return; + } + + /* Don't double up on the discarding */ + if (!spin_trylock(&rxnet->client_conn_discard_lock)) { + _leave(" [already]"); + return; + } + + /* We keep an estimate of what the number of conns ought to be after + * we've discarded some so that we don't overdo the discarding. + */ + nr_conns = atomic_read(&rxnet->nr_client_conns); + +next: + spin_lock(&rxnet->client_conn_cache_lock); + + if (list_empty(&rxnet->idle_client_conns)) + goto out; + + conn = list_entry(rxnet->idle_client_conns.next, + struct rxrpc_connection, cache_link); + + if (!rxnet->kill_all_client_conns) { + /* If the number of connections is over the reap limit, we + * expedite discard by reducing the expiry timeout. We must, + * however, have at least a short grace period to be able to do + * final-ACK or ABORT retransmission. + */ + expiry = rxrpc_conn_idle_client_expiry; + if (nr_conns > rxrpc_reap_client_connections) + expiry = rxrpc_conn_idle_client_fast_expiry; + if (conn->params.local->service_closed) + expiry = rxrpc_closed_conn_expiry * HZ; + + conn_expires_at = conn->idle_timestamp + expiry; + + now = READ_ONCE(jiffies); + if (time_after(conn_expires_at, now)) + goto not_yet_expired; + } + + trace_rxrpc_client(conn, -1, rxrpc_client_discard); + list_del_init(&conn->cache_link); + + spin_unlock(&rxnet->client_conn_cache_lock); + + rxrpc_unbundle_conn(conn); + rxrpc_put_connection(conn); /* Drop the ->cache_link ref */ + + nr_conns--; + goto next; + +not_yet_expired: + /* The connection at the front of the queue hasn't yet expired, so + * schedule the work item for that point if we discarded something. + * + * We don't worry if the work item is already scheduled - it can look + * after rescheduling itself at a later time. We could cancel it, but + * then things get messier. + */ + _debug("not yet"); + if (!rxnet->kill_all_client_conns) + timer_reduce(&rxnet->client_conn_reap_timer, conn_expires_at); + +out: + spin_unlock(&rxnet->client_conn_cache_lock); + spin_unlock(&rxnet->client_conn_discard_lock); + _leave(""); +} + +/* + * Preemptively destroy all the client connection records rather than waiting + * for them to time out + */ +void rxrpc_destroy_all_client_connections(struct rxrpc_net *rxnet) +{ + _enter(""); + + spin_lock(&rxnet->client_conn_cache_lock); + rxnet->kill_all_client_conns = true; + spin_unlock(&rxnet->client_conn_cache_lock); + + del_timer_sync(&rxnet->client_conn_reap_timer); + + if (!rxrpc_queue_work(&rxnet->client_conn_reaper)) + _debug("destroy: queue failed"); + + _leave(""); +} + +/* + * Clean up the client connections on a local endpoint. + */ +void rxrpc_clean_up_local_conns(struct rxrpc_local *local) +{ + struct rxrpc_connection *conn, *tmp; + struct rxrpc_net *rxnet = local->rxnet; + LIST_HEAD(graveyard); + + _enter(""); + + spin_lock(&rxnet->client_conn_cache_lock); + + list_for_each_entry_safe(conn, tmp, &rxnet->idle_client_conns, + cache_link) { + if (conn->params.local == local) { + trace_rxrpc_client(conn, -1, rxrpc_client_discard); + list_move(&conn->cache_link, &graveyard); + } + } + + spin_unlock(&rxnet->client_conn_cache_lock); + + while (!list_empty(&graveyard)) { + conn = list_entry(graveyard.next, + struct rxrpc_connection, cache_link); + list_del_init(&conn->cache_link); + rxrpc_unbundle_conn(conn); + rxrpc_put_connection(conn); + } + + _leave(" [culled]"); +} diff --git a/net/rxrpc/conn_event.c b/net/rxrpc/conn_event.c new file mode 100644 index 000000000..aab069701 --- /dev/null +++ b/net/rxrpc/conn_event.c @@ -0,0 +1,499 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* connection-level event handling + * + * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include "ar-internal.h" + +/* + * Retransmit terminal ACK or ABORT of the previous call. + */ +static void rxrpc_conn_retransmit_call(struct rxrpc_connection *conn, + struct sk_buff *skb, + unsigned int channel) +{ + struct rxrpc_skb_priv *sp = skb ? rxrpc_skb(skb) : NULL; + struct rxrpc_channel *chan; + struct msghdr msg; + struct kvec iov[3]; + struct { + struct rxrpc_wire_header whdr; + union { + __be32 abort_code; + struct rxrpc_ackpacket ack; + }; + } __attribute__((packed)) pkt; + struct rxrpc_ackinfo ack_info; + size_t len; + int ret, ioc; + u32 serial, mtu, call_id, padding; + + _enter("%d", conn->debug_id); + + chan = &conn->channels[channel]; + + /* If the last call got moved on whilst we were waiting to run, just + * ignore this packet. + */ + call_id = READ_ONCE(chan->last_call); + /* Sync with __rxrpc_disconnect_call() */ + smp_rmb(); + if (skb && call_id != sp->hdr.callNumber) + return; + + msg.msg_name = &conn->params.peer->srx.transport; + msg.msg_namelen = conn->params.peer->srx.transport_len; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + iov[0].iov_base = &pkt; + iov[0].iov_len = sizeof(pkt.whdr); + iov[1].iov_base = &padding; + iov[1].iov_len = 3; + iov[2].iov_base = &ack_info; + iov[2].iov_len = sizeof(ack_info); + + pkt.whdr.epoch = htonl(conn->proto.epoch); + pkt.whdr.cid = htonl(conn->proto.cid | channel); + pkt.whdr.callNumber = htonl(call_id); + pkt.whdr.seq = 0; + pkt.whdr.type = chan->last_type; + pkt.whdr.flags = conn->out_clientflag; + pkt.whdr.userStatus = 0; + pkt.whdr.securityIndex = conn->security_ix; + pkt.whdr._rsvd = 0; + pkt.whdr.serviceId = htons(conn->service_id); + + len = sizeof(pkt.whdr); + switch (chan->last_type) { + case RXRPC_PACKET_TYPE_ABORT: + pkt.abort_code = htonl(chan->last_abort); + iov[0].iov_len += sizeof(pkt.abort_code); + len += sizeof(pkt.abort_code); + ioc = 1; + break; + + case RXRPC_PACKET_TYPE_ACK: + mtu = conn->params.peer->if_mtu; + mtu -= conn->params.peer->hdrsize; + pkt.ack.bufferSpace = 0; + pkt.ack.maxSkew = htons(skb ? skb->priority : 0); + pkt.ack.firstPacket = htonl(chan->last_seq + 1); + pkt.ack.previousPacket = htonl(chan->last_seq); + pkt.ack.serial = htonl(skb ? sp->hdr.serial : 0); + pkt.ack.reason = skb ? RXRPC_ACK_DUPLICATE : RXRPC_ACK_IDLE; + pkt.ack.nAcks = 0; + ack_info.rxMTU = htonl(rxrpc_rx_mtu); + ack_info.maxMTU = htonl(mtu); + ack_info.rwind = htonl(rxrpc_rx_window_size); + ack_info.jumbo_max = htonl(rxrpc_rx_jumbo_max); + pkt.whdr.flags |= RXRPC_SLOW_START_OK; + padding = 0; + iov[0].iov_len += sizeof(pkt.ack); + len += sizeof(pkt.ack) + 3 + sizeof(ack_info); + ioc = 3; + break; + + default: + return; + } + + /* Resync with __rxrpc_disconnect_call() and check that the last call + * didn't get advanced whilst we were filling out the packets. + */ + smp_rmb(); + if (READ_ONCE(chan->last_call) != call_id) + return; + + serial = atomic_inc_return(&conn->serial); + pkt.whdr.serial = htonl(serial); + + switch (chan->last_type) { + case RXRPC_PACKET_TYPE_ABORT: + _proto("Tx ABORT %%%u { %d } [re]", serial, conn->abort_code); + break; + case RXRPC_PACKET_TYPE_ACK: + trace_rxrpc_tx_ack(chan->call_debug_id, serial, + ntohl(pkt.ack.firstPacket), + ntohl(pkt.ack.serial), + pkt.ack.reason, 0); + _proto("Tx ACK %%%u [re]", serial); + break; + } + + ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, ioc, len); + conn->params.peer->last_tx_at = ktime_get_seconds(); + if (ret < 0) + trace_rxrpc_tx_fail(chan->call_debug_id, serial, ret, + rxrpc_tx_point_call_final_resend); + else + trace_rxrpc_tx_packet(chan->call_debug_id, &pkt.whdr, + rxrpc_tx_point_call_final_resend); + + _leave(""); +} + +/* + * pass a connection-level abort onto all calls on that connection + */ +static void rxrpc_abort_calls(struct rxrpc_connection *conn, + enum rxrpc_call_completion compl, + rxrpc_serial_t serial) +{ + struct rxrpc_call *call; + int i; + + _enter("{%d},%x", conn->debug_id, conn->abort_code); + + spin_lock(&conn->bundle->channel_lock); + + for (i = 0; i < RXRPC_MAXCALLS; i++) { + call = rcu_dereference_protected( + conn->channels[i].call, + lockdep_is_held(&conn->bundle->channel_lock)); + if (call) { + if (compl == RXRPC_CALL_LOCALLY_ABORTED) + trace_rxrpc_abort(call->debug_id, + "CON", call->cid, + call->call_id, 0, + conn->abort_code, + conn->error); + else + trace_rxrpc_rx_abort(call, serial, + conn->abort_code); + rxrpc_set_call_completion(call, compl, + conn->abort_code, + conn->error); + } + } + + spin_unlock(&conn->bundle->channel_lock); + _leave(""); +} + +/* + * generate a connection-level abort + */ +static int rxrpc_abort_connection(struct rxrpc_connection *conn, + int error, u32 abort_code) +{ + struct rxrpc_wire_header whdr; + struct msghdr msg; + struct kvec iov[2]; + __be32 word; + size_t len; + u32 serial; + int ret; + + _enter("%d,,%u,%u", conn->debug_id, error, abort_code); + + /* generate a connection-level abort */ + spin_lock_bh(&conn->state_lock); + if (conn->state >= RXRPC_CONN_REMOTELY_ABORTED) { + spin_unlock_bh(&conn->state_lock); + _leave(" = 0 [already dead]"); + return 0; + } + + conn->error = error; + conn->abort_code = abort_code; + conn->state = RXRPC_CONN_LOCALLY_ABORTED; + set_bit(RXRPC_CONN_DONT_REUSE, &conn->flags); + spin_unlock_bh(&conn->state_lock); + + msg.msg_name = &conn->params.peer->srx.transport; + msg.msg_namelen = conn->params.peer->srx.transport_len; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + whdr.epoch = htonl(conn->proto.epoch); + whdr.cid = htonl(conn->proto.cid); + whdr.callNumber = 0; + whdr.seq = 0; + whdr.type = RXRPC_PACKET_TYPE_ABORT; + whdr.flags = conn->out_clientflag; + whdr.userStatus = 0; + whdr.securityIndex = conn->security_ix; + whdr._rsvd = 0; + whdr.serviceId = htons(conn->service_id); + + word = htonl(conn->abort_code); + + iov[0].iov_base = &whdr; + iov[0].iov_len = sizeof(whdr); + iov[1].iov_base = &word; + iov[1].iov_len = sizeof(word); + + len = iov[0].iov_len + iov[1].iov_len; + + serial = atomic_inc_return(&conn->serial); + rxrpc_abort_calls(conn, RXRPC_CALL_LOCALLY_ABORTED, serial); + whdr.serial = htonl(serial); + _proto("Tx CONN ABORT %%%u { %d }", serial, conn->abort_code); + + ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 2, len); + if (ret < 0) { + trace_rxrpc_tx_fail(conn->debug_id, serial, ret, + rxrpc_tx_point_conn_abort); + _debug("sendmsg failed: %d", ret); + return -EAGAIN; + } + + trace_rxrpc_tx_packet(conn->debug_id, &whdr, rxrpc_tx_point_conn_abort); + + conn->params.peer->last_tx_at = ktime_get_seconds(); + + _leave(" = 0"); + return 0; +} + +/* + * mark a call as being on a now-secured channel + * - must be called with BH's disabled. + */ +static void rxrpc_call_is_secure(struct rxrpc_call *call) +{ + _enter("%p", call); + if (call) { + write_lock_bh(&call->state_lock); + if (call->state == RXRPC_CALL_SERVER_SECURING) { + call->state = RXRPC_CALL_SERVER_RECV_REQUEST; + rxrpc_notify_socket(call); + } + write_unlock_bh(&call->state_lock); + } +} + +/* + * connection-level Rx packet processor + */ +static int rxrpc_process_event(struct rxrpc_connection *conn, + struct sk_buff *skb, + u32 *_abort_code) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + __be32 wtmp; + u32 abort_code; + int loop, ret; + + if (conn->state >= RXRPC_CONN_REMOTELY_ABORTED) { + _leave(" = -ECONNABORTED [%u]", conn->state); + return -ECONNABORTED; + } + + _enter("{%d},{%u,%%%u},", conn->debug_id, sp->hdr.type, sp->hdr.serial); + + switch (sp->hdr.type) { + case RXRPC_PACKET_TYPE_DATA: + case RXRPC_PACKET_TYPE_ACK: + rxrpc_conn_retransmit_call(conn, skb, + sp->hdr.cid & RXRPC_CHANNELMASK); + return 0; + + case RXRPC_PACKET_TYPE_BUSY: + /* Just ignore BUSY packets for now. */ + return 0; + + case RXRPC_PACKET_TYPE_ABORT: + if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header), + &wtmp, sizeof(wtmp)) < 0) { + trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, + tracepoint_string("bad_abort")); + return -EPROTO; + } + abort_code = ntohl(wtmp); + _proto("Rx ABORT %%%u { ac=%d }", sp->hdr.serial, abort_code); + + conn->error = -ECONNABORTED; + conn->abort_code = abort_code; + conn->state = RXRPC_CONN_REMOTELY_ABORTED; + set_bit(RXRPC_CONN_DONT_REUSE, &conn->flags); + rxrpc_abort_calls(conn, RXRPC_CALL_REMOTELY_ABORTED, sp->hdr.serial); + return -ECONNABORTED; + + case RXRPC_PACKET_TYPE_CHALLENGE: + return conn->security->respond_to_challenge(conn, skb, + _abort_code); + + case RXRPC_PACKET_TYPE_RESPONSE: + ret = conn->security->verify_response(conn, skb, _abort_code); + if (ret < 0) + return ret; + + ret = conn->security->init_connection_security( + conn, conn->params.key->payload.data[0]); + if (ret < 0) + return ret; + + spin_lock(&conn->bundle->channel_lock); + spin_lock_bh(&conn->state_lock); + + if (conn->state == RXRPC_CONN_SERVICE_CHALLENGING) { + conn->state = RXRPC_CONN_SERVICE; + spin_unlock_bh(&conn->state_lock); + for (loop = 0; loop < RXRPC_MAXCALLS; loop++) + rxrpc_call_is_secure( + rcu_dereference_protected( + conn->channels[loop].call, + lockdep_is_held(&conn->bundle->channel_lock))); + } else { + spin_unlock_bh(&conn->state_lock); + } + + spin_unlock(&conn->bundle->channel_lock); + return 0; + + default: + trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, + tracepoint_string("bad_conn_pkt")); + return -EPROTO; + } +} + +/* + * set up security and issue a challenge + */ +static void rxrpc_secure_connection(struct rxrpc_connection *conn) +{ + u32 abort_code; + int ret; + + _enter("{%d}", conn->debug_id); + + ASSERT(conn->security_ix != 0); + + if (conn->security->issue_challenge(conn) < 0) { + abort_code = RX_CALL_DEAD; + ret = -ENOMEM; + goto abort; + } + + _leave(""); + return; + +abort: + _debug("abort %d, %d", ret, abort_code); + rxrpc_abort_connection(conn, ret, abort_code); + _leave(" [aborted]"); +} + +/* + * Process delayed final ACKs that we haven't subsumed into a subsequent call. + */ +void rxrpc_process_delayed_final_acks(struct rxrpc_connection *conn, bool force) +{ + unsigned long j = jiffies, next_j; + unsigned int channel; + bool set; + +again: + next_j = j + LONG_MAX; + set = false; + for (channel = 0; channel < RXRPC_MAXCALLS; channel++) { + struct rxrpc_channel *chan = &conn->channels[channel]; + unsigned long ack_at; + + if (!test_bit(RXRPC_CONN_FINAL_ACK_0 + channel, &conn->flags)) + continue; + + smp_rmb(); /* vs rxrpc_disconnect_client_call */ + ack_at = READ_ONCE(chan->final_ack_at); + + if (time_before(j, ack_at) && !force) { + if (time_before(ack_at, next_j)) { + next_j = ack_at; + set = true; + } + continue; + } + + if (test_and_clear_bit(RXRPC_CONN_FINAL_ACK_0 + channel, + &conn->flags)) + rxrpc_conn_retransmit_call(conn, NULL, channel); + } + + j = jiffies; + if (time_before_eq(next_j, j)) + goto again; + if (set) + rxrpc_reduce_conn_timer(conn, next_j); +} + +/* + * connection-level event processor + */ +static void rxrpc_do_process_connection(struct rxrpc_connection *conn) +{ + struct sk_buff *skb; + u32 abort_code = RX_PROTOCOL_ERROR; + int ret; + + if (test_and_clear_bit(RXRPC_CONN_EV_CHALLENGE, &conn->events)) + rxrpc_secure_connection(conn); + + /* Process delayed ACKs whose time has come. */ + if (conn->flags & RXRPC_CONN_FINAL_ACK_MASK) + rxrpc_process_delayed_final_acks(conn, false); + + /* go through the conn-level event packets, releasing the ref on this + * connection that each one has when we've finished with it */ + while ((skb = skb_dequeue(&conn->rx_queue))) { + rxrpc_see_skb(skb, rxrpc_skb_seen); + ret = rxrpc_process_event(conn, skb, &abort_code); + switch (ret) { + case -EPROTO: + case -EKEYEXPIRED: + case -EKEYREJECTED: + goto protocol_error; + case -ENOMEM: + case -EAGAIN: + goto requeue_and_leave; + case -ECONNABORTED: + default: + rxrpc_free_skb(skb, rxrpc_skb_freed); + break; + } + } + + return; + +requeue_and_leave: + skb_queue_head(&conn->rx_queue, skb); + return; + +protocol_error: + if (rxrpc_abort_connection(conn, ret, abort_code) < 0) + goto requeue_and_leave; + rxrpc_free_skb(skb, rxrpc_skb_freed); + return; +} + +void rxrpc_process_connection(struct work_struct *work) +{ + struct rxrpc_connection *conn = + container_of(work, struct rxrpc_connection, processor); + + rxrpc_see_connection(conn); + + if (__rxrpc_use_local(conn->params.local)) { + rxrpc_do_process_connection(conn); + rxrpc_unuse_local(conn->params.local); + } + + rxrpc_put_connection(conn); + _leave(""); + return; +} diff --git a/net/rxrpc/conn_object.c b/net/rxrpc/conn_object.c new file mode 100644 index 000000000..22089e37e --- /dev/null +++ b/net/rxrpc/conn_object.c @@ -0,0 +1,487 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* RxRPC virtual connection handler, common bits. + * + * Copyright (C) 2007, 2016 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include "ar-internal.h" + +/* + * Time till a connection expires after last use (in seconds). + */ +unsigned int __read_mostly rxrpc_connection_expiry = 10 * 60; +unsigned int __read_mostly rxrpc_closed_conn_expiry = 10; + +static void rxrpc_destroy_connection(struct rcu_head *); + +static void rxrpc_connection_timer(struct timer_list *timer) +{ + struct rxrpc_connection *conn = + container_of(timer, struct rxrpc_connection, timer); + + rxrpc_queue_conn(conn); +} + +/* + * allocate a new connection + */ +struct rxrpc_connection *rxrpc_alloc_connection(gfp_t gfp) +{ + struct rxrpc_connection *conn; + + _enter(""); + + conn = kzalloc(sizeof(struct rxrpc_connection), gfp); + if (conn) { + INIT_LIST_HEAD(&conn->cache_link); + timer_setup(&conn->timer, &rxrpc_connection_timer, 0); + INIT_WORK(&conn->processor, &rxrpc_process_connection); + INIT_LIST_HEAD(&conn->proc_link); + INIT_LIST_HEAD(&conn->link); + skb_queue_head_init(&conn->rx_queue); + conn->security = &rxrpc_no_security; + spin_lock_init(&conn->state_lock); + conn->debug_id = atomic_inc_return(&rxrpc_debug_id); + conn->idle_timestamp = jiffies; + } + + _leave(" = %p{%d}", conn, conn ? conn->debug_id : 0); + return conn; +} + +/* + * Look up a connection in the cache by protocol parameters. + * + * If successful, a pointer to the connection is returned, but no ref is taken. + * NULL is returned if there is no match. + * + * When searching for a service call, if we find a peer but no connection, we + * return that through *_peer in case we need to create a new service call. + * + * The caller must be holding the RCU read lock. + */ +struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local, + struct sk_buff *skb, + struct rxrpc_peer **_peer) +{ + struct rxrpc_connection *conn; + struct rxrpc_conn_proto k; + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + struct sockaddr_rxrpc srx; + struct rxrpc_peer *peer; + + _enter(",%x", sp->hdr.cid & RXRPC_CIDMASK); + + if (rxrpc_extract_addr_from_skb(&srx, skb) < 0) + goto not_found; + + if (srx.transport.family != local->srx.transport.family && + (srx.transport.family == AF_INET && + local->srx.transport.family != AF_INET6)) { + pr_warn_ratelimited("AF_RXRPC: Protocol mismatch %u not %u\n", + srx.transport.family, + local->srx.transport.family); + goto not_found; + } + + k.epoch = sp->hdr.epoch; + k.cid = sp->hdr.cid & RXRPC_CIDMASK; + + if (rxrpc_to_server(sp)) { + /* We need to look up service connections by the full protocol + * parameter set. We look up the peer first as an intermediate + * step and then the connection from the peer's tree. + */ + peer = rxrpc_lookup_peer_rcu(local, &srx); + if (!peer) + goto not_found; + *_peer = peer; + conn = rxrpc_find_service_conn_rcu(peer, skb); + if (!conn || refcount_read(&conn->ref) == 0) + goto not_found; + _leave(" = %p", conn); + return conn; + } else { + /* Look up client connections by connection ID alone as their + * IDs are unique for this machine. + */ + conn = idr_find(&rxrpc_client_conn_ids, + sp->hdr.cid >> RXRPC_CIDSHIFT); + if (!conn || refcount_read(&conn->ref) == 0) { + _debug("no conn"); + goto not_found; + } + + if (conn->proto.epoch != k.epoch || + conn->params.local != local) + goto not_found; + + peer = conn->params.peer; + switch (srx.transport.family) { + case AF_INET: + if (peer->srx.transport.sin.sin_port != + srx.transport.sin.sin_port || + peer->srx.transport.sin.sin_addr.s_addr != + srx.transport.sin.sin_addr.s_addr) + goto not_found; + break; +#ifdef CONFIG_AF_RXRPC_IPV6 + case AF_INET6: + if (peer->srx.transport.sin6.sin6_port != + srx.transport.sin6.sin6_port || + memcmp(&peer->srx.transport.sin6.sin6_addr, + &srx.transport.sin6.sin6_addr, + sizeof(struct in6_addr)) != 0) + goto not_found; + break; +#endif + default: + BUG(); + } + + _leave(" = %p", conn); + return conn; + } + +not_found: + _leave(" = NULL"); + return NULL; +} + +/* + * Disconnect a call and clear any channel it occupies when that call + * terminates. The caller must hold the channel_lock and must release the + * call's ref on the connection. + */ +void __rxrpc_disconnect_call(struct rxrpc_connection *conn, + struct rxrpc_call *call) +{ + struct rxrpc_channel *chan = + &conn->channels[call->cid & RXRPC_CHANNELMASK]; + + _enter("%d,%x", conn->debug_id, call->cid); + + if (rcu_access_pointer(chan->call) == call) { + /* Save the result of the call so that we can repeat it if necessary + * through the channel, whilst disposing of the actual call record. + */ + trace_rxrpc_disconnect_call(call); + switch (call->completion) { + case RXRPC_CALL_SUCCEEDED: + chan->last_seq = call->rx_hard_ack; + chan->last_type = RXRPC_PACKET_TYPE_ACK; + break; + case RXRPC_CALL_LOCALLY_ABORTED: + chan->last_abort = call->abort_code; + chan->last_type = RXRPC_PACKET_TYPE_ABORT; + break; + default: + chan->last_abort = RX_CALL_DEAD; + chan->last_type = RXRPC_PACKET_TYPE_ABORT; + break; + } + + /* Sync with rxrpc_conn_retransmit(). */ + smp_wmb(); + chan->last_call = chan->call_id; + chan->call_id = chan->call_counter; + + rcu_assign_pointer(chan->call, NULL); + } + + _leave(""); +} + +/* + * Disconnect a call and clear any channel it occupies when that call + * terminates. + */ +void rxrpc_disconnect_call(struct rxrpc_call *call) +{ + struct rxrpc_connection *conn = call->conn; + + call->peer->cong_cwnd = call->cong_cwnd; + + if (!hlist_unhashed(&call->error_link)) { + spin_lock_bh(&call->peer->lock); + hlist_del_rcu(&call->error_link); + spin_unlock_bh(&call->peer->lock); + } + + if (rxrpc_is_client_call(call)) + return rxrpc_disconnect_client_call(conn->bundle, call); + + spin_lock(&conn->bundle->channel_lock); + __rxrpc_disconnect_call(conn, call); + spin_unlock(&conn->bundle->channel_lock); + + set_bit(RXRPC_CALL_DISCONNECTED, &call->flags); + conn->idle_timestamp = jiffies; +} + +/* + * Kill off a connection. + */ +void rxrpc_kill_connection(struct rxrpc_connection *conn) +{ + struct rxrpc_net *rxnet = conn->params.local->rxnet; + + ASSERT(!rcu_access_pointer(conn->channels[0].call) && + !rcu_access_pointer(conn->channels[1].call) && + !rcu_access_pointer(conn->channels[2].call) && + !rcu_access_pointer(conn->channels[3].call)); + ASSERT(list_empty(&conn->cache_link)); + + write_lock(&rxnet->conn_lock); + list_del_init(&conn->proc_link); + write_unlock(&rxnet->conn_lock); + + /* Drain the Rx queue. Note that even though we've unpublished, an + * incoming packet could still be being added to our Rx queue, so we + * will need to drain it again in the RCU cleanup handler. + */ + rxrpc_purge_queue(&conn->rx_queue); + + /* Leave final destruction to RCU. The connection processor work item + * must carry a ref on the connection to prevent us getting here whilst + * it is queued or running. + */ + call_rcu(&conn->rcu, rxrpc_destroy_connection); +} + +/* + * Queue a connection's work processor, getting a ref to pass to the work + * queue. + */ +bool rxrpc_queue_conn(struct rxrpc_connection *conn) +{ + const void *here = __builtin_return_address(0); + int r; + + if (!__refcount_inc_not_zero(&conn->ref, &r)) + return false; + if (rxrpc_queue_work(&conn->processor)) + trace_rxrpc_conn(conn->debug_id, rxrpc_conn_queued, r + 1, here); + else + rxrpc_put_connection(conn); + return true; +} + +/* + * Note the re-emergence of a connection. + */ +void rxrpc_see_connection(struct rxrpc_connection *conn) +{ + const void *here = __builtin_return_address(0); + if (conn) { + int n = refcount_read(&conn->ref); + + trace_rxrpc_conn(conn->debug_id, rxrpc_conn_seen, n, here); + } +} + +/* + * Get a ref on a connection. + */ +struct rxrpc_connection *rxrpc_get_connection(struct rxrpc_connection *conn) +{ + const void *here = __builtin_return_address(0); + int r; + + __refcount_inc(&conn->ref, &r); + trace_rxrpc_conn(conn->debug_id, rxrpc_conn_got, r, here); + return conn; +} + +/* + * Try to get a ref on a connection. + */ +struct rxrpc_connection * +rxrpc_get_connection_maybe(struct rxrpc_connection *conn) +{ + const void *here = __builtin_return_address(0); + int r; + + if (conn) { + if (__refcount_inc_not_zero(&conn->ref, &r)) + trace_rxrpc_conn(conn->debug_id, rxrpc_conn_got, r + 1, here); + else + conn = NULL; + } + return conn; +} + +/* + * Set the service connection reap timer. + */ +static void rxrpc_set_service_reap_timer(struct rxrpc_net *rxnet, + unsigned long reap_at) +{ + if (rxnet->live) + timer_reduce(&rxnet->service_conn_reap_timer, reap_at); +} + +/* + * Release a service connection + */ +void rxrpc_put_service_conn(struct rxrpc_connection *conn) +{ + const void *here = __builtin_return_address(0); + unsigned int debug_id = conn->debug_id; + int r; + + __refcount_dec(&conn->ref, &r); + trace_rxrpc_conn(debug_id, rxrpc_conn_put_service, r - 1, here); + if (r - 1 == 1) + rxrpc_set_service_reap_timer(conn->params.local->rxnet, + jiffies + rxrpc_connection_expiry); +} + +/* + * destroy a virtual connection + */ +static void rxrpc_destroy_connection(struct rcu_head *rcu) +{ + struct rxrpc_connection *conn = + container_of(rcu, struct rxrpc_connection, rcu); + + _enter("{%d,u=%d}", conn->debug_id, refcount_read(&conn->ref)); + + ASSERTCMP(refcount_read(&conn->ref), ==, 0); + + _net("DESTROY CONN %d", conn->debug_id); + + del_timer_sync(&conn->timer); + rxrpc_purge_queue(&conn->rx_queue); + + conn->security->clear(conn); + key_put(conn->params.key); + rxrpc_put_bundle(conn->bundle); + rxrpc_put_peer(conn->params.peer); + + if (atomic_dec_and_test(&conn->params.local->rxnet->nr_conns)) + wake_up_var(&conn->params.local->rxnet->nr_conns); + rxrpc_put_local(conn->params.local); + + kfree(conn); + _leave(""); +} + +/* + * reap dead service connections + */ +void rxrpc_service_connection_reaper(struct work_struct *work) +{ + struct rxrpc_connection *conn, *_p; + struct rxrpc_net *rxnet = + container_of(work, struct rxrpc_net, service_conn_reaper); + unsigned long expire_at, earliest, idle_timestamp, now; + + LIST_HEAD(graveyard); + + _enter(""); + + now = jiffies; + earliest = now + MAX_JIFFY_OFFSET; + + write_lock(&rxnet->conn_lock); + list_for_each_entry_safe(conn, _p, &rxnet->service_conns, link) { + ASSERTCMP(refcount_read(&conn->ref), >, 0); + if (likely(refcount_read(&conn->ref) > 1)) + continue; + if (conn->state == RXRPC_CONN_SERVICE_PREALLOC) + continue; + + if (rxnet->live && !conn->params.local->dead) { + idle_timestamp = READ_ONCE(conn->idle_timestamp); + expire_at = idle_timestamp + rxrpc_connection_expiry * HZ; + if (conn->params.local->service_closed) + expire_at = idle_timestamp + rxrpc_closed_conn_expiry * HZ; + + _debug("reap CONN %d { u=%d,t=%ld }", + conn->debug_id, refcount_read(&conn->ref), + (long)expire_at - (long)now); + + if (time_before(now, expire_at)) { + if (time_before(expire_at, earliest)) + earliest = expire_at; + continue; + } + } + + /* The usage count sits at 1 whilst the object is unused on the + * list; we reduce that to 0 to make the object unavailable. + */ + if (!refcount_dec_if_one(&conn->ref)) + continue; + trace_rxrpc_conn(conn->debug_id, rxrpc_conn_reap_service, 0, NULL); + + if (rxrpc_conn_is_client(conn)) + BUG(); + else + rxrpc_unpublish_service_conn(conn); + + list_move_tail(&conn->link, &graveyard); + } + write_unlock(&rxnet->conn_lock); + + if (earliest != now + MAX_JIFFY_OFFSET) { + _debug("reschedule reaper %ld", (long)earliest - (long)now); + ASSERT(time_after(earliest, now)); + rxrpc_set_service_reap_timer(rxnet, earliest); + } + + while (!list_empty(&graveyard)) { + conn = list_entry(graveyard.next, struct rxrpc_connection, + link); + list_del_init(&conn->link); + + ASSERTCMP(refcount_read(&conn->ref), ==, 0); + rxrpc_kill_connection(conn); + } + + _leave(""); +} + +/* + * preemptively destroy all the service connection records rather than + * waiting for them to time out + */ +void rxrpc_destroy_all_connections(struct rxrpc_net *rxnet) +{ + struct rxrpc_connection *conn, *_p; + bool leak = false; + + _enter(""); + + atomic_dec(&rxnet->nr_conns); + rxrpc_destroy_all_client_connections(rxnet); + + del_timer_sync(&rxnet->service_conn_reap_timer); + rxrpc_queue_work(&rxnet->service_conn_reaper); + flush_workqueue(rxrpc_workqueue); + + write_lock(&rxnet->conn_lock); + list_for_each_entry_safe(conn, _p, &rxnet->service_conns, link) { + pr_err("AF_RXRPC: Leaked conn %p {%d}\n", + conn, refcount_read(&conn->ref)); + leak = true; + } + write_unlock(&rxnet->conn_lock); + BUG_ON(leak); + + ASSERT(list_empty(&rxnet->conn_proc_list)); + + /* We need to wait for the connections to be destroyed by RCU as they + * pin things that we still need to get rid of. + */ + wait_var_event(&rxnet->nr_conns, !atomic_read(&rxnet->nr_conns)); + _leave(""); +} diff --git a/net/rxrpc/conn_service.c b/net/rxrpc/conn_service.c new file mode 100644 index 000000000..6e6aa02c6 --- /dev/null +++ b/net/rxrpc/conn_service.c @@ -0,0 +1,203 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* Service connection management + * + * Copyright (C) 2016 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#include +#include "ar-internal.h" + +static struct rxrpc_bundle rxrpc_service_dummy_bundle = { + .ref = REFCOUNT_INIT(1), + .debug_id = UINT_MAX, + .channel_lock = __SPIN_LOCK_UNLOCKED(&rxrpc_service_dummy_bundle.channel_lock), +}; + +/* + * Find a service connection under RCU conditions. + * + * We could use a hash table, but that is subject to bucket stuffing by an + * attacker as the client gets to pick the epoch and cid values and would know + * the hash function. So, instead, we use a hash table for the peer and from + * that an rbtree to find the service connection. Under ordinary circumstances + * it might be slower than a large hash table, but it is at least limited in + * depth. + */ +struct rxrpc_connection *rxrpc_find_service_conn_rcu(struct rxrpc_peer *peer, + struct sk_buff *skb) +{ + struct rxrpc_connection *conn = NULL; + struct rxrpc_conn_proto k; + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + struct rb_node *p; + unsigned int seq = 0; + + k.epoch = sp->hdr.epoch; + k.cid = sp->hdr.cid & RXRPC_CIDMASK; + + do { + /* Unfortunately, rbtree walking doesn't give reliable results + * under just the RCU read lock, so we have to check for + * changes. + */ + read_seqbegin_or_lock(&peer->service_conn_lock, &seq); + + p = rcu_dereference_raw(peer->service_conns.rb_node); + while (p) { + conn = rb_entry(p, struct rxrpc_connection, service_node); + + if (conn->proto.index_key < k.index_key) + p = rcu_dereference_raw(p->rb_left); + else if (conn->proto.index_key > k.index_key) + p = rcu_dereference_raw(p->rb_right); + else + break; + conn = NULL; + } + } while (need_seqretry(&peer->service_conn_lock, seq)); + + done_seqretry(&peer->service_conn_lock, seq); + _leave(" = %d", conn ? conn->debug_id : -1); + return conn; +} + +/* + * Insert a service connection into a peer's tree, thereby making it a target + * for incoming packets. + */ +static void rxrpc_publish_service_conn(struct rxrpc_peer *peer, + struct rxrpc_connection *conn) +{ + struct rxrpc_connection *cursor = NULL; + struct rxrpc_conn_proto k = conn->proto; + struct rb_node **pp, *parent; + + write_seqlock_bh(&peer->service_conn_lock); + + pp = &peer->service_conns.rb_node; + parent = NULL; + while (*pp) { + parent = *pp; + cursor = rb_entry(parent, + struct rxrpc_connection, service_node); + + if (cursor->proto.index_key < k.index_key) + pp = &(*pp)->rb_left; + else if (cursor->proto.index_key > k.index_key) + pp = &(*pp)->rb_right; + else + goto found_extant_conn; + } + + rb_link_node_rcu(&conn->service_node, parent, pp); + rb_insert_color(&conn->service_node, &peer->service_conns); +conn_published: + set_bit(RXRPC_CONN_IN_SERVICE_CONNS, &conn->flags); + write_sequnlock_bh(&peer->service_conn_lock); + _leave(" = %d [new]", conn->debug_id); + return; + +found_extant_conn: + if (refcount_read(&cursor->ref) == 0) + goto replace_old_connection; + write_sequnlock_bh(&peer->service_conn_lock); + /* We should not be able to get here. rxrpc_incoming_connection() is + * called in a non-reentrant context, so there can't be a race to + * insert a new connection. + */ + BUG(); + +replace_old_connection: + /* The old connection is from an outdated epoch. */ + _debug("replace conn"); + rb_replace_node_rcu(&cursor->service_node, + &conn->service_node, + &peer->service_conns); + clear_bit(RXRPC_CONN_IN_SERVICE_CONNS, &cursor->flags); + goto conn_published; +} + +/* + * Preallocate a service connection. The connection is placed on the proc and + * reap lists so that we don't have to get the lock from BH context. + */ +struct rxrpc_connection *rxrpc_prealloc_service_connection(struct rxrpc_net *rxnet, + gfp_t gfp) +{ + struct rxrpc_connection *conn = rxrpc_alloc_connection(gfp); + + if (conn) { + /* We maintain an extra ref on the connection whilst it is on + * the rxrpc_connections list. + */ + conn->state = RXRPC_CONN_SERVICE_PREALLOC; + refcount_set(&conn->ref, 2); + conn->bundle = rxrpc_get_bundle(&rxrpc_service_dummy_bundle); + + atomic_inc(&rxnet->nr_conns); + write_lock(&rxnet->conn_lock); + list_add_tail(&conn->link, &rxnet->service_conns); + list_add_tail(&conn->proc_link, &rxnet->conn_proc_list); + write_unlock(&rxnet->conn_lock); + + trace_rxrpc_conn(conn->debug_id, rxrpc_conn_new_service, + refcount_read(&conn->ref), + __builtin_return_address(0)); + } + + return conn; +} + +/* + * Set up an incoming connection. This is called in BH context with the RCU + * read lock held. + */ +void rxrpc_new_incoming_connection(struct rxrpc_sock *rx, + struct rxrpc_connection *conn, + const struct rxrpc_security *sec, + struct sk_buff *skb) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + + _enter(""); + + conn->proto.epoch = sp->hdr.epoch; + conn->proto.cid = sp->hdr.cid & RXRPC_CIDMASK; + conn->params.service_id = sp->hdr.serviceId; + conn->service_id = sp->hdr.serviceId; + conn->security_ix = sp->hdr.securityIndex; + conn->out_clientflag = 0; + conn->security = sec; + if (conn->security_ix) + conn->state = RXRPC_CONN_SERVICE_UNSECURED; + else + conn->state = RXRPC_CONN_SERVICE; + + /* See if we should upgrade the service. This can only happen on the + * first packet on a new connection. Once done, it applies to all + * subsequent calls on that connection. + */ + if (sp->hdr.userStatus == RXRPC_USERSTATUS_SERVICE_UPGRADE && + conn->service_id == rx->service_upgrade.from) + conn->service_id = rx->service_upgrade.to; + + /* Make the connection a target for incoming packets. */ + rxrpc_publish_service_conn(conn->params.peer, conn); + + _net("CONNECTION new %d {%x}", conn->debug_id, conn->proto.cid); +} + +/* + * Remove the service connection from the peer's tree, thereby removing it as a + * target for incoming packets. + */ +void rxrpc_unpublish_service_conn(struct rxrpc_connection *conn) +{ + struct rxrpc_peer *peer = conn->params.peer; + + write_seqlock_bh(&peer->service_conn_lock); + if (test_and_clear_bit(RXRPC_CONN_IN_SERVICE_CONNS, &conn->flags)) + rb_erase(&conn->service_node, &peer->service_conns); + write_sequnlock_bh(&peer->service_conn_lock); +} diff --git a/net/rxrpc/input.c b/net/rxrpc/input.c new file mode 100644 index 000000000..721d847ba --- /dev/null +++ b/net/rxrpc/input.c @@ -0,0 +1,1502 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* RxRPC packet reception + * + * Copyright (C) 2007, 2016 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ar-internal.h" + +static void rxrpc_proto_abort(const char *why, + struct rxrpc_call *call, rxrpc_seq_t seq) +{ + if (rxrpc_abort_call(why, call, seq, RX_PROTOCOL_ERROR, -EBADMSG)) { + set_bit(RXRPC_CALL_EV_ABORT, &call->events); + rxrpc_queue_call(call); + } +} + +/* + * Do TCP-style congestion management [RFC 5681]. + */ +static void rxrpc_congestion_management(struct rxrpc_call *call, + struct sk_buff *skb, + struct rxrpc_ack_summary *summary, + rxrpc_serial_t acked_serial) +{ + enum rxrpc_congest_change change = rxrpc_cong_no_change; + unsigned int cumulative_acks = call->cong_cumul_acks; + unsigned int cwnd = call->cong_cwnd; + bool resend = false; + + summary->flight_size = + (call->tx_top - call->tx_hard_ack) - summary->nr_acks; + + if (test_and_clear_bit(RXRPC_CALL_RETRANS_TIMEOUT, &call->flags)) { + summary->retrans_timeo = true; + call->cong_ssthresh = max_t(unsigned int, + summary->flight_size / 2, 2); + cwnd = 1; + if (cwnd >= call->cong_ssthresh && + call->cong_mode == RXRPC_CALL_SLOW_START) { + call->cong_mode = RXRPC_CALL_CONGEST_AVOIDANCE; + call->cong_tstamp = skb->tstamp; + cumulative_acks = 0; + } + } + + cumulative_acks += summary->nr_new_acks; + cumulative_acks += summary->nr_rot_new_acks; + if (cumulative_acks > 255) + cumulative_acks = 255; + + summary->mode = call->cong_mode; + summary->cwnd = call->cong_cwnd; + summary->ssthresh = call->cong_ssthresh; + summary->cumulative_acks = cumulative_acks; + summary->dup_acks = call->cong_dup_acks; + + switch (call->cong_mode) { + case RXRPC_CALL_SLOW_START: + if (summary->nr_nacks > 0) + goto packet_loss_detected; + if (summary->cumulative_acks > 0) + cwnd += 1; + if (cwnd >= call->cong_ssthresh) { + call->cong_mode = RXRPC_CALL_CONGEST_AVOIDANCE; + call->cong_tstamp = skb->tstamp; + } + goto out; + + case RXRPC_CALL_CONGEST_AVOIDANCE: + if (summary->nr_nacks > 0) + goto packet_loss_detected; + + /* We analyse the number of packets that get ACK'd per RTT + * period and increase the window if we managed to fill it. + */ + if (call->peer->rtt_count == 0) + goto out; + if (ktime_before(skb->tstamp, + ktime_add_us(call->cong_tstamp, + call->peer->srtt_us >> 3))) + goto out_no_clear_ca; + change = rxrpc_cong_rtt_window_end; + call->cong_tstamp = skb->tstamp; + if (cumulative_acks >= cwnd) + cwnd++; + goto out; + + case RXRPC_CALL_PACKET_LOSS: + if (summary->nr_nacks == 0) + goto resume_normality; + + if (summary->new_low_nack) { + change = rxrpc_cong_new_low_nack; + call->cong_dup_acks = 1; + if (call->cong_extra > 1) + call->cong_extra = 1; + goto send_extra_data; + } + + call->cong_dup_acks++; + if (call->cong_dup_acks < 3) + goto send_extra_data; + + change = rxrpc_cong_begin_retransmission; + call->cong_mode = RXRPC_CALL_FAST_RETRANSMIT; + call->cong_ssthresh = max_t(unsigned int, + summary->flight_size / 2, 2); + cwnd = call->cong_ssthresh + 3; + call->cong_extra = 0; + call->cong_dup_acks = 0; + resend = true; + goto out; + + case RXRPC_CALL_FAST_RETRANSMIT: + if (!summary->new_low_nack) { + if (summary->nr_new_acks == 0) + cwnd += 1; + call->cong_dup_acks++; + if (call->cong_dup_acks == 2) { + change = rxrpc_cong_retransmit_again; + call->cong_dup_acks = 0; + resend = true; + } + } else { + change = rxrpc_cong_progress; + cwnd = call->cong_ssthresh; + if (summary->nr_nacks == 0) + goto resume_normality; + } + goto out; + + default: + BUG(); + goto out; + } + +resume_normality: + change = rxrpc_cong_cleared_nacks; + call->cong_dup_acks = 0; + call->cong_extra = 0; + call->cong_tstamp = skb->tstamp; + if (cwnd < call->cong_ssthresh) + call->cong_mode = RXRPC_CALL_SLOW_START; + else + call->cong_mode = RXRPC_CALL_CONGEST_AVOIDANCE; +out: + cumulative_acks = 0; +out_no_clear_ca: + if (cwnd >= RXRPC_RXTX_BUFF_SIZE - 1) + cwnd = RXRPC_RXTX_BUFF_SIZE - 1; + call->cong_cwnd = cwnd; + call->cong_cumul_acks = cumulative_acks; + trace_rxrpc_congest(call, summary, acked_serial, change); + if (resend && !test_and_set_bit(RXRPC_CALL_EV_RESEND, &call->events)) + rxrpc_queue_call(call); + return; + +packet_loss_detected: + change = rxrpc_cong_saw_nack; + call->cong_mode = RXRPC_CALL_PACKET_LOSS; + call->cong_dup_acks = 0; + goto send_extra_data; + +send_extra_data: + /* Send some previously unsent DATA if we have some to advance the ACK + * state. + */ + if (call->rxtx_annotations[call->tx_top & RXRPC_RXTX_BUFF_MASK] & + RXRPC_TX_ANNO_LAST || + summary->nr_acks != call->tx_top - call->tx_hard_ack) { + call->cong_extra++; + wake_up(&call->waitq); + } + goto out_no_clear_ca; +} + +/* + * Apply a hard ACK by advancing the Tx window. + */ +static bool rxrpc_rotate_tx_window(struct rxrpc_call *call, rxrpc_seq_t to, + struct rxrpc_ack_summary *summary) +{ + struct sk_buff *skb, *list = NULL; + bool rot_last = false; + int ix; + u8 annotation; + + if (call->acks_lowest_nak == call->tx_hard_ack) { + call->acks_lowest_nak = to; + } else if (before_eq(call->acks_lowest_nak, to)) { + summary->new_low_nack = true; + call->acks_lowest_nak = to; + } + + spin_lock(&call->lock); + + while (before(call->tx_hard_ack, to)) { + call->tx_hard_ack++; + ix = call->tx_hard_ack & RXRPC_RXTX_BUFF_MASK; + skb = call->rxtx_buffer[ix]; + annotation = call->rxtx_annotations[ix]; + rxrpc_see_skb(skb, rxrpc_skb_rotated); + call->rxtx_buffer[ix] = NULL; + call->rxtx_annotations[ix] = 0; + skb->next = list; + list = skb; + + if (annotation & RXRPC_TX_ANNO_LAST) { + set_bit(RXRPC_CALL_TX_LAST, &call->flags); + rot_last = true; + } + if ((annotation & RXRPC_TX_ANNO_MASK) != RXRPC_TX_ANNO_ACK) + summary->nr_rot_new_acks++; + } + + spin_unlock(&call->lock); + + trace_rxrpc_transmit(call, (rot_last ? + rxrpc_transmit_rotate_last : + rxrpc_transmit_rotate)); + wake_up(&call->waitq); + + while (list) { + skb = list; + list = skb->next; + skb_mark_not_on_list(skb); + rxrpc_free_skb(skb, rxrpc_skb_freed); + } + + return rot_last; +} + +/* + * End the transmission phase of a call. + * + * This occurs when we get an ACKALL packet, the first DATA packet of a reply, + * or a final ACK packet. + */ +static bool rxrpc_end_tx_phase(struct rxrpc_call *call, bool reply_begun, + const char *abort_why) +{ + unsigned int state; + + ASSERT(test_bit(RXRPC_CALL_TX_LAST, &call->flags)); + + write_lock(&call->state_lock); + + state = call->state; + switch (state) { + case RXRPC_CALL_CLIENT_SEND_REQUEST: + case RXRPC_CALL_CLIENT_AWAIT_REPLY: + if (reply_begun) + call->state = state = RXRPC_CALL_CLIENT_RECV_REPLY; + else + call->state = state = RXRPC_CALL_CLIENT_AWAIT_REPLY; + break; + + case RXRPC_CALL_SERVER_AWAIT_ACK: + __rxrpc_call_completed(call); + state = call->state; + break; + + default: + goto bad_state; + } + + write_unlock(&call->state_lock); + if (state == RXRPC_CALL_CLIENT_AWAIT_REPLY) + trace_rxrpc_transmit(call, rxrpc_transmit_await_reply); + else + trace_rxrpc_transmit(call, rxrpc_transmit_end); + _leave(" = ok"); + return true; + +bad_state: + write_unlock(&call->state_lock); + kdebug("end_tx %s", rxrpc_call_states[call->state]); + rxrpc_proto_abort(abort_why, call, call->tx_top); + return false; +} + +/* + * Begin the reply reception phase of a call. + */ +static bool rxrpc_receiving_reply(struct rxrpc_call *call) +{ + struct rxrpc_ack_summary summary = { 0 }; + unsigned long now, timo; + rxrpc_seq_t top = READ_ONCE(call->tx_top); + + if (call->ackr_reason) { + spin_lock_bh(&call->lock); + call->ackr_reason = 0; + spin_unlock_bh(&call->lock); + now = jiffies; + timo = now + MAX_JIFFY_OFFSET; + WRITE_ONCE(call->resend_at, timo); + WRITE_ONCE(call->ack_at, timo); + trace_rxrpc_timer(call, rxrpc_timer_init_for_reply, now); + } + + if (!test_bit(RXRPC_CALL_TX_LAST, &call->flags)) { + if (!rxrpc_rotate_tx_window(call, top, &summary)) { + rxrpc_proto_abort("TXL", call, top); + return false; + } + } + if (!rxrpc_end_tx_phase(call, true, "ETD")) + return false; + call->tx_phase = false; + return true; +} + +/* + * Scan a data packet to validate its structure and to work out how many + * subpackets it contains. + * + * A jumbo packet is a collection of consecutive packets glued together with + * little headers between that indicate how to change the initial header for + * each subpacket. + * + * RXRPC_JUMBO_PACKET must be set on all but the last subpacket - and all but + * the last are RXRPC_JUMBO_DATALEN in size. The last subpacket may be of any + * size. + */ +static bool rxrpc_validate_data(struct sk_buff *skb) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + unsigned int offset = sizeof(struct rxrpc_wire_header); + unsigned int len = skb->len; + u8 flags = sp->hdr.flags; + + for (;;) { + if (flags & RXRPC_REQUEST_ACK) + __set_bit(sp->nr_subpackets, sp->rx_req_ack); + sp->nr_subpackets++; + + if (!(flags & RXRPC_JUMBO_PACKET)) + break; + + if (len - offset < RXRPC_JUMBO_SUBPKTLEN) + goto protocol_error; + if (flags & RXRPC_LAST_PACKET) + goto protocol_error; + offset += RXRPC_JUMBO_DATALEN; + if (skb_copy_bits(skb, offset, &flags, 1) < 0) + goto protocol_error; + offset += sizeof(struct rxrpc_jumbo_header); + } + + if (flags & RXRPC_LAST_PACKET) + sp->rx_flags |= RXRPC_SKB_INCL_LAST; + return true; + +protocol_error: + return false; +} + +/* + * Handle reception of a duplicate packet. + * + * We have to take care to avoid an attack here whereby we're given a series of + * jumbograms, each with a sequence number one before the preceding one and + * filled up to maximum UDP size. If they never send us the first packet in + * the sequence, they can cause us to have to hold on to around 2MiB of kernel + * space until the call times out. + * + * We limit the space usage by only accepting three duplicate jumbo packets per + * call. After that, we tell the other side we're no longer accepting jumbos + * (that information is encoded in the ACK packet). + */ +static void rxrpc_input_dup_data(struct rxrpc_call *call, rxrpc_seq_t seq, + bool is_jumbo, bool *_jumbo_bad) +{ + /* Discard normal packets that are duplicates. */ + if (is_jumbo) + return; + + /* Skip jumbo subpackets that are duplicates. When we've had three or + * more partially duplicate jumbo packets, we refuse to take any more + * jumbos for this call. + */ + if (!*_jumbo_bad) { + call->nr_jumbo_bad++; + *_jumbo_bad = true; + } +} + +/* + * Process a DATA packet, adding the packet to the Rx ring. The caller's + * packet ref must be passed on or discarded. + */ +static void rxrpc_input_data(struct rxrpc_call *call, struct sk_buff *skb) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + enum rxrpc_call_state state; + unsigned int j, nr_subpackets, nr_unacked = 0; + rxrpc_serial_t serial = sp->hdr.serial, ack_serial = serial; + rxrpc_seq_t seq0 = sp->hdr.seq, hard_ack; + bool immediate_ack = false, jumbo_bad = false; + u8 ack = 0; + + _enter("{%u,%u},{%u,%u}", + call->rx_hard_ack, call->rx_top, skb->len, seq0); + + _proto("Rx DATA %%%u { #%u f=%02x n=%u }", + sp->hdr.serial, seq0, sp->hdr.flags, sp->nr_subpackets); + + state = READ_ONCE(call->state); + if (state >= RXRPC_CALL_COMPLETE) { + rxrpc_free_skb(skb, rxrpc_skb_freed); + return; + } + + if (state == RXRPC_CALL_SERVER_RECV_REQUEST) { + unsigned long timo = READ_ONCE(call->next_req_timo); + unsigned long now, expect_req_by; + + if (timo) { + now = jiffies; + expect_req_by = now + timo; + WRITE_ONCE(call->expect_req_by, expect_req_by); + rxrpc_reduce_call_timer(call, expect_req_by, now, + rxrpc_timer_set_for_idle); + } + } + + spin_lock(&call->input_lock); + + /* Received data implicitly ACKs all of the request packets we sent + * when we're acting as a client. + */ + if ((state == RXRPC_CALL_CLIENT_SEND_REQUEST || + state == RXRPC_CALL_CLIENT_AWAIT_REPLY) && + !rxrpc_receiving_reply(call)) + goto unlock; + + hard_ack = READ_ONCE(call->rx_hard_ack); + + nr_subpackets = sp->nr_subpackets; + if (nr_subpackets > 1) { + if (call->nr_jumbo_bad > 3) { + ack = RXRPC_ACK_NOSPACE; + ack_serial = serial; + goto ack; + } + } + + for (j = 0; j < nr_subpackets; j++) { + rxrpc_serial_t serial = sp->hdr.serial + j; + rxrpc_seq_t seq = seq0 + j; + unsigned int ix = seq & RXRPC_RXTX_BUFF_MASK; + bool terminal = (j == nr_subpackets - 1); + bool last = terminal && (sp->rx_flags & RXRPC_SKB_INCL_LAST); + u8 flags, annotation = j; + + _proto("Rx DATA+%u %%%u { #%x t=%u l=%u }", + j, serial, seq, terminal, last); + + if (last) { + if (test_bit(RXRPC_CALL_RX_LAST, &call->flags) && + seq != call->rx_top) { + rxrpc_proto_abort("LSN", call, seq); + goto unlock; + } + } else { + if (test_bit(RXRPC_CALL_RX_LAST, &call->flags) && + after_eq(seq, call->rx_top)) { + rxrpc_proto_abort("LSA", call, seq); + goto unlock; + } + } + + flags = 0; + if (last) + flags |= RXRPC_LAST_PACKET; + if (!terminal) + flags |= RXRPC_JUMBO_PACKET; + if (test_bit(j, sp->rx_req_ack)) + flags |= RXRPC_REQUEST_ACK; + trace_rxrpc_rx_data(call->debug_id, seq, serial, flags, annotation); + + if (before_eq(seq, hard_ack)) { + ack = RXRPC_ACK_DUPLICATE; + ack_serial = serial; + continue; + } + + if (call->rxtx_buffer[ix]) { + rxrpc_input_dup_data(call, seq, nr_subpackets > 1, + &jumbo_bad); + if (ack != RXRPC_ACK_DUPLICATE) { + ack = RXRPC_ACK_DUPLICATE; + ack_serial = serial; + } + immediate_ack = true; + continue; + } + + if (after(seq, hard_ack + call->rx_winsize)) { + ack = RXRPC_ACK_EXCEEDS_WINDOW; + ack_serial = serial; + if (flags & RXRPC_JUMBO_PACKET) { + if (!jumbo_bad) { + call->nr_jumbo_bad++; + jumbo_bad = true; + } + } + + goto ack; + } + + if (flags & RXRPC_REQUEST_ACK && !ack) { + ack = RXRPC_ACK_REQUESTED; + ack_serial = serial; + } + + if (after(seq0, call->ackr_highest_seq)) + call->ackr_highest_seq = seq0; + + /* Queue the packet. We use a couple of memory barriers here as need + * to make sure that rx_top is perceived to be set after the buffer + * pointer and that the buffer pointer is set after the annotation and + * the skb data. + * + * Barriers against rxrpc_recvmsg_data() and rxrpc_rotate_rx_window() + * and also rxrpc_fill_out_ack(). + */ + if (!terminal) + rxrpc_get_skb(skb, rxrpc_skb_got); + call->rxtx_annotations[ix] = annotation; + smp_wmb(); + call->rxtx_buffer[ix] = skb; + if (after(seq, call->rx_top)) { + smp_store_release(&call->rx_top, seq); + } else if (before(seq, call->rx_top)) { + /* Send an immediate ACK if we fill in a hole */ + if (!ack) { + ack = RXRPC_ACK_DELAY; + ack_serial = serial; + } + immediate_ack = true; + } + + if (terminal) { + /* From this point on, we're not allowed to touch the + * packet any longer as its ref now belongs to the Rx + * ring. + */ + skb = NULL; + sp = NULL; + } + + nr_unacked++; + + if (last) { + set_bit(RXRPC_CALL_RX_LAST, &call->flags); + if (!ack) { + ack = RXRPC_ACK_DELAY; + ack_serial = serial; + } + trace_rxrpc_receive(call, rxrpc_receive_queue_last, serial, seq); + } else { + trace_rxrpc_receive(call, rxrpc_receive_queue, serial, seq); + } + + if (after_eq(seq, call->rx_expect_next)) { + if (after(seq, call->rx_expect_next)) { + _net("OOS %u > %u", seq, call->rx_expect_next); + ack = RXRPC_ACK_OUT_OF_SEQUENCE; + ack_serial = serial; + } + call->rx_expect_next = seq + 1; + } + if (!ack) + ack_serial = serial; + } + +ack: + if (atomic_add_return(nr_unacked, &call->ackr_nr_unacked) > 2 && !ack) + ack = RXRPC_ACK_IDLE; + + if (ack) + rxrpc_propose_ACK(call, ack, ack_serial, + immediate_ack, true, + rxrpc_propose_ack_input_data); + else + rxrpc_propose_ACK(call, RXRPC_ACK_DELAY, serial, + false, true, + rxrpc_propose_ack_input_data); + + trace_rxrpc_notify_socket(call->debug_id, serial); + rxrpc_notify_socket(call); + +unlock: + spin_unlock(&call->input_lock); + rxrpc_free_skb(skb, rxrpc_skb_freed); + _leave(" [queued]"); +} + +/* + * See if there's a cached RTT probe to complete. + */ +static void rxrpc_complete_rtt_probe(struct rxrpc_call *call, + ktime_t resp_time, + rxrpc_serial_t acked_serial, + rxrpc_serial_t ack_serial, + enum rxrpc_rtt_rx_trace type) +{ + rxrpc_serial_t orig_serial; + unsigned long avail; + ktime_t sent_at; + bool matched = false; + int i; + + avail = READ_ONCE(call->rtt_avail); + smp_rmb(); /* Read avail bits before accessing data. */ + + for (i = 0; i < ARRAY_SIZE(call->rtt_serial); i++) { + if (!test_bit(i + RXRPC_CALL_RTT_PEND_SHIFT, &avail)) + continue; + + sent_at = call->rtt_sent_at[i]; + orig_serial = call->rtt_serial[i]; + + if (orig_serial == acked_serial) { + clear_bit(i + RXRPC_CALL_RTT_PEND_SHIFT, &call->rtt_avail); + smp_mb(); /* Read data before setting avail bit */ + set_bit(i, &call->rtt_avail); + if (type != rxrpc_rtt_rx_cancel) + rxrpc_peer_add_rtt(call, type, i, acked_serial, ack_serial, + sent_at, resp_time); + else + trace_rxrpc_rtt_rx(call, rxrpc_rtt_rx_cancel, i, + orig_serial, acked_serial, 0, 0); + matched = true; + } + + /* If a later serial is being acked, then mark this slot as + * being available. + */ + if (after(acked_serial, orig_serial)) { + trace_rxrpc_rtt_rx(call, rxrpc_rtt_rx_obsolete, i, + orig_serial, acked_serial, 0, 0); + clear_bit(i + RXRPC_CALL_RTT_PEND_SHIFT, &call->rtt_avail); + smp_wmb(); + set_bit(i, &call->rtt_avail); + } + } + + if (!matched) + trace_rxrpc_rtt_rx(call, rxrpc_rtt_rx_lost, 9, 0, acked_serial, 0, 0); +} + +/* + * Process the response to a ping that we sent to find out if we lost an ACK. + * + * If we got back a ping response that indicates a lower tx_top than what we + * had at the time of the ping transmission, we adjudge all the DATA packets + * sent between the response tx_top and the ping-time tx_top to have been lost. + */ +static void rxrpc_input_check_for_lost_ack(struct rxrpc_call *call) +{ + rxrpc_seq_t top, bottom, seq; + bool resend = false; + + spin_lock_bh(&call->lock); + + bottom = call->tx_hard_ack + 1; + top = call->acks_lost_top; + if (before(bottom, top)) { + for (seq = bottom; before_eq(seq, top); seq++) { + int ix = seq & RXRPC_RXTX_BUFF_MASK; + u8 annotation = call->rxtx_annotations[ix]; + u8 anno_type = annotation & RXRPC_TX_ANNO_MASK; + + if (anno_type != RXRPC_TX_ANNO_UNACK) + continue; + annotation &= ~RXRPC_TX_ANNO_MASK; + annotation |= RXRPC_TX_ANNO_RETRANS; + call->rxtx_annotations[ix] = annotation; + resend = true; + } + } + + spin_unlock_bh(&call->lock); + + if (resend && !test_and_set_bit(RXRPC_CALL_EV_RESEND, &call->events)) + rxrpc_queue_call(call); +} + +/* + * Process a ping response. + */ +static void rxrpc_input_ping_response(struct rxrpc_call *call, + ktime_t resp_time, + rxrpc_serial_t acked_serial, + rxrpc_serial_t ack_serial) +{ + if (acked_serial == call->acks_lost_ping) + rxrpc_input_check_for_lost_ack(call); +} + +/* + * Process the extra information that may be appended to an ACK packet + */ +static void rxrpc_input_ackinfo(struct rxrpc_call *call, struct sk_buff *skb, + struct rxrpc_ackinfo *ackinfo) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + struct rxrpc_peer *peer; + unsigned int mtu; + bool wake = false; + u32 rwind = ntohl(ackinfo->rwind); + + _proto("Rx ACK %%%u Info { rx=%u max=%u rwin=%u jm=%u }", + sp->hdr.serial, + ntohl(ackinfo->rxMTU), ntohl(ackinfo->maxMTU), + rwind, ntohl(ackinfo->jumbo_max)); + + if (rwind > RXRPC_RXTX_BUFF_SIZE - 1) + rwind = RXRPC_RXTX_BUFF_SIZE - 1; + if (call->tx_winsize != rwind) { + if (rwind > call->tx_winsize) + wake = true; + trace_rxrpc_rx_rwind_change(call, sp->hdr.serial, rwind, wake); + call->tx_winsize = rwind; + } + + if (call->cong_ssthresh > rwind) + call->cong_ssthresh = rwind; + + mtu = min(ntohl(ackinfo->rxMTU), ntohl(ackinfo->maxMTU)); + + peer = call->peer; + if (mtu < peer->maxdata) { + spin_lock_bh(&peer->lock); + peer->maxdata = mtu; + peer->mtu = mtu + peer->hdrsize; + spin_unlock_bh(&peer->lock); + _net("Net MTU %u (maxdata %u)", peer->mtu, peer->maxdata); + } + + if (wake) + wake_up(&call->waitq); +} + +/* + * Process individual soft ACKs. + * + * Each ACK in the array corresponds to one packet and can be either an ACK or + * a NAK. If we get find an explicitly NAK'd packet we resend immediately; + * packets that lie beyond the end of the ACK list are scheduled for resend by + * the timer on the basis that the peer might just not have processed them at + * the time the ACK was sent. + */ +static void rxrpc_input_soft_acks(struct rxrpc_call *call, u8 *acks, + rxrpc_seq_t seq, int nr_acks, + struct rxrpc_ack_summary *summary) +{ + int ix; + u8 annotation, anno_type; + + for (; nr_acks > 0; nr_acks--, seq++) { + ix = seq & RXRPC_RXTX_BUFF_MASK; + annotation = call->rxtx_annotations[ix]; + anno_type = annotation & RXRPC_TX_ANNO_MASK; + annotation &= ~RXRPC_TX_ANNO_MASK; + switch (*acks++) { + case RXRPC_ACK_TYPE_ACK: + summary->nr_acks++; + if (anno_type == RXRPC_TX_ANNO_ACK) + continue; + summary->nr_new_acks++; + call->rxtx_annotations[ix] = + RXRPC_TX_ANNO_ACK | annotation; + break; + case RXRPC_ACK_TYPE_NACK: + if (!summary->nr_nacks && + call->acks_lowest_nak != seq) { + call->acks_lowest_nak = seq; + summary->new_low_nack = true; + } + summary->nr_nacks++; + if (anno_type == RXRPC_TX_ANNO_NAK) + continue; + summary->nr_new_nacks++; + if (anno_type == RXRPC_TX_ANNO_RETRANS) + continue; + call->rxtx_annotations[ix] = + RXRPC_TX_ANNO_NAK | annotation; + break; + default: + return rxrpc_proto_abort("SFT", call, 0); + } + } +} + +/* + * Return true if the ACK is valid - ie. it doesn't appear to have regressed + * with respect to the ack state conveyed by preceding ACKs. + */ +static bool rxrpc_is_ack_valid(struct rxrpc_call *call, + rxrpc_seq_t first_pkt, rxrpc_seq_t prev_pkt) +{ + rxrpc_seq_t base = READ_ONCE(call->acks_first_seq); + + if (after(first_pkt, base)) + return true; /* The window advanced */ + + if (before(first_pkt, base)) + return false; /* firstPacket regressed */ + + if (after_eq(prev_pkt, call->acks_prev_seq)) + return true; /* previousPacket hasn't regressed. */ + + /* Some rx implementations put a serial number in previousPacket. */ + if (after_eq(prev_pkt, base + call->tx_winsize)) + return false; + return true; +} + +/* + * Process an ACK packet. + * + * ack.firstPacket is the sequence number of the first soft-ACK'd/NAK'd packet + * in the ACK array. Anything before that is hard-ACK'd and may be discarded. + * + * A hard-ACK means that a packet has been processed and may be discarded; a + * soft-ACK means that the packet may be discarded and retransmission + * requested. A phase is complete when all packets are hard-ACK'd. + */ +static void rxrpc_input_ack(struct rxrpc_call *call, struct sk_buff *skb) +{ + struct rxrpc_ack_summary summary = { 0 }; + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + union { + struct rxrpc_ackpacket ack; + struct rxrpc_ackinfo info; + u8 acks[RXRPC_MAXACKS]; + } buf; + rxrpc_serial_t ack_serial, acked_serial; + rxrpc_seq_t first_soft_ack, hard_ack, prev_pkt; + int nr_acks, offset, ioffset; + + _enter(""); + + offset = sizeof(struct rxrpc_wire_header); + if (skb_copy_bits(skb, offset, &buf.ack, sizeof(buf.ack)) < 0) { + _debug("extraction failure"); + return rxrpc_proto_abort("XAK", call, 0); + } + offset += sizeof(buf.ack); + + ack_serial = sp->hdr.serial; + acked_serial = ntohl(buf.ack.serial); + first_soft_ack = ntohl(buf.ack.firstPacket); + prev_pkt = ntohl(buf.ack.previousPacket); + hard_ack = first_soft_ack - 1; + nr_acks = buf.ack.nAcks; + summary.ack_reason = (buf.ack.reason < RXRPC_ACK__INVALID ? + buf.ack.reason : RXRPC_ACK__INVALID); + + trace_rxrpc_rx_ack(call, ack_serial, acked_serial, + first_soft_ack, prev_pkt, + summary.ack_reason, nr_acks); + + switch (buf.ack.reason) { + case RXRPC_ACK_PING_RESPONSE: + rxrpc_input_ping_response(call, skb->tstamp, acked_serial, + ack_serial); + rxrpc_complete_rtt_probe(call, skb->tstamp, acked_serial, ack_serial, + rxrpc_rtt_rx_ping_response); + break; + case RXRPC_ACK_REQUESTED: + rxrpc_complete_rtt_probe(call, skb->tstamp, acked_serial, ack_serial, + rxrpc_rtt_rx_requested_ack); + break; + default: + if (acked_serial != 0) + rxrpc_complete_rtt_probe(call, skb->tstamp, acked_serial, ack_serial, + rxrpc_rtt_rx_cancel); + break; + } + + if (buf.ack.reason == RXRPC_ACK_PING) { + _proto("Rx ACK %%%u PING Request", ack_serial); + rxrpc_propose_ACK(call, RXRPC_ACK_PING_RESPONSE, + ack_serial, true, true, + rxrpc_propose_ack_respond_to_ping); + } else if (sp->hdr.flags & RXRPC_REQUEST_ACK) { + rxrpc_propose_ACK(call, RXRPC_ACK_REQUESTED, + ack_serial, true, true, + rxrpc_propose_ack_respond_to_ack); + } + + /* If we get an EXCEEDS_WINDOW ACK from the server, it probably + * indicates that the client address changed due to NAT. The server + * lost the call because it switched to a different peer. + */ + if (unlikely(buf.ack.reason == RXRPC_ACK_EXCEEDS_WINDOW) && + first_soft_ack == 1 && + prev_pkt == 0 && + rxrpc_is_client_call(call)) { + rxrpc_set_call_completion(call, RXRPC_CALL_REMOTELY_ABORTED, + 0, -ENETRESET); + return; + } + + /* If we get an OUT_OF_SEQUENCE ACK from the server, that can also + * indicate a change of address. However, we can retransmit the call + * if we still have it buffered to the beginning. + */ + if (unlikely(buf.ack.reason == RXRPC_ACK_OUT_OF_SEQUENCE) && + first_soft_ack == 1 && + prev_pkt == 0 && + call->tx_hard_ack == 0 && + rxrpc_is_client_call(call)) { + rxrpc_set_call_completion(call, RXRPC_CALL_REMOTELY_ABORTED, + 0, -ENETRESET); + return; + } + + /* Discard any out-of-order or duplicate ACKs (outside lock). */ + if (!rxrpc_is_ack_valid(call, first_soft_ack, prev_pkt)) { + trace_rxrpc_rx_discard_ack(call->debug_id, ack_serial, + first_soft_ack, call->acks_first_seq, + prev_pkt, call->acks_prev_seq); + return; + } + + buf.info.rxMTU = 0; + ioffset = offset + nr_acks + 3; + if (skb->len >= ioffset + sizeof(buf.info) && + skb_copy_bits(skb, ioffset, &buf.info, sizeof(buf.info)) < 0) + return rxrpc_proto_abort("XAI", call, 0); + + spin_lock(&call->input_lock); + + /* Discard any out-of-order or duplicate ACKs (inside lock). */ + if (!rxrpc_is_ack_valid(call, first_soft_ack, prev_pkt)) { + trace_rxrpc_rx_discard_ack(call->debug_id, ack_serial, + first_soft_ack, call->acks_first_seq, + prev_pkt, call->acks_prev_seq); + goto out; + } + call->acks_latest_ts = skb->tstamp; + + call->acks_first_seq = first_soft_ack; + call->acks_prev_seq = prev_pkt; + + /* Parse rwind and mtu sizes if provided. */ + if (buf.info.rxMTU) + rxrpc_input_ackinfo(call, skb, &buf.info); + + if (first_soft_ack == 0) { + rxrpc_proto_abort("AK0", call, 0); + goto out; + } + + /* Ignore ACKs unless we are or have just been transmitting. */ + switch (READ_ONCE(call->state)) { + case RXRPC_CALL_CLIENT_SEND_REQUEST: + case RXRPC_CALL_CLIENT_AWAIT_REPLY: + case RXRPC_CALL_SERVER_SEND_REPLY: + case RXRPC_CALL_SERVER_AWAIT_ACK: + break; + default: + goto out; + } + + if (before(hard_ack, call->tx_hard_ack) || + after(hard_ack, call->tx_top)) { + rxrpc_proto_abort("AKW", call, 0); + goto out; + } + if (nr_acks > call->tx_top - hard_ack) { + rxrpc_proto_abort("AKN", call, 0); + goto out; + } + + if (after(hard_ack, call->tx_hard_ack)) { + if (rxrpc_rotate_tx_window(call, hard_ack, &summary)) { + rxrpc_end_tx_phase(call, false, "ETA"); + goto out; + } + } + + if (nr_acks > 0) { + if (skb_copy_bits(skb, offset, buf.acks, nr_acks) < 0) { + rxrpc_proto_abort("XSA", call, 0); + goto out; + } + rxrpc_input_soft_acks(call, buf.acks, first_soft_ack, nr_acks, + &summary); + } + + if (call->rxtx_annotations[call->tx_top & RXRPC_RXTX_BUFF_MASK] & + RXRPC_TX_ANNO_LAST && + summary.nr_acks == call->tx_top - hard_ack && + rxrpc_is_client_call(call)) + rxrpc_propose_ACK(call, RXRPC_ACK_PING, ack_serial, + false, true, + rxrpc_propose_ack_ping_for_lost_reply); + + rxrpc_congestion_management(call, skb, &summary, acked_serial); +out: + spin_unlock(&call->input_lock); +} + +/* + * Process an ACKALL packet. + */ +static void rxrpc_input_ackall(struct rxrpc_call *call, struct sk_buff *skb) +{ + struct rxrpc_ack_summary summary = { 0 }; + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + + _proto("Rx ACKALL %%%u", sp->hdr.serial); + + spin_lock(&call->input_lock); + + if (rxrpc_rotate_tx_window(call, call->tx_top, &summary)) + rxrpc_end_tx_phase(call, false, "ETL"); + + spin_unlock(&call->input_lock); +} + +/* + * Process an ABORT packet directed at a call. + */ +static void rxrpc_input_abort(struct rxrpc_call *call, struct sk_buff *skb) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + __be32 wtmp; + u32 abort_code = RX_CALL_DEAD; + + _enter(""); + + if (skb->len >= 4 && + skb_copy_bits(skb, sizeof(struct rxrpc_wire_header), + &wtmp, sizeof(wtmp)) >= 0) + abort_code = ntohl(wtmp); + + trace_rxrpc_rx_abort(call, sp->hdr.serial, abort_code); + + _proto("Rx ABORT %%%u { %x }", sp->hdr.serial, abort_code); + + rxrpc_set_call_completion(call, RXRPC_CALL_REMOTELY_ABORTED, + abort_code, -ECONNABORTED); +} + +/* + * Process an incoming call packet. + */ +static void rxrpc_input_call_packet(struct rxrpc_call *call, + struct sk_buff *skb) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + unsigned long timo; + + _enter("%p,%p", call, skb); + + timo = READ_ONCE(call->next_rx_timo); + if (timo) { + unsigned long now = jiffies, expect_rx_by; + + expect_rx_by = now + timo; + WRITE_ONCE(call->expect_rx_by, expect_rx_by); + rxrpc_reduce_call_timer(call, expect_rx_by, now, + rxrpc_timer_set_for_normal); + } + + switch (sp->hdr.type) { + case RXRPC_PACKET_TYPE_DATA: + rxrpc_input_data(call, skb); + goto no_free; + + case RXRPC_PACKET_TYPE_ACK: + rxrpc_input_ack(call, skb); + break; + + case RXRPC_PACKET_TYPE_BUSY: + _proto("Rx BUSY %%%u", sp->hdr.serial); + + /* Just ignore BUSY packets from the server; the retry and + * lifespan timers will take care of business. BUSY packets + * from the client don't make sense. + */ + break; + + case RXRPC_PACKET_TYPE_ABORT: + rxrpc_input_abort(call, skb); + break; + + case RXRPC_PACKET_TYPE_ACKALL: + rxrpc_input_ackall(call, skb); + break; + + default: + break; + } + + rxrpc_free_skb(skb, rxrpc_skb_freed); +no_free: + _leave(""); +} + +/* + * Handle a new service call on a channel implicitly completing the preceding + * call on that channel. This does not apply to client conns. + * + * TODO: If callNumber > call_id + 1, renegotiate security. + */ +static void rxrpc_input_implicit_end_call(struct rxrpc_sock *rx, + struct rxrpc_connection *conn, + struct rxrpc_call *call) +{ + switch (READ_ONCE(call->state)) { + case RXRPC_CALL_SERVER_AWAIT_ACK: + rxrpc_call_completed(call); + fallthrough; + case RXRPC_CALL_COMPLETE: + break; + default: + if (rxrpc_abort_call("IMP", call, 0, RX_CALL_DEAD, -ESHUTDOWN)) { + set_bit(RXRPC_CALL_EV_ABORT, &call->events); + rxrpc_queue_call(call); + } + trace_rxrpc_improper_term(call); + break; + } + + spin_lock(&rx->incoming_lock); + __rxrpc_disconnect_call(conn, call); + spin_unlock(&rx->incoming_lock); +} + +/* + * post connection-level events to the connection + * - this includes challenges, responses, some aborts and call terminal packet + * retransmission. + */ +static void rxrpc_post_packet_to_conn(struct rxrpc_connection *conn, + struct sk_buff *skb) +{ + _enter("%p,%p", conn, skb); + + skb_queue_tail(&conn->rx_queue, skb); + rxrpc_queue_conn(conn); +} + +/* + * post endpoint-level events to the local endpoint + * - this includes debug and version messages + */ +static void rxrpc_post_packet_to_local(struct rxrpc_local *local, + struct sk_buff *skb) +{ + _enter("%p,%p", local, skb); + + if (rxrpc_get_local_maybe(local)) { + skb_queue_tail(&local->event_queue, skb); + rxrpc_queue_local(local); + } else { + rxrpc_free_skb(skb, rxrpc_skb_freed); + } +} + +/* + * put a packet up for transport-level abort + */ +static void rxrpc_reject_packet(struct rxrpc_local *local, struct sk_buff *skb) +{ + if (rxrpc_get_local_maybe(local)) { + skb_queue_tail(&local->reject_queue, skb); + rxrpc_queue_local(local); + } else { + rxrpc_free_skb(skb, rxrpc_skb_freed); + } +} + +/* + * Extract the wire header from a packet and translate the byte order. + */ +static noinline +int rxrpc_extract_header(struct rxrpc_skb_priv *sp, struct sk_buff *skb) +{ + struct rxrpc_wire_header whdr; + + /* dig out the RxRPC connection details */ + if (skb_copy_bits(skb, 0, &whdr, sizeof(whdr)) < 0) { + trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, + tracepoint_string("bad_hdr")); + return -EBADMSG; + } + + memset(sp, 0, sizeof(*sp)); + sp->hdr.epoch = ntohl(whdr.epoch); + sp->hdr.cid = ntohl(whdr.cid); + sp->hdr.callNumber = ntohl(whdr.callNumber); + sp->hdr.seq = ntohl(whdr.seq); + sp->hdr.serial = ntohl(whdr.serial); + sp->hdr.flags = whdr.flags; + sp->hdr.type = whdr.type; + sp->hdr.userStatus = whdr.userStatus; + sp->hdr.securityIndex = whdr.securityIndex; + sp->hdr._rsvd = ntohs(whdr._rsvd); + sp->hdr.serviceId = ntohs(whdr.serviceId); + return 0; +} + +/* + * handle data received on the local endpoint + * - may be called in interrupt context + * + * [!] Note that as this is called from the encap_rcv hook, the socket is not + * held locked by the caller and nothing prevents sk_user_data on the UDP from + * being cleared in the middle of processing this function. + * + * Called with the RCU read lock held from the IP layer via UDP. + */ +int rxrpc_input_packet(struct sock *udp_sk, struct sk_buff *skb) +{ + struct rxrpc_local *local = rcu_dereference_sk_user_data(udp_sk); + struct rxrpc_connection *conn; + struct rxrpc_channel *chan; + struct rxrpc_call *call = NULL; + struct rxrpc_skb_priv *sp; + struct rxrpc_peer *peer = NULL; + struct rxrpc_sock *rx = NULL; + unsigned int channel; + + _enter("%p", udp_sk); + + if (unlikely(!local)) { + kfree_skb(skb); + return 0; + } + if (skb->tstamp == 0) + skb->tstamp = ktime_get_real(); + + rxrpc_new_skb(skb, rxrpc_skb_received); + + skb_pull(skb, sizeof(struct udphdr)); + + /* The UDP protocol already released all skb resources; + * we are free to add our own data there. + */ + sp = rxrpc_skb(skb); + + /* dig out the RxRPC connection details */ + if (rxrpc_extract_header(sp, skb) < 0) + goto bad_message; + + if (IS_ENABLED(CONFIG_AF_RXRPC_INJECT_LOSS)) { + static int lose; + if ((lose++ & 7) == 7) { + trace_rxrpc_rx_lose(sp); + rxrpc_free_skb(skb, rxrpc_skb_lost); + return 0; + } + } + + if (skb->tstamp == 0) + skb->tstamp = ktime_get_real(); + trace_rxrpc_rx_packet(sp); + + switch (sp->hdr.type) { + case RXRPC_PACKET_TYPE_VERSION: + if (rxrpc_to_client(sp)) + goto discard; + rxrpc_post_packet_to_local(local, skb); + goto out; + + case RXRPC_PACKET_TYPE_BUSY: + if (rxrpc_to_server(sp)) + goto discard; + fallthrough; + case RXRPC_PACKET_TYPE_ACK: + case RXRPC_PACKET_TYPE_ACKALL: + if (sp->hdr.callNumber == 0) + goto bad_message; + fallthrough; + case RXRPC_PACKET_TYPE_ABORT: + break; + + case RXRPC_PACKET_TYPE_DATA: + if (sp->hdr.callNumber == 0 || + sp->hdr.seq == 0) + goto bad_message; + if (!rxrpc_validate_data(skb)) + goto bad_message; + + /* Unshare the packet so that it can be modified for in-place + * decryption. + */ + if (sp->hdr.securityIndex != 0) { + struct sk_buff *nskb = skb_unshare(skb, GFP_ATOMIC); + if (!nskb) { + rxrpc_eaten_skb(skb, rxrpc_skb_unshared_nomem); + goto out; + } + + if (nskb != skb) { + rxrpc_eaten_skb(skb, rxrpc_skb_received); + skb = nskb; + rxrpc_new_skb(skb, rxrpc_skb_unshared); + sp = rxrpc_skb(skb); + } + } + break; + + case RXRPC_PACKET_TYPE_CHALLENGE: + if (rxrpc_to_server(sp)) + goto discard; + break; + case RXRPC_PACKET_TYPE_RESPONSE: + if (rxrpc_to_client(sp)) + goto discard; + break; + + /* Packet types 9-11 should just be ignored. */ + case RXRPC_PACKET_TYPE_PARAMS: + case RXRPC_PACKET_TYPE_10: + case RXRPC_PACKET_TYPE_11: + goto discard; + + default: + _proto("Rx Bad Packet Type %u", sp->hdr.type); + goto bad_message; + } + + if (sp->hdr.serviceId == 0) + goto bad_message; + + if (rxrpc_to_server(sp)) { + /* Weed out packets to services we're not offering. Packets + * that would begin a call are explicitly rejected and the rest + * are just discarded. + */ + rx = rcu_dereference(local->service); + if (!rx || (sp->hdr.serviceId != rx->srx.srx_service && + sp->hdr.serviceId != rx->second_service)) { + if (sp->hdr.type == RXRPC_PACKET_TYPE_DATA && + sp->hdr.seq == 1) + goto unsupported_service; + goto discard; + } + } + + conn = rxrpc_find_connection_rcu(local, skb, &peer); + if (conn) { + if (sp->hdr.securityIndex != conn->security_ix) + goto wrong_security; + + if (sp->hdr.serviceId != conn->service_id) { + int old_id; + + if (!test_bit(RXRPC_CONN_PROBING_FOR_UPGRADE, &conn->flags)) + goto reupgrade; + old_id = cmpxchg(&conn->service_id, conn->params.service_id, + sp->hdr.serviceId); + + if (old_id != conn->params.service_id && + old_id != sp->hdr.serviceId) + goto reupgrade; + } + + if (sp->hdr.callNumber == 0) { + /* Connection-level packet */ + _debug("CONN %p {%d}", conn, conn->debug_id); + rxrpc_post_packet_to_conn(conn, skb); + goto out; + } + + if ((int)sp->hdr.serial - (int)conn->hi_serial > 0) + conn->hi_serial = sp->hdr.serial; + + /* Call-bound packets are routed by connection channel. */ + channel = sp->hdr.cid & RXRPC_CHANNELMASK; + chan = &conn->channels[channel]; + + /* Ignore really old calls */ + if (sp->hdr.callNumber < chan->last_call) + goto discard; + + if (sp->hdr.callNumber == chan->last_call) { + if (chan->call || + sp->hdr.type == RXRPC_PACKET_TYPE_ABORT) + goto discard; + + /* For the previous service call, if completed + * successfully, we discard all further packets. + */ + if (rxrpc_conn_is_service(conn) && + chan->last_type == RXRPC_PACKET_TYPE_ACK) + goto discard; + + /* But otherwise we need to retransmit the final packet + * from data cached in the connection record. + */ + if (sp->hdr.type == RXRPC_PACKET_TYPE_DATA) + trace_rxrpc_rx_data(chan->call_debug_id, + sp->hdr.seq, + sp->hdr.serial, + sp->hdr.flags, 0); + rxrpc_post_packet_to_conn(conn, skb); + goto out; + } + + call = rcu_dereference(chan->call); + + if (sp->hdr.callNumber > chan->call_id) { + if (rxrpc_to_client(sp)) + goto reject_packet; + if (call) + rxrpc_input_implicit_end_call(rx, conn, call); + call = NULL; + } + + if (call) { + if (sp->hdr.serviceId != call->service_id) + call->service_id = sp->hdr.serviceId; + if ((int)sp->hdr.serial - (int)call->rx_serial > 0) + call->rx_serial = sp->hdr.serial; + if (!test_bit(RXRPC_CALL_RX_HEARD, &call->flags)) + set_bit(RXRPC_CALL_RX_HEARD, &call->flags); + } + } + + if (!call || refcount_read(&call->ref) == 0) { + if (rxrpc_to_client(sp) || + sp->hdr.type != RXRPC_PACKET_TYPE_DATA) + goto bad_message; + if (sp->hdr.seq != 1) + goto discard; + call = rxrpc_new_incoming_call(local, rx, skb); + if (!call) + goto reject_packet; + } + + /* Process a call packet; this either discards or passes on the ref + * elsewhere. + */ + rxrpc_input_call_packet(call, skb); + goto out; + +discard: + rxrpc_free_skb(skb, rxrpc_skb_freed); +out: + trace_rxrpc_rx_done(0, 0); + return 0; + +wrong_security: + trace_rxrpc_abort(0, "SEC", sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, + RXKADINCONSISTENCY, EBADMSG); + skb->priority = RXKADINCONSISTENCY; + goto post_abort; + +unsupported_service: + trace_rxrpc_abort(0, "INV", sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, + RX_INVALID_OPERATION, EOPNOTSUPP); + skb->priority = RX_INVALID_OPERATION; + goto post_abort; + +reupgrade: + trace_rxrpc_abort(0, "UPG", sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, + RX_PROTOCOL_ERROR, EBADMSG); + goto protocol_error; + +bad_message: + trace_rxrpc_abort(0, "BAD", sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, + RX_PROTOCOL_ERROR, EBADMSG); +protocol_error: + skb->priority = RX_PROTOCOL_ERROR; +post_abort: + skb->mark = RXRPC_SKB_MARK_REJECT_ABORT; +reject_packet: + trace_rxrpc_rx_done(skb->mark, skb->priority); + rxrpc_reject_packet(local, skb); + _leave(" [badmsg]"); + return 0; +} diff --git a/net/rxrpc/insecure.c b/net/rxrpc/insecure.c new file mode 100644 index 000000000..9aae99d67 --- /dev/null +++ b/net/rxrpc/insecure.c @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* Null security operations. + * + * Copyright (C) 2016 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#include +#include "ar-internal.h" + +static int none_init_connection_security(struct rxrpc_connection *conn, + struct rxrpc_key_token *token) +{ + return 0; +} + +/* + * Work out how much data we can put in an unsecured packet. + */ +static int none_how_much_data(struct rxrpc_call *call, size_t remain, + size_t *_buf_size, size_t *_data_size, size_t *_offset) +{ + *_buf_size = *_data_size = min_t(size_t, remain, RXRPC_JUMBO_DATALEN); + *_offset = 0; + return 0; +} + +static int none_secure_packet(struct rxrpc_call *call, struct sk_buff *skb, + size_t data_size) +{ + return 0; +} + +static int none_verify_packet(struct rxrpc_call *call, struct sk_buff *skb, + unsigned int offset, unsigned int len, + rxrpc_seq_t seq, u16 expected_cksum) +{ + return 0; +} + +static void none_free_call_crypto(struct rxrpc_call *call) +{ +} + +static void none_locate_data(struct rxrpc_call *call, struct sk_buff *skb, + unsigned int *_offset, unsigned int *_len) +{ +} + +static int none_respond_to_challenge(struct rxrpc_connection *conn, + struct sk_buff *skb, + u32 *_abort_code) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + + trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, + tracepoint_string("chall_none")); + return -EPROTO; +} + +static int none_verify_response(struct rxrpc_connection *conn, + struct sk_buff *skb, + u32 *_abort_code) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + + trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, + tracepoint_string("resp_none")); + return -EPROTO; +} + +static void none_clear(struct rxrpc_connection *conn) +{ +} + +static int none_init(void) +{ + return 0; +} + +static void none_exit(void) +{ +} + +/* + * RxRPC Kerberos-based security + */ +const struct rxrpc_security rxrpc_no_security = { + .name = "none", + .security_index = RXRPC_SECURITY_NONE, + .init = none_init, + .exit = none_exit, + .init_connection_security = none_init_connection_security, + .free_call_crypto = none_free_call_crypto, + .how_much_data = none_how_much_data, + .secure_packet = none_secure_packet, + .verify_packet = none_verify_packet, + .locate_data = none_locate_data, + .respond_to_challenge = none_respond_to_challenge, + .verify_response = none_verify_response, + .clear = none_clear, +}; diff --git a/net/rxrpc/key.c b/net/rxrpc/key.c new file mode 100644 index 000000000..8d2073e0e --- /dev/null +++ b/net/rxrpc/key.c @@ -0,0 +1,695 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* RxRPC key management + * + * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + * + * RxRPC keys should have a description of describing their purpose: + * "afs@example.com" + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ar-internal.h" + +static int rxrpc_preparse(struct key_preparsed_payload *); +static void rxrpc_free_preparse(struct key_preparsed_payload *); +static void rxrpc_destroy(struct key *); +static void rxrpc_describe(const struct key *, struct seq_file *); +static long rxrpc_read(const struct key *, char *, size_t); + +/* + * rxrpc defined keys take an arbitrary string as the description and an + * arbitrary blob of data as the payload + */ +struct key_type key_type_rxrpc = { + .name = "rxrpc", + .flags = KEY_TYPE_NET_DOMAIN, + .preparse = rxrpc_preparse, + .free_preparse = rxrpc_free_preparse, + .instantiate = generic_key_instantiate, + .destroy = rxrpc_destroy, + .describe = rxrpc_describe, + .read = rxrpc_read, +}; +EXPORT_SYMBOL(key_type_rxrpc); + +/* + * parse an RxKAD type XDR format token + * - the caller guarantees we have at least 4 words + */ +static int rxrpc_preparse_xdr_rxkad(struct key_preparsed_payload *prep, + size_t datalen, + const __be32 *xdr, unsigned int toklen) +{ + struct rxrpc_key_token *token, **pptoken; + time64_t expiry; + size_t plen; + u32 tktlen; + + _enter(",{%x,%x,%x,%x},%u", + ntohl(xdr[0]), ntohl(xdr[1]), ntohl(xdr[2]), ntohl(xdr[3]), + toklen); + + if (toklen <= 8 * 4) + return -EKEYREJECTED; + tktlen = ntohl(xdr[7]); + _debug("tktlen: %x", tktlen); + if (tktlen > AFSTOKEN_RK_TIX_MAX) + return -EKEYREJECTED; + if (toklen < 8 * 4 + tktlen) + return -EKEYREJECTED; + + plen = sizeof(*token) + sizeof(*token->kad) + tktlen; + prep->quotalen = datalen + plen; + + plen -= sizeof(*token); + token = kzalloc(sizeof(*token), GFP_KERNEL); + if (!token) + return -ENOMEM; + + token->kad = kzalloc(plen, GFP_KERNEL); + if (!token->kad) { + kfree(token); + return -ENOMEM; + } + + token->security_index = RXRPC_SECURITY_RXKAD; + token->kad->ticket_len = tktlen; + token->kad->vice_id = ntohl(xdr[0]); + token->kad->kvno = ntohl(xdr[1]); + token->kad->start = ntohl(xdr[4]); + token->kad->expiry = ntohl(xdr[5]); + token->kad->primary_flag = ntohl(xdr[6]); + memcpy(&token->kad->session_key, &xdr[2], 8); + memcpy(&token->kad->ticket, &xdr[8], tktlen); + + _debug("SCIX: %u", token->security_index); + _debug("TLEN: %u", token->kad->ticket_len); + _debug("EXPY: %x", token->kad->expiry); + _debug("KVNO: %u", token->kad->kvno); + _debug("PRIM: %u", token->kad->primary_flag); + _debug("SKEY: %02x%02x%02x%02x%02x%02x%02x%02x", + token->kad->session_key[0], token->kad->session_key[1], + token->kad->session_key[2], token->kad->session_key[3], + token->kad->session_key[4], token->kad->session_key[5], + token->kad->session_key[6], token->kad->session_key[7]); + if (token->kad->ticket_len >= 8) + _debug("TCKT: %02x%02x%02x%02x%02x%02x%02x%02x", + token->kad->ticket[0], token->kad->ticket[1], + token->kad->ticket[2], token->kad->ticket[3], + token->kad->ticket[4], token->kad->ticket[5], + token->kad->ticket[6], token->kad->ticket[7]); + + /* count the number of tokens attached */ + prep->payload.data[1] = (void *)((unsigned long)prep->payload.data[1] + 1); + + /* attach the data */ + for (pptoken = (struct rxrpc_key_token **)&prep->payload.data[0]; + *pptoken; + pptoken = &(*pptoken)->next) + continue; + *pptoken = token; + expiry = rxrpc_u32_to_time64(token->kad->expiry); + if (expiry < prep->expiry) + prep->expiry = expiry; + + _leave(" = 0"); + return 0; +} + +/* + * attempt to parse the data as the XDR format + * - the caller guarantees we have more than 7 words + */ +static int rxrpc_preparse_xdr(struct key_preparsed_payload *prep) +{ + const __be32 *xdr = prep->data, *token, *p; + const char *cp; + unsigned int len, paddedlen, loop, ntoken, toklen, sec_ix; + size_t datalen = prep->datalen; + int ret, ret2; + + _enter(",{%x,%x,%x,%x},%zu", + ntohl(xdr[0]), ntohl(xdr[1]), ntohl(xdr[2]), ntohl(xdr[3]), + prep->datalen); + + if (datalen > AFSTOKEN_LENGTH_MAX) + goto not_xdr; + + /* XDR is an array of __be32's */ + if (datalen & 3) + goto not_xdr; + + /* the flags should be 0 (the setpag bit must be handled by + * userspace) */ + if (ntohl(*xdr++) != 0) + goto not_xdr; + datalen -= 4; + + /* check the cell name */ + len = ntohl(*xdr++); + if (len < 1 || len > AFSTOKEN_CELL_MAX) + goto not_xdr; + datalen -= 4; + paddedlen = (len + 3) & ~3; + if (paddedlen > datalen) + goto not_xdr; + + cp = (const char *) xdr; + for (loop = 0; loop < len; loop++) + if (!isprint(cp[loop])) + goto not_xdr; + for (; loop < paddedlen; loop++) + if (cp[loop]) + goto not_xdr; + _debug("cellname: [%u/%u] '%*.*s'", + len, paddedlen, len, len, (const char *) xdr); + datalen -= paddedlen; + xdr += paddedlen >> 2; + + /* get the token count */ + if (datalen < 12) + goto not_xdr; + ntoken = ntohl(*xdr++); + datalen -= 4; + _debug("ntoken: %x", ntoken); + if (ntoken < 1 || ntoken > AFSTOKEN_MAX) + goto not_xdr; + + /* check each token wrapper */ + p = xdr; + loop = ntoken; + do { + if (datalen < 8) + goto not_xdr; + toklen = ntohl(*p++); + sec_ix = ntohl(*p); + datalen -= 4; + _debug("token: [%x/%zx] %x", toklen, datalen, sec_ix); + paddedlen = (toklen + 3) & ~3; + if (toklen < 20 || toklen > datalen || paddedlen > datalen) + goto not_xdr; + datalen -= paddedlen; + p += paddedlen >> 2; + + } while (--loop > 0); + + _debug("remainder: %zu", datalen); + if (datalen != 0) + goto not_xdr; + + /* okay: we're going to assume it's valid XDR format + * - we ignore the cellname, relying on the key to be correctly named + */ + ret = -EPROTONOSUPPORT; + do { + toklen = ntohl(*xdr++); + token = xdr; + xdr += (toklen + 3) / 4; + + sec_ix = ntohl(*token++); + toklen -= 4; + + _debug("TOKEN type=%x len=%x", sec_ix, toklen); + + switch (sec_ix) { + case RXRPC_SECURITY_RXKAD: + ret2 = rxrpc_preparse_xdr_rxkad(prep, datalen, token, toklen); + break; + default: + ret2 = -EPROTONOSUPPORT; + break; + } + + switch (ret2) { + case 0: + ret = 0; + break; + case -EPROTONOSUPPORT: + break; + case -ENOPKG: + if (ret != 0) + ret = -ENOPKG; + break; + default: + ret = ret2; + goto error; + } + + } while (--ntoken > 0); + +error: + _leave(" = %d", ret); + return ret; + +not_xdr: + _leave(" = -EPROTO"); + return -EPROTO; +} + +/* + * Preparse an rxrpc defined key. + * + * Data should be of the form: + * OFFSET LEN CONTENT + * 0 4 key interface version number + * 4 2 security index (type) + * 6 2 ticket length + * 8 4 key expiry time (time_t) + * 12 4 kvno + * 16 8 session key + * 24 [len] ticket + * + * if no data is provided, then a no-security key is made + */ +static int rxrpc_preparse(struct key_preparsed_payload *prep) +{ + const struct rxrpc_key_data_v1 *v1; + struct rxrpc_key_token *token, **pp; + time64_t expiry; + size_t plen; + u32 kver; + int ret; + + _enter("%zu", prep->datalen); + + /* handle a no-security key */ + if (!prep->data && prep->datalen == 0) + return 0; + + /* determine if the XDR payload format is being used */ + if (prep->datalen > 7 * 4) { + ret = rxrpc_preparse_xdr(prep); + if (ret != -EPROTO) + return ret; + } + + /* get the key interface version number */ + ret = -EINVAL; + if (prep->datalen <= 4 || !prep->data) + goto error; + memcpy(&kver, prep->data, sizeof(kver)); + prep->data += sizeof(kver); + prep->datalen -= sizeof(kver); + + _debug("KEY I/F VERSION: %u", kver); + + ret = -EKEYREJECTED; + if (kver != 1) + goto error; + + /* deal with a version 1 key */ + ret = -EINVAL; + if (prep->datalen < sizeof(*v1)) + goto error; + + v1 = prep->data; + if (prep->datalen != sizeof(*v1) + v1->ticket_length) + goto error; + + _debug("SCIX: %u", v1->security_index); + _debug("TLEN: %u", v1->ticket_length); + _debug("EXPY: %x", v1->expiry); + _debug("KVNO: %u", v1->kvno); + _debug("SKEY: %02x%02x%02x%02x%02x%02x%02x%02x", + v1->session_key[0], v1->session_key[1], + v1->session_key[2], v1->session_key[3], + v1->session_key[4], v1->session_key[5], + v1->session_key[6], v1->session_key[7]); + if (v1->ticket_length >= 8) + _debug("TCKT: %02x%02x%02x%02x%02x%02x%02x%02x", + v1->ticket[0], v1->ticket[1], + v1->ticket[2], v1->ticket[3], + v1->ticket[4], v1->ticket[5], + v1->ticket[6], v1->ticket[7]); + + ret = -EPROTONOSUPPORT; + if (v1->security_index != RXRPC_SECURITY_RXKAD) + goto error; + + plen = sizeof(*token->kad) + v1->ticket_length; + prep->quotalen = plen + sizeof(*token); + + ret = -ENOMEM; + token = kzalloc(sizeof(*token), GFP_KERNEL); + if (!token) + goto error; + token->kad = kzalloc(plen, GFP_KERNEL); + if (!token->kad) + goto error_free; + + token->security_index = RXRPC_SECURITY_RXKAD; + token->kad->ticket_len = v1->ticket_length; + token->kad->expiry = v1->expiry; + token->kad->kvno = v1->kvno; + memcpy(&token->kad->session_key, &v1->session_key, 8); + memcpy(&token->kad->ticket, v1->ticket, v1->ticket_length); + + /* count the number of tokens attached */ + prep->payload.data[1] = (void *)((unsigned long)prep->payload.data[1] + 1); + + /* attach the data */ + pp = (struct rxrpc_key_token **)&prep->payload.data[0]; + while (*pp) + pp = &(*pp)->next; + *pp = token; + expiry = rxrpc_u32_to_time64(token->kad->expiry); + if (expiry < prep->expiry) + prep->expiry = expiry; + token = NULL; + ret = 0; + +error_free: + kfree(token); +error: + return ret; +} + +/* + * Free token list. + */ +static void rxrpc_free_token_list(struct rxrpc_key_token *token) +{ + struct rxrpc_key_token *next; + + for (; token; token = next) { + next = token->next; + switch (token->security_index) { + case RXRPC_SECURITY_RXKAD: + kfree(token->kad); + break; + default: + pr_err("Unknown token type %x on rxrpc key\n", + token->security_index); + BUG(); + } + + kfree(token); + } +} + +/* + * Clean up preparse data. + */ +static void rxrpc_free_preparse(struct key_preparsed_payload *prep) +{ + rxrpc_free_token_list(prep->payload.data[0]); +} + +/* + * dispose of the data dangling from the corpse of a rxrpc key + */ +static void rxrpc_destroy(struct key *key) +{ + rxrpc_free_token_list(key->payload.data[0]); +} + +/* + * describe the rxrpc key + */ +static void rxrpc_describe(const struct key *key, struct seq_file *m) +{ + const struct rxrpc_key_token *token; + const char *sep = ": "; + + seq_puts(m, key->description); + + for (token = key->payload.data[0]; token; token = token->next) { + seq_puts(m, sep); + + switch (token->security_index) { + case RXRPC_SECURITY_RXKAD: + seq_puts(m, "ka"); + break; + default: /* we have a ticket we can't encode */ + seq_printf(m, "%u", token->security_index); + break; + } + + sep = " "; + } +} + +/* + * grab the security key for a socket + */ +int rxrpc_request_key(struct rxrpc_sock *rx, sockptr_t optval, int optlen) +{ + struct key *key; + char *description; + + _enter(""); + + if (optlen <= 0 || optlen > PAGE_SIZE - 1 || rx->securities) + return -EINVAL; + + description = memdup_sockptr_nul(optval, optlen); + if (IS_ERR(description)) + return PTR_ERR(description); + + key = request_key_net(&key_type_rxrpc, description, sock_net(&rx->sk), NULL); + if (IS_ERR(key)) { + kfree(description); + _leave(" = %ld", PTR_ERR(key)); + return PTR_ERR(key); + } + + rx->key = key; + kfree(description); + _leave(" = 0 [key %x]", key->serial); + return 0; +} + +/* + * generate a server data key + */ +int rxrpc_get_server_data_key(struct rxrpc_connection *conn, + const void *session_key, + time64_t expiry, + u32 kvno) +{ + const struct cred *cred = current_cred(); + struct key *key; + int ret; + + struct { + u32 kver; + struct rxrpc_key_data_v1 v1; + } data; + + _enter(""); + + key = key_alloc(&key_type_rxrpc, "x", + GLOBAL_ROOT_UID, GLOBAL_ROOT_GID, cred, 0, + KEY_ALLOC_NOT_IN_QUOTA, NULL); + if (IS_ERR(key)) { + _leave(" = -ENOMEM [alloc %ld]", PTR_ERR(key)); + return -ENOMEM; + } + + _debug("key %d", key_serial(key)); + + data.kver = 1; + data.v1.security_index = RXRPC_SECURITY_RXKAD; + data.v1.ticket_length = 0; + data.v1.expiry = rxrpc_time64_to_u32(expiry); + data.v1.kvno = 0; + + memcpy(&data.v1.session_key, session_key, sizeof(data.v1.session_key)); + + ret = key_instantiate_and_link(key, &data, sizeof(data), NULL, NULL); + if (ret < 0) + goto error; + + conn->params.key = key; + _leave(" = 0 [%d]", key_serial(key)); + return 0; + +error: + key_revoke(key); + key_put(key); + _leave(" = -ENOMEM [ins %d]", ret); + return -ENOMEM; +} +EXPORT_SYMBOL(rxrpc_get_server_data_key); + +/** + * rxrpc_get_null_key - Generate a null RxRPC key + * @keyname: The name to give the key. + * + * Generate a null RxRPC key that can be used to indicate anonymous security is + * required for a particular domain. + */ +struct key *rxrpc_get_null_key(const char *keyname) +{ + const struct cred *cred = current_cred(); + struct key *key; + int ret; + + key = key_alloc(&key_type_rxrpc, keyname, + GLOBAL_ROOT_UID, GLOBAL_ROOT_GID, cred, + KEY_POS_SEARCH, KEY_ALLOC_NOT_IN_QUOTA, NULL); + if (IS_ERR(key)) + return key; + + ret = key_instantiate_and_link(key, NULL, 0, NULL, NULL); + if (ret < 0) { + key_revoke(key); + key_put(key); + return ERR_PTR(ret); + } + + return key; +} +EXPORT_SYMBOL(rxrpc_get_null_key); + +/* + * read the contents of an rxrpc key + * - this returns the result in XDR form + */ +static long rxrpc_read(const struct key *key, + char *buffer, size_t buflen) +{ + const struct rxrpc_key_token *token; + size_t size; + __be32 *xdr, *oldxdr; + u32 cnlen, toksize, ntoks, tok, zero; + u16 toksizes[AFSTOKEN_MAX]; + + _enter(""); + + /* we don't know what form we should return non-AFS keys in */ + if (memcmp(key->description, "afs@", 4) != 0) + return -EOPNOTSUPP; + cnlen = strlen(key->description + 4); + +#define RND(X) (((X) + 3) & ~3) + + /* AFS keys we return in XDR form, so we need to work out the size of + * the XDR */ + size = 2 * 4; /* flags, cellname len */ + size += RND(cnlen); /* cellname */ + size += 1 * 4; /* token count */ + + ntoks = 0; + for (token = key->payload.data[0]; token; token = token->next) { + toksize = 4; /* sec index */ + + switch (token->security_index) { + case RXRPC_SECURITY_RXKAD: + toksize += 8 * 4; /* viceid, kvno, key*2, begin, + * end, primary, tktlen */ + if (!token->no_leak_key) + toksize += RND(token->kad->ticket_len); + break; + + default: /* we have a ticket we can't encode */ + pr_err("Unsupported key token type (%u)\n", + token->security_index); + return -ENOPKG; + } + + _debug("token[%u]: toksize=%u", ntoks, toksize); + ASSERTCMP(toksize, <=, AFSTOKEN_LENGTH_MAX); + + toksizes[ntoks++] = toksize; + size += toksize + 4; /* each token has a length word */ + } + +#undef RND + + if (!buffer || buflen < size) + return size; + + xdr = (__be32 *)buffer; + zero = 0; +#define ENCODE(x) \ + do { \ + *xdr++ = htonl(x); \ + } while(0) +#define ENCODE_DATA(l, s) \ + do { \ + u32 _l = (l); \ + ENCODE(l); \ + memcpy(xdr, (s), _l); \ + if (_l & 3) \ + memcpy((u8 *)xdr + _l, &zero, 4 - (_l & 3)); \ + xdr += (_l + 3) >> 2; \ + } while(0) +#define ENCODE_BYTES(l, s) \ + do { \ + u32 _l = (l); \ + memcpy(xdr, (s), _l); \ + if (_l & 3) \ + memcpy((u8 *)xdr + _l, &zero, 4 - (_l & 3)); \ + xdr += (_l + 3) >> 2; \ + } while(0) +#define ENCODE64(x) \ + do { \ + __be64 y = cpu_to_be64(x); \ + memcpy(xdr, &y, 8); \ + xdr += 8 >> 2; \ + } while(0) +#define ENCODE_STR(s) \ + do { \ + const char *_s = (s); \ + ENCODE_DATA(strlen(_s), _s); \ + } while(0) + + ENCODE(0); /* flags */ + ENCODE_DATA(cnlen, key->description + 4); /* cellname */ + ENCODE(ntoks); + + tok = 0; + for (token = key->payload.data[0]; token; token = token->next) { + toksize = toksizes[tok++]; + ENCODE(toksize); + oldxdr = xdr; + ENCODE(token->security_index); + + switch (token->security_index) { + case RXRPC_SECURITY_RXKAD: + ENCODE(token->kad->vice_id); + ENCODE(token->kad->kvno); + ENCODE_BYTES(8, token->kad->session_key); + ENCODE(token->kad->start); + ENCODE(token->kad->expiry); + ENCODE(token->kad->primary_flag); + if (token->no_leak_key) + ENCODE(0); + else + ENCODE_DATA(token->kad->ticket_len, token->kad->ticket); + break; + + default: + pr_err("Unsupported key token type (%u)\n", + token->security_index); + return -ENOPKG; + } + + ASSERTCMP((unsigned long)xdr - (unsigned long)oldxdr, ==, + toksize); + } + +#undef ENCODE_STR +#undef ENCODE_DATA +#undef ENCODE64 +#undef ENCODE + + ASSERTCMP(tok, ==, ntoks); + ASSERTCMP((char __user *) xdr - buffer, ==, size); + _leave(" = %zu", size); + return size; +} diff --git a/net/rxrpc/local_event.c b/net/rxrpc/local_event.c new file mode 100644 index 000000000..19e929c7c --- /dev/null +++ b/net/rxrpc/local_event.c @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* AF_RXRPC local endpoint management + * + * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include "ar-internal.h" + +static const char rxrpc_version_string[65] = "linux-" UTS_RELEASE " AF_RXRPC"; + +/* + * Reply to a version request + */ +static void rxrpc_send_version_request(struct rxrpc_local *local, + struct rxrpc_host_header *hdr, + struct sk_buff *skb) +{ + struct rxrpc_wire_header whdr; + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + struct sockaddr_rxrpc srx; + struct msghdr msg; + struct kvec iov[2]; + size_t len; + int ret; + + _enter(""); + + if (rxrpc_extract_addr_from_skb(&srx, skb) < 0) + return; + + msg.msg_name = &srx.transport; + msg.msg_namelen = srx.transport_len; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + whdr.epoch = htonl(sp->hdr.epoch); + whdr.cid = htonl(sp->hdr.cid); + whdr.callNumber = htonl(sp->hdr.callNumber); + whdr.seq = 0; + whdr.serial = 0; + whdr.type = RXRPC_PACKET_TYPE_VERSION; + whdr.flags = RXRPC_LAST_PACKET | (~hdr->flags & RXRPC_CLIENT_INITIATED); + whdr.userStatus = 0; + whdr.securityIndex = 0; + whdr._rsvd = 0; + whdr.serviceId = htons(sp->hdr.serviceId); + + iov[0].iov_base = &whdr; + iov[0].iov_len = sizeof(whdr); + iov[1].iov_base = (char *)rxrpc_version_string; + iov[1].iov_len = sizeof(rxrpc_version_string); + + len = iov[0].iov_len + iov[1].iov_len; + + _proto("Tx VERSION (reply)"); + + ret = kernel_sendmsg(local->socket, &msg, iov, 2, len); + if (ret < 0) + trace_rxrpc_tx_fail(local->debug_id, 0, ret, + rxrpc_tx_point_version_reply); + else + trace_rxrpc_tx_packet(local->debug_id, &whdr, + rxrpc_tx_point_version_reply); + + _leave(""); +} + +/* + * Process event packets targeted at a local endpoint. + */ +void rxrpc_process_local_events(struct rxrpc_local *local) +{ + struct sk_buff *skb; + char v; + + _enter(""); + + skb = skb_dequeue(&local->event_queue); + if (skb) { + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + + rxrpc_see_skb(skb, rxrpc_skb_seen); + _debug("{%d},{%u}", local->debug_id, sp->hdr.type); + + switch (sp->hdr.type) { + case RXRPC_PACKET_TYPE_VERSION: + if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header), + &v, 1) < 0) + return; + _proto("Rx VERSION { %02x }", v); + if (v == 0) + rxrpc_send_version_request(local, &sp->hdr, skb); + break; + + default: + /* Just ignore anything we don't understand */ + break; + } + + rxrpc_free_skb(skb, rxrpc_skb_freed); + } + + _leave(""); +} diff --git a/net/rxrpc/local_object.c b/net/rxrpc/local_object.c new file mode 100644 index 000000000..38ea98ff4 --- /dev/null +++ b/net/rxrpc/local_object.c @@ -0,0 +1,474 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* Local endpoint object management + * + * Copyright (C) 2016 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ar-internal.h" + +static void rxrpc_local_processor(struct work_struct *); +static void rxrpc_local_rcu(struct rcu_head *); + +/* + * Compare a local to an address. Return -ve, 0 or +ve to indicate less than, + * same or greater than. + * + * We explicitly don't compare the RxRPC service ID as we want to reject + * conflicting uses by differing services. Further, we don't want to share + * addresses with different options (IPv6), so we don't compare those bits + * either. + */ +static long rxrpc_local_cmp_key(const struct rxrpc_local *local, + const struct sockaddr_rxrpc *srx) +{ + long diff; + + diff = ((local->srx.transport_type - srx->transport_type) ?: + (local->srx.transport_len - srx->transport_len) ?: + (local->srx.transport.family - srx->transport.family)); + if (diff != 0) + return diff; + + switch (srx->transport.family) { + case AF_INET: + /* If the choice of UDP port is left up to the transport, then + * the endpoint record doesn't match. + */ + return ((u16 __force)local->srx.transport.sin.sin_port - + (u16 __force)srx->transport.sin.sin_port) ?: + memcmp(&local->srx.transport.sin.sin_addr, + &srx->transport.sin.sin_addr, + sizeof(struct in_addr)); +#ifdef CONFIG_AF_RXRPC_IPV6 + case AF_INET6: + /* If the choice of UDP6 port is left up to the transport, then + * the endpoint record doesn't match. + */ + return ((u16 __force)local->srx.transport.sin6.sin6_port - + (u16 __force)srx->transport.sin6.sin6_port) ?: + memcmp(&local->srx.transport.sin6.sin6_addr, + &srx->transport.sin6.sin6_addr, + sizeof(struct in6_addr)); +#endif + default: + BUG(); + } +} + +/* + * Allocate a new local endpoint. + */ +static struct rxrpc_local *rxrpc_alloc_local(struct rxrpc_net *rxnet, + const struct sockaddr_rxrpc *srx) +{ + struct rxrpc_local *local; + + local = kzalloc(sizeof(struct rxrpc_local), GFP_KERNEL); + if (local) { + refcount_set(&local->ref, 1); + atomic_set(&local->active_users, 1); + local->rxnet = rxnet; + INIT_HLIST_NODE(&local->link); + INIT_WORK(&local->processor, rxrpc_local_processor); + init_rwsem(&local->defrag_sem); + skb_queue_head_init(&local->reject_queue); + skb_queue_head_init(&local->event_queue); + local->client_bundles = RB_ROOT; + spin_lock_init(&local->client_bundles_lock); + spin_lock_init(&local->lock); + rwlock_init(&local->services_lock); + local->debug_id = atomic_inc_return(&rxrpc_debug_id); + memcpy(&local->srx, srx, sizeof(*srx)); + local->srx.srx_service = 0; + trace_rxrpc_local(local->debug_id, rxrpc_local_new, 1, NULL); + } + + _leave(" = %p", local); + return local; +} + +/* + * create the local socket + * - must be called with rxrpc_local_mutex locked + */ +static int rxrpc_open_socket(struct rxrpc_local *local, struct net *net) +{ + struct udp_tunnel_sock_cfg tuncfg = {NULL}; + struct sockaddr_rxrpc *srx = &local->srx; + struct udp_port_cfg udp_conf = {0}; + struct sock *usk; + int ret; + + _enter("%p{%d,%d}", + local, srx->transport_type, srx->transport.family); + + udp_conf.family = srx->transport.family; + udp_conf.use_udp_checksums = true; + if (udp_conf.family == AF_INET) { + udp_conf.local_ip = srx->transport.sin.sin_addr; + udp_conf.local_udp_port = srx->transport.sin.sin_port; +#if IS_ENABLED(CONFIG_AF_RXRPC_IPV6) + } else { + udp_conf.local_ip6 = srx->transport.sin6.sin6_addr; + udp_conf.local_udp_port = srx->transport.sin6.sin6_port; + udp_conf.use_udp6_tx_checksums = true; + udp_conf.use_udp6_rx_checksums = true; +#endif + } + ret = udp_sock_create(net, &udp_conf, &local->socket); + if (ret < 0) { + _leave(" = %d [socket]", ret); + return ret; + } + + tuncfg.encap_type = UDP_ENCAP_RXRPC; + tuncfg.encap_rcv = rxrpc_input_packet; + tuncfg.encap_err_rcv = rxrpc_encap_err_rcv; + tuncfg.sk_user_data = local; + setup_udp_tunnel_sock(net, local->socket, &tuncfg); + + /* set the socket up */ + usk = local->socket->sk; + usk->sk_error_report = rxrpc_error_report; + + switch (srx->transport.family) { + case AF_INET6: + /* we want to receive ICMPv6 errors */ + ip6_sock_set_recverr(usk); + + /* Fall through and set IPv4 options too otherwise we don't get + * errors from IPv4 packets sent through the IPv6 socket. + */ + fallthrough; + case AF_INET: + /* we want to receive ICMP errors */ + ip_sock_set_recverr(usk); + + /* we want to set the don't fragment bit */ + ip_sock_set_mtu_discover(usk, IP_PMTUDISC_DO); + + /* We want receive timestamps. */ + sock_enable_timestamps(usk); + break; + + default: + BUG(); + } + + _leave(" = 0"); + return 0; +} + +/* + * Look up or create a new local endpoint using the specified local address. + */ +struct rxrpc_local *rxrpc_lookup_local(struct net *net, + const struct sockaddr_rxrpc *srx) +{ + struct rxrpc_local *local; + struct rxrpc_net *rxnet = rxrpc_net(net); + struct hlist_node *cursor; + const char *age; + long diff; + int ret; + + _enter("{%d,%d,%pISp}", + srx->transport_type, srx->transport.family, &srx->transport); + + mutex_lock(&rxnet->local_mutex); + + hlist_for_each(cursor, &rxnet->local_endpoints) { + local = hlist_entry(cursor, struct rxrpc_local, link); + + diff = rxrpc_local_cmp_key(local, srx); + if (diff != 0) + continue; + + /* Services aren't allowed to share transport sockets, so + * reject that here. It is possible that the object is dying - + * but it may also still have the local transport address that + * we want bound. + */ + if (srx->srx_service) { + local = NULL; + goto addr_in_use; + } + + /* Found a match. We want to replace a dying object. + * Attempting to bind the transport socket may still fail if + * we're attempting to use a local address that the dying + * object is still using. + */ + if (!rxrpc_use_local(local)) + break; + + age = "old"; + goto found; + } + + local = rxrpc_alloc_local(rxnet, srx); + if (!local) + goto nomem; + + ret = rxrpc_open_socket(local, net); + if (ret < 0) + goto sock_error; + + if (cursor) { + hlist_replace_rcu(cursor, &local->link); + cursor->pprev = NULL; + } else { + hlist_add_head_rcu(&local->link, &rxnet->local_endpoints); + } + age = "new"; + +found: + mutex_unlock(&rxnet->local_mutex); + + _net("LOCAL %s %d {%pISp}", + age, local->debug_id, &local->srx.transport); + + _leave(" = %p", local); + return local; + +nomem: + ret = -ENOMEM; +sock_error: + mutex_unlock(&rxnet->local_mutex); + if (local) + call_rcu(&local->rcu, rxrpc_local_rcu); + _leave(" = %d", ret); + return ERR_PTR(ret); + +addr_in_use: + mutex_unlock(&rxnet->local_mutex); + _leave(" = -EADDRINUSE"); + return ERR_PTR(-EADDRINUSE); +} + +/* + * Get a ref on a local endpoint. + */ +struct rxrpc_local *rxrpc_get_local(struct rxrpc_local *local) +{ + const void *here = __builtin_return_address(0); + int r; + + __refcount_inc(&local->ref, &r); + trace_rxrpc_local(local->debug_id, rxrpc_local_got, r + 1, here); + return local; +} + +/* + * Get a ref on a local endpoint unless its usage has already reached 0. + */ +struct rxrpc_local *rxrpc_get_local_maybe(struct rxrpc_local *local) +{ + const void *here = __builtin_return_address(0); + int r; + + if (local) { + if (__refcount_inc_not_zero(&local->ref, &r)) + trace_rxrpc_local(local->debug_id, rxrpc_local_got, + r + 1, here); + else + local = NULL; + } + return local; +} + +/* + * Queue a local endpoint and pass the caller's reference to the work item. + */ +void rxrpc_queue_local(struct rxrpc_local *local) +{ + const void *here = __builtin_return_address(0); + unsigned int debug_id = local->debug_id; + int r = refcount_read(&local->ref); + + if (rxrpc_queue_work(&local->processor)) + trace_rxrpc_local(debug_id, rxrpc_local_queued, r + 1, here); + else + rxrpc_put_local(local); +} + +/* + * Drop a ref on a local endpoint. + */ +void rxrpc_put_local(struct rxrpc_local *local) +{ + const void *here = __builtin_return_address(0); + unsigned int debug_id; + bool dead; + int r; + + if (local) { + debug_id = local->debug_id; + + dead = __refcount_dec_and_test(&local->ref, &r); + trace_rxrpc_local(debug_id, rxrpc_local_put, r, here); + + if (dead) + call_rcu(&local->rcu, rxrpc_local_rcu); + } +} + +/* + * Start using a local endpoint. + */ +struct rxrpc_local *rxrpc_use_local(struct rxrpc_local *local) +{ + local = rxrpc_get_local_maybe(local); + if (!local) + return NULL; + + if (!__rxrpc_use_local(local)) { + rxrpc_put_local(local); + return NULL; + } + + return local; +} + +/* + * Cease using a local endpoint. Once the number of active users reaches 0, we + * start the closure of the transport in the work processor. + */ +void rxrpc_unuse_local(struct rxrpc_local *local) +{ + if (local) { + if (__rxrpc_unuse_local(local)) { + rxrpc_get_local(local); + rxrpc_queue_local(local); + } + } +} + +/* + * Destroy a local endpoint's socket and then hand the record to RCU to dispose + * of. + * + * Closing the socket cannot be done from bottom half context or RCU callback + * context because it might sleep. + */ +static void rxrpc_local_destroyer(struct rxrpc_local *local) +{ + struct socket *socket = local->socket; + struct rxrpc_net *rxnet = local->rxnet; + + _enter("%d", local->debug_id); + + local->dead = true; + + mutex_lock(&rxnet->local_mutex); + hlist_del_init_rcu(&local->link); + mutex_unlock(&rxnet->local_mutex); + + rxrpc_clean_up_local_conns(local); + rxrpc_service_connection_reaper(&rxnet->service_conn_reaper); + ASSERT(!local->service); + + if (socket) { + local->socket = NULL; + kernel_sock_shutdown(socket, SHUT_RDWR); + socket->sk->sk_user_data = NULL; + sock_release(socket); + } + + /* At this point, there should be no more packets coming in to the + * local endpoint. + */ + rxrpc_purge_queue(&local->reject_queue); + rxrpc_purge_queue(&local->event_queue); +} + +/* + * Process events on an endpoint. The work item carries a ref which + * we must release. + */ +static void rxrpc_local_processor(struct work_struct *work) +{ + struct rxrpc_local *local = + container_of(work, struct rxrpc_local, processor); + bool again; + + if (local->dead) + return; + + trace_rxrpc_local(local->debug_id, rxrpc_local_processing, + refcount_read(&local->ref), NULL); + + do { + again = false; + if (!__rxrpc_use_local(local)) { + rxrpc_local_destroyer(local); + break; + } + + if (!skb_queue_empty(&local->reject_queue)) { + rxrpc_reject_packets(local); + again = true; + } + + if (!skb_queue_empty(&local->event_queue)) { + rxrpc_process_local_events(local); + again = true; + } + + __rxrpc_unuse_local(local); + } while (again); + + rxrpc_put_local(local); +} + +/* + * Destroy a local endpoint after the RCU grace period expires. + */ +static void rxrpc_local_rcu(struct rcu_head *rcu) +{ + struct rxrpc_local *local = container_of(rcu, struct rxrpc_local, rcu); + + _enter("%d", local->debug_id); + + ASSERT(!work_pending(&local->processor)); + + _net("DESTROY LOCAL %d", local->debug_id); + kfree(local); + _leave(""); +} + +/* + * Verify the local endpoint list is empty by this point. + */ +void rxrpc_destroy_all_locals(struct rxrpc_net *rxnet) +{ + struct rxrpc_local *local; + + _enter(""); + + flush_workqueue(rxrpc_workqueue); + + if (!hlist_empty(&rxnet->local_endpoints)) { + mutex_lock(&rxnet->local_mutex); + hlist_for_each_entry(local, &rxnet->local_endpoints, link) { + pr_err("AF_RXRPC: Leaked local %p {%d}\n", + local, refcount_read(&local->ref)); + } + mutex_unlock(&rxnet->local_mutex); + BUG(); + } +} diff --git a/net/rxrpc/misc.c b/net/rxrpc/misc.c new file mode 100644 index 000000000..d4144fd86 --- /dev/null +++ b/net/rxrpc/misc.c @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* Miscellaneous bits + * + * Copyright (C) 2016 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#include +#include +#include +#include "ar-internal.h" + +/* + * The maximum listening backlog queue size that may be set on a socket by + * listen(). + */ +unsigned int rxrpc_max_backlog __read_mostly = 10; + +/* + * How long to wait before scheduling ACK generation after seeing a + * packet with RXRPC_REQUEST_ACK set (in jiffies). + */ +unsigned long rxrpc_requested_ack_delay = 1; + +/* + * How long to wait before scheduling an ACK with subtype DELAY (in jiffies). + * + * We use this when we've received new data packets. If those packets aren't + * all consumed within this time we will send a DELAY ACK if an ACK was not + * requested to let the sender know it doesn't need to resend. + */ +unsigned long rxrpc_soft_ack_delay = HZ; + +/* + * How long to wait before scheduling an ACK with subtype IDLE (in jiffies). + * + * We use this when we've consumed some previously soft-ACK'd packets when + * further packets aren't immediately received to decide when to send an IDLE + * ACK let the other end know that it can free up its Tx buffer space. + */ +unsigned long rxrpc_idle_ack_delay = HZ / 2; + +/* + * Receive window size in packets. This indicates the maximum number of + * unconsumed received packets we're willing to retain in memory. Once this + * limit is hit, we should generate an EXCEEDS_WINDOW ACK and discard further + * packets. + */ +unsigned int rxrpc_rx_window_size = RXRPC_INIT_RX_WINDOW_SIZE; +#if (RXRPC_RXTX_BUFF_SIZE - 1) < RXRPC_INIT_RX_WINDOW_SIZE +#error Need to reduce RXRPC_INIT_RX_WINDOW_SIZE +#endif + +/* + * Maximum Rx MTU size. This indicates to the sender the size of jumbo packet + * made by gluing normal packets together that we're willing to handle. + */ +unsigned int rxrpc_rx_mtu = 5692; + +/* + * The maximum number of fragments in a received jumbo packet that we tell the + * sender that we're willing to handle. + */ +unsigned int rxrpc_rx_jumbo_max = 4; + +const s8 rxrpc_ack_priority[] = { + [0] = 0, + [RXRPC_ACK_DELAY] = 1, + [RXRPC_ACK_REQUESTED] = 2, + [RXRPC_ACK_IDLE] = 3, + [RXRPC_ACK_DUPLICATE] = 4, + [RXRPC_ACK_OUT_OF_SEQUENCE] = 5, + [RXRPC_ACK_EXCEEDS_WINDOW] = 6, + [RXRPC_ACK_NOSPACE] = 7, + [RXRPC_ACK_PING_RESPONSE] = 8, +}; diff --git a/net/rxrpc/net_ns.c b/net/rxrpc/net_ns.c new file mode 100644 index 000000000..bb4c25d6d --- /dev/null +++ b/net/rxrpc/net_ns.c @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* rxrpc network namespace handling. + * + * Copyright (C) 2017 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#include +#include "ar-internal.h" + +unsigned int rxrpc_net_id; + +static void rxrpc_client_conn_reap_timeout(struct timer_list *timer) +{ + struct rxrpc_net *rxnet = + container_of(timer, struct rxrpc_net, client_conn_reap_timer); + + if (rxnet->live) + rxrpc_queue_work(&rxnet->client_conn_reaper); +} + +static void rxrpc_service_conn_reap_timeout(struct timer_list *timer) +{ + struct rxrpc_net *rxnet = + container_of(timer, struct rxrpc_net, service_conn_reap_timer); + + if (rxnet->live) + rxrpc_queue_work(&rxnet->service_conn_reaper); +} + +static void rxrpc_peer_keepalive_timeout(struct timer_list *timer) +{ + struct rxrpc_net *rxnet = + container_of(timer, struct rxrpc_net, peer_keepalive_timer); + + if (rxnet->live) + rxrpc_queue_work(&rxnet->peer_keepalive_work); +} + +/* + * Initialise a per-network namespace record. + */ +static __net_init int rxrpc_init_net(struct net *net) +{ + struct rxrpc_net *rxnet = rxrpc_net(net); + int ret, i; + + rxnet->live = true; + get_random_bytes(&rxnet->epoch, sizeof(rxnet->epoch)); + rxnet->epoch |= RXRPC_RANDOM_EPOCH; + + INIT_LIST_HEAD(&rxnet->calls); + spin_lock_init(&rxnet->call_lock); + atomic_set(&rxnet->nr_calls, 1); + + atomic_set(&rxnet->nr_conns, 1); + INIT_LIST_HEAD(&rxnet->conn_proc_list); + INIT_LIST_HEAD(&rxnet->service_conns); + rwlock_init(&rxnet->conn_lock); + INIT_WORK(&rxnet->service_conn_reaper, + rxrpc_service_connection_reaper); + timer_setup(&rxnet->service_conn_reap_timer, + rxrpc_service_conn_reap_timeout, 0); + + atomic_set(&rxnet->nr_client_conns, 0); + rxnet->kill_all_client_conns = false; + spin_lock_init(&rxnet->client_conn_cache_lock); + spin_lock_init(&rxnet->client_conn_discard_lock); + INIT_LIST_HEAD(&rxnet->idle_client_conns); + INIT_WORK(&rxnet->client_conn_reaper, + rxrpc_discard_expired_client_conns); + timer_setup(&rxnet->client_conn_reap_timer, + rxrpc_client_conn_reap_timeout, 0); + + INIT_HLIST_HEAD(&rxnet->local_endpoints); + mutex_init(&rxnet->local_mutex); + + hash_init(rxnet->peer_hash); + spin_lock_init(&rxnet->peer_hash_lock); + for (i = 0; i < ARRAY_SIZE(rxnet->peer_keepalive); i++) + INIT_LIST_HEAD(&rxnet->peer_keepalive[i]); + INIT_LIST_HEAD(&rxnet->peer_keepalive_new); + timer_setup(&rxnet->peer_keepalive_timer, + rxrpc_peer_keepalive_timeout, 0); + INIT_WORK(&rxnet->peer_keepalive_work, rxrpc_peer_keepalive_worker); + rxnet->peer_keepalive_base = ktime_get_seconds(); + + ret = -ENOMEM; + rxnet->proc_net = proc_net_mkdir(net, "rxrpc", net->proc_net); + if (!rxnet->proc_net) + goto err_proc; + + proc_create_net("calls", 0444, rxnet->proc_net, &rxrpc_call_seq_ops, + sizeof(struct seq_net_private)); + proc_create_net("conns", 0444, rxnet->proc_net, + &rxrpc_connection_seq_ops, + sizeof(struct seq_net_private)); + proc_create_net("peers", 0444, rxnet->proc_net, + &rxrpc_peer_seq_ops, + sizeof(struct seq_net_private)); + proc_create_net("locals", 0444, rxnet->proc_net, + &rxrpc_local_seq_ops, + sizeof(struct seq_net_private)); + return 0; + +err_proc: + rxnet->live = false; + return ret; +} + +/* + * Clean up a per-network namespace record. + */ +static __net_exit void rxrpc_exit_net(struct net *net) +{ + struct rxrpc_net *rxnet = rxrpc_net(net); + + rxnet->live = false; + del_timer_sync(&rxnet->peer_keepalive_timer); + cancel_work_sync(&rxnet->peer_keepalive_work); + /* Remove the timer again as the worker may have restarted it. */ + del_timer_sync(&rxnet->peer_keepalive_timer); + rxrpc_destroy_all_calls(rxnet); + rxrpc_destroy_all_connections(rxnet); + rxrpc_destroy_all_peers(rxnet); + rxrpc_destroy_all_locals(rxnet); + proc_remove(rxnet->proc_net); +} + +struct pernet_operations rxrpc_net_ops = { + .init = rxrpc_init_net, + .exit = rxrpc_exit_net, + .id = &rxrpc_net_id, + .size = sizeof(struct rxrpc_net), +}; diff --git a/net/rxrpc/output.c b/net/rxrpc/output.c new file mode 100644 index 000000000..08c117bc0 --- /dev/null +++ b/net/rxrpc/output.c @@ -0,0 +1,679 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* RxRPC packet transmission + * + * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include "ar-internal.h" + +struct rxrpc_ack_buffer { + struct rxrpc_wire_header whdr; + struct rxrpc_ackpacket ack; + u8 acks[255]; + u8 pad[3]; + struct rxrpc_ackinfo ackinfo; +}; + +struct rxrpc_abort_buffer { + struct rxrpc_wire_header whdr; + __be32 abort_code; +}; + +static const char rxrpc_keepalive_string[] = ""; + +/* + * Increase Tx backoff on transmission failure and clear it on success. + */ +static void rxrpc_tx_backoff(struct rxrpc_call *call, int ret) +{ + if (ret < 0) { + u16 tx_backoff = READ_ONCE(call->tx_backoff); + + if (tx_backoff < HZ) + WRITE_ONCE(call->tx_backoff, tx_backoff + 1); + } else { + WRITE_ONCE(call->tx_backoff, 0); + } +} + +/* + * Arrange for a keepalive ping a certain time after we last transmitted. This + * lets the far side know we're still interested in this call and helps keep + * the route through any intervening firewall open. + * + * Receiving a response to the ping will prevent the ->expect_rx_by timer from + * expiring. + */ +static void rxrpc_set_keepalive(struct rxrpc_call *call) +{ + unsigned long now = jiffies, keepalive_at = call->next_rx_timo / 6; + + keepalive_at += now; + WRITE_ONCE(call->keepalive_at, keepalive_at); + rxrpc_reduce_call_timer(call, keepalive_at, now, + rxrpc_timer_set_for_keepalive); +} + +/* + * Fill out an ACK packet. + */ +static size_t rxrpc_fill_out_ack(struct rxrpc_connection *conn, + struct rxrpc_call *call, + struct rxrpc_ack_buffer *pkt, + rxrpc_seq_t *_hard_ack, + rxrpc_seq_t *_top, + u8 reason) +{ + rxrpc_serial_t serial; + unsigned int tmp; + rxrpc_seq_t hard_ack, top, seq; + int ix; + u32 mtu, jmax; + u8 *ackp = pkt->acks; + + tmp = atomic_xchg(&call->ackr_nr_unacked, 0); + tmp |= atomic_xchg(&call->ackr_nr_consumed, 0); + if (!tmp && (reason == RXRPC_ACK_DELAY || + reason == RXRPC_ACK_IDLE)) + return 0; + + /* Barrier against rxrpc_input_data(). */ + serial = call->ackr_serial; + hard_ack = READ_ONCE(call->rx_hard_ack); + top = smp_load_acquire(&call->rx_top); + *_hard_ack = hard_ack; + *_top = top; + + pkt->ack.bufferSpace = htons(0); + pkt->ack.maxSkew = htons(0); + pkt->ack.firstPacket = htonl(hard_ack + 1); + pkt->ack.previousPacket = htonl(call->ackr_highest_seq); + pkt->ack.serial = htonl(serial); + pkt->ack.reason = reason; + pkt->ack.nAcks = top - hard_ack; + + if (reason == RXRPC_ACK_PING) + pkt->whdr.flags |= RXRPC_REQUEST_ACK; + + if (after(top, hard_ack)) { + seq = hard_ack + 1; + do { + ix = seq & RXRPC_RXTX_BUFF_MASK; + if (call->rxtx_buffer[ix]) + *ackp++ = RXRPC_ACK_TYPE_ACK; + else + *ackp++ = RXRPC_ACK_TYPE_NACK; + seq++; + } while (before_eq(seq, top)); + } + + mtu = conn->params.peer->if_mtu; + mtu -= conn->params.peer->hdrsize; + jmax = (call->nr_jumbo_bad > 3) ? 1 : rxrpc_rx_jumbo_max; + pkt->ackinfo.rxMTU = htonl(rxrpc_rx_mtu); + pkt->ackinfo.maxMTU = htonl(mtu); + pkt->ackinfo.rwind = htonl(call->rx_winsize); + pkt->ackinfo.jumbo_max = htonl(jmax); + + *ackp++ = 0; + *ackp++ = 0; + *ackp++ = 0; + return top - hard_ack + 3; +} + +/* + * Record the beginning of an RTT probe. + */ +static int rxrpc_begin_rtt_probe(struct rxrpc_call *call, rxrpc_serial_t serial, + enum rxrpc_rtt_tx_trace why) +{ + unsigned long avail = call->rtt_avail; + int rtt_slot = 9; + + if (!(avail & RXRPC_CALL_RTT_AVAIL_MASK)) + goto no_slot; + + rtt_slot = __ffs(avail & RXRPC_CALL_RTT_AVAIL_MASK); + if (!test_and_clear_bit(rtt_slot, &call->rtt_avail)) + goto no_slot; + + call->rtt_serial[rtt_slot] = serial; + call->rtt_sent_at[rtt_slot] = ktime_get_real(); + smp_wmb(); /* Write data before avail bit */ + set_bit(rtt_slot + RXRPC_CALL_RTT_PEND_SHIFT, &call->rtt_avail); + + trace_rxrpc_rtt_tx(call, why, rtt_slot, serial); + return rtt_slot; + +no_slot: + trace_rxrpc_rtt_tx(call, rxrpc_rtt_tx_no_slot, rtt_slot, serial); + return -1; +} + +/* + * Cancel an RTT probe. + */ +static void rxrpc_cancel_rtt_probe(struct rxrpc_call *call, + rxrpc_serial_t serial, int rtt_slot) +{ + if (rtt_slot != -1) { + clear_bit(rtt_slot + RXRPC_CALL_RTT_PEND_SHIFT, &call->rtt_avail); + smp_wmb(); /* Clear pending bit before setting slot */ + set_bit(rtt_slot, &call->rtt_avail); + trace_rxrpc_rtt_tx(call, rxrpc_rtt_tx_cancel, rtt_slot, serial); + } +} + +/* + * Send an ACK call packet. + */ +int rxrpc_send_ack_packet(struct rxrpc_call *call, bool ping, + rxrpc_serial_t *_serial) +{ + struct rxrpc_connection *conn; + struct rxrpc_ack_buffer *pkt; + struct msghdr msg; + struct kvec iov[2]; + rxrpc_serial_t serial; + rxrpc_seq_t hard_ack, top; + size_t len, n; + int ret, rtt_slot = -1; + u8 reason; + + if (test_bit(RXRPC_CALL_DISCONNECTED, &call->flags)) + return -ECONNRESET; + + pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); + if (!pkt) + return -ENOMEM; + + conn = call->conn; + + msg.msg_name = &call->peer->srx.transport; + msg.msg_namelen = call->peer->srx.transport_len; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + pkt->whdr.epoch = htonl(conn->proto.epoch); + pkt->whdr.cid = htonl(call->cid); + pkt->whdr.callNumber = htonl(call->call_id); + pkt->whdr.seq = 0; + pkt->whdr.type = RXRPC_PACKET_TYPE_ACK; + pkt->whdr.flags = RXRPC_SLOW_START_OK | conn->out_clientflag; + pkt->whdr.userStatus = 0; + pkt->whdr.securityIndex = call->security_ix; + pkt->whdr._rsvd = 0; + pkt->whdr.serviceId = htons(call->service_id); + + spin_lock_bh(&call->lock); + if (ping) { + reason = RXRPC_ACK_PING; + } else { + reason = call->ackr_reason; + if (!call->ackr_reason) { + spin_unlock_bh(&call->lock); + ret = 0; + goto out; + } + call->ackr_reason = 0; + } + n = rxrpc_fill_out_ack(conn, call, pkt, &hard_ack, &top, reason); + + spin_unlock_bh(&call->lock); + if (n == 0) { + kfree(pkt); + return 0; + } + + iov[0].iov_base = pkt; + iov[0].iov_len = sizeof(pkt->whdr) + sizeof(pkt->ack) + n; + iov[1].iov_base = &pkt->ackinfo; + iov[1].iov_len = sizeof(pkt->ackinfo); + len = iov[0].iov_len + iov[1].iov_len; + + serial = atomic_inc_return(&conn->serial); + pkt->whdr.serial = htonl(serial); + trace_rxrpc_tx_ack(call->debug_id, serial, + ntohl(pkt->ack.firstPacket), + ntohl(pkt->ack.serial), + pkt->ack.reason, pkt->ack.nAcks); + if (_serial) + *_serial = serial; + + if (ping) + rtt_slot = rxrpc_begin_rtt_probe(call, serial, rxrpc_rtt_tx_ping); + + ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 2, len); + conn->params.peer->last_tx_at = ktime_get_seconds(); + if (ret < 0) + trace_rxrpc_tx_fail(call->debug_id, serial, ret, + rxrpc_tx_point_call_ack); + else + trace_rxrpc_tx_packet(call->debug_id, &pkt->whdr, + rxrpc_tx_point_call_ack); + rxrpc_tx_backoff(call, ret); + + if (call->state < RXRPC_CALL_COMPLETE) { + if (ret < 0) { + rxrpc_cancel_rtt_probe(call, serial, rtt_slot); + rxrpc_propose_ACK(call, pkt->ack.reason, + ntohl(pkt->ack.serial), + false, true, + rxrpc_propose_ack_retry_tx); + } + + rxrpc_set_keepalive(call); + } + +out: + kfree(pkt); + return ret; +} + +/* + * Send an ABORT call packet. + */ +int rxrpc_send_abort_packet(struct rxrpc_call *call) +{ + struct rxrpc_connection *conn; + struct rxrpc_abort_buffer pkt; + struct msghdr msg; + struct kvec iov[1]; + rxrpc_serial_t serial; + int ret; + + /* Don't bother sending aborts for a client call once the server has + * hard-ACK'd all of its request data. After that point, we're not + * going to stop the operation proceeding, and whilst we might limit + * the reply, it's not worth it if we can send a new call on the same + * channel instead, thereby closing off this call. + */ + if (rxrpc_is_client_call(call) && + test_bit(RXRPC_CALL_TX_LAST, &call->flags)) + return 0; + + if (test_bit(RXRPC_CALL_DISCONNECTED, &call->flags)) + return -ECONNRESET; + + conn = call->conn; + + msg.msg_name = &call->peer->srx.transport; + msg.msg_namelen = call->peer->srx.transport_len; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + pkt.whdr.epoch = htonl(conn->proto.epoch); + pkt.whdr.cid = htonl(call->cid); + pkt.whdr.callNumber = htonl(call->call_id); + pkt.whdr.seq = 0; + pkt.whdr.type = RXRPC_PACKET_TYPE_ABORT; + pkt.whdr.flags = conn->out_clientflag; + pkt.whdr.userStatus = 0; + pkt.whdr.securityIndex = call->security_ix; + pkt.whdr._rsvd = 0; + pkt.whdr.serviceId = htons(call->service_id); + pkt.abort_code = htonl(call->abort_code); + + iov[0].iov_base = &pkt; + iov[0].iov_len = sizeof(pkt); + + serial = atomic_inc_return(&conn->serial); + pkt.whdr.serial = htonl(serial); + + ret = kernel_sendmsg(conn->params.local->socket, + &msg, iov, 1, sizeof(pkt)); + conn->params.peer->last_tx_at = ktime_get_seconds(); + if (ret < 0) + trace_rxrpc_tx_fail(call->debug_id, serial, ret, + rxrpc_tx_point_call_abort); + else + trace_rxrpc_tx_packet(call->debug_id, &pkt.whdr, + rxrpc_tx_point_call_abort); + rxrpc_tx_backoff(call, ret); + return ret; +} + +/* + * send a packet through the transport endpoint + */ +int rxrpc_send_data_packet(struct rxrpc_call *call, struct sk_buff *skb, + bool retrans) +{ + struct rxrpc_connection *conn = call->conn; + struct rxrpc_wire_header whdr; + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + struct msghdr msg; + struct kvec iov[2]; + rxrpc_serial_t serial; + size_t len; + int ret, rtt_slot = -1; + + _enter(",{%d}", skb->len); + + if (hlist_unhashed(&call->error_link)) { + spin_lock_bh(&call->peer->lock); + hlist_add_head_rcu(&call->error_link, &call->peer->error_targets); + spin_unlock_bh(&call->peer->lock); + } + + /* Each transmission of a Tx packet needs a new serial number */ + serial = atomic_inc_return(&conn->serial); + + whdr.epoch = htonl(conn->proto.epoch); + whdr.cid = htonl(call->cid); + whdr.callNumber = htonl(call->call_id); + whdr.seq = htonl(sp->hdr.seq); + whdr.serial = htonl(serial); + whdr.type = RXRPC_PACKET_TYPE_DATA; + whdr.flags = sp->hdr.flags; + whdr.userStatus = 0; + whdr.securityIndex = call->security_ix; + whdr._rsvd = htons(sp->hdr._rsvd); + whdr.serviceId = htons(call->service_id); + + if (test_bit(RXRPC_CONN_PROBING_FOR_UPGRADE, &conn->flags) && + sp->hdr.seq == 1) + whdr.userStatus = RXRPC_USERSTATUS_SERVICE_UPGRADE; + + iov[0].iov_base = &whdr; + iov[0].iov_len = sizeof(whdr); + iov[1].iov_base = skb->head; + iov[1].iov_len = skb->len; + len = iov[0].iov_len + iov[1].iov_len; + + msg.msg_name = &call->peer->srx.transport; + msg.msg_namelen = call->peer->srx.transport_len; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + /* If our RTT cache needs working on, request an ACK. Also request + * ACKs if a DATA packet appears to have been lost. + * + * However, we mustn't request an ACK on the last reply packet of a + * service call, lest OpenAFS incorrectly send us an ACK with some + * soft-ACKs in it and then never follow up with a proper hard ACK. + */ + if ((!(sp->hdr.flags & RXRPC_LAST_PACKET) || + rxrpc_to_server(sp) + ) && + (test_and_clear_bit(RXRPC_CALL_EV_ACK_LOST, &call->events) || + retrans || + call->cong_mode == RXRPC_CALL_SLOW_START || + (call->peer->rtt_count < 3 && sp->hdr.seq & 1) || + ktime_before(ktime_add_ms(call->peer->rtt_last_req, 1000), + ktime_get_real()))) + whdr.flags |= RXRPC_REQUEST_ACK; + + if (IS_ENABLED(CONFIG_AF_RXRPC_INJECT_LOSS)) { + static int lose; + if ((lose++ & 7) == 7) { + ret = 0; + trace_rxrpc_tx_data(call, sp->hdr.seq, serial, + whdr.flags, retrans, true); + goto done; + } + } + + trace_rxrpc_tx_data(call, sp->hdr.seq, serial, whdr.flags, retrans, + false); + + /* send the packet with the don't fragment bit set if we currently + * think it's small enough */ + if (iov[1].iov_len >= call->peer->maxdata) + goto send_fragmentable; + + down_read(&conn->params.local->defrag_sem); + + sp->hdr.serial = serial; + smp_wmb(); /* Set serial before timestamp */ + skb->tstamp = ktime_get_real(); + if (whdr.flags & RXRPC_REQUEST_ACK) + rtt_slot = rxrpc_begin_rtt_probe(call, serial, rxrpc_rtt_tx_data); + + /* send the packet by UDP + * - returns -EMSGSIZE if UDP would have to fragment the packet + * to go out of the interface + * - in which case, we'll have processed the ICMP error + * message and update the peer record + */ + ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 2, len); + conn->params.peer->last_tx_at = ktime_get_seconds(); + + up_read(&conn->params.local->defrag_sem); + if (ret < 0) { + rxrpc_cancel_rtt_probe(call, serial, rtt_slot); + trace_rxrpc_tx_fail(call->debug_id, serial, ret, + rxrpc_tx_point_call_data_nofrag); + } else { + trace_rxrpc_tx_packet(call->debug_id, &whdr, + rxrpc_tx_point_call_data_nofrag); + } + + rxrpc_tx_backoff(call, ret); + if (ret == -EMSGSIZE) + goto send_fragmentable; + +done: + if (ret >= 0) { + if (whdr.flags & RXRPC_REQUEST_ACK) { + call->peer->rtt_last_req = skb->tstamp; + if (call->peer->rtt_count > 1) { + unsigned long nowj = jiffies, ack_lost_at; + + ack_lost_at = rxrpc_get_rto_backoff(call->peer, false); + ack_lost_at += nowj; + WRITE_ONCE(call->ack_lost_at, ack_lost_at); + rxrpc_reduce_call_timer(call, ack_lost_at, nowj, + rxrpc_timer_set_for_lost_ack); + } + } + + if (sp->hdr.seq == 1 && + !test_and_set_bit(RXRPC_CALL_BEGAN_RX_TIMER, + &call->flags)) { + unsigned long nowj = jiffies, expect_rx_by; + + expect_rx_by = nowj + call->next_rx_timo; + WRITE_ONCE(call->expect_rx_by, expect_rx_by); + rxrpc_reduce_call_timer(call, expect_rx_by, nowj, + rxrpc_timer_set_for_normal); + } + + rxrpc_set_keepalive(call); + } else { + /* Cancel the call if the initial transmission fails, + * particularly if that's due to network routing issues that + * aren't going away anytime soon. The layer above can arrange + * the retransmission. + */ + if (!test_and_set_bit(RXRPC_CALL_BEGAN_RX_TIMER, &call->flags)) + rxrpc_set_call_completion(call, RXRPC_CALL_LOCAL_ERROR, + RX_USER_ABORT, ret); + } + + _leave(" = %d [%u]", ret, call->peer->maxdata); + return ret; + +send_fragmentable: + /* attempt to send this message with fragmentation enabled */ + _debug("send fragment"); + + down_write(&conn->params.local->defrag_sem); + + sp->hdr.serial = serial; + smp_wmb(); /* Set serial before timestamp */ + skb->tstamp = ktime_get_real(); + if (whdr.flags & RXRPC_REQUEST_ACK) + rtt_slot = rxrpc_begin_rtt_probe(call, serial, rxrpc_rtt_tx_data); + + switch (conn->params.local->srx.transport.family) { + case AF_INET6: + case AF_INET: + ip_sock_set_mtu_discover(conn->params.local->socket->sk, + IP_PMTUDISC_DONT); + ret = kernel_sendmsg(conn->params.local->socket, &msg, + iov, 2, len); + conn->params.peer->last_tx_at = ktime_get_seconds(); + + ip_sock_set_mtu_discover(conn->params.local->socket->sk, + IP_PMTUDISC_DO); + break; + + default: + BUG(); + } + + if (ret < 0) { + rxrpc_cancel_rtt_probe(call, serial, rtt_slot); + trace_rxrpc_tx_fail(call->debug_id, serial, ret, + rxrpc_tx_point_call_data_frag); + } else { + trace_rxrpc_tx_packet(call->debug_id, &whdr, + rxrpc_tx_point_call_data_frag); + } + rxrpc_tx_backoff(call, ret); + + up_write(&conn->params.local->defrag_sem); + goto done; +} + +/* + * reject packets through the local endpoint + */ +void rxrpc_reject_packets(struct rxrpc_local *local) +{ + struct sockaddr_rxrpc srx; + struct rxrpc_skb_priv *sp; + struct rxrpc_wire_header whdr; + struct sk_buff *skb; + struct msghdr msg; + struct kvec iov[2]; + size_t size; + __be32 code; + int ret, ioc; + + _enter("%d", local->debug_id); + + iov[0].iov_base = &whdr; + iov[0].iov_len = sizeof(whdr); + iov[1].iov_base = &code; + iov[1].iov_len = sizeof(code); + + msg.msg_name = &srx.transport; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + memset(&whdr, 0, sizeof(whdr)); + + while ((skb = skb_dequeue(&local->reject_queue))) { + rxrpc_see_skb(skb, rxrpc_skb_seen); + sp = rxrpc_skb(skb); + + switch (skb->mark) { + case RXRPC_SKB_MARK_REJECT_BUSY: + whdr.type = RXRPC_PACKET_TYPE_BUSY; + size = sizeof(whdr); + ioc = 1; + break; + case RXRPC_SKB_MARK_REJECT_ABORT: + whdr.type = RXRPC_PACKET_TYPE_ABORT; + code = htonl(skb->priority); + size = sizeof(whdr) + sizeof(code); + ioc = 2; + break; + default: + rxrpc_free_skb(skb, rxrpc_skb_freed); + continue; + } + + if (rxrpc_extract_addr_from_skb(&srx, skb) == 0) { + msg.msg_namelen = srx.transport_len; + + whdr.epoch = htonl(sp->hdr.epoch); + whdr.cid = htonl(sp->hdr.cid); + whdr.callNumber = htonl(sp->hdr.callNumber); + whdr.serviceId = htons(sp->hdr.serviceId); + whdr.flags = sp->hdr.flags; + whdr.flags ^= RXRPC_CLIENT_INITIATED; + whdr.flags &= RXRPC_CLIENT_INITIATED; + + ret = kernel_sendmsg(local->socket, &msg, + iov, ioc, size); + if (ret < 0) + trace_rxrpc_tx_fail(local->debug_id, 0, ret, + rxrpc_tx_point_reject); + else + trace_rxrpc_tx_packet(local->debug_id, &whdr, + rxrpc_tx_point_reject); + } + + rxrpc_free_skb(skb, rxrpc_skb_freed); + } + + _leave(""); +} + +/* + * Send a VERSION reply to a peer as a keepalive. + */ +void rxrpc_send_keepalive(struct rxrpc_peer *peer) +{ + struct rxrpc_wire_header whdr; + struct msghdr msg; + struct kvec iov[2]; + size_t len; + int ret; + + _enter(""); + + msg.msg_name = &peer->srx.transport; + msg.msg_namelen = peer->srx.transport_len; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + whdr.epoch = htonl(peer->local->rxnet->epoch); + whdr.cid = 0; + whdr.callNumber = 0; + whdr.seq = 0; + whdr.serial = 0; + whdr.type = RXRPC_PACKET_TYPE_VERSION; /* Not client-initiated */ + whdr.flags = RXRPC_LAST_PACKET; + whdr.userStatus = 0; + whdr.securityIndex = 0; + whdr._rsvd = 0; + whdr.serviceId = 0; + + iov[0].iov_base = &whdr; + iov[0].iov_len = sizeof(whdr); + iov[1].iov_base = (char *)rxrpc_keepalive_string; + iov[1].iov_len = sizeof(rxrpc_keepalive_string); + + len = iov[0].iov_len + iov[1].iov_len; + + _proto("Tx VERSION (keepalive)"); + + ret = kernel_sendmsg(peer->local->socket, &msg, iov, 2, len); + if (ret < 0) + trace_rxrpc_tx_fail(peer->debug_id, 0, ret, + rxrpc_tx_point_version_keepalive); + else + trace_rxrpc_tx_packet(peer->debug_id, &whdr, + rxrpc_tx_point_version_keepalive); + + peer->last_tx_at = ktime_get_seconds(); + _leave(""); +} diff --git a/net/rxrpc/peer_event.c b/net/rxrpc/peer_event.c new file mode 100644 index 000000000..32561e956 --- /dev/null +++ b/net/rxrpc/peer_event.c @@ -0,0 +1,636 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* Peer event handling, typically ICMP messages. + * + * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ar-internal.h" + +static void rxrpc_adjust_mtu(struct rxrpc_peer *, unsigned int); +static void rxrpc_store_error(struct rxrpc_peer *, struct sock_exterr_skb *); +static void rxrpc_distribute_error(struct rxrpc_peer *, int, + enum rxrpc_call_completion); + +/* + * Find the peer associated with an ICMPv4 packet. + */ +static struct rxrpc_peer *rxrpc_lookup_peer_icmp_rcu(struct rxrpc_local *local, + struct sk_buff *skb, + unsigned int udp_offset, + unsigned int *info, + struct sockaddr_rxrpc *srx) +{ + struct iphdr *ip, *ip0 = ip_hdr(skb); + struct icmphdr *icmp = icmp_hdr(skb); + struct udphdr *udp = (struct udphdr *)(skb->data + udp_offset); + + _enter("%u,%u,%u", ip0->protocol, icmp->type, icmp->code); + + switch (icmp->type) { + case ICMP_DEST_UNREACH: + *info = ntohs(icmp->un.frag.mtu); + fallthrough; + case ICMP_TIME_EXCEEDED: + case ICMP_PARAMETERPROB: + ip = (struct iphdr *)((void *)icmp + 8); + break; + default: + return NULL; + } + + memset(srx, 0, sizeof(*srx)); + srx->transport_type = local->srx.transport_type; + srx->transport_len = local->srx.transport_len; + srx->transport.family = local->srx.transport.family; + + /* Can we see an ICMP4 packet on an ICMP6 listening socket? and vice + * versa? + */ + switch (srx->transport.family) { + case AF_INET: + srx->transport_len = sizeof(srx->transport.sin); + srx->transport.family = AF_INET; + srx->transport.sin.sin_port = udp->dest; + memcpy(&srx->transport.sin.sin_addr, &ip->daddr, + sizeof(struct in_addr)); + break; + +#ifdef CONFIG_AF_RXRPC_IPV6 + case AF_INET6: + srx->transport_len = sizeof(srx->transport.sin); + srx->transport.family = AF_INET; + srx->transport.sin.sin_port = udp->dest; + memcpy(&srx->transport.sin.sin_addr, &ip->daddr, + sizeof(struct in_addr)); + break; +#endif + + default: + WARN_ON_ONCE(1); + return NULL; + } + + _net("ICMP {%pISp}", &srx->transport); + return rxrpc_lookup_peer_rcu(local, srx); +} + +#ifdef CONFIG_AF_RXRPC_IPV6 +/* + * Find the peer associated with an ICMPv6 packet. + */ +static struct rxrpc_peer *rxrpc_lookup_peer_icmp6_rcu(struct rxrpc_local *local, + struct sk_buff *skb, + unsigned int udp_offset, + unsigned int *info, + struct sockaddr_rxrpc *srx) +{ + struct icmp6hdr *icmp = icmp6_hdr(skb); + struct ipv6hdr *ip, *ip0 = ipv6_hdr(skb); + struct udphdr *udp = (struct udphdr *)(skb->data + udp_offset); + + _enter("%u,%u,%u", ip0->nexthdr, icmp->icmp6_type, icmp->icmp6_code); + + switch (icmp->icmp6_type) { + case ICMPV6_DEST_UNREACH: + *info = ntohl(icmp->icmp6_mtu); + fallthrough; + case ICMPV6_PKT_TOOBIG: + case ICMPV6_TIME_EXCEED: + case ICMPV6_PARAMPROB: + ip = (struct ipv6hdr *)((void *)icmp + 8); + break; + default: + return NULL; + } + + memset(srx, 0, sizeof(*srx)); + srx->transport_type = local->srx.transport_type; + srx->transport_len = local->srx.transport_len; + srx->transport.family = local->srx.transport.family; + + /* Can we see an ICMP4 packet on an ICMP6 listening socket? and vice + * versa? + */ + switch (srx->transport.family) { + case AF_INET: + _net("Rx ICMP6 on v4 sock"); + srx->transport_len = sizeof(srx->transport.sin); + srx->transport.family = AF_INET; + srx->transport.sin.sin_port = udp->dest; + memcpy(&srx->transport.sin.sin_addr, + &ip->daddr.s6_addr32[3], sizeof(struct in_addr)); + break; + case AF_INET6: + _net("Rx ICMP6"); + srx->transport.sin.sin_port = udp->dest; + memcpy(&srx->transport.sin6.sin6_addr, &ip->daddr, + sizeof(struct in6_addr)); + break; + default: + WARN_ON_ONCE(1); + return NULL; + } + + _net("ICMP {%pISp}", &srx->transport); + return rxrpc_lookup_peer_rcu(local, srx); +} +#endif /* CONFIG_AF_RXRPC_IPV6 */ + +/* + * Handle an error received on the local endpoint as a tunnel. + */ +void rxrpc_encap_err_rcv(struct sock *sk, struct sk_buff *skb, + unsigned int udp_offset) +{ + struct sock_extended_err ee; + struct sockaddr_rxrpc srx; + struct rxrpc_local *local; + struct rxrpc_peer *peer; + unsigned int info = 0; + int err; + u8 version = ip_hdr(skb)->version; + u8 type = icmp_hdr(skb)->type; + u8 code = icmp_hdr(skb)->code; + + rcu_read_lock(); + local = rcu_dereference_sk_user_data(sk); + if (unlikely(!local)) { + rcu_read_unlock(); + return; + } + + rxrpc_new_skb(skb, rxrpc_skb_received); + + switch (ip_hdr(skb)->version) { + case IPVERSION: + peer = rxrpc_lookup_peer_icmp_rcu(local, skb, udp_offset, + &info, &srx); + break; +#ifdef CONFIG_AF_RXRPC_IPV6 + case 6: + peer = rxrpc_lookup_peer_icmp6_rcu(local, skb, udp_offset, + &info, &srx); + break; +#endif + default: + rcu_read_unlock(); + return; + } + + if (peer && !rxrpc_get_peer_maybe(peer)) + peer = NULL; + if (!peer) { + rcu_read_unlock(); + return; + } + + memset(&ee, 0, sizeof(ee)); + + switch (version) { + case IPVERSION: + switch (type) { + case ICMP_DEST_UNREACH: + switch (code) { + case ICMP_FRAG_NEEDED: + rxrpc_adjust_mtu(peer, info); + rcu_read_unlock(); + rxrpc_put_peer(peer); + return; + default: + break; + } + + err = EHOSTUNREACH; + if (code <= NR_ICMP_UNREACH) { + /* Might want to do something different with + * non-fatal errors + */ + //harderr = icmp_err_convert[code].fatal; + err = icmp_err_convert[code].errno; + } + break; + + case ICMP_TIME_EXCEEDED: + err = EHOSTUNREACH; + break; + default: + err = EPROTO; + break; + } + + ee.ee_origin = SO_EE_ORIGIN_ICMP; + ee.ee_type = type; + ee.ee_code = code; + ee.ee_errno = err; + break; + +#ifdef CONFIG_AF_RXRPC_IPV6 + case 6: + switch (type) { + case ICMPV6_PKT_TOOBIG: + rxrpc_adjust_mtu(peer, info); + rcu_read_unlock(); + rxrpc_put_peer(peer); + return; + } + + icmpv6_err_convert(type, code, &err); + + if (err == EACCES) + err = EHOSTUNREACH; + + ee.ee_origin = SO_EE_ORIGIN_ICMP6; + ee.ee_type = type; + ee.ee_code = code; + ee.ee_errno = err; + break; +#endif + } + + trace_rxrpc_rx_icmp(peer, &ee, &srx); + + rxrpc_distribute_error(peer, err, RXRPC_CALL_NETWORK_ERROR); + rcu_read_unlock(); + rxrpc_put_peer(peer); +} + +/* + * Find the peer associated with a local error. + */ +static struct rxrpc_peer *rxrpc_lookup_peer_local_rcu(struct rxrpc_local *local, + const struct sk_buff *skb, + struct sockaddr_rxrpc *srx) +{ + struct sock_exterr_skb *serr = SKB_EXT_ERR(skb); + + _enter(""); + + memset(srx, 0, sizeof(*srx)); + srx->transport_type = local->srx.transport_type; + srx->transport_len = local->srx.transport_len; + srx->transport.family = local->srx.transport.family; + + switch (srx->transport.family) { + case AF_INET: + srx->transport_len = sizeof(srx->transport.sin); + srx->transport.family = AF_INET; + srx->transport.sin.sin_port = serr->port; + switch (serr->ee.ee_origin) { + case SO_EE_ORIGIN_ICMP: + _net("Rx ICMP"); + memcpy(&srx->transport.sin.sin_addr, + skb_network_header(skb) + serr->addr_offset, + sizeof(struct in_addr)); + break; + case SO_EE_ORIGIN_ICMP6: + _net("Rx ICMP6 on v4 sock"); + memcpy(&srx->transport.sin.sin_addr, + skb_network_header(skb) + serr->addr_offset + 12, + sizeof(struct in_addr)); + break; + default: + memcpy(&srx->transport.sin.sin_addr, &ip_hdr(skb)->saddr, + sizeof(struct in_addr)); + break; + } + break; + +#ifdef CONFIG_AF_RXRPC_IPV6 + case AF_INET6: + switch (serr->ee.ee_origin) { + case SO_EE_ORIGIN_ICMP6: + _net("Rx ICMP6"); + srx->transport.sin6.sin6_port = serr->port; + memcpy(&srx->transport.sin6.sin6_addr, + skb_network_header(skb) + serr->addr_offset, + sizeof(struct in6_addr)); + break; + case SO_EE_ORIGIN_ICMP: + _net("Rx ICMP on v6 sock"); + srx->transport_len = sizeof(srx->transport.sin); + srx->transport.family = AF_INET; + srx->transport.sin.sin_port = serr->port; + memcpy(&srx->transport.sin.sin_addr, + skb_network_header(skb) + serr->addr_offset, + sizeof(struct in_addr)); + break; + default: + memcpy(&srx->transport.sin6.sin6_addr, + &ipv6_hdr(skb)->saddr, + sizeof(struct in6_addr)); + break; + } + break; +#endif + + default: + BUG(); + } + + return rxrpc_lookup_peer_rcu(local, srx); +} + +/* + * Handle an MTU/fragmentation problem. + */ +static void rxrpc_adjust_mtu(struct rxrpc_peer *peer, unsigned int mtu) +{ + _net("Rx ICMP Fragmentation Needed (%d)", mtu); + + /* wind down the local interface MTU */ + if (mtu > 0 && peer->if_mtu == 65535 && mtu < peer->if_mtu) { + peer->if_mtu = mtu; + _net("I/F MTU %u", mtu); + } + + if (mtu == 0) { + /* they didn't give us a size, estimate one */ + mtu = peer->if_mtu; + if (mtu > 1500) { + mtu >>= 1; + if (mtu < 1500) + mtu = 1500; + } else { + mtu -= 100; + if (mtu < peer->hdrsize) + mtu = peer->hdrsize + 4; + } + } + + if (mtu < peer->mtu) { + spin_lock_bh(&peer->lock); + peer->mtu = mtu; + peer->maxdata = peer->mtu - peer->hdrsize; + spin_unlock_bh(&peer->lock); + _net("Net MTU %u (maxdata %u)", + peer->mtu, peer->maxdata); + } +} + +/* + * Handle an error received on the local endpoint. + */ +void rxrpc_error_report(struct sock *sk) +{ + struct sock_exterr_skb *serr; + struct sockaddr_rxrpc srx; + struct rxrpc_local *local; + struct rxrpc_peer *peer = NULL; + struct sk_buff *skb; + + rcu_read_lock(); + local = rcu_dereference_sk_user_data(sk); + if (unlikely(!local)) { + rcu_read_unlock(); + return; + } + _enter("%p{%d}", sk, local->debug_id); + + /* Clear the outstanding error value on the socket so that it doesn't + * cause kernel_sendmsg() to return it later. + */ + sock_error(sk); + + skb = sock_dequeue_err_skb(sk); + if (!skb) { + rcu_read_unlock(); + _leave("UDP socket errqueue empty"); + return; + } + rxrpc_new_skb(skb, rxrpc_skb_received); + serr = SKB_EXT_ERR(skb); + + if (serr->ee.ee_origin == SO_EE_ORIGIN_LOCAL) { + peer = rxrpc_lookup_peer_local_rcu(local, skb, &srx); + if (peer && !rxrpc_get_peer_maybe(peer)) + peer = NULL; + if (peer) { + trace_rxrpc_rx_icmp(peer, &serr->ee, &srx); + rxrpc_store_error(peer, serr); + } + } + + rcu_read_unlock(); + rxrpc_free_skb(skb, rxrpc_skb_freed); + rxrpc_put_peer(peer); + _leave(""); +} + +/* + * Map an error report to error codes on the peer record. + */ +static void rxrpc_store_error(struct rxrpc_peer *peer, + struct sock_exterr_skb *serr) +{ + enum rxrpc_call_completion compl = RXRPC_CALL_NETWORK_ERROR; + struct sock_extended_err *ee; + int err; + + _enter(""); + + ee = &serr->ee; + + err = ee->ee_errno; + + switch (ee->ee_origin) { + case SO_EE_ORIGIN_ICMP: + switch (ee->ee_type) { + case ICMP_DEST_UNREACH: + switch (ee->ee_code) { + case ICMP_NET_UNREACH: + _net("Rx Received ICMP Network Unreachable"); + break; + case ICMP_HOST_UNREACH: + _net("Rx Received ICMP Host Unreachable"); + break; + case ICMP_PORT_UNREACH: + _net("Rx Received ICMP Port Unreachable"); + break; + case ICMP_NET_UNKNOWN: + _net("Rx Received ICMP Unknown Network"); + break; + case ICMP_HOST_UNKNOWN: + _net("Rx Received ICMP Unknown Host"); + break; + default: + _net("Rx Received ICMP DestUnreach code=%u", + ee->ee_code); + break; + } + break; + + case ICMP_TIME_EXCEEDED: + _net("Rx Received ICMP TTL Exceeded"); + break; + + default: + _proto("Rx Received ICMP error { type=%u code=%u }", + ee->ee_type, ee->ee_code); + break; + } + break; + + case SO_EE_ORIGIN_NONE: + case SO_EE_ORIGIN_LOCAL: + _proto("Rx Received local error { error=%d }", err); + compl = RXRPC_CALL_LOCAL_ERROR; + break; + + case SO_EE_ORIGIN_ICMP6: + if (err == EACCES) + err = EHOSTUNREACH; + fallthrough; + default: + _proto("Rx Received error report { orig=%u }", ee->ee_origin); + break; + } + + rxrpc_distribute_error(peer, err, compl); +} + +/* + * Distribute an error that occurred on a peer. + */ +static void rxrpc_distribute_error(struct rxrpc_peer *peer, int error, + enum rxrpc_call_completion compl) +{ + struct rxrpc_call *call; + + hlist_for_each_entry_rcu(call, &peer->error_targets, error_link) { + rxrpc_see_call(call); + rxrpc_set_call_completion(call, compl, 0, -error); + } +} + +/* + * Perform keep-alive pings. + */ +static void rxrpc_peer_keepalive_dispatch(struct rxrpc_net *rxnet, + struct list_head *collector, + time64_t base, + u8 cursor) +{ + struct rxrpc_peer *peer; + const u8 mask = ARRAY_SIZE(rxnet->peer_keepalive) - 1; + time64_t keepalive_at; + int slot; + + spin_lock_bh(&rxnet->peer_hash_lock); + + while (!list_empty(collector)) { + peer = list_entry(collector->next, + struct rxrpc_peer, keepalive_link); + + list_del_init(&peer->keepalive_link); + if (!rxrpc_get_peer_maybe(peer)) + continue; + + if (__rxrpc_use_local(peer->local)) { + spin_unlock_bh(&rxnet->peer_hash_lock); + + keepalive_at = peer->last_tx_at + RXRPC_KEEPALIVE_TIME; + slot = keepalive_at - base; + _debug("%02x peer %u t=%d {%pISp}", + cursor, peer->debug_id, slot, &peer->srx.transport); + + if (keepalive_at <= base || + keepalive_at > base + RXRPC_KEEPALIVE_TIME) { + rxrpc_send_keepalive(peer); + slot = RXRPC_KEEPALIVE_TIME; + } + + /* A transmission to this peer occurred since last we + * examined it so put it into the appropriate future + * bucket. + */ + slot += cursor; + slot &= mask; + spin_lock_bh(&rxnet->peer_hash_lock); + list_add_tail(&peer->keepalive_link, + &rxnet->peer_keepalive[slot & mask]); + rxrpc_unuse_local(peer->local); + } + rxrpc_put_peer_locked(peer); + } + + spin_unlock_bh(&rxnet->peer_hash_lock); +} + +/* + * Perform keep-alive pings with VERSION packets to keep any NAT alive. + */ +void rxrpc_peer_keepalive_worker(struct work_struct *work) +{ + struct rxrpc_net *rxnet = + container_of(work, struct rxrpc_net, peer_keepalive_work); + const u8 mask = ARRAY_SIZE(rxnet->peer_keepalive) - 1; + time64_t base, now, delay; + u8 cursor, stop; + LIST_HEAD(collector); + + now = ktime_get_seconds(); + base = rxnet->peer_keepalive_base; + cursor = rxnet->peer_keepalive_cursor; + _enter("%lld,%u", base - now, cursor); + + if (!rxnet->live) + return; + + /* Remove to a temporary list all the peers that are currently lodged + * in expired buckets plus all new peers. + * + * Everything in the bucket at the cursor is processed this + * second; the bucket at cursor + 1 goes at now + 1s and so + * on... + */ + spin_lock_bh(&rxnet->peer_hash_lock); + list_splice_init(&rxnet->peer_keepalive_new, &collector); + + stop = cursor + ARRAY_SIZE(rxnet->peer_keepalive); + while (base <= now && (s8)(cursor - stop) < 0) { + list_splice_tail_init(&rxnet->peer_keepalive[cursor & mask], + &collector); + base++; + cursor++; + } + + base = now; + spin_unlock_bh(&rxnet->peer_hash_lock); + + rxnet->peer_keepalive_base = base; + rxnet->peer_keepalive_cursor = cursor; + rxrpc_peer_keepalive_dispatch(rxnet, &collector, base, cursor); + ASSERT(list_empty(&collector)); + + /* Schedule the timer for the next occupied timeslot. */ + cursor = rxnet->peer_keepalive_cursor; + stop = cursor + RXRPC_KEEPALIVE_TIME - 1; + for (; (s8)(cursor - stop) < 0; cursor++) { + if (!list_empty(&rxnet->peer_keepalive[cursor & mask])) + break; + base++; + } + + now = ktime_get_seconds(); + delay = base - now; + if (delay < 1) + delay = 1; + delay *= HZ; + if (rxnet->live) + timer_reduce(&rxnet->peer_keepalive_timer, jiffies + delay); + + _leave(""); +} diff --git a/net/rxrpc/peer_object.c b/net/rxrpc/peer_object.c new file mode 100644 index 000000000..26d2ae9ba --- /dev/null +++ b/net/rxrpc/peer_object.c @@ -0,0 +1,528 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* RxRPC remote transport endpoint record management + * + * Copyright (C) 2007, 2016 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ar-internal.h" + +/* + * Hash a peer key. + */ +static unsigned long rxrpc_peer_hash_key(struct rxrpc_local *local, + const struct sockaddr_rxrpc *srx) +{ + const u16 *p; + unsigned int i, size; + unsigned long hash_key; + + _enter(""); + + hash_key = (unsigned long)local / __alignof__(*local); + hash_key += srx->transport_type; + hash_key += srx->transport_len; + hash_key += srx->transport.family; + + switch (srx->transport.family) { + case AF_INET: + hash_key += (u16 __force)srx->transport.sin.sin_port; + size = sizeof(srx->transport.sin.sin_addr); + p = (u16 *)&srx->transport.sin.sin_addr; + break; +#ifdef CONFIG_AF_RXRPC_IPV6 + case AF_INET6: + hash_key += (u16 __force)srx->transport.sin.sin_port; + size = sizeof(srx->transport.sin6.sin6_addr); + p = (u16 *)&srx->transport.sin6.sin6_addr; + break; +#endif + default: + WARN(1, "AF_RXRPC: Unsupported transport address family\n"); + return 0; + } + + /* Step through the peer address in 16-bit portions for speed */ + for (i = 0; i < size; i += sizeof(*p), p++) + hash_key += *p; + + _leave(" 0x%lx", hash_key); + return hash_key; +} + +/* + * Compare a peer to a key. Return -ve, 0 or +ve to indicate less than, same + * or greater than. + * + * Unfortunately, the primitives in linux/hashtable.h don't allow for sorted + * buckets and mid-bucket insertion, so we don't make full use of this + * information at this point. + */ +static long rxrpc_peer_cmp_key(const struct rxrpc_peer *peer, + struct rxrpc_local *local, + const struct sockaddr_rxrpc *srx, + unsigned long hash_key) +{ + long diff; + + diff = ((peer->hash_key - hash_key) ?: + ((unsigned long)peer->local - (unsigned long)local) ?: + (peer->srx.transport_type - srx->transport_type) ?: + (peer->srx.transport_len - srx->transport_len) ?: + (peer->srx.transport.family - srx->transport.family)); + if (diff != 0) + return diff; + + switch (srx->transport.family) { + case AF_INET: + return ((u16 __force)peer->srx.transport.sin.sin_port - + (u16 __force)srx->transport.sin.sin_port) ?: + memcmp(&peer->srx.transport.sin.sin_addr, + &srx->transport.sin.sin_addr, + sizeof(struct in_addr)); +#ifdef CONFIG_AF_RXRPC_IPV6 + case AF_INET6: + return ((u16 __force)peer->srx.transport.sin6.sin6_port - + (u16 __force)srx->transport.sin6.sin6_port) ?: + memcmp(&peer->srx.transport.sin6.sin6_addr, + &srx->transport.sin6.sin6_addr, + sizeof(struct in6_addr)); +#endif + default: + BUG(); + } +} + +/* + * Look up a remote transport endpoint for the specified address using RCU. + */ +static struct rxrpc_peer *__rxrpc_lookup_peer_rcu( + struct rxrpc_local *local, + const struct sockaddr_rxrpc *srx, + unsigned long hash_key) +{ + struct rxrpc_peer *peer; + struct rxrpc_net *rxnet = local->rxnet; + + hash_for_each_possible_rcu(rxnet->peer_hash, peer, hash_link, hash_key) { + if (rxrpc_peer_cmp_key(peer, local, srx, hash_key) == 0 && + refcount_read(&peer->ref) > 0) + return peer; + } + + return NULL; +} + +/* + * Look up a remote transport endpoint for the specified address using RCU. + */ +struct rxrpc_peer *rxrpc_lookup_peer_rcu(struct rxrpc_local *local, + const struct sockaddr_rxrpc *srx) +{ + struct rxrpc_peer *peer; + unsigned long hash_key = rxrpc_peer_hash_key(local, srx); + + peer = __rxrpc_lookup_peer_rcu(local, srx, hash_key); + if (peer) { + _net("PEER %d {%pISp}", peer->debug_id, &peer->srx.transport); + _leave(" = %p {u=%d}", peer, refcount_read(&peer->ref)); + } + return peer; +} + +/* + * assess the MTU size for the network interface through which this peer is + * reached + */ +static void rxrpc_assess_MTU_size(struct rxrpc_sock *rx, + struct rxrpc_peer *peer) +{ + struct net *net = sock_net(&rx->sk); + struct dst_entry *dst; + struct rtable *rt; + struct flowi fl; + struct flowi4 *fl4 = &fl.u.ip4; +#ifdef CONFIG_AF_RXRPC_IPV6 + struct flowi6 *fl6 = &fl.u.ip6; +#endif + + peer->if_mtu = 1500; + + memset(&fl, 0, sizeof(fl)); + switch (peer->srx.transport.family) { + case AF_INET: + rt = ip_route_output_ports( + net, fl4, NULL, + peer->srx.transport.sin.sin_addr.s_addr, 0, + htons(7000), htons(7001), IPPROTO_UDP, 0, 0); + if (IS_ERR(rt)) { + _leave(" [route err %ld]", PTR_ERR(rt)); + return; + } + dst = &rt->dst; + break; + +#ifdef CONFIG_AF_RXRPC_IPV6 + case AF_INET6: + fl6->flowi6_iif = LOOPBACK_IFINDEX; + fl6->flowi6_scope = RT_SCOPE_UNIVERSE; + fl6->flowi6_proto = IPPROTO_UDP; + memcpy(&fl6->daddr, &peer->srx.transport.sin6.sin6_addr, + sizeof(struct in6_addr)); + fl6->fl6_dport = htons(7001); + fl6->fl6_sport = htons(7000); + dst = ip6_route_output(net, NULL, fl6); + if (dst->error) { + _leave(" [route err %d]", dst->error); + return; + } + break; +#endif + + default: + BUG(); + } + + peer->if_mtu = dst_mtu(dst); + dst_release(dst); + + _leave(" [if_mtu %u]", peer->if_mtu); +} + +/* + * Allocate a peer. + */ +struct rxrpc_peer *rxrpc_alloc_peer(struct rxrpc_local *local, gfp_t gfp) +{ + const void *here = __builtin_return_address(0); + struct rxrpc_peer *peer; + + _enter(""); + + peer = kzalloc(sizeof(struct rxrpc_peer), gfp); + if (peer) { + refcount_set(&peer->ref, 1); + peer->local = rxrpc_get_local(local); + INIT_HLIST_HEAD(&peer->error_targets); + peer->service_conns = RB_ROOT; + seqlock_init(&peer->service_conn_lock); + spin_lock_init(&peer->lock); + spin_lock_init(&peer->rtt_input_lock); + peer->debug_id = atomic_inc_return(&rxrpc_debug_id); + + rxrpc_peer_init_rtt(peer); + + if (RXRPC_TX_SMSS > 2190) + peer->cong_cwnd = 2; + else if (RXRPC_TX_SMSS > 1095) + peer->cong_cwnd = 3; + else + peer->cong_cwnd = 4; + trace_rxrpc_peer(peer->debug_id, rxrpc_peer_new, 1, here); + } + + _leave(" = %p", peer); + return peer; +} + +/* + * Initialise peer record. + */ +static void rxrpc_init_peer(struct rxrpc_sock *rx, struct rxrpc_peer *peer, + unsigned long hash_key) +{ + peer->hash_key = hash_key; + rxrpc_assess_MTU_size(rx, peer); + peer->mtu = peer->if_mtu; + peer->rtt_last_req = ktime_get_real(); + + switch (peer->srx.transport.family) { + case AF_INET: + peer->hdrsize = sizeof(struct iphdr); + break; +#ifdef CONFIG_AF_RXRPC_IPV6 + case AF_INET6: + peer->hdrsize = sizeof(struct ipv6hdr); + break; +#endif + default: + BUG(); + } + + switch (peer->srx.transport_type) { + case SOCK_DGRAM: + peer->hdrsize += sizeof(struct udphdr); + break; + default: + BUG(); + } + + peer->hdrsize += sizeof(struct rxrpc_wire_header); + peer->maxdata = peer->mtu - peer->hdrsize; +} + +/* + * Set up a new peer. + */ +static struct rxrpc_peer *rxrpc_create_peer(struct rxrpc_sock *rx, + struct rxrpc_local *local, + struct sockaddr_rxrpc *srx, + unsigned long hash_key, + gfp_t gfp) +{ + struct rxrpc_peer *peer; + + _enter(""); + + peer = rxrpc_alloc_peer(local, gfp); + if (peer) { + memcpy(&peer->srx, srx, sizeof(*srx)); + rxrpc_init_peer(rx, peer, hash_key); + } + + _leave(" = %p", peer); + return peer; +} + +static void rxrpc_free_peer(struct rxrpc_peer *peer) +{ + rxrpc_put_local(peer->local); + kfree_rcu(peer, rcu); +} + +/* + * Set up a new incoming peer. There shouldn't be any other matching peers + * since we've already done a search in the list from the non-reentrant context + * (the data_ready handler) that is the only place we can add new peers. + */ +void rxrpc_new_incoming_peer(struct rxrpc_sock *rx, struct rxrpc_local *local, + struct rxrpc_peer *peer) +{ + struct rxrpc_net *rxnet = local->rxnet; + unsigned long hash_key; + + hash_key = rxrpc_peer_hash_key(local, &peer->srx); + rxrpc_init_peer(rx, peer, hash_key); + + spin_lock(&rxnet->peer_hash_lock); + hash_add_rcu(rxnet->peer_hash, &peer->hash_link, hash_key); + list_add_tail(&peer->keepalive_link, &rxnet->peer_keepalive_new); + spin_unlock(&rxnet->peer_hash_lock); +} + +/* + * obtain a remote transport endpoint for the specified address + */ +struct rxrpc_peer *rxrpc_lookup_peer(struct rxrpc_sock *rx, + struct rxrpc_local *local, + struct sockaddr_rxrpc *srx, gfp_t gfp) +{ + struct rxrpc_peer *peer, *candidate; + struct rxrpc_net *rxnet = local->rxnet; + unsigned long hash_key = rxrpc_peer_hash_key(local, srx); + + _enter("{%pISp}", &srx->transport); + + /* search the peer list first */ + rcu_read_lock(); + peer = __rxrpc_lookup_peer_rcu(local, srx, hash_key); + if (peer && !rxrpc_get_peer_maybe(peer)) + peer = NULL; + rcu_read_unlock(); + + if (!peer) { + /* The peer is not yet present in hash - create a candidate + * for a new record and then redo the search. + */ + candidate = rxrpc_create_peer(rx, local, srx, hash_key, gfp); + if (!candidate) { + _leave(" = NULL [nomem]"); + return NULL; + } + + spin_lock_bh(&rxnet->peer_hash_lock); + + /* Need to check that we aren't racing with someone else */ + peer = __rxrpc_lookup_peer_rcu(local, srx, hash_key); + if (peer && !rxrpc_get_peer_maybe(peer)) + peer = NULL; + if (!peer) { + hash_add_rcu(rxnet->peer_hash, + &candidate->hash_link, hash_key); + list_add_tail(&candidate->keepalive_link, + &rxnet->peer_keepalive_new); + } + + spin_unlock_bh(&rxnet->peer_hash_lock); + + if (peer) + rxrpc_free_peer(candidate); + else + peer = candidate; + } + + _net("PEER %d {%pISp}", peer->debug_id, &peer->srx.transport); + + _leave(" = %p {u=%d}", peer, refcount_read(&peer->ref)); + return peer; +} + +/* + * Get a ref on a peer record. + */ +struct rxrpc_peer *rxrpc_get_peer(struct rxrpc_peer *peer) +{ + const void *here = __builtin_return_address(0); + int r; + + __refcount_inc(&peer->ref, &r); + trace_rxrpc_peer(peer->debug_id, rxrpc_peer_got, r + 1, here); + return peer; +} + +/* + * Get a ref on a peer record unless its usage has already reached 0. + */ +struct rxrpc_peer *rxrpc_get_peer_maybe(struct rxrpc_peer *peer) +{ + const void *here = __builtin_return_address(0); + int r; + + if (peer) { + if (__refcount_inc_not_zero(&peer->ref, &r)) + trace_rxrpc_peer(peer->debug_id, rxrpc_peer_got, r + 1, here); + else + peer = NULL; + } + return peer; +} + +/* + * Discard a peer record. + */ +static void __rxrpc_put_peer(struct rxrpc_peer *peer) +{ + struct rxrpc_net *rxnet = peer->local->rxnet; + + ASSERT(hlist_empty(&peer->error_targets)); + + spin_lock_bh(&rxnet->peer_hash_lock); + hash_del_rcu(&peer->hash_link); + list_del_init(&peer->keepalive_link); + spin_unlock_bh(&rxnet->peer_hash_lock); + + rxrpc_free_peer(peer); +} + +/* + * Drop a ref on a peer record. + */ +void rxrpc_put_peer(struct rxrpc_peer *peer) +{ + const void *here = __builtin_return_address(0); + unsigned int debug_id; + bool dead; + int r; + + if (peer) { + debug_id = peer->debug_id; + dead = __refcount_dec_and_test(&peer->ref, &r); + trace_rxrpc_peer(debug_id, rxrpc_peer_put, r - 1, here); + if (dead) + __rxrpc_put_peer(peer); + } +} + +/* + * Drop a ref on a peer record where the caller already holds the + * peer_hash_lock. + */ +void rxrpc_put_peer_locked(struct rxrpc_peer *peer) +{ + const void *here = __builtin_return_address(0); + unsigned int debug_id = peer->debug_id; + bool dead; + int r; + + dead = __refcount_dec_and_test(&peer->ref, &r); + trace_rxrpc_peer(debug_id, rxrpc_peer_put, r - 1, here); + if (dead) { + hash_del_rcu(&peer->hash_link); + list_del_init(&peer->keepalive_link); + rxrpc_free_peer(peer); + } +} + +/* + * Make sure all peer records have been discarded. + */ +void rxrpc_destroy_all_peers(struct rxrpc_net *rxnet) +{ + struct rxrpc_peer *peer; + int i; + + for (i = 0; i < HASH_SIZE(rxnet->peer_hash); i++) { + if (hlist_empty(&rxnet->peer_hash[i])) + continue; + + hlist_for_each_entry(peer, &rxnet->peer_hash[i], hash_link) { + pr_err("Leaked peer %u {%u} %pISp\n", + peer->debug_id, + refcount_read(&peer->ref), + &peer->srx.transport); + } + } +} + +/** + * rxrpc_kernel_get_peer - Get the peer address of a call + * @sock: The socket on which the call is in progress. + * @call: The call to query + * @_srx: Where to place the result + * + * Get the address of the remote peer in a call. + */ +void rxrpc_kernel_get_peer(struct socket *sock, struct rxrpc_call *call, + struct sockaddr_rxrpc *_srx) +{ + *_srx = call->peer->srx; +} +EXPORT_SYMBOL(rxrpc_kernel_get_peer); + +/** + * rxrpc_kernel_get_srtt - Get a call's peer smoothed RTT + * @sock: The socket on which the call is in progress. + * @call: The call to query + * @_srtt: Where to store the SRTT value. + * + * Get the call's peer smoothed RTT in uS. + */ +bool rxrpc_kernel_get_srtt(struct socket *sock, struct rxrpc_call *call, + u32 *_srtt) +{ + struct rxrpc_peer *peer = call->peer; + + if (peer->rtt_count == 0) { + *_srtt = 1000000; /* 1S */ + return false; + } + + *_srtt = call->peer->srtt_us >> 3; + return true; +} +EXPORT_SYMBOL(rxrpc_kernel_get_srtt); diff --git a/net/rxrpc/proc.c b/net/rxrpc/proc.c new file mode 100644 index 000000000..245418943 --- /dev/null +++ b/net/rxrpc/proc.c @@ -0,0 +1,399 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* /proc/net/ support for AF_RXRPC + * + * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#include +#include +#include +#include "ar-internal.h" + +static const char *const rxrpc_conn_states[RXRPC_CONN__NR_STATES] = { + [RXRPC_CONN_UNUSED] = "Unused ", + [RXRPC_CONN_CLIENT] = "Client ", + [RXRPC_CONN_SERVICE_PREALLOC] = "SvPrealc", + [RXRPC_CONN_SERVICE_UNSECURED] = "SvUnsec ", + [RXRPC_CONN_SERVICE_CHALLENGING] = "SvChall ", + [RXRPC_CONN_SERVICE] = "SvSecure", + [RXRPC_CONN_REMOTELY_ABORTED] = "RmtAbort", + [RXRPC_CONN_LOCALLY_ABORTED] = "LocAbort", +}; + +/* + * generate a list of extant and dead calls in /proc/net/rxrpc_calls + */ +static void *rxrpc_call_seq_start(struct seq_file *seq, loff_t *_pos) + __acquires(rcu) +{ + struct rxrpc_net *rxnet = rxrpc_net(seq_file_net(seq)); + + rcu_read_lock(); + return seq_list_start_head_rcu(&rxnet->calls, *_pos); +} + +static void *rxrpc_call_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct rxrpc_net *rxnet = rxrpc_net(seq_file_net(seq)); + + return seq_list_next_rcu(v, &rxnet->calls, pos); +} + +static void rxrpc_call_seq_stop(struct seq_file *seq, void *v) + __releases(rcu) +{ + rcu_read_unlock(); +} + +static int rxrpc_call_seq_show(struct seq_file *seq, void *v) +{ + struct rxrpc_local *local; + struct rxrpc_sock *rx; + struct rxrpc_peer *peer; + struct rxrpc_call *call; + struct rxrpc_net *rxnet = rxrpc_net(seq_file_net(seq)); + unsigned long timeout = 0; + rxrpc_seq_t tx_hard_ack, rx_hard_ack; + char lbuff[50], rbuff[50]; + + if (v == &rxnet->calls) { + seq_puts(seq, + "Proto Local " + " Remote " + " SvID ConnID CallID End Use State Abort " + " DebugId TxSeq TW RxSeq RW RxSerial RxTimo\n"); + return 0; + } + + call = list_entry(v, struct rxrpc_call, link); + + rx = rcu_dereference(call->socket); + if (rx) { + local = READ_ONCE(rx->local); + if (local) + sprintf(lbuff, "%pISpc", &local->srx.transport); + else + strcpy(lbuff, "no_local"); + } else { + strcpy(lbuff, "no_socket"); + } + + peer = call->peer; + if (peer) + sprintf(rbuff, "%pISpc", &peer->srx.transport); + else + strcpy(rbuff, "no_connection"); + + if (call->state != RXRPC_CALL_SERVER_PREALLOC) { + timeout = READ_ONCE(call->expect_rx_by); + timeout -= jiffies; + } + + tx_hard_ack = READ_ONCE(call->tx_hard_ack); + rx_hard_ack = READ_ONCE(call->rx_hard_ack); + seq_printf(seq, + "UDP %-47.47s %-47.47s %4x %08x %08x %s %3u" + " %-8.8s %08x %08x %08x %02x %08x %02x %08x %06lx\n", + lbuff, + rbuff, + call->service_id, + call->cid, + call->call_id, + rxrpc_is_service_call(call) ? "Svc" : "Clt", + refcount_read(&call->ref), + rxrpc_call_states[call->state], + call->abort_code, + call->debug_id, + tx_hard_ack, READ_ONCE(call->tx_top) - tx_hard_ack, + rx_hard_ack, READ_ONCE(call->rx_top) - rx_hard_ack, + call->rx_serial, + timeout); + + return 0; +} + +const struct seq_operations rxrpc_call_seq_ops = { + .start = rxrpc_call_seq_start, + .next = rxrpc_call_seq_next, + .stop = rxrpc_call_seq_stop, + .show = rxrpc_call_seq_show, +}; + +/* + * generate a list of extant virtual connections in /proc/net/rxrpc_conns + */ +static void *rxrpc_connection_seq_start(struct seq_file *seq, loff_t *_pos) + __acquires(rxnet->conn_lock) +{ + struct rxrpc_net *rxnet = rxrpc_net(seq_file_net(seq)); + + read_lock(&rxnet->conn_lock); + return seq_list_start_head(&rxnet->conn_proc_list, *_pos); +} + +static void *rxrpc_connection_seq_next(struct seq_file *seq, void *v, + loff_t *pos) +{ + struct rxrpc_net *rxnet = rxrpc_net(seq_file_net(seq)); + + return seq_list_next(v, &rxnet->conn_proc_list, pos); +} + +static void rxrpc_connection_seq_stop(struct seq_file *seq, void *v) + __releases(rxnet->conn_lock) +{ + struct rxrpc_net *rxnet = rxrpc_net(seq_file_net(seq)); + + read_unlock(&rxnet->conn_lock); +} + +static int rxrpc_connection_seq_show(struct seq_file *seq, void *v) +{ + struct rxrpc_connection *conn; + struct rxrpc_net *rxnet = rxrpc_net(seq_file_net(seq)); + char lbuff[50], rbuff[50]; + + if (v == &rxnet->conn_proc_list) { + seq_puts(seq, + "Proto Local " + " Remote " + " SvID ConnID End Use State Key " + " Serial ISerial CallId0 CallId1 CallId2 CallId3\n" + ); + return 0; + } + + conn = list_entry(v, struct rxrpc_connection, proc_link); + if (conn->state == RXRPC_CONN_SERVICE_PREALLOC) { + strcpy(lbuff, "no_local"); + strcpy(rbuff, "no_connection"); + goto print; + } + + sprintf(lbuff, "%pISpc", &conn->params.local->srx.transport); + + sprintf(rbuff, "%pISpc", &conn->params.peer->srx.transport); +print: + seq_printf(seq, + "UDP %-47.47s %-47.47s %4x %08x %s %3u" + " %s %08x %08x %08x %08x %08x %08x %08x\n", + lbuff, + rbuff, + conn->service_id, + conn->proto.cid, + rxrpc_conn_is_service(conn) ? "Svc" : "Clt", + refcount_read(&conn->ref), + rxrpc_conn_states[conn->state], + key_serial(conn->params.key), + atomic_read(&conn->serial), + conn->hi_serial, + conn->channels[0].call_id, + conn->channels[1].call_id, + conn->channels[2].call_id, + conn->channels[3].call_id); + + return 0; +} + +const struct seq_operations rxrpc_connection_seq_ops = { + .start = rxrpc_connection_seq_start, + .next = rxrpc_connection_seq_next, + .stop = rxrpc_connection_seq_stop, + .show = rxrpc_connection_seq_show, +}; + +/* + * generate a list of extant virtual peers in /proc/net/rxrpc/peers + */ +static int rxrpc_peer_seq_show(struct seq_file *seq, void *v) +{ + struct rxrpc_peer *peer; + time64_t now; + char lbuff[50], rbuff[50]; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, + "Proto Local " + " Remote " + " Use CW MTU LastUse RTT RTO\n" + ); + return 0; + } + + peer = list_entry(v, struct rxrpc_peer, hash_link); + + sprintf(lbuff, "%pISpc", &peer->local->srx.transport); + + sprintf(rbuff, "%pISpc", &peer->srx.transport); + + now = ktime_get_seconds(); + seq_printf(seq, + "UDP %-47.47s %-47.47s %3u" + " %3u %5u %6llus %8u %8u\n", + lbuff, + rbuff, + refcount_read(&peer->ref), + peer->cong_cwnd, + peer->mtu, + now - peer->last_tx_at, + peer->srtt_us >> 3, + jiffies_to_usecs(peer->rto_j)); + + return 0; +} + +static void *rxrpc_peer_seq_start(struct seq_file *seq, loff_t *_pos) + __acquires(rcu) +{ + struct rxrpc_net *rxnet = rxrpc_net(seq_file_net(seq)); + unsigned int bucket, n; + unsigned int shift = 32 - HASH_BITS(rxnet->peer_hash); + void *p; + + rcu_read_lock(); + + if (*_pos >= UINT_MAX) + return NULL; + + n = *_pos & ((1U << shift) - 1); + bucket = *_pos >> shift; + for (;;) { + if (bucket >= HASH_SIZE(rxnet->peer_hash)) { + *_pos = UINT_MAX; + return NULL; + } + if (n == 0) { + if (bucket == 0) + return SEQ_START_TOKEN; + *_pos += 1; + n++; + } + + p = seq_hlist_start_rcu(&rxnet->peer_hash[bucket], n - 1); + if (p) + return p; + bucket++; + n = 1; + *_pos = (bucket << shift) | n; + } +} + +static void *rxrpc_peer_seq_next(struct seq_file *seq, void *v, loff_t *_pos) +{ + struct rxrpc_net *rxnet = rxrpc_net(seq_file_net(seq)); + unsigned int bucket, n; + unsigned int shift = 32 - HASH_BITS(rxnet->peer_hash); + void *p; + + if (*_pos >= UINT_MAX) + return NULL; + + bucket = *_pos >> shift; + + p = seq_hlist_next_rcu(v, &rxnet->peer_hash[bucket], _pos); + if (p) + return p; + + for (;;) { + bucket++; + n = 1; + *_pos = (bucket << shift) | n; + + if (bucket >= HASH_SIZE(rxnet->peer_hash)) { + *_pos = UINT_MAX; + return NULL; + } + if (n == 0) { + *_pos += 1; + n++; + } + + p = seq_hlist_start_rcu(&rxnet->peer_hash[bucket], n - 1); + if (p) + return p; + } +} + +static void rxrpc_peer_seq_stop(struct seq_file *seq, void *v) + __releases(rcu) +{ + rcu_read_unlock(); +} + + +const struct seq_operations rxrpc_peer_seq_ops = { + .start = rxrpc_peer_seq_start, + .next = rxrpc_peer_seq_next, + .stop = rxrpc_peer_seq_stop, + .show = rxrpc_peer_seq_show, +}; + +/* + * Generate a list of extant virtual local endpoints in /proc/net/rxrpc/locals + */ +static int rxrpc_local_seq_show(struct seq_file *seq, void *v) +{ + struct rxrpc_local *local; + char lbuff[50]; + + if (v == SEQ_START_TOKEN) { + seq_puts(seq, + "Proto Local " + " Use Act\n"); + return 0; + } + + local = hlist_entry(v, struct rxrpc_local, link); + + sprintf(lbuff, "%pISpc", &local->srx.transport); + + seq_printf(seq, + "UDP %-47.47s %3u %3u\n", + lbuff, + refcount_read(&local->ref), + atomic_read(&local->active_users)); + + return 0; +} + +static void *rxrpc_local_seq_start(struct seq_file *seq, loff_t *_pos) + __acquires(rcu) +{ + struct rxrpc_net *rxnet = rxrpc_net(seq_file_net(seq)); + unsigned int n; + + rcu_read_lock(); + + if (*_pos >= UINT_MAX) + return NULL; + + n = *_pos; + if (n == 0) + return SEQ_START_TOKEN; + + return seq_hlist_start_rcu(&rxnet->local_endpoints, n - 1); +} + +static void *rxrpc_local_seq_next(struct seq_file *seq, void *v, loff_t *_pos) +{ + struct rxrpc_net *rxnet = rxrpc_net(seq_file_net(seq)); + + if (*_pos >= UINT_MAX) + return NULL; + + return seq_hlist_next_rcu(v, &rxnet->local_endpoints, _pos); +} + +static void rxrpc_local_seq_stop(struct seq_file *seq, void *v) + __releases(rcu) +{ + rcu_read_unlock(); +} + +const struct seq_operations rxrpc_local_seq_ops = { + .start = rxrpc_local_seq_start, + .next = rxrpc_local_seq_next, + .stop = rxrpc_local_seq_stop, + .show = rxrpc_local_seq_show, +}; diff --git a/net/rxrpc/protocol.h b/net/rxrpc/protocol.h new file mode 100644 index 000000000..d2cf8e1d2 --- /dev/null +++ b/net/rxrpc/protocol.h @@ -0,0 +1,186 @@ +/* SPDX-License-Identifier: GPL-2.0-or-later */ +/* packet.h: Rx packet layout and definitions + * + * Copyright (C) 2002, 2007 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#ifndef _LINUX_RXRPC_PACKET_H +#define _LINUX_RXRPC_PACKET_H + +typedef u32 rxrpc_seq_t; /* Rx message sequence number */ +typedef u32 rxrpc_serial_t; /* Rx message serial number */ +typedef __be32 rxrpc_seq_net_t; /* on-the-wire Rx message sequence number */ +typedef __be32 rxrpc_serial_net_t; /* on-the-wire Rx message serial number */ + +/*****************************************************************************/ +/* + * on-the-wire Rx packet header + * - all multibyte fields should be in network byte order + */ +struct rxrpc_wire_header { + __be32 epoch; /* client boot timestamp */ +#define RXRPC_RANDOM_EPOCH 0x80000000 /* Random if set, date-based if not */ + + __be32 cid; /* connection and channel ID */ +#define RXRPC_MAXCALLS 4 /* max active calls per conn */ +#define RXRPC_CHANNELMASK (RXRPC_MAXCALLS-1) /* mask for channel ID */ +#define RXRPC_CIDMASK (~RXRPC_CHANNELMASK) /* mask for connection ID */ +#define RXRPC_CIDSHIFT ilog2(RXRPC_MAXCALLS) /* shift for connection ID */ +#define RXRPC_CID_INC (1 << RXRPC_CIDSHIFT) /* connection ID increment */ + + __be32 callNumber; /* call ID (0 for connection-level packets) */ + __be32 seq; /* sequence number of pkt in call stream */ + __be32 serial; /* serial number of pkt sent to network */ + + uint8_t type; /* packet type */ +#define RXRPC_PACKET_TYPE_DATA 1 /* data */ +#define RXRPC_PACKET_TYPE_ACK 2 /* ACK */ +#define RXRPC_PACKET_TYPE_BUSY 3 /* call reject */ +#define RXRPC_PACKET_TYPE_ABORT 4 /* call/connection abort */ +#define RXRPC_PACKET_TYPE_ACKALL 5 /* ACK all outstanding packets on call */ +#define RXRPC_PACKET_TYPE_CHALLENGE 6 /* connection security challenge (SRVR->CLNT) */ +#define RXRPC_PACKET_TYPE_RESPONSE 7 /* connection secutity response (CLNT->SRVR) */ +#define RXRPC_PACKET_TYPE_DEBUG 8 /* debug info request */ +#define RXRPC_PACKET_TYPE_PARAMS 9 /* Parameter negotiation (unspec'd, ignore) */ +#define RXRPC_PACKET_TYPE_10 10 /* Ignored */ +#define RXRPC_PACKET_TYPE_11 11 /* Ignored */ +#define RXRPC_PACKET_TYPE_VERSION 13 /* version string request */ + + uint8_t flags; /* packet flags */ +#define RXRPC_CLIENT_INITIATED 0x01 /* signifies a packet generated by a client */ +#define RXRPC_REQUEST_ACK 0x02 /* request an unconditional ACK of this packet */ +#define RXRPC_LAST_PACKET 0x04 /* the last packet from this side for this call */ +#define RXRPC_MORE_PACKETS 0x08 /* more packets to come */ +#define RXRPC_JUMBO_PACKET 0x20 /* [DATA] this is a jumbo packet */ +#define RXRPC_SLOW_START_OK 0x20 /* [ACK] slow start supported */ + + uint8_t userStatus; /* app-layer defined status */ +#define RXRPC_USERSTATUS_SERVICE_UPGRADE 0x01 /* AuriStor service upgrade request */ + + uint8_t securityIndex; /* security protocol ID */ + union { + __be16 _rsvd; /* reserved */ + __be16 cksum; /* kerberos security checksum */ + }; + __be16 serviceId; /* service ID */ + +} __packed; + +/*****************************************************************************/ +/* + * jumbo packet secondary header + * - can be mapped to read header by: + * - new_serial = serial + 1 + * - new_seq = seq + 1 + * - new_flags = j_flags + * - new__rsvd = j__rsvd + * - duplicating all other fields + */ +struct rxrpc_jumbo_header { + uint8_t flags; /* packet flags (as per rxrpc_header) */ + uint8_t pad; + union { + __be16 _rsvd; /* reserved */ + __be16 cksum; /* kerberos security checksum */ + }; +}; + +#define RXRPC_JUMBO_DATALEN 1412 /* non-terminal jumbo packet data length */ +#define RXRPC_JUMBO_SUBPKTLEN (RXRPC_JUMBO_DATALEN + sizeof(struct rxrpc_jumbo_header)) + +/* + * The maximum number of subpackets that can possibly fit in a UDP packet is: + * + * ((max_IP - IP_hdr - UDP_hdr) / RXRPC_JUMBO_SUBPKTLEN) + 1 + * = ((65535 - 28 - 28) / 1416) + 1 + * = 46 non-terminal packets and 1 terminal packet. + */ +#define RXRPC_MAX_NR_JUMBO 47 + +/*****************************************************************************/ +/* + * on-the-wire Rx ACK packet data payload + * - all multibyte fields should be in network byte order + */ +struct rxrpc_ackpacket { + __be16 bufferSpace; /* number of packet buffers available */ + __be16 maxSkew; /* diff between serno being ACK'd and highest serial no + * received */ + __be32 firstPacket; /* sequence no of first ACK'd packet in attached list */ + __be32 previousPacket; /* sequence no of previous packet received */ + __be32 serial; /* serial no of packet that prompted this ACK */ + + uint8_t reason; /* reason for ACK */ +#define RXRPC_ACK_REQUESTED 1 /* ACK was requested on packet */ +#define RXRPC_ACK_DUPLICATE 2 /* duplicate packet received */ +#define RXRPC_ACK_OUT_OF_SEQUENCE 3 /* out of sequence packet received */ +#define RXRPC_ACK_EXCEEDS_WINDOW 4 /* packet received beyond end of ACK window */ +#define RXRPC_ACK_NOSPACE 5 /* packet discarded due to lack of buffer space */ +#define RXRPC_ACK_PING 6 /* keep alive ACK */ +#define RXRPC_ACK_PING_RESPONSE 7 /* response to RXRPC_ACK_PING */ +#define RXRPC_ACK_DELAY 8 /* nothing happened since received packet */ +#define RXRPC_ACK_IDLE 9 /* ACK due to fully received ACK window */ +#define RXRPC_ACK__INVALID 10 /* Representation of invalid ACK reason */ + + uint8_t nAcks; /* number of ACKs */ +#define RXRPC_MAXACKS 255 + + uint8_t acks[0]; /* list of ACK/NAKs */ +#define RXRPC_ACK_TYPE_NACK 0 +#define RXRPC_ACK_TYPE_ACK 1 + +} __packed; + +/* Some ACKs refer to specific packets and some are general and can be updated. */ +#define RXRPC_ACK_UPDATEABLE ((1 << RXRPC_ACK_REQUESTED) | \ + (1 << RXRPC_ACK_PING_RESPONSE) | \ + (1 << RXRPC_ACK_DELAY) | \ + (1 << RXRPC_ACK_IDLE)) + + +/* + * ACK packets can have a further piece of information tagged on the end + */ +struct rxrpc_ackinfo { + __be32 rxMTU; /* maximum Rx MTU size (bytes) [AFS 3.3] */ + __be32 maxMTU; /* maximum interface MTU size (bytes) [AFS 3.3] */ + __be32 rwind; /* Rx window size (packets) [AFS 3.4] */ + __be32 jumbo_max; /* max packets to stick into a jumbo packet [AFS 3.5] */ +}; + +/*****************************************************************************/ +/* + * Kerberos security type-2 challenge packet + */ +struct rxkad_challenge { + __be32 version; /* version of this challenge type */ + __be32 nonce; /* encrypted random number */ + __be32 min_level; /* minimum security level */ + __be32 __padding; /* padding to 8-byte boundary */ +} __packed; + +/*****************************************************************************/ +/* + * Kerberos security type-2 response packet + */ +struct rxkad_response { + __be32 version; /* version of this response type */ + __be32 __pad; + + /* encrypted bit of the response */ + struct { + __be32 epoch; /* current epoch */ + __be32 cid; /* parent connection ID */ + __be32 checksum; /* checksum */ + __be32 securityIndex; /* security type */ + __be32 call_id[4]; /* encrypted call IDs */ + __be32 inc_nonce; /* challenge nonce + 1 */ + __be32 level; /* desired level */ + } encrypted; + + __be32 kvno; /* Kerberos key version number */ + __be32 ticket_len; /* Kerberos ticket length */ +} __packed; + +#endif /* _LINUX_RXRPC_PACKET_H */ diff --git a/net/rxrpc/recvmsg.c b/net/rxrpc/recvmsg.c new file mode 100644 index 000000000..7e39c262f --- /dev/null +++ b/net/rxrpc/recvmsg.c @@ -0,0 +1,773 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* RxRPC recvmsg() implementation + * + * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include + +#include +#include +#include "ar-internal.h" + +/* + * Post a call for attention by the socket or kernel service. Further + * notifications are suppressed by putting recvmsg_link on a dummy queue. + */ +void rxrpc_notify_socket(struct rxrpc_call *call) +{ + struct rxrpc_sock *rx; + struct sock *sk; + + _enter("%d", call->debug_id); + + if (!list_empty(&call->recvmsg_link)) + return; + + rcu_read_lock(); + + rx = rcu_dereference(call->socket); + sk = &rx->sk; + if (rx && sk->sk_state < RXRPC_CLOSE) { + if (call->notify_rx) { + spin_lock_bh(&call->notify_lock); + call->notify_rx(sk, call, call->user_call_ID); + spin_unlock_bh(&call->notify_lock); + } else { + write_lock_bh(&rx->recvmsg_lock); + if (list_empty(&call->recvmsg_link)) { + rxrpc_get_call(call, rxrpc_call_got); + list_add_tail(&call->recvmsg_link, &rx->recvmsg_q); + } + write_unlock_bh(&rx->recvmsg_lock); + + if (!sock_flag(sk, SOCK_DEAD)) { + _debug("call %ps", sk->sk_data_ready); + sk->sk_data_ready(sk); + } + } + } + + rcu_read_unlock(); + _leave(""); +} + +/* + * Transition a call to the complete state. + */ +bool __rxrpc_set_call_completion(struct rxrpc_call *call, + enum rxrpc_call_completion compl, + u32 abort_code, + int error) +{ + if (call->state < RXRPC_CALL_COMPLETE) { + call->abort_code = abort_code; + call->error = error; + call->completion = compl; + call->state = RXRPC_CALL_COMPLETE; + trace_rxrpc_call_complete(call); + wake_up(&call->waitq); + rxrpc_notify_socket(call); + return true; + } + return false; +} + +bool rxrpc_set_call_completion(struct rxrpc_call *call, + enum rxrpc_call_completion compl, + u32 abort_code, + int error) +{ + bool ret = false; + + if (call->state < RXRPC_CALL_COMPLETE) { + write_lock_bh(&call->state_lock); + ret = __rxrpc_set_call_completion(call, compl, abort_code, error); + write_unlock_bh(&call->state_lock); + } + return ret; +} + +/* + * Record that a call successfully completed. + */ +bool __rxrpc_call_completed(struct rxrpc_call *call) +{ + return __rxrpc_set_call_completion(call, RXRPC_CALL_SUCCEEDED, 0, 0); +} + +bool rxrpc_call_completed(struct rxrpc_call *call) +{ + bool ret = false; + + if (call->state < RXRPC_CALL_COMPLETE) { + write_lock_bh(&call->state_lock); + ret = __rxrpc_call_completed(call); + write_unlock_bh(&call->state_lock); + } + return ret; +} + +/* + * Record that a call is locally aborted. + */ +bool __rxrpc_abort_call(const char *why, struct rxrpc_call *call, + rxrpc_seq_t seq, u32 abort_code, int error) +{ + trace_rxrpc_abort(call->debug_id, why, call->cid, call->call_id, seq, + abort_code, error); + return __rxrpc_set_call_completion(call, RXRPC_CALL_LOCALLY_ABORTED, + abort_code, error); +} + +bool rxrpc_abort_call(const char *why, struct rxrpc_call *call, + rxrpc_seq_t seq, u32 abort_code, int error) +{ + bool ret; + + write_lock_bh(&call->state_lock); + ret = __rxrpc_abort_call(why, call, seq, abort_code, error); + write_unlock_bh(&call->state_lock); + return ret; +} + +/* + * Pass a call terminating message to userspace. + */ +static int rxrpc_recvmsg_term(struct rxrpc_call *call, struct msghdr *msg) +{ + u32 tmp = 0; + int ret; + + switch (call->completion) { + case RXRPC_CALL_SUCCEEDED: + ret = 0; + if (rxrpc_is_service_call(call)) + ret = put_cmsg(msg, SOL_RXRPC, RXRPC_ACK, 0, &tmp); + break; + case RXRPC_CALL_REMOTELY_ABORTED: + tmp = call->abort_code; + ret = put_cmsg(msg, SOL_RXRPC, RXRPC_ABORT, 4, &tmp); + break; + case RXRPC_CALL_LOCALLY_ABORTED: + tmp = call->abort_code; + ret = put_cmsg(msg, SOL_RXRPC, RXRPC_ABORT, 4, &tmp); + break; + case RXRPC_CALL_NETWORK_ERROR: + tmp = -call->error; + ret = put_cmsg(msg, SOL_RXRPC, RXRPC_NET_ERROR, 4, &tmp); + break; + case RXRPC_CALL_LOCAL_ERROR: + tmp = -call->error; + ret = put_cmsg(msg, SOL_RXRPC, RXRPC_LOCAL_ERROR, 4, &tmp); + break; + default: + pr_err("Invalid terminal call state %u\n", call->state); + BUG(); + break; + } + + trace_rxrpc_recvmsg(call, rxrpc_recvmsg_terminal, call->rx_hard_ack, + call->rx_pkt_offset, call->rx_pkt_len, ret); + return ret; +} + +/* + * End the packet reception phase. + */ +static void rxrpc_end_rx_phase(struct rxrpc_call *call, rxrpc_serial_t serial) +{ + _enter("%d,%s", call->debug_id, rxrpc_call_states[call->state]); + + trace_rxrpc_receive(call, rxrpc_receive_end, 0, call->rx_top); + ASSERTCMP(call->rx_hard_ack, ==, call->rx_top); + + if (call->state == RXRPC_CALL_CLIENT_RECV_REPLY) { + rxrpc_propose_ACK(call, RXRPC_ACK_IDLE, serial, false, true, + rxrpc_propose_ack_terminal_ack); + //rxrpc_send_ack_packet(call, false, NULL); + } + + write_lock_bh(&call->state_lock); + + switch (call->state) { + case RXRPC_CALL_CLIENT_RECV_REPLY: + __rxrpc_call_completed(call); + write_unlock_bh(&call->state_lock); + break; + + case RXRPC_CALL_SERVER_RECV_REQUEST: + call->tx_phase = true; + call->state = RXRPC_CALL_SERVER_ACK_REQUEST; + call->expect_req_by = jiffies + MAX_JIFFY_OFFSET; + write_unlock_bh(&call->state_lock); + rxrpc_propose_ACK(call, RXRPC_ACK_DELAY, serial, false, true, + rxrpc_propose_ack_processing_op); + break; + default: + write_unlock_bh(&call->state_lock); + break; + } +} + +/* + * Discard a packet we've used up and advance the Rx window by one. + */ +static void rxrpc_rotate_rx_window(struct rxrpc_call *call) +{ + struct rxrpc_skb_priv *sp; + struct sk_buff *skb; + rxrpc_serial_t serial; + rxrpc_seq_t hard_ack, top; + bool last = false; + u8 subpacket; + int ix; + + _enter("%d", call->debug_id); + + hard_ack = call->rx_hard_ack; + top = smp_load_acquire(&call->rx_top); + ASSERT(before(hard_ack, top)); + + hard_ack++; + ix = hard_ack & RXRPC_RXTX_BUFF_MASK; + skb = call->rxtx_buffer[ix]; + rxrpc_see_skb(skb, rxrpc_skb_rotated); + sp = rxrpc_skb(skb); + + subpacket = call->rxtx_annotations[ix] & RXRPC_RX_ANNO_SUBPACKET; + serial = sp->hdr.serial + subpacket; + + if (subpacket == sp->nr_subpackets - 1 && + sp->rx_flags & RXRPC_SKB_INCL_LAST) + last = true; + + call->rxtx_buffer[ix] = NULL; + call->rxtx_annotations[ix] = 0; + /* Barrier against rxrpc_input_data(). */ + smp_store_release(&call->rx_hard_ack, hard_ack); + + rxrpc_free_skb(skb, rxrpc_skb_freed); + + trace_rxrpc_receive(call, rxrpc_receive_rotate, serial, hard_ack); + if (last) { + rxrpc_end_rx_phase(call, serial); + } else { + /* Check to see if there's an ACK that needs sending. */ + if (atomic_inc_return(&call->ackr_nr_consumed) > 2) + rxrpc_propose_ACK(call, RXRPC_ACK_IDLE, serial, + true, false, + rxrpc_propose_ack_rotate_rx); + if (call->ackr_reason && call->ackr_reason != RXRPC_ACK_DELAY) + rxrpc_send_ack_packet(call, false, NULL); + } +} + +/* + * Decrypt and verify a (sub)packet. The packet's length may be changed due to + * padding, but if this is the case, the packet length will be resident in the + * socket buffer. Note that we can't modify the master skb info as the skb may + * be the home to multiple subpackets. + */ +static int rxrpc_verify_packet(struct rxrpc_call *call, struct sk_buff *skb, + u8 annotation, + unsigned int offset, unsigned int len) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + rxrpc_seq_t seq = sp->hdr.seq; + u16 cksum = sp->hdr.cksum; + u8 subpacket = annotation & RXRPC_RX_ANNO_SUBPACKET; + + _enter(""); + + /* For all but the head jumbo subpacket, the security checksum is in a + * jumbo header immediately prior to the data. + */ + if (subpacket > 0) { + __be16 tmp; + if (skb_copy_bits(skb, offset - 2, &tmp, 2) < 0) + BUG(); + cksum = ntohs(tmp); + seq += subpacket; + } + + return call->security->verify_packet(call, skb, offset, len, + seq, cksum); +} + +/* + * Locate the data within a packet. This is complicated by: + * + * (1) An skb may contain a jumbo packet - so we have to find the appropriate + * subpacket. + * + * (2) The (sub)packets may be encrypted and, if so, the encrypted portion + * contains an extra header which includes the true length of the data, + * excluding any encrypted padding. + */ +static int rxrpc_locate_data(struct rxrpc_call *call, struct sk_buff *skb, + u8 *_annotation, + unsigned int *_offset, unsigned int *_len, + bool *_last) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + unsigned int offset = sizeof(struct rxrpc_wire_header); + unsigned int len; + bool last = false; + int ret; + u8 annotation = *_annotation; + u8 subpacket = annotation & RXRPC_RX_ANNO_SUBPACKET; + + /* Locate the subpacket */ + offset += subpacket * RXRPC_JUMBO_SUBPKTLEN; + len = skb->len - offset; + if (subpacket < sp->nr_subpackets - 1) + len = RXRPC_JUMBO_DATALEN; + else if (sp->rx_flags & RXRPC_SKB_INCL_LAST) + last = true; + + if (!(annotation & RXRPC_RX_ANNO_VERIFIED)) { + ret = rxrpc_verify_packet(call, skb, annotation, offset, len); + if (ret < 0) + return ret; + *_annotation |= RXRPC_RX_ANNO_VERIFIED; + } + + *_offset = offset; + *_len = len; + *_last = last; + call->security->locate_data(call, skb, _offset, _len); + return 0; +} + +/* + * Deliver messages to a call. This keeps processing packets until the buffer + * is filled and we find either more DATA (returns 0) or the end of the DATA + * (returns 1). If more packets are required, it returns -EAGAIN. + */ +static int rxrpc_recvmsg_data(struct socket *sock, struct rxrpc_call *call, + struct msghdr *msg, struct iov_iter *iter, + size_t len, int flags, size_t *_offset) +{ + struct rxrpc_skb_priv *sp; + struct sk_buff *skb; + rxrpc_serial_t serial; + rxrpc_seq_t hard_ack, top, seq; + size_t remain; + bool rx_pkt_last; + unsigned int rx_pkt_offset, rx_pkt_len; + int ix, copy, ret = -EAGAIN, ret2; + + if (test_and_clear_bit(RXRPC_CALL_RX_UNDERRUN, &call->flags) && + call->ackr_reason) + rxrpc_send_ack_packet(call, false, NULL); + + rx_pkt_offset = call->rx_pkt_offset; + rx_pkt_len = call->rx_pkt_len; + rx_pkt_last = call->rx_pkt_last; + + if (call->state >= RXRPC_CALL_SERVER_ACK_REQUEST) { + seq = call->rx_hard_ack; + ret = 1; + goto done; + } + + /* Barriers against rxrpc_input_data(). */ + hard_ack = call->rx_hard_ack; + seq = hard_ack + 1; + + while (top = smp_load_acquire(&call->rx_top), + before_eq(seq, top) + ) { + ix = seq & RXRPC_RXTX_BUFF_MASK; + skb = call->rxtx_buffer[ix]; + if (!skb) { + trace_rxrpc_recvmsg(call, rxrpc_recvmsg_hole, seq, + rx_pkt_offset, rx_pkt_len, 0); + break; + } + smp_rmb(); + rxrpc_see_skb(skb, rxrpc_skb_seen); + sp = rxrpc_skb(skb); + + if (!(flags & MSG_PEEK)) { + serial = sp->hdr.serial; + serial += call->rxtx_annotations[ix] & RXRPC_RX_ANNO_SUBPACKET; + trace_rxrpc_receive(call, rxrpc_receive_front, + serial, seq); + } + + if (msg) + sock_recv_timestamp(msg, sock->sk, skb); + + if (rx_pkt_offset == 0) { + ret2 = rxrpc_locate_data(call, skb, + &call->rxtx_annotations[ix], + &rx_pkt_offset, &rx_pkt_len, + &rx_pkt_last); + trace_rxrpc_recvmsg(call, rxrpc_recvmsg_next, seq, + rx_pkt_offset, rx_pkt_len, ret2); + if (ret2 < 0) { + ret = ret2; + goto out; + } + } else { + trace_rxrpc_recvmsg(call, rxrpc_recvmsg_cont, seq, + rx_pkt_offset, rx_pkt_len, 0); + } + + /* We have to handle short, empty and used-up DATA packets. */ + remain = len - *_offset; + copy = rx_pkt_len; + if (copy > remain) + copy = remain; + if (copy > 0) { + ret2 = skb_copy_datagram_iter(skb, rx_pkt_offset, iter, + copy); + if (ret2 < 0) { + ret = ret2; + goto out; + } + + /* handle piecemeal consumption of data packets */ + rx_pkt_offset += copy; + rx_pkt_len -= copy; + *_offset += copy; + } + + if (rx_pkt_len > 0) { + trace_rxrpc_recvmsg(call, rxrpc_recvmsg_full, seq, + rx_pkt_offset, rx_pkt_len, 0); + ASSERTCMP(*_offset, ==, len); + ret = 0; + break; + } + + /* The whole packet has been transferred. */ + if (!(flags & MSG_PEEK)) + rxrpc_rotate_rx_window(call); + rx_pkt_offset = 0; + rx_pkt_len = 0; + + if (rx_pkt_last) { + ASSERTCMP(seq, ==, READ_ONCE(call->rx_top)); + ret = 1; + goto out; + } + + seq++; + } + +out: + if (!(flags & MSG_PEEK)) { + call->rx_pkt_offset = rx_pkt_offset; + call->rx_pkt_len = rx_pkt_len; + call->rx_pkt_last = rx_pkt_last; + } +done: + trace_rxrpc_recvmsg(call, rxrpc_recvmsg_data_return, seq, + rx_pkt_offset, rx_pkt_len, ret); + if (ret == -EAGAIN) + set_bit(RXRPC_CALL_RX_UNDERRUN, &call->flags); + return ret; +} + +/* + * Receive a message from an RxRPC socket + * - we need to be careful about two or more threads calling recvmsg + * simultaneously + */ +int rxrpc_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, + int flags) +{ + struct rxrpc_call *call; + struct rxrpc_sock *rx = rxrpc_sk(sock->sk); + struct list_head *l; + size_t copied = 0; + long timeo; + int ret; + + DEFINE_WAIT(wait); + + trace_rxrpc_recvmsg(NULL, rxrpc_recvmsg_enter, 0, 0, 0, 0); + + if (flags & (MSG_OOB | MSG_TRUNC)) + return -EOPNOTSUPP; + + timeo = sock_rcvtimeo(&rx->sk, flags & MSG_DONTWAIT); + +try_again: + lock_sock(&rx->sk); + + /* Return immediately if a client socket has no outstanding calls */ + if (RB_EMPTY_ROOT(&rx->calls) && + list_empty(&rx->recvmsg_q) && + rx->sk.sk_state != RXRPC_SERVER_LISTENING) { + release_sock(&rx->sk); + return -EAGAIN; + } + + if (list_empty(&rx->recvmsg_q)) { + ret = -EWOULDBLOCK; + if (timeo == 0) { + call = NULL; + goto error_no_call; + } + + release_sock(&rx->sk); + + /* Wait for something to happen */ + prepare_to_wait_exclusive(sk_sleep(&rx->sk), &wait, + TASK_INTERRUPTIBLE); + ret = sock_error(&rx->sk); + if (ret) + goto wait_error; + + if (list_empty(&rx->recvmsg_q)) { + if (signal_pending(current)) + goto wait_interrupted; + trace_rxrpc_recvmsg(NULL, rxrpc_recvmsg_wait, + 0, 0, 0, 0); + timeo = schedule_timeout(timeo); + } + finish_wait(sk_sleep(&rx->sk), &wait); + goto try_again; + } + + /* Find the next call and dequeue it if we're not just peeking. If we + * do dequeue it, that comes with a ref that we will need to release. + */ + write_lock_bh(&rx->recvmsg_lock); + l = rx->recvmsg_q.next; + call = list_entry(l, struct rxrpc_call, recvmsg_link); + if (!(flags & MSG_PEEK)) + list_del_init(&call->recvmsg_link); + else + rxrpc_get_call(call, rxrpc_call_got); + write_unlock_bh(&rx->recvmsg_lock); + + trace_rxrpc_recvmsg(call, rxrpc_recvmsg_dequeue, 0, 0, 0, 0); + + /* We're going to drop the socket lock, so we need to lock the call + * against interference by sendmsg. + */ + if (!mutex_trylock(&call->user_mutex)) { + ret = -EWOULDBLOCK; + if (flags & MSG_DONTWAIT) + goto error_requeue_call; + ret = -ERESTARTSYS; + if (mutex_lock_interruptible(&call->user_mutex) < 0) + goto error_requeue_call; + } + + release_sock(&rx->sk); + + if (test_bit(RXRPC_CALL_RELEASED, &call->flags)) + BUG(); + + if (test_bit(RXRPC_CALL_HAS_USERID, &call->flags)) { + if (flags & MSG_CMSG_COMPAT) { + unsigned int id32 = call->user_call_ID; + + ret = put_cmsg(msg, SOL_RXRPC, RXRPC_USER_CALL_ID, + sizeof(unsigned int), &id32); + } else { + unsigned long idl = call->user_call_ID; + + ret = put_cmsg(msg, SOL_RXRPC, RXRPC_USER_CALL_ID, + sizeof(unsigned long), &idl); + } + if (ret < 0) + goto error_unlock_call; + } + + if (msg->msg_name && call->peer) { + struct sockaddr_rxrpc *srx = msg->msg_name; + size_t len = sizeof(call->peer->srx); + + memcpy(msg->msg_name, &call->peer->srx, len); + srx->srx_service = call->service_id; + msg->msg_namelen = len; + } + + switch (READ_ONCE(call->state)) { + case RXRPC_CALL_CLIENT_RECV_REPLY: + case RXRPC_CALL_SERVER_RECV_REQUEST: + case RXRPC_CALL_SERVER_ACK_REQUEST: + ret = rxrpc_recvmsg_data(sock, call, msg, &msg->msg_iter, len, + flags, &copied); + if (ret == -EAGAIN) + ret = 0; + + if (after(call->rx_top, call->rx_hard_ack) && + call->rxtx_buffer[(call->rx_hard_ack + 1) & RXRPC_RXTX_BUFF_MASK]) + rxrpc_notify_socket(call); + break; + default: + ret = 0; + break; + } + + if (ret < 0) + goto error_unlock_call; + + if (call->state == RXRPC_CALL_COMPLETE) { + ret = rxrpc_recvmsg_term(call, msg); + if (ret < 0) + goto error_unlock_call; + if (!(flags & MSG_PEEK)) + rxrpc_release_call(rx, call); + msg->msg_flags |= MSG_EOR; + ret = 1; + } + + if (ret == 0) + msg->msg_flags |= MSG_MORE; + else + msg->msg_flags &= ~MSG_MORE; + ret = copied; + +error_unlock_call: + mutex_unlock(&call->user_mutex); + rxrpc_put_call(call, rxrpc_call_put); + trace_rxrpc_recvmsg(call, rxrpc_recvmsg_return, 0, 0, 0, ret); + return ret; + +error_requeue_call: + if (!(flags & MSG_PEEK)) { + write_lock_bh(&rx->recvmsg_lock); + list_add(&call->recvmsg_link, &rx->recvmsg_q); + write_unlock_bh(&rx->recvmsg_lock); + trace_rxrpc_recvmsg(call, rxrpc_recvmsg_requeue, 0, 0, 0, 0); + } else { + rxrpc_put_call(call, rxrpc_call_put); + } +error_no_call: + release_sock(&rx->sk); +error_trace: + trace_rxrpc_recvmsg(call, rxrpc_recvmsg_return, 0, 0, 0, ret); + return ret; + +wait_interrupted: + ret = sock_intr_errno(timeo); +wait_error: + finish_wait(sk_sleep(&rx->sk), &wait); + call = NULL; + goto error_trace; +} + +/** + * rxrpc_kernel_recv_data - Allow a kernel service to receive data/info + * @sock: The socket that the call exists on + * @call: The call to send data through + * @iter: The buffer to receive into + * @_len: The amount of data we want to receive (decreased on return) + * @want_more: True if more data is expected to be read + * @_abort: Where the abort code is stored if -ECONNABORTED is returned + * @_service: Where to store the actual service ID (may be upgraded) + * + * Allow a kernel service to receive data and pick up information about the + * state of a call. Returns 0 if got what was asked for and there's more + * available, 1 if we got what was asked for and we're at the end of the data + * and -EAGAIN if we need more data. + * + * Note that we may return -EAGAIN to drain empty packets at the end of the + * data, even if we've already copied over the requested data. + * + * *_abort should also be initialised to 0. + */ +int rxrpc_kernel_recv_data(struct socket *sock, struct rxrpc_call *call, + struct iov_iter *iter, size_t *_len, + bool want_more, u32 *_abort, u16 *_service) +{ + size_t offset = 0; + int ret; + + _enter("{%d,%s},%zu,%d", + call->debug_id, rxrpc_call_states[call->state], + *_len, want_more); + + ASSERTCMP(call->state, !=, RXRPC_CALL_SERVER_SECURING); + + mutex_lock(&call->user_mutex); + + switch (READ_ONCE(call->state)) { + case RXRPC_CALL_CLIENT_RECV_REPLY: + case RXRPC_CALL_SERVER_RECV_REQUEST: + case RXRPC_CALL_SERVER_ACK_REQUEST: + ret = rxrpc_recvmsg_data(sock, call, NULL, iter, + *_len, 0, &offset); + *_len -= offset; + if (ret < 0) + goto out; + + /* We can only reach here with a partially full buffer if we + * have reached the end of the data. We must otherwise have a + * full buffer or have been given -EAGAIN. + */ + if (ret == 1) { + if (iov_iter_count(iter) > 0) + goto short_data; + if (!want_more) + goto read_phase_complete; + ret = 0; + goto out; + } + + if (!want_more) + goto excess_data; + goto out; + + case RXRPC_CALL_COMPLETE: + goto call_complete; + + default: + ret = -EINPROGRESS; + goto out; + } + +read_phase_complete: + ret = 1; +out: + switch (call->ackr_reason) { + case RXRPC_ACK_IDLE: + break; + case RXRPC_ACK_DELAY: + if (ret != -EAGAIN) + break; + fallthrough; + default: + rxrpc_send_ack_packet(call, false, NULL); + } + + if (_service) + *_service = call->service_id; + mutex_unlock(&call->user_mutex); + _leave(" = %d [%zu,%d]", ret, iov_iter_count(iter), *_abort); + return ret; + +short_data: + trace_rxrpc_rx_eproto(call, 0, tracepoint_string("short_data")); + ret = -EBADMSG; + goto out; +excess_data: + trace_rxrpc_rx_eproto(call, 0, tracepoint_string("excess_data")); + ret = -EMSGSIZE; + goto out; +call_complete: + *_abort = call->abort_code; + ret = call->error; + if (call->completion == RXRPC_CALL_SUCCEEDED) { + ret = 1; + if (iov_iter_count(iter) > 0) + ret = -ECONNRESET; + } + goto out; +} +EXPORT_SYMBOL(rxrpc_kernel_recv_data); diff --git a/net/rxrpc/rtt.c b/net/rxrpc/rtt.c new file mode 100644 index 000000000..be61d6f5b --- /dev/null +++ b/net/rxrpc/rtt.c @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: GPL-2.0 +/* RTT/RTO calculation. + * + * Adapted from TCP for AF_RXRPC by David Howells (dhowells@redhat.com) + * + * https://tools.ietf.org/html/rfc6298 + * https://tools.ietf.org/html/rfc1122#section-4.2.3.1 + * http://ccr.sigcomm.org/archive/1995/jan95/ccr-9501-partridge87.pdf + */ + +#include +#include "ar-internal.h" + +#define RXRPC_RTO_MAX ((unsigned)(120 * HZ)) +#define RXRPC_TIMEOUT_INIT ((unsigned)(1*HZ)) /* RFC6298 2.1 initial RTO value */ +#define rxrpc_jiffies32 ((u32)jiffies) /* As rxrpc_jiffies32 */ + +static u32 rxrpc_rto_min_us(struct rxrpc_peer *peer) +{ + return 200; +} + +static u32 __rxrpc_set_rto(const struct rxrpc_peer *peer) +{ + return usecs_to_jiffies((peer->srtt_us >> 3) + peer->rttvar_us); +} + +static u32 rxrpc_bound_rto(u32 rto) +{ + return min(rto, RXRPC_RTO_MAX); +} + +/* + * Called to compute a smoothed rtt estimate. The data fed to this + * routine either comes from timestamps, or from segments that were + * known _not_ to have been retransmitted [see Karn/Partridge + * Proceedings SIGCOMM 87]. The algorithm is from the SIGCOMM 88 + * piece by Van Jacobson. + * NOTE: the next three routines used to be one big routine. + * To save cycles in the RFC 1323 implementation it was better to break + * it up into three procedures. -- erics + */ +static void rxrpc_rtt_estimator(struct rxrpc_peer *peer, long sample_rtt_us) +{ + long m = sample_rtt_us; /* RTT */ + u32 srtt = peer->srtt_us; + + /* The following amusing code comes from Jacobson's + * article in SIGCOMM '88. Note that rtt and mdev + * are scaled versions of rtt and mean deviation. + * This is designed to be as fast as possible + * m stands for "measurement". + * + * On a 1990 paper the rto value is changed to: + * RTO = rtt + 4 * mdev + * + * Funny. This algorithm seems to be very broken. + * These formulae increase RTO, when it should be decreased, increase + * too slowly, when it should be increased quickly, decrease too quickly + * etc. I guess in BSD RTO takes ONE value, so that it is absolutely + * does not matter how to _calculate_ it. Seems, it was trap + * that VJ failed to avoid. 8) + */ + if (srtt != 0) { + m -= (srtt >> 3); /* m is now error in rtt est */ + srtt += m; /* rtt = 7/8 rtt + 1/8 new */ + if (m < 0) { + m = -m; /* m is now abs(error) */ + m -= (peer->mdev_us >> 2); /* similar update on mdev */ + /* This is similar to one of Eifel findings. + * Eifel blocks mdev updates when rtt decreases. + * This solution is a bit different: we use finer gain + * for mdev in this case (alpha*beta). + * Like Eifel it also prevents growth of rto, + * but also it limits too fast rto decreases, + * happening in pure Eifel. + */ + if (m > 0) + m >>= 3; + } else { + m -= (peer->mdev_us >> 2); /* similar update on mdev */ + } + + peer->mdev_us += m; /* mdev = 3/4 mdev + 1/4 new */ + if (peer->mdev_us > peer->mdev_max_us) { + peer->mdev_max_us = peer->mdev_us; + if (peer->mdev_max_us > peer->rttvar_us) + peer->rttvar_us = peer->mdev_max_us; + } + } else { + /* no previous measure. */ + srtt = m << 3; /* take the measured time to be rtt */ + peer->mdev_us = m << 1; /* make sure rto = 3*rtt */ + peer->rttvar_us = max(peer->mdev_us, rxrpc_rto_min_us(peer)); + peer->mdev_max_us = peer->rttvar_us; + } + + peer->srtt_us = max(1U, srtt); +} + +/* + * Calculate rto without backoff. This is the second half of Van Jacobson's + * routine referred to above. + */ +static void rxrpc_set_rto(struct rxrpc_peer *peer) +{ + u32 rto; + + /* 1. If rtt variance happened to be less 50msec, it is hallucination. + * It cannot be less due to utterly erratic ACK generation made + * at least by solaris and freebsd. "Erratic ACKs" has _nothing_ + * to do with delayed acks, because at cwnd>2 true delack timeout + * is invisible. Actually, Linux-2.4 also generates erratic + * ACKs in some circumstances. + */ + rto = __rxrpc_set_rto(peer); + + /* 2. Fixups made earlier cannot be right. + * If we do not estimate RTO correctly without them, + * all the algo is pure shit and should be replaced + * with correct one. It is exactly, which we pretend to do. + */ + + /* NOTE: clamping at RXRPC_RTO_MIN is not required, current algo + * guarantees that rto is higher. + */ + peer->rto_j = rxrpc_bound_rto(rto); +} + +static void rxrpc_ack_update_rtt(struct rxrpc_peer *peer, long rtt_us) +{ + if (rtt_us < 0) + return; + + //rxrpc_update_rtt_min(peer, rtt_us); + rxrpc_rtt_estimator(peer, rtt_us); + rxrpc_set_rto(peer); + + /* RFC6298: only reset backoff on valid RTT measurement. */ + peer->backoff = 0; +} + +/* + * Add RTT information to cache. This is called in softirq mode and has + * exclusive access to the peer RTT data. + */ +void rxrpc_peer_add_rtt(struct rxrpc_call *call, enum rxrpc_rtt_rx_trace why, + int rtt_slot, + rxrpc_serial_t send_serial, rxrpc_serial_t resp_serial, + ktime_t send_time, ktime_t resp_time) +{ + struct rxrpc_peer *peer = call->peer; + s64 rtt_us; + + rtt_us = ktime_to_us(ktime_sub(resp_time, send_time)); + if (rtt_us < 0) + return; + + spin_lock(&peer->rtt_input_lock); + rxrpc_ack_update_rtt(peer, rtt_us); + if (peer->rtt_count < 3) + peer->rtt_count++; + spin_unlock(&peer->rtt_input_lock); + + trace_rxrpc_rtt_rx(call, why, rtt_slot, send_serial, resp_serial, + peer->srtt_us >> 3, peer->rto_j); +} + +/* + * Get the retransmission timeout to set in jiffies, backing it off each time + * we retransmit. + */ +unsigned long rxrpc_get_rto_backoff(struct rxrpc_peer *peer, bool retrans) +{ + u64 timo_j; + u8 backoff = READ_ONCE(peer->backoff); + + timo_j = peer->rto_j; + timo_j <<= backoff; + if (retrans && timo_j * 2 <= RXRPC_RTO_MAX) + WRITE_ONCE(peer->backoff, backoff + 1); + + if (timo_j < 1) + timo_j = 1; + + return timo_j; +} + +void rxrpc_peer_init_rtt(struct rxrpc_peer *peer) +{ + peer->rto_j = RXRPC_TIMEOUT_INIT; + peer->mdev_us = jiffies_to_usecs(RXRPC_TIMEOUT_INIT); + peer->backoff = 0; + //minmax_reset(&peer->rtt_min, rxrpc_jiffies32, ~0U); +} diff --git a/net/rxrpc/rxkad.c b/net/rxrpc/rxkad.c new file mode 100644 index 000000000..78fa05241 --- /dev/null +++ b/net/rxrpc/rxkad.c @@ -0,0 +1,1405 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* Kerberos-based RxRPC security + * + * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ar-internal.h" + +#define RXKAD_VERSION 2 +#define MAXKRB5TICKETLEN 1024 +#define RXKAD_TKT_TYPE_KERBEROS_V5 256 +#define ANAME_SZ 40 /* size of authentication name */ +#define INST_SZ 40 /* size of principal's instance */ +#define REALM_SZ 40 /* size of principal's auth domain */ +#define SNAME_SZ 40 /* size of service name */ +#define RXKAD_ALIGN 8 + +struct rxkad_level1_hdr { + __be32 data_size; /* true data size (excluding padding) */ +}; + +struct rxkad_level2_hdr { + __be32 data_size; /* true data size (excluding padding) */ + __be32 checksum; /* decrypted data checksum */ +}; + +static int rxkad_prime_packet_security(struct rxrpc_connection *conn, + struct crypto_sync_skcipher *ci); + +/* + * this holds a pinned cipher so that keventd doesn't get called by the cipher + * alloc routine, but since we have it to hand, we use it to decrypt RESPONSE + * packets + */ +static struct crypto_sync_skcipher *rxkad_ci; +static struct skcipher_request *rxkad_ci_req; +static DEFINE_MUTEX(rxkad_ci_mutex); + +/* + * Parse the information from a server key + * + * The data should be the 8-byte secret key. + */ +static int rxkad_preparse_server_key(struct key_preparsed_payload *prep) +{ + struct crypto_skcipher *ci; + + if (prep->datalen != 8) + return -EINVAL; + + memcpy(&prep->payload.data[2], prep->data, 8); + + ci = crypto_alloc_skcipher("pcbc(des)", 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(ci)) { + _leave(" = %ld", PTR_ERR(ci)); + return PTR_ERR(ci); + } + + if (crypto_skcipher_setkey(ci, prep->data, 8) < 0) + BUG(); + + prep->payload.data[0] = ci; + _leave(" = 0"); + return 0; +} + +static void rxkad_free_preparse_server_key(struct key_preparsed_payload *prep) +{ + + if (prep->payload.data[0]) + crypto_free_skcipher(prep->payload.data[0]); +} + +static void rxkad_destroy_server_key(struct key *key) +{ + if (key->payload.data[0]) { + crypto_free_skcipher(key->payload.data[0]); + key->payload.data[0] = NULL; + } +} + +/* + * initialise connection security + */ +static int rxkad_init_connection_security(struct rxrpc_connection *conn, + struct rxrpc_key_token *token) +{ + struct crypto_sync_skcipher *ci; + int ret; + + _enter("{%d},{%x}", conn->debug_id, key_serial(conn->params.key)); + + conn->security_ix = token->security_index; + + ci = crypto_alloc_sync_skcipher("pcbc(fcrypt)", 0, 0); + if (IS_ERR(ci)) { + _debug("no cipher"); + ret = PTR_ERR(ci); + goto error; + } + + if (crypto_sync_skcipher_setkey(ci, token->kad->session_key, + sizeof(token->kad->session_key)) < 0) + BUG(); + + switch (conn->params.security_level) { + case RXRPC_SECURITY_PLAIN: + case RXRPC_SECURITY_AUTH: + case RXRPC_SECURITY_ENCRYPT: + break; + default: + ret = -EKEYREJECTED; + goto error; + } + + ret = rxkad_prime_packet_security(conn, ci); + if (ret < 0) + goto error_ci; + + conn->rxkad.cipher = ci; + return 0; + +error_ci: + crypto_free_sync_skcipher(ci); +error: + _leave(" = %d", ret); + return ret; +} + +/* + * Work out how much data we can put in a packet. + */ +static int rxkad_how_much_data(struct rxrpc_call *call, size_t remain, + size_t *_buf_size, size_t *_data_size, size_t *_offset) +{ + size_t shdr, buf_size, chunk; + + switch (call->conn->params.security_level) { + default: + buf_size = chunk = min_t(size_t, remain, RXRPC_JUMBO_DATALEN); + shdr = 0; + goto out; + case RXRPC_SECURITY_AUTH: + shdr = sizeof(struct rxkad_level1_hdr); + break; + case RXRPC_SECURITY_ENCRYPT: + shdr = sizeof(struct rxkad_level2_hdr); + break; + } + + buf_size = round_down(RXRPC_JUMBO_DATALEN, RXKAD_ALIGN); + + chunk = buf_size - shdr; + if (remain < chunk) + buf_size = round_up(shdr + remain, RXKAD_ALIGN); + +out: + *_buf_size = buf_size; + *_data_size = chunk; + *_offset = shdr; + return 0; +} + +/* + * prime the encryption state with the invariant parts of a connection's + * description + */ +static int rxkad_prime_packet_security(struct rxrpc_connection *conn, + struct crypto_sync_skcipher *ci) +{ + struct skcipher_request *req; + struct rxrpc_key_token *token; + struct scatterlist sg; + struct rxrpc_crypt iv; + __be32 *tmpbuf; + size_t tmpsize = 4 * sizeof(__be32); + + _enter(""); + + if (!conn->params.key) + return 0; + + tmpbuf = kmalloc(tmpsize, GFP_KERNEL); + if (!tmpbuf) + return -ENOMEM; + + req = skcipher_request_alloc(&ci->base, GFP_NOFS); + if (!req) { + kfree(tmpbuf); + return -ENOMEM; + } + + token = conn->params.key->payload.data[0]; + memcpy(&iv, token->kad->session_key, sizeof(iv)); + + tmpbuf[0] = htonl(conn->proto.epoch); + tmpbuf[1] = htonl(conn->proto.cid); + tmpbuf[2] = 0; + tmpbuf[3] = htonl(conn->security_ix); + + sg_init_one(&sg, tmpbuf, tmpsize); + skcipher_request_set_sync_tfm(req, ci); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, &sg, &sg, tmpsize, iv.x); + crypto_skcipher_encrypt(req); + skcipher_request_free(req); + + memcpy(&conn->rxkad.csum_iv, tmpbuf + 2, sizeof(conn->rxkad.csum_iv)); + kfree(tmpbuf); + _leave(" = 0"); + return 0; +} + +/* + * Allocate and prepare the crypto request on a call. For any particular call, + * this is called serially for the packets, so no lock should be necessary. + */ +static struct skcipher_request *rxkad_get_call_crypto(struct rxrpc_call *call) +{ + struct crypto_skcipher *tfm = &call->conn->rxkad.cipher->base; + struct skcipher_request *cipher_req = call->cipher_req; + + if (!cipher_req) { + cipher_req = skcipher_request_alloc(tfm, GFP_NOFS); + if (!cipher_req) + return NULL; + call->cipher_req = cipher_req; + } + + return cipher_req; +} + +/* + * Clean up the crypto on a call. + */ +static void rxkad_free_call_crypto(struct rxrpc_call *call) +{ + if (call->cipher_req) + skcipher_request_free(call->cipher_req); + call->cipher_req = NULL; +} + +/* + * partially encrypt a packet (level 1 security) + */ +static int rxkad_secure_packet_auth(const struct rxrpc_call *call, + struct sk_buff *skb, u32 data_size, + struct skcipher_request *req) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + struct rxkad_level1_hdr hdr; + struct rxrpc_crypt iv; + struct scatterlist sg; + size_t pad; + u16 check; + + _enter(""); + + check = sp->hdr.seq ^ call->call_id; + data_size |= (u32)check << 16; + + hdr.data_size = htonl(data_size); + memcpy(skb->head, &hdr, sizeof(hdr)); + + pad = sizeof(struct rxkad_level1_hdr) + data_size; + pad = RXKAD_ALIGN - pad; + pad &= RXKAD_ALIGN - 1; + if (pad) + skb_put_zero(skb, pad); + + /* start the encryption afresh */ + memset(&iv, 0, sizeof(iv)); + + sg_init_one(&sg, skb->head, 8); + skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x); + crypto_skcipher_encrypt(req); + skcipher_request_zero(req); + + _leave(" = 0"); + return 0; +} + +/* + * wholly encrypt a packet (level 2 security) + */ +static int rxkad_secure_packet_encrypt(const struct rxrpc_call *call, + struct sk_buff *skb, + u32 data_size, + struct skcipher_request *req) +{ + const struct rxrpc_key_token *token; + struct rxkad_level2_hdr rxkhdr; + struct rxrpc_skb_priv *sp; + struct rxrpc_crypt iv; + struct scatterlist sg[16]; + unsigned int len; + size_t pad; + u16 check; + int err; + + sp = rxrpc_skb(skb); + + _enter(""); + + check = sp->hdr.seq ^ call->call_id; + + rxkhdr.data_size = htonl(data_size | (u32)check << 16); + rxkhdr.checksum = 0; + memcpy(skb->head, &rxkhdr, sizeof(rxkhdr)); + + pad = sizeof(struct rxkad_level2_hdr) + data_size; + pad = RXKAD_ALIGN - pad; + pad &= RXKAD_ALIGN - 1; + if (pad) + skb_put_zero(skb, pad); + + /* encrypt from the session key */ + token = call->conn->params.key->payload.data[0]; + memcpy(&iv, token->kad->session_key, sizeof(iv)); + + sg_init_one(&sg[0], skb->head, sizeof(rxkhdr)); + skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, &sg[0], &sg[0], sizeof(rxkhdr), iv.x); + crypto_skcipher_encrypt(req); + + /* we want to encrypt the skbuff in-place */ + err = -EMSGSIZE; + if (skb_shinfo(skb)->nr_frags > 16) + goto out; + + len = round_up(data_size, RXKAD_ALIGN); + + sg_init_table(sg, ARRAY_SIZE(sg)); + err = skb_to_sgvec(skb, sg, 8, len); + if (unlikely(err < 0)) + goto out; + skcipher_request_set_crypt(req, sg, sg, len, iv.x); + crypto_skcipher_encrypt(req); + + _leave(" = 0"); + err = 0; + +out: + skcipher_request_zero(req); + return err; +} + +/* + * checksum an RxRPC packet header + */ +static int rxkad_secure_packet(struct rxrpc_call *call, + struct sk_buff *skb, + size_t data_size) +{ + struct rxrpc_skb_priv *sp; + struct skcipher_request *req; + struct rxrpc_crypt iv; + struct scatterlist sg; + u32 x, y; + int ret; + + sp = rxrpc_skb(skb); + + _enter("{%d{%x}},{#%u},%zu,", + call->debug_id, key_serial(call->conn->params.key), + sp->hdr.seq, data_size); + + if (!call->conn->rxkad.cipher) + return 0; + + ret = key_validate(call->conn->params.key); + if (ret < 0) + return ret; + + req = rxkad_get_call_crypto(call); + if (!req) + return -ENOMEM; + + /* continue encrypting from where we left off */ + memcpy(&iv, call->conn->rxkad.csum_iv.x, sizeof(iv)); + + /* calculate the security checksum */ + x = (call->cid & RXRPC_CHANNELMASK) << (32 - RXRPC_CIDSHIFT); + x |= sp->hdr.seq & 0x3fffffff; + call->crypto_buf[0] = htonl(call->call_id); + call->crypto_buf[1] = htonl(x); + + sg_init_one(&sg, call->crypto_buf, 8); + skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x); + crypto_skcipher_encrypt(req); + skcipher_request_zero(req); + + y = ntohl(call->crypto_buf[1]); + y = (y >> 16) & 0xffff; + if (y == 0) + y = 1; /* zero checksums are not permitted */ + sp->hdr.cksum = y; + + switch (call->conn->params.security_level) { + case RXRPC_SECURITY_PLAIN: + ret = 0; + break; + case RXRPC_SECURITY_AUTH: + ret = rxkad_secure_packet_auth(call, skb, data_size, req); + break; + case RXRPC_SECURITY_ENCRYPT: + ret = rxkad_secure_packet_encrypt(call, skb, data_size, req); + break; + default: + ret = -EPERM; + break; + } + + _leave(" = %d [set %x]", ret, y); + return ret; +} + +/* + * decrypt partial encryption on a packet (level 1 security) + */ +static int rxkad_verify_packet_1(struct rxrpc_call *call, struct sk_buff *skb, + unsigned int offset, unsigned int len, + rxrpc_seq_t seq, + struct skcipher_request *req) +{ + struct rxkad_level1_hdr sechdr; + struct rxrpc_crypt iv; + struct scatterlist sg[16]; + bool aborted; + u32 data_size, buf; + u16 check; + int ret; + + _enter(""); + + if (len < 8) { + aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_hdr", "V1H", + RXKADSEALEDINCON); + goto protocol_error; + } + + /* Decrypt the skbuff in-place. TODO: We really want to decrypt + * directly into the target buffer. + */ + sg_init_table(sg, ARRAY_SIZE(sg)); + ret = skb_to_sgvec(skb, sg, offset, 8); + if (unlikely(ret < 0)) + return ret; + + /* start the decryption afresh */ + memset(&iv, 0, sizeof(iv)); + + skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, sg, sg, 8, iv.x); + crypto_skcipher_decrypt(req); + skcipher_request_zero(req); + + /* Extract the decrypted packet length */ + if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) { + aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_len", "XV1", + RXKADDATALEN); + goto protocol_error; + } + len -= sizeof(sechdr); + + buf = ntohl(sechdr.data_size); + data_size = buf & 0xffff; + + check = buf >> 16; + check ^= seq ^ call->call_id; + check &= 0xffff; + if (check != 0) { + aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_check", "V1C", + RXKADSEALEDINCON); + goto protocol_error; + } + + if (data_size > len) { + aborted = rxrpc_abort_eproto(call, skb, "rxkad_1_datalen", "V1L", + RXKADDATALEN); + goto protocol_error; + } + + _leave(" = 0 [dlen=%x]", data_size); + return 0; + +protocol_error: + if (aborted) + rxrpc_send_abort_packet(call); + return -EPROTO; +} + +/* + * wholly decrypt a packet (level 2 security) + */ +static int rxkad_verify_packet_2(struct rxrpc_call *call, struct sk_buff *skb, + unsigned int offset, unsigned int len, + rxrpc_seq_t seq, + struct skcipher_request *req) +{ + const struct rxrpc_key_token *token; + struct rxkad_level2_hdr sechdr; + struct rxrpc_crypt iv; + struct scatterlist _sg[4], *sg; + bool aborted; + u32 data_size, buf; + u16 check; + int nsg, ret; + + _enter(",{%d}", skb->len); + + if (len < 8) { + aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_hdr", "V2H", + RXKADSEALEDINCON); + goto protocol_error; + } + + /* Decrypt the skbuff in-place. TODO: We really want to decrypt + * directly into the target buffer. + */ + sg = _sg; + nsg = skb_shinfo(skb)->nr_frags + 1; + if (nsg <= 4) { + nsg = 4; + } else { + sg = kmalloc_array(nsg, sizeof(*sg), GFP_NOIO); + if (!sg) + goto nomem; + } + + sg_init_table(sg, nsg); + ret = skb_to_sgvec(skb, sg, offset, len); + if (unlikely(ret < 0)) { + if (sg != _sg) + kfree(sg); + return ret; + } + + /* decrypt from the session key */ + token = call->conn->params.key->payload.data[0]; + memcpy(&iv, token->kad->session_key, sizeof(iv)); + + skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, sg, sg, len, iv.x); + crypto_skcipher_decrypt(req); + skcipher_request_zero(req); + if (sg != _sg) + kfree(sg); + + /* Extract the decrypted packet length */ + if (skb_copy_bits(skb, offset, &sechdr, sizeof(sechdr)) < 0) { + aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_len", "XV2", + RXKADDATALEN); + goto protocol_error; + } + len -= sizeof(sechdr); + + buf = ntohl(sechdr.data_size); + data_size = buf & 0xffff; + + check = buf >> 16; + check ^= seq ^ call->call_id; + check &= 0xffff; + if (check != 0) { + aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_check", "V2C", + RXKADSEALEDINCON); + goto protocol_error; + } + + if (data_size > len) { + aborted = rxrpc_abort_eproto(call, skb, "rxkad_2_datalen", "V2L", + RXKADDATALEN); + goto protocol_error; + } + + _leave(" = 0 [dlen=%x]", data_size); + return 0; + +protocol_error: + if (aborted) + rxrpc_send_abort_packet(call); + return -EPROTO; + +nomem: + _leave(" = -ENOMEM"); + return -ENOMEM; +} + +/* + * Verify the security on a received packet or subpacket (if part of a + * jumbo packet). + */ +static int rxkad_verify_packet(struct rxrpc_call *call, struct sk_buff *skb, + unsigned int offset, unsigned int len, + rxrpc_seq_t seq, u16 expected_cksum) +{ + struct skcipher_request *req; + struct rxrpc_crypt iv; + struct scatterlist sg; + bool aborted; + u16 cksum; + u32 x, y; + + _enter("{%d{%x}},{#%u}", + call->debug_id, key_serial(call->conn->params.key), seq); + + if (!call->conn->rxkad.cipher) + return 0; + + req = rxkad_get_call_crypto(call); + if (!req) + return -ENOMEM; + + /* continue encrypting from where we left off */ + memcpy(&iv, call->conn->rxkad.csum_iv.x, sizeof(iv)); + + /* validate the security checksum */ + x = (call->cid & RXRPC_CHANNELMASK) << (32 - RXRPC_CIDSHIFT); + x |= seq & 0x3fffffff; + call->crypto_buf[0] = htonl(call->call_id); + call->crypto_buf[1] = htonl(x); + + sg_init_one(&sg, call->crypto_buf, 8); + skcipher_request_set_sync_tfm(req, call->conn->rxkad.cipher); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, &sg, &sg, 8, iv.x); + crypto_skcipher_encrypt(req); + skcipher_request_zero(req); + + y = ntohl(call->crypto_buf[1]); + cksum = (y >> 16) & 0xffff; + if (cksum == 0) + cksum = 1; /* zero checksums are not permitted */ + + if (cksum != expected_cksum) { + aborted = rxrpc_abort_eproto(call, skb, "rxkad_csum", "VCK", + RXKADSEALEDINCON); + goto protocol_error; + } + + switch (call->conn->params.security_level) { + case RXRPC_SECURITY_PLAIN: + return 0; + case RXRPC_SECURITY_AUTH: + return rxkad_verify_packet_1(call, skb, offset, len, seq, req); + case RXRPC_SECURITY_ENCRYPT: + return rxkad_verify_packet_2(call, skb, offset, len, seq, req); + default: + return -ENOANO; + } + +protocol_error: + if (aborted) + rxrpc_send_abort_packet(call); + return -EPROTO; +} + +/* + * Locate the data contained in a packet that was partially encrypted. + */ +static void rxkad_locate_data_1(struct rxrpc_call *call, struct sk_buff *skb, + unsigned int *_offset, unsigned int *_len) +{ + struct rxkad_level1_hdr sechdr; + + if (skb_copy_bits(skb, *_offset, &sechdr, sizeof(sechdr)) < 0) + BUG(); + *_offset += sizeof(sechdr); + *_len = ntohl(sechdr.data_size) & 0xffff; +} + +/* + * Locate the data contained in a packet that was completely encrypted. + */ +static void rxkad_locate_data_2(struct rxrpc_call *call, struct sk_buff *skb, + unsigned int *_offset, unsigned int *_len) +{ + struct rxkad_level2_hdr sechdr; + + if (skb_copy_bits(skb, *_offset, &sechdr, sizeof(sechdr)) < 0) + BUG(); + *_offset += sizeof(sechdr); + *_len = ntohl(sechdr.data_size) & 0xffff; +} + +/* + * Locate the data contained in an already decrypted packet. + */ +static void rxkad_locate_data(struct rxrpc_call *call, struct sk_buff *skb, + unsigned int *_offset, unsigned int *_len) +{ + switch (call->conn->params.security_level) { + case RXRPC_SECURITY_AUTH: + rxkad_locate_data_1(call, skb, _offset, _len); + return; + case RXRPC_SECURITY_ENCRYPT: + rxkad_locate_data_2(call, skb, _offset, _len); + return; + default: + return; + } +} + +/* + * issue a challenge + */ +static int rxkad_issue_challenge(struct rxrpc_connection *conn) +{ + struct rxkad_challenge challenge; + struct rxrpc_wire_header whdr; + struct msghdr msg; + struct kvec iov[2]; + size_t len; + u32 serial; + int ret; + + _enter("{%d}", conn->debug_id); + + get_random_bytes(&conn->rxkad.nonce, sizeof(conn->rxkad.nonce)); + + challenge.version = htonl(2); + challenge.nonce = htonl(conn->rxkad.nonce); + challenge.min_level = htonl(0); + challenge.__padding = 0; + + msg.msg_name = &conn->params.peer->srx.transport; + msg.msg_namelen = conn->params.peer->srx.transport_len; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + whdr.epoch = htonl(conn->proto.epoch); + whdr.cid = htonl(conn->proto.cid); + whdr.callNumber = 0; + whdr.seq = 0; + whdr.type = RXRPC_PACKET_TYPE_CHALLENGE; + whdr.flags = conn->out_clientflag; + whdr.userStatus = 0; + whdr.securityIndex = conn->security_ix; + whdr._rsvd = 0; + whdr.serviceId = htons(conn->service_id); + + iov[0].iov_base = &whdr; + iov[0].iov_len = sizeof(whdr); + iov[1].iov_base = &challenge; + iov[1].iov_len = sizeof(challenge); + + len = iov[0].iov_len + iov[1].iov_len; + + serial = atomic_inc_return(&conn->serial); + whdr.serial = htonl(serial); + _proto("Tx CHALLENGE %%%u", serial); + + ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 2, len); + if (ret < 0) { + trace_rxrpc_tx_fail(conn->debug_id, serial, ret, + rxrpc_tx_point_rxkad_challenge); + return -EAGAIN; + } + + conn->params.peer->last_tx_at = ktime_get_seconds(); + trace_rxrpc_tx_packet(conn->debug_id, &whdr, + rxrpc_tx_point_rxkad_challenge); + _leave(" = 0"); + return 0; +} + +/* + * send a Kerberos security response + */ +static int rxkad_send_response(struct rxrpc_connection *conn, + struct rxrpc_host_header *hdr, + struct rxkad_response *resp, + const struct rxkad_key *s2) +{ + struct rxrpc_wire_header whdr; + struct msghdr msg; + struct kvec iov[3]; + size_t len; + u32 serial; + int ret; + + _enter(""); + + msg.msg_name = &conn->params.peer->srx.transport; + msg.msg_namelen = conn->params.peer->srx.transport_len; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_flags = 0; + + memset(&whdr, 0, sizeof(whdr)); + whdr.epoch = htonl(hdr->epoch); + whdr.cid = htonl(hdr->cid); + whdr.type = RXRPC_PACKET_TYPE_RESPONSE; + whdr.flags = conn->out_clientflag; + whdr.securityIndex = hdr->securityIndex; + whdr.serviceId = htons(hdr->serviceId); + + iov[0].iov_base = &whdr; + iov[0].iov_len = sizeof(whdr); + iov[1].iov_base = resp; + iov[1].iov_len = sizeof(*resp); + iov[2].iov_base = (void *)s2->ticket; + iov[2].iov_len = s2->ticket_len; + + len = iov[0].iov_len + iov[1].iov_len + iov[2].iov_len; + + serial = atomic_inc_return(&conn->serial); + whdr.serial = htonl(serial); + _proto("Tx RESPONSE %%%u", serial); + + ret = kernel_sendmsg(conn->params.local->socket, &msg, iov, 3, len); + if (ret < 0) { + trace_rxrpc_tx_fail(conn->debug_id, serial, ret, + rxrpc_tx_point_rxkad_response); + return -EAGAIN; + } + + conn->params.peer->last_tx_at = ktime_get_seconds(); + _leave(" = 0"); + return 0; +} + +/* + * calculate the response checksum + */ +static void rxkad_calc_response_checksum(struct rxkad_response *response) +{ + u32 csum = 1000003; + int loop; + u8 *p = (u8 *) response; + + for (loop = sizeof(*response); loop > 0; loop--) + csum = csum * 0x10204081 + *p++; + + response->encrypted.checksum = htonl(csum); +} + +/* + * encrypt the response packet + */ +static int rxkad_encrypt_response(struct rxrpc_connection *conn, + struct rxkad_response *resp, + const struct rxkad_key *s2) +{ + struct skcipher_request *req; + struct rxrpc_crypt iv; + struct scatterlist sg[1]; + + req = skcipher_request_alloc(&conn->rxkad.cipher->base, GFP_NOFS); + if (!req) + return -ENOMEM; + + /* continue encrypting from where we left off */ + memcpy(&iv, s2->session_key, sizeof(iv)); + + sg_init_table(sg, 1); + sg_set_buf(sg, &resp->encrypted, sizeof(resp->encrypted)); + skcipher_request_set_sync_tfm(req, conn->rxkad.cipher); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, sg, sg, sizeof(resp->encrypted), iv.x); + crypto_skcipher_encrypt(req); + skcipher_request_free(req); + return 0; +} + +/* + * respond to a challenge packet + */ +static int rxkad_respond_to_challenge(struct rxrpc_connection *conn, + struct sk_buff *skb, + u32 *_abort_code) +{ + const struct rxrpc_key_token *token; + struct rxkad_challenge challenge; + struct rxkad_response *resp; + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + const char *eproto; + u32 version, nonce, min_level, abort_code; + int ret; + + _enter("{%d,%x}", conn->debug_id, key_serial(conn->params.key)); + + eproto = tracepoint_string("chall_no_key"); + abort_code = RX_PROTOCOL_ERROR; + if (!conn->params.key) + goto protocol_error; + + abort_code = RXKADEXPIRED; + ret = key_validate(conn->params.key); + if (ret < 0) + goto other_error; + + eproto = tracepoint_string("chall_short"); + abort_code = RXKADPACKETSHORT; + if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header), + &challenge, sizeof(challenge)) < 0) + goto protocol_error; + + version = ntohl(challenge.version); + nonce = ntohl(challenge.nonce); + min_level = ntohl(challenge.min_level); + + _proto("Rx CHALLENGE %%%u { v=%u n=%u ml=%u }", + sp->hdr.serial, version, nonce, min_level); + + eproto = tracepoint_string("chall_ver"); + abort_code = RXKADINCONSISTENCY; + if (version != RXKAD_VERSION) + goto protocol_error; + + abort_code = RXKADLEVELFAIL; + ret = -EACCES; + if (conn->params.security_level < min_level) + goto other_error; + + token = conn->params.key->payload.data[0]; + + /* build the response packet */ + resp = kzalloc(sizeof(struct rxkad_response), GFP_NOFS); + if (!resp) + return -ENOMEM; + + resp->version = htonl(RXKAD_VERSION); + resp->encrypted.epoch = htonl(conn->proto.epoch); + resp->encrypted.cid = htonl(conn->proto.cid); + resp->encrypted.securityIndex = htonl(conn->security_ix); + resp->encrypted.inc_nonce = htonl(nonce + 1); + resp->encrypted.level = htonl(conn->params.security_level); + resp->kvno = htonl(token->kad->kvno); + resp->ticket_len = htonl(token->kad->ticket_len); + resp->encrypted.call_id[0] = htonl(conn->channels[0].call_counter); + resp->encrypted.call_id[1] = htonl(conn->channels[1].call_counter); + resp->encrypted.call_id[2] = htonl(conn->channels[2].call_counter); + resp->encrypted.call_id[3] = htonl(conn->channels[3].call_counter); + + /* calculate the response checksum and then do the encryption */ + rxkad_calc_response_checksum(resp); + ret = rxkad_encrypt_response(conn, resp, token->kad); + if (ret == 0) + ret = rxkad_send_response(conn, &sp->hdr, resp, token->kad); + kfree(resp); + return ret; + +protocol_error: + trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto); + ret = -EPROTO; +other_error: + *_abort_code = abort_code; + return ret; +} + +/* + * decrypt the kerberos IV ticket in the response + */ +static int rxkad_decrypt_ticket(struct rxrpc_connection *conn, + struct key *server_key, + struct sk_buff *skb, + void *ticket, size_t ticket_len, + struct rxrpc_crypt *_session_key, + time64_t *_expiry, + u32 *_abort_code) +{ + struct skcipher_request *req; + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + struct rxrpc_crypt iv, key; + struct scatterlist sg[1]; + struct in_addr addr; + unsigned int life; + const char *eproto; + time64_t issue, now; + bool little_endian; + int ret; + u32 abort_code; + u8 *p, *q, *name, *end; + + _enter("{%d},{%x}", conn->debug_id, key_serial(server_key)); + + *_expiry = 0; + + ASSERT(server_key->payload.data[0] != NULL); + ASSERTCMP((unsigned long) ticket & 7UL, ==, 0); + + memcpy(&iv, &server_key->payload.data[2], sizeof(iv)); + + ret = -ENOMEM; + req = skcipher_request_alloc(server_key->payload.data[0], GFP_NOFS); + if (!req) + goto temporary_error; + + sg_init_one(&sg[0], ticket, ticket_len); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, sg, sg, ticket_len, iv.x); + crypto_skcipher_decrypt(req); + skcipher_request_free(req); + + p = ticket; + end = p + ticket_len; + +#define Z(field) \ + ({ \ + u8 *__str = p; \ + eproto = tracepoint_string("rxkad_bad_"#field); \ + q = memchr(p, 0, end - p); \ + if (!q || q - p > (field##_SZ)) \ + goto bad_ticket; \ + for (; p < q; p++) \ + if (!isprint(*p)) \ + goto bad_ticket; \ + p++; \ + __str; \ + }) + + /* extract the ticket flags */ + _debug("KIV FLAGS: %x", *p); + little_endian = *p & 1; + p++; + + /* extract the authentication name */ + name = Z(ANAME); + _debug("KIV ANAME: %s", name); + + /* extract the principal's instance */ + name = Z(INST); + _debug("KIV INST : %s", name); + + /* extract the principal's authentication domain */ + name = Z(REALM); + _debug("KIV REALM: %s", name); + + eproto = tracepoint_string("rxkad_bad_len"); + if (end - p < 4 + 8 + 4 + 2) + goto bad_ticket; + + /* get the IPv4 address of the entity that requested the ticket */ + memcpy(&addr, p, sizeof(addr)); + p += 4; + _debug("KIV ADDR : %pI4", &addr); + + /* get the session key from the ticket */ + memcpy(&key, p, sizeof(key)); + p += 8; + _debug("KIV KEY : %08x %08x", ntohl(key.n[0]), ntohl(key.n[1])); + memcpy(_session_key, &key, sizeof(key)); + + /* get the ticket's lifetime */ + life = *p++ * 5 * 60; + _debug("KIV LIFE : %u", life); + + /* get the issue time of the ticket */ + if (little_endian) { + __le32 stamp; + memcpy(&stamp, p, 4); + issue = rxrpc_u32_to_time64(le32_to_cpu(stamp)); + } else { + __be32 stamp; + memcpy(&stamp, p, 4); + issue = rxrpc_u32_to_time64(be32_to_cpu(stamp)); + } + p += 4; + now = ktime_get_real_seconds(); + _debug("KIV ISSUE: %llx [%llx]", issue, now); + + /* check the ticket is in date */ + if (issue > now) { + abort_code = RXKADNOAUTH; + ret = -EKEYREJECTED; + goto other_error; + } + + if (issue < now - life) { + abort_code = RXKADEXPIRED; + ret = -EKEYEXPIRED; + goto other_error; + } + + *_expiry = issue + life; + + /* get the service name */ + name = Z(SNAME); + _debug("KIV SNAME: %s", name); + + /* get the service instance name */ + name = Z(INST); + _debug("KIV SINST: %s", name); + return 0; + +bad_ticket: + trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto); + abort_code = RXKADBADTICKET; + ret = -EPROTO; +other_error: + *_abort_code = abort_code; + return ret; +temporary_error: + return ret; +} + +/* + * decrypt the response packet + */ +static void rxkad_decrypt_response(struct rxrpc_connection *conn, + struct rxkad_response *resp, + const struct rxrpc_crypt *session_key) +{ + struct skcipher_request *req = rxkad_ci_req; + struct scatterlist sg[1]; + struct rxrpc_crypt iv; + + _enter(",,%08x%08x", + ntohl(session_key->n[0]), ntohl(session_key->n[1])); + + mutex_lock(&rxkad_ci_mutex); + if (crypto_sync_skcipher_setkey(rxkad_ci, session_key->x, + sizeof(*session_key)) < 0) + BUG(); + + memcpy(&iv, session_key, sizeof(iv)); + + sg_init_table(sg, 1); + sg_set_buf(sg, &resp->encrypted, sizeof(resp->encrypted)); + skcipher_request_set_sync_tfm(req, rxkad_ci); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, sg, sg, sizeof(resp->encrypted), iv.x); + crypto_skcipher_decrypt(req); + skcipher_request_zero(req); + + mutex_unlock(&rxkad_ci_mutex); + + _leave(""); +} + +/* + * verify a response + */ +static int rxkad_verify_response(struct rxrpc_connection *conn, + struct sk_buff *skb, + u32 *_abort_code) +{ + struct rxkad_response *response; + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + struct rxrpc_crypt session_key; + struct key *server_key; + const char *eproto; + time64_t expiry; + void *ticket; + u32 abort_code, version, kvno, ticket_len, level; + __be32 csum; + int ret, i; + + _enter("{%d}", conn->debug_id); + + server_key = rxrpc_look_up_server_security(conn, skb, 0, 0); + if (IS_ERR(server_key)) { + switch (PTR_ERR(server_key)) { + case -ENOKEY: + abort_code = RXKADUNKNOWNKEY; + break; + case -EKEYEXPIRED: + abort_code = RXKADEXPIRED; + break; + default: + abort_code = RXKADNOAUTH; + break; + } + trace_rxrpc_abort(0, "SVK", + sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, + abort_code, PTR_ERR(server_key)); + *_abort_code = abort_code; + return -EPROTO; + } + + ret = -ENOMEM; + response = kzalloc(sizeof(struct rxkad_response), GFP_NOFS); + if (!response) + goto temporary_error; + + eproto = tracepoint_string("rxkad_rsp_short"); + abort_code = RXKADPACKETSHORT; + if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header), + response, sizeof(*response)) < 0) + goto protocol_error; + + version = ntohl(response->version); + ticket_len = ntohl(response->ticket_len); + kvno = ntohl(response->kvno); + _proto("Rx RESPONSE %%%u { v=%u kv=%u tl=%u }", + sp->hdr.serial, version, kvno, ticket_len); + + eproto = tracepoint_string("rxkad_rsp_ver"); + abort_code = RXKADINCONSISTENCY; + if (version != RXKAD_VERSION) + goto protocol_error; + + eproto = tracepoint_string("rxkad_rsp_tktlen"); + abort_code = RXKADTICKETLEN; + if (ticket_len < 4 || ticket_len > MAXKRB5TICKETLEN) + goto protocol_error; + + eproto = tracepoint_string("rxkad_rsp_unkkey"); + abort_code = RXKADUNKNOWNKEY; + if (kvno >= RXKAD_TKT_TYPE_KERBEROS_V5) + goto protocol_error; + + /* extract the kerberos ticket and decrypt and decode it */ + ret = -ENOMEM; + ticket = kmalloc(ticket_len, GFP_NOFS); + if (!ticket) + goto temporary_error_free_resp; + + eproto = tracepoint_string("rxkad_tkt_short"); + abort_code = RXKADPACKETSHORT; + if (skb_copy_bits(skb, sizeof(struct rxrpc_wire_header) + sizeof(*response), + ticket, ticket_len) < 0) + goto protocol_error_free; + + ret = rxkad_decrypt_ticket(conn, server_key, skb, ticket, ticket_len, + &session_key, &expiry, _abort_code); + if (ret < 0) + goto temporary_error_free_ticket; + + /* use the session key from inside the ticket to decrypt the + * response */ + rxkad_decrypt_response(conn, response, &session_key); + + eproto = tracepoint_string("rxkad_rsp_param"); + abort_code = RXKADSEALEDINCON; + if (ntohl(response->encrypted.epoch) != conn->proto.epoch) + goto protocol_error_free; + if (ntohl(response->encrypted.cid) != conn->proto.cid) + goto protocol_error_free; + if (ntohl(response->encrypted.securityIndex) != conn->security_ix) + goto protocol_error_free; + csum = response->encrypted.checksum; + response->encrypted.checksum = 0; + rxkad_calc_response_checksum(response); + eproto = tracepoint_string("rxkad_rsp_csum"); + if (response->encrypted.checksum != csum) + goto protocol_error_free; + + spin_lock(&conn->bundle->channel_lock); + for (i = 0; i < RXRPC_MAXCALLS; i++) { + struct rxrpc_call *call; + u32 call_id = ntohl(response->encrypted.call_id[i]); + + eproto = tracepoint_string("rxkad_rsp_callid"); + if (call_id > INT_MAX) + goto protocol_error_unlock; + + eproto = tracepoint_string("rxkad_rsp_callctr"); + if (call_id < conn->channels[i].call_counter) + goto protocol_error_unlock; + + eproto = tracepoint_string("rxkad_rsp_callst"); + if (call_id > conn->channels[i].call_counter) { + call = rcu_dereference_protected( + conn->channels[i].call, + lockdep_is_held(&conn->bundle->channel_lock)); + if (call && call->state < RXRPC_CALL_COMPLETE) + goto protocol_error_unlock; + conn->channels[i].call_counter = call_id; + } + } + spin_unlock(&conn->bundle->channel_lock); + + eproto = tracepoint_string("rxkad_rsp_seq"); + abort_code = RXKADOUTOFSEQUENCE; + if (ntohl(response->encrypted.inc_nonce) != conn->rxkad.nonce + 1) + goto protocol_error_free; + + eproto = tracepoint_string("rxkad_rsp_level"); + abort_code = RXKADLEVELFAIL; + level = ntohl(response->encrypted.level); + if (level > RXRPC_SECURITY_ENCRYPT) + goto protocol_error_free; + conn->params.security_level = level; + + /* create a key to hold the security data and expiration time - after + * this the connection security can be handled in exactly the same way + * as for a client connection */ + ret = rxrpc_get_server_data_key(conn, &session_key, expiry, kvno); + if (ret < 0) + goto temporary_error_free_ticket; + + kfree(ticket); + kfree(response); + _leave(" = 0"); + return 0; + +protocol_error_unlock: + spin_unlock(&conn->bundle->channel_lock); +protocol_error_free: + kfree(ticket); +protocol_error: + kfree(response); + trace_rxrpc_rx_eproto(NULL, sp->hdr.serial, eproto); + key_put(server_key); + *_abort_code = abort_code; + return -EPROTO; + +temporary_error_free_ticket: + kfree(ticket); +temporary_error_free_resp: + kfree(response); +temporary_error: + /* Ignore the response packet if we got a temporary error such as + * ENOMEM. We just want to send the challenge again. Note that we + * also come out this way if the ticket decryption fails. + */ + key_put(server_key); + return ret; +} + +/* + * clear the connection security + */ +static void rxkad_clear(struct rxrpc_connection *conn) +{ + _enter(""); + + if (conn->rxkad.cipher) + crypto_free_sync_skcipher(conn->rxkad.cipher); +} + +/* + * Initialise the rxkad security service. + */ +static int rxkad_init(void) +{ + struct crypto_sync_skcipher *tfm; + struct skcipher_request *req; + + /* pin the cipher we need so that the crypto layer doesn't invoke + * keventd to go get it */ + tfm = crypto_alloc_sync_skcipher("pcbc(fcrypt)", 0, 0); + if (IS_ERR(tfm)) + return PTR_ERR(tfm); + + req = skcipher_request_alloc(&tfm->base, GFP_KERNEL); + if (!req) + goto nomem_tfm; + + rxkad_ci_req = req; + rxkad_ci = tfm; + return 0; + +nomem_tfm: + crypto_free_sync_skcipher(tfm); + return -ENOMEM; +} + +/* + * Clean up the rxkad security service. + */ +static void rxkad_exit(void) +{ + crypto_free_sync_skcipher(rxkad_ci); + skcipher_request_free(rxkad_ci_req); +} + +/* + * RxRPC Kerberos-based security + */ +const struct rxrpc_security rxkad = { + .name = "rxkad", + .security_index = RXRPC_SECURITY_RXKAD, + .no_key_abort = RXKADUNKNOWNKEY, + .init = rxkad_init, + .exit = rxkad_exit, + .preparse_server_key = rxkad_preparse_server_key, + .free_preparse_server_key = rxkad_free_preparse_server_key, + .destroy_server_key = rxkad_destroy_server_key, + .init_connection_security = rxkad_init_connection_security, + .how_much_data = rxkad_how_much_data, + .secure_packet = rxkad_secure_packet, + .verify_packet = rxkad_verify_packet, + .free_call_crypto = rxkad_free_call_crypto, + .locate_data = rxkad_locate_data, + .issue_challenge = rxkad_issue_challenge, + .respond_to_challenge = rxkad_respond_to_challenge, + .verify_response = rxkad_verify_response, + .clear = rxkad_clear, +}; diff --git a/net/rxrpc/security.c b/net/rxrpc/security.c new file mode 100644 index 000000000..50cb5f1ee --- /dev/null +++ b/net/rxrpc/security.c @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* RxRPC security handling + * + * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "ar-internal.h" + +static const struct rxrpc_security *rxrpc_security_types[] = { + [RXRPC_SECURITY_NONE] = &rxrpc_no_security, +#ifdef CONFIG_RXKAD + [RXRPC_SECURITY_RXKAD] = &rxkad, +#endif +}; + +int __init rxrpc_init_security(void) +{ + int i, ret; + + for (i = 0; i < ARRAY_SIZE(rxrpc_security_types); i++) { + if (rxrpc_security_types[i]) { + ret = rxrpc_security_types[i]->init(); + if (ret < 0) + goto failed; + } + } + + return 0; + +failed: + for (i--; i >= 0; i--) + if (rxrpc_security_types[i]) + rxrpc_security_types[i]->exit(); + return ret; +} + +void rxrpc_exit_security(void) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(rxrpc_security_types); i++) + if (rxrpc_security_types[i]) + rxrpc_security_types[i]->exit(); +} + +/* + * look up an rxrpc security module + */ +const struct rxrpc_security *rxrpc_security_lookup(u8 security_index) +{ + if (security_index >= ARRAY_SIZE(rxrpc_security_types)) + return NULL; + return rxrpc_security_types[security_index]; +} + +/* + * initialise the security on a client connection + */ +int rxrpc_init_client_conn_security(struct rxrpc_connection *conn) +{ + const struct rxrpc_security *sec; + struct rxrpc_key_token *token; + struct key *key = conn->params.key; + int ret; + + _enter("{%d},{%x}", conn->debug_id, key_serial(key)); + + if (!key) + return 0; + + ret = key_validate(key); + if (ret < 0) + return ret; + + for (token = key->payload.data[0]; token; token = token->next) { + sec = rxrpc_security_lookup(token->security_index); + if (sec) + goto found; + } + return -EKEYREJECTED; + +found: + conn->security = sec; + + ret = conn->security->init_connection_security(conn, token); + if (ret < 0) { + conn->security = &rxrpc_no_security; + return ret; + } + + _leave(" = 0"); + return 0; +} + +/* + * Set the ops a server connection. + */ +const struct rxrpc_security *rxrpc_get_incoming_security(struct rxrpc_sock *rx, + struct sk_buff *skb) +{ + const struct rxrpc_security *sec; + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + + _enter(""); + + sec = rxrpc_security_lookup(sp->hdr.securityIndex); + if (!sec) { + trace_rxrpc_abort(0, "SVS", + sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, + RX_INVALID_OPERATION, EKEYREJECTED); + skb->mark = RXRPC_SKB_MARK_REJECT_ABORT; + skb->priority = RX_INVALID_OPERATION; + return NULL; + } + + if (sp->hdr.securityIndex != RXRPC_SECURITY_NONE && + !rx->securities) { + trace_rxrpc_abort(0, "SVR", + sp->hdr.cid, sp->hdr.callNumber, sp->hdr.seq, + RX_INVALID_OPERATION, EKEYREJECTED); + skb->mark = RXRPC_SKB_MARK_REJECT_ABORT; + skb->priority = sec->no_key_abort; + return NULL; + } + + return sec; +} + +/* + * Find the security key for a server connection. + */ +struct key *rxrpc_look_up_server_security(struct rxrpc_connection *conn, + struct sk_buff *skb, + u32 kvno, u32 enctype) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + struct rxrpc_sock *rx; + struct key *key = ERR_PTR(-EKEYREJECTED); + key_ref_t kref = NULL; + char kdesc[5 + 1 + 3 + 1 + 12 + 1 + 12 + 1]; + int ret; + + _enter(""); + + if (enctype) + sprintf(kdesc, "%u:%u:%u:%u", + sp->hdr.serviceId, sp->hdr.securityIndex, kvno, enctype); + else if (kvno) + sprintf(kdesc, "%u:%u:%u", + sp->hdr.serviceId, sp->hdr.securityIndex, kvno); + else + sprintf(kdesc, "%u:%u", + sp->hdr.serviceId, sp->hdr.securityIndex); + + rcu_read_lock(); + + rx = rcu_dereference(conn->params.local->service); + if (!rx) + goto out; + + /* look through the service's keyring */ + kref = keyring_search(make_key_ref(rx->securities, 1UL), + &key_type_rxrpc_s, kdesc, true); + if (IS_ERR(kref)) { + key = ERR_CAST(kref); + goto out; + } + + key = key_ref_to_ptr(kref); + + ret = key_validate(key); + if (ret < 0) { + key_put(key); + key = ERR_PTR(ret); + goto out; + } + +out: + rcu_read_unlock(); + return key; +} diff --git a/net/rxrpc/sendmsg.c b/net/rxrpc/sendmsg.c new file mode 100644 index 000000000..71e40f91d --- /dev/null +++ b/net/rxrpc/sendmsg.c @@ -0,0 +1,882 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* AF_RXRPC sendmsg() implementation. + * + * Copyright (C) 2007, 2016 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include + +#include +#include +#include "ar-internal.h" + +/* + * Return true if there's sufficient Tx queue space. + */ +static bool rxrpc_check_tx_space(struct rxrpc_call *call, rxrpc_seq_t *_tx_win) +{ + unsigned int win_size = + min_t(unsigned int, call->tx_winsize, + call->cong_cwnd + call->cong_extra); + rxrpc_seq_t tx_win = READ_ONCE(call->tx_hard_ack); + + if (_tx_win) + *_tx_win = tx_win; + return call->tx_top - tx_win < win_size; +} + +/* + * Wait for space to appear in the Tx queue or a signal to occur. + */ +static int rxrpc_wait_for_tx_window_intr(struct rxrpc_sock *rx, + struct rxrpc_call *call, + long *timeo) +{ + for (;;) { + set_current_state(TASK_INTERRUPTIBLE); + if (rxrpc_check_tx_space(call, NULL)) + return 0; + + if (call->state >= RXRPC_CALL_COMPLETE) + return call->error; + + if (signal_pending(current)) + return sock_intr_errno(*timeo); + + trace_rxrpc_transmit(call, rxrpc_transmit_wait); + *timeo = schedule_timeout(*timeo); + } +} + +/* + * Wait for space to appear in the Tx queue uninterruptibly, but with + * a timeout of 2*RTT if no progress was made and a signal occurred. + */ +static int rxrpc_wait_for_tx_window_waitall(struct rxrpc_sock *rx, + struct rxrpc_call *call) +{ + rxrpc_seq_t tx_start, tx_win; + signed long rtt, timeout; + + rtt = READ_ONCE(call->peer->srtt_us) >> 3; + rtt = usecs_to_jiffies(rtt) * 2; + if (rtt < 2) + rtt = 2; + + timeout = rtt; + tx_start = READ_ONCE(call->tx_hard_ack); + + for (;;) { + set_current_state(TASK_UNINTERRUPTIBLE); + + tx_win = READ_ONCE(call->tx_hard_ack); + if (rxrpc_check_tx_space(call, &tx_win)) + return 0; + + if (call->state >= RXRPC_CALL_COMPLETE) + return call->error; + + if (timeout == 0 && + tx_win == tx_start && signal_pending(current)) + return -EINTR; + + if (tx_win != tx_start) { + timeout = rtt; + tx_start = tx_win; + } + + trace_rxrpc_transmit(call, rxrpc_transmit_wait); + timeout = schedule_timeout(timeout); + } +} + +/* + * Wait for space to appear in the Tx queue uninterruptibly. + */ +static int rxrpc_wait_for_tx_window_nonintr(struct rxrpc_sock *rx, + struct rxrpc_call *call, + long *timeo) +{ + for (;;) { + set_current_state(TASK_UNINTERRUPTIBLE); + if (rxrpc_check_tx_space(call, NULL)) + return 0; + + if (call->state >= RXRPC_CALL_COMPLETE) + return call->error; + + trace_rxrpc_transmit(call, rxrpc_transmit_wait); + *timeo = schedule_timeout(*timeo); + } +} + +/* + * wait for space to appear in the transmit/ACK window + * - caller holds the socket locked + */ +static int rxrpc_wait_for_tx_window(struct rxrpc_sock *rx, + struct rxrpc_call *call, + long *timeo, + bool waitall) +{ + DECLARE_WAITQUEUE(myself, current); + int ret; + + _enter(",{%u,%u,%u}", + call->tx_hard_ack, call->tx_top, call->tx_winsize); + + add_wait_queue(&call->waitq, &myself); + + switch (call->interruptibility) { + case RXRPC_INTERRUPTIBLE: + if (waitall) + ret = rxrpc_wait_for_tx_window_waitall(rx, call); + else + ret = rxrpc_wait_for_tx_window_intr(rx, call, timeo); + break; + case RXRPC_PREINTERRUPTIBLE: + case RXRPC_UNINTERRUPTIBLE: + default: + ret = rxrpc_wait_for_tx_window_nonintr(rx, call, timeo); + break; + } + + remove_wait_queue(&call->waitq, &myself); + set_current_state(TASK_RUNNING); + _leave(" = %d", ret); + return ret; +} + +/* + * Schedule an instant Tx resend. + */ +static inline void rxrpc_instant_resend(struct rxrpc_call *call, int ix) +{ + spin_lock_bh(&call->lock); + + if (call->state < RXRPC_CALL_COMPLETE) { + call->rxtx_annotations[ix] = + (call->rxtx_annotations[ix] & RXRPC_TX_ANNO_LAST) | + RXRPC_TX_ANNO_RETRANS; + if (!test_and_set_bit(RXRPC_CALL_EV_RESEND, &call->events)) + rxrpc_queue_call(call); + } + + spin_unlock_bh(&call->lock); +} + +/* + * Notify the owner of the call that the transmit phase is ended and the last + * packet has been queued. + */ +static void rxrpc_notify_end_tx(struct rxrpc_sock *rx, struct rxrpc_call *call, + rxrpc_notify_end_tx_t notify_end_tx) +{ + if (notify_end_tx) + notify_end_tx(&rx->sk, call, call->user_call_ID); +} + +/* + * Queue a DATA packet for transmission, set the resend timeout and send + * the packet immediately. Returns the error from rxrpc_send_data_packet() + * in case the caller wants to do something with it. + */ +static int rxrpc_queue_packet(struct rxrpc_sock *rx, struct rxrpc_call *call, + struct sk_buff *skb, bool last, + rxrpc_notify_end_tx_t notify_end_tx) +{ + struct rxrpc_skb_priv *sp = rxrpc_skb(skb); + unsigned long now; + rxrpc_seq_t seq = sp->hdr.seq; + int ret, ix; + u8 annotation = RXRPC_TX_ANNO_UNACK; + + _net("queue skb %p [%d]", skb, seq); + + ASSERTCMP(seq, ==, call->tx_top + 1); + + if (last) + annotation |= RXRPC_TX_ANNO_LAST; + + /* We have to set the timestamp before queueing as the retransmit + * algorithm can see the packet as soon as we queue it. + */ + skb->tstamp = ktime_get_real(); + + ix = seq & RXRPC_RXTX_BUFF_MASK; + rxrpc_get_skb(skb, rxrpc_skb_got); + call->rxtx_annotations[ix] = annotation; + smp_wmb(); + call->rxtx_buffer[ix] = skb; + call->tx_top = seq; + if (last) + trace_rxrpc_transmit(call, rxrpc_transmit_queue_last); + else + trace_rxrpc_transmit(call, rxrpc_transmit_queue); + + if (last || call->state == RXRPC_CALL_SERVER_ACK_REQUEST) { + _debug("________awaiting reply/ACK__________"); + write_lock_bh(&call->state_lock); + switch (call->state) { + case RXRPC_CALL_CLIENT_SEND_REQUEST: + call->state = RXRPC_CALL_CLIENT_AWAIT_REPLY; + rxrpc_notify_end_tx(rx, call, notify_end_tx); + break; + case RXRPC_CALL_SERVER_ACK_REQUEST: + call->state = RXRPC_CALL_SERVER_SEND_REPLY; + now = jiffies; + WRITE_ONCE(call->ack_at, now + MAX_JIFFY_OFFSET); + if (call->ackr_reason == RXRPC_ACK_DELAY) + call->ackr_reason = 0; + trace_rxrpc_timer(call, rxrpc_timer_init_for_send_reply, now); + if (!last) + break; + fallthrough; + case RXRPC_CALL_SERVER_SEND_REPLY: + call->state = RXRPC_CALL_SERVER_AWAIT_ACK; + rxrpc_notify_end_tx(rx, call, notify_end_tx); + break; + default: + break; + } + write_unlock_bh(&call->state_lock); + } + + if (seq == 1 && rxrpc_is_client_call(call)) + rxrpc_expose_client_call(call); + + ret = rxrpc_send_data_packet(call, skb, false); + if (ret < 0) { + switch (ret) { + case -ENETUNREACH: + case -EHOSTUNREACH: + case -ECONNREFUSED: + rxrpc_set_call_completion(call, RXRPC_CALL_LOCAL_ERROR, + 0, ret); + goto out; + } + _debug("need instant resend %d", ret); + rxrpc_instant_resend(call, ix); + } else { + unsigned long now = jiffies; + unsigned long resend_at = now + call->peer->rto_j; + + WRITE_ONCE(call->resend_at, resend_at); + rxrpc_reduce_call_timer(call, resend_at, now, + rxrpc_timer_set_for_send); + } + +out: + rxrpc_free_skb(skb, rxrpc_skb_freed); + _leave(" = %d", ret); + return ret; +} + +/* + * send data through a socket + * - must be called in process context + * - The caller holds the call user access mutex, but not the socket lock. + */ +static int rxrpc_send_data(struct rxrpc_sock *rx, + struct rxrpc_call *call, + struct msghdr *msg, size_t len, + rxrpc_notify_end_tx_t notify_end_tx, + bool *_dropped_lock) +{ + struct rxrpc_skb_priv *sp; + struct sk_buff *skb; + struct sock *sk = &rx->sk; + enum rxrpc_call_state state; + long timeo; + bool more = msg->msg_flags & MSG_MORE; + int ret, copied = 0; + + timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + + /* this should be in poll */ + sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); + +reload: + ret = -EPIPE; + if (sk->sk_shutdown & SEND_SHUTDOWN) + goto maybe_error; + state = READ_ONCE(call->state); + ret = -ESHUTDOWN; + if (state >= RXRPC_CALL_COMPLETE) + goto maybe_error; + ret = -EPROTO; + if (state != RXRPC_CALL_CLIENT_SEND_REQUEST && + state != RXRPC_CALL_SERVER_ACK_REQUEST && + state != RXRPC_CALL_SERVER_SEND_REPLY) + goto maybe_error; + + ret = -EMSGSIZE; + if (call->tx_total_len != -1) { + if (len - copied > call->tx_total_len) + goto maybe_error; + if (!more && len - copied != call->tx_total_len) + goto maybe_error; + } + + skb = call->tx_pending; + call->tx_pending = NULL; + rxrpc_see_skb(skb, rxrpc_skb_seen); + + do { + /* Check to see if there's a ping ACK to reply to. */ + if (call->ackr_reason == RXRPC_ACK_PING_RESPONSE) + rxrpc_send_ack_packet(call, false, NULL); + + if (!skb) { + size_t remain, bufsize, chunk, offset; + + _debug("alloc"); + + if (!rxrpc_check_tx_space(call, NULL)) + goto wait_for_space; + + /* Work out the maximum size of a packet. Assume that + * the security header is going to be in the padded + * region (enc blocksize), but the trailer is not. + */ + remain = more ? INT_MAX : msg_data_left(msg); + ret = call->conn->security->how_much_data(call, remain, + &bufsize, &chunk, &offset); + if (ret < 0) + goto maybe_error; + + _debug("SIZE: %zu/%zu @%zu", chunk, bufsize, offset); + + /* create a buffer that we can retain until it's ACK'd */ + skb = sock_alloc_send_skb( + sk, bufsize, msg->msg_flags & MSG_DONTWAIT, &ret); + if (!skb) + goto maybe_error; + + sp = rxrpc_skb(skb); + sp->rx_flags |= RXRPC_SKB_TX_BUFFER; + rxrpc_new_skb(skb, rxrpc_skb_new); + + _debug("ALLOC SEND %p", skb); + + ASSERTCMP(skb->mark, ==, 0); + + __skb_put(skb, offset); + + sp->remain = chunk; + if (sp->remain > skb_tailroom(skb)) + sp->remain = skb_tailroom(skb); + + _net("skb: hr %d, tr %d, hl %d, rm %d", + skb_headroom(skb), + skb_tailroom(skb), + skb_headlen(skb), + sp->remain); + + skb->ip_summed = CHECKSUM_UNNECESSARY; + } + + _debug("append"); + sp = rxrpc_skb(skb); + + /* append next segment of data to the current buffer */ + if (msg_data_left(msg) > 0) { + int copy = skb_tailroom(skb); + ASSERTCMP(copy, >, 0); + if (copy > msg_data_left(msg)) + copy = msg_data_left(msg); + if (copy > sp->remain) + copy = sp->remain; + + _debug("add"); + ret = skb_add_data(skb, &msg->msg_iter, copy); + _debug("added"); + if (ret < 0) + goto efault; + sp->remain -= copy; + skb->mark += copy; + copied += copy; + if (call->tx_total_len != -1) + call->tx_total_len -= copy; + } + + /* check for the far side aborting the call or a network error + * occurring */ + if (call->state == RXRPC_CALL_COMPLETE) + goto call_terminated; + + /* add the packet to the send queue if it's now full */ + if (sp->remain <= 0 || + (msg_data_left(msg) == 0 && !more)) { + struct rxrpc_connection *conn = call->conn; + uint32_t seq; + + seq = call->tx_top + 1; + + sp->hdr.seq = seq; + sp->hdr._rsvd = 0; + sp->hdr.flags = conn->out_clientflag; + + if (msg_data_left(msg) == 0 && !more) + sp->hdr.flags |= RXRPC_LAST_PACKET; + else if (call->tx_top - call->tx_hard_ack < + call->tx_winsize) + sp->hdr.flags |= RXRPC_MORE_PACKETS; + + ret = call->security->secure_packet(call, skb, skb->mark); + if (ret < 0) + goto out; + + ret = rxrpc_queue_packet(rx, call, skb, + !msg_data_left(msg) && !more, + notify_end_tx); + /* Should check for failure here */ + skb = NULL; + } + } while (msg_data_left(msg) > 0); + +success: + ret = copied; + if (READ_ONCE(call->state) == RXRPC_CALL_COMPLETE) { + read_lock_bh(&call->state_lock); + if (call->error < 0) + ret = call->error; + read_unlock_bh(&call->state_lock); + } +out: + call->tx_pending = skb; + _leave(" = %d", ret); + return ret; + +call_terminated: + rxrpc_free_skb(skb, rxrpc_skb_freed); + _leave(" = %d", call->error); + return call->error; + +maybe_error: + if (copied) + goto success; + goto out; + +efault: + ret = -EFAULT; + goto out; + +wait_for_space: + ret = -EAGAIN; + if (msg->msg_flags & MSG_DONTWAIT) + goto maybe_error; + mutex_unlock(&call->user_mutex); + *_dropped_lock = true; + ret = rxrpc_wait_for_tx_window(rx, call, &timeo, + msg->msg_flags & MSG_WAITALL); + if (ret < 0) + goto maybe_error; + if (call->interruptibility == RXRPC_INTERRUPTIBLE) { + if (mutex_lock_interruptible(&call->user_mutex) < 0) { + ret = sock_intr_errno(timeo); + goto maybe_error; + } + } else { + mutex_lock(&call->user_mutex); + } + *_dropped_lock = false; + goto reload; +} + +/* + * extract control messages from the sendmsg() control buffer + */ +static int rxrpc_sendmsg_cmsg(struct msghdr *msg, struct rxrpc_send_params *p) +{ + struct cmsghdr *cmsg; + bool got_user_ID = false; + int len; + + if (msg->msg_controllen == 0) + return -EINVAL; + + for_each_cmsghdr(cmsg, msg) { + if (!CMSG_OK(msg, cmsg)) + return -EINVAL; + + len = cmsg->cmsg_len - sizeof(struct cmsghdr); + _debug("CMSG %d, %d, %d", + cmsg->cmsg_level, cmsg->cmsg_type, len); + + if (cmsg->cmsg_level != SOL_RXRPC) + continue; + + switch (cmsg->cmsg_type) { + case RXRPC_USER_CALL_ID: + if (msg->msg_flags & MSG_CMSG_COMPAT) { + if (len != sizeof(u32)) + return -EINVAL; + p->call.user_call_ID = *(u32 *)CMSG_DATA(cmsg); + } else { + if (len != sizeof(unsigned long)) + return -EINVAL; + p->call.user_call_ID = *(unsigned long *) + CMSG_DATA(cmsg); + } + got_user_ID = true; + break; + + case RXRPC_ABORT: + if (p->command != RXRPC_CMD_SEND_DATA) + return -EINVAL; + p->command = RXRPC_CMD_SEND_ABORT; + if (len != sizeof(p->abort_code)) + return -EINVAL; + p->abort_code = *(unsigned int *)CMSG_DATA(cmsg); + if (p->abort_code == 0) + return -EINVAL; + break; + + case RXRPC_CHARGE_ACCEPT: + if (p->command != RXRPC_CMD_SEND_DATA) + return -EINVAL; + p->command = RXRPC_CMD_CHARGE_ACCEPT; + if (len != 0) + return -EINVAL; + break; + + case RXRPC_EXCLUSIVE_CALL: + p->exclusive = true; + if (len != 0) + return -EINVAL; + break; + + case RXRPC_UPGRADE_SERVICE: + p->upgrade = true; + if (len != 0) + return -EINVAL; + break; + + case RXRPC_TX_LENGTH: + if (p->call.tx_total_len != -1 || len != sizeof(__s64)) + return -EINVAL; + p->call.tx_total_len = *(__s64 *)CMSG_DATA(cmsg); + if (p->call.tx_total_len < 0) + return -EINVAL; + break; + + case RXRPC_SET_CALL_TIMEOUT: + if (len & 3 || len < 4 || len > 12) + return -EINVAL; + memcpy(&p->call.timeouts, CMSG_DATA(cmsg), len); + p->call.nr_timeouts = len / 4; + if (p->call.timeouts.hard > INT_MAX / HZ) + return -ERANGE; + if (p->call.nr_timeouts >= 2 && p->call.timeouts.idle > 60 * 60 * 1000) + return -ERANGE; + if (p->call.nr_timeouts >= 3 && p->call.timeouts.normal > 60 * 60 * 1000) + return -ERANGE; + break; + + default: + return -EINVAL; + } + } + + if (!got_user_ID) + return -EINVAL; + if (p->call.tx_total_len != -1 && p->command != RXRPC_CMD_SEND_DATA) + return -EINVAL; + _leave(" = 0"); + return 0; +} + +/* + * Create a new client call for sendmsg(). + * - Called with the socket lock held, which it must release. + * - If it returns a call, the call's lock will need releasing by the caller. + */ +static struct rxrpc_call * +rxrpc_new_client_call_for_sendmsg(struct rxrpc_sock *rx, struct msghdr *msg, + struct rxrpc_send_params *p) + __releases(&rx->sk.sk_lock.slock) + __acquires(&call->user_mutex) +{ + struct rxrpc_conn_parameters cp; + struct rxrpc_call *call; + struct key *key; + + DECLARE_SOCKADDR(struct sockaddr_rxrpc *, srx, msg->msg_name); + + _enter(""); + + if (!msg->msg_name) { + release_sock(&rx->sk); + return ERR_PTR(-EDESTADDRREQ); + } + + key = rx->key; + if (key && !rx->key->payload.data[0]) + key = NULL; + + memset(&cp, 0, sizeof(cp)); + cp.local = rx->local; + cp.key = rx->key; + cp.security_level = rx->min_sec_level; + cp.exclusive = rx->exclusive | p->exclusive; + cp.upgrade = p->upgrade; + cp.service_id = srx->srx_service; + call = rxrpc_new_client_call(rx, &cp, srx, &p->call, GFP_KERNEL, + atomic_inc_return(&rxrpc_debug_id)); + /* The socket is now unlocked */ + + rxrpc_put_peer(cp.peer); + _leave(" = %p\n", call); + return call; +} + +/* + * send a message forming part of a client call through an RxRPC socket + * - caller holds the socket locked + * - the socket may be either a client socket or a server socket + */ +int rxrpc_do_sendmsg(struct rxrpc_sock *rx, struct msghdr *msg, size_t len) + __releases(&rx->sk.sk_lock.slock) + __releases(&call->user_mutex) +{ + enum rxrpc_call_state state; + struct rxrpc_call *call; + unsigned long now, j; + bool dropped_lock = false; + int ret; + + struct rxrpc_send_params p = { + .call.tx_total_len = -1, + .call.user_call_ID = 0, + .call.nr_timeouts = 0, + .call.interruptibility = RXRPC_INTERRUPTIBLE, + .abort_code = 0, + .command = RXRPC_CMD_SEND_DATA, + .exclusive = false, + .upgrade = false, + }; + + _enter(""); + + ret = rxrpc_sendmsg_cmsg(msg, &p); + if (ret < 0) + goto error_release_sock; + + if (p.command == RXRPC_CMD_CHARGE_ACCEPT) { + ret = -EINVAL; + if (rx->sk.sk_state != RXRPC_SERVER_LISTENING) + goto error_release_sock; + ret = rxrpc_user_charge_accept(rx, p.call.user_call_ID); + goto error_release_sock; + } + + call = rxrpc_find_call_by_user_ID(rx, p.call.user_call_ID); + if (!call) { + ret = -EBADSLT; + if (p.command != RXRPC_CMD_SEND_DATA) + goto error_release_sock; + call = rxrpc_new_client_call_for_sendmsg(rx, msg, &p); + /* The socket is now unlocked... */ + if (IS_ERR(call)) + return PTR_ERR(call); + /* ... and we have the call lock. */ + ret = 0; + if (READ_ONCE(call->state) == RXRPC_CALL_COMPLETE) + goto out_put_unlock; + } else { + switch (READ_ONCE(call->state)) { + case RXRPC_CALL_UNINITIALISED: + case RXRPC_CALL_CLIENT_AWAIT_CONN: + case RXRPC_CALL_SERVER_PREALLOC: + case RXRPC_CALL_SERVER_SECURING: + rxrpc_put_call(call, rxrpc_call_put); + ret = -EBUSY; + goto error_release_sock; + default: + break; + } + + ret = mutex_lock_interruptible(&call->user_mutex); + release_sock(&rx->sk); + if (ret < 0) { + ret = -ERESTARTSYS; + goto error_put; + } + + if (p.call.tx_total_len != -1) { + ret = -EINVAL; + if (call->tx_total_len != -1 || + call->tx_pending || + call->tx_top != 0) + goto out_put_unlock; + call->tx_total_len = p.call.tx_total_len; + } + } + + switch (p.call.nr_timeouts) { + case 3: + j = msecs_to_jiffies(p.call.timeouts.normal); + if (p.call.timeouts.normal > 0 && j == 0) + j = 1; + WRITE_ONCE(call->next_rx_timo, j); + fallthrough; + case 2: + j = msecs_to_jiffies(p.call.timeouts.idle); + if (p.call.timeouts.idle > 0 && j == 0) + j = 1; + WRITE_ONCE(call->next_req_timo, j); + fallthrough; + case 1: + if (p.call.timeouts.hard > 0) { + j = p.call.timeouts.hard * HZ; + now = jiffies; + j += now; + WRITE_ONCE(call->expect_term_by, j); + rxrpc_reduce_call_timer(call, j, now, + rxrpc_timer_set_for_hard); + } + break; + } + + state = READ_ONCE(call->state); + _debug("CALL %d USR %lx ST %d on CONN %p", + call->debug_id, call->user_call_ID, state, call->conn); + + if (state >= RXRPC_CALL_COMPLETE) { + /* it's too late for this call */ + ret = -ESHUTDOWN; + } else if (p.command == RXRPC_CMD_SEND_ABORT) { + ret = 0; + if (rxrpc_abort_call("CMD", call, 0, p.abort_code, -ECONNABORTED)) + ret = rxrpc_send_abort_packet(call); + } else if (p.command != RXRPC_CMD_SEND_DATA) { + ret = -EINVAL; + } else { + ret = rxrpc_send_data(rx, call, msg, len, NULL, &dropped_lock); + } + +out_put_unlock: + if (!dropped_lock) + mutex_unlock(&call->user_mutex); +error_put: + rxrpc_put_call(call, rxrpc_call_put); + _leave(" = %d", ret); + return ret; + +error_release_sock: + release_sock(&rx->sk); + return ret; +} + +/** + * rxrpc_kernel_send_data - Allow a kernel service to send data on a call + * @sock: The socket the call is on + * @call: The call to send data through + * @msg: The data to send + * @len: The amount of data to send + * @notify_end_tx: Notification that the last packet is queued. + * + * Allow a kernel service to send data on a call. The call must be in an state + * appropriate to sending data. No control data should be supplied in @msg, + * nor should an address be supplied. MSG_MORE should be flagged if there's + * more data to come, otherwise this data will end the transmission phase. + */ +int rxrpc_kernel_send_data(struct socket *sock, struct rxrpc_call *call, + struct msghdr *msg, size_t len, + rxrpc_notify_end_tx_t notify_end_tx) +{ + bool dropped_lock = false; + int ret; + + _enter("{%d,%s},", call->debug_id, rxrpc_call_states[call->state]); + + ASSERTCMP(msg->msg_name, ==, NULL); + ASSERTCMP(msg->msg_control, ==, NULL); + + mutex_lock(&call->user_mutex); + + _debug("CALL %d USR %lx ST %d on CONN %p", + call->debug_id, call->user_call_ID, call->state, call->conn); + + switch (READ_ONCE(call->state)) { + case RXRPC_CALL_CLIENT_SEND_REQUEST: + case RXRPC_CALL_SERVER_ACK_REQUEST: + case RXRPC_CALL_SERVER_SEND_REPLY: + ret = rxrpc_send_data(rxrpc_sk(sock->sk), call, msg, len, + notify_end_tx, &dropped_lock); + break; + case RXRPC_CALL_COMPLETE: + read_lock_bh(&call->state_lock); + ret = call->error; + read_unlock_bh(&call->state_lock); + break; + default: + /* Request phase complete for this client call */ + trace_rxrpc_rx_eproto(call, 0, tracepoint_string("late_send")); + ret = -EPROTO; + break; + } + + if (!dropped_lock) + mutex_unlock(&call->user_mutex); + _leave(" = %d", ret); + return ret; +} +EXPORT_SYMBOL(rxrpc_kernel_send_data); + +/** + * rxrpc_kernel_abort_call - Allow a kernel service to abort a call + * @sock: The socket the call is on + * @call: The call to be aborted + * @abort_code: The abort code to stick into the ABORT packet + * @error: Local error value + * @why: 3-char string indicating why. + * + * Allow a kernel service to abort a call, if it's still in an abortable state + * and return true if the call was aborted, false if it was already complete. + */ +bool rxrpc_kernel_abort_call(struct socket *sock, struct rxrpc_call *call, + u32 abort_code, int error, const char *why) +{ + bool aborted; + + _enter("{%d},%d,%d,%s", call->debug_id, abort_code, error, why); + + mutex_lock(&call->user_mutex); + + aborted = rxrpc_abort_call(why, call, 0, abort_code, error); + if (aborted) + rxrpc_send_abort_packet(call); + + mutex_unlock(&call->user_mutex); + return aborted; +} +EXPORT_SYMBOL(rxrpc_kernel_abort_call); + +/** + * rxrpc_kernel_set_tx_length - Set the total Tx length on a call + * @sock: The socket the call is on + * @call: The call to be informed + * @tx_total_len: The amount of data to be transmitted for this call + * + * Allow a kernel service to set the total transmit length on a call. This + * allows buffer-to-packet encrypt-and-copy to be performed. + * + * This function is primarily for use for setting the reply length since the + * request length can be set when beginning the call. + */ +void rxrpc_kernel_set_tx_length(struct socket *sock, struct rxrpc_call *call, + s64 tx_total_len) +{ + WARN_ON(call->tx_total_len != -1); + call->tx_total_len = tx_total_len; +} +EXPORT_SYMBOL(rxrpc_kernel_set_tx_length); diff --git a/net/rxrpc/server_key.c b/net/rxrpc/server_key.c new file mode 100644 index 000000000..ee269e0e6 --- /dev/null +++ b/net/rxrpc/server_key.c @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* RxRPC key management + * + * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + * + * RxRPC keys should have a description of describing their purpose: + * "afs@CAMBRIDGE.REDHAT.COM> + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "ar-internal.h" + +static int rxrpc_vet_description_s(const char *); +static int rxrpc_preparse_s(struct key_preparsed_payload *); +static void rxrpc_free_preparse_s(struct key_preparsed_payload *); +static void rxrpc_destroy_s(struct key *); +static void rxrpc_describe_s(const struct key *, struct seq_file *); + +/* + * rxrpc server keys take ":[:]" as the + * description and the key material as the payload. + */ +struct key_type key_type_rxrpc_s = { + .name = "rxrpc_s", + .flags = KEY_TYPE_NET_DOMAIN, + .vet_description = rxrpc_vet_description_s, + .preparse = rxrpc_preparse_s, + .free_preparse = rxrpc_free_preparse_s, + .instantiate = generic_key_instantiate, + .destroy = rxrpc_destroy_s, + .describe = rxrpc_describe_s, +}; + +/* + * Vet the description for an RxRPC server key. + */ +static int rxrpc_vet_description_s(const char *desc) +{ + unsigned long service, sec_class; + char *p; + + service = simple_strtoul(desc, &p, 10); + if (*p != ':' || service > 65535) + return -EINVAL; + sec_class = simple_strtoul(p + 1, &p, 10); + if ((*p && *p != ':') || sec_class < 1 || sec_class > 255) + return -EINVAL; + return 0; +} + +/* + * Preparse a server secret key. + */ +static int rxrpc_preparse_s(struct key_preparsed_payload *prep) +{ + const struct rxrpc_security *sec; + unsigned int service, sec_class; + int n; + + _enter("%zu", prep->datalen); + + if (!prep->orig_description) + return -EINVAL; + + if (sscanf(prep->orig_description, "%u:%u%n", &service, &sec_class, &n) != 2) + return -EINVAL; + + sec = rxrpc_security_lookup(sec_class); + if (!sec) + return -ENOPKG; + + prep->payload.data[1] = (struct rxrpc_security *)sec; + + if (!sec->preparse_server_key) + return -EINVAL; + + return sec->preparse_server_key(prep); +} + +static void rxrpc_free_preparse_s(struct key_preparsed_payload *prep) +{ + const struct rxrpc_security *sec = prep->payload.data[1]; + + if (sec && sec->free_preparse_server_key) + sec->free_preparse_server_key(prep); +} + +static void rxrpc_destroy_s(struct key *key) +{ + const struct rxrpc_security *sec = key->payload.data[1]; + + if (sec && sec->destroy_server_key) + sec->destroy_server_key(key); +} + +static void rxrpc_describe_s(const struct key *key, struct seq_file *m) +{ + const struct rxrpc_security *sec = key->payload.data[1]; + + seq_puts(m, key->description); + if (sec && sec->describe_server_key) + sec->describe_server_key(key, m); +} + +/* + * grab the security keyring for a server socket + */ +int rxrpc_server_keyring(struct rxrpc_sock *rx, sockptr_t optval, int optlen) +{ + struct key *key; + char *description; + + _enter(""); + + if (optlen <= 0 || optlen > PAGE_SIZE - 1) + return -EINVAL; + + description = memdup_sockptr_nul(optval, optlen); + if (IS_ERR(description)) + return PTR_ERR(description); + + key = request_key(&key_type_keyring, description, NULL); + if (IS_ERR(key)) { + kfree(description); + _leave(" = %ld", PTR_ERR(key)); + return PTR_ERR(key); + } + + rx->securities = key; + kfree(description); + _leave(" = 0 [key %x]", key->serial); + return 0; +} diff --git a/net/rxrpc/skbuff.c b/net/rxrpc/skbuff.c new file mode 100644 index 000000000..580a5acff --- /dev/null +++ b/net/rxrpc/skbuff.c @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* ar-skbuff.c: socket buffer destruction handling + * + * Copyright (C) 2007 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include "ar-internal.h" + +#define is_tx_skb(skb) (rxrpc_skb(skb)->rx_flags & RXRPC_SKB_TX_BUFFER) +#define select_skb_count(skb) (is_tx_skb(skb) ? &rxrpc_n_tx_skbs : &rxrpc_n_rx_skbs) + +/* + * Note the allocation or reception of a socket buffer. + */ +void rxrpc_new_skb(struct sk_buff *skb, enum rxrpc_skb_trace op) +{ + const void *here = __builtin_return_address(0); + int n = atomic_inc_return(select_skb_count(skb)); + trace_rxrpc_skb(skb, op, refcount_read(&skb->users), n, + rxrpc_skb(skb)->rx_flags, here); +} + +/* + * Note the re-emergence of a socket buffer from a queue or buffer. + */ +void rxrpc_see_skb(struct sk_buff *skb, enum rxrpc_skb_trace op) +{ + const void *here = __builtin_return_address(0); + if (skb) { + int n = atomic_read(select_skb_count(skb)); + trace_rxrpc_skb(skb, op, refcount_read(&skb->users), n, + rxrpc_skb(skb)->rx_flags, here); + } +} + +/* + * Note the addition of a ref on a socket buffer. + */ +void rxrpc_get_skb(struct sk_buff *skb, enum rxrpc_skb_trace op) +{ + const void *here = __builtin_return_address(0); + int n = atomic_inc_return(select_skb_count(skb)); + trace_rxrpc_skb(skb, op, refcount_read(&skb->users), n, + rxrpc_skb(skb)->rx_flags, here); + skb_get(skb); +} + +/* + * Note the dropping of a ref on a socket buffer by the core. + */ +void rxrpc_eaten_skb(struct sk_buff *skb, enum rxrpc_skb_trace op) +{ + const void *here = __builtin_return_address(0); + int n = atomic_inc_return(&rxrpc_n_rx_skbs); + trace_rxrpc_skb(skb, op, 0, n, 0, here); +} + +/* + * Note the destruction of a socket buffer. + */ +void rxrpc_free_skb(struct sk_buff *skb, enum rxrpc_skb_trace op) +{ + const void *here = __builtin_return_address(0); + if (skb) { + int n; + n = atomic_dec_return(select_skb_count(skb)); + trace_rxrpc_skb(skb, op, refcount_read(&skb->users), n, + rxrpc_skb(skb)->rx_flags, here); + kfree_skb(skb); + } +} + +/* + * Clear a queue of socket buffers. + */ +void rxrpc_purge_queue(struct sk_buff_head *list) +{ + const void *here = __builtin_return_address(0); + struct sk_buff *skb; + while ((skb = skb_dequeue((list))) != NULL) { + int n = atomic_dec_return(select_skb_count(skb)); + trace_rxrpc_skb(skb, rxrpc_skb_purged, + refcount_read(&skb->users), n, + rxrpc_skb(skb)->rx_flags, here); + kfree_skb(skb); + } +} diff --git a/net/rxrpc/sysctl.c b/net/rxrpc/sysctl.c new file mode 100644 index 000000000..555e09107 --- /dev/null +++ b/net/rxrpc/sysctl.c @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* sysctls for configuring RxRPC operating parameters + * + * Copyright (C) 2014 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#include +#include +#include +#include "ar-internal.h" + +static struct ctl_table_header *rxrpc_sysctl_reg_table; +static const unsigned int four = 4; +static const unsigned int max_backlog = RXRPC_BACKLOG_MAX - 1; +static const unsigned int n_65535 = 65535; +static const unsigned int n_max_acks = RXRPC_RXTX_BUFF_SIZE - 1; +static const unsigned long one_jiffy = 1; +static const unsigned long max_jiffies = MAX_JIFFY_OFFSET; + +/* + * RxRPC operating parameters. + * + * See Documentation/networking/rxrpc.rst and the variable definitions for more + * information on the individual parameters. + */ +static struct ctl_table rxrpc_sysctl_table[] = { + /* Values measured in milliseconds but used in jiffies */ + { + .procname = "req_ack_delay", + .data = &rxrpc_requested_ack_delay, + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_ms_jiffies_minmax, + .extra1 = (void *)&one_jiffy, + .extra2 = (void *)&max_jiffies, + }, + { + .procname = "soft_ack_delay", + .data = &rxrpc_soft_ack_delay, + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_ms_jiffies_minmax, + .extra1 = (void *)&one_jiffy, + .extra2 = (void *)&max_jiffies, + }, + { + .procname = "idle_ack_delay", + .data = &rxrpc_idle_ack_delay, + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_ms_jiffies_minmax, + .extra1 = (void *)&one_jiffy, + .extra2 = (void *)&max_jiffies, + }, + { + .procname = "idle_conn_expiry", + .data = &rxrpc_conn_idle_client_expiry, + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_ms_jiffies_minmax, + .extra1 = (void *)&one_jiffy, + .extra2 = (void *)&max_jiffies, + }, + { + .procname = "idle_conn_fast_expiry", + .data = &rxrpc_conn_idle_client_fast_expiry, + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = proc_doulongvec_ms_jiffies_minmax, + .extra1 = (void *)&one_jiffy, + .extra2 = (void *)&max_jiffies, + }, + + /* Non-time values */ + { + .procname = "reap_client_conns", + .data = &rxrpc_reap_client_connections, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = (void *)SYSCTL_ONE, + .extra2 = (void *)&n_65535, + }, + { + .procname = "max_backlog", + .data = &rxrpc_max_backlog, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = (void *)&four, + .extra2 = (void *)&max_backlog, + }, + { + .procname = "rx_window_size", + .data = &rxrpc_rx_window_size, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = (void *)SYSCTL_ONE, + .extra2 = (void *)&n_max_acks, + }, + { + .procname = "rx_mtu", + .data = &rxrpc_rx_mtu, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = (void *)SYSCTL_ONE, + .extra2 = (void *)&n_65535, + }, + { + .procname = "rx_jumbo_max", + .data = &rxrpc_rx_jumbo_max, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = (void *)SYSCTL_ONE, + .extra2 = (void *)&four, + }, + + { } +}; + +int __init rxrpc_sysctl_init(void) +{ + rxrpc_sysctl_reg_table = register_net_sysctl(&init_net, "net/rxrpc", + rxrpc_sysctl_table); + if (!rxrpc_sysctl_reg_table) + return -ENOMEM; + return 0; +} + +void rxrpc_sysctl_exit(void) +{ + if (rxrpc_sysctl_reg_table) + unregister_net_sysctl_table(rxrpc_sysctl_reg_table); +} diff --git a/net/rxrpc/utils.c b/net/rxrpc/utils.c new file mode 100644 index 000000000..2e4b9d86e --- /dev/null +++ b/net/rxrpc/utils.c @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* Utility routines + * + * Copyright (C) 2015 Red Hat, Inc. All Rights Reserved. + * Written by David Howells (dhowells@redhat.com) + */ + +#include +#include +#include +#include "ar-internal.h" + +/* + * Fill out a peer address from a socket buffer containing a packet. + */ +int rxrpc_extract_addr_from_skb(struct sockaddr_rxrpc *srx, struct sk_buff *skb) +{ + memset(srx, 0, sizeof(*srx)); + + switch (ntohs(skb->protocol)) { + case ETH_P_IP: + srx->transport_type = SOCK_DGRAM; + srx->transport_len = sizeof(srx->transport.sin); + srx->transport.sin.sin_family = AF_INET; + srx->transport.sin.sin_port = udp_hdr(skb)->source; + srx->transport.sin.sin_addr.s_addr = ip_hdr(skb)->saddr; + return 0; + +#ifdef CONFIG_AF_RXRPC_IPV6 + case ETH_P_IPV6: + srx->transport_type = SOCK_DGRAM; + srx->transport_len = sizeof(srx->transport.sin6); + srx->transport.sin6.sin6_family = AF_INET6; + srx->transport.sin6.sin6_port = udp_hdr(skb)->source; + srx->transport.sin6.sin6_addr = ipv6_hdr(skb)->saddr; + return 0; +#endif + + default: + pr_warn_ratelimited("AF_RXRPC: Unknown eth protocol %u\n", + ntohs(skb->protocol)); + return -EAFNOSUPPORT; + } +} diff --git a/net/sched/Kconfig b/net/sched/Kconfig new file mode 100644 index 000000000..24cf0bf7c --- /dev/null +++ b/net/sched/Kconfig @@ -0,0 +1,988 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Traffic control configuration. +# + +menuconfig NET_SCHED + bool "QoS and/or fair queueing" + select NET_SCH_FIFO + help + When the kernel has several packets to send out over a network + device, it has to decide which ones to send first, which ones to + delay, and which ones to drop. This is the job of the queueing + disciplines, several different algorithms for how to do this + "fairly" have been proposed. + + If you say N here, you will get the standard packet scheduler, which + is a FIFO (first come, first served). If you say Y here, you will be + able to choose from among several alternative algorithms which can + then be attached to different network devices. This is useful for + example if some of your network devices are real time devices that + need a certain minimum data flow rate, or if you need to limit the + maximum data flow rate for traffic which matches specified criteria. + This code is considered to be experimental. + + To administer these schedulers, you'll need the user-level utilities + from the package iproute2+tc at + . That package + also contains some documentation; for more, check out + . + + This Quality of Service (QoS) support will enable you to use + Differentiated Services (diffserv) and Resource Reservation Protocol + (RSVP) on your Linux router if you also say Y to the corresponding + classifiers below. Documentation and software is at + . + + If you say Y here and to "/proc file system" below, you will be able + to read status information about packet schedulers from the file + /proc/net/psched. + + The available schedulers are listed in the following questions; you + can say Y to as many as you like. If unsure, say N now. + +if NET_SCHED + +comment "Queueing/Scheduling" + +config NET_SCH_CBQ + tristate "Class Based Queueing (CBQ)" + help + Say Y here if you want to use the Class-Based Queueing (CBQ) packet + scheduling algorithm. This algorithm classifies the waiting packets + into a tree-like hierarchy of classes; the leaves of this tree are + in turn scheduled by separate algorithms. + + See the top of for more details. + + CBQ is a commonly used scheduler, so if you're unsure, you should + say Y here. Then say Y to all the queueing algorithms below that you + want to use as leaf disciplines. + + To compile this code as a module, choose M here: the + module will be called sch_cbq. + +config NET_SCH_HTB + tristate "Hierarchical Token Bucket (HTB)" + help + Say Y here if you want to use the Hierarchical Token Buckets (HTB) + packet scheduling algorithm. See + for complete manual and + in-depth articles. + + HTB is very similar to CBQ regarding its goals however is has + different properties and different algorithm. + + To compile this code as a module, choose M here: the + module will be called sch_htb. + +config NET_SCH_HFSC + tristate "Hierarchical Fair Service Curve (HFSC)" + help + Say Y here if you want to use the Hierarchical Fair Service Curve + (HFSC) packet scheduling algorithm. + + To compile this code as a module, choose M here: the + module will be called sch_hfsc. + +config NET_SCH_ATM + tristate "ATM Virtual Circuits (ATM)" + depends on ATM + help + Say Y here if you want to use the ATM pseudo-scheduler. This + provides a framework for invoking classifiers, which in turn + select classes of this queuing discipline. Each class maps + the flow(s) it is handling to a given virtual circuit. + + See the top of for more details. + + To compile this code as a module, choose M here: the + module will be called sch_atm. + +config NET_SCH_PRIO + tristate "Multi Band Priority Queueing (PRIO)" + help + Say Y here if you want to use an n-band priority queue packet + scheduler. + + To compile this code as a module, choose M here: the + module will be called sch_prio. + +config NET_SCH_MULTIQ + tristate "Hardware Multiqueue-aware Multi Band Queuing (MULTIQ)" + help + Say Y here if you want to use an n-band queue packet scheduler + to support devices that have multiple hardware transmit queues. + + To compile this code as a module, choose M here: the + module will be called sch_multiq. + +config NET_SCH_RED + tristate "Random Early Detection (RED)" + help + Say Y here if you want to use the Random Early Detection (RED) + packet scheduling algorithm. + + See the top of for more details. + + To compile this code as a module, choose M here: the + module will be called sch_red. + +config NET_SCH_SFB + tristate "Stochastic Fair Blue (SFB)" + help + Say Y here if you want to use the Stochastic Fair Blue (SFB) + packet scheduling algorithm. + + See the top of for more details. + + To compile this code as a module, choose M here: the + module will be called sch_sfb. + +config NET_SCH_SFQ + tristate "Stochastic Fairness Queueing (SFQ)" + help + Say Y here if you want to use the Stochastic Fairness Queueing (SFQ) + packet scheduling algorithm. + + See the top of for more details. + + To compile this code as a module, choose M here: the + module will be called sch_sfq. + +config NET_SCH_TEQL + tristate "True Link Equalizer (TEQL)" + help + Say Y here if you want to use the True Link Equalizer (TLE) packet + scheduling algorithm. This queueing discipline allows the combination + of several physical devices into one virtual device. + + See the top of for more details. + + To compile this code as a module, choose M here: the + module will be called sch_teql. + +config NET_SCH_TBF + tristate "Token Bucket Filter (TBF)" + help + Say Y here if you want to use the Token Bucket Filter (TBF) packet + scheduling algorithm. + + See the top of for more details. + + To compile this code as a module, choose M here: the + module will be called sch_tbf. + +config NET_SCH_CBS + tristate "Credit Based Shaper (CBS)" + help + Say Y here if you want to use the Credit Based Shaper (CBS) packet + scheduling algorithm. + + See the top of for more details. + + To compile this code as a module, choose M here: the + module will be called sch_cbs. + +config NET_SCH_ETF + tristate "Earliest TxTime First (ETF)" + help + Say Y here if you want to use the Earliest TxTime First (ETF) packet + scheduling algorithm. + + See the top of for more details. + + To compile this code as a module, choose M here: the + module will be called sch_etf. + +config NET_SCH_TAPRIO + tristate "Time Aware Priority (taprio) Scheduler" + help + Say Y here if you want to use the Time Aware Priority (taprio) packet + scheduling algorithm. + + See the top of for more details. + + To compile this code as a module, choose M here: the + module will be called sch_taprio. + +config NET_SCH_GRED + tristate "Generic Random Early Detection (GRED)" + help + Say Y here if you want to use the Generic Random Early Detection + (GRED) packet scheduling algorithm for some of your network devices + (see the top of for details and + references about the algorithm). + + To compile this code as a module, choose M here: the + module will be called sch_gred. + +config NET_SCH_DSMARK + tristate "Differentiated Services marker (DSMARK)" + help + Say Y if you want to schedule packets according to the + Differentiated Services architecture proposed in RFC 2475. + Technical information on this method, with pointers to associated + RFCs, is available at . + + To compile this code as a module, choose M here: the + module will be called sch_dsmark. + +config NET_SCH_NETEM + tristate "Network emulator (NETEM)" + help + Say Y if you want to emulate network delay, loss, and packet + re-ordering. This is often useful to simulate networks when + testing applications or protocols. + + To compile this driver as a module, choose M here: the module + will be called sch_netem. + + If unsure, say N. + +config NET_SCH_DRR + tristate "Deficit Round Robin scheduler (DRR)" + help + Say Y here if you want to use the Deficit Round Robin (DRR) packet + scheduling algorithm. + + To compile this driver as a module, choose M here: the module + will be called sch_drr. + + If unsure, say N. + +config NET_SCH_MQPRIO + tristate "Multi-queue priority scheduler (MQPRIO)" + help + Say Y here if you want to use the Multi-queue Priority scheduler. + This scheduler allows QOS to be offloaded on NICs that have support + for offloading QOS schedulers. + + To compile this driver as a module, choose M here: the module will + be called sch_mqprio. + + If unsure, say N. + +config NET_SCH_SKBPRIO + tristate "SKB priority queue scheduler (SKBPRIO)" + help + Say Y here if you want to use the SKB priority queue + scheduler. This schedules packets according to skb->priority, + which is useful for request packets in DoS mitigation systems such + as Gatekeeper. + + To compile this driver as a module, choose M here: the module will + be called sch_skbprio. + + If unsure, say N. + +config NET_SCH_CHOKE + tristate "CHOose and Keep responsive flow scheduler (CHOKE)" + help + Say Y here if you want to use the CHOKe packet scheduler (CHOose + and Keep for responsive flows, CHOose and Kill for unresponsive + flows). This is a variation of RED which tries to penalize flows + that monopolize the queue. + + To compile this code as a module, choose M here: the + module will be called sch_choke. + +config NET_SCH_QFQ + tristate "Quick Fair Queueing scheduler (QFQ)" + help + Say Y here if you want to use the Quick Fair Queueing Scheduler (QFQ) + packet scheduling algorithm. + + To compile this driver as a module, choose M here: the module + will be called sch_qfq. + + If unsure, say N. + +config NET_SCH_CODEL + tristate "Controlled Delay AQM (CODEL)" + help + Say Y here if you want to use the Controlled Delay (CODEL) + packet scheduling algorithm. + + To compile this driver as a module, choose M here: the module + will be called sch_codel. + + If unsure, say N. + +config NET_SCH_FQ_CODEL + tristate "Fair Queue Controlled Delay AQM (FQ_CODEL)" + help + Say Y here if you want to use the FQ Controlled Delay (FQ_CODEL) + packet scheduling algorithm. + + To compile this driver as a module, choose M here: the module + will be called sch_fq_codel. + + If unsure, say N. + +config NET_SCH_CAKE + tristate "Common Applications Kept Enhanced (CAKE)" + help + Say Y here if you want to use the Common Applications Kept Enhanced + (CAKE) queue management algorithm. + + To compile this driver as a module, choose M here: the module + will be called sch_cake. + + If unsure, say N. + +config NET_SCH_FQ + tristate "Fair Queue" + help + Say Y here if you want to use the FQ packet scheduling algorithm. + + FQ does flow separation, and is able to respect pacing requirements + set by TCP stack into sk->sk_pacing_rate (for localy generated + traffic) + + To compile this driver as a module, choose M here: the module + will be called sch_fq. + + If unsure, say N. + +config NET_SCH_HHF + tristate "Heavy-Hitter Filter (HHF)" + help + Say Y here if you want to use the Heavy-Hitter Filter (HHF) + packet scheduling algorithm. + + To compile this driver as a module, choose M here: the module + will be called sch_hhf. + +config NET_SCH_PIE + tristate "Proportional Integral controller Enhanced (PIE) scheduler" + help + Say Y here if you want to use the Proportional Integral controller + Enhanced scheduler packet scheduling algorithm. + For more information, please see https://tools.ietf.org/html/rfc8033 + + To compile this driver as a module, choose M here: the module + will be called sch_pie. + + If unsure, say N. + +config NET_SCH_FQ_PIE + depends on NET_SCH_PIE + tristate "Flow Queue Proportional Integral controller Enhanced (FQ-PIE)" + help + Say Y here if you want to use the Flow Queue Proportional Integral + controller Enhanced (FQ-PIE) packet scheduling algorithm. + For more information, please see https://tools.ietf.org/html/rfc8033 + + To compile this driver as a module, choose M here: the module + will be called sch_fq_pie. + + If unsure, say N. + +config NET_SCH_INGRESS + tristate "Ingress/classifier-action Qdisc" + depends on NET_CLS_ACT + select NET_INGRESS + select NET_EGRESS + help + Say Y here if you want to use classifiers for incoming and/or outgoing + packets. This qdisc doesn't do anything else besides running classifiers, + which can also have actions attached to them. In case of outgoing packets, + classifiers that this qdisc holds are executed in the transmit path + before real enqueuing to an egress qdisc happens. + + If unsure, say Y. + + To compile this code as a module, choose M here: the module will be + called sch_ingress with alias of sch_clsact. + +config NET_SCH_PLUG + tristate "Plug network traffic until release (PLUG)" + help + + This queuing discipline allows userspace to plug/unplug a network + output queue, using the netlink interface. When it receives an + enqueue command it inserts a plug into the outbound queue that + causes following packets to enqueue until a dequeue command arrives + over netlink, causing the plug to be removed and resuming the normal + packet flow. + + This module also provides a generic "network output buffering" + functionality (aka output commit), wherein upon arrival of a dequeue + command, only packets up to the first plug are released for delivery. + The Remus HA project uses this module to enable speculative execution + of virtual machines by allowing the generated network output to be rolled + back if needed. + + For more information, please refer to + + Say Y here if you are using this kernel for Xen dom0 and + want to protect Xen guests with Remus. + + To compile this code as a module, choose M here: the + module will be called sch_plug. + +config NET_SCH_ETS + tristate "Enhanced transmission selection scheduler (ETS)" + help + The Enhanced Transmission Selection scheduler is a classful + queuing discipline that merges functionality of PRIO and DRR + qdiscs in one scheduler. ETS makes it easy to configure a set of + strict and bandwidth-sharing bands to implement the transmission + selection described in 802.1Qaz. + + Say Y here if you want to use the ETS packet scheduling + algorithm. + + To compile this driver as a module, choose M here: the module + will be called sch_ets. + + If unsure, say N. + +menuconfig NET_SCH_DEFAULT + bool "Allow override default queue discipline" + help + Support for selection of default queuing discipline. + + Nearly all users can safely say no here, and the default + of pfifo_fast will be used. Many distributions already set + the default value via /proc/sys/net/core/default_qdisc. + + If unsure, say N. + +if NET_SCH_DEFAULT + +choice + prompt "Default queuing discipline" + default DEFAULT_PFIFO_FAST + help + Select the queueing discipline that will be used by default + for all network devices. + + config DEFAULT_FQ + bool "Fair Queue" if NET_SCH_FQ + + config DEFAULT_CODEL + bool "Controlled Delay" if NET_SCH_CODEL + + config DEFAULT_FQ_CODEL + bool "Fair Queue Controlled Delay" if NET_SCH_FQ_CODEL + + config DEFAULT_FQ_PIE + bool "Flow Queue Proportional Integral controller Enhanced" if NET_SCH_FQ_PIE + + config DEFAULT_SFQ + bool "Stochastic Fair Queue" if NET_SCH_SFQ + + config DEFAULT_PFIFO_FAST + bool "Priority FIFO Fast" +endchoice + +config DEFAULT_NET_SCH + string + default "pfifo_fast" if DEFAULT_PFIFO_FAST + default "fq" if DEFAULT_FQ + default "fq_codel" if DEFAULT_FQ_CODEL + default "fq_pie" if DEFAULT_FQ_PIE + default "sfq" if DEFAULT_SFQ + default "pfifo_fast" +endif + +comment "Classification" + +config NET_CLS + bool + +config NET_CLS_BASIC + tristate "Elementary classification (BASIC)" + select NET_CLS + help + Say Y here if you want to be able to classify packets using + only extended matches and actions. + + To compile this code as a module, choose M here: the + module will be called cls_basic. + +config NET_CLS_ROUTE4 + tristate "Routing decision (ROUTE)" + depends on INET + select IP_ROUTE_CLASSID + select NET_CLS + help + If you say Y here, you will be able to classify packets + according to the route table entry they matched. + + To compile this code as a module, choose M here: the + module will be called cls_route. + +config NET_CLS_FW + tristate "Netfilter mark (FW)" + select NET_CLS + help + If you say Y here, you will be able to classify packets + according to netfilter/firewall marks. + + To compile this code as a module, choose M here: the + module will be called cls_fw. + +config NET_CLS_U32 + tristate "Universal 32bit comparisons w/ hashing (U32)" + select NET_CLS + help + Say Y here to be able to classify packets using a universal + 32bit pieces based comparison scheme. + + To compile this code as a module, choose M here: the + module will be called cls_u32. + +config CLS_U32_PERF + bool "Performance counters support" + depends on NET_CLS_U32 + help + Say Y here to make u32 gather additional statistics useful for + fine tuning u32 classifiers. + +config CLS_U32_MARK + bool "Netfilter marks support" + depends on NET_CLS_U32 + help + Say Y here to be able to use netfilter marks as u32 key. + +config NET_CLS_FLOW + tristate "Flow classifier" + select NET_CLS + help + If you say Y here, you will be able to classify packets based on + a configurable combination of packet keys. This is mostly useful + in combination with SFQ. + + To compile this code as a module, choose M here: the + module will be called cls_flow. + +config NET_CLS_CGROUP + tristate "Control Group Classifier" + select NET_CLS + select CGROUP_NET_CLASSID + depends on CGROUPS + help + Say Y here if you want to classify packets based on the control + cgroup of their process. + + To compile this code as a module, choose M here: the + module will be called cls_cgroup. + +config NET_CLS_BPF + tristate "BPF-based classifier" + select NET_CLS + help + If you say Y here, you will be able to classify packets based on + programmable BPF (JIT'ed) filters as an alternative to ematches. + + To compile this code as a module, choose M here: the module will + be called cls_bpf. + +config NET_CLS_FLOWER + tristate "Flower classifier" + select NET_CLS + help + If you say Y here, you will be able to classify packets based on + a configurable combination of packet keys and masks. + + To compile this code as a module, choose M here: the module will + be called cls_flower. + +config NET_CLS_MATCHALL + tristate "Match-all classifier" + select NET_CLS + help + If you say Y here, you will be able to classify packets based on + nothing. Every packet will match. + + To compile this code as a module, choose M here: the module will + be called cls_matchall. + +config NET_EMATCH + bool "Extended Matches" + select NET_CLS + help + Say Y here if you want to use extended matches on top of classifiers + and select the extended matches below. + + Extended matches are small classification helpers not worth writing + a separate classifier for. + + A recent version of the iproute2 package is required to use + extended matches. + +config NET_EMATCH_STACK + int "Stack size" + depends on NET_EMATCH + default "32" + help + Size of the local stack variable used while evaluating the tree of + ematches. Limits the depth of the tree, i.e. the number of + encapsulated precedences. Every level requires 4 bytes of additional + stack space. + +config NET_EMATCH_CMP + tristate "Simple packet data comparison" + depends on NET_EMATCH + help + Say Y here if you want to be able to classify packets based on + simple packet data comparisons for 8, 16, and 32bit values. + + To compile this code as a module, choose M here: the + module will be called em_cmp. + +config NET_EMATCH_NBYTE + tristate "Multi byte comparison" + depends on NET_EMATCH + help + Say Y here if you want to be able to classify packets based on + multiple byte comparisons mainly useful for IPv6 address comparisons. + + To compile this code as a module, choose M here: the + module will be called em_nbyte. + +config NET_EMATCH_U32 + tristate "U32 key" + depends on NET_EMATCH + help + Say Y here if you want to be able to classify packets using + the famous u32 key in combination with logic relations. + + To compile this code as a module, choose M here: the + module will be called em_u32. + +config NET_EMATCH_META + tristate "Metadata" + depends on NET_EMATCH + help + Say Y here if you want to be able to classify packets based on + metadata such as load average, netfilter attributes, socket + attributes and routing decisions. + + To compile this code as a module, choose M here: the + module will be called em_meta. + +config NET_EMATCH_TEXT + tristate "Textsearch" + depends on NET_EMATCH + select TEXTSEARCH + select TEXTSEARCH_KMP + select TEXTSEARCH_BM + select TEXTSEARCH_FSM + help + Say Y here if you want to be able to classify packets based on + textsearch comparisons. + + To compile this code as a module, choose M here: the + module will be called em_text. + +config NET_EMATCH_CANID + tristate "CAN Identifier" + depends on NET_EMATCH && (CAN=y || CAN=m) + help + Say Y here if you want to be able to classify CAN frames based + on CAN Identifier. + + To compile this code as a module, choose M here: the + module will be called em_canid. + +config NET_EMATCH_IPSET + tristate "IPset" + depends on NET_EMATCH && IP_SET + help + Say Y here if you want to be able to classify packets based on + ipset membership. + + To compile this code as a module, choose M here: the + module will be called em_ipset. + +config NET_EMATCH_IPT + tristate "IPtables Matches" + depends on NET_EMATCH && NETFILTER && NETFILTER_XTABLES + help + Say Y here to be able to classify packets based on iptables + matches. + Current supported match is "policy" which allows packet classification + based on IPsec policy that was used during decapsulation + + To compile this code as a module, choose M here: the + module will be called em_ipt. + +config NET_CLS_ACT + bool "Actions" + select NET_CLS + help + Say Y here if you want to use traffic control actions. Actions + get attached to classifiers and are invoked after a successful + classification. They are used to overwrite the classification + result, instantly drop or redirect packets, etc. + + A recent version of the iproute2 package is required to use + extended matches. + +config NET_ACT_POLICE + tristate "Traffic Policing" + depends on NET_CLS_ACT + help + Say Y here if you want to do traffic policing, i.e. strict + bandwidth limiting. This action replaces the existing policing + module. + + To compile this code as a module, choose M here: the + module will be called act_police. + +config NET_ACT_GACT + tristate "Generic actions" + depends on NET_CLS_ACT + help + Say Y here to take generic actions such as dropping and + accepting packets. + + To compile this code as a module, choose M here: the + module will be called act_gact. + +config GACT_PROB + bool "Probability support" + depends on NET_ACT_GACT + help + Say Y here to use the generic action randomly or deterministically. + +config NET_ACT_MIRRED + tristate "Redirecting and Mirroring" + depends on NET_CLS_ACT + help + Say Y here to allow packets to be mirrored or redirected to + other devices. + + To compile this code as a module, choose M here: the + module will be called act_mirred. + +config NET_ACT_SAMPLE + tristate "Traffic Sampling" + depends on NET_CLS_ACT + select PSAMPLE + help + Say Y here to allow packet sampling tc action. The packet sample + action consists of statistically choosing packets and sampling + them using the psample module. + + To compile this code as a module, choose M here: the + module will be called act_sample. + +config NET_ACT_IPT + tristate "IPtables targets" + depends on NET_CLS_ACT && NETFILTER && NETFILTER_XTABLES + help + Say Y here to be able to invoke iptables targets after successful + classification. + + To compile this code as a module, choose M here: the + module will be called act_ipt. + +config NET_ACT_NAT + tristate "Stateless NAT" + depends on NET_CLS_ACT + help + Say Y here to do stateless NAT on IPv4 packets. You should use + netfilter for NAT unless you know what you are doing. + + To compile this code as a module, choose M here: the + module will be called act_nat. + +config NET_ACT_PEDIT + tristate "Packet Editing" + depends on NET_CLS_ACT + help + Say Y here if you want to mangle the content of packets. + + To compile this code as a module, choose M here: the + module will be called act_pedit. + +config NET_ACT_SIMP + tristate "Simple Example (Debug)" + depends on NET_CLS_ACT + help + Say Y here to add a simple action for demonstration purposes. + It is meant as an example and for debugging purposes. It will + print a configured policy string followed by the packet count + to the console for every packet that passes by. + + If unsure, say N. + + To compile this code as a module, choose M here: the + module will be called act_simple. + +config NET_ACT_SKBEDIT + tristate "SKB Editing" + depends on NET_CLS_ACT + help + Say Y here to change skb priority or queue_mapping settings. + + If unsure, say N. + + To compile this code as a module, choose M here: the + module will be called act_skbedit. + +config NET_ACT_CSUM + tristate "Checksum Updating" + depends on NET_CLS_ACT && INET + select LIBCRC32C + help + Say Y here to update some common checksum after some direct + packet alterations. + + To compile this code as a module, choose M here: the + module will be called act_csum. + +config NET_ACT_MPLS + tristate "MPLS manipulation" + depends on NET_CLS_ACT + help + Say Y here to push or pop MPLS headers. + + If unsure, say N. + + To compile this code as a module, choose M here: the + module will be called act_mpls. + +config NET_ACT_VLAN + tristate "Vlan manipulation" + depends on NET_CLS_ACT + help + Say Y here to push or pop vlan headers. + + If unsure, say N. + + To compile this code as a module, choose M here: the + module will be called act_vlan. + +config NET_ACT_BPF + tristate "BPF based action" + depends on NET_CLS_ACT + help + Say Y here to execute BPF code on packets. The BPF code will decide + if the packet should be dropped or not. + + If unsure, say N. + + To compile this code as a module, choose M here: the + module will be called act_bpf. + +config NET_ACT_CONNMARK + tristate "Netfilter Connection Mark Retriever" + depends on NET_CLS_ACT && NETFILTER + depends on NF_CONNTRACK && NF_CONNTRACK_MARK + help + Say Y here to allow retrieving of conn mark + + If unsure, say N. + + To compile this code as a module, choose M here: the + module will be called act_connmark. + +config NET_ACT_CTINFO + tristate "Netfilter Connection Mark Actions" + depends on NET_CLS_ACT && NETFILTER + depends on NF_CONNTRACK && NF_CONNTRACK_MARK + help + Say Y here to allow transfer of a connmark stored information. + Current actions transfer connmark stored DSCP into + ipv4/v6 diffserv and/or to transfer connmark to packet + mark. Both are useful for restoring egress based marks + back onto ingress connections for qdisc priority mapping + purposes. + + If unsure, say N. + + To compile this code as a module, choose M here: the + module will be called act_ctinfo. + +config NET_ACT_SKBMOD + tristate "skb data modification action" + depends on NET_CLS_ACT + help + Say Y here to allow modification of skb data + + If unsure, say N. + + To compile this code as a module, choose M here: the + module will be called act_skbmod. + +config NET_ACT_IFE + tristate "Inter-FE action based on IETF ForCES InterFE LFB" + depends on NET_CLS_ACT + select NET_IFE + help + Say Y here to allow for sourcing and terminating metadata + For details refer to netdev01 paper: + "Distributing Linux Traffic Control Classifier-Action Subsystem" + Authors: Jamal Hadi Salim and Damascene M. Joachimpillai + + To compile this code as a module, choose M here: the + module will be called act_ife. + +config NET_ACT_TUNNEL_KEY + tristate "IP tunnel metadata manipulation" + depends on NET_CLS_ACT + help + Say Y here to set/release ip tunnel metadata. + + If unsure, say N. + + To compile this code as a module, choose M here: the + module will be called act_tunnel_key. + +config NET_ACT_CT + tristate "connection tracking tc action" + depends on NET_CLS_ACT && NF_CONNTRACK && (!NF_NAT || NF_NAT) && NF_FLOW_TABLE + help + Say Y here to allow sending the packets to conntrack module. + + If unsure, say N. + + To compile this code as a module, choose M here: the + module will be called act_ct. + +config NET_ACT_GATE + tristate "Frame gate entry list control tc action" + depends on NET_CLS_ACT + help + Say Y here to allow to control the ingress flow to be passed at + specific time slot and be dropped at other specific time slot by + the gate entry list. + + If unsure, say N. + To compile this code as a module, choose M here: the + module will be called act_gate. + +config NET_IFE_SKBMARK + tristate "Support to encoding decoding skb mark on IFE action" + depends on NET_ACT_IFE + +config NET_IFE_SKBPRIO + tristate "Support to encoding decoding skb prio on IFE action" + depends on NET_ACT_IFE + +config NET_IFE_SKBTCINDEX + tristate "Support to encoding decoding skb tcindex on IFE action" + depends on NET_ACT_IFE + +config NET_TC_SKB_EXT + bool "TC recirculation support" + depends on NET_CLS_ACT + select SKB_EXTENSIONS + + help + Say Y here to allow tc chain misses to continue in OvS datapath in + the correct recirc_id, and hardware chain misses to continue in + the correct chain in tc software datapath. + + Say N here if you won't be using tc<->ovs offload or tc chains offload. + +endif # NET_SCHED + +config NET_SCH_FIFO + bool diff --git a/net/sched/Makefile b/net/sched/Makefile new file mode 100644 index 000000000..8a33a35fc --- /dev/null +++ b/net/sched/Makefile @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the Linux Traffic Control Unit. +# + +obj-y := sch_generic.o sch_mq.o + +obj-$(CONFIG_INET) += sch_frag.o +obj-$(CONFIG_NET_SCHED) += sch_api.o sch_blackhole.o +obj-$(CONFIG_NET_CLS) += cls_api.o +obj-$(CONFIG_NET_CLS_ACT) += act_api.o +obj-$(CONFIG_NET_ACT_POLICE) += act_police.o +obj-$(CONFIG_NET_ACT_GACT) += act_gact.o +obj-$(CONFIG_NET_ACT_MIRRED) += act_mirred.o +obj-$(CONFIG_NET_ACT_SAMPLE) += act_sample.o +obj-$(CONFIG_NET_ACT_IPT) += act_ipt.o +obj-$(CONFIG_NET_ACT_NAT) += act_nat.o +obj-$(CONFIG_NET_ACT_PEDIT) += act_pedit.o +obj-$(CONFIG_NET_ACT_SIMP) += act_simple.o +obj-$(CONFIG_NET_ACT_SKBEDIT) += act_skbedit.o +obj-$(CONFIG_NET_ACT_CSUM) += act_csum.o +obj-$(CONFIG_NET_ACT_MPLS) += act_mpls.o +obj-$(CONFIG_NET_ACT_VLAN) += act_vlan.o +obj-$(CONFIG_NET_ACT_BPF) += act_bpf.o +obj-$(CONFIG_NET_ACT_CONNMARK) += act_connmark.o +obj-$(CONFIG_NET_ACT_CTINFO) += act_ctinfo.o +obj-$(CONFIG_NET_ACT_SKBMOD) += act_skbmod.o +obj-$(CONFIG_NET_ACT_IFE) += act_ife.o +obj-$(CONFIG_NET_IFE_SKBMARK) += act_meta_mark.o +obj-$(CONFIG_NET_IFE_SKBPRIO) += act_meta_skbprio.o +obj-$(CONFIG_NET_IFE_SKBTCINDEX) += act_meta_skbtcindex.o +obj-$(CONFIG_NET_ACT_TUNNEL_KEY)+= act_tunnel_key.o +obj-$(CONFIG_NET_ACT_CT) += act_ct.o +obj-$(CONFIG_NET_ACT_GATE) += act_gate.o +obj-$(CONFIG_NET_SCH_FIFO) += sch_fifo.o +obj-$(CONFIG_NET_SCH_CBQ) += sch_cbq.o +obj-$(CONFIG_NET_SCH_HTB) += sch_htb.o +obj-$(CONFIG_NET_SCH_HFSC) += sch_hfsc.o +obj-$(CONFIG_NET_SCH_RED) += sch_red.o +obj-$(CONFIG_NET_SCH_GRED) += sch_gred.o +obj-$(CONFIG_NET_SCH_INGRESS) += sch_ingress.o +obj-$(CONFIG_NET_SCH_DSMARK) += sch_dsmark.o +obj-$(CONFIG_NET_SCH_SFB) += sch_sfb.o +obj-$(CONFIG_NET_SCH_SFQ) += sch_sfq.o +obj-$(CONFIG_NET_SCH_TBF) += sch_tbf.o +obj-$(CONFIG_NET_SCH_TEQL) += sch_teql.o +obj-$(CONFIG_NET_SCH_PRIO) += sch_prio.o +obj-$(CONFIG_NET_SCH_MULTIQ) += sch_multiq.o +obj-$(CONFIG_NET_SCH_ATM) += sch_atm.o +obj-$(CONFIG_NET_SCH_NETEM) += sch_netem.o +obj-$(CONFIG_NET_SCH_DRR) += sch_drr.o +obj-$(CONFIG_NET_SCH_PLUG) += sch_plug.o +obj-$(CONFIG_NET_SCH_ETS) += sch_ets.o +obj-$(CONFIG_NET_SCH_MQPRIO) += sch_mqprio.o +obj-$(CONFIG_NET_SCH_SKBPRIO) += sch_skbprio.o +obj-$(CONFIG_NET_SCH_CHOKE) += sch_choke.o +obj-$(CONFIG_NET_SCH_QFQ) += sch_qfq.o +obj-$(CONFIG_NET_SCH_CODEL) += sch_codel.o +obj-$(CONFIG_NET_SCH_FQ_CODEL) += sch_fq_codel.o +obj-$(CONFIG_NET_SCH_CAKE) += sch_cake.o +obj-$(CONFIG_NET_SCH_FQ) += sch_fq.o +obj-$(CONFIG_NET_SCH_HHF) += sch_hhf.o +obj-$(CONFIG_NET_SCH_PIE) += sch_pie.o +obj-$(CONFIG_NET_SCH_FQ_PIE) += sch_fq_pie.o +obj-$(CONFIG_NET_SCH_CBS) += sch_cbs.o +obj-$(CONFIG_NET_SCH_ETF) += sch_etf.o +obj-$(CONFIG_NET_SCH_TAPRIO) += sch_taprio.o + +obj-$(CONFIG_NET_CLS_U32) += cls_u32.o +obj-$(CONFIG_NET_CLS_ROUTE4) += cls_route.o +obj-$(CONFIG_NET_CLS_FW) += cls_fw.o +obj-$(CONFIG_NET_CLS_BASIC) += cls_basic.o +obj-$(CONFIG_NET_CLS_FLOW) += cls_flow.o +obj-$(CONFIG_NET_CLS_CGROUP) += cls_cgroup.o +obj-$(CONFIG_NET_CLS_BPF) += cls_bpf.o +obj-$(CONFIG_NET_CLS_FLOWER) += cls_flower.o +obj-$(CONFIG_NET_CLS_MATCHALL) += cls_matchall.o +obj-$(CONFIG_NET_EMATCH) += ematch.o +obj-$(CONFIG_NET_EMATCH_CMP) += em_cmp.o +obj-$(CONFIG_NET_EMATCH_NBYTE) += em_nbyte.o +obj-$(CONFIG_NET_EMATCH_U32) += em_u32.o +obj-$(CONFIG_NET_EMATCH_META) += em_meta.o +obj-$(CONFIG_NET_EMATCH_TEXT) += em_text.o +obj-$(CONFIG_NET_EMATCH_CANID) += em_canid.o +obj-$(CONFIG_NET_EMATCH_IPSET) += em_ipset.o +obj-$(CONFIG_NET_EMATCH_IPT) += em_ipt.o diff --git a/net/sched/act_api.c b/net/sched/act_api.c new file mode 100644 index 000000000..b33f88e50 --- /dev/null +++ b/net/sched/act_api.c @@ -0,0 +1,2189 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/act_api.c Packet action API. + * + * Author: Jamal Hadi Salim + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CONFIG_INET +DEFINE_STATIC_KEY_FALSE(tcf_frag_xmit_count); +EXPORT_SYMBOL_GPL(tcf_frag_xmit_count); +#endif + +int tcf_dev_queue_xmit(struct sk_buff *skb, int (*xmit)(struct sk_buff *skb)) +{ +#ifdef CONFIG_INET + if (static_branch_unlikely(&tcf_frag_xmit_count)) + return sch_frag_xmit_hook(skb, xmit); +#endif + + return xmit(skb); +} +EXPORT_SYMBOL_GPL(tcf_dev_queue_xmit); + +static void tcf_action_goto_chain_exec(const struct tc_action *a, + struct tcf_result *res) +{ + const struct tcf_chain *chain = rcu_dereference_bh(a->goto_chain); + + res->goto_tp = rcu_dereference_bh(chain->filter_chain); +} + +static void tcf_free_cookie_rcu(struct rcu_head *p) +{ + struct tc_cookie *cookie = container_of(p, struct tc_cookie, rcu); + + kfree(cookie->data); + kfree(cookie); +} + +static void tcf_set_action_cookie(struct tc_cookie __rcu **old_cookie, + struct tc_cookie *new_cookie) +{ + struct tc_cookie *old; + + old = xchg((__force struct tc_cookie **)old_cookie, new_cookie); + if (old) + call_rcu(&old->rcu, tcf_free_cookie_rcu); +} + +int tcf_action_check_ctrlact(int action, struct tcf_proto *tp, + struct tcf_chain **newchain, + struct netlink_ext_ack *extack) +{ + int opcode = TC_ACT_EXT_OPCODE(action), ret = -EINVAL; + u32 chain_index; + + if (!opcode) + ret = action > TC_ACT_VALUE_MAX ? -EINVAL : 0; + else if (opcode <= TC_ACT_EXT_OPCODE_MAX || action == TC_ACT_UNSPEC) + ret = 0; + if (ret) { + NL_SET_ERR_MSG(extack, "invalid control action"); + goto end; + } + + if (TC_ACT_EXT_CMP(action, TC_ACT_GOTO_CHAIN)) { + chain_index = action & TC_ACT_EXT_VAL_MASK; + if (!tp || !newchain) { + ret = -EINVAL; + NL_SET_ERR_MSG(extack, + "can't goto NULL proto/chain"); + goto end; + } + *newchain = tcf_chain_get_by_act(tp->chain->block, chain_index); + if (!*newchain) { + ret = -ENOMEM; + NL_SET_ERR_MSG(extack, + "can't allocate goto_chain"); + } + } +end: + return ret; +} +EXPORT_SYMBOL(tcf_action_check_ctrlact); + +struct tcf_chain *tcf_action_set_ctrlact(struct tc_action *a, int action, + struct tcf_chain *goto_chain) +{ + a->tcfa_action = action; + goto_chain = rcu_replace_pointer(a->goto_chain, goto_chain, 1); + return goto_chain; +} +EXPORT_SYMBOL(tcf_action_set_ctrlact); + +/* XXX: For standalone actions, we don't need a RCU grace period either, because + * actions are always connected to filters and filters are already destroyed in + * RCU callbacks, so after a RCU grace period actions are already disconnected + * from filters. Readers later can not find us. + */ +static void free_tcf(struct tc_action *p) +{ + struct tcf_chain *chain = rcu_dereference_protected(p->goto_chain, 1); + + free_percpu(p->cpu_bstats); + free_percpu(p->cpu_bstats_hw); + free_percpu(p->cpu_qstats); + + tcf_set_action_cookie(&p->act_cookie, NULL); + if (chain) + tcf_chain_put_by_act(chain); + + kfree(p); +} + +static void offload_action_hw_count_set(struct tc_action *act, + u32 hw_count) +{ + act->in_hw_count = hw_count; +} + +static void offload_action_hw_count_inc(struct tc_action *act, + u32 hw_count) +{ + act->in_hw_count += hw_count; +} + +static void offload_action_hw_count_dec(struct tc_action *act, + u32 hw_count) +{ + act->in_hw_count = act->in_hw_count > hw_count ? + act->in_hw_count - hw_count : 0; +} + +static unsigned int tcf_offload_act_num_actions_single(struct tc_action *act) +{ + if (is_tcf_pedit(act)) + return tcf_pedit_nkeys(act); + else + return 1; +} + +static bool tc_act_skip_hw(u32 flags) +{ + return (flags & TCA_ACT_FLAGS_SKIP_HW) ? true : false; +} + +static bool tc_act_skip_sw(u32 flags) +{ + return (flags & TCA_ACT_FLAGS_SKIP_SW) ? true : false; +} + +static bool tc_act_in_hw(struct tc_action *act) +{ + return !!act->in_hw_count; +} + +/* SKIP_HW and SKIP_SW are mutually exclusive flags. */ +static bool tc_act_flags_valid(u32 flags) +{ + flags &= TCA_ACT_FLAGS_SKIP_HW | TCA_ACT_FLAGS_SKIP_SW; + + return flags ^ (TCA_ACT_FLAGS_SKIP_HW | TCA_ACT_FLAGS_SKIP_SW); +} + +static int offload_action_init(struct flow_offload_action *fl_action, + struct tc_action *act, + enum offload_act_command cmd, + struct netlink_ext_ack *extack) +{ + int err; + + fl_action->extack = extack; + fl_action->command = cmd; + fl_action->index = act->tcfa_index; + + if (act->ops->offload_act_setup) { + spin_lock_bh(&act->tcfa_lock); + err = act->ops->offload_act_setup(act, fl_action, NULL, + false, extack); + spin_unlock_bh(&act->tcfa_lock); + return err; + } + + return -EOPNOTSUPP; +} + +static int tcf_action_offload_cmd_ex(struct flow_offload_action *fl_act, + u32 *hw_count) +{ + int err; + + err = flow_indr_dev_setup_offload(NULL, NULL, TC_SETUP_ACT, + fl_act, NULL, NULL); + if (err < 0) + return err; + + if (hw_count) + *hw_count = err; + + return 0; +} + +static int tcf_action_offload_cmd_cb_ex(struct flow_offload_action *fl_act, + u32 *hw_count, + flow_indr_block_bind_cb_t *cb, + void *cb_priv) +{ + int err; + + err = cb(NULL, NULL, cb_priv, TC_SETUP_ACT, NULL, fl_act, NULL); + if (err < 0) + return err; + + if (hw_count) + *hw_count = 1; + + return 0; +} + +static int tcf_action_offload_cmd(struct flow_offload_action *fl_act, + u32 *hw_count, + flow_indr_block_bind_cb_t *cb, + void *cb_priv) +{ + return cb ? tcf_action_offload_cmd_cb_ex(fl_act, hw_count, + cb, cb_priv) : + tcf_action_offload_cmd_ex(fl_act, hw_count); +} + +static int tcf_action_offload_add_ex(struct tc_action *action, + struct netlink_ext_ack *extack, + flow_indr_block_bind_cb_t *cb, + void *cb_priv) +{ + bool skip_sw = tc_act_skip_sw(action->tcfa_flags); + struct tc_action *actions[TCA_ACT_MAX_PRIO] = { + [0] = action, + }; + struct flow_offload_action *fl_action; + u32 in_hw_count = 0; + int num, err = 0; + + if (tc_act_skip_hw(action->tcfa_flags)) + return 0; + + num = tcf_offload_act_num_actions_single(action); + fl_action = offload_action_alloc(num); + if (!fl_action) + return -ENOMEM; + + err = offload_action_init(fl_action, action, FLOW_ACT_REPLACE, extack); + if (err) + goto fl_err; + + err = tc_setup_action(&fl_action->action, actions, extack); + if (err) { + NL_SET_ERR_MSG_MOD(extack, + "Failed to setup tc actions for offload"); + goto fl_err; + } + + err = tcf_action_offload_cmd(fl_action, &in_hw_count, cb, cb_priv); + if (!err) + cb ? offload_action_hw_count_inc(action, in_hw_count) : + offload_action_hw_count_set(action, in_hw_count); + + if (skip_sw && !tc_act_in_hw(action)) + err = -EINVAL; + + tc_cleanup_offload_action(&fl_action->action); + +fl_err: + kfree(fl_action); + + return err; +} + +/* offload the tc action after it is inserted */ +static int tcf_action_offload_add(struct tc_action *action, + struct netlink_ext_ack *extack) +{ + return tcf_action_offload_add_ex(action, extack, NULL, NULL); +} + +int tcf_action_update_hw_stats(struct tc_action *action) +{ + struct flow_offload_action fl_act = {}; + int err; + + if (!tc_act_in_hw(action)) + return -EOPNOTSUPP; + + err = offload_action_init(&fl_act, action, FLOW_ACT_STATS, NULL); + if (err) + return err; + + err = tcf_action_offload_cmd(&fl_act, NULL, NULL, NULL); + if (!err) { + preempt_disable(); + tcf_action_stats_update(action, fl_act.stats.bytes, + fl_act.stats.pkts, + fl_act.stats.drops, + fl_act.stats.lastused, + true); + preempt_enable(); + action->used_hw_stats = fl_act.stats.used_hw_stats; + action->used_hw_stats_valid = true; + } else { + return -EOPNOTSUPP; + } + + return 0; +} +EXPORT_SYMBOL(tcf_action_update_hw_stats); + +static int tcf_action_offload_del_ex(struct tc_action *action, + flow_indr_block_bind_cb_t *cb, + void *cb_priv) +{ + struct flow_offload_action fl_act = {}; + u32 in_hw_count = 0; + int err = 0; + + if (!tc_act_in_hw(action)) + return 0; + + err = offload_action_init(&fl_act, action, FLOW_ACT_DESTROY, NULL); + if (err) + return err; + + err = tcf_action_offload_cmd(&fl_act, &in_hw_count, cb, cb_priv); + if (err < 0) + return err; + + if (!cb && action->in_hw_count != in_hw_count) + return -EINVAL; + + /* do not need to update hw state when deleting action */ + if (cb && in_hw_count) + offload_action_hw_count_dec(action, in_hw_count); + + return 0; +} + +static int tcf_action_offload_del(struct tc_action *action) +{ + return tcf_action_offload_del_ex(action, NULL, NULL); +} + +static void tcf_action_cleanup(struct tc_action *p) +{ + tcf_action_offload_del(p); + if (p->ops->cleanup) + p->ops->cleanup(p); + + gen_kill_estimator(&p->tcfa_rate_est); + free_tcf(p); +} + +static int __tcf_action_put(struct tc_action *p, bool bind) +{ + struct tcf_idrinfo *idrinfo = p->idrinfo; + + if (refcount_dec_and_mutex_lock(&p->tcfa_refcnt, &idrinfo->lock)) { + if (bind) + atomic_dec(&p->tcfa_bindcnt); + idr_remove(&idrinfo->action_idr, p->tcfa_index); + mutex_unlock(&idrinfo->lock); + + tcf_action_cleanup(p); + return 1; + } + + if (bind) + atomic_dec(&p->tcfa_bindcnt); + + return 0; +} + +static int __tcf_idr_release(struct tc_action *p, bool bind, bool strict) +{ + int ret = 0; + + /* Release with strict==1 and bind==0 is only called through act API + * interface (classifiers always bind). Only case when action with + * positive reference count and zero bind count can exist is when it was + * also created with act API (unbinding last classifier will destroy the + * action if it was created by classifier). So only case when bind count + * can be changed after initial check is when unbound action is + * destroyed by act API while classifier binds to action with same id + * concurrently. This result either creation of new action(same behavior + * as before), or reusing existing action if concurrent process + * increments reference count before action is deleted. Both scenarios + * are acceptable. + */ + if (p) { + if (!bind && strict && atomic_read(&p->tcfa_bindcnt) > 0) + return -EPERM; + + if (__tcf_action_put(p, bind)) + ret = ACT_P_DELETED; + } + + return ret; +} + +int tcf_idr_release(struct tc_action *a, bool bind) +{ + const struct tc_action_ops *ops = a->ops; + int ret; + + ret = __tcf_idr_release(a, bind, false); + if (ret == ACT_P_DELETED) + module_put(ops->owner); + return ret; +} +EXPORT_SYMBOL(tcf_idr_release); + +static size_t tcf_action_shared_attrs_size(const struct tc_action *act) +{ + struct tc_cookie *act_cookie; + u32 cookie_len = 0; + + rcu_read_lock(); + act_cookie = rcu_dereference(act->act_cookie); + + if (act_cookie) + cookie_len = nla_total_size(act_cookie->len); + rcu_read_unlock(); + + return nla_total_size(0) /* action number nested */ + + nla_total_size(IFNAMSIZ) /* TCA_ACT_KIND */ + + cookie_len /* TCA_ACT_COOKIE */ + + nla_total_size(sizeof(struct nla_bitfield32)) /* TCA_ACT_HW_STATS */ + + nla_total_size(0) /* TCA_ACT_STATS nested */ + + nla_total_size(sizeof(struct nla_bitfield32)) /* TCA_ACT_FLAGS */ + /* TCA_STATS_BASIC */ + + nla_total_size_64bit(sizeof(struct gnet_stats_basic)) + /* TCA_STATS_PKT64 */ + + nla_total_size_64bit(sizeof(u64)) + /* TCA_STATS_QUEUE */ + + nla_total_size_64bit(sizeof(struct gnet_stats_queue)) + + nla_total_size(0) /* TCA_OPTIONS nested */ + + nla_total_size(sizeof(struct tcf_t)); /* TCA_GACT_TM */ +} + +static size_t tcf_action_full_attrs_size(size_t sz) +{ + return NLMSG_HDRLEN /* struct nlmsghdr */ + + sizeof(struct tcamsg) + + nla_total_size(0) /* TCA_ACT_TAB nested */ + + sz; +} + +static size_t tcf_action_fill_size(const struct tc_action *act) +{ + size_t sz = tcf_action_shared_attrs_size(act); + + if (act->ops->get_fill_size) + return act->ops->get_fill_size(act) + sz; + return sz; +} + +static int +tcf_action_dump_terse(struct sk_buff *skb, struct tc_action *a, bool from_act) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tc_cookie *cookie; + + if (nla_put_string(skb, TCA_KIND, a->ops->kind)) + goto nla_put_failure; + if (tcf_action_copy_stats(skb, a, 0)) + goto nla_put_failure; + if (from_act && nla_put_u32(skb, TCA_ACT_INDEX, a->tcfa_index)) + goto nla_put_failure; + + rcu_read_lock(); + cookie = rcu_dereference(a->act_cookie); + if (cookie) { + if (nla_put(skb, TCA_ACT_COOKIE, cookie->len, cookie->data)) { + rcu_read_unlock(); + goto nla_put_failure; + } + } + rcu_read_unlock(); + + return 0; + +nla_put_failure: + nlmsg_trim(skb, b); + return -1; +} + +static int tcf_dump_walker(struct tcf_idrinfo *idrinfo, struct sk_buff *skb, + struct netlink_callback *cb) +{ + int err = 0, index = -1, s_i = 0, n_i = 0; + u32 act_flags = cb->args[2]; + unsigned long jiffy_since = cb->args[3]; + struct nlattr *nest; + struct idr *idr = &idrinfo->action_idr; + struct tc_action *p; + unsigned long id = 1; + unsigned long tmp; + + mutex_lock(&idrinfo->lock); + + s_i = cb->args[0]; + + idr_for_each_entry_ul(idr, p, tmp, id) { + index++; + if (index < s_i) + continue; + if (IS_ERR(p)) + continue; + + if (jiffy_since && + time_after(jiffy_since, + (unsigned long)p->tcfa_tm.lastuse)) + continue; + + nest = nla_nest_start_noflag(skb, n_i); + if (!nest) { + index--; + goto nla_put_failure; + } + err = (act_flags & TCA_ACT_FLAG_TERSE_DUMP) ? + tcf_action_dump_terse(skb, p, true) : + tcf_action_dump_1(skb, p, 0, 0); + if (err < 0) { + index--; + nlmsg_trim(skb, nest); + goto done; + } + nla_nest_end(skb, nest); + n_i++; + if (!(act_flags & TCA_ACT_FLAG_LARGE_DUMP_ON) && + n_i >= TCA_ACT_MAX_PRIO) + goto done; + } +done: + if (index >= 0) + cb->args[0] = index + 1; + + mutex_unlock(&idrinfo->lock); + if (n_i) { + if (act_flags & TCA_ACT_FLAG_LARGE_DUMP_ON) + cb->args[1] = n_i; + } + return n_i; + +nla_put_failure: + nla_nest_cancel(skb, nest); + goto done; +} + +static int tcf_idr_release_unsafe(struct tc_action *p) +{ + if (atomic_read(&p->tcfa_bindcnt) > 0) + return -EPERM; + + if (refcount_dec_and_test(&p->tcfa_refcnt)) { + idr_remove(&p->idrinfo->action_idr, p->tcfa_index); + tcf_action_cleanup(p); + return ACT_P_DELETED; + } + + return 0; +} + +static int tcf_del_walker(struct tcf_idrinfo *idrinfo, struct sk_buff *skb, + const struct tc_action_ops *ops, + struct netlink_ext_ack *extack) +{ + struct nlattr *nest; + int n_i = 0; + int ret = -EINVAL; + struct idr *idr = &idrinfo->action_idr; + struct tc_action *p; + unsigned long id = 1; + unsigned long tmp; + + nest = nla_nest_start_noflag(skb, 0); + if (nest == NULL) + goto nla_put_failure; + if (nla_put_string(skb, TCA_KIND, ops->kind)) + goto nla_put_failure; + + ret = 0; + mutex_lock(&idrinfo->lock); + idr_for_each_entry_ul(idr, p, tmp, id) { + if (IS_ERR(p)) + continue; + ret = tcf_idr_release_unsafe(p); + if (ret == ACT_P_DELETED) + module_put(ops->owner); + else if (ret < 0) + break; + n_i++; + } + mutex_unlock(&idrinfo->lock); + if (ret < 0) { + if (n_i) + NL_SET_ERR_MSG(extack, "Unable to flush all TC actions"); + else + goto nla_put_failure; + } + + ret = nla_put_u32(skb, TCA_FCNT, n_i); + if (ret) + goto nla_put_failure; + nla_nest_end(skb, nest); + + return n_i; +nla_put_failure: + nla_nest_cancel(skb, nest); + return ret; +} + +int tcf_generic_walker(struct tc_action_net *tn, struct sk_buff *skb, + struct netlink_callback *cb, int type, + const struct tc_action_ops *ops, + struct netlink_ext_ack *extack) +{ + struct tcf_idrinfo *idrinfo = tn->idrinfo; + + if (type == RTM_DELACTION) { + return tcf_del_walker(idrinfo, skb, ops, extack); + } else if (type == RTM_GETACTION) { + return tcf_dump_walker(idrinfo, skb, cb); + } else { + WARN(1, "tcf_generic_walker: unknown command %d\n", type); + NL_SET_ERR_MSG(extack, "tcf_generic_walker: unknown command"); + return -EINVAL; + } +} +EXPORT_SYMBOL(tcf_generic_walker); + +int tcf_idr_search(struct tc_action_net *tn, struct tc_action **a, u32 index) +{ + struct tcf_idrinfo *idrinfo = tn->idrinfo; + struct tc_action *p; + + mutex_lock(&idrinfo->lock); + p = idr_find(&idrinfo->action_idr, index); + if (IS_ERR(p)) + p = NULL; + else if (p) + refcount_inc(&p->tcfa_refcnt); + mutex_unlock(&idrinfo->lock); + + if (p) { + *a = p; + return true; + } + return false; +} +EXPORT_SYMBOL(tcf_idr_search); + +static int __tcf_generic_walker(struct net *net, struct sk_buff *skb, + struct netlink_callback *cb, int type, + const struct tc_action_ops *ops, + struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, ops->net_id); + + if (unlikely(ops->walk)) + return ops->walk(net, skb, cb, type, ops, extack); + + return tcf_generic_walker(tn, skb, cb, type, ops, extack); +} + +static int __tcf_idr_search(struct net *net, + const struct tc_action_ops *ops, + struct tc_action **a, u32 index) +{ + struct tc_action_net *tn = net_generic(net, ops->net_id); + + if (unlikely(ops->lookup)) + return ops->lookup(net, a, index); + + return tcf_idr_search(tn, a, index); +} + +static int tcf_idr_delete_index(struct tcf_idrinfo *idrinfo, u32 index) +{ + struct tc_action *p; + int ret = 0; + + mutex_lock(&idrinfo->lock); + p = idr_find(&idrinfo->action_idr, index); + if (!p) { + mutex_unlock(&idrinfo->lock); + return -ENOENT; + } + + if (!atomic_read(&p->tcfa_bindcnt)) { + if (refcount_dec_and_test(&p->tcfa_refcnt)) { + struct module *owner = p->ops->owner; + + WARN_ON(p != idr_remove(&idrinfo->action_idr, + p->tcfa_index)); + mutex_unlock(&idrinfo->lock); + + tcf_action_cleanup(p); + module_put(owner); + return 0; + } + ret = 0; + } else { + ret = -EPERM; + } + + mutex_unlock(&idrinfo->lock); + return ret; +} + +int tcf_idr_create(struct tc_action_net *tn, u32 index, struct nlattr *est, + struct tc_action **a, const struct tc_action_ops *ops, + int bind, bool cpustats, u32 flags) +{ + struct tc_action *p = kzalloc(ops->size, GFP_KERNEL); + struct tcf_idrinfo *idrinfo = tn->idrinfo; + int err = -ENOMEM; + + if (unlikely(!p)) + return -ENOMEM; + refcount_set(&p->tcfa_refcnt, 1); + if (bind) + atomic_set(&p->tcfa_bindcnt, 1); + + if (cpustats) { + p->cpu_bstats = netdev_alloc_pcpu_stats(struct gnet_stats_basic_sync); + if (!p->cpu_bstats) + goto err1; + p->cpu_bstats_hw = netdev_alloc_pcpu_stats(struct gnet_stats_basic_sync); + if (!p->cpu_bstats_hw) + goto err2; + p->cpu_qstats = alloc_percpu(struct gnet_stats_queue); + if (!p->cpu_qstats) + goto err3; + } + gnet_stats_basic_sync_init(&p->tcfa_bstats); + gnet_stats_basic_sync_init(&p->tcfa_bstats_hw); + spin_lock_init(&p->tcfa_lock); + p->tcfa_index = index; + p->tcfa_tm.install = jiffies; + p->tcfa_tm.lastuse = jiffies; + p->tcfa_tm.firstuse = 0; + p->tcfa_flags = flags; + if (est) { + err = gen_new_estimator(&p->tcfa_bstats, p->cpu_bstats, + &p->tcfa_rate_est, + &p->tcfa_lock, false, est); + if (err) + goto err4; + } + + p->idrinfo = idrinfo; + __module_get(ops->owner); + p->ops = ops; + *a = p; + return 0; +err4: + free_percpu(p->cpu_qstats); +err3: + free_percpu(p->cpu_bstats_hw); +err2: + free_percpu(p->cpu_bstats); +err1: + kfree(p); + return err; +} +EXPORT_SYMBOL(tcf_idr_create); + +int tcf_idr_create_from_flags(struct tc_action_net *tn, u32 index, + struct nlattr *est, struct tc_action **a, + const struct tc_action_ops *ops, int bind, + u32 flags) +{ + /* Set cpustats according to actions flags. */ + return tcf_idr_create(tn, index, est, a, ops, bind, + !(flags & TCA_ACT_FLAGS_NO_PERCPU_STATS), flags); +} +EXPORT_SYMBOL(tcf_idr_create_from_flags); + +/* Cleanup idr index that was allocated but not initialized. */ + +void tcf_idr_cleanup(struct tc_action_net *tn, u32 index) +{ + struct tcf_idrinfo *idrinfo = tn->idrinfo; + + mutex_lock(&idrinfo->lock); + /* Remove ERR_PTR(-EBUSY) allocated by tcf_idr_check_alloc */ + WARN_ON(!IS_ERR(idr_remove(&idrinfo->action_idr, index))); + mutex_unlock(&idrinfo->lock); +} +EXPORT_SYMBOL(tcf_idr_cleanup); + +/* Check if action with specified index exists. If actions is found, increments + * its reference and bind counters, and return 1. Otherwise insert temporary + * error pointer (to prevent concurrent users from inserting actions with same + * index) and return 0. + */ + +int tcf_idr_check_alloc(struct tc_action_net *tn, u32 *index, + struct tc_action **a, int bind) +{ + struct tcf_idrinfo *idrinfo = tn->idrinfo; + struct tc_action *p; + int ret; + +again: + mutex_lock(&idrinfo->lock); + if (*index) { + p = idr_find(&idrinfo->action_idr, *index); + if (IS_ERR(p)) { + /* This means that another process allocated + * index but did not assign the pointer yet. + */ + mutex_unlock(&idrinfo->lock); + goto again; + } + + if (p) { + refcount_inc(&p->tcfa_refcnt); + if (bind) + atomic_inc(&p->tcfa_bindcnt); + *a = p; + ret = 1; + } else { + *a = NULL; + ret = idr_alloc_u32(&idrinfo->action_idr, NULL, index, + *index, GFP_KERNEL); + if (!ret) + idr_replace(&idrinfo->action_idr, + ERR_PTR(-EBUSY), *index); + } + } else { + *index = 1; + *a = NULL; + ret = idr_alloc_u32(&idrinfo->action_idr, NULL, index, + UINT_MAX, GFP_KERNEL); + if (!ret) + idr_replace(&idrinfo->action_idr, ERR_PTR(-EBUSY), + *index); + } + mutex_unlock(&idrinfo->lock); + return ret; +} +EXPORT_SYMBOL(tcf_idr_check_alloc); + +void tcf_idrinfo_destroy(const struct tc_action_ops *ops, + struct tcf_idrinfo *idrinfo) +{ + struct idr *idr = &idrinfo->action_idr; + struct tc_action *p; + int ret; + unsigned long id = 1; + unsigned long tmp; + + idr_for_each_entry_ul(idr, p, tmp, id) { + ret = __tcf_idr_release(p, false, true); + if (ret == ACT_P_DELETED) + module_put(ops->owner); + else if (ret < 0) + return; + } + idr_destroy(&idrinfo->action_idr); +} +EXPORT_SYMBOL(tcf_idrinfo_destroy); + +static LIST_HEAD(act_base); +static DEFINE_RWLOCK(act_mod_lock); +/* since act ops id is stored in pernet subsystem list, + * then there is no way to walk through only all the action + * subsystem, so we keep tc action pernet ops id for + * reoffload to walk through. + */ +static LIST_HEAD(act_pernet_id_list); +static DEFINE_MUTEX(act_id_mutex); +struct tc_act_pernet_id { + struct list_head list; + unsigned int id; +}; + +static int tcf_pernet_add_id_list(unsigned int id) +{ + struct tc_act_pernet_id *id_ptr; + int ret = 0; + + mutex_lock(&act_id_mutex); + list_for_each_entry(id_ptr, &act_pernet_id_list, list) { + if (id_ptr->id == id) { + ret = -EEXIST; + goto err_out; + } + } + + id_ptr = kzalloc(sizeof(*id_ptr), GFP_KERNEL); + if (!id_ptr) { + ret = -ENOMEM; + goto err_out; + } + id_ptr->id = id; + + list_add_tail(&id_ptr->list, &act_pernet_id_list); + +err_out: + mutex_unlock(&act_id_mutex); + return ret; +} + +static void tcf_pernet_del_id_list(unsigned int id) +{ + struct tc_act_pernet_id *id_ptr; + + mutex_lock(&act_id_mutex); + list_for_each_entry(id_ptr, &act_pernet_id_list, list) { + if (id_ptr->id == id) { + list_del(&id_ptr->list); + kfree(id_ptr); + break; + } + } + mutex_unlock(&act_id_mutex); +} + +int tcf_register_action(struct tc_action_ops *act, + struct pernet_operations *ops) +{ + struct tc_action_ops *a; + int ret; + + if (!act->act || !act->dump || !act->init) + return -EINVAL; + + /* We have to register pernet ops before making the action ops visible, + * otherwise tcf_action_init_1() could get a partially initialized + * netns. + */ + ret = register_pernet_subsys(ops); + if (ret) + return ret; + + if (ops->id) { + ret = tcf_pernet_add_id_list(*ops->id); + if (ret) + goto err_id; + } + + write_lock(&act_mod_lock); + list_for_each_entry(a, &act_base, head) { + if (act->id == a->id || (strcmp(act->kind, a->kind) == 0)) { + ret = -EEXIST; + goto err_out; + } + } + list_add_tail(&act->head, &act_base); + write_unlock(&act_mod_lock); + + return 0; + +err_out: + write_unlock(&act_mod_lock); + if (ops->id) + tcf_pernet_del_id_list(*ops->id); +err_id: + unregister_pernet_subsys(ops); + return ret; +} +EXPORT_SYMBOL(tcf_register_action); + +int tcf_unregister_action(struct tc_action_ops *act, + struct pernet_operations *ops) +{ + struct tc_action_ops *a; + int err = -ENOENT; + + write_lock(&act_mod_lock); + list_for_each_entry(a, &act_base, head) { + if (a == act) { + list_del(&act->head); + err = 0; + break; + } + } + write_unlock(&act_mod_lock); + if (!err) { + unregister_pernet_subsys(ops); + if (ops->id) + tcf_pernet_del_id_list(*ops->id); + } + return err; +} +EXPORT_SYMBOL(tcf_unregister_action); + +/* lookup by name */ +static struct tc_action_ops *tc_lookup_action_n(char *kind) +{ + struct tc_action_ops *a, *res = NULL; + + if (kind) { + read_lock(&act_mod_lock); + list_for_each_entry(a, &act_base, head) { + if (strcmp(kind, a->kind) == 0) { + if (try_module_get(a->owner)) + res = a; + break; + } + } + read_unlock(&act_mod_lock); + } + return res; +} + +/* lookup by nlattr */ +static struct tc_action_ops *tc_lookup_action(struct nlattr *kind) +{ + struct tc_action_ops *a, *res = NULL; + + if (kind) { + read_lock(&act_mod_lock); + list_for_each_entry(a, &act_base, head) { + if (nla_strcmp(kind, a->kind) == 0) { + if (try_module_get(a->owner)) + res = a; + break; + } + } + read_unlock(&act_mod_lock); + } + return res; +} + +/*TCA_ACT_MAX_PRIO is 32, there count up to 32 */ +#define TCA_ACT_MAX_PRIO_MASK 0x1FF +int tcf_action_exec(struct sk_buff *skb, struct tc_action **actions, + int nr_actions, struct tcf_result *res) +{ + u32 jmp_prgcnt = 0; + u32 jmp_ttl = TCA_ACT_MAX_PRIO; /*matches actions per filter */ + int i; + int ret = TC_ACT_OK; + + if (skb_skip_tc_classify(skb)) + return TC_ACT_OK; + +restart_act_graph: + for (i = 0; i < nr_actions; i++) { + const struct tc_action *a = actions[i]; + int repeat_ttl; + + if (jmp_prgcnt > 0) { + jmp_prgcnt -= 1; + continue; + } + + if (tc_act_skip_sw(a->tcfa_flags)) + continue; + + repeat_ttl = 32; +repeat: + ret = a->ops->act(skb, a, res); + if (unlikely(ret == TC_ACT_REPEAT)) { + if (--repeat_ttl != 0) + goto repeat; + /* suspicious opcode, stop pipeline */ + net_warn_ratelimited("TC_ACT_REPEAT abuse ?\n"); + return TC_ACT_OK; + } + if (TC_ACT_EXT_CMP(ret, TC_ACT_JUMP)) { + jmp_prgcnt = ret & TCA_ACT_MAX_PRIO_MASK; + if (!jmp_prgcnt || (jmp_prgcnt > nr_actions)) { + /* faulty opcode, stop pipeline */ + return TC_ACT_OK; + } else { + jmp_ttl -= 1; + if (jmp_ttl > 0) + goto restart_act_graph; + else /* faulty graph, stop pipeline */ + return TC_ACT_OK; + } + } else if (TC_ACT_EXT_CMP(ret, TC_ACT_GOTO_CHAIN)) { + if (unlikely(!rcu_access_pointer(a->goto_chain))) { + net_warn_ratelimited("can't go to NULL chain!\n"); + return TC_ACT_SHOT; + } + tcf_action_goto_chain_exec(a, res); + } + + if (ret != TC_ACT_PIPE) + break; + } + + return ret; +} +EXPORT_SYMBOL(tcf_action_exec); + +int tcf_action_destroy(struct tc_action *actions[], int bind) +{ + const struct tc_action_ops *ops; + struct tc_action *a; + int ret = 0, i; + + for (i = 0; i < TCA_ACT_MAX_PRIO && actions[i]; i++) { + a = actions[i]; + actions[i] = NULL; + ops = a->ops; + ret = __tcf_idr_release(a, bind, true); + if (ret == ACT_P_DELETED) + module_put(ops->owner); + else if (ret < 0) + return ret; + } + return ret; +} + +static int tcf_action_put(struct tc_action *p) +{ + return __tcf_action_put(p, false); +} + +/* Put all actions in this array, skip those NULL's. */ +static void tcf_action_put_many(struct tc_action *actions[]) +{ + int i; + + for (i = 0; i < TCA_ACT_MAX_PRIO; i++) { + struct tc_action *a = actions[i]; + const struct tc_action_ops *ops; + + if (!a) + continue; + ops = a->ops; + if (tcf_action_put(a)) + module_put(ops->owner); + } +} + +int +tcf_action_dump_old(struct sk_buff *skb, struct tc_action *a, int bind, int ref) +{ + return a->ops->dump(skb, a, bind, ref); +} + +int +tcf_action_dump_1(struct sk_buff *skb, struct tc_action *a, int bind, int ref) +{ + int err = -EINVAL; + unsigned char *b = skb_tail_pointer(skb); + struct nlattr *nest; + u32 flags; + + if (tcf_action_dump_terse(skb, a, false)) + goto nla_put_failure; + + if (a->hw_stats != TCA_ACT_HW_STATS_ANY && + nla_put_bitfield32(skb, TCA_ACT_HW_STATS, + a->hw_stats, TCA_ACT_HW_STATS_ANY)) + goto nla_put_failure; + + if (a->used_hw_stats_valid && + nla_put_bitfield32(skb, TCA_ACT_USED_HW_STATS, + a->used_hw_stats, TCA_ACT_HW_STATS_ANY)) + goto nla_put_failure; + + flags = a->tcfa_flags & TCA_ACT_FLAGS_USER_MASK; + if (flags && + nla_put_bitfield32(skb, TCA_ACT_FLAGS, + flags, flags)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_ACT_IN_HW_COUNT, a->in_hw_count)) + goto nla_put_failure; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + err = tcf_action_dump_old(skb, a, bind, ref); + if (err > 0) { + nla_nest_end(skb, nest); + return err; + } + +nla_put_failure: + nlmsg_trim(skb, b); + return -1; +} +EXPORT_SYMBOL(tcf_action_dump_1); + +int tcf_action_dump(struct sk_buff *skb, struct tc_action *actions[], + int bind, int ref, bool terse) +{ + struct tc_action *a; + int err = -EINVAL, i; + struct nlattr *nest; + + for (i = 0; i < TCA_ACT_MAX_PRIO && actions[i]; i++) { + a = actions[i]; + nest = nla_nest_start_noflag(skb, i + 1); + if (nest == NULL) + goto nla_put_failure; + err = terse ? tcf_action_dump_terse(skb, a, false) : + tcf_action_dump_1(skb, a, bind, ref); + if (err < 0) + goto errout; + nla_nest_end(skb, nest); + } + + return 0; + +nla_put_failure: + err = -EINVAL; +errout: + nla_nest_cancel(skb, nest); + return err; +} + +static struct tc_cookie *nla_memdup_cookie(struct nlattr **tb) +{ + struct tc_cookie *c = kzalloc(sizeof(*c), GFP_KERNEL); + if (!c) + return NULL; + + c->data = nla_memdup(tb[TCA_ACT_COOKIE], GFP_KERNEL); + if (!c->data) { + kfree(c); + return NULL; + } + c->len = nla_len(tb[TCA_ACT_COOKIE]); + + return c; +} + +static u8 tcf_action_hw_stats_get(struct nlattr *hw_stats_attr) +{ + struct nla_bitfield32 hw_stats_bf; + + /* If the user did not pass the attr, that means he does + * not care about the type. Return "any" in that case + * which is setting on all supported types. + */ + if (!hw_stats_attr) + return TCA_ACT_HW_STATS_ANY; + hw_stats_bf = nla_get_bitfield32(hw_stats_attr); + return hw_stats_bf.value; +} + +static const struct nla_policy tcf_action_policy[TCA_ACT_MAX + 1] = { + [TCA_ACT_KIND] = { .type = NLA_STRING }, + [TCA_ACT_INDEX] = { .type = NLA_U32 }, + [TCA_ACT_COOKIE] = { .type = NLA_BINARY, + .len = TC_COOKIE_MAX_SIZE }, + [TCA_ACT_OPTIONS] = { .type = NLA_NESTED }, + [TCA_ACT_FLAGS] = NLA_POLICY_BITFIELD32(TCA_ACT_FLAGS_NO_PERCPU_STATS | + TCA_ACT_FLAGS_SKIP_HW | + TCA_ACT_FLAGS_SKIP_SW), + [TCA_ACT_HW_STATS] = NLA_POLICY_BITFIELD32(TCA_ACT_HW_STATS_ANY), +}; + +void tcf_idr_insert_many(struct tc_action *actions[]) +{ + int i; + + for (i = 0; i < TCA_ACT_MAX_PRIO; i++) { + struct tc_action *a = actions[i]; + struct tcf_idrinfo *idrinfo; + + if (!a) + continue; + idrinfo = a->idrinfo; + mutex_lock(&idrinfo->lock); + /* Replace ERR_PTR(-EBUSY) allocated by tcf_idr_check_alloc if + * it is just created, otherwise this is just a nop. + */ + idr_replace(&idrinfo->action_idr, a, a->tcfa_index); + mutex_unlock(&idrinfo->lock); + } +} + +struct tc_action_ops *tc_action_load_ops(struct nlattr *nla, bool police, + bool rtnl_held, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_ACT_MAX + 1]; + struct tc_action_ops *a_o; + char act_name[IFNAMSIZ]; + struct nlattr *kind; + int err; + + if (!police) { + err = nla_parse_nested_deprecated(tb, TCA_ACT_MAX, nla, + tcf_action_policy, extack); + if (err < 0) + return ERR_PTR(err); + err = -EINVAL; + kind = tb[TCA_ACT_KIND]; + if (!kind) { + NL_SET_ERR_MSG(extack, "TC action kind must be specified"); + return ERR_PTR(err); + } + if (nla_strscpy(act_name, kind, IFNAMSIZ) < 0) { + NL_SET_ERR_MSG(extack, "TC action name too long"); + return ERR_PTR(err); + } + } else { + if (strlcpy(act_name, "police", IFNAMSIZ) >= IFNAMSIZ) { + NL_SET_ERR_MSG(extack, "TC action name too long"); + return ERR_PTR(-EINVAL); + } + } + + a_o = tc_lookup_action_n(act_name); + if (a_o == NULL) { +#ifdef CONFIG_MODULES + if (rtnl_held) + rtnl_unlock(); + request_module("act_%s", act_name); + if (rtnl_held) + rtnl_lock(); + + a_o = tc_lookup_action_n(act_name); + + /* We dropped the RTNL semaphore in order to + * perform the module load. So, even if we + * succeeded in loading the module we have to + * tell the caller to replay the request. We + * indicate this using -EAGAIN. + */ + if (a_o != NULL) { + module_put(a_o->owner); + return ERR_PTR(-EAGAIN); + } +#endif + NL_SET_ERR_MSG(extack, "Failed to load TC action module"); + return ERR_PTR(-ENOENT); + } + + return a_o; +} + +struct tc_action *tcf_action_init_1(struct net *net, struct tcf_proto *tp, + struct nlattr *nla, struct nlattr *est, + struct tc_action_ops *a_o, int *init_res, + u32 flags, struct netlink_ext_ack *extack) +{ + bool police = flags & TCA_ACT_FLAGS_POLICE; + struct nla_bitfield32 userflags = { 0, 0 }; + u8 hw_stats = TCA_ACT_HW_STATS_ANY; + struct nlattr *tb[TCA_ACT_MAX + 1]; + struct tc_cookie *cookie = NULL; + struct tc_action *a; + int err; + + /* backward compatibility for policer */ + if (!police) { + err = nla_parse_nested_deprecated(tb, TCA_ACT_MAX, nla, + tcf_action_policy, extack); + if (err < 0) + return ERR_PTR(err); + if (tb[TCA_ACT_COOKIE]) { + cookie = nla_memdup_cookie(tb); + if (!cookie) { + NL_SET_ERR_MSG(extack, "No memory to generate TC cookie"); + err = -ENOMEM; + goto err_out; + } + } + hw_stats = tcf_action_hw_stats_get(tb[TCA_ACT_HW_STATS]); + if (tb[TCA_ACT_FLAGS]) { + userflags = nla_get_bitfield32(tb[TCA_ACT_FLAGS]); + if (!tc_act_flags_valid(userflags.value)) { + err = -EINVAL; + goto err_out; + } + } + + err = a_o->init(net, tb[TCA_ACT_OPTIONS], est, &a, tp, + userflags.value | flags, extack); + } else { + err = a_o->init(net, nla, est, &a, tp, userflags.value | flags, + extack); + } + if (err < 0) + goto err_out; + *init_res = err; + + if (!police && tb[TCA_ACT_COOKIE]) + tcf_set_action_cookie(&a->act_cookie, cookie); + + if (!police) + a->hw_stats = hw_stats; + + return a; + +err_out: + if (cookie) { + kfree(cookie->data); + kfree(cookie); + } + return ERR_PTR(err); +} + +static bool tc_act_bind(u32 flags) +{ + return !!(flags & TCA_ACT_FLAGS_BIND); +} + +/* Returns numbers of initialized actions or negative error. */ + +int tcf_action_init(struct net *net, struct tcf_proto *tp, struct nlattr *nla, + struct nlattr *est, struct tc_action *actions[], + int init_res[], size_t *attr_size, + u32 flags, u32 fl_flags, + struct netlink_ext_ack *extack) +{ + struct tc_action_ops *ops[TCA_ACT_MAX_PRIO] = {}; + struct nlattr *tb[TCA_ACT_MAX_PRIO + 1]; + struct tc_action *act; + size_t sz = 0; + int err; + int i; + + err = nla_parse_nested_deprecated(tb, TCA_ACT_MAX_PRIO, nla, NULL, + extack); + if (err < 0) + return err; + + for (i = 1; i <= TCA_ACT_MAX_PRIO && tb[i]; i++) { + struct tc_action_ops *a_o; + + a_o = tc_action_load_ops(tb[i], flags & TCA_ACT_FLAGS_POLICE, + !(flags & TCA_ACT_FLAGS_NO_RTNL), + extack); + if (IS_ERR(a_o)) { + err = PTR_ERR(a_o); + goto err_mod; + } + ops[i - 1] = a_o; + } + + for (i = 1; i <= TCA_ACT_MAX_PRIO && tb[i]; i++) { + act = tcf_action_init_1(net, tp, tb[i], est, ops[i - 1], + &init_res[i - 1], flags, extack); + if (IS_ERR(act)) { + err = PTR_ERR(act); + goto err; + } + sz += tcf_action_fill_size(act); + /* Start from index 0 */ + actions[i - 1] = act; + if (tc_act_bind(flags)) { + bool skip_sw = tc_skip_sw(fl_flags); + bool skip_hw = tc_skip_hw(fl_flags); + + if (tc_act_bind(act->tcfa_flags)) + continue; + if (skip_sw != tc_act_skip_sw(act->tcfa_flags) || + skip_hw != tc_act_skip_hw(act->tcfa_flags)) { + NL_SET_ERR_MSG(extack, + "Mismatch between action and filter offload flags"); + err = -EINVAL; + goto err; + } + } else { + err = tcf_action_offload_add(act, extack); + if (tc_act_skip_sw(act->tcfa_flags) && err) + goto err; + } + } + + /* We have to commit them all together, because if any error happened in + * between, we could not handle the failure gracefully. + */ + tcf_idr_insert_many(actions); + + *attr_size = tcf_action_full_attrs_size(sz); + err = i - 1; + goto err_mod; + +err: + tcf_action_destroy(actions, flags & TCA_ACT_FLAGS_BIND); +err_mod: + for (i = 0; i < TCA_ACT_MAX_PRIO; i++) { + if (ops[i]) + module_put(ops[i]->owner); + } + return err; +} + +void tcf_action_update_stats(struct tc_action *a, u64 bytes, u64 packets, + u64 drops, bool hw) +{ + if (a->cpu_bstats) { + _bstats_update(this_cpu_ptr(a->cpu_bstats), bytes, packets); + + this_cpu_ptr(a->cpu_qstats)->drops += drops; + + if (hw) + _bstats_update(this_cpu_ptr(a->cpu_bstats_hw), + bytes, packets); + return; + } + + _bstats_update(&a->tcfa_bstats, bytes, packets); + a->tcfa_qstats.drops += drops; + if (hw) + _bstats_update(&a->tcfa_bstats_hw, bytes, packets); +} +EXPORT_SYMBOL(tcf_action_update_stats); + +int tcf_action_copy_stats(struct sk_buff *skb, struct tc_action *p, + int compat_mode) +{ + int err = 0; + struct gnet_dump d; + + if (p == NULL) + goto errout; + + /* update hw stats for this action */ + tcf_action_update_hw_stats(p); + + /* compat_mode being true specifies a call that is supposed + * to add additional backward compatibility statistic TLVs. + */ + if (compat_mode) { + if (p->type == TCA_OLD_COMPAT) + err = gnet_stats_start_copy_compat(skb, 0, + TCA_STATS, + TCA_XSTATS, + &p->tcfa_lock, &d, + TCA_PAD); + else + return 0; + } else + err = gnet_stats_start_copy(skb, TCA_ACT_STATS, + &p->tcfa_lock, &d, TCA_ACT_PAD); + + if (err < 0) + goto errout; + + if (gnet_stats_copy_basic(&d, p->cpu_bstats, + &p->tcfa_bstats, false) < 0 || + gnet_stats_copy_basic_hw(&d, p->cpu_bstats_hw, + &p->tcfa_bstats_hw, false) < 0 || + gnet_stats_copy_rate_est(&d, &p->tcfa_rate_est) < 0 || + gnet_stats_copy_queue(&d, p->cpu_qstats, + &p->tcfa_qstats, + p->tcfa_qstats.qlen) < 0) + goto errout; + + if (gnet_stats_finish_copy(&d) < 0) + goto errout; + + return 0; + +errout: + return -1; +} + +static int tca_get_fill(struct sk_buff *skb, struct tc_action *actions[], + u32 portid, u32 seq, u16 flags, int event, int bind, + int ref, struct netlink_ext_ack *extack) +{ + struct tcamsg *t; + struct nlmsghdr *nlh; + unsigned char *b = skb_tail_pointer(skb); + struct nlattr *nest; + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(*t), flags); + if (!nlh) + goto out_nlmsg_trim; + t = nlmsg_data(nlh); + t->tca_family = AF_UNSPEC; + t->tca__pad1 = 0; + t->tca__pad2 = 0; + + if (extack && extack->_msg && + nla_put_string(skb, TCA_ROOT_EXT_WARN_MSG, extack->_msg)) + goto out_nlmsg_trim; + + nest = nla_nest_start_noflag(skb, TCA_ACT_TAB); + if (!nest) + goto out_nlmsg_trim; + + if (tcf_action_dump(skb, actions, bind, ref, false) < 0) + goto out_nlmsg_trim; + + nla_nest_end(skb, nest); + + nlh->nlmsg_len = skb_tail_pointer(skb) - b; + + return skb->len; + +out_nlmsg_trim: + nlmsg_trim(skb, b); + return -1; +} + +static int +tcf_get_notify(struct net *net, u32 portid, struct nlmsghdr *n, + struct tc_action *actions[], int event, + struct netlink_ext_ack *extack) +{ + struct sk_buff *skb; + + skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) + return -ENOBUFS; + if (tca_get_fill(skb, actions, portid, n->nlmsg_seq, 0, event, + 0, 1, NULL) <= 0) { + NL_SET_ERR_MSG(extack, "Failed to fill netlink attributes while adding TC action"); + kfree_skb(skb); + return -EINVAL; + } + + return rtnl_unicast(skb, net, portid); +} + +static struct tc_action *tcf_action_get_1(struct net *net, struct nlattr *nla, + struct nlmsghdr *n, u32 portid, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_ACT_MAX + 1]; + const struct tc_action_ops *ops; + struct tc_action *a; + int index; + int err; + + err = nla_parse_nested_deprecated(tb, TCA_ACT_MAX, nla, + tcf_action_policy, extack); + if (err < 0) + goto err_out; + + err = -EINVAL; + if (tb[TCA_ACT_INDEX] == NULL || + nla_len(tb[TCA_ACT_INDEX]) < sizeof(index)) { + NL_SET_ERR_MSG(extack, "Invalid TC action index value"); + goto err_out; + } + index = nla_get_u32(tb[TCA_ACT_INDEX]); + + err = -EINVAL; + ops = tc_lookup_action(tb[TCA_ACT_KIND]); + if (!ops) { /* could happen in batch of actions */ + NL_SET_ERR_MSG(extack, "Specified TC action kind not found"); + goto err_out; + } + err = -ENOENT; + if (__tcf_idr_search(net, ops, &a, index) == 0) { + NL_SET_ERR_MSG(extack, "TC action with specified index not found"); + goto err_mod; + } + + module_put(ops->owner); + return a; + +err_mod: + module_put(ops->owner); +err_out: + return ERR_PTR(err); +} + +static int tca_action_flush(struct net *net, struct nlattr *nla, + struct nlmsghdr *n, u32 portid, + struct netlink_ext_ack *extack) +{ + struct sk_buff *skb; + unsigned char *b; + struct nlmsghdr *nlh; + struct tcamsg *t; + struct netlink_callback dcb; + struct nlattr *nest; + struct nlattr *tb[TCA_ACT_MAX + 1]; + const struct tc_action_ops *ops; + struct nlattr *kind; + int err = -ENOMEM; + + skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) + return err; + + b = skb_tail_pointer(skb); + + err = nla_parse_nested_deprecated(tb, TCA_ACT_MAX, nla, + tcf_action_policy, extack); + if (err < 0) + goto err_out; + + err = -EINVAL; + kind = tb[TCA_ACT_KIND]; + ops = tc_lookup_action(kind); + if (!ops) { /*some idjot trying to flush unknown action */ + NL_SET_ERR_MSG(extack, "Cannot flush unknown TC action"); + goto err_out; + } + + nlh = nlmsg_put(skb, portid, n->nlmsg_seq, RTM_DELACTION, + sizeof(*t), 0); + if (!nlh) { + NL_SET_ERR_MSG(extack, "Failed to create TC action flush notification"); + goto out_module_put; + } + t = nlmsg_data(nlh); + t->tca_family = AF_UNSPEC; + t->tca__pad1 = 0; + t->tca__pad2 = 0; + + nest = nla_nest_start_noflag(skb, TCA_ACT_TAB); + if (!nest) { + NL_SET_ERR_MSG(extack, "Failed to add new netlink message"); + goto out_module_put; + } + + err = __tcf_generic_walker(net, skb, &dcb, RTM_DELACTION, ops, extack); + if (err <= 0) { + nla_nest_cancel(skb, nest); + goto out_module_put; + } + + nla_nest_end(skb, nest); + + nlh->nlmsg_len = skb_tail_pointer(skb) - b; + nlh->nlmsg_flags |= NLM_F_ROOT; + module_put(ops->owner); + err = rtnetlink_send(skb, net, portid, RTNLGRP_TC, + n->nlmsg_flags & NLM_F_ECHO); + if (err < 0) + NL_SET_ERR_MSG(extack, "Failed to send TC action flush notification"); + + return err; + +out_module_put: + module_put(ops->owner); +err_out: + kfree_skb(skb); + return err; +} + +static int tcf_action_delete(struct net *net, struct tc_action *actions[]) +{ + int i; + + for (i = 0; i < TCA_ACT_MAX_PRIO && actions[i]; i++) { + struct tc_action *a = actions[i]; + const struct tc_action_ops *ops = a->ops; + /* Actions can be deleted concurrently so we must save their + * type and id to search again after reference is released. + */ + struct tcf_idrinfo *idrinfo = a->idrinfo; + u32 act_index = a->tcfa_index; + + actions[i] = NULL; + if (tcf_action_put(a)) { + /* last reference, action was deleted concurrently */ + module_put(ops->owner); + } else { + int ret; + + /* now do the delete */ + ret = tcf_idr_delete_index(idrinfo, act_index); + if (ret < 0) + return ret; + } + } + return 0; +} + +static int +tcf_reoffload_del_notify(struct net *net, struct tc_action *action) +{ + size_t attr_size = tcf_action_fill_size(action); + struct tc_action *actions[TCA_ACT_MAX_PRIO] = { + [0] = action, + }; + const struct tc_action_ops *ops = action->ops; + struct sk_buff *skb; + int ret; + + skb = alloc_skb(attr_size <= NLMSG_GOODSIZE ? NLMSG_GOODSIZE : attr_size, + GFP_KERNEL); + if (!skb) + return -ENOBUFS; + + if (tca_get_fill(skb, actions, 0, 0, 0, RTM_DELACTION, 0, 1, NULL) <= 0) { + kfree_skb(skb); + return -EINVAL; + } + + ret = tcf_idr_release_unsafe(action); + if (ret == ACT_P_DELETED) { + module_put(ops->owner); + ret = rtnetlink_send(skb, net, 0, RTNLGRP_TC, 0); + } else { + kfree_skb(skb); + } + + return ret; +} + +int tcf_action_reoffload_cb(flow_indr_block_bind_cb_t *cb, + void *cb_priv, bool add) +{ + struct tc_act_pernet_id *id_ptr; + struct tcf_idrinfo *idrinfo; + struct tc_action_net *tn; + struct tc_action *p; + unsigned int act_id; + unsigned long tmp; + unsigned long id; + struct idr *idr; + struct net *net; + int ret; + + if (!cb) + return -EINVAL; + + down_read(&net_rwsem); + mutex_lock(&act_id_mutex); + + for_each_net(net) { + list_for_each_entry(id_ptr, &act_pernet_id_list, list) { + act_id = id_ptr->id; + tn = net_generic(net, act_id); + if (!tn) + continue; + idrinfo = tn->idrinfo; + if (!idrinfo) + continue; + + mutex_lock(&idrinfo->lock); + idr = &idrinfo->action_idr; + idr_for_each_entry_ul(idr, p, tmp, id) { + if (IS_ERR(p) || tc_act_bind(p->tcfa_flags)) + continue; + if (add) { + tcf_action_offload_add_ex(p, NULL, cb, + cb_priv); + continue; + } + + /* cb unregister to update hw count */ + ret = tcf_action_offload_del_ex(p, cb, cb_priv); + if (ret < 0) + continue; + if (tc_act_skip_sw(p->tcfa_flags) && + !tc_act_in_hw(p)) + tcf_reoffload_del_notify(net, p); + } + mutex_unlock(&idrinfo->lock); + } + } + mutex_unlock(&act_id_mutex); + up_read(&net_rwsem); + + return 0; +} + +static int +tcf_del_notify(struct net *net, struct nlmsghdr *n, struct tc_action *actions[], + u32 portid, size_t attr_size, struct netlink_ext_ack *extack) +{ + int ret; + struct sk_buff *skb; + + skb = alloc_skb(attr_size <= NLMSG_GOODSIZE ? NLMSG_GOODSIZE : attr_size, + GFP_KERNEL); + if (!skb) + return -ENOBUFS; + + if (tca_get_fill(skb, actions, portid, n->nlmsg_seq, 0, RTM_DELACTION, + 0, 2, extack) <= 0) { + NL_SET_ERR_MSG(extack, "Failed to fill netlink TC action attributes"); + kfree_skb(skb); + return -EINVAL; + } + + /* now do the delete */ + ret = tcf_action_delete(net, actions); + if (ret < 0) { + NL_SET_ERR_MSG(extack, "Failed to delete TC action"); + kfree_skb(skb); + return ret; + } + + ret = rtnetlink_send(skb, net, portid, RTNLGRP_TC, + n->nlmsg_flags & NLM_F_ECHO); + return ret; +} + +static int +tca_action_gd(struct net *net, struct nlattr *nla, struct nlmsghdr *n, + u32 portid, int event, struct netlink_ext_ack *extack) +{ + int i, ret; + struct nlattr *tb[TCA_ACT_MAX_PRIO + 1]; + struct tc_action *act; + size_t attr_size = 0; + struct tc_action *actions[TCA_ACT_MAX_PRIO] = {}; + + ret = nla_parse_nested_deprecated(tb, TCA_ACT_MAX_PRIO, nla, NULL, + extack); + if (ret < 0) + return ret; + + if (event == RTM_DELACTION && n->nlmsg_flags & NLM_F_ROOT) { + if (tb[1]) + return tca_action_flush(net, tb[1], n, portid, extack); + + NL_SET_ERR_MSG(extack, "Invalid netlink attributes while flushing TC action"); + return -EINVAL; + } + + for (i = 1; i <= TCA_ACT_MAX_PRIO && tb[i]; i++) { + act = tcf_action_get_1(net, tb[i], n, portid, extack); + if (IS_ERR(act)) { + ret = PTR_ERR(act); + goto err; + } + attr_size += tcf_action_fill_size(act); + actions[i - 1] = act; + } + + attr_size = tcf_action_full_attrs_size(attr_size); + + if (event == RTM_GETACTION) + ret = tcf_get_notify(net, portid, n, actions, event, extack); + else { /* delete */ + ret = tcf_del_notify(net, n, actions, portid, attr_size, extack); + if (ret) + goto err; + return 0; + } +err: + tcf_action_put_many(actions); + return ret; +} + +static int +tcf_add_notify(struct net *net, struct nlmsghdr *n, struct tc_action *actions[], + u32 portid, size_t attr_size, struct netlink_ext_ack *extack) +{ + struct sk_buff *skb; + + skb = alloc_skb(attr_size <= NLMSG_GOODSIZE ? NLMSG_GOODSIZE : attr_size, + GFP_KERNEL); + if (!skb) + return -ENOBUFS; + + if (tca_get_fill(skb, actions, portid, n->nlmsg_seq, n->nlmsg_flags, + RTM_NEWACTION, 0, 0, extack) <= 0) { + NL_SET_ERR_MSG(extack, "Failed to fill netlink attributes while adding TC action"); + kfree_skb(skb); + return -EINVAL; + } + + return rtnetlink_send(skb, net, portid, RTNLGRP_TC, + n->nlmsg_flags & NLM_F_ECHO); +} + +static int tcf_action_add(struct net *net, struct nlattr *nla, + struct nlmsghdr *n, u32 portid, u32 flags, + struct netlink_ext_ack *extack) +{ + size_t attr_size = 0; + int loop, ret, i; + struct tc_action *actions[TCA_ACT_MAX_PRIO] = {}; + int init_res[TCA_ACT_MAX_PRIO] = {}; + + for (loop = 0; loop < 10; loop++) { + ret = tcf_action_init(net, NULL, nla, NULL, actions, init_res, + &attr_size, flags, 0, extack); + if (ret != -EAGAIN) + break; + } + + if (ret < 0) + return ret; + ret = tcf_add_notify(net, n, actions, portid, attr_size, extack); + + /* only put existing actions */ + for (i = 0; i < TCA_ACT_MAX_PRIO; i++) + if (init_res[i] == ACT_P_CREATED) + actions[i] = NULL; + tcf_action_put_many(actions); + + return ret; +} + +static const struct nla_policy tcaa_policy[TCA_ROOT_MAX + 1] = { + [TCA_ROOT_FLAGS] = NLA_POLICY_BITFIELD32(TCA_ACT_FLAG_LARGE_DUMP_ON | + TCA_ACT_FLAG_TERSE_DUMP), + [TCA_ROOT_TIME_DELTA] = { .type = NLA_U32 }, +}; + +static int tc_ctl_action(struct sk_buff *skb, struct nlmsghdr *n, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *tca[TCA_ROOT_MAX + 1]; + u32 portid = NETLINK_CB(skb).portid; + u32 flags = 0; + int ret = 0; + + if ((n->nlmsg_type != RTM_GETACTION) && + !netlink_capable(skb, CAP_NET_ADMIN)) + return -EPERM; + + ret = nlmsg_parse_deprecated(n, sizeof(struct tcamsg), tca, + TCA_ROOT_MAX, NULL, extack); + if (ret < 0) + return ret; + + if (tca[TCA_ACT_TAB] == NULL) { + NL_SET_ERR_MSG(extack, "Netlink action attributes missing"); + return -EINVAL; + } + + /* n->nlmsg_flags & NLM_F_CREATE */ + switch (n->nlmsg_type) { + case RTM_NEWACTION: + /* we are going to assume all other flags + * imply create only if it doesn't exist + * Note that CREATE | EXCL implies that + * but since we want avoid ambiguity (eg when flags + * is zero) then just set this + */ + if (n->nlmsg_flags & NLM_F_REPLACE) + flags = TCA_ACT_FLAGS_REPLACE; + ret = tcf_action_add(net, tca[TCA_ACT_TAB], n, portid, flags, + extack); + break; + case RTM_DELACTION: + ret = tca_action_gd(net, tca[TCA_ACT_TAB], n, + portid, RTM_DELACTION, extack); + break; + case RTM_GETACTION: + ret = tca_action_gd(net, tca[TCA_ACT_TAB], n, + portid, RTM_GETACTION, extack); + break; + default: + BUG(); + } + + return ret; +} + +static struct nlattr *find_dump_kind(struct nlattr **nla) +{ + struct nlattr *tb1, *tb2[TCA_ACT_MAX + 1]; + struct nlattr *tb[TCA_ACT_MAX_PRIO + 1]; + struct nlattr *kind; + + tb1 = nla[TCA_ACT_TAB]; + if (tb1 == NULL) + return NULL; + + if (nla_parse_deprecated(tb, TCA_ACT_MAX_PRIO, nla_data(tb1), NLMSG_ALIGN(nla_len(tb1)), NULL, NULL) < 0) + return NULL; + + if (tb[1] == NULL) + return NULL; + if (nla_parse_nested_deprecated(tb2, TCA_ACT_MAX, tb[1], tcf_action_policy, NULL) < 0) + return NULL; + kind = tb2[TCA_ACT_KIND]; + + return kind; +} + +static int tc_dump_action(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + struct nlmsghdr *nlh; + unsigned char *b = skb_tail_pointer(skb); + struct nlattr *nest; + struct tc_action_ops *a_o; + int ret = 0; + struct tcamsg *t = (struct tcamsg *) nlmsg_data(cb->nlh); + struct nlattr *tb[TCA_ROOT_MAX + 1]; + struct nlattr *count_attr = NULL; + unsigned long jiffy_since = 0; + struct nlattr *kind = NULL; + struct nla_bitfield32 bf; + u32 msecs_since = 0; + u32 act_count = 0; + + ret = nlmsg_parse_deprecated(cb->nlh, sizeof(struct tcamsg), tb, + TCA_ROOT_MAX, tcaa_policy, cb->extack); + if (ret < 0) + return ret; + + kind = find_dump_kind(tb); + if (kind == NULL) { + pr_info("tc_dump_action: action bad kind\n"); + return 0; + } + + a_o = tc_lookup_action(kind); + if (a_o == NULL) + return 0; + + cb->args[2] = 0; + if (tb[TCA_ROOT_FLAGS]) { + bf = nla_get_bitfield32(tb[TCA_ROOT_FLAGS]); + cb->args[2] = bf.value; + } + + if (tb[TCA_ROOT_TIME_DELTA]) { + msecs_since = nla_get_u32(tb[TCA_ROOT_TIME_DELTA]); + } + + nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + cb->nlh->nlmsg_type, sizeof(*t), 0); + if (!nlh) + goto out_module_put; + + if (msecs_since) + jiffy_since = jiffies - msecs_to_jiffies(msecs_since); + + t = nlmsg_data(nlh); + t->tca_family = AF_UNSPEC; + t->tca__pad1 = 0; + t->tca__pad2 = 0; + cb->args[3] = jiffy_since; + count_attr = nla_reserve(skb, TCA_ROOT_COUNT, sizeof(u32)); + if (!count_attr) + goto out_module_put; + + nest = nla_nest_start_noflag(skb, TCA_ACT_TAB); + if (nest == NULL) + goto out_module_put; + + ret = __tcf_generic_walker(net, skb, cb, RTM_GETACTION, a_o, NULL); + if (ret < 0) + goto out_module_put; + + if (ret > 0) { + nla_nest_end(skb, nest); + ret = skb->len; + act_count = cb->args[1]; + memcpy(nla_data(count_attr), &act_count, sizeof(u32)); + cb->args[1] = 0; + } else + nlmsg_trim(skb, b); + + nlh->nlmsg_len = skb_tail_pointer(skb) - b; + if (NETLINK_CB(cb->skb).portid && ret) + nlh->nlmsg_flags |= NLM_F_MULTI; + module_put(a_o->owner); + return skb->len; + +out_module_put: + module_put(a_o->owner); + nlmsg_trim(skb, b); + return skb->len; +} + +static int __init tc_action_init(void) +{ + rtnl_register(PF_UNSPEC, RTM_NEWACTION, tc_ctl_action, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_DELACTION, tc_ctl_action, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_GETACTION, tc_ctl_action, tc_dump_action, + 0); + + return 0; +} + +subsys_initcall(tc_action_init); diff --git a/net/sched/act_bpf.c b/net/sched/act_bpf.c new file mode 100644 index 000000000..b79eee44e --- /dev/null +++ b/net/sched/act_bpf.c @@ -0,0 +1,437 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (c) 2015 Jiri Pirko + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +#define ACT_BPF_NAME_LEN 256 + +struct tcf_bpf_cfg { + struct bpf_prog *filter; + struct sock_filter *bpf_ops; + const char *bpf_name; + u16 bpf_num_ops; + bool is_ebpf; +}; + +static struct tc_action_ops act_bpf_ops; + +static int tcf_bpf_act(struct sk_buff *skb, const struct tc_action *act, + struct tcf_result *res) +{ + bool at_ingress = skb_at_tc_ingress(skb); + struct tcf_bpf *prog = to_bpf(act); + struct bpf_prog *filter; + int action, filter_res; + + tcf_lastuse_update(&prog->tcf_tm); + bstats_update(this_cpu_ptr(prog->common.cpu_bstats), skb); + + filter = rcu_dereference(prog->filter); + if (at_ingress) { + __skb_push(skb, skb->mac_len); + bpf_compute_data_pointers(skb); + filter_res = bpf_prog_run(filter, skb); + __skb_pull(skb, skb->mac_len); + } else { + bpf_compute_data_pointers(skb); + filter_res = bpf_prog_run(filter, skb); + } + if (unlikely(!skb->tstamp && skb->mono_delivery_time)) + skb->mono_delivery_time = 0; + if (skb_sk_is_prefetched(skb) && filter_res != TC_ACT_OK) + skb_orphan(skb); + + /* A BPF program may overwrite the default action opcode. + * Similarly as in cls_bpf, if filter_res == -1 we use the + * default action specified from tc. + * + * In case a different well-known TC_ACT opcode has been + * returned, it will overwrite the default one. + * + * For everything else that is unknown, TC_ACT_UNSPEC is + * returned. + */ + switch (filter_res) { + case TC_ACT_PIPE: + case TC_ACT_RECLASSIFY: + case TC_ACT_OK: + case TC_ACT_REDIRECT: + action = filter_res; + break; + case TC_ACT_SHOT: + action = filter_res; + qstats_drop_inc(this_cpu_ptr(prog->common.cpu_qstats)); + break; + case TC_ACT_UNSPEC: + action = prog->tcf_action; + break; + default: + action = TC_ACT_UNSPEC; + break; + } + + return action; +} + +static bool tcf_bpf_is_ebpf(const struct tcf_bpf *prog) +{ + return !prog->bpf_ops; +} + +static int tcf_bpf_dump_bpf_info(const struct tcf_bpf *prog, + struct sk_buff *skb) +{ + struct nlattr *nla; + + if (nla_put_u16(skb, TCA_ACT_BPF_OPS_LEN, prog->bpf_num_ops)) + return -EMSGSIZE; + + nla = nla_reserve(skb, TCA_ACT_BPF_OPS, prog->bpf_num_ops * + sizeof(struct sock_filter)); + if (nla == NULL) + return -EMSGSIZE; + + memcpy(nla_data(nla), prog->bpf_ops, nla_len(nla)); + + return 0; +} + +static int tcf_bpf_dump_ebpf_info(const struct tcf_bpf *prog, + struct sk_buff *skb) +{ + struct nlattr *nla; + + if (prog->bpf_name && + nla_put_string(skb, TCA_ACT_BPF_NAME, prog->bpf_name)) + return -EMSGSIZE; + + if (nla_put_u32(skb, TCA_ACT_BPF_ID, prog->filter->aux->id)) + return -EMSGSIZE; + + nla = nla_reserve(skb, TCA_ACT_BPF_TAG, sizeof(prog->filter->tag)); + if (nla == NULL) + return -EMSGSIZE; + + memcpy(nla_data(nla), prog->filter->tag, nla_len(nla)); + + return 0; +} + +static int tcf_bpf_dump(struct sk_buff *skb, struct tc_action *act, + int bind, int ref) +{ + unsigned char *tp = skb_tail_pointer(skb); + struct tcf_bpf *prog = to_bpf(act); + struct tc_act_bpf opt = { + .index = prog->tcf_index, + .refcnt = refcount_read(&prog->tcf_refcnt) - ref, + .bindcnt = atomic_read(&prog->tcf_bindcnt) - bind, + }; + struct tcf_t tm; + int ret; + + spin_lock_bh(&prog->tcf_lock); + opt.action = prog->tcf_action; + if (nla_put(skb, TCA_ACT_BPF_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + + if (tcf_bpf_is_ebpf(prog)) + ret = tcf_bpf_dump_ebpf_info(prog, skb); + else + ret = tcf_bpf_dump_bpf_info(prog, skb); + if (ret) + goto nla_put_failure; + + tcf_tm_dump(&tm, &prog->tcf_tm); + if (nla_put_64bit(skb, TCA_ACT_BPF_TM, sizeof(tm), &tm, + TCA_ACT_BPF_PAD)) + goto nla_put_failure; + + spin_unlock_bh(&prog->tcf_lock); + return skb->len; + +nla_put_failure: + spin_unlock_bh(&prog->tcf_lock); + nlmsg_trim(skb, tp); + return -1; +} + +static const struct nla_policy act_bpf_policy[TCA_ACT_BPF_MAX + 1] = { + [TCA_ACT_BPF_PARMS] = { .len = sizeof(struct tc_act_bpf) }, + [TCA_ACT_BPF_FD] = { .type = NLA_U32 }, + [TCA_ACT_BPF_NAME] = { .type = NLA_NUL_STRING, + .len = ACT_BPF_NAME_LEN }, + [TCA_ACT_BPF_OPS_LEN] = { .type = NLA_U16 }, + [TCA_ACT_BPF_OPS] = { .type = NLA_BINARY, + .len = sizeof(struct sock_filter) * BPF_MAXINSNS }, +}; + +static int tcf_bpf_init_from_ops(struct nlattr **tb, struct tcf_bpf_cfg *cfg) +{ + struct sock_filter *bpf_ops; + struct sock_fprog_kern fprog_tmp; + struct bpf_prog *fp; + u16 bpf_size, bpf_num_ops; + int ret; + + bpf_num_ops = nla_get_u16(tb[TCA_ACT_BPF_OPS_LEN]); + if (bpf_num_ops > BPF_MAXINSNS || bpf_num_ops == 0) + return -EINVAL; + + bpf_size = bpf_num_ops * sizeof(*bpf_ops); + if (bpf_size != nla_len(tb[TCA_ACT_BPF_OPS])) + return -EINVAL; + + bpf_ops = kmemdup(nla_data(tb[TCA_ACT_BPF_OPS]), bpf_size, GFP_KERNEL); + if (bpf_ops == NULL) + return -ENOMEM; + + fprog_tmp.len = bpf_num_ops; + fprog_tmp.filter = bpf_ops; + + ret = bpf_prog_create(&fp, &fprog_tmp); + if (ret < 0) { + kfree(bpf_ops); + return ret; + } + + cfg->bpf_ops = bpf_ops; + cfg->bpf_num_ops = bpf_num_ops; + cfg->filter = fp; + cfg->is_ebpf = false; + + return 0; +} + +static int tcf_bpf_init_from_efd(struct nlattr **tb, struct tcf_bpf_cfg *cfg) +{ + struct bpf_prog *fp; + char *name = NULL; + u32 bpf_fd; + + bpf_fd = nla_get_u32(tb[TCA_ACT_BPF_FD]); + + fp = bpf_prog_get_type(bpf_fd, BPF_PROG_TYPE_SCHED_ACT); + if (IS_ERR(fp)) + return PTR_ERR(fp); + + if (tb[TCA_ACT_BPF_NAME]) { + name = nla_memdup(tb[TCA_ACT_BPF_NAME], GFP_KERNEL); + if (!name) { + bpf_prog_put(fp); + return -ENOMEM; + } + } + + cfg->bpf_name = name; + cfg->filter = fp; + cfg->is_ebpf = true; + + return 0; +} + +static void tcf_bpf_cfg_cleanup(const struct tcf_bpf_cfg *cfg) +{ + struct bpf_prog *filter = cfg->filter; + + if (filter) { + if (cfg->is_ebpf) + bpf_prog_put(filter); + else + bpf_prog_destroy(filter); + } + + kfree(cfg->bpf_ops); + kfree(cfg->bpf_name); +} + +static void tcf_bpf_prog_fill_cfg(const struct tcf_bpf *prog, + struct tcf_bpf_cfg *cfg) +{ + cfg->is_ebpf = tcf_bpf_is_ebpf(prog); + /* updates to prog->filter are prevented, since it's called either + * with tcf lock or during final cleanup in rcu callback + */ + cfg->filter = rcu_dereference_protected(prog->filter, 1); + + cfg->bpf_ops = prog->bpf_ops; + cfg->bpf_name = prog->bpf_name; +} + +static int tcf_bpf_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **act, + struct tcf_proto *tp, u32 flags, + struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_bpf_ops.net_id); + bool bind = flags & TCA_ACT_FLAGS_BIND; + struct nlattr *tb[TCA_ACT_BPF_MAX + 1]; + struct tcf_chain *goto_ch = NULL; + struct tcf_bpf_cfg cfg, old; + struct tc_act_bpf *parm; + struct tcf_bpf *prog; + bool is_bpf, is_ebpf; + int ret, res = 0; + u32 index; + + if (!nla) + return -EINVAL; + + ret = nla_parse_nested_deprecated(tb, TCA_ACT_BPF_MAX, nla, + act_bpf_policy, NULL); + if (ret < 0) + return ret; + + if (!tb[TCA_ACT_BPF_PARMS]) + return -EINVAL; + + parm = nla_data(tb[TCA_ACT_BPF_PARMS]); + index = parm->index; + ret = tcf_idr_check_alloc(tn, &index, act, bind); + if (!ret) { + ret = tcf_idr_create(tn, index, est, act, + &act_bpf_ops, bind, true, flags); + if (ret < 0) { + tcf_idr_cleanup(tn, index); + return ret; + } + + res = ACT_P_CREATED; + } else if (ret > 0) { + /* Don't override defaults. */ + if (bind) + return 0; + + if (!(flags & TCA_ACT_FLAGS_REPLACE)) { + tcf_idr_release(*act, bind); + return -EEXIST; + } + } else { + return ret; + } + + ret = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack); + if (ret < 0) + goto release_idr; + + is_bpf = tb[TCA_ACT_BPF_OPS_LEN] && tb[TCA_ACT_BPF_OPS]; + is_ebpf = tb[TCA_ACT_BPF_FD]; + + if (is_bpf == is_ebpf) { + ret = -EINVAL; + goto put_chain; + } + + memset(&cfg, 0, sizeof(cfg)); + + ret = is_bpf ? tcf_bpf_init_from_ops(tb, &cfg) : + tcf_bpf_init_from_efd(tb, &cfg); + if (ret < 0) + goto put_chain; + + prog = to_bpf(*act); + + spin_lock_bh(&prog->tcf_lock); + if (res != ACT_P_CREATED) + tcf_bpf_prog_fill_cfg(prog, &old); + + prog->bpf_ops = cfg.bpf_ops; + prog->bpf_name = cfg.bpf_name; + + if (cfg.bpf_num_ops) + prog->bpf_num_ops = cfg.bpf_num_ops; + + goto_ch = tcf_action_set_ctrlact(*act, parm->action, goto_ch); + rcu_assign_pointer(prog->filter, cfg.filter); + spin_unlock_bh(&prog->tcf_lock); + + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + + if (res != ACT_P_CREATED) { + /* make sure the program being replaced is no longer executing */ + synchronize_rcu(); + tcf_bpf_cfg_cleanup(&old); + } + + return res; + +put_chain: + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + +release_idr: + tcf_idr_release(*act, bind); + return ret; +} + +static void tcf_bpf_cleanup(struct tc_action *act) +{ + struct tcf_bpf_cfg tmp; + + tcf_bpf_prog_fill_cfg(to_bpf(act), &tmp); + tcf_bpf_cfg_cleanup(&tmp); +} + +static struct tc_action_ops act_bpf_ops __read_mostly = { + .kind = "bpf", + .id = TCA_ID_BPF, + .owner = THIS_MODULE, + .act = tcf_bpf_act, + .dump = tcf_bpf_dump, + .cleanup = tcf_bpf_cleanup, + .init = tcf_bpf_init, + .size = sizeof(struct tcf_bpf), +}; + +static __net_init int bpf_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_bpf_ops.net_id); + + return tc_action_net_init(net, tn, &act_bpf_ops); +} + +static void __net_exit bpf_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_bpf_ops.net_id); +} + +static struct pernet_operations bpf_net_ops = { + .init = bpf_init_net, + .exit_batch = bpf_exit_net, + .id = &act_bpf_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +static int __init bpf_init_module(void) +{ + return tcf_register_action(&act_bpf_ops, &bpf_net_ops); +} + +static void __exit bpf_cleanup_module(void) +{ + tcf_unregister_action(&act_bpf_ops, &bpf_net_ops); +} + +module_init(bpf_init_module); +module_exit(bpf_cleanup_module); + +MODULE_AUTHOR("Jiri Pirko "); +MODULE_DESCRIPTION("TC BPF based action"); +MODULE_LICENSE("GPL v2"); diff --git a/net/sched/act_connmark.c b/net/sched/act_connmark.c new file mode 100644 index 000000000..d41002e46 --- /dev/null +++ b/net/sched/act_connmark.c @@ -0,0 +1,245 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/act_connmark.c netfilter connmark retriever action + * skb mark is over-written + * + * Copyright (c) 2011 Felix Fietkau +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +static struct tc_action_ops act_connmark_ops; + +static int tcf_connmark_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + const struct nf_conntrack_tuple_hash *thash; + struct nf_conntrack_tuple tuple; + enum ip_conntrack_info ctinfo; + struct tcf_connmark_info *ca = to_connmark(a); + struct nf_conntrack_zone zone; + struct nf_conn *c; + int proto; + + spin_lock(&ca->tcf_lock); + tcf_lastuse_update(&ca->tcf_tm); + bstats_update(&ca->tcf_bstats, skb); + + switch (skb_protocol(skb, true)) { + case htons(ETH_P_IP): + if (skb->len < sizeof(struct iphdr)) + goto out; + + proto = NFPROTO_IPV4; + break; + case htons(ETH_P_IPV6): + if (skb->len < sizeof(struct ipv6hdr)) + goto out; + + proto = NFPROTO_IPV6; + break; + default: + goto out; + } + + c = nf_ct_get(skb, &ctinfo); + if (c) { + skb->mark = READ_ONCE(c->mark); + /* using overlimits stats to count how many packets marked */ + ca->tcf_qstats.overlimits++; + goto out; + } + + if (!nf_ct_get_tuplepr(skb, skb_network_offset(skb), + proto, ca->net, &tuple)) + goto out; + + zone.id = ca->zone; + zone.dir = NF_CT_DEFAULT_ZONE_DIR; + + thash = nf_conntrack_find_get(ca->net, &zone, &tuple); + if (!thash) + goto out; + + c = nf_ct_tuplehash_to_ctrack(thash); + /* using overlimits stats to count how many packets marked */ + ca->tcf_qstats.overlimits++; + skb->mark = READ_ONCE(c->mark); + nf_ct_put(c); + +out: + spin_unlock(&ca->tcf_lock); + return ca->tcf_action; +} + +static const struct nla_policy connmark_policy[TCA_CONNMARK_MAX + 1] = { + [TCA_CONNMARK_PARMS] = { .len = sizeof(struct tc_connmark) }, +}; + +static int tcf_connmark_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, u32 flags, + struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_connmark_ops.net_id); + struct nlattr *tb[TCA_CONNMARK_MAX + 1]; + bool bind = flags & TCA_ACT_FLAGS_BIND; + struct tcf_chain *goto_ch = NULL; + struct tcf_connmark_info *ci; + struct tc_connmark *parm; + int ret = 0, err; + u32 index; + + if (!nla) + return -EINVAL; + + ret = nla_parse_nested_deprecated(tb, TCA_CONNMARK_MAX, nla, + connmark_policy, NULL); + if (ret < 0) + return ret; + + if (!tb[TCA_CONNMARK_PARMS]) + return -EINVAL; + + parm = nla_data(tb[TCA_CONNMARK_PARMS]); + index = parm->index; + ret = tcf_idr_check_alloc(tn, &index, a, bind); + if (!ret) { + ret = tcf_idr_create(tn, index, est, a, + &act_connmark_ops, bind, false, flags); + if (ret) { + tcf_idr_cleanup(tn, index); + return ret; + } + + ci = to_connmark(*a); + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, + extack); + if (err < 0) + goto release_idr; + tcf_action_set_ctrlact(*a, parm->action, goto_ch); + ci->net = net; + ci->zone = parm->zone; + + ret = ACT_P_CREATED; + } else if (ret > 0) { + ci = to_connmark(*a); + if (bind) + return 0; + if (!(flags & TCA_ACT_FLAGS_REPLACE)) { + tcf_idr_release(*a, bind); + return -EEXIST; + } + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, + extack); + if (err < 0) + goto release_idr; + /* replacing action and zone */ + spin_lock_bh(&ci->tcf_lock); + goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch); + ci->zone = parm->zone; + spin_unlock_bh(&ci->tcf_lock); + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + ret = 0; + } + + return ret; +release_idr: + tcf_idr_release(*a, bind); + return err; +} + +static inline int tcf_connmark_dump(struct sk_buff *skb, struct tc_action *a, + int bind, int ref) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tcf_connmark_info *ci = to_connmark(a); + struct tc_connmark opt = { + .index = ci->tcf_index, + .refcnt = refcount_read(&ci->tcf_refcnt) - ref, + .bindcnt = atomic_read(&ci->tcf_bindcnt) - bind, + }; + struct tcf_t t; + + spin_lock_bh(&ci->tcf_lock); + opt.action = ci->tcf_action; + opt.zone = ci->zone; + if (nla_put(skb, TCA_CONNMARK_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + + tcf_tm_dump(&t, &ci->tcf_tm); + if (nla_put_64bit(skb, TCA_CONNMARK_TM, sizeof(t), &t, + TCA_CONNMARK_PAD)) + goto nla_put_failure; + spin_unlock_bh(&ci->tcf_lock); + + return skb->len; + +nla_put_failure: + spin_unlock_bh(&ci->tcf_lock); + nlmsg_trim(skb, b); + return -1; +} + +static struct tc_action_ops act_connmark_ops = { + .kind = "connmark", + .id = TCA_ID_CONNMARK, + .owner = THIS_MODULE, + .act = tcf_connmark_act, + .dump = tcf_connmark_dump, + .init = tcf_connmark_init, + .size = sizeof(struct tcf_connmark_info), +}; + +static __net_init int connmark_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_connmark_ops.net_id); + + return tc_action_net_init(net, tn, &act_connmark_ops); +} + +static void __net_exit connmark_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_connmark_ops.net_id); +} + +static struct pernet_operations connmark_net_ops = { + .init = connmark_init_net, + .exit_batch = connmark_exit_net, + .id = &act_connmark_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +static int __init connmark_init_module(void) +{ + return tcf_register_action(&act_connmark_ops, &connmark_net_ops); +} + +static void __exit connmark_cleanup_module(void) +{ + tcf_unregister_action(&act_connmark_ops, &connmark_net_ops); +} + +module_init(connmark_init_module); +module_exit(connmark_cleanup_module); +MODULE_AUTHOR("Felix Fietkau "); +MODULE_DESCRIPTION("Connection tracking mark restoring"); +MODULE_LICENSE("GPL"); diff --git a/net/sched/act_csum.c b/net/sched/act_csum.c new file mode 100644 index 000000000..1366adf9b --- /dev/null +++ b/net/sched/act_csum.c @@ -0,0 +1,745 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Checksum updating actions + * + * Copyright (c) 2010 Gregoire Baron + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +static const struct nla_policy csum_policy[TCA_CSUM_MAX + 1] = { + [TCA_CSUM_PARMS] = { .len = sizeof(struct tc_csum), }, +}; + +static struct tc_action_ops act_csum_ops; + +static int tcf_csum_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, + u32 flags, struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_csum_ops.net_id); + bool bind = flags & TCA_ACT_FLAGS_BIND; + struct tcf_csum_params *params_new; + struct nlattr *tb[TCA_CSUM_MAX + 1]; + struct tcf_chain *goto_ch = NULL; + struct tc_csum *parm; + struct tcf_csum *p; + int ret = 0, err; + u32 index; + + if (nla == NULL) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, TCA_CSUM_MAX, nla, csum_policy, + NULL); + if (err < 0) + return err; + + if (tb[TCA_CSUM_PARMS] == NULL) + return -EINVAL; + parm = nla_data(tb[TCA_CSUM_PARMS]); + index = parm->index; + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (!err) { + ret = tcf_idr_create_from_flags(tn, index, est, a, + &act_csum_ops, bind, flags); + if (ret) { + tcf_idr_cleanup(tn, index); + return ret; + } + ret = ACT_P_CREATED; + } else if (err > 0) { + if (bind)/* dont override defaults */ + return 0; + if (!(flags & TCA_ACT_FLAGS_REPLACE)) { + tcf_idr_release(*a, bind); + return -EEXIST; + } + } else { + return err; + } + + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack); + if (err < 0) + goto release_idr; + + p = to_tcf_csum(*a); + + params_new = kzalloc(sizeof(*params_new), GFP_KERNEL); + if (unlikely(!params_new)) { + err = -ENOMEM; + goto put_chain; + } + params_new->update_flags = parm->update_flags; + + spin_lock_bh(&p->tcf_lock); + goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch); + params_new = rcu_replace_pointer(p->params, params_new, + lockdep_is_held(&p->tcf_lock)); + spin_unlock_bh(&p->tcf_lock); + + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + if (params_new) + kfree_rcu(params_new, rcu); + + return ret; +put_chain: + if (goto_ch) + tcf_chain_put_by_act(goto_ch); +release_idr: + tcf_idr_release(*a, bind); + return err; +} + +/** + * tcf_csum_skb_nextlayer - Get next layer pointer + * @skb: sk_buff to use + * @ihl: previous summed headers length + * @ipl: complete packet length + * @jhl: next header length + * + * Check the expected next layer availability in the specified sk_buff. + * Return the next layer pointer if pass, NULL otherwise. + */ +static void *tcf_csum_skb_nextlayer(struct sk_buff *skb, + unsigned int ihl, unsigned int ipl, + unsigned int jhl) +{ + int ntkoff = skb_network_offset(skb); + int hl = ihl + jhl; + + if (!pskb_may_pull(skb, ipl + ntkoff) || (ipl < hl) || + skb_try_make_writable(skb, hl + ntkoff)) + return NULL; + else + return (void *)(skb_network_header(skb) + ihl); +} + +static int tcf_csum_ipv4_icmp(struct sk_buff *skb, unsigned int ihl, + unsigned int ipl) +{ + struct icmphdr *icmph; + + icmph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*icmph)); + if (icmph == NULL) + return 0; + + icmph->checksum = 0; + skb->csum = csum_partial(icmph, ipl - ihl, 0); + icmph->checksum = csum_fold(skb->csum); + + skb->ip_summed = CHECKSUM_NONE; + + return 1; +} + +static int tcf_csum_ipv4_igmp(struct sk_buff *skb, + unsigned int ihl, unsigned int ipl) +{ + struct igmphdr *igmph; + + igmph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*igmph)); + if (igmph == NULL) + return 0; + + igmph->csum = 0; + skb->csum = csum_partial(igmph, ipl - ihl, 0); + igmph->csum = csum_fold(skb->csum); + + skb->ip_summed = CHECKSUM_NONE; + + return 1; +} + +static int tcf_csum_ipv6_icmp(struct sk_buff *skb, unsigned int ihl, + unsigned int ipl) +{ + struct icmp6hdr *icmp6h; + const struct ipv6hdr *ip6h; + + icmp6h = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*icmp6h)); + if (icmp6h == NULL) + return 0; + + ip6h = ipv6_hdr(skb); + icmp6h->icmp6_cksum = 0; + skb->csum = csum_partial(icmp6h, ipl - ihl, 0); + icmp6h->icmp6_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, + ipl - ihl, IPPROTO_ICMPV6, + skb->csum); + + skb->ip_summed = CHECKSUM_NONE; + + return 1; +} + +static int tcf_csum_ipv4_tcp(struct sk_buff *skb, unsigned int ihl, + unsigned int ipl) +{ + struct tcphdr *tcph; + const struct iphdr *iph; + + if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_TCPV4) + return 1; + + tcph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*tcph)); + if (tcph == NULL) + return 0; + + iph = ip_hdr(skb); + tcph->check = 0; + skb->csum = csum_partial(tcph, ipl - ihl, 0); + tcph->check = tcp_v4_check(ipl - ihl, + iph->saddr, iph->daddr, skb->csum); + + skb->ip_summed = CHECKSUM_NONE; + + return 1; +} + +static int tcf_csum_ipv6_tcp(struct sk_buff *skb, unsigned int ihl, + unsigned int ipl) +{ + struct tcphdr *tcph; + const struct ipv6hdr *ip6h; + + if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_TCPV6) + return 1; + + tcph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*tcph)); + if (tcph == NULL) + return 0; + + ip6h = ipv6_hdr(skb); + tcph->check = 0; + skb->csum = csum_partial(tcph, ipl - ihl, 0); + tcph->check = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, + ipl - ihl, IPPROTO_TCP, + skb->csum); + + skb->ip_summed = CHECKSUM_NONE; + + return 1; +} + +static int tcf_csum_ipv4_udp(struct sk_buff *skb, unsigned int ihl, + unsigned int ipl, int udplite) +{ + struct udphdr *udph; + const struct iphdr *iph; + u16 ul; + + if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_UDP) + return 1; + + /* + * Support both UDP and UDPLITE checksum algorithms, Don't use + * udph->len to get the real length without any protocol check, + * UDPLITE uses udph->len for another thing, + * Use iph->tot_len, or just ipl. + */ + + udph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*udph)); + if (udph == NULL) + return 0; + + iph = ip_hdr(skb); + ul = ntohs(udph->len); + + if (udplite || udph->check) { + + udph->check = 0; + + if (udplite) { + if (ul == 0) + skb->csum = csum_partial(udph, ipl - ihl, 0); + else if ((ul >= sizeof(*udph)) && (ul <= ipl - ihl)) + skb->csum = csum_partial(udph, ul, 0); + else + goto ignore_obscure_skb; + } else { + if (ul != ipl - ihl) + goto ignore_obscure_skb; + + skb->csum = csum_partial(udph, ul, 0); + } + + udph->check = csum_tcpudp_magic(iph->saddr, iph->daddr, + ul, iph->protocol, + skb->csum); + + if (!udph->check) + udph->check = CSUM_MANGLED_0; + } + + skb->ip_summed = CHECKSUM_NONE; + +ignore_obscure_skb: + return 1; +} + +static int tcf_csum_ipv6_udp(struct sk_buff *skb, unsigned int ihl, + unsigned int ipl, int udplite) +{ + struct udphdr *udph; + const struct ipv6hdr *ip6h; + u16 ul; + + if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_UDP) + return 1; + + /* + * Support both UDP and UDPLITE checksum algorithms, Don't use + * udph->len to get the real length without any protocol check, + * UDPLITE uses udph->len for another thing, + * Use ip6h->payload_len + sizeof(*ip6h) ... , or just ipl. + */ + + udph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*udph)); + if (udph == NULL) + return 0; + + ip6h = ipv6_hdr(skb); + ul = ntohs(udph->len); + + udph->check = 0; + + if (udplite) { + if (ul == 0) + skb->csum = csum_partial(udph, ipl - ihl, 0); + + else if ((ul >= sizeof(*udph)) && (ul <= ipl - ihl)) + skb->csum = csum_partial(udph, ul, 0); + + else + goto ignore_obscure_skb; + } else { + if (ul != ipl - ihl) + goto ignore_obscure_skb; + + skb->csum = csum_partial(udph, ul, 0); + } + + udph->check = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, ul, + udplite ? IPPROTO_UDPLITE : IPPROTO_UDP, + skb->csum); + + if (!udph->check) + udph->check = CSUM_MANGLED_0; + + skb->ip_summed = CHECKSUM_NONE; + +ignore_obscure_skb: + return 1; +} + +static int tcf_csum_sctp(struct sk_buff *skb, unsigned int ihl, + unsigned int ipl) +{ + struct sctphdr *sctph; + + if (skb_is_gso(skb) && skb_is_gso_sctp(skb)) + return 1; + + sctph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*sctph)); + if (!sctph) + return 0; + + sctph->checksum = sctp_compute_cksum(skb, + skb_network_offset(skb) + ihl); + skb->ip_summed = CHECKSUM_NONE; + skb->csum_not_inet = 0; + + return 1; +} + +static int tcf_csum_ipv4(struct sk_buff *skb, u32 update_flags) +{ + const struct iphdr *iph; + int ntkoff; + + ntkoff = skb_network_offset(skb); + + if (!pskb_may_pull(skb, sizeof(*iph) + ntkoff)) + goto fail; + + iph = ip_hdr(skb); + + switch (iph->frag_off & htons(IP_OFFSET) ? 0 : iph->protocol) { + case IPPROTO_ICMP: + if (update_flags & TCA_CSUM_UPDATE_FLAG_ICMP) + if (!tcf_csum_ipv4_icmp(skb, iph->ihl * 4, + ntohs(iph->tot_len))) + goto fail; + break; + case IPPROTO_IGMP: + if (update_flags & TCA_CSUM_UPDATE_FLAG_IGMP) + if (!tcf_csum_ipv4_igmp(skb, iph->ihl * 4, + ntohs(iph->tot_len))) + goto fail; + break; + case IPPROTO_TCP: + if (update_flags & TCA_CSUM_UPDATE_FLAG_TCP) + if (!tcf_csum_ipv4_tcp(skb, iph->ihl * 4, + ntohs(iph->tot_len))) + goto fail; + break; + case IPPROTO_UDP: + if (update_flags & TCA_CSUM_UPDATE_FLAG_UDP) + if (!tcf_csum_ipv4_udp(skb, iph->ihl * 4, + ntohs(iph->tot_len), 0)) + goto fail; + break; + case IPPROTO_UDPLITE: + if (update_flags & TCA_CSUM_UPDATE_FLAG_UDPLITE) + if (!tcf_csum_ipv4_udp(skb, iph->ihl * 4, + ntohs(iph->tot_len), 1)) + goto fail; + break; + case IPPROTO_SCTP: + if ((update_flags & TCA_CSUM_UPDATE_FLAG_SCTP) && + !tcf_csum_sctp(skb, iph->ihl * 4, ntohs(iph->tot_len))) + goto fail; + break; + } + + if (update_flags & TCA_CSUM_UPDATE_FLAG_IPV4HDR) { + if (skb_try_make_writable(skb, sizeof(*iph) + ntkoff)) + goto fail; + + ip_send_check(ip_hdr(skb)); + } + + return 1; + +fail: + return 0; +} + +static int tcf_csum_ipv6_hopopts(struct ipv6_opt_hdr *ip6xh, unsigned int ixhl, + unsigned int *pl) +{ + int off, len, optlen; + unsigned char *xh = (void *)ip6xh; + + off = sizeof(*ip6xh); + len = ixhl - off; + + while (len > 1) { + switch (xh[off]) { + case IPV6_TLV_PAD1: + optlen = 1; + break; + case IPV6_TLV_JUMBO: + optlen = xh[off + 1] + 2; + if (optlen != 6 || len < 6 || (off & 3) != 2) + /* wrong jumbo option length/alignment */ + return 0; + *pl = ntohl(*(__be32 *)(xh + off + 2)); + goto done; + default: + optlen = xh[off + 1] + 2; + if (optlen > len) + /* ignore obscure options */ + goto done; + break; + } + off += optlen; + len -= optlen; + } + +done: + return 1; +} + +static int tcf_csum_ipv6(struct sk_buff *skb, u32 update_flags) +{ + struct ipv6hdr *ip6h; + struct ipv6_opt_hdr *ip6xh; + unsigned int hl, ixhl; + unsigned int pl; + int ntkoff; + u8 nexthdr; + + ntkoff = skb_network_offset(skb); + + hl = sizeof(*ip6h); + + if (!pskb_may_pull(skb, hl + ntkoff)) + goto fail; + + ip6h = ipv6_hdr(skb); + + pl = ntohs(ip6h->payload_len); + nexthdr = ip6h->nexthdr; + + do { + switch (nexthdr) { + case NEXTHDR_FRAGMENT: + goto ignore_skb; + case NEXTHDR_ROUTING: + case NEXTHDR_HOP: + case NEXTHDR_DEST: + if (!pskb_may_pull(skb, hl + sizeof(*ip6xh) + ntkoff)) + goto fail; + ip6xh = (void *)(skb_network_header(skb) + hl); + ixhl = ipv6_optlen(ip6xh); + if (!pskb_may_pull(skb, hl + ixhl + ntkoff)) + goto fail; + ip6xh = (void *)(skb_network_header(skb) + hl); + if ((nexthdr == NEXTHDR_HOP) && + !(tcf_csum_ipv6_hopopts(ip6xh, ixhl, &pl))) + goto fail; + nexthdr = ip6xh->nexthdr; + hl += ixhl; + break; + case IPPROTO_ICMPV6: + if (update_flags & TCA_CSUM_UPDATE_FLAG_ICMP) + if (!tcf_csum_ipv6_icmp(skb, + hl, pl + sizeof(*ip6h))) + goto fail; + goto done; + case IPPROTO_TCP: + if (update_flags & TCA_CSUM_UPDATE_FLAG_TCP) + if (!tcf_csum_ipv6_tcp(skb, + hl, pl + sizeof(*ip6h))) + goto fail; + goto done; + case IPPROTO_UDP: + if (update_flags & TCA_CSUM_UPDATE_FLAG_UDP) + if (!tcf_csum_ipv6_udp(skb, hl, + pl + sizeof(*ip6h), 0)) + goto fail; + goto done; + case IPPROTO_UDPLITE: + if (update_flags & TCA_CSUM_UPDATE_FLAG_UDPLITE) + if (!tcf_csum_ipv6_udp(skb, hl, + pl + sizeof(*ip6h), 1)) + goto fail; + goto done; + case IPPROTO_SCTP: + if ((update_flags & TCA_CSUM_UPDATE_FLAG_SCTP) && + !tcf_csum_sctp(skb, hl, pl + sizeof(*ip6h))) + goto fail; + goto done; + default: + goto ignore_skb; + } + } while (pskb_may_pull(skb, hl + 1 + ntkoff)); + +done: +ignore_skb: + return 1; + +fail: + return 0; +} + +static int tcf_csum_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + struct tcf_csum *p = to_tcf_csum(a); + bool orig_vlan_tag_present = false; + unsigned int vlan_hdr_count = 0; + struct tcf_csum_params *params; + u32 update_flags; + __be16 protocol; + int action; + + params = rcu_dereference_bh(p->params); + + tcf_lastuse_update(&p->tcf_tm); + tcf_action_update_bstats(&p->common, skb); + + action = READ_ONCE(p->tcf_action); + if (unlikely(action == TC_ACT_SHOT)) + goto drop; + + update_flags = params->update_flags; + protocol = skb_protocol(skb, false); +again: + switch (protocol) { + case cpu_to_be16(ETH_P_IP): + if (!tcf_csum_ipv4(skb, update_flags)) + goto drop; + break; + case cpu_to_be16(ETH_P_IPV6): + if (!tcf_csum_ipv6(skb, update_flags)) + goto drop; + break; + case cpu_to_be16(ETH_P_8021AD): + fallthrough; + case cpu_to_be16(ETH_P_8021Q): + if (skb_vlan_tag_present(skb) && !orig_vlan_tag_present) { + protocol = skb->protocol; + orig_vlan_tag_present = true; + } else { + struct vlan_hdr *vlan = (struct vlan_hdr *)skb->data; + + protocol = vlan->h_vlan_encapsulated_proto; + skb_pull(skb, VLAN_HLEN); + skb_reset_network_header(skb); + vlan_hdr_count++; + } + goto again; + } + +out: + /* Restore the skb for the pulled VLAN tags */ + while (vlan_hdr_count--) { + skb_push(skb, VLAN_HLEN); + skb_reset_network_header(skb); + } + + return action; + +drop: + tcf_action_inc_drop_qstats(&p->common); + action = TC_ACT_SHOT; + goto out; +} + +static int tcf_csum_dump(struct sk_buff *skb, struct tc_action *a, int bind, + int ref) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tcf_csum *p = to_tcf_csum(a); + struct tcf_csum_params *params; + struct tc_csum opt = { + .index = p->tcf_index, + .refcnt = refcount_read(&p->tcf_refcnt) - ref, + .bindcnt = atomic_read(&p->tcf_bindcnt) - bind, + }; + struct tcf_t t; + + spin_lock_bh(&p->tcf_lock); + params = rcu_dereference_protected(p->params, + lockdep_is_held(&p->tcf_lock)); + opt.action = p->tcf_action; + opt.update_flags = params->update_flags; + + if (nla_put(skb, TCA_CSUM_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + + tcf_tm_dump(&t, &p->tcf_tm); + if (nla_put_64bit(skb, TCA_CSUM_TM, sizeof(t), &t, TCA_CSUM_PAD)) + goto nla_put_failure; + spin_unlock_bh(&p->tcf_lock); + + return skb->len; + +nla_put_failure: + spin_unlock_bh(&p->tcf_lock); + nlmsg_trim(skb, b); + return -1; +} + +static void tcf_csum_cleanup(struct tc_action *a) +{ + struct tcf_csum *p = to_tcf_csum(a); + struct tcf_csum_params *params; + + params = rcu_dereference_protected(p->params, 1); + if (params) + kfree_rcu(params, rcu); +} + +static size_t tcf_csum_get_fill_size(const struct tc_action *act) +{ + return nla_total_size(sizeof(struct tc_csum)); +} + +static int tcf_csum_offload_act_setup(struct tc_action *act, void *entry_data, + u32 *index_inc, bool bind, + struct netlink_ext_ack *extack) +{ + if (bind) { + struct flow_action_entry *entry = entry_data; + + entry->id = FLOW_ACTION_CSUM; + entry->csum_flags = tcf_csum_update_flags(act); + *index_inc = 1; + } else { + struct flow_offload_action *fl_action = entry_data; + + fl_action->id = FLOW_ACTION_CSUM; + } + + return 0; +} + +static struct tc_action_ops act_csum_ops = { + .kind = "csum", + .id = TCA_ID_CSUM, + .owner = THIS_MODULE, + .act = tcf_csum_act, + .dump = tcf_csum_dump, + .init = tcf_csum_init, + .cleanup = tcf_csum_cleanup, + .get_fill_size = tcf_csum_get_fill_size, + .offload_act_setup = tcf_csum_offload_act_setup, + .size = sizeof(struct tcf_csum), +}; + +static __net_init int csum_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_csum_ops.net_id); + + return tc_action_net_init(net, tn, &act_csum_ops); +} + +static void __net_exit csum_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_csum_ops.net_id); +} + +static struct pernet_operations csum_net_ops = { + .init = csum_init_net, + .exit_batch = csum_exit_net, + .id = &act_csum_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +MODULE_DESCRIPTION("Checksum updating actions"); +MODULE_LICENSE("GPL"); + +static int __init csum_init_module(void) +{ + return tcf_register_action(&act_csum_ops, &csum_net_ops); +} + +static void __exit csum_cleanup_module(void) +{ + tcf_unregister_action(&act_csum_ops, &csum_net_ops); +} + +module_init(csum_init_module); +module_exit(csum_cleanup_module); diff --git a/net/sched/act_ct.c b/net/sched/act_ct.c new file mode 100644 index 000000000..84e15116f --- /dev/null +++ b/net/sched/act_ct.c @@ -0,0 +1,1763 @@ +// SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB +/* - + * net/sched/act_ct.c Connection Tracking action + * + * Authors: Paul Blakey + * Yossi Kuperman + * Marcelo Ricardo Leitner + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static struct workqueue_struct *act_ct_wq; +static struct rhashtable zones_ht; +static DEFINE_MUTEX(zones_mutex); + +struct tcf_ct_flow_table { + struct rhash_head node; /* In zones tables */ + + struct rcu_work rwork; + struct nf_flowtable nf_ft; + refcount_t ref; + u16 zone; + + bool dying; +}; + +static const struct rhashtable_params zones_params = { + .head_offset = offsetof(struct tcf_ct_flow_table, node), + .key_offset = offsetof(struct tcf_ct_flow_table, zone), + .key_len = sizeof_field(struct tcf_ct_flow_table, zone), + .automatic_shrinking = true, +}; + +static struct flow_action_entry * +tcf_ct_flow_table_flow_action_get_next(struct flow_action *flow_action) +{ + int i = flow_action->num_entries++; + + return &flow_action->entries[i]; +} + +static void tcf_ct_add_mangle_action(struct flow_action *action, + enum flow_action_mangle_base htype, + u32 offset, + u32 mask, + u32 val) +{ + struct flow_action_entry *entry; + + entry = tcf_ct_flow_table_flow_action_get_next(action); + entry->id = FLOW_ACTION_MANGLE; + entry->mangle.htype = htype; + entry->mangle.mask = ~mask; + entry->mangle.offset = offset; + entry->mangle.val = val; +} + +/* The following nat helper functions check if the inverted reverse tuple + * (target) is different then the current dir tuple - meaning nat for ports + * and/or ip is needed, and add the relevant mangle actions. + */ +static void +tcf_ct_flow_table_add_action_nat_ipv4(const struct nf_conntrack_tuple *tuple, + struct nf_conntrack_tuple target, + struct flow_action *action) +{ + if (memcmp(&target.src.u3, &tuple->src.u3, sizeof(target.src.u3))) + tcf_ct_add_mangle_action(action, FLOW_ACT_MANGLE_HDR_TYPE_IP4, + offsetof(struct iphdr, saddr), + 0xFFFFFFFF, + be32_to_cpu(target.src.u3.ip)); + if (memcmp(&target.dst.u3, &tuple->dst.u3, sizeof(target.dst.u3))) + tcf_ct_add_mangle_action(action, FLOW_ACT_MANGLE_HDR_TYPE_IP4, + offsetof(struct iphdr, daddr), + 0xFFFFFFFF, + be32_to_cpu(target.dst.u3.ip)); +} + +static void +tcf_ct_add_ipv6_addr_mangle_action(struct flow_action *action, + union nf_inet_addr *addr, + u32 offset) +{ + int i; + + for (i = 0; i < sizeof(struct in6_addr) / sizeof(u32); i++) + tcf_ct_add_mangle_action(action, FLOW_ACT_MANGLE_HDR_TYPE_IP6, + i * sizeof(u32) + offset, + 0xFFFFFFFF, be32_to_cpu(addr->ip6[i])); +} + +static void +tcf_ct_flow_table_add_action_nat_ipv6(const struct nf_conntrack_tuple *tuple, + struct nf_conntrack_tuple target, + struct flow_action *action) +{ + if (memcmp(&target.src.u3, &tuple->src.u3, sizeof(target.src.u3))) + tcf_ct_add_ipv6_addr_mangle_action(action, &target.src.u3, + offsetof(struct ipv6hdr, + saddr)); + if (memcmp(&target.dst.u3, &tuple->dst.u3, sizeof(target.dst.u3))) + tcf_ct_add_ipv6_addr_mangle_action(action, &target.dst.u3, + offsetof(struct ipv6hdr, + daddr)); +} + +static void +tcf_ct_flow_table_add_action_nat_tcp(const struct nf_conntrack_tuple *tuple, + struct nf_conntrack_tuple target, + struct flow_action *action) +{ + __be16 target_src = target.src.u.tcp.port; + __be16 target_dst = target.dst.u.tcp.port; + + if (target_src != tuple->src.u.tcp.port) + tcf_ct_add_mangle_action(action, FLOW_ACT_MANGLE_HDR_TYPE_TCP, + offsetof(struct tcphdr, source), + 0xFFFF, be16_to_cpu(target_src)); + if (target_dst != tuple->dst.u.tcp.port) + tcf_ct_add_mangle_action(action, FLOW_ACT_MANGLE_HDR_TYPE_TCP, + offsetof(struct tcphdr, dest), + 0xFFFF, be16_to_cpu(target_dst)); +} + +static void +tcf_ct_flow_table_add_action_nat_udp(const struct nf_conntrack_tuple *tuple, + struct nf_conntrack_tuple target, + struct flow_action *action) +{ + __be16 target_src = target.src.u.udp.port; + __be16 target_dst = target.dst.u.udp.port; + + if (target_src != tuple->src.u.udp.port) + tcf_ct_add_mangle_action(action, FLOW_ACT_MANGLE_HDR_TYPE_UDP, + offsetof(struct udphdr, source), + 0xFFFF, be16_to_cpu(target_src)); + if (target_dst != tuple->dst.u.udp.port) + tcf_ct_add_mangle_action(action, FLOW_ACT_MANGLE_HDR_TYPE_UDP, + offsetof(struct udphdr, dest), + 0xFFFF, be16_to_cpu(target_dst)); +} + +static void tcf_ct_flow_table_add_action_meta(struct nf_conn *ct, + enum ip_conntrack_dir dir, + enum ip_conntrack_info ctinfo, + struct flow_action *action) +{ + struct nf_conn_labels *ct_labels; + struct flow_action_entry *entry; + u32 *act_ct_labels; + + entry = tcf_ct_flow_table_flow_action_get_next(action); + entry->id = FLOW_ACTION_CT_METADATA; +#if IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) + entry->ct_metadata.mark = READ_ONCE(ct->mark); +#endif + /* aligns with the CT reference on the SKB nf_ct_set */ + entry->ct_metadata.cookie = (unsigned long)ct | ctinfo; + entry->ct_metadata.orig_dir = dir == IP_CT_DIR_ORIGINAL; + + act_ct_labels = entry->ct_metadata.labels; + ct_labels = nf_ct_labels_find(ct); + if (ct_labels) + memcpy(act_ct_labels, ct_labels->bits, NF_CT_LABELS_MAX_SIZE); + else + memset(act_ct_labels, 0, NF_CT_LABELS_MAX_SIZE); +} + +static int tcf_ct_flow_table_add_action_nat(struct net *net, + struct nf_conn *ct, + enum ip_conntrack_dir dir, + struct flow_action *action) +{ + const struct nf_conntrack_tuple *tuple = &ct->tuplehash[dir].tuple; + struct nf_conntrack_tuple target; + + if (!(ct->status & IPS_NAT_MASK)) + return 0; + + nf_ct_invert_tuple(&target, &ct->tuplehash[!dir].tuple); + + switch (tuple->src.l3num) { + case NFPROTO_IPV4: + tcf_ct_flow_table_add_action_nat_ipv4(tuple, target, + action); + break; + case NFPROTO_IPV6: + tcf_ct_flow_table_add_action_nat_ipv6(tuple, target, + action); + break; + default: + return -EOPNOTSUPP; + } + + switch (nf_ct_protonum(ct)) { + case IPPROTO_TCP: + tcf_ct_flow_table_add_action_nat_tcp(tuple, target, action); + break; + case IPPROTO_UDP: + tcf_ct_flow_table_add_action_nat_udp(tuple, target, action); + break; + default: + return -EOPNOTSUPP; + } + + return 0; +} + +static int tcf_ct_flow_table_fill_actions(struct net *net, + struct flow_offload *flow, + enum flow_offload_tuple_dir tdir, + struct nf_flow_rule *flow_rule) +{ + struct flow_action *action = &flow_rule->rule->action; + int num_entries = action->num_entries; + struct nf_conn *ct = flow->ct; + enum ip_conntrack_info ctinfo; + enum ip_conntrack_dir dir; + int i, err; + + switch (tdir) { + case FLOW_OFFLOAD_DIR_ORIGINAL: + dir = IP_CT_DIR_ORIGINAL; + ctinfo = IP_CT_ESTABLISHED; + set_bit(NF_FLOW_HW_ESTABLISHED, &flow->flags); + break; + case FLOW_OFFLOAD_DIR_REPLY: + dir = IP_CT_DIR_REPLY; + ctinfo = IP_CT_ESTABLISHED_REPLY; + break; + default: + return -EOPNOTSUPP; + } + + err = tcf_ct_flow_table_add_action_nat(net, ct, dir, action); + if (err) + goto err_nat; + + tcf_ct_flow_table_add_action_meta(ct, dir, ctinfo, action); + return 0; + +err_nat: + /* Clear filled actions */ + for (i = num_entries; i < action->num_entries; i++) + memset(&action->entries[i], 0, sizeof(action->entries[i])); + action->num_entries = num_entries; + + return err; +} + +static bool tcf_ct_flow_is_outdated(const struct flow_offload *flow) +{ + return test_bit(IPS_SEEN_REPLY_BIT, &flow->ct->status) && + test_bit(IPS_HW_OFFLOAD_BIT, &flow->ct->status) && + !test_bit(NF_FLOW_HW_PENDING, &flow->flags) && + !test_bit(NF_FLOW_HW_ESTABLISHED, &flow->flags); +} + +static void tcf_ct_flow_table_get_ref(struct tcf_ct_flow_table *ct_ft); + +static void tcf_ct_nf_get(struct nf_flowtable *ft) +{ + struct tcf_ct_flow_table *ct_ft = + container_of(ft, struct tcf_ct_flow_table, nf_ft); + + tcf_ct_flow_table_get_ref(ct_ft); +} + +static void tcf_ct_flow_table_put(struct tcf_ct_flow_table *ct_ft); + +static void tcf_ct_nf_put(struct nf_flowtable *ft) +{ + struct tcf_ct_flow_table *ct_ft = + container_of(ft, struct tcf_ct_flow_table, nf_ft); + + tcf_ct_flow_table_put(ct_ft); +} + +static struct nf_flowtable_type flowtable_ct = { + .gc = tcf_ct_flow_is_outdated, + .action = tcf_ct_flow_table_fill_actions, + .get = tcf_ct_nf_get, + .put = tcf_ct_nf_put, + .owner = THIS_MODULE, +}; + +static int tcf_ct_flow_table_get(struct net *net, struct tcf_ct_params *params) +{ + struct tcf_ct_flow_table *ct_ft; + int err = -ENOMEM; + + mutex_lock(&zones_mutex); + ct_ft = rhashtable_lookup_fast(&zones_ht, ¶ms->zone, zones_params); + if (ct_ft && refcount_inc_not_zero(&ct_ft->ref)) + goto out_unlock; + + ct_ft = kzalloc(sizeof(*ct_ft), GFP_KERNEL); + if (!ct_ft) + goto err_alloc; + refcount_set(&ct_ft->ref, 1); + + ct_ft->zone = params->zone; + err = rhashtable_insert_fast(&zones_ht, &ct_ft->node, zones_params); + if (err) + goto err_insert; + + ct_ft->nf_ft.type = &flowtable_ct; + ct_ft->nf_ft.flags |= NF_FLOWTABLE_HW_OFFLOAD | + NF_FLOWTABLE_COUNTER; + err = nf_flow_table_init(&ct_ft->nf_ft); + if (err) + goto err_init; + write_pnet(&ct_ft->nf_ft.net, net); + + __module_get(THIS_MODULE); +out_unlock: + params->ct_ft = ct_ft; + params->nf_ft = &ct_ft->nf_ft; + mutex_unlock(&zones_mutex); + + return 0; + +err_init: + rhashtable_remove_fast(&zones_ht, &ct_ft->node, zones_params); +err_insert: + kfree(ct_ft); +err_alloc: + mutex_unlock(&zones_mutex); + return err; +} + +static void tcf_ct_flow_table_get_ref(struct tcf_ct_flow_table *ct_ft) +{ + refcount_inc(&ct_ft->ref); +} + +static void tcf_ct_flow_table_cleanup_work(struct work_struct *work) +{ + struct tcf_ct_flow_table *ct_ft; + struct flow_block *block; + + ct_ft = container_of(to_rcu_work(work), struct tcf_ct_flow_table, + rwork); + nf_flow_table_free(&ct_ft->nf_ft); + + block = &ct_ft->nf_ft.flow_block; + down_write(&ct_ft->nf_ft.flow_block_lock); + WARN_ON(!list_empty(&block->cb_list)); + up_write(&ct_ft->nf_ft.flow_block_lock); + kfree(ct_ft); + + module_put(THIS_MODULE); +} + +static void tcf_ct_flow_table_put(struct tcf_ct_flow_table *ct_ft) +{ + if (refcount_dec_and_test(&ct_ft->ref)) { + rhashtable_remove_fast(&zones_ht, &ct_ft->node, zones_params); + INIT_RCU_WORK(&ct_ft->rwork, tcf_ct_flow_table_cleanup_work); + queue_rcu_work(act_ct_wq, &ct_ft->rwork); + } +} + +static void tcf_ct_flow_tc_ifidx(struct flow_offload *entry, + struct nf_conn_act_ct_ext *act_ct_ext, u8 dir) +{ + entry->tuplehash[dir].tuple.xmit_type = FLOW_OFFLOAD_XMIT_TC; + entry->tuplehash[dir].tuple.tc.iifidx = act_ct_ext->ifindex[dir]; +} + +static void tcf_ct_flow_ct_ext_ifidx_update(struct flow_offload *entry) +{ + struct nf_conn_act_ct_ext *act_ct_ext; + + act_ct_ext = nf_conn_act_ct_ext_find(entry->ct); + if (act_ct_ext) { + tcf_ct_flow_tc_ifidx(entry, act_ct_ext, FLOW_OFFLOAD_DIR_ORIGINAL); + tcf_ct_flow_tc_ifidx(entry, act_ct_ext, FLOW_OFFLOAD_DIR_REPLY); + } +} + +static void tcf_ct_flow_table_add(struct tcf_ct_flow_table *ct_ft, + struct nf_conn *ct, + bool tcp, bool bidirectional) +{ + struct nf_conn_act_ct_ext *act_ct_ext; + struct flow_offload *entry; + int err; + + if (test_and_set_bit(IPS_OFFLOAD_BIT, &ct->status)) + return; + + entry = flow_offload_alloc(ct); + if (!entry) { + WARN_ON_ONCE(1); + goto err_alloc; + } + + if (tcp) { + ct->proto.tcp.seen[0].flags |= IP_CT_TCP_FLAG_BE_LIBERAL; + ct->proto.tcp.seen[1].flags |= IP_CT_TCP_FLAG_BE_LIBERAL; + } + if (bidirectional) + __set_bit(NF_FLOW_HW_BIDIRECTIONAL, &entry->flags); + + act_ct_ext = nf_conn_act_ct_ext_find(ct); + if (act_ct_ext) { + tcf_ct_flow_tc_ifidx(entry, act_ct_ext, FLOW_OFFLOAD_DIR_ORIGINAL); + tcf_ct_flow_tc_ifidx(entry, act_ct_ext, FLOW_OFFLOAD_DIR_REPLY); + } + + err = flow_offload_add(&ct_ft->nf_ft, entry); + if (err) + goto err_add; + + return; + +err_add: + flow_offload_free(entry); +err_alloc: + clear_bit(IPS_OFFLOAD_BIT, &ct->status); +} + +static void tcf_ct_flow_table_process_conn(struct tcf_ct_flow_table *ct_ft, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo) +{ + bool tcp = false, bidirectional = true; + + switch (nf_ct_protonum(ct)) { + case IPPROTO_TCP: + if ((ctinfo != IP_CT_ESTABLISHED && + ctinfo != IP_CT_ESTABLISHED_REPLY) || + !test_bit(IPS_ASSURED_BIT, &ct->status) || + ct->proto.tcp.state != TCP_CONNTRACK_ESTABLISHED) + return; + + tcp = true; + break; + case IPPROTO_UDP: + if (!nf_ct_is_confirmed(ct)) + return; + if (!test_bit(IPS_ASSURED_BIT, &ct->status)) + bidirectional = false; + break; +#ifdef CONFIG_NF_CT_PROTO_GRE + case IPPROTO_GRE: { + struct nf_conntrack_tuple *tuple; + + if ((ctinfo != IP_CT_ESTABLISHED && + ctinfo != IP_CT_ESTABLISHED_REPLY) || + !test_bit(IPS_ASSURED_BIT, &ct->status) || + ct->status & IPS_NAT_MASK) + return; + + tuple = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple; + /* No support for GRE v1 */ + if (tuple->src.u.gre.key || tuple->dst.u.gre.key) + return; + break; + } +#endif + default: + return; + } + + if (nf_ct_ext_exist(ct, NF_CT_EXT_HELPER) || + ct->status & IPS_SEQ_ADJUST) + return; + + tcf_ct_flow_table_add(ct_ft, ct, tcp, bidirectional); +} + +static bool +tcf_ct_flow_table_fill_tuple_ipv4(struct sk_buff *skb, + struct flow_offload_tuple *tuple, + struct tcphdr **tcph) +{ + struct flow_ports *ports; + unsigned int thoff; + struct iphdr *iph; + size_t hdrsize; + u8 ipproto; + + if (!pskb_network_may_pull(skb, sizeof(*iph))) + return false; + + iph = ip_hdr(skb); + thoff = iph->ihl * 4; + + if (ip_is_fragment(iph) || + unlikely(thoff != sizeof(struct iphdr))) + return false; + + ipproto = iph->protocol; + switch (ipproto) { + case IPPROTO_TCP: + hdrsize = sizeof(struct tcphdr); + break; + case IPPROTO_UDP: + hdrsize = sizeof(*ports); + break; +#ifdef CONFIG_NF_CT_PROTO_GRE + case IPPROTO_GRE: + hdrsize = sizeof(struct gre_base_hdr); + break; +#endif + default: + return false; + } + + if (iph->ttl <= 1) + return false; + + if (!pskb_network_may_pull(skb, thoff + hdrsize)) + return false; + + switch (ipproto) { + case IPPROTO_TCP: + *tcph = (void *)(skb_network_header(skb) + thoff); + fallthrough; + case IPPROTO_UDP: + ports = (struct flow_ports *)(skb_network_header(skb) + thoff); + tuple->src_port = ports->source; + tuple->dst_port = ports->dest; + break; + case IPPROTO_GRE: { + struct gre_base_hdr *greh; + + greh = (struct gre_base_hdr *)(skb_network_header(skb) + thoff); + if ((greh->flags & GRE_VERSION) != GRE_VERSION_0) + return false; + break; + } + } + + iph = ip_hdr(skb); + + tuple->src_v4.s_addr = iph->saddr; + tuple->dst_v4.s_addr = iph->daddr; + tuple->l3proto = AF_INET; + tuple->l4proto = ipproto; + + return true; +} + +static bool +tcf_ct_flow_table_fill_tuple_ipv6(struct sk_buff *skb, + struct flow_offload_tuple *tuple, + struct tcphdr **tcph) +{ + struct flow_ports *ports; + struct ipv6hdr *ip6h; + unsigned int thoff; + size_t hdrsize; + u8 nexthdr; + + if (!pskb_network_may_pull(skb, sizeof(*ip6h))) + return false; + + ip6h = ipv6_hdr(skb); + thoff = sizeof(*ip6h); + + nexthdr = ip6h->nexthdr; + switch (nexthdr) { + case IPPROTO_TCP: + hdrsize = sizeof(struct tcphdr); + break; + case IPPROTO_UDP: + hdrsize = sizeof(*ports); + break; +#ifdef CONFIG_NF_CT_PROTO_GRE + case IPPROTO_GRE: + hdrsize = sizeof(struct gre_base_hdr); + break; +#endif + default: + return false; + } + + if (ip6h->hop_limit <= 1) + return false; + + if (!pskb_network_may_pull(skb, thoff + hdrsize)) + return false; + + switch (nexthdr) { + case IPPROTO_TCP: + *tcph = (void *)(skb_network_header(skb) + thoff); + fallthrough; + case IPPROTO_UDP: + ports = (struct flow_ports *)(skb_network_header(skb) + thoff); + tuple->src_port = ports->source; + tuple->dst_port = ports->dest; + break; + case IPPROTO_GRE: { + struct gre_base_hdr *greh; + + greh = (struct gre_base_hdr *)(skb_network_header(skb) + thoff); + if ((greh->flags & GRE_VERSION) != GRE_VERSION_0) + return false; + break; + } + } + + ip6h = ipv6_hdr(skb); + + tuple->src_v6 = ip6h->saddr; + tuple->dst_v6 = ip6h->daddr; + tuple->l3proto = AF_INET6; + tuple->l4proto = nexthdr; + + return true; +} + +static bool tcf_ct_flow_table_lookup(struct tcf_ct_params *p, + struct sk_buff *skb, + u8 family) +{ + struct nf_flowtable *nf_ft = &p->ct_ft->nf_ft; + struct flow_offload_tuple_rhash *tuplehash; + struct flow_offload_tuple tuple = {}; + enum ip_conntrack_info ctinfo; + struct tcphdr *tcph = NULL; + bool force_refresh = false; + struct flow_offload *flow; + struct nf_conn *ct; + u8 dir; + + switch (family) { + case NFPROTO_IPV4: + if (!tcf_ct_flow_table_fill_tuple_ipv4(skb, &tuple, &tcph)) + return false; + break; + case NFPROTO_IPV6: + if (!tcf_ct_flow_table_fill_tuple_ipv6(skb, &tuple, &tcph)) + return false; + break; + default: + return false; + } + + tuplehash = flow_offload_lookup(nf_ft, &tuple); + if (!tuplehash) + return false; + + dir = tuplehash->tuple.dir; + flow = container_of(tuplehash, struct flow_offload, tuplehash[dir]); + ct = flow->ct; + + if (dir == FLOW_OFFLOAD_DIR_REPLY && + !test_bit(NF_FLOW_HW_BIDIRECTIONAL, &flow->flags)) { + /* Only offload reply direction after connection became + * assured. + */ + if (test_bit(IPS_ASSURED_BIT, &ct->status)) + set_bit(NF_FLOW_HW_BIDIRECTIONAL, &flow->flags); + else if (test_bit(NF_FLOW_HW_ESTABLISHED, &flow->flags)) + /* If flow_table flow has already been updated to the + * established state, then don't refresh. + */ + return false; + force_refresh = true; + } + + if (tcph && (unlikely(tcph->fin || tcph->rst))) { + flow_offload_teardown(flow); + return false; + } + + if (dir == FLOW_OFFLOAD_DIR_ORIGINAL) + ctinfo = test_bit(IPS_SEEN_REPLY_BIT, &ct->status) ? + IP_CT_ESTABLISHED : IP_CT_NEW; + else + ctinfo = IP_CT_ESTABLISHED_REPLY; + + nf_conn_act_ct_ext_fill(skb, ct, ctinfo); + tcf_ct_flow_ct_ext_ifidx_update(flow); + flow_offload_refresh(nf_ft, flow, force_refresh); + if (!test_bit(IPS_ASSURED_BIT, &ct->status)) { + /* Process this flow in SW to allow promoting to ASSURED */ + return false; + } + + nf_conntrack_get(&ct->ct_general); + nf_ct_set(skb, ct, ctinfo); + if (nf_ft->flags & NF_FLOWTABLE_COUNTER) + nf_ct_acct_update(ct, dir, skb->len); + + return true; +} + +static int tcf_ct_flow_tables_init(void) +{ + return rhashtable_init(&zones_ht, &zones_params); +} + +static void tcf_ct_flow_tables_uninit(void) +{ + rhashtable_destroy(&zones_ht); +} + +static struct tc_action_ops act_ct_ops; + +struct tc_ct_action_net { + struct tc_action_net tn; /* Must be first */ + bool labels; +}; + +/* Determine whether skb->_nfct is equal to the result of conntrack lookup. */ +static bool tcf_ct_skb_nfct_cached(struct net *net, struct sk_buff *skb, + u16 zone_id, bool force) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + + ct = nf_ct_get(skb, &ctinfo); + if (!ct) + return false; + if (!net_eq(net, read_pnet(&ct->ct_net))) + goto drop_ct; + if (nf_ct_zone(ct)->id != zone_id) + goto drop_ct; + + /* Force conntrack entry direction. */ + if (force && CTINFO2DIR(ctinfo) != IP_CT_DIR_ORIGINAL) { + if (nf_ct_is_confirmed(ct)) + nf_ct_kill(ct); + + goto drop_ct; + } + + return true; + +drop_ct: + nf_ct_put(ct); + nf_ct_set(skb, NULL, IP_CT_UNTRACKED); + + return false; +} + +/* Trim the skb to the length specified by the IP/IPv6 header, + * removing any trailing lower-layer padding. This prepares the skb + * for higher-layer processing that assumes skb->len excludes padding + * (such as nf_ip_checksum). The caller needs to pull the skb to the + * network header, and ensure ip_hdr/ipv6_hdr points to valid data. + */ +static int tcf_ct_skb_network_trim(struct sk_buff *skb, int family) +{ + unsigned int len; + + switch (family) { + case NFPROTO_IPV4: + len = ntohs(ip_hdr(skb)->tot_len); + break; + case NFPROTO_IPV6: + len = sizeof(struct ipv6hdr) + + ntohs(ipv6_hdr(skb)->payload_len); + break; + default: + len = skb->len; + } + + return pskb_trim_rcsum(skb, len); +} + +static u8 tcf_ct_skb_nf_family(struct sk_buff *skb) +{ + u8 family = NFPROTO_UNSPEC; + + switch (skb_protocol(skb, true)) { + case htons(ETH_P_IP): + family = NFPROTO_IPV4; + break; + case htons(ETH_P_IPV6): + family = NFPROTO_IPV6; + break; + default: + break; + } + + return family; +} + +static int tcf_ct_ipv4_is_fragment(struct sk_buff *skb, bool *frag) +{ + unsigned int len; + + len = skb_network_offset(skb) + sizeof(struct iphdr); + if (unlikely(skb->len < len)) + return -EINVAL; + if (unlikely(!pskb_may_pull(skb, len))) + return -ENOMEM; + + *frag = ip_is_fragment(ip_hdr(skb)); + return 0; +} + +static int tcf_ct_ipv6_is_fragment(struct sk_buff *skb, bool *frag) +{ + unsigned int flags = 0, len, payload_ofs = 0; + unsigned short frag_off; + int nexthdr; + + len = skb_network_offset(skb) + sizeof(struct ipv6hdr); + if (unlikely(skb->len < len)) + return -EINVAL; + if (unlikely(!pskb_may_pull(skb, len))) + return -ENOMEM; + + nexthdr = ipv6_find_hdr(skb, &payload_ofs, -1, &frag_off, &flags); + if (unlikely(nexthdr < 0)) + return -EPROTO; + + *frag = flags & IP6_FH_F_FRAG; + return 0; +} + +static int tcf_ct_handle_fragments(struct net *net, struct sk_buff *skb, + u8 family, u16 zone, bool *defrag) +{ + enum ip_conntrack_info ctinfo; + struct nf_conn *ct; + int err = 0; + bool frag; + u16 mru; + + /* Previously seen (loopback)? Ignore. */ + ct = nf_ct_get(skb, &ctinfo); + if ((ct && !nf_ct_is_template(ct)) || ctinfo == IP_CT_UNTRACKED) + return 0; + + if (family == NFPROTO_IPV4) + err = tcf_ct_ipv4_is_fragment(skb, &frag); + else + err = tcf_ct_ipv6_is_fragment(skb, &frag); + if (err || !frag) + return err; + + mru = tc_skb_cb(skb)->mru; + + if (family == NFPROTO_IPV4) { + enum ip_defrag_users user = IP_DEFRAG_CONNTRACK_IN + zone; + + memset(IPCB(skb), 0, sizeof(struct inet_skb_parm)); + local_bh_disable(); + err = ip_defrag(net, skb, user); + local_bh_enable(); + if (err && err != -EINPROGRESS) + return err; + + if (!err) { + *defrag = true; + mru = IPCB(skb)->frag_max_size; + } + } else { /* NFPROTO_IPV6 */ +#if IS_ENABLED(CONFIG_NF_DEFRAG_IPV6) + enum ip6_defrag_users user = IP6_DEFRAG_CONNTRACK_IN + zone; + + memset(IP6CB(skb), 0, sizeof(struct inet6_skb_parm)); + err = nf_ct_frag6_gather(net, skb, user); + if (err && err != -EINPROGRESS) + goto out_free; + + if (!err) { + *defrag = true; + mru = IP6CB(skb)->frag_max_size; + } +#else + err = -EOPNOTSUPP; + goto out_free; +#endif + } + + if (err != -EINPROGRESS) + tc_skb_cb(skb)->mru = mru; + skb_clear_hash(skb); + skb->ignore_df = 1; + return err; + +out_free: + kfree_skb(skb); + return err; +} + +static void tcf_ct_params_free(struct tcf_ct_params *params) +{ + if (params->ct_ft) + tcf_ct_flow_table_put(params->ct_ft); + if (params->tmpl) + nf_ct_put(params->tmpl); + kfree(params); +} + +static void tcf_ct_params_free_rcu(struct rcu_head *head) +{ + struct tcf_ct_params *params; + + params = container_of(head, struct tcf_ct_params, rcu); + tcf_ct_params_free(params); +} + +#if IS_ENABLED(CONFIG_NF_NAT) +/* Modelled after nf_nat_ipv[46]_fn(). + * range is only used for new, uninitialized NAT state. + * Returns either NF_ACCEPT or NF_DROP. + */ +static int ct_nat_execute(struct sk_buff *skb, struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + const struct nf_nat_range2 *range, + enum nf_nat_manip_type maniptype) +{ + __be16 proto = skb_protocol(skb, true); + int hooknum, err = NF_ACCEPT; + + /* See HOOK2MANIP(). */ + if (maniptype == NF_NAT_MANIP_SRC) + hooknum = NF_INET_LOCAL_IN; /* Source NAT */ + else + hooknum = NF_INET_LOCAL_OUT; /* Destination NAT */ + + switch (ctinfo) { + case IP_CT_RELATED: + case IP_CT_RELATED_REPLY: + if (proto == htons(ETH_P_IP) && + ip_hdr(skb)->protocol == IPPROTO_ICMP) { + if (!nf_nat_icmp_reply_translation(skb, ct, ctinfo, + hooknum)) + err = NF_DROP; + goto out; + } else if (IS_ENABLED(CONFIG_IPV6) && proto == htons(ETH_P_IPV6)) { + __be16 frag_off; + u8 nexthdr = ipv6_hdr(skb)->nexthdr; + int hdrlen = ipv6_skip_exthdr(skb, + sizeof(struct ipv6hdr), + &nexthdr, &frag_off); + + if (hdrlen >= 0 && nexthdr == IPPROTO_ICMPV6) { + if (!nf_nat_icmpv6_reply_translation(skb, ct, + ctinfo, + hooknum, + hdrlen)) + err = NF_DROP; + goto out; + } + } + /* Non-ICMP, fall thru to initialize if needed. */ + fallthrough; + case IP_CT_NEW: + /* Seen it before? This can happen for loopback, retrans, + * or local packets. + */ + if (!nf_nat_initialized(ct, maniptype)) { + /* Initialize according to the NAT action. */ + err = (range && range->flags & NF_NAT_RANGE_MAP_IPS) + /* Action is set up to establish a new + * mapping. + */ + ? nf_nat_setup_info(ct, range, maniptype) + : nf_nat_alloc_null_binding(ct, hooknum); + if (err != NF_ACCEPT) + goto out; + } + break; + + case IP_CT_ESTABLISHED: + case IP_CT_ESTABLISHED_REPLY: + break; + + default: + err = NF_DROP; + goto out; + } + + err = nf_nat_packet(ct, ctinfo, hooknum, skb); + if (err == NF_ACCEPT) { + if (maniptype == NF_NAT_MANIP_SRC) + tc_skb_cb(skb)->post_ct_snat = 1; + if (maniptype == NF_NAT_MANIP_DST) + tc_skb_cb(skb)->post_ct_dnat = 1; + } +out: + return err; +} +#endif /* CONFIG_NF_NAT */ + +static void tcf_ct_act_set_mark(struct nf_conn *ct, u32 mark, u32 mask) +{ +#if IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) + u32 new_mark; + + if (!mask) + return; + + new_mark = mark | (READ_ONCE(ct->mark) & ~(mask)); + if (READ_ONCE(ct->mark) != new_mark) { + WRITE_ONCE(ct->mark, new_mark); + if (nf_ct_is_confirmed(ct)) + nf_conntrack_event_cache(IPCT_MARK, ct); + } +#endif +} + +static void tcf_ct_act_set_labels(struct nf_conn *ct, + u32 *labels, + u32 *labels_m) +{ +#if IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS) + size_t labels_sz = sizeof_field(struct tcf_ct_params, labels); + + if (!memchr_inv(labels_m, 0, labels_sz)) + return; + + nf_connlabels_replace(ct, labels, labels_m, 4); +#endif +} + +static int tcf_ct_act_nat(struct sk_buff *skb, + struct nf_conn *ct, + enum ip_conntrack_info ctinfo, + int ct_action, + struct nf_nat_range2 *range, + bool commit) +{ +#if IS_ENABLED(CONFIG_NF_NAT) + int err; + enum nf_nat_manip_type maniptype; + + if (!(ct_action & TCA_CT_ACT_NAT)) + return NF_ACCEPT; + + /* Add NAT extension if not confirmed yet. */ + if (!nf_ct_is_confirmed(ct) && !nf_ct_nat_ext_add(ct)) + return NF_DROP; /* Can't NAT. */ + + if (ctinfo != IP_CT_NEW && (ct->status & IPS_NAT_MASK) && + (ctinfo != IP_CT_RELATED || commit)) { + /* NAT an established or related connection like before. */ + if (CTINFO2DIR(ctinfo) == IP_CT_DIR_REPLY) + /* This is the REPLY direction for a connection + * for which NAT was applied in the forward + * direction. Do the reverse NAT. + */ + maniptype = ct->status & IPS_SRC_NAT + ? NF_NAT_MANIP_DST : NF_NAT_MANIP_SRC; + else + maniptype = ct->status & IPS_SRC_NAT + ? NF_NAT_MANIP_SRC : NF_NAT_MANIP_DST; + } else if (ct_action & TCA_CT_ACT_NAT_SRC) { + maniptype = NF_NAT_MANIP_SRC; + } else if (ct_action & TCA_CT_ACT_NAT_DST) { + maniptype = NF_NAT_MANIP_DST; + } else { + return NF_ACCEPT; + } + + err = ct_nat_execute(skb, ct, ctinfo, range, maniptype); + if (err == NF_ACCEPT && ct->status & IPS_DST_NAT) { + if (ct->status & IPS_SRC_NAT) { + if (maniptype == NF_NAT_MANIP_SRC) + maniptype = NF_NAT_MANIP_DST; + else + maniptype = NF_NAT_MANIP_SRC; + + err = ct_nat_execute(skb, ct, ctinfo, range, + maniptype); + } else if (CTINFO2DIR(ctinfo) == IP_CT_DIR_ORIGINAL) { + err = ct_nat_execute(skb, ct, ctinfo, NULL, + NF_NAT_MANIP_SRC); + } + } + return err; +#else + return NF_ACCEPT; +#endif +} + +static int tcf_ct_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + struct net *net = dev_net(skb->dev); + bool cached, commit, clear, force; + enum ip_conntrack_info ctinfo; + struct tcf_ct *c = to_ct(a); + struct nf_conn *tmpl = NULL; + struct nf_hook_state state; + int nh_ofs, err, retval; + struct tcf_ct_params *p; + bool skip_add = false; + bool defrag = false; + struct nf_conn *ct; + u8 family; + + p = rcu_dereference_bh(c->params); + + retval = READ_ONCE(c->tcf_action); + commit = p->ct_action & TCA_CT_ACT_COMMIT; + clear = p->ct_action & TCA_CT_ACT_CLEAR; + force = p->ct_action & TCA_CT_ACT_FORCE; + tmpl = p->tmpl; + + tcf_lastuse_update(&c->tcf_tm); + tcf_action_update_bstats(&c->common, skb); + + if (clear) { + tc_skb_cb(skb)->post_ct = false; + ct = nf_ct_get(skb, &ctinfo); + if (ct) { + nf_ct_put(ct); + nf_ct_set(skb, NULL, IP_CT_UNTRACKED); + } + + goto out_clear; + } + + family = tcf_ct_skb_nf_family(skb); + if (family == NFPROTO_UNSPEC) + goto drop; + + /* The conntrack module expects to be working at L3. + * We also try to pull the IPv4/6 header to linear area + */ + nh_ofs = skb_network_offset(skb); + skb_pull_rcsum(skb, nh_ofs); + err = tcf_ct_handle_fragments(net, skb, family, p->zone, &defrag); + if (err) + goto out_frag; + + err = tcf_ct_skb_network_trim(skb, family); + if (err) + goto drop; + + /* If we are recirculating packets to match on ct fields and + * committing with a separate ct action, then we don't need to + * actually run the packet through conntrack twice unless it's for a + * different zone. + */ + cached = tcf_ct_skb_nfct_cached(net, skb, p->zone, force); + if (!cached) { + if (tcf_ct_flow_table_lookup(p, skb, family)) { + skip_add = true; + goto do_nat; + } + + /* Associate skb with specified zone. */ + if (tmpl) { + nf_conntrack_put(skb_nfct(skb)); + nf_conntrack_get(&tmpl->ct_general); + nf_ct_set(skb, tmpl, IP_CT_NEW); + } + + state.hook = NF_INET_PRE_ROUTING; + state.net = net; + state.pf = family; + err = nf_conntrack_in(skb, &state); + if (err != NF_ACCEPT) + goto out_push; + } + +do_nat: + ct = nf_ct_get(skb, &ctinfo); + if (!ct) + goto out_push; + nf_ct_deliver_cached_events(ct); + nf_conn_act_ct_ext_fill(skb, ct, ctinfo); + + err = tcf_ct_act_nat(skb, ct, ctinfo, p->ct_action, &p->range, commit); + if (err != NF_ACCEPT) + goto drop; + + if (commit) { + tcf_ct_act_set_mark(ct, p->mark, p->mark_mask); + tcf_ct_act_set_labels(ct, p->labels, p->labels_mask); + + if (!nf_ct_is_confirmed(ct)) + nf_conn_act_ct_ext_add(skb, ct, ctinfo); + + /* This will take care of sending queued events + * even if the connection is already confirmed. + */ + if (nf_conntrack_confirm(skb) != NF_ACCEPT) + goto drop; + } + + if (!skip_add) + tcf_ct_flow_table_process_conn(p->ct_ft, ct, ctinfo); + +out_push: + skb_push_rcsum(skb, nh_ofs); + + tc_skb_cb(skb)->post_ct = true; + tc_skb_cb(skb)->zone = p->zone; +out_clear: + if (defrag) + qdisc_skb_cb(skb)->pkt_len = skb->len; + return retval; + +out_frag: + if (err != -EINPROGRESS) + tcf_action_inc_drop_qstats(&c->common); + return TC_ACT_CONSUMED; + +drop: + tcf_action_inc_drop_qstats(&c->common); + return TC_ACT_SHOT; +} + +static const struct nla_policy ct_policy[TCA_CT_MAX + 1] = { + [TCA_CT_ACTION] = { .type = NLA_U16 }, + [TCA_CT_PARMS] = NLA_POLICY_EXACT_LEN(sizeof(struct tc_ct)), + [TCA_CT_ZONE] = { .type = NLA_U16 }, + [TCA_CT_MARK] = { .type = NLA_U32 }, + [TCA_CT_MARK_MASK] = { .type = NLA_U32 }, + [TCA_CT_LABELS] = { .type = NLA_BINARY, + .len = 128 / BITS_PER_BYTE }, + [TCA_CT_LABELS_MASK] = { .type = NLA_BINARY, + .len = 128 / BITS_PER_BYTE }, + [TCA_CT_NAT_IPV4_MIN] = { .type = NLA_U32 }, + [TCA_CT_NAT_IPV4_MAX] = { .type = NLA_U32 }, + [TCA_CT_NAT_IPV6_MIN] = NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)), + [TCA_CT_NAT_IPV6_MAX] = NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)), + [TCA_CT_NAT_PORT_MIN] = { .type = NLA_U16 }, + [TCA_CT_NAT_PORT_MAX] = { .type = NLA_U16 }, +}; + +static int tcf_ct_fill_params_nat(struct tcf_ct_params *p, + struct tc_ct *parm, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct nf_nat_range2 *range; + + if (!(p->ct_action & TCA_CT_ACT_NAT)) + return 0; + + if (!IS_ENABLED(CONFIG_NF_NAT)) { + NL_SET_ERR_MSG_MOD(extack, "Netfilter nat isn't enabled in kernel"); + return -EOPNOTSUPP; + } + + if (!(p->ct_action & (TCA_CT_ACT_NAT_SRC | TCA_CT_ACT_NAT_DST))) + return 0; + + if ((p->ct_action & TCA_CT_ACT_NAT_SRC) && + (p->ct_action & TCA_CT_ACT_NAT_DST)) { + NL_SET_ERR_MSG_MOD(extack, "dnat and snat can't be enabled at the same time"); + return -EOPNOTSUPP; + } + + range = &p->range; + if (tb[TCA_CT_NAT_IPV4_MIN]) { + struct nlattr *max_attr = tb[TCA_CT_NAT_IPV4_MAX]; + + p->ipv4_range = true; + range->flags |= NF_NAT_RANGE_MAP_IPS; + range->min_addr.ip = + nla_get_in_addr(tb[TCA_CT_NAT_IPV4_MIN]); + + range->max_addr.ip = max_attr ? + nla_get_in_addr(max_attr) : + range->min_addr.ip; + } else if (tb[TCA_CT_NAT_IPV6_MIN]) { + struct nlattr *max_attr = tb[TCA_CT_NAT_IPV6_MAX]; + + p->ipv4_range = false; + range->flags |= NF_NAT_RANGE_MAP_IPS; + range->min_addr.in6 = + nla_get_in6_addr(tb[TCA_CT_NAT_IPV6_MIN]); + + range->max_addr.in6 = max_attr ? + nla_get_in6_addr(max_attr) : + range->min_addr.in6; + } + + if (tb[TCA_CT_NAT_PORT_MIN]) { + range->flags |= NF_NAT_RANGE_PROTO_SPECIFIED; + range->min_proto.all = nla_get_be16(tb[TCA_CT_NAT_PORT_MIN]); + + range->max_proto.all = tb[TCA_CT_NAT_PORT_MAX] ? + nla_get_be16(tb[TCA_CT_NAT_PORT_MAX]) : + range->min_proto.all; + } + + return 0; +} + +static void tcf_ct_set_key_val(struct nlattr **tb, + void *val, int val_type, + void *mask, int mask_type, + int len) +{ + if (!tb[val_type]) + return; + nla_memcpy(val, tb[val_type], len); + + if (!mask) + return; + + if (mask_type == TCA_CT_UNSPEC || !tb[mask_type]) + memset(mask, 0xff, len); + else + nla_memcpy(mask, tb[mask_type], len); +} + +static int tcf_ct_fill_params(struct net *net, + struct tcf_ct_params *p, + struct tc_ct *parm, + struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct tc_ct_action_net *tn = net_generic(net, act_ct_ops.net_id); + struct nf_conntrack_zone zone; + struct nf_conn *tmpl; + int err; + + p->zone = NF_CT_DEFAULT_ZONE_ID; + + tcf_ct_set_key_val(tb, + &p->ct_action, TCA_CT_ACTION, + NULL, TCA_CT_UNSPEC, + sizeof(p->ct_action)); + + if (p->ct_action & TCA_CT_ACT_CLEAR) + return 0; + + err = tcf_ct_fill_params_nat(p, parm, tb, extack); + if (err) + return err; + + if (tb[TCA_CT_MARK]) { + if (!IS_ENABLED(CONFIG_NF_CONNTRACK_MARK)) { + NL_SET_ERR_MSG_MOD(extack, "Conntrack mark isn't enabled."); + return -EOPNOTSUPP; + } + tcf_ct_set_key_val(tb, + &p->mark, TCA_CT_MARK, + &p->mark_mask, TCA_CT_MARK_MASK, + sizeof(p->mark)); + } + + if (tb[TCA_CT_LABELS]) { + if (!IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS)) { + NL_SET_ERR_MSG_MOD(extack, "Conntrack labels isn't enabled."); + return -EOPNOTSUPP; + } + + if (!tn->labels) { + NL_SET_ERR_MSG_MOD(extack, "Failed to set connlabel length"); + return -EOPNOTSUPP; + } + tcf_ct_set_key_val(tb, + p->labels, TCA_CT_LABELS, + p->labels_mask, TCA_CT_LABELS_MASK, + sizeof(p->labels)); + } + + if (tb[TCA_CT_ZONE]) { + if (!IS_ENABLED(CONFIG_NF_CONNTRACK_ZONES)) { + NL_SET_ERR_MSG_MOD(extack, "Conntrack zones isn't enabled."); + return -EOPNOTSUPP; + } + + tcf_ct_set_key_val(tb, + &p->zone, TCA_CT_ZONE, + NULL, TCA_CT_UNSPEC, + sizeof(p->zone)); + } + + nf_ct_zone_init(&zone, p->zone, NF_CT_DEFAULT_ZONE_DIR, 0); + tmpl = nf_ct_tmpl_alloc(net, &zone, GFP_KERNEL); + if (!tmpl) { + NL_SET_ERR_MSG_MOD(extack, "Failed to allocate conntrack template"); + return -ENOMEM; + } + __set_bit(IPS_CONFIRMED_BIT, &tmpl->status); + p->tmpl = tmpl; + + return 0; +} + +static int tcf_ct_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, u32 flags, + struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_ct_ops.net_id); + bool bind = flags & TCA_ACT_FLAGS_BIND; + struct tcf_ct_params *params = NULL; + struct nlattr *tb[TCA_CT_MAX + 1]; + struct tcf_chain *goto_ch = NULL; + struct tc_ct *parm; + struct tcf_ct *c; + int err, res = 0; + u32 index; + + if (!nla) { + NL_SET_ERR_MSG_MOD(extack, "Ct requires attributes to be passed"); + return -EINVAL; + } + + err = nla_parse_nested(tb, TCA_CT_MAX, nla, ct_policy, extack); + if (err < 0) + return err; + + if (!tb[TCA_CT_PARMS]) { + NL_SET_ERR_MSG_MOD(extack, "Missing required ct parameters"); + return -EINVAL; + } + parm = nla_data(tb[TCA_CT_PARMS]); + index = parm->index; + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (err < 0) + return err; + + if (!err) { + err = tcf_idr_create_from_flags(tn, index, est, a, + &act_ct_ops, bind, flags); + if (err) { + tcf_idr_cleanup(tn, index); + return err; + } + res = ACT_P_CREATED; + } else { + if (bind) + return 0; + + if (!(flags & TCA_ACT_FLAGS_REPLACE)) { + tcf_idr_release(*a, bind); + return -EEXIST; + } + } + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack); + if (err < 0) + goto cleanup; + + c = to_ct(*a); + + params = kzalloc(sizeof(*params), GFP_KERNEL); + if (unlikely(!params)) { + err = -ENOMEM; + goto cleanup; + } + + err = tcf_ct_fill_params(net, params, parm, tb, extack); + if (err) + goto cleanup; + + err = tcf_ct_flow_table_get(net, params); + if (err) + goto cleanup; + + spin_lock_bh(&c->tcf_lock); + goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch); + params = rcu_replace_pointer(c->params, params, + lockdep_is_held(&c->tcf_lock)); + spin_unlock_bh(&c->tcf_lock); + + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + if (params) + call_rcu(¶ms->rcu, tcf_ct_params_free_rcu); + + return res; + +cleanup: + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + if (params) + tcf_ct_params_free(params); + tcf_idr_release(*a, bind); + return err; +} + +static void tcf_ct_cleanup(struct tc_action *a) +{ + struct tcf_ct_params *params; + struct tcf_ct *c = to_ct(a); + + params = rcu_dereference_protected(c->params, 1); + if (params) + call_rcu(¶ms->rcu, tcf_ct_params_free_rcu); +} + +static int tcf_ct_dump_key_val(struct sk_buff *skb, + void *val, int val_type, + void *mask, int mask_type, + int len) +{ + int err; + + if (mask && !memchr_inv(mask, 0, len)) + return 0; + + err = nla_put(skb, val_type, len, val); + if (err) + return err; + + if (mask_type != TCA_CT_UNSPEC) { + err = nla_put(skb, mask_type, len, mask); + if (err) + return err; + } + + return 0; +} + +static int tcf_ct_dump_nat(struct sk_buff *skb, struct tcf_ct_params *p) +{ + struct nf_nat_range2 *range = &p->range; + + if (!(p->ct_action & TCA_CT_ACT_NAT)) + return 0; + + if (!(p->ct_action & (TCA_CT_ACT_NAT_SRC | TCA_CT_ACT_NAT_DST))) + return 0; + + if (range->flags & NF_NAT_RANGE_MAP_IPS) { + if (p->ipv4_range) { + if (nla_put_in_addr(skb, TCA_CT_NAT_IPV4_MIN, + range->min_addr.ip)) + return -1; + if (nla_put_in_addr(skb, TCA_CT_NAT_IPV4_MAX, + range->max_addr.ip)) + return -1; + } else { + if (nla_put_in6_addr(skb, TCA_CT_NAT_IPV6_MIN, + &range->min_addr.in6)) + return -1; + if (nla_put_in6_addr(skb, TCA_CT_NAT_IPV6_MAX, + &range->max_addr.in6)) + return -1; + } + } + + if (range->flags & NF_NAT_RANGE_PROTO_SPECIFIED) { + if (nla_put_be16(skb, TCA_CT_NAT_PORT_MIN, + range->min_proto.all)) + return -1; + if (nla_put_be16(skb, TCA_CT_NAT_PORT_MAX, + range->max_proto.all)) + return -1; + } + + return 0; +} + +static inline int tcf_ct_dump(struct sk_buff *skb, struct tc_action *a, + int bind, int ref) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tcf_ct *c = to_ct(a); + struct tcf_ct_params *p; + + struct tc_ct opt = { + .index = c->tcf_index, + .refcnt = refcount_read(&c->tcf_refcnt) - ref, + .bindcnt = atomic_read(&c->tcf_bindcnt) - bind, + }; + struct tcf_t t; + + spin_lock_bh(&c->tcf_lock); + p = rcu_dereference_protected(c->params, + lockdep_is_held(&c->tcf_lock)); + opt.action = c->tcf_action; + + if (tcf_ct_dump_key_val(skb, + &p->ct_action, TCA_CT_ACTION, + NULL, TCA_CT_UNSPEC, + sizeof(p->ct_action))) + goto nla_put_failure; + + if (p->ct_action & TCA_CT_ACT_CLEAR) + goto skip_dump; + + if (IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) && + tcf_ct_dump_key_val(skb, + &p->mark, TCA_CT_MARK, + &p->mark_mask, TCA_CT_MARK_MASK, + sizeof(p->mark))) + goto nla_put_failure; + + if (IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS) && + tcf_ct_dump_key_val(skb, + p->labels, TCA_CT_LABELS, + p->labels_mask, TCA_CT_LABELS_MASK, + sizeof(p->labels))) + goto nla_put_failure; + + if (IS_ENABLED(CONFIG_NF_CONNTRACK_ZONES) && + tcf_ct_dump_key_val(skb, + &p->zone, TCA_CT_ZONE, + NULL, TCA_CT_UNSPEC, + sizeof(p->zone))) + goto nla_put_failure; + + if (tcf_ct_dump_nat(skb, p)) + goto nla_put_failure; + +skip_dump: + if (nla_put(skb, TCA_CT_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + + tcf_tm_dump(&t, &c->tcf_tm); + if (nla_put_64bit(skb, TCA_CT_TM, sizeof(t), &t, TCA_CT_PAD)) + goto nla_put_failure; + spin_unlock_bh(&c->tcf_lock); + + return skb->len; +nla_put_failure: + spin_unlock_bh(&c->tcf_lock); + nlmsg_trim(skb, b); + return -1; +} + +static void tcf_stats_update(struct tc_action *a, u64 bytes, u64 packets, + u64 drops, u64 lastuse, bool hw) +{ + struct tcf_ct *c = to_ct(a); + + tcf_action_update_stats(a, bytes, packets, drops, hw); + c->tcf_tm.lastuse = max_t(u64, c->tcf_tm.lastuse, lastuse); +} + +static int tcf_ct_offload_act_setup(struct tc_action *act, void *entry_data, + u32 *index_inc, bool bind, + struct netlink_ext_ack *extack) +{ + if (bind) { + struct flow_action_entry *entry = entry_data; + + entry->id = FLOW_ACTION_CT; + entry->ct.action = tcf_ct_action(act); + entry->ct.zone = tcf_ct_zone(act); + entry->ct.flow_table = tcf_ct_ft(act); + *index_inc = 1; + } else { + struct flow_offload_action *fl_action = entry_data; + + fl_action->id = FLOW_ACTION_CT; + } + + return 0; +} + +static struct tc_action_ops act_ct_ops = { + .kind = "ct", + .id = TCA_ID_CT, + .owner = THIS_MODULE, + .act = tcf_ct_act, + .dump = tcf_ct_dump, + .init = tcf_ct_init, + .cleanup = tcf_ct_cleanup, + .stats_update = tcf_stats_update, + .offload_act_setup = tcf_ct_offload_act_setup, + .size = sizeof(struct tcf_ct), +}; + +static __net_init int ct_init_net(struct net *net) +{ + unsigned int n_bits = sizeof_field(struct tcf_ct_params, labels) * 8; + struct tc_ct_action_net *tn = net_generic(net, act_ct_ops.net_id); + + if (nf_connlabels_get(net, n_bits - 1)) { + tn->labels = false; + pr_err("act_ct: Failed to set connlabels length"); + } else { + tn->labels = true; + } + + return tc_action_net_init(net, &tn->tn, &act_ct_ops); +} + +static void __net_exit ct_exit_net(struct list_head *net_list) +{ + struct net *net; + + rtnl_lock(); + list_for_each_entry(net, net_list, exit_list) { + struct tc_ct_action_net *tn = net_generic(net, act_ct_ops.net_id); + + if (tn->labels) + nf_connlabels_put(net); + } + rtnl_unlock(); + + tc_action_net_exit(net_list, act_ct_ops.net_id); +} + +static struct pernet_operations ct_net_ops = { + .init = ct_init_net, + .exit_batch = ct_exit_net, + .id = &act_ct_ops.net_id, + .size = sizeof(struct tc_ct_action_net), +}; + +static int __init ct_init_module(void) +{ + int err; + + act_ct_wq = alloc_ordered_workqueue("act_ct_workqueue", 0); + if (!act_ct_wq) + return -ENOMEM; + + err = tcf_ct_flow_tables_init(); + if (err) + goto err_tbl_init; + + err = tcf_register_action(&act_ct_ops, &ct_net_ops); + if (err) + goto err_register; + + static_branch_inc(&tcf_frag_xmit_count); + + return 0; + +err_register: + tcf_ct_flow_tables_uninit(); +err_tbl_init: + destroy_workqueue(act_ct_wq); + return err; +} + +static void __exit ct_cleanup_module(void) +{ + static_branch_dec(&tcf_frag_xmit_count); + tcf_unregister_action(&act_ct_ops, &ct_net_ops); + tcf_ct_flow_tables_uninit(); + destroy_workqueue(act_ct_wq); +} + +module_init(ct_init_module); +module_exit(ct_cleanup_module); +MODULE_AUTHOR("Paul Blakey "); +MODULE_AUTHOR("Yossi Kuperman "); +MODULE_AUTHOR("Marcelo Ricardo Leitner "); +MODULE_DESCRIPTION("Connection tracking action"); +MODULE_LICENSE("GPL v2"); diff --git a/net/sched/act_ctinfo.c b/net/sched/act_ctinfo.c new file mode 100644 index 000000000..7275ad869 --- /dev/null +++ b/net/sched/act_ctinfo.c @@ -0,0 +1,398 @@ +// SPDX-License-Identifier: GPL-2.0+ +/* net/sched/act_ctinfo.c netfilter ctinfo connmark actions + * + * Copyright (c) 2019 Kevin Darbyshire-Bryant + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +static struct tc_action_ops act_ctinfo_ops; + +static void tcf_ctinfo_dscp_set(struct nf_conn *ct, struct tcf_ctinfo *ca, + struct tcf_ctinfo_params *cp, + struct sk_buff *skb, int wlen, int proto) +{ + u8 dscp, newdscp; + + newdscp = (((READ_ONCE(ct->mark) & cp->dscpmask) >> cp->dscpmaskshift) << 2) & + ~INET_ECN_MASK; + + switch (proto) { + case NFPROTO_IPV4: + dscp = ipv4_get_dsfield(ip_hdr(skb)) & ~INET_ECN_MASK; + if (dscp != newdscp) { + if (likely(!skb_try_make_writable(skb, wlen))) { + ipv4_change_dsfield(ip_hdr(skb), + INET_ECN_MASK, + newdscp); + ca->stats_dscp_set++; + } else { + ca->stats_dscp_error++; + } + } + break; + case NFPROTO_IPV6: + dscp = ipv6_get_dsfield(ipv6_hdr(skb)) & ~INET_ECN_MASK; + if (dscp != newdscp) { + if (likely(!skb_try_make_writable(skb, wlen))) { + ipv6_change_dsfield(ipv6_hdr(skb), + INET_ECN_MASK, + newdscp); + ca->stats_dscp_set++; + } else { + ca->stats_dscp_error++; + } + } + break; + default: + break; + } +} + +static void tcf_ctinfo_cpmark_set(struct nf_conn *ct, struct tcf_ctinfo *ca, + struct tcf_ctinfo_params *cp, + struct sk_buff *skb) +{ + ca->stats_cpmark_set++; + skb->mark = READ_ONCE(ct->mark) & cp->cpmarkmask; +} + +static int tcf_ctinfo_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + const struct nf_conntrack_tuple_hash *thash = NULL; + struct tcf_ctinfo *ca = to_ctinfo(a); + struct nf_conntrack_tuple tuple; + struct nf_conntrack_zone zone; + enum ip_conntrack_info ctinfo; + struct tcf_ctinfo_params *cp; + struct nf_conn *ct; + int proto, wlen; + int action; + + cp = rcu_dereference_bh(ca->params); + + tcf_lastuse_update(&ca->tcf_tm); + tcf_action_update_bstats(&ca->common, skb); + action = READ_ONCE(ca->tcf_action); + + wlen = skb_network_offset(skb); + switch (skb_protocol(skb, true)) { + case htons(ETH_P_IP): + wlen += sizeof(struct iphdr); + if (!pskb_may_pull(skb, wlen)) + goto out; + + proto = NFPROTO_IPV4; + break; + case htons(ETH_P_IPV6): + wlen += sizeof(struct ipv6hdr); + if (!pskb_may_pull(skb, wlen)) + goto out; + + proto = NFPROTO_IPV6; + break; + default: + goto out; + } + + ct = nf_ct_get(skb, &ctinfo); + if (!ct) { /* look harder, usually ingress */ + if (!nf_ct_get_tuplepr(skb, skb_network_offset(skb), + proto, cp->net, &tuple)) + goto out; + zone.id = cp->zone; + zone.dir = NF_CT_DEFAULT_ZONE_DIR; + + thash = nf_conntrack_find_get(cp->net, &zone, &tuple); + if (!thash) + goto out; + + ct = nf_ct_tuplehash_to_ctrack(thash); + } + + if (cp->mode & CTINFO_MODE_DSCP) + if (!cp->dscpstatemask || (READ_ONCE(ct->mark) & cp->dscpstatemask)) + tcf_ctinfo_dscp_set(ct, ca, cp, skb, wlen, proto); + + if (cp->mode & CTINFO_MODE_CPMARK) + tcf_ctinfo_cpmark_set(ct, ca, cp, skb); + + if (thash) + nf_ct_put(ct); +out: + return action; +} + +static const struct nla_policy ctinfo_policy[TCA_CTINFO_MAX + 1] = { + [TCA_CTINFO_ACT] = + NLA_POLICY_EXACT_LEN(sizeof(struct tc_ctinfo)), + [TCA_CTINFO_ZONE] = { .type = NLA_U16 }, + [TCA_CTINFO_PARMS_DSCP_MASK] = { .type = NLA_U32 }, + [TCA_CTINFO_PARMS_DSCP_STATEMASK] = { .type = NLA_U32 }, + [TCA_CTINFO_PARMS_CPMARK_MASK] = { .type = NLA_U32 }, +}; + +static int tcf_ctinfo_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, u32 flags, + struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_ctinfo_ops.net_id); + bool bind = flags & TCA_ACT_FLAGS_BIND; + u32 dscpmask = 0, dscpstatemask, index; + struct nlattr *tb[TCA_CTINFO_MAX + 1]; + struct tcf_ctinfo_params *cp_new; + struct tcf_chain *goto_ch = NULL; + struct tc_ctinfo *actparm; + struct tcf_ctinfo *ci; + u8 dscpmaskshift; + int ret = 0, err; + + if (!nla) { + NL_SET_ERR_MSG_MOD(extack, "ctinfo requires attributes to be passed"); + return -EINVAL; + } + + err = nla_parse_nested(tb, TCA_CTINFO_MAX, nla, ctinfo_policy, extack); + if (err < 0) + return err; + + if (!tb[TCA_CTINFO_ACT]) { + NL_SET_ERR_MSG_MOD(extack, + "Missing required TCA_CTINFO_ACT attribute"); + return -EINVAL; + } + actparm = nla_data(tb[TCA_CTINFO_ACT]); + + /* do some basic validation here before dynamically allocating things */ + /* that we would otherwise have to clean up. */ + if (tb[TCA_CTINFO_PARMS_DSCP_MASK]) { + dscpmask = nla_get_u32(tb[TCA_CTINFO_PARMS_DSCP_MASK]); + /* need contiguous 6 bit mask */ + dscpmaskshift = dscpmask ? __ffs(dscpmask) : 0; + if ((~0 & (dscpmask >> dscpmaskshift)) != 0x3f) { + NL_SET_ERR_MSG_ATTR(extack, + tb[TCA_CTINFO_PARMS_DSCP_MASK], + "dscp mask must be 6 contiguous bits"); + return -EINVAL; + } + dscpstatemask = tb[TCA_CTINFO_PARMS_DSCP_STATEMASK] ? + nla_get_u32(tb[TCA_CTINFO_PARMS_DSCP_STATEMASK]) : 0; + /* mask & statemask must not overlap */ + if (dscpmask & dscpstatemask) { + NL_SET_ERR_MSG_ATTR(extack, + tb[TCA_CTINFO_PARMS_DSCP_STATEMASK], + "dscp statemask must not overlap dscp mask"); + return -EINVAL; + } + } + + /* done the validation:now to the actual action allocation */ + index = actparm->index; + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (!err) { + ret = tcf_idr_create_from_flags(tn, index, est, a, + &act_ctinfo_ops, bind, flags); + if (ret) { + tcf_idr_cleanup(tn, index); + return ret; + } + ret = ACT_P_CREATED; + } else if (err > 0) { + if (bind) /* don't override defaults */ + return 0; + if (!(flags & TCA_ACT_FLAGS_REPLACE)) { + tcf_idr_release(*a, bind); + return -EEXIST; + } + } else { + return err; + } + + err = tcf_action_check_ctrlact(actparm->action, tp, &goto_ch, extack); + if (err < 0) + goto release_idr; + + ci = to_ctinfo(*a); + + cp_new = kzalloc(sizeof(*cp_new), GFP_KERNEL); + if (unlikely(!cp_new)) { + err = -ENOMEM; + goto put_chain; + } + + cp_new->net = net; + cp_new->zone = tb[TCA_CTINFO_ZONE] ? + nla_get_u16(tb[TCA_CTINFO_ZONE]) : 0; + if (dscpmask) { + cp_new->dscpmask = dscpmask; + cp_new->dscpmaskshift = dscpmaskshift; + cp_new->dscpstatemask = dscpstatemask; + cp_new->mode |= CTINFO_MODE_DSCP; + } + + if (tb[TCA_CTINFO_PARMS_CPMARK_MASK]) { + cp_new->cpmarkmask = + nla_get_u32(tb[TCA_CTINFO_PARMS_CPMARK_MASK]); + cp_new->mode |= CTINFO_MODE_CPMARK; + } + + spin_lock_bh(&ci->tcf_lock); + goto_ch = tcf_action_set_ctrlact(*a, actparm->action, goto_ch); + cp_new = rcu_replace_pointer(ci->params, cp_new, + lockdep_is_held(&ci->tcf_lock)); + spin_unlock_bh(&ci->tcf_lock); + + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + if (cp_new) + kfree_rcu(cp_new, rcu); + + return ret; + +put_chain: + if (goto_ch) + tcf_chain_put_by_act(goto_ch); +release_idr: + tcf_idr_release(*a, bind); + return err; +} + +static int tcf_ctinfo_dump(struct sk_buff *skb, struct tc_action *a, + int bind, int ref) +{ + struct tcf_ctinfo *ci = to_ctinfo(a); + struct tc_ctinfo opt = { + .index = ci->tcf_index, + .refcnt = refcount_read(&ci->tcf_refcnt) - ref, + .bindcnt = atomic_read(&ci->tcf_bindcnt) - bind, + }; + unsigned char *b = skb_tail_pointer(skb); + struct tcf_ctinfo_params *cp; + struct tcf_t t; + + spin_lock_bh(&ci->tcf_lock); + cp = rcu_dereference_protected(ci->params, + lockdep_is_held(&ci->tcf_lock)); + + tcf_tm_dump(&t, &ci->tcf_tm); + if (nla_put_64bit(skb, TCA_CTINFO_TM, sizeof(t), &t, TCA_CTINFO_PAD)) + goto nla_put_failure; + + opt.action = ci->tcf_action; + if (nla_put(skb, TCA_CTINFO_ACT, sizeof(opt), &opt)) + goto nla_put_failure; + + if (nla_put_u16(skb, TCA_CTINFO_ZONE, cp->zone)) + goto nla_put_failure; + + if (cp->mode & CTINFO_MODE_DSCP) { + if (nla_put_u32(skb, TCA_CTINFO_PARMS_DSCP_MASK, + cp->dscpmask)) + goto nla_put_failure; + if (nla_put_u32(skb, TCA_CTINFO_PARMS_DSCP_STATEMASK, + cp->dscpstatemask)) + goto nla_put_failure; + } + + if (cp->mode & CTINFO_MODE_CPMARK) { + if (nla_put_u32(skb, TCA_CTINFO_PARMS_CPMARK_MASK, + cp->cpmarkmask)) + goto nla_put_failure; + } + + if (nla_put_u64_64bit(skb, TCA_CTINFO_STATS_DSCP_SET, + ci->stats_dscp_set, TCA_CTINFO_PAD)) + goto nla_put_failure; + + if (nla_put_u64_64bit(skb, TCA_CTINFO_STATS_DSCP_ERROR, + ci->stats_dscp_error, TCA_CTINFO_PAD)) + goto nla_put_failure; + + if (nla_put_u64_64bit(skb, TCA_CTINFO_STATS_CPMARK_SET, + ci->stats_cpmark_set, TCA_CTINFO_PAD)) + goto nla_put_failure; + + spin_unlock_bh(&ci->tcf_lock); + return skb->len; + +nla_put_failure: + spin_unlock_bh(&ci->tcf_lock); + nlmsg_trim(skb, b); + return -1; +} + +static void tcf_ctinfo_cleanup(struct tc_action *a) +{ + struct tcf_ctinfo *ci = to_ctinfo(a); + struct tcf_ctinfo_params *cp; + + cp = rcu_dereference_protected(ci->params, 1); + if (cp) + kfree_rcu(cp, rcu); +} + +static struct tc_action_ops act_ctinfo_ops = { + .kind = "ctinfo", + .id = TCA_ID_CTINFO, + .owner = THIS_MODULE, + .act = tcf_ctinfo_act, + .dump = tcf_ctinfo_dump, + .init = tcf_ctinfo_init, + .cleanup= tcf_ctinfo_cleanup, + .size = sizeof(struct tcf_ctinfo), +}; + +static __net_init int ctinfo_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_ctinfo_ops.net_id); + + return tc_action_net_init(net, tn, &act_ctinfo_ops); +} + +static void __net_exit ctinfo_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_ctinfo_ops.net_id); +} + +static struct pernet_operations ctinfo_net_ops = { + .init = ctinfo_init_net, + .exit_batch = ctinfo_exit_net, + .id = &act_ctinfo_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +static int __init ctinfo_init_module(void) +{ + return tcf_register_action(&act_ctinfo_ops, &ctinfo_net_ops); +} + +static void __exit ctinfo_cleanup_module(void) +{ + tcf_unregister_action(&act_ctinfo_ops, &ctinfo_net_ops); +} + +module_init(ctinfo_init_module); +module_exit(ctinfo_cleanup_module); +MODULE_AUTHOR("Kevin Darbyshire-Bryant "); +MODULE_DESCRIPTION("Connection tracking mark actions"); +MODULE_LICENSE("GPL"); diff --git a/net/sched/act_gact.c b/net/sched/act_gact.c new file mode 100644 index 000000000..62d682b96 --- /dev/null +++ b/net/sched/act_gact.c @@ -0,0 +1,338 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/act_gact.c Generic actions + * + * copyright Jamal Hadi Salim (2002-4) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static struct tc_action_ops act_gact_ops; + +#ifdef CONFIG_GACT_PROB +static int gact_net_rand(struct tcf_gact *gact) +{ + smp_rmb(); /* coupled with smp_wmb() in tcf_gact_init() */ + if (prandom_u32_max(gact->tcfg_pval)) + return gact->tcf_action; + return gact->tcfg_paction; +} + +static int gact_determ(struct tcf_gact *gact) +{ + u32 pack = atomic_inc_return(&gact->packets); + + smp_rmb(); /* coupled with smp_wmb() in tcf_gact_init() */ + if (pack % gact->tcfg_pval) + return gact->tcf_action; + return gact->tcfg_paction; +} + +typedef int (*g_rand)(struct tcf_gact *gact); +static g_rand gact_rand[MAX_RAND] = { NULL, gact_net_rand, gact_determ }; +#endif /* CONFIG_GACT_PROB */ + +static const struct nla_policy gact_policy[TCA_GACT_MAX + 1] = { + [TCA_GACT_PARMS] = { .len = sizeof(struct tc_gact) }, + [TCA_GACT_PROB] = { .len = sizeof(struct tc_gact_p) }, +}; + +static int tcf_gact_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, u32 flags, + struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_gact_ops.net_id); + bool bind = flags & TCA_ACT_FLAGS_BIND; + struct nlattr *tb[TCA_GACT_MAX + 1]; + struct tcf_chain *goto_ch = NULL; + struct tc_gact *parm; + struct tcf_gact *gact; + int ret = 0; + u32 index; + int err; +#ifdef CONFIG_GACT_PROB + struct tc_gact_p *p_parm = NULL; +#endif + + if (nla == NULL) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, TCA_GACT_MAX, nla, gact_policy, + NULL); + if (err < 0) + return err; + + if (tb[TCA_GACT_PARMS] == NULL) + return -EINVAL; + parm = nla_data(tb[TCA_GACT_PARMS]); + index = parm->index; + +#ifndef CONFIG_GACT_PROB + if (tb[TCA_GACT_PROB] != NULL) + return -EOPNOTSUPP; +#else + if (tb[TCA_GACT_PROB]) { + p_parm = nla_data(tb[TCA_GACT_PROB]); + if (p_parm->ptype >= MAX_RAND) + return -EINVAL; + if (TC_ACT_EXT_CMP(p_parm->paction, TC_ACT_GOTO_CHAIN)) { + NL_SET_ERR_MSG(extack, + "goto chain not allowed on fallback"); + return -EINVAL; + } + } +#endif + + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (!err) { + ret = tcf_idr_create_from_flags(tn, index, est, a, + &act_gact_ops, bind, flags); + if (ret) { + tcf_idr_cleanup(tn, index); + return ret; + } + ret = ACT_P_CREATED; + } else if (err > 0) { + if (bind)/* dont override defaults */ + return 0; + if (!(flags & TCA_ACT_FLAGS_REPLACE)) { + tcf_idr_release(*a, bind); + return -EEXIST; + } + } else { + return err; + } + + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack); + if (err < 0) + goto release_idr; + gact = to_gact(*a); + + spin_lock_bh(&gact->tcf_lock); + goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch); +#ifdef CONFIG_GACT_PROB + if (p_parm) { + gact->tcfg_paction = p_parm->paction; + gact->tcfg_pval = max_t(u16, 1, p_parm->pval); + /* Make sure tcfg_pval is written before tcfg_ptype + * coupled with smp_rmb() in gact_net_rand() & gact_determ() + */ + smp_wmb(); + gact->tcfg_ptype = p_parm->ptype; + } +#endif + spin_unlock_bh(&gact->tcf_lock); + + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + + return ret; +release_idr: + tcf_idr_release(*a, bind); + return err; +} + +static int tcf_gact_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + struct tcf_gact *gact = to_gact(a); + int action = READ_ONCE(gact->tcf_action); + +#ifdef CONFIG_GACT_PROB + { + u32 ptype = READ_ONCE(gact->tcfg_ptype); + + if (ptype) + action = gact_rand[ptype](gact); + } +#endif + tcf_action_update_bstats(&gact->common, skb); + if (action == TC_ACT_SHOT) + tcf_action_inc_drop_qstats(&gact->common); + + tcf_lastuse_update(&gact->tcf_tm); + + return action; +} + +static void tcf_gact_stats_update(struct tc_action *a, u64 bytes, u64 packets, + u64 drops, u64 lastuse, bool hw) +{ + struct tcf_gact *gact = to_gact(a); + int action = READ_ONCE(gact->tcf_action); + struct tcf_t *tm = &gact->tcf_tm; + + tcf_action_update_stats(a, bytes, packets, + action == TC_ACT_SHOT ? packets : drops, hw); + tm->lastuse = max_t(u64, tm->lastuse, lastuse); +} + +static int tcf_gact_dump(struct sk_buff *skb, struct tc_action *a, + int bind, int ref) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tcf_gact *gact = to_gact(a); + struct tc_gact opt = { + .index = gact->tcf_index, + .refcnt = refcount_read(&gact->tcf_refcnt) - ref, + .bindcnt = atomic_read(&gact->tcf_bindcnt) - bind, + }; + struct tcf_t t; + + spin_lock_bh(&gact->tcf_lock); + opt.action = gact->tcf_action; + if (nla_put(skb, TCA_GACT_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; +#ifdef CONFIG_GACT_PROB + if (gact->tcfg_ptype) { + struct tc_gact_p p_opt = { + .paction = gact->tcfg_paction, + .pval = gact->tcfg_pval, + .ptype = gact->tcfg_ptype, + }; + + if (nla_put(skb, TCA_GACT_PROB, sizeof(p_opt), &p_opt)) + goto nla_put_failure; + } +#endif + tcf_tm_dump(&t, &gact->tcf_tm); + if (nla_put_64bit(skb, TCA_GACT_TM, sizeof(t), &t, TCA_GACT_PAD)) + goto nla_put_failure; + spin_unlock_bh(&gact->tcf_lock); + + return skb->len; + +nla_put_failure: + spin_unlock_bh(&gact->tcf_lock); + nlmsg_trim(skb, b); + return -1; +} + +static size_t tcf_gact_get_fill_size(const struct tc_action *act) +{ + size_t sz = nla_total_size(sizeof(struct tc_gact)); /* TCA_GACT_PARMS */ + +#ifdef CONFIG_GACT_PROB + if (to_gact(act)->tcfg_ptype) + /* TCA_GACT_PROB */ + sz += nla_total_size(sizeof(struct tc_gact_p)); +#endif + + return sz; +} + +static int tcf_gact_offload_act_setup(struct tc_action *act, void *entry_data, + u32 *index_inc, bool bind, + struct netlink_ext_ack *extack) +{ + if (bind) { + struct flow_action_entry *entry = entry_data; + + if (is_tcf_gact_ok(act)) { + entry->id = FLOW_ACTION_ACCEPT; + } else if (is_tcf_gact_shot(act)) { + entry->id = FLOW_ACTION_DROP; + } else if (is_tcf_gact_trap(act)) { + entry->id = FLOW_ACTION_TRAP; + } else if (is_tcf_gact_goto_chain(act)) { + entry->id = FLOW_ACTION_GOTO; + entry->chain_index = tcf_gact_goto_chain_index(act); + } else if (is_tcf_gact_continue(act)) { + NL_SET_ERR_MSG_MOD(extack, "Offload of \"continue\" action is not supported"); + return -EOPNOTSUPP; + } else if (is_tcf_gact_reclassify(act)) { + NL_SET_ERR_MSG_MOD(extack, "Offload of \"reclassify\" action is not supported"); + return -EOPNOTSUPP; + } else if (is_tcf_gact_pipe(act)) { + NL_SET_ERR_MSG_MOD(extack, "Offload of \"pipe\" action is not supported"); + return -EOPNOTSUPP; + } else { + NL_SET_ERR_MSG_MOD(extack, "Unsupported generic action offload"); + return -EOPNOTSUPP; + } + *index_inc = 1; + } else { + struct flow_offload_action *fl_action = entry_data; + + if (is_tcf_gact_ok(act)) + fl_action->id = FLOW_ACTION_ACCEPT; + else if (is_tcf_gact_shot(act)) + fl_action->id = FLOW_ACTION_DROP; + else if (is_tcf_gact_trap(act)) + fl_action->id = FLOW_ACTION_TRAP; + else if (is_tcf_gact_goto_chain(act)) + fl_action->id = FLOW_ACTION_GOTO; + else + return -EOPNOTSUPP; + } + + return 0; +} + +static struct tc_action_ops act_gact_ops = { + .kind = "gact", + .id = TCA_ID_GACT, + .owner = THIS_MODULE, + .act = tcf_gact_act, + .stats_update = tcf_gact_stats_update, + .dump = tcf_gact_dump, + .init = tcf_gact_init, + .get_fill_size = tcf_gact_get_fill_size, + .offload_act_setup = tcf_gact_offload_act_setup, + .size = sizeof(struct tcf_gact), +}; + +static __net_init int gact_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_gact_ops.net_id); + + return tc_action_net_init(net, tn, &act_gact_ops); +} + +static void __net_exit gact_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_gact_ops.net_id); +} + +static struct pernet_operations gact_net_ops = { + .init = gact_init_net, + .exit_batch = gact_exit_net, + .id = &act_gact_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +MODULE_AUTHOR("Jamal Hadi Salim(2002-4)"); +MODULE_DESCRIPTION("Generic Classifier actions"); +MODULE_LICENSE("GPL"); + +static int __init gact_init_module(void) +{ +#ifdef CONFIG_GACT_PROB + pr_info("GACT probability on\n"); +#else + pr_info("GACT probability NOT on\n"); +#endif + + return tcf_register_action(&act_gact_ops, &gact_net_ops); +} + +static void __exit gact_cleanup_module(void) +{ + tcf_unregister_action(&act_gact_ops, &gact_net_ops); +} + +module_init(gact_init_module); +module_exit(gact_cleanup_module); diff --git a/net/sched/act_gate.c b/net/sched/act_gate.c new file mode 100644 index 000000000..3049878e7 --- /dev/null +++ b/net/sched/act_gate.c @@ -0,0 +1,676 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* Copyright 2020 NXP */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static struct tc_action_ops act_gate_ops; + +static ktime_t gate_get_time(struct tcf_gate *gact) +{ + ktime_t mono = ktime_get(); + + switch (gact->tk_offset) { + case TK_OFFS_MAX: + return mono; + default: + return ktime_mono_to_any(mono, gact->tk_offset); + } + + return KTIME_MAX; +} + +static void gate_get_start_time(struct tcf_gate *gact, ktime_t *start) +{ + struct tcf_gate_params *param = &gact->param; + ktime_t now, base, cycle; + u64 n; + + base = ns_to_ktime(param->tcfg_basetime); + now = gate_get_time(gact); + + if (ktime_after(base, now)) { + *start = base; + return; + } + + cycle = param->tcfg_cycletime; + + n = div64_u64(ktime_sub_ns(now, base), cycle); + *start = ktime_add_ns(base, (n + 1) * cycle); +} + +static void gate_start_timer(struct tcf_gate *gact, ktime_t start) +{ + ktime_t expires; + + expires = hrtimer_get_expires(&gact->hitimer); + if (expires == 0) + expires = KTIME_MAX; + + start = min_t(ktime_t, start, expires); + + hrtimer_start(&gact->hitimer, start, HRTIMER_MODE_ABS_SOFT); +} + +static enum hrtimer_restart gate_timer_func(struct hrtimer *timer) +{ + struct tcf_gate *gact = container_of(timer, struct tcf_gate, + hitimer); + struct tcf_gate_params *p = &gact->param; + struct tcfg_gate_entry *next; + ktime_t close_time, now; + + spin_lock(&gact->tcf_lock); + + next = gact->next_entry; + + /* cycle start, clear pending bit, clear total octets */ + gact->current_gate_status = next->gate_state ? GATE_ACT_GATE_OPEN : 0; + gact->current_entry_octets = 0; + gact->current_max_octets = next->maxoctets; + + gact->current_close_time = ktime_add_ns(gact->current_close_time, + next->interval); + + close_time = gact->current_close_time; + + if (list_is_last(&next->list, &p->entries)) + next = list_first_entry(&p->entries, + struct tcfg_gate_entry, list); + else + next = list_next_entry(next, list); + + now = gate_get_time(gact); + + if (ktime_after(now, close_time)) { + ktime_t cycle, base; + u64 n; + + cycle = p->tcfg_cycletime; + base = ns_to_ktime(p->tcfg_basetime); + n = div64_u64(ktime_sub_ns(now, base), cycle); + close_time = ktime_add_ns(base, (n + 1) * cycle); + } + + gact->next_entry = next; + + hrtimer_set_expires(&gact->hitimer, close_time); + + spin_unlock(&gact->tcf_lock); + + return HRTIMER_RESTART; +} + +static int tcf_gate_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + struct tcf_gate *gact = to_gate(a); + + spin_lock(&gact->tcf_lock); + + tcf_lastuse_update(&gact->tcf_tm); + bstats_update(&gact->tcf_bstats, skb); + + if (unlikely(gact->current_gate_status & GATE_ACT_PENDING)) { + spin_unlock(&gact->tcf_lock); + return gact->tcf_action; + } + + if (!(gact->current_gate_status & GATE_ACT_GATE_OPEN)) + goto drop; + + if (gact->current_max_octets >= 0) { + gact->current_entry_octets += qdisc_pkt_len(skb); + if (gact->current_entry_octets > gact->current_max_octets) { + gact->tcf_qstats.overlimits++; + goto drop; + } + } + + spin_unlock(&gact->tcf_lock); + + return gact->tcf_action; +drop: + gact->tcf_qstats.drops++; + spin_unlock(&gact->tcf_lock); + + return TC_ACT_SHOT; +} + +static const struct nla_policy entry_policy[TCA_GATE_ENTRY_MAX + 1] = { + [TCA_GATE_ENTRY_INDEX] = { .type = NLA_U32 }, + [TCA_GATE_ENTRY_GATE] = { .type = NLA_FLAG }, + [TCA_GATE_ENTRY_INTERVAL] = { .type = NLA_U32 }, + [TCA_GATE_ENTRY_IPV] = { .type = NLA_S32 }, + [TCA_GATE_ENTRY_MAX_OCTETS] = { .type = NLA_S32 }, +}; + +static const struct nla_policy gate_policy[TCA_GATE_MAX + 1] = { + [TCA_GATE_PARMS] = + NLA_POLICY_EXACT_LEN(sizeof(struct tc_gate)), + [TCA_GATE_PRIORITY] = { .type = NLA_S32 }, + [TCA_GATE_ENTRY_LIST] = { .type = NLA_NESTED }, + [TCA_GATE_BASE_TIME] = { .type = NLA_U64 }, + [TCA_GATE_CYCLE_TIME] = { .type = NLA_U64 }, + [TCA_GATE_CYCLE_TIME_EXT] = { .type = NLA_U64 }, + [TCA_GATE_FLAGS] = { .type = NLA_U32 }, + [TCA_GATE_CLOCKID] = { .type = NLA_S32 }, +}; + +static int fill_gate_entry(struct nlattr **tb, struct tcfg_gate_entry *entry, + struct netlink_ext_ack *extack) +{ + u32 interval = 0; + + entry->gate_state = nla_get_flag(tb[TCA_GATE_ENTRY_GATE]); + + if (tb[TCA_GATE_ENTRY_INTERVAL]) + interval = nla_get_u32(tb[TCA_GATE_ENTRY_INTERVAL]); + + if (interval == 0) { + NL_SET_ERR_MSG(extack, "Invalid interval for schedule entry"); + return -EINVAL; + } + + entry->interval = interval; + + if (tb[TCA_GATE_ENTRY_IPV]) + entry->ipv = nla_get_s32(tb[TCA_GATE_ENTRY_IPV]); + else + entry->ipv = -1; + + if (tb[TCA_GATE_ENTRY_MAX_OCTETS]) + entry->maxoctets = nla_get_s32(tb[TCA_GATE_ENTRY_MAX_OCTETS]); + else + entry->maxoctets = -1; + + return 0; +} + +static int parse_gate_entry(struct nlattr *n, struct tcfg_gate_entry *entry, + int index, struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_GATE_ENTRY_MAX + 1] = { }; + int err; + + err = nla_parse_nested(tb, TCA_GATE_ENTRY_MAX, n, entry_policy, extack); + if (err < 0) { + NL_SET_ERR_MSG(extack, "Could not parse nested entry"); + return -EINVAL; + } + + entry->index = index; + + return fill_gate_entry(tb, entry, extack); +} + +static void release_entry_list(struct list_head *entries) +{ + struct tcfg_gate_entry *entry, *e; + + list_for_each_entry_safe(entry, e, entries, list) { + list_del(&entry->list); + kfree(entry); + } +} + +static int parse_gate_list(struct nlattr *list_attr, + struct tcf_gate_params *sched, + struct netlink_ext_ack *extack) +{ + struct tcfg_gate_entry *entry; + struct nlattr *n; + int err, rem; + int i = 0; + + if (!list_attr) + return -EINVAL; + + nla_for_each_nested(n, list_attr, rem) { + if (nla_type(n) != TCA_GATE_ONE_ENTRY) { + NL_SET_ERR_MSG(extack, "Attribute isn't type 'entry'"); + continue; + } + + entry = kzalloc(sizeof(*entry), GFP_ATOMIC); + if (!entry) { + NL_SET_ERR_MSG(extack, "Not enough memory for entry"); + err = -ENOMEM; + goto release_list; + } + + err = parse_gate_entry(n, entry, i, extack); + if (err < 0) { + kfree(entry); + goto release_list; + } + + list_add_tail(&entry->list, &sched->entries); + i++; + } + + sched->num_entries = i; + + return i; + +release_list: + release_entry_list(&sched->entries); + + return err; +} + +static void gate_setup_timer(struct tcf_gate *gact, u64 basetime, + enum tk_offsets tko, s32 clockid, + bool do_init) +{ + if (!do_init) { + if (basetime == gact->param.tcfg_basetime && + tko == gact->tk_offset && + clockid == gact->param.tcfg_clockid) + return; + + spin_unlock_bh(&gact->tcf_lock); + hrtimer_cancel(&gact->hitimer); + spin_lock_bh(&gact->tcf_lock); + } + gact->param.tcfg_basetime = basetime; + gact->param.tcfg_clockid = clockid; + gact->tk_offset = tko; + hrtimer_init(&gact->hitimer, clockid, HRTIMER_MODE_ABS_SOFT); + gact->hitimer.function = gate_timer_func; +} + +static int tcf_gate_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, u32 flags, + struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_gate_ops.net_id); + enum tk_offsets tk_offset = TK_OFFS_TAI; + bool bind = flags & TCA_ACT_FLAGS_BIND; + struct nlattr *tb[TCA_GATE_MAX + 1]; + struct tcf_chain *goto_ch = NULL; + u64 cycletime = 0, basetime = 0; + struct tcf_gate_params *p; + s32 clockid = CLOCK_TAI; + struct tcf_gate *gact; + struct tc_gate *parm; + int ret = 0, err; + u32 gflags = 0; + s32 prio = -1; + ktime_t start; + u32 index; + + if (!nla) + return -EINVAL; + + err = nla_parse_nested(tb, TCA_GATE_MAX, nla, gate_policy, extack); + if (err < 0) + return err; + + if (!tb[TCA_GATE_PARMS]) + return -EINVAL; + + if (tb[TCA_GATE_CLOCKID]) { + clockid = nla_get_s32(tb[TCA_GATE_CLOCKID]); + switch (clockid) { + case CLOCK_REALTIME: + tk_offset = TK_OFFS_REAL; + break; + case CLOCK_MONOTONIC: + tk_offset = TK_OFFS_MAX; + break; + case CLOCK_BOOTTIME: + tk_offset = TK_OFFS_BOOT; + break; + case CLOCK_TAI: + tk_offset = TK_OFFS_TAI; + break; + default: + NL_SET_ERR_MSG(extack, "Invalid 'clockid'"); + return -EINVAL; + } + } + + parm = nla_data(tb[TCA_GATE_PARMS]); + index = parm->index; + + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (err < 0) + return err; + + if (err && bind) + return 0; + + if (!err) { + ret = tcf_idr_create(tn, index, est, a, + &act_gate_ops, bind, false, flags); + if (ret) { + tcf_idr_cleanup(tn, index); + return ret; + } + + ret = ACT_P_CREATED; + } else if (!(flags & TCA_ACT_FLAGS_REPLACE)) { + tcf_idr_release(*a, bind); + return -EEXIST; + } + + if (tb[TCA_GATE_PRIORITY]) + prio = nla_get_s32(tb[TCA_GATE_PRIORITY]); + + if (tb[TCA_GATE_BASE_TIME]) + basetime = nla_get_u64(tb[TCA_GATE_BASE_TIME]); + + if (tb[TCA_GATE_FLAGS]) + gflags = nla_get_u32(tb[TCA_GATE_FLAGS]); + + gact = to_gate(*a); + if (ret == ACT_P_CREATED) + INIT_LIST_HEAD(&gact->param.entries); + + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack); + if (err < 0) + goto release_idr; + + spin_lock_bh(&gact->tcf_lock); + p = &gact->param; + + if (tb[TCA_GATE_CYCLE_TIME]) + cycletime = nla_get_u64(tb[TCA_GATE_CYCLE_TIME]); + + if (tb[TCA_GATE_ENTRY_LIST]) { + err = parse_gate_list(tb[TCA_GATE_ENTRY_LIST], p, extack); + if (err < 0) + goto chain_put; + } + + if (!cycletime) { + struct tcfg_gate_entry *entry; + ktime_t cycle = 0; + + list_for_each_entry(entry, &p->entries, list) + cycle = ktime_add_ns(cycle, entry->interval); + cycletime = cycle; + if (!cycletime) { + err = -EINVAL; + goto chain_put; + } + } + p->tcfg_cycletime = cycletime; + + if (tb[TCA_GATE_CYCLE_TIME_EXT]) + p->tcfg_cycletime_ext = + nla_get_u64(tb[TCA_GATE_CYCLE_TIME_EXT]); + + gate_setup_timer(gact, basetime, tk_offset, clockid, + ret == ACT_P_CREATED); + p->tcfg_priority = prio; + p->tcfg_flags = gflags; + gate_get_start_time(gact, &start); + + gact->current_close_time = start; + gact->current_gate_status = GATE_ACT_GATE_OPEN | GATE_ACT_PENDING; + + gact->next_entry = list_first_entry(&p->entries, + struct tcfg_gate_entry, list); + + goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch); + + gate_start_timer(gact, start); + + spin_unlock_bh(&gact->tcf_lock); + + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + + return ret; + +chain_put: + spin_unlock_bh(&gact->tcf_lock); + + if (goto_ch) + tcf_chain_put_by_act(goto_ch); +release_idr: + /* action is not inserted in any list: it's safe to init hitimer + * without taking tcf_lock. + */ + if (ret == ACT_P_CREATED) + gate_setup_timer(gact, gact->param.tcfg_basetime, + gact->tk_offset, gact->param.tcfg_clockid, + true); + tcf_idr_release(*a, bind); + return err; +} + +static void tcf_gate_cleanup(struct tc_action *a) +{ + struct tcf_gate *gact = to_gate(a); + struct tcf_gate_params *p; + + p = &gact->param; + hrtimer_cancel(&gact->hitimer); + release_entry_list(&p->entries); +} + +static int dumping_entry(struct sk_buff *skb, + struct tcfg_gate_entry *entry) +{ + struct nlattr *item; + + item = nla_nest_start_noflag(skb, TCA_GATE_ONE_ENTRY); + if (!item) + return -ENOSPC; + + if (nla_put_u32(skb, TCA_GATE_ENTRY_INDEX, entry->index)) + goto nla_put_failure; + + if (entry->gate_state && nla_put_flag(skb, TCA_GATE_ENTRY_GATE)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_GATE_ENTRY_INTERVAL, entry->interval)) + goto nla_put_failure; + + if (nla_put_s32(skb, TCA_GATE_ENTRY_MAX_OCTETS, entry->maxoctets)) + goto nla_put_failure; + + if (nla_put_s32(skb, TCA_GATE_ENTRY_IPV, entry->ipv)) + goto nla_put_failure; + + return nla_nest_end(skb, item); + +nla_put_failure: + nla_nest_cancel(skb, item); + return -1; +} + +static int tcf_gate_dump(struct sk_buff *skb, struct tc_action *a, + int bind, int ref) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tcf_gate *gact = to_gate(a); + struct tc_gate opt = { + .index = gact->tcf_index, + .refcnt = refcount_read(&gact->tcf_refcnt) - ref, + .bindcnt = atomic_read(&gact->tcf_bindcnt) - bind, + }; + struct tcfg_gate_entry *entry; + struct tcf_gate_params *p; + struct nlattr *entry_list; + struct tcf_t t; + + spin_lock_bh(&gact->tcf_lock); + opt.action = gact->tcf_action; + + p = &gact->param; + + if (nla_put(skb, TCA_GATE_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + + if (nla_put_u64_64bit(skb, TCA_GATE_BASE_TIME, + p->tcfg_basetime, TCA_GATE_PAD)) + goto nla_put_failure; + + if (nla_put_u64_64bit(skb, TCA_GATE_CYCLE_TIME, + p->tcfg_cycletime, TCA_GATE_PAD)) + goto nla_put_failure; + + if (nla_put_u64_64bit(skb, TCA_GATE_CYCLE_TIME_EXT, + p->tcfg_cycletime_ext, TCA_GATE_PAD)) + goto nla_put_failure; + + if (nla_put_s32(skb, TCA_GATE_CLOCKID, p->tcfg_clockid)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_GATE_FLAGS, p->tcfg_flags)) + goto nla_put_failure; + + if (nla_put_s32(skb, TCA_GATE_PRIORITY, p->tcfg_priority)) + goto nla_put_failure; + + entry_list = nla_nest_start_noflag(skb, TCA_GATE_ENTRY_LIST); + if (!entry_list) + goto nla_put_failure; + + list_for_each_entry(entry, &p->entries, list) { + if (dumping_entry(skb, entry) < 0) + goto nla_put_failure; + } + + nla_nest_end(skb, entry_list); + + tcf_tm_dump(&t, &gact->tcf_tm); + if (nla_put_64bit(skb, TCA_GATE_TM, sizeof(t), &t, TCA_GATE_PAD)) + goto nla_put_failure; + spin_unlock_bh(&gact->tcf_lock); + + return skb->len; + +nla_put_failure: + spin_unlock_bh(&gact->tcf_lock); + nlmsg_trim(skb, b); + return -1; +} + +static void tcf_gate_stats_update(struct tc_action *a, u64 bytes, u64 packets, + u64 drops, u64 lastuse, bool hw) +{ + struct tcf_gate *gact = to_gate(a); + struct tcf_t *tm = &gact->tcf_tm; + + tcf_action_update_stats(a, bytes, packets, drops, hw); + tm->lastuse = max_t(u64, tm->lastuse, lastuse); +} + +static size_t tcf_gate_get_fill_size(const struct tc_action *act) +{ + return nla_total_size(sizeof(struct tc_gate)); +} + +static void tcf_gate_entry_destructor(void *priv) +{ + struct action_gate_entry *oe = priv; + + kfree(oe); +} + +static int tcf_gate_get_entries(struct flow_action_entry *entry, + const struct tc_action *act) +{ + entry->gate.entries = tcf_gate_get_list(act); + + if (!entry->gate.entries) + return -EINVAL; + + entry->destructor = tcf_gate_entry_destructor; + entry->destructor_priv = entry->gate.entries; + + return 0; +} + +static int tcf_gate_offload_act_setup(struct tc_action *act, void *entry_data, + u32 *index_inc, bool bind, + struct netlink_ext_ack *extack) +{ + int err; + + if (bind) { + struct flow_action_entry *entry = entry_data; + + entry->id = FLOW_ACTION_GATE; + entry->gate.prio = tcf_gate_prio(act); + entry->gate.basetime = tcf_gate_basetime(act); + entry->gate.cycletime = tcf_gate_cycletime(act); + entry->gate.cycletimeext = tcf_gate_cycletimeext(act); + entry->gate.num_entries = tcf_gate_num_entries(act); + err = tcf_gate_get_entries(entry, act); + if (err) + return err; + *index_inc = 1; + } else { + struct flow_offload_action *fl_action = entry_data; + + fl_action->id = FLOW_ACTION_GATE; + } + + return 0; +} + +static struct tc_action_ops act_gate_ops = { + .kind = "gate", + .id = TCA_ID_GATE, + .owner = THIS_MODULE, + .act = tcf_gate_act, + .dump = tcf_gate_dump, + .init = tcf_gate_init, + .cleanup = tcf_gate_cleanup, + .stats_update = tcf_gate_stats_update, + .get_fill_size = tcf_gate_get_fill_size, + .offload_act_setup = tcf_gate_offload_act_setup, + .size = sizeof(struct tcf_gate), +}; + +static __net_init int gate_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_gate_ops.net_id); + + return tc_action_net_init(net, tn, &act_gate_ops); +} + +static void __net_exit gate_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_gate_ops.net_id); +} + +static struct pernet_operations gate_net_ops = { + .init = gate_init_net, + .exit_batch = gate_exit_net, + .id = &act_gate_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +static int __init gate_init_module(void) +{ + return tcf_register_action(&act_gate_ops, &gate_net_ops); +} + +static void __exit gate_cleanup_module(void) +{ + tcf_unregister_action(&act_gate_ops, &gate_net_ops); +} + +module_init(gate_init_module); +module_exit(gate_cleanup_module); +MODULE_LICENSE("GPL v2"); diff --git a/net/sched/act_ife.c b/net/sched/act_ife.c new file mode 100644 index 000000000..41d63b334 --- /dev/null +++ b/net/sched/act_ife.c @@ -0,0 +1,925 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/ife.c Inter-FE action based on ForCES WG InterFE LFB + * + * Refer to: + * draft-ietf-forces-interfelfb-03 + * and + * netdev01 paper: + * "Distributing Linux Traffic Control Classifier-Action + * Subsystem" + * Authors: Jamal Hadi Salim and Damascene M. Joachimpillai + * + * copyright Jamal Hadi Salim (2015) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int max_metacnt = IFE_META_MAX + 1; +static struct tc_action_ops act_ife_ops; + +static const struct nla_policy ife_policy[TCA_IFE_MAX + 1] = { + [TCA_IFE_PARMS] = { .len = sizeof(struct tc_ife)}, + [TCA_IFE_DMAC] = { .len = ETH_ALEN}, + [TCA_IFE_SMAC] = { .len = ETH_ALEN}, + [TCA_IFE_TYPE] = { .type = NLA_U16}, +}; + +int ife_encode_meta_u16(u16 metaval, void *skbdata, struct tcf_meta_info *mi) +{ + u16 edata = 0; + + if (mi->metaval) + edata = *(u16 *)mi->metaval; + else if (metaval) + edata = metaval; + + if (!edata) /* will not encode */ + return 0; + + edata = htons(edata); + return ife_tlv_meta_encode(skbdata, mi->metaid, 2, &edata); +} +EXPORT_SYMBOL_GPL(ife_encode_meta_u16); + +int ife_get_meta_u32(struct sk_buff *skb, struct tcf_meta_info *mi) +{ + if (mi->metaval) + return nla_put_u32(skb, mi->metaid, *(u32 *)mi->metaval); + else + return nla_put(skb, mi->metaid, 0, NULL); +} +EXPORT_SYMBOL_GPL(ife_get_meta_u32); + +int ife_check_meta_u32(u32 metaval, struct tcf_meta_info *mi) +{ + if (metaval || mi->metaval) + return 8; /* T+L+V == 2+2+4 */ + + return 0; +} +EXPORT_SYMBOL_GPL(ife_check_meta_u32); + +int ife_check_meta_u16(u16 metaval, struct tcf_meta_info *mi) +{ + if (metaval || mi->metaval) + return 8; /* T+L+(V) == 2+2+(2+2bytepad) */ + + return 0; +} +EXPORT_SYMBOL_GPL(ife_check_meta_u16); + +int ife_encode_meta_u32(u32 metaval, void *skbdata, struct tcf_meta_info *mi) +{ + u32 edata = metaval; + + if (mi->metaval) + edata = *(u32 *)mi->metaval; + else if (metaval) + edata = metaval; + + if (!edata) /* will not encode */ + return 0; + + edata = htonl(edata); + return ife_tlv_meta_encode(skbdata, mi->metaid, 4, &edata); +} +EXPORT_SYMBOL_GPL(ife_encode_meta_u32); + +int ife_get_meta_u16(struct sk_buff *skb, struct tcf_meta_info *mi) +{ + if (mi->metaval) + return nla_put_u16(skb, mi->metaid, *(u16 *)mi->metaval); + else + return nla_put(skb, mi->metaid, 0, NULL); +} +EXPORT_SYMBOL_GPL(ife_get_meta_u16); + +int ife_alloc_meta_u32(struct tcf_meta_info *mi, void *metaval, gfp_t gfp) +{ + mi->metaval = kmemdup(metaval, sizeof(u32), gfp); + if (!mi->metaval) + return -ENOMEM; + + return 0; +} +EXPORT_SYMBOL_GPL(ife_alloc_meta_u32); + +int ife_alloc_meta_u16(struct tcf_meta_info *mi, void *metaval, gfp_t gfp) +{ + mi->metaval = kmemdup(metaval, sizeof(u16), gfp); + if (!mi->metaval) + return -ENOMEM; + + return 0; +} +EXPORT_SYMBOL_GPL(ife_alloc_meta_u16); + +void ife_release_meta_gen(struct tcf_meta_info *mi) +{ + kfree(mi->metaval); +} +EXPORT_SYMBOL_GPL(ife_release_meta_gen); + +int ife_validate_meta_u32(void *val, int len) +{ + if (len == sizeof(u32)) + return 0; + + return -EINVAL; +} +EXPORT_SYMBOL_GPL(ife_validate_meta_u32); + +int ife_validate_meta_u16(void *val, int len) +{ + /* length will not include padding */ + if (len == sizeof(u16)) + return 0; + + return -EINVAL; +} +EXPORT_SYMBOL_GPL(ife_validate_meta_u16); + +static LIST_HEAD(ifeoplist); +static DEFINE_RWLOCK(ife_mod_lock); + +static struct tcf_meta_ops *find_ife_oplist(u16 metaid) +{ + struct tcf_meta_ops *o; + + read_lock(&ife_mod_lock); + list_for_each_entry(o, &ifeoplist, list) { + if (o->metaid == metaid) { + if (!try_module_get(o->owner)) + o = NULL; + read_unlock(&ife_mod_lock); + return o; + } + } + read_unlock(&ife_mod_lock); + + return NULL; +} + +int register_ife_op(struct tcf_meta_ops *mops) +{ + struct tcf_meta_ops *m; + + if (!mops->metaid || !mops->metatype || !mops->name || + !mops->check_presence || !mops->encode || !mops->decode || + !mops->get || !mops->alloc) + return -EINVAL; + + write_lock(&ife_mod_lock); + + list_for_each_entry(m, &ifeoplist, list) { + if (m->metaid == mops->metaid || + (strcmp(mops->name, m->name) == 0)) { + write_unlock(&ife_mod_lock); + return -EEXIST; + } + } + + if (!mops->release) + mops->release = ife_release_meta_gen; + + list_add_tail(&mops->list, &ifeoplist); + write_unlock(&ife_mod_lock); + return 0; +} +EXPORT_SYMBOL_GPL(unregister_ife_op); + +int unregister_ife_op(struct tcf_meta_ops *mops) +{ + struct tcf_meta_ops *m; + int err = -ENOENT; + + write_lock(&ife_mod_lock); + list_for_each_entry(m, &ifeoplist, list) { + if (m->metaid == mops->metaid) { + list_del(&mops->list); + err = 0; + break; + } + } + write_unlock(&ife_mod_lock); + + return err; +} +EXPORT_SYMBOL_GPL(register_ife_op); + +static int ife_validate_metatype(struct tcf_meta_ops *ops, void *val, int len) +{ + int ret = 0; + /* XXX: unfortunately cant use nla_policy at this point + * because a length of 0 is valid in the case of + * "allow". "use" semantics do enforce for proper + * length and i couldve use nla_policy but it makes it hard + * to use it just for that.. + */ + if (ops->validate) + return ops->validate(val, len); + + if (ops->metatype == NLA_U32) + ret = ife_validate_meta_u32(val, len); + else if (ops->metatype == NLA_U16) + ret = ife_validate_meta_u16(val, len); + + return ret; +} + +#ifdef CONFIG_MODULES +static const char *ife_meta_id2name(u32 metaid) +{ + switch (metaid) { + case IFE_META_SKBMARK: + return "skbmark"; + case IFE_META_PRIO: + return "skbprio"; + case IFE_META_TCINDEX: + return "tcindex"; + default: + return "unknown"; + } +} +#endif + +/* called when adding new meta information +*/ +static int load_metaops_and_vet(u32 metaid, void *val, int len, bool rtnl_held) +{ + struct tcf_meta_ops *ops = find_ife_oplist(metaid); + int ret = 0; + + if (!ops) { + ret = -ENOENT; +#ifdef CONFIG_MODULES + if (rtnl_held) + rtnl_unlock(); + request_module("ife-meta-%s", ife_meta_id2name(metaid)); + if (rtnl_held) + rtnl_lock(); + ops = find_ife_oplist(metaid); +#endif + } + + if (ops) { + ret = 0; + if (len) + ret = ife_validate_metatype(ops, val, len); + + module_put(ops->owner); + } + + return ret; +} + +/* called when adding new meta information +*/ +static int __add_metainfo(const struct tcf_meta_ops *ops, + struct tcf_ife_info *ife, u32 metaid, void *metaval, + int len, bool atomic, bool exists) +{ + struct tcf_meta_info *mi = NULL; + int ret = 0; + + mi = kzalloc(sizeof(*mi), atomic ? GFP_ATOMIC : GFP_KERNEL); + if (!mi) + return -ENOMEM; + + mi->metaid = metaid; + mi->ops = ops; + if (len > 0) { + ret = ops->alloc(mi, metaval, atomic ? GFP_ATOMIC : GFP_KERNEL); + if (ret != 0) { + kfree(mi); + return ret; + } + } + + if (exists) + spin_lock_bh(&ife->tcf_lock); + list_add_tail(&mi->metalist, &ife->metalist); + if (exists) + spin_unlock_bh(&ife->tcf_lock); + + return ret; +} + +static int add_metainfo_and_get_ops(const struct tcf_meta_ops *ops, + struct tcf_ife_info *ife, u32 metaid, + bool exists) +{ + int ret; + + if (!try_module_get(ops->owner)) + return -ENOENT; + ret = __add_metainfo(ops, ife, metaid, NULL, 0, true, exists); + if (ret) + module_put(ops->owner); + return ret; +} + +static int add_metainfo(struct tcf_ife_info *ife, u32 metaid, void *metaval, + int len, bool exists) +{ + const struct tcf_meta_ops *ops = find_ife_oplist(metaid); + int ret; + + if (!ops) + return -ENOENT; + ret = __add_metainfo(ops, ife, metaid, metaval, len, false, exists); + if (ret) + /*put back what find_ife_oplist took */ + module_put(ops->owner); + return ret; +} + +static int use_all_metadata(struct tcf_ife_info *ife, bool exists) +{ + struct tcf_meta_ops *o; + int rc = 0; + int installed = 0; + + read_lock(&ife_mod_lock); + list_for_each_entry(o, &ifeoplist, list) { + rc = add_metainfo_and_get_ops(o, ife, o->metaid, exists); + if (rc == 0) + installed += 1; + } + read_unlock(&ife_mod_lock); + + if (installed) + return 0; + else + return -EINVAL; +} + +static int dump_metalist(struct sk_buff *skb, struct tcf_ife_info *ife) +{ + struct tcf_meta_info *e; + struct nlattr *nest; + unsigned char *b = skb_tail_pointer(skb); + int total_encoded = 0; + + /*can only happen on decode */ + if (list_empty(&ife->metalist)) + return 0; + + nest = nla_nest_start_noflag(skb, TCA_IFE_METALST); + if (!nest) + goto out_nlmsg_trim; + + list_for_each_entry(e, &ife->metalist, metalist) { + if (!e->ops->get(skb, e)) + total_encoded += 1; + } + + if (!total_encoded) + goto out_nlmsg_trim; + + nla_nest_end(skb, nest); + + return 0; + +out_nlmsg_trim: + nlmsg_trim(skb, b); + return -1; +} + +/* under ife->tcf_lock */ +static void _tcf_ife_cleanup(struct tc_action *a) +{ + struct tcf_ife_info *ife = to_ife(a); + struct tcf_meta_info *e, *n; + + list_for_each_entry_safe(e, n, &ife->metalist, metalist) { + list_del(&e->metalist); + if (e->metaval) { + if (e->ops->release) + e->ops->release(e); + else + kfree(e->metaval); + } + module_put(e->ops->owner); + kfree(e); + } +} + +static void tcf_ife_cleanup(struct tc_action *a) +{ + struct tcf_ife_info *ife = to_ife(a); + struct tcf_ife_params *p; + + spin_lock_bh(&ife->tcf_lock); + _tcf_ife_cleanup(a); + spin_unlock_bh(&ife->tcf_lock); + + p = rcu_dereference_protected(ife->params, 1); + if (p) + kfree_rcu(p, rcu); +} + +static int load_metalist(struct nlattr **tb, bool rtnl_held) +{ + int i; + + for (i = 1; i < max_metacnt; i++) { + if (tb[i]) { + void *val = nla_data(tb[i]); + int len = nla_len(tb[i]); + int rc; + + rc = load_metaops_and_vet(i, val, len, rtnl_held); + if (rc != 0) + return rc; + } + } + + return 0; +} + +static int populate_metalist(struct tcf_ife_info *ife, struct nlattr **tb, + bool exists, bool rtnl_held) +{ + int len = 0; + int rc = 0; + int i = 0; + void *val; + + for (i = 1; i < max_metacnt; i++) { + if (tb[i]) { + val = nla_data(tb[i]); + len = nla_len(tb[i]); + + rc = add_metainfo(ife, i, val, len, exists); + if (rc) + return rc; + } + } + + return rc; +} + +static int tcf_ife_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, u32 flags, + struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_ife_ops.net_id); + bool bind = flags & TCA_ACT_FLAGS_BIND; + struct nlattr *tb[TCA_IFE_MAX + 1]; + struct nlattr *tb2[IFE_META_MAX + 1]; + struct tcf_chain *goto_ch = NULL; + struct tcf_ife_params *p; + struct tcf_ife_info *ife; + u16 ife_type = ETH_P_IFE; + struct tc_ife *parm; + u8 *daddr = NULL; + u8 *saddr = NULL; + bool exists = false; + int ret = 0; + u32 index; + int err; + + if (!nla) { + NL_SET_ERR_MSG_MOD(extack, "IFE requires attributes to be passed"); + return -EINVAL; + } + + err = nla_parse_nested_deprecated(tb, TCA_IFE_MAX, nla, ife_policy, + NULL); + if (err < 0) + return err; + + if (!tb[TCA_IFE_PARMS]) + return -EINVAL; + + parm = nla_data(tb[TCA_IFE_PARMS]); + + /* IFE_DECODE is 0 and indicates the opposite of IFE_ENCODE because + * they cannot run as the same time. Check on all other values which + * are not supported right now. + */ + if (parm->flags & ~IFE_ENCODE) + return -EINVAL; + + p = kzalloc(sizeof(*p), GFP_KERNEL); + if (!p) + return -ENOMEM; + + if (tb[TCA_IFE_METALST]) { + err = nla_parse_nested_deprecated(tb2, IFE_META_MAX, + tb[TCA_IFE_METALST], NULL, + NULL); + if (err) { + kfree(p); + return err; + } + err = load_metalist(tb2, !(flags & TCA_ACT_FLAGS_NO_RTNL)); + if (err) { + kfree(p); + return err; + } + } + + index = parm->index; + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (err < 0) { + kfree(p); + return err; + } + exists = err; + if (exists && bind) { + kfree(p); + return 0; + } + + if (!exists) { + ret = tcf_idr_create(tn, index, est, a, &act_ife_ops, + bind, true, flags); + if (ret) { + tcf_idr_cleanup(tn, index); + kfree(p); + return ret; + } + ret = ACT_P_CREATED; + } else if (!(flags & TCA_ACT_FLAGS_REPLACE)) { + tcf_idr_release(*a, bind); + kfree(p); + return -EEXIST; + } + + ife = to_ife(*a); + if (ret == ACT_P_CREATED) + INIT_LIST_HEAD(&ife->metalist); + + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack); + if (err < 0) + goto release_idr; + + p->flags = parm->flags; + + if (parm->flags & IFE_ENCODE) { + if (tb[TCA_IFE_TYPE]) + ife_type = nla_get_u16(tb[TCA_IFE_TYPE]); + if (tb[TCA_IFE_DMAC]) + daddr = nla_data(tb[TCA_IFE_DMAC]); + if (tb[TCA_IFE_SMAC]) + saddr = nla_data(tb[TCA_IFE_SMAC]); + } + + if (parm->flags & IFE_ENCODE) { + if (daddr) + ether_addr_copy(p->eth_dst, daddr); + else + eth_zero_addr(p->eth_dst); + + if (saddr) + ether_addr_copy(p->eth_src, saddr); + else + eth_zero_addr(p->eth_src); + + p->eth_type = ife_type; + } + + if (tb[TCA_IFE_METALST]) { + err = populate_metalist(ife, tb2, exists, + !(flags & TCA_ACT_FLAGS_NO_RTNL)); + if (err) + goto metadata_parse_err; + } else { + /* if no passed metadata allow list or passed allow-all + * then here we process by adding as many supported metadatum + * as we can. You better have at least one else we are + * going to bail out + */ + err = use_all_metadata(ife, exists); + if (err) + goto metadata_parse_err; + } + + if (exists) + spin_lock_bh(&ife->tcf_lock); + /* protected by tcf_lock when modifying existing action */ + goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch); + p = rcu_replace_pointer(ife->params, p, 1); + + if (exists) + spin_unlock_bh(&ife->tcf_lock); + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + if (p) + kfree_rcu(p, rcu); + + return ret; +metadata_parse_err: + if (goto_ch) + tcf_chain_put_by_act(goto_ch); +release_idr: + kfree(p); + tcf_idr_release(*a, bind); + return err; +} + +static int tcf_ife_dump(struct sk_buff *skb, struct tc_action *a, int bind, + int ref) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tcf_ife_info *ife = to_ife(a); + struct tcf_ife_params *p; + struct tc_ife opt = { + .index = ife->tcf_index, + .refcnt = refcount_read(&ife->tcf_refcnt) - ref, + .bindcnt = atomic_read(&ife->tcf_bindcnt) - bind, + }; + struct tcf_t t; + + spin_lock_bh(&ife->tcf_lock); + opt.action = ife->tcf_action; + p = rcu_dereference_protected(ife->params, + lockdep_is_held(&ife->tcf_lock)); + opt.flags = p->flags; + + if (nla_put(skb, TCA_IFE_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + + tcf_tm_dump(&t, &ife->tcf_tm); + if (nla_put_64bit(skb, TCA_IFE_TM, sizeof(t), &t, TCA_IFE_PAD)) + goto nla_put_failure; + + if (!is_zero_ether_addr(p->eth_dst)) { + if (nla_put(skb, TCA_IFE_DMAC, ETH_ALEN, p->eth_dst)) + goto nla_put_failure; + } + + if (!is_zero_ether_addr(p->eth_src)) { + if (nla_put(skb, TCA_IFE_SMAC, ETH_ALEN, p->eth_src)) + goto nla_put_failure; + } + + if (nla_put(skb, TCA_IFE_TYPE, 2, &p->eth_type)) + goto nla_put_failure; + + if (dump_metalist(skb, ife)) { + /*ignore failure to dump metalist */ + pr_info("Failed to dump metalist\n"); + } + + spin_unlock_bh(&ife->tcf_lock); + return skb->len; + +nla_put_failure: + spin_unlock_bh(&ife->tcf_lock); + nlmsg_trim(skb, b); + return -1; +} + +static int find_decode_metaid(struct sk_buff *skb, struct tcf_ife_info *ife, + u16 metaid, u16 mlen, void *mdata) +{ + struct tcf_meta_info *e; + + /* XXX: use hash to speed up */ + list_for_each_entry(e, &ife->metalist, metalist) { + if (metaid == e->metaid) { + if (e->ops) { + /* We check for decode presence already */ + return e->ops->decode(skb, mdata, mlen); + } + } + } + + return -ENOENT; +} + +static int tcf_ife_decode(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + struct tcf_ife_info *ife = to_ife(a); + int action = ife->tcf_action; + u8 *ifehdr_end; + u8 *tlv_data; + u16 metalen; + + bstats_update(this_cpu_ptr(ife->common.cpu_bstats), skb); + tcf_lastuse_update(&ife->tcf_tm); + + if (skb_at_tc_ingress(skb)) + skb_push(skb, skb->dev->hard_header_len); + + tlv_data = ife_decode(skb, &metalen); + if (unlikely(!tlv_data)) { + qstats_drop_inc(this_cpu_ptr(ife->common.cpu_qstats)); + return TC_ACT_SHOT; + } + + ifehdr_end = tlv_data + metalen; + for (; tlv_data < ifehdr_end; tlv_data = ife_tlv_meta_next(tlv_data)) { + u8 *curr_data; + u16 mtype; + u16 dlen; + + curr_data = ife_tlv_meta_decode(tlv_data, ifehdr_end, &mtype, + &dlen, NULL); + if (!curr_data) { + qstats_drop_inc(this_cpu_ptr(ife->common.cpu_qstats)); + return TC_ACT_SHOT; + } + + if (find_decode_metaid(skb, ife, mtype, dlen, curr_data)) { + /* abuse overlimits to count when we receive metadata + * but dont have an ops for it + */ + pr_info_ratelimited("Unknown metaid %d dlen %d\n", + mtype, dlen); + qstats_overlimit_inc(this_cpu_ptr(ife->common.cpu_qstats)); + } + } + + if (WARN_ON(tlv_data != ifehdr_end)) { + qstats_drop_inc(this_cpu_ptr(ife->common.cpu_qstats)); + return TC_ACT_SHOT; + } + + skb->protocol = eth_type_trans(skb, skb->dev); + skb_reset_network_header(skb); + + return action; +} + +/*XXX: check if we can do this at install time instead of current + * send data path +**/ +static int ife_get_sz(struct sk_buff *skb, struct tcf_ife_info *ife) +{ + struct tcf_meta_info *e, *n; + int tot_run_sz = 0, run_sz = 0; + + list_for_each_entry_safe(e, n, &ife->metalist, metalist) { + if (e->ops->check_presence) { + run_sz = e->ops->check_presence(skb, e); + tot_run_sz += run_sz; + } + } + + return tot_run_sz; +} + +static int tcf_ife_encode(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res, struct tcf_ife_params *p) +{ + struct tcf_ife_info *ife = to_ife(a); + int action = ife->tcf_action; + struct ethhdr *oethh; /* outer ether header */ + struct tcf_meta_info *e; + /* + OUTERHDR:TOTMETALEN:{TLVHDR:Metadatum:TLVHDR..}:ORIGDATA + where ORIGDATA = original ethernet header ... + */ + u16 metalen = ife_get_sz(skb, ife); + int hdrm = metalen + skb->dev->hard_header_len + IFE_METAHDRLEN; + unsigned int skboff = 0; + int new_len = skb->len + hdrm; + bool exceed_mtu = false; + void *ife_meta; + int err = 0; + + if (!skb_at_tc_ingress(skb)) { + if (new_len > skb->dev->mtu) + exceed_mtu = true; + } + + bstats_update(this_cpu_ptr(ife->common.cpu_bstats), skb); + tcf_lastuse_update(&ife->tcf_tm); + + if (!metalen) { /* no metadata to send */ + /* abuse overlimits to count when we allow packet + * with no metadata + */ + qstats_overlimit_inc(this_cpu_ptr(ife->common.cpu_qstats)); + return action; + } + /* could be stupid policy setup or mtu config + * so lets be conservative.. */ + if ((action == TC_ACT_SHOT) || exceed_mtu) { + qstats_drop_inc(this_cpu_ptr(ife->common.cpu_qstats)); + return TC_ACT_SHOT; + } + + if (skb_at_tc_ingress(skb)) + skb_push(skb, skb->dev->hard_header_len); + + ife_meta = ife_encode(skb, metalen); + + spin_lock(&ife->tcf_lock); + + /* XXX: we dont have a clever way of telling encode to + * not repeat some of the computations that are done by + * ops->presence_check... + */ + list_for_each_entry(e, &ife->metalist, metalist) { + if (e->ops->encode) { + err = e->ops->encode(skb, (void *)(ife_meta + skboff), + e); + } + if (err < 0) { + /* too corrupt to keep around if overwritten */ + spin_unlock(&ife->tcf_lock); + qstats_drop_inc(this_cpu_ptr(ife->common.cpu_qstats)); + return TC_ACT_SHOT; + } + skboff += err; + } + spin_unlock(&ife->tcf_lock); + oethh = (struct ethhdr *)skb->data; + + if (!is_zero_ether_addr(p->eth_src)) + ether_addr_copy(oethh->h_source, p->eth_src); + if (!is_zero_ether_addr(p->eth_dst)) + ether_addr_copy(oethh->h_dest, p->eth_dst); + oethh->h_proto = htons(p->eth_type); + + if (skb_at_tc_ingress(skb)) + skb_pull(skb, skb->dev->hard_header_len); + + return action; +} + +static int tcf_ife_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + struct tcf_ife_info *ife = to_ife(a); + struct tcf_ife_params *p; + int ret; + + p = rcu_dereference_bh(ife->params); + if (p->flags & IFE_ENCODE) { + ret = tcf_ife_encode(skb, a, res, p); + return ret; + } + + return tcf_ife_decode(skb, a, res); +} + +static struct tc_action_ops act_ife_ops = { + .kind = "ife", + .id = TCA_ID_IFE, + .owner = THIS_MODULE, + .act = tcf_ife_act, + .dump = tcf_ife_dump, + .cleanup = tcf_ife_cleanup, + .init = tcf_ife_init, + .size = sizeof(struct tcf_ife_info), +}; + +static __net_init int ife_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_ife_ops.net_id); + + return tc_action_net_init(net, tn, &act_ife_ops); +} + +static void __net_exit ife_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_ife_ops.net_id); +} + +static struct pernet_operations ife_net_ops = { + .init = ife_init_net, + .exit_batch = ife_exit_net, + .id = &act_ife_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +static int __init ife_init_module(void) +{ + return tcf_register_action(&act_ife_ops, &ife_net_ops); +} + +static void __exit ife_cleanup_module(void) +{ + tcf_unregister_action(&act_ife_ops, &ife_net_ops); +} + +module_init(ife_init_module); +module_exit(ife_cleanup_module); + +MODULE_AUTHOR("Jamal Hadi Salim(2015)"); +MODULE_DESCRIPTION("Inter-FE LFB action"); +MODULE_LICENSE("GPL"); diff --git a/net/sched/act_ipt.c b/net/sched/act_ipt.c new file mode 100644 index 000000000..29974de68 --- /dev/null +++ b/net/sched/act_ipt.c @@ -0,0 +1,452 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/act_ipt.c iptables target interface + * + *TODO: Add other tables. For now we only support the ipv4 table targets + * + * Copyright: Jamal Hadi Salim (2002-13) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + + +static struct tc_action_ops act_ipt_ops; +static struct tc_action_ops act_xt_ops; + +static int ipt_init_target(struct net *net, struct xt_entry_target *t, + char *table, unsigned int hook) +{ + struct xt_tgchk_param par; + struct xt_target *target; + struct ipt_entry e = {}; + int ret = 0; + + target = xt_request_find_target(AF_INET, t->u.user.name, + t->u.user.revision); + if (IS_ERR(target)) + return PTR_ERR(target); + + t->u.kernel.target = target; + memset(&par, 0, sizeof(par)); + par.net = net; + par.table = table; + par.entryinfo = &e; + par.target = target; + par.targinfo = t->data; + par.hook_mask = 1 << hook; + par.family = NFPROTO_IPV4; + + ret = xt_check_target(&par, t->u.target_size - sizeof(*t), 0, false); + if (ret < 0) { + module_put(t->u.kernel.target->me); + return ret; + } + return 0; +} + +static void ipt_destroy_target(struct xt_entry_target *t, struct net *net) +{ + struct xt_tgdtor_param par = { + .target = t->u.kernel.target, + .targinfo = t->data, + .family = NFPROTO_IPV4, + .net = net, + }; + if (par.target->destroy != NULL) + par.target->destroy(&par); + module_put(par.target->me); +} + +static void tcf_ipt_release(struct tc_action *a) +{ + struct tcf_ipt *ipt = to_ipt(a); + + if (ipt->tcfi_t) { + ipt_destroy_target(ipt->tcfi_t, a->idrinfo->net); + kfree(ipt->tcfi_t); + } + kfree(ipt->tcfi_tname); +} + +static const struct nla_policy ipt_policy[TCA_IPT_MAX + 1] = { + [TCA_IPT_TABLE] = { .type = NLA_STRING, .len = IFNAMSIZ }, + [TCA_IPT_HOOK] = NLA_POLICY_RANGE(NLA_U32, NF_INET_PRE_ROUTING, + NF_INET_NUMHOOKS), + [TCA_IPT_INDEX] = { .type = NLA_U32 }, + [TCA_IPT_TARG] = { .len = sizeof(struct xt_entry_target) }, +}; + +static int __tcf_ipt_init(struct net *net, unsigned int id, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + const struct tc_action_ops *ops, + struct tcf_proto *tp, u32 flags) +{ + struct tc_action_net *tn = net_generic(net, id); + bool bind = flags & TCA_ACT_FLAGS_BIND; + struct nlattr *tb[TCA_IPT_MAX + 1]; + struct tcf_ipt *ipt; + struct xt_entry_target *td, *t; + char *tname; + bool exists = false; + int ret = 0, err; + u32 hook = 0; + u32 index = 0; + + if (nla == NULL) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, TCA_IPT_MAX, nla, ipt_policy, + NULL); + if (err < 0) + return err; + + if (tb[TCA_IPT_INDEX] != NULL) + index = nla_get_u32(tb[TCA_IPT_INDEX]); + + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (err < 0) + return err; + exists = err; + if (exists && bind) + return 0; + + if (tb[TCA_IPT_HOOK] == NULL || tb[TCA_IPT_TARG] == NULL) { + if (exists) + tcf_idr_release(*a, bind); + else + tcf_idr_cleanup(tn, index); + return -EINVAL; + } + + td = (struct xt_entry_target *)nla_data(tb[TCA_IPT_TARG]); + if (nla_len(tb[TCA_IPT_TARG]) != td->u.target_size) { + if (exists) + tcf_idr_release(*a, bind); + else + tcf_idr_cleanup(tn, index); + return -EINVAL; + } + + if (!exists) { + ret = tcf_idr_create(tn, index, est, a, ops, bind, + false, flags); + if (ret) { + tcf_idr_cleanup(tn, index); + return ret; + } + ret = ACT_P_CREATED; + } else { + if (bind)/* dont override defaults */ + return 0; + + if (!(flags & TCA_ACT_FLAGS_REPLACE)) { + tcf_idr_release(*a, bind); + return -EEXIST; + } + } + + err = -EINVAL; + hook = nla_get_u32(tb[TCA_IPT_HOOK]); + switch (hook) { + case NF_INET_PRE_ROUTING: + break; + case NF_INET_POST_ROUTING: + break; + default: + goto err1; + } + + if (tb[TCA_IPT_TABLE]) { + /* mangle only for now */ + if (nla_strcmp(tb[TCA_IPT_TABLE], "mangle")) + goto err1; + } + + tname = kstrdup("mangle", GFP_KERNEL); + if (unlikely(!tname)) + goto err1; + + t = kmemdup(td, td->u.target_size, GFP_KERNEL); + if (unlikely(!t)) + goto err2; + + err = ipt_init_target(net, t, tname, hook); + if (err < 0) + goto err3; + + ipt = to_ipt(*a); + + spin_lock_bh(&ipt->tcf_lock); + if (ret != ACT_P_CREATED) { + ipt_destroy_target(ipt->tcfi_t, net); + kfree(ipt->tcfi_tname); + kfree(ipt->tcfi_t); + } + ipt->tcfi_tname = tname; + ipt->tcfi_t = t; + ipt->tcfi_hook = hook; + spin_unlock_bh(&ipt->tcf_lock); + return ret; + +err3: + kfree(t); +err2: + kfree(tname); +err1: + tcf_idr_release(*a, bind); + return err; +} + +static int tcf_ipt_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, + u32 flags, struct netlink_ext_ack *extack) +{ + return __tcf_ipt_init(net, act_ipt_ops.net_id, nla, est, + a, &act_ipt_ops, tp, flags); +} + +static int tcf_xt_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, + u32 flags, struct netlink_ext_ack *extack) +{ + return __tcf_ipt_init(net, act_xt_ops.net_id, nla, est, + a, &act_xt_ops, tp, flags); +} + +static bool tcf_ipt_act_check(struct sk_buff *skb) +{ + const struct iphdr *iph; + unsigned int nhoff, len; + + if (!pskb_may_pull(skb, sizeof(struct iphdr))) + return false; + + nhoff = skb_network_offset(skb); + iph = ip_hdr(skb); + if (iph->ihl < 5 || iph->version != 4) + return false; + + len = skb_ip_totlen(skb); + if (skb->len < nhoff + len || len < (iph->ihl * 4u)) + return false; + + return pskb_may_pull(skb, iph->ihl * 4u); +} + +static int tcf_ipt_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + int ret = 0, result = 0; + struct tcf_ipt *ipt = to_ipt(a); + struct xt_action_param par; + struct nf_hook_state state = { + .net = dev_net(skb->dev), + .in = skb->dev, + .hook = ipt->tcfi_hook, + .pf = NFPROTO_IPV4, + }; + + if (skb_protocol(skb, false) != htons(ETH_P_IP)) + return TC_ACT_UNSPEC; + + if (skb_unclone(skb, GFP_ATOMIC)) + return TC_ACT_UNSPEC; + + if (!tcf_ipt_act_check(skb)) + return TC_ACT_UNSPEC; + + if (state.hook == NF_INET_POST_ROUTING) { + if (!skb_dst(skb)) + return TC_ACT_UNSPEC; + + state.out = skb->dev; + } + + spin_lock(&ipt->tcf_lock); + + tcf_lastuse_update(&ipt->tcf_tm); + bstats_update(&ipt->tcf_bstats, skb); + + /* yes, we have to worry about both in and out dev + * worry later - danger - this API seems to have changed + * from earlier kernels + */ + par.state = &state; + par.target = ipt->tcfi_t->u.kernel.target; + par.targinfo = ipt->tcfi_t->data; + ret = par.target->target(skb, &par); + + switch (ret) { + case NF_ACCEPT: + result = TC_ACT_OK; + break; + case NF_DROP: + result = TC_ACT_SHOT; + ipt->tcf_qstats.drops++; + break; + case XT_CONTINUE: + result = TC_ACT_PIPE; + break; + default: + net_notice_ratelimited("tc filter: Bogus netfilter code %d assume ACCEPT\n", + ret); + result = TC_ACT_OK; + break; + } + spin_unlock(&ipt->tcf_lock); + return result; + +} + +static int tcf_ipt_dump(struct sk_buff *skb, struct tc_action *a, int bind, + int ref) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tcf_ipt *ipt = to_ipt(a); + struct xt_entry_target *t; + struct tcf_t tm; + struct tc_cnt c; + + /* for simple targets kernel size == user size + * user name = target name + * for foolproof you need to not assume this + */ + + spin_lock_bh(&ipt->tcf_lock); + t = kmemdup(ipt->tcfi_t, ipt->tcfi_t->u.user.target_size, GFP_ATOMIC); + if (unlikely(!t)) + goto nla_put_failure; + + c.bindcnt = atomic_read(&ipt->tcf_bindcnt) - bind; + c.refcnt = refcount_read(&ipt->tcf_refcnt) - ref; + strcpy(t->u.user.name, ipt->tcfi_t->u.kernel.target->name); + + if (nla_put(skb, TCA_IPT_TARG, ipt->tcfi_t->u.user.target_size, t) || + nla_put_u32(skb, TCA_IPT_INDEX, ipt->tcf_index) || + nla_put_u32(skb, TCA_IPT_HOOK, ipt->tcfi_hook) || + nla_put(skb, TCA_IPT_CNT, sizeof(struct tc_cnt), &c) || + nla_put_string(skb, TCA_IPT_TABLE, ipt->tcfi_tname)) + goto nla_put_failure; + + tcf_tm_dump(&tm, &ipt->tcf_tm); + if (nla_put_64bit(skb, TCA_IPT_TM, sizeof(tm), &tm, TCA_IPT_PAD)) + goto nla_put_failure; + + spin_unlock_bh(&ipt->tcf_lock); + kfree(t); + return skb->len; + +nla_put_failure: + spin_unlock_bh(&ipt->tcf_lock); + nlmsg_trim(skb, b); + kfree(t); + return -1; +} + +static struct tc_action_ops act_ipt_ops = { + .kind = "ipt", + .id = TCA_ID_IPT, + .owner = THIS_MODULE, + .act = tcf_ipt_act, + .dump = tcf_ipt_dump, + .cleanup = tcf_ipt_release, + .init = tcf_ipt_init, + .size = sizeof(struct tcf_ipt), +}; + +static __net_init int ipt_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_ipt_ops.net_id); + + return tc_action_net_init(net, tn, &act_ipt_ops); +} + +static void __net_exit ipt_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_ipt_ops.net_id); +} + +static struct pernet_operations ipt_net_ops = { + .init = ipt_init_net, + .exit_batch = ipt_exit_net, + .id = &act_ipt_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +static struct tc_action_ops act_xt_ops = { + .kind = "xt", + .id = TCA_ID_XT, + .owner = THIS_MODULE, + .act = tcf_ipt_act, + .dump = tcf_ipt_dump, + .cleanup = tcf_ipt_release, + .init = tcf_xt_init, + .size = sizeof(struct tcf_ipt), +}; + +static __net_init int xt_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_xt_ops.net_id); + + return tc_action_net_init(net, tn, &act_xt_ops); +} + +static void __net_exit xt_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_xt_ops.net_id); +} + +static struct pernet_operations xt_net_ops = { + .init = xt_init_net, + .exit_batch = xt_exit_net, + .id = &act_xt_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +MODULE_AUTHOR("Jamal Hadi Salim(2002-13)"); +MODULE_DESCRIPTION("Iptables target actions"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("act_xt"); + +static int __init ipt_init_module(void) +{ + int ret1, ret2; + + ret1 = tcf_register_action(&act_xt_ops, &xt_net_ops); + if (ret1 < 0) + pr_err("Failed to load xt action\n"); + + ret2 = tcf_register_action(&act_ipt_ops, &ipt_net_ops); + if (ret2 < 0) + pr_err("Failed to load ipt action\n"); + + if (ret1 < 0 && ret2 < 0) { + return ret1; + } else + return 0; +} + +static void __exit ipt_cleanup_module(void) +{ + tcf_unregister_action(&act_ipt_ops, &ipt_net_ops); + tcf_unregister_action(&act_xt_ops, &xt_net_ops); +} + +module_init(ipt_init_module); +module_exit(ipt_cleanup_module); diff --git a/net/sched/act_meta_mark.c b/net/sched/act_meta_mark.c new file mode 100644 index 000000000..ea0573cb8 --- /dev/null +++ b/net/sched/act_meta_mark.c @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/act_meta_mark.c IFE skb->mark metadata module + * + * copyright Jamal Hadi Salim (2015) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int skbmark_encode(struct sk_buff *skb, void *skbdata, + struct tcf_meta_info *e) +{ + u32 ifemark = skb->mark; + + return ife_encode_meta_u32(ifemark, skbdata, e); +} + +static int skbmark_decode(struct sk_buff *skb, void *data, u16 len) +{ + u32 ifemark = *(u32 *)data; + + skb->mark = ntohl(ifemark); + return 0; +} + +static int skbmark_check(struct sk_buff *skb, struct tcf_meta_info *e) +{ + return ife_check_meta_u32(skb->mark, e); +} + +static struct tcf_meta_ops ife_skbmark_ops = { + .metaid = IFE_META_SKBMARK, + .metatype = NLA_U32, + .name = "skbmark", + .synopsis = "skb mark 32 bit metadata", + .check_presence = skbmark_check, + .encode = skbmark_encode, + .decode = skbmark_decode, + .get = ife_get_meta_u32, + .alloc = ife_alloc_meta_u32, + .release = ife_release_meta_gen, + .validate = ife_validate_meta_u32, + .owner = THIS_MODULE, +}; + +static int __init ifemark_init_module(void) +{ + return register_ife_op(&ife_skbmark_ops); +} + +static void __exit ifemark_cleanup_module(void) +{ + unregister_ife_op(&ife_skbmark_ops); +} + +module_init(ifemark_init_module); +module_exit(ifemark_cleanup_module); + +MODULE_AUTHOR("Jamal Hadi Salim(2015)"); +MODULE_DESCRIPTION("Inter-FE skb mark metadata module"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_IFE_META("skbmark"); diff --git a/net/sched/act_meta_skbprio.c b/net/sched/act_meta_skbprio.c new file mode 100644 index 000000000..2df3133ce --- /dev/null +++ b/net/sched/act_meta_skbprio.c @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/act_meta_prio.c IFE skb->priority metadata module + * + * copyright Jamal Hadi Salim (2015) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int skbprio_check(struct sk_buff *skb, struct tcf_meta_info *e) +{ + return ife_check_meta_u32(skb->priority, e); +} + +static int skbprio_encode(struct sk_buff *skb, void *skbdata, + struct tcf_meta_info *e) +{ + u32 ifeprio = skb->priority; /* avoid having to cast skb->priority*/ + + return ife_encode_meta_u32(ifeprio, skbdata, e); +} + +static int skbprio_decode(struct sk_buff *skb, void *data, u16 len) +{ + u32 ifeprio = *(u32 *)data; + + skb->priority = ntohl(ifeprio); + return 0; +} + +static struct tcf_meta_ops ife_prio_ops = { + .metaid = IFE_META_PRIO, + .metatype = NLA_U32, + .name = "skbprio", + .synopsis = "skb prio metadata", + .check_presence = skbprio_check, + .encode = skbprio_encode, + .decode = skbprio_decode, + .get = ife_get_meta_u32, + .alloc = ife_alloc_meta_u32, + .owner = THIS_MODULE, +}; + +static int __init ifeprio_init_module(void) +{ + return register_ife_op(&ife_prio_ops); +} + +static void __exit ifeprio_cleanup_module(void) +{ + unregister_ife_op(&ife_prio_ops); +} + +module_init(ifeprio_init_module); +module_exit(ifeprio_cleanup_module); + +MODULE_AUTHOR("Jamal Hadi Salim(2015)"); +MODULE_DESCRIPTION("Inter-FE skb prio metadata action"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_IFE_META("skbprio"); diff --git a/net/sched/act_meta_skbtcindex.c b/net/sched/act_meta_skbtcindex.c new file mode 100644 index 000000000..44547caea --- /dev/null +++ b/net/sched/act_meta_skbtcindex.c @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/act_meta_tc_index.c IFE skb->tc_index metadata module + * + * copyright Jamal Hadi Salim (2016) +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int skbtcindex_encode(struct sk_buff *skb, void *skbdata, + struct tcf_meta_info *e) +{ + u32 ifetc_index = skb->tc_index; + + return ife_encode_meta_u16(ifetc_index, skbdata, e); +} + +static int skbtcindex_decode(struct sk_buff *skb, void *data, u16 len) +{ + u16 ifetc_index = *(u16 *)data; + + skb->tc_index = ntohs(ifetc_index); + return 0; +} + +static int skbtcindex_check(struct sk_buff *skb, struct tcf_meta_info *e) +{ + return ife_check_meta_u16(skb->tc_index, e); +} + +static struct tcf_meta_ops ife_skbtcindex_ops = { + .metaid = IFE_META_TCINDEX, + .metatype = NLA_U16, + .name = "tc_index", + .synopsis = "skb tc_index 16 bit metadata", + .check_presence = skbtcindex_check, + .encode = skbtcindex_encode, + .decode = skbtcindex_decode, + .get = ife_get_meta_u16, + .alloc = ife_alloc_meta_u16, + .release = ife_release_meta_gen, + .validate = ife_validate_meta_u16, + .owner = THIS_MODULE, +}; + +static int __init ifetc_index_init_module(void) +{ + return register_ife_op(&ife_skbtcindex_ops); +} + +static void __exit ifetc_index_cleanup_module(void) +{ + unregister_ife_op(&ife_skbtcindex_ops); +} + +module_init(ifetc_index_init_module); +module_exit(ifetc_index_cleanup_module); + +MODULE_AUTHOR("Jamal Hadi Salim(2016)"); +MODULE_DESCRIPTION("Inter-FE skb tc_index metadata module"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_IFE_META("tcindex"); diff --git a/net/sched/act_mirred.c b/net/sched/act_mirred.c new file mode 100644 index 000000000..36395e5db --- /dev/null +++ b/net/sched/act_mirred.c @@ -0,0 +1,551 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/act_mirred.c packet mirroring and redirect actions + * + * Authors: Jamal Hadi Salim (2002-4) + * + * TODO: Add ingress support (and socket redirect support) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static LIST_HEAD(mirred_list); +static DEFINE_SPINLOCK(mirred_list_lock); + +#define MIRRED_NEST_LIMIT 4 +static DEFINE_PER_CPU(unsigned int, mirred_nest_level); + +static bool tcf_mirred_is_act_redirect(int action) +{ + return action == TCA_EGRESS_REDIR || action == TCA_INGRESS_REDIR; +} + +static bool tcf_mirred_act_wants_ingress(int action) +{ + switch (action) { + case TCA_EGRESS_REDIR: + case TCA_EGRESS_MIRROR: + return false; + case TCA_INGRESS_REDIR: + case TCA_INGRESS_MIRROR: + return true; + default: + BUG(); + } +} + +static bool tcf_mirred_can_reinsert(int action) +{ + switch (action) { + case TC_ACT_SHOT: + case TC_ACT_STOLEN: + case TC_ACT_QUEUED: + case TC_ACT_TRAP: + return true; + } + return false; +} + +static struct net_device *tcf_mirred_dev_dereference(struct tcf_mirred *m) +{ + return rcu_dereference_protected(m->tcfm_dev, + lockdep_is_held(&m->tcf_lock)); +} + +static void tcf_mirred_release(struct tc_action *a) +{ + struct tcf_mirred *m = to_mirred(a); + struct net_device *dev; + + spin_lock(&mirred_list_lock); + list_del(&m->tcfm_list); + spin_unlock(&mirred_list_lock); + + /* last reference to action, no need to lock */ + dev = rcu_dereference_protected(m->tcfm_dev, 1); + netdev_put(dev, &m->tcfm_dev_tracker); +} + +static const struct nla_policy mirred_policy[TCA_MIRRED_MAX + 1] = { + [TCA_MIRRED_PARMS] = { .len = sizeof(struct tc_mirred) }, +}; + +static struct tc_action_ops act_mirred_ops; + +static int tcf_mirred_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, + u32 flags, struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_mirred_ops.net_id); + bool bind = flags & TCA_ACT_FLAGS_BIND; + struct nlattr *tb[TCA_MIRRED_MAX + 1]; + struct tcf_chain *goto_ch = NULL; + bool mac_header_xmit = false; + struct tc_mirred *parm; + struct tcf_mirred *m; + bool exists = false; + int ret, err; + u32 index; + + if (!nla) { + NL_SET_ERR_MSG_MOD(extack, "Mirred requires attributes to be passed"); + return -EINVAL; + } + ret = nla_parse_nested_deprecated(tb, TCA_MIRRED_MAX, nla, + mirred_policy, extack); + if (ret < 0) + return ret; + if (!tb[TCA_MIRRED_PARMS]) { + NL_SET_ERR_MSG_MOD(extack, "Missing required mirred parameters"); + return -EINVAL; + } + parm = nla_data(tb[TCA_MIRRED_PARMS]); + index = parm->index; + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (err < 0) + return err; + exists = err; + if (exists && bind) + return 0; + + switch (parm->eaction) { + case TCA_EGRESS_MIRROR: + case TCA_EGRESS_REDIR: + case TCA_INGRESS_REDIR: + case TCA_INGRESS_MIRROR: + break; + default: + if (exists) + tcf_idr_release(*a, bind); + else + tcf_idr_cleanup(tn, index); + NL_SET_ERR_MSG_MOD(extack, "Unknown mirred option"); + return -EINVAL; + } + + if (!exists) { + if (!parm->ifindex) { + tcf_idr_cleanup(tn, index); + NL_SET_ERR_MSG_MOD(extack, "Specified device does not exist"); + return -EINVAL; + } + ret = tcf_idr_create_from_flags(tn, index, est, a, + &act_mirred_ops, bind, flags); + if (ret) { + tcf_idr_cleanup(tn, index); + return ret; + } + ret = ACT_P_CREATED; + } else if (!(flags & TCA_ACT_FLAGS_REPLACE)) { + tcf_idr_release(*a, bind); + return -EEXIST; + } + + m = to_mirred(*a); + if (ret == ACT_P_CREATED) + INIT_LIST_HEAD(&m->tcfm_list); + + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack); + if (err < 0) + goto release_idr; + + spin_lock_bh(&m->tcf_lock); + + if (parm->ifindex) { + struct net_device *odev, *ndev; + + ndev = dev_get_by_index(net, parm->ifindex); + if (!ndev) { + spin_unlock_bh(&m->tcf_lock); + err = -ENODEV; + goto put_chain; + } + mac_header_xmit = dev_is_mac_header_xmit(ndev); + odev = rcu_replace_pointer(m->tcfm_dev, ndev, + lockdep_is_held(&m->tcf_lock)); + netdev_put(odev, &m->tcfm_dev_tracker); + netdev_tracker_alloc(ndev, &m->tcfm_dev_tracker, GFP_ATOMIC); + m->tcfm_mac_header_xmit = mac_header_xmit; + } + goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch); + m->tcfm_eaction = parm->eaction; + spin_unlock_bh(&m->tcf_lock); + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + + if (ret == ACT_P_CREATED) { + spin_lock(&mirred_list_lock); + list_add(&m->tcfm_list, &mirred_list); + spin_unlock(&mirred_list_lock); + } + + return ret; +put_chain: + if (goto_ch) + tcf_chain_put_by_act(goto_ch); +release_idr: + tcf_idr_release(*a, bind); + return err; +} + +static bool is_mirred_nested(void) +{ + return unlikely(__this_cpu_read(mirred_nest_level) > 1); +} + +static int tcf_mirred_forward(bool want_ingress, struct sk_buff *skb) +{ + int err; + + if (!want_ingress) + err = tcf_dev_queue_xmit(skb, dev_queue_xmit); + else if (is_mirred_nested()) + err = netif_rx(skb); + else + err = netif_receive_skb(skb); + + return err; +} + +static int tcf_mirred_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + struct tcf_mirred *m = to_mirred(a); + struct sk_buff *skb2 = skb; + bool m_mac_header_xmit; + struct net_device *dev; + unsigned int nest_level; + int retval, err = 0; + bool use_reinsert; + bool want_ingress; + bool is_redirect; + bool expects_nh; + bool at_ingress; + int m_eaction; + int mac_len; + bool at_nh; + + nest_level = __this_cpu_inc_return(mirred_nest_level); + if (unlikely(nest_level > MIRRED_NEST_LIMIT)) { + net_warn_ratelimited("Packet exceeded mirred recursion limit on dev %s\n", + netdev_name(skb->dev)); + __this_cpu_dec(mirred_nest_level); + return TC_ACT_SHOT; + } + + tcf_lastuse_update(&m->tcf_tm); + tcf_action_update_bstats(&m->common, skb); + + m_mac_header_xmit = READ_ONCE(m->tcfm_mac_header_xmit); + m_eaction = READ_ONCE(m->tcfm_eaction); + retval = READ_ONCE(m->tcf_action); + dev = rcu_dereference_bh(m->tcfm_dev); + if (unlikely(!dev)) { + pr_notice_once("tc mirred: target device is gone\n"); + goto out; + } + + if (unlikely(!(dev->flags & IFF_UP)) || !netif_carrier_ok(dev)) { + net_notice_ratelimited("tc mirred to Houston: device %s is down\n", + dev->name); + goto out; + } + + /* we could easily avoid the clone only if called by ingress and clsact; + * since we can't easily detect the clsact caller, skip clone only for + * ingress - that covers the TC S/W datapath. + */ + is_redirect = tcf_mirred_is_act_redirect(m_eaction); + at_ingress = skb_at_tc_ingress(skb); + use_reinsert = at_ingress && is_redirect && + tcf_mirred_can_reinsert(retval); + if (!use_reinsert) { + skb2 = skb_clone(skb, GFP_ATOMIC); + if (!skb2) + goto out; + } + + want_ingress = tcf_mirred_act_wants_ingress(m_eaction); + + /* All mirred/redirected skbs should clear previous ct info */ + nf_reset_ct(skb2); + if (want_ingress && !at_ingress) /* drop dst for egress -> ingress */ + skb_dst_drop(skb2); + + expects_nh = want_ingress || !m_mac_header_xmit; + at_nh = skb->data == skb_network_header(skb); + if (at_nh != expects_nh) { + mac_len = skb_at_tc_ingress(skb) ? skb->mac_len : + skb_network_header(skb) - skb_mac_header(skb); + if (expects_nh) { + /* target device/action expect data at nh */ + skb_pull_rcsum(skb2, mac_len); + } else { + /* target device/action expect data at mac */ + skb_push_rcsum(skb2, mac_len); + } + } + + skb2->skb_iif = skb->dev->ifindex; + skb2->dev = dev; + + /* mirror is always swallowed */ + if (is_redirect) { + skb_set_redirected(skb2, skb2->tc_at_ingress); + + /* let's the caller reinsert the packet, if possible */ + if (use_reinsert) { + err = tcf_mirred_forward(want_ingress, skb); + if (err) + tcf_action_inc_overlimit_qstats(&m->common); + __this_cpu_dec(mirred_nest_level); + return TC_ACT_CONSUMED; + } + } + + err = tcf_mirred_forward(want_ingress, skb2); + if (err) { +out: + tcf_action_inc_overlimit_qstats(&m->common); + if (tcf_mirred_is_act_redirect(m_eaction)) + retval = TC_ACT_SHOT; + } + __this_cpu_dec(mirred_nest_level); + + return retval; +} + +static void tcf_stats_update(struct tc_action *a, u64 bytes, u64 packets, + u64 drops, u64 lastuse, bool hw) +{ + struct tcf_mirred *m = to_mirred(a); + struct tcf_t *tm = &m->tcf_tm; + + tcf_action_update_stats(a, bytes, packets, drops, hw); + tm->lastuse = max_t(u64, tm->lastuse, lastuse); +} + +static int tcf_mirred_dump(struct sk_buff *skb, struct tc_action *a, int bind, + int ref) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tcf_mirred *m = to_mirred(a); + struct tc_mirred opt = { + .index = m->tcf_index, + .refcnt = refcount_read(&m->tcf_refcnt) - ref, + .bindcnt = atomic_read(&m->tcf_bindcnt) - bind, + }; + struct net_device *dev; + struct tcf_t t; + + spin_lock_bh(&m->tcf_lock); + opt.action = m->tcf_action; + opt.eaction = m->tcfm_eaction; + dev = tcf_mirred_dev_dereference(m); + if (dev) + opt.ifindex = dev->ifindex; + + if (nla_put(skb, TCA_MIRRED_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + + tcf_tm_dump(&t, &m->tcf_tm); + if (nla_put_64bit(skb, TCA_MIRRED_TM, sizeof(t), &t, TCA_MIRRED_PAD)) + goto nla_put_failure; + spin_unlock_bh(&m->tcf_lock); + + return skb->len; + +nla_put_failure: + spin_unlock_bh(&m->tcf_lock); + nlmsg_trim(skb, b); + return -1; +} + +static int mirred_device_event(struct notifier_block *unused, + unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct tcf_mirred *m; + + ASSERT_RTNL(); + if (event == NETDEV_UNREGISTER) { + spin_lock(&mirred_list_lock); + list_for_each_entry(m, &mirred_list, tcfm_list) { + spin_lock_bh(&m->tcf_lock); + if (tcf_mirred_dev_dereference(m) == dev) { + netdev_put(dev, &m->tcfm_dev_tracker); + /* Note : no rcu grace period necessary, as + * net_device are already rcu protected. + */ + RCU_INIT_POINTER(m->tcfm_dev, NULL); + } + spin_unlock_bh(&m->tcf_lock); + } + spin_unlock(&mirred_list_lock); + } + + return NOTIFY_DONE; +} + +static struct notifier_block mirred_device_notifier = { + .notifier_call = mirred_device_event, +}; + +static void tcf_mirred_dev_put(void *priv) +{ + struct net_device *dev = priv; + + dev_put(dev); +} + +static struct net_device * +tcf_mirred_get_dev(const struct tc_action *a, + tc_action_priv_destructor *destructor) +{ + struct tcf_mirred *m = to_mirred(a); + struct net_device *dev; + + rcu_read_lock(); + dev = rcu_dereference(m->tcfm_dev); + if (dev) { + dev_hold(dev); + *destructor = tcf_mirred_dev_put; + } + rcu_read_unlock(); + + return dev; +} + +static size_t tcf_mirred_get_fill_size(const struct tc_action *act) +{ + return nla_total_size(sizeof(struct tc_mirred)); +} + +static void tcf_offload_mirred_get_dev(struct flow_action_entry *entry, + const struct tc_action *act) +{ + entry->dev = act->ops->get_dev(act, &entry->destructor); + if (!entry->dev) + return; + entry->destructor_priv = entry->dev; +} + +static int tcf_mirred_offload_act_setup(struct tc_action *act, void *entry_data, + u32 *index_inc, bool bind, + struct netlink_ext_ack *extack) +{ + if (bind) { + struct flow_action_entry *entry = entry_data; + + if (is_tcf_mirred_egress_redirect(act)) { + entry->id = FLOW_ACTION_REDIRECT; + tcf_offload_mirred_get_dev(entry, act); + } else if (is_tcf_mirred_egress_mirror(act)) { + entry->id = FLOW_ACTION_MIRRED; + tcf_offload_mirred_get_dev(entry, act); + } else if (is_tcf_mirred_ingress_redirect(act)) { + entry->id = FLOW_ACTION_REDIRECT_INGRESS; + tcf_offload_mirred_get_dev(entry, act); + } else if (is_tcf_mirred_ingress_mirror(act)) { + entry->id = FLOW_ACTION_MIRRED_INGRESS; + tcf_offload_mirred_get_dev(entry, act); + } else { + NL_SET_ERR_MSG_MOD(extack, "Unsupported mirred offload"); + return -EOPNOTSUPP; + } + *index_inc = 1; + } else { + struct flow_offload_action *fl_action = entry_data; + + if (is_tcf_mirred_egress_redirect(act)) + fl_action->id = FLOW_ACTION_REDIRECT; + else if (is_tcf_mirred_egress_mirror(act)) + fl_action->id = FLOW_ACTION_MIRRED; + else if (is_tcf_mirred_ingress_redirect(act)) + fl_action->id = FLOW_ACTION_REDIRECT_INGRESS; + else if (is_tcf_mirred_ingress_mirror(act)) + fl_action->id = FLOW_ACTION_MIRRED_INGRESS; + else + return -EOPNOTSUPP; + } + + return 0; +} + +static struct tc_action_ops act_mirred_ops = { + .kind = "mirred", + .id = TCA_ID_MIRRED, + .owner = THIS_MODULE, + .act = tcf_mirred_act, + .stats_update = tcf_stats_update, + .dump = tcf_mirred_dump, + .cleanup = tcf_mirred_release, + .init = tcf_mirred_init, + .get_fill_size = tcf_mirred_get_fill_size, + .offload_act_setup = tcf_mirred_offload_act_setup, + .size = sizeof(struct tcf_mirred), + .get_dev = tcf_mirred_get_dev, +}; + +static __net_init int mirred_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_mirred_ops.net_id); + + return tc_action_net_init(net, tn, &act_mirred_ops); +} + +static void __net_exit mirred_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_mirred_ops.net_id); +} + +static struct pernet_operations mirred_net_ops = { + .init = mirred_init_net, + .exit_batch = mirred_exit_net, + .id = &act_mirred_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +MODULE_AUTHOR("Jamal Hadi Salim(2002)"); +MODULE_DESCRIPTION("Device Mirror/redirect actions"); +MODULE_LICENSE("GPL"); + +static int __init mirred_init_module(void) +{ + int err = register_netdevice_notifier(&mirred_device_notifier); + if (err) + return err; + + pr_info("Mirror/redirect action on\n"); + err = tcf_register_action(&act_mirred_ops, &mirred_net_ops); + if (err) + unregister_netdevice_notifier(&mirred_device_notifier); + + return err; +} + +static void __exit mirred_cleanup_module(void) +{ + tcf_unregister_action(&act_mirred_ops, &mirred_net_ops); + unregister_netdevice_notifier(&mirred_device_notifier); +} + +module_init(mirred_init_module); +module_exit(mirred_cleanup_module); diff --git a/net/sched/act_mpls.c b/net/sched/act_mpls.c new file mode 100644 index 000000000..f24f997a7 --- /dev/null +++ b/net/sched/act_mpls.c @@ -0,0 +1,489 @@ +// SPDX-License-Identifier: (GPL-2.0-only OR BSD-2-Clause) +/* Copyright (C) 2019 Netronome Systems, Inc. */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static struct tc_action_ops act_mpls_ops; + +#define ACT_MPLS_TTL_DEFAULT 255 + +static __be32 tcf_mpls_get_lse(struct mpls_shim_hdr *lse, + struct tcf_mpls_params *p, bool set_bos) +{ + u32 new_lse = 0; + + if (lse) + new_lse = be32_to_cpu(lse->label_stack_entry); + + if (p->tcfm_label != ACT_MPLS_LABEL_NOT_SET) { + new_lse &= ~MPLS_LS_LABEL_MASK; + new_lse |= p->tcfm_label << MPLS_LS_LABEL_SHIFT; + } + if (p->tcfm_ttl) { + new_lse &= ~MPLS_LS_TTL_MASK; + new_lse |= p->tcfm_ttl << MPLS_LS_TTL_SHIFT; + } + if (p->tcfm_tc != ACT_MPLS_TC_NOT_SET) { + new_lse &= ~MPLS_LS_TC_MASK; + new_lse |= p->tcfm_tc << MPLS_LS_TC_SHIFT; + } + if (p->tcfm_bos != ACT_MPLS_BOS_NOT_SET) { + new_lse &= ~MPLS_LS_S_MASK; + new_lse |= p->tcfm_bos << MPLS_LS_S_SHIFT; + } else if (set_bos) { + new_lse |= 1 << MPLS_LS_S_SHIFT; + } + + return cpu_to_be32(new_lse); +} + +static int tcf_mpls_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + struct tcf_mpls *m = to_mpls(a); + struct tcf_mpls_params *p; + __be32 new_lse; + int ret, mac_len; + + tcf_lastuse_update(&m->tcf_tm); + bstats_update(this_cpu_ptr(m->common.cpu_bstats), skb); + + /* Ensure 'data' points at mac_header prior calling mpls manipulating + * functions. + */ + if (skb_at_tc_ingress(skb)) { + skb_push_rcsum(skb, skb->mac_len); + mac_len = skb->mac_len; + } else { + mac_len = skb_network_header(skb) - skb_mac_header(skb); + } + + ret = READ_ONCE(m->tcf_action); + + p = rcu_dereference_bh(m->mpls_p); + + switch (p->tcfm_action) { + case TCA_MPLS_ACT_POP: + if (skb_mpls_pop(skb, p->tcfm_proto, mac_len, + skb->dev && skb->dev->type == ARPHRD_ETHER)) + goto drop; + break; + case TCA_MPLS_ACT_PUSH: + new_lse = tcf_mpls_get_lse(NULL, p, !eth_p_mpls(skb_protocol(skb, true))); + if (skb_mpls_push(skb, new_lse, p->tcfm_proto, mac_len, + skb->dev && skb->dev->type == ARPHRD_ETHER)) + goto drop; + break; + case TCA_MPLS_ACT_MAC_PUSH: + if (skb_vlan_tag_present(skb)) { + if (__vlan_insert_inner_tag(skb, skb->vlan_proto, + skb_vlan_tag_get(skb), + ETH_HLEN) < 0) + goto drop; + + skb->protocol = skb->vlan_proto; + __vlan_hwaccel_clear_tag(skb); + } + + new_lse = tcf_mpls_get_lse(NULL, p, mac_len || + !eth_p_mpls(skb->protocol)); + + if (skb_mpls_push(skb, new_lse, p->tcfm_proto, 0, false)) + goto drop; + break; + case TCA_MPLS_ACT_MODIFY: + if (!pskb_may_pull(skb, + skb_network_offset(skb) + MPLS_HLEN)) + goto drop; + new_lse = tcf_mpls_get_lse(mpls_hdr(skb), p, false); + if (skb_mpls_update_lse(skb, new_lse)) + goto drop; + break; + case TCA_MPLS_ACT_DEC_TTL: + if (skb_mpls_dec_ttl(skb)) + goto drop; + break; + } + + if (skb_at_tc_ingress(skb)) + skb_pull_rcsum(skb, skb->mac_len); + + return ret; + +drop: + qstats_drop_inc(this_cpu_ptr(m->common.cpu_qstats)); + return TC_ACT_SHOT; +} + +static int valid_label(const struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + const u32 *label = nla_data(attr); + + if (nla_len(attr) != sizeof(*label)) { + NL_SET_ERR_MSG_MOD(extack, "Invalid MPLS label length"); + return -EINVAL; + } + + if (*label & ~MPLS_LABEL_MASK || *label == MPLS_LABEL_IMPLNULL) { + NL_SET_ERR_MSG_MOD(extack, "MPLS label out of range"); + return -EINVAL; + } + + return 0; +} + +static const struct nla_policy mpls_policy[TCA_MPLS_MAX + 1] = { + [TCA_MPLS_PARMS] = NLA_POLICY_EXACT_LEN(sizeof(struct tc_mpls)), + [TCA_MPLS_PROTO] = { .type = NLA_U16 }, + [TCA_MPLS_LABEL] = NLA_POLICY_VALIDATE_FN(NLA_BINARY, + valid_label), + [TCA_MPLS_TC] = NLA_POLICY_RANGE(NLA_U8, 0, 7), + [TCA_MPLS_TTL] = NLA_POLICY_MIN(NLA_U8, 1), + [TCA_MPLS_BOS] = NLA_POLICY_RANGE(NLA_U8, 0, 1), +}; + +static int tcf_mpls_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, u32 flags, + struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_mpls_ops.net_id); + bool bind = flags & TCA_ACT_FLAGS_BIND; + struct nlattr *tb[TCA_MPLS_MAX + 1]; + struct tcf_chain *goto_ch = NULL; + struct tcf_mpls_params *p; + struct tc_mpls *parm; + bool exists = false; + struct tcf_mpls *m; + int ret = 0, err; + u8 mpls_ttl = 0; + u32 index; + + if (!nla) { + NL_SET_ERR_MSG_MOD(extack, "Missing netlink attributes"); + return -EINVAL; + } + + err = nla_parse_nested(tb, TCA_MPLS_MAX, nla, mpls_policy, extack); + if (err < 0) + return err; + + if (!tb[TCA_MPLS_PARMS]) { + NL_SET_ERR_MSG_MOD(extack, "No MPLS params"); + return -EINVAL; + } + parm = nla_data(tb[TCA_MPLS_PARMS]); + index = parm->index; + + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (err < 0) + return err; + exists = err; + if (exists && bind) + return 0; + + if (!exists) { + ret = tcf_idr_create(tn, index, est, a, &act_mpls_ops, bind, + true, flags); + if (ret) { + tcf_idr_cleanup(tn, index); + return ret; + } + + ret = ACT_P_CREATED; + } else if (!(flags & TCA_ACT_FLAGS_REPLACE)) { + tcf_idr_release(*a, bind); + return -EEXIST; + } + + /* Verify parameters against action type. */ + switch (parm->m_action) { + case TCA_MPLS_ACT_POP: + if (!tb[TCA_MPLS_PROTO]) { + NL_SET_ERR_MSG_MOD(extack, "Protocol must be set for MPLS pop"); + err = -EINVAL; + goto release_idr; + } + if (!eth_proto_is_802_3(nla_get_be16(tb[TCA_MPLS_PROTO]))) { + NL_SET_ERR_MSG_MOD(extack, "Invalid protocol type for MPLS pop"); + err = -EINVAL; + goto release_idr; + } + if (tb[TCA_MPLS_LABEL] || tb[TCA_MPLS_TTL] || tb[TCA_MPLS_TC] || + tb[TCA_MPLS_BOS]) { + NL_SET_ERR_MSG_MOD(extack, "Label, TTL, TC or BOS cannot be used with MPLS pop"); + err = -EINVAL; + goto release_idr; + } + break; + case TCA_MPLS_ACT_DEC_TTL: + if (tb[TCA_MPLS_PROTO] || tb[TCA_MPLS_LABEL] || + tb[TCA_MPLS_TTL] || tb[TCA_MPLS_TC] || tb[TCA_MPLS_BOS]) { + NL_SET_ERR_MSG_MOD(extack, "Label, TTL, TC, BOS or protocol cannot be used with MPLS dec_ttl"); + err = -EINVAL; + goto release_idr; + } + break; + case TCA_MPLS_ACT_PUSH: + case TCA_MPLS_ACT_MAC_PUSH: + if (!tb[TCA_MPLS_LABEL]) { + NL_SET_ERR_MSG_MOD(extack, "Label is required for MPLS push"); + err = -EINVAL; + goto release_idr; + } + if (tb[TCA_MPLS_PROTO] && + !eth_p_mpls(nla_get_be16(tb[TCA_MPLS_PROTO]))) { + NL_SET_ERR_MSG_MOD(extack, "Protocol must be an MPLS type for MPLS push"); + err = -EPROTONOSUPPORT; + goto release_idr; + } + /* Push needs a TTL - if not specified, set a default value. */ + if (!tb[TCA_MPLS_TTL]) { +#if IS_ENABLED(CONFIG_MPLS) + mpls_ttl = net->mpls.default_ttl ? + net->mpls.default_ttl : ACT_MPLS_TTL_DEFAULT; +#else + mpls_ttl = ACT_MPLS_TTL_DEFAULT; +#endif + } + break; + case TCA_MPLS_ACT_MODIFY: + if (tb[TCA_MPLS_PROTO]) { + NL_SET_ERR_MSG_MOD(extack, "Protocol cannot be used with MPLS modify"); + err = -EINVAL; + goto release_idr; + } + break; + default: + NL_SET_ERR_MSG_MOD(extack, "Unknown MPLS action"); + err = -EINVAL; + goto release_idr; + } + + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack); + if (err < 0) + goto release_idr; + + m = to_mpls(*a); + + p = kzalloc(sizeof(*p), GFP_KERNEL); + if (!p) { + err = -ENOMEM; + goto put_chain; + } + + p->tcfm_action = parm->m_action; + p->tcfm_label = tb[TCA_MPLS_LABEL] ? nla_get_u32(tb[TCA_MPLS_LABEL]) : + ACT_MPLS_LABEL_NOT_SET; + p->tcfm_tc = tb[TCA_MPLS_TC] ? nla_get_u8(tb[TCA_MPLS_TC]) : + ACT_MPLS_TC_NOT_SET; + p->tcfm_ttl = tb[TCA_MPLS_TTL] ? nla_get_u8(tb[TCA_MPLS_TTL]) : + mpls_ttl; + p->tcfm_bos = tb[TCA_MPLS_BOS] ? nla_get_u8(tb[TCA_MPLS_BOS]) : + ACT_MPLS_BOS_NOT_SET; + p->tcfm_proto = tb[TCA_MPLS_PROTO] ? nla_get_be16(tb[TCA_MPLS_PROTO]) : + htons(ETH_P_MPLS_UC); + + spin_lock_bh(&m->tcf_lock); + goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch); + p = rcu_replace_pointer(m->mpls_p, p, lockdep_is_held(&m->tcf_lock)); + spin_unlock_bh(&m->tcf_lock); + + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + if (p) + kfree_rcu(p, rcu); + + return ret; +put_chain: + if (goto_ch) + tcf_chain_put_by_act(goto_ch); +release_idr: + tcf_idr_release(*a, bind); + return err; +} + +static void tcf_mpls_cleanup(struct tc_action *a) +{ + struct tcf_mpls *m = to_mpls(a); + struct tcf_mpls_params *p; + + p = rcu_dereference_protected(m->mpls_p, 1); + if (p) + kfree_rcu(p, rcu); +} + +static int tcf_mpls_dump(struct sk_buff *skb, struct tc_action *a, + int bind, int ref) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tcf_mpls *m = to_mpls(a); + struct tcf_mpls_params *p; + struct tc_mpls opt = { + .index = m->tcf_index, + .refcnt = refcount_read(&m->tcf_refcnt) - ref, + .bindcnt = atomic_read(&m->tcf_bindcnt) - bind, + }; + struct tcf_t t; + + spin_lock_bh(&m->tcf_lock); + opt.action = m->tcf_action; + p = rcu_dereference_protected(m->mpls_p, lockdep_is_held(&m->tcf_lock)); + opt.m_action = p->tcfm_action; + + if (nla_put(skb, TCA_MPLS_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + + if (p->tcfm_label != ACT_MPLS_LABEL_NOT_SET && + nla_put_u32(skb, TCA_MPLS_LABEL, p->tcfm_label)) + goto nla_put_failure; + + if (p->tcfm_tc != ACT_MPLS_TC_NOT_SET && + nla_put_u8(skb, TCA_MPLS_TC, p->tcfm_tc)) + goto nla_put_failure; + + if (p->tcfm_ttl && nla_put_u8(skb, TCA_MPLS_TTL, p->tcfm_ttl)) + goto nla_put_failure; + + if (p->tcfm_bos != ACT_MPLS_BOS_NOT_SET && + nla_put_u8(skb, TCA_MPLS_BOS, p->tcfm_bos)) + goto nla_put_failure; + + if (nla_put_be16(skb, TCA_MPLS_PROTO, p->tcfm_proto)) + goto nla_put_failure; + + tcf_tm_dump(&t, &m->tcf_tm); + + if (nla_put_64bit(skb, TCA_MPLS_TM, sizeof(t), &t, TCA_MPLS_PAD)) + goto nla_put_failure; + + spin_unlock_bh(&m->tcf_lock); + + return skb->len; + +nla_put_failure: + spin_unlock_bh(&m->tcf_lock); + nlmsg_trim(skb, b); + return -EMSGSIZE; +} + +static int tcf_mpls_offload_act_setup(struct tc_action *act, void *entry_data, + u32 *index_inc, bool bind, + struct netlink_ext_ack *extack) +{ + if (bind) { + struct flow_action_entry *entry = entry_data; + + switch (tcf_mpls_action(act)) { + case TCA_MPLS_ACT_PUSH: + entry->id = FLOW_ACTION_MPLS_PUSH; + entry->mpls_push.proto = tcf_mpls_proto(act); + entry->mpls_push.label = tcf_mpls_label(act); + entry->mpls_push.tc = tcf_mpls_tc(act); + entry->mpls_push.bos = tcf_mpls_bos(act); + entry->mpls_push.ttl = tcf_mpls_ttl(act); + break; + case TCA_MPLS_ACT_POP: + entry->id = FLOW_ACTION_MPLS_POP; + entry->mpls_pop.proto = tcf_mpls_proto(act); + break; + case TCA_MPLS_ACT_MODIFY: + entry->id = FLOW_ACTION_MPLS_MANGLE; + entry->mpls_mangle.label = tcf_mpls_label(act); + entry->mpls_mangle.tc = tcf_mpls_tc(act); + entry->mpls_mangle.bos = tcf_mpls_bos(act); + entry->mpls_mangle.ttl = tcf_mpls_ttl(act); + break; + case TCA_MPLS_ACT_DEC_TTL: + NL_SET_ERR_MSG_MOD(extack, "Offload not supported when \"dec_ttl\" option is used"); + return -EOPNOTSUPP; + case TCA_MPLS_ACT_MAC_PUSH: + NL_SET_ERR_MSG_MOD(extack, "Offload not supported when \"mac_push\" option is used"); + return -EOPNOTSUPP; + default: + NL_SET_ERR_MSG_MOD(extack, "Unsupported MPLS mode offload"); + return -EOPNOTSUPP; + } + *index_inc = 1; + } else { + struct flow_offload_action *fl_action = entry_data; + + switch (tcf_mpls_action(act)) { + case TCA_MPLS_ACT_PUSH: + fl_action->id = FLOW_ACTION_MPLS_PUSH; + break; + case TCA_MPLS_ACT_POP: + fl_action->id = FLOW_ACTION_MPLS_POP; + break; + case TCA_MPLS_ACT_MODIFY: + fl_action->id = FLOW_ACTION_MPLS_MANGLE; + break; + default: + return -EOPNOTSUPP; + } + } + + return 0; +} + +static struct tc_action_ops act_mpls_ops = { + .kind = "mpls", + .id = TCA_ID_MPLS, + .owner = THIS_MODULE, + .act = tcf_mpls_act, + .dump = tcf_mpls_dump, + .init = tcf_mpls_init, + .cleanup = tcf_mpls_cleanup, + .offload_act_setup = tcf_mpls_offload_act_setup, + .size = sizeof(struct tcf_mpls), +}; + +static __net_init int mpls_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_mpls_ops.net_id); + + return tc_action_net_init(net, tn, &act_mpls_ops); +} + +static void __net_exit mpls_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_mpls_ops.net_id); +} + +static struct pernet_operations mpls_net_ops = { + .init = mpls_init_net, + .exit_batch = mpls_exit_net, + .id = &act_mpls_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +static int __init mpls_init_module(void) +{ + return tcf_register_action(&act_mpls_ops, &mpls_net_ops); +} + +static void __exit mpls_cleanup_module(void) +{ + tcf_unregister_action(&act_mpls_ops, &mpls_net_ops); +} + +module_init(mpls_init_module); +module_exit(mpls_cleanup_module); + +MODULE_SOFTDEP("post: mpls_gso"); +MODULE_AUTHOR("Netronome Systems "); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("MPLS manipulation actions"); diff --git a/net/sched/act_nat.c b/net/sched/act_nat.c new file mode 100644 index 000000000..9265145f1 --- /dev/null +++ b/net/sched/act_nat.c @@ -0,0 +1,334 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Stateless NAT actions + * + * Copyright (c) 2007 Herbert Xu + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +static struct tc_action_ops act_nat_ops; + +static const struct nla_policy nat_policy[TCA_NAT_MAX + 1] = { + [TCA_NAT_PARMS] = { .len = sizeof(struct tc_nat) }, +}; + +static int tcf_nat_init(struct net *net, struct nlattr *nla, struct nlattr *est, + struct tc_action **a, struct tcf_proto *tp, + u32 flags, struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_nat_ops.net_id); + bool bind = flags & TCA_ACT_FLAGS_BIND; + struct nlattr *tb[TCA_NAT_MAX + 1]; + struct tcf_chain *goto_ch = NULL; + struct tc_nat *parm; + int ret = 0, err; + struct tcf_nat *p; + u32 index; + + if (nla == NULL) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, TCA_NAT_MAX, nla, nat_policy, + NULL); + if (err < 0) + return err; + + if (tb[TCA_NAT_PARMS] == NULL) + return -EINVAL; + parm = nla_data(tb[TCA_NAT_PARMS]); + index = parm->index; + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (!err) { + ret = tcf_idr_create(tn, index, est, a, + &act_nat_ops, bind, false, flags); + if (ret) { + tcf_idr_cleanup(tn, index); + return ret; + } + ret = ACT_P_CREATED; + } else if (err > 0) { + if (bind) + return 0; + if (!(flags & TCA_ACT_FLAGS_REPLACE)) { + tcf_idr_release(*a, bind); + return -EEXIST; + } + } else { + return err; + } + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack); + if (err < 0) + goto release_idr; + p = to_tcf_nat(*a); + + spin_lock_bh(&p->tcf_lock); + p->old_addr = parm->old_addr; + p->new_addr = parm->new_addr; + p->mask = parm->mask; + p->flags = parm->flags; + + goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch); + spin_unlock_bh(&p->tcf_lock); + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + + return ret; +release_idr: + tcf_idr_release(*a, bind); + return err; +} + +static int tcf_nat_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + struct tcf_nat *p = to_tcf_nat(a); + struct iphdr *iph; + __be32 old_addr; + __be32 new_addr; + __be32 mask; + __be32 addr; + int egress; + int action; + int ihl; + int noff; + + spin_lock(&p->tcf_lock); + + tcf_lastuse_update(&p->tcf_tm); + old_addr = p->old_addr; + new_addr = p->new_addr; + mask = p->mask; + egress = p->flags & TCA_NAT_FLAG_EGRESS; + action = p->tcf_action; + + bstats_update(&p->tcf_bstats, skb); + + spin_unlock(&p->tcf_lock); + + if (unlikely(action == TC_ACT_SHOT)) + goto drop; + + noff = skb_network_offset(skb); + if (!pskb_may_pull(skb, sizeof(*iph) + noff)) + goto drop; + + iph = ip_hdr(skb); + + if (egress) + addr = iph->saddr; + else + addr = iph->daddr; + + if (!((old_addr ^ addr) & mask)) { + if (skb_try_make_writable(skb, sizeof(*iph) + noff)) + goto drop; + + new_addr &= mask; + new_addr |= addr & ~mask; + + /* Rewrite IP header */ + iph = ip_hdr(skb); + if (egress) + iph->saddr = new_addr; + else + iph->daddr = new_addr; + + csum_replace4(&iph->check, addr, new_addr); + } else if ((iph->frag_off & htons(IP_OFFSET)) || + iph->protocol != IPPROTO_ICMP) { + goto out; + } + + ihl = iph->ihl * 4; + + /* It would be nice to share code with stateful NAT. */ + switch (iph->frag_off & htons(IP_OFFSET) ? 0 : iph->protocol) { + case IPPROTO_TCP: + { + struct tcphdr *tcph; + + if (!pskb_may_pull(skb, ihl + sizeof(*tcph) + noff) || + skb_try_make_writable(skb, ihl + sizeof(*tcph) + noff)) + goto drop; + + tcph = (void *)(skb_network_header(skb) + ihl); + inet_proto_csum_replace4(&tcph->check, skb, addr, new_addr, + true); + break; + } + case IPPROTO_UDP: + { + struct udphdr *udph; + + if (!pskb_may_pull(skb, ihl + sizeof(*udph) + noff) || + skb_try_make_writable(skb, ihl + sizeof(*udph) + noff)) + goto drop; + + udph = (void *)(skb_network_header(skb) + ihl); + if (udph->check || skb->ip_summed == CHECKSUM_PARTIAL) { + inet_proto_csum_replace4(&udph->check, skb, addr, + new_addr, true); + if (!udph->check) + udph->check = CSUM_MANGLED_0; + } + break; + } + case IPPROTO_ICMP: + { + struct icmphdr *icmph; + + if (!pskb_may_pull(skb, ihl + sizeof(*icmph) + noff)) + goto drop; + + icmph = (void *)(skb_network_header(skb) + ihl); + + if (!icmp_is_err(icmph->type)) + break; + + if (!pskb_may_pull(skb, ihl + sizeof(*icmph) + sizeof(*iph) + + noff)) + goto drop; + + icmph = (void *)(skb_network_header(skb) + ihl); + iph = (void *)(icmph + 1); + if (egress) + addr = iph->daddr; + else + addr = iph->saddr; + + if ((old_addr ^ addr) & mask) + break; + + if (skb_try_make_writable(skb, ihl + sizeof(*icmph) + + sizeof(*iph) + noff)) + goto drop; + + icmph = (void *)(skb_network_header(skb) + ihl); + iph = (void *)(icmph + 1); + + new_addr &= mask; + new_addr |= addr & ~mask; + + /* XXX Fix up the inner checksums. */ + if (egress) + iph->daddr = new_addr; + else + iph->saddr = new_addr; + + inet_proto_csum_replace4(&icmph->checksum, skb, addr, new_addr, + false); + break; + } + default: + break; + } + +out: + return action; + +drop: + spin_lock(&p->tcf_lock); + p->tcf_qstats.drops++; + spin_unlock(&p->tcf_lock); + return TC_ACT_SHOT; +} + +static int tcf_nat_dump(struct sk_buff *skb, struct tc_action *a, + int bind, int ref) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tcf_nat *p = to_tcf_nat(a); + struct tc_nat opt = { + .index = p->tcf_index, + .refcnt = refcount_read(&p->tcf_refcnt) - ref, + .bindcnt = atomic_read(&p->tcf_bindcnt) - bind, + }; + struct tcf_t t; + + spin_lock_bh(&p->tcf_lock); + opt.old_addr = p->old_addr; + opt.new_addr = p->new_addr; + opt.mask = p->mask; + opt.flags = p->flags; + opt.action = p->tcf_action; + + if (nla_put(skb, TCA_NAT_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + + tcf_tm_dump(&t, &p->tcf_tm); + if (nla_put_64bit(skb, TCA_NAT_TM, sizeof(t), &t, TCA_NAT_PAD)) + goto nla_put_failure; + spin_unlock_bh(&p->tcf_lock); + + return skb->len; + +nla_put_failure: + spin_unlock_bh(&p->tcf_lock); + nlmsg_trim(skb, b); + return -1; +} + +static struct tc_action_ops act_nat_ops = { + .kind = "nat", + .id = TCA_ID_NAT, + .owner = THIS_MODULE, + .act = tcf_nat_act, + .dump = tcf_nat_dump, + .init = tcf_nat_init, + .size = sizeof(struct tcf_nat), +}; + +static __net_init int nat_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_nat_ops.net_id); + + return tc_action_net_init(net, tn, &act_nat_ops); +} + +static void __net_exit nat_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_nat_ops.net_id); +} + +static struct pernet_operations nat_net_ops = { + .init = nat_init_net, + .exit_batch = nat_exit_net, + .id = &act_nat_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +MODULE_DESCRIPTION("Stateless NAT actions"); +MODULE_LICENSE("GPL"); + +static int __init nat_init_module(void) +{ + return tcf_register_action(&act_nat_ops, &nat_net_ops); +} + +static void __exit nat_cleanup_module(void) +{ + tcf_unregister_action(&act_nat_ops, &nat_net_ops); +} + +module_init(nat_init_module); +module_exit(nat_cleanup_module); diff --git a/net/sched/act_pedit.c b/net/sched/act_pedit.c new file mode 100644 index 000000000..aee2e13f1 --- /dev/null +++ b/net/sched/act_pedit.c @@ -0,0 +1,627 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/act_pedit.c Generic packet editor + * + * Authors: Jamal Hadi Salim (2002-4) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static struct tc_action_ops act_pedit_ops; + +static const struct nla_policy pedit_policy[TCA_PEDIT_MAX + 1] = { + [TCA_PEDIT_PARMS] = { .len = sizeof(struct tc_pedit) }, + [TCA_PEDIT_PARMS_EX] = { .len = sizeof(struct tc_pedit) }, + [TCA_PEDIT_KEYS_EX] = { .type = NLA_NESTED }, +}; + +static const struct nla_policy pedit_key_ex_policy[TCA_PEDIT_KEY_EX_MAX + 1] = { + [TCA_PEDIT_KEY_EX_HTYPE] = { .type = NLA_U16 }, + [TCA_PEDIT_KEY_EX_CMD] = { .type = NLA_U16 }, +}; + +static struct tcf_pedit_key_ex *tcf_pedit_keys_ex_parse(struct nlattr *nla, + u8 n) +{ + struct tcf_pedit_key_ex *keys_ex; + struct tcf_pedit_key_ex *k; + const struct nlattr *ka; + int err = -EINVAL; + int rem; + + if (!nla) + return NULL; + + keys_ex = kcalloc(n, sizeof(*k), GFP_KERNEL); + if (!keys_ex) + return ERR_PTR(-ENOMEM); + + k = keys_ex; + + nla_for_each_nested(ka, nla, rem) { + struct nlattr *tb[TCA_PEDIT_KEY_EX_MAX + 1]; + + if (!n) { + err = -EINVAL; + goto err_out; + } + n--; + + if (nla_type(ka) != TCA_PEDIT_KEY_EX) { + err = -EINVAL; + goto err_out; + } + + err = nla_parse_nested_deprecated(tb, TCA_PEDIT_KEY_EX_MAX, + ka, pedit_key_ex_policy, + NULL); + if (err) + goto err_out; + + if (!tb[TCA_PEDIT_KEY_EX_HTYPE] || + !tb[TCA_PEDIT_KEY_EX_CMD]) { + err = -EINVAL; + goto err_out; + } + + k->htype = nla_get_u16(tb[TCA_PEDIT_KEY_EX_HTYPE]); + k->cmd = nla_get_u16(tb[TCA_PEDIT_KEY_EX_CMD]); + + if (k->htype > TCA_PEDIT_HDR_TYPE_MAX || + k->cmd > TCA_PEDIT_CMD_MAX) { + err = -EINVAL; + goto err_out; + } + + k++; + } + + if (n) { + err = -EINVAL; + goto err_out; + } + + return keys_ex; + +err_out: + kfree(keys_ex); + return ERR_PTR(err); +} + +static int tcf_pedit_key_ex_dump(struct sk_buff *skb, + struct tcf_pedit_key_ex *keys_ex, int n) +{ + struct nlattr *keys_start = nla_nest_start_noflag(skb, + TCA_PEDIT_KEYS_EX); + + if (!keys_start) + goto nla_failure; + for (; n > 0; n--) { + struct nlattr *key_start; + + key_start = nla_nest_start_noflag(skb, TCA_PEDIT_KEY_EX); + if (!key_start) + goto nla_failure; + + if (nla_put_u16(skb, TCA_PEDIT_KEY_EX_HTYPE, keys_ex->htype) || + nla_put_u16(skb, TCA_PEDIT_KEY_EX_CMD, keys_ex->cmd)) + goto nla_failure; + + nla_nest_end(skb, key_start); + + keys_ex++; + } + + nla_nest_end(skb, keys_start); + + return 0; +nla_failure: + nla_nest_cancel(skb, keys_start); + return -EINVAL; +} + +static void tcf_pedit_cleanup_rcu(struct rcu_head *head) +{ + struct tcf_pedit_parms *parms = + container_of(head, struct tcf_pedit_parms, rcu); + + kfree(parms->tcfp_keys_ex); + kfree(parms->tcfp_keys); + + kfree(parms); +} + +static int tcf_pedit_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, u32 flags, + struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_pedit_ops.net_id); + bool bind = flags & TCA_ACT_FLAGS_BIND; + struct tcf_chain *goto_ch = NULL; + struct tcf_pedit_parms *oparms, *nparms; + struct nlattr *tb[TCA_PEDIT_MAX + 1]; + struct tc_pedit *parm; + struct nlattr *pattr; + struct tcf_pedit *p; + int ret = 0, err; + int i, ksize; + u32 index; + + if (!nla) { + NL_SET_ERR_MSG_MOD(extack, "Pedit requires attributes to be passed"); + return -EINVAL; + } + + err = nla_parse_nested_deprecated(tb, TCA_PEDIT_MAX, nla, + pedit_policy, NULL); + if (err < 0) + return err; + + pattr = tb[TCA_PEDIT_PARMS]; + if (!pattr) + pattr = tb[TCA_PEDIT_PARMS_EX]; + if (!pattr) { + NL_SET_ERR_MSG_MOD(extack, "Missing required TCA_PEDIT_PARMS or TCA_PEDIT_PARMS_EX pedit attribute"); + return -EINVAL; + } + + parm = nla_data(pattr); + + index = parm->index; + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (!err) { + ret = tcf_idr_create_from_flags(tn, index, est, a, + &act_pedit_ops, bind, flags); + if (ret) { + tcf_idr_cleanup(tn, index); + return ret; + } + ret = ACT_P_CREATED; + } else if (err > 0) { + if (bind) + return 0; + if (!(flags & TCA_ACT_FLAGS_REPLACE)) { + ret = -EEXIST; + goto out_release; + } + } else { + return err; + } + + if (!parm->nkeys) { + NL_SET_ERR_MSG_MOD(extack, "Pedit requires keys to be passed"); + ret = -EINVAL; + goto out_release; + } + ksize = parm->nkeys * sizeof(struct tc_pedit_key); + if (nla_len(pattr) < sizeof(*parm) + ksize) { + NL_SET_ERR_MSG_ATTR(extack, pattr, "Length of TCA_PEDIT_PARMS or TCA_PEDIT_PARMS_EX pedit attribute is invalid"); + ret = -EINVAL; + goto out_release; + } + + nparms = kzalloc(sizeof(*nparms), GFP_KERNEL); + if (!nparms) { + ret = -ENOMEM; + goto out_release; + } + + nparms->tcfp_keys_ex = + tcf_pedit_keys_ex_parse(tb[TCA_PEDIT_KEYS_EX], parm->nkeys); + if (IS_ERR(nparms->tcfp_keys_ex)) { + ret = PTR_ERR(nparms->tcfp_keys_ex); + goto out_free; + } + + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack); + if (err < 0) { + ret = err; + goto out_free_ex; + } + + nparms->tcfp_off_max_hint = 0; + nparms->tcfp_flags = parm->flags; + nparms->tcfp_nkeys = parm->nkeys; + + nparms->tcfp_keys = kmalloc(ksize, GFP_KERNEL); + if (!nparms->tcfp_keys) { + ret = -ENOMEM; + goto put_chain; + } + + memcpy(nparms->tcfp_keys, parm->keys, ksize); + + for (i = 0; i < nparms->tcfp_nkeys; ++i) { + u32 cur = nparms->tcfp_keys[i].off; + + /* sanitize the shift value for any later use */ + nparms->tcfp_keys[i].shift = min_t(size_t, + BITS_PER_TYPE(int) - 1, + nparms->tcfp_keys[i].shift); + + /* The AT option can read a single byte, we can bound the actual + * value with uchar max. + */ + cur += (0xff & nparms->tcfp_keys[i].offmask) >> nparms->tcfp_keys[i].shift; + + /* Each key touches 4 bytes starting from the computed offset */ + nparms->tcfp_off_max_hint = + max(nparms->tcfp_off_max_hint, cur + 4); + } + + p = to_pedit(*a); + + spin_lock_bh(&p->tcf_lock); + goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch); + oparms = rcu_replace_pointer(p->parms, nparms, 1); + spin_unlock_bh(&p->tcf_lock); + + if (oparms) + call_rcu(&oparms->rcu, tcf_pedit_cleanup_rcu); + + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + + return ret; + +put_chain: + if (goto_ch) + tcf_chain_put_by_act(goto_ch); +out_free_ex: + kfree(nparms->tcfp_keys_ex); +out_free: + kfree(nparms); +out_release: + tcf_idr_release(*a, bind); + return ret; +} + +static void tcf_pedit_cleanup(struct tc_action *a) +{ + struct tcf_pedit *p = to_pedit(a); + struct tcf_pedit_parms *parms; + + parms = rcu_dereference_protected(p->parms, 1); + + if (parms) + call_rcu(&parms->rcu, tcf_pedit_cleanup_rcu); +} + +static bool offset_valid(struct sk_buff *skb, int offset) +{ + if (offset > 0 && offset > skb->len) + return false; + + if (offset < 0 && -offset > skb_headroom(skb)) + return false; + + return true; +} + +static int pedit_l4_skb_offset(struct sk_buff *skb, int *hoffset, const int header_type) +{ + const int noff = skb_network_offset(skb); + int ret = -EINVAL; + struct iphdr _iph; + + switch (skb->protocol) { + case htons(ETH_P_IP): { + const struct iphdr *iph = skb_header_pointer(skb, noff, sizeof(_iph), &_iph); + + if (!iph) + goto out; + *hoffset = noff + iph->ihl * 4; + ret = 0; + break; + } + case htons(ETH_P_IPV6): + ret = ipv6_find_hdr(skb, hoffset, header_type, NULL, NULL) == header_type ? 0 : -EINVAL; + break; + } +out: + return ret; +} + +static int pedit_skb_hdr_offset(struct sk_buff *skb, + enum pedit_header_type htype, int *hoffset) +{ + int ret = -EINVAL; + /* 'htype' is validated in the netlink parsing */ + switch (htype) { + case TCA_PEDIT_KEY_EX_HDR_TYPE_ETH: + if (skb_mac_header_was_set(skb)) { + *hoffset = skb_mac_offset(skb); + ret = 0; + } + break; + case TCA_PEDIT_KEY_EX_HDR_TYPE_NETWORK: + case TCA_PEDIT_KEY_EX_HDR_TYPE_IP4: + case TCA_PEDIT_KEY_EX_HDR_TYPE_IP6: + *hoffset = skb_network_offset(skb); + ret = 0; + break; + case TCA_PEDIT_KEY_EX_HDR_TYPE_TCP: + ret = pedit_l4_skb_offset(skb, hoffset, IPPROTO_TCP); + break; + case TCA_PEDIT_KEY_EX_HDR_TYPE_UDP: + ret = pedit_l4_skb_offset(skb, hoffset, IPPROTO_UDP); + break; + default: + break; + } + return ret; +} + +static int tcf_pedit_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + enum pedit_header_type htype = TCA_PEDIT_KEY_EX_HDR_TYPE_NETWORK; + enum pedit_cmd cmd = TCA_PEDIT_KEY_EX_CMD_SET; + struct tcf_pedit *p = to_pedit(a); + struct tcf_pedit_key_ex *tkey_ex; + struct tcf_pedit_parms *parms; + struct tc_pedit_key *tkey; + u32 max_offset; + int i; + + parms = rcu_dereference_bh(p->parms); + + max_offset = (skb_transport_header_was_set(skb) ? + skb_transport_offset(skb) : + skb_network_offset(skb)) + + parms->tcfp_off_max_hint; + if (skb_ensure_writable(skb, min(skb->len, max_offset))) + goto done; + + tcf_lastuse_update(&p->tcf_tm); + tcf_action_update_bstats(&p->common, skb); + + tkey = parms->tcfp_keys; + tkey_ex = parms->tcfp_keys_ex; + + for (i = parms->tcfp_nkeys; i > 0; i--, tkey++) { + int offset = tkey->off; + int hoffset = 0; + u32 *ptr, hdata; + u32 val; + int rc; + + if (tkey_ex) { + htype = tkey_ex->htype; + cmd = tkey_ex->cmd; + + tkey_ex++; + } + + rc = pedit_skb_hdr_offset(skb, htype, &hoffset); + if (rc) { + pr_info_ratelimited("tc action pedit unable to extract header offset for header type (0x%x)\n", htype); + goto bad; + } + + if (tkey->offmask) { + u8 *d, _d; + + if (!offset_valid(skb, hoffset + tkey->at)) { + pr_info("tc action pedit 'at' offset %d out of bounds\n", + hoffset + tkey->at); + goto bad; + } + d = skb_header_pointer(skb, hoffset + tkey->at, + sizeof(_d), &_d); + if (!d) + goto bad; + offset += (*d & tkey->offmask) >> tkey->shift; + } + + if (offset % 4) { + pr_info("tc action pedit offset must be on 32 bit boundaries\n"); + goto bad; + } + + if (!offset_valid(skb, hoffset + offset)) { + pr_info("tc action pedit offset %d out of bounds\n", + hoffset + offset); + goto bad; + } + + ptr = skb_header_pointer(skb, hoffset + offset, + sizeof(hdata), &hdata); + if (!ptr) + goto bad; + /* just do it, baby */ + switch (cmd) { + case TCA_PEDIT_KEY_EX_CMD_SET: + val = tkey->val; + break; + case TCA_PEDIT_KEY_EX_CMD_ADD: + val = (*ptr + tkey->val) & ~tkey->mask; + break; + default: + pr_info("tc action pedit bad command (%d)\n", + cmd); + goto bad; + } + + *ptr = ((*ptr & tkey->mask) ^ val); + if (ptr == &hdata) + skb_store_bits(skb, hoffset + offset, ptr, 4); + } + + goto done; + +bad: + spin_lock(&p->tcf_lock); + p->tcf_qstats.overlimits++; + spin_unlock(&p->tcf_lock); +done: + return p->tcf_action; +} + +static void tcf_pedit_stats_update(struct tc_action *a, u64 bytes, u64 packets, + u64 drops, u64 lastuse, bool hw) +{ + struct tcf_pedit *d = to_pedit(a); + struct tcf_t *tm = &d->tcf_tm; + + tcf_action_update_stats(a, bytes, packets, drops, hw); + tm->lastuse = max_t(u64, tm->lastuse, lastuse); +} + +static int tcf_pedit_dump(struct sk_buff *skb, struct tc_action *a, + int bind, int ref) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tcf_pedit *p = to_pedit(a); + struct tcf_pedit_parms *parms; + struct tc_pedit *opt; + struct tcf_t t; + int s; + + spin_lock_bh(&p->tcf_lock); + parms = rcu_dereference_protected(p->parms, 1); + s = struct_size(opt, keys, parms->tcfp_nkeys); + + opt = kzalloc(s, GFP_ATOMIC); + if (unlikely(!opt)) { + spin_unlock_bh(&p->tcf_lock); + return -ENOBUFS; + } + + memcpy(opt->keys, parms->tcfp_keys, + flex_array_size(opt, keys, parms->tcfp_nkeys)); + opt->index = p->tcf_index; + opt->nkeys = parms->tcfp_nkeys; + opt->flags = parms->tcfp_flags; + opt->action = p->tcf_action; + opt->refcnt = refcount_read(&p->tcf_refcnt) - ref; + opt->bindcnt = atomic_read(&p->tcf_bindcnt) - bind; + + if (parms->tcfp_keys_ex) { + if (tcf_pedit_key_ex_dump(skb, parms->tcfp_keys_ex, + parms->tcfp_nkeys)) + goto nla_put_failure; + + if (nla_put(skb, TCA_PEDIT_PARMS_EX, s, opt)) + goto nla_put_failure; + } else { + if (nla_put(skb, TCA_PEDIT_PARMS, s, opt)) + goto nla_put_failure; + } + + tcf_tm_dump(&t, &p->tcf_tm); + if (nla_put_64bit(skb, TCA_PEDIT_TM, sizeof(t), &t, TCA_PEDIT_PAD)) + goto nla_put_failure; + spin_unlock_bh(&p->tcf_lock); + + kfree(opt); + return skb->len; + +nla_put_failure: + spin_unlock_bh(&p->tcf_lock); + nlmsg_trim(skb, b); + kfree(opt); + return -1; +} + +static int tcf_pedit_offload_act_setup(struct tc_action *act, void *entry_data, + u32 *index_inc, bool bind, + struct netlink_ext_ack *extack) +{ + if (bind) { + struct flow_action_entry *entry = entry_data; + int k; + + for (k = 0; k < tcf_pedit_nkeys(act); k++) { + switch (tcf_pedit_cmd(act, k)) { + case TCA_PEDIT_KEY_EX_CMD_SET: + entry->id = FLOW_ACTION_MANGLE; + break; + case TCA_PEDIT_KEY_EX_CMD_ADD: + entry->id = FLOW_ACTION_ADD; + break; + default: + NL_SET_ERR_MSG_MOD(extack, "Unsupported pedit command offload"); + return -EOPNOTSUPP; + } + entry->mangle.htype = tcf_pedit_htype(act, k); + entry->mangle.mask = tcf_pedit_mask(act, k); + entry->mangle.val = tcf_pedit_val(act, k); + entry->mangle.offset = tcf_pedit_offset(act, k); + entry->hw_stats = tc_act_hw_stats(act->hw_stats); + entry++; + } + *index_inc = k; + } else { + return -EOPNOTSUPP; + } + + return 0; +} + +static struct tc_action_ops act_pedit_ops = { + .kind = "pedit", + .id = TCA_ID_PEDIT, + .owner = THIS_MODULE, + .act = tcf_pedit_act, + .stats_update = tcf_pedit_stats_update, + .dump = tcf_pedit_dump, + .cleanup = tcf_pedit_cleanup, + .init = tcf_pedit_init, + .offload_act_setup = tcf_pedit_offload_act_setup, + .size = sizeof(struct tcf_pedit), +}; + +static __net_init int pedit_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_pedit_ops.net_id); + + return tc_action_net_init(net, tn, &act_pedit_ops); +} + +static void __net_exit pedit_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_pedit_ops.net_id); +} + +static struct pernet_operations pedit_net_ops = { + .init = pedit_init_net, + .exit_batch = pedit_exit_net, + .id = &act_pedit_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +MODULE_AUTHOR("Jamal Hadi Salim(2002-4)"); +MODULE_DESCRIPTION("Generic Packet Editor actions"); +MODULE_LICENSE("GPL"); + +static int __init pedit_init_module(void) +{ + return tcf_register_action(&act_pedit_ops, &pedit_net_ops); +} + +static void __exit pedit_cleanup_module(void) +{ + tcf_unregister_action(&act_pedit_ops, &pedit_net_ops); +} + +module_init(pedit_init_module); +module_exit(pedit_cleanup_module); diff --git a/net/sched/act_police.c b/net/sched/act_police.c new file mode 100644 index 000000000..94be21378 --- /dev/null +++ b/net/sched/act_police.c @@ -0,0 +1,533 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/act_police.c Input police filter + * + * Authors: Alexey Kuznetsov, + * J Hadi Salim (action changes) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* Each policer is serialized by its individual spinlock */ + +static struct tc_action_ops act_police_ops; + +static const struct nla_policy police_policy[TCA_POLICE_MAX + 1] = { + [TCA_POLICE_RATE] = { .len = TC_RTAB_SIZE }, + [TCA_POLICE_PEAKRATE] = { .len = TC_RTAB_SIZE }, + [TCA_POLICE_AVRATE] = { .type = NLA_U32 }, + [TCA_POLICE_RESULT] = { .type = NLA_U32 }, + [TCA_POLICE_RATE64] = { .type = NLA_U64 }, + [TCA_POLICE_PEAKRATE64] = { .type = NLA_U64 }, + [TCA_POLICE_PKTRATE64] = { .type = NLA_U64, .min = 1 }, + [TCA_POLICE_PKTBURST64] = { .type = NLA_U64, .min = 1 }, +}; + +static int tcf_police_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, u32 flags, + struct netlink_ext_ack *extack) +{ + int ret = 0, tcfp_result = TC_ACT_OK, err, size; + bool bind = flags & TCA_ACT_FLAGS_BIND; + struct nlattr *tb[TCA_POLICE_MAX + 1]; + struct tcf_chain *goto_ch = NULL; + struct tc_police *parm; + struct tcf_police *police; + struct qdisc_rate_table *R_tab = NULL, *P_tab = NULL; + struct tc_action_net *tn = net_generic(net, act_police_ops.net_id); + struct tcf_police_params *new; + bool exists = false; + u32 index; + u64 rate64, prate64; + u64 pps, ppsburst; + + if (nla == NULL) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, TCA_POLICE_MAX, nla, + police_policy, NULL); + if (err < 0) + return err; + + if (tb[TCA_POLICE_TBF] == NULL) + return -EINVAL; + size = nla_len(tb[TCA_POLICE_TBF]); + if (size != sizeof(*parm) && size != sizeof(struct tc_police_compat)) + return -EINVAL; + + parm = nla_data(tb[TCA_POLICE_TBF]); + index = parm->index; + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (err < 0) + return err; + exists = err; + if (exists && bind) + return 0; + + if (!exists) { + ret = tcf_idr_create(tn, index, NULL, a, + &act_police_ops, bind, true, flags); + if (ret) { + tcf_idr_cleanup(tn, index); + return ret; + } + ret = ACT_P_CREATED; + spin_lock_init(&(to_police(*a)->tcfp_lock)); + } else if (!(flags & TCA_ACT_FLAGS_REPLACE)) { + tcf_idr_release(*a, bind); + return -EEXIST; + } + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack); + if (err < 0) + goto release_idr; + + police = to_police(*a); + if (parm->rate.rate) { + err = -ENOMEM; + R_tab = qdisc_get_rtab(&parm->rate, tb[TCA_POLICE_RATE], NULL); + if (R_tab == NULL) + goto failure; + + if (parm->peakrate.rate) { + P_tab = qdisc_get_rtab(&parm->peakrate, + tb[TCA_POLICE_PEAKRATE], NULL); + if (P_tab == NULL) + goto failure; + } + } + + if (est) { + err = gen_replace_estimator(&police->tcf_bstats, + police->common.cpu_bstats, + &police->tcf_rate_est, + &police->tcf_lock, + false, est); + if (err) + goto failure; + } else if (tb[TCA_POLICE_AVRATE] && + (ret == ACT_P_CREATED || + !gen_estimator_active(&police->tcf_rate_est))) { + err = -EINVAL; + goto failure; + } + + if (tb[TCA_POLICE_RESULT]) { + tcfp_result = nla_get_u32(tb[TCA_POLICE_RESULT]); + if (TC_ACT_EXT_CMP(tcfp_result, TC_ACT_GOTO_CHAIN)) { + NL_SET_ERR_MSG(extack, + "goto chain not allowed on fallback"); + err = -EINVAL; + goto failure; + } + } + + if ((tb[TCA_POLICE_PKTRATE64] && !tb[TCA_POLICE_PKTBURST64]) || + (!tb[TCA_POLICE_PKTRATE64] && tb[TCA_POLICE_PKTBURST64])) { + NL_SET_ERR_MSG(extack, + "Both or neither packet-per-second burst and rate must be provided"); + err = -EINVAL; + goto failure; + } + + if (tb[TCA_POLICE_PKTRATE64] && R_tab) { + NL_SET_ERR_MSG(extack, + "packet-per-second and byte-per-second rate limits not allowed in same action"); + err = -EINVAL; + goto failure; + } + + new = kzalloc(sizeof(*new), GFP_KERNEL); + if (unlikely(!new)) { + err = -ENOMEM; + goto failure; + } + + /* No failure allowed after this point */ + new->tcfp_result = tcfp_result; + new->tcfp_mtu = parm->mtu; + if (!new->tcfp_mtu) { + new->tcfp_mtu = ~0; + if (R_tab) + new->tcfp_mtu = 255 << R_tab->rate.cell_log; + } + if (R_tab) { + new->rate_present = true; + rate64 = tb[TCA_POLICE_RATE64] ? + nla_get_u64(tb[TCA_POLICE_RATE64]) : 0; + psched_ratecfg_precompute(&new->rate, &R_tab->rate, rate64); + qdisc_put_rtab(R_tab); + } else { + new->rate_present = false; + } + if (P_tab) { + new->peak_present = true; + prate64 = tb[TCA_POLICE_PEAKRATE64] ? + nla_get_u64(tb[TCA_POLICE_PEAKRATE64]) : 0; + psched_ratecfg_precompute(&new->peak, &P_tab->rate, prate64); + qdisc_put_rtab(P_tab); + } else { + new->peak_present = false; + } + + new->tcfp_burst = PSCHED_TICKS2NS(parm->burst); + if (new->peak_present) + new->tcfp_mtu_ptoks = (s64)psched_l2t_ns(&new->peak, + new->tcfp_mtu); + + if (tb[TCA_POLICE_AVRATE]) + new->tcfp_ewma_rate = nla_get_u32(tb[TCA_POLICE_AVRATE]); + + if (tb[TCA_POLICE_PKTRATE64]) { + pps = nla_get_u64(tb[TCA_POLICE_PKTRATE64]); + ppsburst = nla_get_u64(tb[TCA_POLICE_PKTBURST64]); + new->pps_present = true; + new->tcfp_pkt_burst = PSCHED_TICKS2NS(ppsburst); + psched_ppscfg_precompute(&new->ppsrate, pps); + } + + spin_lock_bh(&police->tcf_lock); + spin_lock_bh(&police->tcfp_lock); + police->tcfp_t_c = ktime_get_ns(); + police->tcfp_toks = new->tcfp_burst; + if (new->peak_present) + police->tcfp_ptoks = new->tcfp_mtu_ptoks; + spin_unlock_bh(&police->tcfp_lock); + goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch); + new = rcu_replace_pointer(police->params, + new, + lockdep_is_held(&police->tcf_lock)); + spin_unlock_bh(&police->tcf_lock); + + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + if (new) + kfree_rcu(new, rcu); + + return ret; + +failure: + qdisc_put_rtab(P_tab); + qdisc_put_rtab(R_tab); + if (goto_ch) + tcf_chain_put_by_act(goto_ch); +release_idr: + tcf_idr_release(*a, bind); + return err; +} + +static bool tcf_police_mtu_check(struct sk_buff *skb, u32 limit) +{ + u32 len; + + if (skb_is_gso(skb)) + return skb_gso_validate_mac_len(skb, limit); + + len = qdisc_pkt_len(skb); + if (skb_at_tc_ingress(skb)) + len += skb->mac_len; + + return len <= limit; +} + +static int tcf_police_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + struct tcf_police *police = to_police(a); + s64 now, toks, ppstoks = 0, ptoks = 0; + struct tcf_police_params *p; + int ret; + + tcf_lastuse_update(&police->tcf_tm); + bstats_update(this_cpu_ptr(police->common.cpu_bstats), skb); + + ret = READ_ONCE(police->tcf_action); + p = rcu_dereference_bh(police->params); + + if (p->tcfp_ewma_rate) { + struct gnet_stats_rate_est64 sample; + + if (!gen_estimator_read(&police->tcf_rate_est, &sample) || + sample.bps >= p->tcfp_ewma_rate) + goto inc_overlimits; + } + + if (tcf_police_mtu_check(skb, p->tcfp_mtu)) { + if (!p->rate_present && !p->pps_present) { + ret = p->tcfp_result; + goto end; + } + + now = ktime_get_ns(); + spin_lock_bh(&police->tcfp_lock); + toks = min_t(s64, now - police->tcfp_t_c, p->tcfp_burst); + if (p->peak_present) { + ptoks = toks + police->tcfp_ptoks; + if (ptoks > p->tcfp_mtu_ptoks) + ptoks = p->tcfp_mtu_ptoks; + ptoks -= (s64)psched_l2t_ns(&p->peak, + qdisc_pkt_len(skb)); + } + if (p->rate_present) { + toks += police->tcfp_toks; + if (toks > p->tcfp_burst) + toks = p->tcfp_burst; + toks -= (s64)psched_l2t_ns(&p->rate, qdisc_pkt_len(skb)); + } else if (p->pps_present) { + ppstoks = min_t(s64, now - police->tcfp_t_c, p->tcfp_pkt_burst); + ppstoks += police->tcfp_pkttoks; + if (ppstoks > p->tcfp_pkt_burst) + ppstoks = p->tcfp_pkt_burst; + ppstoks -= (s64)psched_pkt2t_ns(&p->ppsrate, 1); + } + if ((toks | ptoks | ppstoks) >= 0) { + police->tcfp_t_c = now; + police->tcfp_toks = toks; + police->tcfp_ptoks = ptoks; + police->tcfp_pkttoks = ppstoks; + spin_unlock_bh(&police->tcfp_lock); + ret = p->tcfp_result; + goto inc_drops; + } + spin_unlock_bh(&police->tcfp_lock); + } + +inc_overlimits: + qstats_overlimit_inc(this_cpu_ptr(police->common.cpu_qstats)); +inc_drops: + if (ret == TC_ACT_SHOT) + qstats_drop_inc(this_cpu_ptr(police->common.cpu_qstats)); +end: + return ret; +} + +static void tcf_police_cleanup(struct tc_action *a) +{ + struct tcf_police *police = to_police(a); + struct tcf_police_params *p; + + p = rcu_dereference_protected(police->params, 1); + if (p) + kfree_rcu(p, rcu); +} + +static void tcf_police_stats_update(struct tc_action *a, + u64 bytes, u64 packets, u64 drops, + u64 lastuse, bool hw) +{ + struct tcf_police *police = to_police(a); + struct tcf_t *tm = &police->tcf_tm; + + tcf_action_update_stats(a, bytes, packets, drops, hw); + tm->lastuse = max_t(u64, tm->lastuse, lastuse); +} + +static int tcf_police_dump(struct sk_buff *skb, struct tc_action *a, + int bind, int ref) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tcf_police *police = to_police(a); + struct tcf_police_params *p; + struct tc_police opt = { + .index = police->tcf_index, + .refcnt = refcount_read(&police->tcf_refcnt) - ref, + .bindcnt = atomic_read(&police->tcf_bindcnt) - bind, + }; + struct tcf_t t; + + spin_lock_bh(&police->tcf_lock); + opt.action = police->tcf_action; + p = rcu_dereference_protected(police->params, + lockdep_is_held(&police->tcf_lock)); + opt.mtu = p->tcfp_mtu; + opt.burst = PSCHED_NS2TICKS(p->tcfp_burst); + if (p->rate_present) { + psched_ratecfg_getrate(&opt.rate, &p->rate); + if ((p->rate.rate_bytes_ps >= (1ULL << 32)) && + nla_put_u64_64bit(skb, TCA_POLICE_RATE64, + p->rate.rate_bytes_ps, + TCA_POLICE_PAD)) + goto nla_put_failure; + } + if (p->peak_present) { + psched_ratecfg_getrate(&opt.peakrate, &p->peak); + if ((p->peak.rate_bytes_ps >= (1ULL << 32)) && + nla_put_u64_64bit(skb, TCA_POLICE_PEAKRATE64, + p->peak.rate_bytes_ps, + TCA_POLICE_PAD)) + goto nla_put_failure; + } + if (p->pps_present) { + if (nla_put_u64_64bit(skb, TCA_POLICE_PKTRATE64, + p->ppsrate.rate_pkts_ps, + TCA_POLICE_PAD)) + goto nla_put_failure; + if (nla_put_u64_64bit(skb, TCA_POLICE_PKTBURST64, + PSCHED_NS2TICKS(p->tcfp_pkt_burst), + TCA_POLICE_PAD)) + goto nla_put_failure; + } + if (nla_put(skb, TCA_POLICE_TBF, sizeof(opt), &opt)) + goto nla_put_failure; + if (p->tcfp_result && + nla_put_u32(skb, TCA_POLICE_RESULT, p->tcfp_result)) + goto nla_put_failure; + if (p->tcfp_ewma_rate && + nla_put_u32(skb, TCA_POLICE_AVRATE, p->tcfp_ewma_rate)) + goto nla_put_failure; + + tcf_tm_dump(&t, &police->tcf_tm); + if (nla_put_64bit(skb, TCA_POLICE_TM, sizeof(t), &t, TCA_POLICE_PAD)) + goto nla_put_failure; + spin_unlock_bh(&police->tcf_lock); + + return skb->len; + +nla_put_failure: + spin_unlock_bh(&police->tcf_lock); + nlmsg_trim(skb, b); + return -1; +} + +static int tcf_police_act_to_flow_act(int tc_act, u32 *extval, + struct netlink_ext_ack *extack) +{ + int act_id = -EOPNOTSUPP; + + if (!TC_ACT_EXT_OPCODE(tc_act)) { + if (tc_act == TC_ACT_OK) + act_id = FLOW_ACTION_ACCEPT; + else if (tc_act == TC_ACT_SHOT) + act_id = FLOW_ACTION_DROP; + else if (tc_act == TC_ACT_PIPE) + act_id = FLOW_ACTION_PIPE; + else if (tc_act == TC_ACT_RECLASSIFY) + NL_SET_ERR_MSG_MOD(extack, "Offload not supported when conform/exceed action is \"reclassify\""); + else + NL_SET_ERR_MSG_MOD(extack, "Unsupported conform/exceed action offload"); + } else if (TC_ACT_EXT_CMP(tc_act, TC_ACT_GOTO_CHAIN)) { + act_id = FLOW_ACTION_GOTO; + *extval = tc_act & TC_ACT_EXT_VAL_MASK; + } else if (TC_ACT_EXT_CMP(tc_act, TC_ACT_JUMP)) { + act_id = FLOW_ACTION_JUMP; + *extval = tc_act & TC_ACT_EXT_VAL_MASK; + } else if (tc_act == TC_ACT_UNSPEC) { + act_id = FLOW_ACTION_CONTINUE; + } else { + NL_SET_ERR_MSG_MOD(extack, "Unsupported conform/exceed action offload"); + } + + return act_id; +} + +static int tcf_police_offload_act_setup(struct tc_action *act, void *entry_data, + u32 *index_inc, bool bind, + struct netlink_ext_ack *extack) +{ + if (bind) { + struct flow_action_entry *entry = entry_data; + struct tcf_police *police = to_police(act); + struct tcf_police_params *p; + int act_id; + + p = rcu_dereference_protected(police->params, + lockdep_is_held(&police->tcf_lock)); + + entry->id = FLOW_ACTION_POLICE; + entry->police.burst = tcf_police_burst(act); + entry->police.rate_bytes_ps = + tcf_police_rate_bytes_ps(act); + entry->police.peakrate_bytes_ps = tcf_police_peakrate_bytes_ps(act); + entry->police.avrate = tcf_police_tcfp_ewma_rate(act); + entry->police.overhead = tcf_police_rate_overhead(act); + entry->police.burst_pkt = tcf_police_burst_pkt(act); + entry->police.rate_pkt_ps = + tcf_police_rate_pkt_ps(act); + entry->police.mtu = tcf_police_tcfp_mtu(act); + + act_id = tcf_police_act_to_flow_act(police->tcf_action, + &entry->police.exceed.extval, + extack); + if (act_id < 0) + return act_id; + + entry->police.exceed.act_id = act_id; + + act_id = tcf_police_act_to_flow_act(p->tcfp_result, + &entry->police.notexceed.extval, + extack); + if (act_id < 0) + return act_id; + + entry->police.notexceed.act_id = act_id; + + *index_inc = 1; + } else { + struct flow_offload_action *fl_action = entry_data; + + fl_action->id = FLOW_ACTION_POLICE; + } + + return 0; +} + +MODULE_AUTHOR("Alexey Kuznetsov"); +MODULE_DESCRIPTION("Policing actions"); +MODULE_LICENSE("GPL"); + +static struct tc_action_ops act_police_ops = { + .kind = "police", + .id = TCA_ID_POLICE, + .owner = THIS_MODULE, + .stats_update = tcf_police_stats_update, + .act = tcf_police_act, + .dump = tcf_police_dump, + .init = tcf_police_init, + .cleanup = tcf_police_cleanup, + .offload_act_setup = tcf_police_offload_act_setup, + .size = sizeof(struct tcf_police), +}; + +static __net_init int police_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_police_ops.net_id); + + return tc_action_net_init(net, tn, &act_police_ops); +} + +static void __net_exit police_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_police_ops.net_id); +} + +static struct pernet_operations police_net_ops = { + .init = police_init_net, + .exit_batch = police_exit_net, + .id = &act_police_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +static int __init police_init_module(void) +{ + return tcf_register_action(&act_police_ops, &police_net_ops); +} + +static void __exit police_cleanup_module(void) +{ + tcf_unregister_action(&act_police_ops, &police_net_ops); +} + +module_init(police_init_module); +module_exit(police_cleanup_module); diff --git a/net/sched/act_sample.c b/net/sched/act_sample.c new file mode 100644 index 000000000..09735a33e --- /dev/null +++ b/net/sched/act_sample.c @@ -0,0 +1,352 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/sched/act_sample.c - Packet sampling tc action + * Copyright (c) 2017 Yotam Gigi + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +static struct tc_action_ops act_sample_ops; + +static const struct nla_policy sample_policy[TCA_SAMPLE_MAX + 1] = { + [TCA_SAMPLE_PARMS] = { .len = sizeof(struct tc_sample) }, + [TCA_SAMPLE_RATE] = { .type = NLA_U32 }, + [TCA_SAMPLE_TRUNC_SIZE] = { .type = NLA_U32 }, + [TCA_SAMPLE_PSAMPLE_GROUP] = { .type = NLA_U32 }, +}; + +static int tcf_sample_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, + u32 flags, struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_sample_ops.net_id); + bool bind = flags & TCA_ACT_FLAGS_BIND; + struct nlattr *tb[TCA_SAMPLE_MAX + 1]; + struct psample_group *psample_group; + u32 psample_group_num, rate, index; + struct tcf_chain *goto_ch = NULL; + struct tc_sample *parm; + struct tcf_sample *s; + bool exists = false; + int ret, err; + + if (!nla) + return -EINVAL; + ret = nla_parse_nested_deprecated(tb, TCA_SAMPLE_MAX, nla, + sample_policy, NULL); + if (ret < 0) + return ret; + + if (!tb[TCA_SAMPLE_PARMS]) + return -EINVAL; + + parm = nla_data(tb[TCA_SAMPLE_PARMS]); + index = parm->index; + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (err < 0) + return err; + exists = err; + if (exists && bind) + return 0; + + if (!exists) { + ret = tcf_idr_create(tn, index, est, a, + &act_sample_ops, bind, true, flags); + if (ret) { + tcf_idr_cleanup(tn, index); + return ret; + } + ret = ACT_P_CREATED; + } else if (!(flags & TCA_ACT_FLAGS_REPLACE)) { + tcf_idr_release(*a, bind); + return -EEXIST; + } + + if (!tb[TCA_SAMPLE_RATE] || !tb[TCA_SAMPLE_PSAMPLE_GROUP]) { + NL_SET_ERR_MSG(extack, "sample rate and group are required"); + err = -EINVAL; + goto release_idr; + } + + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack); + if (err < 0) + goto release_idr; + + rate = nla_get_u32(tb[TCA_SAMPLE_RATE]); + if (!rate) { + NL_SET_ERR_MSG(extack, "invalid sample rate"); + err = -EINVAL; + goto put_chain; + } + psample_group_num = nla_get_u32(tb[TCA_SAMPLE_PSAMPLE_GROUP]); + psample_group = psample_group_get(net, psample_group_num); + if (!psample_group) { + err = -ENOMEM; + goto put_chain; + } + + s = to_sample(*a); + + spin_lock_bh(&s->tcf_lock); + goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch); + s->rate = rate; + s->psample_group_num = psample_group_num; + psample_group = rcu_replace_pointer(s->psample_group, psample_group, + lockdep_is_held(&s->tcf_lock)); + + if (tb[TCA_SAMPLE_TRUNC_SIZE]) { + s->truncate = true; + s->trunc_size = nla_get_u32(tb[TCA_SAMPLE_TRUNC_SIZE]); + } + spin_unlock_bh(&s->tcf_lock); + + if (psample_group) + psample_group_put(psample_group); + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + + return ret; +put_chain: + if (goto_ch) + tcf_chain_put_by_act(goto_ch); +release_idr: + tcf_idr_release(*a, bind); + return err; +} + +static void tcf_sample_cleanup(struct tc_action *a) +{ + struct tcf_sample *s = to_sample(a); + struct psample_group *psample_group; + + /* last reference to action, no need to lock */ + psample_group = rcu_dereference_protected(s->psample_group, 1); + RCU_INIT_POINTER(s->psample_group, NULL); + if (psample_group) + psample_group_put(psample_group); +} + +static bool tcf_sample_dev_ok_push(struct net_device *dev) +{ + switch (dev->type) { + case ARPHRD_TUNNEL: + case ARPHRD_TUNNEL6: + case ARPHRD_SIT: + case ARPHRD_IPGRE: + case ARPHRD_IP6GRE: + case ARPHRD_VOID: + case ARPHRD_NONE: + return false; + default: + return true; + } +} + +static int tcf_sample_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + struct tcf_sample *s = to_sample(a); + struct psample_group *psample_group; + struct psample_metadata md = {}; + int retval; + + tcf_lastuse_update(&s->tcf_tm); + bstats_update(this_cpu_ptr(s->common.cpu_bstats), skb); + retval = READ_ONCE(s->tcf_action); + + psample_group = rcu_dereference_bh(s->psample_group); + + /* randomly sample packets according to rate */ + if (psample_group && (prandom_u32_max(s->rate) == 0)) { + if (!skb_at_tc_ingress(skb)) { + md.in_ifindex = skb->skb_iif; + md.out_ifindex = skb->dev->ifindex; + } else { + md.in_ifindex = skb->dev->ifindex; + } + + /* on ingress, the mac header gets popped, so push it back */ + if (skb_at_tc_ingress(skb) && tcf_sample_dev_ok_push(skb->dev)) + skb_push(skb, skb->mac_len); + + md.trunc_size = s->truncate ? s->trunc_size : skb->len; + psample_sample_packet(psample_group, skb, s->rate, &md); + + if (skb_at_tc_ingress(skb) && tcf_sample_dev_ok_push(skb->dev)) + skb_pull(skb, skb->mac_len); + } + + return retval; +} + +static void tcf_sample_stats_update(struct tc_action *a, u64 bytes, u64 packets, + u64 drops, u64 lastuse, bool hw) +{ + struct tcf_sample *s = to_sample(a); + struct tcf_t *tm = &s->tcf_tm; + + tcf_action_update_stats(a, bytes, packets, drops, hw); + tm->lastuse = max_t(u64, tm->lastuse, lastuse); +} + +static int tcf_sample_dump(struct sk_buff *skb, struct tc_action *a, + int bind, int ref) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tcf_sample *s = to_sample(a); + struct tc_sample opt = { + .index = s->tcf_index, + .refcnt = refcount_read(&s->tcf_refcnt) - ref, + .bindcnt = atomic_read(&s->tcf_bindcnt) - bind, + }; + struct tcf_t t; + + spin_lock_bh(&s->tcf_lock); + opt.action = s->tcf_action; + if (nla_put(skb, TCA_SAMPLE_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + + tcf_tm_dump(&t, &s->tcf_tm); + if (nla_put_64bit(skb, TCA_SAMPLE_TM, sizeof(t), &t, TCA_SAMPLE_PAD)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_SAMPLE_RATE, s->rate)) + goto nla_put_failure; + + if (s->truncate) + if (nla_put_u32(skb, TCA_SAMPLE_TRUNC_SIZE, s->trunc_size)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_SAMPLE_PSAMPLE_GROUP, s->psample_group_num)) + goto nla_put_failure; + spin_unlock_bh(&s->tcf_lock); + + return skb->len; + +nla_put_failure: + spin_unlock_bh(&s->tcf_lock); + nlmsg_trim(skb, b); + return -1; +} + +static void tcf_psample_group_put(void *priv) +{ + struct psample_group *group = priv; + + psample_group_put(group); +} + +static struct psample_group * +tcf_sample_get_group(const struct tc_action *a, + tc_action_priv_destructor *destructor) +{ + struct tcf_sample *s = to_sample(a); + struct psample_group *group; + + group = rcu_dereference_protected(s->psample_group, + lockdep_is_held(&s->tcf_lock)); + if (group) { + psample_group_take(group); + *destructor = tcf_psample_group_put; + } + + return group; +} + +static void tcf_offload_sample_get_group(struct flow_action_entry *entry, + const struct tc_action *act) +{ + entry->sample.psample_group = + act->ops->get_psample_group(act, &entry->destructor); + entry->destructor_priv = entry->sample.psample_group; +} + +static int tcf_sample_offload_act_setup(struct tc_action *act, void *entry_data, + u32 *index_inc, bool bind, + struct netlink_ext_ack *extack) +{ + if (bind) { + struct flow_action_entry *entry = entry_data; + + entry->id = FLOW_ACTION_SAMPLE; + entry->sample.trunc_size = tcf_sample_trunc_size(act); + entry->sample.truncate = tcf_sample_truncate(act); + entry->sample.rate = tcf_sample_rate(act); + tcf_offload_sample_get_group(entry, act); + *index_inc = 1; + } else { + struct flow_offload_action *fl_action = entry_data; + + fl_action->id = FLOW_ACTION_SAMPLE; + } + + return 0; +} + +static struct tc_action_ops act_sample_ops = { + .kind = "sample", + .id = TCA_ID_SAMPLE, + .owner = THIS_MODULE, + .act = tcf_sample_act, + .stats_update = tcf_sample_stats_update, + .dump = tcf_sample_dump, + .init = tcf_sample_init, + .cleanup = tcf_sample_cleanup, + .get_psample_group = tcf_sample_get_group, + .offload_act_setup = tcf_sample_offload_act_setup, + .size = sizeof(struct tcf_sample), +}; + +static __net_init int sample_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_sample_ops.net_id); + + return tc_action_net_init(net, tn, &act_sample_ops); +} + +static void __net_exit sample_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_sample_ops.net_id); +} + +static struct pernet_operations sample_net_ops = { + .init = sample_init_net, + .exit_batch = sample_exit_net, + .id = &act_sample_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +static int __init sample_init_module(void) +{ + return tcf_register_action(&act_sample_ops, &sample_net_ops); +} + +static void __exit sample_cleanup_module(void) +{ + tcf_unregister_action(&act_sample_ops, &sample_net_ops); +} + +module_init(sample_init_module); +module_exit(sample_cleanup_module); + +MODULE_AUTHOR("Yotam Gigi "); +MODULE_DESCRIPTION("Packet sampling action"); +MODULE_LICENSE("GPL v2"); diff --git a/net/sched/act_simple.c b/net/sched/act_simple.c new file mode 100644 index 000000000..18d376135 --- /dev/null +++ b/net/sched/act_simple.c @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/act_simple.c Simple example of an action + * + * Authors: Jamal Hadi Salim (2005-8) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +static struct tc_action_ops act_simp_ops; + +#define SIMP_MAX_DATA 32 +static int tcf_simp_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + struct tcf_defact *d = to_defact(a); + + spin_lock(&d->tcf_lock); + tcf_lastuse_update(&d->tcf_tm); + bstats_update(&d->tcf_bstats, skb); + + /* print policy string followed by _ then packet count + * Example if this was the 3rd packet and the string was "hello" + * then it would look like "hello_3" (without quotes) + */ + pr_info("simple: %s_%llu\n", + (char *)d->tcfd_defdata, + u64_stats_read(&d->tcf_bstats.packets)); + spin_unlock(&d->tcf_lock); + return d->tcf_action; +} + +static void tcf_simp_release(struct tc_action *a) +{ + struct tcf_defact *d = to_defact(a); + kfree(d->tcfd_defdata); +} + +static int alloc_defdata(struct tcf_defact *d, const struct nlattr *defdata) +{ + d->tcfd_defdata = kzalloc(SIMP_MAX_DATA, GFP_KERNEL); + if (unlikely(!d->tcfd_defdata)) + return -ENOMEM; + nla_strscpy(d->tcfd_defdata, defdata, SIMP_MAX_DATA); + return 0; +} + +static int reset_policy(struct tc_action *a, const struct nlattr *defdata, + struct tc_defact *p, struct tcf_proto *tp, + struct netlink_ext_ack *extack) +{ + struct tcf_chain *goto_ch = NULL; + struct tcf_defact *d; + int err; + + err = tcf_action_check_ctrlact(p->action, tp, &goto_ch, extack); + if (err < 0) + return err; + d = to_defact(a); + spin_lock_bh(&d->tcf_lock); + goto_ch = tcf_action_set_ctrlact(a, p->action, goto_ch); + memset(d->tcfd_defdata, 0, SIMP_MAX_DATA); + nla_strscpy(d->tcfd_defdata, defdata, SIMP_MAX_DATA); + spin_unlock_bh(&d->tcf_lock); + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + return 0; +} + +static const struct nla_policy simple_policy[TCA_DEF_MAX + 1] = { + [TCA_DEF_PARMS] = { .len = sizeof(struct tc_defact) }, + [TCA_DEF_DATA] = { .type = NLA_STRING, .len = SIMP_MAX_DATA }, +}; + +static int tcf_simp_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, u32 flags, + struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_simp_ops.net_id); + bool bind = flags & TCA_ACT_FLAGS_BIND; + struct nlattr *tb[TCA_DEF_MAX + 1]; + struct tcf_chain *goto_ch = NULL; + struct tc_defact *parm; + struct tcf_defact *d; + bool exists = false; + int ret = 0, err; + u32 index; + + if (nla == NULL) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, TCA_DEF_MAX, nla, simple_policy, + NULL); + if (err < 0) + return err; + + if (tb[TCA_DEF_PARMS] == NULL) + return -EINVAL; + + parm = nla_data(tb[TCA_DEF_PARMS]); + index = parm->index; + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (err < 0) + return err; + exists = err; + if (exists && bind) + return 0; + + if (tb[TCA_DEF_DATA] == NULL) { + if (exists) + tcf_idr_release(*a, bind); + else + tcf_idr_cleanup(tn, index); + return -EINVAL; + } + + if (!exists) { + ret = tcf_idr_create(tn, index, est, a, + &act_simp_ops, bind, false, flags); + if (ret) { + tcf_idr_cleanup(tn, index); + return ret; + } + + d = to_defact(*a); + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, + extack); + if (err < 0) + goto release_idr; + + err = alloc_defdata(d, tb[TCA_DEF_DATA]); + if (err < 0) + goto put_chain; + + tcf_action_set_ctrlact(*a, parm->action, goto_ch); + ret = ACT_P_CREATED; + } else { + if (!(flags & TCA_ACT_FLAGS_REPLACE)) { + err = -EEXIST; + goto release_idr; + } + + err = reset_policy(*a, tb[TCA_DEF_DATA], parm, tp, extack); + if (err) + goto release_idr; + } + + return ret; +put_chain: + if (goto_ch) + tcf_chain_put_by_act(goto_ch); +release_idr: + tcf_idr_release(*a, bind); + return err; +} + +static int tcf_simp_dump(struct sk_buff *skb, struct tc_action *a, + int bind, int ref) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tcf_defact *d = to_defact(a); + struct tc_defact opt = { + .index = d->tcf_index, + .refcnt = refcount_read(&d->tcf_refcnt) - ref, + .bindcnt = atomic_read(&d->tcf_bindcnt) - bind, + }; + struct tcf_t t; + + spin_lock_bh(&d->tcf_lock); + opt.action = d->tcf_action; + if (nla_put(skb, TCA_DEF_PARMS, sizeof(opt), &opt) || + nla_put_string(skb, TCA_DEF_DATA, d->tcfd_defdata)) + goto nla_put_failure; + + tcf_tm_dump(&t, &d->tcf_tm); + if (nla_put_64bit(skb, TCA_DEF_TM, sizeof(t), &t, TCA_DEF_PAD)) + goto nla_put_failure; + spin_unlock_bh(&d->tcf_lock); + + return skb->len; + +nla_put_failure: + spin_unlock_bh(&d->tcf_lock); + nlmsg_trim(skb, b); + return -1; +} + +static struct tc_action_ops act_simp_ops = { + .kind = "simple", + .id = TCA_ID_SIMP, + .owner = THIS_MODULE, + .act = tcf_simp_act, + .dump = tcf_simp_dump, + .cleanup = tcf_simp_release, + .init = tcf_simp_init, + .size = sizeof(struct tcf_defact), +}; + +static __net_init int simp_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_simp_ops.net_id); + + return tc_action_net_init(net, tn, &act_simp_ops); +} + +static void __net_exit simp_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_simp_ops.net_id); +} + +static struct pernet_operations simp_net_ops = { + .init = simp_init_net, + .exit_batch = simp_exit_net, + .id = &act_simp_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +MODULE_AUTHOR("Jamal Hadi Salim(2005)"); +MODULE_DESCRIPTION("Simple example action"); +MODULE_LICENSE("GPL"); + +static int __init simp_init_module(void) +{ + int ret = tcf_register_action(&act_simp_ops, &simp_net_ops); + if (!ret) + pr_info("Simple TC action Loaded\n"); + return ret; +} + +static void __exit simp_cleanup_module(void) +{ + tcf_unregister_action(&act_simp_ops, &simp_net_ops); +} + +module_init(simp_init_module); +module_exit(simp_cleanup_module); diff --git a/net/sched/act_skbedit.c b/net/sched/act_skbedit.c new file mode 100644 index 000000000..7f598784f --- /dev/null +++ b/net/sched/act_skbedit.c @@ -0,0 +1,452 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008, Intel Corporation. + * + * Author: Alexander Duyck + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +static struct tc_action_ops act_skbedit_ops; + +static u16 tcf_skbedit_hash(struct tcf_skbedit_params *params, + struct sk_buff *skb) +{ + u16 queue_mapping = params->queue_mapping; + + if (params->flags & SKBEDIT_F_TXQ_SKBHASH) { + u32 hash = skb_get_hash(skb); + + queue_mapping += hash % params->mapping_mod; + } + + return netdev_cap_txqueue(skb->dev, queue_mapping); +} + +static int tcf_skbedit_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + struct tcf_skbedit *d = to_skbedit(a); + struct tcf_skbedit_params *params; + int action; + + tcf_lastuse_update(&d->tcf_tm); + bstats_update(this_cpu_ptr(d->common.cpu_bstats), skb); + + params = rcu_dereference_bh(d->params); + action = READ_ONCE(d->tcf_action); + + if (params->flags & SKBEDIT_F_PRIORITY) + skb->priority = params->priority; + if (params->flags & SKBEDIT_F_INHERITDSFIELD) { + int wlen = skb_network_offset(skb); + + switch (skb_protocol(skb, true)) { + case htons(ETH_P_IP): + wlen += sizeof(struct iphdr); + if (!pskb_may_pull(skb, wlen)) + goto err; + skb->priority = ipv4_get_dsfield(ip_hdr(skb)) >> 2; + break; + + case htons(ETH_P_IPV6): + wlen += sizeof(struct ipv6hdr); + if (!pskb_may_pull(skb, wlen)) + goto err; + skb->priority = ipv6_get_dsfield(ipv6_hdr(skb)) >> 2; + break; + } + } + if (params->flags & SKBEDIT_F_QUEUE_MAPPING && + skb->dev->real_num_tx_queues > params->queue_mapping) { +#ifdef CONFIG_NET_EGRESS + netdev_xmit_skip_txqueue(true); +#endif + skb_set_queue_mapping(skb, tcf_skbedit_hash(params, skb)); + } + if (params->flags & SKBEDIT_F_MARK) { + skb->mark &= ~params->mask; + skb->mark |= params->mark & params->mask; + } + if (params->flags & SKBEDIT_F_PTYPE) + skb->pkt_type = params->ptype; + return action; + +err: + qstats_drop_inc(this_cpu_ptr(d->common.cpu_qstats)); + return TC_ACT_SHOT; +} + +static void tcf_skbedit_stats_update(struct tc_action *a, u64 bytes, + u64 packets, u64 drops, + u64 lastuse, bool hw) +{ + struct tcf_skbedit *d = to_skbedit(a); + struct tcf_t *tm = &d->tcf_tm; + + tcf_action_update_stats(a, bytes, packets, drops, hw); + tm->lastuse = max_t(u64, tm->lastuse, lastuse); +} + +static const struct nla_policy skbedit_policy[TCA_SKBEDIT_MAX + 1] = { + [TCA_SKBEDIT_PARMS] = { .len = sizeof(struct tc_skbedit) }, + [TCA_SKBEDIT_PRIORITY] = { .len = sizeof(u32) }, + [TCA_SKBEDIT_QUEUE_MAPPING] = { .len = sizeof(u16) }, + [TCA_SKBEDIT_MARK] = { .len = sizeof(u32) }, + [TCA_SKBEDIT_PTYPE] = { .len = sizeof(u16) }, + [TCA_SKBEDIT_MASK] = { .len = sizeof(u32) }, + [TCA_SKBEDIT_FLAGS] = { .len = sizeof(u64) }, + [TCA_SKBEDIT_QUEUE_MAPPING_MAX] = { .len = sizeof(u16) }, +}; + +static int tcf_skbedit_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, u32 act_flags, + struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_skbedit_ops.net_id); + bool bind = act_flags & TCA_ACT_FLAGS_BIND; + struct tcf_skbedit_params *params_new; + struct nlattr *tb[TCA_SKBEDIT_MAX + 1]; + struct tcf_chain *goto_ch = NULL; + struct tc_skbedit *parm; + struct tcf_skbedit *d; + u32 flags = 0, *priority = NULL, *mark = NULL, *mask = NULL; + u16 *queue_mapping = NULL, *ptype = NULL; + u16 mapping_mod = 1; + bool exists = false; + int ret = 0, err; + u32 index; + + if (nla == NULL) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, TCA_SKBEDIT_MAX, nla, + skbedit_policy, NULL); + if (err < 0) + return err; + + if (tb[TCA_SKBEDIT_PARMS] == NULL) + return -EINVAL; + + if (tb[TCA_SKBEDIT_PRIORITY] != NULL) { + flags |= SKBEDIT_F_PRIORITY; + priority = nla_data(tb[TCA_SKBEDIT_PRIORITY]); + } + + if (tb[TCA_SKBEDIT_QUEUE_MAPPING] != NULL) { + flags |= SKBEDIT_F_QUEUE_MAPPING; + queue_mapping = nla_data(tb[TCA_SKBEDIT_QUEUE_MAPPING]); + } + + if (tb[TCA_SKBEDIT_PTYPE] != NULL) { + ptype = nla_data(tb[TCA_SKBEDIT_PTYPE]); + if (!skb_pkt_type_ok(*ptype)) + return -EINVAL; + flags |= SKBEDIT_F_PTYPE; + } + + if (tb[TCA_SKBEDIT_MARK] != NULL) { + flags |= SKBEDIT_F_MARK; + mark = nla_data(tb[TCA_SKBEDIT_MARK]); + } + + if (tb[TCA_SKBEDIT_MASK] != NULL) { + flags |= SKBEDIT_F_MASK; + mask = nla_data(tb[TCA_SKBEDIT_MASK]); + } + + if (tb[TCA_SKBEDIT_FLAGS] != NULL) { + u64 *pure_flags = nla_data(tb[TCA_SKBEDIT_FLAGS]); + + if (*pure_flags & SKBEDIT_F_TXQ_SKBHASH) { + u16 *queue_mapping_max; + + if (!tb[TCA_SKBEDIT_QUEUE_MAPPING] || + !tb[TCA_SKBEDIT_QUEUE_MAPPING_MAX]) { + NL_SET_ERR_MSG_MOD(extack, "Missing required range of queue_mapping."); + return -EINVAL; + } + + queue_mapping_max = + nla_data(tb[TCA_SKBEDIT_QUEUE_MAPPING_MAX]); + if (*queue_mapping_max < *queue_mapping) { + NL_SET_ERR_MSG_MOD(extack, "The range of queue_mapping is invalid, max < min."); + return -EINVAL; + } + + mapping_mod = *queue_mapping_max - *queue_mapping + 1; + flags |= SKBEDIT_F_TXQ_SKBHASH; + } + if (*pure_flags & SKBEDIT_F_INHERITDSFIELD) + flags |= SKBEDIT_F_INHERITDSFIELD; + } + + parm = nla_data(tb[TCA_SKBEDIT_PARMS]); + index = parm->index; + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (err < 0) + return err; + exists = err; + if (exists && bind) + return 0; + + if (!flags) { + if (exists) + tcf_idr_release(*a, bind); + else + tcf_idr_cleanup(tn, index); + return -EINVAL; + } + + if (!exists) { + ret = tcf_idr_create(tn, index, est, a, + &act_skbedit_ops, bind, true, act_flags); + if (ret) { + tcf_idr_cleanup(tn, index); + return ret; + } + + d = to_skbedit(*a); + ret = ACT_P_CREATED; + } else { + d = to_skbedit(*a); + if (!(act_flags & TCA_ACT_FLAGS_REPLACE)) { + tcf_idr_release(*a, bind); + return -EEXIST; + } + } + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack); + if (err < 0) + goto release_idr; + + params_new = kzalloc(sizeof(*params_new), GFP_KERNEL); + if (unlikely(!params_new)) { + err = -ENOMEM; + goto put_chain; + } + + params_new->flags = flags; + if (flags & SKBEDIT_F_PRIORITY) + params_new->priority = *priority; + if (flags & SKBEDIT_F_QUEUE_MAPPING) { + params_new->queue_mapping = *queue_mapping; + params_new->mapping_mod = mapping_mod; + } + if (flags & SKBEDIT_F_MARK) + params_new->mark = *mark; + if (flags & SKBEDIT_F_PTYPE) + params_new->ptype = *ptype; + /* default behaviour is to use all the bits */ + params_new->mask = 0xffffffff; + if (flags & SKBEDIT_F_MASK) + params_new->mask = *mask; + + spin_lock_bh(&d->tcf_lock); + goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch); + params_new = rcu_replace_pointer(d->params, params_new, + lockdep_is_held(&d->tcf_lock)); + spin_unlock_bh(&d->tcf_lock); + if (params_new) + kfree_rcu(params_new, rcu); + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + + return ret; +put_chain: + if (goto_ch) + tcf_chain_put_by_act(goto_ch); +release_idr: + tcf_idr_release(*a, bind); + return err; +} + +static int tcf_skbedit_dump(struct sk_buff *skb, struct tc_action *a, + int bind, int ref) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tcf_skbedit *d = to_skbedit(a); + struct tcf_skbedit_params *params; + struct tc_skbedit opt = { + .index = d->tcf_index, + .refcnt = refcount_read(&d->tcf_refcnt) - ref, + .bindcnt = atomic_read(&d->tcf_bindcnt) - bind, + }; + u64 pure_flags = 0; + struct tcf_t t; + + spin_lock_bh(&d->tcf_lock); + params = rcu_dereference_protected(d->params, + lockdep_is_held(&d->tcf_lock)); + opt.action = d->tcf_action; + + if (nla_put(skb, TCA_SKBEDIT_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + if ((params->flags & SKBEDIT_F_PRIORITY) && + nla_put_u32(skb, TCA_SKBEDIT_PRIORITY, params->priority)) + goto nla_put_failure; + if ((params->flags & SKBEDIT_F_QUEUE_MAPPING) && + nla_put_u16(skb, TCA_SKBEDIT_QUEUE_MAPPING, params->queue_mapping)) + goto nla_put_failure; + if ((params->flags & SKBEDIT_F_MARK) && + nla_put_u32(skb, TCA_SKBEDIT_MARK, params->mark)) + goto nla_put_failure; + if ((params->flags & SKBEDIT_F_PTYPE) && + nla_put_u16(skb, TCA_SKBEDIT_PTYPE, params->ptype)) + goto nla_put_failure; + if ((params->flags & SKBEDIT_F_MASK) && + nla_put_u32(skb, TCA_SKBEDIT_MASK, params->mask)) + goto nla_put_failure; + if (params->flags & SKBEDIT_F_INHERITDSFIELD) + pure_flags |= SKBEDIT_F_INHERITDSFIELD; + if (params->flags & SKBEDIT_F_TXQ_SKBHASH) { + if (nla_put_u16(skb, TCA_SKBEDIT_QUEUE_MAPPING_MAX, + params->queue_mapping + params->mapping_mod - 1)) + goto nla_put_failure; + + pure_flags |= SKBEDIT_F_TXQ_SKBHASH; + } + if (pure_flags != 0 && + nla_put(skb, TCA_SKBEDIT_FLAGS, sizeof(pure_flags), &pure_flags)) + goto nla_put_failure; + + tcf_tm_dump(&t, &d->tcf_tm); + if (nla_put_64bit(skb, TCA_SKBEDIT_TM, sizeof(t), &t, TCA_SKBEDIT_PAD)) + goto nla_put_failure; + spin_unlock_bh(&d->tcf_lock); + + return skb->len; + +nla_put_failure: + spin_unlock_bh(&d->tcf_lock); + nlmsg_trim(skb, b); + return -1; +} + +static void tcf_skbedit_cleanup(struct tc_action *a) +{ + struct tcf_skbedit *d = to_skbedit(a); + struct tcf_skbedit_params *params; + + params = rcu_dereference_protected(d->params, 1); + if (params) + kfree_rcu(params, rcu); +} + +static size_t tcf_skbedit_get_fill_size(const struct tc_action *act) +{ + return nla_total_size(sizeof(struct tc_skbedit)) + + nla_total_size(sizeof(u32)) /* TCA_SKBEDIT_PRIORITY */ + + nla_total_size(sizeof(u16)) /* TCA_SKBEDIT_QUEUE_MAPPING */ + + nla_total_size(sizeof(u16)) /* TCA_SKBEDIT_QUEUE_MAPPING_MAX */ + + nla_total_size(sizeof(u32)) /* TCA_SKBEDIT_MARK */ + + nla_total_size(sizeof(u16)) /* TCA_SKBEDIT_PTYPE */ + + nla_total_size(sizeof(u32)) /* TCA_SKBEDIT_MASK */ + + nla_total_size_64bit(sizeof(u64)); /* TCA_SKBEDIT_FLAGS */ +} + +static int tcf_skbedit_offload_act_setup(struct tc_action *act, void *entry_data, + u32 *index_inc, bool bind, + struct netlink_ext_ack *extack) +{ + if (bind) { + struct flow_action_entry *entry = entry_data; + + if (is_tcf_skbedit_mark(act)) { + entry->id = FLOW_ACTION_MARK; + entry->mark = tcf_skbedit_mark(act); + } else if (is_tcf_skbedit_ptype(act)) { + entry->id = FLOW_ACTION_PTYPE; + entry->ptype = tcf_skbedit_ptype(act); + } else if (is_tcf_skbedit_priority(act)) { + entry->id = FLOW_ACTION_PRIORITY; + entry->priority = tcf_skbedit_priority(act); + } else if (is_tcf_skbedit_queue_mapping(act)) { + NL_SET_ERR_MSG_MOD(extack, "Offload not supported when \"queue_mapping\" option is used"); + return -EOPNOTSUPP; + } else if (is_tcf_skbedit_inheritdsfield(act)) { + NL_SET_ERR_MSG_MOD(extack, "Offload not supported when \"inheritdsfield\" option is used"); + return -EOPNOTSUPP; + } else { + NL_SET_ERR_MSG_MOD(extack, "Unsupported skbedit option offload"); + return -EOPNOTSUPP; + } + *index_inc = 1; + } else { + struct flow_offload_action *fl_action = entry_data; + + if (is_tcf_skbedit_mark(act)) + fl_action->id = FLOW_ACTION_MARK; + else if (is_tcf_skbedit_ptype(act)) + fl_action->id = FLOW_ACTION_PTYPE; + else if (is_tcf_skbedit_priority(act)) + fl_action->id = FLOW_ACTION_PRIORITY; + else + return -EOPNOTSUPP; + } + + return 0; +} + +static struct tc_action_ops act_skbedit_ops = { + .kind = "skbedit", + .id = TCA_ID_SKBEDIT, + .owner = THIS_MODULE, + .act = tcf_skbedit_act, + .stats_update = tcf_skbedit_stats_update, + .dump = tcf_skbedit_dump, + .init = tcf_skbedit_init, + .cleanup = tcf_skbedit_cleanup, + .get_fill_size = tcf_skbedit_get_fill_size, + .offload_act_setup = tcf_skbedit_offload_act_setup, + .size = sizeof(struct tcf_skbedit), +}; + +static __net_init int skbedit_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_skbedit_ops.net_id); + + return tc_action_net_init(net, tn, &act_skbedit_ops); +} + +static void __net_exit skbedit_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_skbedit_ops.net_id); +} + +static struct pernet_operations skbedit_net_ops = { + .init = skbedit_init_net, + .exit_batch = skbedit_exit_net, + .id = &act_skbedit_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +MODULE_AUTHOR("Alexander Duyck, "); +MODULE_DESCRIPTION("SKB Editing"); +MODULE_LICENSE("GPL"); + +static int __init skbedit_init_module(void) +{ + return tcf_register_action(&act_skbedit_ops, &skbedit_net_ops); +} + +static void __exit skbedit_cleanup_module(void) +{ + tcf_unregister_action(&act_skbedit_ops, &skbedit_net_ops); +} + +module_init(skbedit_init_module); +module_exit(skbedit_cleanup_module); diff --git a/net/sched/act_skbmod.c b/net/sched/act_skbmod.c new file mode 100644 index 000000000..d98758a63 --- /dev/null +++ b/net/sched/act_skbmod.c @@ -0,0 +1,323 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/act_skbmod.c skb data modifier + * + * Copyright (c) 2016 Jamal Hadi Salim +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +static struct tc_action_ops act_skbmod_ops; + +static int tcf_skbmod_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + struct tcf_skbmod *d = to_skbmod(a); + int action, max_edit_len, err; + struct tcf_skbmod_params *p; + u64 flags; + + tcf_lastuse_update(&d->tcf_tm); + bstats_update(this_cpu_ptr(d->common.cpu_bstats), skb); + + action = READ_ONCE(d->tcf_action); + if (unlikely(action == TC_ACT_SHOT)) + goto drop; + + max_edit_len = skb_mac_header_len(skb); + p = rcu_dereference_bh(d->skbmod_p); + flags = p->flags; + + /* tcf_skbmod_init() guarantees "flags" to be one of the following: + * 1. a combination of SKBMOD_F_{DMAC,SMAC,ETYPE} + * 2. SKBMOD_F_SWAPMAC + * 3. SKBMOD_F_ECN + * SKBMOD_F_ECN only works with IP packets; all other flags only work with Ethernet + * packets. + */ + if (flags == SKBMOD_F_ECN) { + switch (skb_protocol(skb, true)) { + case cpu_to_be16(ETH_P_IP): + case cpu_to_be16(ETH_P_IPV6): + max_edit_len += skb_network_header_len(skb); + break; + default: + goto out; + } + } else if (!skb->dev || skb->dev->type != ARPHRD_ETHER) { + goto out; + } + + err = skb_ensure_writable(skb, max_edit_len); + if (unlikely(err)) /* best policy is to drop on the floor */ + goto drop; + + if (flags & SKBMOD_F_DMAC) + ether_addr_copy(eth_hdr(skb)->h_dest, p->eth_dst); + if (flags & SKBMOD_F_SMAC) + ether_addr_copy(eth_hdr(skb)->h_source, p->eth_src); + if (flags & SKBMOD_F_ETYPE) + eth_hdr(skb)->h_proto = p->eth_type; + + if (flags & SKBMOD_F_SWAPMAC) { + u16 tmpaddr[ETH_ALEN / 2]; /* ether_addr_copy() requirement */ + /*XXX: I am sure we can come up with more efficient swapping*/ + ether_addr_copy((u8 *)tmpaddr, eth_hdr(skb)->h_dest); + ether_addr_copy(eth_hdr(skb)->h_dest, eth_hdr(skb)->h_source); + ether_addr_copy(eth_hdr(skb)->h_source, (u8 *)tmpaddr); + } + + if (flags & SKBMOD_F_ECN) + INET_ECN_set_ce(skb); + +out: + return action; + +drop: + qstats_overlimit_inc(this_cpu_ptr(d->common.cpu_qstats)); + return TC_ACT_SHOT; +} + +static const struct nla_policy skbmod_policy[TCA_SKBMOD_MAX + 1] = { + [TCA_SKBMOD_PARMS] = { .len = sizeof(struct tc_skbmod) }, + [TCA_SKBMOD_DMAC] = { .len = ETH_ALEN }, + [TCA_SKBMOD_SMAC] = { .len = ETH_ALEN }, + [TCA_SKBMOD_ETYPE] = { .type = NLA_U16 }, +}; + +static int tcf_skbmod_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, u32 flags, + struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_skbmod_ops.net_id); + bool ovr = flags & TCA_ACT_FLAGS_REPLACE; + bool bind = flags & TCA_ACT_FLAGS_BIND; + struct nlattr *tb[TCA_SKBMOD_MAX + 1]; + struct tcf_skbmod_params *p, *p_old; + struct tcf_chain *goto_ch = NULL; + struct tc_skbmod *parm; + u32 lflags = 0, index; + struct tcf_skbmod *d; + bool exists = false; + u8 *daddr = NULL; + u8 *saddr = NULL; + u16 eth_type = 0; + int ret = 0, err; + + if (!nla) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, TCA_SKBMOD_MAX, nla, + skbmod_policy, NULL); + if (err < 0) + return err; + + if (!tb[TCA_SKBMOD_PARMS]) + return -EINVAL; + + if (tb[TCA_SKBMOD_DMAC]) { + daddr = nla_data(tb[TCA_SKBMOD_DMAC]); + lflags |= SKBMOD_F_DMAC; + } + + if (tb[TCA_SKBMOD_SMAC]) { + saddr = nla_data(tb[TCA_SKBMOD_SMAC]); + lflags |= SKBMOD_F_SMAC; + } + + if (tb[TCA_SKBMOD_ETYPE]) { + eth_type = nla_get_u16(tb[TCA_SKBMOD_ETYPE]); + lflags |= SKBMOD_F_ETYPE; + } + + parm = nla_data(tb[TCA_SKBMOD_PARMS]); + index = parm->index; + if (parm->flags & SKBMOD_F_SWAPMAC) + lflags = SKBMOD_F_SWAPMAC; + if (parm->flags & SKBMOD_F_ECN) + lflags = SKBMOD_F_ECN; + + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (err < 0) + return err; + exists = err; + if (exists && bind) + return 0; + + if (!lflags) { + if (exists) + tcf_idr_release(*a, bind); + else + tcf_idr_cleanup(tn, index); + return -EINVAL; + } + + if (!exists) { + ret = tcf_idr_create(tn, index, est, a, + &act_skbmod_ops, bind, true, flags); + if (ret) { + tcf_idr_cleanup(tn, index); + return ret; + } + + ret = ACT_P_CREATED; + } else if (!ovr) { + tcf_idr_release(*a, bind); + return -EEXIST; + } + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack); + if (err < 0) + goto release_idr; + + d = to_skbmod(*a); + + p = kzalloc(sizeof(struct tcf_skbmod_params), GFP_KERNEL); + if (unlikely(!p)) { + err = -ENOMEM; + goto put_chain; + } + + p->flags = lflags; + + if (ovr) + spin_lock_bh(&d->tcf_lock); + /* Protected by tcf_lock if overwriting existing action. */ + goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch); + p_old = rcu_dereference_protected(d->skbmod_p, 1); + + if (lflags & SKBMOD_F_DMAC) + ether_addr_copy(p->eth_dst, daddr); + if (lflags & SKBMOD_F_SMAC) + ether_addr_copy(p->eth_src, saddr); + if (lflags & SKBMOD_F_ETYPE) + p->eth_type = htons(eth_type); + + rcu_assign_pointer(d->skbmod_p, p); + if (ovr) + spin_unlock_bh(&d->tcf_lock); + + if (p_old) + kfree_rcu(p_old, rcu); + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + + return ret; +put_chain: + if (goto_ch) + tcf_chain_put_by_act(goto_ch); +release_idr: + tcf_idr_release(*a, bind); + return err; +} + +static void tcf_skbmod_cleanup(struct tc_action *a) +{ + struct tcf_skbmod *d = to_skbmod(a); + struct tcf_skbmod_params *p; + + p = rcu_dereference_protected(d->skbmod_p, 1); + if (p) + kfree_rcu(p, rcu); +} + +static int tcf_skbmod_dump(struct sk_buff *skb, struct tc_action *a, + int bind, int ref) +{ + struct tcf_skbmod *d = to_skbmod(a); + unsigned char *b = skb_tail_pointer(skb); + struct tcf_skbmod_params *p; + struct tc_skbmod opt = { + .index = d->tcf_index, + .refcnt = refcount_read(&d->tcf_refcnt) - ref, + .bindcnt = atomic_read(&d->tcf_bindcnt) - bind, + }; + struct tcf_t t; + + spin_lock_bh(&d->tcf_lock); + opt.action = d->tcf_action; + p = rcu_dereference_protected(d->skbmod_p, + lockdep_is_held(&d->tcf_lock)); + opt.flags = p->flags; + if (nla_put(skb, TCA_SKBMOD_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + if ((p->flags & SKBMOD_F_DMAC) && + nla_put(skb, TCA_SKBMOD_DMAC, ETH_ALEN, p->eth_dst)) + goto nla_put_failure; + if ((p->flags & SKBMOD_F_SMAC) && + nla_put(skb, TCA_SKBMOD_SMAC, ETH_ALEN, p->eth_src)) + goto nla_put_failure; + if ((p->flags & SKBMOD_F_ETYPE) && + nla_put_u16(skb, TCA_SKBMOD_ETYPE, ntohs(p->eth_type))) + goto nla_put_failure; + + tcf_tm_dump(&t, &d->tcf_tm); + if (nla_put_64bit(skb, TCA_SKBMOD_TM, sizeof(t), &t, TCA_SKBMOD_PAD)) + goto nla_put_failure; + + spin_unlock_bh(&d->tcf_lock); + return skb->len; +nla_put_failure: + spin_unlock_bh(&d->tcf_lock); + nlmsg_trim(skb, b); + return -1; +} + +static struct tc_action_ops act_skbmod_ops = { + .kind = "skbmod", + .id = TCA_ACT_SKBMOD, + .owner = THIS_MODULE, + .act = tcf_skbmod_act, + .dump = tcf_skbmod_dump, + .init = tcf_skbmod_init, + .cleanup = tcf_skbmod_cleanup, + .size = sizeof(struct tcf_skbmod), +}; + +static __net_init int skbmod_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_skbmod_ops.net_id); + + return tc_action_net_init(net, tn, &act_skbmod_ops); +} + +static void __net_exit skbmod_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_skbmod_ops.net_id); +} + +static struct pernet_operations skbmod_net_ops = { + .init = skbmod_init_net, + .exit_batch = skbmod_exit_net, + .id = &act_skbmod_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +MODULE_AUTHOR("Jamal Hadi Salim, "); +MODULE_DESCRIPTION("SKB data mod-ing"); +MODULE_LICENSE("GPL"); + +static int __init skbmod_init_module(void) +{ + return tcf_register_action(&act_skbmod_ops, &skbmod_net_ops); +} + +static void __exit skbmod_cleanup_module(void) +{ + tcf_unregister_action(&act_skbmod_ops, &skbmod_net_ops); +} + +module_init(skbmod_init_module); +module_exit(skbmod_cleanup_module); diff --git a/net/sched/act_tunnel_key.c b/net/sched/act_tunnel_key.c new file mode 100644 index 000000000..2691a3d8e --- /dev/null +++ b/net/sched/act_tunnel_key.c @@ -0,0 +1,873 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (c) 2016, Amir Vadai + * Copyright (c) 2016, Mellanox Technologies. All rights reserved. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +static struct tc_action_ops act_tunnel_key_ops; + +static int tunnel_key_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + struct tcf_tunnel_key *t = to_tunnel_key(a); + struct tcf_tunnel_key_params *params; + int action; + + params = rcu_dereference_bh(t->params); + + tcf_lastuse_update(&t->tcf_tm); + tcf_action_update_bstats(&t->common, skb); + action = READ_ONCE(t->tcf_action); + + switch (params->tcft_action) { + case TCA_TUNNEL_KEY_ACT_RELEASE: + skb_dst_drop(skb); + break; + case TCA_TUNNEL_KEY_ACT_SET: + skb_dst_drop(skb); + skb_dst_set(skb, dst_clone(¶ms->tcft_enc_metadata->dst)); + break; + default: + WARN_ONCE(1, "Bad tunnel_key action %d.\n", + params->tcft_action); + break; + } + + return action; +} + +static const struct nla_policy +enc_opts_policy[TCA_TUNNEL_KEY_ENC_OPTS_MAX + 1] = { + [TCA_TUNNEL_KEY_ENC_OPTS_UNSPEC] = { + .strict_start_type = TCA_TUNNEL_KEY_ENC_OPTS_VXLAN }, + [TCA_TUNNEL_KEY_ENC_OPTS_GENEVE] = { .type = NLA_NESTED }, + [TCA_TUNNEL_KEY_ENC_OPTS_VXLAN] = { .type = NLA_NESTED }, + [TCA_TUNNEL_KEY_ENC_OPTS_ERSPAN] = { .type = NLA_NESTED }, +}; + +static const struct nla_policy +geneve_opt_policy[TCA_TUNNEL_KEY_ENC_OPT_GENEVE_MAX + 1] = { + [TCA_TUNNEL_KEY_ENC_OPT_GENEVE_CLASS] = { .type = NLA_U16 }, + [TCA_TUNNEL_KEY_ENC_OPT_GENEVE_TYPE] = { .type = NLA_U8 }, + [TCA_TUNNEL_KEY_ENC_OPT_GENEVE_DATA] = { .type = NLA_BINARY, + .len = 128 }, +}; + +static const struct nla_policy +vxlan_opt_policy[TCA_TUNNEL_KEY_ENC_OPT_VXLAN_MAX + 1] = { + [TCA_TUNNEL_KEY_ENC_OPT_VXLAN_GBP] = { .type = NLA_U32 }, +}; + +static const struct nla_policy +erspan_opt_policy[TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_MAX + 1] = { + [TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_VER] = { .type = NLA_U8 }, + [TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_INDEX] = { .type = NLA_U32 }, + [TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_DIR] = { .type = NLA_U8 }, + [TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_HWID] = { .type = NLA_U8 }, +}; + +static int +tunnel_key_copy_geneve_opt(const struct nlattr *nla, void *dst, int dst_len, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_TUNNEL_KEY_ENC_OPT_GENEVE_MAX + 1]; + int err, data_len, opt_len; + u8 *data; + + err = nla_parse_nested_deprecated(tb, + TCA_TUNNEL_KEY_ENC_OPT_GENEVE_MAX, + nla, geneve_opt_policy, extack); + if (err < 0) + return err; + + if (!tb[TCA_TUNNEL_KEY_ENC_OPT_GENEVE_CLASS] || + !tb[TCA_TUNNEL_KEY_ENC_OPT_GENEVE_TYPE] || + !tb[TCA_TUNNEL_KEY_ENC_OPT_GENEVE_DATA]) { + NL_SET_ERR_MSG(extack, "Missing tunnel key geneve option class, type or data"); + return -EINVAL; + } + + data = nla_data(tb[TCA_TUNNEL_KEY_ENC_OPT_GENEVE_DATA]); + data_len = nla_len(tb[TCA_TUNNEL_KEY_ENC_OPT_GENEVE_DATA]); + if (data_len < 4) { + NL_SET_ERR_MSG(extack, "Tunnel key geneve option data is less than 4 bytes long"); + return -ERANGE; + } + if (data_len % 4) { + NL_SET_ERR_MSG(extack, "Tunnel key geneve option data is not a multiple of 4 bytes long"); + return -ERANGE; + } + + opt_len = sizeof(struct geneve_opt) + data_len; + if (dst) { + struct geneve_opt *opt = dst; + + WARN_ON(dst_len < opt_len); + + opt->opt_class = + nla_get_be16(tb[TCA_TUNNEL_KEY_ENC_OPT_GENEVE_CLASS]); + opt->type = nla_get_u8(tb[TCA_TUNNEL_KEY_ENC_OPT_GENEVE_TYPE]); + opt->length = data_len / 4; /* length is in units of 4 bytes */ + opt->r1 = 0; + opt->r2 = 0; + opt->r3 = 0; + + memcpy(opt + 1, data, data_len); + } + + return opt_len; +} + +static int +tunnel_key_copy_vxlan_opt(const struct nlattr *nla, void *dst, int dst_len, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_TUNNEL_KEY_ENC_OPT_VXLAN_MAX + 1]; + int err; + + err = nla_parse_nested(tb, TCA_TUNNEL_KEY_ENC_OPT_VXLAN_MAX, nla, + vxlan_opt_policy, extack); + if (err < 0) + return err; + + if (!tb[TCA_TUNNEL_KEY_ENC_OPT_VXLAN_GBP]) { + NL_SET_ERR_MSG(extack, "Missing tunnel key vxlan option gbp"); + return -EINVAL; + } + + if (dst) { + struct vxlan_metadata *md = dst; + + md->gbp = nla_get_u32(tb[TCA_TUNNEL_KEY_ENC_OPT_VXLAN_GBP]); + md->gbp &= VXLAN_GBP_MASK; + } + + return sizeof(struct vxlan_metadata); +} + +static int +tunnel_key_copy_erspan_opt(const struct nlattr *nla, void *dst, int dst_len, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_MAX + 1]; + int err; + u8 ver; + + err = nla_parse_nested(tb, TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_MAX, nla, + erspan_opt_policy, extack); + if (err < 0) + return err; + + if (!tb[TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_VER]) { + NL_SET_ERR_MSG(extack, "Missing tunnel key erspan option ver"); + return -EINVAL; + } + + ver = nla_get_u8(tb[TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_VER]); + if (ver == 1) { + if (!tb[TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_INDEX]) { + NL_SET_ERR_MSG(extack, "Missing tunnel key erspan option index"); + return -EINVAL; + } + } else if (ver == 2) { + if (!tb[TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_DIR] || + !tb[TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_HWID]) { + NL_SET_ERR_MSG(extack, "Missing tunnel key erspan option dir or hwid"); + return -EINVAL; + } + } else { + NL_SET_ERR_MSG(extack, "Tunnel key erspan option ver is incorrect"); + return -EINVAL; + } + + if (dst) { + struct erspan_metadata *md = dst; + + md->version = ver; + if (ver == 1) { + nla = tb[TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_INDEX]; + md->u.index = nla_get_be32(nla); + } else { + nla = tb[TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_DIR]; + md->u.md2.dir = nla_get_u8(nla); + nla = tb[TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_HWID]; + set_hwid(&md->u.md2, nla_get_u8(nla)); + } + } + + return sizeof(struct erspan_metadata); +} + +static int tunnel_key_copy_opts(const struct nlattr *nla, u8 *dst, + int dst_len, struct netlink_ext_ack *extack) +{ + int err, rem, opt_len, len = nla_len(nla), opts_len = 0, type = 0; + const struct nlattr *attr, *head = nla_data(nla); + + err = nla_validate_deprecated(head, len, TCA_TUNNEL_KEY_ENC_OPTS_MAX, + enc_opts_policy, extack); + if (err) + return err; + + nla_for_each_attr(attr, head, len, rem) { + switch (nla_type(attr)) { + case TCA_TUNNEL_KEY_ENC_OPTS_GENEVE: + if (type && type != TUNNEL_GENEVE_OPT) { + NL_SET_ERR_MSG(extack, "Duplicate type for geneve options"); + return -EINVAL; + } + opt_len = tunnel_key_copy_geneve_opt(attr, dst, + dst_len, extack); + if (opt_len < 0) + return opt_len; + opts_len += opt_len; + if (opts_len > IP_TUNNEL_OPTS_MAX) { + NL_SET_ERR_MSG(extack, "Tunnel options exceeds max size"); + return -EINVAL; + } + if (dst) { + dst_len -= opt_len; + dst += opt_len; + } + type = TUNNEL_GENEVE_OPT; + break; + case TCA_TUNNEL_KEY_ENC_OPTS_VXLAN: + if (type) { + NL_SET_ERR_MSG(extack, "Duplicate type for vxlan options"); + return -EINVAL; + } + opt_len = tunnel_key_copy_vxlan_opt(attr, dst, + dst_len, extack); + if (opt_len < 0) + return opt_len; + opts_len += opt_len; + type = TUNNEL_VXLAN_OPT; + break; + case TCA_TUNNEL_KEY_ENC_OPTS_ERSPAN: + if (type) { + NL_SET_ERR_MSG(extack, "Duplicate type for erspan options"); + return -EINVAL; + } + opt_len = tunnel_key_copy_erspan_opt(attr, dst, + dst_len, extack); + if (opt_len < 0) + return opt_len; + opts_len += opt_len; + type = TUNNEL_ERSPAN_OPT; + break; + } + } + + if (!opts_len) { + NL_SET_ERR_MSG(extack, "Empty list of tunnel options"); + return -EINVAL; + } + + if (rem > 0) { + NL_SET_ERR_MSG(extack, "Trailing data after parsing tunnel key options attributes"); + return -EINVAL; + } + + return opts_len; +} + +static int tunnel_key_get_opts_len(struct nlattr *nla, + struct netlink_ext_ack *extack) +{ + return tunnel_key_copy_opts(nla, NULL, 0, extack); +} + +static int tunnel_key_opts_set(struct nlattr *nla, struct ip_tunnel_info *info, + int opts_len, struct netlink_ext_ack *extack) +{ + info->options_len = opts_len; + switch (nla_type(nla_data(nla))) { + case TCA_TUNNEL_KEY_ENC_OPTS_GENEVE: +#if IS_ENABLED(CONFIG_INET) + info->key.tun_flags |= TUNNEL_GENEVE_OPT; + return tunnel_key_copy_opts(nla, ip_tunnel_info_opts(info), + opts_len, extack); +#else + return -EAFNOSUPPORT; +#endif + case TCA_TUNNEL_KEY_ENC_OPTS_VXLAN: +#if IS_ENABLED(CONFIG_INET) + info->key.tun_flags |= TUNNEL_VXLAN_OPT; + return tunnel_key_copy_opts(nla, ip_tunnel_info_opts(info), + opts_len, extack); +#else + return -EAFNOSUPPORT; +#endif + case TCA_TUNNEL_KEY_ENC_OPTS_ERSPAN: +#if IS_ENABLED(CONFIG_INET) + info->key.tun_flags |= TUNNEL_ERSPAN_OPT; + return tunnel_key_copy_opts(nla, ip_tunnel_info_opts(info), + opts_len, extack); +#else + return -EAFNOSUPPORT; +#endif + default: + NL_SET_ERR_MSG(extack, "Cannot set tunnel options for unknown tunnel type"); + return -EINVAL; + } +} + +static const struct nla_policy tunnel_key_policy[TCA_TUNNEL_KEY_MAX + 1] = { + [TCA_TUNNEL_KEY_PARMS] = { .len = sizeof(struct tc_tunnel_key) }, + [TCA_TUNNEL_KEY_ENC_IPV4_SRC] = { .type = NLA_U32 }, + [TCA_TUNNEL_KEY_ENC_IPV4_DST] = { .type = NLA_U32 }, + [TCA_TUNNEL_KEY_ENC_IPV6_SRC] = { .len = sizeof(struct in6_addr) }, + [TCA_TUNNEL_KEY_ENC_IPV6_DST] = { .len = sizeof(struct in6_addr) }, + [TCA_TUNNEL_KEY_ENC_KEY_ID] = { .type = NLA_U32 }, + [TCA_TUNNEL_KEY_ENC_DST_PORT] = {.type = NLA_U16}, + [TCA_TUNNEL_KEY_NO_CSUM] = { .type = NLA_U8 }, + [TCA_TUNNEL_KEY_ENC_OPTS] = { .type = NLA_NESTED }, + [TCA_TUNNEL_KEY_ENC_TOS] = { .type = NLA_U8 }, + [TCA_TUNNEL_KEY_ENC_TTL] = { .type = NLA_U8 }, +}; + +static void tunnel_key_release_params(struct tcf_tunnel_key_params *p) +{ + if (!p) + return; + if (p->tcft_action == TCA_TUNNEL_KEY_ACT_SET) + dst_release(&p->tcft_enc_metadata->dst); + + kfree_rcu(p, rcu); +} + +static int tunnel_key_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, u32 act_flags, + struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_tunnel_key_ops.net_id); + bool bind = act_flags & TCA_ACT_FLAGS_BIND; + struct nlattr *tb[TCA_TUNNEL_KEY_MAX + 1]; + struct tcf_tunnel_key_params *params_new; + struct metadata_dst *metadata = NULL; + struct tcf_chain *goto_ch = NULL; + struct tc_tunnel_key *parm; + struct tcf_tunnel_key *t; + bool exists = false; + __be16 dst_port = 0; + __be64 key_id = 0; + int opts_len = 0; + __be16 flags = 0; + u8 tos, ttl; + int ret = 0; + u32 index; + int err; + + if (!nla) { + NL_SET_ERR_MSG(extack, "Tunnel requires attributes to be passed"); + return -EINVAL; + } + + err = nla_parse_nested_deprecated(tb, TCA_TUNNEL_KEY_MAX, nla, + tunnel_key_policy, extack); + if (err < 0) { + NL_SET_ERR_MSG(extack, "Failed to parse nested tunnel key attributes"); + return err; + } + + if (!tb[TCA_TUNNEL_KEY_PARMS]) { + NL_SET_ERR_MSG(extack, "Missing tunnel key parameters"); + return -EINVAL; + } + + parm = nla_data(tb[TCA_TUNNEL_KEY_PARMS]); + index = parm->index; + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (err < 0) + return err; + exists = err; + if (exists && bind) + return 0; + + switch (parm->t_action) { + case TCA_TUNNEL_KEY_ACT_RELEASE: + break; + case TCA_TUNNEL_KEY_ACT_SET: + if (tb[TCA_TUNNEL_KEY_ENC_KEY_ID]) { + __be32 key32; + + key32 = nla_get_be32(tb[TCA_TUNNEL_KEY_ENC_KEY_ID]); + key_id = key32_to_tunnel_id(key32); + flags = TUNNEL_KEY; + } + + flags |= TUNNEL_CSUM; + if (tb[TCA_TUNNEL_KEY_NO_CSUM] && + nla_get_u8(tb[TCA_TUNNEL_KEY_NO_CSUM])) + flags &= ~TUNNEL_CSUM; + + if (tb[TCA_TUNNEL_KEY_ENC_DST_PORT]) + dst_port = nla_get_be16(tb[TCA_TUNNEL_KEY_ENC_DST_PORT]); + + if (tb[TCA_TUNNEL_KEY_ENC_OPTS]) { + opts_len = tunnel_key_get_opts_len(tb[TCA_TUNNEL_KEY_ENC_OPTS], + extack); + if (opts_len < 0) { + ret = opts_len; + goto err_out; + } + } + + tos = 0; + if (tb[TCA_TUNNEL_KEY_ENC_TOS]) + tos = nla_get_u8(tb[TCA_TUNNEL_KEY_ENC_TOS]); + ttl = 0; + if (tb[TCA_TUNNEL_KEY_ENC_TTL]) + ttl = nla_get_u8(tb[TCA_TUNNEL_KEY_ENC_TTL]); + + if (tb[TCA_TUNNEL_KEY_ENC_IPV4_SRC] && + tb[TCA_TUNNEL_KEY_ENC_IPV4_DST]) { + __be32 saddr; + __be32 daddr; + + saddr = nla_get_in_addr(tb[TCA_TUNNEL_KEY_ENC_IPV4_SRC]); + daddr = nla_get_in_addr(tb[TCA_TUNNEL_KEY_ENC_IPV4_DST]); + + metadata = __ip_tun_set_dst(saddr, daddr, tos, ttl, + dst_port, flags, + key_id, opts_len); + } else if (tb[TCA_TUNNEL_KEY_ENC_IPV6_SRC] && + tb[TCA_TUNNEL_KEY_ENC_IPV6_DST]) { + struct in6_addr saddr; + struct in6_addr daddr; + + saddr = nla_get_in6_addr(tb[TCA_TUNNEL_KEY_ENC_IPV6_SRC]); + daddr = nla_get_in6_addr(tb[TCA_TUNNEL_KEY_ENC_IPV6_DST]); + + metadata = __ipv6_tun_set_dst(&saddr, &daddr, tos, ttl, dst_port, + 0, flags, + key_id, opts_len); + } else { + NL_SET_ERR_MSG(extack, "Missing either ipv4 or ipv6 src and dst"); + ret = -EINVAL; + goto err_out; + } + + if (!metadata) { + NL_SET_ERR_MSG(extack, "Cannot allocate tunnel metadata dst"); + ret = -ENOMEM; + goto err_out; + } + +#ifdef CONFIG_DST_CACHE + ret = dst_cache_init(&metadata->u.tun_info.dst_cache, GFP_KERNEL); + if (ret) + goto release_tun_meta; +#endif + + if (opts_len) { + ret = tunnel_key_opts_set(tb[TCA_TUNNEL_KEY_ENC_OPTS], + &metadata->u.tun_info, + opts_len, extack); + if (ret < 0) + goto release_tun_meta; + } + + metadata->u.tun_info.mode |= IP_TUNNEL_INFO_TX; + break; + default: + NL_SET_ERR_MSG(extack, "Unknown tunnel key action"); + ret = -EINVAL; + goto err_out; + } + + if (!exists) { + ret = tcf_idr_create_from_flags(tn, index, est, a, + &act_tunnel_key_ops, bind, + act_flags); + if (ret) { + NL_SET_ERR_MSG(extack, "Cannot create TC IDR"); + goto release_tun_meta; + } + + ret = ACT_P_CREATED; + } else if (!(act_flags & TCA_ACT_FLAGS_REPLACE)) { + NL_SET_ERR_MSG(extack, "TC IDR already exists"); + ret = -EEXIST; + goto release_tun_meta; + } + + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack); + if (err < 0) { + ret = err; + exists = true; + goto release_tun_meta; + } + t = to_tunnel_key(*a); + + params_new = kzalloc(sizeof(*params_new), GFP_KERNEL); + if (unlikely(!params_new)) { + NL_SET_ERR_MSG(extack, "Cannot allocate tunnel key parameters"); + ret = -ENOMEM; + exists = true; + goto put_chain; + } + params_new->tcft_action = parm->t_action; + params_new->tcft_enc_metadata = metadata; + + spin_lock_bh(&t->tcf_lock); + goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch); + params_new = rcu_replace_pointer(t->params, params_new, + lockdep_is_held(&t->tcf_lock)); + spin_unlock_bh(&t->tcf_lock); + tunnel_key_release_params(params_new); + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + + return ret; + +put_chain: + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + +release_tun_meta: + if (metadata) + dst_release(&metadata->dst); + +err_out: + if (exists) + tcf_idr_release(*a, bind); + else + tcf_idr_cleanup(tn, index); + return ret; +} + +static void tunnel_key_release(struct tc_action *a) +{ + struct tcf_tunnel_key *t = to_tunnel_key(a); + struct tcf_tunnel_key_params *params; + + params = rcu_dereference_protected(t->params, 1); + tunnel_key_release_params(params); +} + +static int tunnel_key_geneve_opts_dump(struct sk_buff *skb, + const struct ip_tunnel_info *info) +{ + int len = info->options_len; + u8 *src = (u8 *)(info + 1); + struct nlattr *start; + + start = nla_nest_start_noflag(skb, TCA_TUNNEL_KEY_ENC_OPTS_GENEVE); + if (!start) + return -EMSGSIZE; + + while (len > 0) { + struct geneve_opt *opt = (struct geneve_opt *)src; + + if (nla_put_be16(skb, TCA_TUNNEL_KEY_ENC_OPT_GENEVE_CLASS, + opt->opt_class) || + nla_put_u8(skb, TCA_TUNNEL_KEY_ENC_OPT_GENEVE_TYPE, + opt->type) || + nla_put(skb, TCA_TUNNEL_KEY_ENC_OPT_GENEVE_DATA, + opt->length * 4, opt + 1)) { + nla_nest_cancel(skb, start); + return -EMSGSIZE; + } + + len -= sizeof(struct geneve_opt) + opt->length * 4; + src += sizeof(struct geneve_opt) + opt->length * 4; + } + + nla_nest_end(skb, start); + return 0; +} + +static int tunnel_key_vxlan_opts_dump(struct sk_buff *skb, + const struct ip_tunnel_info *info) +{ + struct vxlan_metadata *md = (struct vxlan_metadata *)(info + 1); + struct nlattr *start; + + start = nla_nest_start_noflag(skb, TCA_TUNNEL_KEY_ENC_OPTS_VXLAN); + if (!start) + return -EMSGSIZE; + + if (nla_put_u32(skb, TCA_TUNNEL_KEY_ENC_OPT_VXLAN_GBP, md->gbp)) { + nla_nest_cancel(skb, start); + return -EMSGSIZE; + } + + nla_nest_end(skb, start); + return 0; +} + +static int tunnel_key_erspan_opts_dump(struct sk_buff *skb, + const struct ip_tunnel_info *info) +{ + struct erspan_metadata *md = (struct erspan_metadata *)(info + 1); + struct nlattr *start; + + start = nla_nest_start_noflag(skb, TCA_TUNNEL_KEY_ENC_OPTS_ERSPAN); + if (!start) + return -EMSGSIZE; + + if (nla_put_u8(skb, TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_VER, md->version)) + goto err; + + if (md->version == 1 && + nla_put_be32(skb, TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_INDEX, md->u.index)) + goto err; + + if (md->version == 2 && + (nla_put_u8(skb, TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_DIR, + md->u.md2.dir) || + nla_put_u8(skb, TCA_TUNNEL_KEY_ENC_OPT_ERSPAN_HWID, + get_hwid(&md->u.md2)))) + goto err; + + nla_nest_end(skb, start); + return 0; +err: + nla_nest_cancel(skb, start); + return -EMSGSIZE; +} + +static int tunnel_key_opts_dump(struct sk_buff *skb, + const struct ip_tunnel_info *info) +{ + struct nlattr *start; + int err = -EINVAL; + + if (!info->options_len) + return 0; + + start = nla_nest_start_noflag(skb, TCA_TUNNEL_KEY_ENC_OPTS); + if (!start) + return -EMSGSIZE; + + if (info->key.tun_flags & TUNNEL_GENEVE_OPT) { + err = tunnel_key_geneve_opts_dump(skb, info); + if (err) + goto err_out; + } else if (info->key.tun_flags & TUNNEL_VXLAN_OPT) { + err = tunnel_key_vxlan_opts_dump(skb, info); + if (err) + goto err_out; + } else if (info->key.tun_flags & TUNNEL_ERSPAN_OPT) { + err = tunnel_key_erspan_opts_dump(skb, info); + if (err) + goto err_out; + } else { +err_out: + nla_nest_cancel(skb, start); + return err; + } + + nla_nest_end(skb, start); + return 0; +} + +static int tunnel_key_dump_addresses(struct sk_buff *skb, + const struct ip_tunnel_info *info) +{ + unsigned short family = ip_tunnel_info_af(info); + + if (family == AF_INET) { + __be32 saddr = info->key.u.ipv4.src; + __be32 daddr = info->key.u.ipv4.dst; + + if (!nla_put_in_addr(skb, TCA_TUNNEL_KEY_ENC_IPV4_SRC, saddr) && + !nla_put_in_addr(skb, TCA_TUNNEL_KEY_ENC_IPV4_DST, daddr)) + return 0; + } + + if (family == AF_INET6) { + const struct in6_addr *saddr6 = &info->key.u.ipv6.src; + const struct in6_addr *daddr6 = &info->key.u.ipv6.dst; + + if (!nla_put_in6_addr(skb, + TCA_TUNNEL_KEY_ENC_IPV6_SRC, saddr6) && + !nla_put_in6_addr(skb, + TCA_TUNNEL_KEY_ENC_IPV6_DST, daddr6)) + return 0; + } + + return -EINVAL; +} + +static int tunnel_key_dump(struct sk_buff *skb, struct tc_action *a, + int bind, int ref) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tcf_tunnel_key *t = to_tunnel_key(a); + struct tcf_tunnel_key_params *params; + struct tc_tunnel_key opt = { + .index = t->tcf_index, + .refcnt = refcount_read(&t->tcf_refcnt) - ref, + .bindcnt = atomic_read(&t->tcf_bindcnt) - bind, + }; + struct tcf_t tm; + + spin_lock_bh(&t->tcf_lock); + params = rcu_dereference_protected(t->params, + lockdep_is_held(&t->tcf_lock)); + opt.action = t->tcf_action; + opt.t_action = params->tcft_action; + + if (nla_put(skb, TCA_TUNNEL_KEY_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + + if (params->tcft_action == TCA_TUNNEL_KEY_ACT_SET) { + struct ip_tunnel_info *info = + ¶ms->tcft_enc_metadata->u.tun_info; + struct ip_tunnel_key *key = &info->key; + __be32 key_id = tunnel_id_to_key32(key->tun_id); + + if (((key->tun_flags & TUNNEL_KEY) && + nla_put_be32(skb, TCA_TUNNEL_KEY_ENC_KEY_ID, key_id)) || + tunnel_key_dump_addresses(skb, + ¶ms->tcft_enc_metadata->u.tun_info) || + (key->tp_dst && + nla_put_be16(skb, TCA_TUNNEL_KEY_ENC_DST_PORT, + key->tp_dst)) || + nla_put_u8(skb, TCA_TUNNEL_KEY_NO_CSUM, + !(key->tun_flags & TUNNEL_CSUM)) || + tunnel_key_opts_dump(skb, info)) + goto nla_put_failure; + + if (key->tos && nla_put_u8(skb, TCA_TUNNEL_KEY_ENC_TOS, key->tos)) + goto nla_put_failure; + + if (key->ttl && nla_put_u8(skb, TCA_TUNNEL_KEY_ENC_TTL, key->ttl)) + goto nla_put_failure; + } + + tcf_tm_dump(&tm, &t->tcf_tm); + if (nla_put_64bit(skb, TCA_TUNNEL_KEY_TM, sizeof(tm), + &tm, TCA_TUNNEL_KEY_PAD)) + goto nla_put_failure; + spin_unlock_bh(&t->tcf_lock); + + return skb->len; + +nla_put_failure: + spin_unlock_bh(&t->tcf_lock); + nlmsg_trim(skb, b); + return -1; +} + +static void tcf_tunnel_encap_put_tunnel(void *priv) +{ + struct ip_tunnel_info *tunnel = priv; + + kfree(tunnel); +} + +static int tcf_tunnel_encap_get_tunnel(struct flow_action_entry *entry, + const struct tc_action *act) +{ + entry->tunnel = tcf_tunnel_info_copy(act); + if (!entry->tunnel) + return -ENOMEM; + entry->destructor = tcf_tunnel_encap_put_tunnel; + entry->destructor_priv = entry->tunnel; + return 0; +} + +static int tcf_tunnel_key_offload_act_setup(struct tc_action *act, + void *entry_data, + u32 *index_inc, + bool bind, + struct netlink_ext_ack *extack) +{ + int err; + + if (bind) { + struct flow_action_entry *entry = entry_data; + + if (is_tcf_tunnel_set(act)) { + entry->id = FLOW_ACTION_TUNNEL_ENCAP; + err = tcf_tunnel_encap_get_tunnel(entry, act); + if (err) + return err; + } else if (is_tcf_tunnel_release(act)) { + entry->id = FLOW_ACTION_TUNNEL_DECAP; + } else { + NL_SET_ERR_MSG_MOD(extack, "Unsupported tunnel key mode offload"); + return -EOPNOTSUPP; + } + *index_inc = 1; + } else { + struct flow_offload_action *fl_action = entry_data; + + if (is_tcf_tunnel_set(act)) + fl_action->id = FLOW_ACTION_TUNNEL_ENCAP; + else if (is_tcf_tunnel_release(act)) + fl_action->id = FLOW_ACTION_TUNNEL_DECAP; + else + return -EOPNOTSUPP; + } + + return 0; +} + +static struct tc_action_ops act_tunnel_key_ops = { + .kind = "tunnel_key", + .id = TCA_ID_TUNNEL_KEY, + .owner = THIS_MODULE, + .act = tunnel_key_act, + .dump = tunnel_key_dump, + .init = tunnel_key_init, + .cleanup = tunnel_key_release, + .offload_act_setup = tcf_tunnel_key_offload_act_setup, + .size = sizeof(struct tcf_tunnel_key), +}; + +static __net_init int tunnel_key_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_tunnel_key_ops.net_id); + + return tc_action_net_init(net, tn, &act_tunnel_key_ops); +} + +static void __net_exit tunnel_key_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_tunnel_key_ops.net_id); +} + +static struct pernet_operations tunnel_key_net_ops = { + .init = tunnel_key_init_net, + .exit_batch = tunnel_key_exit_net, + .id = &act_tunnel_key_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +static int __init tunnel_key_init_module(void) +{ + return tcf_register_action(&act_tunnel_key_ops, &tunnel_key_net_ops); +} + +static void __exit tunnel_key_cleanup_module(void) +{ + tcf_unregister_action(&act_tunnel_key_ops, &tunnel_key_net_ops); +} + +module_init(tunnel_key_init_module); +module_exit(tunnel_key_cleanup_module); + +MODULE_AUTHOR("Amir Vadai "); +MODULE_DESCRIPTION("ip tunnel manipulation actions"); +MODULE_LICENSE("GPL v2"); diff --git a/net/sched/act_vlan.c b/net/sched/act_vlan.c new file mode 100644 index 000000000..7b24e898a --- /dev/null +++ b/net/sched/act_vlan.c @@ -0,0 +1,463 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Copyright (c) 2014 Jiri Pirko + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +static struct tc_action_ops act_vlan_ops; + +static int tcf_vlan_act(struct sk_buff *skb, const struct tc_action *a, + struct tcf_result *res) +{ + struct tcf_vlan *v = to_vlan(a); + struct tcf_vlan_params *p; + int action; + int err; + u16 tci; + + tcf_lastuse_update(&v->tcf_tm); + tcf_action_update_bstats(&v->common, skb); + + /* Ensure 'data' points at mac_header prior calling vlan manipulating + * functions. + */ + if (skb_at_tc_ingress(skb)) + skb_push_rcsum(skb, skb->mac_len); + + action = READ_ONCE(v->tcf_action); + + p = rcu_dereference_bh(v->vlan_p); + + switch (p->tcfv_action) { + case TCA_VLAN_ACT_POP: + err = skb_vlan_pop(skb); + if (err) + goto drop; + break; + case TCA_VLAN_ACT_PUSH: + err = skb_vlan_push(skb, p->tcfv_push_proto, p->tcfv_push_vid | + (p->tcfv_push_prio << VLAN_PRIO_SHIFT)); + if (err) + goto drop; + break; + case TCA_VLAN_ACT_MODIFY: + /* No-op if no vlan tag (either hw-accel or in-payload) */ + if (!skb_vlan_tagged(skb)) + goto out; + /* extract existing tag (and guarantee no hw-accel tag) */ + if (skb_vlan_tag_present(skb)) { + tci = skb_vlan_tag_get(skb); + __vlan_hwaccel_clear_tag(skb); + } else { + /* in-payload vlan tag, pop it */ + err = __skb_vlan_pop(skb, &tci); + if (err) + goto drop; + } + /* replace the vid */ + tci = (tci & ~VLAN_VID_MASK) | p->tcfv_push_vid; + /* replace prio bits, if tcfv_push_prio specified */ + if (p->tcfv_push_prio_exists) { + tci &= ~VLAN_PRIO_MASK; + tci |= p->tcfv_push_prio << VLAN_PRIO_SHIFT; + } + /* put updated tci as hwaccel tag */ + __vlan_hwaccel_put_tag(skb, p->tcfv_push_proto, tci); + break; + case TCA_VLAN_ACT_POP_ETH: + err = skb_eth_pop(skb); + if (err) + goto drop; + break; + case TCA_VLAN_ACT_PUSH_ETH: + err = skb_eth_push(skb, p->tcfv_push_dst, p->tcfv_push_src); + if (err) + goto drop; + break; + default: + BUG(); + } + +out: + if (skb_at_tc_ingress(skb)) + skb_pull_rcsum(skb, skb->mac_len); + + return action; + +drop: + tcf_action_inc_drop_qstats(&v->common); + return TC_ACT_SHOT; +} + +static const struct nla_policy vlan_policy[TCA_VLAN_MAX + 1] = { + [TCA_VLAN_UNSPEC] = { .strict_start_type = TCA_VLAN_PUSH_ETH_DST }, + [TCA_VLAN_PARMS] = { .len = sizeof(struct tc_vlan) }, + [TCA_VLAN_PUSH_VLAN_ID] = { .type = NLA_U16 }, + [TCA_VLAN_PUSH_VLAN_PROTOCOL] = { .type = NLA_U16 }, + [TCA_VLAN_PUSH_VLAN_PRIORITY] = { .type = NLA_U8 }, + [TCA_VLAN_PUSH_ETH_DST] = NLA_POLICY_ETH_ADDR, + [TCA_VLAN_PUSH_ETH_SRC] = NLA_POLICY_ETH_ADDR, +}; + +static int tcf_vlan_init(struct net *net, struct nlattr *nla, + struct nlattr *est, struct tc_action **a, + struct tcf_proto *tp, u32 flags, + struct netlink_ext_ack *extack) +{ + struct tc_action_net *tn = net_generic(net, act_vlan_ops.net_id); + bool bind = flags & TCA_ACT_FLAGS_BIND; + struct nlattr *tb[TCA_VLAN_MAX + 1]; + struct tcf_chain *goto_ch = NULL; + bool push_prio_exists = false; + struct tcf_vlan_params *p; + struct tc_vlan *parm; + struct tcf_vlan *v; + int action; + u16 push_vid = 0; + __be16 push_proto = 0; + u8 push_prio = 0; + bool exists = false; + int ret = 0, err; + u32 index; + + if (!nla) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, TCA_VLAN_MAX, nla, vlan_policy, + NULL); + if (err < 0) + return err; + + if (!tb[TCA_VLAN_PARMS]) + return -EINVAL; + parm = nla_data(tb[TCA_VLAN_PARMS]); + index = parm->index; + err = tcf_idr_check_alloc(tn, &index, a, bind); + if (err < 0) + return err; + exists = err; + if (exists && bind) + return 0; + + switch (parm->v_action) { + case TCA_VLAN_ACT_POP: + break; + case TCA_VLAN_ACT_PUSH: + case TCA_VLAN_ACT_MODIFY: + if (!tb[TCA_VLAN_PUSH_VLAN_ID]) { + if (exists) + tcf_idr_release(*a, bind); + else + tcf_idr_cleanup(tn, index); + return -EINVAL; + } + push_vid = nla_get_u16(tb[TCA_VLAN_PUSH_VLAN_ID]); + if (push_vid >= VLAN_VID_MASK) { + if (exists) + tcf_idr_release(*a, bind); + else + tcf_idr_cleanup(tn, index); + return -ERANGE; + } + + if (tb[TCA_VLAN_PUSH_VLAN_PROTOCOL]) { + push_proto = nla_get_be16(tb[TCA_VLAN_PUSH_VLAN_PROTOCOL]); + switch (push_proto) { + case htons(ETH_P_8021Q): + case htons(ETH_P_8021AD): + break; + default: + if (exists) + tcf_idr_release(*a, bind); + else + tcf_idr_cleanup(tn, index); + return -EPROTONOSUPPORT; + } + } else { + push_proto = htons(ETH_P_8021Q); + } + + push_prio_exists = !!tb[TCA_VLAN_PUSH_VLAN_PRIORITY]; + if (push_prio_exists) + push_prio = nla_get_u8(tb[TCA_VLAN_PUSH_VLAN_PRIORITY]); + break; + case TCA_VLAN_ACT_POP_ETH: + break; + case TCA_VLAN_ACT_PUSH_ETH: + if (!tb[TCA_VLAN_PUSH_ETH_DST] || !tb[TCA_VLAN_PUSH_ETH_SRC]) { + if (exists) + tcf_idr_release(*a, bind); + else + tcf_idr_cleanup(tn, index); + return -EINVAL; + } + break; + default: + if (exists) + tcf_idr_release(*a, bind); + else + tcf_idr_cleanup(tn, index); + return -EINVAL; + } + action = parm->v_action; + + if (!exists) { + ret = tcf_idr_create_from_flags(tn, index, est, a, + &act_vlan_ops, bind, flags); + if (ret) { + tcf_idr_cleanup(tn, index); + return ret; + } + + ret = ACT_P_CREATED; + } else if (!(flags & TCA_ACT_FLAGS_REPLACE)) { + tcf_idr_release(*a, bind); + return -EEXIST; + } + + err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack); + if (err < 0) + goto release_idr; + + v = to_vlan(*a); + + p = kzalloc(sizeof(*p), GFP_KERNEL); + if (!p) { + err = -ENOMEM; + goto put_chain; + } + + p->tcfv_action = action; + p->tcfv_push_vid = push_vid; + p->tcfv_push_prio = push_prio; + p->tcfv_push_prio_exists = push_prio_exists || action == TCA_VLAN_ACT_PUSH; + p->tcfv_push_proto = push_proto; + + if (action == TCA_VLAN_ACT_PUSH_ETH) { + nla_memcpy(&p->tcfv_push_dst, tb[TCA_VLAN_PUSH_ETH_DST], + ETH_ALEN); + nla_memcpy(&p->tcfv_push_src, tb[TCA_VLAN_PUSH_ETH_SRC], + ETH_ALEN); + } + + spin_lock_bh(&v->tcf_lock); + goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch); + p = rcu_replace_pointer(v->vlan_p, p, lockdep_is_held(&v->tcf_lock)); + spin_unlock_bh(&v->tcf_lock); + + if (goto_ch) + tcf_chain_put_by_act(goto_ch); + if (p) + kfree_rcu(p, rcu); + + return ret; +put_chain: + if (goto_ch) + tcf_chain_put_by_act(goto_ch); +release_idr: + tcf_idr_release(*a, bind); + return err; +} + +static void tcf_vlan_cleanup(struct tc_action *a) +{ + struct tcf_vlan *v = to_vlan(a); + struct tcf_vlan_params *p; + + p = rcu_dereference_protected(v->vlan_p, 1); + if (p) + kfree_rcu(p, rcu); +} + +static int tcf_vlan_dump(struct sk_buff *skb, struct tc_action *a, + int bind, int ref) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tcf_vlan *v = to_vlan(a); + struct tcf_vlan_params *p; + struct tc_vlan opt = { + .index = v->tcf_index, + .refcnt = refcount_read(&v->tcf_refcnt) - ref, + .bindcnt = atomic_read(&v->tcf_bindcnt) - bind, + }; + struct tcf_t t; + + spin_lock_bh(&v->tcf_lock); + opt.action = v->tcf_action; + p = rcu_dereference_protected(v->vlan_p, lockdep_is_held(&v->tcf_lock)); + opt.v_action = p->tcfv_action; + if (nla_put(skb, TCA_VLAN_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + + if ((p->tcfv_action == TCA_VLAN_ACT_PUSH || + p->tcfv_action == TCA_VLAN_ACT_MODIFY) && + (nla_put_u16(skb, TCA_VLAN_PUSH_VLAN_ID, p->tcfv_push_vid) || + nla_put_be16(skb, TCA_VLAN_PUSH_VLAN_PROTOCOL, + p->tcfv_push_proto) || + (p->tcfv_push_prio_exists && + nla_put_u8(skb, TCA_VLAN_PUSH_VLAN_PRIORITY, p->tcfv_push_prio)))) + goto nla_put_failure; + + if (p->tcfv_action == TCA_VLAN_ACT_PUSH_ETH) { + if (nla_put(skb, TCA_VLAN_PUSH_ETH_DST, ETH_ALEN, + p->tcfv_push_dst)) + goto nla_put_failure; + if (nla_put(skb, TCA_VLAN_PUSH_ETH_SRC, ETH_ALEN, + p->tcfv_push_src)) + goto nla_put_failure; + } + + tcf_tm_dump(&t, &v->tcf_tm); + if (nla_put_64bit(skb, TCA_VLAN_TM, sizeof(t), &t, TCA_VLAN_PAD)) + goto nla_put_failure; + spin_unlock_bh(&v->tcf_lock); + + return skb->len; + +nla_put_failure: + spin_unlock_bh(&v->tcf_lock); + nlmsg_trim(skb, b); + return -1; +} + +static void tcf_vlan_stats_update(struct tc_action *a, u64 bytes, u64 packets, + u64 drops, u64 lastuse, bool hw) +{ + struct tcf_vlan *v = to_vlan(a); + struct tcf_t *tm = &v->tcf_tm; + + tcf_action_update_stats(a, bytes, packets, drops, hw); + tm->lastuse = max_t(u64, tm->lastuse, lastuse); +} + +static size_t tcf_vlan_get_fill_size(const struct tc_action *act) +{ + return nla_total_size(sizeof(struct tc_vlan)) + + nla_total_size(sizeof(u16)) /* TCA_VLAN_PUSH_VLAN_ID */ + + nla_total_size(sizeof(u16)) /* TCA_VLAN_PUSH_VLAN_PROTOCOL */ + + nla_total_size(sizeof(u8)); /* TCA_VLAN_PUSH_VLAN_PRIORITY */ +} + +static int tcf_vlan_offload_act_setup(struct tc_action *act, void *entry_data, + u32 *index_inc, bool bind, + struct netlink_ext_ack *extack) +{ + if (bind) { + struct flow_action_entry *entry = entry_data; + + switch (tcf_vlan_action(act)) { + case TCA_VLAN_ACT_PUSH: + entry->id = FLOW_ACTION_VLAN_PUSH; + entry->vlan.vid = tcf_vlan_push_vid(act); + entry->vlan.proto = tcf_vlan_push_proto(act); + entry->vlan.prio = tcf_vlan_push_prio(act); + break; + case TCA_VLAN_ACT_POP: + entry->id = FLOW_ACTION_VLAN_POP; + break; + case TCA_VLAN_ACT_MODIFY: + entry->id = FLOW_ACTION_VLAN_MANGLE; + entry->vlan.vid = tcf_vlan_push_vid(act); + entry->vlan.proto = tcf_vlan_push_proto(act); + entry->vlan.prio = tcf_vlan_push_prio(act); + break; + case TCA_VLAN_ACT_POP_ETH: + entry->id = FLOW_ACTION_VLAN_POP_ETH; + break; + case TCA_VLAN_ACT_PUSH_ETH: + entry->id = FLOW_ACTION_VLAN_PUSH_ETH; + tcf_vlan_push_eth(entry->vlan_push_eth.src, entry->vlan_push_eth.dst, act); + break; + default: + NL_SET_ERR_MSG_MOD(extack, "Unsupported vlan action mode offload"); + return -EOPNOTSUPP; + } + *index_inc = 1; + } else { + struct flow_offload_action *fl_action = entry_data; + + switch (tcf_vlan_action(act)) { + case TCA_VLAN_ACT_PUSH: + fl_action->id = FLOW_ACTION_VLAN_PUSH; + break; + case TCA_VLAN_ACT_POP: + fl_action->id = FLOW_ACTION_VLAN_POP; + break; + case TCA_VLAN_ACT_MODIFY: + fl_action->id = FLOW_ACTION_VLAN_MANGLE; + break; + case TCA_VLAN_ACT_POP_ETH: + fl_action->id = FLOW_ACTION_VLAN_POP_ETH; + break; + case TCA_VLAN_ACT_PUSH_ETH: + fl_action->id = FLOW_ACTION_VLAN_PUSH_ETH; + break; + default: + return -EOPNOTSUPP; + } + } + + return 0; +} + +static struct tc_action_ops act_vlan_ops = { + .kind = "vlan", + .id = TCA_ID_VLAN, + .owner = THIS_MODULE, + .act = tcf_vlan_act, + .dump = tcf_vlan_dump, + .init = tcf_vlan_init, + .cleanup = tcf_vlan_cleanup, + .stats_update = tcf_vlan_stats_update, + .get_fill_size = tcf_vlan_get_fill_size, + .offload_act_setup = tcf_vlan_offload_act_setup, + .size = sizeof(struct tcf_vlan), +}; + +static __net_init int vlan_init_net(struct net *net) +{ + struct tc_action_net *tn = net_generic(net, act_vlan_ops.net_id); + + return tc_action_net_init(net, tn, &act_vlan_ops); +} + +static void __net_exit vlan_exit_net(struct list_head *net_list) +{ + tc_action_net_exit(net_list, act_vlan_ops.net_id); +} + +static struct pernet_operations vlan_net_ops = { + .init = vlan_init_net, + .exit_batch = vlan_exit_net, + .id = &act_vlan_ops.net_id, + .size = sizeof(struct tc_action_net), +}; + +static int __init vlan_init_module(void) +{ + return tcf_register_action(&act_vlan_ops, &vlan_net_ops); +} + +static void __exit vlan_cleanup_module(void) +{ + tcf_unregister_action(&act_vlan_ops, &vlan_net_ops); +} + +module_init(vlan_init_module); +module_exit(vlan_cleanup_module); + +MODULE_AUTHOR("Jiri Pirko "); +MODULE_DESCRIPTION("vlan manipulation actions"); +MODULE_LICENSE("GPL v2"); diff --git a/net/sched/cls_api.c b/net/sched/cls_api.c new file mode 100644 index 000000000..445ab1b05 --- /dev/null +++ b/net/sched/cls_api.c @@ -0,0 +1,3785 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/cls_api.c Packet classifier API. + * + * Authors: Alexey Kuznetsov, + * + * Changes: + * + * Eduardo J. Blanco :990222: kmod support + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* The list of all installed classifier types */ +static LIST_HEAD(tcf_proto_base); + +/* Protects list of registered TC modules. It is pure SMP lock. */ +static DEFINE_RWLOCK(cls_mod_lock); + +#ifdef CONFIG_NET_CLS_ACT +DEFINE_STATIC_KEY_FALSE(tc_skb_ext_tc); +EXPORT_SYMBOL(tc_skb_ext_tc); + +void tc_skb_ext_tc_enable(void) +{ + static_branch_inc(&tc_skb_ext_tc); +} +EXPORT_SYMBOL(tc_skb_ext_tc_enable); + +void tc_skb_ext_tc_disable(void) +{ + static_branch_dec(&tc_skb_ext_tc); +} +EXPORT_SYMBOL(tc_skb_ext_tc_disable); +#endif + +static u32 destroy_obj_hashfn(const struct tcf_proto *tp) +{ + return jhash_3words(tp->chain->index, tp->prio, + (__force __u32)tp->protocol, 0); +} + +static void tcf_proto_signal_destroying(struct tcf_chain *chain, + struct tcf_proto *tp) +{ + struct tcf_block *block = chain->block; + + mutex_lock(&block->proto_destroy_lock); + hash_add_rcu(block->proto_destroy_ht, &tp->destroy_ht_node, + destroy_obj_hashfn(tp)); + mutex_unlock(&block->proto_destroy_lock); +} + +static bool tcf_proto_cmp(const struct tcf_proto *tp1, + const struct tcf_proto *tp2) +{ + return tp1->chain->index == tp2->chain->index && + tp1->prio == tp2->prio && + tp1->protocol == tp2->protocol; +} + +static bool tcf_proto_exists_destroying(struct tcf_chain *chain, + struct tcf_proto *tp) +{ + u32 hash = destroy_obj_hashfn(tp); + struct tcf_proto *iter; + bool found = false; + + rcu_read_lock(); + hash_for_each_possible_rcu(chain->block->proto_destroy_ht, iter, + destroy_ht_node, hash) { + if (tcf_proto_cmp(tp, iter)) { + found = true; + break; + } + } + rcu_read_unlock(); + + return found; +} + +static void +tcf_proto_signal_destroyed(struct tcf_chain *chain, struct tcf_proto *tp) +{ + struct tcf_block *block = chain->block; + + mutex_lock(&block->proto_destroy_lock); + if (hash_hashed(&tp->destroy_ht_node)) + hash_del_rcu(&tp->destroy_ht_node); + mutex_unlock(&block->proto_destroy_lock); +} + +/* Find classifier type by string name */ + +static const struct tcf_proto_ops *__tcf_proto_lookup_ops(const char *kind) +{ + const struct tcf_proto_ops *t, *res = NULL; + + if (kind) { + read_lock(&cls_mod_lock); + list_for_each_entry(t, &tcf_proto_base, head) { + if (strcmp(kind, t->kind) == 0) { + if (try_module_get(t->owner)) + res = t; + break; + } + } + read_unlock(&cls_mod_lock); + } + return res; +} + +static const struct tcf_proto_ops * +tcf_proto_lookup_ops(const char *kind, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + const struct tcf_proto_ops *ops; + + ops = __tcf_proto_lookup_ops(kind); + if (ops) + return ops; +#ifdef CONFIG_MODULES + if (rtnl_held) + rtnl_unlock(); + request_module("cls_%s", kind); + if (rtnl_held) + rtnl_lock(); + ops = __tcf_proto_lookup_ops(kind); + /* We dropped the RTNL semaphore in order to perform + * the module load. So, even if we succeeded in loading + * the module we have to replay the request. We indicate + * this using -EAGAIN. + */ + if (ops) { + module_put(ops->owner); + return ERR_PTR(-EAGAIN); + } +#endif + NL_SET_ERR_MSG(extack, "TC classifier not found"); + return ERR_PTR(-ENOENT); +} + +/* Register(unregister) new classifier type */ + +int register_tcf_proto_ops(struct tcf_proto_ops *ops) +{ + struct tcf_proto_ops *t; + int rc = -EEXIST; + + write_lock(&cls_mod_lock); + list_for_each_entry(t, &tcf_proto_base, head) + if (!strcmp(ops->kind, t->kind)) + goto out; + + list_add_tail(&ops->head, &tcf_proto_base); + rc = 0; +out: + write_unlock(&cls_mod_lock); + return rc; +} +EXPORT_SYMBOL(register_tcf_proto_ops); + +static struct workqueue_struct *tc_filter_wq; + +void unregister_tcf_proto_ops(struct tcf_proto_ops *ops) +{ + struct tcf_proto_ops *t; + int rc = -ENOENT; + + /* Wait for outstanding call_rcu()s, if any, from a + * tcf_proto_ops's destroy() handler. + */ + rcu_barrier(); + flush_workqueue(tc_filter_wq); + + write_lock(&cls_mod_lock); + list_for_each_entry(t, &tcf_proto_base, head) { + if (t == ops) { + list_del(&t->head); + rc = 0; + break; + } + } + write_unlock(&cls_mod_lock); + + WARN(rc, "unregister tc filter kind(%s) failed %d\n", ops->kind, rc); +} +EXPORT_SYMBOL(unregister_tcf_proto_ops); + +bool tcf_queue_work(struct rcu_work *rwork, work_func_t func) +{ + INIT_RCU_WORK(rwork, func); + return queue_rcu_work(tc_filter_wq, rwork); +} +EXPORT_SYMBOL(tcf_queue_work); + +/* Select new prio value from the range, managed by kernel. */ + +static inline u32 tcf_auto_prio(struct tcf_proto *tp) +{ + u32 first = TC_H_MAKE(0xC0000000U, 0U); + + if (tp) + first = tp->prio - 1; + + return TC_H_MAJ(first); +} + +static bool tcf_proto_check_kind(struct nlattr *kind, char *name) +{ + if (kind) + return nla_strscpy(name, kind, IFNAMSIZ) < 0; + memset(name, 0, IFNAMSIZ); + return false; +} + +static bool tcf_proto_is_unlocked(const char *kind) +{ + const struct tcf_proto_ops *ops; + bool ret; + + if (strlen(kind) == 0) + return false; + + ops = tcf_proto_lookup_ops(kind, false, NULL); + /* On error return false to take rtnl lock. Proto lookup/create + * functions will perform lookup again and properly handle errors. + */ + if (IS_ERR(ops)) + return false; + + ret = !!(ops->flags & TCF_PROTO_OPS_DOIT_UNLOCKED); + module_put(ops->owner); + return ret; +} + +static struct tcf_proto *tcf_proto_create(const char *kind, u32 protocol, + u32 prio, struct tcf_chain *chain, + bool rtnl_held, + struct netlink_ext_ack *extack) +{ + struct tcf_proto *tp; + int err; + + tp = kzalloc(sizeof(*tp), GFP_KERNEL); + if (!tp) + return ERR_PTR(-ENOBUFS); + + tp->ops = tcf_proto_lookup_ops(kind, rtnl_held, extack); + if (IS_ERR(tp->ops)) { + err = PTR_ERR(tp->ops); + goto errout; + } + tp->classify = tp->ops->classify; + tp->protocol = protocol; + tp->prio = prio; + tp->chain = chain; + spin_lock_init(&tp->lock); + refcount_set(&tp->refcnt, 1); + + err = tp->ops->init(tp); + if (err) { + module_put(tp->ops->owner); + goto errout; + } + return tp; + +errout: + kfree(tp); + return ERR_PTR(err); +} + +static void tcf_proto_get(struct tcf_proto *tp) +{ + refcount_inc(&tp->refcnt); +} + +static void tcf_chain_put(struct tcf_chain *chain); + +static void tcf_proto_destroy(struct tcf_proto *tp, bool rtnl_held, + bool sig_destroy, struct netlink_ext_ack *extack) +{ + tp->ops->destroy(tp, rtnl_held, extack); + if (sig_destroy) + tcf_proto_signal_destroyed(tp->chain, tp); + tcf_chain_put(tp->chain); + module_put(tp->ops->owner); + kfree_rcu(tp, rcu); +} + +static void tcf_proto_put(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + if (refcount_dec_and_test(&tp->refcnt)) + tcf_proto_destroy(tp, rtnl_held, true, extack); +} + +static bool tcf_proto_check_delete(struct tcf_proto *tp) +{ + if (tp->ops->delete_empty) + return tp->ops->delete_empty(tp); + + tp->deleting = true; + return tp->deleting; +} + +static void tcf_proto_mark_delete(struct tcf_proto *tp) +{ + spin_lock(&tp->lock); + tp->deleting = true; + spin_unlock(&tp->lock); +} + +static bool tcf_proto_is_deleting(struct tcf_proto *tp) +{ + bool deleting; + + spin_lock(&tp->lock); + deleting = tp->deleting; + spin_unlock(&tp->lock); + + return deleting; +} + +#define ASSERT_BLOCK_LOCKED(block) \ + lockdep_assert_held(&(block)->lock) + +struct tcf_filter_chain_list_item { + struct list_head list; + tcf_chain_head_change_t *chain_head_change; + void *chain_head_change_priv; +}; + +static struct tcf_chain *tcf_chain_create(struct tcf_block *block, + u32 chain_index) +{ + struct tcf_chain *chain; + + ASSERT_BLOCK_LOCKED(block); + + chain = kzalloc(sizeof(*chain), GFP_KERNEL); + if (!chain) + return NULL; + list_add_tail_rcu(&chain->list, &block->chain_list); + mutex_init(&chain->filter_chain_lock); + chain->block = block; + chain->index = chain_index; + chain->refcnt = 1; + if (!chain->index) + block->chain0.chain = chain; + return chain; +} + +static void tcf_chain_head_change_item(struct tcf_filter_chain_list_item *item, + struct tcf_proto *tp_head) +{ + if (item->chain_head_change) + item->chain_head_change(tp_head, item->chain_head_change_priv); +} + +static void tcf_chain0_head_change(struct tcf_chain *chain, + struct tcf_proto *tp_head) +{ + struct tcf_filter_chain_list_item *item; + struct tcf_block *block = chain->block; + + if (chain->index) + return; + + mutex_lock(&block->lock); + list_for_each_entry(item, &block->chain0.filter_chain_list, list) + tcf_chain_head_change_item(item, tp_head); + mutex_unlock(&block->lock); +} + +/* Returns true if block can be safely freed. */ + +static bool tcf_chain_detach(struct tcf_chain *chain) +{ + struct tcf_block *block = chain->block; + + ASSERT_BLOCK_LOCKED(block); + + list_del_rcu(&chain->list); + if (!chain->index) + block->chain0.chain = NULL; + + if (list_empty(&block->chain_list) && + refcount_read(&block->refcnt) == 0) + return true; + + return false; +} + +static void tcf_block_destroy(struct tcf_block *block) +{ + mutex_destroy(&block->lock); + mutex_destroy(&block->proto_destroy_lock); + kfree_rcu(block, rcu); +} + +static void tcf_chain_destroy(struct tcf_chain *chain, bool free_block) +{ + struct tcf_block *block = chain->block; + + mutex_destroy(&chain->filter_chain_lock); + kfree_rcu(chain, rcu); + if (free_block) + tcf_block_destroy(block); +} + +static void tcf_chain_hold(struct tcf_chain *chain) +{ + ASSERT_BLOCK_LOCKED(chain->block); + + ++chain->refcnt; +} + +static bool tcf_chain_held_by_acts_only(struct tcf_chain *chain) +{ + ASSERT_BLOCK_LOCKED(chain->block); + + /* In case all the references are action references, this + * chain should not be shown to the user. + */ + return chain->refcnt == chain->action_refcnt; +} + +static struct tcf_chain *tcf_chain_lookup(struct tcf_block *block, + u32 chain_index) +{ + struct tcf_chain *chain; + + ASSERT_BLOCK_LOCKED(block); + + list_for_each_entry(chain, &block->chain_list, list) { + if (chain->index == chain_index) + return chain; + } + return NULL; +} + +#if IS_ENABLED(CONFIG_NET_TC_SKB_EXT) +static struct tcf_chain *tcf_chain_lookup_rcu(const struct tcf_block *block, + u32 chain_index) +{ + struct tcf_chain *chain; + + list_for_each_entry_rcu(chain, &block->chain_list, list) { + if (chain->index == chain_index) + return chain; + } + return NULL; +} +#endif + +static int tc_chain_notify(struct tcf_chain *chain, struct sk_buff *oskb, + u32 seq, u16 flags, int event, bool unicast, + struct netlink_ext_ack *extack); + +static struct tcf_chain *__tcf_chain_get(struct tcf_block *block, + u32 chain_index, bool create, + bool by_act) +{ + struct tcf_chain *chain = NULL; + bool is_first_reference; + + mutex_lock(&block->lock); + chain = tcf_chain_lookup(block, chain_index); + if (chain) { + tcf_chain_hold(chain); + } else { + if (!create) + goto errout; + chain = tcf_chain_create(block, chain_index); + if (!chain) + goto errout; + } + + if (by_act) + ++chain->action_refcnt; + is_first_reference = chain->refcnt - chain->action_refcnt == 1; + mutex_unlock(&block->lock); + + /* Send notification only in case we got the first + * non-action reference. Until then, the chain acts only as + * a placeholder for actions pointing to it and user ought + * not know about them. + */ + if (is_first_reference && !by_act) + tc_chain_notify(chain, NULL, 0, NLM_F_CREATE | NLM_F_EXCL, + RTM_NEWCHAIN, false, NULL); + + return chain; + +errout: + mutex_unlock(&block->lock); + return chain; +} + +static struct tcf_chain *tcf_chain_get(struct tcf_block *block, u32 chain_index, + bool create) +{ + return __tcf_chain_get(block, chain_index, create, false); +} + +struct tcf_chain *tcf_chain_get_by_act(struct tcf_block *block, u32 chain_index) +{ + return __tcf_chain_get(block, chain_index, true, true); +} +EXPORT_SYMBOL(tcf_chain_get_by_act); + +static void tc_chain_tmplt_del(const struct tcf_proto_ops *tmplt_ops, + void *tmplt_priv); +static int tc_chain_notify_delete(const struct tcf_proto_ops *tmplt_ops, + void *tmplt_priv, u32 chain_index, + struct tcf_block *block, struct sk_buff *oskb, + u32 seq, u16 flags, bool unicast); + +static void __tcf_chain_put(struct tcf_chain *chain, bool by_act, + bool explicitly_created) +{ + struct tcf_block *block = chain->block; + const struct tcf_proto_ops *tmplt_ops; + unsigned int refcnt, non_act_refcnt; + bool free_block = false; + void *tmplt_priv; + + mutex_lock(&block->lock); + if (explicitly_created) { + if (!chain->explicitly_created) { + mutex_unlock(&block->lock); + return; + } + chain->explicitly_created = false; + } + + if (by_act) + chain->action_refcnt--; + + /* tc_chain_notify_delete can't be called while holding block lock. + * However, when block is unlocked chain can be changed concurrently, so + * save these to temporary variables. + */ + refcnt = --chain->refcnt; + non_act_refcnt = refcnt - chain->action_refcnt; + tmplt_ops = chain->tmplt_ops; + tmplt_priv = chain->tmplt_priv; + + if (non_act_refcnt == chain->explicitly_created && !by_act) { + if (non_act_refcnt == 0) + tc_chain_notify_delete(tmplt_ops, tmplt_priv, + chain->index, block, NULL, 0, 0, + false); + /* Last reference to chain, no need to lock. */ + chain->flushing = false; + } + + if (refcnt == 0) + free_block = tcf_chain_detach(chain); + mutex_unlock(&block->lock); + + if (refcnt == 0) { + tc_chain_tmplt_del(tmplt_ops, tmplt_priv); + tcf_chain_destroy(chain, free_block); + } +} + +static void tcf_chain_put(struct tcf_chain *chain) +{ + __tcf_chain_put(chain, false, false); +} + +void tcf_chain_put_by_act(struct tcf_chain *chain) +{ + __tcf_chain_put(chain, true, false); +} +EXPORT_SYMBOL(tcf_chain_put_by_act); + +static void tcf_chain_put_explicitly_created(struct tcf_chain *chain) +{ + __tcf_chain_put(chain, false, true); +} + +static void tcf_chain_flush(struct tcf_chain *chain, bool rtnl_held) +{ + struct tcf_proto *tp, *tp_next; + + mutex_lock(&chain->filter_chain_lock); + tp = tcf_chain_dereference(chain->filter_chain, chain); + while (tp) { + tp_next = rcu_dereference_protected(tp->next, 1); + tcf_proto_signal_destroying(chain, tp); + tp = tp_next; + } + tp = tcf_chain_dereference(chain->filter_chain, chain); + RCU_INIT_POINTER(chain->filter_chain, NULL); + tcf_chain0_head_change(chain, NULL); + chain->flushing = true; + mutex_unlock(&chain->filter_chain_lock); + + while (tp) { + tp_next = rcu_dereference_protected(tp->next, 1); + tcf_proto_put(tp, rtnl_held, NULL); + tp = tp_next; + } +} + +static int tcf_block_setup(struct tcf_block *block, + struct flow_block_offload *bo); + +static void tcf_block_offload_init(struct flow_block_offload *bo, + struct net_device *dev, struct Qdisc *sch, + enum flow_block_command command, + enum flow_block_binder_type binder_type, + struct flow_block *flow_block, + bool shared, struct netlink_ext_ack *extack) +{ + bo->net = dev_net(dev); + bo->command = command; + bo->binder_type = binder_type; + bo->block = flow_block; + bo->block_shared = shared; + bo->extack = extack; + bo->sch = sch; + bo->cb_list_head = &flow_block->cb_list; + INIT_LIST_HEAD(&bo->cb_list); +} + +static void tcf_block_unbind(struct tcf_block *block, + struct flow_block_offload *bo); + +static void tc_block_indr_cleanup(struct flow_block_cb *block_cb) +{ + struct tcf_block *block = block_cb->indr.data; + struct net_device *dev = block_cb->indr.dev; + struct Qdisc *sch = block_cb->indr.sch; + struct netlink_ext_ack extack = {}; + struct flow_block_offload bo = {}; + + tcf_block_offload_init(&bo, dev, sch, FLOW_BLOCK_UNBIND, + block_cb->indr.binder_type, + &block->flow_block, tcf_block_shared(block), + &extack); + rtnl_lock(); + down_write(&block->cb_lock); + list_del(&block_cb->driver_list); + list_move(&block_cb->list, &bo.cb_list); + tcf_block_unbind(block, &bo); + up_write(&block->cb_lock); + rtnl_unlock(); +} + +static bool tcf_block_offload_in_use(struct tcf_block *block) +{ + return atomic_read(&block->offloadcnt); +} + +static int tcf_block_offload_cmd(struct tcf_block *block, + struct net_device *dev, struct Qdisc *sch, + struct tcf_block_ext_info *ei, + enum flow_block_command command, + struct netlink_ext_ack *extack) +{ + struct flow_block_offload bo = {}; + + tcf_block_offload_init(&bo, dev, sch, command, ei->binder_type, + &block->flow_block, tcf_block_shared(block), + extack); + + if (dev->netdev_ops->ndo_setup_tc) { + int err; + + err = dev->netdev_ops->ndo_setup_tc(dev, TC_SETUP_BLOCK, &bo); + if (err < 0) { + if (err != -EOPNOTSUPP) + NL_SET_ERR_MSG(extack, "Driver ndo_setup_tc failed"); + return err; + } + + return tcf_block_setup(block, &bo); + } + + flow_indr_dev_setup_offload(dev, sch, TC_SETUP_BLOCK, block, &bo, + tc_block_indr_cleanup); + tcf_block_setup(block, &bo); + + return -EOPNOTSUPP; +} + +static int tcf_block_offload_bind(struct tcf_block *block, struct Qdisc *q, + struct tcf_block_ext_info *ei, + struct netlink_ext_ack *extack) +{ + struct net_device *dev = q->dev_queue->dev; + int err; + + down_write(&block->cb_lock); + + /* If tc offload feature is disabled and the block we try to bind + * to already has some offloaded filters, forbid to bind. + */ + if (dev->netdev_ops->ndo_setup_tc && + !tc_can_offload(dev) && + tcf_block_offload_in_use(block)) { + NL_SET_ERR_MSG(extack, "Bind to offloaded block failed as dev has offload disabled"); + err = -EOPNOTSUPP; + goto err_unlock; + } + + err = tcf_block_offload_cmd(block, dev, q, ei, FLOW_BLOCK_BIND, extack); + if (err == -EOPNOTSUPP) + goto no_offload_dev_inc; + if (err) + goto err_unlock; + + up_write(&block->cb_lock); + return 0; + +no_offload_dev_inc: + if (tcf_block_offload_in_use(block)) + goto err_unlock; + + err = 0; + block->nooffloaddevcnt++; +err_unlock: + up_write(&block->cb_lock); + return err; +} + +static void tcf_block_offload_unbind(struct tcf_block *block, struct Qdisc *q, + struct tcf_block_ext_info *ei) +{ + struct net_device *dev = q->dev_queue->dev; + int err; + + down_write(&block->cb_lock); + err = tcf_block_offload_cmd(block, dev, q, ei, FLOW_BLOCK_UNBIND, NULL); + if (err == -EOPNOTSUPP) + goto no_offload_dev_dec; + up_write(&block->cb_lock); + return; + +no_offload_dev_dec: + WARN_ON(block->nooffloaddevcnt-- == 0); + up_write(&block->cb_lock); +} + +static int +tcf_chain0_head_change_cb_add(struct tcf_block *block, + struct tcf_block_ext_info *ei, + struct netlink_ext_ack *extack) +{ + struct tcf_filter_chain_list_item *item; + struct tcf_chain *chain0; + + item = kmalloc(sizeof(*item), GFP_KERNEL); + if (!item) { + NL_SET_ERR_MSG(extack, "Memory allocation for head change callback item failed"); + return -ENOMEM; + } + item->chain_head_change = ei->chain_head_change; + item->chain_head_change_priv = ei->chain_head_change_priv; + + mutex_lock(&block->lock); + chain0 = block->chain0.chain; + if (chain0) + tcf_chain_hold(chain0); + else + list_add(&item->list, &block->chain0.filter_chain_list); + mutex_unlock(&block->lock); + + if (chain0) { + struct tcf_proto *tp_head; + + mutex_lock(&chain0->filter_chain_lock); + + tp_head = tcf_chain_dereference(chain0->filter_chain, chain0); + if (tp_head) + tcf_chain_head_change_item(item, tp_head); + + mutex_lock(&block->lock); + list_add(&item->list, &block->chain0.filter_chain_list); + mutex_unlock(&block->lock); + + mutex_unlock(&chain0->filter_chain_lock); + tcf_chain_put(chain0); + } + + return 0; +} + +static void +tcf_chain0_head_change_cb_del(struct tcf_block *block, + struct tcf_block_ext_info *ei) +{ + struct tcf_filter_chain_list_item *item; + + mutex_lock(&block->lock); + list_for_each_entry(item, &block->chain0.filter_chain_list, list) { + if ((!ei->chain_head_change && !ei->chain_head_change_priv) || + (item->chain_head_change == ei->chain_head_change && + item->chain_head_change_priv == ei->chain_head_change_priv)) { + if (block->chain0.chain) + tcf_chain_head_change_item(item, NULL); + list_del(&item->list); + mutex_unlock(&block->lock); + + kfree(item); + return; + } + } + mutex_unlock(&block->lock); + WARN_ON(1); +} + +struct tcf_net { + spinlock_t idr_lock; /* Protects idr */ + struct idr idr; +}; + +static unsigned int tcf_net_id; + +static int tcf_block_insert(struct tcf_block *block, struct net *net, + struct netlink_ext_ack *extack) +{ + struct tcf_net *tn = net_generic(net, tcf_net_id); + int err; + + idr_preload(GFP_KERNEL); + spin_lock(&tn->idr_lock); + err = idr_alloc_u32(&tn->idr, block, &block->index, block->index, + GFP_NOWAIT); + spin_unlock(&tn->idr_lock); + idr_preload_end(); + + return err; +} + +static void tcf_block_remove(struct tcf_block *block, struct net *net) +{ + struct tcf_net *tn = net_generic(net, tcf_net_id); + + spin_lock(&tn->idr_lock); + idr_remove(&tn->idr, block->index); + spin_unlock(&tn->idr_lock); +} + +static struct tcf_block *tcf_block_create(struct net *net, struct Qdisc *q, + u32 block_index, + struct netlink_ext_ack *extack) +{ + struct tcf_block *block; + + block = kzalloc(sizeof(*block), GFP_KERNEL); + if (!block) { + NL_SET_ERR_MSG(extack, "Memory allocation for block failed"); + return ERR_PTR(-ENOMEM); + } + mutex_init(&block->lock); + mutex_init(&block->proto_destroy_lock); + init_rwsem(&block->cb_lock); + flow_block_init(&block->flow_block); + INIT_LIST_HEAD(&block->chain_list); + INIT_LIST_HEAD(&block->owner_list); + INIT_LIST_HEAD(&block->chain0.filter_chain_list); + + refcount_set(&block->refcnt, 1); + block->net = net; + block->index = block_index; + + /* Don't store q pointer for blocks which are shared */ + if (!tcf_block_shared(block)) + block->q = q; + return block; +} + +static struct tcf_block *tcf_block_lookup(struct net *net, u32 block_index) +{ + struct tcf_net *tn = net_generic(net, tcf_net_id); + + return idr_find(&tn->idr, block_index); +} + +static struct tcf_block *tcf_block_refcnt_get(struct net *net, u32 block_index) +{ + struct tcf_block *block; + + rcu_read_lock(); + block = tcf_block_lookup(net, block_index); + if (block && !refcount_inc_not_zero(&block->refcnt)) + block = NULL; + rcu_read_unlock(); + + return block; +} + +static struct tcf_chain * +__tcf_get_next_chain(struct tcf_block *block, struct tcf_chain *chain) +{ + mutex_lock(&block->lock); + if (chain) + chain = list_is_last(&chain->list, &block->chain_list) ? + NULL : list_next_entry(chain, list); + else + chain = list_first_entry_or_null(&block->chain_list, + struct tcf_chain, list); + + /* skip all action-only chains */ + while (chain && tcf_chain_held_by_acts_only(chain)) + chain = list_is_last(&chain->list, &block->chain_list) ? + NULL : list_next_entry(chain, list); + + if (chain) + tcf_chain_hold(chain); + mutex_unlock(&block->lock); + + return chain; +} + +/* Function to be used by all clients that want to iterate over all chains on + * block. It properly obtains block->lock and takes reference to chain before + * returning it. Users of this function must be tolerant to concurrent chain + * insertion/deletion or ensure that no concurrent chain modification is + * possible. Note that all netlink dump callbacks cannot guarantee to provide + * consistent dump because rtnl lock is released each time skb is filled with + * data and sent to user-space. + */ + +struct tcf_chain * +tcf_get_next_chain(struct tcf_block *block, struct tcf_chain *chain) +{ + struct tcf_chain *chain_next = __tcf_get_next_chain(block, chain); + + if (chain) + tcf_chain_put(chain); + + return chain_next; +} +EXPORT_SYMBOL(tcf_get_next_chain); + +static struct tcf_proto * +__tcf_get_next_proto(struct tcf_chain *chain, struct tcf_proto *tp) +{ + u32 prio = 0; + + ASSERT_RTNL(); + mutex_lock(&chain->filter_chain_lock); + + if (!tp) { + tp = tcf_chain_dereference(chain->filter_chain, chain); + } else if (tcf_proto_is_deleting(tp)) { + /* 'deleting' flag is set and chain->filter_chain_lock was + * unlocked, which means next pointer could be invalid. Restart + * search. + */ + prio = tp->prio + 1; + tp = tcf_chain_dereference(chain->filter_chain, chain); + + for (; tp; tp = tcf_chain_dereference(tp->next, chain)) + if (!tp->deleting && tp->prio >= prio) + break; + } else { + tp = tcf_chain_dereference(tp->next, chain); + } + + if (tp) + tcf_proto_get(tp); + + mutex_unlock(&chain->filter_chain_lock); + + return tp; +} + +/* Function to be used by all clients that want to iterate over all tp's on + * chain. Users of this function must be tolerant to concurrent tp + * insertion/deletion or ensure that no concurrent chain modification is + * possible. Note that all netlink dump callbacks cannot guarantee to provide + * consistent dump because rtnl lock is released each time skb is filled with + * data and sent to user-space. + */ + +struct tcf_proto * +tcf_get_next_proto(struct tcf_chain *chain, struct tcf_proto *tp) +{ + struct tcf_proto *tp_next = __tcf_get_next_proto(chain, tp); + + if (tp) + tcf_proto_put(tp, true, NULL); + + return tp_next; +} +EXPORT_SYMBOL(tcf_get_next_proto); + +static void tcf_block_flush_all_chains(struct tcf_block *block, bool rtnl_held) +{ + struct tcf_chain *chain; + + /* Last reference to block. At this point chains cannot be added or + * removed concurrently. + */ + for (chain = tcf_get_next_chain(block, NULL); + chain; + chain = tcf_get_next_chain(block, chain)) { + tcf_chain_put_explicitly_created(chain); + tcf_chain_flush(chain, rtnl_held); + } +} + +/* Lookup Qdisc and increments its reference counter. + * Set parent, if necessary. + */ + +static int __tcf_qdisc_find(struct net *net, struct Qdisc **q, + u32 *parent, int ifindex, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + const struct Qdisc_class_ops *cops; + struct net_device *dev; + int err = 0; + + if (ifindex == TCM_IFINDEX_MAGIC_BLOCK) + return 0; + + rcu_read_lock(); + + /* Find link */ + dev = dev_get_by_index_rcu(net, ifindex); + if (!dev) { + rcu_read_unlock(); + return -ENODEV; + } + + /* Find qdisc */ + if (!*parent) { + *q = rcu_dereference(dev->qdisc); + *parent = (*q)->handle; + } else { + *q = qdisc_lookup_rcu(dev, TC_H_MAJ(*parent)); + if (!*q) { + NL_SET_ERR_MSG(extack, "Parent Qdisc doesn't exists"); + err = -EINVAL; + goto errout_rcu; + } + } + + *q = qdisc_refcount_inc_nz(*q); + if (!*q) { + NL_SET_ERR_MSG(extack, "Parent Qdisc doesn't exists"); + err = -EINVAL; + goto errout_rcu; + } + + /* Is it classful? */ + cops = (*q)->ops->cl_ops; + if (!cops) { + NL_SET_ERR_MSG(extack, "Qdisc not classful"); + err = -EINVAL; + goto errout_qdisc; + } + + if (!cops->tcf_block) { + NL_SET_ERR_MSG(extack, "Class doesn't support blocks"); + err = -EOPNOTSUPP; + goto errout_qdisc; + } + +errout_rcu: + /* At this point we know that qdisc is not noop_qdisc, + * which means that qdisc holds a reference to net_device + * and we hold a reference to qdisc, so it is safe to release + * rcu read lock. + */ + rcu_read_unlock(); + return err; + +errout_qdisc: + rcu_read_unlock(); + + if (rtnl_held) + qdisc_put(*q); + else + qdisc_put_unlocked(*q); + *q = NULL; + + return err; +} + +static int __tcf_qdisc_cl_find(struct Qdisc *q, u32 parent, unsigned long *cl, + int ifindex, struct netlink_ext_ack *extack) +{ + if (ifindex == TCM_IFINDEX_MAGIC_BLOCK) + return 0; + + /* Do we search for filter, attached to class? */ + if (TC_H_MIN(parent)) { + const struct Qdisc_class_ops *cops = q->ops->cl_ops; + + *cl = cops->find(q, parent); + if (*cl == 0) { + NL_SET_ERR_MSG(extack, "Specified class doesn't exist"); + return -ENOENT; + } + } + + return 0; +} + +static struct tcf_block *__tcf_block_find(struct net *net, struct Qdisc *q, + unsigned long cl, int ifindex, + u32 block_index, + struct netlink_ext_ack *extack) +{ + struct tcf_block *block; + + if (ifindex == TCM_IFINDEX_MAGIC_BLOCK) { + block = tcf_block_refcnt_get(net, block_index); + if (!block) { + NL_SET_ERR_MSG(extack, "Block of given index was not found"); + return ERR_PTR(-EINVAL); + } + } else { + const struct Qdisc_class_ops *cops = q->ops->cl_ops; + + block = cops->tcf_block(q, cl, extack); + if (!block) + return ERR_PTR(-EINVAL); + + if (tcf_block_shared(block)) { + NL_SET_ERR_MSG(extack, "This filter block is shared. Please use the block index to manipulate the filters"); + return ERR_PTR(-EOPNOTSUPP); + } + + /* Always take reference to block in order to support execution + * of rules update path of cls API without rtnl lock. Caller + * must release block when it is finished using it. 'if' block + * of this conditional obtain reference to block by calling + * tcf_block_refcnt_get(). + */ + refcount_inc(&block->refcnt); + } + + return block; +} + +static void __tcf_block_put(struct tcf_block *block, struct Qdisc *q, + struct tcf_block_ext_info *ei, bool rtnl_held) +{ + if (refcount_dec_and_mutex_lock(&block->refcnt, &block->lock)) { + /* Flushing/putting all chains will cause the block to be + * deallocated when last chain is freed. However, if chain_list + * is empty, block has to be manually deallocated. After block + * reference counter reached 0, it is no longer possible to + * increment it or add new chains to block. + */ + bool free_block = list_empty(&block->chain_list); + + mutex_unlock(&block->lock); + if (tcf_block_shared(block)) + tcf_block_remove(block, block->net); + + if (q) + tcf_block_offload_unbind(block, q, ei); + + if (free_block) + tcf_block_destroy(block); + else + tcf_block_flush_all_chains(block, rtnl_held); + } else if (q) { + tcf_block_offload_unbind(block, q, ei); + } +} + +static void tcf_block_refcnt_put(struct tcf_block *block, bool rtnl_held) +{ + __tcf_block_put(block, NULL, NULL, rtnl_held); +} + +/* Find tcf block. + * Set q, parent, cl when appropriate. + */ + +static struct tcf_block *tcf_block_find(struct net *net, struct Qdisc **q, + u32 *parent, unsigned long *cl, + int ifindex, u32 block_index, + struct netlink_ext_ack *extack) +{ + struct tcf_block *block; + int err = 0; + + ASSERT_RTNL(); + + err = __tcf_qdisc_find(net, q, parent, ifindex, true, extack); + if (err) + goto errout; + + err = __tcf_qdisc_cl_find(*q, *parent, cl, ifindex, extack); + if (err) + goto errout_qdisc; + + block = __tcf_block_find(net, *q, *cl, ifindex, block_index, extack); + if (IS_ERR(block)) { + err = PTR_ERR(block); + goto errout_qdisc; + } + + return block; + +errout_qdisc: + if (*q) + qdisc_put(*q); +errout: + *q = NULL; + return ERR_PTR(err); +} + +static void tcf_block_release(struct Qdisc *q, struct tcf_block *block, + bool rtnl_held) +{ + if (!IS_ERR_OR_NULL(block)) + tcf_block_refcnt_put(block, rtnl_held); + + if (q) { + if (rtnl_held) + qdisc_put(q); + else + qdisc_put_unlocked(q); + } +} + +struct tcf_block_owner_item { + struct list_head list; + struct Qdisc *q; + enum flow_block_binder_type binder_type; +}; + +static void +tcf_block_owner_netif_keep_dst(struct tcf_block *block, + struct Qdisc *q, + enum flow_block_binder_type binder_type) +{ + if (block->keep_dst && + binder_type != FLOW_BLOCK_BINDER_TYPE_CLSACT_INGRESS && + binder_type != FLOW_BLOCK_BINDER_TYPE_CLSACT_EGRESS) + netif_keep_dst(qdisc_dev(q)); +} + +void tcf_block_netif_keep_dst(struct tcf_block *block) +{ + struct tcf_block_owner_item *item; + + block->keep_dst = true; + list_for_each_entry(item, &block->owner_list, list) + tcf_block_owner_netif_keep_dst(block, item->q, + item->binder_type); +} +EXPORT_SYMBOL(tcf_block_netif_keep_dst); + +static int tcf_block_owner_add(struct tcf_block *block, + struct Qdisc *q, + enum flow_block_binder_type binder_type) +{ + struct tcf_block_owner_item *item; + + item = kmalloc(sizeof(*item), GFP_KERNEL); + if (!item) + return -ENOMEM; + item->q = q; + item->binder_type = binder_type; + list_add(&item->list, &block->owner_list); + return 0; +} + +static void tcf_block_owner_del(struct tcf_block *block, + struct Qdisc *q, + enum flow_block_binder_type binder_type) +{ + struct tcf_block_owner_item *item; + + list_for_each_entry(item, &block->owner_list, list) { + if (item->q == q && item->binder_type == binder_type) { + list_del(&item->list); + kfree(item); + return; + } + } + WARN_ON(1); +} + +int tcf_block_get_ext(struct tcf_block **p_block, struct Qdisc *q, + struct tcf_block_ext_info *ei, + struct netlink_ext_ack *extack) +{ + struct net *net = qdisc_net(q); + struct tcf_block *block = NULL; + int err; + + if (ei->block_index) + /* block_index not 0 means the shared block is requested */ + block = tcf_block_refcnt_get(net, ei->block_index); + + if (!block) { + block = tcf_block_create(net, q, ei->block_index, extack); + if (IS_ERR(block)) + return PTR_ERR(block); + if (tcf_block_shared(block)) { + err = tcf_block_insert(block, net, extack); + if (err) + goto err_block_insert; + } + } + + err = tcf_block_owner_add(block, q, ei->binder_type); + if (err) + goto err_block_owner_add; + + tcf_block_owner_netif_keep_dst(block, q, ei->binder_type); + + err = tcf_chain0_head_change_cb_add(block, ei, extack); + if (err) + goto err_chain0_head_change_cb_add; + + err = tcf_block_offload_bind(block, q, ei, extack); + if (err) + goto err_block_offload_bind; + + *p_block = block; + return 0; + +err_block_offload_bind: + tcf_chain0_head_change_cb_del(block, ei); +err_chain0_head_change_cb_add: + tcf_block_owner_del(block, q, ei->binder_type); +err_block_owner_add: +err_block_insert: + tcf_block_refcnt_put(block, true); + return err; +} +EXPORT_SYMBOL(tcf_block_get_ext); + +static void tcf_chain_head_change_dflt(struct tcf_proto *tp_head, void *priv) +{ + struct tcf_proto __rcu **p_filter_chain = priv; + + rcu_assign_pointer(*p_filter_chain, tp_head); +} + +int tcf_block_get(struct tcf_block **p_block, + struct tcf_proto __rcu **p_filter_chain, struct Qdisc *q, + struct netlink_ext_ack *extack) +{ + struct tcf_block_ext_info ei = { + .chain_head_change = tcf_chain_head_change_dflt, + .chain_head_change_priv = p_filter_chain, + }; + + WARN_ON(!p_filter_chain); + return tcf_block_get_ext(p_block, q, &ei, extack); +} +EXPORT_SYMBOL(tcf_block_get); + +/* XXX: Standalone actions are not allowed to jump to any chain, and bound + * actions should be all removed after flushing. + */ +void tcf_block_put_ext(struct tcf_block *block, struct Qdisc *q, + struct tcf_block_ext_info *ei) +{ + if (!block) + return; + tcf_chain0_head_change_cb_del(block, ei); + tcf_block_owner_del(block, q, ei->binder_type); + + __tcf_block_put(block, q, ei, true); +} +EXPORT_SYMBOL(tcf_block_put_ext); + +void tcf_block_put(struct tcf_block *block) +{ + struct tcf_block_ext_info ei = {0, }; + + if (!block) + return; + tcf_block_put_ext(block, block->q, &ei); +} + +EXPORT_SYMBOL(tcf_block_put); + +static int +tcf_block_playback_offloads(struct tcf_block *block, flow_setup_cb_t *cb, + void *cb_priv, bool add, bool offload_in_use, + struct netlink_ext_ack *extack) +{ + struct tcf_chain *chain, *chain_prev; + struct tcf_proto *tp, *tp_prev; + int err; + + lockdep_assert_held(&block->cb_lock); + + for (chain = __tcf_get_next_chain(block, NULL); + chain; + chain_prev = chain, + chain = __tcf_get_next_chain(block, chain), + tcf_chain_put(chain_prev)) { + for (tp = __tcf_get_next_proto(chain, NULL); tp; + tp_prev = tp, + tp = __tcf_get_next_proto(chain, tp), + tcf_proto_put(tp_prev, true, NULL)) { + if (tp->ops->reoffload) { + err = tp->ops->reoffload(tp, add, cb, cb_priv, + extack); + if (err && add) + goto err_playback_remove; + } else if (add && offload_in_use) { + err = -EOPNOTSUPP; + NL_SET_ERR_MSG(extack, "Filter HW offload failed - classifier without re-offloading support"); + goto err_playback_remove; + } + } + } + + return 0; + +err_playback_remove: + tcf_proto_put(tp, true, NULL); + tcf_chain_put(chain); + tcf_block_playback_offloads(block, cb, cb_priv, false, offload_in_use, + extack); + return err; +} + +static int tcf_block_bind(struct tcf_block *block, + struct flow_block_offload *bo) +{ + struct flow_block_cb *block_cb, *next; + int err, i = 0; + + lockdep_assert_held(&block->cb_lock); + + list_for_each_entry(block_cb, &bo->cb_list, list) { + err = tcf_block_playback_offloads(block, block_cb->cb, + block_cb->cb_priv, true, + tcf_block_offload_in_use(block), + bo->extack); + if (err) + goto err_unroll; + if (!bo->unlocked_driver_cb) + block->lockeddevcnt++; + + i++; + } + list_splice(&bo->cb_list, &block->flow_block.cb_list); + + return 0; + +err_unroll: + list_for_each_entry_safe(block_cb, next, &bo->cb_list, list) { + list_del(&block_cb->driver_list); + if (i-- > 0) { + list_del(&block_cb->list); + tcf_block_playback_offloads(block, block_cb->cb, + block_cb->cb_priv, false, + tcf_block_offload_in_use(block), + NULL); + if (!bo->unlocked_driver_cb) + block->lockeddevcnt--; + } + flow_block_cb_free(block_cb); + } + + return err; +} + +static void tcf_block_unbind(struct tcf_block *block, + struct flow_block_offload *bo) +{ + struct flow_block_cb *block_cb, *next; + + lockdep_assert_held(&block->cb_lock); + + list_for_each_entry_safe(block_cb, next, &bo->cb_list, list) { + tcf_block_playback_offloads(block, block_cb->cb, + block_cb->cb_priv, false, + tcf_block_offload_in_use(block), + NULL); + list_del(&block_cb->list); + flow_block_cb_free(block_cb); + if (!bo->unlocked_driver_cb) + block->lockeddevcnt--; + } +} + +static int tcf_block_setup(struct tcf_block *block, + struct flow_block_offload *bo) +{ + int err; + + switch (bo->command) { + case FLOW_BLOCK_BIND: + err = tcf_block_bind(block, bo); + break; + case FLOW_BLOCK_UNBIND: + err = 0; + tcf_block_unbind(block, bo); + break; + default: + WARN_ON_ONCE(1); + err = -EOPNOTSUPP; + } + + return err; +} + +/* Main classifier routine: scans classifier chain attached + * to this qdisc, (optionally) tests for protocol and asks + * specific classifiers. + */ +static inline int __tcf_classify(struct sk_buff *skb, + const struct tcf_proto *tp, + const struct tcf_proto *orig_tp, + struct tcf_result *res, + bool compat_mode, + u32 *last_executed_chain) +{ +#ifdef CONFIG_NET_CLS_ACT + const int max_reclassify_loop = 16; + const struct tcf_proto *first_tp; + int limit = 0; + +reclassify: +#endif + for (; tp; tp = rcu_dereference_bh(tp->next)) { + __be16 protocol = skb_protocol(skb, false); + int err; + + if (tp->protocol != protocol && + tp->protocol != htons(ETH_P_ALL)) + continue; + + err = tp->classify(skb, tp, res); +#ifdef CONFIG_NET_CLS_ACT + if (unlikely(err == TC_ACT_RECLASSIFY && !compat_mode)) { + first_tp = orig_tp; + *last_executed_chain = first_tp->chain->index; + goto reset; + } else if (unlikely(TC_ACT_EXT_CMP(err, TC_ACT_GOTO_CHAIN))) { + first_tp = res->goto_tp; + *last_executed_chain = err & TC_ACT_EXT_VAL_MASK; + goto reset; + } +#endif + if (err >= 0) + return err; + } + + return TC_ACT_UNSPEC; /* signal: continue lookup */ +#ifdef CONFIG_NET_CLS_ACT +reset: + if (unlikely(limit++ >= max_reclassify_loop)) { + net_notice_ratelimited("%u: reclassify loop, rule prio %u, protocol %02x\n", + tp->chain->block->index, + tp->prio & 0xffff, + ntohs(tp->protocol)); + return TC_ACT_SHOT; + } + + tp = first_tp; + goto reclassify; +#endif +} + +int tcf_classify(struct sk_buff *skb, + const struct tcf_block *block, + const struct tcf_proto *tp, + struct tcf_result *res, bool compat_mode) +{ +#if !IS_ENABLED(CONFIG_NET_TC_SKB_EXT) + u32 last_executed_chain = 0; + + return __tcf_classify(skb, tp, tp, res, compat_mode, + &last_executed_chain); +#else + u32 last_executed_chain = tp ? tp->chain->index : 0; + const struct tcf_proto *orig_tp = tp; + struct tc_skb_ext *ext; + int ret; + + if (block) { + ext = skb_ext_find(skb, TC_SKB_EXT); + + if (ext && ext->chain) { + struct tcf_chain *fchain; + + fchain = tcf_chain_lookup_rcu(block, ext->chain); + if (!fchain) + return TC_ACT_SHOT; + + /* Consume, so cloned/redirect skbs won't inherit ext */ + skb_ext_del(skb, TC_SKB_EXT); + + tp = rcu_dereference_bh(fchain->filter_chain); + last_executed_chain = fchain->index; + } + } + + ret = __tcf_classify(skb, tp, orig_tp, res, compat_mode, + &last_executed_chain); + + if (tc_skb_ext_tc_enabled()) { + /* If we missed on some chain */ + if (ret == TC_ACT_UNSPEC && last_executed_chain) { + struct tc_skb_cb *cb = tc_skb_cb(skb); + + ext = tc_skb_ext_alloc(skb); + if (WARN_ON_ONCE(!ext)) + return TC_ACT_SHOT; + ext->chain = last_executed_chain; + ext->mru = cb->mru; + ext->post_ct = cb->post_ct; + ext->post_ct_snat = cb->post_ct_snat; + ext->post_ct_dnat = cb->post_ct_dnat; + ext->zone = cb->zone; + } + } + + return ret; +#endif +} +EXPORT_SYMBOL(tcf_classify); + +struct tcf_chain_info { + struct tcf_proto __rcu **pprev; + struct tcf_proto __rcu *next; +}; + +static struct tcf_proto *tcf_chain_tp_prev(struct tcf_chain *chain, + struct tcf_chain_info *chain_info) +{ + return tcf_chain_dereference(*chain_info->pprev, chain); +} + +static int tcf_chain_tp_insert(struct tcf_chain *chain, + struct tcf_chain_info *chain_info, + struct tcf_proto *tp) +{ + if (chain->flushing) + return -EAGAIN; + + RCU_INIT_POINTER(tp->next, tcf_chain_tp_prev(chain, chain_info)); + if (*chain_info->pprev == chain->filter_chain) + tcf_chain0_head_change(chain, tp); + tcf_proto_get(tp); + rcu_assign_pointer(*chain_info->pprev, tp); + + return 0; +} + +static void tcf_chain_tp_remove(struct tcf_chain *chain, + struct tcf_chain_info *chain_info, + struct tcf_proto *tp) +{ + struct tcf_proto *next = tcf_chain_dereference(chain_info->next, chain); + + tcf_proto_mark_delete(tp); + if (tp == chain->filter_chain) + tcf_chain0_head_change(chain, next); + RCU_INIT_POINTER(*chain_info->pprev, next); +} + +static struct tcf_proto *tcf_chain_tp_find(struct tcf_chain *chain, + struct tcf_chain_info *chain_info, + u32 protocol, u32 prio, + bool prio_allocate); + +/* Try to insert new proto. + * If proto with specified priority already exists, free new proto + * and return existing one. + */ + +static struct tcf_proto *tcf_chain_tp_insert_unique(struct tcf_chain *chain, + struct tcf_proto *tp_new, + u32 protocol, u32 prio, + bool rtnl_held) +{ + struct tcf_chain_info chain_info; + struct tcf_proto *tp; + int err = 0; + + mutex_lock(&chain->filter_chain_lock); + + if (tcf_proto_exists_destroying(chain, tp_new)) { + mutex_unlock(&chain->filter_chain_lock); + tcf_proto_destroy(tp_new, rtnl_held, false, NULL); + return ERR_PTR(-EAGAIN); + } + + tp = tcf_chain_tp_find(chain, &chain_info, + protocol, prio, false); + if (!tp) + err = tcf_chain_tp_insert(chain, &chain_info, tp_new); + mutex_unlock(&chain->filter_chain_lock); + + if (tp) { + tcf_proto_destroy(tp_new, rtnl_held, false, NULL); + tp_new = tp; + } else if (err) { + tcf_proto_destroy(tp_new, rtnl_held, false, NULL); + tp_new = ERR_PTR(err); + } + + return tp_new; +} + +static void tcf_chain_tp_delete_empty(struct tcf_chain *chain, + struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + struct tcf_chain_info chain_info; + struct tcf_proto *tp_iter; + struct tcf_proto **pprev; + struct tcf_proto *next; + + mutex_lock(&chain->filter_chain_lock); + + /* Atomically find and remove tp from chain. */ + for (pprev = &chain->filter_chain; + (tp_iter = tcf_chain_dereference(*pprev, chain)); + pprev = &tp_iter->next) { + if (tp_iter == tp) { + chain_info.pprev = pprev; + chain_info.next = tp_iter->next; + WARN_ON(tp_iter->deleting); + break; + } + } + /* Verify that tp still exists and no new filters were inserted + * concurrently. + * Mark tp for deletion if it is empty. + */ + if (!tp_iter || !tcf_proto_check_delete(tp)) { + mutex_unlock(&chain->filter_chain_lock); + return; + } + + tcf_proto_signal_destroying(chain, tp); + next = tcf_chain_dereference(chain_info.next, chain); + if (tp == chain->filter_chain) + tcf_chain0_head_change(chain, next); + RCU_INIT_POINTER(*chain_info.pprev, next); + mutex_unlock(&chain->filter_chain_lock); + + tcf_proto_put(tp, rtnl_held, extack); +} + +static struct tcf_proto *tcf_chain_tp_find(struct tcf_chain *chain, + struct tcf_chain_info *chain_info, + u32 protocol, u32 prio, + bool prio_allocate) +{ + struct tcf_proto **pprev; + struct tcf_proto *tp; + + /* Check the chain for existence of proto-tcf with this priority */ + for (pprev = &chain->filter_chain; + (tp = tcf_chain_dereference(*pprev, chain)); + pprev = &tp->next) { + if (tp->prio >= prio) { + if (tp->prio == prio) { + if (prio_allocate || + (tp->protocol != protocol && protocol)) + return ERR_PTR(-EINVAL); + } else { + tp = NULL; + } + break; + } + } + chain_info->pprev = pprev; + if (tp) { + chain_info->next = tp->next; + tcf_proto_get(tp); + } else { + chain_info->next = NULL; + } + return tp; +} + +static int tcf_fill_node(struct net *net, struct sk_buff *skb, + struct tcf_proto *tp, struct tcf_block *block, + struct Qdisc *q, u32 parent, void *fh, + u32 portid, u32 seq, u16 flags, int event, + bool terse_dump, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + struct tcmsg *tcm; + struct nlmsghdr *nlh; + unsigned char *b = skb_tail_pointer(skb); + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(*tcm), flags); + if (!nlh) + goto out_nlmsg_trim; + tcm = nlmsg_data(nlh); + tcm->tcm_family = AF_UNSPEC; + tcm->tcm__pad1 = 0; + tcm->tcm__pad2 = 0; + if (q) { + tcm->tcm_ifindex = qdisc_dev(q)->ifindex; + tcm->tcm_parent = parent; + } else { + tcm->tcm_ifindex = TCM_IFINDEX_MAGIC_BLOCK; + tcm->tcm_block_index = block->index; + } + tcm->tcm_info = TC_H_MAKE(tp->prio, tp->protocol); + if (nla_put_string(skb, TCA_KIND, tp->ops->kind)) + goto nla_put_failure; + if (nla_put_u32(skb, TCA_CHAIN, tp->chain->index)) + goto nla_put_failure; + if (!fh) { + tcm->tcm_handle = 0; + } else if (terse_dump) { + if (tp->ops->terse_dump) { + if (tp->ops->terse_dump(net, tp, fh, skb, tcm, + rtnl_held) < 0) + goto nla_put_failure; + } else { + goto cls_op_not_supp; + } + } else { + if (tp->ops->dump && + tp->ops->dump(net, tp, fh, skb, tcm, rtnl_held) < 0) + goto nla_put_failure; + } + + if (extack && extack->_msg && + nla_put_string(skb, TCA_EXT_WARN_MSG, extack->_msg)) + goto nla_put_failure; + + nlh->nlmsg_len = skb_tail_pointer(skb) - b; + + return skb->len; + +out_nlmsg_trim: +nla_put_failure: +cls_op_not_supp: + nlmsg_trim(skb, b); + return -1; +} + +static int tfilter_notify(struct net *net, struct sk_buff *oskb, + struct nlmsghdr *n, struct tcf_proto *tp, + struct tcf_block *block, struct Qdisc *q, + u32 parent, void *fh, int event, bool unicast, + bool rtnl_held, struct netlink_ext_ack *extack) +{ + struct sk_buff *skb; + u32 portid = oskb ? NETLINK_CB(oskb).portid : 0; + int err = 0; + + skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) + return -ENOBUFS; + + if (tcf_fill_node(net, skb, tp, block, q, parent, fh, portid, + n->nlmsg_seq, n->nlmsg_flags, event, + false, rtnl_held, extack) <= 0) { + kfree_skb(skb); + return -EINVAL; + } + + if (unicast) + err = rtnl_unicast(skb, net, portid); + else + err = rtnetlink_send(skb, net, portid, RTNLGRP_TC, + n->nlmsg_flags & NLM_F_ECHO); + return err; +} + +static int tfilter_del_notify(struct net *net, struct sk_buff *oskb, + struct nlmsghdr *n, struct tcf_proto *tp, + struct tcf_block *block, struct Qdisc *q, + u32 parent, void *fh, bool unicast, bool *last, + bool rtnl_held, struct netlink_ext_ack *extack) +{ + struct sk_buff *skb; + u32 portid = oskb ? NETLINK_CB(oskb).portid : 0; + int err; + + skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) + return -ENOBUFS; + + if (tcf_fill_node(net, skb, tp, block, q, parent, fh, portid, + n->nlmsg_seq, n->nlmsg_flags, RTM_DELTFILTER, + false, rtnl_held, extack) <= 0) { + NL_SET_ERR_MSG(extack, "Failed to build del event notification"); + kfree_skb(skb); + return -EINVAL; + } + + err = tp->ops->delete(tp, fh, last, rtnl_held, extack); + if (err) { + kfree_skb(skb); + return err; + } + + if (unicast) + err = rtnl_unicast(skb, net, portid); + else + err = rtnetlink_send(skb, net, portid, RTNLGRP_TC, + n->nlmsg_flags & NLM_F_ECHO); + if (err < 0) + NL_SET_ERR_MSG(extack, "Failed to send filter delete notification"); + + return err; +} + +static void tfilter_notify_chain(struct net *net, struct sk_buff *oskb, + struct tcf_block *block, struct Qdisc *q, + u32 parent, struct nlmsghdr *n, + struct tcf_chain *chain, int event, + struct netlink_ext_ack *extack) +{ + struct tcf_proto *tp; + + for (tp = tcf_get_next_proto(chain, NULL); + tp; tp = tcf_get_next_proto(chain, tp)) + tfilter_notify(net, oskb, n, tp, block, q, parent, NULL, + event, false, true, extack); +} + +static void tfilter_put(struct tcf_proto *tp, void *fh) +{ + if (tp->ops->put && fh) + tp->ops->put(tp, fh); +} + +static int tc_new_tfilter(struct sk_buff *skb, struct nlmsghdr *n, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *tca[TCA_MAX + 1]; + char name[IFNAMSIZ]; + struct tcmsg *t; + u32 protocol; + u32 prio; + bool prio_allocate; + u32 parent; + u32 chain_index; + struct Qdisc *q; + struct tcf_chain_info chain_info; + struct tcf_chain *chain; + struct tcf_block *block; + struct tcf_proto *tp; + unsigned long cl; + void *fh; + int err; + int tp_created; + bool rtnl_held = false; + u32 flags; + +replay: + tp_created = 0; + + err = nlmsg_parse_deprecated(n, sizeof(*t), tca, TCA_MAX, + rtm_tca_policy, extack); + if (err < 0) + return err; + + t = nlmsg_data(n); + protocol = TC_H_MIN(t->tcm_info); + prio = TC_H_MAJ(t->tcm_info); + prio_allocate = false; + parent = t->tcm_parent; + tp = NULL; + cl = 0; + block = NULL; + q = NULL; + chain = NULL; + flags = 0; + + if (prio == 0) { + /* If no priority is provided by the user, + * we allocate one. + */ + if (n->nlmsg_flags & NLM_F_CREATE) { + prio = TC_H_MAKE(0x80000000U, 0U); + prio_allocate = true; + } else { + NL_SET_ERR_MSG(extack, "Invalid filter command with priority of zero"); + return -ENOENT; + } + } + + /* Find head of filter chain. */ + + err = __tcf_qdisc_find(net, &q, &parent, t->tcm_ifindex, false, extack); + if (err) + return err; + + if (tcf_proto_check_kind(tca[TCA_KIND], name)) { + NL_SET_ERR_MSG(extack, "Specified TC filter name too long"); + err = -EINVAL; + goto errout; + } + + /* Take rtnl mutex if rtnl_held was set to true on previous iteration, + * block is shared (no qdisc found), qdisc is not unlocked, classifier + * type is not specified, classifier is not unlocked. + */ + if (rtnl_held || + (q && !(q->ops->cl_ops->flags & QDISC_CLASS_OPS_DOIT_UNLOCKED)) || + !tcf_proto_is_unlocked(name)) { + rtnl_held = true; + rtnl_lock(); + } + + err = __tcf_qdisc_cl_find(q, parent, &cl, t->tcm_ifindex, extack); + if (err) + goto errout; + + block = __tcf_block_find(net, q, cl, t->tcm_ifindex, t->tcm_block_index, + extack); + if (IS_ERR(block)) { + err = PTR_ERR(block); + goto errout; + } + block->classid = parent; + + chain_index = tca[TCA_CHAIN] ? nla_get_u32(tca[TCA_CHAIN]) : 0; + if (chain_index > TC_ACT_EXT_VAL_MASK) { + NL_SET_ERR_MSG(extack, "Specified chain index exceeds upper limit"); + err = -EINVAL; + goto errout; + } + chain = tcf_chain_get(block, chain_index, true); + if (!chain) { + NL_SET_ERR_MSG(extack, "Cannot create specified filter chain"); + err = -ENOMEM; + goto errout; + } + + mutex_lock(&chain->filter_chain_lock); + tp = tcf_chain_tp_find(chain, &chain_info, protocol, + prio, prio_allocate); + if (IS_ERR(tp)) { + NL_SET_ERR_MSG(extack, "Filter with specified priority/protocol not found"); + err = PTR_ERR(tp); + goto errout_locked; + } + + if (tp == NULL) { + struct tcf_proto *tp_new = NULL; + + if (chain->flushing) { + err = -EAGAIN; + goto errout_locked; + } + + /* Proto-tcf does not exist, create new one */ + + if (tca[TCA_KIND] == NULL || !protocol) { + NL_SET_ERR_MSG(extack, "Filter kind and protocol must be specified"); + err = -EINVAL; + goto errout_locked; + } + + if (!(n->nlmsg_flags & NLM_F_CREATE)) { + NL_SET_ERR_MSG(extack, "Need both RTM_NEWTFILTER and NLM_F_CREATE to create a new filter"); + err = -ENOENT; + goto errout_locked; + } + + if (prio_allocate) + prio = tcf_auto_prio(tcf_chain_tp_prev(chain, + &chain_info)); + + mutex_unlock(&chain->filter_chain_lock); + tp_new = tcf_proto_create(name, protocol, prio, chain, + rtnl_held, extack); + if (IS_ERR(tp_new)) { + err = PTR_ERR(tp_new); + goto errout_tp; + } + + tp_created = 1; + tp = tcf_chain_tp_insert_unique(chain, tp_new, protocol, prio, + rtnl_held); + if (IS_ERR(tp)) { + err = PTR_ERR(tp); + goto errout_tp; + } + } else { + mutex_unlock(&chain->filter_chain_lock); + } + + if (tca[TCA_KIND] && nla_strcmp(tca[TCA_KIND], tp->ops->kind)) { + NL_SET_ERR_MSG(extack, "Specified filter kind does not match existing one"); + err = -EINVAL; + goto errout; + } + + fh = tp->ops->get(tp, t->tcm_handle); + + if (!fh) { + if (!(n->nlmsg_flags & NLM_F_CREATE)) { + NL_SET_ERR_MSG(extack, "Need both RTM_NEWTFILTER and NLM_F_CREATE to create a new filter"); + err = -ENOENT; + goto errout; + } + } else if (n->nlmsg_flags & NLM_F_EXCL) { + tfilter_put(tp, fh); + NL_SET_ERR_MSG(extack, "Filter already exists"); + err = -EEXIST; + goto errout; + } + + if (chain->tmplt_ops && chain->tmplt_ops != tp->ops) { + tfilter_put(tp, fh); + NL_SET_ERR_MSG(extack, "Chain template is set to a different filter kind"); + err = -EINVAL; + goto errout; + } + + if (!(n->nlmsg_flags & NLM_F_CREATE)) + flags |= TCA_ACT_FLAGS_REPLACE; + if (!rtnl_held) + flags |= TCA_ACT_FLAGS_NO_RTNL; + err = tp->ops->change(net, skb, tp, cl, t->tcm_handle, tca, &fh, + flags, extack); + if (err == 0) { + tfilter_notify(net, skb, n, tp, block, q, parent, fh, + RTM_NEWTFILTER, false, rtnl_held, extack); + tfilter_put(tp, fh); + /* q pointer is NULL for shared blocks */ + if (q) + q->flags &= ~TCQ_F_CAN_BYPASS; + } + +errout: + if (err && tp_created) + tcf_chain_tp_delete_empty(chain, tp, rtnl_held, NULL); +errout_tp: + if (chain) { + if (tp && !IS_ERR(tp)) + tcf_proto_put(tp, rtnl_held, NULL); + if (!tp_created) + tcf_chain_put(chain); + } + tcf_block_release(q, block, rtnl_held); + + if (rtnl_held) + rtnl_unlock(); + + if (err == -EAGAIN) { + /* Take rtnl lock in case EAGAIN is caused by concurrent flush + * of target chain. + */ + rtnl_held = true; + /* Replay the request. */ + goto replay; + } + return err; + +errout_locked: + mutex_unlock(&chain->filter_chain_lock); + goto errout; +} + +static int tc_del_tfilter(struct sk_buff *skb, struct nlmsghdr *n, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *tca[TCA_MAX + 1]; + char name[IFNAMSIZ]; + struct tcmsg *t; + u32 protocol; + u32 prio; + u32 parent; + u32 chain_index; + struct Qdisc *q = NULL; + struct tcf_chain_info chain_info; + struct tcf_chain *chain = NULL; + struct tcf_block *block = NULL; + struct tcf_proto *tp = NULL; + unsigned long cl = 0; + void *fh = NULL; + int err; + bool rtnl_held = false; + + err = nlmsg_parse_deprecated(n, sizeof(*t), tca, TCA_MAX, + rtm_tca_policy, extack); + if (err < 0) + return err; + + t = nlmsg_data(n); + protocol = TC_H_MIN(t->tcm_info); + prio = TC_H_MAJ(t->tcm_info); + parent = t->tcm_parent; + + if (prio == 0 && (protocol || t->tcm_handle || tca[TCA_KIND])) { + NL_SET_ERR_MSG(extack, "Cannot flush filters with protocol, handle or kind set"); + return -ENOENT; + } + + /* Find head of filter chain. */ + + err = __tcf_qdisc_find(net, &q, &parent, t->tcm_ifindex, false, extack); + if (err) + return err; + + if (tcf_proto_check_kind(tca[TCA_KIND], name)) { + NL_SET_ERR_MSG(extack, "Specified TC filter name too long"); + err = -EINVAL; + goto errout; + } + /* Take rtnl mutex if flushing whole chain, block is shared (no qdisc + * found), qdisc is not unlocked, classifier type is not specified, + * classifier is not unlocked. + */ + if (!prio || + (q && !(q->ops->cl_ops->flags & QDISC_CLASS_OPS_DOIT_UNLOCKED)) || + !tcf_proto_is_unlocked(name)) { + rtnl_held = true; + rtnl_lock(); + } + + err = __tcf_qdisc_cl_find(q, parent, &cl, t->tcm_ifindex, extack); + if (err) + goto errout; + + block = __tcf_block_find(net, q, cl, t->tcm_ifindex, t->tcm_block_index, + extack); + if (IS_ERR(block)) { + err = PTR_ERR(block); + goto errout; + } + + chain_index = tca[TCA_CHAIN] ? nla_get_u32(tca[TCA_CHAIN]) : 0; + if (chain_index > TC_ACT_EXT_VAL_MASK) { + NL_SET_ERR_MSG(extack, "Specified chain index exceeds upper limit"); + err = -EINVAL; + goto errout; + } + chain = tcf_chain_get(block, chain_index, false); + if (!chain) { + /* User requested flush on non-existent chain. Nothing to do, + * so just return success. + */ + if (prio == 0) { + err = 0; + goto errout; + } + NL_SET_ERR_MSG(extack, "Cannot find specified filter chain"); + err = -ENOENT; + goto errout; + } + + if (prio == 0) { + tfilter_notify_chain(net, skb, block, q, parent, n, + chain, RTM_DELTFILTER, extack); + tcf_chain_flush(chain, rtnl_held); + err = 0; + goto errout; + } + + mutex_lock(&chain->filter_chain_lock); + tp = tcf_chain_tp_find(chain, &chain_info, protocol, + prio, false); + if (!tp || IS_ERR(tp)) { + NL_SET_ERR_MSG(extack, "Filter with specified priority/protocol not found"); + err = tp ? PTR_ERR(tp) : -ENOENT; + goto errout_locked; + } else if (tca[TCA_KIND] && nla_strcmp(tca[TCA_KIND], tp->ops->kind)) { + NL_SET_ERR_MSG(extack, "Specified filter kind does not match existing one"); + err = -EINVAL; + goto errout_locked; + } else if (t->tcm_handle == 0) { + tcf_proto_signal_destroying(chain, tp); + tcf_chain_tp_remove(chain, &chain_info, tp); + mutex_unlock(&chain->filter_chain_lock); + + tcf_proto_put(tp, rtnl_held, NULL); + tfilter_notify(net, skb, n, tp, block, q, parent, fh, + RTM_DELTFILTER, false, rtnl_held, extack); + err = 0; + goto errout; + } + mutex_unlock(&chain->filter_chain_lock); + + fh = tp->ops->get(tp, t->tcm_handle); + + if (!fh) { + NL_SET_ERR_MSG(extack, "Specified filter handle not found"); + err = -ENOENT; + } else { + bool last; + + err = tfilter_del_notify(net, skb, n, tp, block, + q, parent, fh, false, &last, + rtnl_held, extack); + + if (err) + goto errout; + if (last) + tcf_chain_tp_delete_empty(chain, tp, rtnl_held, extack); + } + +errout: + if (chain) { + if (tp && !IS_ERR(tp)) + tcf_proto_put(tp, rtnl_held, NULL); + tcf_chain_put(chain); + } + tcf_block_release(q, block, rtnl_held); + + if (rtnl_held) + rtnl_unlock(); + + return err; + +errout_locked: + mutex_unlock(&chain->filter_chain_lock); + goto errout; +} + +static int tc_get_tfilter(struct sk_buff *skb, struct nlmsghdr *n, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *tca[TCA_MAX + 1]; + char name[IFNAMSIZ]; + struct tcmsg *t; + u32 protocol; + u32 prio; + u32 parent; + u32 chain_index; + struct Qdisc *q = NULL; + struct tcf_chain_info chain_info; + struct tcf_chain *chain = NULL; + struct tcf_block *block = NULL; + struct tcf_proto *tp = NULL; + unsigned long cl = 0; + void *fh = NULL; + int err; + bool rtnl_held = false; + + err = nlmsg_parse_deprecated(n, sizeof(*t), tca, TCA_MAX, + rtm_tca_policy, extack); + if (err < 0) + return err; + + t = nlmsg_data(n); + protocol = TC_H_MIN(t->tcm_info); + prio = TC_H_MAJ(t->tcm_info); + parent = t->tcm_parent; + + if (prio == 0) { + NL_SET_ERR_MSG(extack, "Invalid filter command with priority of zero"); + return -ENOENT; + } + + /* Find head of filter chain. */ + + err = __tcf_qdisc_find(net, &q, &parent, t->tcm_ifindex, false, extack); + if (err) + return err; + + if (tcf_proto_check_kind(tca[TCA_KIND], name)) { + NL_SET_ERR_MSG(extack, "Specified TC filter name too long"); + err = -EINVAL; + goto errout; + } + /* Take rtnl mutex if block is shared (no qdisc found), qdisc is not + * unlocked, classifier type is not specified, classifier is not + * unlocked. + */ + if ((q && !(q->ops->cl_ops->flags & QDISC_CLASS_OPS_DOIT_UNLOCKED)) || + !tcf_proto_is_unlocked(name)) { + rtnl_held = true; + rtnl_lock(); + } + + err = __tcf_qdisc_cl_find(q, parent, &cl, t->tcm_ifindex, extack); + if (err) + goto errout; + + block = __tcf_block_find(net, q, cl, t->tcm_ifindex, t->tcm_block_index, + extack); + if (IS_ERR(block)) { + err = PTR_ERR(block); + goto errout; + } + + chain_index = tca[TCA_CHAIN] ? nla_get_u32(tca[TCA_CHAIN]) : 0; + if (chain_index > TC_ACT_EXT_VAL_MASK) { + NL_SET_ERR_MSG(extack, "Specified chain index exceeds upper limit"); + err = -EINVAL; + goto errout; + } + chain = tcf_chain_get(block, chain_index, false); + if (!chain) { + NL_SET_ERR_MSG(extack, "Cannot find specified filter chain"); + err = -EINVAL; + goto errout; + } + + mutex_lock(&chain->filter_chain_lock); + tp = tcf_chain_tp_find(chain, &chain_info, protocol, + prio, false); + mutex_unlock(&chain->filter_chain_lock); + if (!tp || IS_ERR(tp)) { + NL_SET_ERR_MSG(extack, "Filter with specified priority/protocol not found"); + err = tp ? PTR_ERR(tp) : -ENOENT; + goto errout; + } else if (tca[TCA_KIND] && nla_strcmp(tca[TCA_KIND], tp->ops->kind)) { + NL_SET_ERR_MSG(extack, "Specified filter kind does not match existing one"); + err = -EINVAL; + goto errout; + } + + fh = tp->ops->get(tp, t->tcm_handle); + + if (!fh) { + NL_SET_ERR_MSG(extack, "Specified filter handle not found"); + err = -ENOENT; + } else { + err = tfilter_notify(net, skb, n, tp, block, q, parent, + fh, RTM_NEWTFILTER, true, rtnl_held, NULL); + if (err < 0) + NL_SET_ERR_MSG(extack, "Failed to send filter notify message"); + } + + tfilter_put(tp, fh); +errout: + if (chain) { + if (tp && !IS_ERR(tp)) + tcf_proto_put(tp, rtnl_held, NULL); + tcf_chain_put(chain); + } + tcf_block_release(q, block, rtnl_held); + + if (rtnl_held) + rtnl_unlock(); + + return err; +} + +struct tcf_dump_args { + struct tcf_walker w; + struct sk_buff *skb; + struct netlink_callback *cb; + struct tcf_block *block; + struct Qdisc *q; + u32 parent; + bool terse_dump; +}; + +static int tcf_node_dump(struct tcf_proto *tp, void *n, struct tcf_walker *arg) +{ + struct tcf_dump_args *a = (void *)arg; + struct net *net = sock_net(a->skb->sk); + + return tcf_fill_node(net, a->skb, tp, a->block, a->q, a->parent, + n, NETLINK_CB(a->cb->skb).portid, + a->cb->nlh->nlmsg_seq, NLM_F_MULTI, + RTM_NEWTFILTER, a->terse_dump, true, NULL); +} + +static bool tcf_chain_dump(struct tcf_chain *chain, struct Qdisc *q, u32 parent, + struct sk_buff *skb, struct netlink_callback *cb, + long index_start, long *p_index, bool terse) +{ + struct net *net = sock_net(skb->sk); + struct tcf_block *block = chain->block; + struct tcmsg *tcm = nlmsg_data(cb->nlh); + struct tcf_proto *tp, *tp_prev; + struct tcf_dump_args arg; + + for (tp = __tcf_get_next_proto(chain, NULL); + tp; + tp_prev = tp, + tp = __tcf_get_next_proto(chain, tp), + tcf_proto_put(tp_prev, true, NULL), + (*p_index)++) { + if (*p_index < index_start) + continue; + if (TC_H_MAJ(tcm->tcm_info) && + TC_H_MAJ(tcm->tcm_info) != tp->prio) + continue; + if (TC_H_MIN(tcm->tcm_info) && + TC_H_MIN(tcm->tcm_info) != tp->protocol) + continue; + if (*p_index > index_start) + memset(&cb->args[1], 0, + sizeof(cb->args) - sizeof(cb->args[0])); + if (cb->args[1] == 0) { + if (tcf_fill_node(net, skb, tp, block, q, parent, NULL, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + RTM_NEWTFILTER, false, true, NULL) <= 0) + goto errout; + cb->args[1] = 1; + } + if (!tp->ops->walk) + continue; + arg.w.fn = tcf_node_dump; + arg.skb = skb; + arg.cb = cb; + arg.block = block; + arg.q = q; + arg.parent = parent; + arg.w.stop = 0; + arg.w.skip = cb->args[1] - 1; + arg.w.count = 0; + arg.w.cookie = cb->args[2]; + arg.terse_dump = terse; + tp->ops->walk(tp, &arg.w, true); + cb->args[2] = arg.w.cookie; + cb->args[1] = arg.w.count + 1; + if (arg.w.stop) + goto errout; + } + return true; + +errout: + tcf_proto_put(tp, true, NULL); + return false; +} + +static const struct nla_policy tcf_tfilter_dump_policy[TCA_MAX + 1] = { + [TCA_DUMP_FLAGS] = NLA_POLICY_BITFIELD32(TCA_DUMP_FLAGS_TERSE), +}; + +/* called with RTNL */ +static int tc_dump_tfilter(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct tcf_chain *chain, *chain_prev; + struct net *net = sock_net(skb->sk); + struct nlattr *tca[TCA_MAX + 1]; + struct Qdisc *q = NULL; + struct tcf_block *block; + struct tcmsg *tcm = nlmsg_data(cb->nlh); + bool terse_dump = false; + long index_start; + long index; + u32 parent; + int err; + + if (nlmsg_len(cb->nlh) < sizeof(*tcm)) + return skb->len; + + err = nlmsg_parse_deprecated(cb->nlh, sizeof(*tcm), tca, TCA_MAX, + tcf_tfilter_dump_policy, cb->extack); + if (err) + return err; + + if (tca[TCA_DUMP_FLAGS]) { + struct nla_bitfield32 flags = + nla_get_bitfield32(tca[TCA_DUMP_FLAGS]); + + terse_dump = flags.value & TCA_DUMP_FLAGS_TERSE; + } + + if (tcm->tcm_ifindex == TCM_IFINDEX_MAGIC_BLOCK) { + block = tcf_block_refcnt_get(net, tcm->tcm_block_index); + if (!block) + goto out; + /* If we work with block index, q is NULL and parent value + * will never be used in the following code. The check + * in tcf_fill_node prevents it. However, compiler does not + * see that far, so set parent to zero to silence the warning + * about parent being uninitialized. + */ + parent = 0; + } else { + const struct Qdisc_class_ops *cops; + struct net_device *dev; + unsigned long cl = 0; + + dev = __dev_get_by_index(net, tcm->tcm_ifindex); + if (!dev) + return skb->len; + + parent = tcm->tcm_parent; + if (!parent) + q = rtnl_dereference(dev->qdisc); + else + q = qdisc_lookup(dev, TC_H_MAJ(tcm->tcm_parent)); + if (!q) + goto out; + cops = q->ops->cl_ops; + if (!cops) + goto out; + if (!cops->tcf_block) + goto out; + if (TC_H_MIN(tcm->tcm_parent)) { + cl = cops->find(q, tcm->tcm_parent); + if (cl == 0) + goto out; + } + block = cops->tcf_block(q, cl, NULL); + if (!block) + goto out; + parent = block->classid; + if (tcf_block_shared(block)) + q = NULL; + } + + index_start = cb->args[0]; + index = 0; + + for (chain = __tcf_get_next_chain(block, NULL); + chain; + chain_prev = chain, + chain = __tcf_get_next_chain(block, chain), + tcf_chain_put(chain_prev)) { + if (tca[TCA_CHAIN] && + nla_get_u32(tca[TCA_CHAIN]) != chain->index) + continue; + if (!tcf_chain_dump(chain, q, parent, skb, cb, + index_start, &index, terse_dump)) { + tcf_chain_put(chain); + err = -EMSGSIZE; + break; + } + } + + if (tcm->tcm_ifindex == TCM_IFINDEX_MAGIC_BLOCK) + tcf_block_refcnt_put(block, true); + cb->args[0] = index; + +out: + /* If we did no progress, the error (EMSGSIZE) is real */ + if (skb->len == 0 && err) + return err; + return skb->len; +} + +static int tc_chain_fill_node(const struct tcf_proto_ops *tmplt_ops, + void *tmplt_priv, u32 chain_index, + struct net *net, struct sk_buff *skb, + struct tcf_block *block, + u32 portid, u32 seq, u16 flags, int event, + struct netlink_ext_ack *extack) +{ + unsigned char *b = skb_tail_pointer(skb); + const struct tcf_proto_ops *ops; + struct nlmsghdr *nlh; + struct tcmsg *tcm; + void *priv; + + ops = tmplt_ops; + priv = tmplt_priv; + + nlh = nlmsg_put(skb, portid, seq, event, sizeof(*tcm), flags); + if (!nlh) + goto out_nlmsg_trim; + tcm = nlmsg_data(nlh); + tcm->tcm_family = AF_UNSPEC; + tcm->tcm__pad1 = 0; + tcm->tcm__pad2 = 0; + tcm->tcm_handle = 0; + if (block->q) { + tcm->tcm_ifindex = qdisc_dev(block->q)->ifindex; + tcm->tcm_parent = block->q->handle; + } else { + tcm->tcm_ifindex = TCM_IFINDEX_MAGIC_BLOCK; + tcm->tcm_block_index = block->index; + } + + if (nla_put_u32(skb, TCA_CHAIN, chain_index)) + goto nla_put_failure; + + if (ops) { + if (nla_put_string(skb, TCA_KIND, ops->kind)) + goto nla_put_failure; + if (ops->tmplt_dump(skb, net, priv) < 0) + goto nla_put_failure; + } + + if (extack && extack->_msg && + nla_put_string(skb, TCA_EXT_WARN_MSG, extack->_msg)) + goto out_nlmsg_trim; + + nlh->nlmsg_len = skb_tail_pointer(skb) - b; + + return skb->len; + +out_nlmsg_trim: +nla_put_failure: + nlmsg_trim(skb, b); + return -EMSGSIZE; +} + +static int tc_chain_notify(struct tcf_chain *chain, struct sk_buff *oskb, + u32 seq, u16 flags, int event, bool unicast, + struct netlink_ext_ack *extack) +{ + u32 portid = oskb ? NETLINK_CB(oskb).portid : 0; + struct tcf_block *block = chain->block; + struct net *net = block->net; + struct sk_buff *skb; + int err = 0; + + skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) + return -ENOBUFS; + + if (tc_chain_fill_node(chain->tmplt_ops, chain->tmplt_priv, + chain->index, net, skb, block, portid, + seq, flags, event, extack) <= 0) { + kfree_skb(skb); + return -EINVAL; + } + + if (unicast) + err = rtnl_unicast(skb, net, portid); + else + err = rtnetlink_send(skb, net, portid, RTNLGRP_TC, + flags & NLM_F_ECHO); + + return err; +} + +static int tc_chain_notify_delete(const struct tcf_proto_ops *tmplt_ops, + void *tmplt_priv, u32 chain_index, + struct tcf_block *block, struct sk_buff *oskb, + u32 seq, u16 flags, bool unicast) +{ + u32 portid = oskb ? NETLINK_CB(oskb).portid : 0; + struct net *net = block->net; + struct sk_buff *skb; + + skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) + return -ENOBUFS; + + if (tc_chain_fill_node(tmplt_ops, tmplt_priv, chain_index, net, skb, + block, portid, seq, flags, RTM_DELCHAIN, NULL) <= 0) { + kfree_skb(skb); + return -EINVAL; + } + + if (unicast) + return rtnl_unicast(skb, net, portid); + + return rtnetlink_send(skb, net, portid, RTNLGRP_TC, flags & NLM_F_ECHO); +} + +static int tc_chain_tmplt_add(struct tcf_chain *chain, struct net *net, + struct nlattr **tca, + struct netlink_ext_ack *extack) +{ + const struct tcf_proto_ops *ops; + char name[IFNAMSIZ]; + void *tmplt_priv; + + /* If kind is not set, user did not specify template. */ + if (!tca[TCA_KIND]) + return 0; + + if (tcf_proto_check_kind(tca[TCA_KIND], name)) { + NL_SET_ERR_MSG(extack, "Specified TC chain template name too long"); + return -EINVAL; + } + + ops = tcf_proto_lookup_ops(name, true, extack); + if (IS_ERR(ops)) + return PTR_ERR(ops); + if (!ops->tmplt_create || !ops->tmplt_destroy || !ops->tmplt_dump) { + NL_SET_ERR_MSG(extack, "Chain templates are not supported with specified classifier"); + module_put(ops->owner); + return -EOPNOTSUPP; + } + + tmplt_priv = ops->tmplt_create(net, chain, tca, extack); + if (IS_ERR(tmplt_priv)) { + module_put(ops->owner); + return PTR_ERR(tmplt_priv); + } + chain->tmplt_ops = ops; + chain->tmplt_priv = tmplt_priv; + return 0; +} + +static void tc_chain_tmplt_del(const struct tcf_proto_ops *tmplt_ops, + void *tmplt_priv) +{ + /* If template ops are set, no work to do for us. */ + if (!tmplt_ops) + return; + + tmplt_ops->tmplt_destroy(tmplt_priv); + module_put(tmplt_ops->owner); +} + +/* Add/delete/get a chain */ + +static int tc_ctl_chain(struct sk_buff *skb, struct nlmsghdr *n, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *tca[TCA_MAX + 1]; + struct tcmsg *t; + u32 parent; + u32 chain_index; + struct Qdisc *q; + struct tcf_chain *chain; + struct tcf_block *block; + unsigned long cl; + int err; + +replay: + q = NULL; + err = nlmsg_parse_deprecated(n, sizeof(*t), tca, TCA_MAX, + rtm_tca_policy, extack); + if (err < 0) + return err; + + t = nlmsg_data(n); + parent = t->tcm_parent; + cl = 0; + + block = tcf_block_find(net, &q, &parent, &cl, + t->tcm_ifindex, t->tcm_block_index, extack); + if (IS_ERR(block)) + return PTR_ERR(block); + + chain_index = tca[TCA_CHAIN] ? nla_get_u32(tca[TCA_CHAIN]) : 0; + if (chain_index > TC_ACT_EXT_VAL_MASK) { + NL_SET_ERR_MSG(extack, "Specified chain index exceeds upper limit"); + err = -EINVAL; + goto errout_block; + } + + mutex_lock(&block->lock); + chain = tcf_chain_lookup(block, chain_index); + if (n->nlmsg_type == RTM_NEWCHAIN) { + if (chain) { + if (tcf_chain_held_by_acts_only(chain)) { + /* The chain exists only because there is + * some action referencing it. + */ + tcf_chain_hold(chain); + } else { + NL_SET_ERR_MSG(extack, "Filter chain already exists"); + err = -EEXIST; + goto errout_block_locked; + } + } else { + if (!(n->nlmsg_flags & NLM_F_CREATE)) { + NL_SET_ERR_MSG(extack, "Need both RTM_NEWCHAIN and NLM_F_CREATE to create a new chain"); + err = -ENOENT; + goto errout_block_locked; + } + chain = tcf_chain_create(block, chain_index); + if (!chain) { + NL_SET_ERR_MSG(extack, "Failed to create filter chain"); + err = -ENOMEM; + goto errout_block_locked; + } + } + } else { + if (!chain || tcf_chain_held_by_acts_only(chain)) { + NL_SET_ERR_MSG(extack, "Cannot find specified filter chain"); + err = -EINVAL; + goto errout_block_locked; + } + tcf_chain_hold(chain); + } + + if (n->nlmsg_type == RTM_NEWCHAIN) { + /* Modifying chain requires holding parent block lock. In case + * the chain was successfully added, take a reference to the + * chain. This ensures that an empty chain does not disappear at + * the end of this function. + */ + tcf_chain_hold(chain); + chain->explicitly_created = true; + } + mutex_unlock(&block->lock); + + switch (n->nlmsg_type) { + case RTM_NEWCHAIN: + err = tc_chain_tmplt_add(chain, net, tca, extack); + if (err) { + tcf_chain_put_explicitly_created(chain); + goto errout; + } + + tc_chain_notify(chain, NULL, 0, NLM_F_CREATE | NLM_F_EXCL, + RTM_NEWCHAIN, false, extack); + break; + case RTM_DELCHAIN: + tfilter_notify_chain(net, skb, block, q, parent, n, + chain, RTM_DELTFILTER, extack); + /* Flush the chain first as the user requested chain removal. */ + tcf_chain_flush(chain, true); + /* In case the chain was successfully deleted, put a reference + * to the chain previously taken during addition. + */ + tcf_chain_put_explicitly_created(chain); + break; + case RTM_GETCHAIN: + err = tc_chain_notify(chain, skb, n->nlmsg_seq, + n->nlmsg_flags, n->nlmsg_type, true, extack); + if (err < 0) + NL_SET_ERR_MSG(extack, "Failed to send chain notify message"); + break; + default: + err = -EOPNOTSUPP; + NL_SET_ERR_MSG(extack, "Unsupported message type"); + goto errout; + } + +errout: + tcf_chain_put(chain); +errout_block: + tcf_block_release(q, block, true); + if (err == -EAGAIN) + /* Replay the request. */ + goto replay; + return err; + +errout_block_locked: + mutex_unlock(&block->lock); + goto errout_block; +} + +/* called with RTNL */ +static int tc_dump_chain(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *tca[TCA_MAX + 1]; + struct Qdisc *q = NULL; + struct tcf_block *block; + struct tcmsg *tcm = nlmsg_data(cb->nlh); + struct tcf_chain *chain; + long index_start; + long index; + int err; + + if (nlmsg_len(cb->nlh) < sizeof(*tcm)) + return skb->len; + + err = nlmsg_parse_deprecated(cb->nlh, sizeof(*tcm), tca, TCA_MAX, + rtm_tca_policy, cb->extack); + if (err) + return err; + + if (tcm->tcm_ifindex == TCM_IFINDEX_MAGIC_BLOCK) { + block = tcf_block_refcnt_get(net, tcm->tcm_block_index); + if (!block) + goto out; + } else { + const struct Qdisc_class_ops *cops; + struct net_device *dev; + unsigned long cl = 0; + + dev = __dev_get_by_index(net, tcm->tcm_ifindex); + if (!dev) + return skb->len; + + if (!tcm->tcm_parent) + q = rtnl_dereference(dev->qdisc); + else + q = qdisc_lookup(dev, TC_H_MAJ(tcm->tcm_parent)); + + if (!q) + goto out; + cops = q->ops->cl_ops; + if (!cops) + goto out; + if (!cops->tcf_block) + goto out; + if (TC_H_MIN(tcm->tcm_parent)) { + cl = cops->find(q, tcm->tcm_parent); + if (cl == 0) + goto out; + } + block = cops->tcf_block(q, cl, NULL); + if (!block) + goto out; + if (tcf_block_shared(block)) + q = NULL; + } + + index_start = cb->args[0]; + index = 0; + + mutex_lock(&block->lock); + list_for_each_entry(chain, &block->chain_list, list) { + if ((tca[TCA_CHAIN] && + nla_get_u32(tca[TCA_CHAIN]) != chain->index)) + continue; + if (index < index_start) { + index++; + continue; + } + if (tcf_chain_held_by_acts_only(chain)) + continue; + err = tc_chain_fill_node(chain->tmplt_ops, chain->tmplt_priv, + chain->index, net, skb, block, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + RTM_NEWCHAIN, NULL); + if (err <= 0) + break; + index++; + } + mutex_unlock(&block->lock); + + if (tcm->tcm_ifindex == TCM_IFINDEX_MAGIC_BLOCK) + tcf_block_refcnt_put(block, true); + cb->args[0] = index; + +out: + /* If we did no progress, the error (EMSGSIZE) is real */ + if (skb->len == 0 && err) + return err; + return skb->len; +} + +void tcf_exts_destroy(struct tcf_exts *exts) +{ +#ifdef CONFIG_NET_CLS_ACT + if (exts->actions) { + tcf_action_destroy(exts->actions, TCA_ACT_UNBIND); + kfree(exts->actions); + } + exts->nr_actions = 0; +#endif +} +EXPORT_SYMBOL(tcf_exts_destroy); + +int tcf_exts_validate_ex(struct net *net, struct tcf_proto *tp, struct nlattr **tb, + struct nlattr *rate_tlv, struct tcf_exts *exts, + u32 flags, u32 fl_flags, struct netlink_ext_ack *extack) +{ +#ifdef CONFIG_NET_CLS_ACT + { + int init_res[TCA_ACT_MAX_PRIO] = {}; + struct tc_action *act; + size_t attr_size = 0; + + if (exts->police && tb[exts->police]) { + struct tc_action_ops *a_o; + + a_o = tc_action_load_ops(tb[exts->police], true, + !(flags & TCA_ACT_FLAGS_NO_RTNL), + extack); + if (IS_ERR(a_o)) + return PTR_ERR(a_o); + flags |= TCA_ACT_FLAGS_POLICE | TCA_ACT_FLAGS_BIND; + act = tcf_action_init_1(net, tp, tb[exts->police], + rate_tlv, a_o, init_res, flags, + extack); + module_put(a_o->owner); + if (IS_ERR(act)) + return PTR_ERR(act); + + act->type = exts->type = TCA_OLD_COMPAT; + exts->actions[0] = act; + exts->nr_actions = 1; + tcf_idr_insert_many(exts->actions); + } else if (exts->action && tb[exts->action]) { + int err; + + flags |= TCA_ACT_FLAGS_BIND; + err = tcf_action_init(net, tp, tb[exts->action], + rate_tlv, exts->actions, init_res, + &attr_size, flags, fl_flags, + extack); + if (err < 0) + return err; + exts->nr_actions = err; + } + } +#else + if ((exts->action && tb[exts->action]) || + (exts->police && tb[exts->police])) { + NL_SET_ERR_MSG(extack, "Classifier actions are not supported per compile options (CONFIG_NET_CLS_ACT)"); + return -EOPNOTSUPP; + } +#endif + + return 0; +} +EXPORT_SYMBOL(tcf_exts_validate_ex); + +int tcf_exts_validate(struct net *net, struct tcf_proto *tp, struct nlattr **tb, + struct nlattr *rate_tlv, struct tcf_exts *exts, + u32 flags, struct netlink_ext_ack *extack) +{ + return tcf_exts_validate_ex(net, tp, tb, rate_tlv, exts, + flags, 0, extack); +} +EXPORT_SYMBOL(tcf_exts_validate); + +void tcf_exts_change(struct tcf_exts *dst, struct tcf_exts *src) +{ +#ifdef CONFIG_NET_CLS_ACT + struct tcf_exts old = *dst; + + *dst = *src; + tcf_exts_destroy(&old); +#endif +} +EXPORT_SYMBOL(tcf_exts_change); + +#ifdef CONFIG_NET_CLS_ACT +static struct tc_action *tcf_exts_first_act(struct tcf_exts *exts) +{ + if (exts->nr_actions == 0) + return NULL; + else + return exts->actions[0]; +} +#endif + +int tcf_exts_dump(struct sk_buff *skb, struct tcf_exts *exts) +{ +#ifdef CONFIG_NET_CLS_ACT + struct nlattr *nest; + + if (exts->action && tcf_exts_has_actions(exts)) { + /* + * again for backward compatible mode - we want + * to work with both old and new modes of entering + * tc data even if iproute2 was newer - jhs + */ + if (exts->type != TCA_OLD_COMPAT) { + nest = nla_nest_start_noflag(skb, exts->action); + if (nest == NULL) + goto nla_put_failure; + + if (tcf_action_dump(skb, exts->actions, 0, 0, false) + < 0) + goto nla_put_failure; + nla_nest_end(skb, nest); + } else if (exts->police) { + struct tc_action *act = tcf_exts_first_act(exts); + nest = nla_nest_start_noflag(skb, exts->police); + if (nest == NULL || !act) + goto nla_put_failure; + if (tcf_action_dump_old(skb, act, 0, 0) < 0) + goto nla_put_failure; + nla_nest_end(skb, nest); + } + } + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +#else + return 0; +#endif +} +EXPORT_SYMBOL(tcf_exts_dump); + +int tcf_exts_terse_dump(struct sk_buff *skb, struct tcf_exts *exts) +{ +#ifdef CONFIG_NET_CLS_ACT + struct nlattr *nest; + + if (!exts->action || !tcf_exts_has_actions(exts)) + return 0; + + nest = nla_nest_start_noflag(skb, exts->action); + if (!nest) + goto nla_put_failure; + + if (tcf_action_dump(skb, exts->actions, 0, 0, true) < 0) + goto nla_put_failure; + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +#else + return 0; +#endif +} +EXPORT_SYMBOL(tcf_exts_terse_dump); + +int tcf_exts_dump_stats(struct sk_buff *skb, struct tcf_exts *exts) +{ +#ifdef CONFIG_NET_CLS_ACT + struct tc_action *a = tcf_exts_first_act(exts); + if (a != NULL && tcf_action_copy_stats(skb, a, 1) < 0) + return -1; +#endif + return 0; +} +EXPORT_SYMBOL(tcf_exts_dump_stats); + +static void tcf_block_offload_inc(struct tcf_block *block, u32 *flags) +{ + if (*flags & TCA_CLS_FLAGS_IN_HW) + return; + *flags |= TCA_CLS_FLAGS_IN_HW; + atomic_inc(&block->offloadcnt); +} + +static void tcf_block_offload_dec(struct tcf_block *block, u32 *flags) +{ + if (!(*flags & TCA_CLS_FLAGS_IN_HW)) + return; + *flags &= ~TCA_CLS_FLAGS_IN_HW; + atomic_dec(&block->offloadcnt); +} + +static void tc_cls_offload_cnt_update(struct tcf_block *block, + struct tcf_proto *tp, u32 *cnt, + u32 *flags, u32 diff, bool add) +{ + lockdep_assert_held(&block->cb_lock); + + spin_lock(&tp->lock); + if (add) { + if (!*cnt) + tcf_block_offload_inc(block, flags); + *cnt += diff; + } else { + *cnt -= diff; + if (!*cnt) + tcf_block_offload_dec(block, flags); + } + spin_unlock(&tp->lock); +} + +static void +tc_cls_offload_cnt_reset(struct tcf_block *block, struct tcf_proto *tp, + u32 *cnt, u32 *flags) +{ + lockdep_assert_held(&block->cb_lock); + + spin_lock(&tp->lock); + tcf_block_offload_dec(block, flags); + *cnt = 0; + spin_unlock(&tp->lock); +} + +static int +__tc_setup_cb_call(struct tcf_block *block, enum tc_setup_type type, + void *type_data, bool err_stop) +{ + struct flow_block_cb *block_cb; + int ok_count = 0; + int err; + + list_for_each_entry(block_cb, &block->flow_block.cb_list, list) { + err = block_cb->cb(type, type_data, block_cb->cb_priv); + if (err) { + if (err_stop) + return err; + } else { + ok_count++; + } + } + return ok_count; +} + +int tc_setup_cb_call(struct tcf_block *block, enum tc_setup_type type, + void *type_data, bool err_stop, bool rtnl_held) +{ + bool take_rtnl = READ_ONCE(block->lockeddevcnt) && !rtnl_held; + int ok_count; + +retry: + if (take_rtnl) + rtnl_lock(); + down_read(&block->cb_lock); + /* Need to obtain rtnl lock if block is bound to devs that require it. + * In block bind code cb_lock is obtained while holding rtnl, so we must + * obtain the locks in same order here. + */ + if (!rtnl_held && !take_rtnl && block->lockeddevcnt) { + up_read(&block->cb_lock); + take_rtnl = true; + goto retry; + } + + ok_count = __tc_setup_cb_call(block, type, type_data, err_stop); + + up_read(&block->cb_lock); + if (take_rtnl) + rtnl_unlock(); + return ok_count; +} +EXPORT_SYMBOL(tc_setup_cb_call); + +/* Non-destructive filter add. If filter that wasn't already in hardware is + * successfully offloaded, increment block offloads counter. On failure, + * previously offloaded filter is considered to be intact and offloads counter + * is not decremented. + */ + +int tc_setup_cb_add(struct tcf_block *block, struct tcf_proto *tp, + enum tc_setup_type type, void *type_data, bool err_stop, + u32 *flags, unsigned int *in_hw_count, bool rtnl_held) +{ + bool take_rtnl = READ_ONCE(block->lockeddevcnt) && !rtnl_held; + int ok_count; + +retry: + if (take_rtnl) + rtnl_lock(); + down_read(&block->cb_lock); + /* Need to obtain rtnl lock if block is bound to devs that require it. + * In block bind code cb_lock is obtained while holding rtnl, so we must + * obtain the locks in same order here. + */ + if (!rtnl_held && !take_rtnl && block->lockeddevcnt) { + up_read(&block->cb_lock); + take_rtnl = true; + goto retry; + } + + /* Make sure all netdevs sharing this block are offload-capable. */ + if (block->nooffloaddevcnt && err_stop) { + ok_count = -EOPNOTSUPP; + goto err_unlock; + } + + ok_count = __tc_setup_cb_call(block, type, type_data, err_stop); + if (ok_count < 0) + goto err_unlock; + + if (tp->ops->hw_add) + tp->ops->hw_add(tp, type_data); + if (ok_count > 0) + tc_cls_offload_cnt_update(block, tp, in_hw_count, flags, + ok_count, true); +err_unlock: + up_read(&block->cb_lock); + if (take_rtnl) + rtnl_unlock(); + return min(ok_count, 0); +} +EXPORT_SYMBOL(tc_setup_cb_add); + +/* Destructive filter replace. If filter that wasn't already in hardware is + * successfully offloaded, increment block offload counter. On failure, + * previously offloaded filter is considered to be destroyed and offload counter + * is decremented. + */ + +int tc_setup_cb_replace(struct tcf_block *block, struct tcf_proto *tp, + enum tc_setup_type type, void *type_data, bool err_stop, + u32 *old_flags, unsigned int *old_in_hw_count, + u32 *new_flags, unsigned int *new_in_hw_count, + bool rtnl_held) +{ + bool take_rtnl = READ_ONCE(block->lockeddevcnt) && !rtnl_held; + int ok_count; + +retry: + if (take_rtnl) + rtnl_lock(); + down_read(&block->cb_lock); + /* Need to obtain rtnl lock if block is bound to devs that require it. + * In block bind code cb_lock is obtained while holding rtnl, so we must + * obtain the locks in same order here. + */ + if (!rtnl_held && !take_rtnl && block->lockeddevcnt) { + up_read(&block->cb_lock); + take_rtnl = true; + goto retry; + } + + /* Make sure all netdevs sharing this block are offload-capable. */ + if (block->nooffloaddevcnt && err_stop) { + ok_count = -EOPNOTSUPP; + goto err_unlock; + } + + tc_cls_offload_cnt_reset(block, tp, old_in_hw_count, old_flags); + if (tp->ops->hw_del) + tp->ops->hw_del(tp, type_data); + + ok_count = __tc_setup_cb_call(block, type, type_data, err_stop); + if (ok_count < 0) + goto err_unlock; + + if (tp->ops->hw_add) + tp->ops->hw_add(tp, type_data); + if (ok_count > 0) + tc_cls_offload_cnt_update(block, tp, new_in_hw_count, + new_flags, ok_count, true); +err_unlock: + up_read(&block->cb_lock); + if (take_rtnl) + rtnl_unlock(); + return min(ok_count, 0); +} +EXPORT_SYMBOL(tc_setup_cb_replace); + +/* Destroy filter and decrement block offload counter, if filter was previously + * offloaded. + */ + +int tc_setup_cb_destroy(struct tcf_block *block, struct tcf_proto *tp, + enum tc_setup_type type, void *type_data, bool err_stop, + u32 *flags, unsigned int *in_hw_count, bool rtnl_held) +{ + bool take_rtnl = READ_ONCE(block->lockeddevcnt) && !rtnl_held; + int ok_count; + +retry: + if (take_rtnl) + rtnl_lock(); + down_read(&block->cb_lock); + /* Need to obtain rtnl lock if block is bound to devs that require it. + * In block bind code cb_lock is obtained while holding rtnl, so we must + * obtain the locks in same order here. + */ + if (!rtnl_held && !take_rtnl && block->lockeddevcnt) { + up_read(&block->cb_lock); + take_rtnl = true; + goto retry; + } + + ok_count = __tc_setup_cb_call(block, type, type_data, err_stop); + + tc_cls_offload_cnt_reset(block, tp, in_hw_count, flags); + if (tp->ops->hw_del) + tp->ops->hw_del(tp, type_data); + + up_read(&block->cb_lock); + if (take_rtnl) + rtnl_unlock(); + return min(ok_count, 0); +} +EXPORT_SYMBOL(tc_setup_cb_destroy); + +int tc_setup_cb_reoffload(struct tcf_block *block, struct tcf_proto *tp, + bool add, flow_setup_cb_t *cb, + enum tc_setup_type type, void *type_data, + void *cb_priv, u32 *flags, unsigned int *in_hw_count) +{ + int err = cb(type, type_data, cb_priv); + + if (err) { + if (add && tc_skip_sw(*flags)) + return err; + } else { + tc_cls_offload_cnt_update(block, tp, in_hw_count, flags, 1, + add); + } + + return 0; +} +EXPORT_SYMBOL(tc_setup_cb_reoffload); + +static int tcf_act_get_cookie(struct flow_action_entry *entry, + const struct tc_action *act) +{ + struct tc_cookie *cookie; + int err = 0; + + rcu_read_lock(); + cookie = rcu_dereference(act->act_cookie); + if (cookie) { + entry->cookie = flow_action_cookie_create(cookie->data, + cookie->len, + GFP_ATOMIC); + if (!entry->cookie) + err = -ENOMEM; + } + rcu_read_unlock(); + return err; +} + +static void tcf_act_put_cookie(struct flow_action_entry *entry) +{ + flow_action_cookie_destroy(entry->cookie); +} + +void tc_cleanup_offload_action(struct flow_action *flow_action) +{ + struct flow_action_entry *entry; + int i; + + flow_action_for_each(i, entry, flow_action) { + tcf_act_put_cookie(entry); + if (entry->destructor) + entry->destructor(entry->destructor_priv); + } +} +EXPORT_SYMBOL(tc_cleanup_offload_action); + +static int tc_setup_offload_act(struct tc_action *act, + struct flow_action_entry *entry, + u32 *index_inc, + struct netlink_ext_ack *extack) +{ +#ifdef CONFIG_NET_CLS_ACT + if (act->ops->offload_act_setup) { + return act->ops->offload_act_setup(act, entry, index_inc, true, + extack); + } else { + NL_SET_ERR_MSG(extack, "Action does not support offload"); + return -EOPNOTSUPP; + } +#else + return 0; +#endif +} + +int tc_setup_action(struct flow_action *flow_action, + struct tc_action *actions[], + struct netlink_ext_ack *extack) +{ + int i, j, k, index, err = 0; + struct tc_action *act; + + BUILD_BUG_ON(TCA_ACT_HW_STATS_ANY != FLOW_ACTION_HW_STATS_ANY); + BUILD_BUG_ON(TCA_ACT_HW_STATS_IMMEDIATE != FLOW_ACTION_HW_STATS_IMMEDIATE); + BUILD_BUG_ON(TCA_ACT_HW_STATS_DELAYED != FLOW_ACTION_HW_STATS_DELAYED); + + if (!actions) + return 0; + + j = 0; + tcf_act_for_each_action(i, act, actions) { + struct flow_action_entry *entry; + + entry = &flow_action->entries[j]; + spin_lock_bh(&act->tcfa_lock); + err = tcf_act_get_cookie(entry, act); + if (err) + goto err_out_locked; + + index = 0; + err = tc_setup_offload_act(act, entry, &index, extack); + if (err) + goto err_out_locked; + + for (k = 0; k < index ; k++) { + entry[k].hw_stats = tc_act_hw_stats(act->hw_stats); + entry[k].hw_index = act->tcfa_index; + } + + j += index; + + spin_unlock_bh(&act->tcfa_lock); + } + +err_out: + if (err) + tc_cleanup_offload_action(flow_action); + + return err; +err_out_locked: + spin_unlock_bh(&act->tcfa_lock); + goto err_out; +} + +int tc_setup_offload_action(struct flow_action *flow_action, + const struct tcf_exts *exts, + struct netlink_ext_ack *extack) +{ +#ifdef CONFIG_NET_CLS_ACT + if (!exts) + return 0; + + return tc_setup_action(flow_action, exts->actions, extack); +#else + return 0; +#endif +} +EXPORT_SYMBOL(tc_setup_offload_action); + +unsigned int tcf_exts_num_actions(struct tcf_exts *exts) +{ + unsigned int num_acts = 0; + struct tc_action *act; + int i; + + tcf_exts_for_each_action(i, act, exts) { + if (is_tcf_pedit(act)) + num_acts += tcf_pedit_nkeys(act); + else + num_acts++; + } + return num_acts; +} +EXPORT_SYMBOL(tcf_exts_num_actions); + +#ifdef CONFIG_NET_CLS_ACT +static int tcf_qevent_parse_block_index(struct nlattr *block_index_attr, + u32 *p_block_index, + struct netlink_ext_ack *extack) +{ + *p_block_index = nla_get_u32(block_index_attr); + if (!*p_block_index) { + NL_SET_ERR_MSG(extack, "Block number may not be zero"); + return -EINVAL; + } + + return 0; +} + +int tcf_qevent_init(struct tcf_qevent *qe, struct Qdisc *sch, + enum flow_block_binder_type binder_type, + struct nlattr *block_index_attr, + struct netlink_ext_ack *extack) +{ + u32 block_index; + int err; + + if (!block_index_attr) + return 0; + + err = tcf_qevent_parse_block_index(block_index_attr, &block_index, extack); + if (err) + return err; + + qe->info.binder_type = binder_type; + qe->info.chain_head_change = tcf_chain_head_change_dflt; + qe->info.chain_head_change_priv = &qe->filter_chain; + qe->info.block_index = block_index; + + return tcf_block_get_ext(&qe->block, sch, &qe->info, extack); +} +EXPORT_SYMBOL(tcf_qevent_init); + +void tcf_qevent_destroy(struct tcf_qevent *qe, struct Qdisc *sch) +{ + if (qe->info.block_index) + tcf_block_put_ext(qe->block, sch, &qe->info); +} +EXPORT_SYMBOL(tcf_qevent_destroy); + +int tcf_qevent_validate_change(struct tcf_qevent *qe, struct nlattr *block_index_attr, + struct netlink_ext_ack *extack) +{ + u32 block_index; + int err; + + if (!block_index_attr) + return 0; + + err = tcf_qevent_parse_block_index(block_index_attr, &block_index, extack); + if (err) + return err; + + /* Bounce newly-configured block or change in block. */ + if (block_index != qe->info.block_index) { + NL_SET_ERR_MSG(extack, "Change of blocks is not supported"); + return -EINVAL; + } + + return 0; +} +EXPORT_SYMBOL(tcf_qevent_validate_change); + +struct sk_buff *tcf_qevent_handle(struct tcf_qevent *qe, struct Qdisc *sch, struct sk_buff *skb, + struct sk_buff **to_free, int *ret) +{ + struct tcf_result cl_res; + struct tcf_proto *fl; + + if (!qe->info.block_index) + return skb; + + fl = rcu_dereference_bh(qe->filter_chain); + + switch (tcf_classify(skb, NULL, fl, &cl_res, false)) { + case TC_ACT_SHOT: + qdisc_qstats_drop(sch); + __qdisc_drop(skb, to_free); + *ret = __NET_XMIT_BYPASS; + return NULL; + case TC_ACT_STOLEN: + case TC_ACT_QUEUED: + case TC_ACT_TRAP: + __qdisc_drop(skb, to_free); + *ret = __NET_XMIT_STOLEN; + return NULL; + case TC_ACT_REDIRECT: + skb_do_redirect(skb); + *ret = __NET_XMIT_STOLEN; + return NULL; + } + + return skb; +} +EXPORT_SYMBOL(tcf_qevent_handle); + +int tcf_qevent_dump(struct sk_buff *skb, int attr_name, struct tcf_qevent *qe) +{ + if (!qe->info.block_index) + return 0; + return nla_put_u32(skb, attr_name, qe->info.block_index); +} +EXPORT_SYMBOL(tcf_qevent_dump); +#endif + +static __net_init int tcf_net_init(struct net *net) +{ + struct tcf_net *tn = net_generic(net, tcf_net_id); + + spin_lock_init(&tn->idr_lock); + idr_init(&tn->idr); + return 0; +} + +static void __net_exit tcf_net_exit(struct net *net) +{ + struct tcf_net *tn = net_generic(net, tcf_net_id); + + idr_destroy(&tn->idr); +} + +static struct pernet_operations tcf_net_ops = { + .init = tcf_net_init, + .exit = tcf_net_exit, + .id = &tcf_net_id, + .size = sizeof(struct tcf_net), +}; + +static int __init tc_filter_init(void) +{ + int err; + + tc_filter_wq = alloc_ordered_workqueue("tc_filter_workqueue", 0); + if (!tc_filter_wq) + return -ENOMEM; + + err = register_pernet_subsys(&tcf_net_ops); + if (err) + goto err_register_pernet_subsys; + + rtnl_register(PF_UNSPEC, RTM_NEWTFILTER, tc_new_tfilter, NULL, + RTNL_FLAG_DOIT_UNLOCKED); + rtnl_register(PF_UNSPEC, RTM_DELTFILTER, tc_del_tfilter, NULL, + RTNL_FLAG_DOIT_UNLOCKED); + rtnl_register(PF_UNSPEC, RTM_GETTFILTER, tc_get_tfilter, + tc_dump_tfilter, RTNL_FLAG_DOIT_UNLOCKED); + rtnl_register(PF_UNSPEC, RTM_NEWCHAIN, tc_ctl_chain, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_DELCHAIN, tc_ctl_chain, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_GETCHAIN, tc_ctl_chain, + tc_dump_chain, 0); + + return 0; + +err_register_pernet_subsys: + destroy_workqueue(tc_filter_wq); + return err; +} + +subsys_initcall(tc_filter_init); diff --git a/net/sched/cls_basic.c b/net/sched/cls_basic.c new file mode 100644 index 000000000..d229ce99e --- /dev/null +++ b/net/sched/cls_basic.c @@ -0,0 +1,342 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/cls_basic.c Basic Packet Classifier. + * + * Authors: Thomas Graf + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct basic_head { + struct list_head flist; + struct idr handle_idr; + struct rcu_head rcu; +}; + +struct basic_filter { + u32 handle; + struct tcf_exts exts; + struct tcf_ematch_tree ematches; + struct tcf_result res; + struct tcf_proto *tp; + struct list_head link; + struct tc_basic_pcnt __percpu *pf; + struct rcu_work rwork; +}; + +static int basic_classify(struct sk_buff *skb, const struct tcf_proto *tp, + struct tcf_result *res) +{ + int r; + struct basic_head *head = rcu_dereference_bh(tp->root); + struct basic_filter *f; + + list_for_each_entry_rcu(f, &head->flist, link) { + __this_cpu_inc(f->pf->rcnt); + if (!tcf_em_tree_match(skb, &f->ematches, NULL)) + continue; + __this_cpu_inc(f->pf->rhit); + *res = f->res; + r = tcf_exts_exec(skb, &f->exts, res); + if (r < 0) + continue; + return r; + } + return -1; +} + +static void *basic_get(struct tcf_proto *tp, u32 handle) +{ + struct basic_head *head = rtnl_dereference(tp->root); + struct basic_filter *f; + + list_for_each_entry(f, &head->flist, link) { + if (f->handle == handle) { + return f; + } + } + + return NULL; +} + +static int basic_init(struct tcf_proto *tp) +{ + struct basic_head *head; + + head = kzalloc(sizeof(*head), GFP_KERNEL); + if (head == NULL) + return -ENOBUFS; + INIT_LIST_HEAD(&head->flist); + idr_init(&head->handle_idr); + rcu_assign_pointer(tp->root, head); + return 0; +} + +static void __basic_delete_filter(struct basic_filter *f) +{ + tcf_exts_destroy(&f->exts); + tcf_em_tree_destroy(&f->ematches); + tcf_exts_put_net(&f->exts); + free_percpu(f->pf); + kfree(f); +} + +static void basic_delete_filter_work(struct work_struct *work) +{ + struct basic_filter *f = container_of(to_rcu_work(work), + struct basic_filter, + rwork); + rtnl_lock(); + __basic_delete_filter(f); + rtnl_unlock(); +} + +static void basic_destroy(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + struct basic_head *head = rtnl_dereference(tp->root); + struct basic_filter *f, *n; + + list_for_each_entry_safe(f, n, &head->flist, link) { + list_del_rcu(&f->link); + tcf_unbind_filter(tp, &f->res); + idr_remove(&head->handle_idr, f->handle); + if (tcf_exts_get_net(&f->exts)) + tcf_queue_work(&f->rwork, basic_delete_filter_work); + else + __basic_delete_filter(f); + } + idr_destroy(&head->handle_idr); + kfree_rcu(head, rcu); +} + +static int basic_delete(struct tcf_proto *tp, void *arg, bool *last, + bool rtnl_held, struct netlink_ext_ack *extack) +{ + struct basic_head *head = rtnl_dereference(tp->root); + struct basic_filter *f = arg; + + list_del_rcu(&f->link); + tcf_unbind_filter(tp, &f->res); + idr_remove(&head->handle_idr, f->handle); + tcf_exts_get_net(&f->exts); + tcf_queue_work(&f->rwork, basic_delete_filter_work); + *last = list_empty(&head->flist); + return 0; +} + +static const struct nla_policy basic_policy[TCA_BASIC_MAX + 1] = { + [TCA_BASIC_CLASSID] = { .type = NLA_U32 }, + [TCA_BASIC_EMATCHES] = { .type = NLA_NESTED }, +}; + +static int basic_set_parms(struct net *net, struct tcf_proto *tp, + struct basic_filter *f, unsigned long base, + struct nlattr **tb, + struct nlattr *est, u32 flags, + struct netlink_ext_ack *extack) +{ + int err; + + err = tcf_exts_validate(net, tp, tb, est, &f->exts, flags, extack); + if (err < 0) + return err; + + err = tcf_em_tree_validate(tp, tb[TCA_BASIC_EMATCHES], &f->ematches); + if (err < 0) + return err; + + if (tb[TCA_BASIC_CLASSID]) { + f->res.classid = nla_get_u32(tb[TCA_BASIC_CLASSID]); + tcf_bind_filter(tp, &f->res, base); + } + + f->tp = tp; + return 0; +} + +static int basic_change(struct net *net, struct sk_buff *in_skb, + struct tcf_proto *tp, unsigned long base, u32 handle, + struct nlattr **tca, void **arg, + u32 flags, struct netlink_ext_ack *extack) +{ + int err; + struct basic_head *head = rtnl_dereference(tp->root); + struct nlattr *tb[TCA_BASIC_MAX + 1]; + struct basic_filter *fold = (struct basic_filter *) *arg; + struct basic_filter *fnew; + + if (tca[TCA_OPTIONS] == NULL) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, TCA_BASIC_MAX, tca[TCA_OPTIONS], + basic_policy, NULL); + if (err < 0) + return err; + + if (fold != NULL) { + if (handle && fold->handle != handle) + return -EINVAL; + } + + fnew = kzalloc(sizeof(*fnew), GFP_KERNEL); + if (!fnew) + return -ENOBUFS; + + err = tcf_exts_init(&fnew->exts, net, TCA_BASIC_ACT, TCA_BASIC_POLICE); + if (err < 0) + goto errout; + + if (!handle) { + handle = 1; + err = idr_alloc_u32(&head->handle_idr, fnew, &handle, + INT_MAX, GFP_KERNEL); + } else if (!fold) { + err = idr_alloc_u32(&head->handle_idr, fnew, &handle, + handle, GFP_KERNEL); + } + if (err) + goto errout; + fnew->handle = handle; + fnew->pf = alloc_percpu(struct tc_basic_pcnt); + if (!fnew->pf) { + err = -ENOMEM; + goto errout; + } + + err = basic_set_parms(net, tp, fnew, base, tb, tca[TCA_RATE], flags, + extack); + if (err < 0) { + if (!fold) + idr_remove(&head->handle_idr, fnew->handle); + goto errout; + } + + *arg = fnew; + + if (fold) { + idr_replace(&head->handle_idr, fnew, fnew->handle); + list_replace_rcu(&fold->link, &fnew->link); + tcf_unbind_filter(tp, &fold->res); + tcf_exts_get_net(&fold->exts); + tcf_queue_work(&fold->rwork, basic_delete_filter_work); + } else { + list_add_rcu(&fnew->link, &head->flist); + } + + return 0; +errout: + free_percpu(fnew->pf); + tcf_exts_destroy(&fnew->exts); + kfree(fnew); + return err; +} + +static void basic_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) +{ + struct basic_head *head = rtnl_dereference(tp->root); + struct basic_filter *f; + + list_for_each_entry(f, &head->flist, link) { + if (!tc_cls_stats_dump(tp, arg, f)) + break; + } +} + +static void basic_bind_class(void *fh, u32 classid, unsigned long cl, void *q, + unsigned long base) +{ + struct basic_filter *f = fh; + + tc_cls_bind_class(classid, cl, q, &f->res, base); +} + +static int basic_dump(struct net *net, struct tcf_proto *tp, void *fh, + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) +{ + struct tc_basic_pcnt gpf = {}; + struct basic_filter *f = fh; + struct nlattr *nest; + int cpu; + + if (f == NULL) + return skb->len; + + t->tcm_handle = f->handle; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + + if (f->res.classid && + nla_put_u32(skb, TCA_BASIC_CLASSID, f->res.classid)) + goto nla_put_failure; + + for_each_possible_cpu(cpu) { + struct tc_basic_pcnt *pf = per_cpu_ptr(f->pf, cpu); + + gpf.rcnt += pf->rcnt; + gpf.rhit += pf->rhit; + } + + if (nla_put_64bit(skb, TCA_BASIC_PCNT, + sizeof(struct tc_basic_pcnt), + &gpf, TCA_BASIC_PAD)) + goto nla_put_failure; + + if (tcf_exts_dump(skb, &f->exts) < 0 || + tcf_em_tree_dump(skb, &f->ematches, TCA_BASIC_EMATCHES) < 0) + goto nla_put_failure; + + nla_nest_end(skb, nest); + + if (tcf_exts_dump_stats(skb, &f->exts) < 0) + goto nla_put_failure; + + return skb->len; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static struct tcf_proto_ops cls_basic_ops __read_mostly = { + .kind = "basic", + .classify = basic_classify, + .init = basic_init, + .destroy = basic_destroy, + .get = basic_get, + .change = basic_change, + .delete = basic_delete, + .walk = basic_walk, + .dump = basic_dump, + .bind_class = basic_bind_class, + .owner = THIS_MODULE, +}; + +static int __init init_basic(void) +{ + return register_tcf_proto_ops(&cls_basic_ops); +} + +static void __exit exit_basic(void) +{ + unregister_tcf_proto_ops(&cls_basic_ops); +} + +module_init(init_basic) +module_exit(exit_basic) +MODULE_LICENSE("GPL"); diff --git a/net/sched/cls_bpf.c b/net/sched/cls_bpf.c new file mode 100644 index 000000000..0320e11eb --- /dev/null +++ b/net/sched/cls_bpf.c @@ -0,0 +1,706 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Berkeley Packet Filter based traffic classifier + * + * Might be used to classify traffic through flexible, user-defined and + * possibly JIT-ed BPF filters for traffic control as an alternative to + * ematches. + * + * (C) 2013 Daniel Borkmann + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Daniel Borkmann "); +MODULE_DESCRIPTION("TC BPF based classifier"); + +#define CLS_BPF_NAME_LEN 256 +#define CLS_BPF_SUPPORTED_GEN_FLAGS \ + (TCA_CLS_FLAGS_SKIP_HW | TCA_CLS_FLAGS_SKIP_SW) + +struct cls_bpf_head { + struct list_head plist; + struct idr handle_idr; + struct rcu_head rcu; +}; + +struct cls_bpf_prog { + struct bpf_prog *filter; + struct list_head link; + struct tcf_result res; + bool exts_integrated; + u32 gen_flags; + unsigned int in_hw_count; + struct tcf_exts exts; + u32 handle; + u16 bpf_num_ops; + struct sock_filter *bpf_ops; + const char *bpf_name; + struct tcf_proto *tp; + struct rcu_work rwork; +}; + +static const struct nla_policy bpf_policy[TCA_BPF_MAX + 1] = { + [TCA_BPF_CLASSID] = { .type = NLA_U32 }, + [TCA_BPF_FLAGS] = { .type = NLA_U32 }, + [TCA_BPF_FLAGS_GEN] = { .type = NLA_U32 }, + [TCA_BPF_FD] = { .type = NLA_U32 }, + [TCA_BPF_NAME] = { .type = NLA_NUL_STRING, + .len = CLS_BPF_NAME_LEN }, + [TCA_BPF_OPS_LEN] = { .type = NLA_U16 }, + [TCA_BPF_OPS] = { .type = NLA_BINARY, + .len = sizeof(struct sock_filter) * BPF_MAXINSNS }, +}; + +static int cls_bpf_exec_opcode(int code) +{ + switch (code) { + case TC_ACT_OK: + case TC_ACT_SHOT: + case TC_ACT_STOLEN: + case TC_ACT_TRAP: + case TC_ACT_REDIRECT: + case TC_ACT_UNSPEC: + return code; + default: + return TC_ACT_UNSPEC; + } +} + +static int cls_bpf_classify(struct sk_buff *skb, const struct tcf_proto *tp, + struct tcf_result *res) +{ + struct cls_bpf_head *head = rcu_dereference_bh(tp->root); + bool at_ingress = skb_at_tc_ingress(skb); + struct cls_bpf_prog *prog; + int ret = -1; + + list_for_each_entry_rcu(prog, &head->plist, link) { + int filter_res; + + qdisc_skb_cb(skb)->tc_classid = prog->res.classid; + + if (tc_skip_sw(prog->gen_flags)) { + filter_res = prog->exts_integrated ? TC_ACT_UNSPEC : 0; + } else if (at_ingress) { + /* It is safe to push/pull even if skb_shared() */ + __skb_push(skb, skb->mac_len); + bpf_compute_data_pointers(skb); + filter_res = bpf_prog_run(prog->filter, skb); + __skb_pull(skb, skb->mac_len); + } else { + bpf_compute_data_pointers(skb); + filter_res = bpf_prog_run(prog->filter, skb); + } + if (unlikely(!skb->tstamp && skb->mono_delivery_time)) + skb->mono_delivery_time = 0; + + if (prog->exts_integrated) { + res->class = 0; + res->classid = TC_H_MAJ(prog->res.classid) | + qdisc_skb_cb(skb)->tc_classid; + + ret = cls_bpf_exec_opcode(filter_res); + if (ret == TC_ACT_UNSPEC) + continue; + break; + } + + if (filter_res == 0) + continue; + if (filter_res != -1) { + res->class = 0; + res->classid = filter_res; + } else { + *res = prog->res; + } + + ret = tcf_exts_exec(skb, &prog->exts, res); + if (ret < 0) + continue; + + break; + } + + return ret; +} + +static bool cls_bpf_is_ebpf(const struct cls_bpf_prog *prog) +{ + return !prog->bpf_ops; +} + +static int cls_bpf_offload_cmd(struct tcf_proto *tp, struct cls_bpf_prog *prog, + struct cls_bpf_prog *oldprog, + struct netlink_ext_ack *extack) +{ + struct tcf_block *block = tp->chain->block; + struct tc_cls_bpf_offload cls_bpf = {}; + struct cls_bpf_prog *obj; + bool skip_sw; + int err; + + skip_sw = prog && tc_skip_sw(prog->gen_flags); + obj = prog ?: oldprog; + + tc_cls_common_offload_init(&cls_bpf.common, tp, obj->gen_flags, extack); + cls_bpf.command = TC_CLSBPF_OFFLOAD; + cls_bpf.exts = &obj->exts; + cls_bpf.prog = prog ? prog->filter : NULL; + cls_bpf.oldprog = oldprog ? oldprog->filter : NULL; + cls_bpf.name = obj->bpf_name; + cls_bpf.exts_integrated = obj->exts_integrated; + + if (oldprog && prog) + err = tc_setup_cb_replace(block, tp, TC_SETUP_CLSBPF, &cls_bpf, + skip_sw, &oldprog->gen_flags, + &oldprog->in_hw_count, + &prog->gen_flags, &prog->in_hw_count, + true); + else if (prog) + err = tc_setup_cb_add(block, tp, TC_SETUP_CLSBPF, &cls_bpf, + skip_sw, &prog->gen_flags, + &prog->in_hw_count, true); + else + err = tc_setup_cb_destroy(block, tp, TC_SETUP_CLSBPF, &cls_bpf, + skip_sw, &oldprog->gen_flags, + &oldprog->in_hw_count, true); + + if (prog && err) { + cls_bpf_offload_cmd(tp, oldprog, prog, extack); + return err; + } + + if (prog && skip_sw && !(prog->gen_flags & TCA_CLS_FLAGS_IN_HW)) + return -EINVAL; + + return 0; +} + +static u32 cls_bpf_flags(u32 flags) +{ + return flags & CLS_BPF_SUPPORTED_GEN_FLAGS; +} + +static int cls_bpf_offload(struct tcf_proto *tp, struct cls_bpf_prog *prog, + struct cls_bpf_prog *oldprog, + struct netlink_ext_ack *extack) +{ + if (prog && oldprog && + cls_bpf_flags(prog->gen_flags) != + cls_bpf_flags(oldprog->gen_flags)) + return -EINVAL; + + if (prog && tc_skip_hw(prog->gen_flags)) + prog = NULL; + if (oldprog && tc_skip_hw(oldprog->gen_flags)) + oldprog = NULL; + if (!prog && !oldprog) + return 0; + + return cls_bpf_offload_cmd(tp, prog, oldprog, extack); +} + +static void cls_bpf_stop_offload(struct tcf_proto *tp, + struct cls_bpf_prog *prog, + struct netlink_ext_ack *extack) +{ + int err; + + err = cls_bpf_offload_cmd(tp, NULL, prog, extack); + if (err) + pr_err("Stopping hardware offload failed: %d\n", err); +} + +static void cls_bpf_offload_update_stats(struct tcf_proto *tp, + struct cls_bpf_prog *prog) +{ + struct tcf_block *block = tp->chain->block; + struct tc_cls_bpf_offload cls_bpf = {}; + + tc_cls_common_offload_init(&cls_bpf.common, tp, prog->gen_flags, NULL); + cls_bpf.command = TC_CLSBPF_STATS; + cls_bpf.exts = &prog->exts; + cls_bpf.prog = prog->filter; + cls_bpf.name = prog->bpf_name; + cls_bpf.exts_integrated = prog->exts_integrated; + + tc_setup_cb_call(block, TC_SETUP_CLSBPF, &cls_bpf, false, true); +} + +static int cls_bpf_init(struct tcf_proto *tp) +{ + struct cls_bpf_head *head; + + head = kzalloc(sizeof(*head), GFP_KERNEL); + if (head == NULL) + return -ENOBUFS; + + INIT_LIST_HEAD_RCU(&head->plist); + idr_init(&head->handle_idr); + rcu_assign_pointer(tp->root, head); + + return 0; +} + +static void cls_bpf_free_parms(struct cls_bpf_prog *prog) +{ + if (cls_bpf_is_ebpf(prog)) + bpf_prog_put(prog->filter); + else + bpf_prog_destroy(prog->filter); + + kfree(prog->bpf_name); + kfree(prog->bpf_ops); +} + +static void __cls_bpf_delete_prog(struct cls_bpf_prog *prog) +{ + tcf_exts_destroy(&prog->exts); + tcf_exts_put_net(&prog->exts); + + cls_bpf_free_parms(prog); + kfree(prog); +} + +static void cls_bpf_delete_prog_work(struct work_struct *work) +{ + struct cls_bpf_prog *prog = container_of(to_rcu_work(work), + struct cls_bpf_prog, + rwork); + rtnl_lock(); + __cls_bpf_delete_prog(prog); + rtnl_unlock(); +} + +static void __cls_bpf_delete(struct tcf_proto *tp, struct cls_bpf_prog *prog, + struct netlink_ext_ack *extack) +{ + struct cls_bpf_head *head = rtnl_dereference(tp->root); + + idr_remove(&head->handle_idr, prog->handle); + cls_bpf_stop_offload(tp, prog, extack); + list_del_rcu(&prog->link); + tcf_unbind_filter(tp, &prog->res); + if (tcf_exts_get_net(&prog->exts)) + tcf_queue_work(&prog->rwork, cls_bpf_delete_prog_work); + else + __cls_bpf_delete_prog(prog); +} + +static int cls_bpf_delete(struct tcf_proto *tp, void *arg, bool *last, + bool rtnl_held, struct netlink_ext_ack *extack) +{ + struct cls_bpf_head *head = rtnl_dereference(tp->root); + + __cls_bpf_delete(tp, arg, extack); + *last = list_empty(&head->plist); + return 0; +} + +static void cls_bpf_destroy(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + struct cls_bpf_head *head = rtnl_dereference(tp->root); + struct cls_bpf_prog *prog, *tmp; + + list_for_each_entry_safe(prog, tmp, &head->plist, link) + __cls_bpf_delete(tp, prog, extack); + + idr_destroy(&head->handle_idr); + kfree_rcu(head, rcu); +} + +static void *cls_bpf_get(struct tcf_proto *tp, u32 handle) +{ + struct cls_bpf_head *head = rtnl_dereference(tp->root); + struct cls_bpf_prog *prog; + + list_for_each_entry(prog, &head->plist, link) { + if (prog->handle == handle) + return prog; + } + + return NULL; +} + +static int cls_bpf_prog_from_ops(struct nlattr **tb, struct cls_bpf_prog *prog) +{ + struct sock_filter *bpf_ops; + struct sock_fprog_kern fprog_tmp; + struct bpf_prog *fp; + u16 bpf_size, bpf_num_ops; + int ret; + + bpf_num_ops = nla_get_u16(tb[TCA_BPF_OPS_LEN]); + if (bpf_num_ops > BPF_MAXINSNS || bpf_num_ops == 0) + return -EINVAL; + + bpf_size = bpf_num_ops * sizeof(*bpf_ops); + if (bpf_size != nla_len(tb[TCA_BPF_OPS])) + return -EINVAL; + + bpf_ops = kmemdup(nla_data(tb[TCA_BPF_OPS]), bpf_size, GFP_KERNEL); + if (bpf_ops == NULL) + return -ENOMEM; + + fprog_tmp.len = bpf_num_ops; + fprog_tmp.filter = bpf_ops; + + ret = bpf_prog_create(&fp, &fprog_tmp); + if (ret < 0) { + kfree(bpf_ops); + return ret; + } + + prog->bpf_ops = bpf_ops; + prog->bpf_num_ops = bpf_num_ops; + prog->bpf_name = NULL; + prog->filter = fp; + + return 0; +} + +static int cls_bpf_prog_from_efd(struct nlattr **tb, struct cls_bpf_prog *prog, + u32 gen_flags, const struct tcf_proto *tp) +{ + struct bpf_prog *fp; + char *name = NULL; + bool skip_sw; + u32 bpf_fd; + + bpf_fd = nla_get_u32(tb[TCA_BPF_FD]); + skip_sw = gen_flags & TCA_CLS_FLAGS_SKIP_SW; + + fp = bpf_prog_get_type_dev(bpf_fd, BPF_PROG_TYPE_SCHED_CLS, skip_sw); + if (IS_ERR(fp)) + return PTR_ERR(fp); + + if (tb[TCA_BPF_NAME]) { + name = nla_memdup(tb[TCA_BPF_NAME], GFP_KERNEL); + if (!name) { + bpf_prog_put(fp); + return -ENOMEM; + } + } + + prog->bpf_ops = NULL; + prog->bpf_name = name; + prog->filter = fp; + + if (fp->dst_needed) + tcf_block_netif_keep_dst(tp->chain->block); + + return 0; +} + +static int cls_bpf_change(struct net *net, struct sk_buff *in_skb, + struct tcf_proto *tp, unsigned long base, + u32 handle, struct nlattr **tca, + void **arg, u32 flags, + struct netlink_ext_ack *extack) +{ + struct cls_bpf_head *head = rtnl_dereference(tp->root); + bool is_bpf, is_ebpf, have_exts = false; + struct cls_bpf_prog *oldprog = *arg; + struct nlattr *tb[TCA_BPF_MAX + 1]; + bool bound_to_filter = false; + struct cls_bpf_prog *prog; + u32 gen_flags = 0; + int ret; + + if (tca[TCA_OPTIONS] == NULL) + return -EINVAL; + + ret = nla_parse_nested_deprecated(tb, TCA_BPF_MAX, tca[TCA_OPTIONS], + bpf_policy, NULL); + if (ret < 0) + return ret; + + prog = kzalloc(sizeof(*prog), GFP_KERNEL); + if (!prog) + return -ENOBUFS; + + ret = tcf_exts_init(&prog->exts, net, TCA_BPF_ACT, TCA_BPF_POLICE); + if (ret < 0) + goto errout; + + if (oldprog) { + if (handle && oldprog->handle != handle) { + ret = -EINVAL; + goto errout; + } + } + + if (handle == 0) { + handle = 1; + ret = idr_alloc_u32(&head->handle_idr, prog, &handle, + INT_MAX, GFP_KERNEL); + } else if (!oldprog) { + ret = idr_alloc_u32(&head->handle_idr, prog, &handle, + handle, GFP_KERNEL); + } + + if (ret) + goto errout; + prog->handle = handle; + + is_bpf = tb[TCA_BPF_OPS_LEN] && tb[TCA_BPF_OPS]; + is_ebpf = tb[TCA_BPF_FD]; + if ((!is_bpf && !is_ebpf) || (is_bpf && is_ebpf)) { + ret = -EINVAL; + goto errout_idr; + } + + ret = tcf_exts_validate(net, tp, tb, tca[TCA_RATE], &prog->exts, + flags, extack); + if (ret < 0) + goto errout_idr; + + if (tb[TCA_BPF_FLAGS]) { + u32 bpf_flags = nla_get_u32(tb[TCA_BPF_FLAGS]); + + if (bpf_flags & ~TCA_BPF_FLAG_ACT_DIRECT) { + ret = -EINVAL; + goto errout_idr; + } + + have_exts = bpf_flags & TCA_BPF_FLAG_ACT_DIRECT; + } + if (tb[TCA_BPF_FLAGS_GEN]) { + gen_flags = nla_get_u32(tb[TCA_BPF_FLAGS_GEN]); + if (gen_flags & ~CLS_BPF_SUPPORTED_GEN_FLAGS || + !tc_flags_valid(gen_flags)) { + ret = -EINVAL; + goto errout_idr; + } + } + + prog->exts_integrated = have_exts; + prog->gen_flags = gen_flags; + + ret = is_bpf ? cls_bpf_prog_from_ops(tb, prog) : + cls_bpf_prog_from_efd(tb, prog, gen_flags, tp); + if (ret < 0) + goto errout_idr; + + if (tb[TCA_BPF_CLASSID]) { + prog->res.classid = nla_get_u32(tb[TCA_BPF_CLASSID]); + tcf_bind_filter(tp, &prog->res, base); + bound_to_filter = true; + } + + ret = cls_bpf_offload(tp, prog, oldprog, extack); + if (ret) + goto errout_parms; + + if (!tc_in_hw(prog->gen_flags)) + prog->gen_flags |= TCA_CLS_FLAGS_NOT_IN_HW; + + if (oldprog) { + idr_replace(&head->handle_idr, prog, handle); + list_replace_rcu(&oldprog->link, &prog->link); + tcf_unbind_filter(tp, &oldprog->res); + tcf_exts_get_net(&oldprog->exts); + tcf_queue_work(&oldprog->rwork, cls_bpf_delete_prog_work); + } else { + list_add_rcu(&prog->link, &head->plist); + } + + *arg = prog; + return 0; + +errout_parms: + if (bound_to_filter) + tcf_unbind_filter(tp, &prog->res); + cls_bpf_free_parms(prog); +errout_idr: + if (!oldprog) + idr_remove(&head->handle_idr, prog->handle); +errout: + tcf_exts_destroy(&prog->exts); + kfree(prog); + return ret; +} + +static int cls_bpf_dump_bpf_info(const struct cls_bpf_prog *prog, + struct sk_buff *skb) +{ + struct nlattr *nla; + + if (nla_put_u16(skb, TCA_BPF_OPS_LEN, prog->bpf_num_ops)) + return -EMSGSIZE; + + nla = nla_reserve(skb, TCA_BPF_OPS, prog->bpf_num_ops * + sizeof(struct sock_filter)); + if (nla == NULL) + return -EMSGSIZE; + + memcpy(nla_data(nla), prog->bpf_ops, nla_len(nla)); + + return 0; +} + +static int cls_bpf_dump_ebpf_info(const struct cls_bpf_prog *prog, + struct sk_buff *skb) +{ + struct nlattr *nla; + + if (prog->bpf_name && + nla_put_string(skb, TCA_BPF_NAME, prog->bpf_name)) + return -EMSGSIZE; + + if (nla_put_u32(skb, TCA_BPF_ID, prog->filter->aux->id)) + return -EMSGSIZE; + + nla = nla_reserve(skb, TCA_BPF_TAG, sizeof(prog->filter->tag)); + if (nla == NULL) + return -EMSGSIZE; + + memcpy(nla_data(nla), prog->filter->tag, nla_len(nla)); + + return 0; +} + +static int cls_bpf_dump(struct net *net, struct tcf_proto *tp, void *fh, + struct sk_buff *skb, struct tcmsg *tm, bool rtnl_held) +{ + struct cls_bpf_prog *prog = fh; + struct nlattr *nest; + u32 bpf_flags = 0; + int ret; + + if (prog == NULL) + return skb->len; + + tm->tcm_handle = prog->handle; + + cls_bpf_offload_update_stats(tp, prog); + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + + if (prog->res.classid && + nla_put_u32(skb, TCA_BPF_CLASSID, prog->res.classid)) + goto nla_put_failure; + + if (cls_bpf_is_ebpf(prog)) + ret = cls_bpf_dump_ebpf_info(prog, skb); + else + ret = cls_bpf_dump_bpf_info(prog, skb); + if (ret) + goto nla_put_failure; + + if (tcf_exts_dump(skb, &prog->exts) < 0) + goto nla_put_failure; + + if (prog->exts_integrated) + bpf_flags |= TCA_BPF_FLAG_ACT_DIRECT; + if (bpf_flags && nla_put_u32(skb, TCA_BPF_FLAGS, bpf_flags)) + goto nla_put_failure; + if (prog->gen_flags && + nla_put_u32(skb, TCA_BPF_FLAGS_GEN, prog->gen_flags)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + + if (tcf_exts_dump_stats(skb, &prog->exts) < 0) + goto nla_put_failure; + + return skb->len; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static void cls_bpf_bind_class(void *fh, u32 classid, unsigned long cl, + void *q, unsigned long base) +{ + struct cls_bpf_prog *prog = fh; + + tc_cls_bind_class(classid, cl, q, &prog->res, base); +} + +static void cls_bpf_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) +{ + struct cls_bpf_head *head = rtnl_dereference(tp->root); + struct cls_bpf_prog *prog; + + list_for_each_entry(prog, &head->plist, link) { + if (!tc_cls_stats_dump(tp, arg, prog)) + break; + } +} + +static int cls_bpf_reoffload(struct tcf_proto *tp, bool add, flow_setup_cb_t *cb, + void *cb_priv, struct netlink_ext_ack *extack) +{ + struct cls_bpf_head *head = rtnl_dereference(tp->root); + struct tcf_block *block = tp->chain->block; + struct tc_cls_bpf_offload cls_bpf = {}; + struct cls_bpf_prog *prog; + int err; + + list_for_each_entry(prog, &head->plist, link) { + if (tc_skip_hw(prog->gen_flags)) + continue; + + tc_cls_common_offload_init(&cls_bpf.common, tp, prog->gen_flags, + extack); + cls_bpf.command = TC_CLSBPF_OFFLOAD; + cls_bpf.exts = &prog->exts; + cls_bpf.prog = add ? prog->filter : NULL; + cls_bpf.oldprog = add ? NULL : prog->filter; + cls_bpf.name = prog->bpf_name; + cls_bpf.exts_integrated = prog->exts_integrated; + + err = tc_setup_cb_reoffload(block, tp, add, cb, TC_SETUP_CLSBPF, + &cls_bpf, cb_priv, &prog->gen_flags, + &prog->in_hw_count); + if (err) + return err; + } + + return 0; +} + +static struct tcf_proto_ops cls_bpf_ops __read_mostly = { + .kind = "bpf", + .owner = THIS_MODULE, + .classify = cls_bpf_classify, + .init = cls_bpf_init, + .destroy = cls_bpf_destroy, + .get = cls_bpf_get, + .change = cls_bpf_change, + .delete = cls_bpf_delete, + .walk = cls_bpf_walk, + .reoffload = cls_bpf_reoffload, + .dump = cls_bpf_dump, + .bind_class = cls_bpf_bind_class, +}; + +static int __init cls_bpf_init_mod(void) +{ + return register_tcf_proto_ops(&cls_bpf_ops); +} + +static void __exit cls_bpf_exit_mod(void) +{ + unregister_tcf_proto_ops(&cls_bpf_ops); +} + +module_init(cls_bpf_init_mod); +module_exit(cls_bpf_exit_mod); diff --git a/net/sched/cls_cgroup.c b/net/sched/cls_cgroup.c new file mode 100644 index 000000000..ed00001b5 --- /dev/null +++ b/net/sched/cls_cgroup.c @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/cls_cgroup.c Control Group Classifier + * + * Authors: Thomas Graf + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +struct cls_cgroup_head { + u32 handle; + struct tcf_exts exts; + struct tcf_ematch_tree ematches; + struct tcf_proto *tp; + struct rcu_work rwork; +}; + +static int cls_cgroup_classify(struct sk_buff *skb, const struct tcf_proto *tp, + struct tcf_result *res) +{ + struct cls_cgroup_head *head = rcu_dereference_bh(tp->root); + u32 classid = task_get_classid(skb); + + if (unlikely(!head)) + return -1; + if (!classid) + return -1; + if (!tcf_em_tree_match(skb, &head->ematches, NULL)) + return -1; + + res->classid = classid; + res->class = 0; + + return tcf_exts_exec(skb, &head->exts, res); +} + +static void *cls_cgroup_get(struct tcf_proto *tp, u32 handle) +{ + return NULL; +} + +static int cls_cgroup_init(struct tcf_proto *tp) +{ + return 0; +} + +static const struct nla_policy cgroup_policy[TCA_CGROUP_MAX + 1] = { + [TCA_CGROUP_EMATCHES] = { .type = NLA_NESTED }, +}; + +static void __cls_cgroup_destroy(struct cls_cgroup_head *head) +{ + tcf_exts_destroy(&head->exts); + tcf_em_tree_destroy(&head->ematches); + tcf_exts_put_net(&head->exts); + kfree(head); +} + +static void cls_cgroup_destroy_work(struct work_struct *work) +{ + struct cls_cgroup_head *head = container_of(to_rcu_work(work), + struct cls_cgroup_head, + rwork); + rtnl_lock(); + __cls_cgroup_destroy(head); + rtnl_unlock(); +} + +static int cls_cgroup_change(struct net *net, struct sk_buff *in_skb, + struct tcf_proto *tp, unsigned long base, + u32 handle, struct nlattr **tca, + void **arg, u32 flags, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_CGROUP_MAX + 1]; + struct cls_cgroup_head *head = rtnl_dereference(tp->root); + struct cls_cgroup_head *new; + int err; + + if (!tca[TCA_OPTIONS]) + return -EINVAL; + + if (!head && !handle) + return -EINVAL; + + if (head && handle != head->handle) + return -ENOENT; + + new = kzalloc(sizeof(*head), GFP_KERNEL); + if (!new) + return -ENOBUFS; + + err = tcf_exts_init(&new->exts, net, TCA_CGROUP_ACT, TCA_CGROUP_POLICE); + if (err < 0) + goto errout; + new->handle = handle; + new->tp = tp; + err = nla_parse_nested_deprecated(tb, TCA_CGROUP_MAX, + tca[TCA_OPTIONS], cgroup_policy, + NULL); + if (err < 0) + goto errout; + + err = tcf_exts_validate(net, tp, tb, tca[TCA_RATE], &new->exts, flags, + extack); + if (err < 0) + goto errout; + + err = tcf_em_tree_validate(tp, tb[TCA_CGROUP_EMATCHES], &new->ematches); + if (err < 0) + goto errout; + + rcu_assign_pointer(tp->root, new); + if (head) { + tcf_exts_get_net(&head->exts); + tcf_queue_work(&head->rwork, cls_cgroup_destroy_work); + } + return 0; +errout: + tcf_exts_destroy(&new->exts); + kfree(new); + return err; +} + +static void cls_cgroup_destroy(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + struct cls_cgroup_head *head = rtnl_dereference(tp->root); + + /* Head can still be NULL due to cls_cgroup_init(). */ + if (head) { + if (tcf_exts_get_net(&head->exts)) + tcf_queue_work(&head->rwork, cls_cgroup_destroy_work); + else + __cls_cgroup_destroy(head); + } +} + +static int cls_cgroup_delete(struct tcf_proto *tp, void *arg, bool *last, + bool rtnl_held, struct netlink_ext_ack *extack) +{ + return -EOPNOTSUPP; +} + +static void cls_cgroup_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) +{ + struct cls_cgroup_head *head = rtnl_dereference(tp->root); + + if (arg->count < arg->skip) + goto skip; + + if (!head) + return; + if (arg->fn(tp, head, arg) < 0) { + arg->stop = 1; + return; + } +skip: + arg->count++; +} + +static int cls_cgroup_dump(struct net *net, struct tcf_proto *tp, void *fh, + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) +{ + struct cls_cgroup_head *head = rtnl_dereference(tp->root); + struct nlattr *nest; + + t->tcm_handle = head->handle; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + + if (tcf_exts_dump(skb, &head->exts) < 0 || + tcf_em_tree_dump(skb, &head->ematches, TCA_CGROUP_EMATCHES) < 0) + goto nla_put_failure; + + nla_nest_end(skb, nest); + + if (tcf_exts_dump_stats(skb, &head->exts) < 0) + goto nla_put_failure; + + return skb->len; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static struct tcf_proto_ops cls_cgroup_ops __read_mostly = { + .kind = "cgroup", + .init = cls_cgroup_init, + .change = cls_cgroup_change, + .classify = cls_cgroup_classify, + .destroy = cls_cgroup_destroy, + .get = cls_cgroup_get, + .delete = cls_cgroup_delete, + .walk = cls_cgroup_walk, + .dump = cls_cgroup_dump, + .owner = THIS_MODULE, +}; + +static int __init init_cgroup_cls(void) +{ + return register_tcf_proto_ops(&cls_cgroup_ops); +} + +static void __exit exit_cgroup_cls(void) +{ + unregister_tcf_proto_ops(&cls_cgroup_ops); +} + +module_init(init_cgroup_cls); +module_exit(exit_cgroup_cls); +MODULE_LICENSE("GPL"); diff --git a/net/sched/cls_flow.c b/net/sched/cls_flow.c new file mode 100644 index 000000000..014cd3de7 --- /dev/null +++ b/net/sched/cls_flow.c @@ -0,0 +1,719 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/cls_flow.c Generic flow classifier + * + * Copyright (c) 2007, 2008 Patrick McHardy + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_NF_CONNTRACK) +#include +#endif + +struct flow_head { + struct list_head filters; + struct rcu_head rcu; +}; + +struct flow_filter { + struct list_head list; + struct tcf_exts exts; + struct tcf_ematch_tree ematches; + struct tcf_proto *tp; + struct timer_list perturb_timer; + u32 perturb_period; + u32 handle; + + u32 nkeys; + u32 keymask; + u32 mode; + u32 mask; + u32 xor; + u32 rshift; + u32 addend; + u32 divisor; + u32 baseclass; + u32 hashrnd; + struct rcu_work rwork; +}; + +static inline u32 addr_fold(void *addr) +{ + unsigned long a = (unsigned long)addr; + + return (a & 0xFFFFFFFF) ^ (BITS_PER_LONG > 32 ? a >> 32 : 0); +} + +static u32 flow_get_src(const struct sk_buff *skb, const struct flow_keys *flow) +{ + __be32 src = flow_get_u32_src(flow); + + if (src) + return ntohl(src); + + return addr_fold(skb->sk); +} + +static u32 flow_get_dst(const struct sk_buff *skb, const struct flow_keys *flow) +{ + __be32 dst = flow_get_u32_dst(flow); + + if (dst) + return ntohl(dst); + + return addr_fold(skb_dst(skb)) ^ (__force u16)skb_protocol(skb, true); +} + +static u32 flow_get_proto(const struct sk_buff *skb, + const struct flow_keys *flow) +{ + return flow->basic.ip_proto; +} + +static u32 flow_get_proto_src(const struct sk_buff *skb, + const struct flow_keys *flow) +{ + if (flow->ports.ports) + return ntohs(flow->ports.src); + + return addr_fold(skb->sk); +} + +static u32 flow_get_proto_dst(const struct sk_buff *skb, + const struct flow_keys *flow) +{ + if (flow->ports.ports) + return ntohs(flow->ports.dst); + + return addr_fold(skb_dst(skb)) ^ (__force u16)skb_protocol(skb, true); +} + +static u32 flow_get_iif(const struct sk_buff *skb) +{ + return skb->skb_iif; +} + +static u32 flow_get_priority(const struct sk_buff *skb) +{ + return skb->priority; +} + +static u32 flow_get_mark(const struct sk_buff *skb) +{ + return skb->mark; +} + +static u32 flow_get_nfct(const struct sk_buff *skb) +{ +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + return addr_fold(skb_nfct(skb)); +#else + return 0; +#endif +} + +#if IS_ENABLED(CONFIG_NF_CONNTRACK) +#define CTTUPLE(skb, member) \ +({ \ + enum ip_conntrack_info ctinfo; \ + const struct nf_conn *ct = nf_ct_get(skb, &ctinfo); \ + if (ct == NULL) \ + goto fallback; \ + ct->tuplehash[CTINFO2DIR(ctinfo)].tuple.member; \ +}) +#else +#define CTTUPLE(skb, member) \ +({ \ + goto fallback; \ + 0; \ +}) +#endif + +static u32 flow_get_nfct_src(const struct sk_buff *skb, + const struct flow_keys *flow) +{ + switch (skb_protocol(skb, true)) { + case htons(ETH_P_IP): + return ntohl(CTTUPLE(skb, src.u3.ip)); + case htons(ETH_P_IPV6): + return ntohl(CTTUPLE(skb, src.u3.ip6[3])); + } +fallback: + return flow_get_src(skb, flow); +} + +static u32 flow_get_nfct_dst(const struct sk_buff *skb, + const struct flow_keys *flow) +{ + switch (skb_protocol(skb, true)) { + case htons(ETH_P_IP): + return ntohl(CTTUPLE(skb, dst.u3.ip)); + case htons(ETH_P_IPV6): + return ntohl(CTTUPLE(skb, dst.u3.ip6[3])); + } +fallback: + return flow_get_dst(skb, flow); +} + +static u32 flow_get_nfct_proto_src(const struct sk_buff *skb, + const struct flow_keys *flow) +{ + return ntohs(CTTUPLE(skb, src.u.all)); +fallback: + return flow_get_proto_src(skb, flow); +} + +static u32 flow_get_nfct_proto_dst(const struct sk_buff *skb, + const struct flow_keys *flow) +{ + return ntohs(CTTUPLE(skb, dst.u.all)); +fallback: + return flow_get_proto_dst(skb, flow); +} + +static u32 flow_get_rtclassid(const struct sk_buff *skb) +{ +#ifdef CONFIG_IP_ROUTE_CLASSID + if (skb_dst(skb)) + return skb_dst(skb)->tclassid; +#endif + return 0; +} + +static u32 flow_get_skuid(const struct sk_buff *skb) +{ + struct sock *sk = skb_to_full_sk(skb); + + if (sk && sk->sk_socket && sk->sk_socket->file) { + kuid_t skuid = sk->sk_socket->file->f_cred->fsuid; + + return from_kuid(&init_user_ns, skuid); + } + return 0; +} + +static u32 flow_get_skgid(const struct sk_buff *skb) +{ + struct sock *sk = skb_to_full_sk(skb); + + if (sk && sk->sk_socket && sk->sk_socket->file) { + kgid_t skgid = sk->sk_socket->file->f_cred->fsgid; + + return from_kgid(&init_user_ns, skgid); + } + return 0; +} + +static u32 flow_get_vlan_tag(const struct sk_buff *skb) +{ + u16 tag; + + if (vlan_get_tag(skb, &tag) < 0) + return 0; + return tag & VLAN_VID_MASK; +} + +static u32 flow_get_rxhash(struct sk_buff *skb) +{ + return skb_get_hash(skb); +} + +static u32 flow_key_get(struct sk_buff *skb, int key, struct flow_keys *flow) +{ + switch (key) { + case FLOW_KEY_SRC: + return flow_get_src(skb, flow); + case FLOW_KEY_DST: + return flow_get_dst(skb, flow); + case FLOW_KEY_PROTO: + return flow_get_proto(skb, flow); + case FLOW_KEY_PROTO_SRC: + return flow_get_proto_src(skb, flow); + case FLOW_KEY_PROTO_DST: + return flow_get_proto_dst(skb, flow); + case FLOW_KEY_IIF: + return flow_get_iif(skb); + case FLOW_KEY_PRIORITY: + return flow_get_priority(skb); + case FLOW_KEY_MARK: + return flow_get_mark(skb); + case FLOW_KEY_NFCT: + return flow_get_nfct(skb); + case FLOW_KEY_NFCT_SRC: + return flow_get_nfct_src(skb, flow); + case FLOW_KEY_NFCT_DST: + return flow_get_nfct_dst(skb, flow); + case FLOW_KEY_NFCT_PROTO_SRC: + return flow_get_nfct_proto_src(skb, flow); + case FLOW_KEY_NFCT_PROTO_DST: + return flow_get_nfct_proto_dst(skb, flow); + case FLOW_KEY_RTCLASSID: + return flow_get_rtclassid(skb); + case FLOW_KEY_SKUID: + return flow_get_skuid(skb); + case FLOW_KEY_SKGID: + return flow_get_skgid(skb); + case FLOW_KEY_VLAN_TAG: + return flow_get_vlan_tag(skb); + case FLOW_KEY_RXHASH: + return flow_get_rxhash(skb); + default: + WARN_ON(1); + return 0; + } +} + +#define FLOW_KEYS_NEEDED ((1 << FLOW_KEY_SRC) | \ + (1 << FLOW_KEY_DST) | \ + (1 << FLOW_KEY_PROTO) | \ + (1 << FLOW_KEY_PROTO_SRC) | \ + (1 << FLOW_KEY_PROTO_DST) | \ + (1 << FLOW_KEY_NFCT_SRC) | \ + (1 << FLOW_KEY_NFCT_DST) | \ + (1 << FLOW_KEY_NFCT_PROTO_SRC) | \ + (1 << FLOW_KEY_NFCT_PROTO_DST)) + +static int flow_classify(struct sk_buff *skb, const struct tcf_proto *tp, + struct tcf_result *res) +{ + struct flow_head *head = rcu_dereference_bh(tp->root); + struct flow_filter *f; + u32 keymask; + u32 classid; + unsigned int n, key; + int r; + + list_for_each_entry_rcu(f, &head->filters, list) { + u32 keys[FLOW_KEY_MAX + 1]; + struct flow_keys flow_keys; + + if (!tcf_em_tree_match(skb, &f->ematches, NULL)) + continue; + + keymask = f->keymask; + if (keymask & FLOW_KEYS_NEEDED) + skb_flow_dissect_flow_keys(skb, &flow_keys, 0); + + for (n = 0; n < f->nkeys; n++) { + key = ffs(keymask) - 1; + keymask &= ~(1 << key); + keys[n] = flow_key_get(skb, key, &flow_keys); + } + + if (f->mode == FLOW_MODE_HASH) + classid = jhash2(keys, f->nkeys, f->hashrnd); + else { + classid = keys[0]; + classid = (classid & f->mask) ^ f->xor; + classid = (classid >> f->rshift) + f->addend; + } + + if (f->divisor) + classid %= f->divisor; + + res->class = 0; + res->classid = TC_H_MAKE(f->baseclass, f->baseclass + classid); + + r = tcf_exts_exec(skb, &f->exts, res); + if (r < 0) + continue; + return r; + } + return -1; +} + +static void flow_perturbation(struct timer_list *t) +{ + struct flow_filter *f = from_timer(f, t, perturb_timer); + + get_random_bytes(&f->hashrnd, 4); + if (f->perturb_period) + mod_timer(&f->perturb_timer, jiffies + f->perturb_period); +} + +static const struct nla_policy flow_policy[TCA_FLOW_MAX + 1] = { + [TCA_FLOW_KEYS] = { .type = NLA_U32 }, + [TCA_FLOW_MODE] = { .type = NLA_U32 }, + [TCA_FLOW_BASECLASS] = { .type = NLA_U32 }, + [TCA_FLOW_RSHIFT] = { .type = NLA_U32 }, + [TCA_FLOW_ADDEND] = { .type = NLA_U32 }, + [TCA_FLOW_MASK] = { .type = NLA_U32 }, + [TCA_FLOW_XOR] = { .type = NLA_U32 }, + [TCA_FLOW_DIVISOR] = { .type = NLA_U32 }, + [TCA_FLOW_ACT] = { .type = NLA_NESTED }, + [TCA_FLOW_POLICE] = { .type = NLA_NESTED }, + [TCA_FLOW_EMATCHES] = { .type = NLA_NESTED }, + [TCA_FLOW_PERTURB] = { .type = NLA_U32 }, +}; + +static void __flow_destroy_filter(struct flow_filter *f) +{ + del_timer_sync(&f->perturb_timer); + tcf_exts_destroy(&f->exts); + tcf_em_tree_destroy(&f->ematches); + tcf_exts_put_net(&f->exts); + kfree(f); +} + +static void flow_destroy_filter_work(struct work_struct *work) +{ + struct flow_filter *f = container_of(to_rcu_work(work), + struct flow_filter, + rwork); + rtnl_lock(); + __flow_destroy_filter(f); + rtnl_unlock(); +} + +static int flow_change(struct net *net, struct sk_buff *in_skb, + struct tcf_proto *tp, unsigned long base, + u32 handle, struct nlattr **tca, + void **arg, u32 flags, + struct netlink_ext_ack *extack) +{ + struct flow_head *head = rtnl_dereference(tp->root); + struct flow_filter *fold, *fnew; + struct nlattr *opt = tca[TCA_OPTIONS]; + struct nlattr *tb[TCA_FLOW_MAX + 1]; + unsigned int nkeys = 0; + unsigned int perturb_period = 0; + u32 baseclass = 0; + u32 keymask = 0; + u32 mode; + int err; + + if (opt == NULL) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, TCA_FLOW_MAX, opt, flow_policy, + NULL); + if (err < 0) + return err; + + if (tb[TCA_FLOW_BASECLASS]) { + baseclass = nla_get_u32(tb[TCA_FLOW_BASECLASS]); + if (TC_H_MIN(baseclass) == 0) + return -EINVAL; + } + + if (tb[TCA_FLOW_KEYS]) { + keymask = nla_get_u32(tb[TCA_FLOW_KEYS]); + + nkeys = hweight32(keymask); + if (nkeys == 0) + return -EINVAL; + + if (fls(keymask) - 1 > FLOW_KEY_MAX) + return -EOPNOTSUPP; + + if ((keymask & (FLOW_KEY_SKUID|FLOW_KEY_SKGID)) && + sk_user_ns(NETLINK_CB(in_skb).sk) != &init_user_ns) + return -EOPNOTSUPP; + } + + fnew = kzalloc(sizeof(*fnew), GFP_KERNEL); + if (!fnew) + return -ENOBUFS; + + err = tcf_em_tree_validate(tp, tb[TCA_FLOW_EMATCHES], &fnew->ematches); + if (err < 0) + goto err1; + + err = tcf_exts_init(&fnew->exts, net, TCA_FLOW_ACT, TCA_FLOW_POLICE); + if (err < 0) + goto err2; + + err = tcf_exts_validate(net, tp, tb, tca[TCA_RATE], &fnew->exts, flags, + extack); + if (err < 0) + goto err2; + + fold = *arg; + if (fold) { + err = -EINVAL; + if (fold->handle != handle && handle) + goto err2; + + /* Copy fold into fnew */ + fnew->tp = fold->tp; + fnew->handle = fold->handle; + fnew->nkeys = fold->nkeys; + fnew->keymask = fold->keymask; + fnew->mode = fold->mode; + fnew->mask = fold->mask; + fnew->xor = fold->xor; + fnew->rshift = fold->rshift; + fnew->addend = fold->addend; + fnew->divisor = fold->divisor; + fnew->baseclass = fold->baseclass; + fnew->hashrnd = fold->hashrnd; + + mode = fold->mode; + if (tb[TCA_FLOW_MODE]) + mode = nla_get_u32(tb[TCA_FLOW_MODE]); + if (mode != FLOW_MODE_HASH && nkeys > 1) + goto err2; + + if (mode == FLOW_MODE_HASH) + perturb_period = fold->perturb_period; + if (tb[TCA_FLOW_PERTURB]) { + if (mode != FLOW_MODE_HASH) + goto err2; + perturb_period = nla_get_u32(tb[TCA_FLOW_PERTURB]) * HZ; + } + } else { + err = -EINVAL; + if (!handle) + goto err2; + if (!tb[TCA_FLOW_KEYS]) + goto err2; + + mode = FLOW_MODE_MAP; + if (tb[TCA_FLOW_MODE]) + mode = nla_get_u32(tb[TCA_FLOW_MODE]); + if (mode != FLOW_MODE_HASH && nkeys > 1) + goto err2; + + if (tb[TCA_FLOW_PERTURB]) { + if (mode != FLOW_MODE_HASH) + goto err2; + perturb_period = nla_get_u32(tb[TCA_FLOW_PERTURB]) * HZ; + } + + if (TC_H_MAJ(baseclass) == 0) { + struct Qdisc *q = tcf_block_q(tp->chain->block); + + baseclass = TC_H_MAKE(q->handle, baseclass); + } + if (TC_H_MIN(baseclass) == 0) + baseclass = TC_H_MAKE(baseclass, 1); + + fnew->handle = handle; + fnew->mask = ~0U; + fnew->tp = tp; + get_random_bytes(&fnew->hashrnd, 4); + } + + timer_setup(&fnew->perturb_timer, flow_perturbation, TIMER_DEFERRABLE); + + tcf_block_netif_keep_dst(tp->chain->block); + + if (tb[TCA_FLOW_KEYS]) { + fnew->keymask = keymask; + fnew->nkeys = nkeys; + } + + fnew->mode = mode; + + if (tb[TCA_FLOW_MASK]) + fnew->mask = nla_get_u32(tb[TCA_FLOW_MASK]); + if (tb[TCA_FLOW_XOR]) + fnew->xor = nla_get_u32(tb[TCA_FLOW_XOR]); + if (tb[TCA_FLOW_RSHIFT]) + fnew->rshift = nla_get_u32(tb[TCA_FLOW_RSHIFT]); + if (tb[TCA_FLOW_ADDEND]) + fnew->addend = nla_get_u32(tb[TCA_FLOW_ADDEND]); + + if (tb[TCA_FLOW_DIVISOR]) + fnew->divisor = nla_get_u32(tb[TCA_FLOW_DIVISOR]); + if (baseclass) + fnew->baseclass = baseclass; + + fnew->perturb_period = perturb_period; + if (perturb_period) + mod_timer(&fnew->perturb_timer, jiffies + perturb_period); + + if (!*arg) + list_add_tail_rcu(&fnew->list, &head->filters); + else + list_replace_rcu(&fold->list, &fnew->list); + + *arg = fnew; + + if (fold) { + tcf_exts_get_net(&fold->exts); + tcf_queue_work(&fold->rwork, flow_destroy_filter_work); + } + return 0; + +err2: + tcf_exts_destroy(&fnew->exts); + tcf_em_tree_destroy(&fnew->ematches); +err1: + kfree(fnew); + return err; +} + +static int flow_delete(struct tcf_proto *tp, void *arg, bool *last, + bool rtnl_held, struct netlink_ext_ack *extack) +{ + struct flow_head *head = rtnl_dereference(tp->root); + struct flow_filter *f = arg; + + list_del_rcu(&f->list); + tcf_exts_get_net(&f->exts); + tcf_queue_work(&f->rwork, flow_destroy_filter_work); + *last = list_empty(&head->filters); + return 0; +} + +static int flow_init(struct tcf_proto *tp) +{ + struct flow_head *head; + + head = kzalloc(sizeof(*head), GFP_KERNEL); + if (head == NULL) + return -ENOBUFS; + INIT_LIST_HEAD(&head->filters); + rcu_assign_pointer(tp->root, head); + return 0; +} + +static void flow_destroy(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + struct flow_head *head = rtnl_dereference(tp->root); + struct flow_filter *f, *next; + + list_for_each_entry_safe(f, next, &head->filters, list) { + list_del_rcu(&f->list); + if (tcf_exts_get_net(&f->exts)) + tcf_queue_work(&f->rwork, flow_destroy_filter_work); + else + __flow_destroy_filter(f); + } + kfree_rcu(head, rcu); +} + +static void *flow_get(struct tcf_proto *tp, u32 handle) +{ + struct flow_head *head = rtnl_dereference(tp->root); + struct flow_filter *f; + + list_for_each_entry(f, &head->filters, list) + if (f->handle == handle) + return f; + return NULL; +} + +static int flow_dump(struct net *net, struct tcf_proto *tp, void *fh, + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) +{ + struct flow_filter *f = fh; + struct nlattr *nest; + + if (f == NULL) + return skb->len; + + t->tcm_handle = f->handle; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_FLOW_KEYS, f->keymask) || + nla_put_u32(skb, TCA_FLOW_MODE, f->mode)) + goto nla_put_failure; + + if (f->mask != ~0 || f->xor != 0) { + if (nla_put_u32(skb, TCA_FLOW_MASK, f->mask) || + nla_put_u32(skb, TCA_FLOW_XOR, f->xor)) + goto nla_put_failure; + } + if (f->rshift && + nla_put_u32(skb, TCA_FLOW_RSHIFT, f->rshift)) + goto nla_put_failure; + if (f->addend && + nla_put_u32(skb, TCA_FLOW_ADDEND, f->addend)) + goto nla_put_failure; + + if (f->divisor && + nla_put_u32(skb, TCA_FLOW_DIVISOR, f->divisor)) + goto nla_put_failure; + if (f->baseclass && + nla_put_u32(skb, TCA_FLOW_BASECLASS, f->baseclass)) + goto nla_put_failure; + + if (f->perturb_period && + nla_put_u32(skb, TCA_FLOW_PERTURB, f->perturb_period / HZ)) + goto nla_put_failure; + + if (tcf_exts_dump(skb, &f->exts) < 0) + goto nla_put_failure; +#ifdef CONFIG_NET_EMATCH + if (f->ematches.hdr.nmatches && + tcf_em_tree_dump(skb, &f->ematches, TCA_FLOW_EMATCHES) < 0) + goto nla_put_failure; +#endif + nla_nest_end(skb, nest); + + if (tcf_exts_dump_stats(skb, &f->exts) < 0) + goto nla_put_failure; + + return skb->len; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static void flow_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) +{ + struct flow_head *head = rtnl_dereference(tp->root); + struct flow_filter *f; + + list_for_each_entry(f, &head->filters, list) { + if (!tc_cls_stats_dump(tp, arg, f)) + break; + } +} + +static struct tcf_proto_ops cls_flow_ops __read_mostly = { + .kind = "flow", + .classify = flow_classify, + .init = flow_init, + .destroy = flow_destroy, + .change = flow_change, + .delete = flow_delete, + .get = flow_get, + .dump = flow_dump, + .walk = flow_walk, + .owner = THIS_MODULE, +}; + +static int __init cls_flow_init(void) +{ + return register_tcf_proto_ops(&cls_flow_ops); +} + +static void __exit cls_flow_exit(void) +{ + unregister_tcf_proto_ops(&cls_flow_ops); +} + +module_init(cls_flow_init); +module_exit(cls_flow_exit); + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Patrick McHardy "); +MODULE_DESCRIPTION("TC flow classifier"); diff --git a/net/sched/cls_flower.c b/net/sched/cls_flower.c new file mode 100644 index 000000000..10e6ec0f9 --- /dev/null +++ b/net/sched/cls_flower.c @@ -0,0 +1,3474 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/cls_flower.c Flower classifier + * + * Copyright (c) 2015 Jiri Pirko + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#define TCA_FLOWER_KEY_CT_FLAGS_MAX \ + ((__TCA_FLOWER_KEY_CT_FLAGS_MAX - 1) << 1) +#define TCA_FLOWER_KEY_CT_FLAGS_MASK \ + (TCA_FLOWER_KEY_CT_FLAGS_MAX - 1) + +struct fl_flow_key { + struct flow_dissector_key_meta meta; + struct flow_dissector_key_control control; + struct flow_dissector_key_control enc_control; + struct flow_dissector_key_basic basic; + struct flow_dissector_key_eth_addrs eth; + struct flow_dissector_key_vlan vlan; + struct flow_dissector_key_vlan cvlan; + union { + struct flow_dissector_key_ipv4_addrs ipv4; + struct flow_dissector_key_ipv6_addrs ipv6; + }; + struct flow_dissector_key_ports tp; + struct flow_dissector_key_icmp icmp; + struct flow_dissector_key_arp arp; + struct flow_dissector_key_keyid enc_key_id; + union { + struct flow_dissector_key_ipv4_addrs enc_ipv4; + struct flow_dissector_key_ipv6_addrs enc_ipv6; + }; + struct flow_dissector_key_ports enc_tp; + struct flow_dissector_key_mpls mpls; + struct flow_dissector_key_tcp tcp; + struct flow_dissector_key_ip ip; + struct flow_dissector_key_ip enc_ip; + struct flow_dissector_key_enc_opts enc_opts; + struct flow_dissector_key_ports_range tp_range; + struct flow_dissector_key_ct ct; + struct flow_dissector_key_hash hash; + struct flow_dissector_key_num_of_vlans num_of_vlans; + struct flow_dissector_key_pppoe pppoe; + struct flow_dissector_key_l2tpv3 l2tpv3; +} __aligned(BITS_PER_LONG / 8); /* Ensure that we can do comparisons as longs. */ + +struct fl_flow_mask_range { + unsigned short int start; + unsigned short int end; +}; + +struct fl_flow_mask { + struct fl_flow_key key; + struct fl_flow_mask_range range; + u32 flags; + struct rhash_head ht_node; + struct rhashtable ht; + struct rhashtable_params filter_ht_params; + struct flow_dissector dissector; + struct list_head filters; + struct rcu_work rwork; + struct list_head list; + refcount_t refcnt; +}; + +struct fl_flow_tmplt { + struct fl_flow_key dummy_key; + struct fl_flow_key mask; + struct flow_dissector dissector; + struct tcf_chain *chain; +}; + +struct cls_fl_head { + struct rhashtable ht; + spinlock_t masks_lock; /* Protect masks list */ + struct list_head masks; + struct list_head hw_filters; + struct rcu_work rwork; + struct idr handle_idr; +}; + +struct cls_fl_filter { + struct fl_flow_mask *mask; + struct rhash_head ht_node; + struct fl_flow_key mkey; + struct tcf_exts exts; + struct tcf_result res; + struct fl_flow_key key; + struct list_head list; + struct list_head hw_list; + u32 handle; + u32 flags; + u32 in_hw_count; + struct rcu_work rwork; + struct net_device *hw_dev; + /* Flower classifier is unlocked, which means that its reference counter + * can be changed concurrently without any kind of external + * synchronization. Use atomic reference counter to be concurrency-safe. + */ + refcount_t refcnt; + bool deleted; +}; + +static const struct rhashtable_params mask_ht_params = { + .key_offset = offsetof(struct fl_flow_mask, key), + .key_len = sizeof(struct fl_flow_key), + .head_offset = offsetof(struct fl_flow_mask, ht_node), + .automatic_shrinking = true, +}; + +static unsigned short int fl_mask_range(const struct fl_flow_mask *mask) +{ + return mask->range.end - mask->range.start; +} + +static void fl_mask_update_range(struct fl_flow_mask *mask) +{ + const u8 *bytes = (const u8 *) &mask->key; + size_t size = sizeof(mask->key); + size_t i, first = 0, last; + + for (i = 0; i < size; i++) { + if (bytes[i]) { + first = i; + break; + } + } + last = first; + for (i = size - 1; i != first; i--) { + if (bytes[i]) { + last = i; + break; + } + } + mask->range.start = rounddown(first, sizeof(long)); + mask->range.end = roundup(last + 1, sizeof(long)); +} + +static void *fl_key_get_start(struct fl_flow_key *key, + const struct fl_flow_mask *mask) +{ + return (u8 *) key + mask->range.start; +} + +static void fl_set_masked_key(struct fl_flow_key *mkey, struct fl_flow_key *key, + struct fl_flow_mask *mask) +{ + const long *lkey = fl_key_get_start(key, mask); + const long *lmask = fl_key_get_start(&mask->key, mask); + long *lmkey = fl_key_get_start(mkey, mask); + int i; + + for (i = 0; i < fl_mask_range(mask); i += sizeof(long)) + *lmkey++ = *lkey++ & *lmask++; +} + +static bool fl_mask_fits_tmplt(struct fl_flow_tmplt *tmplt, + struct fl_flow_mask *mask) +{ + const long *lmask = fl_key_get_start(&mask->key, mask); + const long *ltmplt; + int i; + + if (!tmplt) + return true; + ltmplt = fl_key_get_start(&tmplt->mask, mask); + for (i = 0; i < fl_mask_range(mask); i += sizeof(long)) { + if (~*ltmplt++ & *lmask++) + return false; + } + return true; +} + +static void fl_clear_masked_range(struct fl_flow_key *key, + struct fl_flow_mask *mask) +{ + memset(fl_key_get_start(key, mask), 0, fl_mask_range(mask)); +} + +static bool fl_range_port_dst_cmp(struct cls_fl_filter *filter, + struct fl_flow_key *key, + struct fl_flow_key *mkey) +{ + u16 min_mask, max_mask, min_val, max_val; + + min_mask = ntohs(filter->mask->key.tp_range.tp_min.dst); + max_mask = ntohs(filter->mask->key.tp_range.tp_max.dst); + min_val = ntohs(filter->key.tp_range.tp_min.dst); + max_val = ntohs(filter->key.tp_range.tp_max.dst); + + if (min_mask && max_mask) { + if (ntohs(key->tp_range.tp.dst) < min_val || + ntohs(key->tp_range.tp.dst) > max_val) + return false; + + /* skb does not have min and max values */ + mkey->tp_range.tp_min.dst = filter->mkey.tp_range.tp_min.dst; + mkey->tp_range.tp_max.dst = filter->mkey.tp_range.tp_max.dst; + } + return true; +} + +static bool fl_range_port_src_cmp(struct cls_fl_filter *filter, + struct fl_flow_key *key, + struct fl_flow_key *mkey) +{ + u16 min_mask, max_mask, min_val, max_val; + + min_mask = ntohs(filter->mask->key.tp_range.tp_min.src); + max_mask = ntohs(filter->mask->key.tp_range.tp_max.src); + min_val = ntohs(filter->key.tp_range.tp_min.src); + max_val = ntohs(filter->key.tp_range.tp_max.src); + + if (min_mask && max_mask) { + if (ntohs(key->tp_range.tp.src) < min_val || + ntohs(key->tp_range.tp.src) > max_val) + return false; + + /* skb does not have min and max values */ + mkey->tp_range.tp_min.src = filter->mkey.tp_range.tp_min.src; + mkey->tp_range.tp_max.src = filter->mkey.tp_range.tp_max.src; + } + return true; +} + +static struct cls_fl_filter *__fl_lookup(struct fl_flow_mask *mask, + struct fl_flow_key *mkey) +{ + return rhashtable_lookup_fast(&mask->ht, fl_key_get_start(mkey, mask), + mask->filter_ht_params); +} + +static struct cls_fl_filter *fl_lookup_range(struct fl_flow_mask *mask, + struct fl_flow_key *mkey, + struct fl_flow_key *key) +{ + struct cls_fl_filter *filter, *f; + + list_for_each_entry_rcu(filter, &mask->filters, list) { + if (!fl_range_port_dst_cmp(filter, key, mkey)) + continue; + + if (!fl_range_port_src_cmp(filter, key, mkey)) + continue; + + f = __fl_lookup(mask, mkey); + if (f) + return f; + } + return NULL; +} + +static noinline_for_stack +struct cls_fl_filter *fl_mask_lookup(struct fl_flow_mask *mask, struct fl_flow_key *key) +{ + struct fl_flow_key mkey; + + fl_set_masked_key(&mkey, key, mask); + if ((mask->flags & TCA_FLOWER_MASK_FLAGS_RANGE)) + return fl_lookup_range(mask, &mkey, key); + + return __fl_lookup(mask, &mkey); +} + +static u16 fl_ct_info_to_flower_map[] = { + [IP_CT_ESTABLISHED] = TCA_FLOWER_KEY_CT_FLAGS_TRACKED | + TCA_FLOWER_KEY_CT_FLAGS_ESTABLISHED, + [IP_CT_RELATED] = TCA_FLOWER_KEY_CT_FLAGS_TRACKED | + TCA_FLOWER_KEY_CT_FLAGS_RELATED, + [IP_CT_ESTABLISHED_REPLY] = TCA_FLOWER_KEY_CT_FLAGS_TRACKED | + TCA_FLOWER_KEY_CT_FLAGS_ESTABLISHED | + TCA_FLOWER_KEY_CT_FLAGS_REPLY, + [IP_CT_RELATED_REPLY] = TCA_FLOWER_KEY_CT_FLAGS_TRACKED | + TCA_FLOWER_KEY_CT_FLAGS_RELATED | + TCA_FLOWER_KEY_CT_FLAGS_REPLY, + [IP_CT_NEW] = TCA_FLOWER_KEY_CT_FLAGS_TRACKED | + TCA_FLOWER_KEY_CT_FLAGS_NEW, +}; + +static int fl_classify(struct sk_buff *skb, const struct tcf_proto *tp, + struct tcf_result *res) +{ + struct cls_fl_head *head = rcu_dereference_bh(tp->root); + bool post_ct = tc_skb_cb(skb)->post_ct; + u16 zone = tc_skb_cb(skb)->zone; + struct fl_flow_key skb_key; + struct fl_flow_mask *mask; + struct cls_fl_filter *f; + + list_for_each_entry_rcu(mask, &head->masks, list) { + flow_dissector_init_keys(&skb_key.control, &skb_key.basic); + fl_clear_masked_range(&skb_key, mask); + + skb_flow_dissect_meta(skb, &mask->dissector, &skb_key); + /* skb_flow_dissect() does not set n_proto in case an unknown + * protocol, so do it rather here. + */ + skb_key.basic.n_proto = skb_protocol(skb, false); + skb_flow_dissect_tunnel_info(skb, &mask->dissector, &skb_key); + skb_flow_dissect_ct(skb, &mask->dissector, &skb_key, + fl_ct_info_to_flower_map, + ARRAY_SIZE(fl_ct_info_to_flower_map), + post_ct, zone); + skb_flow_dissect_hash(skb, &mask->dissector, &skb_key); + skb_flow_dissect(skb, &mask->dissector, &skb_key, + FLOW_DISSECTOR_F_STOP_BEFORE_ENCAP); + + f = fl_mask_lookup(mask, &skb_key); + if (f && !tc_skip_sw(f->flags)) { + *res = f->res; + return tcf_exts_exec(skb, &f->exts, res); + } + } + return -1; +} + +static int fl_init(struct tcf_proto *tp) +{ + struct cls_fl_head *head; + + head = kzalloc(sizeof(*head), GFP_KERNEL); + if (!head) + return -ENOBUFS; + + spin_lock_init(&head->masks_lock); + INIT_LIST_HEAD_RCU(&head->masks); + INIT_LIST_HEAD(&head->hw_filters); + rcu_assign_pointer(tp->root, head); + idr_init(&head->handle_idr); + + return rhashtable_init(&head->ht, &mask_ht_params); +} + +static void fl_mask_free(struct fl_flow_mask *mask, bool mask_init_done) +{ + /* temporary masks don't have their filters list and ht initialized */ + if (mask_init_done) { + WARN_ON(!list_empty(&mask->filters)); + rhashtable_destroy(&mask->ht); + } + kfree(mask); +} + +static void fl_mask_free_work(struct work_struct *work) +{ + struct fl_flow_mask *mask = container_of(to_rcu_work(work), + struct fl_flow_mask, rwork); + + fl_mask_free(mask, true); +} + +static void fl_uninit_mask_free_work(struct work_struct *work) +{ + struct fl_flow_mask *mask = container_of(to_rcu_work(work), + struct fl_flow_mask, rwork); + + fl_mask_free(mask, false); +} + +static bool fl_mask_put(struct cls_fl_head *head, struct fl_flow_mask *mask) +{ + if (!refcount_dec_and_test(&mask->refcnt)) + return false; + + rhashtable_remove_fast(&head->ht, &mask->ht_node, mask_ht_params); + + spin_lock(&head->masks_lock); + list_del_rcu(&mask->list); + spin_unlock(&head->masks_lock); + + tcf_queue_work(&mask->rwork, fl_mask_free_work); + + return true; +} + +static struct cls_fl_head *fl_head_dereference(struct tcf_proto *tp) +{ + /* Flower classifier only changes root pointer during init and destroy. + * Users must obtain reference to tcf_proto instance before calling its + * API, so tp->root pointer is protected from concurrent call to + * fl_destroy() by reference counting. + */ + return rcu_dereference_raw(tp->root); +} + +static void __fl_destroy_filter(struct cls_fl_filter *f) +{ + tcf_exts_destroy(&f->exts); + tcf_exts_put_net(&f->exts); + kfree(f); +} + +static void fl_destroy_filter_work(struct work_struct *work) +{ + struct cls_fl_filter *f = container_of(to_rcu_work(work), + struct cls_fl_filter, rwork); + + __fl_destroy_filter(f); +} + +static void fl_hw_destroy_filter(struct tcf_proto *tp, struct cls_fl_filter *f, + bool rtnl_held, struct netlink_ext_ack *extack) +{ + struct tcf_block *block = tp->chain->block; + struct flow_cls_offload cls_flower = {}; + + tc_cls_common_offload_init(&cls_flower.common, tp, f->flags, extack); + cls_flower.command = FLOW_CLS_DESTROY; + cls_flower.cookie = (unsigned long) f; + + tc_setup_cb_destroy(block, tp, TC_SETUP_CLSFLOWER, &cls_flower, false, + &f->flags, &f->in_hw_count, rtnl_held); + +} + +static int fl_hw_replace_filter(struct tcf_proto *tp, + struct cls_fl_filter *f, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + struct tcf_block *block = tp->chain->block; + struct flow_cls_offload cls_flower = {}; + bool skip_sw = tc_skip_sw(f->flags); + int err = 0; + + cls_flower.rule = flow_rule_alloc(tcf_exts_num_actions(&f->exts)); + if (!cls_flower.rule) + return -ENOMEM; + + tc_cls_common_offload_init(&cls_flower.common, tp, f->flags, extack); + cls_flower.command = FLOW_CLS_REPLACE; + cls_flower.cookie = (unsigned long) f; + cls_flower.rule->match.dissector = &f->mask->dissector; + cls_flower.rule->match.mask = &f->mask->key; + cls_flower.rule->match.key = &f->mkey; + cls_flower.classid = f->res.classid; + + err = tc_setup_offload_action(&cls_flower.rule->action, &f->exts, + cls_flower.common.extack); + if (err) { + kfree(cls_flower.rule); + + return skip_sw ? err : 0; + } + + err = tc_setup_cb_add(block, tp, TC_SETUP_CLSFLOWER, &cls_flower, + skip_sw, &f->flags, &f->in_hw_count, rtnl_held); + tc_cleanup_offload_action(&cls_flower.rule->action); + kfree(cls_flower.rule); + + if (err) { + fl_hw_destroy_filter(tp, f, rtnl_held, NULL); + return err; + } + + if (skip_sw && !(f->flags & TCA_CLS_FLAGS_IN_HW)) + return -EINVAL; + + return 0; +} + +static void fl_hw_update_stats(struct tcf_proto *tp, struct cls_fl_filter *f, + bool rtnl_held) +{ + struct tcf_block *block = tp->chain->block; + struct flow_cls_offload cls_flower = {}; + + tc_cls_common_offload_init(&cls_flower.common, tp, f->flags, NULL); + cls_flower.command = FLOW_CLS_STATS; + cls_flower.cookie = (unsigned long) f; + cls_flower.classid = f->res.classid; + + tc_setup_cb_call(block, TC_SETUP_CLSFLOWER, &cls_flower, false, + rtnl_held); + + tcf_exts_hw_stats_update(&f->exts, cls_flower.stats.bytes, + cls_flower.stats.pkts, + cls_flower.stats.drops, + cls_flower.stats.lastused, + cls_flower.stats.used_hw_stats, + cls_flower.stats.used_hw_stats_valid); +} + +static void __fl_put(struct cls_fl_filter *f) +{ + if (!refcount_dec_and_test(&f->refcnt)) + return; + + if (tcf_exts_get_net(&f->exts)) + tcf_queue_work(&f->rwork, fl_destroy_filter_work); + else + __fl_destroy_filter(f); +} + +static struct cls_fl_filter *__fl_get(struct cls_fl_head *head, u32 handle) +{ + struct cls_fl_filter *f; + + rcu_read_lock(); + f = idr_find(&head->handle_idr, handle); + if (f && !refcount_inc_not_zero(&f->refcnt)) + f = NULL; + rcu_read_unlock(); + + return f; +} + +static int __fl_delete(struct tcf_proto *tp, struct cls_fl_filter *f, + bool *last, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + struct cls_fl_head *head = fl_head_dereference(tp); + + *last = false; + + spin_lock(&tp->lock); + if (f->deleted) { + spin_unlock(&tp->lock); + return -ENOENT; + } + + f->deleted = true; + rhashtable_remove_fast(&f->mask->ht, &f->ht_node, + f->mask->filter_ht_params); + idr_remove(&head->handle_idr, f->handle); + list_del_rcu(&f->list); + spin_unlock(&tp->lock); + + *last = fl_mask_put(head, f->mask); + if (!tc_skip_hw(f->flags)) + fl_hw_destroy_filter(tp, f, rtnl_held, extack); + tcf_unbind_filter(tp, &f->res); + __fl_put(f); + + return 0; +} + +static void fl_destroy_sleepable(struct work_struct *work) +{ + struct cls_fl_head *head = container_of(to_rcu_work(work), + struct cls_fl_head, + rwork); + + rhashtable_destroy(&head->ht); + kfree(head); + module_put(THIS_MODULE); +} + +static void fl_destroy(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + struct cls_fl_head *head = fl_head_dereference(tp); + struct fl_flow_mask *mask, *next_mask; + struct cls_fl_filter *f, *next; + bool last; + + list_for_each_entry_safe(mask, next_mask, &head->masks, list) { + list_for_each_entry_safe(f, next, &mask->filters, list) { + __fl_delete(tp, f, &last, rtnl_held, extack); + if (last) + break; + } + } + idr_destroy(&head->handle_idr); + + __module_get(THIS_MODULE); + tcf_queue_work(&head->rwork, fl_destroy_sleepable); +} + +static void fl_put(struct tcf_proto *tp, void *arg) +{ + struct cls_fl_filter *f = arg; + + __fl_put(f); +} + +static void *fl_get(struct tcf_proto *tp, u32 handle) +{ + struct cls_fl_head *head = fl_head_dereference(tp); + + return __fl_get(head, handle); +} + +static const struct nla_policy fl_policy[TCA_FLOWER_MAX + 1] = { + [TCA_FLOWER_UNSPEC] = { .type = NLA_UNSPEC }, + [TCA_FLOWER_CLASSID] = { .type = NLA_U32 }, + [TCA_FLOWER_INDEV] = { .type = NLA_STRING, + .len = IFNAMSIZ }, + [TCA_FLOWER_KEY_ETH_DST] = { .len = ETH_ALEN }, + [TCA_FLOWER_KEY_ETH_DST_MASK] = { .len = ETH_ALEN }, + [TCA_FLOWER_KEY_ETH_SRC] = { .len = ETH_ALEN }, + [TCA_FLOWER_KEY_ETH_SRC_MASK] = { .len = ETH_ALEN }, + [TCA_FLOWER_KEY_ETH_TYPE] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_IP_PROTO] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_IPV4_SRC] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_IPV4_SRC_MASK] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_IPV4_DST] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_IPV4_DST_MASK] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_IPV6_SRC] = { .len = sizeof(struct in6_addr) }, + [TCA_FLOWER_KEY_IPV6_SRC_MASK] = { .len = sizeof(struct in6_addr) }, + [TCA_FLOWER_KEY_IPV6_DST] = { .len = sizeof(struct in6_addr) }, + [TCA_FLOWER_KEY_IPV6_DST_MASK] = { .len = sizeof(struct in6_addr) }, + [TCA_FLOWER_KEY_TCP_SRC] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_TCP_DST] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_UDP_SRC] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_UDP_DST] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_VLAN_ID] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_VLAN_PRIO] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_VLAN_ETH_TYPE] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_ENC_KEY_ID] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_ENC_IPV4_SRC] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_ENC_IPV4_SRC_MASK] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_ENC_IPV4_DST] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_ENC_IPV4_DST_MASK] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_ENC_IPV6_SRC] = { .len = sizeof(struct in6_addr) }, + [TCA_FLOWER_KEY_ENC_IPV6_SRC_MASK] = { .len = sizeof(struct in6_addr) }, + [TCA_FLOWER_KEY_ENC_IPV6_DST] = { .len = sizeof(struct in6_addr) }, + [TCA_FLOWER_KEY_ENC_IPV6_DST_MASK] = { .len = sizeof(struct in6_addr) }, + [TCA_FLOWER_KEY_TCP_SRC_MASK] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_TCP_DST_MASK] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_UDP_SRC_MASK] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_UDP_DST_MASK] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_SCTP_SRC_MASK] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_SCTP_DST_MASK] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_SCTP_SRC] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_SCTP_DST] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_ENC_UDP_SRC_PORT] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_ENC_UDP_SRC_PORT_MASK] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_ENC_UDP_DST_PORT] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_ENC_UDP_DST_PORT_MASK] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_FLAGS] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_FLAGS_MASK] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_ICMPV4_TYPE] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ICMPV4_TYPE_MASK] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ICMPV4_CODE] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ICMPV4_CODE_MASK] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ICMPV6_TYPE] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ICMPV6_TYPE_MASK] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ICMPV6_CODE] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ICMPV6_CODE_MASK] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ARP_SIP] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_ARP_SIP_MASK] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_ARP_TIP] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_ARP_TIP_MASK] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_ARP_OP] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ARP_OP_MASK] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ARP_SHA] = { .len = ETH_ALEN }, + [TCA_FLOWER_KEY_ARP_SHA_MASK] = { .len = ETH_ALEN }, + [TCA_FLOWER_KEY_ARP_THA] = { .len = ETH_ALEN }, + [TCA_FLOWER_KEY_ARP_THA_MASK] = { .len = ETH_ALEN }, + [TCA_FLOWER_KEY_MPLS_TTL] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_MPLS_BOS] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_MPLS_TC] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_MPLS_LABEL] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_MPLS_OPTS] = { .type = NLA_NESTED }, + [TCA_FLOWER_KEY_TCP_FLAGS] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_TCP_FLAGS_MASK] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_IP_TOS] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_IP_TOS_MASK] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_IP_TTL] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_IP_TTL_MASK] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_CVLAN_ID] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_CVLAN_PRIO] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_CVLAN_ETH_TYPE] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_ENC_IP_TOS] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ENC_IP_TOS_MASK] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ENC_IP_TTL] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ENC_IP_TTL_MASK] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ENC_OPTS] = { .type = NLA_NESTED }, + [TCA_FLOWER_KEY_ENC_OPTS_MASK] = { .type = NLA_NESTED }, + [TCA_FLOWER_KEY_CT_STATE] = + NLA_POLICY_MASK(NLA_U16, TCA_FLOWER_KEY_CT_FLAGS_MASK), + [TCA_FLOWER_KEY_CT_STATE_MASK] = + NLA_POLICY_MASK(NLA_U16, TCA_FLOWER_KEY_CT_FLAGS_MASK), + [TCA_FLOWER_KEY_CT_ZONE] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_CT_ZONE_MASK] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_CT_MARK] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_CT_MARK_MASK] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_CT_LABELS] = { .type = NLA_BINARY, + .len = 128 / BITS_PER_BYTE }, + [TCA_FLOWER_KEY_CT_LABELS_MASK] = { .type = NLA_BINARY, + .len = 128 / BITS_PER_BYTE }, + [TCA_FLOWER_FLAGS] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_HASH] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_HASH_MASK] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_NUM_OF_VLANS] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_PPPOE_SID] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_PPP_PROTO] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_L2TPV3_SID] = { .type = NLA_U32 }, + +}; + +static const struct nla_policy +enc_opts_policy[TCA_FLOWER_KEY_ENC_OPTS_MAX + 1] = { + [TCA_FLOWER_KEY_ENC_OPTS_UNSPEC] = { + .strict_start_type = TCA_FLOWER_KEY_ENC_OPTS_VXLAN }, + [TCA_FLOWER_KEY_ENC_OPTS_GENEVE] = { .type = NLA_NESTED }, + [TCA_FLOWER_KEY_ENC_OPTS_VXLAN] = { .type = NLA_NESTED }, + [TCA_FLOWER_KEY_ENC_OPTS_ERSPAN] = { .type = NLA_NESTED }, + [TCA_FLOWER_KEY_ENC_OPTS_GTP] = { .type = NLA_NESTED }, +}; + +static const struct nla_policy +geneve_opt_policy[TCA_FLOWER_KEY_ENC_OPT_GENEVE_MAX + 1] = { + [TCA_FLOWER_KEY_ENC_OPT_GENEVE_CLASS] = { .type = NLA_U16 }, + [TCA_FLOWER_KEY_ENC_OPT_GENEVE_TYPE] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ENC_OPT_GENEVE_DATA] = { .type = NLA_BINARY, + .len = 128 }, +}; + +static const struct nla_policy +vxlan_opt_policy[TCA_FLOWER_KEY_ENC_OPT_VXLAN_MAX + 1] = { + [TCA_FLOWER_KEY_ENC_OPT_VXLAN_GBP] = { .type = NLA_U32 }, +}; + +static const struct nla_policy +erspan_opt_policy[TCA_FLOWER_KEY_ENC_OPT_ERSPAN_MAX + 1] = { + [TCA_FLOWER_KEY_ENC_OPT_ERSPAN_VER] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ENC_OPT_ERSPAN_INDEX] = { .type = NLA_U32 }, + [TCA_FLOWER_KEY_ENC_OPT_ERSPAN_DIR] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ENC_OPT_ERSPAN_HWID] = { .type = NLA_U8 }, +}; + +static const struct nla_policy +gtp_opt_policy[TCA_FLOWER_KEY_ENC_OPT_GTP_MAX + 1] = { + [TCA_FLOWER_KEY_ENC_OPT_GTP_PDU_TYPE] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_ENC_OPT_GTP_QFI] = { .type = NLA_U8 }, +}; + +static const struct nla_policy +mpls_stack_entry_policy[TCA_FLOWER_KEY_MPLS_OPT_LSE_MAX + 1] = { + [TCA_FLOWER_KEY_MPLS_OPT_LSE_DEPTH] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_MPLS_OPT_LSE_TTL] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_MPLS_OPT_LSE_BOS] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_MPLS_OPT_LSE_TC] = { .type = NLA_U8 }, + [TCA_FLOWER_KEY_MPLS_OPT_LSE_LABEL] = { .type = NLA_U32 }, +}; + +static void fl_set_key_val(struct nlattr **tb, + void *val, int val_type, + void *mask, int mask_type, int len) +{ + if (!tb[val_type]) + return; + nla_memcpy(val, tb[val_type], len); + if (mask_type == TCA_FLOWER_UNSPEC || !tb[mask_type]) + memset(mask, 0xff, len); + else + nla_memcpy(mask, tb[mask_type], len); +} + +static int fl_set_key_port_range(struct nlattr **tb, struct fl_flow_key *key, + struct fl_flow_key *mask, + struct netlink_ext_ack *extack) +{ + fl_set_key_val(tb, &key->tp_range.tp_min.dst, + TCA_FLOWER_KEY_PORT_DST_MIN, &mask->tp_range.tp_min.dst, + TCA_FLOWER_UNSPEC, sizeof(key->tp_range.tp_min.dst)); + fl_set_key_val(tb, &key->tp_range.tp_max.dst, + TCA_FLOWER_KEY_PORT_DST_MAX, &mask->tp_range.tp_max.dst, + TCA_FLOWER_UNSPEC, sizeof(key->tp_range.tp_max.dst)); + fl_set_key_val(tb, &key->tp_range.tp_min.src, + TCA_FLOWER_KEY_PORT_SRC_MIN, &mask->tp_range.tp_min.src, + TCA_FLOWER_UNSPEC, sizeof(key->tp_range.tp_min.src)); + fl_set_key_val(tb, &key->tp_range.tp_max.src, + TCA_FLOWER_KEY_PORT_SRC_MAX, &mask->tp_range.tp_max.src, + TCA_FLOWER_UNSPEC, sizeof(key->tp_range.tp_max.src)); + + if (mask->tp_range.tp_min.dst != mask->tp_range.tp_max.dst) { + NL_SET_ERR_MSG(extack, + "Both min and max destination ports must be specified"); + return -EINVAL; + } + if (mask->tp_range.tp_min.src != mask->tp_range.tp_max.src) { + NL_SET_ERR_MSG(extack, + "Both min and max source ports must be specified"); + return -EINVAL; + } + if (mask->tp_range.tp_min.dst && mask->tp_range.tp_max.dst && + ntohs(key->tp_range.tp_max.dst) <= + ntohs(key->tp_range.tp_min.dst)) { + NL_SET_ERR_MSG_ATTR(extack, + tb[TCA_FLOWER_KEY_PORT_DST_MIN], + "Invalid destination port range (min must be strictly smaller than max)"); + return -EINVAL; + } + if (mask->tp_range.tp_min.src && mask->tp_range.tp_max.src && + ntohs(key->tp_range.tp_max.src) <= + ntohs(key->tp_range.tp_min.src)) { + NL_SET_ERR_MSG_ATTR(extack, + tb[TCA_FLOWER_KEY_PORT_SRC_MIN], + "Invalid source port range (min must be strictly smaller than max)"); + return -EINVAL; + } + + return 0; +} + +static int fl_set_key_mpls_lse(const struct nlattr *nla_lse, + struct flow_dissector_key_mpls *key_val, + struct flow_dissector_key_mpls *key_mask, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_FLOWER_KEY_MPLS_OPT_LSE_MAX + 1]; + struct flow_dissector_mpls_lse *lse_mask; + struct flow_dissector_mpls_lse *lse_val; + u8 lse_index; + u8 depth; + int err; + + err = nla_parse_nested(tb, TCA_FLOWER_KEY_MPLS_OPT_LSE_MAX, nla_lse, + mpls_stack_entry_policy, extack); + if (err < 0) + return err; + + if (!tb[TCA_FLOWER_KEY_MPLS_OPT_LSE_DEPTH]) { + NL_SET_ERR_MSG(extack, "Missing MPLS option \"depth\""); + return -EINVAL; + } + + depth = nla_get_u8(tb[TCA_FLOWER_KEY_MPLS_OPT_LSE_DEPTH]); + + /* LSE depth starts at 1, for consistency with terminology used by + * RFC 3031 (section 3.9), where depth 0 refers to unlabeled packets. + */ + if (depth < 1 || depth > FLOW_DIS_MPLS_MAX) { + NL_SET_ERR_MSG_ATTR(extack, + tb[TCA_FLOWER_KEY_MPLS_OPT_LSE_DEPTH], + "Invalid MPLS depth"); + return -EINVAL; + } + lse_index = depth - 1; + + dissector_set_mpls_lse(key_val, lse_index); + dissector_set_mpls_lse(key_mask, lse_index); + + lse_val = &key_val->ls[lse_index]; + lse_mask = &key_mask->ls[lse_index]; + + if (tb[TCA_FLOWER_KEY_MPLS_OPT_LSE_TTL]) { + lse_val->mpls_ttl = nla_get_u8(tb[TCA_FLOWER_KEY_MPLS_OPT_LSE_TTL]); + lse_mask->mpls_ttl = MPLS_TTL_MASK; + } + if (tb[TCA_FLOWER_KEY_MPLS_OPT_LSE_BOS]) { + u8 bos = nla_get_u8(tb[TCA_FLOWER_KEY_MPLS_OPT_LSE_BOS]); + + if (bos & ~MPLS_BOS_MASK) { + NL_SET_ERR_MSG_ATTR(extack, + tb[TCA_FLOWER_KEY_MPLS_OPT_LSE_BOS], + "Bottom Of Stack (BOS) must be 0 or 1"); + return -EINVAL; + } + lse_val->mpls_bos = bos; + lse_mask->mpls_bos = MPLS_BOS_MASK; + } + if (tb[TCA_FLOWER_KEY_MPLS_OPT_LSE_TC]) { + u8 tc = nla_get_u8(tb[TCA_FLOWER_KEY_MPLS_OPT_LSE_TC]); + + if (tc & ~MPLS_TC_MASK) { + NL_SET_ERR_MSG_ATTR(extack, + tb[TCA_FLOWER_KEY_MPLS_OPT_LSE_TC], + "Traffic Class (TC) must be between 0 and 7"); + return -EINVAL; + } + lse_val->mpls_tc = tc; + lse_mask->mpls_tc = MPLS_TC_MASK; + } + if (tb[TCA_FLOWER_KEY_MPLS_OPT_LSE_LABEL]) { + u32 label = nla_get_u32(tb[TCA_FLOWER_KEY_MPLS_OPT_LSE_LABEL]); + + if (label & ~MPLS_LABEL_MASK) { + NL_SET_ERR_MSG_ATTR(extack, + tb[TCA_FLOWER_KEY_MPLS_OPT_LSE_LABEL], + "Label must be between 0 and 1048575"); + return -EINVAL; + } + lse_val->mpls_label = label; + lse_mask->mpls_label = MPLS_LABEL_MASK; + } + + return 0; +} + +static int fl_set_key_mpls_opts(const struct nlattr *nla_mpls_opts, + struct flow_dissector_key_mpls *key_val, + struct flow_dissector_key_mpls *key_mask, + struct netlink_ext_ack *extack) +{ + struct nlattr *nla_lse; + int rem; + int err; + + if (!(nla_mpls_opts->nla_type & NLA_F_NESTED)) { + NL_SET_ERR_MSG_ATTR(extack, nla_mpls_opts, + "NLA_F_NESTED is missing"); + return -EINVAL; + } + + nla_for_each_nested(nla_lse, nla_mpls_opts, rem) { + if (nla_type(nla_lse) != TCA_FLOWER_KEY_MPLS_OPTS_LSE) { + NL_SET_ERR_MSG_ATTR(extack, nla_lse, + "Invalid MPLS option type"); + return -EINVAL; + } + + err = fl_set_key_mpls_lse(nla_lse, key_val, key_mask, extack); + if (err < 0) + return err; + } + if (rem) { + NL_SET_ERR_MSG(extack, + "Bytes leftover after parsing MPLS options"); + return -EINVAL; + } + + return 0; +} + +static int fl_set_key_mpls(struct nlattr **tb, + struct flow_dissector_key_mpls *key_val, + struct flow_dissector_key_mpls *key_mask, + struct netlink_ext_ack *extack) +{ + struct flow_dissector_mpls_lse *lse_mask; + struct flow_dissector_mpls_lse *lse_val; + + if (tb[TCA_FLOWER_KEY_MPLS_OPTS]) { + if (tb[TCA_FLOWER_KEY_MPLS_TTL] || + tb[TCA_FLOWER_KEY_MPLS_BOS] || + tb[TCA_FLOWER_KEY_MPLS_TC] || + tb[TCA_FLOWER_KEY_MPLS_LABEL]) { + NL_SET_ERR_MSG_ATTR(extack, + tb[TCA_FLOWER_KEY_MPLS_OPTS], + "MPLS label, Traffic Class, Bottom Of Stack and Time To Live must be encapsulated in the MPLS options attribute"); + return -EBADMSG; + } + + return fl_set_key_mpls_opts(tb[TCA_FLOWER_KEY_MPLS_OPTS], + key_val, key_mask, extack); + } + + lse_val = &key_val->ls[0]; + lse_mask = &key_mask->ls[0]; + + if (tb[TCA_FLOWER_KEY_MPLS_TTL]) { + lse_val->mpls_ttl = nla_get_u8(tb[TCA_FLOWER_KEY_MPLS_TTL]); + lse_mask->mpls_ttl = MPLS_TTL_MASK; + dissector_set_mpls_lse(key_val, 0); + dissector_set_mpls_lse(key_mask, 0); + } + if (tb[TCA_FLOWER_KEY_MPLS_BOS]) { + u8 bos = nla_get_u8(tb[TCA_FLOWER_KEY_MPLS_BOS]); + + if (bos & ~MPLS_BOS_MASK) { + NL_SET_ERR_MSG_ATTR(extack, + tb[TCA_FLOWER_KEY_MPLS_BOS], + "Bottom Of Stack (BOS) must be 0 or 1"); + return -EINVAL; + } + lse_val->mpls_bos = bos; + lse_mask->mpls_bos = MPLS_BOS_MASK; + dissector_set_mpls_lse(key_val, 0); + dissector_set_mpls_lse(key_mask, 0); + } + if (tb[TCA_FLOWER_KEY_MPLS_TC]) { + u8 tc = nla_get_u8(tb[TCA_FLOWER_KEY_MPLS_TC]); + + if (tc & ~MPLS_TC_MASK) { + NL_SET_ERR_MSG_ATTR(extack, + tb[TCA_FLOWER_KEY_MPLS_TC], + "Traffic Class (TC) must be between 0 and 7"); + return -EINVAL; + } + lse_val->mpls_tc = tc; + lse_mask->mpls_tc = MPLS_TC_MASK; + dissector_set_mpls_lse(key_val, 0); + dissector_set_mpls_lse(key_mask, 0); + } + if (tb[TCA_FLOWER_KEY_MPLS_LABEL]) { + u32 label = nla_get_u32(tb[TCA_FLOWER_KEY_MPLS_LABEL]); + + if (label & ~MPLS_LABEL_MASK) { + NL_SET_ERR_MSG_ATTR(extack, + tb[TCA_FLOWER_KEY_MPLS_LABEL], + "Label must be between 0 and 1048575"); + return -EINVAL; + } + lse_val->mpls_label = label; + lse_mask->mpls_label = MPLS_LABEL_MASK; + dissector_set_mpls_lse(key_val, 0); + dissector_set_mpls_lse(key_mask, 0); + } + return 0; +} + +static void fl_set_key_vlan(struct nlattr **tb, + __be16 ethertype, + int vlan_id_key, int vlan_prio_key, + int vlan_next_eth_type_key, + struct flow_dissector_key_vlan *key_val, + struct flow_dissector_key_vlan *key_mask) +{ +#define VLAN_PRIORITY_MASK 0x7 + + if (tb[vlan_id_key]) { + key_val->vlan_id = + nla_get_u16(tb[vlan_id_key]) & VLAN_VID_MASK; + key_mask->vlan_id = VLAN_VID_MASK; + } + if (tb[vlan_prio_key]) { + key_val->vlan_priority = + nla_get_u8(tb[vlan_prio_key]) & + VLAN_PRIORITY_MASK; + key_mask->vlan_priority = VLAN_PRIORITY_MASK; + } + if (ethertype) { + key_val->vlan_tpid = ethertype; + key_mask->vlan_tpid = cpu_to_be16(~0); + } + if (tb[vlan_next_eth_type_key]) { + key_val->vlan_eth_type = + nla_get_be16(tb[vlan_next_eth_type_key]); + key_mask->vlan_eth_type = cpu_to_be16(~0); + } +} + +static void fl_set_key_pppoe(struct nlattr **tb, + struct flow_dissector_key_pppoe *key_val, + struct flow_dissector_key_pppoe *key_mask, + struct fl_flow_key *key, + struct fl_flow_key *mask) +{ + /* key_val::type must be set to ETH_P_PPP_SES + * because ETH_P_PPP_SES was stored in basic.n_proto + * which might get overwritten by ppp_proto + * or might be set to 0, the role of key_val::type + * is simmilar to vlan_key::tpid + */ + key_val->type = htons(ETH_P_PPP_SES); + key_mask->type = cpu_to_be16(~0); + + if (tb[TCA_FLOWER_KEY_PPPOE_SID]) { + key_val->session_id = + nla_get_be16(tb[TCA_FLOWER_KEY_PPPOE_SID]); + key_mask->session_id = cpu_to_be16(~0); + } + if (tb[TCA_FLOWER_KEY_PPP_PROTO]) { + key_val->ppp_proto = + nla_get_be16(tb[TCA_FLOWER_KEY_PPP_PROTO]); + key_mask->ppp_proto = cpu_to_be16(~0); + + if (key_val->ppp_proto == htons(PPP_IP)) { + key->basic.n_proto = htons(ETH_P_IP); + mask->basic.n_proto = cpu_to_be16(~0); + } else if (key_val->ppp_proto == htons(PPP_IPV6)) { + key->basic.n_proto = htons(ETH_P_IPV6); + mask->basic.n_proto = cpu_to_be16(~0); + } else if (key_val->ppp_proto == htons(PPP_MPLS_UC)) { + key->basic.n_proto = htons(ETH_P_MPLS_UC); + mask->basic.n_proto = cpu_to_be16(~0); + } else if (key_val->ppp_proto == htons(PPP_MPLS_MC)) { + key->basic.n_proto = htons(ETH_P_MPLS_MC); + mask->basic.n_proto = cpu_to_be16(~0); + } + } else { + key->basic.n_proto = 0; + mask->basic.n_proto = cpu_to_be16(0); + } +} + +static void fl_set_key_flag(u32 flower_key, u32 flower_mask, + u32 *dissector_key, u32 *dissector_mask, + u32 flower_flag_bit, u32 dissector_flag_bit) +{ + if (flower_mask & flower_flag_bit) { + *dissector_mask |= dissector_flag_bit; + if (flower_key & flower_flag_bit) + *dissector_key |= dissector_flag_bit; + } +} + +static int fl_set_key_flags(struct nlattr **tb, u32 *flags_key, + u32 *flags_mask, struct netlink_ext_ack *extack) +{ + u32 key, mask; + + /* mask is mandatory for flags */ + if (!tb[TCA_FLOWER_KEY_FLAGS_MASK]) { + NL_SET_ERR_MSG(extack, "Missing flags mask"); + return -EINVAL; + } + + key = be32_to_cpu(nla_get_be32(tb[TCA_FLOWER_KEY_FLAGS])); + mask = be32_to_cpu(nla_get_be32(tb[TCA_FLOWER_KEY_FLAGS_MASK])); + + *flags_key = 0; + *flags_mask = 0; + + fl_set_key_flag(key, mask, flags_key, flags_mask, + TCA_FLOWER_KEY_FLAGS_IS_FRAGMENT, FLOW_DIS_IS_FRAGMENT); + fl_set_key_flag(key, mask, flags_key, flags_mask, + TCA_FLOWER_KEY_FLAGS_FRAG_IS_FIRST, + FLOW_DIS_FIRST_FRAG); + + return 0; +} + +static void fl_set_key_ip(struct nlattr **tb, bool encap, + struct flow_dissector_key_ip *key, + struct flow_dissector_key_ip *mask) +{ + int tos_key = encap ? TCA_FLOWER_KEY_ENC_IP_TOS : TCA_FLOWER_KEY_IP_TOS; + int ttl_key = encap ? TCA_FLOWER_KEY_ENC_IP_TTL : TCA_FLOWER_KEY_IP_TTL; + int tos_mask = encap ? TCA_FLOWER_KEY_ENC_IP_TOS_MASK : TCA_FLOWER_KEY_IP_TOS_MASK; + int ttl_mask = encap ? TCA_FLOWER_KEY_ENC_IP_TTL_MASK : TCA_FLOWER_KEY_IP_TTL_MASK; + + fl_set_key_val(tb, &key->tos, tos_key, &mask->tos, tos_mask, sizeof(key->tos)); + fl_set_key_val(tb, &key->ttl, ttl_key, &mask->ttl, ttl_mask, sizeof(key->ttl)); +} + +static int fl_set_geneve_opt(const struct nlattr *nla, struct fl_flow_key *key, + int depth, int option_len, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_FLOWER_KEY_ENC_OPT_GENEVE_MAX + 1]; + struct nlattr *class = NULL, *type = NULL, *data = NULL; + struct geneve_opt *opt; + int err, data_len = 0; + + if (option_len > sizeof(struct geneve_opt)) + data_len = option_len - sizeof(struct geneve_opt); + + if (key->enc_opts.len > FLOW_DIS_TUN_OPTS_MAX - 4) + return -ERANGE; + + opt = (struct geneve_opt *)&key->enc_opts.data[key->enc_opts.len]; + memset(opt, 0xff, option_len); + opt->length = data_len / 4; + opt->r1 = 0; + opt->r2 = 0; + opt->r3 = 0; + + /* If no mask has been prodived we assume an exact match. */ + if (!depth) + return sizeof(struct geneve_opt) + data_len; + + if (nla_type(nla) != TCA_FLOWER_KEY_ENC_OPTS_GENEVE) { + NL_SET_ERR_MSG(extack, "Non-geneve option type for mask"); + return -EINVAL; + } + + err = nla_parse_nested_deprecated(tb, + TCA_FLOWER_KEY_ENC_OPT_GENEVE_MAX, + nla, geneve_opt_policy, extack); + if (err < 0) + return err; + + /* We are not allowed to omit any of CLASS, TYPE or DATA + * fields from the key. + */ + if (!option_len && + (!tb[TCA_FLOWER_KEY_ENC_OPT_GENEVE_CLASS] || + !tb[TCA_FLOWER_KEY_ENC_OPT_GENEVE_TYPE] || + !tb[TCA_FLOWER_KEY_ENC_OPT_GENEVE_DATA])) { + NL_SET_ERR_MSG(extack, "Missing tunnel key geneve option class, type or data"); + return -EINVAL; + } + + /* Omitting any of CLASS, TYPE or DATA fields is allowed + * for the mask. + */ + if (tb[TCA_FLOWER_KEY_ENC_OPT_GENEVE_DATA]) { + int new_len = key->enc_opts.len; + + data = tb[TCA_FLOWER_KEY_ENC_OPT_GENEVE_DATA]; + data_len = nla_len(data); + if (data_len < 4) { + NL_SET_ERR_MSG(extack, "Tunnel key geneve option data is less than 4 bytes long"); + return -ERANGE; + } + if (data_len % 4) { + NL_SET_ERR_MSG(extack, "Tunnel key geneve option data is not a multiple of 4 bytes long"); + return -ERANGE; + } + + new_len += sizeof(struct geneve_opt) + data_len; + BUILD_BUG_ON(FLOW_DIS_TUN_OPTS_MAX != IP_TUNNEL_OPTS_MAX); + if (new_len > FLOW_DIS_TUN_OPTS_MAX) { + NL_SET_ERR_MSG(extack, "Tunnel options exceeds max size"); + return -ERANGE; + } + opt->length = data_len / 4; + memcpy(opt->opt_data, nla_data(data), data_len); + } + + if (tb[TCA_FLOWER_KEY_ENC_OPT_GENEVE_CLASS]) { + class = tb[TCA_FLOWER_KEY_ENC_OPT_GENEVE_CLASS]; + opt->opt_class = nla_get_be16(class); + } + + if (tb[TCA_FLOWER_KEY_ENC_OPT_GENEVE_TYPE]) { + type = tb[TCA_FLOWER_KEY_ENC_OPT_GENEVE_TYPE]; + opt->type = nla_get_u8(type); + } + + return sizeof(struct geneve_opt) + data_len; +} + +static int fl_set_vxlan_opt(const struct nlattr *nla, struct fl_flow_key *key, + int depth, int option_len, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_FLOWER_KEY_ENC_OPT_VXLAN_MAX + 1]; + struct vxlan_metadata *md; + int err; + + md = (struct vxlan_metadata *)&key->enc_opts.data[key->enc_opts.len]; + memset(md, 0xff, sizeof(*md)); + + if (!depth) + return sizeof(*md); + + if (nla_type(nla) != TCA_FLOWER_KEY_ENC_OPTS_VXLAN) { + NL_SET_ERR_MSG(extack, "Non-vxlan option type for mask"); + return -EINVAL; + } + + err = nla_parse_nested(tb, TCA_FLOWER_KEY_ENC_OPT_VXLAN_MAX, nla, + vxlan_opt_policy, extack); + if (err < 0) + return err; + + if (!option_len && !tb[TCA_FLOWER_KEY_ENC_OPT_VXLAN_GBP]) { + NL_SET_ERR_MSG(extack, "Missing tunnel key vxlan option gbp"); + return -EINVAL; + } + + if (tb[TCA_FLOWER_KEY_ENC_OPT_VXLAN_GBP]) { + md->gbp = nla_get_u32(tb[TCA_FLOWER_KEY_ENC_OPT_VXLAN_GBP]); + md->gbp &= VXLAN_GBP_MASK; + } + + return sizeof(*md); +} + +static int fl_set_erspan_opt(const struct nlattr *nla, struct fl_flow_key *key, + int depth, int option_len, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_FLOWER_KEY_ENC_OPT_ERSPAN_MAX + 1]; + struct erspan_metadata *md; + int err; + + md = (struct erspan_metadata *)&key->enc_opts.data[key->enc_opts.len]; + memset(md, 0xff, sizeof(*md)); + md->version = 1; + + if (!depth) + return sizeof(*md); + + if (nla_type(nla) != TCA_FLOWER_KEY_ENC_OPTS_ERSPAN) { + NL_SET_ERR_MSG(extack, "Non-erspan option type for mask"); + return -EINVAL; + } + + err = nla_parse_nested(tb, TCA_FLOWER_KEY_ENC_OPT_ERSPAN_MAX, nla, + erspan_opt_policy, extack); + if (err < 0) + return err; + + if (!option_len && !tb[TCA_FLOWER_KEY_ENC_OPT_ERSPAN_VER]) { + NL_SET_ERR_MSG(extack, "Missing tunnel key erspan option ver"); + return -EINVAL; + } + + if (tb[TCA_FLOWER_KEY_ENC_OPT_ERSPAN_VER]) + md->version = nla_get_u8(tb[TCA_FLOWER_KEY_ENC_OPT_ERSPAN_VER]); + + if (md->version == 1) { + if (!option_len && !tb[TCA_FLOWER_KEY_ENC_OPT_ERSPAN_INDEX]) { + NL_SET_ERR_MSG(extack, "Missing tunnel key erspan option index"); + return -EINVAL; + } + if (tb[TCA_FLOWER_KEY_ENC_OPT_ERSPAN_INDEX]) { + nla = tb[TCA_FLOWER_KEY_ENC_OPT_ERSPAN_INDEX]; + memset(&md->u, 0x00, sizeof(md->u)); + md->u.index = nla_get_be32(nla); + } + } else if (md->version == 2) { + if (!option_len && (!tb[TCA_FLOWER_KEY_ENC_OPT_ERSPAN_DIR] || + !tb[TCA_FLOWER_KEY_ENC_OPT_ERSPAN_HWID])) { + NL_SET_ERR_MSG(extack, "Missing tunnel key erspan option dir or hwid"); + return -EINVAL; + } + if (tb[TCA_FLOWER_KEY_ENC_OPT_ERSPAN_DIR]) { + nla = tb[TCA_FLOWER_KEY_ENC_OPT_ERSPAN_DIR]; + md->u.md2.dir = nla_get_u8(nla); + } + if (tb[TCA_FLOWER_KEY_ENC_OPT_ERSPAN_HWID]) { + nla = tb[TCA_FLOWER_KEY_ENC_OPT_ERSPAN_HWID]; + set_hwid(&md->u.md2, nla_get_u8(nla)); + } + } else { + NL_SET_ERR_MSG(extack, "Tunnel key erspan option ver is incorrect"); + return -EINVAL; + } + + return sizeof(*md); +} + +static int fl_set_gtp_opt(const struct nlattr *nla, struct fl_flow_key *key, + int depth, int option_len, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_FLOWER_KEY_ENC_OPT_GTP_MAX + 1]; + struct gtp_pdu_session_info *sinfo; + u8 len = key->enc_opts.len; + int err; + + sinfo = (struct gtp_pdu_session_info *)&key->enc_opts.data[len]; + memset(sinfo, 0xff, option_len); + + if (!depth) + return sizeof(*sinfo); + + if (nla_type(nla) != TCA_FLOWER_KEY_ENC_OPTS_GTP) { + NL_SET_ERR_MSG_MOD(extack, "Non-gtp option type for mask"); + return -EINVAL; + } + + err = nla_parse_nested(tb, TCA_FLOWER_KEY_ENC_OPT_GTP_MAX, nla, + gtp_opt_policy, extack); + if (err < 0) + return err; + + if (!option_len && + (!tb[TCA_FLOWER_KEY_ENC_OPT_GTP_PDU_TYPE] || + !tb[TCA_FLOWER_KEY_ENC_OPT_GTP_QFI])) { + NL_SET_ERR_MSG_MOD(extack, + "Missing tunnel key gtp option pdu type or qfi"); + return -EINVAL; + } + + if (tb[TCA_FLOWER_KEY_ENC_OPT_GTP_PDU_TYPE]) + sinfo->pdu_type = + nla_get_u8(tb[TCA_FLOWER_KEY_ENC_OPT_GTP_PDU_TYPE]); + + if (tb[TCA_FLOWER_KEY_ENC_OPT_GTP_QFI]) + sinfo->qfi = nla_get_u8(tb[TCA_FLOWER_KEY_ENC_OPT_GTP_QFI]); + + return sizeof(*sinfo); +} + +static int fl_set_enc_opt(struct nlattr **tb, struct fl_flow_key *key, + struct fl_flow_key *mask, + struct netlink_ext_ack *extack) +{ + const struct nlattr *nla_enc_key, *nla_opt_key, *nla_opt_msk = NULL; + int err, option_len, key_depth, msk_depth = 0; + + err = nla_validate_nested_deprecated(tb[TCA_FLOWER_KEY_ENC_OPTS], + TCA_FLOWER_KEY_ENC_OPTS_MAX, + enc_opts_policy, extack); + if (err) + return err; + + nla_enc_key = nla_data(tb[TCA_FLOWER_KEY_ENC_OPTS]); + + if (tb[TCA_FLOWER_KEY_ENC_OPTS_MASK]) { + err = nla_validate_nested_deprecated(tb[TCA_FLOWER_KEY_ENC_OPTS_MASK], + TCA_FLOWER_KEY_ENC_OPTS_MAX, + enc_opts_policy, extack); + if (err) + return err; + + nla_opt_msk = nla_data(tb[TCA_FLOWER_KEY_ENC_OPTS_MASK]); + msk_depth = nla_len(tb[TCA_FLOWER_KEY_ENC_OPTS_MASK]); + if (!nla_ok(nla_opt_msk, msk_depth)) { + NL_SET_ERR_MSG(extack, "Invalid nested attribute for masks"); + return -EINVAL; + } + } + + nla_for_each_attr(nla_opt_key, nla_enc_key, + nla_len(tb[TCA_FLOWER_KEY_ENC_OPTS]), key_depth) { + switch (nla_type(nla_opt_key)) { + case TCA_FLOWER_KEY_ENC_OPTS_GENEVE: + if (key->enc_opts.dst_opt_type && + key->enc_opts.dst_opt_type != TUNNEL_GENEVE_OPT) { + NL_SET_ERR_MSG(extack, "Duplicate type for geneve options"); + return -EINVAL; + } + option_len = 0; + key->enc_opts.dst_opt_type = TUNNEL_GENEVE_OPT; + option_len = fl_set_geneve_opt(nla_opt_key, key, + key_depth, option_len, + extack); + if (option_len < 0) + return option_len; + + key->enc_opts.len += option_len; + /* At the same time we need to parse through the mask + * in order to verify exact and mask attribute lengths. + */ + mask->enc_opts.dst_opt_type = TUNNEL_GENEVE_OPT; + option_len = fl_set_geneve_opt(nla_opt_msk, mask, + msk_depth, option_len, + extack); + if (option_len < 0) + return option_len; + + mask->enc_opts.len += option_len; + if (key->enc_opts.len != mask->enc_opts.len) { + NL_SET_ERR_MSG(extack, "Key and mask miss aligned"); + return -EINVAL; + } + break; + case TCA_FLOWER_KEY_ENC_OPTS_VXLAN: + if (key->enc_opts.dst_opt_type) { + NL_SET_ERR_MSG(extack, "Duplicate type for vxlan options"); + return -EINVAL; + } + option_len = 0; + key->enc_opts.dst_opt_type = TUNNEL_VXLAN_OPT; + option_len = fl_set_vxlan_opt(nla_opt_key, key, + key_depth, option_len, + extack); + if (option_len < 0) + return option_len; + + key->enc_opts.len += option_len; + /* At the same time we need to parse through the mask + * in order to verify exact and mask attribute lengths. + */ + mask->enc_opts.dst_opt_type = TUNNEL_VXLAN_OPT; + option_len = fl_set_vxlan_opt(nla_opt_msk, mask, + msk_depth, option_len, + extack); + if (option_len < 0) + return option_len; + + mask->enc_opts.len += option_len; + if (key->enc_opts.len != mask->enc_opts.len) { + NL_SET_ERR_MSG(extack, "Key and mask miss aligned"); + return -EINVAL; + } + break; + case TCA_FLOWER_KEY_ENC_OPTS_ERSPAN: + if (key->enc_opts.dst_opt_type) { + NL_SET_ERR_MSG(extack, "Duplicate type for erspan options"); + return -EINVAL; + } + option_len = 0; + key->enc_opts.dst_opt_type = TUNNEL_ERSPAN_OPT; + option_len = fl_set_erspan_opt(nla_opt_key, key, + key_depth, option_len, + extack); + if (option_len < 0) + return option_len; + + key->enc_opts.len += option_len; + /* At the same time we need to parse through the mask + * in order to verify exact and mask attribute lengths. + */ + mask->enc_opts.dst_opt_type = TUNNEL_ERSPAN_OPT; + option_len = fl_set_erspan_opt(nla_opt_msk, mask, + msk_depth, option_len, + extack); + if (option_len < 0) + return option_len; + + mask->enc_opts.len += option_len; + if (key->enc_opts.len != mask->enc_opts.len) { + NL_SET_ERR_MSG(extack, "Key and mask miss aligned"); + return -EINVAL; + } + break; + case TCA_FLOWER_KEY_ENC_OPTS_GTP: + if (key->enc_opts.dst_opt_type) { + NL_SET_ERR_MSG_MOD(extack, + "Duplicate type for gtp options"); + return -EINVAL; + } + option_len = 0; + key->enc_opts.dst_opt_type = TUNNEL_GTP_OPT; + option_len = fl_set_gtp_opt(nla_opt_key, key, + key_depth, option_len, + extack); + if (option_len < 0) + return option_len; + + key->enc_opts.len += option_len; + /* At the same time we need to parse through the mask + * in order to verify exact and mask attribute lengths. + */ + mask->enc_opts.dst_opt_type = TUNNEL_GTP_OPT; + option_len = fl_set_gtp_opt(nla_opt_msk, mask, + msk_depth, option_len, + extack); + if (option_len < 0) + return option_len; + + mask->enc_opts.len += option_len; + if (key->enc_opts.len != mask->enc_opts.len) { + NL_SET_ERR_MSG_MOD(extack, + "Key and mask miss aligned"); + return -EINVAL; + } + break; + default: + NL_SET_ERR_MSG(extack, "Unknown tunnel option type"); + return -EINVAL; + } + + if (!msk_depth) + continue; + + if (!nla_ok(nla_opt_msk, msk_depth)) { + NL_SET_ERR_MSG(extack, "A mask attribute is invalid"); + return -EINVAL; + } + nla_opt_msk = nla_next(nla_opt_msk, &msk_depth); + } + + return 0; +} + +static int fl_validate_ct_state(u16 state, struct nlattr *tb, + struct netlink_ext_ack *extack) +{ + if (state && !(state & TCA_FLOWER_KEY_CT_FLAGS_TRACKED)) { + NL_SET_ERR_MSG_ATTR(extack, tb, + "no trk, so no other flag can be set"); + return -EINVAL; + } + + if (state & TCA_FLOWER_KEY_CT_FLAGS_NEW && + state & TCA_FLOWER_KEY_CT_FLAGS_ESTABLISHED) { + NL_SET_ERR_MSG_ATTR(extack, tb, + "new and est are mutually exclusive"); + return -EINVAL; + } + + if (state & TCA_FLOWER_KEY_CT_FLAGS_INVALID && + state & ~(TCA_FLOWER_KEY_CT_FLAGS_TRACKED | + TCA_FLOWER_KEY_CT_FLAGS_INVALID)) { + NL_SET_ERR_MSG_ATTR(extack, tb, + "when inv is set, only trk may be set"); + return -EINVAL; + } + + if (state & TCA_FLOWER_KEY_CT_FLAGS_NEW && + state & TCA_FLOWER_KEY_CT_FLAGS_REPLY) { + NL_SET_ERR_MSG_ATTR(extack, tb, + "new and rpl are mutually exclusive"); + return -EINVAL; + } + + return 0; +} + +static int fl_set_key_ct(struct nlattr **tb, + struct flow_dissector_key_ct *key, + struct flow_dissector_key_ct *mask, + struct netlink_ext_ack *extack) +{ + if (tb[TCA_FLOWER_KEY_CT_STATE]) { + int err; + + if (!IS_ENABLED(CONFIG_NF_CONNTRACK)) { + NL_SET_ERR_MSG(extack, "Conntrack isn't enabled"); + return -EOPNOTSUPP; + } + fl_set_key_val(tb, &key->ct_state, TCA_FLOWER_KEY_CT_STATE, + &mask->ct_state, TCA_FLOWER_KEY_CT_STATE_MASK, + sizeof(key->ct_state)); + + err = fl_validate_ct_state(key->ct_state & mask->ct_state, + tb[TCA_FLOWER_KEY_CT_STATE_MASK], + extack); + if (err) + return err; + + } + if (tb[TCA_FLOWER_KEY_CT_ZONE]) { + if (!IS_ENABLED(CONFIG_NF_CONNTRACK_ZONES)) { + NL_SET_ERR_MSG(extack, "Conntrack zones isn't enabled"); + return -EOPNOTSUPP; + } + fl_set_key_val(tb, &key->ct_zone, TCA_FLOWER_KEY_CT_ZONE, + &mask->ct_zone, TCA_FLOWER_KEY_CT_ZONE_MASK, + sizeof(key->ct_zone)); + } + if (tb[TCA_FLOWER_KEY_CT_MARK]) { + if (!IS_ENABLED(CONFIG_NF_CONNTRACK_MARK)) { + NL_SET_ERR_MSG(extack, "Conntrack mark isn't enabled"); + return -EOPNOTSUPP; + } + fl_set_key_val(tb, &key->ct_mark, TCA_FLOWER_KEY_CT_MARK, + &mask->ct_mark, TCA_FLOWER_KEY_CT_MARK_MASK, + sizeof(key->ct_mark)); + } + if (tb[TCA_FLOWER_KEY_CT_LABELS]) { + if (!IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS)) { + NL_SET_ERR_MSG(extack, "Conntrack labels aren't enabled"); + return -EOPNOTSUPP; + } + fl_set_key_val(tb, key->ct_labels, TCA_FLOWER_KEY_CT_LABELS, + mask->ct_labels, TCA_FLOWER_KEY_CT_LABELS_MASK, + sizeof(key->ct_labels)); + } + + return 0; +} + +static bool is_vlan_key(struct nlattr *tb, __be16 *ethertype, + struct fl_flow_key *key, struct fl_flow_key *mask, + int vthresh) +{ + const bool good_num_of_vlans = key->num_of_vlans.num_of_vlans > vthresh; + + if (!tb) { + *ethertype = 0; + return good_num_of_vlans; + } + + *ethertype = nla_get_be16(tb); + if (good_num_of_vlans || eth_type_vlan(*ethertype)) + return true; + + key->basic.n_proto = *ethertype; + mask->basic.n_proto = cpu_to_be16(~0); + return false; +} + +static int fl_set_key(struct net *net, struct nlattr **tb, + struct fl_flow_key *key, struct fl_flow_key *mask, + struct netlink_ext_ack *extack) +{ + __be16 ethertype; + int ret = 0; + + if (tb[TCA_FLOWER_INDEV]) { + int err = tcf_change_indev(net, tb[TCA_FLOWER_INDEV], extack); + if (err < 0) + return err; + key->meta.ingress_ifindex = err; + mask->meta.ingress_ifindex = 0xffffffff; + } + + fl_set_key_val(tb, key->eth.dst, TCA_FLOWER_KEY_ETH_DST, + mask->eth.dst, TCA_FLOWER_KEY_ETH_DST_MASK, + sizeof(key->eth.dst)); + fl_set_key_val(tb, key->eth.src, TCA_FLOWER_KEY_ETH_SRC, + mask->eth.src, TCA_FLOWER_KEY_ETH_SRC_MASK, + sizeof(key->eth.src)); + fl_set_key_val(tb, &key->num_of_vlans, + TCA_FLOWER_KEY_NUM_OF_VLANS, + &mask->num_of_vlans, + TCA_FLOWER_UNSPEC, + sizeof(key->num_of_vlans)); + + if (is_vlan_key(tb[TCA_FLOWER_KEY_ETH_TYPE], ðertype, key, mask, 0)) { + fl_set_key_vlan(tb, ethertype, TCA_FLOWER_KEY_VLAN_ID, + TCA_FLOWER_KEY_VLAN_PRIO, + TCA_FLOWER_KEY_VLAN_ETH_TYPE, + &key->vlan, &mask->vlan); + + if (is_vlan_key(tb[TCA_FLOWER_KEY_VLAN_ETH_TYPE], + ðertype, key, mask, 1)) { + fl_set_key_vlan(tb, ethertype, + TCA_FLOWER_KEY_CVLAN_ID, + TCA_FLOWER_KEY_CVLAN_PRIO, + TCA_FLOWER_KEY_CVLAN_ETH_TYPE, + &key->cvlan, &mask->cvlan); + fl_set_key_val(tb, &key->basic.n_proto, + TCA_FLOWER_KEY_CVLAN_ETH_TYPE, + &mask->basic.n_proto, + TCA_FLOWER_UNSPEC, + sizeof(key->basic.n_proto)); + } + } + + if (key->basic.n_proto == htons(ETH_P_PPP_SES)) + fl_set_key_pppoe(tb, &key->pppoe, &mask->pppoe, key, mask); + + if (key->basic.n_proto == htons(ETH_P_IP) || + key->basic.n_proto == htons(ETH_P_IPV6)) { + fl_set_key_val(tb, &key->basic.ip_proto, TCA_FLOWER_KEY_IP_PROTO, + &mask->basic.ip_proto, TCA_FLOWER_UNSPEC, + sizeof(key->basic.ip_proto)); + fl_set_key_ip(tb, false, &key->ip, &mask->ip); + } + + if (tb[TCA_FLOWER_KEY_IPV4_SRC] || tb[TCA_FLOWER_KEY_IPV4_DST]) { + key->control.addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; + mask->control.addr_type = ~0; + fl_set_key_val(tb, &key->ipv4.src, TCA_FLOWER_KEY_IPV4_SRC, + &mask->ipv4.src, TCA_FLOWER_KEY_IPV4_SRC_MASK, + sizeof(key->ipv4.src)); + fl_set_key_val(tb, &key->ipv4.dst, TCA_FLOWER_KEY_IPV4_DST, + &mask->ipv4.dst, TCA_FLOWER_KEY_IPV4_DST_MASK, + sizeof(key->ipv4.dst)); + } else if (tb[TCA_FLOWER_KEY_IPV6_SRC] || tb[TCA_FLOWER_KEY_IPV6_DST]) { + key->control.addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + mask->control.addr_type = ~0; + fl_set_key_val(tb, &key->ipv6.src, TCA_FLOWER_KEY_IPV6_SRC, + &mask->ipv6.src, TCA_FLOWER_KEY_IPV6_SRC_MASK, + sizeof(key->ipv6.src)); + fl_set_key_val(tb, &key->ipv6.dst, TCA_FLOWER_KEY_IPV6_DST, + &mask->ipv6.dst, TCA_FLOWER_KEY_IPV6_DST_MASK, + sizeof(key->ipv6.dst)); + } + + if (key->basic.ip_proto == IPPROTO_TCP) { + fl_set_key_val(tb, &key->tp.src, TCA_FLOWER_KEY_TCP_SRC, + &mask->tp.src, TCA_FLOWER_KEY_TCP_SRC_MASK, + sizeof(key->tp.src)); + fl_set_key_val(tb, &key->tp.dst, TCA_FLOWER_KEY_TCP_DST, + &mask->tp.dst, TCA_FLOWER_KEY_TCP_DST_MASK, + sizeof(key->tp.dst)); + fl_set_key_val(tb, &key->tcp.flags, TCA_FLOWER_KEY_TCP_FLAGS, + &mask->tcp.flags, TCA_FLOWER_KEY_TCP_FLAGS_MASK, + sizeof(key->tcp.flags)); + } else if (key->basic.ip_proto == IPPROTO_UDP) { + fl_set_key_val(tb, &key->tp.src, TCA_FLOWER_KEY_UDP_SRC, + &mask->tp.src, TCA_FLOWER_KEY_UDP_SRC_MASK, + sizeof(key->tp.src)); + fl_set_key_val(tb, &key->tp.dst, TCA_FLOWER_KEY_UDP_DST, + &mask->tp.dst, TCA_FLOWER_KEY_UDP_DST_MASK, + sizeof(key->tp.dst)); + } else if (key->basic.ip_proto == IPPROTO_SCTP) { + fl_set_key_val(tb, &key->tp.src, TCA_FLOWER_KEY_SCTP_SRC, + &mask->tp.src, TCA_FLOWER_KEY_SCTP_SRC_MASK, + sizeof(key->tp.src)); + fl_set_key_val(tb, &key->tp.dst, TCA_FLOWER_KEY_SCTP_DST, + &mask->tp.dst, TCA_FLOWER_KEY_SCTP_DST_MASK, + sizeof(key->tp.dst)); + } else if (key->basic.n_proto == htons(ETH_P_IP) && + key->basic.ip_proto == IPPROTO_ICMP) { + fl_set_key_val(tb, &key->icmp.type, TCA_FLOWER_KEY_ICMPV4_TYPE, + &mask->icmp.type, + TCA_FLOWER_KEY_ICMPV4_TYPE_MASK, + sizeof(key->icmp.type)); + fl_set_key_val(tb, &key->icmp.code, TCA_FLOWER_KEY_ICMPV4_CODE, + &mask->icmp.code, + TCA_FLOWER_KEY_ICMPV4_CODE_MASK, + sizeof(key->icmp.code)); + } else if (key->basic.n_proto == htons(ETH_P_IPV6) && + key->basic.ip_proto == IPPROTO_ICMPV6) { + fl_set_key_val(tb, &key->icmp.type, TCA_FLOWER_KEY_ICMPV6_TYPE, + &mask->icmp.type, + TCA_FLOWER_KEY_ICMPV6_TYPE_MASK, + sizeof(key->icmp.type)); + fl_set_key_val(tb, &key->icmp.code, TCA_FLOWER_KEY_ICMPV6_CODE, + &mask->icmp.code, + TCA_FLOWER_KEY_ICMPV6_CODE_MASK, + sizeof(key->icmp.code)); + } else if (key->basic.n_proto == htons(ETH_P_MPLS_UC) || + key->basic.n_proto == htons(ETH_P_MPLS_MC)) { + ret = fl_set_key_mpls(tb, &key->mpls, &mask->mpls, extack); + if (ret) + return ret; + } else if (key->basic.n_proto == htons(ETH_P_ARP) || + key->basic.n_proto == htons(ETH_P_RARP)) { + fl_set_key_val(tb, &key->arp.sip, TCA_FLOWER_KEY_ARP_SIP, + &mask->arp.sip, TCA_FLOWER_KEY_ARP_SIP_MASK, + sizeof(key->arp.sip)); + fl_set_key_val(tb, &key->arp.tip, TCA_FLOWER_KEY_ARP_TIP, + &mask->arp.tip, TCA_FLOWER_KEY_ARP_TIP_MASK, + sizeof(key->arp.tip)); + fl_set_key_val(tb, &key->arp.op, TCA_FLOWER_KEY_ARP_OP, + &mask->arp.op, TCA_FLOWER_KEY_ARP_OP_MASK, + sizeof(key->arp.op)); + fl_set_key_val(tb, key->arp.sha, TCA_FLOWER_KEY_ARP_SHA, + mask->arp.sha, TCA_FLOWER_KEY_ARP_SHA_MASK, + sizeof(key->arp.sha)); + fl_set_key_val(tb, key->arp.tha, TCA_FLOWER_KEY_ARP_THA, + mask->arp.tha, TCA_FLOWER_KEY_ARP_THA_MASK, + sizeof(key->arp.tha)); + } else if (key->basic.ip_proto == IPPROTO_L2TP) { + fl_set_key_val(tb, &key->l2tpv3.session_id, + TCA_FLOWER_KEY_L2TPV3_SID, + &mask->l2tpv3.session_id, TCA_FLOWER_UNSPEC, + sizeof(key->l2tpv3.session_id)); + } + + if (key->basic.ip_proto == IPPROTO_TCP || + key->basic.ip_proto == IPPROTO_UDP || + key->basic.ip_proto == IPPROTO_SCTP) { + ret = fl_set_key_port_range(tb, key, mask, extack); + if (ret) + return ret; + } + + if (tb[TCA_FLOWER_KEY_ENC_IPV4_SRC] || + tb[TCA_FLOWER_KEY_ENC_IPV4_DST]) { + key->enc_control.addr_type = FLOW_DISSECTOR_KEY_IPV4_ADDRS; + mask->enc_control.addr_type = ~0; + fl_set_key_val(tb, &key->enc_ipv4.src, + TCA_FLOWER_KEY_ENC_IPV4_SRC, + &mask->enc_ipv4.src, + TCA_FLOWER_KEY_ENC_IPV4_SRC_MASK, + sizeof(key->enc_ipv4.src)); + fl_set_key_val(tb, &key->enc_ipv4.dst, + TCA_FLOWER_KEY_ENC_IPV4_DST, + &mask->enc_ipv4.dst, + TCA_FLOWER_KEY_ENC_IPV4_DST_MASK, + sizeof(key->enc_ipv4.dst)); + } + + if (tb[TCA_FLOWER_KEY_ENC_IPV6_SRC] || + tb[TCA_FLOWER_KEY_ENC_IPV6_DST]) { + key->enc_control.addr_type = FLOW_DISSECTOR_KEY_IPV6_ADDRS; + mask->enc_control.addr_type = ~0; + fl_set_key_val(tb, &key->enc_ipv6.src, + TCA_FLOWER_KEY_ENC_IPV6_SRC, + &mask->enc_ipv6.src, + TCA_FLOWER_KEY_ENC_IPV6_SRC_MASK, + sizeof(key->enc_ipv6.src)); + fl_set_key_val(tb, &key->enc_ipv6.dst, + TCA_FLOWER_KEY_ENC_IPV6_DST, + &mask->enc_ipv6.dst, + TCA_FLOWER_KEY_ENC_IPV6_DST_MASK, + sizeof(key->enc_ipv6.dst)); + } + + fl_set_key_val(tb, &key->enc_key_id.keyid, TCA_FLOWER_KEY_ENC_KEY_ID, + &mask->enc_key_id.keyid, TCA_FLOWER_UNSPEC, + sizeof(key->enc_key_id.keyid)); + + fl_set_key_val(tb, &key->enc_tp.src, TCA_FLOWER_KEY_ENC_UDP_SRC_PORT, + &mask->enc_tp.src, TCA_FLOWER_KEY_ENC_UDP_SRC_PORT_MASK, + sizeof(key->enc_tp.src)); + + fl_set_key_val(tb, &key->enc_tp.dst, TCA_FLOWER_KEY_ENC_UDP_DST_PORT, + &mask->enc_tp.dst, TCA_FLOWER_KEY_ENC_UDP_DST_PORT_MASK, + sizeof(key->enc_tp.dst)); + + fl_set_key_ip(tb, true, &key->enc_ip, &mask->enc_ip); + + fl_set_key_val(tb, &key->hash.hash, TCA_FLOWER_KEY_HASH, + &mask->hash.hash, TCA_FLOWER_KEY_HASH_MASK, + sizeof(key->hash.hash)); + + if (tb[TCA_FLOWER_KEY_ENC_OPTS]) { + ret = fl_set_enc_opt(tb, key, mask, extack); + if (ret) + return ret; + } + + ret = fl_set_key_ct(tb, &key->ct, &mask->ct, extack); + if (ret) + return ret; + + if (tb[TCA_FLOWER_KEY_FLAGS]) + ret = fl_set_key_flags(tb, &key->control.flags, + &mask->control.flags, extack); + + return ret; +} + +static void fl_mask_copy(struct fl_flow_mask *dst, + struct fl_flow_mask *src) +{ + const void *psrc = fl_key_get_start(&src->key, src); + void *pdst = fl_key_get_start(&dst->key, src); + + memcpy(pdst, psrc, fl_mask_range(src)); + dst->range = src->range; +} + +static const struct rhashtable_params fl_ht_params = { + .key_offset = offsetof(struct cls_fl_filter, mkey), /* base offset */ + .head_offset = offsetof(struct cls_fl_filter, ht_node), + .automatic_shrinking = true, +}; + +static int fl_init_mask_hashtable(struct fl_flow_mask *mask) +{ + mask->filter_ht_params = fl_ht_params; + mask->filter_ht_params.key_len = fl_mask_range(mask); + mask->filter_ht_params.key_offset += mask->range.start; + + return rhashtable_init(&mask->ht, &mask->filter_ht_params); +} + +#define FL_KEY_MEMBER_OFFSET(member) offsetof(struct fl_flow_key, member) +#define FL_KEY_MEMBER_SIZE(member) sizeof_field(struct fl_flow_key, member) + +#define FL_KEY_IS_MASKED(mask, member) \ + memchr_inv(((char *)mask) + FL_KEY_MEMBER_OFFSET(member), \ + 0, FL_KEY_MEMBER_SIZE(member)) \ + +#define FL_KEY_SET(keys, cnt, id, member) \ + do { \ + keys[cnt].key_id = id; \ + keys[cnt].offset = FL_KEY_MEMBER_OFFSET(member); \ + cnt++; \ + } while(0); + +#define FL_KEY_SET_IF_MASKED(mask, keys, cnt, id, member) \ + do { \ + if (FL_KEY_IS_MASKED(mask, member)) \ + FL_KEY_SET(keys, cnt, id, member); \ + } while(0); + +static void fl_init_dissector(struct flow_dissector *dissector, + struct fl_flow_key *mask) +{ + struct flow_dissector_key keys[FLOW_DISSECTOR_KEY_MAX]; + size_t cnt = 0; + + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_META, meta); + FL_KEY_SET(keys, cnt, FLOW_DISSECTOR_KEY_CONTROL, control); + FL_KEY_SET(keys, cnt, FLOW_DISSECTOR_KEY_BASIC, basic); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_ETH_ADDRS, eth); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_IPV4_ADDRS, ipv4); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_IPV6_ADDRS, ipv6); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_PORTS, tp); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_PORTS_RANGE, tp_range); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_IP, ip); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_TCP, tcp); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_ICMP, icmp); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_ARP, arp); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_MPLS, mpls); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_VLAN, vlan); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_CVLAN, cvlan); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_ENC_KEYID, enc_key_id); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_ENC_IPV4_ADDRS, enc_ipv4); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_ENC_IPV6_ADDRS, enc_ipv6); + if (FL_KEY_IS_MASKED(mask, enc_ipv4) || + FL_KEY_IS_MASKED(mask, enc_ipv6)) + FL_KEY_SET(keys, cnt, FLOW_DISSECTOR_KEY_ENC_CONTROL, + enc_control); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_ENC_PORTS, enc_tp); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_ENC_IP, enc_ip); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_ENC_OPTS, enc_opts); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_CT, ct); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_HASH, hash); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_NUM_OF_VLANS, num_of_vlans); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_PPPOE, pppoe); + FL_KEY_SET_IF_MASKED(mask, keys, cnt, + FLOW_DISSECTOR_KEY_L2TPV3, l2tpv3); + + skb_flow_dissector_init(dissector, keys, cnt); +} + +static struct fl_flow_mask *fl_create_new_mask(struct cls_fl_head *head, + struct fl_flow_mask *mask) +{ + struct fl_flow_mask *newmask; + int err; + + newmask = kzalloc(sizeof(*newmask), GFP_KERNEL); + if (!newmask) + return ERR_PTR(-ENOMEM); + + fl_mask_copy(newmask, mask); + + if ((newmask->key.tp_range.tp_min.dst && + newmask->key.tp_range.tp_max.dst) || + (newmask->key.tp_range.tp_min.src && + newmask->key.tp_range.tp_max.src)) + newmask->flags |= TCA_FLOWER_MASK_FLAGS_RANGE; + + err = fl_init_mask_hashtable(newmask); + if (err) + goto errout_free; + + fl_init_dissector(&newmask->dissector, &newmask->key); + + INIT_LIST_HEAD_RCU(&newmask->filters); + + refcount_set(&newmask->refcnt, 1); + err = rhashtable_replace_fast(&head->ht, &mask->ht_node, + &newmask->ht_node, mask_ht_params); + if (err) + goto errout_destroy; + + spin_lock(&head->masks_lock); + list_add_tail_rcu(&newmask->list, &head->masks); + spin_unlock(&head->masks_lock); + + return newmask; + +errout_destroy: + rhashtable_destroy(&newmask->ht); +errout_free: + kfree(newmask); + + return ERR_PTR(err); +} + +static int fl_check_assign_mask(struct cls_fl_head *head, + struct cls_fl_filter *fnew, + struct cls_fl_filter *fold, + struct fl_flow_mask *mask) +{ + struct fl_flow_mask *newmask; + int ret = 0; + + rcu_read_lock(); + + /* Insert mask as temporary node to prevent concurrent creation of mask + * with same key. Any concurrent lookups with same key will return + * -EAGAIN because mask's refcnt is zero. + */ + fnew->mask = rhashtable_lookup_get_insert_fast(&head->ht, + &mask->ht_node, + mask_ht_params); + if (!fnew->mask) { + rcu_read_unlock(); + + if (fold) { + ret = -EINVAL; + goto errout_cleanup; + } + + newmask = fl_create_new_mask(head, mask); + if (IS_ERR(newmask)) { + ret = PTR_ERR(newmask); + goto errout_cleanup; + } + + fnew->mask = newmask; + return 0; + } else if (IS_ERR(fnew->mask)) { + ret = PTR_ERR(fnew->mask); + } else if (fold && fold->mask != fnew->mask) { + ret = -EINVAL; + } else if (!refcount_inc_not_zero(&fnew->mask->refcnt)) { + /* Mask was deleted concurrently, try again */ + ret = -EAGAIN; + } + rcu_read_unlock(); + return ret; + +errout_cleanup: + rhashtable_remove_fast(&head->ht, &mask->ht_node, + mask_ht_params); + return ret; +} + +static int fl_set_parms(struct net *net, struct tcf_proto *tp, + struct cls_fl_filter *f, struct fl_flow_mask *mask, + unsigned long base, struct nlattr **tb, + struct nlattr *est, + struct fl_flow_tmplt *tmplt, + u32 flags, u32 fl_flags, + struct netlink_ext_ack *extack) +{ + int err; + + err = tcf_exts_validate_ex(net, tp, tb, est, &f->exts, flags, + fl_flags, extack); + if (err < 0) + return err; + + if (tb[TCA_FLOWER_CLASSID]) { + f->res.classid = nla_get_u32(tb[TCA_FLOWER_CLASSID]); + if (flags & TCA_ACT_FLAGS_NO_RTNL) + rtnl_lock(); + tcf_bind_filter(tp, &f->res, base); + if (flags & TCA_ACT_FLAGS_NO_RTNL) + rtnl_unlock(); + } + + err = fl_set_key(net, tb, &f->key, &mask->key, extack); + if (err) + return err; + + fl_mask_update_range(mask); + fl_set_masked_key(&f->mkey, &f->key, mask); + + if (!fl_mask_fits_tmplt(tmplt, mask)) { + NL_SET_ERR_MSG_MOD(extack, "Mask does not fit the template"); + return -EINVAL; + } + + return 0; +} + +static int fl_ht_insert_unique(struct cls_fl_filter *fnew, + struct cls_fl_filter *fold, + bool *in_ht) +{ + struct fl_flow_mask *mask = fnew->mask; + int err; + + err = rhashtable_lookup_insert_fast(&mask->ht, + &fnew->ht_node, + mask->filter_ht_params); + if (err) { + *in_ht = false; + /* It is okay if filter with same key exists when + * overwriting. + */ + return fold && err == -EEXIST ? 0 : err; + } + + *in_ht = true; + return 0; +} + +static int fl_change(struct net *net, struct sk_buff *in_skb, + struct tcf_proto *tp, unsigned long base, + u32 handle, struct nlattr **tca, + void **arg, u32 flags, + struct netlink_ext_ack *extack) +{ + struct cls_fl_head *head = fl_head_dereference(tp); + bool rtnl_held = !(flags & TCA_ACT_FLAGS_NO_RTNL); + struct cls_fl_filter *fold = *arg; + struct cls_fl_filter *fnew; + struct fl_flow_mask *mask; + struct nlattr **tb; + bool in_ht; + int err; + + if (!tca[TCA_OPTIONS]) { + err = -EINVAL; + goto errout_fold; + } + + mask = kzalloc(sizeof(struct fl_flow_mask), GFP_KERNEL); + if (!mask) { + err = -ENOBUFS; + goto errout_fold; + } + + tb = kcalloc(TCA_FLOWER_MAX + 1, sizeof(struct nlattr *), GFP_KERNEL); + if (!tb) { + err = -ENOBUFS; + goto errout_mask_alloc; + } + + err = nla_parse_nested_deprecated(tb, TCA_FLOWER_MAX, + tca[TCA_OPTIONS], fl_policy, NULL); + if (err < 0) + goto errout_tb; + + if (fold && handle && fold->handle != handle) { + err = -EINVAL; + goto errout_tb; + } + + fnew = kzalloc(sizeof(*fnew), GFP_KERNEL); + if (!fnew) { + err = -ENOBUFS; + goto errout_tb; + } + INIT_LIST_HEAD(&fnew->hw_list); + refcount_set(&fnew->refcnt, 1); + + err = tcf_exts_init(&fnew->exts, net, TCA_FLOWER_ACT, 0); + if (err < 0) + goto errout; + + if (tb[TCA_FLOWER_FLAGS]) { + fnew->flags = nla_get_u32(tb[TCA_FLOWER_FLAGS]); + + if (!tc_flags_valid(fnew->flags)) { + err = -EINVAL; + goto errout; + } + } + + err = fl_set_parms(net, tp, fnew, mask, base, tb, tca[TCA_RATE], + tp->chain->tmplt_priv, flags, fnew->flags, + extack); + if (err) + goto errout; + + err = fl_check_assign_mask(head, fnew, fold, mask); + if (err) + goto errout; + + err = fl_ht_insert_unique(fnew, fold, &in_ht); + if (err) + goto errout_mask; + + if (!tc_skip_hw(fnew->flags)) { + err = fl_hw_replace_filter(tp, fnew, rtnl_held, extack); + if (err) + goto errout_ht; + } + + if (!tc_in_hw(fnew->flags)) + fnew->flags |= TCA_CLS_FLAGS_NOT_IN_HW; + + spin_lock(&tp->lock); + + /* tp was deleted concurrently. -EAGAIN will cause caller to lookup + * proto again or create new one, if necessary. + */ + if (tp->deleting) { + err = -EAGAIN; + goto errout_hw; + } + + if (fold) { + /* Fold filter was deleted concurrently. Retry lookup. */ + if (fold->deleted) { + err = -EAGAIN; + goto errout_hw; + } + + fnew->handle = handle; + + if (!in_ht) { + struct rhashtable_params params = + fnew->mask->filter_ht_params; + + err = rhashtable_insert_fast(&fnew->mask->ht, + &fnew->ht_node, + params); + if (err) + goto errout_hw; + in_ht = true; + } + + refcount_inc(&fnew->refcnt); + rhashtable_remove_fast(&fold->mask->ht, + &fold->ht_node, + fold->mask->filter_ht_params); + idr_replace(&head->handle_idr, fnew, fnew->handle); + list_replace_rcu(&fold->list, &fnew->list); + fold->deleted = true; + + spin_unlock(&tp->lock); + + fl_mask_put(head, fold->mask); + if (!tc_skip_hw(fold->flags)) + fl_hw_destroy_filter(tp, fold, rtnl_held, NULL); + tcf_unbind_filter(tp, &fold->res); + /* Caller holds reference to fold, so refcnt is always > 0 + * after this. + */ + refcount_dec(&fold->refcnt); + __fl_put(fold); + } else { + if (handle) { + /* user specifies a handle and it doesn't exist */ + err = idr_alloc_u32(&head->handle_idr, fnew, &handle, + handle, GFP_ATOMIC); + + /* Filter with specified handle was concurrently + * inserted after initial check in cls_api. This is not + * necessarily an error if NLM_F_EXCL is not set in + * message flags. Returning EAGAIN will cause cls_api to + * try to update concurrently inserted rule. + */ + if (err == -ENOSPC) + err = -EAGAIN; + } else { + handle = 1; + err = idr_alloc_u32(&head->handle_idr, fnew, &handle, + INT_MAX, GFP_ATOMIC); + } + if (err) + goto errout_hw; + + refcount_inc(&fnew->refcnt); + fnew->handle = handle; + list_add_tail_rcu(&fnew->list, &fnew->mask->filters); + spin_unlock(&tp->lock); + } + + *arg = fnew; + + kfree(tb); + tcf_queue_work(&mask->rwork, fl_uninit_mask_free_work); + return 0; + +errout_ht: + spin_lock(&tp->lock); +errout_hw: + fnew->deleted = true; + spin_unlock(&tp->lock); + if (!tc_skip_hw(fnew->flags)) + fl_hw_destroy_filter(tp, fnew, rtnl_held, NULL); + if (in_ht) + rhashtable_remove_fast(&fnew->mask->ht, &fnew->ht_node, + fnew->mask->filter_ht_params); +errout_mask: + fl_mask_put(head, fnew->mask); +errout: + __fl_put(fnew); +errout_tb: + kfree(tb); +errout_mask_alloc: + tcf_queue_work(&mask->rwork, fl_uninit_mask_free_work); +errout_fold: + if (fold) + __fl_put(fold); + return err; +} + +static int fl_delete(struct tcf_proto *tp, void *arg, bool *last, + bool rtnl_held, struct netlink_ext_ack *extack) +{ + struct cls_fl_head *head = fl_head_dereference(tp); + struct cls_fl_filter *f = arg; + bool last_on_mask; + int err = 0; + + err = __fl_delete(tp, f, &last_on_mask, rtnl_held, extack); + *last = list_empty(&head->masks); + __fl_put(f); + + return err; +} + +static void fl_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) +{ + struct cls_fl_head *head = fl_head_dereference(tp); + unsigned long id = arg->cookie, tmp; + struct cls_fl_filter *f; + + arg->count = arg->skip; + + rcu_read_lock(); + idr_for_each_entry_continue_ul(&head->handle_idr, f, tmp, id) { + /* don't return filters that are being deleted */ + if (!refcount_inc_not_zero(&f->refcnt)) + continue; + rcu_read_unlock(); + + if (arg->fn(tp, f, arg) < 0) { + __fl_put(f); + arg->stop = 1; + rcu_read_lock(); + break; + } + __fl_put(f); + arg->count++; + rcu_read_lock(); + } + rcu_read_unlock(); + arg->cookie = id; +} + +static struct cls_fl_filter * +fl_get_next_hw_filter(struct tcf_proto *tp, struct cls_fl_filter *f, bool add) +{ + struct cls_fl_head *head = fl_head_dereference(tp); + + spin_lock(&tp->lock); + if (list_empty(&head->hw_filters)) { + spin_unlock(&tp->lock); + return NULL; + } + + if (!f) + f = list_entry(&head->hw_filters, struct cls_fl_filter, + hw_list); + list_for_each_entry_continue(f, &head->hw_filters, hw_list) { + if (!(add && f->deleted) && refcount_inc_not_zero(&f->refcnt)) { + spin_unlock(&tp->lock); + return f; + } + } + + spin_unlock(&tp->lock); + return NULL; +} + +static int fl_reoffload(struct tcf_proto *tp, bool add, flow_setup_cb_t *cb, + void *cb_priv, struct netlink_ext_ack *extack) +{ + struct tcf_block *block = tp->chain->block; + struct flow_cls_offload cls_flower = {}; + struct cls_fl_filter *f = NULL; + int err; + + /* hw_filters list can only be changed by hw offload functions after + * obtaining rtnl lock. Make sure it is not changed while reoffload is + * iterating it. + */ + ASSERT_RTNL(); + + while ((f = fl_get_next_hw_filter(tp, f, add))) { + cls_flower.rule = + flow_rule_alloc(tcf_exts_num_actions(&f->exts)); + if (!cls_flower.rule) { + __fl_put(f); + return -ENOMEM; + } + + tc_cls_common_offload_init(&cls_flower.common, tp, f->flags, + extack); + cls_flower.command = add ? + FLOW_CLS_REPLACE : FLOW_CLS_DESTROY; + cls_flower.cookie = (unsigned long)f; + cls_flower.rule->match.dissector = &f->mask->dissector; + cls_flower.rule->match.mask = &f->mask->key; + cls_flower.rule->match.key = &f->mkey; + + err = tc_setup_offload_action(&cls_flower.rule->action, &f->exts, + cls_flower.common.extack); + if (err) { + kfree(cls_flower.rule); + if (tc_skip_sw(f->flags)) { + __fl_put(f); + return err; + } + goto next_flow; + } + + cls_flower.classid = f->res.classid; + + err = tc_setup_cb_reoffload(block, tp, add, cb, + TC_SETUP_CLSFLOWER, &cls_flower, + cb_priv, &f->flags, + &f->in_hw_count); + tc_cleanup_offload_action(&cls_flower.rule->action); + kfree(cls_flower.rule); + + if (err) { + __fl_put(f); + return err; + } +next_flow: + __fl_put(f); + } + + return 0; +} + +static void fl_hw_add(struct tcf_proto *tp, void *type_data) +{ + struct flow_cls_offload *cls_flower = type_data; + struct cls_fl_filter *f = + (struct cls_fl_filter *) cls_flower->cookie; + struct cls_fl_head *head = fl_head_dereference(tp); + + spin_lock(&tp->lock); + list_add(&f->hw_list, &head->hw_filters); + spin_unlock(&tp->lock); +} + +static void fl_hw_del(struct tcf_proto *tp, void *type_data) +{ + struct flow_cls_offload *cls_flower = type_data; + struct cls_fl_filter *f = + (struct cls_fl_filter *) cls_flower->cookie; + + spin_lock(&tp->lock); + if (!list_empty(&f->hw_list)) + list_del_init(&f->hw_list); + spin_unlock(&tp->lock); +} + +static int fl_hw_create_tmplt(struct tcf_chain *chain, + struct fl_flow_tmplt *tmplt) +{ + struct flow_cls_offload cls_flower = {}; + struct tcf_block *block = chain->block; + + cls_flower.rule = flow_rule_alloc(0); + if (!cls_flower.rule) + return -ENOMEM; + + cls_flower.common.chain_index = chain->index; + cls_flower.command = FLOW_CLS_TMPLT_CREATE; + cls_flower.cookie = (unsigned long) tmplt; + cls_flower.rule->match.dissector = &tmplt->dissector; + cls_flower.rule->match.mask = &tmplt->mask; + cls_flower.rule->match.key = &tmplt->dummy_key; + + /* We don't care if driver (any of them) fails to handle this + * call. It serves just as a hint for it. + */ + tc_setup_cb_call(block, TC_SETUP_CLSFLOWER, &cls_flower, false, true); + kfree(cls_flower.rule); + + return 0; +} + +static void fl_hw_destroy_tmplt(struct tcf_chain *chain, + struct fl_flow_tmplt *tmplt) +{ + struct flow_cls_offload cls_flower = {}; + struct tcf_block *block = chain->block; + + cls_flower.common.chain_index = chain->index; + cls_flower.command = FLOW_CLS_TMPLT_DESTROY; + cls_flower.cookie = (unsigned long) tmplt; + + tc_setup_cb_call(block, TC_SETUP_CLSFLOWER, &cls_flower, false, true); +} + +static void *fl_tmplt_create(struct net *net, struct tcf_chain *chain, + struct nlattr **tca, + struct netlink_ext_ack *extack) +{ + struct fl_flow_tmplt *tmplt; + struct nlattr **tb; + int err; + + if (!tca[TCA_OPTIONS]) + return ERR_PTR(-EINVAL); + + tb = kcalloc(TCA_FLOWER_MAX + 1, sizeof(struct nlattr *), GFP_KERNEL); + if (!tb) + return ERR_PTR(-ENOBUFS); + err = nla_parse_nested_deprecated(tb, TCA_FLOWER_MAX, + tca[TCA_OPTIONS], fl_policy, NULL); + if (err) + goto errout_tb; + + tmplt = kzalloc(sizeof(*tmplt), GFP_KERNEL); + if (!tmplt) { + err = -ENOMEM; + goto errout_tb; + } + tmplt->chain = chain; + err = fl_set_key(net, tb, &tmplt->dummy_key, &tmplt->mask, extack); + if (err) + goto errout_tmplt; + + fl_init_dissector(&tmplt->dissector, &tmplt->mask); + + err = fl_hw_create_tmplt(chain, tmplt); + if (err) + goto errout_tmplt; + + kfree(tb); + return tmplt; + +errout_tmplt: + kfree(tmplt); +errout_tb: + kfree(tb); + return ERR_PTR(err); +} + +static void fl_tmplt_destroy(void *tmplt_priv) +{ + struct fl_flow_tmplt *tmplt = tmplt_priv; + + fl_hw_destroy_tmplt(tmplt->chain, tmplt); + kfree(tmplt); +} + +static int fl_dump_key_val(struct sk_buff *skb, + void *val, int val_type, + void *mask, int mask_type, int len) +{ + int err; + + if (!memchr_inv(mask, 0, len)) + return 0; + err = nla_put(skb, val_type, len, val); + if (err) + return err; + if (mask_type != TCA_FLOWER_UNSPEC) { + err = nla_put(skb, mask_type, len, mask); + if (err) + return err; + } + return 0; +} + +static int fl_dump_key_port_range(struct sk_buff *skb, struct fl_flow_key *key, + struct fl_flow_key *mask) +{ + if (fl_dump_key_val(skb, &key->tp_range.tp_min.dst, + TCA_FLOWER_KEY_PORT_DST_MIN, + &mask->tp_range.tp_min.dst, TCA_FLOWER_UNSPEC, + sizeof(key->tp_range.tp_min.dst)) || + fl_dump_key_val(skb, &key->tp_range.tp_max.dst, + TCA_FLOWER_KEY_PORT_DST_MAX, + &mask->tp_range.tp_max.dst, TCA_FLOWER_UNSPEC, + sizeof(key->tp_range.tp_max.dst)) || + fl_dump_key_val(skb, &key->tp_range.tp_min.src, + TCA_FLOWER_KEY_PORT_SRC_MIN, + &mask->tp_range.tp_min.src, TCA_FLOWER_UNSPEC, + sizeof(key->tp_range.tp_min.src)) || + fl_dump_key_val(skb, &key->tp_range.tp_max.src, + TCA_FLOWER_KEY_PORT_SRC_MAX, + &mask->tp_range.tp_max.src, TCA_FLOWER_UNSPEC, + sizeof(key->tp_range.tp_max.src))) + return -1; + + return 0; +} + +static int fl_dump_key_mpls_opt_lse(struct sk_buff *skb, + struct flow_dissector_key_mpls *mpls_key, + struct flow_dissector_key_mpls *mpls_mask, + u8 lse_index) +{ + struct flow_dissector_mpls_lse *lse_mask = &mpls_mask->ls[lse_index]; + struct flow_dissector_mpls_lse *lse_key = &mpls_key->ls[lse_index]; + int err; + + err = nla_put_u8(skb, TCA_FLOWER_KEY_MPLS_OPT_LSE_DEPTH, + lse_index + 1); + if (err) + return err; + + if (lse_mask->mpls_ttl) { + err = nla_put_u8(skb, TCA_FLOWER_KEY_MPLS_OPT_LSE_TTL, + lse_key->mpls_ttl); + if (err) + return err; + } + if (lse_mask->mpls_bos) { + err = nla_put_u8(skb, TCA_FLOWER_KEY_MPLS_OPT_LSE_BOS, + lse_key->mpls_bos); + if (err) + return err; + } + if (lse_mask->mpls_tc) { + err = nla_put_u8(skb, TCA_FLOWER_KEY_MPLS_OPT_LSE_TC, + lse_key->mpls_tc); + if (err) + return err; + } + if (lse_mask->mpls_label) { + err = nla_put_u32(skb, TCA_FLOWER_KEY_MPLS_OPT_LSE_LABEL, + lse_key->mpls_label); + if (err) + return err; + } + + return 0; +} + +static int fl_dump_key_mpls_opts(struct sk_buff *skb, + struct flow_dissector_key_mpls *mpls_key, + struct flow_dissector_key_mpls *mpls_mask) +{ + struct nlattr *opts; + struct nlattr *lse; + u8 lse_index; + int err; + + opts = nla_nest_start(skb, TCA_FLOWER_KEY_MPLS_OPTS); + if (!opts) + return -EMSGSIZE; + + for (lse_index = 0; lse_index < FLOW_DIS_MPLS_MAX; lse_index++) { + if (!(mpls_mask->used_lses & 1 << lse_index)) + continue; + + lse = nla_nest_start(skb, TCA_FLOWER_KEY_MPLS_OPTS_LSE); + if (!lse) { + err = -EMSGSIZE; + goto err_opts; + } + + err = fl_dump_key_mpls_opt_lse(skb, mpls_key, mpls_mask, + lse_index); + if (err) + goto err_opts_lse; + nla_nest_end(skb, lse); + } + nla_nest_end(skb, opts); + + return 0; + +err_opts_lse: + nla_nest_cancel(skb, lse); +err_opts: + nla_nest_cancel(skb, opts); + + return err; +} + +static int fl_dump_key_mpls(struct sk_buff *skb, + struct flow_dissector_key_mpls *mpls_key, + struct flow_dissector_key_mpls *mpls_mask) +{ + struct flow_dissector_mpls_lse *lse_mask; + struct flow_dissector_mpls_lse *lse_key; + int err; + + if (!mpls_mask->used_lses) + return 0; + + lse_mask = &mpls_mask->ls[0]; + lse_key = &mpls_key->ls[0]; + + /* For backward compatibility, don't use the MPLS nested attributes if + * the rule can be expressed using the old attributes. + */ + if (mpls_mask->used_lses & ~1 || + (!lse_mask->mpls_ttl && !lse_mask->mpls_bos && + !lse_mask->mpls_tc && !lse_mask->mpls_label)) + return fl_dump_key_mpls_opts(skb, mpls_key, mpls_mask); + + if (lse_mask->mpls_ttl) { + err = nla_put_u8(skb, TCA_FLOWER_KEY_MPLS_TTL, + lse_key->mpls_ttl); + if (err) + return err; + } + if (lse_mask->mpls_tc) { + err = nla_put_u8(skb, TCA_FLOWER_KEY_MPLS_TC, + lse_key->mpls_tc); + if (err) + return err; + } + if (lse_mask->mpls_label) { + err = nla_put_u32(skb, TCA_FLOWER_KEY_MPLS_LABEL, + lse_key->mpls_label); + if (err) + return err; + } + if (lse_mask->mpls_bos) { + err = nla_put_u8(skb, TCA_FLOWER_KEY_MPLS_BOS, + lse_key->mpls_bos); + if (err) + return err; + } + return 0; +} + +static int fl_dump_key_ip(struct sk_buff *skb, bool encap, + struct flow_dissector_key_ip *key, + struct flow_dissector_key_ip *mask) +{ + int tos_key = encap ? TCA_FLOWER_KEY_ENC_IP_TOS : TCA_FLOWER_KEY_IP_TOS; + int ttl_key = encap ? TCA_FLOWER_KEY_ENC_IP_TTL : TCA_FLOWER_KEY_IP_TTL; + int tos_mask = encap ? TCA_FLOWER_KEY_ENC_IP_TOS_MASK : TCA_FLOWER_KEY_IP_TOS_MASK; + int ttl_mask = encap ? TCA_FLOWER_KEY_ENC_IP_TTL_MASK : TCA_FLOWER_KEY_IP_TTL_MASK; + + if (fl_dump_key_val(skb, &key->tos, tos_key, &mask->tos, tos_mask, sizeof(key->tos)) || + fl_dump_key_val(skb, &key->ttl, ttl_key, &mask->ttl, ttl_mask, sizeof(key->ttl))) + return -1; + + return 0; +} + +static int fl_dump_key_vlan(struct sk_buff *skb, + int vlan_id_key, int vlan_prio_key, + struct flow_dissector_key_vlan *vlan_key, + struct flow_dissector_key_vlan *vlan_mask) +{ + int err; + + if (!memchr_inv(vlan_mask, 0, sizeof(*vlan_mask))) + return 0; + if (vlan_mask->vlan_id) { + err = nla_put_u16(skb, vlan_id_key, + vlan_key->vlan_id); + if (err) + return err; + } + if (vlan_mask->vlan_priority) { + err = nla_put_u8(skb, vlan_prio_key, + vlan_key->vlan_priority); + if (err) + return err; + } + return 0; +} + +static void fl_get_key_flag(u32 dissector_key, u32 dissector_mask, + u32 *flower_key, u32 *flower_mask, + u32 flower_flag_bit, u32 dissector_flag_bit) +{ + if (dissector_mask & dissector_flag_bit) { + *flower_mask |= flower_flag_bit; + if (dissector_key & dissector_flag_bit) + *flower_key |= flower_flag_bit; + } +} + +static int fl_dump_key_flags(struct sk_buff *skb, u32 flags_key, u32 flags_mask) +{ + u32 key, mask; + __be32 _key, _mask; + int err; + + if (!memchr_inv(&flags_mask, 0, sizeof(flags_mask))) + return 0; + + key = 0; + mask = 0; + + fl_get_key_flag(flags_key, flags_mask, &key, &mask, + TCA_FLOWER_KEY_FLAGS_IS_FRAGMENT, FLOW_DIS_IS_FRAGMENT); + fl_get_key_flag(flags_key, flags_mask, &key, &mask, + TCA_FLOWER_KEY_FLAGS_FRAG_IS_FIRST, + FLOW_DIS_FIRST_FRAG); + + _key = cpu_to_be32(key); + _mask = cpu_to_be32(mask); + + err = nla_put(skb, TCA_FLOWER_KEY_FLAGS, 4, &_key); + if (err) + return err; + + return nla_put(skb, TCA_FLOWER_KEY_FLAGS_MASK, 4, &_mask); +} + +static int fl_dump_key_geneve_opt(struct sk_buff *skb, + struct flow_dissector_key_enc_opts *enc_opts) +{ + struct geneve_opt *opt; + struct nlattr *nest; + int opt_off = 0; + + nest = nla_nest_start_noflag(skb, TCA_FLOWER_KEY_ENC_OPTS_GENEVE); + if (!nest) + goto nla_put_failure; + + while (enc_opts->len > opt_off) { + opt = (struct geneve_opt *)&enc_opts->data[opt_off]; + + if (nla_put_be16(skb, TCA_FLOWER_KEY_ENC_OPT_GENEVE_CLASS, + opt->opt_class)) + goto nla_put_failure; + if (nla_put_u8(skb, TCA_FLOWER_KEY_ENC_OPT_GENEVE_TYPE, + opt->type)) + goto nla_put_failure; + if (nla_put(skb, TCA_FLOWER_KEY_ENC_OPT_GENEVE_DATA, + opt->length * 4, opt->opt_data)) + goto nla_put_failure; + + opt_off += sizeof(struct geneve_opt) + opt->length * 4; + } + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int fl_dump_key_vxlan_opt(struct sk_buff *skb, + struct flow_dissector_key_enc_opts *enc_opts) +{ + struct vxlan_metadata *md; + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, TCA_FLOWER_KEY_ENC_OPTS_VXLAN); + if (!nest) + goto nla_put_failure; + + md = (struct vxlan_metadata *)&enc_opts->data[0]; + if (nla_put_u32(skb, TCA_FLOWER_KEY_ENC_OPT_VXLAN_GBP, md->gbp)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int fl_dump_key_erspan_opt(struct sk_buff *skb, + struct flow_dissector_key_enc_opts *enc_opts) +{ + struct erspan_metadata *md; + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, TCA_FLOWER_KEY_ENC_OPTS_ERSPAN); + if (!nest) + goto nla_put_failure; + + md = (struct erspan_metadata *)&enc_opts->data[0]; + if (nla_put_u8(skb, TCA_FLOWER_KEY_ENC_OPT_ERSPAN_VER, md->version)) + goto nla_put_failure; + + if (md->version == 1 && + nla_put_be32(skb, TCA_FLOWER_KEY_ENC_OPT_ERSPAN_INDEX, md->u.index)) + goto nla_put_failure; + + if (md->version == 2 && + (nla_put_u8(skb, TCA_FLOWER_KEY_ENC_OPT_ERSPAN_DIR, + md->u.md2.dir) || + nla_put_u8(skb, TCA_FLOWER_KEY_ENC_OPT_ERSPAN_HWID, + get_hwid(&md->u.md2)))) + goto nla_put_failure; + + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int fl_dump_key_gtp_opt(struct sk_buff *skb, + struct flow_dissector_key_enc_opts *enc_opts) + +{ + struct gtp_pdu_session_info *session_info; + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, TCA_FLOWER_KEY_ENC_OPTS_GTP); + if (!nest) + goto nla_put_failure; + + session_info = (struct gtp_pdu_session_info *)&enc_opts->data[0]; + + if (nla_put_u8(skb, TCA_FLOWER_KEY_ENC_OPT_GTP_PDU_TYPE, + session_info->pdu_type)) + goto nla_put_failure; + + if (nla_put_u8(skb, TCA_FLOWER_KEY_ENC_OPT_GTP_QFI, session_info->qfi)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int fl_dump_key_ct(struct sk_buff *skb, + struct flow_dissector_key_ct *key, + struct flow_dissector_key_ct *mask) +{ + if (IS_ENABLED(CONFIG_NF_CONNTRACK) && + fl_dump_key_val(skb, &key->ct_state, TCA_FLOWER_KEY_CT_STATE, + &mask->ct_state, TCA_FLOWER_KEY_CT_STATE_MASK, + sizeof(key->ct_state))) + goto nla_put_failure; + + if (IS_ENABLED(CONFIG_NF_CONNTRACK_ZONES) && + fl_dump_key_val(skb, &key->ct_zone, TCA_FLOWER_KEY_CT_ZONE, + &mask->ct_zone, TCA_FLOWER_KEY_CT_ZONE_MASK, + sizeof(key->ct_zone))) + goto nla_put_failure; + + if (IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) && + fl_dump_key_val(skb, &key->ct_mark, TCA_FLOWER_KEY_CT_MARK, + &mask->ct_mark, TCA_FLOWER_KEY_CT_MARK_MASK, + sizeof(key->ct_mark))) + goto nla_put_failure; + + if (IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS) && + fl_dump_key_val(skb, &key->ct_labels, TCA_FLOWER_KEY_CT_LABELS, + &mask->ct_labels, TCA_FLOWER_KEY_CT_LABELS_MASK, + sizeof(key->ct_labels))) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static int fl_dump_key_options(struct sk_buff *skb, int enc_opt_type, + struct flow_dissector_key_enc_opts *enc_opts) +{ + struct nlattr *nest; + int err; + + if (!enc_opts->len) + return 0; + + nest = nla_nest_start_noflag(skb, enc_opt_type); + if (!nest) + goto nla_put_failure; + + switch (enc_opts->dst_opt_type) { + case TUNNEL_GENEVE_OPT: + err = fl_dump_key_geneve_opt(skb, enc_opts); + if (err) + goto nla_put_failure; + break; + case TUNNEL_VXLAN_OPT: + err = fl_dump_key_vxlan_opt(skb, enc_opts); + if (err) + goto nla_put_failure; + break; + case TUNNEL_ERSPAN_OPT: + err = fl_dump_key_erspan_opt(skb, enc_opts); + if (err) + goto nla_put_failure; + break; + case TUNNEL_GTP_OPT: + err = fl_dump_key_gtp_opt(skb, enc_opts); + if (err) + goto nla_put_failure; + break; + default: + goto nla_put_failure; + } + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int fl_dump_key_enc_opt(struct sk_buff *skb, + struct flow_dissector_key_enc_opts *key_opts, + struct flow_dissector_key_enc_opts *msk_opts) +{ + int err; + + err = fl_dump_key_options(skb, TCA_FLOWER_KEY_ENC_OPTS, key_opts); + if (err) + return err; + + return fl_dump_key_options(skb, TCA_FLOWER_KEY_ENC_OPTS_MASK, msk_opts); +} + +static int fl_dump_key(struct sk_buff *skb, struct net *net, + struct fl_flow_key *key, struct fl_flow_key *mask) +{ + if (mask->meta.ingress_ifindex) { + struct net_device *dev; + + dev = __dev_get_by_index(net, key->meta.ingress_ifindex); + if (dev && nla_put_string(skb, TCA_FLOWER_INDEV, dev->name)) + goto nla_put_failure; + } + + if (fl_dump_key_val(skb, key->eth.dst, TCA_FLOWER_KEY_ETH_DST, + mask->eth.dst, TCA_FLOWER_KEY_ETH_DST_MASK, + sizeof(key->eth.dst)) || + fl_dump_key_val(skb, key->eth.src, TCA_FLOWER_KEY_ETH_SRC, + mask->eth.src, TCA_FLOWER_KEY_ETH_SRC_MASK, + sizeof(key->eth.src)) || + fl_dump_key_val(skb, &key->basic.n_proto, TCA_FLOWER_KEY_ETH_TYPE, + &mask->basic.n_proto, TCA_FLOWER_UNSPEC, + sizeof(key->basic.n_proto))) + goto nla_put_failure; + + if (mask->num_of_vlans.num_of_vlans) { + if (nla_put_u8(skb, TCA_FLOWER_KEY_NUM_OF_VLANS, key->num_of_vlans.num_of_vlans)) + goto nla_put_failure; + } + + if (fl_dump_key_mpls(skb, &key->mpls, &mask->mpls)) + goto nla_put_failure; + + if (fl_dump_key_vlan(skb, TCA_FLOWER_KEY_VLAN_ID, + TCA_FLOWER_KEY_VLAN_PRIO, &key->vlan, &mask->vlan)) + goto nla_put_failure; + + if (fl_dump_key_vlan(skb, TCA_FLOWER_KEY_CVLAN_ID, + TCA_FLOWER_KEY_CVLAN_PRIO, + &key->cvlan, &mask->cvlan) || + (mask->cvlan.vlan_tpid && + nla_put_be16(skb, TCA_FLOWER_KEY_VLAN_ETH_TYPE, + key->cvlan.vlan_tpid))) + goto nla_put_failure; + + if (mask->basic.n_proto) { + if (mask->cvlan.vlan_eth_type) { + if (nla_put_be16(skb, TCA_FLOWER_KEY_CVLAN_ETH_TYPE, + key->basic.n_proto)) + goto nla_put_failure; + } else if (mask->vlan.vlan_eth_type) { + if (nla_put_be16(skb, TCA_FLOWER_KEY_VLAN_ETH_TYPE, + key->vlan.vlan_eth_type)) + goto nla_put_failure; + } + } + + if ((key->basic.n_proto == htons(ETH_P_IP) || + key->basic.n_proto == htons(ETH_P_IPV6)) && + (fl_dump_key_val(skb, &key->basic.ip_proto, TCA_FLOWER_KEY_IP_PROTO, + &mask->basic.ip_proto, TCA_FLOWER_UNSPEC, + sizeof(key->basic.ip_proto)) || + fl_dump_key_ip(skb, false, &key->ip, &mask->ip))) + goto nla_put_failure; + + if (mask->pppoe.session_id) { + if (nla_put_be16(skb, TCA_FLOWER_KEY_PPPOE_SID, + key->pppoe.session_id)) + goto nla_put_failure; + } + if (mask->basic.n_proto && mask->pppoe.ppp_proto) { + if (nla_put_be16(skb, TCA_FLOWER_KEY_PPP_PROTO, + key->pppoe.ppp_proto)) + goto nla_put_failure; + } + + if (key->control.addr_type == FLOW_DISSECTOR_KEY_IPV4_ADDRS && + (fl_dump_key_val(skb, &key->ipv4.src, TCA_FLOWER_KEY_IPV4_SRC, + &mask->ipv4.src, TCA_FLOWER_KEY_IPV4_SRC_MASK, + sizeof(key->ipv4.src)) || + fl_dump_key_val(skb, &key->ipv4.dst, TCA_FLOWER_KEY_IPV4_DST, + &mask->ipv4.dst, TCA_FLOWER_KEY_IPV4_DST_MASK, + sizeof(key->ipv4.dst)))) + goto nla_put_failure; + else if (key->control.addr_type == FLOW_DISSECTOR_KEY_IPV6_ADDRS && + (fl_dump_key_val(skb, &key->ipv6.src, TCA_FLOWER_KEY_IPV6_SRC, + &mask->ipv6.src, TCA_FLOWER_KEY_IPV6_SRC_MASK, + sizeof(key->ipv6.src)) || + fl_dump_key_val(skb, &key->ipv6.dst, TCA_FLOWER_KEY_IPV6_DST, + &mask->ipv6.dst, TCA_FLOWER_KEY_IPV6_DST_MASK, + sizeof(key->ipv6.dst)))) + goto nla_put_failure; + + if (key->basic.ip_proto == IPPROTO_TCP && + (fl_dump_key_val(skb, &key->tp.src, TCA_FLOWER_KEY_TCP_SRC, + &mask->tp.src, TCA_FLOWER_KEY_TCP_SRC_MASK, + sizeof(key->tp.src)) || + fl_dump_key_val(skb, &key->tp.dst, TCA_FLOWER_KEY_TCP_DST, + &mask->tp.dst, TCA_FLOWER_KEY_TCP_DST_MASK, + sizeof(key->tp.dst)) || + fl_dump_key_val(skb, &key->tcp.flags, TCA_FLOWER_KEY_TCP_FLAGS, + &mask->tcp.flags, TCA_FLOWER_KEY_TCP_FLAGS_MASK, + sizeof(key->tcp.flags)))) + goto nla_put_failure; + else if (key->basic.ip_proto == IPPROTO_UDP && + (fl_dump_key_val(skb, &key->tp.src, TCA_FLOWER_KEY_UDP_SRC, + &mask->tp.src, TCA_FLOWER_KEY_UDP_SRC_MASK, + sizeof(key->tp.src)) || + fl_dump_key_val(skb, &key->tp.dst, TCA_FLOWER_KEY_UDP_DST, + &mask->tp.dst, TCA_FLOWER_KEY_UDP_DST_MASK, + sizeof(key->tp.dst)))) + goto nla_put_failure; + else if (key->basic.ip_proto == IPPROTO_SCTP && + (fl_dump_key_val(skb, &key->tp.src, TCA_FLOWER_KEY_SCTP_SRC, + &mask->tp.src, TCA_FLOWER_KEY_SCTP_SRC_MASK, + sizeof(key->tp.src)) || + fl_dump_key_val(skb, &key->tp.dst, TCA_FLOWER_KEY_SCTP_DST, + &mask->tp.dst, TCA_FLOWER_KEY_SCTP_DST_MASK, + sizeof(key->tp.dst)))) + goto nla_put_failure; + else if (key->basic.n_proto == htons(ETH_P_IP) && + key->basic.ip_proto == IPPROTO_ICMP && + (fl_dump_key_val(skb, &key->icmp.type, + TCA_FLOWER_KEY_ICMPV4_TYPE, &mask->icmp.type, + TCA_FLOWER_KEY_ICMPV4_TYPE_MASK, + sizeof(key->icmp.type)) || + fl_dump_key_val(skb, &key->icmp.code, + TCA_FLOWER_KEY_ICMPV4_CODE, &mask->icmp.code, + TCA_FLOWER_KEY_ICMPV4_CODE_MASK, + sizeof(key->icmp.code)))) + goto nla_put_failure; + else if (key->basic.n_proto == htons(ETH_P_IPV6) && + key->basic.ip_proto == IPPROTO_ICMPV6 && + (fl_dump_key_val(skb, &key->icmp.type, + TCA_FLOWER_KEY_ICMPV6_TYPE, &mask->icmp.type, + TCA_FLOWER_KEY_ICMPV6_TYPE_MASK, + sizeof(key->icmp.type)) || + fl_dump_key_val(skb, &key->icmp.code, + TCA_FLOWER_KEY_ICMPV6_CODE, &mask->icmp.code, + TCA_FLOWER_KEY_ICMPV6_CODE_MASK, + sizeof(key->icmp.code)))) + goto nla_put_failure; + else if ((key->basic.n_proto == htons(ETH_P_ARP) || + key->basic.n_proto == htons(ETH_P_RARP)) && + (fl_dump_key_val(skb, &key->arp.sip, + TCA_FLOWER_KEY_ARP_SIP, &mask->arp.sip, + TCA_FLOWER_KEY_ARP_SIP_MASK, + sizeof(key->arp.sip)) || + fl_dump_key_val(skb, &key->arp.tip, + TCA_FLOWER_KEY_ARP_TIP, &mask->arp.tip, + TCA_FLOWER_KEY_ARP_TIP_MASK, + sizeof(key->arp.tip)) || + fl_dump_key_val(skb, &key->arp.op, + TCA_FLOWER_KEY_ARP_OP, &mask->arp.op, + TCA_FLOWER_KEY_ARP_OP_MASK, + sizeof(key->arp.op)) || + fl_dump_key_val(skb, key->arp.sha, TCA_FLOWER_KEY_ARP_SHA, + mask->arp.sha, TCA_FLOWER_KEY_ARP_SHA_MASK, + sizeof(key->arp.sha)) || + fl_dump_key_val(skb, key->arp.tha, TCA_FLOWER_KEY_ARP_THA, + mask->arp.tha, TCA_FLOWER_KEY_ARP_THA_MASK, + sizeof(key->arp.tha)))) + goto nla_put_failure; + else if (key->basic.ip_proto == IPPROTO_L2TP && + fl_dump_key_val(skb, &key->l2tpv3.session_id, + TCA_FLOWER_KEY_L2TPV3_SID, + &mask->l2tpv3.session_id, + TCA_FLOWER_UNSPEC, + sizeof(key->l2tpv3.session_id))) + goto nla_put_failure; + + if ((key->basic.ip_proto == IPPROTO_TCP || + key->basic.ip_proto == IPPROTO_UDP || + key->basic.ip_proto == IPPROTO_SCTP) && + fl_dump_key_port_range(skb, key, mask)) + goto nla_put_failure; + + if (key->enc_control.addr_type == FLOW_DISSECTOR_KEY_IPV4_ADDRS && + (fl_dump_key_val(skb, &key->enc_ipv4.src, + TCA_FLOWER_KEY_ENC_IPV4_SRC, &mask->enc_ipv4.src, + TCA_FLOWER_KEY_ENC_IPV4_SRC_MASK, + sizeof(key->enc_ipv4.src)) || + fl_dump_key_val(skb, &key->enc_ipv4.dst, + TCA_FLOWER_KEY_ENC_IPV4_DST, &mask->enc_ipv4.dst, + TCA_FLOWER_KEY_ENC_IPV4_DST_MASK, + sizeof(key->enc_ipv4.dst)))) + goto nla_put_failure; + else if (key->enc_control.addr_type == FLOW_DISSECTOR_KEY_IPV6_ADDRS && + (fl_dump_key_val(skb, &key->enc_ipv6.src, + TCA_FLOWER_KEY_ENC_IPV6_SRC, &mask->enc_ipv6.src, + TCA_FLOWER_KEY_ENC_IPV6_SRC_MASK, + sizeof(key->enc_ipv6.src)) || + fl_dump_key_val(skb, &key->enc_ipv6.dst, + TCA_FLOWER_KEY_ENC_IPV6_DST, + &mask->enc_ipv6.dst, + TCA_FLOWER_KEY_ENC_IPV6_DST_MASK, + sizeof(key->enc_ipv6.dst)))) + goto nla_put_failure; + + if (fl_dump_key_val(skb, &key->enc_key_id, TCA_FLOWER_KEY_ENC_KEY_ID, + &mask->enc_key_id, TCA_FLOWER_UNSPEC, + sizeof(key->enc_key_id)) || + fl_dump_key_val(skb, &key->enc_tp.src, + TCA_FLOWER_KEY_ENC_UDP_SRC_PORT, + &mask->enc_tp.src, + TCA_FLOWER_KEY_ENC_UDP_SRC_PORT_MASK, + sizeof(key->enc_tp.src)) || + fl_dump_key_val(skb, &key->enc_tp.dst, + TCA_FLOWER_KEY_ENC_UDP_DST_PORT, + &mask->enc_tp.dst, + TCA_FLOWER_KEY_ENC_UDP_DST_PORT_MASK, + sizeof(key->enc_tp.dst)) || + fl_dump_key_ip(skb, true, &key->enc_ip, &mask->enc_ip) || + fl_dump_key_enc_opt(skb, &key->enc_opts, &mask->enc_opts)) + goto nla_put_failure; + + if (fl_dump_key_ct(skb, &key->ct, &mask->ct)) + goto nla_put_failure; + + if (fl_dump_key_flags(skb, key->control.flags, mask->control.flags)) + goto nla_put_failure; + + if (fl_dump_key_val(skb, &key->hash.hash, TCA_FLOWER_KEY_HASH, + &mask->hash.hash, TCA_FLOWER_KEY_HASH_MASK, + sizeof(key->hash.hash))) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static int fl_dump(struct net *net, struct tcf_proto *tp, void *fh, + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) +{ + struct cls_fl_filter *f = fh; + struct nlattr *nest; + struct fl_flow_key *key, *mask; + bool skip_hw; + + if (!f) + return skb->len; + + t->tcm_handle = f->handle; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (!nest) + goto nla_put_failure; + + spin_lock(&tp->lock); + + if (f->res.classid && + nla_put_u32(skb, TCA_FLOWER_CLASSID, f->res.classid)) + goto nla_put_failure_locked; + + key = &f->key; + mask = &f->mask->key; + skip_hw = tc_skip_hw(f->flags); + + if (fl_dump_key(skb, net, key, mask)) + goto nla_put_failure_locked; + + if (f->flags && nla_put_u32(skb, TCA_FLOWER_FLAGS, f->flags)) + goto nla_put_failure_locked; + + spin_unlock(&tp->lock); + + if (!skip_hw) + fl_hw_update_stats(tp, f, rtnl_held); + + if (nla_put_u32(skb, TCA_FLOWER_IN_HW_COUNT, f->in_hw_count)) + goto nla_put_failure; + + if (tcf_exts_dump(skb, &f->exts)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + + if (tcf_exts_dump_stats(skb, &f->exts) < 0) + goto nla_put_failure; + + return skb->len; + +nla_put_failure_locked: + spin_unlock(&tp->lock); +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static int fl_terse_dump(struct net *net, struct tcf_proto *tp, void *fh, + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) +{ + struct cls_fl_filter *f = fh; + struct nlattr *nest; + bool skip_hw; + + if (!f) + return skb->len; + + t->tcm_handle = f->handle; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (!nest) + goto nla_put_failure; + + spin_lock(&tp->lock); + + skip_hw = tc_skip_hw(f->flags); + + if (f->flags && nla_put_u32(skb, TCA_FLOWER_FLAGS, f->flags)) + goto nla_put_failure_locked; + + spin_unlock(&tp->lock); + + if (!skip_hw) + fl_hw_update_stats(tp, f, rtnl_held); + + if (tcf_exts_terse_dump(skb, &f->exts)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + + return skb->len; + +nla_put_failure_locked: + spin_unlock(&tp->lock); +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static int fl_tmplt_dump(struct sk_buff *skb, struct net *net, void *tmplt_priv) +{ + struct fl_flow_tmplt *tmplt = tmplt_priv; + struct fl_flow_key *key, *mask; + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (!nest) + goto nla_put_failure; + + key = &tmplt->dummy_key; + mask = &tmplt->mask; + + if (fl_dump_key(skb, net, key, mask)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + + return skb->len; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static void fl_bind_class(void *fh, u32 classid, unsigned long cl, void *q, + unsigned long base) +{ + struct cls_fl_filter *f = fh; + + tc_cls_bind_class(classid, cl, q, &f->res, base); +} + +static bool fl_delete_empty(struct tcf_proto *tp) +{ + struct cls_fl_head *head = fl_head_dereference(tp); + + spin_lock(&tp->lock); + tp->deleting = idr_is_empty(&head->handle_idr); + spin_unlock(&tp->lock); + + return tp->deleting; +} + +static struct tcf_proto_ops cls_fl_ops __read_mostly = { + .kind = "flower", + .classify = fl_classify, + .init = fl_init, + .destroy = fl_destroy, + .get = fl_get, + .put = fl_put, + .change = fl_change, + .delete = fl_delete, + .delete_empty = fl_delete_empty, + .walk = fl_walk, + .reoffload = fl_reoffload, + .hw_add = fl_hw_add, + .hw_del = fl_hw_del, + .dump = fl_dump, + .terse_dump = fl_terse_dump, + .bind_class = fl_bind_class, + .tmplt_create = fl_tmplt_create, + .tmplt_destroy = fl_tmplt_destroy, + .tmplt_dump = fl_tmplt_dump, + .owner = THIS_MODULE, + .flags = TCF_PROTO_OPS_DOIT_UNLOCKED, +}; + +static int __init cls_fl_init(void) +{ + return register_tcf_proto_ops(&cls_fl_ops); +} + +static void __exit cls_fl_exit(void) +{ + unregister_tcf_proto_ops(&cls_fl_ops); +} + +module_init(cls_fl_init); +module_exit(cls_fl_exit); + +MODULE_AUTHOR("Jiri Pirko "); +MODULE_DESCRIPTION("Flower classifier"); +MODULE_LICENSE("GPL v2"); diff --git a/net/sched/cls_fw.c b/net/sched/cls_fw.c new file mode 100644 index 000000000..6160ef7d6 --- /dev/null +++ b/net/sched/cls_fw.c @@ -0,0 +1,447 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/cls_fw.c Classifier mapping ipchains' fwmark to traffic class. + * + * Authors: Alexey Kuznetsov, + * + * Changes: + * Karlis Peisenieks : 990415 : fw_walk off by one + * Karlis Peisenieks : 990415 : fw_delete killed all the filter (and kernel). + * Alex : 2004xxyy: Added Action extension + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define HTSIZE 256 + +struct fw_head { + u32 mask; + struct fw_filter __rcu *ht[HTSIZE]; + struct rcu_head rcu; +}; + +struct fw_filter { + struct fw_filter __rcu *next; + u32 id; + struct tcf_result res; + int ifindex; + struct tcf_exts exts; + struct tcf_proto *tp; + struct rcu_work rwork; +}; + +static u32 fw_hash(u32 handle) +{ + handle ^= (handle >> 16); + handle ^= (handle >> 8); + return handle % HTSIZE; +} + +static int fw_classify(struct sk_buff *skb, const struct tcf_proto *tp, + struct tcf_result *res) +{ + struct fw_head *head = rcu_dereference_bh(tp->root); + struct fw_filter *f; + int r; + u32 id = skb->mark; + + if (head != NULL) { + id &= head->mask; + + for (f = rcu_dereference_bh(head->ht[fw_hash(id)]); f; + f = rcu_dereference_bh(f->next)) { + if (f->id == id) { + *res = f->res; + if (!tcf_match_indev(skb, f->ifindex)) + continue; + r = tcf_exts_exec(skb, &f->exts, res); + if (r < 0) + continue; + + return r; + } + } + } else { + struct Qdisc *q = tcf_block_q(tp->chain->block); + + /* Old method: classify the packet using its skb mark. */ + if (id && (TC_H_MAJ(id) == 0 || + !(TC_H_MAJ(id ^ q->handle)))) { + res->classid = id; + res->class = 0; + return 0; + } + } + + return -1; +} + +static void *fw_get(struct tcf_proto *tp, u32 handle) +{ + struct fw_head *head = rtnl_dereference(tp->root); + struct fw_filter *f; + + if (head == NULL) + return NULL; + + f = rtnl_dereference(head->ht[fw_hash(handle)]); + for (; f; f = rtnl_dereference(f->next)) { + if (f->id == handle) + return f; + } + return NULL; +} + +static int fw_init(struct tcf_proto *tp) +{ + /* We don't allocate fw_head here, because in the old method + * we don't need it at all. + */ + return 0; +} + +static void __fw_delete_filter(struct fw_filter *f) +{ + tcf_exts_destroy(&f->exts); + tcf_exts_put_net(&f->exts); + kfree(f); +} + +static void fw_delete_filter_work(struct work_struct *work) +{ + struct fw_filter *f = container_of(to_rcu_work(work), + struct fw_filter, + rwork); + rtnl_lock(); + __fw_delete_filter(f); + rtnl_unlock(); +} + +static void fw_destroy(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + struct fw_head *head = rtnl_dereference(tp->root); + struct fw_filter *f; + int h; + + if (head == NULL) + return; + + for (h = 0; h < HTSIZE; h++) { + while ((f = rtnl_dereference(head->ht[h])) != NULL) { + RCU_INIT_POINTER(head->ht[h], + rtnl_dereference(f->next)); + tcf_unbind_filter(tp, &f->res); + if (tcf_exts_get_net(&f->exts)) + tcf_queue_work(&f->rwork, fw_delete_filter_work); + else + __fw_delete_filter(f); + } + } + kfree_rcu(head, rcu); +} + +static int fw_delete(struct tcf_proto *tp, void *arg, bool *last, + bool rtnl_held, struct netlink_ext_ack *extack) +{ + struct fw_head *head = rtnl_dereference(tp->root); + struct fw_filter *f = arg; + struct fw_filter __rcu **fp; + struct fw_filter *pfp; + int ret = -EINVAL; + int h; + + if (head == NULL || f == NULL) + goto out; + + fp = &head->ht[fw_hash(f->id)]; + + for (pfp = rtnl_dereference(*fp); pfp; + fp = &pfp->next, pfp = rtnl_dereference(*fp)) { + if (pfp == f) { + RCU_INIT_POINTER(*fp, rtnl_dereference(f->next)); + tcf_unbind_filter(tp, &f->res); + tcf_exts_get_net(&f->exts); + tcf_queue_work(&f->rwork, fw_delete_filter_work); + ret = 0; + break; + } + } + + *last = true; + for (h = 0; h < HTSIZE; h++) { + if (rcu_access_pointer(head->ht[h])) { + *last = false; + break; + } + } + +out: + return ret; +} + +static const struct nla_policy fw_policy[TCA_FW_MAX + 1] = { + [TCA_FW_CLASSID] = { .type = NLA_U32 }, + [TCA_FW_INDEV] = { .type = NLA_STRING, .len = IFNAMSIZ }, + [TCA_FW_MASK] = { .type = NLA_U32 }, +}; + +static int fw_set_parms(struct net *net, struct tcf_proto *tp, + struct fw_filter *f, struct nlattr **tb, + struct nlattr **tca, unsigned long base, u32 flags, + struct netlink_ext_ack *extack) +{ + struct fw_head *head = rtnl_dereference(tp->root); + u32 mask; + int err; + + err = tcf_exts_validate(net, tp, tb, tca[TCA_RATE], &f->exts, flags, + extack); + if (err < 0) + return err; + + if (tb[TCA_FW_INDEV]) { + int ret; + ret = tcf_change_indev(net, tb[TCA_FW_INDEV], extack); + if (ret < 0) + return ret; + f->ifindex = ret; + } + + err = -EINVAL; + if (tb[TCA_FW_MASK]) { + mask = nla_get_u32(tb[TCA_FW_MASK]); + if (mask != head->mask) + return err; + } else if (head->mask != 0xFFFFFFFF) + return err; + + if (tb[TCA_FW_CLASSID]) { + f->res.classid = nla_get_u32(tb[TCA_FW_CLASSID]); + tcf_bind_filter(tp, &f->res, base); + } + + return 0; +} + +static int fw_change(struct net *net, struct sk_buff *in_skb, + struct tcf_proto *tp, unsigned long base, + u32 handle, struct nlattr **tca, void **arg, + u32 flags, struct netlink_ext_ack *extack) +{ + struct fw_head *head = rtnl_dereference(tp->root); + struct fw_filter *f = *arg; + struct nlattr *opt = tca[TCA_OPTIONS]; + struct nlattr *tb[TCA_FW_MAX + 1]; + int err; + + if (!opt) + return handle ? -EINVAL : 0; /* Succeed if it is old method. */ + + err = nla_parse_nested_deprecated(tb, TCA_FW_MAX, opt, fw_policy, + NULL); + if (err < 0) + return err; + + if (f) { + struct fw_filter *pfp, *fnew; + struct fw_filter __rcu **fp; + + if (f->id != handle && handle) + return -EINVAL; + + fnew = kzalloc(sizeof(struct fw_filter), GFP_KERNEL); + if (!fnew) + return -ENOBUFS; + + fnew->id = f->id; + fnew->ifindex = f->ifindex; + fnew->tp = f->tp; + + err = tcf_exts_init(&fnew->exts, net, TCA_FW_ACT, + TCA_FW_POLICE); + if (err < 0) { + kfree(fnew); + return err; + } + + err = fw_set_parms(net, tp, fnew, tb, tca, base, flags, extack); + if (err < 0) { + tcf_exts_destroy(&fnew->exts); + kfree(fnew); + return err; + } + + fp = &head->ht[fw_hash(fnew->id)]; + for (pfp = rtnl_dereference(*fp); pfp; + fp = &pfp->next, pfp = rtnl_dereference(*fp)) + if (pfp == f) + break; + + RCU_INIT_POINTER(fnew->next, rtnl_dereference(pfp->next)); + rcu_assign_pointer(*fp, fnew); + tcf_unbind_filter(tp, &f->res); + tcf_exts_get_net(&f->exts); + tcf_queue_work(&f->rwork, fw_delete_filter_work); + + *arg = fnew; + return err; + } + + if (!handle) + return -EINVAL; + + if (!head) { + u32 mask = 0xFFFFFFFF; + if (tb[TCA_FW_MASK]) + mask = nla_get_u32(tb[TCA_FW_MASK]); + + head = kzalloc(sizeof(*head), GFP_KERNEL); + if (!head) + return -ENOBUFS; + head->mask = mask; + + rcu_assign_pointer(tp->root, head); + } + + f = kzalloc(sizeof(struct fw_filter), GFP_KERNEL); + if (f == NULL) + return -ENOBUFS; + + err = tcf_exts_init(&f->exts, net, TCA_FW_ACT, TCA_FW_POLICE); + if (err < 0) + goto errout; + f->id = handle; + f->tp = tp; + + err = fw_set_parms(net, tp, f, tb, tca, base, flags, extack); + if (err < 0) + goto errout; + + RCU_INIT_POINTER(f->next, head->ht[fw_hash(handle)]); + rcu_assign_pointer(head->ht[fw_hash(handle)], f); + + *arg = f; + return 0; + +errout: + tcf_exts_destroy(&f->exts); + kfree(f); + return err; +} + +static void fw_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) +{ + struct fw_head *head = rtnl_dereference(tp->root); + int h; + + if (head == NULL) + arg->stop = 1; + + if (arg->stop) + return; + + for (h = 0; h < HTSIZE; h++) { + struct fw_filter *f; + + for (f = rtnl_dereference(head->ht[h]); f; + f = rtnl_dereference(f->next)) { + if (!tc_cls_stats_dump(tp, arg, f)) + return; + } + } +} + +static int fw_dump(struct net *net, struct tcf_proto *tp, void *fh, + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) +{ + struct fw_head *head = rtnl_dereference(tp->root); + struct fw_filter *f = fh; + struct nlattr *nest; + + if (f == NULL) + return skb->len; + + t->tcm_handle = f->id; + + if (!f->res.classid && !tcf_exts_has_actions(&f->exts)) + return skb->len; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + + if (f->res.classid && + nla_put_u32(skb, TCA_FW_CLASSID, f->res.classid)) + goto nla_put_failure; + if (f->ifindex) { + struct net_device *dev; + dev = __dev_get_by_index(net, f->ifindex); + if (dev && nla_put_string(skb, TCA_FW_INDEV, dev->name)) + goto nla_put_failure; + } + if (head->mask != 0xFFFFFFFF && + nla_put_u32(skb, TCA_FW_MASK, head->mask)) + goto nla_put_failure; + + if (tcf_exts_dump(skb, &f->exts) < 0) + goto nla_put_failure; + + nla_nest_end(skb, nest); + + if (tcf_exts_dump_stats(skb, &f->exts) < 0) + goto nla_put_failure; + + return skb->len; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static void fw_bind_class(void *fh, u32 classid, unsigned long cl, void *q, + unsigned long base) +{ + struct fw_filter *f = fh; + + tc_cls_bind_class(classid, cl, q, &f->res, base); +} + +static struct tcf_proto_ops cls_fw_ops __read_mostly = { + .kind = "fw", + .classify = fw_classify, + .init = fw_init, + .destroy = fw_destroy, + .get = fw_get, + .change = fw_change, + .delete = fw_delete, + .walk = fw_walk, + .dump = fw_dump, + .bind_class = fw_bind_class, + .owner = THIS_MODULE, +}; + +static int __init init_fw(void) +{ + return register_tcf_proto_ops(&cls_fw_ops); +} + +static void __exit exit_fw(void) +{ + unregister_tcf_proto_ops(&cls_fw_ops); +} + +module_init(init_fw) +module_exit(exit_fw) +MODULE_LICENSE("GPL"); diff --git a/net/sched/cls_matchall.c b/net/sched/cls_matchall.c new file mode 100644 index 000000000..43f8df584 --- /dev/null +++ b/net/sched/cls_matchall.c @@ -0,0 +1,419 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/cls_matchll.c Match-all classifier + * + * Copyright (c) 2016 Jiri Pirko + */ + +#include +#include +#include +#include + +#include +#include + +struct cls_mall_head { + struct tcf_exts exts; + struct tcf_result res; + u32 handle; + u32 flags; + unsigned int in_hw_count; + struct tc_matchall_pcnt __percpu *pf; + struct rcu_work rwork; + bool deleting; +}; + +static int mall_classify(struct sk_buff *skb, const struct tcf_proto *tp, + struct tcf_result *res) +{ + struct cls_mall_head *head = rcu_dereference_bh(tp->root); + + if (unlikely(!head)) + return -1; + + if (tc_skip_sw(head->flags)) + return -1; + + *res = head->res; + __this_cpu_inc(head->pf->rhit); + return tcf_exts_exec(skb, &head->exts, res); +} + +static int mall_init(struct tcf_proto *tp) +{ + return 0; +} + +static void __mall_destroy(struct cls_mall_head *head) +{ + tcf_exts_destroy(&head->exts); + tcf_exts_put_net(&head->exts); + free_percpu(head->pf); + kfree(head); +} + +static void mall_destroy_work(struct work_struct *work) +{ + struct cls_mall_head *head = container_of(to_rcu_work(work), + struct cls_mall_head, + rwork); + rtnl_lock(); + __mall_destroy(head); + rtnl_unlock(); +} + +static void mall_destroy_hw_filter(struct tcf_proto *tp, + struct cls_mall_head *head, + unsigned long cookie, + struct netlink_ext_ack *extack) +{ + struct tc_cls_matchall_offload cls_mall = {}; + struct tcf_block *block = tp->chain->block; + + tc_cls_common_offload_init(&cls_mall.common, tp, head->flags, extack); + cls_mall.command = TC_CLSMATCHALL_DESTROY; + cls_mall.cookie = cookie; + + tc_setup_cb_destroy(block, tp, TC_SETUP_CLSMATCHALL, &cls_mall, false, + &head->flags, &head->in_hw_count, true); +} + +static int mall_replace_hw_filter(struct tcf_proto *tp, + struct cls_mall_head *head, + unsigned long cookie, + struct netlink_ext_ack *extack) +{ + struct tc_cls_matchall_offload cls_mall = {}; + struct tcf_block *block = tp->chain->block; + bool skip_sw = tc_skip_sw(head->flags); + int err; + + cls_mall.rule = flow_rule_alloc(tcf_exts_num_actions(&head->exts)); + if (!cls_mall.rule) + return -ENOMEM; + + tc_cls_common_offload_init(&cls_mall.common, tp, head->flags, extack); + cls_mall.command = TC_CLSMATCHALL_REPLACE; + cls_mall.cookie = cookie; + + err = tc_setup_offload_action(&cls_mall.rule->action, &head->exts, + cls_mall.common.extack); + if (err) { + kfree(cls_mall.rule); + mall_destroy_hw_filter(tp, head, cookie, NULL); + + return skip_sw ? err : 0; + } + + err = tc_setup_cb_add(block, tp, TC_SETUP_CLSMATCHALL, &cls_mall, + skip_sw, &head->flags, &head->in_hw_count, true); + tc_cleanup_offload_action(&cls_mall.rule->action); + kfree(cls_mall.rule); + + if (err) { + mall_destroy_hw_filter(tp, head, cookie, NULL); + return err; + } + + if (skip_sw && !(head->flags & TCA_CLS_FLAGS_IN_HW)) + return -EINVAL; + + return 0; +} + +static void mall_destroy(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + struct cls_mall_head *head = rtnl_dereference(tp->root); + + if (!head) + return; + + tcf_unbind_filter(tp, &head->res); + + if (!tc_skip_hw(head->flags)) + mall_destroy_hw_filter(tp, head, (unsigned long) head, extack); + + if (tcf_exts_get_net(&head->exts)) + tcf_queue_work(&head->rwork, mall_destroy_work); + else + __mall_destroy(head); +} + +static void *mall_get(struct tcf_proto *tp, u32 handle) +{ + struct cls_mall_head *head = rtnl_dereference(tp->root); + + if (head && head->handle == handle) + return head; + + return NULL; +} + +static const struct nla_policy mall_policy[TCA_MATCHALL_MAX + 1] = { + [TCA_MATCHALL_UNSPEC] = { .type = NLA_UNSPEC }, + [TCA_MATCHALL_CLASSID] = { .type = NLA_U32 }, + [TCA_MATCHALL_FLAGS] = { .type = NLA_U32 }, +}; + +static int mall_change(struct net *net, struct sk_buff *in_skb, + struct tcf_proto *tp, unsigned long base, + u32 handle, struct nlattr **tca, + void **arg, u32 flags, + struct netlink_ext_ack *extack) +{ + struct cls_mall_head *head = rtnl_dereference(tp->root); + struct nlattr *tb[TCA_MATCHALL_MAX + 1]; + bool bound_to_filter = false; + struct cls_mall_head *new; + u32 userflags = 0; + int err; + + if (!tca[TCA_OPTIONS]) + return -EINVAL; + + if (head) + return -EEXIST; + + err = nla_parse_nested_deprecated(tb, TCA_MATCHALL_MAX, + tca[TCA_OPTIONS], mall_policy, NULL); + if (err < 0) + return err; + + if (tb[TCA_MATCHALL_FLAGS]) { + userflags = nla_get_u32(tb[TCA_MATCHALL_FLAGS]); + if (!tc_flags_valid(userflags)) + return -EINVAL; + } + + new = kzalloc(sizeof(*new), GFP_KERNEL); + if (!new) + return -ENOBUFS; + + err = tcf_exts_init(&new->exts, net, TCA_MATCHALL_ACT, 0); + if (err) + goto err_exts_init; + + if (!handle) + handle = 1; + new->handle = handle; + new->flags = userflags; + new->pf = alloc_percpu(struct tc_matchall_pcnt); + if (!new->pf) { + err = -ENOMEM; + goto err_alloc_percpu; + } + + err = tcf_exts_validate_ex(net, tp, tb, tca[TCA_RATE], + &new->exts, flags, new->flags, extack); + if (err < 0) + goto err_set_parms; + + if (tb[TCA_MATCHALL_CLASSID]) { + new->res.classid = nla_get_u32(tb[TCA_MATCHALL_CLASSID]); + tcf_bind_filter(tp, &new->res, base); + bound_to_filter = true; + } + + if (!tc_skip_hw(new->flags)) { + err = mall_replace_hw_filter(tp, new, (unsigned long)new, + extack); + if (err) + goto err_replace_hw_filter; + } + + if (!tc_in_hw(new->flags)) + new->flags |= TCA_CLS_FLAGS_NOT_IN_HW; + + *arg = head; + rcu_assign_pointer(tp->root, new); + return 0; + +err_replace_hw_filter: + if (bound_to_filter) + tcf_unbind_filter(tp, &new->res); +err_set_parms: + free_percpu(new->pf); +err_alloc_percpu: + tcf_exts_destroy(&new->exts); +err_exts_init: + kfree(new); + return err; +} + +static int mall_delete(struct tcf_proto *tp, void *arg, bool *last, + bool rtnl_held, struct netlink_ext_ack *extack) +{ + struct cls_mall_head *head = rtnl_dereference(tp->root); + + head->deleting = true; + *last = true; + return 0; +} + +static void mall_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) +{ + struct cls_mall_head *head = rtnl_dereference(tp->root); + + if (arg->count < arg->skip) + goto skip; + + if (!head || head->deleting) + return; + if (arg->fn(tp, head, arg) < 0) + arg->stop = 1; +skip: + arg->count++; +} + +static int mall_reoffload(struct tcf_proto *tp, bool add, flow_setup_cb_t *cb, + void *cb_priv, struct netlink_ext_ack *extack) +{ + struct cls_mall_head *head = rtnl_dereference(tp->root); + struct tc_cls_matchall_offload cls_mall = {}; + struct tcf_block *block = tp->chain->block; + int err; + + if (tc_skip_hw(head->flags)) + return 0; + + cls_mall.rule = flow_rule_alloc(tcf_exts_num_actions(&head->exts)); + if (!cls_mall.rule) + return -ENOMEM; + + tc_cls_common_offload_init(&cls_mall.common, tp, head->flags, extack); + cls_mall.command = add ? + TC_CLSMATCHALL_REPLACE : TC_CLSMATCHALL_DESTROY; + cls_mall.cookie = (unsigned long)head; + + err = tc_setup_offload_action(&cls_mall.rule->action, &head->exts, + cls_mall.common.extack); + if (err) { + kfree(cls_mall.rule); + + return add && tc_skip_sw(head->flags) ? err : 0; + } + + err = tc_setup_cb_reoffload(block, tp, add, cb, TC_SETUP_CLSMATCHALL, + &cls_mall, cb_priv, &head->flags, + &head->in_hw_count); + tc_cleanup_offload_action(&cls_mall.rule->action); + kfree(cls_mall.rule); + + return err; +} + +static void mall_stats_hw_filter(struct tcf_proto *tp, + struct cls_mall_head *head, + unsigned long cookie) +{ + struct tc_cls_matchall_offload cls_mall = {}; + struct tcf_block *block = tp->chain->block; + + tc_cls_common_offload_init(&cls_mall.common, tp, head->flags, NULL); + cls_mall.command = TC_CLSMATCHALL_STATS; + cls_mall.cookie = cookie; + + tc_setup_cb_call(block, TC_SETUP_CLSMATCHALL, &cls_mall, false, true); + + tcf_exts_hw_stats_update(&head->exts, cls_mall.stats.bytes, + cls_mall.stats.pkts, cls_mall.stats.drops, + cls_mall.stats.lastused, + cls_mall.stats.used_hw_stats, + cls_mall.stats.used_hw_stats_valid); +} + +static int mall_dump(struct net *net, struct tcf_proto *tp, void *fh, + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) +{ + struct tc_matchall_pcnt gpf = {}; + struct cls_mall_head *head = fh; + struct nlattr *nest; + int cpu; + + if (!head) + return skb->len; + + if (!tc_skip_hw(head->flags)) + mall_stats_hw_filter(tp, head, (unsigned long)head); + + t->tcm_handle = head->handle; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (!nest) + goto nla_put_failure; + + if (head->res.classid && + nla_put_u32(skb, TCA_MATCHALL_CLASSID, head->res.classid)) + goto nla_put_failure; + + if (head->flags && nla_put_u32(skb, TCA_MATCHALL_FLAGS, head->flags)) + goto nla_put_failure; + + for_each_possible_cpu(cpu) { + struct tc_matchall_pcnt *pf = per_cpu_ptr(head->pf, cpu); + + gpf.rhit += pf->rhit; + } + + if (nla_put_64bit(skb, TCA_MATCHALL_PCNT, + sizeof(struct tc_matchall_pcnt), + &gpf, TCA_MATCHALL_PAD)) + goto nla_put_failure; + + if (tcf_exts_dump(skb, &head->exts)) + goto nla_put_failure; + + nla_nest_end(skb, nest); + + if (tcf_exts_dump_stats(skb, &head->exts) < 0) + goto nla_put_failure; + + return skb->len; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static void mall_bind_class(void *fh, u32 classid, unsigned long cl, void *q, + unsigned long base) +{ + struct cls_mall_head *head = fh; + + tc_cls_bind_class(classid, cl, q, &head->res, base); +} + +static struct tcf_proto_ops cls_mall_ops __read_mostly = { + .kind = "matchall", + .classify = mall_classify, + .init = mall_init, + .destroy = mall_destroy, + .get = mall_get, + .change = mall_change, + .delete = mall_delete, + .walk = mall_walk, + .reoffload = mall_reoffload, + .dump = mall_dump, + .bind_class = mall_bind_class, + .owner = THIS_MODULE, +}; + +static int __init cls_mall_init(void) +{ + return register_tcf_proto_ops(&cls_mall_ops); +} + +static void __exit cls_mall_exit(void) +{ + unregister_tcf_proto_ops(&cls_mall_ops); +} + +module_init(cls_mall_init); +module_exit(cls_mall_exit); + +MODULE_AUTHOR("Jiri Pirko "); +MODULE_DESCRIPTION("Match-all classifier"); +MODULE_LICENSE("GPL v2"); diff --git a/net/sched/cls_route.c b/net/sched/cls_route.c new file mode 100644 index 000000000..306188bf2 --- /dev/null +++ b/net/sched/cls_route.c @@ -0,0 +1,680 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/cls_route.c ROUTE4 classifier. + * + * Authors: Alexey Kuznetsov, + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * 1. For now we assume that route tags < 256. + * It allows to use direct table lookups, instead of hash tables. + * 2. For now we assume that "from TAG" and "fromdev DEV" statements + * are mutually exclusive. + * 3. "to TAG from ANY" has higher priority, than "to ANY from XXX" + */ +struct route4_fastmap { + struct route4_filter *filter; + u32 id; + int iif; +}; + +struct route4_head { + struct route4_fastmap fastmap[16]; + struct route4_bucket __rcu *table[256 + 1]; + struct rcu_head rcu; +}; + +struct route4_bucket { + /* 16 FROM buckets + 16 IIF buckets + 1 wildcard bucket */ + struct route4_filter __rcu *ht[16 + 16 + 1]; + struct rcu_head rcu; +}; + +struct route4_filter { + struct route4_filter __rcu *next; + u32 id; + int iif; + + struct tcf_result res; + struct tcf_exts exts; + u32 handle; + struct route4_bucket *bkt; + struct tcf_proto *tp; + struct rcu_work rwork; +}; + +#define ROUTE4_FAILURE ((struct route4_filter *)(-1L)) + +static inline int route4_fastmap_hash(u32 id, int iif) +{ + return id & 0xF; +} + +static DEFINE_SPINLOCK(fastmap_lock); +static void +route4_reset_fastmap(struct route4_head *head) +{ + spin_lock_bh(&fastmap_lock); + memset(head->fastmap, 0, sizeof(head->fastmap)); + spin_unlock_bh(&fastmap_lock); +} + +static void +route4_set_fastmap(struct route4_head *head, u32 id, int iif, + struct route4_filter *f) +{ + int h = route4_fastmap_hash(id, iif); + + /* fastmap updates must look atomic to aling id, iff, filter */ + spin_lock_bh(&fastmap_lock); + head->fastmap[h].id = id; + head->fastmap[h].iif = iif; + head->fastmap[h].filter = f; + spin_unlock_bh(&fastmap_lock); +} + +static inline int route4_hash_to(u32 id) +{ + return id & 0xFF; +} + +static inline int route4_hash_from(u32 id) +{ + return (id >> 16) & 0xF; +} + +static inline int route4_hash_iif(int iif) +{ + return 16 + ((iif >> 16) & 0xF); +} + +static inline int route4_hash_wild(void) +{ + return 32; +} + +#define ROUTE4_APPLY_RESULT() \ +{ \ + *res = f->res; \ + if (tcf_exts_has_actions(&f->exts)) { \ + int r = tcf_exts_exec(skb, &f->exts, res); \ + if (r < 0) { \ + dont_cache = 1; \ + continue; \ + } \ + return r; \ + } else if (!dont_cache) \ + route4_set_fastmap(head, id, iif, f); \ + return 0; \ +} + +static int route4_classify(struct sk_buff *skb, const struct tcf_proto *tp, + struct tcf_result *res) +{ + struct route4_head *head = rcu_dereference_bh(tp->root); + struct dst_entry *dst; + struct route4_bucket *b; + struct route4_filter *f; + u32 id, h; + int iif, dont_cache = 0; + + dst = skb_dst(skb); + if (!dst) + goto failure; + + id = dst->tclassid; + + iif = inet_iif(skb); + + h = route4_fastmap_hash(id, iif); + + spin_lock(&fastmap_lock); + if (id == head->fastmap[h].id && + iif == head->fastmap[h].iif && + (f = head->fastmap[h].filter) != NULL) { + if (f == ROUTE4_FAILURE) { + spin_unlock(&fastmap_lock); + goto failure; + } + + *res = f->res; + spin_unlock(&fastmap_lock); + return 0; + } + spin_unlock(&fastmap_lock); + + h = route4_hash_to(id); + +restart: + b = rcu_dereference_bh(head->table[h]); + if (b) { + for (f = rcu_dereference_bh(b->ht[route4_hash_from(id)]); + f; + f = rcu_dereference_bh(f->next)) + if (f->id == id) + ROUTE4_APPLY_RESULT(); + + for (f = rcu_dereference_bh(b->ht[route4_hash_iif(iif)]); + f; + f = rcu_dereference_bh(f->next)) + if (f->iif == iif) + ROUTE4_APPLY_RESULT(); + + for (f = rcu_dereference_bh(b->ht[route4_hash_wild()]); + f; + f = rcu_dereference_bh(f->next)) + ROUTE4_APPLY_RESULT(); + } + if (h < 256) { + h = 256; + id &= ~0xFFFF; + goto restart; + } + + if (!dont_cache) + route4_set_fastmap(head, id, iif, ROUTE4_FAILURE); +failure: + return -1; +} + +static inline u32 to_hash(u32 id) +{ + u32 h = id & 0xFF; + + if (id & 0x8000) + h += 256; + return h; +} + +static inline u32 from_hash(u32 id) +{ + id &= 0xFFFF; + if (id == 0xFFFF) + return 32; + if (!(id & 0x8000)) { + if (id > 255) + return 256; + return id & 0xF; + } + return 16 + (id & 0xF); +} + +static void *route4_get(struct tcf_proto *tp, u32 handle) +{ + struct route4_head *head = rtnl_dereference(tp->root); + struct route4_bucket *b; + struct route4_filter *f; + unsigned int h1, h2; + + h1 = to_hash(handle); + if (h1 > 256) + return NULL; + + h2 = from_hash(handle >> 16); + if (h2 > 32) + return NULL; + + b = rtnl_dereference(head->table[h1]); + if (b) { + for (f = rtnl_dereference(b->ht[h2]); + f; + f = rtnl_dereference(f->next)) + if (f->handle == handle) + return f; + } + return NULL; +} + +static int route4_init(struct tcf_proto *tp) +{ + struct route4_head *head; + + head = kzalloc(sizeof(struct route4_head), GFP_KERNEL); + if (head == NULL) + return -ENOBUFS; + + rcu_assign_pointer(tp->root, head); + return 0; +} + +static void __route4_delete_filter(struct route4_filter *f) +{ + tcf_exts_destroy(&f->exts); + tcf_exts_put_net(&f->exts); + kfree(f); +} + +static void route4_delete_filter_work(struct work_struct *work) +{ + struct route4_filter *f = container_of(to_rcu_work(work), + struct route4_filter, + rwork); + rtnl_lock(); + __route4_delete_filter(f); + rtnl_unlock(); +} + +static void route4_queue_work(struct route4_filter *f) +{ + tcf_queue_work(&f->rwork, route4_delete_filter_work); +} + +static void route4_destroy(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + struct route4_head *head = rtnl_dereference(tp->root); + int h1, h2; + + if (head == NULL) + return; + + for (h1 = 0; h1 <= 256; h1++) { + struct route4_bucket *b; + + b = rtnl_dereference(head->table[h1]); + if (b) { + for (h2 = 0; h2 <= 32; h2++) { + struct route4_filter *f; + + while ((f = rtnl_dereference(b->ht[h2])) != NULL) { + struct route4_filter *next; + + next = rtnl_dereference(f->next); + RCU_INIT_POINTER(b->ht[h2], next); + tcf_unbind_filter(tp, &f->res); + if (tcf_exts_get_net(&f->exts)) + route4_queue_work(f); + else + __route4_delete_filter(f); + } + } + RCU_INIT_POINTER(head->table[h1], NULL); + kfree_rcu(b, rcu); + } + } + kfree_rcu(head, rcu); +} + +static int route4_delete(struct tcf_proto *tp, void *arg, bool *last, + bool rtnl_held, struct netlink_ext_ack *extack) +{ + struct route4_head *head = rtnl_dereference(tp->root); + struct route4_filter *f = arg; + struct route4_filter __rcu **fp; + struct route4_filter *nf; + struct route4_bucket *b; + unsigned int h = 0; + int i, h1; + + if (!head || !f) + return -EINVAL; + + h = f->handle; + b = f->bkt; + + fp = &b->ht[from_hash(h >> 16)]; + for (nf = rtnl_dereference(*fp); nf; + fp = &nf->next, nf = rtnl_dereference(*fp)) { + if (nf == f) { + /* unlink it */ + RCU_INIT_POINTER(*fp, rtnl_dereference(f->next)); + + /* Remove any fastmap lookups that might ref filter + * notice we unlink'd the filter so we can't get it + * back in the fastmap. + */ + route4_reset_fastmap(head); + + /* Delete it */ + tcf_unbind_filter(tp, &f->res); + tcf_exts_get_net(&f->exts); + tcf_queue_work(&f->rwork, route4_delete_filter_work); + + /* Strip RTNL protected tree */ + for (i = 0; i <= 32; i++) { + struct route4_filter *rt; + + rt = rtnl_dereference(b->ht[i]); + if (rt) + goto out; + } + + /* OK, session has no flows */ + RCU_INIT_POINTER(head->table[to_hash(h)], NULL); + kfree_rcu(b, rcu); + break; + } + } + +out: + *last = true; + for (h1 = 0; h1 <= 256; h1++) { + if (rcu_access_pointer(head->table[h1])) { + *last = false; + break; + } + } + + return 0; +} + +static const struct nla_policy route4_policy[TCA_ROUTE4_MAX + 1] = { + [TCA_ROUTE4_CLASSID] = { .type = NLA_U32 }, + [TCA_ROUTE4_TO] = { .type = NLA_U32 }, + [TCA_ROUTE4_FROM] = { .type = NLA_U32 }, + [TCA_ROUTE4_IIF] = { .type = NLA_U32 }, +}; + +static int route4_set_parms(struct net *net, struct tcf_proto *tp, + unsigned long base, struct route4_filter *f, + u32 handle, struct route4_head *head, + struct nlattr **tb, struct nlattr *est, int new, + u32 flags, struct netlink_ext_ack *extack) +{ + u32 id = 0, to = 0, nhandle = 0x8000; + struct route4_filter *fp; + unsigned int h1; + struct route4_bucket *b; + int err; + + err = tcf_exts_validate(net, tp, tb, est, &f->exts, flags, extack); + if (err < 0) + return err; + + if (tb[TCA_ROUTE4_TO]) { + if (new && handle & 0x8000) + return -EINVAL; + to = nla_get_u32(tb[TCA_ROUTE4_TO]); + if (to > 0xFF) + return -EINVAL; + nhandle = to; + } + + if (tb[TCA_ROUTE4_FROM]) { + if (tb[TCA_ROUTE4_IIF]) + return -EINVAL; + id = nla_get_u32(tb[TCA_ROUTE4_FROM]); + if (id > 0xFF) + return -EINVAL; + nhandle |= id << 16; + } else if (tb[TCA_ROUTE4_IIF]) { + id = nla_get_u32(tb[TCA_ROUTE4_IIF]); + if (id > 0x7FFF) + return -EINVAL; + nhandle |= (id | 0x8000) << 16; + } else + nhandle |= 0xFFFF << 16; + + if (handle && new) { + nhandle |= handle & 0x7F00; + if (nhandle != handle) + return -EINVAL; + } + + if (!nhandle) { + NL_SET_ERR_MSG(extack, "Replacing with handle of 0 is invalid"); + return -EINVAL; + } + + h1 = to_hash(nhandle); + b = rtnl_dereference(head->table[h1]); + if (!b) { + b = kzalloc(sizeof(struct route4_bucket), GFP_KERNEL); + if (b == NULL) + return -ENOBUFS; + + rcu_assign_pointer(head->table[h1], b); + } else { + unsigned int h2 = from_hash(nhandle >> 16); + + for (fp = rtnl_dereference(b->ht[h2]); + fp; + fp = rtnl_dereference(fp->next)) + if (fp->handle == f->handle) + return -EEXIST; + } + + if (tb[TCA_ROUTE4_TO]) + f->id = to; + + if (tb[TCA_ROUTE4_FROM]) + f->id = to | id<<16; + else if (tb[TCA_ROUTE4_IIF]) + f->iif = id; + + f->handle = nhandle; + f->bkt = b; + f->tp = tp; + + if (tb[TCA_ROUTE4_CLASSID]) { + f->res.classid = nla_get_u32(tb[TCA_ROUTE4_CLASSID]); + tcf_bind_filter(tp, &f->res, base); + } + + return 0; +} + +static int route4_change(struct net *net, struct sk_buff *in_skb, + struct tcf_proto *tp, unsigned long base, u32 handle, + struct nlattr **tca, void **arg, u32 flags, + struct netlink_ext_ack *extack) +{ + struct route4_head *head = rtnl_dereference(tp->root); + struct route4_filter __rcu **fp; + struct route4_filter *fold, *f1, *pfp, *f = NULL; + struct route4_bucket *b; + struct nlattr *opt = tca[TCA_OPTIONS]; + struct nlattr *tb[TCA_ROUTE4_MAX + 1]; + unsigned int h, th; + int err; + bool new = true; + + if (!handle) { + NL_SET_ERR_MSG(extack, "Creating with handle of 0 is invalid"); + return -EINVAL; + } + + if (opt == NULL) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, TCA_ROUTE4_MAX, opt, + route4_policy, NULL); + if (err < 0) + return err; + + fold = *arg; + if (fold && fold->handle != handle) + return -EINVAL; + + err = -ENOBUFS; + f = kzalloc(sizeof(struct route4_filter), GFP_KERNEL); + if (!f) + goto errout; + + err = tcf_exts_init(&f->exts, net, TCA_ROUTE4_ACT, TCA_ROUTE4_POLICE); + if (err < 0) + goto errout; + + if (fold) { + f->id = fold->id; + f->iif = fold->iif; + f->handle = fold->handle; + + f->tp = fold->tp; + f->bkt = fold->bkt; + new = false; + } + + err = route4_set_parms(net, tp, base, f, handle, head, tb, + tca[TCA_RATE], new, flags, extack); + if (err < 0) + goto errout; + + h = from_hash(f->handle >> 16); + fp = &f->bkt->ht[h]; + for (pfp = rtnl_dereference(*fp); + (f1 = rtnl_dereference(*fp)) != NULL; + fp = &f1->next) + if (f->handle < f1->handle) + break; + + tcf_block_netif_keep_dst(tp->chain->block); + rcu_assign_pointer(f->next, f1); + rcu_assign_pointer(*fp, f); + + if (fold) { + th = to_hash(fold->handle); + h = from_hash(fold->handle >> 16); + b = rtnl_dereference(head->table[th]); + if (b) { + fp = &b->ht[h]; + for (pfp = rtnl_dereference(*fp); pfp; + fp = &pfp->next, pfp = rtnl_dereference(*fp)) { + if (pfp == fold) { + rcu_assign_pointer(*fp, fold->next); + break; + } + } + } + } + + route4_reset_fastmap(head); + *arg = f; + if (fold) { + tcf_unbind_filter(tp, &fold->res); + tcf_exts_get_net(&fold->exts); + tcf_queue_work(&fold->rwork, route4_delete_filter_work); + } + return 0; + +errout: + if (f) + tcf_exts_destroy(&f->exts); + kfree(f); + return err; +} + +static void route4_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) +{ + struct route4_head *head = rtnl_dereference(tp->root); + unsigned int h, h1; + + if (head == NULL || arg->stop) + return; + + for (h = 0; h <= 256; h++) { + struct route4_bucket *b = rtnl_dereference(head->table[h]); + + if (b) { + for (h1 = 0; h1 <= 32; h1++) { + struct route4_filter *f; + + for (f = rtnl_dereference(b->ht[h1]); + f; + f = rtnl_dereference(f->next)) { + if (!tc_cls_stats_dump(tp, arg, f)) + return; + } + } + } + } +} + +static int route4_dump(struct net *net, struct tcf_proto *tp, void *fh, + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) +{ + struct route4_filter *f = fh; + struct nlattr *nest; + u32 id; + + if (f == NULL) + return skb->len; + + t->tcm_handle = f->handle; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + + if (!(f->handle & 0x8000)) { + id = f->id & 0xFF; + if (nla_put_u32(skb, TCA_ROUTE4_TO, id)) + goto nla_put_failure; + } + if (f->handle & 0x80000000) { + if ((f->handle >> 16) != 0xFFFF && + nla_put_u32(skb, TCA_ROUTE4_IIF, f->iif)) + goto nla_put_failure; + } else { + id = f->id >> 16; + if (nla_put_u32(skb, TCA_ROUTE4_FROM, id)) + goto nla_put_failure; + } + if (f->res.classid && + nla_put_u32(skb, TCA_ROUTE4_CLASSID, f->res.classid)) + goto nla_put_failure; + + if (tcf_exts_dump(skb, &f->exts) < 0) + goto nla_put_failure; + + nla_nest_end(skb, nest); + + if (tcf_exts_dump_stats(skb, &f->exts) < 0) + goto nla_put_failure; + + return skb->len; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static void route4_bind_class(void *fh, u32 classid, unsigned long cl, void *q, + unsigned long base) +{ + struct route4_filter *f = fh; + + tc_cls_bind_class(classid, cl, q, &f->res, base); +} + +static struct tcf_proto_ops cls_route4_ops __read_mostly = { + .kind = "route", + .classify = route4_classify, + .init = route4_init, + .destroy = route4_destroy, + .get = route4_get, + .change = route4_change, + .delete = route4_delete, + .walk = route4_walk, + .dump = route4_dump, + .bind_class = route4_bind_class, + .owner = THIS_MODULE, +}; + +static int __init init_route4(void) +{ + return register_tcf_proto_ops(&cls_route4_ops); +} + +static void __exit exit_route4(void) +{ + unregister_tcf_proto_ops(&cls_route4_ops); +} + +module_init(init_route4) +module_exit(exit_route4) +MODULE_LICENSE("GPL"); diff --git a/net/sched/cls_u32.c b/net/sched/cls_u32.c new file mode 100644 index 000000000..04448bfb4 --- /dev/null +++ b/net/sched/cls_u32.c @@ -0,0 +1,1490 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/cls_u32.c Ugly (or Universal) 32bit key Packet Classifier. + * + * Authors: Alexey Kuznetsov, + * + * The filters are packed to hash tables of key nodes + * with a set of 32bit key/mask pairs at every node. + * Nodes reference next level hash tables etc. + * + * This scheme is the best universal classifier I managed to + * invent; it is not super-fast, but it is not slow (provided you + * program it correctly), and general enough. And its relative + * speed grows as the number of rules becomes larger. + * + * It seems that it represents the best middle point between + * speed and manageability both by human and by machine. + * + * It is especially useful for link sharing combined with QoS; + * pure RSVP doesn't need such a general approach and can use + * much simpler (and faster) schemes, sort of cls_rsvp.c. + * + * nfmark match added by Catalin(ux aka Dino) BOIE + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct tc_u_knode { + struct tc_u_knode __rcu *next; + u32 handle; + struct tc_u_hnode __rcu *ht_up; + struct tcf_exts exts; + int ifindex; + u8 fshift; + struct tcf_result res; + struct tc_u_hnode __rcu *ht_down; +#ifdef CONFIG_CLS_U32_PERF + struct tc_u32_pcnt __percpu *pf; +#endif + u32 flags; + unsigned int in_hw_count; +#ifdef CONFIG_CLS_U32_MARK + u32 val; + u32 mask; + u32 __percpu *pcpu_success; +#endif + struct rcu_work rwork; + /* The 'sel' field MUST be the last field in structure to allow for + * tc_u32_keys allocated at end of structure. + */ + struct tc_u32_sel sel; +}; + +struct tc_u_hnode { + struct tc_u_hnode __rcu *next; + u32 handle; + u32 prio; + int refcnt; + unsigned int divisor; + struct idr handle_idr; + bool is_root; + struct rcu_head rcu; + u32 flags; + /* The 'ht' field MUST be the last field in structure to allow for + * more entries allocated at end of structure. + */ + struct tc_u_knode __rcu *ht[]; +}; + +struct tc_u_common { + struct tc_u_hnode __rcu *hlist; + void *ptr; + int refcnt; + struct idr handle_idr; + struct hlist_node hnode; + long knodes; +}; + +static inline unsigned int u32_hash_fold(__be32 key, + const struct tc_u32_sel *sel, + u8 fshift) +{ + unsigned int h = ntohl(key & sel->hmask) >> fshift; + + return h; +} + +static int u32_classify(struct sk_buff *skb, const struct tcf_proto *tp, + struct tcf_result *res) +{ + struct { + struct tc_u_knode *knode; + unsigned int off; + } stack[TC_U32_MAXDEPTH]; + + struct tc_u_hnode *ht = rcu_dereference_bh(tp->root); + unsigned int off = skb_network_offset(skb); + struct tc_u_knode *n; + int sdepth = 0; + int off2 = 0; + int sel = 0; +#ifdef CONFIG_CLS_U32_PERF + int j; +#endif + int i, r; + +next_ht: + n = rcu_dereference_bh(ht->ht[sel]); + +next_knode: + if (n) { + struct tc_u32_key *key = n->sel.keys; + +#ifdef CONFIG_CLS_U32_PERF + __this_cpu_inc(n->pf->rcnt); + j = 0; +#endif + + if (tc_skip_sw(n->flags)) { + n = rcu_dereference_bh(n->next); + goto next_knode; + } + +#ifdef CONFIG_CLS_U32_MARK + if ((skb->mark & n->mask) != n->val) { + n = rcu_dereference_bh(n->next); + goto next_knode; + } else { + __this_cpu_inc(*n->pcpu_success); + } +#endif + + for (i = n->sel.nkeys; i > 0; i--, key++) { + int toff = off + key->off + (off2 & key->offmask); + __be32 *data, hdata; + + if (skb_headroom(skb) + toff > INT_MAX) + goto out; + + data = skb_header_pointer(skb, toff, 4, &hdata); + if (!data) + goto out; + if ((*data ^ key->val) & key->mask) { + n = rcu_dereference_bh(n->next); + goto next_knode; + } +#ifdef CONFIG_CLS_U32_PERF + __this_cpu_inc(n->pf->kcnts[j]); + j++; +#endif + } + + ht = rcu_dereference_bh(n->ht_down); + if (!ht) { +check_terminal: + if (n->sel.flags & TC_U32_TERMINAL) { + + *res = n->res; + if (!tcf_match_indev(skb, n->ifindex)) { + n = rcu_dereference_bh(n->next); + goto next_knode; + } +#ifdef CONFIG_CLS_U32_PERF + __this_cpu_inc(n->pf->rhit); +#endif + r = tcf_exts_exec(skb, &n->exts, res); + if (r < 0) { + n = rcu_dereference_bh(n->next); + goto next_knode; + } + + return r; + } + n = rcu_dereference_bh(n->next); + goto next_knode; + } + + /* PUSH */ + if (sdepth >= TC_U32_MAXDEPTH) + goto deadloop; + stack[sdepth].knode = n; + stack[sdepth].off = off; + sdepth++; + + ht = rcu_dereference_bh(n->ht_down); + sel = 0; + if (ht->divisor) { + __be32 *data, hdata; + + data = skb_header_pointer(skb, off + n->sel.hoff, 4, + &hdata); + if (!data) + goto out; + sel = ht->divisor & u32_hash_fold(*data, &n->sel, + n->fshift); + } + if (!(n->sel.flags & (TC_U32_VAROFFSET | TC_U32_OFFSET | TC_U32_EAT))) + goto next_ht; + + if (n->sel.flags & (TC_U32_OFFSET | TC_U32_VAROFFSET)) { + off2 = n->sel.off + 3; + if (n->sel.flags & TC_U32_VAROFFSET) { + __be16 *data, hdata; + + data = skb_header_pointer(skb, + off + n->sel.offoff, + 2, &hdata); + if (!data) + goto out; + off2 += ntohs(n->sel.offmask & *data) >> + n->sel.offshift; + } + off2 &= ~3; + } + if (n->sel.flags & TC_U32_EAT) { + off += off2; + off2 = 0; + } + + if (off < skb->len) + goto next_ht; + } + + /* POP */ + if (sdepth--) { + n = stack[sdepth].knode; + ht = rcu_dereference_bh(n->ht_up); + off = stack[sdepth].off; + goto check_terminal; + } +out: + return -1; + +deadloop: + net_warn_ratelimited("cls_u32: dead loop\n"); + return -1; +} + +static struct tc_u_hnode *u32_lookup_ht(struct tc_u_common *tp_c, u32 handle) +{ + struct tc_u_hnode *ht; + + for (ht = rtnl_dereference(tp_c->hlist); + ht; + ht = rtnl_dereference(ht->next)) + if (ht->handle == handle) + break; + + return ht; +} + +static struct tc_u_knode *u32_lookup_key(struct tc_u_hnode *ht, u32 handle) +{ + unsigned int sel; + struct tc_u_knode *n = NULL; + + sel = TC_U32_HASH(handle); + if (sel > ht->divisor) + goto out; + + for (n = rtnl_dereference(ht->ht[sel]); + n; + n = rtnl_dereference(n->next)) + if (n->handle == handle) + break; +out: + return n; +} + + +static void *u32_get(struct tcf_proto *tp, u32 handle) +{ + struct tc_u_hnode *ht; + struct tc_u_common *tp_c = tp->data; + + if (TC_U32_HTID(handle) == TC_U32_ROOT) + ht = rtnl_dereference(tp->root); + else + ht = u32_lookup_ht(tp_c, TC_U32_HTID(handle)); + + if (!ht) + return NULL; + + if (TC_U32_KEY(handle) == 0) + return ht; + + return u32_lookup_key(ht, handle); +} + +/* Protected by rtnl lock */ +static u32 gen_new_htid(struct tc_u_common *tp_c, struct tc_u_hnode *ptr) +{ + int id = idr_alloc_cyclic(&tp_c->handle_idr, ptr, 1, 0x7FF, GFP_KERNEL); + if (id < 0) + return 0; + return (id | 0x800U) << 20; +} + +static struct hlist_head *tc_u_common_hash; + +#define U32_HASH_SHIFT 10 +#define U32_HASH_SIZE (1 << U32_HASH_SHIFT) + +static void *tc_u_common_ptr(const struct tcf_proto *tp) +{ + struct tcf_block *block = tp->chain->block; + + /* The block sharing is currently supported only + * for classless qdiscs. In that case we use block + * for tc_u_common identification. In case the + * block is not shared, block->q is a valid pointer + * and we can use that. That works for classful qdiscs. + */ + if (tcf_block_shared(block)) + return block; + else + return block->q; +} + +static struct hlist_head *tc_u_hash(void *key) +{ + return tc_u_common_hash + hash_ptr(key, U32_HASH_SHIFT); +} + +static struct tc_u_common *tc_u_common_find(void *key) +{ + struct tc_u_common *tc; + hlist_for_each_entry(tc, tc_u_hash(key), hnode) { + if (tc->ptr == key) + return tc; + } + return NULL; +} + +static int u32_init(struct tcf_proto *tp) +{ + struct tc_u_hnode *root_ht; + void *key = tc_u_common_ptr(tp); + struct tc_u_common *tp_c = tc_u_common_find(key); + + root_ht = kzalloc(struct_size(root_ht, ht, 1), GFP_KERNEL); + if (root_ht == NULL) + return -ENOBUFS; + + root_ht->refcnt++; + root_ht->handle = tp_c ? gen_new_htid(tp_c, root_ht) : 0x80000000; + root_ht->prio = tp->prio; + root_ht->is_root = true; + idr_init(&root_ht->handle_idr); + + if (tp_c == NULL) { + tp_c = kzalloc(sizeof(*tp_c), GFP_KERNEL); + if (tp_c == NULL) { + kfree(root_ht); + return -ENOBUFS; + } + tp_c->ptr = key; + INIT_HLIST_NODE(&tp_c->hnode); + idr_init(&tp_c->handle_idr); + + hlist_add_head(&tp_c->hnode, tc_u_hash(key)); + } + + tp_c->refcnt++; + RCU_INIT_POINTER(root_ht->next, tp_c->hlist); + rcu_assign_pointer(tp_c->hlist, root_ht); + + root_ht->refcnt++; + rcu_assign_pointer(tp->root, root_ht); + tp->data = tp_c; + return 0; +} + +static void __u32_destroy_key(struct tc_u_knode *n) +{ + struct tc_u_hnode *ht = rtnl_dereference(n->ht_down); + + tcf_exts_destroy(&n->exts); + if (ht && --ht->refcnt == 0) + kfree(ht); + kfree(n); +} + +static void u32_destroy_key(struct tc_u_knode *n, bool free_pf) +{ + tcf_exts_put_net(&n->exts); +#ifdef CONFIG_CLS_U32_PERF + if (free_pf) + free_percpu(n->pf); +#endif +#ifdef CONFIG_CLS_U32_MARK + if (free_pf) + free_percpu(n->pcpu_success); +#endif + __u32_destroy_key(n); +} + +/* u32_delete_key_rcu should be called when free'ing a copied + * version of a tc_u_knode obtained from u32_init_knode(). When + * copies are obtained from u32_init_knode() the statistics are + * shared between the old and new copies to allow readers to + * continue to update the statistics during the copy. To support + * this the u32_delete_key_rcu variant does not free the percpu + * statistics. + */ +static void u32_delete_key_work(struct work_struct *work) +{ + struct tc_u_knode *key = container_of(to_rcu_work(work), + struct tc_u_knode, + rwork); + rtnl_lock(); + u32_destroy_key(key, false); + rtnl_unlock(); +} + +/* u32_delete_key_freepf_rcu is the rcu callback variant + * that free's the entire structure including the statistics + * percpu variables. Only use this if the key is not a copy + * returned by u32_init_knode(). See u32_delete_key_rcu() + * for the variant that should be used with keys return from + * u32_init_knode() + */ +static void u32_delete_key_freepf_work(struct work_struct *work) +{ + struct tc_u_knode *key = container_of(to_rcu_work(work), + struct tc_u_knode, + rwork); + rtnl_lock(); + u32_destroy_key(key, true); + rtnl_unlock(); +} + +static int u32_delete_key(struct tcf_proto *tp, struct tc_u_knode *key) +{ + struct tc_u_common *tp_c = tp->data; + struct tc_u_knode __rcu **kp; + struct tc_u_knode *pkp; + struct tc_u_hnode *ht = rtnl_dereference(key->ht_up); + + if (ht) { + kp = &ht->ht[TC_U32_HASH(key->handle)]; + for (pkp = rtnl_dereference(*kp); pkp; + kp = &pkp->next, pkp = rtnl_dereference(*kp)) { + if (pkp == key) { + RCU_INIT_POINTER(*kp, key->next); + tp_c->knodes--; + + tcf_unbind_filter(tp, &key->res); + idr_remove(&ht->handle_idr, key->handle); + tcf_exts_get_net(&key->exts); + tcf_queue_work(&key->rwork, u32_delete_key_freepf_work); + return 0; + } + } + } + WARN_ON(1); + return 0; +} + +static void u32_clear_hw_hnode(struct tcf_proto *tp, struct tc_u_hnode *h, + struct netlink_ext_ack *extack) +{ + struct tcf_block *block = tp->chain->block; + struct tc_cls_u32_offload cls_u32 = {}; + + tc_cls_common_offload_init(&cls_u32.common, tp, h->flags, extack); + cls_u32.command = TC_CLSU32_DELETE_HNODE; + cls_u32.hnode.divisor = h->divisor; + cls_u32.hnode.handle = h->handle; + cls_u32.hnode.prio = h->prio; + + tc_setup_cb_call(block, TC_SETUP_CLSU32, &cls_u32, false, true); +} + +static int u32_replace_hw_hnode(struct tcf_proto *tp, struct tc_u_hnode *h, + u32 flags, struct netlink_ext_ack *extack) +{ + struct tcf_block *block = tp->chain->block; + struct tc_cls_u32_offload cls_u32 = {}; + bool skip_sw = tc_skip_sw(flags); + bool offloaded = false; + int err; + + tc_cls_common_offload_init(&cls_u32.common, tp, flags, extack); + cls_u32.command = TC_CLSU32_NEW_HNODE; + cls_u32.hnode.divisor = h->divisor; + cls_u32.hnode.handle = h->handle; + cls_u32.hnode.prio = h->prio; + + err = tc_setup_cb_call(block, TC_SETUP_CLSU32, &cls_u32, skip_sw, true); + if (err < 0) { + u32_clear_hw_hnode(tp, h, NULL); + return err; + } else if (err > 0) { + offloaded = true; + } + + if (skip_sw && !offloaded) + return -EINVAL; + + return 0; +} + +static void u32_remove_hw_knode(struct tcf_proto *tp, struct tc_u_knode *n, + struct netlink_ext_ack *extack) +{ + struct tcf_block *block = tp->chain->block; + struct tc_cls_u32_offload cls_u32 = {}; + + tc_cls_common_offload_init(&cls_u32.common, tp, n->flags, extack); + cls_u32.command = TC_CLSU32_DELETE_KNODE; + cls_u32.knode.handle = n->handle; + + tc_setup_cb_destroy(block, tp, TC_SETUP_CLSU32, &cls_u32, false, + &n->flags, &n->in_hw_count, true); +} + +static int u32_replace_hw_knode(struct tcf_proto *tp, struct tc_u_knode *n, + u32 flags, struct netlink_ext_ack *extack) +{ + struct tc_u_hnode *ht = rtnl_dereference(n->ht_down); + struct tcf_block *block = tp->chain->block; + struct tc_cls_u32_offload cls_u32 = {}; + bool skip_sw = tc_skip_sw(flags); + int err; + + tc_cls_common_offload_init(&cls_u32.common, tp, flags, extack); + cls_u32.command = TC_CLSU32_REPLACE_KNODE; + cls_u32.knode.handle = n->handle; + cls_u32.knode.fshift = n->fshift; +#ifdef CONFIG_CLS_U32_MARK + cls_u32.knode.val = n->val; + cls_u32.knode.mask = n->mask; +#else + cls_u32.knode.val = 0; + cls_u32.knode.mask = 0; +#endif + cls_u32.knode.sel = &n->sel; + cls_u32.knode.res = &n->res; + cls_u32.knode.exts = &n->exts; + if (n->ht_down) + cls_u32.knode.link_handle = ht->handle; + + err = tc_setup_cb_add(block, tp, TC_SETUP_CLSU32, &cls_u32, skip_sw, + &n->flags, &n->in_hw_count, true); + if (err) { + u32_remove_hw_knode(tp, n, NULL); + return err; + } + + if (skip_sw && !(n->flags & TCA_CLS_FLAGS_IN_HW)) + return -EINVAL; + + return 0; +} + +static void u32_clear_hnode(struct tcf_proto *tp, struct tc_u_hnode *ht, + struct netlink_ext_ack *extack) +{ + struct tc_u_common *tp_c = tp->data; + struct tc_u_knode *n; + unsigned int h; + + for (h = 0; h <= ht->divisor; h++) { + while ((n = rtnl_dereference(ht->ht[h])) != NULL) { + RCU_INIT_POINTER(ht->ht[h], + rtnl_dereference(n->next)); + tp_c->knodes--; + tcf_unbind_filter(tp, &n->res); + u32_remove_hw_knode(tp, n, extack); + idr_remove(&ht->handle_idr, n->handle); + if (tcf_exts_get_net(&n->exts)) + tcf_queue_work(&n->rwork, u32_delete_key_freepf_work); + else + u32_destroy_key(n, true); + } + } +} + +static int u32_destroy_hnode(struct tcf_proto *tp, struct tc_u_hnode *ht, + struct netlink_ext_ack *extack) +{ + struct tc_u_common *tp_c = tp->data; + struct tc_u_hnode __rcu **hn; + struct tc_u_hnode *phn; + + WARN_ON(--ht->refcnt); + + u32_clear_hnode(tp, ht, extack); + + hn = &tp_c->hlist; + for (phn = rtnl_dereference(*hn); + phn; + hn = &phn->next, phn = rtnl_dereference(*hn)) { + if (phn == ht) { + u32_clear_hw_hnode(tp, ht, extack); + idr_destroy(&ht->handle_idr); + idr_remove(&tp_c->handle_idr, ht->handle); + RCU_INIT_POINTER(*hn, ht->next); + kfree_rcu(ht, rcu); + return 0; + } + } + + return -ENOENT; +} + +static void u32_destroy(struct tcf_proto *tp, bool rtnl_held, + struct netlink_ext_ack *extack) +{ + struct tc_u_common *tp_c = tp->data; + struct tc_u_hnode *root_ht = rtnl_dereference(tp->root); + + WARN_ON(root_ht == NULL); + + if (root_ht && --root_ht->refcnt == 1) + u32_destroy_hnode(tp, root_ht, extack); + + if (--tp_c->refcnt == 0) { + struct tc_u_hnode *ht; + + hlist_del(&tp_c->hnode); + + while ((ht = rtnl_dereference(tp_c->hlist)) != NULL) { + u32_clear_hnode(tp, ht, extack); + RCU_INIT_POINTER(tp_c->hlist, ht->next); + + /* u32_destroy_key() will later free ht for us, if it's + * still referenced by some knode + */ + if (--ht->refcnt == 0) + kfree_rcu(ht, rcu); + } + + idr_destroy(&tp_c->handle_idr); + kfree(tp_c); + } + + tp->data = NULL; +} + +static int u32_delete(struct tcf_proto *tp, void *arg, bool *last, + bool rtnl_held, struct netlink_ext_ack *extack) +{ + struct tc_u_hnode *ht = arg; + struct tc_u_common *tp_c = tp->data; + int ret = 0; + + if (TC_U32_KEY(ht->handle)) { + u32_remove_hw_knode(tp, (struct tc_u_knode *)ht, extack); + ret = u32_delete_key(tp, (struct tc_u_knode *)ht); + goto out; + } + + if (ht->is_root) { + NL_SET_ERR_MSG_MOD(extack, "Not allowed to delete root node"); + return -EINVAL; + } + + if (ht->refcnt == 1) { + u32_destroy_hnode(tp, ht, extack); + } else { + NL_SET_ERR_MSG_MOD(extack, "Can not delete in-use filter"); + return -EBUSY; + } + +out: + *last = tp_c->refcnt == 1 && tp_c->knodes == 0; + return ret; +} + +static u32 gen_new_kid(struct tc_u_hnode *ht, u32 htid) +{ + u32 index = htid | 0x800; + u32 max = htid | 0xFFF; + + if (idr_alloc_u32(&ht->handle_idr, NULL, &index, max, GFP_KERNEL)) { + index = htid + 1; + if (idr_alloc_u32(&ht->handle_idr, NULL, &index, max, + GFP_KERNEL)) + index = max; + } + + return index; +} + +static const struct nla_policy u32_policy[TCA_U32_MAX + 1] = { + [TCA_U32_CLASSID] = { .type = NLA_U32 }, + [TCA_U32_HASH] = { .type = NLA_U32 }, + [TCA_U32_LINK] = { .type = NLA_U32 }, + [TCA_U32_DIVISOR] = { .type = NLA_U32 }, + [TCA_U32_SEL] = { .len = sizeof(struct tc_u32_sel) }, + [TCA_U32_INDEV] = { .type = NLA_STRING, .len = IFNAMSIZ }, + [TCA_U32_MARK] = { .len = sizeof(struct tc_u32_mark) }, + [TCA_U32_FLAGS] = { .type = NLA_U32 }, +}; + +static void u32_unbind_filter(struct tcf_proto *tp, struct tc_u_knode *n, + struct nlattr **tb) +{ + if (tb[TCA_U32_CLASSID]) + tcf_unbind_filter(tp, &n->res); +} + +static void u32_bind_filter(struct tcf_proto *tp, struct tc_u_knode *n, + unsigned long base, struct nlattr **tb) +{ + if (tb[TCA_U32_CLASSID]) { + n->res.classid = nla_get_u32(tb[TCA_U32_CLASSID]); + tcf_bind_filter(tp, &n->res, base); + } +} + +static int u32_set_parms(struct net *net, struct tcf_proto *tp, + struct tc_u_knode *n, struct nlattr **tb, + struct nlattr *est, u32 flags, u32 fl_flags, + struct netlink_ext_ack *extack) +{ + int err, ifindex = -1; + + err = tcf_exts_validate_ex(net, tp, tb, est, &n->exts, flags, + fl_flags, extack); + if (err < 0) + return err; + + if (tb[TCA_U32_INDEV]) { + ifindex = tcf_change_indev(net, tb[TCA_U32_INDEV], extack); + if (ifindex < 0) + return -EINVAL; + } + + if (tb[TCA_U32_LINK]) { + u32 handle = nla_get_u32(tb[TCA_U32_LINK]); + struct tc_u_hnode *ht_down = NULL, *ht_old; + + if (TC_U32_KEY(handle)) { + NL_SET_ERR_MSG_MOD(extack, "u32 Link handle must be a hash table"); + return -EINVAL; + } + + if (handle) { + ht_down = u32_lookup_ht(tp->data, handle); + + if (!ht_down) { + NL_SET_ERR_MSG_MOD(extack, "Link hash table not found"); + return -EINVAL; + } + if (ht_down->is_root) { + NL_SET_ERR_MSG_MOD(extack, "Not linking to root node"); + return -EINVAL; + } + ht_down->refcnt++; + } + + ht_old = rtnl_dereference(n->ht_down); + rcu_assign_pointer(n->ht_down, ht_down); + + if (ht_old) + ht_old->refcnt--; + } + + if (ifindex >= 0) + n->ifindex = ifindex; + + return 0; +} + +static void u32_replace_knode(struct tcf_proto *tp, struct tc_u_common *tp_c, + struct tc_u_knode *n) +{ + struct tc_u_knode __rcu **ins; + struct tc_u_knode *pins; + struct tc_u_hnode *ht; + + if (TC_U32_HTID(n->handle) == TC_U32_ROOT) + ht = rtnl_dereference(tp->root); + else + ht = u32_lookup_ht(tp_c, TC_U32_HTID(n->handle)); + + ins = &ht->ht[TC_U32_HASH(n->handle)]; + + /* The node must always exist for it to be replaced if this is not the + * case then something went very wrong elsewhere. + */ + for (pins = rtnl_dereference(*ins); ; + ins = &pins->next, pins = rtnl_dereference(*ins)) + if (pins->handle == n->handle) + break; + + idr_replace(&ht->handle_idr, n, n->handle); + RCU_INIT_POINTER(n->next, pins->next); + rcu_assign_pointer(*ins, n); +} + +static struct tc_u_knode *u32_init_knode(struct net *net, struct tcf_proto *tp, + struct tc_u_knode *n) +{ + struct tc_u_hnode *ht = rtnl_dereference(n->ht_down); + struct tc_u32_sel *s = &n->sel; + struct tc_u_knode *new; + + new = kzalloc(struct_size(new, sel.keys, s->nkeys), GFP_KERNEL); + if (!new) + return NULL; + + RCU_INIT_POINTER(new->next, n->next); + new->handle = n->handle; + RCU_INIT_POINTER(new->ht_up, n->ht_up); + + new->ifindex = n->ifindex; + new->fshift = n->fshift; + new->flags = n->flags; + RCU_INIT_POINTER(new->ht_down, ht); + +#ifdef CONFIG_CLS_U32_PERF + /* Statistics may be incremented by readers during update + * so we must keep them in tact. When the node is later destroyed + * a special destroy call must be made to not free the pf memory. + */ + new->pf = n->pf; +#endif + +#ifdef CONFIG_CLS_U32_MARK + new->val = n->val; + new->mask = n->mask; + /* Similarly success statistics must be moved as pointers */ + new->pcpu_success = n->pcpu_success; +#endif + memcpy(&new->sel, s, struct_size(s, keys, s->nkeys)); + + if (tcf_exts_init(&new->exts, net, TCA_U32_ACT, TCA_U32_POLICE)) { + kfree(new); + return NULL; + } + + /* bump reference count as long as we hold pointer to structure */ + if (ht) + ht->refcnt++; + + return new; +} + +static int u32_change(struct net *net, struct sk_buff *in_skb, + struct tcf_proto *tp, unsigned long base, u32 handle, + struct nlattr **tca, void **arg, u32 flags, + struct netlink_ext_ack *extack) +{ + struct tc_u_common *tp_c = tp->data; + struct tc_u_hnode *ht; + struct tc_u_knode *n; + struct tc_u32_sel *s; + struct nlattr *opt = tca[TCA_OPTIONS]; + struct nlattr *tb[TCA_U32_MAX + 1]; + u32 htid, userflags = 0; + size_t sel_size; + int err; + + if (!opt) { + if (handle) { + NL_SET_ERR_MSG_MOD(extack, "Filter handle requires options"); + return -EINVAL; + } else { + return 0; + } + } + + err = nla_parse_nested_deprecated(tb, TCA_U32_MAX, opt, u32_policy, + extack); + if (err < 0) + return err; + + if (tb[TCA_U32_FLAGS]) { + userflags = nla_get_u32(tb[TCA_U32_FLAGS]); + if (!tc_flags_valid(userflags)) { + NL_SET_ERR_MSG_MOD(extack, "Invalid filter flags"); + return -EINVAL; + } + } + + n = *arg; + if (n) { + struct tc_u_knode *new; + + if (TC_U32_KEY(n->handle) == 0) { + NL_SET_ERR_MSG_MOD(extack, "Key node id cannot be zero"); + return -EINVAL; + } + + if ((n->flags ^ userflags) & + ~(TCA_CLS_FLAGS_IN_HW | TCA_CLS_FLAGS_NOT_IN_HW)) { + NL_SET_ERR_MSG_MOD(extack, "Key node flags do not match passed flags"); + return -EINVAL; + } + + new = u32_init_knode(net, tp, n); + if (!new) + return -ENOMEM; + + err = u32_set_parms(net, tp, new, tb, tca[TCA_RATE], + flags, new->flags, extack); + + if (err) { + __u32_destroy_key(new); + return err; + } + + u32_bind_filter(tp, new, base, tb); + + err = u32_replace_hw_knode(tp, new, flags, extack); + if (err) { + u32_unbind_filter(tp, new, tb); + + if (tb[TCA_U32_LINK]) { + struct tc_u_hnode *ht_old; + + ht_old = rtnl_dereference(n->ht_down); + if (ht_old) + ht_old->refcnt++; + } + __u32_destroy_key(new); + return err; + } + + if (!tc_in_hw(new->flags)) + new->flags |= TCA_CLS_FLAGS_NOT_IN_HW; + + u32_replace_knode(tp, tp_c, new); + tcf_unbind_filter(tp, &n->res); + tcf_exts_get_net(&n->exts); + tcf_queue_work(&n->rwork, u32_delete_key_work); + return 0; + } + + if (tb[TCA_U32_DIVISOR]) { + unsigned int divisor = nla_get_u32(tb[TCA_U32_DIVISOR]); + + if (!is_power_of_2(divisor)) { + NL_SET_ERR_MSG_MOD(extack, "Divisor is not a power of 2"); + return -EINVAL; + } + if (divisor-- > 0x100) { + NL_SET_ERR_MSG_MOD(extack, "Exceeded maximum 256 hash buckets"); + return -EINVAL; + } + if (TC_U32_KEY(handle)) { + NL_SET_ERR_MSG_MOD(extack, "Divisor can only be used on a hash table"); + return -EINVAL; + } + ht = kzalloc(struct_size(ht, ht, divisor + 1), GFP_KERNEL); + if (ht == NULL) + return -ENOBUFS; + if (handle == 0) { + handle = gen_new_htid(tp->data, ht); + if (handle == 0) { + kfree(ht); + return -ENOMEM; + } + } else { + err = idr_alloc_u32(&tp_c->handle_idr, ht, &handle, + handle, GFP_KERNEL); + if (err) { + kfree(ht); + return err; + } + } + ht->refcnt = 1; + ht->divisor = divisor; + ht->handle = handle; + ht->prio = tp->prio; + idr_init(&ht->handle_idr); + ht->flags = userflags; + + err = u32_replace_hw_hnode(tp, ht, userflags, extack); + if (err) { + idr_remove(&tp_c->handle_idr, handle); + kfree(ht); + return err; + } + + RCU_INIT_POINTER(ht->next, tp_c->hlist); + rcu_assign_pointer(tp_c->hlist, ht); + *arg = ht; + + return 0; + } + + if (tb[TCA_U32_HASH]) { + htid = nla_get_u32(tb[TCA_U32_HASH]); + if (TC_U32_HTID(htid) == TC_U32_ROOT) { + ht = rtnl_dereference(tp->root); + htid = ht->handle; + } else { + ht = u32_lookup_ht(tp->data, TC_U32_HTID(htid)); + if (!ht) { + NL_SET_ERR_MSG_MOD(extack, "Specified hash table not found"); + return -EINVAL; + } + } + } else { + ht = rtnl_dereference(tp->root); + htid = ht->handle; + } + + if (ht->divisor < TC_U32_HASH(htid)) { + NL_SET_ERR_MSG_MOD(extack, "Specified hash table buckets exceed configured value"); + return -EINVAL; + } + + /* At this point, we need to derive the new handle that will be used to + * uniquely map the identity of this table match entry. The + * identity of the entry that we need to construct is 32 bits made of: + * htid(12b):bucketid(8b):node/entryid(12b) + * + * At this point _we have the table(ht)_ in which we will insert this + * entry. We carry the table's id in variable "htid". + * Note that earlier code picked the ht selection either by a) the user + * providing the htid specified via TCA_U32_HASH attribute or b) when + * no such attribute is passed then the root ht, is default to at ID + * 0x[800][00][000]. Rule: the root table has a single bucket with ID 0. + * If OTOH the user passed us the htid, they may also pass a bucketid of + * choice. 0 is fine. For example a user htid is 0x[600][01][000] it is + * indicating hash bucketid of 1. Rule: the entry/node ID _cannot_ be + * passed via the htid, so even if it was non-zero it will be ignored. + * + * We may also have a handle, if the user passed one. The handle also + * carries the same addressing of htid(12b):bucketid(8b):node/entryid(12b). + * Rule: the bucketid on the handle is ignored even if one was passed; + * rather the value on "htid" is always assumed to be the bucketid. + */ + if (handle) { + /* Rule: The htid from handle and tableid from htid must match */ + if (TC_U32_HTID(handle) && TC_U32_HTID(handle ^ htid)) { + NL_SET_ERR_MSG_MOD(extack, "Handle specified hash table address mismatch"); + return -EINVAL; + } + /* Ok, so far we have a valid htid(12b):bucketid(8b) but we + * need to finalize the table entry identification with the last + * part - the node/entryid(12b)). Rule: Nodeid _cannot be 0_ for + * entries. Rule: nodeid of 0 is reserved only for tables(see + * earlier code which processes TC_U32_DIVISOR attribute). + * Rule: The nodeid can only be derived from the handle (and not + * htid). + * Rule: if the handle specified zero for the node id example + * 0x60000000, then pick a new nodeid from the pool of IDs + * this hash table has been allocating from. + * If OTOH it is specified (i.e for example the user passed a + * handle such as 0x60000123), then we use it generate our final + * handle which is used to uniquely identify the match entry. + */ + if (!TC_U32_NODE(handle)) { + handle = gen_new_kid(ht, htid); + } else { + handle = htid | TC_U32_NODE(handle); + err = idr_alloc_u32(&ht->handle_idr, NULL, &handle, + handle, GFP_KERNEL); + if (err) + return err; + } + } else { + /* The user did not give us a handle; lets just generate one + * from the table's pool of nodeids. + */ + handle = gen_new_kid(ht, htid); + } + + if (tb[TCA_U32_SEL] == NULL) { + NL_SET_ERR_MSG_MOD(extack, "Selector not specified"); + err = -EINVAL; + goto erridr; + } + + s = nla_data(tb[TCA_U32_SEL]); + sel_size = struct_size(s, keys, s->nkeys); + if (nla_len(tb[TCA_U32_SEL]) < sel_size) { + err = -EINVAL; + goto erridr; + } + + n = kzalloc(struct_size(n, sel.keys, s->nkeys), GFP_KERNEL); + if (n == NULL) { + err = -ENOBUFS; + goto erridr; + } + +#ifdef CONFIG_CLS_U32_PERF + n->pf = __alloc_percpu(struct_size(n->pf, kcnts, s->nkeys), + __alignof__(struct tc_u32_pcnt)); + if (!n->pf) { + err = -ENOBUFS; + goto errfree; + } +#endif + + unsafe_memcpy(&n->sel, s, sel_size, + /* A composite flex-array structure destination, + * which was correctly sized with struct_size(), + * bounds-checked against nla_len(), and allocated + * above. */); + RCU_INIT_POINTER(n->ht_up, ht); + n->handle = handle; + n->fshift = s->hmask ? ffs(ntohl(s->hmask)) - 1 : 0; + n->flags = userflags; + + err = tcf_exts_init(&n->exts, net, TCA_U32_ACT, TCA_U32_POLICE); + if (err < 0) + goto errout; + +#ifdef CONFIG_CLS_U32_MARK + n->pcpu_success = alloc_percpu(u32); + if (!n->pcpu_success) { + err = -ENOMEM; + goto errout; + } + + if (tb[TCA_U32_MARK]) { + struct tc_u32_mark *mark; + + mark = nla_data(tb[TCA_U32_MARK]); + n->val = mark->val; + n->mask = mark->mask; + } +#endif + + err = u32_set_parms(net, tp, n, tb, tca[TCA_RATE], + flags, n->flags, extack); + + u32_bind_filter(tp, n, base, tb); + + if (err == 0) { + struct tc_u_knode __rcu **ins; + struct tc_u_knode *pins; + + err = u32_replace_hw_knode(tp, n, flags, extack); + if (err) + goto errunbind; + + if (!tc_in_hw(n->flags)) + n->flags |= TCA_CLS_FLAGS_NOT_IN_HW; + + ins = &ht->ht[TC_U32_HASH(handle)]; + for (pins = rtnl_dereference(*ins); pins; + ins = &pins->next, pins = rtnl_dereference(*ins)) + if (TC_U32_NODE(handle) < TC_U32_NODE(pins->handle)) + break; + + RCU_INIT_POINTER(n->next, pins); + rcu_assign_pointer(*ins, n); + tp_c->knodes++; + *arg = n; + return 0; + } + +errunbind: + u32_unbind_filter(tp, n, tb); + +#ifdef CONFIG_CLS_U32_MARK + free_percpu(n->pcpu_success); +#endif + +errout: + tcf_exts_destroy(&n->exts); +#ifdef CONFIG_CLS_U32_PERF +errfree: + free_percpu(n->pf); +#endif + kfree(n); +erridr: + idr_remove(&ht->handle_idr, handle); + return err; +} + +static void u32_walk(struct tcf_proto *tp, struct tcf_walker *arg, + bool rtnl_held) +{ + struct tc_u_common *tp_c = tp->data; + struct tc_u_hnode *ht; + struct tc_u_knode *n; + unsigned int h; + + if (arg->stop) + return; + + for (ht = rtnl_dereference(tp_c->hlist); + ht; + ht = rtnl_dereference(ht->next)) { + if (ht->prio != tp->prio) + continue; + + if (!tc_cls_stats_dump(tp, arg, ht)) + return; + + for (h = 0; h <= ht->divisor; h++) { + for (n = rtnl_dereference(ht->ht[h]); + n; + n = rtnl_dereference(n->next)) { + if (!tc_cls_stats_dump(tp, arg, n)) + return; + } + } + } +} + +static int u32_reoffload_hnode(struct tcf_proto *tp, struct tc_u_hnode *ht, + bool add, flow_setup_cb_t *cb, void *cb_priv, + struct netlink_ext_ack *extack) +{ + struct tc_cls_u32_offload cls_u32 = {}; + int err; + + tc_cls_common_offload_init(&cls_u32.common, tp, ht->flags, extack); + cls_u32.command = add ? TC_CLSU32_NEW_HNODE : TC_CLSU32_DELETE_HNODE; + cls_u32.hnode.divisor = ht->divisor; + cls_u32.hnode.handle = ht->handle; + cls_u32.hnode.prio = ht->prio; + + err = cb(TC_SETUP_CLSU32, &cls_u32, cb_priv); + if (err && add && tc_skip_sw(ht->flags)) + return err; + + return 0; +} + +static int u32_reoffload_knode(struct tcf_proto *tp, struct tc_u_knode *n, + bool add, flow_setup_cb_t *cb, void *cb_priv, + struct netlink_ext_ack *extack) +{ + struct tc_u_hnode *ht = rtnl_dereference(n->ht_down); + struct tcf_block *block = tp->chain->block; + struct tc_cls_u32_offload cls_u32 = {}; + + tc_cls_common_offload_init(&cls_u32.common, tp, n->flags, extack); + cls_u32.command = add ? + TC_CLSU32_REPLACE_KNODE : TC_CLSU32_DELETE_KNODE; + cls_u32.knode.handle = n->handle; + + if (add) { + cls_u32.knode.fshift = n->fshift; +#ifdef CONFIG_CLS_U32_MARK + cls_u32.knode.val = n->val; + cls_u32.knode.mask = n->mask; +#else + cls_u32.knode.val = 0; + cls_u32.knode.mask = 0; +#endif + cls_u32.knode.sel = &n->sel; + cls_u32.knode.res = &n->res; + cls_u32.knode.exts = &n->exts; + if (n->ht_down) + cls_u32.knode.link_handle = ht->handle; + } + + return tc_setup_cb_reoffload(block, tp, add, cb, TC_SETUP_CLSU32, + &cls_u32, cb_priv, &n->flags, + &n->in_hw_count); +} + +static int u32_reoffload(struct tcf_proto *tp, bool add, flow_setup_cb_t *cb, + void *cb_priv, struct netlink_ext_ack *extack) +{ + struct tc_u_common *tp_c = tp->data; + struct tc_u_hnode *ht; + struct tc_u_knode *n; + unsigned int h; + int err; + + for (ht = rtnl_dereference(tp_c->hlist); + ht; + ht = rtnl_dereference(ht->next)) { + if (ht->prio != tp->prio) + continue; + + /* When adding filters to a new dev, try to offload the + * hashtable first. When removing, do the filters before the + * hashtable. + */ + if (add && !tc_skip_hw(ht->flags)) { + err = u32_reoffload_hnode(tp, ht, add, cb, cb_priv, + extack); + if (err) + return err; + } + + for (h = 0; h <= ht->divisor; h++) { + for (n = rtnl_dereference(ht->ht[h]); + n; + n = rtnl_dereference(n->next)) { + if (tc_skip_hw(n->flags)) + continue; + + err = u32_reoffload_knode(tp, n, add, cb, + cb_priv, extack); + if (err) + return err; + } + } + + if (!add && !tc_skip_hw(ht->flags)) + u32_reoffload_hnode(tp, ht, add, cb, cb_priv, extack); + } + + return 0; +} + +static void u32_bind_class(void *fh, u32 classid, unsigned long cl, void *q, + unsigned long base) +{ + struct tc_u_knode *n = fh; + + tc_cls_bind_class(classid, cl, q, &n->res, base); +} + +static int u32_dump(struct net *net, struct tcf_proto *tp, void *fh, + struct sk_buff *skb, struct tcmsg *t, bool rtnl_held) +{ + struct tc_u_knode *n = fh; + struct tc_u_hnode *ht_up, *ht_down; + struct nlattr *nest; + + if (n == NULL) + return skb->len; + + t->tcm_handle = n->handle; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + + if (TC_U32_KEY(n->handle) == 0) { + struct tc_u_hnode *ht = fh; + u32 divisor = ht->divisor + 1; + + if (nla_put_u32(skb, TCA_U32_DIVISOR, divisor)) + goto nla_put_failure; + } else { +#ifdef CONFIG_CLS_U32_PERF + struct tc_u32_pcnt *gpf; + int cpu; +#endif + + if (nla_put(skb, TCA_U32_SEL, struct_size(&n->sel, keys, n->sel.nkeys), + &n->sel)) + goto nla_put_failure; + + ht_up = rtnl_dereference(n->ht_up); + if (ht_up) { + u32 htid = n->handle & 0xFFFFF000; + if (nla_put_u32(skb, TCA_U32_HASH, htid)) + goto nla_put_failure; + } + if (n->res.classid && + nla_put_u32(skb, TCA_U32_CLASSID, n->res.classid)) + goto nla_put_failure; + + ht_down = rtnl_dereference(n->ht_down); + if (ht_down && + nla_put_u32(skb, TCA_U32_LINK, ht_down->handle)) + goto nla_put_failure; + + if (n->flags && nla_put_u32(skb, TCA_U32_FLAGS, n->flags)) + goto nla_put_failure; + +#ifdef CONFIG_CLS_U32_MARK + if ((n->val || n->mask)) { + struct tc_u32_mark mark = {.val = n->val, + .mask = n->mask, + .success = 0}; + int cpum; + + for_each_possible_cpu(cpum) { + __u32 cnt = *per_cpu_ptr(n->pcpu_success, cpum); + + mark.success += cnt; + } + + if (nla_put(skb, TCA_U32_MARK, sizeof(mark), &mark)) + goto nla_put_failure; + } +#endif + + if (tcf_exts_dump(skb, &n->exts) < 0) + goto nla_put_failure; + + if (n->ifindex) { + struct net_device *dev; + dev = __dev_get_by_index(net, n->ifindex); + if (dev && nla_put_string(skb, TCA_U32_INDEV, dev->name)) + goto nla_put_failure; + } +#ifdef CONFIG_CLS_U32_PERF + gpf = kzalloc(struct_size(gpf, kcnts, n->sel.nkeys), GFP_KERNEL); + if (!gpf) + goto nla_put_failure; + + for_each_possible_cpu(cpu) { + int i; + struct tc_u32_pcnt *pf = per_cpu_ptr(n->pf, cpu); + + gpf->rcnt += pf->rcnt; + gpf->rhit += pf->rhit; + for (i = 0; i < n->sel.nkeys; i++) + gpf->kcnts[i] += pf->kcnts[i]; + } + + if (nla_put_64bit(skb, TCA_U32_PCNT, struct_size(gpf, kcnts, n->sel.nkeys), + gpf, TCA_U32_PAD)) { + kfree(gpf); + goto nla_put_failure; + } + kfree(gpf); +#endif + } + + nla_nest_end(skb, nest); + + if (TC_U32_KEY(n->handle)) + if (tcf_exts_dump_stats(skb, &n->exts) < 0) + goto nla_put_failure; + return skb->len; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static struct tcf_proto_ops cls_u32_ops __read_mostly = { + .kind = "u32", + .classify = u32_classify, + .init = u32_init, + .destroy = u32_destroy, + .get = u32_get, + .change = u32_change, + .delete = u32_delete, + .walk = u32_walk, + .reoffload = u32_reoffload, + .dump = u32_dump, + .bind_class = u32_bind_class, + .owner = THIS_MODULE, +}; + +static int __init init_u32(void) +{ + int i, ret; + + pr_info("u32 classifier\n"); +#ifdef CONFIG_CLS_U32_PERF + pr_info(" Performance counters on\n"); +#endif + pr_info(" input device check on\n"); +#ifdef CONFIG_NET_CLS_ACT + pr_info(" Actions configured\n"); +#endif + tc_u_common_hash = kvmalloc_array(U32_HASH_SIZE, + sizeof(struct hlist_head), + GFP_KERNEL); + if (!tc_u_common_hash) + return -ENOMEM; + + for (i = 0; i < U32_HASH_SIZE; i++) + INIT_HLIST_HEAD(&tc_u_common_hash[i]); + + ret = register_tcf_proto_ops(&cls_u32_ops); + if (ret) + kvfree(tc_u_common_hash); + return ret; +} + +static void __exit exit_u32(void) +{ + unregister_tcf_proto_ops(&cls_u32_ops); + kvfree(tc_u_common_hash); +} + +module_init(init_u32) +module_exit(exit_u32) +MODULE_LICENSE("GPL"); diff --git a/net/sched/em_canid.c b/net/sched/em_canid.c new file mode 100644 index 000000000..5ea84dece --- /dev/null +++ b/net/sched/em_canid.c @@ -0,0 +1,230 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * em_canid.c Ematch rule to match CAN frames according to their CAN IDs + * + * Idea: Oliver Hartkopp + * Copyright: (c) 2011 Czech Technical University in Prague + * (c) 2011 Volkswagen Group Research + * Authors: Michal Sojka + * Pavel Pisa + * Rostislav Lisovy + * Funded by: Volkswagen Group Research + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#define EM_CAN_RULES_MAX 500 + +struct canid_match { + /* For each SFF CAN ID (11 bit) there is one record in this bitfield */ + DECLARE_BITMAP(match_sff, (1 << CAN_SFF_ID_BITS)); + + int rules_count; + int sff_rules_count; + int eff_rules_count; + + /* + * Raw rules copied from netlink message; Used for sending + * information to userspace (when 'tc filter show' is invoked) + * AND when matching EFF frames + */ + struct can_filter rules_raw[]; +}; + +/** + * em_canid_get_id() - Extracts Can ID out of the sk_buff structure. + * @skb: buffer to extract Can ID from + */ +static canid_t em_canid_get_id(struct sk_buff *skb) +{ + /* CAN ID is stored within the data field */ + struct can_frame *cf = (struct can_frame *)skb->data; + + return cf->can_id; +} + +static void em_canid_sff_match_add(struct canid_match *cm, u32 can_id, + u32 can_mask) +{ + int i; + + /* + * Limit can_mask and can_id to SFF range to + * protect against write after end of array + */ + can_mask &= CAN_SFF_MASK; + can_id &= can_mask; + + /* Single frame */ + if (can_mask == CAN_SFF_MASK) { + set_bit(can_id, cm->match_sff); + return; + } + + /* All frames */ + if (can_mask == 0) { + bitmap_fill(cm->match_sff, (1 << CAN_SFF_ID_BITS)); + return; + } + + /* + * Individual frame filter. + * Add record (set bit to 1) for each ID that + * conforms particular rule + */ + for (i = 0; i < (1 << CAN_SFF_ID_BITS); i++) { + if ((i & can_mask) == can_id) + set_bit(i, cm->match_sff); + } +} + +static inline struct canid_match *em_canid_priv(struct tcf_ematch *m) +{ + return (struct canid_match *)m->data; +} + +static int em_canid_match(struct sk_buff *skb, struct tcf_ematch *m, + struct tcf_pkt_info *info) +{ + struct canid_match *cm = em_canid_priv(m); + canid_t can_id; + int match = 0; + int i; + const struct can_filter *lp; + + can_id = em_canid_get_id(skb); + + if (can_id & CAN_EFF_FLAG) { + for (i = 0, lp = cm->rules_raw; + i < cm->eff_rules_count; i++, lp++) { + if (!(((lp->can_id ^ can_id) & lp->can_mask))) { + match = 1; + break; + } + } + } else { /* SFF */ + can_id &= CAN_SFF_MASK; + match = (test_bit(can_id, cm->match_sff) ? 1 : 0); + } + + return match; +} + +static int em_canid_change(struct net *net, void *data, int len, + struct tcf_ematch *m) +{ + struct can_filter *conf = data; /* Array with rules */ + struct canid_match *cm; + int i; + + if (!len) + return -EINVAL; + + if (len % sizeof(struct can_filter)) + return -EINVAL; + + if (len > sizeof(struct can_filter) * EM_CAN_RULES_MAX) + return -EINVAL; + + cm = kzalloc(sizeof(struct canid_match) + len, GFP_KERNEL); + if (!cm) + return -ENOMEM; + + cm->rules_count = len / sizeof(struct can_filter); + + /* + * We need two for() loops for copying rules into two contiguous + * areas in rules_raw to process all eff rules with a simple loop. + * NB: The configuration interface supports sff and eff rules. + * We do not support filters here that match for the same can_id + * provided in a SFF and EFF frame (e.g. 0x123 / 0x80000123). + * For this (unusual case) two filters have to be specified. The + * SFF/EFF separation is done with the CAN_EFF_FLAG in the can_id. + */ + + /* Fill rules_raw with EFF rules first */ + for (i = 0; i < cm->rules_count; i++) { + if (conf[i].can_id & CAN_EFF_FLAG) { + memcpy(cm->rules_raw + cm->eff_rules_count, + &conf[i], + sizeof(struct can_filter)); + + cm->eff_rules_count++; + } + } + + /* append SFF frame rules */ + for (i = 0; i < cm->rules_count; i++) { + if (!(conf[i].can_id & CAN_EFF_FLAG)) { + memcpy(cm->rules_raw + + cm->eff_rules_count + + cm->sff_rules_count, + &conf[i], sizeof(struct can_filter)); + + cm->sff_rules_count++; + + em_canid_sff_match_add(cm, + conf[i].can_id, conf[i].can_mask); + } + } + + m->datalen = sizeof(struct canid_match) + len; + m->data = (unsigned long)cm; + return 0; +} + +static void em_canid_destroy(struct tcf_ematch *m) +{ + struct canid_match *cm = em_canid_priv(m); + + kfree(cm); +} + +static int em_canid_dump(struct sk_buff *skb, struct tcf_ematch *m) +{ + struct canid_match *cm = em_canid_priv(m); + + /* + * When configuring this ematch 'rules_count' is set not to exceed + * 'rules_raw' array size + */ + if (nla_put_nohdr(skb, sizeof(struct can_filter) * cm->rules_count, + &cm->rules_raw) < 0) + return -EMSGSIZE; + + return 0; +} + +static struct tcf_ematch_ops em_canid_ops = { + .kind = TCF_EM_CANID, + .change = em_canid_change, + .match = em_canid_match, + .destroy = em_canid_destroy, + .dump = em_canid_dump, + .owner = THIS_MODULE, + .link = LIST_HEAD_INIT(em_canid_ops.link) +}; + +static int __init init_em_canid(void) +{ + return tcf_em_register(&em_canid_ops); +} + +static void __exit exit_em_canid(void) +{ + tcf_em_unregister(&em_canid_ops); +} + +MODULE_LICENSE("GPL"); + +module_init(init_em_canid); +module_exit(exit_em_canid); + +MODULE_ALIAS_TCF_EMATCH(TCF_EM_CANID); diff --git a/net/sched/em_cmp.c b/net/sched/em_cmp.c new file mode 100644 index 000000000..f17b049ea --- /dev/null +++ b/net/sched/em_cmp.c @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/em_cmp.c Simple packet data comparison ematch + * + * Authors: Thomas Graf + */ + +#include +#include +#include +#include +#include +#include +#include + +static inline int cmp_needs_transformation(struct tcf_em_cmp *cmp) +{ + return unlikely(cmp->flags & TCF_EM_CMP_TRANS); +} + +static int em_cmp_match(struct sk_buff *skb, struct tcf_ematch *em, + struct tcf_pkt_info *info) +{ + struct tcf_em_cmp *cmp = (struct tcf_em_cmp *) em->data; + unsigned char *ptr = tcf_get_base_ptr(skb, cmp->layer) + cmp->off; + u32 val = 0; + + if (!tcf_valid_offset(skb, ptr, cmp->align)) + return 0; + + switch (cmp->align) { + case TCF_EM_ALIGN_U8: + val = *ptr; + break; + + case TCF_EM_ALIGN_U16: + val = get_unaligned_be16(ptr); + + if (cmp_needs_transformation(cmp)) + val = be16_to_cpu(val); + break; + + case TCF_EM_ALIGN_U32: + /* Worth checking boundaries? The branching seems + * to get worse. Visit again. + */ + val = get_unaligned_be32(ptr); + + if (cmp_needs_transformation(cmp)) + val = be32_to_cpu(val); + break; + + default: + return 0; + } + + if (cmp->mask) + val &= cmp->mask; + + switch (cmp->opnd) { + case TCF_EM_OPND_EQ: + return val == cmp->val; + case TCF_EM_OPND_LT: + return val < cmp->val; + case TCF_EM_OPND_GT: + return val > cmp->val; + } + + return 0; +} + +static struct tcf_ematch_ops em_cmp_ops = { + .kind = TCF_EM_CMP, + .datalen = sizeof(struct tcf_em_cmp), + .match = em_cmp_match, + .owner = THIS_MODULE, + .link = LIST_HEAD_INIT(em_cmp_ops.link) +}; + +static int __init init_em_cmp(void) +{ + return tcf_em_register(&em_cmp_ops); +} + +static void __exit exit_em_cmp(void) +{ + tcf_em_unregister(&em_cmp_ops); +} + +MODULE_LICENSE("GPL"); + +module_init(init_em_cmp); +module_exit(exit_em_cmp); + +MODULE_ALIAS_TCF_EMATCH(TCF_EM_CMP); diff --git a/net/sched/em_ipset.c b/net/sched/em_ipset.c new file mode 100644 index 000000000..c95cf86fb --- /dev/null +++ b/net/sched/em_ipset.c @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/sched/em_ipset.c ipset ematch + * + * Copyright (c) 2012 Florian Westphal + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int em_ipset_change(struct net *net, void *data, int data_len, + struct tcf_ematch *em) +{ + struct xt_set_info *set = data; + ip_set_id_t index; + + if (data_len != sizeof(*set)) + return -EINVAL; + + index = ip_set_nfnl_get_byindex(net, set->index); + if (index == IPSET_INVALID_ID) + return -ENOENT; + + em->datalen = sizeof(*set); + em->data = (unsigned long)kmemdup(data, em->datalen, GFP_KERNEL); + if (em->data) + return 0; + + ip_set_nfnl_put(net, index); + return -ENOMEM; +} + +static void em_ipset_destroy(struct tcf_ematch *em) +{ + const struct xt_set_info *set = (const void *) em->data; + if (set) { + ip_set_nfnl_put(em->net, set->index); + kfree((void *) em->data); + } +} + +static int em_ipset_match(struct sk_buff *skb, struct tcf_ematch *em, + struct tcf_pkt_info *info) +{ + struct ip_set_adt_opt opt; + struct xt_action_param acpar; + const struct xt_set_info *set = (const void *) em->data; + struct net_device *dev, *indev = NULL; + struct nf_hook_state state = { + .net = em->net, + }; + int ret, network_offset; + + switch (skb_protocol(skb, true)) { + case htons(ETH_P_IP): + state.pf = NFPROTO_IPV4; + if (!pskb_network_may_pull(skb, sizeof(struct iphdr))) + return 0; + acpar.thoff = ip_hdrlen(skb); + break; + case htons(ETH_P_IPV6): + state.pf = NFPROTO_IPV6; + if (!pskb_network_may_pull(skb, sizeof(struct ipv6hdr))) + return 0; + /* doesn't call ipv6_find_hdr() because ipset doesn't use thoff, yet */ + acpar.thoff = sizeof(struct ipv6hdr); + break; + default: + return 0; + } + + opt.family = state.pf; + opt.dim = set->dim; + opt.flags = set->flags; + opt.cmdflags = 0; + opt.ext.timeout = ~0u; + + network_offset = skb_network_offset(skb); + skb_pull(skb, network_offset); + + dev = skb->dev; + + rcu_read_lock(); + + if (skb->skb_iif) + indev = dev_get_by_index_rcu(em->net, skb->skb_iif); + + state.in = indev ? indev : dev; + state.out = dev; + acpar.state = &state; + + ret = ip_set_test(set->index, skb, &acpar, &opt); + + rcu_read_unlock(); + + skb_push(skb, network_offset); + return ret; +} + +static struct tcf_ematch_ops em_ipset_ops = { + .kind = TCF_EM_IPSET, + .change = em_ipset_change, + .destroy = em_ipset_destroy, + .match = em_ipset_match, + .owner = THIS_MODULE, + .link = LIST_HEAD_INIT(em_ipset_ops.link) +}; + +static int __init init_em_ipset(void) +{ + return tcf_em_register(&em_ipset_ops); +} + +static void __exit exit_em_ipset(void) +{ + tcf_em_unregister(&em_ipset_ops); +} + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Florian Westphal "); +MODULE_DESCRIPTION("TC extended match for IP sets"); + +module_init(init_em_ipset); +module_exit(exit_em_ipset); + +MODULE_ALIAS_TCF_EMATCH(TCF_EM_IPSET); diff --git a/net/sched/em_ipt.c b/net/sched/em_ipt.c new file mode 100644 index 000000000..3650117da --- /dev/null +++ b/net/sched/em_ipt.c @@ -0,0 +1,297 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/em_ipt.c IPtables matches Ematch + * + * (c) 2018 Eyal Birger + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct em_ipt_match { + const struct xt_match *match; + u32 hook; + u8 nfproto; + u8 match_data[] __aligned(8); +}; + +struct em_ipt_xt_match { + char *match_name; + int (*validate_match_data)(struct nlattr **tb, u8 mrev); +}; + +static const struct nla_policy em_ipt_policy[TCA_EM_IPT_MAX + 1] = { + [TCA_EM_IPT_MATCH_NAME] = { .type = NLA_STRING, + .len = XT_EXTENSION_MAXNAMELEN }, + [TCA_EM_IPT_MATCH_REVISION] = { .type = NLA_U8 }, + [TCA_EM_IPT_HOOK] = { .type = NLA_U32 }, + [TCA_EM_IPT_NFPROTO] = { .type = NLA_U8 }, + [TCA_EM_IPT_MATCH_DATA] = { .type = NLA_UNSPEC }, +}; + +static int check_match(struct net *net, struct em_ipt_match *im, int mdata_len) +{ + struct xt_mtchk_param mtpar = {}; + union { + struct ipt_entry e4; + struct ip6t_entry e6; + } e = {}; + + mtpar.net = net; + mtpar.table = "filter"; + mtpar.hook_mask = 1 << im->hook; + mtpar.family = im->match->family; + mtpar.match = im->match; + mtpar.entryinfo = &e; + mtpar.matchinfo = (void *)im->match_data; + return xt_check_match(&mtpar, mdata_len, 0, 0); +} + +static int policy_validate_match_data(struct nlattr **tb, u8 mrev) +{ + if (mrev != 0) { + pr_err("only policy match revision 0 supported"); + return -EINVAL; + } + + if (nla_get_u32(tb[TCA_EM_IPT_HOOK]) != NF_INET_PRE_ROUTING) { + pr_err("policy can only be matched on NF_INET_PRE_ROUTING"); + return -EINVAL; + } + + return 0; +} + +static int addrtype_validate_match_data(struct nlattr **tb, u8 mrev) +{ + if (mrev != 1) { + pr_err("only addrtype match revision 1 supported"); + return -EINVAL; + } + + return 0; +} + +static const struct em_ipt_xt_match em_ipt_xt_matches[] = { + { + .match_name = "policy", + .validate_match_data = policy_validate_match_data + }, + { + .match_name = "addrtype", + .validate_match_data = addrtype_validate_match_data + }, + {} +}; + +static struct xt_match *get_xt_match(struct nlattr **tb) +{ + const struct em_ipt_xt_match *m; + struct nlattr *mname_attr; + u8 nfproto, mrev = 0; + int ret; + + mname_attr = tb[TCA_EM_IPT_MATCH_NAME]; + for (m = em_ipt_xt_matches; m->match_name; m++) { + if (!nla_strcmp(mname_attr, m->match_name)) + break; + } + + if (!m->match_name) { + pr_err("Unsupported xt match"); + return ERR_PTR(-EINVAL); + } + + if (tb[TCA_EM_IPT_MATCH_REVISION]) + mrev = nla_get_u8(tb[TCA_EM_IPT_MATCH_REVISION]); + + ret = m->validate_match_data(tb, mrev); + if (ret < 0) + return ERR_PTR(ret); + + nfproto = nla_get_u8(tb[TCA_EM_IPT_NFPROTO]); + return xt_request_find_match(nfproto, m->match_name, mrev); +} + +static int em_ipt_change(struct net *net, void *data, int data_len, + struct tcf_ematch *em) +{ + struct nlattr *tb[TCA_EM_IPT_MAX + 1]; + struct em_ipt_match *im = NULL; + struct xt_match *match; + int mdata_len, ret; + u8 nfproto; + + ret = nla_parse_deprecated(tb, TCA_EM_IPT_MAX, data, data_len, + em_ipt_policy, NULL); + if (ret < 0) + return ret; + + if (!tb[TCA_EM_IPT_HOOK] || !tb[TCA_EM_IPT_MATCH_NAME] || + !tb[TCA_EM_IPT_MATCH_DATA] || !tb[TCA_EM_IPT_NFPROTO]) + return -EINVAL; + + nfproto = nla_get_u8(tb[TCA_EM_IPT_NFPROTO]); + switch (nfproto) { + case NFPROTO_IPV4: + case NFPROTO_IPV6: + break; + default: + return -EINVAL; + } + + match = get_xt_match(tb); + if (IS_ERR(match)) { + pr_err("unable to load match\n"); + return PTR_ERR(match); + } + + mdata_len = XT_ALIGN(nla_len(tb[TCA_EM_IPT_MATCH_DATA])); + im = kzalloc(sizeof(*im) + mdata_len, GFP_KERNEL); + if (!im) { + ret = -ENOMEM; + goto err; + } + + im->match = match; + im->hook = nla_get_u32(tb[TCA_EM_IPT_HOOK]); + im->nfproto = nfproto; + nla_memcpy(im->match_data, tb[TCA_EM_IPT_MATCH_DATA], mdata_len); + + ret = check_match(net, im, mdata_len); + if (ret) + goto err; + + em->datalen = sizeof(*im) + mdata_len; + em->data = (unsigned long)im; + return 0; + +err: + kfree(im); + module_put(match->me); + return ret; +} + +static void em_ipt_destroy(struct tcf_ematch *em) +{ + struct em_ipt_match *im = (void *)em->data; + + if (!im) + return; + + if (im->match->destroy) { + struct xt_mtdtor_param par = { + .net = em->net, + .match = im->match, + .matchinfo = im->match_data, + .family = im->match->family + }; + im->match->destroy(&par); + } + module_put(im->match->me); + kfree(im); +} + +static int em_ipt_match(struct sk_buff *skb, struct tcf_ematch *em, + struct tcf_pkt_info *info) +{ + const struct em_ipt_match *im = (const void *)em->data; + struct xt_action_param acpar = {}; + struct net_device *indev = NULL; + u8 nfproto = im->match->family; + struct nf_hook_state state; + int ret; + + switch (skb_protocol(skb, true)) { + case htons(ETH_P_IP): + if (!pskb_network_may_pull(skb, sizeof(struct iphdr))) + return 0; + if (nfproto == NFPROTO_UNSPEC) + nfproto = NFPROTO_IPV4; + break; + case htons(ETH_P_IPV6): + if (!pskb_network_may_pull(skb, sizeof(struct ipv6hdr))) + return 0; + if (nfproto == NFPROTO_UNSPEC) + nfproto = NFPROTO_IPV6; + break; + default: + return 0; + } + + rcu_read_lock(); + + if (skb->skb_iif) + indev = dev_get_by_index_rcu(em->net, skb->skb_iif); + + nf_hook_state_init(&state, im->hook, nfproto, + indev ?: skb->dev, skb->dev, NULL, em->net, NULL); + + acpar.match = im->match; + acpar.matchinfo = im->match_data; + acpar.state = &state; + + ret = im->match->match(skb, &acpar); + + rcu_read_unlock(); + return ret; +} + +static int em_ipt_dump(struct sk_buff *skb, struct tcf_ematch *em) +{ + struct em_ipt_match *im = (void *)em->data; + + if (nla_put_string(skb, TCA_EM_IPT_MATCH_NAME, im->match->name) < 0) + return -EMSGSIZE; + if (nla_put_u32(skb, TCA_EM_IPT_HOOK, im->hook) < 0) + return -EMSGSIZE; + if (nla_put_u8(skb, TCA_EM_IPT_MATCH_REVISION, im->match->revision) < 0) + return -EMSGSIZE; + if (nla_put_u8(skb, TCA_EM_IPT_NFPROTO, im->nfproto) < 0) + return -EMSGSIZE; + if (nla_put(skb, TCA_EM_IPT_MATCH_DATA, + im->match->usersize ?: im->match->matchsize, + im->match_data) < 0) + return -EMSGSIZE; + + return 0; +} + +static struct tcf_ematch_ops em_ipt_ops = { + .kind = TCF_EM_IPT, + .change = em_ipt_change, + .destroy = em_ipt_destroy, + .match = em_ipt_match, + .dump = em_ipt_dump, + .owner = THIS_MODULE, + .link = LIST_HEAD_INIT(em_ipt_ops.link) +}; + +static int __init init_em_ipt(void) +{ + return tcf_em_register(&em_ipt_ops); +} + +static void __exit exit_em_ipt(void) +{ + tcf_em_unregister(&em_ipt_ops); +} + +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Eyal Birger "); +MODULE_DESCRIPTION("TC extended match for IPtables matches"); + +module_init(init_em_ipt); +module_exit(exit_em_ipt); + +MODULE_ALIAS_TCF_EMATCH(TCF_EM_IPT); diff --git a/net/sched/em_meta.c b/net/sched/em_meta.c new file mode 100644 index 000000000..6f2f135aa --- /dev/null +++ b/net/sched/em_meta.c @@ -0,0 +1,1014 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/em_meta.c Metadata ematch + * + * Authors: Thomas Graf + * + * ========================================================================== + * + * The metadata ematch compares two meta objects where each object + * represents either a meta value stored in the kernel or a static + * value provided by userspace. The objects are not provided by + * userspace itself but rather a definition providing the information + * to build them. Every object is of a certain type which must be + * equal to the object it is being compared to. + * + * The definition of a objects conists of the type (meta type), a + * identifier (meta id) and additional type specific information. + * The meta id is either TCF_META_TYPE_VALUE for values provided by + * userspace or a index to the meta operations table consisting of + * function pointers to type specific meta data collectors returning + * the value of the requested meta value. + * + * lvalue rvalue + * +-----------+ +-----------+ + * | type: INT | | type: INT | + * def | id: DEV | | id: VALUE | + * | data: | | data: 3 | + * +-----------+ +-----------+ + * | | + * ---> meta_ops[INT][DEV](...) | + * | | + * ----------- | + * V V + * +-----------+ +-----------+ + * | type: INT | | type: INT | + * obj | id: DEV | | id: VALUE | + * | data: 2 |<--data got filled out | data: 3 | + * +-----------+ +-----------+ + * | | + * --------------> 2 equals 3 <-------------- + * + * This is a simplified schema, the complexity varies depending + * on the meta type. Obviously, the length of the data must also + * be provided for non-numeric types. + * + * Additionally, type dependent modifiers such as shift operators + * or mask may be applied to extend the functionaliy. As of now, + * the variable length type supports shifting the byte string to + * the right, eating up any number of octets and thus supporting + * wildcard interface name comparisons such as "ppp%" matching + * ppp0..9. + * + * NOTE: Certain meta values depend on other subsystems and are + * only available if that subsystem is enabled in the kernel. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct meta_obj { + unsigned long value; + unsigned int len; +}; + +struct meta_value { + struct tcf_meta_val hdr; + unsigned long val; + unsigned int len; +}; + +struct meta_match { + struct meta_value lvalue; + struct meta_value rvalue; +}; + +static inline int meta_id(struct meta_value *v) +{ + return TCF_META_ID(v->hdr.kind); +} + +static inline int meta_type(struct meta_value *v) +{ + return TCF_META_TYPE(v->hdr.kind); +} + +#define META_COLLECTOR(FUNC) static void meta_##FUNC(struct sk_buff *skb, \ + struct tcf_pkt_info *info, struct meta_value *v, \ + struct meta_obj *dst, int *err) + +/************************************************************************** + * System status & misc + **************************************************************************/ + +META_COLLECTOR(int_random) +{ + get_random_bytes(&dst->value, sizeof(dst->value)); +} + +static inline unsigned long fixed_loadavg(int load) +{ + int rnd_load = load + (FIXED_1/200); + int rnd_frac = ((rnd_load & (FIXED_1-1)) * 100) >> FSHIFT; + + return ((rnd_load >> FSHIFT) * 100) + rnd_frac; +} + +META_COLLECTOR(int_loadavg_0) +{ + dst->value = fixed_loadavg(avenrun[0]); +} + +META_COLLECTOR(int_loadavg_1) +{ + dst->value = fixed_loadavg(avenrun[1]); +} + +META_COLLECTOR(int_loadavg_2) +{ + dst->value = fixed_loadavg(avenrun[2]); +} + +/************************************************************************** + * Device names & indices + **************************************************************************/ + +static inline int int_dev(struct net_device *dev, struct meta_obj *dst) +{ + if (unlikely(dev == NULL)) + return -1; + + dst->value = dev->ifindex; + return 0; +} + +static inline int var_dev(struct net_device *dev, struct meta_obj *dst) +{ + if (unlikely(dev == NULL)) + return -1; + + dst->value = (unsigned long) dev->name; + dst->len = strlen(dev->name); + return 0; +} + +META_COLLECTOR(int_dev) +{ + *err = int_dev(skb->dev, dst); +} + +META_COLLECTOR(var_dev) +{ + *err = var_dev(skb->dev, dst); +} + +/************************************************************************** + * vlan tag + **************************************************************************/ + +META_COLLECTOR(int_vlan_tag) +{ + unsigned short tag; + + if (skb_vlan_tag_present(skb)) + dst->value = skb_vlan_tag_get(skb); + else if (!__vlan_get_tag(skb, &tag)) + dst->value = tag; + else + *err = -1; +} + + + +/************************************************************************** + * skb attributes + **************************************************************************/ + +META_COLLECTOR(int_priority) +{ + dst->value = skb->priority; +} + +META_COLLECTOR(int_protocol) +{ + /* Let userspace take care of the byte ordering */ + dst->value = skb_protocol(skb, false); +} + +META_COLLECTOR(int_pkttype) +{ + dst->value = skb->pkt_type; +} + +META_COLLECTOR(int_pktlen) +{ + dst->value = skb->len; +} + +META_COLLECTOR(int_datalen) +{ + dst->value = skb->data_len; +} + +META_COLLECTOR(int_maclen) +{ + dst->value = skb->mac_len; +} + +META_COLLECTOR(int_rxhash) +{ + dst->value = skb_get_hash(skb); +} + +/************************************************************************** + * Netfilter + **************************************************************************/ + +META_COLLECTOR(int_mark) +{ + dst->value = skb->mark; +} + +/************************************************************************** + * Traffic Control + **************************************************************************/ + +META_COLLECTOR(int_tcindex) +{ + dst->value = skb->tc_index; +} + +/************************************************************************** + * Routing + **************************************************************************/ + +META_COLLECTOR(int_rtclassid) +{ + if (unlikely(skb_dst(skb) == NULL)) + *err = -1; + else +#ifdef CONFIG_IP_ROUTE_CLASSID + dst->value = skb_dst(skb)->tclassid; +#else + dst->value = 0; +#endif +} + +META_COLLECTOR(int_rtiif) +{ + if (unlikely(skb_rtable(skb) == NULL)) + *err = -1; + else + dst->value = inet_iif(skb); +} + +/************************************************************************** + * Socket Attributes + **************************************************************************/ + +#define skip_nonlocal(skb) \ + (unlikely(skb->sk == NULL)) + +META_COLLECTOR(int_sk_family) +{ + if (skip_nonlocal(skb)) { + *err = -1; + return; + } + dst->value = skb->sk->sk_family; +} + +META_COLLECTOR(int_sk_state) +{ + if (skip_nonlocal(skb)) { + *err = -1; + return; + } + dst->value = skb->sk->sk_state; +} + +META_COLLECTOR(int_sk_reuse) +{ + if (skip_nonlocal(skb)) { + *err = -1; + return; + } + dst->value = skb->sk->sk_reuse; +} + +META_COLLECTOR(int_sk_bound_if) +{ + if (skip_nonlocal(skb)) { + *err = -1; + return; + } + /* No error if bound_dev_if is 0, legal userspace check */ + dst->value = skb->sk->sk_bound_dev_if; +} + +META_COLLECTOR(var_sk_bound_if) +{ + int bound_dev_if; + + if (skip_nonlocal(skb)) { + *err = -1; + return; + } + + bound_dev_if = READ_ONCE(skb->sk->sk_bound_dev_if); + if (bound_dev_if == 0) { + dst->value = (unsigned long) "any"; + dst->len = 3; + } else { + struct net_device *dev; + + rcu_read_lock(); + dev = dev_get_by_index_rcu(sock_net(skb->sk), + bound_dev_if); + *err = var_dev(dev, dst); + rcu_read_unlock(); + } +} + +META_COLLECTOR(int_sk_refcnt) +{ + if (skip_nonlocal(skb)) { + *err = -1; + return; + } + dst->value = refcount_read(&skb->sk->sk_refcnt); +} + +META_COLLECTOR(int_sk_rcvbuf) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = sk->sk_rcvbuf; +} + +META_COLLECTOR(int_sk_shutdown) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = sk->sk_shutdown; +} + +META_COLLECTOR(int_sk_proto) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = sk->sk_protocol; +} + +META_COLLECTOR(int_sk_type) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = sk->sk_type; +} + +META_COLLECTOR(int_sk_rmem_alloc) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = sk_rmem_alloc_get(sk); +} + +META_COLLECTOR(int_sk_wmem_alloc) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = sk_wmem_alloc_get(sk); +} + +META_COLLECTOR(int_sk_omem_alloc) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = atomic_read(&sk->sk_omem_alloc); +} + +META_COLLECTOR(int_sk_rcv_qlen) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = sk->sk_receive_queue.qlen; +} + +META_COLLECTOR(int_sk_snd_qlen) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = sk->sk_write_queue.qlen; +} + +META_COLLECTOR(int_sk_wmem_queued) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = READ_ONCE(sk->sk_wmem_queued); +} + +META_COLLECTOR(int_sk_fwd_alloc) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = sk_forward_alloc_get(sk); +} + +META_COLLECTOR(int_sk_sndbuf) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = sk->sk_sndbuf; +} + +META_COLLECTOR(int_sk_alloc) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = (__force int) sk->sk_allocation; +} + +META_COLLECTOR(int_sk_hash) +{ + if (skip_nonlocal(skb)) { + *err = -1; + return; + } + dst->value = skb->sk->sk_hash; +} + +META_COLLECTOR(int_sk_lingertime) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = READ_ONCE(sk->sk_lingertime) / HZ; +} + +META_COLLECTOR(int_sk_err_qlen) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = sk->sk_error_queue.qlen; +} + +META_COLLECTOR(int_sk_ack_bl) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = READ_ONCE(sk->sk_ack_backlog); +} + +META_COLLECTOR(int_sk_max_ack_bl) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = READ_ONCE(sk->sk_max_ack_backlog); +} + +META_COLLECTOR(int_sk_prio) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = sk->sk_priority; +} + +META_COLLECTOR(int_sk_rcvlowat) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = READ_ONCE(sk->sk_rcvlowat); +} + +META_COLLECTOR(int_sk_rcvtimeo) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = READ_ONCE(sk->sk_rcvtimeo) / HZ; +} + +META_COLLECTOR(int_sk_sndtimeo) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = READ_ONCE(sk->sk_sndtimeo) / HZ; +} + +META_COLLECTOR(int_sk_sendmsg_off) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = sk->sk_frag.offset; +} + +META_COLLECTOR(int_sk_write_pend) +{ + const struct sock *sk = skb_to_full_sk(skb); + + if (!sk) { + *err = -1; + return; + } + dst->value = sk->sk_write_pending; +} + +/************************************************************************** + * Meta value collectors assignment table + **************************************************************************/ + +struct meta_ops { + void (*get)(struct sk_buff *, struct tcf_pkt_info *, + struct meta_value *, struct meta_obj *, int *); +}; + +#define META_ID(name) TCF_META_ID_##name +#define META_FUNC(name) { .get = meta_##name } + +/* Meta value operations table listing all meta value collectors and + * assigns them to a type and meta id. */ +static struct meta_ops __meta_ops[TCF_META_TYPE_MAX + 1][TCF_META_ID_MAX + 1] = { + [TCF_META_TYPE_VAR] = { + [META_ID(DEV)] = META_FUNC(var_dev), + [META_ID(SK_BOUND_IF)] = META_FUNC(var_sk_bound_if), + }, + [TCF_META_TYPE_INT] = { + [META_ID(RANDOM)] = META_FUNC(int_random), + [META_ID(LOADAVG_0)] = META_FUNC(int_loadavg_0), + [META_ID(LOADAVG_1)] = META_FUNC(int_loadavg_1), + [META_ID(LOADAVG_2)] = META_FUNC(int_loadavg_2), + [META_ID(DEV)] = META_FUNC(int_dev), + [META_ID(PRIORITY)] = META_FUNC(int_priority), + [META_ID(PROTOCOL)] = META_FUNC(int_protocol), + [META_ID(PKTTYPE)] = META_FUNC(int_pkttype), + [META_ID(PKTLEN)] = META_FUNC(int_pktlen), + [META_ID(DATALEN)] = META_FUNC(int_datalen), + [META_ID(MACLEN)] = META_FUNC(int_maclen), + [META_ID(NFMARK)] = META_FUNC(int_mark), + [META_ID(TCINDEX)] = META_FUNC(int_tcindex), + [META_ID(RTCLASSID)] = META_FUNC(int_rtclassid), + [META_ID(RTIIF)] = META_FUNC(int_rtiif), + [META_ID(SK_FAMILY)] = META_FUNC(int_sk_family), + [META_ID(SK_STATE)] = META_FUNC(int_sk_state), + [META_ID(SK_REUSE)] = META_FUNC(int_sk_reuse), + [META_ID(SK_BOUND_IF)] = META_FUNC(int_sk_bound_if), + [META_ID(SK_REFCNT)] = META_FUNC(int_sk_refcnt), + [META_ID(SK_RCVBUF)] = META_FUNC(int_sk_rcvbuf), + [META_ID(SK_SNDBUF)] = META_FUNC(int_sk_sndbuf), + [META_ID(SK_SHUTDOWN)] = META_FUNC(int_sk_shutdown), + [META_ID(SK_PROTO)] = META_FUNC(int_sk_proto), + [META_ID(SK_TYPE)] = META_FUNC(int_sk_type), + [META_ID(SK_RMEM_ALLOC)] = META_FUNC(int_sk_rmem_alloc), + [META_ID(SK_WMEM_ALLOC)] = META_FUNC(int_sk_wmem_alloc), + [META_ID(SK_OMEM_ALLOC)] = META_FUNC(int_sk_omem_alloc), + [META_ID(SK_WMEM_QUEUED)] = META_FUNC(int_sk_wmem_queued), + [META_ID(SK_RCV_QLEN)] = META_FUNC(int_sk_rcv_qlen), + [META_ID(SK_SND_QLEN)] = META_FUNC(int_sk_snd_qlen), + [META_ID(SK_ERR_QLEN)] = META_FUNC(int_sk_err_qlen), + [META_ID(SK_FORWARD_ALLOCS)] = META_FUNC(int_sk_fwd_alloc), + [META_ID(SK_ALLOCS)] = META_FUNC(int_sk_alloc), + [META_ID(SK_HASH)] = META_FUNC(int_sk_hash), + [META_ID(SK_LINGERTIME)] = META_FUNC(int_sk_lingertime), + [META_ID(SK_ACK_BACKLOG)] = META_FUNC(int_sk_ack_bl), + [META_ID(SK_MAX_ACK_BACKLOG)] = META_FUNC(int_sk_max_ack_bl), + [META_ID(SK_PRIO)] = META_FUNC(int_sk_prio), + [META_ID(SK_RCVLOWAT)] = META_FUNC(int_sk_rcvlowat), + [META_ID(SK_RCVTIMEO)] = META_FUNC(int_sk_rcvtimeo), + [META_ID(SK_SNDTIMEO)] = META_FUNC(int_sk_sndtimeo), + [META_ID(SK_SENDMSG_OFF)] = META_FUNC(int_sk_sendmsg_off), + [META_ID(SK_WRITE_PENDING)] = META_FUNC(int_sk_write_pend), + [META_ID(VLAN_TAG)] = META_FUNC(int_vlan_tag), + [META_ID(RXHASH)] = META_FUNC(int_rxhash), + } +}; + +static inline struct meta_ops *meta_ops(struct meta_value *val) +{ + return &__meta_ops[meta_type(val)][meta_id(val)]; +} + +/************************************************************************** + * Type specific operations for TCF_META_TYPE_VAR + **************************************************************************/ + +static int meta_var_compare(struct meta_obj *a, struct meta_obj *b) +{ + int r = a->len - b->len; + + if (r == 0) + r = memcmp((void *) a->value, (void *) b->value, a->len); + + return r; +} + +static int meta_var_change(struct meta_value *dst, struct nlattr *nla) +{ + int len = nla_len(nla); + + dst->val = (unsigned long)kmemdup(nla_data(nla), len, GFP_KERNEL); + if (dst->val == 0UL) + return -ENOMEM; + dst->len = len; + return 0; +} + +static void meta_var_destroy(struct meta_value *v) +{ + kfree((void *) v->val); +} + +static void meta_var_apply_extras(struct meta_value *v, + struct meta_obj *dst) +{ + int shift = v->hdr.shift; + + if (shift && shift < dst->len) + dst->len -= shift; +} + +static int meta_var_dump(struct sk_buff *skb, struct meta_value *v, int tlv) +{ + if (v->val && v->len && + nla_put(skb, tlv, v->len, (void *) v->val)) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +/************************************************************************** + * Type specific operations for TCF_META_TYPE_INT + **************************************************************************/ + +static int meta_int_compare(struct meta_obj *a, struct meta_obj *b) +{ + /* Let gcc optimize it, the unlikely is not really based on + * some numbers but jump free code for mismatches seems + * more logical. */ + if (unlikely(a->value == b->value)) + return 0; + else if (a->value < b->value) + return -1; + else + return 1; +} + +static int meta_int_change(struct meta_value *dst, struct nlattr *nla) +{ + if (nla_len(nla) >= sizeof(unsigned long)) { + dst->val = *(unsigned long *) nla_data(nla); + dst->len = sizeof(unsigned long); + } else if (nla_len(nla) == sizeof(u32)) { + dst->val = nla_get_u32(nla); + dst->len = sizeof(u32); + } else + return -EINVAL; + + return 0; +} + +static void meta_int_apply_extras(struct meta_value *v, + struct meta_obj *dst) +{ + if (v->hdr.shift) + dst->value >>= v->hdr.shift; + + if (v->val) + dst->value &= v->val; +} + +static int meta_int_dump(struct sk_buff *skb, struct meta_value *v, int tlv) +{ + if (v->len == sizeof(unsigned long)) { + if (nla_put(skb, tlv, sizeof(unsigned long), &v->val)) + goto nla_put_failure; + } else if (v->len == sizeof(u32)) { + if (nla_put_u32(skb, tlv, v->val)) + goto nla_put_failure; + } + + return 0; + +nla_put_failure: + return -1; +} + +/************************************************************************** + * Type specific operations table + **************************************************************************/ + +struct meta_type_ops { + void (*destroy)(struct meta_value *); + int (*compare)(struct meta_obj *, struct meta_obj *); + int (*change)(struct meta_value *, struct nlattr *); + void (*apply_extras)(struct meta_value *, struct meta_obj *); + int (*dump)(struct sk_buff *, struct meta_value *, int); +}; + +static const struct meta_type_ops __meta_type_ops[TCF_META_TYPE_MAX + 1] = { + [TCF_META_TYPE_VAR] = { + .destroy = meta_var_destroy, + .compare = meta_var_compare, + .change = meta_var_change, + .apply_extras = meta_var_apply_extras, + .dump = meta_var_dump + }, + [TCF_META_TYPE_INT] = { + .compare = meta_int_compare, + .change = meta_int_change, + .apply_extras = meta_int_apply_extras, + .dump = meta_int_dump + } +}; + +static inline const struct meta_type_ops *meta_type_ops(struct meta_value *v) +{ + return &__meta_type_ops[meta_type(v)]; +} + +/************************************************************************** + * Core + **************************************************************************/ + +static int meta_get(struct sk_buff *skb, struct tcf_pkt_info *info, + struct meta_value *v, struct meta_obj *dst) +{ + int err = 0; + + if (meta_id(v) == TCF_META_ID_VALUE) { + dst->value = v->val; + dst->len = v->len; + return 0; + } + + meta_ops(v)->get(skb, info, v, dst, &err); + if (err < 0) + return err; + + if (meta_type_ops(v)->apply_extras) + meta_type_ops(v)->apply_extras(v, dst); + + return 0; +} + +static int em_meta_match(struct sk_buff *skb, struct tcf_ematch *m, + struct tcf_pkt_info *info) +{ + int r; + struct meta_match *meta = (struct meta_match *) m->data; + struct meta_obj l_value, r_value; + + if (meta_get(skb, info, &meta->lvalue, &l_value) < 0 || + meta_get(skb, info, &meta->rvalue, &r_value) < 0) + return 0; + + r = meta_type_ops(&meta->lvalue)->compare(&l_value, &r_value); + + switch (meta->lvalue.hdr.op) { + case TCF_EM_OPND_EQ: + return !r; + case TCF_EM_OPND_LT: + return r < 0; + case TCF_EM_OPND_GT: + return r > 0; + } + + return 0; +} + +static void meta_delete(struct meta_match *meta) +{ + if (meta) { + const struct meta_type_ops *ops = meta_type_ops(&meta->lvalue); + + if (ops && ops->destroy) { + ops->destroy(&meta->lvalue); + ops->destroy(&meta->rvalue); + } + } + + kfree(meta); +} + +static inline int meta_change_data(struct meta_value *dst, struct nlattr *nla) +{ + if (nla) { + if (nla_len(nla) == 0) + return -EINVAL; + + return meta_type_ops(dst)->change(dst, nla); + } + + return 0; +} + +static inline int meta_is_supported(struct meta_value *val) +{ + return !meta_id(val) || meta_ops(val)->get; +} + +static const struct nla_policy meta_policy[TCA_EM_META_MAX + 1] = { + [TCA_EM_META_HDR] = { .len = sizeof(struct tcf_meta_hdr) }, +}; + +static int em_meta_change(struct net *net, void *data, int len, + struct tcf_ematch *m) +{ + int err; + struct nlattr *tb[TCA_EM_META_MAX + 1]; + struct tcf_meta_hdr *hdr; + struct meta_match *meta = NULL; + + err = nla_parse_deprecated(tb, TCA_EM_META_MAX, data, len, + meta_policy, NULL); + if (err < 0) + goto errout; + + err = -EINVAL; + if (tb[TCA_EM_META_HDR] == NULL) + goto errout; + hdr = nla_data(tb[TCA_EM_META_HDR]); + + if (TCF_META_TYPE(hdr->left.kind) != TCF_META_TYPE(hdr->right.kind) || + TCF_META_TYPE(hdr->left.kind) > TCF_META_TYPE_MAX || + TCF_META_ID(hdr->left.kind) > TCF_META_ID_MAX || + TCF_META_ID(hdr->right.kind) > TCF_META_ID_MAX) + goto errout; + + meta = kzalloc(sizeof(*meta), GFP_KERNEL); + if (meta == NULL) { + err = -ENOMEM; + goto errout; + } + + memcpy(&meta->lvalue.hdr, &hdr->left, sizeof(hdr->left)); + memcpy(&meta->rvalue.hdr, &hdr->right, sizeof(hdr->right)); + + if (!meta_is_supported(&meta->lvalue) || + !meta_is_supported(&meta->rvalue)) { + err = -EOPNOTSUPP; + goto errout; + } + + if (meta_change_data(&meta->lvalue, tb[TCA_EM_META_LVALUE]) < 0 || + meta_change_data(&meta->rvalue, tb[TCA_EM_META_RVALUE]) < 0) + goto errout; + + m->datalen = sizeof(*meta); + m->data = (unsigned long) meta; + + err = 0; +errout: + if (err && meta) + meta_delete(meta); + return err; +} + +static void em_meta_destroy(struct tcf_ematch *m) +{ + if (m) + meta_delete((struct meta_match *) m->data); +} + +static int em_meta_dump(struct sk_buff *skb, struct tcf_ematch *em) +{ + struct meta_match *meta = (struct meta_match *) em->data; + struct tcf_meta_hdr hdr; + const struct meta_type_ops *ops; + + memset(&hdr, 0, sizeof(hdr)); + memcpy(&hdr.left, &meta->lvalue.hdr, sizeof(hdr.left)); + memcpy(&hdr.right, &meta->rvalue.hdr, sizeof(hdr.right)); + + if (nla_put(skb, TCA_EM_META_HDR, sizeof(hdr), &hdr)) + goto nla_put_failure; + + ops = meta_type_ops(&meta->lvalue); + if (ops->dump(skb, &meta->lvalue, TCA_EM_META_LVALUE) < 0 || + ops->dump(skb, &meta->rvalue, TCA_EM_META_RVALUE) < 0) + goto nla_put_failure; + + return 0; + +nla_put_failure: + return -1; +} + +static struct tcf_ematch_ops em_meta_ops = { + .kind = TCF_EM_META, + .change = em_meta_change, + .match = em_meta_match, + .destroy = em_meta_destroy, + .dump = em_meta_dump, + .owner = THIS_MODULE, + .link = LIST_HEAD_INIT(em_meta_ops.link) +}; + +static int __init init_em_meta(void) +{ + return tcf_em_register(&em_meta_ops); +} + +static void __exit exit_em_meta(void) +{ + tcf_em_unregister(&em_meta_ops); +} + +MODULE_LICENSE("GPL"); + +module_init(init_em_meta); +module_exit(exit_em_meta); + +MODULE_ALIAS_TCF_EMATCH(TCF_EM_META); diff --git a/net/sched/em_nbyte.c b/net/sched/em_nbyte.c new file mode 100644 index 000000000..a83b237cb --- /dev/null +++ b/net/sched/em_nbyte.c @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/em_nbyte.c N-Byte ematch + * + * Authors: Thomas Graf + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +struct nbyte_data { + struct tcf_em_nbyte hdr; + char pattern[]; +}; + +static int em_nbyte_change(struct net *net, void *data, int data_len, + struct tcf_ematch *em) +{ + struct tcf_em_nbyte *nbyte = data; + + if (data_len < sizeof(*nbyte) || + data_len < (sizeof(*nbyte) + nbyte->len)) + return -EINVAL; + + em->datalen = sizeof(*nbyte) + nbyte->len; + em->data = (unsigned long)kmemdup(data, em->datalen, GFP_KERNEL); + if (em->data == 0UL) + return -ENOMEM; + + return 0; +} + +static int em_nbyte_match(struct sk_buff *skb, struct tcf_ematch *em, + struct tcf_pkt_info *info) +{ + struct nbyte_data *nbyte = (struct nbyte_data *) em->data; + unsigned char *ptr = tcf_get_base_ptr(skb, nbyte->hdr.layer); + + ptr += nbyte->hdr.off; + + if (!tcf_valid_offset(skb, ptr, nbyte->hdr.len)) + return 0; + + return !memcmp(ptr, nbyte->pattern, nbyte->hdr.len); +} + +static struct tcf_ematch_ops em_nbyte_ops = { + .kind = TCF_EM_NBYTE, + .change = em_nbyte_change, + .match = em_nbyte_match, + .owner = THIS_MODULE, + .link = LIST_HEAD_INIT(em_nbyte_ops.link) +}; + +static int __init init_em_nbyte(void) +{ + return tcf_em_register(&em_nbyte_ops); +} + +static void __exit exit_em_nbyte(void) +{ + tcf_em_unregister(&em_nbyte_ops); +} + +MODULE_LICENSE("GPL"); + +module_init(init_em_nbyte); +module_exit(exit_em_nbyte); + +MODULE_ALIAS_TCF_EMATCH(TCF_EM_NBYTE); diff --git a/net/sched/em_text.c b/net/sched/em_text.c new file mode 100644 index 000000000..f176afb70 --- /dev/null +++ b/net/sched/em_text.c @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/em_text.c Textsearch ematch + * + * Authors: Thomas Graf + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct text_match { + u16 from_offset; + u16 to_offset; + u8 from_layer; + u8 to_layer; + struct ts_config *config; +}; + +#define EM_TEXT_PRIV(m) ((struct text_match *) (m)->data) + +static int em_text_match(struct sk_buff *skb, struct tcf_ematch *m, + struct tcf_pkt_info *info) +{ + struct text_match *tm = EM_TEXT_PRIV(m); + int from, to; + + from = tcf_get_base_ptr(skb, tm->from_layer) - skb->data; + from += tm->from_offset; + + to = tcf_get_base_ptr(skb, tm->to_layer) - skb->data; + to += tm->to_offset; + + return skb_find_text(skb, from, to, tm->config) != UINT_MAX; +} + +static int em_text_change(struct net *net, void *data, int len, + struct tcf_ematch *m) +{ + struct text_match *tm; + struct tcf_em_text *conf = data; + struct ts_config *ts_conf; + int flags = 0; + + if (len < sizeof(*conf) || len < (sizeof(*conf) + conf->pattern_len)) + return -EINVAL; + + if (conf->from_layer > conf->to_layer) + return -EINVAL; + + if (conf->from_layer == conf->to_layer && + conf->from_offset > conf->to_offset) + return -EINVAL; + +retry: + ts_conf = textsearch_prepare(conf->algo, (u8 *) conf + sizeof(*conf), + conf->pattern_len, GFP_KERNEL, flags); + + if (flags & TS_AUTOLOAD) + rtnl_lock(); + + if (IS_ERR(ts_conf)) { + if (PTR_ERR(ts_conf) == -ENOENT && !(flags & TS_AUTOLOAD)) { + rtnl_unlock(); + flags |= TS_AUTOLOAD; + goto retry; + } else + return PTR_ERR(ts_conf); + } else if (flags & TS_AUTOLOAD) { + textsearch_destroy(ts_conf); + return -EAGAIN; + } + + tm = kmalloc(sizeof(*tm), GFP_KERNEL); + if (tm == NULL) { + textsearch_destroy(ts_conf); + return -ENOBUFS; + } + + tm->from_offset = conf->from_offset; + tm->to_offset = conf->to_offset; + tm->from_layer = conf->from_layer; + tm->to_layer = conf->to_layer; + tm->config = ts_conf; + + m->datalen = sizeof(*tm); + m->data = (unsigned long) tm; + + return 0; +} + +static void em_text_destroy(struct tcf_ematch *m) +{ + if (EM_TEXT_PRIV(m) && EM_TEXT_PRIV(m)->config) { + textsearch_destroy(EM_TEXT_PRIV(m)->config); + kfree(EM_TEXT_PRIV(m)); + } +} + +static int em_text_dump(struct sk_buff *skb, struct tcf_ematch *m) +{ + struct text_match *tm = EM_TEXT_PRIV(m); + struct tcf_em_text conf; + + strncpy(conf.algo, tm->config->ops->name, sizeof(conf.algo) - 1); + conf.from_offset = tm->from_offset; + conf.to_offset = tm->to_offset; + conf.from_layer = tm->from_layer; + conf.to_layer = tm->to_layer; + conf.pattern_len = textsearch_get_pattern_len(tm->config); + conf.pad = 0; + + if (nla_put_nohdr(skb, sizeof(conf), &conf) < 0) + goto nla_put_failure; + if (nla_append(skb, conf.pattern_len, + textsearch_get_pattern(tm->config)) < 0) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -1; +} + +static struct tcf_ematch_ops em_text_ops = { + .kind = TCF_EM_TEXT, + .change = em_text_change, + .match = em_text_match, + .destroy = em_text_destroy, + .dump = em_text_dump, + .owner = THIS_MODULE, + .link = LIST_HEAD_INIT(em_text_ops.link) +}; + +static int __init init_em_text(void) +{ + return tcf_em_register(&em_text_ops); +} + +static void __exit exit_em_text(void) +{ + tcf_em_unregister(&em_text_ops); +} + +MODULE_LICENSE("GPL"); + +module_init(init_em_text); +module_exit(exit_em_text); + +MODULE_ALIAS_TCF_EMATCH(TCF_EM_TEXT); diff --git a/net/sched/em_u32.c b/net/sched/em_u32.c new file mode 100644 index 000000000..71b070da0 --- /dev/null +++ b/net/sched/em_u32.c @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/em_u32.c U32 Ematch + * + * Authors: Thomas Graf + * Alexey Kuznetsov, + * + * Based on net/sched/cls_u32.c + */ + +#include +#include +#include +#include +#include + +static int em_u32_match(struct sk_buff *skb, struct tcf_ematch *em, + struct tcf_pkt_info *info) +{ + struct tc_u32_key *key = (struct tc_u32_key *) em->data; + const unsigned char *ptr = skb_network_header(skb); + + if (info) { + if (info->ptr) + ptr = info->ptr; + ptr += (info->nexthdr & key->offmask); + } + + ptr += key->off; + + if (!tcf_valid_offset(skb, ptr, sizeof(u32))) + return 0; + + return !(((*(__be32 *) ptr) ^ key->val) & key->mask); +} + +static struct tcf_ematch_ops em_u32_ops = { + .kind = TCF_EM_U32, + .datalen = sizeof(struct tc_u32_key), + .match = em_u32_match, + .owner = THIS_MODULE, + .link = LIST_HEAD_INIT(em_u32_ops.link) +}; + +static int __init init_em_u32(void) +{ + return tcf_em_register(&em_u32_ops); +} + +static void __exit exit_em_u32(void) +{ + tcf_em_unregister(&em_u32_ops); +} + +MODULE_LICENSE("GPL"); + +module_init(init_em_u32); +module_exit(exit_em_u32); + +MODULE_ALIAS_TCF_EMATCH(TCF_EM_U32); diff --git a/net/sched/ematch.c b/net/sched/ematch.c new file mode 100644 index 000000000..5c1235e60 --- /dev/null +++ b/net/sched/ematch.c @@ -0,0 +1,550 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/ematch.c Extended Match API + * + * Authors: Thomas Graf + * + * ========================================================================== + * + * An extended match (ematch) is a small classification tool not worth + * writing a full classifier for. Ematches can be interconnected to form + * a logic expression and get attached to classifiers to extend their + * functionatlity. + * + * The userspace part transforms the logic expressions into an array + * consisting of multiple sequences of interconnected ematches separated + * by markers. Precedence is implemented by a special ematch kind + * referencing a sequence beyond the marker of the current sequence + * causing the current position in the sequence to be pushed onto a stack + * to allow the current position to be overwritten by the position referenced + * in the special ematch. Matching continues in the new sequence until a + * marker is reached causing the position to be restored from the stack. + * + * Example: + * A AND (B1 OR B2) AND C AND D + * + * ------->-PUSH------- + * -->-- / -->-- \ -->-- + * / \ / / \ \ / \ + * +-------+-------+-------+-------+-------+--------+ + * | A AND | B AND | C AND | D END | B1 OR | B2 END | + * +-------+-------+-------+-------+-------+--------+ + * \ / + * --------<-POP--------- + * + * where B is a virtual ematch referencing to sequence starting with B1. + * + * ========================================================================== + * + * How to write an ematch in 60 seconds + * ------------------------------------ + * + * 1) Provide a matcher function: + * static int my_match(struct sk_buff *skb, struct tcf_ematch *m, + * struct tcf_pkt_info *info) + * { + * struct mydata *d = (struct mydata *) m->data; + * + * if (...matching goes here...) + * return 1; + * else + * return 0; + * } + * + * 2) Fill out a struct tcf_ematch_ops: + * static struct tcf_ematch_ops my_ops = { + * .kind = unique id, + * .datalen = sizeof(struct mydata), + * .match = my_match, + * .owner = THIS_MODULE, + * }; + * + * 3) Register/Unregister your ematch: + * static int __init init_my_ematch(void) + * { + * return tcf_em_register(&my_ops); + * } + * + * static void __exit exit_my_ematch(void) + * { + * tcf_em_unregister(&my_ops); + * } + * + * module_init(init_my_ematch); + * module_exit(exit_my_ematch); + * + * 4) By now you should have two more seconds left, barely enough to + * open up a beer to watch the compilation going. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +static LIST_HEAD(ematch_ops); +static DEFINE_RWLOCK(ematch_mod_lock); + +static struct tcf_ematch_ops *tcf_em_lookup(u16 kind) +{ + struct tcf_ematch_ops *e = NULL; + + read_lock(&ematch_mod_lock); + list_for_each_entry(e, &ematch_ops, link) { + if (kind == e->kind) { + if (!try_module_get(e->owner)) + e = NULL; + read_unlock(&ematch_mod_lock); + return e; + } + } + read_unlock(&ematch_mod_lock); + + return NULL; +} + +/** + * tcf_em_register - register an extended match + * + * @ops: ematch operations lookup table + * + * This function must be called by ematches to announce their presence. + * The given @ops must have kind set to a unique identifier and the + * callback match() must be implemented. All other callbacks are optional + * and a fallback implementation is used instead. + * + * Returns -EEXISTS if an ematch of the same kind has already registered. + */ +int tcf_em_register(struct tcf_ematch_ops *ops) +{ + int err = -EEXIST; + struct tcf_ematch_ops *e; + + if (ops->match == NULL) + return -EINVAL; + + write_lock(&ematch_mod_lock); + list_for_each_entry(e, &ematch_ops, link) + if (ops->kind == e->kind) + goto errout; + + list_add_tail(&ops->link, &ematch_ops); + err = 0; +errout: + write_unlock(&ematch_mod_lock); + return err; +} +EXPORT_SYMBOL(tcf_em_register); + +/** + * tcf_em_unregister - unregister and extended match + * + * @ops: ematch operations lookup table + * + * This function must be called by ematches to announce their disappearance + * for examples when the module gets unloaded. The @ops parameter must be + * the same as the one used for registration. + * + * Returns -ENOENT if no matching ematch was found. + */ +void tcf_em_unregister(struct tcf_ematch_ops *ops) +{ + write_lock(&ematch_mod_lock); + list_del(&ops->link); + write_unlock(&ematch_mod_lock); +} +EXPORT_SYMBOL(tcf_em_unregister); + +static inline struct tcf_ematch *tcf_em_get_match(struct tcf_ematch_tree *tree, + int index) +{ + return &tree->matches[index]; +} + + +static int tcf_em_validate(struct tcf_proto *tp, + struct tcf_ematch_tree_hdr *tree_hdr, + struct tcf_ematch *em, struct nlattr *nla, int idx) +{ + int err = -EINVAL; + struct tcf_ematch_hdr *em_hdr = nla_data(nla); + int data_len = nla_len(nla) - sizeof(*em_hdr); + void *data = (void *) em_hdr + sizeof(*em_hdr); + struct net *net = tp->chain->block->net; + + if (!TCF_EM_REL_VALID(em_hdr->flags)) + goto errout; + + if (em_hdr->kind == TCF_EM_CONTAINER) { + /* Special ematch called "container", carries an index + * referencing an external ematch sequence. + */ + u32 ref; + + if (data_len < sizeof(ref)) + goto errout; + ref = *(u32 *) data; + + if (ref >= tree_hdr->nmatches) + goto errout; + + /* We do not allow backward jumps to avoid loops and jumps + * to our own position are of course illegal. + */ + if (ref <= idx) + goto errout; + + + em->data = ref; + } else { + /* Note: This lookup will increase the module refcnt + * of the ematch module referenced. In case of a failure, + * a destroy function is called by the underlying layer + * which automatically releases the reference again, therefore + * the module MUST not be given back under any circumstances + * here. Be aware, the destroy function assumes that the + * module is held if the ops field is non zero. + */ + em->ops = tcf_em_lookup(em_hdr->kind); + + if (em->ops == NULL) { + err = -ENOENT; +#ifdef CONFIG_MODULES + __rtnl_unlock(); + request_module("ematch-kind-%u", em_hdr->kind); + rtnl_lock(); + em->ops = tcf_em_lookup(em_hdr->kind); + if (em->ops) { + /* We dropped the RTNL mutex in order to + * perform the module load. Tell the caller + * to replay the request. + */ + module_put(em->ops->owner); + em->ops = NULL; + err = -EAGAIN; + } +#endif + goto errout; + } + + /* ematch module provides expected length of data, so we + * can do a basic sanity check. + */ + if (em->ops->datalen && data_len < em->ops->datalen) + goto errout; + + if (em->ops->change) { + err = -EINVAL; + if (em_hdr->flags & TCF_EM_SIMPLE) + goto errout; + err = em->ops->change(net, data, data_len, em); + if (err < 0) + goto errout; + } else if (data_len > 0) { + /* ematch module doesn't provide an own change + * procedure and expects us to allocate and copy + * the ematch data. + * + * TCF_EM_SIMPLE may be specified stating that the + * data only consists of a u32 integer and the module + * does not expected a memory reference but rather + * the value carried. + */ + if (em_hdr->flags & TCF_EM_SIMPLE) { + if (em->ops->datalen > 0) + goto errout; + if (data_len < sizeof(u32)) + goto errout; + em->data = *(u32 *) data; + } else { + void *v = kmemdup(data, data_len, GFP_KERNEL); + if (v == NULL) { + err = -ENOBUFS; + goto errout; + } + em->data = (unsigned long) v; + } + em->datalen = data_len; + } + } + + em->matchid = em_hdr->matchid; + em->flags = em_hdr->flags; + em->net = net; + + err = 0; +errout: + return err; +} + +static const struct nla_policy em_policy[TCA_EMATCH_TREE_MAX + 1] = { + [TCA_EMATCH_TREE_HDR] = { .len = sizeof(struct tcf_ematch_tree_hdr) }, + [TCA_EMATCH_TREE_LIST] = { .type = NLA_NESTED }, +}; + +/** + * tcf_em_tree_validate - validate ematch config TLV and build ematch tree + * + * @tp: classifier kind handle + * @nla: ematch tree configuration TLV + * @tree: destination ematch tree variable to store the resulting + * ematch tree. + * + * This function validates the given configuration TLV @nla and builds an + * ematch tree in @tree. The resulting tree must later be copied into + * the private classifier data using tcf_em_tree_change(). You MUST NOT + * provide the ematch tree variable of the private classifier data directly, + * the changes would not be locked properly. + * + * Returns a negative error code if the configuration TLV contains errors. + */ +int tcf_em_tree_validate(struct tcf_proto *tp, struct nlattr *nla, + struct tcf_ematch_tree *tree) +{ + int idx, list_len, matches_len, err; + struct nlattr *tb[TCA_EMATCH_TREE_MAX + 1]; + struct nlattr *rt_match, *rt_hdr, *rt_list; + struct tcf_ematch_tree_hdr *tree_hdr; + struct tcf_ematch *em; + + memset(tree, 0, sizeof(*tree)); + if (!nla) + return 0; + + err = nla_parse_nested_deprecated(tb, TCA_EMATCH_TREE_MAX, nla, + em_policy, NULL); + if (err < 0) + goto errout; + + err = -EINVAL; + rt_hdr = tb[TCA_EMATCH_TREE_HDR]; + rt_list = tb[TCA_EMATCH_TREE_LIST]; + + if (rt_hdr == NULL || rt_list == NULL) + goto errout; + + tree_hdr = nla_data(rt_hdr); + memcpy(&tree->hdr, tree_hdr, sizeof(*tree_hdr)); + + rt_match = nla_data(rt_list); + list_len = nla_len(rt_list); + matches_len = tree_hdr->nmatches * sizeof(*em); + + tree->matches = kzalloc(matches_len, GFP_KERNEL); + if (tree->matches == NULL) + goto errout; + + /* We do not use nla_parse_nested here because the maximum + * number of attributes is unknown. This saves us the allocation + * for a tb buffer which would serve no purpose at all. + * + * The array of rt attributes is parsed in the order as they are + * provided, their type must be incremental from 1 to n. Even + * if it does not serve any real purpose, a failure of sticking + * to this policy will result in parsing failure. + */ + for (idx = 0; nla_ok(rt_match, list_len); idx++) { + err = -EINVAL; + + if (rt_match->nla_type != (idx + 1)) + goto errout_abort; + + if (idx >= tree_hdr->nmatches) + goto errout_abort; + + if (nla_len(rt_match) < sizeof(struct tcf_ematch_hdr)) + goto errout_abort; + + em = tcf_em_get_match(tree, idx); + + err = tcf_em_validate(tp, tree_hdr, em, rt_match, idx); + if (err < 0) + goto errout_abort; + + rt_match = nla_next(rt_match, &list_len); + } + + /* Check if the number of matches provided by userspace actually + * complies with the array of matches. The number was used for + * the validation of references and a mismatch could lead to + * undefined references during the matching process. + */ + if (idx != tree_hdr->nmatches) { + err = -EINVAL; + goto errout_abort; + } + + err = 0; +errout: + return err; + +errout_abort: + tcf_em_tree_destroy(tree); + return err; +} +EXPORT_SYMBOL(tcf_em_tree_validate); + +/** + * tcf_em_tree_destroy - destroy an ematch tree + * + * @tree: ematch tree to be deleted + * + * This functions destroys an ematch tree previously created by + * tcf_em_tree_validate()/tcf_em_tree_change(). You must ensure that + * the ematch tree is not in use before calling this function. + */ +void tcf_em_tree_destroy(struct tcf_ematch_tree *tree) +{ + int i; + + if (tree->matches == NULL) + return; + + for (i = 0; i < tree->hdr.nmatches; i++) { + struct tcf_ematch *em = tcf_em_get_match(tree, i); + + if (em->ops) { + if (em->ops->destroy) + em->ops->destroy(em); + else if (!tcf_em_is_simple(em)) + kfree((void *) em->data); + module_put(em->ops->owner); + } + } + + tree->hdr.nmatches = 0; + kfree(tree->matches); + tree->matches = NULL; +} +EXPORT_SYMBOL(tcf_em_tree_destroy); + +/** + * tcf_em_tree_dump - dump ematch tree into a rtnl message + * + * @skb: skb holding the rtnl message + * @tree: ematch tree to be dumped + * @tlv: TLV type to be used to encapsulate the tree + * + * This function dumps a ematch tree into a rtnl message. It is valid to + * call this function while the ematch tree is in use. + * + * Returns -1 if the skb tailroom is insufficient. + */ +int tcf_em_tree_dump(struct sk_buff *skb, struct tcf_ematch_tree *tree, int tlv) +{ + int i; + u8 *tail; + struct nlattr *top_start; + struct nlattr *list_start; + + top_start = nla_nest_start_noflag(skb, tlv); + if (top_start == NULL) + goto nla_put_failure; + + if (nla_put(skb, TCA_EMATCH_TREE_HDR, sizeof(tree->hdr), &tree->hdr)) + goto nla_put_failure; + + list_start = nla_nest_start_noflag(skb, TCA_EMATCH_TREE_LIST); + if (list_start == NULL) + goto nla_put_failure; + + tail = skb_tail_pointer(skb); + for (i = 0; i < tree->hdr.nmatches; i++) { + struct nlattr *match_start = (struct nlattr *)tail; + struct tcf_ematch *em = tcf_em_get_match(tree, i); + struct tcf_ematch_hdr em_hdr = { + .kind = em->ops ? em->ops->kind : TCF_EM_CONTAINER, + .matchid = em->matchid, + .flags = em->flags + }; + + if (nla_put(skb, i + 1, sizeof(em_hdr), &em_hdr)) + goto nla_put_failure; + + if (em->ops && em->ops->dump) { + if (em->ops->dump(skb, em) < 0) + goto nla_put_failure; + } else if (tcf_em_is_container(em) || tcf_em_is_simple(em)) { + u32 u = em->data; + nla_put_nohdr(skb, sizeof(u), &u); + } else if (em->datalen > 0) + nla_put_nohdr(skb, em->datalen, (void *) em->data); + + tail = skb_tail_pointer(skb); + match_start->nla_len = tail - (u8 *)match_start; + } + + nla_nest_end(skb, list_start); + nla_nest_end(skb, top_start); + + return 0; + +nla_put_failure: + return -1; +} +EXPORT_SYMBOL(tcf_em_tree_dump); + +static inline int tcf_em_match(struct sk_buff *skb, struct tcf_ematch *em, + struct tcf_pkt_info *info) +{ + int r = em->ops->match(skb, em, info); + + return tcf_em_is_inverted(em) ? !r : r; +} + +/* Do not use this function directly, use tcf_em_tree_match instead */ +int __tcf_em_tree_match(struct sk_buff *skb, struct tcf_ematch_tree *tree, + struct tcf_pkt_info *info) +{ + int stackp = 0, match_idx = 0, res = 0; + struct tcf_ematch *cur_match; + int stack[CONFIG_NET_EMATCH_STACK]; + +proceed: + while (match_idx < tree->hdr.nmatches) { + cur_match = tcf_em_get_match(tree, match_idx); + + if (tcf_em_is_container(cur_match)) { + if (unlikely(stackp >= CONFIG_NET_EMATCH_STACK)) + goto stack_overflow; + + stack[stackp++] = match_idx; + match_idx = cur_match->data; + goto proceed; + } + + res = tcf_em_match(skb, cur_match, info); + + if (tcf_em_early_end(cur_match, res)) + break; + + match_idx++; + } + +pop_stack: + if (stackp > 0) { + match_idx = stack[--stackp]; + cur_match = tcf_em_get_match(tree, match_idx); + + if (tcf_em_is_inverted(cur_match)) + res = !res; + + if (tcf_em_early_end(cur_match, res)) { + goto pop_stack; + } else { + match_idx++; + goto proceed; + } + } + + return res; + +stack_overflow: + net_warn_ratelimited("tc ematch: local stack overflow, increase NET_EMATCH_STACK\n"); + return -1; +} +EXPORT_SYMBOL(__tcf_em_tree_match); diff --git a/net/sched/sch_api.c b/net/sched/sch_api.c new file mode 100644 index 000000000..e8f988e1c --- /dev/null +++ b/net/sched/sch_api.c @@ -0,0 +1,2389 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/sch_api.c Packet scheduler API. + * + * Authors: Alexey Kuznetsov, + * + * Fixes: + * + * Rani Assaf :980802: JIFFIES and CPU clock sources are repaired. + * Eduardo J. Blanco :990222: kmod support + * Jamal Hadi Salim : 990601: ingress support + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +/* + + Short review. + ------------- + + This file consists of two interrelated parts: + + 1. queueing disciplines manager frontend. + 2. traffic classes manager frontend. + + Generally, queueing discipline ("qdisc") is a black box, + which is able to enqueue packets and to dequeue them (when + device is ready to send something) in order and at times + determined by algorithm hidden in it. + + qdisc's are divided to two categories: + - "queues", which have no internal structure visible from outside. + - "schedulers", which split all the packets to "traffic classes", + using "packet classifiers" (look at cls_api.c) + + In turn, classes may have child qdiscs (as rule, queues) + attached to them etc. etc. etc. + + The goal of the routines in this file is to translate + information supplied by user in the form of handles + to more intelligible for kernel form, to make some sanity + checks and part of work, which is common to all qdiscs + and to provide rtnetlink notifications. + + All real intelligent work is done inside qdisc modules. + + + + Every discipline has two major routines: enqueue and dequeue. + + ---dequeue + + dequeue usually returns a skb to send. It is allowed to return NULL, + but it does not mean that queue is empty, it just means that + discipline does not want to send anything this time. + Queue is really empty if q->q.qlen == 0. + For complicated disciplines with multiple queues q->q is not + real packet queue, but however q->q.qlen must be valid. + + ---enqueue + + enqueue returns 0, if packet was enqueued successfully. + If packet (this one or another one) was dropped, it returns + not zero error code. + NET_XMIT_DROP - this packet dropped + Expected action: do not backoff, but wait until queue will clear. + NET_XMIT_CN - probably this packet enqueued, but another one dropped. + Expected action: backoff or ignore + + Auxiliary routines: + + ---peek + + like dequeue but without removing a packet from the queue + + ---reset + + returns qdisc to initial state: purge all buffers, clear all + timers, counters (except for statistics) etc. + + ---init + + initializes newly created qdisc. + + ---destroy + + destroys resources allocated by init and during lifetime of qdisc. + + ---change + + changes qdisc parameters. + */ + +/* Protects list of registered TC modules. It is pure SMP lock. */ +static DEFINE_RWLOCK(qdisc_mod_lock); + + +/************************************************ + * Queueing disciplines manipulation. * + ************************************************/ + + +/* The list of all installed queueing disciplines. */ + +static struct Qdisc_ops *qdisc_base; + +/* Register/unregister queueing discipline */ + +int register_qdisc(struct Qdisc_ops *qops) +{ + struct Qdisc_ops *q, **qp; + int rc = -EEXIST; + + write_lock(&qdisc_mod_lock); + for (qp = &qdisc_base; (q = *qp) != NULL; qp = &q->next) + if (!strcmp(qops->id, q->id)) + goto out; + + if (qops->enqueue == NULL) + qops->enqueue = noop_qdisc_ops.enqueue; + if (qops->peek == NULL) { + if (qops->dequeue == NULL) + qops->peek = noop_qdisc_ops.peek; + else + goto out_einval; + } + if (qops->dequeue == NULL) + qops->dequeue = noop_qdisc_ops.dequeue; + + if (qops->cl_ops) { + const struct Qdisc_class_ops *cops = qops->cl_ops; + + if (!(cops->find && cops->walk && cops->leaf)) + goto out_einval; + + if (cops->tcf_block && !(cops->bind_tcf && cops->unbind_tcf)) + goto out_einval; + } + + qops->next = NULL; + *qp = qops; + rc = 0; +out: + write_unlock(&qdisc_mod_lock); + return rc; + +out_einval: + rc = -EINVAL; + goto out; +} +EXPORT_SYMBOL(register_qdisc); + +void unregister_qdisc(struct Qdisc_ops *qops) +{ + struct Qdisc_ops *q, **qp; + int err = -ENOENT; + + write_lock(&qdisc_mod_lock); + for (qp = &qdisc_base; (q = *qp) != NULL; qp = &q->next) + if (q == qops) + break; + if (q) { + *qp = q->next; + q->next = NULL; + err = 0; + } + write_unlock(&qdisc_mod_lock); + + WARN(err, "unregister qdisc(%s) failed\n", qops->id); +} +EXPORT_SYMBOL(unregister_qdisc); + +/* Get default qdisc if not otherwise specified */ +void qdisc_get_default(char *name, size_t len) +{ + read_lock(&qdisc_mod_lock); + strscpy(name, default_qdisc_ops->id, len); + read_unlock(&qdisc_mod_lock); +} + +static struct Qdisc_ops *qdisc_lookup_default(const char *name) +{ + struct Qdisc_ops *q = NULL; + + for (q = qdisc_base; q; q = q->next) { + if (!strcmp(name, q->id)) { + if (!try_module_get(q->owner)) + q = NULL; + break; + } + } + + return q; +} + +/* Set new default qdisc to use */ +int qdisc_set_default(const char *name) +{ + const struct Qdisc_ops *ops; + + if (!capable(CAP_NET_ADMIN)) + return -EPERM; + + write_lock(&qdisc_mod_lock); + ops = qdisc_lookup_default(name); + if (!ops) { + /* Not found, drop lock and try to load module */ + write_unlock(&qdisc_mod_lock); + request_module("sch_%s", name); + write_lock(&qdisc_mod_lock); + + ops = qdisc_lookup_default(name); + } + + if (ops) { + /* Set new default */ + module_put(default_qdisc_ops->owner); + default_qdisc_ops = ops; + } + write_unlock(&qdisc_mod_lock); + + return ops ? 0 : -ENOENT; +} + +#ifdef CONFIG_NET_SCH_DEFAULT +/* Set default value from kernel config */ +static int __init sch_default_qdisc(void) +{ + return qdisc_set_default(CONFIG_DEFAULT_NET_SCH); +} +late_initcall(sch_default_qdisc); +#endif + +/* We know handle. Find qdisc among all qdisc's attached to device + * (root qdisc, all its children, children of children etc.) + * Note: caller either uses rtnl or rcu_read_lock() + */ + +static struct Qdisc *qdisc_match_from_root(struct Qdisc *root, u32 handle) +{ + struct Qdisc *q; + + if (!qdisc_dev(root)) + return (root->handle == handle ? root : NULL); + + if (!(root->flags & TCQ_F_BUILTIN) && + root->handle == handle) + return root; + + hash_for_each_possible_rcu(qdisc_dev(root)->qdisc_hash, q, hash, handle, + lockdep_rtnl_is_held()) { + if (q->handle == handle) + return q; + } + return NULL; +} + +void qdisc_hash_add(struct Qdisc *q, bool invisible) +{ + if ((q->parent != TC_H_ROOT) && !(q->flags & TCQ_F_INGRESS)) { + ASSERT_RTNL(); + hash_add_rcu(qdisc_dev(q)->qdisc_hash, &q->hash, q->handle); + if (invisible) + q->flags |= TCQ_F_INVISIBLE; + } +} +EXPORT_SYMBOL(qdisc_hash_add); + +void qdisc_hash_del(struct Qdisc *q) +{ + if ((q->parent != TC_H_ROOT) && !(q->flags & TCQ_F_INGRESS)) { + ASSERT_RTNL(); + hash_del_rcu(&q->hash); + } +} +EXPORT_SYMBOL(qdisc_hash_del); + +struct Qdisc *qdisc_lookup(struct net_device *dev, u32 handle) +{ + struct Qdisc *q; + + if (!handle) + return NULL; + q = qdisc_match_from_root(rtnl_dereference(dev->qdisc), handle); + if (q) + goto out; + + if (dev_ingress_queue(dev)) + q = qdisc_match_from_root( + rtnl_dereference(dev_ingress_queue(dev)->qdisc_sleeping), + handle); +out: + return q; +} + +struct Qdisc *qdisc_lookup_rcu(struct net_device *dev, u32 handle) +{ + struct netdev_queue *nq; + struct Qdisc *q; + + if (!handle) + return NULL; + q = qdisc_match_from_root(rcu_dereference(dev->qdisc), handle); + if (q) + goto out; + + nq = dev_ingress_queue_rcu(dev); + if (nq) + q = qdisc_match_from_root(rcu_dereference(nq->qdisc_sleeping), + handle); +out: + return q; +} + +static struct Qdisc *qdisc_leaf(struct Qdisc *p, u32 classid) +{ + unsigned long cl; + const struct Qdisc_class_ops *cops = p->ops->cl_ops; + + if (cops == NULL) + return NULL; + cl = cops->find(p, classid); + + if (cl == 0) + return NULL; + return cops->leaf(p, cl); +} + +/* Find queueing discipline by name */ + +static struct Qdisc_ops *qdisc_lookup_ops(struct nlattr *kind) +{ + struct Qdisc_ops *q = NULL; + + if (kind) { + read_lock(&qdisc_mod_lock); + for (q = qdisc_base; q; q = q->next) { + if (nla_strcmp(kind, q->id) == 0) { + if (!try_module_get(q->owner)) + q = NULL; + break; + } + } + read_unlock(&qdisc_mod_lock); + } + return q; +} + +/* The linklayer setting were not transferred from iproute2, in older + * versions, and the rate tables lookup systems have been dropped in + * the kernel. To keep backward compatible with older iproute2 tc + * utils, we detect the linklayer setting by detecting if the rate + * table were modified. + * + * For linklayer ATM table entries, the rate table will be aligned to + * 48 bytes, thus some table entries will contain the same value. The + * mpu (min packet unit) is also encoded into the old rate table, thus + * starting from the mpu, we find low and high table entries for + * mapping this cell. If these entries contain the same value, when + * the rate tables have been modified for linklayer ATM. + * + * This is done by rounding mpu to the nearest 48 bytes cell/entry, + * and then roundup to the next cell, calc the table entry one below, + * and compare. + */ +static __u8 __detect_linklayer(struct tc_ratespec *r, __u32 *rtab) +{ + int low = roundup(r->mpu, 48); + int high = roundup(low+1, 48); + int cell_low = low >> r->cell_log; + int cell_high = (high >> r->cell_log) - 1; + + /* rtab is too inaccurate at rates > 100Mbit/s */ + if ((r->rate > (100000000/8)) || (rtab[0] == 0)) { + pr_debug("TC linklayer: Giving up ATM detection\n"); + return TC_LINKLAYER_ETHERNET; + } + + if ((cell_high > cell_low) && (cell_high < 256) + && (rtab[cell_low] == rtab[cell_high])) { + pr_debug("TC linklayer: Detected ATM, low(%d)=high(%d)=%u\n", + cell_low, cell_high, rtab[cell_high]); + return TC_LINKLAYER_ATM; + } + return TC_LINKLAYER_ETHERNET; +} + +static struct qdisc_rate_table *qdisc_rtab_list; + +struct qdisc_rate_table *qdisc_get_rtab(struct tc_ratespec *r, + struct nlattr *tab, + struct netlink_ext_ack *extack) +{ + struct qdisc_rate_table *rtab; + + if (tab == NULL || r->rate == 0 || + r->cell_log == 0 || r->cell_log >= 32 || + nla_len(tab) != TC_RTAB_SIZE) { + NL_SET_ERR_MSG(extack, "Invalid rate table parameters for searching"); + return NULL; + } + + for (rtab = qdisc_rtab_list; rtab; rtab = rtab->next) { + if (!memcmp(&rtab->rate, r, sizeof(struct tc_ratespec)) && + !memcmp(&rtab->data, nla_data(tab), 1024)) { + rtab->refcnt++; + return rtab; + } + } + + rtab = kmalloc(sizeof(*rtab), GFP_KERNEL); + if (rtab) { + rtab->rate = *r; + rtab->refcnt = 1; + memcpy(rtab->data, nla_data(tab), 1024); + if (r->linklayer == TC_LINKLAYER_UNAWARE) + r->linklayer = __detect_linklayer(r, rtab->data); + rtab->next = qdisc_rtab_list; + qdisc_rtab_list = rtab; + } else { + NL_SET_ERR_MSG(extack, "Failed to allocate new qdisc rate table"); + } + return rtab; +} +EXPORT_SYMBOL(qdisc_get_rtab); + +void qdisc_put_rtab(struct qdisc_rate_table *tab) +{ + struct qdisc_rate_table *rtab, **rtabp; + + if (!tab || --tab->refcnt) + return; + + for (rtabp = &qdisc_rtab_list; + (rtab = *rtabp) != NULL; + rtabp = &rtab->next) { + if (rtab == tab) { + *rtabp = rtab->next; + kfree(rtab); + return; + } + } +} +EXPORT_SYMBOL(qdisc_put_rtab); + +static LIST_HEAD(qdisc_stab_list); + +static const struct nla_policy stab_policy[TCA_STAB_MAX + 1] = { + [TCA_STAB_BASE] = { .len = sizeof(struct tc_sizespec) }, + [TCA_STAB_DATA] = { .type = NLA_BINARY }, +}; + +static struct qdisc_size_table *qdisc_get_stab(struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_STAB_MAX + 1]; + struct qdisc_size_table *stab; + struct tc_sizespec *s; + unsigned int tsize = 0; + u16 *tab = NULL; + int err; + + err = nla_parse_nested_deprecated(tb, TCA_STAB_MAX, opt, stab_policy, + extack); + if (err < 0) + return ERR_PTR(err); + if (!tb[TCA_STAB_BASE]) { + NL_SET_ERR_MSG(extack, "Size table base attribute is missing"); + return ERR_PTR(-EINVAL); + } + + s = nla_data(tb[TCA_STAB_BASE]); + + if (s->tsize > 0) { + if (!tb[TCA_STAB_DATA]) { + NL_SET_ERR_MSG(extack, "Size table data attribute is missing"); + return ERR_PTR(-EINVAL); + } + tab = nla_data(tb[TCA_STAB_DATA]); + tsize = nla_len(tb[TCA_STAB_DATA]) / sizeof(u16); + } + + if (tsize != s->tsize || (!tab && tsize > 0)) { + NL_SET_ERR_MSG(extack, "Invalid size of size table"); + return ERR_PTR(-EINVAL); + } + + list_for_each_entry(stab, &qdisc_stab_list, list) { + if (memcmp(&stab->szopts, s, sizeof(*s))) + continue; + if (tsize > 0 && + memcmp(stab->data, tab, flex_array_size(stab, data, tsize))) + continue; + stab->refcnt++; + return stab; + } + + if (s->size_log > STAB_SIZE_LOG_MAX || + s->cell_log > STAB_SIZE_LOG_MAX) { + NL_SET_ERR_MSG(extack, "Invalid logarithmic size of size table"); + return ERR_PTR(-EINVAL); + } + + stab = kmalloc(struct_size(stab, data, tsize), GFP_KERNEL); + if (!stab) + return ERR_PTR(-ENOMEM); + + stab->refcnt = 1; + stab->szopts = *s; + if (tsize > 0) + memcpy(stab->data, tab, flex_array_size(stab, data, tsize)); + + list_add_tail(&stab->list, &qdisc_stab_list); + + return stab; +} + +void qdisc_put_stab(struct qdisc_size_table *tab) +{ + if (!tab) + return; + + if (--tab->refcnt == 0) { + list_del(&tab->list); + kfree_rcu(tab, rcu); + } +} +EXPORT_SYMBOL(qdisc_put_stab); + +static int qdisc_dump_stab(struct sk_buff *skb, struct qdisc_size_table *stab) +{ + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, TCA_STAB); + if (nest == NULL) + goto nla_put_failure; + if (nla_put(skb, TCA_STAB_BASE, sizeof(stab->szopts), &stab->szopts)) + goto nla_put_failure; + nla_nest_end(skb, nest); + + return skb->len; + +nla_put_failure: + return -1; +} + +void __qdisc_calculate_pkt_len(struct sk_buff *skb, + const struct qdisc_size_table *stab) +{ + int pkt_len, slot; + + pkt_len = skb->len + stab->szopts.overhead; + if (unlikely(!stab->szopts.tsize)) + goto out; + + slot = pkt_len + stab->szopts.cell_align; + if (unlikely(slot < 0)) + slot = 0; + + slot >>= stab->szopts.cell_log; + if (likely(slot < stab->szopts.tsize)) + pkt_len = stab->data[slot]; + else + pkt_len = stab->data[stab->szopts.tsize - 1] * + (slot / stab->szopts.tsize) + + stab->data[slot % stab->szopts.tsize]; + + pkt_len <<= stab->szopts.size_log; +out: + if (unlikely(pkt_len < 1)) + pkt_len = 1; + qdisc_skb_cb(skb)->pkt_len = pkt_len; +} +EXPORT_SYMBOL(__qdisc_calculate_pkt_len); + +void qdisc_warn_nonwc(const char *txt, struct Qdisc *qdisc) +{ + if (!(qdisc->flags & TCQ_F_WARN_NONWC)) { + pr_warn("%s: %s qdisc %X: is non-work-conserving?\n", + txt, qdisc->ops->id, qdisc->handle >> 16); + qdisc->flags |= TCQ_F_WARN_NONWC; + } +} +EXPORT_SYMBOL(qdisc_warn_nonwc); + +static enum hrtimer_restart qdisc_watchdog(struct hrtimer *timer) +{ + struct qdisc_watchdog *wd = container_of(timer, struct qdisc_watchdog, + timer); + + rcu_read_lock(); + __netif_schedule(qdisc_root(wd->qdisc)); + rcu_read_unlock(); + + return HRTIMER_NORESTART; +} + +void qdisc_watchdog_init_clockid(struct qdisc_watchdog *wd, struct Qdisc *qdisc, + clockid_t clockid) +{ + hrtimer_init(&wd->timer, clockid, HRTIMER_MODE_ABS_PINNED); + wd->timer.function = qdisc_watchdog; + wd->qdisc = qdisc; +} +EXPORT_SYMBOL(qdisc_watchdog_init_clockid); + +void qdisc_watchdog_init(struct qdisc_watchdog *wd, struct Qdisc *qdisc) +{ + qdisc_watchdog_init_clockid(wd, qdisc, CLOCK_MONOTONIC); +} +EXPORT_SYMBOL(qdisc_watchdog_init); + +void qdisc_watchdog_schedule_range_ns(struct qdisc_watchdog *wd, u64 expires, + u64 delta_ns) +{ + bool deactivated; + + rcu_read_lock(); + deactivated = test_bit(__QDISC_STATE_DEACTIVATED, + &qdisc_root_sleeping(wd->qdisc)->state); + rcu_read_unlock(); + if (deactivated) + return; + + if (hrtimer_is_queued(&wd->timer)) { + /* If timer is already set in [expires, expires + delta_ns], + * do not reprogram it. + */ + if (wd->last_expires - expires <= delta_ns) + return; + } + + wd->last_expires = expires; + hrtimer_start_range_ns(&wd->timer, + ns_to_ktime(expires), + delta_ns, + HRTIMER_MODE_ABS_PINNED); +} +EXPORT_SYMBOL(qdisc_watchdog_schedule_range_ns); + +void qdisc_watchdog_cancel(struct qdisc_watchdog *wd) +{ + hrtimer_cancel(&wd->timer); +} +EXPORT_SYMBOL(qdisc_watchdog_cancel); + +static struct hlist_head *qdisc_class_hash_alloc(unsigned int n) +{ + struct hlist_head *h; + unsigned int i; + + h = kvmalloc_array(n, sizeof(struct hlist_head), GFP_KERNEL); + + if (h != NULL) { + for (i = 0; i < n; i++) + INIT_HLIST_HEAD(&h[i]); + } + return h; +} + +void qdisc_class_hash_grow(struct Qdisc *sch, struct Qdisc_class_hash *clhash) +{ + struct Qdisc_class_common *cl; + struct hlist_node *next; + struct hlist_head *nhash, *ohash; + unsigned int nsize, nmask, osize; + unsigned int i, h; + + /* Rehash when load factor exceeds 0.75 */ + if (clhash->hashelems * 4 <= clhash->hashsize * 3) + return; + nsize = clhash->hashsize * 2; + nmask = nsize - 1; + nhash = qdisc_class_hash_alloc(nsize); + if (nhash == NULL) + return; + + ohash = clhash->hash; + osize = clhash->hashsize; + + sch_tree_lock(sch); + for (i = 0; i < osize; i++) { + hlist_for_each_entry_safe(cl, next, &ohash[i], hnode) { + h = qdisc_class_hash(cl->classid, nmask); + hlist_add_head(&cl->hnode, &nhash[h]); + } + } + clhash->hash = nhash; + clhash->hashsize = nsize; + clhash->hashmask = nmask; + sch_tree_unlock(sch); + + kvfree(ohash); +} +EXPORT_SYMBOL(qdisc_class_hash_grow); + +int qdisc_class_hash_init(struct Qdisc_class_hash *clhash) +{ + unsigned int size = 4; + + clhash->hash = qdisc_class_hash_alloc(size); + if (!clhash->hash) + return -ENOMEM; + clhash->hashsize = size; + clhash->hashmask = size - 1; + clhash->hashelems = 0; + return 0; +} +EXPORT_SYMBOL(qdisc_class_hash_init); + +void qdisc_class_hash_destroy(struct Qdisc_class_hash *clhash) +{ + kvfree(clhash->hash); +} +EXPORT_SYMBOL(qdisc_class_hash_destroy); + +void qdisc_class_hash_insert(struct Qdisc_class_hash *clhash, + struct Qdisc_class_common *cl) +{ + unsigned int h; + + INIT_HLIST_NODE(&cl->hnode); + h = qdisc_class_hash(cl->classid, clhash->hashmask); + hlist_add_head(&cl->hnode, &clhash->hash[h]); + clhash->hashelems++; +} +EXPORT_SYMBOL(qdisc_class_hash_insert); + +void qdisc_class_hash_remove(struct Qdisc_class_hash *clhash, + struct Qdisc_class_common *cl) +{ + hlist_del(&cl->hnode); + clhash->hashelems--; +} +EXPORT_SYMBOL(qdisc_class_hash_remove); + +/* Allocate an unique handle from space managed by kernel + * Possible range is [8000-FFFF]:0000 (0x8000 values) + */ +static u32 qdisc_alloc_handle(struct net_device *dev) +{ + int i = 0x8000; + static u32 autohandle = TC_H_MAKE(0x80000000U, 0); + + do { + autohandle += TC_H_MAKE(0x10000U, 0); + if (autohandle == TC_H_MAKE(TC_H_ROOT, 0)) + autohandle = TC_H_MAKE(0x80000000U, 0); + if (!qdisc_lookup(dev, autohandle)) + return autohandle; + cond_resched(); + } while (--i > 0); + + return 0; +} + +void qdisc_tree_reduce_backlog(struct Qdisc *sch, int n, int len) +{ + bool qdisc_is_offloaded = sch->flags & TCQ_F_OFFLOADED; + const struct Qdisc_class_ops *cops; + unsigned long cl; + u32 parentid; + bool notify; + int drops; + + if (n == 0 && len == 0) + return; + drops = max_t(int, n, 0); + rcu_read_lock(); + while ((parentid = sch->parent)) { + if (TC_H_MAJ(parentid) == TC_H_MAJ(TC_H_INGRESS)) + break; + + if (sch->flags & TCQ_F_NOPARENT) + break; + /* Notify parent qdisc only if child qdisc becomes empty. + * + * If child was empty even before update then backlog + * counter is screwed and we skip notification because + * parent class is already passive. + * + * If the original child was offloaded then it is allowed + * to be seem as empty, so the parent is notified anyway. + */ + notify = !sch->q.qlen && !WARN_ON_ONCE(!n && + !qdisc_is_offloaded); + /* TODO: perform the search on a per txq basis */ + sch = qdisc_lookup(qdisc_dev(sch), TC_H_MAJ(parentid)); + if (sch == NULL) { + WARN_ON_ONCE(parentid != TC_H_ROOT); + break; + } + cops = sch->ops->cl_ops; + if (notify && cops->qlen_notify) { + cl = cops->find(sch, parentid); + cops->qlen_notify(sch, cl); + } + sch->q.qlen -= n; + sch->qstats.backlog -= len; + __qdisc_qstats_drop(sch, drops); + } + rcu_read_unlock(); +} +EXPORT_SYMBOL(qdisc_tree_reduce_backlog); + +int qdisc_offload_dump_helper(struct Qdisc *sch, enum tc_setup_type type, + void *type_data) +{ + struct net_device *dev = qdisc_dev(sch); + int err; + + sch->flags &= ~TCQ_F_OFFLOADED; + if (!tc_can_offload(dev) || !dev->netdev_ops->ndo_setup_tc) + return 0; + + err = dev->netdev_ops->ndo_setup_tc(dev, type, type_data); + if (err == -EOPNOTSUPP) + return 0; + + if (!err) + sch->flags |= TCQ_F_OFFLOADED; + + return err; +} +EXPORT_SYMBOL(qdisc_offload_dump_helper); + +void qdisc_offload_graft_helper(struct net_device *dev, struct Qdisc *sch, + struct Qdisc *new, struct Qdisc *old, + enum tc_setup_type type, void *type_data, + struct netlink_ext_ack *extack) +{ + bool any_qdisc_is_offloaded; + int err; + + if (!tc_can_offload(dev) || !dev->netdev_ops->ndo_setup_tc) + return; + + err = dev->netdev_ops->ndo_setup_tc(dev, type, type_data); + + /* Don't report error if the graft is part of destroy operation. */ + if (!err || !new || new == &noop_qdisc) + return; + + /* Don't report error if the parent, the old child and the new + * one are not offloaded. + */ + any_qdisc_is_offloaded = new->flags & TCQ_F_OFFLOADED; + any_qdisc_is_offloaded |= sch && sch->flags & TCQ_F_OFFLOADED; + any_qdisc_is_offloaded |= old && old->flags & TCQ_F_OFFLOADED; + + if (any_qdisc_is_offloaded) + NL_SET_ERR_MSG(extack, "Offloading graft operation failed."); +} +EXPORT_SYMBOL(qdisc_offload_graft_helper); + +void qdisc_offload_query_caps(struct net_device *dev, + enum tc_setup_type type, + void *caps, size_t caps_len) +{ + const struct net_device_ops *ops = dev->netdev_ops; + struct tc_query_caps_base base = { + .type = type, + .caps = caps, + }; + + memset(caps, 0, caps_len); + + if (ops->ndo_setup_tc) + ops->ndo_setup_tc(dev, TC_QUERY_CAPS, &base); +} +EXPORT_SYMBOL(qdisc_offload_query_caps); + +static void qdisc_offload_graft_root(struct net_device *dev, + struct Qdisc *new, struct Qdisc *old, + struct netlink_ext_ack *extack) +{ + struct tc_root_qopt_offload graft_offload = { + .command = TC_ROOT_GRAFT, + .handle = new ? new->handle : 0, + .ingress = (new && new->flags & TCQ_F_INGRESS) || + (old && old->flags & TCQ_F_INGRESS), + }; + + qdisc_offload_graft_helper(dev, NULL, new, old, + TC_SETUP_ROOT_QDISC, &graft_offload, extack); +} + +static int tc_fill_qdisc(struct sk_buff *skb, struct Qdisc *q, u32 clid, + u32 portid, u32 seq, u16 flags, int event, + struct netlink_ext_ack *extack) +{ + struct gnet_stats_basic_sync __percpu *cpu_bstats = NULL; + struct gnet_stats_queue __percpu *cpu_qstats = NULL; + struct tcmsg *tcm; + struct nlmsghdr *nlh; + unsigned char *b = skb_tail_pointer(skb); + struct gnet_dump d; + struct qdisc_size_table *stab; + u32 block_index; + __u32 qlen; + + cond_resched(); + nlh = nlmsg_put(skb, portid, seq, event, sizeof(*tcm), flags); + if (!nlh) + goto out_nlmsg_trim; + tcm = nlmsg_data(nlh); + tcm->tcm_family = AF_UNSPEC; + tcm->tcm__pad1 = 0; + tcm->tcm__pad2 = 0; + tcm->tcm_ifindex = qdisc_dev(q)->ifindex; + tcm->tcm_parent = clid; + tcm->tcm_handle = q->handle; + tcm->tcm_info = refcount_read(&q->refcnt); + if (nla_put_string(skb, TCA_KIND, q->ops->id)) + goto nla_put_failure; + if (q->ops->ingress_block_get) { + block_index = q->ops->ingress_block_get(q); + if (block_index && + nla_put_u32(skb, TCA_INGRESS_BLOCK, block_index)) + goto nla_put_failure; + } + if (q->ops->egress_block_get) { + block_index = q->ops->egress_block_get(q); + if (block_index && + nla_put_u32(skb, TCA_EGRESS_BLOCK, block_index)) + goto nla_put_failure; + } + if (q->ops->dump && q->ops->dump(q, skb) < 0) + goto nla_put_failure; + if (nla_put_u8(skb, TCA_HW_OFFLOAD, !!(q->flags & TCQ_F_OFFLOADED))) + goto nla_put_failure; + qlen = qdisc_qlen_sum(q); + + stab = rtnl_dereference(q->stab); + if (stab && qdisc_dump_stab(skb, stab) < 0) + goto nla_put_failure; + + if (gnet_stats_start_copy_compat(skb, TCA_STATS2, TCA_STATS, TCA_XSTATS, + NULL, &d, TCA_PAD) < 0) + goto nla_put_failure; + + if (q->ops->dump_stats && q->ops->dump_stats(q, &d) < 0) + goto nla_put_failure; + + if (qdisc_is_percpu_stats(q)) { + cpu_bstats = q->cpu_bstats; + cpu_qstats = q->cpu_qstats; + } + + if (gnet_stats_copy_basic(&d, cpu_bstats, &q->bstats, true) < 0 || + gnet_stats_copy_rate_est(&d, &q->rate_est) < 0 || + gnet_stats_copy_queue(&d, cpu_qstats, &q->qstats, qlen) < 0) + goto nla_put_failure; + + if (gnet_stats_finish_copy(&d) < 0) + goto nla_put_failure; + + if (extack && extack->_msg && + nla_put_string(skb, TCA_EXT_WARN_MSG, extack->_msg)) + goto out_nlmsg_trim; + + nlh->nlmsg_len = skb_tail_pointer(skb) - b; + + return skb->len; + +out_nlmsg_trim: +nla_put_failure: + nlmsg_trim(skb, b); + return -1; +} + +static bool tc_qdisc_dump_ignore(struct Qdisc *q, bool dump_invisible) +{ + if (q->flags & TCQ_F_BUILTIN) + return true; + if ((q->flags & TCQ_F_INVISIBLE) && !dump_invisible) + return true; + + return false; +} + +static int qdisc_notify(struct net *net, struct sk_buff *oskb, + struct nlmsghdr *n, u32 clid, + struct Qdisc *old, struct Qdisc *new, + struct netlink_ext_ack *extack) +{ + struct sk_buff *skb; + u32 portid = oskb ? NETLINK_CB(oskb).portid : 0; + + skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) + return -ENOBUFS; + + if (old && !tc_qdisc_dump_ignore(old, false)) { + if (tc_fill_qdisc(skb, old, clid, portid, n->nlmsg_seq, + 0, RTM_DELQDISC, extack) < 0) + goto err_out; + } + if (new && !tc_qdisc_dump_ignore(new, false)) { + if (tc_fill_qdisc(skb, new, clid, portid, n->nlmsg_seq, + old ? NLM_F_REPLACE : 0, RTM_NEWQDISC, extack) < 0) + goto err_out; + } + + if (skb->len) + return rtnetlink_send(skb, net, portid, RTNLGRP_TC, + n->nlmsg_flags & NLM_F_ECHO); + +err_out: + kfree_skb(skb); + return -EINVAL; +} + +static void notify_and_destroy(struct net *net, struct sk_buff *skb, + struct nlmsghdr *n, u32 clid, + struct Qdisc *old, struct Qdisc *new, + struct netlink_ext_ack *extack) +{ + if (new || old) + qdisc_notify(net, skb, n, clid, old, new, extack); + + if (old) + qdisc_put(old); +} + +static void qdisc_clear_nolock(struct Qdisc *sch) +{ + sch->flags &= ~TCQ_F_NOLOCK; + if (!(sch->flags & TCQ_F_CPUSTATS)) + return; + + free_percpu(sch->cpu_bstats); + free_percpu(sch->cpu_qstats); + sch->cpu_bstats = NULL; + sch->cpu_qstats = NULL; + sch->flags &= ~TCQ_F_CPUSTATS; +} + +/* Graft qdisc "new" to class "classid" of qdisc "parent" or + * to device "dev". + * + * When appropriate send a netlink notification using 'skb' + * and "n". + * + * On success, destroy old qdisc. + */ + +static int qdisc_graft(struct net_device *dev, struct Qdisc *parent, + struct sk_buff *skb, struct nlmsghdr *n, u32 classid, + struct Qdisc *new, struct Qdisc *old, + struct netlink_ext_ack *extack) +{ + struct Qdisc *q = old; + struct net *net = dev_net(dev); + + if (parent == NULL) { + unsigned int i, num_q, ingress; + struct netdev_queue *dev_queue; + + ingress = 0; + num_q = dev->num_tx_queues; + if ((q && q->flags & TCQ_F_INGRESS) || + (new && new->flags & TCQ_F_INGRESS)) { + ingress = 1; + dev_queue = dev_ingress_queue(dev); + if (!dev_queue) { + NL_SET_ERR_MSG(extack, "Device does not have an ingress queue"); + return -ENOENT; + } + + q = rtnl_dereference(dev_queue->qdisc_sleeping); + + /* This is the counterpart of that qdisc_refcount_inc_nz() call in + * __tcf_qdisc_find() for filter requests. + */ + if (!qdisc_refcount_dec_if_one(q)) { + NL_SET_ERR_MSG(extack, + "Current ingress or clsact Qdisc has ongoing filter requests"); + return -EBUSY; + } + } + + if (dev->flags & IFF_UP) + dev_deactivate(dev); + + qdisc_offload_graft_root(dev, new, old, extack); + + if (new && new->ops->attach && !ingress) + goto skip; + + if (!ingress) { + for (i = 0; i < num_q; i++) { + dev_queue = netdev_get_tx_queue(dev, i); + old = dev_graft_qdisc(dev_queue, new); + + if (new && i > 0) + qdisc_refcount_inc(new); + qdisc_put(old); + } + } else { + old = dev_graft_qdisc(dev_queue, NULL); + + /* {ingress,clsact}_destroy() @old before grafting @new to avoid + * unprotected concurrent accesses to net_device::miniq_{in,e}gress + * pointer(s) in mini_qdisc_pair_swap(). + */ + qdisc_notify(net, skb, n, classid, old, new, extack); + qdisc_destroy(old); + + dev_graft_qdisc(dev_queue, new); + } + +skip: + if (!ingress) { + old = rtnl_dereference(dev->qdisc); + if (new && !new->ops->attach) + qdisc_refcount_inc(new); + rcu_assign_pointer(dev->qdisc, new ? : &noop_qdisc); + + notify_and_destroy(net, skb, n, classid, old, new, extack); + + if (new && new->ops->attach) + new->ops->attach(new); + } + + if (dev->flags & IFF_UP) + dev_activate(dev); + } else { + const struct Qdisc_class_ops *cops = parent->ops->cl_ops; + unsigned long cl; + int err; + + /* Only support running class lockless if parent is lockless */ + if (new && (new->flags & TCQ_F_NOLOCK) && !(parent->flags & TCQ_F_NOLOCK)) + qdisc_clear_nolock(new); + + if (!cops || !cops->graft) + return -EOPNOTSUPP; + + cl = cops->find(parent, classid); + if (!cl) { + NL_SET_ERR_MSG(extack, "Specified class not found"); + return -ENOENT; + } + + if (new && new->ops == &noqueue_qdisc_ops) { + NL_SET_ERR_MSG(extack, "Cannot assign noqueue to a class"); + return -EINVAL; + } + + err = cops->graft(parent, cl, new, &old, extack); + if (err) + return err; + notify_and_destroy(net, skb, n, classid, old, new, extack); + } + return 0; +} + +static int qdisc_block_indexes_set(struct Qdisc *sch, struct nlattr **tca, + struct netlink_ext_ack *extack) +{ + u32 block_index; + + if (tca[TCA_INGRESS_BLOCK]) { + block_index = nla_get_u32(tca[TCA_INGRESS_BLOCK]); + + if (!block_index) { + NL_SET_ERR_MSG(extack, "Ingress block index cannot be 0"); + return -EINVAL; + } + if (!sch->ops->ingress_block_set) { + NL_SET_ERR_MSG(extack, "Ingress block sharing is not supported"); + return -EOPNOTSUPP; + } + sch->ops->ingress_block_set(sch, block_index); + } + if (tca[TCA_EGRESS_BLOCK]) { + block_index = nla_get_u32(tca[TCA_EGRESS_BLOCK]); + + if (!block_index) { + NL_SET_ERR_MSG(extack, "Egress block index cannot be 0"); + return -EINVAL; + } + if (!sch->ops->egress_block_set) { + NL_SET_ERR_MSG(extack, "Egress block sharing is not supported"); + return -EOPNOTSUPP; + } + sch->ops->egress_block_set(sch, block_index); + } + return 0; +} + +/* + Allocate and initialize new qdisc. + + Parameters are passed via opt. + */ + +static struct Qdisc *qdisc_create(struct net_device *dev, + struct netdev_queue *dev_queue, + u32 parent, u32 handle, + struct nlattr **tca, int *errp, + struct netlink_ext_ack *extack) +{ + int err; + struct nlattr *kind = tca[TCA_KIND]; + struct Qdisc *sch; + struct Qdisc_ops *ops; + struct qdisc_size_table *stab; + + ops = qdisc_lookup_ops(kind); +#ifdef CONFIG_MODULES + if (ops == NULL && kind != NULL) { + char name[IFNAMSIZ]; + if (nla_strscpy(name, kind, IFNAMSIZ) >= 0) { + /* We dropped the RTNL semaphore in order to + * perform the module load. So, even if we + * succeeded in loading the module we have to + * tell the caller to replay the request. We + * indicate this using -EAGAIN. + * We replay the request because the device may + * go away in the mean time. + */ + rtnl_unlock(); + request_module("sch_%s", name); + rtnl_lock(); + ops = qdisc_lookup_ops(kind); + if (ops != NULL) { + /* We will try again qdisc_lookup_ops, + * so don't keep a reference. + */ + module_put(ops->owner); + err = -EAGAIN; + goto err_out; + } + } + } +#endif + + err = -ENOENT; + if (!ops) { + NL_SET_ERR_MSG(extack, "Specified qdisc kind is unknown"); + goto err_out; + } + + sch = qdisc_alloc(dev_queue, ops, extack); + if (IS_ERR(sch)) { + err = PTR_ERR(sch); + goto err_out2; + } + + sch->parent = parent; + + if (handle == TC_H_INGRESS) { + if (!(sch->flags & TCQ_F_INGRESS)) { + NL_SET_ERR_MSG(extack, + "Specified parent ID is reserved for ingress and clsact Qdiscs"); + err = -EINVAL; + goto err_out3; + } + handle = TC_H_MAKE(TC_H_INGRESS, 0); + } else { + if (handle == 0) { + handle = qdisc_alloc_handle(dev); + if (handle == 0) { + NL_SET_ERR_MSG(extack, "Maximum number of qdisc handles was exceeded"); + err = -ENOSPC; + goto err_out3; + } + } + if (!netif_is_multiqueue(dev)) + sch->flags |= TCQ_F_ONETXQUEUE; + } + + sch->handle = handle; + + /* This exist to keep backward compatible with a userspace + * loophole, what allowed userspace to get IFF_NO_QUEUE + * facility on older kernels by setting tx_queue_len=0 (prior + * to qdisc init), and then forgot to reinit tx_queue_len + * before again attaching a qdisc. + */ + if ((dev->priv_flags & IFF_NO_QUEUE) && (dev->tx_queue_len == 0)) { + dev->tx_queue_len = DEFAULT_TX_QUEUE_LEN; + netdev_info(dev, "Caught tx_queue_len zero misconfig\n"); + } + + err = qdisc_block_indexes_set(sch, tca, extack); + if (err) + goto err_out3; + + if (ops->init) { + err = ops->init(sch, tca[TCA_OPTIONS], extack); + if (err != 0) + goto err_out5; + } + + if (tca[TCA_STAB]) { + stab = qdisc_get_stab(tca[TCA_STAB], extack); + if (IS_ERR(stab)) { + err = PTR_ERR(stab); + goto err_out4; + } + rcu_assign_pointer(sch->stab, stab); + } + if (tca[TCA_RATE]) { + err = -EOPNOTSUPP; + if (sch->flags & TCQ_F_MQROOT) { + NL_SET_ERR_MSG(extack, "Cannot attach rate estimator to a multi-queue root qdisc"); + goto err_out4; + } + + err = gen_new_estimator(&sch->bstats, + sch->cpu_bstats, + &sch->rate_est, + NULL, + true, + tca[TCA_RATE]); + if (err) { + NL_SET_ERR_MSG(extack, "Failed to generate new estimator"); + goto err_out4; + } + } + + qdisc_hash_add(sch, false); + trace_qdisc_create(ops, dev, parent); + + return sch; + +err_out5: + /* ops->init() failed, we call ->destroy() like qdisc_create_dflt() */ + if (ops->destroy) + ops->destroy(sch); +err_out3: + netdev_put(dev, &sch->dev_tracker); + qdisc_free(sch); +err_out2: + module_put(ops->owner); +err_out: + *errp = err; + return NULL; + +err_out4: + /* + * Any broken qdiscs that would require a ops->reset() here? + * The qdisc was never in action so it shouldn't be necessary. + */ + qdisc_put_stab(rtnl_dereference(sch->stab)); + if (ops->destroy) + ops->destroy(sch); + goto err_out3; +} + +static int qdisc_change(struct Qdisc *sch, struct nlattr **tca, + struct netlink_ext_ack *extack) +{ + struct qdisc_size_table *ostab, *stab = NULL; + int err = 0; + + if (tca[TCA_OPTIONS]) { + if (!sch->ops->change) { + NL_SET_ERR_MSG(extack, "Change operation not supported by specified qdisc"); + return -EINVAL; + } + if (tca[TCA_INGRESS_BLOCK] || tca[TCA_EGRESS_BLOCK]) { + NL_SET_ERR_MSG(extack, "Change of blocks is not supported"); + return -EOPNOTSUPP; + } + err = sch->ops->change(sch, tca[TCA_OPTIONS], extack); + if (err) + return err; + } + + if (tca[TCA_STAB]) { + stab = qdisc_get_stab(tca[TCA_STAB], extack); + if (IS_ERR(stab)) + return PTR_ERR(stab); + } + + ostab = rtnl_dereference(sch->stab); + rcu_assign_pointer(sch->stab, stab); + qdisc_put_stab(ostab); + + if (tca[TCA_RATE]) { + /* NB: ignores errors from replace_estimator + because change can't be undone. */ + if (sch->flags & TCQ_F_MQROOT) + goto out; + gen_replace_estimator(&sch->bstats, + sch->cpu_bstats, + &sch->rate_est, + NULL, + true, + tca[TCA_RATE]); + } +out: + return 0; +} + +struct check_loop_arg { + struct qdisc_walker w; + struct Qdisc *p; + int depth; +}; + +static int check_loop_fn(struct Qdisc *q, unsigned long cl, + struct qdisc_walker *w); + +static int check_loop(struct Qdisc *q, struct Qdisc *p, int depth) +{ + struct check_loop_arg arg; + + if (q->ops->cl_ops == NULL) + return 0; + + arg.w.stop = arg.w.skip = arg.w.count = 0; + arg.w.fn = check_loop_fn; + arg.depth = depth; + arg.p = p; + q->ops->cl_ops->walk(q, &arg.w); + return arg.w.stop ? -ELOOP : 0; +} + +static int +check_loop_fn(struct Qdisc *q, unsigned long cl, struct qdisc_walker *w) +{ + struct Qdisc *leaf; + const struct Qdisc_class_ops *cops = q->ops->cl_ops; + struct check_loop_arg *arg = (struct check_loop_arg *)w; + + leaf = cops->leaf(q, cl); + if (leaf) { + if (leaf == arg->p || arg->depth > 7) + return -ELOOP; + return check_loop(leaf, arg->p, arg->depth + 1); + } + return 0; +} + +const struct nla_policy rtm_tca_policy[TCA_MAX + 1] = { + [TCA_KIND] = { .type = NLA_STRING }, + [TCA_RATE] = { .type = NLA_BINARY, + .len = sizeof(struct tc_estimator) }, + [TCA_STAB] = { .type = NLA_NESTED }, + [TCA_DUMP_INVISIBLE] = { .type = NLA_FLAG }, + [TCA_CHAIN] = { .type = NLA_U32 }, + [TCA_INGRESS_BLOCK] = { .type = NLA_U32 }, + [TCA_EGRESS_BLOCK] = { .type = NLA_U32 }, +}; + +/* + * Delete/get qdisc. + */ + +static int tc_get_qdisc(struct sk_buff *skb, struct nlmsghdr *n, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct tcmsg *tcm = nlmsg_data(n); + struct nlattr *tca[TCA_MAX + 1]; + struct net_device *dev; + u32 clid; + struct Qdisc *q = NULL; + struct Qdisc *p = NULL; + int err; + + err = nlmsg_parse_deprecated(n, sizeof(*tcm), tca, TCA_MAX, + rtm_tca_policy, extack); + if (err < 0) + return err; + + dev = __dev_get_by_index(net, tcm->tcm_ifindex); + if (!dev) + return -ENODEV; + + clid = tcm->tcm_parent; + if (clid) { + if (clid != TC_H_ROOT) { + if (TC_H_MAJ(clid) != TC_H_MAJ(TC_H_INGRESS)) { + p = qdisc_lookup(dev, TC_H_MAJ(clid)); + if (!p) { + NL_SET_ERR_MSG(extack, "Failed to find qdisc with specified classid"); + return -ENOENT; + } + q = qdisc_leaf(p, clid); + } else if (dev_ingress_queue(dev)) { + q = rtnl_dereference(dev_ingress_queue(dev)->qdisc_sleeping); + } + } else { + q = rtnl_dereference(dev->qdisc); + } + if (!q) { + NL_SET_ERR_MSG(extack, "Cannot find specified qdisc on specified device"); + return -ENOENT; + } + + if (tcm->tcm_handle && q->handle != tcm->tcm_handle) { + NL_SET_ERR_MSG(extack, "Invalid handle"); + return -EINVAL; + } + } else { + q = qdisc_lookup(dev, tcm->tcm_handle); + if (!q) { + NL_SET_ERR_MSG(extack, "Failed to find qdisc with specified handle"); + return -ENOENT; + } + } + + if (tca[TCA_KIND] && nla_strcmp(tca[TCA_KIND], q->ops->id)) { + NL_SET_ERR_MSG(extack, "Invalid qdisc name"); + return -EINVAL; + } + + if (n->nlmsg_type == RTM_DELQDISC) { + if (!clid) { + NL_SET_ERR_MSG(extack, "Classid cannot be zero"); + return -EINVAL; + } + if (q->handle == 0) { + NL_SET_ERR_MSG(extack, "Cannot delete qdisc with handle of zero"); + return -ENOENT; + } + err = qdisc_graft(dev, p, skb, n, clid, NULL, q, extack); + if (err != 0) + return err; + } else { + qdisc_notify(net, skb, n, clid, NULL, q, NULL); + } + return 0; +} + +static bool req_create_or_replace(struct nlmsghdr *n) +{ + return (n->nlmsg_flags & NLM_F_CREATE && + n->nlmsg_flags & NLM_F_REPLACE); +} + +static bool req_create_exclusive(struct nlmsghdr *n) +{ + return (n->nlmsg_flags & NLM_F_CREATE && + n->nlmsg_flags & NLM_F_EXCL); +} + +static bool req_change(struct nlmsghdr *n) +{ + return (!(n->nlmsg_flags & NLM_F_CREATE) && + !(n->nlmsg_flags & NLM_F_REPLACE) && + !(n->nlmsg_flags & NLM_F_EXCL)); +} + +/* + * Create/change qdisc. + */ +static int tc_modify_qdisc(struct sk_buff *skb, struct nlmsghdr *n, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct tcmsg *tcm; + struct nlattr *tca[TCA_MAX + 1]; + struct net_device *dev; + u32 clid; + struct Qdisc *q, *p; + int err; + +replay: + /* Reinit, just in case something touches this. */ + err = nlmsg_parse_deprecated(n, sizeof(*tcm), tca, TCA_MAX, + rtm_tca_policy, extack); + if (err < 0) + return err; + + tcm = nlmsg_data(n); + clid = tcm->tcm_parent; + q = p = NULL; + + dev = __dev_get_by_index(net, tcm->tcm_ifindex); + if (!dev) + return -ENODEV; + + + if (clid) { + if (clid != TC_H_ROOT) { + if (clid != TC_H_INGRESS) { + p = qdisc_lookup(dev, TC_H_MAJ(clid)); + if (!p) { + NL_SET_ERR_MSG(extack, "Failed to find specified qdisc"); + return -ENOENT; + } + q = qdisc_leaf(p, clid); + } else if (dev_ingress_queue_create(dev)) { + q = rtnl_dereference(dev_ingress_queue(dev)->qdisc_sleeping); + } + } else { + q = rtnl_dereference(dev->qdisc); + } + + /* It may be default qdisc, ignore it */ + if (q && q->handle == 0) + q = NULL; + + if (!q || !tcm->tcm_handle || q->handle != tcm->tcm_handle) { + if (tcm->tcm_handle) { + if (q && !(n->nlmsg_flags & NLM_F_REPLACE)) { + NL_SET_ERR_MSG(extack, "NLM_F_REPLACE needed to override"); + return -EEXIST; + } + if (TC_H_MIN(tcm->tcm_handle)) { + NL_SET_ERR_MSG(extack, "Invalid minor handle"); + return -EINVAL; + } + q = qdisc_lookup(dev, tcm->tcm_handle); + if (!q) + goto create_n_graft; + if (n->nlmsg_flags & NLM_F_EXCL) { + NL_SET_ERR_MSG(extack, "Exclusivity flag on, cannot override"); + return -EEXIST; + } + if (tca[TCA_KIND] && + nla_strcmp(tca[TCA_KIND], q->ops->id)) { + NL_SET_ERR_MSG(extack, "Invalid qdisc name"); + return -EINVAL; + } + if (q->flags & TCQ_F_INGRESS) { + NL_SET_ERR_MSG(extack, + "Cannot regraft ingress or clsact Qdiscs"); + return -EINVAL; + } + if (q == p || + (p && check_loop(q, p, 0))) { + NL_SET_ERR_MSG(extack, "Qdisc parent/child loop detected"); + return -ELOOP; + } + if (clid == TC_H_INGRESS) { + NL_SET_ERR_MSG(extack, "Ingress cannot graft directly"); + return -EINVAL; + } + qdisc_refcount_inc(q); + goto graft; + } else { + if (!q) + goto create_n_graft; + + /* This magic test requires explanation. + * + * We know, that some child q is already + * attached to this parent and have choice: + * 1) change it or 2) create/graft new one. + * If the requested qdisc kind is different + * than the existing one, then we choose graft. + * If they are the same then this is "change" + * operation - just let it fallthrough.. + * + * 1. We are allowed to create/graft only + * if the request is explicitly stating + * "please create if it doesn't exist". + * + * 2. If the request is to exclusive create + * then the qdisc tcm_handle is not expected + * to exist, so that we choose create/graft too. + * + * 3. The last case is when no flags are set. + * This will happen when for example tc + * utility issues a "change" command. + * Alas, it is sort of hole in API, we + * cannot decide what to do unambiguously. + * For now we select create/graft. + */ + if (tca[TCA_KIND] && + nla_strcmp(tca[TCA_KIND], q->ops->id)) { + if (req_create_or_replace(n) || + req_create_exclusive(n)) + goto create_n_graft; + else if (req_change(n)) + goto create_n_graft2; + } + } + } + } else { + if (!tcm->tcm_handle) { + NL_SET_ERR_MSG(extack, "Handle cannot be zero"); + return -EINVAL; + } + q = qdisc_lookup(dev, tcm->tcm_handle); + } + + /* Change qdisc parameters */ + if (!q) { + NL_SET_ERR_MSG(extack, "Specified qdisc not found"); + return -ENOENT; + } + if (n->nlmsg_flags & NLM_F_EXCL) { + NL_SET_ERR_MSG(extack, "Exclusivity flag on, cannot modify"); + return -EEXIST; + } + if (tca[TCA_KIND] && nla_strcmp(tca[TCA_KIND], q->ops->id)) { + NL_SET_ERR_MSG(extack, "Invalid qdisc name"); + return -EINVAL; + } + err = qdisc_change(q, tca, extack); + if (err == 0) + qdisc_notify(net, skb, n, clid, NULL, q, extack); + return err; + +create_n_graft: + if (!(n->nlmsg_flags & NLM_F_CREATE)) { + NL_SET_ERR_MSG(extack, "Qdisc not found. To create specify NLM_F_CREATE flag"); + return -ENOENT; + } +create_n_graft2: + if (clid == TC_H_INGRESS) { + if (dev_ingress_queue(dev)) { + q = qdisc_create(dev, dev_ingress_queue(dev), + tcm->tcm_parent, tcm->tcm_parent, + tca, &err, extack); + } else { + NL_SET_ERR_MSG(extack, "Cannot find ingress queue for specified device"); + err = -ENOENT; + } + } else { + struct netdev_queue *dev_queue; + + if (p && p->ops->cl_ops && p->ops->cl_ops->select_queue) + dev_queue = p->ops->cl_ops->select_queue(p, tcm); + else if (p) + dev_queue = p->dev_queue; + else + dev_queue = netdev_get_tx_queue(dev, 0); + + q = qdisc_create(dev, dev_queue, + tcm->tcm_parent, tcm->tcm_handle, + tca, &err, extack); + } + if (q == NULL) { + if (err == -EAGAIN) + goto replay; + return err; + } + +graft: + err = qdisc_graft(dev, p, skb, n, clid, q, NULL, extack); + if (err) { + if (q) + qdisc_put(q); + return err; + } + + return 0; +} + +static int tc_dump_qdisc_root(struct Qdisc *root, struct sk_buff *skb, + struct netlink_callback *cb, + int *q_idx_p, int s_q_idx, bool recur, + bool dump_invisible) +{ + int ret = 0, q_idx = *q_idx_p; + struct Qdisc *q; + int b; + + if (!root) + return 0; + + q = root; + if (q_idx < s_q_idx) { + q_idx++; + } else { + if (!tc_qdisc_dump_ignore(q, dump_invisible) && + tc_fill_qdisc(skb, q, q->parent, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + RTM_NEWQDISC, NULL) <= 0) + goto done; + q_idx++; + } + + /* If dumping singletons, there is no qdisc_dev(root) and the singleton + * itself has already been dumped. + * + * If we've already dumped the top-level (ingress) qdisc above and the global + * qdisc hashtable, we don't want to hit it again + */ + if (!qdisc_dev(root) || !recur) + goto out; + + hash_for_each(qdisc_dev(root)->qdisc_hash, b, q, hash) { + if (q_idx < s_q_idx) { + q_idx++; + continue; + } + if (!tc_qdisc_dump_ignore(q, dump_invisible) && + tc_fill_qdisc(skb, q, q->parent, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + RTM_NEWQDISC, NULL) <= 0) + goto done; + q_idx++; + } + +out: + *q_idx_p = q_idx; + return ret; +done: + ret = -1; + goto out; +} + +static int tc_dump_qdisc(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + int idx, q_idx; + int s_idx, s_q_idx; + struct net_device *dev; + const struct nlmsghdr *nlh = cb->nlh; + struct nlattr *tca[TCA_MAX + 1]; + int err; + + s_idx = cb->args[0]; + s_q_idx = q_idx = cb->args[1]; + + idx = 0; + ASSERT_RTNL(); + + err = nlmsg_parse_deprecated(nlh, sizeof(struct tcmsg), tca, TCA_MAX, + rtm_tca_policy, cb->extack); + if (err < 0) + return err; + + for_each_netdev(net, dev) { + struct netdev_queue *dev_queue; + + if (idx < s_idx) + goto cont; + if (idx > s_idx) + s_q_idx = 0; + q_idx = 0; + + if (tc_dump_qdisc_root(rtnl_dereference(dev->qdisc), + skb, cb, &q_idx, s_q_idx, + true, tca[TCA_DUMP_INVISIBLE]) < 0) + goto done; + + dev_queue = dev_ingress_queue(dev); + if (dev_queue && + tc_dump_qdisc_root(rtnl_dereference(dev_queue->qdisc_sleeping), + skb, cb, &q_idx, s_q_idx, false, + tca[TCA_DUMP_INVISIBLE]) < 0) + goto done; + +cont: + idx++; + } + +done: + cb->args[0] = idx; + cb->args[1] = q_idx; + + return skb->len; +} + + + +/************************************************ + * Traffic classes manipulation. * + ************************************************/ + +static int tc_fill_tclass(struct sk_buff *skb, struct Qdisc *q, + unsigned long cl, u32 portid, u32 seq, u16 flags, + int event, struct netlink_ext_ack *extack) +{ + struct tcmsg *tcm; + struct nlmsghdr *nlh; + unsigned char *b = skb_tail_pointer(skb); + struct gnet_dump d; + const struct Qdisc_class_ops *cl_ops = q->ops->cl_ops; + + cond_resched(); + nlh = nlmsg_put(skb, portid, seq, event, sizeof(*tcm), flags); + if (!nlh) + goto out_nlmsg_trim; + tcm = nlmsg_data(nlh); + tcm->tcm_family = AF_UNSPEC; + tcm->tcm__pad1 = 0; + tcm->tcm__pad2 = 0; + tcm->tcm_ifindex = qdisc_dev(q)->ifindex; + tcm->tcm_parent = q->handle; + tcm->tcm_handle = q->handle; + tcm->tcm_info = 0; + if (nla_put_string(skb, TCA_KIND, q->ops->id)) + goto nla_put_failure; + if (cl_ops->dump && cl_ops->dump(q, cl, skb, tcm) < 0) + goto nla_put_failure; + + if (gnet_stats_start_copy_compat(skb, TCA_STATS2, TCA_STATS, TCA_XSTATS, + NULL, &d, TCA_PAD) < 0) + goto nla_put_failure; + + if (cl_ops->dump_stats && cl_ops->dump_stats(q, cl, &d) < 0) + goto nla_put_failure; + + if (gnet_stats_finish_copy(&d) < 0) + goto nla_put_failure; + + if (extack && extack->_msg && + nla_put_string(skb, TCA_EXT_WARN_MSG, extack->_msg)) + goto out_nlmsg_trim; + + nlh->nlmsg_len = skb_tail_pointer(skb) - b; + + return skb->len; + +out_nlmsg_trim: +nla_put_failure: + nlmsg_trim(skb, b); + return -1; +} + +static int tclass_notify(struct net *net, struct sk_buff *oskb, + struct nlmsghdr *n, struct Qdisc *q, + unsigned long cl, int event, struct netlink_ext_ack *extack) +{ + struct sk_buff *skb; + u32 portid = oskb ? NETLINK_CB(oskb).portid : 0; + + skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) + return -ENOBUFS; + + if (tc_fill_tclass(skb, q, cl, portid, n->nlmsg_seq, 0, event, extack) < 0) { + kfree_skb(skb); + return -EINVAL; + } + + return rtnetlink_send(skb, net, portid, RTNLGRP_TC, + n->nlmsg_flags & NLM_F_ECHO); +} + +static int tclass_del_notify(struct net *net, + const struct Qdisc_class_ops *cops, + struct sk_buff *oskb, struct nlmsghdr *n, + struct Qdisc *q, unsigned long cl, + struct netlink_ext_ack *extack) +{ + u32 portid = oskb ? NETLINK_CB(oskb).portid : 0; + struct sk_buff *skb; + int err = 0; + + if (!cops->delete) + return -EOPNOTSUPP; + + skb = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!skb) + return -ENOBUFS; + + if (tc_fill_tclass(skb, q, cl, portid, n->nlmsg_seq, 0, + RTM_DELTCLASS, extack) < 0) { + kfree_skb(skb); + return -EINVAL; + } + + err = cops->delete(q, cl, extack); + if (err) { + kfree_skb(skb); + return err; + } + + err = rtnetlink_send(skb, net, portid, RTNLGRP_TC, + n->nlmsg_flags & NLM_F_ECHO); + return err; +} + +#ifdef CONFIG_NET_CLS + +struct tcf_bind_args { + struct tcf_walker w; + unsigned long base; + unsigned long cl; + u32 classid; +}; + +static int tcf_node_bind(struct tcf_proto *tp, void *n, struct tcf_walker *arg) +{ + struct tcf_bind_args *a = (void *)arg; + + if (n && tp->ops->bind_class) { + struct Qdisc *q = tcf_block_q(tp->chain->block); + + sch_tree_lock(q); + tp->ops->bind_class(n, a->classid, a->cl, q, a->base); + sch_tree_unlock(q); + } + return 0; +} + +struct tc_bind_class_args { + struct qdisc_walker w; + unsigned long new_cl; + u32 portid; + u32 clid; +}; + +static int tc_bind_class_walker(struct Qdisc *q, unsigned long cl, + struct qdisc_walker *w) +{ + struct tc_bind_class_args *a = (struct tc_bind_class_args *)w; + const struct Qdisc_class_ops *cops = q->ops->cl_ops; + struct tcf_block *block; + struct tcf_chain *chain; + + block = cops->tcf_block(q, cl, NULL); + if (!block) + return 0; + for (chain = tcf_get_next_chain(block, NULL); + chain; + chain = tcf_get_next_chain(block, chain)) { + struct tcf_proto *tp; + + for (tp = tcf_get_next_proto(chain, NULL); + tp; tp = tcf_get_next_proto(chain, tp)) { + struct tcf_bind_args arg = {}; + + arg.w.fn = tcf_node_bind; + arg.classid = a->clid; + arg.base = cl; + arg.cl = a->new_cl; + tp->ops->walk(tp, &arg.w, true); + } + } + + return 0; +} + +static void tc_bind_tclass(struct Qdisc *q, u32 portid, u32 clid, + unsigned long new_cl) +{ + const struct Qdisc_class_ops *cops = q->ops->cl_ops; + struct tc_bind_class_args args = {}; + + if (!cops->tcf_block) + return; + args.portid = portid; + args.clid = clid; + args.new_cl = new_cl; + args.w.fn = tc_bind_class_walker; + q->ops->cl_ops->walk(q, &args.w); +} + +#else + +static void tc_bind_tclass(struct Qdisc *q, u32 portid, u32 clid, + unsigned long new_cl) +{ +} + +#endif + +static int tc_ctl_tclass(struct sk_buff *skb, struct nlmsghdr *n, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct tcmsg *tcm = nlmsg_data(n); + struct nlattr *tca[TCA_MAX + 1]; + struct net_device *dev; + struct Qdisc *q = NULL; + const struct Qdisc_class_ops *cops; + unsigned long cl = 0; + unsigned long new_cl; + u32 portid; + u32 clid; + u32 qid; + int err; + + err = nlmsg_parse_deprecated(n, sizeof(*tcm), tca, TCA_MAX, + rtm_tca_policy, extack); + if (err < 0) + return err; + + dev = __dev_get_by_index(net, tcm->tcm_ifindex); + if (!dev) + return -ENODEV; + + /* + parent == TC_H_UNSPEC - unspecified parent. + parent == TC_H_ROOT - class is root, which has no parent. + parent == X:0 - parent is root class. + parent == X:Y - parent is a node in hierarchy. + parent == 0:Y - parent is X:Y, where X:0 is qdisc. + + handle == 0:0 - generate handle from kernel pool. + handle == 0:Y - class is X:Y, where X:0 is qdisc. + handle == X:Y - clear. + handle == X:0 - root class. + */ + + /* Step 1. Determine qdisc handle X:0 */ + + portid = tcm->tcm_parent; + clid = tcm->tcm_handle; + qid = TC_H_MAJ(clid); + + if (portid != TC_H_ROOT) { + u32 qid1 = TC_H_MAJ(portid); + + if (qid && qid1) { + /* If both majors are known, they must be identical. */ + if (qid != qid1) + return -EINVAL; + } else if (qid1) { + qid = qid1; + } else if (qid == 0) + qid = rtnl_dereference(dev->qdisc)->handle; + + /* Now qid is genuine qdisc handle consistent + * both with parent and child. + * + * TC_H_MAJ(portid) still may be unspecified, complete it now. + */ + if (portid) + portid = TC_H_MAKE(qid, portid); + } else { + if (qid == 0) + qid = rtnl_dereference(dev->qdisc)->handle; + } + + /* OK. Locate qdisc */ + q = qdisc_lookup(dev, qid); + if (!q) + return -ENOENT; + + /* An check that it supports classes */ + cops = q->ops->cl_ops; + if (cops == NULL) + return -EINVAL; + + /* Now try to get class */ + if (clid == 0) { + if (portid == TC_H_ROOT) + clid = qid; + } else + clid = TC_H_MAKE(qid, clid); + + if (clid) + cl = cops->find(q, clid); + + if (cl == 0) { + err = -ENOENT; + if (n->nlmsg_type != RTM_NEWTCLASS || + !(n->nlmsg_flags & NLM_F_CREATE)) + goto out; + } else { + switch (n->nlmsg_type) { + case RTM_NEWTCLASS: + err = -EEXIST; + if (n->nlmsg_flags & NLM_F_EXCL) + goto out; + break; + case RTM_DELTCLASS: + err = tclass_del_notify(net, cops, skb, n, q, cl, extack); + /* Unbind the class with flilters with 0 */ + tc_bind_tclass(q, portid, clid, 0); + goto out; + case RTM_GETTCLASS: + err = tclass_notify(net, skb, n, q, cl, RTM_NEWTCLASS, extack); + goto out; + default: + err = -EINVAL; + goto out; + } + } + + if (tca[TCA_INGRESS_BLOCK] || tca[TCA_EGRESS_BLOCK]) { + NL_SET_ERR_MSG(extack, "Shared blocks are not supported for classes"); + return -EOPNOTSUPP; + } + + new_cl = cl; + err = -EOPNOTSUPP; + if (cops->change) + err = cops->change(q, clid, portid, tca, &new_cl, extack); + if (err == 0) { + tclass_notify(net, skb, n, q, new_cl, RTM_NEWTCLASS, extack); + /* We just create a new class, need to do reverse binding. */ + if (cl != new_cl) + tc_bind_tclass(q, portid, clid, new_cl); + } +out: + return err; +} + +struct qdisc_dump_args { + struct qdisc_walker w; + struct sk_buff *skb; + struct netlink_callback *cb; +}; + +static int qdisc_class_dump(struct Qdisc *q, unsigned long cl, + struct qdisc_walker *arg) +{ + struct qdisc_dump_args *a = (struct qdisc_dump_args *)arg; + + return tc_fill_tclass(a->skb, q, cl, NETLINK_CB(a->cb->skb).portid, + a->cb->nlh->nlmsg_seq, NLM_F_MULTI, + RTM_NEWTCLASS, NULL); +} + +static int tc_dump_tclass_qdisc(struct Qdisc *q, struct sk_buff *skb, + struct tcmsg *tcm, struct netlink_callback *cb, + int *t_p, int s_t) +{ + struct qdisc_dump_args arg; + + if (tc_qdisc_dump_ignore(q, false) || + *t_p < s_t || !q->ops->cl_ops || + (tcm->tcm_parent && + TC_H_MAJ(tcm->tcm_parent) != q->handle)) { + (*t_p)++; + return 0; + } + if (*t_p > s_t) + memset(&cb->args[1], 0, sizeof(cb->args)-sizeof(cb->args[0])); + arg.w.fn = qdisc_class_dump; + arg.skb = skb; + arg.cb = cb; + arg.w.stop = 0; + arg.w.skip = cb->args[1]; + arg.w.count = 0; + q->ops->cl_ops->walk(q, &arg.w); + cb->args[1] = arg.w.count; + if (arg.w.stop) + return -1; + (*t_p)++; + return 0; +} + +static int tc_dump_tclass_root(struct Qdisc *root, struct sk_buff *skb, + struct tcmsg *tcm, struct netlink_callback *cb, + int *t_p, int s_t, bool recur) +{ + struct Qdisc *q; + int b; + + if (!root) + return 0; + + if (tc_dump_tclass_qdisc(root, skb, tcm, cb, t_p, s_t) < 0) + return -1; + + if (!qdisc_dev(root) || !recur) + return 0; + + if (tcm->tcm_parent) { + q = qdisc_match_from_root(root, TC_H_MAJ(tcm->tcm_parent)); + if (q && q != root && + tc_dump_tclass_qdisc(q, skb, tcm, cb, t_p, s_t) < 0) + return -1; + return 0; + } + hash_for_each(qdisc_dev(root)->qdisc_hash, b, q, hash) { + if (tc_dump_tclass_qdisc(q, skb, tcm, cb, t_p, s_t) < 0) + return -1; + } + + return 0; +} + +static int tc_dump_tclass(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct tcmsg *tcm = nlmsg_data(cb->nlh); + struct net *net = sock_net(skb->sk); + struct netdev_queue *dev_queue; + struct net_device *dev; + int t, s_t; + + if (nlmsg_len(cb->nlh) < sizeof(*tcm)) + return 0; + dev = dev_get_by_index(net, tcm->tcm_ifindex); + if (!dev) + return 0; + + s_t = cb->args[0]; + t = 0; + + if (tc_dump_tclass_root(rtnl_dereference(dev->qdisc), + skb, tcm, cb, &t, s_t, true) < 0) + goto done; + + dev_queue = dev_ingress_queue(dev); + if (dev_queue && + tc_dump_tclass_root(rtnl_dereference(dev_queue->qdisc_sleeping), + skb, tcm, cb, &t, s_t, false) < 0) + goto done; + +done: + cb->args[0] = t; + + dev_put(dev); + return skb->len; +} + +#ifdef CONFIG_PROC_FS +static int psched_show(struct seq_file *seq, void *v) +{ + seq_printf(seq, "%08x %08x %08x %08x\n", + (u32)NSEC_PER_USEC, (u32)PSCHED_TICKS2NS(1), + 1000000, + (u32)NSEC_PER_SEC / hrtimer_resolution); + + return 0; +} + +static int __net_init psched_net_init(struct net *net) +{ + struct proc_dir_entry *e; + + e = proc_create_single("psched", 0, net->proc_net, psched_show); + if (e == NULL) + return -ENOMEM; + + return 0; +} + +static void __net_exit psched_net_exit(struct net *net) +{ + remove_proc_entry("psched", net->proc_net); +} +#else +static int __net_init psched_net_init(struct net *net) +{ + return 0; +} + +static void __net_exit psched_net_exit(struct net *net) +{ +} +#endif + +static struct pernet_operations psched_net_ops = { + .init = psched_net_init, + .exit = psched_net_exit, +}; + +static int __init pktsched_init(void) +{ + int err; + + err = register_pernet_subsys(&psched_net_ops); + if (err) { + pr_err("pktsched_init: " + "cannot initialize per netns operations\n"); + return err; + } + + register_qdisc(&pfifo_fast_ops); + register_qdisc(&pfifo_qdisc_ops); + register_qdisc(&bfifo_qdisc_ops); + register_qdisc(&pfifo_head_drop_qdisc_ops); + register_qdisc(&mq_qdisc_ops); + register_qdisc(&noqueue_qdisc_ops); + + rtnl_register(PF_UNSPEC, RTM_NEWQDISC, tc_modify_qdisc, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_DELQDISC, tc_get_qdisc, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_GETQDISC, tc_get_qdisc, tc_dump_qdisc, + 0); + rtnl_register(PF_UNSPEC, RTM_NEWTCLASS, tc_ctl_tclass, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_DELTCLASS, tc_ctl_tclass, NULL, 0); + rtnl_register(PF_UNSPEC, RTM_GETTCLASS, tc_ctl_tclass, tc_dump_tclass, + 0); + + return 0; +} + +subsys_initcall(pktsched_init); diff --git a/net/sched/sch_atm.c b/net/sched/sch_atm.c new file mode 100644 index 000000000..4a981ca90 --- /dev/null +++ b/net/sched/sch_atm.c @@ -0,0 +1,706 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* net/sched/sch_atm.c - ATM VC selection "queueing discipline" */ + +/* Written 1998-2000 by Werner Almesberger, EPFL ICA */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include /* for fput */ +#include +#include +#include + +/* + * The ATM queuing discipline provides a framework for invoking classifiers + * (aka "filters"), which in turn select classes of this queuing discipline. + * Each class maps the flow(s) it is handling to a given VC. Multiple classes + * may share the same VC. + * + * When creating a class, VCs are specified by passing the number of the open + * socket descriptor by which the calling process references the VC. The kernel + * keeps the VC open at least until all classes using it are removed. + * + * In this file, most functions are named atm_tc_* to avoid confusion with all + * the atm_* in net/atm. This naming convention differs from what's used in the + * rest of net/sched. + * + * Known bugs: + * - sometimes messes up the IP stack + * - any manipulations besides the few operations described in the README, are + * untested and likely to crash the system + * - should lock the flow while there is data in the queue (?) + */ + +#define VCC2FLOW(vcc) ((struct atm_flow_data *) ((vcc)->user_back)) + +struct atm_flow_data { + struct Qdisc_class_common common; + struct Qdisc *q; /* FIFO, TBF, etc. */ + struct tcf_proto __rcu *filter_list; + struct tcf_block *block; + struct atm_vcc *vcc; /* VCC; NULL if VCC is closed */ + void (*old_pop)(struct atm_vcc *vcc, + struct sk_buff *skb); /* chaining */ + struct atm_qdisc_data *parent; /* parent qdisc */ + struct socket *sock; /* for closing */ + int ref; /* reference count */ + struct gnet_stats_basic_sync bstats; + struct gnet_stats_queue qstats; + struct list_head list; + struct atm_flow_data *excess; /* flow for excess traffic; + NULL to set CLP instead */ + int hdr_len; + unsigned char hdr[]; /* header data; MUST BE LAST */ +}; + +struct atm_qdisc_data { + struct atm_flow_data link; /* unclassified skbs go here */ + struct list_head flows; /* NB: "link" is also on this + list */ + struct tasklet_struct task; /* dequeue tasklet */ +}; + +/* ------------------------- Class/flow operations ------------------------- */ + +static inline struct atm_flow_data *lookup_flow(struct Qdisc *sch, u32 classid) +{ + struct atm_qdisc_data *p = qdisc_priv(sch); + struct atm_flow_data *flow; + + list_for_each_entry(flow, &p->flows, list) { + if (flow->common.classid == classid) + return flow; + } + return NULL; +} + +static int atm_tc_graft(struct Qdisc *sch, unsigned long arg, + struct Qdisc *new, struct Qdisc **old, + struct netlink_ext_ack *extack) +{ + struct atm_qdisc_data *p = qdisc_priv(sch); + struct atm_flow_data *flow = (struct atm_flow_data *)arg; + + pr_debug("atm_tc_graft(sch %p,[qdisc %p],flow %p,new %p,old %p)\n", + sch, p, flow, new, old); + if (list_empty(&flow->list)) + return -EINVAL; + if (!new) + new = &noop_qdisc; + *old = flow->q; + flow->q = new; + if (*old) + qdisc_reset(*old); + return 0; +} + +static struct Qdisc *atm_tc_leaf(struct Qdisc *sch, unsigned long cl) +{ + struct atm_flow_data *flow = (struct atm_flow_data *)cl; + + pr_debug("atm_tc_leaf(sch %p,flow %p)\n", sch, flow); + return flow ? flow->q : NULL; +} + +static unsigned long atm_tc_find(struct Qdisc *sch, u32 classid) +{ + struct atm_qdisc_data *p __maybe_unused = qdisc_priv(sch); + struct atm_flow_data *flow; + + pr_debug("%s(sch %p,[qdisc %p],classid %x)\n", __func__, sch, p, classid); + flow = lookup_flow(sch, classid); + pr_debug("%s: flow %p\n", __func__, flow); + return (unsigned long)flow; +} + +static unsigned long atm_tc_bind_filter(struct Qdisc *sch, + unsigned long parent, u32 classid) +{ + struct atm_qdisc_data *p __maybe_unused = qdisc_priv(sch); + struct atm_flow_data *flow; + + pr_debug("%s(sch %p,[qdisc %p],classid %x)\n", __func__, sch, p, classid); + flow = lookup_flow(sch, classid); + if (flow) + flow->ref++; + pr_debug("%s: flow %p\n", __func__, flow); + return (unsigned long)flow; +} + +/* + * atm_tc_put handles all destructions, including the ones that are explicitly + * requested (atm_tc_destroy, etc.). The assumption here is that we never drop + * anything that still seems to be in use. + */ +static void atm_tc_put(struct Qdisc *sch, unsigned long cl) +{ + struct atm_qdisc_data *p = qdisc_priv(sch); + struct atm_flow_data *flow = (struct atm_flow_data *)cl; + + pr_debug("atm_tc_put(sch %p,[qdisc %p],flow %p)\n", sch, p, flow); + if (--flow->ref) + return; + pr_debug("atm_tc_put: destroying\n"); + list_del_init(&flow->list); + pr_debug("atm_tc_put: qdisc %p\n", flow->q); + qdisc_put(flow->q); + tcf_block_put(flow->block); + if (flow->sock) { + pr_debug("atm_tc_put: f_count %ld\n", + file_count(flow->sock->file)); + flow->vcc->pop = flow->old_pop; + sockfd_put(flow->sock); + } + if (flow->excess) + atm_tc_put(sch, (unsigned long)flow->excess); + if (flow != &p->link) + kfree(flow); + /* + * If flow == &p->link, the qdisc no longer works at this point and + * needs to be removed. (By the caller of atm_tc_put.) + */ +} + +static void sch_atm_pop(struct atm_vcc *vcc, struct sk_buff *skb) +{ + struct atm_qdisc_data *p = VCC2FLOW(vcc)->parent; + + pr_debug("sch_atm_pop(vcc %p,skb %p,[qdisc %p])\n", vcc, skb, p); + VCC2FLOW(vcc)->old_pop(vcc, skb); + tasklet_schedule(&p->task); +} + +static const u8 llc_oui_ip[] = { + 0xaa, /* DSAP: non-ISO */ + 0xaa, /* SSAP: non-ISO */ + 0x03, /* Ctrl: Unnumbered Information Command PDU */ + 0x00, /* OUI: EtherType */ + 0x00, 0x00, + 0x08, 0x00 +}; /* Ethertype IP (0800) */ + +static const struct nla_policy atm_policy[TCA_ATM_MAX + 1] = { + [TCA_ATM_FD] = { .type = NLA_U32 }, + [TCA_ATM_EXCESS] = { .type = NLA_U32 }, +}; + +static int atm_tc_change(struct Qdisc *sch, u32 classid, u32 parent, + struct nlattr **tca, unsigned long *arg, + struct netlink_ext_ack *extack) +{ + struct atm_qdisc_data *p = qdisc_priv(sch); + struct atm_flow_data *flow = (struct atm_flow_data *)*arg; + struct atm_flow_data *excess = NULL; + struct nlattr *opt = tca[TCA_OPTIONS]; + struct nlattr *tb[TCA_ATM_MAX + 1]; + struct socket *sock; + int fd, error, hdr_len; + void *hdr; + + pr_debug("atm_tc_change(sch %p,[qdisc %p],classid %x,parent %x," + "flow %p,opt %p)\n", sch, p, classid, parent, flow, opt); + /* + * The concept of parents doesn't apply for this qdisc. + */ + if (parent && parent != TC_H_ROOT && parent != sch->handle) + return -EINVAL; + /* + * ATM classes cannot be changed. In order to change properties of the + * ATM connection, that socket needs to be modified directly (via the + * native ATM API. In order to send a flow to a different VC, the old + * class needs to be removed and a new one added. (This may be changed + * later.) + */ + if (flow) + return -EBUSY; + if (opt == NULL) + return -EINVAL; + + error = nla_parse_nested_deprecated(tb, TCA_ATM_MAX, opt, atm_policy, + NULL); + if (error < 0) + return error; + + if (!tb[TCA_ATM_FD]) + return -EINVAL; + fd = nla_get_u32(tb[TCA_ATM_FD]); + pr_debug("atm_tc_change: fd %d\n", fd); + if (tb[TCA_ATM_HDR]) { + hdr_len = nla_len(tb[TCA_ATM_HDR]); + hdr = nla_data(tb[TCA_ATM_HDR]); + } else { + hdr_len = RFC1483LLC_LEN; + hdr = NULL; /* default LLC/SNAP for IP */ + } + if (!tb[TCA_ATM_EXCESS]) + excess = NULL; + else { + excess = (struct atm_flow_data *) + atm_tc_find(sch, nla_get_u32(tb[TCA_ATM_EXCESS])); + if (!excess) + return -ENOENT; + } + pr_debug("atm_tc_change: type %d, payload %d, hdr_len %d\n", + opt->nla_type, nla_len(opt), hdr_len); + sock = sockfd_lookup(fd, &error); + if (!sock) + return error; /* f_count++ */ + pr_debug("atm_tc_change: f_count %ld\n", file_count(sock->file)); + if (sock->ops->family != PF_ATMSVC && sock->ops->family != PF_ATMPVC) { + error = -EPROTOTYPE; + goto err_out; + } + /* @@@ should check if the socket is really operational or we'll crash + on vcc->send */ + if (classid) { + if (TC_H_MAJ(classid ^ sch->handle)) { + pr_debug("atm_tc_change: classid mismatch\n"); + error = -EINVAL; + goto err_out; + } + } else { + int i; + unsigned long cl; + + for (i = 1; i < 0x8000; i++) { + classid = TC_H_MAKE(sch->handle, 0x8000 | i); + cl = atm_tc_find(sch, classid); + if (!cl) + break; + } + } + pr_debug("atm_tc_change: new id %x\n", classid); + flow = kzalloc(sizeof(struct atm_flow_data) + hdr_len, GFP_KERNEL); + pr_debug("atm_tc_change: flow %p\n", flow); + if (!flow) { + error = -ENOBUFS; + goto err_out; + } + + error = tcf_block_get(&flow->block, &flow->filter_list, sch, + extack); + if (error) { + kfree(flow); + goto err_out; + } + + flow->q = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, classid, + extack); + if (!flow->q) + flow->q = &noop_qdisc; + pr_debug("atm_tc_change: qdisc %p\n", flow->q); + flow->sock = sock; + flow->vcc = ATM_SD(sock); /* speedup */ + flow->vcc->user_back = flow; + pr_debug("atm_tc_change: vcc %p\n", flow->vcc); + flow->old_pop = flow->vcc->pop; + flow->parent = p; + flow->vcc->pop = sch_atm_pop; + flow->common.classid = classid; + flow->ref = 1; + flow->excess = excess; + list_add(&flow->list, &p->link.list); + flow->hdr_len = hdr_len; + if (hdr) + memcpy(flow->hdr, hdr, hdr_len); + else + memcpy(flow->hdr, llc_oui_ip, sizeof(llc_oui_ip)); + *arg = (unsigned long)flow; + return 0; +err_out: + sockfd_put(sock); + return error; +} + +static int atm_tc_delete(struct Qdisc *sch, unsigned long arg, + struct netlink_ext_ack *extack) +{ + struct atm_qdisc_data *p = qdisc_priv(sch); + struct atm_flow_data *flow = (struct atm_flow_data *)arg; + + pr_debug("atm_tc_delete(sch %p,[qdisc %p],flow %p)\n", sch, p, flow); + if (list_empty(&flow->list)) + return -EINVAL; + if (rcu_access_pointer(flow->filter_list) || flow == &p->link) + return -EBUSY; + /* + * Reference count must be 2: one for "keepalive" (set at class + * creation), and one for the reference held when calling delete. + */ + if (flow->ref < 2) { + pr_err("atm_tc_delete: flow->ref == %d\n", flow->ref); + return -EINVAL; + } + if (flow->ref > 2) + return -EBUSY; /* catch references via excess, etc. */ + atm_tc_put(sch, arg); + return 0; +} + +static void atm_tc_walk(struct Qdisc *sch, struct qdisc_walker *walker) +{ + struct atm_qdisc_data *p = qdisc_priv(sch); + struct atm_flow_data *flow; + + pr_debug("atm_tc_walk(sch %p,[qdisc %p],walker %p)\n", sch, p, walker); + if (walker->stop) + return; + list_for_each_entry(flow, &p->flows, list) { + if (!tc_qdisc_stats_dump(sch, (unsigned long)flow, walker)) + break; + } +} + +static struct tcf_block *atm_tc_tcf_block(struct Qdisc *sch, unsigned long cl, + struct netlink_ext_ack *extack) +{ + struct atm_qdisc_data *p = qdisc_priv(sch); + struct atm_flow_data *flow = (struct atm_flow_data *)cl; + + pr_debug("atm_tc_find_tcf(sch %p,[qdisc %p],flow %p)\n", sch, p, flow); + return flow ? flow->block : p->link.block; +} + +/* --------------------------- Qdisc operations ---------------------------- */ + +static int atm_tc_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct atm_qdisc_data *p = qdisc_priv(sch); + struct atm_flow_data *flow; + struct tcf_result res; + int result; + int ret = NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; + + pr_debug("atm_tc_enqueue(skb %p,sch %p,[qdisc %p])\n", skb, sch, p); + result = TC_ACT_OK; /* be nice to gcc */ + flow = NULL; + if (TC_H_MAJ(skb->priority) != sch->handle || + !(flow = (struct atm_flow_data *)atm_tc_find(sch, skb->priority))) { + struct tcf_proto *fl; + + list_for_each_entry(flow, &p->flows, list) { + fl = rcu_dereference_bh(flow->filter_list); + if (fl) { + result = tcf_classify(skb, NULL, fl, &res, true); + if (result < 0) + continue; + if (result == TC_ACT_SHOT) + goto done; + + flow = (struct atm_flow_data *)res.class; + if (!flow) + flow = lookup_flow(sch, res.classid); + goto drop; + } + } + flow = NULL; +done: + ; + } + if (!flow) { + flow = &p->link; + } else { + if (flow->vcc) + ATM_SKB(skb)->atm_options = flow->vcc->atm_options; + /*@@@ looks good ... but it's not supposed to work :-) */ +#ifdef CONFIG_NET_CLS_ACT + switch (result) { + case TC_ACT_QUEUED: + case TC_ACT_STOLEN: + case TC_ACT_TRAP: + __qdisc_drop(skb, to_free); + return NET_XMIT_SUCCESS | __NET_XMIT_STOLEN; + case TC_ACT_SHOT: + __qdisc_drop(skb, to_free); + goto drop; + case TC_ACT_RECLASSIFY: + if (flow->excess) + flow = flow->excess; + else + ATM_SKB(skb)->atm_options |= ATM_ATMOPT_CLP; + break; + } +#endif + } + + ret = qdisc_enqueue(skb, flow->q, to_free); + if (ret != NET_XMIT_SUCCESS) { +drop: __maybe_unused + if (net_xmit_drop_count(ret)) { + qdisc_qstats_drop(sch); + if (flow) + flow->qstats.drops++; + } + return ret; + } + /* + * Okay, this may seem weird. We pretend we've dropped the packet if + * it goes via ATM. The reason for this is that the outer qdisc + * expects to be able to q->dequeue the packet later on if we return + * success at this place. Also, sch->q.qdisc needs to reflect whether + * there is a packet egligible for dequeuing or not. Note that the + * statistics of the outer qdisc are necessarily wrong because of all + * this. There's currently no correct solution for this. + */ + if (flow == &p->link) { + sch->q.qlen++; + return NET_XMIT_SUCCESS; + } + tasklet_schedule(&p->task); + return NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; +} + +/* + * Dequeue packets and send them over ATM. Note that we quite deliberately + * avoid checking net_device's flow control here, simply because sch_atm + * uses its own channels, which have nothing to do with any CLIP/LANE/or + * non-ATM interfaces. + */ + +static void sch_atm_dequeue(struct tasklet_struct *t) +{ + struct atm_qdisc_data *p = from_tasklet(p, t, task); + struct Qdisc *sch = qdisc_from_priv(p); + struct atm_flow_data *flow; + struct sk_buff *skb; + + pr_debug("sch_atm_dequeue(sch %p,[qdisc %p])\n", sch, p); + list_for_each_entry(flow, &p->flows, list) { + if (flow == &p->link) + continue; + /* + * If traffic is properly shaped, this won't generate nasty + * little bursts. Otherwise, it may ... (but that's okay) + */ + while ((skb = flow->q->ops->peek(flow->q))) { + if (!atm_may_send(flow->vcc, skb->truesize)) + break; + + skb = qdisc_dequeue_peeked(flow->q); + if (unlikely(!skb)) + break; + + qdisc_bstats_update(sch, skb); + bstats_update(&flow->bstats, skb); + pr_debug("atm_tc_dequeue: sending on class %p\n", flow); + /* remove any LL header somebody else has attached */ + skb_pull(skb, skb_network_offset(skb)); + if (skb_headroom(skb) < flow->hdr_len) { + struct sk_buff *new; + + new = skb_realloc_headroom(skb, flow->hdr_len); + dev_kfree_skb(skb); + if (!new) + continue; + skb = new; + } + pr_debug("sch_atm_dequeue: ip %p, data %p\n", + skb_network_header(skb), skb->data); + ATM_SKB(skb)->vcc = flow->vcc; + memcpy(skb_push(skb, flow->hdr_len), flow->hdr, + flow->hdr_len); + refcount_add(skb->truesize, + &sk_atm(flow->vcc)->sk_wmem_alloc); + /* atm.atm_options are already set by atm_tc_enqueue */ + flow->vcc->send(flow->vcc, skb); + } + } +} + +static struct sk_buff *atm_tc_dequeue(struct Qdisc *sch) +{ + struct atm_qdisc_data *p = qdisc_priv(sch); + struct sk_buff *skb; + + pr_debug("atm_tc_dequeue(sch %p,[qdisc %p])\n", sch, p); + tasklet_schedule(&p->task); + skb = qdisc_dequeue_peeked(p->link.q); + if (skb) + sch->q.qlen--; + return skb; +} + +static struct sk_buff *atm_tc_peek(struct Qdisc *sch) +{ + struct atm_qdisc_data *p = qdisc_priv(sch); + + pr_debug("atm_tc_peek(sch %p,[qdisc %p])\n", sch, p); + + return p->link.q->ops->peek(p->link.q); +} + +static int atm_tc_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct atm_qdisc_data *p = qdisc_priv(sch); + int err; + + pr_debug("atm_tc_init(sch %p,[qdisc %p],opt %p)\n", sch, p, opt); + INIT_LIST_HEAD(&p->flows); + INIT_LIST_HEAD(&p->link.list); + gnet_stats_basic_sync_init(&p->link.bstats); + list_add(&p->link.list, &p->flows); + p->link.q = qdisc_create_dflt(sch->dev_queue, + &pfifo_qdisc_ops, sch->handle, extack); + if (!p->link.q) + p->link.q = &noop_qdisc; + pr_debug("atm_tc_init: link (%p) qdisc %p\n", &p->link, p->link.q); + p->link.vcc = NULL; + p->link.sock = NULL; + p->link.common.classid = sch->handle; + p->link.ref = 1; + + err = tcf_block_get(&p->link.block, &p->link.filter_list, sch, + extack); + if (err) + return err; + + tasklet_setup(&p->task, sch_atm_dequeue); + return 0; +} + +static void atm_tc_reset(struct Qdisc *sch) +{ + struct atm_qdisc_data *p = qdisc_priv(sch); + struct atm_flow_data *flow; + + pr_debug("atm_tc_reset(sch %p,[qdisc %p])\n", sch, p); + list_for_each_entry(flow, &p->flows, list) + qdisc_reset(flow->q); +} + +static void atm_tc_destroy(struct Qdisc *sch) +{ + struct atm_qdisc_data *p = qdisc_priv(sch); + struct atm_flow_data *flow, *tmp; + + pr_debug("atm_tc_destroy(sch %p,[qdisc %p])\n", sch, p); + list_for_each_entry(flow, &p->flows, list) { + tcf_block_put(flow->block); + flow->block = NULL; + } + + list_for_each_entry_safe(flow, tmp, &p->flows, list) { + if (flow->ref > 1) + pr_err("atm_destroy: %p->ref = %d\n", flow, flow->ref); + atm_tc_put(sch, (unsigned long)flow); + } + tasklet_kill(&p->task); +} + +static int atm_tc_dump_class(struct Qdisc *sch, unsigned long cl, + struct sk_buff *skb, struct tcmsg *tcm) +{ + struct atm_qdisc_data *p = qdisc_priv(sch); + struct atm_flow_data *flow = (struct atm_flow_data *)cl; + struct nlattr *nest; + + pr_debug("atm_tc_dump_class(sch %p,[qdisc %p],flow %p,skb %p,tcm %p)\n", + sch, p, flow, skb, tcm); + if (list_empty(&flow->list)) + return -EINVAL; + tcm->tcm_handle = flow->common.classid; + tcm->tcm_info = flow->q->handle; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + + if (nla_put(skb, TCA_ATM_HDR, flow->hdr_len, flow->hdr)) + goto nla_put_failure; + if (flow->vcc) { + struct sockaddr_atmpvc pvc; + int state; + + memset(&pvc, 0, sizeof(pvc)); + pvc.sap_family = AF_ATMPVC; + pvc.sap_addr.itf = flow->vcc->dev ? flow->vcc->dev->number : -1; + pvc.sap_addr.vpi = flow->vcc->vpi; + pvc.sap_addr.vci = flow->vcc->vci; + if (nla_put(skb, TCA_ATM_ADDR, sizeof(pvc), &pvc)) + goto nla_put_failure; + state = ATM_VF2VS(flow->vcc->flags); + if (nla_put_u32(skb, TCA_ATM_STATE, state)) + goto nla_put_failure; + } + if (flow->excess) { + if (nla_put_u32(skb, TCA_ATM_EXCESS, flow->common.classid)) + goto nla_put_failure; + } else { + if (nla_put_u32(skb, TCA_ATM_EXCESS, 0)) + goto nla_put_failure; + } + return nla_nest_end(skb, nest); + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} +static int +atm_tc_dump_class_stats(struct Qdisc *sch, unsigned long arg, + struct gnet_dump *d) +{ + struct atm_flow_data *flow = (struct atm_flow_data *)arg; + + if (gnet_stats_copy_basic(d, NULL, &flow->bstats, true) < 0 || + gnet_stats_copy_queue(d, NULL, &flow->qstats, flow->q->q.qlen) < 0) + return -1; + + return 0; +} + +static int atm_tc_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + return 0; +} + +static const struct Qdisc_class_ops atm_class_ops = { + .graft = atm_tc_graft, + .leaf = atm_tc_leaf, + .find = atm_tc_find, + .change = atm_tc_change, + .delete = atm_tc_delete, + .walk = atm_tc_walk, + .tcf_block = atm_tc_tcf_block, + .bind_tcf = atm_tc_bind_filter, + .unbind_tcf = atm_tc_put, + .dump = atm_tc_dump_class, + .dump_stats = atm_tc_dump_class_stats, +}; + +static struct Qdisc_ops atm_qdisc_ops __read_mostly = { + .cl_ops = &atm_class_ops, + .id = "atm", + .priv_size = sizeof(struct atm_qdisc_data), + .enqueue = atm_tc_enqueue, + .dequeue = atm_tc_dequeue, + .peek = atm_tc_peek, + .init = atm_tc_init, + .reset = atm_tc_reset, + .destroy = atm_tc_destroy, + .dump = atm_tc_dump, + .owner = THIS_MODULE, +}; + +static int __init atm_init(void) +{ + return register_qdisc(&atm_qdisc_ops); +} + +static void __exit atm_exit(void) +{ + unregister_qdisc(&atm_qdisc_ops); +} + +module_init(atm_init) +module_exit(atm_exit) +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_blackhole.c b/net/sched/sch_blackhole.c new file mode 100644 index 000000000..a7f7667ae --- /dev/null +++ b/net/sched/sch_blackhole.c @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/sch_blackhole.c Black hole queue + * + * Authors: Thomas Graf + * + * Note: Quantum tunneling is not supported. + */ + +#include +#include +#include +#include +#include + +static int blackhole_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + qdisc_drop(skb, sch, to_free); + return NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; +} + +static struct sk_buff *blackhole_dequeue(struct Qdisc *sch) +{ + return NULL; +} + +static struct Qdisc_ops blackhole_qdisc_ops __read_mostly = { + .id = "blackhole", + .priv_size = 0, + .enqueue = blackhole_enqueue, + .dequeue = blackhole_dequeue, + .peek = blackhole_dequeue, + .owner = THIS_MODULE, +}; + +static int __init blackhole_init(void) +{ + return register_qdisc(&blackhole_qdisc_ops); +} +device_initcall(blackhole_init) diff --git a/net/sched/sch_cake.c b/net/sched/sch_cake.c new file mode 100644 index 000000000..3ed0c3342 --- /dev/null +++ b/net/sched/sch_cake.c @@ -0,0 +1,3120 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause + +/* COMMON Applications Kept Enhanced (CAKE) discipline + * + * Copyright (C) 2014-2018 Jonathan Morton + * Copyright (C) 2015-2018 Toke Høiland-Jørgensen + * Copyright (C) 2014-2018 Dave Täht + * Copyright (C) 2015-2018 Sebastian Moeller + * (C) 2015-2018 Kevin Darbyshire-Bryant + * Copyright (C) 2017-2018 Ryan Mounce + * + * The CAKE Principles: + * (or, how to have your cake and eat it too) + * + * This is a combination of several shaping, AQM and FQ techniques into one + * easy-to-use package: + * + * - An overall bandwidth shaper, to move the bottleneck away from dumb CPE + * equipment and bloated MACs. This operates in deficit mode (as in sch_fq), + * eliminating the need for any sort of burst parameter (eg. token bucket + * depth). Burst support is limited to that necessary to overcome scheduling + * latency. + * + * - A Diffserv-aware priority queue, giving more priority to certain classes, + * up to a specified fraction of bandwidth. Above that bandwidth threshold, + * the priority is reduced to avoid starving other tins. + * + * - Each priority tin has a separate Flow Queue system, to isolate traffic + * flows from each other. This prevents a burst on one flow from increasing + * the delay to another. Flows are distributed to queues using a + * set-associative hash function. + * + * - Each queue is actively managed by Cobalt, which is a combination of the + * Codel and Blue AQM algorithms. This serves flows fairly, and signals + * congestion early via ECN (if available) and/or packet drops, to keep + * latency low. The codel parameters are auto-tuned based on the bandwidth + * setting, as is necessary at low bandwidths. + * + * The configuration parameters are kept deliberately simple for ease of use. + * Everything has sane defaults. Complete generality of configuration is *not* + * a goal. + * + * The priority queue operates according to a weighted DRR scheme, combined with + * a bandwidth tracker which reuses the shaper logic to detect which side of the + * bandwidth sharing threshold the tin is operating. This determines whether a + * priority-based weight (high) or a bandwidth-based weight (low) is used for + * that tin in the current pass. + * + * This qdisc was inspired by Eric Dumazet's fq_codel code, which he kindly + * granted us permission to leverage. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_NF_CONNTRACK) +#include +#endif + +#define CAKE_SET_WAYS (8) +#define CAKE_MAX_TINS (8) +#define CAKE_QUEUES (1024) +#define CAKE_FLOW_MASK 63 +#define CAKE_FLOW_NAT_FLAG 64 + +/* struct cobalt_params - contains codel and blue parameters + * @interval: codel initial drop rate + * @target: maximum persistent sojourn time & blue update rate + * @mtu_time: serialisation delay of maximum-size packet + * @p_inc: increment of blue drop probability (0.32 fxp) + * @p_dec: decrement of blue drop probability (0.32 fxp) + */ +struct cobalt_params { + u64 interval; + u64 target; + u64 mtu_time; + u32 p_inc; + u32 p_dec; +}; + +/* struct cobalt_vars - contains codel and blue variables + * @count: codel dropping frequency + * @rec_inv_sqrt: reciprocal value of sqrt(count) >> 1 + * @drop_next: time to drop next packet, or when we dropped last + * @blue_timer: Blue time to next drop + * @p_drop: BLUE drop probability (0.32 fxp) + * @dropping: set if in dropping state + * @ecn_marked: set if marked + */ +struct cobalt_vars { + u32 count; + u32 rec_inv_sqrt; + ktime_t drop_next; + ktime_t blue_timer; + u32 p_drop; + bool dropping; + bool ecn_marked; +}; + +enum { + CAKE_SET_NONE = 0, + CAKE_SET_SPARSE, + CAKE_SET_SPARSE_WAIT, /* counted in SPARSE, actually in BULK */ + CAKE_SET_BULK, + CAKE_SET_DECAYING +}; + +struct cake_flow { + /* this stuff is all needed per-flow at dequeue time */ + struct sk_buff *head; + struct sk_buff *tail; + struct list_head flowchain; + s32 deficit; + u32 dropped; + struct cobalt_vars cvars; + u16 srchost; /* index into cake_host table */ + u16 dsthost; + u8 set; +}; /* please try to keep this structure <= 64 bytes */ + +struct cake_host { + u32 srchost_tag; + u32 dsthost_tag; + u16 srchost_bulk_flow_count; + u16 dsthost_bulk_flow_count; +}; + +struct cake_heap_entry { + u16 t:3, b:10; +}; + +struct cake_tin_data { + struct cake_flow flows[CAKE_QUEUES]; + u32 backlogs[CAKE_QUEUES]; + u32 tags[CAKE_QUEUES]; /* for set association */ + u16 overflow_idx[CAKE_QUEUES]; + struct cake_host hosts[CAKE_QUEUES]; /* for triple isolation */ + u16 flow_quantum; + + struct cobalt_params cparams; + u32 drop_overlimit; + u16 bulk_flow_count; + u16 sparse_flow_count; + u16 decaying_flow_count; + u16 unresponsive_flow_count; + + u32 max_skblen; + + struct list_head new_flows; + struct list_head old_flows; + struct list_head decaying_flows; + + /* time_next = time_this + ((len * rate_ns) >> rate_shft) */ + ktime_t time_next_packet; + u64 tin_rate_ns; + u64 tin_rate_bps; + u16 tin_rate_shft; + + u16 tin_quantum; + s32 tin_deficit; + u32 tin_backlog; + u32 tin_dropped; + u32 tin_ecn_mark; + + u32 packets; + u64 bytes; + + u32 ack_drops; + + /* moving averages */ + u64 avge_delay; + u64 peak_delay; + u64 base_delay; + + /* hash function stats */ + u32 way_directs; + u32 way_hits; + u32 way_misses; + u32 way_collisions; +}; /* number of tins is small, so size of this struct doesn't matter much */ + +struct cake_sched_data { + struct tcf_proto __rcu *filter_list; /* optional external classifier */ + struct tcf_block *block; + struct cake_tin_data *tins; + + struct cake_heap_entry overflow_heap[CAKE_QUEUES * CAKE_MAX_TINS]; + u16 overflow_timeout; + + u16 tin_cnt; + u8 tin_mode; + u8 flow_mode; + u8 ack_filter; + u8 atm_mode; + + u32 fwmark_mask; + u16 fwmark_shft; + + /* time_next = time_this + ((len * rate_ns) >> rate_shft) */ + u16 rate_shft; + ktime_t time_next_packet; + ktime_t failsafe_next_packet; + u64 rate_ns; + u64 rate_bps; + u16 rate_flags; + s16 rate_overhead; + u16 rate_mpu; + u64 interval; + u64 target; + + /* resource tracking */ + u32 buffer_used; + u32 buffer_max_used; + u32 buffer_limit; + u32 buffer_config_limit; + + /* indices for dequeue */ + u16 cur_tin; + u16 cur_flow; + + struct qdisc_watchdog watchdog; + const u8 *tin_index; + const u8 *tin_order; + + /* bandwidth capacity estimate */ + ktime_t last_packet_time; + ktime_t avg_window_begin; + u64 avg_packet_interval; + u64 avg_window_bytes; + u64 avg_peak_bandwidth; + ktime_t last_reconfig_time; + + /* packet length stats */ + u32 avg_netoff; + u16 max_netlen; + u16 max_adjlen; + u16 min_netlen; + u16 min_adjlen; +}; + +enum { + CAKE_FLAG_OVERHEAD = BIT(0), + CAKE_FLAG_AUTORATE_INGRESS = BIT(1), + CAKE_FLAG_INGRESS = BIT(2), + CAKE_FLAG_WASH = BIT(3), + CAKE_FLAG_SPLIT_GSO = BIT(4) +}; + +/* COBALT operates the Codel and BLUE algorithms in parallel, in order to + * obtain the best features of each. Codel is excellent on flows which + * respond to congestion signals in a TCP-like way. BLUE is more effective on + * unresponsive flows. + */ + +struct cobalt_skb_cb { + ktime_t enqueue_time; + u32 adjusted_len; +}; + +static u64 us_to_ns(u64 us) +{ + return us * NSEC_PER_USEC; +} + +static struct cobalt_skb_cb *get_cobalt_cb(const struct sk_buff *skb) +{ + qdisc_cb_private_validate(skb, sizeof(struct cobalt_skb_cb)); + return (struct cobalt_skb_cb *)qdisc_skb_cb(skb)->data; +} + +static ktime_t cobalt_get_enqueue_time(const struct sk_buff *skb) +{ + return get_cobalt_cb(skb)->enqueue_time; +} + +static void cobalt_set_enqueue_time(struct sk_buff *skb, + ktime_t now) +{ + get_cobalt_cb(skb)->enqueue_time = now; +} + +static u16 quantum_div[CAKE_QUEUES + 1] = {0}; + +/* Diffserv lookup tables */ + +static const u8 precedence[] = { + 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, + 7, 7, 7, 7, 7, 7, 7, 7, +}; + +static const u8 diffserv8[] = { + 2, 0, 1, 2, 4, 2, 2, 2, + 1, 2, 1, 2, 1, 2, 1, 2, + 5, 2, 4, 2, 4, 2, 4, 2, + 3, 2, 3, 2, 3, 2, 3, 2, + 6, 2, 3, 2, 3, 2, 3, 2, + 6, 2, 2, 2, 6, 2, 6, 2, + 7, 2, 2, 2, 2, 2, 2, 2, + 7, 2, 2, 2, 2, 2, 2, 2, +}; + +static const u8 diffserv4[] = { + 0, 1, 0, 0, 2, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 2, 0, 2, 0, 2, 0, 2, 0, + 2, 0, 2, 0, 2, 0, 2, 0, + 3, 0, 2, 0, 2, 0, 2, 0, + 3, 0, 0, 0, 3, 0, 3, 0, + 3, 0, 0, 0, 0, 0, 0, 0, + 3, 0, 0, 0, 0, 0, 0, 0, +}; + +static const u8 diffserv3[] = { + 0, 1, 0, 0, 2, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 2, 0, 2, 0, + 2, 0, 0, 0, 0, 0, 0, 0, + 2, 0, 0, 0, 0, 0, 0, 0, +}; + +static const u8 besteffort[] = { + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, +}; + +/* tin priority order for stats dumping */ + +static const u8 normal_order[] = {0, 1, 2, 3, 4, 5, 6, 7}; +static const u8 bulk_order[] = {1, 0, 2, 3}; + +#define REC_INV_SQRT_CACHE (16) +static u32 cobalt_rec_inv_sqrt_cache[REC_INV_SQRT_CACHE] = {0}; + +/* http://en.wikipedia.org/wiki/Methods_of_computing_square_roots + * new_invsqrt = (invsqrt / 2) * (3 - count * invsqrt^2) + * + * Here, invsqrt is a fixed point number (< 1.0), 32bit mantissa, aka Q0.32 + */ + +static void cobalt_newton_step(struct cobalt_vars *vars) +{ + u32 invsqrt, invsqrt2; + u64 val; + + invsqrt = vars->rec_inv_sqrt; + invsqrt2 = ((u64)invsqrt * invsqrt) >> 32; + val = (3LL << 32) - ((u64)vars->count * invsqrt2); + + val >>= 2; /* avoid overflow in following multiply */ + val = (val * invsqrt) >> (32 - 2 + 1); + + vars->rec_inv_sqrt = val; +} + +static void cobalt_invsqrt(struct cobalt_vars *vars) +{ + if (vars->count < REC_INV_SQRT_CACHE) + vars->rec_inv_sqrt = cobalt_rec_inv_sqrt_cache[vars->count]; + else + cobalt_newton_step(vars); +} + +/* There is a big difference in timing between the accurate values placed in + * the cache and the approximations given by a single Newton step for small + * count values, particularly when stepping from count 1 to 2 or vice versa. + * Above 16, a single Newton step gives sufficient accuracy in either + * direction, given the precision stored. + * + * The magnitude of the error when stepping up to count 2 is such as to give + * the value that *should* have been produced at count 4. + */ + +static void cobalt_cache_init(void) +{ + struct cobalt_vars v; + + memset(&v, 0, sizeof(v)); + v.rec_inv_sqrt = ~0U; + cobalt_rec_inv_sqrt_cache[0] = v.rec_inv_sqrt; + + for (v.count = 1; v.count < REC_INV_SQRT_CACHE; v.count++) { + cobalt_newton_step(&v); + cobalt_newton_step(&v); + cobalt_newton_step(&v); + cobalt_newton_step(&v); + + cobalt_rec_inv_sqrt_cache[v.count] = v.rec_inv_sqrt; + } +} + +static void cobalt_vars_init(struct cobalt_vars *vars) +{ + memset(vars, 0, sizeof(*vars)); + + if (!cobalt_rec_inv_sqrt_cache[0]) { + cobalt_cache_init(); + cobalt_rec_inv_sqrt_cache[0] = ~0; + } +} + +/* CoDel control_law is t + interval/sqrt(count) + * We maintain in rec_inv_sqrt the reciprocal value of sqrt(count) to avoid + * both sqrt() and divide operation. + */ +static ktime_t cobalt_control(ktime_t t, + u64 interval, + u32 rec_inv_sqrt) +{ + return ktime_add_ns(t, reciprocal_scale(interval, + rec_inv_sqrt)); +} + +/* Call this when a packet had to be dropped due to queue overflow. Returns + * true if the BLUE state was quiescent before but active after this call. + */ +static bool cobalt_queue_full(struct cobalt_vars *vars, + struct cobalt_params *p, + ktime_t now) +{ + bool up = false; + + if (ktime_to_ns(ktime_sub(now, vars->blue_timer)) > p->target) { + up = !vars->p_drop; + vars->p_drop += p->p_inc; + if (vars->p_drop < p->p_inc) + vars->p_drop = ~0; + vars->blue_timer = now; + } + vars->dropping = true; + vars->drop_next = now; + if (!vars->count) + vars->count = 1; + + return up; +} + +/* Call this when the queue was serviced but turned out to be empty. Returns + * true if the BLUE state was active before but quiescent after this call. + */ +static bool cobalt_queue_empty(struct cobalt_vars *vars, + struct cobalt_params *p, + ktime_t now) +{ + bool down = false; + + if (vars->p_drop && + ktime_to_ns(ktime_sub(now, vars->blue_timer)) > p->target) { + if (vars->p_drop < p->p_dec) + vars->p_drop = 0; + else + vars->p_drop -= p->p_dec; + vars->blue_timer = now; + down = !vars->p_drop; + } + vars->dropping = false; + + if (vars->count && ktime_to_ns(ktime_sub(now, vars->drop_next)) >= 0) { + vars->count--; + cobalt_invsqrt(vars); + vars->drop_next = cobalt_control(vars->drop_next, + p->interval, + vars->rec_inv_sqrt); + } + + return down; +} + +/* Call this with a freshly dequeued packet for possible congestion marking. + * Returns true as an instruction to drop the packet, false for delivery. + */ +static bool cobalt_should_drop(struct cobalt_vars *vars, + struct cobalt_params *p, + ktime_t now, + struct sk_buff *skb, + u32 bulk_flows) +{ + bool next_due, over_target, drop = false; + ktime_t schedule; + u64 sojourn; + +/* The 'schedule' variable records, in its sign, whether 'now' is before or + * after 'drop_next'. This allows 'drop_next' to be updated before the next + * scheduling decision is actually branched, without destroying that + * information. Similarly, the first 'schedule' value calculated is preserved + * in the boolean 'next_due'. + * + * As for 'drop_next', we take advantage of the fact that 'interval' is both + * the delay between first exceeding 'target' and the first signalling event, + * *and* the scaling factor for the signalling frequency. It's therefore very + * natural to use a single mechanism for both purposes, and eliminates a + * significant amount of reference Codel's spaghetti code. To help with this, + * both the '0' and '1' entries in the invsqrt cache are 0xFFFFFFFF, as close + * as possible to 1.0 in fixed-point. + */ + + sojourn = ktime_to_ns(ktime_sub(now, cobalt_get_enqueue_time(skb))); + schedule = ktime_sub(now, vars->drop_next); + over_target = sojourn > p->target && + sojourn > p->mtu_time * bulk_flows * 2 && + sojourn > p->mtu_time * 4; + next_due = vars->count && ktime_to_ns(schedule) >= 0; + + vars->ecn_marked = false; + + if (over_target) { + if (!vars->dropping) { + vars->dropping = true; + vars->drop_next = cobalt_control(now, + p->interval, + vars->rec_inv_sqrt); + } + if (!vars->count) + vars->count = 1; + } else if (vars->dropping) { + vars->dropping = false; + } + + if (next_due && vars->dropping) { + /* Use ECN mark if possible, otherwise drop */ + drop = !(vars->ecn_marked = INET_ECN_set_ce(skb)); + + vars->count++; + if (!vars->count) + vars->count--; + cobalt_invsqrt(vars); + vars->drop_next = cobalt_control(vars->drop_next, + p->interval, + vars->rec_inv_sqrt); + schedule = ktime_sub(now, vars->drop_next); + } else { + while (next_due) { + vars->count--; + cobalt_invsqrt(vars); + vars->drop_next = cobalt_control(vars->drop_next, + p->interval, + vars->rec_inv_sqrt); + schedule = ktime_sub(now, vars->drop_next); + next_due = vars->count && ktime_to_ns(schedule) >= 0; + } + } + + /* Simple BLUE implementation. Lack of ECN is deliberate. */ + if (vars->p_drop) + drop |= (get_random_u32() < vars->p_drop); + + /* Overload the drop_next field as an activity timeout */ + if (!vars->count) + vars->drop_next = ktime_add_ns(now, p->interval); + else if (ktime_to_ns(schedule) > 0 && !drop) + vars->drop_next = now; + + return drop; +} + +static bool cake_update_flowkeys(struct flow_keys *keys, + const struct sk_buff *skb) +{ +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + struct nf_conntrack_tuple tuple = {}; + bool rev = !skb->_nfct, upd = false; + __be32 ip; + + if (skb_protocol(skb, true) != htons(ETH_P_IP)) + return false; + + if (!nf_ct_get_tuple_skb(&tuple, skb)) + return false; + + ip = rev ? tuple.dst.u3.ip : tuple.src.u3.ip; + if (ip != keys->addrs.v4addrs.src) { + keys->addrs.v4addrs.src = ip; + upd = true; + } + ip = rev ? tuple.src.u3.ip : tuple.dst.u3.ip; + if (ip != keys->addrs.v4addrs.dst) { + keys->addrs.v4addrs.dst = ip; + upd = true; + } + + if (keys->ports.ports) { + __be16 port; + + port = rev ? tuple.dst.u.all : tuple.src.u.all; + if (port != keys->ports.src) { + keys->ports.src = port; + upd = true; + } + port = rev ? tuple.src.u.all : tuple.dst.u.all; + if (port != keys->ports.dst) { + port = keys->ports.dst; + upd = true; + } + } + return upd; +#else + return false; +#endif +} + +/* Cake has several subtle multiple bit settings. In these cases you + * would be matching triple isolate mode as well. + */ + +static bool cake_dsrc(int flow_mode) +{ + return (flow_mode & CAKE_FLOW_DUAL_SRC) == CAKE_FLOW_DUAL_SRC; +} + +static bool cake_ddst(int flow_mode) +{ + return (flow_mode & CAKE_FLOW_DUAL_DST) == CAKE_FLOW_DUAL_DST; +} + +static u32 cake_hash(struct cake_tin_data *q, const struct sk_buff *skb, + int flow_mode, u16 flow_override, u16 host_override) +{ + bool hash_flows = (!flow_override && !!(flow_mode & CAKE_FLOW_FLOWS)); + bool hash_hosts = (!host_override && !!(flow_mode & CAKE_FLOW_HOSTS)); + bool nat_enabled = !!(flow_mode & CAKE_FLOW_NAT_FLAG); + u32 flow_hash = 0, srchost_hash = 0, dsthost_hash = 0; + u16 reduced_hash, srchost_idx, dsthost_idx; + struct flow_keys keys, host_keys; + bool use_skbhash = skb->l4_hash; + + if (unlikely(flow_mode == CAKE_FLOW_NONE)) + return 0; + + /* If both overrides are set, or we can use the SKB hash and nat mode is + * disabled, we can skip packet dissection entirely. If nat mode is + * enabled there's another check below after doing the conntrack lookup. + */ + if ((!hash_flows || (use_skbhash && !nat_enabled)) && !hash_hosts) + goto skip_hash; + + skb_flow_dissect_flow_keys(skb, &keys, + FLOW_DISSECTOR_F_STOP_AT_FLOW_LABEL); + + /* Don't use the SKB hash if we change the lookup keys from conntrack */ + if (nat_enabled && cake_update_flowkeys(&keys, skb)) + use_skbhash = false; + + /* If we can still use the SKB hash and don't need the host hash, we can + * skip the rest of the hashing procedure + */ + if (use_skbhash && !hash_hosts) + goto skip_hash; + + /* flow_hash_from_keys() sorts the addresses by value, so we have + * to preserve their order in a separate data structure to treat + * src and dst host addresses as independently selectable. + */ + host_keys = keys; + host_keys.ports.ports = 0; + host_keys.basic.ip_proto = 0; + host_keys.keyid.keyid = 0; + host_keys.tags.flow_label = 0; + + switch (host_keys.control.addr_type) { + case FLOW_DISSECTOR_KEY_IPV4_ADDRS: + host_keys.addrs.v4addrs.src = 0; + dsthost_hash = flow_hash_from_keys(&host_keys); + host_keys.addrs.v4addrs.src = keys.addrs.v4addrs.src; + host_keys.addrs.v4addrs.dst = 0; + srchost_hash = flow_hash_from_keys(&host_keys); + break; + + case FLOW_DISSECTOR_KEY_IPV6_ADDRS: + memset(&host_keys.addrs.v6addrs.src, 0, + sizeof(host_keys.addrs.v6addrs.src)); + dsthost_hash = flow_hash_from_keys(&host_keys); + host_keys.addrs.v6addrs.src = keys.addrs.v6addrs.src; + memset(&host_keys.addrs.v6addrs.dst, 0, + sizeof(host_keys.addrs.v6addrs.dst)); + srchost_hash = flow_hash_from_keys(&host_keys); + break; + + default: + dsthost_hash = 0; + srchost_hash = 0; + } + + /* This *must* be after the above switch, since as a + * side-effect it sorts the src and dst addresses. + */ + if (hash_flows && !use_skbhash) + flow_hash = flow_hash_from_keys(&keys); + +skip_hash: + if (flow_override) + flow_hash = flow_override - 1; + else if (use_skbhash && (flow_mode & CAKE_FLOW_FLOWS)) + flow_hash = skb->hash; + if (host_override) { + dsthost_hash = host_override - 1; + srchost_hash = host_override - 1; + } + + if (!(flow_mode & CAKE_FLOW_FLOWS)) { + if (flow_mode & CAKE_FLOW_SRC_IP) + flow_hash ^= srchost_hash; + + if (flow_mode & CAKE_FLOW_DST_IP) + flow_hash ^= dsthost_hash; + } + + reduced_hash = flow_hash % CAKE_QUEUES; + + /* set-associative hashing */ + /* fast path if no hash collision (direct lookup succeeds) */ + if (likely(q->tags[reduced_hash] == flow_hash && + q->flows[reduced_hash].set)) { + q->way_directs++; + } else { + u32 inner_hash = reduced_hash % CAKE_SET_WAYS; + u32 outer_hash = reduced_hash - inner_hash; + bool allocate_src = false; + bool allocate_dst = false; + u32 i, k; + + /* check if any active queue in the set is reserved for + * this flow. + */ + for (i = 0, k = inner_hash; i < CAKE_SET_WAYS; + i++, k = (k + 1) % CAKE_SET_WAYS) { + if (q->tags[outer_hash + k] == flow_hash) { + if (i) + q->way_hits++; + + if (!q->flows[outer_hash + k].set) { + /* need to increment host refcnts */ + allocate_src = cake_dsrc(flow_mode); + allocate_dst = cake_ddst(flow_mode); + } + + goto found; + } + } + + /* no queue is reserved for this flow, look for an + * empty one. + */ + for (i = 0; i < CAKE_SET_WAYS; + i++, k = (k + 1) % CAKE_SET_WAYS) { + if (!q->flows[outer_hash + k].set) { + q->way_misses++; + allocate_src = cake_dsrc(flow_mode); + allocate_dst = cake_ddst(flow_mode); + goto found; + } + } + + /* With no empty queues, default to the original + * queue, accept the collision, update the host tags. + */ + q->way_collisions++; + if (q->flows[outer_hash + k].set == CAKE_SET_BULK) { + q->hosts[q->flows[reduced_hash].srchost].srchost_bulk_flow_count--; + q->hosts[q->flows[reduced_hash].dsthost].dsthost_bulk_flow_count--; + } + allocate_src = cake_dsrc(flow_mode); + allocate_dst = cake_ddst(flow_mode); +found: + /* reserve queue for future packets in same flow */ + reduced_hash = outer_hash + k; + q->tags[reduced_hash] = flow_hash; + + if (allocate_src) { + srchost_idx = srchost_hash % CAKE_QUEUES; + inner_hash = srchost_idx % CAKE_SET_WAYS; + outer_hash = srchost_idx - inner_hash; + for (i = 0, k = inner_hash; i < CAKE_SET_WAYS; + i++, k = (k + 1) % CAKE_SET_WAYS) { + if (q->hosts[outer_hash + k].srchost_tag == + srchost_hash) + goto found_src; + } + for (i = 0; i < CAKE_SET_WAYS; + i++, k = (k + 1) % CAKE_SET_WAYS) { + if (!q->hosts[outer_hash + k].srchost_bulk_flow_count) + break; + } + q->hosts[outer_hash + k].srchost_tag = srchost_hash; +found_src: + srchost_idx = outer_hash + k; + if (q->flows[reduced_hash].set == CAKE_SET_BULK) + q->hosts[srchost_idx].srchost_bulk_flow_count++; + q->flows[reduced_hash].srchost = srchost_idx; + } + + if (allocate_dst) { + dsthost_idx = dsthost_hash % CAKE_QUEUES; + inner_hash = dsthost_idx % CAKE_SET_WAYS; + outer_hash = dsthost_idx - inner_hash; + for (i = 0, k = inner_hash; i < CAKE_SET_WAYS; + i++, k = (k + 1) % CAKE_SET_WAYS) { + if (q->hosts[outer_hash + k].dsthost_tag == + dsthost_hash) + goto found_dst; + } + for (i = 0; i < CAKE_SET_WAYS; + i++, k = (k + 1) % CAKE_SET_WAYS) { + if (!q->hosts[outer_hash + k].dsthost_bulk_flow_count) + break; + } + q->hosts[outer_hash + k].dsthost_tag = dsthost_hash; +found_dst: + dsthost_idx = outer_hash + k; + if (q->flows[reduced_hash].set == CAKE_SET_BULK) + q->hosts[dsthost_idx].dsthost_bulk_flow_count++; + q->flows[reduced_hash].dsthost = dsthost_idx; + } + } + + return reduced_hash; +} + +/* helper functions : might be changed when/if skb use a standard list_head */ +/* remove one skb from head of slot queue */ + +static struct sk_buff *dequeue_head(struct cake_flow *flow) +{ + struct sk_buff *skb = flow->head; + + if (skb) { + flow->head = skb->next; + skb_mark_not_on_list(skb); + } + + return skb; +} + +/* add skb to flow queue (tail add) */ + +static void flow_queue_add(struct cake_flow *flow, struct sk_buff *skb) +{ + if (!flow->head) + flow->head = skb; + else + flow->tail->next = skb; + flow->tail = skb; + skb->next = NULL; +} + +static struct iphdr *cake_get_iphdr(const struct sk_buff *skb, + struct ipv6hdr *buf) +{ + unsigned int offset = skb_network_offset(skb); + struct iphdr *iph; + + iph = skb_header_pointer(skb, offset, sizeof(struct iphdr), buf); + + if (!iph) + return NULL; + + if (iph->version == 4 && iph->protocol == IPPROTO_IPV6) + return skb_header_pointer(skb, offset + iph->ihl * 4, + sizeof(struct ipv6hdr), buf); + + else if (iph->version == 4) + return iph; + + else if (iph->version == 6) + return skb_header_pointer(skb, offset, sizeof(struct ipv6hdr), + buf); + + return NULL; +} + +static struct tcphdr *cake_get_tcphdr(const struct sk_buff *skb, + void *buf, unsigned int bufsize) +{ + unsigned int offset = skb_network_offset(skb); + const struct ipv6hdr *ipv6h; + const struct tcphdr *tcph; + const struct iphdr *iph; + struct ipv6hdr _ipv6h; + struct tcphdr _tcph; + + ipv6h = skb_header_pointer(skb, offset, sizeof(_ipv6h), &_ipv6h); + + if (!ipv6h) + return NULL; + + if (ipv6h->version == 4) { + iph = (struct iphdr *)ipv6h; + offset += iph->ihl * 4; + + /* special-case 6in4 tunnelling, as that is a common way to get + * v6 connectivity in the home + */ + if (iph->protocol == IPPROTO_IPV6) { + ipv6h = skb_header_pointer(skb, offset, + sizeof(_ipv6h), &_ipv6h); + + if (!ipv6h || ipv6h->nexthdr != IPPROTO_TCP) + return NULL; + + offset += sizeof(struct ipv6hdr); + + } else if (iph->protocol != IPPROTO_TCP) { + return NULL; + } + + } else if (ipv6h->version == 6) { + if (ipv6h->nexthdr != IPPROTO_TCP) + return NULL; + + offset += sizeof(struct ipv6hdr); + } else { + return NULL; + } + + tcph = skb_header_pointer(skb, offset, sizeof(_tcph), &_tcph); + if (!tcph || tcph->doff < 5) + return NULL; + + return skb_header_pointer(skb, offset, + min(__tcp_hdrlen(tcph), bufsize), buf); +} + +static const void *cake_get_tcpopt(const struct tcphdr *tcph, + int code, int *oplen) +{ + /* inspired by tcp_parse_options in tcp_input.c */ + int length = __tcp_hdrlen(tcph) - sizeof(struct tcphdr); + const u8 *ptr = (const u8 *)(tcph + 1); + + while (length > 0) { + int opcode = *ptr++; + int opsize; + + if (opcode == TCPOPT_EOL) + break; + if (opcode == TCPOPT_NOP) { + length--; + continue; + } + if (length < 2) + break; + opsize = *ptr++; + if (opsize < 2 || opsize > length) + break; + + if (opcode == code) { + *oplen = opsize; + return ptr; + } + + ptr += opsize - 2; + length -= opsize; + } + + return NULL; +} + +/* Compare two SACK sequences. A sequence is considered greater if it SACKs more + * bytes than the other. In the case where both sequences ACKs bytes that the + * other doesn't, A is considered greater. DSACKs in A also makes A be + * considered greater. + * + * @return -1, 0 or 1 as normal compare functions + */ +static int cake_tcph_sack_compare(const struct tcphdr *tcph_a, + const struct tcphdr *tcph_b) +{ + const struct tcp_sack_block_wire *sack_a, *sack_b; + u32 ack_seq_a = ntohl(tcph_a->ack_seq); + u32 bytes_a = 0, bytes_b = 0; + int oplen_a, oplen_b; + bool first = true; + + sack_a = cake_get_tcpopt(tcph_a, TCPOPT_SACK, &oplen_a); + sack_b = cake_get_tcpopt(tcph_b, TCPOPT_SACK, &oplen_b); + + /* pointers point to option contents */ + oplen_a -= TCPOLEN_SACK_BASE; + oplen_b -= TCPOLEN_SACK_BASE; + + if (sack_a && oplen_a >= sizeof(*sack_a) && + (!sack_b || oplen_b < sizeof(*sack_b))) + return -1; + else if (sack_b && oplen_b >= sizeof(*sack_b) && + (!sack_a || oplen_a < sizeof(*sack_a))) + return 1; + else if ((!sack_a || oplen_a < sizeof(*sack_a)) && + (!sack_b || oplen_b < sizeof(*sack_b))) + return 0; + + while (oplen_a >= sizeof(*sack_a)) { + const struct tcp_sack_block_wire *sack_tmp = sack_b; + u32 start_a = get_unaligned_be32(&sack_a->start_seq); + u32 end_a = get_unaligned_be32(&sack_a->end_seq); + int oplen_tmp = oplen_b; + bool found = false; + + /* DSACK; always considered greater to prevent dropping */ + if (before(start_a, ack_seq_a)) + return -1; + + bytes_a += end_a - start_a; + + while (oplen_tmp >= sizeof(*sack_tmp)) { + u32 start_b = get_unaligned_be32(&sack_tmp->start_seq); + u32 end_b = get_unaligned_be32(&sack_tmp->end_seq); + + /* first time through we count the total size */ + if (first) + bytes_b += end_b - start_b; + + if (!after(start_b, start_a) && !before(end_b, end_a)) { + found = true; + if (!first) + break; + } + oplen_tmp -= sizeof(*sack_tmp); + sack_tmp++; + } + + if (!found) + return -1; + + oplen_a -= sizeof(*sack_a); + sack_a++; + first = false; + } + + /* If we made it this far, all ranges SACKed by A are covered by B, so + * either the SACKs are equal, or B SACKs more bytes. + */ + return bytes_b > bytes_a ? 1 : 0; +} + +static void cake_tcph_get_tstamp(const struct tcphdr *tcph, + u32 *tsval, u32 *tsecr) +{ + const u8 *ptr; + int opsize; + + ptr = cake_get_tcpopt(tcph, TCPOPT_TIMESTAMP, &opsize); + + if (ptr && opsize == TCPOLEN_TIMESTAMP) { + *tsval = get_unaligned_be32(ptr); + *tsecr = get_unaligned_be32(ptr + 4); + } +} + +static bool cake_tcph_may_drop(const struct tcphdr *tcph, + u32 tstamp_new, u32 tsecr_new) +{ + /* inspired by tcp_parse_options in tcp_input.c */ + int length = __tcp_hdrlen(tcph) - sizeof(struct tcphdr); + const u8 *ptr = (const u8 *)(tcph + 1); + u32 tstamp, tsecr; + + /* 3 reserved flags must be unset to avoid future breakage + * ACK must be set + * ECE/CWR are handled separately + * All other flags URG/PSH/RST/SYN/FIN must be unset + * 0x0FFF0000 = all TCP flags (confirm ACK=1, others zero) + * 0x00C00000 = CWR/ECE (handled separately) + * 0x0F3F0000 = 0x0FFF0000 & ~0x00C00000 + */ + if (((tcp_flag_word(tcph) & + cpu_to_be32(0x0F3F0000)) != TCP_FLAG_ACK)) + return false; + + while (length > 0) { + int opcode = *ptr++; + int opsize; + + if (opcode == TCPOPT_EOL) + break; + if (opcode == TCPOPT_NOP) { + length--; + continue; + } + if (length < 2) + break; + opsize = *ptr++; + if (opsize < 2 || opsize > length) + break; + + switch (opcode) { + case TCPOPT_MD5SIG: /* doesn't influence state */ + break; + + case TCPOPT_SACK: /* stricter checking performed later */ + if (opsize % 8 != 2) + return false; + break; + + case TCPOPT_TIMESTAMP: + /* only drop timestamps lower than new */ + if (opsize != TCPOLEN_TIMESTAMP) + return false; + tstamp = get_unaligned_be32(ptr); + tsecr = get_unaligned_be32(ptr + 4); + if (after(tstamp, tstamp_new) || + after(tsecr, tsecr_new)) + return false; + break; + + case TCPOPT_MSS: /* these should only be set on SYN */ + case TCPOPT_WINDOW: + case TCPOPT_SACK_PERM: + case TCPOPT_FASTOPEN: + case TCPOPT_EXP: + default: /* don't drop if any unknown options are present */ + return false; + } + + ptr += opsize - 2; + length -= opsize; + } + + return true; +} + +static struct sk_buff *cake_ack_filter(struct cake_sched_data *q, + struct cake_flow *flow) +{ + bool aggressive = q->ack_filter == CAKE_ACK_AGGRESSIVE; + struct sk_buff *elig_ack = NULL, *elig_ack_prev = NULL; + struct sk_buff *skb_check, *skb_prev = NULL; + const struct ipv6hdr *ipv6h, *ipv6h_check; + unsigned char _tcph[64], _tcph_check[64]; + const struct tcphdr *tcph, *tcph_check; + const struct iphdr *iph, *iph_check; + struct ipv6hdr _iph, _iph_check; + const struct sk_buff *skb; + int seglen, num_found = 0; + u32 tstamp = 0, tsecr = 0; + __be32 elig_flags = 0; + int sack_comp; + + /* no other possible ACKs to filter */ + if (flow->head == flow->tail) + return NULL; + + skb = flow->tail; + tcph = cake_get_tcphdr(skb, _tcph, sizeof(_tcph)); + iph = cake_get_iphdr(skb, &_iph); + if (!tcph) + return NULL; + + cake_tcph_get_tstamp(tcph, &tstamp, &tsecr); + + /* the 'triggering' packet need only have the ACK flag set. + * also check that SYN is not set, as there won't be any previous ACKs. + */ + if ((tcp_flag_word(tcph) & + (TCP_FLAG_ACK | TCP_FLAG_SYN)) != TCP_FLAG_ACK) + return NULL; + + /* the 'triggering' ACK is at the tail of the queue, we have already + * returned if it is the only packet in the flow. loop through the rest + * of the queue looking for pure ACKs with the same 5-tuple as the + * triggering one. + */ + for (skb_check = flow->head; + skb_check && skb_check != skb; + skb_prev = skb_check, skb_check = skb_check->next) { + iph_check = cake_get_iphdr(skb_check, &_iph_check); + tcph_check = cake_get_tcphdr(skb_check, &_tcph_check, + sizeof(_tcph_check)); + + /* only TCP packets with matching 5-tuple are eligible, and only + * drop safe headers + */ + if (!tcph_check || iph->version != iph_check->version || + tcph_check->source != tcph->source || + tcph_check->dest != tcph->dest) + continue; + + if (iph_check->version == 4) { + if (iph_check->saddr != iph->saddr || + iph_check->daddr != iph->daddr) + continue; + + seglen = ntohs(iph_check->tot_len) - + (4 * iph_check->ihl); + } else if (iph_check->version == 6) { + ipv6h = (struct ipv6hdr *)iph; + ipv6h_check = (struct ipv6hdr *)iph_check; + + if (ipv6_addr_cmp(&ipv6h_check->saddr, &ipv6h->saddr) || + ipv6_addr_cmp(&ipv6h_check->daddr, &ipv6h->daddr)) + continue; + + seglen = ntohs(ipv6h_check->payload_len); + } else { + WARN_ON(1); /* shouldn't happen */ + continue; + } + + /* If the ECE/CWR flags changed from the previous eligible + * packet in the same flow, we should no longer be dropping that + * previous packet as this would lose information. + */ + if (elig_ack && (tcp_flag_word(tcph_check) & + (TCP_FLAG_ECE | TCP_FLAG_CWR)) != elig_flags) { + elig_ack = NULL; + elig_ack_prev = NULL; + num_found--; + } + + /* Check TCP options and flags, don't drop ACKs with segment + * data, and don't drop ACKs with a higher cumulative ACK + * counter than the triggering packet. Check ACK seqno here to + * avoid parsing SACK options of packets we are going to exclude + * anyway. + */ + if (!cake_tcph_may_drop(tcph_check, tstamp, tsecr) || + (seglen - __tcp_hdrlen(tcph_check)) != 0 || + after(ntohl(tcph_check->ack_seq), ntohl(tcph->ack_seq))) + continue; + + /* Check SACK options. The triggering packet must SACK more data + * than the ACK under consideration, or SACK the same range but + * have a larger cumulative ACK counter. The latter is a + * pathological case, but is contained in the following check + * anyway, just to be safe. + */ + sack_comp = cake_tcph_sack_compare(tcph_check, tcph); + + if (sack_comp < 0 || + (ntohl(tcph_check->ack_seq) == ntohl(tcph->ack_seq) && + sack_comp == 0)) + continue; + + /* At this point we have found an eligible pure ACK to drop; if + * we are in aggressive mode, we are done. Otherwise, keep + * searching unless this is the second eligible ACK we + * found. + * + * Since we want to drop ACK closest to the head of the queue, + * save the first eligible ACK we find, even if we need to loop + * again. + */ + if (!elig_ack) { + elig_ack = skb_check; + elig_ack_prev = skb_prev; + elig_flags = (tcp_flag_word(tcph_check) + & (TCP_FLAG_ECE | TCP_FLAG_CWR)); + } + + if (num_found++ > 0) + goto found; + } + + /* We made it through the queue without finding two eligible ACKs . If + * we found a single eligible ACK we can drop it in aggressive mode if + * we can guarantee that this does not interfere with ECN flag + * information. We ensure this by dropping it only if the enqueued + * packet is consecutive with the eligible ACK, and their flags match. + */ + if (elig_ack && aggressive && elig_ack->next == skb && + (elig_flags == (tcp_flag_word(tcph) & + (TCP_FLAG_ECE | TCP_FLAG_CWR)))) + goto found; + + return NULL; + +found: + if (elig_ack_prev) + elig_ack_prev->next = elig_ack->next; + else + flow->head = elig_ack->next; + + skb_mark_not_on_list(elig_ack); + + return elig_ack; +} + +static u64 cake_ewma(u64 avg, u64 sample, u32 shift) +{ + avg -= avg >> shift; + avg += sample >> shift; + return avg; +} + +static u32 cake_calc_overhead(struct cake_sched_data *q, u32 len, u32 off) +{ + if (q->rate_flags & CAKE_FLAG_OVERHEAD) + len -= off; + + if (q->max_netlen < len) + q->max_netlen = len; + if (q->min_netlen > len) + q->min_netlen = len; + + len += q->rate_overhead; + + if (len < q->rate_mpu) + len = q->rate_mpu; + + if (q->atm_mode == CAKE_ATM_ATM) { + len += 47; + len /= 48; + len *= 53; + } else if (q->atm_mode == CAKE_ATM_PTM) { + /* Add one byte per 64 bytes or part thereof. + * This is conservative and easier to calculate than the + * precise value. + */ + len += (len + 63) / 64; + } + + if (q->max_adjlen < len) + q->max_adjlen = len; + if (q->min_adjlen > len) + q->min_adjlen = len; + + return len; +} + +static u32 cake_overhead(struct cake_sched_data *q, const struct sk_buff *skb) +{ + const struct skb_shared_info *shinfo = skb_shinfo(skb); + unsigned int hdr_len, last_len = 0; + u32 off = skb_network_offset(skb); + u32 len = qdisc_pkt_len(skb); + u16 segs = 1; + + q->avg_netoff = cake_ewma(q->avg_netoff, off << 16, 8); + + if (!shinfo->gso_size) + return cake_calc_overhead(q, len, off); + + /* borrowed from qdisc_pkt_len_init() */ + hdr_len = skb_transport_header(skb) - skb_mac_header(skb); + + /* + transport layer */ + if (likely(shinfo->gso_type & (SKB_GSO_TCPV4 | + SKB_GSO_TCPV6))) { + const struct tcphdr *th; + struct tcphdr _tcphdr; + + th = skb_header_pointer(skb, skb_transport_offset(skb), + sizeof(_tcphdr), &_tcphdr); + if (likely(th)) + hdr_len += __tcp_hdrlen(th); + } else { + struct udphdr _udphdr; + + if (skb_header_pointer(skb, skb_transport_offset(skb), + sizeof(_udphdr), &_udphdr)) + hdr_len += sizeof(struct udphdr); + } + + if (unlikely(shinfo->gso_type & SKB_GSO_DODGY)) + segs = DIV_ROUND_UP(skb->len - hdr_len, + shinfo->gso_size); + else + segs = shinfo->gso_segs; + + len = shinfo->gso_size + hdr_len; + last_len = skb->len - shinfo->gso_size * (segs - 1); + + return (cake_calc_overhead(q, len, off) * (segs - 1) + + cake_calc_overhead(q, last_len, off)); +} + +static void cake_heap_swap(struct cake_sched_data *q, u16 i, u16 j) +{ + struct cake_heap_entry ii = q->overflow_heap[i]; + struct cake_heap_entry jj = q->overflow_heap[j]; + + q->overflow_heap[i] = jj; + q->overflow_heap[j] = ii; + + q->tins[ii.t].overflow_idx[ii.b] = j; + q->tins[jj.t].overflow_idx[jj.b] = i; +} + +static u32 cake_heap_get_backlog(const struct cake_sched_data *q, u16 i) +{ + struct cake_heap_entry ii = q->overflow_heap[i]; + + return q->tins[ii.t].backlogs[ii.b]; +} + +static void cake_heapify(struct cake_sched_data *q, u16 i) +{ + static const u32 a = CAKE_MAX_TINS * CAKE_QUEUES; + u32 mb = cake_heap_get_backlog(q, i); + u32 m = i; + + while (m < a) { + u32 l = m + m + 1; + u32 r = l + 1; + + if (l < a) { + u32 lb = cake_heap_get_backlog(q, l); + + if (lb > mb) { + m = l; + mb = lb; + } + } + + if (r < a) { + u32 rb = cake_heap_get_backlog(q, r); + + if (rb > mb) { + m = r; + mb = rb; + } + } + + if (m != i) { + cake_heap_swap(q, i, m); + i = m; + } else { + break; + } + } +} + +static void cake_heapify_up(struct cake_sched_data *q, u16 i) +{ + while (i > 0 && i < CAKE_MAX_TINS * CAKE_QUEUES) { + u16 p = (i - 1) >> 1; + u32 ib = cake_heap_get_backlog(q, i); + u32 pb = cake_heap_get_backlog(q, p); + + if (ib > pb) { + cake_heap_swap(q, i, p); + i = p; + } else { + break; + } + } +} + +static int cake_advance_shaper(struct cake_sched_data *q, + struct cake_tin_data *b, + struct sk_buff *skb, + ktime_t now, bool drop) +{ + u32 len = get_cobalt_cb(skb)->adjusted_len; + + /* charge packet bandwidth to this tin + * and to the global shaper. + */ + if (q->rate_ns) { + u64 tin_dur = (len * b->tin_rate_ns) >> b->tin_rate_shft; + u64 global_dur = (len * q->rate_ns) >> q->rate_shft; + u64 failsafe_dur = global_dur + (global_dur >> 1); + + if (ktime_before(b->time_next_packet, now)) + b->time_next_packet = ktime_add_ns(b->time_next_packet, + tin_dur); + + else if (ktime_before(b->time_next_packet, + ktime_add_ns(now, tin_dur))) + b->time_next_packet = ktime_add_ns(now, tin_dur); + + q->time_next_packet = ktime_add_ns(q->time_next_packet, + global_dur); + if (!drop) + q->failsafe_next_packet = \ + ktime_add_ns(q->failsafe_next_packet, + failsafe_dur); + } + return len; +} + +static unsigned int cake_drop(struct Qdisc *sch, struct sk_buff **to_free) +{ + struct cake_sched_data *q = qdisc_priv(sch); + ktime_t now = ktime_get(); + u32 idx = 0, tin = 0, len; + struct cake_heap_entry qq; + struct cake_tin_data *b; + struct cake_flow *flow; + struct sk_buff *skb; + + if (!q->overflow_timeout) { + int i; + /* Build fresh max-heap */ + for (i = CAKE_MAX_TINS * CAKE_QUEUES / 2; i >= 0; i--) + cake_heapify(q, i); + } + q->overflow_timeout = 65535; + + /* select longest queue for pruning */ + qq = q->overflow_heap[0]; + tin = qq.t; + idx = qq.b; + + b = &q->tins[tin]; + flow = &b->flows[idx]; + skb = dequeue_head(flow); + if (unlikely(!skb)) { + /* heap has gone wrong, rebuild it next time */ + q->overflow_timeout = 0; + return idx + (tin << 16); + } + + if (cobalt_queue_full(&flow->cvars, &b->cparams, now)) + b->unresponsive_flow_count++; + + len = qdisc_pkt_len(skb); + q->buffer_used -= skb->truesize; + b->backlogs[idx] -= len; + b->tin_backlog -= len; + sch->qstats.backlog -= len; + qdisc_tree_reduce_backlog(sch, 1, len); + + flow->dropped++; + b->tin_dropped++; + sch->qstats.drops++; + + if (q->rate_flags & CAKE_FLAG_INGRESS) + cake_advance_shaper(q, b, skb, now, true); + + __qdisc_drop(skb, to_free); + sch->q.qlen--; + + cake_heapify(q, 0); + + return idx + (tin << 16); +} + +static u8 cake_handle_diffserv(struct sk_buff *skb, bool wash) +{ + const int offset = skb_network_offset(skb); + u16 *buf, buf_; + u8 dscp; + + switch (skb_protocol(skb, true)) { + case htons(ETH_P_IP): + buf = skb_header_pointer(skb, offset, sizeof(buf_), &buf_); + if (unlikely(!buf)) + return 0; + + /* ToS is in the second byte of iphdr */ + dscp = ipv4_get_dsfield((struct iphdr *)buf) >> 2; + + if (wash && dscp) { + const int wlen = offset + sizeof(struct iphdr); + + if (!pskb_may_pull(skb, wlen) || + skb_try_make_writable(skb, wlen)) + return 0; + + ipv4_change_dsfield(ip_hdr(skb), INET_ECN_MASK, 0); + } + + return dscp; + + case htons(ETH_P_IPV6): + buf = skb_header_pointer(skb, offset, sizeof(buf_), &buf_); + if (unlikely(!buf)) + return 0; + + /* Traffic class is in the first and second bytes of ipv6hdr */ + dscp = ipv6_get_dsfield((struct ipv6hdr *)buf) >> 2; + + if (wash && dscp) { + const int wlen = offset + sizeof(struct ipv6hdr); + + if (!pskb_may_pull(skb, wlen) || + skb_try_make_writable(skb, wlen)) + return 0; + + ipv6_change_dsfield(ipv6_hdr(skb), INET_ECN_MASK, 0); + } + + return dscp; + + case htons(ETH_P_ARP): + return 0x38; /* CS7 - Net Control */ + + default: + /* If there is no Diffserv field, treat as best-effort */ + return 0; + } +} + +static struct cake_tin_data *cake_select_tin(struct Qdisc *sch, + struct sk_buff *skb) +{ + struct cake_sched_data *q = qdisc_priv(sch); + u32 tin, mark; + bool wash; + u8 dscp; + + /* Tin selection: Default to diffserv-based selection, allow overriding + * using firewall marks or skb->priority. Call DSCP parsing early if + * wash is enabled, otherwise defer to below to skip unneeded parsing. + */ + mark = (skb->mark & q->fwmark_mask) >> q->fwmark_shft; + wash = !!(q->rate_flags & CAKE_FLAG_WASH); + if (wash) + dscp = cake_handle_diffserv(skb, wash); + + if (q->tin_mode == CAKE_DIFFSERV_BESTEFFORT) + tin = 0; + + else if (mark && mark <= q->tin_cnt) + tin = q->tin_order[mark - 1]; + + else if (TC_H_MAJ(skb->priority) == sch->handle && + TC_H_MIN(skb->priority) > 0 && + TC_H_MIN(skb->priority) <= q->tin_cnt) + tin = q->tin_order[TC_H_MIN(skb->priority) - 1]; + + else { + if (!wash) + dscp = cake_handle_diffserv(skb, wash); + tin = q->tin_index[dscp]; + + if (unlikely(tin >= q->tin_cnt)) + tin = 0; + } + + return &q->tins[tin]; +} + +static u32 cake_classify(struct Qdisc *sch, struct cake_tin_data **t, + struct sk_buff *skb, int flow_mode, int *qerr) +{ + struct cake_sched_data *q = qdisc_priv(sch); + struct tcf_proto *filter; + struct tcf_result res; + u16 flow = 0, host = 0; + int result; + + filter = rcu_dereference_bh(q->filter_list); + if (!filter) + goto hash; + + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; + result = tcf_classify(skb, NULL, filter, &res, false); + + if (result >= 0) { +#ifdef CONFIG_NET_CLS_ACT + switch (result) { + case TC_ACT_STOLEN: + case TC_ACT_QUEUED: + case TC_ACT_TRAP: + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_STOLEN; + fallthrough; + case TC_ACT_SHOT: + return 0; + } +#endif + if (TC_H_MIN(res.classid) <= CAKE_QUEUES) + flow = TC_H_MIN(res.classid); + if (TC_H_MAJ(res.classid) <= (CAKE_QUEUES << 16)) + host = TC_H_MAJ(res.classid) >> 16; + } +hash: + *t = cake_select_tin(sch, skb); + return cake_hash(*t, skb, flow_mode, flow, host) + 1; +} + +static void cake_reconfigure(struct Qdisc *sch); + +static s32 cake_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct cake_sched_data *q = qdisc_priv(sch); + int len = qdisc_pkt_len(skb); + int ret; + struct sk_buff *ack = NULL; + ktime_t now = ktime_get(); + struct cake_tin_data *b; + struct cake_flow *flow; + u32 idx; + + /* choose flow to insert into */ + idx = cake_classify(sch, &b, skb, q->flow_mode, &ret); + if (idx == 0) { + if (ret & __NET_XMIT_BYPASS) + qdisc_qstats_drop(sch); + __qdisc_drop(skb, to_free); + return ret; + } + idx--; + flow = &b->flows[idx]; + + /* ensure shaper state isn't stale */ + if (!b->tin_backlog) { + if (ktime_before(b->time_next_packet, now)) + b->time_next_packet = now; + + if (!sch->q.qlen) { + if (ktime_before(q->time_next_packet, now)) { + q->failsafe_next_packet = now; + q->time_next_packet = now; + } else if (ktime_after(q->time_next_packet, now) && + ktime_after(q->failsafe_next_packet, now)) { + u64 next = \ + min(ktime_to_ns(q->time_next_packet), + ktime_to_ns( + q->failsafe_next_packet)); + sch->qstats.overlimits++; + qdisc_watchdog_schedule_ns(&q->watchdog, next); + } + } + } + + if (unlikely(len > b->max_skblen)) + b->max_skblen = len; + + if (skb_is_gso(skb) && q->rate_flags & CAKE_FLAG_SPLIT_GSO) { + struct sk_buff *segs, *nskb; + netdev_features_t features = netif_skb_features(skb); + unsigned int slen = 0, numsegs = 0; + + segs = skb_gso_segment(skb, features & ~NETIF_F_GSO_MASK); + if (IS_ERR_OR_NULL(segs)) + return qdisc_drop(skb, sch, to_free); + + skb_list_walk_safe(segs, segs, nskb) { + skb_mark_not_on_list(segs); + qdisc_skb_cb(segs)->pkt_len = segs->len; + cobalt_set_enqueue_time(segs, now); + get_cobalt_cb(segs)->adjusted_len = cake_overhead(q, + segs); + flow_queue_add(flow, segs); + + sch->q.qlen++; + numsegs++; + slen += segs->len; + q->buffer_used += segs->truesize; + b->packets++; + } + + /* stats */ + b->bytes += slen; + b->backlogs[idx] += slen; + b->tin_backlog += slen; + sch->qstats.backlog += slen; + q->avg_window_bytes += slen; + + qdisc_tree_reduce_backlog(sch, 1-numsegs, len-slen); + consume_skb(skb); + } else { + /* not splitting */ + cobalt_set_enqueue_time(skb, now); + get_cobalt_cb(skb)->adjusted_len = cake_overhead(q, skb); + flow_queue_add(flow, skb); + + if (q->ack_filter) + ack = cake_ack_filter(q, flow); + + if (ack) { + b->ack_drops++; + sch->qstats.drops++; + b->bytes += qdisc_pkt_len(ack); + len -= qdisc_pkt_len(ack); + q->buffer_used += skb->truesize - ack->truesize; + if (q->rate_flags & CAKE_FLAG_INGRESS) + cake_advance_shaper(q, b, ack, now, true); + + qdisc_tree_reduce_backlog(sch, 1, qdisc_pkt_len(ack)); + consume_skb(ack); + } else { + sch->q.qlen++; + q->buffer_used += skb->truesize; + } + + /* stats */ + b->packets++; + b->bytes += len; + b->backlogs[idx] += len; + b->tin_backlog += len; + sch->qstats.backlog += len; + q->avg_window_bytes += len; + } + + if (q->overflow_timeout) + cake_heapify_up(q, b->overflow_idx[idx]); + + /* incoming bandwidth capacity estimate */ + if (q->rate_flags & CAKE_FLAG_AUTORATE_INGRESS) { + u64 packet_interval = \ + ktime_to_ns(ktime_sub(now, q->last_packet_time)); + + if (packet_interval > NSEC_PER_SEC) + packet_interval = NSEC_PER_SEC; + + /* filter out short-term bursts, eg. wifi aggregation */ + q->avg_packet_interval = \ + cake_ewma(q->avg_packet_interval, + packet_interval, + (packet_interval > q->avg_packet_interval ? + 2 : 8)); + + q->last_packet_time = now; + + if (packet_interval > q->avg_packet_interval) { + u64 window_interval = \ + ktime_to_ns(ktime_sub(now, + q->avg_window_begin)); + u64 b = q->avg_window_bytes * (u64)NSEC_PER_SEC; + + b = div64_u64(b, window_interval); + q->avg_peak_bandwidth = + cake_ewma(q->avg_peak_bandwidth, b, + b > q->avg_peak_bandwidth ? 2 : 8); + q->avg_window_bytes = 0; + q->avg_window_begin = now; + + if (ktime_after(now, + ktime_add_ms(q->last_reconfig_time, + 250))) { + q->rate_bps = (q->avg_peak_bandwidth * 15) >> 4; + cake_reconfigure(sch); + } + } + } else { + q->avg_window_bytes = 0; + q->last_packet_time = now; + } + + /* flowchain */ + if (!flow->set || flow->set == CAKE_SET_DECAYING) { + struct cake_host *srchost = &b->hosts[flow->srchost]; + struct cake_host *dsthost = &b->hosts[flow->dsthost]; + u16 host_load = 1; + + if (!flow->set) { + list_add_tail(&flow->flowchain, &b->new_flows); + } else { + b->decaying_flow_count--; + list_move_tail(&flow->flowchain, &b->new_flows); + } + flow->set = CAKE_SET_SPARSE; + b->sparse_flow_count++; + + if (cake_dsrc(q->flow_mode)) + host_load = max(host_load, srchost->srchost_bulk_flow_count); + + if (cake_ddst(q->flow_mode)) + host_load = max(host_load, dsthost->dsthost_bulk_flow_count); + + flow->deficit = (b->flow_quantum * + quantum_div[host_load]) >> 16; + } else if (flow->set == CAKE_SET_SPARSE_WAIT) { + struct cake_host *srchost = &b->hosts[flow->srchost]; + struct cake_host *dsthost = &b->hosts[flow->dsthost]; + + /* this flow was empty, accounted as a sparse flow, but actually + * in the bulk rotation. + */ + flow->set = CAKE_SET_BULK; + b->sparse_flow_count--; + b->bulk_flow_count++; + + if (cake_dsrc(q->flow_mode)) + srchost->srchost_bulk_flow_count++; + + if (cake_ddst(q->flow_mode)) + dsthost->dsthost_bulk_flow_count++; + + } + + if (q->buffer_used > q->buffer_max_used) + q->buffer_max_used = q->buffer_used; + + if (q->buffer_used > q->buffer_limit) { + u32 dropped = 0; + + while (q->buffer_used > q->buffer_limit) { + dropped++; + cake_drop(sch, to_free); + } + b->drop_overlimit += dropped; + } + return NET_XMIT_SUCCESS; +} + +static struct sk_buff *cake_dequeue_one(struct Qdisc *sch) +{ + struct cake_sched_data *q = qdisc_priv(sch); + struct cake_tin_data *b = &q->tins[q->cur_tin]; + struct cake_flow *flow = &b->flows[q->cur_flow]; + struct sk_buff *skb = NULL; + u32 len; + + if (flow->head) { + skb = dequeue_head(flow); + len = qdisc_pkt_len(skb); + b->backlogs[q->cur_flow] -= len; + b->tin_backlog -= len; + sch->qstats.backlog -= len; + q->buffer_used -= skb->truesize; + sch->q.qlen--; + + if (q->overflow_timeout) + cake_heapify(q, b->overflow_idx[q->cur_flow]); + } + return skb; +} + +/* Discard leftover packets from a tin no longer in use. */ +static void cake_clear_tin(struct Qdisc *sch, u16 tin) +{ + struct cake_sched_data *q = qdisc_priv(sch); + struct sk_buff *skb; + + q->cur_tin = tin; + for (q->cur_flow = 0; q->cur_flow < CAKE_QUEUES; q->cur_flow++) + while (!!(skb = cake_dequeue_one(sch))) + kfree_skb(skb); +} + +static struct sk_buff *cake_dequeue(struct Qdisc *sch) +{ + struct cake_sched_data *q = qdisc_priv(sch); + struct cake_tin_data *b = &q->tins[q->cur_tin]; + struct cake_host *srchost, *dsthost; + ktime_t now = ktime_get(); + struct cake_flow *flow; + struct list_head *head; + bool first_flow = true; + struct sk_buff *skb; + u16 host_load; + u64 delay; + u32 len; + +begin: + if (!sch->q.qlen) + return NULL; + + /* global hard shaper */ + if (ktime_after(q->time_next_packet, now) && + ktime_after(q->failsafe_next_packet, now)) { + u64 next = min(ktime_to_ns(q->time_next_packet), + ktime_to_ns(q->failsafe_next_packet)); + + sch->qstats.overlimits++; + qdisc_watchdog_schedule_ns(&q->watchdog, next); + return NULL; + } + + /* Choose a class to work on. */ + if (!q->rate_ns) { + /* In unlimited mode, can't rely on shaper timings, just balance + * with DRR + */ + bool wrapped = false, empty = true; + + while (b->tin_deficit < 0 || + !(b->sparse_flow_count + b->bulk_flow_count)) { + if (b->tin_deficit <= 0) + b->tin_deficit += b->tin_quantum; + if (b->sparse_flow_count + b->bulk_flow_count) + empty = false; + + q->cur_tin++; + b++; + if (q->cur_tin >= q->tin_cnt) { + q->cur_tin = 0; + b = q->tins; + + if (wrapped) { + /* It's possible for q->qlen to be + * nonzero when we actually have no + * packets anywhere. + */ + if (empty) + return NULL; + } else { + wrapped = true; + } + } + } + } else { + /* In shaped mode, choose: + * - Highest-priority tin with queue and meeting schedule, or + * - The earliest-scheduled tin with queue. + */ + ktime_t best_time = KTIME_MAX; + int tin, best_tin = 0; + + for (tin = 0; tin < q->tin_cnt; tin++) { + b = q->tins + tin; + if ((b->sparse_flow_count + b->bulk_flow_count) > 0) { + ktime_t time_to_pkt = \ + ktime_sub(b->time_next_packet, now); + + if (ktime_to_ns(time_to_pkt) <= 0 || + ktime_compare(time_to_pkt, + best_time) <= 0) { + best_time = time_to_pkt; + best_tin = tin; + } + } + } + + q->cur_tin = best_tin; + b = q->tins + best_tin; + + /* No point in going further if no packets to deliver. */ + if (unlikely(!(b->sparse_flow_count + b->bulk_flow_count))) + return NULL; + } + +retry: + /* service this class */ + head = &b->decaying_flows; + if (!first_flow || list_empty(head)) { + head = &b->new_flows; + if (list_empty(head)) { + head = &b->old_flows; + if (unlikely(list_empty(head))) { + head = &b->decaying_flows; + if (unlikely(list_empty(head))) + goto begin; + } + } + } + flow = list_first_entry(head, struct cake_flow, flowchain); + q->cur_flow = flow - b->flows; + first_flow = false; + + /* triple isolation (modified DRR++) */ + srchost = &b->hosts[flow->srchost]; + dsthost = &b->hosts[flow->dsthost]; + host_load = 1; + + /* flow isolation (DRR++) */ + if (flow->deficit <= 0) { + /* Keep all flows with deficits out of the sparse and decaying + * rotations. No non-empty flow can go into the decaying + * rotation, so they can't get deficits + */ + if (flow->set == CAKE_SET_SPARSE) { + if (flow->head) { + b->sparse_flow_count--; + b->bulk_flow_count++; + + if (cake_dsrc(q->flow_mode)) + srchost->srchost_bulk_flow_count++; + + if (cake_ddst(q->flow_mode)) + dsthost->dsthost_bulk_flow_count++; + + flow->set = CAKE_SET_BULK; + } else { + /* we've moved it to the bulk rotation for + * correct deficit accounting but we still want + * to count it as a sparse flow, not a bulk one. + */ + flow->set = CAKE_SET_SPARSE_WAIT; + } + } + + if (cake_dsrc(q->flow_mode)) + host_load = max(host_load, srchost->srchost_bulk_flow_count); + + if (cake_ddst(q->flow_mode)) + host_load = max(host_load, dsthost->dsthost_bulk_flow_count); + + WARN_ON(host_load > CAKE_QUEUES); + + /* The get_random_u16() is a way to apply dithering to avoid + * accumulating roundoff errors + */ + flow->deficit += (b->flow_quantum * quantum_div[host_load] + + get_random_u16()) >> 16; + list_move_tail(&flow->flowchain, &b->old_flows); + + goto retry; + } + + /* Retrieve a packet via the AQM */ + while (1) { + skb = cake_dequeue_one(sch); + if (!skb) { + /* this queue was actually empty */ + if (cobalt_queue_empty(&flow->cvars, &b->cparams, now)) + b->unresponsive_flow_count--; + + if (flow->cvars.p_drop || flow->cvars.count || + ktime_before(now, flow->cvars.drop_next)) { + /* keep in the flowchain until the state has + * decayed to rest + */ + list_move_tail(&flow->flowchain, + &b->decaying_flows); + if (flow->set == CAKE_SET_BULK) { + b->bulk_flow_count--; + + if (cake_dsrc(q->flow_mode)) + srchost->srchost_bulk_flow_count--; + + if (cake_ddst(q->flow_mode)) + dsthost->dsthost_bulk_flow_count--; + + b->decaying_flow_count++; + } else if (flow->set == CAKE_SET_SPARSE || + flow->set == CAKE_SET_SPARSE_WAIT) { + b->sparse_flow_count--; + b->decaying_flow_count++; + } + flow->set = CAKE_SET_DECAYING; + } else { + /* remove empty queue from the flowchain */ + list_del_init(&flow->flowchain); + if (flow->set == CAKE_SET_SPARSE || + flow->set == CAKE_SET_SPARSE_WAIT) + b->sparse_flow_count--; + else if (flow->set == CAKE_SET_BULK) { + b->bulk_flow_count--; + + if (cake_dsrc(q->flow_mode)) + srchost->srchost_bulk_flow_count--; + + if (cake_ddst(q->flow_mode)) + dsthost->dsthost_bulk_flow_count--; + + } else + b->decaying_flow_count--; + + flow->set = CAKE_SET_NONE; + } + goto begin; + } + + /* Last packet in queue may be marked, shouldn't be dropped */ + if (!cobalt_should_drop(&flow->cvars, &b->cparams, now, skb, + (b->bulk_flow_count * + !!(q->rate_flags & + CAKE_FLAG_INGRESS))) || + !flow->head) + break; + + /* drop this packet, get another one */ + if (q->rate_flags & CAKE_FLAG_INGRESS) { + len = cake_advance_shaper(q, b, skb, + now, true); + flow->deficit -= len; + b->tin_deficit -= len; + } + flow->dropped++; + b->tin_dropped++; + qdisc_tree_reduce_backlog(sch, 1, qdisc_pkt_len(skb)); + qdisc_qstats_drop(sch); + kfree_skb(skb); + if (q->rate_flags & CAKE_FLAG_INGRESS) + goto retry; + } + + b->tin_ecn_mark += !!flow->cvars.ecn_marked; + qdisc_bstats_update(sch, skb); + + /* collect delay stats */ + delay = ktime_to_ns(ktime_sub(now, cobalt_get_enqueue_time(skb))); + b->avge_delay = cake_ewma(b->avge_delay, delay, 8); + b->peak_delay = cake_ewma(b->peak_delay, delay, + delay > b->peak_delay ? 2 : 8); + b->base_delay = cake_ewma(b->base_delay, delay, + delay < b->base_delay ? 2 : 8); + + len = cake_advance_shaper(q, b, skb, now, false); + flow->deficit -= len; + b->tin_deficit -= len; + + if (ktime_after(q->time_next_packet, now) && sch->q.qlen) { + u64 next = min(ktime_to_ns(q->time_next_packet), + ktime_to_ns(q->failsafe_next_packet)); + + qdisc_watchdog_schedule_ns(&q->watchdog, next); + } else if (!sch->q.qlen) { + int i; + + for (i = 0; i < q->tin_cnt; i++) { + if (q->tins[i].decaying_flow_count) { + ktime_t next = \ + ktime_add_ns(now, + q->tins[i].cparams.target); + + qdisc_watchdog_schedule_ns(&q->watchdog, + ktime_to_ns(next)); + break; + } + } + } + + if (q->overflow_timeout) + q->overflow_timeout--; + + return skb; +} + +static void cake_reset(struct Qdisc *sch) +{ + struct cake_sched_data *q = qdisc_priv(sch); + u32 c; + + if (!q->tins) + return; + + for (c = 0; c < CAKE_MAX_TINS; c++) + cake_clear_tin(sch, c); +} + +static const struct nla_policy cake_policy[TCA_CAKE_MAX + 1] = { + [TCA_CAKE_BASE_RATE64] = { .type = NLA_U64 }, + [TCA_CAKE_DIFFSERV_MODE] = { .type = NLA_U32 }, + [TCA_CAKE_ATM] = { .type = NLA_U32 }, + [TCA_CAKE_FLOW_MODE] = { .type = NLA_U32 }, + [TCA_CAKE_OVERHEAD] = { .type = NLA_S32 }, + [TCA_CAKE_RTT] = { .type = NLA_U32 }, + [TCA_CAKE_TARGET] = { .type = NLA_U32 }, + [TCA_CAKE_AUTORATE] = { .type = NLA_U32 }, + [TCA_CAKE_MEMORY] = { .type = NLA_U32 }, + [TCA_CAKE_NAT] = { .type = NLA_U32 }, + [TCA_CAKE_RAW] = { .type = NLA_U32 }, + [TCA_CAKE_WASH] = { .type = NLA_U32 }, + [TCA_CAKE_MPU] = { .type = NLA_U32 }, + [TCA_CAKE_INGRESS] = { .type = NLA_U32 }, + [TCA_CAKE_ACK_FILTER] = { .type = NLA_U32 }, + [TCA_CAKE_SPLIT_GSO] = { .type = NLA_U32 }, + [TCA_CAKE_FWMARK] = { .type = NLA_U32 }, +}; + +static void cake_set_rate(struct cake_tin_data *b, u64 rate, u32 mtu, + u64 target_ns, u64 rtt_est_ns) +{ + /* convert byte-rate into time-per-byte + * so it will always unwedge in reasonable time. + */ + static const u64 MIN_RATE = 64; + u32 byte_target = mtu; + u64 byte_target_ns; + u8 rate_shft = 0; + u64 rate_ns = 0; + + b->flow_quantum = 1514; + if (rate) { + b->flow_quantum = max(min(rate >> 12, 1514ULL), 300ULL); + rate_shft = 34; + rate_ns = ((u64)NSEC_PER_SEC) << rate_shft; + rate_ns = div64_u64(rate_ns, max(MIN_RATE, rate)); + while (!!(rate_ns >> 34)) { + rate_ns >>= 1; + rate_shft--; + } + } /* else unlimited, ie. zero delay */ + + b->tin_rate_bps = rate; + b->tin_rate_ns = rate_ns; + b->tin_rate_shft = rate_shft; + + byte_target_ns = (byte_target * rate_ns) >> rate_shft; + + b->cparams.target = max((byte_target_ns * 3) / 2, target_ns); + b->cparams.interval = max(rtt_est_ns + + b->cparams.target - target_ns, + b->cparams.target * 2); + b->cparams.mtu_time = byte_target_ns; + b->cparams.p_inc = 1 << 24; /* 1/256 */ + b->cparams.p_dec = 1 << 20; /* 1/4096 */ +} + +static int cake_config_besteffort(struct Qdisc *sch) +{ + struct cake_sched_data *q = qdisc_priv(sch); + struct cake_tin_data *b = &q->tins[0]; + u32 mtu = psched_mtu(qdisc_dev(sch)); + u64 rate = q->rate_bps; + + q->tin_cnt = 1; + + q->tin_index = besteffort; + q->tin_order = normal_order; + + cake_set_rate(b, rate, mtu, + us_to_ns(q->target), us_to_ns(q->interval)); + b->tin_quantum = 65535; + + return 0; +} + +static int cake_config_precedence(struct Qdisc *sch) +{ + /* convert high-level (user visible) parameters into internal format */ + struct cake_sched_data *q = qdisc_priv(sch); + u32 mtu = psched_mtu(qdisc_dev(sch)); + u64 rate = q->rate_bps; + u32 quantum = 256; + u32 i; + + q->tin_cnt = 8; + q->tin_index = precedence; + q->tin_order = normal_order; + + for (i = 0; i < q->tin_cnt; i++) { + struct cake_tin_data *b = &q->tins[i]; + + cake_set_rate(b, rate, mtu, us_to_ns(q->target), + us_to_ns(q->interval)); + + b->tin_quantum = max_t(u16, 1U, quantum); + + /* calculate next class's parameters */ + rate *= 7; + rate >>= 3; + + quantum *= 7; + quantum >>= 3; + } + + return 0; +} + +/* List of known Diffserv codepoints: + * + * Default Forwarding (DF/CS0) - Best Effort + * Max Throughput (TOS2) + * Min Delay (TOS4) + * LLT "La" (TOS5) + * Assured Forwarding 1 (AF1x) - x3 + * Assured Forwarding 2 (AF2x) - x3 + * Assured Forwarding 3 (AF3x) - x3 + * Assured Forwarding 4 (AF4x) - x3 + * Precedence Class 1 (CS1) + * Precedence Class 2 (CS2) + * Precedence Class 3 (CS3) + * Precedence Class 4 (CS4) + * Precedence Class 5 (CS5) + * Precedence Class 6 (CS6) + * Precedence Class 7 (CS7) + * Voice Admit (VA) + * Expedited Forwarding (EF) + * Lower Effort (LE) + * + * Total 26 codepoints. + */ + +/* List of traffic classes in RFC 4594, updated by RFC 8622: + * (roughly descending order of contended priority) + * (roughly ascending order of uncontended throughput) + * + * Network Control (CS6,CS7) - routing traffic + * Telephony (EF,VA) - aka. VoIP streams + * Signalling (CS5) - VoIP setup + * Multimedia Conferencing (AF4x) - aka. video calls + * Realtime Interactive (CS4) - eg. games + * Multimedia Streaming (AF3x) - eg. YouTube, NetFlix, Twitch + * Broadcast Video (CS3) + * Low-Latency Data (AF2x,TOS4) - eg. database + * Ops, Admin, Management (CS2) - eg. ssh + * Standard Service (DF & unrecognised codepoints) + * High-Throughput Data (AF1x,TOS2) - eg. web traffic + * Low-Priority Data (LE,CS1) - eg. BitTorrent + * + * Total 12 traffic classes. + */ + +static int cake_config_diffserv8(struct Qdisc *sch) +{ +/* Pruned list of traffic classes for typical applications: + * + * Network Control (CS6, CS7) + * Minimum Latency (EF, VA, CS5, CS4) + * Interactive Shell (CS2) + * Low Latency Transactions (AF2x, TOS4) + * Video Streaming (AF4x, AF3x, CS3) + * Bog Standard (DF etc.) + * High Throughput (AF1x, TOS2, CS1) + * Background Traffic (LE) + * + * Total 8 traffic classes. + */ + + struct cake_sched_data *q = qdisc_priv(sch); + u32 mtu = psched_mtu(qdisc_dev(sch)); + u64 rate = q->rate_bps; + u32 quantum = 256; + u32 i; + + q->tin_cnt = 8; + + /* codepoint to class mapping */ + q->tin_index = diffserv8; + q->tin_order = normal_order; + + /* class characteristics */ + for (i = 0; i < q->tin_cnt; i++) { + struct cake_tin_data *b = &q->tins[i]; + + cake_set_rate(b, rate, mtu, us_to_ns(q->target), + us_to_ns(q->interval)); + + b->tin_quantum = max_t(u16, 1U, quantum); + + /* calculate next class's parameters */ + rate *= 7; + rate >>= 3; + + quantum *= 7; + quantum >>= 3; + } + + return 0; +} + +static int cake_config_diffserv4(struct Qdisc *sch) +{ +/* Further pruned list of traffic classes for four-class system: + * + * Latency Sensitive (CS7, CS6, EF, VA, CS5, CS4) + * Streaming Media (AF4x, AF3x, CS3, AF2x, TOS4, CS2) + * Best Effort (DF, AF1x, TOS2, and those not specified) + * Background Traffic (LE, CS1) + * + * Total 4 traffic classes. + */ + + struct cake_sched_data *q = qdisc_priv(sch); + u32 mtu = psched_mtu(qdisc_dev(sch)); + u64 rate = q->rate_bps; + u32 quantum = 1024; + + q->tin_cnt = 4; + + /* codepoint to class mapping */ + q->tin_index = diffserv4; + q->tin_order = bulk_order; + + /* class characteristics */ + cake_set_rate(&q->tins[0], rate, mtu, + us_to_ns(q->target), us_to_ns(q->interval)); + cake_set_rate(&q->tins[1], rate >> 4, mtu, + us_to_ns(q->target), us_to_ns(q->interval)); + cake_set_rate(&q->tins[2], rate >> 1, mtu, + us_to_ns(q->target), us_to_ns(q->interval)); + cake_set_rate(&q->tins[3], rate >> 2, mtu, + us_to_ns(q->target), us_to_ns(q->interval)); + + /* bandwidth-sharing weights */ + q->tins[0].tin_quantum = quantum; + q->tins[1].tin_quantum = quantum >> 4; + q->tins[2].tin_quantum = quantum >> 1; + q->tins[3].tin_quantum = quantum >> 2; + + return 0; +} + +static int cake_config_diffserv3(struct Qdisc *sch) +{ +/* Simplified Diffserv structure with 3 tins. + * Latency Sensitive (CS7, CS6, EF, VA, TOS4) + * Best Effort + * Low Priority (LE, CS1) + */ + struct cake_sched_data *q = qdisc_priv(sch); + u32 mtu = psched_mtu(qdisc_dev(sch)); + u64 rate = q->rate_bps; + u32 quantum = 1024; + + q->tin_cnt = 3; + + /* codepoint to class mapping */ + q->tin_index = diffserv3; + q->tin_order = bulk_order; + + /* class characteristics */ + cake_set_rate(&q->tins[0], rate, mtu, + us_to_ns(q->target), us_to_ns(q->interval)); + cake_set_rate(&q->tins[1], rate >> 4, mtu, + us_to_ns(q->target), us_to_ns(q->interval)); + cake_set_rate(&q->tins[2], rate >> 2, mtu, + us_to_ns(q->target), us_to_ns(q->interval)); + + /* bandwidth-sharing weights */ + q->tins[0].tin_quantum = quantum; + q->tins[1].tin_quantum = quantum >> 4; + q->tins[2].tin_quantum = quantum >> 2; + + return 0; +} + +static void cake_reconfigure(struct Qdisc *sch) +{ + struct cake_sched_data *q = qdisc_priv(sch); + int c, ft; + + switch (q->tin_mode) { + case CAKE_DIFFSERV_BESTEFFORT: + ft = cake_config_besteffort(sch); + break; + + case CAKE_DIFFSERV_PRECEDENCE: + ft = cake_config_precedence(sch); + break; + + case CAKE_DIFFSERV_DIFFSERV8: + ft = cake_config_diffserv8(sch); + break; + + case CAKE_DIFFSERV_DIFFSERV4: + ft = cake_config_diffserv4(sch); + break; + + case CAKE_DIFFSERV_DIFFSERV3: + default: + ft = cake_config_diffserv3(sch); + break; + } + + for (c = q->tin_cnt; c < CAKE_MAX_TINS; c++) { + cake_clear_tin(sch, c); + q->tins[c].cparams.mtu_time = q->tins[ft].cparams.mtu_time; + } + + q->rate_ns = q->tins[ft].tin_rate_ns; + q->rate_shft = q->tins[ft].tin_rate_shft; + + if (q->buffer_config_limit) { + q->buffer_limit = q->buffer_config_limit; + } else if (q->rate_bps) { + u64 t = q->rate_bps * q->interval; + + do_div(t, USEC_PER_SEC / 4); + q->buffer_limit = max_t(u32, t, 4U << 20); + } else { + q->buffer_limit = ~0; + } + + sch->flags &= ~TCQ_F_CAN_BYPASS; + + q->buffer_limit = min(q->buffer_limit, + max(sch->limit * psched_mtu(qdisc_dev(sch)), + q->buffer_config_limit)); +} + +static int cake_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct cake_sched_data *q = qdisc_priv(sch); + struct nlattr *tb[TCA_CAKE_MAX + 1]; + int err; + + err = nla_parse_nested_deprecated(tb, TCA_CAKE_MAX, opt, cake_policy, + extack); + if (err < 0) + return err; + + if (tb[TCA_CAKE_NAT]) { +#if IS_ENABLED(CONFIG_NF_CONNTRACK) + q->flow_mode &= ~CAKE_FLOW_NAT_FLAG; + q->flow_mode |= CAKE_FLOW_NAT_FLAG * + !!nla_get_u32(tb[TCA_CAKE_NAT]); +#else + NL_SET_ERR_MSG_ATTR(extack, tb[TCA_CAKE_NAT], + "No conntrack support in kernel"); + return -EOPNOTSUPP; +#endif + } + + if (tb[TCA_CAKE_BASE_RATE64]) + q->rate_bps = nla_get_u64(tb[TCA_CAKE_BASE_RATE64]); + + if (tb[TCA_CAKE_DIFFSERV_MODE]) + q->tin_mode = nla_get_u32(tb[TCA_CAKE_DIFFSERV_MODE]); + + if (tb[TCA_CAKE_WASH]) { + if (!!nla_get_u32(tb[TCA_CAKE_WASH])) + q->rate_flags |= CAKE_FLAG_WASH; + else + q->rate_flags &= ~CAKE_FLAG_WASH; + } + + if (tb[TCA_CAKE_FLOW_MODE]) + q->flow_mode = ((q->flow_mode & CAKE_FLOW_NAT_FLAG) | + (nla_get_u32(tb[TCA_CAKE_FLOW_MODE]) & + CAKE_FLOW_MASK)); + + if (tb[TCA_CAKE_ATM]) + q->atm_mode = nla_get_u32(tb[TCA_CAKE_ATM]); + + if (tb[TCA_CAKE_OVERHEAD]) { + q->rate_overhead = nla_get_s32(tb[TCA_CAKE_OVERHEAD]); + q->rate_flags |= CAKE_FLAG_OVERHEAD; + + q->max_netlen = 0; + q->max_adjlen = 0; + q->min_netlen = ~0; + q->min_adjlen = ~0; + } + + if (tb[TCA_CAKE_RAW]) { + q->rate_flags &= ~CAKE_FLAG_OVERHEAD; + + q->max_netlen = 0; + q->max_adjlen = 0; + q->min_netlen = ~0; + q->min_adjlen = ~0; + } + + if (tb[TCA_CAKE_MPU]) + q->rate_mpu = nla_get_u32(tb[TCA_CAKE_MPU]); + + if (tb[TCA_CAKE_RTT]) { + q->interval = nla_get_u32(tb[TCA_CAKE_RTT]); + + if (!q->interval) + q->interval = 1; + } + + if (tb[TCA_CAKE_TARGET]) { + q->target = nla_get_u32(tb[TCA_CAKE_TARGET]); + + if (!q->target) + q->target = 1; + } + + if (tb[TCA_CAKE_AUTORATE]) { + if (!!nla_get_u32(tb[TCA_CAKE_AUTORATE])) + q->rate_flags |= CAKE_FLAG_AUTORATE_INGRESS; + else + q->rate_flags &= ~CAKE_FLAG_AUTORATE_INGRESS; + } + + if (tb[TCA_CAKE_INGRESS]) { + if (!!nla_get_u32(tb[TCA_CAKE_INGRESS])) + q->rate_flags |= CAKE_FLAG_INGRESS; + else + q->rate_flags &= ~CAKE_FLAG_INGRESS; + } + + if (tb[TCA_CAKE_ACK_FILTER]) + q->ack_filter = nla_get_u32(tb[TCA_CAKE_ACK_FILTER]); + + if (tb[TCA_CAKE_MEMORY]) + q->buffer_config_limit = nla_get_u32(tb[TCA_CAKE_MEMORY]); + + if (tb[TCA_CAKE_SPLIT_GSO]) { + if (!!nla_get_u32(tb[TCA_CAKE_SPLIT_GSO])) + q->rate_flags |= CAKE_FLAG_SPLIT_GSO; + else + q->rate_flags &= ~CAKE_FLAG_SPLIT_GSO; + } + + if (tb[TCA_CAKE_FWMARK]) { + q->fwmark_mask = nla_get_u32(tb[TCA_CAKE_FWMARK]); + q->fwmark_shft = q->fwmark_mask ? __ffs(q->fwmark_mask) : 0; + } + + if (q->tins) { + sch_tree_lock(sch); + cake_reconfigure(sch); + sch_tree_unlock(sch); + } + + return 0; +} + +static void cake_destroy(struct Qdisc *sch) +{ + struct cake_sched_data *q = qdisc_priv(sch); + + qdisc_watchdog_cancel(&q->watchdog); + tcf_block_put(q->block); + kvfree(q->tins); +} + +static int cake_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct cake_sched_data *q = qdisc_priv(sch); + int i, j, err; + + sch->limit = 10240; + q->tin_mode = CAKE_DIFFSERV_DIFFSERV3; + q->flow_mode = CAKE_FLOW_TRIPLE; + + q->rate_bps = 0; /* unlimited by default */ + + q->interval = 100000; /* 100ms default */ + q->target = 5000; /* 5ms: codel RFC argues + * for 5 to 10% of interval + */ + q->rate_flags |= CAKE_FLAG_SPLIT_GSO; + q->cur_tin = 0; + q->cur_flow = 0; + + qdisc_watchdog_init(&q->watchdog, sch); + + if (opt) { + err = cake_change(sch, opt, extack); + + if (err) + return err; + } + + err = tcf_block_get(&q->block, &q->filter_list, sch, extack); + if (err) + return err; + + quantum_div[0] = ~0; + for (i = 1; i <= CAKE_QUEUES; i++) + quantum_div[i] = 65535 / i; + + q->tins = kvcalloc(CAKE_MAX_TINS, sizeof(struct cake_tin_data), + GFP_KERNEL); + if (!q->tins) + return -ENOMEM; + + for (i = 0; i < CAKE_MAX_TINS; i++) { + struct cake_tin_data *b = q->tins + i; + + INIT_LIST_HEAD(&b->new_flows); + INIT_LIST_HEAD(&b->old_flows); + INIT_LIST_HEAD(&b->decaying_flows); + b->sparse_flow_count = 0; + b->bulk_flow_count = 0; + b->decaying_flow_count = 0; + + for (j = 0; j < CAKE_QUEUES; j++) { + struct cake_flow *flow = b->flows + j; + u32 k = j * CAKE_MAX_TINS + i; + + INIT_LIST_HEAD(&flow->flowchain); + cobalt_vars_init(&flow->cvars); + + q->overflow_heap[k].t = i; + q->overflow_heap[k].b = j; + b->overflow_idx[j] = k; + } + } + + cake_reconfigure(sch); + q->avg_peak_bandwidth = q->rate_bps; + q->min_netlen = ~0; + q->min_adjlen = ~0; + return 0; +} + +static int cake_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct cake_sched_data *q = qdisc_priv(sch); + struct nlattr *opts; + + opts = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (!opts) + goto nla_put_failure; + + if (nla_put_u64_64bit(skb, TCA_CAKE_BASE_RATE64, q->rate_bps, + TCA_CAKE_PAD)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_CAKE_FLOW_MODE, + q->flow_mode & CAKE_FLOW_MASK)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_CAKE_RTT, q->interval)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_CAKE_TARGET, q->target)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_CAKE_MEMORY, q->buffer_config_limit)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_CAKE_AUTORATE, + !!(q->rate_flags & CAKE_FLAG_AUTORATE_INGRESS))) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_CAKE_INGRESS, + !!(q->rate_flags & CAKE_FLAG_INGRESS))) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_CAKE_ACK_FILTER, q->ack_filter)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_CAKE_NAT, + !!(q->flow_mode & CAKE_FLOW_NAT_FLAG))) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_CAKE_DIFFSERV_MODE, q->tin_mode)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_CAKE_WASH, + !!(q->rate_flags & CAKE_FLAG_WASH))) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_CAKE_OVERHEAD, q->rate_overhead)) + goto nla_put_failure; + + if (!(q->rate_flags & CAKE_FLAG_OVERHEAD)) + if (nla_put_u32(skb, TCA_CAKE_RAW, 0)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_CAKE_ATM, q->atm_mode)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_CAKE_MPU, q->rate_mpu)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_CAKE_SPLIT_GSO, + !!(q->rate_flags & CAKE_FLAG_SPLIT_GSO))) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_CAKE_FWMARK, q->fwmark_mask)) + goto nla_put_failure; + + return nla_nest_end(skb, opts); + +nla_put_failure: + return -1; +} + +static int cake_dump_stats(struct Qdisc *sch, struct gnet_dump *d) +{ + struct nlattr *stats = nla_nest_start_noflag(d->skb, TCA_STATS_APP); + struct cake_sched_data *q = qdisc_priv(sch); + struct nlattr *tstats, *ts; + int i; + + if (!stats) + return -1; + +#define PUT_STAT_U32(attr, data) do { \ + if (nla_put_u32(d->skb, TCA_CAKE_STATS_ ## attr, data)) \ + goto nla_put_failure; \ + } while (0) +#define PUT_STAT_U64(attr, data) do { \ + if (nla_put_u64_64bit(d->skb, TCA_CAKE_STATS_ ## attr, \ + data, TCA_CAKE_STATS_PAD)) \ + goto nla_put_failure; \ + } while (0) + + PUT_STAT_U64(CAPACITY_ESTIMATE64, q->avg_peak_bandwidth); + PUT_STAT_U32(MEMORY_LIMIT, q->buffer_limit); + PUT_STAT_U32(MEMORY_USED, q->buffer_max_used); + PUT_STAT_U32(AVG_NETOFF, ((q->avg_netoff + 0x8000) >> 16)); + PUT_STAT_U32(MAX_NETLEN, q->max_netlen); + PUT_STAT_U32(MAX_ADJLEN, q->max_adjlen); + PUT_STAT_U32(MIN_NETLEN, q->min_netlen); + PUT_STAT_U32(MIN_ADJLEN, q->min_adjlen); + +#undef PUT_STAT_U32 +#undef PUT_STAT_U64 + + tstats = nla_nest_start_noflag(d->skb, TCA_CAKE_STATS_TIN_STATS); + if (!tstats) + goto nla_put_failure; + +#define PUT_TSTAT_U32(attr, data) do { \ + if (nla_put_u32(d->skb, TCA_CAKE_TIN_STATS_ ## attr, data)) \ + goto nla_put_failure; \ + } while (0) +#define PUT_TSTAT_U64(attr, data) do { \ + if (nla_put_u64_64bit(d->skb, TCA_CAKE_TIN_STATS_ ## attr, \ + data, TCA_CAKE_TIN_STATS_PAD)) \ + goto nla_put_failure; \ + } while (0) + + for (i = 0; i < q->tin_cnt; i++) { + struct cake_tin_data *b = &q->tins[q->tin_order[i]]; + + ts = nla_nest_start_noflag(d->skb, i + 1); + if (!ts) + goto nla_put_failure; + + PUT_TSTAT_U64(THRESHOLD_RATE64, b->tin_rate_bps); + PUT_TSTAT_U64(SENT_BYTES64, b->bytes); + PUT_TSTAT_U32(BACKLOG_BYTES, b->tin_backlog); + + PUT_TSTAT_U32(TARGET_US, + ktime_to_us(ns_to_ktime(b->cparams.target))); + PUT_TSTAT_U32(INTERVAL_US, + ktime_to_us(ns_to_ktime(b->cparams.interval))); + + PUT_TSTAT_U32(SENT_PACKETS, b->packets); + PUT_TSTAT_U32(DROPPED_PACKETS, b->tin_dropped); + PUT_TSTAT_U32(ECN_MARKED_PACKETS, b->tin_ecn_mark); + PUT_TSTAT_U32(ACKS_DROPPED_PACKETS, b->ack_drops); + + PUT_TSTAT_U32(PEAK_DELAY_US, + ktime_to_us(ns_to_ktime(b->peak_delay))); + PUT_TSTAT_U32(AVG_DELAY_US, + ktime_to_us(ns_to_ktime(b->avge_delay))); + PUT_TSTAT_U32(BASE_DELAY_US, + ktime_to_us(ns_to_ktime(b->base_delay))); + + PUT_TSTAT_U32(WAY_INDIRECT_HITS, b->way_hits); + PUT_TSTAT_U32(WAY_MISSES, b->way_misses); + PUT_TSTAT_U32(WAY_COLLISIONS, b->way_collisions); + + PUT_TSTAT_U32(SPARSE_FLOWS, b->sparse_flow_count + + b->decaying_flow_count); + PUT_TSTAT_U32(BULK_FLOWS, b->bulk_flow_count); + PUT_TSTAT_U32(UNRESPONSIVE_FLOWS, b->unresponsive_flow_count); + PUT_TSTAT_U32(MAX_SKBLEN, b->max_skblen); + + PUT_TSTAT_U32(FLOW_QUANTUM, b->flow_quantum); + nla_nest_end(d->skb, ts); + } + +#undef PUT_TSTAT_U32 +#undef PUT_TSTAT_U64 + + nla_nest_end(d->skb, tstats); + return nla_nest_end(d->skb, stats); + +nla_put_failure: + nla_nest_cancel(d->skb, stats); + return -1; +} + +static struct Qdisc *cake_leaf(struct Qdisc *sch, unsigned long arg) +{ + return NULL; +} + +static unsigned long cake_find(struct Qdisc *sch, u32 classid) +{ + return 0; +} + +static unsigned long cake_bind(struct Qdisc *sch, unsigned long parent, + u32 classid) +{ + return 0; +} + +static void cake_unbind(struct Qdisc *q, unsigned long cl) +{ +} + +static struct tcf_block *cake_tcf_block(struct Qdisc *sch, unsigned long cl, + struct netlink_ext_ack *extack) +{ + struct cake_sched_data *q = qdisc_priv(sch); + + if (cl) + return NULL; + return q->block; +} + +static int cake_dump_class(struct Qdisc *sch, unsigned long cl, + struct sk_buff *skb, struct tcmsg *tcm) +{ + tcm->tcm_handle |= TC_H_MIN(cl); + return 0; +} + +static int cake_dump_class_stats(struct Qdisc *sch, unsigned long cl, + struct gnet_dump *d) +{ + struct cake_sched_data *q = qdisc_priv(sch); + const struct cake_flow *flow = NULL; + struct gnet_stats_queue qs = { 0 }; + struct nlattr *stats; + u32 idx = cl - 1; + + if (idx < CAKE_QUEUES * q->tin_cnt) { + const struct cake_tin_data *b = \ + &q->tins[q->tin_order[idx / CAKE_QUEUES]]; + const struct sk_buff *skb; + + flow = &b->flows[idx % CAKE_QUEUES]; + + if (flow->head) { + sch_tree_lock(sch); + skb = flow->head; + while (skb) { + qs.qlen++; + skb = skb->next; + } + sch_tree_unlock(sch); + } + qs.backlog = b->backlogs[idx % CAKE_QUEUES]; + qs.drops = flow->dropped; + } + if (gnet_stats_copy_queue(d, NULL, &qs, qs.qlen) < 0) + return -1; + if (flow) { + ktime_t now = ktime_get(); + + stats = nla_nest_start_noflag(d->skb, TCA_STATS_APP); + if (!stats) + return -1; + +#define PUT_STAT_U32(attr, data) do { \ + if (nla_put_u32(d->skb, TCA_CAKE_STATS_ ## attr, data)) \ + goto nla_put_failure; \ + } while (0) +#define PUT_STAT_S32(attr, data) do { \ + if (nla_put_s32(d->skb, TCA_CAKE_STATS_ ## attr, data)) \ + goto nla_put_failure; \ + } while (0) + + PUT_STAT_S32(DEFICIT, flow->deficit); + PUT_STAT_U32(DROPPING, flow->cvars.dropping); + PUT_STAT_U32(COBALT_COUNT, flow->cvars.count); + PUT_STAT_U32(P_DROP, flow->cvars.p_drop); + if (flow->cvars.p_drop) { + PUT_STAT_S32(BLUE_TIMER_US, + ktime_to_us( + ktime_sub(now, + flow->cvars.blue_timer))); + } + if (flow->cvars.dropping) { + PUT_STAT_S32(DROP_NEXT_US, + ktime_to_us( + ktime_sub(now, + flow->cvars.drop_next))); + } + + if (nla_nest_end(d->skb, stats) < 0) + return -1; + } + + return 0; + +nla_put_failure: + nla_nest_cancel(d->skb, stats); + return -1; +} + +static void cake_walk(struct Qdisc *sch, struct qdisc_walker *arg) +{ + struct cake_sched_data *q = qdisc_priv(sch); + unsigned int i, j; + + if (arg->stop) + return; + + for (i = 0; i < q->tin_cnt; i++) { + struct cake_tin_data *b = &q->tins[q->tin_order[i]]; + + for (j = 0; j < CAKE_QUEUES; j++) { + if (list_empty(&b->flows[j].flowchain)) { + arg->count++; + continue; + } + if (!tc_qdisc_stats_dump(sch, i * CAKE_QUEUES + j + 1, + arg)) + break; + } + } +} + +static const struct Qdisc_class_ops cake_class_ops = { + .leaf = cake_leaf, + .find = cake_find, + .tcf_block = cake_tcf_block, + .bind_tcf = cake_bind, + .unbind_tcf = cake_unbind, + .dump = cake_dump_class, + .dump_stats = cake_dump_class_stats, + .walk = cake_walk, +}; + +static struct Qdisc_ops cake_qdisc_ops __read_mostly = { + .cl_ops = &cake_class_ops, + .id = "cake", + .priv_size = sizeof(struct cake_sched_data), + .enqueue = cake_enqueue, + .dequeue = cake_dequeue, + .peek = qdisc_peek_dequeued, + .init = cake_init, + .reset = cake_reset, + .destroy = cake_destroy, + .change = cake_change, + .dump = cake_dump, + .dump_stats = cake_dump_stats, + .owner = THIS_MODULE, +}; + +static int __init cake_module_init(void) +{ + return register_qdisc(&cake_qdisc_ops); +} + +static void __exit cake_module_exit(void) +{ + unregister_qdisc(&cake_qdisc_ops); +} + +module_init(cake_module_init) +module_exit(cake_module_exit) +MODULE_AUTHOR("Jonathan Morton"); +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_DESCRIPTION("The CAKE shaper."); diff --git a/net/sched/sch_cbq.c b/net/sched/sch_cbq.c new file mode 100644 index 000000000..36db5f678 --- /dev/null +++ b/net/sched/sch_cbq.c @@ -0,0 +1,1727 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/sch_cbq.c Class-Based Queueing discipline. + * + * Authors: Alexey Kuznetsov, + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +/* Class-Based Queueing (CBQ) algorithm. + ======================================= + + Sources: [1] Sally Floyd and Van Jacobson, "Link-sharing and Resource + Management Models for Packet Networks", + IEEE/ACM Transactions on Networking, Vol.3, No.4, 1995 + + [2] Sally Floyd, "Notes on CBQ and Guaranteed Service", 1995 + + [3] Sally Floyd, "Notes on Class-Based Queueing: Setting + Parameters", 1996 + + [4] Sally Floyd and Michael Speer, "Experimental Results + for Class-Based Queueing", 1998, not published. + + ----------------------------------------------------------------------- + + Algorithm skeleton was taken from NS simulator cbq.cc. + If someone wants to check this code against the LBL version, + he should take into account that ONLY the skeleton was borrowed, + the implementation is different. Particularly: + + --- The WRR algorithm is different. Our version looks more + reasonable (I hope) and works when quanta are allowed to be + less than MTU, which is always the case when real time classes + have small rates. Note, that the statement of [3] is + incomplete, delay may actually be estimated even if class + per-round allotment is less than MTU. Namely, if per-round + allotment is W*r_i, and r_1+...+r_k = r < 1 + + delay_i <= ([MTU/(W*r_i)]*W*r + W*r + k*MTU)/B + + In the worst case we have IntServ estimate with D = W*r+k*MTU + and C = MTU*r. The proof (if correct at all) is trivial. + + + --- It seems that cbq-2.0 is not very accurate. At least, I cannot + interpret some places, which look like wrong translations + from NS. Anyone is advised to find these differences + and explain to me, why I am wrong 8). + + --- Linux has no EOI event, so that we cannot estimate true class + idle time. Workaround is to consider the next dequeue event + as sign that previous packet is finished. This is wrong because of + internal device queueing, but on a permanently loaded link it is true. + Moreover, combined with clock integrator, this scheme looks + very close to an ideal solution. */ + +struct cbq_sched_data; + + +struct cbq_class { + struct Qdisc_class_common common; + struct cbq_class *next_alive; /* next class with backlog in this priority band */ + +/* Parameters */ + unsigned char priority; /* class priority */ + unsigned char priority2; /* priority to be used after overlimit */ + unsigned char ewma_log; /* time constant for idle time calculation */ + + u32 defmap; + + /* Link-sharing scheduler parameters */ + long maxidle; /* Class parameters: see below. */ + long offtime; + long minidle; + u32 avpkt; + struct qdisc_rate_table *R_tab; + + /* General scheduler (WRR) parameters */ + long allot; + long quantum; /* Allotment per WRR round */ + long weight; /* Relative allotment: see below */ + + struct Qdisc *qdisc; /* Ptr to CBQ discipline */ + struct cbq_class *split; /* Ptr to split node */ + struct cbq_class *share; /* Ptr to LS parent in the class tree */ + struct cbq_class *tparent; /* Ptr to tree parent in the class tree */ + struct cbq_class *borrow; /* NULL if class is bandwidth limited; + parent otherwise */ + struct cbq_class *sibling; /* Sibling chain */ + struct cbq_class *children; /* Pointer to children chain */ + + struct Qdisc *q; /* Elementary queueing discipline */ + + +/* Variables */ + unsigned char cpriority; /* Effective priority */ + unsigned char delayed; + unsigned char level; /* level of the class in hierarchy: + 0 for leaf classes, and maximal + level of children + 1 for nodes. + */ + + psched_time_t last; /* Last end of service */ + psched_time_t undertime; + long avgidle; + long deficit; /* Saved deficit for WRR */ + psched_time_t penalized; + struct gnet_stats_basic_sync bstats; + struct gnet_stats_queue qstats; + struct net_rate_estimator __rcu *rate_est; + struct tc_cbq_xstats xstats; + + struct tcf_proto __rcu *filter_list; + struct tcf_block *block; + + int filters; + + struct cbq_class *defaults[TC_PRIO_MAX + 1]; +}; + +struct cbq_sched_data { + struct Qdisc_class_hash clhash; /* Hash table of all classes */ + int nclasses[TC_CBQ_MAXPRIO + 1]; + unsigned int quanta[TC_CBQ_MAXPRIO + 1]; + + struct cbq_class link; + + unsigned int activemask; + struct cbq_class *active[TC_CBQ_MAXPRIO + 1]; /* List of all classes + with backlog */ + +#ifdef CONFIG_NET_CLS_ACT + struct cbq_class *rx_class; +#endif + struct cbq_class *tx_class; + struct cbq_class *tx_borrowed; + int tx_len; + psched_time_t now; /* Cached timestamp */ + unsigned int pmask; + + struct qdisc_watchdog watchdog; /* Watchdog timer, + started when CBQ has + backlog, but cannot + transmit just now */ + psched_tdiff_t wd_expires; + int toplevel; + u32 hgenerator; +}; + + +#define L2T(cl, len) qdisc_l2t((cl)->R_tab, len) + +static inline struct cbq_class * +cbq_class_lookup(struct cbq_sched_data *q, u32 classid) +{ + struct Qdisc_class_common *clc; + + clc = qdisc_class_find(&q->clhash, classid); + if (clc == NULL) + return NULL; + return container_of(clc, struct cbq_class, common); +} + +#ifdef CONFIG_NET_CLS_ACT + +static struct cbq_class * +cbq_reclassify(struct sk_buff *skb, struct cbq_class *this) +{ + struct cbq_class *cl; + + for (cl = this->tparent; cl; cl = cl->tparent) { + struct cbq_class *new = cl->defaults[TC_PRIO_BESTEFFORT]; + + if (new != NULL && new != this) + return new; + } + return NULL; +} + +#endif + +/* Classify packet. The procedure is pretty complicated, but + * it allows us to combine link sharing and priority scheduling + * transparently. + * + * Namely, you can put link sharing rules (f.e. route based) at root of CBQ, + * so that it resolves to split nodes. Then packets are classified + * by logical priority, or a more specific classifier may be attached + * to the split node. + */ + +static struct cbq_class * +cbq_classify(struct sk_buff *skb, struct Qdisc *sch, int *qerr) +{ + struct cbq_sched_data *q = qdisc_priv(sch); + struct cbq_class *head = &q->link; + struct cbq_class **defmap; + struct cbq_class *cl = NULL; + u32 prio = skb->priority; + struct tcf_proto *fl; + struct tcf_result res; + + /* + * Step 1. If skb->priority points to one of our classes, use it. + */ + if (TC_H_MAJ(prio ^ sch->handle) == 0 && + (cl = cbq_class_lookup(q, prio)) != NULL) + return cl; + + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; + for (;;) { + int result = 0; + defmap = head->defaults; + + fl = rcu_dereference_bh(head->filter_list); + /* + * Step 2+n. Apply classifier. + */ + result = tcf_classify(skb, NULL, fl, &res, true); + if (!fl || result < 0) + goto fallback; + if (result == TC_ACT_SHOT) + return NULL; + + cl = (void *)res.class; + if (!cl) { + if (TC_H_MAJ(res.classid)) + cl = cbq_class_lookup(q, res.classid); + else if ((cl = defmap[res.classid & TC_PRIO_MAX]) == NULL) + cl = defmap[TC_PRIO_BESTEFFORT]; + + if (cl == NULL) + goto fallback; + } + if (cl->level >= head->level) + goto fallback; +#ifdef CONFIG_NET_CLS_ACT + switch (result) { + case TC_ACT_QUEUED: + case TC_ACT_STOLEN: + case TC_ACT_TRAP: + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_STOLEN; + fallthrough; + case TC_ACT_RECLASSIFY: + return cbq_reclassify(skb, cl); + } +#endif + if (cl->level == 0) + return cl; + + /* + * Step 3+n. If classifier selected a link sharing class, + * apply agency specific classifier. + * Repeat this procedure until we hit a leaf node. + */ + head = cl; + } + +fallback: + cl = head; + + /* + * Step 4. No success... + */ + if (TC_H_MAJ(prio) == 0 && + !(cl = head->defaults[prio & TC_PRIO_MAX]) && + !(cl = head->defaults[TC_PRIO_BESTEFFORT])) + return head; + + return cl; +} + +/* + * A packet has just been enqueued on the empty class. + * cbq_activate_class adds it to the tail of active class list + * of its priority band. + */ + +static inline void cbq_activate_class(struct cbq_class *cl) +{ + struct cbq_sched_data *q = qdisc_priv(cl->qdisc); + int prio = cl->cpriority; + struct cbq_class *cl_tail; + + cl_tail = q->active[prio]; + q->active[prio] = cl; + + if (cl_tail != NULL) { + cl->next_alive = cl_tail->next_alive; + cl_tail->next_alive = cl; + } else { + cl->next_alive = cl; + q->activemask |= (1<qdisc); + int prio = this->cpriority; + struct cbq_class *cl; + struct cbq_class *cl_prev = q->active[prio]; + + do { + cl = cl_prev->next_alive; + if (cl == this) { + cl_prev->next_alive = cl->next_alive; + cl->next_alive = NULL; + + if (cl == q->active[prio]) { + q->active[prio] = cl_prev; + if (cl == q->active[prio]) { + q->active[prio] = NULL; + q->activemask &= ~(1<active[prio]); +} + +static void +cbq_mark_toplevel(struct cbq_sched_data *q, struct cbq_class *cl) +{ + int toplevel = q->toplevel; + + if (toplevel > cl->level) { + psched_time_t now = psched_get_time(); + + do { + if (cl->undertime < now) { + q->toplevel = cl->level; + return; + } + } while ((cl = cl->borrow) != NULL && toplevel > cl->level); + } +} + +static int +cbq_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct cbq_sched_data *q = qdisc_priv(sch); + int ret; + struct cbq_class *cl = cbq_classify(skb, sch, &ret); + +#ifdef CONFIG_NET_CLS_ACT + q->rx_class = cl; +#endif + if (cl == NULL) { + if (ret & __NET_XMIT_BYPASS) + qdisc_qstats_drop(sch); + __qdisc_drop(skb, to_free); + return ret; + } + + ret = qdisc_enqueue(skb, cl->q, to_free); + if (ret == NET_XMIT_SUCCESS) { + sch->q.qlen++; + cbq_mark_toplevel(q, cl); + if (!cl->next_alive) + cbq_activate_class(cl); + return ret; + } + + if (net_xmit_drop_count(ret)) { + qdisc_qstats_drop(sch); + cbq_mark_toplevel(q, cl); + cl->qstats.drops++; + } + return ret; +} + +/* Overlimit action: penalize leaf class by adding offtime */ +static void cbq_overlimit(struct cbq_class *cl) +{ + struct cbq_sched_data *q = qdisc_priv(cl->qdisc); + psched_tdiff_t delay = cl->undertime - q->now; + + if (!cl->delayed) { + delay += cl->offtime; + + /* + * Class goes to sleep, so that it will have no + * chance to work avgidle. Let's forgive it 8) + * + * BTW cbq-2.0 has a crap in this + * place, apparently they forgot to shift it by cl->ewma_log. + */ + if (cl->avgidle < 0) + delay -= (-cl->avgidle) - ((-cl->avgidle) >> cl->ewma_log); + if (cl->avgidle < cl->minidle) + cl->avgidle = cl->minidle; + if (delay <= 0) + delay = 1; + cl->undertime = q->now + delay; + + cl->xstats.overactions++; + cl->delayed = 1; + } + if (q->wd_expires == 0 || q->wd_expires > delay) + q->wd_expires = delay; + + /* Dirty work! We must schedule wakeups based on + * real available rate, rather than leaf rate, + * which may be tiny (even zero). + */ + if (q->toplevel == TC_CBQ_MAXLEVEL) { + struct cbq_class *b; + psched_tdiff_t base_delay = q->wd_expires; + + for (b = cl->borrow; b; b = b->borrow) { + delay = b->undertime - q->now; + if (delay < base_delay) { + if (delay <= 0) + delay = 1; + base_delay = delay; + } + } + + q->wd_expires = base_delay; + } +} + +/* + * It is mission critical procedure. + * + * We "regenerate" toplevel cutoff, if transmitting class + * has backlog and it is not regulated. It is not part of + * original CBQ description, but looks more reasonable. + * Probably, it is wrong. This question needs further investigation. + */ + +static inline void +cbq_update_toplevel(struct cbq_sched_data *q, struct cbq_class *cl, + struct cbq_class *borrowed) +{ + if (cl && q->toplevel >= borrowed->level) { + if (cl->q->q.qlen > 1) { + do { + if (borrowed->undertime == PSCHED_PASTPERFECT) { + q->toplevel = borrowed->level; + return; + } + } while ((borrowed = borrowed->borrow) != NULL); + } +#if 0 + /* It is not necessary now. Uncommenting it + will save CPU cycles, but decrease fairness. + */ + q->toplevel = TC_CBQ_MAXLEVEL; +#endif + } +} + +static void +cbq_update(struct cbq_sched_data *q) +{ + struct cbq_class *this = q->tx_class; + struct cbq_class *cl = this; + int len = q->tx_len; + psched_time_t now; + + q->tx_class = NULL; + /* Time integrator. We calculate EOS time + * by adding expected packet transmission time. + */ + now = q->now + L2T(&q->link, len); + + for ( ; cl; cl = cl->share) { + long avgidle = cl->avgidle; + long idle; + + _bstats_update(&cl->bstats, len, 1); + + /* + * (now - last) is total time between packet right edges. + * (last_pktlen/rate) is "virtual" busy time, so that + * + * idle = (now - last) - last_pktlen/rate + */ + + idle = now - cl->last; + if ((unsigned long)idle > 128*1024*1024) { + avgidle = cl->maxidle; + } else { + idle -= L2T(cl, len); + + /* true_avgidle := (1-W)*true_avgidle + W*idle, + * where W=2^{-ewma_log}. But cl->avgidle is scaled: + * cl->avgidle == true_avgidle/W, + * hence: + */ + avgidle += idle - (avgidle>>cl->ewma_log); + } + + if (avgidle <= 0) { + /* Overlimit or at-limit */ + + if (avgidle < cl->minidle) + avgidle = cl->minidle; + + cl->avgidle = avgidle; + + /* Calculate expected time, when this class + * will be allowed to send. + * It will occur, when: + * (1-W)*true_avgidle + W*delay = 0, i.e. + * idle = (1/W - 1)*(-true_avgidle) + * or + * idle = (1 - W)*(-cl->avgidle); + */ + idle = (-avgidle) - ((-avgidle) >> cl->ewma_log); + + /* + * That is not all. + * To maintain the rate allocated to the class, + * we add to undertime virtual clock, + * necessary to complete transmitted packet. + * (len/phys_bandwidth has been already passed + * to the moment of cbq_update) + */ + + idle -= L2T(&q->link, len); + idle += L2T(cl, len); + + cl->undertime = now + idle; + } else { + /* Underlimit */ + + cl->undertime = PSCHED_PASTPERFECT; + if (avgidle > cl->maxidle) + cl->avgidle = cl->maxidle; + else + cl->avgidle = avgidle; + } + if ((s64)(now - cl->last) > 0) + cl->last = now; + } + + cbq_update_toplevel(q, this, q->tx_borrowed); +} + +static inline struct cbq_class * +cbq_under_limit(struct cbq_class *cl) +{ + struct cbq_sched_data *q = qdisc_priv(cl->qdisc); + struct cbq_class *this_cl = cl; + + if (cl->tparent == NULL) + return cl; + + if (cl->undertime == PSCHED_PASTPERFECT || q->now >= cl->undertime) { + cl->delayed = 0; + return cl; + } + + do { + /* It is very suspicious place. Now overlimit + * action is generated for not bounded classes + * only if link is completely congested. + * Though it is in agree with ancestor-only paradigm, + * it looks very stupid. Particularly, + * it means that this chunk of code will either + * never be called or result in strong amplification + * of burstiness. Dangerous, silly, and, however, + * no another solution exists. + */ + cl = cl->borrow; + if (!cl) { + this_cl->qstats.overlimits++; + cbq_overlimit(this_cl); + return NULL; + } + if (cl->level > q->toplevel) + return NULL; + } while (cl->undertime != PSCHED_PASTPERFECT && q->now < cl->undertime); + + cl->delayed = 0; + return cl; +} + +static inline struct sk_buff * +cbq_dequeue_prio(struct Qdisc *sch, int prio) +{ + struct cbq_sched_data *q = qdisc_priv(sch); + struct cbq_class *cl_tail, *cl_prev, *cl; + struct sk_buff *skb; + int deficit; + + cl_tail = cl_prev = q->active[prio]; + cl = cl_prev->next_alive; + + do { + deficit = 0; + + /* Start round */ + do { + struct cbq_class *borrow = cl; + + if (cl->q->q.qlen && + (borrow = cbq_under_limit(cl)) == NULL) + goto skip_class; + + if (cl->deficit <= 0) { + /* Class exhausted its allotment per + * this round. Switch to the next one. + */ + deficit = 1; + cl->deficit += cl->quantum; + goto next_class; + } + + skb = cl->q->dequeue(cl->q); + + /* Class did not give us any skb :-( + * It could occur even if cl->q->q.qlen != 0 + * f.e. if cl->q == "tbf" + */ + if (skb == NULL) + goto skip_class; + + cl->deficit -= qdisc_pkt_len(skb); + q->tx_class = cl; + q->tx_borrowed = borrow; + if (borrow != cl) { +#ifndef CBQ_XSTATS_BORROWS_BYTES + borrow->xstats.borrows++; + cl->xstats.borrows++; +#else + borrow->xstats.borrows += qdisc_pkt_len(skb); + cl->xstats.borrows += qdisc_pkt_len(skb); +#endif + } + q->tx_len = qdisc_pkt_len(skb); + + if (cl->deficit <= 0) { + q->active[prio] = cl; + cl = cl->next_alive; + cl->deficit += cl->quantum; + } + return skb; + +skip_class: + if (cl->q->q.qlen == 0 || prio != cl->cpriority) { + /* Class is empty or penalized. + * Unlink it from active chain. + */ + cl_prev->next_alive = cl->next_alive; + cl->next_alive = NULL; + + /* Did cl_tail point to it? */ + if (cl == cl_tail) { + /* Repair it! */ + cl_tail = cl_prev; + + /* Was it the last class in this band? */ + if (cl == cl_tail) { + /* Kill the band! */ + q->active[prio] = NULL; + q->activemask &= ~(1<q->q.qlen) + cbq_activate_class(cl); + return NULL; + } + + q->active[prio] = cl_tail; + } + if (cl->q->q.qlen) + cbq_activate_class(cl); + + cl = cl_prev; + } + +next_class: + cl_prev = cl; + cl = cl->next_alive; + } while (cl_prev != cl_tail); + } while (deficit); + + q->active[prio] = cl_prev; + + return NULL; +} + +static inline struct sk_buff * +cbq_dequeue_1(struct Qdisc *sch) +{ + struct cbq_sched_data *q = qdisc_priv(sch); + struct sk_buff *skb; + unsigned int activemask; + + activemask = q->activemask & 0xFF; + while (activemask) { + int prio = ffz(~activemask); + activemask &= ~(1<tx_class) + cbq_update(q); + + q->now = now; + + for (;;) { + q->wd_expires = 0; + + skb = cbq_dequeue_1(sch); + if (skb) { + qdisc_bstats_update(sch, skb); + sch->q.qlen--; + return skb; + } + + /* All the classes are overlimit. + * + * It is possible, if: + * + * 1. Scheduler is empty. + * 2. Toplevel cutoff inhibited borrowing. + * 3. Root class is overlimit. + * + * Reset 2d and 3d conditions and retry. + * + * Note, that NS and cbq-2.0 are buggy, peeking + * an arbitrary class is appropriate for ancestor-only + * sharing, but not for toplevel algorithm. + * + * Our version is better, but slower, because it requires + * two passes, but it is unavoidable with top-level sharing. + */ + + if (q->toplevel == TC_CBQ_MAXLEVEL && + q->link.undertime == PSCHED_PASTPERFECT) + break; + + q->toplevel = TC_CBQ_MAXLEVEL; + q->link.undertime = PSCHED_PASTPERFECT; + } + + /* No packets in scheduler or nobody wants to give them to us :-( + * Sigh... start watchdog timer in the last case. + */ + + if (sch->q.qlen) { + qdisc_qstats_overlimit(sch); + if (q->wd_expires) + qdisc_watchdog_schedule(&q->watchdog, + now + q->wd_expires); + } + return NULL; +} + +/* CBQ class maintenance routines */ + +static void cbq_adjust_levels(struct cbq_class *this) +{ + if (this == NULL) + return; + + do { + int level = 0; + struct cbq_class *cl; + + cl = this->children; + if (cl) { + do { + if (cl->level > level) + level = cl->level; + } while ((cl = cl->sibling) != this->children); + } + this->level = level + 1; + } while ((this = this->tparent) != NULL); +} + +static void cbq_normalize_quanta(struct cbq_sched_data *q, int prio) +{ + struct cbq_class *cl; + unsigned int h; + + if (q->quanta[prio] == 0) + return; + + for (h = 0; h < q->clhash.hashsize; h++) { + hlist_for_each_entry(cl, &q->clhash.hash[h], common.hnode) { + /* BUGGGG... Beware! This expression suffer of + * arithmetic overflows! + */ + if (cl->priority == prio) { + cl->quantum = (cl->weight*cl->allot*q->nclasses[prio])/ + q->quanta[prio]; + } + if (cl->quantum <= 0 || + cl->quantum > 32*qdisc_dev(cl->qdisc)->mtu) { + pr_warn("CBQ: class %08x has bad quantum==%ld, repaired.\n", + cl->common.classid, cl->quantum); + cl->quantum = qdisc_dev(cl->qdisc)->mtu/2 + 1; + } + } + } +} + +static void cbq_sync_defmap(struct cbq_class *cl) +{ + struct cbq_sched_data *q = qdisc_priv(cl->qdisc); + struct cbq_class *split = cl->split; + unsigned int h; + int i; + + if (split == NULL) + return; + + for (i = 0; i <= TC_PRIO_MAX; i++) { + if (split->defaults[i] == cl && !(cl->defmap & (1<defaults[i] = NULL; + } + + for (i = 0; i <= TC_PRIO_MAX; i++) { + int level = split->level; + + if (split->defaults[i]) + continue; + + for (h = 0; h < q->clhash.hashsize; h++) { + struct cbq_class *c; + + hlist_for_each_entry(c, &q->clhash.hash[h], + common.hnode) { + if (c->split == split && c->level < level && + c->defmap & (1<defaults[i] = c; + level = c->level; + } + } + } + } +} + +static void cbq_change_defmap(struct cbq_class *cl, u32 splitid, u32 def, u32 mask) +{ + struct cbq_class *split = NULL; + + if (splitid == 0) { + split = cl->split; + if (!split) + return; + splitid = split->common.classid; + } + + if (split == NULL || split->common.classid != splitid) { + for (split = cl->tparent; split; split = split->tparent) + if (split->common.classid == splitid) + break; + } + + if (split == NULL) + return; + + if (cl->split != split) { + cl->defmap = 0; + cbq_sync_defmap(cl); + cl->split = split; + cl->defmap = def & mask; + } else + cl->defmap = (cl->defmap & ~mask) | (def & mask); + + cbq_sync_defmap(cl); +} + +static void cbq_unlink_class(struct cbq_class *this) +{ + struct cbq_class *cl, **clp; + struct cbq_sched_data *q = qdisc_priv(this->qdisc); + + qdisc_class_hash_remove(&q->clhash, &this->common); + + if (this->tparent) { + clp = &this->sibling; + cl = *clp; + do { + if (cl == this) { + *clp = cl->sibling; + break; + } + clp = &cl->sibling; + } while ((cl = *clp) != this->sibling); + + if (this->tparent->children == this) { + this->tparent->children = this->sibling; + if (this->sibling == this) + this->tparent->children = NULL; + } + } else { + WARN_ON(this->sibling != this); + } +} + +static void cbq_link_class(struct cbq_class *this) +{ + struct cbq_sched_data *q = qdisc_priv(this->qdisc); + struct cbq_class *parent = this->tparent; + + this->sibling = this; + qdisc_class_hash_insert(&q->clhash, &this->common); + + if (parent == NULL) + return; + + if (parent->children == NULL) { + parent->children = this; + } else { + this->sibling = parent->children->sibling; + parent->children->sibling = this; + } +} + +static void +cbq_reset(struct Qdisc *sch) +{ + struct cbq_sched_data *q = qdisc_priv(sch); + struct cbq_class *cl; + int prio; + unsigned int h; + + q->activemask = 0; + q->pmask = 0; + q->tx_class = NULL; + q->tx_borrowed = NULL; + qdisc_watchdog_cancel(&q->watchdog); + q->toplevel = TC_CBQ_MAXLEVEL; + q->now = psched_get_time(); + + for (prio = 0; prio <= TC_CBQ_MAXPRIO; prio++) + q->active[prio] = NULL; + + for (h = 0; h < q->clhash.hashsize; h++) { + hlist_for_each_entry(cl, &q->clhash.hash[h], common.hnode) { + qdisc_reset(cl->q); + + cl->next_alive = NULL; + cl->undertime = PSCHED_PASTPERFECT; + cl->avgidle = cl->maxidle; + cl->deficit = cl->quantum; + cl->cpriority = cl->priority; + } + } +} + + +static void cbq_set_lss(struct cbq_class *cl, struct tc_cbq_lssopt *lss) +{ + if (lss->change & TCF_CBQ_LSS_FLAGS) { + cl->share = (lss->flags & TCF_CBQ_LSS_ISOLATED) ? NULL : cl->tparent; + cl->borrow = (lss->flags & TCF_CBQ_LSS_BOUNDED) ? NULL : cl->tparent; + } + if (lss->change & TCF_CBQ_LSS_EWMA) + cl->ewma_log = lss->ewma_log; + if (lss->change & TCF_CBQ_LSS_AVPKT) + cl->avpkt = lss->avpkt; + if (lss->change & TCF_CBQ_LSS_MINIDLE) + cl->minidle = -(long)lss->minidle; + if (lss->change & TCF_CBQ_LSS_MAXIDLE) { + cl->maxidle = lss->maxidle; + cl->avgidle = lss->maxidle; + } + if (lss->change & TCF_CBQ_LSS_OFFTIME) + cl->offtime = lss->offtime; +} + +static void cbq_rmprio(struct cbq_sched_data *q, struct cbq_class *cl) +{ + q->nclasses[cl->priority]--; + q->quanta[cl->priority] -= cl->weight; + cbq_normalize_quanta(q, cl->priority); +} + +static void cbq_addprio(struct cbq_sched_data *q, struct cbq_class *cl) +{ + q->nclasses[cl->priority]++; + q->quanta[cl->priority] += cl->weight; + cbq_normalize_quanta(q, cl->priority); +} + +static int cbq_set_wrr(struct cbq_class *cl, struct tc_cbq_wrropt *wrr) +{ + struct cbq_sched_data *q = qdisc_priv(cl->qdisc); + + if (wrr->allot) + cl->allot = wrr->allot; + if (wrr->weight) + cl->weight = wrr->weight; + if (wrr->priority) { + cl->priority = wrr->priority - 1; + cl->cpriority = cl->priority; + if (cl->priority >= cl->priority2) + cl->priority2 = TC_CBQ_MAXPRIO - 1; + } + + cbq_addprio(q, cl); + return 0; +} + +static int cbq_set_fopt(struct cbq_class *cl, struct tc_cbq_fopt *fopt) +{ + cbq_change_defmap(cl, fopt->split, fopt->defmap, fopt->defchange); + return 0; +} + +static const struct nla_policy cbq_policy[TCA_CBQ_MAX + 1] = { + [TCA_CBQ_LSSOPT] = { .len = sizeof(struct tc_cbq_lssopt) }, + [TCA_CBQ_WRROPT] = { .len = sizeof(struct tc_cbq_wrropt) }, + [TCA_CBQ_FOPT] = { .len = sizeof(struct tc_cbq_fopt) }, + [TCA_CBQ_OVL_STRATEGY] = { .len = sizeof(struct tc_cbq_ovl) }, + [TCA_CBQ_RATE] = { .len = sizeof(struct tc_ratespec) }, + [TCA_CBQ_RTAB] = { .type = NLA_BINARY, .len = TC_RTAB_SIZE }, + [TCA_CBQ_POLICE] = { .len = sizeof(struct tc_cbq_police) }, +}; + +static int cbq_opt_parse(struct nlattr *tb[TCA_CBQ_MAX + 1], + struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + int err; + + if (!opt) { + NL_SET_ERR_MSG(extack, "CBQ options are required for this operation"); + return -EINVAL; + } + + err = nla_parse_nested_deprecated(tb, TCA_CBQ_MAX, opt, + cbq_policy, extack); + if (err < 0) + return err; + + if (tb[TCA_CBQ_WRROPT]) { + const struct tc_cbq_wrropt *wrr = nla_data(tb[TCA_CBQ_WRROPT]); + + if (wrr->priority > TC_CBQ_MAXPRIO) { + NL_SET_ERR_MSG(extack, "priority is bigger than TC_CBQ_MAXPRIO"); + err = -EINVAL; + } + } + return err; +} + +static int cbq_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct cbq_sched_data *q = qdisc_priv(sch); + struct nlattr *tb[TCA_CBQ_MAX + 1]; + struct tc_ratespec *r; + int err; + + qdisc_watchdog_init(&q->watchdog, sch); + + err = cbq_opt_parse(tb, opt, extack); + if (err < 0) + return err; + + if (!tb[TCA_CBQ_RTAB] || !tb[TCA_CBQ_RATE]) { + NL_SET_ERR_MSG(extack, "Rate specification missing or incomplete"); + return -EINVAL; + } + + r = nla_data(tb[TCA_CBQ_RATE]); + + q->link.R_tab = qdisc_get_rtab(r, tb[TCA_CBQ_RTAB], extack); + if (!q->link.R_tab) + return -EINVAL; + + err = tcf_block_get(&q->link.block, &q->link.filter_list, sch, extack); + if (err) + goto put_rtab; + + err = qdisc_class_hash_init(&q->clhash); + if (err < 0) + goto put_block; + + q->link.sibling = &q->link; + q->link.common.classid = sch->handle; + q->link.qdisc = sch; + q->link.q = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, + sch->handle, NULL); + if (!q->link.q) + q->link.q = &noop_qdisc; + else + qdisc_hash_add(q->link.q, true); + + q->link.priority = TC_CBQ_MAXPRIO - 1; + q->link.priority2 = TC_CBQ_MAXPRIO - 1; + q->link.cpriority = TC_CBQ_MAXPRIO - 1; + q->link.allot = psched_mtu(qdisc_dev(sch)); + q->link.quantum = q->link.allot; + q->link.weight = q->link.R_tab->rate.rate; + + q->link.ewma_log = TC_CBQ_DEF_EWMA; + q->link.avpkt = q->link.allot/2; + q->link.minidle = -0x7FFFFFFF; + + q->toplevel = TC_CBQ_MAXLEVEL; + q->now = psched_get_time(); + + cbq_link_class(&q->link); + + if (tb[TCA_CBQ_LSSOPT]) + cbq_set_lss(&q->link, nla_data(tb[TCA_CBQ_LSSOPT])); + + cbq_addprio(q, &q->link); + return 0; + +put_block: + tcf_block_put(q->link.block); + +put_rtab: + qdisc_put_rtab(q->link.R_tab); + return err; +} + +static int cbq_dump_rate(struct sk_buff *skb, struct cbq_class *cl) +{ + unsigned char *b = skb_tail_pointer(skb); + + if (nla_put(skb, TCA_CBQ_RATE, sizeof(cl->R_tab->rate), &cl->R_tab->rate)) + goto nla_put_failure; + return skb->len; + +nla_put_failure: + nlmsg_trim(skb, b); + return -1; +} + +static int cbq_dump_lss(struct sk_buff *skb, struct cbq_class *cl) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tc_cbq_lssopt opt; + + opt.flags = 0; + if (cl->borrow == NULL) + opt.flags |= TCF_CBQ_LSS_BOUNDED; + if (cl->share == NULL) + opt.flags |= TCF_CBQ_LSS_ISOLATED; + opt.ewma_log = cl->ewma_log; + opt.level = cl->level; + opt.avpkt = cl->avpkt; + opt.maxidle = cl->maxidle; + opt.minidle = (u32)(-cl->minidle); + opt.offtime = cl->offtime; + opt.change = ~0; + if (nla_put(skb, TCA_CBQ_LSSOPT, sizeof(opt), &opt)) + goto nla_put_failure; + return skb->len; + +nla_put_failure: + nlmsg_trim(skb, b); + return -1; +} + +static int cbq_dump_wrr(struct sk_buff *skb, struct cbq_class *cl) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tc_cbq_wrropt opt; + + memset(&opt, 0, sizeof(opt)); + opt.flags = 0; + opt.allot = cl->allot; + opt.priority = cl->priority + 1; + opt.cpriority = cl->cpriority + 1; + opt.weight = cl->weight; + if (nla_put(skb, TCA_CBQ_WRROPT, sizeof(opt), &opt)) + goto nla_put_failure; + return skb->len; + +nla_put_failure: + nlmsg_trim(skb, b); + return -1; +} + +static int cbq_dump_fopt(struct sk_buff *skb, struct cbq_class *cl) +{ + unsigned char *b = skb_tail_pointer(skb); + struct tc_cbq_fopt opt; + + if (cl->split || cl->defmap) { + opt.split = cl->split ? cl->split->common.classid : 0; + opt.defmap = cl->defmap; + opt.defchange = ~0; + if (nla_put(skb, TCA_CBQ_FOPT, sizeof(opt), &opt)) + goto nla_put_failure; + } + return skb->len; + +nla_put_failure: + nlmsg_trim(skb, b); + return -1; +} + +static int cbq_dump_attr(struct sk_buff *skb, struct cbq_class *cl) +{ + if (cbq_dump_lss(skb, cl) < 0 || + cbq_dump_rate(skb, cl) < 0 || + cbq_dump_wrr(skb, cl) < 0 || + cbq_dump_fopt(skb, cl) < 0) + return -1; + return 0; +} + +static int cbq_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct cbq_sched_data *q = qdisc_priv(sch); + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + if (cbq_dump_attr(skb, &q->link) < 0) + goto nla_put_failure; + return nla_nest_end(skb, nest); + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static int +cbq_dump_stats(struct Qdisc *sch, struct gnet_dump *d) +{ + struct cbq_sched_data *q = qdisc_priv(sch); + + q->link.xstats.avgidle = q->link.avgidle; + return gnet_stats_copy_app(d, &q->link.xstats, sizeof(q->link.xstats)); +} + +static int +cbq_dump_class(struct Qdisc *sch, unsigned long arg, + struct sk_buff *skb, struct tcmsg *tcm) +{ + struct cbq_class *cl = (struct cbq_class *)arg; + struct nlattr *nest; + + if (cl->tparent) + tcm->tcm_parent = cl->tparent->common.classid; + else + tcm->tcm_parent = TC_H_ROOT; + tcm->tcm_handle = cl->common.classid; + tcm->tcm_info = cl->q->handle; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + if (cbq_dump_attr(skb, cl) < 0) + goto nla_put_failure; + return nla_nest_end(skb, nest); + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static int +cbq_dump_class_stats(struct Qdisc *sch, unsigned long arg, + struct gnet_dump *d) +{ + struct cbq_sched_data *q = qdisc_priv(sch); + struct cbq_class *cl = (struct cbq_class *)arg; + __u32 qlen; + + cl->xstats.avgidle = cl->avgidle; + cl->xstats.undertime = 0; + qdisc_qstats_qlen_backlog(cl->q, &qlen, &cl->qstats.backlog); + + if (cl->undertime != PSCHED_PASTPERFECT) + cl->xstats.undertime = cl->undertime - q->now; + + if (gnet_stats_copy_basic(d, NULL, &cl->bstats, true) < 0 || + gnet_stats_copy_rate_est(d, &cl->rate_est) < 0 || + gnet_stats_copy_queue(d, NULL, &cl->qstats, qlen) < 0) + return -1; + + return gnet_stats_copy_app(d, &cl->xstats, sizeof(cl->xstats)); +} + +static int cbq_graft(struct Qdisc *sch, unsigned long arg, struct Qdisc *new, + struct Qdisc **old, struct netlink_ext_ack *extack) +{ + struct cbq_class *cl = (struct cbq_class *)arg; + + if (new == NULL) { + new = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, + cl->common.classid, extack); + if (new == NULL) + return -ENOBUFS; + } + + *old = qdisc_replace(sch, new, &cl->q); + return 0; +} + +static struct Qdisc *cbq_leaf(struct Qdisc *sch, unsigned long arg) +{ + struct cbq_class *cl = (struct cbq_class *)arg; + + return cl->q; +} + +static void cbq_qlen_notify(struct Qdisc *sch, unsigned long arg) +{ + struct cbq_class *cl = (struct cbq_class *)arg; + + cbq_deactivate_class(cl); +} + +static unsigned long cbq_find(struct Qdisc *sch, u32 classid) +{ + struct cbq_sched_data *q = qdisc_priv(sch); + + return (unsigned long)cbq_class_lookup(q, classid); +} + +static void cbq_destroy_class(struct Qdisc *sch, struct cbq_class *cl) +{ + struct cbq_sched_data *q = qdisc_priv(sch); + + WARN_ON(cl->filters); + + tcf_block_put(cl->block); + qdisc_put(cl->q); + qdisc_put_rtab(cl->R_tab); + gen_kill_estimator(&cl->rate_est); + if (cl != &q->link) + kfree(cl); +} + +static void cbq_destroy(struct Qdisc *sch) +{ + struct cbq_sched_data *q = qdisc_priv(sch); + struct hlist_node *next; + struct cbq_class *cl; + unsigned int h; + +#ifdef CONFIG_NET_CLS_ACT + q->rx_class = NULL; +#endif + /* + * Filters must be destroyed first because we don't destroy the + * classes from root to leafs which means that filters can still + * be bound to classes which have been destroyed already. --TGR '04 + */ + for (h = 0; h < q->clhash.hashsize; h++) { + hlist_for_each_entry(cl, &q->clhash.hash[h], common.hnode) { + tcf_block_put(cl->block); + cl->block = NULL; + } + } + for (h = 0; h < q->clhash.hashsize; h++) { + hlist_for_each_entry_safe(cl, next, &q->clhash.hash[h], + common.hnode) + cbq_destroy_class(sch, cl); + } + qdisc_class_hash_destroy(&q->clhash); +} + +static int +cbq_change_class(struct Qdisc *sch, u32 classid, u32 parentid, struct nlattr **tca, + unsigned long *arg, struct netlink_ext_ack *extack) +{ + int err; + struct cbq_sched_data *q = qdisc_priv(sch); + struct cbq_class *cl = (struct cbq_class *)*arg; + struct nlattr *opt = tca[TCA_OPTIONS]; + struct nlattr *tb[TCA_CBQ_MAX + 1]; + struct cbq_class *parent; + struct qdisc_rate_table *rtab = NULL; + + err = cbq_opt_parse(tb, opt, extack); + if (err < 0) + return err; + + if (tb[TCA_CBQ_OVL_STRATEGY] || tb[TCA_CBQ_POLICE]) { + NL_SET_ERR_MSG(extack, "Neither overlimit strategy nor policing attributes can be used for changing class params"); + return -EOPNOTSUPP; + } + + if (cl) { + /* Check parent */ + if (parentid) { + if (cl->tparent && + cl->tparent->common.classid != parentid) { + NL_SET_ERR_MSG(extack, "Invalid parent id"); + return -EINVAL; + } + if (!cl->tparent && parentid != TC_H_ROOT) { + NL_SET_ERR_MSG(extack, "Parent must be root"); + return -EINVAL; + } + } + + if (tb[TCA_CBQ_RATE]) { + rtab = qdisc_get_rtab(nla_data(tb[TCA_CBQ_RATE]), + tb[TCA_CBQ_RTAB], extack); + if (rtab == NULL) + return -EINVAL; + } + + if (tca[TCA_RATE]) { + err = gen_replace_estimator(&cl->bstats, NULL, + &cl->rate_est, + NULL, + true, + tca[TCA_RATE]); + if (err) { + NL_SET_ERR_MSG(extack, "Failed to replace specified rate estimator"); + qdisc_put_rtab(rtab); + return err; + } + } + + /* Change class parameters */ + sch_tree_lock(sch); + + if (cl->next_alive != NULL) + cbq_deactivate_class(cl); + + if (rtab) { + qdisc_put_rtab(cl->R_tab); + cl->R_tab = rtab; + } + + if (tb[TCA_CBQ_LSSOPT]) + cbq_set_lss(cl, nla_data(tb[TCA_CBQ_LSSOPT])); + + if (tb[TCA_CBQ_WRROPT]) { + cbq_rmprio(q, cl); + cbq_set_wrr(cl, nla_data(tb[TCA_CBQ_WRROPT])); + } + + if (tb[TCA_CBQ_FOPT]) + cbq_set_fopt(cl, nla_data(tb[TCA_CBQ_FOPT])); + + if (cl->q->q.qlen) + cbq_activate_class(cl); + + sch_tree_unlock(sch); + + return 0; + } + + if (parentid == TC_H_ROOT) + return -EINVAL; + + if (!tb[TCA_CBQ_WRROPT] || !tb[TCA_CBQ_RATE] || !tb[TCA_CBQ_LSSOPT]) { + NL_SET_ERR_MSG(extack, "One of the following attributes MUST be specified: WRR, rate or link sharing"); + return -EINVAL; + } + + rtab = qdisc_get_rtab(nla_data(tb[TCA_CBQ_RATE]), tb[TCA_CBQ_RTAB], + extack); + if (rtab == NULL) + return -EINVAL; + + if (classid) { + err = -EINVAL; + if (TC_H_MAJ(classid ^ sch->handle) || + cbq_class_lookup(q, classid)) { + NL_SET_ERR_MSG(extack, "Specified class not found"); + goto failure; + } + } else { + int i; + classid = TC_H_MAKE(sch->handle, 0x8000); + + for (i = 0; i < 0x8000; i++) { + if (++q->hgenerator >= 0x8000) + q->hgenerator = 1; + if (cbq_class_lookup(q, classid|q->hgenerator) == NULL) + break; + } + err = -ENOSR; + if (i >= 0x8000) { + NL_SET_ERR_MSG(extack, "Unable to generate classid"); + goto failure; + } + classid = classid|q->hgenerator; + } + + parent = &q->link; + if (parentid) { + parent = cbq_class_lookup(q, parentid); + err = -EINVAL; + if (!parent) { + NL_SET_ERR_MSG(extack, "Failed to find parentid"); + goto failure; + } + } + + err = -ENOBUFS; + cl = kzalloc(sizeof(*cl), GFP_KERNEL); + if (cl == NULL) + goto failure; + + gnet_stats_basic_sync_init(&cl->bstats); + err = tcf_block_get(&cl->block, &cl->filter_list, sch, extack); + if (err) { + kfree(cl); + goto failure; + } + + if (tca[TCA_RATE]) { + err = gen_new_estimator(&cl->bstats, NULL, &cl->rate_est, + NULL, true, tca[TCA_RATE]); + if (err) { + NL_SET_ERR_MSG(extack, "Couldn't create new estimator"); + tcf_block_put(cl->block); + kfree(cl); + goto failure; + } + } + + cl->R_tab = rtab; + rtab = NULL; + cl->q = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, classid, + NULL); + if (!cl->q) + cl->q = &noop_qdisc; + else + qdisc_hash_add(cl->q, true); + + cl->common.classid = classid; + cl->tparent = parent; + cl->qdisc = sch; + cl->allot = parent->allot; + cl->quantum = cl->allot; + cl->weight = cl->R_tab->rate.rate; + + sch_tree_lock(sch); + cbq_link_class(cl); + cl->borrow = cl->tparent; + if (cl->tparent != &q->link) + cl->share = cl->tparent; + cbq_adjust_levels(parent); + cl->minidle = -0x7FFFFFFF; + cbq_set_lss(cl, nla_data(tb[TCA_CBQ_LSSOPT])); + cbq_set_wrr(cl, nla_data(tb[TCA_CBQ_WRROPT])); + if (cl->ewma_log == 0) + cl->ewma_log = q->link.ewma_log; + if (cl->maxidle == 0) + cl->maxidle = q->link.maxidle; + if (cl->avpkt == 0) + cl->avpkt = q->link.avpkt; + if (tb[TCA_CBQ_FOPT]) + cbq_set_fopt(cl, nla_data(tb[TCA_CBQ_FOPT])); + sch_tree_unlock(sch); + + qdisc_class_hash_grow(sch, &q->clhash); + + *arg = (unsigned long)cl; + return 0; + +failure: + qdisc_put_rtab(rtab); + return err; +} + +static int cbq_delete(struct Qdisc *sch, unsigned long arg, + struct netlink_ext_ack *extack) +{ + struct cbq_sched_data *q = qdisc_priv(sch); + struct cbq_class *cl = (struct cbq_class *)arg; + + if (cl->filters || cl->children || cl == &q->link) + return -EBUSY; + + sch_tree_lock(sch); + + qdisc_purge_queue(cl->q); + + if (cl->next_alive) + cbq_deactivate_class(cl); + + if (q->tx_borrowed == cl) + q->tx_borrowed = q->tx_class; + if (q->tx_class == cl) { + q->tx_class = NULL; + q->tx_borrowed = NULL; + } +#ifdef CONFIG_NET_CLS_ACT + if (q->rx_class == cl) + q->rx_class = NULL; +#endif + + cbq_unlink_class(cl); + cbq_adjust_levels(cl->tparent); + cl->defmap = 0; + cbq_sync_defmap(cl); + + cbq_rmprio(q, cl); + sch_tree_unlock(sch); + + cbq_destroy_class(sch, cl); + return 0; +} + +static struct tcf_block *cbq_tcf_block(struct Qdisc *sch, unsigned long arg, + struct netlink_ext_ack *extack) +{ + struct cbq_sched_data *q = qdisc_priv(sch); + struct cbq_class *cl = (struct cbq_class *)arg; + + if (cl == NULL) + cl = &q->link; + + return cl->block; +} + +static unsigned long cbq_bind_filter(struct Qdisc *sch, unsigned long parent, + u32 classid) +{ + struct cbq_sched_data *q = qdisc_priv(sch); + struct cbq_class *p = (struct cbq_class *)parent; + struct cbq_class *cl = cbq_class_lookup(q, classid); + + if (cl) { + if (p && p->level <= cl->level) + return 0; + cl->filters++; + return (unsigned long)cl; + } + return 0; +} + +static void cbq_unbind_filter(struct Qdisc *sch, unsigned long arg) +{ + struct cbq_class *cl = (struct cbq_class *)arg; + + cl->filters--; +} + +static void cbq_walk(struct Qdisc *sch, struct qdisc_walker *arg) +{ + struct cbq_sched_data *q = qdisc_priv(sch); + struct cbq_class *cl; + unsigned int h; + + if (arg->stop) + return; + + for (h = 0; h < q->clhash.hashsize; h++) { + hlist_for_each_entry(cl, &q->clhash.hash[h], common.hnode) { + if (!tc_qdisc_stats_dump(sch, (unsigned long)cl, arg)) + return; + } + } +} + +static const struct Qdisc_class_ops cbq_class_ops = { + .graft = cbq_graft, + .leaf = cbq_leaf, + .qlen_notify = cbq_qlen_notify, + .find = cbq_find, + .change = cbq_change_class, + .delete = cbq_delete, + .walk = cbq_walk, + .tcf_block = cbq_tcf_block, + .bind_tcf = cbq_bind_filter, + .unbind_tcf = cbq_unbind_filter, + .dump = cbq_dump_class, + .dump_stats = cbq_dump_class_stats, +}; + +static struct Qdisc_ops cbq_qdisc_ops __read_mostly = { + .next = NULL, + .cl_ops = &cbq_class_ops, + .id = "cbq", + .priv_size = sizeof(struct cbq_sched_data), + .enqueue = cbq_enqueue, + .dequeue = cbq_dequeue, + .peek = qdisc_peek_dequeued, + .init = cbq_init, + .reset = cbq_reset, + .destroy = cbq_destroy, + .change = NULL, + .dump = cbq_dump, + .dump_stats = cbq_dump_stats, + .owner = THIS_MODULE, +}; + +static int __init cbq_module_init(void) +{ + return register_qdisc(&cbq_qdisc_ops); +} +static void __exit cbq_module_exit(void) +{ + unregister_qdisc(&cbq_qdisc_ops); +} +module_init(cbq_module_init) +module_exit(cbq_module_exit) +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_cbs.c b/net/sched/sch_cbs.c new file mode 100644 index 000000000..cac870eb7 --- /dev/null +++ b/net/sched/sch_cbs.c @@ -0,0 +1,576 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/sch_cbs.c Credit Based Shaper + * + * Authors: Vinicius Costa Gomes + */ + +/* Credit Based Shaper (CBS) + * ========================= + * + * This is a simple rate-limiting shaper aimed at TSN applications on + * systems with known traffic workloads. + * + * Its algorithm is defined by the IEEE 802.1Q-2014 Specification, + * Section 8.6.8.2, and explained in more detail in the Annex L of the + * same specification. + * + * There are four tunables to be considered: + * + * 'idleslope': Idleslope is the rate of credits that is + * accumulated (in kilobits per second) when there is at least + * one packet waiting for transmission. Packets are transmitted + * when the current value of credits is equal or greater than + * zero. When there is no packet to be transmitted the amount of + * credits is set to zero. This is the main tunable of the CBS + * algorithm. + * + * 'sendslope': + * Sendslope is the rate of credits that is depleted (it should be a + * negative number of kilobits per second) when a transmission is + * ocurring. It can be calculated as follows, (IEEE 802.1Q-2014 Section + * 8.6.8.2 item g): + * + * sendslope = idleslope - port_transmit_rate + * + * 'hicredit': Hicredit defines the maximum amount of credits (in + * bytes) that can be accumulated. Hicredit depends on the + * characteristics of interfering traffic, + * 'max_interference_size' is the maximum size of any burst of + * traffic that can delay the transmission of a frame that is + * available for transmission for this traffic class, (IEEE + * 802.1Q-2014 Annex L, Equation L-3): + * + * hicredit = max_interference_size * (idleslope / port_transmit_rate) + * + * 'locredit': Locredit is the minimum amount of credits that can + * be reached. It is a function of the traffic flowing through + * this qdisc (IEEE 802.1Q-2014 Annex L, Equation L-2): + * + * locredit = max_frame_size * (sendslope / port_transmit_rate) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static LIST_HEAD(cbs_list); +static DEFINE_SPINLOCK(cbs_list_lock); + +#define BYTES_PER_KBIT (1000LL / 8) + +struct cbs_sched_data { + bool offload; + int queue; + atomic64_t port_rate; /* in bytes/s */ + s64 last; /* timestamp in ns */ + s64 credits; /* in bytes */ + s32 locredit; /* in bytes */ + s32 hicredit; /* in bytes */ + s64 sendslope; /* in bytes/s */ + s64 idleslope; /* in bytes/s */ + struct qdisc_watchdog watchdog; + int (*enqueue)(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free); + struct sk_buff *(*dequeue)(struct Qdisc *sch); + struct Qdisc *qdisc; + struct list_head cbs_list; +}; + +static int cbs_child_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct Qdisc *child, + struct sk_buff **to_free) +{ + unsigned int len = qdisc_pkt_len(skb); + int err; + + err = child->ops->enqueue(skb, child, to_free); + if (err != NET_XMIT_SUCCESS) + return err; + + sch->qstats.backlog += len; + sch->q.qlen++; + + return NET_XMIT_SUCCESS; +} + +static int cbs_enqueue_offload(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct cbs_sched_data *q = qdisc_priv(sch); + struct Qdisc *qdisc = q->qdisc; + + return cbs_child_enqueue(skb, sch, qdisc, to_free); +} + +static int cbs_enqueue_soft(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct cbs_sched_data *q = qdisc_priv(sch); + struct Qdisc *qdisc = q->qdisc; + + if (sch->q.qlen == 0 && q->credits > 0) { + /* We need to stop accumulating credits when there's + * no enqueued packets and q->credits is positive. + */ + q->credits = 0; + q->last = ktime_get_ns(); + } + + return cbs_child_enqueue(skb, sch, qdisc, to_free); +} + +static int cbs_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct cbs_sched_data *q = qdisc_priv(sch); + + return q->enqueue(skb, sch, to_free); +} + +/* timediff is in ns, slope is in bytes/s */ +static s64 timediff_to_credits(s64 timediff, s64 slope) +{ + return div64_s64(timediff * slope, NSEC_PER_SEC); +} + +static s64 delay_from_credits(s64 credits, s64 slope) +{ + if (unlikely(slope == 0)) + return S64_MAX; + + return div64_s64(-credits * NSEC_PER_SEC, slope); +} + +static s64 credits_from_len(unsigned int len, s64 slope, s64 port_rate) +{ + if (unlikely(port_rate == 0)) + return S64_MAX; + + return div64_s64(len * slope, port_rate); +} + +static struct sk_buff *cbs_child_dequeue(struct Qdisc *sch, struct Qdisc *child) +{ + struct sk_buff *skb; + + skb = child->ops->dequeue(child); + if (!skb) + return NULL; + + qdisc_qstats_backlog_dec(sch, skb); + qdisc_bstats_update(sch, skb); + sch->q.qlen--; + + return skb; +} + +static struct sk_buff *cbs_dequeue_soft(struct Qdisc *sch) +{ + struct cbs_sched_data *q = qdisc_priv(sch); + struct Qdisc *qdisc = q->qdisc; + s64 now = ktime_get_ns(); + struct sk_buff *skb; + s64 credits; + int len; + + /* The previous packet is still being sent */ + if (now < q->last) { + qdisc_watchdog_schedule_ns(&q->watchdog, q->last); + return NULL; + } + if (q->credits < 0) { + credits = timediff_to_credits(now - q->last, q->idleslope); + + credits = q->credits + credits; + q->credits = min_t(s64, credits, q->hicredit); + + if (q->credits < 0) { + s64 delay; + + delay = delay_from_credits(q->credits, q->idleslope); + qdisc_watchdog_schedule_ns(&q->watchdog, now + delay); + + q->last = now; + + return NULL; + } + } + skb = cbs_child_dequeue(sch, qdisc); + if (!skb) + return NULL; + + len = qdisc_pkt_len(skb); + + /* As sendslope is a negative number, this will decrease the + * amount of q->credits. + */ + credits = credits_from_len(len, q->sendslope, + atomic64_read(&q->port_rate)); + credits += q->credits; + + q->credits = max_t(s64, credits, q->locredit); + /* Estimate of the transmission of the last byte of the packet in ns */ + if (unlikely(atomic64_read(&q->port_rate) == 0)) + q->last = now; + else + q->last = now + div64_s64(len * NSEC_PER_SEC, + atomic64_read(&q->port_rate)); + + return skb; +} + +static struct sk_buff *cbs_dequeue_offload(struct Qdisc *sch) +{ + struct cbs_sched_data *q = qdisc_priv(sch); + struct Qdisc *qdisc = q->qdisc; + + return cbs_child_dequeue(sch, qdisc); +} + +static struct sk_buff *cbs_dequeue(struct Qdisc *sch) +{ + struct cbs_sched_data *q = qdisc_priv(sch); + + return q->dequeue(sch); +} + +static const struct nla_policy cbs_policy[TCA_CBS_MAX + 1] = { + [TCA_CBS_PARMS] = { .len = sizeof(struct tc_cbs_qopt) }, +}; + +static void cbs_disable_offload(struct net_device *dev, + struct cbs_sched_data *q) +{ + struct tc_cbs_qopt_offload cbs = { }; + const struct net_device_ops *ops; + int err; + + if (!q->offload) + return; + + q->enqueue = cbs_enqueue_soft; + q->dequeue = cbs_dequeue_soft; + + ops = dev->netdev_ops; + if (!ops->ndo_setup_tc) + return; + + cbs.queue = q->queue; + cbs.enable = 0; + + err = ops->ndo_setup_tc(dev, TC_SETUP_QDISC_CBS, &cbs); + if (err < 0) + pr_warn("Couldn't disable CBS offload for queue %d\n", + cbs.queue); +} + +static int cbs_enable_offload(struct net_device *dev, struct cbs_sched_data *q, + const struct tc_cbs_qopt *opt, + struct netlink_ext_ack *extack) +{ + const struct net_device_ops *ops = dev->netdev_ops; + struct tc_cbs_qopt_offload cbs = { }; + int err; + + if (!ops->ndo_setup_tc) { + NL_SET_ERR_MSG(extack, "Specified device does not support cbs offload"); + return -EOPNOTSUPP; + } + + cbs.queue = q->queue; + + cbs.enable = 1; + cbs.hicredit = opt->hicredit; + cbs.locredit = opt->locredit; + cbs.idleslope = opt->idleslope; + cbs.sendslope = opt->sendslope; + + err = ops->ndo_setup_tc(dev, TC_SETUP_QDISC_CBS, &cbs); + if (err < 0) { + NL_SET_ERR_MSG(extack, "Specified device failed to setup cbs hardware offload"); + return err; + } + + q->enqueue = cbs_enqueue_offload; + q->dequeue = cbs_dequeue_offload; + + return 0; +} + +static void cbs_set_port_rate(struct net_device *dev, struct cbs_sched_data *q) +{ + struct ethtool_link_ksettings ecmd; + int speed = SPEED_10; + int port_rate; + int err; + + err = __ethtool_get_link_ksettings(dev, &ecmd); + if (err < 0) + goto skip; + + if (ecmd.base.speed && ecmd.base.speed != SPEED_UNKNOWN) + speed = ecmd.base.speed; + +skip: + port_rate = speed * 1000 * BYTES_PER_KBIT; + + atomic64_set(&q->port_rate, port_rate); + netdev_dbg(dev, "cbs: set %s's port_rate to: %lld, linkspeed: %d\n", + dev->name, (long long)atomic64_read(&q->port_rate), + ecmd.base.speed); +} + +static int cbs_dev_notifier(struct notifier_block *nb, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct cbs_sched_data *q; + struct net_device *qdev; + bool found = false; + + ASSERT_RTNL(); + + if (event != NETDEV_UP && event != NETDEV_CHANGE) + return NOTIFY_DONE; + + spin_lock(&cbs_list_lock); + list_for_each_entry(q, &cbs_list, cbs_list) { + qdev = qdisc_dev(q->qdisc); + if (qdev == dev) { + found = true; + break; + } + } + spin_unlock(&cbs_list_lock); + + if (found) + cbs_set_port_rate(dev, q); + + return NOTIFY_DONE; +} + +static int cbs_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct cbs_sched_data *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct nlattr *tb[TCA_CBS_MAX + 1]; + struct tc_cbs_qopt *qopt; + int err; + + err = nla_parse_nested_deprecated(tb, TCA_CBS_MAX, opt, cbs_policy, + extack); + if (err < 0) + return err; + + if (!tb[TCA_CBS_PARMS]) { + NL_SET_ERR_MSG(extack, "Missing CBS parameter which are mandatory"); + return -EINVAL; + } + + qopt = nla_data(tb[TCA_CBS_PARMS]); + + if (!qopt->offload) { + cbs_set_port_rate(dev, q); + cbs_disable_offload(dev, q); + } else { + err = cbs_enable_offload(dev, q, qopt, extack); + if (err < 0) + return err; + } + + /* Everything went OK, save the parameters used. */ + q->hicredit = qopt->hicredit; + q->locredit = qopt->locredit; + q->idleslope = qopt->idleslope * BYTES_PER_KBIT; + q->sendslope = qopt->sendslope * BYTES_PER_KBIT; + q->offload = qopt->offload; + + return 0; +} + +static int cbs_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct cbs_sched_data *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + + if (!opt) { + NL_SET_ERR_MSG(extack, "Missing CBS qdisc options which are mandatory"); + return -EINVAL; + } + + q->qdisc = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, + sch->handle, extack); + if (!q->qdisc) + return -ENOMEM; + + spin_lock(&cbs_list_lock); + list_add(&q->cbs_list, &cbs_list); + spin_unlock(&cbs_list_lock); + + qdisc_hash_add(q->qdisc, false); + + q->queue = sch->dev_queue - netdev_get_tx_queue(dev, 0); + + q->enqueue = cbs_enqueue_soft; + q->dequeue = cbs_dequeue_soft; + + qdisc_watchdog_init(&q->watchdog, sch); + + return cbs_change(sch, opt, extack); +} + +static void cbs_destroy(struct Qdisc *sch) +{ + struct cbs_sched_data *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + + /* Nothing to do if we couldn't create the underlying qdisc */ + if (!q->qdisc) + return; + + qdisc_watchdog_cancel(&q->watchdog); + cbs_disable_offload(dev, q); + + spin_lock(&cbs_list_lock); + list_del(&q->cbs_list); + spin_unlock(&cbs_list_lock); + + qdisc_put(q->qdisc); +} + +static int cbs_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct cbs_sched_data *q = qdisc_priv(sch); + struct tc_cbs_qopt opt = { }; + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (!nest) + goto nla_put_failure; + + opt.hicredit = q->hicredit; + opt.locredit = q->locredit; + opt.sendslope = div64_s64(q->sendslope, BYTES_PER_KBIT); + opt.idleslope = div64_s64(q->idleslope, BYTES_PER_KBIT); + opt.offload = q->offload; + + if (nla_put(skb, TCA_CBS_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + + return nla_nest_end(skb, nest); + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static int cbs_dump_class(struct Qdisc *sch, unsigned long cl, + struct sk_buff *skb, struct tcmsg *tcm) +{ + struct cbs_sched_data *q = qdisc_priv(sch); + + if (cl != 1 || !q->qdisc) /* only one class */ + return -ENOENT; + + tcm->tcm_handle |= TC_H_MIN(1); + tcm->tcm_info = q->qdisc->handle; + + return 0; +} + +static int cbs_graft(struct Qdisc *sch, unsigned long arg, struct Qdisc *new, + struct Qdisc **old, struct netlink_ext_ack *extack) +{ + struct cbs_sched_data *q = qdisc_priv(sch); + + if (!new) { + new = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, + sch->handle, NULL); + if (!new) + new = &noop_qdisc; + } + + *old = qdisc_replace(sch, new, &q->qdisc); + return 0; +} + +static struct Qdisc *cbs_leaf(struct Qdisc *sch, unsigned long arg) +{ + struct cbs_sched_data *q = qdisc_priv(sch); + + return q->qdisc; +} + +static unsigned long cbs_find(struct Qdisc *sch, u32 classid) +{ + return 1; +} + +static void cbs_walk(struct Qdisc *sch, struct qdisc_walker *walker) +{ + if (!walker->stop) { + tc_qdisc_stats_dump(sch, 1, walker); + } +} + +static const struct Qdisc_class_ops cbs_class_ops = { + .graft = cbs_graft, + .leaf = cbs_leaf, + .find = cbs_find, + .walk = cbs_walk, + .dump = cbs_dump_class, +}; + +static struct Qdisc_ops cbs_qdisc_ops __read_mostly = { + .id = "cbs", + .cl_ops = &cbs_class_ops, + .priv_size = sizeof(struct cbs_sched_data), + .enqueue = cbs_enqueue, + .dequeue = cbs_dequeue, + .peek = qdisc_peek_dequeued, + .init = cbs_init, + .reset = qdisc_reset_queue, + .destroy = cbs_destroy, + .change = cbs_change, + .dump = cbs_dump, + .owner = THIS_MODULE, +}; + +static struct notifier_block cbs_device_notifier = { + .notifier_call = cbs_dev_notifier, +}; + +static int __init cbs_module_init(void) +{ + int err; + + err = register_netdevice_notifier(&cbs_device_notifier); + if (err) + return err; + + err = register_qdisc(&cbs_qdisc_ops); + if (err) + unregister_netdevice_notifier(&cbs_device_notifier); + + return err; +} + +static void __exit cbs_module_exit(void) +{ + unregister_qdisc(&cbs_qdisc_ops); + unregister_netdevice_notifier(&cbs_device_notifier); +} +module_init(cbs_module_init) +module_exit(cbs_module_exit) +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_choke.c b/net/sched/sch_choke.c new file mode 100644 index 000000000..3ac3e5c80 --- /dev/null +++ b/net/sched/sch_choke.c @@ -0,0 +1,515 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/sched/sch_choke.c CHOKE scheduler + * + * Copyright (c) 2011 Stephen Hemminger + * Copyright (c) 2011 Eric Dumazet + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + CHOKe stateless AQM for fair bandwidth allocation + ================================================= + + CHOKe (CHOose and Keep for responsive flows, CHOose and Kill for + unresponsive flows) is a variant of RED that penalizes misbehaving flows but + maintains no flow state. The difference from RED is an additional step + during the enqueuing process. If average queue size is over the + low threshold (qmin), a packet is chosen at random from the queue. + If both the new and chosen packet are from the same flow, both + are dropped. Unlike RED, CHOKe is not really a "classful" qdisc because it + needs to access packets in queue randomly. It has a minimal class + interface to allow overriding the builtin flow classifier with + filters. + + Source: + R. Pan, B. Prabhakar, and K. Psounis, "CHOKe, A Stateless + Active Queue Management Scheme for Approximating Fair Bandwidth Allocation", + IEEE INFOCOM, 2000. + + A. Tang, J. Wang, S. Low, "Understanding CHOKe: Throughput and Spatial + Characteristics", IEEE/ACM Transactions on Networking, 2004 + + */ + +/* Upper bound on size of sk_buff table (packets) */ +#define CHOKE_MAX_QUEUE (128*1024 - 1) + +struct choke_sched_data { +/* Parameters */ + u32 limit; + unsigned char flags; + + struct red_parms parms; + +/* Variables */ + struct red_vars vars; + struct { + u32 prob_drop; /* Early probability drops */ + u32 prob_mark; /* Early probability marks */ + u32 forced_drop; /* Forced drops, qavg > max_thresh */ + u32 forced_mark; /* Forced marks, qavg > max_thresh */ + u32 pdrop; /* Drops due to queue limits */ + u32 matched; /* Drops to flow match */ + } stats; + + unsigned int head; + unsigned int tail; + + unsigned int tab_mask; /* size - 1 */ + + struct sk_buff **tab; +}; + +/* number of elements in queue including holes */ +static unsigned int choke_len(const struct choke_sched_data *q) +{ + return (q->tail - q->head) & q->tab_mask; +} + +/* Is ECN parameter configured */ +static int use_ecn(const struct choke_sched_data *q) +{ + return q->flags & TC_RED_ECN; +} + +/* Should packets over max just be dropped (versus marked) */ +static int use_harddrop(const struct choke_sched_data *q) +{ + return q->flags & TC_RED_HARDDROP; +} + +/* Move head pointer forward to skip over holes */ +static void choke_zap_head_holes(struct choke_sched_data *q) +{ + do { + q->head = (q->head + 1) & q->tab_mask; + if (q->head == q->tail) + break; + } while (q->tab[q->head] == NULL); +} + +/* Move tail pointer backwards to reuse holes */ +static void choke_zap_tail_holes(struct choke_sched_data *q) +{ + do { + q->tail = (q->tail - 1) & q->tab_mask; + if (q->head == q->tail) + break; + } while (q->tab[q->tail] == NULL); +} + +/* Drop packet from queue array by creating a "hole" */ +static void choke_drop_by_idx(struct Qdisc *sch, unsigned int idx, + struct sk_buff **to_free) +{ + struct choke_sched_data *q = qdisc_priv(sch); + struct sk_buff *skb = q->tab[idx]; + + q->tab[idx] = NULL; + + if (idx == q->head) + choke_zap_head_holes(q); + if (idx == q->tail) + choke_zap_tail_holes(q); + + qdisc_qstats_backlog_dec(sch, skb); + qdisc_tree_reduce_backlog(sch, 1, qdisc_pkt_len(skb)); + qdisc_drop(skb, sch, to_free); + --sch->q.qlen; +} + +struct choke_skb_cb { + u8 keys_valid; + struct flow_keys_digest keys; +}; + +static inline struct choke_skb_cb *choke_skb_cb(const struct sk_buff *skb) +{ + qdisc_cb_private_validate(skb, sizeof(struct choke_skb_cb)); + return (struct choke_skb_cb *)qdisc_skb_cb(skb)->data; +} + +/* + * Compare flow of two packets + * Returns true only if source and destination address and port match. + * false for special cases + */ +static bool choke_match_flow(struct sk_buff *skb1, + struct sk_buff *skb2) +{ + struct flow_keys temp; + + if (skb1->protocol != skb2->protocol) + return false; + + if (!choke_skb_cb(skb1)->keys_valid) { + choke_skb_cb(skb1)->keys_valid = 1; + skb_flow_dissect_flow_keys(skb1, &temp, 0); + make_flow_keys_digest(&choke_skb_cb(skb1)->keys, &temp); + } + + if (!choke_skb_cb(skb2)->keys_valid) { + choke_skb_cb(skb2)->keys_valid = 1; + skb_flow_dissect_flow_keys(skb2, &temp, 0); + make_flow_keys_digest(&choke_skb_cb(skb2)->keys, &temp); + } + + return !memcmp(&choke_skb_cb(skb1)->keys, + &choke_skb_cb(skb2)->keys, + sizeof(choke_skb_cb(skb1)->keys)); +} + +/* + * Select a packet at random from queue + * HACK: since queue can have holes from previous deletion; retry several + * times to find a random skb but then just give up and return the head + * Will return NULL if queue is empty (q->head == q->tail) + */ +static struct sk_buff *choke_peek_random(const struct choke_sched_data *q, + unsigned int *pidx) +{ + struct sk_buff *skb; + int retrys = 3; + + do { + *pidx = (q->head + prandom_u32_max(choke_len(q))) & q->tab_mask; + skb = q->tab[*pidx]; + if (skb) + return skb; + } while (--retrys > 0); + + return q->tab[*pidx = q->head]; +} + +/* + * Compare new packet with random packet in queue + * returns true if matched and sets *pidx + */ +static bool choke_match_random(const struct choke_sched_data *q, + struct sk_buff *nskb, + unsigned int *pidx) +{ + struct sk_buff *oskb; + + if (q->head == q->tail) + return false; + + oskb = choke_peek_random(q, pidx); + return choke_match_flow(oskb, nskb); +} + +static int choke_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct choke_sched_data *q = qdisc_priv(sch); + const struct red_parms *p = &q->parms; + + choke_skb_cb(skb)->keys_valid = 0; + /* Compute average queue usage (see RED) */ + q->vars.qavg = red_calc_qavg(p, &q->vars, sch->q.qlen); + if (red_is_idling(&q->vars)) + red_end_of_idle_period(&q->vars); + + /* Is queue small? */ + if (q->vars.qavg <= p->qth_min) + q->vars.qcount = -1; + else { + unsigned int idx; + + /* Draw a packet at random from queue and compare flow */ + if (choke_match_random(q, skb, &idx)) { + q->stats.matched++; + choke_drop_by_idx(sch, idx, to_free); + goto congestion_drop; + } + + /* Queue is large, always mark/drop */ + if (q->vars.qavg > p->qth_max) { + q->vars.qcount = -1; + + qdisc_qstats_overlimit(sch); + if (use_harddrop(q) || !use_ecn(q) || + !INET_ECN_set_ce(skb)) { + q->stats.forced_drop++; + goto congestion_drop; + } + + q->stats.forced_mark++; + } else if (++q->vars.qcount) { + if (red_mark_probability(p, &q->vars, q->vars.qavg)) { + q->vars.qcount = 0; + q->vars.qR = red_random(p); + + qdisc_qstats_overlimit(sch); + if (!use_ecn(q) || !INET_ECN_set_ce(skb)) { + q->stats.prob_drop++; + goto congestion_drop; + } + + q->stats.prob_mark++; + } + } else + q->vars.qR = red_random(p); + } + + /* Admit new packet */ + if (sch->q.qlen < q->limit) { + q->tab[q->tail] = skb; + q->tail = (q->tail + 1) & q->tab_mask; + ++sch->q.qlen; + qdisc_qstats_backlog_inc(sch, skb); + return NET_XMIT_SUCCESS; + } + + q->stats.pdrop++; + return qdisc_drop(skb, sch, to_free); + +congestion_drop: + qdisc_drop(skb, sch, to_free); + return NET_XMIT_CN; +} + +static struct sk_buff *choke_dequeue(struct Qdisc *sch) +{ + struct choke_sched_data *q = qdisc_priv(sch); + struct sk_buff *skb; + + if (q->head == q->tail) { + if (!red_is_idling(&q->vars)) + red_start_of_idle_period(&q->vars); + return NULL; + } + + skb = q->tab[q->head]; + q->tab[q->head] = NULL; + choke_zap_head_holes(q); + --sch->q.qlen; + qdisc_qstats_backlog_dec(sch, skb); + qdisc_bstats_update(sch, skb); + + return skb; +} + +static void choke_reset(struct Qdisc *sch) +{ + struct choke_sched_data *q = qdisc_priv(sch); + + while (q->head != q->tail) { + struct sk_buff *skb = q->tab[q->head]; + + q->head = (q->head + 1) & q->tab_mask; + if (!skb) + continue; + rtnl_qdisc_drop(skb, sch); + } + + if (q->tab) + memset(q->tab, 0, (q->tab_mask + 1) * sizeof(struct sk_buff *)); + q->head = q->tail = 0; + red_restart(&q->vars); +} + +static const struct nla_policy choke_policy[TCA_CHOKE_MAX + 1] = { + [TCA_CHOKE_PARMS] = { .len = sizeof(struct tc_red_qopt) }, + [TCA_CHOKE_STAB] = { .len = RED_STAB_SIZE }, + [TCA_CHOKE_MAX_P] = { .type = NLA_U32 }, +}; + + +static void choke_free(void *addr) +{ + kvfree(addr); +} + +static int choke_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct choke_sched_data *q = qdisc_priv(sch); + struct nlattr *tb[TCA_CHOKE_MAX + 1]; + const struct tc_red_qopt *ctl; + int err; + struct sk_buff **old = NULL; + unsigned int mask; + u32 max_P; + u8 *stab; + + if (opt == NULL) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, TCA_CHOKE_MAX, opt, + choke_policy, NULL); + if (err < 0) + return err; + + if (tb[TCA_CHOKE_PARMS] == NULL || + tb[TCA_CHOKE_STAB] == NULL) + return -EINVAL; + + max_P = tb[TCA_CHOKE_MAX_P] ? nla_get_u32(tb[TCA_CHOKE_MAX_P]) : 0; + + ctl = nla_data(tb[TCA_CHOKE_PARMS]); + stab = nla_data(tb[TCA_CHOKE_STAB]); + if (!red_check_params(ctl->qth_min, ctl->qth_max, ctl->Wlog, ctl->Scell_log, stab)) + return -EINVAL; + + if (ctl->limit > CHOKE_MAX_QUEUE) + return -EINVAL; + + mask = roundup_pow_of_two(ctl->limit + 1) - 1; + if (mask != q->tab_mask) { + struct sk_buff **ntab; + + ntab = kvcalloc(mask + 1, sizeof(struct sk_buff *), GFP_KERNEL); + if (!ntab) + return -ENOMEM; + + sch_tree_lock(sch); + old = q->tab; + if (old) { + unsigned int oqlen = sch->q.qlen, tail = 0; + unsigned dropped = 0; + + while (q->head != q->tail) { + struct sk_buff *skb = q->tab[q->head]; + + q->head = (q->head + 1) & q->tab_mask; + if (!skb) + continue; + if (tail < mask) { + ntab[tail++] = skb; + continue; + } + dropped += qdisc_pkt_len(skb); + qdisc_qstats_backlog_dec(sch, skb); + --sch->q.qlen; + rtnl_qdisc_drop(skb, sch); + } + qdisc_tree_reduce_backlog(sch, oqlen - sch->q.qlen, dropped); + q->head = 0; + q->tail = tail; + } + + q->tab_mask = mask; + q->tab = ntab; + } else + sch_tree_lock(sch); + + q->flags = ctl->flags; + q->limit = ctl->limit; + + red_set_parms(&q->parms, ctl->qth_min, ctl->qth_max, ctl->Wlog, + ctl->Plog, ctl->Scell_log, + stab, + max_P); + red_set_vars(&q->vars); + + if (q->head == q->tail) + red_end_of_idle_period(&q->vars); + + sch_tree_unlock(sch); + choke_free(old); + return 0; +} + +static int choke_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + return choke_change(sch, opt, extack); +} + +static int choke_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct choke_sched_data *q = qdisc_priv(sch); + struct nlattr *opts = NULL; + struct tc_red_qopt opt = { + .limit = q->limit, + .flags = q->flags, + .qth_min = q->parms.qth_min >> q->parms.Wlog, + .qth_max = q->parms.qth_max >> q->parms.Wlog, + .Wlog = q->parms.Wlog, + .Plog = q->parms.Plog, + .Scell_log = q->parms.Scell_log, + }; + + opts = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (opts == NULL) + goto nla_put_failure; + + if (nla_put(skb, TCA_CHOKE_PARMS, sizeof(opt), &opt) || + nla_put_u32(skb, TCA_CHOKE_MAX_P, q->parms.max_P)) + goto nla_put_failure; + return nla_nest_end(skb, opts); + +nla_put_failure: + nla_nest_cancel(skb, opts); + return -EMSGSIZE; +} + +static int choke_dump_stats(struct Qdisc *sch, struct gnet_dump *d) +{ + struct choke_sched_data *q = qdisc_priv(sch); + struct tc_choke_xstats st = { + .early = q->stats.prob_drop + q->stats.forced_drop, + .marked = q->stats.prob_mark + q->stats.forced_mark, + .pdrop = q->stats.pdrop, + .matched = q->stats.matched, + }; + + return gnet_stats_copy_app(d, &st, sizeof(st)); +} + +static void choke_destroy(struct Qdisc *sch) +{ + struct choke_sched_data *q = qdisc_priv(sch); + + choke_free(q->tab); +} + +static struct sk_buff *choke_peek_head(struct Qdisc *sch) +{ + struct choke_sched_data *q = qdisc_priv(sch); + + return (q->head != q->tail) ? q->tab[q->head] : NULL; +} + +static struct Qdisc_ops choke_qdisc_ops __read_mostly = { + .id = "choke", + .priv_size = sizeof(struct choke_sched_data), + + .enqueue = choke_enqueue, + .dequeue = choke_dequeue, + .peek = choke_peek_head, + .init = choke_init, + .destroy = choke_destroy, + .reset = choke_reset, + .change = choke_change, + .dump = choke_dump, + .dump_stats = choke_dump_stats, + .owner = THIS_MODULE, +}; + +static int __init choke_module_init(void) +{ + return register_qdisc(&choke_qdisc_ops); +} + +static void __exit choke_module_exit(void) +{ + unregister_qdisc(&choke_qdisc_ops); +} + +module_init(choke_module_init) +module_exit(choke_module_exit) + +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_codel.c b/net/sched/sch_codel.c new file mode 100644 index 000000000..d7a487454 --- /dev/null +++ b/net/sched/sch_codel.c @@ -0,0 +1,307 @@ +/* + * Codel - The Controlled-Delay Active Queue Management algorithm + * + * Copyright (C) 2011-2012 Kathleen Nichols + * Copyright (C) 2011-2012 Van Jacobson + * + * Implemented on linux by : + * Copyright (C) 2012 Michael D. Taht + * Copyright (C) 2012,2015 Eric Dumazet + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions, and the following disclaimer, + * without modification. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. The names of the authors may not be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * Alternatively, provided that this notice is retained in full, this + * software may be distributed under the terms of the GNU General + * Public License ("GPL") version 2, in which case the provisions of the + * GPL apply INSTEAD OF those given above. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH + * DAMAGE. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#define DEFAULT_CODEL_LIMIT 1000 + +struct codel_sched_data { + struct codel_params params; + struct codel_vars vars; + struct codel_stats stats; + u32 drop_overlimit; +}; + +/* This is the specific function called from codel_dequeue() + * to dequeue a packet from queue. Note: backlog is handled in + * codel, we dont need to reduce it here. + */ +static struct sk_buff *dequeue_func(struct codel_vars *vars, void *ctx) +{ + struct Qdisc *sch = ctx; + struct sk_buff *skb = __qdisc_dequeue_head(&sch->q); + + if (skb) { + sch->qstats.backlog -= qdisc_pkt_len(skb); + prefetch(&skb->end); /* we'll need skb_shinfo() */ + } + return skb; +} + +static void drop_func(struct sk_buff *skb, void *ctx) +{ + struct Qdisc *sch = ctx; + + kfree_skb(skb); + qdisc_qstats_drop(sch); +} + +static struct sk_buff *codel_qdisc_dequeue(struct Qdisc *sch) +{ + struct codel_sched_data *q = qdisc_priv(sch); + struct sk_buff *skb; + + skb = codel_dequeue(sch, &sch->qstats.backlog, &q->params, &q->vars, + &q->stats, qdisc_pkt_len, codel_get_enqueue_time, + drop_func, dequeue_func); + + /* We cant call qdisc_tree_reduce_backlog() if our qlen is 0, + * or HTB crashes. Defer it for next round. + */ + if (q->stats.drop_count && sch->q.qlen) { + qdisc_tree_reduce_backlog(sch, q->stats.drop_count, q->stats.drop_len); + q->stats.drop_count = 0; + q->stats.drop_len = 0; + } + if (skb) + qdisc_bstats_update(sch, skb); + return skb; +} + +static int codel_qdisc_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct codel_sched_data *q; + + if (likely(qdisc_qlen(sch) < sch->limit)) { + codel_set_enqueue_time(skb); + return qdisc_enqueue_tail(skb, sch); + } + q = qdisc_priv(sch); + q->drop_overlimit++; + return qdisc_drop(skb, sch, to_free); +} + +static const struct nla_policy codel_policy[TCA_CODEL_MAX + 1] = { + [TCA_CODEL_TARGET] = { .type = NLA_U32 }, + [TCA_CODEL_LIMIT] = { .type = NLA_U32 }, + [TCA_CODEL_INTERVAL] = { .type = NLA_U32 }, + [TCA_CODEL_ECN] = { .type = NLA_U32 }, + [TCA_CODEL_CE_THRESHOLD]= { .type = NLA_U32 }, +}; + +static int codel_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct codel_sched_data *q = qdisc_priv(sch); + struct nlattr *tb[TCA_CODEL_MAX + 1]; + unsigned int qlen, dropped = 0; + int err; + + err = nla_parse_nested_deprecated(tb, TCA_CODEL_MAX, opt, + codel_policy, NULL); + if (err < 0) + return err; + + sch_tree_lock(sch); + + if (tb[TCA_CODEL_TARGET]) { + u32 target = nla_get_u32(tb[TCA_CODEL_TARGET]); + + q->params.target = ((u64)target * NSEC_PER_USEC) >> CODEL_SHIFT; + } + + if (tb[TCA_CODEL_CE_THRESHOLD]) { + u64 val = nla_get_u32(tb[TCA_CODEL_CE_THRESHOLD]); + + q->params.ce_threshold = (val * NSEC_PER_USEC) >> CODEL_SHIFT; + } + + if (tb[TCA_CODEL_INTERVAL]) { + u32 interval = nla_get_u32(tb[TCA_CODEL_INTERVAL]); + + q->params.interval = ((u64)interval * NSEC_PER_USEC) >> CODEL_SHIFT; + } + + if (tb[TCA_CODEL_LIMIT]) + sch->limit = nla_get_u32(tb[TCA_CODEL_LIMIT]); + + if (tb[TCA_CODEL_ECN]) + q->params.ecn = !!nla_get_u32(tb[TCA_CODEL_ECN]); + + qlen = sch->q.qlen; + while (sch->q.qlen > sch->limit) { + struct sk_buff *skb = __qdisc_dequeue_head(&sch->q); + + dropped += qdisc_pkt_len(skb); + qdisc_qstats_backlog_dec(sch, skb); + rtnl_qdisc_drop(skb, sch); + } + qdisc_tree_reduce_backlog(sch, qlen - sch->q.qlen, dropped); + + sch_tree_unlock(sch); + return 0; +} + +static int codel_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct codel_sched_data *q = qdisc_priv(sch); + + sch->limit = DEFAULT_CODEL_LIMIT; + + codel_params_init(&q->params); + codel_vars_init(&q->vars); + codel_stats_init(&q->stats); + q->params.mtu = psched_mtu(qdisc_dev(sch)); + + if (opt) { + int err = codel_change(sch, opt, extack); + + if (err) + return err; + } + + if (sch->limit >= 1) + sch->flags |= TCQ_F_CAN_BYPASS; + else + sch->flags &= ~TCQ_F_CAN_BYPASS; + + return 0; +} + +static int codel_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct codel_sched_data *q = qdisc_priv(sch); + struct nlattr *opts; + + opts = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (opts == NULL) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_CODEL_TARGET, + codel_time_to_us(q->params.target)) || + nla_put_u32(skb, TCA_CODEL_LIMIT, + sch->limit) || + nla_put_u32(skb, TCA_CODEL_INTERVAL, + codel_time_to_us(q->params.interval)) || + nla_put_u32(skb, TCA_CODEL_ECN, + q->params.ecn)) + goto nla_put_failure; + if (q->params.ce_threshold != CODEL_DISABLED_THRESHOLD && + nla_put_u32(skb, TCA_CODEL_CE_THRESHOLD, + codel_time_to_us(q->params.ce_threshold))) + goto nla_put_failure; + return nla_nest_end(skb, opts); + +nla_put_failure: + nla_nest_cancel(skb, opts); + return -1; +} + +static int codel_dump_stats(struct Qdisc *sch, struct gnet_dump *d) +{ + const struct codel_sched_data *q = qdisc_priv(sch); + struct tc_codel_xstats st = { + .maxpacket = q->stats.maxpacket, + .count = q->vars.count, + .lastcount = q->vars.lastcount, + .drop_overlimit = q->drop_overlimit, + .ldelay = codel_time_to_us(q->vars.ldelay), + .dropping = q->vars.dropping, + .ecn_mark = q->stats.ecn_mark, + .ce_mark = q->stats.ce_mark, + }; + + if (q->vars.dropping) { + codel_tdiff_t delta = q->vars.drop_next - codel_get_time(); + + if (delta >= 0) + st.drop_next = codel_time_to_us(delta); + else + st.drop_next = -codel_time_to_us(-delta); + } + + return gnet_stats_copy_app(d, &st, sizeof(st)); +} + +static void codel_reset(struct Qdisc *sch) +{ + struct codel_sched_data *q = qdisc_priv(sch); + + qdisc_reset_queue(sch); + codel_vars_init(&q->vars); +} + +static struct Qdisc_ops codel_qdisc_ops __read_mostly = { + .id = "codel", + .priv_size = sizeof(struct codel_sched_data), + + .enqueue = codel_qdisc_enqueue, + .dequeue = codel_qdisc_dequeue, + .peek = qdisc_peek_dequeued, + .init = codel_init, + .reset = codel_reset, + .change = codel_change, + .dump = codel_dump, + .dump_stats = codel_dump_stats, + .owner = THIS_MODULE, +}; + +static int __init codel_module_init(void) +{ + return register_qdisc(&codel_qdisc_ops); +} + +static void __exit codel_module_exit(void) +{ + unregister_qdisc(&codel_qdisc_ops); +} + +module_init(codel_module_init) +module_exit(codel_module_exit) + +MODULE_DESCRIPTION("Controlled Delay queue discipline"); +MODULE_AUTHOR("Dave Taht"); +MODULE_AUTHOR("Eric Dumazet"); +MODULE_LICENSE("Dual BSD/GPL"); diff --git a/net/sched/sch_drr.c b/net/sched/sch_drr.c new file mode 100644 index 000000000..e35a4e90f --- /dev/null +++ b/net/sched/sch_drr.c @@ -0,0 +1,496 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/sched/sch_drr.c Deficit Round Robin scheduler + * + * Copyright (c) 2008 Patrick McHardy + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct drr_class { + struct Qdisc_class_common common; + unsigned int filter_cnt; + + struct gnet_stats_basic_sync bstats; + struct gnet_stats_queue qstats; + struct net_rate_estimator __rcu *rate_est; + struct list_head alist; + struct Qdisc *qdisc; + + u32 quantum; + u32 deficit; +}; + +struct drr_sched { + struct list_head active; + struct tcf_proto __rcu *filter_list; + struct tcf_block *block; + struct Qdisc_class_hash clhash; +}; + +static struct drr_class *drr_find_class(struct Qdisc *sch, u32 classid) +{ + struct drr_sched *q = qdisc_priv(sch); + struct Qdisc_class_common *clc; + + clc = qdisc_class_find(&q->clhash, classid); + if (clc == NULL) + return NULL; + return container_of(clc, struct drr_class, common); +} + +static const struct nla_policy drr_policy[TCA_DRR_MAX + 1] = { + [TCA_DRR_QUANTUM] = { .type = NLA_U32 }, +}; + +static int drr_change_class(struct Qdisc *sch, u32 classid, u32 parentid, + struct nlattr **tca, unsigned long *arg, + struct netlink_ext_ack *extack) +{ + struct drr_sched *q = qdisc_priv(sch); + struct drr_class *cl = (struct drr_class *)*arg; + struct nlattr *opt = tca[TCA_OPTIONS]; + struct nlattr *tb[TCA_DRR_MAX + 1]; + u32 quantum; + int err; + + if (!opt) { + NL_SET_ERR_MSG(extack, "DRR options are required for this operation"); + return -EINVAL; + } + + err = nla_parse_nested_deprecated(tb, TCA_DRR_MAX, opt, drr_policy, + extack); + if (err < 0) + return err; + + if (tb[TCA_DRR_QUANTUM]) { + quantum = nla_get_u32(tb[TCA_DRR_QUANTUM]); + if (quantum == 0) { + NL_SET_ERR_MSG(extack, "Specified DRR quantum cannot be zero"); + return -EINVAL; + } + } else + quantum = psched_mtu(qdisc_dev(sch)); + + if (cl != NULL) { + if (tca[TCA_RATE]) { + err = gen_replace_estimator(&cl->bstats, NULL, + &cl->rate_est, + NULL, true, + tca[TCA_RATE]); + if (err) { + NL_SET_ERR_MSG(extack, "Failed to replace estimator"); + return err; + } + } + + sch_tree_lock(sch); + if (tb[TCA_DRR_QUANTUM]) + cl->quantum = quantum; + sch_tree_unlock(sch); + + return 0; + } + + cl = kzalloc(sizeof(struct drr_class), GFP_KERNEL); + if (cl == NULL) + return -ENOBUFS; + + gnet_stats_basic_sync_init(&cl->bstats); + cl->common.classid = classid; + cl->quantum = quantum; + cl->qdisc = qdisc_create_dflt(sch->dev_queue, + &pfifo_qdisc_ops, classid, + NULL); + if (cl->qdisc == NULL) + cl->qdisc = &noop_qdisc; + else + qdisc_hash_add(cl->qdisc, true); + + if (tca[TCA_RATE]) { + err = gen_replace_estimator(&cl->bstats, NULL, &cl->rate_est, + NULL, true, tca[TCA_RATE]); + if (err) { + NL_SET_ERR_MSG(extack, "Failed to replace estimator"); + qdisc_put(cl->qdisc); + kfree(cl); + return err; + } + } + + sch_tree_lock(sch); + qdisc_class_hash_insert(&q->clhash, &cl->common); + sch_tree_unlock(sch); + + qdisc_class_hash_grow(sch, &q->clhash); + + *arg = (unsigned long)cl; + return 0; +} + +static void drr_destroy_class(struct Qdisc *sch, struct drr_class *cl) +{ + gen_kill_estimator(&cl->rate_est); + qdisc_put(cl->qdisc); + kfree(cl); +} + +static int drr_delete_class(struct Qdisc *sch, unsigned long arg, + struct netlink_ext_ack *extack) +{ + struct drr_sched *q = qdisc_priv(sch); + struct drr_class *cl = (struct drr_class *)arg; + + if (cl->filter_cnt > 0) + return -EBUSY; + + sch_tree_lock(sch); + + qdisc_purge_queue(cl->qdisc); + qdisc_class_hash_remove(&q->clhash, &cl->common); + + sch_tree_unlock(sch); + + drr_destroy_class(sch, cl); + return 0; +} + +static unsigned long drr_search_class(struct Qdisc *sch, u32 classid) +{ + return (unsigned long)drr_find_class(sch, classid); +} + +static struct tcf_block *drr_tcf_block(struct Qdisc *sch, unsigned long cl, + struct netlink_ext_ack *extack) +{ + struct drr_sched *q = qdisc_priv(sch); + + if (cl) { + NL_SET_ERR_MSG(extack, "DRR classid must be zero"); + return NULL; + } + + return q->block; +} + +static unsigned long drr_bind_tcf(struct Qdisc *sch, unsigned long parent, + u32 classid) +{ + struct drr_class *cl = drr_find_class(sch, classid); + + if (cl != NULL) + cl->filter_cnt++; + + return (unsigned long)cl; +} + +static void drr_unbind_tcf(struct Qdisc *sch, unsigned long arg) +{ + struct drr_class *cl = (struct drr_class *)arg; + + cl->filter_cnt--; +} + +static int drr_graft_class(struct Qdisc *sch, unsigned long arg, + struct Qdisc *new, struct Qdisc **old, + struct netlink_ext_ack *extack) +{ + struct drr_class *cl = (struct drr_class *)arg; + + if (new == NULL) { + new = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, + cl->common.classid, NULL); + if (new == NULL) + new = &noop_qdisc; + } + + *old = qdisc_replace(sch, new, &cl->qdisc); + return 0; +} + +static struct Qdisc *drr_class_leaf(struct Qdisc *sch, unsigned long arg) +{ + struct drr_class *cl = (struct drr_class *)arg; + + return cl->qdisc; +} + +static void drr_qlen_notify(struct Qdisc *csh, unsigned long arg) +{ + struct drr_class *cl = (struct drr_class *)arg; + + list_del(&cl->alist); +} + +static int drr_dump_class(struct Qdisc *sch, unsigned long arg, + struct sk_buff *skb, struct tcmsg *tcm) +{ + struct drr_class *cl = (struct drr_class *)arg; + struct nlattr *nest; + + tcm->tcm_parent = TC_H_ROOT; + tcm->tcm_handle = cl->common.classid; + tcm->tcm_info = cl->qdisc->handle; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + if (nla_put_u32(skb, TCA_DRR_QUANTUM, cl->quantum)) + goto nla_put_failure; + return nla_nest_end(skb, nest); + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int drr_dump_class_stats(struct Qdisc *sch, unsigned long arg, + struct gnet_dump *d) +{ + struct drr_class *cl = (struct drr_class *)arg; + __u32 qlen = qdisc_qlen_sum(cl->qdisc); + struct Qdisc *cl_q = cl->qdisc; + struct tc_drr_stats xstats; + + memset(&xstats, 0, sizeof(xstats)); + if (qlen) + xstats.deficit = cl->deficit; + + if (gnet_stats_copy_basic(d, NULL, &cl->bstats, true) < 0 || + gnet_stats_copy_rate_est(d, &cl->rate_est) < 0 || + gnet_stats_copy_queue(d, cl_q->cpu_qstats, &cl_q->qstats, qlen) < 0) + return -1; + + return gnet_stats_copy_app(d, &xstats, sizeof(xstats)); +} + +static void drr_walk(struct Qdisc *sch, struct qdisc_walker *arg) +{ + struct drr_sched *q = qdisc_priv(sch); + struct drr_class *cl; + unsigned int i; + + if (arg->stop) + return; + + for (i = 0; i < q->clhash.hashsize; i++) { + hlist_for_each_entry(cl, &q->clhash.hash[i], common.hnode) { + if (!tc_qdisc_stats_dump(sch, (unsigned long)cl, arg)) + return; + } + } +} + +static struct drr_class *drr_classify(struct sk_buff *skb, struct Qdisc *sch, + int *qerr) +{ + struct drr_sched *q = qdisc_priv(sch); + struct drr_class *cl; + struct tcf_result res; + struct tcf_proto *fl; + int result; + + if (TC_H_MAJ(skb->priority ^ sch->handle) == 0) { + cl = drr_find_class(sch, skb->priority); + if (cl != NULL) + return cl; + } + + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; + fl = rcu_dereference_bh(q->filter_list); + result = tcf_classify(skb, NULL, fl, &res, false); + if (result >= 0) { +#ifdef CONFIG_NET_CLS_ACT + switch (result) { + case TC_ACT_QUEUED: + case TC_ACT_STOLEN: + case TC_ACT_TRAP: + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_STOLEN; + fallthrough; + case TC_ACT_SHOT: + return NULL; + } +#endif + cl = (struct drr_class *)res.class; + if (cl == NULL) + cl = drr_find_class(sch, res.classid); + return cl; + } + return NULL; +} + +static int drr_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + unsigned int len = qdisc_pkt_len(skb); + struct drr_sched *q = qdisc_priv(sch); + struct drr_class *cl; + int err = 0; + bool first; + + cl = drr_classify(skb, sch, &err); + if (cl == NULL) { + if (err & __NET_XMIT_BYPASS) + qdisc_qstats_drop(sch); + __qdisc_drop(skb, to_free); + return err; + } + + first = !cl->qdisc->q.qlen; + err = qdisc_enqueue(skb, cl->qdisc, to_free); + if (unlikely(err != NET_XMIT_SUCCESS)) { + if (net_xmit_drop_count(err)) { + cl->qstats.drops++; + qdisc_qstats_drop(sch); + } + return err; + } + + if (first) { + list_add_tail(&cl->alist, &q->active); + cl->deficit = cl->quantum; + } + + sch->qstats.backlog += len; + sch->q.qlen++; + return err; +} + +static struct sk_buff *drr_dequeue(struct Qdisc *sch) +{ + struct drr_sched *q = qdisc_priv(sch); + struct drr_class *cl; + struct sk_buff *skb; + unsigned int len; + + if (list_empty(&q->active)) + goto out; + while (1) { + cl = list_first_entry(&q->active, struct drr_class, alist); + skb = cl->qdisc->ops->peek(cl->qdisc); + if (skb == NULL) { + qdisc_warn_nonwc(__func__, cl->qdisc); + goto out; + } + + len = qdisc_pkt_len(skb); + if (len <= cl->deficit) { + cl->deficit -= len; + skb = qdisc_dequeue_peeked(cl->qdisc); + if (unlikely(skb == NULL)) + goto out; + if (cl->qdisc->q.qlen == 0) + list_del(&cl->alist); + + bstats_update(&cl->bstats, skb); + qdisc_bstats_update(sch, skb); + qdisc_qstats_backlog_dec(sch, skb); + sch->q.qlen--; + return skb; + } + + cl->deficit += cl->quantum; + list_move_tail(&cl->alist, &q->active); + } +out: + return NULL; +} + +static int drr_init_qdisc(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct drr_sched *q = qdisc_priv(sch); + int err; + + err = tcf_block_get(&q->block, &q->filter_list, sch, extack); + if (err) + return err; + err = qdisc_class_hash_init(&q->clhash); + if (err < 0) + return err; + INIT_LIST_HEAD(&q->active); + return 0; +} + +static void drr_reset_qdisc(struct Qdisc *sch) +{ + struct drr_sched *q = qdisc_priv(sch); + struct drr_class *cl; + unsigned int i; + + for (i = 0; i < q->clhash.hashsize; i++) { + hlist_for_each_entry(cl, &q->clhash.hash[i], common.hnode) { + if (cl->qdisc->q.qlen) + list_del(&cl->alist); + qdisc_reset(cl->qdisc); + } + } +} + +static void drr_destroy_qdisc(struct Qdisc *sch) +{ + struct drr_sched *q = qdisc_priv(sch); + struct drr_class *cl; + struct hlist_node *next; + unsigned int i; + + tcf_block_put(q->block); + + for (i = 0; i < q->clhash.hashsize; i++) { + hlist_for_each_entry_safe(cl, next, &q->clhash.hash[i], + common.hnode) + drr_destroy_class(sch, cl); + } + qdisc_class_hash_destroy(&q->clhash); +} + +static const struct Qdisc_class_ops drr_class_ops = { + .change = drr_change_class, + .delete = drr_delete_class, + .find = drr_search_class, + .tcf_block = drr_tcf_block, + .bind_tcf = drr_bind_tcf, + .unbind_tcf = drr_unbind_tcf, + .graft = drr_graft_class, + .leaf = drr_class_leaf, + .qlen_notify = drr_qlen_notify, + .dump = drr_dump_class, + .dump_stats = drr_dump_class_stats, + .walk = drr_walk, +}; + +static struct Qdisc_ops drr_qdisc_ops __read_mostly = { + .cl_ops = &drr_class_ops, + .id = "drr", + .priv_size = sizeof(struct drr_sched), + .enqueue = drr_enqueue, + .dequeue = drr_dequeue, + .peek = qdisc_peek_dequeued, + .init = drr_init_qdisc, + .reset = drr_reset_qdisc, + .destroy = drr_destroy_qdisc, + .owner = THIS_MODULE, +}; + +static int __init drr_init(void) +{ + return register_qdisc(&drr_qdisc_ops); +} + +static void __exit drr_exit(void) +{ + unregister_qdisc(&drr_qdisc_ops); +} + +module_init(drr_init); +module_exit(drr_exit); +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_dsmark.c b/net/sched/sch_dsmark.c new file mode 100644 index 000000000..401ffaf87 --- /dev/null +++ b/net/sched/sch_dsmark.c @@ -0,0 +1,518 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* net/sched/sch_dsmark.c - Differentiated Services field marker */ + +/* Written 1998-2000 by Werner Almesberger, EPFL ICA */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * classid class marking + * ------- ----- ------- + * n/a 0 n/a + * x:0 1 use entry [0] + * ... ... ... + * x:y y>0 y+1 use entry [y] + * ... ... ... + * x:indices-1 indices use entry [indices-1] + * ... ... ... + * x:y y+1 use entry [y & (indices-1)] + * ... ... ... + * 0xffff 0x10000 use entry [indices-1] + */ + + +#define NO_DEFAULT_INDEX (1 << 16) + +struct mask_value { + u8 mask; + u8 value; +}; + +struct dsmark_qdisc_data { + struct Qdisc *q; + struct tcf_proto __rcu *filter_list; + struct tcf_block *block; + struct mask_value *mv; + u16 indices; + u8 set_tc_index; + u32 default_index; /* index range is 0...0xffff */ +#define DSMARK_EMBEDDED_SZ 16 + struct mask_value embedded[DSMARK_EMBEDDED_SZ]; +}; + +static inline int dsmark_valid_index(struct dsmark_qdisc_data *p, u16 index) +{ + return index <= p->indices && index > 0; +} + +/* ------------------------- Class/flow operations ------------------------- */ + +static int dsmark_graft(struct Qdisc *sch, unsigned long arg, + struct Qdisc *new, struct Qdisc **old, + struct netlink_ext_ack *extack) +{ + struct dsmark_qdisc_data *p = qdisc_priv(sch); + + pr_debug("%s(sch %p,[qdisc %p],new %p,old %p)\n", + __func__, sch, p, new, old); + + if (new == NULL) { + new = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, + sch->handle, NULL); + if (new == NULL) + new = &noop_qdisc; + } + + *old = qdisc_replace(sch, new, &p->q); + return 0; +} + +static struct Qdisc *dsmark_leaf(struct Qdisc *sch, unsigned long arg) +{ + struct dsmark_qdisc_data *p = qdisc_priv(sch); + return p->q; +} + +static unsigned long dsmark_find(struct Qdisc *sch, u32 classid) +{ + return TC_H_MIN(classid) + 1; +} + +static unsigned long dsmark_bind_filter(struct Qdisc *sch, + unsigned long parent, u32 classid) +{ + pr_debug("%s(sch %p,[qdisc %p],classid %x)\n", + __func__, sch, qdisc_priv(sch), classid); + + return dsmark_find(sch, classid); +} + +static void dsmark_unbind_filter(struct Qdisc *sch, unsigned long cl) +{ +} + +static const struct nla_policy dsmark_policy[TCA_DSMARK_MAX + 1] = { + [TCA_DSMARK_INDICES] = { .type = NLA_U16 }, + [TCA_DSMARK_DEFAULT_INDEX] = { .type = NLA_U16 }, + [TCA_DSMARK_SET_TC_INDEX] = { .type = NLA_FLAG }, + [TCA_DSMARK_MASK] = { .type = NLA_U8 }, + [TCA_DSMARK_VALUE] = { .type = NLA_U8 }, +}; + +static int dsmark_change(struct Qdisc *sch, u32 classid, u32 parent, + struct nlattr **tca, unsigned long *arg, + struct netlink_ext_ack *extack) +{ + struct dsmark_qdisc_data *p = qdisc_priv(sch); + struct nlattr *opt = tca[TCA_OPTIONS]; + struct nlattr *tb[TCA_DSMARK_MAX + 1]; + int err = -EINVAL; + + pr_debug("%s(sch %p,[qdisc %p],classid %x,parent %x), arg 0x%lx\n", + __func__, sch, p, classid, parent, *arg); + + if (!dsmark_valid_index(p, *arg)) { + err = -ENOENT; + goto errout; + } + + if (!opt) + goto errout; + + err = nla_parse_nested_deprecated(tb, TCA_DSMARK_MAX, opt, + dsmark_policy, NULL); + if (err < 0) + goto errout; + + if (tb[TCA_DSMARK_VALUE]) + p->mv[*arg - 1].value = nla_get_u8(tb[TCA_DSMARK_VALUE]); + + if (tb[TCA_DSMARK_MASK]) + p->mv[*arg - 1].mask = nla_get_u8(tb[TCA_DSMARK_MASK]); + + err = 0; + +errout: + return err; +} + +static int dsmark_delete(struct Qdisc *sch, unsigned long arg, + struct netlink_ext_ack *extack) +{ + struct dsmark_qdisc_data *p = qdisc_priv(sch); + + if (!dsmark_valid_index(p, arg)) + return -EINVAL; + + p->mv[arg - 1].mask = 0xff; + p->mv[arg - 1].value = 0; + + return 0; +} + +static void dsmark_walk(struct Qdisc *sch, struct qdisc_walker *walker) +{ + struct dsmark_qdisc_data *p = qdisc_priv(sch); + int i; + + pr_debug("%s(sch %p,[qdisc %p],walker %p)\n", + __func__, sch, p, walker); + + if (walker->stop) + return; + + for (i = 0; i < p->indices; i++) { + if (p->mv[i].mask == 0xff && !p->mv[i].value) { + walker->count++; + continue; + } + if (!tc_qdisc_stats_dump(sch, i + 1, walker)) + break; + } +} + +static struct tcf_block *dsmark_tcf_block(struct Qdisc *sch, unsigned long cl, + struct netlink_ext_ack *extack) +{ + struct dsmark_qdisc_data *p = qdisc_priv(sch); + + return p->block; +} + +/* --------------------------- Qdisc operations ---------------------------- */ + +static int dsmark_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + unsigned int len = qdisc_pkt_len(skb); + struct dsmark_qdisc_data *p = qdisc_priv(sch); + int err; + + pr_debug("%s(skb %p,sch %p,[qdisc %p])\n", __func__, skb, sch, p); + + if (p->set_tc_index) { + int wlen = skb_network_offset(skb); + + switch (skb_protocol(skb, true)) { + case htons(ETH_P_IP): + wlen += sizeof(struct iphdr); + if (!pskb_may_pull(skb, wlen) || + skb_try_make_writable(skb, wlen)) + goto drop; + + skb->tc_index = ipv4_get_dsfield(ip_hdr(skb)) + & ~INET_ECN_MASK; + break; + + case htons(ETH_P_IPV6): + wlen += sizeof(struct ipv6hdr); + if (!pskb_may_pull(skb, wlen) || + skb_try_make_writable(skb, wlen)) + goto drop; + + skb->tc_index = ipv6_get_dsfield(ipv6_hdr(skb)) + & ~INET_ECN_MASK; + break; + default: + skb->tc_index = 0; + break; + } + } + + if (TC_H_MAJ(skb->priority) == sch->handle) + skb->tc_index = TC_H_MIN(skb->priority); + else { + struct tcf_result res; + struct tcf_proto *fl = rcu_dereference_bh(p->filter_list); + int result = tcf_classify(skb, NULL, fl, &res, false); + + pr_debug("result %d class 0x%04x\n", result, res.classid); + + switch (result) { +#ifdef CONFIG_NET_CLS_ACT + case TC_ACT_QUEUED: + case TC_ACT_STOLEN: + case TC_ACT_TRAP: + __qdisc_drop(skb, to_free); + return NET_XMIT_SUCCESS | __NET_XMIT_STOLEN; + + case TC_ACT_SHOT: + goto drop; +#endif + case TC_ACT_OK: + skb->tc_index = TC_H_MIN(res.classid); + break; + + default: + if (p->default_index != NO_DEFAULT_INDEX) + skb->tc_index = p->default_index; + break; + } + } + + err = qdisc_enqueue(skb, p->q, to_free); + if (err != NET_XMIT_SUCCESS) { + if (net_xmit_drop_count(err)) + qdisc_qstats_drop(sch); + return err; + } + + sch->qstats.backlog += len; + sch->q.qlen++; + + return NET_XMIT_SUCCESS; + +drop: + qdisc_drop(skb, sch, to_free); + return NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; +} + +static struct sk_buff *dsmark_dequeue(struct Qdisc *sch) +{ + struct dsmark_qdisc_data *p = qdisc_priv(sch); + struct sk_buff *skb; + u32 index; + + pr_debug("%s(sch %p,[qdisc %p])\n", __func__, sch, p); + + skb = qdisc_dequeue_peeked(p->q); + if (skb == NULL) + return NULL; + + qdisc_bstats_update(sch, skb); + qdisc_qstats_backlog_dec(sch, skb); + sch->q.qlen--; + + index = skb->tc_index & (p->indices - 1); + pr_debug("index %d->%d\n", skb->tc_index, index); + + switch (skb_protocol(skb, true)) { + case htons(ETH_P_IP): + ipv4_change_dsfield(ip_hdr(skb), p->mv[index].mask, + p->mv[index].value); + break; + case htons(ETH_P_IPV6): + ipv6_change_dsfield(ipv6_hdr(skb), p->mv[index].mask, + p->mv[index].value); + break; + default: + /* + * Only complain if a change was actually attempted. + * This way, we can send non-IP traffic through dsmark + * and don't need yet another qdisc as a bypass. + */ + if (p->mv[index].mask != 0xff || p->mv[index].value) + pr_warn("%s: unsupported protocol %d\n", + __func__, ntohs(skb_protocol(skb, true))); + break; + } + + return skb; +} + +static struct sk_buff *dsmark_peek(struct Qdisc *sch) +{ + struct dsmark_qdisc_data *p = qdisc_priv(sch); + + pr_debug("%s(sch %p,[qdisc %p])\n", __func__, sch, p); + + return p->q->ops->peek(p->q); +} + +static int dsmark_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct dsmark_qdisc_data *p = qdisc_priv(sch); + struct nlattr *tb[TCA_DSMARK_MAX + 1]; + int err = -EINVAL; + u32 default_index = NO_DEFAULT_INDEX; + u16 indices; + int i; + + pr_debug("%s(sch %p,[qdisc %p],opt %p)\n", __func__, sch, p, opt); + + if (!opt) + goto errout; + + err = tcf_block_get(&p->block, &p->filter_list, sch, extack); + if (err) + return err; + + err = nla_parse_nested_deprecated(tb, TCA_DSMARK_MAX, opt, + dsmark_policy, NULL); + if (err < 0) + goto errout; + + err = -EINVAL; + if (!tb[TCA_DSMARK_INDICES]) + goto errout; + indices = nla_get_u16(tb[TCA_DSMARK_INDICES]); + + if (hweight32(indices) != 1) + goto errout; + + if (tb[TCA_DSMARK_DEFAULT_INDEX]) + default_index = nla_get_u16(tb[TCA_DSMARK_DEFAULT_INDEX]); + + if (indices <= DSMARK_EMBEDDED_SZ) + p->mv = p->embedded; + else + p->mv = kmalloc_array(indices, sizeof(*p->mv), GFP_KERNEL); + if (!p->mv) { + err = -ENOMEM; + goto errout; + } + for (i = 0; i < indices; i++) { + p->mv[i].mask = 0xff; + p->mv[i].value = 0; + } + p->indices = indices; + p->default_index = default_index; + p->set_tc_index = nla_get_flag(tb[TCA_DSMARK_SET_TC_INDEX]); + + p->q = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, sch->handle, + NULL); + if (p->q == NULL) + p->q = &noop_qdisc; + else + qdisc_hash_add(p->q, true); + + pr_debug("%s: qdisc %p\n", __func__, p->q); + + err = 0; +errout: + return err; +} + +static void dsmark_reset(struct Qdisc *sch) +{ + struct dsmark_qdisc_data *p = qdisc_priv(sch); + + pr_debug("%s(sch %p,[qdisc %p])\n", __func__, sch, p); + if (p->q) + qdisc_reset(p->q); +} + +static void dsmark_destroy(struct Qdisc *sch) +{ + struct dsmark_qdisc_data *p = qdisc_priv(sch); + + pr_debug("%s(sch %p,[qdisc %p])\n", __func__, sch, p); + + tcf_block_put(p->block); + qdisc_put(p->q); + if (p->mv != p->embedded) + kfree(p->mv); +} + +static int dsmark_dump_class(struct Qdisc *sch, unsigned long cl, + struct sk_buff *skb, struct tcmsg *tcm) +{ + struct dsmark_qdisc_data *p = qdisc_priv(sch); + struct nlattr *opts = NULL; + + pr_debug("%s(sch %p,[qdisc %p],class %ld\n", __func__, sch, p, cl); + + if (!dsmark_valid_index(p, cl)) + return -EINVAL; + + tcm->tcm_handle = TC_H_MAKE(TC_H_MAJ(sch->handle), cl - 1); + tcm->tcm_info = p->q->handle; + + opts = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (opts == NULL) + goto nla_put_failure; + if (nla_put_u8(skb, TCA_DSMARK_MASK, p->mv[cl - 1].mask) || + nla_put_u8(skb, TCA_DSMARK_VALUE, p->mv[cl - 1].value)) + goto nla_put_failure; + + return nla_nest_end(skb, opts); + +nla_put_failure: + nla_nest_cancel(skb, opts); + return -EMSGSIZE; +} + +static int dsmark_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct dsmark_qdisc_data *p = qdisc_priv(sch); + struct nlattr *opts = NULL; + + opts = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (opts == NULL) + goto nla_put_failure; + if (nla_put_u16(skb, TCA_DSMARK_INDICES, p->indices)) + goto nla_put_failure; + + if (p->default_index != NO_DEFAULT_INDEX && + nla_put_u16(skb, TCA_DSMARK_DEFAULT_INDEX, p->default_index)) + goto nla_put_failure; + + if (p->set_tc_index && + nla_put_flag(skb, TCA_DSMARK_SET_TC_INDEX)) + goto nla_put_failure; + + return nla_nest_end(skb, opts); + +nla_put_failure: + nla_nest_cancel(skb, opts); + return -EMSGSIZE; +} + +static const struct Qdisc_class_ops dsmark_class_ops = { + .graft = dsmark_graft, + .leaf = dsmark_leaf, + .find = dsmark_find, + .change = dsmark_change, + .delete = dsmark_delete, + .walk = dsmark_walk, + .tcf_block = dsmark_tcf_block, + .bind_tcf = dsmark_bind_filter, + .unbind_tcf = dsmark_unbind_filter, + .dump = dsmark_dump_class, +}; + +static struct Qdisc_ops dsmark_qdisc_ops __read_mostly = { + .next = NULL, + .cl_ops = &dsmark_class_ops, + .id = "dsmark", + .priv_size = sizeof(struct dsmark_qdisc_data), + .enqueue = dsmark_enqueue, + .dequeue = dsmark_dequeue, + .peek = dsmark_peek, + .init = dsmark_init, + .reset = dsmark_reset, + .destroy = dsmark_destroy, + .change = NULL, + .dump = dsmark_dump, + .owner = THIS_MODULE, +}; + +static int __init dsmark_module_init(void) +{ + return register_qdisc(&dsmark_qdisc_ops); +} + +static void __exit dsmark_module_exit(void) +{ + unregister_qdisc(&dsmark_qdisc_ops); +} + +module_init(dsmark_module_init) +module_exit(dsmark_module_exit) + +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_etf.c b/net/sched/sch_etf.c new file mode 100644 index 000000000..61d1f0e32 --- /dev/null +++ b/net/sched/sch_etf.c @@ -0,0 +1,515 @@ +// SPDX-License-Identifier: GPL-2.0 + +/* net/sched/sch_etf.c Earliest TxTime First queueing discipline. + * + * Authors: Jesus Sanchez-Palencia + * Vinicius Costa Gomes + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define DEADLINE_MODE_IS_ON(x) ((x)->flags & TC_ETF_DEADLINE_MODE_ON) +#define OFFLOAD_IS_ON(x) ((x)->flags & TC_ETF_OFFLOAD_ON) +#define SKIP_SOCK_CHECK_IS_SET(x) ((x)->flags & TC_ETF_SKIP_SOCK_CHECK) + +struct etf_sched_data { + bool offload; + bool deadline_mode; + bool skip_sock_check; + int clockid; + int queue; + s32 delta; /* in ns */ + ktime_t last; /* The txtime of the last skb sent to the netdevice. */ + struct rb_root_cached head; + struct qdisc_watchdog watchdog; + ktime_t (*get_time)(void); +}; + +static const struct nla_policy etf_policy[TCA_ETF_MAX + 1] = { + [TCA_ETF_PARMS] = { .len = sizeof(struct tc_etf_qopt) }, +}; + +static inline int validate_input_params(struct tc_etf_qopt *qopt, + struct netlink_ext_ack *extack) +{ + /* Check if params comply to the following rules: + * * Clockid and delta must be valid. + * + * * Dynamic clockids are not supported. + * + * * Delta must be a positive integer. + * + * Also note that for the HW offload case, we must + * expect that system clocks have been synchronized to PHC. + */ + if (qopt->clockid < 0) { + NL_SET_ERR_MSG(extack, "Dynamic clockids are not supported"); + return -ENOTSUPP; + } + + if (qopt->clockid != CLOCK_TAI) { + NL_SET_ERR_MSG(extack, "Invalid clockid. CLOCK_TAI must be used"); + return -EINVAL; + } + + if (qopt->delta < 0) { + NL_SET_ERR_MSG(extack, "Delta must be positive"); + return -EINVAL; + } + + return 0; +} + +static bool is_packet_valid(struct Qdisc *sch, struct sk_buff *nskb) +{ + struct etf_sched_data *q = qdisc_priv(sch); + ktime_t txtime = nskb->tstamp; + struct sock *sk = nskb->sk; + ktime_t now; + + if (q->skip_sock_check) + goto skip; + + if (!sk || !sk_fullsock(sk)) + return false; + + if (!sock_flag(sk, SOCK_TXTIME)) + return false; + + /* We don't perform crosstimestamping. + * Drop if packet's clockid differs from qdisc's. + */ + if (sk->sk_clockid != q->clockid) + return false; + + if (sk->sk_txtime_deadline_mode != q->deadline_mode) + return false; + +skip: + now = q->get_time(); + if (ktime_before(txtime, now) || ktime_before(txtime, q->last)) + return false; + + return true; +} + +static struct sk_buff *etf_peek_timesortedlist(struct Qdisc *sch) +{ + struct etf_sched_data *q = qdisc_priv(sch); + struct rb_node *p; + + p = rb_first_cached(&q->head); + if (!p) + return NULL; + + return rb_to_skb(p); +} + +static void reset_watchdog(struct Qdisc *sch) +{ + struct etf_sched_data *q = qdisc_priv(sch); + struct sk_buff *skb = etf_peek_timesortedlist(sch); + ktime_t next; + + if (!skb) { + qdisc_watchdog_cancel(&q->watchdog); + return; + } + + next = ktime_sub_ns(skb->tstamp, q->delta); + qdisc_watchdog_schedule_ns(&q->watchdog, ktime_to_ns(next)); +} + +static void report_sock_error(struct sk_buff *skb, u32 err, u8 code) +{ + struct sock_exterr_skb *serr; + struct sk_buff *clone; + ktime_t txtime = skb->tstamp; + struct sock *sk = skb->sk; + + if (!sk || !sk_fullsock(sk) || !(sk->sk_txtime_report_errors)) + return; + + clone = skb_clone(skb, GFP_ATOMIC); + if (!clone) + return; + + serr = SKB_EXT_ERR(clone); + serr->ee.ee_errno = err; + serr->ee.ee_origin = SO_EE_ORIGIN_TXTIME; + serr->ee.ee_type = 0; + serr->ee.ee_code = code; + serr->ee.ee_pad = 0; + serr->ee.ee_data = (txtime >> 32); /* high part of tstamp */ + serr->ee.ee_info = txtime; /* low part of tstamp */ + + if (sock_queue_err_skb(sk, clone)) + kfree_skb(clone); +} + +static int etf_enqueue_timesortedlist(struct sk_buff *nskb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct etf_sched_data *q = qdisc_priv(sch); + struct rb_node **p = &q->head.rb_root.rb_node, *parent = NULL; + ktime_t txtime = nskb->tstamp; + bool leftmost = true; + + if (!is_packet_valid(sch, nskb)) { + report_sock_error(nskb, EINVAL, + SO_EE_CODE_TXTIME_INVALID_PARAM); + return qdisc_drop(nskb, sch, to_free); + } + + while (*p) { + struct sk_buff *skb; + + parent = *p; + skb = rb_to_skb(parent); + if (ktime_compare(txtime, skb->tstamp) >= 0) { + p = &parent->rb_right; + leftmost = false; + } else { + p = &parent->rb_left; + } + } + rb_link_node(&nskb->rbnode, parent, p); + rb_insert_color_cached(&nskb->rbnode, &q->head, leftmost); + + qdisc_qstats_backlog_inc(sch, nskb); + sch->q.qlen++; + + /* Now we may need to re-arm the qdisc watchdog for the next packet. */ + reset_watchdog(sch); + + return NET_XMIT_SUCCESS; +} + +static void timesortedlist_drop(struct Qdisc *sch, struct sk_buff *skb, + ktime_t now) +{ + struct etf_sched_data *q = qdisc_priv(sch); + struct sk_buff *to_free = NULL; + struct sk_buff *tmp = NULL; + + skb_rbtree_walk_from_safe(skb, tmp) { + if (ktime_after(skb->tstamp, now)) + break; + + rb_erase_cached(&skb->rbnode, &q->head); + + /* The rbnode field in the skb re-uses these fields, now that + * we are done with the rbnode, reset them. + */ + skb->next = NULL; + skb->prev = NULL; + skb->dev = qdisc_dev(sch); + + report_sock_error(skb, ECANCELED, SO_EE_CODE_TXTIME_MISSED); + + qdisc_qstats_backlog_dec(sch, skb); + qdisc_drop(skb, sch, &to_free); + qdisc_qstats_overlimit(sch); + sch->q.qlen--; + } + + kfree_skb_list(to_free); +} + +static void timesortedlist_remove(struct Qdisc *sch, struct sk_buff *skb) +{ + struct etf_sched_data *q = qdisc_priv(sch); + + rb_erase_cached(&skb->rbnode, &q->head); + + /* The rbnode field in the skb re-uses these fields, now that + * we are done with the rbnode, reset them. + */ + skb->next = NULL; + skb->prev = NULL; + skb->dev = qdisc_dev(sch); + + qdisc_qstats_backlog_dec(sch, skb); + + qdisc_bstats_update(sch, skb); + + q->last = skb->tstamp; + + sch->q.qlen--; +} + +static struct sk_buff *etf_dequeue_timesortedlist(struct Qdisc *sch) +{ + struct etf_sched_data *q = qdisc_priv(sch); + struct sk_buff *skb; + ktime_t now, next; + + skb = etf_peek_timesortedlist(sch); + if (!skb) + return NULL; + + now = q->get_time(); + + /* Drop if packet has expired while in queue. */ + if (ktime_before(skb->tstamp, now)) { + timesortedlist_drop(sch, skb, now); + skb = NULL; + goto out; + } + + /* When in deadline mode, dequeue as soon as possible and change the + * txtime from deadline to (now + delta). + */ + if (q->deadline_mode) { + timesortedlist_remove(sch, skb); + skb->tstamp = now; + goto out; + } + + next = ktime_sub_ns(skb->tstamp, q->delta); + + /* Dequeue only if now is within the [txtime - delta, txtime] range. */ + if (ktime_after(now, next)) + timesortedlist_remove(sch, skb); + else + skb = NULL; + +out: + /* Now we may need to re-arm the qdisc watchdog for the next packet. */ + reset_watchdog(sch); + + return skb; +} + +static void etf_disable_offload(struct net_device *dev, + struct etf_sched_data *q) +{ + struct tc_etf_qopt_offload etf = { }; + const struct net_device_ops *ops; + int err; + + if (!q->offload) + return; + + ops = dev->netdev_ops; + if (!ops->ndo_setup_tc) + return; + + etf.queue = q->queue; + etf.enable = 0; + + err = ops->ndo_setup_tc(dev, TC_SETUP_QDISC_ETF, &etf); + if (err < 0) + pr_warn("Couldn't disable ETF offload for queue %d\n", + etf.queue); +} + +static int etf_enable_offload(struct net_device *dev, struct etf_sched_data *q, + struct netlink_ext_ack *extack) +{ + const struct net_device_ops *ops = dev->netdev_ops; + struct tc_etf_qopt_offload etf = { }; + int err; + + if (!ops->ndo_setup_tc) { + NL_SET_ERR_MSG(extack, "Specified device does not support ETF offload"); + return -EOPNOTSUPP; + } + + etf.queue = q->queue; + etf.enable = 1; + + err = ops->ndo_setup_tc(dev, TC_SETUP_QDISC_ETF, &etf); + if (err < 0) { + NL_SET_ERR_MSG(extack, "Specified device failed to setup ETF hardware offload"); + return err; + } + + return 0; +} + +static int etf_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct etf_sched_data *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct nlattr *tb[TCA_ETF_MAX + 1]; + struct tc_etf_qopt *qopt; + int err; + + if (!opt) { + NL_SET_ERR_MSG(extack, + "Missing ETF qdisc options which are mandatory"); + return -EINVAL; + } + + err = nla_parse_nested_deprecated(tb, TCA_ETF_MAX, opt, etf_policy, + extack); + if (err < 0) + return err; + + if (!tb[TCA_ETF_PARMS]) { + NL_SET_ERR_MSG(extack, "Missing mandatory ETF parameters"); + return -EINVAL; + } + + qopt = nla_data(tb[TCA_ETF_PARMS]); + + pr_debug("delta %d clockid %d offload %s deadline %s\n", + qopt->delta, qopt->clockid, + OFFLOAD_IS_ON(qopt) ? "on" : "off", + DEADLINE_MODE_IS_ON(qopt) ? "on" : "off"); + + err = validate_input_params(qopt, extack); + if (err < 0) + return err; + + q->queue = sch->dev_queue - netdev_get_tx_queue(dev, 0); + + if (OFFLOAD_IS_ON(qopt)) { + err = etf_enable_offload(dev, q, extack); + if (err < 0) + return err; + } + + /* Everything went OK, save the parameters used. */ + q->delta = qopt->delta; + q->clockid = qopt->clockid; + q->offload = OFFLOAD_IS_ON(qopt); + q->deadline_mode = DEADLINE_MODE_IS_ON(qopt); + q->skip_sock_check = SKIP_SOCK_CHECK_IS_SET(qopt); + + switch (q->clockid) { + case CLOCK_REALTIME: + q->get_time = ktime_get_real; + break; + case CLOCK_MONOTONIC: + q->get_time = ktime_get; + break; + case CLOCK_BOOTTIME: + q->get_time = ktime_get_boottime; + break; + case CLOCK_TAI: + q->get_time = ktime_get_clocktai; + break; + default: + NL_SET_ERR_MSG(extack, "Clockid is not supported"); + return -ENOTSUPP; + } + + qdisc_watchdog_init_clockid(&q->watchdog, sch, q->clockid); + + return 0; +} + +static void timesortedlist_clear(struct Qdisc *sch) +{ + struct etf_sched_data *q = qdisc_priv(sch); + struct rb_node *p = rb_first_cached(&q->head); + + while (p) { + struct sk_buff *skb = rb_to_skb(p); + + p = rb_next(p); + + rb_erase_cached(&skb->rbnode, &q->head); + rtnl_kfree_skbs(skb, skb); + sch->q.qlen--; + } +} + +static void etf_reset(struct Qdisc *sch) +{ + struct etf_sched_data *q = qdisc_priv(sch); + + /* Only cancel watchdog if it's been initialized. */ + if (q->watchdog.qdisc == sch) + qdisc_watchdog_cancel(&q->watchdog); + + /* No matter which mode we are on, it's safe to clear both lists. */ + timesortedlist_clear(sch); + __qdisc_reset_queue(&sch->q); + + q->last = 0; +} + +static void etf_destroy(struct Qdisc *sch) +{ + struct etf_sched_data *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + + /* Only cancel watchdog if it's been initialized. */ + if (q->watchdog.qdisc == sch) + qdisc_watchdog_cancel(&q->watchdog); + + etf_disable_offload(dev, q); +} + +static int etf_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct etf_sched_data *q = qdisc_priv(sch); + struct tc_etf_qopt opt = { }; + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (!nest) + goto nla_put_failure; + + opt.delta = q->delta; + opt.clockid = q->clockid; + if (q->offload) + opt.flags |= TC_ETF_OFFLOAD_ON; + + if (q->deadline_mode) + opt.flags |= TC_ETF_DEADLINE_MODE_ON; + + if (q->skip_sock_check) + opt.flags |= TC_ETF_SKIP_SOCK_CHECK; + + if (nla_put(skb, TCA_ETF_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + + return nla_nest_end(skb, nest); + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static struct Qdisc_ops etf_qdisc_ops __read_mostly = { + .id = "etf", + .priv_size = sizeof(struct etf_sched_data), + .enqueue = etf_enqueue_timesortedlist, + .dequeue = etf_dequeue_timesortedlist, + .peek = etf_peek_timesortedlist, + .init = etf_init, + .reset = etf_reset, + .destroy = etf_destroy, + .dump = etf_dump, + .owner = THIS_MODULE, +}; + +static int __init etf_module_init(void) +{ + return register_qdisc(&etf_qdisc_ops); +} + +static void __exit etf_module_exit(void) +{ + unregister_qdisc(&etf_qdisc_ops); +} +module_init(etf_module_init) +module_exit(etf_module_exit) +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_ets.c b/net/sched/sch_ets.c new file mode 100644 index 000000000..b10efeaf0 --- /dev/null +++ b/net/sched/sch_ets.c @@ -0,0 +1,828 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/sched/sch_ets.c Enhanced Transmission Selection scheduler + * + * Description + * ----------- + * + * The Enhanced Transmission Selection scheduler is a classful queuing + * discipline that merges functionality of PRIO and DRR qdiscs in one scheduler. + * ETS makes it easy to configure a set of strict and bandwidth-sharing bands to + * implement the transmission selection described in 802.1Qaz. + * + * Although ETS is technically classful, it's not possible to add and remove + * classes at will. Instead one specifies number of classes, how many are + * PRIO-like and how many DRR-like, and quanta for the latter. + * + * Algorithm + * --------- + * + * The strict classes, if any, are tried for traffic first: first band 0, if it + * has no traffic then band 1, etc. + * + * When there is no traffic in any of the strict queues, the bandwidth-sharing + * ones are tried next. Each band is assigned a deficit counter, initialized to + * "quantum" of that band. ETS maintains a list of active bandwidth-sharing + * bands whose qdiscs are non-empty. A packet is dequeued from the band at the + * head of the list if the packet size is smaller or equal to the deficit + * counter. If the counter is too small, it is increased by "quantum" and the + * scheduler moves on to the next band in the active list. + */ + +#include +#include +#include +#include +#include +#include + +struct ets_class { + struct list_head alist; /* In struct ets_sched.active. */ + struct Qdisc *qdisc; + u32 quantum; + u32 deficit; + struct gnet_stats_basic_sync bstats; + struct gnet_stats_queue qstats; +}; + +struct ets_sched { + struct list_head active; + struct tcf_proto __rcu *filter_list; + struct tcf_block *block; + unsigned int nbands; + unsigned int nstrict; + u8 prio2band[TC_PRIO_MAX + 1]; + struct ets_class classes[TCQ_ETS_MAX_BANDS]; +}; + +static const struct nla_policy ets_policy[TCA_ETS_MAX + 1] = { + [TCA_ETS_NBANDS] = { .type = NLA_U8 }, + [TCA_ETS_NSTRICT] = { .type = NLA_U8 }, + [TCA_ETS_QUANTA] = { .type = NLA_NESTED }, + [TCA_ETS_PRIOMAP] = { .type = NLA_NESTED }, +}; + +static const struct nla_policy ets_priomap_policy[TCA_ETS_MAX + 1] = { + [TCA_ETS_PRIOMAP_BAND] = { .type = NLA_U8 }, +}; + +static const struct nla_policy ets_quanta_policy[TCA_ETS_MAX + 1] = { + [TCA_ETS_QUANTA_BAND] = { .type = NLA_U32 }, +}; + +static const struct nla_policy ets_class_policy[TCA_ETS_MAX + 1] = { + [TCA_ETS_QUANTA_BAND] = { .type = NLA_U32 }, +}; + +static int ets_quantum_parse(struct Qdisc *sch, const struct nlattr *attr, + unsigned int *quantum, + struct netlink_ext_ack *extack) +{ + *quantum = nla_get_u32(attr); + if (!*quantum) { + NL_SET_ERR_MSG(extack, "ETS quantum cannot be zero"); + return -EINVAL; + } + return 0; +} + +static struct ets_class * +ets_class_from_arg(struct Qdisc *sch, unsigned long arg) +{ + struct ets_sched *q = qdisc_priv(sch); + + return &q->classes[arg - 1]; +} + +static u32 ets_class_id(struct Qdisc *sch, const struct ets_class *cl) +{ + struct ets_sched *q = qdisc_priv(sch); + int band = cl - q->classes; + + return TC_H_MAKE(sch->handle, band + 1); +} + +static void ets_offload_change(struct Qdisc *sch) +{ + struct net_device *dev = qdisc_dev(sch); + struct ets_sched *q = qdisc_priv(sch); + struct tc_ets_qopt_offload qopt; + unsigned int w_psum_prev = 0; + unsigned int q_psum = 0; + unsigned int q_sum = 0; + unsigned int quantum; + unsigned int w_psum; + unsigned int weight; + unsigned int i; + + if (!tc_can_offload(dev) || !dev->netdev_ops->ndo_setup_tc) + return; + + qopt.command = TC_ETS_REPLACE; + qopt.handle = sch->handle; + qopt.parent = sch->parent; + qopt.replace_params.bands = q->nbands; + qopt.replace_params.qstats = &sch->qstats; + memcpy(&qopt.replace_params.priomap, + q->prio2band, sizeof(q->prio2band)); + + for (i = 0; i < q->nbands; i++) + q_sum += q->classes[i].quantum; + + for (i = 0; i < q->nbands; i++) { + quantum = q->classes[i].quantum; + q_psum += quantum; + w_psum = quantum ? q_psum * 100 / q_sum : 0; + weight = w_psum - w_psum_prev; + w_psum_prev = w_psum; + + qopt.replace_params.quanta[i] = quantum; + qopt.replace_params.weights[i] = weight; + } + + dev->netdev_ops->ndo_setup_tc(dev, TC_SETUP_QDISC_ETS, &qopt); +} + +static void ets_offload_destroy(struct Qdisc *sch) +{ + struct net_device *dev = qdisc_dev(sch); + struct tc_ets_qopt_offload qopt; + + if (!tc_can_offload(dev) || !dev->netdev_ops->ndo_setup_tc) + return; + + qopt.command = TC_ETS_DESTROY; + qopt.handle = sch->handle; + qopt.parent = sch->parent; + dev->netdev_ops->ndo_setup_tc(dev, TC_SETUP_QDISC_ETS, &qopt); +} + +static void ets_offload_graft(struct Qdisc *sch, struct Qdisc *new, + struct Qdisc *old, unsigned long arg, + struct netlink_ext_ack *extack) +{ + struct net_device *dev = qdisc_dev(sch); + struct tc_ets_qopt_offload qopt; + + qopt.command = TC_ETS_GRAFT; + qopt.handle = sch->handle; + qopt.parent = sch->parent; + qopt.graft_params.band = arg - 1; + qopt.graft_params.child_handle = new->handle; + + qdisc_offload_graft_helper(dev, sch, new, old, TC_SETUP_QDISC_ETS, + &qopt, extack); +} + +static int ets_offload_dump(struct Qdisc *sch) +{ + struct tc_ets_qopt_offload qopt; + + qopt.command = TC_ETS_STATS; + qopt.handle = sch->handle; + qopt.parent = sch->parent; + qopt.stats.bstats = &sch->bstats; + qopt.stats.qstats = &sch->qstats; + + return qdisc_offload_dump_helper(sch, TC_SETUP_QDISC_ETS, &qopt); +} + +static bool ets_class_is_strict(struct ets_sched *q, const struct ets_class *cl) +{ + unsigned int band = cl - q->classes; + + return band < q->nstrict; +} + +static int ets_class_change(struct Qdisc *sch, u32 classid, u32 parentid, + struct nlattr **tca, unsigned long *arg, + struct netlink_ext_ack *extack) +{ + struct ets_class *cl = ets_class_from_arg(sch, *arg); + struct ets_sched *q = qdisc_priv(sch); + struct nlattr *opt = tca[TCA_OPTIONS]; + struct nlattr *tb[TCA_ETS_MAX + 1]; + unsigned int quantum; + int err; + + /* Classes can be added and removed only through Qdisc_ops.change + * interface. + */ + if (!cl) { + NL_SET_ERR_MSG(extack, "Fine-grained class addition and removal is not supported"); + return -EOPNOTSUPP; + } + + if (!opt) { + NL_SET_ERR_MSG(extack, "ETS options are required for this operation"); + return -EINVAL; + } + + err = nla_parse_nested(tb, TCA_ETS_MAX, opt, ets_class_policy, extack); + if (err < 0) + return err; + + if (!tb[TCA_ETS_QUANTA_BAND]) + /* Nothing to configure. */ + return 0; + + if (ets_class_is_strict(q, cl)) { + NL_SET_ERR_MSG(extack, "Strict bands do not have a configurable quantum"); + return -EINVAL; + } + + err = ets_quantum_parse(sch, tb[TCA_ETS_QUANTA_BAND], &quantum, + extack); + if (err) + return err; + + sch_tree_lock(sch); + cl->quantum = quantum; + sch_tree_unlock(sch); + + ets_offload_change(sch); + return 0; +} + +static int ets_class_graft(struct Qdisc *sch, unsigned long arg, + struct Qdisc *new, struct Qdisc **old, + struct netlink_ext_ack *extack) +{ + struct ets_class *cl = ets_class_from_arg(sch, arg); + + if (!new) { + new = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, + ets_class_id(sch, cl), NULL); + if (!new) + new = &noop_qdisc; + else + qdisc_hash_add(new, true); + } + + *old = qdisc_replace(sch, new, &cl->qdisc); + ets_offload_graft(sch, new, *old, arg, extack); + return 0; +} + +static struct Qdisc *ets_class_leaf(struct Qdisc *sch, unsigned long arg) +{ + struct ets_class *cl = ets_class_from_arg(sch, arg); + + return cl->qdisc; +} + +static unsigned long ets_class_find(struct Qdisc *sch, u32 classid) +{ + unsigned long band = TC_H_MIN(classid); + struct ets_sched *q = qdisc_priv(sch); + + if (band - 1 >= q->nbands) + return 0; + return band; +} + +static void ets_class_qlen_notify(struct Qdisc *sch, unsigned long arg) +{ + struct ets_class *cl = ets_class_from_arg(sch, arg); + struct ets_sched *q = qdisc_priv(sch); + + /* We get notified about zero-length child Qdiscs as well if they are + * offloaded. Those aren't on the active list though, so don't attempt + * to remove them. + */ + if (!ets_class_is_strict(q, cl) && sch->q.qlen) + list_del(&cl->alist); +} + +static int ets_class_dump(struct Qdisc *sch, unsigned long arg, + struct sk_buff *skb, struct tcmsg *tcm) +{ + struct ets_class *cl = ets_class_from_arg(sch, arg); + struct ets_sched *q = qdisc_priv(sch); + struct nlattr *nest; + + tcm->tcm_parent = TC_H_ROOT; + tcm->tcm_handle = ets_class_id(sch, cl); + tcm->tcm_info = cl->qdisc->handle; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (!nest) + goto nla_put_failure; + if (!ets_class_is_strict(q, cl)) { + if (nla_put_u32(skb, TCA_ETS_QUANTA_BAND, cl->quantum)) + goto nla_put_failure; + } + return nla_nest_end(skb, nest); + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int ets_class_dump_stats(struct Qdisc *sch, unsigned long arg, + struct gnet_dump *d) +{ + struct ets_class *cl = ets_class_from_arg(sch, arg); + struct Qdisc *cl_q = cl->qdisc; + + if (gnet_stats_copy_basic(d, NULL, &cl_q->bstats, true) < 0 || + qdisc_qstats_copy(d, cl_q) < 0) + return -1; + + return 0; +} + +static void ets_qdisc_walk(struct Qdisc *sch, struct qdisc_walker *arg) +{ + struct ets_sched *q = qdisc_priv(sch); + int i; + + if (arg->stop) + return; + + for (i = 0; i < q->nbands; i++) { + if (!tc_qdisc_stats_dump(sch, i + 1, arg)) + break; + } +} + +static struct tcf_block * +ets_qdisc_tcf_block(struct Qdisc *sch, unsigned long cl, + struct netlink_ext_ack *extack) +{ + struct ets_sched *q = qdisc_priv(sch); + + if (cl) { + NL_SET_ERR_MSG(extack, "ETS classid must be zero"); + return NULL; + } + + return q->block; +} + +static unsigned long ets_qdisc_bind_tcf(struct Qdisc *sch, unsigned long parent, + u32 classid) +{ + return ets_class_find(sch, classid); +} + +static void ets_qdisc_unbind_tcf(struct Qdisc *sch, unsigned long arg) +{ +} + +static struct ets_class *ets_classify(struct sk_buff *skb, struct Qdisc *sch, + int *qerr) +{ + struct ets_sched *q = qdisc_priv(sch); + u32 band = skb->priority; + struct tcf_result res; + struct tcf_proto *fl; + int err; + + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; + if (TC_H_MAJ(skb->priority) != sch->handle) { + fl = rcu_dereference_bh(q->filter_list); + err = tcf_classify(skb, NULL, fl, &res, false); +#ifdef CONFIG_NET_CLS_ACT + switch (err) { + case TC_ACT_STOLEN: + case TC_ACT_QUEUED: + case TC_ACT_TRAP: + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_STOLEN; + fallthrough; + case TC_ACT_SHOT: + return NULL; + } +#endif + if (!fl || err < 0) { + if (TC_H_MAJ(band)) + band = 0; + return &q->classes[q->prio2band[band & TC_PRIO_MAX]]; + } + band = res.classid; + } + band = TC_H_MIN(band) - 1; + if (band >= q->nbands) + return &q->classes[q->prio2band[0]]; + return &q->classes[band]; +} + +static int ets_qdisc_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + unsigned int len = qdisc_pkt_len(skb); + struct ets_sched *q = qdisc_priv(sch); + struct ets_class *cl; + int err = 0; + bool first; + + cl = ets_classify(skb, sch, &err); + if (!cl) { + if (err & __NET_XMIT_BYPASS) + qdisc_qstats_drop(sch); + __qdisc_drop(skb, to_free); + return err; + } + + first = !cl->qdisc->q.qlen; + err = qdisc_enqueue(skb, cl->qdisc, to_free); + if (unlikely(err != NET_XMIT_SUCCESS)) { + if (net_xmit_drop_count(err)) { + cl->qstats.drops++; + qdisc_qstats_drop(sch); + } + return err; + } + + if (first && !ets_class_is_strict(q, cl)) { + list_add_tail(&cl->alist, &q->active); + cl->deficit = cl->quantum; + } + + sch->qstats.backlog += len; + sch->q.qlen++; + return err; +} + +static struct sk_buff * +ets_qdisc_dequeue_skb(struct Qdisc *sch, struct sk_buff *skb) +{ + qdisc_bstats_update(sch, skb); + qdisc_qstats_backlog_dec(sch, skb); + sch->q.qlen--; + return skb; +} + +static struct sk_buff *ets_qdisc_dequeue(struct Qdisc *sch) +{ + struct ets_sched *q = qdisc_priv(sch); + struct ets_class *cl; + struct sk_buff *skb; + unsigned int band; + unsigned int len; + + while (1) { + for (band = 0; band < q->nstrict; band++) { + cl = &q->classes[band]; + skb = qdisc_dequeue_peeked(cl->qdisc); + if (skb) + return ets_qdisc_dequeue_skb(sch, skb); + } + + if (list_empty(&q->active)) + goto out; + + cl = list_first_entry(&q->active, struct ets_class, alist); + skb = cl->qdisc->ops->peek(cl->qdisc); + if (!skb) { + qdisc_warn_nonwc(__func__, cl->qdisc); + goto out; + } + + len = qdisc_pkt_len(skb); + if (len <= cl->deficit) { + cl->deficit -= len; + skb = qdisc_dequeue_peeked(cl->qdisc); + if (unlikely(!skb)) + goto out; + if (cl->qdisc->q.qlen == 0) + list_del(&cl->alist); + return ets_qdisc_dequeue_skb(sch, skb); + } + + cl->deficit += cl->quantum; + list_move_tail(&cl->alist, &q->active); + } +out: + return NULL; +} + +static int ets_qdisc_priomap_parse(struct nlattr *priomap_attr, + unsigned int nbands, u8 *priomap, + struct netlink_ext_ack *extack) +{ + const struct nlattr *attr; + int prio = 0; + u8 band; + int rem; + int err; + + err = __nla_validate_nested(priomap_attr, TCA_ETS_MAX, + ets_priomap_policy, NL_VALIDATE_STRICT, + extack); + if (err) + return err; + + nla_for_each_nested(attr, priomap_attr, rem) { + switch (nla_type(attr)) { + case TCA_ETS_PRIOMAP_BAND: + if (prio > TC_PRIO_MAX) { + NL_SET_ERR_MSG_MOD(extack, "Too many priorities in ETS priomap"); + return -EINVAL; + } + band = nla_get_u8(attr); + if (band >= nbands) { + NL_SET_ERR_MSG_MOD(extack, "Invalid band number in ETS priomap"); + return -EINVAL; + } + priomap[prio++] = band; + break; + default: + WARN_ON_ONCE(1); /* Validate should have caught this. */ + return -EINVAL; + } + } + + return 0; +} + +static int ets_qdisc_quanta_parse(struct Qdisc *sch, struct nlattr *quanta_attr, + unsigned int nbands, unsigned int nstrict, + unsigned int *quanta, + struct netlink_ext_ack *extack) +{ + const struct nlattr *attr; + int band = nstrict; + int rem; + int err; + + err = __nla_validate_nested(quanta_attr, TCA_ETS_MAX, + ets_quanta_policy, NL_VALIDATE_STRICT, + extack); + if (err < 0) + return err; + + nla_for_each_nested(attr, quanta_attr, rem) { + switch (nla_type(attr)) { + case TCA_ETS_QUANTA_BAND: + if (band >= nbands) { + NL_SET_ERR_MSG_MOD(extack, "ETS quanta has more values than bands"); + return -EINVAL; + } + err = ets_quantum_parse(sch, attr, &quanta[band++], + extack); + if (err) + return err; + break; + default: + WARN_ON_ONCE(1); /* Validate should have caught this. */ + return -EINVAL; + } + } + + return 0; +} + +static int ets_qdisc_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + unsigned int quanta[TCQ_ETS_MAX_BANDS] = {0}; + struct Qdisc *queues[TCQ_ETS_MAX_BANDS]; + struct ets_sched *q = qdisc_priv(sch); + struct nlattr *tb[TCA_ETS_MAX + 1]; + unsigned int oldbands = q->nbands; + u8 priomap[TC_PRIO_MAX + 1]; + unsigned int nstrict = 0; + unsigned int nbands; + unsigned int i; + int err; + + err = nla_parse_nested(tb, TCA_ETS_MAX, opt, ets_policy, extack); + if (err < 0) + return err; + + if (!tb[TCA_ETS_NBANDS]) { + NL_SET_ERR_MSG_MOD(extack, "Number of bands is a required argument"); + return -EINVAL; + } + nbands = nla_get_u8(tb[TCA_ETS_NBANDS]); + if (nbands < 1 || nbands > TCQ_ETS_MAX_BANDS) { + NL_SET_ERR_MSG_MOD(extack, "Invalid number of bands"); + return -EINVAL; + } + /* Unless overridden, traffic goes to the last band. */ + memset(priomap, nbands - 1, sizeof(priomap)); + + if (tb[TCA_ETS_NSTRICT]) { + nstrict = nla_get_u8(tb[TCA_ETS_NSTRICT]); + if (nstrict > nbands) { + NL_SET_ERR_MSG_MOD(extack, "Invalid number of strict bands"); + return -EINVAL; + } + } + + if (tb[TCA_ETS_PRIOMAP]) { + err = ets_qdisc_priomap_parse(tb[TCA_ETS_PRIOMAP], + nbands, priomap, extack); + if (err) + return err; + } + + if (tb[TCA_ETS_QUANTA]) { + err = ets_qdisc_quanta_parse(sch, tb[TCA_ETS_QUANTA], + nbands, nstrict, quanta, extack); + if (err) + return err; + } + /* If there are more bands than strict + quanta provided, the remaining + * ones are ETS with quantum of MTU. Initialize the missing values here. + */ + for (i = nstrict; i < nbands; i++) { + if (!quanta[i]) + quanta[i] = psched_mtu(qdisc_dev(sch)); + } + + /* Before commit, make sure we can allocate all new qdiscs */ + for (i = oldbands; i < nbands; i++) { + queues[i] = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, + ets_class_id(sch, &q->classes[i]), + extack); + if (!queues[i]) { + while (i > oldbands) + qdisc_put(queues[--i]); + return -ENOMEM; + } + } + + sch_tree_lock(sch); + + q->nbands = nbands; + for (i = nstrict; i < q->nstrict; i++) { + if (q->classes[i].qdisc->q.qlen) { + list_add_tail(&q->classes[i].alist, &q->active); + q->classes[i].deficit = quanta[i]; + } + } + for (i = q->nbands; i < oldbands; i++) { + if (i >= q->nstrict && q->classes[i].qdisc->q.qlen) + list_del(&q->classes[i].alist); + qdisc_tree_flush_backlog(q->classes[i].qdisc); + } + q->nstrict = nstrict; + memcpy(q->prio2band, priomap, sizeof(priomap)); + + for (i = 0; i < q->nbands; i++) + q->classes[i].quantum = quanta[i]; + + for (i = oldbands; i < q->nbands; i++) { + q->classes[i].qdisc = queues[i]; + if (q->classes[i].qdisc != &noop_qdisc) + qdisc_hash_add(q->classes[i].qdisc, true); + } + + sch_tree_unlock(sch); + + ets_offload_change(sch); + for (i = q->nbands; i < oldbands; i++) { + qdisc_put(q->classes[i].qdisc); + q->classes[i].qdisc = NULL; + q->classes[i].quantum = 0; + q->classes[i].deficit = 0; + gnet_stats_basic_sync_init(&q->classes[i].bstats); + memset(&q->classes[i].qstats, 0, sizeof(q->classes[i].qstats)); + } + return 0; +} + +static int ets_qdisc_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct ets_sched *q = qdisc_priv(sch); + int err, i; + + if (!opt) + return -EINVAL; + + err = tcf_block_get(&q->block, &q->filter_list, sch, extack); + if (err) + return err; + + INIT_LIST_HEAD(&q->active); + for (i = 0; i < TCQ_ETS_MAX_BANDS; i++) + INIT_LIST_HEAD(&q->classes[i].alist); + + return ets_qdisc_change(sch, opt, extack); +} + +static void ets_qdisc_reset(struct Qdisc *sch) +{ + struct ets_sched *q = qdisc_priv(sch); + int band; + + for (band = q->nstrict; band < q->nbands; band++) { + if (q->classes[band].qdisc->q.qlen) + list_del(&q->classes[band].alist); + } + for (band = 0; band < q->nbands; band++) + qdisc_reset(q->classes[band].qdisc); +} + +static void ets_qdisc_destroy(struct Qdisc *sch) +{ + struct ets_sched *q = qdisc_priv(sch); + int band; + + ets_offload_destroy(sch); + tcf_block_put(q->block); + for (band = 0; band < q->nbands; band++) + qdisc_put(q->classes[band].qdisc); +} + +static int ets_qdisc_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct ets_sched *q = qdisc_priv(sch); + struct nlattr *opts; + struct nlattr *nest; + int band; + int prio; + int err; + + err = ets_offload_dump(sch); + if (err) + return err; + + opts = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (!opts) + goto nla_err; + + if (nla_put_u8(skb, TCA_ETS_NBANDS, q->nbands)) + goto nla_err; + + if (q->nstrict && + nla_put_u8(skb, TCA_ETS_NSTRICT, q->nstrict)) + goto nla_err; + + if (q->nbands > q->nstrict) { + nest = nla_nest_start(skb, TCA_ETS_QUANTA); + if (!nest) + goto nla_err; + + for (band = q->nstrict; band < q->nbands; band++) { + if (nla_put_u32(skb, TCA_ETS_QUANTA_BAND, + q->classes[band].quantum)) + goto nla_err; + } + + nla_nest_end(skb, nest); + } + + nest = nla_nest_start(skb, TCA_ETS_PRIOMAP); + if (!nest) + goto nla_err; + + for (prio = 0; prio <= TC_PRIO_MAX; prio++) { + if (nla_put_u8(skb, TCA_ETS_PRIOMAP_BAND, q->prio2band[prio])) + goto nla_err; + } + + nla_nest_end(skb, nest); + + return nla_nest_end(skb, opts); + +nla_err: + nla_nest_cancel(skb, opts); + return -EMSGSIZE; +} + +static const struct Qdisc_class_ops ets_class_ops = { + .change = ets_class_change, + .graft = ets_class_graft, + .leaf = ets_class_leaf, + .find = ets_class_find, + .qlen_notify = ets_class_qlen_notify, + .dump = ets_class_dump, + .dump_stats = ets_class_dump_stats, + .walk = ets_qdisc_walk, + .tcf_block = ets_qdisc_tcf_block, + .bind_tcf = ets_qdisc_bind_tcf, + .unbind_tcf = ets_qdisc_unbind_tcf, +}; + +static struct Qdisc_ops ets_qdisc_ops __read_mostly = { + .cl_ops = &ets_class_ops, + .id = "ets", + .priv_size = sizeof(struct ets_sched), + .enqueue = ets_qdisc_enqueue, + .dequeue = ets_qdisc_dequeue, + .peek = qdisc_peek_dequeued, + .change = ets_qdisc_change, + .init = ets_qdisc_init, + .reset = ets_qdisc_reset, + .destroy = ets_qdisc_destroy, + .dump = ets_qdisc_dump, + .owner = THIS_MODULE, +}; + +static int __init ets_init(void) +{ + return register_qdisc(&ets_qdisc_ops); +} + +static void __exit ets_exit(void) +{ + unregister_qdisc(&ets_qdisc_ops); +} + +module_init(ets_init); +module_exit(ets_exit); +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_fifo.c b/net/sched/sch_fifo.c new file mode 100644 index 000000000..e1040421b --- /dev/null +++ b/net/sched/sch_fifo.c @@ -0,0 +1,271 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/sch_fifo.c The simplest FIFO queue. + * + * Authors: Alexey Kuznetsov, + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +/* 1 band FIFO pseudo-"scheduler" */ + +static int bfifo_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + if (likely(sch->qstats.backlog + qdisc_pkt_len(skb) <= sch->limit)) + return qdisc_enqueue_tail(skb, sch); + + return qdisc_drop(skb, sch, to_free); +} + +static int pfifo_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + if (likely(sch->q.qlen < sch->limit)) + return qdisc_enqueue_tail(skb, sch); + + return qdisc_drop(skb, sch, to_free); +} + +static int pfifo_tail_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + unsigned int prev_backlog; + + if (likely(sch->q.qlen < sch->limit)) + return qdisc_enqueue_tail(skb, sch); + + prev_backlog = sch->qstats.backlog; + /* queue full, remove one skb to fulfill the limit */ + __qdisc_queue_drop_head(sch, &sch->q, to_free); + qdisc_qstats_drop(sch); + qdisc_enqueue_tail(skb, sch); + + qdisc_tree_reduce_backlog(sch, 0, prev_backlog - sch->qstats.backlog); + return NET_XMIT_CN; +} + +static void fifo_offload_init(struct Qdisc *sch) +{ + struct net_device *dev = qdisc_dev(sch); + struct tc_fifo_qopt_offload qopt; + + if (!tc_can_offload(dev) || !dev->netdev_ops->ndo_setup_tc) + return; + + qopt.command = TC_FIFO_REPLACE; + qopt.handle = sch->handle; + qopt.parent = sch->parent; + dev->netdev_ops->ndo_setup_tc(dev, TC_SETUP_QDISC_FIFO, &qopt); +} + +static void fifo_offload_destroy(struct Qdisc *sch) +{ + struct net_device *dev = qdisc_dev(sch); + struct tc_fifo_qopt_offload qopt; + + if (!tc_can_offload(dev) || !dev->netdev_ops->ndo_setup_tc) + return; + + qopt.command = TC_FIFO_DESTROY; + qopt.handle = sch->handle; + qopt.parent = sch->parent; + dev->netdev_ops->ndo_setup_tc(dev, TC_SETUP_QDISC_FIFO, &qopt); +} + +static int fifo_offload_dump(struct Qdisc *sch) +{ + struct tc_fifo_qopt_offload qopt; + + qopt.command = TC_FIFO_STATS; + qopt.handle = sch->handle; + qopt.parent = sch->parent; + qopt.stats.bstats = &sch->bstats; + qopt.stats.qstats = &sch->qstats; + + return qdisc_offload_dump_helper(sch, TC_SETUP_QDISC_FIFO, &qopt); +} + +static int __fifo_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + bool bypass; + bool is_bfifo = sch->ops == &bfifo_qdisc_ops; + + if (opt == NULL) { + u32 limit = qdisc_dev(sch)->tx_queue_len; + + if (is_bfifo) + limit *= psched_mtu(qdisc_dev(sch)); + + sch->limit = limit; + } else { + struct tc_fifo_qopt *ctl = nla_data(opt); + + if (nla_len(opt) < sizeof(*ctl)) + return -EINVAL; + + sch->limit = ctl->limit; + } + + if (is_bfifo) + bypass = sch->limit >= psched_mtu(qdisc_dev(sch)); + else + bypass = sch->limit >= 1; + + if (bypass) + sch->flags |= TCQ_F_CAN_BYPASS; + else + sch->flags &= ~TCQ_F_CAN_BYPASS; + + return 0; +} + +static int fifo_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + int err; + + err = __fifo_init(sch, opt, extack); + if (err) + return err; + + fifo_offload_init(sch); + return 0; +} + +static int fifo_hd_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + return __fifo_init(sch, opt, extack); +} + +static void fifo_destroy(struct Qdisc *sch) +{ + fifo_offload_destroy(sch); +} + +static int __fifo_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct tc_fifo_qopt opt = { .limit = sch->limit }; + + if (nla_put(skb, TCA_OPTIONS, sizeof(opt), &opt)) + goto nla_put_failure; + return skb->len; + +nla_put_failure: + return -1; +} + +static int fifo_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + int err; + + err = fifo_offload_dump(sch); + if (err) + return err; + + return __fifo_dump(sch, skb); +} + +static int fifo_hd_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + return __fifo_dump(sch, skb); +} + +struct Qdisc_ops pfifo_qdisc_ops __read_mostly = { + .id = "pfifo", + .priv_size = 0, + .enqueue = pfifo_enqueue, + .dequeue = qdisc_dequeue_head, + .peek = qdisc_peek_head, + .init = fifo_init, + .destroy = fifo_destroy, + .reset = qdisc_reset_queue, + .change = fifo_init, + .dump = fifo_dump, + .owner = THIS_MODULE, +}; +EXPORT_SYMBOL(pfifo_qdisc_ops); + +struct Qdisc_ops bfifo_qdisc_ops __read_mostly = { + .id = "bfifo", + .priv_size = 0, + .enqueue = bfifo_enqueue, + .dequeue = qdisc_dequeue_head, + .peek = qdisc_peek_head, + .init = fifo_init, + .destroy = fifo_destroy, + .reset = qdisc_reset_queue, + .change = fifo_init, + .dump = fifo_dump, + .owner = THIS_MODULE, +}; +EXPORT_SYMBOL(bfifo_qdisc_ops); + +struct Qdisc_ops pfifo_head_drop_qdisc_ops __read_mostly = { + .id = "pfifo_head_drop", + .priv_size = 0, + .enqueue = pfifo_tail_enqueue, + .dequeue = qdisc_dequeue_head, + .peek = qdisc_peek_head, + .init = fifo_hd_init, + .reset = qdisc_reset_queue, + .change = fifo_hd_init, + .dump = fifo_hd_dump, + .owner = THIS_MODULE, +}; + +/* Pass size change message down to embedded FIFO */ +int fifo_set_limit(struct Qdisc *q, unsigned int limit) +{ + struct nlattr *nla; + int ret = -ENOMEM; + + /* Hack to avoid sending change message to non-FIFO */ + if (strncmp(q->ops->id + 1, "fifo", 4) != 0) + return 0; + + if (!q->ops->change) + return 0; + + nla = kmalloc(nla_attr_size(sizeof(struct tc_fifo_qopt)), GFP_KERNEL); + if (nla) { + nla->nla_type = RTM_NEWQDISC; + nla->nla_len = nla_attr_size(sizeof(struct tc_fifo_qopt)); + ((struct tc_fifo_qopt *)nla_data(nla))->limit = limit; + + ret = q->ops->change(q, nla, NULL); + kfree(nla); + } + return ret; +} +EXPORT_SYMBOL(fifo_set_limit); + +struct Qdisc *fifo_create_dflt(struct Qdisc *sch, struct Qdisc_ops *ops, + unsigned int limit, + struct netlink_ext_ack *extack) +{ + struct Qdisc *q; + int err = -ENOMEM; + + q = qdisc_create_dflt(sch->dev_queue, ops, TC_H_MAKE(sch->handle, 1), + extack); + if (q) { + err = fifo_set_limit(q, limit); + if (err < 0) { + qdisc_put(q); + q = NULL; + } + } + + return q ? : ERR_PTR(err); +} +EXPORT_SYMBOL(fifo_create_dflt); diff --git a/net/sched/sch_fq.c b/net/sched/sch_fq.c new file mode 100644 index 000000000..f59a2cb2c --- /dev/null +++ b/net/sched/sch_fq.c @@ -0,0 +1,1079 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/sch_fq.c Fair Queue Packet Scheduler (per flow pacing) + * + * Copyright (C) 2013-2015 Eric Dumazet + * + * Meant to be mostly used for locally generated traffic : + * Fast classification depends on skb->sk being set before reaching us. + * If not, (router workload), we use rxhash as fallback, with 32 bits wide hash. + * All packets belonging to a socket are considered as a 'flow'. + * + * Flows are dynamically allocated and stored in a hash table of RB trees + * They are also part of one Round Robin 'queues' (new or old flows) + * + * Burst avoidance (aka pacing) capability : + * + * Transport (eg TCP) can set in sk->sk_pacing_rate a rate, enqueue a + * bunch of packets, and this packet scheduler adds delay between + * packets to respect rate limitation. + * + * enqueue() : + * - lookup one RB tree (out of 1024 or more) to find the flow. + * If non existent flow, create it, add it to the tree. + * Add skb to the per flow list of skb (fifo). + * - Use a special fifo for high prio packets + * + * dequeue() : serves flows in Round Robin + * Note : When a flow becomes empty, we do not immediately remove it from + * rb trees, for performance reasons (its expected to send additional packets, + * or SLAB cache will reuse socket for another flow) + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct fq_skb_cb { + u64 time_to_send; +}; + +static inline struct fq_skb_cb *fq_skb_cb(struct sk_buff *skb) +{ + qdisc_cb_private_validate(skb, sizeof(struct fq_skb_cb)); + return (struct fq_skb_cb *)qdisc_skb_cb(skb)->data; +} + +/* + * Per flow structure, dynamically allocated. + * If packets have monotically increasing time_to_send, they are placed in O(1) + * in linear list (head,tail), otherwise are placed in a rbtree (t_root). + */ +struct fq_flow { +/* First cache line : used in fq_gc(), fq_enqueue(), fq_dequeue() */ + struct rb_root t_root; + struct sk_buff *head; /* list of skbs for this flow : first skb */ + union { + struct sk_buff *tail; /* last skb in the list */ + unsigned long age; /* (jiffies | 1UL) when flow was emptied, for gc */ + }; + struct rb_node fq_node; /* anchor in fq_root[] trees */ + struct sock *sk; + u32 socket_hash; /* sk_hash */ + int qlen; /* number of packets in flow queue */ + +/* Second cache line, used in fq_dequeue() */ + int credit; + /* 32bit hole on 64bit arches */ + + struct fq_flow *next; /* next pointer in RR lists */ + + struct rb_node rate_node; /* anchor in q->delayed tree */ + u64 time_next_packet; +} ____cacheline_aligned_in_smp; + +struct fq_flow_head { + struct fq_flow *first; + struct fq_flow *last; +}; + +struct fq_sched_data { + struct fq_flow_head new_flows; + + struct fq_flow_head old_flows; + + struct rb_root delayed; /* for rate limited flows */ + u64 time_next_delayed_flow; + u64 ktime_cache; /* copy of last ktime_get_ns() */ + unsigned long unthrottle_latency_ns; + + struct fq_flow internal; /* for non classified or high prio packets */ + u32 quantum; + u32 initial_quantum; + u32 flow_refill_delay; + u32 flow_plimit; /* max packets per flow */ + unsigned long flow_max_rate; /* optional max rate per flow */ + u64 ce_threshold; + u64 horizon; /* horizon in ns */ + u32 orphan_mask; /* mask for orphaned skb */ + u32 low_rate_threshold; + struct rb_root *fq_root; + u8 rate_enable; + u8 fq_trees_log; + u8 horizon_drop; + u32 flows; + u32 inactive_flows; + u32 throttled_flows; + + u64 stat_gc_flows; + u64 stat_internal_packets; + u64 stat_throttled; + u64 stat_ce_mark; + u64 stat_horizon_drops; + u64 stat_horizon_caps; + u64 stat_flows_plimit; + u64 stat_pkts_too_long; + u64 stat_allocation_errors; + + u32 timer_slack; /* hrtimer slack in ns */ + struct qdisc_watchdog watchdog; +}; + +/* + * f->tail and f->age share the same location. + * We can use the low order bit to differentiate if this location points + * to a sk_buff or contains a jiffies value, if we force this value to be odd. + * This assumes f->tail low order bit must be 0 since alignof(struct sk_buff) >= 2 + */ +static void fq_flow_set_detached(struct fq_flow *f) +{ + f->age = jiffies | 1UL; +} + +static bool fq_flow_is_detached(const struct fq_flow *f) +{ + return !!(f->age & 1UL); +} + +/* special value to mark a throttled flow (not on old/new list) */ +static struct fq_flow throttled; + +static bool fq_flow_is_throttled(const struct fq_flow *f) +{ + return f->next == &throttled; +} + +static void fq_flow_add_tail(struct fq_flow_head *head, struct fq_flow *flow) +{ + if (head->first) + head->last->next = flow; + else + head->first = flow; + head->last = flow; + flow->next = NULL; +} + +static void fq_flow_unset_throttled(struct fq_sched_data *q, struct fq_flow *f) +{ + rb_erase(&f->rate_node, &q->delayed); + q->throttled_flows--; + fq_flow_add_tail(&q->old_flows, f); +} + +static void fq_flow_set_throttled(struct fq_sched_data *q, struct fq_flow *f) +{ + struct rb_node **p = &q->delayed.rb_node, *parent = NULL; + + while (*p) { + struct fq_flow *aux; + + parent = *p; + aux = rb_entry(parent, struct fq_flow, rate_node); + if (f->time_next_packet >= aux->time_next_packet) + p = &parent->rb_right; + else + p = &parent->rb_left; + } + rb_link_node(&f->rate_node, parent, p); + rb_insert_color(&f->rate_node, &q->delayed); + q->throttled_flows++; + q->stat_throttled++; + + f->next = &throttled; + if (q->time_next_delayed_flow > f->time_next_packet) + q->time_next_delayed_flow = f->time_next_packet; +} + + +static struct kmem_cache *fq_flow_cachep __read_mostly; + + +/* limit number of collected flows per round */ +#define FQ_GC_MAX 8 +#define FQ_GC_AGE (3*HZ) + +static bool fq_gc_candidate(const struct fq_flow *f) +{ + return fq_flow_is_detached(f) && + time_after(jiffies, f->age + FQ_GC_AGE); +} + +static void fq_gc(struct fq_sched_data *q, + struct rb_root *root, + struct sock *sk) +{ + struct rb_node **p, *parent; + void *tofree[FQ_GC_MAX]; + struct fq_flow *f; + int i, fcnt = 0; + + p = &root->rb_node; + parent = NULL; + while (*p) { + parent = *p; + + f = rb_entry(parent, struct fq_flow, fq_node); + if (f->sk == sk) + break; + + if (fq_gc_candidate(f)) { + tofree[fcnt++] = f; + if (fcnt == FQ_GC_MAX) + break; + } + + if (f->sk > sk) + p = &parent->rb_right; + else + p = &parent->rb_left; + } + + if (!fcnt) + return; + + for (i = fcnt; i > 0; ) { + f = tofree[--i]; + rb_erase(&f->fq_node, root); + } + q->flows -= fcnt; + q->inactive_flows -= fcnt; + q->stat_gc_flows += fcnt; + + kmem_cache_free_bulk(fq_flow_cachep, fcnt, tofree); +} + +static struct fq_flow *fq_classify(struct sk_buff *skb, struct fq_sched_data *q) +{ + struct rb_node **p, *parent; + struct sock *sk = skb->sk; + struct rb_root *root; + struct fq_flow *f; + + /* warning: no starvation prevention... */ + if (unlikely((skb->priority & TC_PRIO_MAX) == TC_PRIO_CONTROL)) + return &q->internal; + + /* SYNACK messages are attached to a TCP_NEW_SYN_RECV request socket + * or a listener (SYNCOOKIE mode) + * 1) request sockets are not full blown, + * they do not contain sk_pacing_rate + * 2) They are not part of a 'flow' yet + * 3) We do not want to rate limit them (eg SYNFLOOD attack), + * especially if the listener set SO_MAX_PACING_RATE + * 4) We pretend they are orphaned + */ + if (!sk || sk_listener(sk)) { + unsigned long hash = skb_get_hash(skb) & q->orphan_mask; + + /* By forcing low order bit to 1, we make sure to not + * collide with a local flow (socket pointers are word aligned) + */ + sk = (struct sock *)((hash << 1) | 1UL); + skb_orphan(skb); + } else if (sk->sk_state == TCP_CLOSE) { + unsigned long hash = skb_get_hash(skb) & q->orphan_mask; + /* + * Sockets in TCP_CLOSE are non connected. + * Typical use case is UDP sockets, they can send packets + * with sendto() to many different destinations. + * We probably could use a generic bit advertising + * non connected sockets, instead of sk_state == TCP_CLOSE, + * if we care enough. + */ + sk = (struct sock *)((hash << 1) | 1UL); + } + + root = &q->fq_root[hash_ptr(sk, q->fq_trees_log)]; + + if (q->flows >= (2U << q->fq_trees_log) && + q->inactive_flows > q->flows/2) + fq_gc(q, root, sk); + + p = &root->rb_node; + parent = NULL; + while (*p) { + parent = *p; + + f = rb_entry(parent, struct fq_flow, fq_node); + if (f->sk == sk) { + /* socket might have been reallocated, so check + * if its sk_hash is the same. + * It not, we need to refill credit with + * initial quantum + */ + if (unlikely(skb->sk == sk && + f->socket_hash != sk->sk_hash)) { + f->credit = q->initial_quantum; + f->socket_hash = sk->sk_hash; + if (q->rate_enable) + smp_store_release(&sk->sk_pacing_status, + SK_PACING_FQ); + if (fq_flow_is_throttled(f)) + fq_flow_unset_throttled(q, f); + f->time_next_packet = 0ULL; + } + return f; + } + if (f->sk > sk) + p = &parent->rb_right; + else + p = &parent->rb_left; + } + + f = kmem_cache_zalloc(fq_flow_cachep, GFP_ATOMIC | __GFP_NOWARN); + if (unlikely(!f)) { + q->stat_allocation_errors++; + return &q->internal; + } + /* f->t_root is already zeroed after kmem_cache_zalloc() */ + + fq_flow_set_detached(f); + f->sk = sk; + if (skb->sk == sk) { + f->socket_hash = sk->sk_hash; + if (q->rate_enable) + smp_store_release(&sk->sk_pacing_status, + SK_PACING_FQ); + } + f->credit = q->initial_quantum; + + rb_link_node(&f->fq_node, parent, p); + rb_insert_color(&f->fq_node, root); + + q->flows++; + q->inactive_flows++; + return f; +} + +static struct sk_buff *fq_peek(struct fq_flow *flow) +{ + struct sk_buff *skb = skb_rb_first(&flow->t_root); + struct sk_buff *head = flow->head; + + if (!skb) + return head; + + if (!head) + return skb; + + if (fq_skb_cb(skb)->time_to_send < fq_skb_cb(head)->time_to_send) + return skb; + return head; +} + +static void fq_erase_head(struct Qdisc *sch, struct fq_flow *flow, + struct sk_buff *skb) +{ + if (skb == flow->head) { + flow->head = skb->next; + } else { + rb_erase(&skb->rbnode, &flow->t_root); + skb->dev = qdisc_dev(sch); + } +} + +/* Remove one skb from flow queue. + * This skb must be the return value of prior fq_peek(). + */ +static void fq_dequeue_skb(struct Qdisc *sch, struct fq_flow *flow, + struct sk_buff *skb) +{ + fq_erase_head(sch, flow, skb); + skb_mark_not_on_list(skb); + flow->qlen--; + qdisc_qstats_backlog_dec(sch, skb); + sch->q.qlen--; +} + +static void flow_queue_add(struct fq_flow *flow, struct sk_buff *skb) +{ + struct rb_node **p, *parent; + struct sk_buff *head, *aux; + + head = flow->head; + if (!head || + fq_skb_cb(skb)->time_to_send >= fq_skb_cb(flow->tail)->time_to_send) { + if (!head) + flow->head = skb; + else + flow->tail->next = skb; + flow->tail = skb; + skb->next = NULL; + return; + } + + p = &flow->t_root.rb_node; + parent = NULL; + + while (*p) { + parent = *p; + aux = rb_to_skb(parent); + if (fq_skb_cb(skb)->time_to_send >= fq_skb_cb(aux)->time_to_send) + p = &parent->rb_right; + else + p = &parent->rb_left; + } + rb_link_node(&skb->rbnode, parent, p); + rb_insert_color(&skb->rbnode, &flow->t_root); +} + +static bool fq_packet_beyond_horizon(const struct sk_buff *skb, + const struct fq_sched_data *q) +{ + return unlikely((s64)skb->tstamp > (s64)(q->ktime_cache + q->horizon)); +} + +static int fq_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct fq_sched_data *q = qdisc_priv(sch); + struct fq_flow *f; + + if (unlikely(sch->q.qlen >= sch->limit)) + return qdisc_drop(skb, sch, to_free); + + if (!skb->tstamp) { + fq_skb_cb(skb)->time_to_send = q->ktime_cache = ktime_get_ns(); + } else { + /* Check if packet timestamp is too far in the future. + * Try first if our cached value, to avoid ktime_get_ns() + * cost in most cases. + */ + if (fq_packet_beyond_horizon(skb, q)) { + /* Refresh our cache and check another time */ + q->ktime_cache = ktime_get_ns(); + if (fq_packet_beyond_horizon(skb, q)) { + if (q->horizon_drop) { + q->stat_horizon_drops++; + return qdisc_drop(skb, sch, to_free); + } + q->stat_horizon_caps++; + skb->tstamp = q->ktime_cache + q->horizon; + } + } + fq_skb_cb(skb)->time_to_send = skb->tstamp; + } + + f = fq_classify(skb, q); + if (unlikely(f->qlen >= q->flow_plimit && f != &q->internal)) { + q->stat_flows_plimit++; + return qdisc_drop(skb, sch, to_free); + } + + f->qlen++; + qdisc_qstats_backlog_inc(sch, skb); + if (fq_flow_is_detached(f)) { + fq_flow_add_tail(&q->new_flows, f); + if (time_after(jiffies, f->age + q->flow_refill_delay)) + f->credit = max_t(u32, f->credit, q->quantum); + q->inactive_flows--; + } + + /* Note: this overwrites f->age */ + flow_queue_add(f, skb); + + if (unlikely(f == &q->internal)) { + q->stat_internal_packets++; + } + sch->q.qlen++; + + return NET_XMIT_SUCCESS; +} + +static void fq_check_throttled(struct fq_sched_data *q, u64 now) +{ + unsigned long sample; + struct rb_node *p; + + if (q->time_next_delayed_flow > now) + return; + + /* Update unthrottle latency EWMA. + * This is cheap and can help diagnosing timer/latency problems. + */ + sample = (unsigned long)(now - q->time_next_delayed_flow); + q->unthrottle_latency_ns -= q->unthrottle_latency_ns >> 3; + q->unthrottle_latency_ns += sample >> 3; + + q->time_next_delayed_flow = ~0ULL; + while ((p = rb_first(&q->delayed)) != NULL) { + struct fq_flow *f = rb_entry(p, struct fq_flow, rate_node); + + if (f->time_next_packet > now) { + q->time_next_delayed_flow = f->time_next_packet; + break; + } + fq_flow_unset_throttled(q, f); + } +} + +static struct sk_buff *fq_dequeue(struct Qdisc *sch) +{ + struct fq_sched_data *q = qdisc_priv(sch); + struct fq_flow_head *head; + struct sk_buff *skb; + struct fq_flow *f; + unsigned long rate; + u32 plen; + u64 now; + + if (!sch->q.qlen) + return NULL; + + skb = fq_peek(&q->internal); + if (unlikely(skb)) { + fq_dequeue_skb(sch, &q->internal, skb); + goto out; + } + + q->ktime_cache = now = ktime_get_ns(); + fq_check_throttled(q, now); +begin: + head = &q->new_flows; + if (!head->first) { + head = &q->old_flows; + if (!head->first) { + if (q->time_next_delayed_flow != ~0ULL) + qdisc_watchdog_schedule_range_ns(&q->watchdog, + q->time_next_delayed_flow, + q->timer_slack); + return NULL; + } + } + f = head->first; + + if (f->credit <= 0) { + f->credit += q->quantum; + head->first = f->next; + fq_flow_add_tail(&q->old_flows, f); + goto begin; + } + + skb = fq_peek(f); + if (skb) { + u64 time_next_packet = max_t(u64, fq_skb_cb(skb)->time_to_send, + f->time_next_packet); + + if (now < time_next_packet) { + head->first = f->next; + f->time_next_packet = time_next_packet; + fq_flow_set_throttled(q, f); + goto begin; + } + prefetch(&skb->end); + if ((s64)(now - time_next_packet - q->ce_threshold) > 0) { + INET_ECN_set_ce(skb); + q->stat_ce_mark++; + } + fq_dequeue_skb(sch, f, skb); + } else { + head->first = f->next; + /* force a pass through old_flows to prevent starvation */ + if ((head == &q->new_flows) && q->old_flows.first) { + fq_flow_add_tail(&q->old_flows, f); + } else { + fq_flow_set_detached(f); + q->inactive_flows++; + } + goto begin; + } + plen = qdisc_pkt_len(skb); + f->credit -= plen; + + if (!q->rate_enable) + goto out; + + rate = q->flow_max_rate; + + /* If EDT time was provided for this skb, we need to + * update f->time_next_packet only if this qdisc enforces + * a flow max rate. + */ + if (!skb->tstamp) { + if (skb->sk) + rate = min(skb->sk->sk_pacing_rate, rate); + + if (rate <= q->low_rate_threshold) { + f->credit = 0; + } else { + plen = max(plen, q->quantum); + if (f->credit > 0) + goto out; + } + } + if (rate != ~0UL) { + u64 len = (u64)plen * NSEC_PER_SEC; + + if (likely(rate)) + len = div64_ul(len, rate); + /* Since socket rate can change later, + * clamp the delay to 1 second. + * Really, providers of too big packets should be fixed ! + */ + if (unlikely(len > NSEC_PER_SEC)) { + len = NSEC_PER_SEC; + q->stat_pkts_too_long++; + } + /* Account for schedule/timers drifts. + * f->time_next_packet was set when prior packet was sent, + * and current time (@now) can be too late by tens of us. + */ + if (f->time_next_packet) + len -= min(len/2, now - f->time_next_packet); + f->time_next_packet = now + len; + } +out: + qdisc_bstats_update(sch, skb); + return skb; +} + +static void fq_flow_purge(struct fq_flow *flow) +{ + struct rb_node *p = rb_first(&flow->t_root); + + while (p) { + struct sk_buff *skb = rb_to_skb(p); + + p = rb_next(p); + rb_erase(&skb->rbnode, &flow->t_root); + rtnl_kfree_skbs(skb, skb); + } + rtnl_kfree_skbs(flow->head, flow->tail); + flow->head = NULL; + flow->qlen = 0; +} + +static void fq_reset(struct Qdisc *sch) +{ + struct fq_sched_data *q = qdisc_priv(sch); + struct rb_root *root; + struct rb_node *p; + struct fq_flow *f; + unsigned int idx; + + sch->q.qlen = 0; + sch->qstats.backlog = 0; + + fq_flow_purge(&q->internal); + + if (!q->fq_root) + return; + + for (idx = 0; idx < (1U << q->fq_trees_log); idx++) { + root = &q->fq_root[idx]; + while ((p = rb_first(root)) != NULL) { + f = rb_entry(p, struct fq_flow, fq_node); + rb_erase(p, root); + + fq_flow_purge(f); + + kmem_cache_free(fq_flow_cachep, f); + } + } + q->new_flows.first = NULL; + q->old_flows.first = NULL; + q->delayed = RB_ROOT; + q->flows = 0; + q->inactive_flows = 0; + q->throttled_flows = 0; +} + +static void fq_rehash(struct fq_sched_data *q, + struct rb_root *old_array, u32 old_log, + struct rb_root *new_array, u32 new_log) +{ + struct rb_node *op, **np, *parent; + struct rb_root *oroot, *nroot; + struct fq_flow *of, *nf; + int fcnt = 0; + u32 idx; + + for (idx = 0; idx < (1U << old_log); idx++) { + oroot = &old_array[idx]; + while ((op = rb_first(oroot)) != NULL) { + rb_erase(op, oroot); + of = rb_entry(op, struct fq_flow, fq_node); + if (fq_gc_candidate(of)) { + fcnt++; + kmem_cache_free(fq_flow_cachep, of); + continue; + } + nroot = &new_array[hash_ptr(of->sk, new_log)]; + + np = &nroot->rb_node; + parent = NULL; + while (*np) { + parent = *np; + + nf = rb_entry(parent, struct fq_flow, fq_node); + BUG_ON(nf->sk == of->sk); + + if (nf->sk > of->sk) + np = &parent->rb_right; + else + np = &parent->rb_left; + } + + rb_link_node(&of->fq_node, parent, np); + rb_insert_color(&of->fq_node, nroot); + } + } + q->flows -= fcnt; + q->inactive_flows -= fcnt; + q->stat_gc_flows += fcnt; +} + +static void fq_free(void *addr) +{ + kvfree(addr); +} + +static int fq_resize(struct Qdisc *sch, u32 log) +{ + struct fq_sched_data *q = qdisc_priv(sch); + struct rb_root *array; + void *old_fq_root; + u32 idx; + + if (q->fq_root && log == q->fq_trees_log) + return 0; + + /* If XPS was setup, we can allocate memory on right NUMA node */ + array = kvmalloc_node(sizeof(struct rb_root) << log, GFP_KERNEL | __GFP_RETRY_MAYFAIL, + netdev_queue_numa_node_read(sch->dev_queue)); + if (!array) + return -ENOMEM; + + for (idx = 0; idx < (1U << log); idx++) + array[idx] = RB_ROOT; + + sch_tree_lock(sch); + + old_fq_root = q->fq_root; + if (old_fq_root) + fq_rehash(q, old_fq_root, q->fq_trees_log, array, log); + + q->fq_root = array; + q->fq_trees_log = log; + + sch_tree_unlock(sch); + + fq_free(old_fq_root); + + return 0; +} + +static struct netlink_range_validation iq_range = { + .max = INT_MAX, +}; + +static const struct nla_policy fq_policy[TCA_FQ_MAX + 1] = { + [TCA_FQ_UNSPEC] = { .strict_start_type = TCA_FQ_TIMER_SLACK }, + + [TCA_FQ_PLIMIT] = { .type = NLA_U32 }, + [TCA_FQ_FLOW_PLIMIT] = { .type = NLA_U32 }, + [TCA_FQ_QUANTUM] = { .type = NLA_U32 }, + [TCA_FQ_INITIAL_QUANTUM] = NLA_POLICY_FULL_RANGE(NLA_U32, &iq_range), + [TCA_FQ_RATE_ENABLE] = { .type = NLA_U32 }, + [TCA_FQ_FLOW_DEFAULT_RATE] = { .type = NLA_U32 }, + [TCA_FQ_FLOW_MAX_RATE] = { .type = NLA_U32 }, + [TCA_FQ_BUCKETS_LOG] = { .type = NLA_U32 }, + [TCA_FQ_FLOW_REFILL_DELAY] = { .type = NLA_U32 }, + [TCA_FQ_ORPHAN_MASK] = { .type = NLA_U32 }, + [TCA_FQ_LOW_RATE_THRESHOLD] = { .type = NLA_U32 }, + [TCA_FQ_CE_THRESHOLD] = { .type = NLA_U32 }, + [TCA_FQ_TIMER_SLACK] = { .type = NLA_U32 }, + [TCA_FQ_HORIZON] = { .type = NLA_U32 }, + [TCA_FQ_HORIZON_DROP] = { .type = NLA_U8 }, +}; + +static int fq_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct fq_sched_data *q = qdisc_priv(sch); + struct nlattr *tb[TCA_FQ_MAX + 1]; + int err, drop_count = 0; + unsigned drop_len = 0; + u32 fq_log; + + err = nla_parse_nested_deprecated(tb, TCA_FQ_MAX, opt, fq_policy, + NULL); + if (err < 0) + return err; + + sch_tree_lock(sch); + + fq_log = q->fq_trees_log; + + if (tb[TCA_FQ_BUCKETS_LOG]) { + u32 nval = nla_get_u32(tb[TCA_FQ_BUCKETS_LOG]); + + if (nval >= 1 && nval <= ilog2(256*1024)) + fq_log = nval; + else + err = -EINVAL; + } + if (tb[TCA_FQ_PLIMIT]) + sch->limit = nla_get_u32(tb[TCA_FQ_PLIMIT]); + + if (tb[TCA_FQ_FLOW_PLIMIT]) + q->flow_plimit = nla_get_u32(tb[TCA_FQ_FLOW_PLIMIT]); + + if (tb[TCA_FQ_QUANTUM]) { + u32 quantum = nla_get_u32(tb[TCA_FQ_QUANTUM]); + + if (quantum > 0 && quantum <= (1 << 20)) { + q->quantum = quantum; + } else { + NL_SET_ERR_MSG_MOD(extack, "invalid quantum"); + err = -EINVAL; + } + } + + if (tb[TCA_FQ_INITIAL_QUANTUM]) + q->initial_quantum = nla_get_u32(tb[TCA_FQ_INITIAL_QUANTUM]); + + if (tb[TCA_FQ_FLOW_DEFAULT_RATE]) + pr_warn_ratelimited("sch_fq: defrate %u ignored.\n", + nla_get_u32(tb[TCA_FQ_FLOW_DEFAULT_RATE])); + + if (tb[TCA_FQ_FLOW_MAX_RATE]) { + u32 rate = nla_get_u32(tb[TCA_FQ_FLOW_MAX_RATE]); + + q->flow_max_rate = (rate == ~0U) ? ~0UL : rate; + } + if (tb[TCA_FQ_LOW_RATE_THRESHOLD]) + q->low_rate_threshold = + nla_get_u32(tb[TCA_FQ_LOW_RATE_THRESHOLD]); + + if (tb[TCA_FQ_RATE_ENABLE]) { + u32 enable = nla_get_u32(tb[TCA_FQ_RATE_ENABLE]); + + if (enable <= 1) + q->rate_enable = enable; + else + err = -EINVAL; + } + + if (tb[TCA_FQ_FLOW_REFILL_DELAY]) { + u32 usecs_delay = nla_get_u32(tb[TCA_FQ_FLOW_REFILL_DELAY]) ; + + q->flow_refill_delay = usecs_to_jiffies(usecs_delay); + } + + if (tb[TCA_FQ_ORPHAN_MASK]) + q->orphan_mask = nla_get_u32(tb[TCA_FQ_ORPHAN_MASK]); + + if (tb[TCA_FQ_CE_THRESHOLD]) + q->ce_threshold = (u64)NSEC_PER_USEC * + nla_get_u32(tb[TCA_FQ_CE_THRESHOLD]); + + if (tb[TCA_FQ_TIMER_SLACK]) + q->timer_slack = nla_get_u32(tb[TCA_FQ_TIMER_SLACK]); + + if (tb[TCA_FQ_HORIZON]) + q->horizon = (u64)NSEC_PER_USEC * + nla_get_u32(tb[TCA_FQ_HORIZON]); + + if (tb[TCA_FQ_HORIZON_DROP]) + q->horizon_drop = nla_get_u8(tb[TCA_FQ_HORIZON_DROP]); + + if (!err) { + + sch_tree_unlock(sch); + err = fq_resize(sch, fq_log); + sch_tree_lock(sch); + } + while (sch->q.qlen > sch->limit) { + struct sk_buff *skb = fq_dequeue(sch); + + if (!skb) + break; + drop_len += qdisc_pkt_len(skb); + rtnl_kfree_skbs(skb, skb); + drop_count++; + } + qdisc_tree_reduce_backlog(sch, drop_count, drop_len); + + sch_tree_unlock(sch); + return err; +} + +static void fq_destroy(struct Qdisc *sch) +{ + struct fq_sched_data *q = qdisc_priv(sch); + + fq_reset(sch); + fq_free(q->fq_root); + qdisc_watchdog_cancel(&q->watchdog); +} + +static int fq_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct fq_sched_data *q = qdisc_priv(sch); + int err; + + sch->limit = 10000; + q->flow_plimit = 100; + q->quantum = 2 * psched_mtu(qdisc_dev(sch)); + q->initial_quantum = 10 * psched_mtu(qdisc_dev(sch)); + q->flow_refill_delay = msecs_to_jiffies(40); + q->flow_max_rate = ~0UL; + q->time_next_delayed_flow = ~0ULL; + q->rate_enable = 1; + q->new_flows.first = NULL; + q->old_flows.first = NULL; + q->delayed = RB_ROOT; + q->fq_root = NULL; + q->fq_trees_log = ilog2(1024); + q->orphan_mask = 1024 - 1; + q->low_rate_threshold = 550000 / 8; + + q->timer_slack = 10 * NSEC_PER_USEC; /* 10 usec of hrtimer slack */ + + q->horizon = 10ULL * NSEC_PER_SEC; /* 10 seconds */ + q->horizon_drop = 1; /* by default, drop packets beyond horizon */ + + /* Default ce_threshold of 4294 seconds */ + q->ce_threshold = (u64)NSEC_PER_USEC * ~0U; + + qdisc_watchdog_init_clockid(&q->watchdog, sch, CLOCK_MONOTONIC); + + if (opt) + err = fq_change(sch, opt, extack); + else + err = fq_resize(sch, q->fq_trees_log); + + return err; +} + +static int fq_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct fq_sched_data *q = qdisc_priv(sch); + u64 ce_threshold = q->ce_threshold; + u64 horizon = q->horizon; + struct nlattr *opts; + + opts = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (opts == NULL) + goto nla_put_failure; + + /* TCA_FQ_FLOW_DEFAULT_RATE is not used anymore */ + + do_div(ce_threshold, NSEC_PER_USEC); + do_div(horizon, NSEC_PER_USEC); + + if (nla_put_u32(skb, TCA_FQ_PLIMIT, sch->limit) || + nla_put_u32(skb, TCA_FQ_FLOW_PLIMIT, q->flow_plimit) || + nla_put_u32(skb, TCA_FQ_QUANTUM, q->quantum) || + nla_put_u32(skb, TCA_FQ_INITIAL_QUANTUM, q->initial_quantum) || + nla_put_u32(skb, TCA_FQ_RATE_ENABLE, q->rate_enable) || + nla_put_u32(skb, TCA_FQ_FLOW_MAX_RATE, + min_t(unsigned long, q->flow_max_rate, ~0U)) || + nla_put_u32(skb, TCA_FQ_FLOW_REFILL_DELAY, + jiffies_to_usecs(q->flow_refill_delay)) || + nla_put_u32(skb, TCA_FQ_ORPHAN_MASK, q->orphan_mask) || + nla_put_u32(skb, TCA_FQ_LOW_RATE_THRESHOLD, + q->low_rate_threshold) || + nla_put_u32(skb, TCA_FQ_CE_THRESHOLD, (u32)ce_threshold) || + nla_put_u32(skb, TCA_FQ_BUCKETS_LOG, q->fq_trees_log) || + nla_put_u32(skb, TCA_FQ_TIMER_SLACK, q->timer_slack) || + nla_put_u32(skb, TCA_FQ_HORIZON, (u32)horizon) || + nla_put_u8(skb, TCA_FQ_HORIZON_DROP, q->horizon_drop)) + goto nla_put_failure; + + return nla_nest_end(skb, opts); + +nla_put_failure: + return -1; +} + +static int fq_dump_stats(struct Qdisc *sch, struct gnet_dump *d) +{ + struct fq_sched_data *q = qdisc_priv(sch); + struct tc_fq_qd_stats st; + + sch_tree_lock(sch); + + st.gc_flows = q->stat_gc_flows; + st.highprio_packets = q->stat_internal_packets; + st.tcp_retrans = 0; + st.throttled = q->stat_throttled; + st.flows_plimit = q->stat_flows_plimit; + st.pkts_too_long = q->stat_pkts_too_long; + st.allocation_errors = q->stat_allocation_errors; + st.time_next_delayed_flow = q->time_next_delayed_flow + q->timer_slack - + ktime_get_ns(); + st.flows = q->flows; + st.inactive_flows = q->inactive_flows; + st.throttled_flows = q->throttled_flows; + st.unthrottle_latency_ns = min_t(unsigned long, + q->unthrottle_latency_ns, ~0U); + st.ce_mark = q->stat_ce_mark; + st.horizon_drops = q->stat_horizon_drops; + st.horizon_caps = q->stat_horizon_caps; + sch_tree_unlock(sch); + + return gnet_stats_copy_app(d, &st, sizeof(st)); +} + +static struct Qdisc_ops fq_qdisc_ops __read_mostly = { + .id = "fq", + .priv_size = sizeof(struct fq_sched_data), + + .enqueue = fq_enqueue, + .dequeue = fq_dequeue, + .peek = qdisc_peek_dequeued, + .init = fq_init, + .reset = fq_reset, + .destroy = fq_destroy, + .change = fq_change, + .dump = fq_dump, + .dump_stats = fq_dump_stats, + .owner = THIS_MODULE, +}; + +static int __init fq_module_init(void) +{ + int ret; + + fq_flow_cachep = kmem_cache_create("fq_flow_cache", + sizeof(struct fq_flow), + 0, 0, NULL); + if (!fq_flow_cachep) + return -ENOMEM; + + ret = register_qdisc(&fq_qdisc_ops); + if (ret) + kmem_cache_destroy(fq_flow_cachep); + return ret; +} + +static void __exit fq_module_exit(void) +{ + unregister_qdisc(&fq_qdisc_ops); + kmem_cache_destroy(fq_flow_cachep); +} + +module_init(fq_module_init) +module_exit(fq_module_exit) +MODULE_AUTHOR("Eric Dumazet"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Fair Queue Packet Scheduler"); diff --git a/net/sched/sch_fq_codel.c b/net/sched/sch_fq_codel.c new file mode 100644 index 000000000..8c4fee063 --- /dev/null +++ b/net/sched/sch_fq_codel.c @@ -0,0 +1,735 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Fair Queue CoDel discipline + * + * Copyright (C) 2012,2015 Eric Dumazet + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* Fair Queue CoDel. + * + * Principles : + * Packets are classified (internal classifier or external) on flows. + * This is a Stochastic model (as we use a hash, several flows + * might be hashed on same slot) + * Each flow has a CoDel managed queue. + * Flows are linked onto two (Round Robin) lists, + * so that new flows have priority on old ones. + * + * For a given flow, packets are not reordered (CoDel uses a FIFO) + * head drops only. + * ECN capability is on by default. + * Low memory footprint (64 bytes per flow) + */ + +struct fq_codel_flow { + struct sk_buff *head; + struct sk_buff *tail; + struct list_head flowchain; + int deficit; + struct codel_vars cvars; +}; /* please try to keep this structure <= 64 bytes */ + +struct fq_codel_sched_data { + struct tcf_proto __rcu *filter_list; /* optional external classifier */ + struct tcf_block *block; + struct fq_codel_flow *flows; /* Flows table [flows_cnt] */ + u32 *backlogs; /* backlog table [flows_cnt] */ + u32 flows_cnt; /* number of flows */ + u32 quantum; /* psched_mtu(qdisc_dev(sch)); */ + u32 drop_batch_size; + u32 memory_limit; + struct codel_params cparams; + struct codel_stats cstats; + u32 memory_usage; + u32 drop_overmemory; + u32 drop_overlimit; + u32 new_flow_count; + + struct list_head new_flows; /* list of new flows */ + struct list_head old_flows; /* list of old flows */ +}; + +static unsigned int fq_codel_hash(const struct fq_codel_sched_data *q, + struct sk_buff *skb) +{ + return reciprocal_scale(skb_get_hash(skb), q->flows_cnt); +} + +static unsigned int fq_codel_classify(struct sk_buff *skb, struct Qdisc *sch, + int *qerr) +{ + struct fq_codel_sched_data *q = qdisc_priv(sch); + struct tcf_proto *filter; + struct tcf_result res; + int result; + + if (TC_H_MAJ(skb->priority) == sch->handle && + TC_H_MIN(skb->priority) > 0 && + TC_H_MIN(skb->priority) <= q->flows_cnt) + return TC_H_MIN(skb->priority); + + filter = rcu_dereference_bh(q->filter_list); + if (!filter) + return fq_codel_hash(q, skb) + 1; + + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; + result = tcf_classify(skb, NULL, filter, &res, false); + if (result >= 0) { +#ifdef CONFIG_NET_CLS_ACT + switch (result) { + case TC_ACT_STOLEN: + case TC_ACT_QUEUED: + case TC_ACT_TRAP: + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_STOLEN; + fallthrough; + case TC_ACT_SHOT: + return 0; + } +#endif + if (TC_H_MIN(res.classid) <= q->flows_cnt) + return TC_H_MIN(res.classid); + } + return 0; +} + +/* helper functions : might be changed when/if skb use a standard list_head */ + +/* remove one skb from head of slot queue */ +static inline struct sk_buff *dequeue_head(struct fq_codel_flow *flow) +{ + struct sk_buff *skb = flow->head; + + flow->head = skb->next; + skb_mark_not_on_list(skb); + return skb; +} + +/* add skb to flow queue (tail add) */ +static inline void flow_queue_add(struct fq_codel_flow *flow, + struct sk_buff *skb) +{ + if (flow->head == NULL) + flow->head = skb; + else + flow->tail->next = skb; + flow->tail = skb; + skb->next = NULL; +} + +static unsigned int fq_codel_drop(struct Qdisc *sch, unsigned int max_packets, + struct sk_buff **to_free) +{ + struct fq_codel_sched_data *q = qdisc_priv(sch); + struct sk_buff *skb; + unsigned int maxbacklog = 0, idx = 0, i, len; + struct fq_codel_flow *flow; + unsigned int threshold; + unsigned int mem = 0; + + /* Queue is full! Find the fat flow and drop packet(s) from it. + * This might sound expensive, but with 1024 flows, we scan + * 4KB of memory, and we dont need to handle a complex tree + * in fast path (packet queue/enqueue) with many cache misses. + * In stress mode, we'll try to drop 64 packets from the flow, + * amortizing this linear lookup to one cache line per drop. + */ + for (i = 0; i < q->flows_cnt; i++) { + if (q->backlogs[i] > maxbacklog) { + maxbacklog = q->backlogs[i]; + idx = i; + } + } + + /* Our goal is to drop half of this fat flow backlog */ + threshold = maxbacklog >> 1; + + flow = &q->flows[idx]; + len = 0; + i = 0; + do { + skb = dequeue_head(flow); + len += qdisc_pkt_len(skb); + mem += get_codel_cb(skb)->mem_usage; + __qdisc_drop(skb, to_free); + } while (++i < max_packets && len < threshold); + + /* Tell codel to increase its signal strength also */ + flow->cvars.count += i; + q->backlogs[idx] -= len; + q->memory_usage -= mem; + sch->qstats.drops += i; + sch->qstats.backlog -= len; + sch->q.qlen -= i; + return idx; +} + +static int fq_codel_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct fq_codel_sched_data *q = qdisc_priv(sch); + unsigned int idx, prev_backlog, prev_qlen; + struct fq_codel_flow *flow; + int ret; + unsigned int pkt_len; + bool memory_limited; + + idx = fq_codel_classify(skb, sch, &ret); + if (idx == 0) { + if (ret & __NET_XMIT_BYPASS) + qdisc_qstats_drop(sch); + __qdisc_drop(skb, to_free); + return ret; + } + idx--; + + codel_set_enqueue_time(skb); + flow = &q->flows[idx]; + flow_queue_add(flow, skb); + q->backlogs[idx] += qdisc_pkt_len(skb); + qdisc_qstats_backlog_inc(sch, skb); + + if (list_empty(&flow->flowchain)) { + list_add_tail(&flow->flowchain, &q->new_flows); + q->new_flow_count++; + flow->deficit = q->quantum; + } + get_codel_cb(skb)->mem_usage = skb->truesize; + q->memory_usage += get_codel_cb(skb)->mem_usage; + memory_limited = q->memory_usage > q->memory_limit; + if (++sch->q.qlen <= sch->limit && !memory_limited) + return NET_XMIT_SUCCESS; + + prev_backlog = sch->qstats.backlog; + prev_qlen = sch->q.qlen; + + /* save this packet length as it might be dropped by fq_codel_drop() */ + pkt_len = qdisc_pkt_len(skb); + /* fq_codel_drop() is quite expensive, as it performs a linear search + * in q->backlogs[] to find a fat flow. + * So instead of dropping a single packet, drop half of its backlog + * with a 64 packets limit to not add a too big cpu spike here. + */ + ret = fq_codel_drop(sch, q->drop_batch_size, to_free); + + prev_qlen -= sch->q.qlen; + prev_backlog -= sch->qstats.backlog; + q->drop_overlimit += prev_qlen; + if (memory_limited) + q->drop_overmemory += prev_qlen; + + /* As we dropped packet(s), better let upper stack know this. + * If we dropped a packet for this flow, return NET_XMIT_CN, + * but in this case, our parents wont increase their backlogs. + */ + if (ret == idx) { + qdisc_tree_reduce_backlog(sch, prev_qlen - 1, + prev_backlog - pkt_len); + return NET_XMIT_CN; + } + qdisc_tree_reduce_backlog(sch, prev_qlen, prev_backlog); + return NET_XMIT_SUCCESS; +} + +/* This is the specific function called from codel_dequeue() + * to dequeue a packet from queue. Note: backlog is handled in + * codel, we dont need to reduce it here. + */ +static struct sk_buff *dequeue_func(struct codel_vars *vars, void *ctx) +{ + struct Qdisc *sch = ctx; + struct fq_codel_sched_data *q = qdisc_priv(sch); + struct fq_codel_flow *flow; + struct sk_buff *skb = NULL; + + flow = container_of(vars, struct fq_codel_flow, cvars); + if (flow->head) { + skb = dequeue_head(flow); + q->backlogs[flow - q->flows] -= qdisc_pkt_len(skb); + q->memory_usage -= get_codel_cb(skb)->mem_usage; + sch->q.qlen--; + sch->qstats.backlog -= qdisc_pkt_len(skb); + } + return skb; +} + +static void drop_func(struct sk_buff *skb, void *ctx) +{ + struct Qdisc *sch = ctx; + + kfree_skb(skb); + qdisc_qstats_drop(sch); +} + +static struct sk_buff *fq_codel_dequeue(struct Qdisc *sch) +{ + struct fq_codel_sched_data *q = qdisc_priv(sch); + struct sk_buff *skb; + struct fq_codel_flow *flow; + struct list_head *head; + +begin: + head = &q->new_flows; + if (list_empty(head)) { + head = &q->old_flows; + if (list_empty(head)) + return NULL; + } + flow = list_first_entry(head, struct fq_codel_flow, flowchain); + + if (flow->deficit <= 0) { + flow->deficit += q->quantum; + list_move_tail(&flow->flowchain, &q->old_flows); + goto begin; + } + + skb = codel_dequeue(sch, &sch->qstats.backlog, &q->cparams, + &flow->cvars, &q->cstats, qdisc_pkt_len, + codel_get_enqueue_time, drop_func, dequeue_func); + + if (!skb) { + /* force a pass through old_flows to prevent starvation */ + if ((head == &q->new_flows) && !list_empty(&q->old_flows)) + list_move_tail(&flow->flowchain, &q->old_flows); + else + list_del_init(&flow->flowchain); + goto begin; + } + qdisc_bstats_update(sch, skb); + flow->deficit -= qdisc_pkt_len(skb); + /* We cant call qdisc_tree_reduce_backlog() if our qlen is 0, + * or HTB crashes. Defer it for next round. + */ + if (q->cstats.drop_count && sch->q.qlen) { + qdisc_tree_reduce_backlog(sch, q->cstats.drop_count, + q->cstats.drop_len); + q->cstats.drop_count = 0; + q->cstats.drop_len = 0; + } + return skb; +} + +static void fq_codel_flow_purge(struct fq_codel_flow *flow) +{ + rtnl_kfree_skbs(flow->head, flow->tail); + flow->head = NULL; +} + +static void fq_codel_reset(struct Qdisc *sch) +{ + struct fq_codel_sched_data *q = qdisc_priv(sch); + int i; + + INIT_LIST_HEAD(&q->new_flows); + INIT_LIST_HEAD(&q->old_flows); + for (i = 0; i < q->flows_cnt; i++) { + struct fq_codel_flow *flow = q->flows + i; + + fq_codel_flow_purge(flow); + INIT_LIST_HEAD(&flow->flowchain); + codel_vars_init(&flow->cvars); + } + memset(q->backlogs, 0, q->flows_cnt * sizeof(u32)); + q->memory_usage = 0; +} + +static const struct nla_policy fq_codel_policy[TCA_FQ_CODEL_MAX + 1] = { + [TCA_FQ_CODEL_TARGET] = { .type = NLA_U32 }, + [TCA_FQ_CODEL_LIMIT] = { .type = NLA_U32 }, + [TCA_FQ_CODEL_INTERVAL] = { .type = NLA_U32 }, + [TCA_FQ_CODEL_ECN] = { .type = NLA_U32 }, + [TCA_FQ_CODEL_FLOWS] = { .type = NLA_U32 }, + [TCA_FQ_CODEL_QUANTUM] = { .type = NLA_U32 }, + [TCA_FQ_CODEL_CE_THRESHOLD] = { .type = NLA_U32 }, + [TCA_FQ_CODEL_DROP_BATCH_SIZE] = { .type = NLA_U32 }, + [TCA_FQ_CODEL_MEMORY_LIMIT] = { .type = NLA_U32 }, + [TCA_FQ_CODEL_CE_THRESHOLD_SELECTOR] = { .type = NLA_U8 }, + [TCA_FQ_CODEL_CE_THRESHOLD_MASK] = { .type = NLA_U8 }, +}; + +static int fq_codel_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct fq_codel_sched_data *q = qdisc_priv(sch); + struct nlattr *tb[TCA_FQ_CODEL_MAX + 1]; + u32 quantum = 0; + int err; + + err = nla_parse_nested_deprecated(tb, TCA_FQ_CODEL_MAX, opt, + fq_codel_policy, NULL); + if (err < 0) + return err; + if (tb[TCA_FQ_CODEL_FLOWS]) { + if (q->flows) + return -EINVAL; + q->flows_cnt = nla_get_u32(tb[TCA_FQ_CODEL_FLOWS]); + if (!q->flows_cnt || + q->flows_cnt > 65536) + return -EINVAL; + } + if (tb[TCA_FQ_CODEL_QUANTUM]) { + quantum = max(256U, nla_get_u32(tb[TCA_FQ_CODEL_QUANTUM])); + if (quantum > FQ_CODEL_QUANTUM_MAX) { + NL_SET_ERR_MSG(extack, "Invalid quantum"); + return -EINVAL; + } + } + sch_tree_lock(sch); + + if (tb[TCA_FQ_CODEL_TARGET]) { + u64 target = nla_get_u32(tb[TCA_FQ_CODEL_TARGET]); + + q->cparams.target = (target * NSEC_PER_USEC) >> CODEL_SHIFT; + } + + if (tb[TCA_FQ_CODEL_CE_THRESHOLD]) { + u64 val = nla_get_u32(tb[TCA_FQ_CODEL_CE_THRESHOLD]); + + q->cparams.ce_threshold = (val * NSEC_PER_USEC) >> CODEL_SHIFT; + } + + if (tb[TCA_FQ_CODEL_CE_THRESHOLD_SELECTOR]) + q->cparams.ce_threshold_selector = nla_get_u8(tb[TCA_FQ_CODEL_CE_THRESHOLD_SELECTOR]); + if (tb[TCA_FQ_CODEL_CE_THRESHOLD_MASK]) + q->cparams.ce_threshold_mask = nla_get_u8(tb[TCA_FQ_CODEL_CE_THRESHOLD_MASK]); + + if (tb[TCA_FQ_CODEL_INTERVAL]) { + u64 interval = nla_get_u32(tb[TCA_FQ_CODEL_INTERVAL]); + + q->cparams.interval = (interval * NSEC_PER_USEC) >> CODEL_SHIFT; + } + + if (tb[TCA_FQ_CODEL_LIMIT]) + sch->limit = nla_get_u32(tb[TCA_FQ_CODEL_LIMIT]); + + if (tb[TCA_FQ_CODEL_ECN]) + q->cparams.ecn = !!nla_get_u32(tb[TCA_FQ_CODEL_ECN]); + + if (quantum) + q->quantum = quantum; + + if (tb[TCA_FQ_CODEL_DROP_BATCH_SIZE]) + q->drop_batch_size = max(1U, nla_get_u32(tb[TCA_FQ_CODEL_DROP_BATCH_SIZE])); + + if (tb[TCA_FQ_CODEL_MEMORY_LIMIT]) + q->memory_limit = min(1U << 31, nla_get_u32(tb[TCA_FQ_CODEL_MEMORY_LIMIT])); + + while (sch->q.qlen > sch->limit || + q->memory_usage > q->memory_limit) { + struct sk_buff *skb = fq_codel_dequeue(sch); + + q->cstats.drop_len += qdisc_pkt_len(skb); + rtnl_kfree_skbs(skb, skb); + q->cstats.drop_count++; + } + qdisc_tree_reduce_backlog(sch, q->cstats.drop_count, q->cstats.drop_len); + q->cstats.drop_count = 0; + q->cstats.drop_len = 0; + + sch_tree_unlock(sch); + return 0; +} + +static void fq_codel_destroy(struct Qdisc *sch) +{ + struct fq_codel_sched_data *q = qdisc_priv(sch); + + tcf_block_put(q->block); + kvfree(q->backlogs); + kvfree(q->flows); +} + +static int fq_codel_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct fq_codel_sched_data *q = qdisc_priv(sch); + int i; + int err; + + sch->limit = 10*1024; + q->flows_cnt = 1024; + q->memory_limit = 32 << 20; /* 32 MBytes */ + q->drop_batch_size = 64; + q->quantum = psched_mtu(qdisc_dev(sch)); + INIT_LIST_HEAD(&q->new_flows); + INIT_LIST_HEAD(&q->old_flows); + codel_params_init(&q->cparams); + codel_stats_init(&q->cstats); + q->cparams.ecn = true; + q->cparams.mtu = psched_mtu(qdisc_dev(sch)); + + if (opt) { + err = fq_codel_change(sch, opt, extack); + if (err) + goto init_failure; + } + + err = tcf_block_get(&q->block, &q->filter_list, sch, extack); + if (err) + goto init_failure; + + if (!q->flows) { + q->flows = kvcalloc(q->flows_cnt, + sizeof(struct fq_codel_flow), + GFP_KERNEL); + if (!q->flows) { + err = -ENOMEM; + goto init_failure; + } + q->backlogs = kvcalloc(q->flows_cnt, sizeof(u32), GFP_KERNEL); + if (!q->backlogs) { + err = -ENOMEM; + goto alloc_failure; + } + for (i = 0; i < q->flows_cnt; i++) { + struct fq_codel_flow *flow = q->flows + i; + + INIT_LIST_HEAD(&flow->flowchain); + codel_vars_init(&flow->cvars); + } + } + if (sch->limit >= 1) + sch->flags |= TCQ_F_CAN_BYPASS; + else + sch->flags &= ~TCQ_F_CAN_BYPASS; + return 0; + +alloc_failure: + kvfree(q->flows); + q->flows = NULL; +init_failure: + q->flows_cnt = 0; + return err; +} + +static int fq_codel_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct fq_codel_sched_data *q = qdisc_priv(sch); + struct nlattr *opts; + + opts = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (opts == NULL) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_FQ_CODEL_TARGET, + codel_time_to_us(q->cparams.target)) || + nla_put_u32(skb, TCA_FQ_CODEL_LIMIT, + sch->limit) || + nla_put_u32(skb, TCA_FQ_CODEL_INTERVAL, + codel_time_to_us(q->cparams.interval)) || + nla_put_u32(skb, TCA_FQ_CODEL_ECN, + q->cparams.ecn) || + nla_put_u32(skb, TCA_FQ_CODEL_QUANTUM, + q->quantum) || + nla_put_u32(skb, TCA_FQ_CODEL_DROP_BATCH_SIZE, + q->drop_batch_size) || + nla_put_u32(skb, TCA_FQ_CODEL_MEMORY_LIMIT, + q->memory_limit) || + nla_put_u32(skb, TCA_FQ_CODEL_FLOWS, + q->flows_cnt)) + goto nla_put_failure; + + if (q->cparams.ce_threshold != CODEL_DISABLED_THRESHOLD) { + if (nla_put_u32(skb, TCA_FQ_CODEL_CE_THRESHOLD, + codel_time_to_us(q->cparams.ce_threshold))) + goto nla_put_failure; + if (nla_put_u8(skb, TCA_FQ_CODEL_CE_THRESHOLD_SELECTOR, q->cparams.ce_threshold_selector)) + goto nla_put_failure; + if (nla_put_u8(skb, TCA_FQ_CODEL_CE_THRESHOLD_MASK, q->cparams.ce_threshold_mask)) + goto nla_put_failure; + } + + return nla_nest_end(skb, opts); + +nla_put_failure: + return -1; +} + +static int fq_codel_dump_stats(struct Qdisc *sch, struct gnet_dump *d) +{ + struct fq_codel_sched_data *q = qdisc_priv(sch); + struct tc_fq_codel_xstats st = { + .type = TCA_FQ_CODEL_XSTATS_QDISC, + }; + struct list_head *pos; + + st.qdisc_stats.maxpacket = q->cstats.maxpacket; + st.qdisc_stats.drop_overlimit = q->drop_overlimit; + st.qdisc_stats.ecn_mark = q->cstats.ecn_mark; + st.qdisc_stats.new_flow_count = q->new_flow_count; + st.qdisc_stats.ce_mark = q->cstats.ce_mark; + st.qdisc_stats.memory_usage = q->memory_usage; + st.qdisc_stats.drop_overmemory = q->drop_overmemory; + + sch_tree_lock(sch); + list_for_each(pos, &q->new_flows) + st.qdisc_stats.new_flows_len++; + + list_for_each(pos, &q->old_flows) + st.qdisc_stats.old_flows_len++; + sch_tree_unlock(sch); + + return gnet_stats_copy_app(d, &st, sizeof(st)); +} + +static struct Qdisc *fq_codel_leaf(struct Qdisc *sch, unsigned long arg) +{ + return NULL; +} + +static unsigned long fq_codel_find(struct Qdisc *sch, u32 classid) +{ + return 0; +} + +static unsigned long fq_codel_bind(struct Qdisc *sch, unsigned long parent, + u32 classid) +{ + return 0; +} + +static void fq_codel_unbind(struct Qdisc *q, unsigned long cl) +{ +} + +static struct tcf_block *fq_codel_tcf_block(struct Qdisc *sch, unsigned long cl, + struct netlink_ext_ack *extack) +{ + struct fq_codel_sched_data *q = qdisc_priv(sch); + + if (cl) + return NULL; + return q->block; +} + +static int fq_codel_dump_class(struct Qdisc *sch, unsigned long cl, + struct sk_buff *skb, struct tcmsg *tcm) +{ + tcm->tcm_handle |= TC_H_MIN(cl); + return 0; +} + +static int fq_codel_dump_class_stats(struct Qdisc *sch, unsigned long cl, + struct gnet_dump *d) +{ + struct fq_codel_sched_data *q = qdisc_priv(sch); + u32 idx = cl - 1; + struct gnet_stats_queue qs = { 0 }; + struct tc_fq_codel_xstats xstats; + + if (idx < q->flows_cnt) { + const struct fq_codel_flow *flow = &q->flows[idx]; + const struct sk_buff *skb; + + memset(&xstats, 0, sizeof(xstats)); + xstats.type = TCA_FQ_CODEL_XSTATS_CLASS; + xstats.class_stats.deficit = flow->deficit; + xstats.class_stats.ldelay = + codel_time_to_us(flow->cvars.ldelay); + xstats.class_stats.count = flow->cvars.count; + xstats.class_stats.lastcount = flow->cvars.lastcount; + xstats.class_stats.dropping = flow->cvars.dropping; + if (flow->cvars.dropping) { + codel_tdiff_t delta = flow->cvars.drop_next - + codel_get_time(); + + xstats.class_stats.drop_next = (delta >= 0) ? + codel_time_to_us(delta) : + -codel_time_to_us(-delta); + } + if (flow->head) { + sch_tree_lock(sch); + skb = flow->head; + while (skb) { + qs.qlen++; + skb = skb->next; + } + sch_tree_unlock(sch); + } + qs.backlog = q->backlogs[idx]; + qs.drops = 0; + } + if (gnet_stats_copy_queue(d, NULL, &qs, qs.qlen) < 0) + return -1; + if (idx < q->flows_cnt) + return gnet_stats_copy_app(d, &xstats, sizeof(xstats)); + return 0; +} + +static void fq_codel_walk(struct Qdisc *sch, struct qdisc_walker *arg) +{ + struct fq_codel_sched_data *q = qdisc_priv(sch); + unsigned int i; + + if (arg->stop) + return; + + for (i = 0; i < q->flows_cnt; i++) { + if (list_empty(&q->flows[i].flowchain)) { + arg->count++; + continue; + } + if (!tc_qdisc_stats_dump(sch, i + 1, arg)) + break; + } +} + +static const struct Qdisc_class_ops fq_codel_class_ops = { + .leaf = fq_codel_leaf, + .find = fq_codel_find, + .tcf_block = fq_codel_tcf_block, + .bind_tcf = fq_codel_bind, + .unbind_tcf = fq_codel_unbind, + .dump = fq_codel_dump_class, + .dump_stats = fq_codel_dump_class_stats, + .walk = fq_codel_walk, +}; + +static struct Qdisc_ops fq_codel_qdisc_ops __read_mostly = { + .cl_ops = &fq_codel_class_ops, + .id = "fq_codel", + .priv_size = sizeof(struct fq_codel_sched_data), + .enqueue = fq_codel_enqueue, + .dequeue = fq_codel_dequeue, + .peek = qdisc_peek_dequeued, + .init = fq_codel_init, + .reset = fq_codel_reset, + .destroy = fq_codel_destroy, + .change = fq_codel_change, + .dump = fq_codel_dump, + .dump_stats = fq_codel_dump_stats, + .owner = THIS_MODULE, +}; + +static int __init fq_codel_module_init(void) +{ + return register_qdisc(&fq_codel_qdisc_ops); +} + +static void __exit fq_codel_module_exit(void) +{ + unregister_qdisc(&fq_codel_qdisc_ops); +} + +module_init(fq_codel_module_init) +module_exit(fq_codel_module_exit) +MODULE_AUTHOR("Eric Dumazet"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Fair Queue CoDel discipline"); diff --git a/net/sched/sch_fq_pie.c b/net/sched/sch_fq_pie.c new file mode 100644 index 000000000..68e6acd0f --- /dev/null +++ b/net/sched/sch_fq_pie.c @@ -0,0 +1,582 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Flow Queue PIE discipline + * + * Copyright (C) 2019 Mohit P. Tahiliani + * Copyright (C) 2019 Sachin D. Patil + * Copyright (C) 2019 V. Saicharan + * Copyright (C) 2019 Mohit Bhasi + * Copyright (C) 2019 Leslie Monis + * Copyright (C) 2019 Gautam Ramakrishnan + */ + +#include +#include +#include +#include +#include + +/* Flow Queue PIE + * + * Principles: + * - Packets are classified on flows. + * - This is a Stochastic model (as we use a hash, several flows might + * be hashed to the same slot) + * - Each flow has a PIE managed queue. + * - Flows are linked onto two (Round Robin) lists, + * so that new flows have priority on old ones. + * - For a given flow, packets are not reordered. + * - Drops during enqueue only. + * - ECN capability is off by default. + * - ECN threshold (if ECN is enabled) is at 10% by default. + * - Uses timestamps to calculate queue delay by default. + */ + +/** + * struct fq_pie_flow - contains data for each flow + * @vars: pie vars associated with the flow + * @deficit: number of remaining byte credits + * @backlog: size of data in the flow + * @qlen: number of packets in the flow + * @flowchain: flowchain for the flow + * @head: first packet in the flow + * @tail: last packet in the flow + */ +struct fq_pie_flow { + struct pie_vars vars; + s32 deficit; + u32 backlog; + u32 qlen; + struct list_head flowchain; + struct sk_buff *head; + struct sk_buff *tail; +}; + +struct fq_pie_sched_data { + struct tcf_proto __rcu *filter_list; /* optional external classifier */ + struct tcf_block *block; + struct fq_pie_flow *flows; + struct Qdisc *sch; + struct list_head old_flows; + struct list_head new_flows; + struct pie_params p_params; + u32 ecn_prob; + u32 flows_cnt; + u32 flows_cursor; + u32 quantum; + u32 memory_limit; + u32 new_flow_count; + u32 memory_usage; + u32 overmemory; + struct pie_stats stats; + struct timer_list adapt_timer; +}; + +static unsigned int fq_pie_hash(const struct fq_pie_sched_data *q, + struct sk_buff *skb) +{ + return reciprocal_scale(skb_get_hash(skb), q->flows_cnt); +} + +static unsigned int fq_pie_classify(struct sk_buff *skb, struct Qdisc *sch, + int *qerr) +{ + struct fq_pie_sched_data *q = qdisc_priv(sch); + struct tcf_proto *filter; + struct tcf_result res; + int result; + + if (TC_H_MAJ(skb->priority) == sch->handle && + TC_H_MIN(skb->priority) > 0 && + TC_H_MIN(skb->priority) <= q->flows_cnt) + return TC_H_MIN(skb->priority); + + filter = rcu_dereference_bh(q->filter_list); + if (!filter) + return fq_pie_hash(q, skb) + 1; + + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; + result = tcf_classify(skb, NULL, filter, &res, false); + if (result >= 0) { +#ifdef CONFIG_NET_CLS_ACT + switch (result) { + case TC_ACT_STOLEN: + case TC_ACT_QUEUED: + case TC_ACT_TRAP: + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_STOLEN; + fallthrough; + case TC_ACT_SHOT: + return 0; + } +#endif + if (TC_H_MIN(res.classid) <= q->flows_cnt) + return TC_H_MIN(res.classid); + } + return 0; +} + +/* add skb to flow queue (tail add) */ +static inline void flow_queue_add(struct fq_pie_flow *flow, + struct sk_buff *skb) +{ + if (!flow->head) + flow->head = skb; + else + flow->tail->next = skb; + flow->tail = skb; + skb->next = NULL; +} + +static int fq_pie_qdisc_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct fq_pie_sched_data *q = qdisc_priv(sch); + struct fq_pie_flow *sel_flow; + int ret; + u8 memory_limited = false; + u8 enqueue = false; + u32 pkt_len; + u32 idx; + + /* Classifies packet into corresponding flow */ + idx = fq_pie_classify(skb, sch, &ret); + if (idx == 0) { + if (ret & __NET_XMIT_BYPASS) + qdisc_qstats_drop(sch); + __qdisc_drop(skb, to_free); + return ret; + } + idx--; + + sel_flow = &q->flows[idx]; + /* Checks whether adding a new packet would exceed memory limit */ + get_pie_cb(skb)->mem_usage = skb->truesize; + memory_limited = q->memory_usage > q->memory_limit + skb->truesize; + + /* Checks if the qdisc is full */ + if (unlikely(qdisc_qlen(sch) >= sch->limit)) { + q->stats.overlimit++; + goto out; + } else if (unlikely(memory_limited)) { + q->overmemory++; + } + + if (!pie_drop_early(sch, &q->p_params, &sel_flow->vars, + sel_flow->backlog, skb->len)) { + enqueue = true; + } else if (q->p_params.ecn && + sel_flow->vars.prob <= (MAX_PROB / 100) * q->ecn_prob && + INET_ECN_set_ce(skb)) { + /* If packet is ecn capable, mark it if drop probability + * is lower than the parameter ecn_prob, else drop it. + */ + q->stats.ecn_mark++; + enqueue = true; + } + if (enqueue) { + /* Set enqueue time only when dq_rate_estimator is disabled. */ + if (!q->p_params.dq_rate_estimator) + pie_set_enqueue_time(skb); + + pkt_len = qdisc_pkt_len(skb); + q->stats.packets_in++; + q->memory_usage += skb->truesize; + sch->qstats.backlog += pkt_len; + sch->q.qlen++; + flow_queue_add(sel_flow, skb); + if (list_empty(&sel_flow->flowchain)) { + list_add_tail(&sel_flow->flowchain, &q->new_flows); + q->new_flow_count++; + sel_flow->deficit = q->quantum; + sel_flow->qlen = 0; + sel_flow->backlog = 0; + } + sel_flow->qlen++; + sel_flow->backlog += pkt_len; + return NET_XMIT_SUCCESS; + } +out: + q->stats.dropped++; + sel_flow->vars.accu_prob = 0; + __qdisc_drop(skb, to_free); + qdisc_qstats_drop(sch); + return NET_XMIT_CN; +} + +static struct netlink_range_validation fq_pie_q_range = { + .min = 1, + .max = 1 << 20, +}; + +static const struct nla_policy fq_pie_policy[TCA_FQ_PIE_MAX + 1] = { + [TCA_FQ_PIE_LIMIT] = {.type = NLA_U32}, + [TCA_FQ_PIE_FLOWS] = {.type = NLA_U32}, + [TCA_FQ_PIE_TARGET] = {.type = NLA_U32}, + [TCA_FQ_PIE_TUPDATE] = {.type = NLA_U32}, + [TCA_FQ_PIE_ALPHA] = {.type = NLA_U32}, + [TCA_FQ_PIE_BETA] = {.type = NLA_U32}, + [TCA_FQ_PIE_QUANTUM] = + NLA_POLICY_FULL_RANGE(NLA_U32, &fq_pie_q_range), + [TCA_FQ_PIE_MEMORY_LIMIT] = {.type = NLA_U32}, + [TCA_FQ_PIE_ECN_PROB] = {.type = NLA_U32}, + [TCA_FQ_PIE_ECN] = {.type = NLA_U32}, + [TCA_FQ_PIE_BYTEMODE] = {.type = NLA_U32}, + [TCA_FQ_PIE_DQ_RATE_ESTIMATOR] = {.type = NLA_U32}, +}; + +static inline struct sk_buff *dequeue_head(struct fq_pie_flow *flow) +{ + struct sk_buff *skb = flow->head; + + flow->head = skb->next; + skb->next = NULL; + return skb; +} + +static struct sk_buff *fq_pie_qdisc_dequeue(struct Qdisc *sch) +{ + struct fq_pie_sched_data *q = qdisc_priv(sch); + struct sk_buff *skb = NULL; + struct fq_pie_flow *flow; + struct list_head *head; + u32 pkt_len; + +begin: + head = &q->new_flows; + if (list_empty(head)) { + head = &q->old_flows; + if (list_empty(head)) + return NULL; + } + + flow = list_first_entry(head, struct fq_pie_flow, flowchain); + /* Flow has exhausted all its credits */ + if (flow->deficit <= 0) { + flow->deficit += q->quantum; + list_move_tail(&flow->flowchain, &q->old_flows); + goto begin; + } + + if (flow->head) { + skb = dequeue_head(flow); + pkt_len = qdisc_pkt_len(skb); + sch->qstats.backlog -= pkt_len; + sch->q.qlen--; + qdisc_bstats_update(sch, skb); + } + + if (!skb) { + /* force a pass through old_flows to prevent starvation */ + if (head == &q->new_flows && !list_empty(&q->old_flows)) + list_move_tail(&flow->flowchain, &q->old_flows); + else + list_del_init(&flow->flowchain); + goto begin; + } + + flow->qlen--; + flow->deficit -= pkt_len; + flow->backlog -= pkt_len; + q->memory_usage -= get_pie_cb(skb)->mem_usage; + pie_process_dequeue(skb, &q->p_params, &flow->vars, flow->backlog); + return skb; +} + +static int fq_pie_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct fq_pie_sched_data *q = qdisc_priv(sch); + struct nlattr *tb[TCA_FQ_PIE_MAX + 1]; + unsigned int len_dropped = 0; + unsigned int num_dropped = 0; + int err; + + err = nla_parse_nested(tb, TCA_FQ_PIE_MAX, opt, fq_pie_policy, extack); + if (err < 0) + return err; + + sch_tree_lock(sch); + if (tb[TCA_FQ_PIE_LIMIT]) { + u32 limit = nla_get_u32(tb[TCA_FQ_PIE_LIMIT]); + + q->p_params.limit = limit; + sch->limit = limit; + } + if (tb[TCA_FQ_PIE_FLOWS]) { + if (q->flows) { + NL_SET_ERR_MSG_MOD(extack, + "Number of flows cannot be changed"); + goto flow_error; + } + q->flows_cnt = nla_get_u32(tb[TCA_FQ_PIE_FLOWS]); + if (!q->flows_cnt || q->flows_cnt > 65536) { + NL_SET_ERR_MSG_MOD(extack, + "Number of flows must range in [1..65536]"); + goto flow_error; + } + } + + /* convert from microseconds to pschedtime */ + if (tb[TCA_FQ_PIE_TARGET]) { + /* target is in us */ + u32 target = nla_get_u32(tb[TCA_FQ_PIE_TARGET]); + + /* convert to pschedtime */ + q->p_params.target = + PSCHED_NS2TICKS((u64)target * NSEC_PER_USEC); + } + + /* tupdate is in jiffies */ + if (tb[TCA_FQ_PIE_TUPDATE]) + q->p_params.tupdate = + usecs_to_jiffies(nla_get_u32(tb[TCA_FQ_PIE_TUPDATE])); + + if (tb[TCA_FQ_PIE_ALPHA]) + q->p_params.alpha = nla_get_u32(tb[TCA_FQ_PIE_ALPHA]); + + if (tb[TCA_FQ_PIE_BETA]) + q->p_params.beta = nla_get_u32(tb[TCA_FQ_PIE_BETA]); + + if (tb[TCA_FQ_PIE_QUANTUM]) + q->quantum = nla_get_u32(tb[TCA_FQ_PIE_QUANTUM]); + + if (tb[TCA_FQ_PIE_MEMORY_LIMIT]) + q->memory_limit = nla_get_u32(tb[TCA_FQ_PIE_MEMORY_LIMIT]); + + if (tb[TCA_FQ_PIE_ECN_PROB]) + q->ecn_prob = nla_get_u32(tb[TCA_FQ_PIE_ECN_PROB]); + + if (tb[TCA_FQ_PIE_ECN]) + q->p_params.ecn = nla_get_u32(tb[TCA_FQ_PIE_ECN]); + + if (tb[TCA_FQ_PIE_BYTEMODE]) + q->p_params.bytemode = nla_get_u32(tb[TCA_FQ_PIE_BYTEMODE]); + + if (tb[TCA_FQ_PIE_DQ_RATE_ESTIMATOR]) + q->p_params.dq_rate_estimator = + nla_get_u32(tb[TCA_FQ_PIE_DQ_RATE_ESTIMATOR]); + + /* Drop excess packets if new limit is lower */ + while (sch->q.qlen > sch->limit) { + struct sk_buff *skb = fq_pie_qdisc_dequeue(sch); + + len_dropped += qdisc_pkt_len(skb); + num_dropped += 1; + rtnl_kfree_skbs(skb, skb); + } + qdisc_tree_reduce_backlog(sch, num_dropped, len_dropped); + + sch_tree_unlock(sch); + return 0; + +flow_error: + sch_tree_unlock(sch); + return -EINVAL; +} + +static void fq_pie_timer(struct timer_list *t) +{ + struct fq_pie_sched_data *q = from_timer(q, t, adapt_timer); + unsigned long next, tupdate; + struct Qdisc *sch = q->sch; + spinlock_t *root_lock; /* to lock qdisc for probability calculations */ + int max_cnt, i; + + rcu_read_lock(); + root_lock = qdisc_lock(qdisc_root_sleeping(sch)); + spin_lock(root_lock); + + /* Limit this expensive loop to 2048 flows per round. */ + max_cnt = min_t(int, q->flows_cnt - q->flows_cursor, 2048); + for (i = 0; i < max_cnt; i++) { + pie_calculate_probability(&q->p_params, + &q->flows[q->flows_cursor].vars, + q->flows[q->flows_cursor].backlog); + q->flows_cursor++; + } + + tupdate = q->p_params.tupdate; + next = 0; + if (q->flows_cursor >= q->flows_cnt) { + q->flows_cursor = 0; + next = tupdate; + } + if (tupdate) + mod_timer(&q->adapt_timer, jiffies + next); + spin_unlock(root_lock); + rcu_read_unlock(); +} + +static int fq_pie_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct fq_pie_sched_data *q = qdisc_priv(sch); + int err; + u32 idx; + + pie_params_init(&q->p_params); + sch->limit = 10 * 1024; + q->p_params.limit = sch->limit; + q->quantum = psched_mtu(qdisc_dev(sch)); + q->sch = sch; + q->ecn_prob = 10; + q->flows_cnt = 1024; + q->memory_limit = SZ_32M; + + INIT_LIST_HEAD(&q->new_flows); + INIT_LIST_HEAD(&q->old_flows); + timer_setup(&q->adapt_timer, fq_pie_timer, 0); + + if (opt) { + err = fq_pie_change(sch, opt, extack); + + if (err) + return err; + } + + err = tcf_block_get(&q->block, &q->filter_list, sch, extack); + if (err) + goto init_failure; + + q->flows = kvcalloc(q->flows_cnt, sizeof(struct fq_pie_flow), + GFP_KERNEL); + if (!q->flows) { + err = -ENOMEM; + goto init_failure; + } + for (idx = 0; idx < q->flows_cnt; idx++) { + struct fq_pie_flow *flow = q->flows + idx; + + INIT_LIST_HEAD(&flow->flowchain); + pie_vars_init(&flow->vars); + } + + mod_timer(&q->adapt_timer, jiffies + HZ / 2); + + return 0; + +init_failure: + q->flows_cnt = 0; + + return err; +} + +static int fq_pie_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct fq_pie_sched_data *q = qdisc_priv(sch); + struct nlattr *opts; + + opts = nla_nest_start(skb, TCA_OPTIONS); + if (!opts) + return -EMSGSIZE; + + /* convert target from pschedtime to us */ + if (nla_put_u32(skb, TCA_FQ_PIE_LIMIT, sch->limit) || + nla_put_u32(skb, TCA_FQ_PIE_FLOWS, q->flows_cnt) || + nla_put_u32(skb, TCA_FQ_PIE_TARGET, + ((u32)PSCHED_TICKS2NS(q->p_params.target)) / + NSEC_PER_USEC) || + nla_put_u32(skb, TCA_FQ_PIE_TUPDATE, + jiffies_to_usecs(q->p_params.tupdate)) || + nla_put_u32(skb, TCA_FQ_PIE_ALPHA, q->p_params.alpha) || + nla_put_u32(skb, TCA_FQ_PIE_BETA, q->p_params.beta) || + nla_put_u32(skb, TCA_FQ_PIE_QUANTUM, q->quantum) || + nla_put_u32(skb, TCA_FQ_PIE_MEMORY_LIMIT, q->memory_limit) || + nla_put_u32(skb, TCA_FQ_PIE_ECN_PROB, q->ecn_prob) || + nla_put_u32(skb, TCA_FQ_PIE_ECN, q->p_params.ecn) || + nla_put_u32(skb, TCA_FQ_PIE_BYTEMODE, q->p_params.bytemode) || + nla_put_u32(skb, TCA_FQ_PIE_DQ_RATE_ESTIMATOR, + q->p_params.dq_rate_estimator)) + goto nla_put_failure; + + return nla_nest_end(skb, opts); + +nla_put_failure: + nla_nest_cancel(skb, opts); + return -EMSGSIZE; +} + +static int fq_pie_dump_stats(struct Qdisc *sch, struct gnet_dump *d) +{ + struct fq_pie_sched_data *q = qdisc_priv(sch); + struct tc_fq_pie_xstats st = { + .packets_in = q->stats.packets_in, + .overlimit = q->stats.overlimit, + .overmemory = q->overmemory, + .dropped = q->stats.dropped, + .ecn_mark = q->stats.ecn_mark, + .new_flow_count = q->new_flow_count, + .memory_usage = q->memory_usage, + }; + struct list_head *pos; + + sch_tree_lock(sch); + list_for_each(pos, &q->new_flows) + st.new_flows_len++; + + list_for_each(pos, &q->old_flows) + st.old_flows_len++; + sch_tree_unlock(sch); + + return gnet_stats_copy_app(d, &st, sizeof(st)); +} + +static void fq_pie_reset(struct Qdisc *sch) +{ + struct fq_pie_sched_data *q = qdisc_priv(sch); + u32 idx; + + INIT_LIST_HEAD(&q->new_flows); + INIT_LIST_HEAD(&q->old_flows); + for (idx = 0; idx < q->flows_cnt; idx++) { + struct fq_pie_flow *flow = q->flows + idx; + + /* Removes all packets from flow */ + rtnl_kfree_skbs(flow->head, flow->tail); + flow->head = NULL; + + INIT_LIST_HEAD(&flow->flowchain); + pie_vars_init(&flow->vars); + } +} + +static void fq_pie_destroy(struct Qdisc *sch) +{ + struct fq_pie_sched_data *q = qdisc_priv(sch); + + tcf_block_put(q->block); + q->p_params.tupdate = 0; + del_timer_sync(&q->adapt_timer); + kvfree(q->flows); +} + +static struct Qdisc_ops fq_pie_qdisc_ops __read_mostly = { + .id = "fq_pie", + .priv_size = sizeof(struct fq_pie_sched_data), + .enqueue = fq_pie_qdisc_enqueue, + .dequeue = fq_pie_qdisc_dequeue, + .peek = qdisc_peek_dequeued, + .init = fq_pie_init, + .destroy = fq_pie_destroy, + .reset = fq_pie_reset, + .change = fq_pie_change, + .dump = fq_pie_dump, + .dump_stats = fq_pie_dump_stats, + .owner = THIS_MODULE, +}; + +static int __init fq_pie_module_init(void) +{ + return register_qdisc(&fq_pie_qdisc_ops); +} + +static void __exit fq_pie_module_exit(void) +{ + unregister_qdisc(&fq_pie_qdisc_ops); +} + +module_init(fq_pie_module_init); +module_exit(fq_pie_module_exit); + +MODULE_DESCRIPTION("Flow Queue Proportional Integral controller Enhanced (FQ-PIE)"); +MODULE_AUTHOR("Mohit P. Tahiliani"); +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_frag.c b/net/sched/sch_frag.c new file mode 100644 index 000000000..a9bd0a235 --- /dev/null +++ b/net/sched/sch_frag.c @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: GPL-2.0 OR Linux-OpenIB +#include +#include +#include +#include +#include +#include +#include + +struct sch_frag_data { + unsigned long dst; + struct qdisc_skb_cb cb; + __be16 inner_protocol; + u16 vlan_tci; + __be16 vlan_proto; + unsigned int l2_len; + u8 l2_data[VLAN_ETH_HLEN]; + int (*xmit)(struct sk_buff *skb); +}; + +static DEFINE_PER_CPU(struct sch_frag_data, sch_frag_data_storage); + +static int sch_frag_xmit(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct sch_frag_data *data = this_cpu_ptr(&sch_frag_data_storage); + + if (skb_cow_head(skb, data->l2_len) < 0) { + kfree_skb(skb); + return -ENOMEM; + } + + __skb_dst_copy(skb, data->dst); + *qdisc_skb_cb(skb) = data->cb; + skb->inner_protocol = data->inner_protocol; + if (data->vlan_tci & VLAN_CFI_MASK) + __vlan_hwaccel_put_tag(skb, data->vlan_proto, + data->vlan_tci & ~VLAN_CFI_MASK); + else + __vlan_hwaccel_clear_tag(skb); + + /* Reconstruct the MAC header. */ + skb_push(skb, data->l2_len); + memcpy(skb->data, &data->l2_data, data->l2_len); + skb_postpush_rcsum(skb, skb->data, data->l2_len); + skb_reset_mac_header(skb); + + return data->xmit(skb); +} + +static void sch_frag_prepare_frag(struct sk_buff *skb, + int (*xmit)(struct sk_buff *skb)) +{ + unsigned int hlen = skb_network_offset(skb); + struct sch_frag_data *data; + + data = this_cpu_ptr(&sch_frag_data_storage); + data->dst = skb->_skb_refdst; + data->cb = *qdisc_skb_cb(skb); + data->xmit = xmit; + data->inner_protocol = skb->inner_protocol; + if (skb_vlan_tag_present(skb)) + data->vlan_tci = skb_vlan_tag_get(skb) | VLAN_CFI_MASK; + else + data->vlan_tci = 0; + data->vlan_proto = skb->vlan_proto; + data->l2_len = hlen; + memcpy(&data->l2_data, skb->data, hlen); + + memset(IPCB(skb), 0, sizeof(struct inet_skb_parm)); + skb_pull(skb, hlen); +} + +static unsigned int +sch_frag_dst_get_mtu(const struct dst_entry *dst) +{ + return dst->dev->mtu; +} + +static struct dst_ops sch_frag_dst_ops = { + .family = AF_UNSPEC, + .mtu = sch_frag_dst_get_mtu, +}; + +static int sch_fragment(struct net *net, struct sk_buff *skb, + u16 mru, int (*xmit)(struct sk_buff *skb)) +{ + int ret = -1; + + if (skb_network_offset(skb) > VLAN_ETH_HLEN) { + net_warn_ratelimited("L2 header too long to fragment\n"); + goto err; + } + + if (skb_protocol(skb, true) == htons(ETH_P_IP)) { + struct rtable sch_frag_rt = { 0 }; + unsigned long orig_dst; + + sch_frag_prepare_frag(skb, xmit); + dst_init(&sch_frag_rt.dst, &sch_frag_dst_ops, NULL, 1, + DST_OBSOLETE_NONE, DST_NOCOUNT); + sch_frag_rt.dst.dev = skb->dev; + + orig_dst = skb->_skb_refdst; + skb_dst_set_noref(skb, &sch_frag_rt.dst); + IPCB(skb)->frag_max_size = mru; + + ret = ip_do_fragment(net, skb->sk, skb, sch_frag_xmit); + refdst_drop(orig_dst); + } else if (skb_protocol(skb, true) == htons(ETH_P_IPV6)) { + unsigned long orig_dst; + struct rt6_info sch_frag_rt; + + sch_frag_prepare_frag(skb, xmit); + memset(&sch_frag_rt, 0, sizeof(sch_frag_rt)); + dst_init(&sch_frag_rt.dst, &sch_frag_dst_ops, NULL, 1, + DST_OBSOLETE_NONE, DST_NOCOUNT); + sch_frag_rt.dst.dev = skb->dev; + + orig_dst = skb->_skb_refdst; + skb_dst_set_noref(skb, &sch_frag_rt.dst); + IP6CB(skb)->frag_max_size = mru; + + ret = ipv6_stub->ipv6_fragment(net, skb->sk, skb, + sch_frag_xmit); + refdst_drop(orig_dst); + } else { + net_warn_ratelimited("Fail frag %s: eth=%x, MRU=%d, MTU=%d\n", + netdev_name(skb->dev), + ntohs(skb_protocol(skb, true)), mru, + skb->dev->mtu); + goto err; + } + + return ret; +err: + kfree_skb(skb); + return ret; +} + +int sch_frag_xmit_hook(struct sk_buff *skb, int (*xmit)(struct sk_buff *skb)) +{ + u16 mru = tc_skb_cb(skb)->mru; + int err; + + if (mru && skb->len > mru + skb->dev->hard_header_len) + err = sch_fragment(dev_net(skb->dev), skb, mru, xmit); + else + err = xmit(skb); + + return err; +} +EXPORT_SYMBOL_GPL(sch_frag_xmit_hook); diff --git a/net/sched/sch_generic.c b/net/sched/sch_generic.c new file mode 100644 index 000000000..a5693e25b --- /dev/null +++ b/net/sched/sch_generic.c @@ -0,0 +1,1604 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/sch_generic.c Generic packet scheduler routines. + * + * Authors: Alexey Kuznetsov, + * Jamal Hadi Salim, 990601 + * - Ingress support + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* Qdisc to use by default */ +const struct Qdisc_ops *default_qdisc_ops = &pfifo_fast_ops; +EXPORT_SYMBOL(default_qdisc_ops); + +static void qdisc_maybe_clear_missed(struct Qdisc *q, + const struct netdev_queue *txq) +{ + clear_bit(__QDISC_STATE_MISSED, &q->state); + + /* Make sure the below netif_xmit_frozen_or_stopped() + * checking happens after clearing STATE_MISSED. + */ + smp_mb__after_atomic(); + + /* Checking netif_xmit_frozen_or_stopped() again to + * make sure STATE_MISSED is set if the STATE_MISSED + * set by netif_tx_wake_queue()'s rescheduling of + * net_tx_action() is cleared by the above clear_bit(). + */ + if (!netif_xmit_frozen_or_stopped(txq)) + set_bit(__QDISC_STATE_MISSED, &q->state); + else + set_bit(__QDISC_STATE_DRAINING, &q->state); +} + +/* Main transmission queue. */ + +/* Modifications to data participating in scheduling must be protected with + * qdisc_lock(qdisc) spinlock. + * + * The idea is the following: + * - enqueue, dequeue are serialized via qdisc root lock + * - ingress filtering is also serialized via qdisc root lock + * - updates to tree and tree walking are only done under the rtnl mutex. + */ + +#define SKB_XOFF_MAGIC ((struct sk_buff *)1UL) + +static inline struct sk_buff *__skb_dequeue_bad_txq(struct Qdisc *q) +{ + const struct netdev_queue *txq = q->dev_queue; + spinlock_t *lock = NULL; + struct sk_buff *skb; + + if (q->flags & TCQ_F_NOLOCK) { + lock = qdisc_lock(q); + spin_lock(lock); + } + + skb = skb_peek(&q->skb_bad_txq); + if (skb) { + /* check the reason of requeuing without tx lock first */ + txq = skb_get_tx_queue(txq->dev, skb); + if (!netif_xmit_frozen_or_stopped(txq)) { + skb = __skb_dequeue(&q->skb_bad_txq); + if (qdisc_is_percpu_stats(q)) { + qdisc_qstats_cpu_backlog_dec(q, skb); + qdisc_qstats_cpu_qlen_dec(q); + } else { + qdisc_qstats_backlog_dec(q, skb); + q->q.qlen--; + } + } else { + skb = SKB_XOFF_MAGIC; + qdisc_maybe_clear_missed(q, txq); + } + } + + if (lock) + spin_unlock(lock); + + return skb; +} + +static inline struct sk_buff *qdisc_dequeue_skb_bad_txq(struct Qdisc *q) +{ + struct sk_buff *skb = skb_peek(&q->skb_bad_txq); + + if (unlikely(skb)) + skb = __skb_dequeue_bad_txq(q); + + return skb; +} + +static inline void qdisc_enqueue_skb_bad_txq(struct Qdisc *q, + struct sk_buff *skb) +{ + spinlock_t *lock = NULL; + + if (q->flags & TCQ_F_NOLOCK) { + lock = qdisc_lock(q); + spin_lock(lock); + } + + __skb_queue_tail(&q->skb_bad_txq, skb); + + if (qdisc_is_percpu_stats(q)) { + qdisc_qstats_cpu_backlog_inc(q, skb); + qdisc_qstats_cpu_qlen_inc(q); + } else { + qdisc_qstats_backlog_inc(q, skb); + q->q.qlen++; + } + + if (lock) + spin_unlock(lock); +} + +static inline void dev_requeue_skb(struct sk_buff *skb, struct Qdisc *q) +{ + spinlock_t *lock = NULL; + + if (q->flags & TCQ_F_NOLOCK) { + lock = qdisc_lock(q); + spin_lock(lock); + } + + while (skb) { + struct sk_buff *next = skb->next; + + __skb_queue_tail(&q->gso_skb, skb); + + /* it's still part of the queue */ + if (qdisc_is_percpu_stats(q)) { + qdisc_qstats_cpu_requeues_inc(q); + qdisc_qstats_cpu_backlog_inc(q, skb); + qdisc_qstats_cpu_qlen_inc(q); + } else { + q->qstats.requeues++; + qdisc_qstats_backlog_inc(q, skb); + q->q.qlen++; + } + + skb = next; + } + + if (lock) { + spin_unlock(lock); + set_bit(__QDISC_STATE_MISSED, &q->state); + } else { + __netif_schedule(q); + } +} + +static void try_bulk_dequeue_skb(struct Qdisc *q, + struct sk_buff *skb, + const struct netdev_queue *txq, + int *packets) +{ + int bytelimit = qdisc_avail_bulklimit(txq) - skb->len; + + while (bytelimit > 0) { + struct sk_buff *nskb = q->dequeue(q); + + if (!nskb) + break; + + bytelimit -= nskb->len; /* covers GSO len */ + skb->next = nskb; + skb = nskb; + (*packets)++; /* GSO counts as one pkt */ + } + skb_mark_not_on_list(skb); +} + +/* This variant of try_bulk_dequeue_skb() makes sure + * all skbs in the chain are for the same txq + */ +static void try_bulk_dequeue_skb_slow(struct Qdisc *q, + struct sk_buff *skb, + int *packets) +{ + int mapping = skb_get_queue_mapping(skb); + struct sk_buff *nskb; + int cnt = 0; + + do { + nskb = q->dequeue(q); + if (!nskb) + break; + if (unlikely(skb_get_queue_mapping(nskb) != mapping)) { + qdisc_enqueue_skb_bad_txq(q, nskb); + break; + } + skb->next = nskb; + skb = nskb; + } while (++cnt < 8); + (*packets) += cnt; + skb_mark_not_on_list(skb); +} + +/* Note that dequeue_skb can possibly return a SKB list (via skb->next). + * A requeued skb (via q->gso_skb) can also be a SKB list. + */ +static struct sk_buff *dequeue_skb(struct Qdisc *q, bool *validate, + int *packets) +{ + const struct netdev_queue *txq = q->dev_queue; + struct sk_buff *skb = NULL; + + *packets = 1; + if (unlikely(!skb_queue_empty(&q->gso_skb))) { + spinlock_t *lock = NULL; + + if (q->flags & TCQ_F_NOLOCK) { + lock = qdisc_lock(q); + spin_lock(lock); + } + + skb = skb_peek(&q->gso_skb); + + /* skb may be null if another cpu pulls gso_skb off in between + * empty check and lock. + */ + if (!skb) { + if (lock) + spin_unlock(lock); + goto validate; + } + + /* skb in gso_skb were already validated */ + *validate = false; + if (xfrm_offload(skb)) + *validate = true; + /* check the reason of requeuing without tx lock first */ + txq = skb_get_tx_queue(txq->dev, skb); + if (!netif_xmit_frozen_or_stopped(txq)) { + skb = __skb_dequeue(&q->gso_skb); + if (qdisc_is_percpu_stats(q)) { + qdisc_qstats_cpu_backlog_dec(q, skb); + qdisc_qstats_cpu_qlen_dec(q); + } else { + qdisc_qstats_backlog_dec(q, skb); + q->q.qlen--; + } + } else { + skb = NULL; + qdisc_maybe_clear_missed(q, txq); + } + if (lock) + spin_unlock(lock); + goto trace; + } +validate: + *validate = true; + + if ((q->flags & TCQ_F_ONETXQUEUE) && + netif_xmit_frozen_or_stopped(txq)) { + qdisc_maybe_clear_missed(q, txq); + return skb; + } + + skb = qdisc_dequeue_skb_bad_txq(q); + if (unlikely(skb)) { + if (skb == SKB_XOFF_MAGIC) + return NULL; + goto bulk; + } + skb = q->dequeue(q); + if (skb) { +bulk: + if (qdisc_may_bulk(q)) + try_bulk_dequeue_skb(q, skb, txq, packets); + else + try_bulk_dequeue_skb_slow(q, skb, packets); + } +trace: + trace_qdisc_dequeue(q, txq, *packets, skb); + return skb; +} + +/* + * Transmit possibly several skbs, and handle the return status as + * required. Owning qdisc running bit guarantees that only one CPU + * can execute this function. + * + * Returns to the caller: + * false - hardware queue frozen backoff + * true - feel free to send more pkts + */ +bool sch_direct_xmit(struct sk_buff *skb, struct Qdisc *q, + struct net_device *dev, struct netdev_queue *txq, + spinlock_t *root_lock, bool validate) +{ + int ret = NETDEV_TX_BUSY; + bool again = false; + + /* And release qdisc */ + if (root_lock) + spin_unlock(root_lock); + + /* Note that we validate skb (GSO, checksum, ...) outside of locks */ + if (validate) + skb = validate_xmit_skb_list(skb, dev, &again); + +#ifdef CONFIG_XFRM_OFFLOAD + if (unlikely(again)) { + if (root_lock) + spin_lock(root_lock); + + dev_requeue_skb(skb, q); + return false; + } +#endif + + if (likely(skb)) { + HARD_TX_LOCK(dev, txq, smp_processor_id()); + if (!netif_xmit_frozen_or_stopped(txq)) + skb = dev_hard_start_xmit(skb, dev, txq, &ret); + else + qdisc_maybe_clear_missed(q, txq); + + HARD_TX_UNLOCK(dev, txq); + } else { + if (root_lock) + spin_lock(root_lock); + return true; + } + + if (root_lock) + spin_lock(root_lock); + + if (!dev_xmit_complete(ret)) { + /* Driver returned NETDEV_TX_BUSY - requeue skb */ + if (unlikely(ret != NETDEV_TX_BUSY)) + net_warn_ratelimited("BUG %s code %d qlen %d\n", + dev->name, ret, q->q.qlen); + + dev_requeue_skb(skb, q); + return false; + } + + return true; +} + +/* + * NOTE: Called under qdisc_lock(q) with locally disabled BH. + * + * running seqcount guarantees only one CPU can process + * this qdisc at a time. qdisc_lock(q) serializes queue accesses for + * this queue. + * + * netif_tx_lock serializes accesses to device driver. + * + * qdisc_lock(q) and netif_tx_lock are mutually exclusive, + * if one is grabbed, another must be free. + * + * Note, that this procedure can be called by a watchdog timer + * + * Returns to the caller: + * 0 - queue is empty or throttled. + * >0 - queue is not empty. + * + */ +static inline bool qdisc_restart(struct Qdisc *q, int *packets) +{ + spinlock_t *root_lock = NULL; + struct netdev_queue *txq; + struct net_device *dev; + struct sk_buff *skb; + bool validate; + + /* Dequeue packet */ + skb = dequeue_skb(q, &validate, packets); + if (unlikely(!skb)) + return false; + + if (!(q->flags & TCQ_F_NOLOCK)) + root_lock = qdisc_lock(q); + + dev = qdisc_dev(q); + txq = skb_get_tx_queue(dev, skb); + + return sch_direct_xmit(skb, q, dev, txq, root_lock, validate); +} + +void __qdisc_run(struct Qdisc *q) +{ + int quota = READ_ONCE(dev_tx_weight); + int packets; + + while (qdisc_restart(q, &packets)) { + quota -= packets; + if (quota <= 0) { + if (q->flags & TCQ_F_NOLOCK) + set_bit(__QDISC_STATE_MISSED, &q->state); + else + __netif_schedule(q); + + break; + } + } +} + +unsigned long dev_trans_start(struct net_device *dev) +{ + unsigned long res = READ_ONCE(netdev_get_tx_queue(dev, 0)->trans_start); + unsigned long val; + unsigned int i; + + for (i = 1; i < dev->num_tx_queues; i++) { + val = READ_ONCE(netdev_get_tx_queue(dev, i)->trans_start); + if (val && time_after(val, res)) + res = val; + } + + return res; +} +EXPORT_SYMBOL(dev_trans_start); + +static void netif_freeze_queues(struct net_device *dev) +{ + unsigned int i; + int cpu; + + cpu = smp_processor_id(); + for (i = 0; i < dev->num_tx_queues; i++) { + struct netdev_queue *txq = netdev_get_tx_queue(dev, i); + + /* We are the only thread of execution doing a + * freeze, but we have to grab the _xmit_lock in + * order to synchronize with threads which are in + * the ->hard_start_xmit() handler and already + * checked the frozen bit. + */ + __netif_tx_lock(txq, cpu); + set_bit(__QUEUE_STATE_FROZEN, &txq->state); + __netif_tx_unlock(txq); + } +} + +void netif_tx_lock(struct net_device *dev) +{ + spin_lock(&dev->tx_global_lock); + netif_freeze_queues(dev); +} +EXPORT_SYMBOL(netif_tx_lock); + +static void netif_unfreeze_queues(struct net_device *dev) +{ + unsigned int i; + + for (i = 0; i < dev->num_tx_queues; i++) { + struct netdev_queue *txq = netdev_get_tx_queue(dev, i); + + /* No need to grab the _xmit_lock here. If the + * queue is not stopped for another reason, we + * force a schedule. + */ + clear_bit(__QUEUE_STATE_FROZEN, &txq->state); + netif_schedule_queue(txq); + } +} + +void netif_tx_unlock(struct net_device *dev) +{ + netif_unfreeze_queues(dev); + spin_unlock(&dev->tx_global_lock); +} +EXPORT_SYMBOL(netif_tx_unlock); + +static void dev_watchdog(struct timer_list *t) +{ + struct net_device *dev = from_timer(dev, t, watchdog_timer); + bool release = true; + + spin_lock(&dev->tx_global_lock); + if (!qdisc_tx_is_noop(dev)) { + if (netif_device_present(dev) && + netif_running(dev) && + netif_carrier_ok(dev)) { + int some_queue_timedout = 0; + unsigned int i; + unsigned long trans_start; + + for (i = 0; i < dev->num_tx_queues; i++) { + struct netdev_queue *txq; + + txq = netdev_get_tx_queue(dev, i); + trans_start = READ_ONCE(txq->trans_start); + if (netif_xmit_stopped(txq) && + time_after(jiffies, (trans_start + + dev->watchdog_timeo))) { + some_queue_timedout = 1; + atomic_long_inc(&txq->trans_timeout); + break; + } + } + + if (unlikely(some_queue_timedout)) { + trace_net_dev_xmit_timeout(dev, i); + WARN_ONCE(1, KERN_INFO "NETDEV WATCHDOG: %s (%s): transmit queue %u timed out\n", + dev->name, netdev_drivername(dev), i); + netif_freeze_queues(dev); + dev->netdev_ops->ndo_tx_timeout(dev, i); + netif_unfreeze_queues(dev); + } + if (!mod_timer(&dev->watchdog_timer, + round_jiffies(jiffies + + dev->watchdog_timeo))) + release = false; + } + } + spin_unlock(&dev->tx_global_lock); + + if (release) + netdev_put(dev, &dev->watchdog_dev_tracker); +} + +void __netdev_watchdog_up(struct net_device *dev) +{ + if (dev->netdev_ops->ndo_tx_timeout) { + if (dev->watchdog_timeo <= 0) + dev->watchdog_timeo = 5*HZ; + if (!mod_timer(&dev->watchdog_timer, + round_jiffies(jiffies + dev->watchdog_timeo))) + netdev_hold(dev, &dev->watchdog_dev_tracker, + GFP_ATOMIC); + } +} +EXPORT_SYMBOL_GPL(__netdev_watchdog_up); + +static void dev_watchdog_up(struct net_device *dev) +{ + __netdev_watchdog_up(dev); +} + +static void dev_watchdog_down(struct net_device *dev) +{ + netif_tx_lock_bh(dev); + if (del_timer(&dev->watchdog_timer)) + netdev_put(dev, &dev->watchdog_dev_tracker); + netif_tx_unlock_bh(dev); +} + +/** + * netif_carrier_on - set carrier + * @dev: network device + * + * Device has detected acquisition of carrier. + */ +void netif_carrier_on(struct net_device *dev) +{ + if (test_and_clear_bit(__LINK_STATE_NOCARRIER, &dev->state)) { + if (dev->reg_state == NETREG_UNINITIALIZED) + return; + atomic_inc(&dev->carrier_up_count); + linkwatch_fire_event(dev); + if (netif_running(dev)) + __netdev_watchdog_up(dev); + } +} +EXPORT_SYMBOL(netif_carrier_on); + +/** + * netif_carrier_off - clear carrier + * @dev: network device + * + * Device has detected loss of carrier. + */ +void netif_carrier_off(struct net_device *dev) +{ + if (!test_and_set_bit(__LINK_STATE_NOCARRIER, &dev->state)) { + if (dev->reg_state == NETREG_UNINITIALIZED) + return; + atomic_inc(&dev->carrier_down_count); + linkwatch_fire_event(dev); + } +} +EXPORT_SYMBOL(netif_carrier_off); + +/** + * netif_carrier_event - report carrier state event + * @dev: network device + * + * Device has detected a carrier event but the carrier state wasn't changed. + * Use in drivers when querying carrier state asynchronously, to avoid missing + * events (link flaps) if link recovers before it's queried. + */ +void netif_carrier_event(struct net_device *dev) +{ + if (dev->reg_state == NETREG_UNINITIALIZED) + return; + atomic_inc(&dev->carrier_up_count); + atomic_inc(&dev->carrier_down_count); + linkwatch_fire_event(dev); +} +EXPORT_SYMBOL_GPL(netif_carrier_event); + +/* "NOOP" scheduler: the best scheduler, recommended for all interfaces + under all circumstances. It is difficult to invent anything faster or + cheaper. + */ + +static int noop_enqueue(struct sk_buff *skb, struct Qdisc *qdisc, + struct sk_buff **to_free) +{ + __qdisc_drop(skb, to_free); + return NET_XMIT_CN; +} + +static struct sk_buff *noop_dequeue(struct Qdisc *qdisc) +{ + return NULL; +} + +struct Qdisc_ops noop_qdisc_ops __read_mostly = { + .id = "noop", + .priv_size = 0, + .enqueue = noop_enqueue, + .dequeue = noop_dequeue, + .peek = noop_dequeue, + .owner = THIS_MODULE, +}; + +static struct netdev_queue noop_netdev_queue = { + RCU_POINTER_INITIALIZER(qdisc, &noop_qdisc), + RCU_POINTER_INITIALIZER(qdisc_sleeping, &noop_qdisc), +}; + +struct Qdisc noop_qdisc = { + .enqueue = noop_enqueue, + .dequeue = noop_dequeue, + .flags = TCQ_F_BUILTIN, + .ops = &noop_qdisc_ops, + .q.lock = __SPIN_LOCK_UNLOCKED(noop_qdisc.q.lock), + .dev_queue = &noop_netdev_queue, + .busylock = __SPIN_LOCK_UNLOCKED(noop_qdisc.busylock), + .gso_skb = { + .next = (struct sk_buff *)&noop_qdisc.gso_skb, + .prev = (struct sk_buff *)&noop_qdisc.gso_skb, + .qlen = 0, + .lock = __SPIN_LOCK_UNLOCKED(noop_qdisc.gso_skb.lock), + }, + .skb_bad_txq = { + .next = (struct sk_buff *)&noop_qdisc.skb_bad_txq, + .prev = (struct sk_buff *)&noop_qdisc.skb_bad_txq, + .qlen = 0, + .lock = __SPIN_LOCK_UNLOCKED(noop_qdisc.skb_bad_txq.lock), + }, +}; +EXPORT_SYMBOL(noop_qdisc); + +static int noqueue_init(struct Qdisc *qdisc, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + /* register_qdisc() assigns a default of noop_enqueue if unset, + * but __dev_queue_xmit() treats noqueue only as such + * if this is NULL - so clear it here. */ + qdisc->enqueue = NULL; + return 0; +} + +struct Qdisc_ops noqueue_qdisc_ops __read_mostly = { + .id = "noqueue", + .priv_size = 0, + .init = noqueue_init, + .enqueue = noop_enqueue, + .dequeue = noop_dequeue, + .peek = noop_dequeue, + .owner = THIS_MODULE, +}; + +static const u8 prio2band[TC_PRIO_MAX + 1] = { + 1, 2, 2, 2, 1, 2, 0, 0 , 1, 1, 1, 1, 1, 1, 1, 1 +}; + +/* 3-band FIFO queue: old style, but should be a bit faster than + generic prio+fifo combination. + */ + +#define PFIFO_FAST_BANDS 3 + +/* + * Private data for a pfifo_fast scheduler containing: + * - rings for priority bands + */ +struct pfifo_fast_priv { + struct skb_array q[PFIFO_FAST_BANDS]; +}; + +static inline struct skb_array *band2list(struct pfifo_fast_priv *priv, + int band) +{ + return &priv->q[band]; +} + +static int pfifo_fast_enqueue(struct sk_buff *skb, struct Qdisc *qdisc, + struct sk_buff **to_free) +{ + int band = prio2band[skb->priority & TC_PRIO_MAX]; + struct pfifo_fast_priv *priv = qdisc_priv(qdisc); + struct skb_array *q = band2list(priv, band); + unsigned int pkt_len = qdisc_pkt_len(skb); + int err; + + err = skb_array_produce(q, skb); + + if (unlikely(err)) { + if (qdisc_is_percpu_stats(qdisc)) + return qdisc_drop_cpu(skb, qdisc, to_free); + else + return qdisc_drop(skb, qdisc, to_free); + } + + qdisc_update_stats_at_enqueue(qdisc, pkt_len); + return NET_XMIT_SUCCESS; +} + +static struct sk_buff *pfifo_fast_dequeue(struct Qdisc *qdisc) +{ + struct pfifo_fast_priv *priv = qdisc_priv(qdisc); + struct sk_buff *skb = NULL; + bool need_retry = true; + int band; + +retry: + for (band = 0; band < PFIFO_FAST_BANDS && !skb; band++) { + struct skb_array *q = band2list(priv, band); + + if (__skb_array_empty(q)) + continue; + + skb = __skb_array_consume(q); + } + if (likely(skb)) { + qdisc_update_stats_at_dequeue(qdisc, skb); + } else if (need_retry && + READ_ONCE(qdisc->state) & QDISC_STATE_NON_EMPTY) { + /* Delay clearing the STATE_MISSED here to reduce + * the overhead of the second spin_trylock() in + * qdisc_run_begin() and __netif_schedule() calling + * in qdisc_run_end(). + */ + clear_bit(__QDISC_STATE_MISSED, &qdisc->state); + clear_bit(__QDISC_STATE_DRAINING, &qdisc->state); + + /* Make sure dequeuing happens after clearing + * STATE_MISSED. + */ + smp_mb__after_atomic(); + + need_retry = false; + + goto retry; + } + + return skb; +} + +static struct sk_buff *pfifo_fast_peek(struct Qdisc *qdisc) +{ + struct pfifo_fast_priv *priv = qdisc_priv(qdisc); + struct sk_buff *skb = NULL; + int band; + + for (band = 0; band < PFIFO_FAST_BANDS && !skb; band++) { + struct skb_array *q = band2list(priv, band); + + skb = __skb_array_peek(q); + } + + return skb; +} + +static void pfifo_fast_reset(struct Qdisc *qdisc) +{ + int i, band; + struct pfifo_fast_priv *priv = qdisc_priv(qdisc); + + for (band = 0; band < PFIFO_FAST_BANDS; band++) { + struct skb_array *q = band2list(priv, band); + struct sk_buff *skb; + + /* NULL ring is possible if destroy path is due to a failed + * skb_array_init() in pfifo_fast_init() case. + */ + if (!q->ring.queue) + continue; + + while ((skb = __skb_array_consume(q)) != NULL) + kfree_skb(skb); + } + + if (qdisc_is_percpu_stats(qdisc)) { + for_each_possible_cpu(i) { + struct gnet_stats_queue *q; + + q = per_cpu_ptr(qdisc->cpu_qstats, i); + q->backlog = 0; + q->qlen = 0; + } + } +} + +static int pfifo_fast_dump(struct Qdisc *qdisc, struct sk_buff *skb) +{ + struct tc_prio_qopt opt = { .bands = PFIFO_FAST_BANDS }; + + memcpy(&opt.priomap, prio2band, TC_PRIO_MAX + 1); + if (nla_put(skb, TCA_OPTIONS, sizeof(opt), &opt)) + goto nla_put_failure; + return skb->len; + +nla_put_failure: + return -1; +} + +static int pfifo_fast_init(struct Qdisc *qdisc, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + unsigned int qlen = qdisc_dev(qdisc)->tx_queue_len; + struct pfifo_fast_priv *priv = qdisc_priv(qdisc); + int prio; + + /* guard against zero length rings */ + if (!qlen) + return -EINVAL; + + for (prio = 0; prio < PFIFO_FAST_BANDS; prio++) { + struct skb_array *q = band2list(priv, prio); + int err; + + err = skb_array_init(q, qlen, GFP_KERNEL); + if (err) + return -ENOMEM; + } + + /* Can by-pass the queue discipline */ + qdisc->flags |= TCQ_F_CAN_BYPASS; + return 0; +} + +static void pfifo_fast_destroy(struct Qdisc *sch) +{ + struct pfifo_fast_priv *priv = qdisc_priv(sch); + int prio; + + for (prio = 0; prio < PFIFO_FAST_BANDS; prio++) { + struct skb_array *q = band2list(priv, prio); + + /* NULL ring is possible if destroy path is due to a failed + * skb_array_init() in pfifo_fast_init() case. + */ + if (!q->ring.queue) + continue; + /* Destroy ring but no need to kfree_skb because a call to + * pfifo_fast_reset() has already done that work. + */ + ptr_ring_cleanup(&q->ring, NULL); + } +} + +static int pfifo_fast_change_tx_queue_len(struct Qdisc *sch, + unsigned int new_len) +{ + struct pfifo_fast_priv *priv = qdisc_priv(sch); + struct skb_array *bands[PFIFO_FAST_BANDS]; + int prio; + + for (prio = 0; prio < PFIFO_FAST_BANDS; prio++) { + struct skb_array *q = band2list(priv, prio); + + bands[prio] = q; + } + + return skb_array_resize_multiple(bands, PFIFO_FAST_BANDS, new_len, + GFP_KERNEL); +} + +struct Qdisc_ops pfifo_fast_ops __read_mostly = { + .id = "pfifo_fast", + .priv_size = sizeof(struct pfifo_fast_priv), + .enqueue = pfifo_fast_enqueue, + .dequeue = pfifo_fast_dequeue, + .peek = pfifo_fast_peek, + .init = pfifo_fast_init, + .destroy = pfifo_fast_destroy, + .reset = pfifo_fast_reset, + .dump = pfifo_fast_dump, + .change_tx_queue_len = pfifo_fast_change_tx_queue_len, + .owner = THIS_MODULE, + .static_flags = TCQ_F_NOLOCK | TCQ_F_CPUSTATS, +}; +EXPORT_SYMBOL(pfifo_fast_ops); + +static struct lock_class_key qdisc_tx_busylock; + +struct Qdisc *qdisc_alloc(struct netdev_queue *dev_queue, + const struct Qdisc_ops *ops, + struct netlink_ext_ack *extack) +{ + struct Qdisc *sch; + unsigned int size = sizeof(*sch) + ops->priv_size; + int err = -ENOBUFS; + struct net_device *dev; + + if (!dev_queue) { + NL_SET_ERR_MSG(extack, "No device queue given"); + err = -EINVAL; + goto errout; + } + + dev = dev_queue->dev; + sch = kzalloc_node(size, GFP_KERNEL, netdev_queue_numa_node_read(dev_queue)); + + if (!sch) + goto errout; + __skb_queue_head_init(&sch->gso_skb); + __skb_queue_head_init(&sch->skb_bad_txq); + gnet_stats_basic_sync_init(&sch->bstats); + spin_lock_init(&sch->q.lock); + + if (ops->static_flags & TCQ_F_CPUSTATS) { + sch->cpu_bstats = + netdev_alloc_pcpu_stats(struct gnet_stats_basic_sync); + if (!sch->cpu_bstats) + goto errout1; + + sch->cpu_qstats = alloc_percpu(struct gnet_stats_queue); + if (!sch->cpu_qstats) { + free_percpu(sch->cpu_bstats); + goto errout1; + } + } + + spin_lock_init(&sch->busylock); + lockdep_set_class(&sch->busylock, + dev->qdisc_tx_busylock ?: &qdisc_tx_busylock); + + /* seqlock has the same scope of busylock, for NOLOCK qdisc */ + spin_lock_init(&sch->seqlock); + lockdep_set_class(&sch->seqlock, + dev->qdisc_tx_busylock ?: &qdisc_tx_busylock); + + sch->ops = ops; + sch->flags = ops->static_flags; + sch->enqueue = ops->enqueue; + sch->dequeue = ops->dequeue; + sch->dev_queue = dev_queue; + netdev_hold(dev, &sch->dev_tracker, GFP_KERNEL); + refcount_set(&sch->refcnt, 1); + + return sch; +errout1: + kfree(sch); +errout: + return ERR_PTR(err); +} + +struct Qdisc *qdisc_create_dflt(struct netdev_queue *dev_queue, + const struct Qdisc_ops *ops, + unsigned int parentid, + struct netlink_ext_ack *extack) +{ + struct Qdisc *sch; + + if (!try_module_get(ops->owner)) { + NL_SET_ERR_MSG(extack, "Failed to increase module reference counter"); + return NULL; + } + + sch = qdisc_alloc(dev_queue, ops, extack); + if (IS_ERR(sch)) { + module_put(ops->owner); + return NULL; + } + sch->parent = parentid; + + if (!ops->init || ops->init(sch, NULL, extack) == 0) { + trace_qdisc_create(ops, dev_queue->dev, parentid); + return sch; + } + + qdisc_put(sch); + return NULL; +} +EXPORT_SYMBOL(qdisc_create_dflt); + +/* Under qdisc_lock(qdisc) and BH! */ + +void qdisc_reset(struct Qdisc *qdisc) +{ + const struct Qdisc_ops *ops = qdisc->ops; + + trace_qdisc_reset(qdisc); + + if (ops->reset) + ops->reset(qdisc); + + __skb_queue_purge(&qdisc->gso_skb); + __skb_queue_purge(&qdisc->skb_bad_txq); + + qdisc->q.qlen = 0; + qdisc->qstats.backlog = 0; +} +EXPORT_SYMBOL(qdisc_reset); + +void qdisc_free(struct Qdisc *qdisc) +{ + if (qdisc_is_percpu_stats(qdisc)) { + free_percpu(qdisc->cpu_bstats); + free_percpu(qdisc->cpu_qstats); + } + + kfree(qdisc); +} + +static void qdisc_free_cb(struct rcu_head *head) +{ + struct Qdisc *q = container_of(head, struct Qdisc, rcu); + + qdisc_free(q); +} + +static void __qdisc_destroy(struct Qdisc *qdisc) +{ + const struct Qdisc_ops *ops = qdisc->ops; + +#ifdef CONFIG_NET_SCHED + qdisc_hash_del(qdisc); + + qdisc_put_stab(rtnl_dereference(qdisc->stab)); +#endif + gen_kill_estimator(&qdisc->rate_est); + + qdisc_reset(qdisc); + + if (ops->destroy) + ops->destroy(qdisc); + + module_put(ops->owner); + netdev_put(qdisc_dev(qdisc), &qdisc->dev_tracker); + + trace_qdisc_destroy(qdisc); + + call_rcu(&qdisc->rcu, qdisc_free_cb); +} + +void qdisc_destroy(struct Qdisc *qdisc) +{ + if (qdisc->flags & TCQ_F_BUILTIN) + return; + + __qdisc_destroy(qdisc); +} + +void qdisc_put(struct Qdisc *qdisc) +{ + if (!qdisc) + return; + + if (qdisc->flags & TCQ_F_BUILTIN || + !refcount_dec_and_test(&qdisc->refcnt)) + return; + + __qdisc_destroy(qdisc); +} +EXPORT_SYMBOL(qdisc_put); + +/* Version of qdisc_put() that is called with rtnl mutex unlocked. + * Intended to be used as optimization, this function only takes rtnl lock if + * qdisc reference counter reached zero. + */ + +void qdisc_put_unlocked(struct Qdisc *qdisc) +{ + if (qdisc->flags & TCQ_F_BUILTIN || + !refcount_dec_and_rtnl_lock(&qdisc->refcnt)) + return; + + __qdisc_destroy(qdisc); + rtnl_unlock(); +} +EXPORT_SYMBOL(qdisc_put_unlocked); + +/* Attach toplevel qdisc to device queue. */ +struct Qdisc *dev_graft_qdisc(struct netdev_queue *dev_queue, + struct Qdisc *qdisc) +{ + struct Qdisc *oqdisc = rtnl_dereference(dev_queue->qdisc_sleeping); + spinlock_t *root_lock; + + root_lock = qdisc_lock(oqdisc); + spin_lock_bh(root_lock); + + /* ... and graft new one */ + if (qdisc == NULL) + qdisc = &noop_qdisc; + rcu_assign_pointer(dev_queue->qdisc_sleeping, qdisc); + rcu_assign_pointer(dev_queue->qdisc, &noop_qdisc); + + spin_unlock_bh(root_lock); + + return oqdisc; +} +EXPORT_SYMBOL(dev_graft_qdisc); + +static void shutdown_scheduler_queue(struct net_device *dev, + struct netdev_queue *dev_queue, + void *_qdisc_default) +{ + struct Qdisc *qdisc = rtnl_dereference(dev_queue->qdisc_sleeping); + struct Qdisc *qdisc_default = _qdisc_default; + + if (qdisc) { + rcu_assign_pointer(dev_queue->qdisc, qdisc_default); + rcu_assign_pointer(dev_queue->qdisc_sleeping, qdisc_default); + + qdisc_put(qdisc); + } +} + +static void attach_one_default_qdisc(struct net_device *dev, + struct netdev_queue *dev_queue, + void *_unused) +{ + struct Qdisc *qdisc; + const struct Qdisc_ops *ops = default_qdisc_ops; + + if (dev->priv_flags & IFF_NO_QUEUE) + ops = &noqueue_qdisc_ops; + else if(dev->type == ARPHRD_CAN) + ops = &pfifo_fast_ops; + + qdisc = qdisc_create_dflt(dev_queue, ops, TC_H_ROOT, NULL); + if (!qdisc) + return; + + if (!netif_is_multiqueue(dev)) + qdisc->flags |= TCQ_F_ONETXQUEUE | TCQ_F_NOPARENT; + rcu_assign_pointer(dev_queue->qdisc_sleeping, qdisc); +} + +static void attach_default_qdiscs(struct net_device *dev) +{ + struct netdev_queue *txq; + struct Qdisc *qdisc; + + txq = netdev_get_tx_queue(dev, 0); + + if (!netif_is_multiqueue(dev) || + dev->priv_flags & IFF_NO_QUEUE) { + netdev_for_each_tx_queue(dev, attach_one_default_qdisc, NULL); + qdisc = rtnl_dereference(txq->qdisc_sleeping); + rcu_assign_pointer(dev->qdisc, qdisc); + qdisc_refcount_inc(qdisc); + } else { + qdisc = qdisc_create_dflt(txq, &mq_qdisc_ops, TC_H_ROOT, NULL); + if (qdisc) { + rcu_assign_pointer(dev->qdisc, qdisc); + qdisc->ops->attach(qdisc); + } + } + qdisc = rtnl_dereference(dev->qdisc); + + /* Detect default qdisc setup/init failed and fallback to "noqueue" */ + if (qdisc == &noop_qdisc) { + netdev_warn(dev, "default qdisc (%s) fail, fallback to %s\n", + default_qdisc_ops->id, noqueue_qdisc_ops.id); + netdev_for_each_tx_queue(dev, shutdown_scheduler_queue, &noop_qdisc); + dev->priv_flags |= IFF_NO_QUEUE; + netdev_for_each_tx_queue(dev, attach_one_default_qdisc, NULL); + qdisc = rtnl_dereference(txq->qdisc_sleeping); + rcu_assign_pointer(dev->qdisc, qdisc); + qdisc_refcount_inc(qdisc); + dev->priv_flags ^= IFF_NO_QUEUE; + } + +#ifdef CONFIG_NET_SCHED + if (qdisc != &noop_qdisc) + qdisc_hash_add(qdisc, false); +#endif +} + +static void transition_one_qdisc(struct net_device *dev, + struct netdev_queue *dev_queue, + void *_need_watchdog) +{ + struct Qdisc *new_qdisc = rtnl_dereference(dev_queue->qdisc_sleeping); + int *need_watchdog_p = _need_watchdog; + + if (!(new_qdisc->flags & TCQ_F_BUILTIN)) + clear_bit(__QDISC_STATE_DEACTIVATED, &new_qdisc->state); + + rcu_assign_pointer(dev_queue->qdisc, new_qdisc); + if (need_watchdog_p) { + WRITE_ONCE(dev_queue->trans_start, 0); + *need_watchdog_p = 1; + } +} + +void dev_activate(struct net_device *dev) +{ + int need_watchdog; + + /* No queueing discipline is attached to device; + * create default one for devices, which need queueing + * and noqueue_qdisc for virtual interfaces + */ + + if (rtnl_dereference(dev->qdisc) == &noop_qdisc) + attach_default_qdiscs(dev); + + if (!netif_carrier_ok(dev)) + /* Delay activation until next carrier-on event */ + return; + + need_watchdog = 0; + netdev_for_each_tx_queue(dev, transition_one_qdisc, &need_watchdog); + if (dev_ingress_queue(dev)) + transition_one_qdisc(dev, dev_ingress_queue(dev), NULL); + + if (need_watchdog) { + netif_trans_update(dev); + dev_watchdog_up(dev); + } +} +EXPORT_SYMBOL(dev_activate); + +static void qdisc_deactivate(struct Qdisc *qdisc) +{ + if (qdisc->flags & TCQ_F_BUILTIN) + return; + + set_bit(__QDISC_STATE_DEACTIVATED, &qdisc->state); +} + +static void dev_deactivate_queue(struct net_device *dev, + struct netdev_queue *dev_queue, + void *_qdisc_default) +{ + struct Qdisc *qdisc_default = _qdisc_default; + struct Qdisc *qdisc; + + qdisc = rtnl_dereference(dev_queue->qdisc); + if (qdisc) { + qdisc_deactivate(qdisc); + rcu_assign_pointer(dev_queue->qdisc, qdisc_default); + } +} + +static void dev_reset_queue(struct net_device *dev, + struct netdev_queue *dev_queue, + void *_unused) +{ + struct Qdisc *qdisc; + bool nolock; + + qdisc = rtnl_dereference(dev_queue->qdisc_sleeping); + if (!qdisc) + return; + + nolock = qdisc->flags & TCQ_F_NOLOCK; + + if (nolock) + spin_lock_bh(&qdisc->seqlock); + spin_lock_bh(qdisc_lock(qdisc)); + + qdisc_reset(qdisc); + + spin_unlock_bh(qdisc_lock(qdisc)); + if (nolock) { + clear_bit(__QDISC_STATE_MISSED, &qdisc->state); + clear_bit(__QDISC_STATE_DRAINING, &qdisc->state); + spin_unlock_bh(&qdisc->seqlock); + } +} + +static bool some_qdisc_is_busy(struct net_device *dev) +{ + unsigned int i; + + for (i = 0; i < dev->num_tx_queues; i++) { + struct netdev_queue *dev_queue; + spinlock_t *root_lock; + struct Qdisc *q; + int val; + + dev_queue = netdev_get_tx_queue(dev, i); + q = rtnl_dereference(dev_queue->qdisc_sleeping); + + root_lock = qdisc_lock(q); + spin_lock_bh(root_lock); + + val = (qdisc_is_running(q) || + test_bit(__QDISC_STATE_SCHED, &q->state)); + + spin_unlock_bh(root_lock); + + if (val) + return true; + } + return false; +} + +/** + * dev_deactivate_many - deactivate transmissions on several devices + * @head: list of devices to deactivate + * + * This function returns only when all outstanding transmissions + * have completed, unless all devices are in dismantle phase. + */ +void dev_deactivate_many(struct list_head *head) +{ + struct net_device *dev; + + list_for_each_entry(dev, head, close_list) { + netdev_for_each_tx_queue(dev, dev_deactivate_queue, + &noop_qdisc); + if (dev_ingress_queue(dev)) + dev_deactivate_queue(dev, dev_ingress_queue(dev), + &noop_qdisc); + + dev_watchdog_down(dev); + } + + /* Wait for outstanding qdisc-less dev_queue_xmit calls or + * outstanding qdisc enqueuing calls. + * This is avoided if all devices are in dismantle phase : + * Caller will call synchronize_net() for us + */ + synchronize_net(); + + list_for_each_entry(dev, head, close_list) { + netdev_for_each_tx_queue(dev, dev_reset_queue, NULL); + + if (dev_ingress_queue(dev)) + dev_reset_queue(dev, dev_ingress_queue(dev), NULL); + } + + /* Wait for outstanding qdisc_run calls. */ + list_for_each_entry(dev, head, close_list) { + while (some_qdisc_is_busy(dev)) { + /* wait_event() would avoid this sleep-loop but would + * require expensive checks in the fast paths of packet + * processing which isn't worth it. + */ + schedule_timeout_uninterruptible(1); + } + } +} + +void dev_deactivate(struct net_device *dev) +{ + LIST_HEAD(single); + + list_add(&dev->close_list, &single); + dev_deactivate_many(&single); + list_del(&single); +} +EXPORT_SYMBOL(dev_deactivate); + +static int qdisc_change_tx_queue_len(struct net_device *dev, + struct netdev_queue *dev_queue) +{ + struct Qdisc *qdisc = rtnl_dereference(dev_queue->qdisc_sleeping); + const struct Qdisc_ops *ops = qdisc->ops; + + if (ops->change_tx_queue_len) + return ops->change_tx_queue_len(qdisc, dev->tx_queue_len); + return 0; +} + +void dev_qdisc_change_real_num_tx(struct net_device *dev, + unsigned int new_real_tx) +{ + struct Qdisc *qdisc = rtnl_dereference(dev->qdisc); + + if (qdisc->ops->change_real_num_tx) + qdisc->ops->change_real_num_tx(qdisc, new_real_tx); +} + +void mq_change_real_num_tx(struct Qdisc *sch, unsigned int new_real_tx) +{ +#ifdef CONFIG_NET_SCHED + struct net_device *dev = qdisc_dev(sch); + struct Qdisc *qdisc; + unsigned int i; + + for (i = new_real_tx; i < dev->real_num_tx_queues; i++) { + qdisc = rtnl_dereference(netdev_get_tx_queue(dev, i)->qdisc_sleeping); + /* Only update the default qdiscs we created, + * qdiscs with handles are always hashed. + */ + if (qdisc != &noop_qdisc && !qdisc->handle) + qdisc_hash_del(qdisc); + } + for (i = dev->real_num_tx_queues; i < new_real_tx; i++) { + qdisc = rtnl_dereference(netdev_get_tx_queue(dev, i)->qdisc_sleeping); + if (qdisc != &noop_qdisc && !qdisc->handle) + qdisc_hash_add(qdisc, false); + } +#endif +} +EXPORT_SYMBOL(mq_change_real_num_tx); + +int dev_qdisc_change_tx_queue_len(struct net_device *dev) +{ + bool up = dev->flags & IFF_UP; + unsigned int i; + int ret = 0; + + if (up) + dev_deactivate(dev); + + for (i = 0; i < dev->num_tx_queues; i++) { + ret = qdisc_change_tx_queue_len(dev, &dev->_tx[i]); + + /* TODO: revert changes on a partial failure */ + if (ret) + break; + } + + if (up) + dev_activate(dev); + return ret; +} + +static void dev_init_scheduler_queue(struct net_device *dev, + struct netdev_queue *dev_queue, + void *_qdisc) +{ + struct Qdisc *qdisc = _qdisc; + + rcu_assign_pointer(dev_queue->qdisc, qdisc); + rcu_assign_pointer(dev_queue->qdisc_sleeping, qdisc); +} + +void dev_init_scheduler(struct net_device *dev) +{ + rcu_assign_pointer(dev->qdisc, &noop_qdisc); + netdev_for_each_tx_queue(dev, dev_init_scheduler_queue, &noop_qdisc); + if (dev_ingress_queue(dev)) + dev_init_scheduler_queue(dev, dev_ingress_queue(dev), &noop_qdisc); + + timer_setup(&dev->watchdog_timer, dev_watchdog, 0); +} + +void dev_shutdown(struct net_device *dev) +{ + netdev_for_each_tx_queue(dev, shutdown_scheduler_queue, &noop_qdisc); + if (dev_ingress_queue(dev)) + shutdown_scheduler_queue(dev, dev_ingress_queue(dev), &noop_qdisc); + qdisc_put(rtnl_dereference(dev->qdisc)); + rcu_assign_pointer(dev->qdisc, &noop_qdisc); + + WARN_ON(timer_pending(&dev->watchdog_timer)); +} + +/** + * psched_ratecfg_precompute__() - Pre-compute values for reciprocal division + * @rate: Rate to compute reciprocal division values of + * @mult: Multiplier for reciprocal division + * @shift: Shift for reciprocal division + * + * The multiplier and shift for reciprocal division by rate are stored + * in mult and shift. + * + * The deal here is to replace a divide by a reciprocal one + * in fast path (a reciprocal divide is a multiply and a shift) + * + * Normal formula would be : + * time_in_ns = (NSEC_PER_SEC * len) / rate_bps + * + * We compute mult/shift to use instead : + * time_in_ns = (len * mult) >> shift; + * + * We try to get the highest possible mult value for accuracy, + * but have to make sure no overflows will ever happen. + * + * reciprocal_value() is not used here it doesn't handle 64-bit values. + */ +static void psched_ratecfg_precompute__(u64 rate, u32 *mult, u8 *shift) +{ + u64 factor = NSEC_PER_SEC; + + *mult = 1; + *shift = 0; + + if (rate <= 0) + return; + + for (;;) { + *mult = div64_u64(factor, rate); + if (*mult & (1U << 31) || factor & (1ULL << 63)) + break; + factor <<= 1; + (*shift)++; + } +} + +void psched_ratecfg_precompute(struct psched_ratecfg *r, + const struct tc_ratespec *conf, + u64 rate64) +{ + memset(r, 0, sizeof(*r)); + r->overhead = conf->overhead; + r->mpu = conf->mpu; + r->rate_bytes_ps = max_t(u64, conf->rate, rate64); + r->linklayer = (conf->linklayer & TC_LINKLAYER_MASK); + psched_ratecfg_precompute__(r->rate_bytes_ps, &r->mult, &r->shift); +} +EXPORT_SYMBOL(psched_ratecfg_precompute); + +void psched_ppscfg_precompute(struct psched_pktrate *r, u64 pktrate64) +{ + r->rate_pkts_ps = pktrate64; + psched_ratecfg_precompute__(r->rate_pkts_ps, &r->mult, &r->shift); +} +EXPORT_SYMBOL(psched_ppscfg_precompute); + +void mini_qdisc_pair_swap(struct mini_Qdisc_pair *miniqp, + struct tcf_proto *tp_head) +{ + /* Protected with chain0->filter_chain_lock. + * Can't access chain directly because tp_head can be NULL. + */ + struct mini_Qdisc *miniq_old = + rcu_dereference_protected(*miniqp->p_miniq, 1); + struct mini_Qdisc *miniq; + + if (!tp_head) { + RCU_INIT_POINTER(*miniqp->p_miniq, NULL); + } else { + miniq = miniq_old != &miniqp->miniq1 ? + &miniqp->miniq1 : &miniqp->miniq2; + + /* We need to make sure that readers won't see the miniq + * we are about to modify. So ensure that at least one RCU + * grace period has elapsed since the miniq was made + * inactive. + */ + if (IS_ENABLED(CONFIG_PREEMPT_RT)) + cond_synchronize_rcu(miniq->rcu_state); + else if (!poll_state_synchronize_rcu(miniq->rcu_state)) + synchronize_rcu_expedited(); + + miniq->filter_list = tp_head; + rcu_assign_pointer(*miniqp->p_miniq, miniq); + } + + if (miniq_old) + /* This is counterpart of the rcu sync above. We need to + * block potential new user of miniq_old until all readers + * are not seeing it. + */ + miniq_old->rcu_state = start_poll_synchronize_rcu(); +} +EXPORT_SYMBOL(mini_qdisc_pair_swap); + +void mini_qdisc_pair_block_init(struct mini_Qdisc_pair *miniqp, + struct tcf_block *block) +{ + miniqp->miniq1.block = block; + miniqp->miniq2.block = block; +} +EXPORT_SYMBOL(mini_qdisc_pair_block_init); + +void mini_qdisc_pair_init(struct mini_Qdisc_pair *miniqp, struct Qdisc *qdisc, + struct mini_Qdisc __rcu **p_miniq) +{ + miniqp->miniq1.cpu_bstats = qdisc->cpu_bstats; + miniqp->miniq1.cpu_qstats = qdisc->cpu_qstats; + miniqp->miniq2.cpu_bstats = qdisc->cpu_bstats; + miniqp->miniq2.cpu_qstats = qdisc->cpu_qstats; + miniqp->miniq1.rcu_state = get_state_synchronize_rcu(); + miniqp->miniq2.rcu_state = miniqp->miniq1.rcu_state; + miniqp->p_miniq = p_miniq; +} +EXPORT_SYMBOL(mini_qdisc_pair_init); diff --git a/net/sched/sch_gred.c b/net/sched/sch_gred.c new file mode 100644 index 000000000..872d127c9 --- /dev/null +++ b/net/sched/sch_gred.c @@ -0,0 +1,947 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/sch_gred.c Generic Random Early Detection queue. + * + * Authors: J Hadi Salim (hadi@cyberus.ca) 1998-2002 + * + * 991129: - Bug fix with grio mode + * - a better sing. AvgQ mode with Grio(WRED) + * - A finer grained VQ dequeue based on suggestion + * from Ren Liu + * - More error checks + * + * For all the glorious comments look at include/net/red.h + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#define GRED_DEF_PRIO (MAX_DPs / 2) +#define GRED_VQ_MASK (MAX_DPs - 1) + +#define GRED_VQ_RED_FLAGS (TC_RED_ECN | TC_RED_HARDDROP) + +struct gred_sched_data; +struct gred_sched; + +struct gred_sched_data { + u32 limit; /* HARD maximal queue length */ + u32 DP; /* the drop parameters */ + u32 red_flags; /* virtualQ version of red_flags */ + u64 bytesin; /* bytes seen on virtualQ so far*/ + u32 packetsin; /* packets seen on virtualQ so far*/ + u32 backlog; /* bytes on the virtualQ */ + u8 prio; /* the prio of this vq */ + + struct red_parms parms; + struct red_vars vars; + struct red_stats stats; +}; + +enum { + GRED_WRED_MODE = 1, + GRED_RIO_MODE, +}; + +struct gred_sched { + struct gred_sched_data *tab[MAX_DPs]; + unsigned long flags; + u32 red_flags; + u32 DPs; + u32 def; + struct red_vars wred_set; + struct tc_gred_qopt_offload *opt; +}; + +static inline int gred_wred_mode(struct gred_sched *table) +{ + return test_bit(GRED_WRED_MODE, &table->flags); +} + +static inline void gred_enable_wred_mode(struct gred_sched *table) +{ + __set_bit(GRED_WRED_MODE, &table->flags); +} + +static inline void gred_disable_wred_mode(struct gred_sched *table) +{ + __clear_bit(GRED_WRED_MODE, &table->flags); +} + +static inline int gred_rio_mode(struct gred_sched *table) +{ + return test_bit(GRED_RIO_MODE, &table->flags); +} + +static inline void gred_enable_rio_mode(struct gred_sched *table) +{ + __set_bit(GRED_RIO_MODE, &table->flags); +} + +static inline void gred_disable_rio_mode(struct gred_sched *table) +{ + __clear_bit(GRED_RIO_MODE, &table->flags); +} + +static inline int gred_wred_mode_check(struct Qdisc *sch) +{ + struct gred_sched *table = qdisc_priv(sch); + int i; + + /* Really ugly O(n^2) but shouldn't be necessary too frequent. */ + for (i = 0; i < table->DPs; i++) { + struct gred_sched_data *q = table->tab[i]; + int n; + + if (q == NULL) + continue; + + for (n = i + 1; n < table->DPs; n++) + if (table->tab[n] && table->tab[n]->prio == q->prio) + return 1; + } + + return 0; +} + +static inline unsigned int gred_backlog(struct gred_sched *table, + struct gred_sched_data *q, + struct Qdisc *sch) +{ + if (gred_wred_mode(table)) + return sch->qstats.backlog; + else + return q->backlog; +} + +static inline u16 tc_index_to_dp(struct sk_buff *skb) +{ + return skb->tc_index & GRED_VQ_MASK; +} + +static inline void gred_load_wred_set(const struct gred_sched *table, + struct gred_sched_data *q) +{ + q->vars.qavg = table->wred_set.qavg; + q->vars.qidlestart = table->wred_set.qidlestart; +} + +static inline void gred_store_wred_set(struct gred_sched *table, + struct gred_sched_data *q) +{ + table->wred_set.qavg = q->vars.qavg; + table->wred_set.qidlestart = q->vars.qidlestart; +} + +static int gred_use_ecn(struct gred_sched_data *q) +{ + return q->red_flags & TC_RED_ECN; +} + +static int gred_use_harddrop(struct gred_sched_data *q) +{ + return q->red_flags & TC_RED_HARDDROP; +} + +static bool gred_per_vq_red_flags_used(struct gred_sched *table) +{ + unsigned int i; + + /* Local per-vq flags couldn't have been set unless global are 0 */ + if (table->red_flags) + return false; + for (i = 0; i < MAX_DPs; i++) + if (table->tab[i] && table->tab[i]->red_flags) + return true; + return false; +} + +static int gred_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct gred_sched_data *q = NULL; + struct gred_sched *t = qdisc_priv(sch); + unsigned long qavg = 0; + u16 dp = tc_index_to_dp(skb); + + if (dp >= t->DPs || (q = t->tab[dp]) == NULL) { + dp = t->def; + + q = t->tab[dp]; + if (!q) { + /* Pass through packets not assigned to a DP + * if no default DP has been configured. This + * allows for DP flows to be left untouched. + */ + if (likely(sch->qstats.backlog + qdisc_pkt_len(skb) <= + sch->limit)) + return qdisc_enqueue_tail(skb, sch); + else + goto drop; + } + + /* fix tc_index? --could be controversial but needed for + requeueing */ + skb->tc_index = (skb->tc_index & ~GRED_VQ_MASK) | dp; + } + + /* sum up all the qaves of prios < ours to get the new qave */ + if (!gred_wred_mode(t) && gred_rio_mode(t)) { + int i; + + for (i = 0; i < t->DPs; i++) { + if (t->tab[i] && t->tab[i]->prio < q->prio && + !red_is_idling(&t->tab[i]->vars)) + qavg += t->tab[i]->vars.qavg; + } + + } + + q->packetsin++; + q->bytesin += qdisc_pkt_len(skb); + + if (gred_wred_mode(t)) + gred_load_wred_set(t, q); + + q->vars.qavg = red_calc_qavg(&q->parms, + &q->vars, + gred_backlog(t, q, sch)); + + if (red_is_idling(&q->vars)) + red_end_of_idle_period(&q->vars); + + if (gred_wred_mode(t)) + gred_store_wred_set(t, q); + + switch (red_action(&q->parms, &q->vars, q->vars.qavg + qavg)) { + case RED_DONT_MARK: + break; + + case RED_PROB_MARK: + qdisc_qstats_overlimit(sch); + if (!gred_use_ecn(q) || !INET_ECN_set_ce(skb)) { + q->stats.prob_drop++; + goto congestion_drop; + } + + q->stats.prob_mark++; + break; + + case RED_HARD_MARK: + qdisc_qstats_overlimit(sch); + if (gred_use_harddrop(q) || !gred_use_ecn(q) || + !INET_ECN_set_ce(skb)) { + q->stats.forced_drop++; + goto congestion_drop; + } + q->stats.forced_mark++; + break; + } + + if (gred_backlog(t, q, sch) + qdisc_pkt_len(skb) <= q->limit) { + q->backlog += qdisc_pkt_len(skb); + return qdisc_enqueue_tail(skb, sch); + } + + q->stats.pdrop++; +drop: + return qdisc_drop(skb, sch, to_free); + +congestion_drop: + qdisc_drop(skb, sch, to_free); + return NET_XMIT_CN; +} + +static struct sk_buff *gred_dequeue(struct Qdisc *sch) +{ + struct sk_buff *skb; + struct gred_sched *t = qdisc_priv(sch); + + skb = qdisc_dequeue_head(sch); + + if (skb) { + struct gred_sched_data *q; + u16 dp = tc_index_to_dp(skb); + + if (dp >= t->DPs || (q = t->tab[dp]) == NULL) { + net_warn_ratelimited("GRED: Unable to relocate VQ 0x%x after dequeue, screwing up backlog\n", + tc_index_to_dp(skb)); + } else { + q->backlog -= qdisc_pkt_len(skb); + + if (gred_wred_mode(t)) { + if (!sch->qstats.backlog) + red_start_of_idle_period(&t->wred_set); + } else { + if (!q->backlog) + red_start_of_idle_period(&q->vars); + } + } + + return skb; + } + + return NULL; +} + +static void gred_reset(struct Qdisc *sch) +{ + int i; + struct gred_sched *t = qdisc_priv(sch); + + qdisc_reset_queue(sch); + + for (i = 0; i < t->DPs; i++) { + struct gred_sched_data *q = t->tab[i]; + + if (!q) + continue; + + red_restart(&q->vars); + q->backlog = 0; + } +} + +static void gred_offload(struct Qdisc *sch, enum tc_gred_command command) +{ + struct gred_sched *table = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct tc_gred_qopt_offload *opt = table->opt; + + if (!tc_can_offload(dev) || !dev->netdev_ops->ndo_setup_tc) + return; + + memset(opt, 0, sizeof(*opt)); + opt->command = command; + opt->handle = sch->handle; + opt->parent = sch->parent; + + if (command == TC_GRED_REPLACE) { + unsigned int i; + + opt->set.grio_on = gred_rio_mode(table); + opt->set.wred_on = gred_wred_mode(table); + opt->set.dp_cnt = table->DPs; + opt->set.dp_def = table->def; + + for (i = 0; i < table->DPs; i++) { + struct gred_sched_data *q = table->tab[i]; + + if (!q) + continue; + opt->set.tab[i].present = true; + opt->set.tab[i].limit = q->limit; + opt->set.tab[i].prio = q->prio; + opt->set.tab[i].min = q->parms.qth_min >> q->parms.Wlog; + opt->set.tab[i].max = q->parms.qth_max >> q->parms.Wlog; + opt->set.tab[i].is_ecn = gred_use_ecn(q); + opt->set.tab[i].is_harddrop = gred_use_harddrop(q); + opt->set.tab[i].probability = q->parms.max_P; + opt->set.tab[i].backlog = &q->backlog; + } + opt->set.qstats = &sch->qstats; + } + + dev->netdev_ops->ndo_setup_tc(dev, TC_SETUP_QDISC_GRED, opt); +} + +static int gred_offload_dump_stats(struct Qdisc *sch) +{ + struct gred_sched *table = qdisc_priv(sch); + struct tc_gred_qopt_offload *hw_stats; + u64 bytes = 0, packets = 0; + unsigned int i; + int ret; + + hw_stats = kzalloc(sizeof(*hw_stats), GFP_KERNEL); + if (!hw_stats) + return -ENOMEM; + + hw_stats->command = TC_GRED_STATS; + hw_stats->handle = sch->handle; + hw_stats->parent = sch->parent; + + for (i = 0; i < MAX_DPs; i++) { + gnet_stats_basic_sync_init(&hw_stats->stats.bstats[i]); + if (table->tab[i]) + hw_stats->stats.xstats[i] = &table->tab[i]->stats; + } + + ret = qdisc_offload_dump_helper(sch, TC_SETUP_QDISC_GRED, hw_stats); + /* Even if driver returns failure adjust the stats - in case offload + * ended but driver still wants to adjust the values. + */ + sch_tree_lock(sch); + for (i = 0; i < MAX_DPs; i++) { + if (!table->tab[i]) + continue; + table->tab[i]->packetsin += u64_stats_read(&hw_stats->stats.bstats[i].packets); + table->tab[i]->bytesin += u64_stats_read(&hw_stats->stats.bstats[i].bytes); + table->tab[i]->backlog += hw_stats->stats.qstats[i].backlog; + + bytes += u64_stats_read(&hw_stats->stats.bstats[i].bytes); + packets += u64_stats_read(&hw_stats->stats.bstats[i].packets); + sch->qstats.qlen += hw_stats->stats.qstats[i].qlen; + sch->qstats.backlog += hw_stats->stats.qstats[i].backlog; + sch->qstats.drops += hw_stats->stats.qstats[i].drops; + sch->qstats.requeues += hw_stats->stats.qstats[i].requeues; + sch->qstats.overlimits += hw_stats->stats.qstats[i].overlimits; + } + _bstats_update(&sch->bstats, bytes, packets); + sch_tree_unlock(sch); + + kfree(hw_stats); + return ret; +} + +static inline void gred_destroy_vq(struct gred_sched_data *q) +{ + kfree(q); +} + +static int gred_change_table_def(struct Qdisc *sch, struct nlattr *dps, + struct netlink_ext_ack *extack) +{ + struct gred_sched *table = qdisc_priv(sch); + struct tc_gred_sopt *sopt; + bool red_flags_changed; + int i; + + if (!dps) + return -EINVAL; + + sopt = nla_data(dps); + + if (sopt->DPs > MAX_DPs) { + NL_SET_ERR_MSG_MOD(extack, "number of virtual queues too high"); + return -EINVAL; + } + if (sopt->DPs == 0) { + NL_SET_ERR_MSG_MOD(extack, + "number of virtual queues can't be 0"); + return -EINVAL; + } + if (sopt->def_DP >= sopt->DPs) { + NL_SET_ERR_MSG_MOD(extack, "default virtual queue above virtual queue count"); + return -EINVAL; + } + if (sopt->flags && gred_per_vq_red_flags_used(table)) { + NL_SET_ERR_MSG_MOD(extack, "can't set per-Qdisc RED flags when per-virtual queue flags are used"); + return -EINVAL; + } + + sch_tree_lock(sch); + table->DPs = sopt->DPs; + table->def = sopt->def_DP; + red_flags_changed = table->red_flags != sopt->flags; + table->red_flags = sopt->flags; + + /* + * Every entry point to GRED is synchronized with the above code + * and the DP is checked against DPs, i.e. shadowed VQs can no + * longer be found so we can unlock right here. + */ + sch_tree_unlock(sch); + + if (sopt->grio) { + gred_enable_rio_mode(table); + gred_disable_wred_mode(table); + if (gred_wred_mode_check(sch)) + gred_enable_wred_mode(table); + } else { + gred_disable_rio_mode(table); + gred_disable_wred_mode(table); + } + + if (red_flags_changed) + for (i = 0; i < table->DPs; i++) + if (table->tab[i]) + table->tab[i]->red_flags = + table->red_flags & GRED_VQ_RED_FLAGS; + + for (i = table->DPs; i < MAX_DPs; i++) { + if (table->tab[i]) { + pr_warn("GRED: Warning: Destroying shadowed VQ 0x%x\n", + i); + gred_destroy_vq(table->tab[i]); + table->tab[i] = NULL; + } + } + + gred_offload(sch, TC_GRED_REPLACE); + return 0; +} + +static inline int gred_change_vq(struct Qdisc *sch, int dp, + struct tc_gred_qopt *ctl, int prio, + u8 *stab, u32 max_P, + struct gred_sched_data **prealloc, + struct netlink_ext_ack *extack) +{ + struct gred_sched *table = qdisc_priv(sch); + struct gred_sched_data *q = table->tab[dp]; + + if (!red_check_params(ctl->qth_min, ctl->qth_max, ctl->Wlog, ctl->Scell_log, stab)) { + NL_SET_ERR_MSG_MOD(extack, "invalid RED parameters"); + return -EINVAL; + } + + if (!q) { + table->tab[dp] = q = *prealloc; + *prealloc = NULL; + if (!q) + return -ENOMEM; + q->red_flags = table->red_flags & GRED_VQ_RED_FLAGS; + } + + q->DP = dp; + q->prio = prio; + if (ctl->limit > sch->limit) + q->limit = sch->limit; + else + q->limit = ctl->limit; + + if (q->backlog == 0) + red_end_of_idle_period(&q->vars); + + red_set_parms(&q->parms, + ctl->qth_min, ctl->qth_max, ctl->Wlog, ctl->Plog, + ctl->Scell_log, stab, max_P); + red_set_vars(&q->vars); + return 0; +} + +static const struct nla_policy gred_vq_policy[TCA_GRED_VQ_MAX + 1] = { + [TCA_GRED_VQ_DP] = { .type = NLA_U32 }, + [TCA_GRED_VQ_FLAGS] = { .type = NLA_U32 }, +}; + +static const struct nla_policy gred_vqe_policy[TCA_GRED_VQ_ENTRY_MAX + 1] = { + [TCA_GRED_VQ_ENTRY] = { .type = NLA_NESTED }, +}; + +static const struct nla_policy gred_policy[TCA_GRED_MAX + 1] = { + [TCA_GRED_PARMS] = { .len = sizeof(struct tc_gred_qopt) }, + [TCA_GRED_STAB] = { .len = 256 }, + [TCA_GRED_DPS] = { .len = sizeof(struct tc_gred_sopt) }, + [TCA_GRED_MAX_P] = { .type = NLA_U32 }, + [TCA_GRED_LIMIT] = { .type = NLA_U32 }, + [TCA_GRED_VQ_LIST] = { .type = NLA_NESTED }, +}; + +static void gred_vq_apply(struct gred_sched *table, const struct nlattr *entry) +{ + struct nlattr *tb[TCA_GRED_VQ_MAX + 1]; + u32 dp; + + nla_parse_nested_deprecated(tb, TCA_GRED_VQ_MAX, entry, + gred_vq_policy, NULL); + + dp = nla_get_u32(tb[TCA_GRED_VQ_DP]); + + if (tb[TCA_GRED_VQ_FLAGS]) + table->tab[dp]->red_flags = nla_get_u32(tb[TCA_GRED_VQ_FLAGS]); +} + +static void gred_vqs_apply(struct gred_sched *table, struct nlattr *vqs) +{ + const struct nlattr *attr; + int rem; + + nla_for_each_nested(attr, vqs, rem) { + switch (nla_type(attr)) { + case TCA_GRED_VQ_ENTRY: + gred_vq_apply(table, attr); + break; + } + } +} + +static int gred_vq_validate(struct gred_sched *table, u32 cdp, + const struct nlattr *entry, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_GRED_VQ_MAX + 1]; + int err; + u32 dp; + + err = nla_parse_nested_deprecated(tb, TCA_GRED_VQ_MAX, entry, + gred_vq_policy, extack); + if (err < 0) + return err; + + if (!tb[TCA_GRED_VQ_DP]) { + NL_SET_ERR_MSG_MOD(extack, "Virtual queue with no index specified"); + return -EINVAL; + } + dp = nla_get_u32(tb[TCA_GRED_VQ_DP]); + if (dp >= table->DPs) { + NL_SET_ERR_MSG_MOD(extack, "Virtual queue with index out of bounds"); + return -EINVAL; + } + if (dp != cdp && !table->tab[dp]) { + NL_SET_ERR_MSG_MOD(extack, "Virtual queue not yet instantiated"); + return -EINVAL; + } + + if (tb[TCA_GRED_VQ_FLAGS]) { + u32 red_flags = nla_get_u32(tb[TCA_GRED_VQ_FLAGS]); + + if (table->red_flags && table->red_flags != red_flags) { + NL_SET_ERR_MSG_MOD(extack, "can't change per-virtual queue RED flags when per-Qdisc flags are used"); + return -EINVAL; + } + if (red_flags & ~GRED_VQ_RED_FLAGS) { + NL_SET_ERR_MSG_MOD(extack, + "invalid RED flags specified"); + return -EINVAL; + } + } + + return 0; +} + +static int gred_vqs_validate(struct gred_sched *table, u32 cdp, + struct nlattr *vqs, struct netlink_ext_ack *extack) +{ + const struct nlattr *attr; + int rem, err; + + err = nla_validate_nested_deprecated(vqs, TCA_GRED_VQ_ENTRY_MAX, + gred_vqe_policy, extack); + if (err < 0) + return err; + + nla_for_each_nested(attr, vqs, rem) { + switch (nla_type(attr)) { + case TCA_GRED_VQ_ENTRY: + err = gred_vq_validate(table, cdp, attr, extack); + if (err) + return err; + break; + default: + NL_SET_ERR_MSG_MOD(extack, "GRED_VQ_LIST can contain only entry attributes"); + return -EINVAL; + } + } + + if (rem > 0) { + NL_SET_ERR_MSG_MOD(extack, "Trailing data after parsing virtual queue list"); + return -EINVAL; + } + + return 0; +} + +static int gred_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct gred_sched *table = qdisc_priv(sch); + struct tc_gred_qopt *ctl; + struct nlattr *tb[TCA_GRED_MAX + 1]; + int err, prio = GRED_DEF_PRIO; + u8 *stab; + u32 max_P; + struct gred_sched_data *prealloc; + + err = nla_parse_nested_deprecated(tb, TCA_GRED_MAX, opt, gred_policy, + extack); + if (err < 0) + return err; + + if (tb[TCA_GRED_PARMS] == NULL && tb[TCA_GRED_STAB] == NULL) { + if (tb[TCA_GRED_LIMIT] != NULL) + sch->limit = nla_get_u32(tb[TCA_GRED_LIMIT]); + return gred_change_table_def(sch, tb[TCA_GRED_DPS], extack); + } + + if (tb[TCA_GRED_PARMS] == NULL || + tb[TCA_GRED_STAB] == NULL || + tb[TCA_GRED_LIMIT] != NULL) { + NL_SET_ERR_MSG_MOD(extack, "can't configure Qdisc and virtual queue at the same time"); + return -EINVAL; + } + + max_P = tb[TCA_GRED_MAX_P] ? nla_get_u32(tb[TCA_GRED_MAX_P]) : 0; + + ctl = nla_data(tb[TCA_GRED_PARMS]); + stab = nla_data(tb[TCA_GRED_STAB]); + + if (ctl->DP >= table->DPs) { + NL_SET_ERR_MSG_MOD(extack, "virtual queue index above virtual queue count"); + return -EINVAL; + } + + if (tb[TCA_GRED_VQ_LIST]) { + err = gred_vqs_validate(table, ctl->DP, tb[TCA_GRED_VQ_LIST], + extack); + if (err) + return err; + } + + if (gred_rio_mode(table)) { + if (ctl->prio == 0) { + int def_prio = GRED_DEF_PRIO; + + if (table->tab[table->def]) + def_prio = table->tab[table->def]->prio; + + printk(KERN_DEBUG "GRED: DP %u does not have a prio " + "setting default to %d\n", ctl->DP, def_prio); + + prio = def_prio; + } else + prio = ctl->prio; + } + + prealloc = kzalloc(sizeof(*prealloc), GFP_KERNEL); + sch_tree_lock(sch); + + err = gred_change_vq(sch, ctl->DP, ctl, prio, stab, max_P, &prealloc, + extack); + if (err < 0) + goto err_unlock_free; + + if (tb[TCA_GRED_VQ_LIST]) + gred_vqs_apply(table, tb[TCA_GRED_VQ_LIST]); + + if (gred_rio_mode(table)) { + gred_disable_wred_mode(table); + if (gred_wred_mode_check(sch)) + gred_enable_wred_mode(table); + } + + sch_tree_unlock(sch); + kfree(prealloc); + + gred_offload(sch, TC_GRED_REPLACE); + return 0; + +err_unlock_free: + sch_tree_unlock(sch); + kfree(prealloc); + return err; +} + +static int gred_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct gred_sched *table = qdisc_priv(sch); + struct nlattr *tb[TCA_GRED_MAX + 1]; + int err; + + if (!opt) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, TCA_GRED_MAX, opt, gred_policy, + extack); + if (err < 0) + return err; + + if (tb[TCA_GRED_PARMS] || tb[TCA_GRED_STAB]) { + NL_SET_ERR_MSG_MOD(extack, + "virtual queue configuration can't be specified at initialization time"); + return -EINVAL; + } + + if (tb[TCA_GRED_LIMIT]) + sch->limit = nla_get_u32(tb[TCA_GRED_LIMIT]); + else + sch->limit = qdisc_dev(sch)->tx_queue_len + * psched_mtu(qdisc_dev(sch)); + + if (qdisc_dev(sch)->netdev_ops->ndo_setup_tc) { + table->opt = kzalloc(sizeof(*table->opt), GFP_KERNEL); + if (!table->opt) + return -ENOMEM; + } + + return gred_change_table_def(sch, tb[TCA_GRED_DPS], extack); +} + +static int gred_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct gred_sched *table = qdisc_priv(sch); + struct nlattr *parms, *vqs, *opts = NULL; + int i; + u32 max_p[MAX_DPs]; + struct tc_gred_sopt sopt = { + .DPs = table->DPs, + .def_DP = table->def, + .grio = gred_rio_mode(table), + .flags = table->red_flags, + }; + + if (gred_offload_dump_stats(sch)) + goto nla_put_failure; + + opts = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (opts == NULL) + goto nla_put_failure; + if (nla_put(skb, TCA_GRED_DPS, sizeof(sopt), &sopt)) + goto nla_put_failure; + + for (i = 0; i < MAX_DPs; i++) { + struct gred_sched_data *q = table->tab[i]; + + max_p[i] = q ? q->parms.max_P : 0; + } + if (nla_put(skb, TCA_GRED_MAX_P, sizeof(max_p), max_p)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_GRED_LIMIT, sch->limit)) + goto nla_put_failure; + + /* Old style all-in-one dump of VQs */ + parms = nla_nest_start_noflag(skb, TCA_GRED_PARMS); + if (parms == NULL) + goto nla_put_failure; + + for (i = 0; i < MAX_DPs; i++) { + struct gred_sched_data *q = table->tab[i]; + struct tc_gred_qopt opt; + unsigned long qavg; + + memset(&opt, 0, sizeof(opt)); + + if (!q) { + /* hack -- fix at some point with proper message + This is how we indicate to tc that there is no VQ + at this DP */ + + opt.DP = MAX_DPs + i; + goto append_opt; + } + + opt.limit = q->limit; + opt.DP = q->DP; + opt.backlog = gred_backlog(table, q, sch); + opt.prio = q->prio; + opt.qth_min = q->parms.qth_min >> q->parms.Wlog; + opt.qth_max = q->parms.qth_max >> q->parms.Wlog; + opt.Wlog = q->parms.Wlog; + opt.Plog = q->parms.Plog; + opt.Scell_log = q->parms.Scell_log; + opt.early = q->stats.prob_drop; + opt.forced = q->stats.forced_drop; + opt.pdrop = q->stats.pdrop; + opt.packets = q->packetsin; + opt.bytesin = q->bytesin; + + if (gred_wred_mode(table)) + gred_load_wred_set(table, q); + + qavg = red_calc_qavg(&q->parms, &q->vars, + q->vars.qavg >> q->parms.Wlog); + opt.qave = qavg >> q->parms.Wlog; + +append_opt: + if (nla_append(skb, sizeof(opt), &opt) < 0) + goto nla_put_failure; + } + + nla_nest_end(skb, parms); + + /* Dump the VQs again, in more structured way */ + vqs = nla_nest_start_noflag(skb, TCA_GRED_VQ_LIST); + if (!vqs) + goto nla_put_failure; + + for (i = 0; i < MAX_DPs; i++) { + struct gred_sched_data *q = table->tab[i]; + struct nlattr *vq; + + if (!q) + continue; + + vq = nla_nest_start_noflag(skb, TCA_GRED_VQ_ENTRY); + if (!vq) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_GRED_VQ_DP, q->DP)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_GRED_VQ_FLAGS, q->red_flags)) + goto nla_put_failure; + + /* Stats */ + if (nla_put_u64_64bit(skb, TCA_GRED_VQ_STAT_BYTES, q->bytesin, + TCA_GRED_VQ_PAD)) + goto nla_put_failure; + if (nla_put_u32(skb, TCA_GRED_VQ_STAT_PACKETS, q->packetsin)) + goto nla_put_failure; + if (nla_put_u32(skb, TCA_GRED_VQ_STAT_BACKLOG, + gred_backlog(table, q, sch))) + goto nla_put_failure; + if (nla_put_u32(skb, TCA_GRED_VQ_STAT_PROB_DROP, + q->stats.prob_drop)) + goto nla_put_failure; + if (nla_put_u32(skb, TCA_GRED_VQ_STAT_PROB_MARK, + q->stats.prob_mark)) + goto nla_put_failure; + if (nla_put_u32(skb, TCA_GRED_VQ_STAT_FORCED_DROP, + q->stats.forced_drop)) + goto nla_put_failure; + if (nla_put_u32(skb, TCA_GRED_VQ_STAT_FORCED_MARK, + q->stats.forced_mark)) + goto nla_put_failure; + if (nla_put_u32(skb, TCA_GRED_VQ_STAT_PDROP, q->stats.pdrop)) + goto nla_put_failure; + + nla_nest_end(skb, vq); + } + nla_nest_end(skb, vqs); + + return nla_nest_end(skb, opts); + +nla_put_failure: + nla_nest_cancel(skb, opts); + return -EMSGSIZE; +} + +static void gred_destroy(struct Qdisc *sch) +{ + struct gred_sched *table = qdisc_priv(sch); + int i; + + for (i = 0; i < table->DPs; i++) + gred_destroy_vq(table->tab[i]); + + gred_offload(sch, TC_GRED_DESTROY); + kfree(table->opt); +} + +static struct Qdisc_ops gred_qdisc_ops __read_mostly = { + .id = "gred", + .priv_size = sizeof(struct gred_sched), + .enqueue = gred_enqueue, + .dequeue = gred_dequeue, + .peek = qdisc_peek_head, + .init = gred_init, + .reset = gred_reset, + .destroy = gred_destroy, + .change = gred_change, + .dump = gred_dump, + .owner = THIS_MODULE, +}; + +static int __init gred_module_init(void) +{ + return register_qdisc(&gred_qdisc_ops); +} + +static void __exit gred_module_exit(void) +{ + unregister_qdisc(&gred_qdisc_ops); +} + +module_init(gred_module_init) +module_exit(gred_module_exit) + +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_hfsc.c b/net/sched/sch_hfsc.c new file mode 100644 index 000000000..54dddc2ff --- /dev/null +++ b/net/sched/sch_hfsc.c @@ -0,0 +1,1695 @@ +/* + * Copyright (c) 2003 Patrick McHardy, + * + * This program is free software; you can redistribute it and/or + * modify it under the terms of the GNU General Public License + * as published by the Free Software Foundation; either version 2 + * of the License, or (at your option) any later version. + * + * 2003-10-17 - Ported from altq + */ +/* + * Copyright (c) 1997-1999 Carnegie Mellon University. All Rights Reserved. + * + * Permission to use, copy, modify, and distribute this software and + * its documentation is hereby granted (including for commercial or + * for-profit use), provided that both the copyright notice and this + * permission notice appear in all copies of the software, derivative + * works, or modified versions, and any portions thereof. + * + * THIS SOFTWARE IS EXPERIMENTAL AND IS KNOWN TO HAVE BUGS, SOME OF + * WHICH MAY HAVE SERIOUS CONSEQUENCES. CARNEGIE MELLON PROVIDES THIS + * SOFTWARE IN ITS ``AS IS'' CONDITION, AND ANY EXPRESS OR IMPLIED + * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT + * OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH + * DAMAGE. + * + * Carnegie Mellon encourages (but does not require) users of this + * software to return any improvements or extensions that they make, + * and to grant Carnegie Mellon the rights to redistribute these + * changes without encumbrance. + */ +/* + * H-FSC is described in Proceedings of SIGCOMM'97, + * "A Hierarchical Fair Service Curve Algorithm for Link-Sharing, + * Real-Time and Priority Service" + * by Ion Stoica, Hui Zhang, and T. S. Eugene Ng. + * + * Oleg Cherevko added the upperlimit for link-sharing. + * when a class has an upperlimit, the fit-time is computed from the + * upperlimit service curve. the link-sharing scheduler does not schedule + * a class whose fit-time exceeds the current time. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * kernel internal service curve representation: + * coordinates are given by 64 bit unsigned integers. + * x-axis: unit is clock count. + * y-axis: unit is byte. + * + * The service curve parameters are converted to the internal + * representation. The slope values are scaled to avoid overflow. + * the inverse slope values as well as the y-projection of the 1st + * segment are kept in order to avoid 64-bit divide operations + * that are expensive on 32-bit architectures. + */ + +struct internal_sc { + u64 sm1; /* scaled slope of the 1st segment */ + u64 ism1; /* scaled inverse-slope of the 1st segment */ + u64 dx; /* the x-projection of the 1st segment */ + u64 dy; /* the y-projection of the 1st segment */ + u64 sm2; /* scaled slope of the 2nd segment */ + u64 ism2; /* scaled inverse-slope of the 2nd segment */ +}; + +/* runtime service curve */ +struct runtime_sc { + u64 x; /* current starting position on x-axis */ + u64 y; /* current starting position on y-axis */ + u64 sm1; /* scaled slope of the 1st segment */ + u64 ism1; /* scaled inverse-slope of the 1st segment */ + u64 dx; /* the x-projection of the 1st segment */ + u64 dy; /* the y-projection of the 1st segment */ + u64 sm2; /* scaled slope of the 2nd segment */ + u64 ism2; /* scaled inverse-slope of the 2nd segment */ +}; + +enum hfsc_class_flags { + HFSC_RSC = 0x1, + HFSC_FSC = 0x2, + HFSC_USC = 0x4 +}; + +struct hfsc_class { + struct Qdisc_class_common cl_common; + + struct gnet_stats_basic_sync bstats; + struct gnet_stats_queue qstats; + struct net_rate_estimator __rcu *rate_est; + struct tcf_proto __rcu *filter_list; /* filter list */ + struct tcf_block *block; + unsigned int filter_cnt; /* filter count */ + unsigned int level; /* class level in hierarchy */ + + struct hfsc_sched *sched; /* scheduler data */ + struct hfsc_class *cl_parent; /* parent class */ + struct list_head siblings; /* sibling classes */ + struct list_head children; /* child classes */ + struct Qdisc *qdisc; /* leaf qdisc */ + + struct rb_node el_node; /* qdisc's eligible tree member */ + struct rb_root vt_tree; /* active children sorted by cl_vt */ + struct rb_node vt_node; /* parent's vt_tree member */ + struct rb_root cf_tree; /* active children sorted by cl_f */ + struct rb_node cf_node; /* parent's cf_heap member */ + + u64 cl_total; /* total work in bytes */ + u64 cl_cumul; /* cumulative work in bytes done by + real-time criteria */ + + u64 cl_d; /* deadline*/ + u64 cl_e; /* eligible time */ + u64 cl_vt; /* virtual time */ + u64 cl_f; /* time when this class will fit for + link-sharing, max(myf, cfmin) */ + u64 cl_myf; /* my fit-time (calculated from this + class's own upperlimit curve) */ + u64 cl_cfmin; /* earliest children's fit-time (used + with cl_myf to obtain cl_f) */ + u64 cl_cvtmin; /* minimal virtual time among the + children fit for link-sharing + (monotonic within a period) */ + u64 cl_vtadj; /* intra-period cumulative vt + adjustment */ + u64 cl_cvtoff; /* largest virtual time seen among + the children */ + + struct internal_sc cl_rsc; /* internal real-time service curve */ + struct internal_sc cl_fsc; /* internal fair service curve */ + struct internal_sc cl_usc; /* internal upperlimit service curve */ + struct runtime_sc cl_deadline; /* deadline curve */ + struct runtime_sc cl_eligible; /* eligible curve */ + struct runtime_sc cl_virtual; /* virtual curve */ + struct runtime_sc cl_ulimit; /* upperlimit curve */ + + u8 cl_flags; /* which curves are valid */ + u32 cl_vtperiod; /* vt period sequence number */ + u32 cl_parentperiod;/* parent's vt period sequence number*/ + u32 cl_nactive; /* number of active children */ +}; + +struct hfsc_sched { + u16 defcls; /* default class id */ + struct hfsc_class root; /* root class */ + struct Qdisc_class_hash clhash; /* class hash */ + struct rb_root eligible; /* eligible tree */ + struct qdisc_watchdog watchdog; /* watchdog timer */ +}; + +#define HT_INFINITY 0xffffffffffffffffULL /* infinite time value */ + + +/* + * eligible tree holds backlogged classes being sorted by their eligible times. + * there is one eligible tree per hfsc instance. + */ + +static void +eltree_insert(struct hfsc_class *cl) +{ + struct rb_node **p = &cl->sched->eligible.rb_node; + struct rb_node *parent = NULL; + struct hfsc_class *cl1; + + while (*p != NULL) { + parent = *p; + cl1 = rb_entry(parent, struct hfsc_class, el_node); + if (cl->cl_e >= cl1->cl_e) + p = &parent->rb_right; + else + p = &parent->rb_left; + } + rb_link_node(&cl->el_node, parent, p); + rb_insert_color(&cl->el_node, &cl->sched->eligible); +} + +static inline void +eltree_remove(struct hfsc_class *cl) +{ + rb_erase(&cl->el_node, &cl->sched->eligible); +} + +static inline void +eltree_update(struct hfsc_class *cl) +{ + eltree_remove(cl); + eltree_insert(cl); +} + +/* find the class with the minimum deadline among the eligible classes */ +static inline struct hfsc_class * +eltree_get_mindl(struct hfsc_sched *q, u64 cur_time) +{ + struct hfsc_class *p, *cl = NULL; + struct rb_node *n; + + for (n = rb_first(&q->eligible); n != NULL; n = rb_next(n)) { + p = rb_entry(n, struct hfsc_class, el_node); + if (p->cl_e > cur_time) + break; + if (cl == NULL || p->cl_d < cl->cl_d) + cl = p; + } + return cl; +} + +/* find the class with minimum eligible time among the eligible classes */ +static inline struct hfsc_class * +eltree_get_minel(struct hfsc_sched *q) +{ + struct rb_node *n; + + n = rb_first(&q->eligible); + if (n == NULL) + return NULL; + return rb_entry(n, struct hfsc_class, el_node); +} + +/* + * vttree holds holds backlogged child classes being sorted by their virtual + * time. each intermediate class has one vttree. + */ +static void +vttree_insert(struct hfsc_class *cl) +{ + struct rb_node **p = &cl->cl_parent->vt_tree.rb_node; + struct rb_node *parent = NULL; + struct hfsc_class *cl1; + + while (*p != NULL) { + parent = *p; + cl1 = rb_entry(parent, struct hfsc_class, vt_node); + if (cl->cl_vt >= cl1->cl_vt) + p = &parent->rb_right; + else + p = &parent->rb_left; + } + rb_link_node(&cl->vt_node, parent, p); + rb_insert_color(&cl->vt_node, &cl->cl_parent->vt_tree); +} + +static inline void +vttree_remove(struct hfsc_class *cl) +{ + rb_erase(&cl->vt_node, &cl->cl_parent->vt_tree); +} + +static inline void +vttree_update(struct hfsc_class *cl) +{ + vttree_remove(cl); + vttree_insert(cl); +} + +static inline struct hfsc_class * +vttree_firstfit(struct hfsc_class *cl, u64 cur_time) +{ + struct hfsc_class *p; + struct rb_node *n; + + for (n = rb_first(&cl->vt_tree); n != NULL; n = rb_next(n)) { + p = rb_entry(n, struct hfsc_class, vt_node); + if (p->cl_f <= cur_time) + return p; + } + return NULL; +} + +/* + * get the leaf class with the minimum vt in the hierarchy + */ +static struct hfsc_class * +vttree_get_minvt(struct hfsc_class *cl, u64 cur_time) +{ + /* if root-class's cfmin is bigger than cur_time nothing to do */ + if (cl->cl_cfmin > cur_time) + return NULL; + + while (cl->level > 0) { + cl = vttree_firstfit(cl, cur_time); + if (cl == NULL) + return NULL; + /* + * update parent's cl_cvtmin. + */ + if (cl->cl_parent->cl_cvtmin < cl->cl_vt) + cl->cl_parent->cl_cvtmin = cl->cl_vt; + } + return cl; +} + +static void +cftree_insert(struct hfsc_class *cl) +{ + struct rb_node **p = &cl->cl_parent->cf_tree.rb_node; + struct rb_node *parent = NULL; + struct hfsc_class *cl1; + + while (*p != NULL) { + parent = *p; + cl1 = rb_entry(parent, struct hfsc_class, cf_node); + if (cl->cl_f >= cl1->cl_f) + p = &parent->rb_right; + else + p = &parent->rb_left; + } + rb_link_node(&cl->cf_node, parent, p); + rb_insert_color(&cl->cf_node, &cl->cl_parent->cf_tree); +} + +static inline void +cftree_remove(struct hfsc_class *cl) +{ + rb_erase(&cl->cf_node, &cl->cl_parent->cf_tree); +} + +static inline void +cftree_update(struct hfsc_class *cl) +{ + cftree_remove(cl); + cftree_insert(cl); +} + +/* + * service curve support functions + * + * external service curve parameters + * m: bps + * d: us + * internal service curve parameters + * sm: (bytes/psched_us) << SM_SHIFT + * ism: (psched_us/byte) << ISM_SHIFT + * dx: psched_us + * + * The clock source resolution with ktime and PSCHED_SHIFT 10 is 1.024us. + * + * sm and ism are scaled in order to keep effective digits. + * SM_SHIFT and ISM_SHIFT are selected to keep at least 4 effective + * digits in decimal using the following table. + * + * bits/sec 100Kbps 1Mbps 10Mbps 100Mbps 1Gbps + * ------------+------------------------------------------------------- + * bytes/1.024us 12.8e-3 128e-3 1280e-3 12800e-3 128000e-3 + * + * 1.024us/byte 78.125 7.8125 0.78125 0.078125 0.0078125 + * + * So, for PSCHED_SHIFT 10 we need: SM_SHIFT 20, ISM_SHIFT 18. + */ +#define SM_SHIFT (30 - PSCHED_SHIFT) +#define ISM_SHIFT (8 + PSCHED_SHIFT) + +#define SM_MASK ((1ULL << SM_SHIFT) - 1) +#define ISM_MASK ((1ULL << ISM_SHIFT) - 1) + +static inline u64 +seg_x2y(u64 x, u64 sm) +{ + u64 y; + + /* + * compute + * y = x * sm >> SM_SHIFT + * but divide it for the upper and lower bits to avoid overflow + */ + y = (x >> SM_SHIFT) * sm + (((x & SM_MASK) * sm) >> SM_SHIFT); + return y; +} + +static inline u64 +seg_y2x(u64 y, u64 ism) +{ + u64 x; + + if (y == 0) + x = 0; + else if (ism == HT_INFINITY) + x = HT_INFINITY; + else { + x = (y >> ISM_SHIFT) * ism + + (((y & ISM_MASK) * ism) >> ISM_SHIFT); + } + return x; +} + +/* Convert m (bps) into sm (bytes/psched us) */ +static u64 +m2sm(u32 m) +{ + u64 sm; + + sm = ((u64)m << SM_SHIFT); + sm += PSCHED_TICKS_PER_SEC - 1; + do_div(sm, PSCHED_TICKS_PER_SEC); + return sm; +} + +/* convert m (bps) into ism (psched us/byte) */ +static u64 +m2ism(u32 m) +{ + u64 ism; + + if (m == 0) + ism = HT_INFINITY; + else { + ism = ((u64)PSCHED_TICKS_PER_SEC << ISM_SHIFT); + ism += m - 1; + do_div(ism, m); + } + return ism; +} + +/* convert d (us) into dx (psched us) */ +static u64 +d2dx(u32 d) +{ + u64 dx; + + dx = ((u64)d * PSCHED_TICKS_PER_SEC); + dx += USEC_PER_SEC - 1; + do_div(dx, USEC_PER_SEC); + return dx; +} + +/* convert sm (bytes/psched us) into m (bps) */ +static u32 +sm2m(u64 sm) +{ + u64 m; + + m = (sm * PSCHED_TICKS_PER_SEC) >> SM_SHIFT; + return (u32)m; +} + +/* convert dx (psched us) into d (us) */ +static u32 +dx2d(u64 dx) +{ + u64 d; + + d = dx * USEC_PER_SEC; + do_div(d, PSCHED_TICKS_PER_SEC); + return (u32)d; +} + +static void +sc2isc(struct tc_service_curve *sc, struct internal_sc *isc) +{ + isc->sm1 = m2sm(sc->m1); + isc->ism1 = m2ism(sc->m1); + isc->dx = d2dx(sc->d); + isc->dy = seg_x2y(isc->dx, isc->sm1); + isc->sm2 = m2sm(sc->m2); + isc->ism2 = m2ism(sc->m2); +} + +/* + * initialize the runtime service curve with the given internal + * service curve starting at (x, y). + */ +static void +rtsc_init(struct runtime_sc *rtsc, struct internal_sc *isc, u64 x, u64 y) +{ + rtsc->x = x; + rtsc->y = y; + rtsc->sm1 = isc->sm1; + rtsc->ism1 = isc->ism1; + rtsc->dx = isc->dx; + rtsc->dy = isc->dy; + rtsc->sm2 = isc->sm2; + rtsc->ism2 = isc->ism2; +} + +/* + * calculate the y-projection of the runtime service curve by the + * given x-projection value + */ +static u64 +rtsc_y2x(struct runtime_sc *rtsc, u64 y) +{ + u64 x; + + if (y < rtsc->y) + x = rtsc->x; + else if (y <= rtsc->y + rtsc->dy) { + /* x belongs to the 1st segment */ + if (rtsc->dy == 0) + x = rtsc->x + rtsc->dx; + else + x = rtsc->x + seg_y2x(y - rtsc->y, rtsc->ism1); + } else { + /* x belongs to the 2nd segment */ + x = rtsc->x + rtsc->dx + + seg_y2x(y - rtsc->y - rtsc->dy, rtsc->ism2); + } + return x; +} + +static u64 +rtsc_x2y(struct runtime_sc *rtsc, u64 x) +{ + u64 y; + + if (x <= rtsc->x) + y = rtsc->y; + else if (x <= rtsc->x + rtsc->dx) + /* y belongs to the 1st segment */ + y = rtsc->y + seg_x2y(x - rtsc->x, rtsc->sm1); + else + /* y belongs to the 2nd segment */ + y = rtsc->y + rtsc->dy + + seg_x2y(x - rtsc->x - rtsc->dx, rtsc->sm2); + return y; +} + +/* + * update the runtime service curve by taking the minimum of the current + * runtime service curve and the service curve starting at (x, y). + */ +static void +rtsc_min(struct runtime_sc *rtsc, struct internal_sc *isc, u64 x, u64 y) +{ + u64 y1, y2, dx, dy; + u32 dsm; + + if (isc->sm1 <= isc->sm2) { + /* service curve is convex */ + y1 = rtsc_x2y(rtsc, x); + if (y1 < y) + /* the current rtsc is smaller */ + return; + rtsc->x = x; + rtsc->y = y; + return; + } + + /* + * service curve is concave + * compute the two y values of the current rtsc + * y1: at x + * y2: at (x + dx) + */ + y1 = rtsc_x2y(rtsc, x); + if (y1 <= y) { + /* rtsc is below isc, no change to rtsc */ + return; + } + + y2 = rtsc_x2y(rtsc, x + isc->dx); + if (y2 >= y + isc->dy) { + /* rtsc is above isc, replace rtsc by isc */ + rtsc->x = x; + rtsc->y = y; + rtsc->dx = isc->dx; + rtsc->dy = isc->dy; + return; + } + + /* + * the two curves intersect + * compute the offsets (dx, dy) using the reverse + * function of seg_x2y() + * seg_x2y(dx, sm1) == seg_x2y(dx, sm2) + (y1 - y) + */ + dx = (y1 - y) << SM_SHIFT; + dsm = isc->sm1 - isc->sm2; + do_div(dx, dsm); + /* + * check if (x, y1) belongs to the 1st segment of rtsc. + * if so, add the offset. + */ + if (rtsc->x + rtsc->dx > x) + dx += rtsc->x + rtsc->dx - x; + dy = seg_x2y(dx, isc->sm1); + + rtsc->x = x; + rtsc->y = y; + rtsc->dx = dx; + rtsc->dy = dy; +} + +static void +init_ed(struct hfsc_class *cl, unsigned int next_len) +{ + u64 cur_time = psched_get_time(); + + /* update the deadline curve */ + rtsc_min(&cl->cl_deadline, &cl->cl_rsc, cur_time, cl->cl_cumul); + + /* + * update the eligible curve. + * for concave, it is equal to the deadline curve. + * for convex, it is a linear curve with slope m2. + */ + cl->cl_eligible = cl->cl_deadline; + if (cl->cl_rsc.sm1 <= cl->cl_rsc.sm2) { + cl->cl_eligible.dx = 0; + cl->cl_eligible.dy = 0; + } + + /* compute e and d */ + cl->cl_e = rtsc_y2x(&cl->cl_eligible, cl->cl_cumul); + cl->cl_d = rtsc_y2x(&cl->cl_deadline, cl->cl_cumul + next_len); + + eltree_insert(cl); +} + +static void +update_ed(struct hfsc_class *cl, unsigned int next_len) +{ + cl->cl_e = rtsc_y2x(&cl->cl_eligible, cl->cl_cumul); + cl->cl_d = rtsc_y2x(&cl->cl_deadline, cl->cl_cumul + next_len); + + eltree_update(cl); +} + +static inline void +update_d(struct hfsc_class *cl, unsigned int next_len) +{ + cl->cl_d = rtsc_y2x(&cl->cl_deadline, cl->cl_cumul + next_len); +} + +static inline void +update_cfmin(struct hfsc_class *cl) +{ + struct rb_node *n = rb_first(&cl->cf_tree); + struct hfsc_class *p; + + if (n == NULL) { + cl->cl_cfmin = 0; + return; + } + p = rb_entry(n, struct hfsc_class, cf_node); + cl->cl_cfmin = p->cl_f; +} + +static void +init_vf(struct hfsc_class *cl, unsigned int len) +{ + struct hfsc_class *max_cl; + struct rb_node *n; + u64 vt, f, cur_time; + int go_active; + + cur_time = 0; + go_active = 1; + for (; cl->cl_parent != NULL; cl = cl->cl_parent) { + if (go_active && cl->cl_nactive++ == 0) + go_active = 1; + else + go_active = 0; + + if (go_active) { + n = rb_last(&cl->cl_parent->vt_tree); + if (n != NULL) { + max_cl = rb_entry(n, struct hfsc_class, vt_node); + /* + * set vt to the average of the min and max + * classes. if the parent's period didn't + * change, don't decrease vt of the class. + */ + vt = max_cl->cl_vt; + if (cl->cl_parent->cl_cvtmin != 0) + vt = (cl->cl_parent->cl_cvtmin + vt)/2; + + if (cl->cl_parent->cl_vtperiod != + cl->cl_parentperiod || vt > cl->cl_vt) + cl->cl_vt = vt; + } else { + /* + * first child for a new parent backlog period. + * initialize cl_vt to the highest value seen + * among the siblings. this is analogous to + * what cur_time would provide in realtime case. + */ + cl->cl_vt = cl->cl_parent->cl_cvtoff; + cl->cl_parent->cl_cvtmin = 0; + } + + /* update the virtual curve */ + rtsc_min(&cl->cl_virtual, &cl->cl_fsc, cl->cl_vt, cl->cl_total); + cl->cl_vtadj = 0; + + cl->cl_vtperiod++; /* increment vt period */ + cl->cl_parentperiod = cl->cl_parent->cl_vtperiod; + if (cl->cl_parent->cl_nactive == 0) + cl->cl_parentperiod++; + cl->cl_f = 0; + + vttree_insert(cl); + cftree_insert(cl); + + if (cl->cl_flags & HFSC_USC) { + /* class has upper limit curve */ + if (cur_time == 0) + cur_time = psched_get_time(); + + /* update the ulimit curve */ + rtsc_min(&cl->cl_ulimit, &cl->cl_usc, cur_time, + cl->cl_total); + /* compute myf */ + cl->cl_myf = rtsc_y2x(&cl->cl_ulimit, + cl->cl_total); + } + } + + f = max(cl->cl_myf, cl->cl_cfmin); + if (f != cl->cl_f) { + cl->cl_f = f; + cftree_update(cl); + } + update_cfmin(cl->cl_parent); + } +} + +static void +update_vf(struct hfsc_class *cl, unsigned int len, u64 cur_time) +{ + u64 f; /* , myf_bound, delta; */ + int go_passive = 0; + + if (cl->qdisc->q.qlen == 0 && cl->cl_flags & HFSC_FSC) + go_passive = 1; + + for (; cl->cl_parent != NULL; cl = cl->cl_parent) { + cl->cl_total += len; + + if (!(cl->cl_flags & HFSC_FSC) || cl->cl_nactive == 0) + continue; + + if (go_passive && --cl->cl_nactive == 0) + go_passive = 1; + else + go_passive = 0; + + /* update vt */ + cl->cl_vt = rtsc_y2x(&cl->cl_virtual, cl->cl_total) + cl->cl_vtadj; + + /* + * if vt of the class is smaller than cvtmin, + * the class was skipped in the past due to non-fit. + * if so, we need to adjust vtadj. + */ + if (cl->cl_vt < cl->cl_parent->cl_cvtmin) { + cl->cl_vtadj += cl->cl_parent->cl_cvtmin - cl->cl_vt; + cl->cl_vt = cl->cl_parent->cl_cvtmin; + } + + if (go_passive) { + /* no more active child, going passive */ + + /* update cvtoff of the parent class */ + if (cl->cl_vt > cl->cl_parent->cl_cvtoff) + cl->cl_parent->cl_cvtoff = cl->cl_vt; + + /* remove this class from the vt tree */ + vttree_remove(cl); + + cftree_remove(cl); + update_cfmin(cl->cl_parent); + + continue; + } + + /* update the vt tree */ + vttree_update(cl); + + /* update f */ + if (cl->cl_flags & HFSC_USC) { + cl->cl_myf = rtsc_y2x(&cl->cl_ulimit, cl->cl_total); +#if 0 + cl->cl_myf = cl->cl_myfadj + rtsc_y2x(&cl->cl_ulimit, + cl->cl_total); + /* + * This code causes classes to stay way under their + * limit when multiple classes are used at gigabit + * speed. needs investigation. -kaber + */ + /* + * if myf lags behind by more than one clock tick + * from the current time, adjust myfadj to prevent + * a rate-limited class from going greedy. + * in a steady state under rate-limiting, myf + * fluctuates within one clock tick. + */ + myf_bound = cur_time - PSCHED_JIFFIE2US(1); + if (cl->cl_myf < myf_bound) { + delta = cur_time - cl->cl_myf; + cl->cl_myfadj += delta; + cl->cl_myf += delta; + } +#endif + } + + f = max(cl->cl_myf, cl->cl_cfmin); + if (f != cl->cl_f) { + cl->cl_f = f; + cftree_update(cl); + update_cfmin(cl->cl_parent); + } + } +} + +static unsigned int +qdisc_peek_len(struct Qdisc *sch) +{ + struct sk_buff *skb; + unsigned int len; + + skb = sch->ops->peek(sch); + if (unlikely(skb == NULL)) { + qdisc_warn_nonwc("qdisc_peek_len", sch); + return 0; + } + len = qdisc_pkt_len(skb); + + return len; +} + +static void +hfsc_adjust_levels(struct hfsc_class *cl) +{ + struct hfsc_class *p; + unsigned int level; + + do { + level = 0; + list_for_each_entry(p, &cl->children, siblings) { + if (p->level >= level) + level = p->level + 1; + } + cl->level = level; + } while ((cl = cl->cl_parent) != NULL); +} + +static inline struct hfsc_class * +hfsc_find_class(u32 classid, struct Qdisc *sch) +{ + struct hfsc_sched *q = qdisc_priv(sch); + struct Qdisc_class_common *clc; + + clc = qdisc_class_find(&q->clhash, classid); + if (clc == NULL) + return NULL; + return container_of(clc, struct hfsc_class, cl_common); +} + +static void +hfsc_change_rsc(struct hfsc_class *cl, struct tc_service_curve *rsc, + u64 cur_time) +{ + sc2isc(rsc, &cl->cl_rsc); + rtsc_init(&cl->cl_deadline, &cl->cl_rsc, cur_time, cl->cl_cumul); + cl->cl_eligible = cl->cl_deadline; + if (cl->cl_rsc.sm1 <= cl->cl_rsc.sm2) { + cl->cl_eligible.dx = 0; + cl->cl_eligible.dy = 0; + } + cl->cl_flags |= HFSC_RSC; +} + +static void +hfsc_change_fsc(struct hfsc_class *cl, struct tc_service_curve *fsc) +{ + sc2isc(fsc, &cl->cl_fsc); + rtsc_init(&cl->cl_virtual, &cl->cl_fsc, cl->cl_vt, cl->cl_total); + cl->cl_flags |= HFSC_FSC; +} + +static void +hfsc_change_usc(struct hfsc_class *cl, struct tc_service_curve *usc, + u64 cur_time) +{ + sc2isc(usc, &cl->cl_usc); + rtsc_init(&cl->cl_ulimit, &cl->cl_usc, cur_time, cl->cl_total); + cl->cl_flags |= HFSC_USC; +} + +static void +hfsc_upgrade_rt(struct hfsc_class *cl) +{ + cl->cl_fsc = cl->cl_rsc; + rtsc_init(&cl->cl_virtual, &cl->cl_fsc, cl->cl_vt, cl->cl_total); + cl->cl_flags |= HFSC_FSC; +} + +static const struct nla_policy hfsc_policy[TCA_HFSC_MAX + 1] = { + [TCA_HFSC_RSC] = { .len = sizeof(struct tc_service_curve) }, + [TCA_HFSC_FSC] = { .len = sizeof(struct tc_service_curve) }, + [TCA_HFSC_USC] = { .len = sizeof(struct tc_service_curve) }, +}; + +static int +hfsc_change_class(struct Qdisc *sch, u32 classid, u32 parentid, + struct nlattr **tca, unsigned long *arg, + struct netlink_ext_ack *extack) +{ + struct hfsc_sched *q = qdisc_priv(sch); + struct hfsc_class *cl = (struct hfsc_class *)*arg; + struct hfsc_class *parent = NULL; + struct nlattr *opt = tca[TCA_OPTIONS]; + struct nlattr *tb[TCA_HFSC_MAX + 1]; + struct tc_service_curve *rsc = NULL, *fsc = NULL, *usc = NULL; + u64 cur_time; + int err; + + if (opt == NULL) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, TCA_HFSC_MAX, opt, hfsc_policy, + NULL); + if (err < 0) + return err; + + if (tb[TCA_HFSC_RSC]) { + rsc = nla_data(tb[TCA_HFSC_RSC]); + if (rsc->m1 == 0 && rsc->m2 == 0) + rsc = NULL; + } + + if (tb[TCA_HFSC_FSC]) { + fsc = nla_data(tb[TCA_HFSC_FSC]); + if (fsc->m1 == 0 && fsc->m2 == 0) + fsc = NULL; + } + + if (tb[TCA_HFSC_USC]) { + usc = nla_data(tb[TCA_HFSC_USC]); + if (usc->m1 == 0 && usc->m2 == 0) + usc = NULL; + } + + if (cl != NULL) { + int old_flags; + + if (parentid) { + if (cl->cl_parent && + cl->cl_parent->cl_common.classid != parentid) + return -EINVAL; + if (cl->cl_parent == NULL && parentid != TC_H_ROOT) + return -EINVAL; + } + cur_time = psched_get_time(); + + if (tca[TCA_RATE]) { + err = gen_replace_estimator(&cl->bstats, NULL, + &cl->rate_est, + NULL, + true, + tca[TCA_RATE]); + if (err) + return err; + } + + sch_tree_lock(sch); + old_flags = cl->cl_flags; + + if (rsc != NULL) + hfsc_change_rsc(cl, rsc, cur_time); + if (fsc != NULL) + hfsc_change_fsc(cl, fsc); + if (usc != NULL) + hfsc_change_usc(cl, usc, cur_time); + + if (cl->qdisc->q.qlen != 0) { + int len = qdisc_peek_len(cl->qdisc); + + if (cl->cl_flags & HFSC_RSC) { + if (old_flags & HFSC_RSC) + update_ed(cl, len); + else + init_ed(cl, len); + } + + if (cl->cl_flags & HFSC_FSC) { + if (old_flags & HFSC_FSC) + update_vf(cl, 0, cur_time); + else + init_vf(cl, len); + } + } + sch_tree_unlock(sch); + + return 0; + } + + if (parentid == TC_H_ROOT) + return -EEXIST; + + parent = &q->root; + if (parentid) { + parent = hfsc_find_class(parentid, sch); + if (parent == NULL) + return -ENOENT; + } + + if (classid == 0 || TC_H_MAJ(classid ^ sch->handle) != 0) + return -EINVAL; + if (hfsc_find_class(classid, sch)) + return -EEXIST; + + if (rsc == NULL && fsc == NULL) + return -EINVAL; + + cl = kzalloc(sizeof(struct hfsc_class), GFP_KERNEL); + if (cl == NULL) + return -ENOBUFS; + + err = tcf_block_get(&cl->block, &cl->filter_list, sch, extack); + if (err) { + kfree(cl); + return err; + } + + if (tca[TCA_RATE]) { + err = gen_new_estimator(&cl->bstats, NULL, &cl->rate_est, + NULL, true, tca[TCA_RATE]); + if (err) { + tcf_block_put(cl->block); + kfree(cl); + return err; + } + } + + if (rsc != NULL) + hfsc_change_rsc(cl, rsc, 0); + if (fsc != NULL) + hfsc_change_fsc(cl, fsc); + if (usc != NULL) + hfsc_change_usc(cl, usc, 0); + + cl->cl_common.classid = classid; + cl->sched = q; + cl->cl_parent = parent; + cl->qdisc = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, + classid, NULL); + if (cl->qdisc == NULL) + cl->qdisc = &noop_qdisc; + else + qdisc_hash_add(cl->qdisc, true); + INIT_LIST_HEAD(&cl->children); + cl->vt_tree = RB_ROOT; + cl->cf_tree = RB_ROOT; + + sch_tree_lock(sch); + /* Check if the inner class is a misconfigured 'rt' */ + if (!(parent->cl_flags & HFSC_FSC) && parent != &q->root) { + NL_SET_ERR_MSG(extack, + "Forced curve change on parent 'rt' to 'sc'"); + hfsc_upgrade_rt(parent); + } + qdisc_class_hash_insert(&q->clhash, &cl->cl_common); + list_add_tail(&cl->siblings, &parent->children); + if (parent->level == 0) + qdisc_purge_queue(parent->qdisc); + hfsc_adjust_levels(parent); + sch_tree_unlock(sch); + + qdisc_class_hash_grow(sch, &q->clhash); + + *arg = (unsigned long)cl; + return 0; +} + +static void +hfsc_destroy_class(struct Qdisc *sch, struct hfsc_class *cl) +{ + struct hfsc_sched *q = qdisc_priv(sch); + + tcf_block_put(cl->block); + qdisc_put(cl->qdisc); + gen_kill_estimator(&cl->rate_est); + if (cl != &q->root) + kfree(cl); +} + +static int +hfsc_delete_class(struct Qdisc *sch, unsigned long arg, + struct netlink_ext_ack *extack) +{ + struct hfsc_sched *q = qdisc_priv(sch); + struct hfsc_class *cl = (struct hfsc_class *)arg; + + if (cl->level > 0 || cl->filter_cnt > 0 || cl == &q->root) + return -EBUSY; + + sch_tree_lock(sch); + + list_del(&cl->siblings); + hfsc_adjust_levels(cl->cl_parent); + + qdisc_purge_queue(cl->qdisc); + qdisc_class_hash_remove(&q->clhash, &cl->cl_common); + + sch_tree_unlock(sch); + + hfsc_destroy_class(sch, cl); + return 0; +} + +static struct hfsc_class * +hfsc_classify(struct sk_buff *skb, struct Qdisc *sch, int *qerr) +{ + struct hfsc_sched *q = qdisc_priv(sch); + struct hfsc_class *head, *cl; + struct tcf_result res; + struct tcf_proto *tcf; + int result; + + if (TC_H_MAJ(skb->priority ^ sch->handle) == 0 && + (cl = hfsc_find_class(skb->priority, sch)) != NULL) + if (cl->level == 0) + return cl; + + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; + head = &q->root; + tcf = rcu_dereference_bh(q->root.filter_list); + while (tcf && (result = tcf_classify(skb, NULL, tcf, &res, false)) >= 0) { +#ifdef CONFIG_NET_CLS_ACT + switch (result) { + case TC_ACT_QUEUED: + case TC_ACT_STOLEN: + case TC_ACT_TRAP: + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_STOLEN; + fallthrough; + case TC_ACT_SHOT: + return NULL; + } +#endif + cl = (struct hfsc_class *)res.class; + if (!cl) { + cl = hfsc_find_class(res.classid, sch); + if (!cl) + break; /* filter selected invalid classid */ + if (cl->level >= head->level) + break; /* filter may only point downwards */ + } + + if (cl->level == 0) + return cl; /* hit leaf class */ + + /* apply inner filter chain */ + tcf = rcu_dereference_bh(cl->filter_list); + head = cl; + } + + /* classification failed, try default class */ + cl = hfsc_find_class(TC_H_MAKE(TC_H_MAJ(sch->handle), q->defcls), sch); + if (cl == NULL || cl->level > 0) + return NULL; + + return cl; +} + +static int +hfsc_graft_class(struct Qdisc *sch, unsigned long arg, struct Qdisc *new, + struct Qdisc **old, struct netlink_ext_ack *extack) +{ + struct hfsc_class *cl = (struct hfsc_class *)arg; + + if (cl->level > 0) + return -EINVAL; + if (new == NULL) { + new = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, + cl->cl_common.classid, NULL); + if (new == NULL) + new = &noop_qdisc; + } + + *old = qdisc_replace(sch, new, &cl->qdisc); + return 0; +} + +static struct Qdisc * +hfsc_class_leaf(struct Qdisc *sch, unsigned long arg) +{ + struct hfsc_class *cl = (struct hfsc_class *)arg; + + if (cl->level == 0) + return cl->qdisc; + + return NULL; +} + +static void +hfsc_qlen_notify(struct Qdisc *sch, unsigned long arg) +{ + struct hfsc_class *cl = (struct hfsc_class *)arg; + + /* vttree is now handled in update_vf() so that update_vf(cl, 0, 0) + * needs to be called explicitly to remove a class from vttree. + */ + update_vf(cl, 0, 0); + if (cl->cl_flags & HFSC_RSC) + eltree_remove(cl); +} + +static unsigned long +hfsc_search_class(struct Qdisc *sch, u32 classid) +{ + return (unsigned long)hfsc_find_class(classid, sch); +} + +static unsigned long +hfsc_bind_tcf(struct Qdisc *sch, unsigned long parent, u32 classid) +{ + struct hfsc_class *p = (struct hfsc_class *)parent; + struct hfsc_class *cl = hfsc_find_class(classid, sch); + + if (cl != NULL) { + if (p != NULL && p->level <= cl->level) + return 0; + cl->filter_cnt++; + } + + return (unsigned long)cl; +} + +static void +hfsc_unbind_tcf(struct Qdisc *sch, unsigned long arg) +{ + struct hfsc_class *cl = (struct hfsc_class *)arg; + + cl->filter_cnt--; +} + +static struct tcf_block *hfsc_tcf_block(struct Qdisc *sch, unsigned long arg, + struct netlink_ext_ack *extack) +{ + struct hfsc_sched *q = qdisc_priv(sch); + struct hfsc_class *cl = (struct hfsc_class *)arg; + + if (cl == NULL) + cl = &q->root; + + return cl->block; +} + +static int +hfsc_dump_sc(struct sk_buff *skb, int attr, struct internal_sc *sc) +{ + struct tc_service_curve tsc; + + tsc.m1 = sm2m(sc->sm1); + tsc.d = dx2d(sc->dx); + tsc.m2 = sm2m(sc->sm2); + if (nla_put(skb, attr, sizeof(tsc), &tsc)) + goto nla_put_failure; + + return skb->len; + + nla_put_failure: + return -1; +} + +static int +hfsc_dump_curves(struct sk_buff *skb, struct hfsc_class *cl) +{ + if ((cl->cl_flags & HFSC_RSC) && + (hfsc_dump_sc(skb, TCA_HFSC_RSC, &cl->cl_rsc) < 0)) + goto nla_put_failure; + + if ((cl->cl_flags & HFSC_FSC) && + (hfsc_dump_sc(skb, TCA_HFSC_FSC, &cl->cl_fsc) < 0)) + goto nla_put_failure; + + if ((cl->cl_flags & HFSC_USC) && + (hfsc_dump_sc(skb, TCA_HFSC_USC, &cl->cl_usc) < 0)) + goto nla_put_failure; + + return skb->len; + + nla_put_failure: + return -1; +} + +static int +hfsc_dump_class(struct Qdisc *sch, unsigned long arg, struct sk_buff *skb, + struct tcmsg *tcm) +{ + struct hfsc_class *cl = (struct hfsc_class *)arg; + struct nlattr *nest; + + tcm->tcm_parent = cl->cl_parent ? cl->cl_parent->cl_common.classid : + TC_H_ROOT; + tcm->tcm_handle = cl->cl_common.classid; + if (cl->level == 0) + tcm->tcm_info = cl->qdisc->handle; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + if (hfsc_dump_curves(skb, cl) < 0) + goto nla_put_failure; + return nla_nest_end(skb, nest); + + nla_put_failure: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int +hfsc_dump_class_stats(struct Qdisc *sch, unsigned long arg, + struct gnet_dump *d) +{ + struct hfsc_class *cl = (struct hfsc_class *)arg; + struct tc_hfsc_stats xstats; + __u32 qlen; + + qdisc_qstats_qlen_backlog(cl->qdisc, &qlen, &cl->qstats.backlog); + xstats.level = cl->level; + xstats.period = cl->cl_vtperiod; + xstats.work = cl->cl_total; + xstats.rtwork = cl->cl_cumul; + + if (gnet_stats_copy_basic(d, NULL, &cl->bstats, true) < 0 || + gnet_stats_copy_rate_est(d, &cl->rate_est) < 0 || + gnet_stats_copy_queue(d, NULL, &cl->qstats, qlen) < 0) + return -1; + + return gnet_stats_copy_app(d, &xstats, sizeof(xstats)); +} + + + +static void +hfsc_walk(struct Qdisc *sch, struct qdisc_walker *arg) +{ + struct hfsc_sched *q = qdisc_priv(sch); + struct hfsc_class *cl; + unsigned int i; + + if (arg->stop) + return; + + for (i = 0; i < q->clhash.hashsize; i++) { + hlist_for_each_entry(cl, &q->clhash.hash[i], + cl_common.hnode) { + if (!tc_qdisc_stats_dump(sch, (unsigned long)cl, arg)) + return; + } + } +} + +static void +hfsc_schedule_watchdog(struct Qdisc *sch) +{ + struct hfsc_sched *q = qdisc_priv(sch); + struct hfsc_class *cl; + u64 next_time = 0; + + cl = eltree_get_minel(q); + if (cl) + next_time = cl->cl_e; + if (q->root.cl_cfmin != 0) { + if (next_time == 0 || next_time > q->root.cl_cfmin) + next_time = q->root.cl_cfmin; + } + if (next_time) + qdisc_watchdog_schedule(&q->watchdog, next_time); +} + +static int +hfsc_init_qdisc(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct hfsc_sched *q = qdisc_priv(sch); + struct tc_hfsc_qopt *qopt; + int err; + + qdisc_watchdog_init(&q->watchdog, sch); + + if (!opt || nla_len(opt) < sizeof(*qopt)) + return -EINVAL; + qopt = nla_data(opt); + + q->defcls = qopt->defcls; + err = qdisc_class_hash_init(&q->clhash); + if (err < 0) + return err; + q->eligible = RB_ROOT; + + err = tcf_block_get(&q->root.block, &q->root.filter_list, sch, extack); + if (err) + return err; + + gnet_stats_basic_sync_init(&q->root.bstats); + q->root.cl_common.classid = sch->handle; + q->root.sched = q; + q->root.qdisc = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, + sch->handle, NULL); + if (q->root.qdisc == NULL) + q->root.qdisc = &noop_qdisc; + else + qdisc_hash_add(q->root.qdisc, true); + INIT_LIST_HEAD(&q->root.children); + q->root.vt_tree = RB_ROOT; + q->root.cf_tree = RB_ROOT; + + qdisc_class_hash_insert(&q->clhash, &q->root.cl_common); + qdisc_class_hash_grow(sch, &q->clhash); + + return 0; +} + +static int +hfsc_change_qdisc(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct hfsc_sched *q = qdisc_priv(sch); + struct tc_hfsc_qopt *qopt; + + if (nla_len(opt) < sizeof(*qopt)) + return -EINVAL; + qopt = nla_data(opt); + + sch_tree_lock(sch); + q->defcls = qopt->defcls; + sch_tree_unlock(sch); + + return 0; +} + +static void +hfsc_reset_class(struct hfsc_class *cl) +{ + cl->cl_total = 0; + cl->cl_cumul = 0; + cl->cl_d = 0; + cl->cl_e = 0; + cl->cl_vt = 0; + cl->cl_vtadj = 0; + cl->cl_cvtmin = 0; + cl->cl_cvtoff = 0; + cl->cl_vtperiod = 0; + cl->cl_parentperiod = 0; + cl->cl_f = 0; + cl->cl_myf = 0; + cl->cl_cfmin = 0; + cl->cl_nactive = 0; + + cl->vt_tree = RB_ROOT; + cl->cf_tree = RB_ROOT; + qdisc_reset(cl->qdisc); + + if (cl->cl_flags & HFSC_RSC) + rtsc_init(&cl->cl_deadline, &cl->cl_rsc, 0, 0); + if (cl->cl_flags & HFSC_FSC) + rtsc_init(&cl->cl_virtual, &cl->cl_fsc, 0, 0); + if (cl->cl_flags & HFSC_USC) + rtsc_init(&cl->cl_ulimit, &cl->cl_usc, 0, 0); +} + +static void +hfsc_reset_qdisc(struct Qdisc *sch) +{ + struct hfsc_sched *q = qdisc_priv(sch); + struct hfsc_class *cl; + unsigned int i; + + for (i = 0; i < q->clhash.hashsize; i++) { + hlist_for_each_entry(cl, &q->clhash.hash[i], cl_common.hnode) + hfsc_reset_class(cl); + } + q->eligible = RB_ROOT; + qdisc_watchdog_cancel(&q->watchdog); +} + +static void +hfsc_destroy_qdisc(struct Qdisc *sch) +{ + struct hfsc_sched *q = qdisc_priv(sch); + struct hlist_node *next; + struct hfsc_class *cl; + unsigned int i; + + for (i = 0; i < q->clhash.hashsize; i++) { + hlist_for_each_entry(cl, &q->clhash.hash[i], cl_common.hnode) { + tcf_block_put(cl->block); + cl->block = NULL; + } + } + for (i = 0; i < q->clhash.hashsize; i++) { + hlist_for_each_entry_safe(cl, next, &q->clhash.hash[i], + cl_common.hnode) + hfsc_destroy_class(sch, cl); + } + qdisc_class_hash_destroy(&q->clhash); + qdisc_watchdog_cancel(&q->watchdog); +} + +static int +hfsc_dump_qdisc(struct Qdisc *sch, struct sk_buff *skb) +{ + struct hfsc_sched *q = qdisc_priv(sch); + unsigned char *b = skb_tail_pointer(skb); + struct tc_hfsc_qopt qopt; + + qopt.defcls = q->defcls; + if (nla_put(skb, TCA_OPTIONS, sizeof(qopt), &qopt)) + goto nla_put_failure; + return skb->len; + + nla_put_failure: + nlmsg_trim(skb, b); + return -1; +} + +static int +hfsc_enqueue(struct sk_buff *skb, struct Qdisc *sch, struct sk_buff **to_free) +{ + unsigned int len = qdisc_pkt_len(skb); + struct hfsc_class *cl; + int err; + bool first; + + cl = hfsc_classify(skb, sch, &err); + if (cl == NULL) { + if (err & __NET_XMIT_BYPASS) + qdisc_qstats_drop(sch); + __qdisc_drop(skb, to_free); + return err; + } + + first = !cl->qdisc->q.qlen; + err = qdisc_enqueue(skb, cl->qdisc, to_free); + if (unlikely(err != NET_XMIT_SUCCESS)) { + if (net_xmit_drop_count(err)) { + cl->qstats.drops++; + qdisc_qstats_drop(sch); + } + return err; + } + + if (first) { + if (cl->cl_flags & HFSC_RSC) + init_ed(cl, len); + if (cl->cl_flags & HFSC_FSC) + init_vf(cl, len); + /* + * If this is the first packet, isolate the head so an eventual + * head drop before the first dequeue operation has no chance + * to invalidate the deadline. + */ + if (cl->cl_flags & HFSC_RSC) + cl->qdisc->ops->peek(cl->qdisc); + + } + + sch->qstats.backlog += len; + sch->q.qlen++; + + return NET_XMIT_SUCCESS; +} + +static struct sk_buff * +hfsc_dequeue(struct Qdisc *sch) +{ + struct hfsc_sched *q = qdisc_priv(sch); + struct hfsc_class *cl; + struct sk_buff *skb; + u64 cur_time; + unsigned int next_len; + int realtime = 0; + + if (sch->q.qlen == 0) + return NULL; + + cur_time = psched_get_time(); + + /* + * if there are eligible classes, use real-time criteria. + * find the class with the minimum deadline among + * the eligible classes. + */ + cl = eltree_get_mindl(q, cur_time); + if (cl) { + realtime = 1; + } else { + /* + * use link-sharing criteria + * get the class with the minimum vt in the hierarchy + */ + cl = vttree_get_minvt(&q->root, cur_time); + if (cl == NULL) { + qdisc_qstats_overlimit(sch); + hfsc_schedule_watchdog(sch); + return NULL; + } + } + + skb = qdisc_dequeue_peeked(cl->qdisc); + if (skb == NULL) { + qdisc_warn_nonwc("HFSC", cl->qdisc); + return NULL; + } + + bstats_update(&cl->bstats, skb); + update_vf(cl, qdisc_pkt_len(skb), cur_time); + if (realtime) + cl->cl_cumul += qdisc_pkt_len(skb); + + if (cl->cl_flags & HFSC_RSC) { + if (cl->qdisc->q.qlen != 0) { + /* update ed */ + next_len = qdisc_peek_len(cl->qdisc); + if (realtime) + update_ed(cl, next_len); + else + update_d(cl, next_len); + } else { + /* the class becomes passive */ + eltree_remove(cl); + } + } + + qdisc_bstats_update(sch, skb); + qdisc_qstats_backlog_dec(sch, skb); + sch->q.qlen--; + + return skb; +} + +static const struct Qdisc_class_ops hfsc_class_ops = { + .change = hfsc_change_class, + .delete = hfsc_delete_class, + .graft = hfsc_graft_class, + .leaf = hfsc_class_leaf, + .qlen_notify = hfsc_qlen_notify, + .find = hfsc_search_class, + .bind_tcf = hfsc_bind_tcf, + .unbind_tcf = hfsc_unbind_tcf, + .tcf_block = hfsc_tcf_block, + .dump = hfsc_dump_class, + .dump_stats = hfsc_dump_class_stats, + .walk = hfsc_walk +}; + +static struct Qdisc_ops hfsc_qdisc_ops __read_mostly = { + .id = "hfsc", + .init = hfsc_init_qdisc, + .change = hfsc_change_qdisc, + .reset = hfsc_reset_qdisc, + .destroy = hfsc_destroy_qdisc, + .dump = hfsc_dump_qdisc, + .enqueue = hfsc_enqueue, + .dequeue = hfsc_dequeue, + .peek = qdisc_peek_dequeued, + .cl_ops = &hfsc_class_ops, + .priv_size = sizeof(struct hfsc_sched), + .owner = THIS_MODULE +}; + +static int __init +hfsc_init(void) +{ + return register_qdisc(&hfsc_qdisc_ops); +} + +static void __exit +hfsc_cleanup(void) +{ + unregister_qdisc(&hfsc_qdisc_ops); +} + +MODULE_LICENSE("GPL"); +module_init(hfsc_init); +module_exit(hfsc_cleanup); diff --git a/net/sched/sch_hhf.c b/net/sched/sch_hhf.c new file mode 100644 index 000000000..d26cd436c --- /dev/null +++ b/net/sched/sch_hhf.c @@ -0,0 +1,721 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* net/sched/sch_hhf.c Heavy-Hitter Filter (HHF) + * + * Copyright (C) 2013 Terry Lam + * Copyright (C) 2013 Nandita Dukkipati + */ + +#include +#include +#include +#include +#include +#include +#include + +/* Heavy-Hitter Filter (HHF) + * + * Principles : + * Flows are classified into two buckets: non-heavy-hitter and heavy-hitter + * buckets. Initially, a new flow starts as non-heavy-hitter. Once classified + * as heavy-hitter, it is immediately switched to the heavy-hitter bucket. + * The buckets are dequeued by a Weighted Deficit Round Robin (WDRR) scheduler, + * in which the heavy-hitter bucket is served with less weight. + * In other words, non-heavy-hitters (e.g., short bursts of critical traffic) + * are isolated from heavy-hitters (e.g., persistent bulk traffic) and also have + * higher share of bandwidth. + * + * To capture heavy-hitters, we use the "multi-stage filter" algorithm in the + * following paper: + * [EV02] C. Estan and G. Varghese, "New Directions in Traffic Measurement and + * Accounting", in ACM SIGCOMM, 2002. + * + * Conceptually, a multi-stage filter comprises k independent hash functions + * and k counter arrays. Packets are indexed into k counter arrays by k hash + * functions, respectively. The counters are then increased by the packet sizes. + * Therefore, + * - For a heavy-hitter flow: *all* of its k array counters must be large. + * - For a non-heavy-hitter flow: some of its k array counters can be large + * due to hash collision with other small flows; however, with high + * probability, not *all* k counters are large. + * + * By the design of the multi-stage filter algorithm, the false negative rate + * (heavy-hitters getting away uncaptured) is zero. However, the algorithm is + * susceptible to false positives (non-heavy-hitters mistakenly classified as + * heavy-hitters). + * Therefore, we also implement the following optimizations to reduce false + * positives by avoiding unnecessary increment of the counter values: + * - Optimization O1: once a heavy-hitter is identified, its bytes are not + * accounted in the array counters. This technique is called "shielding" + * in Section 3.3.1 of [EV02]. + * - Optimization O2: conservative update of counters + * (Section 3.3.2 of [EV02]), + * New counter value = max {old counter value, + * smallest counter value + packet bytes} + * + * Finally, we refresh the counters periodically since otherwise the counter + * values will keep accumulating. + * + * Once a flow is classified as heavy-hitter, we also save its per-flow state + * in an exact-matching flow table so that its subsequent packets can be + * dispatched to the heavy-hitter bucket accordingly. + * + * + * At a high level, this qdisc works as follows: + * Given a packet p: + * - If the flow-id of p (e.g., TCP 5-tuple) is already in the exact-matching + * heavy-hitter flow table, denoted table T, then send p to the heavy-hitter + * bucket. + * - Otherwise, forward p to the multi-stage filter, denoted filter F + * + If F decides that p belongs to a non-heavy-hitter flow, then send p + * to the non-heavy-hitter bucket. + * + Otherwise, if F decides that p belongs to a new heavy-hitter flow, + * then set up a new flow entry for the flow-id of p in the table T and + * send p to the heavy-hitter bucket. + * + * In this implementation: + * - T is a fixed-size hash-table with 1024 entries. Hash collision is + * resolved by linked-list chaining. + * - F has four counter arrays, each array containing 1024 32-bit counters. + * That means 4 * 1024 * 32 bits = 16KB of memory. + * - Since each array in F contains 1024 counters, 10 bits are sufficient to + * index into each array. + * Hence, instead of having four hash functions, we chop the 32-bit + * skb-hash into three 10-bit chunks, and the remaining 10-bit chunk is + * computed as XOR sum of those three chunks. + * - We need to clear the counter arrays periodically; however, directly + * memsetting 16KB of memory can lead to cache eviction and unwanted delay. + * So by representing each counter by a valid bit, we only need to reset + * 4K of 1 bit (i.e. 512 bytes) instead of 16KB of memory. + * - The Deficit Round Robin engine is taken from fq_codel implementation + * (net/sched/sch_fq_codel.c). Note that wdrr_bucket corresponds to + * fq_codel_flow in fq_codel implementation. + * + */ + +/* Non-configurable parameters */ +#define HH_FLOWS_CNT 1024 /* number of entries in exact-matching table T */ +#define HHF_ARRAYS_CNT 4 /* number of arrays in multi-stage filter F */ +#define HHF_ARRAYS_LEN 1024 /* number of counters in each array of F */ +#define HHF_BIT_MASK_LEN 10 /* masking 10 bits */ +#define HHF_BIT_MASK 0x3FF /* bitmask of 10 bits */ + +#define WDRR_BUCKET_CNT 2 /* two buckets for Weighted DRR */ +enum wdrr_bucket_idx { + WDRR_BUCKET_FOR_HH = 0, /* bucket id for heavy-hitters */ + WDRR_BUCKET_FOR_NON_HH = 1 /* bucket id for non-heavy-hitters */ +}; + +#define hhf_time_before(a, b) \ + (typecheck(u32, a) && typecheck(u32, b) && ((s32)((a) - (b)) < 0)) + +/* Heavy-hitter per-flow state */ +struct hh_flow_state { + u32 hash_id; /* hash of flow-id (e.g. TCP 5-tuple) */ + u32 hit_timestamp; /* last time heavy-hitter was seen */ + struct list_head flowchain; /* chaining under hash collision */ +}; + +/* Weighted Deficit Round Robin (WDRR) scheduler */ +struct wdrr_bucket { + struct sk_buff *head; + struct sk_buff *tail; + struct list_head bucketchain; + int deficit; +}; + +struct hhf_sched_data { + struct wdrr_bucket buckets[WDRR_BUCKET_CNT]; + siphash_key_t perturbation; /* hash perturbation */ + u32 quantum; /* psched_mtu(qdisc_dev(sch)); */ + u32 drop_overlimit; /* number of times max qdisc packet + * limit was hit + */ + struct list_head *hh_flows; /* table T (currently active HHs) */ + u32 hh_flows_limit; /* max active HH allocs */ + u32 hh_flows_overlimit; /* num of disallowed HH allocs */ + u32 hh_flows_total_cnt; /* total admitted HHs */ + u32 hh_flows_current_cnt; /* total current HHs */ + u32 *hhf_arrays[HHF_ARRAYS_CNT]; /* HH filter F */ + u32 hhf_arrays_reset_timestamp; /* last time hhf_arrays + * was reset + */ + unsigned long *hhf_valid_bits[HHF_ARRAYS_CNT]; /* shadow valid bits + * of hhf_arrays + */ + /* Similar to the "new_flows" vs. "old_flows" concept in fq_codel DRR */ + struct list_head new_buckets; /* list of new buckets */ + struct list_head old_buckets; /* list of old buckets */ + + /* Configurable HHF parameters */ + u32 hhf_reset_timeout; /* interval to reset counter + * arrays in filter F + * (default 40ms) + */ + u32 hhf_admit_bytes; /* counter thresh to classify as + * HH (default 128KB). + * With these default values, + * 128KB / 40ms = 25 Mbps + * i.e., we expect to capture HHs + * sending > 25 Mbps. + */ + u32 hhf_evict_timeout; /* aging threshold to evict idle + * HHs out of table T. This should + * be large enough to avoid + * reordering during HH eviction. + * (default 1s) + */ + u32 hhf_non_hh_weight; /* WDRR weight for non-HHs + * (default 2, + * i.e., non-HH : HH = 2 : 1) + */ +}; + +static u32 hhf_time_stamp(void) +{ + return jiffies; +} + +/* Looks up a heavy-hitter flow in a chaining list of table T. */ +static struct hh_flow_state *seek_list(const u32 hash, + struct list_head *head, + struct hhf_sched_data *q) +{ + struct hh_flow_state *flow, *next; + u32 now = hhf_time_stamp(); + + if (list_empty(head)) + return NULL; + + list_for_each_entry_safe(flow, next, head, flowchain) { + u32 prev = flow->hit_timestamp + q->hhf_evict_timeout; + + if (hhf_time_before(prev, now)) { + /* Delete expired heavy-hitters, but preserve one entry + * to avoid kzalloc() when next time this slot is hit. + */ + if (list_is_last(&flow->flowchain, head)) + return NULL; + list_del(&flow->flowchain); + kfree(flow); + q->hh_flows_current_cnt--; + } else if (flow->hash_id == hash) { + return flow; + } + } + return NULL; +} + +/* Returns a flow state entry for a new heavy-hitter. Either reuses an expired + * entry or dynamically alloc a new entry. + */ +static struct hh_flow_state *alloc_new_hh(struct list_head *head, + struct hhf_sched_data *q) +{ + struct hh_flow_state *flow; + u32 now = hhf_time_stamp(); + + if (!list_empty(head)) { + /* Find an expired heavy-hitter flow entry. */ + list_for_each_entry(flow, head, flowchain) { + u32 prev = flow->hit_timestamp + q->hhf_evict_timeout; + + if (hhf_time_before(prev, now)) + return flow; + } + } + + if (q->hh_flows_current_cnt >= q->hh_flows_limit) { + q->hh_flows_overlimit++; + return NULL; + } + /* Create new entry. */ + flow = kzalloc(sizeof(struct hh_flow_state), GFP_ATOMIC); + if (!flow) + return NULL; + + q->hh_flows_current_cnt++; + INIT_LIST_HEAD(&flow->flowchain); + list_add_tail(&flow->flowchain, head); + + return flow; +} + +/* Assigns packets to WDRR buckets. Implements a multi-stage filter to + * classify heavy-hitters. + */ +static enum wdrr_bucket_idx hhf_classify(struct sk_buff *skb, struct Qdisc *sch) +{ + struct hhf_sched_data *q = qdisc_priv(sch); + u32 tmp_hash, hash; + u32 xorsum, filter_pos[HHF_ARRAYS_CNT], flow_pos; + struct hh_flow_state *flow; + u32 pkt_len, min_hhf_val; + int i; + u32 prev; + u32 now = hhf_time_stamp(); + + /* Reset the HHF counter arrays if this is the right time. */ + prev = q->hhf_arrays_reset_timestamp + q->hhf_reset_timeout; + if (hhf_time_before(prev, now)) { + for (i = 0; i < HHF_ARRAYS_CNT; i++) + bitmap_zero(q->hhf_valid_bits[i], HHF_ARRAYS_LEN); + q->hhf_arrays_reset_timestamp = now; + } + + /* Get hashed flow-id of the skb. */ + hash = skb_get_hash_perturb(skb, &q->perturbation); + + /* Check if this packet belongs to an already established HH flow. */ + flow_pos = hash & HHF_BIT_MASK; + flow = seek_list(hash, &q->hh_flows[flow_pos], q); + if (flow) { /* found its HH flow */ + flow->hit_timestamp = now; + return WDRR_BUCKET_FOR_HH; + } + + /* Now pass the packet through the multi-stage filter. */ + tmp_hash = hash; + xorsum = 0; + for (i = 0; i < HHF_ARRAYS_CNT - 1; i++) { + /* Split the skb_hash into three 10-bit chunks. */ + filter_pos[i] = tmp_hash & HHF_BIT_MASK; + xorsum ^= filter_pos[i]; + tmp_hash >>= HHF_BIT_MASK_LEN; + } + /* The last chunk is computed as XOR sum of other chunks. */ + filter_pos[HHF_ARRAYS_CNT - 1] = xorsum ^ tmp_hash; + + pkt_len = qdisc_pkt_len(skb); + min_hhf_val = ~0U; + for (i = 0; i < HHF_ARRAYS_CNT; i++) { + u32 val; + + if (!test_bit(filter_pos[i], q->hhf_valid_bits[i])) { + q->hhf_arrays[i][filter_pos[i]] = 0; + __set_bit(filter_pos[i], q->hhf_valid_bits[i]); + } + + val = q->hhf_arrays[i][filter_pos[i]] + pkt_len; + if (min_hhf_val > val) + min_hhf_val = val; + } + + /* Found a new HH iff all counter values > HH admit threshold. */ + if (min_hhf_val > q->hhf_admit_bytes) { + /* Just captured a new heavy-hitter. */ + flow = alloc_new_hh(&q->hh_flows[flow_pos], q); + if (!flow) /* memory alloc problem */ + return WDRR_BUCKET_FOR_NON_HH; + flow->hash_id = hash; + flow->hit_timestamp = now; + q->hh_flows_total_cnt++; + + /* By returning without updating counters in q->hhf_arrays, + * we implicitly implement "shielding" (see Optimization O1). + */ + return WDRR_BUCKET_FOR_HH; + } + + /* Conservative update of HHF arrays (see Optimization O2). */ + for (i = 0; i < HHF_ARRAYS_CNT; i++) { + if (q->hhf_arrays[i][filter_pos[i]] < min_hhf_val) + q->hhf_arrays[i][filter_pos[i]] = min_hhf_val; + } + return WDRR_BUCKET_FOR_NON_HH; +} + +/* Removes one skb from head of bucket. */ +static struct sk_buff *dequeue_head(struct wdrr_bucket *bucket) +{ + struct sk_buff *skb = bucket->head; + + bucket->head = skb->next; + skb_mark_not_on_list(skb); + return skb; +} + +/* Tail-adds skb to bucket. */ +static void bucket_add(struct wdrr_bucket *bucket, struct sk_buff *skb) +{ + if (bucket->head == NULL) + bucket->head = skb; + else + bucket->tail->next = skb; + bucket->tail = skb; + skb->next = NULL; +} + +static unsigned int hhf_drop(struct Qdisc *sch, struct sk_buff **to_free) +{ + struct hhf_sched_data *q = qdisc_priv(sch); + struct wdrr_bucket *bucket; + + /* Always try to drop from heavy-hitters first. */ + bucket = &q->buckets[WDRR_BUCKET_FOR_HH]; + if (!bucket->head) + bucket = &q->buckets[WDRR_BUCKET_FOR_NON_HH]; + + if (bucket->head) { + struct sk_buff *skb = dequeue_head(bucket); + + sch->q.qlen--; + qdisc_qstats_backlog_dec(sch, skb); + qdisc_drop(skb, sch, to_free); + } + + /* Return id of the bucket from which the packet was dropped. */ + return bucket - q->buckets; +} + +static int hhf_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct hhf_sched_data *q = qdisc_priv(sch); + enum wdrr_bucket_idx idx; + struct wdrr_bucket *bucket; + unsigned int prev_backlog; + + idx = hhf_classify(skb, sch); + + bucket = &q->buckets[idx]; + bucket_add(bucket, skb); + qdisc_qstats_backlog_inc(sch, skb); + + if (list_empty(&bucket->bucketchain)) { + unsigned int weight; + + /* The logic of new_buckets vs. old_buckets is the same as + * new_flows vs. old_flows in the implementation of fq_codel, + * i.e., short bursts of non-HHs should have strict priority. + */ + if (idx == WDRR_BUCKET_FOR_HH) { + /* Always move heavy-hitters to old bucket. */ + weight = 1; + list_add_tail(&bucket->bucketchain, &q->old_buckets); + } else { + weight = q->hhf_non_hh_weight; + list_add_tail(&bucket->bucketchain, &q->new_buckets); + } + bucket->deficit = weight * q->quantum; + } + if (++sch->q.qlen <= sch->limit) + return NET_XMIT_SUCCESS; + + prev_backlog = sch->qstats.backlog; + q->drop_overlimit++; + /* Return Congestion Notification only if we dropped a packet from this + * bucket. + */ + if (hhf_drop(sch, to_free) == idx) + return NET_XMIT_CN; + + /* As we dropped a packet, better let upper stack know this. */ + qdisc_tree_reduce_backlog(sch, 1, prev_backlog - sch->qstats.backlog); + return NET_XMIT_SUCCESS; +} + +static struct sk_buff *hhf_dequeue(struct Qdisc *sch) +{ + struct hhf_sched_data *q = qdisc_priv(sch); + struct sk_buff *skb = NULL; + struct wdrr_bucket *bucket; + struct list_head *head; + +begin: + head = &q->new_buckets; + if (list_empty(head)) { + head = &q->old_buckets; + if (list_empty(head)) + return NULL; + } + bucket = list_first_entry(head, struct wdrr_bucket, bucketchain); + + if (bucket->deficit <= 0) { + int weight = (bucket - q->buckets == WDRR_BUCKET_FOR_HH) ? + 1 : q->hhf_non_hh_weight; + + bucket->deficit += weight * q->quantum; + list_move_tail(&bucket->bucketchain, &q->old_buckets); + goto begin; + } + + if (bucket->head) { + skb = dequeue_head(bucket); + sch->q.qlen--; + qdisc_qstats_backlog_dec(sch, skb); + } + + if (!skb) { + /* Force a pass through old_buckets to prevent starvation. */ + if ((head == &q->new_buckets) && !list_empty(&q->old_buckets)) + list_move_tail(&bucket->bucketchain, &q->old_buckets); + else + list_del_init(&bucket->bucketchain); + goto begin; + } + qdisc_bstats_update(sch, skb); + bucket->deficit -= qdisc_pkt_len(skb); + + return skb; +} + +static void hhf_reset(struct Qdisc *sch) +{ + struct sk_buff *skb; + + while ((skb = hhf_dequeue(sch)) != NULL) + rtnl_kfree_skbs(skb, skb); +} + +static void hhf_destroy(struct Qdisc *sch) +{ + int i; + struct hhf_sched_data *q = qdisc_priv(sch); + + for (i = 0; i < HHF_ARRAYS_CNT; i++) { + kvfree(q->hhf_arrays[i]); + kvfree(q->hhf_valid_bits[i]); + } + + if (!q->hh_flows) + return; + + for (i = 0; i < HH_FLOWS_CNT; i++) { + struct hh_flow_state *flow, *next; + struct list_head *head = &q->hh_flows[i]; + + if (list_empty(head)) + continue; + list_for_each_entry_safe(flow, next, head, flowchain) { + list_del(&flow->flowchain); + kfree(flow); + } + } + kvfree(q->hh_flows); +} + +static const struct nla_policy hhf_policy[TCA_HHF_MAX + 1] = { + [TCA_HHF_BACKLOG_LIMIT] = { .type = NLA_U32 }, + [TCA_HHF_QUANTUM] = { .type = NLA_U32 }, + [TCA_HHF_HH_FLOWS_LIMIT] = { .type = NLA_U32 }, + [TCA_HHF_RESET_TIMEOUT] = { .type = NLA_U32 }, + [TCA_HHF_ADMIT_BYTES] = { .type = NLA_U32 }, + [TCA_HHF_EVICT_TIMEOUT] = { .type = NLA_U32 }, + [TCA_HHF_NON_HH_WEIGHT] = { .type = NLA_U32 }, +}; + +static int hhf_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct hhf_sched_data *q = qdisc_priv(sch); + struct nlattr *tb[TCA_HHF_MAX + 1]; + unsigned int qlen, prev_backlog; + int err; + u64 non_hh_quantum; + u32 new_quantum = q->quantum; + u32 new_hhf_non_hh_weight = q->hhf_non_hh_weight; + + err = nla_parse_nested_deprecated(tb, TCA_HHF_MAX, opt, hhf_policy, + NULL); + if (err < 0) + return err; + + if (tb[TCA_HHF_QUANTUM]) + new_quantum = nla_get_u32(tb[TCA_HHF_QUANTUM]); + + if (tb[TCA_HHF_NON_HH_WEIGHT]) + new_hhf_non_hh_weight = nla_get_u32(tb[TCA_HHF_NON_HH_WEIGHT]); + + non_hh_quantum = (u64)new_quantum * new_hhf_non_hh_weight; + if (non_hh_quantum == 0 || non_hh_quantum > INT_MAX) + return -EINVAL; + + sch_tree_lock(sch); + + if (tb[TCA_HHF_BACKLOG_LIMIT]) + sch->limit = nla_get_u32(tb[TCA_HHF_BACKLOG_LIMIT]); + + q->quantum = new_quantum; + q->hhf_non_hh_weight = new_hhf_non_hh_weight; + + if (tb[TCA_HHF_HH_FLOWS_LIMIT]) + q->hh_flows_limit = nla_get_u32(tb[TCA_HHF_HH_FLOWS_LIMIT]); + + if (tb[TCA_HHF_RESET_TIMEOUT]) { + u32 us = nla_get_u32(tb[TCA_HHF_RESET_TIMEOUT]); + + q->hhf_reset_timeout = usecs_to_jiffies(us); + } + + if (tb[TCA_HHF_ADMIT_BYTES]) + q->hhf_admit_bytes = nla_get_u32(tb[TCA_HHF_ADMIT_BYTES]); + + if (tb[TCA_HHF_EVICT_TIMEOUT]) { + u32 us = nla_get_u32(tb[TCA_HHF_EVICT_TIMEOUT]); + + q->hhf_evict_timeout = usecs_to_jiffies(us); + } + + qlen = sch->q.qlen; + prev_backlog = sch->qstats.backlog; + while (sch->q.qlen > sch->limit) { + struct sk_buff *skb = hhf_dequeue(sch); + + rtnl_kfree_skbs(skb, skb); + } + qdisc_tree_reduce_backlog(sch, qlen - sch->q.qlen, + prev_backlog - sch->qstats.backlog); + + sch_tree_unlock(sch); + return 0; +} + +static int hhf_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct hhf_sched_data *q = qdisc_priv(sch); + int i; + + sch->limit = 1000; + q->quantum = psched_mtu(qdisc_dev(sch)); + get_random_bytes(&q->perturbation, sizeof(q->perturbation)); + INIT_LIST_HEAD(&q->new_buckets); + INIT_LIST_HEAD(&q->old_buckets); + + /* Configurable HHF parameters */ + q->hhf_reset_timeout = HZ / 25; /* 40 ms */ + q->hhf_admit_bytes = 131072; /* 128 KB */ + q->hhf_evict_timeout = HZ; /* 1 sec */ + q->hhf_non_hh_weight = 2; + + if (opt) { + int err = hhf_change(sch, opt, extack); + + if (err) + return err; + } + + if (!q->hh_flows) { + /* Initialize heavy-hitter flow table. */ + q->hh_flows = kvcalloc(HH_FLOWS_CNT, sizeof(struct list_head), + GFP_KERNEL); + if (!q->hh_flows) + return -ENOMEM; + for (i = 0; i < HH_FLOWS_CNT; i++) + INIT_LIST_HEAD(&q->hh_flows[i]); + + /* Cap max active HHs at twice len of hh_flows table. */ + q->hh_flows_limit = 2 * HH_FLOWS_CNT; + q->hh_flows_overlimit = 0; + q->hh_flows_total_cnt = 0; + q->hh_flows_current_cnt = 0; + + /* Initialize heavy-hitter filter arrays. */ + for (i = 0; i < HHF_ARRAYS_CNT; i++) { + q->hhf_arrays[i] = kvcalloc(HHF_ARRAYS_LEN, + sizeof(u32), + GFP_KERNEL); + if (!q->hhf_arrays[i]) { + /* Note: hhf_destroy() will be called + * by our caller. + */ + return -ENOMEM; + } + } + q->hhf_arrays_reset_timestamp = hhf_time_stamp(); + + /* Initialize valid bits of heavy-hitter filter arrays. */ + for (i = 0; i < HHF_ARRAYS_CNT; i++) { + q->hhf_valid_bits[i] = kvzalloc(HHF_ARRAYS_LEN / + BITS_PER_BYTE, GFP_KERNEL); + if (!q->hhf_valid_bits[i]) { + /* Note: hhf_destroy() will be called + * by our caller. + */ + return -ENOMEM; + } + } + + /* Initialize Weighted DRR buckets. */ + for (i = 0; i < WDRR_BUCKET_CNT; i++) { + struct wdrr_bucket *bucket = q->buckets + i; + + INIT_LIST_HEAD(&bucket->bucketchain); + } + } + + return 0; +} + +static int hhf_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct hhf_sched_data *q = qdisc_priv(sch); + struct nlattr *opts; + + opts = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (opts == NULL) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_HHF_BACKLOG_LIMIT, sch->limit) || + nla_put_u32(skb, TCA_HHF_QUANTUM, q->quantum) || + nla_put_u32(skb, TCA_HHF_HH_FLOWS_LIMIT, q->hh_flows_limit) || + nla_put_u32(skb, TCA_HHF_RESET_TIMEOUT, + jiffies_to_usecs(q->hhf_reset_timeout)) || + nla_put_u32(skb, TCA_HHF_ADMIT_BYTES, q->hhf_admit_bytes) || + nla_put_u32(skb, TCA_HHF_EVICT_TIMEOUT, + jiffies_to_usecs(q->hhf_evict_timeout)) || + nla_put_u32(skb, TCA_HHF_NON_HH_WEIGHT, q->hhf_non_hh_weight)) + goto nla_put_failure; + + return nla_nest_end(skb, opts); + +nla_put_failure: + return -1; +} + +static int hhf_dump_stats(struct Qdisc *sch, struct gnet_dump *d) +{ + struct hhf_sched_data *q = qdisc_priv(sch); + struct tc_hhf_xstats st = { + .drop_overlimit = q->drop_overlimit, + .hh_overlimit = q->hh_flows_overlimit, + .hh_tot_count = q->hh_flows_total_cnt, + .hh_cur_count = q->hh_flows_current_cnt, + }; + + return gnet_stats_copy_app(d, &st, sizeof(st)); +} + +static struct Qdisc_ops hhf_qdisc_ops __read_mostly = { + .id = "hhf", + .priv_size = sizeof(struct hhf_sched_data), + + .enqueue = hhf_enqueue, + .dequeue = hhf_dequeue, + .peek = qdisc_peek_dequeued, + .init = hhf_init, + .reset = hhf_reset, + .destroy = hhf_destroy, + .change = hhf_change, + .dump = hhf_dump, + .dump_stats = hhf_dump_stats, + .owner = THIS_MODULE, +}; + +static int __init hhf_module_init(void) +{ + return register_qdisc(&hhf_qdisc_ops); +} + +static void __exit hhf_module_exit(void) +{ + unregister_qdisc(&hhf_qdisc_ops); +} + +module_init(hhf_module_init) +module_exit(hhf_module_exit) +MODULE_AUTHOR("Terry Lam"); +MODULE_AUTHOR("Nandita Dukkipati"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("Heavy-Hitter Filter (HHF)"); diff --git a/net/sched/sch_htb.c b/net/sched/sch_htb.c new file mode 100644 index 000000000..67b1879ea --- /dev/null +++ b/net/sched/sch_htb.c @@ -0,0 +1,2178 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/sch_htb.c Hierarchical token bucket, feed tree version + * + * Authors: Martin Devera, + * + * Credits (in time order) for older HTB versions: + * Stef Coene + * HTB support at LARTC mailing list + * Ondrej Kraus, + * found missing INIT_QDISC(htb) + * Vladimir Smelhaus, Aamer Akhter, Bert Hubert + * helped a lot to locate nasty class stall bug + * Andi Kleen, Jamal Hadi, Bert Hubert + * code review and helpful comments on shaping + * Tomasz Wrona, + * created test case so that I was able to fix nasty bug + * Wilfried Weissmann + * spotted bug in dequeue code and helped with fix + * Jiri Fojtasek + * fixed requeue routine + * and many others. thanks. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* HTB algorithm. + Author: devik@cdi.cz + ======================================================================== + HTB is like TBF with multiple classes. It is also similar to CBQ because + it allows to assign priority to each class in hierarchy. + In fact it is another implementation of Floyd's formal sharing. + + Levels: + Each class is assigned level. Leaf has ALWAYS level 0 and root + classes have level TC_HTB_MAXDEPTH-1. Interior nodes has level + one less than their parent. +*/ + +static int htb_hysteresis __read_mostly = 0; /* whether to use mode hysteresis for speedup */ +#define HTB_VER 0x30011 /* major must be matched with number supplied by TC as version */ + +#if HTB_VER >> 16 != TC_HTB_PROTOVER +#error "Mismatched sch_htb.c and pkt_sch.h" +#endif + +/* Module parameter and sysfs export */ +module_param (htb_hysteresis, int, 0640); +MODULE_PARM_DESC(htb_hysteresis, "Hysteresis mode, less CPU load, less accurate"); + +static int htb_rate_est = 0; /* htb classes have a default rate estimator */ +module_param(htb_rate_est, int, 0640); +MODULE_PARM_DESC(htb_rate_est, "setup a default rate estimator (4sec 16sec) for htb classes"); + +/* used internaly to keep status of single class */ +enum htb_cmode { + HTB_CANT_SEND, /* class can't send and can't borrow */ + HTB_MAY_BORROW, /* class can't send but may borrow */ + HTB_CAN_SEND /* class can send */ +}; + +struct htb_prio { + union { + struct rb_root row; + struct rb_root feed; + }; + struct rb_node *ptr; + /* When class changes from state 1->2 and disconnects from + * parent's feed then we lost ptr value and start from the + * first child again. Here we store classid of the + * last valid ptr (used when ptr is NULL). + */ + u32 last_ptr_id; +}; + +/* interior & leaf nodes; props specific to leaves are marked L: + * To reduce false sharing, place mostly read fields at beginning, + * and mostly written ones at the end. + */ +struct htb_class { + struct Qdisc_class_common common; + struct psched_ratecfg rate; + struct psched_ratecfg ceil; + s64 buffer, cbuffer;/* token bucket depth/rate */ + s64 mbuffer; /* max wait time */ + u32 prio; /* these two are used only by leaves... */ + int quantum; /* but stored for parent-to-leaf return */ + + struct tcf_proto __rcu *filter_list; /* class attached filters */ + struct tcf_block *block; + int filter_cnt; + + int level; /* our level (see above) */ + unsigned int children; + struct htb_class *parent; /* parent class */ + + struct net_rate_estimator __rcu *rate_est; + + /* + * Written often fields + */ + struct gnet_stats_basic_sync bstats; + struct gnet_stats_basic_sync bstats_bias; + struct tc_htb_xstats xstats; /* our special stats */ + + /* token bucket parameters */ + s64 tokens, ctokens;/* current number of tokens */ + s64 t_c; /* checkpoint time */ + + union { + struct htb_class_leaf { + int deficit[TC_HTB_MAXDEPTH]; + struct Qdisc *q; + struct netdev_queue *offload_queue; + } leaf; + struct htb_class_inner { + struct htb_prio clprio[TC_HTB_NUMPRIO]; + } inner; + }; + s64 pq_key; + + int prio_activity; /* for which prios are we active */ + enum htb_cmode cmode; /* current mode of the class */ + struct rb_node pq_node; /* node for event queue */ + struct rb_node node[TC_HTB_NUMPRIO]; /* node for self or feed tree */ + + unsigned int drops ____cacheline_aligned_in_smp; + unsigned int overlimits; +}; + +struct htb_level { + struct rb_root wait_pq; + struct htb_prio hprio[TC_HTB_NUMPRIO]; +}; + +struct htb_sched { + struct Qdisc_class_hash clhash; + int defcls; /* class where unclassified flows go to */ + int rate2quantum; /* quant = rate / rate2quantum */ + + /* filters for qdisc itself */ + struct tcf_proto __rcu *filter_list; + struct tcf_block *block; + +#define HTB_WARN_TOOMANYEVENTS 0x1 + unsigned int warned; /* only one warning */ + int direct_qlen; + struct work_struct work; + + /* non shaped skbs; let them go directly thru */ + struct qdisc_skb_head direct_queue; + u32 direct_pkts; + u32 overlimits; + + struct qdisc_watchdog watchdog; + + s64 now; /* cached dequeue time */ + + /* time of nearest event per level (row) */ + s64 near_ev_cache[TC_HTB_MAXDEPTH]; + + int row_mask[TC_HTB_MAXDEPTH]; + + struct htb_level hlevel[TC_HTB_MAXDEPTH]; + + struct Qdisc **direct_qdiscs; + unsigned int num_direct_qdiscs; + + bool offload; +}; + +/* find class in global hash table using given handle */ +static inline struct htb_class *htb_find(u32 handle, struct Qdisc *sch) +{ + struct htb_sched *q = qdisc_priv(sch); + struct Qdisc_class_common *clc; + + clc = qdisc_class_find(&q->clhash, handle); + if (clc == NULL) + return NULL; + return container_of(clc, struct htb_class, common); +} + +static unsigned long htb_search(struct Qdisc *sch, u32 handle) +{ + return (unsigned long)htb_find(handle, sch); +} +/** + * htb_classify - classify a packet into class + * + * It returns NULL if the packet should be dropped or -1 if the packet + * should be passed directly thru. In all other cases leaf class is returned. + * We allow direct class selection by classid in priority. The we examine + * filters in qdisc and in inner nodes (if higher filter points to the inner + * node). If we end up with classid MAJOR:0 we enqueue the skb into special + * internal fifo (direct). These packets then go directly thru. If we still + * have no valid leaf we try to use MAJOR:default leaf. It still unsuccessful + * then finish and return direct queue. + */ +#define HTB_DIRECT ((struct htb_class *)-1L) + +static struct htb_class *htb_classify(struct sk_buff *skb, struct Qdisc *sch, + int *qerr) +{ + struct htb_sched *q = qdisc_priv(sch); + struct htb_class *cl; + struct tcf_result res; + struct tcf_proto *tcf; + int result; + + /* allow to select class by setting skb->priority to valid classid; + * note that nfmark can be used too by attaching filter fw with no + * rules in it + */ + if (skb->priority == sch->handle) + return HTB_DIRECT; /* X:0 (direct flow) selected */ + cl = htb_find(skb->priority, sch); + if (cl) { + if (cl->level == 0) + return cl; + /* Start with inner filter chain if a non-leaf class is selected */ + tcf = rcu_dereference_bh(cl->filter_list); + } else { + tcf = rcu_dereference_bh(q->filter_list); + } + + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; + while (tcf && (result = tcf_classify(skb, NULL, tcf, &res, false)) >= 0) { +#ifdef CONFIG_NET_CLS_ACT + switch (result) { + case TC_ACT_QUEUED: + case TC_ACT_STOLEN: + case TC_ACT_TRAP: + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_STOLEN; + fallthrough; + case TC_ACT_SHOT: + return NULL; + } +#endif + cl = (void *)res.class; + if (!cl) { + if (res.classid == sch->handle) + return HTB_DIRECT; /* X:0 (direct flow) */ + cl = htb_find(res.classid, sch); + if (!cl) + break; /* filter selected invalid classid */ + } + if (!cl->level) + return cl; /* we hit leaf; return it */ + + /* we have got inner class; apply inner filter chain */ + tcf = rcu_dereference_bh(cl->filter_list); + } + /* classification failed; try to use default class */ + cl = htb_find(TC_H_MAKE(TC_H_MAJ(sch->handle), q->defcls), sch); + if (!cl || cl->level) + return HTB_DIRECT; /* bad default .. this is safe bet */ + return cl; +} + +/** + * htb_add_to_id_tree - adds class to the round robin list + * @root: the root of the tree + * @cl: the class to add + * @prio: the give prio in class + * + * Routine adds class to the list (actually tree) sorted by classid. + * Make sure that class is not already on such list for given prio. + */ +static void htb_add_to_id_tree(struct rb_root *root, + struct htb_class *cl, int prio) +{ + struct rb_node **p = &root->rb_node, *parent = NULL; + + while (*p) { + struct htb_class *c; + parent = *p; + c = rb_entry(parent, struct htb_class, node[prio]); + + if (cl->common.classid > c->common.classid) + p = &parent->rb_right; + else + p = &parent->rb_left; + } + rb_link_node(&cl->node[prio], parent, p); + rb_insert_color(&cl->node[prio], root); +} + +/** + * htb_add_to_wait_tree - adds class to the event queue with delay + * @q: the priority event queue + * @cl: the class to add + * @delay: delay in microseconds + * + * The class is added to priority event queue to indicate that class will + * change its mode in cl->pq_key microseconds. Make sure that class is not + * already in the queue. + */ +static void htb_add_to_wait_tree(struct htb_sched *q, + struct htb_class *cl, s64 delay) +{ + struct rb_node **p = &q->hlevel[cl->level].wait_pq.rb_node, *parent = NULL; + + cl->pq_key = q->now + delay; + if (cl->pq_key == q->now) + cl->pq_key++; + + /* update the nearest event cache */ + if (q->near_ev_cache[cl->level] > cl->pq_key) + q->near_ev_cache[cl->level] = cl->pq_key; + + while (*p) { + struct htb_class *c; + parent = *p; + c = rb_entry(parent, struct htb_class, pq_node); + if (cl->pq_key >= c->pq_key) + p = &parent->rb_right; + else + p = &parent->rb_left; + } + rb_link_node(&cl->pq_node, parent, p); + rb_insert_color(&cl->pq_node, &q->hlevel[cl->level].wait_pq); +} + +/** + * htb_next_rb_node - finds next node in binary tree + * @n: the current node in binary tree + * + * When we are past last key we return NULL. + * Average complexity is 2 steps per call. + */ +static inline void htb_next_rb_node(struct rb_node **n) +{ + *n = rb_next(*n); +} + +/** + * htb_add_class_to_row - add class to its row + * @q: the priority event queue + * @cl: the class to add + * @mask: the given priorities in class in bitmap + * + * The class is added to row at priorities marked in mask. + * It does nothing if mask == 0. + */ +static inline void htb_add_class_to_row(struct htb_sched *q, + struct htb_class *cl, int mask) +{ + q->row_mask[cl->level] |= mask; + while (mask) { + int prio = ffz(~mask); + mask &= ~(1 << prio); + htb_add_to_id_tree(&q->hlevel[cl->level].hprio[prio].row, cl, prio); + } +} + +/* If this triggers, it is a bug in this code, but it need not be fatal */ +static void htb_safe_rb_erase(struct rb_node *rb, struct rb_root *root) +{ + if (RB_EMPTY_NODE(rb)) { + WARN_ON(1); + } else { + rb_erase(rb, root); + RB_CLEAR_NODE(rb); + } +} + + +/** + * htb_remove_class_from_row - removes class from its row + * @q: the priority event queue + * @cl: the class to add + * @mask: the given priorities in class in bitmap + * + * The class is removed from row at priorities marked in mask. + * It does nothing if mask == 0. + */ +static inline void htb_remove_class_from_row(struct htb_sched *q, + struct htb_class *cl, int mask) +{ + int m = 0; + struct htb_level *hlevel = &q->hlevel[cl->level]; + + while (mask) { + int prio = ffz(~mask); + struct htb_prio *hprio = &hlevel->hprio[prio]; + + mask &= ~(1 << prio); + if (hprio->ptr == cl->node + prio) + htb_next_rb_node(&hprio->ptr); + + htb_safe_rb_erase(cl->node + prio, &hprio->row); + if (!hprio->row.rb_node) + m |= 1 << prio; + } + q->row_mask[cl->level] &= ~m; +} + +/** + * htb_activate_prios - creates active classe's feed chain + * @q: the priority event queue + * @cl: the class to activate + * + * The class is connected to ancestors and/or appropriate rows + * for priorities it is participating on. cl->cmode must be new + * (activated) mode. It does nothing if cl->prio_activity == 0. + */ +static void htb_activate_prios(struct htb_sched *q, struct htb_class *cl) +{ + struct htb_class *p = cl->parent; + long m, mask = cl->prio_activity; + + while (cl->cmode == HTB_MAY_BORROW && p && mask) { + m = mask; + while (m) { + unsigned int prio = ffz(~m); + + if (WARN_ON_ONCE(prio >= ARRAY_SIZE(p->inner.clprio))) + break; + m &= ~(1 << prio); + + if (p->inner.clprio[prio].feed.rb_node) + /* parent already has its feed in use so that + * reset bit in mask as parent is already ok + */ + mask &= ~(1 << prio); + + htb_add_to_id_tree(&p->inner.clprio[prio].feed, cl, prio); + } + p->prio_activity |= mask; + cl = p; + p = cl->parent; + + } + if (cl->cmode == HTB_CAN_SEND && mask) + htb_add_class_to_row(q, cl, mask); +} + +/** + * htb_deactivate_prios - remove class from feed chain + * @q: the priority event queue + * @cl: the class to deactivate + * + * cl->cmode must represent old mode (before deactivation). It does + * nothing if cl->prio_activity == 0. Class is removed from all feed + * chains and rows. + */ +static void htb_deactivate_prios(struct htb_sched *q, struct htb_class *cl) +{ + struct htb_class *p = cl->parent; + long m, mask = cl->prio_activity; + + while (cl->cmode == HTB_MAY_BORROW && p && mask) { + m = mask; + mask = 0; + while (m) { + int prio = ffz(~m); + m &= ~(1 << prio); + + if (p->inner.clprio[prio].ptr == cl->node + prio) { + /* we are removing child which is pointed to from + * parent feed - forget the pointer but remember + * classid + */ + p->inner.clprio[prio].last_ptr_id = cl->common.classid; + p->inner.clprio[prio].ptr = NULL; + } + + htb_safe_rb_erase(cl->node + prio, + &p->inner.clprio[prio].feed); + + if (!p->inner.clprio[prio].feed.rb_node) + mask |= 1 << prio; + } + + p->prio_activity &= ~mask; + cl = p; + p = cl->parent; + + } + if (cl->cmode == HTB_CAN_SEND && mask) + htb_remove_class_from_row(q, cl, mask); +} + +static inline s64 htb_lowater(const struct htb_class *cl) +{ + if (htb_hysteresis) + return cl->cmode != HTB_CANT_SEND ? -cl->cbuffer : 0; + else + return 0; +} +static inline s64 htb_hiwater(const struct htb_class *cl) +{ + if (htb_hysteresis) + return cl->cmode == HTB_CAN_SEND ? -cl->buffer : 0; + else + return 0; +} + + +/** + * htb_class_mode - computes and returns current class mode + * @cl: the target class + * @diff: diff time in microseconds + * + * It computes cl's mode at time cl->t_c+diff and returns it. If mode + * is not HTB_CAN_SEND then cl->pq_key is updated to time difference + * from now to time when cl will change its state. + * Also it is worth to note that class mode doesn't change simply + * at cl->{c,}tokens == 0 but there can rather be hysteresis of + * 0 .. -cl->{c,}buffer range. It is meant to limit number of + * mode transitions per time unit. The speed gain is about 1/6. + */ +static inline enum htb_cmode +htb_class_mode(struct htb_class *cl, s64 *diff) +{ + s64 toks; + + if ((toks = (cl->ctokens + *diff)) < htb_lowater(cl)) { + *diff = -toks; + return HTB_CANT_SEND; + } + + if ((toks = (cl->tokens + *diff)) >= htb_hiwater(cl)) + return HTB_CAN_SEND; + + *diff = -toks; + return HTB_MAY_BORROW; +} + +/** + * htb_change_class_mode - changes classe's mode + * @q: the priority event queue + * @cl: the target class + * @diff: diff time in microseconds + * + * This should be the only way how to change classe's mode under normal + * circumstances. Routine will update feed lists linkage, change mode + * and add class to the wait event queue if appropriate. New mode should + * be different from old one and cl->pq_key has to be valid if changing + * to mode other than HTB_CAN_SEND (see htb_add_to_wait_tree). + */ +static void +htb_change_class_mode(struct htb_sched *q, struct htb_class *cl, s64 *diff) +{ + enum htb_cmode new_mode = htb_class_mode(cl, diff); + + if (new_mode == cl->cmode) + return; + + if (new_mode == HTB_CANT_SEND) { + cl->overlimits++; + q->overlimits++; + } + + if (cl->prio_activity) { /* not necessary: speed optimization */ + if (cl->cmode != HTB_CANT_SEND) + htb_deactivate_prios(q, cl); + cl->cmode = new_mode; + if (new_mode != HTB_CANT_SEND) + htb_activate_prios(q, cl); + } else + cl->cmode = new_mode; +} + +/** + * htb_activate - inserts leaf cl into appropriate active feeds + * @q: the priority event queue + * @cl: the target class + * + * Routine learns (new) priority of leaf and activates feed chain + * for the prio. It can be called on already active leaf safely. + * It also adds leaf into droplist. + */ +static inline void htb_activate(struct htb_sched *q, struct htb_class *cl) +{ + WARN_ON(cl->level || !cl->leaf.q || !cl->leaf.q->q.qlen); + + if (!cl->prio_activity) { + cl->prio_activity = 1 << cl->prio; + htb_activate_prios(q, cl); + } +} + +/** + * htb_deactivate - remove leaf cl from active feeds + * @q: the priority event queue + * @cl: the target class + * + * Make sure that leaf is active. In the other words it can't be called + * with non-active leaf. It also removes class from the drop list. + */ +static inline void htb_deactivate(struct htb_sched *q, struct htb_class *cl) +{ + WARN_ON(!cl->prio_activity); + + htb_deactivate_prios(q, cl); + cl->prio_activity = 0; +} + +static int htb_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + int ret; + unsigned int len = qdisc_pkt_len(skb); + struct htb_sched *q = qdisc_priv(sch); + struct htb_class *cl = htb_classify(skb, sch, &ret); + + if (cl == HTB_DIRECT) { + /* enqueue to helper queue */ + if (q->direct_queue.qlen < q->direct_qlen) { + __qdisc_enqueue_tail(skb, &q->direct_queue); + q->direct_pkts++; + } else { + return qdisc_drop(skb, sch, to_free); + } +#ifdef CONFIG_NET_CLS_ACT + } else if (!cl) { + if (ret & __NET_XMIT_BYPASS) + qdisc_qstats_drop(sch); + __qdisc_drop(skb, to_free); + return ret; +#endif + } else if ((ret = qdisc_enqueue(skb, cl->leaf.q, + to_free)) != NET_XMIT_SUCCESS) { + if (net_xmit_drop_count(ret)) { + qdisc_qstats_drop(sch); + cl->drops++; + } + return ret; + } else { + htb_activate(q, cl); + } + + sch->qstats.backlog += len; + sch->q.qlen++; + return NET_XMIT_SUCCESS; +} + +static inline void htb_accnt_tokens(struct htb_class *cl, int bytes, s64 diff) +{ + s64 toks = diff + cl->tokens; + + if (toks > cl->buffer) + toks = cl->buffer; + toks -= (s64) psched_l2t_ns(&cl->rate, bytes); + if (toks <= -cl->mbuffer) + toks = 1 - cl->mbuffer; + + cl->tokens = toks; +} + +static inline void htb_accnt_ctokens(struct htb_class *cl, int bytes, s64 diff) +{ + s64 toks = diff + cl->ctokens; + + if (toks > cl->cbuffer) + toks = cl->cbuffer; + toks -= (s64) psched_l2t_ns(&cl->ceil, bytes); + if (toks <= -cl->mbuffer) + toks = 1 - cl->mbuffer; + + cl->ctokens = toks; +} + +/** + * htb_charge_class - charges amount "bytes" to leaf and ancestors + * @q: the priority event queue + * @cl: the class to start iterate + * @level: the minimum level to account + * @skb: the socket buffer + * + * Routine assumes that packet "bytes" long was dequeued from leaf cl + * borrowing from "level". It accounts bytes to ceil leaky bucket for + * leaf and all ancestors and to rate bucket for ancestors at levels + * "level" and higher. It also handles possible change of mode resulting + * from the update. Note that mode can also increase here (MAY_BORROW to + * CAN_SEND) because we can use more precise clock that event queue here. + * In such case we remove class from event queue first. + */ +static void htb_charge_class(struct htb_sched *q, struct htb_class *cl, + int level, struct sk_buff *skb) +{ + int bytes = qdisc_pkt_len(skb); + enum htb_cmode old_mode; + s64 diff; + + while (cl) { + diff = min_t(s64, q->now - cl->t_c, cl->mbuffer); + if (cl->level >= level) { + if (cl->level == level) + cl->xstats.lends++; + htb_accnt_tokens(cl, bytes, diff); + } else { + cl->xstats.borrows++; + cl->tokens += diff; /* we moved t_c; update tokens */ + } + htb_accnt_ctokens(cl, bytes, diff); + cl->t_c = q->now; + + old_mode = cl->cmode; + diff = 0; + htb_change_class_mode(q, cl, &diff); + if (old_mode != cl->cmode) { + if (old_mode != HTB_CAN_SEND) + htb_safe_rb_erase(&cl->pq_node, &q->hlevel[cl->level].wait_pq); + if (cl->cmode != HTB_CAN_SEND) + htb_add_to_wait_tree(q, cl, diff); + } + + /* update basic stats except for leaves which are already updated */ + if (cl->level) + bstats_update(&cl->bstats, skb); + + cl = cl->parent; + } +} + +/** + * htb_do_events - make mode changes to classes at the level + * @q: the priority event queue + * @level: which wait_pq in 'q->hlevel' + * @start: start jiffies + * + * Scans event queue for pending events and applies them. Returns time of + * next pending event (0 for no event in pq, q->now for too many events). + * Note: Applied are events whose have cl->pq_key <= q->now. + */ +static s64 htb_do_events(struct htb_sched *q, const int level, + unsigned long start) +{ + /* don't run for longer than 2 jiffies; 2 is used instead of + * 1 to simplify things when jiffy is going to be incremented + * too soon + */ + unsigned long stop_at = start + 2; + struct rb_root *wait_pq = &q->hlevel[level].wait_pq; + + while (time_before(jiffies, stop_at)) { + struct htb_class *cl; + s64 diff; + struct rb_node *p = rb_first(wait_pq); + + if (!p) + return 0; + + cl = rb_entry(p, struct htb_class, pq_node); + if (cl->pq_key > q->now) + return cl->pq_key; + + htb_safe_rb_erase(p, wait_pq); + diff = min_t(s64, q->now - cl->t_c, cl->mbuffer); + htb_change_class_mode(q, cl, &diff); + if (cl->cmode != HTB_CAN_SEND) + htb_add_to_wait_tree(q, cl, diff); + } + + /* too much load - let's continue after a break for scheduling */ + if (!(q->warned & HTB_WARN_TOOMANYEVENTS)) { + pr_warn("htb: too many events!\n"); + q->warned |= HTB_WARN_TOOMANYEVENTS; + } + + return q->now; +} + +/* Returns class->node+prio from id-tree where classe's id is >= id. NULL + * is no such one exists. + */ +static struct rb_node *htb_id_find_next_upper(int prio, struct rb_node *n, + u32 id) +{ + struct rb_node *r = NULL; + while (n) { + struct htb_class *cl = + rb_entry(n, struct htb_class, node[prio]); + + if (id > cl->common.classid) { + n = n->rb_right; + } else if (id < cl->common.classid) { + r = n; + n = n->rb_left; + } else { + return n; + } + } + return r; +} + +/** + * htb_lookup_leaf - returns next leaf class in DRR order + * @hprio: the current one + * @prio: which prio in class + * + * Find leaf where current feed pointers points to. + */ +static struct htb_class *htb_lookup_leaf(struct htb_prio *hprio, const int prio) +{ + int i; + struct { + struct rb_node *root; + struct rb_node **pptr; + u32 *pid; + } stk[TC_HTB_MAXDEPTH], *sp = stk; + + BUG_ON(!hprio->row.rb_node); + sp->root = hprio->row.rb_node; + sp->pptr = &hprio->ptr; + sp->pid = &hprio->last_ptr_id; + + for (i = 0; i < 65535; i++) { + if (!*sp->pptr && *sp->pid) { + /* ptr was invalidated but id is valid - try to recover + * the original or next ptr + */ + *sp->pptr = + htb_id_find_next_upper(prio, sp->root, *sp->pid); + } + *sp->pid = 0; /* ptr is valid now so that remove this hint as it + * can become out of date quickly + */ + if (!*sp->pptr) { /* we are at right end; rewind & go up */ + *sp->pptr = sp->root; + while ((*sp->pptr)->rb_left) + *sp->pptr = (*sp->pptr)->rb_left; + if (sp > stk) { + sp--; + if (!*sp->pptr) { + WARN_ON(1); + return NULL; + } + htb_next_rb_node(sp->pptr); + } + } else { + struct htb_class *cl; + struct htb_prio *clp; + + cl = rb_entry(*sp->pptr, struct htb_class, node[prio]); + if (!cl->level) + return cl; + clp = &cl->inner.clprio[prio]; + (++sp)->root = clp->feed.rb_node; + sp->pptr = &clp->ptr; + sp->pid = &clp->last_ptr_id; + } + } + WARN_ON(1); + return NULL; +} + +/* dequeues packet at given priority and level; call only if + * you are sure that there is active class at prio/level + */ +static struct sk_buff *htb_dequeue_tree(struct htb_sched *q, const int prio, + const int level) +{ + struct sk_buff *skb = NULL; + struct htb_class *cl, *start; + struct htb_level *hlevel = &q->hlevel[level]; + struct htb_prio *hprio = &hlevel->hprio[prio]; + + /* look initial class up in the row */ + start = cl = htb_lookup_leaf(hprio, prio); + + do { +next: + if (unlikely(!cl)) + return NULL; + + /* class can be empty - it is unlikely but can be true if leaf + * qdisc drops packets in enqueue routine or if someone used + * graft operation on the leaf since last dequeue; + * simply deactivate and skip such class + */ + if (unlikely(cl->leaf.q->q.qlen == 0)) { + struct htb_class *next; + htb_deactivate(q, cl); + + /* row/level might become empty */ + if ((q->row_mask[level] & (1 << prio)) == 0) + return NULL; + + next = htb_lookup_leaf(hprio, prio); + + if (cl == start) /* fix start if we just deleted it */ + start = next; + cl = next; + goto next; + } + + skb = cl->leaf.q->dequeue(cl->leaf.q); + if (likely(skb != NULL)) + break; + + qdisc_warn_nonwc("htb", cl->leaf.q); + htb_next_rb_node(level ? &cl->parent->inner.clprio[prio].ptr: + &q->hlevel[0].hprio[prio].ptr); + cl = htb_lookup_leaf(hprio, prio); + + } while (cl != start); + + if (likely(skb != NULL)) { + bstats_update(&cl->bstats, skb); + cl->leaf.deficit[level] -= qdisc_pkt_len(skb); + if (cl->leaf.deficit[level] < 0) { + cl->leaf.deficit[level] += cl->quantum; + htb_next_rb_node(level ? &cl->parent->inner.clprio[prio].ptr : + &q->hlevel[0].hprio[prio].ptr); + } + /* this used to be after charge_class but this constelation + * gives us slightly better performance + */ + if (!cl->leaf.q->q.qlen) + htb_deactivate(q, cl); + htb_charge_class(q, cl, level, skb); + } + return skb; +} + +static struct sk_buff *htb_dequeue(struct Qdisc *sch) +{ + struct sk_buff *skb; + struct htb_sched *q = qdisc_priv(sch); + int level; + s64 next_event; + unsigned long start_at; + + /* try to dequeue direct packets as high prio (!) to minimize cpu work */ + skb = __qdisc_dequeue_head(&q->direct_queue); + if (skb != NULL) { +ok: + qdisc_bstats_update(sch, skb); + qdisc_qstats_backlog_dec(sch, skb); + sch->q.qlen--; + return skb; + } + + if (!sch->q.qlen) + goto fin; + q->now = ktime_get_ns(); + start_at = jiffies; + + next_event = q->now + 5LLU * NSEC_PER_SEC; + + for (level = 0; level < TC_HTB_MAXDEPTH; level++) { + /* common case optimization - skip event handler quickly */ + int m; + s64 event = q->near_ev_cache[level]; + + if (q->now >= event) { + event = htb_do_events(q, level, start_at); + if (!event) + event = q->now + NSEC_PER_SEC; + q->near_ev_cache[level] = event; + } + + if (next_event > event) + next_event = event; + + m = ~q->row_mask[level]; + while (m != (int)(-1)) { + int prio = ffz(m); + + m |= 1 << prio; + skb = htb_dequeue_tree(q, prio, level); + if (likely(skb != NULL)) + goto ok; + } + } + if (likely(next_event > q->now)) + qdisc_watchdog_schedule_ns(&q->watchdog, next_event); + else + schedule_work(&q->work); +fin: + return skb; +} + +/* reset all classes */ +/* always caled under BH & queue lock */ +static void htb_reset(struct Qdisc *sch) +{ + struct htb_sched *q = qdisc_priv(sch); + struct htb_class *cl; + unsigned int i; + + for (i = 0; i < q->clhash.hashsize; i++) { + hlist_for_each_entry(cl, &q->clhash.hash[i], common.hnode) { + if (cl->level) + memset(&cl->inner, 0, sizeof(cl->inner)); + else { + if (cl->leaf.q && !q->offload) + qdisc_reset(cl->leaf.q); + } + cl->prio_activity = 0; + cl->cmode = HTB_CAN_SEND; + } + } + qdisc_watchdog_cancel(&q->watchdog); + __qdisc_reset_queue(&q->direct_queue); + memset(q->hlevel, 0, sizeof(q->hlevel)); + memset(q->row_mask, 0, sizeof(q->row_mask)); +} + +static const struct nla_policy htb_policy[TCA_HTB_MAX + 1] = { + [TCA_HTB_PARMS] = { .len = sizeof(struct tc_htb_opt) }, + [TCA_HTB_INIT] = { .len = sizeof(struct tc_htb_glob) }, + [TCA_HTB_CTAB] = { .type = NLA_BINARY, .len = TC_RTAB_SIZE }, + [TCA_HTB_RTAB] = { .type = NLA_BINARY, .len = TC_RTAB_SIZE }, + [TCA_HTB_DIRECT_QLEN] = { .type = NLA_U32 }, + [TCA_HTB_RATE64] = { .type = NLA_U64 }, + [TCA_HTB_CEIL64] = { .type = NLA_U64 }, + [TCA_HTB_OFFLOAD] = { .type = NLA_FLAG }, +}; + +static void htb_work_func(struct work_struct *work) +{ + struct htb_sched *q = container_of(work, struct htb_sched, work); + struct Qdisc *sch = q->watchdog.qdisc; + + rcu_read_lock(); + __netif_schedule(qdisc_root(sch)); + rcu_read_unlock(); +} + +static void htb_set_lockdep_class_child(struct Qdisc *q) +{ + static struct lock_class_key child_key; + + lockdep_set_class(qdisc_lock(q), &child_key); +} + +static int htb_offload(struct net_device *dev, struct tc_htb_qopt_offload *opt) +{ + return dev->netdev_ops->ndo_setup_tc(dev, TC_SETUP_QDISC_HTB, opt); +} + +static int htb_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct net_device *dev = qdisc_dev(sch); + struct tc_htb_qopt_offload offload_opt; + struct htb_sched *q = qdisc_priv(sch); + struct nlattr *tb[TCA_HTB_MAX + 1]; + struct tc_htb_glob *gopt; + unsigned int ntx; + bool offload; + int err; + + qdisc_watchdog_init(&q->watchdog, sch); + INIT_WORK(&q->work, htb_work_func); + + if (!opt) + return -EINVAL; + + err = tcf_block_get(&q->block, &q->filter_list, sch, extack); + if (err) + return err; + + err = nla_parse_nested_deprecated(tb, TCA_HTB_MAX, opt, htb_policy, + NULL); + if (err < 0) + return err; + + if (!tb[TCA_HTB_INIT]) + return -EINVAL; + + gopt = nla_data(tb[TCA_HTB_INIT]); + if (gopt->version != HTB_VER >> 16) + return -EINVAL; + + offload = nla_get_flag(tb[TCA_HTB_OFFLOAD]); + + if (offload) { + if (sch->parent != TC_H_ROOT) { + NL_SET_ERR_MSG(extack, "HTB must be the root qdisc to use offload"); + return -EOPNOTSUPP; + } + + if (!tc_can_offload(dev) || !dev->netdev_ops->ndo_setup_tc) { + NL_SET_ERR_MSG(extack, "hw-tc-offload ethtool feature flag must be on"); + return -EOPNOTSUPP; + } + + q->num_direct_qdiscs = dev->real_num_tx_queues; + q->direct_qdiscs = kcalloc(q->num_direct_qdiscs, + sizeof(*q->direct_qdiscs), + GFP_KERNEL); + if (!q->direct_qdiscs) + return -ENOMEM; + } + + err = qdisc_class_hash_init(&q->clhash); + if (err < 0) + return err; + + if (tb[TCA_HTB_DIRECT_QLEN]) + q->direct_qlen = nla_get_u32(tb[TCA_HTB_DIRECT_QLEN]); + else + q->direct_qlen = qdisc_dev(sch)->tx_queue_len; + + if ((q->rate2quantum = gopt->rate2quantum) < 1) + q->rate2quantum = 1; + q->defcls = gopt->defcls; + + if (!offload) + return 0; + + for (ntx = 0; ntx < q->num_direct_qdiscs; ntx++) { + struct netdev_queue *dev_queue = netdev_get_tx_queue(dev, ntx); + struct Qdisc *qdisc; + + qdisc = qdisc_create_dflt(dev_queue, &pfifo_qdisc_ops, + TC_H_MAKE(sch->handle, 0), extack); + if (!qdisc) { + return -ENOMEM; + } + + htb_set_lockdep_class_child(qdisc); + q->direct_qdiscs[ntx] = qdisc; + qdisc->flags |= TCQ_F_ONETXQUEUE | TCQ_F_NOPARENT; + } + + sch->flags |= TCQ_F_MQROOT; + + offload_opt = (struct tc_htb_qopt_offload) { + .command = TC_HTB_CREATE, + .parent_classid = TC_H_MAJ(sch->handle) >> 16, + .classid = TC_H_MIN(q->defcls), + .extack = extack, + }; + err = htb_offload(dev, &offload_opt); + if (err) + return err; + + /* Defer this assignment, so that htb_destroy skips offload-related + * parts (especially calling ndo_setup_tc) on errors. + */ + q->offload = true; + + return 0; +} + +static void htb_attach_offload(struct Qdisc *sch) +{ + struct net_device *dev = qdisc_dev(sch); + struct htb_sched *q = qdisc_priv(sch); + unsigned int ntx; + + for (ntx = 0; ntx < q->num_direct_qdiscs; ntx++) { + struct Qdisc *old, *qdisc = q->direct_qdiscs[ntx]; + + old = dev_graft_qdisc(qdisc->dev_queue, qdisc); + qdisc_put(old); + qdisc_hash_add(qdisc, false); + } + for (ntx = q->num_direct_qdiscs; ntx < dev->num_tx_queues; ntx++) { + struct netdev_queue *dev_queue = netdev_get_tx_queue(dev, ntx); + struct Qdisc *old = dev_graft_qdisc(dev_queue, NULL); + + qdisc_put(old); + } + + kfree(q->direct_qdiscs); + q->direct_qdiscs = NULL; +} + +static void htb_attach_software(struct Qdisc *sch) +{ + struct net_device *dev = qdisc_dev(sch); + unsigned int ntx; + + /* Resemble qdisc_graft behavior. */ + for (ntx = 0; ntx < dev->num_tx_queues; ntx++) { + struct netdev_queue *dev_queue = netdev_get_tx_queue(dev, ntx); + struct Qdisc *old = dev_graft_qdisc(dev_queue, sch); + + qdisc_refcount_inc(sch); + + qdisc_put(old); + } +} + +static void htb_attach(struct Qdisc *sch) +{ + struct htb_sched *q = qdisc_priv(sch); + + if (q->offload) + htb_attach_offload(sch); + else + htb_attach_software(sch); +} + +static int htb_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct htb_sched *q = qdisc_priv(sch); + struct nlattr *nest; + struct tc_htb_glob gopt; + + if (q->offload) + sch->flags |= TCQ_F_OFFLOADED; + else + sch->flags &= ~TCQ_F_OFFLOADED; + + sch->qstats.overlimits = q->overlimits; + /* Its safe to not acquire qdisc lock. As we hold RTNL, + * no change can happen on the qdisc parameters. + */ + + gopt.direct_pkts = q->direct_pkts; + gopt.version = HTB_VER; + gopt.rate2quantum = q->rate2quantum; + gopt.defcls = q->defcls; + gopt.debug = 0; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + if (nla_put(skb, TCA_HTB_INIT, sizeof(gopt), &gopt) || + nla_put_u32(skb, TCA_HTB_DIRECT_QLEN, q->direct_qlen)) + goto nla_put_failure; + if (q->offload && nla_put_flag(skb, TCA_HTB_OFFLOAD)) + goto nla_put_failure; + + return nla_nest_end(skb, nest); + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static int htb_dump_class(struct Qdisc *sch, unsigned long arg, + struct sk_buff *skb, struct tcmsg *tcm) +{ + struct htb_class *cl = (struct htb_class *)arg; + struct htb_sched *q = qdisc_priv(sch); + struct nlattr *nest; + struct tc_htb_opt opt; + + /* Its safe to not acquire qdisc lock. As we hold RTNL, + * no change can happen on the class parameters. + */ + tcm->tcm_parent = cl->parent ? cl->parent->common.classid : TC_H_ROOT; + tcm->tcm_handle = cl->common.classid; + if (!cl->level && cl->leaf.q) + tcm->tcm_info = cl->leaf.q->handle; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + + memset(&opt, 0, sizeof(opt)); + + psched_ratecfg_getrate(&opt.rate, &cl->rate); + opt.buffer = PSCHED_NS2TICKS(cl->buffer); + psched_ratecfg_getrate(&opt.ceil, &cl->ceil); + opt.cbuffer = PSCHED_NS2TICKS(cl->cbuffer); + opt.quantum = cl->quantum; + opt.prio = cl->prio; + opt.level = cl->level; + if (nla_put(skb, TCA_HTB_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + if (q->offload && nla_put_flag(skb, TCA_HTB_OFFLOAD)) + goto nla_put_failure; + if ((cl->rate.rate_bytes_ps >= (1ULL << 32)) && + nla_put_u64_64bit(skb, TCA_HTB_RATE64, cl->rate.rate_bytes_ps, + TCA_HTB_PAD)) + goto nla_put_failure; + if ((cl->ceil.rate_bytes_ps >= (1ULL << 32)) && + nla_put_u64_64bit(skb, TCA_HTB_CEIL64, cl->ceil.rate_bytes_ps, + TCA_HTB_PAD)) + goto nla_put_failure; + + return nla_nest_end(skb, nest); + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static void htb_offload_aggregate_stats(struct htb_sched *q, + struct htb_class *cl) +{ + u64 bytes = 0, packets = 0; + struct htb_class *c; + unsigned int i; + + gnet_stats_basic_sync_init(&cl->bstats); + + for (i = 0; i < q->clhash.hashsize; i++) { + hlist_for_each_entry(c, &q->clhash.hash[i], common.hnode) { + struct htb_class *p = c; + + while (p && p->level < cl->level) + p = p->parent; + + if (p != cl) + continue; + + bytes += u64_stats_read(&c->bstats_bias.bytes); + packets += u64_stats_read(&c->bstats_bias.packets); + if (c->level == 0) { + bytes += u64_stats_read(&c->leaf.q->bstats.bytes); + packets += u64_stats_read(&c->leaf.q->bstats.packets); + } + } + } + _bstats_update(&cl->bstats, bytes, packets); +} + +static int +htb_dump_class_stats(struct Qdisc *sch, unsigned long arg, struct gnet_dump *d) +{ + struct htb_class *cl = (struct htb_class *)arg; + struct htb_sched *q = qdisc_priv(sch); + struct gnet_stats_queue qs = { + .drops = cl->drops, + .overlimits = cl->overlimits, + }; + __u32 qlen = 0; + + if (!cl->level && cl->leaf.q) + qdisc_qstats_qlen_backlog(cl->leaf.q, &qlen, &qs.backlog); + + cl->xstats.tokens = clamp_t(s64, PSCHED_NS2TICKS(cl->tokens), + INT_MIN, INT_MAX); + cl->xstats.ctokens = clamp_t(s64, PSCHED_NS2TICKS(cl->ctokens), + INT_MIN, INT_MAX); + + if (q->offload) { + if (!cl->level) { + if (cl->leaf.q) + cl->bstats = cl->leaf.q->bstats; + else + gnet_stats_basic_sync_init(&cl->bstats); + _bstats_update(&cl->bstats, + u64_stats_read(&cl->bstats_bias.bytes), + u64_stats_read(&cl->bstats_bias.packets)); + } else { + htb_offload_aggregate_stats(q, cl); + } + } + + if (gnet_stats_copy_basic(d, NULL, &cl->bstats, true) < 0 || + gnet_stats_copy_rate_est(d, &cl->rate_est) < 0 || + gnet_stats_copy_queue(d, NULL, &qs, qlen) < 0) + return -1; + + return gnet_stats_copy_app(d, &cl->xstats, sizeof(cl->xstats)); +} + +static struct netdev_queue * +htb_select_queue(struct Qdisc *sch, struct tcmsg *tcm) +{ + struct net_device *dev = qdisc_dev(sch); + struct tc_htb_qopt_offload offload_opt; + struct htb_sched *q = qdisc_priv(sch); + int err; + + if (!q->offload) + return sch->dev_queue; + + offload_opt = (struct tc_htb_qopt_offload) { + .command = TC_HTB_LEAF_QUERY_QUEUE, + .classid = TC_H_MIN(tcm->tcm_parent), + }; + err = htb_offload(dev, &offload_opt); + if (err || offload_opt.qid >= dev->num_tx_queues) + return NULL; + return netdev_get_tx_queue(dev, offload_opt.qid); +} + +static struct Qdisc * +htb_graft_helper(struct netdev_queue *dev_queue, struct Qdisc *new_q) +{ + struct net_device *dev = dev_queue->dev; + struct Qdisc *old_q; + + if (dev->flags & IFF_UP) + dev_deactivate(dev); + old_q = dev_graft_qdisc(dev_queue, new_q); + if (new_q) + new_q->flags |= TCQ_F_ONETXQUEUE | TCQ_F_NOPARENT; + if (dev->flags & IFF_UP) + dev_activate(dev); + + return old_q; +} + +static struct netdev_queue *htb_offload_get_queue(struct htb_class *cl) +{ + struct netdev_queue *queue; + + queue = cl->leaf.offload_queue; + if (!(cl->leaf.q->flags & TCQ_F_BUILTIN)) + WARN_ON(cl->leaf.q->dev_queue != queue); + + return queue; +} + +static void htb_offload_move_qdisc(struct Qdisc *sch, struct htb_class *cl_old, + struct htb_class *cl_new, bool destroying) +{ + struct netdev_queue *queue_old, *queue_new; + struct net_device *dev = qdisc_dev(sch); + + queue_old = htb_offload_get_queue(cl_old); + queue_new = htb_offload_get_queue(cl_new); + + if (!destroying) { + struct Qdisc *qdisc; + + if (dev->flags & IFF_UP) + dev_deactivate(dev); + qdisc = dev_graft_qdisc(queue_old, NULL); + WARN_ON(qdisc != cl_old->leaf.q); + } + + if (!(cl_old->leaf.q->flags & TCQ_F_BUILTIN)) + cl_old->leaf.q->dev_queue = queue_new; + cl_old->leaf.offload_queue = queue_new; + + if (!destroying) { + struct Qdisc *qdisc; + + qdisc = dev_graft_qdisc(queue_new, cl_old->leaf.q); + if (dev->flags & IFF_UP) + dev_activate(dev); + WARN_ON(!(qdisc->flags & TCQ_F_BUILTIN)); + } +} + +static int htb_graft(struct Qdisc *sch, unsigned long arg, struct Qdisc *new, + struct Qdisc **old, struct netlink_ext_ack *extack) +{ + struct netdev_queue *dev_queue = sch->dev_queue; + struct htb_class *cl = (struct htb_class *)arg; + struct htb_sched *q = qdisc_priv(sch); + struct Qdisc *old_q; + + if (cl->level) + return -EINVAL; + + if (q->offload) + dev_queue = htb_offload_get_queue(cl); + + if (!new) { + new = qdisc_create_dflt(dev_queue, &pfifo_qdisc_ops, + cl->common.classid, extack); + if (!new) + return -ENOBUFS; + } + + if (q->offload) { + htb_set_lockdep_class_child(new); + /* One ref for cl->leaf.q, the other for dev_queue->qdisc. */ + qdisc_refcount_inc(new); + old_q = htb_graft_helper(dev_queue, new); + } + + *old = qdisc_replace(sch, new, &cl->leaf.q); + + if (q->offload) { + WARN_ON(old_q != *old); + qdisc_put(old_q); + } + + return 0; +} + +static struct Qdisc *htb_leaf(struct Qdisc *sch, unsigned long arg) +{ + struct htb_class *cl = (struct htb_class *)arg; + return !cl->level ? cl->leaf.q : NULL; +} + +static void htb_qlen_notify(struct Qdisc *sch, unsigned long arg) +{ + struct htb_class *cl = (struct htb_class *)arg; + + htb_deactivate(qdisc_priv(sch), cl); +} + +static inline int htb_parent_last_child(struct htb_class *cl) +{ + if (!cl->parent) + /* the root class */ + return 0; + if (cl->parent->children > 1) + /* not the last child */ + return 0; + return 1; +} + +static void htb_parent_to_leaf(struct Qdisc *sch, struct htb_class *cl, + struct Qdisc *new_q) +{ + struct htb_sched *q = qdisc_priv(sch); + struct htb_class *parent = cl->parent; + + WARN_ON(cl->level || !cl->leaf.q || cl->prio_activity); + + if (parent->cmode != HTB_CAN_SEND) + htb_safe_rb_erase(&parent->pq_node, + &q->hlevel[parent->level].wait_pq); + + parent->level = 0; + memset(&parent->inner, 0, sizeof(parent->inner)); + parent->leaf.q = new_q ? new_q : &noop_qdisc; + parent->tokens = parent->buffer; + parent->ctokens = parent->cbuffer; + parent->t_c = ktime_get_ns(); + parent->cmode = HTB_CAN_SEND; + if (q->offload) + parent->leaf.offload_queue = cl->leaf.offload_queue; +} + +static void htb_parent_to_leaf_offload(struct Qdisc *sch, + struct netdev_queue *dev_queue, + struct Qdisc *new_q) +{ + struct Qdisc *old_q; + + /* One ref for cl->leaf.q, the other for dev_queue->qdisc. */ + if (new_q) + qdisc_refcount_inc(new_q); + old_q = htb_graft_helper(dev_queue, new_q); + WARN_ON(!(old_q->flags & TCQ_F_BUILTIN)); +} + +static int htb_destroy_class_offload(struct Qdisc *sch, struct htb_class *cl, + bool last_child, bool destroying, + struct netlink_ext_ack *extack) +{ + struct tc_htb_qopt_offload offload_opt; + struct netdev_queue *dev_queue; + struct Qdisc *q = cl->leaf.q; + struct Qdisc *old; + int err; + + if (cl->level) + return -EINVAL; + + WARN_ON(!q); + dev_queue = htb_offload_get_queue(cl); + /* When destroying, caller qdisc_graft grafts the new qdisc and invokes + * qdisc_put for the qdisc being destroyed. htb_destroy_class_offload + * does not need to graft or qdisc_put the qdisc being destroyed. + */ + if (!destroying) { + old = htb_graft_helper(dev_queue, NULL); + /* Last qdisc grafted should be the same as cl->leaf.q when + * calling htb_delete. + */ + WARN_ON(old != q); + } + + if (cl->parent) { + _bstats_update(&cl->parent->bstats_bias, + u64_stats_read(&q->bstats.bytes), + u64_stats_read(&q->bstats.packets)); + } + + offload_opt = (struct tc_htb_qopt_offload) { + .command = !last_child ? TC_HTB_LEAF_DEL : + destroying ? TC_HTB_LEAF_DEL_LAST_FORCE : + TC_HTB_LEAF_DEL_LAST, + .classid = cl->common.classid, + .extack = extack, + }; + err = htb_offload(qdisc_dev(sch), &offload_opt); + + if (!destroying) { + if (!err) + qdisc_put(old); + else + htb_graft_helper(dev_queue, old); + } + + if (last_child) + return err; + + if (!err && offload_opt.classid != TC_H_MIN(cl->common.classid)) { + u32 classid = TC_H_MAJ(sch->handle) | + TC_H_MIN(offload_opt.classid); + struct htb_class *moved_cl = htb_find(classid, sch); + + htb_offload_move_qdisc(sch, moved_cl, cl, destroying); + } + + return err; +} + +static void htb_destroy_class(struct Qdisc *sch, struct htb_class *cl) +{ + if (!cl->level) { + WARN_ON(!cl->leaf.q); + qdisc_put(cl->leaf.q); + } + gen_kill_estimator(&cl->rate_est); + tcf_block_put(cl->block); + kfree(cl); +} + +static void htb_destroy(struct Qdisc *sch) +{ + struct net_device *dev = qdisc_dev(sch); + struct tc_htb_qopt_offload offload_opt; + struct htb_sched *q = qdisc_priv(sch); + struct hlist_node *next; + bool nonempty, changed; + struct htb_class *cl; + unsigned int i; + + cancel_work_sync(&q->work); + qdisc_watchdog_cancel(&q->watchdog); + /* This line used to be after htb_destroy_class call below + * and surprisingly it worked in 2.4. But it must precede it + * because filter need its target class alive to be able to call + * unbind_filter on it (without Oops). + */ + tcf_block_put(q->block); + + for (i = 0; i < q->clhash.hashsize; i++) { + hlist_for_each_entry(cl, &q->clhash.hash[i], common.hnode) { + tcf_block_put(cl->block); + cl->block = NULL; + } + } + + do { + nonempty = false; + changed = false; + for (i = 0; i < q->clhash.hashsize; i++) { + hlist_for_each_entry_safe(cl, next, &q->clhash.hash[i], + common.hnode) { + bool last_child; + + if (!q->offload) { + htb_destroy_class(sch, cl); + continue; + } + + nonempty = true; + + if (cl->level) + continue; + + changed = true; + + last_child = htb_parent_last_child(cl); + htb_destroy_class_offload(sch, cl, last_child, + true, NULL); + qdisc_class_hash_remove(&q->clhash, + &cl->common); + if (cl->parent) + cl->parent->children--; + if (last_child) + htb_parent_to_leaf(sch, cl, NULL); + htb_destroy_class(sch, cl); + } + } + } while (changed); + WARN_ON(nonempty); + + qdisc_class_hash_destroy(&q->clhash); + __qdisc_reset_queue(&q->direct_queue); + + if (q->offload) { + offload_opt = (struct tc_htb_qopt_offload) { + .command = TC_HTB_DESTROY, + }; + htb_offload(dev, &offload_opt); + } + + if (!q->direct_qdiscs) + return; + for (i = 0; i < q->num_direct_qdiscs && q->direct_qdiscs[i]; i++) + qdisc_put(q->direct_qdiscs[i]); + kfree(q->direct_qdiscs); +} + +static int htb_delete(struct Qdisc *sch, unsigned long arg, + struct netlink_ext_ack *extack) +{ + struct htb_sched *q = qdisc_priv(sch); + struct htb_class *cl = (struct htb_class *)arg; + struct Qdisc *new_q = NULL; + int last_child = 0; + int err; + + /* TODO: why don't allow to delete subtree ? references ? does + * tc subsys guarantee us that in htb_destroy it holds no class + * refs so that we can remove children safely there ? + */ + if (cl->children || cl->filter_cnt) + return -EBUSY; + + if (!cl->level && htb_parent_last_child(cl)) + last_child = 1; + + if (q->offload) { + err = htb_destroy_class_offload(sch, cl, last_child, false, + extack); + if (err) + return err; + } + + if (last_child) { + struct netdev_queue *dev_queue = sch->dev_queue; + + if (q->offload) + dev_queue = htb_offload_get_queue(cl); + + new_q = qdisc_create_dflt(dev_queue, &pfifo_qdisc_ops, + cl->parent->common.classid, + NULL); + if (q->offload) { + if (new_q) + htb_set_lockdep_class_child(new_q); + htb_parent_to_leaf_offload(sch, dev_queue, new_q); + } + } + + sch_tree_lock(sch); + + if (!cl->level) + qdisc_purge_queue(cl->leaf.q); + + /* delete from hash and active; remainder in destroy_class */ + qdisc_class_hash_remove(&q->clhash, &cl->common); + if (cl->parent) + cl->parent->children--; + + if (cl->prio_activity) + htb_deactivate(q, cl); + + if (cl->cmode != HTB_CAN_SEND) + htb_safe_rb_erase(&cl->pq_node, + &q->hlevel[cl->level].wait_pq); + + if (last_child) + htb_parent_to_leaf(sch, cl, new_q); + + sch_tree_unlock(sch); + + htb_destroy_class(sch, cl); + return 0; +} + +static int htb_change_class(struct Qdisc *sch, u32 classid, + u32 parentid, struct nlattr **tca, + unsigned long *arg, struct netlink_ext_ack *extack) +{ + int err = -EINVAL; + struct htb_sched *q = qdisc_priv(sch); + struct htb_class *cl = (struct htb_class *)*arg, *parent; + struct tc_htb_qopt_offload offload_opt; + struct nlattr *opt = tca[TCA_OPTIONS]; + struct nlattr *tb[TCA_HTB_MAX + 1]; + struct Qdisc *parent_qdisc = NULL; + struct netdev_queue *dev_queue; + struct tc_htb_opt *hopt; + u64 rate64, ceil64; + int warn = 0; + + /* extract all subattrs from opt attr */ + if (!opt) + goto failure; + + err = nla_parse_nested_deprecated(tb, TCA_HTB_MAX, opt, htb_policy, + NULL); + if (err < 0) + goto failure; + + err = -EINVAL; + if (tb[TCA_HTB_PARMS] == NULL) + goto failure; + + parent = parentid == TC_H_ROOT ? NULL : htb_find(parentid, sch); + + hopt = nla_data(tb[TCA_HTB_PARMS]); + if (!hopt->rate.rate || !hopt->ceil.rate) + goto failure; + + if (q->offload) { + /* Options not supported by the offload. */ + if (hopt->rate.overhead || hopt->ceil.overhead) { + NL_SET_ERR_MSG(extack, "HTB offload doesn't support the overhead parameter"); + goto failure; + } + if (hopt->rate.mpu || hopt->ceil.mpu) { + NL_SET_ERR_MSG(extack, "HTB offload doesn't support the mpu parameter"); + goto failure; + } + if (hopt->quantum) { + NL_SET_ERR_MSG(extack, "HTB offload doesn't support the quantum parameter"); + goto failure; + } + if (hopt->prio) { + NL_SET_ERR_MSG(extack, "HTB offload doesn't support the prio parameter"); + goto failure; + } + } + + /* Keeping backward compatible with rate_table based iproute2 tc */ + if (hopt->rate.linklayer == TC_LINKLAYER_UNAWARE) + qdisc_put_rtab(qdisc_get_rtab(&hopt->rate, tb[TCA_HTB_RTAB], + NULL)); + + if (hopt->ceil.linklayer == TC_LINKLAYER_UNAWARE) + qdisc_put_rtab(qdisc_get_rtab(&hopt->ceil, tb[TCA_HTB_CTAB], + NULL)); + + rate64 = tb[TCA_HTB_RATE64] ? nla_get_u64(tb[TCA_HTB_RATE64]) : 0; + ceil64 = tb[TCA_HTB_CEIL64] ? nla_get_u64(tb[TCA_HTB_CEIL64]) : 0; + + if (!cl) { /* new class */ + struct net_device *dev = qdisc_dev(sch); + struct Qdisc *new_q, *old_q; + int prio; + struct { + struct nlattr nla; + struct gnet_estimator opt; + } est = { + .nla = { + .nla_len = nla_attr_size(sizeof(est.opt)), + .nla_type = TCA_RATE, + }, + .opt = { + /* 4s interval, 16s averaging constant */ + .interval = 2, + .ewma_log = 2, + }, + }; + + /* check for valid classid */ + if (!classid || TC_H_MAJ(classid ^ sch->handle) || + htb_find(classid, sch)) + goto failure; + + /* check maximal depth */ + if (parent && parent->parent && parent->parent->level < 2) { + pr_err("htb: tree is too deep\n"); + goto failure; + } + err = -ENOBUFS; + cl = kzalloc(sizeof(*cl), GFP_KERNEL); + if (!cl) + goto failure; + + gnet_stats_basic_sync_init(&cl->bstats); + gnet_stats_basic_sync_init(&cl->bstats_bias); + + err = tcf_block_get(&cl->block, &cl->filter_list, sch, extack); + if (err) { + kfree(cl); + goto failure; + } + if (htb_rate_est || tca[TCA_RATE]) { + err = gen_new_estimator(&cl->bstats, NULL, + &cl->rate_est, + NULL, + true, + tca[TCA_RATE] ? : &est.nla); + if (err) + goto err_block_put; + } + + cl->children = 0; + RB_CLEAR_NODE(&cl->pq_node); + + for (prio = 0; prio < TC_HTB_NUMPRIO; prio++) + RB_CLEAR_NODE(&cl->node[prio]); + + cl->common.classid = classid; + + /* Make sure nothing interrupts us in between of two + * ndo_setup_tc calls. + */ + ASSERT_RTNL(); + + /* create leaf qdisc early because it uses kmalloc(GFP_KERNEL) + * so that can't be used inside of sch_tree_lock + * -- thanks to Karlis Peisenieks + */ + if (!q->offload) { + dev_queue = sch->dev_queue; + } else if (!(parent && !parent->level)) { + /* Assign a dev_queue to this classid. */ + offload_opt = (struct tc_htb_qopt_offload) { + .command = TC_HTB_LEAF_ALLOC_QUEUE, + .classid = cl->common.classid, + .parent_classid = parent ? + TC_H_MIN(parent->common.classid) : + TC_HTB_CLASSID_ROOT, + .rate = max_t(u64, hopt->rate.rate, rate64), + .ceil = max_t(u64, hopt->ceil.rate, ceil64), + .extack = extack, + }; + err = htb_offload(dev, &offload_opt); + if (err) { + pr_err("htb: TC_HTB_LEAF_ALLOC_QUEUE failed with err = %d\n", + err); + goto err_kill_estimator; + } + dev_queue = netdev_get_tx_queue(dev, offload_opt.qid); + } else { /* First child. */ + dev_queue = htb_offload_get_queue(parent); + old_q = htb_graft_helper(dev_queue, NULL); + WARN_ON(old_q != parent->leaf.q); + offload_opt = (struct tc_htb_qopt_offload) { + .command = TC_HTB_LEAF_TO_INNER, + .classid = cl->common.classid, + .parent_classid = + TC_H_MIN(parent->common.classid), + .rate = max_t(u64, hopt->rate.rate, rate64), + .ceil = max_t(u64, hopt->ceil.rate, ceil64), + .extack = extack, + }; + err = htb_offload(dev, &offload_opt); + if (err) { + pr_err("htb: TC_HTB_LEAF_TO_INNER failed with err = %d\n", + err); + htb_graft_helper(dev_queue, old_q); + goto err_kill_estimator; + } + _bstats_update(&parent->bstats_bias, + u64_stats_read(&old_q->bstats.bytes), + u64_stats_read(&old_q->bstats.packets)); + qdisc_put(old_q); + } + new_q = qdisc_create_dflt(dev_queue, &pfifo_qdisc_ops, + classid, NULL); + if (q->offload) { + if (new_q) { + htb_set_lockdep_class_child(new_q); + /* One ref for cl->leaf.q, the other for + * dev_queue->qdisc. + */ + qdisc_refcount_inc(new_q); + } + old_q = htb_graft_helper(dev_queue, new_q); + /* No qdisc_put needed. */ + WARN_ON(!(old_q->flags & TCQ_F_BUILTIN)); + } + sch_tree_lock(sch); + if (parent && !parent->level) { + /* turn parent into inner node */ + qdisc_purge_queue(parent->leaf.q); + parent_qdisc = parent->leaf.q; + if (parent->prio_activity) + htb_deactivate(q, parent); + + /* remove from evt list because of level change */ + if (parent->cmode != HTB_CAN_SEND) { + htb_safe_rb_erase(&parent->pq_node, &q->hlevel[0].wait_pq); + parent->cmode = HTB_CAN_SEND; + } + parent->level = (parent->parent ? parent->parent->level + : TC_HTB_MAXDEPTH) - 1; + memset(&parent->inner, 0, sizeof(parent->inner)); + } + + /* leaf (we) needs elementary qdisc */ + cl->leaf.q = new_q ? new_q : &noop_qdisc; + if (q->offload) + cl->leaf.offload_queue = dev_queue; + + cl->parent = parent; + + /* set class to be in HTB_CAN_SEND state */ + cl->tokens = PSCHED_TICKS2NS(hopt->buffer); + cl->ctokens = PSCHED_TICKS2NS(hopt->cbuffer); + cl->mbuffer = 60ULL * NSEC_PER_SEC; /* 1min */ + cl->t_c = ktime_get_ns(); + cl->cmode = HTB_CAN_SEND; + + /* attach to the hash list and parent's family */ + qdisc_class_hash_insert(&q->clhash, &cl->common); + if (parent) + parent->children++; + if (cl->leaf.q != &noop_qdisc) + qdisc_hash_add(cl->leaf.q, true); + } else { + if (tca[TCA_RATE]) { + err = gen_replace_estimator(&cl->bstats, NULL, + &cl->rate_est, + NULL, + true, + tca[TCA_RATE]); + if (err) + return err; + } + + if (q->offload) { + struct net_device *dev = qdisc_dev(sch); + + offload_opt = (struct tc_htb_qopt_offload) { + .command = TC_HTB_NODE_MODIFY, + .classid = cl->common.classid, + .rate = max_t(u64, hopt->rate.rate, rate64), + .ceil = max_t(u64, hopt->ceil.rate, ceil64), + .extack = extack, + }; + err = htb_offload(dev, &offload_opt); + if (err) + /* Estimator was replaced, and rollback may fail + * as well, so we don't try to recover it, and + * the estimator won't work property with the + * offload anyway, because bstats are updated + * only when the stats are queried. + */ + return err; + } + + sch_tree_lock(sch); + } + + psched_ratecfg_precompute(&cl->rate, &hopt->rate, rate64); + psched_ratecfg_precompute(&cl->ceil, &hopt->ceil, ceil64); + + /* it used to be a nasty bug here, we have to check that node + * is really leaf before changing cl->leaf ! + */ + if (!cl->level) { + u64 quantum = cl->rate.rate_bytes_ps; + + do_div(quantum, q->rate2quantum); + cl->quantum = min_t(u64, quantum, INT_MAX); + + if (!hopt->quantum && cl->quantum < 1000) { + warn = -1; + cl->quantum = 1000; + } + if (!hopt->quantum && cl->quantum > 200000) { + warn = 1; + cl->quantum = 200000; + } + if (hopt->quantum) + cl->quantum = hopt->quantum; + if ((cl->prio = hopt->prio) >= TC_HTB_NUMPRIO) + cl->prio = TC_HTB_NUMPRIO - 1; + } + + cl->buffer = PSCHED_TICKS2NS(hopt->buffer); + cl->cbuffer = PSCHED_TICKS2NS(hopt->cbuffer); + + sch_tree_unlock(sch); + qdisc_put(parent_qdisc); + + if (warn) + pr_warn("HTB: quantum of class %X is %s. Consider r2q change.\n", + cl->common.classid, (warn == -1 ? "small" : "big")); + + qdisc_class_hash_grow(sch, &q->clhash); + + *arg = (unsigned long)cl; + return 0; + +err_kill_estimator: + gen_kill_estimator(&cl->rate_est); +err_block_put: + tcf_block_put(cl->block); + kfree(cl); +failure: + return err; +} + +static struct tcf_block *htb_tcf_block(struct Qdisc *sch, unsigned long arg, + struct netlink_ext_ack *extack) +{ + struct htb_sched *q = qdisc_priv(sch); + struct htb_class *cl = (struct htb_class *)arg; + + return cl ? cl->block : q->block; +} + +static unsigned long htb_bind_filter(struct Qdisc *sch, unsigned long parent, + u32 classid) +{ + struct htb_class *cl = htb_find(classid, sch); + + /*if (cl && !cl->level) return 0; + * The line above used to be there to prevent attaching filters to + * leaves. But at least tc_index filter uses this just to get class + * for other reasons so that we have to allow for it. + * ---- + * 19.6.2002 As Werner explained it is ok - bind filter is just + * another way to "lock" the class - unlike "get" this lock can + * be broken by class during destroy IIUC. + */ + if (cl) + cl->filter_cnt++; + return (unsigned long)cl; +} + +static void htb_unbind_filter(struct Qdisc *sch, unsigned long arg) +{ + struct htb_class *cl = (struct htb_class *)arg; + + if (cl) + cl->filter_cnt--; +} + +static void htb_walk(struct Qdisc *sch, struct qdisc_walker *arg) +{ + struct htb_sched *q = qdisc_priv(sch); + struct htb_class *cl; + unsigned int i; + + if (arg->stop) + return; + + for (i = 0; i < q->clhash.hashsize; i++) { + hlist_for_each_entry(cl, &q->clhash.hash[i], common.hnode) { + if (!tc_qdisc_stats_dump(sch, (unsigned long)cl, arg)) + return; + } + } +} + +static const struct Qdisc_class_ops htb_class_ops = { + .select_queue = htb_select_queue, + .graft = htb_graft, + .leaf = htb_leaf, + .qlen_notify = htb_qlen_notify, + .find = htb_search, + .change = htb_change_class, + .delete = htb_delete, + .walk = htb_walk, + .tcf_block = htb_tcf_block, + .bind_tcf = htb_bind_filter, + .unbind_tcf = htb_unbind_filter, + .dump = htb_dump_class, + .dump_stats = htb_dump_class_stats, +}; + +static struct Qdisc_ops htb_qdisc_ops __read_mostly = { + .cl_ops = &htb_class_ops, + .id = "htb", + .priv_size = sizeof(struct htb_sched), + .enqueue = htb_enqueue, + .dequeue = htb_dequeue, + .peek = qdisc_peek_dequeued, + .init = htb_init, + .attach = htb_attach, + .reset = htb_reset, + .destroy = htb_destroy, + .dump = htb_dump, + .owner = THIS_MODULE, +}; + +static int __init htb_module_init(void) +{ + return register_qdisc(&htb_qdisc_ops); +} +static void __exit htb_module_exit(void) +{ + unregister_qdisc(&htb_qdisc_ops); +} + +module_init(htb_module_init) +module_exit(htb_module_exit) +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_ingress.c b/net/sched/sch_ingress.c new file mode 100644 index 000000000..e43a45499 --- /dev/null +++ b/net/sched/sch_ingress.c @@ -0,0 +1,319 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* net/sched/sch_ingress.c - Ingress and clsact qdisc + * + * Authors: Jamal Hadi Salim 1999 + */ + +#include +#include +#include +#include +#include + +#include +#include +#include + +struct ingress_sched_data { + struct tcf_block *block; + struct tcf_block_ext_info block_info; + struct mini_Qdisc_pair miniqp; +}; + +static struct Qdisc *ingress_leaf(struct Qdisc *sch, unsigned long arg) +{ + return NULL; +} + +static unsigned long ingress_find(struct Qdisc *sch, u32 classid) +{ + return TC_H_MIN(classid) + 1; +} + +static unsigned long ingress_bind_filter(struct Qdisc *sch, + unsigned long parent, u32 classid) +{ + return ingress_find(sch, classid); +} + +static void ingress_unbind_filter(struct Qdisc *sch, unsigned long cl) +{ +} + +static void ingress_walk(struct Qdisc *sch, struct qdisc_walker *walker) +{ +} + +static struct tcf_block *ingress_tcf_block(struct Qdisc *sch, unsigned long cl, + struct netlink_ext_ack *extack) +{ + struct ingress_sched_data *q = qdisc_priv(sch); + + return q->block; +} + +static void clsact_chain_head_change(struct tcf_proto *tp_head, void *priv) +{ + struct mini_Qdisc_pair *miniqp = priv; + + mini_qdisc_pair_swap(miniqp, tp_head); +}; + +static void ingress_ingress_block_set(struct Qdisc *sch, u32 block_index) +{ + struct ingress_sched_data *q = qdisc_priv(sch); + + q->block_info.block_index = block_index; +} + +static u32 ingress_ingress_block_get(struct Qdisc *sch) +{ + struct ingress_sched_data *q = qdisc_priv(sch); + + return q->block_info.block_index; +} + +static int ingress_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct ingress_sched_data *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + int err; + + if (sch->parent != TC_H_INGRESS) + return -EOPNOTSUPP; + + net_inc_ingress_queue(); + + mini_qdisc_pair_init(&q->miniqp, sch, &dev->miniq_ingress); + + q->block_info.binder_type = FLOW_BLOCK_BINDER_TYPE_CLSACT_INGRESS; + q->block_info.chain_head_change = clsact_chain_head_change; + q->block_info.chain_head_change_priv = &q->miniqp; + + err = tcf_block_get_ext(&q->block, sch, &q->block_info, extack); + if (err) + return err; + + mini_qdisc_pair_block_init(&q->miniqp, q->block); + + return 0; +} + +static void ingress_destroy(struct Qdisc *sch) +{ + struct ingress_sched_data *q = qdisc_priv(sch); + + if (sch->parent != TC_H_INGRESS) + return; + + tcf_block_put_ext(q->block, sch, &q->block_info); + net_dec_ingress_queue(); +} + +static int ingress_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + + return nla_nest_end(skb, nest); + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static const struct Qdisc_class_ops ingress_class_ops = { + .flags = QDISC_CLASS_OPS_DOIT_UNLOCKED, + .leaf = ingress_leaf, + .find = ingress_find, + .walk = ingress_walk, + .tcf_block = ingress_tcf_block, + .bind_tcf = ingress_bind_filter, + .unbind_tcf = ingress_unbind_filter, +}; + +static struct Qdisc_ops ingress_qdisc_ops __read_mostly = { + .cl_ops = &ingress_class_ops, + .id = "ingress", + .priv_size = sizeof(struct ingress_sched_data), + .static_flags = TCQ_F_INGRESS | TCQ_F_CPUSTATS, + .init = ingress_init, + .destroy = ingress_destroy, + .dump = ingress_dump, + .ingress_block_set = ingress_ingress_block_set, + .ingress_block_get = ingress_ingress_block_get, + .owner = THIS_MODULE, +}; + +struct clsact_sched_data { + struct tcf_block *ingress_block; + struct tcf_block *egress_block; + struct tcf_block_ext_info ingress_block_info; + struct tcf_block_ext_info egress_block_info; + struct mini_Qdisc_pair miniqp_ingress; + struct mini_Qdisc_pair miniqp_egress; +}; + +static unsigned long clsact_find(struct Qdisc *sch, u32 classid) +{ + switch (TC_H_MIN(classid)) { + case TC_H_MIN(TC_H_MIN_INGRESS): + case TC_H_MIN(TC_H_MIN_EGRESS): + return TC_H_MIN(classid); + default: + return 0; + } +} + +static unsigned long clsact_bind_filter(struct Qdisc *sch, + unsigned long parent, u32 classid) +{ + return clsact_find(sch, classid); +} + +static struct tcf_block *clsact_tcf_block(struct Qdisc *sch, unsigned long cl, + struct netlink_ext_ack *extack) +{ + struct clsact_sched_data *q = qdisc_priv(sch); + + switch (cl) { + case TC_H_MIN(TC_H_MIN_INGRESS): + return q->ingress_block; + case TC_H_MIN(TC_H_MIN_EGRESS): + return q->egress_block; + default: + return NULL; + } +} + +static void clsact_ingress_block_set(struct Qdisc *sch, u32 block_index) +{ + struct clsact_sched_data *q = qdisc_priv(sch); + + q->ingress_block_info.block_index = block_index; +} + +static void clsact_egress_block_set(struct Qdisc *sch, u32 block_index) +{ + struct clsact_sched_data *q = qdisc_priv(sch); + + q->egress_block_info.block_index = block_index; +} + +static u32 clsact_ingress_block_get(struct Qdisc *sch) +{ + struct clsact_sched_data *q = qdisc_priv(sch); + + return q->ingress_block_info.block_index; +} + +static u32 clsact_egress_block_get(struct Qdisc *sch) +{ + struct clsact_sched_data *q = qdisc_priv(sch); + + return q->egress_block_info.block_index; +} + +static int clsact_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct clsact_sched_data *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + int err; + + if (sch->parent != TC_H_CLSACT) + return -EOPNOTSUPP; + + net_inc_ingress_queue(); + net_inc_egress_queue(); + + mini_qdisc_pair_init(&q->miniqp_ingress, sch, &dev->miniq_ingress); + + q->ingress_block_info.binder_type = FLOW_BLOCK_BINDER_TYPE_CLSACT_INGRESS; + q->ingress_block_info.chain_head_change = clsact_chain_head_change; + q->ingress_block_info.chain_head_change_priv = &q->miniqp_ingress; + + err = tcf_block_get_ext(&q->ingress_block, sch, &q->ingress_block_info, + extack); + if (err) + return err; + + mini_qdisc_pair_block_init(&q->miniqp_ingress, q->ingress_block); + + mini_qdisc_pair_init(&q->miniqp_egress, sch, &dev->miniq_egress); + + q->egress_block_info.binder_type = FLOW_BLOCK_BINDER_TYPE_CLSACT_EGRESS; + q->egress_block_info.chain_head_change = clsact_chain_head_change; + q->egress_block_info.chain_head_change_priv = &q->miniqp_egress; + + return tcf_block_get_ext(&q->egress_block, sch, &q->egress_block_info, extack); +} + +static void clsact_destroy(struct Qdisc *sch) +{ + struct clsact_sched_data *q = qdisc_priv(sch); + + if (sch->parent != TC_H_CLSACT) + return; + + tcf_block_put_ext(q->egress_block, sch, &q->egress_block_info); + tcf_block_put_ext(q->ingress_block, sch, &q->ingress_block_info); + + net_dec_ingress_queue(); + net_dec_egress_queue(); +} + +static const struct Qdisc_class_ops clsact_class_ops = { + .flags = QDISC_CLASS_OPS_DOIT_UNLOCKED, + .leaf = ingress_leaf, + .find = clsact_find, + .walk = ingress_walk, + .tcf_block = clsact_tcf_block, + .bind_tcf = clsact_bind_filter, + .unbind_tcf = ingress_unbind_filter, +}; + +static struct Qdisc_ops clsact_qdisc_ops __read_mostly = { + .cl_ops = &clsact_class_ops, + .id = "clsact", + .priv_size = sizeof(struct clsact_sched_data), + .static_flags = TCQ_F_INGRESS | TCQ_F_CPUSTATS, + .init = clsact_init, + .destroy = clsact_destroy, + .dump = ingress_dump, + .ingress_block_set = clsact_ingress_block_set, + .egress_block_set = clsact_egress_block_set, + .ingress_block_get = clsact_ingress_block_get, + .egress_block_get = clsact_egress_block_get, + .owner = THIS_MODULE, +}; + +static int __init ingress_module_init(void) +{ + int ret; + + ret = register_qdisc(&ingress_qdisc_ops); + if (!ret) { + ret = register_qdisc(&clsact_qdisc_ops); + if (ret) + unregister_qdisc(&ingress_qdisc_ops); + } + + return ret; +} + +static void __exit ingress_module_exit(void) +{ + unregister_qdisc(&ingress_qdisc_ops); + unregister_qdisc(&clsact_qdisc_ops); +} + +module_init(ingress_module_init); +module_exit(ingress_module_exit); + +MODULE_ALIAS("sch_clsact"); +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_mq.c b/net/sched/sch_mq.c new file mode 100644 index 000000000..c860119a8 --- /dev/null +++ b/net/sched/sch_mq.c @@ -0,0 +1,275 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/sched/sch_mq.c Classful multiqueue dummy scheduler + * + * Copyright (c) 2009 Patrick McHardy + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct mq_sched { + struct Qdisc **qdiscs; +}; + +static int mq_offload(struct Qdisc *sch, enum tc_mq_command cmd) +{ + struct net_device *dev = qdisc_dev(sch); + struct tc_mq_qopt_offload opt = { + .command = cmd, + .handle = sch->handle, + }; + + if (!tc_can_offload(dev) || !dev->netdev_ops->ndo_setup_tc) + return -EOPNOTSUPP; + + return dev->netdev_ops->ndo_setup_tc(dev, TC_SETUP_QDISC_MQ, &opt); +} + +static int mq_offload_stats(struct Qdisc *sch) +{ + struct tc_mq_qopt_offload opt = { + .command = TC_MQ_STATS, + .handle = sch->handle, + .stats = { + .bstats = &sch->bstats, + .qstats = &sch->qstats, + }, + }; + + return qdisc_offload_dump_helper(sch, TC_SETUP_QDISC_MQ, &opt); +} + +static void mq_destroy(struct Qdisc *sch) +{ + struct net_device *dev = qdisc_dev(sch); + struct mq_sched *priv = qdisc_priv(sch); + unsigned int ntx; + + mq_offload(sch, TC_MQ_DESTROY); + + if (!priv->qdiscs) + return; + for (ntx = 0; ntx < dev->num_tx_queues && priv->qdiscs[ntx]; ntx++) + qdisc_put(priv->qdiscs[ntx]); + kfree(priv->qdiscs); +} + +static int mq_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct net_device *dev = qdisc_dev(sch); + struct mq_sched *priv = qdisc_priv(sch); + struct netdev_queue *dev_queue; + struct Qdisc *qdisc; + unsigned int ntx; + + if (sch->parent != TC_H_ROOT) + return -EOPNOTSUPP; + + if (!netif_is_multiqueue(dev)) + return -EOPNOTSUPP; + + /* pre-allocate qdiscs, attachment can't fail */ + priv->qdiscs = kcalloc(dev->num_tx_queues, sizeof(priv->qdiscs[0]), + GFP_KERNEL); + if (!priv->qdiscs) + return -ENOMEM; + + for (ntx = 0; ntx < dev->num_tx_queues; ntx++) { + dev_queue = netdev_get_tx_queue(dev, ntx); + qdisc = qdisc_create_dflt(dev_queue, get_default_qdisc_ops(dev, ntx), + TC_H_MAKE(TC_H_MAJ(sch->handle), + TC_H_MIN(ntx + 1)), + extack); + if (!qdisc) + return -ENOMEM; + priv->qdiscs[ntx] = qdisc; + qdisc->flags |= TCQ_F_ONETXQUEUE | TCQ_F_NOPARENT; + } + + sch->flags |= TCQ_F_MQROOT; + + mq_offload(sch, TC_MQ_CREATE); + return 0; +} + +static void mq_attach(struct Qdisc *sch) +{ + struct net_device *dev = qdisc_dev(sch); + struct mq_sched *priv = qdisc_priv(sch); + struct Qdisc *qdisc, *old; + unsigned int ntx; + + for (ntx = 0; ntx < dev->num_tx_queues; ntx++) { + qdisc = priv->qdiscs[ntx]; + old = dev_graft_qdisc(qdisc->dev_queue, qdisc); + if (old) + qdisc_put(old); +#ifdef CONFIG_NET_SCHED + if (ntx < dev->real_num_tx_queues) + qdisc_hash_add(qdisc, false); +#endif + + } + kfree(priv->qdiscs); + priv->qdiscs = NULL; +} + +static int mq_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct net_device *dev = qdisc_dev(sch); + struct Qdisc *qdisc; + unsigned int ntx; + + sch->q.qlen = 0; + gnet_stats_basic_sync_init(&sch->bstats); + memset(&sch->qstats, 0, sizeof(sch->qstats)); + + /* MQ supports lockless qdiscs. However, statistics accounting needs + * to account for all, none, or a mix of locked and unlocked child + * qdiscs. Percpu stats are added to counters in-band and locking + * qdisc totals are added at end. + */ + for (ntx = 0; ntx < dev->num_tx_queues; ntx++) { + qdisc = rtnl_dereference(netdev_get_tx_queue(dev, ntx)->qdisc_sleeping); + spin_lock_bh(qdisc_lock(qdisc)); + + gnet_stats_add_basic(&sch->bstats, qdisc->cpu_bstats, + &qdisc->bstats, false); + gnet_stats_add_queue(&sch->qstats, qdisc->cpu_qstats, + &qdisc->qstats); + sch->q.qlen += qdisc_qlen(qdisc); + + spin_unlock_bh(qdisc_lock(qdisc)); + } + + return mq_offload_stats(sch); +} + +static struct netdev_queue *mq_queue_get(struct Qdisc *sch, unsigned long cl) +{ + struct net_device *dev = qdisc_dev(sch); + unsigned long ntx = cl - 1; + + if (ntx >= dev->num_tx_queues) + return NULL; + return netdev_get_tx_queue(dev, ntx); +} + +static struct netdev_queue *mq_select_queue(struct Qdisc *sch, + struct tcmsg *tcm) +{ + return mq_queue_get(sch, TC_H_MIN(tcm->tcm_parent)); +} + +static int mq_graft(struct Qdisc *sch, unsigned long cl, struct Qdisc *new, + struct Qdisc **old, struct netlink_ext_ack *extack) +{ + struct netdev_queue *dev_queue = mq_queue_get(sch, cl); + struct tc_mq_qopt_offload graft_offload; + struct net_device *dev = qdisc_dev(sch); + + if (dev->flags & IFF_UP) + dev_deactivate(dev); + + *old = dev_graft_qdisc(dev_queue, new); + if (new) + new->flags |= TCQ_F_ONETXQUEUE | TCQ_F_NOPARENT; + if (dev->flags & IFF_UP) + dev_activate(dev); + + graft_offload.handle = sch->handle; + graft_offload.graft_params.queue = cl - 1; + graft_offload.graft_params.child_handle = new ? new->handle : 0; + graft_offload.command = TC_MQ_GRAFT; + + qdisc_offload_graft_helper(qdisc_dev(sch), sch, new, *old, + TC_SETUP_QDISC_MQ, &graft_offload, extack); + return 0; +} + +static struct Qdisc *mq_leaf(struct Qdisc *sch, unsigned long cl) +{ + struct netdev_queue *dev_queue = mq_queue_get(sch, cl); + + return rtnl_dereference(dev_queue->qdisc_sleeping); +} + +static unsigned long mq_find(struct Qdisc *sch, u32 classid) +{ + unsigned int ntx = TC_H_MIN(classid); + + if (!mq_queue_get(sch, ntx)) + return 0; + return ntx; +} + +static int mq_dump_class(struct Qdisc *sch, unsigned long cl, + struct sk_buff *skb, struct tcmsg *tcm) +{ + struct netdev_queue *dev_queue = mq_queue_get(sch, cl); + + tcm->tcm_parent = TC_H_ROOT; + tcm->tcm_handle |= TC_H_MIN(cl); + tcm->tcm_info = rtnl_dereference(dev_queue->qdisc_sleeping)->handle; + return 0; +} + +static int mq_dump_class_stats(struct Qdisc *sch, unsigned long cl, + struct gnet_dump *d) +{ + struct netdev_queue *dev_queue = mq_queue_get(sch, cl); + + sch = rtnl_dereference(dev_queue->qdisc_sleeping); + if (gnet_stats_copy_basic(d, sch->cpu_bstats, &sch->bstats, true) < 0 || + qdisc_qstats_copy(d, sch) < 0) + return -1; + return 0; +} + +static void mq_walk(struct Qdisc *sch, struct qdisc_walker *arg) +{ + struct net_device *dev = qdisc_dev(sch); + unsigned int ntx; + + if (arg->stop) + return; + + arg->count = arg->skip; + for (ntx = arg->skip; ntx < dev->num_tx_queues; ntx++) { + if (!tc_qdisc_stats_dump(sch, ntx + 1, arg)) + break; + } +} + +static const struct Qdisc_class_ops mq_class_ops = { + .select_queue = mq_select_queue, + .graft = mq_graft, + .leaf = mq_leaf, + .find = mq_find, + .walk = mq_walk, + .dump = mq_dump_class, + .dump_stats = mq_dump_class_stats, +}; + +struct Qdisc_ops mq_qdisc_ops __read_mostly = { + .cl_ops = &mq_class_ops, + .id = "mq", + .priv_size = sizeof(struct mq_sched), + .init = mq_init, + .destroy = mq_destroy, + .attach = mq_attach, + .change_real_num_tx = mq_change_real_num_tx, + .dump = mq_dump, + .owner = THIS_MODULE, +}; diff --git a/net/sched/sch_mqprio.c b/net/sched/sch_mqprio.c new file mode 100644 index 000000000..b99e0fe02 --- /dev/null +++ b/net/sched/sch_mqprio.c @@ -0,0 +1,664 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/sched/sch_mqprio.c + * + * Copyright (c) 2010 John Fastabend + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct mqprio_sched { + struct Qdisc **qdiscs; + u16 mode; + u16 shaper; + int hw_offload; + u32 flags; + u64 min_rate[TC_QOPT_MAX_QUEUE]; + u64 max_rate[TC_QOPT_MAX_QUEUE]; +}; + +static void mqprio_destroy(struct Qdisc *sch) +{ + struct net_device *dev = qdisc_dev(sch); + struct mqprio_sched *priv = qdisc_priv(sch); + unsigned int ntx; + + if (priv->qdiscs) { + for (ntx = 0; + ntx < dev->num_tx_queues && priv->qdiscs[ntx]; + ntx++) + qdisc_put(priv->qdiscs[ntx]); + kfree(priv->qdiscs); + } + + if (priv->hw_offload && dev->netdev_ops->ndo_setup_tc) { + struct tc_mqprio_qopt_offload mqprio = { { 0 } }; + + switch (priv->mode) { + case TC_MQPRIO_MODE_DCB: + case TC_MQPRIO_MODE_CHANNEL: + dev->netdev_ops->ndo_setup_tc(dev, + TC_SETUP_QDISC_MQPRIO, + &mqprio); + break; + default: + return; + } + } else { + netdev_set_num_tc(dev, 0); + } +} + +static int mqprio_parse_opt(struct net_device *dev, struct tc_mqprio_qopt *qopt) +{ + int i, j; + + /* Verify num_tc is not out of max range */ + if (qopt->num_tc > TC_MAX_QUEUE) + return -EINVAL; + + /* Verify priority mapping uses valid tcs */ + for (i = 0; i < TC_BITMASK + 1; i++) { + if (qopt->prio_tc_map[i] >= qopt->num_tc) + return -EINVAL; + } + + /* Limit qopt->hw to maximum supported offload value. Drivers have + * the option of overriding this later if they don't support the a + * given offload type. + */ + if (qopt->hw > TC_MQPRIO_HW_OFFLOAD_MAX) + qopt->hw = TC_MQPRIO_HW_OFFLOAD_MAX; + + /* If hardware offload is requested we will leave it to the device + * to either populate the queue counts itself or to validate the + * provided queue counts. If ndo_setup_tc is not present then + * hardware doesn't support offload and we should return an error. + */ + if (qopt->hw) + return dev->netdev_ops->ndo_setup_tc ? 0 : -EINVAL; + + for (i = 0; i < qopt->num_tc; i++) { + unsigned int last = qopt->offset[i] + qopt->count[i]; + + /* Verify the queue count is in tx range being equal to the + * real_num_tx_queues indicates the last queue is in use. + */ + if (qopt->offset[i] >= dev->real_num_tx_queues || + !qopt->count[i] || + last > dev->real_num_tx_queues) + return -EINVAL; + + /* Verify that the offset and counts do not overlap */ + for (j = i + 1; j < qopt->num_tc; j++) { + if (last > qopt->offset[j]) + return -EINVAL; + } + } + + return 0; +} + +static const struct nla_policy mqprio_policy[TCA_MQPRIO_MAX + 1] = { + [TCA_MQPRIO_MODE] = { .len = sizeof(u16) }, + [TCA_MQPRIO_SHAPER] = { .len = sizeof(u16) }, + [TCA_MQPRIO_MIN_RATE64] = { .type = NLA_NESTED }, + [TCA_MQPRIO_MAX_RATE64] = { .type = NLA_NESTED }, +}; + +static int parse_attr(struct nlattr *tb[], int maxtype, struct nlattr *nla, + const struct nla_policy *policy, int len) +{ + int nested_len = nla_len(nla) - NLA_ALIGN(len); + + if (nested_len >= nla_attr_size(0)) + return nla_parse_deprecated(tb, maxtype, + nla_data(nla) + NLA_ALIGN(len), + nested_len, policy, NULL); + + memset(tb, 0, sizeof(struct nlattr *) * (maxtype + 1)); + return 0; +} + +static int mqprio_parse_nlattr(struct Qdisc *sch, struct tc_mqprio_qopt *qopt, + struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct mqprio_sched *priv = qdisc_priv(sch); + struct nlattr *tb[TCA_MQPRIO_MAX + 1]; + struct nlattr *attr; + int i, rem, err; + + err = parse_attr(tb, TCA_MQPRIO_MAX, opt, mqprio_policy, + sizeof(*qopt)); + if (err < 0) + return err; + + if (!qopt->hw) { + NL_SET_ERR_MSG(extack, + "mqprio TCA_OPTIONS can only contain netlink attributes in hardware mode"); + return -EINVAL; + } + + if (tb[TCA_MQPRIO_MODE]) { + priv->flags |= TC_MQPRIO_F_MODE; + priv->mode = *(u16 *)nla_data(tb[TCA_MQPRIO_MODE]); + } + + if (tb[TCA_MQPRIO_SHAPER]) { + priv->flags |= TC_MQPRIO_F_SHAPER; + priv->shaper = *(u16 *)nla_data(tb[TCA_MQPRIO_SHAPER]); + } + + if (tb[TCA_MQPRIO_MIN_RATE64]) { + if (priv->shaper != TC_MQPRIO_SHAPER_BW_RATE) { + NL_SET_ERR_MSG_ATTR(extack, tb[TCA_MQPRIO_MIN_RATE64], + "min_rate accepted only when shaper is in bw_rlimit mode"); + return -EINVAL; + } + i = 0; + nla_for_each_nested(attr, tb[TCA_MQPRIO_MIN_RATE64], + rem) { + if (nla_type(attr) != TCA_MQPRIO_MIN_RATE64) { + NL_SET_ERR_MSG_ATTR(extack, attr, + "Attribute type expected to be TCA_MQPRIO_MIN_RATE64"); + return -EINVAL; + } + + if (nla_len(attr) != sizeof(u64)) { + NL_SET_ERR_MSG_ATTR(extack, attr, + "Attribute TCA_MQPRIO_MIN_RATE64 expected to have 8 bytes length"); + return -EINVAL; + } + + if (i >= qopt->num_tc) + break; + priv->min_rate[i] = *(u64 *)nla_data(attr); + i++; + } + priv->flags |= TC_MQPRIO_F_MIN_RATE; + } + + if (tb[TCA_MQPRIO_MAX_RATE64]) { + if (priv->shaper != TC_MQPRIO_SHAPER_BW_RATE) { + NL_SET_ERR_MSG_ATTR(extack, tb[TCA_MQPRIO_MAX_RATE64], + "max_rate accepted only when shaper is in bw_rlimit mode"); + return -EINVAL; + } + i = 0; + nla_for_each_nested(attr, tb[TCA_MQPRIO_MAX_RATE64], + rem) { + if (nla_type(attr) != TCA_MQPRIO_MAX_RATE64) { + NL_SET_ERR_MSG_ATTR(extack, attr, + "Attribute type expected to be TCA_MQPRIO_MAX_RATE64"); + return -EINVAL; + } + + if (nla_len(attr) != sizeof(u64)) { + NL_SET_ERR_MSG_ATTR(extack, attr, + "Attribute TCA_MQPRIO_MAX_RATE64 expected to have 8 bytes length"); + return -EINVAL; + } + + if (i >= qopt->num_tc) + break; + priv->max_rate[i] = *(u64 *)nla_data(attr); + i++; + } + priv->flags |= TC_MQPRIO_F_MAX_RATE; + } + + return 0; +} + +static int mqprio_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct net_device *dev = qdisc_dev(sch); + struct mqprio_sched *priv = qdisc_priv(sch); + struct netdev_queue *dev_queue; + struct Qdisc *qdisc; + int i, err = -EOPNOTSUPP; + struct tc_mqprio_qopt *qopt = NULL; + int len; + + BUILD_BUG_ON(TC_MAX_QUEUE != TC_QOPT_MAX_QUEUE); + BUILD_BUG_ON(TC_BITMASK != TC_QOPT_BITMASK); + + if (sch->parent != TC_H_ROOT) + return -EOPNOTSUPP; + + if (!netif_is_multiqueue(dev)) + return -EOPNOTSUPP; + + /* make certain can allocate enough classids to handle queues */ + if (dev->num_tx_queues >= TC_H_MIN_PRIORITY) + return -ENOMEM; + + if (!opt || nla_len(opt) < sizeof(*qopt)) + return -EINVAL; + + qopt = nla_data(opt); + if (mqprio_parse_opt(dev, qopt)) + return -EINVAL; + + len = nla_len(opt) - NLA_ALIGN(sizeof(*qopt)); + if (len > 0) { + err = mqprio_parse_nlattr(sch, qopt, opt, extack); + if (err) + return err; + } + + /* pre-allocate qdisc, attachment can't fail */ + priv->qdiscs = kcalloc(dev->num_tx_queues, sizeof(priv->qdiscs[0]), + GFP_KERNEL); + if (!priv->qdiscs) + return -ENOMEM; + + for (i = 0; i < dev->num_tx_queues; i++) { + dev_queue = netdev_get_tx_queue(dev, i); + qdisc = qdisc_create_dflt(dev_queue, + get_default_qdisc_ops(dev, i), + TC_H_MAKE(TC_H_MAJ(sch->handle), + TC_H_MIN(i + 1)), extack); + if (!qdisc) + return -ENOMEM; + + priv->qdiscs[i] = qdisc; + qdisc->flags |= TCQ_F_ONETXQUEUE | TCQ_F_NOPARENT; + } + + /* If the mqprio options indicate that hardware should own + * the queue mapping then run ndo_setup_tc otherwise use the + * supplied and verified mapping + */ + if (qopt->hw) { + struct tc_mqprio_qopt_offload mqprio = {.qopt = *qopt}; + + switch (priv->mode) { + case TC_MQPRIO_MODE_DCB: + if (priv->shaper != TC_MQPRIO_SHAPER_DCB) + return -EINVAL; + break; + case TC_MQPRIO_MODE_CHANNEL: + mqprio.flags = priv->flags; + if (priv->flags & TC_MQPRIO_F_MODE) + mqprio.mode = priv->mode; + if (priv->flags & TC_MQPRIO_F_SHAPER) + mqprio.shaper = priv->shaper; + if (priv->flags & TC_MQPRIO_F_MIN_RATE) + for (i = 0; i < mqprio.qopt.num_tc; i++) + mqprio.min_rate[i] = priv->min_rate[i]; + if (priv->flags & TC_MQPRIO_F_MAX_RATE) + for (i = 0; i < mqprio.qopt.num_tc; i++) + mqprio.max_rate[i] = priv->max_rate[i]; + break; + default: + return -EINVAL; + } + err = dev->netdev_ops->ndo_setup_tc(dev, + TC_SETUP_QDISC_MQPRIO, + &mqprio); + if (err) + return err; + + priv->hw_offload = mqprio.qopt.hw; + } else { + netdev_set_num_tc(dev, qopt->num_tc); + for (i = 0; i < qopt->num_tc; i++) + netdev_set_tc_queue(dev, i, + qopt->count[i], qopt->offset[i]); + } + + /* Always use supplied priority mappings */ + for (i = 0; i < TC_BITMASK + 1; i++) + netdev_set_prio_tc_map(dev, i, qopt->prio_tc_map[i]); + + sch->flags |= TCQ_F_MQROOT; + return 0; +} + +static void mqprio_attach(struct Qdisc *sch) +{ + struct net_device *dev = qdisc_dev(sch); + struct mqprio_sched *priv = qdisc_priv(sch); + struct Qdisc *qdisc, *old; + unsigned int ntx; + + /* Attach underlying qdisc */ + for (ntx = 0; ntx < dev->num_tx_queues; ntx++) { + qdisc = priv->qdiscs[ntx]; + old = dev_graft_qdisc(qdisc->dev_queue, qdisc); + if (old) + qdisc_put(old); + if (ntx < dev->real_num_tx_queues) + qdisc_hash_add(qdisc, false); + } + kfree(priv->qdiscs); + priv->qdiscs = NULL; +} + +static struct netdev_queue *mqprio_queue_get(struct Qdisc *sch, + unsigned long cl) +{ + struct net_device *dev = qdisc_dev(sch); + unsigned long ntx = cl - 1; + + if (ntx >= dev->num_tx_queues) + return NULL; + return netdev_get_tx_queue(dev, ntx); +} + +static int mqprio_graft(struct Qdisc *sch, unsigned long cl, struct Qdisc *new, + struct Qdisc **old, struct netlink_ext_ack *extack) +{ + struct net_device *dev = qdisc_dev(sch); + struct netdev_queue *dev_queue = mqprio_queue_get(sch, cl); + + if (!dev_queue) + return -EINVAL; + + if (dev->flags & IFF_UP) + dev_deactivate(dev); + + *old = dev_graft_qdisc(dev_queue, new); + + if (new) + new->flags |= TCQ_F_ONETXQUEUE | TCQ_F_NOPARENT; + + if (dev->flags & IFF_UP) + dev_activate(dev); + + return 0; +} + +static int dump_rates(struct mqprio_sched *priv, + struct tc_mqprio_qopt *opt, struct sk_buff *skb) +{ + struct nlattr *nest; + int i; + + if (priv->flags & TC_MQPRIO_F_MIN_RATE) { + nest = nla_nest_start_noflag(skb, TCA_MQPRIO_MIN_RATE64); + if (!nest) + goto nla_put_failure; + + for (i = 0; i < opt->num_tc; i++) { + if (nla_put(skb, TCA_MQPRIO_MIN_RATE64, + sizeof(priv->min_rate[i]), + &priv->min_rate[i])) + goto nla_put_failure; + } + nla_nest_end(skb, nest); + } + + if (priv->flags & TC_MQPRIO_F_MAX_RATE) { + nest = nla_nest_start_noflag(skb, TCA_MQPRIO_MAX_RATE64); + if (!nest) + goto nla_put_failure; + + for (i = 0; i < opt->num_tc; i++) { + if (nla_put(skb, TCA_MQPRIO_MAX_RATE64, + sizeof(priv->max_rate[i]), + &priv->max_rate[i])) + goto nla_put_failure; + } + nla_nest_end(skb, nest); + } + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static int mqprio_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct net_device *dev = qdisc_dev(sch); + struct mqprio_sched *priv = qdisc_priv(sch); + struct nlattr *nla = (struct nlattr *)skb_tail_pointer(skb); + struct tc_mqprio_qopt opt = { 0 }; + struct Qdisc *qdisc; + unsigned int ntx, tc; + + sch->q.qlen = 0; + gnet_stats_basic_sync_init(&sch->bstats); + memset(&sch->qstats, 0, sizeof(sch->qstats)); + + /* MQ supports lockless qdiscs. However, statistics accounting needs + * to account for all, none, or a mix of locked and unlocked child + * qdiscs. Percpu stats are added to counters in-band and locking + * qdisc totals are added at end. + */ + for (ntx = 0; ntx < dev->num_tx_queues; ntx++) { + qdisc = rtnl_dereference(netdev_get_tx_queue(dev, ntx)->qdisc_sleeping); + spin_lock_bh(qdisc_lock(qdisc)); + + gnet_stats_add_basic(&sch->bstats, qdisc->cpu_bstats, + &qdisc->bstats, false); + gnet_stats_add_queue(&sch->qstats, qdisc->cpu_qstats, + &qdisc->qstats); + sch->q.qlen += qdisc_qlen(qdisc); + + spin_unlock_bh(qdisc_lock(qdisc)); + } + + opt.num_tc = netdev_get_num_tc(dev); + memcpy(opt.prio_tc_map, dev->prio_tc_map, sizeof(opt.prio_tc_map)); + opt.hw = priv->hw_offload; + + for (tc = 0; tc < netdev_get_num_tc(dev); tc++) { + opt.count[tc] = dev->tc_to_txq[tc].count; + opt.offset[tc] = dev->tc_to_txq[tc].offset; + } + + if (nla_put(skb, TCA_OPTIONS, sizeof(opt), &opt)) + goto nla_put_failure; + + if ((priv->flags & TC_MQPRIO_F_MODE) && + nla_put_u16(skb, TCA_MQPRIO_MODE, priv->mode)) + goto nla_put_failure; + + if ((priv->flags & TC_MQPRIO_F_SHAPER) && + nla_put_u16(skb, TCA_MQPRIO_SHAPER, priv->shaper)) + goto nla_put_failure; + + if ((priv->flags & TC_MQPRIO_F_MIN_RATE || + priv->flags & TC_MQPRIO_F_MAX_RATE) && + (dump_rates(priv, &opt, skb) != 0)) + goto nla_put_failure; + + return nla_nest_end(skb, nla); +nla_put_failure: + nlmsg_trim(skb, nla); + return -1; +} + +static struct Qdisc *mqprio_leaf(struct Qdisc *sch, unsigned long cl) +{ + struct netdev_queue *dev_queue = mqprio_queue_get(sch, cl); + + if (!dev_queue) + return NULL; + + return rtnl_dereference(dev_queue->qdisc_sleeping); +} + +static unsigned long mqprio_find(struct Qdisc *sch, u32 classid) +{ + struct net_device *dev = qdisc_dev(sch); + unsigned int ntx = TC_H_MIN(classid); + + /* There are essentially two regions here that have valid classid + * values. The first region will have a classid value of 1 through + * num_tx_queues. All of these are backed by actual Qdiscs. + */ + if (ntx < TC_H_MIN_PRIORITY) + return (ntx <= dev->num_tx_queues) ? ntx : 0; + + /* The second region represents the hardware traffic classes. These + * are represented by classid values of TC_H_MIN_PRIORITY through + * TC_H_MIN_PRIORITY + netdev_get_num_tc - 1 + */ + return ((ntx - TC_H_MIN_PRIORITY) < netdev_get_num_tc(dev)) ? ntx : 0; +} + +static int mqprio_dump_class(struct Qdisc *sch, unsigned long cl, + struct sk_buff *skb, struct tcmsg *tcm) +{ + if (cl < TC_H_MIN_PRIORITY) { + struct netdev_queue *dev_queue = mqprio_queue_get(sch, cl); + struct net_device *dev = qdisc_dev(sch); + int tc = netdev_txq_to_tc(dev, cl - 1); + + tcm->tcm_parent = (tc < 0) ? 0 : + TC_H_MAKE(TC_H_MAJ(sch->handle), + TC_H_MIN(tc + TC_H_MIN_PRIORITY)); + tcm->tcm_info = rtnl_dereference(dev_queue->qdisc_sleeping)->handle; + } else { + tcm->tcm_parent = TC_H_ROOT; + tcm->tcm_info = 0; + } + tcm->tcm_handle |= TC_H_MIN(cl); + return 0; +} + +static int mqprio_dump_class_stats(struct Qdisc *sch, unsigned long cl, + struct gnet_dump *d) + __releases(d->lock) + __acquires(d->lock) +{ + if (cl >= TC_H_MIN_PRIORITY) { + int i; + __u32 qlen; + struct gnet_stats_queue qstats = {0}; + struct gnet_stats_basic_sync bstats; + struct net_device *dev = qdisc_dev(sch); + struct netdev_tc_txq tc = dev->tc_to_txq[cl & TC_BITMASK]; + + gnet_stats_basic_sync_init(&bstats); + /* Drop lock here it will be reclaimed before touching + * statistics this is required because the d->lock we + * hold here is the look on dev_queue->qdisc_sleeping + * also acquired below. + */ + if (d->lock) + spin_unlock_bh(d->lock); + + for (i = tc.offset; i < tc.offset + tc.count; i++) { + struct netdev_queue *q = netdev_get_tx_queue(dev, i); + struct Qdisc *qdisc = rtnl_dereference(q->qdisc); + + spin_lock_bh(qdisc_lock(qdisc)); + + gnet_stats_add_basic(&bstats, qdisc->cpu_bstats, + &qdisc->bstats, false); + gnet_stats_add_queue(&qstats, qdisc->cpu_qstats, + &qdisc->qstats); + sch->q.qlen += qdisc_qlen(qdisc); + + spin_unlock_bh(qdisc_lock(qdisc)); + } + qlen = qdisc_qlen(sch) + qstats.qlen; + + /* Reclaim root sleeping lock before completing stats */ + if (d->lock) + spin_lock_bh(d->lock); + if (gnet_stats_copy_basic(d, NULL, &bstats, false) < 0 || + gnet_stats_copy_queue(d, NULL, &qstats, qlen) < 0) + return -1; + } else { + struct netdev_queue *dev_queue = mqprio_queue_get(sch, cl); + + sch = rtnl_dereference(dev_queue->qdisc_sleeping); + if (gnet_stats_copy_basic(d, sch->cpu_bstats, + &sch->bstats, true) < 0 || + qdisc_qstats_copy(d, sch) < 0) + return -1; + } + return 0; +} + +static void mqprio_walk(struct Qdisc *sch, struct qdisc_walker *arg) +{ + struct net_device *dev = qdisc_dev(sch); + unsigned long ntx; + + if (arg->stop) + return; + + /* Walk hierarchy with a virtual class per tc */ + arg->count = arg->skip; + for (ntx = arg->skip; ntx < netdev_get_num_tc(dev); ntx++) { + if (!tc_qdisc_stats_dump(sch, ntx + TC_H_MIN_PRIORITY, arg)) + return; + } + + /* Pad the values and skip over unused traffic classes */ + if (ntx < TC_MAX_QUEUE) { + arg->count = TC_MAX_QUEUE; + ntx = TC_MAX_QUEUE; + } + + /* Reset offset, sort out remaining per-queue qdiscs */ + for (ntx -= TC_MAX_QUEUE; ntx < dev->num_tx_queues; ntx++) { + if (arg->fn(sch, ntx + 1, arg) < 0) { + arg->stop = 1; + return; + } + arg->count++; + } +} + +static struct netdev_queue *mqprio_select_queue(struct Qdisc *sch, + struct tcmsg *tcm) +{ + return mqprio_queue_get(sch, TC_H_MIN(tcm->tcm_parent)); +} + +static const struct Qdisc_class_ops mqprio_class_ops = { + .graft = mqprio_graft, + .leaf = mqprio_leaf, + .find = mqprio_find, + .walk = mqprio_walk, + .dump = mqprio_dump_class, + .dump_stats = mqprio_dump_class_stats, + .select_queue = mqprio_select_queue, +}; + +static struct Qdisc_ops mqprio_qdisc_ops __read_mostly = { + .cl_ops = &mqprio_class_ops, + .id = "mqprio", + .priv_size = sizeof(struct mqprio_sched), + .init = mqprio_init, + .destroy = mqprio_destroy, + .attach = mqprio_attach, + .change_real_num_tx = mq_change_real_num_tx, + .dump = mqprio_dump, + .owner = THIS_MODULE, +}; + +static int __init mqprio_module_init(void) +{ + return register_qdisc(&mqprio_qdisc_ops); +} + +static void __exit mqprio_module_exit(void) +{ + unregister_qdisc(&mqprio_qdisc_ops); +} + +module_init(mqprio_module_init); +module_exit(mqprio_module_exit); + +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_multiq.c b/net/sched/sch_multiq.c new file mode 100644 index 000000000..75c9c8601 --- /dev/null +++ b/net/sched/sch_multiq.c @@ -0,0 +1,412 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright (c) 2008, Intel Corporation. + * + * Author: Alexander Duyck + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct multiq_sched_data { + u16 bands; + u16 max_bands; + u16 curband; + struct tcf_proto __rcu *filter_list; + struct tcf_block *block; + struct Qdisc **queues; +}; + + +static struct Qdisc * +multiq_classify(struct sk_buff *skb, struct Qdisc *sch, int *qerr) +{ + struct multiq_sched_data *q = qdisc_priv(sch); + u32 band; + struct tcf_result res; + struct tcf_proto *fl = rcu_dereference_bh(q->filter_list); + int err; + + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; + err = tcf_classify(skb, NULL, fl, &res, false); +#ifdef CONFIG_NET_CLS_ACT + switch (err) { + case TC_ACT_STOLEN: + case TC_ACT_QUEUED: + case TC_ACT_TRAP: + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_STOLEN; + fallthrough; + case TC_ACT_SHOT: + return NULL; + } +#endif + band = skb_get_queue_mapping(skb); + + if (band >= q->bands) + return q->queues[0]; + + return q->queues[band]; +} + +static int +multiq_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct Qdisc *qdisc; + int ret; + + qdisc = multiq_classify(skb, sch, &ret); +#ifdef CONFIG_NET_CLS_ACT + if (qdisc == NULL) { + + if (ret & __NET_XMIT_BYPASS) + qdisc_qstats_drop(sch); + __qdisc_drop(skb, to_free); + return ret; + } +#endif + + ret = qdisc_enqueue(skb, qdisc, to_free); + if (ret == NET_XMIT_SUCCESS) { + sch->q.qlen++; + return NET_XMIT_SUCCESS; + } + if (net_xmit_drop_count(ret)) + qdisc_qstats_drop(sch); + return ret; +} + +static struct sk_buff *multiq_dequeue(struct Qdisc *sch) +{ + struct multiq_sched_data *q = qdisc_priv(sch); + struct Qdisc *qdisc; + struct sk_buff *skb; + int band; + + for (band = 0; band < q->bands; band++) { + /* cycle through bands to ensure fairness */ + q->curband++; + if (q->curband >= q->bands) + q->curband = 0; + + /* Check that target subqueue is available before + * pulling an skb to avoid head-of-line blocking. + */ + if (!netif_xmit_stopped( + netdev_get_tx_queue(qdisc_dev(sch), q->curband))) { + qdisc = q->queues[q->curband]; + skb = qdisc->dequeue(qdisc); + if (skb) { + qdisc_bstats_update(sch, skb); + sch->q.qlen--; + return skb; + } + } + } + return NULL; + +} + +static struct sk_buff *multiq_peek(struct Qdisc *sch) +{ + struct multiq_sched_data *q = qdisc_priv(sch); + unsigned int curband = q->curband; + struct Qdisc *qdisc; + struct sk_buff *skb; + int band; + + for (band = 0; band < q->bands; band++) { + /* cycle through bands to ensure fairness */ + curband++; + if (curband >= q->bands) + curband = 0; + + /* Check that target subqueue is available before + * pulling an skb to avoid head-of-line blocking. + */ + if (!netif_xmit_stopped( + netdev_get_tx_queue(qdisc_dev(sch), curband))) { + qdisc = q->queues[curband]; + skb = qdisc->ops->peek(qdisc); + if (skb) + return skb; + } + } + return NULL; + +} + +static void +multiq_reset(struct Qdisc *sch) +{ + u16 band; + struct multiq_sched_data *q = qdisc_priv(sch); + + for (band = 0; band < q->bands; band++) + qdisc_reset(q->queues[band]); + q->curband = 0; +} + +static void +multiq_destroy(struct Qdisc *sch) +{ + int band; + struct multiq_sched_data *q = qdisc_priv(sch); + + tcf_block_put(q->block); + for (band = 0; band < q->bands; band++) + qdisc_put(q->queues[band]); + + kfree(q->queues); +} + +static int multiq_tune(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct multiq_sched_data *q = qdisc_priv(sch); + struct tc_multiq_qopt *qopt; + struct Qdisc **removed; + int i, n_removed = 0; + + if (!netif_is_multiqueue(qdisc_dev(sch))) + return -EOPNOTSUPP; + if (nla_len(opt) < sizeof(*qopt)) + return -EINVAL; + + qopt = nla_data(opt); + + qopt->bands = qdisc_dev(sch)->real_num_tx_queues; + + removed = kmalloc(sizeof(*removed) * (q->max_bands - q->bands), + GFP_KERNEL); + if (!removed) + return -ENOMEM; + + sch_tree_lock(sch); + q->bands = qopt->bands; + for (i = q->bands; i < q->max_bands; i++) { + if (q->queues[i] != &noop_qdisc) { + struct Qdisc *child = q->queues[i]; + + q->queues[i] = &noop_qdisc; + qdisc_purge_queue(child); + removed[n_removed++] = child; + } + } + + sch_tree_unlock(sch); + + for (i = 0; i < n_removed; i++) + qdisc_put(removed[i]); + kfree(removed); + + for (i = 0; i < q->bands; i++) { + if (q->queues[i] == &noop_qdisc) { + struct Qdisc *child, *old; + child = qdisc_create_dflt(sch->dev_queue, + &pfifo_qdisc_ops, + TC_H_MAKE(sch->handle, + i + 1), extack); + if (child) { + sch_tree_lock(sch); + old = q->queues[i]; + q->queues[i] = child; + if (child != &noop_qdisc) + qdisc_hash_add(child, true); + + if (old != &noop_qdisc) + qdisc_purge_queue(old); + sch_tree_unlock(sch); + qdisc_put(old); + } + } + } + return 0; +} + +static int multiq_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct multiq_sched_data *q = qdisc_priv(sch); + int i, err; + + q->queues = NULL; + + if (!opt) + return -EINVAL; + + err = tcf_block_get(&q->block, &q->filter_list, sch, extack); + if (err) + return err; + + q->max_bands = qdisc_dev(sch)->num_tx_queues; + + q->queues = kcalloc(q->max_bands, sizeof(struct Qdisc *), GFP_KERNEL); + if (!q->queues) + return -ENOBUFS; + for (i = 0; i < q->max_bands; i++) + q->queues[i] = &noop_qdisc; + + return multiq_tune(sch, opt, extack); +} + +static int multiq_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct multiq_sched_data *q = qdisc_priv(sch); + unsigned char *b = skb_tail_pointer(skb); + struct tc_multiq_qopt opt; + + opt.bands = q->bands; + opt.max_bands = q->max_bands; + + if (nla_put(skb, TCA_OPTIONS, sizeof(opt), &opt)) + goto nla_put_failure; + + return skb->len; + +nla_put_failure: + nlmsg_trim(skb, b); + return -1; +} + +static int multiq_graft(struct Qdisc *sch, unsigned long arg, struct Qdisc *new, + struct Qdisc **old, struct netlink_ext_ack *extack) +{ + struct multiq_sched_data *q = qdisc_priv(sch); + unsigned long band = arg - 1; + + if (new == NULL) + new = &noop_qdisc; + + *old = qdisc_replace(sch, new, &q->queues[band]); + return 0; +} + +static struct Qdisc * +multiq_leaf(struct Qdisc *sch, unsigned long arg) +{ + struct multiq_sched_data *q = qdisc_priv(sch); + unsigned long band = arg - 1; + + return q->queues[band]; +} + +static unsigned long multiq_find(struct Qdisc *sch, u32 classid) +{ + struct multiq_sched_data *q = qdisc_priv(sch); + unsigned long band = TC_H_MIN(classid); + + if (band - 1 >= q->bands) + return 0; + return band; +} + +static unsigned long multiq_bind(struct Qdisc *sch, unsigned long parent, + u32 classid) +{ + return multiq_find(sch, classid); +} + + +static void multiq_unbind(struct Qdisc *q, unsigned long cl) +{ +} + +static int multiq_dump_class(struct Qdisc *sch, unsigned long cl, + struct sk_buff *skb, struct tcmsg *tcm) +{ + struct multiq_sched_data *q = qdisc_priv(sch); + + tcm->tcm_handle |= TC_H_MIN(cl); + tcm->tcm_info = q->queues[cl - 1]->handle; + return 0; +} + +static int multiq_dump_class_stats(struct Qdisc *sch, unsigned long cl, + struct gnet_dump *d) +{ + struct multiq_sched_data *q = qdisc_priv(sch); + struct Qdisc *cl_q; + + cl_q = q->queues[cl - 1]; + if (gnet_stats_copy_basic(d, cl_q->cpu_bstats, &cl_q->bstats, true) < 0 || + qdisc_qstats_copy(d, cl_q) < 0) + return -1; + + return 0; +} + +static void multiq_walk(struct Qdisc *sch, struct qdisc_walker *arg) +{ + struct multiq_sched_data *q = qdisc_priv(sch); + int band; + + if (arg->stop) + return; + + for (band = 0; band < q->bands; band++) { + if (!tc_qdisc_stats_dump(sch, band + 1, arg)) + break; + } +} + +static struct tcf_block *multiq_tcf_block(struct Qdisc *sch, unsigned long cl, + struct netlink_ext_ack *extack) +{ + struct multiq_sched_data *q = qdisc_priv(sch); + + if (cl) + return NULL; + return q->block; +} + +static const struct Qdisc_class_ops multiq_class_ops = { + .graft = multiq_graft, + .leaf = multiq_leaf, + .find = multiq_find, + .walk = multiq_walk, + .tcf_block = multiq_tcf_block, + .bind_tcf = multiq_bind, + .unbind_tcf = multiq_unbind, + .dump = multiq_dump_class, + .dump_stats = multiq_dump_class_stats, +}; + +static struct Qdisc_ops multiq_qdisc_ops __read_mostly = { + .next = NULL, + .cl_ops = &multiq_class_ops, + .id = "multiq", + .priv_size = sizeof(struct multiq_sched_data), + .enqueue = multiq_enqueue, + .dequeue = multiq_dequeue, + .peek = multiq_peek, + .init = multiq_init, + .reset = multiq_reset, + .destroy = multiq_destroy, + .change = multiq_tune, + .dump = multiq_dump, + .owner = THIS_MODULE, +}; + +static int __init multiq_module_init(void) +{ + return register_qdisc(&multiq_qdisc_ops); +} + +static void __exit multiq_module_exit(void) +{ + unregister_qdisc(&multiq_qdisc_ops); +} + +module_init(multiq_module_init) +module_exit(multiq_module_exit) + +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_netem.c b/net/sched/sch_netem.c new file mode 100644 index 000000000..d0e045116 --- /dev/null +++ b/net/sched/sch_netem.c @@ -0,0 +1,1289 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/sched/sch_netem.c Network emulator + * + * Many of the algorithms and ideas for this came from + * NIST Net which is not copyrighted. + * + * Authors: Stephen Hemminger + * Catalin(ux aka Dino) BOIE + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#define VERSION "1.3" + +/* Network Emulation Queuing algorithm. + ==================================== + + Sources: [1] Mark Carson, Darrin Santay, "NIST Net - A Linux-based + Network Emulation Tool + [2] Luigi Rizzo, DummyNet for FreeBSD + + ---------------------------------------------------------------- + + This started out as a simple way to delay outgoing packets to + test TCP but has grown to include most of the functionality + of a full blown network emulator like NISTnet. It can delay + packets and add random jitter (and correlation). The random + distribution can be loaded from a table as well to provide + normal, Pareto, or experimental curves. Packet loss, + duplication, and reordering can also be emulated. + + This qdisc does not do classification that can be handled in + layering other disciplines. It does not need to do bandwidth + control either since that can be handled by using token + bucket or other rate control. + + Correlated Loss Generator models + + Added generation of correlated loss according to the + "Gilbert-Elliot" model, a 4-state markov model. + + References: + [1] NetemCLG Home http://netgroup.uniroma2.it/NetemCLG + [2] S. Salsano, F. Ludovici, A. Ordine, "Definition of a general + and intuitive loss model for packet networks and its implementation + in the Netem module in the Linux kernel", available in [1] + + Authors: Stefano Salsano +*/ + +struct disttable { + u32 size; + s16 table[]; +}; + +struct netem_sched_data { + /* internal t(ime)fifo qdisc uses t_root and sch->limit */ + struct rb_root t_root; + + /* a linear queue; reduces rbtree rebalancing when jitter is low */ + struct sk_buff *t_head; + struct sk_buff *t_tail; + + /* optional qdisc for classful handling (NULL at netem init) */ + struct Qdisc *qdisc; + + struct qdisc_watchdog watchdog; + + s64 latency; + s64 jitter; + + u32 loss; + u32 ecn; + u32 limit; + u32 counter; + u32 gap; + u32 duplicate; + u32 reorder; + u32 corrupt; + u64 rate; + s32 packet_overhead; + u32 cell_size; + struct reciprocal_value cell_size_reciprocal; + s32 cell_overhead; + + struct crndstate { + u32 last; + u32 rho; + } delay_cor, loss_cor, dup_cor, reorder_cor, corrupt_cor; + + struct disttable *delay_dist; + + enum { + CLG_RANDOM, + CLG_4_STATES, + CLG_GILB_ELL, + } loss_model; + + enum { + TX_IN_GAP_PERIOD = 1, + TX_IN_BURST_PERIOD, + LOST_IN_GAP_PERIOD, + LOST_IN_BURST_PERIOD, + } _4_state_model; + + enum { + GOOD_STATE = 1, + BAD_STATE, + } GE_state_model; + + /* Correlated Loss Generation models */ + struct clgstate { + /* state of the Markov chain */ + u8 state; + + /* 4-states and Gilbert-Elliot models */ + u32 a1; /* p13 for 4-states or p for GE */ + u32 a2; /* p31 for 4-states or r for GE */ + u32 a3; /* p32 for 4-states or h for GE */ + u32 a4; /* p14 for 4-states or 1-k for GE */ + u32 a5; /* p23 used only in 4-states */ + } clg; + + struct tc_netem_slot slot_config; + struct slotstate { + u64 slot_next; + s32 packets_left; + s32 bytes_left; + } slot; + + struct disttable *slot_dist; +}; + +/* Time stamp put into socket buffer control block + * Only valid when skbs are in our internal t(ime)fifo queue. + * + * As skb->rbnode uses same storage than skb->next, skb->prev and skb->tstamp, + * and skb->next & skb->prev are scratch space for a qdisc, + * we save skb->tstamp value in skb->cb[] before destroying it. + */ +struct netem_skb_cb { + u64 time_to_send; +}; + +static inline struct netem_skb_cb *netem_skb_cb(struct sk_buff *skb) +{ + /* we assume we can use skb next/prev/tstamp as storage for rb_node */ + qdisc_cb_private_validate(skb, sizeof(struct netem_skb_cb)); + return (struct netem_skb_cb *)qdisc_skb_cb(skb)->data; +} + +/* init_crandom - initialize correlated random number generator + * Use entropy source for initial seed. + */ +static void init_crandom(struct crndstate *state, unsigned long rho) +{ + state->rho = rho; + state->last = get_random_u32(); +} + +/* get_crandom - correlated random number generator + * Next number depends on last value. + * rho is scaled to avoid floating point. + */ +static u32 get_crandom(struct crndstate *state) +{ + u64 value, rho; + unsigned long answer; + + if (!state || state->rho == 0) /* no correlation */ + return get_random_u32(); + + value = get_random_u32(); + rho = (u64)state->rho + 1; + answer = (value * ((1ull<<32) - rho) + state->last * rho) >> 32; + state->last = answer; + return answer; +} + +/* loss_4state - 4-state model loss generator + * Generates losses according to the 4-state Markov chain adopted in + * the GI (General and Intuitive) loss model. + */ +static bool loss_4state(struct netem_sched_data *q) +{ + struct clgstate *clg = &q->clg; + u32 rnd = get_random_u32(); + + /* + * Makes a comparison between rnd and the transition + * probabilities outgoing from the current state, then decides the + * next state and if the next packet has to be transmitted or lost. + * The four states correspond to: + * TX_IN_GAP_PERIOD => successfully transmitted packets within a gap period + * LOST_IN_GAP_PERIOD => isolated losses within a gap period + * LOST_IN_BURST_PERIOD => lost packets within a burst period + * TX_IN_BURST_PERIOD => successfully transmitted packets within a burst period + */ + switch (clg->state) { + case TX_IN_GAP_PERIOD: + if (rnd < clg->a4) { + clg->state = LOST_IN_GAP_PERIOD; + return true; + } else if (clg->a4 < rnd && rnd < clg->a1 + clg->a4) { + clg->state = LOST_IN_BURST_PERIOD; + return true; + } else if (clg->a1 + clg->a4 < rnd) { + clg->state = TX_IN_GAP_PERIOD; + } + + break; + case TX_IN_BURST_PERIOD: + if (rnd < clg->a5) { + clg->state = LOST_IN_BURST_PERIOD; + return true; + } else { + clg->state = TX_IN_BURST_PERIOD; + } + + break; + case LOST_IN_BURST_PERIOD: + if (rnd < clg->a3) + clg->state = TX_IN_BURST_PERIOD; + else if (clg->a3 < rnd && rnd < clg->a2 + clg->a3) { + clg->state = TX_IN_GAP_PERIOD; + } else if (clg->a2 + clg->a3 < rnd) { + clg->state = LOST_IN_BURST_PERIOD; + return true; + } + break; + case LOST_IN_GAP_PERIOD: + clg->state = TX_IN_GAP_PERIOD; + break; + } + + return false; +} + +/* loss_gilb_ell - Gilbert-Elliot model loss generator + * Generates losses according to the Gilbert-Elliot loss model or + * its special cases (Gilbert or Simple Gilbert) + * + * Makes a comparison between random number and the transition + * probabilities outgoing from the current state, then decides the + * next state. A second random number is extracted and the comparison + * with the loss probability of the current state decides if the next + * packet will be transmitted or lost. + */ +static bool loss_gilb_ell(struct netem_sched_data *q) +{ + struct clgstate *clg = &q->clg; + + switch (clg->state) { + case GOOD_STATE: + if (get_random_u32() < clg->a1) + clg->state = BAD_STATE; + if (get_random_u32() < clg->a4) + return true; + break; + case BAD_STATE: + if (get_random_u32() < clg->a2) + clg->state = GOOD_STATE; + if (get_random_u32() > clg->a3) + return true; + } + + return false; +} + +static bool loss_event(struct netem_sched_data *q) +{ + switch (q->loss_model) { + case CLG_RANDOM: + /* Random packet drop 0 => none, ~0 => all */ + return q->loss && q->loss >= get_crandom(&q->loss_cor); + + case CLG_4_STATES: + /* 4state loss model algorithm (used also for GI model) + * Extracts a value from the markov 4 state loss generator, + * if it is 1 drops a packet and if needed writes the event in + * the kernel logs + */ + return loss_4state(q); + + case CLG_GILB_ELL: + /* Gilbert-Elliot loss model algorithm + * Extracts a value from the Gilbert-Elliot loss generator, + * if it is 1 drops a packet and if needed writes the event in + * the kernel logs + */ + return loss_gilb_ell(q); + } + + return false; /* not reached */ +} + + +/* tabledist - return a pseudo-randomly distributed value with mean mu and + * std deviation sigma. Uses table lookup to approximate the desired + * distribution, and a uniformly-distributed pseudo-random source. + */ +static s64 tabledist(s64 mu, s32 sigma, + struct crndstate *state, + const struct disttable *dist) +{ + s64 x; + long t; + u32 rnd; + + if (sigma == 0) + return mu; + + rnd = get_crandom(state); + + /* default uniform distribution */ + if (dist == NULL) + return ((rnd % (2 * (u32)sigma)) + mu) - sigma; + + t = dist->table[rnd % dist->size]; + x = (sigma % NETEM_DIST_SCALE) * t; + if (x >= 0) + x += NETEM_DIST_SCALE/2; + else + x -= NETEM_DIST_SCALE/2; + + return x / NETEM_DIST_SCALE + (sigma / NETEM_DIST_SCALE) * t + mu; +} + +static u64 packet_time_ns(u64 len, const struct netem_sched_data *q) +{ + len += q->packet_overhead; + + if (q->cell_size) { + u32 cells = reciprocal_divide(len, q->cell_size_reciprocal); + + if (len > cells * q->cell_size) /* extra cell needed for remainder */ + cells++; + len = cells * (q->cell_size + q->cell_overhead); + } + + return div64_u64(len * NSEC_PER_SEC, q->rate); +} + +static void tfifo_reset(struct Qdisc *sch) +{ + struct netem_sched_data *q = qdisc_priv(sch); + struct rb_node *p = rb_first(&q->t_root); + + while (p) { + struct sk_buff *skb = rb_to_skb(p); + + p = rb_next(p); + rb_erase(&skb->rbnode, &q->t_root); + rtnl_kfree_skbs(skb, skb); + } + + rtnl_kfree_skbs(q->t_head, q->t_tail); + q->t_head = NULL; + q->t_tail = NULL; +} + +static void tfifo_enqueue(struct sk_buff *nskb, struct Qdisc *sch) +{ + struct netem_sched_data *q = qdisc_priv(sch); + u64 tnext = netem_skb_cb(nskb)->time_to_send; + + if (!q->t_tail || tnext >= netem_skb_cb(q->t_tail)->time_to_send) { + if (q->t_tail) + q->t_tail->next = nskb; + else + q->t_head = nskb; + q->t_tail = nskb; + } else { + struct rb_node **p = &q->t_root.rb_node, *parent = NULL; + + while (*p) { + struct sk_buff *skb; + + parent = *p; + skb = rb_to_skb(parent); + if (tnext >= netem_skb_cb(skb)->time_to_send) + p = &parent->rb_right; + else + p = &parent->rb_left; + } + rb_link_node(&nskb->rbnode, parent, p); + rb_insert_color(&nskb->rbnode, &q->t_root); + } + sch->q.qlen++; +} + +/* netem can't properly corrupt a megapacket (like we get from GSO), so instead + * when we statistically choose to corrupt one, we instead segment it, returning + * the first packet to be corrupted, and re-enqueue the remaining frames + */ +static struct sk_buff *netem_segment(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct sk_buff *segs; + netdev_features_t features = netif_skb_features(skb); + + segs = skb_gso_segment(skb, features & ~NETIF_F_GSO_MASK); + + if (IS_ERR_OR_NULL(segs)) { + qdisc_drop(skb, sch, to_free); + return NULL; + } + consume_skb(skb); + return segs; +} + +/* + * Insert one skb into qdisc. + * Note: parent depends on return value to account for queue length. + * NET_XMIT_DROP: queue length didn't change. + * NET_XMIT_SUCCESS: one skb was queued. + */ +static int netem_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct netem_sched_data *q = qdisc_priv(sch); + /* We don't fill cb now as skb_unshare() may invalidate it */ + struct netem_skb_cb *cb; + struct sk_buff *skb2; + struct sk_buff *segs = NULL; + unsigned int prev_len = qdisc_pkt_len(skb); + int count = 1; + int rc = NET_XMIT_SUCCESS; + int rc_drop = NET_XMIT_DROP; + + /* Do not fool qdisc_drop_all() */ + skb->prev = NULL; + + /* Random duplication */ + if (q->duplicate && q->duplicate >= get_crandom(&q->dup_cor)) + ++count; + + /* Drop packet? */ + if (loss_event(q)) { + if (q->ecn && INET_ECN_set_ce(skb)) + qdisc_qstats_drop(sch); /* mark packet */ + else + --count; + } + if (count == 0) { + qdisc_qstats_drop(sch); + __qdisc_drop(skb, to_free); + return NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; + } + + /* If a delay is expected, orphan the skb. (orphaning usually takes + * place at TX completion time, so _before_ the link transit delay) + */ + if (q->latency || q->jitter || q->rate) + skb_orphan_partial(skb); + + /* + * If we need to duplicate packet, then re-insert at top of the + * qdisc tree, since parent queuer expects that only one + * skb will be queued. + */ + if (count > 1 && (skb2 = skb_clone(skb, GFP_ATOMIC)) != NULL) { + struct Qdisc *rootq = qdisc_root_bh(sch); + u32 dupsave = q->duplicate; /* prevent duplicating a dup... */ + + q->duplicate = 0; + rootq->enqueue(skb2, rootq, to_free); + q->duplicate = dupsave; + rc_drop = NET_XMIT_SUCCESS; + } + + /* + * Randomized packet corruption. + * Make copy if needed since we are modifying + * If packet is going to be hardware checksummed, then + * do it now in software before we mangle it. + */ + if (q->corrupt && q->corrupt >= get_crandom(&q->corrupt_cor)) { + if (skb_is_gso(skb)) { + skb = netem_segment(skb, sch, to_free); + if (!skb) + return rc_drop; + segs = skb->next; + skb_mark_not_on_list(skb); + qdisc_skb_cb(skb)->pkt_len = skb->len; + } + + skb = skb_unshare(skb, GFP_ATOMIC); + if (unlikely(!skb)) { + qdisc_qstats_drop(sch); + goto finish_segs; + } + if (skb->ip_summed == CHECKSUM_PARTIAL && + skb_checksum_help(skb)) { + qdisc_drop(skb, sch, to_free); + skb = NULL; + goto finish_segs; + } + + skb->data[prandom_u32_max(skb_headlen(skb))] ^= + 1<q.qlen >= sch->limit)) { + /* re-link segs, so that qdisc_drop_all() frees them all */ + skb->next = segs; + qdisc_drop_all(skb, sch, to_free); + return rc_drop; + } + + qdisc_qstats_backlog_inc(sch, skb); + + cb = netem_skb_cb(skb); + if (q->gap == 0 || /* not doing reordering */ + q->counter < q->gap - 1 || /* inside last reordering gap */ + q->reorder < get_crandom(&q->reorder_cor)) { + u64 now; + s64 delay; + + delay = tabledist(q->latency, q->jitter, + &q->delay_cor, q->delay_dist); + + now = ktime_get_ns(); + + if (q->rate) { + struct netem_skb_cb *last = NULL; + + if (sch->q.tail) + last = netem_skb_cb(sch->q.tail); + if (q->t_root.rb_node) { + struct sk_buff *t_skb; + struct netem_skb_cb *t_last; + + t_skb = skb_rb_last(&q->t_root); + t_last = netem_skb_cb(t_skb); + if (!last || + t_last->time_to_send > last->time_to_send) + last = t_last; + } + if (q->t_tail) { + struct netem_skb_cb *t_last = + netem_skb_cb(q->t_tail); + + if (!last || + t_last->time_to_send > last->time_to_send) + last = t_last; + } + + if (last) { + /* + * Last packet in queue is reference point (now), + * calculate this time bonus and subtract + * from delay. + */ + delay -= last->time_to_send - now; + delay = max_t(s64, 0, delay); + now = last->time_to_send; + } + + delay += packet_time_ns(qdisc_pkt_len(skb), q); + } + + cb->time_to_send = now + delay; + ++q->counter; + tfifo_enqueue(skb, sch); + } else { + /* + * Do re-ordering by putting one out of N packets at the front + * of the queue. + */ + cb->time_to_send = ktime_get_ns(); + q->counter = 0; + + __qdisc_enqueue_head(skb, &sch->q); + sch->qstats.requeues++; + } + +finish_segs: + if (segs) { + unsigned int len, last_len; + int nb; + + len = skb ? skb->len : 0; + nb = skb ? 1 : 0; + + while (segs) { + skb2 = segs->next; + skb_mark_not_on_list(segs); + qdisc_skb_cb(segs)->pkt_len = segs->len; + last_len = segs->len; + rc = qdisc_enqueue(segs, sch, to_free); + if (rc != NET_XMIT_SUCCESS) { + if (net_xmit_drop_count(rc)) + qdisc_qstats_drop(sch); + } else { + nb++; + len += last_len; + } + segs = skb2; + } + /* Parent qdiscs accounted for 1 skb of size @prev_len */ + qdisc_tree_reduce_backlog(sch, -(nb - 1), -(len - prev_len)); + } else if (!skb) { + return NET_XMIT_DROP; + } + return NET_XMIT_SUCCESS; +} + +/* Delay the next round with a new future slot with a + * correct number of bytes and packets. + */ + +static void get_slot_next(struct netem_sched_data *q, u64 now) +{ + s64 next_delay; + + if (!q->slot_dist) + next_delay = q->slot_config.min_delay + + (get_random_u32() * + (q->slot_config.max_delay - + q->slot_config.min_delay) >> 32); + else + next_delay = tabledist(q->slot_config.dist_delay, + (s32)(q->slot_config.dist_jitter), + NULL, q->slot_dist); + + q->slot.slot_next = now + next_delay; + q->slot.packets_left = q->slot_config.max_packets; + q->slot.bytes_left = q->slot_config.max_bytes; +} + +static struct sk_buff *netem_peek(struct netem_sched_data *q) +{ + struct sk_buff *skb = skb_rb_first(&q->t_root); + u64 t1, t2; + + if (!skb) + return q->t_head; + if (!q->t_head) + return skb; + + t1 = netem_skb_cb(skb)->time_to_send; + t2 = netem_skb_cb(q->t_head)->time_to_send; + if (t1 < t2) + return skb; + return q->t_head; +} + +static void netem_erase_head(struct netem_sched_data *q, struct sk_buff *skb) +{ + if (skb == q->t_head) { + q->t_head = skb->next; + if (!q->t_head) + q->t_tail = NULL; + } else { + rb_erase(&skb->rbnode, &q->t_root); + } +} + +static struct sk_buff *netem_dequeue(struct Qdisc *sch) +{ + struct netem_sched_data *q = qdisc_priv(sch); + struct sk_buff *skb; + +tfifo_dequeue: + skb = __qdisc_dequeue_head(&sch->q); + if (skb) { + qdisc_qstats_backlog_dec(sch, skb); +deliver: + qdisc_bstats_update(sch, skb); + return skb; + } + skb = netem_peek(q); + if (skb) { + u64 time_to_send; + u64 now = ktime_get_ns(); + + /* if more time remaining? */ + time_to_send = netem_skb_cb(skb)->time_to_send; + if (q->slot.slot_next && q->slot.slot_next < time_to_send) + get_slot_next(q, now); + + if (time_to_send <= now && q->slot.slot_next <= now) { + netem_erase_head(q, skb); + sch->q.qlen--; + qdisc_qstats_backlog_dec(sch, skb); + skb->next = NULL; + skb->prev = NULL; + /* skb->dev shares skb->rbnode area, + * we need to restore its value. + */ + skb->dev = qdisc_dev(sch); + + if (q->slot.slot_next) { + q->slot.packets_left--; + q->slot.bytes_left -= qdisc_pkt_len(skb); + if (q->slot.packets_left <= 0 || + q->slot.bytes_left <= 0) + get_slot_next(q, now); + } + + if (q->qdisc) { + unsigned int pkt_len = qdisc_pkt_len(skb); + struct sk_buff *to_free = NULL; + int err; + + err = qdisc_enqueue(skb, q->qdisc, &to_free); + kfree_skb_list(to_free); + if (err != NET_XMIT_SUCCESS && + net_xmit_drop_count(err)) { + qdisc_qstats_drop(sch); + qdisc_tree_reduce_backlog(sch, 1, + pkt_len); + } + goto tfifo_dequeue; + } + goto deliver; + } + + if (q->qdisc) { + skb = q->qdisc->ops->dequeue(q->qdisc); + if (skb) + goto deliver; + } + + qdisc_watchdog_schedule_ns(&q->watchdog, + max(time_to_send, + q->slot.slot_next)); + } + + if (q->qdisc) { + skb = q->qdisc->ops->dequeue(q->qdisc); + if (skb) + goto deliver; + } + return NULL; +} + +static void netem_reset(struct Qdisc *sch) +{ + struct netem_sched_data *q = qdisc_priv(sch); + + qdisc_reset_queue(sch); + tfifo_reset(sch); + if (q->qdisc) + qdisc_reset(q->qdisc); + qdisc_watchdog_cancel(&q->watchdog); +} + +static void dist_free(struct disttable *d) +{ + kvfree(d); +} + +/* + * Distribution data is a variable size payload containing + * signed 16 bit values. + */ + +static int get_dist_table(struct disttable **tbl, const struct nlattr *attr) +{ + size_t n = nla_len(attr)/sizeof(__s16); + const __s16 *data = nla_data(attr); + struct disttable *d; + int i; + + if (!n || n > NETEM_DIST_MAX) + return -EINVAL; + + d = kvmalloc(struct_size(d, table, n), GFP_KERNEL); + if (!d) + return -ENOMEM; + + d->size = n; + for (i = 0; i < n; i++) + d->table[i] = data[i]; + + *tbl = d; + return 0; +} + +static void get_slot(struct netem_sched_data *q, const struct nlattr *attr) +{ + const struct tc_netem_slot *c = nla_data(attr); + + q->slot_config = *c; + if (q->slot_config.max_packets == 0) + q->slot_config.max_packets = INT_MAX; + if (q->slot_config.max_bytes == 0) + q->slot_config.max_bytes = INT_MAX; + + /* capping dist_jitter to the range acceptable by tabledist() */ + q->slot_config.dist_jitter = min_t(__s64, INT_MAX, abs(q->slot_config.dist_jitter)); + + q->slot.packets_left = q->slot_config.max_packets; + q->slot.bytes_left = q->slot_config.max_bytes; + if (q->slot_config.min_delay | q->slot_config.max_delay | + q->slot_config.dist_jitter) + q->slot.slot_next = ktime_get_ns(); + else + q->slot.slot_next = 0; +} + +static void get_correlation(struct netem_sched_data *q, const struct nlattr *attr) +{ + const struct tc_netem_corr *c = nla_data(attr); + + init_crandom(&q->delay_cor, c->delay_corr); + init_crandom(&q->loss_cor, c->loss_corr); + init_crandom(&q->dup_cor, c->dup_corr); +} + +static void get_reorder(struct netem_sched_data *q, const struct nlattr *attr) +{ + const struct tc_netem_reorder *r = nla_data(attr); + + q->reorder = r->probability; + init_crandom(&q->reorder_cor, r->correlation); +} + +static void get_corrupt(struct netem_sched_data *q, const struct nlattr *attr) +{ + const struct tc_netem_corrupt *r = nla_data(attr); + + q->corrupt = r->probability; + init_crandom(&q->corrupt_cor, r->correlation); +} + +static void get_rate(struct netem_sched_data *q, const struct nlattr *attr) +{ + const struct tc_netem_rate *r = nla_data(attr); + + q->rate = r->rate; + q->packet_overhead = r->packet_overhead; + q->cell_size = r->cell_size; + q->cell_overhead = r->cell_overhead; + if (q->cell_size) + q->cell_size_reciprocal = reciprocal_value(q->cell_size); + else + q->cell_size_reciprocal = (struct reciprocal_value) { 0 }; +} + +static int get_loss_clg(struct netem_sched_data *q, const struct nlattr *attr) +{ + const struct nlattr *la; + int rem; + + nla_for_each_nested(la, attr, rem) { + u16 type = nla_type(la); + + switch (type) { + case NETEM_LOSS_GI: { + const struct tc_netem_gimodel *gi = nla_data(la); + + if (nla_len(la) < sizeof(struct tc_netem_gimodel)) { + pr_info("netem: incorrect gi model size\n"); + return -EINVAL; + } + + q->loss_model = CLG_4_STATES; + + q->clg.state = TX_IN_GAP_PERIOD; + q->clg.a1 = gi->p13; + q->clg.a2 = gi->p31; + q->clg.a3 = gi->p32; + q->clg.a4 = gi->p14; + q->clg.a5 = gi->p23; + break; + } + + case NETEM_LOSS_GE: { + const struct tc_netem_gemodel *ge = nla_data(la); + + if (nla_len(la) < sizeof(struct tc_netem_gemodel)) { + pr_info("netem: incorrect ge model size\n"); + return -EINVAL; + } + + q->loss_model = CLG_GILB_ELL; + q->clg.state = GOOD_STATE; + q->clg.a1 = ge->p; + q->clg.a2 = ge->r; + q->clg.a3 = ge->h; + q->clg.a4 = ge->k1; + break; + } + + default: + pr_info("netem: unknown loss type %u\n", type); + return -EINVAL; + } + } + + return 0; +} + +static const struct nla_policy netem_policy[TCA_NETEM_MAX + 1] = { + [TCA_NETEM_CORR] = { .len = sizeof(struct tc_netem_corr) }, + [TCA_NETEM_REORDER] = { .len = sizeof(struct tc_netem_reorder) }, + [TCA_NETEM_CORRUPT] = { .len = sizeof(struct tc_netem_corrupt) }, + [TCA_NETEM_RATE] = { .len = sizeof(struct tc_netem_rate) }, + [TCA_NETEM_LOSS] = { .type = NLA_NESTED }, + [TCA_NETEM_ECN] = { .type = NLA_U32 }, + [TCA_NETEM_RATE64] = { .type = NLA_U64 }, + [TCA_NETEM_LATENCY64] = { .type = NLA_S64 }, + [TCA_NETEM_JITTER64] = { .type = NLA_S64 }, + [TCA_NETEM_SLOT] = { .len = sizeof(struct tc_netem_slot) }, +}; + +static int parse_attr(struct nlattr *tb[], int maxtype, struct nlattr *nla, + const struct nla_policy *policy, int len) +{ + int nested_len = nla_len(nla) - NLA_ALIGN(len); + + if (nested_len < 0) { + pr_info("netem: invalid attributes len %d\n", nested_len); + return -EINVAL; + } + + if (nested_len >= nla_attr_size(0)) + return nla_parse_deprecated(tb, maxtype, + nla_data(nla) + NLA_ALIGN(len), + nested_len, policy, NULL); + + memset(tb, 0, sizeof(struct nlattr *) * (maxtype + 1)); + return 0; +} + +/* Parse netlink message to set options */ +static int netem_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct netem_sched_data *q = qdisc_priv(sch); + struct nlattr *tb[TCA_NETEM_MAX + 1]; + struct disttable *delay_dist = NULL; + struct disttable *slot_dist = NULL; + struct tc_netem_qopt *qopt; + struct clgstate old_clg; + int old_loss_model = CLG_RANDOM; + int ret; + + qopt = nla_data(opt); + ret = parse_attr(tb, TCA_NETEM_MAX, opt, netem_policy, sizeof(*qopt)); + if (ret < 0) + return ret; + + if (tb[TCA_NETEM_DELAY_DIST]) { + ret = get_dist_table(&delay_dist, tb[TCA_NETEM_DELAY_DIST]); + if (ret) + goto table_free; + } + + if (tb[TCA_NETEM_SLOT_DIST]) { + ret = get_dist_table(&slot_dist, tb[TCA_NETEM_SLOT_DIST]); + if (ret) + goto table_free; + } + + sch_tree_lock(sch); + /* backup q->clg and q->loss_model */ + old_clg = q->clg; + old_loss_model = q->loss_model; + + if (tb[TCA_NETEM_LOSS]) { + ret = get_loss_clg(q, tb[TCA_NETEM_LOSS]); + if (ret) { + q->loss_model = old_loss_model; + q->clg = old_clg; + goto unlock; + } + } else { + q->loss_model = CLG_RANDOM; + } + + if (delay_dist) + swap(q->delay_dist, delay_dist); + if (slot_dist) + swap(q->slot_dist, slot_dist); + sch->limit = qopt->limit; + + q->latency = PSCHED_TICKS2NS(qopt->latency); + q->jitter = PSCHED_TICKS2NS(qopt->jitter); + q->limit = qopt->limit; + q->gap = qopt->gap; + q->counter = 0; + q->loss = qopt->loss; + q->duplicate = qopt->duplicate; + + /* for compatibility with earlier versions. + * if gap is set, need to assume 100% probability + */ + if (q->gap) + q->reorder = ~0; + + if (tb[TCA_NETEM_CORR]) + get_correlation(q, tb[TCA_NETEM_CORR]); + + if (tb[TCA_NETEM_REORDER]) + get_reorder(q, tb[TCA_NETEM_REORDER]); + + if (tb[TCA_NETEM_CORRUPT]) + get_corrupt(q, tb[TCA_NETEM_CORRUPT]); + + if (tb[TCA_NETEM_RATE]) + get_rate(q, tb[TCA_NETEM_RATE]); + + if (tb[TCA_NETEM_RATE64]) + q->rate = max_t(u64, q->rate, + nla_get_u64(tb[TCA_NETEM_RATE64])); + + if (tb[TCA_NETEM_LATENCY64]) + q->latency = nla_get_s64(tb[TCA_NETEM_LATENCY64]); + + if (tb[TCA_NETEM_JITTER64]) + q->jitter = nla_get_s64(tb[TCA_NETEM_JITTER64]); + + if (tb[TCA_NETEM_ECN]) + q->ecn = nla_get_u32(tb[TCA_NETEM_ECN]); + + if (tb[TCA_NETEM_SLOT]) + get_slot(q, tb[TCA_NETEM_SLOT]); + + /* capping jitter to the range acceptable by tabledist() */ + q->jitter = min_t(s64, abs(q->jitter), INT_MAX); + +unlock: + sch_tree_unlock(sch); + +table_free: + dist_free(delay_dist); + dist_free(slot_dist); + return ret; +} + +static int netem_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct netem_sched_data *q = qdisc_priv(sch); + int ret; + + qdisc_watchdog_init(&q->watchdog, sch); + + if (!opt) + return -EINVAL; + + q->loss_model = CLG_RANDOM; + ret = netem_change(sch, opt, extack); + if (ret) + pr_info("netem: change failed\n"); + return ret; +} + +static void netem_destroy(struct Qdisc *sch) +{ + struct netem_sched_data *q = qdisc_priv(sch); + + qdisc_watchdog_cancel(&q->watchdog); + if (q->qdisc) + qdisc_put(q->qdisc); + dist_free(q->delay_dist); + dist_free(q->slot_dist); +} + +static int dump_loss_model(const struct netem_sched_data *q, + struct sk_buff *skb) +{ + struct nlattr *nest; + + nest = nla_nest_start_noflag(skb, TCA_NETEM_LOSS); + if (nest == NULL) + goto nla_put_failure; + + switch (q->loss_model) { + case CLG_RANDOM: + /* legacy loss model */ + nla_nest_cancel(skb, nest); + return 0; /* no data */ + + case CLG_4_STATES: { + struct tc_netem_gimodel gi = { + .p13 = q->clg.a1, + .p31 = q->clg.a2, + .p32 = q->clg.a3, + .p14 = q->clg.a4, + .p23 = q->clg.a5, + }; + + if (nla_put(skb, NETEM_LOSS_GI, sizeof(gi), &gi)) + goto nla_put_failure; + break; + } + case CLG_GILB_ELL: { + struct tc_netem_gemodel ge = { + .p = q->clg.a1, + .r = q->clg.a2, + .h = q->clg.a3, + .k1 = q->clg.a4, + }; + + if (nla_put(skb, NETEM_LOSS_GE, sizeof(ge), &ge)) + goto nla_put_failure; + break; + } + } + + nla_nest_end(skb, nest); + return 0; + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static int netem_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + const struct netem_sched_data *q = qdisc_priv(sch); + struct nlattr *nla = (struct nlattr *) skb_tail_pointer(skb); + struct tc_netem_qopt qopt; + struct tc_netem_corr cor; + struct tc_netem_reorder reorder; + struct tc_netem_corrupt corrupt; + struct tc_netem_rate rate; + struct tc_netem_slot slot; + + qopt.latency = min_t(psched_time_t, PSCHED_NS2TICKS(q->latency), + UINT_MAX); + qopt.jitter = min_t(psched_time_t, PSCHED_NS2TICKS(q->jitter), + UINT_MAX); + qopt.limit = q->limit; + qopt.loss = q->loss; + qopt.gap = q->gap; + qopt.duplicate = q->duplicate; + if (nla_put(skb, TCA_OPTIONS, sizeof(qopt), &qopt)) + goto nla_put_failure; + + if (nla_put(skb, TCA_NETEM_LATENCY64, sizeof(q->latency), &q->latency)) + goto nla_put_failure; + + if (nla_put(skb, TCA_NETEM_JITTER64, sizeof(q->jitter), &q->jitter)) + goto nla_put_failure; + + cor.delay_corr = q->delay_cor.rho; + cor.loss_corr = q->loss_cor.rho; + cor.dup_corr = q->dup_cor.rho; + if (nla_put(skb, TCA_NETEM_CORR, sizeof(cor), &cor)) + goto nla_put_failure; + + reorder.probability = q->reorder; + reorder.correlation = q->reorder_cor.rho; + if (nla_put(skb, TCA_NETEM_REORDER, sizeof(reorder), &reorder)) + goto nla_put_failure; + + corrupt.probability = q->corrupt; + corrupt.correlation = q->corrupt_cor.rho; + if (nla_put(skb, TCA_NETEM_CORRUPT, sizeof(corrupt), &corrupt)) + goto nla_put_failure; + + if (q->rate >= (1ULL << 32)) { + if (nla_put_u64_64bit(skb, TCA_NETEM_RATE64, q->rate, + TCA_NETEM_PAD)) + goto nla_put_failure; + rate.rate = ~0U; + } else { + rate.rate = q->rate; + } + rate.packet_overhead = q->packet_overhead; + rate.cell_size = q->cell_size; + rate.cell_overhead = q->cell_overhead; + if (nla_put(skb, TCA_NETEM_RATE, sizeof(rate), &rate)) + goto nla_put_failure; + + if (q->ecn && nla_put_u32(skb, TCA_NETEM_ECN, q->ecn)) + goto nla_put_failure; + + if (dump_loss_model(q, skb) != 0) + goto nla_put_failure; + + if (q->slot_config.min_delay | q->slot_config.max_delay | + q->slot_config.dist_jitter) { + slot = q->slot_config; + if (slot.max_packets == INT_MAX) + slot.max_packets = 0; + if (slot.max_bytes == INT_MAX) + slot.max_bytes = 0; + if (nla_put(skb, TCA_NETEM_SLOT, sizeof(slot), &slot)) + goto nla_put_failure; + } + + return nla_nest_end(skb, nla); + +nla_put_failure: + nlmsg_trim(skb, nla); + return -1; +} + +static int netem_dump_class(struct Qdisc *sch, unsigned long cl, + struct sk_buff *skb, struct tcmsg *tcm) +{ + struct netem_sched_data *q = qdisc_priv(sch); + + if (cl != 1 || !q->qdisc) /* only one class */ + return -ENOENT; + + tcm->tcm_handle |= TC_H_MIN(1); + tcm->tcm_info = q->qdisc->handle; + + return 0; +} + +static int netem_graft(struct Qdisc *sch, unsigned long arg, struct Qdisc *new, + struct Qdisc **old, struct netlink_ext_ack *extack) +{ + struct netem_sched_data *q = qdisc_priv(sch); + + *old = qdisc_replace(sch, new, &q->qdisc); + return 0; +} + +static struct Qdisc *netem_leaf(struct Qdisc *sch, unsigned long arg) +{ + struct netem_sched_data *q = qdisc_priv(sch); + return q->qdisc; +} + +static unsigned long netem_find(struct Qdisc *sch, u32 classid) +{ + return 1; +} + +static void netem_walk(struct Qdisc *sch, struct qdisc_walker *walker) +{ + if (!walker->stop) { + if (!tc_qdisc_stats_dump(sch, 1, walker)) + return; + } +} + +static const struct Qdisc_class_ops netem_class_ops = { + .graft = netem_graft, + .leaf = netem_leaf, + .find = netem_find, + .walk = netem_walk, + .dump = netem_dump_class, +}; + +static struct Qdisc_ops netem_qdisc_ops __read_mostly = { + .id = "netem", + .cl_ops = &netem_class_ops, + .priv_size = sizeof(struct netem_sched_data), + .enqueue = netem_enqueue, + .dequeue = netem_dequeue, + .peek = qdisc_peek_dequeued, + .init = netem_init, + .reset = netem_reset, + .destroy = netem_destroy, + .change = netem_change, + .dump = netem_dump, + .owner = THIS_MODULE, +}; + + +static int __init netem_module_init(void) +{ + pr_info("netem: version " VERSION "\n"); + return register_qdisc(&netem_qdisc_ops); +} +static void __exit netem_module_exit(void) +{ + unregister_qdisc(&netem_qdisc_ops); +} +module_init(netem_module_init) +module_exit(netem_module_exit) +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_pie.c b/net/sched/sch_pie.c new file mode 100644 index 000000000..b60b31ef7 --- /dev/null +++ b/net/sched/sch_pie.c @@ -0,0 +1,576 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (C) 2013 Cisco Systems, Inc, 2013. + * + * Author: Vijay Subramanian + * Author: Mythili Prabhu + * + * ECN support is added by Naeem Khademi + * University of Oslo, Norway. + * + * References: + * RFC 8033: https://tools.ietf.org/html/rfc8033 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* private data for the Qdisc */ +struct pie_sched_data { + struct pie_vars vars; + struct pie_params params; + struct pie_stats stats; + struct timer_list adapt_timer; + struct Qdisc *sch; +}; + +bool pie_drop_early(struct Qdisc *sch, struct pie_params *params, + struct pie_vars *vars, u32 backlog, u32 packet_size) +{ + u64 rnd; + u64 local_prob = vars->prob; + u32 mtu = psched_mtu(qdisc_dev(sch)); + + /* If there is still burst allowance left skip random early drop */ + if (vars->burst_time > 0) + return false; + + /* If current delay is less than half of target, and + * if drop prob is low already, disable early_drop + */ + if ((vars->qdelay < params->target / 2) && + (vars->prob < MAX_PROB / 5)) + return false; + + /* If we have fewer than 2 mtu-sized packets, disable pie_drop_early, + * similar to min_th in RED + */ + if (backlog < 2 * mtu) + return false; + + /* If bytemode is turned on, use packet size to compute new + * probablity. Smaller packets will have lower drop prob in this case + */ + if (params->bytemode && packet_size <= mtu) + local_prob = (u64)packet_size * div_u64(local_prob, mtu); + else + local_prob = vars->prob; + + if (local_prob == 0) + vars->accu_prob = 0; + else + vars->accu_prob += local_prob; + + if (vars->accu_prob < (MAX_PROB / 100) * 85) + return false; + if (vars->accu_prob >= (MAX_PROB / 2) * 17) + return true; + + get_random_bytes(&rnd, 8); + if ((rnd >> BITS_PER_BYTE) < local_prob) { + vars->accu_prob = 0; + return true; + } + + return false; +} +EXPORT_SYMBOL_GPL(pie_drop_early); + +static int pie_qdisc_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct pie_sched_data *q = qdisc_priv(sch); + bool enqueue = false; + + if (unlikely(qdisc_qlen(sch) >= sch->limit)) { + q->stats.overlimit++; + goto out; + } + + if (!pie_drop_early(sch, &q->params, &q->vars, sch->qstats.backlog, + skb->len)) { + enqueue = true; + } else if (q->params.ecn && (q->vars.prob <= MAX_PROB / 10) && + INET_ECN_set_ce(skb)) { + /* If packet is ecn capable, mark it if drop probability + * is lower than 10%, else drop it. + */ + q->stats.ecn_mark++; + enqueue = true; + } + + /* we can enqueue the packet */ + if (enqueue) { + /* Set enqueue time only when dq_rate_estimator is disabled. */ + if (!q->params.dq_rate_estimator) + pie_set_enqueue_time(skb); + + q->stats.packets_in++; + if (qdisc_qlen(sch) > q->stats.maxq) + q->stats.maxq = qdisc_qlen(sch); + + return qdisc_enqueue_tail(skb, sch); + } + +out: + q->stats.dropped++; + q->vars.accu_prob = 0; + return qdisc_drop(skb, sch, to_free); +} + +static const struct nla_policy pie_policy[TCA_PIE_MAX + 1] = { + [TCA_PIE_TARGET] = {.type = NLA_U32}, + [TCA_PIE_LIMIT] = {.type = NLA_U32}, + [TCA_PIE_TUPDATE] = {.type = NLA_U32}, + [TCA_PIE_ALPHA] = {.type = NLA_U32}, + [TCA_PIE_BETA] = {.type = NLA_U32}, + [TCA_PIE_ECN] = {.type = NLA_U32}, + [TCA_PIE_BYTEMODE] = {.type = NLA_U32}, + [TCA_PIE_DQ_RATE_ESTIMATOR] = {.type = NLA_U32}, +}; + +static int pie_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct pie_sched_data *q = qdisc_priv(sch); + struct nlattr *tb[TCA_PIE_MAX + 1]; + unsigned int qlen, dropped = 0; + int err; + + err = nla_parse_nested_deprecated(tb, TCA_PIE_MAX, opt, pie_policy, + NULL); + if (err < 0) + return err; + + sch_tree_lock(sch); + + /* convert from microseconds to pschedtime */ + if (tb[TCA_PIE_TARGET]) { + /* target is in us */ + u32 target = nla_get_u32(tb[TCA_PIE_TARGET]); + + /* convert to pschedtime */ + q->params.target = PSCHED_NS2TICKS((u64)target * NSEC_PER_USEC); + } + + /* tupdate is in jiffies */ + if (tb[TCA_PIE_TUPDATE]) + q->params.tupdate = + usecs_to_jiffies(nla_get_u32(tb[TCA_PIE_TUPDATE])); + + if (tb[TCA_PIE_LIMIT]) { + u32 limit = nla_get_u32(tb[TCA_PIE_LIMIT]); + + q->params.limit = limit; + sch->limit = limit; + } + + if (tb[TCA_PIE_ALPHA]) + q->params.alpha = nla_get_u32(tb[TCA_PIE_ALPHA]); + + if (tb[TCA_PIE_BETA]) + q->params.beta = nla_get_u32(tb[TCA_PIE_BETA]); + + if (tb[TCA_PIE_ECN]) + q->params.ecn = nla_get_u32(tb[TCA_PIE_ECN]); + + if (tb[TCA_PIE_BYTEMODE]) + q->params.bytemode = nla_get_u32(tb[TCA_PIE_BYTEMODE]); + + if (tb[TCA_PIE_DQ_RATE_ESTIMATOR]) + q->params.dq_rate_estimator = + nla_get_u32(tb[TCA_PIE_DQ_RATE_ESTIMATOR]); + + /* Drop excess packets if new limit is lower */ + qlen = sch->q.qlen; + while (sch->q.qlen > sch->limit) { + struct sk_buff *skb = __qdisc_dequeue_head(&sch->q); + + dropped += qdisc_pkt_len(skb); + qdisc_qstats_backlog_dec(sch, skb); + rtnl_qdisc_drop(skb, sch); + } + qdisc_tree_reduce_backlog(sch, qlen - sch->q.qlen, dropped); + + sch_tree_unlock(sch); + return 0; +} + +void pie_process_dequeue(struct sk_buff *skb, struct pie_params *params, + struct pie_vars *vars, u32 backlog) +{ + psched_time_t now = psched_get_time(); + u32 dtime = 0; + + /* If dq_rate_estimator is disabled, calculate qdelay using the + * packet timestamp. + */ + if (!params->dq_rate_estimator) { + vars->qdelay = now - pie_get_enqueue_time(skb); + + if (vars->dq_tstamp != DTIME_INVALID) + dtime = now - vars->dq_tstamp; + + vars->dq_tstamp = now; + + if (backlog == 0) + vars->qdelay = 0; + + if (dtime == 0) + return; + + goto burst_allowance_reduction; + } + + /* If current queue is about 10 packets or more and dq_count is unset + * we have enough packets to calculate the drain rate. Save + * current time as dq_tstamp and start measurement cycle. + */ + if (backlog >= QUEUE_THRESHOLD && vars->dq_count == DQCOUNT_INVALID) { + vars->dq_tstamp = psched_get_time(); + vars->dq_count = 0; + } + + /* Calculate the average drain rate from this value. If queue length + * has receded to a small value viz., <= QUEUE_THRESHOLD bytes, reset + * the dq_count to -1 as we don't have enough packets to calculate the + * drain rate anymore. The following if block is entered only when we + * have a substantial queue built up (QUEUE_THRESHOLD bytes or more) + * and we calculate the drain rate for the threshold here. dq_count is + * in bytes, time difference in psched_time, hence rate is in + * bytes/psched_time. + */ + if (vars->dq_count != DQCOUNT_INVALID) { + vars->dq_count += skb->len; + + if (vars->dq_count >= QUEUE_THRESHOLD) { + u32 count = vars->dq_count << PIE_SCALE; + + dtime = now - vars->dq_tstamp; + + if (dtime == 0) + return; + + count = count / dtime; + + if (vars->avg_dq_rate == 0) + vars->avg_dq_rate = count; + else + vars->avg_dq_rate = + (vars->avg_dq_rate - + (vars->avg_dq_rate >> 3)) + (count >> 3); + + /* If the queue has receded below the threshold, we hold + * on to the last drain rate calculated, else we reset + * dq_count to 0 to re-enter the if block when the next + * packet is dequeued + */ + if (backlog < QUEUE_THRESHOLD) { + vars->dq_count = DQCOUNT_INVALID; + } else { + vars->dq_count = 0; + vars->dq_tstamp = psched_get_time(); + } + + goto burst_allowance_reduction; + } + } + + return; + +burst_allowance_reduction: + if (vars->burst_time > 0) { + if (vars->burst_time > dtime) + vars->burst_time -= dtime; + else + vars->burst_time = 0; + } +} +EXPORT_SYMBOL_GPL(pie_process_dequeue); + +void pie_calculate_probability(struct pie_params *params, struct pie_vars *vars, + u32 backlog) +{ + psched_time_t qdelay = 0; /* in pschedtime */ + psched_time_t qdelay_old = 0; /* in pschedtime */ + s64 delta = 0; /* determines the change in probability */ + u64 oldprob; + u64 alpha, beta; + u32 power; + bool update_prob = true; + + if (params->dq_rate_estimator) { + qdelay_old = vars->qdelay; + vars->qdelay_old = vars->qdelay; + + if (vars->avg_dq_rate > 0) + qdelay = (backlog << PIE_SCALE) / vars->avg_dq_rate; + else + qdelay = 0; + } else { + qdelay = vars->qdelay; + qdelay_old = vars->qdelay_old; + } + + /* If qdelay is zero and backlog is not, it means backlog is very small, + * so we do not update probabilty in this round. + */ + if (qdelay == 0 && backlog != 0) + update_prob = false; + + /* In the algorithm, alpha and beta are between 0 and 2 with typical + * value for alpha as 0.125. In this implementation, we use values 0-32 + * passed from user space to represent this. Also, alpha and beta have + * unit of HZ and need to be scaled before they can used to update + * probability. alpha/beta are updated locally below by scaling down + * by 16 to come to 0-2 range. + */ + alpha = ((u64)params->alpha * (MAX_PROB / PSCHED_TICKS_PER_SEC)) >> 4; + beta = ((u64)params->beta * (MAX_PROB / PSCHED_TICKS_PER_SEC)) >> 4; + + /* We scale alpha and beta differently depending on how heavy the + * congestion is. Please see RFC 8033 for details. + */ + if (vars->prob < MAX_PROB / 10) { + alpha >>= 1; + beta >>= 1; + + power = 100; + while (vars->prob < div_u64(MAX_PROB, power) && + power <= 1000000) { + alpha >>= 2; + beta >>= 2; + power *= 10; + } + } + + /* alpha and beta should be between 0 and 32, in multiples of 1/16 */ + delta += alpha * (qdelay - params->target); + delta += beta * (qdelay - qdelay_old); + + oldprob = vars->prob; + + /* to ensure we increase probability in steps of no more than 2% */ + if (delta > (s64)(MAX_PROB / (100 / 2)) && + vars->prob >= MAX_PROB / 10) + delta = (MAX_PROB / 100) * 2; + + /* Non-linear drop: + * Tune drop probability to increase quickly for high delays(>= 250ms) + * 250ms is derived through experiments and provides error protection + */ + + if (qdelay > (PSCHED_NS2TICKS(250 * NSEC_PER_MSEC))) + delta += MAX_PROB / (100 / 2); + + vars->prob += delta; + + if (delta > 0) { + /* prevent overflow */ + if (vars->prob < oldprob) { + vars->prob = MAX_PROB; + /* Prevent normalization error. If probability is at + * maximum value already, we normalize it here, and + * skip the check to do a non-linear drop in the next + * section. + */ + update_prob = false; + } + } else { + /* prevent underflow */ + if (vars->prob > oldprob) + vars->prob = 0; + } + + /* Non-linear drop in probability: Reduce drop probability quickly if + * delay is 0 for 2 consecutive Tupdate periods. + */ + + if (qdelay == 0 && qdelay_old == 0 && update_prob) + /* Reduce drop probability to 98.4% */ + vars->prob -= vars->prob / 64; + + vars->qdelay = qdelay; + vars->backlog_old = backlog; + + /* We restart the measurement cycle if the following conditions are met + * 1. If the delay has been low for 2 consecutive Tupdate periods + * 2. Calculated drop probability is zero + * 3. If average dq_rate_estimator is enabled, we have at least one + * estimate for the avg_dq_rate ie., is a non-zero value + */ + if ((vars->qdelay < params->target / 2) && + (vars->qdelay_old < params->target / 2) && + vars->prob == 0 && + (!params->dq_rate_estimator || vars->avg_dq_rate > 0)) { + pie_vars_init(vars); + } + + if (!params->dq_rate_estimator) + vars->qdelay_old = qdelay; +} +EXPORT_SYMBOL_GPL(pie_calculate_probability); + +static void pie_timer(struct timer_list *t) +{ + struct pie_sched_data *q = from_timer(q, t, adapt_timer); + struct Qdisc *sch = q->sch; + spinlock_t *root_lock; + + rcu_read_lock(); + root_lock = qdisc_lock(qdisc_root_sleeping(sch)); + spin_lock(root_lock); + pie_calculate_probability(&q->params, &q->vars, sch->qstats.backlog); + + /* reset the timer to fire after 'tupdate'. tupdate is in jiffies. */ + if (q->params.tupdate) + mod_timer(&q->adapt_timer, jiffies + q->params.tupdate); + spin_unlock(root_lock); + rcu_read_unlock(); +} + +static int pie_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct pie_sched_data *q = qdisc_priv(sch); + + pie_params_init(&q->params); + pie_vars_init(&q->vars); + sch->limit = q->params.limit; + + q->sch = sch; + timer_setup(&q->adapt_timer, pie_timer, 0); + + if (opt) { + int err = pie_change(sch, opt, extack); + + if (err) + return err; + } + + mod_timer(&q->adapt_timer, jiffies + HZ / 2); + return 0; +} + +static int pie_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct pie_sched_data *q = qdisc_priv(sch); + struct nlattr *opts; + + opts = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (!opts) + goto nla_put_failure; + + /* convert target from pschedtime to us */ + if (nla_put_u32(skb, TCA_PIE_TARGET, + ((u32)PSCHED_TICKS2NS(q->params.target)) / + NSEC_PER_USEC) || + nla_put_u32(skb, TCA_PIE_LIMIT, sch->limit) || + nla_put_u32(skb, TCA_PIE_TUPDATE, + jiffies_to_usecs(q->params.tupdate)) || + nla_put_u32(skb, TCA_PIE_ALPHA, q->params.alpha) || + nla_put_u32(skb, TCA_PIE_BETA, q->params.beta) || + nla_put_u32(skb, TCA_PIE_ECN, q->params.ecn) || + nla_put_u32(skb, TCA_PIE_BYTEMODE, q->params.bytemode) || + nla_put_u32(skb, TCA_PIE_DQ_RATE_ESTIMATOR, + q->params.dq_rate_estimator)) + goto nla_put_failure; + + return nla_nest_end(skb, opts); + +nla_put_failure: + nla_nest_cancel(skb, opts); + return -1; +} + +static int pie_dump_stats(struct Qdisc *sch, struct gnet_dump *d) +{ + struct pie_sched_data *q = qdisc_priv(sch); + struct tc_pie_xstats st = { + .prob = q->vars.prob << BITS_PER_BYTE, + .delay = ((u32)PSCHED_TICKS2NS(q->vars.qdelay)) / + NSEC_PER_USEC, + .packets_in = q->stats.packets_in, + .overlimit = q->stats.overlimit, + .maxq = q->stats.maxq, + .dropped = q->stats.dropped, + .ecn_mark = q->stats.ecn_mark, + }; + + /* avg_dq_rate is only valid if dq_rate_estimator is enabled */ + st.dq_rate_estimating = q->params.dq_rate_estimator; + + /* unscale and return dq_rate in bytes per sec */ + if (q->params.dq_rate_estimator) + st.avg_dq_rate = q->vars.avg_dq_rate * + (PSCHED_TICKS_PER_SEC) >> PIE_SCALE; + + return gnet_stats_copy_app(d, &st, sizeof(st)); +} + +static struct sk_buff *pie_qdisc_dequeue(struct Qdisc *sch) +{ + struct pie_sched_data *q = qdisc_priv(sch); + struct sk_buff *skb = qdisc_dequeue_head(sch); + + if (!skb) + return NULL; + + pie_process_dequeue(skb, &q->params, &q->vars, sch->qstats.backlog); + return skb; +} + +static void pie_reset(struct Qdisc *sch) +{ + struct pie_sched_data *q = qdisc_priv(sch); + + qdisc_reset_queue(sch); + pie_vars_init(&q->vars); +} + +static void pie_destroy(struct Qdisc *sch) +{ + struct pie_sched_data *q = qdisc_priv(sch); + + q->params.tupdate = 0; + del_timer_sync(&q->adapt_timer); +} + +static struct Qdisc_ops pie_qdisc_ops __read_mostly = { + .id = "pie", + .priv_size = sizeof(struct pie_sched_data), + .enqueue = pie_qdisc_enqueue, + .dequeue = pie_qdisc_dequeue, + .peek = qdisc_peek_dequeued, + .init = pie_init, + .destroy = pie_destroy, + .reset = pie_reset, + .change = pie_change, + .dump = pie_dump, + .dump_stats = pie_dump_stats, + .owner = THIS_MODULE, +}; + +static int __init pie_module_init(void) +{ + return register_qdisc(&pie_qdisc_ops); +} + +static void __exit pie_module_exit(void) +{ + unregister_qdisc(&pie_qdisc_ops); +} + +module_init(pie_module_init); +module_exit(pie_module_exit); + +MODULE_DESCRIPTION("Proportional Integral controller Enhanced (PIE) scheduler"); +MODULE_AUTHOR("Vijay Subramanian"); +MODULE_AUTHOR("Mythili Prabhu"); +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_plug.c b/net/sched/sch_plug.c new file mode 100644 index 000000000..35f49edf6 --- /dev/null +++ b/net/sched/sch_plug.c @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * sch_plug.c Queue traffic until an explicit release command + * + * There are two ways to use this qdisc: + * 1. A simple "instantaneous" plug/unplug operation, by issuing an alternating + * sequence of TCQ_PLUG_BUFFER & TCQ_PLUG_RELEASE_INDEFINITE commands. + * + * 2. For network output buffering (a.k.a output commit) functionality. + * Output commit property is commonly used by applications using checkpoint + * based fault-tolerance to ensure that the checkpoint from which a system + * is being restored is consistent w.r.t outside world. + * + * Consider for e.g. Remus - a Virtual Machine checkpointing system, + * wherein a VM is checkpointed, say every 50ms. The checkpoint is replicated + * asynchronously to the backup host, while the VM continues executing the + * next epoch speculatively. + * + * The following is a typical sequence of output buffer operations: + * 1.At epoch i, start_buffer(i) + * 2. At end of epoch i (i.e. after 50ms): + * 2.1 Stop VM and take checkpoint(i). + * 2.2 start_buffer(i+1) and Resume VM + * 3. While speculatively executing epoch(i+1), asynchronously replicate + * checkpoint(i) to backup host. + * 4. When checkpoint_ack(i) is received from backup, release_buffer(i) + * Thus, this Qdisc would receive the following sequence of commands: + * TCQ_PLUG_BUFFER (epoch i) + * .. TCQ_PLUG_BUFFER (epoch i+1) + * ....TCQ_PLUG_RELEASE_ONE (epoch i) + * ......TCQ_PLUG_BUFFER (epoch i+2) + * ........ + */ + +#include +#include +#include +#include +#include +#include +#include + +/* + * State of the queue, when used for network output buffering: + * + * plug(i+1) plug(i) head + * ------------------+--------------------+----------------> + * | | + * | | + * pkts_current_epoch| pkts_last_epoch |pkts_to_release + * ----------------->|<--------+--------->|+---------------> + * v v + * + */ + +struct plug_sched_data { + /* If true, the dequeue function releases all packets + * from head to end of the queue. The queue turns into + * a pass-through queue for newly arriving packets. + */ + bool unplug_indefinite; + + bool throttled; + + /* Queue Limit in bytes */ + u32 limit; + + /* Number of packets (output) from the current speculatively + * executing epoch. + */ + u32 pkts_current_epoch; + + /* Number of packets corresponding to the recently finished + * epoch. These will be released when we receive a + * TCQ_PLUG_RELEASE_ONE command. This command is typically + * issued after committing a checkpoint at the target. + */ + u32 pkts_last_epoch; + + /* + * Number of packets from the head of the queue, that can + * be released (committed checkpoint). + */ + u32 pkts_to_release; +}; + +static int plug_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct plug_sched_data *q = qdisc_priv(sch); + + if (likely(sch->qstats.backlog + skb->len <= q->limit)) { + if (!q->unplug_indefinite) + q->pkts_current_epoch++; + return qdisc_enqueue_tail(skb, sch); + } + + return qdisc_drop(skb, sch, to_free); +} + +static struct sk_buff *plug_dequeue(struct Qdisc *sch) +{ + struct plug_sched_data *q = qdisc_priv(sch); + + if (q->throttled) + return NULL; + + if (!q->unplug_indefinite) { + if (!q->pkts_to_release) { + /* No more packets to dequeue. Block the queue + * and wait for the next release command. + */ + q->throttled = true; + return NULL; + } + q->pkts_to_release--; + } + + return qdisc_dequeue_head(sch); +} + +static int plug_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct plug_sched_data *q = qdisc_priv(sch); + + q->pkts_current_epoch = 0; + q->pkts_last_epoch = 0; + q->pkts_to_release = 0; + q->unplug_indefinite = false; + + if (opt == NULL) { + q->limit = qdisc_dev(sch)->tx_queue_len + * psched_mtu(qdisc_dev(sch)); + } else { + struct tc_plug_qopt *ctl = nla_data(opt); + + if (nla_len(opt) < sizeof(*ctl)) + return -EINVAL; + + q->limit = ctl->limit; + } + + q->throttled = true; + return 0; +} + +/* Receives 4 types of messages: + * TCQ_PLUG_BUFFER: Inset a plug into the queue and + * buffer any incoming packets + * TCQ_PLUG_RELEASE_ONE: Dequeue packets from queue head + * to beginning of the next plug. + * TCQ_PLUG_RELEASE_INDEFINITE: Dequeue all packets from queue. + * Stop buffering packets until the next TCQ_PLUG_BUFFER + * command is received (just act as a pass-thru queue). + * TCQ_PLUG_LIMIT: Increase/decrease queue size + */ +static int plug_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct plug_sched_data *q = qdisc_priv(sch); + struct tc_plug_qopt *msg; + + msg = nla_data(opt); + if (nla_len(opt) < sizeof(*msg)) + return -EINVAL; + + switch (msg->action) { + case TCQ_PLUG_BUFFER: + /* Save size of the current buffer */ + q->pkts_last_epoch = q->pkts_current_epoch; + q->pkts_current_epoch = 0; + if (q->unplug_indefinite) + q->throttled = true; + q->unplug_indefinite = false; + break; + case TCQ_PLUG_RELEASE_ONE: + /* Add packets from the last complete buffer to the + * packets to be released set. + */ + q->pkts_to_release += q->pkts_last_epoch; + q->pkts_last_epoch = 0; + q->throttled = false; + netif_schedule_queue(sch->dev_queue); + break; + case TCQ_PLUG_RELEASE_INDEFINITE: + q->unplug_indefinite = true; + q->pkts_to_release = 0; + q->pkts_last_epoch = 0; + q->pkts_current_epoch = 0; + q->throttled = false; + netif_schedule_queue(sch->dev_queue); + break; + case TCQ_PLUG_LIMIT: + /* Limit is supplied in bytes */ + q->limit = msg->limit; + break; + default: + return -EINVAL; + } + + return 0; +} + +static struct Qdisc_ops plug_qdisc_ops __read_mostly = { + .id = "plug", + .priv_size = sizeof(struct plug_sched_data), + .enqueue = plug_enqueue, + .dequeue = plug_dequeue, + .peek = qdisc_peek_dequeued, + .init = plug_init, + .change = plug_change, + .reset = qdisc_reset_queue, + .owner = THIS_MODULE, +}; + +static int __init plug_module_init(void) +{ + return register_qdisc(&plug_qdisc_ops); +} + +static void __exit plug_module_exit(void) +{ + unregister_qdisc(&plug_qdisc_ops); +} +module_init(plug_module_init) +module_exit(plug_module_exit) +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_prio.c b/net/sched/sch_prio.c new file mode 100644 index 000000000..fdc5ef52c --- /dev/null +++ b/net/sched/sch_prio.c @@ -0,0 +1,435 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/sch_prio.c Simple 3-band priority "scheduler". + * + * Authors: Alexey Kuznetsov, + * Fixes: 19990609: J Hadi Salim : + * Init -- EINVAL when opt undefined + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct prio_sched_data { + int bands; + struct tcf_proto __rcu *filter_list; + struct tcf_block *block; + u8 prio2band[TC_PRIO_MAX+1]; + struct Qdisc *queues[TCQ_PRIO_BANDS]; +}; + + +static struct Qdisc * +prio_classify(struct sk_buff *skb, struct Qdisc *sch, int *qerr) +{ + struct prio_sched_data *q = qdisc_priv(sch); + u32 band = skb->priority; + struct tcf_result res; + struct tcf_proto *fl; + int err; + + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; + if (TC_H_MAJ(skb->priority) != sch->handle) { + fl = rcu_dereference_bh(q->filter_list); + err = tcf_classify(skb, NULL, fl, &res, false); +#ifdef CONFIG_NET_CLS_ACT + switch (err) { + case TC_ACT_STOLEN: + case TC_ACT_QUEUED: + case TC_ACT_TRAP: + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_STOLEN; + fallthrough; + case TC_ACT_SHOT: + return NULL; + } +#endif + if (!fl || err < 0) { + if (TC_H_MAJ(band)) + band = 0; + return q->queues[q->prio2band[band & TC_PRIO_MAX]]; + } + band = res.classid; + } + band = TC_H_MIN(band) - 1; + if (band >= q->bands) + return q->queues[q->prio2band[0]]; + + return q->queues[band]; +} + +static int +prio_enqueue(struct sk_buff *skb, struct Qdisc *sch, struct sk_buff **to_free) +{ + unsigned int len = qdisc_pkt_len(skb); + struct Qdisc *qdisc; + int ret; + + qdisc = prio_classify(skb, sch, &ret); +#ifdef CONFIG_NET_CLS_ACT + if (qdisc == NULL) { + + if (ret & __NET_XMIT_BYPASS) + qdisc_qstats_drop(sch); + __qdisc_drop(skb, to_free); + return ret; + } +#endif + + ret = qdisc_enqueue(skb, qdisc, to_free); + if (ret == NET_XMIT_SUCCESS) { + sch->qstats.backlog += len; + sch->q.qlen++; + return NET_XMIT_SUCCESS; + } + if (net_xmit_drop_count(ret)) + qdisc_qstats_drop(sch); + return ret; +} + +static struct sk_buff *prio_peek(struct Qdisc *sch) +{ + struct prio_sched_data *q = qdisc_priv(sch); + int prio; + + for (prio = 0; prio < q->bands; prio++) { + struct Qdisc *qdisc = q->queues[prio]; + struct sk_buff *skb = qdisc->ops->peek(qdisc); + if (skb) + return skb; + } + return NULL; +} + +static struct sk_buff *prio_dequeue(struct Qdisc *sch) +{ + struct prio_sched_data *q = qdisc_priv(sch); + int prio; + + for (prio = 0; prio < q->bands; prio++) { + struct Qdisc *qdisc = q->queues[prio]; + struct sk_buff *skb = qdisc_dequeue_peeked(qdisc); + if (skb) { + qdisc_bstats_update(sch, skb); + qdisc_qstats_backlog_dec(sch, skb); + sch->q.qlen--; + return skb; + } + } + return NULL; + +} + +static void +prio_reset(struct Qdisc *sch) +{ + int prio; + struct prio_sched_data *q = qdisc_priv(sch); + + for (prio = 0; prio < q->bands; prio++) + qdisc_reset(q->queues[prio]); +} + +static int prio_offload(struct Qdisc *sch, struct tc_prio_qopt *qopt) +{ + struct net_device *dev = qdisc_dev(sch); + struct tc_prio_qopt_offload opt = { + .handle = sch->handle, + .parent = sch->parent, + }; + + if (!tc_can_offload(dev) || !dev->netdev_ops->ndo_setup_tc) + return -EOPNOTSUPP; + + if (qopt) { + opt.command = TC_PRIO_REPLACE; + opt.replace_params.bands = qopt->bands; + memcpy(&opt.replace_params.priomap, qopt->priomap, + TC_PRIO_MAX + 1); + opt.replace_params.qstats = &sch->qstats; + } else { + opt.command = TC_PRIO_DESTROY; + } + + return dev->netdev_ops->ndo_setup_tc(dev, TC_SETUP_QDISC_PRIO, &opt); +} + +static void +prio_destroy(struct Qdisc *sch) +{ + int prio; + struct prio_sched_data *q = qdisc_priv(sch); + + tcf_block_put(q->block); + prio_offload(sch, NULL); + for (prio = 0; prio < q->bands; prio++) + qdisc_put(q->queues[prio]); +} + +static int prio_tune(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct prio_sched_data *q = qdisc_priv(sch); + struct Qdisc *queues[TCQ_PRIO_BANDS]; + int oldbands = q->bands, i; + struct tc_prio_qopt *qopt; + + if (nla_len(opt) < sizeof(*qopt)) + return -EINVAL; + qopt = nla_data(opt); + + if (qopt->bands > TCQ_PRIO_BANDS || qopt->bands < TCQ_MIN_PRIO_BANDS) + return -EINVAL; + + for (i = 0; i <= TC_PRIO_MAX; i++) { + if (qopt->priomap[i] >= qopt->bands) + return -EINVAL; + } + + /* Before commit, make sure we can allocate all new qdiscs */ + for (i = oldbands; i < qopt->bands; i++) { + queues[i] = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, + TC_H_MAKE(sch->handle, i + 1), + extack); + if (!queues[i]) { + while (i > oldbands) + qdisc_put(queues[--i]); + return -ENOMEM; + } + } + + prio_offload(sch, qopt); + sch_tree_lock(sch); + q->bands = qopt->bands; + memcpy(q->prio2band, qopt->priomap, TC_PRIO_MAX+1); + + for (i = q->bands; i < oldbands; i++) + qdisc_tree_flush_backlog(q->queues[i]); + + for (i = oldbands; i < q->bands; i++) { + q->queues[i] = queues[i]; + if (q->queues[i] != &noop_qdisc) + qdisc_hash_add(q->queues[i], true); + } + + sch_tree_unlock(sch); + + for (i = q->bands; i < oldbands; i++) + qdisc_put(q->queues[i]); + return 0; +} + +static int prio_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct prio_sched_data *q = qdisc_priv(sch); + int err; + + if (!opt) + return -EINVAL; + + err = tcf_block_get(&q->block, &q->filter_list, sch, extack); + if (err) + return err; + + return prio_tune(sch, opt, extack); +} + +static int prio_dump_offload(struct Qdisc *sch) +{ + struct tc_prio_qopt_offload hw_stats = { + .command = TC_PRIO_STATS, + .handle = sch->handle, + .parent = sch->parent, + { + .stats = { + .bstats = &sch->bstats, + .qstats = &sch->qstats, + }, + }, + }; + + return qdisc_offload_dump_helper(sch, TC_SETUP_QDISC_PRIO, &hw_stats); +} + +static int prio_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct prio_sched_data *q = qdisc_priv(sch); + unsigned char *b = skb_tail_pointer(skb); + struct tc_prio_qopt opt; + int err; + + opt.bands = q->bands; + memcpy(&opt.priomap, q->prio2band, TC_PRIO_MAX + 1); + + err = prio_dump_offload(sch); + if (err) + goto nla_put_failure; + + if (nla_put(skb, TCA_OPTIONS, sizeof(opt), &opt)) + goto nla_put_failure; + + return skb->len; + +nla_put_failure: + nlmsg_trim(skb, b); + return -1; +} + +static int prio_graft(struct Qdisc *sch, unsigned long arg, struct Qdisc *new, + struct Qdisc **old, struct netlink_ext_ack *extack) +{ + struct prio_sched_data *q = qdisc_priv(sch); + struct tc_prio_qopt_offload graft_offload; + unsigned long band = arg - 1; + + if (!new) { + new = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, + TC_H_MAKE(sch->handle, arg), extack); + if (!new) + new = &noop_qdisc; + else + qdisc_hash_add(new, true); + } + + *old = qdisc_replace(sch, new, &q->queues[band]); + + graft_offload.handle = sch->handle; + graft_offload.parent = sch->parent; + graft_offload.graft_params.band = band; + graft_offload.graft_params.child_handle = new->handle; + graft_offload.command = TC_PRIO_GRAFT; + + qdisc_offload_graft_helper(qdisc_dev(sch), sch, new, *old, + TC_SETUP_QDISC_PRIO, &graft_offload, + extack); + return 0; +} + +static struct Qdisc * +prio_leaf(struct Qdisc *sch, unsigned long arg) +{ + struct prio_sched_data *q = qdisc_priv(sch); + unsigned long band = arg - 1; + + return q->queues[band]; +} + +static unsigned long prio_find(struct Qdisc *sch, u32 classid) +{ + struct prio_sched_data *q = qdisc_priv(sch); + unsigned long band = TC_H_MIN(classid); + + if (band - 1 >= q->bands) + return 0; + return band; +} + +static unsigned long prio_bind(struct Qdisc *sch, unsigned long parent, u32 classid) +{ + return prio_find(sch, classid); +} + + +static void prio_unbind(struct Qdisc *q, unsigned long cl) +{ +} + +static int prio_dump_class(struct Qdisc *sch, unsigned long cl, struct sk_buff *skb, + struct tcmsg *tcm) +{ + struct prio_sched_data *q = qdisc_priv(sch); + + tcm->tcm_handle |= TC_H_MIN(cl); + tcm->tcm_info = q->queues[cl-1]->handle; + return 0; +} + +static int prio_dump_class_stats(struct Qdisc *sch, unsigned long cl, + struct gnet_dump *d) +{ + struct prio_sched_data *q = qdisc_priv(sch); + struct Qdisc *cl_q; + + cl_q = q->queues[cl - 1]; + if (gnet_stats_copy_basic(d, cl_q->cpu_bstats, + &cl_q->bstats, true) < 0 || + qdisc_qstats_copy(d, cl_q) < 0) + return -1; + + return 0; +} + +static void prio_walk(struct Qdisc *sch, struct qdisc_walker *arg) +{ + struct prio_sched_data *q = qdisc_priv(sch); + int prio; + + if (arg->stop) + return; + + for (prio = 0; prio < q->bands; prio++) { + if (!tc_qdisc_stats_dump(sch, prio + 1, arg)) + break; + } +} + +static struct tcf_block *prio_tcf_block(struct Qdisc *sch, unsigned long cl, + struct netlink_ext_ack *extack) +{ + struct prio_sched_data *q = qdisc_priv(sch); + + if (cl) + return NULL; + return q->block; +} + +static const struct Qdisc_class_ops prio_class_ops = { + .graft = prio_graft, + .leaf = prio_leaf, + .find = prio_find, + .walk = prio_walk, + .tcf_block = prio_tcf_block, + .bind_tcf = prio_bind, + .unbind_tcf = prio_unbind, + .dump = prio_dump_class, + .dump_stats = prio_dump_class_stats, +}; + +static struct Qdisc_ops prio_qdisc_ops __read_mostly = { + .next = NULL, + .cl_ops = &prio_class_ops, + .id = "prio", + .priv_size = sizeof(struct prio_sched_data), + .enqueue = prio_enqueue, + .dequeue = prio_dequeue, + .peek = prio_peek, + .init = prio_init, + .reset = prio_reset, + .destroy = prio_destroy, + .change = prio_tune, + .dump = prio_dump, + .owner = THIS_MODULE, +}; + +static int __init prio_module_init(void) +{ + return register_qdisc(&prio_qdisc_ops); +} + +static void __exit prio_module_exit(void) +{ + unregister_qdisc(&prio_qdisc_ops); +} + +module_init(prio_module_init) +module_exit(prio_module_exit) + +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_qfq.c b/net/sched/sch_qfq.c new file mode 100644 index 000000000..ed01634af --- /dev/null +++ b/net/sched/sch_qfq.c @@ -0,0 +1,1536 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/sched/sch_qfq.c Quick Fair Queueing Plus Scheduler. + * + * Copyright (c) 2009 Fabio Checconi, Luigi Rizzo, and Paolo Valente. + * Copyright (c) 2012 Paolo Valente. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +/* Quick Fair Queueing Plus + ======================== + + Sources: + + [1] Paolo Valente, + "Reducing the Execution Time of Fair-Queueing Schedulers." + http://algo.ing.unimo.it/people/paolo/agg-sched/agg-sched.pdf + + Sources for QFQ: + + [2] Fabio Checconi, Luigi Rizzo, and Paolo Valente: "QFQ: Efficient + Packet Scheduling with Tight Bandwidth Distribution Guarantees." + + See also: + http://retis.sssup.it/~fabio/linux/qfq/ + */ + +/* + + QFQ+ divides classes into aggregates of at most MAX_AGG_CLASSES + classes. Each aggregate is timestamped with a virtual start time S + and a virtual finish time F, and scheduled according to its + timestamps. S and F are computed as a function of a system virtual + time function V. The classes within each aggregate are instead + scheduled with DRR. + + To speed up operations, QFQ+ divides also aggregates into a limited + number of groups. Which group a class belongs to depends on the + ratio between the maximum packet length for the class and the weight + of the class. Groups have their own S and F. In the end, QFQ+ + schedules groups, then aggregates within groups, then classes within + aggregates. See [1] and [2] for a full description. + + Virtual time computations. + + S, F and V are all computed in fixed point arithmetic with + FRAC_BITS decimal bits. + + QFQ_MAX_INDEX is the maximum index allowed for a group. We need + one bit per index. + QFQ_MAX_WSHIFT is the maximum power of two supported as a weight. + + The layout of the bits is as below: + + [ MTU_SHIFT ][ FRAC_BITS ] + [ MAX_INDEX ][ MIN_SLOT_SHIFT ] + ^.__grp->index = 0 + *.__grp->slot_shift + + where MIN_SLOT_SHIFT is derived by difference from the others. + + The max group index corresponds to Lmax/w_min, where + Lmax=1<group mapping. We allow class weights that are + * in the range [1, 2^MAX_WSHIFT], and we try to map each aggregate i to the + * group with the smallest index that can support the L_i / r_i configured + * for the classes in the aggregate. + * + * grp->index is the index of the group; and grp->slot_shift + * is the shift for the corresponding (scaled) sigma_i. + */ +#define QFQ_MAX_INDEX 24 +#define QFQ_MAX_WSHIFT 10 + +#define QFQ_MAX_WEIGHT (1<clhash, classid); + if (clc == NULL) + return NULL; + return container_of(clc, struct qfq_class, common); +} + +static struct netlink_range_validation lmax_range = { + .min = QFQ_MIN_LMAX, + .max = QFQ_MAX_LMAX, +}; + +static const struct nla_policy qfq_policy[TCA_QFQ_MAX + 1] = { + [TCA_QFQ_WEIGHT] = NLA_POLICY_RANGE(NLA_U32, 1, QFQ_MAX_WEIGHT), + [TCA_QFQ_LMAX] = NLA_POLICY_FULL_RANGE(NLA_U32, &lmax_range), +}; + +/* + * Calculate a flow index, given its weight and maximum packet length. + * index = log_2(maxlen/weight) but we need to apply the scaling. + * This is used only once at flow creation. + */ +static int qfq_calc_index(u32 inv_w, unsigned int maxlen, u32 min_slot_shift) +{ + u64 slot_size = (u64)maxlen * inv_w; + unsigned long size_map; + int index = 0; + + size_map = slot_size >> min_slot_shift; + if (!size_map) + goto out; + + index = __fls(size_map) + 1; /* basically a log_2 */ + index -= !(slot_size - (1ULL << (index + min_slot_shift - 1))); + + if (index < 0) + index = 0; +out: + pr_debug("qfq calc_index: W = %lu, L = %u, I = %d\n", + (unsigned long) ONE_FP/inv_w, maxlen, index); + + return index; +} + +static void qfq_deactivate_agg(struct qfq_sched *, struct qfq_aggregate *); +static void qfq_activate_agg(struct qfq_sched *, struct qfq_aggregate *, + enum update_reason); + +static void qfq_init_agg(struct qfq_sched *q, struct qfq_aggregate *agg, + u32 lmax, u32 weight) +{ + INIT_LIST_HEAD(&agg->active); + hlist_add_head(&agg->nonfull_next, &q->nonfull_aggs); + + agg->lmax = lmax; + agg->class_weight = weight; +} + +static struct qfq_aggregate *qfq_find_agg(struct qfq_sched *q, + u32 lmax, u32 weight) +{ + struct qfq_aggregate *agg; + + hlist_for_each_entry(agg, &q->nonfull_aggs, nonfull_next) + if (agg->lmax == lmax && agg->class_weight == weight) + return agg; + + return NULL; +} + + +/* Update aggregate as a function of the new number of classes. */ +static void qfq_update_agg(struct qfq_sched *q, struct qfq_aggregate *agg, + int new_num_classes) +{ + u32 new_agg_weight; + + if (new_num_classes == q->max_agg_classes) + hlist_del_init(&agg->nonfull_next); + + if (agg->num_classes > new_num_classes && + new_num_classes == q->max_agg_classes - 1) /* agg no more full */ + hlist_add_head(&agg->nonfull_next, &q->nonfull_aggs); + + /* The next assignment may let + * agg->initial_budget > agg->budgetmax + * hold, we will take it into account in charge_actual_service(). + */ + agg->budgetmax = new_num_classes * agg->lmax; + new_agg_weight = agg->class_weight * new_num_classes; + agg->inv_w = ONE_FP/new_agg_weight; + + if (agg->grp == NULL) { + int i = qfq_calc_index(agg->inv_w, agg->budgetmax, + q->min_slot_shift); + agg->grp = &q->groups[i]; + } + + q->wsum += + (int) agg->class_weight * (new_num_classes - agg->num_classes); + q->iwsum = ONE_FP / q->wsum; + + agg->num_classes = new_num_classes; +} + +/* Add class to aggregate. */ +static void qfq_add_to_agg(struct qfq_sched *q, + struct qfq_aggregate *agg, + struct qfq_class *cl) +{ + cl->agg = agg; + + qfq_update_agg(q, agg, agg->num_classes+1); + if (cl->qdisc->q.qlen > 0) { /* adding an active class */ + list_add_tail(&cl->alist, &agg->active); + if (list_first_entry(&agg->active, struct qfq_class, alist) == + cl && q->in_serv_agg != agg) /* agg was inactive */ + qfq_activate_agg(q, agg, enqueue); /* schedule agg */ + } +} + +static struct qfq_aggregate *qfq_choose_next_agg(struct qfq_sched *); + +static void qfq_destroy_agg(struct qfq_sched *q, struct qfq_aggregate *agg) +{ + hlist_del_init(&agg->nonfull_next); + q->wsum -= agg->class_weight; + if (q->wsum != 0) + q->iwsum = ONE_FP / q->wsum; + + if (q->in_serv_agg == agg) + q->in_serv_agg = qfq_choose_next_agg(q); + kfree(agg); +} + +/* Deschedule class from within its parent aggregate. */ +static void qfq_deactivate_class(struct qfq_sched *q, struct qfq_class *cl) +{ + struct qfq_aggregate *agg = cl->agg; + + + list_del(&cl->alist); /* remove from RR queue of the aggregate */ + if (list_empty(&agg->active)) /* agg is now inactive */ + qfq_deactivate_agg(q, agg); +} + +/* Remove class from its parent aggregate. */ +static void qfq_rm_from_agg(struct qfq_sched *q, struct qfq_class *cl) +{ + struct qfq_aggregate *agg = cl->agg; + + cl->agg = NULL; + if (agg->num_classes == 1) { /* agg being emptied, destroy it */ + qfq_destroy_agg(q, agg); + return; + } + qfq_update_agg(q, agg, agg->num_classes-1); +} + +/* Deschedule class and remove it from its parent aggregate. */ +static void qfq_deact_rm_from_agg(struct qfq_sched *q, struct qfq_class *cl) +{ + if (cl->qdisc->q.qlen > 0) /* class is active */ + qfq_deactivate_class(q, cl); + + qfq_rm_from_agg(q, cl); +} + +/* Move class to a new aggregate, matching the new class weight and/or lmax */ +static int qfq_change_agg(struct Qdisc *sch, struct qfq_class *cl, u32 weight, + u32 lmax) +{ + struct qfq_sched *q = qdisc_priv(sch); + struct qfq_aggregate *new_agg; + + /* 'lmax' can range from [QFQ_MIN_LMAX, pktlen + stab overhead] */ + if (lmax > QFQ_MAX_LMAX) + return -EINVAL; + + new_agg = qfq_find_agg(q, lmax, weight); + if (new_agg == NULL) { /* create new aggregate */ + new_agg = kzalloc(sizeof(*new_agg), GFP_ATOMIC); + if (new_agg == NULL) + return -ENOBUFS; + qfq_init_agg(q, new_agg, lmax, weight); + } + qfq_deact_rm_from_agg(q, cl); + qfq_add_to_agg(q, new_agg, cl); + + return 0; +} + +static int qfq_change_class(struct Qdisc *sch, u32 classid, u32 parentid, + struct nlattr **tca, unsigned long *arg, + struct netlink_ext_ack *extack) +{ + struct qfq_sched *q = qdisc_priv(sch); + struct qfq_class *cl = (struct qfq_class *)*arg; + bool existing = false; + struct nlattr *tb[TCA_QFQ_MAX + 1]; + struct qfq_aggregate *new_agg = NULL; + u32 weight, lmax, inv_w; + int err; + int delta_w; + + if (tca[TCA_OPTIONS] == NULL) { + pr_notice("qfq: no options\n"); + return -EINVAL; + } + + err = nla_parse_nested_deprecated(tb, TCA_QFQ_MAX, tca[TCA_OPTIONS], + qfq_policy, extack); + if (err < 0) + return err; + + if (tb[TCA_QFQ_WEIGHT]) + weight = nla_get_u32(tb[TCA_QFQ_WEIGHT]); + else + weight = 1; + + if (tb[TCA_QFQ_LMAX]) { + lmax = nla_get_u32(tb[TCA_QFQ_LMAX]); + } else { + /* MTU size is user controlled */ + lmax = psched_mtu(qdisc_dev(sch)); + if (lmax < QFQ_MIN_LMAX || lmax > QFQ_MAX_LMAX) { + NL_SET_ERR_MSG_MOD(extack, + "MTU size out of bounds for qfq"); + return -EINVAL; + } + } + + inv_w = ONE_FP / weight; + weight = ONE_FP / inv_w; + + if (cl != NULL && + lmax == cl->agg->lmax && + weight == cl->agg->class_weight) + return 0; /* nothing to change */ + + delta_w = weight - (cl ? cl->agg->class_weight : 0); + + if (q->wsum + delta_w > QFQ_MAX_WSUM) { + pr_notice("qfq: total weight out of range (%d + %u)\n", + delta_w, q->wsum); + return -EINVAL; + } + + if (cl != NULL) { /* modify existing class */ + if (tca[TCA_RATE]) { + err = gen_replace_estimator(&cl->bstats, NULL, + &cl->rate_est, + NULL, + true, + tca[TCA_RATE]); + if (err) + return err; + } + existing = true; + goto set_change_agg; + } + + /* create and init new class */ + cl = kzalloc(sizeof(struct qfq_class), GFP_KERNEL); + if (cl == NULL) + return -ENOBUFS; + + gnet_stats_basic_sync_init(&cl->bstats); + cl->common.classid = classid; + cl->deficit = lmax; + + cl->qdisc = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, + classid, NULL); + if (cl->qdisc == NULL) + cl->qdisc = &noop_qdisc; + + if (tca[TCA_RATE]) { + err = gen_new_estimator(&cl->bstats, NULL, + &cl->rate_est, + NULL, + true, + tca[TCA_RATE]); + if (err) + goto destroy_class; + } + + if (cl->qdisc != &noop_qdisc) + qdisc_hash_add(cl->qdisc, true); + +set_change_agg: + sch_tree_lock(sch); + new_agg = qfq_find_agg(q, lmax, weight); + if (new_agg == NULL) { /* create new aggregate */ + sch_tree_unlock(sch); + new_agg = kzalloc(sizeof(*new_agg), GFP_KERNEL); + if (new_agg == NULL) { + err = -ENOBUFS; + gen_kill_estimator(&cl->rate_est); + goto destroy_class; + } + sch_tree_lock(sch); + qfq_init_agg(q, new_agg, lmax, weight); + } + if (existing) + qfq_deact_rm_from_agg(q, cl); + else + qdisc_class_hash_insert(&q->clhash, &cl->common); + qfq_add_to_agg(q, new_agg, cl); + sch_tree_unlock(sch); + qdisc_class_hash_grow(sch, &q->clhash); + + *arg = (unsigned long)cl; + return 0; + +destroy_class: + qdisc_put(cl->qdisc); + kfree(cl); + return err; +} + +static void qfq_destroy_class(struct Qdisc *sch, struct qfq_class *cl) +{ + struct qfq_sched *q = qdisc_priv(sch); + + qfq_rm_from_agg(q, cl); + gen_kill_estimator(&cl->rate_est); + qdisc_put(cl->qdisc); + kfree(cl); +} + +static int qfq_delete_class(struct Qdisc *sch, unsigned long arg, + struct netlink_ext_ack *extack) +{ + struct qfq_sched *q = qdisc_priv(sch); + struct qfq_class *cl = (struct qfq_class *)arg; + + if (cl->filter_cnt > 0) + return -EBUSY; + + sch_tree_lock(sch); + + qdisc_purge_queue(cl->qdisc); + qdisc_class_hash_remove(&q->clhash, &cl->common); + + sch_tree_unlock(sch); + + qfq_destroy_class(sch, cl); + return 0; +} + +static unsigned long qfq_search_class(struct Qdisc *sch, u32 classid) +{ + return (unsigned long)qfq_find_class(sch, classid); +} + +static struct tcf_block *qfq_tcf_block(struct Qdisc *sch, unsigned long cl, + struct netlink_ext_ack *extack) +{ + struct qfq_sched *q = qdisc_priv(sch); + + if (cl) + return NULL; + + return q->block; +} + +static unsigned long qfq_bind_tcf(struct Qdisc *sch, unsigned long parent, + u32 classid) +{ + struct qfq_class *cl = qfq_find_class(sch, classid); + + if (cl != NULL) + cl->filter_cnt++; + + return (unsigned long)cl; +} + +static void qfq_unbind_tcf(struct Qdisc *sch, unsigned long arg) +{ + struct qfq_class *cl = (struct qfq_class *)arg; + + cl->filter_cnt--; +} + +static int qfq_graft_class(struct Qdisc *sch, unsigned long arg, + struct Qdisc *new, struct Qdisc **old, + struct netlink_ext_ack *extack) +{ + struct qfq_class *cl = (struct qfq_class *)arg; + + if (new == NULL) { + new = qdisc_create_dflt(sch->dev_queue, &pfifo_qdisc_ops, + cl->common.classid, NULL); + if (new == NULL) + new = &noop_qdisc; + } + + *old = qdisc_replace(sch, new, &cl->qdisc); + return 0; +} + +static struct Qdisc *qfq_class_leaf(struct Qdisc *sch, unsigned long arg) +{ + struct qfq_class *cl = (struct qfq_class *)arg; + + return cl->qdisc; +} + +static int qfq_dump_class(struct Qdisc *sch, unsigned long arg, + struct sk_buff *skb, struct tcmsg *tcm) +{ + struct qfq_class *cl = (struct qfq_class *)arg; + struct nlattr *nest; + + tcm->tcm_parent = TC_H_ROOT; + tcm->tcm_handle = cl->common.classid; + tcm->tcm_info = cl->qdisc->handle; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + if (nla_put_u32(skb, TCA_QFQ_WEIGHT, cl->agg->class_weight) || + nla_put_u32(skb, TCA_QFQ_LMAX, cl->agg->lmax)) + goto nla_put_failure; + return nla_nest_end(skb, nest); + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -EMSGSIZE; +} + +static int qfq_dump_class_stats(struct Qdisc *sch, unsigned long arg, + struct gnet_dump *d) +{ + struct qfq_class *cl = (struct qfq_class *)arg; + struct tc_qfq_stats xstats; + + memset(&xstats, 0, sizeof(xstats)); + + xstats.weight = cl->agg->class_weight; + xstats.lmax = cl->agg->lmax; + + if (gnet_stats_copy_basic(d, NULL, &cl->bstats, true) < 0 || + gnet_stats_copy_rate_est(d, &cl->rate_est) < 0 || + qdisc_qstats_copy(d, cl->qdisc) < 0) + return -1; + + return gnet_stats_copy_app(d, &xstats, sizeof(xstats)); +} + +static void qfq_walk(struct Qdisc *sch, struct qdisc_walker *arg) +{ + struct qfq_sched *q = qdisc_priv(sch); + struct qfq_class *cl; + unsigned int i; + + if (arg->stop) + return; + + for (i = 0; i < q->clhash.hashsize; i++) { + hlist_for_each_entry(cl, &q->clhash.hash[i], common.hnode) { + if (!tc_qdisc_stats_dump(sch, (unsigned long)cl, arg)) + return; + } + } +} + +static struct qfq_class *qfq_classify(struct sk_buff *skb, struct Qdisc *sch, + int *qerr) +{ + struct qfq_sched *q = qdisc_priv(sch); + struct qfq_class *cl; + struct tcf_result res; + struct tcf_proto *fl; + int result; + + if (TC_H_MAJ(skb->priority ^ sch->handle) == 0) { + pr_debug("qfq_classify: found %d\n", skb->priority); + cl = qfq_find_class(sch, skb->priority); + if (cl != NULL) + return cl; + } + + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; + fl = rcu_dereference_bh(q->filter_list); + result = tcf_classify(skb, NULL, fl, &res, false); + if (result >= 0) { +#ifdef CONFIG_NET_CLS_ACT + switch (result) { + case TC_ACT_QUEUED: + case TC_ACT_STOLEN: + case TC_ACT_TRAP: + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_STOLEN; + fallthrough; + case TC_ACT_SHOT: + return NULL; + } +#endif + cl = (struct qfq_class *)res.class; + if (cl == NULL) + cl = qfq_find_class(sch, res.classid); + return cl; + } + + return NULL; +} + +/* Generic comparison function, handling wraparound. */ +static inline int qfq_gt(u64 a, u64 b) +{ + return (s64)(a - b) > 0; +} + +/* Round a precise timestamp to its slotted value. */ +static inline u64 qfq_round_down(u64 ts, unsigned int shift) +{ + return ts & ~((1ULL << shift) - 1); +} + +/* return the pointer to the group with lowest index in the bitmap */ +static inline struct qfq_group *qfq_ffs(struct qfq_sched *q, + unsigned long bitmap) +{ + int index = __ffs(bitmap); + return &q->groups[index]; +} +/* Calculate a mask to mimic what would be ffs_from(). */ +static inline unsigned long mask_from(unsigned long bitmap, int from) +{ + return bitmap & ~((1UL << from) - 1); +} + +/* + * The state computation relies on ER=0, IR=1, EB=2, IB=3 + * First compute eligibility comparing grp->S, q->V, + * then check if someone is blocking us and possibly add EB + */ +static int qfq_calc_state(struct qfq_sched *q, const struct qfq_group *grp) +{ + /* if S > V we are not eligible */ + unsigned int state = qfq_gt(grp->S, q->V); + unsigned long mask = mask_from(q->bitmaps[ER], grp->index); + struct qfq_group *next; + + if (mask) { + next = qfq_ffs(q, mask); + if (qfq_gt(grp->F, next->F)) + state |= EB; + } + + return state; +} + + +/* + * In principle + * q->bitmaps[dst] |= q->bitmaps[src] & mask; + * q->bitmaps[src] &= ~mask; + * but we should make sure that src != dst + */ +static inline void qfq_move_groups(struct qfq_sched *q, unsigned long mask, + int src, int dst) +{ + q->bitmaps[dst] |= q->bitmaps[src] & mask; + q->bitmaps[src] &= ~mask; +} + +static void qfq_unblock_groups(struct qfq_sched *q, int index, u64 old_F) +{ + unsigned long mask = mask_from(q->bitmaps[ER], index + 1); + struct qfq_group *next; + + if (mask) { + next = qfq_ffs(q, mask); + if (!qfq_gt(next->F, old_F)) + return; + } + + mask = (1UL << index) - 1; + qfq_move_groups(q, mask, EB, ER); + qfq_move_groups(q, mask, IB, IR); +} + +/* + * perhaps + * + old_V ^= q->V; + old_V >>= q->min_slot_shift; + if (old_V) { + ... + } + * + */ +static void qfq_make_eligible(struct qfq_sched *q) +{ + unsigned long vslot = q->V >> q->min_slot_shift; + unsigned long old_vslot = q->oldV >> q->min_slot_shift; + + if (vslot != old_vslot) { + unsigned long mask; + int last_flip_pos = fls(vslot ^ old_vslot); + + if (last_flip_pos > 31) /* higher than the number of groups */ + mask = ~0UL; /* make all groups eligible */ + else + mask = (1UL << last_flip_pos) - 1; + + qfq_move_groups(q, mask, IR, ER); + qfq_move_groups(q, mask, IB, EB); + } +} + +/* + * The index of the slot in which the input aggregate agg is to be + * inserted must not be higher than QFQ_MAX_SLOTS-2. There is a '-2' + * and not a '-1' because the start time of the group may be moved + * backward by one slot after the aggregate has been inserted, and + * this would cause non-empty slots to be right-shifted by one + * position. + * + * QFQ+ fully satisfies this bound to the slot index if the parameters + * of the classes are not changed dynamically, and if QFQ+ never + * happens to postpone the service of agg unjustly, i.e., it never + * happens that the aggregate becomes backlogged and eligible, or just + * eligible, while an aggregate with a higher approximated finish time + * is being served. In particular, in this case QFQ+ guarantees that + * the timestamps of agg are low enough that the slot index is never + * higher than 2. Unfortunately, QFQ+ cannot provide the same + * guarantee if it happens to unjustly postpone the service of agg, or + * if the parameters of some class are changed. + * + * As for the first event, i.e., an out-of-order service, the + * upper bound to the slot index guaranteed by QFQ+ grows to + * 2 + + * QFQ_MAX_AGG_CLASSES * ((1<S) >> grp->slot_shift; + unsigned int i; /* slot index in the bucket list */ + + if (unlikely(slot > QFQ_MAX_SLOTS - 2)) { + u64 deltaS = roundedS - grp->S - + ((u64)(QFQ_MAX_SLOTS - 2)<slot_shift); + agg->S -= deltaS; + agg->F -= deltaS; + slot = QFQ_MAX_SLOTS - 2; + } + + i = (grp->front + slot) % QFQ_MAX_SLOTS; + + hlist_add_head(&agg->next, &grp->slots[i]); + __set_bit(slot, &grp->full_slots); +} + +/* Maybe introduce hlist_first_entry?? */ +static struct qfq_aggregate *qfq_slot_head(struct qfq_group *grp) +{ + return hlist_entry(grp->slots[grp->front].first, + struct qfq_aggregate, next); +} + +/* + * remove the entry from the slot + */ +static void qfq_front_slot_remove(struct qfq_group *grp) +{ + struct qfq_aggregate *agg = qfq_slot_head(grp); + + BUG_ON(!agg); + hlist_del(&agg->next); + if (hlist_empty(&grp->slots[grp->front])) + __clear_bit(0, &grp->full_slots); +} + +/* + * Returns the first aggregate in the first non-empty bucket of the + * group. As a side effect, adjusts the bucket list so the first + * non-empty bucket is at position 0 in full_slots. + */ +static struct qfq_aggregate *qfq_slot_scan(struct qfq_group *grp) +{ + unsigned int i; + + pr_debug("qfq slot_scan: grp %u full %#lx\n", + grp->index, grp->full_slots); + + if (grp->full_slots == 0) + return NULL; + + i = __ffs(grp->full_slots); /* zero based */ + if (i > 0) { + grp->front = (grp->front + i) % QFQ_MAX_SLOTS; + grp->full_slots >>= i; + } + + return qfq_slot_head(grp); +} + +/* + * adjust the bucket list. When the start time of a group decreases, + * we move the index down (modulo QFQ_MAX_SLOTS) so we don't need to + * move the objects. The mask of occupied slots must be shifted + * because we use ffs() to find the first non-empty slot. + * This covers decreases in the group's start time, but what about + * increases of the start time ? + * Here too we should make sure that i is less than 32 + */ +static void qfq_slot_rotate(struct qfq_group *grp, u64 roundedS) +{ + unsigned int i = (grp->S - roundedS) >> grp->slot_shift; + + grp->full_slots <<= i; + grp->front = (grp->front - i) % QFQ_MAX_SLOTS; +} + +static void qfq_update_eligible(struct qfq_sched *q) +{ + struct qfq_group *grp; + unsigned long ineligible; + + ineligible = q->bitmaps[IR] | q->bitmaps[IB]; + if (ineligible) { + if (!q->bitmaps[ER]) { + grp = qfq_ffs(q, ineligible); + if (qfq_gt(grp->S, q->V)) + q->V = grp->S; + } + qfq_make_eligible(q); + } +} + +/* Dequeue head packet of the head class in the DRR queue of the aggregate. */ +static struct sk_buff *agg_dequeue(struct qfq_aggregate *agg, + struct qfq_class *cl, unsigned int len) +{ + struct sk_buff *skb = qdisc_dequeue_peeked(cl->qdisc); + + if (!skb) + return NULL; + + cl->deficit -= (int) len; + + if (cl->qdisc->q.qlen == 0) /* no more packets, remove from list */ + list_del(&cl->alist); + else if (cl->deficit < qdisc_pkt_len(cl->qdisc->ops->peek(cl->qdisc))) { + cl->deficit += agg->lmax; + list_move_tail(&cl->alist, &agg->active); + } + + return skb; +} + +static inline struct sk_buff *qfq_peek_skb(struct qfq_aggregate *agg, + struct qfq_class **cl, + unsigned int *len) +{ + struct sk_buff *skb; + + *cl = list_first_entry(&agg->active, struct qfq_class, alist); + skb = (*cl)->qdisc->ops->peek((*cl)->qdisc); + if (skb == NULL) + WARN_ONCE(1, "qfq_dequeue: non-workconserving leaf\n"); + else + *len = qdisc_pkt_len(skb); + + return skb; +} + +/* Update F according to the actual service received by the aggregate. */ +static inline void charge_actual_service(struct qfq_aggregate *agg) +{ + /* Compute the service received by the aggregate, taking into + * account that, after decreasing the number of classes in + * agg, it may happen that + * agg->initial_budget - agg->budget > agg->bugdetmax + */ + u32 service_received = min(agg->budgetmax, + agg->initial_budget - agg->budget); + + agg->F = agg->S + (u64)service_received * agg->inv_w; +} + +/* Assign a reasonable start time for a new aggregate in group i. + * Admissible values for \hat(F) are multiples of \sigma_i + * no greater than V+\sigma_i . Larger values mean that + * we had a wraparound so we consider the timestamp to be stale. + * + * If F is not stale and F >= V then we set S = F. + * Otherwise we should assign S = V, but this may violate + * the ordering in EB (see [2]). So, if we have groups in ER, + * set S to the F_j of the first group j which would be blocking us. + * We are guaranteed not to move S backward because + * otherwise our group i would still be blocked. + */ +static void qfq_update_start(struct qfq_sched *q, struct qfq_aggregate *agg) +{ + unsigned long mask; + u64 limit, roundedF; + int slot_shift = agg->grp->slot_shift; + + roundedF = qfq_round_down(agg->F, slot_shift); + limit = qfq_round_down(q->V, slot_shift) + (1ULL << slot_shift); + + if (!qfq_gt(agg->F, q->V) || qfq_gt(roundedF, limit)) { + /* timestamp was stale */ + mask = mask_from(q->bitmaps[ER], agg->grp->index); + if (mask) { + struct qfq_group *next = qfq_ffs(q, mask); + if (qfq_gt(roundedF, next->F)) { + if (qfq_gt(limit, next->F)) + agg->S = next->F; + else /* preserve timestamp correctness */ + agg->S = limit; + return; + } + } + agg->S = q->V; + } else /* timestamp is not stale */ + agg->S = agg->F; +} + +/* Update the timestamps of agg before scheduling/rescheduling it for + * service. In particular, assign to agg->F its maximum possible + * value, i.e., the virtual finish time with which the aggregate + * should be labeled if it used all its budget once in service. + */ +static inline void +qfq_update_agg_ts(struct qfq_sched *q, + struct qfq_aggregate *agg, enum update_reason reason) +{ + if (reason != requeue) + qfq_update_start(q, agg); + else /* just charge agg for the service received */ + agg->S = agg->F; + + agg->F = agg->S + (u64)agg->budgetmax * agg->inv_w; +} + +static void qfq_schedule_agg(struct qfq_sched *q, struct qfq_aggregate *agg); + +static struct sk_buff *qfq_dequeue(struct Qdisc *sch) +{ + struct qfq_sched *q = qdisc_priv(sch); + struct qfq_aggregate *in_serv_agg = q->in_serv_agg; + struct qfq_class *cl; + struct sk_buff *skb = NULL; + /* next-packet len, 0 means no more active classes in in-service agg */ + unsigned int len = 0; + + if (in_serv_agg == NULL) + return NULL; + + if (!list_empty(&in_serv_agg->active)) + skb = qfq_peek_skb(in_serv_agg, &cl, &len); + + /* + * If there are no active classes in the in-service aggregate, + * or if the aggregate has not enough budget to serve its next + * class, then choose the next aggregate to serve. + */ + if (len == 0 || in_serv_agg->budget < len) { + charge_actual_service(in_serv_agg); + + /* recharge the budget of the aggregate */ + in_serv_agg->initial_budget = in_serv_agg->budget = + in_serv_agg->budgetmax; + + if (!list_empty(&in_serv_agg->active)) { + /* + * Still active: reschedule for + * service. Possible optimization: if no other + * aggregate is active, then there is no point + * in rescheduling this aggregate, and we can + * just keep it as the in-service one. This + * should be however a corner case, and to + * handle it, we would need to maintain an + * extra num_active_aggs field. + */ + qfq_update_agg_ts(q, in_serv_agg, requeue); + qfq_schedule_agg(q, in_serv_agg); + } else if (sch->q.qlen == 0) { /* no aggregate to serve */ + q->in_serv_agg = NULL; + return NULL; + } + + /* + * If we get here, there are other aggregates queued: + * choose the new aggregate to serve. + */ + in_serv_agg = q->in_serv_agg = qfq_choose_next_agg(q); + skb = qfq_peek_skb(in_serv_agg, &cl, &len); + } + if (!skb) + return NULL; + + sch->q.qlen--; + + skb = agg_dequeue(in_serv_agg, cl, len); + + if (!skb) { + sch->q.qlen++; + return NULL; + } + + qdisc_qstats_backlog_dec(sch, skb); + qdisc_bstats_update(sch, skb); + + /* If lmax is lowered, through qfq_change_class, for a class + * owning pending packets with larger size than the new value + * of lmax, then the following condition may hold. + */ + if (unlikely(in_serv_agg->budget < len)) + in_serv_agg->budget = 0; + else + in_serv_agg->budget -= len; + + q->V += (u64)len * q->iwsum; + pr_debug("qfq dequeue: len %u F %lld now %lld\n", + len, (unsigned long long) in_serv_agg->F, + (unsigned long long) q->V); + + return skb; +} + +static struct qfq_aggregate *qfq_choose_next_agg(struct qfq_sched *q) +{ + struct qfq_group *grp; + struct qfq_aggregate *agg, *new_front_agg; + u64 old_F; + + qfq_update_eligible(q); + q->oldV = q->V; + + if (!q->bitmaps[ER]) + return NULL; + + grp = qfq_ffs(q, q->bitmaps[ER]); + old_F = grp->F; + + agg = qfq_slot_head(grp); + + /* agg starts to be served, remove it from schedule */ + qfq_front_slot_remove(grp); + + new_front_agg = qfq_slot_scan(grp); + + if (new_front_agg == NULL) /* group is now inactive, remove from ER */ + __clear_bit(grp->index, &q->bitmaps[ER]); + else { + u64 roundedS = qfq_round_down(new_front_agg->S, + grp->slot_shift); + unsigned int s; + + if (grp->S == roundedS) + return agg; + grp->S = roundedS; + grp->F = roundedS + (2ULL << grp->slot_shift); + __clear_bit(grp->index, &q->bitmaps[ER]); + s = qfq_calc_state(q, grp); + __set_bit(grp->index, &q->bitmaps[s]); + } + + qfq_unblock_groups(q, grp->index, old_F); + + return agg; +} + +static int qfq_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + unsigned int len = qdisc_pkt_len(skb), gso_segs; + struct qfq_sched *q = qdisc_priv(sch); + struct qfq_class *cl; + struct qfq_aggregate *agg; + int err = 0; + bool first; + + cl = qfq_classify(skb, sch, &err); + if (cl == NULL) { + if (err & __NET_XMIT_BYPASS) + qdisc_qstats_drop(sch); + __qdisc_drop(skb, to_free); + return err; + } + pr_debug("qfq_enqueue: cl = %x\n", cl->common.classid); + + if (unlikely(cl->agg->lmax < len)) { + pr_debug("qfq: increasing maxpkt from %u to %u for class %u", + cl->agg->lmax, len, cl->common.classid); + err = qfq_change_agg(sch, cl, cl->agg->class_weight, len); + if (err) { + cl->qstats.drops++; + return qdisc_drop(skb, sch, to_free); + } + } + + gso_segs = skb_is_gso(skb) ? skb_shinfo(skb)->gso_segs : 1; + first = !cl->qdisc->q.qlen; + err = qdisc_enqueue(skb, cl->qdisc, to_free); + if (unlikely(err != NET_XMIT_SUCCESS)) { + pr_debug("qfq_enqueue: enqueue failed %d\n", err); + if (net_xmit_drop_count(err)) { + cl->qstats.drops++; + qdisc_qstats_drop(sch); + } + return err; + } + + _bstats_update(&cl->bstats, len, gso_segs); + sch->qstats.backlog += len; + ++sch->q.qlen; + + agg = cl->agg; + /* if the queue was not empty, then done here */ + if (!first) { + if (unlikely(skb == cl->qdisc->ops->peek(cl->qdisc)) && + list_first_entry(&agg->active, struct qfq_class, alist) + == cl && cl->deficit < len) + list_move_tail(&cl->alist, &agg->active); + + return err; + } + + /* schedule class for service within the aggregate */ + cl->deficit = agg->lmax; + list_add_tail(&cl->alist, &agg->active); + + if (list_first_entry(&agg->active, struct qfq_class, alist) != cl || + q->in_serv_agg == agg) + return err; /* non-empty or in service, nothing else to do */ + + qfq_activate_agg(q, agg, enqueue); + + return err; +} + +/* + * Schedule aggregate according to its timestamps. + */ +static void qfq_schedule_agg(struct qfq_sched *q, struct qfq_aggregate *agg) +{ + struct qfq_group *grp = agg->grp; + u64 roundedS; + int s; + + roundedS = qfq_round_down(agg->S, grp->slot_shift); + + /* + * Insert agg in the correct bucket. + * If agg->S >= grp->S we don't need to adjust the + * bucket list and simply go to the insertion phase. + * Otherwise grp->S is decreasing, we must make room + * in the bucket list, and also recompute the group state. + * Finally, if there were no flows in this group and nobody + * was in ER make sure to adjust V. + */ + if (grp->full_slots) { + if (!qfq_gt(grp->S, agg->S)) + goto skip_update; + + /* create a slot for this agg->S */ + qfq_slot_rotate(grp, roundedS); + /* group was surely ineligible, remove */ + __clear_bit(grp->index, &q->bitmaps[IR]); + __clear_bit(grp->index, &q->bitmaps[IB]); + } else if (!q->bitmaps[ER] && qfq_gt(roundedS, q->V) && + q->in_serv_agg == NULL) + q->V = roundedS; + + grp->S = roundedS; + grp->F = roundedS + (2ULL << grp->slot_shift); + s = qfq_calc_state(q, grp); + __set_bit(grp->index, &q->bitmaps[s]); + + pr_debug("qfq enqueue: new state %d %#lx S %lld F %lld V %lld\n", + s, q->bitmaps[s], + (unsigned long long) agg->S, + (unsigned long long) agg->F, + (unsigned long long) q->V); + +skip_update: + qfq_slot_insert(grp, agg, roundedS); +} + + +/* Update agg ts and schedule agg for service */ +static void qfq_activate_agg(struct qfq_sched *q, struct qfq_aggregate *agg, + enum update_reason reason) +{ + agg->initial_budget = agg->budget = agg->budgetmax; /* recharge budg. */ + + qfq_update_agg_ts(q, agg, reason); + if (q->in_serv_agg == NULL) { /* no aggr. in service or scheduled */ + q->in_serv_agg = agg; /* start serving this aggregate */ + /* update V: to be in service, agg must be eligible */ + q->oldV = q->V = agg->S; + } else if (agg != q->in_serv_agg) + qfq_schedule_agg(q, agg); +} + +static void qfq_slot_remove(struct qfq_sched *q, struct qfq_group *grp, + struct qfq_aggregate *agg) +{ + unsigned int i, offset; + u64 roundedS; + + roundedS = qfq_round_down(agg->S, grp->slot_shift); + offset = (roundedS - grp->S) >> grp->slot_shift; + + i = (grp->front + offset) % QFQ_MAX_SLOTS; + + hlist_del(&agg->next); + if (hlist_empty(&grp->slots[i])) + __clear_bit(offset, &grp->full_slots); +} + +/* + * Called to forcibly deschedule an aggregate. If the aggregate is + * not in the front bucket, or if the latter has other aggregates in + * the front bucket, we can simply remove the aggregate with no other + * side effects. + * Otherwise we must propagate the event up. + */ +static void qfq_deactivate_agg(struct qfq_sched *q, struct qfq_aggregate *agg) +{ + struct qfq_group *grp = agg->grp; + unsigned long mask; + u64 roundedS; + int s; + + if (agg == q->in_serv_agg) { + charge_actual_service(agg); + q->in_serv_agg = qfq_choose_next_agg(q); + return; + } + + agg->F = agg->S; + qfq_slot_remove(q, grp, agg); + + if (!grp->full_slots) { + __clear_bit(grp->index, &q->bitmaps[IR]); + __clear_bit(grp->index, &q->bitmaps[EB]); + __clear_bit(grp->index, &q->bitmaps[IB]); + + if (test_bit(grp->index, &q->bitmaps[ER]) && + !(q->bitmaps[ER] & ~((1UL << grp->index) - 1))) { + mask = q->bitmaps[ER] & ((1UL << grp->index) - 1); + if (mask) + mask = ~((1UL << __fls(mask)) - 1); + else + mask = ~0UL; + qfq_move_groups(q, mask, EB, ER); + qfq_move_groups(q, mask, IB, IR); + } + __clear_bit(grp->index, &q->bitmaps[ER]); + } else if (hlist_empty(&grp->slots[grp->front])) { + agg = qfq_slot_scan(grp); + roundedS = qfq_round_down(agg->S, grp->slot_shift); + if (grp->S != roundedS) { + __clear_bit(grp->index, &q->bitmaps[ER]); + __clear_bit(grp->index, &q->bitmaps[IR]); + __clear_bit(grp->index, &q->bitmaps[EB]); + __clear_bit(grp->index, &q->bitmaps[IB]); + grp->S = roundedS; + grp->F = roundedS + (2ULL << grp->slot_shift); + s = qfq_calc_state(q, grp); + __set_bit(grp->index, &q->bitmaps[s]); + } + } +} + +static void qfq_qlen_notify(struct Qdisc *sch, unsigned long arg) +{ + struct qfq_sched *q = qdisc_priv(sch); + struct qfq_class *cl = (struct qfq_class *)arg; + + qfq_deactivate_class(q, cl); +} + +static int qfq_init_qdisc(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct qfq_sched *q = qdisc_priv(sch); + struct qfq_group *grp; + int i, j, err; + u32 max_cl_shift, maxbudg_shift, max_classes; + + err = tcf_block_get(&q->block, &q->filter_list, sch, extack); + if (err) + return err; + + err = qdisc_class_hash_init(&q->clhash); + if (err < 0) + return err; + + max_classes = min_t(u64, (u64)qdisc_dev(sch)->tx_queue_len + 1, + QFQ_MAX_AGG_CLASSES); + /* max_cl_shift = floor(log_2(max_classes)) */ + max_cl_shift = __fls(max_classes); + q->max_agg_classes = 1<min_slot_shift = FRAC_BITS + maxbudg_shift - QFQ_MAX_INDEX; + + for (i = 0; i <= QFQ_MAX_INDEX; i++) { + grp = &q->groups[i]; + grp->index = i; + grp->slot_shift = q->min_slot_shift + i; + for (j = 0; j < QFQ_MAX_SLOTS; j++) + INIT_HLIST_HEAD(&grp->slots[j]); + } + + INIT_HLIST_HEAD(&q->nonfull_aggs); + + return 0; +} + +static void qfq_reset_qdisc(struct Qdisc *sch) +{ + struct qfq_sched *q = qdisc_priv(sch); + struct qfq_class *cl; + unsigned int i; + + for (i = 0; i < q->clhash.hashsize; i++) { + hlist_for_each_entry(cl, &q->clhash.hash[i], common.hnode) { + if (cl->qdisc->q.qlen > 0) + qfq_deactivate_class(q, cl); + + qdisc_reset(cl->qdisc); + } + } +} + +static void qfq_destroy_qdisc(struct Qdisc *sch) +{ + struct qfq_sched *q = qdisc_priv(sch); + struct qfq_class *cl; + struct hlist_node *next; + unsigned int i; + + tcf_block_put(q->block); + + for (i = 0; i < q->clhash.hashsize; i++) { + hlist_for_each_entry_safe(cl, next, &q->clhash.hash[i], + common.hnode) { + qfq_destroy_class(sch, cl); + } + } + qdisc_class_hash_destroy(&q->clhash); +} + +static const struct Qdisc_class_ops qfq_class_ops = { + .change = qfq_change_class, + .delete = qfq_delete_class, + .find = qfq_search_class, + .tcf_block = qfq_tcf_block, + .bind_tcf = qfq_bind_tcf, + .unbind_tcf = qfq_unbind_tcf, + .graft = qfq_graft_class, + .leaf = qfq_class_leaf, + .qlen_notify = qfq_qlen_notify, + .dump = qfq_dump_class, + .dump_stats = qfq_dump_class_stats, + .walk = qfq_walk, +}; + +static struct Qdisc_ops qfq_qdisc_ops __read_mostly = { + .cl_ops = &qfq_class_ops, + .id = "qfq", + .priv_size = sizeof(struct qfq_sched), + .enqueue = qfq_enqueue, + .dequeue = qfq_dequeue, + .peek = qdisc_peek_dequeued, + .init = qfq_init_qdisc, + .reset = qfq_reset_qdisc, + .destroy = qfq_destroy_qdisc, + .owner = THIS_MODULE, +}; + +static int __init qfq_init(void) +{ + return register_qdisc(&qfq_qdisc_ops); +} + +static void __exit qfq_exit(void) +{ + unregister_qdisc(&qfq_qdisc_ops); +} + +module_init(qfq_init); +module_exit(qfq_exit); +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_red.c b/net/sched/sch_red.c new file mode 100644 index 000000000..16277b6a0 --- /dev/null +++ b/net/sched/sch_red.c @@ -0,0 +1,565 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/sch_red.c Random Early Detection queue. + * + * Authors: Alexey Kuznetsov, + * + * Changes: + * J Hadi Salim 980914: computation fixes + * Alexey Makarenko 990814: qave on idle link was calculated incorrectly. + * J Hadi Salim 980816: ECN support + */ + +#include +#include +#include +#include +#include +#include +#include +#include + + +/* Parameters, settable by user: + ----------------------------- + + limit - bytes (must be > qth_max + burst) + + Hard limit on queue length, should be chosen >qth_max + to allow packet bursts. This parameter does not + affect the algorithms behaviour and can be chosen + arbitrarily high (well, less than ram size) + Really, this limit will never be reached + if RED works correctly. + */ + +struct red_sched_data { + u32 limit; /* HARD maximal queue length */ + + unsigned char flags; + /* Non-flags in tc_red_qopt.flags. */ + unsigned char userbits; + + struct timer_list adapt_timer; + struct Qdisc *sch; + struct red_parms parms; + struct red_vars vars; + struct red_stats stats; + struct Qdisc *qdisc; + struct tcf_qevent qe_early_drop; + struct tcf_qevent qe_mark; +}; + +#define TC_RED_SUPPORTED_FLAGS (TC_RED_HISTORIC_FLAGS | TC_RED_NODROP) + +static inline int red_use_ecn(struct red_sched_data *q) +{ + return q->flags & TC_RED_ECN; +} + +static inline int red_use_harddrop(struct red_sched_data *q) +{ + return q->flags & TC_RED_HARDDROP; +} + +static int red_use_nodrop(struct red_sched_data *q) +{ + return q->flags & TC_RED_NODROP; +} + +static int red_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct red_sched_data *q = qdisc_priv(sch); + struct Qdisc *child = q->qdisc; + unsigned int len; + int ret; + + q->vars.qavg = red_calc_qavg(&q->parms, + &q->vars, + child->qstats.backlog); + + if (red_is_idling(&q->vars)) + red_end_of_idle_period(&q->vars); + + switch (red_action(&q->parms, &q->vars, q->vars.qavg)) { + case RED_DONT_MARK: + break; + + case RED_PROB_MARK: + qdisc_qstats_overlimit(sch); + if (!red_use_ecn(q)) { + q->stats.prob_drop++; + goto congestion_drop; + } + + if (INET_ECN_set_ce(skb)) { + q->stats.prob_mark++; + skb = tcf_qevent_handle(&q->qe_mark, sch, skb, to_free, &ret); + if (!skb) + return NET_XMIT_CN | ret; + } else if (!red_use_nodrop(q)) { + q->stats.prob_drop++; + goto congestion_drop; + } + + /* Non-ECT packet in ECN nodrop mode: queue it. */ + break; + + case RED_HARD_MARK: + qdisc_qstats_overlimit(sch); + if (red_use_harddrop(q) || !red_use_ecn(q)) { + q->stats.forced_drop++; + goto congestion_drop; + } + + if (INET_ECN_set_ce(skb)) { + q->stats.forced_mark++; + skb = tcf_qevent_handle(&q->qe_mark, sch, skb, to_free, &ret); + if (!skb) + return NET_XMIT_CN | ret; + } else if (!red_use_nodrop(q)) { + q->stats.forced_drop++; + goto congestion_drop; + } + + /* Non-ECT packet in ECN nodrop mode: queue it. */ + break; + } + + len = qdisc_pkt_len(skb); + ret = qdisc_enqueue(skb, child, to_free); + if (likely(ret == NET_XMIT_SUCCESS)) { + sch->qstats.backlog += len; + sch->q.qlen++; + } else if (net_xmit_drop_count(ret)) { + q->stats.pdrop++; + qdisc_qstats_drop(sch); + } + return ret; + +congestion_drop: + skb = tcf_qevent_handle(&q->qe_early_drop, sch, skb, to_free, &ret); + if (!skb) + return NET_XMIT_CN | ret; + + qdisc_drop(skb, sch, to_free); + return NET_XMIT_CN; +} + +static struct sk_buff *red_dequeue(struct Qdisc *sch) +{ + struct sk_buff *skb; + struct red_sched_data *q = qdisc_priv(sch); + struct Qdisc *child = q->qdisc; + + skb = child->dequeue(child); + if (skb) { + qdisc_bstats_update(sch, skb); + qdisc_qstats_backlog_dec(sch, skb); + sch->q.qlen--; + } else { + if (!red_is_idling(&q->vars)) + red_start_of_idle_period(&q->vars); + } + return skb; +} + +static struct sk_buff *red_peek(struct Qdisc *sch) +{ + struct red_sched_data *q = qdisc_priv(sch); + struct Qdisc *child = q->qdisc; + + return child->ops->peek(child); +} + +static void red_reset(struct Qdisc *sch) +{ + struct red_sched_data *q = qdisc_priv(sch); + + qdisc_reset(q->qdisc); + red_restart(&q->vars); +} + +static int red_offload(struct Qdisc *sch, bool enable) +{ + struct red_sched_data *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct tc_red_qopt_offload opt = { + .handle = sch->handle, + .parent = sch->parent, + }; + + if (!tc_can_offload(dev) || !dev->netdev_ops->ndo_setup_tc) + return -EOPNOTSUPP; + + if (enable) { + opt.command = TC_RED_REPLACE; + opt.set.min = q->parms.qth_min >> q->parms.Wlog; + opt.set.max = q->parms.qth_max >> q->parms.Wlog; + opt.set.probability = q->parms.max_P; + opt.set.limit = q->limit; + opt.set.is_ecn = red_use_ecn(q); + opt.set.is_harddrop = red_use_harddrop(q); + opt.set.is_nodrop = red_use_nodrop(q); + opt.set.qstats = &sch->qstats; + } else { + opt.command = TC_RED_DESTROY; + } + + return dev->netdev_ops->ndo_setup_tc(dev, TC_SETUP_QDISC_RED, &opt); +} + +static void red_destroy(struct Qdisc *sch) +{ + struct red_sched_data *q = qdisc_priv(sch); + + tcf_qevent_destroy(&q->qe_mark, sch); + tcf_qevent_destroy(&q->qe_early_drop, sch); + del_timer_sync(&q->adapt_timer); + red_offload(sch, false); + qdisc_put(q->qdisc); +} + +static const struct nla_policy red_policy[TCA_RED_MAX + 1] = { + [TCA_RED_UNSPEC] = { .strict_start_type = TCA_RED_FLAGS }, + [TCA_RED_PARMS] = { .len = sizeof(struct tc_red_qopt) }, + [TCA_RED_STAB] = { .len = RED_STAB_SIZE }, + [TCA_RED_MAX_P] = { .type = NLA_U32 }, + [TCA_RED_FLAGS] = NLA_POLICY_BITFIELD32(TC_RED_SUPPORTED_FLAGS), + [TCA_RED_EARLY_DROP_BLOCK] = { .type = NLA_U32 }, + [TCA_RED_MARK_BLOCK] = { .type = NLA_U32 }, +}; + +static int __red_change(struct Qdisc *sch, struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct Qdisc *old_child = NULL, *child = NULL; + struct red_sched_data *q = qdisc_priv(sch); + struct nla_bitfield32 flags_bf; + struct tc_red_qopt *ctl; + unsigned char userbits; + unsigned char flags; + int err; + u32 max_P; + u8 *stab; + + if (tb[TCA_RED_PARMS] == NULL || + tb[TCA_RED_STAB] == NULL) + return -EINVAL; + + max_P = tb[TCA_RED_MAX_P] ? nla_get_u32(tb[TCA_RED_MAX_P]) : 0; + + ctl = nla_data(tb[TCA_RED_PARMS]); + stab = nla_data(tb[TCA_RED_STAB]); + if (!red_check_params(ctl->qth_min, ctl->qth_max, ctl->Wlog, + ctl->Scell_log, stab)) + return -EINVAL; + + err = red_get_flags(ctl->flags, TC_RED_HISTORIC_FLAGS, + tb[TCA_RED_FLAGS], TC_RED_SUPPORTED_FLAGS, + &flags_bf, &userbits, extack); + if (err) + return err; + + if (ctl->limit > 0) { + child = fifo_create_dflt(sch, &bfifo_qdisc_ops, ctl->limit, + extack); + if (IS_ERR(child)) + return PTR_ERR(child); + + /* child is fifo, no need to check for noop_qdisc */ + qdisc_hash_add(child, true); + } + + sch_tree_lock(sch); + + flags = (q->flags & ~flags_bf.selector) | flags_bf.value; + err = red_validate_flags(flags, extack); + if (err) + goto unlock_out; + + q->flags = flags; + q->userbits = userbits; + q->limit = ctl->limit; + if (child) { + qdisc_tree_flush_backlog(q->qdisc); + old_child = q->qdisc; + q->qdisc = child; + } + + red_set_parms(&q->parms, + ctl->qth_min, ctl->qth_max, ctl->Wlog, + ctl->Plog, ctl->Scell_log, + stab, + max_P); + red_set_vars(&q->vars); + + del_timer(&q->adapt_timer); + if (ctl->flags & TC_RED_ADAPTATIVE) + mod_timer(&q->adapt_timer, jiffies + HZ/2); + + if (!q->qdisc->q.qlen) + red_start_of_idle_period(&q->vars); + + sch_tree_unlock(sch); + + red_offload(sch, true); + + if (old_child) + qdisc_put(old_child); + return 0; + +unlock_out: + sch_tree_unlock(sch); + if (child) + qdisc_put(child); + return err; +} + +static inline void red_adaptative_timer(struct timer_list *t) +{ + struct red_sched_data *q = from_timer(q, t, adapt_timer); + struct Qdisc *sch = q->sch; + spinlock_t *root_lock; + + rcu_read_lock(); + root_lock = qdisc_lock(qdisc_root_sleeping(sch)); + spin_lock(root_lock); + red_adaptative_algo(&q->parms, &q->vars); + mod_timer(&q->adapt_timer, jiffies + HZ/2); + spin_unlock(root_lock); + rcu_read_unlock(); +} + +static int red_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct red_sched_data *q = qdisc_priv(sch); + struct nlattr *tb[TCA_RED_MAX + 1]; + int err; + + q->qdisc = &noop_qdisc; + q->sch = sch; + timer_setup(&q->adapt_timer, red_adaptative_timer, 0); + + if (!opt) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, TCA_RED_MAX, opt, red_policy, + extack); + if (err < 0) + return err; + + err = __red_change(sch, tb, extack); + if (err) + return err; + + err = tcf_qevent_init(&q->qe_early_drop, sch, + FLOW_BLOCK_BINDER_TYPE_RED_EARLY_DROP, + tb[TCA_RED_EARLY_DROP_BLOCK], extack); + if (err) + return err; + + return tcf_qevent_init(&q->qe_mark, sch, + FLOW_BLOCK_BINDER_TYPE_RED_MARK, + tb[TCA_RED_MARK_BLOCK], extack); +} + +static int red_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct red_sched_data *q = qdisc_priv(sch); + struct nlattr *tb[TCA_RED_MAX + 1]; + int err; + + err = nla_parse_nested_deprecated(tb, TCA_RED_MAX, opt, red_policy, + extack); + if (err < 0) + return err; + + err = tcf_qevent_validate_change(&q->qe_early_drop, + tb[TCA_RED_EARLY_DROP_BLOCK], extack); + if (err) + return err; + + err = tcf_qevent_validate_change(&q->qe_mark, + tb[TCA_RED_MARK_BLOCK], extack); + if (err) + return err; + + return __red_change(sch, tb, extack); +} + +static int red_dump_offload_stats(struct Qdisc *sch) +{ + struct tc_red_qopt_offload hw_stats = { + .command = TC_RED_STATS, + .handle = sch->handle, + .parent = sch->parent, + { + .stats.bstats = &sch->bstats, + .stats.qstats = &sch->qstats, + }, + }; + + return qdisc_offload_dump_helper(sch, TC_SETUP_QDISC_RED, &hw_stats); +} + +static int red_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct red_sched_data *q = qdisc_priv(sch); + struct nlattr *opts = NULL; + struct tc_red_qopt opt = { + .limit = q->limit, + .flags = (q->flags & TC_RED_HISTORIC_FLAGS) | + q->userbits, + .qth_min = q->parms.qth_min >> q->parms.Wlog, + .qth_max = q->parms.qth_max >> q->parms.Wlog, + .Wlog = q->parms.Wlog, + .Plog = q->parms.Plog, + .Scell_log = q->parms.Scell_log, + }; + int err; + + err = red_dump_offload_stats(sch); + if (err) + goto nla_put_failure; + + opts = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (opts == NULL) + goto nla_put_failure; + if (nla_put(skb, TCA_RED_PARMS, sizeof(opt), &opt) || + nla_put_u32(skb, TCA_RED_MAX_P, q->parms.max_P) || + nla_put_bitfield32(skb, TCA_RED_FLAGS, + q->flags, TC_RED_SUPPORTED_FLAGS) || + tcf_qevent_dump(skb, TCA_RED_MARK_BLOCK, &q->qe_mark) || + tcf_qevent_dump(skb, TCA_RED_EARLY_DROP_BLOCK, &q->qe_early_drop)) + goto nla_put_failure; + return nla_nest_end(skb, opts); + +nla_put_failure: + nla_nest_cancel(skb, opts); + return -EMSGSIZE; +} + +static int red_dump_stats(struct Qdisc *sch, struct gnet_dump *d) +{ + struct red_sched_data *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct tc_red_xstats st = {0}; + + if (sch->flags & TCQ_F_OFFLOADED) { + struct tc_red_qopt_offload hw_stats_request = { + .command = TC_RED_XSTATS, + .handle = sch->handle, + .parent = sch->parent, + { + .xstats = &q->stats, + }, + }; + dev->netdev_ops->ndo_setup_tc(dev, TC_SETUP_QDISC_RED, + &hw_stats_request); + } + st.early = q->stats.prob_drop + q->stats.forced_drop; + st.pdrop = q->stats.pdrop; + st.marked = q->stats.prob_mark + q->stats.forced_mark; + + return gnet_stats_copy_app(d, &st, sizeof(st)); +} + +static int red_dump_class(struct Qdisc *sch, unsigned long cl, + struct sk_buff *skb, struct tcmsg *tcm) +{ + struct red_sched_data *q = qdisc_priv(sch); + + tcm->tcm_handle |= TC_H_MIN(1); + tcm->tcm_info = q->qdisc->handle; + return 0; +} + +static void red_graft_offload(struct Qdisc *sch, + struct Qdisc *new, struct Qdisc *old, + struct netlink_ext_ack *extack) +{ + struct tc_red_qopt_offload graft_offload = { + .handle = sch->handle, + .parent = sch->parent, + .child_handle = new->handle, + .command = TC_RED_GRAFT, + }; + + qdisc_offload_graft_helper(qdisc_dev(sch), sch, new, old, + TC_SETUP_QDISC_RED, &graft_offload, extack); +} + +static int red_graft(struct Qdisc *sch, unsigned long arg, struct Qdisc *new, + struct Qdisc **old, struct netlink_ext_ack *extack) +{ + struct red_sched_data *q = qdisc_priv(sch); + + if (new == NULL) + new = &noop_qdisc; + + *old = qdisc_replace(sch, new, &q->qdisc); + + red_graft_offload(sch, new, *old, extack); + return 0; +} + +static struct Qdisc *red_leaf(struct Qdisc *sch, unsigned long arg) +{ + struct red_sched_data *q = qdisc_priv(sch); + return q->qdisc; +} + +static unsigned long red_find(struct Qdisc *sch, u32 classid) +{ + return 1; +} + +static void red_walk(struct Qdisc *sch, struct qdisc_walker *walker) +{ + if (!walker->stop) { + tc_qdisc_stats_dump(sch, 1, walker); + } +} + +static const struct Qdisc_class_ops red_class_ops = { + .graft = red_graft, + .leaf = red_leaf, + .find = red_find, + .walk = red_walk, + .dump = red_dump_class, +}; + +static struct Qdisc_ops red_qdisc_ops __read_mostly = { + .id = "red", + .priv_size = sizeof(struct red_sched_data), + .cl_ops = &red_class_ops, + .enqueue = red_enqueue, + .dequeue = red_dequeue, + .peek = red_peek, + .init = red_init, + .reset = red_reset, + .destroy = red_destroy, + .change = red_change, + .dump = red_dump, + .dump_stats = red_dump_stats, + .owner = THIS_MODULE, +}; + +static int __init red_module_init(void) +{ + return register_qdisc(&red_qdisc_ops); +} + +static void __exit red_module_exit(void) +{ + unregister_qdisc(&red_qdisc_ops); +} + +module_init(red_module_init) +module_exit(red_module_exit) + +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_sfb.c b/net/sched/sch_sfb.c new file mode 100644 index 000000000..1871a1c02 --- /dev/null +++ b/net/sched/sch_sfb.c @@ -0,0 +1,729 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/sched/sch_sfb.c Stochastic Fair Blue + * + * Copyright (c) 2008-2011 Juliusz Chroboczek + * Copyright (c) 2011 Eric Dumazet + * + * W. Feng, D. Kandlur, D. Saha, K. Shin. Blue: + * A New Class of Active Queue Management Algorithms. + * U. Michigan CSE-TR-387-99, April 1999. + * + * http://www.thefengs.com/wuchang/blue/CSE-TR-387-99.pdf + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * SFB uses two B[l][n] : L x N arrays of bins (L levels, N bins per level) + * This implementation uses L = 8 and N = 16 + * This permits us to split one 32bit hash (provided per packet by rxhash or + * external classifier) into 8 subhashes of 4 bits. + */ +#define SFB_BUCKET_SHIFT 4 +#define SFB_NUMBUCKETS (1 << SFB_BUCKET_SHIFT) /* N bins per Level */ +#define SFB_BUCKET_MASK (SFB_NUMBUCKETS - 1) +#define SFB_LEVELS (32 / SFB_BUCKET_SHIFT) /* L */ + +/* SFB algo uses a virtual queue, named "bin" */ +struct sfb_bucket { + u16 qlen; /* length of virtual queue */ + u16 p_mark; /* marking probability */ +}; + +/* We use a double buffering right before hash change + * (Section 4.4 of SFB reference : moving hash functions) + */ +struct sfb_bins { + siphash_key_t perturbation; /* siphash key */ + struct sfb_bucket bins[SFB_LEVELS][SFB_NUMBUCKETS]; +}; + +struct sfb_sched_data { + struct Qdisc *qdisc; + struct tcf_proto __rcu *filter_list; + struct tcf_block *block; + unsigned long rehash_interval; + unsigned long warmup_time; /* double buffering warmup time in jiffies */ + u32 max; + u32 bin_size; /* maximum queue length per bin */ + u32 increment; /* d1 */ + u32 decrement; /* d2 */ + u32 limit; /* HARD maximal queue length */ + u32 penalty_rate; + u32 penalty_burst; + u32 tokens_avail; + unsigned long rehash_time; + unsigned long token_time; + + u8 slot; /* current active bins (0 or 1) */ + bool double_buffering; + struct sfb_bins bins[2]; + + struct { + u32 earlydrop; + u32 penaltydrop; + u32 bucketdrop; + u32 queuedrop; + u32 childdrop; /* drops in child qdisc */ + u32 marked; /* ECN mark */ + } stats; +}; + +/* + * Each queued skb might be hashed on one or two bins + * We store in skb_cb the two hash values. + * (A zero value means double buffering was not used) + */ +struct sfb_skb_cb { + u32 hashes[2]; +}; + +static inline struct sfb_skb_cb *sfb_skb_cb(const struct sk_buff *skb) +{ + qdisc_cb_private_validate(skb, sizeof(struct sfb_skb_cb)); + return (struct sfb_skb_cb *)qdisc_skb_cb(skb)->data; +} + +/* + * If using 'internal' SFB flow classifier, hash comes from skb rxhash + * If using external classifier, hash comes from the classid. + */ +static u32 sfb_hash(const struct sk_buff *skb, u32 slot) +{ + return sfb_skb_cb(skb)->hashes[slot]; +} + +/* Probabilities are coded as Q0.16 fixed-point values, + * with 0xFFFF representing 65535/65536 (almost 1.0) + * Addition and subtraction are saturating in [0, 65535] + */ +static u32 prob_plus(u32 p1, u32 p2) +{ + u32 res = p1 + p2; + + return min_t(u32, res, SFB_MAX_PROB); +} + +static u32 prob_minus(u32 p1, u32 p2) +{ + return p1 > p2 ? p1 - p2 : 0; +} + +static void increment_one_qlen(u32 sfbhash, u32 slot, struct sfb_sched_data *q) +{ + int i; + struct sfb_bucket *b = &q->bins[slot].bins[0][0]; + + for (i = 0; i < SFB_LEVELS; i++) { + u32 hash = sfbhash & SFB_BUCKET_MASK; + + sfbhash >>= SFB_BUCKET_SHIFT; + if (b[hash].qlen < 0xFFFF) + b[hash].qlen++; + b += SFB_NUMBUCKETS; /* next level */ + } +} + +static void increment_qlen(const struct sfb_skb_cb *cb, struct sfb_sched_data *q) +{ + u32 sfbhash; + + sfbhash = cb->hashes[0]; + if (sfbhash) + increment_one_qlen(sfbhash, 0, q); + + sfbhash = cb->hashes[1]; + if (sfbhash) + increment_one_qlen(sfbhash, 1, q); +} + +static void decrement_one_qlen(u32 sfbhash, u32 slot, + struct sfb_sched_data *q) +{ + int i; + struct sfb_bucket *b = &q->bins[slot].bins[0][0]; + + for (i = 0; i < SFB_LEVELS; i++) { + u32 hash = sfbhash & SFB_BUCKET_MASK; + + sfbhash >>= SFB_BUCKET_SHIFT; + if (b[hash].qlen > 0) + b[hash].qlen--; + b += SFB_NUMBUCKETS; /* next level */ + } +} + +static void decrement_qlen(const struct sk_buff *skb, struct sfb_sched_data *q) +{ + u32 sfbhash; + + sfbhash = sfb_hash(skb, 0); + if (sfbhash) + decrement_one_qlen(sfbhash, 0, q); + + sfbhash = sfb_hash(skb, 1); + if (sfbhash) + decrement_one_qlen(sfbhash, 1, q); +} + +static void decrement_prob(struct sfb_bucket *b, struct sfb_sched_data *q) +{ + b->p_mark = prob_minus(b->p_mark, q->decrement); +} + +static void increment_prob(struct sfb_bucket *b, struct sfb_sched_data *q) +{ + b->p_mark = prob_plus(b->p_mark, q->increment); +} + +static void sfb_zero_all_buckets(struct sfb_sched_data *q) +{ + memset(&q->bins, 0, sizeof(q->bins)); +} + +/* + * compute max qlen, max p_mark, and avg p_mark + */ +static u32 sfb_compute_qlen(u32 *prob_r, u32 *avgpm_r, const struct sfb_sched_data *q) +{ + int i; + u32 qlen = 0, prob = 0, totalpm = 0; + const struct sfb_bucket *b = &q->bins[q->slot].bins[0][0]; + + for (i = 0; i < SFB_LEVELS * SFB_NUMBUCKETS; i++) { + if (qlen < b->qlen) + qlen = b->qlen; + totalpm += b->p_mark; + if (prob < b->p_mark) + prob = b->p_mark; + b++; + } + *prob_r = prob; + *avgpm_r = totalpm / (SFB_LEVELS * SFB_NUMBUCKETS); + return qlen; +} + + +static void sfb_init_perturbation(u32 slot, struct sfb_sched_data *q) +{ + get_random_bytes(&q->bins[slot].perturbation, + sizeof(q->bins[slot].perturbation)); +} + +static void sfb_swap_slot(struct sfb_sched_data *q) +{ + sfb_init_perturbation(q->slot, q); + q->slot ^= 1; + q->double_buffering = false; +} + +/* Non elastic flows are allowed to use part of the bandwidth, expressed + * in "penalty_rate" packets per second, with "penalty_burst" burst + */ +static bool sfb_rate_limit(struct sk_buff *skb, struct sfb_sched_data *q) +{ + if (q->penalty_rate == 0 || q->penalty_burst == 0) + return true; + + if (q->tokens_avail < 1) { + unsigned long age = min(10UL * HZ, jiffies - q->token_time); + + q->tokens_avail = (age * q->penalty_rate) / HZ; + if (q->tokens_avail > q->penalty_burst) + q->tokens_avail = q->penalty_burst; + q->token_time = jiffies; + if (q->tokens_avail < 1) + return true; + } + + q->tokens_avail--; + return false; +} + +static bool sfb_classify(struct sk_buff *skb, struct tcf_proto *fl, + int *qerr, u32 *salt) +{ + struct tcf_result res; + int result; + + result = tcf_classify(skb, NULL, fl, &res, false); + if (result >= 0) { +#ifdef CONFIG_NET_CLS_ACT + switch (result) { + case TC_ACT_STOLEN: + case TC_ACT_QUEUED: + case TC_ACT_TRAP: + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_STOLEN; + fallthrough; + case TC_ACT_SHOT: + return false; + } +#endif + *salt = TC_H_MIN(res.classid); + return true; + } + return false; +} + +static int sfb_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + + struct sfb_sched_data *q = qdisc_priv(sch); + unsigned int len = qdisc_pkt_len(skb); + struct Qdisc *child = q->qdisc; + struct tcf_proto *fl; + struct sfb_skb_cb cb; + int i; + u32 p_min = ~0; + u32 minqlen = ~0; + u32 r, sfbhash; + u32 slot = q->slot; + int ret = NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; + + if (unlikely(sch->q.qlen >= q->limit)) { + qdisc_qstats_overlimit(sch); + q->stats.queuedrop++; + goto drop; + } + + if (q->rehash_interval > 0) { + unsigned long limit = q->rehash_time + q->rehash_interval; + + if (unlikely(time_after(jiffies, limit))) { + sfb_swap_slot(q); + q->rehash_time = jiffies; + } else if (unlikely(!q->double_buffering && q->warmup_time > 0 && + time_after(jiffies, limit - q->warmup_time))) { + q->double_buffering = true; + } + } + + fl = rcu_dereference_bh(q->filter_list); + if (fl) { + u32 salt; + + /* If using external classifiers, get result and record it. */ + if (!sfb_classify(skb, fl, &ret, &salt)) + goto other_drop; + sfbhash = siphash_1u32(salt, &q->bins[slot].perturbation); + } else { + sfbhash = skb_get_hash_perturb(skb, &q->bins[slot].perturbation); + } + + + if (!sfbhash) + sfbhash = 1; + sfb_skb_cb(skb)->hashes[slot] = sfbhash; + + for (i = 0; i < SFB_LEVELS; i++) { + u32 hash = sfbhash & SFB_BUCKET_MASK; + struct sfb_bucket *b = &q->bins[slot].bins[i][hash]; + + sfbhash >>= SFB_BUCKET_SHIFT; + if (b->qlen == 0) + decrement_prob(b, q); + else if (b->qlen >= q->bin_size) + increment_prob(b, q); + if (minqlen > b->qlen) + minqlen = b->qlen; + if (p_min > b->p_mark) + p_min = b->p_mark; + } + + slot ^= 1; + sfb_skb_cb(skb)->hashes[slot] = 0; + + if (unlikely(minqlen >= q->max)) { + qdisc_qstats_overlimit(sch); + q->stats.bucketdrop++; + goto drop; + } + + if (unlikely(p_min >= SFB_MAX_PROB)) { + /* Inelastic flow */ + if (q->double_buffering) { + sfbhash = skb_get_hash_perturb(skb, + &q->bins[slot].perturbation); + if (!sfbhash) + sfbhash = 1; + sfb_skb_cb(skb)->hashes[slot] = sfbhash; + + for (i = 0; i < SFB_LEVELS; i++) { + u32 hash = sfbhash & SFB_BUCKET_MASK; + struct sfb_bucket *b = &q->bins[slot].bins[i][hash]; + + sfbhash >>= SFB_BUCKET_SHIFT; + if (b->qlen == 0) + decrement_prob(b, q); + else if (b->qlen >= q->bin_size) + increment_prob(b, q); + } + } + if (sfb_rate_limit(skb, q)) { + qdisc_qstats_overlimit(sch); + q->stats.penaltydrop++; + goto drop; + } + goto enqueue; + } + + r = get_random_u16() & SFB_MAX_PROB; + + if (unlikely(r < p_min)) { + if (unlikely(p_min > SFB_MAX_PROB / 2)) { + /* If we're marking that many packets, then either + * this flow is unresponsive, or we're badly congested. + * In either case, we want to start dropping packets. + */ + if (r < (p_min - SFB_MAX_PROB / 2) * 2) { + q->stats.earlydrop++; + goto drop; + } + } + if (INET_ECN_set_ce(skb)) { + q->stats.marked++; + } else { + q->stats.earlydrop++; + goto drop; + } + } + +enqueue: + memcpy(&cb, sfb_skb_cb(skb), sizeof(cb)); + ret = qdisc_enqueue(skb, child, to_free); + if (likely(ret == NET_XMIT_SUCCESS)) { + sch->qstats.backlog += len; + sch->q.qlen++; + increment_qlen(&cb, q); + } else if (net_xmit_drop_count(ret)) { + q->stats.childdrop++; + qdisc_qstats_drop(sch); + } + return ret; + +drop: + qdisc_drop(skb, sch, to_free); + return NET_XMIT_CN; +other_drop: + if (ret & __NET_XMIT_BYPASS) + qdisc_qstats_drop(sch); + kfree_skb(skb); + return ret; +} + +static struct sk_buff *sfb_dequeue(struct Qdisc *sch) +{ + struct sfb_sched_data *q = qdisc_priv(sch); + struct Qdisc *child = q->qdisc; + struct sk_buff *skb; + + skb = child->dequeue(q->qdisc); + + if (skb) { + qdisc_bstats_update(sch, skb); + qdisc_qstats_backlog_dec(sch, skb); + sch->q.qlen--; + decrement_qlen(skb, q); + } + + return skb; +} + +static struct sk_buff *sfb_peek(struct Qdisc *sch) +{ + struct sfb_sched_data *q = qdisc_priv(sch); + struct Qdisc *child = q->qdisc; + + return child->ops->peek(child); +} + +/* No sfb_drop -- impossible since the child doesn't return the dropped skb. */ + +static void sfb_reset(struct Qdisc *sch) +{ + struct sfb_sched_data *q = qdisc_priv(sch); + + if (likely(q->qdisc)) + qdisc_reset(q->qdisc); + q->slot = 0; + q->double_buffering = false; + sfb_zero_all_buckets(q); + sfb_init_perturbation(0, q); +} + +static void sfb_destroy(struct Qdisc *sch) +{ + struct sfb_sched_data *q = qdisc_priv(sch); + + tcf_block_put(q->block); + qdisc_put(q->qdisc); +} + +static const struct nla_policy sfb_policy[TCA_SFB_MAX + 1] = { + [TCA_SFB_PARMS] = { .len = sizeof(struct tc_sfb_qopt) }, +}; + +static const struct tc_sfb_qopt sfb_default_ops = { + .rehash_interval = 600 * MSEC_PER_SEC, + .warmup_time = 60 * MSEC_PER_SEC, + .limit = 0, + .max = 25, + .bin_size = 20, + .increment = (SFB_MAX_PROB + 500) / 1000, /* 0.1 % */ + .decrement = (SFB_MAX_PROB + 3000) / 6000, + .penalty_rate = 10, + .penalty_burst = 20, +}; + +static int sfb_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct sfb_sched_data *q = qdisc_priv(sch); + struct Qdisc *child, *old; + struct nlattr *tb[TCA_SFB_MAX + 1]; + const struct tc_sfb_qopt *ctl = &sfb_default_ops; + u32 limit; + int err; + + if (opt) { + err = nla_parse_nested_deprecated(tb, TCA_SFB_MAX, opt, + sfb_policy, NULL); + if (err < 0) + return -EINVAL; + + if (tb[TCA_SFB_PARMS] == NULL) + return -EINVAL; + + ctl = nla_data(tb[TCA_SFB_PARMS]); + } + + limit = ctl->limit; + if (limit == 0) + limit = qdisc_dev(sch)->tx_queue_len; + + child = fifo_create_dflt(sch, &pfifo_qdisc_ops, limit, extack); + if (IS_ERR(child)) + return PTR_ERR(child); + + if (child != &noop_qdisc) + qdisc_hash_add(child, true); + sch_tree_lock(sch); + + qdisc_purge_queue(q->qdisc); + old = q->qdisc; + q->qdisc = child; + + q->rehash_interval = msecs_to_jiffies(ctl->rehash_interval); + q->warmup_time = msecs_to_jiffies(ctl->warmup_time); + q->rehash_time = jiffies; + q->limit = limit; + q->increment = ctl->increment; + q->decrement = ctl->decrement; + q->max = ctl->max; + q->bin_size = ctl->bin_size; + q->penalty_rate = ctl->penalty_rate; + q->penalty_burst = ctl->penalty_burst; + q->tokens_avail = ctl->penalty_burst; + q->token_time = jiffies; + + q->slot = 0; + q->double_buffering = false; + sfb_zero_all_buckets(q); + sfb_init_perturbation(0, q); + sfb_init_perturbation(1, q); + + sch_tree_unlock(sch); + qdisc_put(old); + + return 0; +} + +static int sfb_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct sfb_sched_data *q = qdisc_priv(sch); + int err; + + err = tcf_block_get(&q->block, &q->filter_list, sch, extack); + if (err) + return err; + + q->qdisc = &noop_qdisc; + return sfb_change(sch, opt, extack); +} + +static int sfb_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct sfb_sched_data *q = qdisc_priv(sch); + struct nlattr *opts; + struct tc_sfb_qopt opt = { + .rehash_interval = jiffies_to_msecs(q->rehash_interval), + .warmup_time = jiffies_to_msecs(q->warmup_time), + .limit = q->limit, + .max = q->max, + .bin_size = q->bin_size, + .increment = q->increment, + .decrement = q->decrement, + .penalty_rate = q->penalty_rate, + .penalty_burst = q->penalty_burst, + }; + + sch->qstats.backlog = q->qdisc->qstats.backlog; + opts = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (opts == NULL) + goto nla_put_failure; + if (nla_put(skb, TCA_SFB_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + return nla_nest_end(skb, opts); + +nla_put_failure: + nla_nest_cancel(skb, opts); + return -EMSGSIZE; +} + +static int sfb_dump_stats(struct Qdisc *sch, struct gnet_dump *d) +{ + struct sfb_sched_data *q = qdisc_priv(sch); + struct tc_sfb_xstats st = { + .earlydrop = q->stats.earlydrop, + .penaltydrop = q->stats.penaltydrop, + .bucketdrop = q->stats.bucketdrop, + .queuedrop = q->stats.queuedrop, + .childdrop = q->stats.childdrop, + .marked = q->stats.marked, + }; + + st.maxqlen = sfb_compute_qlen(&st.maxprob, &st.avgprob, q); + + return gnet_stats_copy_app(d, &st, sizeof(st)); +} + +static int sfb_dump_class(struct Qdisc *sch, unsigned long cl, + struct sk_buff *skb, struct tcmsg *tcm) +{ + return -ENOSYS; +} + +static int sfb_graft(struct Qdisc *sch, unsigned long arg, struct Qdisc *new, + struct Qdisc **old, struct netlink_ext_ack *extack) +{ + struct sfb_sched_data *q = qdisc_priv(sch); + + if (new == NULL) + new = &noop_qdisc; + + *old = qdisc_replace(sch, new, &q->qdisc); + return 0; +} + +static struct Qdisc *sfb_leaf(struct Qdisc *sch, unsigned long arg) +{ + struct sfb_sched_data *q = qdisc_priv(sch); + + return q->qdisc; +} + +static unsigned long sfb_find(struct Qdisc *sch, u32 classid) +{ + return 1; +} + +static void sfb_unbind(struct Qdisc *sch, unsigned long arg) +{ +} + +static int sfb_change_class(struct Qdisc *sch, u32 classid, u32 parentid, + struct nlattr **tca, unsigned long *arg, + struct netlink_ext_ack *extack) +{ + return -ENOSYS; +} + +static int sfb_delete(struct Qdisc *sch, unsigned long cl, + struct netlink_ext_ack *extack) +{ + return -ENOSYS; +} + +static void sfb_walk(struct Qdisc *sch, struct qdisc_walker *walker) +{ + if (!walker->stop) { + tc_qdisc_stats_dump(sch, 1, walker); + } +} + +static struct tcf_block *sfb_tcf_block(struct Qdisc *sch, unsigned long cl, + struct netlink_ext_ack *extack) +{ + struct sfb_sched_data *q = qdisc_priv(sch); + + if (cl) + return NULL; + return q->block; +} + +static unsigned long sfb_bind(struct Qdisc *sch, unsigned long parent, + u32 classid) +{ + return 0; +} + + +static const struct Qdisc_class_ops sfb_class_ops = { + .graft = sfb_graft, + .leaf = sfb_leaf, + .find = sfb_find, + .change = sfb_change_class, + .delete = sfb_delete, + .walk = sfb_walk, + .tcf_block = sfb_tcf_block, + .bind_tcf = sfb_bind, + .unbind_tcf = sfb_unbind, + .dump = sfb_dump_class, +}; + +static struct Qdisc_ops sfb_qdisc_ops __read_mostly = { + .id = "sfb", + .priv_size = sizeof(struct sfb_sched_data), + .cl_ops = &sfb_class_ops, + .enqueue = sfb_enqueue, + .dequeue = sfb_dequeue, + .peek = sfb_peek, + .init = sfb_init, + .reset = sfb_reset, + .destroy = sfb_destroy, + .change = sfb_change, + .dump = sfb_dump, + .dump_stats = sfb_dump_stats, + .owner = THIS_MODULE, +}; + +static int __init sfb_module_init(void) +{ + return register_qdisc(&sfb_qdisc_ops); +} + +static void __exit sfb_module_exit(void) +{ + unregister_qdisc(&sfb_qdisc_ops); +} + +module_init(sfb_module_init) +module_exit(sfb_module_exit) + +MODULE_DESCRIPTION("Stochastic Fair Blue queue discipline"); +MODULE_AUTHOR("Juliusz Chroboczek"); +MODULE_AUTHOR("Eric Dumazet"); +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_sfq.c b/net/sched/sch_sfq.c new file mode 100644 index 000000000..66dcb1863 --- /dev/null +++ b/net/sched/sch_sfq.c @@ -0,0 +1,939 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/sch_sfq.c Stochastic Fairness Queueing discipline. + * + * Authors: Alexey Kuznetsov, + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +/* Stochastic Fairness Queuing algorithm. + ======================================= + + Source: + Paul E. McKenney "Stochastic Fairness Queuing", + IEEE INFOCOMM'90 Proceedings, San Francisco, 1990. + + Paul E. McKenney "Stochastic Fairness Queuing", + "Interworking: Research and Experience", v.2, 1991, p.113-131. + + + See also: + M. Shreedhar and George Varghese "Efficient Fair + Queuing using Deficit Round Robin", Proc. SIGCOMM 95. + + + This is not the thing that is usually called (W)FQ nowadays. + It does not use any timestamp mechanism, but instead + processes queues in round-robin order. + + ADVANTAGE: + + - It is very cheap. Both CPU and memory requirements are minimal. + + DRAWBACKS: + + - "Stochastic" -> It is not 100% fair. + When hash collisions occur, several flows are considered as one. + + - "Round-robin" -> It introduces larger delays than virtual clock + based schemes, and should not be used for isolating interactive + traffic from non-interactive. It means, that this scheduler + should be used as leaf of CBQ or P3, which put interactive traffic + to higher priority band. + + We still need true WFQ for top level CSZ, but using WFQ + for the best effort traffic is absolutely pointless: + SFQ is superior for this purpose. + + IMPLEMENTATION: + This implementation limits : + - maximal queue length per flow to 127 packets. + - max mtu to 2^18-1; + - max 65408 flows, + - number of hash buckets to 65536. + + It is easy to increase these values, but not in flight. */ + +#define SFQ_MAX_DEPTH 127 /* max number of packets per flow */ +#define SFQ_DEFAULT_FLOWS 128 +#define SFQ_MAX_FLOWS (0x10000 - SFQ_MAX_DEPTH - 1) /* max number of flows */ +#define SFQ_EMPTY_SLOT 0xffff +#define SFQ_DEFAULT_HASH_DIVISOR 1024 + +/* We use 16 bits to store allot, and want to handle packets up to 64K + * Scale allot by 8 (1<<3) so that no overflow occurs. + */ +#define SFQ_ALLOT_SHIFT 3 +#define SFQ_ALLOT_SIZE(X) DIV_ROUND_UP(X, 1 << SFQ_ALLOT_SHIFT) + +/* This type should contain at least SFQ_MAX_DEPTH + 1 + SFQ_MAX_FLOWS values */ +typedef u16 sfq_index; + +/* + * We dont use pointers to save space. + * Small indexes [0 ... SFQ_MAX_FLOWS - 1] are 'pointers' to slots[] array + * while following values [SFQ_MAX_FLOWS ... SFQ_MAX_FLOWS + SFQ_MAX_DEPTH] + * are 'pointers' to dep[] array + */ +struct sfq_head { + sfq_index next; + sfq_index prev; +}; + +struct sfq_slot { + struct sk_buff *skblist_next; + struct sk_buff *skblist_prev; + sfq_index qlen; /* number of skbs in skblist */ + sfq_index next; /* next slot in sfq RR chain */ + struct sfq_head dep; /* anchor in dep[] chains */ + unsigned short hash; /* hash value (index in ht[]) */ + short allot; /* credit for this slot */ + + unsigned int backlog; + struct red_vars vars; +}; + +struct sfq_sched_data { +/* frequently used fields */ + int limit; /* limit of total number of packets in this qdisc */ + unsigned int divisor; /* number of slots in hash table */ + u8 headdrop; + u8 maxdepth; /* limit of packets per flow */ + + siphash_key_t perturbation; + u8 cur_depth; /* depth of longest slot */ + u8 flags; + unsigned short scaled_quantum; /* SFQ_ALLOT_SIZE(quantum) */ + struct tcf_proto __rcu *filter_list; + struct tcf_block *block; + sfq_index *ht; /* Hash table ('divisor' slots) */ + struct sfq_slot *slots; /* Flows table ('maxflows' entries) */ + + struct red_parms *red_parms; + struct tc_sfqred_stats stats; + struct sfq_slot *tail; /* current slot in round */ + + struct sfq_head dep[SFQ_MAX_DEPTH + 1]; + /* Linked lists of slots, indexed by depth + * dep[0] : list of unused flows + * dep[1] : list of flows with 1 packet + * dep[X] : list of flows with X packets + */ + + unsigned int maxflows; /* number of flows in flows array */ + int perturb_period; + unsigned int quantum; /* Allotment per round: MUST BE >= MTU */ + struct timer_list perturb_timer; + struct Qdisc *sch; +}; + +/* + * sfq_head are either in a sfq_slot or in dep[] array + */ +static inline struct sfq_head *sfq_dep_head(struct sfq_sched_data *q, sfq_index val) +{ + if (val < SFQ_MAX_FLOWS) + return &q->slots[val].dep; + return &q->dep[val - SFQ_MAX_FLOWS]; +} + +static unsigned int sfq_hash(const struct sfq_sched_data *q, + const struct sk_buff *skb) +{ + return skb_get_hash_perturb(skb, &q->perturbation) & (q->divisor - 1); +} + +static unsigned int sfq_classify(struct sk_buff *skb, struct Qdisc *sch, + int *qerr) +{ + struct sfq_sched_data *q = qdisc_priv(sch); + struct tcf_result res; + struct tcf_proto *fl; + int result; + + if (TC_H_MAJ(skb->priority) == sch->handle && + TC_H_MIN(skb->priority) > 0 && + TC_H_MIN(skb->priority) <= q->divisor) + return TC_H_MIN(skb->priority); + + fl = rcu_dereference_bh(q->filter_list); + if (!fl) + return sfq_hash(q, skb) + 1; + + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_BYPASS; + result = tcf_classify(skb, NULL, fl, &res, false); + if (result >= 0) { +#ifdef CONFIG_NET_CLS_ACT + switch (result) { + case TC_ACT_STOLEN: + case TC_ACT_QUEUED: + case TC_ACT_TRAP: + *qerr = NET_XMIT_SUCCESS | __NET_XMIT_STOLEN; + fallthrough; + case TC_ACT_SHOT: + return 0; + } +#endif + if (TC_H_MIN(res.classid) <= q->divisor) + return TC_H_MIN(res.classid); + } + return 0; +} + +/* + * x : slot number [0 .. SFQ_MAX_FLOWS - 1] + */ +static inline void sfq_link(struct sfq_sched_data *q, sfq_index x) +{ + sfq_index p, n; + struct sfq_slot *slot = &q->slots[x]; + int qlen = slot->qlen; + + p = qlen + SFQ_MAX_FLOWS; + n = q->dep[qlen].next; + + slot->dep.next = n; + slot->dep.prev = p; + + q->dep[qlen].next = x; /* sfq_dep_head(q, p)->next = x */ + sfq_dep_head(q, n)->prev = x; +} + +#define sfq_unlink(q, x, n, p) \ + do { \ + n = q->slots[x].dep.next; \ + p = q->slots[x].dep.prev; \ + sfq_dep_head(q, p)->next = n; \ + sfq_dep_head(q, n)->prev = p; \ + } while (0) + + +static inline void sfq_dec(struct sfq_sched_data *q, sfq_index x) +{ + sfq_index p, n; + int d; + + sfq_unlink(q, x, n, p); + + d = q->slots[x].qlen--; + if (n == p && q->cur_depth == d) + q->cur_depth--; + sfq_link(q, x); +} + +static inline void sfq_inc(struct sfq_sched_data *q, sfq_index x) +{ + sfq_index p, n; + int d; + + sfq_unlink(q, x, n, p); + + d = ++q->slots[x].qlen; + if (q->cur_depth < d) + q->cur_depth = d; + sfq_link(q, x); +} + +/* helper functions : might be changed when/if skb use a standard list_head */ + +/* remove one skb from tail of slot queue */ +static inline struct sk_buff *slot_dequeue_tail(struct sfq_slot *slot) +{ + struct sk_buff *skb = slot->skblist_prev; + + slot->skblist_prev = skb->prev; + skb->prev->next = (struct sk_buff *)slot; + skb->next = skb->prev = NULL; + return skb; +} + +/* remove one skb from head of slot queue */ +static inline struct sk_buff *slot_dequeue_head(struct sfq_slot *slot) +{ + struct sk_buff *skb = slot->skblist_next; + + slot->skblist_next = skb->next; + skb->next->prev = (struct sk_buff *)slot; + skb->next = skb->prev = NULL; + return skb; +} + +static inline void slot_queue_init(struct sfq_slot *slot) +{ + memset(slot, 0, sizeof(*slot)); + slot->skblist_prev = slot->skblist_next = (struct sk_buff *)slot; +} + +/* add skb to slot queue (tail add) */ +static inline void slot_queue_add(struct sfq_slot *slot, struct sk_buff *skb) +{ + skb->prev = slot->skblist_prev; + skb->next = (struct sk_buff *)slot; + slot->skblist_prev->next = skb; + slot->skblist_prev = skb; +} + +static unsigned int sfq_drop(struct Qdisc *sch, struct sk_buff **to_free) +{ + struct sfq_sched_data *q = qdisc_priv(sch); + sfq_index x, d = q->cur_depth; + struct sk_buff *skb; + unsigned int len; + struct sfq_slot *slot; + + /* Queue is full! Find the longest slot and drop tail packet from it */ + if (d > 1) { + x = q->dep[d].next; + slot = &q->slots[x]; +drop: + skb = q->headdrop ? slot_dequeue_head(slot) : slot_dequeue_tail(slot); + len = qdisc_pkt_len(skb); + slot->backlog -= len; + sfq_dec(q, x); + sch->q.qlen--; + qdisc_qstats_backlog_dec(sch, skb); + qdisc_drop(skb, sch, to_free); + return len; + } + + if (d == 1) { + /* It is difficult to believe, but ALL THE SLOTS HAVE LENGTH 1. */ + x = q->tail->next; + slot = &q->slots[x]; + q->tail->next = slot->next; + q->ht[slot->hash] = SFQ_EMPTY_SLOT; + goto drop; + } + + return 0; +} + +/* Is ECN parameter configured */ +static int sfq_prob_mark(const struct sfq_sched_data *q) +{ + return q->flags & TC_RED_ECN; +} + +/* Should packets over max threshold just be marked */ +static int sfq_hard_mark(const struct sfq_sched_data *q) +{ + return (q->flags & (TC_RED_ECN | TC_RED_HARDDROP)) == TC_RED_ECN; +} + +static int sfq_headdrop(const struct sfq_sched_data *q) +{ + return q->headdrop; +} + +static int +sfq_enqueue(struct sk_buff *skb, struct Qdisc *sch, struct sk_buff **to_free) +{ + struct sfq_sched_data *q = qdisc_priv(sch); + unsigned int hash, dropped; + sfq_index x, qlen; + struct sfq_slot *slot; + int ret; + struct sk_buff *head; + int delta; + + hash = sfq_classify(skb, sch, &ret); + if (hash == 0) { + if (ret & __NET_XMIT_BYPASS) + qdisc_qstats_drop(sch); + __qdisc_drop(skb, to_free); + return ret; + } + hash--; + + x = q->ht[hash]; + slot = &q->slots[x]; + if (x == SFQ_EMPTY_SLOT) { + x = q->dep[0].next; /* get a free slot */ + if (x >= SFQ_MAX_FLOWS) + return qdisc_drop(skb, sch, to_free); + q->ht[hash] = x; + slot = &q->slots[x]; + slot->hash = hash; + slot->backlog = 0; /* should already be 0 anyway... */ + red_set_vars(&slot->vars); + goto enqueue; + } + if (q->red_parms) { + slot->vars.qavg = red_calc_qavg_no_idle_time(q->red_parms, + &slot->vars, + slot->backlog); + switch (red_action(q->red_parms, + &slot->vars, + slot->vars.qavg)) { + case RED_DONT_MARK: + break; + + case RED_PROB_MARK: + qdisc_qstats_overlimit(sch); + if (sfq_prob_mark(q)) { + /* We know we have at least one packet in queue */ + if (sfq_headdrop(q) && + INET_ECN_set_ce(slot->skblist_next)) { + q->stats.prob_mark_head++; + break; + } + if (INET_ECN_set_ce(skb)) { + q->stats.prob_mark++; + break; + } + } + q->stats.prob_drop++; + goto congestion_drop; + + case RED_HARD_MARK: + qdisc_qstats_overlimit(sch); + if (sfq_hard_mark(q)) { + /* We know we have at least one packet in queue */ + if (sfq_headdrop(q) && + INET_ECN_set_ce(slot->skblist_next)) { + q->stats.forced_mark_head++; + break; + } + if (INET_ECN_set_ce(skb)) { + q->stats.forced_mark++; + break; + } + } + q->stats.forced_drop++; + goto congestion_drop; + } + } + + if (slot->qlen >= q->maxdepth) { +congestion_drop: + if (!sfq_headdrop(q)) + return qdisc_drop(skb, sch, to_free); + + /* We know we have at least one packet in queue */ + head = slot_dequeue_head(slot); + delta = qdisc_pkt_len(head) - qdisc_pkt_len(skb); + sch->qstats.backlog -= delta; + slot->backlog -= delta; + qdisc_drop(head, sch, to_free); + + slot_queue_add(slot, skb); + qdisc_tree_reduce_backlog(sch, 0, delta); + return NET_XMIT_CN; + } + +enqueue: + qdisc_qstats_backlog_inc(sch, skb); + slot->backlog += qdisc_pkt_len(skb); + slot_queue_add(slot, skb); + sfq_inc(q, x); + if (slot->qlen == 1) { /* The flow is new */ + if (q->tail == NULL) { /* It is the first flow */ + slot->next = x; + } else { + slot->next = q->tail->next; + q->tail->next = x; + } + /* We put this flow at the end of our flow list. + * This might sound unfair for a new flow to wait after old ones, + * but we could endup servicing new flows only, and freeze old ones. + */ + q->tail = slot; + /* We could use a bigger initial quantum for new flows */ + slot->allot = q->scaled_quantum; + } + if (++sch->q.qlen <= q->limit) + return NET_XMIT_SUCCESS; + + qlen = slot->qlen; + dropped = sfq_drop(sch, to_free); + /* Return Congestion Notification only if we dropped a packet + * from this flow. + */ + if (qlen != slot->qlen) { + qdisc_tree_reduce_backlog(sch, 0, dropped - qdisc_pkt_len(skb)); + return NET_XMIT_CN; + } + + /* As we dropped a packet, better let upper stack know this */ + qdisc_tree_reduce_backlog(sch, 1, dropped); + return NET_XMIT_SUCCESS; +} + +static struct sk_buff * +sfq_dequeue(struct Qdisc *sch) +{ + struct sfq_sched_data *q = qdisc_priv(sch); + struct sk_buff *skb; + sfq_index a, next_a; + struct sfq_slot *slot; + + /* No active slots */ + if (q->tail == NULL) + return NULL; + +next_slot: + a = q->tail->next; + slot = &q->slots[a]; + if (slot->allot <= 0) { + q->tail = slot; + slot->allot += q->scaled_quantum; + goto next_slot; + } + skb = slot_dequeue_head(slot); + sfq_dec(q, a); + qdisc_bstats_update(sch, skb); + sch->q.qlen--; + qdisc_qstats_backlog_dec(sch, skb); + slot->backlog -= qdisc_pkt_len(skb); + /* Is the slot empty? */ + if (slot->qlen == 0) { + q->ht[slot->hash] = SFQ_EMPTY_SLOT; + next_a = slot->next; + if (a == next_a) { + q->tail = NULL; /* no more active slots */ + return skb; + } + q->tail->next = next_a; + } else { + slot->allot -= SFQ_ALLOT_SIZE(qdisc_pkt_len(skb)); + } + return skb; +} + +static void +sfq_reset(struct Qdisc *sch) +{ + struct sk_buff *skb; + + while ((skb = sfq_dequeue(sch)) != NULL) + rtnl_kfree_skbs(skb, skb); +} + +/* + * When q->perturbation is changed, we rehash all queued skbs + * to avoid OOO (Out Of Order) effects. + * We dont use sfq_dequeue()/sfq_enqueue() because we dont want to change + * counters. + */ +static void sfq_rehash(struct Qdisc *sch) +{ + struct sfq_sched_data *q = qdisc_priv(sch); + struct sk_buff *skb; + int i; + struct sfq_slot *slot; + struct sk_buff_head list; + int dropped = 0; + unsigned int drop_len = 0; + + __skb_queue_head_init(&list); + + for (i = 0; i < q->maxflows; i++) { + slot = &q->slots[i]; + if (!slot->qlen) + continue; + while (slot->qlen) { + skb = slot_dequeue_head(slot); + sfq_dec(q, i); + __skb_queue_tail(&list, skb); + } + slot->backlog = 0; + red_set_vars(&slot->vars); + q->ht[slot->hash] = SFQ_EMPTY_SLOT; + } + q->tail = NULL; + + while ((skb = __skb_dequeue(&list)) != NULL) { + unsigned int hash = sfq_hash(q, skb); + sfq_index x = q->ht[hash]; + + slot = &q->slots[x]; + if (x == SFQ_EMPTY_SLOT) { + x = q->dep[0].next; /* get a free slot */ + if (x >= SFQ_MAX_FLOWS) { +drop: + qdisc_qstats_backlog_dec(sch, skb); + drop_len += qdisc_pkt_len(skb); + kfree_skb(skb); + dropped++; + continue; + } + q->ht[hash] = x; + slot = &q->slots[x]; + slot->hash = hash; + } + if (slot->qlen >= q->maxdepth) + goto drop; + slot_queue_add(slot, skb); + if (q->red_parms) + slot->vars.qavg = red_calc_qavg(q->red_parms, + &slot->vars, + slot->backlog); + slot->backlog += qdisc_pkt_len(skb); + sfq_inc(q, x); + if (slot->qlen == 1) { /* The flow is new */ + if (q->tail == NULL) { /* It is the first flow */ + slot->next = x; + } else { + slot->next = q->tail->next; + q->tail->next = x; + } + q->tail = slot; + slot->allot = q->scaled_quantum; + } + } + sch->q.qlen -= dropped; + qdisc_tree_reduce_backlog(sch, dropped, drop_len); +} + +static void sfq_perturbation(struct timer_list *t) +{ + struct sfq_sched_data *q = from_timer(q, t, perturb_timer); + struct Qdisc *sch = q->sch; + spinlock_t *root_lock; + siphash_key_t nkey; + + get_random_bytes(&nkey, sizeof(nkey)); + rcu_read_lock(); + root_lock = qdisc_lock(qdisc_root_sleeping(sch)); + spin_lock(root_lock); + q->perturbation = nkey; + if (!q->filter_list && q->tail) + sfq_rehash(sch); + spin_unlock(root_lock); + + if (q->perturb_period) + mod_timer(&q->perturb_timer, jiffies + q->perturb_period); + rcu_read_unlock(); +} + +static int sfq_change(struct Qdisc *sch, struct nlattr *opt) +{ + struct sfq_sched_data *q = qdisc_priv(sch); + struct tc_sfq_qopt *ctl = nla_data(opt); + struct tc_sfq_qopt_v1 *ctl_v1 = NULL; + unsigned int qlen, dropped = 0; + struct red_parms *p = NULL; + struct sk_buff *to_free = NULL; + struct sk_buff *tail = NULL; + + if (opt->nla_len < nla_attr_size(sizeof(*ctl))) + return -EINVAL; + if (opt->nla_len >= nla_attr_size(sizeof(*ctl_v1))) + ctl_v1 = nla_data(opt); + if (ctl->divisor && + (!is_power_of_2(ctl->divisor) || ctl->divisor > 65536)) + return -EINVAL; + + /* slot->allot is a short, make sure quantum is not too big. */ + if (ctl->quantum) { + unsigned int scaled = SFQ_ALLOT_SIZE(ctl->quantum); + + if (scaled <= 0 || scaled > SHRT_MAX) + return -EINVAL; + } + + if (ctl_v1 && !red_check_params(ctl_v1->qth_min, ctl_v1->qth_max, + ctl_v1->Wlog, ctl_v1->Scell_log, NULL)) + return -EINVAL; + if (ctl_v1 && ctl_v1->qth_min) { + p = kmalloc(sizeof(*p), GFP_KERNEL); + if (!p) + return -ENOMEM; + } + sch_tree_lock(sch); + if (ctl->quantum) { + q->quantum = ctl->quantum; + q->scaled_quantum = SFQ_ALLOT_SIZE(q->quantum); + } + q->perturb_period = ctl->perturb_period * HZ; + if (ctl->flows) + q->maxflows = min_t(u32, ctl->flows, SFQ_MAX_FLOWS); + if (ctl->divisor) { + q->divisor = ctl->divisor; + q->maxflows = min_t(u32, q->maxflows, q->divisor); + } + if (ctl_v1) { + if (ctl_v1->depth) + q->maxdepth = min_t(u32, ctl_v1->depth, SFQ_MAX_DEPTH); + if (p) { + swap(q->red_parms, p); + red_set_parms(q->red_parms, + ctl_v1->qth_min, ctl_v1->qth_max, + ctl_v1->Wlog, + ctl_v1->Plog, ctl_v1->Scell_log, + NULL, + ctl_v1->max_P); + } + q->flags = ctl_v1->flags; + q->headdrop = ctl_v1->headdrop; + } + if (ctl->limit) { + q->limit = min_t(u32, ctl->limit, q->maxdepth * q->maxflows); + q->maxflows = min_t(u32, q->maxflows, q->limit); + } + + qlen = sch->q.qlen; + while (sch->q.qlen > q->limit) { + dropped += sfq_drop(sch, &to_free); + if (!tail) + tail = to_free; + } + + rtnl_kfree_skbs(to_free, tail); + qdisc_tree_reduce_backlog(sch, qlen - sch->q.qlen, dropped); + + del_timer(&q->perturb_timer); + if (q->perturb_period) { + mod_timer(&q->perturb_timer, jiffies + q->perturb_period); + get_random_bytes(&q->perturbation, sizeof(q->perturbation)); + } + sch_tree_unlock(sch); + kfree(p); + return 0; +} + +static void *sfq_alloc(size_t sz) +{ + return kvmalloc(sz, GFP_KERNEL); +} + +static void sfq_free(void *addr) +{ + kvfree(addr); +} + +static void sfq_destroy(struct Qdisc *sch) +{ + struct sfq_sched_data *q = qdisc_priv(sch); + + tcf_block_put(q->block); + q->perturb_period = 0; + del_timer_sync(&q->perturb_timer); + sfq_free(q->ht); + sfq_free(q->slots); + kfree(q->red_parms); +} + +static int sfq_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct sfq_sched_data *q = qdisc_priv(sch); + int i; + int err; + + q->sch = sch; + timer_setup(&q->perturb_timer, sfq_perturbation, TIMER_DEFERRABLE); + + err = tcf_block_get(&q->block, &q->filter_list, sch, extack); + if (err) + return err; + + for (i = 0; i < SFQ_MAX_DEPTH + 1; i++) { + q->dep[i].next = i + SFQ_MAX_FLOWS; + q->dep[i].prev = i + SFQ_MAX_FLOWS; + } + + q->limit = SFQ_MAX_DEPTH; + q->maxdepth = SFQ_MAX_DEPTH; + q->cur_depth = 0; + q->tail = NULL; + q->divisor = SFQ_DEFAULT_HASH_DIVISOR; + q->maxflows = SFQ_DEFAULT_FLOWS; + q->quantum = psched_mtu(qdisc_dev(sch)); + q->scaled_quantum = SFQ_ALLOT_SIZE(q->quantum); + q->perturb_period = 0; + get_random_bytes(&q->perturbation, sizeof(q->perturbation)); + + if (opt) { + int err = sfq_change(sch, opt); + if (err) + return err; + } + + q->ht = sfq_alloc(sizeof(q->ht[0]) * q->divisor); + q->slots = sfq_alloc(sizeof(q->slots[0]) * q->maxflows); + if (!q->ht || !q->slots) { + /* Note: sfq_destroy() will be called by our caller */ + return -ENOMEM; + } + + for (i = 0; i < q->divisor; i++) + q->ht[i] = SFQ_EMPTY_SLOT; + + for (i = 0; i < q->maxflows; i++) { + slot_queue_init(&q->slots[i]); + sfq_link(q, i); + } + if (q->limit >= 1) + sch->flags |= TCQ_F_CAN_BYPASS; + else + sch->flags &= ~TCQ_F_CAN_BYPASS; + return 0; +} + +static int sfq_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct sfq_sched_data *q = qdisc_priv(sch); + unsigned char *b = skb_tail_pointer(skb); + struct tc_sfq_qopt_v1 opt; + struct red_parms *p = q->red_parms; + + memset(&opt, 0, sizeof(opt)); + opt.v0.quantum = q->quantum; + opt.v0.perturb_period = q->perturb_period / HZ; + opt.v0.limit = q->limit; + opt.v0.divisor = q->divisor; + opt.v0.flows = q->maxflows; + opt.depth = q->maxdepth; + opt.headdrop = q->headdrop; + + if (p) { + opt.qth_min = p->qth_min >> p->Wlog; + opt.qth_max = p->qth_max >> p->Wlog; + opt.Wlog = p->Wlog; + opt.Plog = p->Plog; + opt.Scell_log = p->Scell_log; + opt.max_P = p->max_P; + } + memcpy(&opt.stats, &q->stats, sizeof(opt.stats)); + opt.flags = q->flags; + + if (nla_put(skb, TCA_OPTIONS, sizeof(opt), &opt)) + goto nla_put_failure; + + return skb->len; + +nla_put_failure: + nlmsg_trim(skb, b); + return -1; +} + +static struct Qdisc *sfq_leaf(struct Qdisc *sch, unsigned long arg) +{ + return NULL; +} + +static unsigned long sfq_find(struct Qdisc *sch, u32 classid) +{ + return 0; +} + +static unsigned long sfq_bind(struct Qdisc *sch, unsigned long parent, + u32 classid) +{ + return 0; +} + +static void sfq_unbind(struct Qdisc *q, unsigned long cl) +{ +} + +static struct tcf_block *sfq_tcf_block(struct Qdisc *sch, unsigned long cl, + struct netlink_ext_ack *extack) +{ + struct sfq_sched_data *q = qdisc_priv(sch); + + if (cl) + return NULL; + return q->block; +} + +static int sfq_dump_class(struct Qdisc *sch, unsigned long cl, + struct sk_buff *skb, struct tcmsg *tcm) +{ + tcm->tcm_handle |= TC_H_MIN(cl); + return 0; +} + +static int sfq_dump_class_stats(struct Qdisc *sch, unsigned long cl, + struct gnet_dump *d) +{ + struct sfq_sched_data *q = qdisc_priv(sch); + sfq_index idx = q->ht[cl - 1]; + struct gnet_stats_queue qs = { 0 }; + struct tc_sfq_xstats xstats = { 0 }; + + if (idx != SFQ_EMPTY_SLOT) { + const struct sfq_slot *slot = &q->slots[idx]; + + xstats.allot = slot->allot << SFQ_ALLOT_SHIFT; + qs.qlen = slot->qlen; + qs.backlog = slot->backlog; + } + if (gnet_stats_copy_queue(d, NULL, &qs, qs.qlen) < 0) + return -1; + return gnet_stats_copy_app(d, &xstats, sizeof(xstats)); +} + +static void sfq_walk(struct Qdisc *sch, struct qdisc_walker *arg) +{ + struct sfq_sched_data *q = qdisc_priv(sch); + unsigned int i; + + if (arg->stop) + return; + + for (i = 0; i < q->divisor; i++) { + if (q->ht[i] == SFQ_EMPTY_SLOT) { + arg->count++; + continue; + } + if (!tc_qdisc_stats_dump(sch, i + 1, arg)) + break; + } +} + +static const struct Qdisc_class_ops sfq_class_ops = { + .leaf = sfq_leaf, + .find = sfq_find, + .tcf_block = sfq_tcf_block, + .bind_tcf = sfq_bind, + .unbind_tcf = sfq_unbind, + .dump = sfq_dump_class, + .dump_stats = sfq_dump_class_stats, + .walk = sfq_walk, +}; + +static struct Qdisc_ops sfq_qdisc_ops __read_mostly = { + .cl_ops = &sfq_class_ops, + .id = "sfq", + .priv_size = sizeof(struct sfq_sched_data), + .enqueue = sfq_enqueue, + .dequeue = sfq_dequeue, + .peek = qdisc_peek_dequeued, + .init = sfq_init, + .reset = sfq_reset, + .destroy = sfq_destroy, + .change = NULL, + .dump = sfq_dump, + .owner = THIS_MODULE, +}; + +static int __init sfq_module_init(void) +{ + return register_qdisc(&sfq_qdisc_ops); +} +static void __exit sfq_module_exit(void) +{ + unregister_qdisc(&sfq_qdisc_ops); +} +module_init(sfq_module_init) +module_exit(sfq_module_exit) +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_skbprio.c b/net/sched/sch_skbprio.c new file mode 100644 index 000000000..5df2dacb7 --- /dev/null +++ b/net/sched/sch_skbprio.c @@ -0,0 +1,309 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/sch_skbprio.c SKB Priority Queue. + * + * Authors: Nishanth Devarajan, + * Cody Doucette, + * original idea by Michel Machado, Cody Doucette, and Qiaobin Fu + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* SKB Priority Queue + * ================================= + * + * Skbprio (SKB Priority Queue) is a queueing discipline that prioritizes + * packets according to their skb->priority field. Under congestion, + * Skbprio drops already-enqueued lower priority packets to make space + * available for higher priority packets; it was conceived as a solution + * for denial-of-service defenses that need to route packets with different + * priorities as a mean to overcome DoS attacks. + */ + +struct skbprio_sched_data { + /* Queue state. */ + struct sk_buff_head qdiscs[SKBPRIO_MAX_PRIORITY]; + struct gnet_stats_queue qstats[SKBPRIO_MAX_PRIORITY]; + u16 highest_prio; + u16 lowest_prio; +}; + +static u16 calc_new_high_prio(const struct skbprio_sched_data *q) +{ + int prio; + + for (prio = q->highest_prio - 1; prio >= q->lowest_prio; prio--) { + if (!skb_queue_empty(&q->qdiscs[prio])) + return prio; + } + + /* SKB queue is empty, return 0 (default highest priority setting). */ + return 0; +} + +static u16 calc_new_low_prio(const struct skbprio_sched_data *q) +{ + int prio; + + for (prio = q->lowest_prio + 1; prio <= q->highest_prio; prio++) { + if (!skb_queue_empty(&q->qdiscs[prio])) + return prio; + } + + /* SKB queue is empty, return SKBPRIO_MAX_PRIORITY - 1 + * (default lowest priority setting). + */ + return SKBPRIO_MAX_PRIORITY - 1; +} + +static int skbprio_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + const unsigned int max_priority = SKBPRIO_MAX_PRIORITY - 1; + struct skbprio_sched_data *q = qdisc_priv(sch); + struct sk_buff_head *qdisc; + struct sk_buff_head *lp_qdisc; + struct sk_buff *to_drop; + u16 prio, lp; + + /* Obtain the priority of @skb. */ + prio = min(skb->priority, max_priority); + + qdisc = &q->qdiscs[prio]; + if (sch->q.qlen < sch->limit) { + __skb_queue_tail(qdisc, skb); + qdisc_qstats_backlog_inc(sch, skb); + q->qstats[prio].backlog += qdisc_pkt_len(skb); + + /* Check to update highest and lowest priorities. */ + if (prio > q->highest_prio) + q->highest_prio = prio; + + if (prio < q->lowest_prio) + q->lowest_prio = prio; + + sch->q.qlen++; + return NET_XMIT_SUCCESS; + } + + /* If this packet has the lowest priority, drop it. */ + lp = q->lowest_prio; + if (prio <= lp) { + q->qstats[prio].drops++; + q->qstats[prio].overlimits++; + return qdisc_drop(skb, sch, to_free); + } + + __skb_queue_tail(qdisc, skb); + qdisc_qstats_backlog_inc(sch, skb); + q->qstats[prio].backlog += qdisc_pkt_len(skb); + + /* Drop the packet at the tail of the lowest priority qdisc. */ + lp_qdisc = &q->qdiscs[lp]; + to_drop = __skb_dequeue_tail(lp_qdisc); + BUG_ON(!to_drop); + qdisc_qstats_backlog_dec(sch, to_drop); + qdisc_drop(to_drop, sch, to_free); + + q->qstats[lp].backlog -= qdisc_pkt_len(to_drop); + q->qstats[lp].drops++; + q->qstats[lp].overlimits++; + + /* Check to update highest and lowest priorities. */ + if (skb_queue_empty(lp_qdisc)) { + if (q->lowest_prio == q->highest_prio) { + /* The incoming packet is the only packet in queue. */ + BUG_ON(sch->q.qlen != 1); + q->lowest_prio = prio; + q->highest_prio = prio; + } else { + q->lowest_prio = calc_new_low_prio(q); + } + } + + if (prio > q->highest_prio) + q->highest_prio = prio; + + return NET_XMIT_CN; +} + +static struct sk_buff *skbprio_dequeue(struct Qdisc *sch) +{ + struct skbprio_sched_data *q = qdisc_priv(sch); + struct sk_buff_head *hpq = &q->qdiscs[q->highest_prio]; + struct sk_buff *skb = __skb_dequeue(hpq); + + if (unlikely(!skb)) + return NULL; + + sch->q.qlen--; + qdisc_qstats_backlog_dec(sch, skb); + qdisc_bstats_update(sch, skb); + + q->qstats[q->highest_prio].backlog -= qdisc_pkt_len(skb); + + /* Update highest priority field. */ + if (skb_queue_empty(hpq)) { + if (q->lowest_prio == q->highest_prio) { + BUG_ON(sch->q.qlen); + q->highest_prio = 0; + q->lowest_prio = SKBPRIO_MAX_PRIORITY - 1; + } else { + q->highest_prio = calc_new_high_prio(q); + } + } + return skb; +} + +static int skbprio_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct tc_skbprio_qopt *ctl = nla_data(opt); + + if (opt->nla_len != nla_attr_size(sizeof(*ctl))) + return -EINVAL; + + sch->limit = ctl->limit; + return 0; +} + +static int skbprio_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct skbprio_sched_data *q = qdisc_priv(sch); + int prio; + + /* Initialise all queues, one for each possible priority. */ + for (prio = 0; prio < SKBPRIO_MAX_PRIORITY; prio++) + __skb_queue_head_init(&q->qdiscs[prio]); + + memset(&q->qstats, 0, sizeof(q->qstats)); + q->highest_prio = 0; + q->lowest_prio = SKBPRIO_MAX_PRIORITY - 1; + sch->limit = 64; + if (!opt) + return 0; + + return skbprio_change(sch, opt, extack); +} + +static int skbprio_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct tc_skbprio_qopt opt; + + opt.limit = sch->limit; + + if (nla_put(skb, TCA_OPTIONS, sizeof(opt), &opt)) + return -1; + + return skb->len; +} + +static void skbprio_reset(struct Qdisc *sch) +{ + struct skbprio_sched_data *q = qdisc_priv(sch); + int prio; + + for (prio = 0; prio < SKBPRIO_MAX_PRIORITY; prio++) + __skb_queue_purge(&q->qdiscs[prio]); + + memset(&q->qstats, 0, sizeof(q->qstats)); + q->highest_prio = 0; + q->lowest_prio = SKBPRIO_MAX_PRIORITY - 1; +} + +static void skbprio_destroy(struct Qdisc *sch) +{ + struct skbprio_sched_data *q = qdisc_priv(sch); + int prio; + + for (prio = 0; prio < SKBPRIO_MAX_PRIORITY; prio++) + __skb_queue_purge(&q->qdiscs[prio]); +} + +static struct Qdisc *skbprio_leaf(struct Qdisc *sch, unsigned long arg) +{ + return NULL; +} + +static unsigned long skbprio_find(struct Qdisc *sch, u32 classid) +{ + return 0; +} + +static int skbprio_dump_class(struct Qdisc *sch, unsigned long cl, + struct sk_buff *skb, struct tcmsg *tcm) +{ + tcm->tcm_handle |= TC_H_MIN(cl); + return 0; +} + +static int skbprio_dump_class_stats(struct Qdisc *sch, unsigned long cl, + struct gnet_dump *d) +{ + struct skbprio_sched_data *q = qdisc_priv(sch); + if (gnet_stats_copy_queue(d, NULL, &q->qstats[cl - 1], + q->qstats[cl - 1].qlen) < 0) + return -1; + return 0; +} + +static void skbprio_walk(struct Qdisc *sch, struct qdisc_walker *arg) +{ + unsigned int i; + + if (arg->stop) + return; + + for (i = 0; i < SKBPRIO_MAX_PRIORITY; i++) { + if (!tc_qdisc_stats_dump(sch, i + 1, arg)) + break; + } +} + +static const struct Qdisc_class_ops skbprio_class_ops = { + .leaf = skbprio_leaf, + .find = skbprio_find, + .dump = skbprio_dump_class, + .dump_stats = skbprio_dump_class_stats, + .walk = skbprio_walk, +}; + +static struct Qdisc_ops skbprio_qdisc_ops __read_mostly = { + .cl_ops = &skbprio_class_ops, + .id = "skbprio", + .priv_size = sizeof(struct skbprio_sched_data), + .enqueue = skbprio_enqueue, + .dequeue = skbprio_dequeue, + .peek = qdisc_peek_dequeued, + .init = skbprio_init, + .reset = skbprio_reset, + .change = skbprio_change, + .dump = skbprio_dump, + .destroy = skbprio_destroy, + .owner = THIS_MODULE, +}; + +static int __init skbprio_module_init(void) +{ + return register_qdisc(&skbprio_qdisc_ops); +} + +static void __exit skbprio_module_exit(void) +{ + unregister_qdisc(&skbprio_qdisc_ops); +} + +module_init(skbprio_module_init) +module_exit(skbprio_module_exit) + +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_taprio.c b/net/sched/sch_taprio.c new file mode 100644 index 000000000..8d5eebb2d --- /dev/null +++ b/net/sched/sch_taprio.c @@ -0,0 +1,2171 @@ +// SPDX-License-Identifier: GPL-2.0 + +/* net/sched/sch_taprio.c Time Aware Priority Scheduler + * + * Authors: Vinicius Costa Gomes + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static LIST_HEAD(taprio_list); + +#define TAPRIO_ALL_GATES_OPEN -1 + +#define TXTIME_ASSIST_IS_ENABLED(flags) ((flags) & TCA_TAPRIO_ATTR_FLAG_TXTIME_ASSIST) +#define FULL_OFFLOAD_IS_ENABLED(flags) ((flags) & TCA_TAPRIO_ATTR_FLAG_FULL_OFFLOAD) +#define TAPRIO_FLAGS_INVALID U32_MAX + +struct sched_entry { + struct list_head list; + + /* The instant that this entry "closes" and the next one + * should open, the qdisc will make some effort so that no + * packet leaves after this time. + */ + ktime_t close_time; + ktime_t next_txtime; + atomic_t budget; + int index; + u32 gate_mask; + u32 interval; + u8 command; +}; + +struct sched_gate_list { + struct rcu_head rcu; + struct list_head entries; + size_t num_entries; + ktime_t cycle_close_time; + s64 cycle_time; + s64 cycle_time_extension; + s64 base_time; +}; + +struct taprio_sched { + struct Qdisc **qdiscs; + struct Qdisc *root; + u32 flags; + enum tk_offsets tk_offset; + int clockid; + bool offloaded; + atomic64_t picos_per_byte; /* Using picoseconds because for 10Gbps+ + * speeds it's sub-nanoseconds per byte + */ + + /* Protects the update side of the RCU protected current_entry */ + spinlock_t current_entry_lock; + struct sched_entry __rcu *current_entry; + struct sched_gate_list __rcu *oper_sched; + struct sched_gate_list __rcu *admin_sched; + struct hrtimer advance_timer; + struct list_head taprio_list; + u32 max_frm_len[TC_MAX_QUEUE]; /* for the fast path */ + u32 max_sdu[TC_MAX_QUEUE]; /* for dump and offloading */ + u32 txtime_delay; +}; + +struct __tc_taprio_qopt_offload { + refcount_t users; + struct tc_taprio_qopt_offload offload; +}; + +static ktime_t sched_base_time(const struct sched_gate_list *sched) +{ + if (!sched) + return KTIME_MAX; + + return ns_to_ktime(sched->base_time); +} + +static ktime_t taprio_mono_to_any(const struct taprio_sched *q, ktime_t mono) +{ + /* This pairs with WRITE_ONCE() in taprio_parse_clockid() */ + enum tk_offsets tk_offset = READ_ONCE(q->tk_offset); + + switch (tk_offset) { + case TK_OFFS_MAX: + return mono; + default: + return ktime_mono_to_any(mono, tk_offset); + } +} + +static ktime_t taprio_get_time(const struct taprio_sched *q) +{ + return taprio_mono_to_any(q, ktime_get()); +} + +static void taprio_free_sched_cb(struct rcu_head *head) +{ + struct sched_gate_list *sched = container_of(head, struct sched_gate_list, rcu); + struct sched_entry *entry, *n; + + list_for_each_entry_safe(entry, n, &sched->entries, list) { + list_del(&entry->list); + kfree(entry); + } + + kfree(sched); +} + +static void switch_schedules(struct taprio_sched *q, + struct sched_gate_list **admin, + struct sched_gate_list **oper) +{ + rcu_assign_pointer(q->oper_sched, *admin); + rcu_assign_pointer(q->admin_sched, NULL); + + if (*oper) + call_rcu(&(*oper)->rcu, taprio_free_sched_cb); + + *oper = *admin; + *admin = NULL; +} + +/* Get how much time has been already elapsed in the current cycle. */ +static s32 get_cycle_time_elapsed(struct sched_gate_list *sched, ktime_t time) +{ + ktime_t time_since_sched_start; + s32 time_elapsed; + + time_since_sched_start = ktime_sub(time, sched->base_time); + div_s64_rem(time_since_sched_start, sched->cycle_time, &time_elapsed); + + return time_elapsed; +} + +static ktime_t get_interval_end_time(struct sched_gate_list *sched, + struct sched_gate_list *admin, + struct sched_entry *entry, + ktime_t intv_start) +{ + s32 cycle_elapsed = get_cycle_time_elapsed(sched, intv_start); + ktime_t intv_end, cycle_ext_end, cycle_end; + + cycle_end = ktime_add_ns(intv_start, sched->cycle_time - cycle_elapsed); + intv_end = ktime_add_ns(intv_start, entry->interval); + cycle_ext_end = ktime_add(cycle_end, sched->cycle_time_extension); + + if (ktime_before(intv_end, cycle_end)) + return intv_end; + else if (admin && admin != sched && + ktime_after(admin->base_time, cycle_end) && + ktime_before(admin->base_time, cycle_ext_end)) + return admin->base_time; + else + return cycle_end; +} + +static int length_to_duration(struct taprio_sched *q, int len) +{ + return div_u64(len * atomic64_read(&q->picos_per_byte), PSEC_PER_NSEC); +} + +/* Returns the entry corresponding to next available interval. If + * validate_interval is set, it only validates whether the timestamp occurs + * when the gate corresponding to the skb's traffic class is open. + */ +static struct sched_entry *find_entry_to_transmit(struct sk_buff *skb, + struct Qdisc *sch, + struct sched_gate_list *sched, + struct sched_gate_list *admin, + ktime_t time, + ktime_t *interval_start, + ktime_t *interval_end, + bool validate_interval) +{ + ktime_t curr_intv_start, curr_intv_end, cycle_end, packet_transmit_time; + ktime_t earliest_txtime = KTIME_MAX, txtime, cycle, transmit_end_time; + struct sched_entry *entry = NULL, *entry_found = NULL; + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + bool entry_available = false; + s32 cycle_elapsed; + int tc, n; + + tc = netdev_get_prio_tc_map(dev, skb->priority); + packet_transmit_time = length_to_duration(q, qdisc_pkt_len(skb)); + + *interval_start = 0; + *interval_end = 0; + + if (!sched) + return NULL; + + cycle = sched->cycle_time; + cycle_elapsed = get_cycle_time_elapsed(sched, time); + curr_intv_end = ktime_sub_ns(time, cycle_elapsed); + cycle_end = ktime_add_ns(curr_intv_end, cycle); + + list_for_each_entry(entry, &sched->entries, list) { + curr_intv_start = curr_intv_end; + curr_intv_end = get_interval_end_time(sched, admin, entry, + curr_intv_start); + + if (ktime_after(curr_intv_start, cycle_end)) + break; + + if (!(entry->gate_mask & BIT(tc)) || + packet_transmit_time > entry->interval) + continue; + + txtime = entry->next_txtime; + + if (ktime_before(txtime, time) || validate_interval) { + transmit_end_time = ktime_add_ns(time, packet_transmit_time); + if ((ktime_before(curr_intv_start, time) && + ktime_before(transmit_end_time, curr_intv_end)) || + (ktime_after(curr_intv_start, time) && !validate_interval)) { + entry_found = entry; + *interval_start = curr_intv_start; + *interval_end = curr_intv_end; + break; + } else if (!entry_available && !validate_interval) { + /* Here, we are just trying to find out the + * first available interval in the next cycle. + */ + entry_available = true; + entry_found = entry; + *interval_start = ktime_add_ns(curr_intv_start, cycle); + *interval_end = ktime_add_ns(curr_intv_end, cycle); + } + } else if (ktime_before(txtime, earliest_txtime) && + !entry_available) { + earliest_txtime = txtime; + entry_found = entry; + n = div_s64(ktime_sub(txtime, curr_intv_start), cycle); + *interval_start = ktime_add(curr_intv_start, n * cycle); + *interval_end = ktime_add(curr_intv_end, n * cycle); + } + } + + return entry_found; +} + +static bool is_valid_interval(struct sk_buff *skb, struct Qdisc *sch) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct sched_gate_list *sched, *admin; + ktime_t interval_start, interval_end; + struct sched_entry *entry; + + rcu_read_lock(); + sched = rcu_dereference(q->oper_sched); + admin = rcu_dereference(q->admin_sched); + + entry = find_entry_to_transmit(skb, sch, sched, admin, skb->tstamp, + &interval_start, &interval_end, true); + rcu_read_unlock(); + + return entry; +} + +static bool taprio_flags_valid(u32 flags) +{ + /* Make sure no other flag bits are set. */ + if (flags & ~(TCA_TAPRIO_ATTR_FLAG_TXTIME_ASSIST | + TCA_TAPRIO_ATTR_FLAG_FULL_OFFLOAD)) + return false; + /* txtime-assist and full offload are mutually exclusive */ + if ((flags & TCA_TAPRIO_ATTR_FLAG_TXTIME_ASSIST) && + (flags & TCA_TAPRIO_ATTR_FLAG_FULL_OFFLOAD)) + return false; + return true; +} + +/* This returns the tstamp value set by TCP in terms of the set clock. */ +static ktime_t get_tcp_tstamp(struct taprio_sched *q, struct sk_buff *skb) +{ + unsigned int offset = skb_network_offset(skb); + const struct ipv6hdr *ipv6h; + const struct iphdr *iph; + struct ipv6hdr _ipv6h; + + ipv6h = skb_header_pointer(skb, offset, sizeof(_ipv6h), &_ipv6h); + if (!ipv6h) + return 0; + + if (ipv6h->version == 4) { + iph = (struct iphdr *)ipv6h; + offset += iph->ihl * 4; + + /* special-case 6in4 tunnelling, as that is a common way to get + * v6 connectivity in the home + */ + if (iph->protocol == IPPROTO_IPV6) { + ipv6h = skb_header_pointer(skb, offset, + sizeof(_ipv6h), &_ipv6h); + + if (!ipv6h || ipv6h->nexthdr != IPPROTO_TCP) + return 0; + } else if (iph->protocol != IPPROTO_TCP) { + return 0; + } + } else if (ipv6h->version == 6 && ipv6h->nexthdr != IPPROTO_TCP) { + return 0; + } + + return taprio_mono_to_any(q, skb->skb_mstamp_ns); +} + +/* There are a few scenarios where we will have to modify the txtime from + * what is read from next_txtime in sched_entry. They are: + * 1. If txtime is in the past, + * a. The gate for the traffic class is currently open and packet can be + * transmitted before it closes, schedule the packet right away. + * b. If the gate corresponding to the traffic class is going to open later + * in the cycle, set the txtime of packet to the interval start. + * 2. If txtime is in the future, there are packets corresponding to the + * current traffic class waiting to be transmitted. So, the following + * possibilities exist: + * a. We can transmit the packet before the window containing the txtime + * closes. + * b. The window might close before the transmission can be completed + * successfully. So, schedule the packet in the next open window. + */ +static long get_packet_txtime(struct sk_buff *skb, struct Qdisc *sch) +{ + ktime_t transmit_end_time, interval_end, interval_start, tcp_tstamp; + struct taprio_sched *q = qdisc_priv(sch); + struct sched_gate_list *sched, *admin; + ktime_t minimum_time, now, txtime; + int len, packet_transmit_time; + struct sched_entry *entry; + bool sched_changed; + + now = taprio_get_time(q); + minimum_time = ktime_add_ns(now, q->txtime_delay); + + tcp_tstamp = get_tcp_tstamp(q, skb); + minimum_time = max_t(ktime_t, minimum_time, tcp_tstamp); + + rcu_read_lock(); + admin = rcu_dereference(q->admin_sched); + sched = rcu_dereference(q->oper_sched); + if (admin && ktime_after(minimum_time, admin->base_time)) + switch_schedules(q, &admin, &sched); + + /* Until the schedule starts, all the queues are open */ + if (!sched || ktime_before(minimum_time, sched->base_time)) { + txtime = minimum_time; + goto done; + } + + len = qdisc_pkt_len(skb); + packet_transmit_time = length_to_duration(q, len); + + do { + sched_changed = false; + + entry = find_entry_to_transmit(skb, sch, sched, admin, + minimum_time, + &interval_start, &interval_end, + false); + if (!entry) { + txtime = 0; + goto done; + } + + txtime = entry->next_txtime; + txtime = max_t(ktime_t, txtime, minimum_time); + txtime = max_t(ktime_t, txtime, interval_start); + + if (admin && admin != sched && + ktime_after(txtime, admin->base_time)) { + sched = admin; + sched_changed = true; + continue; + } + + transmit_end_time = ktime_add(txtime, packet_transmit_time); + minimum_time = transmit_end_time; + + /* Update the txtime of current entry to the next time it's + * interval starts. + */ + if (ktime_after(transmit_end_time, interval_end)) + entry->next_txtime = ktime_add(interval_start, sched->cycle_time); + } while (sched_changed || ktime_after(transmit_end_time, interval_end)); + + entry->next_txtime = transmit_end_time; + +done: + rcu_read_unlock(); + return txtime; +} + +static int taprio_enqueue_one(struct sk_buff *skb, struct Qdisc *sch, + struct Qdisc *child, struct sk_buff **to_free) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + int prio = skb->priority; + u8 tc; + + /* sk_flags are only safe to use on full sockets. */ + if (skb->sk && sk_fullsock(skb->sk) && sock_flag(skb->sk, SOCK_TXTIME)) { + if (!is_valid_interval(skb, sch)) + return qdisc_drop(skb, sch, to_free); + } else if (TXTIME_ASSIST_IS_ENABLED(q->flags)) { + skb->tstamp = get_packet_txtime(skb, sch); + if (!skb->tstamp) + return qdisc_drop(skb, sch, to_free); + } + + /* Devices with full offload are expected to honor this in hardware */ + tc = netdev_get_prio_tc_map(dev, prio); + if (skb->len > q->max_frm_len[tc]) + return qdisc_drop(skb, sch, to_free); + + qdisc_qstats_backlog_inc(sch, skb); + sch->q.qlen++; + + return qdisc_enqueue(skb, child, to_free); +} + +/* Will not be called in the full offload case, since the TX queues are + * attached to the Qdisc created using qdisc_create_dflt() + */ +static int taprio_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct Qdisc *child; + int queue; + + queue = skb_get_queue_mapping(skb); + + child = q->qdiscs[queue]; + if (unlikely(!child)) + return qdisc_drop(skb, sch, to_free); + + /* Large packets might not be transmitted when the transmission duration + * exceeds any configured interval. Therefore, segment the skb into + * smaller chunks. Drivers with full offload are expected to handle + * this in hardware. + */ + if (skb_is_gso(skb)) { + unsigned int slen = 0, numsegs = 0, len = qdisc_pkt_len(skb); + netdev_features_t features = netif_skb_features(skb); + struct sk_buff *segs, *nskb; + int ret; + + segs = skb_gso_segment(skb, features & ~NETIF_F_GSO_MASK); + if (IS_ERR_OR_NULL(segs)) + return qdisc_drop(skb, sch, to_free); + + skb_list_walk_safe(segs, segs, nskb) { + skb_mark_not_on_list(segs); + qdisc_skb_cb(segs)->pkt_len = segs->len; + slen += segs->len; + + ret = taprio_enqueue_one(segs, sch, child, to_free); + if (ret != NET_XMIT_SUCCESS) { + if (net_xmit_drop_count(ret)) + qdisc_qstats_drop(sch); + } else { + numsegs++; + } + } + + if (numsegs > 1) + qdisc_tree_reduce_backlog(sch, 1 - numsegs, len - slen); + consume_skb(skb); + + return numsegs > 0 ? NET_XMIT_SUCCESS : NET_XMIT_DROP; + } + + return taprio_enqueue_one(skb, sch, child, to_free); +} + +/* Will not be called in the full offload case, since the TX queues are + * attached to the Qdisc created using qdisc_create_dflt() + */ +static struct sk_buff *taprio_peek(struct Qdisc *sch) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct sched_entry *entry; + struct sk_buff *skb; + u32 gate_mask; + int i; + + rcu_read_lock(); + entry = rcu_dereference(q->current_entry); + gate_mask = entry ? entry->gate_mask : TAPRIO_ALL_GATES_OPEN; + rcu_read_unlock(); + + if (!gate_mask) + return NULL; + + for (i = 0; i < dev->num_tx_queues; i++) { + struct Qdisc *child = q->qdiscs[i]; + int prio; + u8 tc; + + if (unlikely(!child)) + continue; + + skb = child->ops->peek(child); + if (!skb) + continue; + + if (TXTIME_ASSIST_IS_ENABLED(q->flags)) + return skb; + + prio = skb->priority; + tc = netdev_get_prio_tc_map(dev, prio); + + if (!(gate_mask & BIT(tc))) + continue; + + return skb; + } + + return NULL; +} + +static void taprio_set_budget(struct taprio_sched *q, struct sched_entry *entry) +{ + atomic_set(&entry->budget, + div64_u64((u64)entry->interval * PSEC_PER_NSEC, + atomic64_read(&q->picos_per_byte))); +} + +/* Will not be called in the full offload case, since the TX queues are + * attached to the Qdisc created using qdisc_create_dflt() + */ +static struct sk_buff *taprio_dequeue(struct Qdisc *sch) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct sk_buff *skb = NULL; + struct sched_entry *entry; + u32 gate_mask; + int i; + + rcu_read_lock(); + entry = rcu_dereference(q->current_entry); + /* if there's no entry, it means that the schedule didn't + * start yet, so force all gates to be open, this is in + * accordance to IEEE 802.1Qbv-2015 Section 8.6.9.4.5 + * "AdminGateStates" + */ + gate_mask = entry ? entry->gate_mask : TAPRIO_ALL_GATES_OPEN; + + if (!gate_mask) + goto done; + + for (i = 0; i < dev->num_tx_queues; i++) { + struct Qdisc *child = q->qdiscs[i]; + ktime_t guard; + int prio; + int len; + u8 tc; + + if (unlikely(!child)) + continue; + + if (TXTIME_ASSIST_IS_ENABLED(q->flags)) { + skb = child->ops->dequeue(child); + if (!skb) + continue; + goto skb_found; + } + + skb = child->ops->peek(child); + if (!skb) + continue; + + prio = skb->priority; + tc = netdev_get_prio_tc_map(dev, prio); + + if (!(gate_mask & BIT(tc))) { + skb = NULL; + continue; + } + + len = qdisc_pkt_len(skb); + guard = ktime_add_ns(taprio_get_time(q), + length_to_duration(q, len)); + + /* In the case that there's no gate entry, there's no + * guard band ... + */ + if (gate_mask != TAPRIO_ALL_GATES_OPEN && + ktime_after(guard, entry->close_time)) { + skb = NULL; + continue; + } + + /* ... and no budget. */ + if (gate_mask != TAPRIO_ALL_GATES_OPEN && + atomic_sub_return(len, &entry->budget) < 0) { + skb = NULL; + continue; + } + + skb = child->ops->dequeue(child); + if (unlikely(!skb)) + goto done; + +skb_found: + qdisc_bstats_update(sch, skb); + qdisc_qstats_backlog_dec(sch, skb); + sch->q.qlen--; + + goto done; + } + +done: + rcu_read_unlock(); + + return skb; +} + +static bool should_restart_cycle(const struct sched_gate_list *oper, + const struct sched_entry *entry) +{ + if (list_is_last(&entry->list, &oper->entries)) + return true; + + if (ktime_compare(entry->close_time, oper->cycle_close_time) == 0) + return true; + + return false; +} + +static bool should_change_schedules(const struct sched_gate_list *admin, + const struct sched_gate_list *oper, + ktime_t close_time) +{ + ktime_t next_base_time, extension_time; + + if (!admin) + return false; + + next_base_time = sched_base_time(admin); + + /* This is the simple case, the close_time would fall after + * the next schedule base_time. + */ + if (ktime_compare(next_base_time, close_time) <= 0) + return true; + + /* This is the cycle_time_extension case, if the close_time + * plus the amount that can be extended would fall after the + * next schedule base_time, we can extend the current schedule + * for that amount. + */ + extension_time = ktime_add_ns(close_time, oper->cycle_time_extension); + + /* FIXME: the IEEE 802.1Q-2018 Specification isn't clear about + * how precisely the extension should be made. So after + * conformance testing, this logic may change. + */ + if (ktime_compare(next_base_time, extension_time) <= 0) + return true; + + return false; +} + +static enum hrtimer_restart advance_sched(struct hrtimer *timer) +{ + struct taprio_sched *q = container_of(timer, struct taprio_sched, + advance_timer); + struct sched_gate_list *oper, *admin; + struct sched_entry *entry, *next; + struct Qdisc *sch = q->root; + ktime_t close_time; + + spin_lock(&q->current_entry_lock); + entry = rcu_dereference_protected(q->current_entry, + lockdep_is_held(&q->current_entry_lock)); + oper = rcu_dereference_protected(q->oper_sched, + lockdep_is_held(&q->current_entry_lock)); + admin = rcu_dereference_protected(q->admin_sched, + lockdep_is_held(&q->current_entry_lock)); + + if (!oper) + switch_schedules(q, &admin, &oper); + + /* This can happen in two cases: 1. this is the very first run + * of this function (i.e. we weren't running any schedule + * previously); 2. The previous schedule just ended. The first + * entry of all schedules are pre-calculated during the + * schedule initialization. + */ + if (unlikely(!entry || entry->close_time == oper->base_time)) { + next = list_first_entry(&oper->entries, struct sched_entry, + list); + close_time = next->close_time; + goto first_run; + } + + if (should_restart_cycle(oper, entry)) { + next = list_first_entry(&oper->entries, struct sched_entry, + list); + oper->cycle_close_time = ktime_add_ns(oper->cycle_close_time, + oper->cycle_time); + } else { + next = list_next_entry(entry, list); + } + + close_time = ktime_add_ns(entry->close_time, next->interval); + close_time = min_t(ktime_t, close_time, oper->cycle_close_time); + + if (should_change_schedules(admin, oper, close_time)) { + /* Set things so the next time this runs, the new + * schedule runs. + */ + close_time = sched_base_time(admin); + switch_schedules(q, &admin, &oper); + } + + next->close_time = close_time; + taprio_set_budget(q, next); + +first_run: + rcu_assign_pointer(q->current_entry, next); + spin_unlock(&q->current_entry_lock); + + hrtimer_set_expires(&q->advance_timer, close_time); + + rcu_read_lock(); + __netif_schedule(sch); + rcu_read_unlock(); + + return HRTIMER_RESTART; +} + +static const struct nla_policy entry_policy[TCA_TAPRIO_SCHED_ENTRY_MAX + 1] = { + [TCA_TAPRIO_SCHED_ENTRY_INDEX] = { .type = NLA_U32 }, + [TCA_TAPRIO_SCHED_ENTRY_CMD] = { .type = NLA_U8 }, + [TCA_TAPRIO_SCHED_ENTRY_GATE_MASK] = { .type = NLA_U32 }, + [TCA_TAPRIO_SCHED_ENTRY_INTERVAL] = { .type = NLA_U32 }, +}; + +static const struct nla_policy taprio_tc_policy[TCA_TAPRIO_TC_ENTRY_MAX + 1] = { + [TCA_TAPRIO_TC_ENTRY_INDEX] = { .type = NLA_U32 }, + [TCA_TAPRIO_TC_ENTRY_MAX_SDU] = { .type = NLA_U32 }, +}; + +static struct netlink_range_validation_signed taprio_cycle_time_range = { + .min = 0, + .max = INT_MAX, +}; + +static const struct nla_policy taprio_policy[TCA_TAPRIO_ATTR_MAX + 1] = { + [TCA_TAPRIO_ATTR_PRIOMAP] = { + .len = sizeof(struct tc_mqprio_qopt) + }, + [TCA_TAPRIO_ATTR_SCHED_ENTRY_LIST] = { .type = NLA_NESTED }, + [TCA_TAPRIO_ATTR_SCHED_BASE_TIME] = { .type = NLA_S64 }, + [TCA_TAPRIO_ATTR_SCHED_SINGLE_ENTRY] = { .type = NLA_NESTED }, + [TCA_TAPRIO_ATTR_SCHED_CLOCKID] = { .type = NLA_S32 }, + [TCA_TAPRIO_ATTR_SCHED_CYCLE_TIME] = + NLA_POLICY_FULL_RANGE_SIGNED(NLA_S64, &taprio_cycle_time_range), + [TCA_TAPRIO_ATTR_SCHED_CYCLE_TIME_EXTENSION] = { .type = NLA_S64 }, + [TCA_TAPRIO_ATTR_FLAGS] = { .type = NLA_U32 }, + [TCA_TAPRIO_ATTR_TXTIME_DELAY] = { .type = NLA_U32 }, + [TCA_TAPRIO_ATTR_TC_ENTRY] = { .type = NLA_NESTED }, +}; + +static int fill_sched_entry(struct taprio_sched *q, struct nlattr **tb, + struct sched_entry *entry, + struct netlink_ext_ack *extack) +{ + int min_duration = length_to_duration(q, ETH_ZLEN); + u32 interval = 0; + + if (tb[TCA_TAPRIO_SCHED_ENTRY_CMD]) + entry->command = nla_get_u8( + tb[TCA_TAPRIO_SCHED_ENTRY_CMD]); + + if (tb[TCA_TAPRIO_SCHED_ENTRY_GATE_MASK]) + entry->gate_mask = nla_get_u32( + tb[TCA_TAPRIO_SCHED_ENTRY_GATE_MASK]); + + if (tb[TCA_TAPRIO_SCHED_ENTRY_INTERVAL]) + interval = nla_get_u32( + tb[TCA_TAPRIO_SCHED_ENTRY_INTERVAL]); + + /* The interval should allow at least the minimum ethernet + * frame to go out. + */ + if (interval < min_duration) { + NL_SET_ERR_MSG(extack, "Invalid interval for schedule entry"); + return -EINVAL; + } + + entry->interval = interval; + + return 0; +} + +static int parse_sched_entry(struct taprio_sched *q, struct nlattr *n, + struct sched_entry *entry, int index, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_TAPRIO_SCHED_ENTRY_MAX + 1] = { }; + int err; + + err = nla_parse_nested_deprecated(tb, TCA_TAPRIO_SCHED_ENTRY_MAX, n, + entry_policy, NULL); + if (err < 0) { + NL_SET_ERR_MSG(extack, "Could not parse nested entry"); + return -EINVAL; + } + + entry->index = index; + + return fill_sched_entry(q, tb, entry, extack); +} + +static int parse_sched_list(struct taprio_sched *q, struct nlattr *list, + struct sched_gate_list *sched, + struct netlink_ext_ack *extack) +{ + struct nlattr *n; + int err, rem; + int i = 0; + + if (!list) + return -EINVAL; + + nla_for_each_nested(n, list, rem) { + struct sched_entry *entry; + + if (nla_type(n) != TCA_TAPRIO_SCHED_ENTRY) { + NL_SET_ERR_MSG(extack, "Attribute is not of type 'entry'"); + continue; + } + + entry = kzalloc(sizeof(*entry), GFP_KERNEL); + if (!entry) { + NL_SET_ERR_MSG(extack, "Not enough memory for entry"); + return -ENOMEM; + } + + err = parse_sched_entry(q, n, entry, i, extack); + if (err < 0) { + kfree(entry); + return err; + } + + list_add_tail(&entry->list, &sched->entries); + i++; + } + + sched->num_entries = i; + + return i; +} + +static int parse_taprio_schedule(struct taprio_sched *q, struct nlattr **tb, + struct sched_gate_list *new, + struct netlink_ext_ack *extack) +{ + int err = 0; + + if (tb[TCA_TAPRIO_ATTR_SCHED_SINGLE_ENTRY]) { + NL_SET_ERR_MSG(extack, "Adding a single entry is not supported"); + return -ENOTSUPP; + } + + if (tb[TCA_TAPRIO_ATTR_SCHED_BASE_TIME]) + new->base_time = nla_get_s64(tb[TCA_TAPRIO_ATTR_SCHED_BASE_TIME]); + + if (tb[TCA_TAPRIO_ATTR_SCHED_CYCLE_TIME_EXTENSION]) + new->cycle_time_extension = nla_get_s64(tb[TCA_TAPRIO_ATTR_SCHED_CYCLE_TIME_EXTENSION]); + + if (tb[TCA_TAPRIO_ATTR_SCHED_CYCLE_TIME]) + new->cycle_time = nla_get_s64(tb[TCA_TAPRIO_ATTR_SCHED_CYCLE_TIME]); + + if (tb[TCA_TAPRIO_ATTR_SCHED_ENTRY_LIST]) + err = parse_sched_list(q, tb[TCA_TAPRIO_ATTR_SCHED_ENTRY_LIST], + new, extack); + if (err < 0) + return err; + + if (!new->cycle_time) { + struct sched_entry *entry; + ktime_t cycle = 0; + + list_for_each_entry(entry, &new->entries, list) + cycle = ktime_add_ns(cycle, entry->interval); + + if (!cycle) { + NL_SET_ERR_MSG(extack, "'cycle_time' can never be 0"); + return -EINVAL; + } + + if (cycle < 0 || cycle > INT_MAX) { + NL_SET_ERR_MSG(extack, "'cycle_time' is too big"); + return -EINVAL; + } + + new->cycle_time = cycle; + } + + return 0; +} + +static int taprio_parse_mqprio_opt(struct net_device *dev, + struct tc_mqprio_qopt *qopt, + struct netlink_ext_ack *extack, + u32 taprio_flags) +{ + int i, j; + + if (!qopt && !dev->num_tc) { + NL_SET_ERR_MSG(extack, "'mqprio' configuration is necessary"); + return -EINVAL; + } + + /* If num_tc is already set, it means that the user already + * configured the mqprio part + */ + if (dev->num_tc) + return 0; + + /* Verify num_tc is not out of max range */ + if (qopt->num_tc > TC_MAX_QUEUE) { + NL_SET_ERR_MSG(extack, "Number of traffic classes is outside valid range"); + return -EINVAL; + } + + /* taprio imposes that traffic classes map 1:n to tx queues */ + if (qopt->num_tc > dev->num_tx_queues) { + NL_SET_ERR_MSG(extack, "Number of traffic classes is greater than number of HW queues"); + return -EINVAL; + } + + /* Verify priority mapping uses valid tcs */ + for (i = 0; i <= TC_BITMASK; i++) { + if (qopt->prio_tc_map[i] >= qopt->num_tc) { + NL_SET_ERR_MSG(extack, "Invalid traffic class in priority to traffic class mapping"); + return -EINVAL; + } + } + + for (i = 0; i < qopt->num_tc; i++) { + unsigned int last = qopt->offset[i] + qopt->count[i]; + + /* Verify the queue count is in tx range being equal to the + * real_num_tx_queues indicates the last queue is in use. + */ + if (qopt->offset[i] >= dev->num_tx_queues || + !qopt->count[i] || + last > dev->real_num_tx_queues) { + NL_SET_ERR_MSG(extack, "Invalid queue in traffic class to queue mapping"); + return -EINVAL; + } + + if (TXTIME_ASSIST_IS_ENABLED(taprio_flags)) + continue; + + /* Verify that the offset and counts do not overlap */ + for (j = i + 1; j < qopt->num_tc; j++) { + if (last > qopt->offset[j]) { + NL_SET_ERR_MSG(extack, "Detected overlap in the traffic class to queue mapping"); + return -EINVAL; + } + } + } + + return 0; +} + +static int taprio_get_start_time(struct Qdisc *sch, + struct sched_gate_list *sched, + ktime_t *start) +{ + struct taprio_sched *q = qdisc_priv(sch); + ktime_t now, base, cycle; + s64 n; + + base = sched_base_time(sched); + now = taprio_get_time(q); + + if (ktime_after(base, now)) { + *start = base; + return 0; + } + + cycle = sched->cycle_time; + + /* The qdisc is expected to have at least one sched_entry. Moreover, + * any entry must have 'interval' > 0. Thus if the cycle time is zero, + * something went really wrong. In that case, we should warn about this + * inconsistent state and return error. + */ + if (WARN_ON(!cycle)) + return -EFAULT; + + /* Schedule the start time for the beginning of the next + * cycle. + */ + n = div64_s64(ktime_sub_ns(now, base), cycle); + *start = ktime_add_ns(base, (n + 1) * cycle); + return 0; +} + +static void setup_first_close_time(struct taprio_sched *q, + struct sched_gate_list *sched, ktime_t base) +{ + struct sched_entry *first; + ktime_t cycle; + + first = list_first_entry(&sched->entries, + struct sched_entry, list); + + cycle = sched->cycle_time; + + /* FIXME: find a better place to do this */ + sched->cycle_close_time = ktime_add_ns(base, cycle); + + first->close_time = ktime_add_ns(base, first->interval); + taprio_set_budget(q, first); + rcu_assign_pointer(q->current_entry, NULL); +} + +static void taprio_start_sched(struct Qdisc *sch, + ktime_t start, struct sched_gate_list *new) +{ + struct taprio_sched *q = qdisc_priv(sch); + ktime_t expires; + + if (FULL_OFFLOAD_IS_ENABLED(q->flags)) + return; + + expires = hrtimer_get_expires(&q->advance_timer); + if (expires == 0) + expires = KTIME_MAX; + + /* If the new schedule starts before the next expiration, we + * reprogram it to the earliest one, so we change the admin + * schedule to the operational one at the right time. + */ + start = min_t(ktime_t, start, expires); + + hrtimer_start(&q->advance_timer, start, HRTIMER_MODE_ABS); +} + +static void taprio_set_picos_per_byte(struct net_device *dev, + struct taprio_sched *q) +{ + struct ethtool_link_ksettings ecmd; + int speed = SPEED_10; + int picos_per_byte; + int err; + + err = __ethtool_get_link_ksettings(dev, &ecmd); + if (err < 0) + goto skip; + + if (ecmd.base.speed && ecmd.base.speed != SPEED_UNKNOWN) + speed = ecmd.base.speed; + +skip: + picos_per_byte = (USEC_PER_SEC * 8) / speed; + + atomic64_set(&q->picos_per_byte, picos_per_byte); + netdev_dbg(dev, "taprio: set %s's picos_per_byte to: %lld, linkspeed: %d\n", + dev->name, (long long)atomic64_read(&q->picos_per_byte), + ecmd.base.speed); +} + +static int taprio_dev_notifier(struct notifier_block *nb, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct taprio_sched *q; + + ASSERT_RTNL(); + + if (event != NETDEV_UP && event != NETDEV_CHANGE) + return NOTIFY_DONE; + + list_for_each_entry(q, &taprio_list, taprio_list) { + if (dev != qdisc_dev(q->root)) + continue; + + taprio_set_picos_per_byte(dev, q); + break; + } + + return NOTIFY_DONE; +} + +static void setup_txtime(struct taprio_sched *q, + struct sched_gate_list *sched, ktime_t base) +{ + struct sched_entry *entry; + u64 interval = 0; + + list_for_each_entry(entry, &sched->entries, list) { + entry->next_txtime = ktime_add_ns(base, interval); + interval += entry->interval; + } +} + +static struct tc_taprio_qopt_offload *taprio_offload_alloc(int num_entries) +{ + struct __tc_taprio_qopt_offload *__offload; + + __offload = kzalloc(struct_size(__offload, offload.entries, num_entries), + GFP_KERNEL); + if (!__offload) + return NULL; + + refcount_set(&__offload->users, 1); + + return &__offload->offload; +} + +struct tc_taprio_qopt_offload *taprio_offload_get(struct tc_taprio_qopt_offload + *offload) +{ + struct __tc_taprio_qopt_offload *__offload; + + __offload = container_of(offload, struct __tc_taprio_qopt_offload, + offload); + + refcount_inc(&__offload->users); + + return offload; +} +EXPORT_SYMBOL_GPL(taprio_offload_get); + +void taprio_offload_free(struct tc_taprio_qopt_offload *offload) +{ + struct __tc_taprio_qopt_offload *__offload; + + __offload = container_of(offload, struct __tc_taprio_qopt_offload, + offload); + + if (!refcount_dec_and_test(&__offload->users)) + return; + + kfree(__offload); +} +EXPORT_SYMBOL_GPL(taprio_offload_free); + +/* The function will only serve to keep the pointers to the "oper" and "admin" + * schedules valid in relation to their base times, so when calling dump() the + * users looks at the right schedules. + * When using full offload, the admin configuration is promoted to oper at the + * base_time in the PHC time domain. But because the system time is not + * necessarily in sync with that, we can't just trigger a hrtimer to call + * switch_schedules at the right hardware time. + * At the moment we call this by hand right away from taprio, but in the future + * it will be useful to create a mechanism for drivers to notify taprio of the + * offload state (PENDING, ACTIVE, INACTIVE) so it can be visible in dump(). + * This is left as TODO. + */ +static void taprio_offload_config_changed(struct taprio_sched *q) +{ + struct sched_gate_list *oper, *admin; + + oper = rtnl_dereference(q->oper_sched); + admin = rtnl_dereference(q->admin_sched); + + switch_schedules(q, &admin, &oper); +} + +static u32 tc_map_to_queue_mask(struct net_device *dev, u32 tc_mask) +{ + u32 i, queue_mask = 0; + + for (i = 0; i < dev->num_tc; i++) { + u32 offset, count; + + if (!(tc_mask & BIT(i))) + continue; + + offset = dev->tc_to_txq[i].offset; + count = dev->tc_to_txq[i].count; + + queue_mask |= GENMASK(offset + count - 1, offset); + } + + return queue_mask; +} + +static void taprio_sched_to_offload(struct net_device *dev, + struct sched_gate_list *sched, + struct tc_taprio_qopt_offload *offload) +{ + struct sched_entry *entry; + int i = 0; + + offload->base_time = sched->base_time; + offload->cycle_time = sched->cycle_time; + offload->cycle_time_extension = sched->cycle_time_extension; + + list_for_each_entry(entry, &sched->entries, list) { + struct tc_taprio_sched_entry *e = &offload->entries[i]; + + e->command = entry->command; + e->interval = entry->interval; + e->gate_mask = tc_map_to_queue_mask(dev, entry->gate_mask); + + i++; + } + + offload->num_entries = i; +} + +static int taprio_enable_offload(struct net_device *dev, + struct taprio_sched *q, + struct sched_gate_list *sched, + struct netlink_ext_ack *extack) +{ + const struct net_device_ops *ops = dev->netdev_ops; + struct tc_taprio_qopt_offload *offload; + struct tc_taprio_caps caps; + int tc, err = 0; + + if (!ops->ndo_setup_tc) { + NL_SET_ERR_MSG(extack, + "Device does not support taprio offload"); + return -EOPNOTSUPP; + } + + qdisc_offload_query_caps(dev, TC_SETUP_QDISC_TAPRIO, + &caps, sizeof(caps)); + + if (!caps.supports_queue_max_sdu) { + for (tc = 0; tc < TC_MAX_QUEUE; tc++) { + if (q->max_sdu[tc]) { + NL_SET_ERR_MSG_MOD(extack, + "Device does not handle queueMaxSDU"); + return -EOPNOTSUPP; + } + } + } + + offload = taprio_offload_alloc(sched->num_entries); + if (!offload) { + NL_SET_ERR_MSG(extack, + "Not enough memory for enabling offload mode"); + return -ENOMEM; + } + offload->enable = 1; + taprio_sched_to_offload(dev, sched, offload); + + for (tc = 0; tc < TC_MAX_QUEUE; tc++) + offload->max_sdu[tc] = q->max_sdu[tc]; + + err = ops->ndo_setup_tc(dev, TC_SETUP_QDISC_TAPRIO, offload); + if (err < 0) { + NL_SET_ERR_MSG(extack, + "Device failed to setup taprio offload"); + goto done; + } + + q->offloaded = true; + +done: + taprio_offload_free(offload); + + return err; +} + +static int taprio_disable_offload(struct net_device *dev, + struct taprio_sched *q, + struct netlink_ext_ack *extack) +{ + const struct net_device_ops *ops = dev->netdev_ops; + struct tc_taprio_qopt_offload *offload; + int err; + + if (!q->offloaded) + return 0; + + offload = taprio_offload_alloc(0); + if (!offload) { + NL_SET_ERR_MSG(extack, + "Not enough memory to disable offload mode"); + return -ENOMEM; + } + offload->enable = 0; + + err = ops->ndo_setup_tc(dev, TC_SETUP_QDISC_TAPRIO, offload); + if (err < 0) { + NL_SET_ERR_MSG(extack, + "Device failed to disable offload"); + goto out; + } + + q->offloaded = false; + +out: + taprio_offload_free(offload); + + return err; +} + +/* If full offload is enabled, the only possible clockid is the net device's + * PHC. For that reason, specifying a clockid through netlink is incorrect. + * For txtime-assist, it is implicitly assumed that the device's PHC is kept + * in sync with the specified clockid via a user space daemon such as phc2sys. + * For both software taprio and txtime-assist, the clockid is used for the + * hrtimer that advances the schedule and hence mandatory. + */ +static int taprio_parse_clockid(struct Qdisc *sch, struct nlattr **tb, + struct netlink_ext_ack *extack) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + int err = -EINVAL; + + if (FULL_OFFLOAD_IS_ENABLED(q->flags)) { + const struct ethtool_ops *ops = dev->ethtool_ops; + struct ethtool_ts_info info = { + .cmd = ETHTOOL_GET_TS_INFO, + .phc_index = -1, + }; + + if (tb[TCA_TAPRIO_ATTR_SCHED_CLOCKID]) { + NL_SET_ERR_MSG(extack, + "The 'clockid' cannot be specified for full offload"); + goto out; + } + + if (ops && ops->get_ts_info) + err = ops->get_ts_info(dev, &info); + + if (err || info.phc_index < 0) { + NL_SET_ERR_MSG(extack, + "Device does not have a PTP clock"); + err = -ENOTSUPP; + goto out; + } + } else if (tb[TCA_TAPRIO_ATTR_SCHED_CLOCKID]) { + int clockid = nla_get_s32(tb[TCA_TAPRIO_ATTR_SCHED_CLOCKID]); + enum tk_offsets tk_offset; + + /* We only support static clockids and we don't allow + * for it to be modified after the first init. + */ + if (clockid < 0 || + (q->clockid != -1 && q->clockid != clockid)) { + NL_SET_ERR_MSG(extack, + "Changing the 'clockid' of a running schedule is not supported"); + err = -ENOTSUPP; + goto out; + } + + switch (clockid) { + case CLOCK_REALTIME: + tk_offset = TK_OFFS_REAL; + break; + case CLOCK_MONOTONIC: + tk_offset = TK_OFFS_MAX; + break; + case CLOCK_BOOTTIME: + tk_offset = TK_OFFS_BOOT; + break; + case CLOCK_TAI: + tk_offset = TK_OFFS_TAI; + break; + default: + NL_SET_ERR_MSG(extack, "Invalid 'clockid'"); + err = -EINVAL; + goto out; + } + /* This pairs with READ_ONCE() in taprio_mono_to_any */ + WRITE_ONCE(q->tk_offset, tk_offset); + + q->clockid = clockid; + } else { + NL_SET_ERR_MSG(extack, "Specifying a 'clockid' is mandatory"); + goto out; + } + + /* Everything went ok, return success. */ + err = 0; + +out: + return err; +} + +static int taprio_parse_tc_entry(struct Qdisc *sch, + struct nlattr *opt, + u32 max_sdu[TC_QOPT_MAX_QUEUE], + unsigned long *seen_tcs, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_TAPRIO_TC_ENTRY_MAX + 1] = { }; + struct net_device *dev = qdisc_dev(sch); + u32 val = 0; + int err, tc; + + err = nla_parse_nested(tb, TCA_TAPRIO_TC_ENTRY_MAX, opt, + taprio_tc_policy, extack); + if (err < 0) + return err; + + if (!tb[TCA_TAPRIO_TC_ENTRY_INDEX]) { + NL_SET_ERR_MSG_MOD(extack, "TC entry index missing"); + return -EINVAL; + } + + tc = nla_get_u32(tb[TCA_TAPRIO_TC_ENTRY_INDEX]); + if (tc >= TC_QOPT_MAX_QUEUE) { + NL_SET_ERR_MSG_MOD(extack, "TC entry index out of range"); + return -ERANGE; + } + + if (*seen_tcs & BIT(tc)) { + NL_SET_ERR_MSG_MOD(extack, "Duplicate TC entry"); + return -EINVAL; + } + + *seen_tcs |= BIT(tc); + + if (tb[TCA_TAPRIO_TC_ENTRY_MAX_SDU]) + val = nla_get_u32(tb[TCA_TAPRIO_TC_ENTRY_MAX_SDU]); + + if (val > dev->max_mtu) { + NL_SET_ERR_MSG_MOD(extack, "TC max SDU exceeds device max MTU"); + return -ERANGE; + } + + max_sdu[tc] = val; + + return 0; +} + +static int taprio_parse_tc_entries(struct Qdisc *sch, + struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + u32 max_sdu[TC_QOPT_MAX_QUEUE]; + unsigned long seen_tcs = 0; + struct nlattr *n; + int tc, rem; + int err = 0; + + for (tc = 0; tc < TC_QOPT_MAX_QUEUE; tc++) + max_sdu[tc] = q->max_sdu[tc]; + + nla_for_each_nested(n, opt, rem) { + if (nla_type(n) != TCA_TAPRIO_ATTR_TC_ENTRY) + continue; + + err = taprio_parse_tc_entry(sch, n, max_sdu, &seen_tcs, extack); + if (err) + goto out; + } + + for (tc = 0; tc < TC_QOPT_MAX_QUEUE; tc++) { + q->max_sdu[tc] = max_sdu[tc]; + if (max_sdu[tc]) + q->max_frm_len[tc] = max_sdu[tc] + dev->hard_header_len; + else + q->max_frm_len[tc] = U32_MAX; /* never oversized */ + } + +out: + return err; +} + +static int taprio_mqprio_cmp(const struct net_device *dev, + const struct tc_mqprio_qopt *mqprio) +{ + int i; + + if (!mqprio || mqprio->num_tc != dev->num_tc) + return -1; + + for (i = 0; i < mqprio->num_tc; i++) + if (dev->tc_to_txq[i].count != mqprio->count[i] || + dev->tc_to_txq[i].offset != mqprio->offset[i]) + return -1; + + for (i = 0; i <= TC_BITMASK; i++) + if (dev->prio_tc_map[i] != mqprio->prio_tc_map[i]) + return -1; + + return 0; +} + +/* The semantics of the 'flags' argument in relation to 'change()' + * requests, are interpreted following two rules (which are applied in + * this order): (1) an omitted 'flags' argument is interpreted as + * zero; (2) the 'flags' of a "running" taprio instance cannot be + * changed. + */ +static int taprio_new_flags(const struct nlattr *attr, u32 old, + struct netlink_ext_ack *extack) +{ + u32 new = 0; + + if (attr) + new = nla_get_u32(attr); + + if (old != TAPRIO_FLAGS_INVALID && old != new) { + NL_SET_ERR_MSG_MOD(extack, "Changing 'flags' of a running schedule is not supported"); + return -EOPNOTSUPP; + } + + if (!taprio_flags_valid(new)) { + NL_SET_ERR_MSG_MOD(extack, "Specified 'flags' are not valid"); + return -EINVAL; + } + + return new; +} + +static int taprio_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[TCA_TAPRIO_ATTR_MAX + 1] = { }; + struct sched_gate_list *oper, *admin, *new_admin; + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct tc_mqprio_qopt *mqprio = NULL; + unsigned long flags; + ktime_t start; + int i, err; + + err = nla_parse_nested_deprecated(tb, TCA_TAPRIO_ATTR_MAX, opt, + taprio_policy, extack); + if (err < 0) + return err; + + if (tb[TCA_TAPRIO_ATTR_PRIOMAP]) + mqprio = nla_data(tb[TCA_TAPRIO_ATTR_PRIOMAP]); + + err = taprio_new_flags(tb[TCA_TAPRIO_ATTR_FLAGS], + q->flags, extack); + if (err < 0) + return err; + + q->flags = err; + + err = taprio_parse_mqprio_opt(dev, mqprio, extack, q->flags); + if (err < 0) + return err; + + err = taprio_parse_tc_entries(sch, opt, extack); + if (err) + return err; + + new_admin = kzalloc(sizeof(*new_admin), GFP_KERNEL); + if (!new_admin) { + NL_SET_ERR_MSG(extack, "Not enough memory for a new schedule"); + return -ENOMEM; + } + INIT_LIST_HEAD(&new_admin->entries); + + oper = rtnl_dereference(q->oper_sched); + admin = rtnl_dereference(q->admin_sched); + + /* no changes - no new mqprio settings */ + if (!taprio_mqprio_cmp(dev, mqprio)) + mqprio = NULL; + + if (mqprio && (oper || admin)) { + NL_SET_ERR_MSG(extack, "Changing the traffic mapping of a running schedule is not supported"); + err = -ENOTSUPP; + goto free_sched; + } + + err = parse_taprio_schedule(q, tb, new_admin, extack); + if (err < 0) + goto free_sched; + + if (new_admin->num_entries == 0) { + NL_SET_ERR_MSG(extack, "There should be at least one entry in the schedule"); + err = -EINVAL; + goto free_sched; + } + + err = taprio_parse_clockid(sch, tb, extack); + if (err < 0) + goto free_sched; + + taprio_set_picos_per_byte(dev, q); + + if (mqprio) { + err = netdev_set_num_tc(dev, mqprio->num_tc); + if (err) + goto free_sched; + for (i = 0; i < mqprio->num_tc; i++) + netdev_set_tc_queue(dev, i, + mqprio->count[i], + mqprio->offset[i]); + + /* Always use supplied priority mappings */ + for (i = 0; i <= TC_BITMASK; i++) + netdev_set_prio_tc_map(dev, i, + mqprio->prio_tc_map[i]); + } + + if (FULL_OFFLOAD_IS_ENABLED(q->flags)) + err = taprio_enable_offload(dev, q, new_admin, extack); + else + err = taprio_disable_offload(dev, q, extack); + if (err) + goto free_sched; + + /* Protects against enqueue()/dequeue() */ + spin_lock_bh(qdisc_lock(sch)); + + if (tb[TCA_TAPRIO_ATTR_TXTIME_DELAY]) { + if (!TXTIME_ASSIST_IS_ENABLED(q->flags)) { + NL_SET_ERR_MSG_MOD(extack, "txtime-delay can only be set when txtime-assist mode is enabled"); + err = -EINVAL; + goto unlock; + } + + q->txtime_delay = nla_get_u32(tb[TCA_TAPRIO_ATTR_TXTIME_DELAY]); + } + + if (!TXTIME_ASSIST_IS_ENABLED(q->flags) && + !FULL_OFFLOAD_IS_ENABLED(q->flags) && + !hrtimer_active(&q->advance_timer)) { + hrtimer_init(&q->advance_timer, q->clockid, HRTIMER_MODE_ABS); + q->advance_timer.function = advance_sched; + } + + err = taprio_get_start_time(sch, new_admin, &start); + if (err < 0) { + NL_SET_ERR_MSG(extack, "Internal error: failed get start time"); + goto unlock; + } + + setup_txtime(q, new_admin, start); + + if (TXTIME_ASSIST_IS_ENABLED(q->flags)) { + if (!oper) { + rcu_assign_pointer(q->oper_sched, new_admin); + err = 0; + new_admin = NULL; + goto unlock; + } + + rcu_assign_pointer(q->admin_sched, new_admin); + if (admin) + call_rcu(&admin->rcu, taprio_free_sched_cb); + } else { + setup_first_close_time(q, new_admin, start); + + /* Protects against advance_sched() */ + spin_lock_irqsave(&q->current_entry_lock, flags); + + taprio_start_sched(sch, start, new_admin); + + rcu_assign_pointer(q->admin_sched, new_admin); + if (admin) + call_rcu(&admin->rcu, taprio_free_sched_cb); + + spin_unlock_irqrestore(&q->current_entry_lock, flags); + + if (FULL_OFFLOAD_IS_ENABLED(q->flags)) + taprio_offload_config_changed(q); + } + + new_admin = NULL; + err = 0; + +unlock: + spin_unlock_bh(qdisc_lock(sch)); + +free_sched: + if (new_admin) + call_rcu(&new_admin->rcu, taprio_free_sched_cb); + + return err; +} + +static void taprio_reset(struct Qdisc *sch) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + int i; + + hrtimer_cancel(&q->advance_timer); + + if (q->qdiscs) { + for (i = 0; i < dev->num_tx_queues; i++) + if (q->qdiscs[i]) + qdisc_reset(q->qdiscs[i]); + } +} + +static void taprio_destroy(struct Qdisc *sch) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct sched_gate_list *oper, *admin; + unsigned int i; + + list_del(&q->taprio_list); + + /* Note that taprio_reset() might not be called if an error + * happens in qdisc_create(), after taprio_init() has been called. + */ + hrtimer_cancel(&q->advance_timer); + qdisc_synchronize(sch); + + taprio_disable_offload(dev, q, NULL); + + if (q->qdiscs) { + for (i = 0; i < dev->num_tx_queues; i++) + qdisc_put(q->qdiscs[i]); + + kfree(q->qdiscs); + } + q->qdiscs = NULL; + + netdev_reset_tc(dev); + + oper = rtnl_dereference(q->oper_sched); + admin = rtnl_dereference(q->admin_sched); + + if (oper) + call_rcu(&oper->rcu, taprio_free_sched_cb); + + if (admin) + call_rcu(&admin->rcu, taprio_free_sched_cb); +} + +static int taprio_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + int i; + + spin_lock_init(&q->current_entry_lock); + + hrtimer_init(&q->advance_timer, CLOCK_TAI, HRTIMER_MODE_ABS); + q->advance_timer.function = advance_sched; + + q->root = sch; + + /* We only support static clockids. Use an invalid value as default + * and get the valid one on taprio_change(). + */ + q->clockid = -1; + q->flags = TAPRIO_FLAGS_INVALID; + + list_add(&q->taprio_list, &taprio_list); + + if (sch->parent != TC_H_ROOT) { + NL_SET_ERR_MSG_MOD(extack, "Can only be attached as root qdisc"); + return -EOPNOTSUPP; + } + + if (!netif_is_multiqueue(dev)) { + NL_SET_ERR_MSG_MOD(extack, "Multi-queue device is required"); + return -EOPNOTSUPP; + } + + /* pre-allocate qdisc, attachment can't fail */ + q->qdiscs = kcalloc(dev->num_tx_queues, + sizeof(q->qdiscs[0]), + GFP_KERNEL); + + if (!q->qdiscs) + return -ENOMEM; + + if (!opt) + return -EINVAL; + + for (i = 0; i < dev->num_tx_queues; i++) { + struct netdev_queue *dev_queue; + struct Qdisc *qdisc; + + dev_queue = netdev_get_tx_queue(dev, i); + qdisc = qdisc_create_dflt(dev_queue, + &pfifo_qdisc_ops, + TC_H_MAKE(TC_H_MAJ(sch->handle), + TC_H_MIN(i + 1)), + extack); + if (!qdisc) + return -ENOMEM; + + if (i < dev->real_num_tx_queues) + qdisc_hash_add(qdisc, false); + + q->qdiscs[i] = qdisc; + } + + return taprio_change(sch, opt, extack); +} + +static void taprio_attach(struct Qdisc *sch) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + unsigned int ntx; + + /* Attach underlying qdisc */ + for (ntx = 0; ntx < dev->num_tx_queues; ntx++) { + struct Qdisc *qdisc = q->qdiscs[ntx]; + struct Qdisc *old; + + if (FULL_OFFLOAD_IS_ENABLED(q->flags)) { + qdisc->flags |= TCQ_F_ONETXQUEUE | TCQ_F_NOPARENT; + old = dev_graft_qdisc(qdisc->dev_queue, qdisc); + } else { + old = dev_graft_qdisc(qdisc->dev_queue, sch); + qdisc_refcount_inc(sch); + } + if (old) + qdisc_put(old); + } + + /* access to the child qdiscs is not needed in offload mode */ + if (FULL_OFFLOAD_IS_ENABLED(q->flags)) { + kfree(q->qdiscs); + q->qdiscs = NULL; + } +} + +static struct netdev_queue *taprio_queue_get(struct Qdisc *sch, + unsigned long cl) +{ + struct net_device *dev = qdisc_dev(sch); + unsigned long ntx = cl - 1; + + if (ntx >= dev->num_tx_queues) + return NULL; + + return netdev_get_tx_queue(dev, ntx); +} + +static int taprio_graft(struct Qdisc *sch, unsigned long cl, + struct Qdisc *new, struct Qdisc **old, + struct netlink_ext_ack *extack) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct netdev_queue *dev_queue = taprio_queue_get(sch, cl); + + if (!dev_queue) + return -EINVAL; + + if (dev->flags & IFF_UP) + dev_deactivate(dev); + + if (FULL_OFFLOAD_IS_ENABLED(q->flags)) { + *old = dev_graft_qdisc(dev_queue, new); + } else { + *old = q->qdiscs[cl - 1]; + q->qdiscs[cl - 1] = new; + } + + if (new) + new->flags |= TCQ_F_ONETXQUEUE | TCQ_F_NOPARENT; + + if (dev->flags & IFF_UP) + dev_activate(dev); + + return 0; +} + +static int dump_entry(struct sk_buff *msg, + const struct sched_entry *entry) +{ + struct nlattr *item; + + item = nla_nest_start_noflag(msg, TCA_TAPRIO_SCHED_ENTRY); + if (!item) + return -ENOSPC; + + if (nla_put_u32(msg, TCA_TAPRIO_SCHED_ENTRY_INDEX, entry->index)) + goto nla_put_failure; + + if (nla_put_u8(msg, TCA_TAPRIO_SCHED_ENTRY_CMD, entry->command)) + goto nla_put_failure; + + if (nla_put_u32(msg, TCA_TAPRIO_SCHED_ENTRY_GATE_MASK, + entry->gate_mask)) + goto nla_put_failure; + + if (nla_put_u32(msg, TCA_TAPRIO_SCHED_ENTRY_INTERVAL, + entry->interval)) + goto nla_put_failure; + + return nla_nest_end(msg, item); + +nla_put_failure: + nla_nest_cancel(msg, item); + return -1; +} + +static int dump_schedule(struct sk_buff *msg, + const struct sched_gate_list *root) +{ + struct nlattr *entry_list; + struct sched_entry *entry; + + if (nla_put_s64(msg, TCA_TAPRIO_ATTR_SCHED_BASE_TIME, + root->base_time, TCA_TAPRIO_PAD)) + return -1; + + if (nla_put_s64(msg, TCA_TAPRIO_ATTR_SCHED_CYCLE_TIME, + root->cycle_time, TCA_TAPRIO_PAD)) + return -1; + + if (nla_put_s64(msg, TCA_TAPRIO_ATTR_SCHED_CYCLE_TIME_EXTENSION, + root->cycle_time_extension, TCA_TAPRIO_PAD)) + return -1; + + entry_list = nla_nest_start_noflag(msg, + TCA_TAPRIO_ATTR_SCHED_ENTRY_LIST); + if (!entry_list) + goto error_nest; + + list_for_each_entry(entry, &root->entries, list) { + if (dump_entry(msg, entry) < 0) + goto error_nest; + } + + nla_nest_end(msg, entry_list); + return 0; + +error_nest: + nla_nest_cancel(msg, entry_list); + return -1; +} + +static int taprio_dump_tc_entries(struct taprio_sched *q, struct sk_buff *skb) +{ + struct nlattr *n; + int tc; + + for (tc = 0; tc < TC_MAX_QUEUE; tc++) { + n = nla_nest_start(skb, TCA_TAPRIO_ATTR_TC_ENTRY); + if (!n) + return -EMSGSIZE; + + if (nla_put_u32(skb, TCA_TAPRIO_TC_ENTRY_INDEX, tc)) + goto nla_put_failure; + + if (nla_put_u32(skb, TCA_TAPRIO_TC_ENTRY_MAX_SDU, + q->max_sdu[tc])) + goto nla_put_failure; + + nla_nest_end(skb, n); + } + + return 0; + +nla_put_failure: + nla_nest_cancel(skb, n); + return -EMSGSIZE; +} + +static int taprio_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct taprio_sched *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct sched_gate_list *oper, *admin; + struct tc_mqprio_qopt opt = { 0 }; + struct nlattr *nest, *sched_nest; + unsigned int i; + + oper = rtnl_dereference(q->oper_sched); + admin = rtnl_dereference(q->admin_sched); + + opt.num_tc = netdev_get_num_tc(dev); + memcpy(opt.prio_tc_map, dev->prio_tc_map, sizeof(opt.prio_tc_map)); + + for (i = 0; i < netdev_get_num_tc(dev); i++) { + opt.count[i] = dev->tc_to_txq[i].count; + opt.offset[i] = dev->tc_to_txq[i].offset; + } + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (!nest) + goto start_error; + + if (nla_put(skb, TCA_TAPRIO_ATTR_PRIOMAP, sizeof(opt), &opt)) + goto options_error; + + if (!FULL_OFFLOAD_IS_ENABLED(q->flags) && + nla_put_s32(skb, TCA_TAPRIO_ATTR_SCHED_CLOCKID, q->clockid)) + goto options_error; + + if (q->flags && nla_put_u32(skb, TCA_TAPRIO_ATTR_FLAGS, q->flags)) + goto options_error; + + if (q->txtime_delay && + nla_put_u32(skb, TCA_TAPRIO_ATTR_TXTIME_DELAY, q->txtime_delay)) + goto options_error; + + if (taprio_dump_tc_entries(q, skb)) + goto options_error; + + if (oper && dump_schedule(skb, oper)) + goto options_error; + + if (!admin) + goto done; + + sched_nest = nla_nest_start_noflag(skb, TCA_TAPRIO_ATTR_ADMIN_SCHED); + if (!sched_nest) + goto options_error; + + if (dump_schedule(skb, admin)) + goto admin_error; + + nla_nest_end(skb, sched_nest); + +done: + return nla_nest_end(skb, nest); + +admin_error: + nla_nest_cancel(skb, sched_nest); + +options_error: + nla_nest_cancel(skb, nest); + +start_error: + return -ENOSPC; +} + +static struct Qdisc *taprio_leaf(struct Qdisc *sch, unsigned long cl) +{ + struct netdev_queue *dev_queue = taprio_queue_get(sch, cl); + + if (!dev_queue) + return NULL; + + return rtnl_dereference(dev_queue->qdisc_sleeping); +} + +static unsigned long taprio_find(struct Qdisc *sch, u32 classid) +{ + unsigned int ntx = TC_H_MIN(classid); + + if (!taprio_queue_get(sch, ntx)) + return 0; + return ntx; +} + +static int taprio_dump_class(struct Qdisc *sch, unsigned long cl, + struct sk_buff *skb, struct tcmsg *tcm) +{ + struct netdev_queue *dev_queue = taprio_queue_get(sch, cl); + + tcm->tcm_parent = TC_H_ROOT; + tcm->tcm_handle |= TC_H_MIN(cl); + tcm->tcm_info = rtnl_dereference(dev_queue->qdisc_sleeping)->handle; + + return 0; +} + +static int taprio_dump_class_stats(struct Qdisc *sch, unsigned long cl, + struct gnet_dump *d) + __releases(d->lock) + __acquires(d->lock) +{ + struct netdev_queue *dev_queue = taprio_queue_get(sch, cl); + + sch = rtnl_dereference(dev_queue->qdisc_sleeping); + if (gnet_stats_copy_basic(d, NULL, &sch->bstats, true) < 0 || + qdisc_qstats_copy(d, sch) < 0) + return -1; + return 0; +} + +static void taprio_walk(struct Qdisc *sch, struct qdisc_walker *arg) +{ + struct net_device *dev = qdisc_dev(sch); + unsigned long ntx; + + if (arg->stop) + return; + + arg->count = arg->skip; + for (ntx = arg->skip; ntx < dev->num_tx_queues; ntx++) { + if (!tc_qdisc_stats_dump(sch, ntx + 1, arg)) + break; + } +} + +static struct netdev_queue *taprio_select_queue(struct Qdisc *sch, + struct tcmsg *tcm) +{ + return taprio_queue_get(sch, TC_H_MIN(tcm->tcm_parent)); +} + +static const struct Qdisc_class_ops taprio_class_ops = { + .graft = taprio_graft, + .leaf = taprio_leaf, + .find = taprio_find, + .walk = taprio_walk, + .dump = taprio_dump_class, + .dump_stats = taprio_dump_class_stats, + .select_queue = taprio_select_queue, +}; + +static struct Qdisc_ops taprio_qdisc_ops __read_mostly = { + .cl_ops = &taprio_class_ops, + .id = "taprio", + .priv_size = sizeof(struct taprio_sched), + .init = taprio_init, + .change = taprio_change, + .destroy = taprio_destroy, + .reset = taprio_reset, + .attach = taprio_attach, + .peek = taprio_peek, + .dequeue = taprio_dequeue, + .enqueue = taprio_enqueue, + .dump = taprio_dump, + .owner = THIS_MODULE, +}; + +static struct notifier_block taprio_device_notifier = { + .notifier_call = taprio_dev_notifier, +}; + +static int __init taprio_module_init(void) +{ + int err = register_netdevice_notifier(&taprio_device_notifier); + + if (err) + return err; + + return register_qdisc(&taprio_qdisc_ops); +} + +static void __exit taprio_module_exit(void) +{ + unregister_qdisc(&taprio_qdisc_ops); + unregister_netdevice_notifier(&taprio_device_notifier); +} + +module_init(taprio_module_init); +module_exit(taprio_module_exit); +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_tbf.c b/net/sched/sch_tbf.c new file mode 100644 index 000000000..277ad11f4 --- /dev/null +++ b/net/sched/sch_tbf.c @@ -0,0 +1,622 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/sched/sch_tbf.c Token Bucket Filter queue. + * + * Authors: Alexey Kuznetsov, + * Dmitry Torokhov - allow attaching inner qdiscs - + * original idea by Martin Devera + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +/* Simple Token Bucket Filter. + ======================================= + + SOURCE. + ------- + + None. + + Description. + ------------ + + A data flow obeys TBF with rate R and depth B, if for any + time interval t_i...t_f the number of transmitted bits + does not exceed B + R*(t_f-t_i). + + Packetized version of this definition: + The sequence of packets of sizes s_i served at moments t_i + obeys TBF, if for any i<=k: + + s_i+....+s_k <= B + R*(t_k - t_i) + + Algorithm. + ---------- + + Let N(t_i) be B/R initially and N(t) grow continuously with time as: + + N(t+delta) = min{B/R, N(t) + delta} + + If the first packet in queue has length S, it may be + transmitted only at the time t_* when S/R <= N(t_*), + and in this case N(t) jumps: + + N(t_* + 0) = N(t_* - 0) - S/R. + + + + Actually, QoS requires two TBF to be applied to a data stream. + One of them controls steady state burst size, another + one with rate P (peak rate) and depth M (equal to link MTU) + limits bursts at a smaller time scale. + + It is easy to see that P>R, and B>M. If P is infinity, this double + TBF is equivalent to a single one. + + When TBF works in reshaping mode, latency is estimated as: + + lat = max ((L-B)/R, (L-M)/P) + + + NOTES. + ------ + + If TBF throttles, it starts a watchdog timer, which will wake it up + when it is ready to transmit. + Note that the minimal timer resolution is 1/HZ. + If no new packets arrive during this period, + or if the device is not awaken by EOI for some previous packet, + TBF can stop its activity for 1/HZ. + + + This means, that with depth B, the maximal rate is + + R_crit = B*HZ + + F.e. for 10Mbit ethernet and HZ=100 the minimal allowed B is ~10Kbytes. + + Note that the peak rate TBF is much more tough: with MTU 1500 + P_crit = 150Kbytes/sec. So, if you need greater peak + rates, use alpha with HZ=1000 :-) + + With classful TBF, limit is just kept for backwards compatibility. + It is passed to the default bfifo qdisc - if the inner qdisc is + changed the limit is not effective anymore. +*/ + +struct tbf_sched_data { +/* Parameters */ + u32 limit; /* Maximal length of backlog: bytes */ + u32 max_size; + s64 buffer; /* Token bucket depth/rate: MUST BE >= MTU/B */ + s64 mtu; + struct psched_ratecfg rate; + struct psched_ratecfg peak; + +/* Variables */ + s64 tokens; /* Current number of B tokens */ + s64 ptokens; /* Current number of P tokens */ + s64 t_c; /* Time check-point */ + struct Qdisc *qdisc; /* Inner qdisc, default - bfifo queue */ + struct qdisc_watchdog watchdog; /* Watchdog timer */ +}; + + +/* Time to Length, convert time in ns to length in bytes + * to determinate how many bytes can be sent in given time. + */ +static u64 psched_ns_t2l(const struct psched_ratecfg *r, + u64 time_in_ns) +{ + /* The formula is : + * len = (time_in_ns * r->rate_bytes_ps) / NSEC_PER_SEC + */ + u64 len = time_in_ns * r->rate_bytes_ps; + + do_div(len, NSEC_PER_SEC); + + if (unlikely(r->linklayer == TC_LINKLAYER_ATM)) { + do_div(len, 53); + len = len * 48; + } + + if (len > r->overhead) + len -= r->overhead; + else + len = 0; + + return len; +} + +static void tbf_offload_change(struct Qdisc *sch) +{ + struct tbf_sched_data *q = qdisc_priv(sch); + struct net_device *dev = qdisc_dev(sch); + struct tc_tbf_qopt_offload qopt; + + if (!tc_can_offload(dev) || !dev->netdev_ops->ndo_setup_tc) + return; + + qopt.command = TC_TBF_REPLACE; + qopt.handle = sch->handle; + qopt.parent = sch->parent; + qopt.replace_params.rate = q->rate; + qopt.replace_params.max_size = q->max_size; + qopt.replace_params.qstats = &sch->qstats; + + dev->netdev_ops->ndo_setup_tc(dev, TC_SETUP_QDISC_TBF, &qopt); +} + +static void tbf_offload_destroy(struct Qdisc *sch) +{ + struct net_device *dev = qdisc_dev(sch); + struct tc_tbf_qopt_offload qopt; + + if (!tc_can_offload(dev) || !dev->netdev_ops->ndo_setup_tc) + return; + + qopt.command = TC_TBF_DESTROY; + qopt.handle = sch->handle; + qopt.parent = sch->parent; + dev->netdev_ops->ndo_setup_tc(dev, TC_SETUP_QDISC_TBF, &qopt); +} + +static int tbf_offload_dump(struct Qdisc *sch) +{ + struct tc_tbf_qopt_offload qopt; + + qopt.command = TC_TBF_STATS; + qopt.handle = sch->handle; + qopt.parent = sch->parent; + qopt.stats.bstats = &sch->bstats; + qopt.stats.qstats = &sch->qstats; + + return qdisc_offload_dump_helper(sch, TC_SETUP_QDISC_TBF, &qopt); +} + +static void tbf_offload_graft(struct Qdisc *sch, struct Qdisc *new, + struct Qdisc *old, struct netlink_ext_ack *extack) +{ + struct tc_tbf_qopt_offload graft_offload = { + .handle = sch->handle, + .parent = sch->parent, + .child_handle = new->handle, + .command = TC_TBF_GRAFT, + }; + + qdisc_offload_graft_helper(qdisc_dev(sch), sch, new, old, + TC_SETUP_QDISC_TBF, &graft_offload, extack); +} + +/* GSO packet is too big, segment it so that tbf can transmit + * each segment in time + */ +static int tbf_segment(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct tbf_sched_data *q = qdisc_priv(sch); + struct sk_buff *segs, *nskb; + netdev_features_t features = netif_skb_features(skb); + unsigned int len = 0, prev_len = qdisc_pkt_len(skb); + int ret, nb; + + segs = skb_gso_segment(skb, features & ~NETIF_F_GSO_MASK); + + if (IS_ERR_OR_NULL(segs)) + return qdisc_drop(skb, sch, to_free); + + nb = 0; + skb_list_walk_safe(segs, segs, nskb) { + skb_mark_not_on_list(segs); + qdisc_skb_cb(segs)->pkt_len = segs->len; + len += segs->len; + ret = qdisc_enqueue(segs, q->qdisc, to_free); + if (ret != NET_XMIT_SUCCESS) { + if (net_xmit_drop_count(ret)) + qdisc_qstats_drop(sch); + } else { + nb++; + } + } + sch->q.qlen += nb; + if (nb > 1) + qdisc_tree_reduce_backlog(sch, 1 - nb, prev_len - len); + consume_skb(skb); + return nb > 0 ? NET_XMIT_SUCCESS : NET_XMIT_DROP; +} + +static int tbf_enqueue(struct sk_buff *skb, struct Qdisc *sch, + struct sk_buff **to_free) +{ + struct tbf_sched_data *q = qdisc_priv(sch); + unsigned int len = qdisc_pkt_len(skb); + int ret; + + if (qdisc_pkt_len(skb) > q->max_size) { + if (skb_is_gso(skb) && + skb_gso_validate_mac_len(skb, q->max_size)) + return tbf_segment(skb, sch, to_free); + return qdisc_drop(skb, sch, to_free); + } + ret = qdisc_enqueue(skb, q->qdisc, to_free); + if (ret != NET_XMIT_SUCCESS) { + if (net_xmit_drop_count(ret)) + qdisc_qstats_drop(sch); + return ret; + } + + sch->qstats.backlog += len; + sch->q.qlen++; + return NET_XMIT_SUCCESS; +} + +static bool tbf_peak_present(const struct tbf_sched_data *q) +{ + return q->peak.rate_bytes_ps; +} + +static struct sk_buff *tbf_dequeue(struct Qdisc *sch) +{ + struct tbf_sched_data *q = qdisc_priv(sch); + struct sk_buff *skb; + + skb = q->qdisc->ops->peek(q->qdisc); + + if (skb) { + s64 now; + s64 toks; + s64 ptoks = 0; + unsigned int len = qdisc_pkt_len(skb); + + now = ktime_get_ns(); + toks = min_t(s64, now - q->t_c, q->buffer); + + if (tbf_peak_present(q)) { + ptoks = toks + q->ptokens; + if (ptoks > q->mtu) + ptoks = q->mtu; + ptoks -= (s64) psched_l2t_ns(&q->peak, len); + } + toks += q->tokens; + if (toks > q->buffer) + toks = q->buffer; + toks -= (s64) psched_l2t_ns(&q->rate, len); + + if ((toks|ptoks) >= 0) { + skb = qdisc_dequeue_peeked(q->qdisc); + if (unlikely(!skb)) + return NULL; + + q->t_c = now; + q->tokens = toks; + q->ptokens = ptoks; + qdisc_qstats_backlog_dec(sch, skb); + sch->q.qlen--; + qdisc_bstats_update(sch, skb); + return skb; + } + + qdisc_watchdog_schedule_ns(&q->watchdog, + now + max_t(long, -toks, -ptoks)); + + /* Maybe we have a shorter packet in the queue, + which can be sent now. It sounds cool, + but, however, this is wrong in principle. + We MUST NOT reorder packets under these circumstances. + + Really, if we split the flow into independent + subflows, it would be a very good solution. + This is the main idea of all FQ algorithms + (cf. CSZ, HPFQ, HFSC) + */ + + qdisc_qstats_overlimit(sch); + } + return NULL; +} + +static void tbf_reset(struct Qdisc *sch) +{ + struct tbf_sched_data *q = qdisc_priv(sch); + + qdisc_reset(q->qdisc); + q->t_c = ktime_get_ns(); + q->tokens = q->buffer; + q->ptokens = q->mtu; + qdisc_watchdog_cancel(&q->watchdog); +} + +static const struct nla_policy tbf_policy[TCA_TBF_MAX + 1] = { + [TCA_TBF_PARMS] = { .len = sizeof(struct tc_tbf_qopt) }, + [TCA_TBF_RTAB] = { .type = NLA_BINARY, .len = TC_RTAB_SIZE }, + [TCA_TBF_PTAB] = { .type = NLA_BINARY, .len = TC_RTAB_SIZE }, + [TCA_TBF_RATE64] = { .type = NLA_U64 }, + [TCA_TBF_PRATE64] = { .type = NLA_U64 }, + [TCA_TBF_BURST] = { .type = NLA_U32 }, + [TCA_TBF_PBURST] = { .type = NLA_U32 }, +}; + +static int tbf_change(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + int err; + struct tbf_sched_data *q = qdisc_priv(sch); + struct nlattr *tb[TCA_TBF_MAX + 1]; + struct tc_tbf_qopt *qopt; + struct Qdisc *child = NULL; + struct Qdisc *old = NULL; + struct psched_ratecfg rate; + struct psched_ratecfg peak; + u64 max_size; + s64 buffer, mtu; + u64 rate64 = 0, prate64 = 0; + + err = nla_parse_nested_deprecated(tb, TCA_TBF_MAX, opt, tbf_policy, + NULL); + if (err < 0) + return err; + + err = -EINVAL; + if (tb[TCA_TBF_PARMS] == NULL) + goto done; + + qopt = nla_data(tb[TCA_TBF_PARMS]); + if (qopt->rate.linklayer == TC_LINKLAYER_UNAWARE) + qdisc_put_rtab(qdisc_get_rtab(&qopt->rate, + tb[TCA_TBF_RTAB], + NULL)); + + if (qopt->peakrate.linklayer == TC_LINKLAYER_UNAWARE) + qdisc_put_rtab(qdisc_get_rtab(&qopt->peakrate, + tb[TCA_TBF_PTAB], + NULL)); + + buffer = min_t(u64, PSCHED_TICKS2NS(qopt->buffer), ~0U); + mtu = min_t(u64, PSCHED_TICKS2NS(qopt->mtu), ~0U); + + if (tb[TCA_TBF_RATE64]) + rate64 = nla_get_u64(tb[TCA_TBF_RATE64]); + psched_ratecfg_precompute(&rate, &qopt->rate, rate64); + + if (tb[TCA_TBF_BURST]) { + max_size = nla_get_u32(tb[TCA_TBF_BURST]); + buffer = psched_l2t_ns(&rate, max_size); + } else { + max_size = min_t(u64, psched_ns_t2l(&rate, buffer), ~0U); + } + + if (qopt->peakrate.rate) { + if (tb[TCA_TBF_PRATE64]) + prate64 = nla_get_u64(tb[TCA_TBF_PRATE64]); + psched_ratecfg_precompute(&peak, &qopt->peakrate, prate64); + if (peak.rate_bytes_ps <= rate.rate_bytes_ps) { + pr_warn_ratelimited("sch_tbf: peakrate %llu is lower than or equals to rate %llu !\n", + peak.rate_bytes_ps, rate.rate_bytes_ps); + err = -EINVAL; + goto done; + } + + if (tb[TCA_TBF_PBURST]) { + u32 pburst = nla_get_u32(tb[TCA_TBF_PBURST]); + max_size = min_t(u32, max_size, pburst); + mtu = psched_l2t_ns(&peak, pburst); + } else { + max_size = min_t(u64, max_size, psched_ns_t2l(&peak, mtu)); + } + } else { + memset(&peak, 0, sizeof(peak)); + } + + if (max_size < psched_mtu(qdisc_dev(sch))) + pr_warn_ratelimited("sch_tbf: burst %llu is lower than device %s mtu (%u) !\n", + max_size, qdisc_dev(sch)->name, + psched_mtu(qdisc_dev(sch))); + + if (!max_size) { + err = -EINVAL; + goto done; + } + + if (q->qdisc != &noop_qdisc) { + err = fifo_set_limit(q->qdisc, qopt->limit); + if (err) + goto done; + } else if (qopt->limit > 0) { + child = fifo_create_dflt(sch, &bfifo_qdisc_ops, qopt->limit, + extack); + if (IS_ERR(child)) { + err = PTR_ERR(child); + goto done; + } + + /* child is fifo, no need to check for noop_qdisc */ + qdisc_hash_add(child, true); + } + + sch_tree_lock(sch); + if (child) { + qdisc_tree_flush_backlog(q->qdisc); + old = q->qdisc; + q->qdisc = child; + } + q->limit = qopt->limit; + if (tb[TCA_TBF_PBURST]) + q->mtu = mtu; + else + q->mtu = PSCHED_TICKS2NS(qopt->mtu); + q->max_size = max_size; + if (tb[TCA_TBF_BURST]) + q->buffer = buffer; + else + q->buffer = PSCHED_TICKS2NS(qopt->buffer); + q->tokens = q->buffer; + q->ptokens = q->mtu; + + memcpy(&q->rate, &rate, sizeof(struct psched_ratecfg)); + memcpy(&q->peak, &peak, sizeof(struct psched_ratecfg)); + + sch_tree_unlock(sch); + qdisc_put(old); + err = 0; + + tbf_offload_change(sch); +done: + return err; +} + +static int tbf_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct tbf_sched_data *q = qdisc_priv(sch); + + qdisc_watchdog_init(&q->watchdog, sch); + q->qdisc = &noop_qdisc; + + if (!opt) + return -EINVAL; + + q->t_c = ktime_get_ns(); + + return tbf_change(sch, opt, extack); +} + +static void tbf_destroy(struct Qdisc *sch) +{ + struct tbf_sched_data *q = qdisc_priv(sch); + + qdisc_watchdog_cancel(&q->watchdog); + tbf_offload_destroy(sch); + qdisc_put(q->qdisc); +} + +static int tbf_dump(struct Qdisc *sch, struct sk_buff *skb) +{ + struct tbf_sched_data *q = qdisc_priv(sch); + struct nlattr *nest; + struct tc_tbf_qopt opt; + int err; + + err = tbf_offload_dump(sch); + if (err) + return err; + + nest = nla_nest_start_noflag(skb, TCA_OPTIONS); + if (nest == NULL) + goto nla_put_failure; + + opt.limit = q->limit; + psched_ratecfg_getrate(&opt.rate, &q->rate); + if (tbf_peak_present(q)) + psched_ratecfg_getrate(&opt.peakrate, &q->peak); + else + memset(&opt.peakrate, 0, sizeof(opt.peakrate)); + opt.mtu = PSCHED_NS2TICKS(q->mtu); + opt.buffer = PSCHED_NS2TICKS(q->buffer); + if (nla_put(skb, TCA_TBF_PARMS, sizeof(opt), &opt)) + goto nla_put_failure; + if (q->rate.rate_bytes_ps >= (1ULL << 32) && + nla_put_u64_64bit(skb, TCA_TBF_RATE64, q->rate.rate_bytes_ps, + TCA_TBF_PAD)) + goto nla_put_failure; + if (tbf_peak_present(q) && + q->peak.rate_bytes_ps >= (1ULL << 32) && + nla_put_u64_64bit(skb, TCA_TBF_PRATE64, q->peak.rate_bytes_ps, + TCA_TBF_PAD)) + goto nla_put_failure; + + return nla_nest_end(skb, nest); + +nla_put_failure: + nla_nest_cancel(skb, nest); + return -1; +} + +static int tbf_dump_class(struct Qdisc *sch, unsigned long cl, + struct sk_buff *skb, struct tcmsg *tcm) +{ + struct tbf_sched_data *q = qdisc_priv(sch); + + tcm->tcm_handle |= TC_H_MIN(1); + tcm->tcm_info = q->qdisc->handle; + + return 0; +} + +static int tbf_graft(struct Qdisc *sch, unsigned long arg, struct Qdisc *new, + struct Qdisc **old, struct netlink_ext_ack *extack) +{ + struct tbf_sched_data *q = qdisc_priv(sch); + + if (new == NULL) + new = &noop_qdisc; + + *old = qdisc_replace(sch, new, &q->qdisc); + + tbf_offload_graft(sch, new, *old, extack); + return 0; +} + +static struct Qdisc *tbf_leaf(struct Qdisc *sch, unsigned long arg) +{ + struct tbf_sched_data *q = qdisc_priv(sch); + return q->qdisc; +} + +static unsigned long tbf_find(struct Qdisc *sch, u32 classid) +{ + return 1; +} + +static void tbf_walk(struct Qdisc *sch, struct qdisc_walker *walker) +{ + if (!walker->stop) { + tc_qdisc_stats_dump(sch, 1, walker); + } +} + +static const struct Qdisc_class_ops tbf_class_ops = { + .graft = tbf_graft, + .leaf = tbf_leaf, + .find = tbf_find, + .walk = tbf_walk, + .dump = tbf_dump_class, +}; + +static struct Qdisc_ops tbf_qdisc_ops __read_mostly = { + .next = NULL, + .cl_ops = &tbf_class_ops, + .id = "tbf", + .priv_size = sizeof(struct tbf_sched_data), + .enqueue = tbf_enqueue, + .dequeue = tbf_dequeue, + .peek = qdisc_peek_dequeued, + .init = tbf_init, + .reset = tbf_reset, + .destroy = tbf_destroy, + .change = tbf_change, + .dump = tbf_dump, + .owner = THIS_MODULE, +}; + +static int __init tbf_module_init(void) +{ + return register_qdisc(&tbf_qdisc_ops); +} + +static void __exit tbf_module_exit(void) +{ + unregister_qdisc(&tbf_qdisc_ops); +} +module_init(tbf_module_init) +module_exit(tbf_module_exit) +MODULE_LICENSE("GPL"); diff --git a/net/sched/sch_teql.c b/net/sched/sch_teql.c new file mode 100644 index 000000000..7721239c1 --- /dev/null +++ b/net/sched/sch_teql.c @@ -0,0 +1,525 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* net/sched/sch_teql.c "True" (or "trivial") link equalizer. + * + * Authors: Alexey Kuznetsov, + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* + How to setup it. + ---------------- + + After loading this module you will find a new device teqlN + and new qdisc with the same name. To join a slave to the equalizer + you should just set this qdisc on a device f.e. + + # tc qdisc add dev eth0 root teql0 + # tc qdisc add dev eth1 root teql0 + + That's all. Full PnP 8) + + Applicability. + -------------- + + 1. Slave devices MUST be active devices, i.e., they must raise the tbusy + signal and generate EOI events. If you want to equalize virtual devices + like tunnels, use a normal eql device. + 2. This device puts no limitations on physical slave characteristics + f.e. it will equalize 9600baud line and 100Mb ethernet perfectly :-) + Certainly, large difference in link speeds will make the resulting + eqalized link unusable, because of huge packet reordering. + I estimate an upper useful difference as ~10 times. + 3. If the slave requires address resolution, only protocols using + neighbour cache (IPv4/IPv6) will work over the equalized link. + Other protocols are still allowed to use the slave device directly, + which will not break load balancing, though native slave + traffic will have the highest priority. */ + +struct teql_master { + struct Qdisc_ops qops; + struct net_device *dev; + struct Qdisc *slaves; + struct list_head master_list; + unsigned long tx_bytes; + unsigned long tx_packets; + unsigned long tx_errors; + unsigned long tx_dropped; +}; + +struct teql_sched_data { + struct Qdisc *next; + struct teql_master *m; + struct sk_buff_head q; +}; + +#define NEXT_SLAVE(q) (((struct teql_sched_data *)qdisc_priv(q))->next) + +#define FMASK (IFF_BROADCAST | IFF_POINTOPOINT) + +/* "teql*" qdisc routines */ + +static int +teql_enqueue(struct sk_buff *skb, struct Qdisc *sch, struct sk_buff **to_free) +{ + struct net_device *dev = qdisc_dev(sch); + struct teql_sched_data *q = qdisc_priv(sch); + + if (q->q.qlen < dev->tx_queue_len) { + __skb_queue_tail(&q->q, skb); + return NET_XMIT_SUCCESS; + } + + return qdisc_drop(skb, sch, to_free); +} + +static struct sk_buff * +teql_dequeue(struct Qdisc *sch) +{ + struct teql_sched_data *dat = qdisc_priv(sch); + struct netdev_queue *dat_queue; + struct sk_buff *skb; + struct Qdisc *q; + + skb = __skb_dequeue(&dat->q); + dat_queue = netdev_get_tx_queue(dat->m->dev, 0); + q = rcu_dereference_bh(dat_queue->qdisc); + + if (skb == NULL) { + struct net_device *m = qdisc_dev(q); + if (m) { + dat->m->slaves = sch; + netif_wake_queue(m); + } + } else { + qdisc_bstats_update(sch, skb); + } + sch->q.qlen = dat->q.qlen + q->q.qlen; + return skb; +} + +static struct sk_buff * +teql_peek(struct Qdisc *sch) +{ + /* teql is meant to be used as root qdisc */ + return NULL; +} + +static void +teql_reset(struct Qdisc *sch) +{ + struct teql_sched_data *dat = qdisc_priv(sch); + + skb_queue_purge(&dat->q); +} + +static void +teql_destroy(struct Qdisc *sch) +{ + struct Qdisc *q, *prev; + struct teql_sched_data *dat = qdisc_priv(sch); + struct teql_master *master = dat->m; + + if (!master) + return; + + prev = master->slaves; + if (prev) { + do { + q = NEXT_SLAVE(prev); + if (q == sch) { + NEXT_SLAVE(prev) = NEXT_SLAVE(q); + if (q == master->slaves) { + master->slaves = NEXT_SLAVE(q); + if (q == master->slaves) { + struct netdev_queue *txq; + spinlock_t *root_lock; + + txq = netdev_get_tx_queue(master->dev, 0); + master->slaves = NULL; + + root_lock = qdisc_root_sleeping_lock(rtnl_dereference(txq->qdisc)); + spin_lock_bh(root_lock); + qdisc_reset(rtnl_dereference(txq->qdisc)); + spin_unlock_bh(root_lock); + } + } + skb_queue_purge(&dat->q); + break; + } + + } while ((prev = q) != master->slaves); + } +} + +static int teql_qdisc_init(struct Qdisc *sch, struct nlattr *opt, + struct netlink_ext_ack *extack) +{ + struct net_device *dev = qdisc_dev(sch); + struct teql_master *m = (struct teql_master *)sch->ops; + struct teql_sched_data *q = qdisc_priv(sch); + + if (dev->hard_header_len > m->dev->hard_header_len) + return -EINVAL; + + if (m->dev == dev) + return -ELOOP; + + q->m = m; + + skb_queue_head_init(&q->q); + + if (m->slaves) { + if (m->dev->flags & IFF_UP) { + if ((m->dev->flags & IFF_POINTOPOINT && + !(dev->flags & IFF_POINTOPOINT)) || + (m->dev->flags & IFF_BROADCAST && + !(dev->flags & IFF_BROADCAST)) || + (m->dev->flags & IFF_MULTICAST && + !(dev->flags & IFF_MULTICAST)) || + dev->mtu < m->dev->mtu) + return -EINVAL; + } else { + if (!(dev->flags&IFF_POINTOPOINT)) + m->dev->flags &= ~IFF_POINTOPOINT; + if (!(dev->flags&IFF_BROADCAST)) + m->dev->flags &= ~IFF_BROADCAST; + if (!(dev->flags&IFF_MULTICAST)) + m->dev->flags &= ~IFF_MULTICAST; + if (dev->mtu < m->dev->mtu) + m->dev->mtu = dev->mtu; + } + q->next = NEXT_SLAVE(m->slaves); + NEXT_SLAVE(m->slaves) = sch; + } else { + q->next = sch; + m->slaves = sch; + m->dev->mtu = dev->mtu; + m->dev->flags = (m->dev->flags&~FMASK)|(dev->flags&FMASK); + } + return 0; +} + + +static int +__teql_resolve(struct sk_buff *skb, struct sk_buff *skb_res, + struct net_device *dev, struct netdev_queue *txq, + struct dst_entry *dst) +{ + struct neighbour *n; + int err = 0; + + n = dst_neigh_lookup_skb(dst, skb); + if (!n) + return -ENOENT; + + if (dst->dev != dev) { + struct neighbour *mn; + + mn = __neigh_lookup_errno(n->tbl, n->primary_key, dev); + neigh_release(n); + if (IS_ERR(mn)) + return PTR_ERR(mn); + n = mn; + } + + if (neigh_event_send(n, skb_res) == 0) { + int err; + char haddr[MAX_ADDR_LEN]; + + neigh_ha_snapshot(haddr, n, dev); + err = dev_hard_header(skb, dev, ntohs(skb_protocol(skb, false)), + haddr, NULL, skb->len); + + if (err < 0) + err = -EINVAL; + } else { + err = (skb_res == NULL) ? -EAGAIN : 1; + } + neigh_release(n); + return err; +} + +static inline int teql_resolve(struct sk_buff *skb, + struct sk_buff *skb_res, + struct net_device *dev, + struct netdev_queue *txq) +{ + struct dst_entry *dst = skb_dst(skb); + int res; + + if (rcu_access_pointer(txq->qdisc) == &noop_qdisc) + return -ENODEV; + + if (!dev->header_ops || !dst) + return 0; + + rcu_read_lock(); + res = __teql_resolve(skb, skb_res, dev, txq, dst); + rcu_read_unlock(); + + return res; +} + +static netdev_tx_t teql_master_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct teql_master *master = netdev_priv(dev); + struct Qdisc *start, *q; + int busy; + int nores; + int subq = skb_get_queue_mapping(skb); + struct sk_buff *skb_res = NULL; + + start = master->slaves; + +restart: + nores = 0; + busy = 0; + + q = start; + if (!q) + goto drop; + + do { + struct net_device *slave = qdisc_dev(q); + struct netdev_queue *slave_txq = netdev_get_tx_queue(slave, 0); + + if (rcu_access_pointer(slave_txq->qdisc_sleeping) != q) + continue; + if (netif_xmit_stopped(netdev_get_tx_queue(slave, subq)) || + !netif_running(slave)) { + busy = 1; + continue; + } + + switch (teql_resolve(skb, skb_res, slave, slave_txq)) { + case 0: + if (__netif_tx_trylock(slave_txq)) { + unsigned int length = qdisc_pkt_len(skb); + + if (!netif_xmit_frozen_or_stopped(slave_txq) && + netdev_start_xmit(skb, slave, slave_txq, false) == + NETDEV_TX_OK) { + __netif_tx_unlock(slave_txq); + master->slaves = NEXT_SLAVE(q); + netif_wake_queue(dev); + master->tx_packets++; + master->tx_bytes += length; + return NETDEV_TX_OK; + } + __netif_tx_unlock(slave_txq); + } + if (netif_xmit_stopped(netdev_get_tx_queue(dev, 0))) + busy = 1; + break; + case 1: + master->slaves = NEXT_SLAVE(q); + return NETDEV_TX_OK; + default: + nores = 1; + break; + } + __skb_pull(skb, skb_network_offset(skb)); + } while ((q = NEXT_SLAVE(q)) != start); + + if (nores && skb_res == NULL) { + skb_res = skb; + goto restart; + } + + if (busy) { + netif_stop_queue(dev); + return NETDEV_TX_BUSY; + } + master->tx_errors++; + +drop: + master->tx_dropped++; + dev_kfree_skb(skb); + return NETDEV_TX_OK; +} + +static int teql_master_open(struct net_device *dev) +{ + struct Qdisc *q; + struct teql_master *m = netdev_priv(dev); + int mtu = 0xFFFE; + unsigned int flags = IFF_NOARP | IFF_MULTICAST; + + if (m->slaves == NULL) + return -EUNATCH; + + flags = FMASK; + + q = m->slaves; + do { + struct net_device *slave = qdisc_dev(q); + + if (slave == NULL) + return -EUNATCH; + + if (slave->mtu < mtu) + mtu = slave->mtu; + if (slave->hard_header_len > LL_MAX_HEADER) + return -EINVAL; + + /* If all the slaves are BROADCAST, master is BROADCAST + If all the slaves are PtP, master is PtP + Otherwise, master is NBMA. + */ + if (!(slave->flags&IFF_POINTOPOINT)) + flags &= ~IFF_POINTOPOINT; + if (!(slave->flags&IFF_BROADCAST)) + flags &= ~IFF_BROADCAST; + if (!(slave->flags&IFF_MULTICAST)) + flags &= ~IFF_MULTICAST; + } while ((q = NEXT_SLAVE(q)) != m->slaves); + + m->dev->mtu = mtu; + m->dev->flags = (m->dev->flags&~FMASK) | flags; + netif_start_queue(m->dev); + return 0; +} + +static int teql_master_close(struct net_device *dev) +{ + netif_stop_queue(dev); + return 0; +} + +static void teql_master_stats64(struct net_device *dev, + struct rtnl_link_stats64 *stats) +{ + struct teql_master *m = netdev_priv(dev); + + stats->tx_packets = m->tx_packets; + stats->tx_bytes = m->tx_bytes; + stats->tx_errors = m->tx_errors; + stats->tx_dropped = m->tx_dropped; +} + +static int teql_master_mtu(struct net_device *dev, int new_mtu) +{ + struct teql_master *m = netdev_priv(dev); + struct Qdisc *q; + + q = m->slaves; + if (q) { + do { + if (new_mtu > qdisc_dev(q)->mtu) + return -EINVAL; + } while ((q = NEXT_SLAVE(q)) != m->slaves); + } + + dev->mtu = new_mtu; + return 0; +} + +static const struct net_device_ops teql_netdev_ops = { + .ndo_open = teql_master_open, + .ndo_stop = teql_master_close, + .ndo_start_xmit = teql_master_xmit, + .ndo_get_stats64 = teql_master_stats64, + .ndo_change_mtu = teql_master_mtu, +}; + +static __init void teql_master_setup(struct net_device *dev) +{ + struct teql_master *master = netdev_priv(dev); + struct Qdisc_ops *ops = &master->qops; + + master->dev = dev; + ops->priv_size = sizeof(struct teql_sched_data); + + ops->enqueue = teql_enqueue; + ops->dequeue = teql_dequeue; + ops->peek = teql_peek; + ops->init = teql_qdisc_init; + ops->reset = teql_reset; + ops->destroy = teql_destroy; + ops->owner = THIS_MODULE; + + dev->netdev_ops = &teql_netdev_ops; + dev->type = ARPHRD_VOID; + dev->mtu = 1500; + dev->min_mtu = 68; + dev->max_mtu = 65535; + dev->tx_queue_len = 100; + dev->flags = IFF_NOARP; + dev->hard_header_len = LL_MAX_HEADER; + netif_keep_dst(dev); +} + +static LIST_HEAD(master_dev_list); +static int max_equalizers = 1; +module_param(max_equalizers, int, 0); +MODULE_PARM_DESC(max_equalizers, "Max number of link equalizers"); + +static int __init teql_init(void) +{ + int i; + int err = -ENODEV; + + for (i = 0; i < max_equalizers; i++) { + struct net_device *dev; + struct teql_master *master; + + dev = alloc_netdev(sizeof(struct teql_master), "teql%d", + NET_NAME_UNKNOWN, teql_master_setup); + if (!dev) { + err = -ENOMEM; + break; + } + + if ((err = register_netdev(dev))) { + free_netdev(dev); + break; + } + + master = netdev_priv(dev); + + strscpy(master->qops.id, dev->name, IFNAMSIZ); + err = register_qdisc(&master->qops); + + if (err) { + unregister_netdev(dev); + free_netdev(dev); + break; + } + + list_add_tail(&master->master_list, &master_dev_list); + } + return i ? 0 : err; +} + +static void __exit teql_exit(void) +{ + struct teql_master *master, *nxt; + + list_for_each_entry_safe(master, nxt, &master_dev_list, master_list) { + + list_del(&master->master_list); + + unregister_qdisc(&master->qops); + unregister_netdev(master->dev); + free_netdev(master->dev); + } +} + +module_init(teql_init); +module_exit(teql_exit); + +MODULE_LICENSE("GPL"); diff --git a/net/sctp/Kconfig b/net/sctp/Kconfig new file mode 100644 index 000000000..5da599ff8 --- /dev/null +++ b/net/sctp/Kconfig @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# SCTP configuration +# + +menuconfig IP_SCTP + tristate "The SCTP Protocol" + depends on INET + depends on IPV6 || IPV6=n + select CRYPTO + select CRYPTO_HMAC + select CRYPTO_SHA1 + select LIBCRC32C + select NET_UDP_TUNNEL + help + Stream Control Transmission Protocol + + From RFC 2960 . + + "SCTP is a reliable transport protocol operating on top of a + connectionless packet network such as IP. It offers the following + services to its users: + + -- acknowledged error-free non-duplicated transfer of user data, + -- data fragmentation to conform to discovered path MTU size, + -- sequenced delivery of user messages within multiple streams, + with an option for order-of-arrival delivery of individual user + messages, + -- optional bundling of multiple user messages into a single SCTP + packet, and + -- network-level fault tolerance through supporting of multi- + homing at either or both ends of an association." + + To compile this protocol support as a module, choose M here: the + module will be called sctp. Debug messages are handled by the + kernel's dynamic debugging framework. + + If in doubt, say N. + +if IP_SCTP + +config SCTP_DBG_OBJCNT + bool "SCTP: Debug object counts" + depends on PROC_FS + help + If you say Y, this will enable debugging support for counting the + type of objects that are currently allocated. This is useful for + identifying memory leaks. This debug information can be viewed by + 'cat /proc/net/sctp/sctp_dbg_objcnt' + + If unsure, say N +choice + prompt "Default SCTP cookie HMAC encoding" + default SCTP_DEFAULT_COOKIE_HMAC_MD5 + help + This option sets the default sctp cookie hmac algorithm + when in doubt select 'md5' + +config SCTP_DEFAULT_COOKIE_HMAC_MD5 + bool "Enable optional MD5 hmac cookie generation" + help + Enable optional MD5 hmac based SCTP cookie generation + select SCTP_COOKIE_HMAC_MD5 + +config SCTP_DEFAULT_COOKIE_HMAC_SHA1 + bool "Enable optional SHA1 hmac cookie generation" + help + Enable optional SHA1 hmac based SCTP cookie generation + select SCTP_COOKIE_HMAC_SHA1 + +config SCTP_DEFAULT_COOKIE_HMAC_NONE + bool "Use no hmac alg in SCTP cookie generation" + help + Use no hmac algorithm in SCTP cookie generation + +endchoice + +config SCTP_COOKIE_HMAC_MD5 + bool "Enable optional MD5 hmac cookie generation" + help + Enable optional MD5 hmac based SCTP cookie generation + select CRYPTO_HMAC if SCTP_COOKIE_HMAC_MD5 + select CRYPTO_MD5 if SCTP_COOKIE_HMAC_MD5 + +config SCTP_COOKIE_HMAC_SHA1 + bool "Enable optional SHA1 hmac cookie generation" + help + Enable optional SHA1 hmac based SCTP cookie generation + select CRYPTO_HMAC if SCTP_COOKIE_HMAC_SHA1 + select CRYPTO_SHA1 if SCTP_COOKIE_HMAC_SHA1 + +config INET_SCTP_DIAG + depends on INET_DIAG + def_tristate INET_DIAG + + +endif # IP_SCTP diff --git a/net/sctp/Makefile b/net/sctp/Makefile new file mode 100644 index 000000000..e845e4588 --- /dev/null +++ b/net/sctp/Makefile @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for SCTP support code. +# + +obj-$(CONFIG_IP_SCTP) += sctp.o +obj-$(CONFIG_INET_SCTP_DIAG) += sctp_diag.o + +sctp-y := sm_statetable.o sm_statefuns.o sm_sideeffect.o \ + protocol.o endpointola.o associola.o \ + transport.o chunk.o sm_make_chunk.o ulpevent.o \ + inqueue.o outqueue.o ulpqueue.o \ + tsnmap.o bind_addr.o socket.o primitive.o \ + output.o input.o debug.o stream.o auth.o \ + offload.o stream_sched.o stream_sched_prio.o \ + stream_sched_rr.o stream_interleave.o + +sctp_diag-y := diag.o + +sctp-$(CONFIG_SCTP_DBG_OBJCNT) += objcnt.o +sctp-$(CONFIG_PROC_FS) += proc.o +sctp-$(CONFIG_SYSCTL) += sysctl.o + +sctp-$(subst m,y,$(CONFIG_IPV6)) += ipv6.o diff --git a/net/sctp/associola.c b/net/sctp/associola.c new file mode 100644 index 000000000..2965a12fe --- /dev/null +++ b/net/sctp/associola.c @@ -0,0 +1,1729 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2001, 2004 + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * Copyright (c) 2001 Intel Corp. + * Copyright (c) 2001 La Monte H.P. Yarroll + * + * This file is part of the SCTP kernel implementation + * + * This module provides the abstraction for an SCTP association. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * La Monte H.P. Yarroll + * Karl Knutson + * Jon Grimm + * Xingang Guo + * Hui Huang + * Sridhar Samudrala + * Daisy Chang + * Ryan Layer + * Kevin Gao + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +/* Forward declarations for internal functions. */ +static void sctp_select_active_and_retran_path(struct sctp_association *asoc); +static void sctp_assoc_bh_rcv(struct work_struct *work); +static void sctp_assoc_free_asconf_acks(struct sctp_association *asoc); +static void sctp_assoc_free_asconf_queue(struct sctp_association *asoc); + +/* 1st Level Abstractions. */ + +/* Initialize a new association from provided memory. */ +static struct sctp_association *sctp_association_init( + struct sctp_association *asoc, + const struct sctp_endpoint *ep, + const struct sock *sk, + enum sctp_scope scope, gfp_t gfp) +{ + struct sctp_sock *sp; + struct sctp_paramhdr *p; + int i; + + /* Retrieve the SCTP per socket area. */ + sp = sctp_sk((struct sock *)sk); + + /* Discarding const is appropriate here. */ + asoc->ep = (struct sctp_endpoint *)ep; + asoc->base.sk = (struct sock *)sk; + asoc->base.net = sock_net(sk); + + sctp_endpoint_hold(asoc->ep); + sock_hold(asoc->base.sk); + + /* Initialize the common base substructure. */ + asoc->base.type = SCTP_EP_TYPE_ASSOCIATION; + + /* Initialize the object handling fields. */ + refcount_set(&asoc->base.refcnt, 1); + + /* Initialize the bind addr area. */ + sctp_bind_addr_init(&asoc->base.bind_addr, ep->base.bind_addr.port); + + asoc->state = SCTP_STATE_CLOSED; + asoc->cookie_life = ms_to_ktime(sp->assocparams.sasoc_cookie_life); + asoc->user_frag = sp->user_frag; + + /* Set the association max_retrans and RTO values from the + * socket values. + */ + asoc->max_retrans = sp->assocparams.sasoc_asocmaxrxt; + asoc->pf_retrans = sp->pf_retrans; + asoc->ps_retrans = sp->ps_retrans; + asoc->pf_expose = sp->pf_expose; + + asoc->rto_initial = msecs_to_jiffies(sp->rtoinfo.srto_initial); + asoc->rto_max = msecs_to_jiffies(sp->rtoinfo.srto_max); + asoc->rto_min = msecs_to_jiffies(sp->rtoinfo.srto_min); + + /* Initialize the association's heartbeat interval based on the + * sock configured value. + */ + asoc->hbinterval = msecs_to_jiffies(sp->hbinterval); + asoc->probe_interval = msecs_to_jiffies(sp->probe_interval); + + asoc->encap_port = sp->encap_port; + + /* Initialize path max retrans value. */ + asoc->pathmaxrxt = sp->pathmaxrxt; + + asoc->flowlabel = sp->flowlabel; + asoc->dscp = sp->dscp; + + /* Set association default SACK delay */ + asoc->sackdelay = msecs_to_jiffies(sp->sackdelay); + asoc->sackfreq = sp->sackfreq; + + /* Set the association default flags controlling + * Heartbeat, SACK delay, and Path MTU Discovery. + */ + asoc->param_flags = sp->param_flags; + + /* Initialize the maximum number of new data packets that can be sent + * in a burst. + */ + asoc->max_burst = sp->max_burst; + + asoc->subscribe = sp->subscribe; + + /* initialize association timers */ + asoc->timeouts[SCTP_EVENT_TIMEOUT_T1_COOKIE] = asoc->rto_initial; + asoc->timeouts[SCTP_EVENT_TIMEOUT_T1_INIT] = asoc->rto_initial; + asoc->timeouts[SCTP_EVENT_TIMEOUT_T2_SHUTDOWN] = asoc->rto_initial; + + /* sctpimpguide Section 2.12.2 + * If the 'T5-shutdown-guard' timer is used, it SHOULD be set to the + * recommended value of 5 times 'RTO.Max'. + */ + asoc->timeouts[SCTP_EVENT_TIMEOUT_T5_SHUTDOWN_GUARD] + = 5 * asoc->rto_max; + + asoc->timeouts[SCTP_EVENT_TIMEOUT_SACK] = asoc->sackdelay; + asoc->timeouts[SCTP_EVENT_TIMEOUT_AUTOCLOSE] = sp->autoclose * HZ; + + /* Initializes the timers */ + for (i = SCTP_EVENT_TIMEOUT_NONE; i < SCTP_NUM_TIMEOUT_TYPES; ++i) + timer_setup(&asoc->timers[i], sctp_timer_events[i], 0); + + /* Pull default initialization values from the sock options. + * Note: This assumes that the values have already been + * validated in the sock. + */ + asoc->c.sinit_max_instreams = sp->initmsg.sinit_max_instreams; + asoc->c.sinit_num_ostreams = sp->initmsg.sinit_num_ostreams; + asoc->max_init_attempts = sp->initmsg.sinit_max_attempts; + + asoc->max_init_timeo = + msecs_to_jiffies(sp->initmsg.sinit_max_init_timeo); + + /* Set the local window size for receive. + * This is also the rcvbuf space per association. + * RFC 6 - A SCTP receiver MUST be able to receive a minimum of + * 1500 bytes in one SCTP packet. + */ + if ((sk->sk_rcvbuf/2) < SCTP_DEFAULT_MINWINDOW) + asoc->rwnd = SCTP_DEFAULT_MINWINDOW; + else + asoc->rwnd = sk->sk_rcvbuf/2; + + asoc->a_rwnd = asoc->rwnd; + + /* Use my own max window until I learn something better. */ + asoc->peer.rwnd = SCTP_DEFAULT_MAXWINDOW; + + /* Initialize the receive memory counter */ + atomic_set(&asoc->rmem_alloc, 0); + + init_waitqueue_head(&asoc->wait); + + asoc->c.my_vtag = sctp_generate_tag(ep); + asoc->c.my_port = ep->base.bind_addr.port; + + asoc->c.initial_tsn = sctp_generate_tsn(ep); + + asoc->next_tsn = asoc->c.initial_tsn; + + asoc->ctsn_ack_point = asoc->next_tsn - 1; + asoc->adv_peer_ack_point = asoc->ctsn_ack_point; + asoc->highest_sacked = asoc->ctsn_ack_point; + asoc->last_cwr_tsn = asoc->ctsn_ack_point; + + /* ADDIP Section 4.1 Asconf Chunk Procedures + * + * When an endpoint has an ASCONF signaled change to be sent to the + * remote endpoint it should do the following: + * ... + * A2) a serial number should be assigned to the chunk. The serial + * number SHOULD be a monotonically increasing number. The serial + * numbers SHOULD be initialized at the start of the + * association to the same value as the initial TSN. + */ + asoc->addip_serial = asoc->c.initial_tsn; + asoc->strreset_outseq = asoc->c.initial_tsn; + + INIT_LIST_HEAD(&asoc->addip_chunk_list); + INIT_LIST_HEAD(&asoc->asconf_ack_list); + + /* Make an empty list of remote transport addresses. */ + INIT_LIST_HEAD(&asoc->peer.transport_addr_list); + + /* RFC 2960 5.1 Normal Establishment of an Association + * + * After the reception of the first data chunk in an + * association the endpoint must immediately respond with a + * sack to acknowledge the data chunk. Subsequent + * acknowledgements should be done as described in Section + * 6.2. + * + * [We implement this by telling a new association that it + * already received one packet.] + */ + asoc->peer.sack_needed = 1; + asoc->peer.sack_generation = 1; + + /* Create an input queue. */ + sctp_inq_init(&asoc->base.inqueue); + sctp_inq_set_th_handler(&asoc->base.inqueue, sctp_assoc_bh_rcv); + + /* Create an output queue. */ + sctp_outq_init(asoc, &asoc->outqueue); + + if (!sctp_ulpq_init(&asoc->ulpq, asoc)) + goto fail_init; + + if (sctp_stream_init(&asoc->stream, asoc->c.sinit_num_ostreams, 0, gfp)) + goto stream_free; + + /* Initialize default path MTU. */ + asoc->pathmtu = sp->pathmtu; + sctp_assoc_update_frag_point(asoc); + + /* Assume that peer would support both address types unless we are + * told otherwise. + */ + asoc->peer.ipv4_address = 1; + if (asoc->base.sk->sk_family == PF_INET6) + asoc->peer.ipv6_address = 1; + INIT_LIST_HEAD(&asoc->asocs); + + asoc->default_stream = sp->default_stream; + asoc->default_ppid = sp->default_ppid; + asoc->default_flags = sp->default_flags; + asoc->default_context = sp->default_context; + asoc->default_timetolive = sp->default_timetolive; + asoc->default_rcv_context = sp->default_rcv_context; + + /* AUTH related initializations */ + INIT_LIST_HEAD(&asoc->endpoint_shared_keys); + if (sctp_auth_asoc_copy_shkeys(ep, asoc, gfp)) + goto stream_free; + + asoc->active_key_id = ep->active_key_id; + asoc->strreset_enable = ep->strreset_enable; + + /* Save the hmacs and chunks list into this association */ + if (ep->auth_hmacs_list) + memcpy(asoc->c.auth_hmacs, ep->auth_hmacs_list, + ntohs(ep->auth_hmacs_list->param_hdr.length)); + if (ep->auth_chunk_list) + memcpy(asoc->c.auth_chunks, ep->auth_chunk_list, + ntohs(ep->auth_chunk_list->param_hdr.length)); + + /* Get the AUTH random number for this association */ + p = (struct sctp_paramhdr *)asoc->c.auth_random; + p->type = SCTP_PARAM_RANDOM; + p->length = htons(sizeof(*p) + SCTP_AUTH_RANDOM_LENGTH); + get_random_bytes(p+1, SCTP_AUTH_RANDOM_LENGTH); + + return asoc; + +stream_free: + sctp_stream_free(&asoc->stream); +fail_init: + sock_put(asoc->base.sk); + sctp_endpoint_put(asoc->ep); + return NULL; +} + +/* Allocate and initialize a new association */ +struct sctp_association *sctp_association_new(const struct sctp_endpoint *ep, + const struct sock *sk, + enum sctp_scope scope, gfp_t gfp) +{ + struct sctp_association *asoc; + + asoc = kzalloc(sizeof(*asoc), gfp); + if (!asoc) + goto fail; + + if (!sctp_association_init(asoc, ep, sk, scope, gfp)) + goto fail_init; + + SCTP_DBG_OBJCNT_INC(assoc); + + pr_debug("Created asoc %p\n", asoc); + + return asoc; + +fail_init: + kfree(asoc); +fail: + return NULL; +} + +/* Free this association if possible. There may still be users, so + * the actual deallocation may be delayed. + */ +void sctp_association_free(struct sctp_association *asoc) +{ + struct sock *sk = asoc->base.sk; + struct sctp_transport *transport; + struct list_head *pos, *temp; + int i; + + /* Only real associations count against the endpoint, so + * don't bother for if this is a temporary association. + */ + if (!list_empty(&asoc->asocs)) { + list_del(&asoc->asocs); + + /* Decrement the backlog value for a TCP-style listening + * socket. + */ + if (sctp_style(sk, TCP) && sctp_sstate(sk, LISTENING)) + sk_acceptq_removed(sk); + } + + /* Mark as dead, so other users can know this structure is + * going away. + */ + asoc->base.dead = true; + + /* Dispose of any data lying around in the outqueue. */ + sctp_outq_free(&asoc->outqueue); + + /* Dispose of any pending messages for the upper layer. */ + sctp_ulpq_free(&asoc->ulpq); + + /* Dispose of any pending chunks on the inqueue. */ + sctp_inq_free(&asoc->base.inqueue); + + sctp_tsnmap_free(&asoc->peer.tsn_map); + + /* Free stream information. */ + sctp_stream_free(&asoc->stream); + + if (asoc->strreset_chunk) + sctp_chunk_free(asoc->strreset_chunk); + + /* Clean up the bound address list. */ + sctp_bind_addr_free(&asoc->base.bind_addr); + + /* Do we need to go through all of our timers and + * delete them? To be safe we will try to delete all, but we + * should be able to go through and make a guess based + * on our state. + */ + for (i = SCTP_EVENT_TIMEOUT_NONE; i < SCTP_NUM_TIMEOUT_TYPES; ++i) { + if (del_timer(&asoc->timers[i])) + sctp_association_put(asoc); + } + + /* Free peer's cached cookie. */ + kfree(asoc->peer.cookie); + kfree(asoc->peer.peer_random); + kfree(asoc->peer.peer_chunks); + kfree(asoc->peer.peer_hmacs); + + /* Release the transport structures. */ + list_for_each_safe(pos, temp, &asoc->peer.transport_addr_list) { + transport = list_entry(pos, struct sctp_transport, transports); + list_del_rcu(pos); + sctp_unhash_transport(transport); + sctp_transport_free(transport); + } + + asoc->peer.transport_count = 0; + + sctp_asconf_queue_teardown(asoc); + + /* Free pending address space being deleted */ + kfree(asoc->asconf_addr_del_pending); + + /* AUTH - Free the endpoint shared keys */ + sctp_auth_destroy_keys(&asoc->endpoint_shared_keys); + + /* AUTH - Free the association shared key */ + sctp_auth_key_put(asoc->asoc_shared_key); + + sctp_association_put(asoc); +} + +/* Cleanup and free up an association. */ +static void sctp_association_destroy(struct sctp_association *asoc) +{ + if (unlikely(!asoc->base.dead)) { + WARN(1, "Attempt to destroy undead association %p!\n", asoc); + return; + } + + sctp_endpoint_put(asoc->ep); + sock_put(asoc->base.sk); + + if (asoc->assoc_id != 0) { + spin_lock_bh(&sctp_assocs_id_lock); + idr_remove(&sctp_assocs_id, asoc->assoc_id); + spin_unlock_bh(&sctp_assocs_id_lock); + } + + WARN_ON(atomic_read(&asoc->rmem_alloc)); + + kfree_rcu(asoc, rcu); + SCTP_DBG_OBJCNT_DEC(assoc); +} + +/* Change the primary destination address for the peer. */ +void sctp_assoc_set_primary(struct sctp_association *asoc, + struct sctp_transport *transport) +{ + int changeover = 0; + + /* it's a changeover only if we already have a primary path + * that we are changing + */ + if (asoc->peer.primary_path != NULL && + asoc->peer.primary_path != transport) + changeover = 1 ; + + asoc->peer.primary_path = transport; + sctp_ulpevent_notify_peer_addr_change(transport, + SCTP_ADDR_MADE_PRIM, 0); + + /* Set a default msg_name for events. */ + memcpy(&asoc->peer.primary_addr, &transport->ipaddr, + sizeof(union sctp_addr)); + + /* If the primary path is changing, assume that the + * user wants to use this new path. + */ + if ((transport->state == SCTP_ACTIVE) || + (transport->state == SCTP_UNKNOWN)) + asoc->peer.active_path = transport; + + /* + * SFR-CACC algorithm: + * Upon the receipt of a request to change the primary + * destination address, on the data structure for the new + * primary destination, the sender MUST do the following: + * + * 1) If CHANGEOVER_ACTIVE is set, then there was a switch + * to this destination address earlier. The sender MUST set + * CYCLING_CHANGEOVER to indicate that this switch is a + * double switch to the same destination address. + * + * Really, only bother is we have data queued or outstanding on + * the association. + */ + if (!asoc->outqueue.outstanding_bytes && !asoc->outqueue.out_qlen) + return; + + if (transport->cacc.changeover_active) + transport->cacc.cycling_changeover = changeover; + + /* 2) The sender MUST set CHANGEOVER_ACTIVE to indicate that + * a changeover has occurred. + */ + transport->cacc.changeover_active = changeover; + + /* 3) The sender MUST store the next TSN to be sent in + * next_tsn_at_change. + */ + transport->cacc.next_tsn_at_change = asoc->next_tsn; +} + +/* Remove a transport from an association. */ +void sctp_assoc_rm_peer(struct sctp_association *asoc, + struct sctp_transport *peer) +{ + struct sctp_transport *transport; + struct list_head *pos; + struct sctp_chunk *ch; + + pr_debug("%s: association:%p addr:%pISpc\n", + __func__, asoc, &peer->ipaddr.sa); + + /* If we are to remove the current retran_path, update it + * to the next peer before removing this peer from the list. + */ + if (asoc->peer.retran_path == peer) + sctp_assoc_update_retran_path(asoc); + + /* Remove this peer from the list. */ + list_del_rcu(&peer->transports); + /* Remove this peer from the transport hashtable */ + sctp_unhash_transport(peer); + + /* Get the first transport of asoc. */ + pos = asoc->peer.transport_addr_list.next; + transport = list_entry(pos, struct sctp_transport, transports); + + /* Update any entries that match the peer to be deleted. */ + if (asoc->peer.primary_path == peer) + sctp_assoc_set_primary(asoc, transport); + if (asoc->peer.active_path == peer) + asoc->peer.active_path = transport; + if (asoc->peer.retran_path == peer) + asoc->peer.retran_path = transport; + if (asoc->peer.last_data_from == peer) + asoc->peer.last_data_from = transport; + + if (asoc->strreset_chunk && + asoc->strreset_chunk->transport == peer) { + asoc->strreset_chunk->transport = transport; + sctp_transport_reset_reconf_timer(transport); + } + + /* If we remove the transport an INIT was last sent to, set it to + * NULL. Combined with the update of the retran path above, this + * will cause the next INIT to be sent to the next available + * transport, maintaining the cycle. + */ + if (asoc->init_last_sent_to == peer) + asoc->init_last_sent_to = NULL; + + /* If we remove the transport an SHUTDOWN was last sent to, set it + * to NULL. Combined with the update of the retran path above, this + * will cause the next SHUTDOWN to be sent to the next available + * transport, maintaining the cycle. + */ + if (asoc->shutdown_last_sent_to == peer) + asoc->shutdown_last_sent_to = NULL; + + /* If we remove the transport an ASCONF was last sent to, set it to + * NULL. + */ + if (asoc->addip_last_asconf && + asoc->addip_last_asconf->transport == peer) + asoc->addip_last_asconf->transport = NULL; + + /* If we have something on the transmitted list, we have to + * save it off. The best place is the active path. + */ + if (!list_empty(&peer->transmitted)) { + struct sctp_transport *active = asoc->peer.active_path; + + /* Reset the transport of each chunk on this list */ + list_for_each_entry(ch, &peer->transmitted, + transmitted_list) { + ch->transport = NULL; + ch->rtt_in_progress = 0; + } + + list_splice_tail_init(&peer->transmitted, + &active->transmitted); + + /* Start a T3 timer here in case it wasn't running so + * that these migrated packets have a chance to get + * retransmitted. + */ + if (!timer_pending(&active->T3_rtx_timer)) + if (!mod_timer(&active->T3_rtx_timer, + jiffies + active->rto)) + sctp_transport_hold(active); + } + + list_for_each_entry(ch, &asoc->outqueue.out_chunk_list, list) + if (ch->transport == peer) + ch->transport = NULL; + + asoc->peer.transport_count--; + + sctp_ulpevent_notify_peer_addr_change(peer, SCTP_ADDR_REMOVED, 0); + sctp_transport_free(peer); +} + +/* Add a transport address to an association. */ +struct sctp_transport *sctp_assoc_add_peer(struct sctp_association *asoc, + const union sctp_addr *addr, + const gfp_t gfp, + const int peer_state) +{ + struct sctp_transport *peer; + struct sctp_sock *sp; + unsigned short port; + + sp = sctp_sk(asoc->base.sk); + + /* AF_INET and AF_INET6 share common port field. */ + port = ntohs(addr->v4.sin_port); + + pr_debug("%s: association:%p addr:%pISpc state:%d\n", __func__, + asoc, &addr->sa, peer_state); + + /* Set the port if it has not been set yet. */ + if (0 == asoc->peer.port) + asoc->peer.port = port; + + /* Check to see if this is a duplicate. */ + peer = sctp_assoc_lookup_paddr(asoc, addr); + if (peer) { + /* An UNKNOWN state is only set on transports added by + * user in sctp_connectx() call. Such transports should be + * considered CONFIRMED per RFC 4960, Section 5.4. + */ + if (peer->state == SCTP_UNKNOWN) { + peer->state = SCTP_ACTIVE; + } + return peer; + } + + peer = sctp_transport_new(asoc->base.net, addr, gfp); + if (!peer) + return NULL; + + sctp_transport_set_owner(peer, asoc); + + /* Initialize the peer's heartbeat interval based on the + * association configured value. + */ + peer->hbinterval = asoc->hbinterval; + peer->probe_interval = asoc->probe_interval; + + peer->encap_port = asoc->encap_port; + + /* Set the path max_retrans. */ + peer->pathmaxrxt = asoc->pathmaxrxt; + + /* And the partial failure retrans threshold */ + peer->pf_retrans = asoc->pf_retrans; + /* And the primary path switchover retrans threshold */ + peer->ps_retrans = asoc->ps_retrans; + + /* Initialize the peer's SACK delay timeout based on the + * association configured value. + */ + peer->sackdelay = asoc->sackdelay; + peer->sackfreq = asoc->sackfreq; + + if (addr->sa.sa_family == AF_INET6) { + __be32 info = addr->v6.sin6_flowinfo; + + if (info) { + peer->flowlabel = ntohl(info & IPV6_FLOWLABEL_MASK); + peer->flowlabel |= SCTP_FLOWLABEL_SET_MASK; + } else { + peer->flowlabel = asoc->flowlabel; + } + } + peer->dscp = asoc->dscp; + + /* Enable/disable heartbeat, SACK delay, and path MTU discovery + * based on association setting. + */ + peer->param_flags = asoc->param_flags; + + /* Initialize the pmtu of the transport. */ + sctp_transport_route(peer, NULL, sp); + + /* If this is the first transport addr on this association, + * initialize the association PMTU to the peer's PMTU. + * If not and the current association PMTU is higher than the new + * peer's PMTU, reset the association PMTU to the new peer's PMTU. + */ + sctp_assoc_set_pmtu(asoc, asoc->pathmtu ? + min_t(int, peer->pathmtu, asoc->pathmtu) : + peer->pathmtu); + + peer->pmtu_pending = 0; + + /* The asoc->peer.port might not be meaningful yet, but + * initialize the packet structure anyway. + */ + sctp_packet_init(&peer->packet, peer, asoc->base.bind_addr.port, + asoc->peer.port); + + /* 7.2.1 Slow-Start + * + * o The initial cwnd before DATA transmission or after a sufficiently + * long idle period MUST be set to + * min(4*MTU, max(2*MTU, 4380 bytes)) + * + * o The initial value of ssthresh MAY be arbitrarily high + * (for example, implementations MAY use the size of the + * receiver advertised window). + */ + peer->cwnd = min(4*asoc->pathmtu, max_t(__u32, 2*asoc->pathmtu, 4380)); + + /* At this point, we may not have the receiver's advertised window, + * so initialize ssthresh to the default value and it will be set + * later when we process the INIT. + */ + peer->ssthresh = SCTP_DEFAULT_MAXWINDOW; + + peer->partial_bytes_acked = 0; + peer->flight_size = 0; + peer->burst_limited = 0; + + /* Set the transport's RTO.initial value */ + peer->rto = asoc->rto_initial; + sctp_max_rto(asoc, peer); + + /* Set the peer's active state. */ + peer->state = peer_state; + + /* Add this peer into the transport hashtable */ + if (sctp_hash_transport(peer)) { + sctp_transport_free(peer); + return NULL; + } + + sctp_transport_pl_reset(peer); + + /* Attach the remote transport to our asoc. */ + list_add_tail_rcu(&peer->transports, &asoc->peer.transport_addr_list); + asoc->peer.transport_count++; + + sctp_ulpevent_notify_peer_addr_change(peer, SCTP_ADDR_ADDED, 0); + + /* If we do not yet have a primary path, set one. */ + if (!asoc->peer.primary_path) { + sctp_assoc_set_primary(asoc, peer); + asoc->peer.retran_path = peer; + } + + if (asoc->peer.active_path == asoc->peer.retran_path && + peer->state != SCTP_UNCONFIRMED) { + asoc->peer.retran_path = peer; + } + + return peer; +} + +/* Delete a transport address from an association. */ +void sctp_assoc_del_peer(struct sctp_association *asoc, + const union sctp_addr *addr) +{ + struct list_head *pos; + struct list_head *temp; + struct sctp_transport *transport; + + list_for_each_safe(pos, temp, &asoc->peer.transport_addr_list) { + transport = list_entry(pos, struct sctp_transport, transports); + if (sctp_cmp_addr_exact(addr, &transport->ipaddr)) { + /* Do book keeping for removing the peer and free it. */ + sctp_assoc_rm_peer(asoc, transport); + break; + } + } +} + +/* Lookup a transport by address. */ +struct sctp_transport *sctp_assoc_lookup_paddr( + const struct sctp_association *asoc, + const union sctp_addr *address) +{ + struct sctp_transport *t; + + /* Cycle through all transports searching for a peer address. */ + + list_for_each_entry(t, &asoc->peer.transport_addr_list, + transports) { + if (sctp_cmp_addr_exact(address, &t->ipaddr)) + return t; + } + + return NULL; +} + +/* Remove all transports except a give one */ +void sctp_assoc_del_nonprimary_peers(struct sctp_association *asoc, + struct sctp_transport *primary) +{ + struct sctp_transport *temp; + struct sctp_transport *t; + + list_for_each_entry_safe(t, temp, &asoc->peer.transport_addr_list, + transports) { + /* if the current transport is not the primary one, delete it */ + if (t != primary) + sctp_assoc_rm_peer(asoc, t); + } +} + +/* Engage in transport control operations. + * Mark the transport up or down and send a notification to the user. + * Select and update the new active and retran paths. + */ +void sctp_assoc_control_transport(struct sctp_association *asoc, + struct sctp_transport *transport, + enum sctp_transport_cmd command, + sctp_sn_error_t error) +{ + int spc_state = SCTP_ADDR_AVAILABLE; + bool ulp_notify = true; + + /* Record the transition on the transport. */ + switch (command) { + case SCTP_TRANSPORT_UP: + /* If we are moving from UNCONFIRMED state due + * to heartbeat success, report the SCTP_ADDR_CONFIRMED + * state to the user, otherwise report SCTP_ADDR_AVAILABLE. + */ + if (transport->state == SCTP_PF && + asoc->pf_expose != SCTP_PF_EXPOSE_ENABLE) + ulp_notify = false; + else if (transport->state == SCTP_UNCONFIRMED && + error == SCTP_HEARTBEAT_SUCCESS) + spc_state = SCTP_ADDR_CONFIRMED; + + transport->state = SCTP_ACTIVE; + sctp_transport_pl_reset(transport); + break; + + case SCTP_TRANSPORT_DOWN: + /* If the transport was never confirmed, do not transition it + * to inactive state. Also, release the cached route since + * there may be a better route next time. + */ + if (transport->state != SCTP_UNCONFIRMED) { + transport->state = SCTP_INACTIVE; + sctp_transport_pl_reset(transport); + spc_state = SCTP_ADDR_UNREACHABLE; + } else { + sctp_transport_dst_release(transport); + ulp_notify = false; + } + break; + + case SCTP_TRANSPORT_PF: + transport->state = SCTP_PF; + if (asoc->pf_expose != SCTP_PF_EXPOSE_ENABLE) + ulp_notify = false; + else + spc_state = SCTP_ADDR_POTENTIALLY_FAILED; + break; + + default: + return; + } + + /* Generate and send a SCTP_PEER_ADDR_CHANGE notification + * to the user. + */ + if (ulp_notify) + sctp_ulpevent_notify_peer_addr_change(transport, + spc_state, error); + + /* Select new active and retran paths. */ + sctp_select_active_and_retran_path(asoc); +} + +/* Hold a reference to an association. */ +void sctp_association_hold(struct sctp_association *asoc) +{ + refcount_inc(&asoc->base.refcnt); +} + +/* Release a reference to an association and cleanup + * if there are no more references. + */ +void sctp_association_put(struct sctp_association *asoc) +{ + if (refcount_dec_and_test(&asoc->base.refcnt)) + sctp_association_destroy(asoc); +} + +/* Allocate the next TSN, Transmission Sequence Number, for the given + * association. + */ +__u32 sctp_association_get_next_tsn(struct sctp_association *asoc) +{ + /* From Section 1.6 Serial Number Arithmetic: + * Transmission Sequence Numbers wrap around when they reach + * 2**32 - 1. That is, the next TSN a DATA chunk MUST use + * after transmitting TSN = 2*32 - 1 is TSN = 0. + */ + __u32 retval = asoc->next_tsn; + asoc->next_tsn++; + asoc->unack_data++; + + return retval; +} + +/* Compare two addresses to see if they match. Wildcard addresses + * only match themselves. + */ +int sctp_cmp_addr_exact(const union sctp_addr *ss1, + const union sctp_addr *ss2) +{ + struct sctp_af *af; + + af = sctp_get_af_specific(ss1->sa.sa_family); + if (unlikely(!af)) + return 0; + + return af->cmp_addr(ss1, ss2); +} + +/* Return an ecne chunk to get prepended to a packet. + * Note: We are sly and return a shared, prealloced chunk. FIXME: + * No we don't, but we could/should. + */ +struct sctp_chunk *sctp_get_ecne_prepend(struct sctp_association *asoc) +{ + if (!asoc->need_ecne) + return NULL; + + /* Send ECNE if needed. + * Not being able to allocate a chunk here is not deadly. + */ + return sctp_make_ecne(asoc, asoc->last_ecne_tsn); +} + +/* + * Find which transport this TSN was sent on. + */ +struct sctp_transport *sctp_assoc_lookup_tsn(struct sctp_association *asoc, + __u32 tsn) +{ + struct sctp_transport *active; + struct sctp_transport *match; + struct sctp_transport *transport; + struct sctp_chunk *chunk; + __be32 key = htonl(tsn); + + match = NULL; + + /* + * FIXME: In general, find a more efficient data structure for + * searching. + */ + + /* + * The general strategy is to search each transport's transmitted + * list. Return which transport this TSN lives on. + * + * Let's be hopeful and check the active_path first. + * Another optimization would be to know if there is only one + * outbound path and not have to look for the TSN at all. + * + */ + + active = asoc->peer.active_path; + + list_for_each_entry(chunk, &active->transmitted, + transmitted_list) { + + if (key == chunk->subh.data_hdr->tsn) { + match = active; + goto out; + } + } + + /* If not found, go search all the other transports. */ + list_for_each_entry(transport, &asoc->peer.transport_addr_list, + transports) { + + if (transport == active) + continue; + list_for_each_entry(chunk, &transport->transmitted, + transmitted_list) { + if (key == chunk->subh.data_hdr->tsn) { + match = transport; + goto out; + } + } + } +out: + return match; +} + +/* Do delayed input processing. This is scheduled by sctp_rcv(). */ +static void sctp_assoc_bh_rcv(struct work_struct *work) +{ + struct sctp_association *asoc = + container_of(work, struct sctp_association, + base.inqueue.immediate); + struct net *net = asoc->base.net; + union sctp_subtype subtype; + struct sctp_endpoint *ep; + struct sctp_chunk *chunk; + struct sctp_inq *inqueue; + int first_time = 1; /* is this the first time through the loop */ + int error = 0; + int state; + + /* The association should be held so we should be safe. */ + ep = asoc->ep; + + inqueue = &asoc->base.inqueue; + sctp_association_hold(asoc); + while (NULL != (chunk = sctp_inq_pop(inqueue))) { + state = asoc->state; + subtype = SCTP_ST_CHUNK(chunk->chunk_hdr->type); + + /* If the first chunk in the packet is AUTH, do special + * processing specified in Section 6.3 of SCTP-AUTH spec + */ + if (first_time && subtype.chunk == SCTP_CID_AUTH) { + struct sctp_chunkhdr *next_hdr; + + next_hdr = sctp_inq_peek(inqueue); + if (!next_hdr) + goto normal; + + /* If the next chunk is COOKIE-ECHO, skip the AUTH + * chunk while saving a pointer to it so we can do + * Authentication later (during cookie-echo + * processing). + */ + if (next_hdr->type == SCTP_CID_COOKIE_ECHO) { + chunk->auth_chunk = skb_clone(chunk->skb, + GFP_ATOMIC); + chunk->auth = 1; + continue; + } + } + +normal: + /* SCTP-AUTH, Section 6.3: + * The receiver has a list of chunk types which it expects + * to be received only after an AUTH-chunk. This list has + * been sent to the peer during the association setup. It + * MUST silently discard these chunks if they are not placed + * after an AUTH chunk in the packet. + */ + if (sctp_auth_recv_cid(subtype.chunk, asoc) && !chunk->auth) + continue; + + /* Remember where the last DATA chunk came from so we + * know where to send the SACK. + */ + if (sctp_chunk_is_data(chunk)) + asoc->peer.last_data_from = chunk->transport; + else { + SCTP_INC_STATS(net, SCTP_MIB_INCTRLCHUNKS); + asoc->stats.ictrlchunks++; + if (chunk->chunk_hdr->type == SCTP_CID_SACK) + asoc->stats.isacks++; + } + + if (chunk->transport) + chunk->transport->last_time_heard = ktime_get(); + + /* Run through the state machine. */ + error = sctp_do_sm(net, SCTP_EVENT_T_CHUNK, subtype, + state, ep, asoc, chunk, GFP_ATOMIC); + + /* Check to see if the association is freed in response to + * the incoming chunk. If so, get out of the while loop. + */ + if (asoc->base.dead) + break; + + /* If there is an error on chunk, discard this packet. */ + if (error && chunk) + chunk->pdiscard = 1; + + if (first_time) + first_time = 0; + } + sctp_association_put(asoc); +} + +/* This routine moves an association from its old sk to a new sk. */ +void sctp_assoc_migrate(struct sctp_association *assoc, struct sock *newsk) +{ + struct sctp_sock *newsp = sctp_sk(newsk); + struct sock *oldsk = assoc->base.sk; + + /* Delete the association from the old endpoint's list of + * associations. + */ + list_del_init(&assoc->asocs); + + /* Decrement the backlog value for a TCP-style socket. */ + if (sctp_style(oldsk, TCP)) + sk_acceptq_removed(oldsk); + + /* Release references to the old endpoint and the sock. */ + sctp_endpoint_put(assoc->ep); + sock_put(assoc->base.sk); + + /* Get a reference to the new endpoint. */ + assoc->ep = newsp->ep; + sctp_endpoint_hold(assoc->ep); + + /* Get a reference to the new sock. */ + assoc->base.sk = newsk; + sock_hold(assoc->base.sk); + + /* Add the association to the new endpoint's list of associations. */ + sctp_endpoint_add_asoc(newsp->ep, assoc); +} + +/* Update an association (possibly from unexpected COOKIE-ECHO processing). */ +int sctp_assoc_update(struct sctp_association *asoc, + struct sctp_association *new) +{ + struct sctp_transport *trans; + struct list_head *pos, *temp; + + /* Copy in new parameters of peer. */ + asoc->c = new->c; + asoc->peer.rwnd = new->peer.rwnd; + asoc->peer.sack_needed = new->peer.sack_needed; + asoc->peer.auth_capable = new->peer.auth_capable; + asoc->peer.i = new->peer.i; + + if (!sctp_tsnmap_init(&asoc->peer.tsn_map, SCTP_TSN_MAP_INITIAL, + asoc->peer.i.initial_tsn, GFP_ATOMIC)) + return -ENOMEM; + + /* Remove any peer addresses not present in the new association. */ + list_for_each_safe(pos, temp, &asoc->peer.transport_addr_list) { + trans = list_entry(pos, struct sctp_transport, transports); + if (!sctp_assoc_lookup_paddr(new, &trans->ipaddr)) { + sctp_assoc_rm_peer(asoc, trans); + continue; + } + + if (asoc->state >= SCTP_STATE_ESTABLISHED) + sctp_transport_reset(trans); + } + + /* If the case is A (association restart), use + * initial_tsn as next_tsn. If the case is B, use + * current next_tsn in case data sent to peer + * has been discarded and needs retransmission. + */ + if (asoc->state >= SCTP_STATE_ESTABLISHED) { + asoc->next_tsn = new->next_tsn; + asoc->ctsn_ack_point = new->ctsn_ack_point; + asoc->adv_peer_ack_point = new->adv_peer_ack_point; + + /* Reinitialize SSN for both local streams + * and peer's streams. + */ + sctp_stream_clear(&asoc->stream); + + /* Flush the ULP reassembly and ordered queue. + * Any data there will now be stale and will + * cause problems. + */ + sctp_ulpq_flush(&asoc->ulpq); + + /* reset the overall association error count so + * that the restarted association doesn't get torn + * down on the next retransmission timer. + */ + asoc->overall_error_count = 0; + + } else { + /* Add any peer addresses from the new association. */ + list_for_each_entry(trans, &new->peer.transport_addr_list, + transports) + if (!sctp_assoc_add_peer(asoc, &trans->ipaddr, + GFP_ATOMIC, trans->state)) + return -ENOMEM; + + asoc->ctsn_ack_point = asoc->next_tsn - 1; + asoc->adv_peer_ack_point = asoc->ctsn_ack_point; + + if (sctp_state(asoc, COOKIE_WAIT)) + sctp_stream_update(&asoc->stream, &new->stream); + + /* get a new assoc id if we don't have one yet. */ + if (sctp_assoc_set_id(asoc, GFP_ATOMIC)) + return -ENOMEM; + } + + /* SCTP-AUTH: Save the peer parameters from the new associations + * and also move the association shared keys over + */ + kfree(asoc->peer.peer_random); + asoc->peer.peer_random = new->peer.peer_random; + new->peer.peer_random = NULL; + + kfree(asoc->peer.peer_chunks); + asoc->peer.peer_chunks = new->peer.peer_chunks; + new->peer.peer_chunks = NULL; + + kfree(asoc->peer.peer_hmacs); + asoc->peer.peer_hmacs = new->peer.peer_hmacs; + new->peer.peer_hmacs = NULL; + + return sctp_auth_asoc_init_active_key(asoc, GFP_ATOMIC); +} + +/* Update the retran path for sending a retransmitted packet. + * See also RFC4960, 6.4. Multi-Homed SCTP Endpoints: + * + * When there is outbound data to send and the primary path + * becomes inactive (e.g., due to failures), or where the + * SCTP user explicitly requests to send data to an + * inactive destination transport address, before reporting + * an error to its ULP, the SCTP endpoint should try to send + * the data to an alternate active destination transport + * address if one exists. + * + * When retransmitting data that timed out, if the endpoint + * is multihomed, it should consider each source-destination + * address pair in its retransmission selection policy. + * When retransmitting timed-out data, the endpoint should + * attempt to pick the most divergent source-destination + * pair from the original source-destination pair to which + * the packet was transmitted. + * + * Note: Rules for picking the most divergent source-destination + * pair are an implementation decision and are not specified + * within this document. + * + * Our basic strategy is to round-robin transports in priorities + * according to sctp_trans_score() e.g., if no such + * transport with state SCTP_ACTIVE exists, round-robin through + * SCTP_UNKNOWN, etc. You get the picture. + */ +static u8 sctp_trans_score(const struct sctp_transport *trans) +{ + switch (trans->state) { + case SCTP_ACTIVE: + return 3; /* best case */ + case SCTP_UNKNOWN: + return 2; + case SCTP_PF: + return 1; + default: /* case SCTP_INACTIVE */ + return 0; /* worst case */ + } +} + +static struct sctp_transport *sctp_trans_elect_tie(struct sctp_transport *trans1, + struct sctp_transport *trans2) +{ + if (trans1->error_count > trans2->error_count) { + return trans2; + } else if (trans1->error_count == trans2->error_count && + ktime_after(trans2->last_time_heard, + trans1->last_time_heard)) { + return trans2; + } else { + return trans1; + } +} + +static struct sctp_transport *sctp_trans_elect_best(struct sctp_transport *curr, + struct sctp_transport *best) +{ + u8 score_curr, score_best; + + if (best == NULL || curr == best) + return curr; + + score_curr = sctp_trans_score(curr); + score_best = sctp_trans_score(best); + + /* First, try a score-based selection if both transport states + * differ. If we're in a tie, lets try to make a more clever + * decision here based on error counts and last time heard. + */ + if (score_curr > score_best) + return curr; + else if (score_curr == score_best) + return sctp_trans_elect_tie(best, curr); + else + return best; +} + +void sctp_assoc_update_retran_path(struct sctp_association *asoc) +{ + struct sctp_transport *trans = asoc->peer.retran_path; + struct sctp_transport *trans_next = NULL; + + /* We're done as we only have the one and only path. */ + if (asoc->peer.transport_count == 1) + return; + /* If active_path and retran_path are the same and active, + * then this is the only active path. Use it. + */ + if (asoc->peer.active_path == asoc->peer.retran_path && + asoc->peer.active_path->state == SCTP_ACTIVE) + return; + + /* Iterate from retran_path's successor back to retran_path. */ + for (trans = list_next_entry(trans, transports); 1; + trans = list_next_entry(trans, transports)) { + /* Manually skip the head element. */ + if (&trans->transports == &asoc->peer.transport_addr_list) + continue; + if (trans->state == SCTP_UNCONFIRMED) + continue; + trans_next = sctp_trans_elect_best(trans, trans_next); + /* Active is good enough for immediate return. */ + if (trans_next->state == SCTP_ACTIVE) + break; + /* We've reached the end, time to update path. */ + if (trans == asoc->peer.retran_path) + break; + } + + asoc->peer.retran_path = trans_next; + + pr_debug("%s: association:%p updated new path to addr:%pISpc\n", + __func__, asoc, &asoc->peer.retran_path->ipaddr.sa); +} + +static void sctp_select_active_and_retran_path(struct sctp_association *asoc) +{ + struct sctp_transport *trans, *trans_pri = NULL, *trans_sec = NULL; + struct sctp_transport *trans_pf = NULL; + + /* Look for the two most recently used active transports. */ + list_for_each_entry(trans, &asoc->peer.transport_addr_list, + transports) { + /* Skip uninteresting transports. */ + if (trans->state == SCTP_INACTIVE || + trans->state == SCTP_UNCONFIRMED) + continue; + /* Keep track of the best PF transport from our + * list in case we don't find an active one. + */ + if (trans->state == SCTP_PF) { + trans_pf = sctp_trans_elect_best(trans, trans_pf); + continue; + } + /* For active transports, pick the most recent ones. */ + if (trans_pri == NULL || + ktime_after(trans->last_time_heard, + trans_pri->last_time_heard)) { + trans_sec = trans_pri; + trans_pri = trans; + } else if (trans_sec == NULL || + ktime_after(trans->last_time_heard, + trans_sec->last_time_heard)) { + trans_sec = trans; + } + } + + /* RFC 2960 6.4 Multi-Homed SCTP Endpoints + * + * By default, an endpoint should always transmit to the primary + * path, unless the SCTP user explicitly specifies the + * destination transport address (and possibly source transport + * address) to use. [If the primary is active but not most recent, + * bump the most recently used transport.] + */ + if ((asoc->peer.primary_path->state == SCTP_ACTIVE || + asoc->peer.primary_path->state == SCTP_UNKNOWN) && + asoc->peer.primary_path != trans_pri) { + trans_sec = trans_pri; + trans_pri = asoc->peer.primary_path; + } + + /* We did not find anything useful for a possible retransmission + * path; either primary path that we found is the same as + * the current one, or we didn't generally find an active one. + */ + if (trans_sec == NULL) + trans_sec = trans_pri; + + /* If we failed to find a usable transport, just camp on the + * active or pick a PF iff it's the better choice. + */ + if (trans_pri == NULL) { + trans_pri = sctp_trans_elect_best(asoc->peer.active_path, trans_pf); + trans_sec = trans_pri; + } + + /* Set the active and retran transports. */ + asoc->peer.active_path = trans_pri; + asoc->peer.retran_path = trans_sec; +} + +struct sctp_transport * +sctp_assoc_choose_alter_transport(struct sctp_association *asoc, + struct sctp_transport *last_sent_to) +{ + /* If this is the first time packet is sent, use the active path, + * else use the retran path. If the last packet was sent over the + * retran path, update the retran path and use it. + */ + if (last_sent_to == NULL) { + return asoc->peer.active_path; + } else { + if (last_sent_to == asoc->peer.retran_path) + sctp_assoc_update_retran_path(asoc); + + return asoc->peer.retran_path; + } +} + +void sctp_assoc_update_frag_point(struct sctp_association *asoc) +{ + int frag = sctp_mtu_payload(sctp_sk(asoc->base.sk), asoc->pathmtu, + sctp_datachk_len(&asoc->stream)); + + if (asoc->user_frag) + frag = min_t(int, frag, asoc->user_frag); + + frag = min_t(int, frag, SCTP_MAX_CHUNK_LEN - + sctp_datachk_len(&asoc->stream)); + + asoc->frag_point = SCTP_TRUNC4(frag); +} + +void sctp_assoc_set_pmtu(struct sctp_association *asoc, __u32 pmtu) +{ + if (asoc->pathmtu != pmtu) { + asoc->pathmtu = pmtu; + sctp_assoc_update_frag_point(asoc); + } + + pr_debug("%s: asoc:%p, pmtu:%d, frag_point:%d\n", __func__, asoc, + asoc->pathmtu, asoc->frag_point); +} + +/* Update the association's pmtu and frag_point by going through all the + * transports. This routine is called when a transport's PMTU has changed. + */ +void sctp_assoc_sync_pmtu(struct sctp_association *asoc) +{ + struct sctp_transport *t; + __u32 pmtu = 0; + + if (!asoc) + return; + + /* Get the lowest pmtu of all the transports. */ + list_for_each_entry(t, &asoc->peer.transport_addr_list, transports) { + if (t->pmtu_pending && t->dst) { + sctp_transport_update_pmtu(t, + atomic_read(&t->mtu_info)); + t->pmtu_pending = 0; + } + if (!pmtu || (t->pathmtu < pmtu)) + pmtu = t->pathmtu; + } + + sctp_assoc_set_pmtu(asoc, pmtu); +} + +/* Should we send a SACK to update our peer? */ +static inline bool sctp_peer_needs_update(struct sctp_association *asoc) +{ + struct net *net = asoc->base.net; + + switch (asoc->state) { + case SCTP_STATE_ESTABLISHED: + case SCTP_STATE_SHUTDOWN_PENDING: + case SCTP_STATE_SHUTDOWN_RECEIVED: + case SCTP_STATE_SHUTDOWN_SENT: + if ((asoc->rwnd > asoc->a_rwnd) && + ((asoc->rwnd - asoc->a_rwnd) >= max_t(__u32, + (asoc->base.sk->sk_rcvbuf >> net->sctp.rwnd_upd_shift), + asoc->pathmtu))) + return true; + break; + default: + break; + } + return false; +} + +/* Increase asoc's rwnd by len and send any window update SACK if needed. */ +void sctp_assoc_rwnd_increase(struct sctp_association *asoc, unsigned int len) +{ + struct sctp_chunk *sack; + struct timer_list *timer; + + if (asoc->rwnd_over) { + if (asoc->rwnd_over >= len) { + asoc->rwnd_over -= len; + } else { + asoc->rwnd += (len - asoc->rwnd_over); + asoc->rwnd_over = 0; + } + } else { + asoc->rwnd += len; + } + + /* If we had window pressure, start recovering it + * once our rwnd had reached the accumulated pressure + * threshold. The idea is to recover slowly, but up + * to the initial advertised window. + */ + if (asoc->rwnd_press) { + int change = min(asoc->pathmtu, asoc->rwnd_press); + asoc->rwnd += change; + asoc->rwnd_press -= change; + } + + pr_debug("%s: asoc:%p rwnd increased by %d to (%u, %u) - %u\n", + __func__, asoc, len, asoc->rwnd, asoc->rwnd_over, + asoc->a_rwnd); + + /* Send a window update SACK if the rwnd has increased by at least the + * minimum of the association's PMTU and half of the receive buffer. + * The algorithm used is similar to the one described in + * Section 4.2.3.3 of RFC 1122. + */ + if (sctp_peer_needs_update(asoc)) { + asoc->a_rwnd = asoc->rwnd; + + pr_debug("%s: sending window update SACK- asoc:%p rwnd:%u " + "a_rwnd:%u\n", __func__, asoc, asoc->rwnd, + asoc->a_rwnd); + + sack = sctp_make_sack(asoc); + if (!sack) + return; + + asoc->peer.sack_needed = 0; + + sctp_outq_tail(&asoc->outqueue, sack, GFP_ATOMIC); + + /* Stop the SACK timer. */ + timer = &asoc->timers[SCTP_EVENT_TIMEOUT_SACK]; + if (del_timer(timer)) + sctp_association_put(asoc); + } +} + +/* Decrease asoc's rwnd by len. */ +void sctp_assoc_rwnd_decrease(struct sctp_association *asoc, unsigned int len) +{ + int rx_count; + int over = 0; + + if (unlikely(!asoc->rwnd || asoc->rwnd_over)) + pr_debug("%s: association:%p has asoc->rwnd:%u, " + "asoc->rwnd_over:%u!\n", __func__, asoc, + asoc->rwnd, asoc->rwnd_over); + + if (asoc->ep->rcvbuf_policy) + rx_count = atomic_read(&asoc->rmem_alloc); + else + rx_count = atomic_read(&asoc->base.sk->sk_rmem_alloc); + + /* If we've reached or overflowed our receive buffer, announce + * a 0 rwnd if rwnd would still be positive. Store the + * potential pressure overflow so that the window can be restored + * back to original value. + */ + if (rx_count >= asoc->base.sk->sk_rcvbuf) + over = 1; + + if (asoc->rwnd >= len) { + asoc->rwnd -= len; + if (over) { + asoc->rwnd_press += asoc->rwnd; + asoc->rwnd = 0; + } + } else { + asoc->rwnd_over += len - asoc->rwnd; + asoc->rwnd = 0; + } + + pr_debug("%s: asoc:%p rwnd decreased by %d to (%u, %u, %u)\n", + __func__, asoc, len, asoc->rwnd, asoc->rwnd_over, + asoc->rwnd_press); +} + +/* Build the bind address list for the association based on info from the + * local endpoint and the remote peer. + */ +int sctp_assoc_set_bind_addr_from_ep(struct sctp_association *asoc, + enum sctp_scope scope, gfp_t gfp) +{ + struct sock *sk = asoc->base.sk; + int flags; + + /* Use scoping rules to determine the subset of addresses from + * the endpoint. + */ + flags = (PF_INET6 == sk->sk_family) ? SCTP_ADDR6_ALLOWED : 0; + if (!inet_v6_ipv6only(sk)) + flags |= SCTP_ADDR4_ALLOWED; + if (asoc->peer.ipv4_address) + flags |= SCTP_ADDR4_PEERSUPP; + if (asoc->peer.ipv6_address) + flags |= SCTP_ADDR6_PEERSUPP; + + return sctp_bind_addr_copy(asoc->base.net, + &asoc->base.bind_addr, + &asoc->ep->base.bind_addr, + scope, gfp, flags); +} + +/* Build the association's bind address list from the cookie. */ +int sctp_assoc_set_bind_addr_from_cookie(struct sctp_association *asoc, + struct sctp_cookie *cookie, + gfp_t gfp) +{ + int var_size2 = ntohs(cookie->peer_init->chunk_hdr.length); + int var_size3 = cookie->raw_addr_list_len; + __u8 *raw = (__u8 *)cookie->peer_init + var_size2; + + return sctp_raw_to_bind_addrs(&asoc->base.bind_addr, raw, var_size3, + asoc->ep->base.bind_addr.port, gfp); +} + +/* Lookup laddr in the bind address list of an association. */ +int sctp_assoc_lookup_laddr(struct sctp_association *asoc, + const union sctp_addr *laddr) +{ + int found = 0; + + if ((asoc->base.bind_addr.port == ntohs(laddr->v4.sin_port)) && + sctp_bind_addr_match(&asoc->base.bind_addr, laddr, + sctp_sk(asoc->base.sk))) + found = 1; + + return found; +} + +/* Set an association id for a given association */ +int sctp_assoc_set_id(struct sctp_association *asoc, gfp_t gfp) +{ + bool preload = gfpflags_allow_blocking(gfp); + int ret; + + /* If the id is already assigned, keep it. */ + if (asoc->assoc_id) + return 0; + + if (preload) + idr_preload(gfp); + spin_lock_bh(&sctp_assocs_id_lock); + /* 0, 1, 2 are used as SCTP_FUTURE_ASSOC, SCTP_CURRENT_ASSOC and + * SCTP_ALL_ASSOC, so an available id must be > SCTP_ALL_ASSOC. + */ + ret = idr_alloc_cyclic(&sctp_assocs_id, asoc, SCTP_ALL_ASSOC + 1, 0, + GFP_NOWAIT); + spin_unlock_bh(&sctp_assocs_id_lock); + if (preload) + idr_preload_end(); + if (ret < 0) + return ret; + + asoc->assoc_id = (sctp_assoc_t)ret; + return 0; +} + +/* Free the ASCONF queue */ +static void sctp_assoc_free_asconf_queue(struct sctp_association *asoc) +{ + struct sctp_chunk *asconf; + struct sctp_chunk *tmp; + + list_for_each_entry_safe(asconf, tmp, &asoc->addip_chunk_list, list) { + list_del_init(&asconf->list); + sctp_chunk_free(asconf); + } +} + +/* Free asconf_ack cache */ +static void sctp_assoc_free_asconf_acks(struct sctp_association *asoc) +{ + struct sctp_chunk *ack; + struct sctp_chunk *tmp; + + list_for_each_entry_safe(ack, tmp, &asoc->asconf_ack_list, + transmitted_list) { + list_del_init(&ack->transmitted_list); + sctp_chunk_free(ack); + } +} + +/* Clean up the ASCONF_ACK queue */ +void sctp_assoc_clean_asconf_ack_cache(const struct sctp_association *asoc) +{ + struct sctp_chunk *ack; + struct sctp_chunk *tmp; + + /* We can remove all the entries from the queue up to + * the "Peer-Sequence-Number". + */ + list_for_each_entry_safe(ack, tmp, &asoc->asconf_ack_list, + transmitted_list) { + if (ack->subh.addip_hdr->serial == + htonl(asoc->peer.addip_serial)) + break; + + list_del_init(&ack->transmitted_list); + sctp_chunk_free(ack); + } +} + +/* Find the ASCONF_ACK whose serial number matches ASCONF */ +struct sctp_chunk *sctp_assoc_lookup_asconf_ack( + const struct sctp_association *asoc, + __be32 serial) +{ + struct sctp_chunk *ack; + + /* Walk through the list of cached ASCONF-ACKs and find the + * ack chunk whose serial number matches that of the request. + */ + list_for_each_entry(ack, &asoc->asconf_ack_list, transmitted_list) { + if (sctp_chunk_pending(ack)) + continue; + if (ack->subh.addip_hdr->serial == serial) { + sctp_chunk_hold(ack); + return ack; + } + } + + return NULL; +} + +void sctp_asconf_queue_teardown(struct sctp_association *asoc) +{ + /* Free any cached ASCONF_ACK chunk. */ + sctp_assoc_free_asconf_acks(asoc); + + /* Free the ASCONF queue. */ + sctp_assoc_free_asconf_queue(asoc); + + /* Free any cached ASCONF chunk. */ + if (asoc->addip_last_asconf) + sctp_chunk_free(asoc->addip_last_asconf); +} diff --git a/net/sctp/auth.c b/net/sctp/auth.c new file mode 100644 index 000000000..349641455 --- /dev/null +++ b/net/sctp/auth.c @@ -0,0 +1,1089 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright 2007 Hewlett-Packard Development Company, L.P. + * + * This file is part of the SCTP kernel implementation + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * Vlad Yasevich + */ + +#include +#include +#include +#include +#include +#include + +static struct sctp_hmac sctp_hmac_list[SCTP_AUTH_NUM_HMACS] = { + { + /* id 0 is reserved. as all 0 */ + .hmac_id = SCTP_AUTH_HMAC_ID_RESERVED_0, + }, + { + .hmac_id = SCTP_AUTH_HMAC_ID_SHA1, + .hmac_name = "hmac(sha1)", + .hmac_len = SCTP_SHA1_SIG_SIZE, + }, + { + /* id 2 is reserved as well */ + .hmac_id = SCTP_AUTH_HMAC_ID_RESERVED_2, + }, +#if IS_ENABLED(CONFIG_CRYPTO_SHA256) + { + .hmac_id = SCTP_AUTH_HMAC_ID_SHA256, + .hmac_name = "hmac(sha256)", + .hmac_len = SCTP_SHA256_SIG_SIZE, + } +#endif +}; + + +void sctp_auth_key_put(struct sctp_auth_bytes *key) +{ + if (!key) + return; + + if (refcount_dec_and_test(&key->refcnt)) { + kfree_sensitive(key); + SCTP_DBG_OBJCNT_DEC(keys); + } +} + +/* Create a new key structure of a given length */ +static struct sctp_auth_bytes *sctp_auth_create_key(__u32 key_len, gfp_t gfp) +{ + struct sctp_auth_bytes *key; + + /* Verify that we are not going to overflow INT_MAX */ + if (key_len > (INT_MAX - sizeof(struct sctp_auth_bytes))) + return NULL; + + /* Allocate the shared key */ + key = kmalloc(sizeof(struct sctp_auth_bytes) + key_len, gfp); + if (!key) + return NULL; + + key->len = key_len; + refcount_set(&key->refcnt, 1); + SCTP_DBG_OBJCNT_INC(keys); + + return key; +} + +/* Create a new shared key container with a give key id */ +struct sctp_shared_key *sctp_auth_shkey_create(__u16 key_id, gfp_t gfp) +{ + struct sctp_shared_key *new; + + /* Allocate the shared key container */ + new = kzalloc(sizeof(struct sctp_shared_key), gfp); + if (!new) + return NULL; + + INIT_LIST_HEAD(&new->key_list); + refcount_set(&new->refcnt, 1); + new->key_id = key_id; + + return new; +} + +/* Free the shared key structure */ +static void sctp_auth_shkey_destroy(struct sctp_shared_key *sh_key) +{ + BUG_ON(!list_empty(&sh_key->key_list)); + sctp_auth_key_put(sh_key->key); + sh_key->key = NULL; + kfree(sh_key); +} + +void sctp_auth_shkey_release(struct sctp_shared_key *sh_key) +{ + if (refcount_dec_and_test(&sh_key->refcnt)) + sctp_auth_shkey_destroy(sh_key); +} + +void sctp_auth_shkey_hold(struct sctp_shared_key *sh_key) +{ + refcount_inc(&sh_key->refcnt); +} + +/* Destroy the entire key list. This is done during the + * associon and endpoint free process. + */ +void sctp_auth_destroy_keys(struct list_head *keys) +{ + struct sctp_shared_key *ep_key; + struct sctp_shared_key *tmp; + + if (list_empty(keys)) + return; + + key_for_each_safe(ep_key, tmp, keys) { + list_del_init(&ep_key->key_list); + sctp_auth_shkey_release(ep_key); + } +} + +/* Compare two byte vectors as numbers. Return values + * are: + * 0 - vectors are equal + * < 0 - vector 1 is smaller than vector2 + * > 0 - vector 1 is greater than vector2 + * + * Algorithm is: + * This is performed by selecting the numerically smaller key vector... + * If the key vectors are equal as numbers but differ in length ... + * the shorter vector is considered smaller + * + * Examples (with small values): + * 000123456789 > 123456789 (first number is longer) + * 000123456789 < 234567891 (second number is larger numerically) + * 123456789 > 2345678 (first number is both larger & longer) + */ +static int sctp_auth_compare_vectors(struct sctp_auth_bytes *vector1, + struct sctp_auth_bytes *vector2) +{ + int diff; + int i; + const __u8 *longer; + + diff = vector1->len - vector2->len; + if (diff) { + longer = (diff > 0) ? vector1->data : vector2->data; + + /* Check to see if the longer number is + * lead-zero padded. If it is not, it + * is automatically larger numerically. + */ + for (i = 0; i < abs(diff); i++) { + if (longer[i] != 0) + return diff; + } + } + + /* lengths are the same, compare numbers */ + return memcmp(vector1->data, vector2->data, vector1->len); +} + +/* + * Create a key vector as described in SCTP-AUTH, Section 6.1 + * The RANDOM parameter, the CHUNKS parameter and the HMAC-ALGO + * parameter sent by each endpoint are concatenated as byte vectors. + * These parameters include the parameter type, parameter length, and + * the parameter value, but padding is omitted; all padding MUST be + * removed from this concatenation before proceeding with further + * computation of keys. Parameters which were not sent are simply + * omitted from the concatenation process. The resulting two vectors + * are called the two key vectors. + */ +static struct sctp_auth_bytes *sctp_auth_make_key_vector( + struct sctp_random_param *random, + struct sctp_chunks_param *chunks, + struct sctp_hmac_algo_param *hmacs, + gfp_t gfp) +{ + struct sctp_auth_bytes *new; + __u32 len; + __u32 offset = 0; + __u16 random_len, hmacs_len, chunks_len = 0; + + random_len = ntohs(random->param_hdr.length); + hmacs_len = ntohs(hmacs->param_hdr.length); + if (chunks) + chunks_len = ntohs(chunks->param_hdr.length); + + len = random_len + hmacs_len + chunks_len; + + new = sctp_auth_create_key(len, gfp); + if (!new) + return NULL; + + memcpy(new->data, random, random_len); + offset += random_len; + + if (chunks) { + memcpy(new->data + offset, chunks, chunks_len); + offset += chunks_len; + } + + memcpy(new->data + offset, hmacs, hmacs_len); + + return new; +} + + +/* Make a key vector based on our local parameters */ +static struct sctp_auth_bytes *sctp_auth_make_local_vector( + const struct sctp_association *asoc, + gfp_t gfp) +{ + return sctp_auth_make_key_vector( + (struct sctp_random_param *)asoc->c.auth_random, + (struct sctp_chunks_param *)asoc->c.auth_chunks, + (struct sctp_hmac_algo_param *)asoc->c.auth_hmacs, gfp); +} + +/* Make a key vector based on peer's parameters */ +static struct sctp_auth_bytes *sctp_auth_make_peer_vector( + const struct sctp_association *asoc, + gfp_t gfp) +{ + return sctp_auth_make_key_vector(asoc->peer.peer_random, + asoc->peer.peer_chunks, + asoc->peer.peer_hmacs, + gfp); +} + + +/* Set the value of the association shared key base on the parameters + * given. The algorithm is: + * From the endpoint pair shared keys and the key vectors the + * association shared keys are computed. This is performed by selecting + * the numerically smaller key vector and concatenating it to the + * endpoint pair shared key, and then concatenating the numerically + * larger key vector to that. The result of the concatenation is the + * association shared key. + */ +static struct sctp_auth_bytes *sctp_auth_asoc_set_secret( + struct sctp_shared_key *ep_key, + struct sctp_auth_bytes *first_vector, + struct sctp_auth_bytes *last_vector, + gfp_t gfp) +{ + struct sctp_auth_bytes *secret; + __u32 offset = 0; + __u32 auth_len; + + auth_len = first_vector->len + last_vector->len; + if (ep_key->key) + auth_len += ep_key->key->len; + + secret = sctp_auth_create_key(auth_len, gfp); + if (!secret) + return NULL; + + if (ep_key->key) { + memcpy(secret->data, ep_key->key->data, ep_key->key->len); + offset += ep_key->key->len; + } + + memcpy(secret->data + offset, first_vector->data, first_vector->len); + offset += first_vector->len; + + memcpy(secret->data + offset, last_vector->data, last_vector->len); + + return secret; +} + +/* Create an association shared key. Follow the algorithm + * described in SCTP-AUTH, Section 6.1 + */ +static struct sctp_auth_bytes *sctp_auth_asoc_create_secret( + const struct sctp_association *asoc, + struct sctp_shared_key *ep_key, + gfp_t gfp) +{ + struct sctp_auth_bytes *local_key_vector; + struct sctp_auth_bytes *peer_key_vector; + struct sctp_auth_bytes *first_vector, + *last_vector; + struct sctp_auth_bytes *secret = NULL; + int cmp; + + + /* Now we need to build the key vectors + * SCTP-AUTH , Section 6.1 + * The RANDOM parameter, the CHUNKS parameter and the HMAC-ALGO + * parameter sent by each endpoint are concatenated as byte vectors. + * These parameters include the parameter type, parameter length, and + * the parameter value, but padding is omitted; all padding MUST be + * removed from this concatenation before proceeding with further + * computation of keys. Parameters which were not sent are simply + * omitted from the concatenation process. The resulting two vectors + * are called the two key vectors. + */ + + local_key_vector = sctp_auth_make_local_vector(asoc, gfp); + peer_key_vector = sctp_auth_make_peer_vector(asoc, gfp); + + if (!peer_key_vector || !local_key_vector) + goto out; + + /* Figure out the order in which the key_vectors will be + * added to the endpoint shared key. + * SCTP-AUTH, Section 6.1: + * This is performed by selecting the numerically smaller key + * vector and concatenating it to the endpoint pair shared + * key, and then concatenating the numerically larger key + * vector to that. If the key vectors are equal as numbers + * but differ in length, then the concatenation order is the + * endpoint shared key, followed by the shorter key vector, + * followed by the longer key vector. Otherwise, the key + * vectors are identical, and may be concatenated to the + * endpoint pair key in any order. + */ + cmp = sctp_auth_compare_vectors(local_key_vector, + peer_key_vector); + if (cmp < 0) { + first_vector = local_key_vector; + last_vector = peer_key_vector; + } else { + first_vector = peer_key_vector; + last_vector = local_key_vector; + } + + secret = sctp_auth_asoc_set_secret(ep_key, first_vector, last_vector, + gfp); +out: + sctp_auth_key_put(local_key_vector); + sctp_auth_key_put(peer_key_vector); + + return secret; +} + +/* + * Populate the association overlay list with the list + * from the endpoint. + */ +int sctp_auth_asoc_copy_shkeys(const struct sctp_endpoint *ep, + struct sctp_association *asoc, + gfp_t gfp) +{ + struct sctp_shared_key *sh_key; + struct sctp_shared_key *new; + + BUG_ON(!list_empty(&asoc->endpoint_shared_keys)); + + key_for_each(sh_key, &ep->endpoint_shared_keys) { + new = sctp_auth_shkey_create(sh_key->key_id, gfp); + if (!new) + goto nomem; + + new->key = sh_key->key; + sctp_auth_key_hold(new->key); + list_add(&new->key_list, &asoc->endpoint_shared_keys); + } + + return 0; + +nomem: + sctp_auth_destroy_keys(&asoc->endpoint_shared_keys); + return -ENOMEM; +} + + +/* Public interface to create the association shared key. + * See code above for the algorithm. + */ +int sctp_auth_asoc_init_active_key(struct sctp_association *asoc, gfp_t gfp) +{ + struct sctp_auth_bytes *secret; + struct sctp_shared_key *ep_key; + struct sctp_chunk *chunk; + + /* If we don't support AUTH, or peer is not capable + * we don't need to do anything. + */ + if (!asoc->peer.auth_capable) + return 0; + + /* If the key_id is non-zero and we couldn't find an + * endpoint pair shared key, we can't compute the + * secret. + * For key_id 0, endpoint pair shared key is a NULL key. + */ + ep_key = sctp_auth_get_shkey(asoc, asoc->active_key_id); + BUG_ON(!ep_key); + + secret = sctp_auth_asoc_create_secret(asoc, ep_key, gfp); + if (!secret) + return -ENOMEM; + + sctp_auth_key_put(asoc->asoc_shared_key); + asoc->asoc_shared_key = secret; + asoc->shkey = ep_key; + + /* Update send queue in case any chunk already in there now + * needs authenticating + */ + list_for_each_entry(chunk, &asoc->outqueue.out_chunk_list, list) { + if (sctp_auth_send_cid(chunk->chunk_hdr->type, asoc)) { + chunk->auth = 1; + if (!chunk->shkey) { + chunk->shkey = asoc->shkey; + sctp_auth_shkey_hold(chunk->shkey); + } + } + } + + return 0; +} + + +/* Find the endpoint pair shared key based on the key_id */ +struct sctp_shared_key *sctp_auth_get_shkey( + const struct sctp_association *asoc, + __u16 key_id) +{ + struct sctp_shared_key *key; + + /* First search associations set of endpoint pair shared keys */ + key_for_each(key, &asoc->endpoint_shared_keys) { + if (key->key_id == key_id) { + if (!key->deactivated) + return key; + break; + } + } + + return NULL; +} + +/* + * Initialize all the possible digest transforms that we can use. Right + * now, the supported digests are SHA1 and SHA256. We do this here once + * because of the restrictiong that transforms may only be allocated in + * user context. This forces us to pre-allocated all possible transforms + * at the endpoint init time. + */ +int sctp_auth_init_hmacs(struct sctp_endpoint *ep, gfp_t gfp) +{ + struct crypto_shash *tfm = NULL; + __u16 id; + + /* If the transforms are already allocated, we are done */ + if (ep->auth_hmacs) + return 0; + + /* Allocated the array of pointers to transorms */ + ep->auth_hmacs = kcalloc(SCTP_AUTH_NUM_HMACS, + sizeof(struct crypto_shash *), + gfp); + if (!ep->auth_hmacs) + return -ENOMEM; + + for (id = 0; id < SCTP_AUTH_NUM_HMACS; id++) { + + /* See is we support the id. Supported IDs have name and + * length fields set, so that we can allocated and use + * them. We can safely just check for name, for without the + * name, we can't allocate the TFM. + */ + if (!sctp_hmac_list[id].hmac_name) + continue; + + /* If this TFM has been allocated, we are all set */ + if (ep->auth_hmacs[id]) + continue; + + /* Allocate the ID */ + tfm = crypto_alloc_shash(sctp_hmac_list[id].hmac_name, 0, 0); + if (IS_ERR(tfm)) + goto out_err; + + ep->auth_hmacs[id] = tfm; + } + + return 0; + +out_err: + /* Clean up any successful allocations */ + sctp_auth_destroy_hmacs(ep->auth_hmacs); + ep->auth_hmacs = NULL; + return -ENOMEM; +} + +/* Destroy the hmac tfm array */ +void sctp_auth_destroy_hmacs(struct crypto_shash *auth_hmacs[]) +{ + int i; + + if (!auth_hmacs) + return; + + for (i = 0; i < SCTP_AUTH_NUM_HMACS; i++) { + crypto_free_shash(auth_hmacs[i]); + } + kfree(auth_hmacs); +} + + +struct sctp_hmac *sctp_auth_get_hmac(__u16 hmac_id) +{ + return &sctp_hmac_list[hmac_id]; +} + +/* Get an hmac description information that we can use to build + * the AUTH chunk + */ +struct sctp_hmac *sctp_auth_asoc_get_hmac(const struct sctp_association *asoc) +{ + struct sctp_hmac_algo_param *hmacs; + __u16 n_elt; + __u16 id = 0; + int i; + + /* If we have a default entry, use it */ + if (asoc->default_hmac_id) + return &sctp_hmac_list[asoc->default_hmac_id]; + + /* Since we do not have a default entry, find the first entry + * we support and return that. Do not cache that id. + */ + hmacs = asoc->peer.peer_hmacs; + if (!hmacs) + return NULL; + + n_elt = (ntohs(hmacs->param_hdr.length) - + sizeof(struct sctp_paramhdr)) >> 1; + for (i = 0; i < n_elt; i++) { + id = ntohs(hmacs->hmac_ids[i]); + + /* Check the id is in the supported range. And + * see if we support the id. Supported IDs have name and + * length fields set, so that we can allocate and use + * them. We can safely just check for name, for without the + * name, we can't allocate the TFM. + */ + if (id > SCTP_AUTH_HMAC_ID_MAX || + !sctp_hmac_list[id].hmac_name) { + id = 0; + continue; + } + + break; + } + + if (id == 0) + return NULL; + + return &sctp_hmac_list[id]; +} + +static int __sctp_auth_find_hmacid(__be16 *hmacs, int n_elts, __be16 hmac_id) +{ + int found = 0; + int i; + + for (i = 0; i < n_elts; i++) { + if (hmac_id == hmacs[i]) { + found = 1; + break; + } + } + + return found; +} + +/* See if the HMAC_ID is one that we claim as supported */ +int sctp_auth_asoc_verify_hmac_id(const struct sctp_association *asoc, + __be16 hmac_id) +{ + struct sctp_hmac_algo_param *hmacs; + __u16 n_elt; + + if (!asoc) + return 0; + + hmacs = (struct sctp_hmac_algo_param *)asoc->c.auth_hmacs; + n_elt = (ntohs(hmacs->param_hdr.length) - + sizeof(struct sctp_paramhdr)) >> 1; + + return __sctp_auth_find_hmacid(hmacs->hmac_ids, n_elt, hmac_id); +} + + +/* Cache the default HMAC id. This to follow this text from SCTP-AUTH: + * Section 6.1: + * The receiver of a HMAC-ALGO parameter SHOULD use the first listed + * algorithm it supports. + */ +void sctp_auth_asoc_set_default_hmac(struct sctp_association *asoc, + struct sctp_hmac_algo_param *hmacs) +{ + struct sctp_endpoint *ep; + __u16 id; + int i; + int n_params; + + /* if the default id is already set, use it */ + if (asoc->default_hmac_id) + return; + + n_params = (ntohs(hmacs->param_hdr.length) - + sizeof(struct sctp_paramhdr)) >> 1; + ep = asoc->ep; + for (i = 0; i < n_params; i++) { + id = ntohs(hmacs->hmac_ids[i]); + + /* Check the id is in the supported range */ + if (id > SCTP_AUTH_HMAC_ID_MAX) + continue; + + /* If this TFM has been allocated, use this id */ + if (ep->auth_hmacs[id]) { + asoc->default_hmac_id = id; + break; + } + } +} + + +/* Check to see if the given chunk is supposed to be authenticated */ +static int __sctp_auth_cid(enum sctp_cid chunk, struct sctp_chunks_param *param) +{ + unsigned short len; + int found = 0; + int i; + + if (!param || param->param_hdr.length == 0) + return 0; + + len = ntohs(param->param_hdr.length) - sizeof(struct sctp_paramhdr); + + /* SCTP-AUTH, Section 3.2 + * The chunk types for INIT, INIT-ACK, SHUTDOWN-COMPLETE and AUTH + * chunks MUST NOT be listed in the CHUNKS parameter. However, if + * a CHUNKS parameter is received then the types for INIT, INIT-ACK, + * SHUTDOWN-COMPLETE and AUTH chunks MUST be ignored. + */ + for (i = 0; !found && i < len; i++) { + switch (param->chunks[i]) { + case SCTP_CID_INIT: + case SCTP_CID_INIT_ACK: + case SCTP_CID_SHUTDOWN_COMPLETE: + case SCTP_CID_AUTH: + break; + + default: + if (param->chunks[i] == chunk) + found = 1; + break; + } + } + + return found; +} + +/* Check if peer requested that this chunk is authenticated */ +int sctp_auth_send_cid(enum sctp_cid chunk, const struct sctp_association *asoc) +{ + if (!asoc) + return 0; + + if (!asoc->peer.auth_capable) + return 0; + + return __sctp_auth_cid(chunk, asoc->peer.peer_chunks); +} + +/* Check if we requested that peer authenticate this chunk. */ +int sctp_auth_recv_cid(enum sctp_cid chunk, const struct sctp_association *asoc) +{ + if (!asoc) + return 0; + + if (!asoc->peer.auth_capable) + return 0; + + return __sctp_auth_cid(chunk, + (struct sctp_chunks_param *)asoc->c.auth_chunks); +} + +/* SCTP-AUTH: Section 6.2: + * The sender MUST calculate the MAC as described in RFC2104 [2] using + * the hash function H as described by the MAC Identifier and the shared + * association key K based on the endpoint pair shared key described by + * the shared key identifier. The 'data' used for the computation of + * the AUTH-chunk is given by the AUTH chunk with its HMAC field set to + * zero (as shown in Figure 6) followed by all chunks that are placed + * after the AUTH chunk in the SCTP packet. + */ +void sctp_auth_calculate_hmac(const struct sctp_association *asoc, + struct sk_buff *skb, struct sctp_auth_chunk *auth, + struct sctp_shared_key *ep_key, gfp_t gfp) +{ + struct sctp_auth_bytes *asoc_key; + struct crypto_shash *tfm; + __u16 key_id, hmac_id; + unsigned char *end; + int free_key = 0; + __u8 *digest; + + /* Extract the info we need: + * - hmac id + * - key id + */ + key_id = ntohs(auth->auth_hdr.shkey_id); + hmac_id = ntohs(auth->auth_hdr.hmac_id); + + if (key_id == asoc->active_key_id) + asoc_key = asoc->asoc_shared_key; + else { + /* ep_key can't be NULL here */ + asoc_key = sctp_auth_asoc_create_secret(asoc, ep_key, gfp); + if (!asoc_key) + return; + + free_key = 1; + } + + /* set up scatter list */ + end = skb_tail_pointer(skb); + + tfm = asoc->ep->auth_hmacs[hmac_id]; + + digest = auth->auth_hdr.hmac; + if (crypto_shash_setkey(tfm, &asoc_key->data[0], asoc_key->len)) + goto free; + + crypto_shash_tfm_digest(tfm, (u8 *)auth, end - (unsigned char *)auth, + digest); + +free: + if (free_key) + sctp_auth_key_put(asoc_key); +} + +/* API Helpers */ + +/* Add a chunk to the endpoint authenticated chunk list */ +int sctp_auth_ep_add_chunkid(struct sctp_endpoint *ep, __u8 chunk_id) +{ + struct sctp_chunks_param *p = ep->auth_chunk_list; + __u16 nchunks; + __u16 param_len; + + /* If this chunk is already specified, we are done */ + if (__sctp_auth_cid(chunk_id, p)) + return 0; + + /* Check if we can add this chunk to the array */ + param_len = ntohs(p->param_hdr.length); + nchunks = param_len - sizeof(struct sctp_paramhdr); + if (nchunks == SCTP_NUM_CHUNK_TYPES) + return -EINVAL; + + p->chunks[nchunks] = chunk_id; + p->param_hdr.length = htons(param_len + 1); + return 0; +} + +/* Add hmac identifires to the endpoint list of supported hmac ids */ +int sctp_auth_ep_set_hmacs(struct sctp_endpoint *ep, + struct sctp_hmacalgo *hmacs) +{ + int has_sha1 = 0; + __u16 id; + int i; + + /* Scan the list looking for unsupported id. Also make sure that + * SHA1 is specified. + */ + for (i = 0; i < hmacs->shmac_num_idents; i++) { + id = hmacs->shmac_idents[i]; + + if (id > SCTP_AUTH_HMAC_ID_MAX) + return -EOPNOTSUPP; + + if (SCTP_AUTH_HMAC_ID_SHA1 == id) + has_sha1 = 1; + + if (!sctp_hmac_list[id].hmac_name) + return -EOPNOTSUPP; + } + + if (!has_sha1) + return -EINVAL; + + for (i = 0; i < hmacs->shmac_num_idents; i++) + ep->auth_hmacs_list->hmac_ids[i] = + htons(hmacs->shmac_idents[i]); + ep->auth_hmacs_list->param_hdr.length = + htons(sizeof(struct sctp_paramhdr) + + hmacs->shmac_num_idents * sizeof(__u16)); + return 0; +} + +/* Set a new shared key on either endpoint or association. If the + * key with a same ID already exists, replace the key (remove the + * old key and add a new one). + */ +int sctp_auth_set_key(struct sctp_endpoint *ep, + struct sctp_association *asoc, + struct sctp_authkey *auth_key) +{ + struct sctp_shared_key *cur_key, *shkey; + struct sctp_auth_bytes *key; + struct list_head *sh_keys; + int replace = 0; + + /* Try to find the given key id to see if + * we are doing a replace, or adding a new key + */ + if (asoc) { + if (!asoc->peer.auth_capable) + return -EACCES; + sh_keys = &asoc->endpoint_shared_keys; + } else { + if (!ep->auth_enable) + return -EACCES; + sh_keys = &ep->endpoint_shared_keys; + } + + key_for_each(shkey, sh_keys) { + if (shkey->key_id == auth_key->sca_keynumber) { + replace = 1; + break; + } + } + + cur_key = sctp_auth_shkey_create(auth_key->sca_keynumber, GFP_KERNEL); + if (!cur_key) + return -ENOMEM; + + /* Create a new key data based on the info passed in */ + key = sctp_auth_create_key(auth_key->sca_keylength, GFP_KERNEL); + if (!key) { + kfree(cur_key); + return -ENOMEM; + } + + memcpy(key->data, &auth_key->sca_key[0], auth_key->sca_keylength); + cur_key->key = key; + + if (!replace) { + list_add(&cur_key->key_list, sh_keys); + return 0; + } + + list_del_init(&shkey->key_list); + list_add(&cur_key->key_list, sh_keys); + + if (asoc && asoc->active_key_id == auth_key->sca_keynumber && + sctp_auth_asoc_init_active_key(asoc, GFP_KERNEL)) { + list_del_init(&cur_key->key_list); + sctp_auth_shkey_release(cur_key); + list_add(&shkey->key_list, sh_keys); + return -ENOMEM; + } + + sctp_auth_shkey_release(shkey); + return 0; +} + +int sctp_auth_set_active_key(struct sctp_endpoint *ep, + struct sctp_association *asoc, + __u16 key_id) +{ + struct sctp_shared_key *key; + struct list_head *sh_keys; + int found = 0; + + /* The key identifier MUST correst to an existing key */ + if (asoc) { + if (!asoc->peer.auth_capable) + return -EACCES; + sh_keys = &asoc->endpoint_shared_keys; + } else { + if (!ep->auth_enable) + return -EACCES; + sh_keys = &ep->endpoint_shared_keys; + } + + key_for_each(key, sh_keys) { + if (key->key_id == key_id) { + found = 1; + break; + } + } + + if (!found || key->deactivated) + return -EINVAL; + + if (asoc) { + __u16 active_key_id = asoc->active_key_id; + + asoc->active_key_id = key_id; + if (sctp_auth_asoc_init_active_key(asoc, GFP_KERNEL)) { + asoc->active_key_id = active_key_id; + return -ENOMEM; + } + } else + ep->active_key_id = key_id; + + return 0; +} + +int sctp_auth_del_key_id(struct sctp_endpoint *ep, + struct sctp_association *asoc, + __u16 key_id) +{ + struct sctp_shared_key *key; + struct list_head *sh_keys; + int found = 0; + + /* The key identifier MUST NOT be the current active key + * The key identifier MUST correst to an existing key + */ + if (asoc) { + if (!asoc->peer.auth_capable) + return -EACCES; + if (asoc->active_key_id == key_id) + return -EINVAL; + + sh_keys = &asoc->endpoint_shared_keys; + } else { + if (!ep->auth_enable) + return -EACCES; + if (ep->active_key_id == key_id) + return -EINVAL; + + sh_keys = &ep->endpoint_shared_keys; + } + + key_for_each(key, sh_keys) { + if (key->key_id == key_id) { + found = 1; + break; + } + } + + if (!found) + return -EINVAL; + + /* Delete the shared key */ + list_del_init(&key->key_list); + sctp_auth_shkey_release(key); + + return 0; +} + +int sctp_auth_deact_key_id(struct sctp_endpoint *ep, + struct sctp_association *asoc, __u16 key_id) +{ + struct sctp_shared_key *key; + struct list_head *sh_keys; + int found = 0; + + /* The key identifier MUST NOT be the current active key + * The key identifier MUST correst to an existing key + */ + if (asoc) { + if (!asoc->peer.auth_capable) + return -EACCES; + if (asoc->active_key_id == key_id) + return -EINVAL; + + sh_keys = &asoc->endpoint_shared_keys; + } else { + if (!ep->auth_enable) + return -EACCES; + if (ep->active_key_id == key_id) + return -EINVAL; + + sh_keys = &ep->endpoint_shared_keys; + } + + key_for_each(key, sh_keys) { + if (key->key_id == key_id) { + found = 1; + break; + } + } + + if (!found) + return -EINVAL; + + /* refcnt == 1 and !list_empty mean it's not being used anywhere + * and deactivated will be set, so it's time to notify userland + * that this shkey can be freed. + */ + if (asoc && !list_empty(&key->key_list) && + refcount_read(&key->refcnt) == 1) { + struct sctp_ulpevent *ev; + + ev = sctp_ulpevent_make_authkey(asoc, key->key_id, + SCTP_AUTH_FREE_KEY, GFP_KERNEL); + if (ev) + asoc->stream.si->enqueue_event(&asoc->ulpq, ev); + } + + key->deactivated = 1; + + return 0; +} + +int sctp_auth_init(struct sctp_endpoint *ep, gfp_t gfp) +{ + int err = -ENOMEM; + + /* Allocate space for HMACS and CHUNKS authentication + * variables. There are arrays that we encode directly + * into parameters to make the rest of the operations easier. + */ + if (!ep->auth_hmacs_list) { + struct sctp_hmac_algo_param *auth_hmacs; + + auth_hmacs = kzalloc(struct_size(auth_hmacs, hmac_ids, + SCTP_AUTH_NUM_HMACS), gfp); + if (!auth_hmacs) + goto nomem; + /* Initialize the HMACS parameter. + * SCTP-AUTH: Section 3.3 + * Every endpoint supporting SCTP chunk authentication MUST + * support the HMAC based on the SHA-1 algorithm. + */ + auth_hmacs->param_hdr.type = SCTP_PARAM_HMAC_ALGO; + auth_hmacs->param_hdr.length = + htons(sizeof(struct sctp_paramhdr) + 2); + auth_hmacs->hmac_ids[0] = htons(SCTP_AUTH_HMAC_ID_SHA1); + ep->auth_hmacs_list = auth_hmacs; + } + + if (!ep->auth_chunk_list) { + struct sctp_chunks_param *auth_chunks; + + auth_chunks = kzalloc(sizeof(*auth_chunks) + + SCTP_NUM_CHUNK_TYPES, gfp); + if (!auth_chunks) + goto nomem; + /* Initialize the CHUNKS parameter */ + auth_chunks->param_hdr.type = SCTP_PARAM_CHUNKS; + auth_chunks->param_hdr.length = + htons(sizeof(struct sctp_paramhdr)); + ep->auth_chunk_list = auth_chunks; + } + + /* Allocate and initialize transorms arrays for supported + * HMACs. + */ + err = sctp_auth_init_hmacs(ep, gfp); + if (err) + goto nomem; + + return 0; + +nomem: + /* Free all allocations */ + kfree(ep->auth_hmacs_list); + kfree(ep->auth_chunk_list); + ep->auth_hmacs_list = NULL; + ep->auth_chunk_list = NULL; + return err; +} + +void sctp_auth_free(struct sctp_endpoint *ep) +{ + kfree(ep->auth_hmacs_list); + kfree(ep->auth_chunk_list); + ep->auth_hmacs_list = NULL; + ep->auth_chunk_list = NULL; + sctp_auth_destroy_hmacs(ep->auth_hmacs); + ep->auth_hmacs = NULL; +} diff --git a/net/sctp/bind_addr.c b/net/sctp/bind_addr.c new file mode 100644 index 000000000..6b95d3ba8 --- /dev/null +++ b/net/sctp/bind_addr.c @@ -0,0 +1,575 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2001, 2003 + * Copyright (c) Cisco 1999,2000 + * Copyright (c) Motorola 1999,2000,2001 + * Copyright (c) La Monte H.P. Yarroll 2001 + * + * This file is part of the SCTP kernel implementation. + * + * A collection class to handle the storage of transport addresses. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * La Monte H.P. Yarroll + * Karl Knutson + * Jon Grimm + * Daisy Chang + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +/* Forward declarations for internal helpers. */ +static int sctp_copy_one_addr(struct net *net, struct sctp_bind_addr *dest, + union sctp_addr *addr, enum sctp_scope scope, + gfp_t gfp, int flags); +static void sctp_bind_addr_clean(struct sctp_bind_addr *); + +/* First Level Abstractions. */ + +/* Copy 'src' to 'dest' taking 'scope' into account. Omit addresses + * in 'src' which have a broader scope than 'scope'. + */ +int sctp_bind_addr_copy(struct net *net, struct sctp_bind_addr *dest, + const struct sctp_bind_addr *src, + enum sctp_scope scope, gfp_t gfp, + int flags) +{ + struct sctp_sockaddr_entry *addr; + int error = 0; + + /* All addresses share the same port. */ + dest->port = src->port; + + /* Extract the addresses which are relevant for this scope. */ + list_for_each_entry(addr, &src->address_list, list) { + error = sctp_copy_one_addr(net, dest, &addr->a, scope, + gfp, flags); + if (error < 0) + goto out; + } + + /* If there are no addresses matching the scope and + * this is global scope, try to get a link scope address, with + * the assumption that we must be sitting behind a NAT. + */ + if (list_empty(&dest->address_list) && (SCTP_SCOPE_GLOBAL == scope)) { + list_for_each_entry(addr, &src->address_list, list) { + error = sctp_copy_one_addr(net, dest, &addr->a, + SCTP_SCOPE_LINK, gfp, + flags); + if (error < 0) + goto out; + } + } + + /* If somehow no addresses were found that can be used with this + * scope, it's an error. + */ + if (list_empty(&dest->address_list)) + error = -ENETUNREACH; + +out: + if (error) + sctp_bind_addr_clean(dest); + + return error; +} + +/* Exactly duplicate the address lists. This is necessary when doing + * peer-offs and accepts. We don't want to put all the current system + * addresses into the endpoint. That's useless. But we do want duplicat + * the list of bound addresses that the older endpoint used. + */ +int sctp_bind_addr_dup(struct sctp_bind_addr *dest, + const struct sctp_bind_addr *src, + gfp_t gfp) +{ + struct sctp_sockaddr_entry *addr; + int error = 0; + + /* All addresses share the same port. */ + dest->port = src->port; + + list_for_each_entry(addr, &src->address_list, list) { + error = sctp_add_bind_addr(dest, &addr->a, sizeof(addr->a), + 1, gfp); + if (error < 0) + break; + } + + return error; +} + +/* Initialize the SCTP_bind_addr structure for either an endpoint or + * an association. + */ +void sctp_bind_addr_init(struct sctp_bind_addr *bp, __u16 port) +{ + INIT_LIST_HEAD(&bp->address_list); + bp->port = port; +} + +/* Dispose of the address list. */ +static void sctp_bind_addr_clean(struct sctp_bind_addr *bp) +{ + struct sctp_sockaddr_entry *addr, *temp; + + /* Empty the bind address list. */ + list_for_each_entry_safe(addr, temp, &bp->address_list, list) { + list_del_rcu(&addr->list); + kfree_rcu(addr, rcu); + SCTP_DBG_OBJCNT_DEC(addr); + } +} + +/* Dispose of an SCTP_bind_addr structure */ +void sctp_bind_addr_free(struct sctp_bind_addr *bp) +{ + /* Empty the bind address list. */ + sctp_bind_addr_clean(bp); +} + +/* Add an address to the bind address list in the SCTP_bind_addr structure. */ +int sctp_add_bind_addr(struct sctp_bind_addr *bp, union sctp_addr *new, + int new_size, __u8 addr_state, gfp_t gfp) +{ + struct sctp_sockaddr_entry *addr; + + /* Add the address to the bind address list. */ + addr = kzalloc(sizeof(*addr), gfp); + if (!addr) + return -ENOMEM; + + memcpy(&addr->a, new, min_t(size_t, sizeof(*new), new_size)); + + /* Fix up the port if it has not yet been set. + * Both v4 and v6 have the port at the same offset. + */ + if (!addr->a.v4.sin_port) + addr->a.v4.sin_port = htons(bp->port); + + addr->state = addr_state; + addr->valid = 1; + + INIT_LIST_HEAD(&addr->list); + + /* We always hold a socket lock when calling this function, + * and that acts as a writer synchronizing lock. + */ + list_add_tail_rcu(&addr->list, &bp->address_list); + SCTP_DBG_OBJCNT_INC(addr); + + return 0; +} + +/* Delete an address from the bind address list in the SCTP_bind_addr + * structure. + */ +int sctp_del_bind_addr(struct sctp_bind_addr *bp, union sctp_addr *del_addr) +{ + struct sctp_sockaddr_entry *addr, *temp; + int found = 0; + + /* We hold the socket lock when calling this function, + * and that acts as a writer synchronizing lock. + */ + list_for_each_entry_safe(addr, temp, &bp->address_list, list) { + if (sctp_cmp_addr_exact(&addr->a, del_addr)) { + /* Found the exact match. */ + found = 1; + addr->valid = 0; + list_del_rcu(&addr->list); + break; + } + } + + if (found) { + kfree_rcu(addr, rcu); + SCTP_DBG_OBJCNT_DEC(addr); + return 0; + } + + return -EINVAL; +} + +/* Create a network byte-order representation of all the addresses + * formated as SCTP parameters. + * + * The second argument is the return value for the length. + */ +union sctp_params sctp_bind_addrs_to_raw(const struct sctp_bind_addr *bp, + int *addrs_len, + gfp_t gfp) +{ + union sctp_params addrparms; + union sctp_params retval; + int addrparms_len; + union sctp_addr_param rawaddr; + int len; + struct sctp_sockaddr_entry *addr; + struct list_head *pos; + struct sctp_af *af; + + addrparms_len = 0; + len = 0; + + /* Allocate enough memory at once. */ + list_for_each(pos, &bp->address_list) { + len += sizeof(union sctp_addr_param); + } + + /* Don't even bother embedding an address if there + * is only one. + */ + if (len == sizeof(union sctp_addr_param)) { + retval.v = NULL; + goto end_raw; + } + + retval.v = kmalloc(len, gfp); + if (!retval.v) + goto end_raw; + + addrparms = retval; + + list_for_each_entry(addr, &bp->address_list, list) { + af = sctp_get_af_specific(addr->a.v4.sin_family); + len = af->to_addr_param(&addr->a, &rawaddr); + memcpy(addrparms.v, &rawaddr, len); + addrparms.v += len; + addrparms_len += len; + } + +end_raw: + *addrs_len = addrparms_len; + return retval; +} + +/* + * Create an address list out of the raw address list format (IPv4 and IPv6 + * address parameters). + */ +int sctp_raw_to_bind_addrs(struct sctp_bind_addr *bp, __u8 *raw_addr_list, + int addrs_len, __u16 port, gfp_t gfp) +{ + union sctp_addr_param *rawaddr; + struct sctp_paramhdr *param; + union sctp_addr addr; + int retval = 0; + int len; + struct sctp_af *af; + + /* Convert the raw address to standard address format */ + while (addrs_len) { + param = (struct sctp_paramhdr *)raw_addr_list; + rawaddr = (union sctp_addr_param *)raw_addr_list; + + af = sctp_get_af_specific(param_type2af(param->type)); + if (unlikely(!af) || + !af->from_addr_param(&addr, rawaddr, htons(port), 0)) { + retval = -EINVAL; + goto out_err; + } + + if (sctp_bind_addr_state(bp, &addr) != -1) + goto next; + retval = sctp_add_bind_addr(bp, &addr, sizeof(addr), + SCTP_ADDR_SRC, gfp); + if (retval) + /* Can't finish building the list, clean up. */ + goto out_err; + +next: + len = ntohs(param->length); + addrs_len -= len; + raw_addr_list += len; + } + + return retval; + +out_err: + if (retval) + sctp_bind_addr_clean(bp); + + return retval; +} + +/******************************************************************** + * 2nd Level Abstractions + ********************************************************************/ + +/* Does this contain a specified address? Allow wildcarding. */ +int sctp_bind_addr_match(struct sctp_bind_addr *bp, + const union sctp_addr *addr, + struct sctp_sock *opt) +{ + struct sctp_sockaddr_entry *laddr; + int match = 0; + + rcu_read_lock(); + list_for_each_entry_rcu(laddr, &bp->address_list, list) { + if (!laddr->valid) + continue; + if (opt->pf->cmp_addr(&laddr->a, addr, opt)) { + match = 1; + break; + } + } + rcu_read_unlock(); + + return match; +} + +int sctp_bind_addrs_check(struct sctp_sock *sp, + struct sctp_sock *sp2, int cnt2) +{ + struct sctp_bind_addr *bp2 = &sp2->ep->base.bind_addr; + struct sctp_bind_addr *bp = &sp->ep->base.bind_addr; + struct sctp_sockaddr_entry *laddr, *laddr2; + bool exist = false; + int cnt = 0; + + rcu_read_lock(); + list_for_each_entry_rcu(laddr, &bp->address_list, list) { + list_for_each_entry_rcu(laddr2, &bp2->address_list, list) { + if (sp->pf->af->cmp_addr(&laddr->a, &laddr2->a) && + laddr->valid && laddr2->valid) { + exist = true; + goto next; + } + } + cnt = 0; + break; +next: + cnt++; + } + rcu_read_unlock(); + + return (cnt == cnt2) ? 0 : (exist ? -EEXIST : 1); +} + +/* Does the address 'addr' conflict with any addresses in + * the bp. + */ +int sctp_bind_addr_conflict(struct sctp_bind_addr *bp, + const union sctp_addr *addr, + struct sctp_sock *bp_sp, + struct sctp_sock *addr_sp) +{ + struct sctp_sockaddr_entry *laddr; + int conflict = 0; + struct sctp_sock *sp; + + /* Pick the IPv6 socket as the basis of comparison + * since it's usually a superset of the IPv4. + * If there is no IPv6 socket, then default to bind_addr. + */ + if (sctp_opt2sk(bp_sp)->sk_family == AF_INET6) + sp = bp_sp; + else if (sctp_opt2sk(addr_sp)->sk_family == AF_INET6) + sp = addr_sp; + else + sp = bp_sp; + + rcu_read_lock(); + list_for_each_entry_rcu(laddr, &bp->address_list, list) { + if (!laddr->valid) + continue; + + conflict = sp->pf->cmp_addr(&laddr->a, addr, sp); + if (conflict) + break; + } + rcu_read_unlock(); + + return conflict; +} + +/* Get the state of the entry in the bind_addr_list */ +int sctp_bind_addr_state(const struct sctp_bind_addr *bp, + const union sctp_addr *addr) +{ + struct sctp_sockaddr_entry *laddr; + struct sctp_af *af; + + af = sctp_get_af_specific(addr->sa.sa_family); + if (unlikely(!af)) + return -1; + + list_for_each_entry_rcu(laddr, &bp->address_list, list) { + if (!laddr->valid) + continue; + if (af->cmp_addr(&laddr->a, addr)) + return laddr->state; + } + + return -1; +} + +/* Find the first address in the bind address list that is not present in + * the addrs packed array. + */ +union sctp_addr *sctp_find_unmatch_addr(struct sctp_bind_addr *bp, + const union sctp_addr *addrs, + int addrcnt, + struct sctp_sock *opt) +{ + struct sctp_sockaddr_entry *laddr; + union sctp_addr *addr; + void *addr_buf; + struct sctp_af *af; + int i; + + /* This is only called sctp_send_asconf_del_ip() and we hold + * the socket lock in that code patch, so that address list + * can't change. + */ + list_for_each_entry(laddr, &bp->address_list, list) { + addr_buf = (union sctp_addr *)addrs; + for (i = 0; i < addrcnt; i++) { + addr = addr_buf; + af = sctp_get_af_specific(addr->v4.sin_family); + if (!af) + break; + + if (opt->pf->cmp_addr(&laddr->a, addr, opt)) + break; + + addr_buf += af->sockaddr_len; + } + if (i == addrcnt) + return &laddr->a; + } + + return NULL; +} + +/* Copy out addresses from the global local address list. */ +static int sctp_copy_one_addr(struct net *net, struct sctp_bind_addr *dest, + union sctp_addr *addr, enum sctp_scope scope, + gfp_t gfp, int flags) +{ + int error = 0; + + if (sctp_is_any(NULL, addr)) { + error = sctp_copy_local_addr_list(net, dest, scope, gfp, flags); + } else if (sctp_in_scope(net, addr, scope)) { + /* Now that the address is in scope, check to see if + * the address type is supported by local sock as + * well as the remote peer. + */ + if ((((AF_INET == addr->sa.sa_family) && + (flags & SCTP_ADDR4_ALLOWED) && + (flags & SCTP_ADDR4_PEERSUPP))) || + (((AF_INET6 == addr->sa.sa_family) && + (flags & SCTP_ADDR6_ALLOWED) && + (flags & SCTP_ADDR6_PEERSUPP)))) + error = sctp_add_bind_addr(dest, addr, sizeof(*addr), + SCTP_ADDR_SRC, gfp); + } + + return error; +} + +/* Is this a wildcard address? */ +int sctp_is_any(struct sock *sk, const union sctp_addr *addr) +{ + unsigned short fam = 0; + struct sctp_af *af; + + /* Try to get the right address family */ + if (addr->sa.sa_family != AF_UNSPEC) + fam = addr->sa.sa_family; + else if (sk) + fam = sk->sk_family; + + af = sctp_get_af_specific(fam); + if (!af) + return 0; + + return af->is_any(addr); +} + +/* Is 'addr' valid for 'scope'? */ +int sctp_in_scope(struct net *net, const union sctp_addr *addr, + enum sctp_scope scope) +{ + enum sctp_scope addr_scope = sctp_scope(addr); + + /* The unusable SCTP addresses will not be considered with + * any defined scopes. + */ + if (SCTP_SCOPE_UNUSABLE == addr_scope) + return 0; + /* + * For INIT and INIT-ACK address list, let L be the level of + * requested destination address, sender and receiver + * SHOULD include all of its addresses with level greater + * than or equal to L. + * + * Address scoping can be selectively controlled via sysctl + * option + */ + switch (net->sctp.scope_policy) { + case SCTP_SCOPE_POLICY_DISABLE: + return 1; + case SCTP_SCOPE_POLICY_ENABLE: + if (addr_scope <= scope) + return 1; + break; + case SCTP_SCOPE_POLICY_PRIVATE: + if (addr_scope <= scope || SCTP_SCOPE_PRIVATE == addr_scope) + return 1; + break; + case SCTP_SCOPE_POLICY_LINK: + if (addr_scope <= scope || SCTP_SCOPE_LINK == addr_scope) + return 1; + break; + default: + break; + } + + return 0; +} + +int sctp_is_ep_boundall(struct sock *sk) +{ + struct sctp_bind_addr *bp; + struct sctp_sockaddr_entry *addr; + + bp = &sctp_sk(sk)->ep->base.bind_addr; + if (sctp_list_single_entry(&bp->address_list)) { + addr = list_entry(bp->address_list.next, + struct sctp_sockaddr_entry, list); + if (sctp_is_any(sk, &addr->a)) + return 1; + } + return 0; +} + +/******************************************************************** + * 3rd Level Abstractions + ********************************************************************/ + +/* What is the scope of 'addr'? */ +enum sctp_scope sctp_scope(const union sctp_addr *addr) +{ + struct sctp_af *af; + + af = sctp_get_af_specific(addr->sa.sa_family); + if (!af) + return SCTP_SCOPE_UNUSABLE; + + return af->scope((union sctp_addr *)addr); +} diff --git a/net/sctp/chunk.c b/net/sctp/chunk.c new file mode 100644 index 000000000..fd4f8243c --- /dev/null +++ b/net/sctp/chunk.c @@ -0,0 +1,353 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2003, 2004 + * + * This file is part of the SCTP kernel implementation + * + * This file contains the code relating the chunk abstraction. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * Jon Grimm + * Sridhar Samudrala + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* This file is mostly in anticipation of future work, but initially + * populate with fragment tracking for an outbound message. + */ + +/* Initialize datamsg from memory. */ +static void sctp_datamsg_init(struct sctp_datamsg *msg) +{ + refcount_set(&msg->refcnt, 1); + msg->send_failed = 0; + msg->send_error = 0; + msg->can_delay = 1; + msg->abandoned = 0; + msg->expires_at = 0; + INIT_LIST_HEAD(&msg->chunks); +} + +/* Allocate and initialize datamsg. */ +static struct sctp_datamsg *sctp_datamsg_new(gfp_t gfp) +{ + struct sctp_datamsg *msg; + msg = kmalloc(sizeof(struct sctp_datamsg), gfp); + if (msg) { + sctp_datamsg_init(msg); + SCTP_DBG_OBJCNT_INC(datamsg); + } + return msg; +} + +void sctp_datamsg_free(struct sctp_datamsg *msg) +{ + struct sctp_chunk *chunk; + + /* This doesn't have to be a _safe vairant because + * sctp_chunk_free() only drops the refs. + */ + list_for_each_entry(chunk, &msg->chunks, frag_list) + sctp_chunk_free(chunk); + + sctp_datamsg_put(msg); +} + +/* Final destructruction of datamsg memory. */ +static void sctp_datamsg_destroy(struct sctp_datamsg *msg) +{ + struct sctp_association *asoc = NULL; + struct list_head *pos, *temp; + struct sctp_chunk *chunk; + struct sctp_ulpevent *ev; + int error, sent; + + /* Release all references. */ + list_for_each_safe(pos, temp, &msg->chunks) { + list_del_init(pos); + chunk = list_entry(pos, struct sctp_chunk, frag_list); + + if (!msg->send_failed) { + sctp_chunk_put(chunk); + continue; + } + + asoc = chunk->asoc; + error = msg->send_error ?: asoc->outqueue.error; + sent = chunk->has_tsn ? SCTP_DATA_SENT : SCTP_DATA_UNSENT; + + if (sctp_ulpevent_type_enabled(asoc->subscribe, + SCTP_SEND_FAILED)) { + ev = sctp_ulpevent_make_send_failed(asoc, chunk, sent, + error, GFP_ATOMIC); + if (ev) + asoc->stream.si->enqueue_event(&asoc->ulpq, ev); + } + + if (sctp_ulpevent_type_enabled(asoc->subscribe, + SCTP_SEND_FAILED_EVENT)) { + ev = sctp_ulpevent_make_send_failed_event(asoc, chunk, + sent, error, + GFP_ATOMIC); + if (ev) + asoc->stream.si->enqueue_event(&asoc->ulpq, ev); + } + + sctp_chunk_put(chunk); + } + + SCTP_DBG_OBJCNT_DEC(datamsg); + kfree(msg); +} + +/* Hold a reference. */ +static void sctp_datamsg_hold(struct sctp_datamsg *msg) +{ + refcount_inc(&msg->refcnt); +} + +/* Release a reference. */ +void sctp_datamsg_put(struct sctp_datamsg *msg) +{ + if (refcount_dec_and_test(&msg->refcnt)) + sctp_datamsg_destroy(msg); +} + +/* Assign a chunk to this datamsg. */ +static void sctp_datamsg_assign(struct sctp_datamsg *msg, struct sctp_chunk *chunk) +{ + sctp_datamsg_hold(msg); + chunk->msg = msg; +} + + +/* A data chunk can have a maximum payload of (2^16 - 20). Break + * down any such message into smaller chunks. Opportunistically, fragment + * the chunks down to the current MTU constraints. We may get refragmented + * later if the PMTU changes, but it is _much better_ to fragment immediately + * with a reasonable guess than always doing our fragmentation on the + * soft-interrupt. + */ +struct sctp_datamsg *sctp_datamsg_from_user(struct sctp_association *asoc, + struct sctp_sndrcvinfo *sinfo, + struct iov_iter *from) +{ + size_t len, first_len, max_data, remaining; + size_t msg_len = iov_iter_count(from); + struct sctp_shared_key *shkey = NULL; + struct list_head *pos, *temp; + struct sctp_chunk *chunk; + struct sctp_datamsg *msg; + int err; + + msg = sctp_datamsg_new(GFP_KERNEL); + if (!msg) + return ERR_PTR(-ENOMEM); + + /* Note: Calculate this outside of the loop, so that all fragments + * have the same expiration. + */ + if (asoc->peer.prsctp_capable && sinfo->sinfo_timetolive && + (SCTP_PR_TTL_ENABLED(sinfo->sinfo_flags) || + !SCTP_PR_POLICY(sinfo->sinfo_flags))) + msg->expires_at = jiffies + + msecs_to_jiffies(sinfo->sinfo_timetolive); + + /* This is the biggest possible DATA chunk that can fit into + * the packet + */ + max_data = asoc->frag_point; + if (unlikely(!max_data)) { + max_data = sctp_min_frag_point(sctp_sk(asoc->base.sk), + sctp_datachk_len(&asoc->stream)); + pr_warn_ratelimited("%s: asoc:%p frag_point is zero, forcing max_data to default minimum (%zu)", + __func__, asoc, max_data); + } + + /* If the peer requested that we authenticate DATA chunks + * we need to account for bundling of the AUTH chunks along with + * DATA. + */ + if (sctp_auth_send_cid(SCTP_CID_DATA, asoc)) { + struct sctp_hmac *hmac_desc = sctp_auth_asoc_get_hmac(asoc); + + if (hmac_desc) + max_data -= SCTP_PAD4(sizeof(struct sctp_auth_chunk) + + hmac_desc->hmac_len); + + if (sinfo->sinfo_tsn && + sinfo->sinfo_ssn != asoc->active_key_id) { + shkey = sctp_auth_get_shkey(asoc, sinfo->sinfo_ssn); + if (!shkey) { + err = -EINVAL; + goto errout; + } + } else { + shkey = asoc->shkey; + } + } + + /* Set first_len and then account for possible bundles on first frag */ + first_len = max_data; + + /* Check to see if we have a pending SACK and try to let it be bundled + * with this message. Do this if we don't have any data queued already. + * To check that, look at out_qlen and retransmit list. + * NOTE: we will not reduce to account for SACK, if the message would + * not have been fragmented. + */ + if (timer_pending(&asoc->timers[SCTP_EVENT_TIMEOUT_SACK]) && + asoc->outqueue.out_qlen == 0 && + list_empty(&asoc->outqueue.retransmit) && + msg_len > max_data) + first_len -= SCTP_PAD4(sizeof(struct sctp_sack_chunk)); + + /* Encourage Cookie-ECHO bundling. */ + if (asoc->state < SCTP_STATE_COOKIE_ECHOED) + first_len -= SCTP_ARBITRARY_COOKIE_ECHO_LEN; + + /* Account for a different sized first fragment */ + if (msg_len >= first_len) { + msg->can_delay = 0; + if (msg_len > first_len) + SCTP_INC_STATS(asoc->base.net, + SCTP_MIB_FRAGUSRMSGS); + } else { + /* Which may be the only one... */ + first_len = msg_len; + } + + /* Create chunks for all DATA chunks. */ + for (remaining = msg_len; remaining; remaining -= len) { + u8 frag = SCTP_DATA_MIDDLE_FRAG; + + if (remaining == msg_len) { + /* First frag, which may also be the last */ + frag |= SCTP_DATA_FIRST_FRAG; + len = first_len; + } else { + /* Middle frags */ + len = max_data; + } + + if (len >= remaining) { + /* Last frag, which may also be the first */ + len = remaining; + frag |= SCTP_DATA_LAST_FRAG; + + /* The application requests to set the I-bit of the + * last DATA chunk of a user message when providing + * the user message to the SCTP implementation. + */ + if ((sinfo->sinfo_flags & SCTP_EOF) || + (sinfo->sinfo_flags & SCTP_SACK_IMMEDIATELY)) + frag |= SCTP_DATA_SACK_IMM; + } + + chunk = asoc->stream.si->make_datafrag(asoc, sinfo, len, frag, + GFP_KERNEL); + if (!chunk) { + err = -ENOMEM; + goto errout; + } + + err = sctp_user_addto_chunk(chunk, len, from); + if (err < 0) + goto errout_chunk_free; + + chunk->shkey = shkey; + + /* Put the chunk->skb back into the form expected by send. */ + __skb_pull(chunk->skb, (__u8 *)chunk->chunk_hdr - + chunk->skb->data); + + sctp_datamsg_assign(msg, chunk); + list_add_tail(&chunk->frag_list, &msg->chunks); + } + + return msg; + +errout_chunk_free: + sctp_chunk_free(chunk); + +errout: + list_for_each_safe(pos, temp, &msg->chunks) { + list_del_init(pos); + chunk = list_entry(pos, struct sctp_chunk, frag_list); + sctp_chunk_free(chunk); + } + sctp_datamsg_put(msg); + + return ERR_PTR(err); +} + +/* Check whether this message has expired. */ +int sctp_chunk_abandoned(struct sctp_chunk *chunk) +{ + if (!chunk->asoc->peer.prsctp_capable) + return 0; + + if (chunk->msg->abandoned) + return 1; + + if (!chunk->has_tsn && + !(chunk->chunk_hdr->flags & SCTP_DATA_FIRST_FRAG)) + return 0; + + if (SCTP_PR_TTL_ENABLED(chunk->sinfo.sinfo_flags) && + time_after(jiffies, chunk->msg->expires_at)) { + struct sctp_stream_out *streamout = + SCTP_SO(&chunk->asoc->stream, + chunk->sinfo.sinfo_stream); + + if (chunk->sent_count) { + chunk->asoc->abandoned_sent[SCTP_PR_INDEX(TTL)]++; + streamout->ext->abandoned_sent[SCTP_PR_INDEX(TTL)]++; + } else { + chunk->asoc->abandoned_unsent[SCTP_PR_INDEX(TTL)]++; + streamout->ext->abandoned_unsent[SCTP_PR_INDEX(TTL)]++; + } + chunk->msg->abandoned = 1; + return 1; + } else if (SCTP_PR_RTX_ENABLED(chunk->sinfo.sinfo_flags) && + chunk->sent_count > chunk->sinfo.sinfo_timetolive) { + struct sctp_stream_out *streamout = + SCTP_SO(&chunk->asoc->stream, + chunk->sinfo.sinfo_stream); + + chunk->asoc->abandoned_sent[SCTP_PR_INDEX(RTX)]++; + streamout->ext->abandoned_sent[SCTP_PR_INDEX(RTX)]++; + chunk->msg->abandoned = 1; + return 1; + } else if (!SCTP_PR_POLICY(chunk->sinfo.sinfo_flags) && + chunk->msg->expires_at && + time_after(jiffies, chunk->msg->expires_at)) { + chunk->msg->abandoned = 1; + return 1; + } + /* PRIO policy is processed by sendmsg, not here */ + + return 0; +} + +/* This chunk (and consequently entire message) has failed in its sending. */ +void sctp_chunk_fail(struct sctp_chunk *chunk, int error) +{ + chunk->msg->send_failed = 1; + chunk->msg->send_error = error; +} diff --git a/net/sctp/debug.c b/net/sctp/debug.c new file mode 100644 index 000000000..ccd773e4c --- /dev/null +++ b/net/sctp/debug.c @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2001, 2004 + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * Copyright (c) 2001 Intel Corp. + * + * This file is part of the SCTP kernel implementation + * + * This file converts numerical ID value to alphabetical names for SCTP + * terms such as chunk type, parameter time, event type, etc. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * La Monte H.P. Yarroll + * Karl Knutson + * Xingang Guo + * Jon Grimm + * Daisy Chang + * Sridhar Samudrala + */ + +#include + +/* These are printable forms of Chunk ID's from section 3.1. */ +static const char *const sctp_cid_tbl[SCTP_NUM_BASE_CHUNK_TYPES] = { + "DATA", + "INIT", + "INIT_ACK", + "SACK", + "HEARTBEAT", + "HEARTBEAT_ACK", + "ABORT", + "SHUTDOWN", + "SHUTDOWN_ACK", + "ERROR", + "COOKIE_ECHO", + "COOKIE_ACK", + "ECN_ECNE", + "ECN_CWR", + "SHUTDOWN_COMPLETE", +}; + +/* Lookup "chunk type" debug name. */ +const char *sctp_cname(const union sctp_subtype cid) +{ + if (cid.chunk <= SCTP_CID_BASE_MAX) + return sctp_cid_tbl[cid.chunk]; + + switch (cid.chunk) { + case SCTP_CID_ASCONF: + return "ASCONF"; + + case SCTP_CID_ASCONF_ACK: + return "ASCONF_ACK"; + + case SCTP_CID_FWD_TSN: + return "FWD_TSN"; + + case SCTP_CID_AUTH: + return "AUTH"; + + case SCTP_CID_RECONF: + return "RECONF"; + + case SCTP_CID_I_DATA: + return "I_DATA"; + + case SCTP_CID_I_FWD_TSN: + return "I_FWD_TSN"; + + default: + break; + } + + return "unknown chunk"; +} + +/* These are printable forms of the states. */ +const char *const sctp_state_tbl[SCTP_STATE_NUM_STATES] = { + "STATE_CLOSED", + "STATE_COOKIE_WAIT", + "STATE_COOKIE_ECHOED", + "STATE_ESTABLISHED", + "STATE_SHUTDOWN_PENDING", + "STATE_SHUTDOWN_SENT", + "STATE_SHUTDOWN_RECEIVED", + "STATE_SHUTDOWN_ACK_SENT", +}; + +/* Events that could change the state of an association. */ +const char *const sctp_evttype_tbl[] = { + "EVENT_T_unknown", + "EVENT_T_CHUNK", + "EVENT_T_TIMEOUT", + "EVENT_T_OTHER", + "EVENT_T_PRIMITIVE" +}; + +/* Return value of a state function */ +const char *const sctp_status_tbl[] = { + "DISPOSITION_DISCARD", + "DISPOSITION_CONSUME", + "DISPOSITION_NOMEM", + "DISPOSITION_DELETE_TCB", + "DISPOSITION_ABORT", + "DISPOSITION_VIOLATION", + "DISPOSITION_NOT_IMPL", + "DISPOSITION_ERROR", + "DISPOSITION_BUG" +}; + +/* Printable forms of primitives */ +static const char *const sctp_primitive_tbl[SCTP_NUM_PRIMITIVE_TYPES] = { + "PRIMITIVE_ASSOCIATE", + "PRIMITIVE_SHUTDOWN", + "PRIMITIVE_ABORT", + "PRIMITIVE_SEND", + "PRIMITIVE_REQUESTHEARTBEAT", + "PRIMITIVE_ASCONF", +}; + +/* Lookup primitive debug name. */ +const char *sctp_pname(const union sctp_subtype id) +{ + if (id.primitive <= SCTP_EVENT_PRIMITIVE_MAX) + return sctp_primitive_tbl[id.primitive]; + return "unknown_primitive"; +} + +static const char *const sctp_other_tbl[] = { + "NO_PENDING_TSN", + "ICMP_PROTO_UNREACH", +}; + +/* Lookup "other" debug name. */ +const char *sctp_oname(const union sctp_subtype id) +{ + if (id.other <= SCTP_EVENT_OTHER_MAX) + return sctp_other_tbl[id.other]; + return "unknown 'other' event"; +} + +static const char *const sctp_timer_tbl[] = { + "TIMEOUT_NONE", + "TIMEOUT_T1_COOKIE", + "TIMEOUT_T1_INIT", + "TIMEOUT_T2_SHUTDOWN", + "TIMEOUT_T3_RTX", + "TIMEOUT_T4_RTO", + "TIMEOUT_T5_SHUTDOWN_GUARD", + "TIMEOUT_HEARTBEAT", + "TIMEOUT_RECONF", + "TIMEOUT_PROBE", + "TIMEOUT_SACK", + "TIMEOUT_AUTOCLOSE", +}; + +/* Lookup timer debug name. */ +const char *sctp_tname(const union sctp_subtype id) +{ + BUILD_BUG_ON(SCTP_EVENT_TIMEOUT_MAX + 1 != ARRAY_SIZE(sctp_timer_tbl)); + + if (id.timeout < ARRAY_SIZE(sctp_timer_tbl)) + return sctp_timer_tbl[id.timeout]; + return "unknown_timer"; +} diff --git a/net/sctp/diag.c b/net/sctp/diag.c new file mode 100644 index 000000000..b0ce10808 --- /dev/null +++ b/net/sctp/diag.c @@ -0,0 +1,529 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright Red Hat Inc. 2017 + * + * This file is part of the SCTP kernel implementation + * + * These functions implement sctp diag support. + * + * Please send any bug reports or fixes you make to the + * email addresched(es): + * lksctp developers + * + * Written or modified by: + * Xin Long + */ + +#include +#include +#include +#include + +static void sctp_diag_get_info(struct sock *sk, struct inet_diag_msg *r, + void *info); + +/* define some functions to make asoc/ep fill look clean */ +static void inet_diag_msg_sctpasoc_fill(struct inet_diag_msg *r, + struct sock *sk, + struct sctp_association *asoc) +{ + union sctp_addr laddr, paddr; + struct dst_entry *dst; + struct timer_list *t3_rtx = &asoc->peer.primary_path->T3_rtx_timer; + + laddr = list_entry(asoc->base.bind_addr.address_list.next, + struct sctp_sockaddr_entry, list)->a; + paddr = asoc->peer.primary_path->ipaddr; + dst = asoc->peer.primary_path->dst; + + r->idiag_family = sk->sk_family; + r->id.idiag_sport = htons(asoc->base.bind_addr.port); + r->id.idiag_dport = htons(asoc->peer.port); + r->id.idiag_if = dst ? dst->dev->ifindex : 0; + sock_diag_save_cookie(sk, r->id.idiag_cookie); + +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) { + *(struct in6_addr *)r->id.idiag_src = laddr.v6.sin6_addr; + *(struct in6_addr *)r->id.idiag_dst = paddr.v6.sin6_addr; + } else +#endif + { + memset(&r->id.idiag_src, 0, sizeof(r->id.idiag_src)); + memset(&r->id.idiag_dst, 0, sizeof(r->id.idiag_dst)); + + r->id.idiag_src[0] = laddr.v4.sin_addr.s_addr; + r->id.idiag_dst[0] = paddr.v4.sin_addr.s_addr; + } + + r->idiag_state = asoc->state; + if (timer_pending(t3_rtx)) { + r->idiag_timer = SCTP_EVENT_TIMEOUT_T3_RTX; + r->idiag_retrans = asoc->rtx_data_chunks; + r->idiag_expires = jiffies_to_msecs(t3_rtx->expires - jiffies); + } +} + +static int inet_diag_msg_sctpladdrs_fill(struct sk_buff *skb, + struct list_head *address_list) +{ + struct sctp_sockaddr_entry *laddr; + int addrlen = sizeof(struct sockaddr_storage); + int addrcnt = 0; + struct nlattr *attr; + void *info = NULL; + + list_for_each_entry_rcu(laddr, address_list, list) + addrcnt++; + + attr = nla_reserve(skb, INET_DIAG_LOCALS, addrlen * addrcnt); + if (!attr) + return -EMSGSIZE; + + info = nla_data(attr); + list_for_each_entry_rcu(laddr, address_list, list) { + memcpy(info, &laddr->a, sizeof(laddr->a)); + memset(info + sizeof(laddr->a), 0, addrlen - sizeof(laddr->a)); + info += addrlen; + } + + return 0; +} + +static int inet_diag_msg_sctpaddrs_fill(struct sk_buff *skb, + struct sctp_association *asoc) +{ + int addrlen = sizeof(struct sockaddr_storage); + struct sctp_transport *from; + struct nlattr *attr; + void *info = NULL; + + attr = nla_reserve(skb, INET_DIAG_PEERS, + addrlen * asoc->peer.transport_count); + if (!attr) + return -EMSGSIZE; + + info = nla_data(attr); + list_for_each_entry(from, &asoc->peer.transport_addr_list, + transports) { + memcpy(info, &from->ipaddr, sizeof(from->ipaddr)); + memset(info + sizeof(from->ipaddr), 0, + addrlen - sizeof(from->ipaddr)); + info += addrlen; + } + + return 0; +} + +/* sctp asoc/ep fill*/ +static int inet_sctp_diag_fill(struct sock *sk, struct sctp_association *asoc, + struct sk_buff *skb, + const struct inet_diag_req_v2 *req, + struct user_namespace *user_ns, + int portid, u32 seq, u16 nlmsg_flags, + const struct nlmsghdr *unlh, + bool net_admin) +{ + struct sctp_endpoint *ep = sctp_sk(sk)->ep; + struct list_head *addr_list; + struct inet_diag_msg *r; + struct nlmsghdr *nlh; + int ext = req->idiag_ext; + struct sctp_infox infox; + void *info = NULL; + + nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r), + nlmsg_flags); + if (!nlh) + return -EMSGSIZE; + + r = nlmsg_data(nlh); + BUG_ON(!sk_fullsock(sk)); + + r->idiag_timer = 0; + r->idiag_retrans = 0; + r->idiag_expires = 0; + if (asoc) { + inet_diag_msg_sctpasoc_fill(r, sk, asoc); + } else { + inet_diag_msg_common_fill(r, sk); + r->idiag_state = sk->sk_state; + } + + if (inet_diag_msg_attrs_fill(sk, skb, r, ext, user_ns, net_admin)) + goto errout; + + if (ext & (1 << (INET_DIAG_SKMEMINFO - 1))) { + u32 mem[SK_MEMINFO_VARS]; + int amt; + + if (asoc && asoc->ep->sndbuf_policy) + amt = asoc->sndbuf_used; + else + amt = sk_wmem_alloc_get(sk); + mem[SK_MEMINFO_WMEM_ALLOC] = amt; + if (asoc && asoc->ep->rcvbuf_policy) + amt = atomic_read(&asoc->rmem_alloc); + else + amt = sk_rmem_alloc_get(sk); + mem[SK_MEMINFO_RMEM_ALLOC] = amt; + mem[SK_MEMINFO_RCVBUF] = sk->sk_rcvbuf; + mem[SK_MEMINFO_SNDBUF] = sk->sk_sndbuf; + mem[SK_MEMINFO_FWD_ALLOC] = sk->sk_forward_alloc; + mem[SK_MEMINFO_WMEM_QUEUED] = sk->sk_wmem_queued; + mem[SK_MEMINFO_OPTMEM] = atomic_read(&sk->sk_omem_alloc); + mem[SK_MEMINFO_BACKLOG] = READ_ONCE(sk->sk_backlog.len); + mem[SK_MEMINFO_DROPS] = atomic_read(&sk->sk_drops); + + if (nla_put(skb, INET_DIAG_SKMEMINFO, sizeof(mem), &mem) < 0) + goto errout; + } + + if (ext & (1 << (INET_DIAG_INFO - 1))) { + struct nlattr *attr; + + attr = nla_reserve_64bit(skb, INET_DIAG_INFO, + sizeof(struct sctp_info), + INET_DIAG_PAD); + if (!attr) + goto errout; + + info = nla_data(attr); + } + infox.sctpinfo = (struct sctp_info *)info; + infox.asoc = asoc; + sctp_diag_get_info(sk, r, &infox); + + addr_list = asoc ? &asoc->base.bind_addr.address_list + : &ep->base.bind_addr.address_list; + if (inet_diag_msg_sctpladdrs_fill(skb, addr_list)) + goto errout; + + if (asoc && (ext & (1 << (INET_DIAG_CONG - 1)))) + if (nla_put_string(skb, INET_DIAG_CONG, "reno") < 0) + goto errout; + + if (asoc && inet_diag_msg_sctpaddrs_fill(skb, asoc)) + goto errout; + + nlmsg_end(skb, nlh); + return 0; + +errout: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +/* callback and param */ +struct sctp_comm_param { + struct sk_buff *skb; + struct netlink_callback *cb; + const struct inet_diag_req_v2 *r; + const struct nlmsghdr *nlh; + bool net_admin; +}; + +static size_t inet_assoc_attr_size(struct sctp_association *asoc) +{ + int addrlen = sizeof(struct sockaddr_storage); + int addrcnt = 0; + struct sctp_sockaddr_entry *laddr; + + list_for_each_entry_rcu(laddr, &asoc->base.bind_addr.address_list, + list) + addrcnt++; + + return nla_total_size(sizeof(struct sctp_info)) + + nla_total_size(addrlen * asoc->peer.transport_count) + + nla_total_size(addrlen * addrcnt) + + nla_total_size(sizeof(struct inet_diag_msg)) + + inet_diag_msg_attrs_size() + + nla_total_size(sizeof(struct inet_diag_meminfo)) + + 64; +} + +static int sctp_sock_dump_one(struct sctp_endpoint *ep, struct sctp_transport *tsp, void *p) +{ + struct sctp_association *assoc = tsp->asoc; + struct sctp_comm_param *commp = p; + struct sock *sk = ep->base.sk; + const struct inet_diag_req_v2 *req = commp->r; + struct sk_buff *skb = commp->skb; + struct sk_buff *rep; + int err; + + err = sock_diag_check_cookie(sk, req->id.idiag_cookie); + if (err) + return err; + + rep = nlmsg_new(inet_assoc_attr_size(assoc), GFP_KERNEL); + if (!rep) + return -ENOMEM; + + lock_sock(sk); + if (ep != assoc->ep) { + err = -EAGAIN; + goto out; + } + + err = inet_sctp_diag_fill(sk, assoc, rep, req, sk_user_ns(NETLINK_CB(skb).sk), + NETLINK_CB(skb).portid, commp->nlh->nlmsg_seq, 0, + commp->nlh, commp->net_admin); + if (err < 0) { + WARN_ON(err == -EMSGSIZE); + goto out; + } + release_sock(sk); + + return nlmsg_unicast(sock_net(skb->sk)->diag_nlsk, rep, NETLINK_CB(skb).portid); + +out: + release_sock(sk); + kfree_skb(rep); + return err; +} + +static int sctp_sock_dump(struct sctp_endpoint *ep, struct sctp_transport *tsp, void *p) +{ + struct sctp_comm_param *commp = p; + struct sock *sk = ep->base.sk; + struct sk_buff *skb = commp->skb; + struct netlink_callback *cb = commp->cb; + const struct inet_diag_req_v2 *r = commp->r; + struct sctp_association *assoc; + int err = 0; + + lock_sock(sk); + if (ep != tsp->asoc->ep) + goto release; + list_for_each_entry(assoc, &ep->asocs, asocs) { + if (cb->args[4] < cb->args[1]) + goto next; + + if (r->id.idiag_sport != htons(assoc->base.bind_addr.port) && + r->id.idiag_sport) + goto next; + if (r->id.idiag_dport != htons(assoc->peer.port) && + r->id.idiag_dport) + goto next; + + if (!cb->args[3] && + inet_sctp_diag_fill(sk, NULL, skb, r, + sk_user_ns(NETLINK_CB(cb->skb).sk), + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI, cb->nlh, + commp->net_admin) < 0) { + err = 1; + goto release; + } + cb->args[3] = 1; + + if (inet_sctp_diag_fill(sk, assoc, skb, r, + sk_user_ns(NETLINK_CB(cb->skb).sk), + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, 0, cb->nlh, + commp->net_admin) < 0) { + err = 1; + goto release; + } +next: + cb->args[4]++; + } + cb->args[1] = 0; + cb->args[3] = 0; + cb->args[4] = 0; +release: + release_sock(sk); + return err; +} + +static int sctp_sock_filter(struct sctp_endpoint *ep, struct sctp_transport *tsp, void *p) +{ + struct sctp_comm_param *commp = p; + struct sock *sk = ep->base.sk; + const struct inet_diag_req_v2 *r = commp->r; + + /* find the ep only once through the transports by this condition */ + if (!list_is_first(&tsp->asoc->asocs, &ep->asocs)) + return 0; + + if (r->sdiag_family != AF_UNSPEC && sk->sk_family != r->sdiag_family) + return 0; + + return 1; +} + +static int sctp_ep_dump(struct sctp_endpoint *ep, void *p) +{ + struct sctp_comm_param *commp = p; + struct sock *sk = ep->base.sk; + struct sk_buff *skb = commp->skb; + struct netlink_callback *cb = commp->cb; + const struct inet_diag_req_v2 *r = commp->r; + struct net *net = sock_net(skb->sk); + struct inet_sock *inet = inet_sk(sk); + int err = 0; + + if (!net_eq(sock_net(sk), net)) + goto out; + + if (cb->args[4] < cb->args[1]) + goto next; + + if (!(r->idiag_states & TCPF_LISTEN) && !list_empty(&ep->asocs)) + goto next; + + if (r->sdiag_family != AF_UNSPEC && + sk->sk_family != r->sdiag_family) + goto next; + + if (r->id.idiag_sport != inet->inet_sport && + r->id.idiag_sport) + goto next; + + if (r->id.idiag_dport != inet->inet_dport && + r->id.idiag_dport) + goto next; + + if (inet_sctp_diag_fill(sk, NULL, skb, r, + sk_user_ns(NETLINK_CB(cb->skb).sk), + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + cb->nlh, commp->net_admin) < 0) { + err = 2; + goto out; + } +next: + cb->args[4]++; +out: + return err; +} + +/* define the functions for sctp_diag_handler*/ +static void sctp_diag_get_info(struct sock *sk, struct inet_diag_msg *r, + void *info) +{ + struct sctp_infox *infox = (struct sctp_infox *)info; + + if (infox->asoc) { + r->idiag_rqueue = atomic_read(&infox->asoc->rmem_alloc); + r->idiag_wqueue = infox->asoc->sndbuf_used; + } else { + r->idiag_rqueue = READ_ONCE(sk->sk_ack_backlog); + r->idiag_wqueue = READ_ONCE(sk->sk_max_ack_backlog); + } + if (infox->sctpinfo) + sctp_get_sctp_info(sk, infox->asoc, infox->sctpinfo); +} + +static int sctp_diag_dump_one(struct netlink_callback *cb, + const struct inet_diag_req_v2 *req) +{ + struct sk_buff *skb = cb->skb; + struct net *net = sock_net(skb->sk); + const struct nlmsghdr *nlh = cb->nlh; + union sctp_addr laddr, paddr; + struct sctp_comm_param commp = { + .skb = skb, + .r = req, + .nlh = nlh, + .net_admin = netlink_net_capable(skb, CAP_NET_ADMIN), + }; + + if (req->sdiag_family == AF_INET) { + laddr.v4.sin_port = req->id.idiag_sport; + laddr.v4.sin_addr.s_addr = req->id.idiag_src[0]; + laddr.v4.sin_family = AF_INET; + + paddr.v4.sin_port = req->id.idiag_dport; + paddr.v4.sin_addr.s_addr = req->id.idiag_dst[0]; + paddr.v4.sin_family = AF_INET; + } else { + laddr.v6.sin6_port = req->id.idiag_sport; + memcpy(&laddr.v6.sin6_addr, req->id.idiag_src, + sizeof(laddr.v6.sin6_addr)); + laddr.v6.sin6_family = AF_INET6; + + paddr.v6.sin6_port = req->id.idiag_dport; + memcpy(&paddr.v6.sin6_addr, req->id.idiag_dst, + sizeof(paddr.v6.sin6_addr)); + paddr.v6.sin6_family = AF_INET6; + } + + return sctp_transport_lookup_process(sctp_sock_dump_one, + net, &laddr, &paddr, &commp); +} + +static void sctp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, + const struct inet_diag_req_v2 *r) +{ + u32 idiag_states = r->idiag_states; + struct net *net = sock_net(skb->sk); + struct sctp_comm_param commp = { + .skb = skb, + .cb = cb, + .r = r, + .net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN), + }; + int pos = cb->args[2]; + + /* eps hashtable dumps + * args: + * 0 : if it will traversal listen sock + * 1 : to record the sock pos of this time's traversal + * 4 : to work as a temporary variable to traversal list + */ + if (cb->args[0] == 0) { + if (!(idiag_states & TCPF_LISTEN)) + goto skip; + if (sctp_for_each_endpoint(sctp_ep_dump, &commp)) + goto done; +skip: + cb->args[0] = 1; + cb->args[1] = 0; + cb->args[4] = 0; + } + + /* asocs by transport hashtable dump + * args: + * 1 : to record the assoc pos of this time's traversal + * 2 : to record the transport pos of this time's traversal + * 3 : to mark if we have dumped the ep info of the current asoc + * 4 : to work as a temporary variable to traversal list + * 5 : to save the sk we get from travelsing the tsp list. + */ + if (!(idiag_states & ~(TCPF_LISTEN | TCPF_CLOSE))) + goto done; + + sctp_transport_traverse_process(sctp_sock_filter, sctp_sock_dump, + net, &pos, &commp); + cb->args[2] = pos; + +done: + cb->args[1] = cb->args[4]; + cb->args[4] = 0; +} + +static const struct inet_diag_handler sctp_diag_handler = { + .dump = sctp_diag_dump, + .dump_one = sctp_diag_dump_one, + .idiag_get_info = sctp_diag_get_info, + .idiag_type = IPPROTO_SCTP, + .idiag_info_size = sizeof(struct sctp_info), +}; + +static int __init sctp_diag_init(void) +{ + return inet_diag_register(&sctp_diag_handler); +} + +static void __exit sctp_diag_exit(void) +{ + inet_diag_unregister(&sctp_diag_handler); +} + +module_init(sctp_diag_init); +module_exit(sctp_diag_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 2-132); diff --git a/net/sctp/endpointola.c b/net/sctp/endpointola.c new file mode 100644 index 000000000..efffde7f2 --- /dev/null +++ b/net/sctp/endpointola.c @@ -0,0 +1,417 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * Copyright (c) 2001-2002 International Business Machines, Corp. + * Copyright (c) 2001 Intel Corp. + * Copyright (c) 2001 Nokia, Inc. + * Copyright (c) 2001 La Monte H.P. Yarroll + * + * This file is part of the SCTP kernel implementation + * + * This abstraction represents an SCTP endpoint. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * La Monte H.P. Yarroll + * Karl Knutson + * Jon Grimm + * Daisy Chang + * Dajiang Zhang + */ + +#include +#include +#include +#include /* get_random_bytes() */ +#include +#include +#include +#include + +/* Forward declarations for internal helpers. */ +static void sctp_endpoint_bh_rcv(struct work_struct *work); + +/* + * Initialize the base fields of the endpoint structure. + */ +static struct sctp_endpoint *sctp_endpoint_init(struct sctp_endpoint *ep, + struct sock *sk, + gfp_t gfp) +{ + struct net *net = sock_net(sk); + struct sctp_shared_key *null_key; + + ep->digest = kzalloc(SCTP_SIGNATURE_SIZE, gfp); + if (!ep->digest) + return NULL; + + ep->asconf_enable = net->sctp.addip_enable; + ep->auth_enable = net->sctp.auth_enable; + if (ep->auth_enable) { + if (sctp_auth_init(ep, gfp)) + goto nomem; + if (ep->asconf_enable) { + sctp_auth_ep_add_chunkid(ep, SCTP_CID_ASCONF); + sctp_auth_ep_add_chunkid(ep, SCTP_CID_ASCONF_ACK); + } + } + + /* Initialize the base structure. */ + /* What type of endpoint are we? */ + ep->base.type = SCTP_EP_TYPE_SOCKET; + + /* Initialize the basic object fields. */ + refcount_set(&ep->base.refcnt, 1); + ep->base.dead = false; + + /* Create an input queue. */ + sctp_inq_init(&ep->base.inqueue); + + /* Set its top-half handler */ + sctp_inq_set_th_handler(&ep->base.inqueue, sctp_endpoint_bh_rcv); + + /* Initialize the bind addr area */ + sctp_bind_addr_init(&ep->base.bind_addr, 0); + + /* Create the lists of associations. */ + INIT_LIST_HEAD(&ep->asocs); + + /* Use SCTP specific send buffer space queues. */ + ep->sndbuf_policy = net->sctp.sndbuf_policy; + + sk->sk_data_ready = sctp_data_ready; + sk->sk_write_space = sctp_write_space; + sock_set_flag(sk, SOCK_USE_WRITE_QUEUE); + + /* Get the receive buffer policy for this endpoint */ + ep->rcvbuf_policy = net->sctp.rcvbuf_policy; + + /* Initialize the secret key used with cookie. */ + get_random_bytes(ep->secret_key, sizeof(ep->secret_key)); + + /* SCTP-AUTH extensions*/ + INIT_LIST_HEAD(&ep->endpoint_shared_keys); + null_key = sctp_auth_shkey_create(0, gfp); + if (!null_key) + goto nomem_shkey; + + list_add(&null_key->key_list, &ep->endpoint_shared_keys); + + /* Add the null key to the endpoint shared keys list and + * set the hmcas and chunks pointers. + */ + ep->prsctp_enable = net->sctp.prsctp_enable; + ep->reconf_enable = net->sctp.reconf_enable; + ep->ecn_enable = net->sctp.ecn_enable; + + /* Remember who we are attached to. */ + ep->base.sk = sk; + ep->base.net = sock_net(sk); + sock_hold(ep->base.sk); + + return ep; + +nomem_shkey: + sctp_auth_free(ep); +nomem: + kfree(ep->digest); + return NULL; + +} + +/* Create a sctp_endpoint with all that boring stuff initialized. + * Returns NULL if there isn't enough memory. + */ +struct sctp_endpoint *sctp_endpoint_new(struct sock *sk, gfp_t gfp) +{ + struct sctp_endpoint *ep; + + /* Build a local endpoint. */ + ep = kzalloc(sizeof(*ep), gfp); + if (!ep) + goto fail; + + if (!sctp_endpoint_init(ep, sk, gfp)) + goto fail_init; + + SCTP_DBG_OBJCNT_INC(ep); + return ep; + +fail_init: + kfree(ep); +fail: + return NULL; +} + +/* Add an association to an endpoint. */ +void sctp_endpoint_add_asoc(struct sctp_endpoint *ep, + struct sctp_association *asoc) +{ + struct sock *sk = ep->base.sk; + + /* If this is a temporary association, don't bother + * since we'll be removing it shortly and don't + * want anyone to find it anyway. + */ + if (asoc->temp) + return; + + /* Now just add it to our list of asocs */ + list_add_tail(&asoc->asocs, &ep->asocs); + + /* Increment the backlog value for a TCP-style listening socket. */ + if (sctp_style(sk, TCP) && sctp_sstate(sk, LISTENING)) + sk_acceptq_added(sk); +} + +/* Free the endpoint structure. Delay cleanup until + * all users have released their reference count on this structure. + */ +void sctp_endpoint_free(struct sctp_endpoint *ep) +{ + ep->base.dead = true; + + inet_sk_set_state(ep->base.sk, SCTP_SS_CLOSED); + + /* Unlink this endpoint, so we can't find it again! */ + sctp_unhash_endpoint(ep); + + sctp_endpoint_put(ep); +} + +/* Final destructor for endpoint. */ +static void sctp_endpoint_destroy_rcu(struct rcu_head *head) +{ + struct sctp_endpoint *ep = container_of(head, struct sctp_endpoint, rcu); + struct sock *sk = ep->base.sk; + + sctp_sk(sk)->ep = NULL; + sock_put(sk); + + kfree(ep); + SCTP_DBG_OBJCNT_DEC(ep); +} + +static void sctp_endpoint_destroy(struct sctp_endpoint *ep) +{ + struct sock *sk; + + if (unlikely(!ep->base.dead)) { + WARN(1, "Attempt to destroy undead endpoint %p!\n", ep); + return; + } + + /* Free the digest buffer */ + kfree(ep->digest); + + /* SCTP-AUTH: Free up AUTH releated data such as shared keys + * chunks and hmacs arrays that were allocated + */ + sctp_auth_destroy_keys(&ep->endpoint_shared_keys); + sctp_auth_free(ep); + + /* Cleanup. */ + sctp_inq_free(&ep->base.inqueue); + sctp_bind_addr_free(&ep->base.bind_addr); + + memset(ep->secret_key, 0, sizeof(ep->secret_key)); + + sk = ep->base.sk; + /* Remove and free the port */ + if (sctp_sk(sk)->bind_hash) + sctp_put_port(sk); + + call_rcu(&ep->rcu, sctp_endpoint_destroy_rcu); +} + +/* Hold a reference to an endpoint. */ +int sctp_endpoint_hold(struct sctp_endpoint *ep) +{ + return refcount_inc_not_zero(&ep->base.refcnt); +} + +/* Release a reference to an endpoint and clean up if there are + * no more references. + */ +void sctp_endpoint_put(struct sctp_endpoint *ep) +{ + if (refcount_dec_and_test(&ep->base.refcnt)) + sctp_endpoint_destroy(ep); +} + +/* Is this the endpoint we are looking for? */ +struct sctp_endpoint *sctp_endpoint_is_match(struct sctp_endpoint *ep, + struct net *net, + const union sctp_addr *laddr) +{ + struct sctp_endpoint *retval = NULL; + + if ((htons(ep->base.bind_addr.port) == laddr->v4.sin_port) && + net_eq(ep->base.net, net)) { + if (sctp_bind_addr_match(&ep->base.bind_addr, laddr, + sctp_sk(ep->base.sk))) + retval = ep; + } + + return retval; +} + +/* Find the association that goes with this chunk. + * We lookup the transport from hashtable at first, then get association + * through t->assoc. + */ +struct sctp_association *sctp_endpoint_lookup_assoc( + const struct sctp_endpoint *ep, + const union sctp_addr *paddr, + struct sctp_transport **transport) +{ + struct sctp_association *asoc = NULL; + struct sctp_transport *t; + + *transport = NULL; + + /* If the local port is not set, there can't be any associations + * on this endpoint. + */ + if (!ep->base.bind_addr.port) + return NULL; + + rcu_read_lock(); + t = sctp_epaddr_lookup_transport(ep, paddr); + if (!t) + goto out; + + *transport = t; + asoc = t->asoc; +out: + rcu_read_unlock(); + return asoc; +} + +/* Look for any peeled off association from the endpoint that matches the + * given peer address. + */ +bool sctp_endpoint_is_peeled_off(struct sctp_endpoint *ep, + const union sctp_addr *paddr) +{ + struct sctp_sockaddr_entry *addr; + struct net *net = ep->base.net; + struct sctp_bind_addr *bp; + + bp = &ep->base.bind_addr; + /* This function is called with the socket lock held, + * so the address_list can not change. + */ + list_for_each_entry(addr, &bp->address_list, list) { + if (sctp_has_association(net, &addr->a, paddr)) + return true; + } + + return false; +} + +/* Do delayed input processing. This is scheduled by sctp_rcv(). + * This may be called on BH or task time. + */ +static void sctp_endpoint_bh_rcv(struct work_struct *work) +{ + struct sctp_endpoint *ep = + container_of(work, struct sctp_endpoint, + base.inqueue.immediate); + struct sctp_association *asoc; + struct sock *sk; + struct net *net; + struct sctp_transport *transport; + struct sctp_chunk *chunk; + struct sctp_inq *inqueue; + union sctp_subtype subtype; + enum sctp_state state; + int error = 0; + int first_time = 1; /* is this the first time through the loop */ + + if (ep->base.dead) + return; + + asoc = NULL; + inqueue = &ep->base.inqueue; + sk = ep->base.sk; + net = sock_net(sk); + + while (NULL != (chunk = sctp_inq_pop(inqueue))) { + subtype = SCTP_ST_CHUNK(chunk->chunk_hdr->type); + + /* If the first chunk in the packet is AUTH, do special + * processing specified in Section 6.3 of SCTP-AUTH spec + */ + if (first_time && (subtype.chunk == SCTP_CID_AUTH)) { + struct sctp_chunkhdr *next_hdr; + + next_hdr = sctp_inq_peek(inqueue); + if (!next_hdr) + goto normal; + + /* If the next chunk is COOKIE-ECHO, skip the AUTH + * chunk while saving a pointer to it so we can do + * Authentication later (during cookie-echo + * processing). + */ + if (next_hdr->type == SCTP_CID_COOKIE_ECHO) { + chunk->auth_chunk = skb_clone(chunk->skb, + GFP_ATOMIC); + chunk->auth = 1; + continue; + } + } +normal: + /* We might have grown an association since last we + * looked, so try again. + * + * This happens when we've just processed our + * COOKIE-ECHO chunk. + */ + if (NULL == chunk->asoc) { + asoc = sctp_endpoint_lookup_assoc(ep, + sctp_source(chunk), + &transport); + chunk->asoc = asoc; + chunk->transport = transport; + } + + state = asoc ? asoc->state : SCTP_STATE_CLOSED; + if (sctp_auth_recv_cid(subtype.chunk, asoc) && !chunk->auth) + continue; + + /* Remember where the last DATA chunk came from so we + * know where to send the SACK. + */ + if (asoc && sctp_chunk_is_data(chunk)) + asoc->peer.last_data_from = chunk->transport; + else { + SCTP_INC_STATS(ep->base.net, SCTP_MIB_INCTRLCHUNKS); + if (asoc) + asoc->stats.ictrlchunks++; + } + + if (chunk->transport) + chunk->transport->last_time_heard = ktime_get(); + + error = sctp_do_sm(net, SCTP_EVENT_T_CHUNK, subtype, state, + ep, asoc, chunk, GFP_ATOMIC); + + if (error && chunk) + chunk->pdiscard = 1; + + /* Check to see if the endpoint is freed in response to + * the incoming chunk. If so, get out of the while loop. + */ + if (!sctp_sk(sk)->ep) + break; + + if (first_time) + first_time = 0; + } +} diff --git a/net/sctp/input.c b/net/sctp/input.c new file mode 100644 index 000000000..4f43afa86 --- /dev/null +++ b/net/sctp/input.c @@ -0,0 +1,1349 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * Copyright (c) 2001-2003 International Business Machines, Corp. + * Copyright (c) 2001 Intel Corp. + * Copyright (c) 2001 Nokia, Inc. + * Copyright (c) 2001 La Monte H.P. Yarroll + * + * This file is part of the SCTP kernel implementation + * + * These functions handle all input from the IP layer into SCTP. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * La Monte H.P. Yarroll + * Karl Knutson + * Xingang Guo + * Jon Grimm + * Hui Huang + * Daisy Chang + * Sridhar Samudrala + * Ardelle Fan + */ + +#include +#include /* For struct list_head */ +#include +#include +#include /* For struct timeval */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/* Forward declarations for internal helpers. */ +static int sctp_rcv_ootb(struct sk_buff *); +static struct sctp_association *__sctp_rcv_lookup(struct net *net, + struct sk_buff *skb, + const union sctp_addr *paddr, + const union sctp_addr *laddr, + struct sctp_transport **transportp); +static struct sctp_endpoint *__sctp_rcv_lookup_endpoint( + struct net *net, struct sk_buff *skb, + const union sctp_addr *laddr, + const union sctp_addr *daddr); +static struct sctp_association *__sctp_lookup_association( + struct net *net, + const union sctp_addr *local, + const union sctp_addr *peer, + struct sctp_transport **pt); + +static int sctp_add_backlog(struct sock *sk, struct sk_buff *skb); + + +/* Calculate the SCTP checksum of an SCTP packet. */ +static inline int sctp_rcv_checksum(struct net *net, struct sk_buff *skb) +{ + struct sctphdr *sh = sctp_hdr(skb); + __le32 cmp = sh->checksum; + __le32 val = sctp_compute_cksum(skb, 0); + + if (val != cmp) { + /* CRC failure, dump it. */ + __SCTP_INC_STATS(net, SCTP_MIB_CHECKSUMERRORS); + return -1; + } + return 0; +} + +/* + * This is the routine which IP calls when receiving an SCTP packet. + */ +int sctp_rcv(struct sk_buff *skb) +{ + struct sock *sk; + struct sctp_association *asoc; + struct sctp_endpoint *ep = NULL; + struct sctp_ep_common *rcvr; + struct sctp_transport *transport = NULL; + struct sctp_chunk *chunk; + union sctp_addr src; + union sctp_addr dest; + int bound_dev_if; + int family; + struct sctp_af *af; + struct net *net = dev_net(skb->dev); + bool is_gso = skb_is_gso(skb) && skb_is_gso_sctp(skb); + + if (skb->pkt_type != PACKET_HOST) + goto discard_it; + + __SCTP_INC_STATS(net, SCTP_MIB_INSCTPPACKS); + + /* If packet is too small to contain a single chunk, let's not + * waste time on it anymore. + */ + if (skb->len < sizeof(struct sctphdr) + sizeof(struct sctp_chunkhdr) + + skb_transport_offset(skb)) + goto discard_it; + + /* If the packet is fragmented and we need to do crc checking, + * it's better to just linearize it otherwise crc computing + * takes longer. + */ + if ((!is_gso && skb_linearize(skb)) || + !pskb_may_pull(skb, sizeof(struct sctphdr))) + goto discard_it; + + /* Pull up the IP header. */ + __skb_pull(skb, skb_transport_offset(skb)); + + skb->csum_valid = 0; /* Previous value not applicable */ + if (skb_csum_unnecessary(skb)) + __skb_decr_checksum_unnecessary(skb); + else if (!sctp_checksum_disable && + !is_gso && + sctp_rcv_checksum(net, skb) < 0) + goto discard_it; + skb->csum_valid = 1; + + __skb_pull(skb, sizeof(struct sctphdr)); + + family = ipver2af(ip_hdr(skb)->version); + af = sctp_get_af_specific(family); + if (unlikely(!af)) + goto discard_it; + SCTP_INPUT_CB(skb)->af = af; + + /* Initialize local addresses for lookups. */ + af->from_skb(&src, skb, 1); + af->from_skb(&dest, skb, 0); + + /* If the packet is to or from a non-unicast address, + * silently discard the packet. + * + * This is not clearly defined in the RFC except in section + * 8.4 - OOTB handling. However, based on the book "Stream Control + * Transmission Protocol" 2.1, "It is important to note that the + * IP address of an SCTP transport address must be a routable + * unicast address. In other words, IP multicast addresses and + * IP broadcast addresses cannot be used in an SCTP transport + * address." + */ + if (!af->addr_valid(&src, NULL, skb) || + !af->addr_valid(&dest, NULL, skb)) + goto discard_it; + + asoc = __sctp_rcv_lookup(net, skb, &src, &dest, &transport); + + if (!asoc) + ep = __sctp_rcv_lookup_endpoint(net, skb, &dest, &src); + + /* Retrieve the common input handling substructure. */ + rcvr = asoc ? &asoc->base : &ep->base; + sk = rcvr->sk; + + /* + * If a frame arrives on an interface and the receiving socket is + * bound to another interface, via SO_BINDTODEVICE, treat it as OOTB + */ + bound_dev_if = READ_ONCE(sk->sk_bound_dev_if); + if (bound_dev_if && (bound_dev_if != af->skb_iif(skb))) { + if (transport) { + sctp_transport_put(transport); + asoc = NULL; + transport = NULL; + } else { + sctp_endpoint_put(ep); + ep = NULL; + } + sk = net->sctp.ctl_sock; + ep = sctp_sk(sk)->ep; + sctp_endpoint_hold(ep); + rcvr = &ep->base; + } + + /* + * RFC 2960, 8.4 - Handle "Out of the blue" Packets. + * An SCTP packet is called an "out of the blue" (OOTB) + * packet if it is correctly formed, i.e., passed the + * receiver's checksum check, but the receiver is not + * able to identify the association to which this + * packet belongs. + */ + if (!asoc) { + if (sctp_rcv_ootb(skb)) { + __SCTP_INC_STATS(net, SCTP_MIB_OUTOFBLUES); + goto discard_release; + } + } + + if (!xfrm_policy_check(sk, XFRM_POLICY_IN, skb, family)) + goto discard_release; + nf_reset_ct(skb); + + if (sk_filter(sk, skb)) + goto discard_release; + + /* Create an SCTP packet structure. */ + chunk = sctp_chunkify(skb, asoc, sk, GFP_ATOMIC); + if (!chunk) + goto discard_release; + SCTP_INPUT_CB(skb)->chunk = chunk; + + /* Remember what endpoint is to handle this packet. */ + chunk->rcvr = rcvr; + + /* Remember the SCTP header. */ + chunk->sctp_hdr = sctp_hdr(skb); + + /* Set the source and destination addresses of the incoming chunk. */ + sctp_init_addrs(chunk, &src, &dest); + + /* Remember where we came from. */ + chunk->transport = transport; + + /* Acquire access to the sock lock. Note: We are safe from other + * bottom halves on this lock, but a user may be in the lock too, + * so check if it is busy. + */ + bh_lock_sock(sk); + + if (sk != rcvr->sk) { + /* Our cached sk is different from the rcvr->sk. This is + * because migrate()/accept() may have moved the association + * to a new socket and released all the sockets. So now we + * are holding a lock on the old socket while the user may + * be doing something with the new socket. Switch our veiw + * of the current sk. + */ + bh_unlock_sock(sk); + sk = rcvr->sk; + bh_lock_sock(sk); + } + + if (sock_owned_by_user(sk) || !sctp_newsk_ready(sk)) { + if (sctp_add_backlog(sk, skb)) { + bh_unlock_sock(sk); + sctp_chunk_free(chunk); + skb = NULL; /* sctp_chunk_free already freed the skb */ + goto discard_release; + } + __SCTP_INC_STATS(net, SCTP_MIB_IN_PKT_BACKLOG); + } else { + __SCTP_INC_STATS(net, SCTP_MIB_IN_PKT_SOFTIRQ); + sctp_inq_push(&chunk->rcvr->inqueue, chunk); + } + + bh_unlock_sock(sk); + + /* Release the asoc/ep ref we took in the lookup calls. */ + if (transport) + sctp_transport_put(transport); + else + sctp_endpoint_put(ep); + + return 0; + +discard_it: + __SCTP_INC_STATS(net, SCTP_MIB_IN_PKT_DISCARDS); + kfree_skb(skb); + return 0; + +discard_release: + /* Release the asoc/ep ref we took in the lookup calls. */ + if (transport) + sctp_transport_put(transport); + else + sctp_endpoint_put(ep); + + goto discard_it; +} + +/* Process the backlog queue of the socket. Every skb on + * the backlog holds a ref on an association or endpoint. + * We hold this ref throughout the state machine to make + * sure that the structure we need is still around. + */ +int sctp_backlog_rcv(struct sock *sk, struct sk_buff *skb) +{ + struct sctp_chunk *chunk = SCTP_INPUT_CB(skb)->chunk; + struct sctp_inq *inqueue = &chunk->rcvr->inqueue; + struct sctp_transport *t = chunk->transport; + struct sctp_ep_common *rcvr = NULL; + int backloged = 0; + + rcvr = chunk->rcvr; + + /* If the rcvr is dead then the association or endpoint + * has been deleted and we can safely drop the chunk + * and refs that we are holding. + */ + if (rcvr->dead) { + sctp_chunk_free(chunk); + goto done; + } + + if (unlikely(rcvr->sk != sk)) { + /* In this case, the association moved from one socket to + * another. We are currently sitting on the backlog of the + * old socket, so we need to move. + * However, since we are here in the process context we + * need to take make sure that the user doesn't own + * the new socket when we process the packet. + * If the new socket is user-owned, queue the chunk to the + * backlog of the new socket without dropping any refs. + * Otherwise, we can safely push the chunk on the inqueue. + */ + + sk = rcvr->sk; + local_bh_disable(); + bh_lock_sock(sk); + + if (sock_owned_by_user(sk) || !sctp_newsk_ready(sk)) { + if (sk_add_backlog(sk, skb, READ_ONCE(sk->sk_rcvbuf))) + sctp_chunk_free(chunk); + else + backloged = 1; + } else + sctp_inq_push(inqueue, chunk); + + bh_unlock_sock(sk); + local_bh_enable(); + + /* If the chunk was backloged again, don't drop refs */ + if (backloged) + return 0; + } else { + if (!sctp_newsk_ready(sk)) { + if (!sk_add_backlog(sk, skb, READ_ONCE(sk->sk_rcvbuf))) + return 0; + sctp_chunk_free(chunk); + } else { + sctp_inq_push(inqueue, chunk); + } + } + +done: + /* Release the refs we took in sctp_add_backlog */ + if (SCTP_EP_TYPE_ASSOCIATION == rcvr->type) + sctp_transport_put(t); + else if (SCTP_EP_TYPE_SOCKET == rcvr->type) + sctp_endpoint_put(sctp_ep(rcvr)); + else + BUG(); + + return 0; +} + +static int sctp_add_backlog(struct sock *sk, struct sk_buff *skb) +{ + struct sctp_chunk *chunk = SCTP_INPUT_CB(skb)->chunk; + struct sctp_transport *t = chunk->transport; + struct sctp_ep_common *rcvr = chunk->rcvr; + int ret; + + ret = sk_add_backlog(sk, skb, READ_ONCE(sk->sk_rcvbuf)); + if (!ret) { + /* Hold the assoc/ep while hanging on the backlog queue. + * This way, we know structures we need will not disappear + * from us + */ + if (SCTP_EP_TYPE_ASSOCIATION == rcvr->type) + sctp_transport_hold(t); + else if (SCTP_EP_TYPE_SOCKET == rcvr->type) + sctp_endpoint_hold(sctp_ep(rcvr)); + else + BUG(); + } + return ret; + +} + +/* Handle icmp frag needed error. */ +void sctp_icmp_frag_needed(struct sock *sk, struct sctp_association *asoc, + struct sctp_transport *t, __u32 pmtu) +{ + if (!t || + (t->pathmtu <= pmtu && + t->pl.probe_size + sctp_transport_pl_hlen(t) <= pmtu)) + return; + + if (sock_owned_by_user(sk)) { + atomic_set(&t->mtu_info, pmtu); + asoc->pmtu_pending = 1; + t->pmtu_pending = 1; + return; + } + + if (!(t->param_flags & SPP_PMTUD_ENABLE)) + /* We can't allow retransmitting in such case, as the + * retransmission would be sized just as before, and thus we + * would get another icmp, and retransmit again. + */ + return; + + /* Update transports view of the MTU. Return if no update was needed. + * If an update wasn't needed/possible, it also doesn't make sense to + * try to retransmit now. + */ + if (!sctp_transport_update_pmtu(t, pmtu)) + return; + + /* Update association pmtu. */ + sctp_assoc_sync_pmtu(asoc); + + /* Retransmit with the new pmtu setting. */ + sctp_retransmit(&asoc->outqueue, t, SCTP_RTXR_PMTUD); +} + +void sctp_icmp_redirect(struct sock *sk, struct sctp_transport *t, + struct sk_buff *skb) +{ + struct dst_entry *dst; + + if (sock_owned_by_user(sk) || !t) + return; + dst = sctp_transport_dst_check(t); + if (dst) + dst->ops->redirect(dst, sk, skb); +} + +/* + * SCTP Implementer's Guide, 2.37 ICMP handling procedures + * + * ICMP8) If the ICMP code is a "Unrecognized next header type encountered" + * or a "Protocol Unreachable" treat this message as an abort + * with the T bit set. + * + * This function sends an event to the state machine, which will abort the + * association. + * + */ +void sctp_icmp_proto_unreachable(struct sock *sk, + struct sctp_association *asoc, + struct sctp_transport *t) +{ + if (sock_owned_by_user(sk)) { + if (timer_pending(&t->proto_unreach_timer)) + return; + else { + if (!mod_timer(&t->proto_unreach_timer, + jiffies + (HZ/20))) + sctp_transport_hold(t); + } + } else { + struct net *net = sock_net(sk); + + pr_debug("%s: unrecognized next header type " + "encountered!\n", __func__); + + if (del_timer(&t->proto_unreach_timer)) + sctp_transport_put(t); + + sctp_do_sm(net, SCTP_EVENT_T_OTHER, + SCTP_ST_OTHER(SCTP_EVENT_ICMP_PROTO_UNREACH), + asoc->state, asoc->ep, asoc, t, + GFP_ATOMIC); + } +} + +/* Common lookup code for icmp/icmpv6 error handler. */ +struct sock *sctp_err_lookup(struct net *net, int family, struct sk_buff *skb, + struct sctphdr *sctphdr, + struct sctp_association **app, + struct sctp_transport **tpp) +{ + struct sctp_init_chunk *chunkhdr, _chunkhdr; + union sctp_addr saddr; + union sctp_addr daddr; + struct sctp_af *af; + struct sock *sk = NULL; + struct sctp_association *asoc; + struct sctp_transport *transport = NULL; + __u32 vtag = ntohl(sctphdr->vtag); + + *app = NULL; *tpp = NULL; + + af = sctp_get_af_specific(family); + if (unlikely(!af)) { + return NULL; + } + + /* Initialize local addresses for lookups. */ + af->from_skb(&saddr, skb, 1); + af->from_skb(&daddr, skb, 0); + + /* Look for an association that matches the incoming ICMP error + * packet. + */ + asoc = __sctp_lookup_association(net, &saddr, &daddr, &transport); + if (!asoc) + return NULL; + + sk = asoc->base.sk; + + /* RFC 4960, Appendix C. ICMP Handling + * + * ICMP6) An implementation MUST validate that the Verification Tag + * contained in the ICMP message matches the Verification Tag of + * the peer. If the Verification Tag is not 0 and does NOT + * match, discard the ICMP message. If it is 0 and the ICMP + * message contains enough bytes to verify that the chunk type is + * an INIT chunk and that the Initiate Tag matches the tag of the + * peer, continue with ICMP7. If the ICMP message is too short + * or the chunk type or the Initiate Tag does not match, silently + * discard the packet. + */ + if (vtag == 0) { + /* chunk header + first 4 octects of init header */ + chunkhdr = skb_header_pointer(skb, skb_transport_offset(skb) + + sizeof(struct sctphdr), + sizeof(struct sctp_chunkhdr) + + sizeof(__be32), &_chunkhdr); + if (!chunkhdr || + chunkhdr->chunk_hdr.type != SCTP_CID_INIT || + ntohl(chunkhdr->init_hdr.init_tag) != asoc->c.my_vtag) + goto out; + + } else if (vtag != asoc->c.peer_vtag) { + goto out; + } + + bh_lock_sock(sk); + + /* If too many ICMPs get dropped on busy + * servers this needs to be solved differently. + */ + if (sock_owned_by_user(sk)) + __NET_INC_STATS(net, LINUX_MIB_LOCKDROPPEDICMPS); + + *app = asoc; + *tpp = transport; + return sk; + +out: + sctp_transport_put(transport); + return NULL; +} + +/* Common cleanup code for icmp/icmpv6 error handler. */ +void sctp_err_finish(struct sock *sk, struct sctp_transport *t) + __releases(&((__sk)->sk_lock.slock)) +{ + bh_unlock_sock(sk); + sctp_transport_put(t); +} + +static void sctp_v4_err_handle(struct sctp_transport *t, struct sk_buff *skb, + __u8 type, __u8 code, __u32 info) +{ + struct sctp_association *asoc = t->asoc; + struct sock *sk = asoc->base.sk; + int err = 0; + + switch (type) { + case ICMP_PARAMETERPROB: + err = EPROTO; + break; + case ICMP_DEST_UNREACH: + if (code > NR_ICMP_UNREACH) + return; + if (code == ICMP_FRAG_NEEDED) { + sctp_icmp_frag_needed(sk, asoc, t, SCTP_TRUNC4(info)); + return; + } + if (code == ICMP_PROT_UNREACH) { + sctp_icmp_proto_unreachable(sk, asoc, t); + return; + } + err = icmp_err_convert[code].errno; + break; + case ICMP_TIME_EXCEEDED: + if (code == ICMP_EXC_FRAGTIME) + return; + + err = EHOSTUNREACH; + break; + case ICMP_REDIRECT: + sctp_icmp_redirect(sk, t, skb); + return; + default: + return; + } + if (!sock_owned_by_user(sk) && inet_sk(sk)->recverr) { + sk->sk_err = err; + sk_error_report(sk); + } else { /* Only an error on timeout */ + sk->sk_err_soft = err; + } +} + +/* + * This routine is called by the ICMP module when it gets some + * sort of error condition. If err < 0 then the socket should + * be closed and the error returned to the user. If err > 0 + * it's just the icmp type << 8 | icmp code. After adjustment + * header points to the first 8 bytes of the sctp header. We need + * to find the appropriate port. + * + * The locking strategy used here is very "optimistic". When + * someone else accesses the socket the ICMP is just dropped + * and for some paths there is no check at all. + * A more general error queue to queue errors for later handling + * is probably better. + * + */ +int sctp_v4_err(struct sk_buff *skb, __u32 info) +{ + const struct iphdr *iph = (const struct iphdr *)skb->data; + const int type = icmp_hdr(skb)->type; + const int code = icmp_hdr(skb)->code; + struct net *net = dev_net(skb->dev); + struct sctp_transport *transport; + struct sctp_association *asoc; + __u16 saveip, savesctp; + struct sock *sk; + + /* Fix up skb to look at the embedded net header. */ + saveip = skb->network_header; + savesctp = skb->transport_header; + skb_reset_network_header(skb); + skb_set_transport_header(skb, iph->ihl * 4); + sk = sctp_err_lookup(net, AF_INET, skb, sctp_hdr(skb), &asoc, &transport); + /* Put back, the original values. */ + skb->network_header = saveip; + skb->transport_header = savesctp; + if (!sk) { + __ICMP_INC_STATS(net, ICMP_MIB_INERRORS); + return -ENOENT; + } + + sctp_v4_err_handle(transport, skb, type, code, info); + sctp_err_finish(sk, transport); + + return 0; +} + +int sctp_udp_v4_err(struct sock *sk, struct sk_buff *skb) +{ + struct net *net = dev_net(skb->dev); + struct sctp_association *asoc; + struct sctp_transport *t; + struct icmphdr *hdr; + __u32 info = 0; + + skb->transport_header += sizeof(struct udphdr); + sk = sctp_err_lookup(net, AF_INET, skb, sctp_hdr(skb), &asoc, &t); + if (!sk) { + __ICMP_INC_STATS(net, ICMP_MIB_INERRORS); + return -ENOENT; + } + + skb->transport_header -= sizeof(struct udphdr); + hdr = (struct icmphdr *)(skb_network_header(skb) - sizeof(struct icmphdr)); + if (hdr->type == ICMP_REDIRECT) { + /* can't be handled without outer iphdr known, leave it to udp_err */ + sctp_err_finish(sk, t); + return 0; + } + if (hdr->type == ICMP_DEST_UNREACH && hdr->code == ICMP_FRAG_NEEDED) + info = ntohs(hdr->un.frag.mtu); + sctp_v4_err_handle(t, skb, hdr->type, hdr->code, info); + + sctp_err_finish(sk, t); + return 1; +} + +/* + * RFC 2960, 8.4 - Handle "Out of the blue" Packets. + * + * This function scans all the chunks in the OOTB packet to determine if + * the packet should be discarded right away. If a response might be needed + * for this packet, or, if further processing is possible, the packet will + * be queued to a proper inqueue for the next phase of handling. + * + * Output: + * Return 0 - If further processing is needed. + * Return 1 - If the packet can be discarded right away. + */ +static int sctp_rcv_ootb(struct sk_buff *skb) +{ + struct sctp_chunkhdr *ch, _ch; + int ch_end, offset = 0; + + /* Scan through all the chunks in the packet. */ + do { + /* Make sure we have at least the header there */ + if (offset + sizeof(_ch) > skb->len) + break; + + ch = skb_header_pointer(skb, offset, sizeof(*ch), &_ch); + + /* Break out if chunk length is less then minimal. */ + if (!ch || ntohs(ch->length) < sizeof(_ch)) + break; + + ch_end = offset + SCTP_PAD4(ntohs(ch->length)); + if (ch_end > skb->len) + break; + + /* RFC 8.4, 2) If the OOTB packet contains an ABORT chunk, the + * receiver MUST silently discard the OOTB packet and take no + * further action. + */ + if (SCTP_CID_ABORT == ch->type) + goto discard; + + /* RFC 8.4, 6) If the packet contains a SHUTDOWN COMPLETE + * chunk, the receiver should silently discard the packet + * and take no further action. + */ + if (SCTP_CID_SHUTDOWN_COMPLETE == ch->type) + goto discard; + + /* RFC 4460, 2.11.2 + * This will discard packets with INIT chunk bundled as + * subsequent chunks in the packet. When INIT is first, + * the normal INIT processing will discard the chunk. + */ + if (SCTP_CID_INIT == ch->type && (void *)ch != skb->data) + goto discard; + + offset = ch_end; + } while (ch_end < skb->len); + + return 0; + +discard: + return 1; +} + +/* Insert endpoint into the hash table. */ +static int __sctp_hash_endpoint(struct sctp_endpoint *ep) +{ + struct sock *sk = ep->base.sk; + struct net *net = sock_net(sk); + struct sctp_hashbucket *head; + + ep->hashent = sctp_ep_hashfn(net, ep->base.bind_addr.port); + head = &sctp_ep_hashtable[ep->hashent]; + + if (sk->sk_reuseport) { + bool any = sctp_is_ep_boundall(sk); + struct sctp_endpoint *ep2; + struct list_head *list; + int cnt = 0, err = 1; + + list_for_each(list, &ep->base.bind_addr.address_list) + cnt++; + + sctp_for_each_hentry(ep2, &head->chain) { + struct sock *sk2 = ep2->base.sk; + + if (!net_eq(sock_net(sk2), net) || sk2 == sk || + !uid_eq(sock_i_uid(sk2), sock_i_uid(sk)) || + !sk2->sk_reuseport) + continue; + + err = sctp_bind_addrs_check(sctp_sk(sk2), + sctp_sk(sk), cnt); + if (!err) { + err = reuseport_add_sock(sk, sk2, any); + if (err) + return err; + break; + } else if (err < 0) { + return err; + } + } + + if (err) { + err = reuseport_alloc(sk, any); + if (err) + return err; + } + } + + write_lock(&head->lock); + hlist_add_head(&ep->node, &head->chain); + write_unlock(&head->lock); + return 0; +} + +/* Add an endpoint to the hash. Local BH-safe. */ +int sctp_hash_endpoint(struct sctp_endpoint *ep) +{ + int err; + + local_bh_disable(); + err = __sctp_hash_endpoint(ep); + local_bh_enable(); + + return err; +} + +/* Remove endpoint from the hash table. */ +static void __sctp_unhash_endpoint(struct sctp_endpoint *ep) +{ + struct sock *sk = ep->base.sk; + struct sctp_hashbucket *head; + + ep->hashent = sctp_ep_hashfn(sock_net(sk), ep->base.bind_addr.port); + + head = &sctp_ep_hashtable[ep->hashent]; + + if (rcu_access_pointer(sk->sk_reuseport_cb)) + reuseport_detach_sock(sk); + + write_lock(&head->lock); + hlist_del_init(&ep->node); + write_unlock(&head->lock); +} + +/* Remove endpoint from the hash. Local BH-safe. */ +void sctp_unhash_endpoint(struct sctp_endpoint *ep) +{ + local_bh_disable(); + __sctp_unhash_endpoint(ep); + local_bh_enable(); +} + +static inline __u32 sctp_hashfn(const struct net *net, __be16 lport, + const union sctp_addr *paddr, __u32 seed) +{ + __u32 addr; + + if (paddr->sa.sa_family == AF_INET6) + addr = jhash(&paddr->v6.sin6_addr, 16, seed); + else + addr = (__force __u32)paddr->v4.sin_addr.s_addr; + + return jhash_3words(addr, ((__force __u32)paddr->v4.sin_port) << 16 | + (__force __u32)lport, net_hash_mix(net), seed); +} + +/* Look up an endpoint. */ +static struct sctp_endpoint *__sctp_rcv_lookup_endpoint( + struct net *net, struct sk_buff *skb, + const union sctp_addr *laddr, + const union sctp_addr *paddr) +{ + struct sctp_hashbucket *head; + struct sctp_endpoint *ep; + struct sock *sk; + __be16 lport; + int hash; + + lport = laddr->v4.sin_port; + hash = sctp_ep_hashfn(net, ntohs(lport)); + head = &sctp_ep_hashtable[hash]; + read_lock(&head->lock); + sctp_for_each_hentry(ep, &head->chain) { + if (sctp_endpoint_is_match(ep, net, laddr)) + goto hit; + } + + ep = sctp_sk(net->sctp.ctl_sock)->ep; + +hit: + sk = ep->base.sk; + if (sk->sk_reuseport) { + __u32 phash = sctp_hashfn(net, lport, paddr, 0); + + sk = reuseport_select_sock(sk, phash, skb, + sizeof(struct sctphdr)); + if (sk) + ep = sctp_sk(sk)->ep; + } + sctp_endpoint_hold(ep); + read_unlock(&head->lock); + return ep; +} + +/* rhashtable for transport */ +struct sctp_hash_cmp_arg { + const union sctp_addr *paddr; + const struct net *net; + __be16 lport; +}; + +static inline int sctp_hash_cmp(struct rhashtable_compare_arg *arg, + const void *ptr) +{ + struct sctp_transport *t = (struct sctp_transport *)ptr; + const struct sctp_hash_cmp_arg *x = arg->key; + int err = 1; + + if (!sctp_cmp_addr_exact(&t->ipaddr, x->paddr)) + return err; + if (!sctp_transport_hold(t)) + return err; + + if (!net_eq(t->asoc->base.net, x->net)) + goto out; + if (x->lport != htons(t->asoc->base.bind_addr.port)) + goto out; + + err = 0; +out: + sctp_transport_put(t); + return err; +} + +static inline __u32 sctp_hash_obj(const void *data, u32 len, u32 seed) +{ + const struct sctp_transport *t = data; + + return sctp_hashfn(t->asoc->base.net, + htons(t->asoc->base.bind_addr.port), + &t->ipaddr, seed); +} + +static inline __u32 sctp_hash_key(const void *data, u32 len, u32 seed) +{ + const struct sctp_hash_cmp_arg *x = data; + + return sctp_hashfn(x->net, x->lport, x->paddr, seed); +} + +static const struct rhashtable_params sctp_hash_params = { + .head_offset = offsetof(struct sctp_transport, node), + .hashfn = sctp_hash_key, + .obj_hashfn = sctp_hash_obj, + .obj_cmpfn = sctp_hash_cmp, + .automatic_shrinking = true, +}; + +int sctp_transport_hashtable_init(void) +{ + return rhltable_init(&sctp_transport_hashtable, &sctp_hash_params); +} + +void sctp_transport_hashtable_destroy(void) +{ + rhltable_destroy(&sctp_transport_hashtable); +} + +int sctp_hash_transport(struct sctp_transport *t) +{ + struct sctp_transport *transport; + struct rhlist_head *tmp, *list; + struct sctp_hash_cmp_arg arg; + int err; + + if (t->asoc->temp) + return 0; + + arg.net = t->asoc->base.net; + arg.paddr = &t->ipaddr; + arg.lport = htons(t->asoc->base.bind_addr.port); + + rcu_read_lock(); + list = rhltable_lookup(&sctp_transport_hashtable, &arg, + sctp_hash_params); + + rhl_for_each_entry_rcu(transport, tmp, list, node) + if (transport->asoc->ep == t->asoc->ep) { + rcu_read_unlock(); + return -EEXIST; + } + rcu_read_unlock(); + + err = rhltable_insert_key(&sctp_transport_hashtable, &arg, + &t->node, sctp_hash_params); + if (err) + pr_err_once("insert transport fail, errno %d\n", err); + + return err; +} + +void sctp_unhash_transport(struct sctp_transport *t) +{ + if (t->asoc->temp) + return; + + rhltable_remove(&sctp_transport_hashtable, &t->node, + sctp_hash_params); +} + +/* return a transport with holding it */ +struct sctp_transport *sctp_addrs_lookup_transport( + struct net *net, + const union sctp_addr *laddr, + const union sctp_addr *paddr) +{ + struct rhlist_head *tmp, *list; + struct sctp_transport *t; + struct sctp_hash_cmp_arg arg = { + .paddr = paddr, + .net = net, + .lport = laddr->v4.sin_port, + }; + + list = rhltable_lookup(&sctp_transport_hashtable, &arg, + sctp_hash_params); + + rhl_for_each_entry_rcu(t, tmp, list, node) { + if (!sctp_transport_hold(t)) + continue; + + if (sctp_bind_addr_match(&t->asoc->base.bind_addr, + laddr, sctp_sk(t->asoc->base.sk))) + return t; + sctp_transport_put(t); + } + + return NULL; +} + +/* return a transport without holding it, as it's only used under sock lock */ +struct sctp_transport *sctp_epaddr_lookup_transport( + const struct sctp_endpoint *ep, + const union sctp_addr *paddr) +{ + struct rhlist_head *tmp, *list; + struct sctp_transport *t; + struct sctp_hash_cmp_arg arg = { + .paddr = paddr, + .net = ep->base.net, + .lport = htons(ep->base.bind_addr.port), + }; + + list = rhltable_lookup(&sctp_transport_hashtable, &arg, + sctp_hash_params); + + rhl_for_each_entry_rcu(t, tmp, list, node) + if (ep == t->asoc->ep) + return t; + + return NULL; +} + +/* Look up an association. */ +static struct sctp_association *__sctp_lookup_association( + struct net *net, + const union sctp_addr *local, + const union sctp_addr *peer, + struct sctp_transport **pt) +{ + struct sctp_transport *t; + struct sctp_association *asoc = NULL; + + t = sctp_addrs_lookup_transport(net, local, peer); + if (!t) + goto out; + + asoc = t->asoc; + *pt = t; + +out: + return asoc; +} + +/* Look up an association. protected by RCU read lock */ +static +struct sctp_association *sctp_lookup_association(struct net *net, + const union sctp_addr *laddr, + const union sctp_addr *paddr, + struct sctp_transport **transportp) +{ + struct sctp_association *asoc; + + rcu_read_lock(); + asoc = __sctp_lookup_association(net, laddr, paddr, transportp); + rcu_read_unlock(); + + return asoc; +} + +/* Is there an association matching the given local and peer addresses? */ +bool sctp_has_association(struct net *net, + const union sctp_addr *laddr, + const union sctp_addr *paddr) +{ + struct sctp_transport *transport; + + if (sctp_lookup_association(net, laddr, paddr, &transport)) { + sctp_transport_put(transport); + return true; + } + + return false; +} + +/* + * SCTP Implementors Guide, 2.18 Handling of address + * parameters within the INIT or INIT-ACK. + * + * D) When searching for a matching TCB upon reception of an INIT + * or INIT-ACK chunk the receiver SHOULD use not only the + * source address of the packet (containing the INIT or + * INIT-ACK) but the receiver SHOULD also use all valid + * address parameters contained within the chunk. + * + * 2.18.3 Solution description + * + * This new text clearly specifies to an implementor the need + * to look within the INIT or INIT-ACK. Any implementation that + * does not do this, may not be able to establish associations + * in certain circumstances. + * + */ +static struct sctp_association *__sctp_rcv_init_lookup(struct net *net, + struct sk_buff *skb, + const union sctp_addr *laddr, struct sctp_transport **transportp) +{ + struct sctp_association *asoc; + union sctp_addr addr; + union sctp_addr *paddr = &addr; + struct sctphdr *sh = sctp_hdr(skb); + union sctp_params params; + struct sctp_init_chunk *init; + struct sctp_af *af; + + /* + * This code will NOT touch anything inside the chunk--it is + * strictly READ-ONLY. + * + * RFC 2960 3 SCTP packet Format + * + * Multiple chunks can be bundled into one SCTP packet up to + * the MTU size, except for the INIT, INIT ACK, and SHUTDOWN + * COMPLETE chunks. These chunks MUST NOT be bundled with any + * other chunk in a packet. See Section 6.10 for more details + * on chunk bundling. + */ + + /* Find the start of the TLVs and the end of the chunk. This is + * the region we search for address parameters. + */ + init = (struct sctp_init_chunk *)skb->data; + + /* Walk the parameters looking for embedded addresses. */ + sctp_walk_params(params, init, init_hdr.params) { + + /* Note: Ignoring hostname addresses. */ + af = sctp_get_af_specific(param_type2af(params.p->type)); + if (!af) + continue; + + if (!af->from_addr_param(paddr, params.addr, sh->source, 0)) + continue; + + asoc = __sctp_lookup_association(net, laddr, paddr, transportp); + if (asoc) + return asoc; + } + + return NULL; +} + +/* ADD-IP, Section 5.2 + * When an endpoint receives an ASCONF Chunk from the remote peer + * special procedures may be needed to identify the association the + * ASCONF Chunk is associated with. To properly find the association + * the following procedures SHOULD be followed: + * + * D2) If the association is not found, use the address found in the + * Address Parameter TLV combined with the port number found in the + * SCTP common header. If found proceed to rule D4. + * + * D2-ext) If more than one ASCONF Chunks are packed together, use the + * address found in the ASCONF Address Parameter TLV of each of the + * subsequent ASCONF Chunks. If found, proceed to rule D4. + */ +static struct sctp_association *__sctp_rcv_asconf_lookup( + struct net *net, + struct sctp_chunkhdr *ch, + const union sctp_addr *laddr, + __be16 peer_port, + struct sctp_transport **transportp) +{ + struct sctp_addip_chunk *asconf = (struct sctp_addip_chunk *)ch; + struct sctp_af *af; + union sctp_addr_param *param; + union sctp_addr paddr; + + if (ntohs(ch->length) < sizeof(*asconf) + sizeof(struct sctp_paramhdr)) + return NULL; + + /* Skip over the ADDIP header and find the Address parameter */ + param = (union sctp_addr_param *)(asconf + 1); + + af = sctp_get_af_specific(param_type2af(param->p.type)); + if (unlikely(!af)) + return NULL; + + if (!af->from_addr_param(&paddr, param, peer_port, 0)) + return NULL; + + return __sctp_lookup_association(net, laddr, &paddr, transportp); +} + + +/* SCTP-AUTH, Section 6.3: +* If the receiver does not find a STCB for a packet containing an AUTH +* chunk as the first chunk and not a COOKIE-ECHO chunk as the second +* chunk, it MUST use the chunks after the AUTH chunk to look up an existing +* association. +* +* This means that any chunks that can help us identify the association need +* to be looked at to find this association. +*/ +static struct sctp_association *__sctp_rcv_walk_lookup(struct net *net, + struct sk_buff *skb, + const union sctp_addr *laddr, + struct sctp_transport **transportp) +{ + struct sctp_association *asoc = NULL; + struct sctp_chunkhdr *ch; + int have_auth = 0; + unsigned int chunk_num = 1; + __u8 *ch_end; + + /* Walk through the chunks looking for AUTH or ASCONF chunks + * to help us find the association. + */ + ch = (struct sctp_chunkhdr *)skb->data; + do { + /* Break out if chunk length is less then minimal. */ + if (ntohs(ch->length) < sizeof(*ch)) + break; + + ch_end = ((__u8 *)ch) + SCTP_PAD4(ntohs(ch->length)); + if (ch_end > skb_tail_pointer(skb)) + break; + + switch (ch->type) { + case SCTP_CID_AUTH: + have_auth = chunk_num; + break; + + case SCTP_CID_COOKIE_ECHO: + /* If a packet arrives containing an AUTH chunk as + * a first chunk, a COOKIE-ECHO chunk as the second + * chunk, and possibly more chunks after them, and + * the receiver does not have an STCB for that + * packet, then authentication is based on + * the contents of the COOKIE- ECHO chunk. + */ + if (have_auth == 1 && chunk_num == 2) + return NULL; + break; + + case SCTP_CID_ASCONF: + if (have_auth || net->sctp.addip_noauth) + asoc = __sctp_rcv_asconf_lookup( + net, ch, laddr, + sctp_hdr(skb)->source, + transportp); + break; + default: + break; + } + + if (asoc) + break; + + ch = (struct sctp_chunkhdr *)ch_end; + chunk_num++; + } while (ch_end + sizeof(*ch) < skb_tail_pointer(skb)); + + return asoc; +} + +/* + * There are circumstances when we need to look inside the SCTP packet + * for information to help us find the association. Examples + * include looking inside of INIT/INIT-ACK chunks or after the AUTH + * chunks. + */ +static struct sctp_association *__sctp_rcv_lookup_harder(struct net *net, + struct sk_buff *skb, + const union sctp_addr *laddr, + struct sctp_transport **transportp) +{ + struct sctp_chunkhdr *ch; + + /* We do not allow GSO frames here as we need to linearize and + * then cannot guarantee frame boundaries. This shouldn't be an + * issue as packets hitting this are mostly INIT or INIT-ACK and + * those cannot be on GSO-style anyway. + */ + if (skb_is_gso(skb) && skb_is_gso_sctp(skb)) + return NULL; + + ch = (struct sctp_chunkhdr *)skb->data; + + /* The code below will attempt to walk the chunk and extract + * parameter information. Before we do that, we need to verify + * that the chunk length doesn't cause overflow. Otherwise, we'll + * walk off the end. + */ + if (SCTP_PAD4(ntohs(ch->length)) > skb->len) + return NULL; + + /* If this is INIT/INIT-ACK look inside the chunk too. */ + if (ch->type == SCTP_CID_INIT || ch->type == SCTP_CID_INIT_ACK) + return __sctp_rcv_init_lookup(net, skb, laddr, transportp); + + return __sctp_rcv_walk_lookup(net, skb, laddr, transportp); +} + +/* Lookup an association for an inbound skb. */ +static struct sctp_association *__sctp_rcv_lookup(struct net *net, + struct sk_buff *skb, + const union sctp_addr *paddr, + const union sctp_addr *laddr, + struct sctp_transport **transportp) +{ + struct sctp_association *asoc; + + asoc = __sctp_lookup_association(net, laddr, paddr, transportp); + if (asoc) + goto out; + + /* Further lookup for INIT/INIT-ACK packets. + * SCTP Implementors Guide, 2.18 Handling of address + * parameters within the INIT or INIT-ACK. + */ + asoc = __sctp_rcv_lookup_harder(net, skb, laddr, transportp); + if (asoc) + goto out; + + if (paddr->sa.sa_family == AF_INET) + pr_debug("sctp: asoc not found for src:%pI4:%d dst:%pI4:%d\n", + &laddr->v4.sin_addr, ntohs(laddr->v4.sin_port), + &paddr->v4.sin_addr, ntohs(paddr->v4.sin_port)); + else + pr_debug("sctp: asoc not found for src:%pI6:%d dst:%pI6:%d\n", + &laddr->v6.sin6_addr, ntohs(laddr->v6.sin6_port), + &paddr->v6.sin6_addr, ntohs(paddr->v6.sin6_port)); + +out: + return asoc; +} diff --git a/net/sctp/inqueue.c b/net/sctp/inqueue.c new file mode 100644 index 000000000..7182c5a45 --- /dev/null +++ b/net/sctp/inqueue.c @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * Copyright (c) 2002 International Business Machines, Corp. + * + * This file is part of the SCTP kernel implementation + * + * These functions are the methods for accessing the SCTP inqueue. + * + * An SCTP inqueue is a queue into which you push SCTP packets + * (which might be bundles or fragments of chunks) and out of which you + * pop SCTP whole chunks. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * La Monte H.P. Yarroll + * Karl Knutson + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include + +/* Initialize an SCTP inqueue. */ +void sctp_inq_init(struct sctp_inq *queue) +{ + INIT_LIST_HEAD(&queue->in_chunk_list); + queue->in_progress = NULL; + + /* Create a task for delivering data. */ + INIT_WORK(&queue->immediate, NULL); +} + +/* Release the memory associated with an SCTP inqueue. */ +void sctp_inq_free(struct sctp_inq *queue) +{ + struct sctp_chunk *chunk, *tmp; + + /* Empty the queue. */ + list_for_each_entry_safe(chunk, tmp, &queue->in_chunk_list, list) { + list_del_init(&chunk->list); + sctp_chunk_free(chunk); + } + + /* If there is a packet which is currently being worked on, + * free it as well. + */ + if (queue->in_progress) { + sctp_chunk_free(queue->in_progress); + queue->in_progress = NULL; + } +} + +/* Put a new packet in an SCTP inqueue. + * We assume that packet->sctp_hdr is set and in host byte order. + */ +void sctp_inq_push(struct sctp_inq *q, struct sctp_chunk *chunk) +{ + /* Directly call the packet handling routine. */ + if (chunk->rcvr->dead) { + sctp_chunk_free(chunk); + return; + } + + /* We are now calling this either from the soft interrupt + * or from the backlog processing. + * Eventually, we should clean up inqueue to not rely + * on the BH related data structures. + */ + list_add_tail(&chunk->list, &q->in_chunk_list); + if (chunk->asoc) + chunk->asoc->stats.ipackets++; + q->immediate.func(&q->immediate); +} + +/* Peek at the next chunk on the inqeue. */ +struct sctp_chunkhdr *sctp_inq_peek(struct sctp_inq *queue) +{ + struct sctp_chunk *chunk; + struct sctp_chunkhdr *ch = NULL; + + chunk = queue->in_progress; + /* If there is no more chunks in this packet, say so */ + if (chunk->singleton || + chunk->end_of_packet || + chunk->pdiscard) + return NULL; + + ch = (struct sctp_chunkhdr *)chunk->chunk_end; + + return ch; +} + + +/* Extract a chunk from an SCTP inqueue. + * + * WARNING: If you need to put the chunk on another queue, you need to + * make a shallow copy (clone) of it. + */ +struct sctp_chunk *sctp_inq_pop(struct sctp_inq *queue) +{ + struct sctp_chunk *chunk; + struct sctp_chunkhdr *ch = NULL; + + /* The assumption is that we are safe to process the chunks + * at this time. + */ + + chunk = queue->in_progress; + if (chunk) { + /* There is a packet that we have been working on. + * Any post processing work to do before we move on? + */ + if (chunk->singleton || + chunk->end_of_packet || + chunk->pdiscard) { + if (chunk->head_skb == chunk->skb) { + chunk->skb = skb_shinfo(chunk->skb)->frag_list; + goto new_skb; + } + if (chunk->skb->next) { + chunk->skb = chunk->skb->next; + goto new_skb; + } + + if (chunk->head_skb) + chunk->skb = chunk->head_skb; + sctp_chunk_free(chunk); + chunk = queue->in_progress = NULL; + } else { + /* Nothing to do. Next chunk in the packet, please. */ + ch = (struct sctp_chunkhdr *)chunk->chunk_end; + /* Force chunk->skb->data to chunk->chunk_end. */ + skb_pull(chunk->skb, chunk->chunk_end - chunk->skb->data); + /* We are guaranteed to pull a SCTP header. */ + } + } + + /* Do we need to take the next packet out of the queue to process? */ + if (!chunk) { + struct list_head *entry; + +next_chunk: + /* Is the queue empty? */ + entry = sctp_list_dequeue(&queue->in_chunk_list); + if (!entry) + return NULL; + + chunk = list_entry(entry, struct sctp_chunk, list); + + if (skb_is_gso(chunk->skb) && skb_is_gso_sctp(chunk->skb)) { + /* GSO-marked skbs but without frags, handle + * them normally + */ + if (skb_shinfo(chunk->skb)->frag_list) + chunk->head_skb = chunk->skb; + + /* skbs with "cover letter" */ + if (chunk->head_skb && chunk->skb->data_len == chunk->skb->len) + chunk->skb = skb_shinfo(chunk->skb)->frag_list; + + if (WARN_ON(!chunk->skb)) { + __SCTP_INC_STATS(dev_net(chunk->skb->dev), SCTP_MIB_IN_PKT_DISCARDS); + sctp_chunk_free(chunk); + goto next_chunk; + } + } + + if (chunk->asoc) + sock_rps_save_rxhash(chunk->asoc->base.sk, chunk->skb); + + queue->in_progress = chunk; + +new_skb: + /* This is the first chunk in the packet. */ + ch = (struct sctp_chunkhdr *)chunk->skb->data; + chunk->singleton = 1; + chunk->data_accepted = 0; + chunk->pdiscard = 0; + chunk->auth = 0; + chunk->has_asconf = 0; + chunk->end_of_packet = 0; + if (chunk->head_skb) { + struct sctp_input_cb + *cb = SCTP_INPUT_CB(chunk->skb), + *head_cb = SCTP_INPUT_CB(chunk->head_skb); + + cb->chunk = head_cb->chunk; + cb->af = head_cb->af; + } + } + + chunk->chunk_hdr = ch; + chunk->chunk_end = ((__u8 *)ch) + SCTP_PAD4(ntohs(ch->length)); + skb_pull(chunk->skb, sizeof(*ch)); + chunk->subh.v = NULL; /* Subheader is no longer valid. */ + + if (chunk->chunk_end + sizeof(*ch) <= skb_tail_pointer(chunk->skb)) { + /* This is not a singleton */ + chunk->singleton = 0; + } else if (chunk->chunk_end > skb_tail_pointer(chunk->skb)) { + /* Discard inside state machine. */ + chunk->pdiscard = 1; + chunk->chunk_end = skb_tail_pointer(chunk->skb); + } else { + /* We are at the end of the packet, so mark the chunk + * in case we need to send a SACK. + */ + chunk->end_of_packet = 1; + } + + pr_debug("+++sctp_inq_pop+++ chunk:%p[%s], length:%d, skb->len:%d\n", + chunk, sctp_cname(SCTP_ST_CHUNK(chunk->chunk_hdr->type)), + ntohs(chunk->chunk_hdr->length), chunk->skb->len); + + return chunk; +} + +/* Set a top-half handler. + * + * Originally, we the top-half handler was scheduled as a BH. We now + * call the handler directly in sctp_inq_push() at a time that + * we know we are lock safe. + * The intent is that this routine will pull stuff out of the + * inqueue and process it. + */ +void sctp_inq_set_th_handler(struct sctp_inq *q, work_func_t callback) +{ + INIT_WORK(&q->immediate, callback); +} diff --git a/net/sctp/ipv6.c b/net/sctp/ipv6.c new file mode 100644 index 000000000..d081858c2 --- /dev/null +++ b/net/sctp/ipv6.c @@ -0,0 +1,1217 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2002, 2004 + * Copyright (c) 2001 Nokia, Inc. + * Copyright (c) 2001 La Monte H.P. Yarroll + * Copyright (c) 2002-2003 Intel Corp. + * + * This file is part of the SCTP kernel implementation + * + * SCTP over IPv6. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * Le Yanqun + * Hui Huang + * La Monte H.P. Yarroll + * Sridhar Samudrala + * Jon Grimm + * Ardelle Fan + * + * Based on: + * linux/net/ipv6/tcp_ipv6.c + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +static inline int sctp_v6_addr_match_len(union sctp_addr *s1, + union sctp_addr *s2); +static void sctp_v6_to_addr(union sctp_addr *addr, struct in6_addr *saddr, + __be16 port); +static int sctp_v6_cmp_addr(const union sctp_addr *addr1, + const union sctp_addr *addr2); + +/* Event handler for inet6 address addition/deletion events. + * The sctp_local_addr_list needs to be protocted by a spin lock since + * multiple notifiers (say IPv4 and IPv6) may be running at the same + * time and thus corrupt the list. + * The reader side is protected with RCU. + */ +static int sctp_inet6addr_event(struct notifier_block *this, unsigned long ev, + void *ptr) +{ + struct inet6_ifaddr *ifa = (struct inet6_ifaddr *)ptr; + struct sctp_sockaddr_entry *addr = NULL; + struct sctp_sockaddr_entry *temp; + struct net *net = dev_net(ifa->idev->dev); + int found = 0; + + switch (ev) { + case NETDEV_UP: + addr = kzalloc(sizeof(*addr), GFP_ATOMIC); + if (addr) { + addr->a.v6.sin6_family = AF_INET6; + addr->a.v6.sin6_addr = ifa->addr; + addr->a.v6.sin6_scope_id = ifa->idev->dev->ifindex; + addr->valid = 1; + spin_lock_bh(&net->sctp.local_addr_lock); + list_add_tail_rcu(&addr->list, &net->sctp.local_addr_list); + sctp_addr_wq_mgmt(net, addr, SCTP_ADDR_NEW); + spin_unlock_bh(&net->sctp.local_addr_lock); + } + break; + case NETDEV_DOWN: + spin_lock_bh(&net->sctp.local_addr_lock); + list_for_each_entry_safe(addr, temp, + &net->sctp.local_addr_list, list) { + if (addr->a.sa.sa_family == AF_INET6 && + ipv6_addr_equal(&addr->a.v6.sin6_addr, + &ifa->addr) && + addr->a.v6.sin6_scope_id == ifa->idev->dev->ifindex) { + sctp_addr_wq_mgmt(net, addr, SCTP_ADDR_DEL); + found = 1; + addr->valid = 0; + list_del_rcu(&addr->list); + break; + } + } + spin_unlock_bh(&net->sctp.local_addr_lock); + if (found) + kfree_rcu(addr, rcu); + break; + } + + return NOTIFY_DONE; +} + +static struct notifier_block sctp_inet6addr_notifier = { + .notifier_call = sctp_inet6addr_event, +}; + +static void sctp_v6_err_handle(struct sctp_transport *t, struct sk_buff *skb, + __u8 type, __u8 code, __u32 info) +{ + struct sctp_association *asoc = t->asoc; + struct sock *sk = asoc->base.sk; + struct ipv6_pinfo *np; + int err = 0; + + switch (type) { + case ICMPV6_PKT_TOOBIG: + if (ip6_sk_accept_pmtu(sk)) + sctp_icmp_frag_needed(sk, asoc, t, info); + return; + case ICMPV6_PARAMPROB: + if (ICMPV6_UNK_NEXTHDR == code) { + sctp_icmp_proto_unreachable(sk, asoc, t); + return; + } + break; + case NDISC_REDIRECT: + sctp_icmp_redirect(sk, t, skb); + return; + default: + break; + } + + np = inet6_sk(sk); + icmpv6_err_convert(type, code, &err); + if (!sock_owned_by_user(sk) && np->recverr) { + sk->sk_err = err; + sk_error_report(sk); + } else { + sk->sk_err_soft = err; + } +} + +/* ICMP error handler. */ +static int sctp_v6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + struct net *net = dev_net(skb->dev); + struct sctp_transport *transport; + struct sctp_association *asoc; + __u16 saveip, savesctp; + struct sock *sk; + + /* Fix up skb to look at the embedded net header. */ + saveip = skb->network_header; + savesctp = skb->transport_header; + skb_reset_network_header(skb); + skb_set_transport_header(skb, offset); + sk = sctp_err_lookup(net, AF_INET6, skb, sctp_hdr(skb), &asoc, &transport); + /* Put back, the original pointers. */ + skb->network_header = saveip; + skb->transport_header = savesctp; + if (!sk) { + __ICMP6_INC_STATS(net, __in6_dev_get(skb->dev), ICMP6_MIB_INERRORS); + return -ENOENT; + } + + sctp_v6_err_handle(transport, skb, type, code, ntohl(info)); + sctp_err_finish(sk, transport); + + return 0; +} + +int sctp_udp_v6_err(struct sock *sk, struct sk_buff *skb) +{ + struct net *net = dev_net(skb->dev); + struct sctp_association *asoc; + struct sctp_transport *t; + struct icmp6hdr *hdr; + __u32 info = 0; + + skb->transport_header += sizeof(struct udphdr); + sk = sctp_err_lookup(net, AF_INET6, skb, sctp_hdr(skb), &asoc, &t); + if (!sk) { + __ICMP6_INC_STATS(net, __in6_dev_get(skb->dev), ICMP6_MIB_INERRORS); + return -ENOENT; + } + + skb->transport_header -= sizeof(struct udphdr); + hdr = (struct icmp6hdr *)(skb_network_header(skb) - sizeof(struct icmp6hdr)); + if (hdr->icmp6_type == NDISC_REDIRECT) { + /* can't be handled without outer ip6hdr known, leave it to udpv6_err */ + sctp_err_finish(sk, t); + return 0; + } + if (hdr->icmp6_type == ICMPV6_PKT_TOOBIG) + info = ntohl(hdr->icmp6_mtu); + sctp_v6_err_handle(t, skb, hdr->icmp6_type, hdr->icmp6_code, info); + + sctp_err_finish(sk, t); + return 1; +} + +static int sctp_v6_xmit(struct sk_buff *skb, struct sctp_transport *t) +{ + struct dst_entry *dst = dst_clone(t->dst); + struct flowi6 *fl6 = &t->fl.u.ip6; + struct sock *sk = skb->sk; + struct ipv6_pinfo *np = inet6_sk(sk); + __u8 tclass = np->tclass; + __be32 label; + + pr_debug("%s: skb:%p, len:%d, src:%pI6 dst:%pI6\n", __func__, skb, + skb->len, &fl6->saddr, &fl6->daddr); + + if (t->dscp & SCTP_DSCP_SET_MASK) + tclass = t->dscp & SCTP_DSCP_VAL_MASK; + + if (INET_ECN_is_capable(tclass)) + IP6_ECN_flow_xmit(sk, fl6->flowlabel); + + if (!(t->param_flags & SPP_PMTUD_ENABLE)) + skb->ignore_df = 1; + + SCTP_INC_STATS(sock_net(sk), SCTP_MIB_OUTSCTPPACKS); + + if (!t->encap_port || !sctp_sk(sk)->udp_port) { + int res; + + skb_dst_set(skb, dst); + rcu_read_lock(); + res = ip6_xmit(sk, skb, fl6, sk->sk_mark, + rcu_dereference(np->opt), + tclass, sk->sk_priority); + rcu_read_unlock(); + return res; + } + + if (skb_is_gso(skb)) + skb_shinfo(skb)->gso_type |= SKB_GSO_UDP_TUNNEL_CSUM; + + skb->encapsulation = 1; + skb_reset_inner_mac_header(skb); + skb_reset_inner_transport_header(skb); + skb_set_inner_ipproto(skb, IPPROTO_SCTP); + label = ip6_make_flowlabel(sock_net(sk), skb, fl6->flowlabel, true, fl6); + + return udp_tunnel6_xmit_skb(dst, sk, skb, NULL, &fl6->saddr, + &fl6->daddr, tclass, ip6_dst_hoplimit(dst), + label, sctp_sk(sk)->udp_port, t->encap_port, false); +} + +/* Returns the dst cache entry for the given source and destination ip + * addresses. + */ +static void sctp_v6_get_dst(struct sctp_transport *t, union sctp_addr *saddr, + struct flowi *fl, struct sock *sk) +{ + struct sctp_association *asoc = t->asoc; + struct dst_entry *dst = NULL; + struct flowi _fl; + struct flowi6 *fl6 = &_fl.u.ip6; + struct sctp_bind_addr *bp; + struct ipv6_pinfo *np = inet6_sk(sk); + struct sctp_sockaddr_entry *laddr; + union sctp_addr *daddr = &t->ipaddr; + union sctp_addr dst_saddr; + struct in6_addr *final_p, final; + enum sctp_scope scope; + __u8 matchlen = 0; + + memset(&_fl, 0, sizeof(_fl)); + fl6->daddr = daddr->v6.sin6_addr; + fl6->fl6_dport = daddr->v6.sin6_port; + fl6->flowi6_proto = IPPROTO_SCTP; + if (ipv6_addr_type(&daddr->v6.sin6_addr) & IPV6_ADDR_LINKLOCAL) + fl6->flowi6_oif = daddr->v6.sin6_scope_id; + else if (asoc) + fl6->flowi6_oif = asoc->base.sk->sk_bound_dev_if; + if (t->flowlabel & SCTP_FLOWLABEL_SET_MASK) + fl6->flowlabel = htonl(t->flowlabel & SCTP_FLOWLABEL_VAL_MASK); + + if (np->sndflow && (fl6->flowlabel & IPV6_FLOWLABEL_MASK)) { + struct ip6_flowlabel *flowlabel; + + flowlabel = fl6_sock_lookup(sk, fl6->flowlabel); + if (IS_ERR(flowlabel)) + goto out; + fl6_sock_release(flowlabel); + } + + pr_debug("%s: dst=%pI6 ", __func__, &fl6->daddr); + + if (asoc) + fl6->fl6_sport = htons(asoc->base.bind_addr.port); + + if (saddr) { + fl6->saddr = saddr->v6.sin6_addr; + if (!fl6->fl6_sport) + fl6->fl6_sport = saddr->v6.sin6_port; + + pr_debug("src=%pI6 - ", &fl6->saddr); + } + + rcu_read_lock(); + final_p = fl6_update_dst(fl6, rcu_dereference(np->opt), &final); + rcu_read_unlock(); + + dst = ip6_dst_lookup_flow(sock_net(sk), sk, fl6, final_p); + if (!asoc || saddr) { + t->dst = dst; + memcpy(fl, &_fl, sizeof(_fl)); + goto out; + } + + bp = &asoc->base.bind_addr; + scope = sctp_scope(daddr); + /* ip6_dst_lookup has filled in the fl6->saddr for us. Check + * to see if we can use it. + */ + if (!IS_ERR(dst)) { + /* Walk through the bind address list and look for a bind + * address that matches the source address of the returned dst. + */ + sctp_v6_to_addr(&dst_saddr, &fl6->saddr, htons(bp->port)); + rcu_read_lock(); + list_for_each_entry_rcu(laddr, &bp->address_list, list) { + if (!laddr->valid || laddr->state == SCTP_ADDR_DEL || + (laddr->state != SCTP_ADDR_SRC && + !asoc->src_out_of_asoc_ok)) + continue; + + /* Do not compare against v4 addrs */ + if ((laddr->a.sa.sa_family == AF_INET6) && + (sctp_v6_cmp_addr(&dst_saddr, &laddr->a))) { + rcu_read_unlock(); + t->dst = dst; + memcpy(fl, &_fl, sizeof(_fl)); + goto out; + } + } + rcu_read_unlock(); + /* None of the bound addresses match the source address of the + * dst. So release it. + */ + dst_release(dst); + dst = NULL; + } + + /* Walk through the bind address list and try to get the + * best source address for a given destination. + */ + rcu_read_lock(); + list_for_each_entry_rcu(laddr, &bp->address_list, list) { + struct dst_entry *bdst; + __u8 bmatchlen; + + if (!laddr->valid || + laddr->state != SCTP_ADDR_SRC || + laddr->a.sa.sa_family != AF_INET6 || + scope > sctp_scope(&laddr->a)) + continue; + + fl6->saddr = laddr->a.v6.sin6_addr; + fl6->fl6_sport = laddr->a.v6.sin6_port; + final_p = fl6_update_dst(fl6, rcu_dereference(np->opt), &final); + bdst = ip6_dst_lookup_flow(sock_net(sk), sk, fl6, final_p); + + if (IS_ERR(bdst)) + continue; + + if (ipv6_chk_addr(dev_net(bdst->dev), + &laddr->a.v6.sin6_addr, bdst->dev, 1)) { + if (!IS_ERR_OR_NULL(dst)) + dst_release(dst); + dst = bdst; + t->dst = dst; + memcpy(fl, &_fl, sizeof(_fl)); + break; + } + + bmatchlen = sctp_v6_addr_match_len(daddr, &laddr->a); + if (matchlen > bmatchlen) { + dst_release(bdst); + continue; + } + + if (!IS_ERR_OR_NULL(dst)) + dst_release(dst); + dst = bdst; + matchlen = bmatchlen; + t->dst = dst; + memcpy(fl, &_fl, sizeof(_fl)); + } + rcu_read_unlock(); + +out: + if (!IS_ERR_OR_NULL(dst)) { + struct rt6_info *rt; + + rt = (struct rt6_info *)dst; + t->dst_cookie = rt6_get_cookie(rt); + pr_debug("rt6_dst:%pI6/%d rt6_src:%pI6\n", + &rt->rt6i_dst.addr, rt->rt6i_dst.plen, + &fl->u.ip6.saddr); + } else { + t->dst = NULL; + pr_debug("no route\n"); + } +} + +/* Returns the number of consecutive initial bits that match in the 2 ipv6 + * addresses. + */ +static inline int sctp_v6_addr_match_len(union sctp_addr *s1, + union sctp_addr *s2) +{ + return ipv6_addr_diff(&s1->v6.sin6_addr, &s2->v6.sin6_addr); +} + +/* Fills in the source address(saddr) based on the destination address(daddr) + * and asoc's bind address list. + */ +static void sctp_v6_get_saddr(struct sctp_sock *sk, + struct sctp_transport *t, + struct flowi *fl) +{ + struct flowi6 *fl6 = &fl->u.ip6; + union sctp_addr *saddr = &t->saddr; + + pr_debug("%s: asoc:%p dst:%p\n", __func__, t->asoc, t->dst); + + if (t->dst) { + saddr->v6.sin6_family = AF_INET6; + saddr->v6.sin6_addr = fl6->saddr; + } +} + +/* Make a copy of all potential local addresses. */ +static void sctp_v6_copy_addrlist(struct list_head *addrlist, + struct net_device *dev) +{ + struct inet6_dev *in6_dev; + struct inet6_ifaddr *ifp; + struct sctp_sockaddr_entry *addr; + + rcu_read_lock(); + if ((in6_dev = __in6_dev_get(dev)) == NULL) { + rcu_read_unlock(); + return; + } + + read_lock_bh(&in6_dev->lock); + list_for_each_entry(ifp, &in6_dev->addr_list, if_list) { + /* Add the address to the local list. */ + addr = kzalloc(sizeof(*addr), GFP_ATOMIC); + if (addr) { + addr->a.v6.sin6_family = AF_INET6; + addr->a.v6.sin6_addr = ifp->addr; + addr->a.v6.sin6_scope_id = dev->ifindex; + addr->valid = 1; + INIT_LIST_HEAD(&addr->list); + list_add_tail(&addr->list, addrlist); + } + } + + read_unlock_bh(&in6_dev->lock); + rcu_read_unlock(); +} + +/* Copy over any ip options */ +static void sctp_v6_copy_ip_options(struct sock *sk, struct sock *newsk) +{ + struct ipv6_pinfo *newnp, *np = inet6_sk(sk); + struct ipv6_txoptions *opt; + + newnp = inet6_sk(newsk); + + rcu_read_lock(); + opt = rcu_dereference(np->opt); + if (opt) { + opt = ipv6_dup_options(newsk, opt); + if (!opt) + pr_err("%s: Failed to copy ip options\n", __func__); + } + RCU_INIT_POINTER(newnp->opt, opt); + rcu_read_unlock(); +} + +/* Account for the IP options */ +static int sctp_v6_ip_options_len(struct sock *sk) +{ + struct ipv6_pinfo *np = inet6_sk(sk); + struct ipv6_txoptions *opt; + int len = 0; + + rcu_read_lock(); + opt = rcu_dereference(np->opt); + if (opt) + len = opt->opt_flen + opt->opt_nflen; + + rcu_read_unlock(); + return len; +} + +/* Initialize a sockaddr_storage from in incoming skb. */ +static void sctp_v6_from_skb(union sctp_addr *addr, struct sk_buff *skb, + int is_saddr) +{ + /* Always called on head skb, so this is safe */ + struct sctphdr *sh = sctp_hdr(skb); + struct sockaddr_in6 *sa = &addr->v6; + + addr->v6.sin6_family = AF_INET6; + addr->v6.sin6_flowinfo = 0; /* FIXME */ + addr->v6.sin6_scope_id = ((struct inet6_skb_parm *)skb->cb)->iif; + + if (is_saddr) { + sa->sin6_port = sh->source; + sa->sin6_addr = ipv6_hdr(skb)->saddr; + } else { + sa->sin6_port = sh->dest; + sa->sin6_addr = ipv6_hdr(skb)->daddr; + } +} + +/* Initialize an sctp_addr from a socket. */ +static void sctp_v6_from_sk(union sctp_addr *addr, struct sock *sk) +{ + addr->v6.sin6_family = AF_INET6; + addr->v6.sin6_port = 0; + addr->v6.sin6_addr = sk->sk_v6_rcv_saddr; +} + +/* Initialize sk->sk_rcv_saddr from sctp_addr. */ +static void sctp_v6_to_sk_saddr(union sctp_addr *addr, struct sock *sk) +{ + if (addr->sa.sa_family == AF_INET) { + sk->sk_v6_rcv_saddr.s6_addr32[0] = 0; + sk->sk_v6_rcv_saddr.s6_addr32[1] = 0; + sk->sk_v6_rcv_saddr.s6_addr32[2] = htonl(0x0000ffff); + sk->sk_v6_rcv_saddr.s6_addr32[3] = + addr->v4.sin_addr.s_addr; + } else { + sk->sk_v6_rcv_saddr = addr->v6.sin6_addr; + } +} + +/* Initialize sk->sk_daddr from sctp_addr. */ +static void sctp_v6_to_sk_daddr(union sctp_addr *addr, struct sock *sk) +{ + if (addr->sa.sa_family == AF_INET) { + sk->sk_v6_daddr.s6_addr32[0] = 0; + sk->sk_v6_daddr.s6_addr32[1] = 0; + sk->sk_v6_daddr.s6_addr32[2] = htonl(0x0000ffff); + sk->sk_v6_daddr.s6_addr32[3] = addr->v4.sin_addr.s_addr; + } else { + sk->sk_v6_daddr = addr->v6.sin6_addr; + } +} + +/* Initialize a sctp_addr from an address parameter. */ +static bool sctp_v6_from_addr_param(union sctp_addr *addr, + union sctp_addr_param *param, + __be16 port, int iif) +{ + if (ntohs(param->v6.param_hdr.length) < sizeof(struct sctp_ipv6addr_param)) + return false; + + addr->v6.sin6_family = AF_INET6; + addr->v6.sin6_port = port; + addr->v6.sin6_flowinfo = 0; /* BUG */ + addr->v6.sin6_addr = param->v6.addr; + addr->v6.sin6_scope_id = iif; + + return true; +} + +/* Initialize an address parameter from a sctp_addr and return the length + * of the address parameter. + */ +static int sctp_v6_to_addr_param(const union sctp_addr *addr, + union sctp_addr_param *param) +{ + int length = sizeof(struct sctp_ipv6addr_param); + + param->v6.param_hdr.type = SCTP_PARAM_IPV6_ADDRESS; + param->v6.param_hdr.length = htons(length); + param->v6.addr = addr->v6.sin6_addr; + + return length; +} + +/* Initialize a sctp_addr from struct in6_addr. */ +static void sctp_v6_to_addr(union sctp_addr *addr, struct in6_addr *saddr, + __be16 port) +{ + addr->sa.sa_family = AF_INET6; + addr->v6.sin6_port = port; + addr->v6.sin6_flowinfo = 0; + addr->v6.sin6_addr = *saddr; + addr->v6.sin6_scope_id = 0; +} + +static int __sctp_v6_cmp_addr(const union sctp_addr *addr1, + const union sctp_addr *addr2) +{ + if (addr1->sa.sa_family != addr2->sa.sa_family) { + if (addr1->sa.sa_family == AF_INET && + addr2->sa.sa_family == AF_INET6 && + ipv6_addr_v4mapped(&addr2->v6.sin6_addr) && + addr2->v6.sin6_addr.s6_addr32[3] == + addr1->v4.sin_addr.s_addr) + return 1; + + if (addr2->sa.sa_family == AF_INET && + addr1->sa.sa_family == AF_INET6 && + ipv6_addr_v4mapped(&addr1->v6.sin6_addr) && + addr1->v6.sin6_addr.s6_addr32[3] == + addr2->v4.sin_addr.s_addr) + return 1; + + return 0; + } + + if (!ipv6_addr_equal(&addr1->v6.sin6_addr, &addr2->v6.sin6_addr)) + return 0; + + /* If this is a linklocal address, compare the scope_id. */ + if ((ipv6_addr_type(&addr1->v6.sin6_addr) & IPV6_ADDR_LINKLOCAL) && + addr1->v6.sin6_scope_id && addr2->v6.sin6_scope_id && + addr1->v6.sin6_scope_id != addr2->v6.sin6_scope_id) + return 0; + + return 1; +} + +/* Compare addresses exactly. + * v4-mapped-v6 is also in consideration. + */ +static int sctp_v6_cmp_addr(const union sctp_addr *addr1, + const union sctp_addr *addr2) +{ + return __sctp_v6_cmp_addr(addr1, addr2) && + addr1->v6.sin6_port == addr2->v6.sin6_port; +} + +/* Initialize addr struct to INADDR_ANY. */ +static void sctp_v6_inaddr_any(union sctp_addr *addr, __be16 port) +{ + memset(addr, 0x00, sizeof(union sctp_addr)); + addr->v6.sin6_family = AF_INET6; + addr->v6.sin6_port = port; +} + +/* Is this a wildcard address? */ +static int sctp_v6_is_any(const union sctp_addr *addr) +{ + return ipv6_addr_any(&addr->v6.sin6_addr); +} + +/* Should this be available for binding? */ +static int sctp_v6_available(union sctp_addr *addr, struct sctp_sock *sp) +{ + int type; + struct net *net = sock_net(&sp->inet.sk); + const struct in6_addr *in6 = (const struct in6_addr *)&addr->v6.sin6_addr; + + type = ipv6_addr_type(in6); + if (IPV6_ADDR_ANY == type) + return 1; + if (type == IPV6_ADDR_MAPPED) { + if (sp && ipv6_only_sock(sctp_opt2sk(sp))) + return 0; + sctp_v6_map_v4(addr); + return sctp_get_af_specific(AF_INET)->available(addr, sp); + } + if (!(type & IPV6_ADDR_UNICAST)) + return 0; + + return ipv6_can_nonlocal_bind(net, &sp->inet) || + ipv6_chk_addr(net, in6, NULL, 0); +} + +/* This function checks if the address is a valid address to be used for + * SCTP. + * + * Output: + * Return 0 - If the address is a non-unicast or an illegal address. + * Return 1 - If the address is a unicast. + */ +static int sctp_v6_addr_valid(union sctp_addr *addr, + struct sctp_sock *sp, + const struct sk_buff *skb) +{ + int ret = ipv6_addr_type(&addr->v6.sin6_addr); + + /* Support v4-mapped-v6 address. */ + if (ret == IPV6_ADDR_MAPPED) { + /* Note: This routine is used in input, so v4-mapped-v6 + * are disallowed here when there is no sctp_sock. + */ + if (sp && ipv6_only_sock(sctp_opt2sk(sp))) + return 0; + sctp_v6_map_v4(addr); + return sctp_get_af_specific(AF_INET)->addr_valid(addr, sp, skb); + } + + /* Is this a non-unicast address */ + if (!(ret & IPV6_ADDR_UNICAST)) + return 0; + + return 1; +} + +/* What is the scope of 'addr'? */ +static enum sctp_scope sctp_v6_scope(union sctp_addr *addr) +{ + enum sctp_scope retval; + int v6scope; + + /* The IPv6 scope is really a set of bit fields. + * See IFA_* in . Map to a generic SCTP scope. + */ + + v6scope = ipv6_addr_scope(&addr->v6.sin6_addr); + switch (v6scope) { + case IFA_HOST: + retval = SCTP_SCOPE_LOOPBACK; + break; + case IFA_LINK: + retval = SCTP_SCOPE_LINK; + break; + case IFA_SITE: + retval = SCTP_SCOPE_PRIVATE; + break; + default: + retval = SCTP_SCOPE_GLOBAL; + break; + } + + return retval; +} + +/* Create and initialize a new sk for the socket to be returned by accept(). */ +static struct sock *sctp_v6_create_accept_sk(struct sock *sk, + struct sctp_association *asoc, + bool kern) +{ + struct sock *newsk; + struct ipv6_pinfo *newnp, *np = inet6_sk(sk); + struct sctp6_sock *newsctp6sk; + + newsk = sk_alloc(sock_net(sk), PF_INET6, GFP_KERNEL, sk->sk_prot, kern); + if (!newsk) + goto out; + + sock_init_data(NULL, newsk); + + sctp_copy_sock(newsk, sk, asoc); + sock_reset_flag(sk, SOCK_ZAPPED); + + newsctp6sk = (struct sctp6_sock *)newsk; + inet_sk(newsk)->pinet6 = &newsctp6sk->inet6; + + sctp_sk(newsk)->v4mapped = sctp_sk(sk)->v4mapped; + + newnp = inet6_sk(newsk); + + memcpy(newnp, np, sizeof(struct ipv6_pinfo)); + newnp->ipv6_mc_list = NULL; + newnp->ipv6_ac_list = NULL; + newnp->ipv6_fl_list = NULL; + + sctp_v6_copy_ip_options(sk, newsk); + + /* Initialize sk's sport, dport, rcv_saddr and daddr for getsockname() + * and getpeername(). + */ + sctp_v6_to_sk_daddr(&asoc->peer.primary_addr, newsk); + + newsk->sk_v6_rcv_saddr = sk->sk_v6_rcv_saddr; + + sk_refcnt_debug_inc(newsk); + + if (newsk->sk_prot->init(newsk)) { + sk_common_release(newsk); + newsk = NULL; + } + +out: + return newsk; +} + +/* Format a sockaddr for return to user space. This makes sure the return is + * AF_INET or AF_INET6 depending on the SCTP_I_WANT_MAPPED_V4_ADDR option. + */ +static int sctp_v6_addr_to_user(struct sctp_sock *sp, union sctp_addr *addr) +{ + if (sp->v4mapped) { + if (addr->sa.sa_family == AF_INET) + sctp_v4_map_v6(addr); + } else { + if (addr->sa.sa_family == AF_INET6 && + ipv6_addr_v4mapped(&addr->v6.sin6_addr)) + sctp_v6_map_v4(addr); + } + + if (addr->sa.sa_family == AF_INET) { + memset(addr->v4.sin_zero, 0, sizeof(addr->v4.sin_zero)); + return sizeof(struct sockaddr_in); + } + return sizeof(struct sockaddr_in6); +} + +/* Where did this skb come from? */ +static int sctp_v6_skb_iif(const struct sk_buff *skb) +{ + return IP6CB(skb)->iif; +} + +/* Was this packet marked by Explicit Congestion Notification? */ +static int sctp_v6_is_ce(const struct sk_buff *skb) +{ + return *((__u32 *)(ipv6_hdr(skb))) & (__force __u32)htonl(1 << 20); +} + +/* Dump the v6 addr to the seq file. */ +static void sctp_v6_seq_dump_addr(struct seq_file *seq, union sctp_addr *addr) +{ + seq_printf(seq, "%pI6 ", &addr->v6.sin6_addr); +} + +static void sctp_v6_ecn_capable(struct sock *sk) +{ + inet6_sk(sk)->tclass |= INET_ECN_ECT_0; +} + +/* Initialize a PF_INET msgname from a ulpevent. */ +static void sctp_inet6_event_msgname(struct sctp_ulpevent *event, + char *msgname, int *addrlen) +{ + union sctp_addr *addr; + struct sctp_association *asoc; + union sctp_addr *paddr; + + if (!msgname) + return; + + addr = (union sctp_addr *)msgname; + asoc = event->asoc; + paddr = &asoc->peer.primary_addr; + + if (paddr->sa.sa_family == AF_INET) { + addr->v4.sin_family = AF_INET; + addr->v4.sin_port = htons(asoc->peer.port); + addr->v4.sin_addr = paddr->v4.sin_addr; + } else { + addr->v6.sin6_family = AF_INET6; + addr->v6.sin6_flowinfo = 0; + if (ipv6_addr_type(&paddr->v6.sin6_addr) & IPV6_ADDR_LINKLOCAL) + addr->v6.sin6_scope_id = paddr->v6.sin6_scope_id; + else + addr->v6.sin6_scope_id = 0; + addr->v6.sin6_port = htons(asoc->peer.port); + addr->v6.sin6_addr = paddr->v6.sin6_addr; + } + + *addrlen = sctp_v6_addr_to_user(sctp_sk(asoc->base.sk), addr); +} + +/* Initialize a msg_name from an inbound skb. */ +static void sctp_inet6_skb_msgname(struct sk_buff *skb, char *msgname, + int *addr_len) +{ + union sctp_addr *addr; + struct sctphdr *sh; + + if (!msgname) + return; + + addr = (union sctp_addr *)msgname; + sh = sctp_hdr(skb); + + if (ip_hdr(skb)->version == 4) { + addr->v4.sin_family = AF_INET; + addr->v4.sin_port = sh->source; + addr->v4.sin_addr.s_addr = ip_hdr(skb)->saddr; + } else { + addr->v6.sin6_family = AF_INET6; + addr->v6.sin6_flowinfo = 0; + addr->v6.sin6_port = sh->source; + addr->v6.sin6_addr = ipv6_hdr(skb)->saddr; + if (ipv6_addr_type(&addr->v6.sin6_addr) & IPV6_ADDR_LINKLOCAL) + addr->v6.sin6_scope_id = sctp_v6_skb_iif(skb); + else + addr->v6.sin6_scope_id = 0; + } + + *addr_len = sctp_v6_addr_to_user(sctp_sk(skb->sk), addr); +} + +/* Do we support this AF? */ +static int sctp_inet6_af_supported(sa_family_t family, struct sctp_sock *sp) +{ + switch (family) { + case AF_INET6: + return 1; + /* v4-mapped-v6 addresses */ + case AF_INET: + if (!ipv6_only_sock(sctp_opt2sk(sp))) + return 1; + fallthrough; + default: + return 0; + } +} + +/* Address matching with wildcards allowed. This extra level + * of indirection lets us choose whether a PF_INET6 should + * disallow any v4 addresses if we so choose. + */ +static int sctp_inet6_cmp_addr(const union sctp_addr *addr1, + const union sctp_addr *addr2, + struct sctp_sock *opt) +{ + struct sock *sk = sctp_opt2sk(opt); + struct sctp_af *af1, *af2; + + af1 = sctp_get_af_specific(addr1->sa.sa_family); + af2 = sctp_get_af_specific(addr2->sa.sa_family); + + if (!af1 || !af2) + return 0; + + /* If the socket is IPv6 only, v4 addrs will not match */ + if (ipv6_only_sock(sk) && af1 != af2) + return 0; + + /* Today, wildcard AF_INET/AF_INET6. */ + if (sctp_is_any(sk, addr1) || sctp_is_any(sk, addr2)) + return 1; + + if (addr1->sa.sa_family == AF_INET && addr2->sa.sa_family == AF_INET) + return addr1->v4.sin_addr.s_addr == addr2->v4.sin_addr.s_addr; + + return __sctp_v6_cmp_addr(addr1, addr2); +} + +/* Verify that the provided sockaddr looks bindable. Common verification, + * has already been taken care of. + */ +static int sctp_inet6_bind_verify(struct sctp_sock *opt, union sctp_addr *addr) +{ + struct sctp_af *af; + + /* ASSERT: address family has already been verified. */ + if (addr->sa.sa_family != AF_INET6) + af = sctp_get_af_specific(addr->sa.sa_family); + else { + int type = ipv6_addr_type(&addr->v6.sin6_addr); + struct net_device *dev; + + if (type & IPV6_ADDR_LINKLOCAL) { + struct net *net; + if (!addr->v6.sin6_scope_id) + return 0; + net = sock_net(&opt->inet.sk); + rcu_read_lock(); + dev = dev_get_by_index_rcu(net, addr->v6.sin6_scope_id); + if (!dev || !(ipv6_can_nonlocal_bind(net, &opt->inet) || + ipv6_chk_addr(net, &addr->v6.sin6_addr, + dev, 0))) { + rcu_read_unlock(); + return 0; + } + rcu_read_unlock(); + } + + af = opt->pf->af; + } + return af->available(addr, opt); +} + +/* Verify that the provided sockaddr looks sendable. Common verification, + * has already been taken care of. + */ +static int sctp_inet6_send_verify(struct sctp_sock *opt, union sctp_addr *addr) +{ + struct sctp_af *af = NULL; + + /* ASSERT: address family has already been verified. */ + if (addr->sa.sa_family != AF_INET6) + af = sctp_get_af_specific(addr->sa.sa_family); + else { + int type = ipv6_addr_type(&addr->v6.sin6_addr); + struct net_device *dev; + + if (type & IPV6_ADDR_LINKLOCAL) { + if (!addr->v6.sin6_scope_id) + return 0; + rcu_read_lock(); + dev = dev_get_by_index_rcu(sock_net(&opt->inet.sk), + addr->v6.sin6_scope_id); + rcu_read_unlock(); + if (!dev) + return 0; + } + af = opt->pf->af; + } + + return af != NULL; +} + +/* Fill in Supported Address Type information for INIT and INIT-ACK + * chunks. Note: In the future, we may want to look at sock options + * to determine whether a PF_INET6 socket really wants to have IPV4 + * addresses. + * Returns number of addresses supported. + */ +static int sctp_inet6_supported_addrs(const struct sctp_sock *opt, + __be16 *types) +{ + types[0] = SCTP_PARAM_IPV6_ADDRESS; + if (!opt || !ipv6_only_sock(sctp_opt2sk(opt))) { + types[1] = SCTP_PARAM_IPV4_ADDRESS; + return 2; + } + return 1; +} + +/* Handle SCTP_I_WANT_MAPPED_V4_ADDR for getpeername() and getsockname() */ +static int sctp_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + int rc; + + rc = inet6_getname(sock, uaddr, peer); + + if (rc < 0) + return rc; + + rc = sctp_v6_addr_to_user(sctp_sk(sock->sk), + (union sctp_addr *)uaddr); + + return rc; +} + +static const struct proto_ops inet6_seqpacket_ops = { + .family = PF_INET6, + .owner = THIS_MODULE, + .release = inet6_release, + .bind = inet6_bind, + .connect = sctp_inet_connect, + .socketpair = sock_no_socketpair, + .accept = inet_accept, + .getname = sctp_getname, + .poll = sctp_poll, + .ioctl = inet6_ioctl, + .gettstamp = sock_gettstamp, + .listen = sctp_inet_listen, + .shutdown = inet_shutdown, + .setsockopt = sock_common_setsockopt, + .getsockopt = sock_common_getsockopt, + .sendmsg = inet_sendmsg, + .recvmsg = inet_recvmsg, + .mmap = sock_no_mmap, +#ifdef CONFIG_COMPAT + .compat_ioctl = inet6_compat_ioctl, +#endif +}; + +static struct inet_protosw sctpv6_seqpacket_protosw = { + .type = SOCK_SEQPACKET, + .protocol = IPPROTO_SCTP, + .prot = &sctpv6_prot, + .ops = &inet6_seqpacket_ops, + .flags = SCTP_PROTOSW_FLAG +}; +static struct inet_protosw sctpv6_stream_protosw = { + .type = SOCK_STREAM, + .protocol = IPPROTO_SCTP, + .prot = &sctpv6_prot, + .ops = &inet6_seqpacket_ops, + .flags = SCTP_PROTOSW_FLAG, +}; + +static int sctp6_rcv(struct sk_buff *skb) +{ + SCTP_INPUT_CB(skb)->encap_port = 0; + return sctp_rcv(skb) ? -1 : 0; +} + +static const struct inet6_protocol sctpv6_protocol = { + .handler = sctp6_rcv, + .err_handler = sctp_v6_err, + .flags = INET6_PROTO_NOPOLICY | INET6_PROTO_FINAL, +}; + +static struct sctp_af sctp_af_inet6 = { + .sa_family = AF_INET6, + .sctp_xmit = sctp_v6_xmit, + .setsockopt = ipv6_setsockopt, + .getsockopt = ipv6_getsockopt, + .get_dst = sctp_v6_get_dst, + .get_saddr = sctp_v6_get_saddr, + .copy_addrlist = sctp_v6_copy_addrlist, + .from_skb = sctp_v6_from_skb, + .from_sk = sctp_v6_from_sk, + .from_addr_param = sctp_v6_from_addr_param, + .to_addr_param = sctp_v6_to_addr_param, + .cmp_addr = sctp_v6_cmp_addr, + .scope = sctp_v6_scope, + .addr_valid = sctp_v6_addr_valid, + .inaddr_any = sctp_v6_inaddr_any, + .is_any = sctp_v6_is_any, + .available = sctp_v6_available, + .skb_iif = sctp_v6_skb_iif, + .is_ce = sctp_v6_is_ce, + .seq_dump_addr = sctp_v6_seq_dump_addr, + .ecn_capable = sctp_v6_ecn_capable, + .net_header_len = sizeof(struct ipv6hdr), + .sockaddr_len = sizeof(struct sockaddr_in6), + .ip_options_len = sctp_v6_ip_options_len, +}; + +static struct sctp_pf sctp_pf_inet6 = { + .event_msgname = sctp_inet6_event_msgname, + .skb_msgname = sctp_inet6_skb_msgname, + .af_supported = sctp_inet6_af_supported, + .cmp_addr = sctp_inet6_cmp_addr, + .bind_verify = sctp_inet6_bind_verify, + .send_verify = sctp_inet6_send_verify, + .supported_addrs = sctp_inet6_supported_addrs, + .create_accept_sk = sctp_v6_create_accept_sk, + .addr_to_user = sctp_v6_addr_to_user, + .to_sk_saddr = sctp_v6_to_sk_saddr, + .to_sk_daddr = sctp_v6_to_sk_daddr, + .copy_ip_options = sctp_v6_copy_ip_options, + .af = &sctp_af_inet6, +}; + +/* Initialize IPv6 support and register with socket layer. */ +void sctp_v6_pf_init(void) +{ + /* Register the SCTP specific PF_INET6 functions. */ + sctp_register_pf(&sctp_pf_inet6, PF_INET6); + + /* Register the SCTP specific AF_INET6 functions. */ + sctp_register_af(&sctp_af_inet6); +} + +void sctp_v6_pf_exit(void) +{ + list_del(&sctp_af_inet6.list); +} + +/* Initialize IPv6 support and register with socket layer. */ +int sctp_v6_protosw_init(void) +{ + int rc; + + rc = proto_register(&sctpv6_prot, 1); + if (rc) + return rc; + + /* Add SCTPv6(UDP and TCP style) to inetsw6 linked list. */ + inet6_register_protosw(&sctpv6_seqpacket_protosw); + inet6_register_protosw(&sctpv6_stream_protosw); + + return 0; +} + +void sctp_v6_protosw_exit(void) +{ + inet6_unregister_protosw(&sctpv6_seqpacket_protosw); + inet6_unregister_protosw(&sctpv6_stream_protosw); + proto_unregister(&sctpv6_prot); +} + + +/* Register with inet6 layer. */ +int sctp_v6_add_protocol(void) +{ + /* Register notifier for inet6 address additions/deletions. */ + register_inet6addr_notifier(&sctp_inet6addr_notifier); + + if (inet6_add_protocol(&sctpv6_protocol, IPPROTO_SCTP) < 0) + return -EAGAIN; + + return 0; +} + +/* Unregister with inet6 layer. */ +void sctp_v6_del_protocol(void) +{ + inet6_del_protocol(&sctpv6_protocol, IPPROTO_SCTP); + unregister_inet6addr_notifier(&sctp_inet6addr_notifier); +} diff --git a/net/sctp/objcnt.c b/net/sctp/objcnt.c new file mode 100644 index 000000000..0400c964e --- /dev/null +++ b/net/sctp/objcnt.c @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2001, 2004 + * + * This file is part of the SCTP kernel implementation + * + * Support for memory object debugging. This allows one to monitor the + * object allocations/deallocations for types instrumented for this + * via the proc fs. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * Jon Grimm + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include + +/* + * Global counters to count raw object allocation counts. + * To add new counters, choose a unique suffix for the variable + * name as the helper macros key off this suffix to make + * life easier for the programmer. + */ + +SCTP_DBG_OBJCNT(sock); +SCTP_DBG_OBJCNT(ep); +SCTP_DBG_OBJCNT(transport); +SCTP_DBG_OBJCNT(assoc); +SCTP_DBG_OBJCNT(bind_addr); +SCTP_DBG_OBJCNT(bind_bucket); +SCTP_DBG_OBJCNT(chunk); +SCTP_DBG_OBJCNT(addr); +SCTP_DBG_OBJCNT(datamsg); +SCTP_DBG_OBJCNT(keys); + +/* An array to make it easy to pretty print the debug information + * to the proc fs. + */ +static struct sctp_dbg_objcnt_entry sctp_dbg_objcnt[] = { + SCTP_DBG_OBJCNT_ENTRY(sock), + SCTP_DBG_OBJCNT_ENTRY(ep), + SCTP_DBG_OBJCNT_ENTRY(assoc), + SCTP_DBG_OBJCNT_ENTRY(transport), + SCTP_DBG_OBJCNT_ENTRY(chunk), + SCTP_DBG_OBJCNT_ENTRY(bind_addr), + SCTP_DBG_OBJCNT_ENTRY(bind_bucket), + SCTP_DBG_OBJCNT_ENTRY(addr), + SCTP_DBG_OBJCNT_ENTRY(datamsg), + SCTP_DBG_OBJCNT_ENTRY(keys), +}; + +/* Callback from procfs to read out objcount information. + * Walk through the entries in the sctp_dbg_objcnt array, dumping + * the raw object counts for each monitored type. + */ +static int sctp_objcnt_seq_show(struct seq_file *seq, void *v) +{ + int i; + + i = (int)*(loff_t *)v; + seq_setwidth(seq, 127); + seq_printf(seq, "%s: %d", sctp_dbg_objcnt[i].label, + atomic_read(sctp_dbg_objcnt[i].counter)); + seq_pad(seq, '\n'); + return 0; +} + +static void *sctp_objcnt_seq_start(struct seq_file *seq, loff_t *pos) +{ + return (*pos >= ARRAY_SIZE(sctp_dbg_objcnt)) ? NULL : (void *)pos; +} + +static void sctp_objcnt_seq_stop(struct seq_file *seq, void *v) +{ +} + +static void *sctp_objcnt_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + ++*pos; + return (*pos >= ARRAY_SIZE(sctp_dbg_objcnt)) ? NULL : (void *)pos; +} + +static const struct seq_operations sctp_objcnt_seq_ops = { + .start = sctp_objcnt_seq_start, + .next = sctp_objcnt_seq_next, + .stop = sctp_objcnt_seq_stop, + .show = sctp_objcnt_seq_show, +}; + +/* Initialize the objcount in the proc filesystem. */ +void sctp_dbg_objcnt_init(struct net *net) +{ + struct proc_dir_entry *ent; + + ent = proc_create_seq("sctp_dbg_objcnt", 0, + net->sctp.proc_net_sctp, &sctp_objcnt_seq_ops); + if (!ent) + pr_warn("sctp_dbg_objcnt: Unable to create /proc entry.\n"); +} diff --git a/net/sctp/offload.c b/net/sctp/offload.c new file mode 100644 index 000000000..eb874e3c3 --- /dev/null +++ b/net/sctp/offload.c @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * sctp_offload - GRO/GSO Offloading for SCTP + * + * Copyright (C) 2015, Marcelo Ricardo Leitner + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +static __le32 sctp_gso_make_checksum(struct sk_buff *skb) +{ + skb->ip_summed = CHECKSUM_NONE; + skb->csum_not_inet = 0; + /* csum and csum_start in GSO CB may be needed to do the UDP + * checksum when it's a UDP tunneling packet. + */ + SKB_GSO_CB(skb)->csum = (__force __wsum)~0; + SKB_GSO_CB(skb)->csum_start = skb_headroom(skb) + skb->len; + return sctp_compute_cksum(skb, skb_transport_offset(skb)); +} + +static struct sk_buff *sctp_gso_segment(struct sk_buff *skb, + netdev_features_t features) +{ + struct sk_buff *segs = ERR_PTR(-EINVAL); + struct sctphdr *sh; + + if (!skb_is_gso_sctp(skb)) + goto out; + + sh = sctp_hdr(skb); + if (!pskb_may_pull(skb, sizeof(*sh))) + goto out; + + __skb_pull(skb, sizeof(*sh)); + + if (skb_gso_ok(skb, features | NETIF_F_GSO_ROBUST)) { + /* Packet is from an untrusted source, reset gso_segs. */ + struct skb_shared_info *pinfo = skb_shinfo(skb); + struct sk_buff *frag_iter; + + pinfo->gso_segs = 0; + if (skb->len != skb->data_len) { + /* Means we have chunks in here too */ + pinfo->gso_segs++; + } + + skb_walk_frags(skb, frag_iter) + pinfo->gso_segs++; + + segs = NULL; + goto out; + } + + segs = skb_segment(skb, (features | NETIF_F_HW_CSUM) & ~NETIF_F_SG); + if (IS_ERR(segs)) + goto out; + + /* All that is left is update SCTP CRC if necessary */ + if (!(features & NETIF_F_SCTP_CRC)) { + for (skb = segs; skb; skb = skb->next) { + if (skb->ip_summed == CHECKSUM_PARTIAL) { + sh = sctp_hdr(skb); + sh->checksum = sctp_gso_make_checksum(skb); + } + } + } + +out: + return segs; +} + +static const struct net_offload sctp_offload = { + .callbacks = { + .gso_segment = sctp_gso_segment, + }, +}; + +static const struct net_offload sctp6_offload = { + .callbacks = { + .gso_segment = sctp_gso_segment, + }, +}; + +int __init sctp_offload_init(void) +{ + int ret; + + ret = inet_add_offload(&sctp_offload, IPPROTO_SCTP); + if (ret) + goto out; + + ret = inet6_add_offload(&sctp6_offload, IPPROTO_SCTP); + if (ret) + goto ipv4; + + crc32c_csum_stub = &sctp_csum_ops; + return ret; + +ipv4: + inet_del_offload(&sctp_offload, IPPROTO_SCTP); +out: + return ret; +} diff --git a/net/sctp/output.c b/net/sctp/output.c new file mode 100644 index 000000000..a63df055a --- /dev/null +++ b/net/sctp/output.c @@ -0,0 +1,865 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2001, 2004 + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * + * This file is part of the SCTP kernel implementation + * + * These functions handle output processing. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * La Monte H.P. Yarroll + * Karl Knutson + * Jon Grimm + * Sridhar Samudrala + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include /* for sa_family_t */ +#include + +#include +#include +#include + +/* Forward declarations for private helpers. */ +static enum sctp_xmit __sctp_packet_append_chunk(struct sctp_packet *packet, + struct sctp_chunk *chunk); +static enum sctp_xmit sctp_packet_can_append_data(struct sctp_packet *packet, + struct sctp_chunk *chunk); +static void sctp_packet_append_data(struct sctp_packet *packet, + struct sctp_chunk *chunk); +static enum sctp_xmit sctp_packet_will_fit(struct sctp_packet *packet, + struct sctp_chunk *chunk, + u16 chunk_len); + +static void sctp_packet_reset(struct sctp_packet *packet) +{ + /* sctp_packet_transmit() relies on this to reset size to the + * current overhead after sending packets. + */ + packet->size = packet->overhead; + + packet->has_cookie_echo = 0; + packet->has_sack = 0; + packet->has_data = 0; + packet->has_auth = 0; + packet->ipfragok = 0; + packet->auth = NULL; +} + +/* Config a packet. + * This appears to be a followup set of initializations. + */ +void sctp_packet_config(struct sctp_packet *packet, __u32 vtag, + int ecn_capable) +{ + struct sctp_transport *tp = packet->transport; + struct sctp_association *asoc = tp->asoc; + struct sctp_sock *sp = NULL; + struct sock *sk; + + pr_debug("%s: packet:%p vtag:0x%x\n", __func__, packet, vtag); + packet->vtag = vtag; + + /* do the following jobs only once for a flush schedule */ + if (!sctp_packet_empty(packet)) + return; + + /* set packet max_size with pathmtu, then calculate overhead */ + packet->max_size = tp->pathmtu; + + if (asoc) { + sk = asoc->base.sk; + sp = sctp_sk(sk); + } + packet->overhead = sctp_mtu_payload(sp, 0, 0); + packet->size = packet->overhead; + + if (!asoc) + return; + + /* update dst or transport pathmtu if in need */ + if (!sctp_transport_dst_check(tp)) { + sctp_transport_route(tp, NULL, sp); + if (asoc->param_flags & SPP_PMTUD_ENABLE) + sctp_assoc_sync_pmtu(asoc); + } else if (!sctp_transport_pl_enabled(tp) && + asoc->param_flags & SPP_PMTUD_ENABLE) { + if (!sctp_transport_pmtu_check(tp)) + sctp_assoc_sync_pmtu(asoc); + } + + if (asoc->pmtu_pending) { + if (asoc->param_flags & SPP_PMTUD_ENABLE) + sctp_assoc_sync_pmtu(asoc); + asoc->pmtu_pending = 0; + } + + /* If there a is a prepend chunk stick it on the list before + * any other chunks get appended. + */ + if (ecn_capable) { + struct sctp_chunk *chunk = sctp_get_ecne_prepend(asoc); + + if (chunk) + sctp_packet_append_chunk(packet, chunk); + } + + if (!tp->dst) + return; + + /* set packet max_size with gso_max_size if gso is enabled*/ + rcu_read_lock(); + if (__sk_dst_get(sk) != tp->dst) { + dst_hold(tp->dst); + sk_setup_caps(sk, tp->dst); + } + packet->max_size = sk_can_gso(sk) ? min(READ_ONCE(tp->dst->dev->gso_max_size), + GSO_LEGACY_MAX_SIZE) + : asoc->pathmtu; + rcu_read_unlock(); +} + +/* Initialize the packet structure. */ +void sctp_packet_init(struct sctp_packet *packet, + struct sctp_transport *transport, + __u16 sport, __u16 dport) +{ + pr_debug("%s: packet:%p transport:%p\n", __func__, packet, transport); + + packet->transport = transport; + packet->source_port = sport; + packet->destination_port = dport; + INIT_LIST_HEAD(&packet->chunk_list); + /* The overhead will be calculated by sctp_packet_config() */ + packet->overhead = 0; + sctp_packet_reset(packet); + packet->vtag = 0; +} + +/* Free a packet. */ +void sctp_packet_free(struct sctp_packet *packet) +{ + struct sctp_chunk *chunk, *tmp; + + pr_debug("%s: packet:%p\n", __func__, packet); + + list_for_each_entry_safe(chunk, tmp, &packet->chunk_list, list) { + list_del_init(&chunk->list); + sctp_chunk_free(chunk); + } +} + +/* This routine tries to append the chunk to the offered packet. If adding + * the chunk causes the packet to exceed the path MTU and COOKIE_ECHO chunk + * is not present in the packet, it transmits the input packet. + * Data can be bundled with a packet containing a COOKIE_ECHO chunk as long + * as it can fit in the packet, but any more data that does not fit in this + * packet can be sent only after receiving the COOKIE_ACK. + */ +enum sctp_xmit sctp_packet_transmit_chunk(struct sctp_packet *packet, + struct sctp_chunk *chunk, + int one_packet, gfp_t gfp) +{ + enum sctp_xmit retval; + + pr_debug("%s: packet:%p size:%zu chunk:%p size:%d\n", __func__, + packet, packet->size, chunk, chunk->skb ? chunk->skb->len : -1); + + switch ((retval = (sctp_packet_append_chunk(packet, chunk)))) { + case SCTP_XMIT_PMTU_FULL: + if (!packet->has_cookie_echo) { + int error = 0; + + error = sctp_packet_transmit(packet, gfp); + if (error < 0) + chunk->skb->sk->sk_err = -error; + + /* If we have an empty packet, then we can NOT ever + * return PMTU_FULL. + */ + if (!one_packet) + retval = sctp_packet_append_chunk(packet, + chunk); + } + break; + + case SCTP_XMIT_RWND_FULL: + case SCTP_XMIT_OK: + case SCTP_XMIT_DELAY: + break; + } + + return retval; +} + +/* Try to bundle a pad chunk into a packet with a heartbeat chunk for PLPMTUTD probe */ +static enum sctp_xmit sctp_packet_bundle_pad(struct sctp_packet *pkt, struct sctp_chunk *chunk) +{ + struct sctp_transport *t = pkt->transport; + struct sctp_chunk *pad; + int overhead = 0; + + if (!chunk->pmtu_probe) + return SCTP_XMIT_OK; + + /* calculate the Padding Data size for the pad chunk */ + overhead += sizeof(struct sctphdr) + sizeof(struct sctp_chunkhdr); + overhead += sizeof(struct sctp_sender_hb_info) + sizeof(struct sctp_pad_chunk); + pad = sctp_make_pad(t->asoc, t->pl.probe_size - overhead); + if (!pad) + return SCTP_XMIT_DELAY; + + list_add_tail(&pad->list, &pkt->chunk_list); + pkt->size += SCTP_PAD4(ntohs(pad->chunk_hdr->length)); + chunk->transport = t; + + return SCTP_XMIT_OK; +} + +/* Try to bundle an auth chunk into the packet. */ +static enum sctp_xmit sctp_packet_bundle_auth(struct sctp_packet *pkt, + struct sctp_chunk *chunk) +{ + struct sctp_association *asoc = pkt->transport->asoc; + enum sctp_xmit retval = SCTP_XMIT_OK; + struct sctp_chunk *auth; + + /* if we don't have an association, we can't do authentication */ + if (!asoc) + return retval; + + /* See if this is an auth chunk we are bundling or if + * auth is already bundled. + */ + if (chunk->chunk_hdr->type == SCTP_CID_AUTH || pkt->has_auth) + return retval; + + /* if the peer did not request this chunk to be authenticated, + * don't do it + */ + if (!chunk->auth) + return retval; + + auth = sctp_make_auth(asoc, chunk->shkey->key_id); + if (!auth) + return retval; + + auth->shkey = chunk->shkey; + sctp_auth_shkey_hold(auth->shkey); + + retval = __sctp_packet_append_chunk(pkt, auth); + + if (retval != SCTP_XMIT_OK) + sctp_chunk_free(auth); + + return retval; +} + +/* Try to bundle a SACK with the packet. */ +static enum sctp_xmit sctp_packet_bundle_sack(struct sctp_packet *pkt, + struct sctp_chunk *chunk) +{ + enum sctp_xmit retval = SCTP_XMIT_OK; + + /* If sending DATA and haven't aleady bundled a SACK, try to + * bundle one in to the packet. + */ + if (sctp_chunk_is_data(chunk) && !pkt->has_sack && + !pkt->has_cookie_echo) { + struct sctp_association *asoc; + struct timer_list *timer; + asoc = pkt->transport->asoc; + timer = &asoc->timers[SCTP_EVENT_TIMEOUT_SACK]; + + /* If the SACK timer is running, we have a pending SACK */ + if (timer_pending(timer)) { + struct sctp_chunk *sack; + + if (pkt->transport->sack_generation != + pkt->transport->asoc->peer.sack_generation) + return retval; + + asoc->a_rwnd = asoc->rwnd; + sack = sctp_make_sack(asoc); + if (sack) { + retval = __sctp_packet_append_chunk(pkt, sack); + if (retval != SCTP_XMIT_OK) { + sctp_chunk_free(sack); + goto out; + } + SCTP_INC_STATS(asoc->base.net, + SCTP_MIB_OUTCTRLCHUNKS); + asoc->stats.octrlchunks++; + asoc->peer.sack_needed = 0; + if (del_timer(timer)) + sctp_association_put(asoc); + } + } + } +out: + return retval; +} + + +/* Append a chunk to the offered packet reporting back any inability to do + * so. + */ +static enum sctp_xmit __sctp_packet_append_chunk(struct sctp_packet *packet, + struct sctp_chunk *chunk) +{ + __u16 chunk_len = SCTP_PAD4(ntohs(chunk->chunk_hdr->length)); + enum sctp_xmit retval = SCTP_XMIT_OK; + + /* Check to see if this chunk will fit into the packet */ + retval = sctp_packet_will_fit(packet, chunk, chunk_len); + if (retval != SCTP_XMIT_OK) + goto finish; + + /* We believe that this chunk is OK to add to the packet */ + switch (chunk->chunk_hdr->type) { + case SCTP_CID_DATA: + case SCTP_CID_I_DATA: + /* Account for the data being in the packet */ + sctp_packet_append_data(packet, chunk); + /* Disallow SACK bundling after DATA. */ + packet->has_sack = 1; + /* Disallow AUTH bundling after DATA */ + packet->has_auth = 1; + /* Let it be knows that packet has DATA in it */ + packet->has_data = 1; + /* timestamp the chunk for rtx purposes */ + chunk->sent_at = jiffies; + /* Mainly used for prsctp RTX policy */ + chunk->sent_count++; + break; + case SCTP_CID_COOKIE_ECHO: + packet->has_cookie_echo = 1; + break; + + case SCTP_CID_SACK: + packet->has_sack = 1; + if (chunk->asoc) + chunk->asoc->stats.osacks++; + break; + + case SCTP_CID_AUTH: + packet->has_auth = 1; + packet->auth = chunk; + break; + } + + /* It is OK to send this chunk. */ + list_add_tail(&chunk->list, &packet->chunk_list); + packet->size += chunk_len; + chunk->transport = packet->transport; +finish: + return retval; +} + +/* Append a chunk to the offered packet reporting back any inability to do + * so. + */ +enum sctp_xmit sctp_packet_append_chunk(struct sctp_packet *packet, + struct sctp_chunk *chunk) +{ + enum sctp_xmit retval = SCTP_XMIT_OK; + + pr_debug("%s: packet:%p chunk:%p\n", __func__, packet, chunk); + + /* Data chunks are special. Before seeing what else we can + * bundle into this packet, check to see if we are allowed to + * send this DATA. + */ + if (sctp_chunk_is_data(chunk)) { + retval = sctp_packet_can_append_data(packet, chunk); + if (retval != SCTP_XMIT_OK) + goto finish; + } + + /* Try to bundle AUTH chunk */ + retval = sctp_packet_bundle_auth(packet, chunk); + if (retval != SCTP_XMIT_OK) + goto finish; + + /* Try to bundle SACK chunk */ + retval = sctp_packet_bundle_sack(packet, chunk); + if (retval != SCTP_XMIT_OK) + goto finish; + + retval = __sctp_packet_append_chunk(packet, chunk); + if (retval != SCTP_XMIT_OK) + goto finish; + + retval = sctp_packet_bundle_pad(packet, chunk); + +finish: + return retval; +} + +static void sctp_packet_gso_append(struct sk_buff *head, struct sk_buff *skb) +{ + if (SCTP_OUTPUT_CB(head)->last == head) + skb_shinfo(head)->frag_list = skb; + else + SCTP_OUTPUT_CB(head)->last->next = skb; + SCTP_OUTPUT_CB(head)->last = skb; + + head->truesize += skb->truesize; + head->data_len += skb->len; + head->len += skb->len; + refcount_add(skb->truesize, &head->sk->sk_wmem_alloc); + + __skb_header_release(skb); +} + +static int sctp_packet_pack(struct sctp_packet *packet, + struct sk_buff *head, int gso, gfp_t gfp) +{ + struct sctp_transport *tp = packet->transport; + struct sctp_auth_chunk *auth = NULL; + struct sctp_chunk *chunk, *tmp; + int pkt_count = 0, pkt_size; + struct sock *sk = head->sk; + struct sk_buff *nskb; + int auth_len = 0; + + if (gso) { + skb_shinfo(head)->gso_type = sk->sk_gso_type; + SCTP_OUTPUT_CB(head)->last = head; + } else { + nskb = head; + pkt_size = packet->size; + goto merge; + } + + do { + /* calculate the pkt_size and alloc nskb */ + pkt_size = packet->overhead; + list_for_each_entry_safe(chunk, tmp, &packet->chunk_list, + list) { + int padded = SCTP_PAD4(chunk->skb->len); + + if (chunk == packet->auth) + auth_len = padded; + else if (auth_len + padded + packet->overhead > + tp->pathmtu) + return 0; + else if (pkt_size + padded > tp->pathmtu) + break; + pkt_size += padded; + } + nskb = alloc_skb(pkt_size + MAX_HEADER, gfp); + if (!nskb) + return 0; + skb_reserve(nskb, packet->overhead + MAX_HEADER); + +merge: + /* merge chunks into nskb and append nskb into head list */ + pkt_size -= packet->overhead; + list_for_each_entry_safe(chunk, tmp, &packet->chunk_list, list) { + int padding; + + list_del_init(&chunk->list); + if (sctp_chunk_is_data(chunk)) { + if (!sctp_chunk_retransmitted(chunk) && + !tp->rto_pending) { + chunk->rtt_in_progress = 1; + tp->rto_pending = 1; + } + } + + padding = SCTP_PAD4(chunk->skb->len) - chunk->skb->len; + if (padding) + skb_put_zero(chunk->skb, padding); + + if (chunk == packet->auth) + auth = (struct sctp_auth_chunk *) + skb_tail_pointer(nskb); + + skb_put_data(nskb, chunk->skb->data, chunk->skb->len); + + pr_debug("*** Chunk:%p[%s] %s 0x%x, length:%d, chunk->skb->len:%d, rtt_in_progress:%d\n", + chunk, + sctp_cname(SCTP_ST_CHUNK(chunk->chunk_hdr->type)), + chunk->has_tsn ? "TSN" : "No TSN", + chunk->has_tsn ? ntohl(chunk->subh.data_hdr->tsn) : 0, + ntohs(chunk->chunk_hdr->length), chunk->skb->len, + chunk->rtt_in_progress); + + pkt_size -= SCTP_PAD4(chunk->skb->len); + + if (!sctp_chunk_is_data(chunk) && chunk != packet->auth) + sctp_chunk_free(chunk); + + if (!pkt_size) + break; + } + + if (auth) { + sctp_auth_calculate_hmac(tp->asoc, nskb, auth, + packet->auth->shkey, gfp); + /* free auth if no more chunks, or add it back */ + if (list_empty(&packet->chunk_list)) + sctp_chunk_free(packet->auth); + else + list_add(&packet->auth->list, + &packet->chunk_list); + } + + if (gso) + sctp_packet_gso_append(head, nskb); + + pkt_count++; + } while (!list_empty(&packet->chunk_list)); + + if (gso) { + memset(head->cb, 0, max(sizeof(struct inet_skb_parm), + sizeof(struct inet6_skb_parm))); + skb_shinfo(head)->gso_segs = pkt_count; + skb_shinfo(head)->gso_size = GSO_BY_FRAGS; + goto chksum; + } + + if (sctp_checksum_disable) + return 1; + + if (!(tp->dst->dev->features & NETIF_F_SCTP_CRC) || + dst_xfrm(tp->dst) || packet->ipfragok || tp->encap_port) { + struct sctphdr *sh = + (struct sctphdr *)skb_transport_header(head); + + sh->checksum = sctp_compute_cksum(head, 0); + } else { +chksum: + head->ip_summed = CHECKSUM_PARTIAL; + head->csum_not_inet = 1; + head->csum_start = skb_transport_header(head) - head->head; + head->csum_offset = offsetof(struct sctphdr, checksum); + } + + return pkt_count; +} + +/* All packets are sent to the network through this function from + * sctp_outq_tail(). + * + * The return value is always 0 for now. + */ +int sctp_packet_transmit(struct sctp_packet *packet, gfp_t gfp) +{ + struct sctp_transport *tp = packet->transport; + struct sctp_association *asoc = tp->asoc; + struct sctp_chunk *chunk, *tmp; + int pkt_count, gso = 0; + struct sk_buff *head; + struct sctphdr *sh; + struct sock *sk; + + pr_debug("%s: packet:%p\n", __func__, packet); + if (list_empty(&packet->chunk_list)) + return 0; + chunk = list_entry(packet->chunk_list.next, struct sctp_chunk, list); + sk = chunk->skb->sk; + + if (packet->size > tp->pathmtu && !packet->ipfragok && !chunk->pmtu_probe) { + if (tp->pl.state == SCTP_PL_ERROR) { /* do IP fragmentation if in Error state */ + packet->ipfragok = 1; + } else { + if (!sk_can_gso(sk)) { /* check gso */ + pr_err_once("Trying to GSO but underlying device doesn't support it."); + goto out; + } + gso = 1; + } + } + + /* alloc head skb */ + head = alloc_skb((gso ? packet->overhead : packet->size) + + MAX_HEADER, gfp); + if (!head) + goto out; + skb_reserve(head, packet->overhead + MAX_HEADER); + skb_set_owner_w(head, sk); + + /* set sctp header */ + sh = skb_push(head, sizeof(struct sctphdr)); + skb_reset_transport_header(head); + sh->source = htons(packet->source_port); + sh->dest = htons(packet->destination_port); + sh->vtag = htonl(packet->vtag); + sh->checksum = 0; + + /* drop packet if no dst */ + if (!tp->dst) { + IP_INC_STATS(sock_net(sk), IPSTATS_MIB_OUTNOROUTES); + kfree_skb(head); + goto out; + } + + /* pack up chunks */ + pkt_count = sctp_packet_pack(packet, head, gso, gfp); + if (!pkt_count) { + kfree_skb(head); + goto out; + } + pr_debug("***sctp_transmit_packet*** skb->len:%d\n", head->len); + + /* start autoclose timer */ + if (packet->has_data && sctp_state(asoc, ESTABLISHED) && + asoc->timeouts[SCTP_EVENT_TIMEOUT_AUTOCLOSE]) { + struct timer_list *timer = + &asoc->timers[SCTP_EVENT_TIMEOUT_AUTOCLOSE]; + unsigned long timeout = + asoc->timeouts[SCTP_EVENT_TIMEOUT_AUTOCLOSE]; + + if (!mod_timer(timer, jiffies + timeout)) + sctp_association_hold(asoc); + } + + /* sctp xmit */ + tp->af_specific->ecn_capable(sk); + if (asoc) { + asoc->stats.opackets += pkt_count; + if (asoc->peer.last_sent_to != tp) + asoc->peer.last_sent_to = tp; + } + head->ignore_df = packet->ipfragok; + if (tp->dst_pending_confirm) + skb_set_dst_pending_confirm(head, 1); + /* neighbour should be confirmed on successful transmission or + * positive error + */ + if (tp->af_specific->sctp_xmit(head, tp) >= 0 && + tp->dst_pending_confirm) + tp->dst_pending_confirm = 0; + +out: + list_for_each_entry_safe(chunk, tmp, &packet->chunk_list, list) { + list_del_init(&chunk->list); + if (!sctp_chunk_is_data(chunk)) + sctp_chunk_free(chunk); + } + sctp_packet_reset(packet); + return 0; +} + +/******************************************************************** + * 2nd Level Abstractions + ********************************************************************/ + +/* This private function check to see if a chunk can be added */ +static enum sctp_xmit sctp_packet_can_append_data(struct sctp_packet *packet, + struct sctp_chunk *chunk) +{ + size_t datasize, rwnd, inflight, flight_size; + struct sctp_transport *transport = packet->transport; + struct sctp_association *asoc = transport->asoc; + struct sctp_outq *q = &asoc->outqueue; + + /* RFC 2960 6.1 Transmission of DATA Chunks + * + * A) At any given time, the data sender MUST NOT transmit new data to + * any destination transport address if its peer's rwnd indicates + * that the peer has no buffer space (i.e. rwnd is 0, see Section + * 6.2.1). However, regardless of the value of rwnd (including if it + * is 0), the data sender can always have one DATA chunk in flight to + * the receiver if allowed by cwnd (see rule B below). This rule + * allows the sender to probe for a change in rwnd that the sender + * missed due to the SACK having been lost in transit from the data + * receiver to the data sender. + */ + + rwnd = asoc->peer.rwnd; + inflight = q->outstanding_bytes; + flight_size = transport->flight_size; + + datasize = sctp_data_size(chunk); + + if (datasize > rwnd && inflight > 0) + /* We have (at least) one data chunk in flight, + * so we can't fall back to rule 6.1 B). + */ + return SCTP_XMIT_RWND_FULL; + + /* RFC 2960 6.1 Transmission of DATA Chunks + * + * B) At any given time, the sender MUST NOT transmit new data + * to a given transport address if it has cwnd or more bytes + * of data outstanding to that transport address. + */ + /* RFC 7.2.4 & the Implementers Guide 2.8. + * + * 3) ... + * When a Fast Retransmit is being performed the sender SHOULD + * ignore the value of cwnd and SHOULD NOT delay retransmission. + */ + if (chunk->fast_retransmit != SCTP_NEED_FRTX && + flight_size >= transport->cwnd) + return SCTP_XMIT_RWND_FULL; + + /* Nagle's algorithm to solve small-packet problem: + * Inhibit the sending of new chunks when new outgoing data arrives + * if any previously transmitted data on the connection remains + * unacknowledged. + */ + + if ((sctp_sk(asoc->base.sk)->nodelay || inflight == 0) && + !asoc->force_delay) + /* Nothing unacked */ + return SCTP_XMIT_OK; + + if (!sctp_packet_empty(packet)) + /* Append to packet */ + return SCTP_XMIT_OK; + + if (!sctp_state(asoc, ESTABLISHED)) + return SCTP_XMIT_OK; + + /* Check whether this chunk and all the rest of pending data will fit + * or delay in hopes of bundling a full sized packet. + */ + if (chunk->skb->len + q->out_qlen > transport->pathmtu - + packet->overhead - sctp_datachk_len(&chunk->asoc->stream) - 4) + /* Enough data queued to fill a packet */ + return SCTP_XMIT_OK; + + /* Don't delay large message writes that may have been fragmented */ + if (!chunk->msg->can_delay) + return SCTP_XMIT_OK; + + /* Defer until all data acked or packet full */ + return SCTP_XMIT_DELAY; +} + +/* This private function does management things when adding DATA chunk */ +static void sctp_packet_append_data(struct sctp_packet *packet, + struct sctp_chunk *chunk) +{ + struct sctp_transport *transport = packet->transport; + size_t datasize = sctp_data_size(chunk); + struct sctp_association *asoc = transport->asoc; + u32 rwnd = asoc->peer.rwnd; + + /* Keep track of how many bytes are in flight over this transport. */ + transport->flight_size += datasize; + + /* Keep track of how many bytes are in flight to the receiver. */ + asoc->outqueue.outstanding_bytes += datasize; + + /* Update our view of the receiver's rwnd. */ + if (datasize < rwnd) + rwnd -= datasize; + else + rwnd = 0; + + asoc->peer.rwnd = rwnd; + sctp_chunk_assign_tsn(chunk); + asoc->stream.si->assign_number(chunk); +} + +static enum sctp_xmit sctp_packet_will_fit(struct sctp_packet *packet, + struct sctp_chunk *chunk, + u16 chunk_len) +{ + enum sctp_xmit retval = SCTP_XMIT_OK; + size_t psize, pmtu, maxsize; + + /* Don't bundle in this packet if this chunk's auth key doesn't + * match other chunks already enqueued on this packet. Also, + * don't bundle the chunk with auth key if other chunks in this + * packet don't have auth key. + */ + if ((packet->auth && chunk->shkey != packet->auth->shkey) || + (!packet->auth && chunk->shkey && + chunk->chunk_hdr->type != SCTP_CID_AUTH)) + return SCTP_XMIT_PMTU_FULL; + + psize = packet->size; + if (packet->transport->asoc) + pmtu = packet->transport->asoc->pathmtu; + else + pmtu = packet->transport->pathmtu; + + /* Decide if we need to fragment or resubmit later. */ + if (psize + chunk_len > pmtu) { + /* It's OK to fragment at IP level if any one of the following + * is true: + * 1. The packet is empty (meaning this chunk is greater + * the MTU) + * 2. The packet doesn't have any data in it yet and data + * requires authentication. + */ + if (sctp_packet_empty(packet) || + (!packet->has_data && chunk->auth)) { + /* We no longer do re-fragmentation. + * Just fragment at the IP layer, if we + * actually hit this condition + */ + packet->ipfragok = 1; + goto out; + } + + /* Similarly, if this chunk was built before a PMTU + * reduction, we have to fragment it at IP level now. So + * if the packet already contains something, we need to + * flush. + */ + maxsize = pmtu - packet->overhead; + if (packet->auth) + maxsize -= SCTP_PAD4(packet->auth->skb->len); + if (chunk_len > maxsize) + retval = SCTP_XMIT_PMTU_FULL; + + /* It is also okay to fragment if the chunk we are + * adding is a control chunk, but only if current packet + * is not a GSO one otherwise it causes fragmentation of + * a large frame. So in this case we allow the + * fragmentation by forcing it to be in a new packet. + */ + if (!sctp_chunk_is_data(chunk) && packet->has_data) + retval = SCTP_XMIT_PMTU_FULL; + + if (psize + chunk_len > packet->max_size) + /* Hit GSO/PMTU limit, gotta flush */ + retval = SCTP_XMIT_PMTU_FULL; + + if (!packet->transport->burst_limited && + psize + chunk_len > (packet->transport->cwnd >> 1)) + /* Do not allow a single GSO packet to use more + * than half of cwnd. + */ + retval = SCTP_XMIT_PMTU_FULL; + + if (packet->transport->burst_limited && + psize + chunk_len > (packet->transport->burst_limited >> 1)) + /* Do not allow a single GSO packet to use more + * than half of original cwnd. + */ + retval = SCTP_XMIT_PMTU_FULL; + /* Otherwise it will fit in the GSO packet */ + } + +out: + return retval; +} diff --git a/net/sctp/outqueue.c b/net/sctp/outqueue.c new file mode 100644 index 000000000..20831079f --- /dev/null +++ b/net/sctp/outqueue.c @@ -0,0 +1,1922 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2001, 2004 + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * Copyright (c) 2001-2003 Intel Corp. + * + * This file is part of the SCTP kernel implementation + * + * These functions implement the sctp_outq class. The outqueue handles + * bundling and queueing of outgoing SCTP chunks. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * La Monte H.P. Yarroll + * Karl Knutson + * Perry Melange + * Xingang Guo + * Hui Huang + * Sridhar Samudrala + * Jon Grimm + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include /* For struct list_head */ +#include +#include +#include +#include /* For skb_set_owner_w */ + +#include +#include +#include +#include + +/* Declare internal functions here. */ +static int sctp_acked(struct sctp_sackhdr *sack, __u32 tsn); +static void sctp_check_transmitted(struct sctp_outq *q, + struct list_head *transmitted_queue, + struct sctp_transport *transport, + union sctp_addr *saddr, + struct sctp_sackhdr *sack, + __u32 *highest_new_tsn); + +static void sctp_mark_missing(struct sctp_outq *q, + struct list_head *transmitted_queue, + struct sctp_transport *transport, + __u32 highest_new_tsn, + int count_of_newacks); + +static void sctp_outq_flush(struct sctp_outq *q, int rtx_timeout, gfp_t gfp); + +/* Add data to the front of the queue. */ +static inline void sctp_outq_head_data(struct sctp_outq *q, + struct sctp_chunk *ch) +{ + struct sctp_stream_out_ext *oute; + __u16 stream; + + list_add(&ch->list, &q->out_chunk_list); + q->out_qlen += ch->skb->len; + + stream = sctp_chunk_stream_no(ch); + oute = SCTP_SO(&q->asoc->stream, stream)->ext; + list_add(&ch->stream_list, &oute->outq); +} + +/* Take data from the front of the queue. */ +static inline struct sctp_chunk *sctp_outq_dequeue_data(struct sctp_outq *q) +{ + return q->sched->dequeue(q); +} + +/* Add data chunk to the end of the queue. */ +static inline void sctp_outq_tail_data(struct sctp_outq *q, + struct sctp_chunk *ch) +{ + struct sctp_stream_out_ext *oute; + __u16 stream; + + list_add_tail(&ch->list, &q->out_chunk_list); + q->out_qlen += ch->skb->len; + + stream = sctp_chunk_stream_no(ch); + oute = SCTP_SO(&q->asoc->stream, stream)->ext; + list_add_tail(&ch->stream_list, &oute->outq); +} + +/* + * SFR-CACC algorithm: + * D) If count_of_newacks is greater than or equal to 2 + * and t was not sent to the current primary then the + * sender MUST NOT increment missing report count for t. + */ +static inline int sctp_cacc_skip_3_1_d(struct sctp_transport *primary, + struct sctp_transport *transport, + int count_of_newacks) +{ + if (count_of_newacks >= 2 && transport != primary) + return 1; + return 0; +} + +/* + * SFR-CACC algorithm: + * F) If count_of_newacks is less than 2, let d be the + * destination to which t was sent. If cacc_saw_newack + * is 0 for destination d, then the sender MUST NOT + * increment missing report count for t. + */ +static inline int sctp_cacc_skip_3_1_f(struct sctp_transport *transport, + int count_of_newacks) +{ + if (count_of_newacks < 2 && + (transport && !transport->cacc.cacc_saw_newack)) + return 1; + return 0; +} + +/* + * SFR-CACC algorithm: + * 3.1) If CYCLING_CHANGEOVER is 0, the sender SHOULD + * execute steps C, D, F. + * + * C has been implemented in sctp_outq_sack + */ +static inline int sctp_cacc_skip_3_1(struct sctp_transport *primary, + struct sctp_transport *transport, + int count_of_newacks) +{ + if (!primary->cacc.cycling_changeover) { + if (sctp_cacc_skip_3_1_d(primary, transport, count_of_newacks)) + return 1; + if (sctp_cacc_skip_3_1_f(transport, count_of_newacks)) + return 1; + return 0; + } + return 0; +} + +/* + * SFR-CACC algorithm: + * 3.2) Else if CYCLING_CHANGEOVER is 1, and t is less + * than next_tsn_at_change of the current primary, then + * the sender MUST NOT increment missing report count + * for t. + */ +static inline int sctp_cacc_skip_3_2(struct sctp_transport *primary, __u32 tsn) +{ + if (primary->cacc.cycling_changeover && + TSN_lt(tsn, primary->cacc.next_tsn_at_change)) + return 1; + return 0; +} + +/* + * SFR-CACC algorithm: + * 3) If the missing report count for TSN t is to be + * incremented according to [RFC2960] and + * [SCTP_STEWART-2002], and CHANGEOVER_ACTIVE is set, + * then the sender MUST further execute steps 3.1 and + * 3.2 to determine if the missing report count for + * TSN t SHOULD NOT be incremented. + * + * 3.3) If 3.1 and 3.2 do not dictate that the missing + * report count for t should not be incremented, then + * the sender SHOULD increment missing report count for + * t (according to [RFC2960] and [SCTP_STEWART_2002]). + */ +static inline int sctp_cacc_skip(struct sctp_transport *primary, + struct sctp_transport *transport, + int count_of_newacks, + __u32 tsn) +{ + if (primary->cacc.changeover_active && + (sctp_cacc_skip_3_1(primary, transport, count_of_newacks) || + sctp_cacc_skip_3_2(primary, tsn))) + return 1; + return 0; +} + +/* Initialize an existing sctp_outq. This does the boring stuff. + * You still need to define handlers if you really want to DO + * something with this structure... + */ +void sctp_outq_init(struct sctp_association *asoc, struct sctp_outq *q) +{ + memset(q, 0, sizeof(struct sctp_outq)); + + q->asoc = asoc; + INIT_LIST_HEAD(&q->out_chunk_list); + INIT_LIST_HEAD(&q->control_chunk_list); + INIT_LIST_HEAD(&q->retransmit); + INIT_LIST_HEAD(&q->sacked); + INIT_LIST_HEAD(&q->abandoned); + sctp_sched_set_sched(asoc, sctp_sk(asoc->base.sk)->default_ss); +} + +/* Free the outqueue structure and any related pending chunks. + */ +static void __sctp_outq_teardown(struct sctp_outq *q) +{ + struct sctp_transport *transport; + struct list_head *lchunk, *temp; + struct sctp_chunk *chunk, *tmp; + + /* Throw away unacknowledged chunks. */ + list_for_each_entry(transport, &q->asoc->peer.transport_addr_list, + transports) { + while ((lchunk = sctp_list_dequeue(&transport->transmitted)) != NULL) { + chunk = list_entry(lchunk, struct sctp_chunk, + transmitted_list); + /* Mark as part of a failed message. */ + sctp_chunk_fail(chunk, q->error); + sctp_chunk_free(chunk); + } + } + + /* Throw away chunks that have been gap ACKed. */ + list_for_each_safe(lchunk, temp, &q->sacked) { + list_del_init(lchunk); + chunk = list_entry(lchunk, struct sctp_chunk, + transmitted_list); + sctp_chunk_fail(chunk, q->error); + sctp_chunk_free(chunk); + } + + /* Throw away any chunks in the retransmit queue. */ + list_for_each_safe(lchunk, temp, &q->retransmit) { + list_del_init(lchunk); + chunk = list_entry(lchunk, struct sctp_chunk, + transmitted_list); + sctp_chunk_fail(chunk, q->error); + sctp_chunk_free(chunk); + } + + /* Throw away any chunks that are in the abandoned queue. */ + list_for_each_safe(lchunk, temp, &q->abandoned) { + list_del_init(lchunk); + chunk = list_entry(lchunk, struct sctp_chunk, + transmitted_list); + sctp_chunk_fail(chunk, q->error); + sctp_chunk_free(chunk); + } + + /* Throw away any leftover data chunks. */ + while ((chunk = sctp_outq_dequeue_data(q)) != NULL) { + sctp_sched_dequeue_done(q, chunk); + + /* Mark as send failure. */ + sctp_chunk_fail(chunk, q->error); + sctp_chunk_free(chunk); + } + + /* Throw away any leftover control chunks. */ + list_for_each_entry_safe(chunk, tmp, &q->control_chunk_list, list) { + list_del_init(&chunk->list); + sctp_chunk_free(chunk); + } +} + +void sctp_outq_teardown(struct sctp_outq *q) +{ + __sctp_outq_teardown(q); + sctp_outq_init(q->asoc, q); +} + +/* Free the outqueue structure and any related pending chunks. */ +void sctp_outq_free(struct sctp_outq *q) +{ + /* Throw away leftover chunks. */ + __sctp_outq_teardown(q); +} + +/* Put a new chunk in an sctp_outq. */ +void sctp_outq_tail(struct sctp_outq *q, struct sctp_chunk *chunk, gfp_t gfp) +{ + struct net *net = q->asoc->base.net; + + pr_debug("%s: outq:%p, chunk:%p[%s]\n", __func__, q, chunk, + chunk && chunk->chunk_hdr ? + sctp_cname(SCTP_ST_CHUNK(chunk->chunk_hdr->type)) : + "illegal chunk"); + + /* If it is data, queue it up, otherwise, send it + * immediately. + */ + if (sctp_chunk_is_data(chunk)) { + pr_debug("%s: outqueueing: outq:%p, chunk:%p[%s])\n", + __func__, q, chunk, chunk && chunk->chunk_hdr ? + sctp_cname(SCTP_ST_CHUNK(chunk->chunk_hdr->type)) : + "illegal chunk"); + + sctp_outq_tail_data(q, chunk); + if (chunk->asoc->peer.prsctp_capable && + SCTP_PR_PRIO_ENABLED(chunk->sinfo.sinfo_flags)) + chunk->asoc->sent_cnt_removable++; + if (chunk->chunk_hdr->flags & SCTP_DATA_UNORDERED) + SCTP_INC_STATS(net, SCTP_MIB_OUTUNORDERCHUNKS); + else + SCTP_INC_STATS(net, SCTP_MIB_OUTORDERCHUNKS); + } else { + list_add_tail(&chunk->list, &q->control_chunk_list); + SCTP_INC_STATS(net, SCTP_MIB_OUTCTRLCHUNKS); + } + + if (!q->cork) + sctp_outq_flush(q, 0, gfp); +} + +/* Insert a chunk into the sorted list based on the TSNs. The retransmit list + * and the abandoned list are in ascending order. + */ +static void sctp_insert_list(struct list_head *head, struct list_head *new) +{ + struct list_head *pos; + struct sctp_chunk *nchunk, *lchunk; + __u32 ntsn, ltsn; + int done = 0; + + nchunk = list_entry(new, struct sctp_chunk, transmitted_list); + ntsn = ntohl(nchunk->subh.data_hdr->tsn); + + list_for_each(pos, head) { + lchunk = list_entry(pos, struct sctp_chunk, transmitted_list); + ltsn = ntohl(lchunk->subh.data_hdr->tsn); + if (TSN_lt(ntsn, ltsn)) { + list_add(new, pos->prev); + done = 1; + break; + } + } + if (!done) + list_add_tail(new, head); +} + +static int sctp_prsctp_prune_sent(struct sctp_association *asoc, + struct sctp_sndrcvinfo *sinfo, + struct list_head *queue, int msg_len) +{ + struct sctp_chunk *chk, *temp; + + list_for_each_entry_safe(chk, temp, queue, transmitted_list) { + struct sctp_stream_out *streamout; + + if (!chk->msg->abandoned && + (!SCTP_PR_PRIO_ENABLED(chk->sinfo.sinfo_flags) || + chk->sinfo.sinfo_timetolive <= sinfo->sinfo_timetolive)) + continue; + + chk->msg->abandoned = 1; + list_del_init(&chk->transmitted_list); + sctp_insert_list(&asoc->outqueue.abandoned, + &chk->transmitted_list); + + streamout = SCTP_SO(&asoc->stream, chk->sinfo.sinfo_stream); + asoc->sent_cnt_removable--; + asoc->abandoned_sent[SCTP_PR_INDEX(PRIO)]++; + streamout->ext->abandoned_sent[SCTP_PR_INDEX(PRIO)]++; + + if (queue != &asoc->outqueue.retransmit && + !chk->tsn_gap_acked) { + if (chk->transport) + chk->transport->flight_size -= + sctp_data_size(chk); + asoc->outqueue.outstanding_bytes -= sctp_data_size(chk); + } + + msg_len -= chk->skb->truesize + sizeof(struct sctp_chunk); + if (msg_len <= 0) + break; + } + + return msg_len; +} + +static int sctp_prsctp_prune_unsent(struct sctp_association *asoc, + struct sctp_sndrcvinfo *sinfo, int msg_len) +{ + struct sctp_outq *q = &asoc->outqueue; + struct sctp_chunk *chk, *temp; + struct sctp_stream_out *sout; + + q->sched->unsched_all(&asoc->stream); + + list_for_each_entry_safe(chk, temp, &q->out_chunk_list, list) { + if (!chk->msg->abandoned && + (!(chk->chunk_hdr->flags & SCTP_DATA_FIRST_FRAG) || + !SCTP_PR_PRIO_ENABLED(chk->sinfo.sinfo_flags) || + chk->sinfo.sinfo_timetolive <= sinfo->sinfo_timetolive)) + continue; + + chk->msg->abandoned = 1; + sctp_sched_dequeue_common(q, chk); + asoc->sent_cnt_removable--; + asoc->abandoned_unsent[SCTP_PR_INDEX(PRIO)]++; + + sout = SCTP_SO(&asoc->stream, chk->sinfo.sinfo_stream); + sout->ext->abandoned_unsent[SCTP_PR_INDEX(PRIO)]++; + + /* clear out_curr if all frag chunks are pruned */ + if (asoc->stream.out_curr == sout && + list_is_last(&chk->frag_list, &chk->msg->chunks)) + asoc->stream.out_curr = NULL; + + msg_len -= chk->skb->truesize + sizeof(struct sctp_chunk); + sctp_chunk_free(chk); + if (msg_len <= 0) + break; + } + + q->sched->sched_all(&asoc->stream); + + return msg_len; +} + +/* Abandon the chunks according their priorities */ +void sctp_prsctp_prune(struct sctp_association *asoc, + struct sctp_sndrcvinfo *sinfo, int msg_len) +{ + struct sctp_transport *transport; + + if (!asoc->peer.prsctp_capable || !asoc->sent_cnt_removable) + return; + + msg_len = sctp_prsctp_prune_sent(asoc, sinfo, + &asoc->outqueue.retransmit, + msg_len); + if (msg_len <= 0) + return; + + list_for_each_entry(transport, &asoc->peer.transport_addr_list, + transports) { + msg_len = sctp_prsctp_prune_sent(asoc, sinfo, + &transport->transmitted, + msg_len); + if (msg_len <= 0) + return; + } + + sctp_prsctp_prune_unsent(asoc, sinfo, msg_len); +} + +/* Mark all the eligible packets on a transport for retransmission. */ +void sctp_retransmit_mark(struct sctp_outq *q, + struct sctp_transport *transport, + __u8 reason) +{ + struct list_head *lchunk, *ltemp; + struct sctp_chunk *chunk; + + /* Walk through the specified transmitted queue. */ + list_for_each_safe(lchunk, ltemp, &transport->transmitted) { + chunk = list_entry(lchunk, struct sctp_chunk, + transmitted_list); + + /* If the chunk is abandoned, move it to abandoned list. */ + if (sctp_chunk_abandoned(chunk)) { + list_del_init(lchunk); + sctp_insert_list(&q->abandoned, lchunk); + + /* If this chunk has not been previousely acked, + * stop considering it 'outstanding'. Our peer + * will most likely never see it since it will + * not be retransmitted + */ + if (!chunk->tsn_gap_acked) { + if (chunk->transport) + chunk->transport->flight_size -= + sctp_data_size(chunk); + q->outstanding_bytes -= sctp_data_size(chunk); + q->asoc->peer.rwnd += sctp_data_size(chunk); + } + continue; + } + + /* If we are doing retransmission due to a timeout or pmtu + * discovery, only the chunks that are not yet acked should + * be added to the retransmit queue. + */ + if ((reason == SCTP_RTXR_FAST_RTX && + (chunk->fast_retransmit == SCTP_NEED_FRTX)) || + (reason != SCTP_RTXR_FAST_RTX && !chunk->tsn_gap_acked)) { + /* RFC 2960 6.2.1 Processing a Received SACK + * + * C) Any time a DATA chunk is marked for + * retransmission (via either T3-rtx timer expiration + * (Section 6.3.3) or via fast retransmit + * (Section 7.2.4)), add the data size of those + * chunks to the rwnd. + */ + q->asoc->peer.rwnd += sctp_data_size(chunk); + q->outstanding_bytes -= sctp_data_size(chunk); + if (chunk->transport) + transport->flight_size -= sctp_data_size(chunk); + + /* sctpimpguide-05 Section 2.8.2 + * M5) If a T3-rtx timer expires, the + * 'TSN.Missing.Report' of all affected TSNs is set + * to 0. + */ + chunk->tsn_missing_report = 0; + + /* If a chunk that is being used for RTT measurement + * has to be retransmitted, we cannot use this chunk + * anymore for RTT measurements. Reset rto_pending so + * that a new RTT measurement is started when a new + * data chunk is sent. + */ + if (chunk->rtt_in_progress) { + chunk->rtt_in_progress = 0; + transport->rto_pending = 0; + } + + /* Move the chunk to the retransmit queue. The chunks + * on the retransmit queue are always kept in order. + */ + list_del_init(lchunk); + sctp_insert_list(&q->retransmit, lchunk); + } + } + + pr_debug("%s: transport:%p, reason:%d, cwnd:%d, ssthresh:%d, " + "flight_size:%d, pba:%d\n", __func__, transport, reason, + transport->cwnd, transport->ssthresh, transport->flight_size, + transport->partial_bytes_acked); +} + +/* Mark all the eligible packets on a transport for retransmission and force + * one packet out. + */ +void sctp_retransmit(struct sctp_outq *q, struct sctp_transport *transport, + enum sctp_retransmit_reason reason) +{ + struct net *net = q->asoc->base.net; + + switch (reason) { + case SCTP_RTXR_T3_RTX: + SCTP_INC_STATS(net, SCTP_MIB_T3_RETRANSMITS); + sctp_transport_lower_cwnd(transport, SCTP_LOWER_CWND_T3_RTX); + /* Update the retran path if the T3-rtx timer has expired for + * the current retran path. + */ + if (transport == transport->asoc->peer.retran_path) + sctp_assoc_update_retran_path(transport->asoc); + transport->asoc->rtx_data_chunks += + transport->asoc->unack_data; + if (transport->pl.state == SCTP_PL_COMPLETE && + transport->asoc->unack_data) + sctp_transport_reset_probe_timer(transport); + break; + case SCTP_RTXR_FAST_RTX: + SCTP_INC_STATS(net, SCTP_MIB_FAST_RETRANSMITS); + sctp_transport_lower_cwnd(transport, SCTP_LOWER_CWND_FAST_RTX); + q->fast_rtx = 1; + break; + case SCTP_RTXR_PMTUD: + SCTP_INC_STATS(net, SCTP_MIB_PMTUD_RETRANSMITS); + break; + case SCTP_RTXR_T1_RTX: + SCTP_INC_STATS(net, SCTP_MIB_T1_RETRANSMITS); + transport->asoc->init_retries++; + break; + default: + BUG(); + } + + sctp_retransmit_mark(q, transport, reason); + + /* PR-SCTP A5) Any time the T3-rtx timer expires, on any destination, + * the sender SHOULD try to advance the "Advanced.Peer.Ack.Point" by + * following the procedures outlined in C1 - C5. + */ + if (reason == SCTP_RTXR_T3_RTX) + q->asoc->stream.si->generate_ftsn(q, q->asoc->ctsn_ack_point); + + /* Flush the queues only on timeout, since fast_rtx is only + * triggered during sack processing and the queue + * will be flushed at the end. + */ + if (reason != SCTP_RTXR_FAST_RTX) + sctp_outq_flush(q, /* rtx_timeout */ 1, GFP_ATOMIC); +} + +/* + * Transmit DATA chunks on the retransmit queue. Upon return from + * __sctp_outq_flush_rtx() the packet 'pkt' may contain chunks which + * need to be transmitted by the caller. + * We assume that pkt->transport has already been set. + * + * The return value is a normal kernel error return value. + */ +static int __sctp_outq_flush_rtx(struct sctp_outq *q, struct sctp_packet *pkt, + int rtx_timeout, int *start_timer, gfp_t gfp) +{ + struct sctp_transport *transport = pkt->transport; + struct sctp_chunk *chunk, *chunk1; + struct list_head *lqueue; + enum sctp_xmit status; + int error = 0; + int timer = 0; + int done = 0; + int fast_rtx; + + lqueue = &q->retransmit; + fast_rtx = q->fast_rtx; + + /* This loop handles time-out retransmissions, fast retransmissions, + * and retransmissions due to opening of whindow. + * + * RFC 2960 6.3.3 Handle T3-rtx Expiration + * + * E3) Determine how many of the earliest (i.e., lowest TSN) + * outstanding DATA chunks for the address for which the + * T3-rtx has expired will fit into a single packet, subject + * to the MTU constraint for the path corresponding to the + * destination transport address to which the retransmission + * is being sent (this may be different from the address for + * which the timer expires [see Section 6.4]). Call this value + * K. Bundle and retransmit those K DATA chunks in a single + * packet to the destination endpoint. + * + * [Just to be painfully clear, if we are retransmitting + * because a timeout just happened, we should send only ONE + * packet of retransmitted data.] + * + * For fast retransmissions we also send only ONE packet. However, + * if we are just flushing the queue due to open window, we'll + * try to send as much as possible. + */ + list_for_each_entry_safe(chunk, chunk1, lqueue, transmitted_list) { + /* If the chunk is abandoned, move it to abandoned list. */ + if (sctp_chunk_abandoned(chunk)) { + list_del_init(&chunk->transmitted_list); + sctp_insert_list(&q->abandoned, + &chunk->transmitted_list); + continue; + } + + /* Make sure that Gap Acked TSNs are not retransmitted. A + * simple approach is just to move such TSNs out of the + * way and into a 'transmitted' queue and skip to the + * next chunk. + */ + if (chunk->tsn_gap_acked) { + list_move_tail(&chunk->transmitted_list, + &transport->transmitted); + continue; + } + + /* If we are doing fast retransmit, ignore non-fast_rtransmit + * chunks + */ + if (fast_rtx && !chunk->fast_retransmit) + continue; + +redo: + /* Attempt to append this chunk to the packet. */ + status = sctp_packet_append_chunk(pkt, chunk); + + switch (status) { + case SCTP_XMIT_PMTU_FULL: + if (!pkt->has_data && !pkt->has_cookie_echo) { + /* If this packet did not contain DATA then + * retransmission did not happen, so do it + * again. We'll ignore the error here since + * control chunks are already freed so there + * is nothing we can do. + */ + sctp_packet_transmit(pkt, gfp); + goto redo; + } + + /* Send this packet. */ + error = sctp_packet_transmit(pkt, gfp); + + /* If we are retransmitting, we should only + * send a single packet. + * Otherwise, try appending this chunk again. + */ + if (rtx_timeout || fast_rtx) + done = 1; + else + goto redo; + + /* Bundle next chunk in the next round. */ + break; + + case SCTP_XMIT_RWND_FULL: + /* Send this packet. */ + error = sctp_packet_transmit(pkt, gfp); + + /* Stop sending DATA as there is no more room + * at the receiver. + */ + done = 1; + break; + + case SCTP_XMIT_DELAY: + /* Send this packet. */ + error = sctp_packet_transmit(pkt, gfp); + + /* Stop sending DATA because of nagle delay. */ + done = 1; + break; + + default: + /* The append was successful, so add this chunk to + * the transmitted list. + */ + list_move_tail(&chunk->transmitted_list, + &transport->transmitted); + + /* Mark the chunk as ineligible for fast retransmit + * after it is retransmitted. + */ + if (chunk->fast_retransmit == SCTP_NEED_FRTX) + chunk->fast_retransmit = SCTP_DONT_FRTX; + + q->asoc->stats.rtxchunks++; + break; + } + + /* Set the timer if there were no errors */ + if (!error && !timer) + timer = 1; + + if (done) + break; + } + + /* If we are here due to a retransmit timeout or a fast + * retransmit and if there are any chunks left in the retransmit + * queue that could not fit in the PMTU sized packet, they need + * to be marked as ineligible for a subsequent fast retransmit. + */ + if (rtx_timeout || fast_rtx) { + list_for_each_entry(chunk1, lqueue, transmitted_list) { + if (chunk1->fast_retransmit == SCTP_NEED_FRTX) + chunk1->fast_retransmit = SCTP_DONT_FRTX; + } + } + + *start_timer = timer; + + /* Clear fast retransmit hint */ + if (fast_rtx) + q->fast_rtx = 0; + + return error; +} + +/* Cork the outqueue so queued chunks are really queued. */ +void sctp_outq_uncork(struct sctp_outq *q, gfp_t gfp) +{ + if (q->cork) + q->cork = 0; + + sctp_outq_flush(q, 0, gfp); +} + +static int sctp_packet_singleton(struct sctp_transport *transport, + struct sctp_chunk *chunk, gfp_t gfp) +{ + const struct sctp_association *asoc = transport->asoc; + const __u16 sport = asoc->base.bind_addr.port; + const __u16 dport = asoc->peer.port; + const __u32 vtag = asoc->peer.i.init_tag; + struct sctp_packet singleton; + + sctp_packet_init(&singleton, transport, sport, dport); + sctp_packet_config(&singleton, vtag, 0); + if (sctp_packet_append_chunk(&singleton, chunk) != SCTP_XMIT_OK) { + list_del_init(&chunk->list); + sctp_chunk_free(chunk); + return -ENOMEM; + } + return sctp_packet_transmit(&singleton, gfp); +} + +/* Struct to hold the context during sctp outq flush */ +struct sctp_flush_ctx { + struct sctp_outq *q; + /* Current transport being used. It's NOT the same as curr active one */ + struct sctp_transport *transport; + /* These transports have chunks to send. */ + struct list_head transport_list; + struct sctp_association *asoc; + /* Packet on the current transport above */ + struct sctp_packet *packet; + gfp_t gfp; +}; + +/* transport: current transport */ +static void sctp_outq_select_transport(struct sctp_flush_ctx *ctx, + struct sctp_chunk *chunk) +{ + struct sctp_transport *new_transport = chunk->transport; + + if (!new_transport) { + if (!sctp_chunk_is_data(chunk)) { + /* If we have a prior transport pointer, see if + * the destination address of the chunk + * matches the destination address of the + * current transport. If not a match, then + * try to look up the transport with a given + * destination address. We do this because + * after processing ASCONFs, we may have new + * transports created. + */ + if (ctx->transport && sctp_cmp_addr_exact(&chunk->dest, + &ctx->transport->ipaddr)) + new_transport = ctx->transport; + else + new_transport = sctp_assoc_lookup_paddr(ctx->asoc, + &chunk->dest); + } + + /* if we still don't have a new transport, then + * use the current active path. + */ + if (!new_transport) + new_transport = ctx->asoc->peer.active_path; + } else { + __u8 type; + + switch (new_transport->state) { + case SCTP_INACTIVE: + case SCTP_UNCONFIRMED: + case SCTP_PF: + /* If the chunk is Heartbeat or Heartbeat Ack, + * send it to chunk->transport, even if it's + * inactive. + * + * 3.3.6 Heartbeat Acknowledgement: + * ... + * A HEARTBEAT ACK is always sent to the source IP + * address of the IP datagram containing the + * HEARTBEAT chunk to which this ack is responding. + * ... + * + * ASCONF_ACKs also must be sent to the source. + */ + type = chunk->chunk_hdr->type; + if (type != SCTP_CID_HEARTBEAT && + type != SCTP_CID_HEARTBEAT_ACK && + type != SCTP_CID_ASCONF_ACK) + new_transport = ctx->asoc->peer.active_path; + break; + default: + break; + } + } + + /* Are we switching transports? Take care of transport locks. */ + if (new_transport != ctx->transport) { + ctx->transport = new_transport; + ctx->packet = &ctx->transport->packet; + + if (list_empty(&ctx->transport->send_ready)) + list_add_tail(&ctx->transport->send_ready, + &ctx->transport_list); + + sctp_packet_config(ctx->packet, + ctx->asoc->peer.i.init_tag, + ctx->asoc->peer.ecn_capable); + /* We've switched transports, so apply the + * Burst limit to the new transport. + */ + sctp_transport_burst_limited(ctx->transport); + } +} + +static void sctp_outq_flush_ctrl(struct sctp_flush_ctx *ctx) +{ + struct sctp_chunk *chunk, *tmp; + enum sctp_xmit status; + int one_packet, error; + + list_for_each_entry_safe(chunk, tmp, &ctx->q->control_chunk_list, list) { + one_packet = 0; + + /* RFC 5061, 5.3 + * F1) This means that until such time as the ASCONF + * containing the add is acknowledged, the sender MUST + * NOT use the new IP address as a source for ANY SCTP + * packet except on carrying an ASCONF Chunk. + */ + if (ctx->asoc->src_out_of_asoc_ok && + chunk->chunk_hdr->type != SCTP_CID_ASCONF) + continue; + + list_del_init(&chunk->list); + + /* Pick the right transport to use. Should always be true for + * the first chunk as we don't have a transport by then. + */ + sctp_outq_select_transport(ctx, chunk); + + switch (chunk->chunk_hdr->type) { + /* 6.10 Bundling + * ... + * An endpoint MUST NOT bundle INIT, INIT ACK or SHUTDOWN + * COMPLETE with any other chunks. [Send them immediately.] + */ + case SCTP_CID_INIT: + case SCTP_CID_INIT_ACK: + case SCTP_CID_SHUTDOWN_COMPLETE: + error = sctp_packet_singleton(ctx->transport, chunk, + ctx->gfp); + if (error < 0) { + ctx->asoc->base.sk->sk_err = -error; + return; + } + ctx->asoc->stats.octrlchunks++; + break; + + case SCTP_CID_ABORT: + if (sctp_test_T_bit(chunk)) + ctx->packet->vtag = ctx->asoc->c.my_vtag; + fallthrough; + + /* The following chunks are "response" chunks, i.e. + * they are generated in response to something we + * received. If we are sending these, then we can + * send only 1 packet containing these chunks. + */ + case SCTP_CID_HEARTBEAT_ACK: + case SCTP_CID_SHUTDOWN_ACK: + case SCTP_CID_COOKIE_ACK: + case SCTP_CID_COOKIE_ECHO: + case SCTP_CID_ERROR: + case SCTP_CID_ECN_CWR: + case SCTP_CID_ASCONF_ACK: + one_packet = 1; + fallthrough; + + case SCTP_CID_HEARTBEAT: + if (chunk->pmtu_probe) { + error = sctp_packet_singleton(ctx->transport, + chunk, ctx->gfp); + if (!error) + ctx->asoc->stats.octrlchunks++; + break; + } + fallthrough; + case SCTP_CID_SACK: + case SCTP_CID_SHUTDOWN: + case SCTP_CID_ECN_ECNE: + case SCTP_CID_ASCONF: + case SCTP_CID_FWD_TSN: + case SCTP_CID_I_FWD_TSN: + case SCTP_CID_RECONF: + status = sctp_packet_transmit_chunk(ctx->packet, chunk, + one_packet, ctx->gfp); + if (status != SCTP_XMIT_OK) { + /* put the chunk back */ + list_add(&chunk->list, &ctx->q->control_chunk_list); + break; + } + + ctx->asoc->stats.octrlchunks++; + /* PR-SCTP C5) If a FORWARD TSN is sent, the + * sender MUST assure that at least one T3-rtx + * timer is running. + */ + if (chunk->chunk_hdr->type == SCTP_CID_FWD_TSN || + chunk->chunk_hdr->type == SCTP_CID_I_FWD_TSN) { + sctp_transport_reset_t3_rtx(ctx->transport); + ctx->transport->last_time_sent = jiffies; + } + + if (chunk == ctx->asoc->strreset_chunk) + sctp_transport_reset_reconf_timer(ctx->transport); + + break; + + default: + /* We built a chunk with an illegal type! */ + BUG(); + } + } +} + +/* Returns false if new data shouldn't be sent */ +static bool sctp_outq_flush_rtx(struct sctp_flush_ctx *ctx, + int rtx_timeout) +{ + int error, start_timer = 0; + + if (ctx->asoc->peer.retran_path->state == SCTP_UNCONFIRMED) + return false; + + if (ctx->transport != ctx->asoc->peer.retran_path) { + /* Switch transports & prepare the packet. */ + ctx->transport = ctx->asoc->peer.retran_path; + ctx->packet = &ctx->transport->packet; + + if (list_empty(&ctx->transport->send_ready)) + list_add_tail(&ctx->transport->send_ready, + &ctx->transport_list); + + sctp_packet_config(ctx->packet, ctx->asoc->peer.i.init_tag, + ctx->asoc->peer.ecn_capable); + } + + error = __sctp_outq_flush_rtx(ctx->q, ctx->packet, rtx_timeout, + &start_timer, ctx->gfp); + if (error < 0) + ctx->asoc->base.sk->sk_err = -error; + + if (start_timer) { + sctp_transport_reset_t3_rtx(ctx->transport); + ctx->transport->last_time_sent = jiffies; + } + + /* This can happen on COOKIE-ECHO resend. Only + * one chunk can get bundled with a COOKIE-ECHO. + */ + if (ctx->packet->has_cookie_echo) + return false; + + /* Don't send new data if there is still data + * waiting to retransmit. + */ + if (!list_empty(&ctx->q->retransmit)) + return false; + + return true; +} + +static void sctp_outq_flush_data(struct sctp_flush_ctx *ctx, + int rtx_timeout) +{ + struct sctp_chunk *chunk; + enum sctp_xmit status; + + /* Is it OK to send data chunks? */ + switch (ctx->asoc->state) { + case SCTP_STATE_COOKIE_ECHOED: + /* Only allow bundling when this packet has a COOKIE-ECHO + * chunk. + */ + if (!ctx->packet || !ctx->packet->has_cookie_echo) + return; + + fallthrough; + case SCTP_STATE_ESTABLISHED: + case SCTP_STATE_SHUTDOWN_PENDING: + case SCTP_STATE_SHUTDOWN_RECEIVED: + break; + + default: + /* Do nothing. */ + return; + } + + /* RFC 2960 6.1 Transmission of DATA Chunks + * + * C) When the time comes for the sender to transmit, + * before sending new DATA chunks, the sender MUST + * first transmit any outstanding DATA chunks which + * are marked for retransmission (limited by the + * current cwnd). + */ + if (!list_empty(&ctx->q->retransmit) && + !sctp_outq_flush_rtx(ctx, rtx_timeout)) + return; + + /* Apply Max.Burst limitation to the current transport in + * case it will be used for new data. We are going to + * rest it before we return, but we want to apply the limit + * to the currently queued data. + */ + if (ctx->transport) + sctp_transport_burst_limited(ctx->transport); + + /* Finally, transmit new packets. */ + while ((chunk = sctp_outq_dequeue_data(ctx->q)) != NULL) { + __u32 sid = ntohs(chunk->subh.data_hdr->stream); + __u8 stream_state = SCTP_SO(&ctx->asoc->stream, sid)->state; + + /* Has this chunk expired? */ + if (sctp_chunk_abandoned(chunk)) { + sctp_sched_dequeue_done(ctx->q, chunk); + sctp_chunk_fail(chunk, 0); + sctp_chunk_free(chunk); + continue; + } + + if (stream_state == SCTP_STREAM_CLOSED) { + sctp_outq_head_data(ctx->q, chunk); + break; + } + + sctp_outq_select_transport(ctx, chunk); + + pr_debug("%s: outq:%p, chunk:%p[%s], tx-tsn:0x%x skb->head:%p skb->users:%d\n", + __func__, ctx->q, chunk, chunk && chunk->chunk_hdr ? + sctp_cname(SCTP_ST_CHUNK(chunk->chunk_hdr->type)) : + "illegal chunk", ntohl(chunk->subh.data_hdr->tsn), + chunk->skb ? chunk->skb->head : NULL, chunk->skb ? + refcount_read(&chunk->skb->users) : -1); + + /* Add the chunk to the packet. */ + status = sctp_packet_transmit_chunk(ctx->packet, chunk, 0, + ctx->gfp); + if (status != SCTP_XMIT_OK) { + /* We could not append this chunk, so put + * the chunk back on the output queue. + */ + pr_debug("%s: could not transmit tsn:0x%x, status:%d\n", + __func__, ntohl(chunk->subh.data_hdr->tsn), + status); + + sctp_outq_head_data(ctx->q, chunk); + break; + } + + /* The sender is in the SHUTDOWN-PENDING state, + * The sender MAY set the I-bit in the DATA + * chunk header. + */ + if (ctx->asoc->state == SCTP_STATE_SHUTDOWN_PENDING) + chunk->chunk_hdr->flags |= SCTP_DATA_SACK_IMM; + if (chunk->chunk_hdr->flags & SCTP_DATA_UNORDERED) + ctx->asoc->stats.ouodchunks++; + else + ctx->asoc->stats.oodchunks++; + + /* Only now it's safe to consider this + * chunk as sent, sched-wise. + */ + sctp_sched_dequeue_done(ctx->q, chunk); + + list_add_tail(&chunk->transmitted_list, + &ctx->transport->transmitted); + + sctp_transport_reset_t3_rtx(ctx->transport); + ctx->transport->last_time_sent = jiffies; + + /* Only let one DATA chunk get bundled with a + * COOKIE-ECHO chunk. + */ + if (ctx->packet->has_cookie_echo) + break; + } +} + +static void sctp_outq_flush_transports(struct sctp_flush_ctx *ctx) +{ + struct sock *sk = ctx->asoc->base.sk; + struct list_head *ltransport; + struct sctp_packet *packet; + struct sctp_transport *t; + int error = 0; + + while ((ltransport = sctp_list_dequeue(&ctx->transport_list)) != NULL) { + t = list_entry(ltransport, struct sctp_transport, send_ready); + packet = &t->packet; + if (!sctp_packet_empty(packet)) { + rcu_read_lock(); + if (t->dst && __sk_dst_get(sk) != t->dst) { + dst_hold(t->dst); + sk_setup_caps(sk, t->dst); + } + rcu_read_unlock(); + error = sctp_packet_transmit(packet, ctx->gfp); + if (error < 0) + ctx->q->asoc->base.sk->sk_err = -error; + } + + /* Clear the burst limited state, if any */ + sctp_transport_burst_reset(t); + } +} + +/* Try to flush an outqueue. + * + * Description: Send everything in q which we legally can, subject to + * congestion limitations. + * * Note: This function can be called from multiple contexts so appropriate + * locking concerns must be made. Today we use the sock lock to protect + * this function. + */ + +static void sctp_outq_flush(struct sctp_outq *q, int rtx_timeout, gfp_t gfp) +{ + struct sctp_flush_ctx ctx = { + .q = q, + .transport = NULL, + .transport_list = LIST_HEAD_INIT(ctx.transport_list), + .asoc = q->asoc, + .packet = NULL, + .gfp = gfp, + }; + + /* 6.10 Bundling + * ... + * When bundling control chunks with DATA chunks, an + * endpoint MUST place control chunks first in the outbound + * SCTP packet. The transmitter MUST transmit DATA chunks + * within a SCTP packet in increasing order of TSN. + * ... + */ + + sctp_outq_flush_ctrl(&ctx); + + if (q->asoc->src_out_of_asoc_ok) + goto sctp_flush_out; + + sctp_outq_flush_data(&ctx, rtx_timeout); + +sctp_flush_out: + + sctp_outq_flush_transports(&ctx); +} + +/* Update unack_data based on the incoming SACK chunk */ +static void sctp_sack_update_unack_data(struct sctp_association *assoc, + struct sctp_sackhdr *sack) +{ + union sctp_sack_variable *frags; + __u16 unack_data; + int i; + + unack_data = assoc->next_tsn - assoc->ctsn_ack_point - 1; + + frags = sack->variable; + for (i = 0; i < ntohs(sack->num_gap_ack_blocks); i++) { + unack_data -= ((ntohs(frags[i].gab.end) - + ntohs(frags[i].gab.start) + 1)); + } + + assoc->unack_data = unack_data; +} + +/* This is where we REALLY process a SACK. + * + * Process the SACK against the outqueue. Mostly, this just frees + * things off the transmitted queue. + */ +int sctp_outq_sack(struct sctp_outq *q, struct sctp_chunk *chunk) +{ + struct sctp_association *asoc = q->asoc; + struct sctp_sackhdr *sack = chunk->subh.sack_hdr; + struct sctp_transport *transport; + struct sctp_chunk *tchunk = NULL; + struct list_head *lchunk, *transport_list, *temp; + union sctp_sack_variable *frags = sack->variable; + __u32 sack_ctsn, ctsn, tsn; + __u32 highest_tsn, highest_new_tsn; + __u32 sack_a_rwnd; + unsigned int outstanding; + struct sctp_transport *primary = asoc->peer.primary_path; + int count_of_newacks = 0; + int gap_ack_blocks; + u8 accum_moved = 0; + + /* Grab the association's destination address list. */ + transport_list = &asoc->peer.transport_addr_list; + + /* SCTP path tracepoint for congestion control debugging. */ + if (trace_sctp_probe_path_enabled()) { + list_for_each_entry(transport, transport_list, transports) + trace_sctp_probe_path(transport, asoc); + } + + sack_ctsn = ntohl(sack->cum_tsn_ack); + gap_ack_blocks = ntohs(sack->num_gap_ack_blocks); + asoc->stats.gapcnt += gap_ack_blocks; + /* + * SFR-CACC algorithm: + * On receipt of a SACK the sender SHOULD execute the + * following statements. + * + * 1) If the cumulative ack in the SACK passes next tsn_at_change + * on the current primary, the CHANGEOVER_ACTIVE flag SHOULD be + * cleared. The CYCLING_CHANGEOVER flag SHOULD also be cleared for + * all destinations. + * 2) If the SACK contains gap acks and the flag CHANGEOVER_ACTIVE + * is set the receiver of the SACK MUST take the following actions: + * + * A) Initialize the cacc_saw_newack to 0 for all destination + * addresses. + * + * Only bother if changeover_active is set. Otherwise, this is + * totally suboptimal to do on every SACK. + */ + if (primary->cacc.changeover_active) { + u8 clear_cycling = 0; + + if (TSN_lte(primary->cacc.next_tsn_at_change, sack_ctsn)) { + primary->cacc.changeover_active = 0; + clear_cycling = 1; + } + + if (clear_cycling || gap_ack_blocks) { + list_for_each_entry(transport, transport_list, + transports) { + if (clear_cycling) + transport->cacc.cycling_changeover = 0; + if (gap_ack_blocks) + transport->cacc.cacc_saw_newack = 0; + } + } + } + + /* Get the highest TSN in the sack. */ + highest_tsn = sack_ctsn; + if (gap_ack_blocks) + highest_tsn += ntohs(frags[gap_ack_blocks - 1].gab.end); + + if (TSN_lt(asoc->highest_sacked, highest_tsn)) + asoc->highest_sacked = highest_tsn; + + highest_new_tsn = sack_ctsn; + + /* Run through the retransmit queue. Credit bytes received + * and free those chunks that we can. + */ + sctp_check_transmitted(q, &q->retransmit, NULL, NULL, sack, &highest_new_tsn); + + /* Run through the transmitted queue. + * Credit bytes received and free those chunks which we can. + * + * This is a MASSIVE candidate for optimization. + */ + list_for_each_entry(transport, transport_list, transports) { + sctp_check_transmitted(q, &transport->transmitted, + transport, &chunk->source, sack, + &highest_new_tsn); + /* + * SFR-CACC algorithm: + * C) Let count_of_newacks be the number of + * destinations for which cacc_saw_newack is set. + */ + if (transport->cacc.cacc_saw_newack) + count_of_newacks++; + } + + /* Move the Cumulative TSN Ack Point if appropriate. */ + if (TSN_lt(asoc->ctsn_ack_point, sack_ctsn)) { + asoc->ctsn_ack_point = sack_ctsn; + accum_moved = 1; + } + + if (gap_ack_blocks) { + + if (asoc->fast_recovery && accum_moved) + highest_new_tsn = highest_tsn; + + list_for_each_entry(transport, transport_list, transports) + sctp_mark_missing(q, &transport->transmitted, transport, + highest_new_tsn, count_of_newacks); + } + + /* Update unack_data field in the assoc. */ + sctp_sack_update_unack_data(asoc, sack); + + ctsn = asoc->ctsn_ack_point; + + /* Throw away stuff rotting on the sack queue. */ + list_for_each_safe(lchunk, temp, &q->sacked) { + tchunk = list_entry(lchunk, struct sctp_chunk, + transmitted_list); + tsn = ntohl(tchunk->subh.data_hdr->tsn); + if (TSN_lte(tsn, ctsn)) { + list_del_init(&tchunk->transmitted_list); + if (asoc->peer.prsctp_capable && + SCTP_PR_PRIO_ENABLED(chunk->sinfo.sinfo_flags)) + asoc->sent_cnt_removable--; + sctp_chunk_free(tchunk); + } + } + + /* ii) Set rwnd equal to the newly received a_rwnd minus the + * number of bytes still outstanding after processing the + * Cumulative TSN Ack and the Gap Ack Blocks. + */ + + sack_a_rwnd = ntohl(sack->a_rwnd); + asoc->peer.zero_window_announced = !sack_a_rwnd; + outstanding = q->outstanding_bytes; + + if (outstanding < sack_a_rwnd) + sack_a_rwnd -= outstanding; + else + sack_a_rwnd = 0; + + asoc->peer.rwnd = sack_a_rwnd; + + asoc->stream.si->generate_ftsn(q, sack_ctsn); + + pr_debug("%s: sack cumulative tsn ack:0x%x\n", __func__, sack_ctsn); + pr_debug("%s: cumulative tsn ack of assoc:%p is 0x%x, " + "advertised peer ack point:0x%x\n", __func__, asoc, ctsn, + asoc->adv_peer_ack_point); + + return sctp_outq_is_empty(q); +} + +/* Is the outqueue empty? + * The queue is empty when we have not pending data, no in-flight data + * and nothing pending retransmissions. + */ +int sctp_outq_is_empty(const struct sctp_outq *q) +{ + return q->out_qlen == 0 && q->outstanding_bytes == 0 && + list_empty(&q->retransmit); +} + +/******************************************************************** + * 2nd Level Abstractions + ********************************************************************/ + +/* Go through a transport's transmitted list or the association's retransmit + * list and move chunks that are acked by the Cumulative TSN Ack to q->sacked. + * The retransmit list will not have an associated transport. + * + * I added coherent debug information output. --xguo + * + * Instead of printing 'sacked' or 'kept' for each TSN on the + * transmitted_queue, we print a range: SACKED: TSN1-TSN2, TSN3, TSN4-TSN5. + * KEPT TSN6-TSN7, etc. + */ +static void sctp_check_transmitted(struct sctp_outq *q, + struct list_head *transmitted_queue, + struct sctp_transport *transport, + union sctp_addr *saddr, + struct sctp_sackhdr *sack, + __u32 *highest_new_tsn_in_sack) +{ + struct list_head *lchunk; + struct sctp_chunk *tchunk; + struct list_head tlist; + __u32 tsn; + __u32 sack_ctsn; + __u32 rtt; + __u8 restart_timer = 0; + int bytes_acked = 0; + int migrate_bytes = 0; + bool forward_progress = false; + + sack_ctsn = ntohl(sack->cum_tsn_ack); + + INIT_LIST_HEAD(&tlist); + + /* The while loop will skip empty transmitted queues. */ + while (NULL != (lchunk = sctp_list_dequeue(transmitted_queue))) { + tchunk = list_entry(lchunk, struct sctp_chunk, + transmitted_list); + + if (sctp_chunk_abandoned(tchunk)) { + /* Move the chunk to abandoned list. */ + sctp_insert_list(&q->abandoned, lchunk); + + /* If this chunk has not been acked, stop + * considering it as 'outstanding'. + */ + if (transmitted_queue != &q->retransmit && + !tchunk->tsn_gap_acked) { + if (tchunk->transport) + tchunk->transport->flight_size -= + sctp_data_size(tchunk); + q->outstanding_bytes -= sctp_data_size(tchunk); + } + continue; + } + + tsn = ntohl(tchunk->subh.data_hdr->tsn); + if (sctp_acked(sack, tsn)) { + /* If this queue is the retransmit queue, the + * retransmit timer has already reclaimed + * the outstanding bytes for this chunk, so only + * count bytes associated with a transport. + */ + if (transport && !tchunk->tsn_gap_acked) { + /* If this chunk is being used for RTT + * measurement, calculate the RTT and update + * the RTO using this value. + * + * 6.3.1 C5) Karn's algorithm: RTT measurements + * MUST NOT be made using packets that were + * retransmitted (and thus for which it is + * ambiguous whether the reply was for the + * first instance of the packet or a later + * instance). + */ + if (!sctp_chunk_retransmitted(tchunk) && + tchunk->rtt_in_progress) { + tchunk->rtt_in_progress = 0; + rtt = jiffies - tchunk->sent_at; + sctp_transport_update_rto(transport, + rtt); + } + + if (TSN_lte(tsn, sack_ctsn)) { + /* + * SFR-CACC algorithm: + * 2) If the SACK contains gap acks + * and the flag CHANGEOVER_ACTIVE is + * set the receiver of the SACK MUST + * take the following action: + * + * B) For each TSN t being acked that + * has not been acked in any SACK so + * far, set cacc_saw_newack to 1 for + * the destination that the TSN was + * sent to. + */ + if (sack->num_gap_ack_blocks && + q->asoc->peer.primary_path->cacc. + changeover_active) + transport->cacc.cacc_saw_newack + = 1; + } + } + + /* If the chunk hasn't been marked as ACKED, + * mark it and account bytes_acked if the + * chunk had a valid transport (it will not + * have a transport if ASCONF had deleted it + * while DATA was outstanding). + */ + if (!tchunk->tsn_gap_acked) { + tchunk->tsn_gap_acked = 1; + if (TSN_lt(*highest_new_tsn_in_sack, tsn)) + *highest_new_tsn_in_sack = tsn; + bytes_acked += sctp_data_size(tchunk); + if (!tchunk->transport) + migrate_bytes += sctp_data_size(tchunk); + forward_progress = true; + } + + if (TSN_lte(tsn, sack_ctsn)) { + /* RFC 2960 6.3.2 Retransmission Timer Rules + * + * R3) Whenever a SACK is received + * that acknowledges the DATA chunk + * with the earliest outstanding TSN + * for that address, restart T3-rtx + * timer for that address with its + * current RTO. + */ + restart_timer = 1; + forward_progress = true; + + list_add_tail(&tchunk->transmitted_list, + &q->sacked); + } else { + /* RFC2960 7.2.4, sctpimpguide-05 2.8.2 + * M2) Each time a SACK arrives reporting + * 'Stray DATA chunk(s)' record the highest TSN + * reported as newly acknowledged, call this + * value 'HighestTSNinSack'. A newly + * acknowledged DATA chunk is one not + * previously acknowledged in a SACK. + * + * When the SCTP sender of data receives a SACK + * chunk that acknowledges, for the first time, + * the receipt of a DATA chunk, all the still + * unacknowledged DATA chunks whose TSN is + * older than that newly acknowledged DATA + * chunk, are qualified as 'Stray DATA chunks'. + */ + list_add_tail(lchunk, &tlist); + } + } else { + if (tchunk->tsn_gap_acked) { + pr_debug("%s: receiver reneged on data TSN:0x%x\n", + __func__, tsn); + + tchunk->tsn_gap_acked = 0; + + if (tchunk->transport) + bytes_acked -= sctp_data_size(tchunk); + + /* RFC 2960 6.3.2 Retransmission Timer Rules + * + * R4) Whenever a SACK is received missing a + * TSN that was previously acknowledged via a + * Gap Ack Block, start T3-rtx for the + * destination address to which the DATA + * chunk was originally + * transmitted if it is not already running. + */ + restart_timer = 1; + } + + list_add_tail(lchunk, &tlist); + } + } + + if (transport) { + if (bytes_acked) { + struct sctp_association *asoc = transport->asoc; + + /* We may have counted DATA that was migrated + * to this transport due to DEL-IP operation. + * Subtract those bytes, since the were never + * send on this transport and shouldn't be + * credited to this transport. + */ + bytes_acked -= migrate_bytes; + + /* 8.2. When an outstanding TSN is acknowledged, + * the endpoint shall clear the error counter of + * the destination transport address to which the + * DATA chunk was last sent. + * The association's overall error counter is + * also cleared. + */ + transport->error_count = 0; + transport->asoc->overall_error_count = 0; + forward_progress = true; + + /* + * While in SHUTDOWN PENDING, we may have started + * the T5 shutdown guard timer after reaching the + * retransmission limit. Stop that timer as soon + * as the receiver acknowledged any data. + */ + if (asoc->state == SCTP_STATE_SHUTDOWN_PENDING && + del_timer(&asoc->timers + [SCTP_EVENT_TIMEOUT_T5_SHUTDOWN_GUARD])) + sctp_association_put(asoc); + + /* Mark the destination transport address as + * active if it is not so marked. + */ + if ((transport->state == SCTP_INACTIVE || + transport->state == SCTP_UNCONFIRMED) && + sctp_cmp_addr_exact(&transport->ipaddr, saddr)) { + sctp_assoc_control_transport( + transport->asoc, + transport, + SCTP_TRANSPORT_UP, + SCTP_RECEIVED_SACK); + } + + sctp_transport_raise_cwnd(transport, sack_ctsn, + bytes_acked); + + transport->flight_size -= bytes_acked; + if (transport->flight_size == 0) + transport->partial_bytes_acked = 0; + q->outstanding_bytes -= bytes_acked + migrate_bytes; + } else { + /* RFC 2960 6.1, sctpimpguide-06 2.15.2 + * When a sender is doing zero window probing, it + * should not timeout the association if it continues + * to receive new packets from the receiver. The + * reason is that the receiver MAY keep its window + * closed for an indefinite time. + * A sender is doing zero window probing when the + * receiver's advertised window is zero, and there is + * only one data chunk in flight to the receiver. + * + * Allow the association to timeout while in SHUTDOWN + * PENDING or SHUTDOWN RECEIVED in case the receiver + * stays in zero window mode forever. + */ + if (!q->asoc->peer.rwnd && + !list_empty(&tlist) && + (sack_ctsn+2 == q->asoc->next_tsn) && + q->asoc->state < SCTP_STATE_SHUTDOWN_PENDING) { + pr_debug("%s: sack received for zero window " + "probe:%u\n", __func__, sack_ctsn); + + q->asoc->overall_error_count = 0; + transport->error_count = 0; + } + } + + /* RFC 2960 6.3.2 Retransmission Timer Rules + * + * R2) Whenever all outstanding data sent to an address have + * been acknowledged, turn off the T3-rtx timer of that + * address. + */ + if (!transport->flight_size) { + if (del_timer(&transport->T3_rtx_timer)) + sctp_transport_put(transport); + } else if (restart_timer) { + if (!mod_timer(&transport->T3_rtx_timer, + jiffies + transport->rto)) + sctp_transport_hold(transport); + } + + if (forward_progress) { + if (transport->dst) + sctp_transport_dst_confirm(transport); + } + } + + list_splice(&tlist, transmitted_queue); +} + +/* Mark chunks as missing and consequently may get retransmitted. */ +static void sctp_mark_missing(struct sctp_outq *q, + struct list_head *transmitted_queue, + struct sctp_transport *transport, + __u32 highest_new_tsn_in_sack, + int count_of_newacks) +{ + struct sctp_chunk *chunk; + __u32 tsn; + char do_fast_retransmit = 0; + struct sctp_association *asoc = q->asoc; + struct sctp_transport *primary = asoc->peer.primary_path; + + list_for_each_entry(chunk, transmitted_queue, transmitted_list) { + + tsn = ntohl(chunk->subh.data_hdr->tsn); + + /* RFC 2960 7.2.4, sctpimpguide-05 2.8.2 M3) Examine all + * 'Unacknowledged TSN's', if the TSN number of an + * 'Unacknowledged TSN' is smaller than the 'HighestTSNinSack' + * value, increment the 'TSN.Missing.Report' count on that + * chunk if it has NOT been fast retransmitted or marked for + * fast retransmit already. + */ + if (chunk->fast_retransmit == SCTP_CAN_FRTX && + !chunk->tsn_gap_acked && + TSN_lt(tsn, highest_new_tsn_in_sack)) { + + /* SFR-CACC may require us to skip marking + * this chunk as missing. + */ + if (!transport || !sctp_cacc_skip(primary, + chunk->transport, + count_of_newacks, tsn)) { + chunk->tsn_missing_report++; + + pr_debug("%s: tsn:0x%x missing counter:%d\n", + __func__, tsn, chunk->tsn_missing_report); + } + } + /* + * M4) If any DATA chunk is found to have a + * 'TSN.Missing.Report' + * value larger than or equal to 3, mark that chunk for + * retransmission and start the fast retransmit procedure. + */ + + if (chunk->tsn_missing_report >= 3) { + chunk->fast_retransmit = SCTP_NEED_FRTX; + do_fast_retransmit = 1; + } + } + + if (transport) { + if (do_fast_retransmit) + sctp_retransmit(q, transport, SCTP_RTXR_FAST_RTX); + + pr_debug("%s: transport:%p, cwnd:%d, ssthresh:%d, " + "flight_size:%d, pba:%d\n", __func__, transport, + transport->cwnd, transport->ssthresh, + transport->flight_size, transport->partial_bytes_acked); + } +} + +/* Is the given TSN acked by this packet? */ +static int sctp_acked(struct sctp_sackhdr *sack, __u32 tsn) +{ + __u32 ctsn = ntohl(sack->cum_tsn_ack); + union sctp_sack_variable *frags; + __u16 tsn_offset, blocks; + int i; + + if (TSN_lte(tsn, ctsn)) + goto pass; + + /* 3.3.4 Selective Acknowledgment (SACK) (3): + * + * Gap Ack Blocks: + * These fields contain the Gap Ack Blocks. They are repeated + * for each Gap Ack Block up to the number of Gap Ack Blocks + * defined in the Number of Gap Ack Blocks field. All DATA + * chunks with TSNs greater than or equal to (Cumulative TSN + * Ack + Gap Ack Block Start) and less than or equal to + * (Cumulative TSN Ack + Gap Ack Block End) of each Gap Ack + * Block are assumed to have been received correctly. + */ + + frags = sack->variable; + blocks = ntohs(sack->num_gap_ack_blocks); + tsn_offset = tsn - ctsn; + for (i = 0; i < blocks; ++i) { + if (tsn_offset >= ntohs(frags[i].gab.start) && + tsn_offset <= ntohs(frags[i].gab.end)) + goto pass; + } + + return 0; +pass: + return 1; +} + +static inline int sctp_get_skip_pos(struct sctp_fwdtsn_skip *skiplist, + int nskips, __be16 stream) +{ + int i; + + for (i = 0; i < nskips; i++) { + if (skiplist[i].stream == stream) + return i; + } + return i; +} + +/* Create and add a fwdtsn chunk to the outq's control queue if needed. */ +void sctp_generate_fwdtsn(struct sctp_outq *q, __u32 ctsn) +{ + struct sctp_association *asoc = q->asoc; + struct sctp_chunk *ftsn_chunk = NULL; + struct sctp_fwdtsn_skip ftsn_skip_arr[10]; + int nskips = 0; + int skip_pos = 0; + __u32 tsn; + struct sctp_chunk *chunk; + struct list_head *lchunk, *temp; + + if (!asoc->peer.prsctp_capable) + return; + + /* PR-SCTP C1) Let SackCumAck be the Cumulative TSN ACK carried in the + * received SACK. + * + * If (Advanced.Peer.Ack.Point < SackCumAck), then update + * Advanced.Peer.Ack.Point to be equal to SackCumAck. + */ + if (TSN_lt(asoc->adv_peer_ack_point, ctsn)) + asoc->adv_peer_ack_point = ctsn; + + /* PR-SCTP C2) Try to further advance the "Advanced.Peer.Ack.Point" + * locally, that is, to move "Advanced.Peer.Ack.Point" up as long as + * the chunk next in the out-queue space is marked as "abandoned" as + * shown in the following example: + * + * Assuming that a SACK arrived with the Cumulative TSN ACK 102 + * and the Advanced.Peer.Ack.Point is updated to this value: + * + * out-queue at the end of ==> out-queue after Adv.Ack.Point + * normal SACK processing local advancement + * ... ... + * Adv.Ack.Pt-> 102 acked 102 acked + * 103 abandoned 103 abandoned + * 104 abandoned Adv.Ack.P-> 104 abandoned + * 105 105 + * 106 acked 106 acked + * ... ... + * + * In this example, the data sender successfully advanced the + * "Advanced.Peer.Ack.Point" from 102 to 104 locally. + */ + list_for_each_safe(lchunk, temp, &q->abandoned) { + chunk = list_entry(lchunk, struct sctp_chunk, + transmitted_list); + tsn = ntohl(chunk->subh.data_hdr->tsn); + + /* Remove any chunks in the abandoned queue that are acked by + * the ctsn. + */ + if (TSN_lte(tsn, ctsn)) { + list_del_init(lchunk); + sctp_chunk_free(chunk); + } else { + if (TSN_lte(tsn, asoc->adv_peer_ack_point+1)) { + asoc->adv_peer_ack_point = tsn; + if (chunk->chunk_hdr->flags & + SCTP_DATA_UNORDERED) + continue; + skip_pos = sctp_get_skip_pos(&ftsn_skip_arr[0], + nskips, + chunk->subh.data_hdr->stream); + ftsn_skip_arr[skip_pos].stream = + chunk->subh.data_hdr->stream; + ftsn_skip_arr[skip_pos].ssn = + chunk->subh.data_hdr->ssn; + if (skip_pos == nskips) + nskips++; + if (nskips == 10) + break; + } else + break; + } + } + + /* PR-SCTP C3) If, after step C1 and C2, the "Advanced.Peer.Ack.Point" + * is greater than the Cumulative TSN ACK carried in the received + * SACK, the data sender MUST send the data receiver a FORWARD TSN + * chunk containing the latest value of the + * "Advanced.Peer.Ack.Point". + * + * C4) For each "abandoned" TSN the sender of the FORWARD TSN SHOULD + * list each stream and sequence number in the forwarded TSN. This + * information will enable the receiver to easily find any + * stranded TSN's waiting on stream reorder queues. Each stream + * SHOULD only be reported once; this means that if multiple + * abandoned messages occur in the same stream then only the + * highest abandoned stream sequence number is reported. If the + * total size of the FORWARD TSN does NOT fit in a single MTU then + * the sender of the FORWARD TSN SHOULD lower the + * Advanced.Peer.Ack.Point to the last TSN that will fit in a + * single MTU. + */ + if (asoc->adv_peer_ack_point > ctsn) + ftsn_chunk = sctp_make_fwdtsn(asoc, asoc->adv_peer_ack_point, + nskips, &ftsn_skip_arr[0]); + + if (ftsn_chunk) { + list_add_tail(&ftsn_chunk->list, &q->control_chunk_list); + SCTP_INC_STATS(asoc->base.net, SCTP_MIB_OUTCTRLCHUNKS); + } +} diff --git a/net/sctp/primitive.c b/net/sctp/primitive.c new file mode 100644 index 000000000..782d673c3 --- /dev/null +++ b/net/sctp/primitive.c @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * + * This file is part of the SCTP kernel implementation + * + * These functions implement the SCTP primitive functions from Section 10. + * + * Note that the descriptions from the specification are USER level + * functions--this file is the functions which populate the struct proto + * for SCTP which is the BOTTOM of the sockets interface. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * La Monte H.P. Yarroll + * Narasimha Budihal + * Karl Knutson + * Ardelle Fan + * Kevin Gao + */ + +#include +#include /* For struct list_head */ +#include +#include +#include /* For struct timeval */ +#include +#include +#include +#include + +#define DECLARE_PRIMITIVE(name) \ +/* This is called in the code as sctp_primitive_ ## name. */ \ +int sctp_primitive_ ## name(struct net *net, struct sctp_association *asoc, \ + void *arg) { \ + int error = 0; \ + enum sctp_event_type event_type; union sctp_subtype subtype; \ + enum sctp_state state; \ + struct sctp_endpoint *ep; \ + \ + event_type = SCTP_EVENT_T_PRIMITIVE; \ + subtype = SCTP_ST_PRIMITIVE(SCTP_PRIMITIVE_ ## name); \ + state = asoc ? asoc->state : SCTP_STATE_CLOSED; \ + ep = asoc ? asoc->ep : NULL; \ + \ + error = sctp_do_sm(net, event_type, subtype, state, ep, asoc, \ + arg, GFP_KERNEL); \ + return error; \ +} + +/* 10.1 ULP-to-SCTP + * B) Associate + * + * Format: ASSOCIATE(local SCTP instance name, destination transport addr, + * outbound stream count) + * -> association id [,destination transport addr list] [,outbound stream + * count] + * + * This primitive allows the upper layer to initiate an association to a + * specific peer endpoint. + * + * This version assumes that asoc is fully populated with the initial + * parameters. We then return a traditional kernel indicator of + * success or failure. + */ + +/* This is called in the code as sctp_primitive_ASSOCIATE. */ + +DECLARE_PRIMITIVE(ASSOCIATE) + +/* 10.1 ULP-to-SCTP + * C) Shutdown + * + * Format: SHUTDOWN(association id) + * -> result + * + * Gracefully closes an association. Any locally queued user data + * will be delivered to the peer. The association will be terminated only + * after the peer acknowledges all the SCTP packets sent. A success code + * will be returned on successful termination of the association. If + * attempting to terminate the association results in a failure, an error + * code shall be returned. + */ + +DECLARE_PRIMITIVE(SHUTDOWN); + +/* 10.1 ULP-to-SCTP + * C) Abort + * + * Format: Abort(association id [, cause code]) + * -> result + * + * Ungracefully closes an association. Any locally queued user data + * will be discarded and an ABORT chunk is sent to the peer. A success + * code will be returned on successful abortion of the association. If + * attempting to abort the association results in a failure, an error + * code shall be returned. + */ + +DECLARE_PRIMITIVE(ABORT); + +/* 10.1 ULP-to-SCTP + * E) Send + * + * Format: SEND(association id, buffer address, byte count [,context] + * [,stream id] [,life time] [,destination transport address] + * [,unorder flag] [,no-bundle flag] [,payload protocol-id] ) + * -> result + * + * This is the main method to send user data via SCTP. + * + * Mandatory attributes: + * + * o association id - local handle to the SCTP association + * + * o buffer address - the location where the user message to be + * transmitted is stored; + * + * o byte count - The size of the user data in number of bytes; + * + * Optional attributes: + * + * o context - an optional 32 bit integer that will be carried in the + * sending failure notification to the ULP if the transportation of + * this User Message fails. + * + * o stream id - to indicate which stream to send the data on. If not + * specified, stream 0 will be used. + * + * o life time - specifies the life time of the user data. The user data + * will not be sent by SCTP after the life time expires. This + * parameter can be used to avoid efforts to transmit stale + * user messages. SCTP notifies the ULP if the data cannot be + * initiated to transport (i.e. sent to the destination via SCTP's + * send primitive) within the life time variable. However, the + * user data will be transmitted if SCTP has attempted to transmit a + * chunk before the life time expired. + * + * o destination transport address - specified as one of the destination + * transport addresses of the peer endpoint to which this packet + * should be sent. Whenever possible, SCTP should use this destination + * transport address for sending the packets, instead of the current + * primary path. + * + * o unorder flag - this flag, if present, indicates that the user + * would like the data delivered in an unordered fashion to the peer + * (i.e., the U flag is set to 1 on all DATA chunks carrying this + * message). + * + * o no-bundle flag - instructs SCTP not to bundle this user data with + * other outbound DATA chunks. SCTP MAY still bundle even when + * this flag is present, when faced with network congestion. + * + * o payload protocol-id - A 32 bit unsigned integer that is to be + * passed to the peer indicating the type of payload protocol data + * being transmitted. This value is passed as opaque data by SCTP. + */ + +DECLARE_PRIMITIVE(SEND); + +/* 10.1 ULP-to-SCTP + * J) Request Heartbeat + * + * Format: REQUESTHEARTBEAT(association id, destination transport address) + * + * -> result + * + * Instructs the local endpoint to perform a HeartBeat on the specified + * destination transport address of the given association. The returned + * result should indicate whether the transmission of the HEARTBEAT + * chunk to the destination address is successful. + * + * Mandatory attributes: + * + * o association id - local handle to the SCTP association + * + * o destination transport address - the transport address of the + * association on which a heartbeat should be issued. + */ + +DECLARE_PRIMITIVE(REQUESTHEARTBEAT); + +/* ADDIP +* 3.1.1 Address Configuration Change Chunk (ASCONF) +* +* This chunk is used to communicate to the remote endpoint one of the +* configuration change requests that MUST be acknowledged. The +* information carried in the ASCONF Chunk uses the form of a +* Type-Length-Value (TLV), as described in "3.2.1 Optional/ +* Variable-length Parameter Format" in RFC2960 [5], forall variable +* parameters. +*/ + +DECLARE_PRIMITIVE(ASCONF); + +/* RE-CONFIG 5.1 */ +DECLARE_PRIMITIVE(RECONF); diff --git a/net/sctp/proc.c b/net/sctp/proc.c new file mode 100644 index 000000000..ec00ee75d --- /dev/null +++ b/net/sctp/proc.c @@ -0,0 +1,399 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * Copyright (c) 2003 International Business Machines, Corp. + * + * This file is part of the SCTP kernel implementation + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * Sridhar Samudrala + */ + +#include +#include +#include +#include +#include +#include /* for snmp_fold_field */ + +static const struct snmp_mib sctp_snmp_list[] = { + SNMP_MIB_ITEM("SctpCurrEstab", SCTP_MIB_CURRESTAB), + SNMP_MIB_ITEM("SctpActiveEstabs", SCTP_MIB_ACTIVEESTABS), + SNMP_MIB_ITEM("SctpPassiveEstabs", SCTP_MIB_PASSIVEESTABS), + SNMP_MIB_ITEM("SctpAborteds", SCTP_MIB_ABORTEDS), + SNMP_MIB_ITEM("SctpShutdowns", SCTP_MIB_SHUTDOWNS), + SNMP_MIB_ITEM("SctpOutOfBlues", SCTP_MIB_OUTOFBLUES), + SNMP_MIB_ITEM("SctpChecksumErrors", SCTP_MIB_CHECKSUMERRORS), + SNMP_MIB_ITEM("SctpOutCtrlChunks", SCTP_MIB_OUTCTRLCHUNKS), + SNMP_MIB_ITEM("SctpOutOrderChunks", SCTP_MIB_OUTORDERCHUNKS), + SNMP_MIB_ITEM("SctpOutUnorderChunks", SCTP_MIB_OUTUNORDERCHUNKS), + SNMP_MIB_ITEM("SctpInCtrlChunks", SCTP_MIB_INCTRLCHUNKS), + SNMP_MIB_ITEM("SctpInOrderChunks", SCTP_MIB_INORDERCHUNKS), + SNMP_MIB_ITEM("SctpInUnorderChunks", SCTP_MIB_INUNORDERCHUNKS), + SNMP_MIB_ITEM("SctpFragUsrMsgs", SCTP_MIB_FRAGUSRMSGS), + SNMP_MIB_ITEM("SctpReasmUsrMsgs", SCTP_MIB_REASMUSRMSGS), + SNMP_MIB_ITEM("SctpOutSCTPPacks", SCTP_MIB_OUTSCTPPACKS), + SNMP_MIB_ITEM("SctpInSCTPPacks", SCTP_MIB_INSCTPPACKS), + SNMP_MIB_ITEM("SctpT1InitExpireds", SCTP_MIB_T1_INIT_EXPIREDS), + SNMP_MIB_ITEM("SctpT1CookieExpireds", SCTP_MIB_T1_COOKIE_EXPIREDS), + SNMP_MIB_ITEM("SctpT2ShutdownExpireds", SCTP_MIB_T2_SHUTDOWN_EXPIREDS), + SNMP_MIB_ITEM("SctpT3RtxExpireds", SCTP_MIB_T3_RTX_EXPIREDS), + SNMP_MIB_ITEM("SctpT4RtoExpireds", SCTP_MIB_T4_RTO_EXPIREDS), + SNMP_MIB_ITEM("SctpT5ShutdownGuardExpireds", SCTP_MIB_T5_SHUTDOWN_GUARD_EXPIREDS), + SNMP_MIB_ITEM("SctpDelaySackExpireds", SCTP_MIB_DELAY_SACK_EXPIREDS), + SNMP_MIB_ITEM("SctpAutocloseExpireds", SCTP_MIB_AUTOCLOSE_EXPIREDS), + SNMP_MIB_ITEM("SctpT3Retransmits", SCTP_MIB_T3_RETRANSMITS), + SNMP_MIB_ITEM("SctpPmtudRetransmits", SCTP_MIB_PMTUD_RETRANSMITS), + SNMP_MIB_ITEM("SctpFastRetransmits", SCTP_MIB_FAST_RETRANSMITS), + SNMP_MIB_ITEM("SctpInPktSoftirq", SCTP_MIB_IN_PKT_SOFTIRQ), + SNMP_MIB_ITEM("SctpInPktBacklog", SCTP_MIB_IN_PKT_BACKLOG), + SNMP_MIB_ITEM("SctpInPktDiscards", SCTP_MIB_IN_PKT_DISCARDS), + SNMP_MIB_ITEM("SctpInDataChunkDiscards", SCTP_MIB_IN_DATA_CHUNK_DISCARDS), + SNMP_MIB_SENTINEL +}; + +/* Display sctp snmp mib statistics(/proc/net/sctp/snmp). */ +static int sctp_snmp_seq_show(struct seq_file *seq, void *v) +{ + unsigned long buff[SCTP_MIB_MAX]; + struct net *net = seq->private; + int i; + + memset(buff, 0, sizeof(unsigned long) * SCTP_MIB_MAX); + + snmp_get_cpu_field_batch(buff, sctp_snmp_list, + net->sctp.sctp_statistics); + for (i = 0; sctp_snmp_list[i].name; i++) + seq_printf(seq, "%-32s\t%ld\n", sctp_snmp_list[i].name, + buff[i]); + + return 0; +} + +/* Dump local addresses of an association/endpoint. */ +static void sctp_seq_dump_local_addrs(struct seq_file *seq, struct sctp_ep_common *epb) +{ + struct sctp_association *asoc; + struct sctp_sockaddr_entry *laddr; + struct sctp_transport *peer; + union sctp_addr *addr, *primary = NULL; + struct sctp_af *af; + + if (epb->type == SCTP_EP_TYPE_ASSOCIATION) { + asoc = sctp_assoc(epb); + + peer = asoc->peer.primary_path; + if (unlikely(peer == NULL)) { + WARN(1, "Association %p with NULL primary path!\n", asoc); + return; + } + + primary = &peer->saddr; + } + + rcu_read_lock(); + list_for_each_entry_rcu(laddr, &epb->bind_addr.address_list, list) { + if (!laddr->valid) + continue; + + addr = &laddr->a; + af = sctp_get_af_specific(addr->sa.sa_family); + if (primary && af->cmp_addr(addr, primary)) { + seq_printf(seq, "*"); + } + af->seq_dump_addr(seq, addr); + } + rcu_read_unlock(); +} + +/* Dump remote addresses of an association. */ +static void sctp_seq_dump_remote_addrs(struct seq_file *seq, struct sctp_association *assoc) +{ + struct sctp_transport *transport; + union sctp_addr *addr, *primary; + struct sctp_af *af; + + primary = &assoc->peer.primary_addr; + list_for_each_entry_rcu(transport, &assoc->peer.transport_addr_list, + transports) { + addr = &transport->ipaddr; + + af = sctp_get_af_specific(addr->sa.sa_family); + if (af->cmp_addr(addr, primary)) { + seq_printf(seq, "*"); + } + af->seq_dump_addr(seq, addr); + } +} + +static void *sctp_eps_seq_start(struct seq_file *seq, loff_t *pos) +{ + if (*pos >= sctp_ep_hashsize) + return NULL; + + if (*pos < 0) + *pos = 0; + + if (*pos == 0) + seq_printf(seq, " ENDPT SOCK STY SST HBKT LPORT UID INODE LADDRS\n"); + + return (void *)pos; +} + +static void sctp_eps_seq_stop(struct seq_file *seq, void *v) +{ +} + + +static void *sctp_eps_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + if (++*pos >= sctp_ep_hashsize) + return NULL; + + return pos; +} + + +/* Display sctp endpoints (/proc/net/sctp/eps). */ +static int sctp_eps_seq_show(struct seq_file *seq, void *v) +{ + struct sctp_hashbucket *head; + struct sctp_endpoint *ep; + struct sock *sk; + int hash = *(loff_t *)v; + + if (hash >= sctp_ep_hashsize) + return -ENOMEM; + + head = &sctp_ep_hashtable[hash]; + read_lock_bh(&head->lock); + sctp_for_each_hentry(ep, &head->chain) { + sk = ep->base.sk; + if (!net_eq(sock_net(sk), seq_file_net(seq))) + continue; + seq_printf(seq, "%8pK %8pK %-3d %-3d %-4d %-5d %5u %5lu ", ep, sk, + sctp_sk(sk)->type, sk->sk_state, hash, + ep->base.bind_addr.port, + from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk)), + sock_i_ino(sk)); + + sctp_seq_dump_local_addrs(seq, &ep->base); + seq_printf(seq, "\n"); + } + read_unlock_bh(&head->lock); + + return 0; +} + +static const struct seq_operations sctp_eps_ops = { + .start = sctp_eps_seq_start, + .next = sctp_eps_seq_next, + .stop = sctp_eps_seq_stop, + .show = sctp_eps_seq_show, +}; + +struct sctp_ht_iter { + struct seq_net_private p; + struct rhashtable_iter hti; +}; + +static void *sctp_transport_seq_start(struct seq_file *seq, loff_t *pos) +{ + struct sctp_ht_iter *iter = seq->private; + + sctp_transport_walk_start(&iter->hti); + + return sctp_transport_get_idx(seq_file_net(seq), &iter->hti, *pos); +} + +static void sctp_transport_seq_stop(struct seq_file *seq, void *v) +{ + struct sctp_ht_iter *iter = seq->private; + + if (v && v != SEQ_START_TOKEN) { + struct sctp_transport *transport = v; + + sctp_transport_put(transport); + } + + sctp_transport_walk_stop(&iter->hti); +} + +static void *sctp_transport_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct sctp_ht_iter *iter = seq->private; + + if (v && v != SEQ_START_TOKEN) { + struct sctp_transport *transport = v; + + sctp_transport_put(transport); + } + + ++*pos; + + return sctp_transport_get_next(seq_file_net(seq), &iter->hti); +} + +/* Display sctp associations (/proc/net/sctp/assocs). */ +static int sctp_assocs_seq_show(struct seq_file *seq, void *v) +{ + struct sctp_transport *transport; + struct sctp_association *assoc; + struct sctp_ep_common *epb; + struct sock *sk; + + if (v == SEQ_START_TOKEN) { + seq_printf(seq, " ASSOC SOCK STY SST ST HBKT " + "ASSOC-ID TX_QUEUE RX_QUEUE UID INODE LPORT " + "RPORT LADDRS <-> RADDRS " + "HBINT INS OUTS MAXRT T1X T2X RTXC " + "wmema wmemq sndbuf rcvbuf\n"); + return 0; + } + + transport = (struct sctp_transport *)v; + assoc = transport->asoc; + epb = &assoc->base; + sk = epb->sk; + + seq_printf(seq, + "%8pK %8pK %-3d %-3d %-2d %-4d " + "%4d %8d %8d %7u %5lu %-5d %5d ", + assoc, sk, sctp_sk(sk)->type, sk->sk_state, + assoc->state, 0, + assoc->assoc_id, + assoc->sndbuf_used, + atomic_read(&assoc->rmem_alloc), + from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk)), + sock_i_ino(sk), + epb->bind_addr.port, + assoc->peer.port); + seq_printf(seq, " "); + sctp_seq_dump_local_addrs(seq, epb); + seq_printf(seq, "<-> "); + sctp_seq_dump_remote_addrs(seq, assoc); + seq_printf(seq, "\t%8lu %5d %5d %4d %4d %4d %8d " + "%8d %8d %8d %8d", + assoc->hbinterval, assoc->stream.incnt, + assoc->stream.outcnt, assoc->max_retrans, + assoc->init_retries, assoc->shutdown_retries, + assoc->rtx_data_chunks, + refcount_read(&sk->sk_wmem_alloc), + READ_ONCE(sk->sk_wmem_queued), + sk->sk_sndbuf, + sk->sk_rcvbuf); + seq_printf(seq, "\n"); + + return 0; +} + +static const struct seq_operations sctp_assoc_ops = { + .start = sctp_transport_seq_start, + .next = sctp_transport_seq_next, + .stop = sctp_transport_seq_stop, + .show = sctp_assocs_seq_show, +}; + +static int sctp_remaddr_seq_show(struct seq_file *seq, void *v) +{ + struct sctp_association *assoc; + struct sctp_transport *transport, *tsp; + + if (v == SEQ_START_TOKEN) { + seq_printf(seq, "ADDR ASSOC_ID HB_ACT RTO MAX_PATH_RTX " + "REM_ADDR_RTX START STATE\n"); + return 0; + } + + transport = (struct sctp_transport *)v; + assoc = transport->asoc; + + list_for_each_entry_rcu(tsp, &assoc->peer.transport_addr_list, + transports) { + /* + * The remote address (ADDR) + */ + tsp->af_specific->seq_dump_addr(seq, &tsp->ipaddr); + seq_printf(seq, " "); + /* + * The association ID (ASSOC_ID) + */ + seq_printf(seq, "%d ", tsp->asoc->assoc_id); + + /* + * If the Heartbeat is active (HB_ACT) + * Note: 1 = Active, 0 = Inactive + */ + seq_printf(seq, "%d ", timer_pending(&tsp->hb_timer)); + + /* + * Retransmit time out (RTO) + */ + seq_printf(seq, "%lu ", tsp->rto); + + /* + * Maximum path retransmit count (PATH_MAX_RTX) + */ + seq_printf(seq, "%d ", tsp->pathmaxrxt); + + /* + * remote address retransmit count (REM_ADDR_RTX) + * Note: We don't have a way to tally this at the moment + * so lets just leave it as zero for the moment + */ + seq_puts(seq, "0 "); + + /* + * remote address start time (START). This is also not + * currently implemented, but we can record it with a + * jiffies marker in a subsequent patch + */ + seq_puts(seq, "0 "); + + /* + * The current state of this destination. I.e. + * SCTP_ACTIVE, SCTP_INACTIVE, ... + */ + seq_printf(seq, "%d", tsp->state); + + seq_printf(seq, "\n"); + } + + return 0; +} + +static const struct seq_operations sctp_remaddr_ops = { + .start = sctp_transport_seq_start, + .next = sctp_transport_seq_next, + .stop = sctp_transport_seq_stop, + .show = sctp_remaddr_seq_show, +}; + +/* Set up the proc fs entry for the SCTP protocol. */ +int __net_init sctp_proc_init(struct net *net) +{ + net->sctp.proc_net_sctp = proc_net_mkdir(net, "sctp", net->proc_net); + if (!net->sctp.proc_net_sctp) + return -ENOMEM; + if (!proc_create_net_single("snmp", 0444, net->sctp.proc_net_sctp, + sctp_snmp_seq_show, NULL)) + goto cleanup; + if (!proc_create_net("eps", 0444, net->sctp.proc_net_sctp, + &sctp_eps_ops, sizeof(struct seq_net_private))) + goto cleanup; + if (!proc_create_net("assocs", 0444, net->sctp.proc_net_sctp, + &sctp_assoc_ops, sizeof(struct sctp_ht_iter))) + goto cleanup; + if (!proc_create_net("remaddr", 0444, net->sctp.proc_net_sctp, + &sctp_remaddr_ops, sizeof(struct sctp_ht_iter))) + goto cleanup; + return 0; + +cleanup: + remove_proc_subtree("sctp", net->proc_net); + net->sctp.proc_net_sctp = NULL; + return -ENOMEM; +} diff --git a/net/sctp/protocol.c b/net/sctp/protocol.c new file mode 100644 index 000000000..bcd3384ab --- /dev/null +++ b/net/sctp/protocol.c @@ -0,0 +1,1728 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2001, 2004 + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * Copyright (c) 2001 Intel Corp. + * Copyright (c) 2001 Nokia, Inc. + * Copyright (c) 2001 La Monte H.P. Yarroll + * + * This file is part of the SCTP kernel implementation + * + * Initialization/cleanup for SCTP protocol support. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * La Monte H.P. Yarroll + * Karl Knutson + * Jon Grimm + * Sridhar Samudrala + * Daisy Chang + * Ardelle Fan + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define MAX_SCTP_PORT_HASH_ENTRIES (64 * 1024) + +/* Global data structures. */ +struct sctp_globals sctp_globals __read_mostly; + +struct idr sctp_assocs_id; +DEFINE_SPINLOCK(sctp_assocs_id_lock); + +static struct sctp_pf *sctp_pf_inet6_specific; +static struct sctp_pf *sctp_pf_inet_specific; +static struct sctp_af *sctp_af_v4_specific; +static struct sctp_af *sctp_af_v6_specific; + +struct kmem_cache *sctp_chunk_cachep __read_mostly; +struct kmem_cache *sctp_bucket_cachep __read_mostly; + +long sysctl_sctp_mem[3]; +int sysctl_sctp_rmem[3]; +int sysctl_sctp_wmem[3]; + +/* Private helper to extract ipv4 address and stash them in + * the protocol structure. + */ +static void sctp_v4_copy_addrlist(struct list_head *addrlist, + struct net_device *dev) +{ + struct in_device *in_dev; + struct in_ifaddr *ifa; + struct sctp_sockaddr_entry *addr; + + rcu_read_lock(); + if ((in_dev = __in_dev_get_rcu(dev)) == NULL) { + rcu_read_unlock(); + return; + } + + in_dev_for_each_ifa_rcu(ifa, in_dev) { + /* Add the address to the local list. */ + addr = kzalloc(sizeof(*addr), GFP_ATOMIC); + if (addr) { + addr->a.v4.sin_family = AF_INET; + addr->a.v4.sin_addr.s_addr = ifa->ifa_local; + addr->valid = 1; + INIT_LIST_HEAD(&addr->list); + list_add_tail(&addr->list, addrlist); + } + } + + rcu_read_unlock(); +} + +/* Extract our IP addresses from the system and stash them in the + * protocol structure. + */ +static void sctp_get_local_addr_list(struct net *net) +{ + struct net_device *dev; + struct list_head *pos; + struct sctp_af *af; + + rcu_read_lock(); + for_each_netdev_rcu(net, dev) { + list_for_each(pos, &sctp_address_families) { + af = list_entry(pos, struct sctp_af, list); + af->copy_addrlist(&net->sctp.local_addr_list, dev); + } + } + rcu_read_unlock(); +} + +/* Free the existing local addresses. */ +static void sctp_free_local_addr_list(struct net *net) +{ + struct sctp_sockaddr_entry *addr; + struct list_head *pos, *temp; + + list_for_each_safe(pos, temp, &net->sctp.local_addr_list) { + addr = list_entry(pos, struct sctp_sockaddr_entry, list); + list_del(pos); + kfree(addr); + } +} + +/* Copy the local addresses which are valid for 'scope' into 'bp'. */ +int sctp_copy_local_addr_list(struct net *net, struct sctp_bind_addr *bp, + enum sctp_scope scope, gfp_t gfp, int copy_flags) +{ + struct sctp_sockaddr_entry *addr; + union sctp_addr laddr; + int error = 0; + + rcu_read_lock(); + list_for_each_entry_rcu(addr, &net->sctp.local_addr_list, list) { + if (!addr->valid) + continue; + if (!sctp_in_scope(net, &addr->a, scope)) + continue; + + /* Now that the address is in scope, check to see if + * the address type is really supported by the local + * sock as well as the remote peer. + */ + if (addr->a.sa.sa_family == AF_INET && + (!(copy_flags & SCTP_ADDR4_ALLOWED) || + !(copy_flags & SCTP_ADDR4_PEERSUPP))) + continue; + if (addr->a.sa.sa_family == AF_INET6 && + (!(copy_flags & SCTP_ADDR6_ALLOWED) || + !(copy_flags & SCTP_ADDR6_PEERSUPP))) + continue; + + laddr = addr->a; + /* also works for setting ipv6 address port */ + laddr.v4.sin_port = htons(bp->port); + if (sctp_bind_addr_state(bp, &laddr) != -1) + continue; + + error = sctp_add_bind_addr(bp, &addr->a, sizeof(addr->a), + SCTP_ADDR_SRC, GFP_ATOMIC); + if (error) + break; + } + + rcu_read_unlock(); + return error; +} + +/* Copy over any ip options */ +static void sctp_v4_copy_ip_options(struct sock *sk, struct sock *newsk) +{ + struct inet_sock *newinet, *inet = inet_sk(sk); + struct ip_options_rcu *inet_opt, *newopt = NULL; + + newinet = inet_sk(newsk); + + rcu_read_lock(); + inet_opt = rcu_dereference(inet->inet_opt); + if (inet_opt) { + newopt = sock_kmalloc(newsk, sizeof(*inet_opt) + + inet_opt->opt.optlen, GFP_ATOMIC); + if (newopt) + memcpy(newopt, inet_opt, sizeof(*inet_opt) + + inet_opt->opt.optlen); + else + pr_err("%s: Failed to copy ip options\n", __func__); + } + RCU_INIT_POINTER(newinet->inet_opt, newopt); + rcu_read_unlock(); +} + +/* Account for the IP options */ +static int sctp_v4_ip_options_len(struct sock *sk) +{ + struct inet_sock *inet = inet_sk(sk); + struct ip_options_rcu *inet_opt; + int len = 0; + + rcu_read_lock(); + inet_opt = rcu_dereference(inet->inet_opt); + if (inet_opt) + len = inet_opt->opt.optlen; + + rcu_read_unlock(); + return len; +} + +/* Initialize a sctp_addr from in incoming skb. */ +static void sctp_v4_from_skb(union sctp_addr *addr, struct sk_buff *skb, + int is_saddr) +{ + /* Always called on head skb, so this is safe */ + struct sctphdr *sh = sctp_hdr(skb); + struct sockaddr_in *sa = &addr->v4; + + addr->v4.sin_family = AF_INET; + + if (is_saddr) { + sa->sin_port = sh->source; + sa->sin_addr.s_addr = ip_hdr(skb)->saddr; + } else { + sa->sin_port = sh->dest; + sa->sin_addr.s_addr = ip_hdr(skb)->daddr; + } + memset(sa->sin_zero, 0, sizeof(sa->sin_zero)); +} + +/* Initialize an sctp_addr from a socket. */ +static void sctp_v4_from_sk(union sctp_addr *addr, struct sock *sk) +{ + addr->v4.sin_family = AF_INET; + addr->v4.sin_port = 0; + addr->v4.sin_addr.s_addr = inet_sk(sk)->inet_rcv_saddr; + memset(addr->v4.sin_zero, 0, sizeof(addr->v4.sin_zero)); +} + +/* Initialize sk->sk_rcv_saddr from sctp_addr. */ +static void sctp_v4_to_sk_saddr(union sctp_addr *addr, struct sock *sk) +{ + inet_sk(sk)->inet_rcv_saddr = addr->v4.sin_addr.s_addr; +} + +/* Initialize sk->sk_daddr from sctp_addr. */ +static void sctp_v4_to_sk_daddr(union sctp_addr *addr, struct sock *sk) +{ + inet_sk(sk)->inet_daddr = addr->v4.sin_addr.s_addr; +} + +/* Initialize a sctp_addr from an address parameter. */ +static bool sctp_v4_from_addr_param(union sctp_addr *addr, + union sctp_addr_param *param, + __be16 port, int iif) +{ + if (ntohs(param->v4.param_hdr.length) < sizeof(struct sctp_ipv4addr_param)) + return false; + + addr->v4.sin_family = AF_INET; + addr->v4.sin_port = port; + addr->v4.sin_addr.s_addr = param->v4.addr.s_addr; + memset(addr->v4.sin_zero, 0, sizeof(addr->v4.sin_zero)); + + return true; +} + +/* Initialize an address parameter from a sctp_addr and return the length + * of the address parameter. + */ +static int sctp_v4_to_addr_param(const union sctp_addr *addr, + union sctp_addr_param *param) +{ + int length = sizeof(struct sctp_ipv4addr_param); + + param->v4.param_hdr.type = SCTP_PARAM_IPV4_ADDRESS; + param->v4.param_hdr.length = htons(length); + param->v4.addr.s_addr = addr->v4.sin_addr.s_addr; + + return length; +} + +/* Initialize a sctp_addr from a dst_entry. */ +static void sctp_v4_dst_saddr(union sctp_addr *saddr, struct flowi4 *fl4, + __be16 port) +{ + saddr->v4.sin_family = AF_INET; + saddr->v4.sin_port = port; + saddr->v4.sin_addr.s_addr = fl4->saddr; + memset(saddr->v4.sin_zero, 0, sizeof(saddr->v4.sin_zero)); +} + +/* Compare two addresses exactly. */ +static int sctp_v4_cmp_addr(const union sctp_addr *addr1, + const union sctp_addr *addr2) +{ + if (addr1->sa.sa_family != addr2->sa.sa_family) + return 0; + if (addr1->v4.sin_port != addr2->v4.sin_port) + return 0; + if (addr1->v4.sin_addr.s_addr != addr2->v4.sin_addr.s_addr) + return 0; + + return 1; +} + +/* Initialize addr struct to INADDR_ANY. */ +static void sctp_v4_inaddr_any(union sctp_addr *addr, __be16 port) +{ + addr->v4.sin_family = AF_INET; + addr->v4.sin_addr.s_addr = htonl(INADDR_ANY); + addr->v4.sin_port = port; + memset(addr->v4.sin_zero, 0, sizeof(addr->v4.sin_zero)); +} + +/* Is this a wildcard address? */ +static int sctp_v4_is_any(const union sctp_addr *addr) +{ + return htonl(INADDR_ANY) == addr->v4.sin_addr.s_addr; +} + +/* This function checks if the address is a valid address to be used for + * SCTP binding. + * + * Output: + * Return 0 - If the address is a non-unicast or an illegal address. + * Return 1 - If the address is a unicast. + */ +static int sctp_v4_addr_valid(union sctp_addr *addr, + struct sctp_sock *sp, + const struct sk_buff *skb) +{ + /* IPv4 addresses not allowed */ + if (sp && ipv6_only_sock(sctp_opt2sk(sp))) + return 0; + + /* Is this a non-unicast address or a unusable SCTP address? */ + if (IS_IPV4_UNUSABLE_ADDRESS(addr->v4.sin_addr.s_addr)) + return 0; + + /* Is this a broadcast address? */ + if (skb && skb_rtable(skb)->rt_flags & RTCF_BROADCAST) + return 0; + + return 1; +} + +/* Should this be available for binding? */ +static int sctp_v4_available(union sctp_addr *addr, struct sctp_sock *sp) +{ + struct net *net = sock_net(&sp->inet.sk); + int ret = inet_addr_type(net, addr->v4.sin_addr.s_addr); + + + if (addr->v4.sin_addr.s_addr != htonl(INADDR_ANY) && + ret != RTN_LOCAL && + !sp->inet.freebind && + !READ_ONCE(net->ipv4.sysctl_ip_nonlocal_bind)) + return 0; + + if (ipv6_only_sock(sctp_opt2sk(sp))) + return 0; + + return 1; +} + +/* Checking the loopback, private and other address scopes as defined in + * RFC 1918. The IPv4 scoping is based on the draft for SCTP IPv4 + * scoping . + * + * Level 0 - unusable SCTP addresses + * Level 1 - loopback address + * Level 2 - link-local addresses + * Level 3 - private addresses. + * Level 4 - global addresses + * For INIT and INIT-ACK address list, let L be the level of + * requested destination address, sender and receiver + * SHOULD include all of its addresses with level greater + * than or equal to L. + * + * IPv4 scoping can be controlled through sysctl option + * net.sctp.addr_scope_policy + */ +static enum sctp_scope sctp_v4_scope(union sctp_addr *addr) +{ + enum sctp_scope retval; + + /* Check for unusable SCTP addresses. */ + if (IS_IPV4_UNUSABLE_ADDRESS(addr->v4.sin_addr.s_addr)) { + retval = SCTP_SCOPE_UNUSABLE; + } else if (ipv4_is_loopback(addr->v4.sin_addr.s_addr)) { + retval = SCTP_SCOPE_LOOPBACK; + } else if (ipv4_is_linklocal_169(addr->v4.sin_addr.s_addr)) { + retval = SCTP_SCOPE_LINK; + } else if (ipv4_is_private_10(addr->v4.sin_addr.s_addr) || + ipv4_is_private_172(addr->v4.sin_addr.s_addr) || + ipv4_is_private_192(addr->v4.sin_addr.s_addr) || + ipv4_is_test_198(addr->v4.sin_addr.s_addr)) { + retval = SCTP_SCOPE_PRIVATE; + } else { + retval = SCTP_SCOPE_GLOBAL; + } + + return retval; +} + +/* Returns a valid dst cache entry for the given source and destination ip + * addresses. If an association is passed, trys to get a dst entry with a + * source address that matches an address in the bind address list. + */ +static void sctp_v4_get_dst(struct sctp_transport *t, union sctp_addr *saddr, + struct flowi *fl, struct sock *sk) +{ + struct sctp_association *asoc = t->asoc; + struct rtable *rt; + struct flowi _fl; + struct flowi4 *fl4 = &_fl.u.ip4; + struct sctp_bind_addr *bp; + struct sctp_sockaddr_entry *laddr; + struct dst_entry *dst = NULL; + union sctp_addr *daddr = &t->ipaddr; + union sctp_addr dst_saddr; + __u8 tos = inet_sk(sk)->tos; + + if (t->dscp & SCTP_DSCP_SET_MASK) + tos = t->dscp & SCTP_DSCP_VAL_MASK; + memset(&_fl, 0x0, sizeof(_fl)); + fl4->daddr = daddr->v4.sin_addr.s_addr; + fl4->fl4_dport = daddr->v4.sin_port; + fl4->flowi4_proto = IPPROTO_SCTP; + if (asoc) { + fl4->flowi4_tos = RT_CONN_FLAGS_TOS(asoc->base.sk, tos); + fl4->flowi4_oif = asoc->base.sk->sk_bound_dev_if; + fl4->fl4_sport = htons(asoc->base.bind_addr.port); + } + if (saddr) { + fl4->saddr = saddr->v4.sin_addr.s_addr; + if (!fl4->fl4_sport) + fl4->fl4_sport = saddr->v4.sin_port; + } + + pr_debug("%s: dst:%pI4, src:%pI4 - ", __func__, &fl4->daddr, + &fl4->saddr); + + rt = ip_route_output_key(sock_net(sk), fl4); + if (!IS_ERR(rt)) { + dst = &rt->dst; + t->dst = dst; + memcpy(fl, &_fl, sizeof(_fl)); + } + + /* If there is no association or if a source address is passed, no + * more validation is required. + */ + if (!asoc || saddr) + goto out; + + bp = &asoc->base.bind_addr; + + if (dst) { + /* Walk through the bind address list and look for a bind + * address that matches the source address of the returned dst. + */ + sctp_v4_dst_saddr(&dst_saddr, fl4, htons(bp->port)); + rcu_read_lock(); + list_for_each_entry_rcu(laddr, &bp->address_list, list) { + if (!laddr->valid || (laddr->state == SCTP_ADDR_DEL) || + (laddr->state != SCTP_ADDR_SRC && + !asoc->src_out_of_asoc_ok)) + continue; + if (sctp_v4_cmp_addr(&dst_saddr, &laddr->a)) + goto out_unlock; + } + rcu_read_unlock(); + + /* None of the bound addresses match the source address of the + * dst. So release it. + */ + dst_release(dst); + dst = NULL; + } + + /* Walk through the bind address list and try to get a dst that + * matches a bind address as the source address. + */ + rcu_read_lock(); + list_for_each_entry_rcu(laddr, &bp->address_list, list) { + struct net_device *odev; + + if (!laddr->valid) + continue; + if (laddr->state != SCTP_ADDR_SRC || + AF_INET != laddr->a.sa.sa_family) + continue; + + fl4->fl4_sport = laddr->a.v4.sin_port; + flowi4_update_output(fl4, + asoc->base.sk->sk_bound_dev_if, + RT_CONN_FLAGS_TOS(asoc->base.sk, tos), + daddr->v4.sin_addr.s_addr, + laddr->a.v4.sin_addr.s_addr); + + rt = ip_route_output_key(sock_net(sk), fl4); + if (IS_ERR(rt)) + continue; + + /* Ensure the src address belongs to the output + * interface. + */ + odev = __ip_dev_find(sock_net(sk), laddr->a.v4.sin_addr.s_addr, + false); + if (!odev || odev->ifindex != fl4->flowi4_oif) { + if (!dst) { + dst = &rt->dst; + t->dst = dst; + memcpy(fl, &_fl, sizeof(_fl)); + } else { + dst_release(&rt->dst); + } + continue; + } + + dst_release(dst); + dst = &rt->dst; + t->dst = dst; + memcpy(fl, &_fl, sizeof(_fl)); + break; + } + +out_unlock: + rcu_read_unlock(); +out: + if (dst) { + pr_debug("rt_dst:%pI4, rt_src:%pI4\n", + &fl->u.ip4.daddr, &fl->u.ip4.saddr); + } else { + t->dst = NULL; + pr_debug("no route\n"); + } +} + +/* For v4, the source address is cached in the route entry(dst). So no need + * to cache it separately and hence this is an empty routine. + */ +static void sctp_v4_get_saddr(struct sctp_sock *sk, + struct sctp_transport *t, + struct flowi *fl) +{ + union sctp_addr *saddr = &t->saddr; + struct rtable *rt = (struct rtable *)t->dst; + + if (rt) { + saddr->v4.sin_family = AF_INET; + saddr->v4.sin_addr.s_addr = fl->u.ip4.saddr; + } +} + +/* What interface did this skb arrive on? */ +static int sctp_v4_skb_iif(const struct sk_buff *skb) +{ + return inet_iif(skb); +} + +/* Was this packet marked by Explicit Congestion Notification? */ +static int sctp_v4_is_ce(const struct sk_buff *skb) +{ + return INET_ECN_is_ce(ip_hdr(skb)->tos); +} + +/* Create and initialize a new sk for the socket returned by accept(). */ +static struct sock *sctp_v4_create_accept_sk(struct sock *sk, + struct sctp_association *asoc, + bool kern) +{ + struct sock *newsk = sk_alloc(sock_net(sk), PF_INET, GFP_KERNEL, + sk->sk_prot, kern); + struct inet_sock *newinet; + + if (!newsk) + goto out; + + sock_init_data(NULL, newsk); + + sctp_copy_sock(newsk, sk, asoc); + sock_reset_flag(newsk, SOCK_ZAPPED); + + sctp_v4_copy_ip_options(sk, newsk); + + newinet = inet_sk(newsk); + + newinet->inet_daddr = asoc->peer.primary_addr.v4.sin_addr.s_addr; + + sk_refcnt_debug_inc(newsk); + + if (newsk->sk_prot->init(newsk)) { + sk_common_release(newsk); + newsk = NULL; + } + +out: + return newsk; +} + +static int sctp_v4_addr_to_user(struct sctp_sock *sp, union sctp_addr *addr) +{ + /* No address mapping for V4 sockets */ + memset(addr->v4.sin_zero, 0, sizeof(addr->v4.sin_zero)); + return sizeof(struct sockaddr_in); +} + +/* Dump the v4 addr to the seq file. */ +static void sctp_v4_seq_dump_addr(struct seq_file *seq, union sctp_addr *addr) +{ + seq_printf(seq, "%pI4 ", &addr->v4.sin_addr); +} + +static void sctp_v4_ecn_capable(struct sock *sk) +{ + INET_ECN_xmit(sk); +} + +static void sctp_addr_wq_timeout_handler(struct timer_list *t) +{ + struct net *net = from_timer(net, t, sctp.addr_wq_timer); + struct sctp_sockaddr_entry *addrw, *temp; + struct sctp_sock *sp; + + spin_lock_bh(&net->sctp.addr_wq_lock); + + list_for_each_entry_safe(addrw, temp, &net->sctp.addr_waitq, list) { + pr_debug("%s: the first ent in wq:%p is addr:%pISc for cmd:%d at " + "entry:%p\n", __func__, &net->sctp.addr_waitq, &addrw->a.sa, + addrw->state, addrw); + +#if IS_ENABLED(CONFIG_IPV6) + /* Now we send an ASCONF for each association */ + /* Note. we currently don't handle link local IPv6 addressees */ + if (addrw->a.sa.sa_family == AF_INET6) { + struct in6_addr *in6; + + if (ipv6_addr_type(&addrw->a.v6.sin6_addr) & + IPV6_ADDR_LINKLOCAL) + goto free_next; + + in6 = (struct in6_addr *)&addrw->a.v6.sin6_addr; + if (ipv6_chk_addr(net, in6, NULL, 0) == 0 && + addrw->state == SCTP_ADDR_NEW) { + unsigned long timeo_val; + + pr_debug("%s: this is on DAD, trying %d sec " + "later\n", __func__, + SCTP_ADDRESS_TICK_DELAY); + + timeo_val = jiffies; + timeo_val += msecs_to_jiffies(SCTP_ADDRESS_TICK_DELAY); + mod_timer(&net->sctp.addr_wq_timer, timeo_val); + break; + } + } +#endif + list_for_each_entry(sp, &net->sctp.auto_asconf_splist, auto_asconf_list) { + struct sock *sk; + + sk = sctp_opt2sk(sp); + /* ignore bound-specific endpoints */ + if (!sctp_is_ep_boundall(sk)) + continue; + bh_lock_sock(sk); + if (sctp_asconf_mgmt(sp, addrw) < 0) + pr_debug("%s: sctp_asconf_mgmt failed\n", __func__); + bh_unlock_sock(sk); + } +#if IS_ENABLED(CONFIG_IPV6) +free_next: +#endif + list_del(&addrw->list); + kfree(addrw); + } + spin_unlock_bh(&net->sctp.addr_wq_lock); +} + +static void sctp_free_addr_wq(struct net *net) +{ + struct sctp_sockaddr_entry *addrw; + struct sctp_sockaddr_entry *temp; + + spin_lock_bh(&net->sctp.addr_wq_lock); + del_timer(&net->sctp.addr_wq_timer); + list_for_each_entry_safe(addrw, temp, &net->sctp.addr_waitq, list) { + list_del(&addrw->list); + kfree(addrw); + } + spin_unlock_bh(&net->sctp.addr_wq_lock); +} + +/* lookup the entry for the same address in the addr_waitq + * sctp_addr_wq MUST be locked + */ +static struct sctp_sockaddr_entry *sctp_addr_wq_lookup(struct net *net, + struct sctp_sockaddr_entry *addr) +{ + struct sctp_sockaddr_entry *addrw; + + list_for_each_entry(addrw, &net->sctp.addr_waitq, list) { + if (addrw->a.sa.sa_family != addr->a.sa.sa_family) + continue; + if (addrw->a.sa.sa_family == AF_INET) { + if (addrw->a.v4.sin_addr.s_addr == + addr->a.v4.sin_addr.s_addr) + return addrw; + } else if (addrw->a.sa.sa_family == AF_INET6) { + if (ipv6_addr_equal(&addrw->a.v6.sin6_addr, + &addr->a.v6.sin6_addr)) + return addrw; + } + } + return NULL; +} + +void sctp_addr_wq_mgmt(struct net *net, struct sctp_sockaddr_entry *addr, int cmd) +{ + struct sctp_sockaddr_entry *addrw; + unsigned long timeo_val; + + /* first, we check if an opposite message already exist in the queue. + * If we found such message, it is removed. + * This operation is a bit stupid, but the DHCP client attaches the + * new address after a couple of addition and deletion of that address + */ + + spin_lock_bh(&net->sctp.addr_wq_lock); + /* Offsets existing events in addr_wq */ + addrw = sctp_addr_wq_lookup(net, addr); + if (addrw) { + if (addrw->state != cmd) { + pr_debug("%s: offsets existing entry for %d, addr:%pISc " + "in wq:%p\n", __func__, addrw->state, &addrw->a.sa, + &net->sctp.addr_waitq); + + list_del(&addrw->list); + kfree(addrw); + } + spin_unlock_bh(&net->sctp.addr_wq_lock); + return; + } + + /* OK, we have to add the new address to the wait queue */ + addrw = kmemdup(addr, sizeof(struct sctp_sockaddr_entry), GFP_ATOMIC); + if (addrw == NULL) { + spin_unlock_bh(&net->sctp.addr_wq_lock); + return; + } + addrw->state = cmd; + list_add_tail(&addrw->list, &net->sctp.addr_waitq); + + pr_debug("%s: add new entry for cmd:%d, addr:%pISc in wq:%p\n", + __func__, addrw->state, &addrw->a.sa, &net->sctp.addr_waitq); + + if (!timer_pending(&net->sctp.addr_wq_timer)) { + timeo_val = jiffies; + timeo_val += msecs_to_jiffies(SCTP_ADDRESS_TICK_DELAY); + mod_timer(&net->sctp.addr_wq_timer, timeo_val); + } + spin_unlock_bh(&net->sctp.addr_wq_lock); +} + +/* Event handler for inet address addition/deletion events. + * The sctp_local_addr_list needs to be protocted by a spin lock since + * multiple notifiers (say IPv4 and IPv6) may be running at the same + * time and thus corrupt the list. + * The reader side is protected with RCU. + */ +static int sctp_inetaddr_event(struct notifier_block *this, unsigned long ev, + void *ptr) +{ + struct in_ifaddr *ifa = (struct in_ifaddr *)ptr; + struct sctp_sockaddr_entry *addr = NULL; + struct sctp_sockaddr_entry *temp; + struct net *net = dev_net(ifa->ifa_dev->dev); + int found = 0; + + switch (ev) { + case NETDEV_UP: + addr = kzalloc(sizeof(*addr), GFP_ATOMIC); + if (addr) { + addr->a.v4.sin_family = AF_INET; + addr->a.v4.sin_addr.s_addr = ifa->ifa_local; + addr->valid = 1; + spin_lock_bh(&net->sctp.local_addr_lock); + list_add_tail_rcu(&addr->list, &net->sctp.local_addr_list); + sctp_addr_wq_mgmt(net, addr, SCTP_ADDR_NEW); + spin_unlock_bh(&net->sctp.local_addr_lock); + } + break; + case NETDEV_DOWN: + spin_lock_bh(&net->sctp.local_addr_lock); + list_for_each_entry_safe(addr, temp, + &net->sctp.local_addr_list, list) { + if (addr->a.sa.sa_family == AF_INET && + addr->a.v4.sin_addr.s_addr == + ifa->ifa_local) { + sctp_addr_wq_mgmt(net, addr, SCTP_ADDR_DEL); + found = 1; + addr->valid = 0; + list_del_rcu(&addr->list); + break; + } + } + spin_unlock_bh(&net->sctp.local_addr_lock); + if (found) + kfree_rcu(addr, rcu); + break; + } + + return NOTIFY_DONE; +} + +/* + * Initialize the control inode/socket with a control endpoint data + * structure. This endpoint is reserved exclusively for the OOTB processing. + */ +static int sctp_ctl_sock_init(struct net *net) +{ + int err; + sa_family_t family = PF_INET; + + if (sctp_get_pf_specific(PF_INET6)) + family = PF_INET6; + + err = inet_ctl_sock_create(&net->sctp.ctl_sock, family, + SOCK_SEQPACKET, IPPROTO_SCTP, net); + + /* If IPv6 socket could not be created, try the IPv4 socket */ + if (err < 0 && family == PF_INET6) + err = inet_ctl_sock_create(&net->sctp.ctl_sock, AF_INET, + SOCK_SEQPACKET, IPPROTO_SCTP, + net); + + if (err < 0) { + pr_err("Failed to create the SCTP control socket\n"); + return err; + } + return 0; +} + +static int sctp_udp_rcv(struct sock *sk, struct sk_buff *skb) +{ + SCTP_INPUT_CB(skb)->encap_port = udp_hdr(skb)->source; + + skb_set_transport_header(skb, sizeof(struct udphdr)); + sctp_rcv(skb); + return 0; +} + +int sctp_udp_sock_start(struct net *net) +{ + struct udp_tunnel_sock_cfg tuncfg = {NULL}; + struct udp_port_cfg udp_conf = {0}; + struct socket *sock; + int err; + + udp_conf.family = AF_INET; + udp_conf.local_ip.s_addr = htonl(INADDR_ANY); + udp_conf.local_udp_port = htons(net->sctp.udp_port); + err = udp_sock_create(net, &udp_conf, &sock); + if (err) { + pr_err("Failed to create the SCTP UDP tunneling v4 sock\n"); + return err; + } + + tuncfg.encap_type = 1; + tuncfg.encap_rcv = sctp_udp_rcv; + tuncfg.encap_err_lookup = sctp_udp_v4_err; + setup_udp_tunnel_sock(net, sock, &tuncfg); + net->sctp.udp4_sock = sock->sk; + +#if IS_ENABLED(CONFIG_IPV6) + memset(&udp_conf, 0, sizeof(udp_conf)); + + udp_conf.family = AF_INET6; + udp_conf.local_ip6 = in6addr_any; + udp_conf.local_udp_port = htons(net->sctp.udp_port); + udp_conf.use_udp6_rx_checksums = true; + udp_conf.ipv6_v6only = true; + err = udp_sock_create(net, &udp_conf, &sock); + if (err) { + pr_err("Failed to create the SCTP UDP tunneling v6 sock\n"); + udp_tunnel_sock_release(net->sctp.udp4_sock->sk_socket); + net->sctp.udp4_sock = NULL; + return err; + } + + tuncfg.encap_type = 1; + tuncfg.encap_rcv = sctp_udp_rcv; + tuncfg.encap_err_lookup = sctp_udp_v6_err; + setup_udp_tunnel_sock(net, sock, &tuncfg); + net->sctp.udp6_sock = sock->sk; +#endif + + return 0; +} + +void sctp_udp_sock_stop(struct net *net) +{ + if (net->sctp.udp4_sock) { + udp_tunnel_sock_release(net->sctp.udp4_sock->sk_socket); + net->sctp.udp4_sock = NULL; + } + if (net->sctp.udp6_sock) { + udp_tunnel_sock_release(net->sctp.udp6_sock->sk_socket); + net->sctp.udp6_sock = NULL; + } +} + +/* Register address family specific functions. */ +int sctp_register_af(struct sctp_af *af) +{ + switch (af->sa_family) { + case AF_INET: + if (sctp_af_v4_specific) + return 0; + sctp_af_v4_specific = af; + break; + case AF_INET6: + if (sctp_af_v6_specific) + return 0; + sctp_af_v6_specific = af; + break; + default: + return 0; + } + + INIT_LIST_HEAD(&af->list); + list_add_tail(&af->list, &sctp_address_families); + return 1; +} + +/* Get the table of functions for manipulating a particular address + * family. + */ +struct sctp_af *sctp_get_af_specific(sa_family_t family) +{ + switch (family) { + case AF_INET: + return sctp_af_v4_specific; + case AF_INET6: + return sctp_af_v6_specific; + default: + return NULL; + } +} + +/* Common code to initialize a AF_INET msg_name. */ +static void sctp_inet_msgname(char *msgname, int *addr_len) +{ + struct sockaddr_in *sin; + + sin = (struct sockaddr_in *)msgname; + *addr_len = sizeof(struct sockaddr_in); + sin->sin_family = AF_INET; + memset(sin->sin_zero, 0, sizeof(sin->sin_zero)); +} + +/* Copy the primary address of the peer primary address as the msg_name. */ +static void sctp_inet_event_msgname(struct sctp_ulpevent *event, char *msgname, + int *addr_len) +{ + struct sockaddr_in *sin, *sinfrom; + + if (msgname) { + struct sctp_association *asoc; + + asoc = event->asoc; + sctp_inet_msgname(msgname, addr_len); + sin = (struct sockaddr_in *)msgname; + sinfrom = &asoc->peer.primary_addr.v4; + sin->sin_port = htons(asoc->peer.port); + sin->sin_addr.s_addr = sinfrom->sin_addr.s_addr; + } +} + +/* Initialize and copy out a msgname from an inbound skb. */ +static void sctp_inet_skb_msgname(struct sk_buff *skb, char *msgname, int *len) +{ + if (msgname) { + struct sctphdr *sh = sctp_hdr(skb); + struct sockaddr_in *sin = (struct sockaddr_in *)msgname; + + sctp_inet_msgname(msgname, len); + sin->sin_port = sh->source; + sin->sin_addr.s_addr = ip_hdr(skb)->saddr; + } +} + +/* Do we support this AF? */ +static int sctp_inet_af_supported(sa_family_t family, struct sctp_sock *sp) +{ + /* PF_INET only supports AF_INET addresses. */ + return AF_INET == family; +} + +/* Address matching with wildcards allowed. */ +static int sctp_inet_cmp_addr(const union sctp_addr *addr1, + const union sctp_addr *addr2, + struct sctp_sock *opt) +{ + /* PF_INET only supports AF_INET addresses. */ + if (addr1->sa.sa_family != addr2->sa.sa_family) + return 0; + if (htonl(INADDR_ANY) == addr1->v4.sin_addr.s_addr || + htonl(INADDR_ANY) == addr2->v4.sin_addr.s_addr) + return 1; + if (addr1->v4.sin_addr.s_addr == addr2->v4.sin_addr.s_addr) + return 1; + + return 0; +} + +/* Verify that provided sockaddr looks bindable. Common verification has + * already been taken care of. + */ +static int sctp_inet_bind_verify(struct sctp_sock *opt, union sctp_addr *addr) +{ + return sctp_v4_available(addr, opt); +} + +/* Verify that sockaddr looks sendable. Common verification has already + * been taken care of. + */ +static int sctp_inet_send_verify(struct sctp_sock *opt, union sctp_addr *addr) +{ + return 1; +} + +/* Fill in Supported Address Type information for INIT and INIT-ACK + * chunks. Returns number of addresses supported. + */ +static int sctp_inet_supported_addrs(const struct sctp_sock *opt, + __be16 *types) +{ + types[0] = SCTP_PARAM_IPV4_ADDRESS; + return 1; +} + +/* Wrapper routine that calls the ip transmit routine. */ +static inline int sctp_v4_xmit(struct sk_buff *skb, struct sctp_transport *t) +{ + struct dst_entry *dst = dst_clone(t->dst); + struct flowi4 *fl4 = &t->fl.u.ip4; + struct sock *sk = skb->sk; + struct inet_sock *inet = inet_sk(sk); + __u8 dscp = inet->tos; + __be16 df = 0; + + pr_debug("%s: skb:%p, len:%d, src:%pI4, dst:%pI4\n", __func__, skb, + skb->len, &fl4->saddr, &fl4->daddr); + + if (t->dscp & SCTP_DSCP_SET_MASK) + dscp = t->dscp & SCTP_DSCP_VAL_MASK; + + inet->pmtudisc = t->param_flags & SPP_PMTUD_ENABLE ? IP_PMTUDISC_DO + : IP_PMTUDISC_DONT; + SCTP_INC_STATS(sock_net(sk), SCTP_MIB_OUTSCTPPACKS); + + if (!t->encap_port || !sctp_sk(sk)->udp_port) { + skb_dst_set(skb, dst); + return __ip_queue_xmit(sk, skb, &t->fl, dscp); + } + + if (skb_is_gso(skb)) + skb_shinfo(skb)->gso_type |= SKB_GSO_UDP_TUNNEL_CSUM; + + if (ip_dont_fragment(sk, dst) && !skb->ignore_df) + df = htons(IP_DF); + + skb->encapsulation = 1; + skb_reset_inner_mac_header(skb); + skb_reset_inner_transport_header(skb); + skb_set_inner_ipproto(skb, IPPROTO_SCTP); + udp_tunnel_xmit_skb((struct rtable *)dst, sk, skb, fl4->saddr, + fl4->daddr, dscp, ip4_dst_hoplimit(dst), df, + sctp_sk(sk)->udp_port, t->encap_port, false, false); + return 0; +} + +static struct sctp_af sctp_af_inet; + +static struct sctp_pf sctp_pf_inet = { + .event_msgname = sctp_inet_event_msgname, + .skb_msgname = sctp_inet_skb_msgname, + .af_supported = sctp_inet_af_supported, + .cmp_addr = sctp_inet_cmp_addr, + .bind_verify = sctp_inet_bind_verify, + .send_verify = sctp_inet_send_verify, + .supported_addrs = sctp_inet_supported_addrs, + .create_accept_sk = sctp_v4_create_accept_sk, + .addr_to_user = sctp_v4_addr_to_user, + .to_sk_saddr = sctp_v4_to_sk_saddr, + .to_sk_daddr = sctp_v4_to_sk_daddr, + .copy_ip_options = sctp_v4_copy_ip_options, + .af = &sctp_af_inet +}; + +/* Notifier for inetaddr addition/deletion events. */ +static struct notifier_block sctp_inetaddr_notifier = { + .notifier_call = sctp_inetaddr_event, +}; + +/* Socket operations. */ +static const struct proto_ops inet_seqpacket_ops = { + .family = PF_INET, + .owner = THIS_MODULE, + .release = inet_release, /* Needs to be wrapped... */ + .bind = inet_bind, + .connect = sctp_inet_connect, + .socketpair = sock_no_socketpair, + .accept = inet_accept, + .getname = inet_getname, /* Semantics are different. */ + .poll = sctp_poll, + .ioctl = inet_ioctl, + .gettstamp = sock_gettstamp, + .listen = sctp_inet_listen, + .shutdown = inet_shutdown, /* Looks harmless. */ + .setsockopt = sock_common_setsockopt, /* IP_SOL IP_OPTION is a problem */ + .getsockopt = sock_common_getsockopt, + .sendmsg = inet_sendmsg, + .recvmsg = inet_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +/* Registration with AF_INET family. */ +static struct inet_protosw sctp_seqpacket_protosw = { + .type = SOCK_SEQPACKET, + .protocol = IPPROTO_SCTP, + .prot = &sctp_prot, + .ops = &inet_seqpacket_ops, + .flags = SCTP_PROTOSW_FLAG +}; +static struct inet_protosw sctp_stream_protosw = { + .type = SOCK_STREAM, + .protocol = IPPROTO_SCTP, + .prot = &sctp_prot, + .ops = &inet_seqpacket_ops, + .flags = SCTP_PROTOSW_FLAG +}; + +static int sctp4_rcv(struct sk_buff *skb) +{ + SCTP_INPUT_CB(skb)->encap_port = 0; + return sctp_rcv(skb); +} + +/* Register with IP layer. */ +static const struct net_protocol sctp_protocol = { + .handler = sctp4_rcv, + .err_handler = sctp_v4_err, + .no_policy = 1, + .icmp_strict_tag_validation = 1, +}; + +/* IPv4 address related functions. */ +static struct sctp_af sctp_af_inet = { + .sa_family = AF_INET, + .sctp_xmit = sctp_v4_xmit, + .setsockopt = ip_setsockopt, + .getsockopt = ip_getsockopt, + .get_dst = sctp_v4_get_dst, + .get_saddr = sctp_v4_get_saddr, + .copy_addrlist = sctp_v4_copy_addrlist, + .from_skb = sctp_v4_from_skb, + .from_sk = sctp_v4_from_sk, + .from_addr_param = sctp_v4_from_addr_param, + .to_addr_param = sctp_v4_to_addr_param, + .cmp_addr = sctp_v4_cmp_addr, + .addr_valid = sctp_v4_addr_valid, + .inaddr_any = sctp_v4_inaddr_any, + .is_any = sctp_v4_is_any, + .available = sctp_v4_available, + .scope = sctp_v4_scope, + .skb_iif = sctp_v4_skb_iif, + .is_ce = sctp_v4_is_ce, + .seq_dump_addr = sctp_v4_seq_dump_addr, + .ecn_capable = sctp_v4_ecn_capable, + .net_header_len = sizeof(struct iphdr), + .sockaddr_len = sizeof(struct sockaddr_in), + .ip_options_len = sctp_v4_ip_options_len, +}; + +struct sctp_pf *sctp_get_pf_specific(sa_family_t family) +{ + switch (family) { + case PF_INET: + return sctp_pf_inet_specific; + case PF_INET6: + return sctp_pf_inet6_specific; + default: + return NULL; + } +} + +/* Register the PF specific function table. */ +int sctp_register_pf(struct sctp_pf *pf, sa_family_t family) +{ + switch (family) { + case PF_INET: + if (sctp_pf_inet_specific) + return 0; + sctp_pf_inet_specific = pf; + break; + case PF_INET6: + if (sctp_pf_inet6_specific) + return 0; + sctp_pf_inet6_specific = pf; + break; + default: + return 0; + } + return 1; +} + +static inline int init_sctp_mibs(struct net *net) +{ + net->sctp.sctp_statistics = alloc_percpu(struct sctp_mib); + if (!net->sctp.sctp_statistics) + return -ENOMEM; + return 0; +} + +static inline void cleanup_sctp_mibs(struct net *net) +{ + free_percpu(net->sctp.sctp_statistics); +} + +static void sctp_v4_pf_init(void) +{ + /* Initialize the SCTP specific PF functions. */ + sctp_register_pf(&sctp_pf_inet, PF_INET); + sctp_register_af(&sctp_af_inet); +} + +static void sctp_v4_pf_exit(void) +{ + list_del(&sctp_af_inet.list); +} + +static int sctp_v4_protosw_init(void) +{ + int rc; + + rc = proto_register(&sctp_prot, 1); + if (rc) + return rc; + + /* Register SCTP(UDP and TCP style) with socket layer. */ + inet_register_protosw(&sctp_seqpacket_protosw); + inet_register_protosw(&sctp_stream_protosw); + + return 0; +} + +static void sctp_v4_protosw_exit(void) +{ + inet_unregister_protosw(&sctp_stream_protosw); + inet_unregister_protosw(&sctp_seqpacket_protosw); + proto_unregister(&sctp_prot); +} + +static int sctp_v4_add_protocol(void) +{ + /* Register notifier for inet address additions/deletions. */ + register_inetaddr_notifier(&sctp_inetaddr_notifier); + + /* Register SCTP with inet layer. */ + if (inet_add_protocol(&sctp_protocol, IPPROTO_SCTP) < 0) + return -EAGAIN; + + return 0; +} + +static void sctp_v4_del_protocol(void) +{ + inet_del_protocol(&sctp_protocol, IPPROTO_SCTP); + unregister_inetaddr_notifier(&sctp_inetaddr_notifier); +} + +static int __net_init sctp_defaults_init(struct net *net) +{ + int status; + + /* + * 14. Suggested SCTP Protocol Parameter Values + */ + /* The following protocol parameters are RECOMMENDED: */ + /* RTO.Initial - 3 seconds */ + net->sctp.rto_initial = SCTP_RTO_INITIAL; + /* RTO.Min - 1 second */ + net->sctp.rto_min = SCTP_RTO_MIN; + /* RTO.Max - 60 seconds */ + net->sctp.rto_max = SCTP_RTO_MAX; + /* RTO.Alpha - 1/8 */ + net->sctp.rto_alpha = SCTP_RTO_ALPHA; + /* RTO.Beta - 1/4 */ + net->sctp.rto_beta = SCTP_RTO_BETA; + + /* Valid.Cookie.Life - 60 seconds */ + net->sctp.valid_cookie_life = SCTP_DEFAULT_COOKIE_LIFE; + + /* Whether Cookie Preservative is enabled(1) or not(0) */ + net->sctp.cookie_preserve_enable = 1; + + /* Default sctp sockets to use md5 as their hmac alg */ +#if defined (CONFIG_SCTP_DEFAULT_COOKIE_HMAC_MD5) + net->sctp.sctp_hmac_alg = "md5"; +#elif defined (CONFIG_SCTP_DEFAULT_COOKIE_HMAC_SHA1) + net->sctp.sctp_hmac_alg = "sha1"; +#else + net->sctp.sctp_hmac_alg = NULL; +#endif + + /* Max.Burst - 4 */ + net->sctp.max_burst = SCTP_DEFAULT_MAX_BURST; + + /* Disable of Primary Path Switchover by default */ + net->sctp.ps_retrans = SCTP_PS_RETRANS_MAX; + + /* Enable pf state by default */ + net->sctp.pf_enable = 1; + + /* Ignore pf exposure feature by default */ + net->sctp.pf_expose = SCTP_PF_EXPOSE_UNSET; + + /* Association.Max.Retrans - 10 attempts + * Path.Max.Retrans - 5 attempts (per destination address) + * Max.Init.Retransmits - 8 attempts + */ + net->sctp.max_retrans_association = 10; + net->sctp.max_retrans_path = 5; + net->sctp.max_retrans_init = 8; + + /* Sendbuffer growth - do per-socket accounting */ + net->sctp.sndbuf_policy = 0; + + /* Rcvbuffer growth - do per-socket accounting */ + net->sctp.rcvbuf_policy = 0; + + /* HB.interval - 30 seconds */ + net->sctp.hb_interval = SCTP_DEFAULT_TIMEOUT_HEARTBEAT; + + /* delayed SACK timeout */ + net->sctp.sack_timeout = SCTP_DEFAULT_TIMEOUT_SACK; + + /* Disable ADDIP by default. */ + net->sctp.addip_enable = 0; + net->sctp.addip_noauth = 0; + net->sctp.default_auto_asconf = 0; + + /* Enable PR-SCTP by default. */ + net->sctp.prsctp_enable = 1; + + /* Disable RECONF by default. */ + net->sctp.reconf_enable = 0; + + /* Disable AUTH by default. */ + net->sctp.auth_enable = 0; + + /* Enable ECN by default. */ + net->sctp.ecn_enable = 1; + + /* Set UDP tunneling listening port to 0 by default */ + net->sctp.udp_port = 0; + + /* Set remote encap port to 0 by default */ + net->sctp.encap_port = 0; + + /* Set SCOPE policy to enabled */ + net->sctp.scope_policy = SCTP_SCOPE_POLICY_ENABLE; + + /* Set the default rwnd update threshold */ + net->sctp.rwnd_upd_shift = SCTP_DEFAULT_RWND_SHIFT; + + /* Initialize maximum autoclose timeout. */ + net->sctp.max_autoclose = INT_MAX / HZ; + + status = sctp_sysctl_net_register(net); + if (status) + goto err_sysctl_register; + + /* Allocate and initialise sctp mibs. */ + status = init_sctp_mibs(net); + if (status) + goto err_init_mibs; + +#ifdef CONFIG_PROC_FS + /* Initialize proc fs directory. */ + status = sctp_proc_init(net); + if (status) + goto err_init_proc; +#endif + + sctp_dbg_objcnt_init(net); + + /* Initialize the local address list. */ + INIT_LIST_HEAD(&net->sctp.local_addr_list); + spin_lock_init(&net->sctp.local_addr_lock); + sctp_get_local_addr_list(net); + + /* Initialize the address event list */ + INIT_LIST_HEAD(&net->sctp.addr_waitq); + INIT_LIST_HEAD(&net->sctp.auto_asconf_splist); + spin_lock_init(&net->sctp.addr_wq_lock); + net->sctp.addr_wq_timer.expires = 0; + timer_setup(&net->sctp.addr_wq_timer, sctp_addr_wq_timeout_handler, 0); + + return 0; + +#ifdef CONFIG_PROC_FS +err_init_proc: + cleanup_sctp_mibs(net); +#endif +err_init_mibs: + sctp_sysctl_net_unregister(net); +err_sysctl_register: + return status; +} + +static void __net_exit sctp_defaults_exit(struct net *net) +{ + /* Free the local address list */ + sctp_free_addr_wq(net); + sctp_free_local_addr_list(net); + +#ifdef CONFIG_PROC_FS + remove_proc_subtree("sctp", net->proc_net); + net->sctp.proc_net_sctp = NULL; +#endif + cleanup_sctp_mibs(net); + sctp_sysctl_net_unregister(net); +} + +static struct pernet_operations sctp_defaults_ops = { + .init = sctp_defaults_init, + .exit = sctp_defaults_exit, +}; + +static int __net_init sctp_ctrlsock_init(struct net *net) +{ + int status; + + /* Initialize the control inode/socket for handling OOTB packets. */ + status = sctp_ctl_sock_init(net); + if (status) + pr_err("Failed to initialize the SCTP control sock\n"); + + return status; +} + +static void __net_exit sctp_ctrlsock_exit(struct net *net) +{ + /* Free the control endpoint. */ + inet_ctl_sock_destroy(net->sctp.ctl_sock); +} + +static struct pernet_operations sctp_ctrlsock_ops = { + .init = sctp_ctrlsock_init, + .exit = sctp_ctrlsock_exit, +}; + +/* Initialize the universe into something sensible. */ +static __init int sctp_init(void) +{ + unsigned long nr_pages = totalram_pages(); + unsigned long limit; + unsigned long goal; + int max_entry_order; + int num_entries; + int max_share; + int status; + int order; + int i; + + sock_skb_cb_check_size(sizeof(struct sctp_ulpevent)); + + /* Allocate bind_bucket and chunk caches. */ + status = -ENOBUFS; + sctp_bucket_cachep = kmem_cache_create("sctp_bind_bucket", + sizeof(struct sctp_bind_bucket), + 0, SLAB_HWCACHE_ALIGN, + NULL); + if (!sctp_bucket_cachep) + goto out; + + sctp_chunk_cachep = kmem_cache_create("sctp_chunk", + sizeof(struct sctp_chunk), + 0, SLAB_HWCACHE_ALIGN, + NULL); + if (!sctp_chunk_cachep) + goto err_chunk_cachep; + + status = percpu_counter_init(&sctp_sockets_allocated, 0, GFP_KERNEL); + if (status) + goto err_percpu_counter_init; + + /* Implementation specific variables. */ + + /* Initialize default stream count setup information. */ + sctp_max_instreams = SCTP_DEFAULT_INSTREAMS; + sctp_max_outstreams = SCTP_DEFAULT_OUTSTREAMS; + + /* Initialize handle used for association ids. */ + idr_init(&sctp_assocs_id); + + limit = nr_free_buffer_pages() / 8; + limit = max(limit, 128UL); + sysctl_sctp_mem[0] = limit / 4 * 3; + sysctl_sctp_mem[1] = limit; + sysctl_sctp_mem[2] = sysctl_sctp_mem[0] * 2; + + /* Set per-socket limits to no more than 1/128 the pressure threshold*/ + limit = (sysctl_sctp_mem[1]) << (PAGE_SHIFT - 7); + max_share = min(4UL*1024*1024, limit); + + sysctl_sctp_rmem[0] = PAGE_SIZE; /* give each asoc 1 page min */ + sysctl_sctp_rmem[1] = 1500 * SKB_TRUESIZE(1); + sysctl_sctp_rmem[2] = max(sysctl_sctp_rmem[1], max_share); + + sysctl_sctp_wmem[0] = PAGE_SIZE; + sysctl_sctp_wmem[1] = 16*1024; + sysctl_sctp_wmem[2] = max(64*1024, max_share); + + /* Size and allocate the association hash table. + * The methodology is similar to that of the tcp hash tables. + * Though not identical. Start by getting a goal size + */ + if (nr_pages >= (128 * 1024)) + goal = nr_pages >> (22 - PAGE_SHIFT); + else + goal = nr_pages >> (24 - PAGE_SHIFT); + + /* Then compute the page order for said goal */ + order = get_order(goal); + + /* Now compute the required page order for the maximum sized table we + * want to create + */ + max_entry_order = get_order(MAX_SCTP_PORT_HASH_ENTRIES * + sizeof(struct sctp_bind_hashbucket)); + + /* Limit the page order by that maximum hash table size */ + order = min(order, max_entry_order); + + /* Allocate and initialize the endpoint hash table. */ + sctp_ep_hashsize = 64; + sctp_ep_hashtable = + kmalloc_array(64, sizeof(struct sctp_hashbucket), GFP_KERNEL); + if (!sctp_ep_hashtable) { + pr_err("Failed endpoint_hash alloc\n"); + status = -ENOMEM; + goto err_ehash_alloc; + } + for (i = 0; i < sctp_ep_hashsize; i++) { + rwlock_init(&sctp_ep_hashtable[i].lock); + INIT_HLIST_HEAD(&sctp_ep_hashtable[i].chain); + } + + /* Allocate and initialize the SCTP port hash table. + * Note that order is initalized to start at the max sized + * table we want to support. If we can't get that many pages + * reduce the order and try again + */ + do { + sctp_port_hashtable = (struct sctp_bind_hashbucket *) + __get_free_pages(GFP_KERNEL | __GFP_NOWARN, order); + } while (!sctp_port_hashtable && --order > 0); + + if (!sctp_port_hashtable) { + pr_err("Failed bind hash alloc\n"); + status = -ENOMEM; + goto err_bhash_alloc; + } + + /* Now compute the number of entries that will fit in the + * port hash space we allocated + */ + num_entries = (1UL << order) * PAGE_SIZE / + sizeof(struct sctp_bind_hashbucket); + + /* And finish by rounding it down to the nearest power of two. + * This wastes some memory of course, but it's needed because + * the hash function operates based on the assumption that + * the number of entries is a power of two. + */ + sctp_port_hashsize = rounddown_pow_of_two(num_entries); + + for (i = 0; i < sctp_port_hashsize; i++) { + spin_lock_init(&sctp_port_hashtable[i].lock); + INIT_HLIST_HEAD(&sctp_port_hashtable[i].chain); + } + + status = sctp_transport_hashtable_init(); + if (status) + goto err_thash_alloc; + + pr_info("Hash tables configured (bind %d/%d)\n", sctp_port_hashsize, + num_entries); + + sctp_sysctl_register(); + + INIT_LIST_HEAD(&sctp_address_families); + sctp_v4_pf_init(); + sctp_v6_pf_init(); + sctp_sched_ops_init(); + + status = register_pernet_subsys(&sctp_defaults_ops); + if (status) + goto err_register_defaults; + + status = sctp_v4_protosw_init(); + if (status) + goto err_protosw_init; + + status = sctp_v6_protosw_init(); + if (status) + goto err_v6_protosw_init; + + status = register_pernet_subsys(&sctp_ctrlsock_ops); + if (status) + goto err_register_ctrlsock; + + status = sctp_v4_add_protocol(); + if (status) + goto err_add_protocol; + + /* Register SCTP with inet6 layer. */ + status = sctp_v6_add_protocol(); + if (status) + goto err_v6_add_protocol; + + if (sctp_offload_init() < 0) + pr_crit("%s: Cannot add SCTP protocol offload\n", __func__); + +out: + return status; +err_v6_add_protocol: + sctp_v4_del_protocol(); +err_add_protocol: + unregister_pernet_subsys(&sctp_ctrlsock_ops); +err_register_ctrlsock: + sctp_v6_protosw_exit(); +err_v6_protosw_init: + sctp_v4_protosw_exit(); +err_protosw_init: + unregister_pernet_subsys(&sctp_defaults_ops); +err_register_defaults: + sctp_v4_pf_exit(); + sctp_v6_pf_exit(); + sctp_sysctl_unregister(); + free_pages((unsigned long)sctp_port_hashtable, + get_order(sctp_port_hashsize * + sizeof(struct sctp_bind_hashbucket))); +err_bhash_alloc: + sctp_transport_hashtable_destroy(); +err_thash_alloc: + kfree(sctp_ep_hashtable); +err_ehash_alloc: + percpu_counter_destroy(&sctp_sockets_allocated); +err_percpu_counter_init: + kmem_cache_destroy(sctp_chunk_cachep); +err_chunk_cachep: + kmem_cache_destroy(sctp_bucket_cachep); + goto out; +} + +/* Exit handler for the SCTP protocol. */ +static __exit void sctp_exit(void) +{ + /* BUG. This should probably do something useful like clean + * up all the remaining associations and all that memory. + */ + + /* Unregister with inet6/inet layers. */ + sctp_v6_del_protocol(); + sctp_v4_del_protocol(); + + unregister_pernet_subsys(&sctp_ctrlsock_ops); + + /* Free protosw registrations */ + sctp_v6_protosw_exit(); + sctp_v4_protosw_exit(); + + unregister_pernet_subsys(&sctp_defaults_ops); + + /* Unregister with socket layer. */ + sctp_v6_pf_exit(); + sctp_v4_pf_exit(); + + sctp_sysctl_unregister(); + + free_pages((unsigned long)sctp_port_hashtable, + get_order(sctp_port_hashsize * + sizeof(struct sctp_bind_hashbucket))); + kfree(sctp_ep_hashtable); + sctp_transport_hashtable_destroy(); + + percpu_counter_destroy(&sctp_sockets_allocated); + + rcu_barrier(); /* Wait for completion of call_rcu()'s */ + + kmem_cache_destroy(sctp_chunk_cachep); + kmem_cache_destroy(sctp_bucket_cachep); +} + +module_init(sctp_init); +module_exit(sctp_exit); + +/* + * __stringify doesn't likes enums, so use IPPROTO_SCTP value (132) directly. + */ +MODULE_ALIAS("net-pf-" __stringify(PF_INET) "-proto-132"); +MODULE_ALIAS("net-pf-" __stringify(PF_INET6) "-proto-132"); +MODULE_AUTHOR("Linux Kernel SCTP developers "); +MODULE_DESCRIPTION("Support for the SCTP protocol (RFC2960)"); +module_param_named(no_checksums, sctp_checksum_disable, bool, 0644); +MODULE_PARM_DESC(no_checksums, "Disable checksums computing and verification"); +MODULE_LICENSE("GPL"); diff --git a/net/sctp/sm_make_chunk.c b/net/sctp/sm_make_chunk.c new file mode 100644 index 000000000..c7503fd64 --- /dev/null +++ b/net/sctp/sm_make_chunk.c @@ -0,0 +1,3937 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2001, 2004 + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * Copyright (c) 2001-2002 Intel Corp. + * + * This file is part of the SCTP kernel implementation + * + * These functions work with the state functions in sctp_sm_statefuns.c + * to implement the state operations. These functions implement the + * steps which require modifying existing data structures. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * La Monte H.P. Yarroll + * Karl Knutson + * C. Robin + * Jon Grimm + * Xingang Guo + * Dajiang Zhang + * Sridhar Samudrala + * Daisy Chang + * Ardelle Fan + * Kevin Gao + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include /* for get_random_bytes */ +#include +#include + +static struct sctp_chunk *sctp_make_control(const struct sctp_association *asoc, + __u8 type, __u8 flags, int paylen, + gfp_t gfp); +static struct sctp_chunk *sctp_make_data(const struct sctp_association *asoc, + __u8 flags, int paylen, gfp_t gfp); +static struct sctp_chunk *_sctp_make_chunk(const struct sctp_association *asoc, + __u8 type, __u8 flags, int paylen, + gfp_t gfp); +static struct sctp_cookie_param *sctp_pack_cookie( + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const struct sctp_chunk *init_chunk, + int *cookie_len, + const __u8 *raw_addrs, int addrs_len); +static int sctp_process_param(struct sctp_association *asoc, + union sctp_params param, + const union sctp_addr *peer_addr, + gfp_t gfp); +static void *sctp_addto_param(struct sctp_chunk *chunk, int len, + const void *data); + +/* Control chunk destructor */ +static void sctp_control_release_owner(struct sk_buff *skb) +{ + struct sctp_chunk *chunk = skb_shinfo(skb)->destructor_arg; + + if (chunk->shkey) { + struct sctp_shared_key *shkey = chunk->shkey; + struct sctp_association *asoc = chunk->asoc; + + /* refcnt == 2 and !list_empty mean after this release, it's + * not being used anywhere, and it's time to notify userland + * that this shkey can be freed if it's been deactivated. + */ + if (shkey->deactivated && !list_empty(&shkey->key_list) && + refcount_read(&shkey->refcnt) == 2) { + struct sctp_ulpevent *ev; + + ev = sctp_ulpevent_make_authkey(asoc, shkey->key_id, + SCTP_AUTH_FREE_KEY, + GFP_KERNEL); + if (ev) + asoc->stream.si->enqueue_event(&asoc->ulpq, ev); + } + sctp_auth_shkey_release(chunk->shkey); + } +} + +static void sctp_control_set_owner_w(struct sctp_chunk *chunk) +{ + struct sctp_association *asoc = chunk->asoc; + struct sk_buff *skb = chunk->skb; + + /* TODO: properly account for control chunks. + * To do it right we'll need: + * 1) endpoint if association isn't known. + * 2) proper memory accounting. + * + * For now don't do anything for now. + */ + if (chunk->auth) { + chunk->shkey = asoc->shkey; + sctp_auth_shkey_hold(chunk->shkey); + } + skb->sk = asoc ? asoc->base.sk : NULL; + skb_shinfo(skb)->destructor_arg = chunk; + skb->destructor = sctp_control_release_owner; +} + +/* What was the inbound interface for this chunk? */ +int sctp_chunk_iif(const struct sctp_chunk *chunk) +{ + struct sk_buff *skb = chunk->skb; + + return SCTP_INPUT_CB(skb)->af->skb_iif(skb); +} + +/* RFC 2960 3.3.2 Initiation (INIT) (1) + * + * Note 2: The ECN capable field is reserved for future use of + * Explicit Congestion Notification. + */ +static const struct sctp_paramhdr ecap_param = { + SCTP_PARAM_ECN_CAPABLE, + cpu_to_be16(sizeof(struct sctp_paramhdr)), +}; +static const struct sctp_paramhdr prsctp_param = { + SCTP_PARAM_FWD_TSN_SUPPORT, + cpu_to_be16(sizeof(struct sctp_paramhdr)), +}; + +/* A helper to initialize an op error inside a provided chunk, as most + * cause codes will be embedded inside an abort chunk. + */ +int sctp_init_cause(struct sctp_chunk *chunk, __be16 cause_code, + size_t paylen) +{ + struct sctp_errhdr err; + __u16 len; + + /* Cause code constants are now defined in network order. */ + err.cause = cause_code; + len = sizeof(err) + paylen; + err.length = htons(len); + + if (skb_tailroom(chunk->skb) < len) + return -ENOSPC; + + chunk->subh.err_hdr = sctp_addto_chunk(chunk, sizeof(err), &err); + + return 0; +} + +/* 3.3.2 Initiation (INIT) (1) + * + * This chunk is used to initiate a SCTP association between two + * endpoints. The format of the INIT chunk is shown below: + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Type = 1 | Chunk Flags | Chunk Length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Initiate Tag | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Advertised Receiver Window Credit (a_rwnd) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Number of Outbound Streams | Number of Inbound Streams | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Initial TSN | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * \ \ + * / Optional/Variable-Length Parameters / + * \ \ + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * + * The INIT chunk contains the following parameters. Unless otherwise + * noted, each parameter MUST only be included once in the INIT chunk. + * + * Fixed Parameters Status + * ---------------------------------------------- + * Initiate Tag Mandatory + * Advertised Receiver Window Credit Mandatory + * Number of Outbound Streams Mandatory + * Number of Inbound Streams Mandatory + * Initial TSN Mandatory + * + * Variable Parameters Status Type Value + * ------------------------------------------------------------- + * IPv4 Address (Note 1) Optional 5 + * IPv6 Address (Note 1) Optional 6 + * Cookie Preservative Optional 9 + * Reserved for ECN Capable (Note 2) Optional 32768 (0x8000) + * Host Name Address (Note 3) Optional 11 + * Supported Address Types (Note 4) Optional 12 + */ +struct sctp_chunk *sctp_make_init(const struct sctp_association *asoc, + const struct sctp_bind_addr *bp, + gfp_t gfp, int vparam_len) +{ + struct sctp_supported_ext_param ext_param; + struct sctp_adaptation_ind_param aiparam; + struct sctp_paramhdr *auth_chunks = NULL; + struct sctp_paramhdr *auth_hmacs = NULL; + struct sctp_supported_addrs_param sat; + struct sctp_endpoint *ep = asoc->ep; + struct sctp_chunk *retval = NULL; + int num_types, addrs_len = 0; + struct sctp_inithdr init; + union sctp_params addrs; + struct sctp_sock *sp; + __u8 extensions[5]; + size_t chunksize; + __be16 types[2]; + int num_ext = 0; + + /* RFC 2960 3.3.2 Initiation (INIT) (1) + * + * Note 1: The INIT chunks can contain multiple addresses that + * can be IPv4 and/or IPv6 in any combination. + */ + + /* Convert the provided bind address list to raw format. */ + addrs = sctp_bind_addrs_to_raw(bp, &addrs_len, gfp); + + init.init_tag = htonl(asoc->c.my_vtag); + init.a_rwnd = htonl(asoc->rwnd); + init.num_outbound_streams = htons(asoc->c.sinit_num_ostreams); + init.num_inbound_streams = htons(asoc->c.sinit_max_instreams); + init.initial_tsn = htonl(asoc->c.initial_tsn); + + /* How many address types are needed? */ + sp = sctp_sk(asoc->base.sk); + num_types = sp->pf->supported_addrs(sp, types); + + chunksize = sizeof(init) + addrs_len; + chunksize += SCTP_PAD4(SCTP_SAT_LEN(num_types)); + + if (asoc->ep->ecn_enable) + chunksize += sizeof(ecap_param); + + if (asoc->ep->prsctp_enable) + chunksize += sizeof(prsctp_param); + + /* ADDIP: Section 4.2.7: + * An implementation supporting this extension [ADDIP] MUST list + * the ASCONF,the ASCONF-ACK, and the AUTH chunks in its INIT and + * INIT-ACK parameters. + */ + if (asoc->ep->asconf_enable) { + extensions[num_ext] = SCTP_CID_ASCONF; + extensions[num_ext+1] = SCTP_CID_ASCONF_ACK; + num_ext += 2; + } + + if (asoc->ep->reconf_enable) { + extensions[num_ext] = SCTP_CID_RECONF; + num_ext += 1; + } + + if (sp->adaptation_ind) + chunksize += sizeof(aiparam); + + if (asoc->ep->intl_enable) { + extensions[num_ext] = SCTP_CID_I_DATA; + num_ext += 1; + } + + chunksize += vparam_len; + + /* Account for AUTH related parameters */ + if (ep->auth_enable) { + /* Add random parameter length*/ + chunksize += sizeof(asoc->c.auth_random); + + /* Add HMACS parameter length if any were defined */ + auth_hmacs = (struct sctp_paramhdr *)asoc->c.auth_hmacs; + if (auth_hmacs->length) + chunksize += SCTP_PAD4(ntohs(auth_hmacs->length)); + else + auth_hmacs = NULL; + + /* Add CHUNKS parameter length */ + auth_chunks = (struct sctp_paramhdr *)asoc->c.auth_chunks; + if (auth_chunks->length) + chunksize += SCTP_PAD4(ntohs(auth_chunks->length)); + else + auth_chunks = NULL; + + extensions[num_ext] = SCTP_CID_AUTH; + num_ext += 1; + } + + /* If we have any extensions to report, account for that */ + if (num_ext) + chunksize += SCTP_PAD4(sizeof(ext_param) + num_ext); + + /* RFC 2960 3.3.2 Initiation (INIT) (1) + * + * Note 3: An INIT chunk MUST NOT contain more than one Host + * Name address parameter. Moreover, the sender of the INIT + * MUST NOT combine any other address types with the Host Name + * address in the INIT. The receiver of INIT MUST ignore any + * other address types if the Host Name address parameter is + * present in the received INIT chunk. + * + * PLEASE DO NOT FIXME [This version does not support Host Name.] + */ + + retval = sctp_make_control(asoc, SCTP_CID_INIT, 0, chunksize, gfp); + if (!retval) + goto nodata; + + retval->subh.init_hdr = + sctp_addto_chunk(retval, sizeof(init), &init); + retval->param_hdr.v = + sctp_addto_chunk(retval, addrs_len, addrs.v); + + /* RFC 2960 3.3.2 Initiation (INIT) (1) + * + * Note 4: This parameter, when present, specifies all the + * address types the sending endpoint can support. The absence + * of this parameter indicates that the sending endpoint can + * support any address type. + */ + sat.param_hdr.type = SCTP_PARAM_SUPPORTED_ADDRESS_TYPES; + sat.param_hdr.length = htons(SCTP_SAT_LEN(num_types)); + sctp_addto_chunk(retval, sizeof(sat), &sat); + sctp_addto_chunk(retval, num_types * sizeof(__u16), &types); + + if (asoc->ep->ecn_enable) + sctp_addto_chunk(retval, sizeof(ecap_param), &ecap_param); + + /* Add the supported extensions parameter. Be nice and add this + * fist before addiding the parameters for the extensions themselves + */ + if (num_ext) { + ext_param.param_hdr.type = SCTP_PARAM_SUPPORTED_EXT; + ext_param.param_hdr.length = htons(sizeof(ext_param) + num_ext); + sctp_addto_chunk(retval, sizeof(ext_param), &ext_param); + sctp_addto_param(retval, num_ext, extensions); + } + + if (asoc->ep->prsctp_enable) + sctp_addto_chunk(retval, sizeof(prsctp_param), &prsctp_param); + + if (sp->adaptation_ind) { + aiparam.param_hdr.type = SCTP_PARAM_ADAPTATION_LAYER_IND; + aiparam.param_hdr.length = htons(sizeof(aiparam)); + aiparam.adaptation_ind = htonl(sp->adaptation_ind); + sctp_addto_chunk(retval, sizeof(aiparam), &aiparam); + } + + /* Add SCTP-AUTH chunks to the parameter list */ + if (ep->auth_enable) { + sctp_addto_chunk(retval, sizeof(asoc->c.auth_random), + asoc->c.auth_random); + if (auth_hmacs) + sctp_addto_chunk(retval, ntohs(auth_hmacs->length), + auth_hmacs); + if (auth_chunks) + sctp_addto_chunk(retval, ntohs(auth_chunks->length), + auth_chunks); + } +nodata: + kfree(addrs.v); + return retval; +} + +struct sctp_chunk *sctp_make_init_ack(const struct sctp_association *asoc, + const struct sctp_chunk *chunk, + gfp_t gfp, int unkparam_len) +{ + struct sctp_supported_ext_param ext_param; + struct sctp_adaptation_ind_param aiparam; + struct sctp_paramhdr *auth_chunks = NULL; + struct sctp_paramhdr *auth_random = NULL; + struct sctp_paramhdr *auth_hmacs = NULL; + struct sctp_chunk *retval = NULL; + struct sctp_cookie_param *cookie; + struct sctp_inithdr initack; + union sctp_params addrs; + struct sctp_sock *sp; + __u8 extensions[5]; + size_t chunksize; + int num_ext = 0; + int cookie_len; + int addrs_len; + + /* Note: there may be no addresses to embed. */ + addrs = sctp_bind_addrs_to_raw(&asoc->base.bind_addr, &addrs_len, gfp); + + initack.init_tag = htonl(asoc->c.my_vtag); + initack.a_rwnd = htonl(asoc->rwnd); + initack.num_outbound_streams = htons(asoc->c.sinit_num_ostreams); + initack.num_inbound_streams = htons(asoc->c.sinit_max_instreams); + initack.initial_tsn = htonl(asoc->c.initial_tsn); + + /* FIXME: We really ought to build the cookie right + * into the packet instead of allocating more fresh memory. + */ + cookie = sctp_pack_cookie(asoc->ep, asoc, chunk, &cookie_len, + addrs.v, addrs_len); + if (!cookie) + goto nomem_cookie; + + /* Calculate the total size of allocation, include the reserved + * space for reporting unknown parameters if it is specified. + */ + sp = sctp_sk(asoc->base.sk); + chunksize = sizeof(initack) + addrs_len + cookie_len + unkparam_len; + + /* Tell peer that we'll do ECN only if peer advertised such cap. */ + if (asoc->peer.ecn_capable) + chunksize += sizeof(ecap_param); + + if (asoc->peer.prsctp_capable) + chunksize += sizeof(prsctp_param); + + if (asoc->peer.asconf_capable) { + extensions[num_ext] = SCTP_CID_ASCONF; + extensions[num_ext+1] = SCTP_CID_ASCONF_ACK; + num_ext += 2; + } + + if (asoc->peer.reconf_capable) { + extensions[num_ext] = SCTP_CID_RECONF; + num_ext += 1; + } + + if (sp->adaptation_ind) + chunksize += sizeof(aiparam); + + if (asoc->peer.intl_capable) { + extensions[num_ext] = SCTP_CID_I_DATA; + num_ext += 1; + } + + if (asoc->peer.auth_capable) { + auth_random = (struct sctp_paramhdr *)asoc->c.auth_random; + chunksize += ntohs(auth_random->length); + + auth_hmacs = (struct sctp_paramhdr *)asoc->c.auth_hmacs; + if (auth_hmacs->length) + chunksize += SCTP_PAD4(ntohs(auth_hmacs->length)); + else + auth_hmacs = NULL; + + auth_chunks = (struct sctp_paramhdr *)asoc->c.auth_chunks; + if (auth_chunks->length) + chunksize += SCTP_PAD4(ntohs(auth_chunks->length)); + else + auth_chunks = NULL; + + extensions[num_ext] = SCTP_CID_AUTH; + num_ext += 1; + } + + if (num_ext) + chunksize += SCTP_PAD4(sizeof(ext_param) + num_ext); + + /* Now allocate and fill out the chunk. */ + retval = sctp_make_control(asoc, SCTP_CID_INIT_ACK, 0, chunksize, gfp); + if (!retval) + goto nomem_chunk; + + /* RFC 2960 6.4 Multi-homed SCTP Endpoints + * + * An endpoint SHOULD transmit reply chunks (e.g., SACK, + * HEARTBEAT ACK, * etc.) to the same destination transport + * address from which it received the DATA or control chunk + * to which it is replying. + * + * [INIT ACK back to where the INIT came from.] + */ + if (chunk->transport) + retval->transport = + sctp_assoc_lookup_paddr(asoc, + &chunk->transport->ipaddr); + + retval->subh.init_hdr = + sctp_addto_chunk(retval, sizeof(initack), &initack); + retval->param_hdr.v = sctp_addto_chunk(retval, addrs_len, addrs.v); + sctp_addto_chunk(retval, cookie_len, cookie); + if (asoc->peer.ecn_capable) + sctp_addto_chunk(retval, sizeof(ecap_param), &ecap_param); + if (num_ext) { + ext_param.param_hdr.type = SCTP_PARAM_SUPPORTED_EXT; + ext_param.param_hdr.length = htons(sizeof(ext_param) + num_ext); + sctp_addto_chunk(retval, sizeof(ext_param), &ext_param); + sctp_addto_param(retval, num_ext, extensions); + } + if (asoc->peer.prsctp_capable) + sctp_addto_chunk(retval, sizeof(prsctp_param), &prsctp_param); + + if (sp->adaptation_ind) { + aiparam.param_hdr.type = SCTP_PARAM_ADAPTATION_LAYER_IND; + aiparam.param_hdr.length = htons(sizeof(aiparam)); + aiparam.adaptation_ind = htonl(sp->adaptation_ind); + sctp_addto_chunk(retval, sizeof(aiparam), &aiparam); + } + + if (asoc->peer.auth_capable) { + sctp_addto_chunk(retval, ntohs(auth_random->length), + auth_random); + if (auth_hmacs) + sctp_addto_chunk(retval, ntohs(auth_hmacs->length), + auth_hmacs); + if (auth_chunks) + sctp_addto_chunk(retval, ntohs(auth_chunks->length), + auth_chunks); + } + + /* We need to remove the const qualifier at this point. */ + retval->asoc = (struct sctp_association *) asoc; + +nomem_chunk: + kfree(cookie); +nomem_cookie: + kfree(addrs.v); + return retval; +} + +/* 3.3.11 Cookie Echo (COOKIE ECHO) (10): + * + * This chunk is used only during the initialization of an association. + * It is sent by the initiator of an association to its peer to complete + * the initialization process. This chunk MUST precede any DATA chunk + * sent within the association, but MAY be bundled with one or more DATA + * chunks in the same packet. + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Type = 10 |Chunk Flags | Length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * / Cookie / + * \ \ + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * Chunk Flags: 8 bit + * + * Set to zero on transmit and ignored on receipt. + * + * Length: 16 bits (unsigned integer) + * + * Set to the size of the chunk in bytes, including the 4 bytes of + * the chunk header and the size of the Cookie. + * + * Cookie: variable size + * + * This field must contain the exact cookie received in the + * State Cookie parameter from the previous INIT ACK. + * + * An implementation SHOULD make the cookie as small as possible + * to insure interoperability. + */ +struct sctp_chunk *sctp_make_cookie_echo(const struct sctp_association *asoc, + const struct sctp_chunk *chunk) +{ + struct sctp_chunk *retval; + int cookie_len; + void *cookie; + + cookie = asoc->peer.cookie; + cookie_len = asoc->peer.cookie_len; + + /* Build a cookie echo chunk. */ + retval = sctp_make_control(asoc, SCTP_CID_COOKIE_ECHO, 0, + cookie_len, GFP_ATOMIC); + if (!retval) + goto nodata; + retval->subh.cookie_hdr = + sctp_addto_chunk(retval, cookie_len, cookie); + + /* RFC 2960 6.4 Multi-homed SCTP Endpoints + * + * An endpoint SHOULD transmit reply chunks (e.g., SACK, + * HEARTBEAT ACK, * etc.) to the same destination transport + * address from which it * received the DATA or control chunk + * to which it is replying. + * + * [COOKIE ECHO back to where the INIT ACK came from.] + */ + if (chunk) + retval->transport = chunk->transport; + +nodata: + return retval; +} + +/* 3.3.12 Cookie Acknowledgement (COOKIE ACK) (11): + * + * This chunk is used only during the initialization of an + * association. It is used to acknowledge the receipt of a COOKIE + * ECHO chunk. This chunk MUST precede any DATA or SACK chunk sent + * within the association, but MAY be bundled with one or more DATA + * chunks or SACK chunk in the same SCTP packet. + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Type = 11 |Chunk Flags | Length = 4 | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * Chunk Flags: 8 bits + * + * Set to zero on transmit and ignored on receipt. + */ +struct sctp_chunk *sctp_make_cookie_ack(const struct sctp_association *asoc, + const struct sctp_chunk *chunk) +{ + struct sctp_chunk *retval; + + retval = sctp_make_control(asoc, SCTP_CID_COOKIE_ACK, 0, 0, GFP_ATOMIC); + + /* RFC 2960 6.4 Multi-homed SCTP Endpoints + * + * An endpoint SHOULD transmit reply chunks (e.g., SACK, + * HEARTBEAT ACK, * etc.) to the same destination transport + * address from which it * received the DATA or control chunk + * to which it is replying. + * + * [COOKIE ACK back to where the COOKIE ECHO came from.] + */ + if (retval && chunk && chunk->transport) + retval->transport = + sctp_assoc_lookup_paddr(asoc, + &chunk->transport->ipaddr); + + return retval; +} + +/* + * Appendix A: Explicit Congestion Notification: + * CWR: + * + * RFC 2481 details a specific bit for a sender to send in the header of + * its next outbound TCP segment to indicate to its peer that it has + * reduced its congestion window. This is termed the CWR bit. For + * SCTP the same indication is made by including the CWR chunk. + * This chunk contains one data element, i.e. the TSN number that + * was sent in the ECNE chunk. This element represents the lowest + * TSN number in the datagram that was originally marked with the + * CE bit. + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Chunk Type=13 | Flags=00000000| Chunk Length = 8 | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Lowest TSN Number | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * Note: The CWR is considered a Control chunk. + */ +struct sctp_chunk *sctp_make_cwr(const struct sctp_association *asoc, + const __u32 lowest_tsn, + const struct sctp_chunk *chunk) +{ + struct sctp_chunk *retval; + struct sctp_cwrhdr cwr; + + cwr.lowest_tsn = htonl(lowest_tsn); + retval = sctp_make_control(asoc, SCTP_CID_ECN_CWR, 0, + sizeof(cwr), GFP_ATOMIC); + + if (!retval) + goto nodata; + + retval->subh.ecn_cwr_hdr = + sctp_addto_chunk(retval, sizeof(cwr), &cwr); + + /* RFC 2960 6.4 Multi-homed SCTP Endpoints + * + * An endpoint SHOULD transmit reply chunks (e.g., SACK, + * HEARTBEAT ACK, * etc.) to the same destination transport + * address from which it * received the DATA or control chunk + * to which it is replying. + * + * [Report a reduced congestion window back to where the ECNE + * came from.] + */ + if (chunk) + retval->transport = chunk->transport; + +nodata: + return retval; +} + +/* Make an ECNE chunk. This is a congestion experienced report. */ +struct sctp_chunk *sctp_make_ecne(const struct sctp_association *asoc, + const __u32 lowest_tsn) +{ + struct sctp_chunk *retval; + struct sctp_ecnehdr ecne; + + ecne.lowest_tsn = htonl(lowest_tsn); + retval = sctp_make_control(asoc, SCTP_CID_ECN_ECNE, 0, + sizeof(ecne), GFP_ATOMIC); + if (!retval) + goto nodata; + retval->subh.ecne_hdr = + sctp_addto_chunk(retval, sizeof(ecne), &ecne); + +nodata: + return retval; +} + +/* Make a DATA chunk for the given association from the provided + * parameters. However, do not populate the data payload. + */ +struct sctp_chunk *sctp_make_datafrag_empty(const struct sctp_association *asoc, + const struct sctp_sndrcvinfo *sinfo, + int len, __u8 flags, gfp_t gfp) +{ + struct sctp_chunk *retval; + struct sctp_datahdr dp; + + /* We assign the TSN as LATE as possible, not here when + * creating the chunk. + */ + memset(&dp, 0, sizeof(dp)); + dp.ppid = sinfo->sinfo_ppid; + dp.stream = htons(sinfo->sinfo_stream); + + /* Set the flags for an unordered send. */ + if (sinfo->sinfo_flags & SCTP_UNORDERED) + flags |= SCTP_DATA_UNORDERED; + + retval = sctp_make_data(asoc, flags, sizeof(dp) + len, gfp); + if (!retval) + return NULL; + + retval->subh.data_hdr = sctp_addto_chunk(retval, sizeof(dp), &dp); + memcpy(&retval->sinfo, sinfo, sizeof(struct sctp_sndrcvinfo)); + + return retval; +} + +/* Create a selective ackowledgement (SACK) for the given + * association. This reports on which TSN's we've seen to date, + * including duplicates and gaps. + */ +struct sctp_chunk *sctp_make_sack(struct sctp_association *asoc) +{ + struct sctp_tsnmap *map = (struct sctp_tsnmap *)&asoc->peer.tsn_map; + struct sctp_gap_ack_block gabs[SCTP_MAX_GABS]; + __u16 num_gabs, num_dup_tsns; + struct sctp_transport *trans; + struct sctp_chunk *retval; + struct sctp_sackhdr sack; + __u32 ctsn; + int len; + + memset(gabs, 0, sizeof(gabs)); + ctsn = sctp_tsnmap_get_ctsn(map); + + pr_debug("%s: sackCTSNAck sent:0x%x\n", __func__, ctsn); + + /* How much room is needed in the chunk? */ + num_gabs = sctp_tsnmap_num_gabs(map, gabs); + num_dup_tsns = sctp_tsnmap_num_dups(map); + + /* Initialize the SACK header. */ + sack.cum_tsn_ack = htonl(ctsn); + sack.a_rwnd = htonl(asoc->a_rwnd); + sack.num_gap_ack_blocks = htons(num_gabs); + sack.num_dup_tsns = htons(num_dup_tsns); + + len = sizeof(sack) + + sizeof(struct sctp_gap_ack_block) * num_gabs + + sizeof(__u32) * num_dup_tsns; + + /* Create the chunk. */ + retval = sctp_make_control(asoc, SCTP_CID_SACK, 0, len, GFP_ATOMIC); + if (!retval) + goto nodata; + + /* RFC 2960 6.4 Multi-homed SCTP Endpoints + * + * An endpoint SHOULD transmit reply chunks (e.g., SACK, + * HEARTBEAT ACK, etc.) to the same destination transport + * address from which it received the DATA or control chunk to + * which it is replying. This rule should also be followed if + * the endpoint is bundling DATA chunks together with the + * reply chunk. + * + * However, when acknowledging multiple DATA chunks received + * in packets from different source addresses in a single + * SACK, the SACK chunk may be transmitted to one of the + * destination transport addresses from which the DATA or + * control chunks being acknowledged were received. + * + * [BUG: We do not implement the following paragraph. + * Perhaps we should remember the last transport we used for a + * SACK and avoid that (if possible) if we have seen any + * duplicates. --piggy] + * + * When a receiver of a duplicate DATA chunk sends a SACK to a + * multi- homed endpoint it MAY be beneficial to vary the + * destination address and not use the source address of the + * DATA chunk. The reason being that receiving a duplicate + * from a multi-homed endpoint might indicate that the return + * path (as specified in the source address of the DATA chunk) + * for the SACK is broken. + * + * [Send to the address from which we last received a DATA chunk.] + */ + retval->transport = asoc->peer.last_data_from; + + retval->subh.sack_hdr = + sctp_addto_chunk(retval, sizeof(sack), &sack); + + /* Add the gap ack block information. */ + if (num_gabs) + sctp_addto_chunk(retval, sizeof(__u32) * num_gabs, + gabs); + + /* Add the duplicate TSN information. */ + if (num_dup_tsns) { + asoc->stats.idupchunks += num_dup_tsns; + sctp_addto_chunk(retval, sizeof(__u32) * num_dup_tsns, + sctp_tsnmap_get_dups(map)); + } + /* Once we have a sack generated, check to see what our sack + * generation is, if its 0, reset the transports to 0, and reset + * the association generation to 1 + * + * The idea is that zero is never used as a valid generation for the + * association so no transport will match after a wrap event like this, + * Until the next sack + */ + if (++asoc->peer.sack_generation == 0) { + list_for_each_entry(trans, &asoc->peer.transport_addr_list, + transports) + trans->sack_generation = 0; + asoc->peer.sack_generation = 1; + } +nodata: + return retval; +} + +/* Make a SHUTDOWN chunk. */ +struct sctp_chunk *sctp_make_shutdown(const struct sctp_association *asoc, + const struct sctp_chunk *chunk) +{ + struct sctp_shutdownhdr shut; + struct sctp_chunk *retval; + __u32 ctsn; + + ctsn = sctp_tsnmap_get_ctsn(&asoc->peer.tsn_map); + shut.cum_tsn_ack = htonl(ctsn); + + retval = sctp_make_control(asoc, SCTP_CID_SHUTDOWN, 0, + sizeof(shut), GFP_ATOMIC); + if (!retval) + goto nodata; + + retval->subh.shutdown_hdr = + sctp_addto_chunk(retval, sizeof(shut), &shut); + + if (chunk) + retval->transport = chunk->transport; +nodata: + return retval; +} + +struct sctp_chunk *sctp_make_shutdown_ack(const struct sctp_association *asoc, + const struct sctp_chunk *chunk) +{ + struct sctp_chunk *retval; + + retval = sctp_make_control(asoc, SCTP_CID_SHUTDOWN_ACK, 0, 0, + GFP_ATOMIC); + + /* RFC 2960 6.4 Multi-homed SCTP Endpoints + * + * An endpoint SHOULD transmit reply chunks (e.g., SACK, + * HEARTBEAT ACK, * etc.) to the same destination transport + * address from which it * received the DATA or control chunk + * to which it is replying. + * + * [ACK back to where the SHUTDOWN came from.] + */ + if (retval && chunk) + retval->transport = chunk->transport; + + return retval; +} + +struct sctp_chunk *sctp_make_shutdown_complete( + const struct sctp_association *asoc, + const struct sctp_chunk *chunk) +{ + struct sctp_chunk *retval; + __u8 flags = 0; + + /* Set the T-bit if we have no association (vtag will be + * reflected) + */ + flags |= asoc ? 0 : SCTP_CHUNK_FLAG_T; + + retval = sctp_make_control(asoc, SCTP_CID_SHUTDOWN_COMPLETE, flags, + 0, GFP_ATOMIC); + + /* RFC 2960 6.4 Multi-homed SCTP Endpoints + * + * An endpoint SHOULD transmit reply chunks (e.g., SACK, + * HEARTBEAT ACK, * etc.) to the same destination transport + * address from which it * received the DATA or control chunk + * to which it is replying. + * + * [Report SHUTDOWN COMPLETE back to where the SHUTDOWN ACK + * came from.] + */ + if (retval && chunk) + retval->transport = chunk->transport; + + return retval; +} + +/* Create an ABORT. Note that we set the T bit if we have no + * association, except when responding to an INIT (sctpimpguide 2.41). + */ +struct sctp_chunk *sctp_make_abort(const struct sctp_association *asoc, + const struct sctp_chunk *chunk, + const size_t hint) +{ + struct sctp_chunk *retval; + __u8 flags = 0; + + /* Set the T-bit if we have no association and 'chunk' is not + * an INIT (vtag will be reflected). + */ + if (!asoc) { + if (chunk && chunk->chunk_hdr && + chunk->chunk_hdr->type == SCTP_CID_INIT) + flags = 0; + else + flags = SCTP_CHUNK_FLAG_T; + } + + retval = sctp_make_control(asoc, SCTP_CID_ABORT, flags, hint, + GFP_ATOMIC); + + /* RFC 2960 6.4 Multi-homed SCTP Endpoints + * + * An endpoint SHOULD transmit reply chunks (e.g., SACK, + * HEARTBEAT ACK, * etc.) to the same destination transport + * address from which it * received the DATA or control chunk + * to which it is replying. + * + * [ABORT back to where the offender came from.] + */ + if (retval && chunk) + retval->transport = chunk->transport; + + return retval; +} + +/* Helper to create ABORT with a NO_USER_DATA error. */ +struct sctp_chunk *sctp_make_abort_no_data( + const struct sctp_association *asoc, + const struct sctp_chunk *chunk, + __u32 tsn) +{ + struct sctp_chunk *retval; + __be32 payload; + + retval = sctp_make_abort(asoc, chunk, + sizeof(struct sctp_errhdr) + sizeof(tsn)); + + if (!retval) + goto no_mem; + + /* Put the tsn back into network byte order. */ + payload = htonl(tsn); + sctp_init_cause(retval, SCTP_ERROR_NO_DATA, sizeof(payload)); + sctp_addto_chunk(retval, sizeof(payload), (const void *)&payload); + + /* RFC 2960 6.4 Multi-homed SCTP Endpoints + * + * An endpoint SHOULD transmit reply chunks (e.g., SACK, + * HEARTBEAT ACK, * etc.) to the same destination transport + * address from which it * received the DATA or control chunk + * to which it is replying. + * + * [ABORT back to where the offender came from.] + */ + if (chunk) + retval->transport = chunk->transport; + +no_mem: + return retval; +} + +/* Helper to create ABORT with a SCTP_ERROR_USER_ABORT error. */ +struct sctp_chunk *sctp_make_abort_user(const struct sctp_association *asoc, + struct msghdr *msg, + size_t paylen) +{ + struct sctp_chunk *retval; + void *payload = NULL; + int err; + + retval = sctp_make_abort(asoc, NULL, + sizeof(struct sctp_errhdr) + paylen); + if (!retval) + goto err_chunk; + + if (paylen) { + /* Put the msg_iov together into payload. */ + payload = kmalloc(paylen, GFP_KERNEL); + if (!payload) + goto err_payload; + + err = memcpy_from_msg(payload, msg, paylen); + if (err < 0) + goto err_copy; + } + + sctp_init_cause(retval, SCTP_ERROR_USER_ABORT, paylen); + sctp_addto_chunk(retval, paylen, payload); + + if (paylen) + kfree(payload); + + return retval; + +err_copy: + kfree(payload); +err_payload: + sctp_chunk_free(retval); + retval = NULL; +err_chunk: + return retval; +} + +/* Append bytes to the end of a parameter. Will panic if chunk is not big + * enough. + */ +static void *sctp_addto_param(struct sctp_chunk *chunk, int len, + const void *data) +{ + int chunklen = ntohs(chunk->chunk_hdr->length); + void *target; + + target = skb_put(chunk->skb, len); + + if (data) + memcpy(target, data, len); + else + memset(target, 0, len); + + /* Adjust the chunk length field. */ + chunk->chunk_hdr->length = htons(chunklen + len); + chunk->chunk_end = skb_tail_pointer(chunk->skb); + + return target; +} + +/* Make an ABORT chunk with a PROTOCOL VIOLATION cause code. */ +struct sctp_chunk *sctp_make_abort_violation( + const struct sctp_association *asoc, + const struct sctp_chunk *chunk, + const __u8 *payload, + const size_t paylen) +{ + struct sctp_chunk *retval; + struct sctp_paramhdr phdr; + + retval = sctp_make_abort(asoc, chunk, sizeof(struct sctp_errhdr) + + paylen + sizeof(phdr)); + if (!retval) + goto end; + + sctp_init_cause(retval, SCTP_ERROR_PROTO_VIOLATION, paylen + + sizeof(phdr)); + + phdr.type = htons(chunk->chunk_hdr->type); + phdr.length = chunk->chunk_hdr->length; + sctp_addto_chunk(retval, paylen, payload); + sctp_addto_param(retval, sizeof(phdr), &phdr); + +end: + return retval; +} + +struct sctp_chunk *sctp_make_violation_paramlen( + const struct sctp_association *asoc, + const struct sctp_chunk *chunk, + struct sctp_paramhdr *param) +{ + static const char error[] = "The following parameter had invalid length:"; + size_t payload_len = sizeof(error) + sizeof(struct sctp_errhdr) + + sizeof(*param); + struct sctp_chunk *retval; + + retval = sctp_make_abort(asoc, chunk, payload_len); + if (!retval) + goto nodata; + + sctp_init_cause(retval, SCTP_ERROR_PROTO_VIOLATION, + sizeof(error) + sizeof(*param)); + sctp_addto_chunk(retval, sizeof(error), error); + sctp_addto_param(retval, sizeof(*param), param); + +nodata: + return retval; +} + +struct sctp_chunk *sctp_make_violation_max_retrans( + const struct sctp_association *asoc, + const struct sctp_chunk *chunk) +{ + static const char error[] = "Association exceeded its max_retrans count"; + size_t payload_len = sizeof(error) + sizeof(struct sctp_errhdr); + struct sctp_chunk *retval; + + retval = sctp_make_abort(asoc, chunk, payload_len); + if (!retval) + goto nodata; + + sctp_init_cause(retval, SCTP_ERROR_PROTO_VIOLATION, sizeof(error)); + sctp_addto_chunk(retval, sizeof(error), error); + +nodata: + return retval; +} + +struct sctp_chunk *sctp_make_new_encap_port(const struct sctp_association *asoc, + const struct sctp_chunk *chunk) +{ + struct sctp_new_encap_port_hdr nep; + struct sctp_chunk *retval; + + retval = sctp_make_abort(asoc, chunk, + sizeof(struct sctp_errhdr) + sizeof(nep)); + if (!retval) + goto nodata; + + sctp_init_cause(retval, SCTP_ERROR_NEW_ENCAP_PORT, sizeof(nep)); + nep.cur_port = SCTP_INPUT_CB(chunk->skb)->encap_port; + nep.new_port = chunk->transport->encap_port; + sctp_addto_chunk(retval, sizeof(nep), &nep); + +nodata: + return retval; +} + +/* Make a HEARTBEAT chunk. */ +struct sctp_chunk *sctp_make_heartbeat(const struct sctp_association *asoc, + const struct sctp_transport *transport, + __u32 probe_size) +{ + struct sctp_sender_hb_info hbinfo = {}; + struct sctp_chunk *retval; + + retval = sctp_make_control(asoc, SCTP_CID_HEARTBEAT, 0, + sizeof(hbinfo), GFP_ATOMIC); + + if (!retval) + goto nodata; + + hbinfo.param_hdr.type = SCTP_PARAM_HEARTBEAT_INFO; + hbinfo.param_hdr.length = htons(sizeof(hbinfo)); + hbinfo.daddr = transport->ipaddr; + hbinfo.sent_at = jiffies; + hbinfo.hb_nonce = transport->hb_nonce; + hbinfo.probe_size = probe_size; + + /* Cast away the 'const', as this is just telling the chunk + * what transport it belongs to. + */ + retval->transport = (struct sctp_transport *) transport; + retval->subh.hbs_hdr = sctp_addto_chunk(retval, sizeof(hbinfo), + &hbinfo); + retval->pmtu_probe = !!probe_size; + +nodata: + return retval; +} + +struct sctp_chunk *sctp_make_heartbeat_ack(const struct sctp_association *asoc, + const struct sctp_chunk *chunk, + const void *payload, + const size_t paylen) +{ + struct sctp_chunk *retval; + + retval = sctp_make_control(asoc, SCTP_CID_HEARTBEAT_ACK, 0, paylen, + GFP_ATOMIC); + if (!retval) + goto nodata; + + retval->subh.hbs_hdr = sctp_addto_chunk(retval, paylen, payload); + + /* RFC 2960 6.4 Multi-homed SCTP Endpoints + * + * An endpoint SHOULD transmit reply chunks (e.g., SACK, + * HEARTBEAT ACK, * etc.) to the same destination transport + * address from which it * received the DATA or control chunk + * to which it is replying. + * + * [HBACK back to where the HEARTBEAT came from.] + */ + if (chunk) + retval->transport = chunk->transport; + +nodata: + return retval; +} + +/* RFC4820 3. Padding Chunk (PAD) + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Type = 0x84 | Flags=0 | Length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | | + * \ Padding Data / + * / \ + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ +struct sctp_chunk *sctp_make_pad(const struct sctp_association *asoc, int len) +{ + struct sctp_chunk *retval; + + retval = sctp_make_control(asoc, SCTP_CID_PAD, 0, len, GFP_ATOMIC); + if (!retval) + return NULL; + + skb_put_zero(retval->skb, len); + retval->chunk_hdr->length = htons(ntohs(retval->chunk_hdr->length) + len); + retval->chunk_end = skb_tail_pointer(retval->skb); + + return retval; +} + +/* Create an Operation Error chunk with the specified space reserved. + * This routine can be used for containing multiple causes in the chunk. + */ +static struct sctp_chunk *sctp_make_op_error_space( + const struct sctp_association *asoc, + const struct sctp_chunk *chunk, + size_t size) +{ + struct sctp_chunk *retval; + + retval = sctp_make_control(asoc, SCTP_CID_ERROR, 0, + sizeof(struct sctp_errhdr) + size, + GFP_ATOMIC); + if (!retval) + goto nodata; + + /* RFC 2960 6.4 Multi-homed SCTP Endpoints + * + * An endpoint SHOULD transmit reply chunks (e.g., SACK, + * HEARTBEAT ACK, etc.) to the same destination transport + * address from which it received the DATA or control chunk + * to which it is replying. + * + */ + if (chunk) + retval->transport = chunk->transport; + +nodata: + return retval; +} + +/* Create an Operation Error chunk of a fixed size, specifically, + * min(asoc->pathmtu, SCTP_DEFAULT_MAXSEGMENT) - overheads. + * This is a helper function to allocate an error chunk for those + * invalid parameter codes in which we may not want to report all the + * errors, if the incoming chunk is large. If it can't fit in a single + * packet, we ignore it. + */ +static inline struct sctp_chunk *sctp_make_op_error_limited( + const struct sctp_association *asoc, + const struct sctp_chunk *chunk) +{ + size_t size = SCTP_DEFAULT_MAXSEGMENT; + struct sctp_sock *sp = NULL; + + if (asoc) { + size = min_t(size_t, size, asoc->pathmtu); + sp = sctp_sk(asoc->base.sk); + } + + size = sctp_mtu_payload(sp, size, sizeof(struct sctp_errhdr)); + + return sctp_make_op_error_space(asoc, chunk, size); +} + +/* Create an Operation Error chunk. */ +struct sctp_chunk *sctp_make_op_error(const struct sctp_association *asoc, + const struct sctp_chunk *chunk, + __be16 cause_code, const void *payload, + size_t paylen, size_t reserve_tail) +{ + struct sctp_chunk *retval; + + retval = sctp_make_op_error_space(asoc, chunk, paylen + reserve_tail); + if (!retval) + goto nodata; + + sctp_init_cause(retval, cause_code, paylen + reserve_tail); + sctp_addto_chunk(retval, paylen, payload); + if (reserve_tail) + sctp_addto_param(retval, reserve_tail, NULL); + +nodata: + return retval; +} + +struct sctp_chunk *sctp_make_auth(const struct sctp_association *asoc, + __u16 key_id) +{ + struct sctp_authhdr auth_hdr; + struct sctp_hmac *hmac_desc; + struct sctp_chunk *retval; + + /* Get the first hmac that the peer told us to use */ + hmac_desc = sctp_auth_asoc_get_hmac(asoc); + if (unlikely(!hmac_desc)) + return NULL; + + retval = sctp_make_control(asoc, SCTP_CID_AUTH, 0, + hmac_desc->hmac_len + sizeof(auth_hdr), + GFP_ATOMIC); + if (!retval) + return NULL; + + auth_hdr.hmac_id = htons(hmac_desc->hmac_id); + auth_hdr.shkey_id = htons(key_id); + + retval->subh.auth_hdr = sctp_addto_chunk(retval, sizeof(auth_hdr), + &auth_hdr); + + skb_put_zero(retval->skb, hmac_desc->hmac_len); + + /* Adjust the chunk header to include the empty MAC */ + retval->chunk_hdr->length = + htons(ntohs(retval->chunk_hdr->length) + hmac_desc->hmac_len); + retval->chunk_end = skb_tail_pointer(retval->skb); + + return retval; +} + + +/******************************************************************** + * 2nd Level Abstractions + ********************************************************************/ + +/* Turn an skb into a chunk. + * FIXME: Eventually move the structure directly inside the skb->cb[]. + * + * sctpimpguide-05.txt Section 2.8.2 + * M1) Each time a new DATA chunk is transmitted + * set the 'TSN.Missing.Report' count for that TSN to 0. The + * 'TSN.Missing.Report' count will be used to determine missing chunks + * and when to fast retransmit. + * + */ +struct sctp_chunk *sctp_chunkify(struct sk_buff *skb, + const struct sctp_association *asoc, + struct sock *sk, gfp_t gfp) +{ + struct sctp_chunk *retval; + + retval = kmem_cache_zalloc(sctp_chunk_cachep, gfp); + + if (!retval) + goto nodata; + if (!sk) + pr_debug("%s: chunkifying skb:%p w/o an sk\n", __func__, skb); + + INIT_LIST_HEAD(&retval->list); + retval->skb = skb; + retval->asoc = (struct sctp_association *)asoc; + retval->singleton = 1; + + retval->fast_retransmit = SCTP_CAN_FRTX; + + /* Polish the bead hole. */ + INIT_LIST_HEAD(&retval->transmitted_list); + INIT_LIST_HEAD(&retval->frag_list); + SCTP_DBG_OBJCNT_INC(chunk); + refcount_set(&retval->refcnt, 1); + +nodata: + return retval; +} + +/* Set chunk->source and dest based on the IP header in chunk->skb. */ +void sctp_init_addrs(struct sctp_chunk *chunk, union sctp_addr *src, + union sctp_addr *dest) +{ + memcpy(&chunk->source, src, sizeof(union sctp_addr)); + memcpy(&chunk->dest, dest, sizeof(union sctp_addr)); +} + +/* Extract the source address from a chunk. */ +const union sctp_addr *sctp_source(const struct sctp_chunk *chunk) +{ + /* If we have a known transport, use that. */ + if (chunk->transport) { + return &chunk->transport->ipaddr; + } else { + /* Otherwise, extract it from the IP header. */ + return &chunk->source; + } +} + +/* Create a new chunk, setting the type and flags headers from the + * arguments, reserving enough space for a 'paylen' byte payload. + */ +static struct sctp_chunk *_sctp_make_chunk(const struct sctp_association *asoc, + __u8 type, __u8 flags, int paylen, + gfp_t gfp) +{ + struct sctp_chunkhdr *chunk_hdr; + struct sctp_chunk *retval; + struct sk_buff *skb; + struct sock *sk; + int chunklen; + + chunklen = SCTP_PAD4(sizeof(*chunk_hdr) + paylen); + if (chunklen > SCTP_MAX_CHUNK_LEN) + goto nodata; + + /* No need to allocate LL here, as this is only a chunk. */ + skb = alloc_skb(chunklen, gfp); + if (!skb) + goto nodata; + + /* Make room for the chunk header. */ + chunk_hdr = (struct sctp_chunkhdr *)skb_put(skb, sizeof(*chunk_hdr)); + chunk_hdr->type = type; + chunk_hdr->flags = flags; + chunk_hdr->length = htons(sizeof(*chunk_hdr)); + + sk = asoc ? asoc->base.sk : NULL; + retval = sctp_chunkify(skb, asoc, sk, gfp); + if (!retval) { + kfree_skb(skb); + goto nodata; + } + + retval->chunk_hdr = chunk_hdr; + retval->chunk_end = ((__u8 *)chunk_hdr) + sizeof(*chunk_hdr); + + /* Determine if the chunk needs to be authenticated */ + if (sctp_auth_send_cid(type, asoc)) + retval->auth = 1; + + return retval; +nodata: + return NULL; +} + +static struct sctp_chunk *sctp_make_data(const struct sctp_association *asoc, + __u8 flags, int paylen, gfp_t gfp) +{ + return _sctp_make_chunk(asoc, SCTP_CID_DATA, flags, paylen, gfp); +} + +struct sctp_chunk *sctp_make_idata(const struct sctp_association *asoc, + __u8 flags, int paylen, gfp_t gfp) +{ + return _sctp_make_chunk(asoc, SCTP_CID_I_DATA, flags, paylen, gfp); +} + +static struct sctp_chunk *sctp_make_control(const struct sctp_association *asoc, + __u8 type, __u8 flags, int paylen, + gfp_t gfp) +{ + struct sctp_chunk *chunk; + + chunk = _sctp_make_chunk(asoc, type, flags, paylen, gfp); + if (chunk) + sctp_control_set_owner_w(chunk); + + return chunk; +} + +/* Release the memory occupied by a chunk. */ +static void sctp_chunk_destroy(struct sctp_chunk *chunk) +{ + BUG_ON(!list_empty(&chunk->list)); + list_del_init(&chunk->transmitted_list); + + consume_skb(chunk->skb); + consume_skb(chunk->auth_chunk); + + SCTP_DBG_OBJCNT_DEC(chunk); + kmem_cache_free(sctp_chunk_cachep, chunk); +} + +/* Possibly, free the chunk. */ +void sctp_chunk_free(struct sctp_chunk *chunk) +{ + /* Release our reference on the message tracker. */ + if (chunk->msg) + sctp_datamsg_put(chunk->msg); + + sctp_chunk_put(chunk); +} + +/* Grab a reference to the chunk. */ +void sctp_chunk_hold(struct sctp_chunk *ch) +{ + refcount_inc(&ch->refcnt); +} + +/* Release a reference to the chunk. */ +void sctp_chunk_put(struct sctp_chunk *ch) +{ + if (refcount_dec_and_test(&ch->refcnt)) + sctp_chunk_destroy(ch); +} + +/* Append bytes to the end of a chunk. Will panic if chunk is not big + * enough. + */ +void *sctp_addto_chunk(struct sctp_chunk *chunk, int len, const void *data) +{ + int chunklen = ntohs(chunk->chunk_hdr->length); + int padlen = SCTP_PAD4(chunklen) - chunklen; + void *target; + + skb_put_zero(chunk->skb, padlen); + target = skb_put_data(chunk->skb, data, len); + + /* Adjust the chunk length field. */ + chunk->chunk_hdr->length = htons(chunklen + padlen + len); + chunk->chunk_end = skb_tail_pointer(chunk->skb); + + return target; +} + +/* Append bytes from user space to the end of a chunk. Will panic if + * chunk is not big enough. + * Returns a kernel err value. + */ +int sctp_user_addto_chunk(struct sctp_chunk *chunk, int len, + struct iov_iter *from) +{ + void *target; + + /* Make room in chunk for data. */ + target = skb_put(chunk->skb, len); + + /* Copy data (whole iovec) into chunk */ + if (!copy_from_iter_full(target, len, from)) + return -EFAULT; + + /* Adjust the chunk length field. */ + chunk->chunk_hdr->length = + htons(ntohs(chunk->chunk_hdr->length) + len); + chunk->chunk_end = skb_tail_pointer(chunk->skb); + + return 0; +} + +/* Helper function to assign a TSN if needed. This assumes that both + * the data_hdr and association have already been assigned. + */ +void sctp_chunk_assign_ssn(struct sctp_chunk *chunk) +{ + struct sctp_stream *stream; + struct sctp_chunk *lchunk; + struct sctp_datamsg *msg; + __u16 ssn, sid; + + if (chunk->has_ssn) + return; + + /* All fragments will be on the same stream */ + sid = ntohs(chunk->subh.data_hdr->stream); + stream = &chunk->asoc->stream; + + /* Now assign the sequence number to the entire message. + * All fragments must have the same stream sequence number. + */ + msg = chunk->msg; + list_for_each_entry(lchunk, &msg->chunks, frag_list) { + if (lchunk->chunk_hdr->flags & SCTP_DATA_UNORDERED) { + ssn = 0; + } else { + if (lchunk->chunk_hdr->flags & SCTP_DATA_LAST_FRAG) + ssn = sctp_ssn_next(stream, out, sid); + else + ssn = sctp_ssn_peek(stream, out, sid); + } + + lchunk->subh.data_hdr->ssn = htons(ssn); + lchunk->has_ssn = 1; + } +} + +/* Helper function to assign a TSN if needed. This assumes that both + * the data_hdr and association have already been assigned. + */ +void sctp_chunk_assign_tsn(struct sctp_chunk *chunk) +{ + if (!chunk->has_tsn) { + /* This is the last possible instant to + * assign a TSN. + */ + chunk->subh.data_hdr->tsn = + htonl(sctp_association_get_next_tsn(chunk->asoc)); + chunk->has_tsn = 1; + } +} + +/* Create a CLOSED association to use with an incoming packet. */ +struct sctp_association *sctp_make_temp_asoc(const struct sctp_endpoint *ep, + struct sctp_chunk *chunk, + gfp_t gfp) +{ + struct sctp_association *asoc; + enum sctp_scope scope; + struct sk_buff *skb; + + /* Create the bare association. */ + scope = sctp_scope(sctp_source(chunk)); + asoc = sctp_association_new(ep, ep->base.sk, scope, gfp); + if (!asoc) + goto nodata; + asoc->temp = 1; + skb = chunk->skb; + /* Create an entry for the source address of the packet. */ + SCTP_INPUT_CB(skb)->af->from_skb(&asoc->c.peer_addr, skb, 1); + +nodata: + return asoc; +} + +/* Build a cookie representing asoc. + * This INCLUDES the param header needed to put the cookie in the INIT ACK. + */ +static struct sctp_cookie_param *sctp_pack_cookie( + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const struct sctp_chunk *init_chunk, + int *cookie_len, const __u8 *raw_addrs, + int addrs_len) +{ + struct sctp_signed_cookie *cookie; + struct sctp_cookie_param *retval; + int headersize, bodysize; + + /* Header size is static data prior to the actual cookie, including + * any padding. + */ + headersize = sizeof(struct sctp_paramhdr) + + (sizeof(struct sctp_signed_cookie) - + sizeof(struct sctp_cookie)); + bodysize = sizeof(struct sctp_cookie) + + ntohs(init_chunk->chunk_hdr->length) + addrs_len; + + /* Pad out the cookie to a multiple to make the signature + * functions simpler to write. + */ + if (bodysize % SCTP_COOKIE_MULTIPLE) + bodysize += SCTP_COOKIE_MULTIPLE + - (bodysize % SCTP_COOKIE_MULTIPLE); + *cookie_len = headersize + bodysize; + + /* Clear this memory since we are sending this data structure + * out on the network. + */ + retval = kzalloc(*cookie_len, GFP_ATOMIC); + if (!retval) + goto nodata; + + cookie = (struct sctp_signed_cookie *) retval->body; + + /* Set up the parameter header. */ + retval->p.type = SCTP_PARAM_STATE_COOKIE; + retval->p.length = htons(*cookie_len); + + /* Copy the cookie part of the association itself. */ + cookie->c = asoc->c; + /* Save the raw address list length in the cookie. */ + cookie->c.raw_addr_list_len = addrs_len; + + /* Remember PR-SCTP capability. */ + cookie->c.prsctp_capable = asoc->peer.prsctp_capable; + + /* Save adaptation indication in the cookie. */ + cookie->c.adaptation_ind = asoc->peer.adaptation_ind; + + /* Set an expiration time for the cookie. */ + cookie->c.expiration = ktime_add(asoc->cookie_life, + ktime_get_real()); + + /* Copy the peer's init packet. */ + memcpy(&cookie->c.peer_init[0], init_chunk->chunk_hdr, + ntohs(init_chunk->chunk_hdr->length)); + + /* Copy the raw local address list of the association. */ + memcpy((__u8 *)&cookie->c.peer_init[0] + + ntohs(init_chunk->chunk_hdr->length), raw_addrs, addrs_len); + + if (sctp_sk(ep->base.sk)->hmac) { + struct crypto_shash *tfm = sctp_sk(ep->base.sk)->hmac; + int err; + + /* Sign the message. */ + err = crypto_shash_setkey(tfm, ep->secret_key, + sizeof(ep->secret_key)) ?: + crypto_shash_tfm_digest(tfm, (u8 *)&cookie->c, bodysize, + cookie->signature); + if (err) + goto free_cookie; + } + + return retval; + +free_cookie: + kfree(retval); +nodata: + *cookie_len = 0; + return NULL; +} + +/* Unpack the cookie from COOKIE ECHO chunk, recreating the association. */ +struct sctp_association *sctp_unpack_cookie( + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + struct sctp_chunk *chunk, gfp_t gfp, + int *error, struct sctp_chunk **errp) +{ + struct sctp_association *retval = NULL; + int headersize, bodysize, fixed_size; + struct sctp_signed_cookie *cookie; + struct sk_buff *skb = chunk->skb; + struct sctp_cookie *bear_cookie; + __u8 *digest = ep->digest; + enum sctp_scope scope; + unsigned int len; + ktime_t kt; + + /* Header size is static data prior to the actual cookie, including + * any padding. + */ + headersize = sizeof(struct sctp_chunkhdr) + + (sizeof(struct sctp_signed_cookie) - + sizeof(struct sctp_cookie)); + bodysize = ntohs(chunk->chunk_hdr->length) - headersize; + fixed_size = headersize + sizeof(struct sctp_cookie); + + /* Verify that the chunk looks like it even has a cookie. + * There must be enough room for our cookie and our peer's + * INIT chunk. + */ + len = ntohs(chunk->chunk_hdr->length); + if (len < fixed_size + sizeof(struct sctp_chunkhdr)) + goto malformed; + + /* Verify that the cookie has been padded out. */ + if (bodysize % SCTP_COOKIE_MULTIPLE) + goto malformed; + + /* Process the cookie. */ + cookie = chunk->subh.cookie_hdr; + bear_cookie = &cookie->c; + + if (!sctp_sk(ep->base.sk)->hmac) + goto no_hmac; + + /* Check the signature. */ + { + struct crypto_shash *tfm = sctp_sk(ep->base.sk)->hmac; + int err; + + err = crypto_shash_setkey(tfm, ep->secret_key, + sizeof(ep->secret_key)) ?: + crypto_shash_tfm_digest(tfm, (u8 *)bear_cookie, bodysize, + digest); + if (err) { + *error = -SCTP_IERROR_NOMEM; + goto fail; + } + } + + if (memcmp(digest, cookie->signature, SCTP_SIGNATURE_SIZE)) { + *error = -SCTP_IERROR_BAD_SIG; + goto fail; + } + +no_hmac: + /* IG Section 2.35.2: + * 3) Compare the port numbers and the verification tag contained + * within the COOKIE ECHO chunk to the actual port numbers and the + * verification tag within the SCTP common header of the received + * packet. If these values do not match the packet MUST be silently + * discarded, + */ + if (ntohl(chunk->sctp_hdr->vtag) != bear_cookie->my_vtag) { + *error = -SCTP_IERROR_BAD_TAG; + goto fail; + } + + if (chunk->sctp_hdr->source != bear_cookie->peer_addr.v4.sin_port || + ntohs(chunk->sctp_hdr->dest) != bear_cookie->my_port) { + *error = -SCTP_IERROR_BAD_PORTS; + goto fail; + } + + /* Check to see if the cookie is stale. If there is already + * an association, there is no need to check cookie's expiration + * for init collision case of lost COOKIE ACK. + * If skb has been timestamped, then use the stamp, otherwise + * use current time. This introduces a small possibility that + * a cookie may be considered expired, but this would only slow + * down the new association establishment instead of every packet. + */ + if (sock_flag(ep->base.sk, SOCK_TIMESTAMP)) + kt = skb_get_ktime(skb); + else + kt = ktime_get_real(); + + if (!asoc && ktime_before(bear_cookie->expiration, kt)) { + suseconds_t usecs = ktime_to_us(ktime_sub(kt, bear_cookie->expiration)); + __be32 n = htonl(usecs); + + /* + * Section 3.3.10.3 Stale Cookie Error (3) + * + * Cause of error + * --------------- + * Stale Cookie Error: Indicates the receipt of a valid State + * Cookie that has expired. + */ + *errp = sctp_make_op_error(asoc, chunk, + SCTP_ERROR_STALE_COOKIE, &n, + sizeof(n), 0); + if (*errp) + *error = -SCTP_IERROR_STALE_COOKIE; + else + *error = -SCTP_IERROR_NOMEM; + + goto fail; + } + + /* Make a new base association. */ + scope = sctp_scope(sctp_source(chunk)); + retval = sctp_association_new(ep, ep->base.sk, scope, gfp); + if (!retval) { + *error = -SCTP_IERROR_NOMEM; + goto fail; + } + + /* Set up our peer's port number. */ + retval->peer.port = ntohs(chunk->sctp_hdr->source); + + /* Populate the association from the cookie. */ + memcpy(&retval->c, bear_cookie, sizeof(*bear_cookie)); + + if (sctp_assoc_set_bind_addr_from_cookie(retval, bear_cookie, + GFP_ATOMIC) < 0) { + *error = -SCTP_IERROR_NOMEM; + goto fail; + } + + /* Also, add the destination address. */ + if (list_empty(&retval->base.bind_addr.address_list)) { + sctp_add_bind_addr(&retval->base.bind_addr, &chunk->dest, + sizeof(chunk->dest), SCTP_ADDR_SRC, + GFP_ATOMIC); + } + + retval->next_tsn = retval->c.initial_tsn; + retval->ctsn_ack_point = retval->next_tsn - 1; + retval->addip_serial = retval->c.initial_tsn; + retval->strreset_outseq = retval->c.initial_tsn; + retval->adv_peer_ack_point = retval->ctsn_ack_point; + retval->peer.prsctp_capable = retval->c.prsctp_capable; + retval->peer.adaptation_ind = retval->c.adaptation_ind; + + /* The INIT stuff will be done by the side effects. */ + return retval; + +fail: + if (retval) + sctp_association_free(retval); + + return NULL; + +malformed: + /* Yikes! The packet is either corrupt or deliberately + * malformed. + */ + *error = -SCTP_IERROR_MALFORMED; + goto fail; +} + +/******************************************************************** + * 3rd Level Abstractions + ********************************************************************/ + +struct __sctp_missing { + __be32 num_missing; + __be16 type; +} __packed; + +/* + * Report a missing mandatory parameter. + */ +static int sctp_process_missing_param(const struct sctp_association *asoc, + enum sctp_param paramtype, + struct sctp_chunk *chunk, + struct sctp_chunk **errp) +{ + struct __sctp_missing report; + __u16 len; + + len = SCTP_PAD4(sizeof(report)); + + /* Make an ERROR chunk, preparing enough room for + * returning multiple unknown parameters. + */ + if (!*errp) + *errp = sctp_make_op_error_space(asoc, chunk, len); + + if (*errp) { + report.num_missing = htonl(1); + report.type = paramtype; + sctp_init_cause(*errp, SCTP_ERROR_MISS_PARAM, + sizeof(report)); + sctp_addto_chunk(*errp, sizeof(report), &report); + } + + /* Stop processing this chunk. */ + return 0; +} + +/* Report an Invalid Mandatory Parameter. */ +static int sctp_process_inv_mandatory(const struct sctp_association *asoc, + struct sctp_chunk *chunk, + struct sctp_chunk **errp) +{ + /* Invalid Mandatory Parameter Error has no payload. */ + + if (!*errp) + *errp = sctp_make_op_error_space(asoc, chunk, 0); + + if (*errp) + sctp_init_cause(*errp, SCTP_ERROR_INV_PARAM, 0); + + /* Stop processing this chunk. */ + return 0; +} + +static int sctp_process_inv_paramlength(const struct sctp_association *asoc, + struct sctp_paramhdr *param, + const struct sctp_chunk *chunk, + struct sctp_chunk **errp) +{ + /* This is a fatal error. Any accumulated non-fatal errors are + * not reported. + */ + if (*errp) + sctp_chunk_free(*errp); + + /* Create an error chunk and fill it in with our payload. */ + *errp = sctp_make_violation_paramlen(asoc, chunk, param); + + return 0; +} + + +/* Do not attempt to handle the HOST_NAME parm. However, do + * send back an indicator to the peer. + */ +static int sctp_process_hn_param(const struct sctp_association *asoc, + union sctp_params param, + struct sctp_chunk *chunk, + struct sctp_chunk **errp) +{ + __u16 len = ntohs(param.p->length); + + /* Processing of the HOST_NAME parameter will generate an + * ABORT. If we've accumulated any non-fatal errors, they + * would be unrecognized parameters and we should not include + * them in the ABORT. + */ + if (*errp) + sctp_chunk_free(*errp); + + *errp = sctp_make_op_error(asoc, chunk, SCTP_ERROR_DNS_FAILED, + param.v, len, 0); + + /* Stop processing this chunk. */ + return 0; +} + +static int sctp_verify_ext_param(struct net *net, + const struct sctp_endpoint *ep, + union sctp_params param) +{ + __u16 num_ext = ntohs(param.p->length) - sizeof(struct sctp_paramhdr); + int have_asconf = 0; + int have_auth = 0; + int i; + + for (i = 0; i < num_ext; i++) { + switch (param.ext->chunks[i]) { + case SCTP_CID_AUTH: + have_auth = 1; + break; + case SCTP_CID_ASCONF: + case SCTP_CID_ASCONF_ACK: + have_asconf = 1; + break; + } + } + + /* ADD-IP Security: The draft requires us to ABORT or ignore the + * INIT/INIT-ACK if ADD-IP is listed, but AUTH is not. Do this + * only if ADD-IP is turned on and we are not backward-compatible + * mode. + */ + if (net->sctp.addip_noauth) + return 1; + + if (ep->asconf_enable && !have_auth && have_asconf) + return 0; + + return 1; +} + +static void sctp_process_ext_param(struct sctp_association *asoc, + union sctp_params param) +{ + __u16 num_ext = ntohs(param.p->length) - sizeof(struct sctp_paramhdr); + int i; + + for (i = 0; i < num_ext; i++) { + switch (param.ext->chunks[i]) { + case SCTP_CID_RECONF: + if (asoc->ep->reconf_enable) + asoc->peer.reconf_capable = 1; + break; + case SCTP_CID_FWD_TSN: + if (asoc->ep->prsctp_enable) + asoc->peer.prsctp_capable = 1; + break; + case SCTP_CID_AUTH: + /* if the peer reports AUTH, assume that he + * supports AUTH. + */ + if (asoc->ep->auth_enable) + asoc->peer.auth_capable = 1; + break; + case SCTP_CID_ASCONF: + case SCTP_CID_ASCONF_ACK: + if (asoc->ep->asconf_enable) + asoc->peer.asconf_capable = 1; + break; + case SCTP_CID_I_DATA: + if (asoc->ep->intl_enable) + asoc->peer.intl_capable = 1; + break; + default: + break; + } + } +} + +/* RFC 3.2.1 & the Implementers Guide 2.2. + * + * The Parameter Types are encoded such that the + * highest-order two bits specify the action that must be + * taken if the processing endpoint does not recognize the + * Parameter Type. + * + * 00 - Stop processing this parameter; do not process any further + * parameters within this chunk + * + * 01 - Stop processing this parameter, do not process any further + * parameters within this chunk, and report the unrecognized + * parameter in an 'Unrecognized Parameter' ERROR chunk. + * + * 10 - Skip this parameter and continue processing. + * + * 11 - Skip this parameter and continue processing but + * report the unrecognized parameter in an + * 'Unrecognized Parameter' ERROR chunk. + * + * Return value: + * SCTP_IERROR_NO_ERROR - continue with the chunk + * SCTP_IERROR_ERROR - stop and report an error. + * SCTP_IERROR_NOMEME - out of memory. + */ +static enum sctp_ierror sctp_process_unk_param( + const struct sctp_association *asoc, + union sctp_params param, + struct sctp_chunk *chunk, + struct sctp_chunk **errp) +{ + int retval = SCTP_IERROR_NO_ERROR; + + switch (param.p->type & SCTP_PARAM_ACTION_MASK) { + case SCTP_PARAM_ACTION_DISCARD: + retval = SCTP_IERROR_ERROR; + break; + case SCTP_PARAM_ACTION_SKIP: + break; + case SCTP_PARAM_ACTION_DISCARD_ERR: + retval = SCTP_IERROR_ERROR; + fallthrough; + case SCTP_PARAM_ACTION_SKIP_ERR: + /* Make an ERROR chunk, preparing enough room for + * returning multiple unknown parameters. + */ + if (!*errp) { + *errp = sctp_make_op_error_limited(asoc, chunk); + if (!*errp) { + /* If there is no memory for generating the + * ERROR report as specified, an ABORT will be + * triggered to the peer and the association + * won't be established. + */ + retval = SCTP_IERROR_NOMEM; + break; + } + } + + if (!sctp_init_cause(*errp, SCTP_ERROR_UNKNOWN_PARAM, + ntohs(param.p->length))) + sctp_addto_chunk(*errp, ntohs(param.p->length), + param.v); + break; + default: + break; + } + + return retval; +} + +/* Verify variable length parameters + * Return values: + * SCTP_IERROR_ABORT - trigger an ABORT + * SCTP_IERROR_NOMEM - out of memory (abort) + * SCTP_IERROR_ERROR - stop processing, trigger an ERROR + * SCTP_IERROR_NO_ERROR - continue with the chunk + */ +static enum sctp_ierror sctp_verify_param(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + union sctp_params param, + enum sctp_cid cid, + struct sctp_chunk *chunk, + struct sctp_chunk **err_chunk) +{ + struct sctp_hmac_algo_param *hmacs; + int retval = SCTP_IERROR_NO_ERROR; + __u16 n_elt, id = 0; + int i; + + /* FIXME - This routine is not looking at each parameter per the + * chunk type, i.e., unrecognized parameters should be further + * identified based on the chunk id. + */ + + switch (param.p->type) { + case SCTP_PARAM_IPV4_ADDRESS: + case SCTP_PARAM_IPV6_ADDRESS: + case SCTP_PARAM_COOKIE_PRESERVATIVE: + case SCTP_PARAM_SUPPORTED_ADDRESS_TYPES: + case SCTP_PARAM_STATE_COOKIE: + case SCTP_PARAM_HEARTBEAT_INFO: + case SCTP_PARAM_UNRECOGNIZED_PARAMETERS: + case SCTP_PARAM_ECN_CAPABLE: + case SCTP_PARAM_ADAPTATION_LAYER_IND: + break; + + case SCTP_PARAM_SUPPORTED_EXT: + if (!sctp_verify_ext_param(net, ep, param)) + return SCTP_IERROR_ABORT; + break; + + case SCTP_PARAM_SET_PRIMARY: + if (!ep->asconf_enable) + goto unhandled; + + if (ntohs(param.p->length) < sizeof(struct sctp_addip_param) + + sizeof(struct sctp_paramhdr)) { + sctp_process_inv_paramlength(asoc, param.p, + chunk, err_chunk); + retval = SCTP_IERROR_ABORT; + } + break; + + case SCTP_PARAM_HOST_NAME_ADDRESS: + /* Tell the peer, we won't support this param. */ + sctp_process_hn_param(asoc, param, chunk, err_chunk); + retval = SCTP_IERROR_ABORT; + break; + + case SCTP_PARAM_FWD_TSN_SUPPORT: + if (ep->prsctp_enable) + break; + goto unhandled; + + case SCTP_PARAM_RANDOM: + if (!ep->auth_enable) + goto unhandled; + + /* SCTP-AUTH: Secion 6.1 + * If the random number is not 32 byte long the association + * MUST be aborted. The ABORT chunk SHOULD contain the error + * cause 'Protocol Violation'. + */ + if (SCTP_AUTH_RANDOM_LENGTH != ntohs(param.p->length) - + sizeof(struct sctp_paramhdr)) { + sctp_process_inv_paramlength(asoc, param.p, + chunk, err_chunk); + retval = SCTP_IERROR_ABORT; + } + break; + + case SCTP_PARAM_CHUNKS: + if (!ep->auth_enable) + goto unhandled; + + /* SCTP-AUTH: Section 3.2 + * The CHUNKS parameter MUST be included once in the INIT or + * INIT-ACK chunk if the sender wants to receive authenticated + * chunks. Its maximum length is 260 bytes. + */ + if (260 < ntohs(param.p->length)) { + sctp_process_inv_paramlength(asoc, param.p, + chunk, err_chunk); + retval = SCTP_IERROR_ABORT; + } + break; + + case SCTP_PARAM_HMAC_ALGO: + if (!ep->auth_enable) + goto unhandled; + + hmacs = (struct sctp_hmac_algo_param *)param.p; + n_elt = (ntohs(param.p->length) - + sizeof(struct sctp_paramhdr)) >> 1; + + /* SCTP-AUTH: Section 6.1 + * The HMAC algorithm based on SHA-1 MUST be supported and + * included in the HMAC-ALGO parameter. + */ + for (i = 0; i < n_elt; i++) { + id = ntohs(hmacs->hmac_ids[i]); + + if (id == SCTP_AUTH_HMAC_ID_SHA1) + break; + } + + if (id != SCTP_AUTH_HMAC_ID_SHA1) { + sctp_process_inv_paramlength(asoc, param.p, chunk, + err_chunk); + retval = SCTP_IERROR_ABORT; + } + break; +unhandled: + default: + pr_debug("%s: unrecognized param:%d for chunk:%d\n", + __func__, ntohs(param.p->type), cid); + + retval = sctp_process_unk_param(asoc, param, chunk, err_chunk); + break; + } + return retval; +} + +/* Verify the INIT packet before we process it. */ +int sctp_verify_init(struct net *net, const struct sctp_endpoint *ep, + const struct sctp_association *asoc, enum sctp_cid cid, + struct sctp_init_chunk *peer_init, + struct sctp_chunk *chunk, struct sctp_chunk **errp) +{ + union sctp_params param; + bool has_cookie = false; + int result; + + /* Check for missing mandatory parameters. Note: Initial TSN is + * also mandatory, but is not checked here since the valid range + * is 0..2**32-1. RFC4960, section 3.3.3. + */ + if (peer_init->init_hdr.num_outbound_streams == 0 || + peer_init->init_hdr.num_inbound_streams == 0 || + peer_init->init_hdr.init_tag == 0 || + ntohl(peer_init->init_hdr.a_rwnd) < SCTP_DEFAULT_MINWINDOW) + return sctp_process_inv_mandatory(asoc, chunk, errp); + + sctp_walk_params(param, peer_init, init_hdr.params) { + if (param.p->type == SCTP_PARAM_STATE_COOKIE) + has_cookie = true; + } + + /* There is a possibility that a parameter length was bad and + * in that case we would have stoped walking the parameters. + * The current param.p would point at the bad one. + * Current consensus on the mailing list is to generate a PROTOCOL + * VIOLATION error. We build the ERROR chunk here and let the normal + * error handling code build and send the packet. + */ + if (param.v != (void *)chunk->chunk_end) + return sctp_process_inv_paramlength(asoc, param.p, chunk, errp); + + /* The only missing mandatory param possible today is + * the state cookie for an INIT-ACK chunk. + */ + if ((SCTP_CID_INIT_ACK == cid) && !has_cookie) + return sctp_process_missing_param(asoc, SCTP_PARAM_STATE_COOKIE, + chunk, errp); + + /* Verify all the variable length parameters */ + sctp_walk_params(param, peer_init, init_hdr.params) { + result = sctp_verify_param(net, ep, asoc, param, cid, + chunk, errp); + switch (result) { + case SCTP_IERROR_ABORT: + case SCTP_IERROR_NOMEM: + return 0; + case SCTP_IERROR_ERROR: + return 1; + case SCTP_IERROR_NO_ERROR: + default: + break; + } + + } /* for (loop through all parameters) */ + + return 1; +} + +/* Unpack the parameters in an INIT packet into an association. + * Returns 0 on failure, else success. + * FIXME: This is an association method. + */ +int sctp_process_init(struct sctp_association *asoc, struct sctp_chunk *chunk, + const union sctp_addr *peer_addr, + struct sctp_init_chunk *peer_init, gfp_t gfp) +{ + struct sctp_transport *transport; + struct list_head *pos, *temp; + union sctp_params param; + union sctp_addr addr; + struct sctp_af *af; + int src_match = 0; + + /* We must include the address that the INIT packet came from. + * This is the only address that matters for an INIT packet. + * When processing a COOKIE ECHO, we retrieve the from address + * of the INIT from the cookie. + */ + + /* This implementation defaults to making the first transport + * added as the primary transport. The source address seems to + * be a better choice than any of the embedded addresses. + */ + asoc->encap_port = SCTP_INPUT_CB(chunk->skb)->encap_port; + if (!sctp_assoc_add_peer(asoc, peer_addr, gfp, SCTP_ACTIVE)) + goto nomem; + + if (sctp_cmp_addr_exact(sctp_source(chunk), peer_addr)) + src_match = 1; + + /* Process the initialization parameters. */ + sctp_walk_params(param, peer_init, init_hdr.params) { + if (!src_match && + (param.p->type == SCTP_PARAM_IPV4_ADDRESS || + param.p->type == SCTP_PARAM_IPV6_ADDRESS)) { + af = sctp_get_af_specific(param_type2af(param.p->type)); + if (!af->from_addr_param(&addr, param.addr, + chunk->sctp_hdr->source, 0)) + continue; + if (sctp_cmp_addr_exact(sctp_source(chunk), &addr)) + src_match = 1; + } + + if (!sctp_process_param(asoc, param, peer_addr, gfp)) + goto clean_up; + } + + /* source address of chunk may not match any valid address */ + if (!src_match) + goto clean_up; + + /* AUTH: After processing the parameters, make sure that we + * have all the required info to potentially do authentications. + */ + if (asoc->peer.auth_capable && (!asoc->peer.peer_random || + !asoc->peer.peer_hmacs)) + asoc->peer.auth_capable = 0; + + /* In a non-backward compatible mode, if the peer claims + * support for ADD-IP but not AUTH, the ADD-IP spec states + * that we MUST ABORT the association. Section 6. The section + * also give us an option to silently ignore the packet, which + * is what we'll do here. + */ + if (!asoc->base.net->sctp.addip_noauth && + (asoc->peer.asconf_capable && !asoc->peer.auth_capable)) { + asoc->peer.addip_disabled_mask |= (SCTP_PARAM_ADD_IP | + SCTP_PARAM_DEL_IP | + SCTP_PARAM_SET_PRIMARY); + asoc->peer.asconf_capable = 0; + goto clean_up; + } + + /* Walk list of transports, removing transports in the UNKNOWN state. */ + list_for_each_safe(pos, temp, &asoc->peer.transport_addr_list) { + transport = list_entry(pos, struct sctp_transport, transports); + if (transport->state == SCTP_UNKNOWN) { + sctp_assoc_rm_peer(asoc, transport); + } + } + + /* The fixed INIT headers are always in network byte + * order. + */ + asoc->peer.i.init_tag = + ntohl(peer_init->init_hdr.init_tag); + asoc->peer.i.a_rwnd = + ntohl(peer_init->init_hdr.a_rwnd); + asoc->peer.i.num_outbound_streams = + ntohs(peer_init->init_hdr.num_outbound_streams); + asoc->peer.i.num_inbound_streams = + ntohs(peer_init->init_hdr.num_inbound_streams); + asoc->peer.i.initial_tsn = + ntohl(peer_init->init_hdr.initial_tsn); + + asoc->strreset_inseq = asoc->peer.i.initial_tsn; + + /* Apply the upper bounds for output streams based on peer's + * number of inbound streams. + */ + if (asoc->c.sinit_num_ostreams > + ntohs(peer_init->init_hdr.num_inbound_streams)) { + asoc->c.sinit_num_ostreams = + ntohs(peer_init->init_hdr.num_inbound_streams); + } + + if (asoc->c.sinit_max_instreams > + ntohs(peer_init->init_hdr.num_outbound_streams)) { + asoc->c.sinit_max_instreams = + ntohs(peer_init->init_hdr.num_outbound_streams); + } + + /* Copy Initiation tag from INIT to VT_peer in cookie. */ + asoc->c.peer_vtag = asoc->peer.i.init_tag; + + /* Peer Rwnd : Current calculated value of the peer's rwnd. */ + asoc->peer.rwnd = asoc->peer.i.a_rwnd; + + /* RFC 2960 7.2.1 The initial value of ssthresh MAY be arbitrarily + * high (for example, implementations MAY use the size of the receiver + * advertised window). + */ + list_for_each_entry(transport, &asoc->peer.transport_addr_list, + transports) { + transport->ssthresh = asoc->peer.i.a_rwnd; + } + + /* Set up the TSN tracking pieces. */ + if (!sctp_tsnmap_init(&asoc->peer.tsn_map, SCTP_TSN_MAP_INITIAL, + asoc->peer.i.initial_tsn, gfp)) + goto clean_up; + + /* RFC 2960 6.5 Stream Identifier and Stream Sequence Number + * + * The stream sequence number in all the streams shall start + * from 0 when the association is established. Also, when the + * stream sequence number reaches the value 65535 the next + * stream sequence number shall be set to 0. + */ + + if (sctp_stream_init(&asoc->stream, asoc->c.sinit_num_ostreams, + asoc->c.sinit_max_instreams, gfp)) + goto clean_up; + + /* Update frag_point when stream_interleave may get changed. */ + sctp_assoc_update_frag_point(asoc); + + if (!asoc->temp && sctp_assoc_set_id(asoc, gfp)) + goto clean_up; + + /* ADDIP Section 4.1 ASCONF Chunk Procedures + * + * When an endpoint has an ASCONF signaled change to be sent to the + * remote endpoint it should do the following: + * ... + * A2) A serial number should be assigned to the Chunk. The serial + * number should be a monotonically increasing number. All serial + * numbers are defined to be initialized at the start of the + * association to the same value as the Initial TSN. + */ + asoc->peer.addip_serial = asoc->peer.i.initial_tsn - 1; + return 1; + +clean_up: + /* Release the transport structures. */ + list_for_each_safe(pos, temp, &asoc->peer.transport_addr_list) { + transport = list_entry(pos, struct sctp_transport, transports); + if (transport->state != SCTP_ACTIVE) + sctp_assoc_rm_peer(asoc, transport); + } + +nomem: + return 0; +} + + +/* Update asoc with the option described in param. + * + * RFC2960 3.3.2.1 Optional/Variable Length Parameters in INIT + * + * asoc is the association to update. + * param is the variable length parameter to use for update. + * cid tells us if this is an INIT, INIT ACK or COOKIE ECHO. + * If the current packet is an INIT we want to minimize the amount of + * work we do. In particular, we should not build transport + * structures for the addresses. + */ +static int sctp_process_param(struct sctp_association *asoc, + union sctp_params param, + const union sctp_addr *peer_addr, + gfp_t gfp) +{ + struct sctp_endpoint *ep = asoc->ep; + union sctp_addr_param *addr_param; + struct net *net = asoc->base.net; + struct sctp_transport *t; + enum sctp_scope scope; + union sctp_addr addr; + struct sctp_af *af; + int retval = 1, i; + u32 stale; + __u16 sat; + + /* We maintain all INIT parameters in network byte order all the + * time. This allows us to not worry about whether the parameters + * came from a fresh INIT, and INIT ACK, or were stored in a cookie. + */ + switch (param.p->type) { + case SCTP_PARAM_IPV6_ADDRESS: + if (PF_INET6 != asoc->base.sk->sk_family) + break; + goto do_addr_param; + + case SCTP_PARAM_IPV4_ADDRESS: + /* v4 addresses are not allowed on v6-only socket */ + if (ipv6_only_sock(asoc->base.sk)) + break; +do_addr_param: + af = sctp_get_af_specific(param_type2af(param.p->type)); + if (!af->from_addr_param(&addr, param.addr, htons(asoc->peer.port), 0)) + break; + scope = sctp_scope(peer_addr); + if (sctp_in_scope(net, &addr, scope)) + if (!sctp_assoc_add_peer(asoc, &addr, gfp, SCTP_UNCONFIRMED)) + return 0; + break; + + case SCTP_PARAM_COOKIE_PRESERVATIVE: + if (!net->sctp.cookie_preserve_enable) + break; + + stale = ntohl(param.life->lifespan_increment); + + /* Suggested Cookie Life span increment's unit is msec, + * (1/1000sec). + */ + asoc->cookie_life = ktime_add_ms(asoc->cookie_life, stale); + break; + + case SCTP_PARAM_HOST_NAME_ADDRESS: + pr_debug("%s: unimplemented SCTP_HOST_NAME_ADDRESS\n", __func__); + break; + + case SCTP_PARAM_SUPPORTED_ADDRESS_TYPES: + /* Turn off the default values first so we'll know which + * ones are really set by the peer. + */ + asoc->peer.ipv4_address = 0; + asoc->peer.ipv6_address = 0; + + /* Assume that peer supports the address family + * by which it sends a packet. + */ + if (peer_addr->sa.sa_family == AF_INET6) + asoc->peer.ipv6_address = 1; + else if (peer_addr->sa.sa_family == AF_INET) + asoc->peer.ipv4_address = 1; + + /* Cycle through address types; avoid divide by 0. */ + sat = ntohs(param.p->length) - sizeof(struct sctp_paramhdr); + if (sat) + sat /= sizeof(__u16); + + for (i = 0; i < sat; ++i) { + switch (param.sat->types[i]) { + case SCTP_PARAM_IPV4_ADDRESS: + asoc->peer.ipv4_address = 1; + break; + + case SCTP_PARAM_IPV6_ADDRESS: + if (PF_INET6 == asoc->base.sk->sk_family) + asoc->peer.ipv6_address = 1; + break; + + case SCTP_PARAM_HOST_NAME_ADDRESS: + asoc->peer.hostname_address = 1; + break; + + default: /* Just ignore anything else. */ + break; + } + } + break; + + case SCTP_PARAM_STATE_COOKIE: + asoc->peer.cookie_len = + ntohs(param.p->length) - sizeof(struct sctp_paramhdr); + kfree(asoc->peer.cookie); + asoc->peer.cookie = kmemdup(param.cookie->body, asoc->peer.cookie_len, gfp); + if (!asoc->peer.cookie) + retval = 0; + break; + + case SCTP_PARAM_HEARTBEAT_INFO: + /* Would be odd to receive, but it causes no problems. */ + break; + + case SCTP_PARAM_UNRECOGNIZED_PARAMETERS: + /* Rejected during verify stage. */ + break; + + case SCTP_PARAM_ECN_CAPABLE: + if (asoc->ep->ecn_enable) { + asoc->peer.ecn_capable = 1; + break; + } + /* Fall Through */ + goto fall_through; + + + case SCTP_PARAM_ADAPTATION_LAYER_IND: + asoc->peer.adaptation_ind = ntohl(param.aind->adaptation_ind); + break; + + case SCTP_PARAM_SET_PRIMARY: + if (!ep->asconf_enable) + goto fall_through; + + addr_param = param.v + sizeof(struct sctp_addip_param); + + af = sctp_get_af_specific(param_type2af(addr_param->p.type)); + if (!af) + break; + + if (!af->from_addr_param(&addr, addr_param, + htons(asoc->peer.port), 0)) + break; + + if (!af->addr_valid(&addr, NULL, NULL)) + break; + + t = sctp_assoc_lookup_paddr(asoc, &addr); + if (!t) + break; + + sctp_assoc_set_primary(asoc, t); + break; + + case SCTP_PARAM_SUPPORTED_EXT: + sctp_process_ext_param(asoc, param); + break; + + case SCTP_PARAM_FWD_TSN_SUPPORT: + if (asoc->ep->prsctp_enable) { + asoc->peer.prsctp_capable = 1; + break; + } + /* Fall Through */ + goto fall_through; + + case SCTP_PARAM_RANDOM: + if (!ep->auth_enable) + goto fall_through; + + /* Save peer's random parameter */ + kfree(asoc->peer.peer_random); + asoc->peer.peer_random = kmemdup(param.p, + ntohs(param.p->length), gfp); + if (!asoc->peer.peer_random) { + retval = 0; + break; + } + break; + + case SCTP_PARAM_HMAC_ALGO: + if (!ep->auth_enable) + goto fall_through; + + /* Save peer's HMAC list */ + kfree(asoc->peer.peer_hmacs); + asoc->peer.peer_hmacs = kmemdup(param.p, + ntohs(param.p->length), gfp); + if (!asoc->peer.peer_hmacs) { + retval = 0; + break; + } + + /* Set the default HMAC the peer requested*/ + sctp_auth_asoc_set_default_hmac(asoc, param.hmac_algo); + break; + + case SCTP_PARAM_CHUNKS: + if (!ep->auth_enable) + goto fall_through; + + kfree(asoc->peer.peer_chunks); + asoc->peer.peer_chunks = kmemdup(param.p, + ntohs(param.p->length), gfp); + if (!asoc->peer.peer_chunks) + retval = 0; + break; +fall_through: + default: + /* Any unrecognized parameters should have been caught + * and handled by sctp_verify_param() which should be + * called prior to this routine. Simply log the error + * here. + */ + pr_debug("%s: ignoring param:%d for association:%p.\n", + __func__, ntohs(param.p->type), asoc); + break; + } + + return retval; +} + +/* Select a new verification tag. */ +__u32 sctp_generate_tag(const struct sctp_endpoint *ep) +{ + /* I believe that this random number generator complies with RFC1750. + * A tag of 0 is reserved for special cases (e.g. INIT). + */ + __u32 x; + + do { + get_random_bytes(&x, sizeof(__u32)); + } while (x == 0); + + return x; +} + +/* Select an initial TSN to send during startup. */ +__u32 sctp_generate_tsn(const struct sctp_endpoint *ep) +{ + __u32 retval; + + get_random_bytes(&retval, sizeof(__u32)); + return retval; +} + +/* + * ADDIP 3.1.1 Address Configuration Change Chunk (ASCONF) + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Type = 0xC1 | Chunk Flags | Chunk Length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Serial Number | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Address Parameter | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | ASCONF Parameter #1 | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * \ \ + * / .... / + * \ \ + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | ASCONF Parameter #N | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * Address Parameter and other parameter will not be wrapped in this function + */ +static struct sctp_chunk *sctp_make_asconf(struct sctp_association *asoc, + union sctp_addr *addr, + int vparam_len) +{ + struct sctp_addiphdr asconf; + struct sctp_chunk *retval; + int length = sizeof(asconf) + vparam_len; + union sctp_addr_param addrparam; + int addrlen; + struct sctp_af *af = sctp_get_af_specific(addr->v4.sin_family); + + addrlen = af->to_addr_param(addr, &addrparam); + if (!addrlen) + return NULL; + length += addrlen; + + /* Create the chunk. */ + retval = sctp_make_control(asoc, SCTP_CID_ASCONF, 0, length, + GFP_ATOMIC); + if (!retval) + return NULL; + + asconf.serial = htonl(asoc->addip_serial++); + + retval->subh.addip_hdr = + sctp_addto_chunk(retval, sizeof(asconf), &asconf); + retval->param_hdr.v = + sctp_addto_chunk(retval, addrlen, &addrparam); + + return retval; +} + +/* ADDIP + * 3.2.1 Add IP Address + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Type = 0xC001 | Length = Variable | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | ASCONF-Request Correlation ID | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Address Parameter | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * 3.2.2 Delete IP Address + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Type = 0xC002 | Length = Variable | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | ASCONF-Request Correlation ID | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Address Parameter | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + */ +struct sctp_chunk *sctp_make_asconf_update_ip(struct sctp_association *asoc, + union sctp_addr *laddr, + struct sockaddr *addrs, + int addrcnt, __be16 flags) +{ + union sctp_addr_param addr_param; + struct sctp_addip_param param; + int paramlen = sizeof(param); + struct sctp_chunk *retval; + int addr_param_len = 0; + union sctp_addr *addr; + int totallen = 0, i; + int del_pickup = 0; + struct sctp_af *af; + void *addr_buf; + + /* Get total length of all the address parameters. */ + addr_buf = addrs; + for (i = 0; i < addrcnt; i++) { + addr = addr_buf; + af = sctp_get_af_specific(addr->v4.sin_family); + addr_param_len = af->to_addr_param(addr, &addr_param); + + totallen += paramlen; + totallen += addr_param_len; + + addr_buf += af->sockaddr_len; + if (asoc->asconf_addr_del_pending && !del_pickup) { + /* reuse the parameter length from the same scope one */ + totallen += paramlen; + totallen += addr_param_len; + del_pickup = 1; + + pr_debug("%s: picked same-scope del_pending addr, " + "totallen for all addresses is %d\n", + __func__, totallen); + } + } + + /* Create an asconf chunk with the required length. */ + retval = sctp_make_asconf(asoc, laddr, totallen); + if (!retval) + return NULL; + + /* Add the address parameters to the asconf chunk. */ + addr_buf = addrs; + for (i = 0; i < addrcnt; i++) { + addr = addr_buf; + af = sctp_get_af_specific(addr->v4.sin_family); + addr_param_len = af->to_addr_param(addr, &addr_param); + param.param_hdr.type = flags; + param.param_hdr.length = htons(paramlen + addr_param_len); + param.crr_id = htonl(i); + + sctp_addto_chunk(retval, paramlen, ¶m); + sctp_addto_chunk(retval, addr_param_len, &addr_param); + + addr_buf += af->sockaddr_len; + } + if (flags == SCTP_PARAM_ADD_IP && del_pickup) { + addr = asoc->asconf_addr_del_pending; + af = sctp_get_af_specific(addr->v4.sin_family); + addr_param_len = af->to_addr_param(addr, &addr_param); + param.param_hdr.type = SCTP_PARAM_DEL_IP; + param.param_hdr.length = htons(paramlen + addr_param_len); + param.crr_id = htonl(i); + + sctp_addto_chunk(retval, paramlen, ¶m); + sctp_addto_chunk(retval, addr_param_len, &addr_param); + } + return retval; +} + +/* ADDIP + * 3.2.4 Set Primary IP Address + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Type =0xC004 | Length = Variable | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | ASCONF-Request Correlation ID | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Address Parameter | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * Create an ASCONF chunk with Set Primary IP address parameter. + */ +struct sctp_chunk *sctp_make_asconf_set_prim(struct sctp_association *asoc, + union sctp_addr *addr) +{ + struct sctp_af *af = sctp_get_af_specific(addr->v4.sin_family); + union sctp_addr_param addrparam; + struct sctp_addip_param param; + struct sctp_chunk *retval; + int len = sizeof(param); + int addrlen; + + addrlen = af->to_addr_param(addr, &addrparam); + if (!addrlen) + return NULL; + len += addrlen; + + /* Create the chunk and make asconf header. */ + retval = sctp_make_asconf(asoc, addr, len); + if (!retval) + return NULL; + + param.param_hdr.type = SCTP_PARAM_SET_PRIMARY; + param.param_hdr.length = htons(len); + param.crr_id = 0; + + sctp_addto_chunk(retval, sizeof(param), ¶m); + sctp_addto_chunk(retval, addrlen, &addrparam); + + return retval; +} + +/* ADDIP 3.1.2 Address Configuration Acknowledgement Chunk (ASCONF-ACK) + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Type = 0x80 | Chunk Flags | Chunk Length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Serial Number | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | ASCONF Parameter Response#1 | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * \ \ + * / .... / + * \ \ + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | ASCONF Parameter Response#N | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * Create an ASCONF_ACK chunk with enough space for the parameter responses. + */ +static struct sctp_chunk *sctp_make_asconf_ack(const struct sctp_association *asoc, + __u32 serial, int vparam_len) +{ + struct sctp_addiphdr asconf; + struct sctp_chunk *retval; + int length = sizeof(asconf) + vparam_len; + + /* Create the chunk. */ + retval = sctp_make_control(asoc, SCTP_CID_ASCONF_ACK, 0, length, + GFP_ATOMIC); + if (!retval) + return NULL; + + asconf.serial = htonl(serial); + + retval->subh.addip_hdr = + sctp_addto_chunk(retval, sizeof(asconf), &asconf); + + return retval; +} + +/* Add response parameters to an ASCONF_ACK chunk. */ +static void sctp_add_asconf_response(struct sctp_chunk *chunk, __be32 crr_id, + __be16 err_code, + struct sctp_addip_param *asconf_param) +{ + struct sctp_addip_param ack_param; + struct sctp_errhdr err_param; + int asconf_param_len = 0; + int err_param_len = 0; + __be16 response_type; + + if (SCTP_ERROR_NO_ERROR == err_code) { + response_type = SCTP_PARAM_SUCCESS_REPORT; + } else { + response_type = SCTP_PARAM_ERR_CAUSE; + err_param_len = sizeof(err_param); + if (asconf_param) + asconf_param_len = + ntohs(asconf_param->param_hdr.length); + } + + /* Add Success Indication or Error Cause Indication parameter. */ + ack_param.param_hdr.type = response_type; + ack_param.param_hdr.length = htons(sizeof(ack_param) + + err_param_len + + asconf_param_len); + ack_param.crr_id = crr_id; + sctp_addto_chunk(chunk, sizeof(ack_param), &ack_param); + + if (SCTP_ERROR_NO_ERROR == err_code) + return; + + /* Add Error Cause parameter. */ + err_param.cause = err_code; + err_param.length = htons(err_param_len + asconf_param_len); + sctp_addto_chunk(chunk, err_param_len, &err_param); + + /* Add the failed TLV copied from ASCONF chunk. */ + if (asconf_param) + sctp_addto_chunk(chunk, asconf_param_len, asconf_param); +} + +/* Process a asconf parameter. */ +static __be16 sctp_process_asconf_param(struct sctp_association *asoc, + struct sctp_chunk *asconf, + struct sctp_addip_param *asconf_param) +{ + union sctp_addr_param *addr_param; + struct sctp_transport *peer; + union sctp_addr addr; + struct sctp_af *af; + + addr_param = (void *)asconf_param + sizeof(*asconf_param); + + if (asconf_param->param_hdr.type != SCTP_PARAM_ADD_IP && + asconf_param->param_hdr.type != SCTP_PARAM_DEL_IP && + asconf_param->param_hdr.type != SCTP_PARAM_SET_PRIMARY) + return SCTP_ERROR_UNKNOWN_PARAM; + + switch (addr_param->p.type) { + case SCTP_PARAM_IPV6_ADDRESS: + if (!asoc->peer.ipv6_address) + return SCTP_ERROR_DNS_FAILED; + break; + case SCTP_PARAM_IPV4_ADDRESS: + if (!asoc->peer.ipv4_address) + return SCTP_ERROR_DNS_FAILED; + break; + default: + return SCTP_ERROR_DNS_FAILED; + } + + af = sctp_get_af_specific(param_type2af(addr_param->p.type)); + if (unlikely(!af)) + return SCTP_ERROR_DNS_FAILED; + + if (!af->from_addr_param(&addr, addr_param, htons(asoc->peer.port), 0)) + return SCTP_ERROR_DNS_FAILED; + + /* ADDIP 4.2.1 This parameter MUST NOT contain a broadcast + * or multicast address. + * (note: wildcard is permitted and requires special handling so + * make sure we check for that) + */ + if (!af->is_any(&addr) && !af->addr_valid(&addr, NULL, asconf->skb)) + return SCTP_ERROR_DNS_FAILED; + + switch (asconf_param->param_hdr.type) { + case SCTP_PARAM_ADD_IP: + /* Section 4.2.1: + * If the address 0.0.0.0 or ::0 is provided, the source + * address of the packet MUST be added. + */ + if (af->is_any(&addr)) + memcpy(&addr, &asconf->source, sizeof(addr)); + + if (security_sctp_bind_connect(asoc->ep->base.sk, + SCTP_PARAM_ADD_IP, + (struct sockaddr *)&addr, + af->sockaddr_len)) + return SCTP_ERROR_REQ_REFUSED; + + /* ADDIP 4.3 D9) If an endpoint receives an ADD IP address + * request and does not have the local resources to add this + * new address to the association, it MUST return an Error + * Cause TLV set to the new error code 'Operation Refused + * Due to Resource Shortage'. + */ + + peer = sctp_assoc_add_peer(asoc, &addr, GFP_ATOMIC, SCTP_UNCONFIRMED); + if (!peer) + return SCTP_ERROR_RSRC_LOW; + + /* Start the heartbeat timer. */ + sctp_transport_reset_hb_timer(peer); + asoc->new_transport = peer; + break; + case SCTP_PARAM_DEL_IP: + /* ADDIP 4.3 D7) If a request is received to delete the + * last remaining IP address of a peer endpoint, the receiver + * MUST send an Error Cause TLV with the error cause set to the + * new error code 'Request to Delete Last Remaining IP Address'. + */ + if (asoc->peer.transport_count == 1) + return SCTP_ERROR_DEL_LAST_IP; + + /* ADDIP 4.3 D8) If a request is received to delete an IP + * address which is also the source address of the IP packet + * which contained the ASCONF chunk, the receiver MUST reject + * this request. To reject the request the receiver MUST send + * an Error Cause TLV set to the new error code 'Request to + * Delete Source IP Address' + */ + if (sctp_cmp_addr_exact(&asconf->source, &addr)) + return SCTP_ERROR_DEL_SRC_IP; + + /* Section 4.2.2 + * If the address 0.0.0.0 or ::0 is provided, all + * addresses of the peer except the source address of the + * packet MUST be deleted. + */ + if (af->is_any(&addr)) { + sctp_assoc_set_primary(asoc, asconf->transport); + sctp_assoc_del_nonprimary_peers(asoc, + asconf->transport); + return SCTP_ERROR_NO_ERROR; + } + + /* If the address is not part of the association, the + * ASCONF-ACK with Error Cause Indication Parameter + * which including cause of Unresolvable Address should + * be sent. + */ + peer = sctp_assoc_lookup_paddr(asoc, &addr); + if (!peer) + return SCTP_ERROR_DNS_FAILED; + + sctp_assoc_rm_peer(asoc, peer); + break; + case SCTP_PARAM_SET_PRIMARY: + /* ADDIP Section 4.2.4 + * If the address 0.0.0.0 or ::0 is provided, the receiver + * MAY mark the source address of the packet as its + * primary. + */ + if (af->is_any(&addr)) + memcpy(&addr, sctp_source(asconf), sizeof(addr)); + + if (security_sctp_bind_connect(asoc->ep->base.sk, + SCTP_PARAM_SET_PRIMARY, + (struct sockaddr *)&addr, + af->sockaddr_len)) + return SCTP_ERROR_REQ_REFUSED; + + peer = sctp_assoc_lookup_paddr(asoc, &addr); + if (!peer) + return SCTP_ERROR_DNS_FAILED; + + sctp_assoc_set_primary(asoc, peer); + break; + } + + return SCTP_ERROR_NO_ERROR; +} + +/* Verify the ASCONF packet before we process it. */ +bool sctp_verify_asconf(const struct sctp_association *asoc, + struct sctp_chunk *chunk, bool addr_param_needed, + struct sctp_paramhdr **errp) +{ + struct sctp_addip_chunk *addip; + bool addr_param_seen = false; + union sctp_params param; + + addip = (struct sctp_addip_chunk *)chunk->chunk_hdr; + sctp_walk_params(param, addip, addip_hdr.params) { + size_t length = ntohs(param.p->length); + + *errp = param.p; + switch (param.p->type) { + case SCTP_PARAM_ERR_CAUSE: + break; + case SCTP_PARAM_IPV4_ADDRESS: + if (length != sizeof(struct sctp_ipv4addr_param)) + return false; + /* ensure there is only one addr param and it's in the + * beginning of addip_hdr params, or we reject it. + */ + if (param.v != addip->addip_hdr.params) + return false; + addr_param_seen = true; + break; + case SCTP_PARAM_IPV6_ADDRESS: + if (length != sizeof(struct sctp_ipv6addr_param)) + return false; + if (param.v != addip->addip_hdr.params) + return false; + addr_param_seen = true; + break; + case SCTP_PARAM_ADD_IP: + case SCTP_PARAM_DEL_IP: + case SCTP_PARAM_SET_PRIMARY: + /* In ASCONF chunks, these need to be first. */ + if (addr_param_needed && !addr_param_seen) + return false; + length = ntohs(param.addip->param_hdr.length); + if (length < sizeof(struct sctp_addip_param) + + sizeof(**errp)) + return false; + break; + case SCTP_PARAM_SUCCESS_REPORT: + case SCTP_PARAM_ADAPTATION_LAYER_IND: + if (length != sizeof(struct sctp_addip_param)) + return false; + break; + default: + /* This is unknown to us, reject! */ + return false; + } + } + + /* Remaining sanity checks. */ + if (addr_param_needed && !addr_param_seen) + return false; + if (!addr_param_needed && addr_param_seen) + return false; + if (param.v != chunk->chunk_end) + return false; + + return true; +} + +/* Process an incoming ASCONF chunk with the next expected serial no. and + * return an ASCONF_ACK chunk to be sent in response. + */ +struct sctp_chunk *sctp_process_asconf(struct sctp_association *asoc, + struct sctp_chunk *asconf) +{ + union sctp_addr_param *addr_param; + struct sctp_addip_chunk *addip; + struct sctp_chunk *asconf_ack; + bool all_param_pass = true; + struct sctp_addiphdr *hdr; + int length = 0, chunk_len; + union sctp_params param; + __be16 err_code; + __u32 serial; + + addip = (struct sctp_addip_chunk *)asconf->chunk_hdr; + chunk_len = ntohs(asconf->chunk_hdr->length) - + sizeof(struct sctp_chunkhdr); + hdr = (struct sctp_addiphdr *)asconf->skb->data; + serial = ntohl(hdr->serial); + + /* Skip the addiphdr and store a pointer to address parameter. */ + length = sizeof(*hdr); + addr_param = (union sctp_addr_param *)(asconf->skb->data + length); + chunk_len -= length; + + /* Skip the address parameter and store a pointer to the first + * asconf parameter. + */ + length = ntohs(addr_param->p.length); + chunk_len -= length; + + /* create an ASCONF_ACK chunk. + * Based on the definitions of parameters, we know that the size of + * ASCONF_ACK parameters are less than or equal to the fourfold of ASCONF + * parameters. + */ + asconf_ack = sctp_make_asconf_ack(asoc, serial, chunk_len * 4); + if (!asconf_ack) + goto done; + + /* Process the TLVs contained within the ASCONF chunk. */ + sctp_walk_params(param, addip, addip_hdr.params) { + /* Skip preceeding address parameters. */ + if (param.p->type == SCTP_PARAM_IPV4_ADDRESS || + param.p->type == SCTP_PARAM_IPV6_ADDRESS) + continue; + + err_code = sctp_process_asconf_param(asoc, asconf, + param.addip); + /* ADDIP 4.1 A7) + * If an error response is received for a TLV parameter, + * all TLVs with no response before the failed TLV are + * considered successful if not reported. All TLVs after + * the failed response are considered unsuccessful unless + * a specific success indication is present for the parameter. + */ + if (err_code != SCTP_ERROR_NO_ERROR) + all_param_pass = false; + if (!all_param_pass) + sctp_add_asconf_response(asconf_ack, param.addip->crr_id, + err_code, param.addip); + + /* ADDIP 4.3 D11) When an endpoint receiving an ASCONF to add + * an IP address sends an 'Out of Resource' in its response, it + * MUST also fail any subsequent add or delete requests bundled + * in the ASCONF. + */ + if (err_code == SCTP_ERROR_RSRC_LOW) + goto done; + } +done: + asoc->peer.addip_serial++; + + /* If we are sending a new ASCONF_ACK hold a reference to it in assoc + * after freeing the reference to old asconf ack if any. + */ + if (asconf_ack) { + sctp_chunk_hold(asconf_ack); + list_add_tail(&asconf_ack->transmitted_list, + &asoc->asconf_ack_list); + } + + return asconf_ack; +} + +/* Process a asconf parameter that is successfully acked. */ +static void sctp_asconf_param_success(struct sctp_association *asoc, + struct sctp_addip_param *asconf_param) +{ + struct sctp_bind_addr *bp = &asoc->base.bind_addr; + union sctp_addr_param *addr_param; + struct sctp_sockaddr_entry *saddr; + struct sctp_transport *transport; + union sctp_addr addr; + struct sctp_af *af; + + addr_param = (void *)asconf_param + sizeof(*asconf_param); + + /* We have checked the packet before, so we do not check again. */ + af = sctp_get_af_specific(param_type2af(addr_param->p.type)); + if (!af->from_addr_param(&addr, addr_param, htons(bp->port), 0)) + return; + + switch (asconf_param->param_hdr.type) { + case SCTP_PARAM_ADD_IP: + /* This is always done in BH context with a socket lock + * held, so the list can not change. + */ + local_bh_disable(); + list_for_each_entry(saddr, &bp->address_list, list) { + if (sctp_cmp_addr_exact(&saddr->a, &addr)) + saddr->state = SCTP_ADDR_SRC; + } + local_bh_enable(); + list_for_each_entry(transport, &asoc->peer.transport_addr_list, + transports) { + sctp_transport_dst_release(transport); + } + break; + case SCTP_PARAM_DEL_IP: + local_bh_disable(); + sctp_del_bind_addr(bp, &addr); + if (asoc->asconf_addr_del_pending != NULL && + sctp_cmp_addr_exact(asoc->asconf_addr_del_pending, &addr)) { + kfree(asoc->asconf_addr_del_pending); + asoc->asconf_addr_del_pending = NULL; + } + local_bh_enable(); + list_for_each_entry(transport, &asoc->peer.transport_addr_list, + transports) { + sctp_transport_dst_release(transport); + } + break; + default: + break; + } +} + +/* Get the corresponding ASCONF response error code from the ASCONF_ACK chunk + * for the given asconf parameter. If there is no response for this parameter, + * return the error code based on the third argument 'no_err'. + * ADDIP 4.1 + * A7) If an error response is received for a TLV parameter, all TLVs with no + * response before the failed TLV are considered successful if not reported. + * All TLVs after the failed response are considered unsuccessful unless a + * specific success indication is present for the parameter. + */ +static __be16 sctp_get_asconf_response(struct sctp_chunk *asconf_ack, + struct sctp_addip_param *asconf_param, + int no_err) +{ + struct sctp_addip_param *asconf_ack_param; + struct sctp_errhdr *err_param; + int asconf_ack_len; + __be16 err_code; + int length; + + if (no_err) + err_code = SCTP_ERROR_NO_ERROR; + else + err_code = SCTP_ERROR_REQ_REFUSED; + + asconf_ack_len = ntohs(asconf_ack->chunk_hdr->length) - + sizeof(struct sctp_chunkhdr); + + /* Skip the addiphdr from the asconf_ack chunk and store a pointer to + * the first asconf_ack parameter. + */ + length = sizeof(struct sctp_addiphdr); + asconf_ack_param = (struct sctp_addip_param *)(asconf_ack->skb->data + + length); + asconf_ack_len -= length; + + while (asconf_ack_len > 0) { + if (asconf_ack_param->crr_id == asconf_param->crr_id) { + switch (asconf_ack_param->param_hdr.type) { + case SCTP_PARAM_SUCCESS_REPORT: + return SCTP_ERROR_NO_ERROR; + case SCTP_PARAM_ERR_CAUSE: + length = sizeof(*asconf_ack_param); + err_param = (void *)asconf_ack_param + length; + asconf_ack_len -= length; + if (asconf_ack_len > 0) + return err_param->cause; + else + return SCTP_ERROR_INV_PARAM; + break; + default: + return SCTP_ERROR_INV_PARAM; + } + } + + length = ntohs(asconf_ack_param->param_hdr.length); + asconf_ack_param = (void *)asconf_ack_param + length; + asconf_ack_len -= length; + } + + return err_code; +} + +/* Process an incoming ASCONF_ACK chunk against the cached last ASCONF chunk. */ +int sctp_process_asconf_ack(struct sctp_association *asoc, + struct sctp_chunk *asconf_ack) +{ + struct sctp_chunk *asconf = asoc->addip_last_asconf; + struct sctp_addip_param *asconf_param; + __be16 err_code = SCTP_ERROR_NO_ERROR; + union sctp_addr_param *addr_param; + int asconf_len = asconf->skb->len; + int all_param_pass = 0; + int length = 0; + int no_err = 1; + int retval = 0; + + /* Skip the chunkhdr and addiphdr from the last asconf sent and store + * a pointer to address parameter. + */ + length = sizeof(struct sctp_addip_chunk); + addr_param = (union sctp_addr_param *)(asconf->skb->data + length); + asconf_len -= length; + + /* Skip the address parameter in the last asconf sent and store a + * pointer to the first asconf parameter. + */ + length = ntohs(addr_param->p.length); + asconf_param = (void *)addr_param + length; + asconf_len -= length; + + /* ADDIP 4.1 + * A8) If there is no response(s) to specific TLV parameter(s), and no + * failures are indicated, then all request(s) are considered + * successful. + */ + if (asconf_ack->skb->len == sizeof(struct sctp_addiphdr)) + all_param_pass = 1; + + /* Process the TLVs contained in the last sent ASCONF chunk. */ + while (asconf_len > 0) { + if (all_param_pass) + err_code = SCTP_ERROR_NO_ERROR; + else { + err_code = sctp_get_asconf_response(asconf_ack, + asconf_param, + no_err); + if (no_err && (SCTP_ERROR_NO_ERROR != err_code)) + no_err = 0; + } + + switch (err_code) { + case SCTP_ERROR_NO_ERROR: + sctp_asconf_param_success(asoc, asconf_param); + break; + + case SCTP_ERROR_RSRC_LOW: + retval = 1; + break; + + case SCTP_ERROR_UNKNOWN_PARAM: + /* Disable sending this type of asconf parameter in + * future. + */ + asoc->peer.addip_disabled_mask |= + asconf_param->param_hdr.type; + break; + + case SCTP_ERROR_REQ_REFUSED: + case SCTP_ERROR_DEL_LAST_IP: + case SCTP_ERROR_DEL_SRC_IP: + default: + break; + } + + /* Skip the processed asconf parameter and move to the next + * one. + */ + length = ntohs(asconf_param->param_hdr.length); + asconf_param = (void *)asconf_param + length; + asconf_len -= length; + } + + if (no_err && asoc->src_out_of_asoc_ok) { + asoc->src_out_of_asoc_ok = 0; + sctp_transport_immediate_rtx(asoc->peer.primary_path); + } + + /* Free the cached last sent asconf chunk. */ + list_del_init(&asconf->transmitted_list); + sctp_chunk_free(asconf); + asoc->addip_last_asconf = NULL; + + return retval; +} + +/* Make a FWD TSN chunk. */ +struct sctp_chunk *sctp_make_fwdtsn(const struct sctp_association *asoc, + __u32 new_cum_tsn, size_t nstreams, + struct sctp_fwdtsn_skip *skiplist) +{ + struct sctp_chunk *retval = NULL; + struct sctp_fwdtsn_hdr ftsn_hdr; + struct sctp_fwdtsn_skip skip; + size_t hint; + int i; + + hint = (nstreams + 1) * sizeof(__u32); + + retval = sctp_make_control(asoc, SCTP_CID_FWD_TSN, 0, hint, GFP_ATOMIC); + + if (!retval) + return NULL; + + ftsn_hdr.new_cum_tsn = htonl(new_cum_tsn); + retval->subh.fwdtsn_hdr = + sctp_addto_chunk(retval, sizeof(ftsn_hdr), &ftsn_hdr); + + for (i = 0; i < nstreams; i++) { + skip.stream = skiplist[i].stream; + skip.ssn = skiplist[i].ssn; + sctp_addto_chunk(retval, sizeof(skip), &skip); + } + + return retval; +} + +struct sctp_chunk *sctp_make_ifwdtsn(const struct sctp_association *asoc, + __u32 new_cum_tsn, size_t nstreams, + struct sctp_ifwdtsn_skip *skiplist) +{ + struct sctp_chunk *retval = NULL; + struct sctp_ifwdtsn_hdr ftsn_hdr; + size_t hint; + + hint = (nstreams + 1) * sizeof(__u32); + + retval = sctp_make_control(asoc, SCTP_CID_I_FWD_TSN, 0, hint, + GFP_ATOMIC); + if (!retval) + return NULL; + + ftsn_hdr.new_cum_tsn = htonl(new_cum_tsn); + retval->subh.ifwdtsn_hdr = + sctp_addto_chunk(retval, sizeof(ftsn_hdr), &ftsn_hdr); + + sctp_addto_chunk(retval, nstreams * sizeof(skiplist[0]), skiplist); + + return retval; +} + +/* RE-CONFIG 3.1 (RE-CONFIG chunk) + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Type = 130 | Chunk Flags | Chunk Length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * \ \ + * / Re-configuration Parameter / + * \ \ + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * \ \ + * / Re-configuration Parameter (optional) / + * \ \ + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ +static struct sctp_chunk *sctp_make_reconf(const struct sctp_association *asoc, + int length) +{ + struct sctp_reconf_chunk *reconf; + struct sctp_chunk *retval; + + retval = sctp_make_control(asoc, SCTP_CID_RECONF, 0, length, + GFP_ATOMIC); + if (!retval) + return NULL; + + reconf = (struct sctp_reconf_chunk *)retval->chunk_hdr; + retval->param_hdr.v = reconf->params; + + return retval; +} + +/* RE-CONFIG 4.1 (STREAM OUT RESET) + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Parameter Type = 13 | Parameter Length = 16 + 2 * N | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Re-configuration Request Sequence Number | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Re-configuration Response Sequence Number | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Sender's Last Assigned TSN | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Stream Number 1 (optional) | Stream Number 2 (optional) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * / ...... / + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Stream Number N-1 (optional) | Stream Number N (optional) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * RE-CONFIG 4.2 (STREAM IN RESET) + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Parameter Type = 14 | Parameter Length = 8 + 2 * N | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Re-configuration Request Sequence Number | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Stream Number 1 (optional) | Stream Number 2 (optional) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * / ...... / + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Stream Number N-1 (optional) | Stream Number N (optional) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ +struct sctp_chunk *sctp_make_strreset_req( + const struct sctp_association *asoc, + __u16 stream_num, __be16 *stream_list, + bool out, bool in) +{ + __u16 stream_len = stream_num * sizeof(__u16); + struct sctp_strreset_outreq outreq; + struct sctp_strreset_inreq inreq; + struct sctp_chunk *retval; + __u16 outlen, inlen; + + outlen = (sizeof(outreq) + stream_len) * out; + inlen = (sizeof(inreq) + stream_len) * in; + + retval = sctp_make_reconf(asoc, SCTP_PAD4(outlen) + SCTP_PAD4(inlen)); + if (!retval) + return NULL; + + if (outlen) { + outreq.param_hdr.type = SCTP_PARAM_RESET_OUT_REQUEST; + outreq.param_hdr.length = htons(outlen); + outreq.request_seq = htonl(asoc->strreset_outseq); + outreq.response_seq = htonl(asoc->strreset_inseq - 1); + outreq.send_reset_at_tsn = htonl(asoc->next_tsn - 1); + + sctp_addto_chunk(retval, sizeof(outreq), &outreq); + + if (stream_len) + sctp_addto_chunk(retval, stream_len, stream_list); + } + + if (inlen) { + inreq.param_hdr.type = SCTP_PARAM_RESET_IN_REQUEST; + inreq.param_hdr.length = htons(inlen); + inreq.request_seq = htonl(asoc->strreset_outseq + out); + + sctp_addto_chunk(retval, sizeof(inreq), &inreq); + + if (stream_len) + sctp_addto_chunk(retval, stream_len, stream_list); + } + + return retval; +} + +/* RE-CONFIG 4.3 (SSN/TSN RESET ALL) + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Parameter Type = 15 | Parameter Length = 8 | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Re-configuration Request Sequence Number | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ +struct sctp_chunk *sctp_make_strreset_tsnreq( + const struct sctp_association *asoc) +{ + struct sctp_strreset_tsnreq tsnreq; + __u16 length = sizeof(tsnreq); + struct sctp_chunk *retval; + + retval = sctp_make_reconf(asoc, length); + if (!retval) + return NULL; + + tsnreq.param_hdr.type = SCTP_PARAM_RESET_TSN_REQUEST; + tsnreq.param_hdr.length = htons(length); + tsnreq.request_seq = htonl(asoc->strreset_outseq); + + sctp_addto_chunk(retval, sizeof(tsnreq), &tsnreq); + + return retval; +} + +/* RE-CONFIG 4.5/4.6 (ADD STREAM) + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Parameter Type = 17 | Parameter Length = 12 | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Re-configuration Request Sequence Number | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Number of new streams | Reserved | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ +struct sctp_chunk *sctp_make_strreset_addstrm( + const struct sctp_association *asoc, + __u16 out, __u16 in) +{ + struct sctp_strreset_addstrm addstrm; + __u16 size = sizeof(addstrm); + struct sctp_chunk *retval; + + retval = sctp_make_reconf(asoc, (!!out + !!in) * size); + if (!retval) + return NULL; + + if (out) { + addstrm.param_hdr.type = SCTP_PARAM_RESET_ADD_OUT_STREAMS; + addstrm.param_hdr.length = htons(size); + addstrm.number_of_streams = htons(out); + addstrm.request_seq = htonl(asoc->strreset_outseq); + addstrm.reserved = 0; + + sctp_addto_chunk(retval, size, &addstrm); + } + + if (in) { + addstrm.param_hdr.type = SCTP_PARAM_RESET_ADD_IN_STREAMS; + addstrm.param_hdr.length = htons(size); + addstrm.number_of_streams = htons(in); + addstrm.request_seq = htonl(asoc->strreset_outseq + !!out); + addstrm.reserved = 0; + + sctp_addto_chunk(retval, size, &addstrm); + } + + return retval; +} + +/* RE-CONFIG 4.4 (RESP) + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Parameter Type = 16 | Parameter Length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Re-configuration Response Sequence Number | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Result | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ +struct sctp_chunk *sctp_make_strreset_resp(const struct sctp_association *asoc, + __u32 result, __u32 sn) +{ + struct sctp_strreset_resp resp; + __u16 length = sizeof(resp); + struct sctp_chunk *retval; + + retval = sctp_make_reconf(asoc, length); + if (!retval) + return NULL; + + resp.param_hdr.type = SCTP_PARAM_RESET_RESPONSE; + resp.param_hdr.length = htons(length); + resp.response_seq = htonl(sn); + resp.result = htonl(result); + + sctp_addto_chunk(retval, sizeof(resp), &resp); + + return retval; +} + +/* RE-CONFIG 4.4 OPTIONAL (TSNRESP) + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Parameter Type = 16 | Parameter Length | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Re-configuration Response Sequence Number | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Result | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Sender's Next TSN (optional) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | Receiver's Next TSN (optional) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + */ +struct sctp_chunk *sctp_make_strreset_tsnresp(struct sctp_association *asoc, + __u32 result, __u32 sn, + __u32 sender_tsn, + __u32 receiver_tsn) +{ + struct sctp_strreset_resptsn tsnresp; + __u16 length = sizeof(tsnresp); + struct sctp_chunk *retval; + + retval = sctp_make_reconf(asoc, length); + if (!retval) + return NULL; + + tsnresp.param_hdr.type = SCTP_PARAM_RESET_RESPONSE; + tsnresp.param_hdr.length = htons(length); + + tsnresp.response_seq = htonl(sn); + tsnresp.result = htonl(result); + tsnresp.senders_next_tsn = htonl(sender_tsn); + tsnresp.receivers_next_tsn = htonl(receiver_tsn); + + sctp_addto_chunk(retval, sizeof(tsnresp), &tsnresp); + + return retval; +} + +bool sctp_verify_reconf(const struct sctp_association *asoc, + struct sctp_chunk *chunk, + struct sctp_paramhdr **errp) +{ + struct sctp_reconf_chunk *hdr; + union sctp_params param; + __be16 last = 0; + __u16 cnt = 0; + + hdr = (struct sctp_reconf_chunk *)chunk->chunk_hdr; + sctp_walk_params(param, hdr, params) { + __u16 length = ntohs(param.p->length); + + *errp = param.p; + if (cnt++ > 2) + return false; + switch (param.p->type) { + case SCTP_PARAM_RESET_OUT_REQUEST: + if (length < sizeof(struct sctp_strreset_outreq) || + (last && last != SCTP_PARAM_RESET_RESPONSE && + last != SCTP_PARAM_RESET_IN_REQUEST)) + return false; + break; + case SCTP_PARAM_RESET_IN_REQUEST: + if (length < sizeof(struct sctp_strreset_inreq) || + (last && last != SCTP_PARAM_RESET_OUT_REQUEST)) + return false; + break; + case SCTP_PARAM_RESET_RESPONSE: + if ((length != sizeof(struct sctp_strreset_resp) && + length != sizeof(struct sctp_strreset_resptsn)) || + (last && last != SCTP_PARAM_RESET_RESPONSE && + last != SCTP_PARAM_RESET_OUT_REQUEST)) + return false; + break; + case SCTP_PARAM_RESET_TSN_REQUEST: + if (length != + sizeof(struct sctp_strreset_tsnreq) || last) + return false; + break; + case SCTP_PARAM_RESET_ADD_IN_STREAMS: + if (length != sizeof(struct sctp_strreset_addstrm) || + (last && last != SCTP_PARAM_RESET_ADD_OUT_STREAMS)) + return false; + break; + case SCTP_PARAM_RESET_ADD_OUT_STREAMS: + if (length != sizeof(struct sctp_strreset_addstrm) || + (last && last != SCTP_PARAM_RESET_ADD_IN_STREAMS)) + return false; + break; + default: + return false; + } + + last = param.p->type; + } + + return true; +} diff --git a/net/sctp/sm_sideeffect.c b/net/sctp/sm_sideeffect.c new file mode 100644 index 000000000..970c6a486 --- /dev/null +++ b/net/sctp/sm_sideeffect.c @@ -0,0 +1,1825 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2001, 2004 + * Copyright (c) 1999 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * + * This file is part of the SCTP kernel implementation + * + * These functions work with the state functions in sctp_sm_statefuns.c + * to implement that state operations. These functions implement the + * steps which require modifying existing data structures. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * La Monte H.P. Yarroll + * Karl Knutson + * Jon Grimm + * Hui Huang + * Dajiang Zhang + * Daisy Chang + * Sridhar Samudrala + * Ardelle Fan + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int sctp_cmd_interpreter(enum sctp_event_type event_type, + union sctp_subtype subtype, + enum sctp_state state, + struct sctp_endpoint *ep, + struct sctp_association *asoc, + void *event_arg, + enum sctp_disposition status, + struct sctp_cmd_seq *commands, + gfp_t gfp); +static int sctp_side_effects(enum sctp_event_type event_type, + union sctp_subtype subtype, + enum sctp_state state, + struct sctp_endpoint *ep, + struct sctp_association **asoc, + void *event_arg, + enum sctp_disposition status, + struct sctp_cmd_seq *commands, + gfp_t gfp); + +/******************************************************************** + * Helper functions + ********************************************************************/ + +/* A helper function for delayed processing of INET ECN CE bit. */ +static void sctp_do_ecn_ce_work(struct sctp_association *asoc, + __u32 lowest_tsn) +{ + /* Save the TSN away for comparison when we receive CWR */ + + asoc->last_ecne_tsn = lowest_tsn; + asoc->need_ecne = 1; +} + +/* Helper function for delayed processing of SCTP ECNE chunk. */ +/* RFC 2960 Appendix A + * + * RFC 2481 details a specific bit for a sender to send in + * the header of its next outbound TCP segment to indicate to + * its peer that it has reduced its congestion window. This + * is termed the CWR bit. For SCTP the same indication is made + * by including the CWR chunk. This chunk contains one data + * element, i.e. the TSN number that was sent in the ECNE chunk. + * This element represents the lowest TSN number in the datagram + * that was originally marked with the CE bit. + */ +static struct sctp_chunk *sctp_do_ecn_ecne_work(struct sctp_association *asoc, + __u32 lowest_tsn, + struct sctp_chunk *chunk) +{ + struct sctp_chunk *repl; + + /* Our previously transmitted packet ran into some congestion + * so we should take action by reducing cwnd and ssthresh + * and then ACK our peer that we we've done so by + * sending a CWR. + */ + + /* First, try to determine if we want to actually lower + * our cwnd variables. Only lower them if the ECNE looks more + * recent than the last response. + */ + if (TSN_lt(asoc->last_cwr_tsn, lowest_tsn)) { + struct sctp_transport *transport; + + /* Find which transport's congestion variables + * need to be adjusted. + */ + transport = sctp_assoc_lookup_tsn(asoc, lowest_tsn); + + /* Update the congestion variables. */ + if (transport) + sctp_transport_lower_cwnd(transport, + SCTP_LOWER_CWND_ECNE); + asoc->last_cwr_tsn = lowest_tsn; + } + + /* Always try to quiet the other end. In case of lost CWR, + * resend last_cwr_tsn. + */ + repl = sctp_make_cwr(asoc, asoc->last_cwr_tsn, chunk); + + /* If we run out of memory, it will look like a lost CWR. We'll + * get back in sync eventually. + */ + return repl; +} + +/* Helper function to do delayed processing of ECN CWR chunk. */ +static void sctp_do_ecn_cwr_work(struct sctp_association *asoc, + __u32 lowest_tsn) +{ + /* Turn off ECNE getting auto-prepended to every outgoing + * packet + */ + asoc->need_ecne = 0; +} + +/* Generate SACK if necessary. We call this at the end of a packet. */ +static int sctp_gen_sack(struct sctp_association *asoc, int force, + struct sctp_cmd_seq *commands) +{ + struct sctp_transport *trans = asoc->peer.last_data_from; + __u32 ctsn, max_tsn_seen; + struct sctp_chunk *sack; + int error = 0; + + if (force || + (!trans && (asoc->param_flags & SPP_SACKDELAY_DISABLE)) || + (trans && (trans->param_flags & SPP_SACKDELAY_DISABLE))) + asoc->peer.sack_needed = 1; + + ctsn = sctp_tsnmap_get_ctsn(&asoc->peer.tsn_map); + max_tsn_seen = sctp_tsnmap_get_max_tsn_seen(&asoc->peer.tsn_map); + + /* From 12.2 Parameters necessary per association (i.e. the TCB): + * + * Ack State : This flag indicates if the next received packet + * : is to be responded to with a SACK. ... + * : When DATA chunks are out of order, SACK's + * : are not delayed (see Section 6). + * + * [This is actually not mentioned in Section 6, but we + * implement it here anyway. --piggy] + */ + if (max_tsn_seen != ctsn) + asoc->peer.sack_needed = 1; + + /* From 6.2 Acknowledgement on Reception of DATA Chunks: + * + * Section 4.2 of [RFC2581] SHOULD be followed. Specifically, + * an acknowledgement SHOULD be generated for at least every + * second packet (not every second DATA chunk) received, and + * SHOULD be generated within 200 ms of the arrival of any + * unacknowledged DATA chunk. ... + */ + if (!asoc->peer.sack_needed) { + asoc->peer.sack_cnt++; + + /* Set the SACK delay timeout based on the + * SACK delay for the last transport + * data was received from, or the default + * for the association. + */ + if (trans) { + /* We will need a SACK for the next packet. */ + if (asoc->peer.sack_cnt >= trans->sackfreq - 1) + asoc->peer.sack_needed = 1; + + asoc->timeouts[SCTP_EVENT_TIMEOUT_SACK] = + trans->sackdelay; + } else { + /* We will need a SACK for the next packet. */ + if (asoc->peer.sack_cnt >= asoc->sackfreq - 1) + asoc->peer.sack_needed = 1; + + asoc->timeouts[SCTP_EVENT_TIMEOUT_SACK] = + asoc->sackdelay; + } + + /* Restart the SACK timer. */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_RESTART, + SCTP_TO(SCTP_EVENT_TIMEOUT_SACK)); + } else { + __u32 old_a_rwnd = asoc->a_rwnd; + + asoc->a_rwnd = asoc->rwnd; + sack = sctp_make_sack(asoc); + if (!sack) { + asoc->a_rwnd = old_a_rwnd; + goto nomem; + } + + asoc->peer.sack_needed = 0; + asoc->peer.sack_cnt = 0; + + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(sack)); + + /* Stop the SACK timer. */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_SACK)); + } + + return error; +nomem: + error = -ENOMEM; + return error; +} + +/* When the T3-RTX timer expires, it calls this function to create the + * relevant state machine event. + */ +void sctp_generate_t3_rtx_event(struct timer_list *t) +{ + struct sctp_transport *transport = + from_timer(transport, t, T3_rtx_timer); + struct sctp_association *asoc = transport->asoc; + struct sock *sk = asoc->base.sk; + struct net *net = sock_net(sk); + int error; + + /* Check whether a task is in the sock. */ + + bh_lock_sock(sk); + if (sock_owned_by_user(sk)) { + pr_debug("%s: sock is busy\n", __func__); + + /* Try again later. */ + if (!mod_timer(&transport->T3_rtx_timer, jiffies + (HZ/20))) + sctp_transport_hold(transport); + goto out_unlock; + } + + /* Run through the state machine. */ + error = sctp_do_sm(net, SCTP_EVENT_T_TIMEOUT, + SCTP_ST_TIMEOUT(SCTP_EVENT_TIMEOUT_T3_RTX), + asoc->state, + asoc->ep, asoc, + transport, GFP_ATOMIC); + + if (error) + sk->sk_err = -error; + +out_unlock: + bh_unlock_sock(sk); + sctp_transport_put(transport); +} + +/* This is a sa interface for producing timeout events. It works + * for timeouts which use the association as their parameter. + */ +static void sctp_generate_timeout_event(struct sctp_association *asoc, + enum sctp_event_timeout timeout_type) +{ + struct sock *sk = asoc->base.sk; + struct net *net = sock_net(sk); + int error = 0; + + bh_lock_sock(sk); + if (sock_owned_by_user(sk)) { + pr_debug("%s: sock is busy: timer %d\n", __func__, + timeout_type); + + /* Try again later. */ + if (!mod_timer(&asoc->timers[timeout_type], jiffies + (HZ/20))) + sctp_association_hold(asoc); + goto out_unlock; + } + + /* Is this association really dead and just waiting around for + * the timer to let go of the reference? + */ + if (asoc->base.dead) + goto out_unlock; + + /* Run through the state machine. */ + error = sctp_do_sm(net, SCTP_EVENT_T_TIMEOUT, + SCTP_ST_TIMEOUT(timeout_type), + asoc->state, asoc->ep, asoc, + (void *)timeout_type, GFP_ATOMIC); + + if (error) + sk->sk_err = -error; + +out_unlock: + bh_unlock_sock(sk); + sctp_association_put(asoc); +} + +static void sctp_generate_t1_cookie_event(struct timer_list *t) +{ + struct sctp_association *asoc = + from_timer(asoc, t, timers[SCTP_EVENT_TIMEOUT_T1_COOKIE]); + + sctp_generate_timeout_event(asoc, SCTP_EVENT_TIMEOUT_T1_COOKIE); +} + +static void sctp_generate_t1_init_event(struct timer_list *t) +{ + struct sctp_association *asoc = + from_timer(asoc, t, timers[SCTP_EVENT_TIMEOUT_T1_INIT]); + + sctp_generate_timeout_event(asoc, SCTP_EVENT_TIMEOUT_T1_INIT); +} + +static void sctp_generate_t2_shutdown_event(struct timer_list *t) +{ + struct sctp_association *asoc = + from_timer(asoc, t, timers[SCTP_EVENT_TIMEOUT_T2_SHUTDOWN]); + + sctp_generate_timeout_event(asoc, SCTP_EVENT_TIMEOUT_T2_SHUTDOWN); +} + +static void sctp_generate_t4_rto_event(struct timer_list *t) +{ + struct sctp_association *asoc = + from_timer(asoc, t, timers[SCTP_EVENT_TIMEOUT_T4_RTO]); + + sctp_generate_timeout_event(asoc, SCTP_EVENT_TIMEOUT_T4_RTO); +} + +static void sctp_generate_t5_shutdown_guard_event(struct timer_list *t) +{ + struct sctp_association *asoc = + from_timer(asoc, t, + timers[SCTP_EVENT_TIMEOUT_T5_SHUTDOWN_GUARD]); + + sctp_generate_timeout_event(asoc, + SCTP_EVENT_TIMEOUT_T5_SHUTDOWN_GUARD); + +} /* sctp_generate_t5_shutdown_guard_event() */ + +static void sctp_generate_autoclose_event(struct timer_list *t) +{ + struct sctp_association *asoc = + from_timer(asoc, t, timers[SCTP_EVENT_TIMEOUT_AUTOCLOSE]); + + sctp_generate_timeout_event(asoc, SCTP_EVENT_TIMEOUT_AUTOCLOSE); +} + +/* Generate a heart beat event. If the sock is busy, reschedule. Make + * sure that the transport is still valid. + */ +void sctp_generate_heartbeat_event(struct timer_list *t) +{ + struct sctp_transport *transport = from_timer(transport, t, hb_timer); + struct sctp_association *asoc = transport->asoc; + struct sock *sk = asoc->base.sk; + struct net *net = sock_net(sk); + u32 elapsed, timeout; + int error = 0; + + bh_lock_sock(sk); + if (sock_owned_by_user(sk)) { + pr_debug("%s: sock is busy\n", __func__); + + /* Try again later. */ + if (!mod_timer(&transport->hb_timer, jiffies + (HZ/20))) + sctp_transport_hold(transport); + goto out_unlock; + } + + /* Check if we should still send the heartbeat or reschedule */ + elapsed = jiffies - transport->last_time_sent; + timeout = sctp_transport_timeout(transport); + if (elapsed < timeout) { + elapsed = timeout - elapsed; + if (!mod_timer(&transport->hb_timer, jiffies + elapsed)) + sctp_transport_hold(transport); + goto out_unlock; + } + + error = sctp_do_sm(net, SCTP_EVENT_T_TIMEOUT, + SCTP_ST_TIMEOUT(SCTP_EVENT_TIMEOUT_HEARTBEAT), + asoc->state, asoc->ep, asoc, + transport, GFP_ATOMIC); + + if (error) + sk->sk_err = -error; + +out_unlock: + bh_unlock_sock(sk); + sctp_transport_put(transport); +} + +/* Handle the timeout of the ICMP protocol unreachable timer. Trigger + * the correct state machine transition that will close the association. + */ +void sctp_generate_proto_unreach_event(struct timer_list *t) +{ + struct sctp_transport *transport = + from_timer(transport, t, proto_unreach_timer); + struct sctp_association *asoc = transport->asoc; + struct sock *sk = asoc->base.sk; + struct net *net = sock_net(sk); + + bh_lock_sock(sk); + if (sock_owned_by_user(sk)) { + pr_debug("%s: sock is busy\n", __func__); + + /* Try again later. */ + if (!mod_timer(&transport->proto_unreach_timer, + jiffies + (HZ/20))) + sctp_transport_hold(transport); + goto out_unlock; + } + + /* Is this structure just waiting around for us to actually + * get destroyed? + */ + if (asoc->base.dead) + goto out_unlock; + + sctp_do_sm(net, SCTP_EVENT_T_OTHER, + SCTP_ST_OTHER(SCTP_EVENT_ICMP_PROTO_UNREACH), + asoc->state, asoc->ep, asoc, transport, GFP_ATOMIC); + +out_unlock: + bh_unlock_sock(sk); + sctp_transport_put(transport); +} + + /* Handle the timeout of the RE-CONFIG timer. */ +void sctp_generate_reconf_event(struct timer_list *t) +{ + struct sctp_transport *transport = + from_timer(transport, t, reconf_timer); + struct sctp_association *asoc = transport->asoc; + struct sock *sk = asoc->base.sk; + struct net *net = sock_net(sk); + int error = 0; + + bh_lock_sock(sk); + if (sock_owned_by_user(sk)) { + pr_debug("%s: sock is busy\n", __func__); + + /* Try again later. */ + if (!mod_timer(&transport->reconf_timer, jiffies + (HZ / 20))) + sctp_transport_hold(transport); + goto out_unlock; + } + + /* This happens when the response arrives after the timer is triggered. */ + if (!asoc->strreset_chunk) + goto out_unlock; + + error = sctp_do_sm(net, SCTP_EVENT_T_TIMEOUT, + SCTP_ST_TIMEOUT(SCTP_EVENT_TIMEOUT_RECONF), + asoc->state, asoc->ep, asoc, + transport, GFP_ATOMIC); + + if (error) + sk->sk_err = -error; + +out_unlock: + bh_unlock_sock(sk); + sctp_transport_put(transport); +} + +/* Handle the timeout of the probe timer. */ +void sctp_generate_probe_event(struct timer_list *t) +{ + struct sctp_transport *transport = from_timer(transport, t, probe_timer); + struct sctp_association *asoc = transport->asoc; + struct sock *sk = asoc->base.sk; + struct net *net = sock_net(sk); + int error = 0; + + bh_lock_sock(sk); + if (sock_owned_by_user(sk)) { + pr_debug("%s: sock is busy\n", __func__); + + /* Try again later. */ + if (!mod_timer(&transport->probe_timer, jiffies + (HZ / 20))) + sctp_transport_hold(transport); + goto out_unlock; + } + + error = sctp_do_sm(net, SCTP_EVENT_T_TIMEOUT, + SCTP_ST_TIMEOUT(SCTP_EVENT_TIMEOUT_PROBE), + asoc->state, asoc->ep, asoc, + transport, GFP_ATOMIC); + + if (error) + sk->sk_err = -error; + +out_unlock: + bh_unlock_sock(sk); + sctp_transport_put(transport); +} + +/* Inject a SACK Timeout event into the state machine. */ +static void sctp_generate_sack_event(struct timer_list *t) +{ + struct sctp_association *asoc = + from_timer(asoc, t, timers[SCTP_EVENT_TIMEOUT_SACK]); + + sctp_generate_timeout_event(asoc, SCTP_EVENT_TIMEOUT_SACK); +} + +sctp_timer_event_t *sctp_timer_events[SCTP_NUM_TIMEOUT_TYPES] = { + [SCTP_EVENT_TIMEOUT_NONE] = NULL, + [SCTP_EVENT_TIMEOUT_T1_COOKIE] = sctp_generate_t1_cookie_event, + [SCTP_EVENT_TIMEOUT_T1_INIT] = sctp_generate_t1_init_event, + [SCTP_EVENT_TIMEOUT_T2_SHUTDOWN] = sctp_generate_t2_shutdown_event, + [SCTP_EVENT_TIMEOUT_T3_RTX] = NULL, + [SCTP_EVENT_TIMEOUT_T4_RTO] = sctp_generate_t4_rto_event, + [SCTP_EVENT_TIMEOUT_T5_SHUTDOWN_GUARD] = + sctp_generate_t5_shutdown_guard_event, + [SCTP_EVENT_TIMEOUT_HEARTBEAT] = NULL, + [SCTP_EVENT_TIMEOUT_RECONF] = NULL, + [SCTP_EVENT_TIMEOUT_SACK] = sctp_generate_sack_event, + [SCTP_EVENT_TIMEOUT_AUTOCLOSE] = sctp_generate_autoclose_event, +}; + + +/* RFC 2960 8.2 Path Failure Detection + * + * When its peer endpoint is multi-homed, an endpoint should keep a + * error counter for each of the destination transport addresses of the + * peer endpoint. + * + * Each time the T3-rtx timer expires on any address, or when a + * HEARTBEAT sent to an idle address is not acknowledged within a RTO, + * the error counter of that destination address will be incremented. + * When the value in the error counter exceeds the protocol parameter + * 'Path.Max.Retrans' of that destination address, the endpoint should + * mark the destination transport address as inactive, and a + * notification SHOULD be sent to the upper layer. + * + */ +static void sctp_do_8_2_transport_strike(struct sctp_cmd_seq *commands, + struct sctp_association *asoc, + struct sctp_transport *transport, + int is_hb) +{ + /* The check for association's overall error counter exceeding the + * threshold is done in the state function. + */ + /* We are here due to a timer expiration. If the timer was + * not a HEARTBEAT, then normal error tracking is done. + * If the timer was a heartbeat, we only increment error counts + * when we already have an outstanding HEARTBEAT that has not + * been acknowledged. + * Additionally, some tranport states inhibit error increments. + */ + if (!is_hb) { + asoc->overall_error_count++; + if (transport->state != SCTP_INACTIVE) + transport->error_count++; + } else if (transport->hb_sent) { + if (transport->state != SCTP_UNCONFIRMED) + asoc->overall_error_count++; + if (transport->state != SCTP_INACTIVE) + transport->error_count++; + } + + /* If the transport error count is greater than the pf_retrans + * threshold, and less than pathmaxrtx, and if the current state + * is SCTP_ACTIVE, then mark this transport as Partially Failed, + * see SCTP Quick Failover Draft, section 5.1 + */ + if (asoc->base.net->sctp.pf_enable && + transport->state == SCTP_ACTIVE && + transport->error_count < transport->pathmaxrxt && + transport->error_count > transport->pf_retrans) { + + sctp_assoc_control_transport(asoc, transport, + SCTP_TRANSPORT_PF, + 0); + + /* Update the hb timer to resend a heartbeat every rto */ + sctp_transport_reset_hb_timer(transport); + } + + if (transport->state != SCTP_INACTIVE && + (transport->error_count > transport->pathmaxrxt)) { + pr_debug("%s: association:%p transport addr:%pISpc failed\n", + __func__, asoc, &transport->ipaddr.sa); + + sctp_assoc_control_transport(asoc, transport, + SCTP_TRANSPORT_DOWN, + SCTP_FAILED_THRESHOLD); + } + + if (transport->error_count > transport->ps_retrans && + asoc->peer.primary_path == transport && + asoc->peer.active_path != transport) + sctp_assoc_set_primary(asoc, asoc->peer.active_path); + + /* E2) For the destination address for which the timer + * expires, set RTO <- RTO * 2 ("back off the timer"). The + * maximum value discussed in rule C7 above (RTO.max) may be + * used to provide an upper bound to this doubling operation. + * + * Special Case: the first HB doesn't trigger exponential backoff. + * The first unacknowledged HB triggers it. We do this with a flag + * that indicates that we have an outstanding HB. + */ + if (!is_hb || transport->hb_sent) { + transport->rto = min((transport->rto * 2), transport->asoc->rto_max); + sctp_max_rto(asoc, transport); + } +} + +/* Worker routine to handle INIT command failure. */ +static void sctp_cmd_init_failed(struct sctp_cmd_seq *commands, + struct sctp_association *asoc, + unsigned int error) +{ + struct sctp_ulpevent *event; + + event = sctp_ulpevent_make_assoc_change(asoc, 0, SCTP_CANT_STR_ASSOC, + (__u16)error, 0, 0, NULL, + GFP_ATOMIC); + + if (event) + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, + SCTP_ULPEVENT(event)); + + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_CLOSED)); + + /* SEND_FAILED sent later when cleaning up the association. */ + asoc->outqueue.error = error; + sctp_add_cmd_sf(commands, SCTP_CMD_DELETE_TCB, SCTP_NULL()); +} + +/* Worker routine to handle SCTP_CMD_ASSOC_FAILED. */ +static void sctp_cmd_assoc_failed(struct sctp_cmd_seq *commands, + struct sctp_association *asoc, + enum sctp_event_type event_type, + union sctp_subtype subtype, + struct sctp_chunk *chunk, + unsigned int error) +{ + struct sctp_ulpevent *event; + struct sctp_chunk *abort; + + /* Cancel any partial delivery in progress. */ + asoc->stream.si->abort_pd(&asoc->ulpq, GFP_ATOMIC); + + if (event_type == SCTP_EVENT_T_CHUNK && subtype.chunk == SCTP_CID_ABORT) + event = sctp_ulpevent_make_assoc_change(asoc, 0, SCTP_COMM_LOST, + (__u16)error, 0, 0, chunk, + GFP_ATOMIC); + else + event = sctp_ulpevent_make_assoc_change(asoc, 0, SCTP_COMM_LOST, + (__u16)error, 0, 0, NULL, + GFP_ATOMIC); + if (event) + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, + SCTP_ULPEVENT(event)); + + if (asoc->overall_error_count >= asoc->max_retrans) { + abort = sctp_make_violation_max_retrans(asoc, chunk); + if (abort) + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, + SCTP_CHUNK(abort)); + } + + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_CLOSED)); + + /* SEND_FAILED sent later when cleaning up the association. */ + asoc->outqueue.error = error; + sctp_add_cmd_sf(commands, SCTP_CMD_DELETE_TCB, SCTP_NULL()); +} + +/* Process an init chunk (may be real INIT/INIT-ACK or an embedded INIT + * inside the cookie. In reality, this is only used for INIT-ACK processing + * since all other cases use "temporary" associations and can do all + * their work in statefuns directly. + */ +static int sctp_cmd_process_init(struct sctp_cmd_seq *commands, + struct sctp_association *asoc, + struct sctp_chunk *chunk, + struct sctp_init_chunk *peer_init, + gfp_t gfp) +{ + int error; + + /* We only process the init as a sideeffect in a single + * case. This is when we process the INIT-ACK. If we + * fail during INIT processing (due to malloc problems), + * just return the error and stop processing the stack. + */ + if (!sctp_process_init(asoc, chunk, sctp_source(chunk), peer_init, gfp)) + error = -ENOMEM; + else + error = 0; + + return error; +} + +/* Helper function to break out starting up of heartbeat timers. */ +static void sctp_cmd_hb_timers_start(struct sctp_cmd_seq *cmds, + struct sctp_association *asoc) +{ + struct sctp_transport *t; + + /* Start a heartbeat timer for each transport on the association. + * hold a reference on the transport to make sure none of + * the needed data structures go away. + */ + list_for_each_entry(t, &asoc->peer.transport_addr_list, transports) + sctp_transport_reset_hb_timer(t); +} + +static void sctp_cmd_hb_timers_stop(struct sctp_cmd_seq *cmds, + struct sctp_association *asoc) +{ + struct sctp_transport *t; + + /* Stop all heartbeat timers. */ + + list_for_each_entry(t, &asoc->peer.transport_addr_list, + transports) { + if (del_timer(&t->hb_timer)) + sctp_transport_put(t); + } +} + +/* Helper function to stop any pending T3-RTX timers */ +static void sctp_cmd_t3_rtx_timers_stop(struct sctp_cmd_seq *cmds, + struct sctp_association *asoc) +{ + struct sctp_transport *t; + + list_for_each_entry(t, &asoc->peer.transport_addr_list, + transports) { + if (del_timer(&t->T3_rtx_timer)) + sctp_transport_put(t); + } +} + + +/* Helper function to handle the reception of an HEARTBEAT ACK. */ +static void sctp_cmd_transport_on(struct sctp_cmd_seq *cmds, + struct sctp_association *asoc, + struct sctp_transport *t, + struct sctp_chunk *chunk) +{ + struct sctp_sender_hb_info *hbinfo; + int was_unconfirmed = 0; + + /* 8.3 Upon the receipt of the HEARTBEAT ACK, the sender of the + * HEARTBEAT should clear the error counter of the destination + * transport address to which the HEARTBEAT was sent. + */ + t->error_count = 0; + + /* + * Although RFC4960 specifies that the overall error count must + * be cleared when a HEARTBEAT ACK is received, we make an + * exception while in SHUTDOWN PENDING. If the peer keeps its + * window shut forever, we may never be able to transmit our + * outstanding data and rely on the retransmission limit be reached + * to shutdown the association. + */ + if (t->asoc->state < SCTP_STATE_SHUTDOWN_PENDING) + t->asoc->overall_error_count = 0; + + /* Clear the hb_sent flag to signal that we had a good + * acknowledgement. + */ + t->hb_sent = 0; + + /* Mark the destination transport address as active if it is not so + * marked. + */ + if ((t->state == SCTP_INACTIVE) || (t->state == SCTP_UNCONFIRMED)) { + was_unconfirmed = 1; + sctp_assoc_control_transport(asoc, t, SCTP_TRANSPORT_UP, + SCTP_HEARTBEAT_SUCCESS); + } + + if (t->state == SCTP_PF) + sctp_assoc_control_transport(asoc, t, SCTP_TRANSPORT_UP, + SCTP_HEARTBEAT_SUCCESS); + + /* HB-ACK was received for a the proper HB. Consider this + * forward progress. + */ + if (t->dst) + sctp_transport_dst_confirm(t); + + /* The receiver of the HEARTBEAT ACK should also perform an + * RTT measurement for that destination transport address + * using the time value carried in the HEARTBEAT ACK chunk. + * If the transport's rto_pending variable has been cleared, + * it was most likely due to a retransmit. However, we want + * to re-enable it to properly update the rto. + */ + if (t->rto_pending == 0) + t->rto_pending = 1; + + hbinfo = (struct sctp_sender_hb_info *)chunk->skb->data; + sctp_transport_update_rto(t, (jiffies - hbinfo->sent_at)); + + /* Update the heartbeat timer. */ + sctp_transport_reset_hb_timer(t); + + if (was_unconfirmed && asoc->peer.transport_count == 1) + sctp_transport_immediate_rtx(t); +} + + +/* Helper function to process the process SACK command. */ +static int sctp_cmd_process_sack(struct sctp_cmd_seq *cmds, + struct sctp_association *asoc, + struct sctp_chunk *chunk) +{ + int err = 0; + + if (sctp_outq_sack(&asoc->outqueue, chunk)) { + /* There are no more TSNs awaiting SACK. */ + err = sctp_do_sm(asoc->base.net, SCTP_EVENT_T_OTHER, + SCTP_ST_OTHER(SCTP_EVENT_NO_PENDING_TSN), + asoc->state, asoc->ep, asoc, NULL, + GFP_ATOMIC); + } + + return err; +} + +/* Helper function to set the timeout value for T2-SHUTDOWN timer and to set + * the transport for a shutdown chunk. + */ +static void sctp_cmd_setup_t2(struct sctp_cmd_seq *cmds, + struct sctp_association *asoc, + struct sctp_chunk *chunk) +{ + struct sctp_transport *t; + + if (chunk->transport) + t = chunk->transport; + else { + t = sctp_assoc_choose_alter_transport(asoc, + asoc->shutdown_last_sent_to); + chunk->transport = t; + } + asoc->shutdown_last_sent_to = t; + asoc->timeouts[SCTP_EVENT_TIMEOUT_T2_SHUTDOWN] = t->rto; +} + +/* Helper function to change the state of an association. */ +static void sctp_cmd_new_state(struct sctp_cmd_seq *cmds, + struct sctp_association *asoc, + enum sctp_state state) +{ + struct sock *sk = asoc->base.sk; + + asoc->state = state; + + pr_debug("%s: asoc:%p[%s]\n", __func__, asoc, sctp_state_tbl[state]); + + if (sctp_style(sk, TCP)) { + /* Change the sk->sk_state of a TCP-style socket that has + * successfully completed a connect() call. + */ + if (sctp_state(asoc, ESTABLISHED) && sctp_sstate(sk, CLOSED)) + inet_sk_set_state(sk, SCTP_SS_ESTABLISHED); + + /* Set the RCV_SHUTDOWN flag when a SHUTDOWN is received. */ + if (sctp_state(asoc, SHUTDOWN_RECEIVED) && + sctp_sstate(sk, ESTABLISHED)) { + inet_sk_set_state(sk, SCTP_SS_CLOSING); + sk->sk_shutdown |= RCV_SHUTDOWN; + } + } + + if (sctp_state(asoc, COOKIE_WAIT)) { + /* Reset init timeouts since they may have been + * increased due to timer expirations. + */ + asoc->timeouts[SCTP_EVENT_TIMEOUT_T1_INIT] = + asoc->rto_initial; + asoc->timeouts[SCTP_EVENT_TIMEOUT_T1_COOKIE] = + asoc->rto_initial; + } + + if (sctp_state(asoc, ESTABLISHED)) { + kfree(asoc->peer.cookie); + asoc->peer.cookie = NULL; + } + + if (sctp_state(asoc, ESTABLISHED) || + sctp_state(asoc, CLOSED) || + sctp_state(asoc, SHUTDOWN_RECEIVED)) { + /* Wake up any processes waiting in the asoc's wait queue in + * sctp_wait_for_connect() or sctp_wait_for_sndbuf(). + */ + if (waitqueue_active(&asoc->wait)) + wake_up_interruptible(&asoc->wait); + + /* Wake up any processes waiting in the sk's sleep queue of + * a TCP-style or UDP-style peeled-off socket in + * sctp_wait_for_accept() or sctp_wait_for_packet(). + * For a UDP-style socket, the waiters are woken up by the + * notifications. + */ + if (!sctp_style(sk, UDP)) + sk->sk_state_change(sk); + } + + if (sctp_state(asoc, SHUTDOWN_PENDING) && + !sctp_outq_is_empty(&asoc->outqueue)) + sctp_outq_uncork(&asoc->outqueue, GFP_ATOMIC); +} + +/* Helper function to delete an association. */ +static void sctp_cmd_delete_tcb(struct sctp_cmd_seq *cmds, + struct sctp_association *asoc) +{ + struct sock *sk = asoc->base.sk; + + /* If it is a non-temporary association belonging to a TCP-style + * listening socket that is not closed, do not free it so that accept() + * can pick it up later. + */ + if (sctp_style(sk, TCP) && sctp_sstate(sk, LISTENING) && + (!asoc->temp) && (sk->sk_shutdown != SHUTDOWN_MASK)) + return; + + sctp_association_free(asoc); +} + +/* + * ADDIP Section 4.1 ASCONF Chunk Procedures + * A4) Start a T-4 RTO timer, using the RTO value of the selected + * destination address (we use active path instead of primary path just + * because primary path may be inactive. + */ +static void sctp_cmd_setup_t4(struct sctp_cmd_seq *cmds, + struct sctp_association *asoc, + struct sctp_chunk *chunk) +{ + struct sctp_transport *t; + + t = sctp_assoc_choose_alter_transport(asoc, chunk->transport); + asoc->timeouts[SCTP_EVENT_TIMEOUT_T4_RTO] = t->rto; + chunk->transport = t; +} + +/* Process an incoming Operation Error Chunk. */ +static void sctp_cmd_process_operr(struct sctp_cmd_seq *cmds, + struct sctp_association *asoc, + struct sctp_chunk *chunk) +{ + struct sctp_errhdr *err_hdr; + struct sctp_ulpevent *ev; + + while (chunk->chunk_end > chunk->skb->data) { + err_hdr = (struct sctp_errhdr *)(chunk->skb->data); + + ev = sctp_ulpevent_make_remote_error(asoc, chunk, 0, + GFP_ATOMIC); + if (!ev) + return; + + asoc->stream.si->enqueue_event(&asoc->ulpq, ev); + + switch (err_hdr->cause) { + case SCTP_ERROR_UNKNOWN_CHUNK: + { + struct sctp_chunkhdr *unk_chunk_hdr; + + unk_chunk_hdr = (struct sctp_chunkhdr *) + err_hdr->variable; + switch (unk_chunk_hdr->type) { + /* ADDIP 4.1 A9) If the peer responds to an ASCONF with + * an ERROR chunk reporting that it did not recognized + * the ASCONF chunk type, the sender of the ASCONF MUST + * NOT send any further ASCONF chunks and MUST stop its + * T-4 timer. + */ + case SCTP_CID_ASCONF: + if (asoc->peer.asconf_capable == 0) + break; + + asoc->peer.asconf_capable = 0; + sctp_add_cmd_sf(cmds, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T4_RTO)); + break; + default: + break; + } + break; + } + default: + break; + } + } +} + +/* Helper function to remove the association non-primary peer + * transports. + */ +static void sctp_cmd_del_non_primary(struct sctp_association *asoc) +{ + struct sctp_transport *t; + struct list_head *temp; + struct list_head *pos; + + list_for_each_safe(pos, temp, &asoc->peer.transport_addr_list) { + t = list_entry(pos, struct sctp_transport, transports); + if (!sctp_cmp_addr_exact(&t->ipaddr, + &asoc->peer.primary_addr)) { + sctp_assoc_rm_peer(asoc, t); + } + } +} + +/* Helper function to set sk_err on a 1-1 style socket. */ +static void sctp_cmd_set_sk_err(struct sctp_association *asoc, int error) +{ + struct sock *sk = asoc->base.sk; + + if (!sctp_style(sk, UDP)) + sk->sk_err = error; +} + +/* Helper function to generate an association change event */ +static void sctp_cmd_assoc_change(struct sctp_cmd_seq *commands, + struct sctp_association *asoc, + u8 state) +{ + struct sctp_ulpevent *ev; + + ev = sctp_ulpevent_make_assoc_change(asoc, 0, state, 0, + asoc->c.sinit_num_ostreams, + asoc->c.sinit_max_instreams, + NULL, GFP_ATOMIC); + if (ev) + asoc->stream.si->enqueue_event(&asoc->ulpq, ev); +} + +static void sctp_cmd_peer_no_auth(struct sctp_cmd_seq *commands, + struct sctp_association *asoc) +{ + struct sctp_ulpevent *ev; + + ev = sctp_ulpevent_make_authkey(asoc, 0, SCTP_AUTH_NO_AUTH, GFP_ATOMIC); + if (ev) + asoc->stream.si->enqueue_event(&asoc->ulpq, ev); +} + +/* Helper function to generate an adaptation indication event */ +static void sctp_cmd_adaptation_ind(struct sctp_cmd_seq *commands, + struct sctp_association *asoc) +{ + struct sctp_ulpevent *ev; + + ev = sctp_ulpevent_make_adaptation_indication(asoc, GFP_ATOMIC); + + if (ev) + asoc->stream.si->enqueue_event(&asoc->ulpq, ev); +} + + +static void sctp_cmd_t1_timer_update(struct sctp_association *asoc, + enum sctp_event_timeout timer, + char *name) +{ + struct sctp_transport *t; + + t = asoc->init_last_sent_to; + asoc->init_err_counter++; + + if (t->init_sent_count > (asoc->init_cycle + 1)) { + asoc->timeouts[timer] *= 2; + if (asoc->timeouts[timer] > asoc->max_init_timeo) { + asoc->timeouts[timer] = asoc->max_init_timeo; + } + asoc->init_cycle++; + + pr_debug("%s: T1[%s] timeout adjustment init_err_counter:%d" + " cycle:%d timeout:%ld\n", __func__, name, + asoc->init_err_counter, asoc->init_cycle, + asoc->timeouts[timer]); + } + +} + +/* Send the whole message, chunk by chunk, to the outqueue. + * This way the whole message is queued up and bundling if + * encouraged for small fragments. + */ +static void sctp_cmd_send_msg(struct sctp_association *asoc, + struct sctp_datamsg *msg, gfp_t gfp) +{ + struct sctp_chunk *chunk; + + list_for_each_entry(chunk, &msg->chunks, frag_list) + sctp_outq_tail(&asoc->outqueue, chunk, gfp); + + asoc->outqueue.sched->enqueue(&asoc->outqueue, msg); +} + + +/* These three macros allow us to pull the debugging code out of the + * main flow of sctp_do_sm() to keep attention focused on the real + * functionality there. + */ +#define debug_pre_sfn() \ + pr_debug("%s[pre-fn]: ep:%p, %s, %s, asoc:%p[%s], %s\n", __func__, \ + ep, sctp_evttype_tbl[event_type], (*debug_fn)(subtype), \ + asoc, sctp_state_tbl[state], state_fn->name) + +#define debug_post_sfn() \ + pr_debug("%s[post-fn]: asoc:%p, status:%s\n", __func__, asoc, \ + sctp_status_tbl[status]) + +#define debug_post_sfx() \ + pr_debug("%s[post-sfx]: error:%d, asoc:%p[%s]\n", __func__, error, \ + asoc, sctp_state_tbl[(asoc && sctp_id2assoc(ep->base.sk, \ + sctp_assoc2id(asoc))) ? asoc->state : SCTP_STATE_CLOSED]) + +/* + * This is the master state machine processing function. + * + * If you want to understand all of lksctp, this is a + * good place to start. + */ +int sctp_do_sm(struct net *net, enum sctp_event_type event_type, + union sctp_subtype subtype, enum sctp_state state, + struct sctp_endpoint *ep, struct sctp_association *asoc, + void *event_arg, gfp_t gfp) +{ + typedef const char *(printfn_t)(union sctp_subtype); + static printfn_t *table[] = { + NULL, sctp_cname, sctp_tname, sctp_oname, sctp_pname, + }; + printfn_t *debug_fn __attribute__ ((unused)) = table[event_type]; + const struct sctp_sm_table_entry *state_fn; + struct sctp_cmd_seq commands; + enum sctp_disposition status; + int error = 0; + + /* Look up the state function, run it, and then process the + * side effects. These three steps are the heart of lksctp. + */ + state_fn = sctp_sm_lookup_event(net, event_type, state, subtype); + + sctp_init_cmd_seq(&commands); + + debug_pre_sfn(); + status = state_fn->fn(net, ep, asoc, subtype, event_arg, &commands); + debug_post_sfn(); + + error = sctp_side_effects(event_type, subtype, state, + ep, &asoc, event_arg, status, + &commands, gfp); + debug_post_sfx(); + + return error; +} + +/***************************************************************** + * This the master state function side effect processing function. + *****************************************************************/ +static int sctp_side_effects(enum sctp_event_type event_type, + union sctp_subtype subtype, + enum sctp_state state, + struct sctp_endpoint *ep, + struct sctp_association **asoc, + void *event_arg, + enum sctp_disposition status, + struct sctp_cmd_seq *commands, + gfp_t gfp) +{ + int error; + + /* FIXME - Most of the dispositions left today would be categorized + * as "exceptional" dispositions. For those dispositions, it + * may not be proper to run through any of the commands at all. + * For example, the command interpreter might be run only with + * disposition SCTP_DISPOSITION_CONSUME. + */ + if (0 != (error = sctp_cmd_interpreter(event_type, subtype, state, + ep, *asoc, + event_arg, status, + commands, gfp))) + goto bail; + + switch (status) { + case SCTP_DISPOSITION_DISCARD: + pr_debug("%s: ignored sctp protocol event - state:%d, " + "event_type:%d, event_id:%d\n", __func__, state, + event_type, subtype.chunk); + break; + + case SCTP_DISPOSITION_NOMEM: + /* We ran out of memory, so we need to discard this + * packet. + */ + /* BUG--we should now recover some memory, probably by + * reneging... + */ + error = -ENOMEM; + break; + + case SCTP_DISPOSITION_DELETE_TCB: + case SCTP_DISPOSITION_ABORT: + /* This should now be a command. */ + *asoc = NULL; + break; + + case SCTP_DISPOSITION_CONSUME: + /* + * We should no longer have much work to do here as the + * real work has been done as explicit commands above. + */ + break; + + case SCTP_DISPOSITION_VIOLATION: + net_err_ratelimited("protocol violation state %d chunkid %d\n", + state, subtype.chunk); + break; + + case SCTP_DISPOSITION_NOT_IMPL: + pr_warn("unimplemented feature in state %d, event_type %d, event_id %d\n", + state, event_type, subtype.chunk); + break; + + case SCTP_DISPOSITION_BUG: + pr_err("bug in state %d, event_type %d, event_id %d\n", + state, event_type, subtype.chunk); + BUG(); + break; + + default: + pr_err("impossible disposition %d in state %d, event_type %d, event_id %d\n", + status, state, event_type, subtype.chunk); + error = status; + if (error >= 0) + error = -EINVAL; + WARN_ON_ONCE(1); + break; + } + +bail: + return error; +} + +/******************************************************************** + * 2nd Level Abstractions + ********************************************************************/ + +/* This is the side-effect interpreter. */ +static int sctp_cmd_interpreter(enum sctp_event_type event_type, + union sctp_subtype subtype, + enum sctp_state state, + struct sctp_endpoint *ep, + struct sctp_association *asoc, + void *event_arg, + enum sctp_disposition status, + struct sctp_cmd_seq *commands, + gfp_t gfp) +{ + struct sctp_sock *sp = sctp_sk(ep->base.sk); + struct sctp_chunk *chunk = NULL, *new_obj; + struct sctp_packet *packet; + struct sctp_sackhdr sackh; + struct timer_list *timer; + struct sctp_transport *t; + unsigned long timeout; + struct sctp_cmd *cmd; + int local_cork = 0; + int error = 0; + int force; + + if (SCTP_EVENT_T_TIMEOUT != event_type) + chunk = event_arg; + + /* Note: This whole file is a huge candidate for rework. + * For example, each command could either have its own handler, so + * the loop would look like: + * while (cmds) + * cmd->handle(x, y, z) + * --jgrimm + */ + while (NULL != (cmd = sctp_next_cmd(commands))) { + switch (cmd->verb) { + case SCTP_CMD_NOP: + /* Do nothing. */ + break; + + case SCTP_CMD_NEW_ASOC: + /* Register a new association. */ + if (local_cork) { + sctp_outq_uncork(&asoc->outqueue, gfp); + local_cork = 0; + } + + /* Register with the endpoint. */ + asoc = cmd->obj.asoc; + BUG_ON(asoc->peer.primary_path == NULL); + sctp_endpoint_add_asoc(ep, asoc); + break; + + case SCTP_CMD_PURGE_OUTQUEUE: + sctp_outq_teardown(&asoc->outqueue); + break; + + case SCTP_CMD_DELETE_TCB: + if (local_cork) { + sctp_outq_uncork(&asoc->outqueue, gfp); + local_cork = 0; + } + /* Delete the current association. */ + sctp_cmd_delete_tcb(commands, asoc); + asoc = NULL; + break; + + case SCTP_CMD_NEW_STATE: + /* Enter a new state. */ + sctp_cmd_new_state(commands, asoc, cmd->obj.state); + break; + + case SCTP_CMD_REPORT_TSN: + /* Record the arrival of a TSN. */ + error = sctp_tsnmap_mark(&asoc->peer.tsn_map, + cmd->obj.u32, NULL); + break; + + case SCTP_CMD_REPORT_FWDTSN: + asoc->stream.si->report_ftsn(&asoc->ulpq, cmd->obj.u32); + break; + + case SCTP_CMD_PROCESS_FWDTSN: + asoc->stream.si->handle_ftsn(&asoc->ulpq, + cmd->obj.chunk); + break; + + case SCTP_CMD_GEN_SACK: + /* Generate a Selective ACK. + * The argument tells us whether to just count + * the packet and MAYBE generate a SACK, or + * force a SACK out. + */ + force = cmd->obj.i32; + error = sctp_gen_sack(asoc, force, commands); + break; + + case SCTP_CMD_PROCESS_SACK: + /* Process an inbound SACK. */ + error = sctp_cmd_process_sack(commands, asoc, + cmd->obj.chunk); + break; + + case SCTP_CMD_GEN_INIT_ACK: + /* Generate an INIT ACK chunk. */ + new_obj = sctp_make_init_ack(asoc, chunk, GFP_ATOMIC, + 0); + if (!new_obj) { + error = -ENOMEM; + break; + } + + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, + SCTP_CHUNK(new_obj)); + break; + + case SCTP_CMD_PEER_INIT: + /* Process a unified INIT from the peer. + * Note: Only used during INIT-ACK processing. If + * there is an error just return to the outter + * layer which will bail. + */ + error = sctp_cmd_process_init(commands, asoc, chunk, + cmd->obj.init, gfp); + break; + + case SCTP_CMD_GEN_COOKIE_ECHO: + /* Generate a COOKIE ECHO chunk. */ + new_obj = sctp_make_cookie_echo(asoc, chunk); + if (!new_obj) { + if (cmd->obj.chunk) + sctp_chunk_free(cmd->obj.chunk); + error = -ENOMEM; + break; + } + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, + SCTP_CHUNK(new_obj)); + + /* If there is an ERROR chunk to be sent along with + * the COOKIE_ECHO, send it, too. + */ + if (cmd->obj.chunk) + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, + SCTP_CHUNK(cmd->obj.chunk)); + + if (new_obj->transport) { + new_obj->transport->init_sent_count++; + asoc->init_last_sent_to = new_obj->transport; + } + + /* FIXME - Eventually come up with a cleaner way to + * enabling COOKIE-ECHO + DATA bundling during + * multihoming stale cookie scenarios, the following + * command plays with asoc->peer.retran_path to + * avoid the problem of sending the COOKIE-ECHO and + * DATA in different paths, which could result + * in the association being ABORTed if the DATA chunk + * is processed first by the server. Checking the + * init error counter simply causes this command + * to be executed only during failed attempts of + * association establishment. + */ + if ((asoc->peer.retran_path != + asoc->peer.primary_path) && + (asoc->init_err_counter > 0)) { + sctp_add_cmd_sf(commands, + SCTP_CMD_FORCE_PRIM_RETRAN, + SCTP_NULL()); + } + + break; + + case SCTP_CMD_GEN_SHUTDOWN: + /* Generate SHUTDOWN when in SHUTDOWN_SENT state. + * Reset error counts. + */ + asoc->overall_error_count = 0; + + /* Generate a SHUTDOWN chunk. */ + new_obj = sctp_make_shutdown(asoc, chunk); + if (!new_obj) { + error = -ENOMEM; + break; + } + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, + SCTP_CHUNK(new_obj)); + break; + + case SCTP_CMD_CHUNK_ULP: + /* Send a chunk to the sockets layer. */ + pr_debug("%s: sm_sideff: chunk_up:%p, ulpq:%p\n", + __func__, cmd->obj.chunk, &asoc->ulpq); + + asoc->stream.si->ulpevent_data(&asoc->ulpq, + cmd->obj.chunk, + GFP_ATOMIC); + break; + + case SCTP_CMD_EVENT_ULP: + /* Send a notification to the sockets layer. */ + pr_debug("%s: sm_sideff: event_up:%p, ulpq:%p\n", + __func__, cmd->obj.ulpevent, &asoc->ulpq); + + asoc->stream.si->enqueue_event(&asoc->ulpq, + cmd->obj.ulpevent); + break; + + case SCTP_CMD_REPLY: + /* If an caller has not already corked, do cork. */ + if (!asoc->outqueue.cork) { + sctp_outq_cork(&asoc->outqueue); + local_cork = 1; + } + /* Send a chunk to our peer. */ + sctp_outq_tail(&asoc->outqueue, cmd->obj.chunk, gfp); + break; + + case SCTP_CMD_SEND_PKT: + /* Send a full packet to our peer. */ + packet = cmd->obj.packet; + sctp_packet_transmit(packet, gfp); + sctp_ootb_pkt_free(packet); + break; + + case SCTP_CMD_T1_RETRAN: + /* Mark a transport for retransmission. */ + sctp_retransmit(&asoc->outqueue, cmd->obj.transport, + SCTP_RTXR_T1_RTX); + break; + + case SCTP_CMD_RETRAN: + /* Mark a transport for retransmission. */ + sctp_retransmit(&asoc->outqueue, cmd->obj.transport, + SCTP_RTXR_T3_RTX); + break; + + case SCTP_CMD_ECN_CE: + /* Do delayed CE processing. */ + sctp_do_ecn_ce_work(asoc, cmd->obj.u32); + break; + + case SCTP_CMD_ECN_ECNE: + /* Do delayed ECNE processing. */ + new_obj = sctp_do_ecn_ecne_work(asoc, cmd->obj.u32, + chunk); + if (new_obj) + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, + SCTP_CHUNK(new_obj)); + break; + + case SCTP_CMD_ECN_CWR: + /* Do delayed CWR processing. */ + sctp_do_ecn_cwr_work(asoc, cmd->obj.u32); + break; + + case SCTP_CMD_SETUP_T2: + sctp_cmd_setup_t2(commands, asoc, cmd->obj.chunk); + break; + + case SCTP_CMD_TIMER_START_ONCE: + timer = &asoc->timers[cmd->obj.to]; + + if (timer_pending(timer)) + break; + fallthrough; + + case SCTP_CMD_TIMER_START: + timer = &asoc->timers[cmd->obj.to]; + timeout = asoc->timeouts[cmd->obj.to]; + BUG_ON(!timeout); + + /* + * SCTP has a hard time with timer starts. Because we process + * timer starts as side effects, it can be hard to tell if we + * have already started a timer or not, which leads to BUG + * halts when we call add_timer. So here, instead of just starting + * a timer, if the timer is already started, and just mod + * the timer with the shorter of the two expiration times + */ + if (!timer_pending(timer)) + sctp_association_hold(asoc); + timer_reduce(timer, jiffies + timeout); + break; + + case SCTP_CMD_TIMER_RESTART: + timer = &asoc->timers[cmd->obj.to]; + timeout = asoc->timeouts[cmd->obj.to]; + if (!mod_timer(timer, jiffies + timeout)) + sctp_association_hold(asoc); + break; + + case SCTP_CMD_TIMER_STOP: + timer = &asoc->timers[cmd->obj.to]; + if (del_timer(timer)) + sctp_association_put(asoc); + break; + + case SCTP_CMD_INIT_CHOOSE_TRANSPORT: + chunk = cmd->obj.chunk; + t = sctp_assoc_choose_alter_transport(asoc, + asoc->init_last_sent_to); + asoc->init_last_sent_to = t; + chunk->transport = t; + t->init_sent_count++; + /* Set the new transport as primary */ + sctp_assoc_set_primary(asoc, t); + break; + + case SCTP_CMD_INIT_RESTART: + /* Do the needed accounting and updates + * associated with restarting an initialization + * timer. Only multiply the timeout by two if + * all transports have been tried at the current + * timeout. + */ + sctp_cmd_t1_timer_update(asoc, + SCTP_EVENT_TIMEOUT_T1_INIT, + "INIT"); + + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_RESTART, + SCTP_TO(SCTP_EVENT_TIMEOUT_T1_INIT)); + break; + + case SCTP_CMD_COOKIEECHO_RESTART: + /* Do the needed accounting and updates + * associated with restarting an initialization + * timer. Only multiply the timeout by two if + * all transports have been tried at the current + * timeout. + */ + sctp_cmd_t1_timer_update(asoc, + SCTP_EVENT_TIMEOUT_T1_COOKIE, + "COOKIE"); + + /* If we've sent any data bundled with + * COOKIE-ECHO we need to resend. + */ + list_for_each_entry(t, &asoc->peer.transport_addr_list, + transports) { + sctp_retransmit_mark(&asoc->outqueue, t, + SCTP_RTXR_T1_RTX); + } + + sctp_add_cmd_sf(commands, + SCTP_CMD_TIMER_RESTART, + SCTP_TO(SCTP_EVENT_TIMEOUT_T1_COOKIE)); + break; + + case SCTP_CMD_INIT_FAILED: + sctp_cmd_init_failed(commands, asoc, cmd->obj.u16); + break; + + case SCTP_CMD_ASSOC_FAILED: + sctp_cmd_assoc_failed(commands, asoc, event_type, + subtype, chunk, cmd->obj.u16); + break; + + case SCTP_CMD_INIT_COUNTER_INC: + asoc->init_err_counter++; + break; + + case SCTP_CMD_INIT_COUNTER_RESET: + asoc->init_err_counter = 0; + asoc->init_cycle = 0; + list_for_each_entry(t, &asoc->peer.transport_addr_list, + transports) { + t->init_sent_count = 0; + } + break; + + case SCTP_CMD_REPORT_DUP: + sctp_tsnmap_mark_dup(&asoc->peer.tsn_map, + cmd->obj.u32); + break; + + case SCTP_CMD_REPORT_BAD_TAG: + pr_debug("%s: vtag mismatch!\n", __func__); + break; + + case SCTP_CMD_STRIKE: + /* Mark one strike against a transport. */ + sctp_do_8_2_transport_strike(commands, asoc, + cmd->obj.transport, 0); + break; + + case SCTP_CMD_TRANSPORT_IDLE: + t = cmd->obj.transport; + sctp_transport_lower_cwnd(t, SCTP_LOWER_CWND_INACTIVE); + break; + + case SCTP_CMD_TRANSPORT_HB_SENT: + t = cmd->obj.transport; + sctp_do_8_2_transport_strike(commands, asoc, + t, 1); + t->hb_sent = 1; + break; + + case SCTP_CMD_TRANSPORT_ON: + t = cmd->obj.transport; + sctp_cmd_transport_on(commands, asoc, t, chunk); + break; + + case SCTP_CMD_HB_TIMERS_START: + sctp_cmd_hb_timers_start(commands, asoc); + break; + + case SCTP_CMD_HB_TIMER_UPDATE: + t = cmd->obj.transport; + sctp_transport_reset_hb_timer(t); + break; + + case SCTP_CMD_HB_TIMERS_STOP: + sctp_cmd_hb_timers_stop(commands, asoc); + break; + + case SCTP_CMD_PROBE_TIMER_UPDATE: + t = cmd->obj.transport; + sctp_transport_reset_probe_timer(t); + break; + + case SCTP_CMD_REPORT_ERROR: + error = cmd->obj.error; + break; + + case SCTP_CMD_PROCESS_CTSN: + /* Dummy up a SACK for processing. */ + sackh.cum_tsn_ack = cmd->obj.be32; + sackh.a_rwnd = htonl(asoc->peer.rwnd + + asoc->outqueue.outstanding_bytes); + sackh.num_gap_ack_blocks = 0; + sackh.num_dup_tsns = 0; + chunk->subh.sack_hdr = &sackh; + sctp_add_cmd_sf(commands, SCTP_CMD_PROCESS_SACK, + SCTP_CHUNK(chunk)); + break; + + case SCTP_CMD_DISCARD_PACKET: + /* We need to discard the whole packet. + * Uncork the queue since there might be + * responses pending + */ + chunk->pdiscard = 1; + if (asoc) { + sctp_outq_uncork(&asoc->outqueue, gfp); + local_cork = 0; + } + break; + + case SCTP_CMD_RTO_PENDING: + t = cmd->obj.transport; + t->rto_pending = 1; + break; + + case SCTP_CMD_PART_DELIVER: + asoc->stream.si->start_pd(&asoc->ulpq, GFP_ATOMIC); + break; + + case SCTP_CMD_RENEGE: + asoc->stream.si->renege_events(&asoc->ulpq, + cmd->obj.chunk, + GFP_ATOMIC); + break; + + case SCTP_CMD_SETUP_T4: + sctp_cmd_setup_t4(commands, asoc, cmd->obj.chunk); + break; + + case SCTP_CMD_PROCESS_OPERR: + sctp_cmd_process_operr(commands, asoc, chunk); + break; + case SCTP_CMD_CLEAR_INIT_TAG: + asoc->peer.i.init_tag = 0; + break; + case SCTP_CMD_DEL_NON_PRIMARY: + sctp_cmd_del_non_primary(asoc); + break; + case SCTP_CMD_T3_RTX_TIMERS_STOP: + sctp_cmd_t3_rtx_timers_stop(commands, asoc); + break; + case SCTP_CMD_FORCE_PRIM_RETRAN: + t = asoc->peer.retran_path; + asoc->peer.retran_path = asoc->peer.primary_path; + sctp_outq_uncork(&asoc->outqueue, gfp); + local_cork = 0; + asoc->peer.retran_path = t; + break; + case SCTP_CMD_SET_SK_ERR: + sctp_cmd_set_sk_err(asoc, cmd->obj.error); + break; + case SCTP_CMD_ASSOC_CHANGE: + sctp_cmd_assoc_change(commands, asoc, + cmd->obj.u8); + break; + case SCTP_CMD_ADAPTATION_IND: + sctp_cmd_adaptation_ind(commands, asoc); + break; + case SCTP_CMD_PEER_NO_AUTH: + sctp_cmd_peer_no_auth(commands, asoc); + break; + + case SCTP_CMD_ASSOC_SHKEY: + error = sctp_auth_asoc_init_active_key(asoc, + GFP_ATOMIC); + break; + case SCTP_CMD_UPDATE_INITTAG: + asoc->peer.i.init_tag = cmd->obj.u32; + break; + case SCTP_CMD_SEND_MSG: + if (!asoc->outqueue.cork) { + sctp_outq_cork(&asoc->outqueue); + local_cork = 1; + } + sctp_cmd_send_msg(asoc, cmd->obj.msg, gfp); + break; + case SCTP_CMD_PURGE_ASCONF_QUEUE: + sctp_asconf_queue_teardown(asoc); + break; + + case SCTP_CMD_SET_ASOC: + if (asoc && local_cork) { + sctp_outq_uncork(&asoc->outqueue, gfp); + local_cork = 0; + } + asoc = cmd->obj.asoc; + break; + + default: + pr_warn("Impossible command: %u\n", + cmd->verb); + break; + } + + if (error) { + cmd = sctp_next_cmd(commands); + while (cmd) { + if (cmd->verb == SCTP_CMD_REPLY) + sctp_chunk_free(cmd->obj.chunk); + cmd = sctp_next_cmd(commands); + } + break; + } + } + + /* If this is in response to a received chunk, wait until + * we are done with the packet to open the queue so that we don't + * send multiple packets in response to a single request. + */ + if (asoc && SCTP_EVENT_T_CHUNK == event_type && chunk) { + if (chunk->end_of_packet || chunk->singleton) + sctp_outq_uncork(&asoc->outqueue, gfp); + } else if (local_cork) + sctp_outq_uncork(&asoc->outqueue, gfp); + + if (sp->data_ready_signalled) + sp->data_ready_signalled = 0; + + return error; +} diff --git a/net/sctp/sm_statefuns.c b/net/sctp/sm_statefuns.c new file mode 100644 index 000000000..5383b6a9d --- /dev/null +++ b/net/sctp/sm_statefuns.c @@ -0,0 +1,6677 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2001, 2004 + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * Copyright (c) 2001-2002 Intel Corp. + * Copyright (c) 2002 Nokia Corp. + * + * This is part of the SCTP Linux Kernel Implementation. + * + * These are the state functions for the state machine. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * La Monte H.P. Yarroll + * Karl Knutson + * Mathew Kotowsky + * Sridhar Samudrala + * Jon Grimm + * Hui Huang + * Dajiang Zhang + * Daisy Chang + * Ardelle Fan + * Ryan Layer + * Kevin Gao + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define CREATE_TRACE_POINTS +#include + +static struct sctp_packet *sctp_abort_pkt_new( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + struct sctp_chunk *chunk, + const void *payload, size_t paylen); +static int sctp_eat_data(const struct sctp_association *asoc, + struct sctp_chunk *chunk, + struct sctp_cmd_seq *commands); +static struct sctp_packet *sctp_ootb_pkt_new( + struct net *net, + const struct sctp_association *asoc, + const struct sctp_chunk *chunk); +static void sctp_send_stale_cookie_err(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const struct sctp_chunk *chunk, + struct sctp_cmd_seq *commands, + struct sctp_chunk *err_chunk); +static enum sctp_disposition sctp_sf_do_5_2_6_stale( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands); +static enum sctp_disposition sctp_sf_shut_8_4_5( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands); +static enum sctp_disposition sctp_sf_tabort_8_4_8( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands); +static enum sctp_disposition sctp_sf_new_encap_port( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands); +static struct sctp_sackhdr *sctp_sm_pull_sack(struct sctp_chunk *chunk); + +static enum sctp_disposition sctp_stop_t1_and_abort( + struct net *net, + struct sctp_cmd_seq *commands, + __be16 error, int sk_err, + const struct sctp_association *asoc, + struct sctp_transport *transport); + +static enum sctp_disposition sctp_sf_abort_violation( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + void *arg, + struct sctp_cmd_seq *commands, + const __u8 *payload, + const size_t paylen); + +static enum sctp_disposition sctp_sf_violation_chunklen( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands); + +static enum sctp_disposition sctp_sf_violation_paramlen( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, void *ext, + struct sctp_cmd_seq *commands); + +static enum sctp_disposition sctp_sf_violation_ctsn( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands); + +static enum sctp_disposition sctp_sf_violation_chunk( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands); + +static enum sctp_ierror sctp_sf_authenticate( + const struct sctp_association *asoc, + struct sctp_chunk *chunk); + +static enum sctp_disposition __sctp_sf_do_9_1_abort( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands); + +static enum sctp_disposition +__sctp_sf_do_9_2_reshutack(struct net *net, const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, void *arg, + struct sctp_cmd_seq *commands); + +/* Small helper function that checks if the chunk length + * is of the appropriate length. The 'required_length' argument + * is set to be the size of a specific chunk we are testing. + * Return Values: true = Valid length + * false = Invalid length + * + */ +static inline bool sctp_chunk_length_valid(struct sctp_chunk *chunk, + __u16 required_length) +{ + __u16 chunk_length = ntohs(chunk->chunk_hdr->length); + + /* Previously already marked? */ + if (unlikely(chunk->pdiscard)) + return false; + if (unlikely(chunk_length < required_length)) + return false; + + return true; +} + +/* Check for format error in an ABORT chunk */ +static inline bool sctp_err_chunk_valid(struct sctp_chunk *chunk) +{ + struct sctp_errhdr *err; + + sctp_walk_errors(err, chunk->chunk_hdr); + + return (void *)err == (void *)chunk->chunk_end; +} + +/********************************************************** + * These are the state functions for handling chunk events. + **********************************************************/ + +/* + * Process the final SHUTDOWN COMPLETE. + * + * Section: 4 (C) (diagram), 9.2 + * Upon reception of the SHUTDOWN COMPLETE chunk the endpoint will verify + * that it is in SHUTDOWN-ACK-SENT state, if it is not the chunk should be + * discarded. If the endpoint is in the SHUTDOWN-ACK-SENT state the endpoint + * should stop the T2-shutdown timer and remove all knowledge of the + * association (and thus the association enters the CLOSED state). + * + * Verification Tag: 8.5.1(C), sctpimpguide 2.41. + * C) Rules for packet carrying SHUTDOWN COMPLETE: + * ... + * - The receiver of a SHUTDOWN COMPLETE shall accept the packet + * if the Verification Tag field of the packet matches its own tag and + * the T bit is not set + * OR + * it is set to its peer's tag and the T bit is set in the Chunk + * Flags. + * Otherwise, the receiver MUST silently discard the packet + * and take no further action. An endpoint MUST ignore the + * SHUTDOWN COMPLETE if it is not in the SHUTDOWN-ACK-SENT state. + * + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_do_4_C(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + struct sctp_ulpevent *ev; + + if (!sctp_vtag_verify_either(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* RFC 2960 6.10 Bundling + * + * An endpoint MUST NOT bundle INIT, INIT ACK or + * SHUTDOWN COMPLETE with any other chunks. + */ + if (!chunk->singleton) + return sctp_sf_violation_chunk(net, ep, asoc, type, arg, commands); + + /* Make sure that the SHUTDOWN_COMPLETE chunk has a valid length. */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_chunkhdr))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + /* RFC 2960 10.2 SCTP-to-ULP + * + * H) SHUTDOWN COMPLETE notification + * + * When SCTP completes the shutdown procedures (section 9.2) this + * notification is passed to the upper layer. + */ + ev = sctp_ulpevent_make_assoc_change(asoc, 0, SCTP_SHUTDOWN_COMP, + 0, 0, 0, NULL, GFP_ATOMIC); + if (ev) + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, + SCTP_ULPEVENT(ev)); + + /* Upon reception of the SHUTDOWN COMPLETE chunk the endpoint + * will verify that it is in SHUTDOWN-ACK-SENT state, if it is + * not the chunk should be discarded. If the endpoint is in + * the SHUTDOWN-ACK-SENT state the endpoint should stop the + * T2-shutdown timer and remove all knowledge of the + * association (and thus the association enters the CLOSED + * state). + */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T2_SHUTDOWN)); + + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T5_SHUTDOWN_GUARD)); + + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_CLOSED)); + + SCTP_INC_STATS(net, SCTP_MIB_SHUTDOWNS); + SCTP_DEC_STATS(net, SCTP_MIB_CURRESTAB); + + sctp_add_cmd_sf(commands, SCTP_CMD_DELETE_TCB, SCTP_NULL()); + + return SCTP_DISPOSITION_DELETE_TCB; +} + +/* + * Respond to a normal INIT chunk. + * We are the side that is being asked for an association. + * + * Section: 5.1 Normal Establishment of an Association, B + * B) "Z" shall respond immediately with an INIT ACK chunk. The + * destination IP address of the INIT ACK MUST be set to the source + * IP address of the INIT to which this INIT ACK is responding. In + * the response, besides filling in other parameters, "Z" must set the + * Verification Tag field to Tag_A, and also provide its own + * Verification Tag (Tag_Z) in the Initiate Tag field. + * + * Verification Tag: Must be 0. + * + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_do_5_1B_init(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg, *repl, *err_chunk; + struct sctp_unrecognized_param *unk_param; + struct sctp_association *new_asoc; + struct sctp_packet *packet; + int len; + + /* 6.10 Bundling + * An endpoint MUST NOT bundle INIT, INIT ACK or + * SHUTDOWN COMPLETE with any other chunks. + * + * IG Section 2.11.2 + * Furthermore, we require that the receiver of an INIT chunk MUST + * enforce these rules by silently discarding an arriving packet + * with an INIT chunk that is bundled with other chunks. + */ + if (!chunk->singleton) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Make sure that the INIT chunk has a valid length. + * Normally, this would cause an ABORT with a Protocol Violation + * error, but since we don't have an association, we'll + * just discard the packet. + */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_init_chunk))) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* If the packet is an OOTB packet which is temporarily on the + * control endpoint, respond with an ABORT. + */ + if (ep == sctp_sk(net->sctp.ctl_sock)->ep) { + SCTP_INC_STATS(net, SCTP_MIB_OUTOFBLUES); + return sctp_sf_tabort_8_4_8(net, ep, asoc, type, arg, commands); + } + + /* 3.1 A packet containing an INIT chunk MUST have a zero Verification + * Tag. + */ + if (chunk->sctp_hdr->vtag != 0) + return sctp_sf_tabort_8_4_8(net, ep, asoc, type, arg, commands); + + /* If the INIT is coming toward a closing socket, we'll send back + * and ABORT. Essentially, this catches the race of INIT being + * backloged to the socket at the same time as the user issues close(). + * Since the socket and all its associations are going away, we + * can treat this OOTB + */ + if (sctp_sstate(ep->base.sk, CLOSING)) + return sctp_sf_tabort_8_4_8(net, ep, asoc, type, arg, commands); + + /* Verify the INIT chunk before processing it. */ + err_chunk = NULL; + if (!sctp_verify_init(net, ep, asoc, chunk->chunk_hdr->type, + (struct sctp_init_chunk *)chunk->chunk_hdr, chunk, + &err_chunk)) { + /* This chunk contains fatal error. It is to be discarded. + * Send an ABORT, with causes if there is any. + */ + if (err_chunk) { + packet = sctp_abort_pkt_new(net, ep, asoc, arg, + (__u8 *)(err_chunk->chunk_hdr) + + sizeof(struct sctp_chunkhdr), + ntohs(err_chunk->chunk_hdr->length) - + sizeof(struct sctp_chunkhdr)); + + sctp_chunk_free(err_chunk); + + if (packet) { + sctp_add_cmd_sf(commands, SCTP_CMD_SEND_PKT, + SCTP_PACKET(packet)); + SCTP_INC_STATS(net, SCTP_MIB_OUTCTRLCHUNKS); + return SCTP_DISPOSITION_CONSUME; + } else { + return SCTP_DISPOSITION_NOMEM; + } + } else { + return sctp_sf_tabort_8_4_8(net, ep, asoc, type, arg, + commands); + } + } + + /* Grab the INIT header. */ + chunk->subh.init_hdr = (struct sctp_inithdr *)chunk->skb->data; + + /* Tag the variable length parameters. */ + chunk->param_hdr.v = skb_pull(chunk->skb, sizeof(struct sctp_inithdr)); + + new_asoc = sctp_make_temp_asoc(ep, chunk, GFP_ATOMIC); + if (!new_asoc) + goto nomem; + + /* Update socket peer label if first association. */ + if (security_sctp_assoc_request(new_asoc, chunk->skb)) { + sctp_association_free(new_asoc); + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + } + + if (sctp_assoc_set_bind_addr_from_ep(new_asoc, + sctp_scope(sctp_source(chunk)), + GFP_ATOMIC) < 0) + goto nomem_init; + + /* The call, sctp_process_init(), can fail on memory allocation. */ + if (!sctp_process_init(new_asoc, chunk, sctp_source(chunk), + (struct sctp_init_chunk *)chunk->chunk_hdr, + GFP_ATOMIC)) + goto nomem_init; + + /* B) "Z" shall respond immediately with an INIT ACK chunk. */ + + /* If there are errors need to be reported for unknown parameters, + * make sure to reserve enough room in the INIT ACK for them. + */ + len = 0; + if (err_chunk) + len = ntohs(err_chunk->chunk_hdr->length) - + sizeof(struct sctp_chunkhdr); + + repl = sctp_make_init_ack(new_asoc, chunk, GFP_ATOMIC, len); + if (!repl) + goto nomem_init; + + /* If there are errors need to be reported for unknown parameters, + * include them in the outgoing INIT ACK as "Unrecognized parameter" + * parameter. + */ + if (err_chunk) { + /* Get the "Unrecognized parameter" parameter(s) out of the + * ERROR chunk generated by sctp_verify_init(). Since the + * error cause code for "unknown parameter" and the + * "Unrecognized parameter" type is the same, we can + * construct the parameters in INIT ACK by copying the + * ERROR causes over. + */ + unk_param = (struct sctp_unrecognized_param *) + ((__u8 *)(err_chunk->chunk_hdr) + + sizeof(struct sctp_chunkhdr)); + /* Replace the cause code with the "Unrecognized parameter" + * parameter type. + */ + sctp_addto_chunk(repl, len, unk_param); + sctp_chunk_free(err_chunk); + } + + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_ASOC, SCTP_ASOC(new_asoc)); + + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(repl)); + + /* + * Note: After sending out INIT ACK with the State Cookie parameter, + * "Z" MUST NOT allocate any resources, nor keep any states for the + * new association. Otherwise, "Z" will be vulnerable to resource + * attacks. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_DELETE_TCB, SCTP_NULL()); + + return SCTP_DISPOSITION_DELETE_TCB; + +nomem_init: + sctp_association_free(new_asoc); +nomem: + if (err_chunk) + sctp_chunk_free(err_chunk); + return SCTP_DISPOSITION_NOMEM; +} + +/* + * Respond to a normal INIT ACK chunk. + * We are the side that is initiating the association. + * + * Section: 5.1 Normal Establishment of an Association, C + * C) Upon reception of the INIT ACK from "Z", "A" shall stop the T1-init + * timer and leave COOKIE-WAIT state. "A" shall then send the State + * Cookie received in the INIT ACK chunk in a COOKIE ECHO chunk, start + * the T1-cookie timer, and enter the COOKIE-ECHOED state. + * + * Note: The COOKIE ECHO chunk can be bundled with any pending outbound + * DATA chunks, but it MUST be the first chunk in the packet and + * until the COOKIE ACK is returned the sender MUST NOT send any + * other packets to the peer. + * + * Verification Tag: 3.3.3 + * If the value of the Initiate Tag in a received INIT ACK chunk is + * found to be 0, the receiver MUST treat it as an error and close the + * association by transmitting an ABORT. + * + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_do_5_1C_ack(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_init_chunk *initchunk; + struct sctp_chunk *chunk = arg; + struct sctp_chunk *err_chunk; + struct sctp_packet *packet; + + if (!sctp_vtag_verify(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* 6.10 Bundling + * An endpoint MUST NOT bundle INIT, INIT ACK or + * SHUTDOWN COMPLETE with any other chunks. + */ + if (!chunk->singleton) + return sctp_sf_violation_chunk(net, ep, asoc, type, arg, commands); + + /* Make sure that the INIT-ACK chunk has a valid length */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_initack_chunk))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + /* Grab the INIT header. */ + chunk->subh.init_hdr = (struct sctp_inithdr *)chunk->skb->data; + + /* Verify the INIT chunk before processing it. */ + err_chunk = NULL; + if (!sctp_verify_init(net, ep, asoc, chunk->chunk_hdr->type, + (struct sctp_init_chunk *)chunk->chunk_hdr, chunk, + &err_chunk)) { + + enum sctp_error error = SCTP_ERROR_NO_RESOURCE; + + /* This chunk contains fatal error. It is to be discarded. + * Send an ABORT, with causes. If there are no causes, + * then there wasn't enough memory. Just terminate + * the association. + */ + if (err_chunk) { + packet = sctp_abort_pkt_new(net, ep, asoc, arg, + (__u8 *)(err_chunk->chunk_hdr) + + sizeof(struct sctp_chunkhdr), + ntohs(err_chunk->chunk_hdr->length) - + sizeof(struct sctp_chunkhdr)); + + sctp_chunk_free(err_chunk); + + if (packet) { + sctp_add_cmd_sf(commands, SCTP_CMD_SEND_PKT, + SCTP_PACKET(packet)); + SCTP_INC_STATS(net, SCTP_MIB_OUTCTRLCHUNKS); + error = SCTP_ERROR_INV_PARAM; + } + } + + /* SCTP-AUTH, Section 6.3: + * It should be noted that if the receiver wants to tear + * down an association in an authenticated way only, the + * handling of malformed packets should not result in + * tearing down the association. + * + * This means that if we only want to abort associations + * in an authenticated way (i.e AUTH+ABORT), then we + * can't destroy this association just because the packet + * was malformed. + */ + if (sctp_auth_recv_cid(SCTP_CID_ABORT, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + SCTP_INC_STATS(net, SCTP_MIB_ABORTEDS); + return sctp_stop_t1_and_abort(net, commands, error, ECONNREFUSED, + asoc, chunk->transport); + } + + /* Tag the variable length parameters. Note that we never + * convert the parameters in an INIT chunk. + */ + chunk->param_hdr.v = skb_pull(chunk->skb, sizeof(struct sctp_inithdr)); + + initchunk = (struct sctp_init_chunk *)chunk->chunk_hdr; + + sctp_add_cmd_sf(commands, SCTP_CMD_PEER_INIT, + SCTP_PEER_INIT(initchunk)); + + /* Reset init error count upon receipt of INIT-ACK. */ + sctp_add_cmd_sf(commands, SCTP_CMD_INIT_COUNTER_RESET, SCTP_NULL()); + + /* 5.1 C) "A" shall stop the T1-init timer and leave + * COOKIE-WAIT state. "A" shall then ... start the T1-cookie + * timer, and enter the COOKIE-ECHOED state. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T1_INIT)); + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_START, + SCTP_TO(SCTP_EVENT_TIMEOUT_T1_COOKIE)); + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_COOKIE_ECHOED)); + + /* SCTP-AUTH: generate the association shared keys so that + * we can potentially sign the COOKIE-ECHO. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_ASSOC_SHKEY, SCTP_NULL()); + + /* 5.1 C) "A" shall then send the State Cookie received in the + * INIT ACK chunk in a COOKIE ECHO chunk, ... + */ + /* If there is any errors to report, send the ERROR chunk generated + * for unknown parameters as well. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_GEN_COOKIE_ECHO, + SCTP_CHUNK(err_chunk)); + + return SCTP_DISPOSITION_CONSUME; +} + +static bool sctp_auth_chunk_verify(struct net *net, struct sctp_chunk *chunk, + const struct sctp_association *asoc) +{ + struct sctp_chunk auth; + + if (!chunk->auth_chunk) + return true; + + /* SCTP-AUTH: auth_chunk pointer is only set when the cookie-echo + * is supposed to be authenticated and we have to do delayed + * authentication. We've just recreated the association using + * the information in the cookie and now it's much easier to + * do the authentication. + */ + + /* Make sure that we and the peer are AUTH capable */ + if (!net->sctp.auth_enable || !asoc->peer.auth_capable) + return false; + + /* set-up our fake chunk so that we can process it */ + auth.skb = chunk->auth_chunk; + auth.asoc = chunk->asoc; + auth.sctp_hdr = chunk->sctp_hdr; + auth.chunk_hdr = (struct sctp_chunkhdr *) + skb_push(chunk->auth_chunk, + sizeof(struct sctp_chunkhdr)); + skb_pull(chunk->auth_chunk, sizeof(struct sctp_chunkhdr)); + auth.transport = chunk->transport; + + return sctp_sf_authenticate(asoc, &auth) == SCTP_IERROR_NO_ERROR; +} + +/* + * Respond to a normal COOKIE ECHO chunk. + * We are the side that is being asked for an association. + * + * Section: 5.1 Normal Establishment of an Association, D + * D) Upon reception of the COOKIE ECHO chunk, Endpoint "Z" will reply + * with a COOKIE ACK chunk after building a TCB and moving to + * the ESTABLISHED state. A COOKIE ACK chunk may be bundled with + * any pending DATA chunks (and/or SACK chunks), but the COOKIE ACK + * chunk MUST be the first chunk in the packet. + * + * IMPLEMENTATION NOTE: An implementation may choose to send the + * Communication Up notification to the SCTP user upon reception + * of a valid COOKIE ECHO chunk. + * + * Verification Tag: 8.5.1 Exceptions in Verification Tag Rules + * D) Rules for packet carrying a COOKIE ECHO + * + * - When sending a COOKIE ECHO, the endpoint MUST use the value of the + * Initial Tag received in the INIT ACK. + * + * - The receiver of a COOKIE ECHO follows the procedures in Section 5. + * + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_do_5_1D_ce(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_ulpevent *ev, *ai_ev = NULL, *auth_ev = NULL; + struct sctp_association *new_asoc; + struct sctp_init_chunk *peer_init; + struct sctp_chunk *chunk = arg; + struct sctp_chunk *err_chk_p; + struct sctp_chunk *repl; + struct sock *sk; + int error = 0; + + if (asoc && !sctp_vtag_verify(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* If the packet is an OOTB packet which is temporarily on the + * control endpoint, respond with an ABORT. + */ + if (ep == sctp_sk(net->sctp.ctl_sock)->ep) { + SCTP_INC_STATS(net, SCTP_MIB_OUTOFBLUES); + return sctp_sf_tabort_8_4_8(net, ep, asoc, type, arg, commands); + } + + /* Make sure that the COOKIE_ECHO chunk has a valid length. + * In this case, we check that we have enough for at least a + * chunk header. More detailed verification is done + * in sctp_unpack_cookie(). + */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_chunkhdr))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + /* If the endpoint is not listening or if the number of associations + * on the TCP-style socket exceed the max backlog, respond with an + * ABORT. + */ + sk = ep->base.sk; + if (!sctp_sstate(sk, LISTENING) || + (sctp_style(sk, TCP) && sk_acceptq_is_full(sk))) + return sctp_sf_tabort_8_4_8(net, ep, asoc, type, arg, commands); + + /* "Decode" the chunk. We have no optional parameters so we + * are in good shape. + */ + chunk->subh.cookie_hdr = + (struct sctp_signed_cookie *)chunk->skb->data; + if (!pskb_pull(chunk->skb, ntohs(chunk->chunk_hdr->length) - + sizeof(struct sctp_chunkhdr))) + goto nomem; + + /* 5.1 D) Upon reception of the COOKIE ECHO chunk, Endpoint + * "Z" will reply with a COOKIE ACK chunk after building a TCB + * and moving to the ESTABLISHED state. + */ + new_asoc = sctp_unpack_cookie(ep, asoc, chunk, GFP_ATOMIC, &error, + &err_chk_p); + + /* FIXME: + * If the re-build failed, what is the proper error path + * from here? + * + * [We should abort the association. --piggy] + */ + if (!new_asoc) { + /* FIXME: Several errors are possible. A bad cookie should + * be silently discarded, but think about logging it too. + */ + switch (error) { + case -SCTP_IERROR_NOMEM: + goto nomem; + + case -SCTP_IERROR_STALE_COOKIE: + sctp_send_stale_cookie_err(net, ep, asoc, chunk, commands, + err_chk_p); + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + case -SCTP_IERROR_BAD_SIG: + default: + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + } + } + + if (security_sctp_assoc_request(new_asoc, chunk->head_skb ?: chunk->skb)) { + sctp_association_free(new_asoc); + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + } + + /* Delay state machine commands until later. + * + * Re-build the bind address for the association is done in + * the sctp_unpack_cookie() already. + */ + /* This is a brand-new association, so these are not yet side + * effects--it is safe to run them here. + */ + peer_init = &chunk->subh.cookie_hdr->c.peer_init[0]; + + if (!sctp_process_init(new_asoc, chunk, + &chunk->subh.cookie_hdr->c.peer_addr, + peer_init, GFP_ATOMIC)) + goto nomem_init; + + /* SCTP-AUTH: Now that we've populate required fields in + * sctp_process_init, set up the association shared keys as + * necessary so that we can potentially authenticate the ACK + */ + error = sctp_auth_asoc_init_active_key(new_asoc, GFP_ATOMIC); + if (error) + goto nomem_init; + + if (!sctp_auth_chunk_verify(net, chunk, new_asoc)) { + sctp_association_free(new_asoc); + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + } + + repl = sctp_make_cookie_ack(new_asoc, chunk); + if (!repl) + goto nomem_init; + + /* RFC 2960 5.1 Normal Establishment of an Association + * + * D) IMPLEMENTATION NOTE: An implementation may choose to + * send the Communication Up notification to the SCTP user + * upon reception of a valid COOKIE ECHO chunk. + */ + ev = sctp_ulpevent_make_assoc_change(new_asoc, 0, SCTP_COMM_UP, 0, + new_asoc->c.sinit_num_ostreams, + new_asoc->c.sinit_max_instreams, + NULL, GFP_ATOMIC); + if (!ev) + goto nomem_ev; + + /* Sockets API Draft Section 5.3.1.6 + * When a peer sends a Adaptation Layer Indication parameter , SCTP + * delivers this notification to inform the application that of the + * peers requested adaptation layer. + */ + if (new_asoc->peer.adaptation_ind) { + ai_ev = sctp_ulpevent_make_adaptation_indication(new_asoc, + GFP_ATOMIC); + if (!ai_ev) + goto nomem_aiev; + } + + if (!new_asoc->peer.auth_capable) { + auth_ev = sctp_ulpevent_make_authkey(new_asoc, 0, + SCTP_AUTH_NO_AUTH, + GFP_ATOMIC); + if (!auth_ev) + goto nomem_authev; + } + + /* Add all the state machine commands now since we've created + * everything. This way we don't introduce memory corruptions + * during side-effect processing and correctly count established + * associations. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_ASOC, SCTP_ASOC(new_asoc)); + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_ESTABLISHED)); + SCTP_INC_STATS(net, SCTP_MIB_CURRESTAB); + SCTP_INC_STATS(net, SCTP_MIB_PASSIVEESTABS); + sctp_add_cmd_sf(commands, SCTP_CMD_HB_TIMERS_START, SCTP_NULL()); + + if (new_asoc->timeouts[SCTP_EVENT_TIMEOUT_AUTOCLOSE]) + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_START, + SCTP_TO(SCTP_EVENT_TIMEOUT_AUTOCLOSE)); + + /* This will send the COOKIE ACK */ + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(repl)); + + /* Queue the ASSOC_CHANGE event */ + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, SCTP_ULPEVENT(ev)); + + /* Send up the Adaptation Layer Indication event */ + if (ai_ev) + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, + SCTP_ULPEVENT(ai_ev)); + + if (auth_ev) + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, + SCTP_ULPEVENT(auth_ev)); + + return SCTP_DISPOSITION_CONSUME; + +nomem_authev: + sctp_ulpevent_free(ai_ev); +nomem_aiev: + sctp_ulpevent_free(ev); +nomem_ev: + sctp_chunk_free(repl); +nomem_init: + sctp_association_free(new_asoc); +nomem: + return SCTP_DISPOSITION_NOMEM; +} + +/* + * Respond to a normal COOKIE ACK chunk. + * We are the side that is asking for an association. + * + * RFC 2960 5.1 Normal Establishment of an Association + * + * E) Upon reception of the COOKIE ACK, endpoint "A" will move from the + * COOKIE-ECHOED state to the ESTABLISHED state, stopping the T1-cookie + * timer. It may also notify its ULP about the successful + * establishment of the association with a Communication Up + * notification (see Section 10). + * + * Verification Tag: + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_do_5_1E_ca(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + struct sctp_ulpevent *ev; + + if (!sctp_vtag_verify(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Set peer label for connection. */ + if (security_sctp_assoc_established((struct sctp_association *)asoc, + chunk->head_skb ?: chunk->skb)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Verify that the chunk length for the COOKIE-ACK is OK. + * If we don't do this, any bundled chunks may be junked. + */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_chunkhdr))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + /* Reset init error count upon receipt of COOKIE-ACK, + * to avoid problems with the management of this + * counter in stale cookie situations when a transition back + * from the COOKIE-ECHOED state to the COOKIE-WAIT + * state is performed. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_INIT_COUNTER_RESET, SCTP_NULL()); + + /* RFC 2960 5.1 Normal Establishment of an Association + * + * E) Upon reception of the COOKIE ACK, endpoint "A" will move + * from the COOKIE-ECHOED state to the ESTABLISHED state, + * stopping the T1-cookie timer. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T1_COOKIE)); + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_ESTABLISHED)); + SCTP_INC_STATS(net, SCTP_MIB_CURRESTAB); + SCTP_INC_STATS(net, SCTP_MIB_ACTIVEESTABS); + sctp_add_cmd_sf(commands, SCTP_CMD_HB_TIMERS_START, SCTP_NULL()); + if (asoc->timeouts[SCTP_EVENT_TIMEOUT_AUTOCLOSE]) + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_START, + SCTP_TO(SCTP_EVENT_TIMEOUT_AUTOCLOSE)); + + /* It may also notify its ULP about the successful + * establishment of the association with a Communication Up + * notification (see Section 10). + */ + ev = sctp_ulpevent_make_assoc_change(asoc, 0, SCTP_COMM_UP, + 0, asoc->c.sinit_num_ostreams, + asoc->c.sinit_max_instreams, + NULL, GFP_ATOMIC); + + if (!ev) + goto nomem; + + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, SCTP_ULPEVENT(ev)); + + /* Sockets API Draft Section 5.3.1.6 + * When a peer sends a Adaptation Layer Indication parameter , SCTP + * delivers this notification to inform the application that of the + * peers requested adaptation layer. + */ + if (asoc->peer.adaptation_ind) { + ev = sctp_ulpevent_make_adaptation_indication(asoc, GFP_ATOMIC); + if (!ev) + goto nomem; + + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, + SCTP_ULPEVENT(ev)); + } + + if (!asoc->peer.auth_capable) { + ev = sctp_ulpevent_make_authkey(asoc, 0, SCTP_AUTH_NO_AUTH, + GFP_ATOMIC); + if (!ev) + goto nomem; + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, + SCTP_ULPEVENT(ev)); + } + + return SCTP_DISPOSITION_CONSUME; +nomem: + return SCTP_DISPOSITION_NOMEM; +} + +/* Generate and sendout a heartbeat packet. */ +static enum sctp_disposition sctp_sf_heartbeat( + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_transport *transport = (struct sctp_transport *) arg; + struct sctp_chunk *reply; + + /* Send a heartbeat to our peer. */ + reply = sctp_make_heartbeat(asoc, transport, 0); + if (!reply) + return SCTP_DISPOSITION_NOMEM; + + /* Set rto_pending indicating that an RTT measurement + * is started with this heartbeat chunk. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_RTO_PENDING, + SCTP_TRANSPORT(transport)); + + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(reply)); + return SCTP_DISPOSITION_CONSUME; +} + +/* Generate a HEARTBEAT packet on the given transport. */ +enum sctp_disposition sctp_sf_sendbeat_8_3(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_transport *transport = (struct sctp_transport *) arg; + + if (asoc->overall_error_count >= asoc->max_retrans) { + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, + SCTP_ERROR(ETIMEDOUT)); + /* CMD_ASSOC_FAILED calls CMD_DELETE_TCB. */ + sctp_add_cmd_sf(commands, SCTP_CMD_ASSOC_FAILED, + SCTP_PERR(SCTP_ERROR_NO_ERROR)); + SCTP_INC_STATS(net, SCTP_MIB_ABORTEDS); + SCTP_DEC_STATS(net, SCTP_MIB_CURRESTAB); + return SCTP_DISPOSITION_DELETE_TCB; + } + + /* Section 3.3.5. + * The Sender-specific Heartbeat Info field should normally include + * information about the sender's current time when this HEARTBEAT + * chunk is sent and the destination transport address to which this + * HEARTBEAT is sent (see Section 8.3). + */ + + if (transport->param_flags & SPP_HB_ENABLE) { + if (SCTP_DISPOSITION_NOMEM == + sctp_sf_heartbeat(ep, asoc, type, arg, + commands)) + return SCTP_DISPOSITION_NOMEM; + + /* Set transport error counter and association error counter + * when sending heartbeat. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_TRANSPORT_HB_SENT, + SCTP_TRANSPORT(transport)); + } + sctp_add_cmd_sf(commands, SCTP_CMD_TRANSPORT_IDLE, + SCTP_TRANSPORT(transport)); + sctp_add_cmd_sf(commands, SCTP_CMD_HB_TIMER_UPDATE, + SCTP_TRANSPORT(transport)); + + return SCTP_DISPOSITION_CONSUME; +} + +/* resend asoc strreset_chunk. */ +enum sctp_disposition sctp_sf_send_reconf(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_transport *transport = arg; + + if (asoc->overall_error_count >= asoc->max_retrans) { + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, + SCTP_ERROR(ETIMEDOUT)); + /* CMD_ASSOC_FAILED calls CMD_DELETE_TCB. */ + sctp_add_cmd_sf(commands, SCTP_CMD_ASSOC_FAILED, + SCTP_PERR(SCTP_ERROR_NO_ERROR)); + SCTP_INC_STATS(net, SCTP_MIB_ABORTEDS); + SCTP_DEC_STATS(net, SCTP_MIB_CURRESTAB); + return SCTP_DISPOSITION_DELETE_TCB; + } + + sctp_chunk_hold(asoc->strreset_chunk); + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, + SCTP_CHUNK(asoc->strreset_chunk)); + sctp_add_cmd_sf(commands, SCTP_CMD_STRIKE, SCTP_TRANSPORT(transport)); + + return SCTP_DISPOSITION_CONSUME; +} + +/* send hb chunk with padding for PLPMUTD. */ +enum sctp_disposition sctp_sf_send_probe(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_transport *transport = (struct sctp_transport *)arg; + struct sctp_chunk *reply; + + if (!sctp_transport_pl_enabled(transport)) + return SCTP_DISPOSITION_CONSUME; + + sctp_transport_pl_send(transport); + reply = sctp_make_heartbeat(asoc, transport, transport->pl.probe_size); + if (!reply) + return SCTP_DISPOSITION_NOMEM; + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(reply)); + sctp_add_cmd_sf(commands, SCTP_CMD_PROBE_TIMER_UPDATE, + SCTP_TRANSPORT(transport)); + + return SCTP_DISPOSITION_CONSUME; +} + +/* + * Process an heartbeat request. + * + * Section: 8.3 Path Heartbeat + * The receiver of the HEARTBEAT should immediately respond with a + * HEARTBEAT ACK that contains the Heartbeat Information field copied + * from the received HEARTBEAT chunk. + * + * Verification Tag: 8.5 Verification Tag [Normal verification] + * When receiving an SCTP packet, the endpoint MUST ensure that the + * value in the Verification Tag field of the received SCTP packet + * matches its own Tag. If the received Verification Tag value does not + * match the receiver's own tag value, the receiver shall silently + * discard the packet and shall not process it any further except for + * those cases listed in Section 8.5.1 below. + * + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_beat_8_3(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, struct sctp_cmd_seq *commands) +{ + struct sctp_paramhdr *param_hdr; + struct sctp_chunk *chunk = arg; + struct sctp_chunk *reply; + size_t paylen = 0; + + if (!sctp_vtag_verify(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Make sure that the HEARTBEAT chunk has a valid length. */ + if (!sctp_chunk_length_valid(chunk, + sizeof(struct sctp_heartbeat_chunk))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + /* 8.3 The receiver of the HEARTBEAT should immediately + * respond with a HEARTBEAT ACK that contains the Heartbeat + * Information field copied from the received HEARTBEAT chunk. + */ + chunk->subh.hb_hdr = (struct sctp_heartbeathdr *)chunk->skb->data; + param_hdr = (struct sctp_paramhdr *)chunk->subh.hb_hdr; + paylen = ntohs(chunk->chunk_hdr->length) - sizeof(struct sctp_chunkhdr); + + if (ntohs(param_hdr->length) > paylen) + return sctp_sf_violation_paramlen(net, ep, asoc, type, arg, + param_hdr, commands); + + if (!pskb_pull(chunk->skb, paylen)) + goto nomem; + + reply = sctp_make_heartbeat_ack(asoc, chunk, param_hdr, paylen); + if (!reply) + goto nomem; + + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(reply)); + return SCTP_DISPOSITION_CONSUME; + +nomem: + return SCTP_DISPOSITION_NOMEM; +} + +/* + * Process the returning HEARTBEAT ACK. + * + * Section: 8.3 Path Heartbeat + * Upon the receipt of the HEARTBEAT ACK, the sender of the HEARTBEAT + * should clear the error counter of the destination transport + * address to which the HEARTBEAT was sent, and mark the destination + * transport address as active if it is not so marked. The endpoint may + * optionally report to the upper layer when an inactive destination + * address is marked as active due to the reception of the latest + * HEARTBEAT ACK. The receiver of the HEARTBEAT ACK must also + * clear the association overall error count as well (as defined + * in section 8.1). + * + * The receiver of the HEARTBEAT ACK should also perform an RTT + * measurement for that destination transport address using the time + * value carried in the HEARTBEAT ACK chunk. + * + * Verification Tag: 8.5 Verification Tag [Normal verification] + * + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_backbeat_8_3(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_sender_hb_info *hbinfo; + struct sctp_chunk *chunk = arg; + struct sctp_transport *link; + unsigned long max_interval; + union sctp_addr from_addr; + + if (!sctp_vtag_verify(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Make sure that the HEARTBEAT-ACK chunk has a valid length. */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_chunkhdr) + + sizeof(*hbinfo))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + hbinfo = (struct sctp_sender_hb_info *)chunk->skb->data; + /* Make sure that the length of the parameter is what we expect */ + if (ntohs(hbinfo->param_hdr.length) != sizeof(*hbinfo)) + return SCTP_DISPOSITION_DISCARD; + + from_addr = hbinfo->daddr; + link = sctp_assoc_lookup_paddr(asoc, &from_addr); + + /* This should never happen, but lets log it if so. */ + if (unlikely(!link)) { + if (from_addr.sa.sa_family == AF_INET6) { + net_warn_ratelimited("%s association %p could not find address %pI6\n", + __func__, + asoc, + &from_addr.v6.sin6_addr); + } else { + net_warn_ratelimited("%s association %p could not find address %pI4\n", + __func__, + asoc, + &from_addr.v4.sin_addr.s_addr); + } + return SCTP_DISPOSITION_DISCARD; + } + + /* Validate the 64-bit random nonce. */ + if (hbinfo->hb_nonce != link->hb_nonce) + return SCTP_DISPOSITION_DISCARD; + + if (hbinfo->probe_size) { + if (hbinfo->probe_size != link->pl.probe_size || + !sctp_transport_pl_enabled(link)) + return SCTP_DISPOSITION_DISCARD; + + if (sctp_transport_pl_recv(link)) + return SCTP_DISPOSITION_CONSUME; + + return sctp_sf_send_probe(net, ep, asoc, type, link, commands); + } + + max_interval = link->hbinterval + link->rto; + + /* Check if the timestamp looks valid. */ + if (time_after(hbinfo->sent_at, jiffies) || + time_after(jiffies, hbinfo->sent_at + max_interval)) { + pr_debug("%s: HEARTBEAT ACK with invalid timestamp received " + "for transport:%p\n", __func__, link); + + return SCTP_DISPOSITION_DISCARD; + } + + /* 8.3 Upon the receipt of the HEARTBEAT ACK, the sender of + * the HEARTBEAT should clear the error counter of the + * destination transport address to which the HEARTBEAT was + * sent and mark the destination transport address as active if + * it is not so marked. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_TRANSPORT_ON, SCTP_TRANSPORT(link)); + + return SCTP_DISPOSITION_CONSUME; +} + +/* Helper function to send out an abort for the restart + * condition. + */ +static int sctp_sf_send_restart_abort(struct net *net, union sctp_addr *ssa, + struct sctp_chunk *init, + struct sctp_cmd_seq *commands) +{ + struct sctp_af *af = sctp_get_af_specific(ssa->v4.sin_family); + union sctp_addr_param *addrparm; + struct sctp_errhdr *errhdr; + char buffer[sizeof(*errhdr) + sizeof(*addrparm)]; + struct sctp_endpoint *ep; + struct sctp_packet *pkt; + int len; + + /* Build the error on the stack. We are way to malloc crazy + * throughout the code today. + */ + errhdr = (struct sctp_errhdr *)buffer; + addrparm = (union sctp_addr_param *)errhdr->variable; + + /* Copy into a parm format. */ + len = af->to_addr_param(ssa, addrparm); + len += sizeof(*errhdr); + + errhdr->cause = SCTP_ERROR_RESTART; + errhdr->length = htons(len); + + /* Assign to the control socket. */ + ep = sctp_sk(net->sctp.ctl_sock)->ep; + + /* Association is NULL since this may be a restart attack and we + * want to send back the attacker's vtag. + */ + pkt = sctp_abort_pkt_new(net, ep, NULL, init, errhdr, len); + + if (!pkt) + goto out; + sctp_add_cmd_sf(commands, SCTP_CMD_SEND_PKT, SCTP_PACKET(pkt)); + + SCTP_INC_STATS(net, SCTP_MIB_OUTCTRLCHUNKS); + + /* Discard the rest of the inbound packet. */ + sctp_add_cmd_sf(commands, SCTP_CMD_DISCARD_PACKET, SCTP_NULL()); + +out: + /* Even if there is no memory, treat as a failure so + * the packet will get dropped. + */ + return 0; +} + +static bool list_has_sctp_addr(const struct list_head *list, + union sctp_addr *ipaddr) +{ + struct sctp_transport *addr; + + list_for_each_entry(addr, list, transports) { + if (sctp_cmp_addr_exact(ipaddr, &addr->ipaddr)) + return true; + } + + return false; +} +/* A restart is occurring, check to make sure no new addresses + * are being added as we may be under a takeover attack. + */ +static int sctp_sf_check_restart_addrs(const struct sctp_association *new_asoc, + const struct sctp_association *asoc, + struct sctp_chunk *init, + struct sctp_cmd_seq *commands) +{ + struct net *net = new_asoc->base.net; + struct sctp_transport *new_addr; + int ret = 1; + + /* Implementor's Guide - Section 5.2.2 + * ... + * Before responding the endpoint MUST check to see if the + * unexpected INIT adds new addresses to the association. If new + * addresses are added to the association, the endpoint MUST respond + * with an ABORT.. + */ + + /* Search through all current addresses and make sure + * we aren't adding any new ones. + */ + list_for_each_entry(new_addr, &new_asoc->peer.transport_addr_list, + transports) { + if (!list_has_sctp_addr(&asoc->peer.transport_addr_list, + &new_addr->ipaddr)) { + sctp_sf_send_restart_abort(net, &new_addr->ipaddr, init, + commands); + ret = 0; + break; + } + } + + /* Return success if all addresses were found. */ + return ret; +} + +/* Populate the verification/tie tags based on overlapping INIT + * scenario. + * + * Note: Do not use in CLOSED or SHUTDOWN-ACK-SENT state. + */ +static void sctp_tietags_populate(struct sctp_association *new_asoc, + const struct sctp_association *asoc) +{ + switch (asoc->state) { + + /* 5.2.1 INIT received in COOKIE-WAIT or COOKIE-ECHOED State */ + + case SCTP_STATE_COOKIE_WAIT: + new_asoc->c.my_vtag = asoc->c.my_vtag; + new_asoc->c.my_ttag = asoc->c.my_vtag; + new_asoc->c.peer_ttag = 0; + break; + + case SCTP_STATE_COOKIE_ECHOED: + new_asoc->c.my_vtag = asoc->c.my_vtag; + new_asoc->c.my_ttag = asoc->c.my_vtag; + new_asoc->c.peer_ttag = asoc->c.peer_vtag; + break; + + /* 5.2.2 Unexpected INIT in States Other than CLOSED, COOKIE-ECHOED, + * COOKIE-WAIT and SHUTDOWN-ACK-SENT + */ + default: + new_asoc->c.my_ttag = asoc->c.my_vtag; + new_asoc->c.peer_ttag = asoc->c.peer_vtag; + break; + } + + /* Other parameters for the endpoint SHOULD be copied from the + * existing parameters of the association (e.g. number of + * outbound streams) into the INIT ACK and cookie. + */ + new_asoc->rwnd = asoc->rwnd; + new_asoc->c.sinit_num_ostreams = asoc->c.sinit_num_ostreams; + new_asoc->c.sinit_max_instreams = asoc->c.sinit_max_instreams; + new_asoc->c.initial_tsn = asoc->c.initial_tsn; +} + +/* + * Compare vtag/tietag values to determine unexpected COOKIE-ECHO + * handling action. + * + * RFC 2960 5.2.4 Handle a COOKIE ECHO when a TCB exists. + * + * Returns value representing action to be taken. These action values + * correspond to Action/Description values in RFC 2960, Table 2. + */ +static char sctp_tietags_compare(struct sctp_association *new_asoc, + const struct sctp_association *asoc) +{ + /* In this case, the peer may have restarted. */ + if ((asoc->c.my_vtag != new_asoc->c.my_vtag) && + (asoc->c.peer_vtag != new_asoc->c.peer_vtag) && + (asoc->c.my_vtag == new_asoc->c.my_ttag) && + (asoc->c.peer_vtag == new_asoc->c.peer_ttag)) + return 'A'; + + /* Collision case B. */ + if ((asoc->c.my_vtag == new_asoc->c.my_vtag) && + ((asoc->c.peer_vtag != new_asoc->c.peer_vtag) || + (0 == asoc->c.peer_vtag))) { + return 'B'; + } + + /* Collision case D. */ + if ((asoc->c.my_vtag == new_asoc->c.my_vtag) && + (asoc->c.peer_vtag == new_asoc->c.peer_vtag)) + return 'D'; + + /* Collision case C. */ + if ((asoc->c.my_vtag != new_asoc->c.my_vtag) && + (asoc->c.peer_vtag == new_asoc->c.peer_vtag) && + (0 == new_asoc->c.my_ttag) && + (0 == new_asoc->c.peer_ttag)) + return 'C'; + + /* No match to any of the special cases; discard this packet. */ + return 'E'; +} + +/* Common helper routine for both duplicate and simultaneous INIT + * chunk handling. + */ +static enum sctp_disposition sctp_sf_do_unexpected_init( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg, *repl, *err_chunk; + struct sctp_unrecognized_param *unk_param; + struct sctp_association *new_asoc; + enum sctp_disposition retval; + struct sctp_packet *packet; + int len; + + /* 6.10 Bundling + * An endpoint MUST NOT bundle INIT, INIT ACK or + * SHUTDOWN COMPLETE with any other chunks. + * + * IG Section 2.11.2 + * Furthermore, we require that the receiver of an INIT chunk MUST + * enforce these rules by silently discarding an arriving packet + * with an INIT chunk that is bundled with other chunks. + */ + if (!chunk->singleton) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Make sure that the INIT chunk has a valid length. */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_init_chunk))) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* 3.1 A packet containing an INIT chunk MUST have a zero Verification + * Tag. + */ + if (chunk->sctp_hdr->vtag != 0) + return sctp_sf_tabort_8_4_8(net, ep, asoc, type, arg, commands); + + if (SCTP_INPUT_CB(chunk->skb)->encap_port != chunk->transport->encap_port) + return sctp_sf_new_encap_port(net, ep, asoc, type, arg, commands); + + /* Grab the INIT header. */ + chunk->subh.init_hdr = (struct sctp_inithdr *)chunk->skb->data; + + /* Tag the variable length parameters. */ + chunk->param_hdr.v = skb_pull(chunk->skb, sizeof(struct sctp_inithdr)); + + /* Verify the INIT chunk before processing it. */ + err_chunk = NULL; + if (!sctp_verify_init(net, ep, asoc, chunk->chunk_hdr->type, + (struct sctp_init_chunk *)chunk->chunk_hdr, chunk, + &err_chunk)) { + /* This chunk contains fatal error. It is to be discarded. + * Send an ABORT, with causes if there is any. + */ + if (err_chunk) { + packet = sctp_abort_pkt_new(net, ep, asoc, arg, + (__u8 *)(err_chunk->chunk_hdr) + + sizeof(struct sctp_chunkhdr), + ntohs(err_chunk->chunk_hdr->length) - + sizeof(struct sctp_chunkhdr)); + + if (packet) { + sctp_add_cmd_sf(commands, SCTP_CMD_SEND_PKT, + SCTP_PACKET(packet)); + SCTP_INC_STATS(net, SCTP_MIB_OUTCTRLCHUNKS); + retval = SCTP_DISPOSITION_CONSUME; + } else { + retval = SCTP_DISPOSITION_NOMEM; + } + goto cleanup; + } else { + return sctp_sf_tabort_8_4_8(net, ep, asoc, type, arg, + commands); + } + } + + /* + * Other parameters for the endpoint SHOULD be copied from the + * existing parameters of the association (e.g. number of + * outbound streams) into the INIT ACK and cookie. + * FIXME: We are copying parameters from the endpoint not the + * association. + */ + new_asoc = sctp_make_temp_asoc(ep, chunk, GFP_ATOMIC); + if (!new_asoc) + goto nomem; + + /* Update socket peer label if first association. */ + if (security_sctp_assoc_request(new_asoc, chunk->skb)) { + sctp_association_free(new_asoc); + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + } + + if (sctp_assoc_set_bind_addr_from_ep(new_asoc, + sctp_scope(sctp_source(chunk)), GFP_ATOMIC) < 0) + goto nomem; + + /* In the outbound INIT ACK the endpoint MUST copy its current + * Verification Tag and Peers Verification tag into a reserved + * place (local tie-tag and per tie-tag) within the state cookie. + */ + if (!sctp_process_init(new_asoc, chunk, sctp_source(chunk), + (struct sctp_init_chunk *)chunk->chunk_hdr, + GFP_ATOMIC)) + goto nomem; + + /* Make sure no new addresses are being added during the + * restart. Do not do this check for COOKIE-WAIT state, + * since there are no peer addresses to check against. + * Upon return an ABORT will have been sent if needed. + */ + if (!sctp_state(asoc, COOKIE_WAIT)) { + if (!sctp_sf_check_restart_addrs(new_asoc, asoc, chunk, + commands)) { + retval = SCTP_DISPOSITION_CONSUME; + goto nomem_retval; + } + } + + sctp_tietags_populate(new_asoc, asoc); + + /* B) "Z" shall respond immediately with an INIT ACK chunk. */ + + /* If there are errors need to be reported for unknown parameters, + * make sure to reserve enough room in the INIT ACK for them. + */ + len = 0; + if (err_chunk) { + len = ntohs(err_chunk->chunk_hdr->length) - + sizeof(struct sctp_chunkhdr); + } + + repl = sctp_make_init_ack(new_asoc, chunk, GFP_ATOMIC, len); + if (!repl) + goto nomem; + + /* If there are errors need to be reported for unknown parameters, + * include them in the outgoing INIT ACK as "Unrecognized parameter" + * parameter. + */ + if (err_chunk) { + /* Get the "Unrecognized parameter" parameter(s) out of the + * ERROR chunk generated by sctp_verify_init(). Since the + * error cause code for "unknown parameter" and the + * "Unrecognized parameter" type is the same, we can + * construct the parameters in INIT ACK by copying the + * ERROR causes over. + */ + unk_param = (struct sctp_unrecognized_param *) + ((__u8 *)(err_chunk->chunk_hdr) + + sizeof(struct sctp_chunkhdr)); + /* Replace the cause code with the "Unrecognized parameter" + * parameter type. + */ + sctp_addto_chunk(repl, len, unk_param); + } + + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_ASOC, SCTP_ASOC(new_asoc)); + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(repl)); + + /* + * Note: After sending out INIT ACK with the State Cookie parameter, + * "Z" MUST NOT allocate any resources for this new association. + * Otherwise, "Z" will be vulnerable to resource attacks. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_DELETE_TCB, SCTP_NULL()); + retval = SCTP_DISPOSITION_CONSUME; + + return retval; + +nomem: + retval = SCTP_DISPOSITION_NOMEM; +nomem_retval: + if (new_asoc) + sctp_association_free(new_asoc); +cleanup: + if (err_chunk) + sctp_chunk_free(err_chunk); + return retval; +} + +/* + * Handle simultaneous INIT. + * This means we started an INIT and then we got an INIT request from + * our peer. + * + * Section: 5.2.1 INIT received in COOKIE-WAIT or COOKIE-ECHOED State (Item B) + * This usually indicates an initialization collision, i.e., each + * endpoint is attempting, at about the same time, to establish an + * association with the other endpoint. + * + * Upon receipt of an INIT in the COOKIE-WAIT or COOKIE-ECHOED state, an + * endpoint MUST respond with an INIT ACK using the same parameters it + * sent in its original INIT chunk (including its Verification Tag, + * unchanged). These original parameters are combined with those from the + * newly received INIT chunk. The endpoint shall also generate a State + * Cookie with the INIT ACK. The endpoint uses the parameters sent in its + * INIT to calculate the State Cookie. + * + * After that, the endpoint MUST NOT change its state, the T1-init + * timer shall be left running and the corresponding TCB MUST NOT be + * destroyed. The normal procedures for handling State Cookies when + * a TCB exists will resolve the duplicate INITs to a single association. + * + * For an endpoint that is in the COOKIE-ECHOED state it MUST populate + * its Tie-Tags with the Tag information of itself and its peer (see + * section 5.2.2 for a description of the Tie-Tags). + * + * Verification Tag: Not explicit, but an INIT can not have a valid + * verification tag, so we skip the check. + * + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_do_5_2_1_siminit( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + /* Call helper to do the real work for both simultaneous and + * duplicate INIT chunk handling. + */ + return sctp_sf_do_unexpected_init(net, ep, asoc, type, arg, commands); +} + +/* + * Handle duplicated INIT messages. These are usually delayed + * restransmissions. + * + * Section: 5.2.2 Unexpected INIT in States Other than CLOSED, + * COOKIE-ECHOED and COOKIE-WAIT + * + * Unless otherwise stated, upon reception of an unexpected INIT for + * this association, the endpoint shall generate an INIT ACK with a + * State Cookie. In the outbound INIT ACK the endpoint MUST copy its + * current Verification Tag and peer's Verification Tag into a reserved + * place within the state cookie. We shall refer to these locations as + * the Peer's-Tie-Tag and the Local-Tie-Tag. The outbound SCTP packet + * containing this INIT ACK MUST carry a Verification Tag value equal to + * the Initiation Tag found in the unexpected INIT. And the INIT ACK + * MUST contain a new Initiation Tag (randomly generated see Section + * 5.3.1). Other parameters for the endpoint SHOULD be copied from the + * existing parameters of the association (e.g. number of outbound + * streams) into the INIT ACK and cookie. + * + * After sending out the INIT ACK, the endpoint shall take no further + * actions, i.e., the existing association, including its current state, + * and the corresponding TCB MUST NOT be changed. + * + * Note: Only when a TCB exists and the association is not in a COOKIE- + * WAIT state are the Tie-Tags populated. For a normal association INIT + * (i.e. the endpoint is in a COOKIE-WAIT state), the Tie-Tags MUST be + * set to 0 (indicating that no previous TCB existed). The INIT ACK and + * State Cookie are populated as specified in section 5.2.1. + * + * Verification Tag: Not specified, but an INIT has no way of knowing + * what the verification tag could be, so we ignore it. + * + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_do_5_2_2_dupinit( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + /* Call helper to do the real work for both simultaneous and + * duplicate INIT chunk handling. + */ + return sctp_sf_do_unexpected_init(net, ep, asoc, type, arg, commands); +} + + +/* + * Unexpected INIT-ACK handler. + * + * Section 5.2.3 + * If an INIT ACK received by an endpoint in any state other than the + * COOKIE-WAIT state, the endpoint should discard the INIT ACK chunk. + * An unexpected INIT ACK usually indicates the processing of an old or + * duplicated INIT chunk. +*/ +enum sctp_disposition sctp_sf_do_5_2_3_initack( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + /* Per the above section, we'll discard the chunk if we have an + * endpoint. If this is an OOTB INIT-ACK, treat it as such. + */ + if (ep == sctp_sk(net->sctp.ctl_sock)->ep) + return sctp_sf_ootb(net, ep, asoc, type, arg, commands); + else + return sctp_sf_discard_chunk(net, ep, asoc, type, arg, commands); +} + +static int sctp_sf_do_assoc_update(struct sctp_association *asoc, + struct sctp_association *new, + struct sctp_cmd_seq *cmds) +{ + struct net *net = asoc->base.net; + struct sctp_chunk *abort; + + if (!sctp_assoc_update(asoc, new)) + return 0; + + abort = sctp_make_abort(asoc, NULL, sizeof(struct sctp_errhdr)); + if (abort) { + sctp_init_cause(abort, SCTP_ERROR_RSRC_LOW, 0); + sctp_add_cmd_sf(cmds, SCTP_CMD_REPLY, SCTP_CHUNK(abort)); + } + sctp_add_cmd_sf(cmds, SCTP_CMD_SET_SK_ERR, SCTP_ERROR(ECONNABORTED)); + sctp_add_cmd_sf(cmds, SCTP_CMD_ASSOC_FAILED, + SCTP_PERR(SCTP_ERROR_RSRC_LOW)); + SCTP_INC_STATS(net, SCTP_MIB_ABORTEDS); + SCTP_DEC_STATS(net, SCTP_MIB_CURRESTAB); + + return -ENOMEM; +} + +/* Unexpected COOKIE-ECHO handler for peer restart (Table 2, action 'A') + * + * Section 5.2.4 + * A) In this case, the peer may have restarted. + */ +static enum sctp_disposition sctp_sf_do_dupcook_a( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + struct sctp_chunk *chunk, + struct sctp_cmd_seq *commands, + struct sctp_association *new_asoc) +{ + struct sctp_init_chunk *peer_init; + enum sctp_disposition disposition; + struct sctp_ulpevent *ev; + struct sctp_chunk *repl; + struct sctp_chunk *err; + + /* new_asoc is a brand-new association, so these are not yet + * side effects--it is safe to run them here. + */ + peer_init = &chunk->subh.cookie_hdr->c.peer_init[0]; + + if (!sctp_process_init(new_asoc, chunk, sctp_source(chunk), peer_init, + GFP_ATOMIC)) + goto nomem; + + if (sctp_auth_asoc_init_active_key(new_asoc, GFP_ATOMIC)) + goto nomem; + + if (!sctp_auth_chunk_verify(net, chunk, new_asoc)) + return SCTP_DISPOSITION_DISCARD; + + /* Make sure no new addresses are being added during the + * restart. Though this is a pretty complicated attack + * since you'd have to get inside the cookie. + */ + if (!sctp_sf_check_restart_addrs(new_asoc, asoc, chunk, commands)) + return SCTP_DISPOSITION_CONSUME; + + /* If the endpoint is in the SHUTDOWN-ACK-SENT state and recognizes + * the peer has restarted (Action A), it MUST NOT setup a new + * association but instead resend the SHUTDOWN ACK and send an ERROR + * chunk with a "Cookie Received while Shutting Down" error cause to + * its peer. + */ + if (sctp_state(asoc, SHUTDOWN_ACK_SENT)) { + disposition = __sctp_sf_do_9_2_reshutack(net, ep, asoc, + SCTP_ST_CHUNK(chunk->chunk_hdr->type), + chunk, commands); + if (SCTP_DISPOSITION_NOMEM == disposition) + goto nomem; + + err = sctp_make_op_error(asoc, chunk, + SCTP_ERROR_COOKIE_IN_SHUTDOWN, + NULL, 0, 0); + if (err) + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, + SCTP_CHUNK(err)); + + return SCTP_DISPOSITION_CONSUME; + } + + /* For now, stop pending T3-rtx and SACK timers, fail any unsent/unacked + * data. Consider the optional choice of resending of this data. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_T3_RTX_TIMERS_STOP, SCTP_NULL()); + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_SACK)); + sctp_add_cmd_sf(commands, SCTP_CMD_PURGE_OUTQUEUE, SCTP_NULL()); + + /* Stop pending T4-rto timer, teardown ASCONF queue, ASCONF-ACK queue + * and ASCONF-ACK cache. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T4_RTO)); + sctp_add_cmd_sf(commands, SCTP_CMD_PURGE_ASCONF_QUEUE, SCTP_NULL()); + + /* Update the content of current association. */ + if (sctp_sf_do_assoc_update((struct sctp_association *)asoc, new_asoc, commands)) + goto nomem; + + repl = sctp_make_cookie_ack(asoc, chunk); + if (!repl) + goto nomem; + + /* Report association restart to upper layer. */ + ev = sctp_ulpevent_make_assoc_change(asoc, 0, SCTP_RESTART, 0, + asoc->c.sinit_num_ostreams, + asoc->c.sinit_max_instreams, + NULL, GFP_ATOMIC); + if (!ev) + goto nomem_ev; + + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, SCTP_ULPEVENT(ev)); + if ((sctp_state(asoc, SHUTDOWN_PENDING) || + sctp_state(asoc, SHUTDOWN_SENT)) && + (sctp_sstate(asoc->base.sk, CLOSING) || + sock_flag(asoc->base.sk, SOCK_DEAD))) { + /* If the socket has been closed by user, don't + * transition to ESTABLISHED. Instead trigger SHUTDOWN + * bundled with COOKIE_ACK. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(repl)); + return sctp_sf_do_9_2_start_shutdown(net, ep, asoc, + SCTP_ST_CHUNK(0), repl, + commands); + } else { + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_ESTABLISHED)); + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(repl)); + } + return SCTP_DISPOSITION_CONSUME; + +nomem_ev: + sctp_chunk_free(repl); +nomem: + return SCTP_DISPOSITION_NOMEM; +} + +/* Unexpected COOKIE-ECHO handler for setup collision (Table 2, action 'B') + * + * Section 5.2.4 + * B) In this case, both sides may be attempting to start an association + * at about the same time but the peer endpoint started its INIT + * after responding to the local endpoint's INIT + */ +/* This case represents an initialization collision. */ +static enum sctp_disposition sctp_sf_do_dupcook_b( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + struct sctp_chunk *chunk, + struct sctp_cmd_seq *commands, + struct sctp_association *new_asoc) +{ + struct sctp_init_chunk *peer_init; + struct sctp_chunk *repl; + + /* new_asoc is a brand-new association, so these are not yet + * side effects--it is safe to run them here. + */ + peer_init = &chunk->subh.cookie_hdr->c.peer_init[0]; + if (!sctp_process_init(new_asoc, chunk, sctp_source(chunk), peer_init, + GFP_ATOMIC)) + goto nomem; + + if (sctp_auth_asoc_init_active_key(new_asoc, GFP_ATOMIC)) + goto nomem; + + if (!sctp_auth_chunk_verify(net, chunk, new_asoc)) + return SCTP_DISPOSITION_DISCARD; + + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_ESTABLISHED)); + if (asoc->state < SCTP_STATE_ESTABLISHED) + SCTP_INC_STATS(net, SCTP_MIB_CURRESTAB); + sctp_add_cmd_sf(commands, SCTP_CMD_HB_TIMERS_START, SCTP_NULL()); + + /* Update the content of current association. */ + if (sctp_sf_do_assoc_update((struct sctp_association *)asoc, new_asoc, commands)) + goto nomem; + + repl = sctp_make_cookie_ack(asoc, chunk); + if (!repl) + goto nomem; + + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(repl)); + + /* RFC 2960 5.1 Normal Establishment of an Association + * + * D) IMPLEMENTATION NOTE: An implementation may choose to + * send the Communication Up notification to the SCTP user + * upon reception of a valid COOKIE ECHO chunk. + * + * Sadly, this needs to be implemented as a side-effect, because + * we are not guaranteed to have set the association id of the real + * association and so these notifications need to be delayed until + * the association id is allocated. + */ + + sctp_add_cmd_sf(commands, SCTP_CMD_ASSOC_CHANGE, SCTP_U8(SCTP_COMM_UP)); + + /* Sockets API Draft Section 5.3.1.6 + * When a peer sends a Adaptation Layer Indication parameter , SCTP + * delivers this notification to inform the application that of the + * peers requested adaptation layer. + * + * This also needs to be done as a side effect for the same reason as + * above. + */ + if (asoc->peer.adaptation_ind) + sctp_add_cmd_sf(commands, SCTP_CMD_ADAPTATION_IND, SCTP_NULL()); + + if (!asoc->peer.auth_capable) + sctp_add_cmd_sf(commands, SCTP_CMD_PEER_NO_AUTH, SCTP_NULL()); + + return SCTP_DISPOSITION_CONSUME; + +nomem: + return SCTP_DISPOSITION_NOMEM; +} + +/* Unexpected COOKIE-ECHO handler for setup collision (Table 2, action 'C') + * + * Section 5.2.4 + * C) In this case, the local endpoint's cookie has arrived late. + * Before it arrived, the local endpoint sent an INIT and received an + * INIT-ACK and finally sent a COOKIE ECHO with the peer's same tag + * but a new tag of its own. + */ +/* This case represents an initialization collision. */ +static enum sctp_disposition sctp_sf_do_dupcook_c( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + struct sctp_chunk *chunk, + struct sctp_cmd_seq *commands, + struct sctp_association *new_asoc) +{ + /* The cookie should be silently discarded. + * The endpoint SHOULD NOT change states and should leave + * any timers running. + */ + return SCTP_DISPOSITION_DISCARD; +} + +/* Unexpected COOKIE-ECHO handler lost chunk (Table 2, action 'D') + * + * Section 5.2.4 + * + * D) When both local and remote tags match the endpoint should always + * enter the ESTABLISHED state, if it has not already done so. + */ +/* This case represents an initialization collision. */ +static enum sctp_disposition sctp_sf_do_dupcook_d( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + struct sctp_chunk *chunk, + struct sctp_cmd_seq *commands, + struct sctp_association *new_asoc) +{ + struct sctp_ulpevent *ev = NULL, *ai_ev = NULL, *auth_ev = NULL; + struct sctp_chunk *repl; + + /* Clarification from Implementor's Guide: + * D) When both local and remote tags match the endpoint should + * enter the ESTABLISHED state, if it is in the COOKIE-ECHOED state. + * It should stop any cookie timer that may be running and send + * a COOKIE ACK. + */ + + if (!sctp_auth_chunk_verify(net, chunk, asoc)) + return SCTP_DISPOSITION_DISCARD; + + /* Don't accidentally move back into established state. */ + if (asoc->state < SCTP_STATE_ESTABLISHED) { + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T1_COOKIE)); + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_ESTABLISHED)); + SCTP_INC_STATS(net, SCTP_MIB_CURRESTAB); + sctp_add_cmd_sf(commands, SCTP_CMD_HB_TIMERS_START, + SCTP_NULL()); + + /* RFC 2960 5.1 Normal Establishment of an Association + * + * D) IMPLEMENTATION NOTE: An implementation may choose + * to send the Communication Up notification to the + * SCTP user upon reception of a valid COOKIE + * ECHO chunk. + */ + ev = sctp_ulpevent_make_assoc_change(asoc, 0, + SCTP_COMM_UP, 0, + asoc->c.sinit_num_ostreams, + asoc->c.sinit_max_instreams, + NULL, GFP_ATOMIC); + if (!ev) + goto nomem; + + /* Sockets API Draft Section 5.3.1.6 + * When a peer sends a Adaptation Layer Indication parameter, + * SCTP delivers this notification to inform the application + * that of the peers requested adaptation layer. + */ + if (asoc->peer.adaptation_ind) { + ai_ev = sctp_ulpevent_make_adaptation_indication(asoc, + GFP_ATOMIC); + if (!ai_ev) + goto nomem; + + } + + if (!asoc->peer.auth_capable) { + auth_ev = sctp_ulpevent_make_authkey(asoc, 0, + SCTP_AUTH_NO_AUTH, + GFP_ATOMIC); + if (!auth_ev) + goto nomem; + } + } + + repl = sctp_make_cookie_ack(asoc, chunk); + if (!repl) + goto nomem; + + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(repl)); + + if (ev) + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, + SCTP_ULPEVENT(ev)); + if (ai_ev) + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, + SCTP_ULPEVENT(ai_ev)); + if (auth_ev) + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, + SCTP_ULPEVENT(auth_ev)); + + return SCTP_DISPOSITION_CONSUME; + +nomem: + if (auth_ev) + sctp_ulpevent_free(auth_ev); + if (ai_ev) + sctp_ulpevent_free(ai_ev); + if (ev) + sctp_ulpevent_free(ev); + return SCTP_DISPOSITION_NOMEM; +} + +/* + * Handle a duplicate COOKIE-ECHO. This usually means a cookie-carrying + * chunk was retransmitted and then delayed in the network. + * + * Section: 5.2.4 Handle a COOKIE ECHO when a TCB exists + * + * Verification Tag: None. Do cookie validation. + * + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_do_5_2_4_dupcook( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_association *new_asoc; + struct sctp_chunk *chunk = arg; + enum sctp_disposition retval; + struct sctp_chunk *err_chk_p; + int error = 0; + char action; + + /* Make sure that the chunk has a valid length from the protocol + * perspective. In this case check to make sure we have at least + * enough for the chunk header. Cookie length verification is + * done later. + */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_chunkhdr))) { + if (!sctp_vtag_verify(chunk, asoc)) + asoc = NULL; + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, commands); + } + + /* "Decode" the chunk. We have no optional parameters so we + * are in good shape. + */ + chunk->subh.cookie_hdr = (struct sctp_signed_cookie *)chunk->skb->data; + if (!pskb_pull(chunk->skb, ntohs(chunk->chunk_hdr->length) - + sizeof(struct sctp_chunkhdr))) + goto nomem; + + /* In RFC 2960 5.2.4 3, if both Verification Tags in the State Cookie + * of a duplicate COOKIE ECHO match the Verification Tags of the + * current association, consider the State Cookie valid even if + * the lifespan is exceeded. + */ + new_asoc = sctp_unpack_cookie(ep, asoc, chunk, GFP_ATOMIC, &error, + &err_chk_p); + + /* FIXME: + * If the re-build failed, what is the proper error path + * from here? + * + * [We should abort the association. --piggy] + */ + if (!new_asoc) { + /* FIXME: Several errors are possible. A bad cookie should + * be silently discarded, but think about logging it too. + */ + switch (error) { + case -SCTP_IERROR_NOMEM: + goto nomem; + + case -SCTP_IERROR_STALE_COOKIE: + sctp_send_stale_cookie_err(net, ep, asoc, chunk, commands, + err_chk_p); + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + case -SCTP_IERROR_BAD_SIG: + default: + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + } + } + + /* Update socket peer label if first association. */ + if (security_sctp_assoc_request(new_asoc, chunk->head_skb ?: chunk->skb)) { + sctp_association_free(new_asoc); + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + } + + /* Set temp so that it won't be added into hashtable */ + new_asoc->temp = 1; + + /* Compare the tie_tag in cookie with the verification tag of + * current association. + */ + action = sctp_tietags_compare(new_asoc, asoc); + + switch (action) { + case 'A': /* Association restart. */ + retval = sctp_sf_do_dupcook_a(net, ep, asoc, chunk, commands, + new_asoc); + break; + + case 'B': /* Collision case B. */ + retval = sctp_sf_do_dupcook_b(net, ep, asoc, chunk, commands, + new_asoc); + break; + + case 'C': /* Collision case C. */ + retval = sctp_sf_do_dupcook_c(net, ep, asoc, chunk, commands, + new_asoc); + break; + + case 'D': /* Collision case D. */ + retval = sctp_sf_do_dupcook_d(net, ep, asoc, chunk, commands, + new_asoc); + break; + + default: /* Discard packet for all others. */ + retval = sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + break; + } + + /* Delete the temporary new association. */ + sctp_add_cmd_sf(commands, SCTP_CMD_SET_ASOC, SCTP_ASOC(new_asoc)); + sctp_add_cmd_sf(commands, SCTP_CMD_DELETE_TCB, SCTP_NULL()); + + /* Restore association pointer to provide SCTP command interpreter + * with a valid context in case it needs to manipulate + * the queues */ + sctp_add_cmd_sf(commands, SCTP_CMD_SET_ASOC, + SCTP_ASOC((struct sctp_association *)asoc)); + + return retval; + +nomem: + return SCTP_DISPOSITION_NOMEM; +} + +/* + * Process an ABORT. (SHUTDOWN-PENDING state) + * + * See sctp_sf_do_9_1_abort(). + */ +enum sctp_disposition sctp_sf_shutdown_pending_abort( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + + if (!sctp_vtag_verify_either(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Make sure that the ABORT chunk has a valid length. + * Since this is an ABORT chunk, we have to discard it + * because of the following text: + * RFC 2960, Section 3.3.7 + * If an endpoint receives an ABORT with a format error or for an + * association that doesn't exist, it MUST silently discard it. + * Because the length is "invalid", we can't really discard just + * as we do not know its true length. So, to be safe, discard the + * packet. + */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_abort_chunk))) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* ADD-IP: Special case for ABORT chunks + * F4) One special consideration is that ABORT Chunks arriving + * destined to the IP address being deleted MUST be + * ignored (see Section 5.3.1 for further details). + */ + if (SCTP_ADDR_DEL == + sctp_bind_addr_state(&asoc->base.bind_addr, &chunk->dest)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + if (!sctp_err_chunk_valid(chunk)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + return __sctp_sf_do_9_1_abort(net, ep, asoc, type, arg, commands); +} + +/* + * Process an ABORT. (SHUTDOWN-SENT state) + * + * See sctp_sf_do_9_1_abort(). + */ +enum sctp_disposition sctp_sf_shutdown_sent_abort( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + + if (!sctp_vtag_verify_either(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Make sure that the ABORT chunk has a valid length. + * Since this is an ABORT chunk, we have to discard it + * because of the following text: + * RFC 2960, Section 3.3.7 + * If an endpoint receives an ABORT with a format error or for an + * association that doesn't exist, it MUST silently discard it. + * Because the length is "invalid", we can't really discard just + * as we do not know its true length. So, to be safe, discard the + * packet. + */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_abort_chunk))) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* ADD-IP: Special case for ABORT chunks + * F4) One special consideration is that ABORT Chunks arriving + * destined to the IP address being deleted MUST be + * ignored (see Section 5.3.1 for further details). + */ + if (SCTP_ADDR_DEL == + sctp_bind_addr_state(&asoc->base.bind_addr, &chunk->dest)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + if (!sctp_err_chunk_valid(chunk)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Stop the T2-shutdown timer. */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T2_SHUTDOWN)); + + /* Stop the T5-shutdown guard timer. */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T5_SHUTDOWN_GUARD)); + + return __sctp_sf_do_9_1_abort(net, ep, asoc, type, arg, commands); +} + +/* + * Process an ABORT. (SHUTDOWN-ACK-SENT state) + * + * See sctp_sf_do_9_1_abort(). + */ +enum sctp_disposition sctp_sf_shutdown_ack_sent_abort( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + /* The same T2 timer, so we should be able to use + * common function with the SHUTDOWN-SENT state. + */ + return sctp_sf_shutdown_sent_abort(net, ep, asoc, type, arg, commands); +} + +/* + * Handle an Error received in COOKIE_ECHOED state. + * + * Only handle the error type of stale COOKIE Error, the other errors will + * be ignored. + * + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_cookie_echoed_err( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + struct sctp_errhdr *err; + + if (!sctp_vtag_verify(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Make sure that the ERROR chunk has a valid length. + * The parameter walking depends on this as well. + */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_operr_chunk))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + /* Process the error here */ + /* FUTURE FIXME: When PR-SCTP related and other optional + * parms are emitted, this will have to change to handle multiple + * errors. + */ + sctp_walk_errors(err, chunk->chunk_hdr) { + if (SCTP_ERROR_STALE_COOKIE == err->cause) + return sctp_sf_do_5_2_6_stale(net, ep, asoc, type, + arg, commands); + } + + /* It is possible to have malformed error causes, and that + * will cause us to end the walk early. However, since + * we are discarding the packet, there should be no adverse + * affects. + */ + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); +} + +/* + * Handle a Stale COOKIE Error + * + * Section: 5.2.6 Handle Stale COOKIE Error + * If the association is in the COOKIE-ECHOED state, the endpoint may elect + * one of the following three alternatives. + * ... + * 3) Send a new INIT chunk to the endpoint, adding a Cookie + * Preservative parameter requesting an extension to the lifetime of + * the State Cookie. When calculating the time extension, an + * implementation SHOULD use the RTT information measured based on the + * previous COOKIE ECHO / ERROR exchange, and should add no more + * than 1 second beyond the measured RTT, due to long State Cookie + * lifetimes making the endpoint more subject to a replay attack. + * + * Verification Tag: Not explicit, but safe to ignore. + * + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +static enum sctp_disposition sctp_sf_do_5_2_6_stale( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + int attempts = asoc->init_err_counter + 1; + struct sctp_chunk *chunk = arg, *reply; + struct sctp_cookie_preserve_param bht; + struct sctp_bind_addr *bp; + struct sctp_errhdr *err; + u32 stale; + + if (attempts > asoc->max_init_attempts) { + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, + SCTP_ERROR(ETIMEDOUT)); + sctp_add_cmd_sf(commands, SCTP_CMD_INIT_FAILED, + SCTP_PERR(SCTP_ERROR_STALE_COOKIE)); + return SCTP_DISPOSITION_DELETE_TCB; + } + + err = (struct sctp_errhdr *)(chunk->skb->data); + + /* When calculating the time extension, an implementation + * SHOULD use the RTT information measured based on the + * previous COOKIE ECHO / ERROR exchange, and should add no + * more than 1 second beyond the measured RTT, due to long + * State Cookie lifetimes making the endpoint more subject to + * a replay attack. + * Measure of Staleness's unit is usec. (1/1000000 sec) + * Suggested Cookie Life-span Increment's unit is msec. + * (1/1000 sec) + * In general, if you use the suggested cookie life, the value + * found in the field of measure of staleness should be doubled + * to give ample time to retransmit the new cookie and thus + * yield a higher probability of success on the reattempt. + */ + stale = ntohl(*(__be32 *)((u8 *)err + sizeof(*err))); + stale = (stale * 2) / 1000; + + bht.param_hdr.type = SCTP_PARAM_COOKIE_PRESERVATIVE; + bht.param_hdr.length = htons(sizeof(bht)); + bht.lifespan_increment = htonl(stale); + + /* Build that new INIT chunk. */ + bp = (struct sctp_bind_addr *) &asoc->base.bind_addr; + reply = sctp_make_init(asoc, bp, GFP_ATOMIC, sizeof(bht)); + if (!reply) + goto nomem; + + sctp_addto_chunk(reply, sizeof(bht), &bht); + + /* Clear peer's init_tag cached in assoc as we are sending a new INIT */ + sctp_add_cmd_sf(commands, SCTP_CMD_CLEAR_INIT_TAG, SCTP_NULL()); + + /* Stop pending T3-rtx and heartbeat timers */ + sctp_add_cmd_sf(commands, SCTP_CMD_T3_RTX_TIMERS_STOP, SCTP_NULL()); + sctp_add_cmd_sf(commands, SCTP_CMD_HB_TIMERS_STOP, SCTP_NULL()); + + /* Delete non-primary peer ip addresses since we are transitioning + * back to the COOKIE-WAIT state + */ + sctp_add_cmd_sf(commands, SCTP_CMD_DEL_NON_PRIMARY, SCTP_NULL()); + + /* If we've sent any data bundled with COOKIE-ECHO we will need to + * resend + */ + sctp_add_cmd_sf(commands, SCTP_CMD_T1_RETRAN, + SCTP_TRANSPORT(asoc->peer.primary_path)); + + /* Cast away the const modifier, as we want to just + * rerun it through as a sideffect. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_INIT_COUNTER_INC, SCTP_NULL()); + + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T1_COOKIE)); + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_COOKIE_WAIT)); + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_START, + SCTP_TO(SCTP_EVENT_TIMEOUT_T1_INIT)); + + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(reply)); + + return SCTP_DISPOSITION_CONSUME; + +nomem: + return SCTP_DISPOSITION_NOMEM; +} + +/* + * Process an ABORT. + * + * Section: 9.1 + * After checking the Verification Tag, the receiving endpoint shall + * remove the association from its record, and shall report the + * termination to its upper layer. + * + * Verification Tag: 8.5.1 Exceptions in Verification Tag Rules + * B) Rules for packet carrying ABORT: + * + * - The endpoint shall always fill in the Verification Tag field of the + * outbound packet with the destination endpoint's tag value if it + * is known. + * + * - If the ABORT is sent in response to an OOTB packet, the endpoint + * MUST follow the procedure described in Section 8.4. + * + * - The receiver MUST accept the packet if the Verification Tag + * matches either its own tag, OR the tag of its peer. Otherwise, the + * receiver MUST silently discard the packet and take no further + * action. + * + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_do_9_1_abort( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + + if (!sctp_vtag_verify_either(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Make sure that the ABORT chunk has a valid length. + * Since this is an ABORT chunk, we have to discard it + * because of the following text: + * RFC 2960, Section 3.3.7 + * If an endpoint receives an ABORT with a format error or for an + * association that doesn't exist, it MUST silently discard it. + * Because the length is "invalid", we can't really discard just + * as we do not know its true length. So, to be safe, discard the + * packet. + */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_abort_chunk))) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* ADD-IP: Special case for ABORT chunks + * F4) One special consideration is that ABORT Chunks arriving + * destined to the IP address being deleted MUST be + * ignored (see Section 5.3.1 for further details). + */ + if (SCTP_ADDR_DEL == + sctp_bind_addr_state(&asoc->base.bind_addr, &chunk->dest)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + if (!sctp_err_chunk_valid(chunk)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + return __sctp_sf_do_9_1_abort(net, ep, asoc, type, arg, commands); +} + +static enum sctp_disposition __sctp_sf_do_9_1_abort( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + __be16 error = SCTP_ERROR_NO_ERROR; + struct sctp_chunk *chunk = arg; + unsigned int len; + + /* See if we have an error cause code in the chunk. */ + len = ntohs(chunk->chunk_hdr->length); + if (len >= sizeof(struct sctp_chunkhdr) + sizeof(struct sctp_errhdr)) + error = ((struct sctp_errhdr *)chunk->skb->data)->cause; + + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, SCTP_ERROR(ECONNRESET)); + /* ASSOC_FAILED will DELETE_TCB. */ + sctp_add_cmd_sf(commands, SCTP_CMD_ASSOC_FAILED, SCTP_PERR(error)); + SCTP_INC_STATS(net, SCTP_MIB_ABORTEDS); + SCTP_DEC_STATS(net, SCTP_MIB_CURRESTAB); + + return SCTP_DISPOSITION_ABORT; +} + +/* + * Process an ABORT. (COOKIE-WAIT state) + * + * See sctp_sf_do_9_1_abort() above. + */ +enum sctp_disposition sctp_sf_cookie_wait_abort( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + __be16 error = SCTP_ERROR_NO_ERROR; + struct sctp_chunk *chunk = arg; + unsigned int len; + + if (!sctp_vtag_verify_either(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Make sure that the ABORT chunk has a valid length. + * Since this is an ABORT chunk, we have to discard it + * because of the following text: + * RFC 2960, Section 3.3.7 + * If an endpoint receives an ABORT with a format error or for an + * association that doesn't exist, it MUST silently discard it. + * Because the length is "invalid", we can't really discard just + * as we do not know its true length. So, to be safe, discard the + * packet. + */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_abort_chunk))) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* See if we have an error cause code in the chunk. */ + len = ntohs(chunk->chunk_hdr->length); + if (len >= sizeof(struct sctp_chunkhdr) + sizeof(struct sctp_errhdr)) + error = ((struct sctp_errhdr *)chunk->skb->data)->cause; + + return sctp_stop_t1_and_abort(net, commands, error, ECONNREFUSED, asoc, + chunk->transport); +} + +/* + * Process an incoming ICMP as an ABORT. (COOKIE-WAIT state) + */ +enum sctp_disposition sctp_sf_cookie_wait_icmp_abort( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + return sctp_stop_t1_and_abort(net, commands, SCTP_ERROR_NO_ERROR, + ENOPROTOOPT, asoc, + (struct sctp_transport *)arg); +} + +/* + * Process an ABORT. (COOKIE-ECHOED state) + */ +enum sctp_disposition sctp_sf_cookie_echoed_abort( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + /* There is a single T1 timer, so we should be able to use + * common function with the COOKIE-WAIT state. + */ + return sctp_sf_cookie_wait_abort(net, ep, asoc, type, arg, commands); +} + +/* + * Stop T1 timer and abort association with "INIT failed". + * + * This is common code called by several sctp_sf_*_abort() functions above. + */ +static enum sctp_disposition sctp_stop_t1_and_abort( + struct net *net, + struct sctp_cmd_seq *commands, + __be16 error, int sk_err, + const struct sctp_association *asoc, + struct sctp_transport *transport) +{ + pr_debug("%s: ABORT received (INIT)\n", __func__); + + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_CLOSED)); + SCTP_INC_STATS(net, SCTP_MIB_ABORTEDS); + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T1_INIT)); + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, SCTP_ERROR(sk_err)); + /* CMD_INIT_FAILED will DELETE_TCB. */ + sctp_add_cmd_sf(commands, SCTP_CMD_INIT_FAILED, + SCTP_PERR(error)); + + return SCTP_DISPOSITION_ABORT; +} + +/* + * sctp_sf_do_9_2_shut + * + * Section: 9.2 + * Upon the reception of the SHUTDOWN, the peer endpoint shall + * - enter the SHUTDOWN-RECEIVED state, + * + * - stop accepting new data from its SCTP user + * + * - verify, by checking the Cumulative TSN Ack field of the chunk, + * that all its outstanding DATA chunks have been received by the + * SHUTDOWN sender. + * + * Once an endpoint as reached the SHUTDOWN-RECEIVED state it MUST NOT + * send a SHUTDOWN in response to a ULP request. And should discard + * subsequent SHUTDOWN chunks. + * + * If there are still outstanding DATA chunks left, the SHUTDOWN + * receiver shall continue to follow normal data transmission + * procedures defined in Section 6 until all outstanding DATA chunks + * are acknowledged; however, the SHUTDOWN receiver MUST NOT accept + * new data from its SCTP user. + * + * Verification Tag: 8.5 Verification Tag [Normal verification] + * + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_do_9_2_shutdown( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + enum sctp_disposition disposition; + struct sctp_chunk *chunk = arg; + struct sctp_shutdownhdr *sdh; + struct sctp_ulpevent *ev; + __u32 ctsn; + + if (!sctp_vtag_verify(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Make sure that the SHUTDOWN chunk has a valid length. */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_shutdown_chunk))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + /* Convert the elaborate header. */ + sdh = (struct sctp_shutdownhdr *)chunk->skb->data; + skb_pull(chunk->skb, sizeof(*sdh)); + chunk->subh.shutdown_hdr = sdh; + ctsn = ntohl(sdh->cum_tsn_ack); + + if (TSN_lt(ctsn, asoc->ctsn_ack_point)) { + pr_debug("%s: ctsn:%x, ctsn_ack_point:%x\n", __func__, ctsn, + asoc->ctsn_ack_point); + + return SCTP_DISPOSITION_DISCARD; + } + + /* If Cumulative TSN Ack beyond the max tsn currently + * send, terminating the association and respond to the + * sender with an ABORT. + */ + if (!TSN_lt(ctsn, asoc->next_tsn)) + return sctp_sf_violation_ctsn(net, ep, asoc, type, arg, commands); + + /* API 5.3.1.5 SCTP_SHUTDOWN_EVENT + * When a peer sends a SHUTDOWN, SCTP delivers this notification to + * inform the application that it should cease sending data. + */ + ev = sctp_ulpevent_make_shutdown_event(asoc, 0, GFP_ATOMIC); + if (!ev) { + disposition = SCTP_DISPOSITION_NOMEM; + goto out; + } + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, SCTP_ULPEVENT(ev)); + + /* Upon the reception of the SHUTDOWN, the peer endpoint shall + * - enter the SHUTDOWN-RECEIVED state, + * - stop accepting new data from its SCTP user + * + * [This is implicit in the new state.] + */ + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_SHUTDOWN_RECEIVED)); + disposition = SCTP_DISPOSITION_CONSUME; + + if (sctp_outq_is_empty(&asoc->outqueue)) { + disposition = sctp_sf_do_9_2_shutdown_ack(net, ep, asoc, type, + arg, commands); + } + + if (SCTP_DISPOSITION_NOMEM == disposition) + goto out; + + /* - verify, by checking the Cumulative TSN Ack field of the + * chunk, that all its outstanding DATA chunks have been + * received by the SHUTDOWN sender. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_PROCESS_CTSN, + SCTP_BE32(chunk->subh.shutdown_hdr->cum_tsn_ack)); + +out: + return disposition; +} + +/* + * sctp_sf_do_9_2_shut_ctsn + * + * Once an endpoint has reached the SHUTDOWN-RECEIVED state, + * it MUST NOT send a SHUTDOWN in response to a ULP request. + * The Cumulative TSN Ack of the received SHUTDOWN chunk + * MUST be processed. + */ +enum sctp_disposition sctp_sf_do_9_2_shut_ctsn( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + struct sctp_shutdownhdr *sdh; + __u32 ctsn; + + if (!sctp_vtag_verify(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Make sure that the SHUTDOWN chunk has a valid length. */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_shutdown_chunk))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + sdh = (struct sctp_shutdownhdr *)chunk->skb->data; + ctsn = ntohl(sdh->cum_tsn_ack); + + if (TSN_lt(ctsn, asoc->ctsn_ack_point)) { + pr_debug("%s: ctsn:%x, ctsn_ack_point:%x\n", __func__, ctsn, + asoc->ctsn_ack_point); + + return SCTP_DISPOSITION_DISCARD; + } + + /* If Cumulative TSN Ack beyond the max tsn currently + * send, terminating the association and respond to the + * sender with an ABORT. + */ + if (!TSN_lt(ctsn, asoc->next_tsn)) + return sctp_sf_violation_ctsn(net, ep, asoc, type, arg, commands); + + /* verify, by checking the Cumulative TSN Ack field of the + * chunk, that all its outstanding DATA chunks have been + * received by the SHUTDOWN sender. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_PROCESS_CTSN, + SCTP_BE32(sdh->cum_tsn_ack)); + + return SCTP_DISPOSITION_CONSUME; +} + +/* RFC 2960 9.2 + * If an endpoint is in SHUTDOWN-ACK-SENT state and receives an INIT chunk + * (e.g., if the SHUTDOWN COMPLETE was lost) with source and destination + * transport addresses (either in the IP addresses or in the INIT chunk) + * that belong to this association, it should discard the INIT chunk and + * retransmit the SHUTDOWN ACK chunk. + */ +static enum sctp_disposition +__sctp_sf_do_9_2_reshutack(struct net *net, const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + struct sctp_chunk *reply; + + /* Make sure that the chunk has a valid length */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_chunkhdr))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + /* Since we are not going to really process this INIT, there + * is no point in verifying chunk boundaries. Just generate + * the SHUTDOWN ACK. + */ + reply = sctp_make_shutdown_ack(asoc, chunk); + if (NULL == reply) + goto nomem; + + /* Set the transport for the SHUTDOWN ACK chunk and the timeout for + * the T2-SHUTDOWN timer. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_SETUP_T2, SCTP_CHUNK(reply)); + + /* and restart the T2-shutdown timer. */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_RESTART, + SCTP_TO(SCTP_EVENT_TIMEOUT_T2_SHUTDOWN)); + + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(reply)); + + return SCTP_DISPOSITION_CONSUME; +nomem: + return SCTP_DISPOSITION_NOMEM; +} + +enum sctp_disposition +sctp_sf_do_9_2_reshutack(struct net *net, const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + + if (!chunk->singleton) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_init_chunk))) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + if (chunk->sctp_hdr->vtag != 0) + return sctp_sf_tabort_8_4_8(net, ep, asoc, type, arg, commands); + + return __sctp_sf_do_9_2_reshutack(net, ep, asoc, type, arg, commands); +} + +/* + * sctp_sf_do_ecn_cwr + * + * Section: Appendix A: Explicit Congestion Notification + * + * CWR: + * + * RFC 2481 details a specific bit for a sender to send in the header of + * its next outbound TCP segment to indicate to its peer that it has + * reduced its congestion window. This is termed the CWR bit. For + * SCTP the same indication is made by including the CWR chunk. + * This chunk contains one data element, i.e. the TSN number that + * was sent in the ECNE chunk. This element represents the lowest + * TSN number in the datagram that was originally marked with the + * CE bit. + * + * Verification Tag: 8.5 Verification Tag [Normal verification] + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_do_ecn_cwr(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + struct sctp_cwrhdr *cwr; + u32 lowest_tsn; + + if (!sctp_vtag_verify(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_ecne_chunk))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + cwr = (struct sctp_cwrhdr *)chunk->skb->data; + skb_pull(chunk->skb, sizeof(*cwr)); + + lowest_tsn = ntohl(cwr->lowest_tsn); + + /* Does this CWR ack the last sent congestion notification? */ + if (TSN_lte(asoc->last_ecne_tsn, lowest_tsn)) { + /* Stop sending ECNE. */ + sctp_add_cmd_sf(commands, + SCTP_CMD_ECN_CWR, + SCTP_U32(lowest_tsn)); + } + return SCTP_DISPOSITION_CONSUME; +} + +/* + * sctp_sf_do_ecne + * + * Section: Appendix A: Explicit Congestion Notification + * + * ECN-Echo + * + * RFC 2481 details a specific bit for a receiver to send back in its + * TCP acknowledgements to notify the sender of the Congestion + * Experienced (CE) bit having arrived from the network. For SCTP this + * same indication is made by including the ECNE chunk. This chunk + * contains one data element, i.e. the lowest TSN associated with the IP + * datagram marked with the CE bit..... + * + * Verification Tag: 8.5 Verification Tag [Normal verification] + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_do_ecne(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + struct sctp_ecnehdr *ecne; + + if (!sctp_vtag_verify(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_ecne_chunk))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + ecne = (struct sctp_ecnehdr *)chunk->skb->data; + skb_pull(chunk->skb, sizeof(*ecne)); + + /* If this is a newer ECNE than the last CWR packet we sent out */ + sctp_add_cmd_sf(commands, SCTP_CMD_ECN_ECNE, + SCTP_U32(ntohl(ecne->lowest_tsn))); + + return SCTP_DISPOSITION_CONSUME; +} + +/* + * Section: 6.2 Acknowledgement on Reception of DATA Chunks + * + * The SCTP endpoint MUST always acknowledge the reception of each valid + * DATA chunk. + * + * The guidelines on delayed acknowledgement algorithm specified in + * Section 4.2 of [RFC2581] SHOULD be followed. Specifically, an + * acknowledgement SHOULD be generated for at least every second packet + * (not every second DATA chunk) received, and SHOULD be generated within + * 200 ms of the arrival of any unacknowledged DATA chunk. In some + * situations it may be beneficial for an SCTP transmitter to be more + * conservative than the algorithms detailed in this document allow. + * However, an SCTP transmitter MUST NOT be more aggressive than the + * following algorithms allow. + * + * A SCTP receiver MUST NOT generate more than one SACK for every + * incoming packet, other than to update the offered window as the + * receiving application consumes new data. + * + * Verification Tag: 8.5 Verification Tag [Normal verification] + * + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_eat_data_6_2(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + union sctp_arg force = SCTP_NOFORCE(); + struct sctp_chunk *chunk = arg; + int error; + + if (!sctp_vtag_verify(chunk, asoc)) { + sctp_add_cmd_sf(commands, SCTP_CMD_REPORT_BAD_TAG, + SCTP_NULL()); + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + } + + if (!sctp_chunk_length_valid(chunk, sctp_datachk_len(&asoc->stream))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + error = sctp_eat_data(asoc, chunk, commands); + switch (error) { + case SCTP_IERROR_NO_ERROR: + break; + case SCTP_IERROR_HIGH_TSN: + case SCTP_IERROR_BAD_STREAM: + SCTP_INC_STATS(net, SCTP_MIB_IN_DATA_CHUNK_DISCARDS); + goto discard_noforce; + case SCTP_IERROR_DUP_TSN: + case SCTP_IERROR_IGNORE_TSN: + SCTP_INC_STATS(net, SCTP_MIB_IN_DATA_CHUNK_DISCARDS); + goto discard_force; + case SCTP_IERROR_NO_DATA: + return SCTP_DISPOSITION_ABORT; + case SCTP_IERROR_PROTO_VIOLATION: + return sctp_sf_abort_violation(net, ep, asoc, chunk, commands, + (u8 *)chunk->subh.data_hdr, + sctp_datahdr_len(&asoc->stream)); + default: + BUG(); + } + + if (chunk->chunk_hdr->flags & SCTP_DATA_SACK_IMM) + force = SCTP_FORCE(); + + if (asoc->timeouts[SCTP_EVENT_TIMEOUT_AUTOCLOSE]) { + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_RESTART, + SCTP_TO(SCTP_EVENT_TIMEOUT_AUTOCLOSE)); + } + + /* If this is the last chunk in a packet, we need to count it + * toward sack generation. Note that we need to SACK every + * OTHER packet containing data chunks, EVEN IF WE DISCARD + * THEM. We elect to NOT generate SACK's if the chunk fails + * the verification tag test. + * + * RFC 2960 6.2 Acknowledgement on Reception of DATA Chunks + * + * The SCTP endpoint MUST always acknowledge the reception of + * each valid DATA chunk. + * + * The guidelines on delayed acknowledgement algorithm + * specified in Section 4.2 of [RFC2581] SHOULD be followed. + * Specifically, an acknowledgement SHOULD be generated for at + * least every second packet (not every second DATA chunk) + * received, and SHOULD be generated within 200 ms of the + * arrival of any unacknowledged DATA chunk. In some + * situations it may be beneficial for an SCTP transmitter to + * be more conservative than the algorithms detailed in this + * document allow. However, an SCTP transmitter MUST NOT be + * more aggressive than the following algorithms allow. + */ + if (chunk->end_of_packet) + sctp_add_cmd_sf(commands, SCTP_CMD_GEN_SACK, force); + + return SCTP_DISPOSITION_CONSUME; + +discard_force: + /* RFC 2960 6.2 Acknowledgement on Reception of DATA Chunks + * + * When a packet arrives with duplicate DATA chunk(s) and with + * no new DATA chunk(s), the endpoint MUST immediately send a + * SACK with no delay. If a packet arrives with duplicate + * DATA chunk(s) bundled with new DATA chunks, the endpoint + * MAY immediately send a SACK. Normally receipt of duplicate + * DATA chunks will occur when the original SACK chunk was lost + * and the peer's RTO has expired. The duplicate TSN number(s) + * SHOULD be reported in the SACK as duplicate. + */ + /* In our case, we split the MAY SACK advice up whether or not + * the last chunk is a duplicate.' + */ + if (chunk->end_of_packet) + sctp_add_cmd_sf(commands, SCTP_CMD_GEN_SACK, SCTP_FORCE()); + return SCTP_DISPOSITION_DISCARD; + +discard_noforce: + if (chunk->end_of_packet) + sctp_add_cmd_sf(commands, SCTP_CMD_GEN_SACK, force); + + return SCTP_DISPOSITION_DISCARD; +} + +/* + * sctp_sf_eat_data_fast_4_4 + * + * Section: 4 (4) + * (4) In SHUTDOWN-SENT state the endpoint MUST acknowledge any received + * DATA chunks without delay. + * + * Verification Tag: 8.5 Verification Tag [Normal verification] + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_eat_data_fast_4_4( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + int error; + + if (!sctp_vtag_verify(chunk, asoc)) { + sctp_add_cmd_sf(commands, SCTP_CMD_REPORT_BAD_TAG, + SCTP_NULL()); + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + } + + if (!sctp_chunk_length_valid(chunk, sctp_datachk_len(&asoc->stream))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + error = sctp_eat_data(asoc, chunk, commands); + switch (error) { + case SCTP_IERROR_NO_ERROR: + case SCTP_IERROR_HIGH_TSN: + case SCTP_IERROR_DUP_TSN: + case SCTP_IERROR_IGNORE_TSN: + case SCTP_IERROR_BAD_STREAM: + break; + case SCTP_IERROR_NO_DATA: + return SCTP_DISPOSITION_ABORT; + case SCTP_IERROR_PROTO_VIOLATION: + return sctp_sf_abort_violation(net, ep, asoc, chunk, commands, + (u8 *)chunk->subh.data_hdr, + sctp_datahdr_len(&asoc->stream)); + default: + BUG(); + } + + /* Go a head and force a SACK, since we are shutting down. */ + + /* Implementor's Guide. + * + * While in SHUTDOWN-SENT state, the SHUTDOWN sender MUST immediately + * respond to each received packet containing one or more DATA chunk(s) + * with a SACK, a SHUTDOWN chunk, and restart the T2-shutdown timer + */ + if (chunk->end_of_packet) { + /* We must delay the chunk creation since the cumulative + * TSN has not been updated yet. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_GEN_SHUTDOWN, SCTP_NULL()); + sctp_add_cmd_sf(commands, SCTP_CMD_GEN_SACK, SCTP_FORCE()); + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_RESTART, + SCTP_TO(SCTP_EVENT_TIMEOUT_T2_SHUTDOWN)); + } + + return SCTP_DISPOSITION_CONSUME; +} + +/* + * Section: 6.2 Processing a Received SACK + * D) Any time a SACK arrives, the endpoint performs the following: + * + * i) If Cumulative TSN Ack is less than the Cumulative TSN Ack Point, + * then drop the SACK. Since Cumulative TSN Ack is monotonically + * increasing, a SACK whose Cumulative TSN Ack is less than the + * Cumulative TSN Ack Point indicates an out-of-order SACK. + * + * ii) Set rwnd equal to the newly received a_rwnd minus the number + * of bytes still outstanding after processing the Cumulative TSN Ack + * and the Gap Ack Blocks. + * + * iii) If the SACK is missing a TSN that was previously + * acknowledged via a Gap Ack Block (e.g., the data receiver + * reneged on the data), then mark the corresponding DATA chunk + * as available for retransmit: Mark it as missing for fast + * retransmit as described in Section 7.2.4 and if no retransmit + * timer is running for the destination address to which the DATA + * chunk was originally transmitted, then T3-rtx is started for + * that destination address. + * + * Verification Tag: 8.5 Verification Tag [Normal verification] + * + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_eat_sack_6_2(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + struct sctp_sackhdr *sackh; + __u32 ctsn; + + if (!sctp_vtag_verify(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Make sure that the SACK chunk has a valid length. */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_sack_chunk))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + /* Pull the SACK chunk from the data buffer */ + sackh = sctp_sm_pull_sack(chunk); + /* Was this a bogus SACK? */ + if (!sackh) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + chunk->subh.sack_hdr = sackh; + ctsn = ntohl(sackh->cum_tsn_ack); + + /* If Cumulative TSN Ack beyond the max tsn currently + * send, terminating the association and respond to the + * sender with an ABORT. + */ + if (TSN_lte(asoc->next_tsn, ctsn)) + return sctp_sf_violation_ctsn(net, ep, asoc, type, arg, commands); + + trace_sctp_probe(ep, asoc, chunk); + + /* i) If Cumulative TSN Ack is less than the Cumulative TSN + * Ack Point, then drop the SACK. Since Cumulative TSN + * Ack is monotonically increasing, a SACK whose + * Cumulative TSN Ack is less than the Cumulative TSN Ack + * Point indicates an out-of-order SACK. + */ + if (TSN_lt(ctsn, asoc->ctsn_ack_point)) { + pr_debug("%s: ctsn:%x, ctsn_ack_point:%x\n", __func__, ctsn, + asoc->ctsn_ack_point); + + return SCTP_DISPOSITION_DISCARD; + } + + /* Return this SACK for further processing. */ + sctp_add_cmd_sf(commands, SCTP_CMD_PROCESS_SACK, SCTP_CHUNK(chunk)); + + /* Note: We do the rest of the work on the PROCESS_SACK + * sideeffect. + */ + return SCTP_DISPOSITION_CONSUME; +} + +/* + * Generate an ABORT in response to a packet. + * + * Section: 8.4 Handle "Out of the blue" Packets, sctpimpguide 2.41 + * + * 8) The receiver should respond to the sender of the OOTB packet with + * an ABORT. When sending the ABORT, the receiver of the OOTB packet + * MUST fill in the Verification Tag field of the outbound packet + * with the value found in the Verification Tag field of the OOTB + * packet and set the T-bit in the Chunk Flags to indicate that the + * Verification Tag is reflected. After sending this ABORT, the + * receiver of the OOTB packet shall discard the OOTB packet and take + * no further action. + * + * Verification Tag: + * + * The return value is the disposition of the chunk. +*/ +static enum sctp_disposition sctp_sf_tabort_8_4_8( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_packet *packet = NULL; + struct sctp_chunk *chunk = arg; + struct sctp_chunk *abort; + + packet = sctp_ootb_pkt_new(net, asoc, chunk); + if (!packet) + return SCTP_DISPOSITION_NOMEM; + + /* Make an ABORT. The T bit will be set if the asoc + * is NULL. + */ + abort = sctp_make_abort(asoc, chunk, 0); + if (!abort) { + sctp_ootb_pkt_free(packet); + return SCTP_DISPOSITION_NOMEM; + } + + /* Reflect vtag if T-Bit is set */ + if (sctp_test_T_bit(abort)) + packet->vtag = ntohl(chunk->sctp_hdr->vtag); + + /* Set the skb to the belonging sock for accounting. */ + abort->skb->sk = ep->base.sk; + + sctp_packet_append_chunk(packet, abort); + + sctp_add_cmd_sf(commands, SCTP_CMD_SEND_PKT, SCTP_PACKET(packet)); + + SCTP_INC_STATS(net, SCTP_MIB_OUTCTRLCHUNKS); + + sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + return SCTP_DISPOSITION_CONSUME; +} + +/* Handling of SCTP Packets Containing an INIT Chunk Matching an + * Existing Associations when the UDP encap port is incorrect. + * + * From Section 4 at draft-tuexen-tsvwg-sctp-udp-encaps-cons-03. + */ +static enum sctp_disposition sctp_sf_new_encap_port( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_packet *packet = NULL; + struct sctp_chunk *chunk = arg; + struct sctp_chunk *abort; + + packet = sctp_ootb_pkt_new(net, asoc, chunk); + if (!packet) + return SCTP_DISPOSITION_NOMEM; + + abort = sctp_make_new_encap_port(asoc, chunk); + if (!abort) { + sctp_ootb_pkt_free(packet); + return SCTP_DISPOSITION_NOMEM; + } + + abort->skb->sk = ep->base.sk; + + sctp_packet_append_chunk(packet, abort); + + sctp_add_cmd_sf(commands, SCTP_CMD_SEND_PKT, + SCTP_PACKET(packet)); + + SCTP_INC_STATS(net, SCTP_MIB_OUTCTRLCHUNKS); + + sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + return SCTP_DISPOSITION_CONSUME; +} + +/* + * Received an ERROR chunk from peer. Generate SCTP_REMOTE_ERROR + * event as ULP notification for each cause included in the chunk. + * + * API 5.3.1.3 - SCTP_REMOTE_ERROR + * + * The return value is the disposition of the chunk. +*/ +enum sctp_disposition sctp_sf_operr_notify(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + struct sctp_errhdr *err; + + if (!sctp_vtag_verify(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Make sure that the ERROR chunk has a valid length. */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_operr_chunk))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + sctp_walk_errors(err, chunk->chunk_hdr); + if ((void *)err != (void *)chunk->chunk_end) + return sctp_sf_violation_paramlen(net, ep, asoc, type, arg, + (void *)err, commands); + + sctp_add_cmd_sf(commands, SCTP_CMD_PROCESS_OPERR, + SCTP_CHUNK(chunk)); + + return SCTP_DISPOSITION_CONSUME; +} + +/* + * Process an inbound SHUTDOWN ACK. + * + * From Section 9.2: + * Upon the receipt of the SHUTDOWN ACK, the SHUTDOWN sender shall + * stop the T2-shutdown timer, send a SHUTDOWN COMPLETE chunk to its + * peer, and remove all record of the association. + * + * The return value is the disposition. + */ +enum sctp_disposition sctp_sf_do_9_2_final(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + struct sctp_chunk *reply; + struct sctp_ulpevent *ev; + + if (!sctp_vtag_verify(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Make sure that the SHUTDOWN_ACK chunk has a valid length. */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_chunkhdr))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + /* 10.2 H) SHUTDOWN COMPLETE notification + * + * When SCTP completes the shutdown procedures (section 9.2) this + * notification is passed to the upper layer. + */ + ev = sctp_ulpevent_make_assoc_change(asoc, 0, SCTP_SHUTDOWN_COMP, + 0, 0, 0, NULL, GFP_ATOMIC); + if (!ev) + goto nomem; + + /* ...send a SHUTDOWN COMPLETE chunk to its peer, */ + reply = sctp_make_shutdown_complete(asoc, chunk); + if (!reply) + goto nomem_chunk; + + /* Do all the commands now (after allocation), so that we + * have consistent state if memory allocation fails + */ + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, SCTP_ULPEVENT(ev)); + + /* Upon the receipt of the SHUTDOWN ACK, the SHUTDOWN sender shall + * stop the T2-shutdown timer, + */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T2_SHUTDOWN)); + + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T5_SHUTDOWN_GUARD)); + + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_CLOSED)); + SCTP_INC_STATS(net, SCTP_MIB_SHUTDOWNS); + SCTP_DEC_STATS(net, SCTP_MIB_CURRESTAB); + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(reply)); + + /* ...and remove all record of the association. */ + sctp_add_cmd_sf(commands, SCTP_CMD_DELETE_TCB, SCTP_NULL()); + return SCTP_DISPOSITION_DELETE_TCB; + +nomem_chunk: + sctp_ulpevent_free(ev); +nomem: + return SCTP_DISPOSITION_NOMEM; +} + +/* + * RFC 2960, 8.4 - Handle "Out of the blue" Packets, sctpimpguide 2.41. + * + * 5) If the packet contains a SHUTDOWN ACK chunk, the receiver should + * respond to the sender of the OOTB packet with a SHUTDOWN COMPLETE. + * When sending the SHUTDOWN COMPLETE, the receiver of the OOTB + * packet must fill in the Verification Tag field of the outbound + * packet with the Verification Tag received in the SHUTDOWN ACK and + * set the T-bit in the Chunk Flags to indicate that the Verification + * Tag is reflected. + * + * 8) The receiver should respond to the sender of the OOTB packet with + * an ABORT. When sending the ABORT, the receiver of the OOTB packet + * MUST fill in the Verification Tag field of the outbound packet + * with the value found in the Verification Tag field of the OOTB + * packet and set the T-bit in the Chunk Flags to indicate that the + * Verification Tag is reflected. After sending this ABORT, the + * receiver of the OOTB packet shall discard the OOTB packet and take + * no further action. + */ +enum sctp_disposition sctp_sf_ootb(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + struct sk_buff *skb = chunk->skb; + struct sctp_chunkhdr *ch; + struct sctp_errhdr *err; + int ootb_cookie_ack = 0; + int ootb_shut_ack = 0; + __u8 *ch_end; + + SCTP_INC_STATS(net, SCTP_MIB_OUTOFBLUES); + + if (asoc && !sctp_vtag_verify(chunk, asoc)) + asoc = NULL; + + ch = (struct sctp_chunkhdr *)chunk->chunk_hdr; + do { + /* Report violation if the chunk is less then minimal */ + if (ntohs(ch->length) < sizeof(*ch)) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + /* Report violation if chunk len overflows */ + ch_end = ((__u8 *)ch) + SCTP_PAD4(ntohs(ch->length)); + if (ch_end > skb_tail_pointer(skb)) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + /* Now that we know we at least have a chunk header, + * do things that are type appropriate. + */ + if (SCTP_CID_SHUTDOWN_ACK == ch->type) + ootb_shut_ack = 1; + + /* RFC 2960, Section 3.3.7 + * Moreover, under any circumstances, an endpoint that + * receives an ABORT MUST NOT respond to that ABORT by + * sending an ABORT of its own. + */ + if (SCTP_CID_ABORT == ch->type) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* RFC 8.4, 7) If the packet contains a "Stale cookie" ERROR + * or a COOKIE ACK the SCTP Packet should be silently + * discarded. + */ + + if (SCTP_CID_COOKIE_ACK == ch->type) + ootb_cookie_ack = 1; + + if (SCTP_CID_ERROR == ch->type) { + sctp_walk_errors(err, ch) { + if (SCTP_ERROR_STALE_COOKIE == err->cause) { + ootb_cookie_ack = 1; + break; + } + } + } + + ch = (struct sctp_chunkhdr *)ch_end; + } while (ch_end < skb_tail_pointer(skb)); + + if (ootb_shut_ack) + return sctp_sf_shut_8_4_5(net, ep, asoc, type, arg, commands); + else if (ootb_cookie_ack) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + else + return sctp_sf_tabort_8_4_8(net, ep, asoc, type, arg, commands); +} + +/* + * Handle an "Out of the blue" SHUTDOWN ACK. + * + * Section: 8.4 5, sctpimpguide 2.41. + * + * 5) If the packet contains a SHUTDOWN ACK chunk, the receiver should + * respond to the sender of the OOTB packet with a SHUTDOWN COMPLETE. + * When sending the SHUTDOWN COMPLETE, the receiver of the OOTB + * packet must fill in the Verification Tag field of the outbound + * packet with the Verification Tag received in the SHUTDOWN ACK and + * set the T-bit in the Chunk Flags to indicate that the Verification + * Tag is reflected. + * + * Inputs + * (endpoint, asoc, type, arg, commands) + * + * Outputs + * (enum sctp_disposition) + * + * The return value is the disposition of the chunk. + */ +static enum sctp_disposition sctp_sf_shut_8_4_5( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_packet *packet = NULL; + struct sctp_chunk *chunk = arg; + struct sctp_chunk *shut; + + packet = sctp_ootb_pkt_new(net, asoc, chunk); + if (!packet) + return SCTP_DISPOSITION_NOMEM; + + /* Make an SHUTDOWN_COMPLETE. + * The T bit will be set if the asoc is NULL. + */ + shut = sctp_make_shutdown_complete(asoc, chunk); + if (!shut) { + sctp_ootb_pkt_free(packet); + return SCTP_DISPOSITION_NOMEM; + } + + /* Reflect vtag if T-Bit is set */ + if (sctp_test_T_bit(shut)) + packet->vtag = ntohl(chunk->sctp_hdr->vtag); + + /* Set the skb to the belonging sock for accounting. */ + shut->skb->sk = ep->base.sk; + + sctp_packet_append_chunk(packet, shut); + + sctp_add_cmd_sf(commands, SCTP_CMD_SEND_PKT, + SCTP_PACKET(packet)); + + SCTP_INC_STATS(net, SCTP_MIB_OUTCTRLCHUNKS); + + /* We need to discard the rest of the packet to prevent + * potential boomming attacks from additional bundled chunks. + * This is documented in SCTP Threats ID. + */ + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); +} + +/* + * Handle SHUTDOWN ACK in COOKIE_ECHOED or COOKIE_WAIT state. + * + * Verification Tag: 8.5.1 E) Rules for packet carrying a SHUTDOWN ACK + * If the receiver is in COOKIE-ECHOED or COOKIE-WAIT state the + * procedures in section 8.4 SHOULD be followed, in other words it + * should be treated as an Out Of The Blue packet. + * [This means that we do NOT check the Verification Tag on these + * chunks. --piggy ] + * + */ +enum sctp_disposition sctp_sf_do_8_5_1_E_sa(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + + if (!sctp_vtag_verify(chunk, asoc)) + asoc = NULL; + + /* Make sure that the SHUTDOWN_ACK chunk has a valid length. */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_chunkhdr))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + /* Although we do have an association in this case, it corresponds + * to a restarted association. So the packet is treated as an OOTB + * packet and the state function that handles OOTB SHUTDOWN_ACK is + * called with a NULL association. + */ + SCTP_INC_STATS(net, SCTP_MIB_OUTOFBLUES); + + return sctp_sf_shut_8_4_5(net, ep, NULL, type, arg, commands); +} + +/* ADDIP Section 4.2 Upon reception of an ASCONF Chunk. */ +enum sctp_disposition sctp_sf_do_asconf(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_paramhdr *err_param = NULL; + struct sctp_chunk *asconf_ack = NULL; + struct sctp_chunk *chunk = arg; + struct sctp_addiphdr *hdr; + __u32 serial; + + if (!sctp_vtag_verify(chunk, asoc)) { + sctp_add_cmd_sf(commands, SCTP_CMD_REPORT_BAD_TAG, + SCTP_NULL()); + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + } + + /* Make sure that the ASCONF ADDIP chunk has a valid length. */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_addip_chunk))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + /* ADD-IP: Section 4.1.1 + * This chunk MUST be sent in an authenticated way by using + * the mechanism defined in [I-D.ietf-tsvwg-sctp-auth]. If this chunk + * is received unauthenticated it MUST be silently discarded as + * described in [I-D.ietf-tsvwg-sctp-auth]. + */ + if (!asoc->peer.asconf_capable || + (!net->sctp.addip_noauth && !chunk->auth)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + hdr = (struct sctp_addiphdr *)chunk->skb->data; + serial = ntohl(hdr->serial); + + /* Verify the ASCONF chunk before processing it. */ + if (!sctp_verify_asconf(asoc, chunk, true, &err_param)) + return sctp_sf_violation_paramlen(net, ep, asoc, type, arg, + (void *)err_param, commands); + + /* ADDIP 5.2 E1) Compare the value of the serial number to the value + * the endpoint stored in a new association variable + * 'Peer-Serial-Number'. + */ + if (serial == asoc->peer.addip_serial + 1) { + /* If this is the first instance of ASCONF in the packet, + * we can clean our old ASCONF-ACKs. + */ + if (!chunk->has_asconf) + sctp_assoc_clean_asconf_ack_cache(asoc); + + /* ADDIP 5.2 E4) When the Sequence Number matches the next one + * expected, process the ASCONF as described below and after + * processing the ASCONF Chunk, append an ASCONF-ACK Chunk to + * the response packet and cache a copy of it (in the event it + * later needs to be retransmitted). + * + * Essentially, do V1-V5. + */ + asconf_ack = sctp_process_asconf((struct sctp_association *) + asoc, chunk); + if (!asconf_ack) + return SCTP_DISPOSITION_NOMEM; + } else if (serial < asoc->peer.addip_serial + 1) { + /* ADDIP 5.2 E2) + * If the value found in the Sequence Number is less than the + * ('Peer- Sequence-Number' + 1), simply skip to the next + * ASCONF, and include in the outbound response packet + * any previously cached ASCONF-ACK response that was + * sent and saved that matches the Sequence Number of the + * ASCONF. Note: It is possible that no cached ASCONF-ACK + * Chunk exists. This will occur when an older ASCONF + * arrives out of order. In such a case, the receiver + * should skip the ASCONF Chunk and not include ASCONF-ACK + * Chunk for that chunk. + */ + asconf_ack = sctp_assoc_lookup_asconf_ack(asoc, hdr->serial); + if (!asconf_ack) + return SCTP_DISPOSITION_DISCARD; + + /* Reset the transport so that we select the correct one + * this time around. This is to make sure that we don't + * accidentally use a stale transport that's been removed. + */ + asconf_ack->transport = NULL; + } else { + /* ADDIP 5.2 E5) Otherwise, the ASCONF Chunk is discarded since + * it must be either a stale packet or from an attacker. + */ + return SCTP_DISPOSITION_DISCARD; + } + + /* ADDIP 5.2 E6) The destination address of the SCTP packet + * containing the ASCONF-ACK Chunks MUST be the source address of + * the SCTP packet that held the ASCONF Chunks. + * + * To do this properly, we'll set the destination address of the chunk + * and at the transmit time, will try look up the transport to use. + * Since ASCONFs may be bundled, the correct transport may not be + * created until we process the entire packet, thus this workaround. + */ + asconf_ack->dest = chunk->source; + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(asconf_ack)); + if (asoc->new_transport) { + sctp_sf_heartbeat(ep, asoc, type, asoc->new_transport, commands); + ((struct sctp_association *)asoc)->new_transport = NULL; + } + + return SCTP_DISPOSITION_CONSUME; +} + +static enum sctp_disposition sctp_send_next_asconf( + struct net *net, + const struct sctp_endpoint *ep, + struct sctp_association *asoc, + const union sctp_subtype type, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *asconf; + struct list_head *entry; + + if (list_empty(&asoc->addip_chunk_list)) + return SCTP_DISPOSITION_CONSUME; + + entry = asoc->addip_chunk_list.next; + asconf = list_entry(entry, struct sctp_chunk, list); + + list_del_init(entry); + sctp_chunk_hold(asconf); + asoc->addip_last_asconf = asconf; + + return sctp_sf_do_prm_asconf(net, ep, asoc, type, asconf, commands); +} + +/* + * ADDIP Section 4.3 General rules for address manipulation + * When building TLV parameters for the ASCONF Chunk that will add or + * delete IP addresses the D0 to D13 rules should be applied: + */ +enum sctp_disposition sctp_sf_do_asconf_ack(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *last_asconf = asoc->addip_last_asconf; + struct sctp_paramhdr *err_param = NULL; + struct sctp_chunk *asconf_ack = arg; + struct sctp_addiphdr *addip_hdr; + __u32 sent_serial, rcvd_serial; + struct sctp_chunk *abort; + + if (!sctp_vtag_verify(asconf_ack, asoc)) { + sctp_add_cmd_sf(commands, SCTP_CMD_REPORT_BAD_TAG, + SCTP_NULL()); + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + } + + /* Make sure that the ADDIP chunk has a valid length. */ + if (!sctp_chunk_length_valid(asconf_ack, + sizeof(struct sctp_addip_chunk))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + /* ADD-IP, Section 4.1.2: + * This chunk MUST be sent in an authenticated way by using + * the mechanism defined in [I-D.ietf-tsvwg-sctp-auth]. If this chunk + * is received unauthenticated it MUST be silently discarded as + * described in [I-D.ietf-tsvwg-sctp-auth]. + */ + if (!asoc->peer.asconf_capable || + (!net->sctp.addip_noauth && !asconf_ack->auth)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + addip_hdr = (struct sctp_addiphdr *)asconf_ack->skb->data; + rcvd_serial = ntohl(addip_hdr->serial); + + /* Verify the ASCONF-ACK chunk before processing it. */ + if (!sctp_verify_asconf(asoc, asconf_ack, false, &err_param)) + return sctp_sf_violation_paramlen(net, ep, asoc, type, arg, + (void *)err_param, commands); + + if (last_asconf) { + addip_hdr = (struct sctp_addiphdr *)last_asconf->subh.addip_hdr; + sent_serial = ntohl(addip_hdr->serial); + } else { + sent_serial = asoc->addip_serial - 1; + } + + /* D0) If an endpoint receives an ASCONF-ACK that is greater than or + * equal to the next serial number to be used but no ASCONF chunk is + * outstanding the endpoint MUST ABORT the association. Note that a + * sequence number is greater than if it is no more than 2^^31-1 + * larger than the current sequence number (using serial arithmetic). + */ + if (ADDIP_SERIAL_gte(rcvd_serial, sent_serial + 1) && + !(asoc->addip_last_asconf)) { + abort = sctp_make_abort(asoc, asconf_ack, + sizeof(struct sctp_errhdr)); + if (abort) { + sctp_init_cause(abort, SCTP_ERROR_ASCONF_ACK, 0); + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, + SCTP_CHUNK(abort)); + } + /* We are going to ABORT, so we might as well stop + * processing the rest of the chunks in the packet. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T4_RTO)); + sctp_add_cmd_sf(commands, SCTP_CMD_DISCARD_PACKET, SCTP_NULL()); + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, + SCTP_ERROR(ECONNABORTED)); + sctp_add_cmd_sf(commands, SCTP_CMD_ASSOC_FAILED, + SCTP_PERR(SCTP_ERROR_ASCONF_ACK)); + SCTP_INC_STATS(net, SCTP_MIB_ABORTEDS); + SCTP_DEC_STATS(net, SCTP_MIB_CURRESTAB); + return SCTP_DISPOSITION_ABORT; + } + + if ((rcvd_serial == sent_serial) && asoc->addip_last_asconf) { + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T4_RTO)); + + if (!sctp_process_asconf_ack((struct sctp_association *)asoc, + asconf_ack)) + return sctp_send_next_asconf(net, ep, + (struct sctp_association *)asoc, + type, commands); + + abort = sctp_make_abort(asoc, asconf_ack, + sizeof(struct sctp_errhdr)); + if (abort) { + sctp_init_cause(abort, SCTP_ERROR_RSRC_LOW, 0); + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, + SCTP_CHUNK(abort)); + } + /* We are going to ABORT, so we might as well stop + * processing the rest of the chunks in the packet. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_DISCARD_PACKET, SCTP_NULL()); + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, + SCTP_ERROR(ECONNABORTED)); + sctp_add_cmd_sf(commands, SCTP_CMD_ASSOC_FAILED, + SCTP_PERR(SCTP_ERROR_ASCONF_ACK)); + SCTP_INC_STATS(net, SCTP_MIB_ABORTEDS); + SCTP_DEC_STATS(net, SCTP_MIB_CURRESTAB); + return SCTP_DISPOSITION_ABORT; + } + + return SCTP_DISPOSITION_DISCARD; +} + +/* RE-CONFIG Section 5.2 Upon reception of an RECONF Chunk. */ +enum sctp_disposition sctp_sf_do_reconf(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_paramhdr *err_param = NULL; + struct sctp_chunk *chunk = arg; + struct sctp_reconf_chunk *hdr; + union sctp_params param; + + if (!sctp_vtag_verify(chunk, asoc)) { + sctp_add_cmd_sf(commands, SCTP_CMD_REPORT_BAD_TAG, + SCTP_NULL()); + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + } + + /* Make sure that the RECONF chunk has a valid length. */ + if (!sctp_chunk_length_valid(chunk, sizeof(*hdr))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + if (!sctp_verify_reconf(asoc, chunk, &err_param)) + return sctp_sf_violation_paramlen(net, ep, asoc, type, arg, + (void *)err_param, commands); + + hdr = (struct sctp_reconf_chunk *)chunk->chunk_hdr; + sctp_walk_params(param, hdr, params) { + struct sctp_chunk *reply = NULL; + struct sctp_ulpevent *ev = NULL; + + if (param.p->type == SCTP_PARAM_RESET_OUT_REQUEST) + reply = sctp_process_strreset_outreq( + (struct sctp_association *)asoc, param, &ev); + else if (param.p->type == SCTP_PARAM_RESET_IN_REQUEST) + reply = sctp_process_strreset_inreq( + (struct sctp_association *)asoc, param, &ev); + else if (param.p->type == SCTP_PARAM_RESET_TSN_REQUEST) + reply = sctp_process_strreset_tsnreq( + (struct sctp_association *)asoc, param, &ev); + else if (param.p->type == SCTP_PARAM_RESET_ADD_OUT_STREAMS) + reply = sctp_process_strreset_addstrm_out( + (struct sctp_association *)asoc, param, &ev); + else if (param.p->type == SCTP_PARAM_RESET_ADD_IN_STREAMS) + reply = sctp_process_strreset_addstrm_in( + (struct sctp_association *)asoc, param, &ev); + else if (param.p->type == SCTP_PARAM_RESET_RESPONSE) + reply = sctp_process_strreset_resp( + (struct sctp_association *)asoc, param, &ev); + + if (ev) + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, + SCTP_ULPEVENT(ev)); + + if (reply) + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, + SCTP_CHUNK(reply)); + } + + return SCTP_DISPOSITION_CONSUME; +} + +/* + * PR-SCTP Section 3.6 Receiver Side Implementation of PR-SCTP + * + * When a FORWARD TSN chunk arrives, the data receiver MUST first update + * its cumulative TSN point to the value carried in the FORWARD TSN + * chunk, and then MUST further advance its cumulative TSN point locally + * if possible. + * After the above processing, the data receiver MUST stop reporting any + * missing TSNs earlier than or equal to the new cumulative TSN point. + * + * Verification Tag: 8.5 Verification Tag [Normal verification] + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_eat_fwd_tsn(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_fwdtsn_hdr *fwdtsn_hdr; + struct sctp_chunk *chunk = arg; + __u16 len; + __u32 tsn; + + if (!sctp_vtag_verify(chunk, asoc)) { + sctp_add_cmd_sf(commands, SCTP_CMD_REPORT_BAD_TAG, + SCTP_NULL()); + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + } + + if (!asoc->peer.prsctp_capable) + return sctp_sf_unk_chunk(net, ep, asoc, type, arg, commands); + + /* Make sure that the FORWARD_TSN chunk has valid length. */ + if (!sctp_chunk_length_valid(chunk, sctp_ftsnchk_len(&asoc->stream))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + fwdtsn_hdr = (struct sctp_fwdtsn_hdr *)chunk->skb->data; + chunk->subh.fwdtsn_hdr = fwdtsn_hdr; + len = ntohs(chunk->chunk_hdr->length); + len -= sizeof(struct sctp_chunkhdr); + skb_pull(chunk->skb, len); + + tsn = ntohl(fwdtsn_hdr->new_cum_tsn); + pr_debug("%s: TSN 0x%x\n", __func__, tsn); + + /* The TSN is too high--silently discard the chunk and count on it + * getting retransmitted later. + */ + if (sctp_tsnmap_check(&asoc->peer.tsn_map, tsn) < 0) + goto discard_noforce; + + if (!asoc->stream.si->validate_ftsn(chunk)) + goto discard_noforce; + + sctp_add_cmd_sf(commands, SCTP_CMD_REPORT_FWDTSN, SCTP_U32(tsn)); + if (len > sctp_ftsnhdr_len(&asoc->stream)) + sctp_add_cmd_sf(commands, SCTP_CMD_PROCESS_FWDTSN, + SCTP_CHUNK(chunk)); + + /* Count this as receiving DATA. */ + if (asoc->timeouts[SCTP_EVENT_TIMEOUT_AUTOCLOSE]) { + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_RESTART, + SCTP_TO(SCTP_EVENT_TIMEOUT_AUTOCLOSE)); + } + + /* FIXME: For now send a SACK, but DATA processing may + * send another. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_GEN_SACK, SCTP_NOFORCE()); + + return SCTP_DISPOSITION_CONSUME; + +discard_noforce: + return SCTP_DISPOSITION_DISCARD; +} + +enum sctp_disposition sctp_sf_eat_fwd_tsn_fast( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_fwdtsn_hdr *fwdtsn_hdr; + struct sctp_chunk *chunk = arg; + __u16 len; + __u32 tsn; + + if (!sctp_vtag_verify(chunk, asoc)) { + sctp_add_cmd_sf(commands, SCTP_CMD_REPORT_BAD_TAG, + SCTP_NULL()); + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + } + + if (!asoc->peer.prsctp_capable) + return sctp_sf_unk_chunk(net, ep, asoc, type, arg, commands); + + /* Make sure that the FORWARD_TSN chunk has a valid length. */ + if (!sctp_chunk_length_valid(chunk, sctp_ftsnchk_len(&asoc->stream))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + fwdtsn_hdr = (struct sctp_fwdtsn_hdr *)chunk->skb->data; + chunk->subh.fwdtsn_hdr = fwdtsn_hdr; + len = ntohs(chunk->chunk_hdr->length); + len -= sizeof(struct sctp_chunkhdr); + skb_pull(chunk->skb, len); + + tsn = ntohl(fwdtsn_hdr->new_cum_tsn); + pr_debug("%s: TSN 0x%x\n", __func__, tsn); + + /* The TSN is too high--silently discard the chunk and count on it + * getting retransmitted later. + */ + if (sctp_tsnmap_check(&asoc->peer.tsn_map, tsn) < 0) + goto gen_shutdown; + + if (!asoc->stream.si->validate_ftsn(chunk)) + goto gen_shutdown; + + sctp_add_cmd_sf(commands, SCTP_CMD_REPORT_FWDTSN, SCTP_U32(tsn)); + if (len > sctp_ftsnhdr_len(&asoc->stream)) + sctp_add_cmd_sf(commands, SCTP_CMD_PROCESS_FWDTSN, + SCTP_CHUNK(chunk)); + + /* Go a head and force a SACK, since we are shutting down. */ +gen_shutdown: + /* Implementor's Guide. + * + * While in SHUTDOWN-SENT state, the SHUTDOWN sender MUST immediately + * respond to each received packet containing one or more DATA chunk(s) + * with a SACK, a SHUTDOWN chunk, and restart the T2-shutdown timer + */ + sctp_add_cmd_sf(commands, SCTP_CMD_GEN_SHUTDOWN, SCTP_NULL()); + sctp_add_cmd_sf(commands, SCTP_CMD_GEN_SACK, SCTP_FORCE()); + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_RESTART, + SCTP_TO(SCTP_EVENT_TIMEOUT_T2_SHUTDOWN)); + + return SCTP_DISPOSITION_CONSUME; +} + +/* + * SCTP-AUTH Section 6.3 Receiving authenticated chunks + * + * The receiver MUST use the HMAC algorithm indicated in the HMAC + * Identifier field. If this algorithm was not specified by the + * receiver in the HMAC-ALGO parameter in the INIT or INIT-ACK chunk + * during association setup, the AUTH chunk and all chunks after it MUST + * be discarded and an ERROR chunk SHOULD be sent with the error cause + * defined in Section 4.1. + * + * If an endpoint with no shared key receives a Shared Key Identifier + * other than 0, it MUST silently discard all authenticated chunks. If + * the endpoint has at least one endpoint pair shared key for the peer, + * it MUST use the key specified by the Shared Key Identifier if a + * key has been configured for that Shared Key Identifier. If no + * endpoint pair shared key has been configured for that Shared Key + * Identifier, all authenticated chunks MUST be silently discarded. + * + * Verification Tag: 8.5 Verification Tag [Normal verification] + * + * The return value is the disposition of the chunk. + */ +static enum sctp_ierror sctp_sf_authenticate( + const struct sctp_association *asoc, + struct sctp_chunk *chunk) +{ + struct sctp_shared_key *sh_key = NULL; + struct sctp_authhdr *auth_hdr; + __u8 *save_digest, *digest; + struct sctp_hmac *hmac; + unsigned int sig_len; + __u16 key_id; + + /* Pull in the auth header, so we can do some more verification */ + auth_hdr = (struct sctp_authhdr *)chunk->skb->data; + chunk->subh.auth_hdr = auth_hdr; + skb_pull(chunk->skb, sizeof(*auth_hdr)); + + /* Make sure that we support the HMAC algorithm from the auth + * chunk. + */ + if (!sctp_auth_asoc_verify_hmac_id(asoc, auth_hdr->hmac_id)) + return SCTP_IERROR_AUTH_BAD_HMAC; + + /* Make sure that the provided shared key identifier has been + * configured + */ + key_id = ntohs(auth_hdr->shkey_id); + if (key_id != asoc->active_key_id) { + sh_key = sctp_auth_get_shkey(asoc, key_id); + if (!sh_key) + return SCTP_IERROR_AUTH_BAD_KEYID; + } + + /* Make sure that the length of the signature matches what + * we expect. + */ + sig_len = ntohs(chunk->chunk_hdr->length) - + sizeof(struct sctp_auth_chunk); + hmac = sctp_auth_get_hmac(ntohs(auth_hdr->hmac_id)); + if (sig_len != hmac->hmac_len) + return SCTP_IERROR_PROTO_VIOLATION; + + /* Now that we've done validation checks, we can compute and + * verify the hmac. The steps involved are: + * 1. Save the digest from the chunk. + * 2. Zero out the digest in the chunk. + * 3. Compute the new digest + * 4. Compare saved and new digests. + */ + digest = auth_hdr->hmac; + skb_pull(chunk->skb, sig_len); + + save_digest = kmemdup(digest, sig_len, GFP_ATOMIC); + if (!save_digest) + goto nomem; + + memset(digest, 0, sig_len); + + sctp_auth_calculate_hmac(asoc, chunk->skb, + (struct sctp_auth_chunk *)chunk->chunk_hdr, + sh_key, GFP_ATOMIC); + + /* Discard the packet if the digests do not match */ + if (memcmp(save_digest, digest, sig_len)) { + kfree(save_digest); + return SCTP_IERROR_BAD_SIG; + } + + kfree(save_digest); + chunk->auth = 1; + + return SCTP_IERROR_NO_ERROR; +nomem: + return SCTP_IERROR_NOMEM; +} + +enum sctp_disposition sctp_sf_eat_auth(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + struct sctp_authhdr *auth_hdr; + struct sctp_chunk *err_chunk; + enum sctp_ierror error; + + /* Make sure that the peer has AUTH capable */ + if (!asoc->peer.auth_capable) + return sctp_sf_unk_chunk(net, ep, asoc, type, arg, commands); + + if (!sctp_vtag_verify(chunk, asoc)) { + sctp_add_cmd_sf(commands, SCTP_CMD_REPORT_BAD_TAG, + SCTP_NULL()); + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + } + + /* Make sure that the AUTH chunk has valid length. */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_auth_chunk))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + auth_hdr = (struct sctp_authhdr *)chunk->skb->data; + error = sctp_sf_authenticate(asoc, chunk); + switch (error) { + case SCTP_IERROR_AUTH_BAD_HMAC: + /* Generate the ERROR chunk and discard the rest + * of the packet + */ + err_chunk = sctp_make_op_error(asoc, chunk, + SCTP_ERROR_UNSUP_HMAC, + &auth_hdr->hmac_id, + sizeof(__u16), 0); + if (err_chunk) { + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, + SCTP_CHUNK(err_chunk)); + } + fallthrough; + case SCTP_IERROR_AUTH_BAD_KEYID: + case SCTP_IERROR_BAD_SIG: + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + case SCTP_IERROR_PROTO_VIOLATION: + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + case SCTP_IERROR_NOMEM: + return SCTP_DISPOSITION_NOMEM; + + default: /* Prevent gcc warnings */ + break; + } + + if (asoc->active_key_id != ntohs(auth_hdr->shkey_id)) { + struct sctp_ulpevent *ev; + + ev = sctp_ulpevent_make_authkey(asoc, ntohs(auth_hdr->shkey_id), + SCTP_AUTH_NEW_KEY, GFP_ATOMIC); + + if (!ev) + return SCTP_DISPOSITION_NOMEM; + + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, + SCTP_ULPEVENT(ev)); + } + + return SCTP_DISPOSITION_CONSUME; +} + +/* + * Process an unknown chunk. + * + * Section: 3.2. Also, 2.1 in the implementor's guide. + * + * Chunk Types are encoded such that the highest-order two bits specify + * the action that must be taken if the processing endpoint does not + * recognize the Chunk Type. + * + * 00 - Stop processing this SCTP packet and discard it, do not process + * any further chunks within it. + * + * 01 - Stop processing this SCTP packet and discard it, do not process + * any further chunks within it, and report the unrecognized + * chunk in an 'Unrecognized Chunk Type'. + * + * 10 - Skip this chunk and continue processing. + * + * 11 - Skip this chunk and continue processing, but report in an ERROR + * Chunk using the 'Unrecognized Chunk Type' cause of error. + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_unk_chunk(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *unk_chunk = arg; + struct sctp_chunk *err_chunk; + struct sctp_chunkhdr *hdr; + + pr_debug("%s: processing unknown chunk id:%d\n", __func__, type.chunk); + + if (!sctp_vtag_verify(unk_chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Make sure that the chunk has a valid length. + * Since we don't know the chunk type, we use a general + * chunkhdr structure to make a comparison. + */ + if (!sctp_chunk_length_valid(unk_chunk, sizeof(*hdr))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + switch (type.chunk & SCTP_CID_ACTION_MASK) { + case SCTP_CID_ACTION_DISCARD: + /* Discard the packet. */ + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + case SCTP_CID_ACTION_DISCARD_ERR: + /* Generate an ERROR chunk as response. */ + hdr = unk_chunk->chunk_hdr; + err_chunk = sctp_make_op_error(asoc, unk_chunk, + SCTP_ERROR_UNKNOWN_CHUNK, hdr, + SCTP_PAD4(ntohs(hdr->length)), + 0); + if (err_chunk) { + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, + SCTP_CHUNK(err_chunk)); + } + + /* Discard the packet. */ + sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + return SCTP_DISPOSITION_CONSUME; + case SCTP_CID_ACTION_SKIP: + /* Skip the chunk. */ + return SCTP_DISPOSITION_DISCARD; + case SCTP_CID_ACTION_SKIP_ERR: + /* Generate an ERROR chunk as response. */ + hdr = unk_chunk->chunk_hdr; + err_chunk = sctp_make_op_error(asoc, unk_chunk, + SCTP_ERROR_UNKNOWN_CHUNK, hdr, + SCTP_PAD4(ntohs(hdr->length)), + 0); + if (err_chunk) { + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, + SCTP_CHUNK(err_chunk)); + } + /* Skip the chunk. */ + return SCTP_DISPOSITION_CONSUME; + default: + break; + } + + return SCTP_DISPOSITION_DISCARD; +} + +/* + * Discard the chunk. + * + * Section: 0.2, 5.2.3, 5.2.5, 5.2.6, 6.0, 8.4.6, 8.5.1c, 9.2 + * [Too numerous to mention...] + * Verification Tag: No verification needed. + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_discard_chunk(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + + if (asoc && !sctp_vtag_verify(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Make sure that the chunk has a valid length. + * Since we don't know the chunk type, we use a general + * chunkhdr structure to make a comparison. + */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_chunkhdr))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + pr_debug("%s: chunk:%d is discarded\n", __func__, type.chunk); + + return SCTP_DISPOSITION_DISCARD; +} + +/* + * Discard the whole packet. + * + * Section: 8.4 2) + * + * 2) If the OOTB packet contains an ABORT chunk, the receiver MUST + * silently discard the OOTB packet and take no further action. + * + * Verification Tag: No verification necessary + * + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_pdiscard(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, struct sctp_cmd_seq *commands) +{ + SCTP_INC_STATS(net, SCTP_MIB_IN_PKT_DISCARDS); + sctp_add_cmd_sf(commands, SCTP_CMD_DISCARD_PACKET, SCTP_NULL()); + + return SCTP_DISPOSITION_CONSUME; +} + + +/* + * The other end is violating protocol. + * + * Section: Not specified + * Verification Tag: Not specified + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (asoc, reply_msg, msg_up, timers, counters) + * + * We simply tag the chunk as a violation. The state machine will log + * the violation and continue. + */ +enum sctp_disposition sctp_sf_violation(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + + if (!sctp_vtag_verify(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, commands); + + /* Make sure that the chunk has a valid length. */ + if (!sctp_chunk_length_valid(chunk, sizeof(struct sctp_chunkhdr))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, arg, + commands); + + return SCTP_DISPOSITION_VIOLATION; +} + +/* + * Common function to handle a protocol violation. + */ +static enum sctp_disposition sctp_sf_abort_violation( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + void *arg, + struct sctp_cmd_seq *commands, + const __u8 *payload, + const size_t paylen) +{ + struct sctp_packet *packet = NULL; + struct sctp_chunk *chunk = arg; + struct sctp_chunk *abort = NULL; + + /* SCTP-AUTH, Section 6.3: + * It should be noted that if the receiver wants to tear + * down an association in an authenticated way only, the + * handling of malformed packets should not result in + * tearing down the association. + * + * This means that if we only want to abort associations + * in an authenticated way (i.e AUTH+ABORT), then we + * can't destroy this association just because the packet + * was malformed. + */ + if (sctp_auth_recv_cid(SCTP_CID_ABORT, asoc)) + goto discard; + + /* Make the abort chunk. */ + abort = sctp_make_abort_violation(asoc, chunk, payload, paylen); + if (!abort) + goto nomem; + + if (asoc) { + /* Treat INIT-ACK as a special case during COOKIE-WAIT. */ + if (chunk->chunk_hdr->type == SCTP_CID_INIT_ACK && + !asoc->peer.i.init_tag) { + struct sctp_initack_chunk *initack; + + initack = (struct sctp_initack_chunk *)chunk->chunk_hdr; + if (!sctp_chunk_length_valid(chunk, sizeof(*initack))) + abort->chunk_hdr->flags |= SCTP_CHUNK_FLAG_T; + else { + unsigned int inittag; + + inittag = ntohl(initack->init_hdr.init_tag); + sctp_add_cmd_sf(commands, SCTP_CMD_UPDATE_INITTAG, + SCTP_U32(inittag)); + } + } + + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(abort)); + SCTP_INC_STATS(net, SCTP_MIB_OUTCTRLCHUNKS); + + if (asoc->state <= SCTP_STATE_COOKIE_ECHOED) { + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T1_INIT)); + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, + SCTP_ERROR(ECONNREFUSED)); + sctp_add_cmd_sf(commands, SCTP_CMD_INIT_FAILED, + SCTP_PERR(SCTP_ERROR_PROTO_VIOLATION)); + } else { + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, + SCTP_ERROR(ECONNABORTED)); + sctp_add_cmd_sf(commands, SCTP_CMD_ASSOC_FAILED, + SCTP_PERR(SCTP_ERROR_PROTO_VIOLATION)); + SCTP_DEC_STATS(net, SCTP_MIB_CURRESTAB); + } + } else { + packet = sctp_ootb_pkt_new(net, asoc, chunk); + + if (!packet) + goto nomem_pkt; + + if (sctp_test_T_bit(abort)) + packet->vtag = ntohl(chunk->sctp_hdr->vtag); + + abort->skb->sk = ep->base.sk; + + sctp_packet_append_chunk(packet, abort); + + sctp_add_cmd_sf(commands, SCTP_CMD_SEND_PKT, + SCTP_PACKET(packet)); + + SCTP_INC_STATS(net, SCTP_MIB_OUTCTRLCHUNKS); + } + + SCTP_INC_STATS(net, SCTP_MIB_ABORTEDS); + +discard: + sctp_sf_pdiscard(net, ep, asoc, SCTP_ST_CHUNK(0), arg, commands); + return SCTP_DISPOSITION_ABORT; + +nomem_pkt: + sctp_chunk_free(abort); +nomem: + return SCTP_DISPOSITION_NOMEM; +} + +/* + * Handle a protocol violation when the chunk length is invalid. + * "Invalid" length is identified as smaller than the minimal length a + * given chunk can be. For example, a SACK chunk has invalid length + * if its length is set to be smaller than the size of struct sctp_sack_chunk. + * + * We inform the other end by sending an ABORT with a Protocol Violation + * error code. + * + * Section: Not specified + * Verification Tag: Nothing to do + * Inputs + * (endpoint, asoc, chunk) + * + * Outputs + * (reply_msg, msg_up, counters) + * + * Generate an ABORT chunk and terminate the association. + */ +static enum sctp_disposition sctp_sf_violation_chunklen( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + static const char err_str[] = "The following chunk had invalid length:"; + + return sctp_sf_abort_violation(net, ep, asoc, arg, commands, err_str, + sizeof(err_str)); +} + +/* + * Handle a protocol violation when the parameter length is invalid. + * If the length is smaller than the minimum length of a given parameter, + * or accumulated length in multi parameters exceeds the end of the chunk, + * the length is considered as invalid. + */ +static enum sctp_disposition sctp_sf_violation_paramlen( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, void *ext, + struct sctp_cmd_seq *commands) +{ + struct sctp_paramhdr *param = ext; + struct sctp_chunk *abort = NULL; + struct sctp_chunk *chunk = arg; + + if (sctp_auth_recv_cid(SCTP_CID_ABORT, asoc)) + goto discard; + + /* Make the abort chunk. */ + abort = sctp_make_violation_paramlen(asoc, chunk, param); + if (!abort) + goto nomem; + + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(abort)); + SCTP_INC_STATS(net, SCTP_MIB_OUTCTRLCHUNKS); + + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, + SCTP_ERROR(ECONNABORTED)); + sctp_add_cmd_sf(commands, SCTP_CMD_ASSOC_FAILED, + SCTP_PERR(SCTP_ERROR_PROTO_VIOLATION)); + SCTP_DEC_STATS(net, SCTP_MIB_CURRESTAB); + SCTP_INC_STATS(net, SCTP_MIB_ABORTEDS); + +discard: + sctp_sf_pdiscard(net, ep, asoc, SCTP_ST_CHUNK(0), arg, commands); + return SCTP_DISPOSITION_ABORT; +nomem: + return SCTP_DISPOSITION_NOMEM; +} + +/* Handle a protocol violation when the peer trying to advance the + * cumulative tsn ack to a point beyond the max tsn currently sent. + * + * We inform the other end by sending an ABORT with a Protocol Violation + * error code. + */ +static enum sctp_disposition sctp_sf_violation_ctsn( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + static const char err_str[] = "The cumulative tsn ack beyond the max tsn currently sent:"; + + return sctp_sf_abort_violation(net, ep, asoc, arg, commands, err_str, + sizeof(err_str)); +} + +/* Handle protocol violation of an invalid chunk bundling. For example, + * when we have an association and we receive bundled INIT-ACK, or + * SHUTDOWN-COMPLETE, our peer is clearly violating the "MUST NOT bundle" + * statement from the specs. Additionally, there might be an attacker + * on the path and we may not want to continue this communication. + */ +static enum sctp_disposition sctp_sf_violation_chunk( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + static const char err_str[] = "The following chunk violates protocol:"; + + return sctp_sf_abort_violation(net, ep, asoc, arg, commands, err_str, + sizeof(err_str)); +} +/*************************************************************************** + * These are the state functions for handling primitive (Section 10) events. + ***************************************************************************/ +/* + * sctp_sf_do_prm_asoc + * + * Section: 10.1 ULP-to-SCTP + * B) Associate + * + * Format: ASSOCIATE(local SCTP instance name, destination transport addr, + * outbound stream count) + * -> association id [,destination transport addr list] [,outbound stream + * count] + * + * This primitive allows the upper layer to initiate an association to a + * specific peer endpoint. + * + * The peer endpoint shall be specified by one of the transport addresses + * which defines the endpoint (see Section 1.4). If the local SCTP + * instance has not been initialized, the ASSOCIATE is considered an + * error. + * [This is not relevant for the kernel implementation since we do all + * initialization at boot time. It we hadn't initialized we wouldn't + * get anywhere near this code.] + * + * An association id, which is a local handle to the SCTP association, + * will be returned on successful establishment of the association. If + * SCTP is not able to open an SCTP association with the peer endpoint, + * an error is returned. + * [In the kernel implementation, the struct sctp_association needs to + * be created BEFORE causing this primitive to run.] + * + * Other association parameters may be returned, including the + * complete destination transport addresses of the peer as well as the + * outbound stream count of the local endpoint. One of the transport + * address from the returned destination addresses will be selected by + * the local endpoint as default primary path for sending SCTP packets + * to this peer. The returned "destination transport addr list" can + * be used by the ULP to change the default primary path or to force + * sending a packet to a specific transport address. [All of this + * stuff happens when the INIT ACK arrives. This is a NON-BLOCKING + * function.] + * + * Mandatory attributes: + * + * o local SCTP instance name - obtained from the INITIALIZE operation. + * [This is the argument asoc.] + * o destination transport addr - specified as one of the transport + * addresses of the peer endpoint with which the association is to be + * established. + * [This is asoc->peer.active_path.] + * o outbound stream count - the number of outbound streams the ULP + * would like to open towards this peer endpoint. + * [BUG: This is not currently implemented.] + * Optional attributes: + * + * None. + * + * The return value is a disposition. + */ +enum sctp_disposition sctp_sf_do_prm_asoc(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_association *my_asoc; + struct sctp_chunk *repl; + + /* The comment below says that we enter COOKIE-WAIT AFTER + * sending the INIT, but that doesn't actually work in our + * implementation... + */ + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_COOKIE_WAIT)); + + /* RFC 2960 5.1 Normal Establishment of an Association + * + * A) "A" first sends an INIT chunk to "Z". In the INIT, "A" + * must provide its Verification Tag (Tag_A) in the Initiate + * Tag field. Tag_A SHOULD be a random number in the range of + * 1 to 4294967295 (see 5.3.1 for Tag value selection). ... + */ + + repl = sctp_make_init(asoc, &asoc->base.bind_addr, GFP_ATOMIC, 0); + if (!repl) + goto nomem; + + /* Choose transport for INIT. */ + sctp_add_cmd_sf(commands, SCTP_CMD_INIT_CHOOSE_TRANSPORT, + SCTP_CHUNK(repl)); + + /* Cast away the const modifier, as we want to just + * rerun it through as a sideffect. + */ + my_asoc = (struct sctp_association *)asoc; + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_ASOC, SCTP_ASOC(my_asoc)); + + /* After sending the INIT, "A" starts the T1-init timer and + * enters the COOKIE-WAIT state. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_START, + SCTP_TO(SCTP_EVENT_TIMEOUT_T1_INIT)); + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(repl)); + return SCTP_DISPOSITION_CONSUME; + +nomem: + return SCTP_DISPOSITION_NOMEM; +} + +/* + * Process the SEND primitive. + * + * Section: 10.1 ULP-to-SCTP + * E) Send + * + * Format: SEND(association id, buffer address, byte count [,context] + * [,stream id] [,life time] [,destination transport address] + * [,unorder flag] [,no-bundle flag] [,payload protocol-id] ) + * -> result + * + * This is the main method to send user data via SCTP. + * + * Mandatory attributes: + * + * o association id - local handle to the SCTP association + * + * o buffer address - the location where the user message to be + * transmitted is stored; + * + * o byte count - The size of the user data in number of bytes; + * + * Optional attributes: + * + * o context - an optional 32 bit integer that will be carried in the + * sending failure notification to the ULP if the transportation of + * this User Message fails. + * + * o stream id - to indicate which stream to send the data on. If not + * specified, stream 0 will be used. + * + * o life time - specifies the life time of the user data. The user data + * will not be sent by SCTP after the life time expires. This + * parameter can be used to avoid efforts to transmit stale + * user messages. SCTP notifies the ULP if the data cannot be + * initiated to transport (i.e. sent to the destination via SCTP's + * send primitive) within the life time variable. However, the + * user data will be transmitted if SCTP has attempted to transmit a + * chunk before the life time expired. + * + * o destination transport address - specified as one of the destination + * transport addresses of the peer endpoint to which this packet + * should be sent. Whenever possible, SCTP should use this destination + * transport address for sending the packets, instead of the current + * primary path. + * + * o unorder flag - this flag, if present, indicates that the user + * would like the data delivered in an unordered fashion to the peer + * (i.e., the U flag is set to 1 on all DATA chunks carrying this + * message). + * + * o no-bundle flag - instructs SCTP not to bundle this user data with + * other outbound DATA chunks. SCTP MAY still bundle even when + * this flag is present, when faced with network congestion. + * + * o payload protocol-id - A 32 bit unsigned integer that is to be + * passed to the peer indicating the type of payload protocol data + * being transmitted. This value is passed as opaque data by SCTP. + * + * The return value is the disposition. + */ +enum sctp_disposition sctp_sf_do_prm_send(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_datamsg *msg = arg; + + sctp_add_cmd_sf(commands, SCTP_CMD_SEND_MSG, SCTP_DATAMSG(msg)); + return SCTP_DISPOSITION_CONSUME; +} + +/* + * Process the SHUTDOWN primitive. + * + * Section: 10.1: + * C) Shutdown + * + * Format: SHUTDOWN(association id) + * -> result + * + * Gracefully closes an association. Any locally queued user data + * will be delivered to the peer. The association will be terminated only + * after the peer acknowledges all the SCTP packets sent. A success code + * will be returned on successful termination of the association. If + * attempting to terminate the association results in a failure, an error + * code shall be returned. + * + * Mandatory attributes: + * + * o association id - local handle to the SCTP association + * + * Optional attributes: + * + * None. + * + * The return value is the disposition. + */ +enum sctp_disposition sctp_sf_do_9_2_prm_shutdown( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + enum sctp_disposition disposition; + + /* From 9.2 Shutdown of an Association + * Upon receipt of the SHUTDOWN primitive from its upper + * layer, the endpoint enters SHUTDOWN-PENDING state and + * remains there until all outstanding data has been + * acknowledged by its peer. The endpoint accepts no new data + * from its upper layer, but retransmits data to the far end + * if necessary to fill gaps. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_SHUTDOWN_PENDING)); + + disposition = SCTP_DISPOSITION_CONSUME; + if (sctp_outq_is_empty(&asoc->outqueue)) { + disposition = sctp_sf_do_9_2_start_shutdown(net, ep, asoc, type, + arg, commands); + } + + return disposition; +} + +/* + * Process the ABORT primitive. + * + * Section: 10.1: + * C) Abort + * + * Format: Abort(association id [, cause code]) + * -> result + * + * Ungracefully closes an association. Any locally queued user data + * will be discarded and an ABORT chunk is sent to the peer. A success code + * will be returned on successful abortion of the association. If + * attempting to abort the association results in a failure, an error + * code shall be returned. + * + * Mandatory attributes: + * + * o association id - local handle to the SCTP association + * + * Optional attributes: + * + * o cause code - reason of the abort to be passed to the peer + * + * None. + * + * The return value is the disposition. + */ +enum sctp_disposition sctp_sf_do_9_1_prm_abort( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + /* From 9.1 Abort of an Association + * Upon receipt of the ABORT primitive from its upper + * layer, the endpoint enters CLOSED state and + * discard all outstanding data has been + * acknowledged by its peer. The endpoint accepts no new data + * from its upper layer, but retransmits data to the far end + * if necessary to fill gaps. + */ + struct sctp_chunk *abort = arg; + + if (abort) + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(abort)); + + /* Even if we can't send the ABORT due to low memory delete the + * TCB. This is a departure from our typical NOMEM handling. + */ + + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, + SCTP_ERROR(ECONNABORTED)); + /* Delete the established association. */ + sctp_add_cmd_sf(commands, SCTP_CMD_ASSOC_FAILED, + SCTP_PERR(SCTP_ERROR_USER_ABORT)); + + SCTP_INC_STATS(net, SCTP_MIB_ABORTEDS); + SCTP_DEC_STATS(net, SCTP_MIB_CURRESTAB); + + return SCTP_DISPOSITION_ABORT; +} + +/* We tried an illegal operation on an association which is closed. */ +enum sctp_disposition sctp_sf_error_closed(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + sctp_add_cmd_sf(commands, SCTP_CMD_REPORT_ERROR, SCTP_ERROR(-EINVAL)); + return SCTP_DISPOSITION_CONSUME; +} + +/* We tried an illegal operation on an association which is shutting + * down. + */ +enum sctp_disposition sctp_sf_error_shutdown( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + sctp_add_cmd_sf(commands, SCTP_CMD_REPORT_ERROR, + SCTP_ERROR(-ESHUTDOWN)); + return SCTP_DISPOSITION_CONSUME; +} + +/* + * sctp_cookie_wait_prm_shutdown + * + * Section: 4 Note: 2 + * Verification Tag: + * Inputs + * (endpoint, asoc) + * + * The RFC does not explicitly address this issue, but is the route through the + * state table when someone issues a shutdown while in COOKIE_WAIT state. + * + * Outputs + * (timers) + */ +enum sctp_disposition sctp_sf_cookie_wait_prm_shutdown( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T1_INIT)); + + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_CLOSED)); + + SCTP_INC_STATS(net, SCTP_MIB_SHUTDOWNS); + + sctp_add_cmd_sf(commands, SCTP_CMD_DELETE_TCB, SCTP_NULL()); + + return SCTP_DISPOSITION_DELETE_TCB; +} + +/* + * sctp_cookie_echoed_prm_shutdown + * + * Section: 4 Note: 2 + * Verification Tag: + * Inputs + * (endpoint, asoc) + * + * The RFC does not explicitly address this issue, but is the route through the + * state table when someone issues a shutdown while in COOKIE_ECHOED state. + * + * Outputs + * (timers) + */ +enum sctp_disposition sctp_sf_cookie_echoed_prm_shutdown( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + /* There is a single T1 timer, so we should be able to use + * common function with the COOKIE-WAIT state. + */ + return sctp_sf_cookie_wait_prm_shutdown(net, ep, asoc, type, arg, commands); +} + +/* + * sctp_sf_cookie_wait_prm_abort + * + * Section: 4 Note: 2 + * Verification Tag: + * Inputs + * (endpoint, asoc) + * + * The RFC does not explicitly address this issue, but is the route through the + * state table when someone issues an abort while in COOKIE_WAIT state. + * + * Outputs + * (timers) + */ +enum sctp_disposition sctp_sf_cookie_wait_prm_abort( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *abort = arg; + + /* Stop T1-init timer */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T1_INIT)); + + if (abort) + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(abort)); + + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_CLOSED)); + + SCTP_INC_STATS(net, SCTP_MIB_ABORTEDS); + + /* Even if we can't send the ABORT due to low memory delete the + * TCB. This is a departure from our typical NOMEM handling. + */ + + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, + SCTP_ERROR(ECONNREFUSED)); + /* Delete the established association. */ + sctp_add_cmd_sf(commands, SCTP_CMD_INIT_FAILED, + SCTP_PERR(SCTP_ERROR_USER_ABORT)); + + return SCTP_DISPOSITION_ABORT; +} + +/* + * sctp_sf_cookie_echoed_prm_abort + * + * Section: 4 Note: 3 + * Verification Tag: + * Inputs + * (endpoint, asoc) + * + * The RFC does not explcitly address this issue, but is the route through the + * state table when someone issues an abort while in COOKIE_ECHOED state. + * + * Outputs + * (timers) + */ +enum sctp_disposition sctp_sf_cookie_echoed_prm_abort( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + /* There is a single T1 timer, so we should be able to use + * common function with the COOKIE-WAIT state. + */ + return sctp_sf_cookie_wait_prm_abort(net, ep, asoc, type, arg, commands); +} + +/* + * sctp_sf_shutdown_pending_prm_abort + * + * Inputs + * (endpoint, asoc) + * + * The RFC does not explicitly address this issue, but is the route through the + * state table when someone issues an abort while in SHUTDOWN-PENDING state. + * + * Outputs + * (timers) + */ +enum sctp_disposition sctp_sf_shutdown_pending_prm_abort( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + /* Stop the T5-shutdown guard timer. */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T5_SHUTDOWN_GUARD)); + + return sctp_sf_do_9_1_prm_abort(net, ep, asoc, type, arg, commands); +} + +/* + * sctp_sf_shutdown_sent_prm_abort + * + * Inputs + * (endpoint, asoc) + * + * The RFC does not explicitly address this issue, but is the route through the + * state table when someone issues an abort while in SHUTDOWN-SENT state. + * + * Outputs + * (timers) + */ +enum sctp_disposition sctp_sf_shutdown_sent_prm_abort( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + /* Stop the T2-shutdown timer. */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T2_SHUTDOWN)); + + /* Stop the T5-shutdown guard timer. */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T5_SHUTDOWN_GUARD)); + + return sctp_sf_do_9_1_prm_abort(net, ep, asoc, type, arg, commands); +} + +/* + * sctp_sf_cookie_echoed_prm_abort + * + * Inputs + * (endpoint, asoc) + * + * The RFC does not explcitly address this issue, but is the route through the + * state table when someone issues an abort while in COOKIE_ECHOED state. + * + * Outputs + * (timers) + */ +enum sctp_disposition sctp_sf_shutdown_ack_sent_prm_abort( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + /* The same T2 timer, so we should be able to use + * common function with the SHUTDOWN-SENT state. + */ + return sctp_sf_shutdown_sent_prm_abort(net, ep, asoc, type, arg, commands); +} + +/* + * Process the REQUESTHEARTBEAT primitive + * + * 10.1 ULP-to-SCTP + * J) Request Heartbeat + * + * Format: REQUESTHEARTBEAT(association id, destination transport address) + * + * -> result + * + * Instructs the local endpoint to perform a HeartBeat on the specified + * destination transport address of the given association. The returned + * result should indicate whether the transmission of the HEARTBEAT + * chunk to the destination address is successful. + * + * Mandatory attributes: + * + * o association id - local handle to the SCTP association + * + * o destination transport address - the transport address of the + * association on which a heartbeat should be issued. + */ +enum sctp_disposition sctp_sf_do_prm_requestheartbeat( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + if (SCTP_DISPOSITION_NOMEM == sctp_sf_heartbeat(ep, asoc, type, + (struct sctp_transport *)arg, commands)) + return SCTP_DISPOSITION_NOMEM; + + /* + * RFC 2960 (bis), section 8.3 + * + * D) Request an on-demand HEARTBEAT on a specific destination + * transport address of a given association. + * + * The endpoint should increment the respective error counter of + * the destination transport address each time a HEARTBEAT is sent + * to that address and not acknowledged within one RTO. + * + */ + sctp_add_cmd_sf(commands, SCTP_CMD_TRANSPORT_HB_SENT, + SCTP_TRANSPORT(arg)); + return SCTP_DISPOSITION_CONSUME; +} + +/* + * ADDIP Section 4.1 ASCONF Chunk Procedures + * When an endpoint has an ASCONF signaled change to be sent to the + * remote endpoint it should do A1 to A9 + */ +enum sctp_disposition sctp_sf_do_prm_asconf(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + + sctp_add_cmd_sf(commands, SCTP_CMD_SETUP_T4, SCTP_CHUNK(chunk)); + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_START, + SCTP_TO(SCTP_EVENT_TIMEOUT_T4_RTO)); + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(chunk)); + return SCTP_DISPOSITION_CONSUME; +} + +/* RE-CONFIG Section 5.1 RECONF Chunk Procedures */ +enum sctp_disposition sctp_sf_do_prm_reconf(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(chunk)); + return SCTP_DISPOSITION_CONSUME; +} + +/* + * Ignore the primitive event + * + * The return value is the disposition of the primitive. + */ +enum sctp_disposition sctp_sf_ignore_primitive( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + pr_debug("%s: primitive type:%d is ignored\n", __func__, + type.primitive); + + return SCTP_DISPOSITION_DISCARD; +} + +/*************************************************************************** + * These are the state functions for the OTHER events. + ***************************************************************************/ + +/* + * When the SCTP stack has no more user data to send or retransmit, this + * notification is given to the user. Also, at the time when a user app + * subscribes to this event, if there is no data to be sent or + * retransmit, the stack will immediately send up this notification. + */ +enum sctp_disposition sctp_sf_do_no_pending_tsn( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_ulpevent *event; + + event = sctp_ulpevent_make_sender_dry_event(asoc, GFP_ATOMIC); + if (!event) + return SCTP_DISPOSITION_NOMEM; + + sctp_add_cmd_sf(commands, SCTP_CMD_EVENT_ULP, SCTP_ULPEVENT(event)); + + return SCTP_DISPOSITION_CONSUME; +} + +/* + * Start the shutdown negotiation. + * + * From Section 9.2: + * Once all its outstanding data has been acknowledged, the endpoint + * shall send a SHUTDOWN chunk to its peer including in the Cumulative + * TSN Ack field the last sequential TSN it has received from the peer. + * It shall then start the T2-shutdown timer and enter the SHUTDOWN-SENT + * state. If the timer expires, the endpoint must re-send the SHUTDOWN + * with the updated last sequential TSN received from its peer. + * + * The return value is the disposition. + */ +enum sctp_disposition sctp_sf_do_9_2_start_shutdown( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *reply; + + /* Once all its outstanding data has been acknowledged, the + * endpoint shall send a SHUTDOWN chunk to its peer including + * in the Cumulative TSN Ack field the last sequential TSN it + * has received from the peer. + */ + reply = sctp_make_shutdown(asoc, arg); + if (!reply) + goto nomem; + + /* Set the transport for the SHUTDOWN chunk and the timeout for the + * T2-shutdown timer. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_SETUP_T2, SCTP_CHUNK(reply)); + + /* It shall then start the T2-shutdown timer */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_START, + SCTP_TO(SCTP_EVENT_TIMEOUT_T2_SHUTDOWN)); + + /* RFC 4960 Section 9.2 + * The sender of the SHUTDOWN MAY also start an overall guard timer + * 'T5-shutdown-guard' to bound the overall time for shutdown sequence. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_RESTART, + SCTP_TO(SCTP_EVENT_TIMEOUT_T5_SHUTDOWN_GUARD)); + + if (asoc->timeouts[SCTP_EVENT_TIMEOUT_AUTOCLOSE]) + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_AUTOCLOSE)); + + /* and enter the SHUTDOWN-SENT state. */ + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_SHUTDOWN_SENT)); + + /* sctp-implguide 2.10 Issues with Heartbeating and failover + * + * HEARTBEAT ... is discontinued after sending either SHUTDOWN + * or SHUTDOWN-ACK. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_HB_TIMERS_STOP, SCTP_NULL()); + + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(reply)); + + return SCTP_DISPOSITION_CONSUME; + +nomem: + return SCTP_DISPOSITION_NOMEM; +} + +/* + * Generate a SHUTDOWN ACK now that everything is SACK'd. + * + * From Section 9.2: + * + * If it has no more outstanding DATA chunks, the SHUTDOWN receiver + * shall send a SHUTDOWN ACK and start a T2-shutdown timer of its own, + * entering the SHUTDOWN-ACK-SENT state. If the timer expires, the + * endpoint must re-send the SHUTDOWN ACK. + * + * The return value is the disposition. + */ +enum sctp_disposition sctp_sf_do_9_2_shutdown_ack( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = arg; + struct sctp_chunk *reply; + + /* There are 2 ways of getting here: + * 1) called in response to a SHUTDOWN chunk + * 2) called when SCTP_EVENT_NO_PENDING_TSN event is issued. + * + * For the case (2), the arg parameter is set to NULL. We need + * to check that we have a chunk before accessing it's fields. + */ + if (chunk) { + if (!sctp_vtag_verify(chunk, asoc)) + return sctp_sf_pdiscard(net, ep, asoc, type, arg, + commands); + + /* Make sure that the SHUTDOWN chunk has a valid length. */ + if (!sctp_chunk_length_valid( + chunk, sizeof(struct sctp_shutdown_chunk))) + return sctp_sf_violation_chunklen(net, ep, asoc, type, + arg, commands); + } + + /* If it has no more outstanding DATA chunks, the SHUTDOWN receiver + * shall send a SHUTDOWN ACK ... + */ + reply = sctp_make_shutdown_ack(asoc, chunk); + if (!reply) + goto nomem; + + /* Set the transport for the SHUTDOWN ACK chunk and the timeout for + * the T2-shutdown timer. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_SETUP_T2, SCTP_CHUNK(reply)); + + /* and start/restart a T2-shutdown timer of its own, */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_RESTART, + SCTP_TO(SCTP_EVENT_TIMEOUT_T2_SHUTDOWN)); + + if (asoc->timeouts[SCTP_EVENT_TIMEOUT_AUTOCLOSE]) + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_AUTOCLOSE)); + + /* Enter the SHUTDOWN-ACK-SENT state. */ + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_SHUTDOWN_ACK_SENT)); + + /* sctp-implguide 2.10 Issues with Heartbeating and failover + * + * HEARTBEAT ... is discontinued after sending either SHUTDOWN + * or SHUTDOWN-ACK. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_HB_TIMERS_STOP, SCTP_NULL()); + + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(reply)); + + return SCTP_DISPOSITION_CONSUME; + +nomem: + return SCTP_DISPOSITION_NOMEM; +} + +/* + * Ignore the event defined as other + * + * The return value is the disposition of the event. + */ +enum sctp_disposition sctp_sf_ignore_other(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + pr_debug("%s: the event other type:%d is ignored\n", + __func__, type.other); + + return SCTP_DISPOSITION_DISCARD; +} + +/************************************************************ + * These are the state functions for handling timeout events. + ************************************************************/ + +/* + * RTX Timeout + * + * Section: 6.3.3 Handle T3-rtx Expiration + * + * Whenever the retransmission timer T3-rtx expires for a destination + * address, do the following: + * [See below] + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_do_6_3_3_rtx(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_transport *transport = arg; + + SCTP_INC_STATS(net, SCTP_MIB_T3_RTX_EXPIREDS); + + if (asoc->overall_error_count >= asoc->max_retrans) { + if (asoc->peer.zero_window_announced && + asoc->state == SCTP_STATE_SHUTDOWN_PENDING) { + /* + * We are here likely because the receiver had its rwnd + * closed for a while and we have not been able to + * transmit the locally queued data within the maximum + * retransmission attempts limit. Start the T5 + * shutdown guard timer to give the receiver one last + * chance and some additional time to recover before + * aborting. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_START_ONCE, + SCTP_TO(SCTP_EVENT_TIMEOUT_T5_SHUTDOWN_GUARD)); + } else { + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, + SCTP_ERROR(ETIMEDOUT)); + /* CMD_ASSOC_FAILED calls CMD_DELETE_TCB. */ + sctp_add_cmd_sf(commands, SCTP_CMD_ASSOC_FAILED, + SCTP_PERR(SCTP_ERROR_NO_ERROR)); + SCTP_INC_STATS(net, SCTP_MIB_ABORTEDS); + SCTP_DEC_STATS(net, SCTP_MIB_CURRESTAB); + return SCTP_DISPOSITION_DELETE_TCB; + } + } + + /* E1) For the destination address for which the timer + * expires, adjust its ssthresh with rules defined in Section + * 7.2.3 and set the cwnd <- MTU. + */ + + /* E2) For the destination address for which the timer + * expires, set RTO <- RTO * 2 ("back off the timer"). The + * maximum value discussed in rule C7 above (RTO.max) may be + * used to provide an upper bound to this doubling operation. + */ + + /* E3) Determine how many of the earliest (i.e., lowest TSN) + * outstanding DATA chunks for the address for which the + * T3-rtx has expired will fit into a single packet, subject + * to the MTU constraint for the path corresponding to the + * destination transport address to which the retransmission + * is being sent (this may be different from the address for + * which the timer expires [see Section 6.4]). Call this + * value K. Bundle and retransmit those K DATA chunks in a + * single packet to the destination endpoint. + * + * Note: Any DATA chunks that were sent to the address for + * which the T3-rtx timer expired but did not fit in one MTU + * (rule E3 above), should be marked for retransmission and + * sent as soon as cwnd allows (normally when a SACK arrives). + */ + + /* Do some failure management (Section 8.2). */ + sctp_add_cmd_sf(commands, SCTP_CMD_STRIKE, SCTP_TRANSPORT(transport)); + + /* NB: Rules E4 and F1 are implicit in R1. */ + sctp_add_cmd_sf(commands, SCTP_CMD_RETRAN, SCTP_TRANSPORT(transport)); + + return SCTP_DISPOSITION_CONSUME; +} + +/* + * Generate delayed SACK on timeout + * + * Section: 6.2 Acknowledgement on Reception of DATA Chunks + * + * The guidelines on delayed acknowledgement algorithm specified in + * Section 4.2 of [RFC2581] SHOULD be followed. Specifically, an + * acknowledgement SHOULD be generated for at least every second packet + * (not every second DATA chunk) received, and SHOULD be generated + * within 200 ms of the arrival of any unacknowledged DATA chunk. In + * some situations it may be beneficial for an SCTP transmitter to be + * more conservative than the algorithms detailed in this document + * allow. However, an SCTP transmitter MUST NOT be more aggressive than + * the following algorithms allow. + */ +enum sctp_disposition sctp_sf_do_6_2_sack(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + SCTP_INC_STATS(net, SCTP_MIB_DELAY_SACK_EXPIREDS); + sctp_add_cmd_sf(commands, SCTP_CMD_GEN_SACK, SCTP_FORCE()); + return SCTP_DISPOSITION_CONSUME; +} + +/* + * sctp_sf_t1_init_timer_expire + * + * Section: 4 Note: 2 + * Verification Tag: + * Inputs + * (endpoint, asoc) + * + * RFC 2960 Section 4 Notes + * 2) If the T1-init timer expires, the endpoint MUST retransmit INIT + * and re-start the T1-init timer without changing state. This MUST + * be repeated up to 'Max.Init.Retransmits' times. After that, the + * endpoint MUST abort the initialization process and report the + * error to SCTP user. + * + * Outputs + * (timers, events) + * + */ +enum sctp_disposition sctp_sf_t1_init_timer_expire( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + int attempts = asoc->init_err_counter + 1; + struct sctp_chunk *repl = NULL; + struct sctp_bind_addr *bp; + + pr_debug("%s: timer T1 expired (INIT)\n", __func__); + + SCTP_INC_STATS(net, SCTP_MIB_T1_INIT_EXPIREDS); + + if (attempts <= asoc->max_init_attempts) { + bp = (struct sctp_bind_addr *) &asoc->base.bind_addr; + repl = sctp_make_init(asoc, bp, GFP_ATOMIC, 0); + if (!repl) + return SCTP_DISPOSITION_NOMEM; + + /* Choose transport for INIT. */ + sctp_add_cmd_sf(commands, SCTP_CMD_INIT_CHOOSE_TRANSPORT, + SCTP_CHUNK(repl)); + + /* Issue a sideeffect to do the needed accounting. */ + sctp_add_cmd_sf(commands, SCTP_CMD_INIT_RESTART, + SCTP_TO(SCTP_EVENT_TIMEOUT_T1_INIT)); + + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(repl)); + } else { + pr_debug("%s: giving up on INIT, attempts:%d " + "max_init_attempts:%d\n", __func__, attempts, + asoc->max_init_attempts); + + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, + SCTP_ERROR(ETIMEDOUT)); + sctp_add_cmd_sf(commands, SCTP_CMD_INIT_FAILED, + SCTP_PERR(SCTP_ERROR_NO_ERROR)); + return SCTP_DISPOSITION_DELETE_TCB; + } + + return SCTP_DISPOSITION_CONSUME; +} + +/* + * sctp_sf_t1_cookie_timer_expire + * + * Section: 4 Note: 2 + * Verification Tag: + * Inputs + * (endpoint, asoc) + * + * RFC 2960 Section 4 Notes + * 3) If the T1-cookie timer expires, the endpoint MUST retransmit + * COOKIE ECHO and re-start the T1-cookie timer without changing + * state. This MUST be repeated up to 'Max.Init.Retransmits' times. + * After that, the endpoint MUST abort the initialization process and + * report the error to SCTP user. + * + * Outputs + * (timers, events) + * + */ +enum sctp_disposition sctp_sf_t1_cookie_timer_expire( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + int attempts = asoc->init_err_counter + 1; + struct sctp_chunk *repl = NULL; + + pr_debug("%s: timer T1 expired (COOKIE-ECHO)\n", __func__); + + SCTP_INC_STATS(net, SCTP_MIB_T1_COOKIE_EXPIREDS); + + if (attempts <= asoc->max_init_attempts) { + repl = sctp_make_cookie_echo(asoc, NULL); + if (!repl) + return SCTP_DISPOSITION_NOMEM; + + sctp_add_cmd_sf(commands, SCTP_CMD_INIT_CHOOSE_TRANSPORT, + SCTP_CHUNK(repl)); + /* Issue a sideeffect to do the needed accounting. */ + sctp_add_cmd_sf(commands, SCTP_CMD_COOKIEECHO_RESTART, + SCTP_TO(SCTP_EVENT_TIMEOUT_T1_COOKIE)); + + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(repl)); + } else { + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, + SCTP_ERROR(ETIMEDOUT)); + sctp_add_cmd_sf(commands, SCTP_CMD_INIT_FAILED, + SCTP_PERR(SCTP_ERROR_NO_ERROR)); + return SCTP_DISPOSITION_DELETE_TCB; + } + + return SCTP_DISPOSITION_CONSUME; +} + +/* RFC2960 9.2 If the timer expires, the endpoint must re-send the SHUTDOWN + * with the updated last sequential TSN received from its peer. + * + * An endpoint should limit the number of retransmission of the + * SHUTDOWN chunk to the protocol parameter 'Association.Max.Retrans'. + * If this threshold is exceeded the endpoint should destroy the TCB and + * MUST report the peer endpoint unreachable to the upper layer (and + * thus the association enters the CLOSED state). The reception of any + * packet from its peer (i.e. as the peer sends all of its queued DATA + * chunks) should clear the endpoint's retransmission count and restart + * the T2-Shutdown timer, giving its peer ample opportunity to transmit + * all of its queued DATA chunks that have not yet been sent. + */ +enum sctp_disposition sctp_sf_t2_timer_expire( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *reply = NULL; + + pr_debug("%s: timer T2 expired\n", __func__); + + SCTP_INC_STATS(net, SCTP_MIB_T2_SHUTDOWN_EXPIREDS); + + ((struct sctp_association *)asoc)->shutdown_retries++; + + if (asoc->overall_error_count >= asoc->max_retrans) { + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, + SCTP_ERROR(ETIMEDOUT)); + /* Note: CMD_ASSOC_FAILED calls CMD_DELETE_TCB. */ + sctp_add_cmd_sf(commands, SCTP_CMD_ASSOC_FAILED, + SCTP_PERR(SCTP_ERROR_NO_ERROR)); + SCTP_INC_STATS(net, SCTP_MIB_ABORTEDS); + SCTP_DEC_STATS(net, SCTP_MIB_CURRESTAB); + return SCTP_DISPOSITION_DELETE_TCB; + } + + switch (asoc->state) { + case SCTP_STATE_SHUTDOWN_SENT: + reply = sctp_make_shutdown(asoc, NULL); + break; + + case SCTP_STATE_SHUTDOWN_ACK_SENT: + reply = sctp_make_shutdown_ack(asoc, NULL); + break; + + default: + BUG(); + break; + } + + if (!reply) + goto nomem; + + /* Do some failure management (Section 8.2). + * If we remove the transport an SHUTDOWN was last sent to, don't + * do failure management. + */ + if (asoc->shutdown_last_sent_to) + sctp_add_cmd_sf(commands, SCTP_CMD_STRIKE, + SCTP_TRANSPORT(asoc->shutdown_last_sent_to)); + + /* Set the transport for the SHUTDOWN/ACK chunk and the timeout for + * the T2-shutdown timer. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_SETUP_T2, SCTP_CHUNK(reply)); + + /* Restart the T2-shutdown timer. */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_RESTART, + SCTP_TO(SCTP_EVENT_TIMEOUT_T2_SHUTDOWN)); + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(reply)); + return SCTP_DISPOSITION_CONSUME; + +nomem: + return SCTP_DISPOSITION_NOMEM; +} + +/* + * ADDIP Section 4.1 ASCONF Chunk Procedures + * If the T4 RTO timer expires the endpoint should do B1 to B5 + */ +enum sctp_disposition sctp_sf_t4_timer_expire( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *chunk = asoc->addip_last_asconf; + struct sctp_transport *transport = chunk->transport; + + SCTP_INC_STATS(net, SCTP_MIB_T4_RTO_EXPIREDS); + + /* ADDIP 4.1 B1) Increment the error counters and perform path failure + * detection on the appropriate destination address as defined in + * RFC2960 [5] section 8.1 and 8.2. + */ + if (transport) + sctp_add_cmd_sf(commands, SCTP_CMD_STRIKE, + SCTP_TRANSPORT(transport)); + + /* Reconfig T4 timer and transport. */ + sctp_add_cmd_sf(commands, SCTP_CMD_SETUP_T4, SCTP_CHUNK(chunk)); + + /* ADDIP 4.1 B2) Increment the association error counters and perform + * endpoint failure detection on the association as defined in + * RFC2960 [5] section 8.1 and 8.2. + * association error counter is incremented in SCTP_CMD_STRIKE. + */ + if (asoc->overall_error_count >= asoc->max_retrans) { + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_STOP, + SCTP_TO(SCTP_EVENT_TIMEOUT_T4_RTO)); + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, + SCTP_ERROR(ETIMEDOUT)); + sctp_add_cmd_sf(commands, SCTP_CMD_ASSOC_FAILED, + SCTP_PERR(SCTP_ERROR_NO_ERROR)); + SCTP_INC_STATS(net, SCTP_MIB_ABORTEDS); + SCTP_DEC_STATS(net, SCTP_MIB_CURRESTAB); + return SCTP_DISPOSITION_ABORT; + } + + /* ADDIP 4.1 B3) Back-off the destination address RTO value to which + * the ASCONF chunk was sent by doubling the RTO timer value. + * This is done in SCTP_CMD_STRIKE. + */ + + /* ADDIP 4.1 B4) Re-transmit the ASCONF Chunk last sent and if possible + * choose an alternate destination address (please refer to RFC2960 + * [5] section 6.4.1). An endpoint MUST NOT add new parameters to this + * chunk, it MUST be the same (including its serial number) as the last + * ASCONF sent. + */ + sctp_chunk_hold(asoc->addip_last_asconf); + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, + SCTP_CHUNK(asoc->addip_last_asconf)); + + /* ADDIP 4.1 B5) Restart the T-4 RTO timer. Note that if a different + * destination is selected, then the RTO used will be that of the new + * destination address. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_TIMER_RESTART, + SCTP_TO(SCTP_EVENT_TIMEOUT_T4_RTO)); + + return SCTP_DISPOSITION_CONSUME; +} + +/* sctpimpguide-05 Section 2.12.2 + * The sender of the SHUTDOWN MAY also start an overall guard timer + * 'T5-shutdown-guard' to bound the overall time for shutdown sequence. + * At the expiration of this timer the sender SHOULD abort the association + * by sending an ABORT chunk. + */ +enum sctp_disposition sctp_sf_t5_timer_expire( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + struct sctp_chunk *reply = NULL; + + pr_debug("%s: timer T5 expired\n", __func__); + + SCTP_INC_STATS(net, SCTP_MIB_T5_SHUTDOWN_GUARD_EXPIREDS); + + reply = sctp_make_abort(asoc, NULL, 0); + if (!reply) + goto nomem; + + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, SCTP_CHUNK(reply)); + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, + SCTP_ERROR(ETIMEDOUT)); + sctp_add_cmd_sf(commands, SCTP_CMD_ASSOC_FAILED, + SCTP_PERR(SCTP_ERROR_NO_ERROR)); + + SCTP_INC_STATS(net, SCTP_MIB_ABORTEDS); + SCTP_DEC_STATS(net, SCTP_MIB_CURRESTAB); + + return SCTP_DISPOSITION_DELETE_TCB; +nomem: + return SCTP_DISPOSITION_NOMEM; +} + +/* Handle expiration of AUTOCLOSE timer. When the autoclose timer expires, + * the association is automatically closed by starting the shutdown process. + * The work that needs to be done is same as when SHUTDOWN is initiated by + * the user. So this routine looks same as sctp_sf_do_9_2_prm_shutdown(). + */ +enum sctp_disposition sctp_sf_autoclose_timer_expire( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + enum sctp_disposition disposition; + + SCTP_INC_STATS(net, SCTP_MIB_AUTOCLOSE_EXPIREDS); + + /* From 9.2 Shutdown of an Association + * Upon receipt of the SHUTDOWN primitive from its upper + * layer, the endpoint enters SHUTDOWN-PENDING state and + * remains there until all outstanding data has been + * acknowledged by its peer. The endpoint accepts no new data + * from its upper layer, but retransmits data to the far end + * if necessary to fill gaps. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_NEW_STATE, + SCTP_STATE(SCTP_STATE_SHUTDOWN_PENDING)); + + disposition = SCTP_DISPOSITION_CONSUME; + if (sctp_outq_is_empty(&asoc->outqueue)) { + disposition = sctp_sf_do_9_2_start_shutdown(net, ep, asoc, type, + NULL, commands); + } + + return disposition; +} + +/***************************************************************************** + * These are sa state functions which could apply to all types of events. + ****************************************************************************/ + +/* + * This table entry is not implemented. + * + * Inputs + * (endpoint, asoc, chunk) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_not_impl(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, struct sctp_cmd_seq *commands) +{ + return SCTP_DISPOSITION_NOT_IMPL; +} + +/* + * This table entry represents a bug. + * + * Inputs + * (endpoint, asoc, chunk) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_bug(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, struct sctp_cmd_seq *commands) +{ + return SCTP_DISPOSITION_BUG; +} + +/* + * This table entry represents the firing of a timer in the wrong state. + * Since timer deletion cannot be guaranteed a timer 'may' end up firing + * when the association is in the wrong state. This event should + * be ignored, so as to prevent any rearming of the timer. + * + * Inputs + * (endpoint, asoc, chunk) + * + * The return value is the disposition of the chunk. + */ +enum sctp_disposition sctp_sf_timer_ignore(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const union sctp_subtype type, + void *arg, + struct sctp_cmd_seq *commands) +{ + pr_debug("%s: timer %d ignored\n", __func__, type.chunk); + + return SCTP_DISPOSITION_CONSUME; +} + +/******************************************************************** + * 2nd Level Abstractions + ********************************************************************/ + +/* Pull the SACK chunk based on the SACK header. */ +static struct sctp_sackhdr *sctp_sm_pull_sack(struct sctp_chunk *chunk) +{ + struct sctp_sackhdr *sack; + __u16 num_dup_tsns; + unsigned int len; + __u16 num_blocks; + + /* Protect ourselves from reading too far into + * the skb from a bogus sender. + */ + sack = (struct sctp_sackhdr *) chunk->skb->data; + + num_blocks = ntohs(sack->num_gap_ack_blocks); + num_dup_tsns = ntohs(sack->num_dup_tsns); + len = sizeof(struct sctp_sackhdr); + len += (num_blocks + num_dup_tsns) * sizeof(__u32); + if (len > chunk->skb->len) + return NULL; + + skb_pull(chunk->skb, len); + + return sack; +} + +/* Create an ABORT packet to be sent as a response, with the specified + * error causes. + */ +static struct sctp_packet *sctp_abort_pkt_new( + struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + struct sctp_chunk *chunk, + const void *payload, size_t paylen) +{ + struct sctp_packet *packet; + struct sctp_chunk *abort; + + packet = sctp_ootb_pkt_new(net, asoc, chunk); + + if (packet) { + /* Make an ABORT. + * The T bit will be set if the asoc is NULL. + */ + abort = sctp_make_abort(asoc, chunk, paylen); + if (!abort) { + sctp_ootb_pkt_free(packet); + return NULL; + } + + /* Reflect vtag if T-Bit is set */ + if (sctp_test_T_bit(abort)) + packet->vtag = ntohl(chunk->sctp_hdr->vtag); + + /* Add specified error causes, i.e., payload, to the + * end of the chunk. + */ + sctp_addto_chunk(abort, paylen, payload); + + /* Set the skb to the belonging sock for accounting. */ + abort->skb->sk = ep->base.sk; + + sctp_packet_append_chunk(packet, abort); + + } + + return packet; +} + +/* Allocate a packet for responding in the OOTB conditions. */ +static struct sctp_packet *sctp_ootb_pkt_new( + struct net *net, + const struct sctp_association *asoc, + const struct sctp_chunk *chunk) +{ + struct sctp_transport *transport; + struct sctp_packet *packet; + __u16 sport, dport; + __u32 vtag; + + /* Get the source and destination port from the inbound packet. */ + sport = ntohs(chunk->sctp_hdr->dest); + dport = ntohs(chunk->sctp_hdr->source); + + /* The V-tag is going to be the same as the inbound packet if no + * association exists, otherwise, use the peer's vtag. + */ + if (asoc) { + /* Special case the INIT-ACK as there is no peer's vtag + * yet. + */ + switch (chunk->chunk_hdr->type) { + case SCTP_CID_INIT: + case SCTP_CID_INIT_ACK: + { + struct sctp_initack_chunk *initack; + + initack = (struct sctp_initack_chunk *)chunk->chunk_hdr; + vtag = ntohl(initack->init_hdr.init_tag); + break; + } + default: + vtag = asoc->peer.i.init_tag; + break; + } + } else { + /* Special case the INIT and stale COOKIE_ECHO as there is no + * vtag yet. + */ + switch (chunk->chunk_hdr->type) { + case SCTP_CID_INIT: + { + struct sctp_init_chunk *init; + + init = (struct sctp_init_chunk *)chunk->chunk_hdr; + vtag = ntohl(init->init_hdr.init_tag); + break; + } + default: + vtag = ntohl(chunk->sctp_hdr->vtag); + break; + } + } + + /* Make a transport for the bucket, Eliza... */ + transport = sctp_transport_new(net, sctp_source(chunk), GFP_ATOMIC); + if (!transport) + goto nomem; + + transport->encap_port = SCTP_INPUT_CB(chunk->skb)->encap_port; + + /* Cache a route for the transport with the chunk's destination as + * the source address. + */ + sctp_transport_route(transport, (union sctp_addr *)&chunk->dest, + sctp_sk(net->sctp.ctl_sock)); + + packet = &transport->packet; + sctp_packet_init(packet, transport, sport, dport); + sctp_packet_config(packet, vtag, 0); + + return packet; + +nomem: + return NULL; +} + +/* Free the packet allocated earlier for responding in the OOTB condition. */ +void sctp_ootb_pkt_free(struct sctp_packet *packet) +{ + sctp_transport_free(packet->transport); +} + +/* Send a stale cookie error when a invalid COOKIE ECHO chunk is found */ +static void sctp_send_stale_cookie_err(struct net *net, + const struct sctp_endpoint *ep, + const struct sctp_association *asoc, + const struct sctp_chunk *chunk, + struct sctp_cmd_seq *commands, + struct sctp_chunk *err_chunk) +{ + struct sctp_packet *packet; + + if (err_chunk) { + packet = sctp_ootb_pkt_new(net, asoc, chunk); + if (packet) { + struct sctp_signed_cookie *cookie; + + /* Override the OOTB vtag from the cookie. */ + cookie = chunk->subh.cookie_hdr; + packet->vtag = cookie->c.peer_vtag; + + /* Set the skb to the belonging sock for accounting. */ + err_chunk->skb->sk = ep->base.sk; + sctp_packet_append_chunk(packet, err_chunk); + sctp_add_cmd_sf(commands, SCTP_CMD_SEND_PKT, + SCTP_PACKET(packet)); + SCTP_INC_STATS(net, SCTP_MIB_OUTCTRLCHUNKS); + } else + sctp_chunk_free (err_chunk); + } +} + + +/* Process a data chunk */ +static int sctp_eat_data(const struct sctp_association *asoc, + struct sctp_chunk *chunk, + struct sctp_cmd_seq *commands) +{ + struct sctp_tsnmap *map = (struct sctp_tsnmap *)&asoc->peer.tsn_map; + struct sock *sk = asoc->base.sk; + struct net *net = sock_net(sk); + struct sctp_datahdr *data_hdr; + struct sctp_chunk *err; + enum sctp_verb deliver; + size_t datalen; + __u32 tsn; + int tmp; + + data_hdr = (struct sctp_datahdr *)chunk->skb->data; + chunk->subh.data_hdr = data_hdr; + skb_pull(chunk->skb, sctp_datahdr_len(&asoc->stream)); + + tsn = ntohl(data_hdr->tsn); + pr_debug("%s: TSN 0x%x\n", __func__, tsn); + + /* ASSERT: Now skb->data is really the user data. */ + + /* Process ECN based congestion. + * + * Since the chunk structure is reused for all chunks within + * a packet, we use ecn_ce_done to track if we've already + * done CE processing for this packet. + * + * We need to do ECN processing even if we plan to discard the + * chunk later. + */ + + if (asoc->peer.ecn_capable && !chunk->ecn_ce_done) { + struct sctp_af *af = SCTP_INPUT_CB(chunk->skb)->af; + chunk->ecn_ce_done = 1; + + if (af->is_ce(sctp_gso_headskb(chunk->skb))) { + /* Do real work as side effect. */ + sctp_add_cmd_sf(commands, SCTP_CMD_ECN_CE, + SCTP_U32(tsn)); + } + } + + tmp = sctp_tsnmap_check(&asoc->peer.tsn_map, tsn); + if (tmp < 0) { + /* The TSN is too high--silently discard the chunk and + * count on it getting retransmitted later. + */ + if (chunk->asoc) + chunk->asoc->stats.outofseqtsns++; + return SCTP_IERROR_HIGH_TSN; + } else if (tmp > 0) { + /* This is a duplicate. Record it. */ + sctp_add_cmd_sf(commands, SCTP_CMD_REPORT_DUP, SCTP_U32(tsn)); + return SCTP_IERROR_DUP_TSN; + } + + /* This is a new TSN. */ + + /* Discard if there is no room in the receive window. + * Actually, allow a little bit of overflow (up to a MTU). + */ + datalen = ntohs(chunk->chunk_hdr->length); + datalen -= sctp_datachk_len(&asoc->stream); + + deliver = SCTP_CMD_CHUNK_ULP; + + /* Think about partial delivery. */ + if ((datalen >= asoc->rwnd) && (!asoc->ulpq.pd_mode)) { + + /* Even if we don't accept this chunk there is + * memory pressure. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_PART_DELIVER, SCTP_NULL()); + } + + /* Spill over rwnd a little bit. Note: While allowed, this spill over + * seems a bit troublesome in that frag_point varies based on + * PMTU. In cases, such as loopback, this might be a rather + * large spill over. + */ + if ((!chunk->data_accepted) && (!asoc->rwnd || asoc->rwnd_over || + (datalen > asoc->rwnd + asoc->frag_point))) { + + /* If this is the next TSN, consider reneging to make + * room. Note: Playing nice with a confused sender. A + * malicious sender can still eat up all our buffer + * space and in the future we may want to detect and + * do more drastic reneging. + */ + if (sctp_tsnmap_has_gap(map) && + (sctp_tsnmap_get_ctsn(map) + 1) == tsn) { + pr_debug("%s: reneging for tsn:%u\n", __func__, tsn); + deliver = SCTP_CMD_RENEGE; + } else { + pr_debug("%s: discard tsn:%u len:%zu, rwnd:%d\n", + __func__, tsn, datalen, asoc->rwnd); + + return SCTP_IERROR_IGNORE_TSN; + } + } + + /* + * Also try to renege to limit our memory usage in the event that + * we are under memory pressure + * If we can't renege, don't worry about it, the sk_rmem_schedule + * in sctp_ulpevent_make_rcvmsg will drop the frame if we grow our + * memory usage too much + */ + if (sk_under_memory_pressure(sk)) { + if (sctp_tsnmap_has_gap(map) && + (sctp_tsnmap_get_ctsn(map) + 1) == tsn) { + pr_debug("%s: under pressure, reneging for tsn:%u\n", + __func__, tsn); + deliver = SCTP_CMD_RENEGE; + } + } + + /* + * Section 3.3.10.9 No User Data (9) + * + * Cause of error + * --------------- + * No User Data: This error cause is returned to the originator of a + * DATA chunk if a received DATA chunk has no user data. + */ + if (unlikely(0 == datalen)) { + err = sctp_make_abort_no_data(asoc, chunk, tsn); + if (err) { + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, + SCTP_CHUNK(err)); + } + /* We are going to ABORT, so we might as well stop + * processing the rest of the chunks in the packet. + */ + sctp_add_cmd_sf(commands, SCTP_CMD_DISCARD_PACKET, SCTP_NULL()); + sctp_add_cmd_sf(commands, SCTP_CMD_SET_SK_ERR, + SCTP_ERROR(ECONNABORTED)); + sctp_add_cmd_sf(commands, SCTP_CMD_ASSOC_FAILED, + SCTP_PERR(SCTP_ERROR_NO_DATA)); + SCTP_INC_STATS(net, SCTP_MIB_ABORTEDS); + SCTP_DEC_STATS(net, SCTP_MIB_CURRESTAB); + return SCTP_IERROR_NO_DATA; + } + + chunk->data_accepted = 1; + + /* Note: Some chunks may get overcounted (if we drop) or overcounted + * if we renege and the chunk arrives again. + */ + if (chunk->chunk_hdr->flags & SCTP_DATA_UNORDERED) { + SCTP_INC_STATS(net, SCTP_MIB_INUNORDERCHUNKS); + if (chunk->asoc) + chunk->asoc->stats.iuodchunks++; + } else { + SCTP_INC_STATS(net, SCTP_MIB_INORDERCHUNKS); + if (chunk->asoc) + chunk->asoc->stats.iodchunks++; + } + + /* RFC 2960 6.5 Stream Identifier and Stream Sequence Number + * + * If an endpoint receive a DATA chunk with an invalid stream + * identifier, it shall acknowledge the reception of the DATA chunk + * following the normal procedure, immediately send an ERROR chunk + * with cause set to "Invalid Stream Identifier" (See Section 3.3.10) + * and discard the DATA chunk. + */ + if (ntohs(data_hdr->stream) >= asoc->stream.incnt) { + /* Mark tsn as received even though we drop it */ + sctp_add_cmd_sf(commands, SCTP_CMD_REPORT_TSN, SCTP_U32(tsn)); + + err = sctp_make_op_error(asoc, chunk, SCTP_ERROR_INV_STRM, + &data_hdr->stream, + sizeof(data_hdr->stream), + sizeof(u16)); + if (err) + sctp_add_cmd_sf(commands, SCTP_CMD_REPLY, + SCTP_CHUNK(err)); + return SCTP_IERROR_BAD_STREAM; + } + + /* Check to see if the SSN is possible for this TSN. + * The biggest gap we can record is 4K wide. Since SSNs wrap + * at an unsigned short, there is no way that an SSN can + * wrap and for a valid TSN. We can simply check if the current + * SSN is smaller then the next expected one. If it is, it wrapped + * and is invalid. + */ + if (!asoc->stream.si->validate_data(chunk)) + return SCTP_IERROR_PROTO_VIOLATION; + + /* Send the data up to the user. Note: Schedule the + * SCTP_CMD_CHUNK_ULP cmd before the SCTP_CMD_GEN_SACK, as the SACK + * chunk needs the updated rwnd. + */ + sctp_add_cmd_sf(commands, deliver, SCTP_CHUNK(chunk)); + + return SCTP_IERROR_NO_ERROR; +} diff --git a/net/sctp/sm_statetable.c b/net/sctp/sm_statetable.c new file mode 100644 index 000000000..1816a4410 --- /dev/null +++ b/net/sctp/sm_statetable.c @@ -0,0 +1,1041 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2001, 2004 + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * Copyright (c) 2001 Intel Corp. + * Copyright (c) 2001 Nokia, Inc. + * + * This file is part of the SCTP kernel implementation + * + * These are the state tables for the SCTP state machine. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * La Monte H.P. Yarroll + * Karl Knutson + * Jon Grimm + * Hui Huang + * Daisy Chang + * Ardelle Fan + * Sridhar Samudrala + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include + +static const struct sctp_sm_table_entry +primitive_event_table[SCTP_NUM_PRIMITIVE_TYPES][SCTP_STATE_NUM_STATES]; +static const struct sctp_sm_table_entry +other_event_table[SCTP_NUM_OTHER_TYPES][SCTP_STATE_NUM_STATES]; +static const struct sctp_sm_table_entry +timeout_event_table[SCTP_NUM_TIMEOUT_TYPES][SCTP_STATE_NUM_STATES]; + +static const struct sctp_sm_table_entry *sctp_chunk_event_lookup( + struct net *net, + enum sctp_cid cid, + enum sctp_state state); + + +static const struct sctp_sm_table_entry bug = { + .fn = sctp_sf_bug, + .name = "sctp_sf_bug" +}; + +#define DO_LOOKUP(_max, _type, _table) \ +({ \ + const struct sctp_sm_table_entry *rtn; \ + \ + if ((event_subtype._type > (_max))) { \ + pr_warn("table %p possible attack: event %d exceeds max %d\n", \ + _table, event_subtype._type, _max); \ + rtn = &bug; \ + } else \ + rtn = &_table[event_subtype._type][(int)state]; \ + \ + rtn; \ +}) + +const struct sctp_sm_table_entry *sctp_sm_lookup_event( + struct net *net, + enum sctp_event_type event_type, + enum sctp_state state, + union sctp_subtype event_subtype) +{ + switch (event_type) { + case SCTP_EVENT_T_CHUNK: + return sctp_chunk_event_lookup(net, event_subtype.chunk, state); + case SCTP_EVENT_T_TIMEOUT: + return DO_LOOKUP(SCTP_EVENT_TIMEOUT_MAX, timeout, + timeout_event_table); + case SCTP_EVENT_T_OTHER: + return DO_LOOKUP(SCTP_EVENT_OTHER_MAX, other, + other_event_table); + case SCTP_EVENT_T_PRIMITIVE: + return DO_LOOKUP(SCTP_EVENT_PRIMITIVE_MAX, primitive, + primitive_event_table); + default: + /* Yikes! We got an illegal event type. */ + return &bug; + } +} + +#define TYPE_SCTP_FUNC(func) {.fn = func, .name = #func} + +#define TYPE_SCTP_DATA { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_ootb), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_eat_data_6_2), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_eat_data_6_2), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_eat_data_fast_4_4), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ +} /* TYPE_SCTP_DATA */ + +#define TYPE_SCTP_INIT { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_1B_init), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_2_1_siminit), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_2_1_siminit), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_2_2_dupinit), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_2_2_dupinit), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_2_2_dupinit), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_2_2_dupinit), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_9_2_reshutack), \ +} /* TYPE_SCTP_INIT */ + +#define TYPE_SCTP_INIT_ACK { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_2_3_initack), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_1C_ack), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ +} /* TYPE_SCTP_INIT_ACK */ + +#define TYPE_SCTP_SACK { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_ootb), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_eat_sack_6_2), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_eat_sack_6_2), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_eat_sack_6_2), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_eat_sack_6_2), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ +} /* TYPE_SCTP_SACK */ + +#define TYPE_SCTP_HEARTBEAT { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_ootb), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_beat_8_3), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_beat_8_3), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_beat_8_3), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_beat_8_3), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_beat_8_3), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + /* This should not happen, but we are nice. */ \ + TYPE_SCTP_FUNC(sctp_sf_beat_8_3), \ +} /* TYPE_SCTP_HEARTBEAT */ + +#define TYPE_SCTP_HEARTBEAT_ACK { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_ootb), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_violation), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_backbeat_8_3), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_backbeat_8_3), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_backbeat_8_3), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_backbeat_8_3), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ +} /* TYPE_SCTP_HEARTBEAT_ACK */ + +#define TYPE_SCTP_ABORT { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_pdiscard), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_cookie_wait_abort), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_cookie_echoed_abort), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_9_1_abort), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_shutdown_pending_abort), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_shutdown_sent_abort), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_9_1_abort), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_shutdown_ack_sent_abort), \ +} /* TYPE_SCTP_ABORT */ + +#define TYPE_SCTP_SHUTDOWN { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_ootb), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_9_2_shutdown), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_do_9_2_shutdown), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_9_2_shutdown_ack), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_9_2_shut_ctsn), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ +} /* TYPE_SCTP_SHUTDOWN */ + +#define TYPE_SCTP_SHUTDOWN_ACK { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_ootb), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_8_5_1_E_sa), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_8_5_1_E_sa), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_violation), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_violation), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_9_2_final), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_violation), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_9_2_final), \ +} /* TYPE_SCTP_SHUTDOWN_ACK */ + +#define TYPE_SCTP_ERROR { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_ootb), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_cookie_echoed_err), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_operr_notify), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_operr_notify), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_operr_notify), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ +} /* TYPE_SCTP_ERROR */ + +#define TYPE_SCTP_COOKIE_ECHO { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_1D_ce), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_2_4_dupcook), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_2_4_dupcook), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_2_4_dupcook), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_2_4_dupcook), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_2_4_dupcook), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_2_4_dupcook), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_2_4_dupcook), \ +} /* TYPE_SCTP_COOKIE_ECHO */ + +#define TYPE_SCTP_COOKIE_ACK { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_5_1E_ca), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ +} /* TYPE_SCTP_COOKIE_ACK */ + +#define TYPE_SCTP_ECN_ECNE { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_ecne), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_ecne), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_do_ecne), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_ecne), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_ecne), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ +} /* TYPE_SCTP_ECN_ECNE */ + +#define TYPE_SCTP_ECN_CWR { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_ecn_cwr), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_do_ecn_cwr), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_ecn_cwr), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ +} /* TYPE_SCTP_ECN_CWR */ + +#define TYPE_SCTP_SHUTDOWN_COMPLETE { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_4_C), \ +} /* TYPE_SCTP_SHUTDOWN_COMPLETE */ + +/* The primary index for this table is the chunk type. + * The secondary index for this table is the state. + * + * For base protocol (RFC 2960). + */ +static const struct sctp_sm_table_entry +chunk_event_table[SCTP_NUM_BASE_CHUNK_TYPES][SCTP_STATE_NUM_STATES] = { + TYPE_SCTP_DATA, + TYPE_SCTP_INIT, + TYPE_SCTP_INIT_ACK, + TYPE_SCTP_SACK, + TYPE_SCTP_HEARTBEAT, + TYPE_SCTP_HEARTBEAT_ACK, + TYPE_SCTP_ABORT, + TYPE_SCTP_SHUTDOWN, + TYPE_SCTP_SHUTDOWN_ACK, + TYPE_SCTP_ERROR, + TYPE_SCTP_COOKIE_ECHO, + TYPE_SCTP_COOKIE_ACK, + TYPE_SCTP_ECN_ECNE, + TYPE_SCTP_ECN_CWR, + TYPE_SCTP_SHUTDOWN_COMPLETE, +}; /* state_fn_t chunk_event_table[][] */ + +#define TYPE_SCTP_ASCONF { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_asconf), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_do_asconf), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_asconf), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_asconf), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ +} /* TYPE_SCTP_ASCONF */ + +#define TYPE_SCTP_ASCONF_ACK { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_asconf_ack), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_do_asconf_ack), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_asconf_ack), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_asconf_ack), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ +} /* TYPE_SCTP_ASCONF_ACK */ + +/* The primary index for this table is the chunk type. + * The secondary index for this table is the state. + */ +static const struct sctp_sm_table_entry +addip_chunk_event_table[SCTP_NUM_ADDIP_CHUNK_TYPES][SCTP_STATE_NUM_STATES] = { + TYPE_SCTP_ASCONF, + TYPE_SCTP_ASCONF_ACK, +}; /*state_fn_t addip_chunk_event_table[][] */ + +#define TYPE_SCTP_FWD_TSN { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_ootb), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_eat_fwd_tsn), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_eat_fwd_tsn), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_eat_fwd_tsn_fast), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ +} /* TYPE_SCTP_FWD_TSN */ + +/* The primary index for this table is the chunk type. + * The secondary index for this table is the state. + */ +static const struct sctp_sm_table_entry +prsctp_chunk_event_table[SCTP_NUM_PRSCTP_CHUNK_TYPES][SCTP_STATE_NUM_STATES] = { + TYPE_SCTP_FWD_TSN, +}; /*state_fn_t prsctp_chunk_event_table[][] */ + +#define TYPE_SCTP_RECONF { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_reconf), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_do_reconf), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ +} /* TYPE_SCTP_RECONF */ + +/* The primary index for this table is the chunk type. + * The secondary index for this table is the state. + */ +static const struct sctp_sm_table_entry +reconf_chunk_event_table[SCTP_NUM_RECONF_CHUNK_TYPES][SCTP_STATE_NUM_STATES] = { + TYPE_SCTP_RECONF, +}; /*state_fn_t reconf_chunk_event_table[][] */ + +#define TYPE_SCTP_AUTH { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_ootb), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_eat_auth), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_eat_auth), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_eat_auth), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_eat_auth), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_eat_auth), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_eat_auth), \ +} /* TYPE_SCTP_AUTH */ + +/* The primary index for this table is the chunk type. + * The secondary index for this table is the state. + */ +static const struct sctp_sm_table_entry +auth_chunk_event_table[SCTP_NUM_AUTH_CHUNK_TYPES][SCTP_STATE_NUM_STATES] = { + TYPE_SCTP_AUTH, +}; /*state_fn_t auth_chunk_event_table[][] */ + +static const struct sctp_sm_table_entry +pad_chunk_event_table[SCTP_STATE_NUM_STATES] = { + /* SCTP_STATE_CLOSED */ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), + /* SCTP_STATE_COOKIE_WAIT */ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), + /* SCTP_STATE_COOKIE_ECHOED */ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), + /* SCTP_STATE_ESTABLISHED */ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), + /* SCTP_STATE_SHUTDOWN_PENDING */ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), + /* SCTP_STATE_SHUTDOWN_SENT */ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), + /* SCTP_STATE_SHUTDOWN_RECEIVED */ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ + TYPE_SCTP_FUNC(sctp_sf_discard_chunk), +}; /* chunk pad */ + +static const struct sctp_sm_table_entry +chunk_event_table_unknown[SCTP_STATE_NUM_STATES] = { + /* SCTP_STATE_CLOSED */ + TYPE_SCTP_FUNC(sctp_sf_ootb), + /* SCTP_STATE_COOKIE_WAIT */ + TYPE_SCTP_FUNC(sctp_sf_unk_chunk), + /* SCTP_STATE_COOKIE_ECHOED */ + TYPE_SCTP_FUNC(sctp_sf_unk_chunk), + /* SCTP_STATE_ESTABLISHED */ + TYPE_SCTP_FUNC(sctp_sf_unk_chunk), + /* SCTP_STATE_SHUTDOWN_PENDING */ + TYPE_SCTP_FUNC(sctp_sf_unk_chunk), + /* SCTP_STATE_SHUTDOWN_SENT */ + TYPE_SCTP_FUNC(sctp_sf_unk_chunk), + /* SCTP_STATE_SHUTDOWN_RECEIVED */ + TYPE_SCTP_FUNC(sctp_sf_unk_chunk), + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ + TYPE_SCTP_FUNC(sctp_sf_unk_chunk), +}; /* chunk unknown */ + + +#define TYPE_SCTP_PRIMITIVE_ASSOCIATE { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_asoc), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_not_impl), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_not_impl), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_not_impl), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_not_impl), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_not_impl), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_not_impl), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_not_impl), \ +} /* TYPE_SCTP_PRIMITIVE_ASSOCIATE */ + +#define TYPE_SCTP_PRIMITIVE_SHUTDOWN { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_error_closed), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_cookie_wait_prm_shutdown), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_cookie_echoed_prm_shutdown),\ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_9_2_prm_shutdown), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_ignore_primitive), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_ignore_primitive), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_ignore_primitive), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_ignore_primitive), \ +} /* TYPE_SCTP_PRIMITIVE_SHUTDOWN */ + +#define TYPE_SCTP_PRIMITIVE_ABORT { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_error_closed), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_cookie_wait_prm_abort), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_cookie_echoed_prm_abort), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_9_1_prm_abort), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_shutdown_pending_prm_abort), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_shutdown_sent_prm_abort), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_9_1_prm_abort), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_shutdown_ack_sent_prm_abort), \ +} /* TYPE_SCTP_PRIMITIVE_ABORT */ + +#define TYPE_SCTP_PRIMITIVE_SEND { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_error_closed), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_send), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_send), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_send), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_error_shutdown), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_error_shutdown), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_error_shutdown), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_error_shutdown), \ +} /* TYPE_SCTP_PRIMITIVE_SEND */ + +#define TYPE_SCTP_PRIMITIVE_REQUESTHEARTBEAT { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_error_closed), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_requestheartbeat), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_requestheartbeat), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_requestheartbeat), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_requestheartbeat), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_requestheartbeat), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_requestheartbeat), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_requestheartbeat), \ +} /* TYPE_SCTP_PRIMITIVE_REQUESTHEARTBEAT */ + +#define TYPE_SCTP_PRIMITIVE_ASCONF { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_error_closed), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_error_closed), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_error_closed), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_asconf), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_asconf), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_asconf), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_asconf), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_error_shutdown), \ +} /* TYPE_SCTP_PRIMITIVE_ASCONF */ + +#define TYPE_SCTP_PRIMITIVE_RECONF { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_error_closed), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_error_closed), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_error_closed), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_reconf), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_reconf), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_reconf), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_prm_reconf), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_error_shutdown), \ +} /* TYPE_SCTP_PRIMITIVE_RECONF */ + +/* The primary index for this table is the primitive type. + * The secondary index for this table is the state. + */ +static const struct sctp_sm_table_entry +primitive_event_table[SCTP_NUM_PRIMITIVE_TYPES][SCTP_STATE_NUM_STATES] = { + TYPE_SCTP_PRIMITIVE_ASSOCIATE, + TYPE_SCTP_PRIMITIVE_SHUTDOWN, + TYPE_SCTP_PRIMITIVE_ABORT, + TYPE_SCTP_PRIMITIVE_SEND, + TYPE_SCTP_PRIMITIVE_REQUESTHEARTBEAT, + TYPE_SCTP_PRIMITIVE_ASCONF, + TYPE_SCTP_PRIMITIVE_RECONF, +}; + +#define TYPE_SCTP_OTHER_NO_PENDING_TSN { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_ignore_other), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_ignore_other), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_ignore_other), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_no_pending_tsn), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_do_9_2_start_shutdown), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_ignore_other), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_9_2_shutdown_ack), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_ignore_other), \ +} + +#define TYPE_SCTP_OTHER_ICMP_PROTO_UNREACH { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_ignore_other), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_cookie_wait_icmp_abort), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_ignore_other), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_ignore_other), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_ignore_other), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_ignore_other), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_ignore_other), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_ignore_other), \ +} + +static const struct sctp_sm_table_entry +other_event_table[SCTP_NUM_OTHER_TYPES][SCTP_STATE_NUM_STATES] = { + TYPE_SCTP_OTHER_NO_PENDING_TSN, + TYPE_SCTP_OTHER_ICMP_PROTO_UNREACH, +}; + +#define TYPE_SCTP_EVENT_TIMEOUT_NONE { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_bug), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_bug), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_bug), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_bug), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_bug), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_bug), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_bug), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_bug), \ +} + +#define TYPE_SCTP_EVENT_TIMEOUT_T1_COOKIE { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_bug), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_t1_cookie_timer_expire), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ +} + +#define TYPE_SCTP_EVENT_TIMEOUT_T1_INIT { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_t1_init_timer_expire), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ +} + +#define TYPE_SCTP_EVENT_TIMEOUT_T2_SHUTDOWN { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_t2_timer_expire), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_t2_timer_expire), \ +} + +#define TYPE_SCTP_EVENT_TIMEOUT_T3_RTX { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_6_3_3_rtx), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_6_3_3_rtx), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_do_6_3_3_rtx), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_6_3_3_rtx), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ +} + +#define TYPE_SCTP_EVENT_TIMEOUT_T4_RTO { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_t4_timer_expire), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ +} + +#define TYPE_SCTP_EVENT_TIMEOUT_T5_SHUTDOWN_GUARD { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_t5_timer_expire), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_t5_timer_expire), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ +} + +#define TYPE_SCTP_EVENT_TIMEOUT_HEARTBEAT { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_sendbeat_8_3), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_sendbeat_8_3), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_sendbeat_8_3), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ +} + +#define TYPE_SCTP_EVENT_TIMEOUT_SACK { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_do_6_2_sack), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_do_6_2_sack), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_do_6_2_sack), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ +} + +#define TYPE_SCTP_EVENT_TIMEOUT_AUTOCLOSE { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_autoclose_timer_expire), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ +} + +#define TYPE_SCTP_EVENT_TIMEOUT_RECONF { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_send_reconf), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ +} + +#define TYPE_SCTP_EVENT_TIMEOUT_PROBE { \ + /* SCTP_STATE_CLOSED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_WAIT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_COOKIE_ECHOED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_ESTABLISHED */ \ + TYPE_SCTP_FUNC(sctp_sf_send_probe), \ + /* SCTP_STATE_SHUTDOWN_PENDING */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_RECEIVED */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ + /* SCTP_STATE_SHUTDOWN_ACK_SENT */ \ + TYPE_SCTP_FUNC(sctp_sf_timer_ignore), \ +} + +static const struct sctp_sm_table_entry +timeout_event_table[SCTP_NUM_TIMEOUT_TYPES][SCTP_STATE_NUM_STATES] = { + TYPE_SCTP_EVENT_TIMEOUT_NONE, + TYPE_SCTP_EVENT_TIMEOUT_T1_COOKIE, + TYPE_SCTP_EVENT_TIMEOUT_T1_INIT, + TYPE_SCTP_EVENT_TIMEOUT_T2_SHUTDOWN, + TYPE_SCTP_EVENT_TIMEOUT_T3_RTX, + TYPE_SCTP_EVENT_TIMEOUT_T4_RTO, + TYPE_SCTP_EVENT_TIMEOUT_T5_SHUTDOWN_GUARD, + TYPE_SCTP_EVENT_TIMEOUT_HEARTBEAT, + TYPE_SCTP_EVENT_TIMEOUT_RECONF, + TYPE_SCTP_EVENT_TIMEOUT_PROBE, + TYPE_SCTP_EVENT_TIMEOUT_SACK, + TYPE_SCTP_EVENT_TIMEOUT_AUTOCLOSE, +}; + +static const struct sctp_sm_table_entry *sctp_chunk_event_lookup( + struct net *net, + enum sctp_cid cid, + enum sctp_state state) +{ + if (state > SCTP_STATE_MAX) + return &bug; + + if (cid == SCTP_CID_I_DATA) + cid = SCTP_CID_DATA; + + if (cid <= SCTP_CID_BASE_MAX) + return &chunk_event_table[cid][state]; + + switch ((u16)cid) { + case SCTP_CID_FWD_TSN: + case SCTP_CID_I_FWD_TSN: + return &prsctp_chunk_event_table[0][state]; + + case SCTP_CID_ASCONF: + return &addip_chunk_event_table[0][state]; + + case SCTP_CID_ASCONF_ACK: + return &addip_chunk_event_table[1][state]; + + case SCTP_CID_RECONF: + return &reconf_chunk_event_table[0][state]; + + case SCTP_CID_AUTH: + return &auth_chunk_event_table[0][state]; + + case SCTP_CID_PAD: + return &pad_chunk_event_table[state]; + } + + return &chunk_event_table_unknown[state]; +} diff --git a/net/sctp/socket.c b/net/sctp/socket.c new file mode 100644 index 000000000..237a6b04a --- /dev/null +++ b/net/sctp/socket.c @@ -0,0 +1,9745 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2001, 2004 + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * Copyright (c) 2001-2003 Intel Corp. + * Copyright (c) 2001-2002 Nokia, Inc. + * Copyright (c) 2001 La Monte H.P. Yarroll + * + * This file is part of the SCTP kernel implementation + * + * These functions interface with the sockets layer to implement the + * SCTP Extensions for the Sockets API. + * + * Note that the descriptions from the specification are USER level + * functions--this file is the functions which populate the struct proto + * for SCTP which is the BOTTOM of the sockets interface. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * La Monte H.P. Yarroll + * Narasimha Budihal + * Karl Knutson + * Jon Grimm + * Xingang Guo + * Daisy Chang + * Sridhar Samudrala + * Inaky Perez-Gonzalez + * Ardelle Fan + * Ryan Layer + * Anup Pemmaiah + * Kevin Gao + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include /* for sa_family_t */ +#include +#include +#include +#include +#include + +/* Forward declarations for internal helper functions. */ +static bool sctp_writeable(const struct sock *sk); +static void sctp_wfree(struct sk_buff *skb); +static int sctp_wait_for_sndbuf(struct sctp_association *asoc, long *timeo_p, + size_t msg_len); +static int sctp_wait_for_packet(struct sock *sk, int *err, long *timeo_p); +static int sctp_wait_for_connect(struct sctp_association *, long *timeo_p); +static int sctp_wait_for_accept(struct sock *sk, long timeo); +static void sctp_wait_for_close(struct sock *sk, long timeo); +static void sctp_destruct_sock(struct sock *sk); +static struct sctp_af *sctp_sockaddr_af(struct sctp_sock *opt, + union sctp_addr *addr, int len); +static int sctp_bindx_add(struct sock *, struct sockaddr *, int); +static int sctp_bindx_rem(struct sock *, struct sockaddr *, int); +static int sctp_send_asconf_add_ip(struct sock *, struct sockaddr *, int); +static int sctp_send_asconf_del_ip(struct sock *, struct sockaddr *, int); +static int sctp_send_asconf(struct sctp_association *asoc, + struct sctp_chunk *chunk); +static int sctp_do_bind(struct sock *, union sctp_addr *, int); +static int sctp_autobind(struct sock *sk); +static int sctp_sock_migrate(struct sock *oldsk, struct sock *newsk, + struct sctp_association *assoc, + enum sctp_socket_type type); + +static unsigned long sctp_memory_pressure; +static atomic_long_t sctp_memory_allocated; +static DEFINE_PER_CPU(int, sctp_memory_per_cpu_fw_alloc); +struct percpu_counter sctp_sockets_allocated; + +static void sctp_enter_memory_pressure(struct sock *sk) +{ + WRITE_ONCE(sctp_memory_pressure, 1); +} + + +/* Get the sndbuf space available at the time on the association. */ +static inline int sctp_wspace(struct sctp_association *asoc) +{ + struct sock *sk = asoc->base.sk; + + return asoc->ep->sndbuf_policy ? sk->sk_sndbuf - asoc->sndbuf_used + : sk_stream_wspace(sk); +} + +/* Increment the used sndbuf space count of the corresponding association by + * the size of the outgoing data chunk. + * Also, set the skb destructor for sndbuf accounting later. + * + * Since it is always 1-1 between chunk and skb, and also a new skb is always + * allocated for chunk bundling in sctp_packet_transmit(), we can use the + * destructor in the data chunk skb for the purpose of the sndbuf space + * tracking. + */ +static inline void sctp_set_owner_w(struct sctp_chunk *chunk) +{ + struct sctp_association *asoc = chunk->asoc; + struct sock *sk = asoc->base.sk; + + /* The sndbuf space is tracked per association. */ + sctp_association_hold(asoc); + + if (chunk->shkey) + sctp_auth_shkey_hold(chunk->shkey); + + skb_set_owner_w(chunk->skb, sk); + + chunk->skb->destructor = sctp_wfree; + /* Save the chunk pointer in skb for sctp_wfree to use later. */ + skb_shinfo(chunk->skb)->destructor_arg = chunk; + + refcount_add(sizeof(struct sctp_chunk), &sk->sk_wmem_alloc); + asoc->sndbuf_used += chunk->skb->truesize + sizeof(struct sctp_chunk); + sk_wmem_queued_add(sk, chunk->skb->truesize + sizeof(struct sctp_chunk)); + sk_mem_charge(sk, chunk->skb->truesize); +} + +static void sctp_clear_owner_w(struct sctp_chunk *chunk) +{ + skb_orphan(chunk->skb); +} + +#define traverse_and_process() \ +do { \ + msg = chunk->msg; \ + if (msg == prev_msg) \ + continue; \ + list_for_each_entry(c, &msg->chunks, frag_list) { \ + if ((clear && asoc->base.sk == c->skb->sk) || \ + (!clear && asoc->base.sk != c->skb->sk)) \ + cb(c); \ + } \ + prev_msg = msg; \ +} while (0) + +static void sctp_for_each_tx_datachunk(struct sctp_association *asoc, + bool clear, + void (*cb)(struct sctp_chunk *)) + +{ + struct sctp_datamsg *msg, *prev_msg = NULL; + struct sctp_outq *q = &asoc->outqueue; + struct sctp_chunk *chunk, *c; + struct sctp_transport *t; + + list_for_each_entry(t, &asoc->peer.transport_addr_list, transports) + list_for_each_entry(chunk, &t->transmitted, transmitted_list) + traverse_and_process(); + + list_for_each_entry(chunk, &q->retransmit, transmitted_list) + traverse_and_process(); + + list_for_each_entry(chunk, &q->sacked, transmitted_list) + traverse_and_process(); + + list_for_each_entry(chunk, &q->abandoned, transmitted_list) + traverse_and_process(); + + list_for_each_entry(chunk, &q->out_chunk_list, list) + traverse_and_process(); +} + +static void sctp_for_each_rx_skb(struct sctp_association *asoc, struct sock *sk, + void (*cb)(struct sk_buff *, struct sock *)) + +{ + struct sk_buff *skb, *tmp; + + sctp_skb_for_each(skb, &asoc->ulpq.lobby, tmp) + cb(skb, sk); + + sctp_skb_for_each(skb, &asoc->ulpq.reasm, tmp) + cb(skb, sk); + + sctp_skb_for_each(skb, &asoc->ulpq.reasm_uo, tmp) + cb(skb, sk); +} + +/* Verify that this is a valid address. */ +static inline int sctp_verify_addr(struct sock *sk, union sctp_addr *addr, + int len) +{ + struct sctp_af *af; + + /* Verify basic sockaddr. */ + af = sctp_sockaddr_af(sctp_sk(sk), addr, len); + if (!af) + return -EINVAL; + + /* Is this a valid SCTP address? */ + if (!af->addr_valid(addr, sctp_sk(sk), NULL)) + return -EINVAL; + + if (!sctp_sk(sk)->pf->send_verify(sctp_sk(sk), (addr))) + return -EINVAL; + + return 0; +} + +/* Look up the association by its id. If this is not a UDP-style + * socket, the ID field is always ignored. + */ +struct sctp_association *sctp_id2assoc(struct sock *sk, sctp_assoc_t id) +{ + struct sctp_association *asoc = NULL; + + /* If this is not a UDP-style socket, assoc id should be ignored. */ + if (!sctp_style(sk, UDP)) { + /* Return NULL if the socket state is not ESTABLISHED. It + * could be a TCP-style listening socket or a socket which + * hasn't yet called connect() to establish an association. + */ + if (!sctp_sstate(sk, ESTABLISHED) && !sctp_sstate(sk, CLOSING)) + return NULL; + + /* Get the first and the only association from the list. */ + if (!list_empty(&sctp_sk(sk)->ep->asocs)) + asoc = list_entry(sctp_sk(sk)->ep->asocs.next, + struct sctp_association, asocs); + return asoc; + } + + /* Otherwise this is a UDP-style socket. */ + if (id <= SCTP_ALL_ASSOC) + return NULL; + + spin_lock_bh(&sctp_assocs_id_lock); + asoc = (struct sctp_association *)idr_find(&sctp_assocs_id, (int)id); + if (asoc && (asoc->base.sk != sk || asoc->base.dead)) + asoc = NULL; + spin_unlock_bh(&sctp_assocs_id_lock); + + return asoc; +} + +/* Look up the transport from an address and an assoc id. If both address and + * id are specified, the associations matching the address and the id should be + * the same. + */ +static struct sctp_transport *sctp_addr_id2transport(struct sock *sk, + struct sockaddr_storage *addr, + sctp_assoc_t id) +{ + struct sctp_association *addr_asoc = NULL, *id_asoc = NULL; + struct sctp_af *af = sctp_get_af_specific(addr->ss_family); + union sctp_addr *laddr = (union sctp_addr *)addr; + struct sctp_transport *transport; + + if (!af || sctp_verify_addr(sk, laddr, af->sockaddr_len)) + return NULL; + + addr_asoc = sctp_endpoint_lookup_assoc(sctp_sk(sk)->ep, + laddr, + &transport); + + if (!addr_asoc) + return NULL; + + id_asoc = sctp_id2assoc(sk, id); + if (id_asoc && (id_asoc != addr_asoc)) + return NULL; + + sctp_get_pf_specific(sk->sk_family)->addr_to_user(sctp_sk(sk), + (union sctp_addr *)addr); + + return transport; +} + +/* API 3.1.2 bind() - UDP Style Syntax + * The syntax of bind() is, + * + * ret = bind(int sd, struct sockaddr *addr, int addrlen); + * + * sd - the socket descriptor returned by socket(). + * addr - the address structure (struct sockaddr_in or struct + * sockaddr_in6 [RFC 2553]), + * addr_len - the size of the address structure. + */ +static int sctp_bind(struct sock *sk, struct sockaddr *addr, int addr_len) +{ + int retval = 0; + + lock_sock(sk); + + pr_debug("%s: sk:%p, addr:%p, addr_len:%d\n", __func__, sk, + addr, addr_len); + + /* Disallow binding twice. */ + if (!sctp_sk(sk)->ep->base.bind_addr.port) + retval = sctp_do_bind(sk, (union sctp_addr *)addr, + addr_len); + else + retval = -EINVAL; + + release_sock(sk); + + return retval; +} + +static int sctp_get_port_local(struct sock *, union sctp_addr *); + +/* Verify this is a valid sockaddr. */ +static struct sctp_af *sctp_sockaddr_af(struct sctp_sock *opt, + union sctp_addr *addr, int len) +{ + struct sctp_af *af; + + /* Check minimum size. */ + if (len < sizeof (struct sockaddr)) + return NULL; + + if (!opt->pf->af_supported(addr->sa.sa_family, opt)) + return NULL; + + if (addr->sa.sa_family == AF_INET6) { + if (len < SIN6_LEN_RFC2133) + return NULL; + /* V4 mapped address are really of AF_INET family */ + if (ipv6_addr_v4mapped(&addr->v6.sin6_addr) && + !opt->pf->af_supported(AF_INET, opt)) + return NULL; + } + + /* If we get this far, af is valid. */ + af = sctp_get_af_specific(addr->sa.sa_family); + + if (len < af->sockaddr_len) + return NULL; + + return af; +} + +static void sctp_auto_asconf_init(struct sctp_sock *sp) +{ + struct net *net = sock_net(&sp->inet.sk); + + if (net->sctp.default_auto_asconf) { + spin_lock_bh(&net->sctp.addr_wq_lock); + list_add_tail(&sp->auto_asconf_list, &net->sctp.auto_asconf_splist); + spin_unlock_bh(&net->sctp.addr_wq_lock); + sp->do_auto_asconf = 1; + } +} + +/* Bind a local address either to an endpoint or to an association. */ +static int sctp_do_bind(struct sock *sk, union sctp_addr *addr, int len) +{ + struct net *net = sock_net(sk); + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_endpoint *ep = sp->ep; + struct sctp_bind_addr *bp = &ep->base.bind_addr; + struct sctp_af *af; + unsigned short snum; + int ret = 0; + + /* Common sockaddr verification. */ + af = sctp_sockaddr_af(sp, addr, len); + if (!af) { + pr_debug("%s: sk:%p, newaddr:%p, len:%d EINVAL\n", + __func__, sk, addr, len); + return -EINVAL; + } + + snum = ntohs(addr->v4.sin_port); + + pr_debug("%s: sk:%p, new addr:%pISc, port:%d, new port:%d, len:%d\n", + __func__, sk, &addr->sa, bp->port, snum, len); + + /* PF specific bind() address verification. */ + if (!sp->pf->bind_verify(sp, addr)) + return -EADDRNOTAVAIL; + + /* We must either be unbound, or bind to the same port. + * It's OK to allow 0 ports if we are already bound. + * We'll just inhert an already bound port in this case + */ + if (bp->port) { + if (!snum) + snum = bp->port; + else if (snum != bp->port) { + pr_debug("%s: new port %d doesn't match existing port " + "%d\n", __func__, snum, bp->port); + return -EINVAL; + } + } + + if (snum && inet_port_requires_bind_service(net, snum) && + !ns_capable(net->user_ns, CAP_NET_BIND_SERVICE)) + return -EACCES; + + /* See if the address matches any of the addresses we may have + * already bound before checking against other endpoints. + */ + if (sctp_bind_addr_match(bp, addr, sp)) + return -EINVAL; + + /* Make sure we are allowed to bind here. + * The function sctp_get_port_local() does duplicate address + * detection. + */ + addr->v4.sin_port = htons(snum); + if (sctp_get_port_local(sk, addr)) + return -EADDRINUSE; + + /* Refresh ephemeral port. */ + if (!bp->port) { + bp->port = inet_sk(sk)->inet_num; + sctp_auto_asconf_init(sp); + } + + /* Add the address to the bind address list. + * Use GFP_ATOMIC since BHs will be disabled. + */ + ret = sctp_add_bind_addr(bp, addr, af->sockaddr_len, + SCTP_ADDR_SRC, GFP_ATOMIC); + + if (ret) { + sctp_put_port(sk); + return ret; + } + /* Copy back into socket for getsockname() use. */ + inet_sk(sk)->inet_sport = htons(inet_sk(sk)->inet_num); + sp->pf->to_sk_saddr(addr, sk); + + return ret; +} + + /* ADDIP Section 4.1.1 Congestion Control of ASCONF Chunks + * + * R1) One and only one ASCONF Chunk MAY be in transit and unacknowledged + * at any one time. If a sender, after sending an ASCONF chunk, decides + * it needs to transfer another ASCONF Chunk, it MUST wait until the + * ASCONF-ACK Chunk returns from the previous ASCONF Chunk before sending a + * subsequent ASCONF. Note this restriction binds each side, so at any + * time two ASCONF may be in-transit on any given association (one sent + * from each endpoint). + */ +static int sctp_send_asconf(struct sctp_association *asoc, + struct sctp_chunk *chunk) +{ + int retval = 0; + + /* If there is an outstanding ASCONF chunk, queue it for later + * transmission. + */ + if (asoc->addip_last_asconf) { + list_add_tail(&chunk->list, &asoc->addip_chunk_list); + goto out; + } + + /* Hold the chunk until an ASCONF_ACK is received. */ + sctp_chunk_hold(chunk); + retval = sctp_primitive_ASCONF(asoc->base.net, asoc, chunk); + if (retval) + sctp_chunk_free(chunk); + else + asoc->addip_last_asconf = chunk; + +out: + return retval; +} + +/* Add a list of addresses as bind addresses to local endpoint or + * association. + * + * Basically run through each address specified in the addrs/addrcnt + * array/length pair, determine if it is IPv6 or IPv4 and call + * sctp_do_bind() on it. + * + * If any of them fails, then the operation will be reversed and the + * ones that were added will be removed. + * + * Only sctp_setsockopt_bindx() is supposed to call this function. + */ +static int sctp_bindx_add(struct sock *sk, struct sockaddr *addrs, int addrcnt) +{ + int cnt; + int retval = 0; + void *addr_buf; + struct sockaddr *sa_addr; + struct sctp_af *af; + + pr_debug("%s: sk:%p, addrs:%p, addrcnt:%d\n", __func__, sk, + addrs, addrcnt); + + addr_buf = addrs; + for (cnt = 0; cnt < addrcnt; cnt++) { + /* The list may contain either IPv4 or IPv6 address; + * determine the address length for walking thru the list. + */ + sa_addr = addr_buf; + af = sctp_get_af_specific(sa_addr->sa_family); + if (!af) { + retval = -EINVAL; + goto err_bindx_add; + } + + retval = sctp_do_bind(sk, (union sctp_addr *)sa_addr, + af->sockaddr_len); + + addr_buf += af->sockaddr_len; + +err_bindx_add: + if (retval < 0) { + /* Failed. Cleanup the ones that have been added */ + if (cnt > 0) + sctp_bindx_rem(sk, addrs, cnt); + return retval; + } + } + + return retval; +} + +/* Send an ASCONF chunk with Add IP address parameters to all the peers of the + * associations that are part of the endpoint indicating that a list of local + * addresses are added to the endpoint. + * + * If any of the addresses is already in the bind address list of the + * association, we do not send the chunk for that association. But it will not + * affect other associations. + * + * Only sctp_setsockopt_bindx() is supposed to call this function. + */ +static int sctp_send_asconf_add_ip(struct sock *sk, + struct sockaddr *addrs, + int addrcnt) +{ + struct sctp_sock *sp; + struct sctp_endpoint *ep; + struct sctp_association *asoc; + struct sctp_bind_addr *bp; + struct sctp_chunk *chunk; + struct sctp_sockaddr_entry *laddr; + union sctp_addr *addr; + union sctp_addr saveaddr; + void *addr_buf; + struct sctp_af *af; + struct list_head *p; + int i; + int retval = 0; + + sp = sctp_sk(sk); + ep = sp->ep; + + if (!ep->asconf_enable) + return retval; + + pr_debug("%s: sk:%p, addrs:%p, addrcnt:%d\n", + __func__, sk, addrs, addrcnt); + + list_for_each_entry(asoc, &ep->asocs, asocs) { + if (!asoc->peer.asconf_capable) + continue; + + if (asoc->peer.addip_disabled_mask & SCTP_PARAM_ADD_IP) + continue; + + if (!sctp_state(asoc, ESTABLISHED)) + continue; + + /* Check if any address in the packed array of addresses is + * in the bind address list of the association. If so, + * do not send the asconf chunk to its peer, but continue with + * other associations. + */ + addr_buf = addrs; + for (i = 0; i < addrcnt; i++) { + addr = addr_buf; + af = sctp_get_af_specific(addr->v4.sin_family); + if (!af) { + retval = -EINVAL; + goto out; + } + + if (sctp_assoc_lookup_laddr(asoc, addr)) + break; + + addr_buf += af->sockaddr_len; + } + if (i < addrcnt) + continue; + + /* Use the first valid address in bind addr list of + * association as Address Parameter of ASCONF CHUNK. + */ + bp = &asoc->base.bind_addr; + p = bp->address_list.next; + laddr = list_entry(p, struct sctp_sockaddr_entry, list); + chunk = sctp_make_asconf_update_ip(asoc, &laddr->a, addrs, + addrcnt, SCTP_PARAM_ADD_IP); + if (!chunk) { + retval = -ENOMEM; + goto out; + } + + /* Add the new addresses to the bind address list with + * use_as_src set to 0. + */ + addr_buf = addrs; + for (i = 0; i < addrcnt; i++) { + addr = addr_buf; + af = sctp_get_af_specific(addr->v4.sin_family); + memcpy(&saveaddr, addr, af->sockaddr_len); + retval = sctp_add_bind_addr(bp, &saveaddr, + sizeof(saveaddr), + SCTP_ADDR_NEW, GFP_ATOMIC); + addr_buf += af->sockaddr_len; + } + if (asoc->src_out_of_asoc_ok) { + struct sctp_transport *trans; + + list_for_each_entry(trans, + &asoc->peer.transport_addr_list, transports) { + trans->cwnd = min(4*asoc->pathmtu, max_t(__u32, + 2*asoc->pathmtu, 4380)); + trans->ssthresh = asoc->peer.i.a_rwnd; + trans->rto = asoc->rto_initial; + sctp_max_rto(asoc, trans); + trans->rtt = trans->srtt = trans->rttvar = 0; + /* Clear the source and route cache */ + sctp_transport_route(trans, NULL, + sctp_sk(asoc->base.sk)); + } + } + retval = sctp_send_asconf(asoc, chunk); + } + +out: + return retval; +} + +/* Remove a list of addresses from bind addresses list. Do not remove the + * last address. + * + * Basically run through each address specified in the addrs/addrcnt + * array/length pair, determine if it is IPv6 or IPv4 and call + * sctp_del_bind() on it. + * + * If any of them fails, then the operation will be reversed and the + * ones that were removed will be added back. + * + * At least one address has to be left; if only one address is + * available, the operation will return -EBUSY. + * + * Only sctp_setsockopt_bindx() is supposed to call this function. + */ +static int sctp_bindx_rem(struct sock *sk, struct sockaddr *addrs, int addrcnt) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_endpoint *ep = sp->ep; + int cnt; + struct sctp_bind_addr *bp = &ep->base.bind_addr; + int retval = 0; + void *addr_buf; + union sctp_addr *sa_addr; + struct sctp_af *af; + + pr_debug("%s: sk:%p, addrs:%p, addrcnt:%d\n", + __func__, sk, addrs, addrcnt); + + addr_buf = addrs; + for (cnt = 0; cnt < addrcnt; cnt++) { + /* If the bind address list is empty or if there is only one + * bind address, there is nothing more to be removed (we need + * at least one address here). + */ + if (list_empty(&bp->address_list) || + (sctp_list_single_entry(&bp->address_list))) { + retval = -EBUSY; + goto err_bindx_rem; + } + + sa_addr = addr_buf; + af = sctp_get_af_specific(sa_addr->sa.sa_family); + if (!af) { + retval = -EINVAL; + goto err_bindx_rem; + } + + if (!af->addr_valid(sa_addr, sp, NULL)) { + retval = -EADDRNOTAVAIL; + goto err_bindx_rem; + } + + if (sa_addr->v4.sin_port && + sa_addr->v4.sin_port != htons(bp->port)) { + retval = -EINVAL; + goto err_bindx_rem; + } + + if (!sa_addr->v4.sin_port) + sa_addr->v4.sin_port = htons(bp->port); + + /* FIXME - There is probably a need to check if sk->sk_saddr and + * sk->sk_rcv_addr are currently set to one of the addresses to + * be removed. This is something which needs to be looked into + * when we are fixing the outstanding issues with multi-homing + * socket routing and failover schemes. Refer to comments in + * sctp_do_bind(). -daisy + */ + retval = sctp_del_bind_addr(bp, sa_addr); + + addr_buf += af->sockaddr_len; +err_bindx_rem: + if (retval < 0) { + /* Failed. Add the ones that has been removed back */ + if (cnt > 0) + sctp_bindx_add(sk, addrs, cnt); + return retval; + } + } + + return retval; +} + +/* Send an ASCONF chunk with Delete IP address parameters to all the peers of + * the associations that are part of the endpoint indicating that a list of + * local addresses are removed from the endpoint. + * + * If any of the addresses is already in the bind address list of the + * association, we do not send the chunk for that association. But it will not + * affect other associations. + * + * Only sctp_setsockopt_bindx() is supposed to call this function. + */ +static int sctp_send_asconf_del_ip(struct sock *sk, + struct sockaddr *addrs, + int addrcnt) +{ + struct sctp_sock *sp; + struct sctp_endpoint *ep; + struct sctp_association *asoc; + struct sctp_transport *transport; + struct sctp_bind_addr *bp; + struct sctp_chunk *chunk; + union sctp_addr *laddr; + void *addr_buf; + struct sctp_af *af; + struct sctp_sockaddr_entry *saddr; + int i; + int retval = 0; + int stored = 0; + + chunk = NULL; + sp = sctp_sk(sk); + ep = sp->ep; + + if (!ep->asconf_enable) + return retval; + + pr_debug("%s: sk:%p, addrs:%p, addrcnt:%d\n", + __func__, sk, addrs, addrcnt); + + list_for_each_entry(asoc, &ep->asocs, asocs) { + + if (!asoc->peer.asconf_capable) + continue; + + if (asoc->peer.addip_disabled_mask & SCTP_PARAM_DEL_IP) + continue; + + if (!sctp_state(asoc, ESTABLISHED)) + continue; + + /* Check if any address in the packed array of addresses is + * not present in the bind address list of the association. + * If so, do not send the asconf chunk to its peer, but + * continue with other associations. + */ + addr_buf = addrs; + for (i = 0; i < addrcnt; i++) { + laddr = addr_buf; + af = sctp_get_af_specific(laddr->v4.sin_family); + if (!af) { + retval = -EINVAL; + goto out; + } + + if (!sctp_assoc_lookup_laddr(asoc, laddr)) + break; + + addr_buf += af->sockaddr_len; + } + if (i < addrcnt) + continue; + + /* Find one address in the association's bind address list + * that is not in the packed array of addresses. This is to + * make sure that we do not delete all the addresses in the + * association. + */ + bp = &asoc->base.bind_addr; + laddr = sctp_find_unmatch_addr(bp, (union sctp_addr *)addrs, + addrcnt, sp); + if ((laddr == NULL) && (addrcnt == 1)) { + if (asoc->asconf_addr_del_pending) + continue; + asoc->asconf_addr_del_pending = + kzalloc(sizeof(union sctp_addr), GFP_ATOMIC); + if (asoc->asconf_addr_del_pending == NULL) { + retval = -ENOMEM; + goto out; + } + asoc->asconf_addr_del_pending->sa.sa_family = + addrs->sa_family; + asoc->asconf_addr_del_pending->v4.sin_port = + htons(bp->port); + if (addrs->sa_family == AF_INET) { + struct sockaddr_in *sin; + + sin = (struct sockaddr_in *)addrs; + asoc->asconf_addr_del_pending->v4.sin_addr.s_addr = sin->sin_addr.s_addr; + } else if (addrs->sa_family == AF_INET6) { + struct sockaddr_in6 *sin6; + + sin6 = (struct sockaddr_in6 *)addrs; + asoc->asconf_addr_del_pending->v6.sin6_addr = sin6->sin6_addr; + } + + pr_debug("%s: keep the last address asoc:%p %pISc at %p\n", + __func__, asoc, &asoc->asconf_addr_del_pending->sa, + asoc->asconf_addr_del_pending); + + asoc->src_out_of_asoc_ok = 1; + stored = 1; + goto skip_mkasconf; + } + + if (laddr == NULL) + return -EINVAL; + + /* We do not need RCU protection throughout this loop + * because this is done under a socket lock from the + * setsockopt call. + */ + chunk = sctp_make_asconf_update_ip(asoc, laddr, addrs, addrcnt, + SCTP_PARAM_DEL_IP); + if (!chunk) { + retval = -ENOMEM; + goto out; + } + +skip_mkasconf: + /* Reset use_as_src flag for the addresses in the bind address + * list that are to be deleted. + */ + addr_buf = addrs; + for (i = 0; i < addrcnt; i++) { + laddr = addr_buf; + af = sctp_get_af_specific(laddr->v4.sin_family); + list_for_each_entry(saddr, &bp->address_list, list) { + if (sctp_cmp_addr_exact(&saddr->a, laddr)) + saddr->state = SCTP_ADDR_DEL; + } + addr_buf += af->sockaddr_len; + } + + /* Update the route and saddr entries for all the transports + * as some of the addresses in the bind address list are + * about to be deleted and cannot be used as source addresses. + */ + list_for_each_entry(transport, &asoc->peer.transport_addr_list, + transports) { + sctp_transport_route(transport, NULL, + sctp_sk(asoc->base.sk)); + } + + if (stored) + /* We don't need to transmit ASCONF */ + continue; + retval = sctp_send_asconf(asoc, chunk); + } +out: + return retval; +} + +/* set addr events to assocs in the endpoint. ep and addr_wq must be locked */ +int sctp_asconf_mgmt(struct sctp_sock *sp, struct sctp_sockaddr_entry *addrw) +{ + struct sock *sk = sctp_opt2sk(sp); + union sctp_addr *addr; + struct sctp_af *af; + + /* It is safe to write port space in caller. */ + addr = &addrw->a; + addr->v4.sin_port = htons(sp->ep->base.bind_addr.port); + af = sctp_get_af_specific(addr->sa.sa_family); + if (!af) + return -EINVAL; + if (sctp_verify_addr(sk, addr, af->sockaddr_len)) + return -EINVAL; + + if (addrw->state == SCTP_ADDR_NEW) + return sctp_send_asconf_add_ip(sk, (struct sockaddr *)addr, 1); + else + return sctp_send_asconf_del_ip(sk, (struct sockaddr *)addr, 1); +} + +/* Helper for tunneling sctp_bindx() requests through sctp_setsockopt() + * + * API 8.1 + * int sctp_bindx(int sd, struct sockaddr *addrs, int addrcnt, + * int flags); + * + * If sd is an IPv4 socket, the addresses passed must be IPv4 addresses. + * If the sd is an IPv6 socket, the addresses passed can either be IPv4 + * or IPv6 addresses. + * + * A single address may be specified as INADDR_ANY or IN6ADDR_ANY, see + * Section 3.1.2 for this usage. + * + * addrs is a pointer to an array of one or more socket addresses. Each + * address is contained in its appropriate structure (i.e. struct + * sockaddr_in or struct sockaddr_in6) the family of the address type + * must be used to distinguish the address length (note that this + * representation is termed a "packed array" of addresses). The caller + * specifies the number of addresses in the array with addrcnt. + * + * On success, sctp_bindx() returns 0. On failure, sctp_bindx() returns + * -1, and sets errno to the appropriate error code. + * + * For SCTP, the port given in each socket address must be the same, or + * sctp_bindx() will fail, setting errno to EINVAL. + * + * The flags parameter is formed from the bitwise OR of zero or more of + * the following currently defined flags: + * + * SCTP_BINDX_ADD_ADDR + * + * SCTP_BINDX_REM_ADDR + * + * SCTP_BINDX_ADD_ADDR directs SCTP to add the given addresses to the + * association, and SCTP_BINDX_REM_ADDR directs SCTP to remove the given + * addresses from the association. The two flags are mutually exclusive; + * if both are given, sctp_bindx() will fail with EINVAL. A caller may + * not remove all addresses from an association; sctp_bindx() will + * reject such an attempt with EINVAL. + * + * An application can use sctp_bindx(SCTP_BINDX_ADD_ADDR) to associate + * additional addresses with an endpoint after calling bind(). Or use + * sctp_bindx(SCTP_BINDX_REM_ADDR) to remove some addresses a listening + * socket is associated with so that no new association accepted will be + * associated with those addresses. If the endpoint supports dynamic + * address a SCTP_BINDX_REM_ADDR or SCTP_BINDX_ADD_ADDR may cause a + * endpoint to send the appropriate message to the peer to change the + * peers address lists. + * + * Adding and removing addresses from a connected association is + * optional functionality. Implementations that do not support this + * functionality should return EOPNOTSUPP. + * + * Basically do nothing but copying the addresses from user to kernel + * land and invoking either sctp_bindx_add() or sctp_bindx_rem() on the sk. + * This is used for tunneling the sctp_bindx() request through sctp_setsockopt() + * from userspace. + * + * On exit there is no need to do sockfd_put(), sys_setsockopt() does + * it. + * + * sk The sk of the socket + * addrs The pointer to the addresses + * addrssize Size of the addrs buffer + * op Operation to perform (add or remove, see the flags of + * sctp_bindx) + * + * Returns 0 if ok, <0 errno code on error. + */ +static int sctp_setsockopt_bindx(struct sock *sk, struct sockaddr *addrs, + int addrs_size, int op) +{ + int err; + int addrcnt = 0; + int walk_size = 0; + struct sockaddr *sa_addr; + void *addr_buf = addrs; + struct sctp_af *af; + + pr_debug("%s: sk:%p addrs:%p addrs_size:%d opt:%d\n", + __func__, sk, addr_buf, addrs_size, op); + + if (unlikely(addrs_size <= 0)) + return -EINVAL; + + /* Walk through the addrs buffer and count the number of addresses. */ + while (walk_size < addrs_size) { + if (walk_size + sizeof(sa_family_t) > addrs_size) + return -EINVAL; + + sa_addr = addr_buf; + af = sctp_get_af_specific(sa_addr->sa_family); + + /* If the address family is not supported or if this address + * causes the address buffer to overflow return EINVAL. + */ + if (!af || (walk_size + af->sockaddr_len) > addrs_size) + return -EINVAL; + addrcnt++; + addr_buf += af->sockaddr_len; + walk_size += af->sockaddr_len; + } + + /* Do the work. */ + switch (op) { + case SCTP_BINDX_ADD_ADDR: + /* Allow security module to validate bindx addresses. */ + err = security_sctp_bind_connect(sk, SCTP_SOCKOPT_BINDX_ADD, + addrs, addrs_size); + if (err) + return err; + err = sctp_bindx_add(sk, addrs, addrcnt); + if (err) + return err; + return sctp_send_asconf_add_ip(sk, addrs, addrcnt); + case SCTP_BINDX_REM_ADDR: + err = sctp_bindx_rem(sk, addrs, addrcnt); + if (err) + return err; + return sctp_send_asconf_del_ip(sk, addrs, addrcnt); + + default: + return -EINVAL; + } +} + +static int sctp_bind_add(struct sock *sk, struct sockaddr *addrs, + int addrlen) +{ + int err; + + lock_sock(sk); + err = sctp_setsockopt_bindx(sk, addrs, addrlen, SCTP_BINDX_ADD_ADDR); + release_sock(sk); + return err; +} + +static int sctp_connect_new_asoc(struct sctp_endpoint *ep, + const union sctp_addr *daddr, + const struct sctp_initmsg *init, + struct sctp_transport **tp) +{ + struct sctp_association *asoc; + struct sock *sk = ep->base.sk; + struct net *net = sock_net(sk); + enum sctp_scope scope; + int err; + + if (sctp_endpoint_is_peeled_off(ep, daddr)) + return -EADDRNOTAVAIL; + + if (!ep->base.bind_addr.port) { + if (sctp_autobind(sk)) + return -EAGAIN; + } else { + if (inet_port_requires_bind_service(net, ep->base.bind_addr.port) && + !ns_capable(net->user_ns, CAP_NET_BIND_SERVICE)) + return -EACCES; + } + + scope = sctp_scope(daddr); + asoc = sctp_association_new(ep, sk, scope, GFP_KERNEL); + if (!asoc) + return -ENOMEM; + + err = sctp_assoc_set_bind_addr_from_ep(asoc, scope, GFP_KERNEL); + if (err < 0) + goto free; + + *tp = sctp_assoc_add_peer(asoc, daddr, GFP_KERNEL, SCTP_UNKNOWN); + if (!*tp) { + err = -ENOMEM; + goto free; + } + + if (!init) + return 0; + + if (init->sinit_num_ostreams) { + __u16 outcnt = init->sinit_num_ostreams; + + asoc->c.sinit_num_ostreams = outcnt; + /* outcnt has been changed, need to re-init stream */ + err = sctp_stream_init(&asoc->stream, outcnt, 0, GFP_KERNEL); + if (err) + goto free; + } + + if (init->sinit_max_instreams) + asoc->c.sinit_max_instreams = init->sinit_max_instreams; + + if (init->sinit_max_attempts) + asoc->max_init_attempts = init->sinit_max_attempts; + + if (init->sinit_max_init_timeo) + asoc->max_init_timeo = + msecs_to_jiffies(init->sinit_max_init_timeo); + + return 0; +free: + sctp_association_free(asoc); + return err; +} + +static int sctp_connect_add_peer(struct sctp_association *asoc, + union sctp_addr *daddr, int addr_len) +{ + struct sctp_endpoint *ep = asoc->ep; + struct sctp_association *old; + struct sctp_transport *t; + int err; + + err = sctp_verify_addr(ep->base.sk, daddr, addr_len); + if (err) + return err; + + old = sctp_endpoint_lookup_assoc(ep, daddr, &t); + if (old && old != asoc) + return old->state >= SCTP_STATE_ESTABLISHED ? -EISCONN + : -EALREADY; + + if (sctp_endpoint_is_peeled_off(ep, daddr)) + return -EADDRNOTAVAIL; + + t = sctp_assoc_add_peer(asoc, daddr, GFP_KERNEL, SCTP_UNKNOWN); + if (!t) + return -ENOMEM; + + return 0; +} + +/* __sctp_connect(struct sock* sk, struct sockaddr *kaddrs, int addrs_size) + * + * Common routine for handling connect() and sctp_connectx(). + * Connect will come in with just a single address. + */ +static int __sctp_connect(struct sock *sk, struct sockaddr *kaddrs, + int addrs_size, int flags, sctp_assoc_t *assoc_id) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_endpoint *ep = sp->ep; + struct sctp_transport *transport; + struct sctp_association *asoc; + void *addr_buf = kaddrs; + union sctp_addr *daddr; + struct sctp_af *af; + int walk_size, err; + long timeo; + + if (sctp_sstate(sk, ESTABLISHED) || sctp_sstate(sk, CLOSING) || + (sctp_style(sk, TCP) && sctp_sstate(sk, LISTENING))) + return -EISCONN; + + daddr = addr_buf; + af = sctp_get_af_specific(daddr->sa.sa_family); + if (!af || af->sockaddr_len > addrs_size) + return -EINVAL; + + err = sctp_verify_addr(sk, daddr, af->sockaddr_len); + if (err) + return err; + + asoc = sctp_endpoint_lookup_assoc(ep, daddr, &transport); + if (asoc) + return asoc->state >= SCTP_STATE_ESTABLISHED ? -EISCONN + : -EALREADY; + + err = sctp_connect_new_asoc(ep, daddr, NULL, &transport); + if (err) + return err; + asoc = transport->asoc; + + addr_buf += af->sockaddr_len; + walk_size = af->sockaddr_len; + while (walk_size < addrs_size) { + err = -EINVAL; + if (walk_size + sizeof(sa_family_t) > addrs_size) + goto out_free; + + daddr = addr_buf; + af = sctp_get_af_specific(daddr->sa.sa_family); + if (!af || af->sockaddr_len + walk_size > addrs_size) + goto out_free; + + if (asoc->peer.port != ntohs(daddr->v4.sin_port)) + goto out_free; + + err = sctp_connect_add_peer(asoc, daddr, af->sockaddr_len); + if (err) + goto out_free; + + addr_buf += af->sockaddr_len; + walk_size += af->sockaddr_len; + } + + /* In case the user of sctp_connectx() wants an association + * id back, assign one now. + */ + if (assoc_id) { + err = sctp_assoc_set_id(asoc, GFP_KERNEL); + if (err < 0) + goto out_free; + } + + err = sctp_primitive_ASSOCIATE(sock_net(sk), asoc, NULL); + if (err < 0) + goto out_free; + + /* Initialize sk's dport and daddr for getpeername() */ + inet_sk(sk)->inet_dport = htons(asoc->peer.port); + sp->pf->to_sk_daddr(daddr, sk); + sk->sk_err = 0; + + if (assoc_id) + *assoc_id = asoc->assoc_id; + + timeo = sock_sndtimeo(sk, flags & O_NONBLOCK); + return sctp_wait_for_connect(asoc, &timeo); + +out_free: + pr_debug("%s: took out_free path with asoc:%p kaddrs:%p err:%d\n", + __func__, asoc, kaddrs, err); + sctp_association_free(asoc); + return err; +} + +/* Helper for tunneling sctp_connectx() requests through sctp_setsockopt() + * + * API 8.9 + * int sctp_connectx(int sd, struct sockaddr *addrs, int addrcnt, + * sctp_assoc_t *asoc); + * + * If sd is an IPv4 socket, the addresses passed must be IPv4 addresses. + * If the sd is an IPv6 socket, the addresses passed can either be IPv4 + * or IPv6 addresses. + * + * A single address may be specified as INADDR_ANY or IN6ADDR_ANY, see + * Section 3.1.2 for this usage. + * + * addrs is a pointer to an array of one or more socket addresses. Each + * address is contained in its appropriate structure (i.e. struct + * sockaddr_in or struct sockaddr_in6) the family of the address type + * must be used to distengish the address length (note that this + * representation is termed a "packed array" of addresses). The caller + * specifies the number of addresses in the array with addrcnt. + * + * On success, sctp_connectx() returns 0. It also sets the assoc_id to + * the association id of the new association. On failure, sctp_connectx() + * returns -1, and sets errno to the appropriate error code. The assoc_id + * is not touched by the kernel. + * + * For SCTP, the port given in each socket address must be the same, or + * sctp_connectx() will fail, setting errno to EINVAL. + * + * An application can use sctp_connectx to initiate an association with + * an endpoint that is multi-homed. Much like sctp_bindx() this call + * allows a caller to specify multiple addresses at which a peer can be + * reached. The way the SCTP stack uses the list of addresses to set up + * the association is implementation dependent. This function only + * specifies that the stack will try to make use of all the addresses in + * the list when needed. + * + * Note that the list of addresses passed in is only used for setting up + * the association. It does not necessarily equal the set of addresses + * the peer uses for the resulting association. If the caller wants to + * find out the set of peer addresses, it must use sctp_getpaddrs() to + * retrieve them after the association has been set up. + * + * Basically do nothing but copying the addresses from user to kernel + * land and invoking either sctp_connectx(). This is used for tunneling + * the sctp_connectx() request through sctp_setsockopt() from userspace. + * + * On exit there is no need to do sockfd_put(), sys_setsockopt() does + * it. + * + * sk The sk of the socket + * addrs The pointer to the addresses + * addrssize Size of the addrs buffer + * + * Returns >=0 if ok, <0 errno code on error. + */ +static int __sctp_setsockopt_connectx(struct sock *sk, struct sockaddr *kaddrs, + int addrs_size, sctp_assoc_t *assoc_id) +{ + int err = 0, flags = 0; + + pr_debug("%s: sk:%p addrs:%p addrs_size:%d\n", + __func__, sk, kaddrs, addrs_size); + + /* make sure the 1st addr's sa_family is accessible later */ + if (unlikely(addrs_size < sizeof(sa_family_t))) + return -EINVAL; + + /* Allow security module to validate connectx addresses. */ + err = security_sctp_bind_connect(sk, SCTP_SOCKOPT_CONNECTX, + (struct sockaddr *)kaddrs, + addrs_size); + if (err) + return err; + + /* in-kernel sockets don't generally have a file allocated to them + * if all they do is call sock_create_kern(). + */ + if (sk->sk_socket->file) + flags = sk->sk_socket->file->f_flags; + + return __sctp_connect(sk, kaddrs, addrs_size, flags, assoc_id); +} + +/* + * This is an older interface. It's kept for backward compatibility + * to the option that doesn't provide association id. + */ +static int sctp_setsockopt_connectx_old(struct sock *sk, + struct sockaddr *kaddrs, + int addrs_size) +{ + return __sctp_setsockopt_connectx(sk, kaddrs, addrs_size, NULL); +} + +/* + * New interface for the API. The since the API is done with a socket + * option, to make it simple we feed back the association id is as a return + * indication to the call. Error is always negative and association id is + * always positive. + */ +static int sctp_setsockopt_connectx(struct sock *sk, + struct sockaddr *kaddrs, + int addrs_size) +{ + sctp_assoc_t assoc_id = 0; + int err = 0; + + err = __sctp_setsockopt_connectx(sk, kaddrs, addrs_size, &assoc_id); + + if (err) + return err; + else + return assoc_id; +} + +/* + * New (hopefully final) interface for the API. + * We use the sctp_getaddrs_old structure so that use-space library + * can avoid any unnecessary allocations. The only different part + * is that we store the actual length of the address buffer into the + * addrs_num structure member. That way we can re-use the existing + * code. + */ +#ifdef CONFIG_COMPAT +struct compat_sctp_getaddrs_old { + sctp_assoc_t assoc_id; + s32 addr_num; + compat_uptr_t addrs; /* struct sockaddr * */ +}; +#endif + +static int sctp_getsockopt_connectx3(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_getaddrs_old param; + sctp_assoc_t assoc_id = 0; + struct sockaddr *kaddrs; + int err = 0; + +#ifdef CONFIG_COMPAT + if (in_compat_syscall()) { + struct compat_sctp_getaddrs_old param32; + + if (len < sizeof(param32)) + return -EINVAL; + if (copy_from_user(¶m32, optval, sizeof(param32))) + return -EFAULT; + + param.assoc_id = param32.assoc_id; + param.addr_num = param32.addr_num; + param.addrs = compat_ptr(param32.addrs); + } else +#endif + { + if (len < sizeof(param)) + return -EINVAL; + if (copy_from_user(¶m, optval, sizeof(param))) + return -EFAULT; + } + + kaddrs = memdup_user(param.addrs, param.addr_num); + if (IS_ERR(kaddrs)) + return PTR_ERR(kaddrs); + + err = __sctp_setsockopt_connectx(sk, kaddrs, param.addr_num, &assoc_id); + kfree(kaddrs); + if (err == 0 || err == -EINPROGRESS) { + if (copy_to_user(optval, &assoc_id, sizeof(assoc_id))) + return -EFAULT; + if (put_user(sizeof(assoc_id), optlen)) + return -EFAULT; + } + + return err; +} + +/* API 3.1.4 close() - UDP Style Syntax + * Applications use close() to perform graceful shutdown (as described in + * Section 10.1 of [SCTP]) on ALL the associations currently represented + * by a UDP-style socket. + * + * The syntax is + * + * ret = close(int sd); + * + * sd - the socket descriptor of the associations to be closed. + * + * To gracefully shutdown a specific association represented by the + * UDP-style socket, an application should use the sendmsg() call, + * passing no user data, but including the appropriate flag in the + * ancillary data (see Section xxxx). + * + * If sd in the close() call is a branched-off socket representing only + * one association, the shutdown is performed on that association only. + * + * 4.1.6 close() - TCP Style Syntax + * + * Applications use close() to gracefully close down an association. + * + * The syntax is: + * + * int close(int sd); + * + * sd - the socket descriptor of the association to be closed. + * + * After an application calls close() on a socket descriptor, no further + * socket operations will succeed on that descriptor. + * + * API 7.1.4 SO_LINGER + * + * An application using the TCP-style socket can use this option to + * perform the SCTP ABORT primitive. The linger option structure is: + * + * struct linger { + * int l_onoff; // option on/off + * int l_linger; // linger time + * }; + * + * To enable the option, set l_onoff to 1. If the l_linger value is set + * to 0, calling close() is the same as the ABORT primitive. If the + * value is set to a negative value, the setsockopt() call will return + * an error. If the value is set to a positive value linger_time, the + * close() can be blocked for at most linger_time ms. If the graceful + * shutdown phase does not finish during this period, close() will + * return but the graceful shutdown phase continues in the system. + */ +static void sctp_close(struct sock *sk, long timeout) +{ + struct net *net = sock_net(sk); + struct sctp_endpoint *ep; + struct sctp_association *asoc; + struct list_head *pos, *temp; + unsigned int data_was_unread; + + pr_debug("%s: sk:%p, timeout:%ld\n", __func__, sk, timeout); + + lock_sock_nested(sk, SINGLE_DEPTH_NESTING); + sk->sk_shutdown = SHUTDOWN_MASK; + inet_sk_set_state(sk, SCTP_SS_CLOSING); + + ep = sctp_sk(sk)->ep; + + /* Clean up any skbs sitting on the receive queue. */ + data_was_unread = sctp_queue_purge_ulpevents(&sk->sk_receive_queue); + data_was_unread += sctp_queue_purge_ulpevents(&sctp_sk(sk)->pd_lobby); + + /* Walk all associations on an endpoint. */ + list_for_each_safe(pos, temp, &ep->asocs) { + asoc = list_entry(pos, struct sctp_association, asocs); + + if (sctp_style(sk, TCP)) { + /* A closed association can still be in the list if + * it belongs to a TCP-style listening socket that is + * not yet accepted. If so, free it. If not, send an + * ABORT or SHUTDOWN based on the linger options. + */ + if (sctp_state(asoc, CLOSED)) { + sctp_association_free(asoc); + continue; + } + } + + if (data_was_unread || !skb_queue_empty(&asoc->ulpq.lobby) || + !skb_queue_empty(&asoc->ulpq.reasm) || + !skb_queue_empty(&asoc->ulpq.reasm_uo) || + (sock_flag(sk, SOCK_LINGER) && !sk->sk_lingertime)) { + struct sctp_chunk *chunk; + + chunk = sctp_make_abort_user(asoc, NULL, 0); + sctp_primitive_ABORT(net, asoc, chunk); + } else + sctp_primitive_SHUTDOWN(net, asoc, NULL); + } + + /* On a TCP-style socket, block for at most linger_time if set. */ + if (sctp_style(sk, TCP) && timeout) + sctp_wait_for_close(sk, timeout); + + /* This will run the backlog queue. */ + release_sock(sk); + + /* Supposedly, no process has access to the socket, but + * the net layers still may. + * Also, sctp_destroy_sock() needs to be called with addr_wq_lock + * held and that should be grabbed before socket lock. + */ + spin_lock_bh(&net->sctp.addr_wq_lock); + bh_lock_sock_nested(sk); + + /* Hold the sock, since sk_common_release() will put sock_put() + * and we have just a little more cleanup. + */ + sock_hold(sk); + sk_common_release(sk); + + bh_unlock_sock(sk); + spin_unlock_bh(&net->sctp.addr_wq_lock); + + sock_put(sk); + + SCTP_DBG_OBJCNT_DEC(sock); +} + +/* Handle EPIPE error. */ +static int sctp_error(struct sock *sk, int flags, int err) +{ + if (err == -EPIPE) + err = sock_error(sk) ? : -EPIPE; + if (err == -EPIPE && !(flags & MSG_NOSIGNAL)) + send_sig(SIGPIPE, current, 0); + return err; +} + +/* API 3.1.3 sendmsg() - UDP Style Syntax + * + * An application uses sendmsg() and recvmsg() calls to transmit data to + * and receive data from its peer. + * + * ssize_t sendmsg(int socket, const struct msghdr *message, + * int flags); + * + * socket - the socket descriptor of the endpoint. + * message - pointer to the msghdr structure which contains a single + * user message and possibly some ancillary data. + * + * See Section 5 for complete description of the data + * structures. + * + * flags - flags sent or received with the user message, see Section + * 5 for complete description of the flags. + * + * Note: This function could use a rewrite especially when explicit + * connect support comes in. + */ +/* BUG: We do not implement the equivalent of sk_stream_wait_memory(). */ + +static int sctp_msghdr_parse(const struct msghdr *msg, + struct sctp_cmsgs *cmsgs); + +static int sctp_sendmsg_parse(struct sock *sk, struct sctp_cmsgs *cmsgs, + struct sctp_sndrcvinfo *srinfo, + const struct msghdr *msg, size_t msg_len) +{ + __u16 sflags; + int err; + + if (sctp_sstate(sk, LISTENING) && sctp_style(sk, TCP)) + return -EPIPE; + + if (msg_len > sk->sk_sndbuf) + return -EMSGSIZE; + + memset(cmsgs, 0, sizeof(*cmsgs)); + err = sctp_msghdr_parse(msg, cmsgs); + if (err) { + pr_debug("%s: msghdr parse err:%x\n", __func__, err); + return err; + } + + memset(srinfo, 0, sizeof(*srinfo)); + if (cmsgs->srinfo) { + srinfo->sinfo_stream = cmsgs->srinfo->sinfo_stream; + srinfo->sinfo_flags = cmsgs->srinfo->sinfo_flags; + srinfo->sinfo_ppid = cmsgs->srinfo->sinfo_ppid; + srinfo->sinfo_context = cmsgs->srinfo->sinfo_context; + srinfo->sinfo_assoc_id = cmsgs->srinfo->sinfo_assoc_id; + srinfo->sinfo_timetolive = cmsgs->srinfo->sinfo_timetolive; + } + + if (cmsgs->sinfo) { + srinfo->sinfo_stream = cmsgs->sinfo->snd_sid; + srinfo->sinfo_flags = cmsgs->sinfo->snd_flags; + srinfo->sinfo_ppid = cmsgs->sinfo->snd_ppid; + srinfo->sinfo_context = cmsgs->sinfo->snd_context; + srinfo->sinfo_assoc_id = cmsgs->sinfo->snd_assoc_id; + } + + if (cmsgs->prinfo) { + srinfo->sinfo_timetolive = cmsgs->prinfo->pr_value; + SCTP_PR_SET_POLICY(srinfo->sinfo_flags, + cmsgs->prinfo->pr_policy); + } + + sflags = srinfo->sinfo_flags; + if (!sflags && msg_len) + return 0; + + if (sctp_style(sk, TCP) && (sflags & (SCTP_EOF | SCTP_ABORT))) + return -EINVAL; + + if (((sflags & SCTP_EOF) && msg_len > 0) || + (!(sflags & (SCTP_EOF | SCTP_ABORT)) && msg_len == 0)) + return -EINVAL; + + if ((sflags & SCTP_ADDR_OVER) && !msg->msg_name) + return -EINVAL; + + return 0; +} + +static int sctp_sendmsg_new_asoc(struct sock *sk, __u16 sflags, + struct sctp_cmsgs *cmsgs, + union sctp_addr *daddr, + struct sctp_transport **tp) +{ + struct sctp_endpoint *ep = sctp_sk(sk)->ep; + struct sctp_association *asoc; + struct cmsghdr *cmsg; + __be32 flowinfo = 0; + struct sctp_af *af; + int err; + + *tp = NULL; + + if (sflags & (SCTP_EOF | SCTP_ABORT)) + return -EINVAL; + + if (sctp_style(sk, TCP) && (sctp_sstate(sk, ESTABLISHED) || + sctp_sstate(sk, CLOSING))) + return -EADDRNOTAVAIL; + + /* Label connection socket for first association 1-to-many + * style for client sequence socket()->sendmsg(). This + * needs to be done before sctp_assoc_add_peer() as that will + * set up the initial packet that needs to account for any + * security ip options (CIPSO/CALIPSO) added to the packet. + */ + af = sctp_get_af_specific(daddr->sa.sa_family); + if (!af) + return -EINVAL; + err = security_sctp_bind_connect(sk, SCTP_SENDMSG_CONNECT, + (struct sockaddr *)daddr, + af->sockaddr_len); + if (err < 0) + return err; + + err = sctp_connect_new_asoc(ep, daddr, cmsgs->init, tp); + if (err) + return err; + asoc = (*tp)->asoc; + + if (!cmsgs->addrs_msg) + return 0; + + if (daddr->sa.sa_family == AF_INET6) + flowinfo = daddr->v6.sin6_flowinfo; + + /* sendv addr list parse */ + for_each_cmsghdr(cmsg, cmsgs->addrs_msg) { + union sctp_addr _daddr; + int dlen; + + if (cmsg->cmsg_level != IPPROTO_SCTP || + (cmsg->cmsg_type != SCTP_DSTADDRV4 && + cmsg->cmsg_type != SCTP_DSTADDRV6)) + continue; + + daddr = &_daddr; + memset(daddr, 0, sizeof(*daddr)); + dlen = cmsg->cmsg_len - sizeof(struct cmsghdr); + if (cmsg->cmsg_type == SCTP_DSTADDRV4) { + if (dlen < sizeof(struct in_addr)) { + err = -EINVAL; + goto free; + } + + dlen = sizeof(struct in_addr); + daddr->v4.sin_family = AF_INET; + daddr->v4.sin_port = htons(asoc->peer.port); + memcpy(&daddr->v4.sin_addr, CMSG_DATA(cmsg), dlen); + } else { + if (dlen < sizeof(struct in6_addr)) { + err = -EINVAL; + goto free; + } + + dlen = sizeof(struct in6_addr); + daddr->v6.sin6_flowinfo = flowinfo; + daddr->v6.sin6_family = AF_INET6; + daddr->v6.sin6_port = htons(asoc->peer.port); + memcpy(&daddr->v6.sin6_addr, CMSG_DATA(cmsg), dlen); + } + + err = sctp_connect_add_peer(asoc, daddr, sizeof(*daddr)); + if (err) + goto free; + } + + return 0; + +free: + sctp_association_free(asoc); + return err; +} + +static int sctp_sendmsg_check_sflags(struct sctp_association *asoc, + __u16 sflags, struct msghdr *msg, + size_t msg_len) +{ + struct sock *sk = asoc->base.sk; + struct net *net = sock_net(sk); + + if (sctp_state(asoc, CLOSED) && sctp_style(sk, TCP)) + return -EPIPE; + + if ((sflags & SCTP_SENDALL) && sctp_style(sk, UDP) && + !sctp_state(asoc, ESTABLISHED)) + return 0; + + if (sflags & SCTP_EOF) { + pr_debug("%s: shutting down association:%p\n", __func__, asoc); + sctp_primitive_SHUTDOWN(net, asoc, NULL); + + return 0; + } + + if (sflags & SCTP_ABORT) { + struct sctp_chunk *chunk; + + chunk = sctp_make_abort_user(asoc, msg, msg_len); + if (!chunk) + return -ENOMEM; + + pr_debug("%s: aborting association:%p\n", __func__, asoc); + sctp_primitive_ABORT(net, asoc, chunk); + iov_iter_revert(&msg->msg_iter, msg_len); + + return 0; + } + + return 1; +} + +static int sctp_sendmsg_to_asoc(struct sctp_association *asoc, + struct msghdr *msg, size_t msg_len, + struct sctp_transport *transport, + struct sctp_sndrcvinfo *sinfo) +{ + struct sock *sk = asoc->base.sk; + struct sctp_sock *sp = sctp_sk(sk); + struct net *net = sock_net(sk); + struct sctp_datamsg *datamsg; + bool wait_connect = false; + struct sctp_chunk *chunk; + long timeo; + int err; + + if (sinfo->sinfo_stream >= asoc->stream.outcnt) { + err = -EINVAL; + goto err; + } + + if (unlikely(!SCTP_SO(&asoc->stream, sinfo->sinfo_stream)->ext)) { + err = sctp_stream_init_ext(&asoc->stream, sinfo->sinfo_stream); + if (err) + goto err; + } + + if (sp->disable_fragments && msg_len > asoc->frag_point) { + err = -EMSGSIZE; + goto err; + } + + if (asoc->pmtu_pending) { + if (sp->param_flags & SPP_PMTUD_ENABLE) + sctp_assoc_sync_pmtu(asoc); + asoc->pmtu_pending = 0; + } + + if (sctp_wspace(asoc) < (int)msg_len) + sctp_prsctp_prune(asoc, sinfo, msg_len - sctp_wspace(asoc)); + + if (sctp_wspace(asoc) <= 0 || !sk_wmem_schedule(sk, msg_len)) { + timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + err = sctp_wait_for_sndbuf(asoc, &timeo, msg_len); + if (err) + goto err; + if (unlikely(sinfo->sinfo_stream >= asoc->stream.outcnt)) { + err = -EINVAL; + goto err; + } + } + + if (sctp_state(asoc, CLOSED)) { + err = sctp_primitive_ASSOCIATE(net, asoc, NULL); + if (err) + goto err; + + if (asoc->ep->intl_enable) { + timeo = sock_sndtimeo(sk, 0); + err = sctp_wait_for_connect(asoc, &timeo); + if (err) { + err = -ESRCH; + goto err; + } + } else { + wait_connect = true; + } + + pr_debug("%s: we associated primitively\n", __func__); + } + + datamsg = sctp_datamsg_from_user(asoc, sinfo, &msg->msg_iter); + if (IS_ERR(datamsg)) { + err = PTR_ERR(datamsg); + goto err; + } + + asoc->force_delay = !!(msg->msg_flags & MSG_MORE); + + list_for_each_entry(chunk, &datamsg->chunks, frag_list) { + sctp_chunk_hold(chunk); + sctp_set_owner_w(chunk); + chunk->transport = transport; + } + + err = sctp_primitive_SEND(net, asoc, datamsg); + if (err) { + sctp_datamsg_free(datamsg); + goto err; + } + + pr_debug("%s: we sent primitively\n", __func__); + + sctp_datamsg_put(datamsg); + + if (unlikely(wait_connect)) { + timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + sctp_wait_for_connect(asoc, &timeo); + } + + err = msg_len; + +err: + return err; +} + +static union sctp_addr *sctp_sendmsg_get_daddr(struct sock *sk, + const struct msghdr *msg, + struct sctp_cmsgs *cmsgs) +{ + union sctp_addr *daddr = NULL; + int err; + + if (!sctp_style(sk, UDP_HIGH_BANDWIDTH) && msg->msg_name) { + int len = msg->msg_namelen; + + if (len > sizeof(*daddr)) + len = sizeof(*daddr); + + daddr = (union sctp_addr *)msg->msg_name; + + err = sctp_verify_addr(sk, daddr, len); + if (err) + return ERR_PTR(err); + } + + return daddr; +} + +static void sctp_sendmsg_update_sinfo(struct sctp_association *asoc, + struct sctp_sndrcvinfo *sinfo, + struct sctp_cmsgs *cmsgs) +{ + if (!cmsgs->srinfo && !cmsgs->sinfo) { + sinfo->sinfo_stream = asoc->default_stream; + sinfo->sinfo_ppid = asoc->default_ppid; + sinfo->sinfo_context = asoc->default_context; + sinfo->sinfo_assoc_id = sctp_assoc2id(asoc); + + if (!cmsgs->prinfo) + sinfo->sinfo_flags = asoc->default_flags; + } + + if (!cmsgs->srinfo && !cmsgs->prinfo) + sinfo->sinfo_timetolive = asoc->default_timetolive; + + if (cmsgs->authinfo) { + /* Reuse sinfo_tsn to indicate that authinfo was set and + * sinfo_ssn to save the keyid on tx path. + */ + sinfo->sinfo_tsn = 1; + sinfo->sinfo_ssn = cmsgs->authinfo->auth_keynumber; + } +} + +static int sctp_sendmsg(struct sock *sk, struct msghdr *msg, size_t msg_len) +{ + struct sctp_endpoint *ep = sctp_sk(sk)->ep; + struct sctp_transport *transport = NULL; + struct sctp_sndrcvinfo _sinfo, *sinfo; + struct sctp_association *asoc, *tmp; + struct sctp_cmsgs cmsgs; + union sctp_addr *daddr; + bool new = false; + __u16 sflags; + int err; + + /* Parse and get snd_info */ + err = sctp_sendmsg_parse(sk, &cmsgs, &_sinfo, msg, msg_len); + if (err) + goto out; + + sinfo = &_sinfo; + sflags = sinfo->sinfo_flags; + + /* Get daddr from msg */ + daddr = sctp_sendmsg_get_daddr(sk, msg, &cmsgs); + if (IS_ERR(daddr)) { + err = PTR_ERR(daddr); + goto out; + } + + lock_sock(sk); + + /* SCTP_SENDALL process */ + if ((sflags & SCTP_SENDALL) && sctp_style(sk, UDP)) { + list_for_each_entry_safe(asoc, tmp, &ep->asocs, asocs) { + err = sctp_sendmsg_check_sflags(asoc, sflags, msg, + msg_len); + if (err == 0) + continue; + if (err < 0) + goto out_unlock; + + sctp_sendmsg_update_sinfo(asoc, sinfo, &cmsgs); + + err = sctp_sendmsg_to_asoc(asoc, msg, msg_len, + NULL, sinfo); + if (err < 0) + goto out_unlock; + + iov_iter_revert(&msg->msg_iter, err); + } + + goto out_unlock; + } + + /* Get and check or create asoc */ + if (daddr) { + asoc = sctp_endpoint_lookup_assoc(ep, daddr, &transport); + if (asoc) { + err = sctp_sendmsg_check_sflags(asoc, sflags, msg, + msg_len); + if (err <= 0) + goto out_unlock; + } else { + err = sctp_sendmsg_new_asoc(sk, sflags, &cmsgs, daddr, + &transport); + if (err) + goto out_unlock; + + asoc = transport->asoc; + new = true; + } + + if (!sctp_style(sk, TCP) && !(sflags & SCTP_ADDR_OVER)) + transport = NULL; + } else { + asoc = sctp_id2assoc(sk, sinfo->sinfo_assoc_id); + if (!asoc) { + err = -EPIPE; + goto out_unlock; + } + + err = sctp_sendmsg_check_sflags(asoc, sflags, msg, msg_len); + if (err <= 0) + goto out_unlock; + } + + /* Update snd_info with the asoc */ + sctp_sendmsg_update_sinfo(asoc, sinfo, &cmsgs); + + /* Send msg to the asoc */ + err = sctp_sendmsg_to_asoc(asoc, msg, msg_len, transport, sinfo); + if (err < 0 && err != -ESRCH && new) + sctp_association_free(asoc); + +out_unlock: + release_sock(sk); +out: + return sctp_error(sk, msg->msg_flags, err); +} + +/* This is an extended version of skb_pull() that removes the data from the + * start of a skb even when data is spread across the list of skb's in the + * frag_list. len specifies the total amount of data that needs to be removed. + * when 'len' bytes could be removed from the skb, it returns 0. + * If 'len' exceeds the total skb length, it returns the no. of bytes that + * could not be removed. + */ +static int sctp_skb_pull(struct sk_buff *skb, int len) +{ + struct sk_buff *list; + int skb_len = skb_headlen(skb); + int rlen; + + if (len <= skb_len) { + __skb_pull(skb, len); + return 0; + } + len -= skb_len; + __skb_pull(skb, skb_len); + + skb_walk_frags(skb, list) { + rlen = sctp_skb_pull(list, len); + skb->len -= (len-rlen); + skb->data_len -= (len-rlen); + + if (!rlen) + return 0; + + len = rlen; + } + + return len; +} + +/* API 3.1.3 recvmsg() - UDP Style Syntax + * + * ssize_t recvmsg(int socket, struct msghdr *message, + * int flags); + * + * socket - the socket descriptor of the endpoint. + * message - pointer to the msghdr structure which contains a single + * user message and possibly some ancillary data. + * + * See Section 5 for complete description of the data + * structures. + * + * flags - flags sent or received with the user message, see Section + * 5 for complete description of the flags. + */ +static int sctp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int flags, int *addr_len) +{ + struct sctp_ulpevent *event = NULL; + struct sctp_sock *sp = sctp_sk(sk); + struct sk_buff *skb, *head_skb; + int copied; + int err = 0; + int skb_len; + + pr_debug("%s: sk:%p, msghdr:%p, len:%zd, flags:0x%x, addr_len:%p)\n", + __func__, sk, msg, len, flags, addr_len); + + if (unlikely(flags & MSG_ERRQUEUE)) + return inet_recv_error(sk, msg, len, addr_len); + + if (sk_can_busy_loop(sk) && + skb_queue_empty_lockless(&sk->sk_receive_queue)) + sk_busy_loop(sk, flags & MSG_DONTWAIT); + + lock_sock(sk); + + if (sctp_style(sk, TCP) && !sctp_sstate(sk, ESTABLISHED) && + !sctp_sstate(sk, CLOSING) && !sctp_sstate(sk, CLOSED)) { + err = -ENOTCONN; + goto out; + } + + skb = sctp_skb_recv_datagram(sk, flags, &err); + if (!skb) + goto out; + + /* Get the total length of the skb including any skb's in the + * frag_list. + */ + skb_len = skb->len; + + copied = skb_len; + if (copied > len) + copied = len; + + err = skb_copy_datagram_msg(skb, 0, msg, copied); + + event = sctp_skb2event(skb); + + if (err) + goto out_free; + + if (event->chunk && event->chunk->head_skb) + head_skb = event->chunk->head_skb; + else + head_skb = skb; + sock_recv_cmsgs(msg, sk, head_skb); + if (sctp_ulpevent_is_notification(event)) { + msg->msg_flags |= MSG_NOTIFICATION; + sp->pf->event_msgname(event, msg->msg_name, addr_len); + } else { + sp->pf->skb_msgname(head_skb, msg->msg_name, addr_len); + } + + /* Check if we allow SCTP_NXTINFO. */ + if (sp->recvnxtinfo) + sctp_ulpevent_read_nxtinfo(event, msg, sk); + /* Check if we allow SCTP_RCVINFO. */ + if (sp->recvrcvinfo) + sctp_ulpevent_read_rcvinfo(event, msg); + /* Check if we allow SCTP_SNDRCVINFO. */ + if (sctp_ulpevent_type_enabled(sp->subscribe, SCTP_DATA_IO_EVENT)) + sctp_ulpevent_read_sndrcvinfo(event, msg); + + err = copied; + + /* If skb's length exceeds the user's buffer, update the skb and + * push it back to the receive_queue so that the next call to + * recvmsg() will return the remaining data. Don't set MSG_EOR. + */ + if (skb_len > copied) { + msg->msg_flags &= ~MSG_EOR; + if (flags & MSG_PEEK) + goto out_free; + sctp_skb_pull(skb, copied); + skb_queue_head(&sk->sk_receive_queue, skb); + + /* When only partial message is copied to the user, increase + * rwnd by that amount. If all the data in the skb is read, + * rwnd is updated when the event is freed. + */ + if (!sctp_ulpevent_is_notification(event)) + sctp_assoc_rwnd_increase(event->asoc, copied); + goto out; + } else if ((event->msg_flags & MSG_NOTIFICATION) || + (event->msg_flags & MSG_EOR)) + msg->msg_flags |= MSG_EOR; + else + msg->msg_flags &= ~MSG_EOR; + +out_free: + if (flags & MSG_PEEK) { + /* Release the skb reference acquired after peeking the skb in + * sctp_skb_recv_datagram(). + */ + kfree_skb(skb); + } else { + /* Free the event which includes releasing the reference to + * the owner of the skb, freeing the skb and updating the + * rwnd. + */ + sctp_ulpevent_free(event); + } +out: + release_sock(sk); + return err; +} + +/* 7.1.12 Enable/Disable message fragmentation (SCTP_DISABLE_FRAGMENTS) + * + * This option is a on/off flag. If enabled no SCTP message + * fragmentation will be performed. Instead if a message being sent + * exceeds the current PMTU size, the message will NOT be sent and + * instead a error will be indicated to the user. + */ +static int sctp_setsockopt_disable_fragments(struct sock *sk, int *val, + unsigned int optlen) +{ + if (optlen < sizeof(int)) + return -EINVAL; + sctp_sk(sk)->disable_fragments = (*val == 0) ? 0 : 1; + return 0; +} + +static int sctp_setsockopt_events(struct sock *sk, __u8 *sn_type, + unsigned int optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_association *asoc; + int i; + + if (optlen > sizeof(struct sctp_event_subscribe)) + return -EINVAL; + + for (i = 0; i < optlen; i++) + sctp_ulpevent_type_set(&sp->subscribe, SCTP_SN_TYPE_BASE + i, + sn_type[i]); + + list_for_each_entry(asoc, &sp->ep->asocs, asocs) + asoc->subscribe = sctp_sk(sk)->subscribe; + + /* At the time when a user app subscribes to SCTP_SENDER_DRY_EVENT, + * if there is no data to be sent or retransmit, the stack will + * immediately send up this notification. + */ + if (sctp_ulpevent_type_enabled(sp->subscribe, SCTP_SENDER_DRY_EVENT)) { + struct sctp_ulpevent *event; + + asoc = sctp_id2assoc(sk, 0); + if (asoc && sctp_outq_is_empty(&asoc->outqueue)) { + event = sctp_ulpevent_make_sender_dry_event(asoc, + GFP_USER | __GFP_NOWARN); + if (!event) + return -ENOMEM; + + asoc->stream.si->enqueue_event(&asoc->ulpq, event); + } + } + + return 0; +} + +/* 7.1.8 Automatic Close of associations (SCTP_AUTOCLOSE) + * + * This socket option is applicable to the UDP-style socket only. When + * set it will cause associations that are idle for more than the + * specified number of seconds to automatically close. An association + * being idle is defined an association that has NOT sent or received + * user data. The special value of '0' indicates that no automatic + * close of any associations should be performed. The option expects an + * integer defining the number of seconds of idle time before an + * association is closed. + */ +static int sctp_setsockopt_autoclose(struct sock *sk, u32 *optval, + unsigned int optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct net *net = sock_net(sk); + + /* Applicable to UDP-style socket only */ + if (sctp_style(sk, TCP)) + return -EOPNOTSUPP; + if (optlen != sizeof(int)) + return -EINVAL; + + sp->autoclose = *optval; + if (sp->autoclose > net->sctp.max_autoclose) + sp->autoclose = net->sctp.max_autoclose; + + return 0; +} + +/* 7.1.13 Peer Address Parameters (SCTP_PEER_ADDR_PARAMS) + * + * Applications can enable or disable heartbeats for any peer address of + * an association, modify an address's heartbeat interval, force a + * heartbeat to be sent immediately, and adjust the address's maximum + * number of retransmissions sent before an address is considered + * unreachable. The following structure is used to access and modify an + * address's parameters: + * + * struct sctp_paddrparams { + * sctp_assoc_t spp_assoc_id; + * struct sockaddr_storage spp_address; + * uint32_t spp_hbinterval; + * uint16_t spp_pathmaxrxt; + * uint32_t spp_pathmtu; + * uint32_t spp_sackdelay; + * uint32_t spp_flags; + * uint32_t spp_ipv6_flowlabel; + * uint8_t spp_dscp; + * }; + * + * spp_assoc_id - (one-to-many style socket) This is filled in the + * application, and identifies the association for + * this query. + * spp_address - This specifies which address is of interest. + * spp_hbinterval - This contains the value of the heartbeat interval, + * in milliseconds. If a value of zero + * is present in this field then no changes are to + * be made to this parameter. + * spp_pathmaxrxt - This contains the maximum number of + * retransmissions before this address shall be + * considered unreachable. If a value of zero + * is present in this field then no changes are to + * be made to this parameter. + * spp_pathmtu - When Path MTU discovery is disabled the value + * specified here will be the "fixed" path mtu. + * Note that if the spp_address field is empty + * then all associations on this address will + * have this fixed path mtu set upon them. + * + * spp_sackdelay - When delayed sack is enabled, this value specifies + * the number of milliseconds that sacks will be delayed + * for. This value will apply to all addresses of an + * association if the spp_address field is empty. Note + * also, that if delayed sack is enabled and this + * value is set to 0, no change is made to the last + * recorded delayed sack timer value. + * + * spp_flags - These flags are used to control various features + * on an association. The flag field may contain + * zero or more of the following options. + * + * SPP_HB_ENABLE - Enable heartbeats on the + * specified address. Note that if the address + * field is empty all addresses for the association + * have heartbeats enabled upon them. + * + * SPP_HB_DISABLE - Disable heartbeats on the + * speicifed address. Note that if the address + * field is empty all addresses for the association + * will have their heartbeats disabled. Note also + * that SPP_HB_ENABLE and SPP_HB_DISABLE are + * mutually exclusive, only one of these two should + * be specified. Enabling both fields will have + * undetermined results. + * + * SPP_HB_DEMAND - Request a user initiated heartbeat + * to be made immediately. + * + * SPP_HB_TIME_IS_ZERO - Specify's that the time for + * heartbeat delayis to be set to the value of 0 + * milliseconds. + * + * SPP_PMTUD_ENABLE - This field will enable PMTU + * discovery upon the specified address. Note that + * if the address feild is empty then all addresses + * on the association are effected. + * + * SPP_PMTUD_DISABLE - This field will disable PMTU + * discovery upon the specified address. Note that + * if the address feild is empty then all addresses + * on the association are effected. Not also that + * SPP_PMTUD_ENABLE and SPP_PMTUD_DISABLE are mutually + * exclusive. Enabling both will have undetermined + * results. + * + * SPP_SACKDELAY_ENABLE - Setting this flag turns + * on delayed sack. The time specified in spp_sackdelay + * is used to specify the sack delay for this address. Note + * that if spp_address is empty then all addresses will + * enable delayed sack and take on the sack delay + * value specified in spp_sackdelay. + * SPP_SACKDELAY_DISABLE - Setting this flag turns + * off delayed sack. If the spp_address field is blank then + * delayed sack is disabled for the entire association. Note + * also that this field is mutually exclusive to + * SPP_SACKDELAY_ENABLE, setting both will have undefined + * results. + * + * SPP_IPV6_FLOWLABEL: Setting this flag enables the + * setting of the IPV6 flow label value. The value is + * contained in the spp_ipv6_flowlabel field. + * Upon retrieval, this flag will be set to indicate that + * the spp_ipv6_flowlabel field has a valid value returned. + * If a specific destination address is set (in the + * spp_address field), then the value returned is that of + * the address. If just an association is specified (and + * no address), then the association's default flow label + * is returned. If neither an association nor a destination + * is specified, then the socket's default flow label is + * returned. For non-IPv6 sockets, this flag will be left + * cleared. + * + * SPP_DSCP: Setting this flag enables the setting of the + * Differentiated Services Code Point (DSCP) value + * associated with either the association or a specific + * address. The value is obtained in the spp_dscp field. + * Upon retrieval, this flag will be set to indicate that + * the spp_dscp field has a valid value returned. If a + * specific destination address is set when called (in the + * spp_address field), then that specific destination + * address's DSCP value is returned. If just an association + * is specified, then the association's default DSCP is + * returned. If neither an association nor a destination is + * specified, then the socket's default DSCP is returned. + * + * spp_ipv6_flowlabel + * - This field is used in conjunction with the + * SPP_IPV6_FLOWLABEL flag and contains the IPv6 flow label. + * The 20 least significant bits are used for the flow + * label. This setting has precedence over any IPv6-layer + * setting. + * + * spp_dscp - This field is used in conjunction with the SPP_DSCP flag + * and contains the DSCP. The 6 most significant bits are + * used for the DSCP. This setting has precedence over any + * IPv4- or IPv6- layer setting. + */ +static int sctp_apply_peer_addr_params(struct sctp_paddrparams *params, + struct sctp_transport *trans, + struct sctp_association *asoc, + struct sctp_sock *sp, + int hb_change, + int pmtud_change, + int sackdelay_change) +{ + int error; + + if (params->spp_flags & SPP_HB_DEMAND && trans) { + error = sctp_primitive_REQUESTHEARTBEAT(trans->asoc->base.net, + trans->asoc, trans); + if (error) + return error; + } + + /* Note that unless the spp_flag is set to SPP_HB_ENABLE the value of + * this field is ignored. Note also that a value of zero indicates + * the current setting should be left unchanged. + */ + if (params->spp_flags & SPP_HB_ENABLE) { + + /* Re-zero the interval if the SPP_HB_TIME_IS_ZERO is + * set. This lets us use 0 value when this flag + * is set. + */ + if (params->spp_flags & SPP_HB_TIME_IS_ZERO) + params->spp_hbinterval = 0; + + if (params->spp_hbinterval || + (params->spp_flags & SPP_HB_TIME_IS_ZERO)) { + if (trans) { + trans->hbinterval = + msecs_to_jiffies(params->spp_hbinterval); + sctp_transport_reset_hb_timer(trans); + } else if (asoc) { + asoc->hbinterval = + msecs_to_jiffies(params->spp_hbinterval); + } else { + sp->hbinterval = params->spp_hbinterval; + } + } + } + + if (hb_change) { + if (trans) { + trans->param_flags = + (trans->param_flags & ~SPP_HB) | hb_change; + } else if (asoc) { + asoc->param_flags = + (asoc->param_flags & ~SPP_HB) | hb_change; + } else { + sp->param_flags = + (sp->param_flags & ~SPP_HB) | hb_change; + } + } + + /* When Path MTU discovery is disabled the value specified here will + * be the "fixed" path mtu (i.e. the value of the spp_flags field must + * include the flag SPP_PMTUD_DISABLE for this field to have any + * effect). + */ + if ((params->spp_flags & SPP_PMTUD_DISABLE) && params->spp_pathmtu) { + if (trans) { + trans->pathmtu = params->spp_pathmtu; + sctp_assoc_sync_pmtu(asoc); + } else if (asoc) { + sctp_assoc_set_pmtu(asoc, params->spp_pathmtu); + } else { + sp->pathmtu = params->spp_pathmtu; + } + } + + if (pmtud_change) { + if (trans) { + int update = (trans->param_flags & SPP_PMTUD_DISABLE) && + (params->spp_flags & SPP_PMTUD_ENABLE); + trans->param_flags = + (trans->param_flags & ~SPP_PMTUD) | pmtud_change; + if (update) { + sctp_transport_pmtu(trans, sctp_opt2sk(sp)); + sctp_assoc_sync_pmtu(asoc); + } + sctp_transport_pl_reset(trans); + } else if (asoc) { + asoc->param_flags = + (asoc->param_flags & ~SPP_PMTUD) | pmtud_change; + } else { + sp->param_flags = + (sp->param_flags & ~SPP_PMTUD) | pmtud_change; + } + } + + /* Note that unless the spp_flag is set to SPP_SACKDELAY_ENABLE the + * value of this field is ignored. Note also that a value of zero + * indicates the current setting should be left unchanged. + */ + if ((params->spp_flags & SPP_SACKDELAY_ENABLE) && params->spp_sackdelay) { + if (trans) { + trans->sackdelay = + msecs_to_jiffies(params->spp_sackdelay); + } else if (asoc) { + asoc->sackdelay = + msecs_to_jiffies(params->spp_sackdelay); + } else { + sp->sackdelay = params->spp_sackdelay; + } + } + + if (sackdelay_change) { + if (trans) { + trans->param_flags = + (trans->param_flags & ~SPP_SACKDELAY) | + sackdelay_change; + } else if (asoc) { + asoc->param_flags = + (asoc->param_flags & ~SPP_SACKDELAY) | + sackdelay_change; + } else { + sp->param_flags = + (sp->param_flags & ~SPP_SACKDELAY) | + sackdelay_change; + } + } + + /* Note that a value of zero indicates the current setting should be + left unchanged. + */ + if (params->spp_pathmaxrxt) { + if (trans) { + trans->pathmaxrxt = params->spp_pathmaxrxt; + } else if (asoc) { + asoc->pathmaxrxt = params->spp_pathmaxrxt; + } else { + sp->pathmaxrxt = params->spp_pathmaxrxt; + } + } + + if (params->spp_flags & SPP_IPV6_FLOWLABEL) { + if (trans) { + if (trans->ipaddr.sa.sa_family == AF_INET6) { + trans->flowlabel = params->spp_ipv6_flowlabel & + SCTP_FLOWLABEL_VAL_MASK; + trans->flowlabel |= SCTP_FLOWLABEL_SET_MASK; + } + } else if (asoc) { + struct sctp_transport *t; + + list_for_each_entry(t, &asoc->peer.transport_addr_list, + transports) { + if (t->ipaddr.sa.sa_family != AF_INET6) + continue; + t->flowlabel = params->spp_ipv6_flowlabel & + SCTP_FLOWLABEL_VAL_MASK; + t->flowlabel |= SCTP_FLOWLABEL_SET_MASK; + } + asoc->flowlabel = params->spp_ipv6_flowlabel & + SCTP_FLOWLABEL_VAL_MASK; + asoc->flowlabel |= SCTP_FLOWLABEL_SET_MASK; + } else if (sctp_opt2sk(sp)->sk_family == AF_INET6) { + sp->flowlabel = params->spp_ipv6_flowlabel & + SCTP_FLOWLABEL_VAL_MASK; + sp->flowlabel |= SCTP_FLOWLABEL_SET_MASK; + } + } + + if (params->spp_flags & SPP_DSCP) { + if (trans) { + trans->dscp = params->spp_dscp & SCTP_DSCP_VAL_MASK; + trans->dscp |= SCTP_DSCP_SET_MASK; + } else if (asoc) { + struct sctp_transport *t; + + list_for_each_entry(t, &asoc->peer.transport_addr_list, + transports) { + t->dscp = params->spp_dscp & + SCTP_DSCP_VAL_MASK; + t->dscp |= SCTP_DSCP_SET_MASK; + } + asoc->dscp = params->spp_dscp & SCTP_DSCP_VAL_MASK; + asoc->dscp |= SCTP_DSCP_SET_MASK; + } else { + sp->dscp = params->spp_dscp & SCTP_DSCP_VAL_MASK; + sp->dscp |= SCTP_DSCP_SET_MASK; + } + } + + return 0; +} + +static int sctp_setsockopt_peer_addr_params(struct sock *sk, + struct sctp_paddrparams *params, + unsigned int optlen) +{ + struct sctp_transport *trans = NULL; + struct sctp_association *asoc = NULL; + struct sctp_sock *sp = sctp_sk(sk); + int error; + int hb_change, pmtud_change, sackdelay_change; + + if (optlen == ALIGN(offsetof(struct sctp_paddrparams, + spp_ipv6_flowlabel), 4)) { + if (params->spp_flags & (SPP_DSCP | SPP_IPV6_FLOWLABEL)) + return -EINVAL; + } else if (optlen != sizeof(*params)) { + return -EINVAL; + } + + /* Validate flags and value parameters. */ + hb_change = params->spp_flags & SPP_HB; + pmtud_change = params->spp_flags & SPP_PMTUD; + sackdelay_change = params->spp_flags & SPP_SACKDELAY; + + if (hb_change == SPP_HB || + pmtud_change == SPP_PMTUD || + sackdelay_change == SPP_SACKDELAY || + params->spp_sackdelay > 500 || + (params->spp_pathmtu && + params->spp_pathmtu < SCTP_DEFAULT_MINSEGMENT)) + return -EINVAL; + + /* If an address other than INADDR_ANY is specified, and + * no transport is found, then the request is invalid. + */ + if (!sctp_is_any(sk, (union sctp_addr *)¶ms->spp_address)) { + trans = sctp_addr_id2transport(sk, ¶ms->spp_address, + params->spp_assoc_id); + if (!trans) + return -EINVAL; + } + + /* Get association, if assoc_id != SCTP_FUTURE_ASSOC and the + * socket is a one to many style socket, and an association + * was not found, then the id was invalid. + */ + asoc = sctp_id2assoc(sk, params->spp_assoc_id); + if (!asoc && params->spp_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + /* Heartbeat demand can only be sent on a transport or + * association, but not a socket. + */ + if (params->spp_flags & SPP_HB_DEMAND && !trans && !asoc) + return -EINVAL; + + /* Process parameters. */ + error = sctp_apply_peer_addr_params(params, trans, asoc, sp, + hb_change, pmtud_change, + sackdelay_change); + + if (error) + return error; + + /* If changes are for association, also apply parameters to each + * transport. + */ + if (!trans && asoc) { + list_for_each_entry(trans, &asoc->peer.transport_addr_list, + transports) { + sctp_apply_peer_addr_params(params, trans, asoc, sp, + hb_change, pmtud_change, + sackdelay_change); + } + } + + return 0; +} + +static inline __u32 sctp_spp_sackdelay_enable(__u32 param_flags) +{ + return (param_flags & ~SPP_SACKDELAY) | SPP_SACKDELAY_ENABLE; +} + +static inline __u32 sctp_spp_sackdelay_disable(__u32 param_flags) +{ + return (param_flags & ~SPP_SACKDELAY) | SPP_SACKDELAY_DISABLE; +} + +static void sctp_apply_asoc_delayed_ack(struct sctp_sack_info *params, + struct sctp_association *asoc) +{ + struct sctp_transport *trans; + + if (params->sack_delay) { + asoc->sackdelay = msecs_to_jiffies(params->sack_delay); + asoc->param_flags = + sctp_spp_sackdelay_enable(asoc->param_flags); + } + if (params->sack_freq == 1) { + asoc->param_flags = + sctp_spp_sackdelay_disable(asoc->param_flags); + } else if (params->sack_freq > 1) { + asoc->sackfreq = params->sack_freq; + asoc->param_flags = + sctp_spp_sackdelay_enable(asoc->param_flags); + } + + list_for_each_entry(trans, &asoc->peer.transport_addr_list, + transports) { + if (params->sack_delay) { + trans->sackdelay = msecs_to_jiffies(params->sack_delay); + trans->param_flags = + sctp_spp_sackdelay_enable(trans->param_flags); + } + if (params->sack_freq == 1) { + trans->param_flags = + sctp_spp_sackdelay_disable(trans->param_flags); + } else if (params->sack_freq > 1) { + trans->sackfreq = params->sack_freq; + trans->param_flags = + sctp_spp_sackdelay_enable(trans->param_flags); + } + } +} + +/* + * 7.1.23. Get or set delayed ack timer (SCTP_DELAYED_SACK) + * + * This option will effect the way delayed acks are performed. This + * option allows you to get or set the delayed ack time, in + * milliseconds. It also allows changing the delayed ack frequency. + * Changing the frequency to 1 disables the delayed sack algorithm. If + * the assoc_id is 0, then this sets or gets the endpoints default + * values. If the assoc_id field is non-zero, then the set or get + * effects the specified association for the one to many model (the + * assoc_id field is ignored by the one to one model). Note that if + * sack_delay or sack_freq are 0 when setting this option, then the + * current values will remain unchanged. + * + * struct sctp_sack_info { + * sctp_assoc_t sack_assoc_id; + * uint32_t sack_delay; + * uint32_t sack_freq; + * }; + * + * sack_assoc_id - This parameter, indicates which association the user + * is performing an action upon. Note that if this field's value is + * zero then the endpoints default value is changed (effecting future + * associations only). + * + * sack_delay - This parameter contains the number of milliseconds that + * the user is requesting the delayed ACK timer be set to. Note that + * this value is defined in the standard to be between 200 and 500 + * milliseconds. + * + * sack_freq - This parameter contains the number of packets that must + * be received before a sack is sent without waiting for the delay + * timer to expire. The default value for this is 2, setting this + * value to 1 will disable the delayed sack algorithm. + */ +static int __sctp_setsockopt_delayed_ack(struct sock *sk, + struct sctp_sack_info *params) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_association *asoc; + + /* Validate value parameter. */ + if (params->sack_delay > 500) + return -EINVAL; + + /* Get association, if sack_assoc_id != SCTP_FUTURE_ASSOC and the + * socket is a one to many style socket, and an association + * was not found, then the id was invalid. + */ + asoc = sctp_id2assoc(sk, params->sack_assoc_id); + if (!asoc && params->sack_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) { + sctp_apply_asoc_delayed_ack(params, asoc); + + return 0; + } + + if (sctp_style(sk, TCP)) + params->sack_assoc_id = SCTP_FUTURE_ASSOC; + + if (params->sack_assoc_id == SCTP_FUTURE_ASSOC || + params->sack_assoc_id == SCTP_ALL_ASSOC) { + if (params->sack_delay) { + sp->sackdelay = params->sack_delay; + sp->param_flags = + sctp_spp_sackdelay_enable(sp->param_flags); + } + if (params->sack_freq == 1) { + sp->param_flags = + sctp_spp_sackdelay_disable(sp->param_flags); + } else if (params->sack_freq > 1) { + sp->sackfreq = params->sack_freq; + sp->param_flags = + sctp_spp_sackdelay_enable(sp->param_flags); + } + } + + if (params->sack_assoc_id == SCTP_CURRENT_ASSOC || + params->sack_assoc_id == SCTP_ALL_ASSOC) + list_for_each_entry(asoc, &sp->ep->asocs, asocs) + sctp_apply_asoc_delayed_ack(params, asoc); + + return 0; +} + +static int sctp_setsockopt_delayed_ack(struct sock *sk, + struct sctp_sack_info *params, + unsigned int optlen) +{ + if (optlen == sizeof(struct sctp_assoc_value)) { + struct sctp_assoc_value *v = (struct sctp_assoc_value *)params; + struct sctp_sack_info p; + + pr_warn_ratelimited(DEPRECATED + "%s (pid %d) " + "Use of struct sctp_assoc_value in delayed_ack socket option.\n" + "Use struct sctp_sack_info instead\n", + current->comm, task_pid_nr(current)); + + p.sack_assoc_id = v->assoc_id; + p.sack_delay = v->assoc_value; + p.sack_freq = v->assoc_value ? 0 : 1; + return __sctp_setsockopt_delayed_ack(sk, &p); + } + + if (optlen != sizeof(struct sctp_sack_info)) + return -EINVAL; + if (params->sack_delay == 0 && params->sack_freq == 0) + return 0; + return __sctp_setsockopt_delayed_ack(sk, params); +} + +/* 7.1.3 Initialization Parameters (SCTP_INITMSG) + * + * Applications can specify protocol parameters for the default association + * initialization. The option name argument to setsockopt() and getsockopt() + * is SCTP_INITMSG. + * + * Setting initialization parameters is effective only on an unconnected + * socket (for UDP-style sockets only future associations are effected + * by the change). With TCP-style sockets, this option is inherited by + * sockets derived from a listener socket. + */ +static int sctp_setsockopt_initmsg(struct sock *sk, struct sctp_initmsg *sinit, + unsigned int optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + + if (optlen != sizeof(struct sctp_initmsg)) + return -EINVAL; + + if (sinit->sinit_num_ostreams) + sp->initmsg.sinit_num_ostreams = sinit->sinit_num_ostreams; + if (sinit->sinit_max_instreams) + sp->initmsg.sinit_max_instreams = sinit->sinit_max_instreams; + if (sinit->sinit_max_attempts) + sp->initmsg.sinit_max_attempts = sinit->sinit_max_attempts; + if (sinit->sinit_max_init_timeo) + sp->initmsg.sinit_max_init_timeo = sinit->sinit_max_init_timeo; + + return 0; +} + +/* + * 7.1.14 Set default send parameters (SCTP_DEFAULT_SEND_PARAM) + * + * Applications that wish to use the sendto() system call may wish to + * specify a default set of parameters that would normally be supplied + * through the inclusion of ancillary data. This socket option allows + * such an application to set the default sctp_sndrcvinfo structure. + * The application that wishes to use this socket option simply passes + * in to this call the sctp_sndrcvinfo structure defined in Section + * 5.2.2) The input parameters accepted by this call include + * sinfo_stream, sinfo_flags, sinfo_ppid, sinfo_context, + * sinfo_timetolive. The user must provide the sinfo_assoc_id field in + * to this call if the caller is using the UDP model. + */ +static int sctp_setsockopt_default_send_param(struct sock *sk, + struct sctp_sndrcvinfo *info, + unsigned int optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_association *asoc; + + if (optlen != sizeof(*info)) + return -EINVAL; + if (info->sinfo_flags & + ~(SCTP_UNORDERED | SCTP_ADDR_OVER | + SCTP_ABORT | SCTP_EOF)) + return -EINVAL; + + asoc = sctp_id2assoc(sk, info->sinfo_assoc_id); + if (!asoc && info->sinfo_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) { + asoc->default_stream = info->sinfo_stream; + asoc->default_flags = info->sinfo_flags; + asoc->default_ppid = info->sinfo_ppid; + asoc->default_context = info->sinfo_context; + asoc->default_timetolive = info->sinfo_timetolive; + + return 0; + } + + if (sctp_style(sk, TCP)) + info->sinfo_assoc_id = SCTP_FUTURE_ASSOC; + + if (info->sinfo_assoc_id == SCTP_FUTURE_ASSOC || + info->sinfo_assoc_id == SCTP_ALL_ASSOC) { + sp->default_stream = info->sinfo_stream; + sp->default_flags = info->sinfo_flags; + sp->default_ppid = info->sinfo_ppid; + sp->default_context = info->sinfo_context; + sp->default_timetolive = info->sinfo_timetolive; + } + + if (info->sinfo_assoc_id == SCTP_CURRENT_ASSOC || + info->sinfo_assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &sp->ep->asocs, asocs) { + asoc->default_stream = info->sinfo_stream; + asoc->default_flags = info->sinfo_flags; + asoc->default_ppid = info->sinfo_ppid; + asoc->default_context = info->sinfo_context; + asoc->default_timetolive = info->sinfo_timetolive; + } + } + + return 0; +} + +/* RFC6458, Section 8.1.31. Set/get Default Send Parameters + * (SCTP_DEFAULT_SNDINFO) + */ +static int sctp_setsockopt_default_sndinfo(struct sock *sk, + struct sctp_sndinfo *info, + unsigned int optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_association *asoc; + + if (optlen != sizeof(*info)) + return -EINVAL; + if (info->snd_flags & + ~(SCTP_UNORDERED | SCTP_ADDR_OVER | + SCTP_ABORT | SCTP_EOF)) + return -EINVAL; + + asoc = sctp_id2assoc(sk, info->snd_assoc_id); + if (!asoc && info->snd_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) { + asoc->default_stream = info->snd_sid; + asoc->default_flags = info->snd_flags; + asoc->default_ppid = info->snd_ppid; + asoc->default_context = info->snd_context; + + return 0; + } + + if (sctp_style(sk, TCP)) + info->snd_assoc_id = SCTP_FUTURE_ASSOC; + + if (info->snd_assoc_id == SCTP_FUTURE_ASSOC || + info->snd_assoc_id == SCTP_ALL_ASSOC) { + sp->default_stream = info->snd_sid; + sp->default_flags = info->snd_flags; + sp->default_ppid = info->snd_ppid; + sp->default_context = info->snd_context; + } + + if (info->snd_assoc_id == SCTP_CURRENT_ASSOC || + info->snd_assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &sp->ep->asocs, asocs) { + asoc->default_stream = info->snd_sid; + asoc->default_flags = info->snd_flags; + asoc->default_ppid = info->snd_ppid; + asoc->default_context = info->snd_context; + } + } + + return 0; +} + +/* 7.1.10 Set Primary Address (SCTP_PRIMARY_ADDR) + * + * Requests that the local SCTP stack use the enclosed peer address as + * the association primary. The enclosed address must be one of the + * association peer's addresses. + */ +static int sctp_setsockopt_primary_addr(struct sock *sk, struct sctp_prim *prim, + unsigned int optlen) +{ + struct sctp_transport *trans; + struct sctp_af *af; + int err; + + if (optlen != sizeof(struct sctp_prim)) + return -EINVAL; + + /* Allow security module to validate address but need address len. */ + af = sctp_get_af_specific(prim->ssp_addr.ss_family); + if (!af) + return -EINVAL; + + err = security_sctp_bind_connect(sk, SCTP_PRIMARY_ADDR, + (struct sockaddr *)&prim->ssp_addr, + af->sockaddr_len); + if (err) + return err; + + trans = sctp_addr_id2transport(sk, &prim->ssp_addr, prim->ssp_assoc_id); + if (!trans) + return -EINVAL; + + sctp_assoc_set_primary(trans->asoc, trans); + + return 0; +} + +/* + * 7.1.5 SCTP_NODELAY + * + * Turn on/off any Nagle-like algorithm. This means that packets are + * generally sent as soon as possible and no unnecessary delays are + * introduced, at the cost of more packets in the network. Expects an + * integer boolean flag. + */ +static int sctp_setsockopt_nodelay(struct sock *sk, int *val, + unsigned int optlen) +{ + if (optlen < sizeof(int)) + return -EINVAL; + sctp_sk(sk)->nodelay = (*val == 0) ? 0 : 1; + return 0; +} + +/* + * + * 7.1.1 SCTP_RTOINFO + * + * The protocol parameters used to initialize and bound retransmission + * timeout (RTO) are tunable. sctp_rtoinfo structure is used to access + * and modify these parameters. + * All parameters are time values, in milliseconds. A value of 0, when + * modifying the parameters, indicates that the current value should not + * be changed. + * + */ +static int sctp_setsockopt_rtoinfo(struct sock *sk, + struct sctp_rtoinfo *rtoinfo, + unsigned int optlen) +{ + struct sctp_association *asoc; + unsigned long rto_min, rto_max; + struct sctp_sock *sp = sctp_sk(sk); + + if (optlen != sizeof (struct sctp_rtoinfo)) + return -EINVAL; + + asoc = sctp_id2assoc(sk, rtoinfo->srto_assoc_id); + + /* Set the values to the specific association */ + if (!asoc && rtoinfo->srto_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + rto_max = rtoinfo->srto_max; + rto_min = rtoinfo->srto_min; + + if (rto_max) + rto_max = asoc ? msecs_to_jiffies(rto_max) : rto_max; + else + rto_max = asoc ? asoc->rto_max : sp->rtoinfo.srto_max; + + if (rto_min) + rto_min = asoc ? msecs_to_jiffies(rto_min) : rto_min; + else + rto_min = asoc ? asoc->rto_min : sp->rtoinfo.srto_min; + + if (rto_min > rto_max) + return -EINVAL; + + if (asoc) { + if (rtoinfo->srto_initial != 0) + asoc->rto_initial = + msecs_to_jiffies(rtoinfo->srto_initial); + asoc->rto_max = rto_max; + asoc->rto_min = rto_min; + } else { + /* If there is no association or the association-id = 0 + * set the values to the endpoint. + */ + if (rtoinfo->srto_initial != 0) + sp->rtoinfo.srto_initial = rtoinfo->srto_initial; + sp->rtoinfo.srto_max = rto_max; + sp->rtoinfo.srto_min = rto_min; + } + + return 0; +} + +/* + * + * 7.1.2 SCTP_ASSOCINFO + * + * This option is used to tune the maximum retransmission attempts + * of the association. + * Returns an error if the new association retransmission value is + * greater than the sum of the retransmission value of the peer. + * See [SCTP] for more information. + * + */ +static int sctp_setsockopt_associnfo(struct sock *sk, + struct sctp_assocparams *assocparams, + unsigned int optlen) +{ + + struct sctp_association *asoc; + + if (optlen != sizeof(struct sctp_assocparams)) + return -EINVAL; + + asoc = sctp_id2assoc(sk, assocparams->sasoc_assoc_id); + + if (!asoc && assocparams->sasoc_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + /* Set the values to the specific association */ + if (asoc) { + if (assocparams->sasoc_asocmaxrxt != 0) { + __u32 path_sum = 0; + int paths = 0; + struct sctp_transport *peer_addr; + + list_for_each_entry(peer_addr, &asoc->peer.transport_addr_list, + transports) { + path_sum += peer_addr->pathmaxrxt; + paths++; + } + + /* Only validate asocmaxrxt if we have more than + * one path/transport. We do this because path + * retransmissions are only counted when we have more + * then one path. + */ + if (paths > 1 && + assocparams->sasoc_asocmaxrxt > path_sum) + return -EINVAL; + + asoc->max_retrans = assocparams->sasoc_asocmaxrxt; + } + + if (assocparams->sasoc_cookie_life != 0) + asoc->cookie_life = + ms_to_ktime(assocparams->sasoc_cookie_life); + } else { + /* Set the values to the endpoint */ + struct sctp_sock *sp = sctp_sk(sk); + + if (assocparams->sasoc_asocmaxrxt != 0) + sp->assocparams.sasoc_asocmaxrxt = + assocparams->sasoc_asocmaxrxt; + if (assocparams->sasoc_cookie_life != 0) + sp->assocparams.sasoc_cookie_life = + assocparams->sasoc_cookie_life; + } + return 0; +} + +/* + * 7.1.16 Set/clear IPv4 mapped addresses (SCTP_I_WANT_MAPPED_V4_ADDR) + * + * This socket option is a boolean flag which turns on or off mapped V4 + * addresses. If this option is turned on and the socket is type + * PF_INET6, then IPv4 addresses will be mapped to V6 representation. + * If this option is turned off, then no mapping will be done of V4 + * addresses and a user will receive both PF_INET6 and PF_INET type + * addresses on the socket. + */ +static int sctp_setsockopt_mappedv4(struct sock *sk, int *val, + unsigned int optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + + if (optlen < sizeof(int)) + return -EINVAL; + if (*val) + sp->v4mapped = 1; + else + sp->v4mapped = 0; + + return 0; +} + +/* + * 8.1.16. Get or Set the Maximum Fragmentation Size (SCTP_MAXSEG) + * This option will get or set the maximum size to put in any outgoing + * SCTP DATA chunk. If a message is larger than this size it will be + * fragmented by SCTP into the specified size. Note that the underlying + * SCTP implementation may fragment into smaller sized chunks when the + * PMTU of the underlying association is smaller than the value set by + * the user. The default value for this option is '0' which indicates + * the user is NOT limiting fragmentation and only the PMTU will effect + * SCTP's choice of DATA chunk size. Note also that values set larger + * than the maximum size of an IP datagram will effectively let SCTP + * control fragmentation (i.e. the same as setting this option to 0). + * + * The following structure is used to access and modify this parameter: + * + * struct sctp_assoc_value { + * sctp_assoc_t assoc_id; + * uint32_t assoc_value; + * }; + * + * assoc_id: This parameter is ignored for one-to-one style sockets. + * For one-to-many style sockets this parameter indicates which + * association the user is performing an action upon. Note that if + * this field's value is zero then the endpoints default value is + * changed (effecting future associations only). + * assoc_value: This parameter specifies the maximum size in bytes. + */ +static int sctp_setsockopt_maxseg(struct sock *sk, + struct sctp_assoc_value *params, + unsigned int optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_association *asoc; + sctp_assoc_t assoc_id; + int val; + + if (optlen == sizeof(int)) { + pr_warn_ratelimited(DEPRECATED + "%s (pid %d) " + "Use of int in maxseg socket option.\n" + "Use struct sctp_assoc_value instead\n", + current->comm, task_pid_nr(current)); + assoc_id = SCTP_FUTURE_ASSOC; + val = *(int *)params; + } else if (optlen == sizeof(struct sctp_assoc_value)) { + assoc_id = params->assoc_id; + val = params->assoc_value; + } else { + return -EINVAL; + } + + asoc = sctp_id2assoc(sk, assoc_id); + if (!asoc && assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (val) { + int min_len, max_len; + __u16 datasize = asoc ? sctp_datachk_len(&asoc->stream) : + sizeof(struct sctp_data_chunk); + + min_len = sctp_min_frag_point(sp, datasize); + max_len = SCTP_MAX_CHUNK_LEN - datasize; + + if (val < min_len || val > max_len) + return -EINVAL; + } + + if (asoc) { + asoc->user_frag = val; + sctp_assoc_update_frag_point(asoc); + } else { + sp->user_frag = val; + } + + return 0; +} + + +/* + * 7.1.9 Set Peer Primary Address (SCTP_SET_PEER_PRIMARY_ADDR) + * + * Requests that the peer mark the enclosed address as the association + * primary. The enclosed address must be one of the association's + * locally bound addresses. The following structure is used to make a + * set primary request: + */ +static int sctp_setsockopt_peer_primary_addr(struct sock *sk, + struct sctp_setpeerprim *prim, + unsigned int optlen) +{ + struct sctp_sock *sp; + struct sctp_association *asoc = NULL; + struct sctp_chunk *chunk; + struct sctp_af *af; + int err; + + sp = sctp_sk(sk); + + if (!sp->ep->asconf_enable) + return -EPERM; + + if (optlen != sizeof(struct sctp_setpeerprim)) + return -EINVAL; + + asoc = sctp_id2assoc(sk, prim->sspp_assoc_id); + if (!asoc) + return -EINVAL; + + if (!asoc->peer.asconf_capable) + return -EPERM; + + if (asoc->peer.addip_disabled_mask & SCTP_PARAM_SET_PRIMARY) + return -EPERM; + + if (!sctp_state(asoc, ESTABLISHED)) + return -ENOTCONN; + + af = sctp_get_af_specific(prim->sspp_addr.ss_family); + if (!af) + return -EINVAL; + + if (!af->addr_valid((union sctp_addr *)&prim->sspp_addr, sp, NULL)) + return -EADDRNOTAVAIL; + + if (!sctp_assoc_lookup_laddr(asoc, (union sctp_addr *)&prim->sspp_addr)) + return -EADDRNOTAVAIL; + + /* Allow security module to validate address. */ + err = security_sctp_bind_connect(sk, SCTP_SET_PEER_PRIMARY_ADDR, + (struct sockaddr *)&prim->sspp_addr, + af->sockaddr_len); + if (err) + return err; + + /* Create an ASCONF chunk with SET_PRIMARY parameter */ + chunk = sctp_make_asconf_set_prim(asoc, + (union sctp_addr *)&prim->sspp_addr); + if (!chunk) + return -ENOMEM; + + err = sctp_send_asconf(asoc, chunk); + + pr_debug("%s: we set peer primary addr primitively\n", __func__); + + return err; +} + +static int sctp_setsockopt_adaptation_layer(struct sock *sk, + struct sctp_setadaptation *adapt, + unsigned int optlen) +{ + if (optlen != sizeof(struct sctp_setadaptation)) + return -EINVAL; + + sctp_sk(sk)->adaptation_ind = adapt->ssb_adaptation_ind; + + return 0; +} + +/* + * 7.1.29. Set or Get the default context (SCTP_CONTEXT) + * + * The context field in the sctp_sndrcvinfo structure is normally only + * used when a failed message is retrieved holding the value that was + * sent down on the actual send call. This option allows the setting of + * a default context on an association basis that will be received on + * reading messages from the peer. This is especially helpful in the + * one-2-many model for an application to keep some reference to an + * internal state machine that is processing messages on the + * association. Note that the setting of this value only effects + * received messages from the peer and does not effect the value that is + * saved with outbound messages. + */ +static int sctp_setsockopt_context(struct sock *sk, + struct sctp_assoc_value *params, + unsigned int optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_association *asoc; + + if (optlen != sizeof(struct sctp_assoc_value)) + return -EINVAL; + + asoc = sctp_id2assoc(sk, params->assoc_id); + if (!asoc && params->assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) { + asoc->default_rcv_context = params->assoc_value; + + return 0; + } + + if (sctp_style(sk, TCP)) + params->assoc_id = SCTP_FUTURE_ASSOC; + + if (params->assoc_id == SCTP_FUTURE_ASSOC || + params->assoc_id == SCTP_ALL_ASSOC) + sp->default_rcv_context = params->assoc_value; + + if (params->assoc_id == SCTP_CURRENT_ASSOC || + params->assoc_id == SCTP_ALL_ASSOC) + list_for_each_entry(asoc, &sp->ep->asocs, asocs) + asoc->default_rcv_context = params->assoc_value; + + return 0; +} + +/* + * 7.1.24. Get or set fragmented interleave (SCTP_FRAGMENT_INTERLEAVE) + * + * This options will at a minimum specify if the implementation is doing + * fragmented interleave. Fragmented interleave, for a one to many + * socket, is when subsequent calls to receive a message may return + * parts of messages from different associations. Some implementations + * may allow you to turn this value on or off. If so, when turned off, + * no fragment interleave will occur (which will cause a head of line + * blocking amongst multiple associations sharing the same one to many + * socket). When this option is turned on, then each receive call may + * come from a different association (thus the user must receive data + * with the extended calls (e.g. sctp_recvmsg) to keep track of which + * association each receive belongs to. + * + * This option takes a boolean value. A non-zero value indicates that + * fragmented interleave is on. A value of zero indicates that + * fragmented interleave is off. + * + * Note that it is important that an implementation that allows this + * option to be turned on, have it off by default. Otherwise an unaware + * application using the one to many model may become confused and act + * incorrectly. + */ +static int sctp_setsockopt_fragment_interleave(struct sock *sk, int *val, + unsigned int optlen) +{ + if (optlen != sizeof(int)) + return -EINVAL; + + sctp_sk(sk)->frag_interleave = !!*val; + + if (!sctp_sk(sk)->frag_interleave) + sctp_sk(sk)->ep->intl_enable = 0; + + return 0; +} + +/* + * 8.1.21. Set or Get the SCTP Partial Delivery Point + * (SCTP_PARTIAL_DELIVERY_POINT) + * + * This option will set or get the SCTP partial delivery point. This + * point is the size of a message where the partial delivery API will be + * invoked to help free up rwnd space for the peer. Setting this to a + * lower value will cause partial deliveries to happen more often. The + * calls argument is an integer that sets or gets the partial delivery + * point. Note also that the call will fail if the user attempts to set + * this value larger than the socket receive buffer size. + * + * Note that any single message having a length smaller than or equal to + * the SCTP partial delivery point will be delivered in one single read + * call as long as the user provided buffer is large enough to hold the + * message. + */ +static int sctp_setsockopt_partial_delivery_point(struct sock *sk, u32 *val, + unsigned int optlen) +{ + if (optlen != sizeof(u32)) + return -EINVAL; + + /* Note: We double the receive buffer from what the user sets + * it to be, also initial rwnd is based on rcvbuf/2. + */ + if (*val > (sk->sk_rcvbuf >> 1)) + return -EINVAL; + + sctp_sk(sk)->pd_point = *val; + + return 0; /* is this the right error code? */ +} + +/* + * 7.1.28. Set or Get the maximum burst (SCTP_MAX_BURST) + * + * This option will allow a user to change the maximum burst of packets + * that can be emitted by this association. Note that the default value + * is 4, and some implementations may restrict this setting so that it + * can only be lowered. + * + * NOTE: This text doesn't seem right. Do this on a socket basis with + * future associations inheriting the socket value. + */ +static int sctp_setsockopt_maxburst(struct sock *sk, + struct sctp_assoc_value *params, + unsigned int optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_association *asoc; + sctp_assoc_t assoc_id; + u32 assoc_value; + + if (optlen == sizeof(int)) { + pr_warn_ratelimited(DEPRECATED + "%s (pid %d) " + "Use of int in max_burst socket option deprecated.\n" + "Use struct sctp_assoc_value instead\n", + current->comm, task_pid_nr(current)); + assoc_id = SCTP_FUTURE_ASSOC; + assoc_value = *((int *)params); + } else if (optlen == sizeof(struct sctp_assoc_value)) { + assoc_id = params->assoc_id; + assoc_value = params->assoc_value; + } else + return -EINVAL; + + asoc = sctp_id2assoc(sk, assoc_id); + if (!asoc && assoc_id > SCTP_ALL_ASSOC && sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) { + asoc->max_burst = assoc_value; + + return 0; + } + + if (sctp_style(sk, TCP)) + assoc_id = SCTP_FUTURE_ASSOC; + + if (assoc_id == SCTP_FUTURE_ASSOC || assoc_id == SCTP_ALL_ASSOC) + sp->max_burst = assoc_value; + + if (assoc_id == SCTP_CURRENT_ASSOC || assoc_id == SCTP_ALL_ASSOC) + list_for_each_entry(asoc, &sp->ep->asocs, asocs) + asoc->max_burst = assoc_value; + + return 0; +} + +/* + * 7.1.18. Add a chunk that must be authenticated (SCTP_AUTH_CHUNK) + * + * This set option adds a chunk type that the user is requesting to be + * received only in an authenticated way. Changes to the list of chunks + * will only effect future associations on the socket. + */ +static int sctp_setsockopt_auth_chunk(struct sock *sk, + struct sctp_authchunk *val, + unsigned int optlen) +{ + struct sctp_endpoint *ep = sctp_sk(sk)->ep; + + if (!ep->auth_enable) + return -EACCES; + + if (optlen != sizeof(struct sctp_authchunk)) + return -EINVAL; + + switch (val->sauth_chunk) { + case SCTP_CID_INIT: + case SCTP_CID_INIT_ACK: + case SCTP_CID_SHUTDOWN_COMPLETE: + case SCTP_CID_AUTH: + return -EINVAL; + } + + /* add this chunk id to the endpoint */ + return sctp_auth_ep_add_chunkid(ep, val->sauth_chunk); +} + +/* + * 7.1.19. Get or set the list of supported HMAC Identifiers (SCTP_HMAC_IDENT) + * + * This option gets or sets the list of HMAC algorithms that the local + * endpoint requires the peer to use. + */ +static int sctp_setsockopt_hmac_ident(struct sock *sk, + struct sctp_hmacalgo *hmacs, + unsigned int optlen) +{ + struct sctp_endpoint *ep = sctp_sk(sk)->ep; + u32 idents; + + if (!ep->auth_enable) + return -EACCES; + + if (optlen < sizeof(struct sctp_hmacalgo)) + return -EINVAL; + optlen = min_t(unsigned int, optlen, sizeof(struct sctp_hmacalgo) + + SCTP_AUTH_NUM_HMACS * sizeof(u16)); + + idents = hmacs->shmac_num_idents; + if (idents == 0 || idents > SCTP_AUTH_NUM_HMACS || + (idents * sizeof(u16)) > (optlen - sizeof(struct sctp_hmacalgo))) + return -EINVAL; + + return sctp_auth_ep_set_hmacs(ep, hmacs); +} + +/* + * 7.1.20. Set a shared key (SCTP_AUTH_KEY) + * + * This option will set a shared secret key which is used to build an + * association shared key. + */ +static int sctp_setsockopt_auth_key(struct sock *sk, + struct sctp_authkey *authkey, + unsigned int optlen) +{ + struct sctp_endpoint *ep = sctp_sk(sk)->ep; + struct sctp_association *asoc; + int ret = -EINVAL; + + if (optlen <= sizeof(struct sctp_authkey)) + return -EINVAL; + /* authkey->sca_keylength is u16, so optlen can't be bigger than + * this. + */ + optlen = min_t(unsigned int, optlen, USHRT_MAX + sizeof(*authkey)); + + if (authkey->sca_keylength > optlen - sizeof(*authkey)) + goto out; + + asoc = sctp_id2assoc(sk, authkey->sca_assoc_id); + if (!asoc && authkey->sca_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + goto out; + + if (asoc) { + ret = sctp_auth_set_key(ep, asoc, authkey); + goto out; + } + + if (sctp_style(sk, TCP)) + authkey->sca_assoc_id = SCTP_FUTURE_ASSOC; + + if (authkey->sca_assoc_id == SCTP_FUTURE_ASSOC || + authkey->sca_assoc_id == SCTP_ALL_ASSOC) { + ret = sctp_auth_set_key(ep, asoc, authkey); + if (ret) + goto out; + } + + ret = 0; + + if (authkey->sca_assoc_id == SCTP_CURRENT_ASSOC || + authkey->sca_assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &ep->asocs, asocs) { + int res = sctp_auth_set_key(ep, asoc, authkey); + + if (res && !ret) + ret = res; + } + } + +out: + memzero_explicit(authkey, optlen); + return ret; +} + +/* + * 7.1.21. Get or set the active shared key (SCTP_AUTH_ACTIVE_KEY) + * + * This option will get or set the active shared key to be used to build + * the association shared key. + */ +static int sctp_setsockopt_active_key(struct sock *sk, + struct sctp_authkeyid *val, + unsigned int optlen) +{ + struct sctp_endpoint *ep = sctp_sk(sk)->ep; + struct sctp_association *asoc; + int ret = 0; + + if (optlen != sizeof(struct sctp_authkeyid)) + return -EINVAL; + + asoc = sctp_id2assoc(sk, val->scact_assoc_id); + if (!asoc && val->scact_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) + return sctp_auth_set_active_key(ep, asoc, val->scact_keynumber); + + if (sctp_style(sk, TCP)) + val->scact_assoc_id = SCTP_FUTURE_ASSOC; + + if (val->scact_assoc_id == SCTP_FUTURE_ASSOC || + val->scact_assoc_id == SCTP_ALL_ASSOC) { + ret = sctp_auth_set_active_key(ep, asoc, val->scact_keynumber); + if (ret) + return ret; + } + + if (val->scact_assoc_id == SCTP_CURRENT_ASSOC || + val->scact_assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &ep->asocs, asocs) { + int res = sctp_auth_set_active_key(ep, asoc, + val->scact_keynumber); + + if (res && !ret) + ret = res; + } + } + + return ret; +} + +/* + * 7.1.22. Delete a shared key (SCTP_AUTH_DELETE_KEY) + * + * This set option will delete a shared secret key from use. + */ +static int sctp_setsockopt_del_key(struct sock *sk, + struct sctp_authkeyid *val, + unsigned int optlen) +{ + struct sctp_endpoint *ep = sctp_sk(sk)->ep; + struct sctp_association *asoc; + int ret = 0; + + if (optlen != sizeof(struct sctp_authkeyid)) + return -EINVAL; + + asoc = sctp_id2assoc(sk, val->scact_assoc_id); + if (!asoc && val->scact_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) + return sctp_auth_del_key_id(ep, asoc, val->scact_keynumber); + + if (sctp_style(sk, TCP)) + val->scact_assoc_id = SCTP_FUTURE_ASSOC; + + if (val->scact_assoc_id == SCTP_FUTURE_ASSOC || + val->scact_assoc_id == SCTP_ALL_ASSOC) { + ret = sctp_auth_del_key_id(ep, asoc, val->scact_keynumber); + if (ret) + return ret; + } + + if (val->scact_assoc_id == SCTP_CURRENT_ASSOC || + val->scact_assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &ep->asocs, asocs) { + int res = sctp_auth_del_key_id(ep, asoc, + val->scact_keynumber); + + if (res && !ret) + ret = res; + } + } + + return ret; +} + +/* + * 8.3.4 Deactivate a Shared Key (SCTP_AUTH_DEACTIVATE_KEY) + * + * This set option will deactivate a shared secret key. + */ +static int sctp_setsockopt_deactivate_key(struct sock *sk, + struct sctp_authkeyid *val, + unsigned int optlen) +{ + struct sctp_endpoint *ep = sctp_sk(sk)->ep; + struct sctp_association *asoc; + int ret = 0; + + if (optlen != sizeof(struct sctp_authkeyid)) + return -EINVAL; + + asoc = sctp_id2assoc(sk, val->scact_assoc_id); + if (!asoc && val->scact_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) + return sctp_auth_deact_key_id(ep, asoc, val->scact_keynumber); + + if (sctp_style(sk, TCP)) + val->scact_assoc_id = SCTP_FUTURE_ASSOC; + + if (val->scact_assoc_id == SCTP_FUTURE_ASSOC || + val->scact_assoc_id == SCTP_ALL_ASSOC) { + ret = sctp_auth_deact_key_id(ep, asoc, val->scact_keynumber); + if (ret) + return ret; + } + + if (val->scact_assoc_id == SCTP_CURRENT_ASSOC || + val->scact_assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &ep->asocs, asocs) { + int res = sctp_auth_deact_key_id(ep, asoc, + val->scact_keynumber); + + if (res && !ret) + ret = res; + } + } + + return ret; +} + +/* + * 8.1.23 SCTP_AUTO_ASCONF + * + * This option will enable or disable the use of the automatic generation of + * ASCONF chunks to add and delete addresses to an existing association. Note + * that this option has two caveats namely: a) it only affects sockets that + * are bound to all addresses available to the SCTP stack, and b) the system + * administrator may have an overriding control that turns the ASCONF feature + * off no matter what setting the socket option may have. + * This option expects an integer boolean flag, where a non-zero value turns on + * the option, and a zero value turns off the option. + * Note. In this implementation, socket operation overrides default parameter + * being set by sysctl as well as FreeBSD implementation + */ +static int sctp_setsockopt_auto_asconf(struct sock *sk, int *val, + unsigned int optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + + if (optlen < sizeof(int)) + return -EINVAL; + if (!sctp_is_ep_boundall(sk) && *val) + return -EINVAL; + if ((*val && sp->do_auto_asconf) || (!*val && !sp->do_auto_asconf)) + return 0; + + spin_lock_bh(&sock_net(sk)->sctp.addr_wq_lock); + if (*val == 0 && sp->do_auto_asconf) { + list_del(&sp->auto_asconf_list); + sp->do_auto_asconf = 0; + } else if (*val && !sp->do_auto_asconf) { + list_add_tail(&sp->auto_asconf_list, + &sock_net(sk)->sctp.auto_asconf_splist); + sp->do_auto_asconf = 1; + } + spin_unlock_bh(&sock_net(sk)->sctp.addr_wq_lock); + return 0; +} + +/* + * SCTP_PEER_ADDR_THLDS + * + * This option allows us to alter the partially failed threshold for one or all + * transports in an association. See Section 6.1 of: + * http://www.ietf.org/id/draft-nishida-tsvwg-sctp-failover-05.txt + */ +static int sctp_setsockopt_paddr_thresholds(struct sock *sk, + struct sctp_paddrthlds_v2 *val, + unsigned int optlen, bool v2) +{ + struct sctp_transport *trans; + struct sctp_association *asoc; + int len; + + len = v2 ? sizeof(*val) : sizeof(struct sctp_paddrthlds); + if (optlen < len) + return -EINVAL; + + if (v2 && val->spt_pathpfthld > val->spt_pathcpthld) + return -EINVAL; + + if (!sctp_is_any(sk, (const union sctp_addr *)&val->spt_address)) { + trans = sctp_addr_id2transport(sk, &val->spt_address, + val->spt_assoc_id); + if (!trans) + return -ENOENT; + + if (val->spt_pathmaxrxt) + trans->pathmaxrxt = val->spt_pathmaxrxt; + if (v2) + trans->ps_retrans = val->spt_pathcpthld; + trans->pf_retrans = val->spt_pathpfthld; + + return 0; + } + + asoc = sctp_id2assoc(sk, val->spt_assoc_id); + if (!asoc && val->spt_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) { + list_for_each_entry(trans, &asoc->peer.transport_addr_list, + transports) { + if (val->spt_pathmaxrxt) + trans->pathmaxrxt = val->spt_pathmaxrxt; + if (v2) + trans->ps_retrans = val->spt_pathcpthld; + trans->pf_retrans = val->spt_pathpfthld; + } + + if (val->spt_pathmaxrxt) + asoc->pathmaxrxt = val->spt_pathmaxrxt; + if (v2) + asoc->ps_retrans = val->spt_pathcpthld; + asoc->pf_retrans = val->spt_pathpfthld; + } else { + struct sctp_sock *sp = sctp_sk(sk); + + if (val->spt_pathmaxrxt) + sp->pathmaxrxt = val->spt_pathmaxrxt; + if (v2) + sp->ps_retrans = val->spt_pathcpthld; + sp->pf_retrans = val->spt_pathpfthld; + } + + return 0; +} + +static int sctp_setsockopt_recvrcvinfo(struct sock *sk, int *val, + unsigned int optlen) +{ + if (optlen < sizeof(int)) + return -EINVAL; + + sctp_sk(sk)->recvrcvinfo = (*val == 0) ? 0 : 1; + + return 0; +} + +static int sctp_setsockopt_recvnxtinfo(struct sock *sk, int *val, + unsigned int optlen) +{ + if (optlen < sizeof(int)) + return -EINVAL; + + sctp_sk(sk)->recvnxtinfo = (*val == 0) ? 0 : 1; + + return 0; +} + +static int sctp_setsockopt_pr_supported(struct sock *sk, + struct sctp_assoc_value *params, + unsigned int optlen) +{ + struct sctp_association *asoc; + + if (optlen != sizeof(*params)) + return -EINVAL; + + asoc = sctp_id2assoc(sk, params->assoc_id); + if (!asoc && params->assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + sctp_sk(sk)->ep->prsctp_enable = !!params->assoc_value; + + return 0; +} + +static int sctp_setsockopt_default_prinfo(struct sock *sk, + struct sctp_default_prinfo *info, + unsigned int optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_association *asoc; + int retval = -EINVAL; + + if (optlen != sizeof(*info)) + goto out; + + if (info->pr_policy & ~SCTP_PR_SCTP_MASK) + goto out; + + if (info->pr_policy == SCTP_PR_SCTP_NONE) + info->pr_value = 0; + + asoc = sctp_id2assoc(sk, info->pr_assoc_id); + if (!asoc && info->pr_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + goto out; + + retval = 0; + + if (asoc) { + SCTP_PR_SET_POLICY(asoc->default_flags, info->pr_policy); + asoc->default_timetolive = info->pr_value; + goto out; + } + + if (sctp_style(sk, TCP)) + info->pr_assoc_id = SCTP_FUTURE_ASSOC; + + if (info->pr_assoc_id == SCTP_FUTURE_ASSOC || + info->pr_assoc_id == SCTP_ALL_ASSOC) { + SCTP_PR_SET_POLICY(sp->default_flags, info->pr_policy); + sp->default_timetolive = info->pr_value; + } + + if (info->pr_assoc_id == SCTP_CURRENT_ASSOC || + info->pr_assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &sp->ep->asocs, asocs) { + SCTP_PR_SET_POLICY(asoc->default_flags, + info->pr_policy); + asoc->default_timetolive = info->pr_value; + } + } + +out: + return retval; +} + +static int sctp_setsockopt_reconfig_supported(struct sock *sk, + struct sctp_assoc_value *params, + unsigned int optlen) +{ + struct sctp_association *asoc; + int retval = -EINVAL; + + if (optlen != sizeof(*params)) + goto out; + + asoc = sctp_id2assoc(sk, params->assoc_id); + if (!asoc && params->assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + goto out; + + sctp_sk(sk)->ep->reconf_enable = !!params->assoc_value; + + retval = 0; + +out: + return retval; +} + +static int sctp_setsockopt_enable_strreset(struct sock *sk, + struct sctp_assoc_value *params, + unsigned int optlen) +{ + struct sctp_endpoint *ep = sctp_sk(sk)->ep; + struct sctp_association *asoc; + int retval = -EINVAL; + + if (optlen != sizeof(*params)) + goto out; + + if (params->assoc_value & (~SCTP_ENABLE_STRRESET_MASK)) + goto out; + + asoc = sctp_id2assoc(sk, params->assoc_id); + if (!asoc && params->assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + goto out; + + retval = 0; + + if (asoc) { + asoc->strreset_enable = params->assoc_value; + goto out; + } + + if (sctp_style(sk, TCP)) + params->assoc_id = SCTP_FUTURE_ASSOC; + + if (params->assoc_id == SCTP_FUTURE_ASSOC || + params->assoc_id == SCTP_ALL_ASSOC) + ep->strreset_enable = params->assoc_value; + + if (params->assoc_id == SCTP_CURRENT_ASSOC || + params->assoc_id == SCTP_ALL_ASSOC) + list_for_each_entry(asoc, &ep->asocs, asocs) + asoc->strreset_enable = params->assoc_value; + +out: + return retval; +} + +static int sctp_setsockopt_reset_streams(struct sock *sk, + struct sctp_reset_streams *params, + unsigned int optlen) +{ + struct sctp_association *asoc; + + if (optlen < sizeof(*params)) + return -EINVAL; + /* srs_number_streams is u16, so optlen can't be bigger than this. */ + optlen = min_t(unsigned int, optlen, USHRT_MAX + + sizeof(__u16) * sizeof(*params)); + + if (params->srs_number_streams * sizeof(__u16) > + optlen - sizeof(*params)) + return -EINVAL; + + asoc = sctp_id2assoc(sk, params->srs_assoc_id); + if (!asoc) + return -EINVAL; + + return sctp_send_reset_streams(asoc, params); +} + +static int sctp_setsockopt_reset_assoc(struct sock *sk, sctp_assoc_t *associd, + unsigned int optlen) +{ + struct sctp_association *asoc; + + if (optlen != sizeof(*associd)) + return -EINVAL; + + asoc = sctp_id2assoc(sk, *associd); + if (!asoc) + return -EINVAL; + + return sctp_send_reset_assoc(asoc); +} + +static int sctp_setsockopt_add_streams(struct sock *sk, + struct sctp_add_streams *params, + unsigned int optlen) +{ + struct sctp_association *asoc; + + if (optlen != sizeof(*params)) + return -EINVAL; + + asoc = sctp_id2assoc(sk, params->sas_assoc_id); + if (!asoc) + return -EINVAL; + + return sctp_send_add_streams(asoc, params); +} + +static int sctp_setsockopt_scheduler(struct sock *sk, + struct sctp_assoc_value *params, + unsigned int optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_association *asoc; + int retval = 0; + + if (optlen < sizeof(*params)) + return -EINVAL; + + if (params->assoc_value > SCTP_SS_MAX) + return -EINVAL; + + asoc = sctp_id2assoc(sk, params->assoc_id); + if (!asoc && params->assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) + return sctp_sched_set_sched(asoc, params->assoc_value); + + if (sctp_style(sk, TCP)) + params->assoc_id = SCTP_FUTURE_ASSOC; + + if (params->assoc_id == SCTP_FUTURE_ASSOC || + params->assoc_id == SCTP_ALL_ASSOC) + sp->default_ss = params->assoc_value; + + if (params->assoc_id == SCTP_CURRENT_ASSOC || + params->assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &sp->ep->asocs, asocs) { + int ret = sctp_sched_set_sched(asoc, + params->assoc_value); + + if (ret && !retval) + retval = ret; + } + } + + return retval; +} + +static int sctp_setsockopt_scheduler_value(struct sock *sk, + struct sctp_stream_value *params, + unsigned int optlen) +{ + struct sctp_association *asoc; + int retval = -EINVAL; + + if (optlen < sizeof(*params)) + goto out; + + asoc = sctp_id2assoc(sk, params->assoc_id); + if (!asoc && params->assoc_id != SCTP_CURRENT_ASSOC && + sctp_style(sk, UDP)) + goto out; + + if (asoc) { + retval = sctp_sched_set_value(asoc, params->stream_id, + params->stream_value, GFP_KERNEL); + goto out; + } + + retval = 0; + + list_for_each_entry(asoc, &sctp_sk(sk)->ep->asocs, asocs) { + int ret = sctp_sched_set_value(asoc, params->stream_id, + params->stream_value, + GFP_KERNEL); + if (ret && !retval) /* try to return the 1st error. */ + retval = ret; + } + +out: + return retval; +} + +static int sctp_setsockopt_interleaving_supported(struct sock *sk, + struct sctp_assoc_value *p, + unsigned int optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_association *asoc; + + if (optlen < sizeof(*p)) + return -EINVAL; + + asoc = sctp_id2assoc(sk, p->assoc_id); + if (!asoc && p->assoc_id != SCTP_FUTURE_ASSOC && sctp_style(sk, UDP)) + return -EINVAL; + + if (!sock_net(sk)->sctp.intl_enable || !sp->frag_interleave) { + return -EPERM; + } + + sp->ep->intl_enable = !!p->assoc_value; + return 0; +} + +static int sctp_setsockopt_reuse_port(struct sock *sk, int *val, + unsigned int optlen) +{ + if (!sctp_style(sk, TCP)) + return -EOPNOTSUPP; + + if (sctp_sk(sk)->ep->base.bind_addr.port) + return -EFAULT; + + if (optlen < sizeof(int)) + return -EINVAL; + + sctp_sk(sk)->reuse = !!*val; + + return 0; +} + +static int sctp_assoc_ulpevent_type_set(struct sctp_event *param, + struct sctp_association *asoc) +{ + struct sctp_ulpevent *event; + + sctp_ulpevent_type_set(&asoc->subscribe, param->se_type, param->se_on); + + if (param->se_type == SCTP_SENDER_DRY_EVENT && param->se_on) { + if (sctp_outq_is_empty(&asoc->outqueue)) { + event = sctp_ulpevent_make_sender_dry_event(asoc, + GFP_USER | __GFP_NOWARN); + if (!event) + return -ENOMEM; + + asoc->stream.si->enqueue_event(&asoc->ulpq, event); + } + } + + return 0; +} + +static int sctp_setsockopt_event(struct sock *sk, struct sctp_event *param, + unsigned int optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_association *asoc; + int retval = 0; + + if (optlen < sizeof(*param)) + return -EINVAL; + + if (param->se_type < SCTP_SN_TYPE_BASE || + param->se_type > SCTP_SN_TYPE_MAX) + return -EINVAL; + + asoc = sctp_id2assoc(sk, param->se_assoc_id); + if (!asoc && param->se_assoc_id > SCTP_ALL_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) + return sctp_assoc_ulpevent_type_set(param, asoc); + + if (sctp_style(sk, TCP)) + param->se_assoc_id = SCTP_FUTURE_ASSOC; + + if (param->se_assoc_id == SCTP_FUTURE_ASSOC || + param->se_assoc_id == SCTP_ALL_ASSOC) + sctp_ulpevent_type_set(&sp->subscribe, + param->se_type, param->se_on); + + if (param->se_assoc_id == SCTP_CURRENT_ASSOC || + param->se_assoc_id == SCTP_ALL_ASSOC) { + list_for_each_entry(asoc, &sp->ep->asocs, asocs) { + int ret = sctp_assoc_ulpevent_type_set(param, asoc); + + if (ret && !retval) + retval = ret; + } + } + + return retval; +} + +static int sctp_setsockopt_asconf_supported(struct sock *sk, + struct sctp_assoc_value *params, + unsigned int optlen) +{ + struct sctp_association *asoc; + struct sctp_endpoint *ep; + int retval = -EINVAL; + + if (optlen != sizeof(*params)) + goto out; + + asoc = sctp_id2assoc(sk, params->assoc_id); + if (!asoc && params->assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + goto out; + + ep = sctp_sk(sk)->ep; + ep->asconf_enable = !!params->assoc_value; + + if (ep->asconf_enable && ep->auth_enable) { + sctp_auth_ep_add_chunkid(ep, SCTP_CID_ASCONF); + sctp_auth_ep_add_chunkid(ep, SCTP_CID_ASCONF_ACK); + } + + retval = 0; + +out: + return retval; +} + +static int sctp_setsockopt_auth_supported(struct sock *sk, + struct sctp_assoc_value *params, + unsigned int optlen) +{ + struct sctp_association *asoc; + struct sctp_endpoint *ep; + int retval = -EINVAL; + + if (optlen != sizeof(*params)) + goto out; + + asoc = sctp_id2assoc(sk, params->assoc_id); + if (!asoc && params->assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + goto out; + + ep = sctp_sk(sk)->ep; + if (params->assoc_value) { + retval = sctp_auth_init(ep, GFP_KERNEL); + if (retval) + goto out; + if (ep->asconf_enable) { + sctp_auth_ep_add_chunkid(ep, SCTP_CID_ASCONF); + sctp_auth_ep_add_chunkid(ep, SCTP_CID_ASCONF_ACK); + } + } + + ep->auth_enable = !!params->assoc_value; + retval = 0; + +out: + return retval; +} + +static int sctp_setsockopt_ecn_supported(struct sock *sk, + struct sctp_assoc_value *params, + unsigned int optlen) +{ + struct sctp_association *asoc; + int retval = -EINVAL; + + if (optlen != sizeof(*params)) + goto out; + + asoc = sctp_id2assoc(sk, params->assoc_id); + if (!asoc && params->assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + goto out; + + sctp_sk(sk)->ep->ecn_enable = !!params->assoc_value; + retval = 0; + +out: + return retval; +} + +static int sctp_setsockopt_pf_expose(struct sock *sk, + struct sctp_assoc_value *params, + unsigned int optlen) +{ + struct sctp_association *asoc; + int retval = -EINVAL; + + if (optlen != sizeof(*params)) + goto out; + + if (params->assoc_value > SCTP_PF_EXPOSE_MAX) + goto out; + + asoc = sctp_id2assoc(sk, params->assoc_id); + if (!asoc && params->assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + goto out; + + if (asoc) + asoc->pf_expose = params->assoc_value; + else + sctp_sk(sk)->pf_expose = params->assoc_value; + retval = 0; + +out: + return retval; +} + +static int sctp_setsockopt_encap_port(struct sock *sk, + struct sctp_udpencaps *encap, + unsigned int optlen) +{ + struct sctp_association *asoc; + struct sctp_transport *t; + __be16 encap_port; + + if (optlen != sizeof(*encap)) + return -EINVAL; + + /* If an address other than INADDR_ANY is specified, and + * no transport is found, then the request is invalid. + */ + encap_port = (__force __be16)encap->sue_port; + if (!sctp_is_any(sk, (union sctp_addr *)&encap->sue_address)) { + t = sctp_addr_id2transport(sk, &encap->sue_address, + encap->sue_assoc_id); + if (!t) + return -EINVAL; + + t->encap_port = encap_port; + return 0; + } + + /* Get association, if assoc_id != SCTP_FUTURE_ASSOC and the + * socket is a one to many style socket, and an association + * was not found, then the id was invalid. + */ + asoc = sctp_id2assoc(sk, encap->sue_assoc_id); + if (!asoc && encap->sue_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + /* If changes are for association, also apply encap_port to + * each transport. + */ + if (asoc) { + list_for_each_entry(t, &asoc->peer.transport_addr_list, + transports) + t->encap_port = encap_port; + + asoc->encap_port = encap_port; + return 0; + } + + sctp_sk(sk)->encap_port = encap_port; + return 0; +} + +static int sctp_setsockopt_probe_interval(struct sock *sk, + struct sctp_probeinterval *params, + unsigned int optlen) +{ + struct sctp_association *asoc; + struct sctp_transport *t; + __u32 probe_interval; + + if (optlen != sizeof(*params)) + return -EINVAL; + + probe_interval = params->spi_interval; + if (probe_interval && probe_interval < SCTP_PROBE_TIMER_MIN) + return -EINVAL; + + /* If an address other than INADDR_ANY is specified, and + * no transport is found, then the request is invalid. + */ + if (!sctp_is_any(sk, (union sctp_addr *)¶ms->spi_address)) { + t = sctp_addr_id2transport(sk, ¶ms->spi_address, + params->spi_assoc_id); + if (!t) + return -EINVAL; + + t->probe_interval = msecs_to_jiffies(probe_interval); + sctp_transport_pl_reset(t); + return 0; + } + + /* Get association, if assoc_id != SCTP_FUTURE_ASSOC and the + * socket is a one to many style socket, and an association + * was not found, then the id was invalid. + */ + asoc = sctp_id2assoc(sk, params->spi_assoc_id); + if (!asoc && params->spi_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + /* If changes are for association, also apply probe_interval to + * each transport. + */ + if (asoc) { + list_for_each_entry(t, &asoc->peer.transport_addr_list, transports) { + t->probe_interval = msecs_to_jiffies(probe_interval); + sctp_transport_pl_reset(t); + } + + asoc->probe_interval = msecs_to_jiffies(probe_interval); + return 0; + } + + sctp_sk(sk)->probe_interval = probe_interval; + return 0; +} + +/* API 6.2 setsockopt(), getsockopt() + * + * Applications use setsockopt() and getsockopt() to set or retrieve + * socket options. Socket options are used to change the default + * behavior of sockets calls. They are described in Section 7. + * + * The syntax is: + * + * ret = getsockopt(int sd, int level, int optname, void __user *optval, + * int __user *optlen); + * ret = setsockopt(int sd, int level, int optname, const void __user *optval, + * int optlen); + * + * sd - the socket descript. + * level - set to IPPROTO_SCTP for all SCTP options. + * optname - the option name. + * optval - the buffer to store the value of the option. + * optlen - the size of the buffer. + */ +static int sctp_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + void *kopt = NULL; + int retval = 0; + + pr_debug("%s: sk:%p, optname:%d\n", __func__, sk, optname); + + /* I can hardly begin to describe how wrong this is. This is + * so broken as to be worse than useless. The API draft + * REALLY is NOT helpful here... I am not convinced that the + * semantics of setsockopt() with a level OTHER THAN SOL_SCTP + * are at all well-founded. + */ + if (level != SOL_SCTP) { + struct sctp_af *af = sctp_sk(sk)->pf->af; + + return af->setsockopt(sk, level, optname, optval, optlen); + } + + if (optlen > 0) { + /* Trim it to the biggest size sctp sockopt may need if necessary */ + optlen = min_t(unsigned int, optlen, + PAGE_ALIGN(USHRT_MAX + + sizeof(__u16) * sizeof(struct sctp_reset_streams))); + kopt = memdup_sockptr(optval, optlen); + if (IS_ERR(kopt)) + return PTR_ERR(kopt); + } + + lock_sock(sk); + + switch (optname) { + case SCTP_SOCKOPT_BINDX_ADD: + /* 'optlen' is the size of the addresses buffer. */ + retval = sctp_setsockopt_bindx(sk, kopt, optlen, + SCTP_BINDX_ADD_ADDR); + break; + + case SCTP_SOCKOPT_BINDX_REM: + /* 'optlen' is the size of the addresses buffer. */ + retval = sctp_setsockopt_bindx(sk, kopt, optlen, + SCTP_BINDX_REM_ADDR); + break; + + case SCTP_SOCKOPT_CONNECTX_OLD: + /* 'optlen' is the size of the addresses buffer. */ + retval = sctp_setsockopt_connectx_old(sk, kopt, optlen); + break; + + case SCTP_SOCKOPT_CONNECTX: + /* 'optlen' is the size of the addresses buffer. */ + retval = sctp_setsockopt_connectx(sk, kopt, optlen); + break; + + case SCTP_DISABLE_FRAGMENTS: + retval = sctp_setsockopt_disable_fragments(sk, kopt, optlen); + break; + + case SCTP_EVENTS: + retval = sctp_setsockopt_events(sk, kopt, optlen); + break; + + case SCTP_AUTOCLOSE: + retval = sctp_setsockopt_autoclose(sk, kopt, optlen); + break; + + case SCTP_PEER_ADDR_PARAMS: + retval = sctp_setsockopt_peer_addr_params(sk, kopt, optlen); + break; + + case SCTP_DELAYED_SACK: + retval = sctp_setsockopt_delayed_ack(sk, kopt, optlen); + break; + case SCTP_PARTIAL_DELIVERY_POINT: + retval = sctp_setsockopt_partial_delivery_point(sk, kopt, optlen); + break; + + case SCTP_INITMSG: + retval = sctp_setsockopt_initmsg(sk, kopt, optlen); + break; + case SCTP_DEFAULT_SEND_PARAM: + retval = sctp_setsockopt_default_send_param(sk, kopt, optlen); + break; + case SCTP_DEFAULT_SNDINFO: + retval = sctp_setsockopt_default_sndinfo(sk, kopt, optlen); + break; + case SCTP_PRIMARY_ADDR: + retval = sctp_setsockopt_primary_addr(sk, kopt, optlen); + break; + case SCTP_SET_PEER_PRIMARY_ADDR: + retval = sctp_setsockopt_peer_primary_addr(sk, kopt, optlen); + break; + case SCTP_NODELAY: + retval = sctp_setsockopt_nodelay(sk, kopt, optlen); + break; + case SCTP_RTOINFO: + retval = sctp_setsockopt_rtoinfo(sk, kopt, optlen); + break; + case SCTP_ASSOCINFO: + retval = sctp_setsockopt_associnfo(sk, kopt, optlen); + break; + case SCTP_I_WANT_MAPPED_V4_ADDR: + retval = sctp_setsockopt_mappedv4(sk, kopt, optlen); + break; + case SCTP_MAXSEG: + retval = sctp_setsockopt_maxseg(sk, kopt, optlen); + break; + case SCTP_ADAPTATION_LAYER: + retval = sctp_setsockopt_adaptation_layer(sk, kopt, optlen); + break; + case SCTP_CONTEXT: + retval = sctp_setsockopt_context(sk, kopt, optlen); + break; + case SCTP_FRAGMENT_INTERLEAVE: + retval = sctp_setsockopt_fragment_interleave(sk, kopt, optlen); + break; + case SCTP_MAX_BURST: + retval = sctp_setsockopt_maxburst(sk, kopt, optlen); + break; + case SCTP_AUTH_CHUNK: + retval = sctp_setsockopt_auth_chunk(sk, kopt, optlen); + break; + case SCTP_HMAC_IDENT: + retval = sctp_setsockopt_hmac_ident(sk, kopt, optlen); + break; + case SCTP_AUTH_KEY: + retval = sctp_setsockopt_auth_key(sk, kopt, optlen); + break; + case SCTP_AUTH_ACTIVE_KEY: + retval = sctp_setsockopt_active_key(sk, kopt, optlen); + break; + case SCTP_AUTH_DELETE_KEY: + retval = sctp_setsockopt_del_key(sk, kopt, optlen); + break; + case SCTP_AUTH_DEACTIVATE_KEY: + retval = sctp_setsockopt_deactivate_key(sk, kopt, optlen); + break; + case SCTP_AUTO_ASCONF: + retval = sctp_setsockopt_auto_asconf(sk, kopt, optlen); + break; + case SCTP_PEER_ADDR_THLDS: + retval = sctp_setsockopt_paddr_thresholds(sk, kopt, optlen, + false); + break; + case SCTP_PEER_ADDR_THLDS_V2: + retval = sctp_setsockopt_paddr_thresholds(sk, kopt, optlen, + true); + break; + case SCTP_RECVRCVINFO: + retval = sctp_setsockopt_recvrcvinfo(sk, kopt, optlen); + break; + case SCTP_RECVNXTINFO: + retval = sctp_setsockopt_recvnxtinfo(sk, kopt, optlen); + break; + case SCTP_PR_SUPPORTED: + retval = sctp_setsockopt_pr_supported(sk, kopt, optlen); + break; + case SCTP_DEFAULT_PRINFO: + retval = sctp_setsockopt_default_prinfo(sk, kopt, optlen); + break; + case SCTP_RECONFIG_SUPPORTED: + retval = sctp_setsockopt_reconfig_supported(sk, kopt, optlen); + break; + case SCTP_ENABLE_STREAM_RESET: + retval = sctp_setsockopt_enable_strreset(sk, kopt, optlen); + break; + case SCTP_RESET_STREAMS: + retval = sctp_setsockopt_reset_streams(sk, kopt, optlen); + break; + case SCTP_RESET_ASSOC: + retval = sctp_setsockopt_reset_assoc(sk, kopt, optlen); + break; + case SCTP_ADD_STREAMS: + retval = sctp_setsockopt_add_streams(sk, kopt, optlen); + break; + case SCTP_STREAM_SCHEDULER: + retval = sctp_setsockopt_scheduler(sk, kopt, optlen); + break; + case SCTP_STREAM_SCHEDULER_VALUE: + retval = sctp_setsockopt_scheduler_value(sk, kopt, optlen); + break; + case SCTP_INTERLEAVING_SUPPORTED: + retval = sctp_setsockopt_interleaving_supported(sk, kopt, + optlen); + break; + case SCTP_REUSE_PORT: + retval = sctp_setsockopt_reuse_port(sk, kopt, optlen); + break; + case SCTP_EVENT: + retval = sctp_setsockopt_event(sk, kopt, optlen); + break; + case SCTP_ASCONF_SUPPORTED: + retval = sctp_setsockopt_asconf_supported(sk, kopt, optlen); + break; + case SCTP_AUTH_SUPPORTED: + retval = sctp_setsockopt_auth_supported(sk, kopt, optlen); + break; + case SCTP_ECN_SUPPORTED: + retval = sctp_setsockopt_ecn_supported(sk, kopt, optlen); + break; + case SCTP_EXPOSE_POTENTIALLY_FAILED_STATE: + retval = sctp_setsockopt_pf_expose(sk, kopt, optlen); + break; + case SCTP_REMOTE_UDP_ENCAPS_PORT: + retval = sctp_setsockopt_encap_port(sk, kopt, optlen); + break; + case SCTP_PLPMTUD_PROBE_INTERVAL: + retval = sctp_setsockopt_probe_interval(sk, kopt, optlen); + break; + default: + retval = -ENOPROTOOPT; + break; + } + + release_sock(sk); + kfree(kopt); + return retval; +} + +/* API 3.1.6 connect() - UDP Style Syntax + * + * An application may use the connect() call in the UDP model to initiate an + * association without sending data. + * + * The syntax is: + * + * ret = connect(int sd, const struct sockaddr *nam, socklen_t len); + * + * sd: the socket descriptor to have a new association added to. + * + * nam: the address structure (either struct sockaddr_in or struct + * sockaddr_in6 defined in RFC2553 [7]). + * + * len: the size of the address. + */ +static int sctp_connect(struct sock *sk, struct sockaddr *addr, + int addr_len, int flags) +{ + struct sctp_af *af; + int err = -EINVAL; + + lock_sock(sk); + pr_debug("%s: sk:%p, sockaddr:%p, addr_len:%d\n", __func__, sk, + addr, addr_len); + + /* Validate addr_len before calling common connect/connectx routine. */ + af = sctp_get_af_specific(addr->sa_family); + if (af && addr_len >= af->sockaddr_len) + err = __sctp_connect(sk, addr, af->sockaddr_len, flags, NULL); + + release_sock(sk); + return err; +} + +int sctp_inet_connect(struct socket *sock, struct sockaddr *uaddr, + int addr_len, int flags) +{ + if (addr_len < sizeof(uaddr->sa_family)) + return -EINVAL; + + if (uaddr->sa_family == AF_UNSPEC) + return -EOPNOTSUPP; + + return sctp_connect(sock->sk, uaddr, addr_len, flags); +} + +/* FIXME: Write comments. */ +static int sctp_disconnect(struct sock *sk, int flags) +{ + return -EOPNOTSUPP; /* STUB */ +} + +/* 4.1.4 accept() - TCP Style Syntax + * + * Applications use accept() call to remove an established SCTP + * association from the accept queue of the endpoint. A new socket + * descriptor will be returned from accept() to represent the newly + * formed association. + */ +static struct sock *sctp_accept(struct sock *sk, int flags, int *err, bool kern) +{ + struct sctp_sock *sp; + struct sctp_endpoint *ep; + struct sock *newsk = NULL; + struct sctp_association *asoc; + long timeo; + int error = 0; + + lock_sock(sk); + + sp = sctp_sk(sk); + ep = sp->ep; + + if (!sctp_style(sk, TCP)) { + error = -EOPNOTSUPP; + goto out; + } + + if (!sctp_sstate(sk, LISTENING)) { + error = -EINVAL; + goto out; + } + + timeo = sock_rcvtimeo(sk, flags & O_NONBLOCK); + + error = sctp_wait_for_accept(sk, timeo); + if (error) + goto out; + + /* We treat the list of associations on the endpoint as the accept + * queue and pick the first association on the list. + */ + asoc = list_entry(ep->asocs.next, struct sctp_association, asocs); + + newsk = sp->pf->create_accept_sk(sk, asoc, kern); + if (!newsk) { + error = -ENOMEM; + goto out; + } + + /* Populate the fields of the newsk from the oldsk and migrate the + * asoc to the newsk. + */ + error = sctp_sock_migrate(sk, newsk, asoc, SCTP_SOCKET_TCP); + if (error) { + sk_common_release(newsk); + newsk = NULL; + } + +out: + release_sock(sk); + *err = error; + return newsk; +} + +/* The SCTP ioctl handler. */ +static int sctp_ioctl(struct sock *sk, int cmd, unsigned long arg) +{ + int rc = -ENOTCONN; + + lock_sock(sk); + + /* + * SEQPACKET-style sockets in LISTENING state are valid, for + * SCTP, so only discard TCP-style sockets in LISTENING state. + */ + if (sctp_style(sk, TCP) && sctp_sstate(sk, LISTENING)) + goto out; + + switch (cmd) { + case SIOCINQ: { + struct sk_buff *skb; + unsigned int amount = 0; + + skb = skb_peek(&sk->sk_receive_queue); + if (skb != NULL) { + /* + * We will only return the amount of this packet since + * that is all that will be read. + */ + amount = skb->len; + } + rc = put_user(amount, (int __user *)arg); + break; + } + default: + rc = -ENOIOCTLCMD; + break; + } +out: + release_sock(sk); + return rc; +} + +/* This is the function which gets called during socket creation to + * initialized the SCTP-specific portion of the sock. + * The sock structure should already be zero-filled memory. + */ +static int sctp_init_sock(struct sock *sk) +{ + struct net *net = sock_net(sk); + struct sctp_sock *sp; + + pr_debug("%s: sk:%p\n", __func__, sk); + + sp = sctp_sk(sk); + + /* Initialize the SCTP per socket area. */ + switch (sk->sk_type) { + case SOCK_SEQPACKET: + sp->type = SCTP_SOCKET_UDP; + break; + case SOCK_STREAM: + sp->type = SCTP_SOCKET_TCP; + break; + default: + return -ESOCKTNOSUPPORT; + } + + sk->sk_gso_type = SKB_GSO_SCTP; + + /* Initialize default send parameters. These parameters can be + * modified with the SCTP_DEFAULT_SEND_PARAM socket option. + */ + sp->default_stream = 0; + sp->default_ppid = 0; + sp->default_flags = 0; + sp->default_context = 0; + sp->default_timetolive = 0; + + sp->default_rcv_context = 0; + sp->max_burst = net->sctp.max_burst; + + sp->sctp_hmac_alg = net->sctp.sctp_hmac_alg; + + /* Initialize default setup parameters. These parameters + * can be modified with the SCTP_INITMSG socket option or + * overridden by the SCTP_INIT CMSG. + */ + sp->initmsg.sinit_num_ostreams = sctp_max_outstreams; + sp->initmsg.sinit_max_instreams = sctp_max_instreams; + sp->initmsg.sinit_max_attempts = net->sctp.max_retrans_init; + sp->initmsg.sinit_max_init_timeo = net->sctp.rto_max; + + /* Initialize default RTO related parameters. These parameters can + * be modified for with the SCTP_RTOINFO socket option. + */ + sp->rtoinfo.srto_initial = net->sctp.rto_initial; + sp->rtoinfo.srto_max = net->sctp.rto_max; + sp->rtoinfo.srto_min = net->sctp.rto_min; + + /* Initialize default association related parameters. These parameters + * can be modified with the SCTP_ASSOCINFO socket option. + */ + sp->assocparams.sasoc_asocmaxrxt = net->sctp.max_retrans_association; + sp->assocparams.sasoc_number_peer_destinations = 0; + sp->assocparams.sasoc_peer_rwnd = 0; + sp->assocparams.sasoc_local_rwnd = 0; + sp->assocparams.sasoc_cookie_life = net->sctp.valid_cookie_life; + + /* Initialize default event subscriptions. By default, all the + * options are off. + */ + sp->subscribe = 0; + + /* Default Peer Address Parameters. These defaults can + * be modified via SCTP_PEER_ADDR_PARAMS + */ + sp->hbinterval = net->sctp.hb_interval; + sp->udp_port = htons(net->sctp.udp_port); + sp->encap_port = htons(net->sctp.encap_port); + sp->pathmaxrxt = net->sctp.max_retrans_path; + sp->pf_retrans = net->sctp.pf_retrans; + sp->ps_retrans = net->sctp.ps_retrans; + sp->pf_expose = net->sctp.pf_expose; + sp->pathmtu = 0; /* allow default discovery */ + sp->sackdelay = net->sctp.sack_timeout; + sp->sackfreq = 2; + sp->param_flags = SPP_HB_ENABLE | + SPP_PMTUD_ENABLE | + SPP_SACKDELAY_ENABLE; + sp->default_ss = SCTP_SS_DEFAULT; + + /* If enabled no SCTP message fragmentation will be performed. + * Configure through SCTP_DISABLE_FRAGMENTS socket option. + */ + sp->disable_fragments = 0; + + /* Enable Nagle algorithm by default. */ + sp->nodelay = 0; + + sp->recvrcvinfo = 0; + sp->recvnxtinfo = 0; + + /* Enable by default. */ + sp->v4mapped = 1; + + /* Auto-close idle associations after the configured + * number of seconds. A value of 0 disables this + * feature. Configure through the SCTP_AUTOCLOSE socket option, + * for UDP-style sockets only. + */ + sp->autoclose = 0; + + /* User specified fragmentation limit. */ + sp->user_frag = 0; + + sp->adaptation_ind = 0; + + sp->pf = sctp_get_pf_specific(sk->sk_family); + + /* Control variables for partial data delivery. */ + atomic_set(&sp->pd_mode, 0); + skb_queue_head_init(&sp->pd_lobby); + sp->frag_interleave = 0; + sp->probe_interval = net->sctp.probe_interval; + + /* Create a per socket endpoint structure. Even if we + * change the data structure relationships, this may still + * be useful for storing pre-connect address information. + */ + sp->ep = sctp_endpoint_new(sk, GFP_KERNEL); + if (!sp->ep) + return -ENOMEM; + + sp->hmac = NULL; + + sk->sk_destruct = sctp_destruct_sock; + + SCTP_DBG_OBJCNT_INC(sock); + + sk_sockets_allocated_inc(sk); + sock_prot_inuse_add(net, sk->sk_prot, 1); + + return 0; +} + +/* Cleanup any SCTP per socket resources. Must be called with + * sock_net(sk)->sctp.addr_wq_lock held if sp->do_auto_asconf is true + */ +static void sctp_destroy_sock(struct sock *sk) +{ + struct sctp_sock *sp; + + pr_debug("%s: sk:%p\n", __func__, sk); + + /* Release our hold on the endpoint. */ + sp = sctp_sk(sk); + /* This could happen during socket init, thus we bail out + * early, since the rest of the below is not setup either. + */ + if (sp->ep == NULL) + return; + + if (sp->do_auto_asconf) { + sp->do_auto_asconf = 0; + list_del(&sp->auto_asconf_list); + } + sctp_endpoint_free(sp->ep); + sk_sockets_allocated_dec(sk); + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); +} + +/* Triggered when there are no references on the socket anymore */ +static void sctp_destruct_common(struct sock *sk) +{ + struct sctp_sock *sp = sctp_sk(sk); + + /* Free up the HMAC transform. */ + crypto_free_shash(sp->hmac); +} + +static void sctp_destruct_sock(struct sock *sk) +{ + sctp_destruct_common(sk); + inet_sock_destruct(sk); +} + +/* API 4.1.7 shutdown() - TCP Style Syntax + * int shutdown(int socket, int how); + * + * sd - the socket descriptor of the association to be closed. + * how - Specifies the type of shutdown. The values are + * as follows: + * SHUT_RD + * Disables further receive operations. No SCTP + * protocol action is taken. + * SHUT_WR + * Disables further send operations, and initiates + * the SCTP shutdown sequence. + * SHUT_RDWR + * Disables further send and receive operations + * and initiates the SCTP shutdown sequence. + */ +static void sctp_shutdown(struct sock *sk, int how) +{ + struct net *net = sock_net(sk); + struct sctp_endpoint *ep; + + if (!sctp_style(sk, TCP)) + return; + + ep = sctp_sk(sk)->ep; + if (how & SEND_SHUTDOWN && !list_empty(&ep->asocs)) { + struct sctp_association *asoc; + + inet_sk_set_state(sk, SCTP_SS_CLOSING); + asoc = list_entry(ep->asocs.next, + struct sctp_association, asocs); + sctp_primitive_SHUTDOWN(net, asoc, NULL); + } +} + +int sctp_get_sctp_info(struct sock *sk, struct sctp_association *asoc, + struct sctp_info *info) +{ + struct sctp_transport *prim; + struct list_head *pos; + int mask; + + memset(info, 0, sizeof(*info)); + if (!asoc) { + struct sctp_sock *sp = sctp_sk(sk); + + info->sctpi_s_autoclose = sp->autoclose; + info->sctpi_s_adaptation_ind = sp->adaptation_ind; + info->sctpi_s_pd_point = sp->pd_point; + info->sctpi_s_nodelay = sp->nodelay; + info->sctpi_s_disable_fragments = sp->disable_fragments; + info->sctpi_s_v4mapped = sp->v4mapped; + info->sctpi_s_frag_interleave = sp->frag_interleave; + info->sctpi_s_type = sp->type; + + return 0; + } + + info->sctpi_tag = asoc->c.my_vtag; + info->sctpi_state = asoc->state; + info->sctpi_rwnd = asoc->a_rwnd; + info->sctpi_unackdata = asoc->unack_data; + info->sctpi_penddata = sctp_tsnmap_pending(&asoc->peer.tsn_map); + info->sctpi_instrms = asoc->stream.incnt; + info->sctpi_outstrms = asoc->stream.outcnt; + list_for_each(pos, &asoc->base.inqueue.in_chunk_list) + info->sctpi_inqueue++; + list_for_each(pos, &asoc->outqueue.out_chunk_list) + info->sctpi_outqueue++; + info->sctpi_overall_error = asoc->overall_error_count; + info->sctpi_max_burst = asoc->max_burst; + info->sctpi_maxseg = asoc->frag_point; + info->sctpi_peer_rwnd = asoc->peer.rwnd; + info->sctpi_peer_tag = asoc->c.peer_vtag; + + mask = asoc->peer.ecn_capable << 1; + mask = (mask | asoc->peer.ipv4_address) << 1; + mask = (mask | asoc->peer.ipv6_address) << 1; + mask = (mask | asoc->peer.hostname_address) << 1; + mask = (mask | asoc->peer.asconf_capable) << 1; + mask = (mask | asoc->peer.prsctp_capable) << 1; + mask = (mask | asoc->peer.auth_capable); + info->sctpi_peer_capable = mask; + mask = asoc->peer.sack_needed << 1; + mask = (mask | asoc->peer.sack_generation) << 1; + mask = (mask | asoc->peer.zero_window_announced); + info->sctpi_peer_sack = mask; + + info->sctpi_isacks = asoc->stats.isacks; + info->sctpi_osacks = asoc->stats.osacks; + info->sctpi_opackets = asoc->stats.opackets; + info->sctpi_ipackets = asoc->stats.ipackets; + info->sctpi_rtxchunks = asoc->stats.rtxchunks; + info->sctpi_outofseqtsns = asoc->stats.outofseqtsns; + info->sctpi_idupchunks = asoc->stats.idupchunks; + info->sctpi_gapcnt = asoc->stats.gapcnt; + info->sctpi_ouodchunks = asoc->stats.ouodchunks; + info->sctpi_iuodchunks = asoc->stats.iuodchunks; + info->sctpi_oodchunks = asoc->stats.oodchunks; + info->sctpi_iodchunks = asoc->stats.iodchunks; + info->sctpi_octrlchunks = asoc->stats.octrlchunks; + info->sctpi_ictrlchunks = asoc->stats.ictrlchunks; + + prim = asoc->peer.primary_path; + memcpy(&info->sctpi_p_address, &prim->ipaddr, sizeof(prim->ipaddr)); + info->sctpi_p_state = prim->state; + info->sctpi_p_cwnd = prim->cwnd; + info->sctpi_p_srtt = prim->srtt; + info->sctpi_p_rto = jiffies_to_msecs(prim->rto); + info->sctpi_p_hbinterval = prim->hbinterval; + info->sctpi_p_pathmaxrxt = prim->pathmaxrxt; + info->sctpi_p_sackdelay = jiffies_to_msecs(prim->sackdelay); + info->sctpi_p_ssthresh = prim->ssthresh; + info->sctpi_p_partial_bytes_acked = prim->partial_bytes_acked; + info->sctpi_p_flight_size = prim->flight_size; + info->sctpi_p_error = prim->error_count; + + return 0; +} +EXPORT_SYMBOL_GPL(sctp_get_sctp_info); + +/* use callback to avoid exporting the core structure */ +void sctp_transport_walk_start(struct rhashtable_iter *iter) __acquires(RCU) +{ + rhltable_walk_enter(&sctp_transport_hashtable, iter); + + rhashtable_walk_start(iter); +} + +void sctp_transport_walk_stop(struct rhashtable_iter *iter) __releases(RCU) +{ + rhashtable_walk_stop(iter); + rhashtable_walk_exit(iter); +} + +struct sctp_transport *sctp_transport_get_next(struct net *net, + struct rhashtable_iter *iter) +{ + struct sctp_transport *t; + + t = rhashtable_walk_next(iter); + for (; t; t = rhashtable_walk_next(iter)) { + if (IS_ERR(t)) { + if (PTR_ERR(t) == -EAGAIN) + continue; + break; + } + + if (!sctp_transport_hold(t)) + continue; + + if (net_eq(t->asoc->base.net, net) && + t->asoc->peer.primary_path == t) + break; + + sctp_transport_put(t); + } + + return t; +} + +struct sctp_transport *sctp_transport_get_idx(struct net *net, + struct rhashtable_iter *iter, + int pos) +{ + struct sctp_transport *t; + + if (!pos) + return SEQ_START_TOKEN; + + while ((t = sctp_transport_get_next(net, iter)) && !IS_ERR(t)) { + if (!--pos) + break; + sctp_transport_put(t); + } + + return t; +} + +int sctp_for_each_endpoint(int (*cb)(struct sctp_endpoint *, void *), + void *p) { + int err = 0; + int hash = 0; + struct sctp_endpoint *ep; + struct sctp_hashbucket *head; + + for (head = sctp_ep_hashtable; hash < sctp_ep_hashsize; + hash++, head++) { + read_lock_bh(&head->lock); + sctp_for_each_hentry(ep, &head->chain) { + err = cb(ep, p); + if (err) + break; + } + read_unlock_bh(&head->lock); + } + + return err; +} +EXPORT_SYMBOL_GPL(sctp_for_each_endpoint); + +int sctp_transport_lookup_process(sctp_callback_t cb, struct net *net, + const union sctp_addr *laddr, + const union sctp_addr *paddr, void *p) +{ + struct sctp_transport *transport; + struct sctp_endpoint *ep; + int err = -ENOENT; + + rcu_read_lock(); + transport = sctp_addrs_lookup_transport(net, laddr, paddr); + if (!transport) { + rcu_read_unlock(); + return err; + } + ep = transport->asoc->ep; + if (!sctp_endpoint_hold(ep)) { /* asoc can be peeled off */ + sctp_transport_put(transport); + rcu_read_unlock(); + return err; + } + rcu_read_unlock(); + + err = cb(ep, transport, p); + sctp_endpoint_put(ep); + sctp_transport_put(transport); + return err; +} +EXPORT_SYMBOL_GPL(sctp_transport_lookup_process); + +int sctp_transport_traverse_process(sctp_callback_t cb, sctp_callback_t cb_done, + struct net *net, int *pos, void *p) +{ + struct rhashtable_iter hti; + struct sctp_transport *tsp; + struct sctp_endpoint *ep; + int ret; + +again: + ret = 0; + sctp_transport_walk_start(&hti); + + tsp = sctp_transport_get_idx(net, &hti, *pos + 1); + for (; !IS_ERR_OR_NULL(tsp); tsp = sctp_transport_get_next(net, &hti)) { + ep = tsp->asoc->ep; + if (sctp_endpoint_hold(ep)) { /* asoc can be peeled off */ + ret = cb(ep, tsp, p); + if (ret) + break; + sctp_endpoint_put(ep); + } + (*pos)++; + sctp_transport_put(tsp); + } + sctp_transport_walk_stop(&hti); + + if (ret) { + if (cb_done && !cb_done(ep, tsp, p)) { + (*pos)++; + sctp_endpoint_put(ep); + sctp_transport_put(tsp); + goto again; + } + sctp_endpoint_put(ep); + sctp_transport_put(tsp); + } + + return ret; +} +EXPORT_SYMBOL_GPL(sctp_transport_traverse_process); + +/* 7.2.1 Association Status (SCTP_STATUS) + + * Applications can retrieve current status information about an + * association, including association state, peer receiver window size, + * number of unacked data chunks, and number of data chunks pending + * receipt. This information is read-only. + */ +static int sctp_getsockopt_sctp_status(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_status status; + struct sctp_association *asoc = NULL; + struct sctp_transport *transport; + sctp_assoc_t associd; + int retval = 0; + + if (len < sizeof(status)) { + retval = -EINVAL; + goto out; + } + + len = sizeof(status); + if (copy_from_user(&status, optval, len)) { + retval = -EFAULT; + goto out; + } + + associd = status.sstat_assoc_id; + asoc = sctp_id2assoc(sk, associd); + if (!asoc) { + retval = -EINVAL; + goto out; + } + + transport = asoc->peer.primary_path; + + status.sstat_assoc_id = sctp_assoc2id(asoc); + status.sstat_state = sctp_assoc_to_state(asoc); + status.sstat_rwnd = asoc->peer.rwnd; + status.sstat_unackdata = asoc->unack_data; + + status.sstat_penddata = sctp_tsnmap_pending(&asoc->peer.tsn_map); + status.sstat_instrms = asoc->stream.incnt; + status.sstat_outstrms = asoc->stream.outcnt; + status.sstat_fragmentation_point = asoc->frag_point; + status.sstat_primary.spinfo_assoc_id = sctp_assoc2id(transport->asoc); + memcpy(&status.sstat_primary.spinfo_address, &transport->ipaddr, + transport->af_specific->sockaddr_len); + /* Map ipv4 address into v4-mapped-on-v6 address. */ + sctp_get_pf_specific(sk->sk_family)->addr_to_user(sctp_sk(sk), + (union sctp_addr *)&status.sstat_primary.spinfo_address); + status.sstat_primary.spinfo_state = transport->state; + status.sstat_primary.spinfo_cwnd = transport->cwnd; + status.sstat_primary.spinfo_srtt = transport->srtt; + status.sstat_primary.spinfo_rto = jiffies_to_msecs(transport->rto); + status.sstat_primary.spinfo_mtu = transport->pathmtu; + + if (status.sstat_primary.spinfo_state == SCTP_UNKNOWN) + status.sstat_primary.spinfo_state = SCTP_ACTIVE; + + if (put_user(len, optlen)) { + retval = -EFAULT; + goto out; + } + + pr_debug("%s: len:%d, state:%d, rwnd:%d, assoc_id:%d\n", + __func__, len, status.sstat_state, status.sstat_rwnd, + status.sstat_assoc_id); + + if (copy_to_user(optval, &status, len)) { + retval = -EFAULT; + goto out; + } + +out: + return retval; +} + + +/* 7.2.2 Peer Address Information (SCTP_GET_PEER_ADDR_INFO) + * + * Applications can retrieve information about a specific peer address + * of an association, including its reachability state, congestion + * window, and retransmission timer values. This information is + * read-only. + */ +static int sctp_getsockopt_peer_addr_info(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_paddrinfo pinfo; + struct sctp_transport *transport; + int retval = 0; + + if (len < sizeof(pinfo)) { + retval = -EINVAL; + goto out; + } + + len = sizeof(pinfo); + if (copy_from_user(&pinfo, optval, len)) { + retval = -EFAULT; + goto out; + } + + transport = sctp_addr_id2transport(sk, &pinfo.spinfo_address, + pinfo.spinfo_assoc_id); + if (!transport) { + retval = -EINVAL; + goto out; + } + + if (transport->state == SCTP_PF && + transport->asoc->pf_expose == SCTP_PF_EXPOSE_DISABLE) { + retval = -EACCES; + goto out; + } + + pinfo.spinfo_assoc_id = sctp_assoc2id(transport->asoc); + pinfo.spinfo_state = transport->state; + pinfo.spinfo_cwnd = transport->cwnd; + pinfo.spinfo_srtt = transport->srtt; + pinfo.spinfo_rto = jiffies_to_msecs(transport->rto); + pinfo.spinfo_mtu = transport->pathmtu; + + if (pinfo.spinfo_state == SCTP_UNKNOWN) + pinfo.spinfo_state = SCTP_ACTIVE; + + if (put_user(len, optlen)) { + retval = -EFAULT; + goto out; + } + + if (copy_to_user(optval, &pinfo, len)) { + retval = -EFAULT; + goto out; + } + +out: + return retval; +} + +/* 7.1.12 Enable/Disable message fragmentation (SCTP_DISABLE_FRAGMENTS) + * + * This option is a on/off flag. If enabled no SCTP message + * fragmentation will be performed. Instead if a message being sent + * exceeds the current PMTU size, the message will NOT be sent and + * instead a error will be indicated to the user. + */ +static int sctp_getsockopt_disable_fragments(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + int val; + + if (len < sizeof(int)) + return -EINVAL; + + len = sizeof(int); + val = (sctp_sk(sk)->disable_fragments == 1); + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + return 0; +} + +/* 7.1.15 Set notification and ancillary events (SCTP_EVENTS) + * + * This socket option is used to specify various notifications and + * ancillary data the user wishes to receive. + */ +static int sctp_getsockopt_events(struct sock *sk, int len, char __user *optval, + int __user *optlen) +{ + struct sctp_event_subscribe subscribe; + __u8 *sn_type = (__u8 *)&subscribe; + int i; + + if (len == 0) + return -EINVAL; + if (len > sizeof(struct sctp_event_subscribe)) + len = sizeof(struct sctp_event_subscribe); + if (put_user(len, optlen)) + return -EFAULT; + + for (i = 0; i < len; i++) + sn_type[i] = sctp_ulpevent_type_enabled(sctp_sk(sk)->subscribe, + SCTP_SN_TYPE_BASE + i); + + if (copy_to_user(optval, &subscribe, len)) + return -EFAULT; + + return 0; +} + +/* 7.1.8 Automatic Close of associations (SCTP_AUTOCLOSE) + * + * This socket option is applicable to the UDP-style socket only. When + * set it will cause associations that are idle for more than the + * specified number of seconds to automatically close. An association + * being idle is defined an association that has NOT sent or received + * user data. The special value of '0' indicates that no automatic + * close of any associations should be performed. The option expects an + * integer defining the number of seconds of idle time before an + * association is closed. + */ +static int sctp_getsockopt_autoclose(struct sock *sk, int len, char __user *optval, int __user *optlen) +{ + /* Applicable to UDP-style socket only */ + if (sctp_style(sk, TCP)) + return -EOPNOTSUPP; + if (len < sizeof(int)) + return -EINVAL; + len = sizeof(int); + if (put_user(len, optlen)) + return -EFAULT; + if (put_user(sctp_sk(sk)->autoclose, (int __user *)optval)) + return -EFAULT; + return 0; +} + +/* Helper routine to branch off an association to a new socket. */ +int sctp_do_peeloff(struct sock *sk, sctp_assoc_t id, struct socket **sockp) +{ + struct sctp_association *asoc = sctp_id2assoc(sk, id); + struct sctp_sock *sp = sctp_sk(sk); + struct socket *sock; + int err = 0; + + /* Do not peel off from one netns to another one. */ + if (!net_eq(current->nsproxy->net_ns, sock_net(sk))) + return -EINVAL; + + if (!asoc) + return -EINVAL; + + /* An association cannot be branched off from an already peeled-off + * socket, nor is this supported for tcp style sockets. + */ + if (!sctp_style(sk, UDP)) + return -EINVAL; + + /* Create a new socket. */ + err = sock_create(sk->sk_family, SOCK_SEQPACKET, IPPROTO_SCTP, &sock); + if (err < 0) + return err; + + sctp_copy_sock(sock->sk, sk, asoc); + + /* Make peeled-off sockets more like 1-1 accepted sockets. + * Set the daddr and initialize id to something more random and also + * copy over any ip options. + */ + sp->pf->to_sk_daddr(&asoc->peer.primary_addr, sock->sk); + sp->pf->copy_ip_options(sk, sock->sk); + + /* Populate the fields of the newsk from the oldsk and migrate the + * asoc to the newsk. + */ + err = sctp_sock_migrate(sk, sock->sk, asoc, + SCTP_SOCKET_UDP_HIGH_BANDWIDTH); + if (err) { + sock_release(sock); + sock = NULL; + } + + *sockp = sock; + + return err; +} +EXPORT_SYMBOL(sctp_do_peeloff); + +static int sctp_getsockopt_peeloff_common(struct sock *sk, sctp_peeloff_arg_t *peeloff, + struct file **newfile, unsigned flags) +{ + struct socket *newsock; + int retval; + + retval = sctp_do_peeloff(sk, peeloff->associd, &newsock); + if (retval < 0) + goto out; + + /* Map the socket to an unused fd that can be returned to the user. */ + retval = get_unused_fd_flags(flags & SOCK_CLOEXEC); + if (retval < 0) { + sock_release(newsock); + goto out; + } + + *newfile = sock_alloc_file(newsock, 0, NULL); + if (IS_ERR(*newfile)) { + put_unused_fd(retval); + retval = PTR_ERR(*newfile); + *newfile = NULL; + return retval; + } + + pr_debug("%s: sk:%p, newsk:%p, sd:%d\n", __func__, sk, newsock->sk, + retval); + + peeloff->sd = retval; + + if (flags & SOCK_NONBLOCK) + (*newfile)->f_flags |= O_NONBLOCK; +out: + return retval; +} + +static int sctp_getsockopt_peeloff(struct sock *sk, int len, char __user *optval, int __user *optlen) +{ + sctp_peeloff_arg_t peeloff; + struct file *newfile = NULL; + int retval = 0; + + if (len < sizeof(sctp_peeloff_arg_t)) + return -EINVAL; + len = sizeof(sctp_peeloff_arg_t); + if (copy_from_user(&peeloff, optval, len)) + return -EFAULT; + + retval = sctp_getsockopt_peeloff_common(sk, &peeloff, &newfile, 0); + if (retval < 0) + goto out; + + /* Return the fd mapped to the new socket. */ + if (put_user(len, optlen)) { + fput(newfile); + put_unused_fd(retval); + return -EFAULT; + } + + if (copy_to_user(optval, &peeloff, len)) { + fput(newfile); + put_unused_fd(retval); + return -EFAULT; + } + fd_install(retval, newfile); +out: + return retval; +} + +static int sctp_getsockopt_peeloff_flags(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + sctp_peeloff_flags_arg_t peeloff; + struct file *newfile = NULL; + int retval = 0; + + if (len < sizeof(sctp_peeloff_flags_arg_t)) + return -EINVAL; + len = sizeof(sctp_peeloff_flags_arg_t); + if (copy_from_user(&peeloff, optval, len)) + return -EFAULT; + + retval = sctp_getsockopt_peeloff_common(sk, &peeloff.p_arg, + &newfile, peeloff.flags); + if (retval < 0) + goto out; + + /* Return the fd mapped to the new socket. */ + if (put_user(len, optlen)) { + fput(newfile); + put_unused_fd(retval); + return -EFAULT; + } + + if (copy_to_user(optval, &peeloff, len)) { + fput(newfile); + put_unused_fd(retval); + return -EFAULT; + } + fd_install(retval, newfile); +out: + return retval; +} + +/* 7.1.13 Peer Address Parameters (SCTP_PEER_ADDR_PARAMS) + * + * Applications can enable or disable heartbeats for any peer address of + * an association, modify an address's heartbeat interval, force a + * heartbeat to be sent immediately, and adjust the address's maximum + * number of retransmissions sent before an address is considered + * unreachable. The following structure is used to access and modify an + * address's parameters: + * + * struct sctp_paddrparams { + * sctp_assoc_t spp_assoc_id; + * struct sockaddr_storage spp_address; + * uint32_t spp_hbinterval; + * uint16_t spp_pathmaxrxt; + * uint32_t spp_pathmtu; + * uint32_t spp_sackdelay; + * uint32_t spp_flags; + * }; + * + * spp_assoc_id - (one-to-many style socket) This is filled in the + * application, and identifies the association for + * this query. + * spp_address - This specifies which address is of interest. + * spp_hbinterval - This contains the value of the heartbeat interval, + * in milliseconds. If a value of zero + * is present in this field then no changes are to + * be made to this parameter. + * spp_pathmaxrxt - This contains the maximum number of + * retransmissions before this address shall be + * considered unreachable. If a value of zero + * is present in this field then no changes are to + * be made to this parameter. + * spp_pathmtu - When Path MTU discovery is disabled the value + * specified here will be the "fixed" path mtu. + * Note that if the spp_address field is empty + * then all associations on this address will + * have this fixed path mtu set upon them. + * + * spp_sackdelay - When delayed sack is enabled, this value specifies + * the number of milliseconds that sacks will be delayed + * for. This value will apply to all addresses of an + * association if the spp_address field is empty. Note + * also, that if delayed sack is enabled and this + * value is set to 0, no change is made to the last + * recorded delayed sack timer value. + * + * spp_flags - These flags are used to control various features + * on an association. The flag field may contain + * zero or more of the following options. + * + * SPP_HB_ENABLE - Enable heartbeats on the + * specified address. Note that if the address + * field is empty all addresses for the association + * have heartbeats enabled upon them. + * + * SPP_HB_DISABLE - Disable heartbeats on the + * speicifed address. Note that if the address + * field is empty all addresses for the association + * will have their heartbeats disabled. Note also + * that SPP_HB_ENABLE and SPP_HB_DISABLE are + * mutually exclusive, only one of these two should + * be specified. Enabling both fields will have + * undetermined results. + * + * SPP_HB_DEMAND - Request a user initiated heartbeat + * to be made immediately. + * + * SPP_PMTUD_ENABLE - This field will enable PMTU + * discovery upon the specified address. Note that + * if the address feild is empty then all addresses + * on the association are effected. + * + * SPP_PMTUD_DISABLE - This field will disable PMTU + * discovery upon the specified address. Note that + * if the address feild is empty then all addresses + * on the association are effected. Not also that + * SPP_PMTUD_ENABLE and SPP_PMTUD_DISABLE are mutually + * exclusive. Enabling both will have undetermined + * results. + * + * SPP_SACKDELAY_ENABLE - Setting this flag turns + * on delayed sack. The time specified in spp_sackdelay + * is used to specify the sack delay for this address. Note + * that if spp_address is empty then all addresses will + * enable delayed sack and take on the sack delay + * value specified in spp_sackdelay. + * SPP_SACKDELAY_DISABLE - Setting this flag turns + * off delayed sack. If the spp_address field is blank then + * delayed sack is disabled for the entire association. Note + * also that this field is mutually exclusive to + * SPP_SACKDELAY_ENABLE, setting both will have undefined + * results. + * + * SPP_IPV6_FLOWLABEL: Setting this flag enables the + * setting of the IPV6 flow label value. The value is + * contained in the spp_ipv6_flowlabel field. + * Upon retrieval, this flag will be set to indicate that + * the spp_ipv6_flowlabel field has a valid value returned. + * If a specific destination address is set (in the + * spp_address field), then the value returned is that of + * the address. If just an association is specified (and + * no address), then the association's default flow label + * is returned. If neither an association nor a destination + * is specified, then the socket's default flow label is + * returned. For non-IPv6 sockets, this flag will be left + * cleared. + * + * SPP_DSCP: Setting this flag enables the setting of the + * Differentiated Services Code Point (DSCP) value + * associated with either the association or a specific + * address. The value is obtained in the spp_dscp field. + * Upon retrieval, this flag will be set to indicate that + * the spp_dscp field has a valid value returned. If a + * specific destination address is set when called (in the + * spp_address field), then that specific destination + * address's DSCP value is returned. If just an association + * is specified, then the association's default DSCP is + * returned. If neither an association nor a destination is + * specified, then the socket's default DSCP is returned. + * + * spp_ipv6_flowlabel + * - This field is used in conjunction with the + * SPP_IPV6_FLOWLABEL flag and contains the IPv6 flow label. + * The 20 least significant bits are used for the flow + * label. This setting has precedence over any IPv6-layer + * setting. + * + * spp_dscp - This field is used in conjunction with the SPP_DSCP flag + * and contains the DSCP. The 6 most significant bits are + * used for the DSCP. This setting has precedence over any + * IPv4- or IPv6- layer setting. + */ +static int sctp_getsockopt_peer_addr_params(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + struct sctp_paddrparams params; + struct sctp_transport *trans = NULL; + struct sctp_association *asoc = NULL; + struct sctp_sock *sp = sctp_sk(sk); + + if (len >= sizeof(params)) + len = sizeof(params); + else if (len >= ALIGN(offsetof(struct sctp_paddrparams, + spp_ipv6_flowlabel), 4)) + len = ALIGN(offsetof(struct sctp_paddrparams, + spp_ipv6_flowlabel), 4); + else + return -EINVAL; + + if (copy_from_user(¶ms, optval, len)) + return -EFAULT; + + /* If an address other than INADDR_ANY is specified, and + * no transport is found, then the request is invalid. + */ + if (!sctp_is_any(sk, (union sctp_addr *)¶ms.spp_address)) { + trans = sctp_addr_id2transport(sk, ¶ms.spp_address, + params.spp_assoc_id); + if (!trans) { + pr_debug("%s: failed no transport\n", __func__); + return -EINVAL; + } + } + + /* Get association, if assoc_id != SCTP_FUTURE_ASSOC and the + * socket is a one to many style socket, and an association + * was not found, then the id was invalid. + */ + asoc = sctp_id2assoc(sk, params.spp_assoc_id); + if (!asoc && params.spp_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { + pr_debug("%s: failed no association\n", __func__); + return -EINVAL; + } + + if (trans) { + /* Fetch transport values. */ + params.spp_hbinterval = jiffies_to_msecs(trans->hbinterval); + params.spp_pathmtu = trans->pathmtu; + params.spp_pathmaxrxt = trans->pathmaxrxt; + params.spp_sackdelay = jiffies_to_msecs(trans->sackdelay); + + /*draft-11 doesn't say what to return in spp_flags*/ + params.spp_flags = trans->param_flags; + if (trans->flowlabel & SCTP_FLOWLABEL_SET_MASK) { + params.spp_ipv6_flowlabel = trans->flowlabel & + SCTP_FLOWLABEL_VAL_MASK; + params.spp_flags |= SPP_IPV6_FLOWLABEL; + } + if (trans->dscp & SCTP_DSCP_SET_MASK) { + params.spp_dscp = trans->dscp & SCTP_DSCP_VAL_MASK; + params.spp_flags |= SPP_DSCP; + } + } else if (asoc) { + /* Fetch association values. */ + params.spp_hbinterval = jiffies_to_msecs(asoc->hbinterval); + params.spp_pathmtu = asoc->pathmtu; + params.spp_pathmaxrxt = asoc->pathmaxrxt; + params.spp_sackdelay = jiffies_to_msecs(asoc->sackdelay); + + /*draft-11 doesn't say what to return in spp_flags*/ + params.spp_flags = asoc->param_flags; + if (asoc->flowlabel & SCTP_FLOWLABEL_SET_MASK) { + params.spp_ipv6_flowlabel = asoc->flowlabel & + SCTP_FLOWLABEL_VAL_MASK; + params.spp_flags |= SPP_IPV6_FLOWLABEL; + } + if (asoc->dscp & SCTP_DSCP_SET_MASK) { + params.spp_dscp = asoc->dscp & SCTP_DSCP_VAL_MASK; + params.spp_flags |= SPP_DSCP; + } + } else { + /* Fetch socket values. */ + params.spp_hbinterval = sp->hbinterval; + params.spp_pathmtu = sp->pathmtu; + params.spp_sackdelay = sp->sackdelay; + params.spp_pathmaxrxt = sp->pathmaxrxt; + + /*draft-11 doesn't say what to return in spp_flags*/ + params.spp_flags = sp->param_flags; + if (sp->flowlabel & SCTP_FLOWLABEL_SET_MASK) { + params.spp_ipv6_flowlabel = sp->flowlabel & + SCTP_FLOWLABEL_VAL_MASK; + params.spp_flags |= SPP_IPV6_FLOWLABEL; + } + if (sp->dscp & SCTP_DSCP_SET_MASK) { + params.spp_dscp = sp->dscp & SCTP_DSCP_VAL_MASK; + params.spp_flags |= SPP_DSCP; + } + } + + if (copy_to_user(optval, ¶ms, len)) + return -EFAULT; + + if (put_user(len, optlen)) + return -EFAULT; + + return 0; +} + +/* + * 7.1.23. Get or set delayed ack timer (SCTP_DELAYED_SACK) + * + * This option will effect the way delayed acks are performed. This + * option allows you to get or set the delayed ack time, in + * milliseconds. It also allows changing the delayed ack frequency. + * Changing the frequency to 1 disables the delayed sack algorithm. If + * the assoc_id is 0, then this sets or gets the endpoints default + * values. If the assoc_id field is non-zero, then the set or get + * effects the specified association for the one to many model (the + * assoc_id field is ignored by the one to one model). Note that if + * sack_delay or sack_freq are 0 when setting this option, then the + * current values will remain unchanged. + * + * struct sctp_sack_info { + * sctp_assoc_t sack_assoc_id; + * uint32_t sack_delay; + * uint32_t sack_freq; + * }; + * + * sack_assoc_id - This parameter, indicates which association the user + * is performing an action upon. Note that if this field's value is + * zero then the endpoints default value is changed (effecting future + * associations only). + * + * sack_delay - This parameter contains the number of milliseconds that + * the user is requesting the delayed ACK timer be set to. Note that + * this value is defined in the standard to be between 200 and 500 + * milliseconds. + * + * sack_freq - This parameter contains the number of packets that must + * be received before a sack is sent without waiting for the delay + * timer to expire. The default value for this is 2, setting this + * value to 1 will disable the delayed sack algorithm. + */ +static int sctp_getsockopt_delayed_ack(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_sack_info params; + struct sctp_association *asoc = NULL; + struct sctp_sock *sp = sctp_sk(sk); + + if (len >= sizeof(struct sctp_sack_info)) { + len = sizeof(struct sctp_sack_info); + + if (copy_from_user(¶ms, optval, len)) + return -EFAULT; + } else if (len == sizeof(struct sctp_assoc_value)) { + pr_warn_ratelimited(DEPRECATED + "%s (pid %d) " + "Use of struct sctp_assoc_value in delayed_ack socket option.\n" + "Use struct sctp_sack_info instead\n", + current->comm, task_pid_nr(current)); + if (copy_from_user(¶ms, optval, len)) + return -EFAULT; + } else + return -EINVAL; + + /* Get association, if sack_assoc_id != SCTP_FUTURE_ASSOC and the + * socket is a one to many style socket, and an association + * was not found, then the id was invalid. + */ + asoc = sctp_id2assoc(sk, params.sack_assoc_id); + if (!asoc && params.sack_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) { + /* Fetch association values. */ + if (asoc->param_flags & SPP_SACKDELAY_ENABLE) { + params.sack_delay = jiffies_to_msecs(asoc->sackdelay); + params.sack_freq = asoc->sackfreq; + + } else { + params.sack_delay = 0; + params.sack_freq = 1; + } + } else { + /* Fetch socket values. */ + if (sp->param_flags & SPP_SACKDELAY_ENABLE) { + params.sack_delay = sp->sackdelay; + params.sack_freq = sp->sackfreq; + } else { + params.sack_delay = 0; + params.sack_freq = 1; + } + } + + if (copy_to_user(optval, ¶ms, len)) + return -EFAULT; + + if (put_user(len, optlen)) + return -EFAULT; + + return 0; +} + +/* 7.1.3 Initialization Parameters (SCTP_INITMSG) + * + * Applications can specify protocol parameters for the default association + * initialization. The option name argument to setsockopt() and getsockopt() + * is SCTP_INITMSG. + * + * Setting initialization parameters is effective only on an unconnected + * socket (for UDP-style sockets only future associations are effected + * by the change). With TCP-style sockets, this option is inherited by + * sockets derived from a listener socket. + */ +static int sctp_getsockopt_initmsg(struct sock *sk, int len, char __user *optval, int __user *optlen) +{ + if (len < sizeof(struct sctp_initmsg)) + return -EINVAL; + len = sizeof(struct sctp_initmsg); + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &sctp_sk(sk)->initmsg, len)) + return -EFAULT; + return 0; +} + + +static int sctp_getsockopt_peer_addrs(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + struct sctp_association *asoc; + int cnt = 0; + struct sctp_getaddrs getaddrs; + struct sctp_transport *from; + void __user *to; + union sctp_addr temp; + struct sctp_sock *sp = sctp_sk(sk); + int addrlen; + size_t space_left; + int bytes_copied; + + if (len < sizeof(struct sctp_getaddrs)) + return -EINVAL; + + if (copy_from_user(&getaddrs, optval, sizeof(struct sctp_getaddrs))) + return -EFAULT; + + /* For UDP-style sockets, id specifies the association to query. */ + asoc = sctp_id2assoc(sk, getaddrs.assoc_id); + if (!asoc) + return -EINVAL; + + to = optval + offsetof(struct sctp_getaddrs, addrs); + space_left = len - offsetof(struct sctp_getaddrs, addrs); + + list_for_each_entry(from, &asoc->peer.transport_addr_list, + transports) { + memcpy(&temp, &from->ipaddr, sizeof(temp)); + addrlen = sctp_get_pf_specific(sk->sk_family) + ->addr_to_user(sp, &temp); + if (space_left < addrlen) + return -ENOMEM; + if (copy_to_user(to, &temp, addrlen)) + return -EFAULT; + to += addrlen; + cnt++; + space_left -= addrlen; + } + + if (put_user(cnt, &((struct sctp_getaddrs __user *)optval)->addr_num)) + return -EFAULT; + bytes_copied = ((char __user *)to) - optval; + if (put_user(bytes_copied, optlen)) + return -EFAULT; + + return 0; +} + +static int sctp_copy_laddrs(struct sock *sk, __u16 port, void *to, + size_t space_left, int *bytes_copied) +{ + struct sctp_sockaddr_entry *addr; + union sctp_addr temp; + int cnt = 0; + int addrlen; + struct net *net = sock_net(sk); + + rcu_read_lock(); + list_for_each_entry_rcu(addr, &net->sctp.local_addr_list, list) { + if (!addr->valid) + continue; + + if ((PF_INET == sk->sk_family) && + (AF_INET6 == addr->a.sa.sa_family)) + continue; + if ((PF_INET6 == sk->sk_family) && + inet_v6_ipv6only(sk) && + (AF_INET == addr->a.sa.sa_family)) + continue; + memcpy(&temp, &addr->a, sizeof(temp)); + if (!temp.v4.sin_port) + temp.v4.sin_port = htons(port); + + addrlen = sctp_get_pf_specific(sk->sk_family) + ->addr_to_user(sctp_sk(sk), &temp); + + if (space_left < addrlen) { + cnt = -ENOMEM; + break; + } + memcpy(to, &temp, addrlen); + + to += addrlen; + cnt++; + space_left -= addrlen; + *bytes_copied += addrlen; + } + rcu_read_unlock(); + + return cnt; +} + + +static int sctp_getsockopt_local_addrs(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + struct sctp_bind_addr *bp; + struct sctp_association *asoc; + int cnt = 0; + struct sctp_getaddrs getaddrs; + struct sctp_sockaddr_entry *addr; + void __user *to; + union sctp_addr temp; + struct sctp_sock *sp = sctp_sk(sk); + int addrlen; + int err = 0; + size_t space_left; + int bytes_copied = 0; + void *addrs; + void *buf; + + if (len < sizeof(struct sctp_getaddrs)) + return -EINVAL; + + if (copy_from_user(&getaddrs, optval, sizeof(struct sctp_getaddrs))) + return -EFAULT; + + /* + * For UDP-style sockets, id specifies the association to query. + * If the id field is set to the value '0' then the locally bound + * addresses are returned without regard to any particular + * association. + */ + if (0 == getaddrs.assoc_id) { + bp = &sctp_sk(sk)->ep->base.bind_addr; + } else { + asoc = sctp_id2assoc(sk, getaddrs.assoc_id); + if (!asoc) + return -EINVAL; + bp = &asoc->base.bind_addr; + } + + to = optval + offsetof(struct sctp_getaddrs, addrs); + space_left = len - offsetof(struct sctp_getaddrs, addrs); + + addrs = kmalloc(space_left, GFP_USER | __GFP_NOWARN); + if (!addrs) + return -ENOMEM; + + /* If the endpoint is bound to 0.0.0.0 or ::0, get the valid + * addresses from the global local address list. + */ + if (sctp_list_single_entry(&bp->address_list)) { + addr = list_entry(bp->address_list.next, + struct sctp_sockaddr_entry, list); + if (sctp_is_any(sk, &addr->a)) { + cnt = sctp_copy_laddrs(sk, bp->port, addrs, + space_left, &bytes_copied); + if (cnt < 0) { + err = cnt; + goto out; + } + goto copy_getaddrs; + } + } + + buf = addrs; + /* Protection on the bound address list is not needed since + * in the socket option context we hold a socket lock and + * thus the bound address list can't change. + */ + list_for_each_entry(addr, &bp->address_list, list) { + memcpy(&temp, &addr->a, sizeof(temp)); + addrlen = sctp_get_pf_specific(sk->sk_family) + ->addr_to_user(sp, &temp); + if (space_left < addrlen) { + err = -ENOMEM; /*fixme: right error?*/ + goto out; + } + memcpy(buf, &temp, addrlen); + buf += addrlen; + bytes_copied += addrlen; + cnt++; + space_left -= addrlen; + } + +copy_getaddrs: + if (copy_to_user(to, addrs, bytes_copied)) { + err = -EFAULT; + goto out; + } + if (put_user(cnt, &((struct sctp_getaddrs __user *)optval)->addr_num)) { + err = -EFAULT; + goto out; + } + /* XXX: We should have accounted for sizeof(struct sctp_getaddrs) too, + * but we can't change it anymore. + */ + if (put_user(bytes_copied, optlen)) + err = -EFAULT; +out: + kfree(addrs); + return err; +} + +/* 7.1.10 Set Primary Address (SCTP_PRIMARY_ADDR) + * + * Requests that the local SCTP stack use the enclosed peer address as + * the association primary. The enclosed address must be one of the + * association peer's addresses. + */ +static int sctp_getsockopt_primary_addr(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + struct sctp_prim prim; + struct sctp_association *asoc; + struct sctp_sock *sp = sctp_sk(sk); + + if (len < sizeof(struct sctp_prim)) + return -EINVAL; + + len = sizeof(struct sctp_prim); + + if (copy_from_user(&prim, optval, len)) + return -EFAULT; + + asoc = sctp_id2assoc(sk, prim.ssp_assoc_id); + if (!asoc) + return -EINVAL; + + if (!asoc->peer.primary_path) + return -ENOTCONN; + + memcpy(&prim.ssp_addr, &asoc->peer.primary_path->ipaddr, + asoc->peer.primary_path->af_specific->sockaddr_len); + + sctp_get_pf_specific(sk->sk_family)->addr_to_user(sp, + (union sctp_addr *)&prim.ssp_addr); + + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &prim, len)) + return -EFAULT; + + return 0; +} + +/* + * 7.1.11 Set Adaptation Layer Indicator (SCTP_ADAPTATION_LAYER) + * + * Requests that the local endpoint set the specified Adaptation Layer + * Indication parameter for all future INIT and INIT-ACK exchanges. + */ +static int sctp_getsockopt_adaptation_layer(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + struct sctp_setadaptation adaptation; + + if (len < sizeof(struct sctp_setadaptation)) + return -EINVAL; + + len = sizeof(struct sctp_setadaptation); + + adaptation.ssb_adaptation_ind = sctp_sk(sk)->adaptation_ind; + + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &adaptation, len)) + return -EFAULT; + + return 0; +} + +/* + * + * 7.1.14 Set default send parameters (SCTP_DEFAULT_SEND_PARAM) + * + * Applications that wish to use the sendto() system call may wish to + * specify a default set of parameters that would normally be supplied + * through the inclusion of ancillary data. This socket option allows + * such an application to set the default sctp_sndrcvinfo structure. + + + * The application that wishes to use this socket option simply passes + * in to this call the sctp_sndrcvinfo structure defined in Section + * 5.2.2) The input parameters accepted by this call include + * sinfo_stream, sinfo_flags, sinfo_ppid, sinfo_context, + * sinfo_timetolive. The user must provide the sinfo_assoc_id field in + * to this call if the caller is using the UDP model. + * + * For getsockopt, it get the default sctp_sndrcvinfo structure. + */ +static int sctp_getsockopt_default_send_param(struct sock *sk, + int len, char __user *optval, + int __user *optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_association *asoc; + struct sctp_sndrcvinfo info; + + if (len < sizeof(info)) + return -EINVAL; + + len = sizeof(info); + + if (copy_from_user(&info, optval, len)) + return -EFAULT; + + asoc = sctp_id2assoc(sk, info.sinfo_assoc_id); + if (!asoc && info.sinfo_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) { + info.sinfo_stream = asoc->default_stream; + info.sinfo_flags = asoc->default_flags; + info.sinfo_ppid = asoc->default_ppid; + info.sinfo_context = asoc->default_context; + info.sinfo_timetolive = asoc->default_timetolive; + } else { + info.sinfo_stream = sp->default_stream; + info.sinfo_flags = sp->default_flags; + info.sinfo_ppid = sp->default_ppid; + info.sinfo_context = sp->default_context; + info.sinfo_timetolive = sp->default_timetolive; + } + + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &info, len)) + return -EFAULT; + + return 0; +} + +/* RFC6458, Section 8.1.31. Set/get Default Send Parameters + * (SCTP_DEFAULT_SNDINFO) + */ +static int sctp_getsockopt_default_sndinfo(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_association *asoc; + struct sctp_sndinfo info; + + if (len < sizeof(info)) + return -EINVAL; + + len = sizeof(info); + + if (copy_from_user(&info, optval, len)) + return -EFAULT; + + asoc = sctp_id2assoc(sk, info.snd_assoc_id); + if (!asoc && info.snd_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) { + info.snd_sid = asoc->default_stream; + info.snd_flags = asoc->default_flags; + info.snd_ppid = asoc->default_ppid; + info.snd_context = asoc->default_context; + } else { + info.snd_sid = sp->default_stream; + info.snd_flags = sp->default_flags; + info.snd_ppid = sp->default_ppid; + info.snd_context = sp->default_context; + } + + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &info, len)) + return -EFAULT; + + return 0; +} + +/* + * + * 7.1.5 SCTP_NODELAY + * + * Turn on/off any Nagle-like algorithm. This means that packets are + * generally sent as soon as possible and no unnecessary delays are + * introduced, at the cost of more packets in the network. Expects an + * integer boolean flag. + */ + +static int sctp_getsockopt_nodelay(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + int val; + + if (len < sizeof(int)) + return -EINVAL; + + len = sizeof(int); + val = (sctp_sk(sk)->nodelay == 1); + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + return 0; +} + +/* + * + * 7.1.1 SCTP_RTOINFO + * + * The protocol parameters used to initialize and bound retransmission + * timeout (RTO) are tunable. sctp_rtoinfo structure is used to access + * and modify these parameters. + * All parameters are time values, in milliseconds. A value of 0, when + * modifying the parameters, indicates that the current value should not + * be changed. + * + */ +static int sctp_getsockopt_rtoinfo(struct sock *sk, int len, + char __user *optval, + int __user *optlen) { + struct sctp_rtoinfo rtoinfo; + struct sctp_association *asoc; + + if (len < sizeof (struct sctp_rtoinfo)) + return -EINVAL; + + len = sizeof(struct sctp_rtoinfo); + + if (copy_from_user(&rtoinfo, optval, len)) + return -EFAULT; + + asoc = sctp_id2assoc(sk, rtoinfo.srto_assoc_id); + + if (!asoc && rtoinfo.srto_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + /* Values corresponding to the specific association. */ + if (asoc) { + rtoinfo.srto_initial = jiffies_to_msecs(asoc->rto_initial); + rtoinfo.srto_max = jiffies_to_msecs(asoc->rto_max); + rtoinfo.srto_min = jiffies_to_msecs(asoc->rto_min); + } else { + /* Values corresponding to the endpoint. */ + struct sctp_sock *sp = sctp_sk(sk); + + rtoinfo.srto_initial = sp->rtoinfo.srto_initial; + rtoinfo.srto_max = sp->rtoinfo.srto_max; + rtoinfo.srto_min = sp->rtoinfo.srto_min; + } + + if (put_user(len, optlen)) + return -EFAULT; + + if (copy_to_user(optval, &rtoinfo, len)) + return -EFAULT; + + return 0; +} + +/* + * + * 7.1.2 SCTP_ASSOCINFO + * + * This option is used to tune the maximum retransmission attempts + * of the association. + * Returns an error if the new association retransmission value is + * greater than the sum of the retransmission value of the peer. + * See [SCTP] for more information. + * + */ +static int sctp_getsockopt_associnfo(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + + struct sctp_assocparams assocparams; + struct sctp_association *asoc; + struct list_head *pos; + int cnt = 0; + + if (len < sizeof (struct sctp_assocparams)) + return -EINVAL; + + len = sizeof(struct sctp_assocparams); + + if (copy_from_user(&assocparams, optval, len)) + return -EFAULT; + + asoc = sctp_id2assoc(sk, assocparams.sasoc_assoc_id); + + if (!asoc && assocparams.sasoc_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + /* Values correspoinding to the specific association */ + if (asoc) { + assocparams.sasoc_asocmaxrxt = asoc->max_retrans; + assocparams.sasoc_peer_rwnd = asoc->peer.rwnd; + assocparams.sasoc_local_rwnd = asoc->a_rwnd; + assocparams.sasoc_cookie_life = ktime_to_ms(asoc->cookie_life); + + list_for_each(pos, &asoc->peer.transport_addr_list) { + cnt++; + } + + assocparams.sasoc_number_peer_destinations = cnt; + } else { + /* Values corresponding to the endpoint */ + struct sctp_sock *sp = sctp_sk(sk); + + assocparams.sasoc_asocmaxrxt = sp->assocparams.sasoc_asocmaxrxt; + assocparams.sasoc_peer_rwnd = sp->assocparams.sasoc_peer_rwnd; + assocparams.sasoc_local_rwnd = sp->assocparams.sasoc_local_rwnd; + assocparams.sasoc_cookie_life = + sp->assocparams.sasoc_cookie_life; + assocparams.sasoc_number_peer_destinations = + sp->assocparams. + sasoc_number_peer_destinations; + } + + if (put_user(len, optlen)) + return -EFAULT; + + if (copy_to_user(optval, &assocparams, len)) + return -EFAULT; + + return 0; +} + +/* + * 7.1.16 Set/clear IPv4 mapped addresses (SCTP_I_WANT_MAPPED_V4_ADDR) + * + * This socket option is a boolean flag which turns on or off mapped V4 + * addresses. If this option is turned on and the socket is type + * PF_INET6, then IPv4 addresses will be mapped to V6 representation. + * If this option is turned off, then no mapping will be done of V4 + * addresses and a user will receive both PF_INET6 and PF_INET type + * addresses on the socket. + */ +static int sctp_getsockopt_mappedv4(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + int val; + struct sctp_sock *sp = sctp_sk(sk); + + if (len < sizeof(int)) + return -EINVAL; + + len = sizeof(int); + val = sp->v4mapped; + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + + return 0; +} + +/* + * 7.1.29. Set or Get the default context (SCTP_CONTEXT) + * (chapter and verse is quoted at sctp_setsockopt_context()) + */ +static int sctp_getsockopt_context(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + struct sctp_assoc_value params; + struct sctp_association *asoc; + + if (len < sizeof(struct sctp_assoc_value)) + return -EINVAL; + + len = sizeof(struct sctp_assoc_value); + + if (copy_from_user(¶ms, optval, len)) + return -EFAULT; + + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + params.assoc_value = asoc ? asoc->default_rcv_context + : sctp_sk(sk)->default_rcv_context; + + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, ¶ms, len)) + return -EFAULT; + + return 0; +} + +/* + * 8.1.16. Get or Set the Maximum Fragmentation Size (SCTP_MAXSEG) + * This option will get or set the maximum size to put in any outgoing + * SCTP DATA chunk. If a message is larger than this size it will be + * fragmented by SCTP into the specified size. Note that the underlying + * SCTP implementation may fragment into smaller sized chunks when the + * PMTU of the underlying association is smaller than the value set by + * the user. The default value for this option is '0' which indicates + * the user is NOT limiting fragmentation and only the PMTU will effect + * SCTP's choice of DATA chunk size. Note also that values set larger + * than the maximum size of an IP datagram will effectively let SCTP + * control fragmentation (i.e. the same as setting this option to 0). + * + * The following structure is used to access and modify this parameter: + * + * struct sctp_assoc_value { + * sctp_assoc_t assoc_id; + * uint32_t assoc_value; + * }; + * + * assoc_id: This parameter is ignored for one-to-one style sockets. + * For one-to-many style sockets this parameter indicates which + * association the user is performing an action upon. Note that if + * this field's value is zero then the endpoints default value is + * changed (effecting future associations only). + * assoc_value: This parameter specifies the maximum size in bytes. + */ +static int sctp_getsockopt_maxseg(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + struct sctp_assoc_value params; + struct sctp_association *asoc; + + if (len == sizeof(int)) { + pr_warn_ratelimited(DEPRECATED + "%s (pid %d) " + "Use of int in maxseg socket option.\n" + "Use struct sctp_assoc_value instead\n", + current->comm, task_pid_nr(current)); + params.assoc_id = SCTP_FUTURE_ASSOC; + } else if (len >= sizeof(struct sctp_assoc_value)) { + len = sizeof(struct sctp_assoc_value); + if (copy_from_user(¶ms, optval, len)) + return -EFAULT; + } else + return -EINVAL; + + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) + params.assoc_value = asoc->frag_point; + else + params.assoc_value = sctp_sk(sk)->user_frag; + + if (put_user(len, optlen)) + return -EFAULT; + if (len == sizeof(int)) { + if (copy_to_user(optval, ¶ms.assoc_value, len)) + return -EFAULT; + } else { + if (copy_to_user(optval, ¶ms, len)) + return -EFAULT; + } + + return 0; +} + +/* + * 7.1.24. Get or set fragmented interleave (SCTP_FRAGMENT_INTERLEAVE) + * (chapter and verse is quoted at sctp_setsockopt_fragment_interleave()) + */ +static int sctp_getsockopt_fragment_interleave(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + int val; + + if (len < sizeof(int)) + return -EINVAL; + + len = sizeof(int); + + val = sctp_sk(sk)->frag_interleave; + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + + return 0; +} + +/* + * 7.1.25. Set or Get the sctp partial delivery point + * (chapter and verse is quoted at sctp_setsockopt_partial_delivery_point()) + */ +static int sctp_getsockopt_partial_delivery_point(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + u32 val; + + if (len < sizeof(u32)) + return -EINVAL; + + len = sizeof(u32); + + val = sctp_sk(sk)->pd_point; + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + + return 0; +} + +/* + * 7.1.28. Set or Get the maximum burst (SCTP_MAX_BURST) + * (chapter and verse is quoted at sctp_setsockopt_maxburst()) + */ +static int sctp_getsockopt_maxburst(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_assoc_value params; + struct sctp_association *asoc; + + if (len == sizeof(int)) { + pr_warn_ratelimited(DEPRECATED + "%s (pid %d) " + "Use of int in max_burst socket option.\n" + "Use struct sctp_assoc_value instead\n", + current->comm, task_pid_nr(current)); + params.assoc_id = SCTP_FUTURE_ASSOC; + } else if (len >= sizeof(struct sctp_assoc_value)) { + len = sizeof(struct sctp_assoc_value); + if (copy_from_user(¶ms, optval, len)) + return -EFAULT; + } else + return -EINVAL; + + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + params.assoc_value = asoc ? asoc->max_burst : sctp_sk(sk)->max_burst; + + if (len == sizeof(int)) { + if (copy_to_user(optval, ¶ms.assoc_value, len)) + return -EFAULT; + } else { + if (copy_to_user(optval, ¶ms, len)) + return -EFAULT; + } + + return 0; + +} + +static int sctp_getsockopt_hmac_ident(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + struct sctp_endpoint *ep = sctp_sk(sk)->ep; + struct sctp_hmacalgo __user *p = (void __user *)optval; + struct sctp_hmac_algo_param *hmacs; + __u16 data_len = 0; + u32 num_idents; + int i; + + if (!ep->auth_enable) + return -EACCES; + + hmacs = ep->auth_hmacs_list; + data_len = ntohs(hmacs->param_hdr.length) - + sizeof(struct sctp_paramhdr); + + if (len < sizeof(struct sctp_hmacalgo) + data_len) + return -EINVAL; + + len = sizeof(struct sctp_hmacalgo) + data_len; + num_idents = data_len / sizeof(u16); + + if (put_user(len, optlen)) + return -EFAULT; + if (put_user(num_idents, &p->shmac_num_idents)) + return -EFAULT; + for (i = 0; i < num_idents; i++) { + __u16 hmacid = ntohs(hmacs->hmac_ids[i]); + + if (copy_to_user(&p->shmac_idents[i], &hmacid, sizeof(__u16))) + return -EFAULT; + } + return 0; +} + +static int sctp_getsockopt_active_key(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + struct sctp_endpoint *ep = sctp_sk(sk)->ep; + struct sctp_authkeyid val; + struct sctp_association *asoc; + + if (len < sizeof(struct sctp_authkeyid)) + return -EINVAL; + + len = sizeof(struct sctp_authkeyid); + if (copy_from_user(&val, optval, len)) + return -EFAULT; + + asoc = sctp_id2assoc(sk, val.scact_assoc_id); + if (!asoc && val.scact_assoc_id && sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) { + if (!asoc->peer.auth_capable) + return -EACCES; + val.scact_keynumber = asoc->active_key_id; + } else { + if (!ep->auth_enable) + return -EACCES; + val.scact_keynumber = ep->active_key_id; + } + + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + + return 0; +} + +static int sctp_getsockopt_peer_auth_chunks(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + struct sctp_authchunks __user *p = (void __user *)optval; + struct sctp_authchunks val; + struct sctp_association *asoc; + struct sctp_chunks_param *ch; + u32 num_chunks = 0; + char __user *to; + + if (len < sizeof(struct sctp_authchunks)) + return -EINVAL; + + if (copy_from_user(&val, optval, sizeof(val))) + return -EFAULT; + + to = p->gauth_chunks; + asoc = sctp_id2assoc(sk, val.gauth_assoc_id); + if (!asoc) + return -EINVAL; + + if (!asoc->peer.auth_capable) + return -EACCES; + + ch = asoc->peer.peer_chunks; + if (!ch) + goto num; + + /* See if the user provided enough room for all the data */ + num_chunks = ntohs(ch->param_hdr.length) - sizeof(struct sctp_paramhdr); + if (len < num_chunks) + return -EINVAL; + + if (copy_to_user(to, ch->chunks, num_chunks)) + return -EFAULT; +num: + len = sizeof(struct sctp_authchunks) + num_chunks; + if (put_user(len, optlen)) + return -EFAULT; + if (put_user(num_chunks, &p->gauth_number_of_chunks)) + return -EFAULT; + return 0; +} + +static int sctp_getsockopt_local_auth_chunks(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + struct sctp_endpoint *ep = sctp_sk(sk)->ep; + struct sctp_authchunks __user *p = (void __user *)optval; + struct sctp_authchunks val; + struct sctp_association *asoc; + struct sctp_chunks_param *ch; + u32 num_chunks = 0; + char __user *to; + + if (len < sizeof(struct sctp_authchunks)) + return -EINVAL; + + if (copy_from_user(&val, optval, sizeof(val))) + return -EFAULT; + + to = p->gauth_chunks; + asoc = sctp_id2assoc(sk, val.gauth_assoc_id); + if (!asoc && val.gauth_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) { + if (!asoc->peer.auth_capable) + return -EACCES; + ch = (struct sctp_chunks_param *)asoc->c.auth_chunks; + } else { + if (!ep->auth_enable) + return -EACCES; + ch = ep->auth_chunk_list; + } + if (!ch) + goto num; + + num_chunks = ntohs(ch->param_hdr.length) - sizeof(struct sctp_paramhdr); + if (len < sizeof(struct sctp_authchunks) + num_chunks) + return -EINVAL; + + if (copy_to_user(to, ch->chunks, num_chunks)) + return -EFAULT; +num: + len = sizeof(struct sctp_authchunks) + num_chunks; + if (put_user(len, optlen)) + return -EFAULT; + if (put_user(num_chunks, &p->gauth_number_of_chunks)) + return -EFAULT; + + return 0; +} + +/* + * 8.2.5. Get the Current Number of Associations (SCTP_GET_ASSOC_NUMBER) + * This option gets the current number of associations that are attached + * to a one-to-many style socket. The option value is an uint32_t. + */ +static int sctp_getsockopt_assoc_number(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_association *asoc; + u32 val = 0; + + if (sctp_style(sk, TCP)) + return -EOPNOTSUPP; + + if (len < sizeof(u32)) + return -EINVAL; + + len = sizeof(u32); + + list_for_each_entry(asoc, &(sp->ep->asocs), asocs) { + val++; + } + + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + + return 0; +} + +/* + * 8.1.23 SCTP_AUTO_ASCONF + * See the corresponding setsockopt entry as description + */ +static int sctp_getsockopt_auto_asconf(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + int val = 0; + + if (len < sizeof(int)) + return -EINVAL; + + len = sizeof(int); + if (sctp_sk(sk)->do_auto_asconf && sctp_is_ep_boundall(sk)) + val = 1; + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + return 0; +} + +/* + * 8.2.6. Get the Current Identifiers of Associations + * (SCTP_GET_ASSOC_ID_LIST) + * + * This option gets the current list of SCTP association identifiers of + * the SCTP associations handled by a one-to-many style socket. + */ +static int sctp_getsockopt_assoc_ids(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_association *asoc; + struct sctp_assoc_ids *ids; + u32 num = 0; + + if (sctp_style(sk, TCP)) + return -EOPNOTSUPP; + + if (len < sizeof(struct sctp_assoc_ids)) + return -EINVAL; + + list_for_each_entry(asoc, &(sp->ep->asocs), asocs) { + num++; + } + + if (len < sizeof(struct sctp_assoc_ids) + sizeof(sctp_assoc_t) * num) + return -EINVAL; + + len = sizeof(struct sctp_assoc_ids) + sizeof(sctp_assoc_t) * num; + + ids = kmalloc(len, GFP_USER | __GFP_NOWARN); + if (unlikely(!ids)) + return -ENOMEM; + + ids->gaids_number_of_ids = num; + num = 0; + list_for_each_entry(asoc, &(sp->ep->asocs), asocs) { + ids->gaids_assoc_id[num++] = asoc->assoc_id; + } + + if (put_user(len, optlen) || copy_to_user(optval, ids, len)) { + kfree(ids); + return -EFAULT; + } + + kfree(ids); + return 0; +} + +/* + * SCTP_PEER_ADDR_THLDS + * + * This option allows us to fetch the partially failed threshold for one or all + * transports in an association. See Section 6.1 of: + * http://www.ietf.org/id/draft-nishida-tsvwg-sctp-failover-05.txt + */ +static int sctp_getsockopt_paddr_thresholds(struct sock *sk, + char __user *optval, int len, + int __user *optlen, bool v2) +{ + struct sctp_paddrthlds_v2 val; + struct sctp_transport *trans; + struct sctp_association *asoc; + int min; + + min = v2 ? sizeof(val) : sizeof(struct sctp_paddrthlds); + if (len < min) + return -EINVAL; + len = min; + if (copy_from_user(&val, optval, len)) + return -EFAULT; + + if (!sctp_is_any(sk, (const union sctp_addr *)&val.spt_address)) { + trans = sctp_addr_id2transport(sk, &val.spt_address, + val.spt_assoc_id); + if (!trans) + return -ENOENT; + + val.spt_pathmaxrxt = trans->pathmaxrxt; + val.spt_pathpfthld = trans->pf_retrans; + val.spt_pathcpthld = trans->ps_retrans; + + goto out; + } + + asoc = sctp_id2assoc(sk, val.spt_assoc_id); + if (!asoc && val.spt_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + if (asoc) { + val.spt_pathpfthld = asoc->pf_retrans; + val.spt_pathmaxrxt = asoc->pathmaxrxt; + val.spt_pathcpthld = asoc->ps_retrans; + } else { + struct sctp_sock *sp = sctp_sk(sk); + + val.spt_pathpfthld = sp->pf_retrans; + val.spt_pathmaxrxt = sp->pathmaxrxt; + val.spt_pathcpthld = sp->ps_retrans; + } + +out: + if (put_user(len, optlen) || copy_to_user(optval, &val, len)) + return -EFAULT; + + return 0; +} + +/* + * SCTP_GET_ASSOC_STATS + * + * This option retrieves local per endpoint statistics. It is modeled + * after OpenSolaris' implementation + */ +static int sctp_getsockopt_assoc_stats(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_assoc_stats sas; + struct sctp_association *asoc = NULL; + + /* User must provide at least the assoc id */ + if (len < sizeof(sctp_assoc_t)) + return -EINVAL; + + /* Allow the struct to grow and fill in as much as possible */ + len = min_t(size_t, len, sizeof(sas)); + + if (copy_from_user(&sas, optval, len)) + return -EFAULT; + + asoc = sctp_id2assoc(sk, sas.sas_assoc_id); + if (!asoc) + return -EINVAL; + + sas.sas_rtxchunks = asoc->stats.rtxchunks; + sas.sas_gapcnt = asoc->stats.gapcnt; + sas.sas_outofseqtsns = asoc->stats.outofseqtsns; + sas.sas_osacks = asoc->stats.osacks; + sas.sas_isacks = asoc->stats.isacks; + sas.sas_octrlchunks = asoc->stats.octrlchunks; + sas.sas_ictrlchunks = asoc->stats.ictrlchunks; + sas.sas_oodchunks = asoc->stats.oodchunks; + sas.sas_iodchunks = asoc->stats.iodchunks; + sas.sas_ouodchunks = asoc->stats.ouodchunks; + sas.sas_iuodchunks = asoc->stats.iuodchunks; + sas.sas_idupchunks = asoc->stats.idupchunks; + sas.sas_opackets = asoc->stats.opackets; + sas.sas_ipackets = asoc->stats.ipackets; + + /* New high max rto observed, will return 0 if not a single + * RTO update took place. obs_rto_ipaddr will be bogus + * in such a case + */ + sas.sas_maxrto = asoc->stats.max_obs_rto; + memcpy(&sas.sas_obs_rto_ipaddr, &asoc->stats.obs_rto_ipaddr, + sizeof(struct sockaddr_storage)); + + /* Mark beginning of a new observation period */ + asoc->stats.max_obs_rto = asoc->rto_min; + + if (put_user(len, optlen)) + return -EFAULT; + + pr_debug("%s: len:%d, assoc_id:%d\n", __func__, len, sas.sas_assoc_id); + + if (copy_to_user(optval, &sas, len)) + return -EFAULT; + + return 0; +} + +static int sctp_getsockopt_recvrcvinfo(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + int val = 0; + + if (len < sizeof(int)) + return -EINVAL; + + len = sizeof(int); + if (sctp_sk(sk)->recvrcvinfo) + val = 1; + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + + return 0; +} + +static int sctp_getsockopt_recvnxtinfo(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + int val = 0; + + if (len < sizeof(int)) + return -EINVAL; + + len = sizeof(int); + if (sctp_sk(sk)->recvnxtinfo) + val = 1; + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + + return 0; +} + +static int sctp_getsockopt_pr_supported(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_assoc_value params; + struct sctp_association *asoc; + int retval = -EFAULT; + + if (len < sizeof(params)) { + retval = -EINVAL; + goto out; + } + + len = sizeof(params); + if (copy_from_user(¶ms, optval, len)) + goto out; + + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { + retval = -EINVAL; + goto out; + } + + params.assoc_value = asoc ? asoc->peer.prsctp_capable + : sctp_sk(sk)->ep->prsctp_enable; + + if (put_user(len, optlen)) + goto out; + + if (copy_to_user(optval, ¶ms, len)) + goto out; + + retval = 0; + +out: + return retval; +} + +static int sctp_getsockopt_default_prinfo(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_default_prinfo info; + struct sctp_association *asoc; + int retval = -EFAULT; + + if (len < sizeof(info)) { + retval = -EINVAL; + goto out; + } + + len = sizeof(info); + if (copy_from_user(&info, optval, len)) + goto out; + + asoc = sctp_id2assoc(sk, info.pr_assoc_id); + if (!asoc && info.pr_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { + retval = -EINVAL; + goto out; + } + + if (asoc) { + info.pr_policy = SCTP_PR_POLICY(asoc->default_flags); + info.pr_value = asoc->default_timetolive; + } else { + struct sctp_sock *sp = sctp_sk(sk); + + info.pr_policy = SCTP_PR_POLICY(sp->default_flags); + info.pr_value = sp->default_timetolive; + } + + if (put_user(len, optlen)) + goto out; + + if (copy_to_user(optval, &info, len)) + goto out; + + retval = 0; + +out: + return retval; +} + +static int sctp_getsockopt_pr_assocstatus(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_prstatus params; + struct sctp_association *asoc; + int policy; + int retval = -EINVAL; + + if (len < sizeof(params)) + goto out; + + len = sizeof(params); + if (copy_from_user(¶ms, optval, len)) { + retval = -EFAULT; + goto out; + } + + policy = params.sprstat_policy; + if (!policy || (policy & ~(SCTP_PR_SCTP_MASK | SCTP_PR_SCTP_ALL)) || + ((policy & SCTP_PR_SCTP_ALL) && (policy & SCTP_PR_SCTP_MASK))) + goto out; + + asoc = sctp_id2assoc(sk, params.sprstat_assoc_id); + if (!asoc) + goto out; + + if (policy == SCTP_PR_SCTP_ALL) { + params.sprstat_abandoned_unsent = 0; + params.sprstat_abandoned_sent = 0; + for (policy = 0; policy <= SCTP_PR_INDEX(MAX); policy++) { + params.sprstat_abandoned_unsent += + asoc->abandoned_unsent[policy]; + params.sprstat_abandoned_sent += + asoc->abandoned_sent[policy]; + } + } else { + params.sprstat_abandoned_unsent = + asoc->abandoned_unsent[__SCTP_PR_INDEX(policy)]; + params.sprstat_abandoned_sent = + asoc->abandoned_sent[__SCTP_PR_INDEX(policy)]; + } + + if (put_user(len, optlen)) { + retval = -EFAULT; + goto out; + } + + if (copy_to_user(optval, ¶ms, len)) { + retval = -EFAULT; + goto out; + } + + retval = 0; + +out: + return retval; +} + +static int sctp_getsockopt_pr_streamstatus(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_stream_out_ext *streamoute; + struct sctp_association *asoc; + struct sctp_prstatus params; + int retval = -EINVAL; + int policy; + + if (len < sizeof(params)) + goto out; + + len = sizeof(params); + if (copy_from_user(¶ms, optval, len)) { + retval = -EFAULT; + goto out; + } + + policy = params.sprstat_policy; + if (!policy || (policy & ~(SCTP_PR_SCTP_MASK | SCTP_PR_SCTP_ALL)) || + ((policy & SCTP_PR_SCTP_ALL) && (policy & SCTP_PR_SCTP_MASK))) + goto out; + + asoc = sctp_id2assoc(sk, params.sprstat_assoc_id); + if (!asoc || params.sprstat_sid >= asoc->stream.outcnt) + goto out; + + streamoute = SCTP_SO(&asoc->stream, params.sprstat_sid)->ext; + if (!streamoute) { + /* Not allocated yet, means all stats are 0 */ + params.sprstat_abandoned_unsent = 0; + params.sprstat_abandoned_sent = 0; + retval = 0; + goto out; + } + + if (policy == SCTP_PR_SCTP_ALL) { + params.sprstat_abandoned_unsent = 0; + params.sprstat_abandoned_sent = 0; + for (policy = 0; policy <= SCTP_PR_INDEX(MAX); policy++) { + params.sprstat_abandoned_unsent += + streamoute->abandoned_unsent[policy]; + params.sprstat_abandoned_sent += + streamoute->abandoned_sent[policy]; + } + } else { + params.sprstat_abandoned_unsent = + streamoute->abandoned_unsent[__SCTP_PR_INDEX(policy)]; + params.sprstat_abandoned_sent = + streamoute->abandoned_sent[__SCTP_PR_INDEX(policy)]; + } + + if (put_user(len, optlen) || copy_to_user(optval, ¶ms, len)) { + retval = -EFAULT; + goto out; + } + + retval = 0; + +out: + return retval; +} + +static int sctp_getsockopt_reconfig_supported(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_assoc_value params; + struct sctp_association *asoc; + int retval = -EFAULT; + + if (len < sizeof(params)) { + retval = -EINVAL; + goto out; + } + + len = sizeof(params); + if (copy_from_user(¶ms, optval, len)) + goto out; + + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { + retval = -EINVAL; + goto out; + } + + params.assoc_value = asoc ? asoc->peer.reconf_capable + : sctp_sk(sk)->ep->reconf_enable; + + if (put_user(len, optlen)) + goto out; + + if (copy_to_user(optval, ¶ms, len)) + goto out; + + retval = 0; + +out: + return retval; +} + +static int sctp_getsockopt_enable_strreset(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_assoc_value params; + struct sctp_association *asoc; + int retval = -EFAULT; + + if (len < sizeof(params)) { + retval = -EINVAL; + goto out; + } + + len = sizeof(params); + if (copy_from_user(¶ms, optval, len)) + goto out; + + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { + retval = -EINVAL; + goto out; + } + + params.assoc_value = asoc ? asoc->strreset_enable + : sctp_sk(sk)->ep->strreset_enable; + + if (put_user(len, optlen)) + goto out; + + if (copy_to_user(optval, ¶ms, len)) + goto out; + + retval = 0; + +out: + return retval; +} + +static int sctp_getsockopt_scheduler(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_assoc_value params; + struct sctp_association *asoc; + int retval = -EFAULT; + + if (len < sizeof(params)) { + retval = -EINVAL; + goto out; + } + + len = sizeof(params); + if (copy_from_user(¶ms, optval, len)) + goto out; + + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { + retval = -EINVAL; + goto out; + } + + params.assoc_value = asoc ? sctp_sched_get_sched(asoc) + : sctp_sk(sk)->default_ss; + + if (put_user(len, optlen)) + goto out; + + if (copy_to_user(optval, ¶ms, len)) + goto out; + + retval = 0; + +out: + return retval; +} + +static int sctp_getsockopt_scheduler_value(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_stream_value params; + struct sctp_association *asoc; + int retval = -EFAULT; + + if (len < sizeof(params)) { + retval = -EINVAL; + goto out; + } + + len = sizeof(params); + if (copy_from_user(¶ms, optval, len)) + goto out; + + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc) { + retval = -EINVAL; + goto out; + } + + retval = sctp_sched_get_value(asoc, params.stream_id, + ¶ms.stream_value); + if (retval) + goto out; + + if (put_user(len, optlen)) { + retval = -EFAULT; + goto out; + } + + if (copy_to_user(optval, ¶ms, len)) { + retval = -EFAULT; + goto out; + } + +out: + return retval; +} + +static int sctp_getsockopt_interleaving_supported(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_assoc_value params; + struct sctp_association *asoc; + int retval = -EFAULT; + + if (len < sizeof(params)) { + retval = -EINVAL; + goto out; + } + + len = sizeof(params); + if (copy_from_user(¶ms, optval, len)) + goto out; + + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { + retval = -EINVAL; + goto out; + } + + params.assoc_value = asoc ? asoc->peer.intl_capable + : sctp_sk(sk)->ep->intl_enable; + + if (put_user(len, optlen)) + goto out; + + if (copy_to_user(optval, ¶ms, len)) + goto out; + + retval = 0; + +out: + return retval; +} + +static int sctp_getsockopt_reuse_port(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + int val; + + if (len < sizeof(int)) + return -EINVAL; + + len = sizeof(int); + val = sctp_sk(sk)->reuse; + if (put_user(len, optlen)) + return -EFAULT; + + if (copy_to_user(optval, &val, len)) + return -EFAULT; + + return 0; +} + +static int sctp_getsockopt_event(struct sock *sk, int len, char __user *optval, + int __user *optlen) +{ + struct sctp_association *asoc; + struct sctp_event param; + __u16 subscribe; + + if (len < sizeof(param)) + return -EINVAL; + + len = sizeof(param); + if (copy_from_user(¶m, optval, len)) + return -EFAULT; + + if (param.se_type < SCTP_SN_TYPE_BASE || + param.se_type > SCTP_SN_TYPE_MAX) + return -EINVAL; + + asoc = sctp_id2assoc(sk, param.se_assoc_id); + if (!asoc && param.se_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) + return -EINVAL; + + subscribe = asoc ? asoc->subscribe : sctp_sk(sk)->subscribe; + param.se_on = sctp_ulpevent_type_enabled(subscribe, param.se_type); + + if (put_user(len, optlen)) + return -EFAULT; + + if (copy_to_user(optval, ¶m, len)) + return -EFAULT; + + return 0; +} + +static int sctp_getsockopt_asconf_supported(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_assoc_value params; + struct sctp_association *asoc; + int retval = -EFAULT; + + if (len < sizeof(params)) { + retval = -EINVAL; + goto out; + } + + len = sizeof(params); + if (copy_from_user(¶ms, optval, len)) + goto out; + + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { + retval = -EINVAL; + goto out; + } + + params.assoc_value = asoc ? asoc->peer.asconf_capable + : sctp_sk(sk)->ep->asconf_enable; + + if (put_user(len, optlen)) + goto out; + + if (copy_to_user(optval, ¶ms, len)) + goto out; + + retval = 0; + +out: + return retval; +} + +static int sctp_getsockopt_auth_supported(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_assoc_value params; + struct sctp_association *asoc; + int retval = -EFAULT; + + if (len < sizeof(params)) { + retval = -EINVAL; + goto out; + } + + len = sizeof(params); + if (copy_from_user(¶ms, optval, len)) + goto out; + + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { + retval = -EINVAL; + goto out; + } + + params.assoc_value = asoc ? asoc->peer.auth_capable + : sctp_sk(sk)->ep->auth_enable; + + if (put_user(len, optlen)) + goto out; + + if (copy_to_user(optval, ¶ms, len)) + goto out; + + retval = 0; + +out: + return retval; +} + +static int sctp_getsockopt_ecn_supported(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_assoc_value params; + struct sctp_association *asoc; + int retval = -EFAULT; + + if (len < sizeof(params)) { + retval = -EINVAL; + goto out; + } + + len = sizeof(params); + if (copy_from_user(¶ms, optval, len)) + goto out; + + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { + retval = -EINVAL; + goto out; + } + + params.assoc_value = asoc ? asoc->peer.ecn_capable + : sctp_sk(sk)->ep->ecn_enable; + + if (put_user(len, optlen)) + goto out; + + if (copy_to_user(optval, ¶ms, len)) + goto out; + + retval = 0; + +out: + return retval; +} + +static int sctp_getsockopt_pf_expose(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_assoc_value params; + struct sctp_association *asoc; + int retval = -EFAULT; + + if (len < sizeof(params)) { + retval = -EINVAL; + goto out; + } + + len = sizeof(params); + if (copy_from_user(¶ms, optval, len)) + goto out; + + asoc = sctp_id2assoc(sk, params.assoc_id); + if (!asoc && params.assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { + retval = -EINVAL; + goto out; + } + + params.assoc_value = asoc ? asoc->pf_expose + : sctp_sk(sk)->pf_expose; + + if (put_user(len, optlen)) + goto out; + + if (copy_to_user(optval, ¶ms, len)) + goto out; + + retval = 0; + +out: + return retval; +} + +static int sctp_getsockopt_encap_port(struct sock *sk, int len, + char __user *optval, int __user *optlen) +{ + struct sctp_association *asoc; + struct sctp_udpencaps encap; + struct sctp_transport *t; + __be16 encap_port; + + if (len < sizeof(encap)) + return -EINVAL; + + len = sizeof(encap); + if (copy_from_user(&encap, optval, len)) + return -EFAULT; + + /* If an address other than INADDR_ANY is specified, and + * no transport is found, then the request is invalid. + */ + if (!sctp_is_any(sk, (union sctp_addr *)&encap.sue_address)) { + t = sctp_addr_id2transport(sk, &encap.sue_address, + encap.sue_assoc_id); + if (!t) { + pr_debug("%s: failed no transport\n", __func__); + return -EINVAL; + } + + encap_port = t->encap_port; + goto out; + } + + /* Get association, if assoc_id != SCTP_FUTURE_ASSOC and the + * socket is a one to many style socket, and an association + * was not found, then the id was invalid. + */ + asoc = sctp_id2assoc(sk, encap.sue_assoc_id); + if (!asoc && encap.sue_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { + pr_debug("%s: failed no association\n", __func__); + return -EINVAL; + } + + if (asoc) { + encap_port = asoc->encap_port; + goto out; + } + + encap_port = sctp_sk(sk)->encap_port; + +out: + encap.sue_port = (__force uint16_t)encap_port; + if (copy_to_user(optval, &encap, len)) + return -EFAULT; + + if (put_user(len, optlen)) + return -EFAULT; + + return 0; +} + +static int sctp_getsockopt_probe_interval(struct sock *sk, int len, + char __user *optval, + int __user *optlen) +{ + struct sctp_probeinterval params; + struct sctp_association *asoc; + struct sctp_transport *t; + __u32 probe_interval; + + if (len < sizeof(params)) + return -EINVAL; + + len = sizeof(params); + if (copy_from_user(¶ms, optval, len)) + return -EFAULT; + + /* If an address other than INADDR_ANY is specified, and + * no transport is found, then the request is invalid. + */ + if (!sctp_is_any(sk, (union sctp_addr *)¶ms.spi_address)) { + t = sctp_addr_id2transport(sk, ¶ms.spi_address, + params.spi_assoc_id); + if (!t) { + pr_debug("%s: failed no transport\n", __func__); + return -EINVAL; + } + + probe_interval = jiffies_to_msecs(t->probe_interval); + goto out; + } + + /* Get association, if assoc_id != SCTP_FUTURE_ASSOC and the + * socket is a one to many style socket, and an association + * was not found, then the id was invalid. + */ + asoc = sctp_id2assoc(sk, params.spi_assoc_id); + if (!asoc && params.spi_assoc_id != SCTP_FUTURE_ASSOC && + sctp_style(sk, UDP)) { + pr_debug("%s: failed no association\n", __func__); + return -EINVAL; + } + + if (asoc) { + probe_interval = jiffies_to_msecs(asoc->probe_interval); + goto out; + } + + probe_interval = sctp_sk(sk)->probe_interval; + +out: + params.spi_interval = probe_interval; + if (copy_to_user(optval, ¶ms, len)) + return -EFAULT; + + if (put_user(len, optlen)) + return -EFAULT; + + return 0; +} + +static int sctp_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen) +{ + int retval = 0; + int len; + + pr_debug("%s: sk:%p, optname:%d\n", __func__, sk, optname); + + /* I can hardly begin to describe how wrong this is. This is + * so broken as to be worse than useless. The API draft + * REALLY is NOT helpful here... I am not convinced that the + * semantics of getsockopt() with a level OTHER THAN SOL_SCTP + * are at all well-founded. + */ + if (level != SOL_SCTP) { + struct sctp_af *af = sctp_sk(sk)->pf->af; + + retval = af->getsockopt(sk, level, optname, optval, optlen); + return retval; + } + + if (get_user(len, optlen)) + return -EFAULT; + + if (len < 0) + return -EINVAL; + + lock_sock(sk); + + switch (optname) { + case SCTP_STATUS: + retval = sctp_getsockopt_sctp_status(sk, len, optval, optlen); + break; + case SCTP_DISABLE_FRAGMENTS: + retval = sctp_getsockopt_disable_fragments(sk, len, optval, + optlen); + break; + case SCTP_EVENTS: + retval = sctp_getsockopt_events(sk, len, optval, optlen); + break; + case SCTP_AUTOCLOSE: + retval = sctp_getsockopt_autoclose(sk, len, optval, optlen); + break; + case SCTP_SOCKOPT_PEELOFF: + retval = sctp_getsockopt_peeloff(sk, len, optval, optlen); + break; + case SCTP_SOCKOPT_PEELOFF_FLAGS: + retval = sctp_getsockopt_peeloff_flags(sk, len, optval, optlen); + break; + case SCTP_PEER_ADDR_PARAMS: + retval = sctp_getsockopt_peer_addr_params(sk, len, optval, + optlen); + break; + case SCTP_DELAYED_SACK: + retval = sctp_getsockopt_delayed_ack(sk, len, optval, + optlen); + break; + case SCTP_INITMSG: + retval = sctp_getsockopt_initmsg(sk, len, optval, optlen); + break; + case SCTP_GET_PEER_ADDRS: + retval = sctp_getsockopt_peer_addrs(sk, len, optval, + optlen); + break; + case SCTP_GET_LOCAL_ADDRS: + retval = sctp_getsockopt_local_addrs(sk, len, optval, + optlen); + break; + case SCTP_SOCKOPT_CONNECTX3: + retval = sctp_getsockopt_connectx3(sk, len, optval, optlen); + break; + case SCTP_DEFAULT_SEND_PARAM: + retval = sctp_getsockopt_default_send_param(sk, len, + optval, optlen); + break; + case SCTP_DEFAULT_SNDINFO: + retval = sctp_getsockopt_default_sndinfo(sk, len, + optval, optlen); + break; + case SCTP_PRIMARY_ADDR: + retval = sctp_getsockopt_primary_addr(sk, len, optval, optlen); + break; + case SCTP_NODELAY: + retval = sctp_getsockopt_nodelay(sk, len, optval, optlen); + break; + case SCTP_RTOINFO: + retval = sctp_getsockopt_rtoinfo(sk, len, optval, optlen); + break; + case SCTP_ASSOCINFO: + retval = sctp_getsockopt_associnfo(sk, len, optval, optlen); + break; + case SCTP_I_WANT_MAPPED_V4_ADDR: + retval = sctp_getsockopt_mappedv4(sk, len, optval, optlen); + break; + case SCTP_MAXSEG: + retval = sctp_getsockopt_maxseg(sk, len, optval, optlen); + break; + case SCTP_GET_PEER_ADDR_INFO: + retval = sctp_getsockopt_peer_addr_info(sk, len, optval, + optlen); + break; + case SCTP_ADAPTATION_LAYER: + retval = sctp_getsockopt_adaptation_layer(sk, len, optval, + optlen); + break; + case SCTP_CONTEXT: + retval = sctp_getsockopt_context(sk, len, optval, optlen); + break; + case SCTP_FRAGMENT_INTERLEAVE: + retval = sctp_getsockopt_fragment_interleave(sk, len, optval, + optlen); + break; + case SCTP_PARTIAL_DELIVERY_POINT: + retval = sctp_getsockopt_partial_delivery_point(sk, len, optval, + optlen); + break; + case SCTP_MAX_BURST: + retval = sctp_getsockopt_maxburst(sk, len, optval, optlen); + break; + case SCTP_AUTH_KEY: + case SCTP_AUTH_CHUNK: + case SCTP_AUTH_DELETE_KEY: + case SCTP_AUTH_DEACTIVATE_KEY: + retval = -EOPNOTSUPP; + break; + case SCTP_HMAC_IDENT: + retval = sctp_getsockopt_hmac_ident(sk, len, optval, optlen); + break; + case SCTP_AUTH_ACTIVE_KEY: + retval = sctp_getsockopt_active_key(sk, len, optval, optlen); + break; + case SCTP_PEER_AUTH_CHUNKS: + retval = sctp_getsockopt_peer_auth_chunks(sk, len, optval, + optlen); + break; + case SCTP_LOCAL_AUTH_CHUNKS: + retval = sctp_getsockopt_local_auth_chunks(sk, len, optval, + optlen); + break; + case SCTP_GET_ASSOC_NUMBER: + retval = sctp_getsockopt_assoc_number(sk, len, optval, optlen); + break; + case SCTP_GET_ASSOC_ID_LIST: + retval = sctp_getsockopt_assoc_ids(sk, len, optval, optlen); + break; + case SCTP_AUTO_ASCONF: + retval = sctp_getsockopt_auto_asconf(sk, len, optval, optlen); + break; + case SCTP_PEER_ADDR_THLDS: + retval = sctp_getsockopt_paddr_thresholds(sk, optval, len, + optlen, false); + break; + case SCTP_PEER_ADDR_THLDS_V2: + retval = sctp_getsockopt_paddr_thresholds(sk, optval, len, + optlen, true); + break; + case SCTP_GET_ASSOC_STATS: + retval = sctp_getsockopt_assoc_stats(sk, len, optval, optlen); + break; + case SCTP_RECVRCVINFO: + retval = sctp_getsockopt_recvrcvinfo(sk, len, optval, optlen); + break; + case SCTP_RECVNXTINFO: + retval = sctp_getsockopt_recvnxtinfo(sk, len, optval, optlen); + break; + case SCTP_PR_SUPPORTED: + retval = sctp_getsockopt_pr_supported(sk, len, optval, optlen); + break; + case SCTP_DEFAULT_PRINFO: + retval = sctp_getsockopt_default_prinfo(sk, len, optval, + optlen); + break; + case SCTP_PR_ASSOC_STATUS: + retval = sctp_getsockopt_pr_assocstatus(sk, len, optval, + optlen); + break; + case SCTP_PR_STREAM_STATUS: + retval = sctp_getsockopt_pr_streamstatus(sk, len, optval, + optlen); + break; + case SCTP_RECONFIG_SUPPORTED: + retval = sctp_getsockopt_reconfig_supported(sk, len, optval, + optlen); + break; + case SCTP_ENABLE_STREAM_RESET: + retval = sctp_getsockopt_enable_strreset(sk, len, optval, + optlen); + break; + case SCTP_STREAM_SCHEDULER: + retval = sctp_getsockopt_scheduler(sk, len, optval, + optlen); + break; + case SCTP_STREAM_SCHEDULER_VALUE: + retval = sctp_getsockopt_scheduler_value(sk, len, optval, + optlen); + break; + case SCTP_INTERLEAVING_SUPPORTED: + retval = sctp_getsockopt_interleaving_supported(sk, len, optval, + optlen); + break; + case SCTP_REUSE_PORT: + retval = sctp_getsockopt_reuse_port(sk, len, optval, optlen); + break; + case SCTP_EVENT: + retval = sctp_getsockopt_event(sk, len, optval, optlen); + break; + case SCTP_ASCONF_SUPPORTED: + retval = sctp_getsockopt_asconf_supported(sk, len, optval, + optlen); + break; + case SCTP_AUTH_SUPPORTED: + retval = sctp_getsockopt_auth_supported(sk, len, optval, + optlen); + break; + case SCTP_ECN_SUPPORTED: + retval = sctp_getsockopt_ecn_supported(sk, len, optval, optlen); + break; + case SCTP_EXPOSE_POTENTIALLY_FAILED_STATE: + retval = sctp_getsockopt_pf_expose(sk, len, optval, optlen); + break; + case SCTP_REMOTE_UDP_ENCAPS_PORT: + retval = sctp_getsockopt_encap_port(sk, len, optval, optlen); + break; + case SCTP_PLPMTUD_PROBE_INTERVAL: + retval = sctp_getsockopt_probe_interval(sk, len, optval, optlen); + break; + default: + retval = -ENOPROTOOPT; + break; + } + + release_sock(sk); + return retval; +} + +static bool sctp_bpf_bypass_getsockopt(int level, int optname) +{ + if (level == SOL_SCTP) { + switch (optname) { + case SCTP_SOCKOPT_PEELOFF: + case SCTP_SOCKOPT_PEELOFF_FLAGS: + case SCTP_SOCKOPT_CONNECTX3: + return true; + default: + return false; + } + } + + return false; +} + +static int sctp_hash(struct sock *sk) +{ + /* STUB */ + return 0; +} + +static void sctp_unhash(struct sock *sk) +{ + /* STUB */ +} + +/* Check if port is acceptable. Possibly find first available port. + * + * The port hash table (contained in the 'global' SCTP protocol storage + * returned by struct sctp_protocol *sctp_get_protocol()). The hash + * table is an array of 4096 lists (sctp_bind_hashbucket). Each + * list (the list number is the port number hashed out, so as you + * would expect from a hash function, all the ports in a given list have + * such a number that hashes out to the same list number; you were + * expecting that, right?); so each list has a set of ports, with a + * link to the socket (struct sock) that uses it, the port number and + * a fastreuse flag (FIXME: NPI ipg). + */ +static struct sctp_bind_bucket *sctp_bucket_create( + struct sctp_bind_hashbucket *head, struct net *, unsigned short snum); + +static int sctp_get_port_local(struct sock *sk, union sctp_addr *addr) +{ + struct sctp_sock *sp = sctp_sk(sk); + bool reuse = (sk->sk_reuse || sp->reuse); + struct sctp_bind_hashbucket *head; /* hash list */ + struct net *net = sock_net(sk); + kuid_t uid = sock_i_uid(sk); + struct sctp_bind_bucket *pp; + unsigned short snum; + int ret; + + snum = ntohs(addr->v4.sin_port); + + pr_debug("%s: begins, snum:%d\n", __func__, snum); + + if (snum == 0) { + /* Search for an available port. */ + int low, high, remaining, index; + unsigned int rover; + + inet_sk_get_local_port_range(sk, &low, &high); + remaining = (high - low) + 1; + rover = prandom_u32_max(remaining) + low; + + do { + rover++; + if ((rover < low) || (rover > high)) + rover = low; + if (inet_is_local_reserved_port(net, rover)) + continue; + index = sctp_phashfn(net, rover); + head = &sctp_port_hashtable[index]; + spin_lock_bh(&head->lock); + sctp_for_each_hentry(pp, &head->chain) + if ((pp->port == rover) && + net_eq(net, pp->net)) + goto next; + break; + next: + spin_unlock_bh(&head->lock); + cond_resched(); + } while (--remaining > 0); + + /* Exhausted local port range during search? */ + ret = 1; + if (remaining <= 0) + return ret; + + /* OK, here is the one we will use. HEAD (the port + * hash table list entry) is non-NULL and we hold it's + * mutex. + */ + snum = rover; + } else { + /* We are given an specific port number; we verify + * that it is not being used. If it is used, we will + * exahust the search in the hash list corresponding + * to the port number (snum) - we detect that with the + * port iterator, pp being NULL. + */ + head = &sctp_port_hashtable[sctp_phashfn(net, snum)]; + spin_lock_bh(&head->lock); + sctp_for_each_hentry(pp, &head->chain) { + if ((pp->port == snum) && net_eq(pp->net, net)) + goto pp_found; + } + } + pp = NULL; + goto pp_not_found; +pp_found: + if (!hlist_empty(&pp->owner)) { + /* We had a port hash table hit - there is an + * available port (pp != NULL) and it is being + * used by other socket (pp->owner not empty); that other + * socket is going to be sk2. + */ + struct sock *sk2; + + pr_debug("%s: found a possible match\n", __func__); + + if ((pp->fastreuse && reuse && + sk->sk_state != SCTP_SS_LISTENING) || + (pp->fastreuseport && sk->sk_reuseport && + uid_eq(pp->fastuid, uid))) + goto success; + + /* Run through the list of sockets bound to the port + * (pp->port) [via the pointers bind_next and + * bind_pprev in the struct sock *sk2 (pp->sk)]. On each one, + * we get the endpoint they describe and run through + * the endpoint's list of IP (v4 or v6) addresses, + * comparing each of the addresses with the address of + * the socket sk. If we find a match, then that means + * that this port/socket (sk) combination are already + * in an endpoint. + */ + sk_for_each_bound(sk2, &pp->owner) { + struct sctp_sock *sp2 = sctp_sk(sk2); + struct sctp_endpoint *ep2 = sp2->ep; + + if (sk == sk2 || + (reuse && (sk2->sk_reuse || sp2->reuse) && + sk2->sk_state != SCTP_SS_LISTENING) || + (sk->sk_reuseport && sk2->sk_reuseport && + uid_eq(uid, sock_i_uid(sk2)))) + continue; + + if (sctp_bind_addr_conflict(&ep2->base.bind_addr, + addr, sp2, sp)) { + ret = 1; + goto fail_unlock; + } + } + + pr_debug("%s: found a match\n", __func__); + } +pp_not_found: + /* If there was a hash table miss, create a new port. */ + ret = 1; + if (!pp && !(pp = sctp_bucket_create(head, net, snum))) + goto fail_unlock; + + /* In either case (hit or miss), make sure fastreuse is 1 only + * if sk->sk_reuse is too (that is, if the caller requested + * SO_REUSEADDR on this socket -sk-). + */ + if (hlist_empty(&pp->owner)) { + if (reuse && sk->sk_state != SCTP_SS_LISTENING) + pp->fastreuse = 1; + else + pp->fastreuse = 0; + + if (sk->sk_reuseport) { + pp->fastreuseport = 1; + pp->fastuid = uid; + } else { + pp->fastreuseport = 0; + } + } else { + if (pp->fastreuse && + (!reuse || sk->sk_state == SCTP_SS_LISTENING)) + pp->fastreuse = 0; + + if (pp->fastreuseport && + (!sk->sk_reuseport || !uid_eq(pp->fastuid, uid))) + pp->fastreuseport = 0; + } + + /* We are set, so fill up all the data in the hash table + * entry, tie the socket list information with the rest of the + * sockets FIXME: Blurry, NPI (ipg). + */ +success: + if (!sp->bind_hash) { + inet_sk(sk)->inet_num = snum; + sk_add_bind_node(sk, &pp->owner); + sp->bind_hash = pp; + } + ret = 0; + +fail_unlock: + spin_unlock_bh(&head->lock); + return ret; +} + +/* Assign a 'snum' port to the socket. If snum == 0, an ephemeral + * port is requested. + */ +static int sctp_get_port(struct sock *sk, unsigned short snum) +{ + union sctp_addr addr; + struct sctp_af *af = sctp_sk(sk)->pf->af; + + /* Set up a dummy address struct from the sk. */ + af->from_sk(&addr, sk); + addr.v4.sin_port = htons(snum); + + /* Note: sk->sk_num gets filled in if ephemeral port request. */ + return sctp_get_port_local(sk, &addr); +} + +/* + * Move a socket to LISTENING state. + */ +static int sctp_listen_start(struct sock *sk, int backlog) +{ + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_endpoint *ep = sp->ep; + struct crypto_shash *tfm = NULL; + char alg[32]; + + /* Allocate HMAC for generating cookie. */ + if (!sp->hmac && sp->sctp_hmac_alg) { + sprintf(alg, "hmac(%s)", sp->sctp_hmac_alg); + tfm = crypto_alloc_shash(alg, 0, 0); + if (IS_ERR(tfm)) { + net_info_ratelimited("failed to load transform for %s: %ld\n", + sp->sctp_hmac_alg, PTR_ERR(tfm)); + return -ENOSYS; + } + sctp_sk(sk)->hmac = tfm; + } + + /* + * If a bind() or sctp_bindx() is not called prior to a listen() + * call that allows new associations to be accepted, the system + * picks an ephemeral port and will choose an address set equivalent + * to binding with a wildcard address. + * + * This is not currently spelled out in the SCTP sockets + * extensions draft, but follows the practice as seen in TCP + * sockets. + * + */ + inet_sk_set_state(sk, SCTP_SS_LISTENING); + if (!ep->base.bind_addr.port) { + if (sctp_autobind(sk)) + return -EAGAIN; + } else { + if (sctp_get_port(sk, inet_sk(sk)->inet_num)) { + inet_sk_set_state(sk, SCTP_SS_CLOSED); + return -EADDRINUSE; + } + } + + WRITE_ONCE(sk->sk_max_ack_backlog, backlog); + return sctp_hash_endpoint(ep); +} + +/* + * 4.1.3 / 5.1.3 listen() + * + * By default, new associations are not accepted for UDP style sockets. + * An application uses listen() to mark a socket as being able to + * accept new associations. + * + * On TCP style sockets, applications use listen() to ready the SCTP + * endpoint for accepting inbound associations. + * + * On both types of endpoints a backlog of '0' disables listening. + * + * Move a socket to LISTENING state. + */ +int sctp_inet_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + struct sctp_endpoint *ep = sctp_sk(sk)->ep; + int err = -EINVAL; + + if (unlikely(backlog < 0)) + return err; + + lock_sock(sk); + + /* Peeled-off sockets are not allowed to listen(). */ + if (sctp_style(sk, UDP_HIGH_BANDWIDTH)) + goto out; + + if (sock->state != SS_UNCONNECTED) + goto out; + + if (!sctp_sstate(sk, LISTENING) && !sctp_sstate(sk, CLOSED)) + goto out; + + /* If backlog is zero, disable listening. */ + if (!backlog) { + if (sctp_sstate(sk, CLOSED)) + goto out; + + err = 0; + sctp_unhash_endpoint(ep); + sk->sk_state = SCTP_SS_CLOSED; + if (sk->sk_reuse || sctp_sk(sk)->reuse) + sctp_sk(sk)->bind_hash->fastreuse = 1; + goto out; + } + + /* If we are already listening, just update the backlog */ + if (sctp_sstate(sk, LISTENING)) + WRITE_ONCE(sk->sk_max_ack_backlog, backlog); + else { + err = sctp_listen_start(sk, backlog); + if (err) + goto out; + } + + err = 0; +out: + release_sock(sk); + return err; +} + +/* + * This function is done by modeling the current datagram_poll() and the + * tcp_poll(). Note that, based on these implementations, we don't + * lock the socket in this function, even though it seems that, + * ideally, locking or some other mechanisms can be used to ensure + * the integrity of the counters (sndbuf and wmem_alloc) used + * in this place. We assume that we don't need locks either until proven + * otherwise. + * + * Another thing to note is that we include the Async I/O support + * here, again, by modeling the current TCP/UDP code. We don't have + * a good way to test with it yet. + */ +__poll_t sctp_poll(struct file *file, struct socket *sock, poll_table *wait) +{ + struct sock *sk = sock->sk; + struct sctp_sock *sp = sctp_sk(sk); + __poll_t mask; + + poll_wait(file, sk_sleep(sk), wait); + + sock_rps_record_flow(sk); + + /* A TCP-style listening socket becomes readable when the accept queue + * is not empty. + */ + if (sctp_style(sk, TCP) && sctp_sstate(sk, LISTENING)) + return (!list_empty(&sp->ep->asocs)) ? + (EPOLLIN | EPOLLRDNORM) : 0; + + mask = 0; + + /* Is there any exceptional events? */ + if (sk->sk_err || !skb_queue_empty_lockless(&sk->sk_error_queue)) + mask |= EPOLLERR | + (sock_flag(sk, SOCK_SELECT_ERR_QUEUE) ? EPOLLPRI : 0); + if (sk->sk_shutdown & RCV_SHUTDOWN) + mask |= EPOLLRDHUP | EPOLLIN | EPOLLRDNORM; + if (sk->sk_shutdown == SHUTDOWN_MASK) + mask |= EPOLLHUP; + + /* Is it readable? Reconsider this code with TCP-style support. */ + if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) + mask |= EPOLLIN | EPOLLRDNORM; + + /* The association is either gone or not ready. */ + if (!sctp_style(sk, UDP) && sctp_sstate(sk, CLOSED)) + return mask; + + /* Is it writable? */ + if (sctp_writeable(sk)) { + mask |= EPOLLOUT | EPOLLWRNORM; + } else { + sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk); + /* + * Since the socket is not locked, the buffer + * might be made available after the writeable check and + * before the bit is set. This could cause a lost I/O + * signal. tcp_poll() has a race breaker for this race + * condition. Based on their implementation, we put + * in the following code to cover it as well. + */ + if (sctp_writeable(sk)) + mask |= EPOLLOUT | EPOLLWRNORM; + } + return mask; +} + +/******************************************************************** + * 2nd Level Abstractions + ********************************************************************/ + +static struct sctp_bind_bucket *sctp_bucket_create( + struct sctp_bind_hashbucket *head, struct net *net, unsigned short snum) +{ + struct sctp_bind_bucket *pp; + + pp = kmem_cache_alloc(sctp_bucket_cachep, GFP_ATOMIC); + if (pp) { + SCTP_DBG_OBJCNT_INC(bind_bucket); + pp->port = snum; + pp->fastreuse = 0; + INIT_HLIST_HEAD(&pp->owner); + pp->net = net; + hlist_add_head(&pp->node, &head->chain); + } + return pp; +} + +/* Caller must hold hashbucket lock for this tb with local BH disabled */ +static void sctp_bucket_destroy(struct sctp_bind_bucket *pp) +{ + if (pp && hlist_empty(&pp->owner)) { + __hlist_del(&pp->node); + kmem_cache_free(sctp_bucket_cachep, pp); + SCTP_DBG_OBJCNT_DEC(bind_bucket); + } +} + +/* Release this socket's reference to a local port. */ +static inline void __sctp_put_port(struct sock *sk) +{ + struct sctp_bind_hashbucket *head = + &sctp_port_hashtable[sctp_phashfn(sock_net(sk), + inet_sk(sk)->inet_num)]; + struct sctp_bind_bucket *pp; + + spin_lock(&head->lock); + pp = sctp_sk(sk)->bind_hash; + __sk_del_bind_node(sk); + sctp_sk(sk)->bind_hash = NULL; + inet_sk(sk)->inet_num = 0; + sctp_bucket_destroy(pp); + spin_unlock(&head->lock); +} + +void sctp_put_port(struct sock *sk) +{ + local_bh_disable(); + __sctp_put_port(sk); + local_bh_enable(); +} + +/* + * The system picks an ephemeral port and choose an address set equivalent + * to binding with a wildcard address. + * One of those addresses will be the primary address for the association. + * This automatically enables the multihoming capability of SCTP. + */ +static int sctp_autobind(struct sock *sk) +{ + union sctp_addr autoaddr; + struct sctp_af *af; + __be16 port; + + /* Initialize a local sockaddr structure to INADDR_ANY. */ + af = sctp_sk(sk)->pf->af; + + port = htons(inet_sk(sk)->inet_num); + af->inaddr_any(&autoaddr, port); + + return sctp_do_bind(sk, &autoaddr, af->sockaddr_len); +} + +/* Parse out IPPROTO_SCTP CMSG headers. Perform only minimal validation. + * + * From RFC 2292 + * 4.2 The cmsghdr Structure * + * + * When ancillary data is sent or received, any number of ancillary data + * objects can be specified by the msg_control and msg_controllen members of + * the msghdr structure, because each object is preceded by + * a cmsghdr structure defining the object's length (the cmsg_len member). + * Historically Berkeley-derived implementations have passed only one object + * at a time, but this API allows multiple objects to be + * passed in a single call to sendmsg() or recvmsg(). The following example + * shows two ancillary data objects in a control buffer. + * + * |<--------------------------- msg_controllen -------------------------->| + * | | + * + * |<----- ancillary data object ----->|<----- ancillary data object ----->| + * + * |<---------- CMSG_SPACE() --------->|<---------- CMSG_SPACE() --------->| + * | | | + * + * |<---------- cmsg_len ---------->| |<--------- cmsg_len ----------->| | + * + * |<--------- CMSG_LEN() --------->| |<-------- CMSG_LEN() ---------->| | + * | | | | | + * + * +-----+-----+-----+--+-----------+--+-----+-----+-----+--+-----------+--+ + * |cmsg_|cmsg_|cmsg_|XX| |XX|cmsg_|cmsg_|cmsg_|XX| |XX| + * + * |len |level|type |XX|cmsg_data[]|XX|len |level|type |XX|cmsg_data[]|XX| + * + * +-----+-----+-----+--+-----------+--+-----+-----+-----+--+-----------+--+ + * ^ + * | + * + * msg_control + * points here + */ +static int sctp_msghdr_parse(const struct msghdr *msg, struct sctp_cmsgs *cmsgs) +{ + struct msghdr *my_msg = (struct msghdr *)msg; + struct cmsghdr *cmsg; + + for_each_cmsghdr(cmsg, my_msg) { + if (!CMSG_OK(my_msg, cmsg)) + return -EINVAL; + + /* Should we parse this header or ignore? */ + if (cmsg->cmsg_level != IPPROTO_SCTP) + continue; + + /* Strictly check lengths following example in SCM code. */ + switch (cmsg->cmsg_type) { + case SCTP_INIT: + /* SCTP Socket API Extension + * 5.3.1 SCTP Initiation Structure (SCTP_INIT) + * + * This cmsghdr structure provides information for + * initializing new SCTP associations with sendmsg(). + * The SCTP_INITMSG socket option uses this same data + * structure. This structure is not used for + * recvmsg(). + * + * cmsg_level cmsg_type cmsg_data[] + * ------------ ------------ ---------------------- + * IPPROTO_SCTP SCTP_INIT struct sctp_initmsg + */ + if (cmsg->cmsg_len != CMSG_LEN(sizeof(struct sctp_initmsg))) + return -EINVAL; + + cmsgs->init = CMSG_DATA(cmsg); + break; + + case SCTP_SNDRCV: + /* SCTP Socket API Extension + * 5.3.2 SCTP Header Information Structure(SCTP_SNDRCV) + * + * This cmsghdr structure specifies SCTP options for + * sendmsg() and describes SCTP header information + * about a received message through recvmsg(). + * + * cmsg_level cmsg_type cmsg_data[] + * ------------ ------------ ---------------------- + * IPPROTO_SCTP SCTP_SNDRCV struct sctp_sndrcvinfo + */ + if (cmsg->cmsg_len != CMSG_LEN(sizeof(struct sctp_sndrcvinfo))) + return -EINVAL; + + cmsgs->srinfo = CMSG_DATA(cmsg); + + if (cmsgs->srinfo->sinfo_flags & + ~(SCTP_UNORDERED | SCTP_ADDR_OVER | + SCTP_SACK_IMMEDIATELY | SCTP_SENDALL | + SCTP_PR_SCTP_MASK | SCTP_ABORT | SCTP_EOF)) + return -EINVAL; + break; + + case SCTP_SNDINFO: + /* SCTP Socket API Extension + * 5.3.4 SCTP Send Information Structure (SCTP_SNDINFO) + * + * This cmsghdr structure specifies SCTP options for + * sendmsg(). This structure and SCTP_RCVINFO replaces + * SCTP_SNDRCV which has been deprecated. + * + * cmsg_level cmsg_type cmsg_data[] + * ------------ ------------ --------------------- + * IPPROTO_SCTP SCTP_SNDINFO struct sctp_sndinfo + */ + if (cmsg->cmsg_len != CMSG_LEN(sizeof(struct sctp_sndinfo))) + return -EINVAL; + + cmsgs->sinfo = CMSG_DATA(cmsg); + + if (cmsgs->sinfo->snd_flags & + ~(SCTP_UNORDERED | SCTP_ADDR_OVER | + SCTP_SACK_IMMEDIATELY | SCTP_SENDALL | + SCTP_PR_SCTP_MASK | SCTP_ABORT | SCTP_EOF)) + return -EINVAL; + break; + case SCTP_PRINFO: + /* SCTP Socket API Extension + * 5.3.7 SCTP PR-SCTP Information Structure (SCTP_PRINFO) + * + * This cmsghdr structure specifies SCTP options for sendmsg(). + * + * cmsg_level cmsg_type cmsg_data[] + * ------------ ------------ --------------------- + * IPPROTO_SCTP SCTP_PRINFO struct sctp_prinfo + */ + if (cmsg->cmsg_len != CMSG_LEN(sizeof(struct sctp_prinfo))) + return -EINVAL; + + cmsgs->prinfo = CMSG_DATA(cmsg); + if (cmsgs->prinfo->pr_policy & ~SCTP_PR_SCTP_MASK) + return -EINVAL; + + if (cmsgs->prinfo->pr_policy == SCTP_PR_SCTP_NONE) + cmsgs->prinfo->pr_value = 0; + break; + case SCTP_AUTHINFO: + /* SCTP Socket API Extension + * 5.3.8 SCTP AUTH Information Structure (SCTP_AUTHINFO) + * + * This cmsghdr structure specifies SCTP options for sendmsg(). + * + * cmsg_level cmsg_type cmsg_data[] + * ------------ ------------ --------------------- + * IPPROTO_SCTP SCTP_AUTHINFO struct sctp_authinfo + */ + if (cmsg->cmsg_len != CMSG_LEN(sizeof(struct sctp_authinfo))) + return -EINVAL; + + cmsgs->authinfo = CMSG_DATA(cmsg); + break; + case SCTP_DSTADDRV4: + case SCTP_DSTADDRV6: + /* SCTP Socket API Extension + * 5.3.9/10 SCTP Destination IPv4/6 Address Structure (SCTP_DSTADDRV4/6) + * + * This cmsghdr structure specifies SCTP options for sendmsg(). + * + * cmsg_level cmsg_type cmsg_data[] + * ------------ ------------ --------------------- + * IPPROTO_SCTP SCTP_DSTADDRV4 struct in_addr + * ------------ ------------ --------------------- + * IPPROTO_SCTP SCTP_DSTADDRV6 struct in6_addr + */ + cmsgs->addrs_msg = my_msg; + break; + default: + return -EINVAL; + } + } + + return 0; +} + +/* + * Wait for a packet.. + * Note: This function is the same function as in core/datagram.c + * with a few modifications to make lksctp work. + */ +static int sctp_wait_for_packet(struct sock *sk, int *err, long *timeo_p) +{ + int error; + DEFINE_WAIT(wait); + + prepare_to_wait_exclusive(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + + /* Socket errors? */ + error = sock_error(sk); + if (error) + goto out; + + if (!skb_queue_empty(&sk->sk_receive_queue)) + goto ready; + + /* Socket shut down? */ + if (sk->sk_shutdown & RCV_SHUTDOWN) + goto out; + + /* Sequenced packets can come disconnected. If so we report the + * problem. + */ + error = -ENOTCONN; + + /* Is there a good reason to think that we may receive some data? */ + if (list_empty(&sctp_sk(sk)->ep->asocs) && !sctp_sstate(sk, LISTENING)) + goto out; + + /* Handle signals. */ + if (signal_pending(current)) + goto interrupted; + + /* Let another process have a go. Since we are going to sleep + * anyway. Note: This may cause odd behaviors if the message + * does not fit in the user's buffer, but this seems to be the + * only way to honor MSG_DONTWAIT realistically. + */ + release_sock(sk); + *timeo_p = schedule_timeout(*timeo_p); + lock_sock(sk); + +ready: + finish_wait(sk_sleep(sk), &wait); + return 0; + +interrupted: + error = sock_intr_errno(*timeo_p); + +out: + finish_wait(sk_sleep(sk), &wait); + *err = error; + return error; +} + +/* Receive a datagram. + * Note: This is pretty much the same routine as in core/datagram.c + * with a few changes to make lksctp work. + */ +struct sk_buff *sctp_skb_recv_datagram(struct sock *sk, int flags, int *err) +{ + int error; + struct sk_buff *skb; + long timeo; + + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + + pr_debug("%s: timeo:%ld, max:%ld\n", __func__, timeo, + MAX_SCHEDULE_TIMEOUT); + + do { + /* Again only user level code calls this function, + * so nothing interrupt level + * will suddenly eat the receive_queue. + * + * Look at current nfs client by the way... + * However, this function was correct in any case. 8) + */ + if (flags & MSG_PEEK) { + skb = skb_peek(&sk->sk_receive_queue); + if (skb) + refcount_inc(&skb->users); + } else { + skb = __skb_dequeue(&sk->sk_receive_queue); + } + + if (skb) + return skb; + + /* Caller is allowed not to check sk->sk_err before calling. */ + error = sock_error(sk); + if (error) + goto no_packet; + + if (sk->sk_shutdown & RCV_SHUTDOWN) + break; + + + /* User doesn't want to wait. */ + error = -EAGAIN; + if (!timeo) + goto no_packet; + } while (sctp_wait_for_packet(sk, err, &timeo) == 0); + + return NULL; + +no_packet: + *err = error; + return NULL; +} + +/* If sndbuf has changed, wake up per association sndbuf waiters. */ +static void __sctp_write_space(struct sctp_association *asoc) +{ + struct sock *sk = asoc->base.sk; + + if (sctp_wspace(asoc) <= 0) + return; + + if (waitqueue_active(&asoc->wait)) + wake_up_interruptible(&asoc->wait); + + if (sctp_writeable(sk)) { + struct socket_wq *wq; + + rcu_read_lock(); + wq = rcu_dereference(sk->sk_wq); + if (wq) { + if (waitqueue_active(&wq->wait)) + wake_up_interruptible(&wq->wait); + + /* Note that we try to include the Async I/O support + * here by modeling from the current TCP/UDP code. + * We have not tested with it yet. + */ + if (!(sk->sk_shutdown & SEND_SHUTDOWN)) + sock_wake_async(wq, SOCK_WAKE_SPACE, POLL_OUT); + } + rcu_read_unlock(); + } +} + +static void sctp_wake_up_waiters(struct sock *sk, + struct sctp_association *asoc) +{ + struct sctp_association *tmp = asoc; + + /* We do accounting for the sndbuf space per association, + * so we only need to wake our own association. + */ + if (asoc->ep->sndbuf_policy) + return __sctp_write_space(asoc); + + /* If association goes down and is just flushing its + * outq, then just normally notify others. + */ + if (asoc->base.dead) + return sctp_write_space(sk); + + /* Accounting for the sndbuf space is per socket, so we + * need to wake up others, try to be fair and in case of + * other associations, let them have a go first instead + * of just doing a sctp_write_space() call. + * + * Note that we reach sctp_wake_up_waiters() only when + * associations free up queued chunks, thus we are under + * lock and the list of associations on a socket is + * guaranteed not to change. + */ + for (tmp = list_next_entry(tmp, asocs); 1; + tmp = list_next_entry(tmp, asocs)) { + /* Manually skip the head element. */ + if (&tmp->asocs == &((sctp_sk(sk))->ep->asocs)) + continue; + /* Wake up association. */ + __sctp_write_space(tmp); + /* We've reached the end. */ + if (tmp == asoc) + break; + } +} + +/* Do accounting for the sndbuf space. + * Decrement the used sndbuf space of the corresponding association by the + * data size which was just transmitted(freed). + */ +static void sctp_wfree(struct sk_buff *skb) +{ + struct sctp_chunk *chunk = skb_shinfo(skb)->destructor_arg; + struct sctp_association *asoc = chunk->asoc; + struct sock *sk = asoc->base.sk; + + sk_mem_uncharge(sk, skb->truesize); + sk_wmem_queued_add(sk, -(skb->truesize + sizeof(struct sctp_chunk))); + asoc->sndbuf_used -= skb->truesize + sizeof(struct sctp_chunk); + WARN_ON(refcount_sub_and_test(sizeof(struct sctp_chunk), + &sk->sk_wmem_alloc)); + + if (chunk->shkey) { + struct sctp_shared_key *shkey = chunk->shkey; + + /* refcnt == 2 and !list_empty mean after this release, it's + * not being used anywhere, and it's time to notify userland + * that this shkey can be freed if it's been deactivated. + */ + if (shkey->deactivated && !list_empty(&shkey->key_list) && + refcount_read(&shkey->refcnt) == 2) { + struct sctp_ulpevent *ev; + + ev = sctp_ulpevent_make_authkey(asoc, shkey->key_id, + SCTP_AUTH_FREE_KEY, + GFP_KERNEL); + if (ev) + asoc->stream.si->enqueue_event(&asoc->ulpq, ev); + } + sctp_auth_shkey_release(chunk->shkey); + } + + sock_wfree(skb); + sctp_wake_up_waiters(sk, asoc); + + sctp_association_put(asoc); +} + +/* Do accounting for the receive space on the socket. + * Accounting for the association is done in ulpevent.c + * We set this as a destructor for the cloned data skbs so that + * accounting is done at the correct time. + */ +void sctp_sock_rfree(struct sk_buff *skb) +{ + struct sock *sk = skb->sk; + struct sctp_ulpevent *event = sctp_skb2event(skb); + + atomic_sub(event->rmem_len, &sk->sk_rmem_alloc); + + /* + * Mimic the behavior of sock_rfree + */ + sk_mem_uncharge(sk, event->rmem_len); +} + + +/* Helper function to wait for space in the sndbuf. */ +static int sctp_wait_for_sndbuf(struct sctp_association *asoc, long *timeo_p, + size_t msg_len) +{ + struct sock *sk = asoc->base.sk; + long current_timeo = *timeo_p; + DEFINE_WAIT(wait); + int err = 0; + + pr_debug("%s: asoc:%p, timeo:%ld, msg_len:%zu\n", __func__, asoc, + *timeo_p, msg_len); + + /* Increment the association's refcnt. */ + sctp_association_hold(asoc); + + /* Wait on the association specific sndbuf space. */ + for (;;) { + prepare_to_wait_exclusive(&asoc->wait, &wait, + TASK_INTERRUPTIBLE); + if (asoc->base.dead) + goto do_dead; + if (!*timeo_p) + goto do_nonblock; + if (sk->sk_err || asoc->state >= SCTP_STATE_SHUTDOWN_PENDING) + goto do_error; + if (signal_pending(current)) + goto do_interrupted; + if ((int)msg_len <= sctp_wspace(asoc) && + sk_wmem_schedule(sk, msg_len)) + break; + + /* Let another process have a go. Since we are going + * to sleep anyway. + */ + release_sock(sk); + current_timeo = schedule_timeout(current_timeo); + lock_sock(sk); + if (sk != asoc->base.sk) + goto do_error; + + *timeo_p = current_timeo; + } + +out: + finish_wait(&asoc->wait, &wait); + + /* Release the association's refcnt. */ + sctp_association_put(asoc); + + return err; + +do_dead: + err = -ESRCH; + goto out; + +do_error: + err = -EPIPE; + goto out; + +do_interrupted: + err = sock_intr_errno(*timeo_p); + goto out; + +do_nonblock: + err = -EAGAIN; + goto out; +} + +void sctp_data_ready(struct sock *sk) +{ + struct socket_wq *wq; + + rcu_read_lock(); + wq = rcu_dereference(sk->sk_wq); + if (skwq_has_sleeper(wq)) + wake_up_interruptible_sync_poll(&wq->wait, EPOLLIN | + EPOLLRDNORM | EPOLLRDBAND); + sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN); + rcu_read_unlock(); +} + +/* If socket sndbuf has changed, wake up all per association waiters. */ +void sctp_write_space(struct sock *sk) +{ + struct sctp_association *asoc; + + /* Wake up the tasks in each wait queue. */ + list_for_each_entry(asoc, &((sctp_sk(sk))->ep->asocs), asocs) { + __sctp_write_space(asoc); + } +} + +/* Is there any sndbuf space available on the socket? + * + * Note that sk_wmem_alloc is the sum of the send buffers on all of the + * associations on the same socket. For a UDP-style socket with + * multiple associations, it is possible for it to be "unwriteable" + * prematurely. I assume that this is acceptable because + * a premature "unwriteable" is better than an accidental "writeable" which + * would cause an unwanted block under certain circumstances. For the 1-1 + * UDP-style sockets or TCP-style sockets, this code should work. + * - Daisy + */ +static bool sctp_writeable(const struct sock *sk) +{ + return READ_ONCE(sk->sk_sndbuf) > READ_ONCE(sk->sk_wmem_queued); +} + +/* Wait for an association to go into ESTABLISHED state. If timeout is 0, + * returns immediately with EINPROGRESS. + */ +static int sctp_wait_for_connect(struct sctp_association *asoc, long *timeo_p) +{ + struct sock *sk = asoc->base.sk; + int err = 0; + long current_timeo = *timeo_p; + DEFINE_WAIT(wait); + + pr_debug("%s: asoc:%p, timeo:%ld\n", __func__, asoc, *timeo_p); + + /* Increment the association's refcnt. */ + sctp_association_hold(asoc); + + for (;;) { + prepare_to_wait_exclusive(&asoc->wait, &wait, + TASK_INTERRUPTIBLE); + if (!*timeo_p) + goto do_nonblock; + if (sk->sk_shutdown & RCV_SHUTDOWN) + break; + if (sk->sk_err || asoc->state >= SCTP_STATE_SHUTDOWN_PENDING || + asoc->base.dead) + goto do_error; + if (signal_pending(current)) + goto do_interrupted; + + if (sctp_state(asoc, ESTABLISHED)) + break; + + /* Let another process have a go. Since we are going + * to sleep anyway. + */ + release_sock(sk); + current_timeo = schedule_timeout(current_timeo); + lock_sock(sk); + + *timeo_p = current_timeo; + } + +out: + finish_wait(&asoc->wait, &wait); + + /* Release the association's refcnt. */ + sctp_association_put(asoc); + + return err; + +do_error: + if (asoc->init_err_counter + 1 > asoc->max_init_attempts) + err = -ETIMEDOUT; + else + err = -ECONNREFUSED; + goto out; + +do_interrupted: + err = sock_intr_errno(*timeo_p); + goto out; + +do_nonblock: + err = -EINPROGRESS; + goto out; +} + +static int sctp_wait_for_accept(struct sock *sk, long timeo) +{ + struct sctp_endpoint *ep; + int err = 0; + DEFINE_WAIT(wait); + + ep = sctp_sk(sk)->ep; + + + for (;;) { + prepare_to_wait_exclusive(sk_sleep(sk), &wait, + TASK_INTERRUPTIBLE); + + if (list_empty(&ep->asocs)) { + release_sock(sk); + timeo = schedule_timeout(timeo); + lock_sock(sk); + } + + err = -EINVAL; + if (!sctp_sstate(sk, LISTENING)) + break; + + err = 0; + if (!list_empty(&ep->asocs)) + break; + + err = sock_intr_errno(timeo); + if (signal_pending(current)) + break; + + err = -EAGAIN; + if (!timeo) + break; + } + + finish_wait(sk_sleep(sk), &wait); + + return err; +} + +static void sctp_wait_for_close(struct sock *sk, long timeout) +{ + DEFINE_WAIT(wait); + + do { + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + if (list_empty(&sctp_sk(sk)->ep->asocs)) + break; + release_sock(sk); + timeout = schedule_timeout(timeout); + lock_sock(sk); + } while (!signal_pending(current) && timeout); + + finish_wait(sk_sleep(sk), &wait); +} + +static void sctp_skb_set_owner_r_frag(struct sk_buff *skb, struct sock *sk) +{ + struct sk_buff *frag; + + if (!skb->data_len) + goto done; + + /* Don't forget the fragments. */ + skb_walk_frags(skb, frag) + sctp_skb_set_owner_r_frag(frag, sk); + +done: + sctp_skb_set_owner_r(skb, sk); +} + +void sctp_copy_sock(struct sock *newsk, struct sock *sk, + struct sctp_association *asoc) +{ + struct inet_sock *inet = inet_sk(sk); + struct inet_sock *newinet; + struct sctp_sock *sp = sctp_sk(sk); + + newsk->sk_type = sk->sk_type; + newsk->sk_bound_dev_if = sk->sk_bound_dev_if; + newsk->sk_flags = sk->sk_flags; + newsk->sk_tsflags = sk->sk_tsflags; + newsk->sk_no_check_tx = sk->sk_no_check_tx; + newsk->sk_no_check_rx = sk->sk_no_check_rx; + newsk->sk_reuse = sk->sk_reuse; + sctp_sk(newsk)->reuse = sp->reuse; + + newsk->sk_shutdown = sk->sk_shutdown; + newsk->sk_destruct = sk->sk_destruct; + newsk->sk_family = sk->sk_family; + newsk->sk_protocol = IPPROTO_SCTP; + newsk->sk_backlog_rcv = sk->sk_prot->backlog_rcv; + newsk->sk_sndbuf = sk->sk_sndbuf; + newsk->sk_rcvbuf = sk->sk_rcvbuf; + newsk->sk_lingertime = sk->sk_lingertime; + newsk->sk_rcvtimeo = sk->sk_rcvtimeo; + newsk->sk_sndtimeo = sk->sk_sndtimeo; + newsk->sk_rxhash = sk->sk_rxhash; + + newinet = inet_sk(newsk); + + /* Initialize sk's sport, dport, rcv_saddr and daddr for + * getsockname() and getpeername() + */ + newinet->inet_sport = inet->inet_sport; + newinet->inet_saddr = inet->inet_saddr; + newinet->inet_rcv_saddr = inet->inet_rcv_saddr; + newinet->inet_dport = htons(asoc->peer.port); + newinet->pmtudisc = inet->pmtudisc; + atomic_set(&newinet->inet_id, get_random_u16()); + + newinet->uc_ttl = inet->uc_ttl; + newinet->mc_loop = 1; + newinet->mc_ttl = 1; + newinet->mc_index = 0; + newinet->mc_list = NULL; + + if (newsk->sk_flags & SK_FLAGS_TIMESTAMP) + net_enable_timestamp(); + + /* Set newsk security attributes from original sk and connection + * security attribute from asoc. + */ + security_sctp_sk_clone(asoc, sk, newsk); +} + +static inline void sctp_copy_descendant(struct sock *sk_to, + const struct sock *sk_from) +{ + size_t ancestor_size = sizeof(struct inet_sock); + + ancestor_size += sk_from->sk_prot->obj_size; + ancestor_size -= offsetof(struct sctp_sock, pd_lobby); + __inet_sk_copy_descendant(sk_to, sk_from, ancestor_size); +} + +/* Populate the fields of the newsk from the oldsk and migrate the assoc + * and its messages to the newsk. + */ +static int sctp_sock_migrate(struct sock *oldsk, struct sock *newsk, + struct sctp_association *assoc, + enum sctp_socket_type type) +{ + struct sctp_sock *oldsp = sctp_sk(oldsk); + struct sctp_sock *newsp = sctp_sk(newsk); + struct sctp_bind_bucket *pp; /* hash list port iterator */ + struct sctp_endpoint *newep = newsp->ep; + struct sk_buff *skb, *tmp; + struct sctp_ulpevent *event; + struct sctp_bind_hashbucket *head; + int err; + + /* Migrate socket buffer sizes and all the socket level options to the + * new socket. + */ + newsk->sk_sndbuf = oldsk->sk_sndbuf; + newsk->sk_rcvbuf = oldsk->sk_rcvbuf; + /* Brute force copy old sctp opt. */ + sctp_copy_descendant(newsk, oldsk); + + /* Restore the ep value that was overwritten with the above structure + * copy. + */ + newsp->ep = newep; + newsp->hmac = NULL; + + /* Hook this new socket in to the bind_hash list. */ + head = &sctp_port_hashtable[sctp_phashfn(sock_net(oldsk), + inet_sk(oldsk)->inet_num)]; + spin_lock_bh(&head->lock); + pp = sctp_sk(oldsk)->bind_hash; + sk_add_bind_node(newsk, &pp->owner); + sctp_sk(newsk)->bind_hash = pp; + inet_sk(newsk)->inet_num = inet_sk(oldsk)->inet_num; + spin_unlock_bh(&head->lock); + + /* Copy the bind_addr list from the original endpoint to the new + * endpoint so that we can handle restarts properly + */ + err = sctp_bind_addr_dup(&newsp->ep->base.bind_addr, + &oldsp->ep->base.bind_addr, GFP_KERNEL); + if (err) + return err; + + /* New ep's auth_hmacs should be set if old ep's is set, in case + * that net->sctp.auth_enable has been changed to 0 by users and + * new ep's auth_hmacs couldn't be set in sctp_endpoint_init(). + */ + if (oldsp->ep->auth_hmacs) { + err = sctp_auth_init_hmacs(newsp->ep, GFP_KERNEL); + if (err) + return err; + } + + sctp_auto_asconf_init(newsp); + + /* Move any messages in the old socket's receive queue that are for the + * peeled off association to the new socket's receive queue. + */ + sctp_skb_for_each(skb, &oldsk->sk_receive_queue, tmp) { + event = sctp_skb2event(skb); + if (event->asoc == assoc) { + __skb_unlink(skb, &oldsk->sk_receive_queue); + __skb_queue_tail(&newsk->sk_receive_queue, skb); + sctp_skb_set_owner_r_frag(skb, newsk); + } + } + + /* Clean up any messages pending delivery due to partial + * delivery. Three cases: + * 1) No partial deliver; no work. + * 2) Peeling off partial delivery; keep pd_lobby in new pd_lobby. + * 3) Peeling off non-partial delivery; move pd_lobby to receive_queue. + */ + atomic_set(&sctp_sk(newsk)->pd_mode, assoc->ulpq.pd_mode); + + if (atomic_read(&sctp_sk(oldsk)->pd_mode)) { + struct sk_buff_head *queue; + + /* Decide which queue to move pd_lobby skbs to. */ + if (assoc->ulpq.pd_mode) { + queue = &newsp->pd_lobby; + } else + queue = &newsk->sk_receive_queue; + + /* Walk through the pd_lobby, looking for skbs that + * need moved to the new socket. + */ + sctp_skb_for_each(skb, &oldsp->pd_lobby, tmp) { + event = sctp_skb2event(skb); + if (event->asoc == assoc) { + __skb_unlink(skb, &oldsp->pd_lobby); + __skb_queue_tail(queue, skb); + sctp_skb_set_owner_r_frag(skb, newsk); + } + } + + /* Clear up any skbs waiting for the partial + * delivery to finish. + */ + if (assoc->ulpq.pd_mode) + sctp_clear_pd(oldsk, NULL); + + } + + sctp_for_each_rx_skb(assoc, newsk, sctp_skb_set_owner_r_frag); + + /* Set the type of socket to indicate that it is peeled off from the + * original UDP-style socket or created with the accept() call on a + * TCP-style socket.. + */ + newsp->type = type; + + /* Mark the new socket "in-use" by the user so that any packets + * that may arrive on the association after we've moved it are + * queued to the backlog. This prevents a potential race between + * backlog processing on the old socket and new-packet processing + * on the new socket. + * + * The caller has just allocated newsk so we can guarantee that other + * paths won't try to lock it and then oldsk. + */ + lock_sock_nested(newsk, SINGLE_DEPTH_NESTING); + sctp_for_each_tx_datachunk(assoc, true, sctp_clear_owner_w); + sctp_assoc_migrate(assoc, newsk); + sctp_for_each_tx_datachunk(assoc, false, sctp_set_owner_w); + + /* If the association on the newsk is already closed before accept() + * is called, set RCV_SHUTDOWN flag. + */ + if (sctp_state(assoc, CLOSED) && sctp_style(newsk, TCP)) { + inet_sk_set_state(newsk, SCTP_SS_CLOSED); + newsk->sk_shutdown |= RCV_SHUTDOWN; + } else { + inet_sk_set_state(newsk, SCTP_SS_ESTABLISHED); + } + + release_sock(newsk); + + return 0; +} + + +/* This proto struct describes the ULP interface for SCTP. */ +struct proto sctp_prot = { + .name = "SCTP", + .owner = THIS_MODULE, + .close = sctp_close, + .disconnect = sctp_disconnect, + .accept = sctp_accept, + .ioctl = sctp_ioctl, + .init = sctp_init_sock, + .destroy = sctp_destroy_sock, + .shutdown = sctp_shutdown, + .setsockopt = sctp_setsockopt, + .getsockopt = sctp_getsockopt, + .bpf_bypass_getsockopt = sctp_bpf_bypass_getsockopt, + .sendmsg = sctp_sendmsg, + .recvmsg = sctp_recvmsg, + .bind = sctp_bind, + .bind_add = sctp_bind_add, + .backlog_rcv = sctp_backlog_rcv, + .hash = sctp_hash, + .unhash = sctp_unhash, + .no_autobind = true, + .obj_size = sizeof(struct sctp_sock), + .useroffset = offsetof(struct sctp_sock, subscribe), + .usersize = offsetof(struct sctp_sock, initmsg) - + offsetof(struct sctp_sock, subscribe) + + sizeof_field(struct sctp_sock, initmsg), + .sysctl_mem = sysctl_sctp_mem, + .sysctl_rmem = sysctl_sctp_rmem, + .sysctl_wmem = sysctl_sctp_wmem, + .memory_pressure = &sctp_memory_pressure, + .enter_memory_pressure = sctp_enter_memory_pressure, + + .memory_allocated = &sctp_memory_allocated, + .per_cpu_fw_alloc = &sctp_memory_per_cpu_fw_alloc, + + .sockets_allocated = &sctp_sockets_allocated, +}; + +#if IS_ENABLED(CONFIG_IPV6) + +static void sctp_v6_destruct_sock(struct sock *sk) +{ + sctp_destruct_common(sk); + inet6_sock_destruct(sk); +} + +static int sctp_v6_init_sock(struct sock *sk) +{ + int ret = sctp_init_sock(sk); + + if (!ret) + sk->sk_destruct = sctp_v6_destruct_sock; + + return ret; +} + +struct proto sctpv6_prot = { + .name = "SCTPv6", + .owner = THIS_MODULE, + .close = sctp_close, + .disconnect = sctp_disconnect, + .accept = sctp_accept, + .ioctl = sctp_ioctl, + .init = sctp_v6_init_sock, + .destroy = sctp_destroy_sock, + .shutdown = sctp_shutdown, + .setsockopt = sctp_setsockopt, + .getsockopt = sctp_getsockopt, + .bpf_bypass_getsockopt = sctp_bpf_bypass_getsockopt, + .sendmsg = sctp_sendmsg, + .recvmsg = sctp_recvmsg, + .bind = sctp_bind, + .bind_add = sctp_bind_add, + .backlog_rcv = sctp_backlog_rcv, + .hash = sctp_hash, + .unhash = sctp_unhash, + .no_autobind = true, + .obj_size = sizeof(struct sctp6_sock), + .useroffset = offsetof(struct sctp6_sock, sctp.subscribe), + .usersize = offsetof(struct sctp6_sock, sctp.initmsg) - + offsetof(struct sctp6_sock, sctp.subscribe) + + sizeof_field(struct sctp6_sock, sctp.initmsg), + .sysctl_mem = sysctl_sctp_mem, + .sysctl_rmem = sysctl_sctp_rmem, + .sysctl_wmem = sysctl_sctp_wmem, + .memory_pressure = &sctp_memory_pressure, + .enter_memory_pressure = sctp_enter_memory_pressure, + + .memory_allocated = &sctp_memory_allocated, + .per_cpu_fw_alloc = &sctp_memory_per_cpu_fw_alloc, + + .sockets_allocated = &sctp_sockets_allocated, +}; +#endif /* IS_ENABLED(CONFIG_IPV6) */ diff --git a/net/sctp/stream.c b/net/sctp/stream.c new file mode 100644 index 000000000..ee6514af8 --- /dev/null +++ b/net/sctp/stream.c @@ -0,0 +1,1087 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2001, 2004 + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * Copyright (c) 2001 Intel Corp. + * + * This file is part of the SCTP kernel implementation + * + * This file contains sctp stream maniuplation primitives and helpers. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * Xin Long + */ + +#include +#include +#include +#include + +static void sctp_stream_shrink_out(struct sctp_stream *stream, __u16 outcnt) +{ + struct sctp_association *asoc; + struct sctp_chunk *ch, *temp; + struct sctp_outq *outq; + + asoc = container_of(stream, struct sctp_association, stream); + outq = &asoc->outqueue; + + list_for_each_entry_safe(ch, temp, &outq->out_chunk_list, list) { + __u16 sid = sctp_chunk_stream_no(ch); + + if (sid < outcnt) + continue; + + sctp_sched_dequeue_common(outq, ch); + /* No need to call dequeue_done here because + * the chunks are not scheduled by now. + */ + + /* Mark as failed send. */ + sctp_chunk_fail(ch, (__force __u32)SCTP_ERROR_INV_STRM); + if (asoc->peer.prsctp_capable && + SCTP_PR_PRIO_ENABLED(ch->sinfo.sinfo_flags)) + asoc->sent_cnt_removable--; + + sctp_chunk_free(ch); + } +} + +static void sctp_stream_free_ext(struct sctp_stream *stream, __u16 sid) +{ + struct sctp_sched_ops *sched; + + if (!SCTP_SO(stream, sid)->ext) + return; + + sched = sctp_sched_ops_from_stream(stream); + sched->free_sid(stream, sid); + kfree(SCTP_SO(stream, sid)->ext); + SCTP_SO(stream, sid)->ext = NULL; +} + +/* Migrates chunks from stream queues to new stream queues if needed, + * but not across associations. Also, removes those chunks to streams + * higher than the new max. + */ +static void sctp_stream_outq_migrate(struct sctp_stream *stream, + struct sctp_stream *new, __u16 outcnt) +{ + int i; + + if (stream->outcnt > outcnt) + sctp_stream_shrink_out(stream, outcnt); + + if (new) { + /* Here we actually move the old ext stuff into the new + * buffer, because we want to keep it. Then + * sctp_stream_update will swap ->out pointers. + */ + for (i = 0; i < outcnt; i++) { + sctp_stream_free_ext(new, i); + SCTP_SO(new, i)->ext = SCTP_SO(stream, i)->ext; + SCTP_SO(stream, i)->ext = NULL; + } + } + + for (i = outcnt; i < stream->outcnt; i++) + sctp_stream_free_ext(stream, i); +} + +static int sctp_stream_alloc_out(struct sctp_stream *stream, __u16 outcnt, + gfp_t gfp) +{ + int ret; + + if (outcnt <= stream->outcnt) + goto out; + + ret = genradix_prealloc(&stream->out, outcnt, gfp); + if (ret) + return ret; + +out: + stream->outcnt = outcnt; + return 0; +} + +static int sctp_stream_alloc_in(struct sctp_stream *stream, __u16 incnt, + gfp_t gfp) +{ + int ret; + + if (incnt <= stream->incnt) + goto out; + + ret = genradix_prealloc(&stream->in, incnt, gfp); + if (ret) + return ret; + +out: + stream->incnt = incnt; + return 0; +} + +int sctp_stream_init(struct sctp_stream *stream, __u16 outcnt, __u16 incnt, + gfp_t gfp) +{ + struct sctp_sched_ops *sched = sctp_sched_ops_from_stream(stream); + int i, ret = 0; + + gfp |= __GFP_NOWARN; + + /* Initial stream->out size may be very big, so free it and alloc + * a new one with new outcnt to save memory if needed. + */ + if (outcnt == stream->outcnt) + goto handle_in; + + /* Filter out chunks queued on streams that won't exist anymore */ + sched->unsched_all(stream); + sctp_stream_outq_migrate(stream, NULL, outcnt); + sched->sched_all(stream); + + ret = sctp_stream_alloc_out(stream, outcnt, gfp); + if (ret) + return ret; + + for (i = 0; i < stream->outcnt; i++) + SCTP_SO(stream, i)->state = SCTP_STREAM_OPEN; + +handle_in: + sctp_stream_interleave_init(stream); + if (!incnt) + return 0; + + return sctp_stream_alloc_in(stream, incnt, gfp); +} + +int sctp_stream_init_ext(struct sctp_stream *stream, __u16 sid) +{ + struct sctp_stream_out_ext *soute; + int ret; + + soute = kzalloc(sizeof(*soute), GFP_KERNEL); + if (!soute) + return -ENOMEM; + SCTP_SO(stream, sid)->ext = soute; + + ret = sctp_sched_init_sid(stream, sid, GFP_KERNEL); + if (ret) { + kfree(SCTP_SO(stream, sid)->ext); + SCTP_SO(stream, sid)->ext = NULL; + } + + return ret; +} + +void sctp_stream_free(struct sctp_stream *stream) +{ + struct sctp_sched_ops *sched = sctp_sched_ops_from_stream(stream); + int i; + + sched->unsched_all(stream); + for (i = 0; i < stream->outcnt; i++) + sctp_stream_free_ext(stream, i); + genradix_free(&stream->out); + genradix_free(&stream->in); +} + +void sctp_stream_clear(struct sctp_stream *stream) +{ + int i; + + for (i = 0; i < stream->outcnt; i++) { + SCTP_SO(stream, i)->mid = 0; + SCTP_SO(stream, i)->mid_uo = 0; + } + + for (i = 0; i < stream->incnt; i++) + SCTP_SI(stream, i)->mid = 0; +} + +void sctp_stream_update(struct sctp_stream *stream, struct sctp_stream *new) +{ + struct sctp_sched_ops *sched = sctp_sched_ops_from_stream(stream); + + sched->unsched_all(stream); + sctp_stream_outq_migrate(stream, new, new->outcnt); + sctp_stream_free(stream); + + stream->out = new->out; + stream->in = new->in; + stream->outcnt = new->outcnt; + stream->incnt = new->incnt; + + sched->sched_all(stream); + + new->out.tree.root = NULL; + new->in.tree.root = NULL; + new->outcnt = 0; + new->incnt = 0; +} + +static int sctp_send_reconf(struct sctp_association *asoc, + struct sctp_chunk *chunk) +{ + int retval = 0; + + retval = sctp_primitive_RECONF(asoc->base.net, asoc, chunk); + if (retval) + sctp_chunk_free(chunk); + + return retval; +} + +static bool sctp_stream_outq_is_empty(struct sctp_stream *stream, + __u16 str_nums, __be16 *str_list) +{ + struct sctp_association *asoc; + __u16 i; + + asoc = container_of(stream, struct sctp_association, stream); + if (!asoc->outqueue.out_qlen) + return true; + + if (!str_nums) + return false; + + for (i = 0; i < str_nums; i++) { + __u16 sid = ntohs(str_list[i]); + + if (SCTP_SO(stream, sid)->ext && + !list_empty(&SCTP_SO(stream, sid)->ext->outq)) + return false; + } + + return true; +} + +int sctp_send_reset_streams(struct sctp_association *asoc, + struct sctp_reset_streams *params) +{ + struct sctp_stream *stream = &asoc->stream; + __u16 i, str_nums, *str_list; + struct sctp_chunk *chunk; + int retval = -EINVAL; + __be16 *nstr_list; + bool out, in; + + if (!asoc->peer.reconf_capable || + !(asoc->strreset_enable & SCTP_ENABLE_RESET_STREAM_REQ)) { + retval = -ENOPROTOOPT; + goto out; + } + + if (asoc->strreset_outstanding) { + retval = -EINPROGRESS; + goto out; + } + + out = params->srs_flags & SCTP_STREAM_RESET_OUTGOING; + in = params->srs_flags & SCTP_STREAM_RESET_INCOMING; + if (!out && !in) + goto out; + + str_nums = params->srs_number_streams; + str_list = params->srs_stream_list; + if (str_nums) { + int param_len = 0; + + if (out) { + for (i = 0; i < str_nums; i++) + if (str_list[i] >= stream->outcnt) + goto out; + + param_len = str_nums * sizeof(__u16) + + sizeof(struct sctp_strreset_outreq); + } + + if (in) { + for (i = 0; i < str_nums; i++) + if (str_list[i] >= stream->incnt) + goto out; + + param_len += str_nums * sizeof(__u16) + + sizeof(struct sctp_strreset_inreq); + } + + if (param_len > SCTP_MAX_CHUNK_LEN - + sizeof(struct sctp_reconf_chunk)) + goto out; + } + + nstr_list = kcalloc(str_nums, sizeof(__be16), GFP_KERNEL); + if (!nstr_list) { + retval = -ENOMEM; + goto out; + } + + for (i = 0; i < str_nums; i++) + nstr_list[i] = htons(str_list[i]); + + if (out && !sctp_stream_outq_is_empty(stream, str_nums, nstr_list)) { + kfree(nstr_list); + retval = -EAGAIN; + goto out; + } + + chunk = sctp_make_strreset_req(asoc, str_nums, nstr_list, out, in); + + kfree(nstr_list); + + if (!chunk) { + retval = -ENOMEM; + goto out; + } + + if (out) { + if (str_nums) + for (i = 0; i < str_nums; i++) + SCTP_SO(stream, str_list[i])->state = + SCTP_STREAM_CLOSED; + else + for (i = 0; i < stream->outcnt; i++) + SCTP_SO(stream, i)->state = SCTP_STREAM_CLOSED; + } + + asoc->strreset_chunk = chunk; + sctp_chunk_hold(asoc->strreset_chunk); + + retval = sctp_send_reconf(asoc, chunk); + if (retval) { + sctp_chunk_put(asoc->strreset_chunk); + asoc->strreset_chunk = NULL; + if (!out) + goto out; + + if (str_nums) + for (i = 0; i < str_nums; i++) + SCTP_SO(stream, str_list[i])->state = + SCTP_STREAM_OPEN; + else + for (i = 0; i < stream->outcnt; i++) + SCTP_SO(stream, i)->state = SCTP_STREAM_OPEN; + + goto out; + } + + asoc->strreset_outstanding = out + in; + +out: + return retval; +} + +int sctp_send_reset_assoc(struct sctp_association *asoc) +{ + struct sctp_stream *stream = &asoc->stream; + struct sctp_chunk *chunk = NULL; + int retval; + __u16 i; + + if (!asoc->peer.reconf_capable || + !(asoc->strreset_enable & SCTP_ENABLE_RESET_ASSOC_REQ)) + return -ENOPROTOOPT; + + if (asoc->strreset_outstanding) + return -EINPROGRESS; + + if (!sctp_outq_is_empty(&asoc->outqueue)) + return -EAGAIN; + + chunk = sctp_make_strreset_tsnreq(asoc); + if (!chunk) + return -ENOMEM; + + /* Block further xmit of data until this request is completed */ + for (i = 0; i < stream->outcnt; i++) + SCTP_SO(stream, i)->state = SCTP_STREAM_CLOSED; + + asoc->strreset_chunk = chunk; + sctp_chunk_hold(asoc->strreset_chunk); + + retval = sctp_send_reconf(asoc, chunk); + if (retval) { + sctp_chunk_put(asoc->strreset_chunk); + asoc->strreset_chunk = NULL; + + for (i = 0; i < stream->outcnt; i++) + SCTP_SO(stream, i)->state = SCTP_STREAM_OPEN; + + return retval; + } + + asoc->strreset_outstanding = 1; + + return 0; +} + +int sctp_send_add_streams(struct sctp_association *asoc, + struct sctp_add_streams *params) +{ + struct sctp_stream *stream = &asoc->stream; + struct sctp_chunk *chunk = NULL; + int retval; + __u32 outcnt, incnt; + __u16 out, in; + + if (!asoc->peer.reconf_capable || + !(asoc->strreset_enable & SCTP_ENABLE_CHANGE_ASSOC_REQ)) { + retval = -ENOPROTOOPT; + goto out; + } + + if (asoc->strreset_outstanding) { + retval = -EINPROGRESS; + goto out; + } + + out = params->sas_outstrms; + in = params->sas_instrms; + outcnt = stream->outcnt + out; + incnt = stream->incnt + in; + if (outcnt > SCTP_MAX_STREAM || incnt > SCTP_MAX_STREAM || + (!out && !in)) { + retval = -EINVAL; + goto out; + } + + if (out) { + retval = sctp_stream_alloc_out(stream, outcnt, GFP_KERNEL); + if (retval) + goto out; + } + + chunk = sctp_make_strreset_addstrm(asoc, out, in); + if (!chunk) { + retval = -ENOMEM; + goto out; + } + + asoc->strreset_chunk = chunk; + sctp_chunk_hold(asoc->strreset_chunk); + + retval = sctp_send_reconf(asoc, chunk); + if (retval) { + sctp_chunk_put(asoc->strreset_chunk); + asoc->strreset_chunk = NULL; + goto out; + } + + asoc->strreset_outstanding = !!out + !!in; + +out: + return retval; +} + +static struct sctp_paramhdr *sctp_chunk_lookup_strreset_param( + struct sctp_association *asoc, __be32 resp_seq, + __be16 type) +{ + struct sctp_chunk *chunk = asoc->strreset_chunk; + struct sctp_reconf_chunk *hdr; + union sctp_params param; + + if (!chunk) + return NULL; + + hdr = (struct sctp_reconf_chunk *)chunk->chunk_hdr; + sctp_walk_params(param, hdr, params) { + /* sctp_strreset_tsnreq is actually the basic structure + * of all stream reconf params, so it's safe to use it + * to access request_seq. + */ + struct sctp_strreset_tsnreq *req = param.v; + + if ((!resp_seq || req->request_seq == resp_seq) && + (!type || type == req->param_hdr.type)) + return param.v; + } + + return NULL; +} + +static void sctp_update_strreset_result(struct sctp_association *asoc, + __u32 result) +{ + asoc->strreset_result[1] = asoc->strreset_result[0]; + asoc->strreset_result[0] = result; +} + +struct sctp_chunk *sctp_process_strreset_outreq( + struct sctp_association *asoc, + union sctp_params param, + struct sctp_ulpevent **evp) +{ + struct sctp_strreset_outreq *outreq = param.v; + struct sctp_stream *stream = &asoc->stream; + __u32 result = SCTP_STRRESET_DENIED; + __be16 *str_p = NULL; + __u32 request_seq; + __u16 i, nums; + + request_seq = ntohl(outreq->request_seq); + + if (ntohl(outreq->send_reset_at_tsn) > + sctp_tsnmap_get_ctsn(&asoc->peer.tsn_map)) { + result = SCTP_STRRESET_IN_PROGRESS; + goto err; + } + + if (TSN_lt(asoc->strreset_inseq, request_seq) || + TSN_lt(request_seq, asoc->strreset_inseq - 2)) { + result = SCTP_STRRESET_ERR_BAD_SEQNO; + goto err; + } else if (TSN_lt(request_seq, asoc->strreset_inseq)) { + i = asoc->strreset_inseq - request_seq - 1; + result = asoc->strreset_result[i]; + goto err; + } + asoc->strreset_inseq++; + + /* Check strreset_enable after inseq inc, as sender cannot tell + * the peer doesn't enable strreset after receiving response with + * result denied, as well as to keep consistent with bsd. + */ + if (!(asoc->strreset_enable & SCTP_ENABLE_RESET_STREAM_REQ)) + goto out; + + nums = (ntohs(param.p->length) - sizeof(*outreq)) / sizeof(__u16); + str_p = outreq->list_of_streams; + for (i = 0; i < nums; i++) { + if (ntohs(str_p[i]) >= stream->incnt) { + result = SCTP_STRRESET_ERR_WRONG_SSN; + goto out; + } + } + + if (asoc->strreset_chunk) { + if (!sctp_chunk_lookup_strreset_param( + asoc, outreq->response_seq, + SCTP_PARAM_RESET_IN_REQUEST)) { + /* same process with outstanding isn't 0 */ + result = SCTP_STRRESET_ERR_IN_PROGRESS; + goto out; + } + + asoc->strreset_outstanding--; + asoc->strreset_outseq++; + + if (!asoc->strreset_outstanding) { + struct sctp_transport *t; + + t = asoc->strreset_chunk->transport; + if (del_timer(&t->reconf_timer)) + sctp_transport_put(t); + + sctp_chunk_put(asoc->strreset_chunk); + asoc->strreset_chunk = NULL; + } + } + + if (nums) + for (i = 0; i < nums; i++) + SCTP_SI(stream, ntohs(str_p[i]))->mid = 0; + else + for (i = 0; i < stream->incnt; i++) + SCTP_SI(stream, i)->mid = 0; + + result = SCTP_STRRESET_PERFORMED; + + *evp = sctp_ulpevent_make_stream_reset_event(asoc, + SCTP_STREAM_RESET_INCOMING_SSN, nums, str_p, GFP_ATOMIC); + +out: + sctp_update_strreset_result(asoc, result); +err: + return sctp_make_strreset_resp(asoc, result, request_seq); +} + +struct sctp_chunk *sctp_process_strreset_inreq( + struct sctp_association *asoc, + union sctp_params param, + struct sctp_ulpevent **evp) +{ + struct sctp_strreset_inreq *inreq = param.v; + struct sctp_stream *stream = &asoc->stream; + __u32 result = SCTP_STRRESET_DENIED; + struct sctp_chunk *chunk = NULL; + __u32 request_seq; + __u16 i, nums; + __be16 *str_p; + + request_seq = ntohl(inreq->request_seq); + if (TSN_lt(asoc->strreset_inseq, request_seq) || + TSN_lt(request_seq, asoc->strreset_inseq - 2)) { + result = SCTP_STRRESET_ERR_BAD_SEQNO; + goto err; + } else if (TSN_lt(request_seq, asoc->strreset_inseq)) { + i = asoc->strreset_inseq - request_seq - 1; + result = asoc->strreset_result[i]; + if (result == SCTP_STRRESET_PERFORMED) + return NULL; + goto err; + } + asoc->strreset_inseq++; + + if (!(asoc->strreset_enable & SCTP_ENABLE_RESET_STREAM_REQ)) + goto out; + + if (asoc->strreset_outstanding) { + result = SCTP_STRRESET_ERR_IN_PROGRESS; + goto out; + } + + nums = (ntohs(param.p->length) - sizeof(*inreq)) / sizeof(__u16); + str_p = inreq->list_of_streams; + for (i = 0; i < nums; i++) { + if (ntohs(str_p[i]) >= stream->outcnt) { + result = SCTP_STRRESET_ERR_WRONG_SSN; + goto out; + } + } + + if (!sctp_stream_outq_is_empty(stream, nums, str_p)) { + result = SCTP_STRRESET_IN_PROGRESS; + asoc->strreset_inseq--; + goto err; + } + + chunk = sctp_make_strreset_req(asoc, nums, str_p, 1, 0); + if (!chunk) + goto out; + + if (nums) + for (i = 0; i < nums; i++) + SCTP_SO(stream, ntohs(str_p[i]))->state = + SCTP_STREAM_CLOSED; + else + for (i = 0; i < stream->outcnt; i++) + SCTP_SO(stream, i)->state = SCTP_STREAM_CLOSED; + + asoc->strreset_chunk = chunk; + asoc->strreset_outstanding = 1; + sctp_chunk_hold(asoc->strreset_chunk); + + result = SCTP_STRRESET_PERFORMED; + +out: + sctp_update_strreset_result(asoc, result); +err: + if (!chunk) + chunk = sctp_make_strreset_resp(asoc, result, request_seq); + + return chunk; +} + +struct sctp_chunk *sctp_process_strreset_tsnreq( + struct sctp_association *asoc, + union sctp_params param, + struct sctp_ulpevent **evp) +{ + __u32 init_tsn = 0, next_tsn = 0, max_tsn_seen; + struct sctp_strreset_tsnreq *tsnreq = param.v; + struct sctp_stream *stream = &asoc->stream; + __u32 result = SCTP_STRRESET_DENIED; + __u32 request_seq; + __u16 i; + + request_seq = ntohl(tsnreq->request_seq); + if (TSN_lt(asoc->strreset_inseq, request_seq) || + TSN_lt(request_seq, asoc->strreset_inseq - 2)) { + result = SCTP_STRRESET_ERR_BAD_SEQNO; + goto err; + } else if (TSN_lt(request_seq, asoc->strreset_inseq)) { + i = asoc->strreset_inseq - request_seq - 1; + result = asoc->strreset_result[i]; + if (result == SCTP_STRRESET_PERFORMED) { + next_tsn = asoc->ctsn_ack_point + 1; + init_tsn = + sctp_tsnmap_get_ctsn(&asoc->peer.tsn_map) + 1; + } + goto err; + } + + if (!sctp_outq_is_empty(&asoc->outqueue)) { + result = SCTP_STRRESET_IN_PROGRESS; + goto err; + } + + asoc->strreset_inseq++; + + if (!(asoc->strreset_enable & SCTP_ENABLE_RESET_ASSOC_REQ)) + goto out; + + if (asoc->strreset_outstanding) { + result = SCTP_STRRESET_ERR_IN_PROGRESS; + goto out; + } + + /* G4: The same processing as though a FWD-TSN chunk (as defined in + * [RFC3758]) with all streams affected and a new cumulative TSN + * ACK of the Receiver's Next TSN minus 1 were received MUST be + * performed. + */ + max_tsn_seen = sctp_tsnmap_get_max_tsn_seen(&asoc->peer.tsn_map); + asoc->stream.si->report_ftsn(&asoc->ulpq, max_tsn_seen); + + /* G1: Compute an appropriate value for the Receiver's Next TSN -- the + * TSN that the peer should use to send the next DATA chunk. The + * value SHOULD be the smallest TSN not acknowledged by the + * receiver of the request plus 2^31. + */ + init_tsn = sctp_tsnmap_get_ctsn(&asoc->peer.tsn_map) + (1 << 31); + sctp_tsnmap_init(&asoc->peer.tsn_map, SCTP_TSN_MAP_INITIAL, + init_tsn, GFP_ATOMIC); + + /* G3: The same processing as though a SACK chunk with no gap report + * and a cumulative TSN ACK of the Sender's Next TSN minus 1 were + * received MUST be performed. + */ + sctp_outq_free(&asoc->outqueue); + + /* G2: Compute an appropriate value for the local endpoint's next TSN, + * i.e., the next TSN assigned by the receiver of the SSN/TSN reset + * chunk. The value SHOULD be the highest TSN sent by the receiver + * of the request plus 1. + */ + next_tsn = asoc->next_tsn; + asoc->ctsn_ack_point = next_tsn - 1; + asoc->adv_peer_ack_point = asoc->ctsn_ack_point; + + /* G5: The next expected and outgoing SSNs MUST be reset to 0 for all + * incoming and outgoing streams. + */ + for (i = 0; i < stream->outcnt; i++) { + SCTP_SO(stream, i)->mid = 0; + SCTP_SO(stream, i)->mid_uo = 0; + } + for (i = 0; i < stream->incnt; i++) + SCTP_SI(stream, i)->mid = 0; + + result = SCTP_STRRESET_PERFORMED; + + *evp = sctp_ulpevent_make_assoc_reset_event(asoc, 0, init_tsn, + next_tsn, GFP_ATOMIC); + +out: + sctp_update_strreset_result(asoc, result); +err: + return sctp_make_strreset_tsnresp(asoc, result, request_seq, + next_tsn, init_tsn); +} + +struct sctp_chunk *sctp_process_strreset_addstrm_out( + struct sctp_association *asoc, + union sctp_params param, + struct sctp_ulpevent **evp) +{ + struct sctp_strreset_addstrm *addstrm = param.v; + struct sctp_stream *stream = &asoc->stream; + __u32 result = SCTP_STRRESET_DENIED; + __u32 request_seq, incnt; + __u16 in, i; + + request_seq = ntohl(addstrm->request_seq); + if (TSN_lt(asoc->strreset_inseq, request_seq) || + TSN_lt(request_seq, asoc->strreset_inseq - 2)) { + result = SCTP_STRRESET_ERR_BAD_SEQNO; + goto err; + } else if (TSN_lt(request_seq, asoc->strreset_inseq)) { + i = asoc->strreset_inseq - request_seq - 1; + result = asoc->strreset_result[i]; + goto err; + } + asoc->strreset_inseq++; + + if (!(asoc->strreset_enable & SCTP_ENABLE_CHANGE_ASSOC_REQ)) + goto out; + + in = ntohs(addstrm->number_of_streams); + incnt = stream->incnt + in; + if (!in || incnt > SCTP_MAX_STREAM) + goto out; + + if (sctp_stream_alloc_in(stream, incnt, GFP_ATOMIC)) + goto out; + + if (asoc->strreset_chunk) { + if (!sctp_chunk_lookup_strreset_param( + asoc, 0, SCTP_PARAM_RESET_ADD_IN_STREAMS)) { + /* same process with outstanding isn't 0 */ + result = SCTP_STRRESET_ERR_IN_PROGRESS; + goto out; + } + + asoc->strreset_outstanding--; + asoc->strreset_outseq++; + + if (!asoc->strreset_outstanding) { + struct sctp_transport *t; + + t = asoc->strreset_chunk->transport; + if (del_timer(&t->reconf_timer)) + sctp_transport_put(t); + + sctp_chunk_put(asoc->strreset_chunk); + asoc->strreset_chunk = NULL; + } + } + + stream->incnt = incnt; + + result = SCTP_STRRESET_PERFORMED; + + *evp = sctp_ulpevent_make_stream_change_event(asoc, + 0, ntohs(addstrm->number_of_streams), 0, GFP_ATOMIC); + +out: + sctp_update_strreset_result(asoc, result); +err: + return sctp_make_strreset_resp(asoc, result, request_seq); +} + +struct sctp_chunk *sctp_process_strreset_addstrm_in( + struct sctp_association *asoc, + union sctp_params param, + struct sctp_ulpevent **evp) +{ + struct sctp_strreset_addstrm *addstrm = param.v; + struct sctp_stream *stream = &asoc->stream; + __u32 result = SCTP_STRRESET_DENIED; + struct sctp_chunk *chunk = NULL; + __u32 request_seq, outcnt; + __u16 out, i; + int ret; + + request_seq = ntohl(addstrm->request_seq); + if (TSN_lt(asoc->strreset_inseq, request_seq) || + TSN_lt(request_seq, asoc->strreset_inseq - 2)) { + result = SCTP_STRRESET_ERR_BAD_SEQNO; + goto err; + } else if (TSN_lt(request_seq, asoc->strreset_inseq)) { + i = asoc->strreset_inseq - request_seq - 1; + result = asoc->strreset_result[i]; + if (result == SCTP_STRRESET_PERFORMED) + return NULL; + goto err; + } + asoc->strreset_inseq++; + + if (!(asoc->strreset_enable & SCTP_ENABLE_CHANGE_ASSOC_REQ)) + goto out; + + if (asoc->strreset_outstanding) { + result = SCTP_STRRESET_ERR_IN_PROGRESS; + goto out; + } + + out = ntohs(addstrm->number_of_streams); + outcnt = stream->outcnt + out; + if (!out || outcnt > SCTP_MAX_STREAM) + goto out; + + ret = sctp_stream_alloc_out(stream, outcnt, GFP_ATOMIC); + if (ret) + goto out; + + chunk = sctp_make_strreset_addstrm(asoc, out, 0); + if (!chunk) + goto out; + + asoc->strreset_chunk = chunk; + asoc->strreset_outstanding = 1; + sctp_chunk_hold(asoc->strreset_chunk); + + stream->outcnt = outcnt; + + result = SCTP_STRRESET_PERFORMED; + +out: + sctp_update_strreset_result(asoc, result); +err: + if (!chunk) + chunk = sctp_make_strreset_resp(asoc, result, request_seq); + + return chunk; +} + +struct sctp_chunk *sctp_process_strreset_resp( + struct sctp_association *asoc, + union sctp_params param, + struct sctp_ulpevent **evp) +{ + struct sctp_stream *stream = &asoc->stream; + struct sctp_strreset_resp *resp = param.v; + struct sctp_transport *t; + __u16 i, nums, flags = 0; + struct sctp_paramhdr *req; + __u32 result; + + req = sctp_chunk_lookup_strreset_param(asoc, resp->response_seq, 0); + if (!req) + return NULL; + + result = ntohl(resp->result); + if (result != SCTP_STRRESET_PERFORMED) { + /* if in progress, do nothing but retransmit */ + if (result == SCTP_STRRESET_IN_PROGRESS) + return NULL; + else if (result == SCTP_STRRESET_DENIED) + flags = SCTP_STREAM_RESET_DENIED; + else + flags = SCTP_STREAM_RESET_FAILED; + } + + if (req->type == SCTP_PARAM_RESET_OUT_REQUEST) { + struct sctp_strreset_outreq *outreq; + __be16 *str_p; + + outreq = (struct sctp_strreset_outreq *)req; + str_p = outreq->list_of_streams; + nums = (ntohs(outreq->param_hdr.length) - sizeof(*outreq)) / + sizeof(__u16); + + if (result == SCTP_STRRESET_PERFORMED) { + struct sctp_stream_out *sout; + if (nums) { + for (i = 0; i < nums; i++) { + sout = SCTP_SO(stream, ntohs(str_p[i])); + sout->mid = 0; + sout->mid_uo = 0; + } + } else { + for (i = 0; i < stream->outcnt; i++) { + sout = SCTP_SO(stream, i); + sout->mid = 0; + sout->mid_uo = 0; + } + } + } + + flags |= SCTP_STREAM_RESET_OUTGOING_SSN; + + for (i = 0; i < stream->outcnt; i++) + SCTP_SO(stream, i)->state = SCTP_STREAM_OPEN; + + *evp = sctp_ulpevent_make_stream_reset_event(asoc, flags, + nums, str_p, GFP_ATOMIC); + } else if (req->type == SCTP_PARAM_RESET_IN_REQUEST) { + struct sctp_strreset_inreq *inreq; + __be16 *str_p; + + /* if the result is performed, it's impossible for inreq */ + if (result == SCTP_STRRESET_PERFORMED) + return NULL; + + inreq = (struct sctp_strreset_inreq *)req; + str_p = inreq->list_of_streams; + nums = (ntohs(inreq->param_hdr.length) - sizeof(*inreq)) / + sizeof(__u16); + + flags |= SCTP_STREAM_RESET_INCOMING_SSN; + + *evp = sctp_ulpevent_make_stream_reset_event(asoc, flags, + nums, str_p, GFP_ATOMIC); + } else if (req->type == SCTP_PARAM_RESET_TSN_REQUEST) { + struct sctp_strreset_resptsn *resptsn; + __u32 stsn, rtsn; + + /* check for resptsn, as sctp_verify_reconf didn't do it*/ + if (ntohs(param.p->length) != sizeof(*resptsn)) + return NULL; + + resptsn = (struct sctp_strreset_resptsn *)resp; + stsn = ntohl(resptsn->senders_next_tsn); + rtsn = ntohl(resptsn->receivers_next_tsn); + + if (result == SCTP_STRRESET_PERFORMED) { + __u32 mtsn = sctp_tsnmap_get_max_tsn_seen( + &asoc->peer.tsn_map); + LIST_HEAD(temp); + + asoc->stream.si->report_ftsn(&asoc->ulpq, mtsn); + + sctp_tsnmap_init(&asoc->peer.tsn_map, + SCTP_TSN_MAP_INITIAL, + stsn, GFP_ATOMIC); + + /* Clean up sacked and abandoned queues only. As the + * out_chunk_list may not be empty, splice it to temp, + * then get it back after sctp_outq_free is done. + */ + list_splice_init(&asoc->outqueue.out_chunk_list, &temp); + sctp_outq_free(&asoc->outqueue); + list_splice_init(&temp, &asoc->outqueue.out_chunk_list); + + asoc->next_tsn = rtsn; + asoc->ctsn_ack_point = asoc->next_tsn - 1; + asoc->adv_peer_ack_point = asoc->ctsn_ack_point; + + for (i = 0; i < stream->outcnt; i++) { + SCTP_SO(stream, i)->mid = 0; + SCTP_SO(stream, i)->mid_uo = 0; + } + for (i = 0; i < stream->incnt; i++) + SCTP_SI(stream, i)->mid = 0; + } + + for (i = 0; i < stream->outcnt; i++) + SCTP_SO(stream, i)->state = SCTP_STREAM_OPEN; + + *evp = sctp_ulpevent_make_assoc_reset_event(asoc, flags, + stsn, rtsn, GFP_ATOMIC); + } else if (req->type == SCTP_PARAM_RESET_ADD_OUT_STREAMS) { + struct sctp_strreset_addstrm *addstrm; + __u16 number; + + addstrm = (struct sctp_strreset_addstrm *)req; + nums = ntohs(addstrm->number_of_streams); + number = stream->outcnt - nums; + + if (result == SCTP_STRRESET_PERFORMED) { + for (i = number; i < stream->outcnt; i++) + SCTP_SO(stream, i)->state = SCTP_STREAM_OPEN; + } else { + sctp_stream_shrink_out(stream, number); + stream->outcnt = number; + } + + *evp = sctp_ulpevent_make_stream_change_event(asoc, flags, + 0, nums, GFP_ATOMIC); + } else if (req->type == SCTP_PARAM_RESET_ADD_IN_STREAMS) { + struct sctp_strreset_addstrm *addstrm; + + /* if the result is performed, it's impossible for addstrm in + * request. + */ + if (result == SCTP_STRRESET_PERFORMED) + return NULL; + + addstrm = (struct sctp_strreset_addstrm *)req; + nums = ntohs(addstrm->number_of_streams); + + *evp = sctp_ulpevent_make_stream_change_event(asoc, flags, + nums, 0, GFP_ATOMIC); + } + + asoc->strreset_outstanding--; + asoc->strreset_outseq++; + + /* remove everything for this reconf request */ + if (!asoc->strreset_outstanding) { + t = asoc->strreset_chunk->transport; + if (del_timer(&t->reconf_timer)) + sctp_transport_put(t); + + sctp_chunk_put(asoc->strreset_chunk); + asoc->strreset_chunk = NULL; + } + + return NULL; +} diff --git a/net/sctp/stream_interleave.c b/net/sctp/stream_interleave.c new file mode 100644 index 000000000..4fbd6d052 --- /dev/null +++ b/net/sctp/stream_interleave.c @@ -0,0 +1,1359 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright Red Hat Inc. 2017 + * + * This file is part of the SCTP kernel implementation + * + * These functions implement sctp stream message interleaving, mostly + * including I-DATA and I-FORWARD-TSN chunks process. + * + * Please send any bug reports or fixes you make to the + * email addresched(es): + * lksctp developers + * + * Written or modified by: + * Xin Long + */ + +#include +#include +#include +#include +#include + +static struct sctp_chunk *sctp_make_idatafrag_empty( + const struct sctp_association *asoc, + const struct sctp_sndrcvinfo *sinfo, + int len, __u8 flags, gfp_t gfp) +{ + struct sctp_chunk *retval; + struct sctp_idatahdr dp; + + memset(&dp, 0, sizeof(dp)); + dp.stream = htons(sinfo->sinfo_stream); + + if (sinfo->sinfo_flags & SCTP_UNORDERED) + flags |= SCTP_DATA_UNORDERED; + + retval = sctp_make_idata(asoc, flags, sizeof(dp) + len, gfp); + if (!retval) + return NULL; + + retval->subh.idata_hdr = sctp_addto_chunk(retval, sizeof(dp), &dp); + memcpy(&retval->sinfo, sinfo, sizeof(struct sctp_sndrcvinfo)); + + return retval; +} + +static void sctp_chunk_assign_mid(struct sctp_chunk *chunk) +{ + struct sctp_stream *stream; + struct sctp_chunk *lchunk; + __u32 cfsn = 0; + __u16 sid; + + if (chunk->has_mid) + return; + + sid = sctp_chunk_stream_no(chunk); + stream = &chunk->asoc->stream; + + list_for_each_entry(lchunk, &chunk->msg->chunks, frag_list) { + struct sctp_idatahdr *hdr; + __u32 mid; + + lchunk->has_mid = 1; + + hdr = lchunk->subh.idata_hdr; + + if (lchunk->chunk_hdr->flags & SCTP_DATA_FIRST_FRAG) + hdr->ppid = lchunk->sinfo.sinfo_ppid; + else + hdr->fsn = htonl(cfsn++); + + if (lchunk->chunk_hdr->flags & SCTP_DATA_UNORDERED) { + mid = lchunk->chunk_hdr->flags & SCTP_DATA_LAST_FRAG ? + sctp_mid_uo_next(stream, out, sid) : + sctp_mid_uo_peek(stream, out, sid); + } else { + mid = lchunk->chunk_hdr->flags & SCTP_DATA_LAST_FRAG ? + sctp_mid_next(stream, out, sid) : + sctp_mid_peek(stream, out, sid); + } + hdr->mid = htonl(mid); + } +} + +static bool sctp_validate_data(struct sctp_chunk *chunk) +{ + struct sctp_stream *stream; + __u16 sid, ssn; + + if (chunk->chunk_hdr->type != SCTP_CID_DATA) + return false; + + if (chunk->chunk_hdr->flags & SCTP_DATA_UNORDERED) + return true; + + stream = &chunk->asoc->stream; + sid = sctp_chunk_stream_no(chunk); + ssn = ntohs(chunk->subh.data_hdr->ssn); + + return !SSN_lt(ssn, sctp_ssn_peek(stream, in, sid)); +} + +static bool sctp_validate_idata(struct sctp_chunk *chunk) +{ + struct sctp_stream *stream; + __u32 mid; + __u16 sid; + + if (chunk->chunk_hdr->type != SCTP_CID_I_DATA) + return false; + + if (chunk->chunk_hdr->flags & SCTP_DATA_UNORDERED) + return true; + + stream = &chunk->asoc->stream; + sid = sctp_chunk_stream_no(chunk); + mid = ntohl(chunk->subh.idata_hdr->mid); + + return !MID_lt(mid, sctp_mid_peek(stream, in, sid)); +} + +static void sctp_intl_store_reasm(struct sctp_ulpq *ulpq, + struct sctp_ulpevent *event) +{ + struct sctp_ulpevent *cevent; + struct sk_buff *pos, *loc; + + pos = skb_peek_tail(&ulpq->reasm); + if (!pos) { + __skb_queue_tail(&ulpq->reasm, sctp_event2skb(event)); + return; + } + + cevent = sctp_skb2event(pos); + + if (event->stream == cevent->stream && + event->mid == cevent->mid && + (cevent->msg_flags & SCTP_DATA_FIRST_FRAG || + (!(event->msg_flags & SCTP_DATA_FIRST_FRAG) && + event->fsn > cevent->fsn))) { + __skb_queue_tail(&ulpq->reasm, sctp_event2skb(event)); + return; + } + + if ((event->stream == cevent->stream && + MID_lt(cevent->mid, event->mid)) || + event->stream > cevent->stream) { + __skb_queue_tail(&ulpq->reasm, sctp_event2skb(event)); + return; + } + + loc = NULL; + skb_queue_walk(&ulpq->reasm, pos) { + cevent = sctp_skb2event(pos); + + if (event->stream < cevent->stream || + (event->stream == cevent->stream && + MID_lt(event->mid, cevent->mid))) { + loc = pos; + break; + } + if (event->stream == cevent->stream && + event->mid == cevent->mid && + !(cevent->msg_flags & SCTP_DATA_FIRST_FRAG) && + (event->msg_flags & SCTP_DATA_FIRST_FRAG || + event->fsn < cevent->fsn)) { + loc = pos; + break; + } + } + + if (!loc) + __skb_queue_tail(&ulpq->reasm, sctp_event2skb(event)); + else + __skb_queue_before(&ulpq->reasm, loc, sctp_event2skb(event)); +} + +static struct sctp_ulpevent *sctp_intl_retrieve_partial( + struct sctp_ulpq *ulpq, + struct sctp_ulpevent *event) +{ + struct sk_buff *first_frag = NULL; + struct sk_buff *last_frag = NULL; + struct sctp_ulpevent *retval; + struct sctp_stream_in *sin; + struct sk_buff *pos; + __u32 next_fsn = 0; + int is_last = 0; + + sin = sctp_stream_in(&ulpq->asoc->stream, event->stream); + + skb_queue_walk(&ulpq->reasm, pos) { + struct sctp_ulpevent *cevent = sctp_skb2event(pos); + + if (cevent->stream < event->stream) + continue; + + if (cevent->stream > event->stream || + cevent->mid != sin->mid) + break; + + switch (cevent->msg_flags & SCTP_DATA_FRAG_MASK) { + case SCTP_DATA_FIRST_FRAG: + goto out; + case SCTP_DATA_MIDDLE_FRAG: + if (!first_frag) { + if (cevent->fsn == sin->fsn) { + first_frag = pos; + last_frag = pos; + next_fsn = cevent->fsn + 1; + } + } else if (cevent->fsn == next_fsn) { + last_frag = pos; + next_fsn++; + } else { + goto out; + } + break; + case SCTP_DATA_LAST_FRAG: + if (!first_frag) { + if (cevent->fsn == sin->fsn) { + first_frag = pos; + last_frag = pos; + next_fsn = 0; + is_last = 1; + } + } else if (cevent->fsn == next_fsn) { + last_frag = pos; + next_fsn = 0; + is_last = 1; + } + goto out; + default: + goto out; + } + } + +out: + if (!first_frag) + return NULL; + + retval = sctp_make_reassembled_event(ulpq->asoc->base.net, &ulpq->reasm, + first_frag, last_frag); + if (retval) { + sin->fsn = next_fsn; + if (is_last) { + retval->msg_flags |= MSG_EOR; + sin->pd_mode = 0; + } + } + + return retval; +} + +static struct sctp_ulpevent *sctp_intl_retrieve_reassembled( + struct sctp_ulpq *ulpq, + struct sctp_ulpevent *event) +{ + struct sctp_association *asoc = ulpq->asoc; + struct sk_buff *pos, *first_frag = NULL; + struct sctp_ulpevent *retval = NULL; + struct sk_buff *pd_first = NULL; + struct sk_buff *pd_last = NULL; + struct sctp_stream_in *sin; + __u32 next_fsn = 0; + __u32 pd_point = 0; + __u32 pd_len = 0; + __u32 mid = 0; + + sin = sctp_stream_in(&ulpq->asoc->stream, event->stream); + + skb_queue_walk(&ulpq->reasm, pos) { + struct sctp_ulpevent *cevent = sctp_skb2event(pos); + + if (cevent->stream < event->stream) + continue; + if (cevent->stream > event->stream) + break; + + if (MID_lt(cevent->mid, event->mid)) + continue; + if (MID_lt(event->mid, cevent->mid)) + break; + + switch (cevent->msg_flags & SCTP_DATA_FRAG_MASK) { + case SCTP_DATA_FIRST_FRAG: + if (cevent->mid == sin->mid) { + pd_first = pos; + pd_last = pos; + pd_len = pos->len; + } + + first_frag = pos; + next_fsn = 0; + mid = cevent->mid; + break; + + case SCTP_DATA_MIDDLE_FRAG: + if (first_frag && cevent->mid == mid && + cevent->fsn == next_fsn) { + next_fsn++; + if (pd_first) { + pd_last = pos; + pd_len += pos->len; + } + } else { + first_frag = NULL; + } + break; + + case SCTP_DATA_LAST_FRAG: + if (first_frag && cevent->mid == mid && + cevent->fsn == next_fsn) + goto found; + else + first_frag = NULL; + break; + } + } + + if (!pd_first) + goto out; + + pd_point = sctp_sk(asoc->base.sk)->pd_point; + if (pd_point && pd_point <= pd_len) { + retval = sctp_make_reassembled_event(asoc->base.net, + &ulpq->reasm, + pd_first, pd_last); + if (retval) { + sin->fsn = next_fsn; + sin->pd_mode = 1; + } + } + goto out; + +found: + retval = sctp_make_reassembled_event(asoc->base.net, &ulpq->reasm, + first_frag, pos); + if (retval) + retval->msg_flags |= MSG_EOR; + +out: + return retval; +} + +static struct sctp_ulpevent *sctp_intl_reasm(struct sctp_ulpq *ulpq, + struct sctp_ulpevent *event) +{ + struct sctp_ulpevent *retval = NULL; + struct sctp_stream_in *sin; + + if (SCTP_DATA_NOT_FRAG == (event->msg_flags & SCTP_DATA_FRAG_MASK)) { + event->msg_flags |= MSG_EOR; + return event; + } + + sctp_intl_store_reasm(ulpq, event); + + sin = sctp_stream_in(&ulpq->asoc->stream, event->stream); + if (sin->pd_mode && event->mid == sin->mid && + event->fsn == sin->fsn) + retval = sctp_intl_retrieve_partial(ulpq, event); + + if (!retval) + retval = sctp_intl_retrieve_reassembled(ulpq, event); + + return retval; +} + +static void sctp_intl_store_ordered(struct sctp_ulpq *ulpq, + struct sctp_ulpevent *event) +{ + struct sctp_ulpevent *cevent; + struct sk_buff *pos, *loc; + + pos = skb_peek_tail(&ulpq->lobby); + if (!pos) { + __skb_queue_tail(&ulpq->lobby, sctp_event2skb(event)); + return; + } + + cevent = (struct sctp_ulpevent *)pos->cb; + if (event->stream == cevent->stream && + MID_lt(cevent->mid, event->mid)) { + __skb_queue_tail(&ulpq->lobby, sctp_event2skb(event)); + return; + } + + if (event->stream > cevent->stream) { + __skb_queue_tail(&ulpq->lobby, sctp_event2skb(event)); + return; + } + + loc = NULL; + skb_queue_walk(&ulpq->lobby, pos) { + cevent = (struct sctp_ulpevent *)pos->cb; + + if (cevent->stream > event->stream) { + loc = pos; + break; + } + if (cevent->stream == event->stream && + MID_lt(event->mid, cevent->mid)) { + loc = pos; + break; + } + } + + if (!loc) + __skb_queue_tail(&ulpq->lobby, sctp_event2skb(event)); + else + __skb_queue_before(&ulpq->lobby, loc, sctp_event2skb(event)); +} + +static void sctp_intl_retrieve_ordered(struct sctp_ulpq *ulpq, + struct sctp_ulpevent *event) +{ + struct sk_buff_head *event_list; + struct sctp_stream *stream; + struct sk_buff *pos, *tmp; + __u16 sid = event->stream; + + stream = &ulpq->asoc->stream; + event_list = (struct sk_buff_head *)sctp_event2skb(event)->prev; + + sctp_skb_for_each(pos, &ulpq->lobby, tmp) { + struct sctp_ulpevent *cevent = (struct sctp_ulpevent *)pos->cb; + + if (cevent->stream > sid) + break; + + if (cevent->stream < sid) + continue; + + if (cevent->mid != sctp_mid_peek(stream, in, sid)) + break; + + sctp_mid_next(stream, in, sid); + + __skb_unlink(pos, &ulpq->lobby); + + __skb_queue_tail(event_list, pos); + } +} + +static struct sctp_ulpevent *sctp_intl_order(struct sctp_ulpq *ulpq, + struct sctp_ulpevent *event) +{ + struct sctp_stream *stream; + __u16 sid; + + stream = &ulpq->asoc->stream; + sid = event->stream; + + if (event->mid != sctp_mid_peek(stream, in, sid)) { + sctp_intl_store_ordered(ulpq, event); + return NULL; + } + + sctp_mid_next(stream, in, sid); + + sctp_intl_retrieve_ordered(ulpq, event); + + return event; +} + +static int sctp_enqueue_event(struct sctp_ulpq *ulpq, + struct sk_buff_head *skb_list) +{ + struct sock *sk = ulpq->asoc->base.sk; + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_ulpevent *event; + struct sk_buff *skb; + + skb = __skb_peek(skb_list); + event = sctp_skb2event(skb); + + if (sk->sk_shutdown & RCV_SHUTDOWN && + (sk->sk_shutdown & SEND_SHUTDOWN || + !sctp_ulpevent_is_notification(event))) + goto out_free; + + if (!sctp_ulpevent_is_notification(event)) { + sk_mark_napi_id(sk, skb); + sk_incoming_cpu_update(sk); + } + + if (!sctp_ulpevent_is_enabled(event, ulpq->asoc->subscribe)) + goto out_free; + + if (skb_list) + skb_queue_splice_tail_init(skb_list, + &sk->sk_receive_queue); + else + __skb_queue_tail(&sk->sk_receive_queue, skb); + + if (!sp->data_ready_signalled) { + sp->data_ready_signalled = 1; + sk->sk_data_ready(sk); + } + + return 1; + +out_free: + if (skb_list) + sctp_queue_purge_ulpevents(skb_list); + else + sctp_ulpevent_free(event); + + return 0; +} + +static void sctp_intl_store_reasm_uo(struct sctp_ulpq *ulpq, + struct sctp_ulpevent *event) +{ + struct sctp_ulpevent *cevent; + struct sk_buff *pos; + + pos = skb_peek_tail(&ulpq->reasm_uo); + if (!pos) { + __skb_queue_tail(&ulpq->reasm_uo, sctp_event2skb(event)); + return; + } + + cevent = sctp_skb2event(pos); + + if (event->stream == cevent->stream && + event->mid == cevent->mid && + (cevent->msg_flags & SCTP_DATA_FIRST_FRAG || + (!(event->msg_flags & SCTP_DATA_FIRST_FRAG) && + event->fsn > cevent->fsn))) { + __skb_queue_tail(&ulpq->reasm_uo, sctp_event2skb(event)); + return; + } + + if ((event->stream == cevent->stream && + MID_lt(cevent->mid, event->mid)) || + event->stream > cevent->stream) { + __skb_queue_tail(&ulpq->reasm_uo, sctp_event2skb(event)); + return; + } + + skb_queue_walk(&ulpq->reasm_uo, pos) { + cevent = sctp_skb2event(pos); + + if (event->stream < cevent->stream || + (event->stream == cevent->stream && + MID_lt(event->mid, cevent->mid))) + break; + + if (event->stream == cevent->stream && + event->mid == cevent->mid && + !(cevent->msg_flags & SCTP_DATA_FIRST_FRAG) && + (event->msg_flags & SCTP_DATA_FIRST_FRAG || + event->fsn < cevent->fsn)) + break; + } + + __skb_queue_before(&ulpq->reasm_uo, pos, sctp_event2skb(event)); +} + +static struct sctp_ulpevent *sctp_intl_retrieve_partial_uo( + struct sctp_ulpq *ulpq, + struct sctp_ulpevent *event) +{ + struct sk_buff *first_frag = NULL; + struct sk_buff *last_frag = NULL; + struct sctp_ulpevent *retval; + struct sctp_stream_in *sin; + struct sk_buff *pos; + __u32 next_fsn = 0; + int is_last = 0; + + sin = sctp_stream_in(&ulpq->asoc->stream, event->stream); + + skb_queue_walk(&ulpq->reasm_uo, pos) { + struct sctp_ulpevent *cevent = sctp_skb2event(pos); + + if (cevent->stream < event->stream) + continue; + if (cevent->stream > event->stream) + break; + + if (MID_lt(cevent->mid, sin->mid_uo)) + continue; + if (MID_lt(sin->mid_uo, cevent->mid)) + break; + + switch (cevent->msg_flags & SCTP_DATA_FRAG_MASK) { + case SCTP_DATA_FIRST_FRAG: + goto out; + case SCTP_DATA_MIDDLE_FRAG: + if (!first_frag) { + if (cevent->fsn == sin->fsn_uo) { + first_frag = pos; + last_frag = pos; + next_fsn = cevent->fsn + 1; + } + } else if (cevent->fsn == next_fsn) { + last_frag = pos; + next_fsn++; + } else { + goto out; + } + break; + case SCTP_DATA_LAST_FRAG: + if (!first_frag) { + if (cevent->fsn == sin->fsn_uo) { + first_frag = pos; + last_frag = pos; + next_fsn = 0; + is_last = 1; + } + } else if (cevent->fsn == next_fsn) { + last_frag = pos; + next_fsn = 0; + is_last = 1; + } + goto out; + default: + goto out; + } + } + +out: + if (!first_frag) + return NULL; + + retval = sctp_make_reassembled_event(ulpq->asoc->base.net, + &ulpq->reasm_uo, first_frag, + last_frag); + if (retval) { + sin->fsn_uo = next_fsn; + if (is_last) { + retval->msg_flags |= MSG_EOR; + sin->pd_mode_uo = 0; + } + } + + return retval; +} + +static struct sctp_ulpevent *sctp_intl_retrieve_reassembled_uo( + struct sctp_ulpq *ulpq, + struct sctp_ulpevent *event) +{ + struct sctp_association *asoc = ulpq->asoc; + struct sk_buff *pos, *first_frag = NULL; + struct sctp_ulpevent *retval = NULL; + struct sk_buff *pd_first = NULL; + struct sk_buff *pd_last = NULL; + struct sctp_stream_in *sin; + __u32 next_fsn = 0; + __u32 pd_point = 0; + __u32 pd_len = 0; + __u32 mid = 0; + + sin = sctp_stream_in(&ulpq->asoc->stream, event->stream); + + skb_queue_walk(&ulpq->reasm_uo, pos) { + struct sctp_ulpevent *cevent = sctp_skb2event(pos); + + if (cevent->stream < event->stream) + continue; + if (cevent->stream > event->stream) + break; + + if (MID_lt(cevent->mid, event->mid)) + continue; + if (MID_lt(event->mid, cevent->mid)) + break; + + switch (cevent->msg_flags & SCTP_DATA_FRAG_MASK) { + case SCTP_DATA_FIRST_FRAG: + if (!sin->pd_mode_uo) { + sin->mid_uo = cevent->mid; + pd_first = pos; + pd_last = pos; + pd_len = pos->len; + } + + first_frag = pos; + next_fsn = 0; + mid = cevent->mid; + break; + + case SCTP_DATA_MIDDLE_FRAG: + if (first_frag && cevent->mid == mid && + cevent->fsn == next_fsn) { + next_fsn++; + if (pd_first) { + pd_last = pos; + pd_len += pos->len; + } + } else { + first_frag = NULL; + } + break; + + case SCTP_DATA_LAST_FRAG: + if (first_frag && cevent->mid == mid && + cevent->fsn == next_fsn) + goto found; + else + first_frag = NULL; + break; + } + } + + if (!pd_first) + goto out; + + pd_point = sctp_sk(asoc->base.sk)->pd_point; + if (pd_point && pd_point <= pd_len) { + retval = sctp_make_reassembled_event(asoc->base.net, + &ulpq->reasm_uo, + pd_first, pd_last); + if (retval) { + sin->fsn_uo = next_fsn; + sin->pd_mode_uo = 1; + } + } + goto out; + +found: + retval = sctp_make_reassembled_event(asoc->base.net, &ulpq->reasm_uo, + first_frag, pos); + if (retval) + retval->msg_flags |= MSG_EOR; + +out: + return retval; +} + +static struct sctp_ulpevent *sctp_intl_reasm_uo(struct sctp_ulpq *ulpq, + struct sctp_ulpevent *event) +{ + struct sctp_ulpevent *retval = NULL; + struct sctp_stream_in *sin; + + if (SCTP_DATA_NOT_FRAG == (event->msg_flags & SCTP_DATA_FRAG_MASK)) { + event->msg_flags |= MSG_EOR; + return event; + } + + sctp_intl_store_reasm_uo(ulpq, event); + + sin = sctp_stream_in(&ulpq->asoc->stream, event->stream); + if (sin->pd_mode_uo && event->mid == sin->mid_uo && + event->fsn == sin->fsn_uo) + retval = sctp_intl_retrieve_partial_uo(ulpq, event); + + if (!retval) + retval = sctp_intl_retrieve_reassembled_uo(ulpq, event); + + return retval; +} + +static struct sctp_ulpevent *sctp_intl_retrieve_first_uo(struct sctp_ulpq *ulpq) +{ + struct sctp_stream_in *csin, *sin = NULL; + struct sk_buff *first_frag = NULL; + struct sk_buff *last_frag = NULL; + struct sctp_ulpevent *retval; + struct sk_buff *pos; + __u32 next_fsn = 0; + __u16 sid = 0; + + skb_queue_walk(&ulpq->reasm_uo, pos) { + struct sctp_ulpevent *cevent = sctp_skb2event(pos); + + csin = sctp_stream_in(&ulpq->asoc->stream, cevent->stream); + if (csin->pd_mode_uo) + continue; + + switch (cevent->msg_flags & SCTP_DATA_FRAG_MASK) { + case SCTP_DATA_FIRST_FRAG: + if (first_frag) + goto out; + first_frag = pos; + last_frag = pos; + next_fsn = 0; + sin = csin; + sid = cevent->stream; + sin->mid_uo = cevent->mid; + break; + case SCTP_DATA_MIDDLE_FRAG: + if (!first_frag) + break; + if (cevent->stream == sid && + cevent->mid == sin->mid_uo && + cevent->fsn == next_fsn) { + next_fsn++; + last_frag = pos; + } else { + goto out; + } + break; + case SCTP_DATA_LAST_FRAG: + if (first_frag) + goto out; + break; + default: + break; + } + } + + if (!first_frag) + return NULL; + +out: + retval = sctp_make_reassembled_event(ulpq->asoc->base.net, + &ulpq->reasm_uo, first_frag, + last_frag); + if (retval) { + sin->fsn_uo = next_fsn; + sin->pd_mode_uo = 1; + } + + return retval; +} + +static int sctp_ulpevent_idata(struct sctp_ulpq *ulpq, + struct sctp_chunk *chunk, gfp_t gfp) +{ + struct sctp_ulpevent *event; + struct sk_buff_head temp; + int event_eor = 0; + + event = sctp_ulpevent_make_rcvmsg(chunk->asoc, chunk, gfp); + if (!event) + return -ENOMEM; + + event->mid = ntohl(chunk->subh.idata_hdr->mid); + if (event->msg_flags & SCTP_DATA_FIRST_FRAG) + event->ppid = chunk->subh.idata_hdr->ppid; + else + event->fsn = ntohl(chunk->subh.idata_hdr->fsn); + + if (!(event->msg_flags & SCTP_DATA_UNORDERED)) { + event = sctp_intl_reasm(ulpq, event); + if (event) { + skb_queue_head_init(&temp); + __skb_queue_tail(&temp, sctp_event2skb(event)); + + if (event->msg_flags & MSG_EOR) + event = sctp_intl_order(ulpq, event); + } + } else { + event = sctp_intl_reasm_uo(ulpq, event); + if (event) { + skb_queue_head_init(&temp); + __skb_queue_tail(&temp, sctp_event2skb(event)); + } + } + + if (event) { + event_eor = (event->msg_flags & MSG_EOR) ? 1 : 0; + sctp_enqueue_event(ulpq, &temp); + } + + return event_eor; +} + +static struct sctp_ulpevent *sctp_intl_retrieve_first(struct sctp_ulpq *ulpq) +{ + struct sctp_stream_in *csin, *sin = NULL; + struct sk_buff *first_frag = NULL; + struct sk_buff *last_frag = NULL; + struct sctp_ulpevent *retval; + struct sk_buff *pos; + __u32 next_fsn = 0; + __u16 sid = 0; + + skb_queue_walk(&ulpq->reasm, pos) { + struct sctp_ulpevent *cevent = sctp_skb2event(pos); + + csin = sctp_stream_in(&ulpq->asoc->stream, cevent->stream); + if (csin->pd_mode) + continue; + + switch (cevent->msg_flags & SCTP_DATA_FRAG_MASK) { + case SCTP_DATA_FIRST_FRAG: + if (first_frag) + goto out; + if (cevent->mid == csin->mid) { + first_frag = pos; + last_frag = pos; + next_fsn = 0; + sin = csin; + sid = cevent->stream; + } + break; + case SCTP_DATA_MIDDLE_FRAG: + if (!first_frag) + break; + if (cevent->stream == sid && + cevent->mid == sin->mid && + cevent->fsn == next_fsn) { + next_fsn++; + last_frag = pos; + } else { + goto out; + } + break; + case SCTP_DATA_LAST_FRAG: + if (first_frag) + goto out; + break; + default: + break; + } + } + + if (!first_frag) + return NULL; + +out: + retval = sctp_make_reassembled_event(ulpq->asoc->base.net, + &ulpq->reasm, first_frag, + last_frag); + if (retval) { + sin->fsn = next_fsn; + sin->pd_mode = 1; + } + + return retval; +} + +static void sctp_intl_start_pd(struct sctp_ulpq *ulpq, gfp_t gfp) +{ + struct sctp_ulpevent *event; + struct sk_buff_head temp; + + if (!skb_queue_empty(&ulpq->reasm)) { + do { + event = sctp_intl_retrieve_first(ulpq); + if (event) { + skb_queue_head_init(&temp); + __skb_queue_tail(&temp, sctp_event2skb(event)); + sctp_enqueue_event(ulpq, &temp); + } + } while (event); + } + + if (!skb_queue_empty(&ulpq->reasm_uo)) { + do { + event = sctp_intl_retrieve_first_uo(ulpq); + if (event) { + skb_queue_head_init(&temp); + __skb_queue_tail(&temp, sctp_event2skb(event)); + sctp_enqueue_event(ulpq, &temp); + } + } while (event); + } +} + +static void sctp_renege_events(struct sctp_ulpq *ulpq, struct sctp_chunk *chunk, + gfp_t gfp) +{ + struct sctp_association *asoc = ulpq->asoc; + __u32 freed = 0; + __u16 needed; + + needed = ntohs(chunk->chunk_hdr->length) - + sizeof(struct sctp_idata_chunk); + + if (skb_queue_empty(&asoc->base.sk->sk_receive_queue)) { + freed = sctp_ulpq_renege_list(ulpq, &ulpq->lobby, needed); + if (freed < needed) + freed += sctp_ulpq_renege_list(ulpq, &ulpq->reasm, + needed); + if (freed < needed) + freed += sctp_ulpq_renege_list(ulpq, &ulpq->reasm_uo, + needed); + } + + if (freed >= needed && sctp_ulpevent_idata(ulpq, chunk, gfp) <= 0) + sctp_intl_start_pd(ulpq, gfp); +} + +static void sctp_intl_stream_abort_pd(struct sctp_ulpq *ulpq, __u16 sid, + __u32 mid, __u16 flags, gfp_t gfp) +{ + struct sock *sk = ulpq->asoc->base.sk; + struct sctp_ulpevent *ev = NULL; + + if (!sctp_ulpevent_type_enabled(ulpq->asoc->subscribe, + SCTP_PARTIAL_DELIVERY_EVENT)) + return; + + ev = sctp_ulpevent_make_pdapi(ulpq->asoc, SCTP_PARTIAL_DELIVERY_ABORTED, + sid, mid, flags, gfp); + if (ev) { + struct sctp_sock *sp = sctp_sk(sk); + + __skb_queue_tail(&sk->sk_receive_queue, sctp_event2skb(ev)); + + if (!sp->data_ready_signalled) { + sp->data_ready_signalled = 1; + sk->sk_data_ready(sk); + } + } +} + +static void sctp_intl_reap_ordered(struct sctp_ulpq *ulpq, __u16 sid) +{ + struct sctp_stream *stream = &ulpq->asoc->stream; + struct sctp_ulpevent *cevent, *event = NULL; + struct sk_buff_head *lobby = &ulpq->lobby; + struct sk_buff *pos, *tmp; + struct sk_buff_head temp; + __u16 csid; + __u32 cmid; + + skb_queue_head_init(&temp); + sctp_skb_for_each(pos, lobby, tmp) { + cevent = (struct sctp_ulpevent *)pos->cb; + csid = cevent->stream; + cmid = cevent->mid; + + if (csid > sid) + break; + + if (csid < sid) + continue; + + if (!MID_lt(cmid, sctp_mid_peek(stream, in, csid))) + break; + + __skb_unlink(pos, lobby); + if (!event) + event = sctp_skb2event(pos); + + __skb_queue_tail(&temp, pos); + } + + if (!event && pos != (struct sk_buff *)lobby) { + cevent = (struct sctp_ulpevent *)pos->cb; + csid = cevent->stream; + cmid = cevent->mid; + + if (csid == sid && cmid == sctp_mid_peek(stream, in, csid)) { + sctp_mid_next(stream, in, csid); + __skb_unlink(pos, lobby); + __skb_queue_tail(&temp, pos); + event = sctp_skb2event(pos); + } + } + + if (event) { + sctp_intl_retrieve_ordered(ulpq, event); + sctp_enqueue_event(ulpq, &temp); + } +} + +static void sctp_intl_abort_pd(struct sctp_ulpq *ulpq, gfp_t gfp) +{ + struct sctp_stream *stream = &ulpq->asoc->stream; + __u16 sid; + + for (sid = 0; sid < stream->incnt; sid++) { + struct sctp_stream_in *sin = SCTP_SI(stream, sid); + __u32 mid; + + if (sin->pd_mode_uo) { + sin->pd_mode_uo = 0; + + mid = sin->mid_uo; + sctp_intl_stream_abort_pd(ulpq, sid, mid, 0x1, gfp); + } + + if (sin->pd_mode) { + sin->pd_mode = 0; + + mid = sin->mid; + sctp_intl_stream_abort_pd(ulpq, sid, mid, 0, gfp); + sctp_mid_skip(stream, in, sid, mid); + + sctp_intl_reap_ordered(ulpq, sid); + } + } + + /* intl abort pd happens only when all data needs to be cleaned */ + sctp_ulpq_flush(ulpq); +} + +static inline int sctp_get_skip_pos(struct sctp_ifwdtsn_skip *skiplist, + int nskips, __be16 stream, __u8 flags) +{ + int i; + + for (i = 0; i < nskips; i++) + if (skiplist[i].stream == stream && + skiplist[i].flags == flags) + return i; + + return i; +} + +#define SCTP_FTSN_U_BIT 0x1 +static void sctp_generate_iftsn(struct sctp_outq *q, __u32 ctsn) +{ + struct sctp_ifwdtsn_skip ftsn_skip_arr[10]; + struct sctp_association *asoc = q->asoc; + struct sctp_chunk *ftsn_chunk = NULL; + struct list_head *lchunk, *temp; + int nskips = 0, skip_pos; + struct sctp_chunk *chunk; + __u32 tsn; + + if (!asoc->peer.prsctp_capable) + return; + + if (TSN_lt(asoc->adv_peer_ack_point, ctsn)) + asoc->adv_peer_ack_point = ctsn; + + list_for_each_safe(lchunk, temp, &q->abandoned) { + chunk = list_entry(lchunk, struct sctp_chunk, transmitted_list); + tsn = ntohl(chunk->subh.data_hdr->tsn); + + if (TSN_lte(tsn, ctsn)) { + list_del_init(lchunk); + sctp_chunk_free(chunk); + } else if (TSN_lte(tsn, asoc->adv_peer_ack_point + 1)) { + __be16 sid = chunk->subh.idata_hdr->stream; + __be32 mid = chunk->subh.idata_hdr->mid; + __u8 flags = 0; + + if (chunk->chunk_hdr->flags & SCTP_DATA_UNORDERED) + flags |= SCTP_FTSN_U_BIT; + + asoc->adv_peer_ack_point = tsn; + skip_pos = sctp_get_skip_pos(&ftsn_skip_arr[0], nskips, + sid, flags); + ftsn_skip_arr[skip_pos].stream = sid; + ftsn_skip_arr[skip_pos].reserved = 0; + ftsn_skip_arr[skip_pos].flags = flags; + ftsn_skip_arr[skip_pos].mid = mid; + if (skip_pos == nskips) + nskips++; + if (nskips == 10) + break; + } else { + break; + } + } + + if (asoc->adv_peer_ack_point > ctsn) + ftsn_chunk = sctp_make_ifwdtsn(asoc, asoc->adv_peer_ack_point, + nskips, &ftsn_skip_arr[0]); + + if (ftsn_chunk) { + list_add_tail(&ftsn_chunk->list, &q->control_chunk_list); + SCTP_INC_STATS(asoc->base.net, SCTP_MIB_OUTCTRLCHUNKS); + } +} + +#define _sctp_walk_ifwdtsn(pos, chunk, end) \ + for (pos = chunk->subh.ifwdtsn_hdr->skip; \ + (void *)pos <= (void *)chunk->subh.ifwdtsn_hdr->skip + (end) - \ + sizeof(struct sctp_ifwdtsn_skip); pos++) + +#define sctp_walk_ifwdtsn(pos, ch) \ + _sctp_walk_ifwdtsn((pos), (ch), ntohs((ch)->chunk_hdr->length) - \ + sizeof(struct sctp_ifwdtsn_chunk)) + +static bool sctp_validate_fwdtsn(struct sctp_chunk *chunk) +{ + struct sctp_fwdtsn_skip *skip; + __u16 incnt; + + if (chunk->chunk_hdr->type != SCTP_CID_FWD_TSN) + return false; + + incnt = chunk->asoc->stream.incnt; + sctp_walk_fwdtsn(skip, chunk) + if (ntohs(skip->stream) >= incnt) + return false; + + return true; +} + +static bool sctp_validate_iftsn(struct sctp_chunk *chunk) +{ + struct sctp_ifwdtsn_skip *skip; + __u16 incnt; + + if (chunk->chunk_hdr->type != SCTP_CID_I_FWD_TSN) + return false; + + incnt = chunk->asoc->stream.incnt; + sctp_walk_ifwdtsn(skip, chunk) + if (ntohs(skip->stream) >= incnt) + return false; + + return true; +} + +static void sctp_report_fwdtsn(struct sctp_ulpq *ulpq, __u32 ftsn) +{ + /* Move the Cumulattive TSN Ack ahead. */ + sctp_tsnmap_skip(&ulpq->asoc->peer.tsn_map, ftsn); + /* purge the fragmentation queue */ + sctp_ulpq_reasm_flushtsn(ulpq, ftsn); + /* Abort any in progress partial delivery. */ + sctp_ulpq_abort_pd(ulpq, GFP_ATOMIC); +} + +static void sctp_intl_reasm_flushtsn(struct sctp_ulpq *ulpq, __u32 ftsn) +{ + struct sk_buff *pos, *tmp; + + skb_queue_walk_safe(&ulpq->reasm, pos, tmp) { + struct sctp_ulpevent *event = sctp_skb2event(pos); + __u32 tsn = event->tsn; + + if (TSN_lte(tsn, ftsn)) { + __skb_unlink(pos, &ulpq->reasm); + sctp_ulpevent_free(event); + } + } + + skb_queue_walk_safe(&ulpq->reasm_uo, pos, tmp) { + struct sctp_ulpevent *event = sctp_skb2event(pos); + __u32 tsn = event->tsn; + + if (TSN_lte(tsn, ftsn)) { + __skb_unlink(pos, &ulpq->reasm_uo); + sctp_ulpevent_free(event); + } + } +} + +static void sctp_report_iftsn(struct sctp_ulpq *ulpq, __u32 ftsn) +{ + /* Move the Cumulattive TSN Ack ahead. */ + sctp_tsnmap_skip(&ulpq->asoc->peer.tsn_map, ftsn); + /* purge the fragmentation queue */ + sctp_intl_reasm_flushtsn(ulpq, ftsn); + /* abort only when it's for all data */ + if (ftsn == sctp_tsnmap_get_max_tsn_seen(&ulpq->asoc->peer.tsn_map)) + sctp_intl_abort_pd(ulpq, GFP_ATOMIC); +} + +static void sctp_handle_fwdtsn(struct sctp_ulpq *ulpq, struct sctp_chunk *chunk) +{ + struct sctp_fwdtsn_skip *skip; + + /* Walk through all the skipped SSNs */ + sctp_walk_fwdtsn(skip, chunk) + sctp_ulpq_skip(ulpq, ntohs(skip->stream), ntohs(skip->ssn)); +} + +static void sctp_intl_skip(struct sctp_ulpq *ulpq, __u16 sid, __u32 mid, + __u8 flags) +{ + struct sctp_stream_in *sin = sctp_stream_in(&ulpq->asoc->stream, sid); + struct sctp_stream *stream = &ulpq->asoc->stream; + + if (flags & SCTP_FTSN_U_BIT) { + if (sin->pd_mode_uo && MID_lt(sin->mid_uo, mid)) { + sin->pd_mode_uo = 0; + sctp_intl_stream_abort_pd(ulpq, sid, mid, 0x1, + GFP_ATOMIC); + } + return; + } + + if (MID_lt(mid, sctp_mid_peek(stream, in, sid))) + return; + + if (sin->pd_mode) { + sin->pd_mode = 0; + sctp_intl_stream_abort_pd(ulpq, sid, mid, 0x0, GFP_ATOMIC); + } + + sctp_mid_skip(stream, in, sid, mid); + + sctp_intl_reap_ordered(ulpq, sid); +} + +static void sctp_handle_iftsn(struct sctp_ulpq *ulpq, struct sctp_chunk *chunk) +{ + struct sctp_ifwdtsn_skip *skip; + + /* Walk through all the skipped MIDs and abort stream pd if possible */ + sctp_walk_ifwdtsn(skip, chunk) + sctp_intl_skip(ulpq, ntohs(skip->stream), + ntohl(skip->mid), skip->flags); +} + +static int do_ulpq_tail_event(struct sctp_ulpq *ulpq, struct sctp_ulpevent *event) +{ + struct sk_buff_head temp; + + skb_queue_head_init(&temp); + __skb_queue_tail(&temp, sctp_event2skb(event)); + return sctp_ulpq_tail_event(ulpq, &temp); +} + +static struct sctp_stream_interleave sctp_stream_interleave_0 = { + .data_chunk_len = sizeof(struct sctp_data_chunk), + .ftsn_chunk_len = sizeof(struct sctp_fwdtsn_chunk), + /* DATA process functions */ + .make_datafrag = sctp_make_datafrag_empty, + .assign_number = sctp_chunk_assign_ssn, + .validate_data = sctp_validate_data, + .ulpevent_data = sctp_ulpq_tail_data, + .enqueue_event = do_ulpq_tail_event, + .renege_events = sctp_ulpq_renege, + .start_pd = sctp_ulpq_partial_delivery, + .abort_pd = sctp_ulpq_abort_pd, + /* FORWARD-TSN process functions */ + .generate_ftsn = sctp_generate_fwdtsn, + .validate_ftsn = sctp_validate_fwdtsn, + .report_ftsn = sctp_report_fwdtsn, + .handle_ftsn = sctp_handle_fwdtsn, +}; + +static int do_sctp_enqueue_event(struct sctp_ulpq *ulpq, + struct sctp_ulpevent *event) +{ + struct sk_buff_head temp; + + skb_queue_head_init(&temp); + __skb_queue_tail(&temp, sctp_event2skb(event)); + return sctp_enqueue_event(ulpq, &temp); +} + +static struct sctp_stream_interleave sctp_stream_interleave_1 = { + .data_chunk_len = sizeof(struct sctp_idata_chunk), + .ftsn_chunk_len = sizeof(struct sctp_ifwdtsn_chunk), + /* I-DATA process functions */ + .make_datafrag = sctp_make_idatafrag_empty, + .assign_number = sctp_chunk_assign_mid, + .validate_data = sctp_validate_idata, + .ulpevent_data = sctp_ulpevent_idata, + .enqueue_event = do_sctp_enqueue_event, + .renege_events = sctp_renege_events, + .start_pd = sctp_intl_start_pd, + .abort_pd = sctp_intl_abort_pd, + /* I-FORWARD-TSN process functions */ + .generate_ftsn = sctp_generate_iftsn, + .validate_ftsn = sctp_validate_iftsn, + .report_ftsn = sctp_report_iftsn, + .handle_ftsn = sctp_handle_iftsn, +}; + +void sctp_stream_interleave_init(struct sctp_stream *stream) +{ + struct sctp_association *asoc; + + asoc = container_of(stream, struct sctp_association, stream); + stream->si = asoc->peer.intl_capable ? &sctp_stream_interleave_1 + : &sctp_stream_interleave_0; +} diff --git a/net/sctp/stream_sched.c b/net/sctp/stream_sched.c new file mode 100644 index 000000000..7c8f9d89e --- /dev/null +++ b/net/sctp/stream_sched.c @@ -0,0 +1,275 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright Red Hat Inc. 2017 + * + * This file is part of the SCTP kernel implementation + * + * These functions manipulate sctp stream queue/scheduling. + * + * Please send any bug reports or fixes you make to the + * email addresched(es): + * lksctp developers + * + * Written or modified by: + * Marcelo Ricardo Leitner + */ + +#include +#include +#include +#include + +/* First Come First Serve (a.k.a. FIFO) + * RFC DRAFT ndata Section 3.1 + */ +static int sctp_sched_fcfs_set(struct sctp_stream *stream, __u16 sid, + __u16 value, gfp_t gfp) +{ + return 0; +} + +static int sctp_sched_fcfs_get(struct sctp_stream *stream, __u16 sid, + __u16 *value) +{ + *value = 0; + return 0; +} + +static int sctp_sched_fcfs_init(struct sctp_stream *stream) +{ + return 0; +} + +static int sctp_sched_fcfs_init_sid(struct sctp_stream *stream, __u16 sid, + gfp_t gfp) +{ + return 0; +} + +static void sctp_sched_fcfs_free_sid(struct sctp_stream *stream, __u16 sid) +{ +} + +static void sctp_sched_fcfs_free(struct sctp_stream *stream) +{ +} + +static void sctp_sched_fcfs_enqueue(struct sctp_outq *q, + struct sctp_datamsg *msg) +{ +} + +static struct sctp_chunk *sctp_sched_fcfs_dequeue(struct sctp_outq *q) +{ + struct sctp_stream *stream = &q->asoc->stream; + struct sctp_chunk *ch = NULL; + struct list_head *entry; + + if (list_empty(&q->out_chunk_list)) + goto out; + + if (stream->out_curr) { + ch = list_entry(stream->out_curr->ext->outq.next, + struct sctp_chunk, stream_list); + } else { + entry = q->out_chunk_list.next; + ch = list_entry(entry, struct sctp_chunk, list); + } + + sctp_sched_dequeue_common(q, ch); + +out: + return ch; +} + +static void sctp_sched_fcfs_dequeue_done(struct sctp_outq *q, + struct sctp_chunk *chunk) +{ +} + +static void sctp_sched_fcfs_sched_all(struct sctp_stream *stream) +{ +} + +static void sctp_sched_fcfs_unsched_all(struct sctp_stream *stream) +{ +} + +static struct sctp_sched_ops sctp_sched_fcfs = { + .set = sctp_sched_fcfs_set, + .get = sctp_sched_fcfs_get, + .init = sctp_sched_fcfs_init, + .init_sid = sctp_sched_fcfs_init_sid, + .free_sid = sctp_sched_fcfs_free_sid, + .free = sctp_sched_fcfs_free, + .enqueue = sctp_sched_fcfs_enqueue, + .dequeue = sctp_sched_fcfs_dequeue, + .dequeue_done = sctp_sched_fcfs_dequeue_done, + .sched_all = sctp_sched_fcfs_sched_all, + .unsched_all = sctp_sched_fcfs_unsched_all, +}; + +static void sctp_sched_ops_fcfs_init(void) +{ + sctp_sched_ops_register(SCTP_SS_FCFS, &sctp_sched_fcfs); +} + +/* API to other parts of the stack */ + +static struct sctp_sched_ops *sctp_sched_ops[SCTP_SS_MAX + 1]; + +void sctp_sched_ops_register(enum sctp_sched_type sched, + struct sctp_sched_ops *sched_ops) +{ + sctp_sched_ops[sched] = sched_ops; +} + +void sctp_sched_ops_init(void) +{ + sctp_sched_ops_fcfs_init(); + sctp_sched_ops_prio_init(); + sctp_sched_ops_rr_init(); +} + +int sctp_sched_set_sched(struct sctp_association *asoc, + enum sctp_sched_type sched) +{ + struct sctp_sched_ops *n = sctp_sched_ops[sched]; + struct sctp_sched_ops *old = asoc->outqueue.sched; + struct sctp_datamsg *msg = NULL; + struct sctp_chunk *ch; + int i, ret = 0; + + if (old == n) + return ret; + + if (sched > SCTP_SS_MAX) + return -EINVAL; + + if (old) { + old->free(&asoc->stream); + + /* Give the next scheduler a clean slate. */ + for (i = 0; i < asoc->stream.outcnt; i++) { + struct sctp_stream_out_ext *ext = SCTP_SO(&asoc->stream, i)->ext; + + if (!ext) + continue; + memset_after(ext, 0, outq); + } + } + + asoc->outqueue.sched = n; + n->init(&asoc->stream); + for (i = 0; i < asoc->stream.outcnt; i++) { + if (!SCTP_SO(&asoc->stream, i)->ext) + continue; + + ret = n->init_sid(&asoc->stream, i, GFP_ATOMIC); + if (ret) + goto err; + } + + /* We have to requeue all chunks already queued. */ + list_for_each_entry(ch, &asoc->outqueue.out_chunk_list, list) { + if (ch->msg == msg) + continue; + msg = ch->msg; + n->enqueue(&asoc->outqueue, msg); + } + + return ret; + +err: + n->free(&asoc->stream); + asoc->outqueue.sched = &sctp_sched_fcfs; /* Always safe */ + + return ret; +} + +int sctp_sched_get_sched(struct sctp_association *asoc) +{ + int i; + + for (i = 0; i <= SCTP_SS_MAX; i++) + if (asoc->outqueue.sched == sctp_sched_ops[i]) + return i; + + return 0; +} + +int sctp_sched_set_value(struct sctp_association *asoc, __u16 sid, + __u16 value, gfp_t gfp) +{ + if (sid >= asoc->stream.outcnt) + return -EINVAL; + + if (!SCTP_SO(&asoc->stream, sid)->ext) { + int ret; + + ret = sctp_stream_init_ext(&asoc->stream, sid); + if (ret) + return ret; + } + + return asoc->outqueue.sched->set(&asoc->stream, sid, value, gfp); +} + +int sctp_sched_get_value(struct sctp_association *asoc, __u16 sid, + __u16 *value) +{ + if (sid >= asoc->stream.outcnt) + return -EINVAL; + + if (!SCTP_SO(&asoc->stream, sid)->ext) + return 0; + + return asoc->outqueue.sched->get(&asoc->stream, sid, value); +} + +void sctp_sched_dequeue_done(struct sctp_outq *q, struct sctp_chunk *ch) +{ + if (!list_is_last(&ch->frag_list, &ch->msg->chunks) && + !q->asoc->peer.intl_capable) { + struct sctp_stream_out *sout; + __u16 sid; + + /* datamsg is not finish, so save it as current one, + * in case application switch scheduler or a higher + * priority stream comes in. + */ + sid = sctp_chunk_stream_no(ch); + sout = SCTP_SO(&q->asoc->stream, sid); + q->asoc->stream.out_curr = sout; + return; + } + + q->asoc->stream.out_curr = NULL; + q->sched->dequeue_done(q, ch); +} + +/* Auxiliary functions for the schedulers */ +void sctp_sched_dequeue_common(struct sctp_outq *q, struct sctp_chunk *ch) +{ + list_del_init(&ch->list); + list_del_init(&ch->stream_list); + q->out_qlen -= ch->skb->len; +} + +int sctp_sched_init_sid(struct sctp_stream *stream, __u16 sid, gfp_t gfp) +{ + struct sctp_sched_ops *sched = sctp_sched_ops_from_stream(stream); + struct sctp_stream_out_ext *ext = SCTP_SO(stream, sid)->ext; + + INIT_LIST_HEAD(&ext->outq); + return sched->init_sid(stream, sid, gfp); +} + +struct sctp_sched_ops *sctp_sched_ops_from_stream(struct sctp_stream *stream) +{ + struct sctp_association *asoc; + + asoc = container_of(stream, struct sctp_association, stream); + + return asoc->outqueue.sched; +} diff --git a/net/sctp/stream_sched_prio.c b/net/sctp/stream_sched_prio.c new file mode 100644 index 000000000..7dd9f8b38 --- /dev/null +++ b/net/sctp/stream_sched_prio.c @@ -0,0 +1,346 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright Red Hat Inc. 2017 + * + * This file is part of the SCTP kernel implementation + * + * These functions manipulate sctp stream queue/scheduling. + * + * Please send any bug reports or fixes you make to the + * email addresched(es): + * lksctp developers + * + * Written or modified by: + * Marcelo Ricardo Leitner + */ + +#include +#include +#include +#include + +/* Priority handling + * RFC DRAFT ndata section 3.4 + */ + +static void sctp_sched_prio_unsched_all(struct sctp_stream *stream); + +static struct sctp_stream_priorities *sctp_sched_prio_head_get(struct sctp_stream_priorities *p) +{ + p->users++; + return p; +} + +static void sctp_sched_prio_head_put(struct sctp_stream_priorities *p) +{ + if (p && --p->users == 0) + kfree(p); +} + +static struct sctp_stream_priorities *sctp_sched_prio_new_head( + struct sctp_stream *stream, int prio, gfp_t gfp) +{ + struct sctp_stream_priorities *p; + + p = kmalloc(sizeof(*p), gfp); + if (!p) + return NULL; + + INIT_LIST_HEAD(&p->prio_sched); + INIT_LIST_HEAD(&p->active); + p->next = NULL; + p->prio = prio; + p->users = 1; + + return p; +} + +static struct sctp_stream_priorities *sctp_sched_prio_get_head( + struct sctp_stream *stream, int prio, gfp_t gfp) +{ + struct sctp_stream_priorities *p; + int i; + + /* Look into scheduled priorities first, as they are sorted and + * we can find it fast IF it's scheduled. + */ + list_for_each_entry(p, &stream->prio_list, prio_sched) { + if (p->prio == prio) + return sctp_sched_prio_head_get(p); + if (p->prio > prio) + break; + } + + /* No luck. So we search on all streams now. */ + for (i = 0; i < stream->outcnt; i++) { + if (!SCTP_SO(stream, i)->ext) + continue; + + p = SCTP_SO(stream, i)->ext->prio_head; + if (!p) + /* Means all other streams won't be initialized + * as well. + */ + break; + if (p->prio == prio) + return sctp_sched_prio_head_get(p); + } + + /* If not even there, allocate a new one. */ + return sctp_sched_prio_new_head(stream, prio, gfp); +} + +static void sctp_sched_prio_next_stream(struct sctp_stream_priorities *p) +{ + struct list_head *pos; + + pos = p->next->prio_list.next; + if (pos == &p->active) + pos = pos->next; + p->next = list_entry(pos, struct sctp_stream_out_ext, prio_list); +} + +static bool sctp_sched_prio_unsched(struct sctp_stream_out_ext *soute) +{ + bool scheduled = false; + + if (!list_empty(&soute->prio_list)) { + struct sctp_stream_priorities *prio_head = soute->prio_head; + + /* Scheduled */ + scheduled = true; + + if (prio_head->next == soute) + /* Try to move to the next stream */ + sctp_sched_prio_next_stream(prio_head); + + list_del_init(&soute->prio_list); + + /* Also unsched the priority if this was the last stream */ + if (list_empty(&prio_head->active)) { + list_del_init(&prio_head->prio_sched); + /* If there is no stream left, clear next */ + prio_head->next = NULL; + } + } + + return scheduled; +} + +static void sctp_sched_prio_sched(struct sctp_stream *stream, + struct sctp_stream_out_ext *soute) +{ + struct sctp_stream_priorities *prio, *prio_head; + + prio_head = soute->prio_head; + + /* Nothing to do if already scheduled */ + if (!list_empty(&soute->prio_list)) + return; + + /* Schedule the stream. If there is a next, we schedule the new + * one before it, so it's the last in round robin order. + * If there isn't, we also have to schedule the priority. + */ + if (prio_head->next) { + list_add(&soute->prio_list, prio_head->next->prio_list.prev); + return; + } + + list_add(&soute->prio_list, &prio_head->active); + prio_head->next = soute; + + list_for_each_entry(prio, &stream->prio_list, prio_sched) { + if (prio->prio > prio_head->prio) { + list_add(&prio_head->prio_sched, prio->prio_sched.prev); + return; + } + } + + list_add_tail(&prio_head->prio_sched, &stream->prio_list); +} + +static int sctp_sched_prio_set(struct sctp_stream *stream, __u16 sid, + __u16 prio, gfp_t gfp) +{ + struct sctp_stream_out *sout = SCTP_SO(stream, sid); + struct sctp_stream_out_ext *soute = sout->ext; + struct sctp_stream_priorities *prio_head, *old; + bool reschedule = false; + + old = soute->prio_head; + if (old && old->prio == prio) + return 0; + + prio_head = sctp_sched_prio_get_head(stream, prio, gfp); + if (!prio_head) + return -ENOMEM; + + reschedule = sctp_sched_prio_unsched(soute); + soute->prio_head = prio_head; + if (reschedule) + sctp_sched_prio_sched(stream, soute); + + sctp_sched_prio_head_put(old); + return 0; +} + +static int sctp_sched_prio_get(struct sctp_stream *stream, __u16 sid, + __u16 *value) +{ + *value = SCTP_SO(stream, sid)->ext->prio_head->prio; + return 0; +} + +static int sctp_sched_prio_init(struct sctp_stream *stream) +{ + INIT_LIST_HEAD(&stream->prio_list); + + return 0; +} + +static int sctp_sched_prio_init_sid(struct sctp_stream *stream, __u16 sid, + gfp_t gfp) +{ + INIT_LIST_HEAD(&SCTP_SO(stream, sid)->ext->prio_list); + return sctp_sched_prio_set(stream, sid, 0, gfp); +} + +static void sctp_sched_prio_free_sid(struct sctp_stream *stream, __u16 sid) +{ + sctp_sched_prio_head_put(SCTP_SO(stream, sid)->ext->prio_head); + SCTP_SO(stream, sid)->ext->prio_head = NULL; +} + +static void sctp_sched_prio_free(struct sctp_stream *stream) +{ + struct sctp_stream_priorities *prio, *n; + LIST_HEAD(list); + int i; + + /* As we don't keep a list of priorities, to avoid multiple + * frees we have to do it in 3 steps: + * 1. unsched everyone, so the lists are free to use in 2. + * 2. build the list of the priorities + * 3. free the list + */ + sctp_sched_prio_unsched_all(stream); + for (i = 0; i < stream->outcnt; i++) { + if (!SCTP_SO(stream, i)->ext) + continue; + prio = SCTP_SO(stream, i)->ext->prio_head; + if (prio && list_empty(&prio->prio_sched)) + list_add(&prio->prio_sched, &list); + } + list_for_each_entry_safe(prio, n, &list, prio_sched) { + list_del_init(&prio->prio_sched); + kfree(prio); + } +} + +static void sctp_sched_prio_enqueue(struct sctp_outq *q, + struct sctp_datamsg *msg) +{ + struct sctp_stream *stream; + struct sctp_chunk *ch; + __u16 sid; + + ch = list_first_entry(&msg->chunks, struct sctp_chunk, frag_list); + sid = sctp_chunk_stream_no(ch); + stream = &q->asoc->stream; + sctp_sched_prio_sched(stream, SCTP_SO(stream, sid)->ext); +} + +static struct sctp_chunk *sctp_sched_prio_dequeue(struct sctp_outq *q) +{ + struct sctp_stream *stream = &q->asoc->stream; + struct sctp_stream_priorities *prio; + struct sctp_stream_out_ext *soute; + struct sctp_chunk *ch = NULL; + + /* Bail out quickly if queue is empty */ + if (list_empty(&q->out_chunk_list)) + goto out; + + /* Find which chunk is next. It's easy, it's either the current + * one or the first chunk on the next active stream. + */ + if (stream->out_curr) { + soute = stream->out_curr->ext; + } else { + prio = list_entry(stream->prio_list.next, + struct sctp_stream_priorities, prio_sched); + soute = prio->next; + } + ch = list_entry(soute->outq.next, struct sctp_chunk, stream_list); + sctp_sched_dequeue_common(q, ch); + +out: + return ch; +} + +static void sctp_sched_prio_dequeue_done(struct sctp_outq *q, + struct sctp_chunk *ch) +{ + struct sctp_stream_priorities *prio; + struct sctp_stream_out_ext *soute; + __u16 sid; + + /* Last chunk on that msg, move to the next stream on + * this priority. + */ + sid = sctp_chunk_stream_no(ch); + soute = SCTP_SO(&q->asoc->stream, sid)->ext; + prio = soute->prio_head; + + sctp_sched_prio_next_stream(prio); + + if (list_empty(&soute->outq)) + sctp_sched_prio_unsched(soute); +} + +static void sctp_sched_prio_sched_all(struct sctp_stream *stream) +{ + struct sctp_association *asoc; + struct sctp_stream_out *sout; + struct sctp_chunk *ch; + + asoc = container_of(stream, struct sctp_association, stream); + list_for_each_entry(ch, &asoc->outqueue.out_chunk_list, list) { + __u16 sid; + + sid = sctp_chunk_stream_no(ch); + sout = SCTP_SO(stream, sid); + if (sout->ext) + sctp_sched_prio_sched(stream, sout->ext); + } +} + +static void sctp_sched_prio_unsched_all(struct sctp_stream *stream) +{ + struct sctp_stream_priorities *p, *tmp; + struct sctp_stream_out_ext *soute, *souttmp; + + list_for_each_entry_safe(p, tmp, &stream->prio_list, prio_sched) + list_for_each_entry_safe(soute, souttmp, &p->active, prio_list) + sctp_sched_prio_unsched(soute); +} + +static struct sctp_sched_ops sctp_sched_prio = { + .set = sctp_sched_prio_set, + .get = sctp_sched_prio_get, + .init = sctp_sched_prio_init, + .init_sid = sctp_sched_prio_init_sid, + .free_sid = sctp_sched_prio_free_sid, + .free = sctp_sched_prio_free, + .enqueue = sctp_sched_prio_enqueue, + .dequeue = sctp_sched_prio_dequeue, + .dequeue_done = sctp_sched_prio_dequeue_done, + .sched_all = sctp_sched_prio_sched_all, + .unsched_all = sctp_sched_prio_unsched_all, +}; + +void sctp_sched_ops_prio_init(void) +{ + sctp_sched_ops_register(SCTP_SS_PRIO, &sctp_sched_prio); +} diff --git a/net/sctp/stream_sched_rr.c b/net/sctp/stream_sched_rr.c new file mode 100644 index 000000000..cc444fe0d --- /dev/null +++ b/net/sctp/stream_sched_rr.c @@ -0,0 +1,196 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright Red Hat Inc. 2017 + * + * This file is part of the SCTP kernel implementation + * + * These functions manipulate sctp stream queue/scheduling. + * + * Please send any bug reports or fixes you make to the + * email addresched(es): + * lksctp developers + * + * Written or modified by: + * Marcelo Ricardo Leitner + */ + +#include +#include +#include +#include + +/* Priority handling + * RFC DRAFT ndata section 3.2 + */ +static void sctp_sched_rr_unsched_all(struct sctp_stream *stream); + +static void sctp_sched_rr_next_stream(struct sctp_stream *stream) +{ + struct list_head *pos; + + pos = stream->rr_next->rr_list.next; + if (pos == &stream->rr_list) + pos = pos->next; + stream->rr_next = list_entry(pos, struct sctp_stream_out_ext, rr_list); +} + +static void sctp_sched_rr_unsched(struct sctp_stream *stream, + struct sctp_stream_out_ext *soute) +{ + if (stream->rr_next == soute) + /* Try to move to the next stream */ + sctp_sched_rr_next_stream(stream); + + list_del_init(&soute->rr_list); + + /* If we have no other stream queued, clear next */ + if (list_empty(&stream->rr_list)) + stream->rr_next = NULL; +} + +static void sctp_sched_rr_sched(struct sctp_stream *stream, + struct sctp_stream_out_ext *soute) +{ + if (!list_empty(&soute->rr_list)) + /* Already scheduled. */ + return; + + /* Schedule the stream */ + list_add_tail(&soute->rr_list, &stream->rr_list); + + if (!stream->rr_next) + stream->rr_next = soute; +} + +static int sctp_sched_rr_set(struct sctp_stream *stream, __u16 sid, + __u16 prio, gfp_t gfp) +{ + return 0; +} + +static int sctp_sched_rr_get(struct sctp_stream *stream, __u16 sid, + __u16 *value) +{ + return 0; +} + +static int sctp_sched_rr_init(struct sctp_stream *stream) +{ + INIT_LIST_HEAD(&stream->rr_list); + stream->rr_next = NULL; + + return 0; +} + +static int sctp_sched_rr_init_sid(struct sctp_stream *stream, __u16 sid, + gfp_t gfp) +{ + INIT_LIST_HEAD(&SCTP_SO(stream, sid)->ext->rr_list); + + return 0; +} + +static void sctp_sched_rr_free_sid(struct sctp_stream *stream, __u16 sid) +{ +} + +static void sctp_sched_rr_free(struct sctp_stream *stream) +{ + sctp_sched_rr_unsched_all(stream); +} + +static void sctp_sched_rr_enqueue(struct sctp_outq *q, + struct sctp_datamsg *msg) +{ + struct sctp_stream *stream; + struct sctp_chunk *ch; + __u16 sid; + + ch = list_first_entry(&msg->chunks, struct sctp_chunk, frag_list); + sid = sctp_chunk_stream_no(ch); + stream = &q->asoc->stream; + sctp_sched_rr_sched(stream, SCTP_SO(stream, sid)->ext); +} + +static struct sctp_chunk *sctp_sched_rr_dequeue(struct sctp_outq *q) +{ + struct sctp_stream *stream = &q->asoc->stream; + struct sctp_stream_out_ext *soute; + struct sctp_chunk *ch = NULL; + + /* Bail out quickly if queue is empty */ + if (list_empty(&q->out_chunk_list)) + goto out; + + /* Find which chunk is next */ + if (stream->out_curr) + soute = stream->out_curr->ext; + else + soute = stream->rr_next; + ch = list_entry(soute->outq.next, struct sctp_chunk, stream_list); + + sctp_sched_dequeue_common(q, ch); + +out: + return ch; +} + +static void sctp_sched_rr_dequeue_done(struct sctp_outq *q, + struct sctp_chunk *ch) +{ + struct sctp_stream_out_ext *soute; + __u16 sid; + + /* Last chunk on that msg, move to the next stream */ + sid = sctp_chunk_stream_no(ch); + soute = SCTP_SO(&q->asoc->stream, sid)->ext; + + sctp_sched_rr_next_stream(&q->asoc->stream); + + if (list_empty(&soute->outq)) + sctp_sched_rr_unsched(&q->asoc->stream, soute); +} + +static void sctp_sched_rr_sched_all(struct sctp_stream *stream) +{ + struct sctp_association *asoc; + struct sctp_stream_out_ext *soute; + struct sctp_chunk *ch; + + asoc = container_of(stream, struct sctp_association, stream); + list_for_each_entry(ch, &asoc->outqueue.out_chunk_list, list) { + __u16 sid; + + sid = sctp_chunk_stream_no(ch); + soute = SCTP_SO(stream, sid)->ext; + if (soute) + sctp_sched_rr_sched(stream, soute); + } +} + +static void sctp_sched_rr_unsched_all(struct sctp_stream *stream) +{ + struct sctp_stream_out_ext *soute, *tmp; + + list_for_each_entry_safe(soute, tmp, &stream->rr_list, rr_list) + sctp_sched_rr_unsched(stream, soute); +} + +static struct sctp_sched_ops sctp_sched_rr = { + .set = sctp_sched_rr_set, + .get = sctp_sched_rr_get, + .init = sctp_sched_rr_init, + .init_sid = sctp_sched_rr_init_sid, + .free_sid = sctp_sched_rr_free_sid, + .free = sctp_sched_rr_free, + .enqueue = sctp_sched_rr_enqueue, + .dequeue = sctp_sched_rr_dequeue, + .dequeue_done = sctp_sched_rr_dequeue_done, + .sched_all = sctp_sched_rr_sched_all, + .unsched_all = sctp_sched_rr_unsched_all, +}; + +void sctp_sched_ops_rr_init(void) +{ + sctp_sched_ops_register(SCTP_SS_RR, &sctp_sched_rr); +} diff --git a/net/sctp/sysctl.c b/net/sctp/sysctl.c new file mode 100644 index 000000000..43ebf0900 --- /dev/null +++ b/net/sctp/sysctl.c @@ -0,0 +1,633 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2002, 2004 + * Copyright (c) 2002 Intel Corp. + * + * This file is part of the SCTP kernel implementation + * + * Sysctl related interfaces for SCTP. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * Mingqin Liu + * Jon Grimm + * Ardelle Fan + * Ryan Layer + * Sridhar Samudrala + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include + +static int timer_max = 86400000; /* ms in one day */ +static int sack_timer_min = 1; +static int sack_timer_max = 500; +static int addr_scope_max = SCTP_SCOPE_POLICY_MAX; +static int rwnd_scale_max = 16; +static int rto_alpha_min = 0; +static int rto_beta_min = 0; +static int rto_alpha_max = 1000; +static int rto_beta_max = 1000; +static int pf_expose_max = SCTP_PF_EXPOSE_MAX; +static int ps_retrans_max = SCTP_PS_RETRANS_MAX; +static int udp_port_max = 65535; + +static unsigned long max_autoclose_min = 0; +static unsigned long max_autoclose_max = + (MAX_SCHEDULE_TIMEOUT / HZ > UINT_MAX) + ? UINT_MAX : MAX_SCHEDULE_TIMEOUT / HZ; + +static int proc_sctp_do_hmac_alg(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos); +static int proc_sctp_do_rto_min(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos); +static int proc_sctp_do_rto_max(struct ctl_table *ctl, int write, void *buffer, + size_t *lenp, loff_t *ppos); +static int proc_sctp_do_udp_port(struct ctl_table *ctl, int write, void *buffer, + size_t *lenp, loff_t *ppos); +static int proc_sctp_do_alpha_beta(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos); +static int proc_sctp_do_auth(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos); +static int proc_sctp_do_probe_interval(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos); + +static struct ctl_table sctp_table[] = { + { + .procname = "sctp_mem", + .data = &sysctl_sctp_mem, + .maxlen = sizeof(sysctl_sctp_mem), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax + }, + { + .procname = "sctp_rmem", + .data = &sysctl_sctp_rmem, + .maxlen = sizeof(sysctl_sctp_rmem), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "sctp_wmem", + .data = &sysctl_sctp_wmem, + .maxlen = sizeof(sysctl_sctp_wmem), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + + { /* sentinel */ } +}; + +/* The following index defines are used in sctp_sysctl_net_register(). + * If you add new items to the sctp_net_table, please ensure that + * the index values of these defines hold the same meaning indicated by + * their macro names when they appear in sctp_net_table. + */ +#define SCTP_RTO_MIN_IDX 0 +#define SCTP_RTO_MAX_IDX 1 +#define SCTP_PF_RETRANS_IDX 2 +#define SCTP_PS_RETRANS_IDX 3 + +static struct ctl_table sctp_net_table[] = { + [SCTP_RTO_MIN_IDX] = { + .procname = "rto_min", + .data = &init_net.sctp.rto_min, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_sctp_do_rto_min, + .extra1 = SYSCTL_ONE, + .extra2 = &init_net.sctp.rto_max + }, + [SCTP_RTO_MAX_IDX] = { + .procname = "rto_max", + .data = &init_net.sctp.rto_max, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_sctp_do_rto_max, + .extra1 = &init_net.sctp.rto_min, + .extra2 = &timer_max + }, + [SCTP_PF_RETRANS_IDX] = { + .procname = "pf_retrans", + .data = &init_net.sctp.pf_retrans, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &init_net.sctp.ps_retrans, + }, + [SCTP_PS_RETRANS_IDX] = { + .procname = "ps_retrans", + .data = &init_net.sctp.ps_retrans, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &init_net.sctp.pf_retrans, + .extra2 = &ps_retrans_max, + }, + { + .procname = "rto_initial", + .data = &init_net.sctp.rto_initial, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + .extra2 = &timer_max + }, + { + .procname = "rto_alpha_exp_divisor", + .data = &init_net.sctp.rto_alpha, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_sctp_do_alpha_beta, + .extra1 = &rto_alpha_min, + .extra2 = &rto_alpha_max, + }, + { + .procname = "rto_beta_exp_divisor", + .data = &init_net.sctp.rto_beta, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_sctp_do_alpha_beta, + .extra1 = &rto_beta_min, + .extra2 = &rto_beta_max, + }, + { + .procname = "max_burst", + .data = &init_net.sctp.max_burst, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_INT_MAX, + }, + { + .procname = "cookie_preserve_enable", + .data = &init_net.sctp.cookie_preserve_enable, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "cookie_hmac_alg", + .data = &init_net.sctp.sctp_hmac_alg, + .maxlen = 8, + .mode = 0644, + .proc_handler = proc_sctp_do_hmac_alg, + }, + { + .procname = "valid_cookie_life", + .data = &init_net.sctp.valid_cookie_life, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + .extra2 = &timer_max + }, + { + .procname = "sack_timeout", + .data = &init_net.sctp.sack_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &sack_timer_min, + .extra2 = &sack_timer_max, + }, + { + .procname = "hb_interval", + .data = &init_net.sctp.hb_interval, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + .extra2 = &timer_max + }, + { + .procname = "association_max_retrans", + .data = &init_net.sctp.max_retrans_association, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + .extra2 = SYSCTL_INT_MAX, + }, + { + .procname = "path_max_retrans", + .data = &init_net.sctp.max_retrans_path, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + .extra2 = SYSCTL_INT_MAX, + }, + { + .procname = "max_init_retransmits", + .data = &init_net.sctp.max_retrans_init, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + .extra2 = SYSCTL_INT_MAX, + }, + { + .procname = "sndbuf_policy", + .data = &init_net.sctp.sndbuf_policy, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "rcvbuf_policy", + .data = &init_net.sctp.rcvbuf_policy, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "default_auto_asconf", + .data = &init_net.sctp.default_auto_asconf, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "addip_enable", + .data = &init_net.sctp.addip_enable, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "addip_noauth_enable", + .data = &init_net.sctp.addip_noauth, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "prsctp_enable", + .data = &init_net.sctp.prsctp_enable, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "reconf_enable", + .data = &init_net.sctp.reconf_enable, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "auth_enable", + .data = &init_net.sctp.auth_enable, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_sctp_do_auth, + }, + { + .procname = "intl_enable", + .data = &init_net.sctp.intl_enable, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "ecn_enable", + .data = &init_net.sctp.ecn_enable, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "plpmtud_probe_interval", + .data = &init_net.sctp.probe_interval, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_sctp_do_probe_interval, + }, + { + .procname = "udp_port", + .data = &init_net.sctp.udp_port, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_sctp_do_udp_port, + .extra1 = SYSCTL_ZERO, + .extra2 = &udp_port_max, + }, + { + .procname = "encap_port", + .data = &init_net.sctp.encap_port, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &udp_port_max, + }, + { + .procname = "addr_scope_policy", + .data = &init_net.sctp.scope_policy, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &addr_scope_max, + }, + { + .procname = "rwnd_update_shift", + .data = &init_net.sctp.rwnd_upd_shift, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = &proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + .extra2 = &rwnd_scale_max, + }, + { + .procname = "max_autoclose", + .data = &init_net.sctp.max_autoclose, + .maxlen = sizeof(unsigned long), + .mode = 0644, + .proc_handler = &proc_doulongvec_minmax, + .extra1 = &max_autoclose_min, + .extra2 = &max_autoclose_max, + }, + { + .procname = "pf_enable", + .data = &init_net.sctp.pf_enable, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { + .procname = "pf_expose", + .data = &init_net.sctp.pf_expose, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &pf_expose_max, + }, + + { /* sentinel */ } +}; + +static int proc_sctp_do_hmac_alg(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = current->nsproxy->net_ns; + struct ctl_table tbl; + bool changed = false; + char *none = "none"; + char tmp[8] = {0}; + int ret; + + memset(&tbl, 0, sizeof(struct ctl_table)); + + if (write) { + tbl.data = tmp; + tbl.maxlen = sizeof(tmp); + } else { + tbl.data = net->sctp.sctp_hmac_alg ? : none; + tbl.maxlen = strlen(tbl.data); + } + + ret = proc_dostring(&tbl, write, buffer, lenp, ppos); + if (write && ret == 0) { +#ifdef CONFIG_CRYPTO_MD5 + if (!strncmp(tmp, "md5", 3)) { + net->sctp.sctp_hmac_alg = "md5"; + changed = true; + } +#endif +#ifdef CONFIG_CRYPTO_SHA1 + if (!strncmp(tmp, "sha1", 4)) { + net->sctp.sctp_hmac_alg = "sha1"; + changed = true; + } +#endif + if (!strncmp(tmp, "none", 4)) { + net->sctp.sctp_hmac_alg = NULL; + changed = true; + } + if (!changed) + ret = -EINVAL; + } + + return ret; +} + +static int proc_sctp_do_rto_min(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = current->nsproxy->net_ns; + unsigned int min = *(unsigned int *) ctl->extra1; + unsigned int max = *(unsigned int *) ctl->extra2; + struct ctl_table tbl; + int ret, new_value; + + memset(&tbl, 0, sizeof(struct ctl_table)); + tbl.maxlen = sizeof(unsigned int); + + if (write) + tbl.data = &new_value; + else + tbl.data = &net->sctp.rto_min; + + ret = proc_dointvec(&tbl, write, buffer, lenp, ppos); + if (write && ret == 0) { + if (new_value > max || new_value < min) + return -EINVAL; + + net->sctp.rto_min = new_value; + } + + return ret; +} + +static int proc_sctp_do_rto_max(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = current->nsproxy->net_ns; + unsigned int min = *(unsigned int *) ctl->extra1; + unsigned int max = *(unsigned int *) ctl->extra2; + struct ctl_table tbl; + int ret, new_value; + + memset(&tbl, 0, sizeof(struct ctl_table)); + tbl.maxlen = sizeof(unsigned int); + + if (write) + tbl.data = &new_value; + else + tbl.data = &net->sctp.rto_max; + + ret = proc_dointvec(&tbl, write, buffer, lenp, ppos); + if (write && ret == 0) { + if (new_value > max || new_value < min) + return -EINVAL; + + net->sctp.rto_max = new_value; + } + + return ret; +} + +static int proc_sctp_do_alpha_beta(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + if (write) + pr_warn_once("Changing rto_alpha or rto_beta may lead to " + "suboptimal rtt/srtt estimations!\n"); + + return proc_dointvec_minmax(ctl, write, buffer, lenp, ppos); +} + +static int proc_sctp_do_auth(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = current->nsproxy->net_ns; + struct ctl_table tbl; + int new_value, ret; + + memset(&tbl, 0, sizeof(struct ctl_table)); + tbl.maxlen = sizeof(unsigned int); + + if (write) + tbl.data = &new_value; + else + tbl.data = &net->sctp.auth_enable; + + ret = proc_dointvec(&tbl, write, buffer, lenp, ppos); + if (write && ret == 0) { + struct sock *sk = net->sctp.ctl_sock; + + net->sctp.auth_enable = new_value; + /* Update the value in the control socket */ + lock_sock(sk); + sctp_sk(sk)->ep->auth_enable = new_value; + release_sock(sk); + } + + return ret; +} + +static int proc_sctp_do_udp_port(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = current->nsproxy->net_ns; + unsigned int min = *(unsigned int *)ctl->extra1; + unsigned int max = *(unsigned int *)ctl->extra2; + struct ctl_table tbl; + int ret, new_value; + + memset(&tbl, 0, sizeof(struct ctl_table)); + tbl.maxlen = sizeof(unsigned int); + + if (write) + tbl.data = &new_value; + else + tbl.data = &net->sctp.udp_port; + + ret = proc_dointvec(&tbl, write, buffer, lenp, ppos); + if (write && ret == 0) { + struct sock *sk = net->sctp.ctl_sock; + + if (new_value > max || new_value < min) + return -EINVAL; + + net->sctp.udp_port = new_value; + sctp_udp_sock_stop(net); + if (new_value) { + ret = sctp_udp_sock_start(net); + if (ret) + net->sctp.udp_port = 0; + } + + /* Update the value in the control socket */ + lock_sock(sk); + sctp_sk(sk)->udp_port = htons(net->sctp.udp_port); + release_sock(sk); + } + + return ret; +} + +static int proc_sctp_do_probe_interval(struct ctl_table *ctl, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct net *net = current->nsproxy->net_ns; + struct ctl_table tbl; + int ret, new_value; + + memset(&tbl, 0, sizeof(struct ctl_table)); + tbl.maxlen = sizeof(unsigned int); + + if (write) + tbl.data = &new_value; + else + tbl.data = &net->sctp.probe_interval; + + ret = proc_dointvec(&tbl, write, buffer, lenp, ppos); + if (write && ret == 0) { + if (new_value && new_value < SCTP_PROBE_TIMER_MIN) + return -EINVAL; + + net->sctp.probe_interval = new_value; + } + + return ret; +} + +int sctp_sysctl_net_register(struct net *net) +{ + struct ctl_table *table; + int i; + + table = kmemdup(sctp_net_table, sizeof(sctp_net_table), GFP_KERNEL); + if (!table) + return -ENOMEM; + + for (i = 0; table[i].data; i++) + table[i].data += (char *)(&net->sctp) - (char *)&init_net.sctp; + + table[SCTP_RTO_MIN_IDX].extra2 = &net->sctp.rto_max; + table[SCTP_RTO_MAX_IDX].extra1 = &net->sctp.rto_min; + table[SCTP_PF_RETRANS_IDX].extra2 = &net->sctp.ps_retrans; + table[SCTP_PS_RETRANS_IDX].extra1 = &net->sctp.pf_retrans; + + net->sctp.sysctl_header = register_net_sysctl(net, "net/sctp", table); + if (net->sctp.sysctl_header == NULL) { + kfree(table); + return -ENOMEM; + } + return 0; +} + +void sctp_sysctl_net_unregister(struct net *net) +{ + struct ctl_table *table; + + table = net->sctp.sysctl_header->ctl_table_arg; + unregister_net_sysctl_table(net->sctp.sysctl_header); + kfree(table); +} + +static struct ctl_table_header *sctp_sysctl_header; + +/* Sysctl registration. */ +void sctp_sysctl_register(void) +{ + sctp_sysctl_header = register_net_sysctl(&init_net, "net/sctp", sctp_table); +} + +/* Sysctl deregistration. */ +void sctp_sysctl_unregister(void) +{ + unregister_net_sysctl_table(sctp_sysctl_header); +} diff --git a/net/sctp/transport.c b/net/sctp/transport.c new file mode 100644 index 000000000..2990365c2 --- /dev/null +++ b/net/sctp/transport.c @@ -0,0 +1,854 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * Copyright (c) 2001-2003 International Business Machines Corp. + * Copyright (c) 2001 Intel Corp. + * Copyright (c) 2001 La Monte H.P. Yarroll + * + * This file is part of the SCTP kernel implementation + * + * This module provides the abstraction for an SCTP transport representing + * a remote transport address. For local transport addresses, we just use + * union sctp_addr. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * La Monte H.P. Yarroll + * Karl Knutson + * Jon Grimm + * Xingang Guo + * Hui Huang + * Sridhar Samudrala + * Ardelle Fan + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include + +/* 1st Level Abstractions. */ + +/* Initialize a new transport from provided memory. */ +static struct sctp_transport *sctp_transport_init(struct net *net, + struct sctp_transport *peer, + const union sctp_addr *addr, + gfp_t gfp) +{ + /* Copy in the address. */ + peer->af_specific = sctp_get_af_specific(addr->sa.sa_family); + memcpy(&peer->ipaddr, addr, peer->af_specific->sockaddr_len); + memset(&peer->saddr, 0, sizeof(union sctp_addr)); + + peer->sack_generation = 0; + + /* From 6.3.1 RTO Calculation: + * + * C1) Until an RTT measurement has been made for a packet sent to the + * given destination transport address, set RTO to the protocol + * parameter 'RTO.Initial'. + */ + peer->rto = msecs_to_jiffies(net->sctp.rto_initial); + + peer->last_time_heard = 0; + peer->last_time_ecne_reduced = jiffies; + + peer->param_flags = SPP_HB_DISABLE | + SPP_PMTUD_ENABLE | + SPP_SACKDELAY_ENABLE; + + /* Initialize the default path max_retrans. */ + peer->pathmaxrxt = net->sctp.max_retrans_path; + peer->pf_retrans = net->sctp.pf_retrans; + + INIT_LIST_HEAD(&peer->transmitted); + INIT_LIST_HEAD(&peer->send_ready); + INIT_LIST_HEAD(&peer->transports); + + timer_setup(&peer->T3_rtx_timer, sctp_generate_t3_rtx_event, 0); + timer_setup(&peer->hb_timer, sctp_generate_heartbeat_event, 0); + timer_setup(&peer->reconf_timer, sctp_generate_reconf_event, 0); + timer_setup(&peer->probe_timer, sctp_generate_probe_event, 0); + timer_setup(&peer->proto_unreach_timer, + sctp_generate_proto_unreach_event, 0); + + /* Initialize the 64-bit random nonce sent with heartbeat. */ + get_random_bytes(&peer->hb_nonce, sizeof(peer->hb_nonce)); + + refcount_set(&peer->refcnt, 1); + + return peer; +} + +/* Allocate and initialize a new transport. */ +struct sctp_transport *sctp_transport_new(struct net *net, + const union sctp_addr *addr, + gfp_t gfp) +{ + struct sctp_transport *transport; + + transport = kzalloc(sizeof(*transport), gfp); + if (!transport) + goto fail; + + if (!sctp_transport_init(net, transport, addr, gfp)) + goto fail_init; + + SCTP_DBG_OBJCNT_INC(transport); + + return transport; + +fail_init: + kfree(transport); + +fail: + return NULL; +} + +/* This transport is no longer needed. Free up if possible, or + * delay until it last reference count. + */ +void sctp_transport_free(struct sctp_transport *transport) +{ + /* Try to delete the heartbeat timer. */ + if (del_timer(&transport->hb_timer)) + sctp_transport_put(transport); + + /* Delete the T3_rtx timer if it's active. + * There is no point in not doing this now and letting + * structure hang around in memory since we know + * the transport is going away. + */ + if (del_timer(&transport->T3_rtx_timer)) + sctp_transport_put(transport); + + if (del_timer(&transport->reconf_timer)) + sctp_transport_put(transport); + + if (del_timer(&transport->probe_timer)) + sctp_transport_put(transport); + + /* Delete the ICMP proto unreachable timer if it's active. */ + if (del_timer(&transport->proto_unreach_timer)) + sctp_transport_put(transport); + + sctp_transport_put(transport); +} + +static void sctp_transport_destroy_rcu(struct rcu_head *head) +{ + struct sctp_transport *transport; + + transport = container_of(head, struct sctp_transport, rcu); + + dst_release(transport->dst); + kfree(transport); + SCTP_DBG_OBJCNT_DEC(transport); +} + +/* Destroy the transport data structure. + * Assumes there are no more users of this structure. + */ +static void sctp_transport_destroy(struct sctp_transport *transport) +{ + if (unlikely(refcount_read(&transport->refcnt))) { + WARN(1, "Attempt to destroy undead transport %p!\n", transport); + return; + } + + sctp_packet_free(&transport->packet); + + if (transport->asoc) + sctp_association_put(transport->asoc); + + call_rcu(&transport->rcu, sctp_transport_destroy_rcu); +} + +/* Start T3_rtx timer if it is not already running and update the heartbeat + * timer. This routine is called every time a DATA chunk is sent. + */ +void sctp_transport_reset_t3_rtx(struct sctp_transport *transport) +{ + /* RFC 2960 6.3.2 Retransmission Timer Rules + * + * R1) Every time a DATA chunk is sent to any address(including a + * retransmission), if the T3-rtx timer of that address is not running + * start it running so that it will expire after the RTO of that + * address. + */ + + if (!timer_pending(&transport->T3_rtx_timer)) + if (!mod_timer(&transport->T3_rtx_timer, + jiffies + transport->rto)) + sctp_transport_hold(transport); +} + +void sctp_transport_reset_hb_timer(struct sctp_transport *transport) +{ + unsigned long expires; + + /* When a data chunk is sent, reset the heartbeat interval. */ + expires = jiffies + sctp_transport_timeout(transport); + if (!mod_timer(&transport->hb_timer, + expires + prandom_u32_max(transport->rto))) + sctp_transport_hold(transport); +} + +void sctp_transport_reset_reconf_timer(struct sctp_transport *transport) +{ + if (!timer_pending(&transport->reconf_timer)) + if (!mod_timer(&transport->reconf_timer, + jiffies + transport->rto)) + sctp_transport_hold(transport); +} + +void sctp_transport_reset_probe_timer(struct sctp_transport *transport) +{ + if (!mod_timer(&transport->probe_timer, + jiffies + transport->probe_interval)) + sctp_transport_hold(transport); +} + +void sctp_transport_reset_raise_timer(struct sctp_transport *transport) +{ + if (!mod_timer(&transport->probe_timer, + jiffies + transport->probe_interval * 30)) + sctp_transport_hold(transport); +} + +/* This transport has been assigned to an association. + * Initialize fields from the association or from the sock itself. + * Register the reference count in the association. + */ +void sctp_transport_set_owner(struct sctp_transport *transport, + struct sctp_association *asoc) +{ + transport->asoc = asoc; + sctp_association_hold(asoc); +} + +/* Initialize the pmtu of a transport. */ +void sctp_transport_pmtu(struct sctp_transport *transport, struct sock *sk) +{ + /* If we don't have a fresh route, look one up */ + if (!transport->dst || transport->dst->obsolete) { + sctp_transport_dst_release(transport); + transport->af_specific->get_dst(transport, &transport->saddr, + &transport->fl, sk); + } + + if (transport->param_flags & SPP_PMTUD_DISABLE) { + struct sctp_association *asoc = transport->asoc; + + if (!transport->pathmtu && asoc && asoc->pathmtu) + transport->pathmtu = asoc->pathmtu; + if (transport->pathmtu) + return; + } + + if (transport->dst) + transport->pathmtu = sctp_dst_mtu(transport->dst); + else + transport->pathmtu = SCTP_DEFAULT_MAXSEGMENT; + + sctp_transport_pl_update(transport); +} + +void sctp_transport_pl_send(struct sctp_transport *t) +{ + if (t->pl.probe_count < SCTP_MAX_PROBES) + goto out; + + t->pl.probe_count = 0; + if (t->pl.state == SCTP_PL_BASE) { + if (t->pl.probe_size == SCTP_BASE_PLPMTU) { /* BASE_PLPMTU Confirmation Failed */ + t->pl.state = SCTP_PL_ERROR; /* Base -> Error */ + + t->pl.pmtu = SCTP_BASE_PLPMTU; + t->pathmtu = t->pl.pmtu + sctp_transport_pl_hlen(t); + sctp_assoc_sync_pmtu(t->asoc); + } + } else if (t->pl.state == SCTP_PL_SEARCH) { + if (t->pl.pmtu == t->pl.probe_size) { /* Black Hole Detected */ + t->pl.state = SCTP_PL_BASE; /* Search -> Base */ + t->pl.probe_size = SCTP_BASE_PLPMTU; + t->pl.probe_high = 0; + + t->pl.pmtu = SCTP_BASE_PLPMTU; + t->pathmtu = t->pl.pmtu + sctp_transport_pl_hlen(t); + sctp_assoc_sync_pmtu(t->asoc); + } else { /* Normal probe failure. */ + t->pl.probe_high = t->pl.probe_size; + t->pl.probe_size = t->pl.pmtu; + } + } else if (t->pl.state == SCTP_PL_COMPLETE) { + if (t->pl.pmtu == t->pl.probe_size) { /* Black Hole Detected */ + t->pl.state = SCTP_PL_BASE; /* Search Complete -> Base */ + t->pl.probe_size = SCTP_BASE_PLPMTU; + + t->pl.pmtu = SCTP_BASE_PLPMTU; + t->pathmtu = t->pl.pmtu + sctp_transport_pl_hlen(t); + sctp_assoc_sync_pmtu(t->asoc); + } + } + +out: + pr_debug("%s: PLPMTUD: transport: %p, state: %d, pmtu: %d, size: %d, high: %d\n", + __func__, t, t->pl.state, t->pl.pmtu, t->pl.probe_size, t->pl.probe_high); + t->pl.probe_count++; +} + +bool sctp_transport_pl_recv(struct sctp_transport *t) +{ + pr_debug("%s: PLPMTUD: transport: %p, state: %d, pmtu: %d, size: %d, high: %d\n", + __func__, t, t->pl.state, t->pl.pmtu, t->pl.probe_size, t->pl.probe_high); + + t->pl.pmtu = t->pl.probe_size; + t->pl.probe_count = 0; + if (t->pl.state == SCTP_PL_BASE) { + t->pl.state = SCTP_PL_SEARCH; /* Base -> Search */ + t->pl.probe_size += SCTP_PL_BIG_STEP; + } else if (t->pl.state == SCTP_PL_ERROR) { + t->pl.state = SCTP_PL_SEARCH; /* Error -> Search */ + + t->pl.pmtu = t->pl.probe_size; + t->pathmtu = t->pl.pmtu + sctp_transport_pl_hlen(t); + sctp_assoc_sync_pmtu(t->asoc); + t->pl.probe_size += SCTP_PL_BIG_STEP; + } else if (t->pl.state == SCTP_PL_SEARCH) { + if (!t->pl.probe_high) { + if (t->pl.probe_size < SCTP_MAX_PLPMTU) { + t->pl.probe_size = min(t->pl.probe_size + SCTP_PL_BIG_STEP, + SCTP_MAX_PLPMTU); + return false; + } + t->pl.probe_high = SCTP_MAX_PLPMTU; + } + t->pl.probe_size += SCTP_PL_MIN_STEP; + if (t->pl.probe_size >= t->pl.probe_high) { + t->pl.probe_high = 0; + t->pl.state = SCTP_PL_COMPLETE; /* Search -> Search Complete */ + + t->pl.probe_size = t->pl.pmtu; + t->pathmtu = t->pl.pmtu + sctp_transport_pl_hlen(t); + sctp_assoc_sync_pmtu(t->asoc); + sctp_transport_reset_raise_timer(t); + } + } else if (t->pl.state == SCTP_PL_COMPLETE) { + /* Raise probe_size again after 30 * interval in Search Complete */ + t->pl.state = SCTP_PL_SEARCH; /* Search Complete -> Search */ + t->pl.probe_size = min(t->pl.probe_size + SCTP_PL_MIN_STEP, SCTP_MAX_PLPMTU); + } + + return t->pl.state == SCTP_PL_COMPLETE; +} + +static bool sctp_transport_pl_toobig(struct sctp_transport *t, u32 pmtu) +{ + pr_debug("%s: PLPMTUD: transport: %p, state: %d, pmtu: %d, size: %d, ptb: %d\n", + __func__, t, t->pl.state, t->pl.pmtu, t->pl.probe_size, pmtu); + + if (pmtu < SCTP_MIN_PLPMTU || pmtu >= t->pl.probe_size) + return false; + + if (t->pl.state == SCTP_PL_BASE) { + if (pmtu >= SCTP_MIN_PLPMTU && pmtu < SCTP_BASE_PLPMTU) { + t->pl.state = SCTP_PL_ERROR; /* Base -> Error */ + + t->pl.pmtu = SCTP_BASE_PLPMTU; + t->pathmtu = t->pl.pmtu + sctp_transport_pl_hlen(t); + return true; + } + } else if (t->pl.state == SCTP_PL_SEARCH) { + if (pmtu >= SCTP_BASE_PLPMTU && pmtu < t->pl.pmtu) { + t->pl.state = SCTP_PL_BASE; /* Search -> Base */ + t->pl.probe_size = SCTP_BASE_PLPMTU; + t->pl.probe_count = 0; + + t->pl.probe_high = 0; + t->pl.pmtu = SCTP_BASE_PLPMTU; + t->pathmtu = t->pl.pmtu + sctp_transport_pl_hlen(t); + return true; + } else if (pmtu > t->pl.pmtu && pmtu < t->pl.probe_size) { + t->pl.probe_size = pmtu; + t->pl.probe_count = 0; + } + } else if (t->pl.state == SCTP_PL_COMPLETE) { + if (pmtu >= SCTP_BASE_PLPMTU && pmtu < t->pl.pmtu) { + t->pl.state = SCTP_PL_BASE; /* Complete -> Base */ + t->pl.probe_size = SCTP_BASE_PLPMTU; + t->pl.probe_count = 0; + + t->pl.probe_high = 0; + t->pl.pmtu = SCTP_BASE_PLPMTU; + t->pathmtu = t->pl.pmtu + sctp_transport_pl_hlen(t); + sctp_transport_reset_probe_timer(t); + return true; + } + } + + return false; +} + +bool sctp_transport_update_pmtu(struct sctp_transport *t, u32 pmtu) +{ + struct sock *sk = t->asoc->base.sk; + struct dst_entry *dst; + bool change = true; + + if (unlikely(pmtu < SCTP_DEFAULT_MINSEGMENT)) { + pr_warn_ratelimited("%s: Reported pmtu %d too low, using default minimum of %d\n", + __func__, pmtu, SCTP_DEFAULT_MINSEGMENT); + /* Use default minimum segment instead */ + pmtu = SCTP_DEFAULT_MINSEGMENT; + } + pmtu = SCTP_TRUNC4(pmtu); + + if (sctp_transport_pl_enabled(t)) + return sctp_transport_pl_toobig(t, pmtu - sctp_transport_pl_hlen(t)); + + dst = sctp_transport_dst_check(t); + if (dst) { + struct sctp_pf *pf = sctp_get_pf_specific(dst->ops->family); + union sctp_addr addr; + + pf->af->from_sk(&addr, sk); + pf->to_sk_daddr(&t->ipaddr, sk); + dst->ops->update_pmtu(dst, sk, NULL, pmtu, true); + pf->to_sk_daddr(&addr, sk); + + dst = sctp_transport_dst_check(t); + } + + if (!dst) { + t->af_specific->get_dst(t, &t->saddr, &t->fl, sk); + dst = t->dst; + } + + if (dst) { + /* Re-fetch, as under layers may have a higher minimum size */ + pmtu = sctp_dst_mtu(dst); + change = t->pathmtu != pmtu; + } + t->pathmtu = pmtu; + + return change; +} + +/* Caches the dst entry and source address for a transport's destination + * address. + */ +void sctp_transport_route(struct sctp_transport *transport, + union sctp_addr *saddr, struct sctp_sock *opt) +{ + struct sctp_association *asoc = transport->asoc; + struct sctp_af *af = transport->af_specific; + + sctp_transport_dst_release(transport); + af->get_dst(transport, saddr, &transport->fl, sctp_opt2sk(opt)); + + if (saddr) + memcpy(&transport->saddr, saddr, sizeof(union sctp_addr)); + else + af->get_saddr(opt, transport, &transport->fl); + + sctp_transport_pmtu(transport, sctp_opt2sk(opt)); + + /* Initialize sk->sk_rcv_saddr, if the transport is the + * association's active path for getsockname(). + */ + if (transport->dst && asoc && + (!asoc->peer.primary_path || transport == asoc->peer.active_path)) + opt->pf->to_sk_saddr(&transport->saddr, asoc->base.sk); +} + +/* Hold a reference to a transport. */ +int sctp_transport_hold(struct sctp_transport *transport) +{ + return refcount_inc_not_zero(&transport->refcnt); +} + +/* Release a reference to a transport and clean up + * if there are no more references. + */ +void sctp_transport_put(struct sctp_transport *transport) +{ + if (refcount_dec_and_test(&transport->refcnt)) + sctp_transport_destroy(transport); +} + +/* Update transport's RTO based on the newly calculated RTT. */ +void sctp_transport_update_rto(struct sctp_transport *tp, __u32 rtt) +{ + if (unlikely(!tp->rto_pending)) + /* We should not be doing any RTO updates unless rto_pending is set. */ + pr_debug("%s: rto_pending not set on transport %p!\n", __func__, tp); + + if (tp->rttvar || tp->srtt) { + struct net *net = tp->asoc->base.net; + /* 6.3.1 C3) When a new RTT measurement R' is made, set + * RTTVAR <- (1 - RTO.Beta) * RTTVAR + RTO.Beta * |SRTT - R'| + * SRTT <- (1 - RTO.Alpha) * SRTT + RTO.Alpha * R' + */ + + /* Note: The above algorithm has been rewritten to + * express rto_beta and rto_alpha as inverse powers + * of two. + * For example, assuming the default value of RTO.Alpha of + * 1/8, rto_alpha would be expressed as 3. + */ + tp->rttvar = tp->rttvar - (tp->rttvar >> net->sctp.rto_beta) + + (((__u32)abs((__s64)tp->srtt - (__s64)rtt)) >> net->sctp.rto_beta); + tp->srtt = tp->srtt - (tp->srtt >> net->sctp.rto_alpha) + + (rtt >> net->sctp.rto_alpha); + } else { + /* 6.3.1 C2) When the first RTT measurement R is made, set + * SRTT <- R, RTTVAR <- R/2. + */ + tp->srtt = rtt; + tp->rttvar = rtt >> 1; + } + + /* 6.3.1 G1) Whenever RTTVAR is computed, if RTTVAR = 0, then + * adjust RTTVAR <- G, where G is the CLOCK GRANULARITY. + */ + if (tp->rttvar == 0) + tp->rttvar = SCTP_CLOCK_GRANULARITY; + + /* 6.3.1 C3) After the computation, update RTO <- SRTT + 4 * RTTVAR. */ + tp->rto = tp->srtt + (tp->rttvar << 2); + + /* 6.3.1 C6) Whenever RTO is computed, if it is less than RTO.Min + * seconds then it is rounded up to RTO.Min seconds. + */ + if (tp->rto < tp->asoc->rto_min) + tp->rto = tp->asoc->rto_min; + + /* 6.3.1 C7) A maximum value may be placed on RTO provided it is + * at least RTO.max seconds. + */ + if (tp->rto > tp->asoc->rto_max) + tp->rto = tp->asoc->rto_max; + + sctp_max_rto(tp->asoc, tp); + tp->rtt = rtt; + + /* Reset rto_pending so that a new RTT measurement is started when a + * new data chunk is sent. + */ + tp->rto_pending = 0; + + pr_debug("%s: transport:%p, rtt:%d, srtt:%d rttvar:%d, rto:%ld\n", + __func__, tp, rtt, tp->srtt, tp->rttvar, tp->rto); +} + +/* This routine updates the transport's cwnd and partial_bytes_acked + * parameters based on the bytes acked in the received SACK. + */ +void sctp_transport_raise_cwnd(struct sctp_transport *transport, + __u32 sack_ctsn, __u32 bytes_acked) +{ + struct sctp_association *asoc = transport->asoc; + __u32 cwnd, ssthresh, flight_size, pba, pmtu; + + cwnd = transport->cwnd; + flight_size = transport->flight_size; + + /* See if we need to exit Fast Recovery first */ + if (asoc->fast_recovery && + TSN_lte(asoc->fast_recovery_exit, sack_ctsn)) + asoc->fast_recovery = 0; + + ssthresh = transport->ssthresh; + pba = transport->partial_bytes_acked; + pmtu = transport->asoc->pathmtu; + + if (cwnd <= ssthresh) { + /* RFC 4960 7.2.1 + * o When cwnd is less than or equal to ssthresh, an SCTP + * endpoint MUST use the slow-start algorithm to increase + * cwnd only if the current congestion window is being fully + * utilized, an incoming SACK advances the Cumulative TSN + * Ack Point, and the data sender is not in Fast Recovery. + * Only when these three conditions are met can the cwnd be + * increased; otherwise, the cwnd MUST not be increased. + * If these conditions are met, then cwnd MUST be increased + * by, at most, the lesser of 1) the total size of the + * previously outstanding DATA chunk(s) acknowledged, and + * 2) the destination's path MTU. This upper bound protects + * against the ACK-Splitting attack outlined in [SAVAGE99]. + */ + if (asoc->fast_recovery) + return; + + /* The appropriate cwnd increase algorithm is performed + * if, and only if the congestion window is being fully + * utilized. Note that RFC4960 Errata 3.22 removed the + * other condition on ctsn moving. + */ + if (flight_size < cwnd) + return; + + if (bytes_acked > pmtu) + cwnd += pmtu; + else + cwnd += bytes_acked; + + pr_debug("%s: slow start: transport:%p, bytes_acked:%d, " + "cwnd:%d, ssthresh:%d, flight_size:%d, pba:%d\n", + __func__, transport, bytes_acked, cwnd, ssthresh, + flight_size, pba); + } else { + /* RFC 2960 7.2.2 Whenever cwnd is greater than ssthresh, + * upon each SACK arrival, increase partial_bytes_acked + * by the total number of bytes of all new chunks + * acknowledged in that SACK including chunks + * acknowledged by the new Cumulative TSN Ack and by Gap + * Ack Blocks. (updated by RFC4960 Errata 3.22) + * + * When partial_bytes_acked is greater than cwnd and + * before the arrival of the SACK the sender had less + * bytes of data outstanding than cwnd (i.e., before + * arrival of the SACK, flightsize was less than cwnd), + * reset partial_bytes_acked to cwnd. (RFC 4960 Errata + * 3.26) + * + * When partial_bytes_acked is equal to or greater than + * cwnd and before the arrival of the SACK the sender + * had cwnd or more bytes of data outstanding (i.e., + * before arrival of the SACK, flightsize was greater + * than or equal to cwnd), partial_bytes_acked is reset + * to (partial_bytes_acked - cwnd). Next, cwnd is + * increased by MTU. (RFC 4960 Errata 3.12) + */ + pba += bytes_acked; + if (pba > cwnd && flight_size < cwnd) + pba = cwnd; + if (pba >= cwnd && flight_size >= cwnd) { + pba = pba - cwnd; + cwnd += pmtu; + } + + pr_debug("%s: congestion avoidance: transport:%p, " + "bytes_acked:%d, cwnd:%d, ssthresh:%d, " + "flight_size:%d, pba:%d\n", __func__, + transport, bytes_acked, cwnd, ssthresh, + flight_size, pba); + } + + transport->cwnd = cwnd; + transport->partial_bytes_acked = pba; +} + +/* This routine is used to lower the transport's cwnd when congestion is + * detected. + */ +void sctp_transport_lower_cwnd(struct sctp_transport *transport, + enum sctp_lower_cwnd reason) +{ + struct sctp_association *asoc = transport->asoc; + + switch (reason) { + case SCTP_LOWER_CWND_T3_RTX: + /* RFC 2960 Section 7.2.3, sctpimpguide + * When the T3-rtx timer expires on an address, SCTP should + * perform slow start by: + * ssthresh = max(cwnd/2, 4*MTU) + * cwnd = 1*MTU + * partial_bytes_acked = 0 + */ + transport->ssthresh = max(transport->cwnd/2, + 4*asoc->pathmtu); + transport->cwnd = asoc->pathmtu; + + /* T3-rtx also clears fast recovery */ + asoc->fast_recovery = 0; + break; + + case SCTP_LOWER_CWND_FAST_RTX: + /* RFC 2960 7.2.4 Adjust the ssthresh and cwnd of the + * destination address(es) to which the missing DATA chunks + * were last sent, according to the formula described in + * Section 7.2.3. + * + * RFC 2960 7.2.3, sctpimpguide Upon detection of packet + * losses from SACK (see Section 7.2.4), An endpoint + * should do the following: + * ssthresh = max(cwnd/2, 4*MTU) + * cwnd = ssthresh + * partial_bytes_acked = 0 + */ + if (asoc->fast_recovery) + return; + + /* Mark Fast recovery */ + asoc->fast_recovery = 1; + asoc->fast_recovery_exit = asoc->next_tsn - 1; + + transport->ssthresh = max(transport->cwnd/2, + 4*asoc->pathmtu); + transport->cwnd = transport->ssthresh; + break; + + case SCTP_LOWER_CWND_ECNE: + /* RFC 2481 Section 6.1.2. + * If the sender receives an ECN-Echo ACK packet + * then the sender knows that congestion was encountered in the + * network on the path from the sender to the receiver. The + * indication of congestion should be treated just as a + * congestion loss in non-ECN Capable TCP. That is, the TCP + * source halves the congestion window "cwnd" and reduces the + * slow start threshold "ssthresh". + * A critical condition is that TCP does not react to + * congestion indications more than once every window of + * data (or more loosely more than once every round-trip time). + */ + if (time_after(jiffies, transport->last_time_ecne_reduced + + transport->rtt)) { + transport->ssthresh = max(transport->cwnd/2, + 4*asoc->pathmtu); + transport->cwnd = transport->ssthresh; + transport->last_time_ecne_reduced = jiffies; + } + break; + + case SCTP_LOWER_CWND_INACTIVE: + /* RFC 2960 Section 7.2.1, sctpimpguide + * When the endpoint does not transmit data on a given + * transport address, the cwnd of the transport address + * should be adjusted to max(cwnd/2, 4*MTU) per RTO. + * NOTE: Although the draft recommends that this check needs + * to be done every RTO interval, we do it every hearbeat + * interval. + */ + transport->cwnd = max(transport->cwnd/2, + 4*asoc->pathmtu); + /* RFC 4960 Errata 3.27.2: also adjust sshthresh */ + transport->ssthresh = transport->cwnd; + break; + } + + transport->partial_bytes_acked = 0; + + pr_debug("%s: transport:%p, reason:%d, cwnd:%d, ssthresh:%d\n", + __func__, transport, reason, transport->cwnd, + transport->ssthresh); +} + +/* Apply Max.Burst limit to the congestion window: + * sctpimpguide-05 2.14.2 + * D) When the time comes for the sender to + * transmit new DATA chunks, the protocol parameter Max.Burst MUST + * first be applied to limit how many new DATA chunks may be sent. + * The limit is applied by adjusting cwnd as follows: + * if ((flightsize+ Max.Burst * MTU) < cwnd) + * cwnd = flightsize + Max.Burst * MTU + */ + +void sctp_transport_burst_limited(struct sctp_transport *t) +{ + struct sctp_association *asoc = t->asoc; + u32 old_cwnd = t->cwnd; + u32 max_burst_bytes; + + if (t->burst_limited || asoc->max_burst == 0) + return; + + max_burst_bytes = t->flight_size + (asoc->max_burst * asoc->pathmtu); + if (max_burst_bytes < old_cwnd) { + t->cwnd = max_burst_bytes; + t->burst_limited = old_cwnd; + } +} + +/* Restore the old cwnd congestion window, after the burst had it's + * desired effect. + */ +void sctp_transport_burst_reset(struct sctp_transport *t) +{ + if (t->burst_limited) { + t->cwnd = t->burst_limited; + t->burst_limited = 0; + } +} + +/* What is the next timeout value for this transport? */ +unsigned long sctp_transport_timeout(struct sctp_transport *trans) +{ + /* RTO + timer slack +/- 50% of RTO */ + unsigned long timeout = trans->rto >> 1; + + if (trans->state != SCTP_UNCONFIRMED && + trans->state != SCTP_PF) + timeout += trans->hbinterval; + + return max_t(unsigned long, timeout, HZ / 5); +} + +/* Reset transport variables to their initial values */ +void sctp_transport_reset(struct sctp_transport *t) +{ + struct sctp_association *asoc = t->asoc; + + /* RFC 2960 (bis), Section 5.2.4 + * All the congestion control parameters (e.g., cwnd, ssthresh) + * related to this peer MUST be reset to their initial values + * (see Section 6.2.1) + */ + t->cwnd = min(4*asoc->pathmtu, max_t(__u32, 2*asoc->pathmtu, 4380)); + t->burst_limited = 0; + t->ssthresh = asoc->peer.i.a_rwnd; + t->rto = asoc->rto_initial; + sctp_max_rto(asoc, t); + t->rtt = 0; + t->srtt = 0; + t->rttvar = 0; + + /* Reset these additional variables so that we have a clean slate. */ + t->partial_bytes_acked = 0; + t->flight_size = 0; + t->error_count = 0; + t->rto_pending = 0; + t->hb_sent = 0; + + /* Initialize the state information for SFR-CACC */ + t->cacc.changeover_active = 0; + t->cacc.cycling_changeover = 0; + t->cacc.next_tsn_at_change = 0; + t->cacc.cacc_saw_newack = 0; +} + +/* Schedule retransmission on the given transport */ +void sctp_transport_immediate_rtx(struct sctp_transport *t) +{ + /* Stop pending T3_rtx_timer */ + if (del_timer(&t->T3_rtx_timer)) + sctp_transport_put(t); + + sctp_retransmit(&t->asoc->outqueue, t, SCTP_RTXR_T3_RTX); + if (!timer_pending(&t->T3_rtx_timer)) { + if (!mod_timer(&t->T3_rtx_timer, jiffies + t->rto)) + sctp_transport_hold(t); + } +} + +/* Drop dst */ +void sctp_transport_dst_release(struct sctp_transport *t) +{ + dst_release(t->dst); + t->dst = NULL; + t->dst_pending_confirm = 0; +} + +/* Schedule neighbour confirm */ +void sctp_transport_dst_confirm(struct sctp_transport *t) +{ + t->dst_pending_confirm = 1; +} diff --git a/net/sctp/tsnmap.c b/net/sctp/tsnmap.c new file mode 100644 index 000000000..5ba456727 --- /dev/null +++ b/net/sctp/tsnmap.c @@ -0,0 +1,364 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2001, 2004 + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * Copyright (c) 2001 Intel Corp. + * + * This file is part of the SCTP kernel implementation + * + * These functions manipulate sctp tsn mapping array. + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * La Monte H.P. Yarroll + * Jon Grimm + * Karl Knutson + * Sridhar Samudrala + */ + +#include +#include +#include +#include +#include + +static void sctp_tsnmap_update(struct sctp_tsnmap *map); +static void sctp_tsnmap_find_gap_ack(unsigned long *map, __u16 off, + __u16 len, __u16 *start, __u16 *end); +static int sctp_tsnmap_grow(struct sctp_tsnmap *map, u16 size); + +/* Initialize a block of memory as a tsnmap. */ +struct sctp_tsnmap *sctp_tsnmap_init(struct sctp_tsnmap *map, __u16 len, + __u32 initial_tsn, gfp_t gfp) +{ + if (!map->tsn_map) { + map->tsn_map = kzalloc(len>>3, gfp); + if (map->tsn_map == NULL) + return NULL; + + map->len = len; + } else { + bitmap_zero(map->tsn_map, map->len); + } + + /* Keep track of TSNs represented by tsn_map. */ + map->base_tsn = initial_tsn; + map->cumulative_tsn_ack_point = initial_tsn - 1; + map->max_tsn_seen = map->cumulative_tsn_ack_point; + map->num_dup_tsns = 0; + + return map; +} + +void sctp_tsnmap_free(struct sctp_tsnmap *map) +{ + map->len = 0; + kfree(map->tsn_map); +} + +/* Test the tracking state of this TSN. + * Returns: + * 0 if the TSN has not yet been seen + * >0 if the TSN has been seen (duplicate) + * <0 if the TSN is invalid (too large to track) + */ +int sctp_tsnmap_check(const struct sctp_tsnmap *map, __u32 tsn) +{ + u32 gap; + + /* Check to see if this is an old TSN */ + if (TSN_lte(tsn, map->cumulative_tsn_ack_point)) + return 1; + + /* Verify that we can hold this TSN and that it will not + * overflow our map + */ + if (!TSN_lt(tsn, map->base_tsn + SCTP_TSN_MAP_SIZE)) + return -1; + + /* Calculate the index into the mapping arrays. */ + gap = tsn - map->base_tsn; + + /* Check to see if TSN has already been recorded. */ + if (gap < map->len && test_bit(gap, map->tsn_map)) + return 1; + else + return 0; +} + + +/* Mark this TSN as seen. */ +int sctp_tsnmap_mark(struct sctp_tsnmap *map, __u32 tsn, + struct sctp_transport *trans) +{ + u16 gap; + + if (TSN_lt(tsn, map->base_tsn)) + return 0; + + gap = tsn - map->base_tsn; + + if (gap >= map->len && !sctp_tsnmap_grow(map, gap + 1)) + return -ENOMEM; + + if (!sctp_tsnmap_has_gap(map) && gap == 0) { + /* In this case the map has no gaps and the tsn we are + * recording is the next expected tsn. We don't touch + * the map but simply bump the values. + */ + map->max_tsn_seen++; + map->cumulative_tsn_ack_point++; + if (trans) + trans->sack_generation = + trans->asoc->peer.sack_generation; + map->base_tsn++; + } else { + /* Either we already have a gap, or about to record a gap, so + * have work to do. + * + * Bump the max. + */ + if (TSN_lt(map->max_tsn_seen, tsn)) + map->max_tsn_seen = tsn; + + /* Mark the TSN as received. */ + set_bit(gap, map->tsn_map); + + /* Go fixup any internal TSN mapping variables including + * cumulative_tsn_ack_point. + */ + sctp_tsnmap_update(map); + } + + return 0; +} + + +/* Initialize a Gap Ack Block iterator from memory being provided. */ +static void sctp_tsnmap_iter_init(const struct sctp_tsnmap *map, + struct sctp_tsnmap_iter *iter) +{ + /* Only start looking one past the Cumulative TSN Ack Point. */ + iter->start = map->cumulative_tsn_ack_point + 1; +} + +/* Get the next Gap Ack Blocks. Returns 0 if there was not another block + * to get. + */ +static int sctp_tsnmap_next_gap_ack(const struct sctp_tsnmap *map, + struct sctp_tsnmap_iter *iter, + __u16 *start, __u16 *end) +{ + int ended = 0; + __u16 start_ = 0, end_ = 0, offset; + + /* If there are no more gap acks possible, get out fast. */ + if (TSN_lte(map->max_tsn_seen, iter->start)) + return 0; + + offset = iter->start - map->base_tsn; + sctp_tsnmap_find_gap_ack(map->tsn_map, offset, map->len, + &start_, &end_); + + /* The Gap Ack Block happens to end at the end of the map. */ + if (start_ && !end_) + end_ = map->len - 1; + + /* If we found a Gap Ack Block, return the start and end and + * bump the iterator forward. + */ + if (end_) { + /* Fix up the start and end based on the + * Cumulative TSN Ack which is always 1 behind base. + */ + *start = start_ + 1; + *end = end_ + 1; + + /* Move the iterator forward. */ + iter->start = map->cumulative_tsn_ack_point + *end + 1; + ended = 1; + } + + return ended; +} + +/* Mark this and any lower TSN as seen. */ +void sctp_tsnmap_skip(struct sctp_tsnmap *map, __u32 tsn) +{ + u32 gap; + + if (TSN_lt(tsn, map->base_tsn)) + return; + if (!TSN_lt(tsn, map->base_tsn + SCTP_TSN_MAP_SIZE)) + return; + + /* Bump the max. */ + if (TSN_lt(map->max_tsn_seen, tsn)) + map->max_tsn_seen = tsn; + + gap = tsn - map->base_tsn + 1; + + map->base_tsn += gap; + map->cumulative_tsn_ack_point += gap; + if (gap >= map->len) { + /* If our gap is larger then the map size, just + * zero out the map. + */ + bitmap_zero(map->tsn_map, map->len); + } else { + /* If the gap is smaller than the map size, + * shift the map by 'gap' bits and update further. + */ + bitmap_shift_right(map->tsn_map, map->tsn_map, gap, map->len); + sctp_tsnmap_update(map); + } +} + +/******************************************************************** + * 2nd Level Abstractions + ********************************************************************/ + +/* This private helper function updates the tsnmap buffers and + * the Cumulative TSN Ack Point. + */ +static void sctp_tsnmap_update(struct sctp_tsnmap *map) +{ + u16 len; + unsigned long zero_bit; + + + len = map->max_tsn_seen - map->cumulative_tsn_ack_point; + zero_bit = find_first_zero_bit(map->tsn_map, len); + if (!zero_bit) + return; /* The first 0-bit is bit 0. nothing to do */ + + map->base_tsn += zero_bit; + map->cumulative_tsn_ack_point += zero_bit; + + bitmap_shift_right(map->tsn_map, map->tsn_map, zero_bit, map->len); +} + +/* How many data chunks are we missing from our peer? + */ +__u16 sctp_tsnmap_pending(struct sctp_tsnmap *map) +{ + __u32 cum_tsn = map->cumulative_tsn_ack_point; + __u32 max_tsn = map->max_tsn_seen; + __u32 base_tsn = map->base_tsn; + __u16 pending_data; + u32 gap; + + pending_data = max_tsn - cum_tsn; + gap = max_tsn - base_tsn; + + if (gap == 0 || gap >= map->len) + goto out; + + pending_data -= bitmap_weight(map->tsn_map, gap + 1); +out: + return pending_data; +} + +/* This is a private helper for finding Gap Ack Blocks. It searches a + * single array for the start and end of a Gap Ack Block. + * + * The flags "started" and "ended" tell is if we found the beginning + * or (respectively) the end of a Gap Ack Block. + */ +static void sctp_tsnmap_find_gap_ack(unsigned long *map, __u16 off, + __u16 len, __u16 *start, __u16 *end) +{ + int i = off; + + /* Look through the entire array, but break out + * early if we have found the end of the Gap Ack Block. + */ + + /* Also, stop looking past the maximum TSN seen. */ + + /* Look for the start. */ + i = find_next_bit(map, len, off); + if (i < len) + *start = i; + + /* Look for the end. */ + if (*start) { + /* We have found the start, let's find the + * end. If we find the end, break out. + */ + i = find_next_zero_bit(map, len, i); + if (i < len) + *end = i - 1; + } +} + +/* Renege that we have seen a TSN. */ +void sctp_tsnmap_renege(struct sctp_tsnmap *map, __u32 tsn) +{ + u32 gap; + + if (TSN_lt(tsn, map->base_tsn)) + return; + /* Assert: TSN is in range. */ + if (!TSN_lt(tsn, map->base_tsn + map->len)) + return; + + gap = tsn - map->base_tsn; + + /* Pretend we never saw the TSN. */ + clear_bit(gap, map->tsn_map); +} + +/* How many gap ack blocks do we have recorded? */ +__u16 sctp_tsnmap_num_gabs(struct sctp_tsnmap *map, + struct sctp_gap_ack_block *gabs) +{ + struct sctp_tsnmap_iter iter; + int ngaps = 0; + + /* Refresh the gap ack information. */ + if (sctp_tsnmap_has_gap(map)) { + __u16 start = 0, end = 0; + sctp_tsnmap_iter_init(map, &iter); + while (sctp_tsnmap_next_gap_ack(map, &iter, + &start, + &end)) { + + gabs[ngaps].start = htons(start); + gabs[ngaps].end = htons(end); + ngaps++; + if (ngaps >= SCTP_MAX_GABS) + break; + } + } + return ngaps; +} + +static int sctp_tsnmap_grow(struct sctp_tsnmap *map, u16 size) +{ + unsigned long *new; + unsigned long inc; + u16 len; + + if (size > SCTP_TSN_MAP_SIZE) + return 0; + + inc = ALIGN((size - map->len), BITS_PER_LONG) + SCTP_TSN_MAP_INCREMENT; + len = min_t(u16, map->len + inc, SCTP_TSN_MAP_SIZE); + + new = kzalloc(len>>3, GFP_ATOMIC); + if (!new) + return 0; + + bitmap_copy(new, map->tsn_map, + map->max_tsn_seen - map->cumulative_tsn_ack_point); + kfree(map->tsn_map); + map->tsn_map = new; + map->len = len; + + return 1; +} diff --git a/net/sctp/ulpevent.c b/net/sctp/ulpevent.c new file mode 100644 index 000000000..8920ca92a --- /dev/null +++ b/net/sctp/ulpevent.c @@ -0,0 +1,1190 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2001, 2004 + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * Copyright (c) 2001 Intel Corp. + * Copyright (c) 2001 Nokia, Inc. + * Copyright (c) 2001 La Monte H.P. Yarroll + * + * These functions manipulate an sctp event. The struct ulpevent is used + * to carry notifications and data to the ULP (sockets). + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * Jon Grimm + * La Monte H.P. Yarroll + * Ardelle Fan + * Sridhar Samudrala + */ + +#include +#include +#include +#include +#include +#include + +static void sctp_ulpevent_receive_data(struct sctp_ulpevent *event, + struct sctp_association *asoc); +static void sctp_ulpevent_release_data(struct sctp_ulpevent *event); +static void sctp_ulpevent_release_frag_data(struct sctp_ulpevent *event); + + +/* Initialize an ULP event from an given skb. */ +static void sctp_ulpevent_init(struct sctp_ulpevent *event, + __u16 msg_flags, + unsigned int len) +{ + memset(event, 0, sizeof(struct sctp_ulpevent)); + event->msg_flags = msg_flags; + event->rmem_len = len; +} + +/* Create a new sctp_ulpevent. */ +static struct sctp_ulpevent *sctp_ulpevent_new(int size, __u16 msg_flags, + gfp_t gfp) +{ + struct sctp_ulpevent *event; + struct sk_buff *skb; + + skb = alloc_skb(size, gfp); + if (!skb) + goto fail; + + event = sctp_skb2event(skb); + sctp_ulpevent_init(event, msg_flags, skb->truesize); + + return event; + +fail: + return NULL; +} + +/* Is this a MSG_NOTIFICATION? */ +int sctp_ulpevent_is_notification(const struct sctp_ulpevent *event) +{ + return MSG_NOTIFICATION == (event->msg_flags & MSG_NOTIFICATION); +} + +/* Hold the association in case the msg_name needs read out of + * the association. + */ +static inline void sctp_ulpevent_set_owner(struct sctp_ulpevent *event, + const struct sctp_association *asoc) +{ + struct sctp_chunk *chunk = event->chunk; + struct sk_buff *skb; + + /* Cast away the const, as we are just wanting to + * bump the reference count. + */ + sctp_association_hold((struct sctp_association *)asoc); + skb = sctp_event2skb(event); + event->asoc = (struct sctp_association *)asoc; + atomic_add(event->rmem_len, &event->asoc->rmem_alloc); + sctp_skb_set_owner_r(skb, asoc->base.sk); + if (chunk && chunk->head_skb && !chunk->head_skb->sk) + chunk->head_skb->sk = asoc->base.sk; +} + +/* A simple destructor to give up the reference to the association. */ +static inline void sctp_ulpevent_release_owner(struct sctp_ulpevent *event) +{ + struct sctp_association *asoc = event->asoc; + + atomic_sub(event->rmem_len, &asoc->rmem_alloc); + sctp_association_put(asoc); +} + +/* Create and initialize an SCTP_ASSOC_CHANGE event. + * + * 5.3.1.1 SCTP_ASSOC_CHANGE + * + * Communication notifications inform the ULP that an SCTP association + * has either begun or ended. The identifier for a new association is + * provided by this notification. + * + * Note: There is no field checking here. If a field is unused it will be + * zero'd out. + */ +struct sctp_ulpevent *sctp_ulpevent_make_assoc_change( + const struct sctp_association *asoc, + __u16 flags, __u16 state, __u16 error, __u16 outbound, + __u16 inbound, struct sctp_chunk *chunk, gfp_t gfp) +{ + struct sctp_ulpevent *event; + struct sctp_assoc_change *sac; + struct sk_buff *skb; + + /* If the lower layer passed in the chunk, it will be + * an ABORT, so we need to include it in the sac_info. + */ + if (chunk) { + /* Copy the chunk data to a new skb and reserve enough + * head room to use as notification. + */ + skb = skb_copy_expand(chunk->skb, + sizeof(struct sctp_assoc_change), 0, gfp); + + if (!skb) + goto fail; + + /* Embed the event fields inside the cloned skb. */ + event = sctp_skb2event(skb); + sctp_ulpevent_init(event, MSG_NOTIFICATION, skb->truesize); + + /* Include the notification structure */ + sac = skb_push(skb, sizeof(struct sctp_assoc_change)); + + /* Trim the buffer to the right length. */ + skb_trim(skb, sizeof(struct sctp_assoc_change) + + ntohs(chunk->chunk_hdr->length) - + sizeof(struct sctp_chunkhdr)); + } else { + event = sctp_ulpevent_new(sizeof(struct sctp_assoc_change), + MSG_NOTIFICATION, gfp); + if (!event) + goto fail; + + skb = sctp_event2skb(event); + sac = skb_put(skb, sizeof(struct sctp_assoc_change)); + } + + /* Socket Extensions for SCTP + * 5.3.1.1 SCTP_ASSOC_CHANGE + * + * sac_type: + * It should be SCTP_ASSOC_CHANGE. + */ + sac->sac_type = SCTP_ASSOC_CHANGE; + + /* Socket Extensions for SCTP + * 5.3.1.1 SCTP_ASSOC_CHANGE + * + * sac_state: 32 bits (signed integer) + * This field holds one of a number of values that communicate the + * event that happened to the association. + */ + sac->sac_state = state; + + /* Socket Extensions for SCTP + * 5.3.1.1 SCTP_ASSOC_CHANGE + * + * sac_flags: 16 bits (unsigned integer) + * Currently unused. + */ + sac->sac_flags = 0; + + /* Socket Extensions for SCTP + * 5.3.1.1 SCTP_ASSOC_CHANGE + * + * sac_length: sizeof (__u32) + * This field is the total length of the notification data, including + * the notification header. + */ + sac->sac_length = skb->len; + + /* Socket Extensions for SCTP + * 5.3.1.1 SCTP_ASSOC_CHANGE + * + * sac_error: 32 bits (signed integer) + * + * If the state was reached due to a error condition (e.g. + * COMMUNICATION_LOST) any relevant error information is available in + * this field. This corresponds to the protocol error codes defined in + * [SCTP]. + */ + sac->sac_error = error; + + /* Socket Extensions for SCTP + * 5.3.1.1 SCTP_ASSOC_CHANGE + * + * sac_outbound_streams: 16 bits (unsigned integer) + * sac_inbound_streams: 16 bits (unsigned integer) + * + * The maximum number of streams allowed in each direction are + * available in sac_outbound_streams and sac_inbound streams. + */ + sac->sac_outbound_streams = outbound; + sac->sac_inbound_streams = inbound; + + /* Socket Extensions for SCTP + * 5.3.1.1 SCTP_ASSOC_CHANGE + * + * sac_assoc_id: sizeof (sctp_assoc_t) + * + * The association id field, holds the identifier for the association. + * All notifications for a given association have the same association + * identifier. For TCP style socket, this field is ignored. + */ + sctp_ulpevent_set_owner(event, asoc); + sac->sac_assoc_id = sctp_assoc2id(asoc); + + return event; + +fail: + return NULL; +} + +/* Create and initialize an SCTP_PEER_ADDR_CHANGE event. + * + * Socket Extensions for SCTP - draft-01 + * 5.3.1.2 SCTP_PEER_ADDR_CHANGE + * + * When a destination address on a multi-homed peer encounters a change + * an interface details event is sent. + */ +static struct sctp_ulpevent *sctp_ulpevent_make_peer_addr_change( + const struct sctp_association *asoc, + const struct sockaddr_storage *aaddr, + int flags, int state, int error, gfp_t gfp) +{ + struct sctp_ulpevent *event; + struct sctp_paddr_change *spc; + struct sk_buff *skb; + + event = sctp_ulpevent_new(sizeof(struct sctp_paddr_change), + MSG_NOTIFICATION, gfp); + if (!event) + goto fail; + + skb = sctp_event2skb(event); + spc = skb_put(skb, sizeof(struct sctp_paddr_change)); + + /* Sockets API Extensions for SCTP + * Section 5.3.1.2 SCTP_PEER_ADDR_CHANGE + * + * spc_type: + * + * It should be SCTP_PEER_ADDR_CHANGE. + */ + spc->spc_type = SCTP_PEER_ADDR_CHANGE; + + /* Sockets API Extensions for SCTP + * Section 5.3.1.2 SCTP_PEER_ADDR_CHANGE + * + * spc_length: sizeof (__u32) + * + * This field is the total length of the notification data, including + * the notification header. + */ + spc->spc_length = sizeof(struct sctp_paddr_change); + + /* Sockets API Extensions for SCTP + * Section 5.3.1.2 SCTP_PEER_ADDR_CHANGE + * + * spc_flags: 16 bits (unsigned integer) + * Currently unused. + */ + spc->spc_flags = 0; + + /* Sockets API Extensions for SCTP + * Section 5.3.1.2 SCTP_PEER_ADDR_CHANGE + * + * spc_state: 32 bits (signed integer) + * + * This field holds one of a number of values that communicate the + * event that happened to the address. + */ + spc->spc_state = state; + + /* Sockets API Extensions for SCTP + * Section 5.3.1.2 SCTP_PEER_ADDR_CHANGE + * + * spc_error: 32 bits (signed integer) + * + * If the state was reached due to any error condition (e.g. + * ADDRESS_UNREACHABLE) any relevant error information is available in + * this field. + */ + spc->spc_error = error; + + /* Socket Extensions for SCTP + * 5.3.1.1 SCTP_ASSOC_CHANGE + * + * spc_assoc_id: sizeof (sctp_assoc_t) + * + * The association id field, holds the identifier for the association. + * All notifications for a given association have the same association + * identifier. For TCP style socket, this field is ignored. + */ + sctp_ulpevent_set_owner(event, asoc); + spc->spc_assoc_id = sctp_assoc2id(asoc); + + /* Sockets API Extensions for SCTP + * Section 5.3.1.2 SCTP_PEER_ADDR_CHANGE + * + * spc_aaddr: sizeof (struct sockaddr_storage) + * + * The affected address field, holds the remote peer's address that is + * encountering the change of state. + */ + memcpy(&spc->spc_aaddr, aaddr, sizeof(struct sockaddr_storage)); + + /* Map ipv4 address into v4-mapped-on-v6 address. */ + sctp_get_pf_specific(asoc->base.sk->sk_family)->addr_to_user( + sctp_sk(asoc->base.sk), + (union sctp_addr *)&spc->spc_aaddr); + + return event; + +fail: + return NULL; +} + +void sctp_ulpevent_notify_peer_addr_change(struct sctp_transport *transport, + int state, int error) +{ + struct sctp_association *asoc = transport->asoc; + struct sockaddr_storage addr; + struct sctp_ulpevent *event; + + if (asoc->state < SCTP_STATE_ESTABLISHED) + return; + + memset(&addr, 0, sizeof(struct sockaddr_storage)); + memcpy(&addr, &transport->ipaddr, transport->af_specific->sockaddr_len); + + event = sctp_ulpevent_make_peer_addr_change(asoc, &addr, 0, state, + error, GFP_ATOMIC); + if (event) + asoc->stream.si->enqueue_event(&asoc->ulpq, event); +} + +/* Create and initialize an SCTP_REMOTE_ERROR notification. + * + * Note: This assumes that the chunk->skb->data already points to the + * operation error payload. + * + * Socket Extensions for SCTP - draft-01 + * 5.3.1.3 SCTP_REMOTE_ERROR + * + * A remote peer may send an Operational Error message to its peer. + * This message indicates a variety of error conditions on an + * association. The entire error TLV as it appears on the wire is + * included in a SCTP_REMOTE_ERROR event. Please refer to the SCTP + * specification [SCTP] and any extensions for a list of possible + * error formats. + */ +struct sctp_ulpevent * +sctp_ulpevent_make_remote_error(const struct sctp_association *asoc, + struct sctp_chunk *chunk, __u16 flags, + gfp_t gfp) +{ + struct sctp_remote_error *sre; + struct sctp_ulpevent *event; + struct sctp_errhdr *ch; + struct sk_buff *skb; + __be16 cause; + int elen; + + ch = (struct sctp_errhdr *)(chunk->skb->data); + cause = ch->cause; + elen = SCTP_PAD4(ntohs(ch->length)) - sizeof(*ch); + + /* Pull off the ERROR header. */ + skb_pull(chunk->skb, sizeof(*ch)); + + /* Copy the skb to a new skb with room for us to prepend + * notification with. + */ + skb = skb_copy_expand(chunk->skb, sizeof(*sre), 0, gfp); + + /* Pull off the rest of the cause TLV from the chunk. */ + skb_pull(chunk->skb, elen); + if (!skb) + goto fail; + + /* Embed the event fields inside the cloned skb. */ + event = sctp_skb2event(skb); + sctp_ulpevent_init(event, MSG_NOTIFICATION, skb->truesize); + + sre = skb_push(skb, sizeof(*sre)); + + /* Trim the buffer to the right length. */ + skb_trim(skb, sizeof(*sre) + elen); + + /* RFC6458, Section 6.1.3. SCTP_REMOTE_ERROR */ + memset(sre, 0, sizeof(*sre)); + sre->sre_type = SCTP_REMOTE_ERROR; + sre->sre_flags = 0; + sre->sre_length = skb->len; + sre->sre_error = cause; + sctp_ulpevent_set_owner(event, asoc); + sre->sre_assoc_id = sctp_assoc2id(asoc); + + return event; +fail: + return NULL; +} + +/* Create and initialize a SCTP_SEND_FAILED notification. + * + * Socket Extensions for SCTP - draft-01 + * 5.3.1.4 SCTP_SEND_FAILED + */ +struct sctp_ulpevent *sctp_ulpevent_make_send_failed( + const struct sctp_association *asoc, struct sctp_chunk *chunk, + __u16 flags, __u32 error, gfp_t gfp) +{ + struct sctp_ulpevent *event; + struct sctp_send_failed *ssf; + struct sk_buff *skb; + + /* Pull off any padding. */ + int len = ntohs(chunk->chunk_hdr->length); + + /* Make skb with more room so we can prepend notification. */ + skb = skb_copy_expand(chunk->skb, + sizeof(struct sctp_send_failed), /* headroom */ + 0, /* tailroom */ + gfp); + if (!skb) + goto fail; + + /* Pull off the common chunk header and DATA header. */ + skb_pull(skb, sctp_datachk_len(&asoc->stream)); + len -= sctp_datachk_len(&asoc->stream); + + /* Embed the event fields inside the cloned skb. */ + event = sctp_skb2event(skb); + sctp_ulpevent_init(event, MSG_NOTIFICATION, skb->truesize); + + ssf = skb_push(skb, sizeof(struct sctp_send_failed)); + + /* Socket Extensions for SCTP + * 5.3.1.4 SCTP_SEND_FAILED + * + * ssf_type: + * It should be SCTP_SEND_FAILED. + */ + ssf->ssf_type = SCTP_SEND_FAILED; + + /* Socket Extensions for SCTP + * 5.3.1.4 SCTP_SEND_FAILED + * + * ssf_flags: 16 bits (unsigned integer) + * The flag value will take one of the following values + * + * SCTP_DATA_UNSENT - Indicates that the data was never put on + * the wire. + * + * SCTP_DATA_SENT - Indicates that the data was put on the wire. + * Note that this does not necessarily mean that the + * data was (or was not) successfully delivered. + */ + ssf->ssf_flags = flags; + + /* Socket Extensions for SCTP + * 5.3.1.4 SCTP_SEND_FAILED + * + * ssf_length: sizeof (__u32) + * This field is the total length of the notification data, including + * the notification header. + */ + ssf->ssf_length = sizeof(struct sctp_send_failed) + len; + skb_trim(skb, ssf->ssf_length); + + /* Socket Extensions for SCTP + * 5.3.1.4 SCTP_SEND_FAILED + * + * ssf_error: 16 bits (unsigned integer) + * This value represents the reason why the send failed, and if set, + * will be a SCTP protocol error code as defined in [SCTP] section + * 3.3.10. + */ + ssf->ssf_error = error; + + /* Socket Extensions for SCTP + * 5.3.1.4 SCTP_SEND_FAILED + * + * ssf_info: sizeof (struct sctp_sndrcvinfo) + * The original send information associated with the undelivered + * message. + */ + memcpy(&ssf->ssf_info, &chunk->sinfo, sizeof(struct sctp_sndrcvinfo)); + + /* Per TSVWG discussion with Randy. Allow the application to + * reassemble a fragmented message. + */ + ssf->ssf_info.sinfo_flags = chunk->chunk_hdr->flags; + + /* Socket Extensions for SCTP + * 5.3.1.4 SCTP_SEND_FAILED + * + * ssf_assoc_id: sizeof (sctp_assoc_t) + * The association id field, sf_assoc_id, holds the identifier for the + * association. All notifications for a given association have the + * same association identifier. For TCP style socket, this field is + * ignored. + */ + sctp_ulpevent_set_owner(event, asoc); + ssf->ssf_assoc_id = sctp_assoc2id(asoc); + return event; + +fail: + return NULL; +} + +struct sctp_ulpevent *sctp_ulpevent_make_send_failed_event( + const struct sctp_association *asoc, struct sctp_chunk *chunk, + __u16 flags, __u32 error, gfp_t gfp) +{ + struct sctp_send_failed_event *ssf; + struct sctp_ulpevent *event; + struct sk_buff *skb; + int len; + + skb = skb_copy_expand(chunk->skb, sizeof(*ssf), 0, gfp); + if (!skb) + return NULL; + + len = ntohs(chunk->chunk_hdr->length); + len -= sctp_datachk_len(&asoc->stream); + + skb_pull(skb, sctp_datachk_len(&asoc->stream)); + event = sctp_skb2event(skb); + sctp_ulpevent_init(event, MSG_NOTIFICATION, skb->truesize); + + ssf = skb_push(skb, sizeof(*ssf)); + ssf->ssf_type = SCTP_SEND_FAILED_EVENT; + ssf->ssf_flags = flags; + ssf->ssf_length = sizeof(*ssf) + len; + skb_trim(skb, ssf->ssf_length); + ssf->ssf_error = error; + + ssf->ssfe_info.snd_sid = chunk->sinfo.sinfo_stream; + ssf->ssfe_info.snd_ppid = chunk->sinfo.sinfo_ppid; + ssf->ssfe_info.snd_context = chunk->sinfo.sinfo_context; + ssf->ssfe_info.snd_assoc_id = chunk->sinfo.sinfo_assoc_id; + ssf->ssfe_info.snd_flags = chunk->chunk_hdr->flags; + + sctp_ulpevent_set_owner(event, asoc); + ssf->ssf_assoc_id = sctp_assoc2id(asoc); + + return event; +} + +/* Create and initialize a SCTP_SHUTDOWN_EVENT notification. + * + * Socket Extensions for SCTP - draft-01 + * 5.3.1.5 SCTP_SHUTDOWN_EVENT + */ +struct sctp_ulpevent *sctp_ulpevent_make_shutdown_event( + const struct sctp_association *asoc, + __u16 flags, gfp_t gfp) +{ + struct sctp_ulpevent *event; + struct sctp_shutdown_event *sse; + struct sk_buff *skb; + + event = sctp_ulpevent_new(sizeof(struct sctp_shutdown_event), + MSG_NOTIFICATION, gfp); + if (!event) + goto fail; + + skb = sctp_event2skb(event); + sse = skb_put(skb, sizeof(struct sctp_shutdown_event)); + + /* Socket Extensions for SCTP + * 5.3.1.5 SCTP_SHUTDOWN_EVENT + * + * sse_type + * It should be SCTP_SHUTDOWN_EVENT + */ + sse->sse_type = SCTP_SHUTDOWN_EVENT; + + /* Socket Extensions for SCTP + * 5.3.1.5 SCTP_SHUTDOWN_EVENT + * + * sse_flags: 16 bits (unsigned integer) + * Currently unused. + */ + sse->sse_flags = 0; + + /* Socket Extensions for SCTP + * 5.3.1.5 SCTP_SHUTDOWN_EVENT + * + * sse_length: sizeof (__u32) + * This field is the total length of the notification data, including + * the notification header. + */ + sse->sse_length = sizeof(struct sctp_shutdown_event); + + /* Socket Extensions for SCTP + * 5.3.1.5 SCTP_SHUTDOWN_EVENT + * + * sse_assoc_id: sizeof (sctp_assoc_t) + * The association id field, holds the identifier for the association. + * All notifications for a given association have the same association + * identifier. For TCP style socket, this field is ignored. + */ + sctp_ulpevent_set_owner(event, asoc); + sse->sse_assoc_id = sctp_assoc2id(asoc); + + return event; + +fail: + return NULL; +} + +/* Create and initialize a SCTP_ADAPTATION_INDICATION notification. + * + * Socket Extensions for SCTP + * 5.3.1.6 SCTP_ADAPTATION_INDICATION + */ +struct sctp_ulpevent *sctp_ulpevent_make_adaptation_indication( + const struct sctp_association *asoc, gfp_t gfp) +{ + struct sctp_ulpevent *event; + struct sctp_adaptation_event *sai; + struct sk_buff *skb; + + event = sctp_ulpevent_new(sizeof(struct sctp_adaptation_event), + MSG_NOTIFICATION, gfp); + if (!event) + goto fail; + + skb = sctp_event2skb(event); + sai = skb_put(skb, sizeof(struct sctp_adaptation_event)); + + sai->sai_type = SCTP_ADAPTATION_INDICATION; + sai->sai_flags = 0; + sai->sai_length = sizeof(struct sctp_adaptation_event); + sai->sai_adaptation_ind = asoc->peer.adaptation_ind; + sctp_ulpevent_set_owner(event, asoc); + sai->sai_assoc_id = sctp_assoc2id(asoc); + + return event; + +fail: + return NULL; +} + +/* A message has been received. Package this message as a notification + * to pass it to the upper layers. Go ahead and calculate the sndrcvinfo + * even if filtered out later. + * + * Socket Extensions for SCTP + * 5.2.2 SCTP Header Information Structure (SCTP_SNDRCV) + */ +struct sctp_ulpevent *sctp_ulpevent_make_rcvmsg(struct sctp_association *asoc, + struct sctp_chunk *chunk, + gfp_t gfp) +{ + struct sctp_ulpevent *event = NULL; + struct sk_buff *skb = chunk->skb; + struct sock *sk = asoc->base.sk; + size_t padding, datalen; + int rx_count; + + /* + * check to see if we need to make space for this + * new skb, expand the rcvbuffer if needed, or drop + * the frame + */ + if (asoc->ep->rcvbuf_policy) + rx_count = atomic_read(&asoc->rmem_alloc); + else + rx_count = atomic_read(&sk->sk_rmem_alloc); + + datalen = ntohs(chunk->chunk_hdr->length); + + if (rx_count >= sk->sk_rcvbuf || !sk_rmem_schedule(sk, skb, datalen)) + goto fail; + + /* Clone the original skb, sharing the data. */ + skb = skb_clone(chunk->skb, gfp); + if (!skb) + goto fail; + + /* Now that all memory allocations for this chunk succeeded, we + * can mark it as received so the tsn_map is updated correctly. + */ + if (sctp_tsnmap_mark(&asoc->peer.tsn_map, + ntohl(chunk->subh.data_hdr->tsn), + chunk->transport)) + goto fail_mark; + + /* First calculate the padding, so we don't inadvertently + * pass up the wrong length to the user. + * + * RFC 2960 - Section 3.2 Chunk Field Descriptions + * + * The total length of a chunk(including Type, Length and Value fields) + * MUST be a multiple of 4 bytes. If the length of the chunk is not a + * multiple of 4 bytes, the sender MUST pad the chunk with all zero + * bytes and this padding is not included in the chunk length field. + * The sender should never pad with more than 3 bytes. The receiver + * MUST ignore the padding bytes. + */ + padding = SCTP_PAD4(datalen) - datalen; + + /* Fixup cloned skb with just this chunks data. */ + skb_trim(skb, chunk->chunk_end - padding - skb->data); + + /* Embed the event fields inside the cloned skb. */ + event = sctp_skb2event(skb); + + /* Initialize event with flags 0 and correct length + * Since this is a clone of the original skb, only account for + * the data of this chunk as other chunks will be accounted separately. + */ + sctp_ulpevent_init(event, 0, skb->len + sizeof(struct sk_buff)); + + /* And hold the chunk as we need it for getting the IP headers + * later in recvmsg + */ + sctp_chunk_hold(chunk); + event->chunk = chunk; + + sctp_ulpevent_receive_data(event, asoc); + + event->stream = ntohs(chunk->subh.data_hdr->stream); + if (chunk->chunk_hdr->flags & SCTP_DATA_UNORDERED) { + event->flags |= SCTP_UNORDERED; + event->cumtsn = sctp_tsnmap_get_ctsn(&asoc->peer.tsn_map); + } + event->tsn = ntohl(chunk->subh.data_hdr->tsn); + event->msg_flags |= chunk->chunk_hdr->flags; + + return event; + +fail_mark: + kfree_skb(skb); +fail: + return NULL; +} + +/* Create a partial delivery related event. + * + * 5.3.1.7 SCTP_PARTIAL_DELIVERY_EVENT + * + * When a receiver is engaged in a partial delivery of a + * message this notification will be used to indicate + * various events. + */ +struct sctp_ulpevent *sctp_ulpevent_make_pdapi( + const struct sctp_association *asoc, + __u32 indication, __u32 sid, __u32 seq, + __u32 flags, gfp_t gfp) +{ + struct sctp_ulpevent *event; + struct sctp_pdapi_event *pd; + struct sk_buff *skb; + + event = sctp_ulpevent_new(sizeof(struct sctp_pdapi_event), + MSG_NOTIFICATION, gfp); + if (!event) + goto fail; + + skb = sctp_event2skb(event); + pd = skb_put(skb, sizeof(struct sctp_pdapi_event)); + + /* pdapi_type + * It should be SCTP_PARTIAL_DELIVERY_EVENT + * + * pdapi_flags: 16 bits (unsigned integer) + * Currently unused. + */ + pd->pdapi_type = SCTP_PARTIAL_DELIVERY_EVENT; + pd->pdapi_flags = flags; + pd->pdapi_stream = sid; + pd->pdapi_seq = seq; + + /* pdapi_length: 32 bits (unsigned integer) + * + * This field is the total length of the notification data, including + * the notification header. It will generally be sizeof (struct + * sctp_pdapi_event). + */ + pd->pdapi_length = sizeof(struct sctp_pdapi_event); + + /* pdapi_indication: 32 bits (unsigned integer) + * + * This field holds the indication being sent to the application. + */ + pd->pdapi_indication = indication; + + /* pdapi_assoc_id: sizeof (sctp_assoc_t) + * + * The association id field, holds the identifier for the association. + */ + sctp_ulpevent_set_owner(event, asoc); + pd->pdapi_assoc_id = sctp_assoc2id(asoc); + + return event; +fail: + return NULL; +} + +struct sctp_ulpevent *sctp_ulpevent_make_authkey( + const struct sctp_association *asoc, __u16 key_id, + __u32 indication, gfp_t gfp) +{ + struct sctp_ulpevent *event; + struct sctp_authkey_event *ak; + struct sk_buff *skb; + + event = sctp_ulpevent_new(sizeof(struct sctp_authkey_event), + MSG_NOTIFICATION, gfp); + if (!event) + goto fail; + + skb = sctp_event2skb(event); + ak = skb_put(skb, sizeof(struct sctp_authkey_event)); + + ak->auth_type = SCTP_AUTHENTICATION_EVENT; + ak->auth_flags = 0; + ak->auth_length = sizeof(struct sctp_authkey_event); + + ak->auth_keynumber = key_id; + ak->auth_altkeynumber = 0; + ak->auth_indication = indication; + + /* + * The association id field, holds the identifier for the association. + */ + sctp_ulpevent_set_owner(event, asoc); + ak->auth_assoc_id = sctp_assoc2id(asoc); + + return event; +fail: + return NULL; +} + +/* + * Socket Extensions for SCTP + * 6.3.10. SCTP_SENDER_DRY_EVENT + */ +struct sctp_ulpevent *sctp_ulpevent_make_sender_dry_event( + const struct sctp_association *asoc, gfp_t gfp) +{ + struct sctp_ulpevent *event; + struct sctp_sender_dry_event *sdry; + struct sk_buff *skb; + + event = sctp_ulpevent_new(sizeof(struct sctp_sender_dry_event), + MSG_NOTIFICATION, gfp); + if (!event) + return NULL; + + skb = sctp_event2skb(event); + sdry = skb_put(skb, sizeof(struct sctp_sender_dry_event)); + + sdry->sender_dry_type = SCTP_SENDER_DRY_EVENT; + sdry->sender_dry_flags = 0; + sdry->sender_dry_length = sizeof(struct sctp_sender_dry_event); + sctp_ulpevent_set_owner(event, asoc); + sdry->sender_dry_assoc_id = sctp_assoc2id(asoc); + + return event; +} + +struct sctp_ulpevent *sctp_ulpevent_make_stream_reset_event( + const struct sctp_association *asoc, __u16 flags, __u16 stream_num, + __be16 *stream_list, gfp_t gfp) +{ + struct sctp_stream_reset_event *sreset; + struct sctp_ulpevent *event; + struct sk_buff *skb; + int length, i; + + length = sizeof(struct sctp_stream_reset_event) + 2 * stream_num; + event = sctp_ulpevent_new(length, MSG_NOTIFICATION, gfp); + if (!event) + return NULL; + + skb = sctp_event2skb(event); + sreset = skb_put(skb, length); + + sreset->strreset_type = SCTP_STREAM_RESET_EVENT; + sreset->strreset_flags = flags; + sreset->strreset_length = length; + sctp_ulpevent_set_owner(event, asoc); + sreset->strreset_assoc_id = sctp_assoc2id(asoc); + + for (i = 0; i < stream_num; i++) + sreset->strreset_stream_list[i] = ntohs(stream_list[i]); + + return event; +} + +struct sctp_ulpevent *sctp_ulpevent_make_assoc_reset_event( + const struct sctp_association *asoc, __u16 flags, __u32 local_tsn, + __u32 remote_tsn, gfp_t gfp) +{ + struct sctp_assoc_reset_event *areset; + struct sctp_ulpevent *event; + struct sk_buff *skb; + + event = sctp_ulpevent_new(sizeof(struct sctp_assoc_reset_event), + MSG_NOTIFICATION, gfp); + if (!event) + return NULL; + + skb = sctp_event2skb(event); + areset = skb_put(skb, sizeof(struct sctp_assoc_reset_event)); + + areset->assocreset_type = SCTP_ASSOC_RESET_EVENT; + areset->assocreset_flags = flags; + areset->assocreset_length = sizeof(struct sctp_assoc_reset_event); + sctp_ulpevent_set_owner(event, asoc); + areset->assocreset_assoc_id = sctp_assoc2id(asoc); + areset->assocreset_local_tsn = local_tsn; + areset->assocreset_remote_tsn = remote_tsn; + + return event; +} + +struct sctp_ulpevent *sctp_ulpevent_make_stream_change_event( + const struct sctp_association *asoc, __u16 flags, + __u32 strchange_instrms, __u32 strchange_outstrms, gfp_t gfp) +{ + struct sctp_stream_change_event *schange; + struct sctp_ulpevent *event; + struct sk_buff *skb; + + event = sctp_ulpevent_new(sizeof(struct sctp_stream_change_event), + MSG_NOTIFICATION, gfp); + if (!event) + return NULL; + + skb = sctp_event2skb(event); + schange = skb_put(skb, sizeof(struct sctp_stream_change_event)); + + schange->strchange_type = SCTP_STREAM_CHANGE_EVENT; + schange->strchange_flags = flags; + schange->strchange_length = sizeof(struct sctp_stream_change_event); + sctp_ulpevent_set_owner(event, asoc); + schange->strchange_assoc_id = sctp_assoc2id(asoc); + schange->strchange_instrms = strchange_instrms; + schange->strchange_outstrms = strchange_outstrms; + + return event; +} + +/* Return the notification type, assuming this is a notification + * event. + */ +__u16 sctp_ulpevent_get_notification_type(const struct sctp_ulpevent *event) +{ + union sctp_notification *notification; + struct sk_buff *skb; + + skb = sctp_event2skb(event); + notification = (union sctp_notification *) skb->data; + return notification->sn_header.sn_type; +} + +/* RFC6458, Section 5.3.2. SCTP Header Information Structure + * (SCTP_SNDRCV, DEPRECATED) + */ +void sctp_ulpevent_read_sndrcvinfo(const struct sctp_ulpevent *event, + struct msghdr *msghdr) +{ + struct sctp_sndrcvinfo sinfo; + + if (sctp_ulpevent_is_notification(event)) + return; + + memset(&sinfo, 0, sizeof(sinfo)); + sinfo.sinfo_stream = event->stream; + sinfo.sinfo_ssn = event->ssn; + sinfo.sinfo_ppid = event->ppid; + sinfo.sinfo_flags = event->flags; + sinfo.sinfo_tsn = event->tsn; + sinfo.sinfo_cumtsn = event->cumtsn; + sinfo.sinfo_assoc_id = sctp_assoc2id(event->asoc); + /* Context value that is set via SCTP_CONTEXT socket option. */ + sinfo.sinfo_context = event->asoc->default_rcv_context; + /* These fields are not used while receiving. */ + sinfo.sinfo_timetolive = 0; + + put_cmsg(msghdr, IPPROTO_SCTP, SCTP_SNDRCV, + sizeof(sinfo), &sinfo); +} + +/* RFC6458, Section 5.3.5 SCTP Receive Information Structure + * (SCTP_SNDRCV) + */ +void sctp_ulpevent_read_rcvinfo(const struct sctp_ulpevent *event, + struct msghdr *msghdr) +{ + struct sctp_rcvinfo rinfo; + + if (sctp_ulpevent_is_notification(event)) + return; + + memset(&rinfo, 0, sizeof(struct sctp_rcvinfo)); + rinfo.rcv_sid = event->stream; + rinfo.rcv_ssn = event->ssn; + rinfo.rcv_ppid = event->ppid; + rinfo.rcv_flags = event->flags; + rinfo.rcv_tsn = event->tsn; + rinfo.rcv_cumtsn = event->cumtsn; + rinfo.rcv_assoc_id = sctp_assoc2id(event->asoc); + rinfo.rcv_context = event->asoc->default_rcv_context; + + put_cmsg(msghdr, IPPROTO_SCTP, SCTP_RCVINFO, + sizeof(rinfo), &rinfo); +} + +/* RFC6458, Section 5.3.6. SCTP Next Receive Information Structure + * (SCTP_NXTINFO) + */ +static void __sctp_ulpevent_read_nxtinfo(const struct sctp_ulpevent *event, + struct msghdr *msghdr, + const struct sk_buff *skb) +{ + struct sctp_nxtinfo nxtinfo; + + memset(&nxtinfo, 0, sizeof(nxtinfo)); + nxtinfo.nxt_sid = event->stream; + nxtinfo.nxt_ppid = event->ppid; + nxtinfo.nxt_flags = event->flags; + if (sctp_ulpevent_is_notification(event)) + nxtinfo.nxt_flags |= SCTP_NOTIFICATION; + nxtinfo.nxt_length = skb->len; + nxtinfo.nxt_assoc_id = sctp_assoc2id(event->asoc); + + put_cmsg(msghdr, IPPROTO_SCTP, SCTP_NXTINFO, + sizeof(nxtinfo), &nxtinfo); +} + +void sctp_ulpevent_read_nxtinfo(const struct sctp_ulpevent *event, + struct msghdr *msghdr, + struct sock *sk) +{ + struct sk_buff *skb; + int err; + + skb = sctp_skb_recv_datagram(sk, MSG_PEEK | MSG_DONTWAIT, &err); + if (skb != NULL) { + __sctp_ulpevent_read_nxtinfo(sctp_skb2event(skb), + msghdr, skb); + /* Just release refcount here. */ + kfree_skb(skb); + } +} + +/* Do accounting for bytes received and hold a reference to the association + * for each skb. + */ +static void sctp_ulpevent_receive_data(struct sctp_ulpevent *event, + struct sctp_association *asoc) +{ + struct sk_buff *skb, *frag; + + skb = sctp_event2skb(event); + /* Set the owner and charge rwnd for bytes received. */ + sctp_ulpevent_set_owner(event, asoc); + sctp_assoc_rwnd_decrease(asoc, skb_headlen(skb)); + + if (!skb->data_len) + return; + + /* Note: Not clearing the entire event struct as this is just a + * fragment of the real event. However, we still need to do rwnd + * accounting. + * In general, the skb passed from IP can have only 1 level of + * fragments. But we allow multiple levels of fragments. + */ + skb_walk_frags(skb, frag) + sctp_ulpevent_receive_data(sctp_skb2event(frag), asoc); +} + +/* Do accounting for bytes just read by user and release the references to + * the association. + */ +static void sctp_ulpevent_release_data(struct sctp_ulpevent *event) +{ + struct sk_buff *skb, *frag; + unsigned int len; + + /* Current stack structures assume that the rcv buffer is + * per socket. For UDP style sockets this is not true as + * multiple associations may be on a single UDP-style socket. + * Use the local private area of the skb to track the owning + * association. + */ + + skb = sctp_event2skb(event); + len = skb->len; + + if (!skb->data_len) + goto done; + + /* Don't forget the fragments. */ + skb_walk_frags(skb, frag) { + /* NOTE: skb_shinfos are recursive. Although IP returns + * skb's with only 1 level of fragments, SCTP reassembly can + * increase the levels. + */ + sctp_ulpevent_release_frag_data(sctp_skb2event(frag)); + } + +done: + sctp_assoc_rwnd_increase(event->asoc, len); + sctp_chunk_put(event->chunk); + sctp_ulpevent_release_owner(event); +} + +static void sctp_ulpevent_release_frag_data(struct sctp_ulpevent *event) +{ + struct sk_buff *skb, *frag; + + skb = sctp_event2skb(event); + + if (!skb->data_len) + goto done; + + /* Don't forget the fragments. */ + skb_walk_frags(skb, frag) { + /* NOTE: skb_shinfos are recursive. Although IP returns + * skb's with only 1 level of fragments, SCTP reassembly can + * increase the levels. + */ + sctp_ulpevent_release_frag_data(sctp_skb2event(frag)); + } + +done: + sctp_chunk_put(event->chunk); + sctp_ulpevent_release_owner(event); +} + +/* Free a ulpevent that has an owner. It includes releasing the reference + * to the owner, updating the rwnd in case of a DATA event and freeing the + * skb. + */ +void sctp_ulpevent_free(struct sctp_ulpevent *event) +{ + if (sctp_ulpevent_is_notification(event)) + sctp_ulpevent_release_owner(event); + else + sctp_ulpevent_release_data(event); + + kfree_skb(sctp_event2skb(event)); +} + +/* Purge the skb lists holding ulpevents. */ +unsigned int sctp_queue_purge_ulpevents(struct sk_buff_head *list) +{ + struct sk_buff *skb; + unsigned int data_unread = 0; + + while ((skb = skb_dequeue(list)) != NULL) { + struct sctp_ulpevent *event = sctp_skb2event(skb); + + if (!sctp_ulpevent_is_notification(event)) + data_unread += skb->len; + + sctp_ulpevent_free(event); + } + + return data_unread; +} diff --git a/net/sctp/ulpqueue.c b/net/sctp/ulpqueue.c new file mode 100644 index 000000000..0a8510a0c --- /dev/null +++ b/net/sctp/ulpqueue.c @@ -0,0 +1,1132 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* SCTP kernel implementation + * (C) Copyright IBM Corp. 2001, 2004 + * Copyright (c) 1999-2000 Cisco, Inc. + * Copyright (c) 1999-2001 Motorola, Inc. + * Copyright (c) 2001 Intel Corp. + * Copyright (c) 2001 Nokia, Inc. + * Copyright (c) 2001 La Monte H.P. Yarroll + * + * This abstraction carries sctp events to the ULP (sockets). + * + * Please send any bug reports or fixes you make to the + * email address(es): + * lksctp developers + * + * Written or modified by: + * Jon Grimm + * La Monte H.P. Yarroll + * Sridhar Samudrala + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +/* Forward declarations for internal helpers. */ +static struct sctp_ulpevent *sctp_ulpq_reasm(struct sctp_ulpq *ulpq, + struct sctp_ulpevent *); +static struct sctp_ulpevent *sctp_ulpq_order(struct sctp_ulpq *, + struct sctp_ulpevent *); +static void sctp_ulpq_reasm_drain(struct sctp_ulpq *ulpq); + +/* 1st Level Abstractions */ + +/* Initialize a ULP queue from a block of memory. */ +struct sctp_ulpq *sctp_ulpq_init(struct sctp_ulpq *ulpq, + struct sctp_association *asoc) +{ + memset(ulpq, 0, sizeof(struct sctp_ulpq)); + + ulpq->asoc = asoc; + skb_queue_head_init(&ulpq->reasm); + skb_queue_head_init(&ulpq->reasm_uo); + skb_queue_head_init(&ulpq->lobby); + ulpq->pd_mode = 0; + + return ulpq; +} + + +/* Flush the reassembly and ordering queues. */ +void sctp_ulpq_flush(struct sctp_ulpq *ulpq) +{ + struct sk_buff *skb; + struct sctp_ulpevent *event; + + while ((skb = __skb_dequeue(&ulpq->lobby)) != NULL) { + event = sctp_skb2event(skb); + sctp_ulpevent_free(event); + } + + while ((skb = __skb_dequeue(&ulpq->reasm)) != NULL) { + event = sctp_skb2event(skb); + sctp_ulpevent_free(event); + } + + while ((skb = __skb_dequeue(&ulpq->reasm_uo)) != NULL) { + event = sctp_skb2event(skb); + sctp_ulpevent_free(event); + } +} + +/* Dispose of a ulpqueue. */ +void sctp_ulpq_free(struct sctp_ulpq *ulpq) +{ + sctp_ulpq_flush(ulpq); +} + +/* Process an incoming DATA chunk. */ +int sctp_ulpq_tail_data(struct sctp_ulpq *ulpq, struct sctp_chunk *chunk, + gfp_t gfp) +{ + struct sk_buff_head temp; + struct sctp_ulpevent *event; + int event_eor = 0; + + /* Create an event from the incoming chunk. */ + event = sctp_ulpevent_make_rcvmsg(chunk->asoc, chunk, gfp); + if (!event) + return -ENOMEM; + + event->ssn = ntohs(chunk->subh.data_hdr->ssn); + event->ppid = chunk->subh.data_hdr->ppid; + + /* Do reassembly if needed. */ + event = sctp_ulpq_reasm(ulpq, event); + + /* Do ordering if needed. */ + if (event) { + /* Create a temporary list to collect chunks on. */ + skb_queue_head_init(&temp); + __skb_queue_tail(&temp, sctp_event2skb(event)); + + if (event->msg_flags & MSG_EOR) + event = sctp_ulpq_order(ulpq, event); + } + + /* Send event to the ULP. 'event' is the sctp_ulpevent for + * very first SKB on the 'temp' list. + */ + if (event) { + event_eor = (event->msg_flags & MSG_EOR) ? 1 : 0; + sctp_ulpq_tail_event(ulpq, &temp); + } + + return event_eor; +} + +/* Add a new event for propagation to the ULP. */ +/* Clear the partial delivery mode for this socket. Note: This + * assumes that no association is currently in partial delivery mode. + */ +int sctp_clear_pd(struct sock *sk, struct sctp_association *asoc) +{ + struct sctp_sock *sp = sctp_sk(sk); + + if (atomic_dec_and_test(&sp->pd_mode)) { + /* This means there are no other associations in PD, so + * we can go ahead and clear out the lobby in one shot + */ + if (!skb_queue_empty(&sp->pd_lobby)) { + skb_queue_splice_tail_init(&sp->pd_lobby, + &sk->sk_receive_queue); + return 1; + } + } else { + /* There are other associations in PD, so we only need to + * pull stuff out of the lobby that belongs to the + * associations that is exiting PD (all of its notifications + * are posted here). + */ + if (!skb_queue_empty(&sp->pd_lobby) && asoc) { + struct sk_buff *skb, *tmp; + struct sctp_ulpevent *event; + + sctp_skb_for_each(skb, &sp->pd_lobby, tmp) { + event = sctp_skb2event(skb); + if (event->asoc == asoc) { + __skb_unlink(skb, &sp->pd_lobby); + __skb_queue_tail(&sk->sk_receive_queue, + skb); + } + } + } + } + + return 0; +} + +/* Set the pd_mode on the socket and ulpq */ +static void sctp_ulpq_set_pd(struct sctp_ulpq *ulpq) +{ + struct sctp_sock *sp = sctp_sk(ulpq->asoc->base.sk); + + atomic_inc(&sp->pd_mode); + ulpq->pd_mode = 1; +} + +/* Clear the pd_mode and restart any pending messages waiting for delivery. */ +static int sctp_ulpq_clear_pd(struct sctp_ulpq *ulpq) +{ + ulpq->pd_mode = 0; + sctp_ulpq_reasm_drain(ulpq); + return sctp_clear_pd(ulpq->asoc->base.sk, ulpq->asoc); +} + +int sctp_ulpq_tail_event(struct sctp_ulpq *ulpq, struct sk_buff_head *skb_list) +{ + struct sock *sk = ulpq->asoc->base.sk; + struct sctp_sock *sp = sctp_sk(sk); + struct sctp_ulpevent *event; + struct sk_buff_head *queue; + struct sk_buff *skb; + int clear_pd = 0; + + skb = __skb_peek(skb_list); + event = sctp_skb2event(skb); + + /* If the socket is just going to throw this away, do not + * even try to deliver it. + */ + if (sk->sk_shutdown & RCV_SHUTDOWN && + (sk->sk_shutdown & SEND_SHUTDOWN || + !sctp_ulpevent_is_notification(event))) + goto out_free; + + if (!sctp_ulpevent_is_notification(event)) { + sk_mark_napi_id(sk, skb); + sk_incoming_cpu_update(sk); + } + /* Check if the user wishes to receive this event. */ + if (!sctp_ulpevent_is_enabled(event, ulpq->asoc->subscribe)) + goto out_free; + + /* If we are in partial delivery mode, post to the lobby until + * partial delivery is cleared, unless, of course _this_ is + * the association the cause of the partial delivery. + */ + + if (atomic_read(&sp->pd_mode) == 0) { + queue = &sk->sk_receive_queue; + } else { + if (ulpq->pd_mode) { + /* If the association is in partial delivery, we + * need to finish delivering the partially processed + * packet before passing any other data. This is + * because we don't truly support stream interleaving. + */ + if ((event->msg_flags & MSG_NOTIFICATION) || + (SCTP_DATA_NOT_FRAG == + (event->msg_flags & SCTP_DATA_FRAG_MASK))) + queue = &sp->pd_lobby; + else { + clear_pd = event->msg_flags & MSG_EOR; + queue = &sk->sk_receive_queue; + } + } else { + /* + * If fragment interleave is enabled, we + * can queue this to the receive queue instead + * of the lobby. + */ + if (sp->frag_interleave) + queue = &sk->sk_receive_queue; + else + queue = &sp->pd_lobby; + } + } + + skb_queue_splice_tail_init(skb_list, queue); + + /* Did we just complete partial delivery and need to get + * rolling again? Move pending data to the receive + * queue. + */ + if (clear_pd) + sctp_ulpq_clear_pd(ulpq); + + if (queue == &sk->sk_receive_queue && !sp->data_ready_signalled) { + if (!sock_owned_by_user(sk)) + sp->data_ready_signalled = 1; + sk->sk_data_ready(sk); + } + return 1; + +out_free: + if (skb_list) + sctp_queue_purge_ulpevents(skb_list); + else + sctp_ulpevent_free(event); + + return 0; +} + +/* 2nd Level Abstractions */ + +/* Helper function to store chunks that need to be reassembled. */ +static void sctp_ulpq_store_reasm(struct sctp_ulpq *ulpq, + struct sctp_ulpevent *event) +{ + struct sk_buff *pos; + struct sctp_ulpevent *cevent; + __u32 tsn, ctsn; + + tsn = event->tsn; + + /* See if it belongs at the end. */ + pos = skb_peek_tail(&ulpq->reasm); + if (!pos) { + __skb_queue_tail(&ulpq->reasm, sctp_event2skb(event)); + return; + } + + /* Short circuit just dropping it at the end. */ + cevent = sctp_skb2event(pos); + ctsn = cevent->tsn; + if (TSN_lt(ctsn, tsn)) { + __skb_queue_tail(&ulpq->reasm, sctp_event2skb(event)); + return; + } + + /* Find the right place in this list. We store them by TSN. */ + skb_queue_walk(&ulpq->reasm, pos) { + cevent = sctp_skb2event(pos); + ctsn = cevent->tsn; + + if (TSN_lt(tsn, ctsn)) + break; + } + + /* Insert before pos. */ + __skb_queue_before(&ulpq->reasm, pos, sctp_event2skb(event)); + +} + +/* Helper function to return an event corresponding to the reassembled + * datagram. + * This routine creates a re-assembled skb given the first and last skb's + * as stored in the reassembly queue. The skb's may be non-linear if the sctp + * payload was fragmented on the way and ip had to reassemble them. + * We add the rest of skb's to the first skb's fraglist. + */ +struct sctp_ulpevent *sctp_make_reassembled_event(struct net *net, + struct sk_buff_head *queue, + struct sk_buff *f_frag, + struct sk_buff *l_frag) +{ + struct sk_buff *pos; + struct sk_buff *new = NULL; + struct sctp_ulpevent *event; + struct sk_buff *pnext, *last; + struct sk_buff *list = skb_shinfo(f_frag)->frag_list; + + /* Store the pointer to the 2nd skb */ + if (f_frag == l_frag) + pos = NULL; + else + pos = f_frag->next; + + /* Get the last skb in the f_frag's frag_list if present. */ + for (last = list; list; last = list, list = list->next) + ; + + /* Add the list of remaining fragments to the first fragments + * frag_list. + */ + if (last) + last->next = pos; + else { + if (skb_cloned(f_frag)) { + /* This is a cloned skb, we can't just modify + * the frag_list. We need a new skb to do that. + * Instead of calling skb_unshare(), we'll do it + * ourselves since we need to delay the free. + */ + new = skb_copy(f_frag, GFP_ATOMIC); + if (!new) + return NULL; /* try again later */ + + sctp_skb_set_owner_r(new, f_frag->sk); + + skb_shinfo(new)->frag_list = pos; + } else + skb_shinfo(f_frag)->frag_list = pos; + } + + /* Remove the first fragment from the reassembly queue. */ + __skb_unlink(f_frag, queue); + + /* if we did unshare, then free the old skb and re-assign */ + if (new) { + kfree_skb(f_frag); + f_frag = new; + } + + while (pos) { + + pnext = pos->next; + + /* Update the len and data_len fields of the first fragment. */ + f_frag->len += pos->len; + f_frag->data_len += pos->len; + + /* Remove the fragment from the reassembly queue. */ + __skb_unlink(pos, queue); + + /* Break if we have reached the last fragment. */ + if (pos == l_frag) + break; + pos->next = pnext; + pos = pnext; + } + + event = sctp_skb2event(f_frag); + SCTP_INC_STATS(net, SCTP_MIB_REASMUSRMSGS); + + return event; +} + + +/* Helper function to check if an incoming chunk has filled up the last + * missing fragment in a SCTP datagram and return the corresponding event. + */ +static struct sctp_ulpevent *sctp_ulpq_retrieve_reassembled(struct sctp_ulpq *ulpq) +{ + struct sk_buff *pos; + struct sctp_ulpevent *cevent; + struct sk_buff *first_frag = NULL; + __u32 ctsn, next_tsn; + struct sctp_ulpevent *retval = NULL; + struct sk_buff *pd_first = NULL; + struct sk_buff *pd_last = NULL; + size_t pd_len = 0; + struct sctp_association *asoc; + u32 pd_point; + + /* Initialized to 0 just to avoid compiler warning message. Will + * never be used with this value. It is referenced only after it + * is set when we find the first fragment of a message. + */ + next_tsn = 0; + + /* The chunks are held in the reasm queue sorted by TSN. + * Walk through the queue sequentially and look for a sequence of + * fragmented chunks that complete a datagram. + * 'first_frag' and next_tsn are reset when we find a chunk which + * is the first fragment of a datagram. Once these 2 fields are set + * we expect to find the remaining middle fragments and the last + * fragment in order. If not, first_frag is reset to NULL and we + * start the next pass when we find another first fragment. + * + * There is a potential to do partial delivery if user sets + * SCTP_PARTIAL_DELIVERY_POINT option. Lets count some things here + * to see if can do PD. + */ + skb_queue_walk(&ulpq->reasm, pos) { + cevent = sctp_skb2event(pos); + ctsn = cevent->tsn; + + switch (cevent->msg_flags & SCTP_DATA_FRAG_MASK) { + case SCTP_DATA_FIRST_FRAG: + /* If this "FIRST_FRAG" is the first + * element in the queue, then count it towards + * possible PD. + */ + if (skb_queue_is_first(&ulpq->reasm, pos)) { + pd_first = pos; + pd_last = pos; + pd_len = pos->len; + } else { + pd_first = NULL; + pd_last = NULL; + pd_len = 0; + } + + first_frag = pos; + next_tsn = ctsn + 1; + break; + + case SCTP_DATA_MIDDLE_FRAG: + if ((first_frag) && (ctsn == next_tsn)) { + next_tsn++; + if (pd_first) { + pd_last = pos; + pd_len += pos->len; + } + } else + first_frag = NULL; + break; + + case SCTP_DATA_LAST_FRAG: + if (first_frag && (ctsn == next_tsn)) + goto found; + else + first_frag = NULL; + break; + } + } + + asoc = ulpq->asoc; + if (pd_first) { + /* Make sure we can enter partial deliver. + * We can trigger partial delivery only if framgent + * interleave is set, or the socket is not already + * in partial delivery. + */ + if (!sctp_sk(asoc->base.sk)->frag_interleave && + atomic_read(&sctp_sk(asoc->base.sk)->pd_mode)) + goto done; + + cevent = sctp_skb2event(pd_first); + pd_point = sctp_sk(asoc->base.sk)->pd_point; + if (pd_point && pd_point <= pd_len) { + retval = sctp_make_reassembled_event(asoc->base.net, + &ulpq->reasm, + pd_first, pd_last); + if (retval) + sctp_ulpq_set_pd(ulpq); + } + } +done: + return retval; +found: + retval = sctp_make_reassembled_event(ulpq->asoc->base.net, + &ulpq->reasm, first_frag, pos); + if (retval) + retval->msg_flags |= MSG_EOR; + goto done; +} + +/* Retrieve the next set of fragments of a partial message. */ +static struct sctp_ulpevent *sctp_ulpq_retrieve_partial(struct sctp_ulpq *ulpq) +{ + struct sk_buff *pos, *last_frag, *first_frag; + struct sctp_ulpevent *cevent; + __u32 ctsn, next_tsn; + int is_last; + struct sctp_ulpevent *retval; + + /* The chunks are held in the reasm queue sorted by TSN. + * Walk through the queue sequentially and look for the first + * sequence of fragmented chunks. + */ + + if (skb_queue_empty(&ulpq->reasm)) + return NULL; + + last_frag = first_frag = NULL; + retval = NULL; + next_tsn = 0; + is_last = 0; + + skb_queue_walk(&ulpq->reasm, pos) { + cevent = sctp_skb2event(pos); + ctsn = cevent->tsn; + + switch (cevent->msg_flags & SCTP_DATA_FRAG_MASK) { + case SCTP_DATA_FIRST_FRAG: + if (!first_frag) + return NULL; + goto done; + case SCTP_DATA_MIDDLE_FRAG: + if (!first_frag) { + first_frag = pos; + next_tsn = ctsn + 1; + last_frag = pos; + } else if (next_tsn == ctsn) { + next_tsn++; + last_frag = pos; + } else + goto done; + break; + case SCTP_DATA_LAST_FRAG: + if (!first_frag) + first_frag = pos; + else if (ctsn != next_tsn) + goto done; + last_frag = pos; + is_last = 1; + goto done; + default: + return NULL; + } + } + + /* We have the reassembled event. There is no need to look + * further. + */ +done: + retval = sctp_make_reassembled_event(ulpq->asoc->base.net, &ulpq->reasm, + first_frag, last_frag); + if (retval && is_last) + retval->msg_flags |= MSG_EOR; + + return retval; +} + + +/* Helper function to reassemble chunks. Hold chunks on the reasm queue that + * need reassembling. + */ +static struct sctp_ulpevent *sctp_ulpq_reasm(struct sctp_ulpq *ulpq, + struct sctp_ulpevent *event) +{ + struct sctp_ulpevent *retval = NULL; + + /* Check if this is part of a fragmented message. */ + if (SCTP_DATA_NOT_FRAG == (event->msg_flags & SCTP_DATA_FRAG_MASK)) { + event->msg_flags |= MSG_EOR; + return event; + } + + sctp_ulpq_store_reasm(ulpq, event); + if (!ulpq->pd_mode) + retval = sctp_ulpq_retrieve_reassembled(ulpq); + else { + __u32 ctsn, ctsnap; + + /* Do not even bother unless this is the next tsn to + * be delivered. + */ + ctsn = event->tsn; + ctsnap = sctp_tsnmap_get_ctsn(&ulpq->asoc->peer.tsn_map); + if (TSN_lte(ctsn, ctsnap)) + retval = sctp_ulpq_retrieve_partial(ulpq); + } + + return retval; +} + +/* Retrieve the first part (sequential fragments) for partial delivery. */ +static struct sctp_ulpevent *sctp_ulpq_retrieve_first(struct sctp_ulpq *ulpq) +{ + struct sk_buff *pos, *last_frag, *first_frag; + struct sctp_ulpevent *cevent; + __u32 ctsn, next_tsn; + struct sctp_ulpevent *retval; + + /* The chunks are held in the reasm queue sorted by TSN. + * Walk through the queue sequentially and look for a sequence of + * fragmented chunks that start a datagram. + */ + + if (skb_queue_empty(&ulpq->reasm)) + return NULL; + + last_frag = first_frag = NULL; + retval = NULL; + next_tsn = 0; + + skb_queue_walk(&ulpq->reasm, pos) { + cevent = sctp_skb2event(pos); + ctsn = cevent->tsn; + + switch (cevent->msg_flags & SCTP_DATA_FRAG_MASK) { + case SCTP_DATA_FIRST_FRAG: + if (!first_frag) { + first_frag = pos; + next_tsn = ctsn + 1; + last_frag = pos; + } else + goto done; + break; + + case SCTP_DATA_MIDDLE_FRAG: + if (!first_frag) + return NULL; + if (ctsn == next_tsn) { + next_tsn++; + last_frag = pos; + } else + goto done; + break; + + case SCTP_DATA_LAST_FRAG: + if (!first_frag) + return NULL; + else + goto done; + break; + + default: + return NULL; + } + } + + /* We have the reassembled event. There is no need to look + * further. + */ +done: + retval = sctp_make_reassembled_event(ulpq->asoc->base.net, &ulpq->reasm, + first_frag, last_frag); + return retval; +} + +/* + * Flush out stale fragments from the reassembly queue when processing + * a Forward TSN. + * + * RFC 3758, Section 3.6 + * + * After receiving and processing a FORWARD TSN, the data receiver MUST + * take cautions in updating its re-assembly queue. The receiver MUST + * remove any partially reassembled message, which is still missing one + * or more TSNs earlier than or equal to the new cumulative TSN point. + * In the event that the receiver has invoked the partial delivery API, + * a notification SHOULD also be generated to inform the upper layer API + * that the message being partially delivered will NOT be completed. + */ +void sctp_ulpq_reasm_flushtsn(struct sctp_ulpq *ulpq, __u32 fwd_tsn) +{ + struct sk_buff *pos, *tmp; + struct sctp_ulpevent *event; + __u32 tsn; + + if (skb_queue_empty(&ulpq->reasm)) + return; + + skb_queue_walk_safe(&ulpq->reasm, pos, tmp) { + event = sctp_skb2event(pos); + tsn = event->tsn; + + /* Since the entire message must be abandoned by the + * sender (item A3 in Section 3.5, RFC 3758), we can + * free all fragments on the list that are less then + * or equal to ctsn_point + */ + if (TSN_lte(tsn, fwd_tsn)) { + __skb_unlink(pos, &ulpq->reasm); + sctp_ulpevent_free(event); + } else + break; + } +} + +/* + * Drain the reassembly queue. If we just cleared parted delivery, it + * is possible that the reassembly queue will contain already reassembled + * messages. Retrieve any such messages and give them to the user. + */ +static void sctp_ulpq_reasm_drain(struct sctp_ulpq *ulpq) +{ + struct sctp_ulpevent *event = NULL; + + if (skb_queue_empty(&ulpq->reasm)) + return; + + while ((event = sctp_ulpq_retrieve_reassembled(ulpq)) != NULL) { + struct sk_buff_head temp; + + skb_queue_head_init(&temp); + __skb_queue_tail(&temp, sctp_event2skb(event)); + + /* Do ordering if needed. */ + if (event->msg_flags & MSG_EOR) + event = sctp_ulpq_order(ulpq, event); + + /* Send event to the ULP. 'event' is the + * sctp_ulpevent for very first SKB on the temp' list. + */ + if (event) + sctp_ulpq_tail_event(ulpq, &temp); + } +} + + +/* Helper function to gather skbs that have possibly become + * ordered by an incoming chunk. + */ +static void sctp_ulpq_retrieve_ordered(struct sctp_ulpq *ulpq, + struct sctp_ulpevent *event) +{ + struct sk_buff_head *event_list; + struct sk_buff *pos, *tmp; + struct sctp_ulpevent *cevent; + struct sctp_stream *stream; + __u16 sid, csid, cssn; + + sid = event->stream; + stream = &ulpq->asoc->stream; + + event_list = (struct sk_buff_head *) sctp_event2skb(event)->prev; + + /* We are holding the chunks by stream, by SSN. */ + sctp_skb_for_each(pos, &ulpq->lobby, tmp) { + cevent = (struct sctp_ulpevent *) pos->cb; + csid = cevent->stream; + cssn = cevent->ssn; + + /* Have we gone too far? */ + if (csid > sid) + break; + + /* Have we not gone far enough? */ + if (csid < sid) + continue; + + if (cssn != sctp_ssn_peek(stream, in, sid)) + break; + + /* Found it, so mark in the stream. */ + sctp_ssn_next(stream, in, sid); + + __skb_unlink(pos, &ulpq->lobby); + + /* Attach all gathered skbs to the event. */ + __skb_queue_tail(event_list, pos); + } +} + +/* Helper function to store chunks needing ordering. */ +static void sctp_ulpq_store_ordered(struct sctp_ulpq *ulpq, + struct sctp_ulpevent *event) +{ + struct sk_buff *pos; + struct sctp_ulpevent *cevent; + __u16 sid, csid; + __u16 ssn, cssn; + + pos = skb_peek_tail(&ulpq->lobby); + if (!pos) { + __skb_queue_tail(&ulpq->lobby, sctp_event2skb(event)); + return; + } + + sid = event->stream; + ssn = event->ssn; + + cevent = (struct sctp_ulpevent *) pos->cb; + csid = cevent->stream; + cssn = cevent->ssn; + if (sid > csid) { + __skb_queue_tail(&ulpq->lobby, sctp_event2skb(event)); + return; + } + + if ((sid == csid) && SSN_lt(cssn, ssn)) { + __skb_queue_tail(&ulpq->lobby, sctp_event2skb(event)); + return; + } + + /* Find the right place in this list. We store them by + * stream ID and then by SSN. + */ + skb_queue_walk(&ulpq->lobby, pos) { + cevent = (struct sctp_ulpevent *) pos->cb; + csid = cevent->stream; + cssn = cevent->ssn; + + if (csid > sid) + break; + if (csid == sid && SSN_lt(ssn, cssn)) + break; + } + + + /* Insert before pos. */ + __skb_queue_before(&ulpq->lobby, pos, sctp_event2skb(event)); +} + +static struct sctp_ulpevent *sctp_ulpq_order(struct sctp_ulpq *ulpq, + struct sctp_ulpevent *event) +{ + __u16 sid, ssn; + struct sctp_stream *stream; + + /* Check if this message needs ordering. */ + if (event->msg_flags & SCTP_DATA_UNORDERED) + return event; + + /* Note: The stream ID must be verified before this routine. */ + sid = event->stream; + ssn = event->ssn; + stream = &ulpq->asoc->stream; + + /* Is this the expected SSN for this stream ID? */ + if (ssn != sctp_ssn_peek(stream, in, sid)) { + /* We've received something out of order, so find where it + * needs to be placed. We order by stream and then by SSN. + */ + sctp_ulpq_store_ordered(ulpq, event); + return NULL; + } + + /* Mark that the next chunk has been found. */ + sctp_ssn_next(stream, in, sid); + + /* Go find any other chunks that were waiting for + * ordering. + */ + sctp_ulpq_retrieve_ordered(ulpq, event); + + return event; +} + +/* Helper function to gather skbs that have possibly become + * ordered by forward tsn skipping their dependencies. + */ +static void sctp_ulpq_reap_ordered(struct sctp_ulpq *ulpq, __u16 sid) +{ + struct sk_buff *pos, *tmp; + struct sctp_ulpevent *cevent; + struct sctp_ulpevent *event; + struct sctp_stream *stream; + struct sk_buff_head temp; + struct sk_buff_head *lobby = &ulpq->lobby; + __u16 csid, cssn; + + stream = &ulpq->asoc->stream; + + /* We are holding the chunks by stream, by SSN. */ + skb_queue_head_init(&temp); + event = NULL; + sctp_skb_for_each(pos, lobby, tmp) { + cevent = (struct sctp_ulpevent *) pos->cb; + csid = cevent->stream; + cssn = cevent->ssn; + + /* Have we gone too far? */ + if (csid > sid) + break; + + /* Have we not gone far enough? */ + if (csid < sid) + continue; + + /* see if this ssn has been marked by skipping */ + if (!SSN_lt(cssn, sctp_ssn_peek(stream, in, csid))) + break; + + __skb_unlink(pos, lobby); + if (!event) + /* Create a temporary list to collect chunks on. */ + event = sctp_skb2event(pos); + + /* Attach all gathered skbs to the event. */ + __skb_queue_tail(&temp, pos); + } + + /* If we didn't reap any data, see if the next expected SSN + * is next on the queue and if so, use that. + */ + if (event == NULL && pos != (struct sk_buff *)lobby) { + cevent = (struct sctp_ulpevent *) pos->cb; + csid = cevent->stream; + cssn = cevent->ssn; + + if (csid == sid && cssn == sctp_ssn_peek(stream, in, csid)) { + sctp_ssn_next(stream, in, csid); + __skb_unlink(pos, lobby); + __skb_queue_tail(&temp, pos); + event = sctp_skb2event(pos); + } + } + + /* Send event to the ULP. 'event' is the sctp_ulpevent for + * very first SKB on the 'temp' list. + */ + if (event) { + /* see if we have more ordered that we can deliver */ + sctp_ulpq_retrieve_ordered(ulpq, event); + sctp_ulpq_tail_event(ulpq, &temp); + } +} + +/* Skip over an SSN. This is used during the processing of + * Forwared TSN chunk to skip over the abandoned ordered data + */ +void sctp_ulpq_skip(struct sctp_ulpq *ulpq, __u16 sid, __u16 ssn) +{ + struct sctp_stream *stream; + + /* Note: The stream ID must be verified before this routine. */ + stream = &ulpq->asoc->stream; + + /* Is this an old SSN? If so ignore. */ + if (SSN_lt(ssn, sctp_ssn_peek(stream, in, sid))) + return; + + /* Mark that we are no longer expecting this SSN or lower. */ + sctp_ssn_skip(stream, in, sid, ssn); + + /* Go find any other chunks that were waiting for + * ordering and deliver them if needed. + */ + sctp_ulpq_reap_ordered(ulpq, sid); +} + +__u16 sctp_ulpq_renege_list(struct sctp_ulpq *ulpq, struct sk_buff_head *list, + __u16 needed) +{ + __u16 freed = 0; + __u32 tsn, last_tsn; + struct sk_buff *skb, *flist, *last; + struct sctp_ulpevent *event; + struct sctp_tsnmap *tsnmap; + + tsnmap = &ulpq->asoc->peer.tsn_map; + + while ((skb = skb_peek_tail(list)) != NULL) { + event = sctp_skb2event(skb); + tsn = event->tsn; + + /* Don't renege below the Cumulative TSN ACK Point. */ + if (TSN_lte(tsn, sctp_tsnmap_get_ctsn(tsnmap))) + break; + + /* Events in ordering queue may have multiple fragments + * corresponding to additional TSNs. Sum the total + * freed space; find the last TSN. + */ + freed += skb_headlen(skb); + flist = skb_shinfo(skb)->frag_list; + for (last = flist; flist; flist = flist->next) { + last = flist; + freed += skb_headlen(last); + } + if (last) + last_tsn = sctp_skb2event(last)->tsn; + else + last_tsn = tsn; + + /* Unlink the event, then renege all applicable TSNs. */ + __skb_unlink(skb, list); + sctp_ulpevent_free(event); + while (TSN_lte(tsn, last_tsn)) { + sctp_tsnmap_renege(tsnmap, tsn); + tsn++; + } + if (freed >= needed) + return freed; + } + + return freed; +} + +/* Renege 'needed' bytes from the ordering queue. */ +static __u16 sctp_ulpq_renege_order(struct sctp_ulpq *ulpq, __u16 needed) +{ + return sctp_ulpq_renege_list(ulpq, &ulpq->lobby, needed); +} + +/* Renege 'needed' bytes from the reassembly queue. */ +static __u16 sctp_ulpq_renege_frags(struct sctp_ulpq *ulpq, __u16 needed) +{ + return sctp_ulpq_renege_list(ulpq, &ulpq->reasm, needed); +} + +/* Partial deliver the first message as there is pressure on rwnd. */ +void sctp_ulpq_partial_delivery(struct sctp_ulpq *ulpq, + gfp_t gfp) +{ + struct sctp_ulpevent *event; + struct sctp_association *asoc; + struct sctp_sock *sp; + __u32 ctsn; + struct sk_buff *skb; + + asoc = ulpq->asoc; + sp = sctp_sk(asoc->base.sk); + + /* If the association is already in Partial Delivery mode + * we have nothing to do. + */ + if (ulpq->pd_mode) + return; + + /* Data must be at or below the Cumulative TSN ACK Point to + * start partial delivery. + */ + skb = skb_peek(&asoc->ulpq.reasm); + if (skb != NULL) { + ctsn = sctp_skb2event(skb)->tsn; + if (!TSN_lte(ctsn, sctp_tsnmap_get_ctsn(&asoc->peer.tsn_map))) + return; + } + + /* If the user enabled fragment interleave socket option, + * multiple associations can enter partial delivery. + * Otherwise, we can only enter partial delivery if the + * socket is not in partial deliver mode. + */ + if (sp->frag_interleave || atomic_read(&sp->pd_mode) == 0) { + /* Is partial delivery possible? */ + event = sctp_ulpq_retrieve_first(ulpq); + /* Send event to the ULP. */ + if (event) { + struct sk_buff_head temp; + + skb_queue_head_init(&temp); + __skb_queue_tail(&temp, sctp_event2skb(event)); + sctp_ulpq_tail_event(ulpq, &temp); + sctp_ulpq_set_pd(ulpq); + return; + } + } +} + +/* Renege some packets to make room for an incoming chunk. */ +void sctp_ulpq_renege(struct sctp_ulpq *ulpq, struct sctp_chunk *chunk, + gfp_t gfp) +{ + struct sctp_association *asoc = ulpq->asoc; + __u32 freed = 0; + __u16 needed; + + needed = ntohs(chunk->chunk_hdr->length) - + sizeof(struct sctp_data_chunk); + + if (skb_queue_empty(&asoc->base.sk->sk_receive_queue)) { + freed = sctp_ulpq_renege_order(ulpq, needed); + if (freed < needed) + freed += sctp_ulpq_renege_frags(ulpq, needed - freed); + } + /* If able to free enough room, accept this chunk. */ + if (sk_rmem_schedule(asoc->base.sk, chunk->skb, needed) && + freed >= needed) { + int retval = sctp_ulpq_tail_data(ulpq, chunk, gfp); + /* + * Enter partial delivery if chunk has not been + * delivered; otherwise, drain the reassembly queue. + */ + if (retval <= 0) + sctp_ulpq_partial_delivery(ulpq, gfp); + else if (retval == 1) + sctp_ulpq_reasm_drain(ulpq); + } +} + +/* Notify the application if an association is aborted and in + * partial delivery mode. Send up any pending received messages. + */ +void sctp_ulpq_abort_pd(struct sctp_ulpq *ulpq, gfp_t gfp) +{ + struct sctp_ulpevent *ev = NULL; + struct sctp_sock *sp; + struct sock *sk; + + if (!ulpq->pd_mode) + return; + + sk = ulpq->asoc->base.sk; + sp = sctp_sk(sk); + if (sctp_ulpevent_type_enabled(ulpq->asoc->subscribe, + SCTP_PARTIAL_DELIVERY_EVENT)) + ev = sctp_ulpevent_make_pdapi(ulpq->asoc, + SCTP_PARTIAL_DELIVERY_ABORTED, + 0, 0, 0, gfp); + if (ev) + __skb_queue_tail(&sk->sk_receive_queue, sctp_event2skb(ev)); + + /* If there is data waiting, send it up the socket now. */ + if ((sctp_ulpq_clear_pd(ulpq) || ev) && !sp->data_ready_signalled) { + sp->data_ready_signalled = 1; + sk->sk_data_ready(sk); + } +} diff --git a/net/smc/Kconfig b/net/smc/Kconfig new file mode 100644 index 000000000..1ab3c5a2c --- /dev/null +++ b/net/smc/Kconfig @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: GPL-2.0-only +config SMC + tristate "SMC socket protocol family" + depends on INET && INFINIBAND + help + SMC-R provides a "sockets over RDMA" solution making use of + RDMA over Converged Ethernet (RoCE) technology to upgrade + AF_INET TCP connections transparently. + The Linux implementation of the SMC-R solution is designed as + a separate socket family SMC. + + Select this option if you want to run SMC socket applications + +config SMC_DIAG + tristate "SMC: socket monitoring interface" + depends on SMC + help + Support for SMC socket monitoring interface used by tools such as + smcss. + + if unsure, say Y. diff --git a/net/smc/Makefile b/net/smc/Makefile new file mode 100644 index 000000000..875efcd12 --- /dev/null +++ b/net/smc/Makefile @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: GPL-2.0-only +ccflags-y += -I$(src) +obj-$(CONFIG_SMC) += smc.o +obj-$(CONFIG_SMC_DIAG) += smc_diag.o +smc-y := af_smc.o smc_pnet.o smc_ib.o smc_clc.o smc_core.o smc_wr.o smc_llc.o +smc-y += smc_cdc.o smc_tx.o smc_rx.o smc_close.o smc_ism.o smc_netlink.o smc_stats.o +smc-y += smc_tracepoint.o +smc-$(CONFIG_SYSCTL) += smc_sysctl.o diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c new file mode 100644 index 000000000..b6609527d --- /dev/null +++ b/net/smc/af_smc.c @@ -0,0 +1,3563 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * AF_SMC protocol family socket handler keeping the AF_INET sock address type + * applies to SOCK_STREAM sockets only + * offers an alternative communication option for TCP-protocol sockets + * applicable with RoCE-cards only + * + * Initial restrictions: + * - support for alternate links postponed + * + * Copyright IBM Corp. 2016, 2018 + * + * Author(s): Ursula Braun + * based on prototype from Frank Blaschka + */ + +#define KMSG_COMPONENT "smc" +#define pr_fmt(fmt) KMSG_COMPONENT ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include "smc_netns.h" + +#include "smc.h" +#include "smc_clc.h" +#include "smc_llc.h" +#include "smc_cdc.h" +#include "smc_core.h" +#include "smc_ib.h" +#include "smc_ism.h" +#include "smc_pnet.h" +#include "smc_netlink.h" +#include "smc_tx.h" +#include "smc_rx.h" +#include "smc_close.h" +#include "smc_stats.h" +#include "smc_tracepoint.h" +#include "smc_sysctl.h" + +static DEFINE_MUTEX(smc_server_lgr_pending); /* serialize link group + * creation on server + */ +static DEFINE_MUTEX(smc_client_lgr_pending); /* serialize link group + * creation on client + */ + +static struct workqueue_struct *smc_tcp_ls_wq; /* wq for tcp listen work */ +struct workqueue_struct *smc_hs_wq; /* wq for handshake work */ +struct workqueue_struct *smc_close_wq; /* wq for close work */ + +static void smc_tcp_listen_work(struct work_struct *); +static void smc_connect_work(struct work_struct *); + +int smc_nl_dump_hs_limitation(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct smc_nl_dmp_ctx *cb_ctx = smc_nl_dmp_ctx(cb); + void *hdr; + + if (cb_ctx->pos[0]) + goto out; + + hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &smc_gen_nl_family, NLM_F_MULTI, + SMC_NETLINK_DUMP_HS_LIMITATION); + if (!hdr) + return -ENOMEM; + + if (nla_put_u8(skb, SMC_NLA_HS_LIMITATION_ENABLED, + sock_net(skb->sk)->smc.limit_smc_hs)) + goto err; + + genlmsg_end(skb, hdr); + cb_ctx->pos[0] = 1; +out: + return skb->len; +err: + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; +} + +int smc_nl_enable_hs_limitation(struct sk_buff *skb, struct genl_info *info) +{ + sock_net(skb->sk)->smc.limit_smc_hs = true; + return 0; +} + +int smc_nl_disable_hs_limitation(struct sk_buff *skb, struct genl_info *info) +{ + sock_net(skb->sk)->smc.limit_smc_hs = false; + return 0; +} + +static void smc_set_keepalive(struct sock *sk, int val) +{ + struct smc_sock *smc = smc_sk(sk); + + smc->clcsock->sk->sk_prot->keepalive(smc->clcsock->sk, val); +} + +static struct sock *smc_tcp_syn_recv_sock(const struct sock *sk, + struct sk_buff *skb, + struct request_sock *req, + struct dst_entry *dst, + struct request_sock *req_unhash, + bool *own_req) +{ + struct smc_sock *smc; + struct sock *child; + + smc = smc_clcsock_user_data(sk); + + if (READ_ONCE(sk->sk_ack_backlog) + atomic_read(&smc->queued_smc_hs) > + sk->sk_max_ack_backlog) + goto drop; + + if (sk_acceptq_is_full(&smc->sk)) { + NET_INC_STATS(sock_net(sk), LINUX_MIB_LISTENOVERFLOWS); + goto drop; + } + + /* passthrough to original syn recv sock fct */ + child = smc->ori_af_ops->syn_recv_sock(sk, skb, req, dst, req_unhash, + own_req); + /* child must not inherit smc or its ops */ + if (child) { + rcu_assign_sk_user_data(child, NULL); + + /* v4-mapped sockets don't inherit parent ops. Don't restore. */ + if (inet_csk(child)->icsk_af_ops == inet_csk(sk)->icsk_af_ops) + inet_csk(child)->icsk_af_ops = smc->ori_af_ops; + } + return child; + +drop: + dst_release(dst); + tcp_listendrop(sk); + return NULL; +} + +static bool smc_hs_congested(const struct sock *sk) +{ + const struct smc_sock *smc; + + smc = smc_clcsock_user_data(sk); + + if (!smc) + return true; + + if (workqueue_congested(WORK_CPU_UNBOUND, smc_hs_wq)) + return true; + + return false; +} + +static struct smc_hashinfo smc_v4_hashinfo = { + .lock = __RW_LOCK_UNLOCKED(smc_v4_hashinfo.lock), +}; + +static struct smc_hashinfo smc_v6_hashinfo = { + .lock = __RW_LOCK_UNLOCKED(smc_v6_hashinfo.lock), +}; + +int smc_hash_sk(struct sock *sk) +{ + struct smc_hashinfo *h = sk->sk_prot->h.smc_hash; + struct hlist_head *head; + + head = &h->ht; + + write_lock_bh(&h->lock); + sk_add_node(sk, head); + write_unlock_bh(&h->lock); + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); + + return 0; +} +EXPORT_SYMBOL_GPL(smc_hash_sk); + +void smc_unhash_sk(struct sock *sk) +{ + struct smc_hashinfo *h = sk->sk_prot->h.smc_hash; + + write_lock_bh(&h->lock); + if (sk_del_node_init(sk)) + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); + write_unlock_bh(&h->lock); +} +EXPORT_SYMBOL_GPL(smc_unhash_sk); + +/* This will be called before user really release sock_lock. So do the + * work which we didn't do because of user hold the sock_lock in the + * BH context + */ +static void smc_release_cb(struct sock *sk) +{ + struct smc_sock *smc = smc_sk(sk); + + if (smc->conn.tx_in_release_sock) { + smc_tx_pending(&smc->conn); + smc->conn.tx_in_release_sock = false; + } +} + +struct proto smc_proto = { + .name = "SMC", + .owner = THIS_MODULE, + .keepalive = smc_set_keepalive, + .hash = smc_hash_sk, + .unhash = smc_unhash_sk, + .release_cb = smc_release_cb, + .obj_size = sizeof(struct smc_sock), + .h.smc_hash = &smc_v4_hashinfo, + .slab_flags = SLAB_TYPESAFE_BY_RCU, +}; +EXPORT_SYMBOL_GPL(smc_proto); + +struct proto smc_proto6 = { + .name = "SMC6", + .owner = THIS_MODULE, + .keepalive = smc_set_keepalive, + .hash = smc_hash_sk, + .unhash = smc_unhash_sk, + .release_cb = smc_release_cb, + .obj_size = sizeof(struct smc_sock), + .h.smc_hash = &smc_v6_hashinfo, + .slab_flags = SLAB_TYPESAFE_BY_RCU, +}; +EXPORT_SYMBOL_GPL(smc_proto6); + +static void smc_fback_restore_callbacks(struct smc_sock *smc) +{ + struct sock *clcsk = smc->clcsock->sk; + + write_lock_bh(&clcsk->sk_callback_lock); + clcsk->sk_user_data = NULL; + + smc_clcsock_restore_cb(&clcsk->sk_state_change, &smc->clcsk_state_change); + smc_clcsock_restore_cb(&clcsk->sk_data_ready, &smc->clcsk_data_ready); + smc_clcsock_restore_cb(&clcsk->sk_write_space, &smc->clcsk_write_space); + smc_clcsock_restore_cb(&clcsk->sk_error_report, &smc->clcsk_error_report); + + write_unlock_bh(&clcsk->sk_callback_lock); +} + +static void smc_restore_fallback_changes(struct smc_sock *smc) +{ + if (smc->clcsock->file) { /* non-accepted sockets have no file yet */ + smc->clcsock->file->private_data = smc->sk.sk_socket; + smc->clcsock->file = NULL; + smc_fback_restore_callbacks(smc); + } +} + +static int __smc_release(struct smc_sock *smc) +{ + struct sock *sk = &smc->sk; + int rc = 0; + + if (!smc->use_fallback) { + rc = smc_close_active(smc); + smc_sock_set_flag(sk, SOCK_DEAD); + sk->sk_shutdown |= SHUTDOWN_MASK; + } else { + if (sk->sk_state != SMC_CLOSED) { + if (sk->sk_state != SMC_LISTEN && + sk->sk_state != SMC_INIT) + sock_put(sk); /* passive closing */ + if (sk->sk_state == SMC_LISTEN) { + /* wake up clcsock accept */ + rc = kernel_sock_shutdown(smc->clcsock, + SHUT_RDWR); + } + sk->sk_state = SMC_CLOSED; + sk->sk_state_change(sk); + } + smc_restore_fallback_changes(smc); + } + + sk->sk_prot->unhash(sk); + + if (sk->sk_state == SMC_CLOSED) { + if (smc->clcsock) { + release_sock(sk); + smc_clcsock_release(smc); + lock_sock(sk); + } + if (!smc->use_fallback) + smc_conn_free(&smc->conn); + } + + return rc; +} + +static int smc_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct smc_sock *smc; + int old_state, rc = 0; + + if (!sk) + goto out; + + sock_hold(sk); /* sock_put below */ + smc = smc_sk(sk); + + old_state = sk->sk_state; + + /* cleanup for a dangling non-blocking connect */ + if (smc->connect_nonblock && old_state == SMC_INIT) + tcp_abort(smc->clcsock->sk, ECONNABORTED); + + if (cancel_work_sync(&smc->connect_work)) + sock_put(&smc->sk); /* sock_hold in smc_connect for passive closing */ + + if (sk->sk_state == SMC_LISTEN) + /* smc_close_non_accepted() is called and acquires + * sock lock for child sockets again + */ + lock_sock_nested(sk, SINGLE_DEPTH_NESTING); + else + lock_sock(sk); + + if (old_state == SMC_INIT && sk->sk_state == SMC_ACTIVE && + !smc->use_fallback) + smc_close_active_abort(smc); + + rc = __smc_release(smc); + + /* detach socket */ + sock_orphan(sk); + sock->sk = NULL; + release_sock(sk); + + sock_put(sk); /* sock_hold above */ + sock_put(sk); /* final sock_put */ +out: + return rc; +} + +static void smc_destruct(struct sock *sk) +{ + if (sk->sk_state != SMC_CLOSED) + return; + if (!sock_flag(sk, SOCK_DEAD)) + return; + + sk_refcnt_debug_dec(sk); +} + +static struct sock *smc_sock_alloc(struct net *net, struct socket *sock, + int protocol) +{ + struct smc_sock *smc; + struct proto *prot; + struct sock *sk; + + prot = (protocol == SMCPROTO_SMC6) ? &smc_proto6 : &smc_proto; + sk = sk_alloc(net, PF_SMC, GFP_KERNEL, prot, 0); + if (!sk) + return NULL; + + sock_init_data(sock, sk); /* sets sk_refcnt to 1 */ + sk->sk_state = SMC_INIT; + sk->sk_destruct = smc_destruct; + sk->sk_protocol = protocol; + WRITE_ONCE(sk->sk_sndbuf, 2 * READ_ONCE(net->smc.sysctl_wmem)); + WRITE_ONCE(sk->sk_rcvbuf, 2 * READ_ONCE(net->smc.sysctl_rmem)); + smc = smc_sk(sk); + INIT_WORK(&smc->tcp_listen_work, smc_tcp_listen_work); + INIT_WORK(&smc->connect_work, smc_connect_work); + INIT_DELAYED_WORK(&smc->conn.tx_work, smc_tx_work); + INIT_LIST_HEAD(&smc->accept_q); + spin_lock_init(&smc->accept_q_lock); + spin_lock_init(&smc->conn.send_lock); + sk->sk_prot->hash(sk); + sk_refcnt_debug_inc(sk); + mutex_init(&smc->clcsock_release_lock); + smc_init_saved_callbacks(smc); + + return sk; +} + +static int smc_bind(struct socket *sock, struct sockaddr *uaddr, + int addr_len) +{ + struct sockaddr_in *addr = (struct sockaddr_in *)uaddr; + struct sock *sk = sock->sk; + struct smc_sock *smc; + int rc; + + smc = smc_sk(sk); + + /* replicate tests from inet_bind(), to be safe wrt. future changes */ + rc = -EINVAL; + if (addr_len < sizeof(struct sockaddr_in)) + goto out; + + rc = -EAFNOSUPPORT; + if (addr->sin_family != AF_INET && + addr->sin_family != AF_INET6 && + addr->sin_family != AF_UNSPEC) + goto out; + /* accept AF_UNSPEC (mapped to AF_INET) only if s_addr is INADDR_ANY */ + if (addr->sin_family == AF_UNSPEC && + addr->sin_addr.s_addr != htonl(INADDR_ANY)) + goto out; + + lock_sock(sk); + + /* Check if socket is already active */ + rc = -EINVAL; + if (sk->sk_state != SMC_INIT || smc->connect_nonblock) + goto out_rel; + + smc->clcsock->sk->sk_reuse = sk->sk_reuse; + smc->clcsock->sk->sk_reuseport = sk->sk_reuseport; + rc = kernel_bind(smc->clcsock, uaddr, addr_len); + +out_rel: + release_sock(sk); +out: + return rc; +} + +/* copy only relevant settings and flags of SOL_SOCKET level from smc to + * clc socket (since smc is not called for these options from net/core) + */ + +#define SK_FLAGS_SMC_TO_CLC ((1UL << SOCK_URGINLINE) | \ + (1UL << SOCK_KEEPOPEN) | \ + (1UL << SOCK_LINGER) | \ + (1UL << SOCK_BROADCAST) | \ + (1UL << SOCK_TIMESTAMP) | \ + (1UL << SOCK_DBG) | \ + (1UL << SOCK_RCVTSTAMP) | \ + (1UL << SOCK_RCVTSTAMPNS) | \ + (1UL << SOCK_LOCALROUTE) | \ + (1UL << SOCK_TIMESTAMPING_RX_SOFTWARE) | \ + (1UL << SOCK_RXQ_OVFL) | \ + (1UL << SOCK_WIFI_STATUS) | \ + (1UL << SOCK_NOFCS) | \ + (1UL << SOCK_FILTER_LOCKED) | \ + (1UL << SOCK_TSTAMP_NEW)) + +/* if set, use value set by setsockopt() - else use IPv4 or SMC sysctl value */ +static void smc_adjust_sock_bufsizes(struct sock *nsk, struct sock *osk, + unsigned long mask) +{ + struct net *nnet = sock_net(nsk); + + nsk->sk_userlocks = osk->sk_userlocks; + if (osk->sk_userlocks & SOCK_SNDBUF_LOCK) { + nsk->sk_sndbuf = osk->sk_sndbuf; + } else { + if (mask == SK_FLAGS_SMC_TO_CLC) + WRITE_ONCE(nsk->sk_sndbuf, + READ_ONCE(nnet->ipv4.sysctl_tcp_wmem[1])); + else + WRITE_ONCE(nsk->sk_sndbuf, + 2 * READ_ONCE(nnet->smc.sysctl_wmem)); + } + if (osk->sk_userlocks & SOCK_RCVBUF_LOCK) { + nsk->sk_rcvbuf = osk->sk_rcvbuf; + } else { + if (mask == SK_FLAGS_SMC_TO_CLC) + WRITE_ONCE(nsk->sk_rcvbuf, + READ_ONCE(nnet->ipv4.sysctl_tcp_rmem[1])); + else + WRITE_ONCE(nsk->sk_rcvbuf, + 2 * READ_ONCE(nnet->smc.sysctl_rmem)); + } +} + +static void smc_copy_sock_settings(struct sock *nsk, struct sock *osk, + unsigned long mask) +{ + /* options we don't get control via setsockopt for */ + nsk->sk_type = osk->sk_type; + nsk->sk_sndtimeo = osk->sk_sndtimeo; + nsk->sk_rcvtimeo = osk->sk_rcvtimeo; + nsk->sk_mark = READ_ONCE(osk->sk_mark); + nsk->sk_priority = osk->sk_priority; + nsk->sk_rcvlowat = osk->sk_rcvlowat; + nsk->sk_bound_dev_if = osk->sk_bound_dev_if; + nsk->sk_err = osk->sk_err; + + nsk->sk_flags &= ~mask; + nsk->sk_flags |= osk->sk_flags & mask; + + smc_adjust_sock_bufsizes(nsk, osk, mask); +} + +static void smc_copy_sock_settings_to_clc(struct smc_sock *smc) +{ + smc_copy_sock_settings(smc->clcsock->sk, &smc->sk, SK_FLAGS_SMC_TO_CLC); +} + +#define SK_FLAGS_CLC_TO_SMC ((1UL << SOCK_URGINLINE) | \ + (1UL << SOCK_KEEPOPEN) | \ + (1UL << SOCK_LINGER) | \ + (1UL << SOCK_DBG)) +/* copy only settings and flags relevant for smc from clc to smc socket */ +static void smc_copy_sock_settings_to_smc(struct smc_sock *smc) +{ + smc_copy_sock_settings(&smc->sk, smc->clcsock->sk, SK_FLAGS_CLC_TO_SMC); +} + +/* register the new vzalloced sndbuf on all links */ +static int smcr_lgr_reg_sndbufs(struct smc_link *link, + struct smc_buf_desc *snd_desc) +{ + struct smc_link_group *lgr = link->lgr; + int i, rc = 0; + + if (!snd_desc->is_vm) + return -EINVAL; + + /* protect against parallel smcr_link_reg_buf() */ + mutex_lock(&lgr->llc_conf_mutex); + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + if (!smc_link_active(&lgr->lnk[i])) + continue; + rc = smcr_link_reg_buf(&lgr->lnk[i], snd_desc); + if (rc) + break; + } + mutex_unlock(&lgr->llc_conf_mutex); + return rc; +} + +/* register the new rmb on all links */ +static int smcr_lgr_reg_rmbs(struct smc_link *link, + struct smc_buf_desc *rmb_desc) +{ + struct smc_link_group *lgr = link->lgr; + int i, rc = 0; + + rc = smc_llc_flow_initiate(lgr, SMC_LLC_FLOW_RKEY); + if (rc) + return rc; + /* protect against parallel smc_llc_cli_rkey_exchange() and + * parallel smcr_link_reg_buf() + */ + mutex_lock(&lgr->llc_conf_mutex); + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + if (!smc_link_active(&lgr->lnk[i])) + continue; + rc = smcr_link_reg_buf(&lgr->lnk[i], rmb_desc); + if (rc) + goto out; + } + + /* exchange confirm_rkey msg with peer */ + rc = smc_llc_do_confirm_rkey(link, rmb_desc); + if (rc) { + rc = -EFAULT; + goto out; + } + rmb_desc->is_conf_rkey = true; +out: + mutex_unlock(&lgr->llc_conf_mutex); + smc_llc_flow_stop(lgr, &lgr->llc_flow_lcl); + return rc; +} + +static int smcr_clnt_conf_first_link(struct smc_sock *smc) +{ + struct smc_link *link = smc->conn.lnk; + struct smc_llc_qentry *qentry; + int rc; + + /* Receive CONFIRM LINK request from server over RoCE fabric. + * Increasing the client's timeout by twice as much as the server's + * timeout by default can temporarily avoid decline messages of + * both sides crossing or colliding + */ + qentry = smc_llc_wait(link->lgr, NULL, 2 * SMC_LLC_WAIT_TIME, + SMC_LLC_CONFIRM_LINK); + if (!qentry) { + struct smc_clc_msg_decline dclc; + + rc = smc_clc_wait_msg(smc, &dclc, sizeof(dclc), + SMC_CLC_DECLINE, CLC_WAIT_TIME_SHORT); + return rc == -EAGAIN ? SMC_CLC_DECL_TIMEOUT_CL : rc; + } + smc_llc_save_peer_uid(qentry); + rc = smc_llc_eval_conf_link(qentry, SMC_LLC_REQ); + smc_llc_flow_qentry_del(&link->lgr->llc_flow_lcl); + if (rc) + return SMC_CLC_DECL_RMBE_EC; + + rc = smc_ib_modify_qp_rts(link); + if (rc) + return SMC_CLC_DECL_ERR_RDYLNK; + + smc_wr_remember_qp_attr(link); + + /* reg the sndbuf if it was vzalloced */ + if (smc->conn.sndbuf_desc->is_vm) { + if (smcr_link_reg_buf(link, smc->conn.sndbuf_desc)) + return SMC_CLC_DECL_ERR_REGBUF; + } + + /* reg the rmb */ + if (smcr_link_reg_buf(link, smc->conn.rmb_desc)) + return SMC_CLC_DECL_ERR_REGBUF; + + /* confirm_rkey is implicit on 1st contact */ + smc->conn.rmb_desc->is_conf_rkey = true; + + /* send CONFIRM LINK response over RoCE fabric */ + rc = smc_llc_send_confirm_link(link, SMC_LLC_RESP); + if (rc < 0) + return SMC_CLC_DECL_TIMEOUT_CL; + + smc_llc_link_active(link); + smcr_lgr_set_type(link->lgr, SMC_LGR_SINGLE); + + /* optional 2nd link, receive ADD LINK request from server */ + qentry = smc_llc_wait(link->lgr, NULL, SMC_LLC_WAIT_TIME, + SMC_LLC_ADD_LINK); + if (!qentry) { + struct smc_clc_msg_decline dclc; + + rc = smc_clc_wait_msg(smc, &dclc, sizeof(dclc), + SMC_CLC_DECLINE, CLC_WAIT_TIME_SHORT); + if (rc == -EAGAIN) + rc = 0; /* no DECLINE received, go with one link */ + return rc; + } + smc_llc_flow_qentry_clr(&link->lgr->llc_flow_lcl); + smc_llc_cli_add_link(link, qentry); + return 0; +} + +static bool smc_isascii(char *hostname) +{ + int i; + + for (i = 0; i < SMC_MAX_HOSTNAME_LEN; i++) + if (!isascii(hostname[i])) + return false; + return true; +} + +static void smc_conn_save_peer_info_fce(struct smc_sock *smc, + struct smc_clc_msg_accept_confirm *clc) +{ + struct smc_clc_msg_accept_confirm_v2 *clc_v2 = + (struct smc_clc_msg_accept_confirm_v2 *)clc; + struct smc_clc_first_contact_ext *fce; + int clc_v2_len; + + if (clc->hdr.version == SMC_V1 || + !(clc->hdr.typev2 & SMC_FIRST_CONTACT_MASK)) + return; + + if (smc->conn.lgr->is_smcd) { + memcpy(smc->conn.lgr->negotiated_eid, clc_v2->d1.eid, + SMC_MAX_EID_LEN); + clc_v2_len = offsetofend(struct smc_clc_msg_accept_confirm_v2, + d1); + } else { + memcpy(smc->conn.lgr->negotiated_eid, clc_v2->r1.eid, + SMC_MAX_EID_LEN); + clc_v2_len = offsetofend(struct smc_clc_msg_accept_confirm_v2, + r1); + } + fce = (struct smc_clc_first_contact_ext *)(((u8 *)clc_v2) + clc_v2_len); + smc->conn.lgr->peer_os = fce->os_type; + smc->conn.lgr->peer_smc_release = fce->release; + if (smc_isascii(fce->hostname)) + memcpy(smc->conn.lgr->peer_hostname, fce->hostname, + SMC_MAX_HOSTNAME_LEN); +} + +static void smcr_conn_save_peer_info(struct smc_sock *smc, + struct smc_clc_msg_accept_confirm *clc) +{ + int bufsize = smc_uncompress_bufsize(clc->r0.rmbe_size); + + smc->conn.peer_rmbe_idx = clc->r0.rmbe_idx; + smc->conn.local_tx_ctrl.token = ntohl(clc->r0.rmbe_alert_token); + smc->conn.peer_rmbe_size = bufsize; + atomic_set(&smc->conn.peer_rmbe_space, smc->conn.peer_rmbe_size); + smc->conn.tx_off = bufsize * (smc->conn.peer_rmbe_idx - 1); +} + +static void smcd_conn_save_peer_info(struct smc_sock *smc, + struct smc_clc_msg_accept_confirm *clc) +{ + int bufsize = smc_uncompress_bufsize(clc->d0.dmbe_size); + + smc->conn.peer_rmbe_idx = clc->d0.dmbe_idx; + smc->conn.peer_token = clc->d0.token; + /* msg header takes up space in the buffer */ + smc->conn.peer_rmbe_size = bufsize - sizeof(struct smcd_cdc_msg); + atomic_set(&smc->conn.peer_rmbe_space, smc->conn.peer_rmbe_size); + smc->conn.tx_off = bufsize * smc->conn.peer_rmbe_idx; +} + +static void smc_conn_save_peer_info(struct smc_sock *smc, + struct smc_clc_msg_accept_confirm *clc) +{ + if (smc->conn.lgr->is_smcd) + smcd_conn_save_peer_info(smc, clc); + else + smcr_conn_save_peer_info(smc, clc); + smc_conn_save_peer_info_fce(smc, clc); +} + +static void smc_link_save_peer_info(struct smc_link *link, + struct smc_clc_msg_accept_confirm *clc, + struct smc_init_info *ini) +{ + link->peer_qpn = ntoh24(clc->r0.qpn); + memcpy(link->peer_gid, ini->peer_gid, SMC_GID_SIZE); + memcpy(link->peer_mac, ini->peer_mac, sizeof(link->peer_mac)); + link->peer_psn = ntoh24(clc->r0.psn); + link->peer_mtu = clc->r0.qp_mtu; +} + +static void smc_stat_inc_fback_rsn_cnt(struct smc_sock *smc, + struct smc_stats_fback *fback_arr) +{ + int cnt; + + for (cnt = 0; cnt < SMC_MAX_FBACK_RSN_CNT; cnt++) { + if (fback_arr[cnt].fback_code == smc->fallback_rsn) { + fback_arr[cnt].count++; + break; + } + if (!fback_arr[cnt].fback_code) { + fback_arr[cnt].fback_code = smc->fallback_rsn; + fback_arr[cnt].count++; + break; + } + } +} + +static void smc_stat_fallback(struct smc_sock *smc) +{ + struct net *net = sock_net(&smc->sk); + + mutex_lock(&net->smc.mutex_fback_rsn); + if (smc->listen_smc) { + smc_stat_inc_fback_rsn_cnt(smc, net->smc.fback_rsn->srv); + net->smc.fback_rsn->srv_fback_cnt++; + } else { + smc_stat_inc_fback_rsn_cnt(smc, net->smc.fback_rsn->clnt); + net->smc.fback_rsn->clnt_fback_cnt++; + } + mutex_unlock(&net->smc.mutex_fback_rsn); +} + +/* must be called under rcu read lock */ +static void smc_fback_wakeup_waitqueue(struct smc_sock *smc, void *key) +{ + struct socket_wq *wq; + __poll_t flags; + + wq = rcu_dereference(smc->sk.sk_wq); + if (!skwq_has_sleeper(wq)) + return; + + /* wake up smc sk->sk_wq */ + if (!key) { + /* sk_state_change */ + wake_up_interruptible_all(&wq->wait); + } else { + flags = key_to_poll(key); + if (flags & (EPOLLIN | EPOLLOUT)) + /* sk_data_ready or sk_write_space */ + wake_up_interruptible_sync_poll(&wq->wait, flags); + else if (flags & EPOLLERR) + /* sk_error_report */ + wake_up_interruptible_poll(&wq->wait, flags); + } +} + +static int smc_fback_mark_woken(wait_queue_entry_t *wait, + unsigned int mode, int sync, void *key) +{ + struct smc_mark_woken *mark = + container_of(wait, struct smc_mark_woken, wait_entry); + + mark->woken = true; + mark->key = key; + return 0; +} + +static void smc_fback_forward_wakeup(struct smc_sock *smc, struct sock *clcsk, + void (*clcsock_callback)(struct sock *sk)) +{ + struct smc_mark_woken mark = { .woken = false }; + struct socket_wq *wq; + + init_waitqueue_func_entry(&mark.wait_entry, + smc_fback_mark_woken); + rcu_read_lock(); + wq = rcu_dereference(clcsk->sk_wq); + if (!wq) + goto out; + add_wait_queue(sk_sleep(clcsk), &mark.wait_entry); + clcsock_callback(clcsk); + remove_wait_queue(sk_sleep(clcsk), &mark.wait_entry); + + if (mark.woken) + smc_fback_wakeup_waitqueue(smc, mark.key); +out: + rcu_read_unlock(); +} + +static void smc_fback_state_change(struct sock *clcsk) +{ + struct smc_sock *smc; + + read_lock_bh(&clcsk->sk_callback_lock); + smc = smc_clcsock_user_data(clcsk); + if (smc) + smc_fback_forward_wakeup(smc, clcsk, + smc->clcsk_state_change); + read_unlock_bh(&clcsk->sk_callback_lock); +} + +static void smc_fback_data_ready(struct sock *clcsk) +{ + struct smc_sock *smc; + + read_lock_bh(&clcsk->sk_callback_lock); + smc = smc_clcsock_user_data(clcsk); + if (smc) + smc_fback_forward_wakeup(smc, clcsk, + smc->clcsk_data_ready); + read_unlock_bh(&clcsk->sk_callback_lock); +} + +static void smc_fback_write_space(struct sock *clcsk) +{ + struct smc_sock *smc; + + read_lock_bh(&clcsk->sk_callback_lock); + smc = smc_clcsock_user_data(clcsk); + if (smc) + smc_fback_forward_wakeup(smc, clcsk, + smc->clcsk_write_space); + read_unlock_bh(&clcsk->sk_callback_lock); +} + +static void smc_fback_error_report(struct sock *clcsk) +{ + struct smc_sock *smc; + + read_lock_bh(&clcsk->sk_callback_lock); + smc = smc_clcsock_user_data(clcsk); + if (smc) + smc_fback_forward_wakeup(smc, clcsk, + smc->clcsk_error_report); + read_unlock_bh(&clcsk->sk_callback_lock); +} + +static void smc_fback_replace_callbacks(struct smc_sock *smc) +{ + struct sock *clcsk = smc->clcsock->sk; + + write_lock_bh(&clcsk->sk_callback_lock); + clcsk->sk_user_data = (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY); + + smc_clcsock_replace_cb(&clcsk->sk_state_change, smc_fback_state_change, + &smc->clcsk_state_change); + smc_clcsock_replace_cb(&clcsk->sk_data_ready, smc_fback_data_ready, + &smc->clcsk_data_ready); + smc_clcsock_replace_cb(&clcsk->sk_write_space, smc_fback_write_space, + &smc->clcsk_write_space); + smc_clcsock_replace_cb(&clcsk->sk_error_report, smc_fback_error_report, + &smc->clcsk_error_report); + + write_unlock_bh(&clcsk->sk_callback_lock); +} + +static int smc_switch_to_fallback(struct smc_sock *smc, int reason_code) +{ + int rc = 0; + + mutex_lock(&smc->clcsock_release_lock); + if (!smc->clcsock) { + rc = -EBADF; + goto out; + } + + smc->use_fallback = true; + smc->fallback_rsn = reason_code; + smc_stat_fallback(smc); + trace_smc_switch_to_fallback(smc, reason_code); + if (smc->sk.sk_socket && smc->sk.sk_socket->file) { + smc->clcsock->file = smc->sk.sk_socket->file; + smc->clcsock->file->private_data = smc->clcsock; + smc->clcsock->wq.fasync_list = + smc->sk.sk_socket->wq.fasync_list; + + /* There might be some wait entries remaining + * in smc sk->sk_wq and they should be woken up + * as clcsock's wait queue is woken up. + */ + smc_fback_replace_callbacks(smc); + } +out: + mutex_unlock(&smc->clcsock_release_lock); + return rc; +} + +/* fall back during connect */ +static int smc_connect_fallback(struct smc_sock *smc, int reason_code) +{ + struct net *net = sock_net(&smc->sk); + int rc = 0; + + rc = smc_switch_to_fallback(smc, reason_code); + if (rc) { /* fallback fails */ + this_cpu_inc(net->smc.smc_stats->clnt_hshake_err_cnt); + if (smc->sk.sk_state == SMC_INIT) + sock_put(&smc->sk); /* passive closing */ + return rc; + } + smc_copy_sock_settings_to_clc(smc); + smc->connect_nonblock = 0; + if (smc->sk.sk_state == SMC_INIT) + smc->sk.sk_state = SMC_ACTIVE; + return 0; +} + +/* decline and fall back during connect */ +static int smc_connect_decline_fallback(struct smc_sock *smc, int reason_code, + u8 version) +{ + struct net *net = sock_net(&smc->sk); + int rc; + + if (reason_code < 0) { /* error, fallback is not possible */ + this_cpu_inc(net->smc.smc_stats->clnt_hshake_err_cnt); + if (smc->sk.sk_state == SMC_INIT) + sock_put(&smc->sk); /* passive closing */ + return reason_code; + } + if (reason_code != SMC_CLC_DECL_PEERDECL) { + rc = smc_clc_send_decline(smc, reason_code, version); + if (rc < 0) { + this_cpu_inc(net->smc.smc_stats->clnt_hshake_err_cnt); + if (smc->sk.sk_state == SMC_INIT) + sock_put(&smc->sk); /* passive closing */ + return rc; + } + } + return smc_connect_fallback(smc, reason_code); +} + +static void smc_conn_abort(struct smc_sock *smc, int local_first) +{ + struct smc_connection *conn = &smc->conn; + struct smc_link_group *lgr = conn->lgr; + bool lgr_valid = false; + + if (smc_conn_lgr_valid(conn)) + lgr_valid = true; + + smc_conn_free(conn); + if (local_first && lgr_valid) + smc_lgr_cleanup_early(lgr); +} + +/* check if there is a rdma device available for this connection. */ +/* called for connect and listen */ +static int smc_find_rdma_device(struct smc_sock *smc, struct smc_init_info *ini) +{ + /* PNET table look up: search active ib_device and port + * within same PNETID that also contains the ethernet device + * used for the internal TCP socket + */ + smc_pnet_find_roce_resource(smc->clcsock->sk, ini); + if (!ini->check_smcrv2 && !ini->ib_dev) + return SMC_CLC_DECL_NOSMCRDEV; + if (ini->check_smcrv2 && !ini->smcrv2.ib_dev_v2) + return SMC_CLC_DECL_NOSMCRDEV; + return 0; +} + +/* check if there is an ISM device available for this connection. */ +/* called for connect and listen */ +static int smc_find_ism_device(struct smc_sock *smc, struct smc_init_info *ini) +{ + /* Find ISM device with same PNETID as connecting interface */ + smc_pnet_find_ism_resource(smc->clcsock->sk, ini); + if (!ini->ism_dev[0]) + return SMC_CLC_DECL_NOSMCDDEV; + else + ini->ism_chid[0] = smc_ism_get_chid(ini->ism_dev[0]); + return 0; +} + +/* is chid unique for the ism devices that are already determined? */ +static bool smc_find_ism_v2_is_unique_chid(u16 chid, struct smc_init_info *ini, + int cnt) +{ + int i = (!ini->ism_dev[0]) ? 1 : 0; + + for (; i < cnt; i++) + if (ini->ism_chid[i] == chid) + return false; + return true; +} + +/* determine possible V2 ISM devices (either without PNETID or with PNETID plus + * PNETID matching net_device) + */ +static int smc_find_ism_v2_device_clnt(struct smc_sock *smc, + struct smc_init_info *ini) +{ + int rc = SMC_CLC_DECL_NOSMCDDEV; + struct smcd_dev *smcd; + int i = 1; + u16 chid; + + if (smcd_indicated(ini->smc_type_v1)) + rc = 0; /* already initialized for V1 */ + mutex_lock(&smcd_dev_list.mutex); + list_for_each_entry(smcd, &smcd_dev_list.list, list) { + if (smcd->going_away || smcd == ini->ism_dev[0]) + continue; + chid = smc_ism_get_chid(smcd); + if (!smc_find_ism_v2_is_unique_chid(chid, ini, i)) + continue; + if (!smc_pnet_is_pnetid_set(smcd->pnetid) || + smc_pnet_is_ndev_pnetid(sock_net(&smc->sk), smcd->pnetid)) { + ini->ism_dev[i] = smcd; + ini->ism_chid[i] = chid; + ini->is_smcd = true; + rc = 0; + i++; + if (i > SMC_MAX_ISM_DEVS) + break; + } + } + mutex_unlock(&smcd_dev_list.mutex); + ini->ism_offered_cnt = i - 1; + if (!ini->ism_dev[0] && !ini->ism_dev[1]) + ini->smcd_version = 0; + + return rc; +} + +/* Check for VLAN ID and register it on ISM device just for CLC handshake */ +static int smc_connect_ism_vlan_setup(struct smc_sock *smc, + struct smc_init_info *ini) +{ + if (ini->vlan_id && smc_ism_get_vlan(ini->ism_dev[0], ini->vlan_id)) + return SMC_CLC_DECL_ISMVLANERR; + return 0; +} + +static int smc_find_proposal_devices(struct smc_sock *smc, + struct smc_init_info *ini) +{ + int rc = 0; + + /* check if there is an ism device available */ + if (!(ini->smcd_version & SMC_V1) || + smc_find_ism_device(smc, ini) || + smc_connect_ism_vlan_setup(smc, ini)) + ini->smcd_version &= ~SMC_V1; + /* else ISM V1 is supported for this connection */ + + /* check if there is an rdma device available */ + if (!(ini->smcr_version & SMC_V1) || + smc_find_rdma_device(smc, ini)) + ini->smcr_version &= ~SMC_V1; + /* else RDMA is supported for this connection */ + + ini->smc_type_v1 = smc_indicated_type(ini->smcd_version & SMC_V1, + ini->smcr_version & SMC_V1); + + /* check if there is an ism v2 device available */ + if (!(ini->smcd_version & SMC_V2) || + !smc_ism_is_v2_capable() || + smc_find_ism_v2_device_clnt(smc, ini)) + ini->smcd_version &= ~SMC_V2; + + /* check if there is an rdma v2 device available */ + ini->check_smcrv2 = true; + ini->smcrv2.saddr = smc->clcsock->sk->sk_rcv_saddr; + if (!(ini->smcr_version & SMC_V2) || + smc->clcsock->sk->sk_family != AF_INET || + !smc_clc_ueid_count() || + smc_find_rdma_device(smc, ini)) + ini->smcr_version &= ~SMC_V2; + ini->check_smcrv2 = false; + + ini->smc_type_v2 = smc_indicated_type(ini->smcd_version & SMC_V2, + ini->smcr_version & SMC_V2); + + /* if neither ISM nor RDMA are supported, fallback */ + if (ini->smc_type_v1 == SMC_TYPE_N && ini->smc_type_v2 == SMC_TYPE_N) + rc = SMC_CLC_DECL_NOSMCDEV; + + return rc; +} + +/* cleanup temporary VLAN ID registration used for CLC handshake. If ISM is + * used, the VLAN ID will be registered again during the connection setup. + */ +static int smc_connect_ism_vlan_cleanup(struct smc_sock *smc, + struct smc_init_info *ini) +{ + if (!smcd_indicated(ini->smc_type_v1)) + return 0; + if (ini->vlan_id && smc_ism_put_vlan(ini->ism_dev[0], ini->vlan_id)) + return SMC_CLC_DECL_CNFERR; + return 0; +} + +#define SMC_CLC_MAX_ACCEPT_LEN \ + (sizeof(struct smc_clc_msg_accept_confirm_v2) + \ + sizeof(struct smc_clc_first_contact_ext) + \ + sizeof(struct smc_clc_msg_trail)) + +/* CLC handshake during connect */ +static int smc_connect_clc(struct smc_sock *smc, + struct smc_clc_msg_accept_confirm_v2 *aclc2, + struct smc_init_info *ini) +{ + int rc = 0; + + /* do inband token exchange */ + rc = smc_clc_send_proposal(smc, ini); + if (rc) + return rc; + /* receive SMC Accept CLC message */ + return smc_clc_wait_msg(smc, aclc2, SMC_CLC_MAX_ACCEPT_LEN, + SMC_CLC_ACCEPT, CLC_WAIT_TIME); +} + +void smc_fill_gid_list(struct smc_link_group *lgr, + struct smc_gidlist *gidlist, + struct smc_ib_device *known_dev, u8 *known_gid) +{ + struct smc_init_info *alt_ini = NULL; + + memset(gidlist, 0, sizeof(*gidlist)); + memcpy(gidlist->list[gidlist->len++], known_gid, SMC_GID_SIZE); + + alt_ini = kzalloc(sizeof(*alt_ini), GFP_KERNEL); + if (!alt_ini) + goto out; + + alt_ini->vlan_id = lgr->vlan_id; + alt_ini->check_smcrv2 = true; + alt_ini->smcrv2.saddr = lgr->saddr; + smc_pnet_find_alt_roce(lgr, alt_ini, known_dev); + + if (!alt_ini->smcrv2.ib_dev_v2) + goto out; + + memcpy(gidlist->list[gidlist->len++], alt_ini->smcrv2.ib_gid_v2, + SMC_GID_SIZE); + +out: + kfree(alt_ini); +} + +static int smc_connect_rdma_v2_prepare(struct smc_sock *smc, + struct smc_clc_msg_accept_confirm *aclc, + struct smc_init_info *ini) +{ + struct smc_clc_msg_accept_confirm_v2 *clc_v2 = + (struct smc_clc_msg_accept_confirm_v2 *)aclc; + struct smc_clc_first_contact_ext *fce = + (struct smc_clc_first_contact_ext *) + (((u8 *)clc_v2) + sizeof(*clc_v2)); + struct net *net = sock_net(&smc->sk); + + if (!ini->first_contact_peer || aclc->hdr.version == SMC_V1) + return 0; + + if (fce->v2_direct) { + memcpy(ini->smcrv2.nexthop_mac, &aclc->r0.lcl.mac, ETH_ALEN); + ini->smcrv2.uses_gateway = false; + } else { + if (smc_ib_find_route(net, smc->clcsock->sk->sk_rcv_saddr, + smc_ib_gid_to_ipv4(aclc->r0.lcl.gid), + ini->smcrv2.nexthop_mac, + &ini->smcrv2.uses_gateway)) + return SMC_CLC_DECL_NOROUTE; + if (!ini->smcrv2.uses_gateway) { + /* mismatch: peer claims indirect, but its direct */ + return SMC_CLC_DECL_NOINDIRECT; + } + } + return 0; +} + +/* setup for RDMA connection of client */ +static int smc_connect_rdma(struct smc_sock *smc, + struct smc_clc_msg_accept_confirm *aclc, + struct smc_init_info *ini) +{ + int i, reason_code = 0; + struct smc_link *link; + u8 *eid = NULL; + + ini->is_smcd = false; + ini->ib_clcqpn = ntoh24(aclc->r0.qpn); + ini->first_contact_peer = aclc->hdr.typev2 & SMC_FIRST_CONTACT_MASK; + memcpy(ini->peer_systemid, aclc->r0.lcl.id_for_peer, SMC_SYSTEMID_LEN); + memcpy(ini->peer_gid, aclc->r0.lcl.gid, SMC_GID_SIZE); + memcpy(ini->peer_mac, aclc->r0.lcl.mac, ETH_ALEN); + + reason_code = smc_connect_rdma_v2_prepare(smc, aclc, ini); + if (reason_code) + return reason_code; + + mutex_lock(&smc_client_lgr_pending); + reason_code = smc_conn_create(smc, ini); + if (reason_code) { + mutex_unlock(&smc_client_lgr_pending); + return reason_code; + } + + smc_conn_save_peer_info(smc, aclc); + + if (ini->first_contact_local) { + link = smc->conn.lnk; + } else { + /* set link that was assigned by server */ + link = NULL; + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + struct smc_link *l = &smc->conn.lgr->lnk[i]; + + if (l->peer_qpn == ntoh24(aclc->r0.qpn) && + !memcmp(l->peer_gid, &aclc->r0.lcl.gid, + SMC_GID_SIZE) && + (aclc->hdr.version > SMC_V1 || + !memcmp(l->peer_mac, &aclc->r0.lcl.mac, + sizeof(l->peer_mac)))) { + link = l; + break; + } + } + if (!link) { + reason_code = SMC_CLC_DECL_NOSRVLINK; + goto connect_abort; + } + smc_switch_link_and_count(&smc->conn, link); + } + + /* create send buffer and rmb */ + if (smc_buf_create(smc, false)) { + reason_code = SMC_CLC_DECL_MEM; + goto connect_abort; + } + + if (ini->first_contact_local) + smc_link_save_peer_info(link, aclc, ini); + + if (smc_rmb_rtoken_handling(&smc->conn, link, aclc)) { + reason_code = SMC_CLC_DECL_ERR_RTOK; + goto connect_abort; + } + + smc_close_init(smc); + smc_rx_init(smc); + + if (ini->first_contact_local) { + if (smc_ib_ready_link(link)) { + reason_code = SMC_CLC_DECL_ERR_RDYLNK; + goto connect_abort; + } + } else { + /* reg sendbufs if they were vzalloced */ + if (smc->conn.sndbuf_desc->is_vm) { + if (smcr_lgr_reg_sndbufs(link, smc->conn.sndbuf_desc)) { + reason_code = SMC_CLC_DECL_ERR_REGBUF; + goto connect_abort; + } + } + if (smcr_lgr_reg_rmbs(link, smc->conn.rmb_desc)) { + reason_code = SMC_CLC_DECL_ERR_REGBUF; + goto connect_abort; + } + } + + if (aclc->hdr.version > SMC_V1) { + struct smc_clc_msg_accept_confirm_v2 *clc_v2 = + (struct smc_clc_msg_accept_confirm_v2 *)aclc; + + eid = clc_v2->r1.eid; + if (ini->first_contact_local) + smc_fill_gid_list(link->lgr, &ini->smcrv2.gidlist, + link->smcibdev, link->gid); + } + + reason_code = smc_clc_send_confirm(smc, ini->first_contact_local, + aclc->hdr.version, eid, ini); + if (reason_code) + goto connect_abort; + + smc_tx_init(smc); + + if (ini->first_contact_local) { + /* QP confirmation over RoCE fabric */ + smc_llc_flow_initiate(link->lgr, SMC_LLC_FLOW_ADD_LINK); + reason_code = smcr_clnt_conf_first_link(smc); + smc_llc_flow_stop(link->lgr, &link->lgr->llc_flow_lcl); + if (reason_code) + goto connect_abort; + } + mutex_unlock(&smc_client_lgr_pending); + + smc_copy_sock_settings_to_clc(smc); + smc->connect_nonblock = 0; + if (smc->sk.sk_state == SMC_INIT) + smc->sk.sk_state = SMC_ACTIVE; + + return 0; +connect_abort: + smc_conn_abort(smc, ini->first_contact_local); + mutex_unlock(&smc_client_lgr_pending); + smc->connect_nonblock = 0; + + return reason_code; +} + +/* The server has chosen one of the proposed ISM devices for the communication. + * Determine from the CHID of the received CLC ACCEPT the ISM device chosen. + */ +static int +smc_v2_determine_accepted_chid(struct smc_clc_msg_accept_confirm_v2 *aclc, + struct smc_init_info *ini) +{ + int i; + + for (i = 0; i < ini->ism_offered_cnt + 1; i++) { + if (ini->ism_chid[i] == ntohs(aclc->d1.chid)) { + ini->ism_selected = i; + return 0; + } + } + + return -EPROTO; +} + +/* setup for ISM connection of client */ +static int smc_connect_ism(struct smc_sock *smc, + struct smc_clc_msg_accept_confirm *aclc, + struct smc_init_info *ini) +{ + u8 *eid = NULL; + int rc = 0; + + ini->is_smcd = true; + ini->first_contact_peer = aclc->hdr.typev2 & SMC_FIRST_CONTACT_MASK; + + if (aclc->hdr.version == SMC_V2) { + struct smc_clc_msg_accept_confirm_v2 *aclc_v2 = + (struct smc_clc_msg_accept_confirm_v2 *)aclc; + + rc = smc_v2_determine_accepted_chid(aclc_v2, ini); + if (rc) + return rc; + } + ini->ism_peer_gid[ini->ism_selected] = aclc->d0.gid; + + /* there is only one lgr role for SMC-D; use server lock */ + mutex_lock(&smc_server_lgr_pending); + rc = smc_conn_create(smc, ini); + if (rc) { + mutex_unlock(&smc_server_lgr_pending); + return rc; + } + + /* Create send and receive buffers */ + rc = smc_buf_create(smc, true); + if (rc) { + rc = (rc == -ENOSPC) ? SMC_CLC_DECL_MAX_DMB : SMC_CLC_DECL_MEM; + goto connect_abort; + } + + smc_conn_save_peer_info(smc, aclc); + smc_close_init(smc); + smc_rx_init(smc); + smc_tx_init(smc); + + if (aclc->hdr.version > SMC_V1) { + struct smc_clc_msg_accept_confirm_v2 *clc_v2 = + (struct smc_clc_msg_accept_confirm_v2 *)aclc; + + eid = clc_v2->d1.eid; + } + + rc = smc_clc_send_confirm(smc, ini->first_contact_local, + aclc->hdr.version, eid, NULL); + if (rc) + goto connect_abort; + mutex_unlock(&smc_server_lgr_pending); + + smc_copy_sock_settings_to_clc(smc); + smc->connect_nonblock = 0; + if (smc->sk.sk_state == SMC_INIT) + smc->sk.sk_state = SMC_ACTIVE; + + return 0; +connect_abort: + smc_conn_abort(smc, ini->first_contact_local); + mutex_unlock(&smc_server_lgr_pending); + smc->connect_nonblock = 0; + + return rc; +} + +/* check if received accept type and version matches a proposed one */ +static int smc_connect_check_aclc(struct smc_init_info *ini, + struct smc_clc_msg_accept_confirm *aclc) +{ + if (aclc->hdr.typev1 != SMC_TYPE_R && + aclc->hdr.typev1 != SMC_TYPE_D) + return SMC_CLC_DECL_MODEUNSUPP; + + if (aclc->hdr.version >= SMC_V2) { + if ((aclc->hdr.typev1 == SMC_TYPE_R && + !smcr_indicated(ini->smc_type_v2)) || + (aclc->hdr.typev1 == SMC_TYPE_D && + !smcd_indicated(ini->smc_type_v2))) + return SMC_CLC_DECL_MODEUNSUPP; + } else { + if ((aclc->hdr.typev1 == SMC_TYPE_R && + !smcr_indicated(ini->smc_type_v1)) || + (aclc->hdr.typev1 == SMC_TYPE_D && + !smcd_indicated(ini->smc_type_v1))) + return SMC_CLC_DECL_MODEUNSUPP; + } + + return 0; +} + +/* perform steps before actually connecting */ +static int __smc_connect(struct smc_sock *smc) +{ + u8 version = smc_ism_is_v2_capable() ? SMC_V2 : SMC_V1; + struct smc_clc_msg_accept_confirm_v2 *aclc2; + struct smc_clc_msg_accept_confirm *aclc; + struct smc_init_info *ini = NULL; + u8 *buf = NULL; + int rc = 0; + + if (smc->use_fallback) + return smc_connect_fallback(smc, smc->fallback_rsn); + + /* if peer has not signalled SMC-capability, fall back */ + if (!tcp_sk(smc->clcsock->sk)->syn_smc) + return smc_connect_fallback(smc, SMC_CLC_DECL_PEERNOSMC); + + /* IPSec connections opt out of SMC optimizations */ + if (using_ipsec(smc)) + return smc_connect_decline_fallback(smc, SMC_CLC_DECL_IPSEC, + version); + + ini = kzalloc(sizeof(*ini), GFP_KERNEL); + if (!ini) + return smc_connect_decline_fallback(smc, SMC_CLC_DECL_MEM, + version); + + ini->smcd_version = SMC_V1 | SMC_V2; + ini->smcr_version = SMC_V1 | SMC_V2; + ini->smc_type_v1 = SMC_TYPE_B; + ini->smc_type_v2 = SMC_TYPE_B; + + /* get vlan id from IP device */ + if (smc_vlan_by_tcpsk(smc->clcsock, ini)) { + ini->smcd_version &= ~SMC_V1; + ini->smcr_version = 0; + ini->smc_type_v1 = SMC_TYPE_N; + if (!ini->smcd_version) { + rc = SMC_CLC_DECL_GETVLANERR; + goto fallback; + } + } + + rc = smc_find_proposal_devices(smc, ini); + if (rc) + goto fallback; + + buf = kzalloc(SMC_CLC_MAX_ACCEPT_LEN, GFP_KERNEL); + if (!buf) { + rc = SMC_CLC_DECL_MEM; + goto fallback; + } + aclc2 = (struct smc_clc_msg_accept_confirm_v2 *)buf; + aclc = (struct smc_clc_msg_accept_confirm *)aclc2; + + /* perform CLC handshake */ + rc = smc_connect_clc(smc, aclc2, ini); + if (rc) { + /* -EAGAIN on timeout, see tcp_recvmsg() */ + if (rc == -EAGAIN) { + rc = -ETIMEDOUT; + smc->sk.sk_err = ETIMEDOUT; + } + goto vlan_cleanup; + } + + /* check if smc modes and versions of CLC proposal and accept match */ + rc = smc_connect_check_aclc(ini, aclc); + version = aclc->hdr.version == SMC_V1 ? SMC_V1 : SMC_V2; + if (rc) + goto vlan_cleanup; + + /* depending on previous steps, connect using rdma or ism */ + if (aclc->hdr.typev1 == SMC_TYPE_R) { + ini->smcr_version = version; + rc = smc_connect_rdma(smc, aclc, ini); + } else if (aclc->hdr.typev1 == SMC_TYPE_D) { + ini->smcd_version = version; + rc = smc_connect_ism(smc, aclc, ini); + } + if (rc) + goto vlan_cleanup; + + SMC_STAT_CLNT_SUCC_INC(sock_net(smc->clcsock->sk), aclc); + smc_connect_ism_vlan_cleanup(smc, ini); + kfree(buf); + kfree(ini); + return 0; + +vlan_cleanup: + smc_connect_ism_vlan_cleanup(smc, ini); + kfree(buf); +fallback: + kfree(ini); + return smc_connect_decline_fallback(smc, rc, version); +} + +static void smc_connect_work(struct work_struct *work) +{ + struct smc_sock *smc = container_of(work, struct smc_sock, + connect_work); + long timeo = smc->sk.sk_sndtimeo; + int rc = 0; + + if (!timeo) + timeo = MAX_SCHEDULE_TIMEOUT; + lock_sock(smc->clcsock->sk); + if (smc->clcsock->sk->sk_err) { + smc->sk.sk_err = smc->clcsock->sk->sk_err; + } else if ((1 << smc->clcsock->sk->sk_state) & + (TCPF_SYN_SENT | TCPF_SYN_RECV)) { + rc = sk_stream_wait_connect(smc->clcsock->sk, &timeo); + if ((rc == -EPIPE) && + ((1 << smc->clcsock->sk->sk_state) & + (TCPF_ESTABLISHED | TCPF_CLOSE_WAIT))) + rc = 0; + } + release_sock(smc->clcsock->sk); + lock_sock(&smc->sk); + if (rc != 0 || smc->sk.sk_err) { + smc->sk.sk_state = SMC_CLOSED; + if (rc == -EPIPE || rc == -EAGAIN) + smc->sk.sk_err = EPIPE; + else if (rc == -ECONNREFUSED) + smc->sk.sk_err = ECONNREFUSED; + else if (signal_pending(current)) + smc->sk.sk_err = -sock_intr_errno(timeo); + sock_put(&smc->sk); /* passive closing */ + goto out; + } + + rc = __smc_connect(smc); + if (rc < 0) + smc->sk.sk_err = -rc; + +out: + if (!sock_flag(&smc->sk, SOCK_DEAD)) { + if (smc->sk.sk_err) { + smc->sk.sk_state_change(&smc->sk); + } else { /* allow polling before and after fallback decision */ + smc->clcsock->sk->sk_write_space(smc->clcsock->sk); + smc->sk.sk_write_space(&smc->sk); + } + } + release_sock(&smc->sk); +} + +static int smc_connect(struct socket *sock, struct sockaddr *addr, + int alen, int flags) +{ + struct sock *sk = sock->sk; + struct smc_sock *smc; + int rc = -EINVAL; + + smc = smc_sk(sk); + + /* separate smc parameter checking to be safe */ + if (alen < sizeof(addr->sa_family)) + goto out_err; + if (addr->sa_family != AF_INET && addr->sa_family != AF_INET6) + goto out_err; + + lock_sock(sk); + switch (sock->state) { + default: + rc = -EINVAL; + goto out; + case SS_CONNECTED: + rc = sk->sk_state == SMC_ACTIVE ? -EISCONN : -EINVAL; + goto out; + case SS_CONNECTING: + if (sk->sk_state == SMC_ACTIVE) + goto connected; + break; + case SS_UNCONNECTED: + sock->state = SS_CONNECTING; + break; + } + + switch (sk->sk_state) { + default: + goto out; + case SMC_CLOSED: + rc = sock_error(sk) ? : -ECONNABORTED; + sock->state = SS_UNCONNECTED; + goto out; + case SMC_ACTIVE: + rc = -EISCONN; + goto out; + case SMC_INIT: + break; + } + + smc_copy_sock_settings_to_clc(smc); + tcp_sk(smc->clcsock->sk)->syn_smc = 1; + if (smc->connect_nonblock) { + rc = -EALREADY; + goto out; + } + rc = kernel_connect(smc->clcsock, addr, alen, flags); + if (rc && rc != -EINPROGRESS) + goto out; + + if (smc->use_fallback) { + sock->state = rc ? SS_CONNECTING : SS_CONNECTED; + goto out; + } + sock_hold(&smc->sk); /* sock put in passive closing */ + if (flags & O_NONBLOCK) { + if (queue_work(smc_hs_wq, &smc->connect_work)) + smc->connect_nonblock = 1; + rc = -EINPROGRESS; + goto out; + } else { + rc = __smc_connect(smc); + if (rc < 0) + goto out; + } + +connected: + rc = 0; + sock->state = SS_CONNECTED; +out: + release_sock(sk); +out_err: + return rc; +} + +static int smc_clcsock_accept(struct smc_sock *lsmc, struct smc_sock **new_smc) +{ + struct socket *new_clcsock = NULL; + struct sock *lsk = &lsmc->sk; + struct sock *new_sk; + int rc = -EINVAL; + + release_sock(lsk); + new_sk = smc_sock_alloc(sock_net(lsk), NULL, lsk->sk_protocol); + if (!new_sk) { + rc = -ENOMEM; + lsk->sk_err = ENOMEM; + *new_smc = NULL; + lock_sock(lsk); + goto out; + } + *new_smc = smc_sk(new_sk); + + mutex_lock(&lsmc->clcsock_release_lock); + if (lsmc->clcsock) + rc = kernel_accept(lsmc->clcsock, &new_clcsock, SOCK_NONBLOCK); + mutex_unlock(&lsmc->clcsock_release_lock); + lock_sock(lsk); + if (rc < 0 && rc != -EAGAIN) + lsk->sk_err = -rc; + if (rc < 0 || lsk->sk_state == SMC_CLOSED) { + new_sk->sk_prot->unhash(new_sk); + if (new_clcsock) + sock_release(new_clcsock); + new_sk->sk_state = SMC_CLOSED; + smc_sock_set_flag(new_sk, SOCK_DEAD); + sock_put(new_sk); /* final */ + *new_smc = NULL; + goto out; + } + + /* new clcsock has inherited the smc listen-specific sk_data_ready + * function; switch it back to the original sk_data_ready function + */ + new_clcsock->sk->sk_data_ready = lsmc->clcsk_data_ready; + + /* if new clcsock has also inherited the fallback-specific callback + * functions, switch them back to the original ones. + */ + if (lsmc->use_fallback) { + if (lsmc->clcsk_state_change) + new_clcsock->sk->sk_state_change = lsmc->clcsk_state_change; + if (lsmc->clcsk_write_space) + new_clcsock->sk->sk_write_space = lsmc->clcsk_write_space; + if (lsmc->clcsk_error_report) + new_clcsock->sk->sk_error_report = lsmc->clcsk_error_report; + } + + (*new_smc)->clcsock = new_clcsock; +out: + return rc; +} + +/* add a just created sock to the accept queue of the listen sock as + * candidate for a following socket accept call from user space + */ +static void smc_accept_enqueue(struct sock *parent, struct sock *sk) +{ + struct smc_sock *par = smc_sk(parent); + + sock_hold(sk); /* sock_put in smc_accept_unlink () */ + spin_lock(&par->accept_q_lock); + list_add_tail(&smc_sk(sk)->accept_q, &par->accept_q); + spin_unlock(&par->accept_q_lock); + sk_acceptq_added(parent); +} + +/* remove a socket from the accept queue of its parental listening socket */ +static void smc_accept_unlink(struct sock *sk) +{ + struct smc_sock *par = smc_sk(sk)->listen_smc; + + spin_lock(&par->accept_q_lock); + list_del_init(&smc_sk(sk)->accept_q); + spin_unlock(&par->accept_q_lock); + sk_acceptq_removed(&smc_sk(sk)->listen_smc->sk); + sock_put(sk); /* sock_hold in smc_accept_enqueue */ +} + +/* remove a sock from the accept queue to bind it to a new socket created + * for a socket accept call from user space + */ +struct sock *smc_accept_dequeue(struct sock *parent, + struct socket *new_sock) +{ + struct smc_sock *isk, *n; + struct sock *new_sk; + + list_for_each_entry_safe(isk, n, &smc_sk(parent)->accept_q, accept_q) { + new_sk = (struct sock *)isk; + + smc_accept_unlink(new_sk); + if (new_sk->sk_state == SMC_CLOSED) { + new_sk->sk_prot->unhash(new_sk); + if (isk->clcsock) { + sock_release(isk->clcsock); + isk->clcsock = NULL; + } + sock_put(new_sk); /* final */ + continue; + } + if (new_sock) { + sock_graft(new_sk, new_sock); + new_sock->state = SS_CONNECTED; + if (isk->use_fallback) { + smc_sk(new_sk)->clcsock->file = new_sock->file; + isk->clcsock->file->private_data = isk->clcsock; + } + } + return new_sk; + } + return NULL; +} + +/* clean up for a created but never accepted sock */ +void smc_close_non_accepted(struct sock *sk) +{ + struct smc_sock *smc = smc_sk(sk); + + sock_hold(sk); /* sock_put below */ + lock_sock(sk); + if (!sk->sk_lingertime) + /* wait for peer closing */ + WRITE_ONCE(sk->sk_lingertime, SMC_MAX_STREAM_WAIT_TIMEOUT); + __smc_release(smc); + release_sock(sk); + sock_put(sk); /* sock_hold above */ + sock_put(sk); /* final sock_put */ +} + +static int smcr_serv_conf_first_link(struct smc_sock *smc) +{ + struct smc_link *link = smc->conn.lnk; + struct smc_llc_qentry *qentry; + int rc; + + /* reg the sndbuf if it was vzalloced*/ + if (smc->conn.sndbuf_desc->is_vm) { + if (smcr_link_reg_buf(link, smc->conn.sndbuf_desc)) + return SMC_CLC_DECL_ERR_REGBUF; + } + + /* reg the rmb */ + if (smcr_link_reg_buf(link, smc->conn.rmb_desc)) + return SMC_CLC_DECL_ERR_REGBUF; + + /* send CONFIRM LINK request to client over the RoCE fabric */ + rc = smc_llc_send_confirm_link(link, SMC_LLC_REQ); + if (rc < 0) + return SMC_CLC_DECL_TIMEOUT_CL; + + /* receive CONFIRM LINK response from client over the RoCE fabric */ + qentry = smc_llc_wait(link->lgr, link, SMC_LLC_WAIT_TIME, + SMC_LLC_CONFIRM_LINK); + if (!qentry) { + struct smc_clc_msg_decline dclc; + + rc = smc_clc_wait_msg(smc, &dclc, sizeof(dclc), + SMC_CLC_DECLINE, CLC_WAIT_TIME_SHORT); + return rc == -EAGAIN ? SMC_CLC_DECL_TIMEOUT_CL : rc; + } + smc_llc_save_peer_uid(qentry); + rc = smc_llc_eval_conf_link(qentry, SMC_LLC_RESP); + smc_llc_flow_qentry_del(&link->lgr->llc_flow_lcl); + if (rc) + return SMC_CLC_DECL_RMBE_EC; + + /* confirm_rkey is implicit on 1st contact */ + smc->conn.rmb_desc->is_conf_rkey = true; + + smc_llc_link_active(link); + smcr_lgr_set_type(link->lgr, SMC_LGR_SINGLE); + + mutex_lock(&link->lgr->llc_conf_mutex); + /* initial contact - try to establish second link */ + smc_llc_srv_add_link(link, NULL); + mutex_unlock(&link->lgr->llc_conf_mutex); + return 0; +} + +/* listen worker: finish */ +static void smc_listen_out(struct smc_sock *new_smc) +{ + struct smc_sock *lsmc = new_smc->listen_smc; + struct sock *newsmcsk = &new_smc->sk; + + if (tcp_sk(new_smc->clcsock->sk)->syn_smc) + atomic_dec(&lsmc->queued_smc_hs); + + if (lsmc->sk.sk_state == SMC_LISTEN) { + lock_sock_nested(&lsmc->sk, SINGLE_DEPTH_NESTING); + smc_accept_enqueue(&lsmc->sk, newsmcsk); + release_sock(&lsmc->sk); + } else { /* no longer listening */ + smc_close_non_accepted(newsmcsk); + } + + /* Wake up accept */ + lsmc->sk.sk_data_ready(&lsmc->sk); + sock_put(&lsmc->sk); /* sock_hold in smc_tcp_listen_work */ +} + +/* listen worker: finish in state connected */ +static void smc_listen_out_connected(struct smc_sock *new_smc) +{ + struct sock *newsmcsk = &new_smc->sk; + + if (newsmcsk->sk_state == SMC_INIT) + newsmcsk->sk_state = SMC_ACTIVE; + + smc_listen_out(new_smc); +} + +/* listen worker: finish in error state */ +static void smc_listen_out_err(struct smc_sock *new_smc) +{ + struct sock *newsmcsk = &new_smc->sk; + struct net *net = sock_net(newsmcsk); + + this_cpu_inc(net->smc.smc_stats->srv_hshake_err_cnt); + if (newsmcsk->sk_state == SMC_INIT) + sock_put(&new_smc->sk); /* passive closing */ + newsmcsk->sk_state = SMC_CLOSED; + + smc_listen_out(new_smc); +} + +/* listen worker: decline and fall back if possible */ +static void smc_listen_decline(struct smc_sock *new_smc, int reason_code, + int local_first, u8 version) +{ + /* RDMA setup failed, switch back to TCP */ + smc_conn_abort(new_smc, local_first); + if (reason_code < 0 || + smc_switch_to_fallback(new_smc, reason_code)) { + /* error, no fallback possible */ + smc_listen_out_err(new_smc); + return; + } + if (reason_code && reason_code != SMC_CLC_DECL_PEERDECL) { + if (smc_clc_send_decline(new_smc, reason_code, version) < 0) { + smc_listen_out_err(new_smc); + return; + } + } + smc_listen_out_connected(new_smc); +} + +/* listen worker: version checking */ +static int smc_listen_v2_check(struct smc_sock *new_smc, + struct smc_clc_msg_proposal *pclc, + struct smc_init_info *ini) +{ + struct smc_clc_smcd_v2_extension *pclc_smcd_v2_ext; + struct smc_clc_v2_extension *pclc_v2_ext; + int rc = SMC_CLC_DECL_PEERNOSMC; + + ini->smc_type_v1 = pclc->hdr.typev1; + ini->smc_type_v2 = pclc->hdr.typev2; + ini->smcd_version = smcd_indicated(ini->smc_type_v1) ? SMC_V1 : 0; + ini->smcr_version = smcr_indicated(ini->smc_type_v1) ? SMC_V1 : 0; + if (pclc->hdr.version > SMC_V1) { + if (smcd_indicated(ini->smc_type_v2)) + ini->smcd_version |= SMC_V2; + if (smcr_indicated(ini->smc_type_v2)) + ini->smcr_version |= SMC_V2; + } + if (!(ini->smcd_version & SMC_V2) && !(ini->smcr_version & SMC_V2)) { + rc = SMC_CLC_DECL_PEERNOSMC; + goto out; + } + pclc_v2_ext = smc_get_clc_v2_ext(pclc); + if (!pclc_v2_ext) { + ini->smcd_version &= ~SMC_V2; + ini->smcr_version &= ~SMC_V2; + rc = SMC_CLC_DECL_NOV2EXT; + goto out; + } + pclc_smcd_v2_ext = smc_get_clc_smcd_v2_ext(pclc_v2_ext); + if (ini->smcd_version & SMC_V2) { + if (!smc_ism_is_v2_capable()) { + ini->smcd_version &= ~SMC_V2; + rc = SMC_CLC_DECL_NOISM2SUPP; + } else if (!pclc_smcd_v2_ext) { + ini->smcd_version &= ~SMC_V2; + rc = SMC_CLC_DECL_NOV2DEXT; + } else if (!pclc_v2_ext->hdr.eid_cnt && + !pclc_v2_ext->hdr.flag.seid) { + ini->smcd_version &= ~SMC_V2; + rc = SMC_CLC_DECL_NOUEID; + } + } + if (ini->smcr_version & SMC_V2) { + if (!pclc_v2_ext->hdr.eid_cnt) { + ini->smcr_version &= ~SMC_V2; + rc = SMC_CLC_DECL_NOUEID; + } + } + +out: + if (!ini->smcd_version && !ini->smcr_version) + return rc; + + return 0; +} + +/* listen worker: check prefixes */ +static int smc_listen_prfx_check(struct smc_sock *new_smc, + struct smc_clc_msg_proposal *pclc) +{ + struct smc_clc_msg_proposal_prefix *pclc_prfx; + struct socket *newclcsock = new_smc->clcsock; + + if (pclc->hdr.typev1 == SMC_TYPE_N) + return 0; + pclc_prfx = smc_clc_proposal_get_prefix(pclc); + if (smc_clc_prfx_match(newclcsock, pclc_prfx)) + return SMC_CLC_DECL_DIFFPREFIX; + + return 0; +} + +/* listen worker: initialize connection and buffers */ +static int smc_listen_rdma_init(struct smc_sock *new_smc, + struct smc_init_info *ini) +{ + int rc; + + /* allocate connection / link group */ + rc = smc_conn_create(new_smc, ini); + if (rc) + return rc; + + /* create send buffer and rmb */ + if (smc_buf_create(new_smc, false)) { + smc_conn_abort(new_smc, ini->first_contact_local); + return SMC_CLC_DECL_MEM; + } + + return 0; +} + +/* listen worker: initialize connection and buffers for SMC-D */ +static int smc_listen_ism_init(struct smc_sock *new_smc, + struct smc_init_info *ini) +{ + int rc; + + rc = smc_conn_create(new_smc, ini); + if (rc) + return rc; + + /* Create send and receive buffers */ + rc = smc_buf_create(new_smc, true); + if (rc) { + smc_conn_abort(new_smc, ini->first_contact_local); + return (rc == -ENOSPC) ? SMC_CLC_DECL_MAX_DMB : + SMC_CLC_DECL_MEM; + } + + return 0; +} + +static bool smc_is_already_selected(struct smcd_dev *smcd, + struct smc_init_info *ini, + int matches) +{ + int i; + + for (i = 0; i < matches; i++) + if (smcd == ini->ism_dev[i]) + return true; + + return false; +} + +/* check for ISM devices matching proposed ISM devices */ +static void smc_check_ism_v2_match(struct smc_init_info *ini, + u16 proposed_chid, u64 proposed_gid, + unsigned int *matches) +{ + struct smcd_dev *smcd; + + list_for_each_entry(smcd, &smcd_dev_list.list, list) { + if (smcd->going_away) + continue; + if (smc_is_already_selected(smcd, ini, *matches)) + continue; + if (smc_ism_get_chid(smcd) == proposed_chid && + !smc_ism_cantalk(proposed_gid, ISM_RESERVED_VLANID, smcd)) { + ini->ism_peer_gid[*matches] = proposed_gid; + ini->ism_dev[*matches] = smcd; + (*matches)++; + break; + } + } +} + +static void smc_find_ism_store_rc(u32 rc, struct smc_init_info *ini) +{ + if (!ini->rc) + ini->rc = rc; +} + +static void smc_find_ism_v2_device_serv(struct smc_sock *new_smc, + struct smc_clc_msg_proposal *pclc, + struct smc_init_info *ini) +{ + struct smc_clc_smcd_v2_extension *smcd_v2_ext; + struct smc_clc_v2_extension *smc_v2_ext; + struct smc_clc_msg_smcd *pclc_smcd; + unsigned int matches = 0; + u8 smcd_version; + u8 *eid = NULL; + int i, rc; + + if (!(ini->smcd_version & SMC_V2) || !smcd_indicated(ini->smc_type_v2)) + goto not_found; + + pclc_smcd = smc_get_clc_msg_smcd(pclc); + smc_v2_ext = smc_get_clc_v2_ext(pclc); + smcd_v2_ext = smc_get_clc_smcd_v2_ext(smc_v2_ext); + + mutex_lock(&smcd_dev_list.mutex); + if (pclc_smcd->ism.chid) + /* check for ISM device matching proposed native ISM device */ + smc_check_ism_v2_match(ini, ntohs(pclc_smcd->ism.chid), + ntohll(pclc_smcd->ism.gid), &matches); + for (i = 1; i <= smc_v2_ext->hdr.ism_gid_cnt; i++) { + /* check for ISM devices matching proposed non-native ISM + * devices + */ + smc_check_ism_v2_match(ini, + ntohs(smcd_v2_ext->gidchid[i - 1].chid), + ntohll(smcd_v2_ext->gidchid[i - 1].gid), + &matches); + } + mutex_unlock(&smcd_dev_list.mutex); + + if (!ini->ism_dev[0]) { + smc_find_ism_store_rc(SMC_CLC_DECL_NOSMCD2DEV, ini); + goto not_found; + } + + smc_ism_get_system_eid(&eid); + if (!smc_clc_match_eid(ini->negotiated_eid, smc_v2_ext, + smcd_v2_ext->system_eid, eid)) + goto not_found; + + /* separate - outside the smcd_dev_list.lock */ + smcd_version = ini->smcd_version; + for (i = 0; i < matches; i++) { + ini->smcd_version = SMC_V2; + ini->is_smcd = true; + ini->ism_selected = i; + rc = smc_listen_ism_init(new_smc, ini); + if (rc) { + smc_find_ism_store_rc(rc, ini); + /* try next active ISM device */ + continue; + } + return; /* matching and usable V2 ISM device found */ + } + /* no V2 ISM device could be initialized */ + ini->smcd_version = smcd_version; /* restore original value */ + ini->negotiated_eid[0] = 0; + +not_found: + ini->smcd_version &= ~SMC_V2; + ini->ism_dev[0] = NULL; + ini->is_smcd = false; +} + +static void smc_find_ism_v1_device_serv(struct smc_sock *new_smc, + struct smc_clc_msg_proposal *pclc, + struct smc_init_info *ini) +{ + struct smc_clc_msg_smcd *pclc_smcd = smc_get_clc_msg_smcd(pclc); + int rc = 0; + + /* check if ISM V1 is available */ + if (!(ini->smcd_version & SMC_V1) || !smcd_indicated(ini->smc_type_v1)) + goto not_found; + ini->is_smcd = true; /* prepare ISM check */ + ini->ism_peer_gid[0] = ntohll(pclc_smcd->ism.gid); + rc = smc_find_ism_device(new_smc, ini); + if (rc) + goto not_found; + ini->ism_selected = 0; + rc = smc_listen_ism_init(new_smc, ini); + if (!rc) + return; /* V1 ISM device found */ + +not_found: + smc_find_ism_store_rc(rc, ini); + ini->smcd_version &= ~SMC_V1; + ini->ism_dev[0] = NULL; + ini->is_smcd = false; +} + +/* listen worker: register buffers */ +static int smc_listen_rdma_reg(struct smc_sock *new_smc, bool local_first) +{ + struct smc_connection *conn = &new_smc->conn; + + if (!local_first) { + /* reg sendbufs if they were vzalloced */ + if (conn->sndbuf_desc->is_vm) { + if (smcr_lgr_reg_sndbufs(conn->lnk, + conn->sndbuf_desc)) + return SMC_CLC_DECL_ERR_REGBUF; + } + if (smcr_lgr_reg_rmbs(conn->lnk, conn->rmb_desc)) + return SMC_CLC_DECL_ERR_REGBUF; + } + + return 0; +} + +static void smc_find_rdma_v2_device_serv(struct smc_sock *new_smc, + struct smc_clc_msg_proposal *pclc, + struct smc_init_info *ini) +{ + struct smc_clc_v2_extension *smc_v2_ext; + u8 smcr_version; + int rc; + + if (!(ini->smcr_version & SMC_V2) || !smcr_indicated(ini->smc_type_v2)) + goto not_found; + + smc_v2_ext = smc_get_clc_v2_ext(pclc); + if (!smc_clc_match_eid(ini->negotiated_eid, smc_v2_ext, NULL, NULL)) + goto not_found; + + /* prepare RDMA check */ + memcpy(ini->peer_systemid, pclc->lcl.id_for_peer, SMC_SYSTEMID_LEN); + memcpy(ini->peer_gid, smc_v2_ext->roce, SMC_GID_SIZE); + memcpy(ini->peer_mac, pclc->lcl.mac, ETH_ALEN); + ini->check_smcrv2 = true; + ini->smcrv2.clc_sk = new_smc->clcsock->sk; + ini->smcrv2.saddr = new_smc->clcsock->sk->sk_rcv_saddr; + ini->smcrv2.daddr = smc_ib_gid_to_ipv4(smc_v2_ext->roce); + rc = smc_find_rdma_device(new_smc, ini); + if (rc) { + smc_find_ism_store_rc(rc, ini); + goto not_found; + } + if (!ini->smcrv2.uses_gateway) + memcpy(ini->smcrv2.nexthop_mac, pclc->lcl.mac, ETH_ALEN); + + smcr_version = ini->smcr_version; + ini->smcr_version = SMC_V2; + rc = smc_listen_rdma_init(new_smc, ini); + if (!rc) { + rc = smc_listen_rdma_reg(new_smc, ini->first_contact_local); + if (rc) + smc_conn_abort(new_smc, ini->first_contact_local); + } + if (!rc) + return; + ini->smcr_version = smcr_version; + smc_find_ism_store_rc(rc, ini); + +not_found: + ini->smcr_version &= ~SMC_V2; + ini->smcrv2.ib_dev_v2 = NULL; + ini->check_smcrv2 = false; +} + +static int smc_find_rdma_v1_device_serv(struct smc_sock *new_smc, + struct smc_clc_msg_proposal *pclc, + struct smc_init_info *ini) +{ + int rc; + + if (!(ini->smcr_version & SMC_V1) || !smcr_indicated(ini->smc_type_v1)) + return SMC_CLC_DECL_NOSMCDEV; + + /* prepare RDMA check */ + memcpy(ini->peer_systemid, pclc->lcl.id_for_peer, SMC_SYSTEMID_LEN); + memcpy(ini->peer_gid, pclc->lcl.gid, SMC_GID_SIZE); + memcpy(ini->peer_mac, pclc->lcl.mac, ETH_ALEN); + rc = smc_find_rdma_device(new_smc, ini); + if (rc) { + /* no RDMA device found */ + return SMC_CLC_DECL_NOSMCDEV; + } + rc = smc_listen_rdma_init(new_smc, ini); + if (rc) + return rc; + return smc_listen_rdma_reg(new_smc, ini->first_contact_local); +} + +/* determine the local device matching to proposal */ +static int smc_listen_find_device(struct smc_sock *new_smc, + struct smc_clc_msg_proposal *pclc, + struct smc_init_info *ini) +{ + int prfx_rc; + + /* check for ISM device matching V2 proposed device */ + smc_find_ism_v2_device_serv(new_smc, pclc, ini); + if (ini->ism_dev[0]) + return 0; + + /* check for matching IP prefix and subnet length (V1) */ + prfx_rc = smc_listen_prfx_check(new_smc, pclc); + if (prfx_rc) + smc_find_ism_store_rc(prfx_rc, ini); + + /* get vlan id from IP device */ + if (smc_vlan_by_tcpsk(new_smc->clcsock, ini)) + return ini->rc ?: SMC_CLC_DECL_GETVLANERR; + + /* check for ISM device matching V1 proposed device */ + if (!prfx_rc) + smc_find_ism_v1_device_serv(new_smc, pclc, ini); + if (ini->ism_dev[0]) + return 0; + + if (!smcr_indicated(pclc->hdr.typev1) && + !smcr_indicated(pclc->hdr.typev2)) + /* skip RDMA and decline */ + return ini->rc ?: SMC_CLC_DECL_NOSMCDDEV; + + /* check if RDMA V2 is available */ + smc_find_rdma_v2_device_serv(new_smc, pclc, ini); + if (ini->smcrv2.ib_dev_v2) + return 0; + + /* check if RDMA V1 is available */ + if (!prfx_rc) { + int rc; + + rc = smc_find_rdma_v1_device_serv(new_smc, pclc, ini); + smc_find_ism_store_rc(rc, ini); + return (!rc) ? 0 : ini->rc; + } + return prfx_rc; +} + +/* listen worker: finish RDMA setup */ +static int smc_listen_rdma_finish(struct smc_sock *new_smc, + struct smc_clc_msg_accept_confirm *cclc, + bool local_first, + struct smc_init_info *ini) +{ + struct smc_link *link = new_smc->conn.lnk; + int reason_code = 0; + + if (local_first) + smc_link_save_peer_info(link, cclc, ini); + + if (smc_rmb_rtoken_handling(&new_smc->conn, link, cclc)) + return SMC_CLC_DECL_ERR_RTOK; + + if (local_first) { + if (smc_ib_ready_link(link)) + return SMC_CLC_DECL_ERR_RDYLNK; + /* QP confirmation over RoCE fabric */ + smc_llc_flow_initiate(link->lgr, SMC_LLC_FLOW_ADD_LINK); + reason_code = smcr_serv_conf_first_link(new_smc); + smc_llc_flow_stop(link->lgr, &link->lgr->llc_flow_lcl); + } + return reason_code; +} + +/* setup for connection of server */ +static void smc_listen_work(struct work_struct *work) +{ + struct smc_sock *new_smc = container_of(work, struct smc_sock, + smc_listen_work); + struct socket *newclcsock = new_smc->clcsock; + struct smc_clc_msg_accept_confirm *cclc; + struct smc_clc_msg_proposal_area *buf; + struct smc_clc_msg_proposal *pclc; + struct smc_init_info *ini = NULL; + u8 proposal_version = SMC_V1; + u8 accept_version; + int rc = 0; + + if (new_smc->listen_smc->sk.sk_state != SMC_LISTEN) + return smc_listen_out_err(new_smc); + + if (new_smc->use_fallback) { + smc_listen_out_connected(new_smc); + return; + } + + /* check if peer is smc capable */ + if (!tcp_sk(newclcsock->sk)->syn_smc) { + rc = smc_switch_to_fallback(new_smc, SMC_CLC_DECL_PEERNOSMC); + if (rc) + smc_listen_out_err(new_smc); + else + smc_listen_out_connected(new_smc); + return; + } + + /* do inband token exchange - + * wait for and receive SMC Proposal CLC message + */ + buf = kzalloc(sizeof(*buf), GFP_KERNEL); + if (!buf) { + rc = SMC_CLC_DECL_MEM; + goto out_decl; + } + pclc = (struct smc_clc_msg_proposal *)buf; + rc = smc_clc_wait_msg(new_smc, pclc, sizeof(*buf), + SMC_CLC_PROPOSAL, CLC_WAIT_TIME); + if (rc) + goto out_decl; + + if (pclc->hdr.version > SMC_V1) + proposal_version = SMC_V2; + + /* IPSec connections opt out of SMC optimizations */ + if (using_ipsec(new_smc)) { + rc = SMC_CLC_DECL_IPSEC; + goto out_decl; + } + + ini = kzalloc(sizeof(*ini), GFP_KERNEL); + if (!ini) { + rc = SMC_CLC_DECL_MEM; + goto out_decl; + } + + /* initial version checking */ + rc = smc_listen_v2_check(new_smc, pclc, ini); + if (rc) + goto out_decl; + + mutex_lock(&smc_server_lgr_pending); + smc_close_init(new_smc); + smc_rx_init(new_smc); + smc_tx_init(new_smc); + + /* determine ISM or RoCE device used for connection */ + rc = smc_listen_find_device(new_smc, pclc, ini); + if (rc) + goto out_unlock; + + /* send SMC Accept CLC message */ + accept_version = ini->is_smcd ? ini->smcd_version : ini->smcr_version; + rc = smc_clc_send_accept(new_smc, ini->first_contact_local, + accept_version, ini->negotiated_eid); + if (rc) + goto out_unlock; + + /* SMC-D does not need this lock any more */ + if (ini->is_smcd) + mutex_unlock(&smc_server_lgr_pending); + + /* receive SMC Confirm CLC message */ + memset(buf, 0, sizeof(*buf)); + cclc = (struct smc_clc_msg_accept_confirm *)buf; + rc = smc_clc_wait_msg(new_smc, cclc, sizeof(*buf), + SMC_CLC_CONFIRM, CLC_WAIT_TIME); + if (rc) { + if (!ini->is_smcd) + goto out_unlock; + goto out_decl; + } + + /* finish worker */ + if (!ini->is_smcd) { + rc = smc_listen_rdma_finish(new_smc, cclc, + ini->first_contact_local, ini); + if (rc) + goto out_unlock; + mutex_unlock(&smc_server_lgr_pending); + } + smc_conn_save_peer_info(new_smc, cclc); + smc_listen_out_connected(new_smc); + SMC_STAT_SERV_SUCC_INC(sock_net(newclcsock->sk), ini); + goto out_free; + +out_unlock: + mutex_unlock(&smc_server_lgr_pending); +out_decl: + smc_listen_decline(new_smc, rc, ini ? ini->first_contact_local : 0, + proposal_version); +out_free: + kfree(ini); + kfree(buf); +} + +static void smc_tcp_listen_work(struct work_struct *work) +{ + struct smc_sock *lsmc = container_of(work, struct smc_sock, + tcp_listen_work); + struct sock *lsk = &lsmc->sk; + struct smc_sock *new_smc; + int rc = 0; + + lock_sock(lsk); + while (lsk->sk_state == SMC_LISTEN) { + rc = smc_clcsock_accept(lsmc, &new_smc); + if (rc) /* clcsock accept queue empty or error */ + goto out; + if (!new_smc) + continue; + + if (tcp_sk(new_smc->clcsock->sk)->syn_smc) + atomic_inc(&lsmc->queued_smc_hs); + + new_smc->listen_smc = lsmc; + new_smc->use_fallback = lsmc->use_fallback; + new_smc->fallback_rsn = lsmc->fallback_rsn; + sock_hold(lsk); /* sock_put in smc_listen_work */ + INIT_WORK(&new_smc->smc_listen_work, smc_listen_work); + smc_copy_sock_settings_to_smc(new_smc); + sock_hold(&new_smc->sk); /* sock_put in passive closing */ + if (!queue_work(smc_hs_wq, &new_smc->smc_listen_work)) + sock_put(&new_smc->sk); + } + +out: + release_sock(lsk); + sock_put(&lsmc->sk); /* sock_hold in smc_clcsock_data_ready() */ +} + +static void smc_clcsock_data_ready(struct sock *listen_clcsock) +{ + struct smc_sock *lsmc; + + read_lock_bh(&listen_clcsock->sk_callback_lock); + lsmc = smc_clcsock_user_data(listen_clcsock); + if (!lsmc) + goto out; + lsmc->clcsk_data_ready(listen_clcsock); + if (lsmc->sk.sk_state == SMC_LISTEN) { + sock_hold(&lsmc->sk); /* sock_put in smc_tcp_listen_work() */ + if (!queue_work(smc_tcp_ls_wq, &lsmc->tcp_listen_work)) + sock_put(&lsmc->sk); + } +out: + read_unlock_bh(&listen_clcsock->sk_callback_lock); +} + +static int smc_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + struct smc_sock *smc; + int rc; + + smc = smc_sk(sk); + lock_sock(sk); + + rc = -EINVAL; + if ((sk->sk_state != SMC_INIT && sk->sk_state != SMC_LISTEN) || + smc->connect_nonblock || sock->state != SS_UNCONNECTED) + goto out; + + rc = 0; + if (sk->sk_state == SMC_LISTEN) { + sk->sk_max_ack_backlog = backlog; + goto out; + } + /* some socket options are handled in core, so we could not apply + * them to the clc socket -- copy smc socket options to clc socket + */ + smc_copy_sock_settings_to_clc(smc); + if (!smc->use_fallback) + tcp_sk(smc->clcsock->sk)->syn_smc = 1; + + /* save original sk_data_ready function and establish + * smc-specific sk_data_ready function + */ + write_lock_bh(&smc->clcsock->sk->sk_callback_lock); + smc->clcsock->sk->sk_user_data = + (void *)((uintptr_t)smc | SK_USER_DATA_NOCOPY); + smc_clcsock_replace_cb(&smc->clcsock->sk->sk_data_ready, + smc_clcsock_data_ready, &smc->clcsk_data_ready); + write_unlock_bh(&smc->clcsock->sk->sk_callback_lock); + + /* save original ops */ + smc->ori_af_ops = inet_csk(smc->clcsock->sk)->icsk_af_ops; + + smc->af_ops = *smc->ori_af_ops; + smc->af_ops.syn_recv_sock = smc_tcp_syn_recv_sock; + + inet_csk(smc->clcsock->sk)->icsk_af_ops = &smc->af_ops; + + if (smc->limit_smc_hs) + tcp_sk(smc->clcsock->sk)->smc_hs_congested = smc_hs_congested; + + rc = kernel_listen(smc->clcsock, backlog); + if (rc) { + write_lock_bh(&smc->clcsock->sk->sk_callback_lock); + smc_clcsock_restore_cb(&smc->clcsock->sk->sk_data_ready, + &smc->clcsk_data_ready); + smc->clcsock->sk->sk_user_data = NULL; + write_unlock_bh(&smc->clcsock->sk->sk_callback_lock); + goto out; + } + sk->sk_max_ack_backlog = backlog; + sk->sk_ack_backlog = 0; + sk->sk_state = SMC_LISTEN; + +out: + release_sock(sk); + return rc; +} + +static int smc_accept(struct socket *sock, struct socket *new_sock, + int flags, bool kern) +{ + struct sock *sk = sock->sk, *nsk; + DECLARE_WAITQUEUE(wait, current); + struct smc_sock *lsmc; + long timeo; + int rc = 0; + + lsmc = smc_sk(sk); + sock_hold(sk); /* sock_put below */ + lock_sock(sk); + + if (lsmc->sk.sk_state != SMC_LISTEN) { + rc = -EINVAL; + release_sock(sk); + goto out; + } + + /* Wait for an incoming connection */ + timeo = sock_rcvtimeo(sk, flags & O_NONBLOCK); + add_wait_queue_exclusive(sk_sleep(sk), &wait); + while (!(nsk = smc_accept_dequeue(sk, new_sock))) { + set_current_state(TASK_INTERRUPTIBLE); + if (!timeo) { + rc = -EAGAIN; + break; + } + release_sock(sk); + timeo = schedule_timeout(timeo); + /* wakeup by sk_data_ready in smc_listen_work() */ + sched_annotate_sleep(); + lock_sock(sk); + if (signal_pending(current)) { + rc = sock_intr_errno(timeo); + break; + } + } + set_current_state(TASK_RUNNING); + remove_wait_queue(sk_sleep(sk), &wait); + + if (!rc) + rc = sock_error(nsk); + release_sock(sk); + if (rc) + goto out; + + if (lsmc->sockopt_defer_accept && !(flags & O_NONBLOCK)) { + /* wait till data arrives on the socket */ + timeo = msecs_to_jiffies(lsmc->sockopt_defer_accept * + MSEC_PER_SEC); + if (smc_sk(nsk)->use_fallback) { + struct sock *clcsk = smc_sk(nsk)->clcsock->sk; + + lock_sock(clcsk); + if (skb_queue_empty(&clcsk->sk_receive_queue)) + sk_wait_data(clcsk, &timeo, NULL); + release_sock(clcsk); + } else if (!atomic_read(&smc_sk(nsk)->conn.bytes_to_rcv)) { + lock_sock(nsk); + smc_rx_wait(smc_sk(nsk), &timeo, smc_rx_data_available); + release_sock(nsk); + } + } + +out: + sock_put(sk); /* sock_hold above */ + return rc; +} + +static int smc_getname(struct socket *sock, struct sockaddr *addr, + int peer) +{ + struct smc_sock *smc; + + if (peer && (sock->sk->sk_state != SMC_ACTIVE) && + (sock->sk->sk_state != SMC_APPCLOSEWAIT1)) + return -ENOTCONN; + + smc = smc_sk(sock->sk); + + return smc->clcsock->ops->getname(smc->clcsock, addr, peer); +} + +static int smc_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) +{ + struct sock *sk = sock->sk; + struct smc_sock *smc; + int rc; + + smc = smc_sk(sk); + lock_sock(sk); + + /* SMC does not support connect with fastopen */ + if (msg->msg_flags & MSG_FASTOPEN) { + /* not connected yet, fallback */ + if (sk->sk_state == SMC_INIT && !smc->connect_nonblock) { + rc = smc_switch_to_fallback(smc, SMC_CLC_DECL_OPTUNSUPP); + if (rc) + goto out; + } else { + rc = -EINVAL; + goto out; + } + } else if ((sk->sk_state != SMC_ACTIVE) && + (sk->sk_state != SMC_APPCLOSEWAIT1) && + (sk->sk_state != SMC_INIT)) { + rc = -EPIPE; + goto out; + } + + if (smc->use_fallback) { + rc = smc->clcsock->ops->sendmsg(smc->clcsock, msg, len); + } else { + rc = smc_tx_sendmsg(smc, msg, len); + SMC_STAT_TX_PAYLOAD(smc, len, rc); + } +out: + release_sock(sk); + return rc; +} + +static int smc_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, + int flags) +{ + struct sock *sk = sock->sk; + struct smc_sock *smc; + int rc = -ENOTCONN; + + smc = smc_sk(sk); + lock_sock(sk); + if (sk->sk_state == SMC_CLOSED && (sk->sk_shutdown & RCV_SHUTDOWN)) { + /* socket was connected before, no more data to read */ + rc = 0; + goto out; + } + if ((sk->sk_state == SMC_INIT) || + (sk->sk_state == SMC_LISTEN) || + (sk->sk_state == SMC_CLOSED)) + goto out; + + if (sk->sk_state == SMC_PEERFINCLOSEWAIT) { + rc = 0; + goto out; + } + + if (smc->use_fallback) { + rc = smc->clcsock->ops->recvmsg(smc->clcsock, msg, len, flags); + } else { + msg->msg_namelen = 0; + rc = smc_rx_recvmsg(smc, msg, NULL, len, flags); + SMC_STAT_RX_PAYLOAD(smc, rc, rc); + } + +out: + release_sock(sk); + return rc; +} + +static __poll_t smc_accept_poll(struct sock *parent) +{ + struct smc_sock *isk = smc_sk(parent); + __poll_t mask = 0; + + spin_lock(&isk->accept_q_lock); + if (!list_empty(&isk->accept_q)) + mask = EPOLLIN | EPOLLRDNORM; + spin_unlock(&isk->accept_q_lock); + + return mask; +} + +static __poll_t smc_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + struct sock *sk = sock->sk; + struct smc_sock *smc; + __poll_t mask = 0; + + if (!sk) + return EPOLLNVAL; + + smc = smc_sk(sock->sk); + if (smc->use_fallback) { + /* delegate to CLC child sock */ + mask = smc->clcsock->ops->poll(file, smc->clcsock, wait); + sk->sk_err = smc->clcsock->sk->sk_err; + } else { + if (sk->sk_state != SMC_CLOSED) + sock_poll_wait(file, sock, wait); + if (sk->sk_err) + mask |= EPOLLERR; + if ((sk->sk_shutdown == SHUTDOWN_MASK) || + (sk->sk_state == SMC_CLOSED)) + mask |= EPOLLHUP; + if (sk->sk_state == SMC_LISTEN) { + /* woken up by sk_data_ready in smc_listen_work() */ + mask |= smc_accept_poll(sk); + } else if (smc->use_fallback) { /* as result of connect_work()*/ + mask |= smc->clcsock->ops->poll(file, smc->clcsock, + wait); + sk->sk_err = smc->clcsock->sk->sk_err; + } else { + if ((sk->sk_state != SMC_INIT && + atomic_read(&smc->conn.sndbuf_space)) || + sk->sk_shutdown & SEND_SHUTDOWN) { + mask |= EPOLLOUT | EPOLLWRNORM; + } else { + sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk); + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + } + if (atomic_read(&smc->conn.bytes_to_rcv)) + mask |= EPOLLIN | EPOLLRDNORM; + if (sk->sk_shutdown & RCV_SHUTDOWN) + mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP; + if (sk->sk_state == SMC_APPCLOSEWAIT1) + mask |= EPOLLIN; + if (smc->conn.urg_state == SMC_URG_VALID) + mask |= EPOLLPRI; + } + } + + return mask; +} + +static int smc_shutdown(struct socket *sock, int how) +{ + struct sock *sk = sock->sk; + bool do_shutdown = true; + struct smc_sock *smc; + int rc = -EINVAL; + int old_state; + int rc1 = 0; + + smc = smc_sk(sk); + + if ((how < SHUT_RD) || (how > SHUT_RDWR)) + return rc; + + lock_sock(sk); + + if (sock->state == SS_CONNECTING) { + if (sk->sk_state == SMC_ACTIVE) + sock->state = SS_CONNECTED; + else if (sk->sk_state == SMC_PEERCLOSEWAIT1 || + sk->sk_state == SMC_PEERCLOSEWAIT2 || + sk->sk_state == SMC_APPCLOSEWAIT1 || + sk->sk_state == SMC_APPCLOSEWAIT2 || + sk->sk_state == SMC_APPFINCLOSEWAIT) + sock->state = SS_DISCONNECTING; + } + + rc = -ENOTCONN; + if ((sk->sk_state != SMC_ACTIVE) && + (sk->sk_state != SMC_PEERCLOSEWAIT1) && + (sk->sk_state != SMC_PEERCLOSEWAIT2) && + (sk->sk_state != SMC_APPCLOSEWAIT1) && + (sk->sk_state != SMC_APPCLOSEWAIT2) && + (sk->sk_state != SMC_APPFINCLOSEWAIT)) + goto out; + if (smc->use_fallback) { + rc = kernel_sock_shutdown(smc->clcsock, how); + sk->sk_shutdown = smc->clcsock->sk->sk_shutdown; + if (sk->sk_shutdown == SHUTDOWN_MASK) { + sk->sk_state = SMC_CLOSED; + sk->sk_socket->state = SS_UNCONNECTED; + sock_put(sk); + } + goto out; + } + switch (how) { + case SHUT_RDWR: /* shutdown in both directions */ + old_state = sk->sk_state; + rc = smc_close_active(smc); + if (old_state == SMC_ACTIVE && + sk->sk_state == SMC_PEERCLOSEWAIT1) + do_shutdown = false; + break; + case SHUT_WR: + rc = smc_close_shutdown_write(smc); + break; + case SHUT_RD: + rc = 0; + /* nothing more to do because peer is not involved */ + break; + } + if (do_shutdown && smc->clcsock) + rc1 = kernel_sock_shutdown(smc->clcsock, how); + /* map sock_shutdown_cmd constants to sk_shutdown value range */ + sk->sk_shutdown |= how + 1; + + if (sk->sk_state == SMC_CLOSED) + sock->state = SS_UNCONNECTED; + else + sock->state = SS_DISCONNECTING; +out: + release_sock(sk); + return rc ? rc : rc1; +} + +static int __smc_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct smc_sock *smc; + int val, len; + + smc = smc_sk(sock->sk); + + if (get_user(len, optlen)) + return -EFAULT; + + len = min_t(int, len, sizeof(int)); + + if (len < 0) + return -EINVAL; + + switch (optname) { + case SMC_LIMIT_HS: + val = smc->limit_smc_hs; + break; + default: + return -EOPNOTSUPP; + } + + if (put_user(len, optlen)) + return -EFAULT; + if (copy_to_user(optval, &val, len)) + return -EFAULT; + + return 0; +} + +static int __smc_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct smc_sock *smc; + int val, rc; + + smc = smc_sk(sk); + + lock_sock(sk); + switch (optname) { + case SMC_LIMIT_HS: + if (optlen < sizeof(int)) { + rc = -EINVAL; + break; + } + if (copy_from_sockptr(&val, optval, sizeof(int))) { + rc = -EFAULT; + break; + } + + smc->limit_smc_hs = !!val; + rc = 0; + break; + default: + rc = -EOPNOTSUPP; + break; + } + release_sock(sk); + + return rc; +} + +static int smc_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct smc_sock *smc; + int val, rc; + + if (level == SOL_TCP && optname == TCP_ULP) + return -EOPNOTSUPP; + else if (level == SOL_SMC) + return __smc_setsockopt(sock, level, optname, optval, optlen); + + smc = smc_sk(sk); + + /* generic setsockopts reaching us here always apply to the + * CLC socket + */ + mutex_lock(&smc->clcsock_release_lock); + if (!smc->clcsock) { + mutex_unlock(&smc->clcsock_release_lock); + return -EBADF; + } + if (unlikely(!smc->clcsock->ops->setsockopt)) + rc = -EOPNOTSUPP; + else + rc = smc->clcsock->ops->setsockopt(smc->clcsock, level, optname, + optval, optlen); + if (smc->clcsock->sk->sk_err) { + sk->sk_err = smc->clcsock->sk->sk_err; + sk_error_report(sk); + } + mutex_unlock(&smc->clcsock_release_lock); + + if (optlen < sizeof(int)) + return -EINVAL; + if (copy_from_sockptr(&val, optval, sizeof(int))) + return -EFAULT; + + lock_sock(sk); + if (rc || smc->use_fallback) + goto out; + switch (optname) { + case TCP_FASTOPEN: + case TCP_FASTOPEN_CONNECT: + case TCP_FASTOPEN_KEY: + case TCP_FASTOPEN_NO_COOKIE: + /* option not supported by SMC */ + if (sk->sk_state == SMC_INIT && !smc->connect_nonblock) { + rc = smc_switch_to_fallback(smc, SMC_CLC_DECL_OPTUNSUPP); + } else { + rc = -EINVAL; + } + break; + case TCP_NODELAY: + if (sk->sk_state != SMC_INIT && + sk->sk_state != SMC_LISTEN && + sk->sk_state != SMC_CLOSED) { + if (val) { + SMC_STAT_INC(smc, ndly_cnt); + smc_tx_pending(&smc->conn); + cancel_delayed_work(&smc->conn.tx_work); + } + } + break; + case TCP_CORK: + if (sk->sk_state != SMC_INIT && + sk->sk_state != SMC_LISTEN && + sk->sk_state != SMC_CLOSED) { + if (!val) { + SMC_STAT_INC(smc, cork_cnt); + smc_tx_pending(&smc->conn); + cancel_delayed_work(&smc->conn.tx_work); + } + } + break; + case TCP_DEFER_ACCEPT: + smc->sockopt_defer_accept = val; + break; + default: + break; + } +out: + release_sock(sk); + + return rc; +} + +static int smc_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct smc_sock *smc; + int rc; + + if (level == SOL_SMC) + return __smc_getsockopt(sock, level, optname, optval, optlen); + + smc = smc_sk(sock->sk); + mutex_lock(&smc->clcsock_release_lock); + if (!smc->clcsock) { + mutex_unlock(&smc->clcsock_release_lock); + return -EBADF; + } + /* socket options apply to the CLC socket */ + if (unlikely(!smc->clcsock->ops->getsockopt)) { + mutex_unlock(&smc->clcsock_release_lock); + return -EOPNOTSUPP; + } + rc = smc->clcsock->ops->getsockopt(smc->clcsock, level, optname, + optval, optlen); + mutex_unlock(&smc->clcsock_release_lock); + return rc; +} + +static int smc_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + union smc_host_cursor cons, urg; + struct smc_connection *conn; + struct smc_sock *smc; + int answ; + + smc = smc_sk(sock->sk); + conn = &smc->conn; + lock_sock(&smc->sk); + if (smc->use_fallback) { + if (!smc->clcsock) { + release_sock(&smc->sk); + return -EBADF; + } + answ = smc->clcsock->ops->ioctl(smc->clcsock, cmd, arg); + release_sock(&smc->sk); + return answ; + } + switch (cmd) { + case SIOCINQ: /* same as FIONREAD */ + if (smc->sk.sk_state == SMC_LISTEN) { + release_sock(&smc->sk); + return -EINVAL; + } + if (smc->sk.sk_state == SMC_INIT || + smc->sk.sk_state == SMC_CLOSED) + answ = 0; + else + answ = atomic_read(&smc->conn.bytes_to_rcv); + break; + case SIOCOUTQ: + /* output queue size (not send + not acked) */ + if (smc->sk.sk_state == SMC_LISTEN) { + release_sock(&smc->sk); + return -EINVAL; + } + if (smc->sk.sk_state == SMC_INIT || + smc->sk.sk_state == SMC_CLOSED) + answ = 0; + else + answ = smc->conn.sndbuf_desc->len - + atomic_read(&smc->conn.sndbuf_space); + break; + case SIOCOUTQNSD: + /* output queue size (not send only) */ + if (smc->sk.sk_state == SMC_LISTEN) { + release_sock(&smc->sk); + return -EINVAL; + } + if (smc->sk.sk_state == SMC_INIT || + smc->sk.sk_state == SMC_CLOSED) + answ = 0; + else + answ = smc_tx_prepared_sends(&smc->conn); + break; + case SIOCATMARK: + if (smc->sk.sk_state == SMC_LISTEN) { + release_sock(&smc->sk); + return -EINVAL; + } + if (smc->sk.sk_state == SMC_INIT || + smc->sk.sk_state == SMC_CLOSED) { + answ = 0; + } else { + smc_curs_copy(&cons, &conn->local_tx_ctrl.cons, conn); + smc_curs_copy(&urg, &conn->urg_curs, conn); + answ = smc_curs_diff(conn->rmb_desc->len, + &cons, &urg) == 1; + } + break; + default: + release_sock(&smc->sk); + return -ENOIOCTLCMD; + } + release_sock(&smc->sk); + + return put_user(answ, (int __user *)arg); +} + +static ssize_t smc_sendpage(struct socket *sock, struct page *page, + int offset, size_t size, int flags) +{ + struct sock *sk = sock->sk; + struct smc_sock *smc; + int rc = -EPIPE; + + smc = smc_sk(sk); + lock_sock(sk); + if (sk->sk_state != SMC_ACTIVE) { + release_sock(sk); + goto out; + } + release_sock(sk); + if (smc->use_fallback) { + rc = kernel_sendpage(smc->clcsock, page, offset, + size, flags); + } else { + lock_sock(sk); + rc = smc_tx_sendpage(smc, page, offset, size, flags); + release_sock(sk); + SMC_STAT_INC(smc, sendpage_cnt); + } + +out: + return rc; +} + +/* Map the affected portions of the rmbe into an spd, note the number of bytes + * to splice in conn->splice_pending, and press 'go'. Delays consumer cursor + * updates till whenever a respective page has been fully processed. + * Note that subsequent recv() calls have to wait till all splice() processing + * completed. + */ +static ssize_t smc_splice_read(struct socket *sock, loff_t *ppos, + struct pipe_inode_info *pipe, size_t len, + unsigned int flags) +{ + struct sock *sk = sock->sk; + struct smc_sock *smc; + int rc = -ENOTCONN; + + smc = smc_sk(sk); + lock_sock(sk); + if (sk->sk_state == SMC_CLOSED && (sk->sk_shutdown & RCV_SHUTDOWN)) { + /* socket was connected before, no more data to read */ + rc = 0; + goto out; + } + if (sk->sk_state == SMC_INIT || + sk->sk_state == SMC_LISTEN || + sk->sk_state == SMC_CLOSED) + goto out; + + if (sk->sk_state == SMC_PEERFINCLOSEWAIT) { + rc = 0; + goto out; + } + + if (smc->use_fallback) { + rc = smc->clcsock->ops->splice_read(smc->clcsock, ppos, + pipe, len, flags); + } else { + if (*ppos) { + rc = -ESPIPE; + goto out; + } + if (flags & SPLICE_F_NONBLOCK) + flags = MSG_DONTWAIT; + else + flags = 0; + SMC_STAT_INC(smc, splice_cnt); + rc = smc_rx_recvmsg(smc, NULL, pipe, len, flags); + } +out: + release_sock(sk); + + return rc; +} + +/* must look like tcp */ +static const struct proto_ops smc_sock_ops = { + .family = PF_SMC, + .owner = THIS_MODULE, + .release = smc_release, + .bind = smc_bind, + .connect = smc_connect, + .socketpair = sock_no_socketpair, + .accept = smc_accept, + .getname = smc_getname, + .poll = smc_poll, + .ioctl = smc_ioctl, + .listen = smc_listen, + .shutdown = smc_shutdown, + .setsockopt = smc_setsockopt, + .getsockopt = smc_getsockopt, + .sendmsg = smc_sendmsg, + .recvmsg = smc_recvmsg, + .mmap = sock_no_mmap, + .sendpage = smc_sendpage, + .splice_read = smc_splice_read, +}; + +static int __smc_create(struct net *net, struct socket *sock, int protocol, + int kern, struct socket *clcsock) +{ + int family = (protocol == SMCPROTO_SMC6) ? PF_INET6 : PF_INET; + struct smc_sock *smc; + struct sock *sk; + int rc; + + rc = -ESOCKTNOSUPPORT; + if (sock->type != SOCK_STREAM) + goto out; + + rc = -EPROTONOSUPPORT; + if (protocol != SMCPROTO_SMC && protocol != SMCPROTO_SMC6) + goto out; + + rc = -ENOBUFS; + sock->ops = &smc_sock_ops; + sock->state = SS_UNCONNECTED; + sk = smc_sock_alloc(net, sock, protocol); + if (!sk) + goto out; + + /* create internal TCP socket for CLC handshake and fallback */ + smc = smc_sk(sk); + smc->use_fallback = false; /* assume rdma capability first */ + smc->fallback_rsn = 0; + + /* default behavior from limit_smc_hs in every net namespace */ + smc->limit_smc_hs = net->smc.limit_smc_hs; + + rc = 0; + if (!clcsock) { + rc = sock_create_kern(net, family, SOCK_STREAM, IPPROTO_TCP, + &smc->clcsock); + if (rc) { + sk_common_release(sk); + goto out; + } + } else { + smc->clcsock = clcsock; + } + +out: + return rc; +} + +static int smc_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + return __smc_create(net, sock, protocol, kern, NULL); +} + +static const struct net_proto_family smc_sock_family_ops = { + .family = PF_SMC, + .owner = THIS_MODULE, + .create = smc_create, +}; + +static int smc_ulp_init(struct sock *sk) +{ + struct socket *tcp = sk->sk_socket; + struct net *net = sock_net(sk); + struct socket *smcsock; + int protocol, ret; + + /* only TCP can be replaced */ + if (tcp->type != SOCK_STREAM || sk->sk_protocol != IPPROTO_TCP || + (sk->sk_family != AF_INET && sk->sk_family != AF_INET6)) + return -ESOCKTNOSUPPORT; + /* don't handle wq now */ + if (tcp->state != SS_UNCONNECTED || !tcp->file || tcp->wq.fasync_list) + return -ENOTCONN; + + if (sk->sk_family == AF_INET) + protocol = SMCPROTO_SMC; + else + protocol = SMCPROTO_SMC6; + + smcsock = sock_alloc(); + if (!smcsock) + return -ENFILE; + + smcsock->type = SOCK_STREAM; + __module_get(THIS_MODULE); /* tried in __tcp_ulp_find_autoload */ + ret = __smc_create(net, smcsock, protocol, 1, tcp); + if (ret) { + sock_release(smcsock); /* module_put() which ops won't be NULL */ + return ret; + } + + /* replace tcp socket to smc */ + smcsock->file = tcp->file; + smcsock->file->private_data = smcsock; + smcsock->file->f_inode = SOCK_INODE(smcsock); /* replace inode when sock_close */ + smcsock->file->f_path.dentry->d_inode = SOCK_INODE(smcsock); /* dput() in __fput */ + tcp->file = NULL; + + return ret; +} + +static void smc_ulp_clone(const struct request_sock *req, struct sock *newsk, + const gfp_t priority) +{ + struct inet_connection_sock *icsk = inet_csk(newsk); + + /* don't inherit ulp ops to child when listen */ + icsk->icsk_ulp_ops = NULL; +} + +static struct tcp_ulp_ops smc_ulp_ops __read_mostly = { + .name = "smc", + .owner = THIS_MODULE, + .init = smc_ulp_init, + .clone = smc_ulp_clone, +}; + +unsigned int smc_net_id; + +static __net_init int smc_net_init(struct net *net) +{ + int rc; + + rc = smc_sysctl_net_init(net); + if (rc) + return rc; + return smc_pnet_net_init(net); +} + +static void __net_exit smc_net_exit(struct net *net) +{ + smc_sysctl_net_exit(net); + smc_pnet_net_exit(net); +} + +static __net_init int smc_net_stat_init(struct net *net) +{ + return smc_stats_init(net); +} + +static void __net_exit smc_net_stat_exit(struct net *net) +{ + smc_stats_exit(net); +} + +static struct pernet_operations smc_net_ops = { + .init = smc_net_init, + .exit = smc_net_exit, + .id = &smc_net_id, + .size = sizeof(struct smc_net), +}; + +static struct pernet_operations smc_net_stat_ops = { + .init = smc_net_stat_init, + .exit = smc_net_stat_exit, +}; + +static int __init smc_init(void) +{ + int rc; + + rc = register_pernet_subsys(&smc_net_ops); + if (rc) + return rc; + + rc = register_pernet_subsys(&smc_net_stat_ops); + if (rc) + goto out_pernet_subsys; + + smc_ism_init(); + smc_clc_init(); + + rc = smc_nl_init(); + if (rc) + goto out_pernet_subsys_stat; + + rc = smc_pnet_init(); + if (rc) + goto out_nl; + + rc = -ENOMEM; + + smc_tcp_ls_wq = alloc_workqueue("smc_tcp_ls_wq", 0, 0); + if (!smc_tcp_ls_wq) + goto out_pnet; + + smc_hs_wq = alloc_workqueue("smc_hs_wq", 0, 0); + if (!smc_hs_wq) + goto out_alloc_tcp_ls_wq; + + smc_close_wq = alloc_workqueue("smc_close_wq", 0, 0); + if (!smc_close_wq) + goto out_alloc_hs_wq; + + rc = smc_core_init(); + if (rc) { + pr_err("%s: smc_core_init fails with %d\n", __func__, rc); + goto out_alloc_wqs; + } + + rc = smc_llc_init(); + if (rc) { + pr_err("%s: smc_llc_init fails with %d\n", __func__, rc); + goto out_core; + } + + rc = smc_cdc_init(); + if (rc) { + pr_err("%s: smc_cdc_init fails with %d\n", __func__, rc); + goto out_core; + } + + rc = proto_register(&smc_proto, 1); + if (rc) { + pr_err("%s: proto_register(v4) fails with %d\n", __func__, rc); + goto out_core; + } + + rc = proto_register(&smc_proto6, 1); + if (rc) { + pr_err("%s: proto_register(v6) fails with %d\n", __func__, rc); + goto out_proto; + } + + rc = sock_register(&smc_sock_family_ops); + if (rc) { + pr_err("%s: sock_register fails with %d\n", __func__, rc); + goto out_proto6; + } + INIT_HLIST_HEAD(&smc_v4_hashinfo.ht); + INIT_HLIST_HEAD(&smc_v6_hashinfo.ht); + + rc = smc_ib_register_client(); + if (rc) { + pr_err("%s: ib_register fails with %d\n", __func__, rc); + goto out_sock; + } + + rc = tcp_register_ulp(&smc_ulp_ops); + if (rc) { + pr_err("%s: tcp_ulp_register fails with %d\n", __func__, rc); + goto out_ib; + } + + static_branch_enable(&tcp_have_smc); + return 0; + +out_ib: + smc_ib_unregister_client(); +out_sock: + sock_unregister(PF_SMC); +out_proto6: + proto_unregister(&smc_proto6); +out_proto: + proto_unregister(&smc_proto); +out_core: + smc_core_exit(); +out_alloc_wqs: + destroy_workqueue(smc_close_wq); +out_alloc_hs_wq: + destroy_workqueue(smc_hs_wq); +out_alloc_tcp_ls_wq: + destroy_workqueue(smc_tcp_ls_wq); +out_pnet: + smc_pnet_exit(); +out_nl: + smc_nl_exit(); +out_pernet_subsys_stat: + unregister_pernet_subsys(&smc_net_stat_ops); +out_pernet_subsys: + unregister_pernet_subsys(&smc_net_ops); + + return rc; +} + +static void __exit smc_exit(void) +{ + static_branch_disable(&tcp_have_smc); + tcp_unregister_ulp(&smc_ulp_ops); + sock_unregister(PF_SMC); + smc_core_exit(); + smc_ib_unregister_client(); + destroy_workqueue(smc_close_wq); + destroy_workqueue(smc_tcp_ls_wq); + destroy_workqueue(smc_hs_wq); + proto_unregister(&smc_proto6); + proto_unregister(&smc_proto); + smc_pnet_exit(); + smc_nl_exit(); + smc_clc_exit(); + unregister_pernet_subsys(&smc_net_stat_ops); + unregister_pernet_subsys(&smc_net_ops); + rcu_barrier(); +} + +module_init(smc_init); +module_exit(smc_exit); + +MODULE_AUTHOR("Ursula Braun "); +MODULE_DESCRIPTION("smc socket address family"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_SMC); +MODULE_ALIAS_TCP_ULP("smc"); +MODULE_ALIAS_GENL_FAMILY(SMC_GENL_FAMILY_NAME); diff --git a/net/smc/smc.h b/net/smc/smc.h new file mode 100644 index 000000000..bcb57e60b --- /dev/null +++ b/net/smc/smc.h @@ -0,0 +1,385 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Definitions for the SMC module (socket related) + * + * Copyright IBM Corp. 2016 + * + * Author(s): Ursula Braun + */ +#ifndef __SMC_H +#define __SMC_H + +#include +#include +#include /* __aligned */ +#include +#include + +#include "smc_ib.h" + +#define SMC_V1 1 /* SMC version V1 */ +#define SMC_V2 2 /* SMC version V2 */ +#define SMC_RELEASE 0 + +#define SMCPROTO_SMC 0 /* SMC protocol, IPv4 */ +#define SMCPROTO_SMC6 1 /* SMC protocol, IPv6 */ + +#define SMC_MAX_ISM_DEVS 8 /* max # of proposed non-native ISM + * devices + */ +#define SMC_AUTOCORKING_DEFAULT_SIZE 0x10000 /* 64K by default */ + +extern struct proto smc_proto; +extern struct proto smc_proto6; + +#ifdef ATOMIC64_INIT +#define KERNEL_HAS_ATOMIC64 +#endif + +enum smc_state { /* possible states of an SMC socket */ + SMC_ACTIVE = 1, + SMC_INIT = 2, + SMC_CLOSED = 7, + SMC_LISTEN = 10, + /* normal close */ + SMC_PEERCLOSEWAIT1 = 20, + SMC_PEERCLOSEWAIT2 = 21, + SMC_APPFINCLOSEWAIT = 24, + SMC_APPCLOSEWAIT1 = 22, + SMC_APPCLOSEWAIT2 = 23, + SMC_PEERFINCLOSEWAIT = 25, + /* abnormal close */ + SMC_PEERABORTWAIT = 26, + SMC_PROCESSABORT = 27, +}; + +struct smc_link_group; + +struct smc_wr_rx_hdr { /* common prefix part of LLC and CDC to demultiplex */ + union { + u8 type; +#if defined(__BIG_ENDIAN_BITFIELD) + struct { + u8 llc_version:4, + llc_type:4; + }; +#elif defined(__LITTLE_ENDIAN_BITFIELD) + struct { + u8 llc_type:4, + llc_version:4; + }; +#endif + }; +} __aligned(1); + +struct smc_cdc_conn_state_flags { +#if defined(__BIG_ENDIAN_BITFIELD) + u8 peer_done_writing : 1; /* Sending done indicator */ + u8 peer_conn_closed : 1; /* Peer connection closed indicator */ + u8 peer_conn_abort : 1; /* Abnormal close indicator */ + u8 reserved : 5; +#elif defined(__LITTLE_ENDIAN_BITFIELD) + u8 reserved : 5; + u8 peer_conn_abort : 1; + u8 peer_conn_closed : 1; + u8 peer_done_writing : 1; +#endif +}; + +struct smc_cdc_producer_flags { +#if defined(__BIG_ENDIAN_BITFIELD) + u8 write_blocked : 1; /* Writing Blocked, no rx buf space */ + u8 urg_data_pending : 1; /* Urgent Data Pending */ + u8 urg_data_present : 1; /* Urgent Data Present */ + u8 cons_curs_upd_req : 1; /* cursor update requested */ + u8 failover_validation : 1;/* message replay due to failover */ + u8 reserved : 3; +#elif defined(__LITTLE_ENDIAN_BITFIELD) + u8 reserved : 3; + u8 failover_validation : 1; + u8 cons_curs_upd_req : 1; + u8 urg_data_present : 1; + u8 urg_data_pending : 1; + u8 write_blocked : 1; +#endif +}; + +/* in host byte order */ +union smc_host_cursor { /* SMC cursor - an offset in an RMBE */ + struct { + u16 reserved; + u16 wrap; /* window wrap sequence number */ + u32 count; /* cursor (= offset) part */ + }; +#ifdef KERNEL_HAS_ATOMIC64 + atomic64_t acurs; /* for atomic processing */ +#else + u64 acurs; /* for atomic processing */ +#endif +} __aligned(8); + +/* in host byte order, except for flag bitfields in network byte order */ +struct smc_host_cdc_msg { /* Connection Data Control message */ + struct smc_wr_rx_hdr common; /* .type = 0xFE */ + u8 len; /* length = 44 */ + u16 seqno; /* connection seq # */ + u32 token; /* alert_token */ + union smc_host_cursor prod; /* producer cursor */ + union smc_host_cursor cons; /* consumer cursor, + * piggy backed "ack" + */ + struct smc_cdc_producer_flags prod_flags; /* conn. tx/rx status */ + struct smc_cdc_conn_state_flags conn_state_flags; /* peer conn. status*/ + u8 reserved[18]; +} __aligned(8); + +enum smc_urg_state { + SMC_URG_VALID = 1, /* data present */ + SMC_URG_NOTYET = 2, /* data pending */ + SMC_URG_READ = 3, /* data was already read */ +}; + +struct smc_mark_woken { + bool woken; + void *key; + wait_queue_entry_t wait_entry; +}; + +struct smc_connection { + struct rb_node alert_node; + struct smc_link_group *lgr; /* link group of connection */ + struct smc_link *lnk; /* assigned SMC-R link */ + u32 alert_token_local; /* unique conn. id */ + u8 peer_rmbe_idx; /* from tcp handshake */ + int peer_rmbe_size; /* size of peer rx buffer */ + atomic_t peer_rmbe_space;/* remaining free bytes in peer + * rmbe + */ + int rtoken_idx; /* idx to peer RMB rkey/addr */ + + struct smc_buf_desc *sndbuf_desc; /* send buffer descriptor */ + struct smc_buf_desc *rmb_desc; /* RMBE descriptor */ + int rmbe_size_comp; /* compressed notation */ + int rmbe_update_limit; + /* lower limit for consumer + * cursor update + */ + + struct smc_host_cdc_msg local_tx_ctrl; /* host byte order staging + * buffer for CDC msg send + * .prod cf. TCP snd_nxt + * .cons cf. TCP sends ack + */ + union smc_host_cursor local_tx_ctrl_fin; + /* prod crsr - confirmed by peer + */ + union smc_host_cursor tx_curs_prep; /* tx - prepared data + * snd_max..wmem_alloc + */ + union smc_host_cursor tx_curs_sent; /* tx - sent data + * snd_nxt ? + */ + union smc_host_cursor tx_curs_fin; /* tx - confirmed by peer + * snd-wnd-begin ? + */ + atomic_t sndbuf_space; /* remaining space in sndbuf */ + u16 tx_cdc_seq; /* sequence # for CDC send */ + u16 tx_cdc_seq_fin; /* sequence # - tx completed */ + spinlock_t send_lock; /* protect wr_sends */ + atomic_t cdc_pend_tx_wr; /* number of pending tx CDC wqe + * - inc when post wqe, + * - dec on polled tx cqe + */ + wait_queue_head_t cdc_pend_tx_wq; /* wakeup on no cdc_pend_tx_wr*/ + atomic_t tx_pushing; /* nr_threads trying tx push */ + struct delayed_work tx_work; /* retry of smc_cdc_msg_send */ + u32 tx_off; /* base offset in peer rmb */ + + struct smc_host_cdc_msg local_rx_ctrl; /* filled during event_handl. + * .prod cf. TCP rcv_nxt + * .cons cf. TCP snd_una + */ + union smc_host_cursor rx_curs_confirmed; /* confirmed to peer + * source of snd_una ? + */ + union smc_host_cursor urg_curs; /* points at urgent byte */ + enum smc_urg_state urg_state; + bool urg_tx_pend; /* urgent data staged */ + bool urg_rx_skip_pend; + /* indicate urgent oob data + * read, but previous regular + * data still pending + */ + char urg_rx_byte; /* urgent byte */ + bool tx_in_release_sock; + /* flush pending tx data in + * sock release_cb() + */ + atomic_t bytes_to_rcv; /* arrived data, + * not yet received + */ + atomic_t splice_pending; /* number of spliced bytes + * pending processing + */ +#ifndef KERNEL_HAS_ATOMIC64 + spinlock_t acurs_lock; /* protect cursors */ +#endif + struct work_struct close_work; /* peer sent some closing */ + struct work_struct abort_work; /* abort the connection */ + struct tasklet_struct rx_tsklet; /* Receiver tasklet for SMC-D */ + u8 rx_off; /* receive offset: + * 0 for SMC-R, 32 for SMC-D + */ + u64 peer_token; /* SMC-D token of peer */ + u8 killed : 1; /* abnormal termination */ + u8 freed : 1; /* normal termiation */ + u8 out_of_sync : 1; /* out of sync with peer */ +}; + +struct smc_sock { /* smc sock container */ + struct sock sk; + struct socket *clcsock; /* internal tcp socket */ + void (*clcsk_state_change)(struct sock *sk); + /* original stat_change fct. */ + void (*clcsk_data_ready)(struct sock *sk); + /* original data_ready fct. */ + void (*clcsk_write_space)(struct sock *sk); + /* original write_space fct. */ + void (*clcsk_error_report)(struct sock *sk); + /* original error_report fct. */ + struct smc_connection conn; /* smc connection */ + struct smc_sock *listen_smc; /* listen parent */ + struct work_struct connect_work; /* handle non-blocking connect*/ + struct work_struct tcp_listen_work;/* handle tcp socket accepts */ + struct work_struct smc_listen_work;/* prepare new accept socket */ + struct list_head accept_q; /* sockets to be accepted */ + spinlock_t accept_q_lock; /* protects accept_q */ + bool limit_smc_hs; /* put constraint on handshake */ + bool use_fallback; /* fallback to tcp */ + int fallback_rsn; /* reason for fallback */ + u32 peer_diagnosis; /* decline reason from peer */ + atomic_t queued_smc_hs; /* queued smc handshakes */ + struct inet_connection_sock_af_ops af_ops; + const struct inet_connection_sock_af_ops *ori_af_ops; + /* original af ops */ + int sockopt_defer_accept; + /* sockopt TCP_DEFER_ACCEPT + * value + */ + u8 wait_close_tx_prepared : 1; + /* shutdown wr or close + * started, waiting for unsent + * data to be sent + */ + u8 connect_nonblock : 1; + /* non-blocking connect in + * flight + */ + struct mutex clcsock_release_lock; + /* protects clcsock of a listen + * socket + * */ +}; + +static inline struct smc_sock *smc_sk(const struct sock *sk) +{ + return (struct smc_sock *)sk; +} + +static inline void smc_init_saved_callbacks(struct smc_sock *smc) +{ + smc->clcsk_state_change = NULL; + smc->clcsk_data_ready = NULL; + smc->clcsk_write_space = NULL; + smc->clcsk_error_report = NULL; +} + +static inline struct smc_sock *smc_clcsock_user_data(const struct sock *clcsk) +{ + return (struct smc_sock *) + ((uintptr_t)clcsk->sk_user_data & ~SK_USER_DATA_NOCOPY); +} + +/* save target_cb in saved_cb, and replace target_cb with new_cb */ +static inline void smc_clcsock_replace_cb(void (**target_cb)(struct sock *), + void (*new_cb)(struct sock *), + void (**saved_cb)(struct sock *)) +{ + /* only save once */ + if (!*saved_cb) + *saved_cb = *target_cb; + *target_cb = new_cb; +} + +/* restore target_cb to saved_cb, and reset saved_cb to NULL */ +static inline void smc_clcsock_restore_cb(void (**target_cb)(struct sock *), + void (**saved_cb)(struct sock *)) +{ + if (!*saved_cb) + return; + *target_cb = *saved_cb; + *saved_cb = NULL; +} + +extern struct workqueue_struct *smc_hs_wq; /* wq for handshake work */ +extern struct workqueue_struct *smc_close_wq; /* wq for close work */ + +#define SMC_SYSTEMID_LEN 8 + +extern u8 local_systemid[SMC_SYSTEMID_LEN]; /* unique system identifier */ + +#define ntohll(x) be64_to_cpu(x) +#define htonll(x) cpu_to_be64(x) + +/* convert an u32 value into network byte order, store it into a 3 byte field */ +static inline void hton24(u8 *net, u32 host) +{ + __be32 t; + + t = cpu_to_be32(host); + memcpy(net, ((u8 *)&t) + 1, 3); +} + +/* convert a received 3 byte field into host byte order*/ +static inline u32 ntoh24(u8 *net) +{ + __be32 t = 0; + + memcpy(((u8 *)&t) + 1, net, 3); + return be32_to_cpu(t); +} + +#ifdef CONFIG_XFRM +static inline bool using_ipsec(struct smc_sock *smc) +{ + return (smc->clcsock->sk->sk_policy[0] || + smc->clcsock->sk->sk_policy[1]) ? true : false; +} +#else +static inline bool using_ipsec(struct smc_sock *smc) +{ + return false; +} +#endif + +struct smc_gidlist; + +struct sock *smc_accept_dequeue(struct sock *parent, struct socket *new_sock); +void smc_close_non_accepted(struct sock *sk); +void smc_fill_gid_list(struct smc_link_group *lgr, + struct smc_gidlist *gidlist, + struct smc_ib_device *known_dev, u8 *known_gid); + +/* smc handshake limitation interface for netlink */ +int smc_nl_dump_hs_limitation(struct sk_buff *skb, struct netlink_callback *cb); +int smc_nl_enable_hs_limitation(struct sk_buff *skb, struct genl_info *info); +int smc_nl_disable_hs_limitation(struct sk_buff *skb, struct genl_info *info); + +static inline void smc_sock_set_flag(struct sock *sk, enum sock_flags flag) +{ + set_bit(flag, &sk->sk_flags); +} + +#endif /* __SMC_H */ diff --git a/net/smc/smc_cdc.c b/net/smc/smc_cdc.c new file mode 100644 index 000000000..3c06625ce --- /dev/null +++ b/net/smc/smc_cdc.c @@ -0,0 +1,493 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Connection Data Control (CDC) + * handles flow control + * + * Copyright IBM Corp. 2016 + * + * Author(s): Ursula Braun + */ + +#include + +#include "smc.h" +#include "smc_wr.h" +#include "smc_cdc.h" +#include "smc_tx.h" +#include "smc_rx.h" +#include "smc_close.h" + +/********************************** send *************************************/ + +/* handler for send/transmission completion of a CDC msg */ +static void smc_cdc_tx_handler(struct smc_wr_tx_pend_priv *pnd_snd, + struct smc_link *link, + enum ib_wc_status wc_status) +{ + struct smc_cdc_tx_pend *cdcpend = (struct smc_cdc_tx_pend *)pnd_snd; + struct smc_connection *conn = cdcpend->conn; + struct smc_buf_desc *sndbuf_desc; + struct smc_sock *smc; + int diff; + + sndbuf_desc = conn->sndbuf_desc; + smc = container_of(conn, struct smc_sock, conn); + bh_lock_sock(&smc->sk); + if (!wc_status && sndbuf_desc) { + diff = smc_curs_diff(sndbuf_desc->len, + &cdcpend->conn->tx_curs_fin, + &cdcpend->cursor); + /* sndbuf_space is decreased in smc_sendmsg */ + smp_mb__before_atomic(); + atomic_add(diff, &cdcpend->conn->sndbuf_space); + /* guarantee 0 <= sndbuf_space <= sndbuf_desc->len */ + smp_mb__after_atomic(); + smc_curs_copy(&conn->tx_curs_fin, &cdcpend->cursor, conn); + smc_curs_copy(&conn->local_tx_ctrl_fin, &cdcpend->p_cursor, + conn); + conn->tx_cdc_seq_fin = cdcpend->ctrl_seq; + } + + if (atomic_dec_and_test(&conn->cdc_pend_tx_wr)) { + /* If user owns the sock_lock, mark the connection need sending. + * User context will later try to send when it release sock_lock + * in smc_release_cb() + */ + if (sock_owned_by_user(&smc->sk)) + conn->tx_in_release_sock = true; + else + smc_tx_pending(conn); + + if (unlikely(wq_has_sleeper(&conn->cdc_pend_tx_wq))) + wake_up(&conn->cdc_pend_tx_wq); + } + WARN_ON(atomic_read(&conn->cdc_pend_tx_wr) < 0); + + smc_tx_sndbuf_nonfull(smc); + bh_unlock_sock(&smc->sk); +} + +int smc_cdc_get_free_slot(struct smc_connection *conn, + struct smc_link *link, + struct smc_wr_buf **wr_buf, + struct smc_rdma_wr **wr_rdma_buf, + struct smc_cdc_tx_pend **pend) +{ + int rc; + + rc = smc_wr_tx_get_free_slot(link, smc_cdc_tx_handler, wr_buf, + wr_rdma_buf, + (struct smc_wr_tx_pend_priv **)pend); + if (conn->killed) { + /* abnormal termination */ + if (!rc) + smc_wr_tx_put_slot(link, + (struct smc_wr_tx_pend_priv *)(*pend)); + rc = -EPIPE; + } + return rc; +} + +static inline void smc_cdc_add_pending_send(struct smc_connection *conn, + struct smc_cdc_tx_pend *pend) +{ + BUILD_BUG_ON_MSG( + sizeof(struct smc_cdc_msg) > SMC_WR_BUF_SIZE, + "must increase SMC_WR_BUF_SIZE to at least sizeof(struct smc_cdc_msg)"); + BUILD_BUG_ON_MSG( + offsetofend(struct smc_cdc_msg, reserved) > SMC_WR_TX_SIZE, + "must adapt SMC_WR_TX_SIZE to sizeof(struct smc_cdc_msg); if not all smc_wr upper layer protocols use the same message size any more, must start to set link->wr_tx_sges[i].length on each individual smc_wr_tx_send()"); + BUILD_BUG_ON_MSG( + sizeof(struct smc_cdc_tx_pend) > SMC_WR_TX_PEND_PRIV_SIZE, + "must increase SMC_WR_TX_PEND_PRIV_SIZE to at least sizeof(struct smc_cdc_tx_pend)"); + pend->conn = conn; + pend->cursor = conn->tx_curs_sent; + pend->p_cursor = conn->local_tx_ctrl.prod; + pend->ctrl_seq = conn->tx_cdc_seq; +} + +int smc_cdc_msg_send(struct smc_connection *conn, + struct smc_wr_buf *wr_buf, + struct smc_cdc_tx_pend *pend) +{ + struct smc_link *link = conn->lnk; + union smc_host_cursor cfed; + int rc; + + smc_cdc_add_pending_send(conn, pend); + + conn->tx_cdc_seq++; + conn->local_tx_ctrl.seqno = conn->tx_cdc_seq; + smc_host_msg_to_cdc((struct smc_cdc_msg *)wr_buf, conn, &cfed); + + atomic_inc(&conn->cdc_pend_tx_wr); + smp_mb__after_atomic(); /* Make sure cdc_pend_tx_wr added before post */ + + rc = smc_wr_tx_send(link, (struct smc_wr_tx_pend_priv *)pend); + if (!rc) { + smc_curs_copy(&conn->rx_curs_confirmed, &cfed, conn); + conn->local_rx_ctrl.prod_flags.cons_curs_upd_req = 0; + } else { + conn->tx_cdc_seq--; + conn->local_tx_ctrl.seqno = conn->tx_cdc_seq; + atomic_dec(&conn->cdc_pend_tx_wr); + } + + return rc; +} + +/* send a validation msg indicating the move of a conn to an other QP link */ +int smcr_cdc_msg_send_validation(struct smc_connection *conn, + struct smc_cdc_tx_pend *pend, + struct smc_wr_buf *wr_buf) +{ + struct smc_host_cdc_msg *local = &conn->local_tx_ctrl; + struct smc_link *link = conn->lnk; + struct smc_cdc_msg *peer; + int rc; + + peer = (struct smc_cdc_msg *)wr_buf; + peer->common.type = local->common.type; + peer->len = local->len; + peer->seqno = htons(conn->tx_cdc_seq_fin); /* seqno last compl. tx */ + peer->token = htonl(local->token); + peer->prod_flags.failover_validation = 1; + + /* We need to set pend->conn here to make sure smc_cdc_tx_handler() + * can handle properly + */ + smc_cdc_add_pending_send(conn, pend); + + atomic_inc(&conn->cdc_pend_tx_wr); + smp_mb__after_atomic(); /* Make sure cdc_pend_tx_wr added before post */ + + rc = smc_wr_tx_send(link, (struct smc_wr_tx_pend_priv *)pend); + if (unlikely(rc)) + atomic_dec(&conn->cdc_pend_tx_wr); + + return rc; +} + +static int smcr_cdc_get_slot_and_msg_send(struct smc_connection *conn) +{ + struct smc_cdc_tx_pend *pend; + struct smc_wr_buf *wr_buf; + struct smc_link *link; + bool again = false; + int rc; + +again: + link = conn->lnk; + if (!smc_wr_tx_link_hold(link)) + return -ENOLINK; + rc = smc_cdc_get_free_slot(conn, link, &wr_buf, NULL, &pend); + if (rc) + goto put_out; + + spin_lock_bh(&conn->send_lock); + if (link != conn->lnk) { + /* link of connection changed, try again one time*/ + spin_unlock_bh(&conn->send_lock); + smc_wr_tx_put_slot(link, + (struct smc_wr_tx_pend_priv *)pend); + smc_wr_tx_link_put(link); + if (again) + return -ENOLINK; + again = true; + goto again; + } + rc = smc_cdc_msg_send(conn, wr_buf, pend); + spin_unlock_bh(&conn->send_lock); +put_out: + smc_wr_tx_link_put(link); + return rc; +} + +int smc_cdc_get_slot_and_msg_send(struct smc_connection *conn) +{ + int rc; + + if (!smc_conn_lgr_valid(conn) || + (conn->lgr->is_smcd && conn->lgr->peer_shutdown)) + return -EPIPE; + + if (conn->lgr->is_smcd) { + spin_lock_bh(&conn->send_lock); + rc = smcd_cdc_msg_send(conn); + spin_unlock_bh(&conn->send_lock); + } else { + rc = smcr_cdc_get_slot_and_msg_send(conn); + } + + return rc; +} + +void smc_cdc_wait_pend_tx_wr(struct smc_connection *conn) +{ + wait_event(conn->cdc_pend_tx_wq, !atomic_read(&conn->cdc_pend_tx_wr)); +} + +/* Send a SMC-D CDC header. + * This increments the free space available in our send buffer. + * Also update the confirmed receive buffer with what was sent to the peer. + */ +int smcd_cdc_msg_send(struct smc_connection *conn) +{ + struct smc_sock *smc = container_of(conn, struct smc_sock, conn); + union smc_host_cursor curs; + struct smcd_cdc_msg cdc; + int rc, diff; + + memset(&cdc, 0, sizeof(cdc)); + cdc.common.type = SMC_CDC_MSG_TYPE; + curs.acurs.counter = atomic64_read(&conn->local_tx_ctrl.prod.acurs); + cdc.prod.wrap = curs.wrap; + cdc.prod.count = curs.count; + curs.acurs.counter = atomic64_read(&conn->local_tx_ctrl.cons.acurs); + cdc.cons.wrap = curs.wrap; + cdc.cons.count = curs.count; + cdc.cons.prod_flags = conn->local_tx_ctrl.prod_flags; + cdc.cons.conn_state_flags = conn->local_tx_ctrl.conn_state_flags; + rc = smcd_tx_ism_write(conn, &cdc, sizeof(cdc), 0, 1); + if (rc) + return rc; + smc_curs_copy(&conn->rx_curs_confirmed, &curs, conn); + conn->local_rx_ctrl.prod_flags.cons_curs_upd_req = 0; + /* Calculate transmitted data and increment free send buffer space */ + diff = smc_curs_diff(conn->sndbuf_desc->len, &conn->tx_curs_fin, + &conn->tx_curs_sent); + /* increased by confirmed number of bytes */ + smp_mb__before_atomic(); + atomic_add(diff, &conn->sndbuf_space); + /* guarantee 0 <= sndbuf_space <= sndbuf_desc->len */ + smp_mb__after_atomic(); + smc_curs_copy(&conn->tx_curs_fin, &conn->tx_curs_sent, conn); + + smc_tx_sndbuf_nonfull(smc); + return rc; +} + +/********************************* receive ***********************************/ + +static inline bool smc_cdc_before(u16 seq1, u16 seq2) +{ + return (s16)(seq1 - seq2) < 0; +} + +static void smc_cdc_handle_urg_data_arrival(struct smc_sock *smc, + int *diff_prod) +{ + struct smc_connection *conn = &smc->conn; + char *base; + + /* new data included urgent business */ + smc_curs_copy(&conn->urg_curs, &conn->local_rx_ctrl.prod, conn); + conn->urg_state = SMC_URG_VALID; + if (!sock_flag(&smc->sk, SOCK_URGINLINE)) + /* we'll skip the urgent byte, so don't account for it */ + (*diff_prod)--; + base = (char *)conn->rmb_desc->cpu_addr + conn->rx_off; + if (conn->urg_curs.count) + conn->urg_rx_byte = *(base + conn->urg_curs.count - 1); + else + conn->urg_rx_byte = *(base + conn->rmb_desc->len - 1); + sk_send_sigurg(&smc->sk); +} + +static void smc_cdc_msg_validate(struct smc_sock *smc, struct smc_cdc_msg *cdc, + struct smc_link *link) +{ + struct smc_connection *conn = &smc->conn; + u16 recv_seq = ntohs(cdc->seqno); + s16 diff; + + /* check that seqnum was seen before */ + diff = conn->local_rx_ctrl.seqno - recv_seq; + if (diff < 0) { /* diff larger than 0x7fff */ + /* drop connection */ + conn->out_of_sync = 1; /* prevent any further receives */ + spin_lock_bh(&conn->send_lock); + conn->local_tx_ctrl.conn_state_flags.peer_conn_abort = 1; + conn->lnk = link; + spin_unlock_bh(&conn->send_lock); + sock_hold(&smc->sk); /* sock_put in abort_work */ + if (!queue_work(smc_close_wq, &conn->abort_work)) + sock_put(&smc->sk); + } +} + +static void smc_cdc_msg_recv_action(struct smc_sock *smc, + struct smc_cdc_msg *cdc) +{ + union smc_host_cursor cons_old, prod_old; + struct smc_connection *conn = &smc->conn; + int diff_cons, diff_prod; + + smc_curs_copy(&prod_old, &conn->local_rx_ctrl.prod, conn); + smc_curs_copy(&cons_old, &conn->local_rx_ctrl.cons, conn); + smc_cdc_msg_to_host(&conn->local_rx_ctrl, cdc, conn); + + diff_cons = smc_curs_diff(conn->peer_rmbe_size, &cons_old, + &conn->local_rx_ctrl.cons); + if (diff_cons) { + /* peer_rmbe_space is decreased during data transfer with RDMA + * write + */ + smp_mb__before_atomic(); + atomic_add(diff_cons, &conn->peer_rmbe_space); + /* guarantee 0 <= peer_rmbe_space <= peer_rmbe_size */ + smp_mb__after_atomic(); + } + + diff_prod = smc_curs_diff(conn->rmb_desc->len, &prod_old, + &conn->local_rx_ctrl.prod); + if (diff_prod) { + if (conn->local_rx_ctrl.prod_flags.urg_data_present) + smc_cdc_handle_urg_data_arrival(smc, &diff_prod); + /* bytes_to_rcv is decreased in smc_recvmsg */ + smp_mb__before_atomic(); + atomic_add(diff_prod, &conn->bytes_to_rcv); + /* guarantee 0 <= bytes_to_rcv <= rmb_desc->len */ + smp_mb__after_atomic(); + smc->sk.sk_data_ready(&smc->sk); + } else { + if (conn->local_rx_ctrl.prod_flags.write_blocked) + smc->sk.sk_data_ready(&smc->sk); + if (conn->local_rx_ctrl.prod_flags.urg_data_pending) + conn->urg_state = SMC_URG_NOTYET; + } + + /* trigger sndbuf consumer: RDMA write into peer RMBE and CDC */ + if ((diff_cons && smc_tx_prepared_sends(conn)) || + conn->local_rx_ctrl.prod_flags.cons_curs_upd_req || + conn->local_rx_ctrl.prod_flags.urg_data_pending) { + if (!sock_owned_by_user(&smc->sk)) + smc_tx_pending(conn); + else + conn->tx_in_release_sock = true; + } + + if (diff_cons && conn->urg_tx_pend && + atomic_read(&conn->peer_rmbe_space) == conn->peer_rmbe_size) { + /* urg data confirmed by peer, indicate we're ready for more */ + conn->urg_tx_pend = false; + smc->sk.sk_write_space(&smc->sk); + } + + if (conn->local_rx_ctrl.conn_state_flags.peer_conn_abort) { + smc->sk.sk_err = ECONNRESET; + conn->local_tx_ctrl.conn_state_flags.peer_conn_abort = 1; + } + if (smc_cdc_rxed_any_close_or_senddone(conn)) { + smc->sk.sk_shutdown |= RCV_SHUTDOWN; + if (smc->clcsock && smc->clcsock->sk) + smc->clcsock->sk->sk_shutdown |= RCV_SHUTDOWN; + smc_sock_set_flag(&smc->sk, SOCK_DONE); + sock_hold(&smc->sk); /* sock_put in close_work */ + if (!queue_work(smc_close_wq, &conn->close_work)) + sock_put(&smc->sk); + } +} + +/* called under tasklet context */ +static void smc_cdc_msg_recv(struct smc_sock *smc, struct smc_cdc_msg *cdc) +{ + sock_hold(&smc->sk); + bh_lock_sock(&smc->sk); + smc_cdc_msg_recv_action(smc, cdc); + bh_unlock_sock(&smc->sk); + sock_put(&smc->sk); /* no free sk in softirq-context */ +} + +/* Schedule a tasklet for this connection. Triggered from the ISM device IRQ + * handler to indicate update in the DMBE. + * + * Context: + * - tasklet context + */ +static void smcd_cdc_rx_tsklet(struct tasklet_struct *t) +{ + struct smc_connection *conn = from_tasklet(conn, t, rx_tsklet); + struct smcd_cdc_msg *data_cdc; + struct smcd_cdc_msg cdc; + struct smc_sock *smc; + + if (!conn || conn->killed) + return; + + data_cdc = (struct smcd_cdc_msg *)conn->rmb_desc->cpu_addr; + smcd_curs_copy(&cdc.prod, &data_cdc->prod, conn); + smcd_curs_copy(&cdc.cons, &data_cdc->cons, conn); + smc = container_of(conn, struct smc_sock, conn); + smc_cdc_msg_recv(smc, (struct smc_cdc_msg *)&cdc); +} + +/* Initialize receive tasklet. Called from ISM device IRQ handler to start + * receiver side. + */ +void smcd_cdc_rx_init(struct smc_connection *conn) +{ + tasklet_setup(&conn->rx_tsklet, smcd_cdc_rx_tsklet); +} + +/***************************** init, exit, misc ******************************/ + +static void smc_cdc_rx_handler(struct ib_wc *wc, void *buf) +{ + struct smc_link *link = (struct smc_link *)wc->qp->qp_context; + struct smc_cdc_msg *cdc = buf; + struct smc_connection *conn; + struct smc_link_group *lgr; + struct smc_sock *smc; + + if (wc->byte_len < offsetof(struct smc_cdc_msg, reserved)) + return; /* short message */ + if (cdc->len != SMC_WR_TX_SIZE) + return; /* invalid message */ + + /* lookup connection */ + lgr = smc_get_lgr(link); + read_lock_bh(&lgr->conns_lock); + conn = smc_lgr_find_conn(ntohl(cdc->token), lgr); + read_unlock_bh(&lgr->conns_lock); + if (!conn || conn->out_of_sync) + return; + smc = container_of(conn, struct smc_sock, conn); + + if (cdc->prod_flags.failover_validation) { + smc_cdc_msg_validate(smc, cdc, link); + return; + } + if (smc_cdc_before(ntohs(cdc->seqno), + conn->local_rx_ctrl.seqno)) + /* received seqno is old */ + return; + + smc_cdc_msg_recv(smc, cdc); +} + +static struct smc_wr_rx_handler smc_cdc_rx_handlers[] = { + { + .handler = smc_cdc_rx_handler, + .type = SMC_CDC_MSG_TYPE + }, + { + .handler = NULL, + } +}; + +int __init smc_cdc_init(void) +{ + struct smc_wr_rx_handler *handler; + int rc = 0; + + for (handler = smc_cdc_rx_handlers; handler->handler; handler++) { + INIT_HLIST_NODE(&handler->list); + rc = smc_wr_rx_register_handler(handler); + if (rc) + break; + } + return rc; +} diff --git a/net/smc/smc_cdc.h b/net/smc/smc_cdc.h new file mode 100644 index 000000000..696cc11f2 --- /dev/null +++ b/net/smc/smc_cdc.h @@ -0,0 +1,305 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Connection Data Control (CDC) + * + * Copyright IBM Corp. 2016 + * + * Author(s): Ursula Braun + */ + +#ifndef SMC_CDC_H +#define SMC_CDC_H + +#include /* max_t */ +#include +#include +#include + +#include "smc.h" +#include "smc_core.h" +#include "smc_wr.h" + +#define SMC_CDC_MSG_TYPE 0xFE + +/* in network byte order */ +union smc_cdc_cursor { /* SMC cursor */ + struct { + __be16 reserved; + __be16 wrap; + __be32 count; + }; +#ifdef KERNEL_HAS_ATOMIC64 + atomic64_t acurs; /* for atomic processing */ +#else + u64 acurs; /* for atomic processing */ +#endif +} __aligned(8); + +/* in network byte order */ +struct smc_cdc_msg { + struct smc_wr_rx_hdr common; /* .type = 0xFE */ + u8 len; /* 44 */ + __be16 seqno; + __be32 token; + union smc_cdc_cursor prod; + union smc_cdc_cursor cons; /* piggy backed "ack" */ + struct smc_cdc_producer_flags prod_flags; + struct smc_cdc_conn_state_flags conn_state_flags; + u8 reserved[18]; +}; + +/* SMC-D cursor format */ +union smcd_cdc_cursor { + struct { + u16 wrap; + u32 count; + struct smc_cdc_producer_flags prod_flags; + struct smc_cdc_conn_state_flags conn_state_flags; + } __packed; +#ifdef KERNEL_HAS_ATOMIC64 + atomic64_t acurs; /* for atomic processing */ +#else + u64 acurs; /* for atomic processing */ +#endif +} __aligned(8); + +/* CDC message for SMC-D */ +struct smcd_cdc_msg { + struct smc_wr_rx_hdr common; /* Type = 0xFE */ + u8 res1[7]; + union smcd_cdc_cursor prod; + union smcd_cdc_cursor cons; + u8 res3[8]; +} __aligned(8); + +static inline bool smc_cdc_rxed_any_close(struct smc_connection *conn) +{ + return conn->local_rx_ctrl.conn_state_flags.peer_conn_abort || + conn->local_rx_ctrl.conn_state_flags.peer_conn_closed; +} + +static inline bool smc_cdc_rxed_any_close_or_senddone( + struct smc_connection *conn) +{ + return smc_cdc_rxed_any_close(conn) || + conn->local_rx_ctrl.conn_state_flags.peer_done_writing; +} + +static inline void smc_curs_add(int size, union smc_host_cursor *curs, + int value) +{ + curs->count += value; + if (curs->count >= size) { + curs->wrap++; + curs->count -= size; + } +} + +/* Copy cursor src into tgt */ +static inline void smc_curs_copy(union smc_host_cursor *tgt, + union smc_host_cursor *src, + struct smc_connection *conn) +{ +#ifndef KERNEL_HAS_ATOMIC64 + unsigned long flags; + + spin_lock_irqsave(&conn->acurs_lock, flags); + tgt->acurs = src->acurs; + spin_unlock_irqrestore(&conn->acurs_lock, flags); +#else + atomic64_set(&tgt->acurs, atomic64_read(&src->acurs)); +#endif +} + +static inline void smc_curs_copy_net(union smc_cdc_cursor *tgt, + union smc_cdc_cursor *src, + struct smc_connection *conn) +{ +#ifndef KERNEL_HAS_ATOMIC64 + unsigned long flags; + + spin_lock_irqsave(&conn->acurs_lock, flags); + tgt->acurs = src->acurs; + spin_unlock_irqrestore(&conn->acurs_lock, flags); +#else + atomic64_set(&tgt->acurs, atomic64_read(&src->acurs)); +#endif +} + +static inline void smcd_curs_copy(union smcd_cdc_cursor *tgt, + union smcd_cdc_cursor *src, + struct smc_connection *conn) +{ +#ifndef KERNEL_HAS_ATOMIC64 + unsigned long flags; + + spin_lock_irqsave(&conn->acurs_lock, flags); + tgt->acurs = src->acurs; + spin_unlock_irqrestore(&conn->acurs_lock, flags); +#else + atomic64_set(&tgt->acurs, atomic64_read(&src->acurs)); +#endif +} + +/* calculate cursor difference between old and new, where old <= new and + * difference cannot exceed size + */ +static inline int smc_curs_diff(unsigned int size, + union smc_host_cursor *old, + union smc_host_cursor *new) +{ + if (old->wrap != new->wrap) + return max_t(int, 0, + ((size - old->count) + new->count)); + + return max_t(int, 0, (new->count - old->count)); +} + +/* calculate cursor difference between old and new - returns negative + * value in case old > new + */ +static inline int smc_curs_comp(unsigned int size, + union smc_host_cursor *old, + union smc_host_cursor *new) +{ + if (old->wrap > new->wrap || + (old->wrap == new->wrap && old->count > new->count)) + return -smc_curs_diff(size, new, old); + return smc_curs_diff(size, old, new); +} + +/* calculate cursor difference between old and new, where old <= new and + * difference may exceed size + */ +static inline int smc_curs_diff_large(unsigned int size, + union smc_host_cursor *old, + union smc_host_cursor *new) +{ + if (old->wrap < new->wrap) + return min_t(int, + (size - old->count) + new->count + + (new->wrap - old->wrap - 1) * size, + size); + + if (old->wrap > new->wrap) /* wrap has switched from 0xffff to 0x0000 */ + return min_t(int, + (size - old->count) + new->count + + (new->wrap + 0xffff - old->wrap) * size, + size); + + return max_t(int, 0, (new->count - old->count)); +} + +static inline void smc_host_cursor_to_cdc(union smc_cdc_cursor *peer, + union smc_host_cursor *local, + union smc_host_cursor *save, + struct smc_connection *conn) +{ + smc_curs_copy(save, local, conn); + peer->count = htonl(save->count); + peer->wrap = htons(save->wrap); + /* peer->reserved = htons(0); must be ensured by caller */ +} + +static inline void smc_host_msg_to_cdc(struct smc_cdc_msg *peer, + struct smc_connection *conn, + union smc_host_cursor *save) +{ + struct smc_host_cdc_msg *local = &conn->local_tx_ctrl; + + peer->common.type = local->common.type; + peer->len = local->len; + peer->seqno = htons(local->seqno); + peer->token = htonl(local->token); + smc_host_cursor_to_cdc(&peer->prod, &local->prod, save, conn); + smc_host_cursor_to_cdc(&peer->cons, &local->cons, save, conn); + peer->prod_flags = local->prod_flags; + peer->conn_state_flags = local->conn_state_flags; +} + +static inline void smc_cdc_cursor_to_host(union smc_host_cursor *local, + union smc_cdc_cursor *peer, + struct smc_connection *conn) +{ + union smc_host_cursor temp, old; + union smc_cdc_cursor net; + + smc_curs_copy(&old, local, conn); + smc_curs_copy_net(&net, peer, conn); + temp.count = ntohl(net.count); + temp.wrap = ntohs(net.wrap); + if ((old.wrap > temp.wrap) && temp.wrap) + return; + if ((old.wrap == temp.wrap) && + (old.count > temp.count)) + return; + smc_curs_copy(local, &temp, conn); +} + +static inline void smcr_cdc_msg_to_host(struct smc_host_cdc_msg *local, + struct smc_cdc_msg *peer, + struct smc_connection *conn) +{ + local->common.type = peer->common.type; + local->len = peer->len; + local->seqno = ntohs(peer->seqno); + local->token = ntohl(peer->token); + smc_cdc_cursor_to_host(&local->prod, &peer->prod, conn); + smc_cdc_cursor_to_host(&local->cons, &peer->cons, conn); + local->prod_flags = peer->prod_flags; + local->conn_state_flags = peer->conn_state_flags; +} + +static inline void smcd_cdc_msg_to_host(struct smc_host_cdc_msg *local, + struct smcd_cdc_msg *peer, + struct smc_connection *conn) +{ + union smc_host_cursor temp; + + temp.wrap = peer->prod.wrap; + temp.count = peer->prod.count; + smc_curs_copy(&local->prod, &temp, conn); + + temp.wrap = peer->cons.wrap; + temp.count = peer->cons.count; + smc_curs_copy(&local->cons, &temp, conn); + local->prod_flags = peer->cons.prod_flags; + local->conn_state_flags = peer->cons.conn_state_flags; +} + +static inline void smc_cdc_msg_to_host(struct smc_host_cdc_msg *local, + struct smc_cdc_msg *peer, + struct smc_connection *conn) +{ + if (conn->lgr->is_smcd) + smcd_cdc_msg_to_host(local, (struct smcd_cdc_msg *)peer, conn); + else + smcr_cdc_msg_to_host(local, peer, conn); +} + +struct smc_cdc_tx_pend { + struct smc_connection *conn; /* socket connection */ + union smc_host_cursor cursor; /* tx sndbuf cursor sent */ + union smc_host_cursor p_cursor; /* rx RMBE cursor produced */ + u16 ctrl_seq; /* conn. tx sequence # */ +}; + +int smc_cdc_get_free_slot(struct smc_connection *conn, + struct smc_link *link, + struct smc_wr_buf **wr_buf, + struct smc_rdma_wr **wr_rdma_buf, + struct smc_cdc_tx_pend **pend); +void smc_cdc_wait_pend_tx_wr(struct smc_connection *conn); +int smc_cdc_msg_send(struct smc_connection *conn, struct smc_wr_buf *wr_buf, + struct smc_cdc_tx_pend *pend); +int smc_cdc_get_slot_and_msg_send(struct smc_connection *conn); +int smcd_cdc_msg_send(struct smc_connection *conn); +int smcr_cdc_msg_send_validation(struct smc_connection *conn, + struct smc_cdc_tx_pend *pend, + struct smc_wr_buf *wr_buf); +int smc_cdc_init(void) __init; +void smcd_cdc_rx_init(struct smc_connection *conn); + +#endif /* SMC_CDC_H */ diff --git a/net/smc/smc_clc.c b/net/smc/smc_clc.c new file mode 100644 index 000000000..9b8999e2a --- /dev/null +++ b/net/smc/smc_clc.c @@ -0,0 +1,1177 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * CLC (connection layer control) handshake over initial TCP socket to + * prepare for RDMA traffic + * + * Copyright IBM Corp. 2016, 2018 + * + * Author(s): Ursula Braun + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "smc.h" +#include "smc_core.h" +#include "smc_clc.h" +#include "smc_ib.h" +#include "smc_ism.h" +#include "smc_netlink.h" + +#define SMCR_CLC_ACCEPT_CONFIRM_LEN 68 +#define SMCD_CLC_ACCEPT_CONFIRM_LEN 48 +#define SMCD_CLC_ACCEPT_CONFIRM_LEN_V2 78 +#define SMCR_CLC_ACCEPT_CONFIRM_LEN_V2 108 +#define SMC_CLC_RECV_BUF_LEN 100 + +/* eye catcher "SMCR" EBCDIC for CLC messages */ +static const char SMC_EYECATCHER[4] = {'\xe2', '\xd4', '\xc3', '\xd9'}; +/* eye catcher "SMCD" EBCDIC for CLC messages */ +static const char SMCD_EYECATCHER[4] = {'\xe2', '\xd4', '\xc3', '\xc4'}; + +static u8 smc_hostname[SMC_MAX_HOSTNAME_LEN]; + +struct smc_clc_eid_table { + rwlock_t lock; + struct list_head list; + u8 ueid_cnt; + u8 seid_enabled; +}; + +static struct smc_clc_eid_table smc_clc_eid_table; + +struct smc_clc_eid_entry { + struct list_head list; + u8 eid[SMC_MAX_EID_LEN]; +}; + +/* The size of a user EID is 32 characters. + * Valid characters should be (single-byte character set) A-Z, 0-9, '.' and '-'. + * Blanks should only be used to pad to the expected size. + * First character must be alphanumeric. + */ +static bool smc_clc_ueid_valid(char *ueid) +{ + char *end = ueid + SMC_MAX_EID_LEN; + + while (--end >= ueid && isspace(*end)) + ; + if (end < ueid) + return false; + if (!isalnum(*ueid) || islower(*ueid)) + return false; + while (ueid <= end) { + if ((!isalnum(*ueid) || islower(*ueid)) && *ueid != '.' && + *ueid != '-') + return false; + ueid++; + } + return true; +} + +static int smc_clc_ueid_add(char *ueid) +{ + struct smc_clc_eid_entry *new_ueid, *tmp_ueid; + int rc; + + if (!smc_clc_ueid_valid(ueid)) + return -EINVAL; + + /* add a new ueid entry to the ueid table if there isn't one */ + new_ueid = kzalloc(sizeof(*new_ueid), GFP_KERNEL); + if (!new_ueid) + return -ENOMEM; + memcpy(new_ueid->eid, ueid, SMC_MAX_EID_LEN); + + write_lock(&smc_clc_eid_table.lock); + if (smc_clc_eid_table.ueid_cnt >= SMC_MAX_UEID) { + rc = -ERANGE; + goto err_out; + } + list_for_each_entry(tmp_ueid, &smc_clc_eid_table.list, list) { + if (!memcmp(tmp_ueid->eid, ueid, SMC_MAX_EID_LEN)) { + rc = -EEXIST; + goto err_out; + } + } + list_add_tail(&new_ueid->list, &smc_clc_eid_table.list); + smc_clc_eid_table.ueid_cnt++; + write_unlock(&smc_clc_eid_table.lock); + return 0; + +err_out: + write_unlock(&smc_clc_eid_table.lock); + kfree(new_ueid); + return rc; +} + +int smc_clc_ueid_count(void) +{ + int count; + + read_lock(&smc_clc_eid_table.lock); + count = smc_clc_eid_table.ueid_cnt; + read_unlock(&smc_clc_eid_table.lock); + + return count; +} + +int smc_nl_add_ueid(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *nla_ueid = info->attrs[SMC_NLA_EID_TABLE_ENTRY]; + char *ueid; + + if (!nla_ueid || nla_len(nla_ueid) != SMC_MAX_EID_LEN + 1) + return -EINVAL; + ueid = (char *)nla_data(nla_ueid); + + return smc_clc_ueid_add(ueid); +} + +/* remove one or all ueid entries from the table */ +static int smc_clc_ueid_remove(char *ueid) +{ + struct smc_clc_eid_entry *lst_ueid, *tmp_ueid; + int rc = -ENOENT; + + /* remove table entry */ + write_lock(&smc_clc_eid_table.lock); + list_for_each_entry_safe(lst_ueid, tmp_ueid, &smc_clc_eid_table.list, + list) { + if (!ueid || !memcmp(lst_ueid->eid, ueid, SMC_MAX_EID_LEN)) { + list_del(&lst_ueid->list); + smc_clc_eid_table.ueid_cnt--; + kfree(lst_ueid); + rc = 0; + } + } + if (!rc && !smc_clc_eid_table.ueid_cnt) { + smc_clc_eid_table.seid_enabled = 1; + rc = -EAGAIN; /* indicate success and enabling of seid */ + } + write_unlock(&smc_clc_eid_table.lock); + return rc; +} + +int smc_nl_remove_ueid(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *nla_ueid = info->attrs[SMC_NLA_EID_TABLE_ENTRY]; + char *ueid; + + if (!nla_ueid || nla_len(nla_ueid) != SMC_MAX_EID_LEN + 1) + return -EINVAL; + ueid = (char *)nla_data(nla_ueid); + + return smc_clc_ueid_remove(ueid); +} + +int smc_nl_flush_ueid(struct sk_buff *skb, struct genl_info *info) +{ + smc_clc_ueid_remove(NULL); + return 0; +} + +static int smc_nl_ueid_dumpinfo(struct sk_buff *skb, u32 portid, u32 seq, + u32 flags, char *ueid) +{ + char ueid_str[SMC_MAX_EID_LEN + 1]; + void *hdr; + + hdr = genlmsg_put(skb, portid, seq, &smc_gen_nl_family, + flags, SMC_NETLINK_DUMP_UEID); + if (!hdr) + return -ENOMEM; + memcpy(ueid_str, ueid, SMC_MAX_EID_LEN); + ueid_str[SMC_MAX_EID_LEN] = 0; + if (nla_put_string(skb, SMC_NLA_EID_TABLE_ENTRY, ueid_str)) { + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; + } + genlmsg_end(skb, hdr); + return 0; +} + +static int _smc_nl_ueid_dump(struct sk_buff *skb, u32 portid, u32 seq, + int start_idx) +{ + struct smc_clc_eid_entry *lst_ueid; + int idx = 0; + + read_lock(&smc_clc_eid_table.lock); + list_for_each_entry(lst_ueid, &smc_clc_eid_table.list, list) { + if (idx++ < start_idx) + continue; + if (smc_nl_ueid_dumpinfo(skb, portid, seq, NLM_F_MULTI, + lst_ueid->eid)) { + --idx; + break; + } + } + read_unlock(&smc_clc_eid_table.lock); + return idx; +} + +int smc_nl_dump_ueid(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct smc_nl_dmp_ctx *cb_ctx = smc_nl_dmp_ctx(cb); + int idx; + + idx = _smc_nl_ueid_dump(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, cb_ctx->pos[0]); + + cb_ctx->pos[0] = idx; + return skb->len; +} + +int smc_nl_dump_seid(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct smc_nl_dmp_ctx *cb_ctx = smc_nl_dmp_ctx(cb); + char seid_str[SMC_MAX_EID_LEN + 1]; + u8 seid_enabled; + void *hdr; + u8 *seid; + + if (cb_ctx->pos[0]) + return skb->len; + + hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &smc_gen_nl_family, NLM_F_MULTI, + SMC_NETLINK_DUMP_SEID); + if (!hdr) + return -ENOMEM; + if (!smc_ism_is_v2_capable()) + goto end; + + smc_ism_get_system_eid(&seid); + memcpy(seid_str, seid, SMC_MAX_EID_LEN); + seid_str[SMC_MAX_EID_LEN] = 0; + if (nla_put_string(skb, SMC_NLA_SEID_ENTRY, seid_str)) + goto err; + read_lock(&smc_clc_eid_table.lock); + seid_enabled = smc_clc_eid_table.seid_enabled; + read_unlock(&smc_clc_eid_table.lock); + if (nla_put_u8(skb, SMC_NLA_SEID_ENABLED, seid_enabled)) + goto err; +end: + genlmsg_end(skb, hdr); + cb_ctx->pos[0]++; + return skb->len; +err: + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; +} + +int smc_nl_enable_seid(struct sk_buff *skb, struct genl_info *info) +{ + write_lock(&smc_clc_eid_table.lock); + smc_clc_eid_table.seid_enabled = 1; + write_unlock(&smc_clc_eid_table.lock); + return 0; +} + +int smc_nl_disable_seid(struct sk_buff *skb, struct genl_info *info) +{ + int rc = 0; + + write_lock(&smc_clc_eid_table.lock); + if (!smc_clc_eid_table.ueid_cnt) + rc = -ENOENT; + else + smc_clc_eid_table.seid_enabled = 0; + write_unlock(&smc_clc_eid_table.lock); + return rc; +} + +static bool _smc_clc_match_ueid(u8 *peer_ueid) +{ + struct smc_clc_eid_entry *tmp_ueid; + + list_for_each_entry(tmp_ueid, &smc_clc_eid_table.list, list) { + if (!memcmp(tmp_ueid->eid, peer_ueid, SMC_MAX_EID_LEN)) + return true; + } + return false; +} + +bool smc_clc_match_eid(u8 *negotiated_eid, + struct smc_clc_v2_extension *smc_v2_ext, + u8 *peer_eid, u8 *local_eid) +{ + bool match = false; + int i; + + negotiated_eid[0] = 0; + read_lock(&smc_clc_eid_table.lock); + if (peer_eid && local_eid && + smc_clc_eid_table.seid_enabled && + smc_v2_ext->hdr.flag.seid && + !memcmp(peer_eid, local_eid, SMC_MAX_EID_LEN)) { + memcpy(negotiated_eid, peer_eid, SMC_MAX_EID_LEN); + match = true; + goto out; + } + + for (i = 0; i < smc_v2_ext->hdr.eid_cnt; i++) { + if (_smc_clc_match_ueid(smc_v2_ext->user_eids[i])) { + memcpy(negotiated_eid, smc_v2_ext->user_eids[i], + SMC_MAX_EID_LEN); + match = true; + goto out; + } + } +out: + read_unlock(&smc_clc_eid_table.lock); + return match; +} + +/* check arriving CLC proposal */ +static bool smc_clc_msg_prop_valid(struct smc_clc_msg_proposal *pclc) +{ + struct smc_clc_msg_proposal_prefix *pclc_prfx; + struct smc_clc_smcd_v2_extension *smcd_v2_ext; + struct smc_clc_msg_hdr *hdr = &pclc->hdr; + struct smc_clc_v2_extension *v2_ext; + + v2_ext = smc_get_clc_v2_ext(pclc); + pclc_prfx = smc_clc_proposal_get_prefix(pclc); + if (hdr->version == SMC_V1) { + if (hdr->typev1 == SMC_TYPE_N) + return false; + if (ntohs(hdr->length) != + sizeof(*pclc) + ntohs(pclc->iparea_offset) + + sizeof(*pclc_prfx) + + pclc_prfx->ipv6_prefixes_cnt * + sizeof(struct smc_clc_ipv6_prefix) + + sizeof(struct smc_clc_msg_trail)) + return false; + } else { + if (ntohs(hdr->length) != + sizeof(*pclc) + + sizeof(struct smc_clc_msg_smcd) + + (hdr->typev1 != SMC_TYPE_N ? + sizeof(*pclc_prfx) + + pclc_prfx->ipv6_prefixes_cnt * + sizeof(struct smc_clc_ipv6_prefix) : 0) + + (hdr->typev2 != SMC_TYPE_N ? + sizeof(*v2_ext) + + v2_ext->hdr.eid_cnt * SMC_MAX_EID_LEN : 0) + + (smcd_indicated(hdr->typev2) ? + sizeof(*smcd_v2_ext) + v2_ext->hdr.ism_gid_cnt * + sizeof(struct smc_clc_smcd_gid_chid) : + 0) + + sizeof(struct smc_clc_msg_trail)) + return false; + } + return true; +} + +/* check arriving CLC accept or confirm */ +static bool +smc_clc_msg_acc_conf_valid(struct smc_clc_msg_accept_confirm_v2 *clc_v2) +{ + struct smc_clc_msg_hdr *hdr = &clc_v2->hdr; + + if (hdr->typev1 != SMC_TYPE_R && hdr->typev1 != SMC_TYPE_D) + return false; + if (hdr->version == SMC_V1) { + if ((hdr->typev1 == SMC_TYPE_R && + ntohs(hdr->length) != SMCR_CLC_ACCEPT_CONFIRM_LEN) || + (hdr->typev1 == SMC_TYPE_D && + ntohs(hdr->length) != SMCD_CLC_ACCEPT_CONFIRM_LEN)) + return false; + } else { + if (hdr->typev1 == SMC_TYPE_D && + ntohs(hdr->length) != SMCD_CLC_ACCEPT_CONFIRM_LEN_V2 && + (ntohs(hdr->length) != SMCD_CLC_ACCEPT_CONFIRM_LEN_V2 + + sizeof(struct smc_clc_first_contact_ext))) + return false; + if (hdr->typev1 == SMC_TYPE_R && + ntohs(hdr->length) < SMCR_CLC_ACCEPT_CONFIRM_LEN_V2) + return false; + } + return true; +} + +/* check arriving CLC decline */ +static bool +smc_clc_msg_decl_valid(struct smc_clc_msg_decline *dclc) +{ + struct smc_clc_msg_hdr *hdr = &dclc->hdr; + + if (hdr->typev1 != SMC_TYPE_R && hdr->typev1 != SMC_TYPE_D) + return false; + if (hdr->version == SMC_V1) { + if (ntohs(hdr->length) != sizeof(struct smc_clc_msg_decline)) + return false; + } else { + if (ntohs(hdr->length) != sizeof(struct smc_clc_msg_decline_v2)) + return false; + } + return true; +} + +static void smc_clc_fill_fce(struct smc_clc_first_contact_ext *fce, int *len) +{ + memset(fce, 0, sizeof(*fce)); + fce->os_type = SMC_CLC_OS_LINUX; + fce->release = SMC_RELEASE; + memcpy(fce->hostname, smc_hostname, sizeof(smc_hostname)); + (*len) += sizeof(*fce); +} + +/* check if received message has a correct header length and contains valid + * heading and trailing eyecatchers + */ +static bool smc_clc_msg_hdr_valid(struct smc_clc_msg_hdr *clcm, bool check_trl) +{ + struct smc_clc_msg_accept_confirm_v2 *clc_v2; + struct smc_clc_msg_proposal *pclc; + struct smc_clc_msg_decline *dclc; + struct smc_clc_msg_trail *trl; + + if (memcmp(clcm->eyecatcher, SMC_EYECATCHER, sizeof(SMC_EYECATCHER)) && + memcmp(clcm->eyecatcher, SMCD_EYECATCHER, sizeof(SMCD_EYECATCHER))) + return false; + switch (clcm->type) { + case SMC_CLC_PROPOSAL: + pclc = (struct smc_clc_msg_proposal *)clcm; + if (!smc_clc_msg_prop_valid(pclc)) + return false; + trl = (struct smc_clc_msg_trail *) + ((u8 *)pclc + ntohs(pclc->hdr.length) - sizeof(*trl)); + break; + case SMC_CLC_ACCEPT: + case SMC_CLC_CONFIRM: + clc_v2 = (struct smc_clc_msg_accept_confirm_v2 *)clcm; + if (!smc_clc_msg_acc_conf_valid(clc_v2)) + return false; + trl = (struct smc_clc_msg_trail *) + ((u8 *)clc_v2 + ntohs(clc_v2->hdr.length) - + sizeof(*trl)); + break; + case SMC_CLC_DECLINE: + dclc = (struct smc_clc_msg_decline *)clcm; + if (!smc_clc_msg_decl_valid(dclc)) + return false; + check_trl = false; + break; + default: + return false; + } + if (check_trl && + memcmp(trl->eyecatcher, SMC_EYECATCHER, sizeof(SMC_EYECATCHER)) && + memcmp(trl->eyecatcher, SMCD_EYECATCHER, sizeof(SMCD_EYECATCHER))) + return false; + return true; +} + +/* find ipv4 addr on device and get the prefix len, fill CLC proposal msg */ +static int smc_clc_prfx_set4_rcu(struct dst_entry *dst, __be32 ipv4, + struct smc_clc_msg_proposal_prefix *prop) +{ + struct in_device *in_dev = __in_dev_get_rcu(dst->dev); + const struct in_ifaddr *ifa; + + if (!in_dev) + return -ENODEV; + + in_dev_for_each_ifa_rcu(ifa, in_dev) { + if (!inet_ifa_match(ipv4, ifa)) + continue; + prop->prefix_len = inet_mask_len(ifa->ifa_mask); + prop->outgoing_subnet = ifa->ifa_address & ifa->ifa_mask; + /* prop->ipv6_prefixes_cnt = 0; already done by memset before */ + return 0; + } + return -ENOENT; +} + +/* fill CLC proposal msg with ipv6 prefixes from device */ +static int smc_clc_prfx_set6_rcu(struct dst_entry *dst, + struct smc_clc_msg_proposal_prefix *prop, + struct smc_clc_ipv6_prefix *ipv6_prfx) +{ +#if IS_ENABLED(CONFIG_IPV6) + struct inet6_dev *in6_dev = __in6_dev_get(dst->dev); + struct inet6_ifaddr *ifa; + int cnt = 0; + + if (!in6_dev) + return -ENODEV; + /* use a maximum of 8 IPv6 prefixes from device */ + list_for_each_entry(ifa, &in6_dev->addr_list, if_list) { + if (ipv6_addr_type(&ifa->addr) & IPV6_ADDR_LINKLOCAL) + continue; + ipv6_addr_prefix(&ipv6_prfx[cnt].prefix, + &ifa->addr, ifa->prefix_len); + ipv6_prfx[cnt].prefix_len = ifa->prefix_len; + cnt++; + if (cnt == SMC_CLC_MAX_V6_PREFIX) + break; + } + prop->ipv6_prefixes_cnt = cnt; + if (cnt) + return 0; +#endif + return -ENOENT; +} + +/* retrieve and set prefixes in CLC proposal msg */ +static int smc_clc_prfx_set(struct socket *clcsock, + struct smc_clc_msg_proposal_prefix *prop, + struct smc_clc_ipv6_prefix *ipv6_prfx) +{ + struct dst_entry *dst = sk_dst_get(clcsock->sk); + struct sockaddr_storage addrs; + struct sockaddr_in6 *addr6; + struct sockaddr_in *addr; + int rc = -ENOENT; + + if (!dst) { + rc = -ENOTCONN; + goto out; + } + if (!dst->dev) { + rc = -ENODEV; + goto out_rel; + } + /* get address to which the internal TCP socket is bound */ + if (kernel_getsockname(clcsock, (struct sockaddr *)&addrs) < 0) + goto out_rel; + /* analyze IP specific data of net_device belonging to TCP socket */ + addr6 = (struct sockaddr_in6 *)&addrs; + rcu_read_lock(); + if (addrs.ss_family == PF_INET) { + /* IPv4 */ + addr = (struct sockaddr_in *)&addrs; + rc = smc_clc_prfx_set4_rcu(dst, addr->sin_addr.s_addr, prop); + } else if (ipv6_addr_v4mapped(&addr6->sin6_addr)) { + /* mapped IPv4 address - peer is IPv4 only */ + rc = smc_clc_prfx_set4_rcu(dst, addr6->sin6_addr.s6_addr32[3], + prop); + } else { + /* IPv6 */ + rc = smc_clc_prfx_set6_rcu(dst, prop, ipv6_prfx); + } + rcu_read_unlock(); +out_rel: + dst_release(dst); +out: + return rc; +} + +/* match ipv4 addrs of dev against addr in CLC proposal */ +static int smc_clc_prfx_match4_rcu(struct net_device *dev, + struct smc_clc_msg_proposal_prefix *prop) +{ + struct in_device *in_dev = __in_dev_get_rcu(dev); + const struct in_ifaddr *ifa; + + if (!in_dev) + return -ENODEV; + in_dev_for_each_ifa_rcu(ifa, in_dev) { + if (prop->prefix_len == inet_mask_len(ifa->ifa_mask) && + inet_ifa_match(prop->outgoing_subnet, ifa)) + return 0; + } + + return -ENOENT; +} + +/* match ipv6 addrs of dev against addrs in CLC proposal */ +static int smc_clc_prfx_match6_rcu(struct net_device *dev, + struct smc_clc_msg_proposal_prefix *prop) +{ +#if IS_ENABLED(CONFIG_IPV6) + struct inet6_dev *in6_dev = __in6_dev_get(dev); + struct smc_clc_ipv6_prefix *ipv6_prfx; + struct inet6_ifaddr *ifa; + int i, max; + + if (!in6_dev) + return -ENODEV; + /* ipv6 prefix list starts behind smc_clc_msg_proposal_prefix */ + ipv6_prfx = (struct smc_clc_ipv6_prefix *)((u8 *)prop + sizeof(*prop)); + max = min_t(u8, prop->ipv6_prefixes_cnt, SMC_CLC_MAX_V6_PREFIX); + list_for_each_entry(ifa, &in6_dev->addr_list, if_list) { + if (ipv6_addr_type(&ifa->addr) & IPV6_ADDR_LINKLOCAL) + continue; + for (i = 0; i < max; i++) { + if (ifa->prefix_len == ipv6_prfx[i].prefix_len && + ipv6_prefix_equal(&ifa->addr, &ipv6_prfx[i].prefix, + ifa->prefix_len)) + return 0; + } + } +#endif + return -ENOENT; +} + +/* check if proposed prefixes match one of our device prefixes */ +int smc_clc_prfx_match(struct socket *clcsock, + struct smc_clc_msg_proposal_prefix *prop) +{ + struct dst_entry *dst = sk_dst_get(clcsock->sk); + int rc; + + if (!dst) { + rc = -ENOTCONN; + goto out; + } + if (!dst->dev) { + rc = -ENODEV; + goto out_rel; + } + rcu_read_lock(); + if (!prop->ipv6_prefixes_cnt) + rc = smc_clc_prfx_match4_rcu(dst->dev, prop); + else + rc = smc_clc_prfx_match6_rcu(dst->dev, prop); + rcu_read_unlock(); +out_rel: + dst_release(dst); +out: + return rc; +} + +/* Wait for data on the tcp-socket, analyze received data + * Returns: + * 0 if success and it was not a decline that we received. + * SMC_CLC_DECL_REPLY if decline received for fallback w/o another decl send. + * clcsock error, -EINTR, -ECONNRESET, -EPROTO otherwise. + */ +int smc_clc_wait_msg(struct smc_sock *smc, void *buf, int buflen, + u8 expected_type, unsigned long timeout) +{ + long rcvtimeo = smc->clcsock->sk->sk_rcvtimeo; + struct sock *clc_sk = smc->clcsock->sk; + struct smc_clc_msg_hdr *clcm = buf; + struct msghdr msg = {NULL, 0}; + int reason_code = 0; + struct kvec vec = {buf, buflen}; + int len, datlen, recvlen; + bool check_trl = true; + int krflags; + + /* peek the first few bytes to determine length of data to receive + * so we don't consume any subsequent CLC message or payload data + * in the TCP byte stream + */ + /* + * Caller must make sure that buflen is no less than + * sizeof(struct smc_clc_msg_hdr) + */ + krflags = MSG_PEEK | MSG_WAITALL; + clc_sk->sk_rcvtimeo = timeout; + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &vec, 1, + sizeof(struct smc_clc_msg_hdr)); + len = sock_recvmsg(smc->clcsock, &msg, krflags); + if (signal_pending(current)) { + reason_code = -EINTR; + clc_sk->sk_err = EINTR; + smc->sk.sk_err = EINTR; + goto out; + } + if (clc_sk->sk_err) { + reason_code = -clc_sk->sk_err; + if (clc_sk->sk_err == EAGAIN && + expected_type == SMC_CLC_DECLINE) + clc_sk->sk_err = 0; /* reset for fallback usage */ + else + smc->sk.sk_err = clc_sk->sk_err; + goto out; + } + if (!len) { /* peer has performed orderly shutdown */ + smc->sk.sk_err = ECONNRESET; + reason_code = -ECONNRESET; + goto out; + } + if (len < 0) { + if (len != -EAGAIN || expected_type != SMC_CLC_DECLINE) + smc->sk.sk_err = -len; + reason_code = len; + goto out; + } + datlen = ntohs(clcm->length); + if ((len < sizeof(struct smc_clc_msg_hdr)) || + (clcm->version < SMC_V1) || + ((clcm->type != SMC_CLC_DECLINE) && + (clcm->type != expected_type))) { + smc->sk.sk_err = EPROTO; + reason_code = -EPROTO; + goto out; + } + + /* receive the complete CLC message */ + memset(&msg, 0, sizeof(struct msghdr)); + if (datlen > buflen) { + check_trl = false; + recvlen = buflen; + } else { + recvlen = datlen; + } + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &vec, 1, recvlen); + krflags = MSG_WAITALL; + len = sock_recvmsg(smc->clcsock, &msg, krflags); + if (len < recvlen || !smc_clc_msg_hdr_valid(clcm, check_trl)) { + smc->sk.sk_err = EPROTO; + reason_code = -EPROTO; + goto out; + } + datlen -= len; + while (datlen) { + u8 tmp[SMC_CLC_RECV_BUF_LEN]; + + vec.iov_base = &tmp; + vec.iov_len = SMC_CLC_RECV_BUF_LEN; + /* receive remaining proposal message */ + recvlen = datlen > SMC_CLC_RECV_BUF_LEN ? + SMC_CLC_RECV_BUF_LEN : datlen; + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &vec, 1, recvlen); + len = sock_recvmsg(smc->clcsock, &msg, krflags); + datlen -= len; + } + if (clcm->type == SMC_CLC_DECLINE) { + struct smc_clc_msg_decline *dclc; + + dclc = (struct smc_clc_msg_decline *)clcm; + reason_code = SMC_CLC_DECL_PEERDECL; + smc->peer_diagnosis = ntohl(dclc->peer_diagnosis); + if (((struct smc_clc_msg_decline *)buf)->hdr.typev2 & + SMC_FIRST_CONTACT_MASK) { + smc->conn.lgr->sync_err = 1; + smc_lgr_terminate_sched(smc->conn.lgr); + } + } + +out: + clc_sk->sk_rcvtimeo = rcvtimeo; + return reason_code; +} + +/* send CLC DECLINE message across internal TCP socket */ +int smc_clc_send_decline(struct smc_sock *smc, u32 peer_diag_info, u8 version) +{ + struct smc_clc_msg_decline *dclc_v1; + struct smc_clc_msg_decline_v2 dclc; + struct msghdr msg; + int len, send_len; + struct kvec vec; + + dclc_v1 = (struct smc_clc_msg_decline *)&dclc; + memset(&dclc, 0, sizeof(dclc)); + memcpy(dclc.hdr.eyecatcher, SMC_EYECATCHER, sizeof(SMC_EYECATCHER)); + dclc.hdr.type = SMC_CLC_DECLINE; + dclc.hdr.version = version; + dclc.os_type = version == SMC_V1 ? 0 : SMC_CLC_OS_LINUX; + dclc.hdr.typev2 = (peer_diag_info == SMC_CLC_DECL_SYNCERR) ? + SMC_FIRST_CONTACT_MASK : 0; + if ((!smc_conn_lgr_valid(&smc->conn) || !smc->conn.lgr->is_smcd) && + smc_ib_is_valid_local_systemid()) + memcpy(dclc.id_for_peer, local_systemid, + sizeof(local_systemid)); + dclc.peer_diagnosis = htonl(peer_diag_info); + if (version == SMC_V1) { + memcpy(dclc_v1->trl.eyecatcher, SMC_EYECATCHER, + sizeof(SMC_EYECATCHER)); + send_len = sizeof(*dclc_v1); + } else { + memcpy(dclc.trl.eyecatcher, SMC_EYECATCHER, + sizeof(SMC_EYECATCHER)); + send_len = sizeof(dclc); + } + dclc.hdr.length = htons(send_len); + + memset(&msg, 0, sizeof(msg)); + vec.iov_base = &dclc; + vec.iov_len = send_len; + len = kernel_sendmsg(smc->clcsock, &msg, &vec, 1, send_len); + if (len < 0 || len < send_len) + len = -EPROTO; + return len > 0 ? 0 : len; +} + +/* send CLC PROPOSAL message across internal TCP socket */ +int smc_clc_send_proposal(struct smc_sock *smc, struct smc_init_info *ini) +{ + struct smc_clc_smcd_v2_extension *smcd_v2_ext; + struct smc_clc_msg_proposal_prefix *pclc_prfx; + struct smc_clc_msg_proposal *pclc_base; + struct smc_clc_smcd_gid_chid *gidchids; + struct smc_clc_msg_proposal_area *pclc; + struct smc_clc_ipv6_prefix *ipv6_prfx; + struct smc_clc_v2_extension *v2_ext; + struct smc_clc_msg_smcd *pclc_smcd; + struct smc_clc_msg_trail *trl; + int len, i, plen, rc; + int reason_code = 0; + struct kvec vec[8]; + struct msghdr msg; + + pclc = kzalloc(sizeof(*pclc), GFP_KERNEL); + if (!pclc) + return -ENOMEM; + + pclc_base = &pclc->pclc_base; + pclc_smcd = &pclc->pclc_smcd; + pclc_prfx = &pclc->pclc_prfx; + ipv6_prfx = pclc->pclc_prfx_ipv6; + v2_ext = &pclc->pclc_v2_ext; + smcd_v2_ext = &pclc->pclc_smcd_v2_ext; + gidchids = pclc->pclc_gidchids; + trl = &pclc->pclc_trl; + + pclc_base->hdr.version = SMC_V2; + pclc_base->hdr.typev1 = ini->smc_type_v1; + pclc_base->hdr.typev2 = ini->smc_type_v2; + plen = sizeof(*pclc_base) + sizeof(*pclc_smcd) + sizeof(*trl); + + /* retrieve ip prefixes for CLC proposal msg */ + if (ini->smc_type_v1 != SMC_TYPE_N) { + rc = smc_clc_prfx_set(smc->clcsock, pclc_prfx, ipv6_prfx); + if (rc) { + if (ini->smc_type_v2 == SMC_TYPE_N) { + kfree(pclc); + return SMC_CLC_DECL_CNFERR; + } + pclc_base->hdr.typev1 = SMC_TYPE_N; + } else { + pclc_base->iparea_offset = htons(sizeof(*pclc_smcd)); + plen += sizeof(*pclc_prfx) + + pclc_prfx->ipv6_prefixes_cnt * + sizeof(ipv6_prfx[0]); + } + } + + /* build SMC Proposal CLC message */ + memcpy(pclc_base->hdr.eyecatcher, SMC_EYECATCHER, + sizeof(SMC_EYECATCHER)); + pclc_base->hdr.type = SMC_CLC_PROPOSAL; + if (smcr_indicated(ini->smc_type_v1)) { + /* add SMC-R specifics */ + memcpy(pclc_base->lcl.id_for_peer, local_systemid, + sizeof(local_systemid)); + memcpy(pclc_base->lcl.gid, ini->ib_gid, SMC_GID_SIZE); + memcpy(pclc_base->lcl.mac, &ini->ib_dev->mac[ini->ib_port - 1], + ETH_ALEN); + } + if (smcd_indicated(ini->smc_type_v1)) { + /* add SMC-D specifics */ + if (ini->ism_dev[0]) { + pclc_smcd->ism.gid = htonll(ini->ism_dev[0]->local_gid); + pclc_smcd->ism.chid = + htons(smc_ism_get_chid(ini->ism_dev[0])); + } + } + if (ini->smc_type_v2 == SMC_TYPE_N) { + pclc_smcd->v2_ext_offset = 0; + } else { + struct smc_clc_eid_entry *ueident; + u16 v2_ext_offset; + + v2_ext->hdr.flag.release = SMC_RELEASE; + v2_ext_offset = sizeof(*pclc_smcd) - + offsetofend(struct smc_clc_msg_smcd, v2_ext_offset); + if (ini->smc_type_v1 != SMC_TYPE_N) + v2_ext_offset += sizeof(*pclc_prfx) + + pclc_prfx->ipv6_prefixes_cnt * + sizeof(ipv6_prfx[0]); + pclc_smcd->v2_ext_offset = htons(v2_ext_offset); + plen += sizeof(*v2_ext); + + read_lock(&smc_clc_eid_table.lock); + v2_ext->hdr.eid_cnt = smc_clc_eid_table.ueid_cnt; + plen += smc_clc_eid_table.ueid_cnt * SMC_MAX_EID_LEN; + i = 0; + list_for_each_entry(ueident, &smc_clc_eid_table.list, list) { + memcpy(v2_ext->user_eids[i++], ueident->eid, + sizeof(ueident->eid)); + } + read_unlock(&smc_clc_eid_table.lock); + } + if (smcd_indicated(ini->smc_type_v2)) { + u8 *eid = NULL; + + v2_ext->hdr.flag.seid = smc_clc_eid_table.seid_enabled; + v2_ext->hdr.ism_gid_cnt = ini->ism_offered_cnt; + v2_ext->hdr.smcd_v2_ext_offset = htons(sizeof(*v2_ext) - + offsetofend(struct smc_clnt_opts_area_hdr, + smcd_v2_ext_offset) + + v2_ext->hdr.eid_cnt * SMC_MAX_EID_LEN); + smc_ism_get_system_eid(&eid); + if (eid && v2_ext->hdr.flag.seid) + memcpy(smcd_v2_ext->system_eid, eid, SMC_MAX_EID_LEN); + plen += sizeof(*smcd_v2_ext); + if (ini->ism_offered_cnt) { + for (i = 1; i <= ini->ism_offered_cnt; i++) { + gidchids[i - 1].gid = + htonll(ini->ism_dev[i]->local_gid); + gidchids[i - 1].chid = + htons(smc_ism_get_chid(ini->ism_dev[i])); + } + plen += ini->ism_offered_cnt * + sizeof(struct smc_clc_smcd_gid_chid); + } + } + if (smcr_indicated(ini->smc_type_v2)) + memcpy(v2_ext->roce, ini->smcrv2.ib_gid_v2, SMC_GID_SIZE); + + pclc_base->hdr.length = htons(plen); + memcpy(trl->eyecatcher, SMC_EYECATCHER, sizeof(SMC_EYECATCHER)); + + /* send SMC Proposal CLC message */ + memset(&msg, 0, sizeof(msg)); + i = 0; + vec[i].iov_base = pclc_base; + vec[i++].iov_len = sizeof(*pclc_base); + vec[i].iov_base = pclc_smcd; + vec[i++].iov_len = sizeof(*pclc_smcd); + if (ini->smc_type_v1 != SMC_TYPE_N) { + vec[i].iov_base = pclc_prfx; + vec[i++].iov_len = sizeof(*pclc_prfx); + if (pclc_prfx->ipv6_prefixes_cnt > 0) { + vec[i].iov_base = ipv6_prfx; + vec[i++].iov_len = pclc_prfx->ipv6_prefixes_cnt * + sizeof(ipv6_prfx[0]); + } + } + if (ini->smc_type_v2 != SMC_TYPE_N) { + vec[i].iov_base = v2_ext; + vec[i++].iov_len = sizeof(*v2_ext) + + (v2_ext->hdr.eid_cnt * SMC_MAX_EID_LEN); + if (smcd_indicated(ini->smc_type_v2)) { + vec[i].iov_base = smcd_v2_ext; + vec[i++].iov_len = sizeof(*smcd_v2_ext); + if (ini->ism_offered_cnt) { + vec[i].iov_base = gidchids; + vec[i++].iov_len = ini->ism_offered_cnt * + sizeof(struct smc_clc_smcd_gid_chid); + } + } + } + vec[i].iov_base = trl; + vec[i++].iov_len = sizeof(*trl); + /* due to the few bytes needed for clc-handshake this cannot block */ + len = kernel_sendmsg(smc->clcsock, &msg, vec, i, plen); + if (len < 0) { + smc->sk.sk_err = smc->clcsock->sk->sk_err; + reason_code = -smc->sk.sk_err; + } else if (len < ntohs(pclc_base->hdr.length)) { + reason_code = -ENETUNREACH; + smc->sk.sk_err = -reason_code; + } + + kfree(pclc); + return reason_code; +} + +/* build and send CLC CONFIRM / ACCEPT message */ +static int smc_clc_send_confirm_accept(struct smc_sock *smc, + struct smc_clc_msg_accept_confirm_v2 *clc_v2, + int first_contact, u8 version, + u8 *eid, struct smc_init_info *ini) +{ + struct smc_connection *conn = &smc->conn; + struct smc_clc_msg_accept_confirm *clc; + struct smc_clc_first_contact_ext fce; + struct smc_clc_fce_gid_ext gle; + struct smc_clc_msg_trail trl; + struct kvec vec[5]; + struct msghdr msg; + int i, len; + + /* send SMC Confirm CLC msg */ + clc = (struct smc_clc_msg_accept_confirm *)clc_v2; + clc->hdr.version = version; /* SMC version */ + if (first_contact) + clc->hdr.typev2 |= SMC_FIRST_CONTACT_MASK; + if (conn->lgr->is_smcd) { + /* SMC-D specific settings */ + memcpy(clc->hdr.eyecatcher, SMCD_EYECATCHER, + sizeof(SMCD_EYECATCHER)); + clc->hdr.typev1 = SMC_TYPE_D; + clc->d0.gid = conn->lgr->smcd->local_gid; + clc->d0.token = conn->rmb_desc->token; + clc->d0.dmbe_size = conn->rmbe_size_comp; + clc->d0.dmbe_idx = 0; + memcpy(&clc->d0.linkid, conn->lgr->id, SMC_LGR_ID_SIZE); + if (version == SMC_V1) { + clc->hdr.length = htons(SMCD_CLC_ACCEPT_CONFIRM_LEN); + } else { + clc_v2->d1.chid = + htons(smc_ism_get_chid(conn->lgr->smcd)); + if (eid && eid[0]) + memcpy(clc_v2->d1.eid, eid, SMC_MAX_EID_LEN); + len = SMCD_CLC_ACCEPT_CONFIRM_LEN_V2; + if (first_contact) + smc_clc_fill_fce(&fce, &len); + clc_v2->hdr.length = htons(len); + } + memcpy(trl.eyecatcher, SMCD_EYECATCHER, + sizeof(SMCD_EYECATCHER)); + } else { + struct smc_link *link = conn->lnk; + + /* SMC-R specific settings */ + memcpy(clc->hdr.eyecatcher, SMC_EYECATCHER, + sizeof(SMC_EYECATCHER)); + clc->hdr.typev1 = SMC_TYPE_R; + clc->hdr.length = htons(SMCR_CLC_ACCEPT_CONFIRM_LEN); + memcpy(clc->r0.lcl.id_for_peer, local_systemid, + sizeof(local_systemid)); + memcpy(&clc->r0.lcl.gid, link->gid, SMC_GID_SIZE); + memcpy(&clc->r0.lcl.mac, &link->smcibdev->mac[link->ibport - 1], + ETH_ALEN); + hton24(clc->r0.qpn, link->roce_qp->qp_num); + clc->r0.rmb_rkey = + htonl(conn->rmb_desc->mr[link->link_idx]->rkey); + clc->r0.rmbe_idx = 1; /* for now: 1 RMB = 1 RMBE */ + clc->r0.rmbe_alert_token = htonl(conn->alert_token_local); + switch (clc->hdr.type) { + case SMC_CLC_ACCEPT: + clc->r0.qp_mtu = link->path_mtu; + break; + case SMC_CLC_CONFIRM: + clc->r0.qp_mtu = min(link->path_mtu, link->peer_mtu); + break; + } + clc->r0.rmbe_size = conn->rmbe_size_comp; + clc->r0.rmb_dma_addr = conn->rmb_desc->is_vm ? + cpu_to_be64((uintptr_t)conn->rmb_desc->cpu_addr) : + cpu_to_be64((u64)sg_dma_address + (conn->rmb_desc->sgt[link->link_idx].sgl)); + hton24(clc->r0.psn, link->psn_initial); + if (version == SMC_V1) { + clc->hdr.length = htons(SMCR_CLC_ACCEPT_CONFIRM_LEN); + } else { + if (eid && eid[0]) + memcpy(clc_v2->r1.eid, eid, SMC_MAX_EID_LEN); + len = SMCR_CLC_ACCEPT_CONFIRM_LEN_V2; + if (first_contact) { + smc_clc_fill_fce(&fce, &len); + fce.v2_direct = !link->lgr->uses_gateway; + memset(&gle, 0, sizeof(gle)); + if (ini && clc->hdr.type == SMC_CLC_CONFIRM) { + gle.gid_cnt = ini->smcrv2.gidlist.len; + len += sizeof(gle); + len += gle.gid_cnt * sizeof(gle.gid[0]); + } else { + len += sizeof(gle.reserved); + } + } + clc_v2->hdr.length = htons(len); + } + memcpy(trl.eyecatcher, SMC_EYECATCHER, sizeof(SMC_EYECATCHER)); + } + + memset(&msg, 0, sizeof(msg)); + i = 0; + vec[i].iov_base = clc_v2; + if (version > SMC_V1) + vec[i++].iov_len = (clc->hdr.typev1 == SMC_TYPE_D ? + SMCD_CLC_ACCEPT_CONFIRM_LEN_V2 : + SMCR_CLC_ACCEPT_CONFIRM_LEN_V2) - + sizeof(trl); + else + vec[i++].iov_len = (clc->hdr.typev1 == SMC_TYPE_D ? + SMCD_CLC_ACCEPT_CONFIRM_LEN : + SMCR_CLC_ACCEPT_CONFIRM_LEN) - + sizeof(trl); + if (version > SMC_V1 && first_contact) { + vec[i].iov_base = &fce; + vec[i++].iov_len = sizeof(fce); + if (!conn->lgr->is_smcd) { + if (clc->hdr.type == SMC_CLC_CONFIRM) { + vec[i].iov_base = &gle; + vec[i++].iov_len = sizeof(gle); + vec[i].iov_base = &ini->smcrv2.gidlist.list; + vec[i++].iov_len = gle.gid_cnt * + sizeof(gle.gid[0]); + } else { + vec[i].iov_base = &gle.reserved; + vec[i++].iov_len = sizeof(gle.reserved); + } + } + } + vec[i].iov_base = &trl; + vec[i++].iov_len = sizeof(trl); + return kernel_sendmsg(smc->clcsock, &msg, vec, 1, + ntohs(clc->hdr.length)); +} + +/* send CLC CONFIRM message across internal TCP socket */ +int smc_clc_send_confirm(struct smc_sock *smc, bool clnt_first_contact, + u8 version, u8 *eid, struct smc_init_info *ini) +{ + struct smc_clc_msg_accept_confirm_v2 cclc_v2; + int reason_code = 0; + int len; + + /* send SMC Confirm CLC msg */ + memset(&cclc_v2, 0, sizeof(cclc_v2)); + cclc_v2.hdr.type = SMC_CLC_CONFIRM; + len = smc_clc_send_confirm_accept(smc, &cclc_v2, clnt_first_contact, + version, eid, ini); + if (len < ntohs(cclc_v2.hdr.length)) { + if (len >= 0) { + reason_code = -ENETUNREACH; + smc->sk.sk_err = -reason_code; + } else { + smc->sk.sk_err = smc->clcsock->sk->sk_err; + reason_code = -smc->sk.sk_err; + } + } + return reason_code; +} + +/* send CLC ACCEPT message across internal TCP socket */ +int smc_clc_send_accept(struct smc_sock *new_smc, bool srv_first_contact, + u8 version, u8 *negotiated_eid) +{ + struct smc_clc_msg_accept_confirm_v2 aclc_v2; + int len; + + memset(&aclc_v2, 0, sizeof(aclc_v2)); + aclc_v2.hdr.type = SMC_CLC_ACCEPT; + len = smc_clc_send_confirm_accept(new_smc, &aclc_v2, srv_first_contact, + version, negotiated_eid, NULL); + if (len < ntohs(aclc_v2.hdr.length)) + len = len >= 0 ? -EPROTO : -new_smc->clcsock->sk->sk_err; + + return len > 0 ? 0 : len; +} + +void smc_clc_get_hostname(u8 **host) +{ + *host = &smc_hostname[0]; +} + +void __init smc_clc_init(void) +{ + struct new_utsname *u; + + memset(smc_hostname, _S, sizeof(smc_hostname)); /* ASCII blanks */ + u = utsname(); + memcpy(smc_hostname, u->nodename, + min_t(size_t, strlen(u->nodename), sizeof(smc_hostname))); + + INIT_LIST_HEAD(&smc_clc_eid_table.list); + rwlock_init(&smc_clc_eid_table.lock); + smc_clc_eid_table.ueid_cnt = 0; + smc_clc_eid_table.seid_enabled = 1; +} + +void smc_clc_exit(void) +{ + smc_clc_ueid_remove(NULL); +} diff --git a/net/smc/smc_clc.h b/net/smc/smc_clc.h new file mode 100644 index 000000000..5fee545c9 --- /dev/null +++ b/net/smc/smc_clc.h @@ -0,0 +1,401 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * CLC (connection layer control) handshake over initial TCP socket to + * prepare for RDMA traffic + * + * Copyright IBM Corp. 2016 + * + * Author(s): Ursula Braun + */ + +#ifndef _SMC_CLC_H +#define _SMC_CLC_H + +#include +#include + +#include "smc.h" +#include "smc_netlink.h" + +#define SMC_CLC_PROPOSAL 0x01 +#define SMC_CLC_ACCEPT 0x02 +#define SMC_CLC_CONFIRM 0x03 +#define SMC_CLC_DECLINE 0x04 + +#define SMC_TYPE_R 0 /* SMC-R only */ +#define SMC_TYPE_D 1 /* SMC-D only */ +#define SMC_TYPE_N 2 /* neither SMC-R nor SMC-D */ +#define SMC_TYPE_B 3 /* SMC-R and SMC-D */ +#define CLC_WAIT_TIME (6 * HZ) /* max. wait time on clcsock */ +#define CLC_WAIT_TIME_SHORT HZ /* short wait time on clcsock */ +#define SMC_CLC_DECL_MEM 0x01010000 /* insufficient memory resources */ +#define SMC_CLC_DECL_TIMEOUT_CL 0x02010000 /* timeout w4 QP confirm link */ +#define SMC_CLC_DECL_TIMEOUT_AL 0x02020000 /* timeout w4 QP add link */ +#define SMC_CLC_DECL_CNFERR 0x03000000 /* configuration error */ +#define SMC_CLC_DECL_PEERNOSMC 0x03010000 /* peer did not indicate SMC */ +#define SMC_CLC_DECL_IPSEC 0x03020000 /* IPsec usage */ +#define SMC_CLC_DECL_NOSMCDEV 0x03030000 /* no SMC device found (R or D) */ +#define SMC_CLC_DECL_NOSMCDDEV 0x03030001 /* no SMC-D device found */ +#define SMC_CLC_DECL_NOSMCRDEV 0x03030002 /* no SMC-R device found */ +#define SMC_CLC_DECL_NOISM2SUPP 0x03030003 /* hardware has no ISMv2 support */ +#define SMC_CLC_DECL_NOV2EXT 0x03030004 /* peer sent no clc v2 extension */ +#define SMC_CLC_DECL_NOV2DEXT 0x03030005 /* peer sent no clc SMC-Dv2 ext. */ +#define SMC_CLC_DECL_NOSEID 0x03030006 /* peer sent no SEID */ +#define SMC_CLC_DECL_NOSMCD2DEV 0x03030007 /* no SMC-Dv2 device found */ +#define SMC_CLC_DECL_NOUEID 0x03030008 /* peer sent no UEID */ +#define SMC_CLC_DECL_MODEUNSUPP 0x03040000 /* smc modes do not match (R or D)*/ +#define SMC_CLC_DECL_RMBE_EC 0x03050000 /* peer has eyecatcher in RMBE */ +#define SMC_CLC_DECL_OPTUNSUPP 0x03060000 /* fastopen sockopt not supported */ +#define SMC_CLC_DECL_DIFFPREFIX 0x03070000 /* IP prefix / subnet mismatch */ +#define SMC_CLC_DECL_GETVLANERR 0x03080000 /* err to get vlan id of ip device*/ +#define SMC_CLC_DECL_ISMVLANERR 0x03090000 /* err to reg vlan id on ism dev */ +#define SMC_CLC_DECL_NOACTLINK 0x030a0000 /* no active smc-r link in lgr */ +#define SMC_CLC_DECL_NOSRVLINK 0x030b0000 /* SMC-R link from srv not found */ +#define SMC_CLC_DECL_VERSMISMAT 0x030c0000 /* SMC version mismatch */ +#define SMC_CLC_DECL_MAX_DMB 0x030d0000 /* SMC-D DMB limit exceeded */ +#define SMC_CLC_DECL_NOROUTE 0x030e0000 /* SMC-Rv2 conn. no route to peer */ +#define SMC_CLC_DECL_NOINDIRECT 0x030f0000 /* SMC-Rv2 conn. indirect mismatch*/ +#define SMC_CLC_DECL_SYNCERR 0x04000000 /* synchronization error */ +#define SMC_CLC_DECL_PEERDECL 0x05000000 /* peer declined during handshake */ +#define SMC_CLC_DECL_INTERR 0x09990000 /* internal error */ +#define SMC_CLC_DECL_ERR_RTOK 0x09990001 /* rtoken handling failed */ +#define SMC_CLC_DECL_ERR_RDYLNK 0x09990002 /* ib ready link failed */ +#define SMC_CLC_DECL_ERR_REGBUF 0x09990003 /* reg rdma bufs failed */ + +#define SMC_FIRST_CONTACT_MASK 0b10 /* first contact bit within typev2 */ + +struct smc_clc_msg_hdr { /* header1 of clc messages */ + u8 eyecatcher[4]; /* eye catcher */ + u8 type; /* proposal / accept / confirm / decline */ + __be16 length; +#if defined(__BIG_ENDIAN_BITFIELD) + u8 version : 4, + typev2 : 2, + typev1 : 2; +#elif defined(__LITTLE_ENDIAN_BITFIELD) + u8 typev1 : 2, + typev2 : 2, + version : 4; +#endif +} __packed; /* format defined in RFC7609 */ + +struct smc_clc_msg_trail { /* trailer of clc messages */ + u8 eyecatcher[4]; +}; + +struct smc_clc_msg_local { /* header2 of clc messages */ + u8 id_for_peer[SMC_SYSTEMID_LEN]; /* unique system id */ + u8 gid[16]; /* gid of ib_device port */ + u8 mac[6]; /* mac of ib_device port */ +}; + +/* Struct would be 4 byte aligned, but it is used in an array that is sent + * to peers and must conform to RFC7609, hence we need to use packed here. + */ +struct smc_clc_ipv6_prefix { + struct in6_addr prefix; + u8 prefix_len; +} __packed; /* format defined in RFC7609 */ + +#if defined(__BIG_ENDIAN_BITFIELD) +struct smc_clc_v2_flag { + u8 release : 4, + rsvd : 3, + seid : 1; +}; +#elif defined(__LITTLE_ENDIAN_BITFIELD) +struct smc_clc_v2_flag { + u8 seid : 1, + rsvd : 3, + release : 4; +}; +#endif + +struct smc_clnt_opts_area_hdr { + u8 eid_cnt; /* number of user defined EIDs */ + u8 ism_gid_cnt; /* number of ISMv2 GIDs */ + u8 reserved1; + struct smc_clc_v2_flag flag; + u8 reserved2[2]; + __be16 smcd_v2_ext_offset; /* SMC-Dv2 Extension Offset */ +}; + +struct smc_clc_smcd_gid_chid { + __be64 gid; /* ISM GID */ + __be16 chid; /* ISMv2 CHID */ +} __packed; /* format defined in + * IBM Shared Memory Communications Version 2 + * (https://www.ibm.com/support/pages/node/6326337) + */ + +struct smc_clc_v2_extension { + struct smc_clnt_opts_area_hdr hdr; + u8 roce[16]; /* RoCEv2 GID */ + u8 reserved[16]; + u8 user_eids[][SMC_MAX_EID_LEN]; +}; + +struct smc_clc_msg_proposal_prefix { /* prefix part of clc proposal message*/ + __be32 outgoing_subnet; /* subnet mask */ + u8 prefix_len; /* number of significant bits in mask */ + u8 reserved[2]; + u8 ipv6_prefixes_cnt; /* number of IPv6 prefixes in prefix array */ +} __aligned(4); + +struct smc_clc_msg_smcd { /* SMC-D GID information */ + struct smc_clc_smcd_gid_chid ism; /* ISM native GID+CHID of requestor */ + __be16 v2_ext_offset; /* SMC Version 2 Extension Offset */ + u8 reserved[28]; +}; + +struct smc_clc_smcd_v2_extension { + u8 system_eid[SMC_MAX_EID_LEN]; + u8 reserved[16]; + struct smc_clc_smcd_gid_chid gidchid[]; +}; + +struct smc_clc_msg_proposal { /* clc proposal message sent by Linux */ + struct smc_clc_msg_hdr hdr; + struct smc_clc_msg_local lcl; + __be16 iparea_offset; /* offset to IP address information area */ +} __aligned(4); + +#define SMC_CLC_MAX_V6_PREFIX 8 +#define SMC_CLC_MAX_UEID 8 + +struct smc_clc_msg_proposal_area { + struct smc_clc_msg_proposal pclc_base; + struct smc_clc_msg_smcd pclc_smcd; + struct smc_clc_msg_proposal_prefix pclc_prfx; + struct smc_clc_ipv6_prefix pclc_prfx_ipv6[SMC_CLC_MAX_V6_PREFIX]; + struct smc_clc_v2_extension pclc_v2_ext; + u8 user_eids[SMC_CLC_MAX_UEID][SMC_MAX_EID_LEN]; + struct smc_clc_smcd_v2_extension pclc_smcd_v2_ext; + struct smc_clc_smcd_gid_chid pclc_gidchids[SMC_MAX_ISM_DEVS]; + struct smc_clc_msg_trail pclc_trl; +}; + +struct smcr_clc_msg_accept_confirm { /* SMCR accept/confirm */ + struct smc_clc_msg_local lcl; + u8 qpn[3]; /* QP number */ + __be32 rmb_rkey; /* RMB rkey */ + u8 rmbe_idx; /* Index of RMBE in RMB */ + __be32 rmbe_alert_token; /* unique connection id */ + #if defined(__BIG_ENDIAN_BITFIELD) + u8 rmbe_size : 4, /* buf size (compressed) */ + qp_mtu : 4; /* QP mtu */ +#elif defined(__LITTLE_ENDIAN_BITFIELD) + u8 qp_mtu : 4, + rmbe_size : 4; +#endif + u8 reserved; + __be64 rmb_dma_addr; /* RMB virtual address */ + u8 reserved2; + u8 psn[3]; /* packet sequence number */ +} __packed; + +struct smcd_clc_msg_accept_confirm_common { /* SMCD accept/confirm */ + u64 gid; /* Sender GID */ + u64 token; /* DMB token */ + u8 dmbe_idx; /* DMBE index */ +#if defined(__BIG_ENDIAN_BITFIELD) + u8 dmbe_size : 4, /* buf size (compressed) */ + reserved3 : 4; +#elif defined(__LITTLE_ENDIAN_BITFIELD) + u8 reserved3 : 4, + dmbe_size : 4; +#endif + u16 reserved4; + __be32 linkid; /* Link identifier */ +} __packed; + +#define SMC_CLC_OS_ZOS 1 +#define SMC_CLC_OS_LINUX 2 +#define SMC_CLC_OS_AIX 3 + +struct smc_clc_first_contact_ext { +#if defined(__BIG_ENDIAN_BITFIELD) + u8 v2_direct : 1, + reserved : 7; + u8 os_type : 4, + release : 4; +#elif defined(__LITTLE_ENDIAN_BITFIELD) + u8 reserved : 7, + v2_direct : 1; + u8 release : 4, + os_type : 4; +#endif + u8 reserved2[2]; + u8 hostname[SMC_MAX_HOSTNAME_LEN]; +}; + +struct smc_clc_fce_gid_ext { + u8 reserved[16]; + u8 gid_cnt; + u8 reserved2[3]; + u8 gid[][SMC_GID_SIZE]; +}; + +struct smc_clc_msg_accept_confirm { /* clc accept / confirm message */ + struct smc_clc_msg_hdr hdr; + union { + struct smcr_clc_msg_accept_confirm r0; /* SMC-R */ + struct { /* SMC-D */ + struct smcd_clc_msg_accept_confirm_common d0; + u32 reserved5[3]; + }; + }; +} __packed; /* format defined in RFC7609 */ + +struct smc_clc_msg_accept_confirm_v2 { /* clc accept / confirm message */ + struct smc_clc_msg_hdr hdr; + union { + struct { /* SMC-R */ + struct smcr_clc_msg_accept_confirm r0; + u8 eid[SMC_MAX_EID_LEN]; + u8 reserved6[8]; + } r1; + struct { /* SMC-D */ + struct smcd_clc_msg_accept_confirm_common d0; + __be16 chid; + u8 eid[SMC_MAX_EID_LEN]; + u8 reserved5[8]; + } d1; + }; +}; + +struct smc_clc_msg_decline { /* clc decline message */ + struct smc_clc_msg_hdr hdr; + u8 id_for_peer[SMC_SYSTEMID_LEN]; /* sender peer_id */ + __be32 peer_diagnosis; /* diagnosis information */ +#if defined(__BIG_ENDIAN_BITFIELD) + u8 os_type : 4, + reserved : 4; +#elif defined(__LITTLE_ENDIAN_BITFIELD) + u8 reserved : 4, + os_type : 4; +#endif + u8 reserved2[3]; + struct smc_clc_msg_trail trl; /* eye catcher "SMCD" or "SMCR" EBCDIC */ +} __aligned(4); + +#define SMC_DECL_DIAG_COUNT_V2 4 /* no. of additional peer diagnosis codes */ + +struct smc_clc_msg_decline_v2 { /* clc decline message */ + struct smc_clc_msg_hdr hdr; + u8 id_for_peer[SMC_SYSTEMID_LEN]; /* sender peer_id */ + __be32 peer_diagnosis; /* diagnosis information */ +#if defined(__BIG_ENDIAN_BITFIELD) + u8 os_type : 4, + reserved : 4; +#elif defined(__LITTLE_ENDIAN_BITFIELD) + u8 reserved : 4, + os_type : 4; +#endif + u8 reserved2[3]; + __be32 peer_diagnosis_v2[SMC_DECL_DIAG_COUNT_V2]; + struct smc_clc_msg_trail trl; /* eye catcher "SMCD" or "SMCR" EBCDIC */ +} __aligned(4); + +/* determine start of the prefix area within the proposal message */ +static inline struct smc_clc_msg_proposal_prefix * +smc_clc_proposal_get_prefix(struct smc_clc_msg_proposal *pclc) +{ + return (struct smc_clc_msg_proposal_prefix *) + ((u8 *)pclc + sizeof(*pclc) + ntohs(pclc->iparea_offset)); +} + +static inline bool smcr_indicated(int smc_type) +{ + return smc_type == SMC_TYPE_R || smc_type == SMC_TYPE_B; +} + +static inline bool smcd_indicated(int smc_type) +{ + return smc_type == SMC_TYPE_D || smc_type == SMC_TYPE_B; +} + +static inline u8 smc_indicated_type(int is_smcd, int is_smcr) +{ + if (is_smcd && is_smcr) + return SMC_TYPE_B; + if (is_smcd) + return SMC_TYPE_D; + if (is_smcr) + return SMC_TYPE_R; + return SMC_TYPE_N; +} + +/* get SMC-D info from proposal message */ +static inline struct smc_clc_msg_smcd * +smc_get_clc_msg_smcd(struct smc_clc_msg_proposal *prop) +{ + if (smcd_indicated(prop->hdr.typev1) && + ntohs(prop->iparea_offset) != sizeof(struct smc_clc_msg_smcd)) + return NULL; + + return (struct smc_clc_msg_smcd *)(prop + 1); +} + +static inline struct smc_clc_v2_extension * +smc_get_clc_v2_ext(struct smc_clc_msg_proposal *prop) +{ + struct smc_clc_msg_smcd *prop_smcd = smc_get_clc_msg_smcd(prop); + + if (!prop_smcd || !ntohs(prop_smcd->v2_ext_offset)) + return NULL; + + return (struct smc_clc_v2_extension *) + ((u8 *)prop_smcd + + offsetof(struct smc_clc_msg_smcd, v2_ext_offset) + + sizeof(prop_smcd->v2_ext_offset) + + ntohs(prop_smcd->v2_ext_offset)); +} + +static inline struct smc_clc_smcd_v2_extension * +smc_get_clc_smcd_v2_ext(struct smc_clc_v2_extension *prop_v2ext) +{ + if (!prop_v2ext) + return NULL; + if (!ntohs(prop_v2ext->hdr.smcd_v2_ext_offset)) + return NULL; + + return (struct smc_clc_smcd_v2_extension *) + ((u8 *)prop_v2ext + + offsetof(struct smc_clc_v2_extension, hdr) + + offsetof(struct smc_clnt_opts_area_hdr, smcd_v2_ext_offset) + + sizeof(prop_v2ext->hdr.smcd_v2_ext_offset) + + ntohs(prop_v2ext->hdr.smcd_v2_ext_offset)); +} + +struct smcd_dev; +struct smc_init_info; + +int smc_clc_prfx_match(struct socket *clcsock, + struct smc_clc_msg_proposal_prefix *prop); +int smc_clc_wait_msg(struct smc_sock *smc, void *buf, int buflen, + u8 expected_type, unsigned long timeout); +int smc_clc_send_decline(struct smc_sock *smc, u32 peer_diag_info, u8 version); +int smc_clc_send_proposal(struct smc_sock *smc, struct smc_init_info *ini); +int smc_clc_send_confirm(struct smc_sock *smc, bool clnt_first_contact, + u8 version, u8 *eid, struct smc_init_info *ini); +int smc_clc_send_accept(struct smc_sock *smc, bool srv_first_contact, + u8 version, u8 *negotiated_eid); +void smc_clc_init(void) __init; +void smc_clc_exit(void); +void smc_clc_get_hostname(u8 **host); +bool smc_clc_match_eid(u8 *negotiated_eid, + struct smc_clc_v2_extension *smc_v2_ext, + u8 *peer_eid, u8 *local_eid); +int smc_clc_ueid_count(void); +int smc_nl_dump_ueid(struct sk_buff *skb, struct netlink_callback *cb); +int smc_nl_add_ueid(struct sk_buff *skb, struct genl_info *info); +int smc_nl_remove_ueid(struct sk_buff *skb, struct genl_info *info); +int smc_nl_flush_ueid(struct sk_buff *skb, struct genl_info *info); +int smc_nl_dump_seid(struct sk_buff *skb, struct netlink_callback *cb); +int smc_nl_enable_seid(struct sk_buff *skb, struct genl_info *info); +int smc_nl_disable_seid(struct sk_buff *skb, struct genl_info *info); + +#endif diff --git a/net/smc/smc_close.c b/net/smc/smc_close.c new file mode 100644 index 000000000..10219f55a --- /dev/null +++ b/net/smc/smc_close.c @@ -0,0 +1,506 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Socket Closing - normal and abnormal + * + * Copyright IBM Corp. 2016 + * + * Author(s): Ursula Braun + */ + +#include +#include + +#include +#include + +#include "smc.h" +#include "smc_tx.h" +#include "smc_cdc.h" +#include "smc_close.h" + +/* release the clcsock that is assigned to the smc_sock */ +void smc_clcsock_release(struct smc_sock *smc) +{ + struct socket *tcp; + + if (smc->listen_smc && current_work() != &smc->smc_listen_work) + cancel_work_sync(&smc->smc_listen_work); + mutex_lock(&smc->clcsock_release_lock); + if (smc->clcsock) { + tcp = smc->clcsock; + smc->clcsock = NULL; + sock_release(tcp); + } + mutex_unlock(&smc->clcsock_release_lock); +} + +static void smc_close_cleanup_listen(struct sock *parent) +{ + struct sock *sk; + + /* Close non-accepted connections */ + while ((sk = smc_accept_dequeue(parent, NULL))) + smc_close_non_accepted(sk); +} + +/* wait for sndbuf data being transmitted */ +static void smc_close_stream_wait(struct smc_sock *smc, long timeout) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + struct sock *sk = &smc->sk; + + if (!timeout) + return; + + if (!smc_tx_prepared_sends(&smc->conn)) + return; + + /* Send out corked data remaining in sndbuf */ + smc_tx_pending(&smc->conn); + + smc->wait_close_tx_prepared = 1; + add_wait_queue(sk_sleep(sk), &wait); + while (!signal_pending(current) && timeout) { + int rc; + + rc = sk_wait_event(sk, &timeout, + !smc_tx_prepared_sends(&smc->conn) || + READ_ONCE(sk->sk_err) == ECONNABORTED || + READ_ONCE(sk->sk_err) == ECONNRESET || + smc->conn.killed, + &wait); + if (rc) + break; + } + remove_wait_queue(sk_sleep(sk), &wait); + smc->wait_close_tx_prepared = 0; +} + +void smc_close_wake_tx_prepared(struct smc_sock *smc) +{ + if (smc->wait_close_tx_prepared) + /* wake up socket closing */ + smc->sk.sk_state_change(&smc->sk); +} + +static int smc_close_wr(struct smc_connection *conn) +{ + conn->local_tx_ctrl.conn_state_flags.peer_done_writing = 1; + + return smc_cdc_get_slot_and_msg_send(conn); +} + +static int smc_close_final(struct smc_connection *conn) +{ + if (atomic_read(&conn->bytes_to_rcv)) + conn->local_tx_ctrl.conn_state_flags.peer_conn_abort = 1; + else + conn->local_tx_ctrl.conn_state_flags.peer_conn_closed = 1; + if (conn->killed) + return -EPIPE; + + return smc_cdc_get_slot_and_msg_send(conn); +} + +int smc_close_abort(struct smc_connection *conn) +{ + conn->local_tx_ctrl.conn_state_flags.peer_conn_abort = 1; + + return smc_cdc_get_slot_and_msg_send(conn); +} + +static void smc_close_cancel_work(struct smc_sock *smc) +{ + struct sock *sk = &smc->sk; + + release_sock(sk); + if (cancel_work_sync(&smc->conn.close_work)) + sock_put(sk); + cancel_delayed_work_sync(&smc->conn.tx_work); + lock_sock(sk); +} + +/* terminate smc socket abnormally - active abort + * link group is terminated, i.e. RDMA communication no longer possible + */ +void smc_close_active_abort(struct smc_sock *smc) +{ + struct sock *sk = &smc->sk; + bool release_clcsock = false; + + if (sk->sk_state != SMC_INIT && smc->clcsock && smc->clcsock->sk) { + sk->sk_err = ECONNABORTED; + if (smc->clcsock && smc->clcsock->sk) + tcp_abort(smc->clcsock->sk, ECONNABORTED); + } + switch (sk->sk_state) { + case SMC_ACTIVE: + case SMC_APPCLOSEWAIT1: + case SMC_APPCLOSEWAIT2: + sk->sk_state = SMC_PEERABORTWAIT; + smc_close_cancel_work(smc); + if (sk->sk_state != SMC_PEERABORTWAIT) + break; + sk->sk_state = SMC_CLOSED; + sock_put(sk); /* (postponed) passive closing */ + break; + case SMC_PEERCLOSEWAIT1: + case SMC_PEERCLOSEWAIT2: + case SMC_PEERFINCLOSEWAIT: + sk->sk_state = SMC_PEERABORTWAIT; + smc_close_cancel_work(smc); + if (sk->sk_state != SMC_PEERABORTWAIT) + break; + sk->sk_state = SMC_CLOSED; + smc_conn_free(&smc->conn); + release_clcsock = true; + sock_put(sk); /* passive closing */ + break; + case SMC_PROCESSABORT: + case SMC_APPFINCLOSEWAIT: + sk->sk_state = SMC_PEERABORTWAIT; + smc_close_cancel_work(smc); + if (sk->sk_state != SMC_PEERABORTWAIT) + break; + sk->sk_state = SMC_CLOSED; + smc_conn_free(&smc->conn); + release_clcsock = true; + break; + case SMC_INIT: + case SMC_PEERABORTWAIT: + case SMC_CLOSED: + break; + } + + smc_sock_set_flag(sk, SOCK_DEAD); + sk->sk_state_change(sk); + + if (release_clcsock) { + release_sock(sk); + smc_clcsock_release(smc); + lock_sock(sk); + } +} + +static inline bool smc_close_sent_any_close(struct smc_connection *conn) +{ + return conn->local_tx_ctrl.conn_state_flags.peer_conn_abort || + conn->local_tx_ctrl.conn_state_flags.peer_conn_closed; +} + +int smc_close_active(struct smc_sock *smc) +{ + struct smc_cdc_conn_state_flags *txflags = + &smc->conn.local_tx_ctrl.conn_state_flags; + struct smc_connection *conn = &smc->conn; + struct sock *sk = &smc->sk; + int old_state; + long timeout; + int rc = 0; + int rc1 = 0; + + timeout = current->flags & PF_EXITING ? + 0 : sock_flag(sk, SOCK_LINGER) ? + sk->sk_lingertime : SMC_MAX_STREAM_WAIT_TIMEOUT; + + old_state = sk->sk_state; +again: + switch (sk->sk_state) { + case SMC_INIT: + sk->sk_state = SMC_CLOSED; + break; + case SMC_LISTEN: + sk->sk_state = SMC_CLOSED; + sk->sk_state_change(sk); /* wake up accept */ + if (smc->clcsock && smc->clcsock->sk) { + write_lock_bh(&smc->clcsock->sk->sk_callback_lock); + smc_clcsock_restore_cb(&smc->clcsock->sk->sk_data_ready, + &smc->clcsk_data_ready); + smc->clcsock->sk->sk_user_data = NULL; + write_unlock_bh(&smc->clcsock->sk->sk_callback_lock); + rc = kernel_sock_shutdown(smc->clcsock, SHUT_RDWR); + } + smc_close_cleanup_listen(sk); + release_sock(sk); + flush_work(&smc->tcp_listen_work); + lock_sock(sk); + break; + case SMC_ACTIVE: + smc_close_stream_wait(smc, timeout); + release_sock(sk); + cancel_delayed_work_sync(&conn->tx_work); + lock_sock(sk); + if (sk->sk_state == SMC_ACTIVE) { + /* send close request */ + rc = smc_close_final(conn); + sk->sk_state = SMC_PEERCLOSEWAIT1; + + /* actively shutdown clcsock before peer close it, + * prevent peer from entering TIME_WAIT state. + */ + if (smc->clcsock && smc->clcsock->sk) { + rc1 = kernel_sock_shutdown(smc->clcsock, + SHUT_RDWR); + rc = rc ? rc : rc1; + } + } else { + /* peer event has changed the state */ + goto again; + } + break; + case SMC_APPFINCLOSEWAIT: + /* socket already shutdown wr or both (active close) */ + if (txflags->peer_done_writing && + !smc_close_sent_any_close(conn)) { + /* just shutdown wr done, send close request */ + rc = smc_close_final(conn); + } + sk->sk_state = SMC_CLOSED; + break; + case SMC_APPCLOSEWAIT1: + case SMC_APPCLOSEWAIT2: + if (!smc_cdc_rxed_any_close(conn)) + smc_close_stream_wait(smc, timeout); + release_sock(sk); + cancel_delayed_work_sync(&conn->tx_work); + lock_sock(sk); + if (sk->sk_state != SMC_APPCLOSEWAIT1 && + sk->sk_state != SMC_APPCLOSEWAIT2) + goto again; + /* confirm close from peer */ + rc = smc_close_final(conn); + if (smc_cdc_rxed_any_close(conn)) { + /* peer has closed the socket already */ + sk->sk_state = SMC_CLOSED; + sock_put(sk); /* postponed passive closing */ + } else { + /* peer has just issued a shutdown write */ + sk->sk_state = SMC_PEERFINCLOSEWAIT; + } + break; + case SMC_PEERCLOSEWAIT1: + case SMC_PEERCLOSEWAIT2: + if (txflags->peer_done_writing && + !smc_close_sent_any_close(conn)) { + /* just shutdown wr done, send close request */ + rc = smc_close_final(conn); + } + /* peer sending PeerConnectionClosed will cause transition */ + break; + case SMC_PEERFINCLOSEWAIT: + /* peer sending PeerConnectionClosed will cause transition */ + break; + case SMC_PROCESSABORT: + rc = smc_close_abort(conn); + sk->sk_state = SMC_CLOSED; + break; + case SMC_PEERABORTWAIT: + sk->sk_state = SMC_CLOSED; + break; + case SMC_CLOSED: + /* nothing to do, add tracing in future patch */ + break; + } + + if (old_state != sk->sk_state) + sk->sk_state_change(sk); + return rc; +} + +static void smc_close_passive_abort_received(struct smc_sock *smc) +{ + struct smc_cdc_conn_state_flags *txflags = + &smc->conn.local_tx_ctrl.conn_state_flags; + struct sock *sk = &smc->sk; + + switch (sk->sk_state) { + case SMC_INIT: + case SMC_ACTIVE: + case SMC_APPCLOSEWAIT1: + sk->sk_state = SMC_PROCESSABORT; + sock_put(sk); /* passive closing */ + break; + case SMC_APPFINCLOSEWAIT: + sk->sk_state = SMC_PROCESSABORT; + break; + case SMC_PEERCLOSEWAIT1: + case SMC_PEERCLOSEWAIT2: + if (txflags->peer_done_writing && + !smc_close_sent_any_close(&smc->conn)) + /* just shutdown, but not yet closed locally */ + sk->sk_state = SMC_PROCESSABORT; + else + sk->sk_state = SMC_CLOSED; + sock_put(sk); /* passive closing */ + break; + case SMC_APPCLOSEWAIT2: + case SMC_PEERFINCLOSEWAIT: + sk->sk_state = SMC_CLOSED; + sock_put(sk); /* passive closing */ + break; + case SMC_PEERABORTWAIT: + sk->sk_state = SMC_CLOSED; + break; + case SMC_PROCESSABORT: + /* nothing to do, add tracing in future patch */ + break; + } +} + +/* Either some kind of closing has been received: peer_conn_closed, + * peer_conn_abort, or peer_done_writing + * or the link group of the connection terminates abnormally. + */ +static void smc_close_passive_work(struct work_struct *work) +{ + struct smc_connection *conn = container_of(work, + struct smc_connection, + close_work); + struct smc_sock *smc = container_of(conn, struct smc_sock, conn); + struct smc_cdc_conn_state_flags *rxflags; + bool release_clcsock = false; + struct sock *sk = &smc->sk; + int old_state; + + lock_sock(sk); + old_state = sk->sk_state; + + rxflags = &conn->local_rx_ctrl.conn_state_flags; + if (rxflags->peer_conn_abort) { + /* peer has not received all data */ + smc_close_passive_abort_received(smc); + release_sock(sk); + cancel_delayed_work_sync(&conn->tx_work); + lock_sock(sk); + goto wakeup; + } + + switch (sk->sk_state) { + case SMC_INIT: + sk->sk_state = SMC_APPCLOSEWAIT1; + break; + case SMC_ACTIVE: + sk->sk_state = SMC_APPCLOSEWAIT1; + /* postpone sock_put() for passive closing to cover + * received SEND_SHUTDOWN as well + */ + break; + case SMC_PEERCLOSEWAIT1: + if (rxflags->peer_done_writing) + sk->sk_state = SMC_PEERCLOSEWAIT2; + fallthrough; + /* to check for closing */ + case SMC_PEERCLOSEWAIT2: + if (!smc_cdc_rxed_any_close(conn)) + break; + if (sock_flag(sk, SOCK_DEAD) && + smc_close_sent_any_close(conn)) { + /* smc_release has already been called locally */ + sk->sk_state = SMC_CLOSED; + } else { + /* just shutdown, but not yet closed locally */ + sk->sk_state = SMC_APPFINCLOSEWAIT; + } + sock_put(sk); /* passive closing */ + break; + case SMC_PEERFINCLOSEWAIT: + if (smc_cdc_rxed_any_close(conn)) { + sk->sk_state = SMC_CLOSED; + sock_put(sk); /* passive closing */ + } + break; + case SMC_APPCLOSEWAIT1: + case SMC_APPCLOSEWAIT2: + /* postpone sock_put() for passive closing to cover + * received SEND_SHUTDOWN as well + */ + break; + case SMC_APPFINCLOSEWAIT: + case SMC_PEERABORTWAIT: + case SMC_PROCESSABORT: + case SMC_CLOSED: + /* nothing to do, add tracing in future patch */ + break; + } + +wakeup: + sk->sk_data_ready(sk); /* wakeup blocked rcvbuf consumers */ + sk->sk_write_space(sk); /* wakeup blocked sndbuf producers */ + + if (old_state != sk->sk_state) { + sk->sk_state_change(sk); + if ((sk->sk_state == SMC_CLOSED) && + (sock_flag(sk, SOCK_DEAD) || !sk->sk_socket)) { + smc_conn_free(conn); + if (smc->clcsock) + release_clcsock = true; + } + } + release_sock(sk); + if (release_clcsock) + smc_clcsock_release(smc); + sock_put(sk); /* sock_hold done by schedulers of close_work */ +} + +int smc_close_shutdown_write(struct smc_sock *smc) +{ + struct smc_connection *conn = &smc->conn; + struct sock *sk = &smc->sk; + int old_state; + long timeout; + int rc = 0; + + timeout = current->flags & PF_EXITING ? + 0 : sock_flag(sk, SOCK_LINGER) ? + sk->sk_lingertime : SMC_MAX_STREAM_WAIT_TIMEOUT; + + old_state = sk->sk_state; +again: + switch (sk->sk_state) { + case SMC_ACTIVE: + smc_close_stream_wait(smc, timeout); + release_sock(sk); + cancel_delayed_work_sync(&conn->tx_work); + lock_sock(sk); + if (sk->sk_state != SMC_ACTIVE) + goto again; + /* send close wr request */ + rc = smc_close_wr(conn); + sk->sk_state = SMC_PEERCLOSEWAIT1; + break; + case SMC_APPCLOSEWAIT1: + /* passive close */ + if (!smc_cdc_rxed_any_close(conn)) + smc_close_stream_wait(smc, timeout); + release_sock(sk); + cancel_delayed_work_sync(&conn->tx_work); + lock_sock(sk); + if (sk->sk_state != SMC_APPCLOSEWAIT1) + goto again; + /* confirm close from peer */ + rc = smc_close_wr(conn); + sk->sk_state = SMC_APPCLOSEWAIT2; + break; + case SMC_APPCLOSEWAIT2: + case SMC_PEERFINCLOSEWAIT: + case SMC_PEERCLOSEWAIT1: + case SMC_PEERCLOSEWAIT2: + case SMC_APPFINCLOSEWAIT: + case SMC_PROCESSABORT: + case SMC_PEERABORTWAIT: + /* nothing to do, add tracing in future patch */ + break; + } + + if (old_state != sk->sk_state) + sk->sk_state_change(sk); + return rc; +} + +/* Initialize close properties on connection establishment. */ +void smc_close_init(struct smc_sock *smc) +{ + INIT_WORK(&smc->conn.close_work, smc_close_passive_work); +} diff --git a/net/smc/smc_close.h b/net/smc/smc_close.h new file mode 100644 index 000000000..634fea2b7 --- /dev/null +++ b/net/smc/smc_close.h @@ -0,0 +1,30 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Socket Closing + * + * Copyright IBM Corp. 2016 + * + * Author(s): Ursula Braun + */ + +#ifndef SMC_CLOSE_H +#define SMC_CLOSE_H + +#include + +#include "smc.h" + +#define SMC_MAX_STREAM_WAIT_TIMEOUT (2 * HZ) +#define SMC_CLOSE_SOCK_PUT_DELAY HZ + +void smc_close_wake_tx_prepared(struct smc_sock *smc); +int smc_close_active(struct smc_sock *smc); +int smc_close_shutdown_write(struct smc_sock *smc); +void smc_close_init(struct smc_sock *smc); +void smc_clcsock_release(struct smc_sock *smc); +int smc_close_abort(struct smc_connection *conn); +void smc_close_active_abort(struct smc_sock *smc); + +#endif /* SMC_CLOSE_H */ diff --git a/net/smc/smc_core.c b/net/smc/smc_core.c new file mode 100644 index 000000000..64b6dd439 --- /dev/null +++ b/net/smc/smc_core.c @@ -0,0 +1,2617 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Basic Transport Functions exploiting Infiniband API + * + * Copyright IBM Corp. 2016 + * + * Author(s): Ursula Braun + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "smc.h" +#include "smc_clc.h" +#include "smc_core.h" +#include "smc_ib.h" +#include "smc_wr.h" +#include "smc_llc.h" +#include "smc_cdc.h" +#include "smc_close.h" +#include "smc_ism.h" +#include "smc_netlink.h" +#include "smc_stats.h" +#include "smc_tracepoint.h" + +#define SMC_LGR_NUM_INCR 256 +#define SMC_LGR_FREE_DELAY_SERV (600 * HZ) +#define SMC_LGR_FREE_DELAY_CLNT (SMC_LGR_FREE_DELAY_SERV + 10 * HZ) + +struct smc_lgr_list smc_lgr_list = { /* established link groups */ + .lock = __SPIN_LOCK_UNLOCKED(smc_lgr_list.lock), + .list = LIST_HEAD_INIT(smc_lgr_list.list), + .num = 0, +}; + +static atomic_t lgr_cnt = ATOMIC_INIT(0); /* number of existing link groups */ +static DECLARE_WAIT_QUEUE_HEAD(lgrs_deleted); + +static void smc_buf_free(struct smc_link_group *lgr, bool is_rmb, + struct smc_buf_desc *buf_desc); +static void __smc_lgr_terminate(struct smc_link_group *lgr, bool soft); + +static void smc_link_down_work(struct work_struct *work); + +/* return head of link group list and its lock for a given link group */ +static inline struct list_head *smc_lgr_list_head(struct smc_link_group *lgr, + spinlock_t **lgr_lock) +{ + if (lgr->is_smcd) { + *lgr_lock = &lgr->smcd->lgr_lock; + return &lgr->smcd->lgr_list; + } + + *lgr_lock = &smc_lgr_list.lock; + return &smc_lgr_list.list; +} + +static void smc_ibdev_cnt_inc(struct smc_link *lnk) +{ + atomic_inc(&lnk->smcibdev->lnk_cnt_by_port[lnk->ibport - 1]); +} + +static void smc_ibdev_cnt_dec(struct smc_link *lnk) +{ + atomic_dec(&lnk->smcibdev->lnk_cnt_by_port[lnk->ibport - 1]); +} + +static void smc_lgr_schedule_free_work(struct smc_link_group *lgr) +{ + /* client link group creation always follows the server link group + * creation. For client use a somewhat higher removal delay time, + * otherwise there is a risk of out-of-sync link groups. + */ + if (!lgr->freeing) { + mod_delayed_work(system_wq, &lgr->free_work, + (!lgr->is_smcd && lgr->role == SMC_CLNT) ? + SMC_LGR_FREE_DELAY_CLNT : + SMC_LGR_FREE_DELAY_SERV); + } +} + +/* Register connection's alert token in our lookup structure. + * To use rbtrees we have to implement our own insert core. + * Requires @conns_lock + * @smc connection to register + * Returns 0 on success, != otherwise. + */ +static void smc_lgr_add_alert_token(struct smc_connection *conn) +{ + struct rb_node **link, *parent = NULL; + u32 token = conn->alert_token_local; + + link = &conn->lgr->conns_all.rb_node; + while (*link) { + struct smc_connection *cur = rb_entry(*link, + struct smc_connection, alert_node); + + parent = *link; + if (cur->alert_token_local > token) + link = &parent->rb_left; + else + link = &parent->rb_right; + } + /* Put the new node there */ + rb_link_node(&conn->alert_node, parent, link); + rb_insert_color(&conn->alert_node, &conn->lgr->conns_all); +} + +/* assign an SMC-R link to the connection */ +static int smcr_lgr_conn_assign_link(struct smc_connection *conn, bool first) +{ + enum smc_link_state expected = first ? SMC_LNK_ACTIVATING : + SMC_LNK_ACTIVE; + int i, j; + + /* do link balancing */ + conn->lnk = NULL; /* reset conn->lnk first */ + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + struct smc_link *lnk = &conn->lgr->lnk[i]; + + if (lnk->state != expected || lnk->link_is_asym) + continue; + if (conn->lgr->role == SMC_CLNT) { + conn->lnk = lnk; /* temporary, SMC server assigns link*/ + break; + } + if (conn->lgr->conns_num % 2) { + for (j = i + 1; j < SMC_LINKS_PER_LGR_MAX; j++) { + struct smc_link *lnk2; + + lnk2 = &conn->lgr->lnk[j]; + if (lnk2->state == expected && + !lnk2->link_is_asym) { + conn->lnk = lnk2; + break; + } + } + } + if (!conn->lnk) + conn->lnk = lnk; + break; + } + if (!conn->lnk) + return SMC_CLC_DECL_NOACTLINK; + atomic_inc(&conn->lnk->conn_cnt); + return 0; +} + +/* Register connection in link group by assigning an alert token + * registered in a search tree. + * Requires @conns_lock + * Note that '0' is a reserved value and not assigned. + */ +static int smc_lgr_register_conn(struct smc_connection *conn, bool first) +{ + struct smc_sock *smc = container_of(conn, struct smc_sock, conn); + static atomic_t nexttoken = ATOMIC_INIT(0); + int rc; + + if (!conn->lgr->is_smcd) { + rc = smcr_lgr_conn_assign_link(conn, first); + if (rc) { + conn->lgr = NULL; + return rc; + } + } + /* find a new alert_token_local value not yet used by some connection + * in this link group + */ + sock_hold(&smc->sk); /* sock_put in smc_lgr_unregister_conn() */ + while (!conn->alert_token_local) { + conn->alert_token_local = atomic_inc_return(&nexttoken); + if (smc_lgr_find_conn(conn->alert_token_local, conn->lgr)) + conn->alert_token_local = 0; + } + smc_lgr_add_alert_token(conn); + conn->lgr->conns_num++; + return 0; +} + +/* Unregister connection and reset the alert token of the given connection< + */ +static void __smc_lgr_unregister_conn(struct smc_connection *conn) +{ + struct smc_sock *smc = container_of(conn, struct smc_sock, conn); + struct smc_link_group *lgr = conn->lgr; + + rb_erase(&conn->alert_node, &lgr->conns_all); + if (conn->lnk) + atomic_dec(&conn->lnk->conn_cnt); + lgr->conns_num--; + conn->alert_token_local = 0; + sock_put(&smc->sk); /* sock_hold in smc_lgr_register_conn() */ +} + +/* Unregister connection from lgr + */ +static void smc_lgr_unregister_conn(struct smc_connection *conn) +{ + struct smc_link_group *lgr = conn->lgr; + + if (!smc_conn_lgr_valid(conn)) + return; + write_lock_bh(&lgr->conns_lock); + if (conn->alert_token_local) { + __smc_lgr_unregister_conn(conn); + } + write_unlock_bh(&lgr->conns_lock); +} + +int smc_nl_get_sys_info(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct smc_nl_dmp_ctx *cb_ctx = smc_nl_dmp_ctx(cb); + char hostname[SMC_MAX_HOSTNAME_LEN + 1]; + char smc_seid[SMC_MAX_EID_LEN + 1]; + struct nlattr *attrs; + u8 *seid = NULL; + u8 *host = NULL; + void *nlh; + + nlh = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &smc_gen_nl_family, NLM_F_MULTI, + SMC_NETLINK_GET_SYS_INFO); + if (!nlh) + goto errmsg; + if (cb_ctx->pos[0]) + goto errout; + attrs = nla_nest_start(skb, SMC_GEN_SYS_INFO); + if (!attrs) + goto errout; + if (nla_put_u8(skb, SMC_NLA_SYS_VER, SMC_V2)) + goto errattr; + if (nla_put_u8(skb, SMC_NLA_SYS_REL, SMC_RELEASE)) + goto errattr; + if (nla_put_u8(skb, SMC_NLA_SYS_IS_ISM_V2, smc_ism_is_v2_capable())) + goto errattr; + if (nla_put_u8(skb, SMC_NLA_SYS_IS_SMCR_V2, true)) + goto errattr; + smc_clc_get_hostname(&host); + if (host) { + memcpy(hostname, host, SMC_MAX_HOSTNAME_LEN); + hostname[SMC_MAX_HOSTNAME_LEN] = 0; + if (nla_put_string(skb, SMC_NLA_SYS_LOCAL_HOST, hostname)) + goto errattr; + } + if (smc_ism_is_v2_capable()) { + smc_ism_get_system_eid(&seid); + memcpy(smc_seid, seid, SMC_MAX_EID_LEN); + smc_seid[SMC_MAX_EID_LEN] = 0; + if (nla_put_string(skb, SMC_NLA_SYS_SEID, smc_seid)) + goto errattr; + } + nla_nest_end(skb, attrs); + genlmsg_end(skb, nlh); + cb_ctx->pos[0] = 1; + return skb->len; + +errattr: + nla_nest_cancel(skb, attrs); +errout: + genlmsg_cancel(skb, nlh); +errmsg: + return skb->len; +} + +/* Fill SMC_NLA_LGR_D_V2_COMMON/SMC_NLA_LGR_R_V2_COMMON nested attributes */ +static int smc_nl_fill_lgr_v2_common(struct smc_link_group *lgr, + struct sk_buff *skb, + struct netlink_callback *cb, + struct nlattr *v2_attrs) +{ + char smc_host[SMC_MAX_HOSTNAME_LEN + 1]; + char smc_eid[SMC_MAX_EID_LEN + 1]; + + if (nla_put_u8(skb, SMC_NLA_LGR_V2_VER, lgr->smc_version)) + goto errv2attr; + if (nla_put_u8(skb, SMC_NLA_LGR_V2_REL, lgr->peer_smc_release)) + goto errv2attr; + if (nla_put_u8(skb, SMC_NLA_LGR_V2_OS, lgr->peer_os)) + goto errv2attr; + memcpy(smc_host, lgr->peer_hostname, SMC_MAX_HOSTNAME_LEN); + smc_host[SMC_MAX_HOSTNAME_LEN] = 0; + if (nla_put_string(skb, SMC_NLA_LGR_V2_PEER_HOST, smc_host)) + goto errv2attr; + memcpy(smc_eid, lgr->negotiated_eid, SMC_MAX_EID_LEN); + smc_eid[SMC_MAX_EID_LEN] = 0; + if (nla_put_string(skb, SMC_NLA_LGR_V2_NEG_EID, smc_eid)) + goto errv2attr; + + nla_nest_end(skb, v2_attrs); + return 0; + +errv2attr: + nla_nest_cancel(skb, v2_attrs); + return -EMSGSIZE; +} + +static int smc_nl_fill_smcr_lgr_v2(struct smc_link_group *lgr, + struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct nlattr *v2_attrs; + + v2_attrs = nla_nest_start(skb, SMC_NLA_LGR_R_V2); + if (!v2_attrs) + goto errattr; + if (nla_put_u8(skb, SMC_NLA_LGR_R_V2_DIRECT, !lgr->uses_gateway)) + goto errv2attr; + + nla_nest_end(skb, v2_attrs); + return 0; + +errv2attr: + nla_nest_cancel(skb, v2_attrs); +errattr: + return -EMSGSIZE; +} + +static int smc_nl_fill_lgr(struct smc_link_group *lgr, + struct sk_buff *skb, + struct netlink_callback *cb) +{ + char smc_target[SMC_MAX_PNETID_LEN + 1]; + struct nlattr *attrs, *v2_attrs; + + attrs = nla_nest_start(skb, SMC_GEN_LGR_SMCR); + if (!attrs) + goto errout; + + if (nla_put_u32(skb, SMC_NLA_LGR_R_ID, *((u32 *)&lgr->id))) + goto errattr; + if (nla_put_u32(skb, SMC_NLA_LGR_R_CONNS_NUM, lgr->conns_num)) + goto errattr; + if (nla_put_u8(skb, SMC_NLA_LGR_R_ROLE, lgr->role)) + goto errattr; + if (nla_put_u8(skb, SMC_NLA_LGR_R_TYPE, lgr->type)) + goto errattr; + if (nla_put_u8(skb, SMC_NLA_LGR_R_BUF_TYPE, lgr->buf_type)) + goto errattr; + if (nla_put_u8(skb, SMC_NLA_LGR_R_VLAN_ID, lgr->vlan_id)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_LGR_R_NET_COOKIE, + lgr->net->net_cookie, SMC_NLA_LGR_R_PAD)) + goto errattr; + memcpy(smc_target, lgr->pnet_id, SMC_MAX_PNETID_LEN); + smc_target[SMC_MAX_PNETID_LEN] = 0; + if (nla_put_string(skb, SMC_NLA_LGR_R_PNETID, smc_target)) + goto errattr; + if (lgr->smc_version > SMC_V1) { + v2_attrs = nla_nest_start(skb, SMC_NLA_LGR_R_V2_COMMON); + if (!v2_attrs) + goto errattr; + if (smc_nl_fill_lgr_v2_common(lgr, skb, cb, v2_attrs)) + goto errattr; + if (smc_nl_fill_smcr_lgr_v2(lgr, skb, cb)) + goto errattr; + } + + nla_nest_end(skb, attrs); + return 0; +errattr: + nla_nest_cancel(skb, attrs); +errout: + return -EMSGSIZE; +} + +static int smc_nl_fill_lgr_link(struct smc_link_group *lgr, + struct smc_link *link, + struct sk_buff *skb, + struct netlink_callback *cb) +{ + char smc_ibname[IB_DEVICE_NAME_MAX]; + u8 smc_gid_target[41]; + struct nlattr *attrs; + u32 link_uid = 0; + void *nlh; + + nlh = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &smc_gen_nl_family, NLM_F_MULTI, + SMC_NETLINK_GET_LINK_SMCR); + if (!nlh) + goto errmsg; + + attrs = nla_nest_start(skb, SMC_GEN_LINK_SMCR); + if (!attrs) + goto errout; + + if (nla_put_u8(skb, SMC_NLA_LINK_ID, link->link_id)) + goto errattr; + if (nla_put_u32(skb, SMC_NLA_LINK_STATE, link->state)) + goto errattr; + if (nla_put_u32(skb, SMC_NLA_LINK_CONN_CNT, + atomic_read(&link->conn_cnt))) + goto errattr; + if (nla_put_u8(skb, SMC_NLA_LINK_IB_PORT, link->ibport)) + goto errattr; + if (nla_put_u32(skb, SMC_NLA_LINK_NET_DEV, link->ndev_ifidx)) + goto errattr; + snprintf(smc_ibname, sizeof(smc_ibname), "%s", link->ibname); + if (nla_put_string(skb, SMC_NLA_LINK_IB_DEV, smc_ibname)) + goto errattr; + memcpy(&link_uid, link->link_uid, sizeof(link_uid)); + if (nla_put_u32(skb, SMC_NLA_LINK_UID, link_uid)) + goto errattr; + memcpy(&link_uid, link->peer_link_uid, sizeof(link_uid)); + if (nla_put_u32(skb, SMC_NLA_LINK_PEER_UID, link_uid)) + goto errattr; + memset(smc_gid_target, 0, sizeof(smc_gid_target)); + smc_gid_be16_convert(smc_gid_target, link->gid); + if (nla_put_string(skb, SMC_NLA_LINK_GID, smc_gid_target)) + goto errattr; + memset(smc_gid_target, 0, sizeof(smc_gid_target)); + smc_gid_be16_convert(smc_gid_target, link->peer_gid); + if (nla_put_string(skb, SMC_NLA_LINK_PEER_GID, smc_gid_target)) + goto errattr; + + nla_nest_end(skb, attrs); + genlmsg_end(skb, nlh); + return 0; +errattr: + nla_nest_cancel(skb, attrs); +errout: + genlmsg_cancel(skb, nlh); +errmsg: + return -EMSGSIZE; +} + +static int smc_nl_handle_lgr(struct smc_link_group *lgr, + struct sk_buff *skb, + struct netlink_callback *cb, + bool list_links) +{ + void *nlh; + int i; + + nlh = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &smc_gen_nl_family, NLM_F_MULTI, + SMC_NETLINK_GET_LGR_SMCR); + if (!nlh) + goto errmsg; + if (smc_nl_fill_lgr(lgr, skb, cb)) + goto errout; + + genlmsg_end(skb, nlh); + if (!list_links) + goto out; + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + if (!smc_link_usable(&lgr->lnk[i])) + continue; + if (smc_nl_fill_lgr_link(lgr, &lgr->lnk[i], skb, cb)) + goto errout; + } +out: + return 0; + +errout: + genlmsg_cancel(skb, nlh); +errmsg: + return -EMSGSIZE; +} + +static void smc_nl_fill_lgr_list(struct smc_lgr_list *smc_lgr, + struct sk_buff *skb, + struct netlink_callback *cb, + bool list_links) +{ + struct smc_nl_dmp_ctx *cb_ctx = smc_nl_dmp_ctx(cb); + struct smc_link_group *lgr; + int snum = cb_ctx->pos[0]; + int num = 0; + + spin_lock_bh(&smc_lgr->lock); + list_for_each_entry(lgr, &smc_lgr->list, list) { + if (num < snum) + goto next; + if (smc_nl_handle_lgr(lgr, skb, cb, list_links)) + goto errout; +next: + num++; + } +errout: + spin_unlock_bh(&smc_lgr->lock); + cb_ctx->pos[0] = num; +} + +static int smc_nl_fill_smcd_lgr(struct smc_link_group *lgr, + struct sk_buff *skb, + struct netlink_callback *cb) +{ + char smc_pnet[SMC_MAX_PNETID_LEN + 1]; + struct nlattr *attrs; + void *nlh; + + nlh = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &smc_gen_nl_family, NLM_F_MULTI, + SMC_NETLINK_GET_LGR_SMCD); + if (!nlh) + goto errmsg; + + attrs = nla_nest_start(skb, SMC_GEN_LGR_SMCD); + if (!attrs) + goto errout; + + if (nla_put_u32(skb, SMC_NLA_LGR_D_ID, *((u32 *)&lgr->id))) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_LGR_D_GID, lgr->smcd->local_gid, + SMC_NLA_LGR_D_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_LGR_D_PEER_GID, lgr->peer_gid, + SMC_NLA_LGR_D_PAD)) + goto errattr; + if (nla_put_u8(skb, SMC_NLA_LGR_D_VLAN_ID, lgr->vlan_id)) + goto errattr; + if (nla_put_u32(skb, SMC_NLA_LGR_D_CONNS_NUM, lgr->conns_num)) + goto errattr; + if (nla_put_u32(skb, SMC_NLA_LGR_D_CHID, smc_ism_get_chid(lgr->smcd))) + goto errattr; + memcpy(smc_pnet, lgr->smcd->pnetid, SMC_MAX_PNETID_LEN); + smc_pnet[SMC_MAX_PNETID_LEN] = 0; + if (nla_put_string(skb, SMC_NLA_LGR_D_PNETID, smc_pnet)) + goto errattr; + if (lgr->smc_version > SMC_V1) { + struct nlattr *v2_attrs; + + v2_attrs = nla_nest_start(skb, SMC_NLA_LGR_D_V2_COMMON); + if (!v2_attrs) + goto errattr; + if (smc_nl_fill_lgr_v2_common(lgr, skb, cb, v2_attrs)) + goto errattr; + } + nla_nest_end(skb, attrs); + genlmsg_end(skb, nlh); + return 0; + +errattr: + nla_nest_cancel(skb, attrs); +errout: + genlmsg_cancel(skb, nlh); +errmsg: + return -EMSGSIZE; +} + +static int smc_nl_handle_smcd_lgr(struct smcd_dev *dev, + struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct smc_nl_dmp_ctx *cb_ctx = smc_nl_dmp_ctx(cb); + struct smc_link_group *lgr; + int snum = cb_ctx->pos[1]; + int rc = 0, num = 0; + + spin_lock_bh(&dev->lgr_lock); + list_for_each_entry(lgr, &dev->lgr_list, list) { + if (!lgr->is_smcd) + continue; + if (num < snum) + goto next; + rc = smc_nl_fill_smcd_lgr(lgr, skb, cb); + if (rc) + goto errout; +next: + num++; + } +errout: + spin_unlock_bh(&dev->lgr_lock); + cb_ctx->pos[1] = num; + return rc; +} + +static int smc_nl_fill_smcd_dev(struct smcd_dev_list *dev_list, + struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct smc_nl_dmp_ctx *cb_ctx = smc_nl_dmp_ctx(cb); + struct smcd_dev *smcd_dev; + int snum = cb_ctx->pos[0]; + int rc = 0, num = 0; + + mutex_lock(&dev_list->mutex); + list_for_each_entry(smcd_dev, &dev_list->list, list) { + if (list_empty(&smcd_dev->lgr_list)) + continue; + if (num < snum) + goto next; + rc = smc_nl_handle_smcd_lgr(smcd_dev, skb, cb); + if (rc) + goto errout; +next: + num++; + } +errout: + mutex_unlock(&dev_list->mutex); + cb_ctx->pos[0] = num; + return rc; +} + +int smcr_nl_get_lgr(struct sk_buff *skb, struct netlink_callback *cb) +{ + bool list_links = false; + + smc_nl_fill_lgr_list(&smc_lgr_list, skb, cb, list_links); + return skb->len; +} + +int smcr_nl_get_link(struct sk_buff *skb, struct netlink_callback *cb) +{ + bool list_links = true; + + smc_nl_fill_lgr_list(&smc_lgr_list, skb, cb, list_links); + return skb->len; +} + +int smcd_nl_get_lgr(struct sk_buff *skb, struct netlink_callback *cb) +{ + smc_nl_fill_smcd_dev(&smcd_dev_list, skb, cb); + return skb->len; +} + +void smc_lgr_cleanup_early(struct smc_link_group *lgr) +{ + spinlock_t *lgr_lock; + + if (!lgr) + return; + + smc_lgr_list_head(lgr, &lgr_lock); + spin_lock_bh(lgr_lock); + /* do not use this link group for new connections */ + if (!list_empty(&lgr->list)) + list_del_init(&lgr->list); + spin_unlock_bh(lgr_lock); + __smc_lgr_terminate(lgr, true); +} + +static void smcr_lgr_link_deactivate_all(struct smc_link_group *lgr) +{ + int i; + + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + struct smc_link *lnk = &lgr->lnk[i]; + + if (smc_link_sendable(lnk)) + lnk->state = SMC_LNK_INACTIVE; + } + wake_up_all(&lgr->llc_msg_waiter); + wake_up_all(&lgr->llc_flow_waiter); +} + +static void smc_lgr_free(struct smc_link_group *lgr); + +static void smc_lgr_free_work(struct work_struct *work) +{ + struct smc_link_group *lgr = container_of(to_delayed_work(work), + struct smc_link_group, + free_work); + spinlock_t *lgr_lock; + bool conns; + + smc_lgr_list_head(lgr, &lgr_lock); + spin_lock_bh(lgr_lock); + if (lgr->freeing) { + spin_unlock_bh(lgr_lock); + return; + } + read_lock_bh(&lgr->conns_lock); + conns = RB_EMPTY_ROOT(&lgr->conns_all); + read_unlock_bh(&lgr->conns_lock); + if (!conns) { /* number of lgr connections is no longer zero */ + spin_unlock_bh(lgr_lock); + return; + } + list_del_init(&lgr->list); /* remove from smc_lgr_list */ + lgr->freeing = 1; /* this instance does the freeing, no new schedule */ + spin_unlock_bh(lgr_lock); + cancel_delayed_work(&lgr->free_work); + + if (!lgr->is_smcd && !lgr->terminating) + smc_llc_send_link_delete_all(lgr, true, + SMC_LLC_DEL_PROG_INIT_TERM); + if (lgr->is_smcd && !lgr->terminating) + smc_ism_signal_shutdown(lgr); + if (!lgr->is_smcd) + smcr_lgr_link_deactivate_all(lgr); + smc_lgr_free(lgr); +} + +static void smc_lgr_terminate_work(struct work_struct *work) +{ + struct smc_link_group *lgr = container_of(work, struct smc_link_group, + terminate_work); + + __smc_lgr_terminate(lgr, true); +} + +/* return next unique link id for the lgr */ +static u8 smcr_next_link_id(struct smc_link_group *lgr) +{ + u8 link_id; + int i; + + while (1) { +again: + link_id = ++lgr->next_link_id; + if (!link_id) /* skip zero as link_id */ + link_id = ++lgr->next_link_id; + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + if (smc_link_usable(&lgr->lnk[i]) && + lgr->lnk[i].link_id == link_id) + goto again; + } + break; + } + return link_id; +} + +static void smcr_copy_dev_info_to_link(struct smc_link *link) +{ + struct smc_ib_device *smcibdev = link->smcibdev; + + snprintf(link->ibname, sizeof(link->ibname), "%s", + smcibdev->ibdev->name); + link->ndev_ifidx = smcibdev->ndev_ifidx[link->ibport - 1]; +} + +int smcr_link_init(struct smc_link_group *lgr, struct smc_link *lnk, + u8 link_idx, struct smc_init_info *ini) +{ + struct smc_ib_device *smcibdev; + u8 rndvec[3]; + int rc; + + if (lgr->smc_version == SMC_V2) { + lnk->smcibdev = ini->smcrv2.ib_dev_v2; + lnk->ibport = ini->smcrv2.ib_port_v2; + } else { + lnk->smcibdev = ini->ib_dev; + lnk->ibport = ini->ib_port; + } + get_device(&lnk->smcibdev->ibdev->dev); + atomic_inc(&lnk->smcibdev->lnk_cnt); + refcount_set(&lnk->refcnt, 1); /* link refcnt is set to 1 */ + lnk->clearing = 0; + lnk->path_mtu = lnk->smcibdev->pattr[lnk->ibport - 1].active_mtu; + lnk->link_id = smcr_next_link_id(lgr); + lnk->lgr = lgr; + smc_lgr_hold(lgr); /* lgr_put in smcr_link_clear() */ + lnk->link_idx = link_idx; + lnk->wr_rx_id_compl = 0; + smc_ibdev_cnt_inc(lnk); + smcr_copy_dev_info_to_link(lnk); + atomic_set(&lnk->conn_cnt, 0); + smc_llc_link_set_uid(lnk); + INIT_WORK(&lnk->link_down_wrk, smc_link_down_work); + if (!lnk->smcibdev->initialized) { + rc = (int)smc_ib_setup_per_ibdev(lnk->smcibdev); + if (rc) + goto out; + } + get_random_bytes(rndvec, sizeof(rndvec)); + lnk->psn_initial = rndvec[0] + (rndvec[1] << 8) + + (rndvec[2] << 16); + rc = smc_ib_determine_gid(lnk->smcibdev, lnk->ibport, + ini->vlan_id, lnk->gid, &lnk->sgid_index, + lgr->smc_version == SMC_V2 ? + &ini->smcrv2 : NULL); + if (rc) + goto out; + rc = smc_llc_link_init(lnk); + if (rc) + goto out; + rc = smc_wr_alloc_link_mem(lnk); + if (rc) + goto clear_llc_lnk; + rc = smc_ib_create_protection_domain(lnk); + if (rc) + goto free_link_mem; + rc = smc_ib_create_queue_pair(lnk); + if (rc) + goto dealloc_pd; + rc = smc_wr_create_link(lnk); + if (rc) + goto destroy_qp; + lnk->state = SMC_LNK_ACTIVATING; + return 0; + +destroy_qp: + smc_ib_destroy_queue_pair(lnk); +dealloc_pd: + smc_ib_dealloc_protection_domain(lnk); +free_link_mem: + smc_wr_free_link_mem(lnk); +clear_llc_lnk: + smc_llc_link_clear(lnk, false); +out: + smc_ibdev_cnt_dec(lnk); + put_device(&lnk->smcibdev->ibdev->dev); + smcibdev = lnk->smcibdev; + memset(lnk, 0, sizeof(struct smc_link)); + lnk->state = SMC_LNK_UNUSED; + if (!atomic_dec_return(&smcibdev->lnk_cnt)) + wake_up(&smcibdev->lnks_deleted); + smc_lgr_put(lgr); /* lgr_hold above */ + return rc; +} + +/* create a new SMC link group */ +static int smc_lgr_create(struct smc_sock *smc, struct smc_init_info *ini) +{ + struct smc_link_group *lgr; + struct list_head *lgr_list; + struct smc_link *lnk; + spinlock_t *lgr_lock; + u8 link_idx; + int rc = 0; + int i; + + if (ini->is_smcd && ini->vlan_id) { + if (smc_ism_get_vlan(ini->ism_dev[ini->ism_selected], + ini->vlan_id)) { + rc = SMC_CLC_DECL_ISMVLANERR; + goto out; + } + } + + lgr = kzalloc(sizeof(*lgr), GFP_KERNEL); + if (!lgr) { + rc = SMC_CLC_DECL_MEM; + goto ism_put_vlan; + } + lgr->tx_wq = alloc_workqueue("smc_tx_wq-%*phN", 0, 0, + SMC_LGR_ID_SIZE, &lgr->id); + if (!lgr->tx_wq) { + rc = -ENOMEM; + goto free_lgr; + } + lgr->is_smcd = ini->is_smcd; + lgr->sync_err = 0; + lgr->terminating = 0; + lgr->freeing = 0; + lgr->vlan_id = ini->vlan_id; + refcount_set(&lgr->refcnt, 1); /* set lgr refcnt to 1 */ + init_rwsem(&lgr->sndbufs_lock); + init_rwsem(&lgr->rmbs_lock); + rwlock_init(&lgr->conns_lock); + for (i = 0; i < SMC_RMBE_SIZES; i++) { + INIT_LIST_HEAD(&lgr->sndbufs[i]); + INIT_LIST_HEAD(&lgr->rmbs[i]); + } + lgr->next_link_id = 0; + smc_lgr_list.num += SMC_LGR_NUM_INCR; + memcpy(&lgr->id, (u8 *)&smc_lgr_list.num, SMC_LGR_ID_SIZE); + INIT_DELAYED_WORK(&lgr->free_work, smc_lgr_free_work); + INIT_WORK(&lgr->terminate_work, smc_lgr_terminate_work); + lgr->conns_all = RB_ROOT; + if (ini->is_smcd) { + /* SMC-D specific settings */ + get_device(&ini->ism_dev[ini->ism_selected]->dev); + lgr->peer_gid = ini->ism_peer_gid[ini->ism_selected]; + lgr->smcd = ini->ism_dev[ini->ism_selected]; + lgr_list = &ini->ism_dev[ini->ism_selected]->lgr_list; + lgr_lock = &lgr->smcd->lgr_lock; + lgr->smc_version = ini->smcd_version; + lgr->peer_shutdown = 0; + atomic_inc(&ini->ism_dev[ini->ism_selected]->lgr_cnt); + } else { + /* SMC-R specific settings */ + struct smc_ib_device *ibdev; + int ibport; + + lgr->role = smc->listen_smc ? SMC_SERV : SMC_CLNT; + lgr->smc_version = ini->smcr_version; + memcpy(lgr->peer_systemid, ini->peer_systemid, + SMC_SYSTEMID_LEN); + if (lgr->smc_version == SMC_V2) { + ibdev = ini->smcrv2.ib_dev_v2; + ibport = ini->smcrv2.ib_port_v2; + lgr->saddr = ini->smcrv2.saddr; + lgr->uses_gateway = ini->smcrv2.uses_gateway; + memcpy(lgr->nexthop_mac, ini->smcrv2.nexthop_mac, + ETH_ALEN); + } else { + ibdev = ini->ib_dev; + ibport = ini->ib_port; + } + memcpy(lgr->pnet_id, ibdev->pnetid[ibport - 1], + SMC_MAX_PNETID_LEN); + rc = smc_wr_alloc_lgr_mem(lgr); + if (rc) + goto free_wq; + smc_llc_lgr_init(lgr, smc); + + link_idx = SMC_SINGLE_LINK; + lnk = &lgr->lnk[link_idx]; + rc = smcr_link_init(lgr, lnk, link_idx, ini); + if (rc) { + smc_wr_free_lgr_mem(lgr); + goto free_wq; + } + lgr->net = smc_ib_net(lnk->smcibdev); + lgr_list = &smc_lgr_list.list; + lgr_lock = &smc_lgr_list.lock; + lgr->buf_type = lgr->net->smc.sysctl_smcr_buf_type; + atomic_inc(&lgr_cnt); + } + smc->conn.lgr = lgr; + spin_lock_bh(lgr_lock); + list_add_tail(&lgr->list, lgr_list); + spin_unlock_bh(lgr_lock); + return 0; + +free_wq: + destroy_workqueue(lgr->tx_wq); +free_lgr: + kfree(lgr); +ism_put_vlan: + if (ini->is_smcd && ini->vlan_id) + smc_ism_put_vlan(ini->ism_dev[ini->ism_selected], ini->vlan_id); +out: + if (rc < 0) { + if (rc == -ENOMEM) + rc = SMC_CLC_DECL_MEM; + else + rc = SMC_CLC_DECL_INTERR; + } + return rc; +} + +static int smc_write_space(struct smc_connection *conn) +{ + int buffer_len = conn->peer_rmbe_size; + union smc_host_cursor prod; + union smc_host_cursor cons; + int space; + + smc_curs_copy(&prod, &conn->local_tx_ctrl.prod, conn); + smc_curs_copy(&cons, &conn->local_rx_ctrl.cons, conn); + /* determine rx_buf space */ + space = buffer_len - smc_curs_diff(buffer_len, &cons, &prod); + return space; +} + +static int smc_switch_cursor(struct smc_sock *smc, struct smc_cdc_tx_pend *pend, + struct smc_wr_buf *wr_buf) +{ + struct smc_connection *conn = &smc->conn; + union smc_host_cursor cons, fin; + int rc = 0; + int diff; + + smc_curs_copy(&conn->tx_curs_sent, &conn->tx_curs_fin, conn); + smc_curs_copy(&fin, &conn->local_tx_ctrl_fin, conn); + /* set prod cursor to old state, enforce tx_rdma_writes() */ + smc_curs_copy(&conn->local_tx_ctrl.prod, &fin, conn); + smc_curs_copy(&cons, &conn->local_rx_ctrl.cons, conn); + + if (smc_curs_comp(conn->peer_rmbe_size, &cons, &fin) < 0) { + /* cons cursor advanced more than fin, and prod was set + * fin above, so now prod is smaller than cons. Fix that. + */ + diff = smc_curs_diff(conn->peer_rmbe_size, &fin, &cons); + smc_curs_add(conn->sndbuf_desc->len, + &conn->tx_curs_sent, diff); + smc_curs_add(conn->sndbuf_desc->len, + &conn->tx_curs_fin, diff); + + smp_mb__before_atomic(); + atomic_add(diff, &conn->sndbuf_space); + smp_mb__after_atomic(); + + smc_curs_add(conn->peer_rmbe_size, + &conn->local_tx_ctrl.prod, diff); + smc_curs_add(conn->peer_rmbe_size, + &conn->local_tx_ctrl_fin, diff); + } + /* recalculate, value is used by tx_rdma_writes() */ + atomic_set(&smc->conn.peer_rmbe_space, smc_write_space(conn)); + + if (smc->sk.sk_state != SMC_INIT && + smc->sk.sk_state != SMC_CLOSED) { + rc = smcr_cdc_msg_send_validation(conn, pend, wr_buf); + if (!rc) { + queue_delayed_work(conn->lgr->tx_wq, &conn->tx_work, 0); + smc->sk.sk_data_ready(&smc->sk); + } + } else { + smc_wr_tx_put_slot(conn->lnk, + (struct smc_wr_tx_pend_priv *)pend); + } + return rc; +} + +void smc_switch_link_and_count(struct smc_connection *conn, + struct smc_link *to_lnk) +{ + atomic_dec(&conn->lnk->conn_cnt); + /* link_hold in smc_conn_create() */ + smcr_link_put(conn->lnk); + conn->lnk = to_lnk; + atomic_inc(&conn->lnk->conn_cnt); + /* link_put in smc_conn_free() */ + smcr_link_hold(conn->lnk); +} + +struct smc_link *smc_switch_conns(struct smc_link_group *lgr, + struct smc_link *from_lnk, bool is_dev_err) +{ + struct smc_link *to_lnk = NULL; + struct smc_cdc_tx_pend *pend; + struct smc_connection *conn; + struct smc_wr_buf *wr_buf; + struct smc_sock *smc; + struct rb_node *node; + int i, rc = 0; + + /* link is inactive, wake up tx waiters */ + smc_wr_wakeup_tx_wait(from_lnk); + + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + if (!smc_link_active(&lgr->lnk[i]) || i == from_lnk->link_idx) + continue; + if (is_dev_err && from_lnk->smcibdev == lgr->lnk[i].smcibdev && + from_lnk->ibport == lgr->lnk[i].ibport) { + continue; + } + to_lnk = &lgr->lnk[i]; + break; + } + if (!to_lnk || !smc_wr_tx_link_hold(to_lnk)) { + smc_lgr_terminate_sched(lgr); + return NULL; + } +again: + read_lock_bh(&lgr->conns_lock); + for (node = rb_first(&lgr->conns_all); node; node = rb_next(node)) { + conn = rb_entry(node, struct smc_connection, alert_node); + if (conn->lnk != from_lnk) + continue; + smc = container_of(conn, struct smc_sock, conn); + /* conn->lnk not yet set in SMC_INIT state */ + if (smc->sk.sk_state == SMC_INIT) + continue; + if (smc->sk.sk_state == SMC_CLOSED || + smc->sk.sk_state == SMC_PEERCLOSEWAIT1 || + smc->sk.sk_state == SMC_PEERCLOSEWAIT2 || + smc->sk.sk_state == SMC_APPFINCLOSEWAIT || + smc->sk.sk_state == SMC_APPCLOSEWAIT1 || + smc->sk.sk_state == SMC_APPCLOSEWAIT2 || + smc->sk.sk_state == SMC_PEERFINCLOSEWAIT || + smc->sk.sk_state == SMC_PEERABORTWAIT || + smc->sk.sk_state == SMC_PROCESSABORT) { + spin_lock_bh(&conn->send_lock); + smc_switch_link_and_count(conn, to_lnk); + spin_unlock_bh(&conn->send_lock); + continue; + } + sock_hold(&smc->sk); + read_unlock_bh(&lgr->conns_lock); + /* pre-fetch buffer outside of send_lock, might sleep */ + rc = smc_cdc_get_free_slot(conn, to_lnk, &wr_buf, NULL, &pend); + if (rc) + goto err_out; + /* avoid race with smcr_tx_sndbuf_nonempty() */ + spin_lock_bh(&conn->send_lock); + smc_switch_link_and_count(conn, to_lnk); + rc = smc_switch_cursor(smc, pend, wr_buf); + spin_unlock_bh(&conn->send_lock); + sock_put(&smc->sk); + if (rc) + goto err_out; + goto again; + } + read_unlock_bh(&lgr->conns_lock); + smc_wr_tx_link_put(to_lnk); + return to_lnk; + +err_out: + smcr_link_down_cond_sched(to_lnk); + smc_wr_tx_link_put(to_lnk); + return NULL; +} + +static void smcr_buf_unuse(struct smc_buf_desc *buf_desc, bool is_rmb, + struct smc_link_group *lgr) +{ + struct rw_semaphore *lock; /* lock buffer list */ + int rc; + + if (is_rmb && buf_desc->is_conf_rkey && !list_empty(&lgr->list)) { + /* unregister rmb with peer */ + rc = smc_llc_flow_initiate(lgr, SMC_LLC_FLOW_RKEY); + if (!rc) { + /* protect against smc_llc_cli_rkey_exchange() */ + mutex_lock(&lgr->llc_conf_mutex); + smc_llc_do_delete_rkey(lgr, buf_desc); + buf_desc->is_conf_rkey = false; + mutex_unlock(&lgr->llc_conf_mutex); + smc_llc_flow_stop(lgr, &lgr->llc_flow_lcl); + } + } + + if (buf_desc->is_reg_err) { + /* buf registration failed, reuse not possible */ + lock = is_rmb ? &lgr->rmbs_lock : + &lgr->sndbufs_lock; + down_write(lock); + list_del(&buf_desc->list); + up_write(lock); + + smc_buf_free(lgr, is_rmb, buf_desc); + } else { + /* memzero_explicit provides potential memory barrier semantics */ + memzero_explicit(buf_desc->cpu_addr, buf_desc->len); + WRITE_ONCE(buf_desc->used, 0); + } +} + +static void smc_buf_unuse(struct smc_connection *conn, + struct smc_link_group *lgr) +{ + if (conn->sndbuf_desc) { + if (!lgr->is_smcd && conn->sndbuf_desc->is_vm) { + smcr_buf_unuse(conn->sndbuf_desc, false, lgr); + } else { + memzero_explicit(conn->sndbuf_desc->cpu_addr, conn->sndbuf_desc->len); + WRITE_ONCE(conn->sndbuf_desc->used, 0); + } + } + if (conn->rmb_desc) { + if (!lgr->is_smcd) { + smcr_buf_unuse(conn->rmb_desc, true, lgr); + } else { + memzero_explicit(conn->rmb_desc->cpu_addr, + conn->rmb_desc->len + sizeof(struct smcd_cdc_msg)); + WRITE_ONCE(conn->rmb_desc->used, 0); + } + } +} + +/* remove a finished connection from its link group */ +void smc_conn_free(struct smc_connection *conn) +{ + struct smc_link_group *lgr = conn->lgr; + + if (!lgr || conn->freed) + /* Connection has never been registered in a + * link group, or has already been freed. + */ + return; + + conn->freed = 1; + if (!smc_conn_lgr_valid(conn)) + /* Connection has already unregistered from + * link group. + */ + goto lgr_put; + + if (lgr->is_smcd) { + if (!list_empty(&lgr->list)) + smc_ism_unset_conn(conn); + tasklet_kill(&conn->rx_tsklet); + } else { + smc_cdc_wait_pend_tx_wr(conn); + if (current_work() != &conn->abort_work) + cancel_work_sync(&conn->abort_work); + } + if (!list_empty(&lgr->list)) { + smc_buf_unuse(conn, lgr); /* allow buffer reuse */ + smc_lgr_unregister_conn(conn); + } + + if (!lgr->conns_num) + smc_lgr_schedule_free_work(lgr); +lgr_put: + if (!lgr->is_smcd) + smcr_link_put(conn->lnk); /* link_hold in smc_conn_create() */ + smc_lgr_put(lgr); /* lgr_hold in smc_conn_create() */ +} + +/* unregister a link from a buf_desc */ +static void smcr_buf_unmap_link(struct smc_buf_desc *buf_desc, bool is_rmb, + struct smc_link *lnk) +{ + if (is_rmb || buf_desc->is_vm) + buf_desc->is_reg_mr[lnk->link_idx] = false; + if (!buf_desc->is_map_ib[lnk->link_idx]) + return; + + if ((is_rmb || buf_desc->is_vm) && + buf_desc->mr[lnk->link_idx]) { + smc_ib_put_memory_region(buf_desc->mr[lnk->link_idx]); + buf_desc->mr[lnk->link_idx] = NULL; + } + if (is_rmb) + smc_ib_buf_unmap_sg(lnk, buf_desc, DMA_FROM_DEVICE); + else + smc_ib_buf_unmap_sg(lnk, buf_desc, DMA_TO_DEVICE); + + sg_free_table(&buf_desc->sgt[lnk->link_idx]); + buf_desc->is_map_ib[lnk->link_idx] = false; +} + +/* unmap all buffers of lgr for a deleted link */ +static void smcr_buf_unmap_lgr(struct smc_link *lnk) +{ + struct smc_link_group *lgr = lnk->lgr; + struct smc_buf_desc *buf_desc, *bf; + int i; + + for (i = 0; i < SMC_RMBE_SIZES; i++) { + down_write(&lgr->rmbs_lock); + list_for_each_entry_safe(buf_desc, bf, &lgr->rmbs[i], list) + smcr_buf_unmap_link(buf_desc, true, lnk); + up_write(&lgr->rmbs_lock); + + down_write(&lgr->sndbufs_lock); + list_for_each_entry_safe(buf_desc, bf, &lgr->sndbufs[i], + list) + smcr_buf_unmap_link(buf_desc, false, lnk); + up_write(&lgr->sndbufs_lock); + } +} + +static void smcr_rtoken_clear_link(struct smc_link *lnk) +{ + struct smc_link_group *lgr = lnk->lgr; + int i; + + for (i = 0; i < SMC_RMBS_PER_LGR_MAX; i++) { + lgr->rtokens[i][lnk->link_idx].rkey = 0; + lgr->rtokens[i][lnk->link_idx].dma_addr = 0; + } +} + +static void __smcr_link_clear(struct smc_link *lnk) +{ + struct smc_link_group *lgr = lnk->lgr; + struct smc_ib_device *smcibdev; + + smc_wr_free_link_mem(lnk); + smc_ibdev_cnt_dec(lnk); + put_device(&lnk->smcibdev->ibdev->dev); + smcibdev = lnk->smcibdev; + memset(lnk, 0, sizeof(struct smc_link)); + lnk->state = SMC_LNK_UNUSED; + if (!atomic_dec_return(&smcibdev->lnk_cnt)) + wake_up(&smcibdev->lnks_deleted); + smc_lgr_put(lgr); /* lgr_hold in smcr_link_init() */ +} + +/* must be called under lgr->llc_conf_mutex lock */ +void smcr_link_clear(struct smc_link *lnk, bool log) +{ + if (!lnk->lgr || lnk->clearing || + lnk->state == SMC_LNK_UNUSED) + return; + lnk->clearing = 1; + lnk->peer_qpn = 0; + smc_llc_link_clear(lnk, log); + smcr_buf_unmap_lgr(lnk); + smcr_rtoken_clear_link(lnk); + smc_ib_modify_qp_error(lnk); + smc_wr_free_link(lnk); + smc_ib_destroy_queue_pair(lnk); + smc_ib_dealloc_protection_domain(lnk); + smcr_link_put(lnk); /* theoretically last link_put */ +} + +void smcr_link_hold(struct smc_link *lnk) +{ + refcount_inc(&lnk->refcnt); +} + +void smcr_link_put(struct smc_link *lnk) +{ + if (refcount_dec_and_test(&lnk->refcnt)) + __smcr_link_clear(lnk); +} + +static void smcr_buf_free(struct smc_link_group *lgr, bool is_rmb, + struct smc_buf_desc *buf_desc) +{ + int i; + + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) + smcr_buf_unmap_link(buf_desc, is_rmb, &lgr->lnk[i]); + + if (!buf_desc->is_vm && buf_desc->pages) + __free_pages(buf_desc->pages, buf_desc->order); + else if (buf_desc->is_vm && buf_desc->cpu_addr) + vfree(buf_desc->cpu_addr); + kfree(buf_desc); +} + +static void smcd_buf_free(struct smc_link_group *lgr, bool is_dmb, + struct smc_buf_desc *buf_desc) +{ + if (is_dmb) { + /* restore original buf len */ + buf_desc->len += sizeof(struct smcd_cdc_msg); + smc_ism_unregister_dmb(lgr->smcd, buf_desc); + } else { + kfree(buf_desc->cpu_addr); + } + kfree(buf_desc); +} + +static void smc_buf_free(struct smc_link_group *lgr, bool is_rmb, + struct smc_buf_desc *buf_desc) +{ + if (lgr->is_smcd) + smcd_buf_free(lgr, is_rmb, buf_desc); + else + smcr_buf_free(lgr, is_rmb, buf_desc); +} + +static void __smc_lgr_free_bufs(struct smc_link_group *lgr, bool is_rmb) +{ + struct smc_buf_desc *buf_desc, *bf_desc; + struct list_head *buf_list; + int i; + + for (i = 0; i < SMC_RMBE_SIZES; i++) { + if (is_rmb) + buf_list = &lgr->rmbs[i]; + else + buf_list = &lgr->sndbufs[i]; + list_for_each_entry_safe(buf_desc, bf_desc, buf_list, + list) { + list_del(&buf_desc->list); + smc_buf_free(lgr, is_rmb, buf_desc); + } + } +} + +static void smc_lgr_free_bufs(struct smc_link_group *lgr) +{ + /* free send buffers */ + __smc_lgr_free_bufs(lgr, false); + /* free rmbs */ + __smc_lgr_free_bufs(lgr, true); +} + +/* won't be freed until no one accesses to lgr anymore */ +static void __smc_lgr_free(struct smc_link_group *lgr) +{ + smc_lgr_free_bufs(lgr); + if (lgr->is_smcd) { + if (!atomic_dec_return(&lgr->smcd->lgr_cnt)) + wake_up(&lgr->smcd->lgrs_deleted); + } else { + smc_wr_free_lgr_mem(lgr); + if (!atomic_dec_return(&lgr_cnt)) + wake_up(&lgrs_deleted); + } + kfree(lgr); +} + +/* remove a link group */ +static void smc_lgr_free(struct smc_link_group *lgr) +{ + int i; + + if (!lgr->is_smcd) { + mutex_lock(&lgr->llc_conf_mutex); + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + if (lgr->lnk[i].state != SMC_LNK_UNUSED) + smcr_link_clear(&lgr->lnk[i], false); + } + mutex_unlock(&lgr->llc_conf_mutex); + smc_llc_lgr_clear(lgr); + } + + destroy_workqueue(lgr->tx_wq); + if (lgr->is_smcd) { + smc_ism_put_vlan(lgr->smcd, lgr->vlan_id); + put_device(&lgr->smcd->dev); + } + smc_lgr_put(lgr); /* theoretically last lgr_put */ +} + +void smc_lgr_hold(struct smc_link_group *lgr) +{ + refcount_inc(&lgr->refcnt); +} + +void smc_lgr_put(struct smc_link_group *lgr) +{ + if (refcount_dec_and_test(&lgr->refcnt)) + __smc_lgr_free(lgr); +} + +static void smc_sk_wake_ups(struct smc_sock *smc) +{ + smc->sk.sk_write_space(&smc->sk); + smc->sk.sk_data_ready(&smc->sk); + smc->sk.sk_state_change(&smc->sk); +} + +/* kill a connection */ +static void smc_conn_kill(struct smc_connection *conn, bool soft) +{ + struct smc_sock *smc = container_of(conn, struct smc_sock, conn); + + if (conn->lgr->is_smcd && conn->lgr->peer_shutdown) + conn->local_tx_ctrl.conn_state_flags.peer_conn_abort = 1; + else + smc_close_abort(conn); + conn->killed = 1; + smc->sk.sk_err = ECONNABORTED; + smc_sk_wake_ups(smc); + if (conn->lgr->is_smcd) { + smc_ism_unset_conn(conn); + if (soft) + tasklet_kill(&conn->rx_tsklet); + else + tasklet_unlock_wait(&conn->rx_tsklet); + } else { + smc_cdc_wait_pend_tx_wr(conn); + } + smc_lgr_unregister_conn(conn); + smc_close_active_abort(smc); +} + +static void smc_lgr_cleanup(struct smc_link_group *lgr) +{ + if (lgr->is_smcd) { + smc_ism_signal_shutdown(lgr); + } else { + u32 rsn = lgr->llc_termination_rsn; + + if (!rsn) + rsn = SMC_LLC_DEL_PROG_INIT_TERM; + smc_llc_send_link_delete_all(lgr, false, rsn); + smcr_lgr_link_deactivate_all(lgr); + } +} + +/* terminate link group + * @soft: true if link group shutdown can take its time + * false if immediate link group shutdown is required + */ +static void __smc_lgr_terminate(struct smc_link_group *lgr, bool soft) +{ + struct smc_connection *conn; + struct smc_sock *smc; + struct rb_node *node; + + if (lgr->terminating) + return; /* lgr already terminating */ + /* cancel free_work sync, will terminate when lgr->freeing is set */ + cancel_delayed_work(&lgr->free_work); + lgr->terminating = 1; + + /* kill remaining link group connections */ + read_lock_bh(&lgr->conns_lock); + node = rb_first(&lgr->conns_all); + while (node) { + read_unlock_bh(&lgr->conns_lock); + conn = rb_entry(node, struct smc_connection, alert_node); + smc = container_of(conn, struct smc_sock, conn); + sock_hold(&smc->sk); /* sock_put below */ + lock_sock(&smc->sk); + smc_conn_kill(conn, soft); + release_sock(&smc->sk); + sock_put(&smc->sk); /* sock_hold above */ + read_lock_bh(&lgr->conns_lock); + node = rb_first(&lgr->conns_all); + } + read_unlock_bh(&lgr->conns_lock); + smc_lgr_cleanup(lgr); + smc_lgr_free(lgr); +} + +/* unlink link group and schedule termination */ +void smc_lgr_terminate_sched(struct smc_link_group *lgr) +{ + spinlock_t *lgr_lock; + + smc_lgr_list_head(lgr, &lgr_lock); + spin_lock_bh(lgr_lock); + if (list_empty(&lgr->list) || lgr->terminating || lgr->freeing) { + spin_unlock_bh(lgr_lock); + return; /* lgr already terminating */ + } + list_del_init(&lgr->list); + lgr->freeing = 1; + spin_unlock_bh(lgr_lock); + schedule_work(&lgr->terminate_work); +} + +/* Called when peer lgr shutdown (regularly or abnormally) is received */ +void smc_smcd_terminate(struct smcd_dev *dev, u64 peer_gid, unsigned short vlan) +{ + struct smc_link_group *lgr, *l; + LIST_HEAD(lgr_free_list); + + /* run common cleanup function and build free list */ + spin_lock_bh(&dev->lgr_lock); + list_for_each_entry_safe(lgr, l, &dev->lgr_list, list) { + if ((!peer_gid || lgr->peer_gid == peer_gid) && + (vlan == VLAN_VID_MASK || lgr->vlan_id == vlan)) { + if (peer_gid) /* peer triggered termination */ + lgr->peer_shutdown = 1; + list_move(&lgr->list, &lgr_free_list); + lgr->freeing = 1; + } + } + spin_unlock_bh(&dev->lgr_lock); + + /* cancel the regular free workers and actually free lgrs */ + list_for_each_entry_safe(lgr, l, &lgr_free_list, list) { + list_del_init(&lgr->list); + schedule_work(&lgr->terminate_work); + } +} + +/* Called when an SMCD device is removed or the smc module is unloaded */ +void smc_smcd_terminate_all(struct smcd_dev *smcd) +{ + struct smc_link_group *lgr, *lg; + LIST_HEAD(lgr_free_list); + + spin_lock_bh(&smcd->lgr_lock); + list_splice_init(&smcd->lgr_list, &lgr_free_list); + list_for_each_entry(lgr, &lgr_free_list, list) + lgr->freeing = 1; + spin_unlock_bh(&smcd->lgr_lock); + + list_for_each_entry_safe(lgr, lg, &lgr_free_list, list) { + list_del_init(&lgr->list); + __smc_lgr_terminate(lgr, false); + } + + if (atomic_read(&smcd->lgr_cnt)) + wait_event(smcd->lgrs_deleted, !atomic_read(&smcd->lgr_cnt)); +} + +/* Called when an SMCR device is removed or the smc module is unloaded. + * If smcibdev is given, all SMCR link groups using this device are terminated. + * If smcibdev is NULL, all SMCR link groups are terminated. + */ +void smc_smcr_terminate_all(struct smc_ib_device *smcibdev) +{ + struct smc_link_group *lgr, *lg; + LIST_HEAD(lgr_free_list); + int i; + + spin_lock_bh(&smc_lgr_list.lock); + if (!smcibdev) { + list_splice_init(&smc_lgr_list.list, &lgr_free_list); + list_for_each_entry(lgr, &lgr_free_list, list) + lgr->freeing = 1; + } else { + list_for_each_entry_safe(lgr, lg, &smc_lgr_list.list, list) { + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + if (lgr->lnk[i].smcibdev == smcibdev) + smcr_link_down_cond_sched(&lgr->lnk[i]); + } + } + } + spin_unlock_bh(&smc_lgr_list.lock); + + list_for_each_entry_safe(lgr, lg, &lgr_free_list, list) { + list_del_init(&lgr->list); + smc_llc_set_termination_rsn(lgr, SMC_LLC_DEL_OP_INIT_TERM); + __smc_lgr_terminate(lgr, false); + } + + if (smcibdev) { + if (atomic_read(&smcibdev->lnk_cnt)) + wait_event(smcibdev->lnks_deleted, + !atomic_read(&smcibdev->lnk_cnt)); + } else { + if (atomic_read(&lgr_cnt)) + wait_event(lgrs_deleted, !atomic_read(&lgr_cnt)); + } +} + +/* set new lgr type and clear all asymmetric link tagging */ +void smcr_lgr_set_type(struct smc_link_group *lgr, enum smc_lgr_type new_type) +{ + char *lgr_type = ""; + int i; + + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) + if (smc_link_usable(&lgr->lnk[i])) + lgr->lnk[i].link_is_asym = false; + if (lgr->type == new_type) + return; + lgr->type = new_type; + + switch (lgr->type) { + case SMC_LGR_NONE: + lgr_type = "NONE"; + break; + case SMC_LGR_SINGLE: + lgr_type = "SINGLE"; + break; + case SMC_LGR_SYMMETRIC: + lgr_type = "SYMMETRIC"; + break; + case SMC_LGR_ASYMMETRIC_PEER: + lgr_type = "ASYMMETRIC_PEER"; + break; + case SMC_LGR_ASYMMETRIC_LOCAL: + lgr_type = "ASYMMETRIC_LOCAL"; + break; + } + pr_warn_ratelimited("smc: SMC-R lg %*phN net %llu state changed: " + "%s, pnetid %.16s\n", SMC_LGR_ID_SIZE, &lgr->id, + lgr->net->net_cookie, lgr_type, lgr->pnet_id); +} + +/* set new lgr type and tag a link as asymmetric */ +void smcr_lgr_set_type_asym(struct smc_link_group *lgr, + enum smc_lgr_type new_type, int asym_lnk_idx) +{ + smcr_lgr_set_type(lgr, new_type); + lgr->lnk[asym_lnk_idx].link_is_asym = true; +} + +/* abort connection, abort_work scheduled from tasklet context */ +static void smc_conn_abort_work(struct work_struct *work) +{ + struct smc_connection *conn = container_of(work, + struct smc_connection, + abort_work); + struct smc_sock *smc = container_of(conn, struct smc_sock, conn); + + lock_sock(&smc->sk); + smc_conn_kill(conn, true); + release_sock(&smc->sk); + sock_put(&smc->sk); /* sock_hold done by schedulers of abort_work */ +} + +void smcr_port_add(struct smc_ib_device *smcibdev, u8 ibport) +{ + struct smc_link_group *lgr, *n; + + spin_lock_bh(&smc_lgr_list.lock); + list_for_each_entry_safe(lgr, n, &smc_lgr_list.list, list) { + struct smc_link *link; + + if (strncmp(smcibdev->pnetid[ibport - 1], lgr->pnet_id, + SMC_MAX_PNETID_LEN) || + lgr->type == SMC_LGR_SYMMETRIC || + lgr->type == SMC_LGR_ASYMMETRIC_PEER || + !rdma_dev_access_netns(smcibdev->ibdev, lgr->net)) + continue; + + /* trigger local add link processing */ + link = smc_llc_usable_link(lgr); + if (link) + smc_llc_add_link_local(link); + } + spin_unlock_bh(&smc_lgr_list.lock); +} + +/* link is down - switch connections to alternate link, + * must be called under lgr->llc_conf_mutex lock + */ +static void smcr_link_down(struct smc_link *lnk) +{ + struct smc_link_group *lgr = lnk->lgr; + struct smc_link *to_lnk; + int del_link_id; + + if (!lgr || lnk->state == SMC_LNK_UNUSED || list_empty(&lgr->list)) + return; + + to_lnk = smc_switch_conns(lgr, lnk, true); + if (!to_lnk) { /* no backup link available */ + smcr_link_clear(lnk, true); + return; + } + smcr_lgr_set_type(lgr, SMC_LGR_SINGLE); + del_link_id = lnk->link_id; + + if (lgr->role == SMC_SERV) { + /* trigger local delete link processing */ + smc_llc_srv_delete_link_local(to_lnk, del_link_id); + } else { + if (lgr->llc_flow_lcl.type != SMC_LLC_FLOW_NONE) { + /* another llc task is ongoing */ + mutex_unlock(&lgr->llc_conf_mutex); + wait_event_timeout(lgr->llc_flow_waiter, + (list_empty(&lgr->list) || + lgr->llc_flow_lcl.type == SMC_LLC_FLOW_NONE), + SMC_LLC_WAIT_TIME); + mutex_lock(&lgr->llc_conf_mutex); + } + if (!list_empty(&lgr->list)) { + smc_llc_send_delete_link(to_lnk, del_link_id, + SMC_LLC_REQ, true, + SMC_LLC_DEL_LOST_PATH); + smcr_link_clear(lnk, true); + } + wake_up(&lgr->llc_flow_waiter); /* wake up next waiter */ + } +} + +/* must be called under lgr->llc_conf_mutex lock */ +void smcr_link_down_cond(struct smc_link *lnk) +{ + if (smc_link_downing(&lnk->state)) { + trace_smcr_link_down(lnk, __builtin_return_address(0)); + smcr_link_down(lnk); + } +} + +/* will get the lgr->llc_conf_mutex lock */ +void smcr_link_down_cond_sched(struct smc_link *lnk) +{ + if (smc_link_downing(&lnk->state)) { + trace_smcr_link_down(lnk, __builtin_return_address(0)); + schedule_work(&lnk->link_down_wrk); + } +} + +void smcr_port_err(struct smc_ib_device *smcibdev, u8 ibport) +{ + struct smc_link_group *lgr, *n; + int i; + + list_for_each_entry_safe(lgr, n, &smc_lgr_list.list, list) { + if (strncmp(smcibdev->pnetid[ibport - 1], lgr->pnet_id, + SMC_MAX_PNETID_LEN)) + continue; /* lgr is not affected */ + if (list_empty(&lgr->list)) + continue; + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + struct smc_link *lnk = &lgr->lnk[i]; + + if (smc_link_usable(lnk) && + lnk->smcibdev == smcibdev && lnk->ibport == ibport) + smcr_link_down_cond_sched(lnk); + } + } +} + +static void smc_link_down_work(struct work_struct *work) +{ + struct smc_link *link = container_of(work, struct smc_link, + link_down_wrk); + struct smc_link_group *lgr = link->lgr; + + if (list_empty(&lgr->list)) + return; + wake_up_all(&lgr->llc_msg_waiter); + mutex_lock(&lgr->llc_conf_mutex); + smcr_link_down(link); + mutex_unlock(&lgr->llc_conf_mutex); +} + +static int smc_vlan_by_tcpsk_walk(struct net_device *lower_dev, + struct netdev_nested_priv *priv) +{ + unsigned short *vlan_id = (unsigned short *)priv->data; + + if (is_vlan_dev(lower_dev)) { + *vlan_id = vlan_dev_vlan_id(lower_dev); + return 1; + } + + return 0; +} + +/* Determine vlan of internal TCP socket. */ +int smc_vlan_by_tcpsk(struct socket *clcsock, struct smc_init_info *ini) +{ + struct dst_entry *dst = sk_dst_get(clcsock->sk); + struct netdev_nested_priv priv; + struct net_device *ndev; + int rc = 0; + + ini->vlan_id = 0; + if (!dst) { + rc = -ENOTCONN; + goto out; + } + if (!dst->dev) { + rc = -ENODEV; + goto out_rel; + } + + ndev = dst->dev; + if (is_vlan_dev(ndev)) { + ini->vlan_id = vlan_dev_vlan_id(ndev); + goto out_rel; + } + + priv.data = (void *)&ini->vlan_id; + rtnl_lock(); + netdev_walk_all_lower_dev(ndev, smc_vlan_by_tcpsk_walk, &priv); + rtnl_unlock(); + +out_rel: + dst_release(dst); +out: + return rc; +} + +static bool smcr_lgr_match(struct smc_link_group *lgr, u8 smcr_version, + u8 peer_systemid[], + u8 peer_gid[], + u8 peer_mac_v1[], + enum smc_lgr_role role, u32 clcqpn, + struct net *net) +{ + struct smc_link *lnk; + int i; + + if (memcmp(lgr->peer_systemid, peer_systemid, SMC_SYSTEMID_LEN) || + lgr->role != role) + return false; + + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + lnk = &lgr->lnk[i]; + + if (!smc_link_active(lnk)) + continue; + /* use verbs API to check netns, instead of lgr->net */ + if (!rdma_dev_access_netns(lnk->smcibdev->ibdev, net)) + return false; + if ((lgr->role == SMC_SERV || lnk->peer_qpn == clcqpn) && + !memcmp(lnk->peer_gid, peer_gid, SMC_GID_SIZE) && + (smcr_version == SMC_V2 || + !memcmp(lnk->peer_mac, peer_mac_v1, ETH_ALEN))) + return true; + } + return false; +} + +static bool smcd_lgr_match(struct smc_link_group *lgr, + struct smcd_dev *smcismdev, u64 peer_gid) +{ + return lgr->peer_gid == peer_gid && lgr->smcd == smcismdev; +} + +/* create a new SMC connection (and a new link group if necessary) */ +int smc_conn_create(struct smc_sock *smc, struct smc_init_info *ini) +{ + struct smc_connection *conn = &smc->conn; + struct net *net = sock_net(&smc->sk); + struct list_head *lgr_list; + struct smc_link_group *lgr; + enum smc_lgr_role role; + spinlock_t *lgr_lock; + int rc = 0; + + lgr_list = ini->is_smcd ? &ini->ism_dev[ini->ism_selected]->lgr_list : + &smc_lgr_list.list; + lgr_lock = ini->is_smcd ? &ini->ism_dev[ini->ism_selected]->lgr_lock : + &smc_lgr_list.lock; + ini->first_contact_local = 1; + role = smc->listen_smc ? SMC_SERV : SMC_CLNT; + if (role == SMC_CLNT && ini->first_contact_peer) + /* create new link group as well */ + goto create; + + /* determine if an existing link group can be reused */ + spin_lock_bh(lgr_lock); + list_for_each_entry(lgr, lgr_list, list) { + write_lock_bh(&lgr->conns_lock); + if ((ini->is_smcd ? + smcd_lgr_match(lgr, ini->ism_dev[ini->ism_selected], + ini->ism_peer_gid[ini->ism_selected]) : + smcr_lgr_match(lgr, ini->smcr_version, + ini->peer_systemid, + ini->peer_gid, ini->peer_mac, role, + ini->ib_clcqpn, net)) && + !lgr->sync_err && + (ini->smcd_version == SMC_V2 || + lgr->vlan_id == ini->vlan_id) && + (role == SMC_CLNT || ini->is_smcd || + (lgr->conns_num < SMC_RMBS_PER_LGR_MAX && + !bitmap_full(lgr->rtokens_used_mask, SMC_RMBS_PER_LGR_MAX)))) { + /* link group found */ + ini->first_contact_local = 0; + conn->lgr = lgr; + rc = smc_lgr_register_conn(conn, false); + write_unlock_bh(&lgr->conns_lock); + if (!rc && delayed_work_pending(&lgr->free_work)) + cancel_delayed_work(&lgr->free_work); + break; + } + write_unlock_bh(&lgr->conns_lock); + } + spin_unlock_bh(lgr_lock); + if (rc) + return rc; + + if (role == SMC_CLNT && !ini->first_contact_peer && + ini->first_contact_local) { + /* Server reuses a link group, but Client wants to start + * a new one + * send out_of_sync decline, reason synchr. error + */ + return SMC_CLC_DECL_SYNCERR; + } + +create: + if (ini->first_contact_local) { + rc = smc_lgr_create(smc, ini); + if (rc) + goto out; + lgr = conn->lgr; + write_lock_bh(&lgr->conns_lock); + rc = smc_lgr_register_conn(conn, true); + write_unlock_bh(&lgr->conns_lock); + if (rc) { + smc_lgr_cleanup_early(lgr); + goto out; + } + } + smc_lgr_hold(conn->lgr); /* lgr_put in smc_conn_free() */ + if (!conn->lgr->is_smcd) + smcr_link_hold(conn->lnk); /* link_put in smc_conn_free() */ + conn->freed = 0; + conn->local_tx_ctrl.common.type = SMC_CDC_MSG_TYPE; + conn->local_tx_ctrl.len = SMC_WR_TX_SIZE; + conn->urg_state = SMC_URG_READ; + init_waitqueue_head(&conn->cdc_pend_tx_wq); + INIT_WORK(&smc->conn.abort_work, smc_conn_abort_work); + if (ini->is_smcd) { + conn->rx_off = sizeof(struct smcd_cdc_msg); + smcd_cdc_rx_init(conn); /* init tasklet for this conn */ + } else { + conn->rx_off = 0; + } +#ifndef KERNEL_HAS_ATOMIC64 + spin_lock_init(&conn->acurs_lock); +#endif + +out: + return rc; +} + +#define SMCD_DMBE_SIZES 6 /* 0 -> 16KB, 1 -> 32KB, .. 6 -> 1MB */ +#define SMCR_RMBE_SIZES 5 /* 0 -> 16KB, 1 -> 32KB, .. 5 -> 512KB */ + +/* convert the RMB size into the compressed notation (minimum 16K, see + * SMCD/R_DMBE_SIZES. + * In contrast to plain ilog2, this rounds towards the next power of 2, + * so the socket application gets at least its desired sndbuf / rcvbuf size. + */ +static u8 smc_compress_bufsize(int size, bool is_smcd, bool is_rmb) +{ + const unsigned int max_scat = SG_MAX_SINGLE_ALLOC * PAGE_SIZE; + u8 compressed; + + if (size <= SMC_BUF_MIN_SIZE) + return 0; + + size = (size - 1) >> 14; /* convert to 16K multiple */ + compressed = min_t(u8, ilog2(size) + 1, + is_smcd ? SMCD_DMBE_SIZES : SMCR_RMBE_SIZES); + + if (!is_smcd && is_rmb) + /* RMBs are backed by & limited to max size of scatterlists */ + compressed = min_t(u8, compressed, ilog2(max_scat >> 14)); + + return compressed; +} + +/* convert the RMB size from compressed notation into integer */ +int smc_uncompress_bufsize(u8 compressed) +{ + u32 size; + + size = 0x00000001 << (((int)compressed) + 14); + return (int)size; +} + +/* try to reuse a sndbuf or rmb description slot for a certain + * buffer size; if not available, return NULL + */ +static struct smc_buf_desc *smc_buf_get_slot(int compressed_bufsize, + struct rw_semaphore *lock, + struct list_head *buf_list) +{ + struct smc_buf_desc *buf_slot; + + down_read(lock); + list_for_each_entry(buf_slot, buf_list, list) { + if (cmpxchg(&buf_slot->used, 0, 1) == 0) { + up_read(lock); + return buf_slot; + } + } + up_read(lock); + return NULL; +} + +/* one of the conditions for announcing a receiver's current window size is + * that it "results in a minimum increase in the window size of 10% of the + * receive buffer space" [RFC7609] + */ +static inline int smc_rmb_wnd_update_limit(int rmbe_size) +{ + return max_t(int, rmbe_size / 10, SOCK_MIN_SNDBUF / 2); +} + +/* map an buf to a link */ +static int smcr_buf_map_link(struct smc_buf_desc *buf_desc, bool is_rmb, + struct smc_link *lnk) +{ + int rc, i, nents, offset, buf_size, size, access_flags; + struct scatterlist *sg; + void *buf; + + if (buf_desc->is_map_ib[lnk->link_idx]) + return 0; + + if (buf_desc->is_vm) { + buf = buf_desc->cpu_addr; + buf_size = buf_desc->len; + offset = offset_in_page(buf_desc->cpu_addr); + nents = PAGE_ALIGN(buf_size + offset) / PAGE_SIZE; + } else { + nents = 1; + } + + rc = sg_alloc_table(&buf_desc->sgt[lnk->link_idx], nents, GFP_KERNEL); + if (rc) + return rc; + + if (buf_desc->is_vm) { + /* virtually contiguous buffer */ + for_each_sg(buf_desc->sgt[lnk->link_idx].sgl, sg, nents, i) { + size = min_t(int, PAGE_SIZE - offset, buf_size); + sg_set_page(sg, vmalloc_to_page(buf), size, offset); + buf += size / sizeof(*buf); + buf_size -= size; + offset = 0; + } + } else { + /* physically contiguous buffer */ + sg_set_buf(buf_desc->sgt[lnk->link_idx].sgl, + buf_desc->cpu_addr, buf_desc->len); + } + + /* map sg table to DMA address */ + rc = smc_ib_buf_map_sg(lnk, buf_desc, + is_rmb ? DMA_FROM_DEVICE : DMA_TO_DEVICE); + /* SMC protocol depends on mapping to one DMA address only */ + if (rc != nents) { + rc = -EAGAIN; + goto free_table; + } + + buf_desc->is_dma_need_sync |= + smc_ib_is_sg_need_sync(lnk, buf_desc) << lnk->link_idx; + + if (is_rmb || buf_desc->is_vm) { + /* create a new memory region for the RMB or vzalloced sndbuf */ + access_flags = is_rmb ? + IB_ACCESS_REMOTE_WRITE | IB_ACCESS_LOCAL_WRITE : + IB_ACCESS_LOCAL_WRITE; + + rc = smc_ib_get_memory_region(lnk->roce_pd, access_flags, + buf_desc, lnk->link_idx); + if (rc) + goto buf_unmap; + smc_ib_sync_sg_for_device(lnk, buf_desc, + is_rmb ? DMA_FROM_DEVICE : DMA_TO_DEVICE); + } + buf_desc->is_map_ib[lnk->link_idx] = true; + return 0; + +buf_unmap: + smc_ib_buf_unmap_sg(lnk, buf_desc, + is_rmb ? DMA_FROM_DEVICE : DMA_TO_DEVICE); +free_table: + sg_free_table(&buf_desc->sgt[lnk->link_idx]); + return rc; +} + +/* register a new buf on IB device, rmb or vzalloced sndbuf + * must be called under lgr->llc_conf_mutex lock + */ +int smcr_link_reg_buf(struct smc_link *link, struct smc_buf_desc *buf_desc) +{ + if (list_empty(&link->lgr->list)) + return -ENOLINK; + if (!buf_desc->is_reg_mr[link->link_idx]) { + /* register memory region for new buf */ + if (buf_desc->is_vm) + buf_desc->mr[link->link_idx]->iova = + (uintptr_t)buf_desc->cpu_addr; + if (smc_wr_reg_send(link, buf_desc->mr[link->link_idx])) { + buf_desc->is_reg_err = true; + return -EFAULT; + } + buf_desc->is_reg_mr[link->link_idx] = true; + } + return 0; +} + +static int _smcr_buf_map_lgr(struct smc_link *lnk, struct rw_semaphore *lock, + struct list_head *lst, bool is_rmb) +{ + struct smc_buf_desc *buf_desc, *bf; + int rc = 0; + + down_write(lock); + list_for_each_entry_safe(buf_desc, bf, lst, list) { + if (!buf_desc->used) + continue; + rc = smcr_buf_map_link(buf_desc, is_rmb, lnk); + if (rc) + goto out; + } +out: + up_write(lock); + return rc; +} + +/* map all used buffers of lgr for a new link */ +int smcr_buf_map_lgr(struct smc_link *lnk) +{ + struct smc_link_group *lgr = lnk->lgr; + int i, rc = 0; + + for (i = 0; i < SMC_RMBE_SIZES; i++) { + rc = _smcr_buf_map_lgr(lnk, &lgr->rmbs_lock, + &lgr->rmbs[i], true); + if (rc) + return rc; + rc = _smcr_buf_map_lgr(lnk, &lgr->sndbufs_lock, + &lgr->sndbufs[i], false); + if (rc) + return rc; + } + return 0; +} + +/* register all used buffers of lgr for a new link, + * must be called under lgr->llc_conf_mutex lock + */ +int smcr_buf_reg_lgr(struct smc_link *lnk) +{ + struct smc_link_group *lgr = lnk->lgr; + struct smc_buf_desc *buf_desc, *bf; + int i, rc = 0; + + /* reg all RMBs for a new link */ + down_write(&lgr->rmbs_lock); + for (i = 0; i < SMC_RMBE_SIZES; i++) { + list_for_each_entry_safe(buf_desc, bf, &lgr->rmbs[i], list) { + if (!buf_desc->used) + continue; + rc = smcr_link_reg_buf(lnk, buf_desc); + if (rc) { + up_write(&lgr->rmbs_lock); + return rc; + } + } + } + up_write(&lgr->rmbs_lock); + + if (lgr->buf_type == SMCR_PHYS_CONT_BUFS) + return rc; + + /* reg all vzalloced sndbufs for a new link */ + down_write(&lgr->sndbufs_lock); + for (i = 0; i < SMC_RMBE_SIZES; i++) { + list_for_each_entry_safe(buf_desc, bf, &lgr->sndbufs[i], list) { + if (!buf_desc->used || !buf_desc->is_vm) + continue; + rc = smcr_link_reg_buf(lnk, buf_desc); + if (rc) { + up_write(&lgr->sndbufs_lock); + return rc; + } + } + } + up_write(&lgr->sndbufs_lock); + return rc; +} + +static struct smc_buf_desc *smcr_new_buf_create(struct smc_link_group *lgr, + bool is_rmb, int bufsize) +{ + struct smc_buf_desc *buf_desc; + + /* try to alloc a new buffer */ + buf_desc = kzalloc(sizeof(*buf_desc), GFP_KERNEL); + if (!buf_desc) + return ERR_PTR(-ENOMEM); + + switch (lgr->buf_type) { + case SMCR_PHYS_CONT_BUFS: + case SMCR_MIXED_BUFS: + buf_desc->order = get_order(bufsize); + buf_desc->pages = alloc_pages(GFP_KERNEL | __GFP_NOWARN | + __GFP_NOMEMALLOC | __GFP_COMP | + __GFP_NORETRY | __GFP_ZERO, + buf_desc->order); + if (buf_desc->pages) { + buf_desc->cpu_addr = + (void *)page_address(buf_desc->pages); + buf_desc->len = bufsize; + buf_desc->is_vm = false; + break; + } + if (lgr->buf_type == SMCR_PHYS_CONT_BUFS) + goto out; + fallthrough; // try virtually continguous buf + case SMCR_VIRT_CONT_BUFS: + buf_desc->order = get_order(bufsize); + buf_desc->cpu_addr = vzalloc(PAGE_SIZE << buf_desc->order); + if (!buf_desc->cpu_addr) + goto out; + buf_desc->pages = NULL; + buf_desc->len = bufsize; + buf_desc->is_vm = true; + break; + } + return buf_desc; + +out: + kfree(buf_desc); + return ERR_PTR(-EAGAIN); +} + +/* map buf_desc on all usable links, + * unused buffers stay mapped as long as the link is up + */ +static int smcr_buf_map_usable_links(struct smc_link_group *lgr, + struct smc_buf_desc *buf_desc, bool is_rmb) +{ + int i, rc = 0, cnt = 0; + + /* protect against parallel link reconfiguration */ + mutex_lock(&lgr->llc_conf_mutex); + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + struct smc_link *lnk = &lgr->lnk[i]; + + if (!smc_link_usable(lnk)) + continue; + if (smcr_buf_map_link(buf_desc, is_rmb, lnk)) { + rc = -ENOMEM; + goto out; + } + cnt++; + } +out: + mutex_unlock(&lgr->llc_conf_mutex); + if (!rc && !cnt) + rc = -EINVAL; + return rc; +} + +static struct smc_buf_desc *smcd_new_buf_create(struct smc_link_group *lgr, + bool is_dmb, int bufsize) +{ + struct smc_buf_desc *buf_desc; + int rc; + + /* try to alloc a new DMB */ + buf_desc = kzalloc(sizeof(*buf_desc), GFP_KERNEL); + if (!buf_desc) + return ERR_PTR(-ENOMEM); + if (is_dmb) { + rc = smc_ism_register_dmb(lgr, bufsize, buf_desc); + if (rc) { + kfree(buf_desc); + if (rc == -ENOMEM) + return ERR_PTR(-EAGAIN); + if (rc == -ENOSPC) + return ERR_PTR(-ENOSPC); + return ERR_PTR(-EIO); + } + buf_desc->pages = virt_to_page(buf_desc->cpu_addr); + /* CDC header stored in buf. So, pretend it was smaller */ + buf_desc->len = bufsize - sizeof(struct smcd_cdc_msg); + } else { + buf_desc->cpu_addr = kzalloc(bufsize, GFP_KERNEL | + __GFP_NOWARN | __GFP_NORETRY | + __GFP_NOMEMALLOC); + if (!buf_desc->cpu_addr) { + kfree(buf_desc); + return ERR_PTR(-EAGAIN); + } + buf_desc->len = bufsize; + } + return buf_desc; +} + +static int __smc_buf_create(struct smc_sock *smc, bool is_smcd, bool is_rmb) +{ + struct smc_buf_desc *buf_desc = ERR_PTR(-ENOMEM); + struct smc_connection *conn = &smc->conn; + struct smc_link_group *lgr = conn->lgr; + struct list_head *buf_list; + int bufsize, bufsize_comp; + struct rw_semaphore *lock; /* lock buffer list */ + bool is_dgraded = false; + + if (is_rmb) + /* use socket recv buffer size (w/o overhead) as start value */ + bufsize = smc->sk.sk_rcvbuf / 2; + else + /* use socket send buffer size (w/o overhead) as start value */ + bufsize = smc->sk.sk_sndbuf / 2; + + for (bufsize_comp = smc_compress_bufsize(bufsize, is_smcd, is_rmb); + bufsize_comp >= 0; bufsize_comp--) { + if (is_rmb) { + lock = &lgr->rmbs_lock; + buf_list = &lgr->rmbs[bufsize_comp]; + } else { + lock = &lgr->sndbufs_lock; + buf_list = &lgr->sndbufs[bufsize_comp]; + } + bufsize = smc_uncompress_bufsize(bufsize_comp); + + /* check for reusable slot in the link group */ + buf_desc = smc_buf_get_slot(bufsize_comp, lock, buf_list); + if (buf_desc) { + buf_desc->is_dma_need_sync = 0; + SMC_STAT_RMB_SIZE(smc, is_smcd, is_rmb, bufsize); + SMC_STAT_BUF_REUSE(smc, is_smcd, is_rmb); + break; /* found reusable slot */ + } + + if (is_smcd) + buf_desc = smcd_new_buf_create(lgr, is_rmb, bufsize); + else + buf_desc = smcr_new_buf_create(lgr, is_rmb, bufsize); + + if (PTR_ERR(buf_desc) == -ENOMEM) + break; + if (IS_ERR(buf_desc)) { + if (!is_dgraded) { + is_dgraded = true; + SMC_STAT_RMB_DOWNGRADED(smc, is_smcd, is_rmb); + } + continue; + } + + SMC_STAT_RMB_ALLOC(smc, is_smcd, is_rmb); + SMC_STAT_RMB_SIZE(smc, is_smcd, is_rmb, bufsize); + buf_desc->used = 1; + down_write(lock); + list_add(&buf_desc->list, buf_list); + up_write(lock); + break; /* found */ + } + + if (IS_ERR(buf_desc)) + return PTR_ERR(buf_desc); + + if (!is_smcd) { + if (smcr_buf_map_usable_links(lgr, buf_desc, is_rmb)) { + smcr_buf_unuse(buf_desc, is_rmb, lgr); + return -ENOMEM; + } + } + + if (is_rmb) { + conn->rmb_desc = buf_desc; + conn->rmbe_size_comp = bufsize_comp; + smc->sk.sk_rcvbuf = bufsize * 2; + atomic_set(&conn->bytes_to_rcv, 0); + conn->rmbe_update_limit = + smc_rmb_wnd_update_limit(buf_desc->len); + if (is_smcd) + smc_ism_set_conn(conn); /* map RMB/smcd_dev to conn */ + } else { + conn->sndbuf_desc = buf_desc; + smc->sk.sk_sndbuf = bufsize * 2; + atomic_set(&conn->sndbuf_space, bufsize); + } + return 0; +} + +void smc_sndbuf_sync_sg_for_device(struct smc_connection *conn) +{ + if (!conn->sndbuf_desc->is_dma_need_sync) + return; + if (!smc_conn_lgr_valid(conn) || conn->lgr->is_smcd || + !smc_link_active(conn->lnk)) + return; + smc_ib_sync_sg_for_device(conn->lnk, conn->sndbuf_desc, DMA_TO_DEVICE); +} + +void smc_rmb_sync_sg_for_cpu(struct smc_connection *conn) +{ + int i; + + if (!conn->rmb_desc->is_dma_need_sync) + return; + if (!smc_conn_lgr_valid(conn) || conn->lgr->is_smcd) + return; + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + if (!smc_link_active(&conn->lgr->lnk[i])) + continue; + smc_ib_sync_sg_for_cpu(&conn->lgr->lnk[i], conn->rmb_desc, + DMA_FROM_DEVICE); + } +} + +/* create the send and receive buffer for an SMC socket; + * receive buffers are called RMBs; + * (even though the SMC protocol allows more than one RMB-element per RMB, + * the Linux implementation uses just one RMB-element per RMB, i.e. uses an + * extra RMB for every connection in a link group + */ +int smc_buf_create(struct smc_sock *smc, bool is_smcd) +{ + int rc; + + /* create send buffer */ + rc = __smc_buf_create(smc, is_smcd, false); + if (rc) + return rc; + /* create rmb */ + rc = __smc_buf_create(smc, is_smcd, true); + if (rc) { + down_write(&smc->conn.lgr->sndbufs_lock); + list_del(&smc->conn.sndbuf_desc->list); + up_write(&smc->conn.lgr->sndbufs_lock); + smc_buf_free(smc->conn.lgr, false, smc->conn.sndbuf_desc); + smc->conn.sndbuf_desc = NULL; + } + return rc; +} + +static inline int smc_rmb_reserve_rtoken_idx(struct smc_link_group *lgr) +{ + int i; + + for_each_clear_bit(i, lgr->rtokens_used_mask, SMC_RMBS_PER_LGR_MAX) { + if (!test_and_set_bit(i, lgr->rtokens_used_mask)) + return i; + } + return -ENOSPC; +} + +static int smc_rtoken_find_by_link(struct smc_link_group *lgr, int lnk_idx, + u32 rkey) +{ + int i; + + for (i = 0; i < SMC_RMBS_PER_LGR_MAX; i++) { + if (test_bit(i, lgr->rtokens_used_mask) && + lgr->rtokens[i][lnk_idx].rkey == rkey) + return i; + } + return -ENOENT; +} + +/* set rtoken for a new link to an existing rmb */ +void smc_rtoken_set(struct smc_link_group *lgr, int link_idx, int link_idx_new, + __be32 nw_rkey_known, __be64 nw_vaddr, __be32 nw_rkey) +{ + int rtok_idx; + + rtok_idx = smc_rtoken_find_by_link(lgr, link_idx, ntohl(nw_rkey_known)); + if (rtok_idx == -ENOENT) + return; + lgr->rtokens[rtok_idx][link_idx_new].rkey = ntohl(nw_rkey); + lgr->rtokens[rtok_idx][link_idx_new].dma_addr = be64_to_cpu(nw_vaddr); +} + +/* set rtoken for a new link whose link_id is given */ +void smc_rtoken_set2(struct smc_link_group *lgr, int rtok_idx, int link_id, + __be64 nw_vaddr, __be32 nw_rkey) +{ + u64 dma_addr = be64_to_cpu(nw_vaddr); + u32 rkey = ntohl(nw_rkey); + bool found = false; + int link_idx; + + for (link_idx = 0; link_idx < SMC_LINKS_PER_LGR_MAX; link_idx++) { + if (lgr->lnk[link_idx].link_id == link_id) { + found = true; + break; + } + } + if (!found) + return; + lgr->rtokens[rtok_idx][link_idx].rkey = rkey; + lgr->rtokens[rtok_idx][link_idx].dma_addr = dma_addr; +} + +/* add a new rtoken from peer */ +int smc_rtoken_add(struct smc_link *lnk, __be64 nw_vaddr, __be32 nw_rkey) +{ + struct smc_link_group *lgr = smc_get_lgr(lnk); + u64 dma_addr = be64_to_cpu(nw_vaddr); + u32 rkey = ntohl(nw_rkey); + int i; + + for (i = 0; i < SMC_RMBS_PER_LGR_MAX; i++) { + if (lgr->rtokens[i][lnk->link_idx].rkey == rkey && + lgr->rtokens[i][lnk->link_idx].dma_addr == dma_addr && + test_bit(i, lgr->rtokens_used_mask)) { + /* already in list */ + return i; + } + } + i = smc_rmb_reserve_rtoken_idx(lgr); + if (i < 0) + return i; + lgr->rtokens[i][lnk->link_idx].rkey = rkey; + lgr->rtokens[i][lnk->link_idx].dma_addr = dma_addr; + return i; +} + +/* delete an rtoken from all links */ +int smc_rtoken_delete(struct smc_link *lnk, __be32 nw_rkey) +{ + struct smc_link_group *lgr = smc_get_lgr(lnk); + u32 rkey = ntohl(nw_rkey); + int i, j; + + for (i = 0; i < SMC_RMBS_PER_LGR_MAX; i++) { + if (lgr->rtokens[i][lnk->link_idx].rkey == rkey && + test_bit(i, lgr->rtokens_used_mask)) { + for (j = 0; j < SMC_LINKS_PER_LGR_MAX; j++) { + lgr->rtokens[i][j].rkey = 0; + lgr->rtokens[i][j].dma_addr = 0; + } + clear_bit(i, lgr->rtokens_used_mask); + return 0; + } + } + return -ENOENT; +} + +/* save rkey and dma_addr received from peer during clc handshake */ +int smc_rmb_rtoken_handling(struct smc_connection *conn, + struct smc_link *lnk, + struct smc_clc_msg_accept_confirm *clc) +{ + conn->rtoken_idx = smc_rtoken_add(lnk, clc->r0.rmb_dma_addr, + clc->r0.rmb_rkey); + if (conn->rtoken_idx < 0) + return conn->rtoken_idx; + return 0; +} + +static void smc_core_going_away(void) +{ + struct smc_ib_device *smcibdev; + struct smcd_dev *smcd; + + mutex_lock(&smc_ib_devices.mutex); + list_for_each_entry(smcibdev, &smc_ib_devices.list, list) { + int i; + + for (i = 0; i < SMC_MAX_PORTS; i++) + set_bit(i, smcibdev->ports_going_away); + } + mutex_unlock(&smc_ib_devices.mutex); + + mutex_lock(&smcd_dev_list.mutex); + list_for_each_entry(smcd, &smcd_dev_list.list, list) { + smcd->going_away = 1; + } + mutex_unlock(&smcd_dev_list.mutex); +} + +/* Clean up all SMC link groups */ +static void smc_lgrs_shutdown(void) +{ + struct smcd_dev *smcd; + + smc_core_going_away(); + + smc_smcr_terminate_all(NULL); + + mutex_lock(&smcd_dev_list.mutex); + list_for_each_entry(smcd, &smcd_dev_list.list, list) + smc_smcd_terminate_all(smcd); + mutex_unlock(&smcd_dev_list.mutex); +} + +static int smc_core_reboot_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + smc_lgrs_shutdown(); + smc_ib_unregister_client(); + return 0; +} + +static struct notifier_block smc_reboot_notifier = { + .notifier_call = smc_core_reboot_event, +}; + +int __init smc_core_init(void) +{ + return register_reboot_notifier(&smc_reboot_notifier); +} + +/* Called (from smc_exit) when module is removed */ +void smc_core_exit(void) +{ + unregister_reboot_notifier(&smc_reboot_notifier); + smc_lgrs_shutdown(); +} diff --git a/net/smc/smc_core.h b/net/smc/smc_core.h new file mode 100644 index 000000000..6051d9227 --- /dev/null +++ b/net/smc/smc_core.h @@ -0,0 +1,566 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Definitions for SMC Connections, Link Groups and Links + * + * Copyright IBM Corp. 2016 + * + * Author(s): Ursula Braun + */ + +#ifndef _SMC_CORE_H +#define _SMC_CORE_H + +#include +#include +#include +#include +#include + +#include "smc.h" +#include "smc_ib.h" + +#define SMC_RMBS_PER_LGR_MAX 255 /* max. # of RMBs per link group */ + +struct smc_lgr_list { /* list of link group definition */ + struct list_head list; + spinlock_t lock; /* protects list of link groups */ + u32 num; /* unique link group number */ +}; + +enum smc_lgr_role { /* possible roles of a link group */ + SMC_CLNT, /* client */ + SMC_SERV /* server */ +}; + +enum smc_link_state { /* possible states of a link */ + SMC_LNK_UNUSED, /* link is unused */ + SMC_LNK_INACTIVE, /* link is inactive */ + SMC_LNK_ACTIVATING, /* link is being activated */ + SMC_LNK_ACTIVE, /* link is active */ +}; + +#define SMC_WR_BUF_SIZE 48 /* size of work request buffer */ +#define SMC_WR_BUF_V2_SIZE 8192 /* size of v2 work request buffer */ + +struct smc_wr_buf { + u8 raw[SMC_WR_BUF_SIZE]; +}; + +struct smc_wr_v2_buf { + u8 raw[SMC_WR_BUF_V2_SIZE]; +}; + +#define SMC_WR_REG_MR_WAIT_TIME (5 * HZ)/* wait time for ib_wr_reg_mr result */ + +enum smc_wr_reg_state { + POSTED, /* ib_wr_reg_mr request posted */ + CONFIRMED, /* ib_wr_reg_mr response: successful */ + FAILED /* ib_wr_reg_mr response: failure */ +}; + +struct smc_rdma_sge { /* sges for RDMA writes */ + struct ib_sge wr_tx_rdma_sge[SMC_IB_MAX_SEND_SGE]; +}; + +#define SMC_MAX_RDMA_WRITES 2 /* max. # of RDMA writes per + * message send + */ + +struct smc_rdma_sges { /* sges per message send */ + struct smc_rdma_sge tx_rdma_sge[SMC_MAX_RDMA_WRITES]; +}; + +struct smc_rdma_wr { /* work requests per message + * send + */ + struct ib_rdma_wr wr_tx_rdma[SMC_MAX_RDMA_WRITES]; +}; + +#define SMC_LGR_ID_SIZE 4 + +struct smc_link { + struct smc_ib_device *smcibdev; /* ib-device */ + u8 ibport; /* port - values 1 | 2 */ + struct ib_pd *roce_pd; /* IB protection domain, + * unique for every RoCE QP + */ + struct ib_qp *roce_qp; /* IB queue pair */ + struct ib_qp_attr qp_attr; /* IB queue pair attributes */ + + struct smc_wr_buf *wr_tx_bufs; /* WR send payload buffers */ + struct ib_send_wr *wr_tx_ibs; /* WR send meta data */ + struct ib_sge *wr_tx_sges; /* WR send gather meta data */ + struct smc_rdma_sges *wr_tx_rdma_sges;/*RDMA WRITE gather meta data*/ + struct smc_rdma_wr *wr_tx_rdmas; /* WR RDMA WRITE */ + struct smc_wr_tx_pend *wr_tx_pends; /* WR send waiting for CQE */ + struct completion *wr_tx_compl; /* WR send CQE completion */ + /* above four vectors have wr_tx_cnt elements and use the same index */ + struct ib_send_wr *wr_tx_v2_ib; /* WR send v2 meta data */ + struct ib_sge *wr_tx_v2_sge; /* WR send v2 gather meta data*/ + struct smc_wr_tx_pend *wr_tx_v2_pend; /* WR send v2 waiting for CQE */ + dma_addr_t wr_tx_dma_addr; /* DMA address of wr_tx_bufs */ + dma_addr_t wr_tx_v2_dma_addr; /* DMA address of v2 tx buf*/ + atomic_long_t wr_tx_id; /* seq # of last sent WR */ + unsigned long *wr_tx_mask; /* bit mask of used indexes */ + u32 wr_tx_cnt; /* number of WR send buffers */ + wait_queue_head_t wr_tx_wait; /* wait for free WR send buf */ + atomic_t wr_tx_refcnt; /* tx refs to link */ + + struct smc_wr_buf *wr_rx_bufs; /* WR recv payload buffers */ + struct ib_recv_wr *wr_rx_ibs; /* WR recv meta data */ + struct ib_sge *wr_rx_sges; /* WR recv scatter meta data */ + /* above three vectors have wr_rx_cnt elements and use the same index */ + dma_addr_t wr_rx_dma_addr; /* DMA address of wr_rx_bufs */ + dma_addr_t wr_rx_v2_dma_addr; /* DMA address of v2 rx buf*/ + u64 wr_rx_id; /* seq # of last recv WR */ + u64 wr_rx_id_compl; /* seq # of last completed WR */ + u32 wr_rx_cnt; /* number of WR recv buffers */ + unsigned long wr_rx_tstamp; /* jiffies when last buf rx */ + wait_queue_head_t wr_rx_empty_wait; /* wait for RQ empty */ + + struct ib_reg_wr wr_reg; /* WR register memory region */ + wait_queue_head_t wr_reg_wait; /* wait for wr_reg result */ + atomic_t wr_reg_refcnt; /* reg refs to link */ + enum smc_wr_reg_state wr_reg_state; /* state of wr_reg request */ + + u8 gid[SMC_GID_SIZE];/* gid matching used vlan id*/ + u8 sgid_index; /* gid index for vlan id */ + u32 peer_qpn; /* QP number of peer */ + enum ib_mtu path_mtu; /* used mtu */ + enum ib_mtu peer_mtu; /* mtu size of peer */ + u32 psn_initial; /* QP tx initial packet seqno */ + u32 peer_psn; /* QP rx initial packet seqno */ + u8 peer_mac[ETH_ALEN]; /* = gid[8:10||13:15] */ + u8 peer_gid[SMC_GID_SIZE]; /* gid of peer*/ + u8 link_id; /* unique # within link group */ + u8 link_uid[SMC_LGR_ID_SIZE]; /* unique lnk id */ + u8 peer_link_uid[SMC_LGR_ID_SIZE]; /* peer uid */ + u8 link_idx; /* index in lgr link array */ + u8 link_is_asym; /* is link asymmetric? */ + u8 clearing : 1; /* link is being cleared */ + refcount_t refcnt; /* link reference count */ + struct smc_link_group *lgr; /* parent link group */ + struct work_struct link_down_wrk; /* wrk to bring link down */ + char ibname[IB_DEVICE_NAME_MAX]; /* ib device name */ + int ndev_ifidx; /* network device ifindex */ + + enum smc_link_state state; /* state of link */ + struct delayed_work llc_testlink_wrk; /* testlink worker */ + struct completion llc_testlink_resp; /* wait for rx of testlink */ + int llc_testlink_time; /* testlink interval */ + atomic_t conn_cnt; /* connections on this link */ +}; + +/* For now we just allow one parallel link per link group. The SMC protocol + * allows more (up to 8). + */ +#define SMC_LINKS_PER_LGR_MAX 3 +#define SMC_SINGLE_LINK 0 + +/* tx/rx buffer list element for sndbufs list and rmbs list of a lgr */ +struct smc_buf_desc { + struct list_head list; + void *cpu_addr; /* virtual address of buffer */ + struct page *pages; + int len; /* length of buffer */ + u32 used; /* currently used / unused */ + union { + struct { /* SMC-R */ + struct sg_table sgt[SMC_LINKS_PER_LGR_MAX]; + /* virtual buffer */ + struct ib_mr *mr[SMC_LINKS_PER_LGR_MAX]; + /* memory region: for rmb and + * vzalloced sndbuf + * incl. rkey provided to peer + * and lkey provided to local + */ + u32 order; /* allocation order */ + + u8 is_conf_rkey; + /* confirm_rkey done */ + u8 is_reg_mr[SMC_LINKS_PER_LGR_MAX]; + /* mem region registered */ + u8 is_map_ib[SMC_LINKS_PER_LGR_MAX]; + /* mem region mapped to lnk */ + u8 is_dma_need_sync; + u8 is_reg_err; + /* buffer registration err */ + u8 is_vm; + /* virtually contiguous */ + }; + struct { /* SMC-D */ + unsigned short sba_idx; + /* SBA index number */ + u64 token; + /* DMB token number */ + dma_addr_t dma_addr; + /* DMA address */ + }; + }; +}; + +struct smc_rtoken { /* address/key of remote RMB */ + u64 dma_addr; + u32 rkey; +}; + +#define SMC_BUF_MIN_SIZE 16384 /* minimum size of an RMB */ +#define SMC_RMBE_SIZES 16 /* number of distinct RMBE sizes */ +/* theoretically, the RFC states that largest size would be 512K, + * i.e. compressed 5 and thus 6 sizes (0..5), despite + * struct smc_clc_msg_accept_confirm.rmbe_size being a 4 bit value (0..15) + */ + +struct smcd_dev; + +enum smc_lgr_type { /* redundancy state of lgr */ + SMC_LGR_NONE, /* no active links, lgr to be deleted */ + SMC_LGR_SINGLE, /* 1 active RNIC on each peer */ + SMC_LGR_SYMMETRIC, /* 2 active RNICs on each peer */ + SMC_LGR_ASYMMETRIC_PEER, /* local has 2, peer 1 active RNICs */ + SMC_LGR_ASYMMETRIC_LOCAL, /* local has 1, peer 2 active RNICs */ +}; + +enum smcr_buf_type { /* types of SMC-R sndbufs and RMBs */ + SMCR_PHYS_CONT_BUFS = 0, + SMCR_VIRT_CONT_BUFS = 1, + SMCR_MIXED_BUFS = 2, +}; + +enum smc_llc_flowtype { + SMC_LLC_FLOW_NONE = 0, + SMC_LLC_FLOW_ADD_LINK = 2, + SMC_LLC_FLOW_DEL_LINK = 4, + SMC_LLC_FLOW_REQ_ADD_LINK = 5, + SMC_LLC_FLOW_RKEY = 6, +}; + +struct smc_llc_qentry; + +struct smc_llc_flow { + enum smc_llc_flowtype type; + struct smc_llc_qentry *qentry; +}; + +struct smc_link_group { + struct list_head list; + struct rb_root conns_all; /* connection tree */ + rwlock_t conns_lock; /* protects conns_all */ + unsigned int conns_num; /* current # of connections */ + unsigned short vlan_id; /* vlan id of link group */ + + struct list_head sndbufs[SMC_RMBE_SIZES];/* tx buffers */ + struct rw_semaphore sndbufs_lock; /* protects tx buffers */ + struct list_head rmbs[SMC_RMBE_SIZES]; /* rx buffers */ + struct rw_semaphore rmbs_lock; /* protects rx buffers */ + + u8 id[SMC_LGR_ID_SIZE]; /* unique lgr id */ + struct delayed_work free_work; /* delayed freeing of an lgr */ + struct work_struct terminate_work; /* abnormal lgr termination */ + struct workqueue_struct *tx_wq; /* wq for conn. tx workers */ + u8 sync_err : 1; /* lgr no longer fits to peer */ + u8 terminating : 1;/* lgr is terminating */ + u8 freeing : 1; /* lgr is being freed */ + + refcount_t refcnt; /* lgr reference count */ + bool is_smcd; /* SMC-R or SMC-D */ + u8 smc_version; + u8 negotiated_eid[SMC_MAX_EID_LEN]; + u8 peer_os; /* peer operating system */ + u8 peer_smc_release; + u8 peer_hostname[SMC_MAX_HOSTNAME_LEN]; + union { + struct { /* SMC-R */ + enum smc_lgr_role role; + /* client or server */ + struct smc_link lnk[SMC_LINKS_PER_LGR_MAX]; + /* smc link */ + struct smc_wr_v2_buf *wr_rx_buf_v2; + /* WR v2 recv payload buffer */ + struct smc_wr_v2_buf *wr_tx_buf_v2; + /* WR v2 send payload buffer */ + char peer_systemid[SMC_SYSTEMID_LEN]; + /* unique system_id of peer */ + struct smc_rtoken rtokens[SMC_RMBS_PER_LGR_MAX] + [SMC_LINKS_PER_LGR_MAX]; + /* remote addr/key pairs */ + DECLARE_BITMAP(rtokens_used_mask, SMC_RMBS_PER_LGR_MAX); + /* used rtoken elements */ + u8 next_link_id; + enum smc_lgr_type type; + enum smcr_buf_type buf_type; + /* redundancy state */ + u8 pnet_id[SMC_MAX_PNETID_LEN + 1]; + /* pnet id of this lgr */ + struct list_head llc_event_q; + /* queue for llc events */ + spinlock_t llc_event_q_lock; + /* protects llc_event_q */ + struct mutex llc_conf_mutex; + /* protects lgr reconfig. */ + struct work_struct llc_add_link_work; + struct work_struct llc_del_link_work; + struct work_struct llc_event_work; + /* llc event worker */ + wait_queue_head_t llc_flow_waiter; + /* w4 next llc event */ + wait_queue_head_t llc_msg_waiter; + /* w4 next llc msg */ + struct smc_llc_flow llc_flow_lcl; + /* llc local control field */ + struct smc_llc_flow llc_flow_rmt; + /* llc remote control field */ + struct smc_llc_qentry *delayed_event; + /* arrived when flow active */ + spinlock_t llc_flow_lock; + /* protects llc flow */ + int llc_testlink_time; + /* link keep alive time */ + u32 llc_termination_rsn; + /* rsn code for termination */ + u8 nexthop_mac[ETH_ALEN]; + u8 uses_gateway; + __be32 saddr; + /* net namespace */ + struct net *net; + }; + struct { /* SMC-D */ + u64 peer_gid; + /* Peer GID (remote) */ + struct smcd_dev *smcd; + /* ISM device for VLAN reg. */ + u8 peer_shutdown : 1; + /* peer triggered shutdownn */ + }; + }; +}; + +struct smc_clc_msg_local; + +#define GID_LIST_SIZE 2 + +struct smc_gidlist { + u8 len; + u8 list[GID_LIST_SIZE][SMC_GID_SIZE]; +}; + +struct smc_init_info_smcrv2 { + /* Input fields */ + __be32 saddr; + struct sock *clc_sk; + __be32 daddr; + + /* Output fields when saddr is set */ + struct smc_ib_device *ib_dev_v2; + u8 ib_port_v2; + u8 ib_gid_v2[SMC_GID_SIZE]; + + /* Additional output fields when clc_sk and daddr is set as well */ + u8 uses_gateway; + u8 nexthop_mac[ETH_ALEN]; + + struct smc_gidlist gidlist; +}; + +struct smc_init_info { + u8 is_smcd; + u8 smc_type_v1; + u8 smc_type_v2; + u8 first_contact_peer; + u8 first_contact_local; + unsigned short vlan_id; + u32 rc; + u8 negotiated_eid[SMC_MAX_EID_LEN]; + /* SMC-R */ + u8 smcr_version; + u8 check_smcrv2; + u8 peer_gid[SMC_GID_SIZE]; + u8 peer_mac[ETH_ALEN]; + u8 peer_systemid[SMC_SYSTEMID_LEN]; + struct smc_ib_device *ib_dev; + u8 ib_gid[SMC_GID_SIZE]; + u8 ib_port; + u32 ib_clcqpn; + struct smc_init_info_smcrv2 smcrv2; + /* SMC-D */ + u64 ism_peer_gid[SMC_MAX_ISM_DEVS + 1]; + struct smcd_dev *ism_dev[SMC_MAX_ISM_DEVS + 1]; + u16 ism_chid[SMC_MAX_ISM_DEVS + 1]; + u8 ism_offered_cnt; /* # of ISM devices offered */ + u8 ism_selected; /* index of selected ISM dev*/ + u8 smcd_version; +}; + +/* Find the connection associated with the given alert token in the link group. + * To use rbtrees we have to implement our own search core. + * Requires @conns_lock + * @token alert token to search for + * @lgr link group to search in + * Returns connection associated with token if found, NULL otherwise. + */ +static inline struct smc_connection *smc_lgr_find_conn( + u32 token, struct smc_link_group *lgr) +{ + struct smc_connection *res = NULL; + struct rb_node *node; + + node = lgr->conns_all.rb_node; + while (node) { + struct smc_connection *cur = rb_entry(node, + struct smc_connection, alert_node); + + if (cur->alert_token_local > token) { + node = node->rb_left; + } else { + if (cur->alert_token_local < token) { + node = node->rb_right; + } else { + res = cur; + break; + } + } + } + + return res; +} + +static inline bool smc_conn_lgr_valid(struct smc_connection *conn) +{ + return conn->lgr && conn->alert_token_local; +} + +/* + * Returns true if the specified link is usable. + * + * usable means the link is ready to receive RDMA messages, map memory + * on the link, etc. This doesn't ensure we are able to send RDMA messages + * on this link, if sending RDMA messages is needed, use smc_link_sendable() + */ +static inline bool smc_link_usable(struct smc_link *lnk) +{ + if (lnk->state == SMC_LNK_UNUSED || lnk->state == SMC_LNK_INACTIVE) + return false; + return true; +} + +/* + * Returns true if the specified link is ready to receive AND send RDMA + * messages. + * + * For the client side in first contact, the underlying QP may still in + * RESET or RTR when the link state is ACTIVATING, checks in smc_link_usable() + * is not strong enough. For those places that need to send any CDC or LLC + * messages, use smc_link_sendable(), otherwise, use smc_link_usable() instead + */ +static inline bool smc_link_sendable(struct smc_link *lnk) +{ + return smc_link_usable(lnk) && + lnk->qp_attr.cur_qp_state == IB_QPS_RTS; +} + +static inline bool smc_link_active(struct smc_link *lnk) +{ + return lnk->state == SMC_LNK_ACTIVE; +} + +static inline void smc_gid_be16_convert(__u8 *buf, u8 *gid_raw) +{ + sprintf(buf, "%04x:%04x:%04x:%04x:%04x:%04x:%04x:%04x", + be16_to_cpu(((__be16 *)gid_raw)[0]), + be16_to_cpu(((__be16 *)gid_raw)[1]), + be16_to_cpu(((__be16 *)gid_raw)[2]), + be16_to_cpu(((__be16 *)gid_raw)[3]), + be16_to_cpu(((__be16 *)gid_raw)[4]), + be16_to_cpu(((__be16 *)gid_raw)[5]), + be16_to_cpu(((__be16 *)gid_raw)[6]), + be16_to_cpu(((__be16 *)gid_raw)[7])); +} + +struct smc_pci_dev { + __u32 pci_fid; + __u16 pci_pchid; + __u16 pci_vendor; + __u16 pci_device; + __u8 pci_id[SMC_PCI_ID_STR_LEN]; +}; + +static inline void smc_set_pci_values(struct pci_dev *pci_dev, + struct smc_pci_dev *smc_dev) +{ + smc_dev->pci_vendor = pci_dev->vendor; + smc_dev->pci_device = pci_dev->device; + snprintf(smc_dev->pci_id, sizeof(smc_dev->pci_id), "%s", + pci_name(pci_dev)); +#if IS_ENABLED(CONFIG_S390) + { /* Set s390 specific PCI information */ + struct zpci_dev *zdev; + + zdev = to_zpci(pci_dev); + smc_dev->pci_fid = zdev->fid; + smc_dev->pci_pchid = zdev->pchid; + } +#endif +} + +struct smc_sock; +struct smc_clc_msg_accept_confirm; + +void smc_lgr_cleanup_early(struct smc_link_group *lgr); +void smc_lgr_terminate_sched(struct smc_link_group *lgr); +void smc_lgr_hold(struct smc_link_group *lgr); +void smc_lgr_put(struct smc_link_group *lgr); +void smcr_port_add(struct smc_ib_device *smcibdev, u8 ibport); +void smcr_port_err(struct smc_ib_device *smcibdev, u8 ibport); +void smc_smcd_terminate(struct smcd_dev *dev, u64 peer_gid, + unsigned short vlan); +void smc_smcd_terminate_all(struct smcd_dev *dev); +void smc_smcr_terminate_all(struct smc_ib_device *smcibdev); +int smc_buf_create(struct smc_sock *smc, bool is_smcd); +int smc_uncompress_bufsize(u8 compressed); +int smc_rmb_rtoken_handling(struct smc_connection *conn, struct smc_link *link, + struct smc_clc_msg_accept_confirm *clc); +int smc_rtoken_add(struct smc_link *lnk, __be64 nw_vaddr, __be32 nw_rkey); +int smc_rtoken_delete(struct smc_link *lnk, __be32 nw_rkey); +void smc_rtoken_set(struct smc_link_group *lgr, int link_idx, int link_idx_new, + __be32 nw_rkey_known, __be64 nw_vaddr, __be32 nw_rkey); +void smc_rtoken_set2(struct smc_link_group *lgr, int rtok_idx, int link_id, + __be64 nw_vaddr, __be32 nw_rkey); +void smc_sndbuf_sync_sg_for_device(struct smc_connection *conn); +void smc_rmb_sync_sg_for_cpu(struct smc_connection *conn); +int smc_vlan_by_tcpsk(struct socket *clcsock, struct smc_init_info *ini); + +void smc_conn_free(struct smc_connection *conn); +int smc_conn_create(struct smc_sock *smc, struct smc_init_info *ini); +void smc_lgr_schedule_free_work_fast(struct smc_link_group *lgr); +int smc_core_init(void); +void smc_core_exit(void); + +int smcr_link_init(struct smc_link_group *lgr, struct smc_link *lnk, + u8 link_idx, struct smc_init_info *ini); +void smcr_link_clear(struct smc_link *lnk, bool log); +void smcr_link_hold(struct smc_link *lnk); +void smcr_link_put(struct smc_link *lnk); +void smc_switch_link_and_count(struct smc_connection *conn, + struct smc_link *to_lnk); +int smcr_buf_map_lgr(struct smc_link *lnk); +int smcr_buf_reg_lgr(struct smc_link *lnk); +void smcr_lgr_set_type(struct smc_link_group *lgr, enum smc_lgr_type new_type); +void smcr_lgr_set_type_asym(struct smc_link_group *lgr, + enum smc_lgr_type new_type, int asym_lnk_idx); +int smcr_link_reg_buf(struct smc_link *link, struct smc_buf_desc *rmb_desc); +struct smc_link *smc_switch_conns(struct smc_link_group *lgr, + struct smc_link *from_lnk, bool is_dev_err); +void smcr_link_down_cond(struct smc_link *lnk); +void smcr_link_down_cond_sched(struct smc_link *lnk); +int smc_nl_get_sys_info(struct sk_buff *skb, struct netlink_callback *cb); +int smcr_nl_get_lgr(struct sk_buff *skb, struct netlink_callback *cb); +int smcr_nl_get_link(struct sk_buff *skb, struct netlink_callback *cb); +int smcd_nl_get_lgr(struct sk_buff *skb, struct netlink_callback *cb); + +static inline struct smc_link_group *smc_get_lgr(struct smc_link *link) +{ + return link->lgr; +} +#endif diff --git a/net/smc/smc_diag.c b/net/smc/smc_diag.c new file mode 100644 index 000000000..7a907186a --- /dev/null +++ b/net/smc/smc_diag.c @@ -0,0 +1,270 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Monitoring SMC transport protocol sockets + * + * Copyright IBM Corp. 2016 + * + * Author(s): Ursula Braun + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "smc.h" +#include "smc_core.h" + +struct smc_diag_dump_ctx { + int pos[2]; +}; + +static struct smc_diag_dump_ctx *smc_dump_context(struct netlink_callback *cb) +{ + return (struct smc_diag_dump_ctx *)cb->ctx; +} + +static void smc_diag_msg_common_fill(struct smc_diag_msg *r, struct sock *sk) +{ + struct smc_sock *smc = smc_sk(sk); + + memset(r, 0, sizeof(*r)); + r->diag_family = sk->sk_family; + sock_diag_save_cookie(sk, r->id.idiag_cookie); + if (!smc->clcsock) + return; + r->id.idiag_sport = htons(smc->clcsock->sk->sk_num); + r->id.idiag_dport = smc->clcsock->sk->sk_dport; + r->id.idiag_if = smc->clcsock->sk->sk_bound_dev_if; + if (sk->sk_protocol == SMCPROTO_SMC) { + r->id.idiag_src[0] = smc->clcsock->sk->sk_rcv_saddr; + r->id.idiag_dst[0] = smc->clcsock->sk->sk_daddr; +#if IS_ENABLED(CONFIG_IPV6) + } else if (sk->sk_protocol == SMCPROTO_SMC6) { + memcpy(&r->id.idiag_src, &smc->clcsock->sk->sk_v6_rcv_saddr, + sizeof(smc->clcsock->sk->sk_v6_rcv_saddr)); + memcpy(&r->id.idiag_dst, &smc->clcsock->sk->sk_v6_daddr, + sizeof(smc->clcsock->sk->sk_v6_daddr)); +#endif + } +} + +static int smc_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb, + struct smc_diag_msg *r, + struct user_namespace *user_ns) +{ + if (nla_put_u8(skb, SMC_DIAG_SHUTDOWN, sk->sk_shutdown)) + return 1; + + r->diag_uid = from_kuid_munged(user_ns, sock_i_uid(sk)); + r->diag_inode = sock_i_ino(sk); + return 0; +} + +static int __smc_diag_dump(struct sock *sk, struct sk_buff *skb, + struct netlink_callback *cb, + const struct smc_diag_req *req, + struct nlattr *bc) +{ + struct smc_sock *smc = smc_sk(sk); + struct smc_diag_fallback fallback; + struct user_namespace *user_ns; + struct smc_diag_msg *r; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + cb->nlh->nlmsg_type, sizeof(*r), NLM_F_MULTI); + if (!nlh) + return -EMSGSIZE; + + r = nlmsg_data(nlh); + smc_diag_msg_common_fill(r, sk); + r->diag_state = sk->sk_state; + if (smc->use_fallback) + r->diag_mode = SMC_DIAG_MODE_FALLBACK_TCP; + else if (smc_conn_lgr_valid(&smc->conn) && smc->conn.lgr->is_smcd) + r->diag_mode = SMC_DIAG_MODE_SMCD; + else + r->diag_mode = SMC_DIAG_MODE_SMCR; + user_ns = sk_user_ns(NETLINK_CB(cb->skb).sk); + if (smc_diag_msg_attrs_fill(sk, skb, r, user_ns)) + goto errout; + + fallback.reason = smc->fallback_rsn; + fallback.peer_diagnosis = smc->peer_diagnosis; + if (nla_put(skb, SMC_DIAG_FALLBACK, sizeof(fallback), &fallback) < 0) + goto errout; + + if ((req->diag_ext & (1 << (SMC_DIAG_CONNINFO - 1))) && + smc->conn.alert_token_local) { + struct smc_connection *conn = &smc->conn; + struct smc_diag_conninfo cinfo = { + .token = conn->alert_token_local, + .sndbuf_size = conn->sndbuf_desc ? + conn->sndbuf_desc->len : 0, + .rmbe_size = conn->rmb_desc ? conn->rmb_desc->len : 0, + .peer_rmbe_size = conn->peer_rmbe_size, + + .rx_prod.wrap = conn->local_rx_ctrl.prod.wrap, + .rx_prod.count = conn->local_rx_ctrl.prod.count, + .rx_cons.wrap = conn->local_rx_ctrl.cons.wrap, + .rx_cons.count = conn->local_rx_ctrl.cons.count, + + .tx_prod.wrap = conn->local_tx_ctrl.prod.wrap, + .tx_prod.count = conn->local_tx_ctrl.prod.count, + .tx_cons.wrap = conn->local_tx_ctrl.cons.wrap, + .tx_cons.count = conn->local_tx_ctrl.cons.count, + + .tx_prod_flags = + *(u8 *)&conn->local_tx_ctrl.prod_flags, + .tx_conn_state_flags = + *(u8 *)&conn->local_tx_ctrl.conn_state_flags, + .rx_prod_flags = *(u8 *)&conn->local_rx_ctrl.prod_flags, + .rx_conn_state_flags = + *(u8 *)&conn->local_rx_ctrl.conn_state_flags, + + .tx_prep.wrap = conn->tx_curs_prep.wrap, + .tx_prep.count = conn->tx_curs_prep.count, + .tx_sent.wrap = conn->tx_curs_sent.wrap, + .tx_sent.count = conn->tx_curs_sent.count, + .tx_fin.wrap = conn->tx_curs_fin.wrap, + .tx_fin.count = conn->tx_curs_fin.count, + }; + + if (nla_put(skb, SMC_DIAG_CONNINFO, sizeof(cinfo), &cinfo) < 0) + goto errout; + } + + if (smc_conn_lgr_valid(&smc->conn) && !smc->conn.lgr->is_smcd && + (req->diag_ext & (1 << (SMC_DIAG_LGRINFO - 1))) && + !list_empty(&smc->conn.lgr->list)) { + struct smc_link *link = smc->conn.lnk; + + struct smc_diag_lgrinfo linfo = { + .role = smc->conn.lgr->role, + .lnk[0].ibport = link->ibport, + .lnk[0].link_id = link->link_id, + }; + + memcpy(linfo.lnk[0].ibname, link->smcibdev->ibdev->name, + sizeof(link->smcibdev->ibdev->name)); + smc_gid_be16_convert(linfo.lnk[0].gid, link->gid); + smc_gid_be16_convert(linfo.lnk[0].peer_gid, link->peer_gid); + + if (nla_put(skb, SMC_DIAG_LGRINFO, sizeof(linfo), &linfo) < 0) + goto errout; + } + if (smc_conn_lgr_valid(&smc->conn) && smc->conn.lgr->is_smcd && + (req->diag_ext & (1 << (SMC_DIAG_DMBINFO - 1))) && + !list_empty(&smc->conn.lgr->list) && smc->conn.rmb_desc) { + struct smc_connection *conn = &smc->conn; + struct smcd_diag_dmbinfo dinfo; + + memset(&dinfo, 0, sizeof(dinfo)); + + dinfo.linkid = *((u32 *)conn->lgr->id); + dinfo.peer_gid = conn->lgr->peer_gid; + dinfo.my_gid = conn->lgr->smcd->local_gid; + dinfo.token = conn->rmb_desc->token; + dinfo.peer_token = conn->peer_token; + + if (nla_put(skb, SMC_DIAG_DMBINFO, sizeof(dinfo), &dinfo) < 0) + goto errout; + } + + nlmsg_end(skb, nlh); + return 0; + +errout: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int smc_diag_dump_proto(struct proto *prot, struct sk_buff *skb, + struct netlink_callback *cb, int p_type) +{ + struct smc_diag_dump_ctx *cb_ctx = smc_dump_context(cb); + struct net *net = sock_net(skb->sk); + int snum = cb_ctx->pos[p_type]; + struct nlattr *bc = NULL; + struct hlist_head *head; + int rc = 0, num = 0; + struct sock *sk; + + read_lock(&prot->h.smc_hash->lock); + head = &prot->h.smc_hash->ht; + if (hlist_empty(head)) + goto out; + + sk_for_each(sk, head) { + if (!net_eq(sock_net(sk), net)) + continue; + if (num < snum) + goto next; + rc = __smc_diag_dump(sk, skb, cb, nlmsg_data(cb->nlh), bc); + if (rc < 0) + goto out; +next: + num++; + } + +out: + read_unlock(&prot->h.smc_hash->lock); + cb_ctx->pos[p_type] = num; + return rc; +} + +static int smc_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + int rc = 0; + + rc = smc_diag_dump_proto(&smc_proto, skb, cb, SMCPROTO_SMC); + if (!rc) + smc_diag_dump_proto(&smc_proto6, skb, cb, SMCPROTO_SMC6); + return skb->len; +} + +static int smc_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h) +{ + struct net *net = sock_net(skb->sk); + + if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY && + h->nlmsg_flags & NLM_F_DUMP) { + { + struct netlink_dump_control c = { + .dump = smc_diag_dump, + .min_dump_alloc = SKB_WITH_OVERHEAD(32768), + }; + return netlink_dump_start(net->diag_nlsk, skb, h, &c); + } + } + return 0; +} + +static const struct sock_diag_handler smc_diag_handler = { + .family = AF_SMC, + .dump = smc_diag_handler_dump, +}; + +static int __init smc_diag_init(void) +{ + return sock_diag_register(&smc_diag_handler); +} + +static void __exit smc_diag_exit(void) +{ + sock_diag_unregister(&smc_diag_handler); +} + +module_init(smc_diag_init); +module_exit(smc_diag_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 43 /* AF_SMC */); +MODULE_ALIAS_GENL_FAMILY(SMCR_GENL_FAMILY_NAME); diff --git a/net/smc/smc_ib.c b/net/smc/smc_ib.c new file mode 100644 index 000000000..ace861173 --- /dev/null +++ b/net/smc/smc_ib.c @@ -0,0 +1,1018 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * IB infrastructure: + * Establish SMC-R as an Infiniband Client to be notified about added and + * removed IB devices of type RDMA. + * Determine device and port characteristics for these IB devices. + * + * Copyright IBM Corp. 2016 + * + * Author(s): Ursula Braun + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "smc_pnet.h" +#include "smc_ib.h" +#include "smc_core.h" +#include "smc_wr.h" +#include "smc.h" +#include "smc_netlink.h" + +#define SMC_MAX_CQE 32766 /* max. # of completion queue elements */ + +#define SMC_QP_MIN_RNR_TIMER 5 +#define SMC_QP_TIMEOUT 15 /* 4096 * 2 ** timeout usec */ +#define SMC_QP_RETRY_CNT 7 /* 7: infinite */ +#define SMC_QP_RNR_RETRY 7 /* 7: infinite */ + +struct smc_ib_devices smc_ib_devices = { /* smc-registered ib devices */ + .mutex = __MUTEX_INITIALIZER(smc_ib_devices.mutex), + .list = LIST_HEAD_INIT(smc_ib_devices.list), +}; + +u8 local_systemid[SMC_SYSTEMID_LEN]; /* unique system identifier */ + +static int smc_ib_modify_qp_init(struct smc_link *lnk) +{ + struct ib_qp_attr qp_attr; + + memset(&qp_attr, 0, sizeof(qp_attr)); + qp_attr.qp_state = IB_QPS_INIT; + qp_attr.pkey_index = 0; + qp_attr.port_num = lnk->ibport; + qp_attr.qp_access_flags = IB_ACCESS_LOCAL_WRITE + | IB_ACCESS_REMOTE_WRITE; + return ib_modify_qp(lnk->roce_qp, &qp_attr, + IB_QP_STATE | IB_QP_PKEY_INDEX | + IB_QP_ACCESS_FLAGS | IB_QP_PORT); +} + +static int smc_ib_modify_qp_rtr(struct smc_link *lnk) +{ + enum ib_qp_attr_mask qp_attr_mask = + IB_QP_STATE | IB_QP_AV | IB_QP_PATH_MTU | IB_QP_DEST_QPN | + IB_QP_RQ_PSN | IB_QP_MAX_DEST_RD_ATOMIC | IB_QP_MIN_RNR_TIMER; + struct ib_qp_attr qp_attr; + u8 hop_lim = 1; + + memset(&qp_attr, 0, sizeof(qp_attr)); + qp_attr.qp_state = IB_QPS_RTR; + qp_attr.path_mtu = min(lnk->path_mtu, lnk->peer_mtu); + qp_attr.ah_attr.type = RDMA_AH_ATTR_TYPE_ROCE; + rdma_ah_set_port_num(&qp_attr.ah_attr, lnk->ibport); + if (lnk->lgr->smc_version == SMC_V2 && lnk->lgr->uses_gateway) + hop_lim = IPV6_DEFAULT_HOPLIMIT; + rdma_ah_set_grh(&qp_attr.ah_attr, NULL, 0, lnk->sgid_index, hop_lim, 0); + rdma_ah_set_dgid_raw(&qp_attr.ah_attr, lnk->peer_gid); + if (lnk->lgr->smc_version == SMC_V2 && lnk->lgr->uses_gateway) + memcpy(&qp_attr.ah_attr.roce.dmac, lnk->lgr->nexthop_mac, + sizeof(lnk->lgr->nexthop_mac)); + else + memcpy(&qp_attr.ah_attr.roce.dmac, lnk->peer_mac, + sizeof(lnk->peer_mac)); + qp_attr.dest_qp_num = lnk->peer_qpn; + qp_attr.rq_psn = lnk->peer_psn; /* starting receive packet seq # */ + qp_attr.max_dest_rd_atomic = 1; /* max # of resources for incoming + * requests + */ + qp_attr.min_rnr_timer = SMC_QP_MIN_RNR_TIMER; + + return ib_modify_qp(lnk->roce_qp, &qp_attr, qp_attr_mask); +} + +int smc_ib_modify_qp_rts(struct smc_link *lnk) +{ + struct ib_qp_attr qp_attr; + + memset(&qp_attr, 0, sizeof(qp_attr)); + qp_attr.qp_state = IB_QPS_RTS; + qp_attr.timeout = SMC_QP_TIMEOUT; /* local ack timeout */ + qp_attr.retry_cnt = SMC_QP_RETRY_CNT; /* retry count */ + qp_attr.rnr_retry = SMC_QP_RNR_RETRY; /* RNR retries, 7=infinite */ + qp_attr.sq_psn = lnk->psn_initial; /* starting send packet seq # */ + qp_attr.max_rd_atomic = 1; /* # of outstanding RDMA reads and + * atomic ops allowed + */ + return ib_modify_qp(lnk->roce_qp, &qp_attr, + IB_QP_STATE | IB_QP_TIMEOUT | IB_QP_RETRY_CNT | + IB_QP_SQ_PSN | IB_QP_RNR_RETRY | + IB_QP_MAX_QP_RD_ATOMIC); +} + +int smc_ib_modify_qp_error(struct smc_link *lnk) +{ + struct ib_qp_attr qp_attr; + + memset(&qp_attr, 0, sizeof(qp_attr)); + qp_attr.qp_state = IB_QPS_ERR; + return ib_modify_qp(lnk->roce_qp, &qp_attr, IB_QP_STATE); +} + +int smc_ib_ready_link(struct smc_link *lnk) +{ + struct smc_link_group *lgr = smc_get_lgr(lnk); + int rc = 0; + + rc = smc_ib_modify_qp_init(lnk); + if (rc) + goto out; + + rc = smc_ib_modify_qp_rtr(lnk); + if (rc) + goto out; + smc_wr_remember_qp_attr(lnk); + rc = ib_req_notify_cq(lnk->smcibdev->roce_cq_recv, + IB_CQ_SOLICITED_MASK); + if (rc) + goto out; + rc = smc_wr_rx_post_init(lnk); + if (rc) + goto out; + smc_wr_remember_qp_attr(lnk); + + if (lgr->role == SMC_SERV) { + rc = smc_ib_modify_qp_rts(lnk); + if (rc) + goto out; + smc_wr_remember_qp_attr(lnk); + } +out: + return rc; +} + +static int smc_ib_fill_mac(struct smc_ib_device *smcibdev, u8 ibport) +{ + const struct ib_gid_attr *attr; + int rc; + + attr = rdma_get_gid_attr(smcibdev->ibdev, ibport, 0); + if (IS_ERR(attr)) + return -ENODEV; + + rc = rdma_read_gid_l2_fields(attr, NULL, smcibdev->mac[ibport - 1]); + rdma_put_gid_attr(attr); + return rc; +} + +/* Create an identifier unique for this instance of SMC-R. + * The MAC-address of the first active registered IB device + * plus a random 2-byte number is used to create this identifier. + * This name is delivered to the peer during connection initialization. + */ +static inline void smc_ib_define_local_systemid(struct smc_ib_device *smcibdev, + u8 ibport) +{ + memcpy(&local_systemid[2], &smcibdev->mac[ibport - 1], + sizeof(smcibdev->mac[ibport - 1])); +} + +bool smc_ib_is_valid_local_systemid(void) +{ + return !is_zero_ether_addr(&local_systemid[2]); +} + +static void smc_ib_init_local_systemid(void) +{ + get_random_bytes(&local_systemid[0], 2); +} + +bool smc_ib_port_active(struct smc_ib_device *smcibdev, u8 ibport) +{ + return smcibdev->pattr[ibport - 1].state == IB_PORT_ACTIVE; +} + +int smc_ib_find_route(struct net *net, __be32 saddr, __be32 daddr, + u8 nexthop_mac[], u8 *uses_gateway) +{ + struct neighbour *neigh = NULL; + struct rtable *rt = NULL; + struct flowi4 fl4 = { + .saddr = saddr, + .daddr = daddr + }; + + if (daddr == cpu_to_be32(INADDR_NONE)) + goto out; + rt = ip_route_output_flow(net, &fl4, NULL); + if (IS_ERR(rt)) + goto out; + if (rt->rt_uses_gateway && rt->rt_gw_family != AF_INET) + goto out; + neigh = rt->dst.ops->neigh_lookup(&rt->dst, NULL, &fl4.daddr); + if (neigh) { + memcpy(nexthop_mac, neigh->ha, ETH_ALEN); + *uses_gateway = rt->rt_uses_gateway; + return 0; + } +out: + return -ENOENT; +} + +static int smc_ib_determine_gid_rcu(const struct net_device *ndev, + const struct ib_gid_attr *attr, + u8 gid[], u8 *sgid_index, + struct smc_init_info_smcrv2 *smcrv2) +{ + if (!smcrv2 && attr->gid_type == IB_GID_TYPE_ROCE) { + if (gid) + memcpy(gid, &attr->gid, SMC_GID_SIZE); + if (sgid_index) + *sgid_index = attr->index; + return 0; + } + if (smcrv2 && attr->gid_type == IB_GID_TYPE_ROCE_UDP_ENCAP && + smc_ib_gid_to_ipv4((u8 *)&attr->gid) != cpu_to_be32(INADDR_NONE)) { + struct in_device *in_dev = __in_dev_get_rcu(ndev); + struct net *net = dev_net(ndev); + const struct in_ifaddr *ifa; + bool subnet_match = false; + + if (!in_dev) + goto out; + in_dev_for_each_ifa_rcu(ifa, in_dev) { + if (!inet_ifa_match(smcrv2->saddr, ifa)) + continue; + subnet_match = true; + break; + } + if (!subnet_match) + goto out; + if (smcrv2->daddr && smc_ib_find_route(net, smcrv2->saddr, + smcrv2->daddr, + smcrv2->nexthop_mac, + &smcrv2->uses_gateway)) + goto out; + + if (gid) + memcpy(gid, &attr->gid, SMC_GID_SIZE); + if (sgid_index) + *sgid_index = attr->index; + return 0; + } +out: + return -ENODEV; +} + +/* determine the gid for an ib-device port and vlan id */ +int smc_ib_determine_gid(struct smc_ib_device *smcibdev, u8 ibport, + unsigned short vlan_id, u8 gid[], u8 *sgid_index, + struct smc_init_info_smcrv2 *smcrv2) +{ + const struct ib_gid_attr *attr; + const struct net_device *ndev; + int i; + + for (i = 0; i < smcibdev->pattr[ibport - 1].gid_tbl_len; i++) { + attr = rdma_get_gid_attr(smcibdev->ibdev, ibport, i); + if (IS_ERR(attr)) + continue; + + rcu_read_lock(); + ndev = rdma_read_gid_attr_ndev_rcu(attr); + if (!IS_ERR(ndev) && + ((!vlan_id && !is_vlan_dev(ndev)) || + (vlan_id && is_vlan_dev(ndev) && + vlan_dev_vlan_id(ndev) == vlan_id))) { + if (!smc_ib_determine_gid_rcu(ndev, attr, gid, + sgid_index, smcrv2)) { + rcu_read_unlock(); + rdma_put_gid_attr(attr); + return 0; + } + } + rcu_read_unlock(); + rdma_put_gid_attr(attr); + } + return -ENODEV; +} + +/* check if gid is still defined on smcibdev */ +static bool smc_ib_check_link_gid(u8 gid[SMC_GID_SIZE], bool smcrv2, + struct smc_ib_device *smcibdev, u8 ibport) +{ + const struct ib_gid_attr *attr; + bool rc = false; + int i; + + for (i = 0; !rc && i < smcibdev->pattr[ibport - 1].gid_tbl_len; i++) { + attr = rdma_get_gid_attr(smcibdev->ibdev, ibport, i); + if (IS_ERR(attr)) + continue; + + rcu_read_lock(); + if ((!smcrv2 && attr->gid_type == IB_GID_TYPE_ROCE) || + (smcrv2 && attr->gid_type == IB_GID_TYPE_ROCE_UDP_ENCAP && + !(ipv6_addr_type((const struct in6_addr *)&attr->gid) + & IPV6_ADDR_LINKLOCAL))) + if (!memcmp(gid, &attr->gid, SMC_GID_SIZE)) + rc = true; + rcu_read_unlock(); + rdma_put_gid_attr(attr); + } + return rc; +} + +/* check all links if the gid is still defined on smcibdev */ +static void smc_ib_gid_check(struct smc_ib_device *smcibdev, u8 ibport) +{ + struct smc_link_group *lgr; + int i; + + spin_lock_bh(&smc_lgr_list.lock); + list_for_each_entry(lgr, &smc_lgr_list.list, list) { + if (strncmp(smcibdev->pnetid[ibport - 1], lgr->pnet_id, + SMC_MAX_PNETID_LEN)) + continue; /* lgr is not affected */ + if (list_empty(&lgr->list)) + continue; + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + if (lgr->lnk[i].state == SMC_LNK_UNUSED || + lgr->lnk[i].smcibdev != smcibdev) + continue; + if (!smc_ib_check_link_gid(lgr->lnk[i].gid, + lgr->smc_version == SMC_V2, + smcibdev, ibport)) + smcr_port_err(smcibdev, ibport); + } + } + spin_unlock_bh(&smc_lgr_list.lock); +} + +static int smc_ib_remember_port_attr(struct smc_ib_device *smcibdev, u8 ibport) +{ + int rc; + + memset(&smcibdev->pattr[ibport - 1], 0, + sizeof(smcibdev->pattr[ibport - 1])); + rc = ib_query_port(smcibdev->ibdev, ibport, + &smcibdev->pattr[ibport - 1]); + if (rc) + goto out; + /* the SMC protocol requires specification of the RoCE MAC address */ + rc = smc_ib_fill_mac(smcibdev, ibport); + if (rc) + goto out; + if (!smc_ib_is_valid_local_systemid() && + smc_ib_port_active(smcibdev, ibport)) + /* create unique system identifier */ + smc_ib_define_local_systemid(smcibdev, ibport); +out: + return rc; +} + +/* process context wrapper for might_sleep smc_ib_remember_port_attr */ +static void smc_ib_port_event_work(struct work_struct *work) +{ + struct smc_ib_device *smcibdev = container_of( + work, struct smc_ib_device, port_event_work); + u8 port_idx; + + for_each_set_bit(port_idx, &smcibdev->port_event_mask, SMC_MAX_PORTS) { + smc_ib_remember_port_attr(smcibdev, port_idx + 1); + clear_bit(port_idx, &smcibdev->port_event_mask); + if (!smc_ib_port_active(smcibdev, port_idx + 1)) { + set_bit(port_idx, smcibdev->ports_going_away); + smcr_port_err(smcibdev, port_idx + 1); + } else { + clear_bit(port_idx, smcibdev->ports_going_away); + smcr_port_add(smcibdev, port_idx + 1); + smc_ib_gid_check(smcibdev, port_idx + 1); + } + } +} + +/* can be called in IRQ context */ +static void smc_ib_global_event_handler(struct ib_event_handler *handler, + struct ib_event *ibevent) +{ + struct smc_ib_device *smcibdev; + bool schedule = false; + u8 port_idx; + + smcibdev = container_of(handler, struct smc_ib_device, event_handler); + + switch (ibevent->event) { + case IB_EVENT_DEVICE_FATAL: + /* terminate all ports on device */ + for (port_idx = 0; port_idx < SMC_MAX_PORTS; port_idx++) { + set_bit(port_idx, &smcibdev->port_event_mask); + if (!test_and_set_bit(port_idx, + smcibdev->ports_going_away)) + schedule = true; + } + if (schedule) + schedule_work(&smcibdev->port_event_work); + break; + case IB_EVENT_PORT_ACTIVE: + port_idx = ibevent->element.port_num - 1; + if (port_idx >= SMC_MAX_PORTS) + break; + set_bit(port_idx, &smcibdev->port_event_mask); + if (test_and_clear_bit(port_idx, smcibdev->ports_going_away)) + schedule_work(&smcibdev->port_event_work); + break; + case IB_EVENT_PORT_ERR: + port_idx = ibevent->element.port_num - 1; + if (port_idx >= SMC_MAX_PORTS) + break; + set_bit(port_idx, &smcibdev->port_event_mask); + if (!test_and_set_bit(port_idx, smcibdev->ports_going_away)) + schedule_work(&smcibdev->port_event_work); + break; + case IB_EVENT_GID_CHANGE: + port_idx = ibevent->element.port_num - 1; + if (port_idx >= SMC_MAX_PORTS) + break; + set_bit(port_idx, &smcibdev->port_event_mask); + schedule_work(&smcibdev->port_event_work); + break; + default: + break; + } +} + +void smc_ib_dealloc_protection_domain(struct smc_link *lnk) +{ + if (lnk->roce_pd) + ib_dealloc_pd(lnk->roce_pd); + lnk->roce_pd = NULL; +} + +int smc_ib_create_protection_domain(struct smc_link *lnk) +{ + int rc; + + lnk->roce_pd = ib_alloc_pd(lnk->smcibdev->ibdev, 0); + rc = PTR_ERR_OR_ZERO(lnk->roce_pd); + if (IS_ERR(lnk->roce_pd)) + lnk->roce_pd = NULL; + return rc; +} + +static bool smcr_diag_is_dev_critical(struct smc_lgr_list *smc_lgr, + struct smc_ib_device *smcibdev) +{ + struct smc_link_group *lgr; + bool rc = false; + int i; + + spin_lock_bh(&smc_lgr->lock); + list_for_each_entry(lgr, &smc_lgr->list, list) { + if (lgr->is_smcd) + continue; + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + if (lgr->lnk[i].state == SMC_LNK_UNUSED || + lgr->lnk[i].smcibdev != smcibdev) + continue; + if (lgr->type == SMC_LGR_SINGLE || + lgr->type == SMC_LGR_ASYMMETRIC_LOCAL) { + rc = true; + goto out; + } + } + } +out: + spin_unlock_bh(&smc_lgr->lock); + return rc; +} + +static int smc_nl_handle_dev_port(struct sk_buff *skb, + struct ib_device *ibdev, + struct smc_ib_device *smcibdev, + int port) +{ + char smc_pnet[SMC_MAX_PNETID_LEN + 1]; + struct nlattr *port_attrs; + unsigned char port_state; + int lnk_count = 0; + + port_attrs = nla_nest_start(skb, SMC_NLA_DEV_PORT + port); + if (!port_attrs) + goto errout; + + if (nla_put_u8(skb, SMC_NLA_DEV_PORT_PNET_USR, + smcibdev->pnetid_by_user[port])) + goto errattr; + memcpy(smc_pnet, &smcibdev->pnetid[port], SMC_MAX_PNETID_LEN); + smc_pnet[SMC_MAX_PNETID_LEN] = 0; + if (nla_put_string(skb, SMC_NLA_DEV_PORT_PNETID, smc_pnet)) + goto errattr; + if (nla_put_u32(skb, SMC_NLA_DEV_PORT_NETDEV, + smcibdev->ndev_ifidx[port])) + goto errattr; + if (nla_put_u8(skb, SMC_NLA_DEV_PORT_VALID, 1)) + goto errattr; + port_state = smc_ib_port_active(smcibdev, port + 1); + if (nla_put_u8(skb, SMC_NLA_DEV_PORT_STATE, port_state)) + goto errattr; + lnk_count = atomic_read(&smcibdev->lnk_cnt_by_port[port]); + if (nla_put_u32(skb, SMC_NLA_DEV_PORT_LNK_CNT, lnk_count)) + goto errattr; + nla_nest_end(skb, port_attrs); + return 0; +errattr: + nla_nest_cancel(skb, port_attrs); +errout: + return -EMSGSIZE; +} + +static bool smc_nl_handle_pci_values(const struct smc_pci_dev *smc_pci_dev, + struct sk_buff *skb) +{ + if (nla_put_u32(skb, SMC_NLA_DEV_PCI_FID, smc_pci_dev->pci_fid)) + return false; + if (nla_put_u16(skb, SMC_NLA_DEV_PCI_CHID, smc_pci_dev->pci_pchid)) + return false; + if (nla_put_u16(skb, SMC_NLA_DEV_PCI_VENDOR, smc_pci_dev->pci_vendor)) + return false; + if (nla_put_u16(skb, SMC_NLA_DEV_PCI_DEVICE, smc_pci_dev->pci_device)) + return false; + if (nla_put_string(skb, SMC_NLA_DEV_PCI_ID, smc_pci_dev->pci_id)) + return false; + return true; +} + +static int smc_nl_handle_smcr_dev(struct smc_ib_device *smcibdev, + struct sk_buff *skb, + struct netlink_callback *cb) +{ + char smc_ibname[IB_DEVICE_NAME_MAX]; + struct smc_pci_dev smc_pci_dev; + struct pci_dev *pci_dev; + unsigned char is_crit; + struct nlattr *attrs; + void *nlh; + int i; + + nlh = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &smc_gen_nl_family, NLM_F_MULTI, + SMC_NETLINK_GET_DEV_SMCR); + if (!nlh) + goto errmsg; + attrs = nla_nest_start(skb, SMC_GEN_DEV_SMCR); + if (!attrs) + goto errout; + is_crit = smcr_diag_is_dev_critical(&smc_lgr_list, smcibdev); + if (nla_put_u8(skb, SMC_NLA_DEV_IS_CRIT, is_crit)) + goto errattr; + if (smcibdev->ibdev->dev.parent) { + memset(&smc_pci_dev, 0, sizeof(smc_pci_dev)); + pci_dev = to_pci_dev(smcibdev->ibdev->dev.parent); + smc_set_pci_values(pci_dev, &smc_pci_dev); + if (!smc_nl_handle_pci_values(&smc_pci_dev, skb)) + goto errattr; + } + snprintf(smc_ibname, sizeof(smc_ibname), "%s", smcibdev->ibdev->name); + if (nla_put_string(skb, SMC_NLA_DEV_IB_NAME, smc_ibname)) + goto errattr; + for (i = 1; i <= SMC_MAX_PORTS; i++) { + if (!rdma_is_port_valid(smcibdev->ibdev, i)) + continue; + if (smc_nl_handle_dev_port(skb, smcibdev->ibdev, + smcibdev, i - 1)) + goto errattr; + } + + nla_nest_end(skb, attrs); + genlmsg_end(skb, nlh); + return 0; + +errattr: + nla_nest_cancel(skb, attrs); +errout: + genlmsg_cancel(skb, nlh); +errmsg: + return -EMSGSIZE; +} + +static void smc_nl_prep_smcr_dev(struct smc_ib_devices *dev_list, + struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct smc_nl_dmp_ctx *cb_ctx = smc_nl_dmp_ctx(cb); + struct smc_ib_device *smcibdev; + int snum = cb_ctx->pos[0]; + int num = 0; + + mutex_lock(&dev_list->mutex); + list_for_each_entry(smcibdev, &dev_list->list, list) { + if (num < snum) + goto next; + if (smc_nl_handle_smcr_dev(smcibdev, skb, cb)) + goto errout; +next: + num++; + } +errout: + mutex_unlock(&dev_list->mutex); + cb_ctx->pos[0] = num; +} + +int smcr_nl_get_device(struct sk_buff *skb, struct netlink_callback *cb) +{ + smc_nl_prep_smcr_dev(&smc_ib_devices, skb, cb); + return skb->len; +} + +static void smc_ib_qp_event_handler(struct ib_event *ibevent, void *priv) +{ + struct smc_link *lnk = (struct smc_link *)priv; + struct smc_ib_device *smcibdev = lnk->smcibdev; + u8 port_idx; + + switch (ibevent->event) { + case IB_EVENT_QP_FATAL: + case IB_EVENT_QP_ACCESS_ERR: + port_idx = ibevent->element.qp->port - 1; + if (port_idx >= SMC_MAX_PORTS) + break; + set_bit(port_idx, &smcibdev->port_event_mask); + if (!test_and_set_bit(port_idx, smcibdev->ports_going_away)) + schedule_work(&smcibdev->port_event_work); + break; + default: + break; + } +} + +void smc_ib_destroy_queue_pair(struct smc_link *lnk) +{ + if (lnk->roce_qp) + ib_destroy_qp(lnk->roce_qp); + lnk->roce_qp = NULL; +} + +/* create a queue pair within the protection domain for a link */ +int smc_ib_create_queue_pair(struct smc_link *lnk) +{ + int sges_per_buf = (lnk->lgr->smc_version == SMC_V2) ? 2 : 1; + struct ib_qp_init_attr qp_attr = { + .event_handler = smc_ib_qp_event_handler, + .qp_context = lnk, + .send_cq = lnk->smcibdev->roce_cq_send, + .recv_cq = lnk->smcibdev->roce_cq_recv, + .srq = NULL, + .cap = { + /* include unsolicited rdma_writes as well, + * there are max. 2 RDMA_WRITE per 1 WR_SEND + */ + .max_send_wr = SMC_WR_BUF_CNT * 3, + .max_recv_wr = SMC_WR_BUF_CNT * 3, + .max_send_sge = SMC_IB_MAX_SEND_SGE, + .max_recv_sge = sges_per_buf, + .max_inline_data = 0, + }, + .sq_sig_type = IB_SIGNAL_REQ_WR, + .qp_type = IB_QPT_RC, + }; + int rc; + + lnk->roce_qp = ib_create_qp(lnk->roce_pd, &qp_attr); + rc = PTR_ERR_OR_ZERO(lnk->roce_qp); + if (IS_ERR(lnk->roce_qp)) + lnk->roce_qp = NULL; + else + smc_wr_remember_qp_attr(lnk); + return rc; +} + +void smc_ib_put_memory_region(struct ib_mr *mr) +{ + ib_dereg_mr(mr); +} + +static int smc_ib_map_mr_sg(struct smc_buf_desc *buf_slot, u8 link_idx) +{ + unsigned int offset = 0; + int sg_num; + + /* map the largest prefix of a dma mapped SG list */ + sg_num = ib_map_mr_sg(buf_slot->mr[link_idx], + buf_slot->sgt[link_idx].sgl, + buf_slot->sgt[link_idx].orig_nents, + &offset, PAGE_SIZE); + + return sg_num; +} + +/* Allocate a memory region and map the dma mapped SG list of buf_slot */ +int smc_ib_get_memory_region(struct ib_pd *pd, int access_flags, + struct smc_buf_desc *buf_slot, u8 link_idx) +{ + if (buf_slot->mr[link_idx]) + return 0; /* already done */ + + buf_slot->mr[link_idx] = + ib_alloc_mr(pd, IB_MR_TYPE_MEM_REG, 1 << buf_slot->order); + if (IS_ERR(buf_slot->mr[link_idx])) { + int rc; + + rc = PTR_ERR(buf_slot->mr[link_idx]); + buf_slot->mr[link_idx] = NULL; + return rc; + } + + if (smc_ib_map_mr_sg(buf_slot, link_idx) != + buf_slot->sgt[link_idx].orig_nents) + return -EINVAL; + + return 0; +} + +bool smc_ib_is_sg_need_sync(struct smc_link *lnk, + struct smc_buf_desc *buf_slot) +{ + struct scatterlist *sg; + unsigned int i; + bool ret = false; + + /* for now there is just one DMA address */ + for_each_sg(buf_slot->sgt[lnk->link_idx].sgl, sg, + buf_slot->sgt[lnk->link_idx].nents, i) { + if (!sg_dma_len(sg)) + break; + if (dma_need_sync(lnk->smcibdev->ibdev->dma_device, + sg_dma_address(sg))) { + ret = true; + goto out; + } + } + +out: + return ret; +} + +/* synchronize buffer usage for cpu access */ +void smc_ib_sync_sg_for_cpu(struct smc_link *lnk, + struct smc_buf_desc *buf_slot, + enum dma_data_direction data_direction) +{ + struct scatterlist *sg; + unsigned int i; + + if (!(buf_slot->is_dma_need_sync & (1U << lnk->link_idx))) + return; + + /* for now there is just one DMA address */ + for_each_sg(buf_slot->sgt[lnk->link_idx].sgl, sg, + buf_slot->sgt[lnk->link_idx].nents, i) { + if (!sg_dma_len(sg)) + break; + ib_dma_sync_single_for_cpu(lnk->smcibdev->ibdev, + sg_dma_address(sg), + sg_dma_len(sg), + data_direction); + } +} + +/* synchronize buffer usage for device access */ +void smc_ib_sync_sg_for_device(struct smc_link *lnk, + struct smc_buf_desc *buf_slot, + enum dma_data_direction data_direction) +{ + struct scatterlist *sg; + unsigned int i; + + if (!(buf_slot->is_dma_need_sync & (1U << lnk->link_idx))) + return; + + /* for now there is just one DMA address */ + for_each_sg(buf_slot->sgt[lnk->link_idx].sgl, sg, + buf_slot->sgt[lnk->link_idx].nents, i) { + if (!sg_dma_len(sg)) + break; + ib_dma_sync_single_for_device(lnk->smcibdev->ibdev, + sg_dma_address(sg), + sg_dma_len(sg), + data_direction); + } +} + +/* Map a new TX or RX buffer SG-table to DMA */ +int smc_ib_buf_map_sg(struct smc_link *lnk, + struct smc_buf_desc *buf_slot, + enum dma_data_direction data_direction) +{ + int mapped_nents; + + mapped_nents = ib_dma_map_sg(lnk->smcibdev->ibdev, + buf_slot->sgt[lnk->link_idx].sgl, + buf_slot->sgt[lnk->link_idx].orig_nents, + data_direction); + if (!mapped_nents) + return -ENOMEM; + + return mapped_nents; +} + +void smc_ib_buf_unmap_sg(struct smc_link *lnk, + struct smc_buf_desc *buf_slot, + enum dma_data_direction data_direction) +{ + if (!buf_slot->sgt[lnk->link_idx].sgl->dma_address) + return; /* already unmapped */ + + ib_dma_unmap_sg(lnk->smcibdev->ibdev, + buf_slot->sgt[lnk->link_idx].sgl, + buf_slot->sgt[lnk->link_idx].orig_nents, + data_direction); + buf_slot->sgt[lnk->link_idx].sgl->dma_address = 0; +} + +long smc_ib_setup_per_ibdev(struct smc_ib_device *smcibdev) +{ + struct ib_cq_init_attr cqattr = { + .cqe = SMC_MAX_CQE, .comp_vector = 0 }; + int cqe_size_order, smc_order; + long rc; + + mutex_lock(&smcibdev->mutex); + rc = 0; + if (smcibdev->initialized) + goto out; + /* the calculated number of cq entries fits to mlx5 cq allocation */ + cqe_size_order = cache_line_size() == 128 ? 7 : 6; + smc_order = MAX_ORDER - cqe_size_order - 1; + if (SMC_MAX_CQE + 2 > (0x00000001 << smc_order) * PAGE_SIZE) + cqattr.cqe = (0x00000001 << smc_order) * PAGE_SIZE - 2; + smcibdev->roce_cq_send = ib_create_cq(smcibdev->ibdev, + smc_wr_tx_cq_handler, NULL, + smcibdev, &cqattr); + rc = PTR_ERR_OR_ZERO(smcibdev->roce_cq_send); + if (IS_ERR(smcibdev->roce_cq_send)) { + smcibdev->roce_cq_send = NULL; + goto out; + } + smcibdev->roce_cq_recv = ib_create_cq(smcibdev->ibdev, + smc_wr_rx_cq_handler, NULL, + smcibdev, &cqattr); + rc = PTR_ERR_OR_ZERO(smcibdev->roce_cq_recv); + if (IS_ERR(smcibdev->roce_cq_recv)) { + smcibdev->roce_cq_recv = NULL; + goto err; + } + smc_wr_add_dev(smcibdev); + smcibdev->initialized = 1; + goto out; + +err: + ib_destroy_cq(smcibdev->roce_cq_send); +out: + mutex_unlock(&smcibdev->mutex); + return rc; +} + +static void smc_ib_cleanup_per_ibdev(struct smc_ib_device *smcibdev) +{ + mutex_lock(&smcibdev->mutex); + if (!smcibdev->initialized) + goto out; + smcibdev->initialized = 0; + ib_destroy_cq(smcibdev->roce_cq_recv); + ib_destroy_cq(smcibdev->roce_cq_send); + smc_wr_remove_dev(smcibdev); +out: + mutex_unlock(&smcibdev->mutex); +} + +static struct ib_client smc_ib_client; + +static void smc_copy_netdev_ifindex(struct smc_ib_device *smcibdev, int port) +{ + struct ib_device *ibdev = smcibdev->ibdev; + struct net_device *ndev; + + if (!ibdev->ops.get_netdev) + return; + ndev = ibdev->ops.get_netdev(ibdev, port + 1); + if (ndev) { + smcibdev->ndev_ifidx[port] = ndev->ifindex; + dev_put(ndev); + } +} + +void smc_ib_ndev_change(struct net_device *ndev, unsigned long event) +{ + struct smc_ib_device *smcibdev; + struct ib_device *libdev; + struct net_device *lndev; + u8 port_cnt; + int i; + + mutex_lock(&smc_ib_devices.mutex); + list_for_each_entry(smcibdev, &smc_ib_devices.list, list) { + port_cnt = smcibdev->ibdev->phys_port_cnt; + for (i = 0; i < min_t(size_t, port_cnt, SMC_MAX_PORTS); i++) { + libdev = smcibdev->ibdev; + if (!libdev->ops.get_netdev) + continue; + lndev = libdev->ops.get_netdev(libdev, i + 1); + dev_put(lndev); + if (lndev != ndev) + continue; + if (event == NETDEV_REGISTER) + smcibdev->ndev_ifidx[i] = ndev->ifindex; + if (event == NETDEV_UNREGISTER) + smcibdev->ndev_ifidx[i] = 0; + } + } + mutex_unlock(&smc_ib_devices.mutex); +} + +/* callback function for ib_register_client() */ +static int smc_ib_add_dev(struct ib_device *ibdev) +{ + struct smc_ib_device *smcibdev; + u8 port_cnt; + int i; + + if (ibdev->node_type != RDMA_NODE_IB_CA) + return -EOPNOTSUPP; + + smcibdev = kzalloc(sizeof(*smcibdev), GFP_KERNEL); + if (!smcibdev) + return -ENOMEM; + + smcibdev->ibdev = ibdev; + INIT_WORK(&smcibdev->port_event_work, smc_ib_port_event_work); + atomic_set(&smcibdev->lnk_cnt, 0); + init_waitqueue_head(&smcibdev->lnks_deleted); + mutex_init(&smcibdev->mutex); + mutex_lock(&smc_ib_devices.mutex); + list_add_tail(&smcibdev->list, &smc_ib_devices.list); + mutex_unlock(&smc_ib_devices.mutex); + ib_set_client_data(ibdev, &smc_ib_client, smcibdev); + INIT_IB_EVENT_HANDLER(&smcibdev->event_handler, smcibdev->ibdev, + smc_ib_global_event_handler); + ib_register_event_handler(&smcibdev->event_handler); + + /* trigger reading of the port attributes */ + port_cnt = smcibdev->ibdev->phys_port_cnt; + pr_warn_ratelimited("smc: adding ib device %s with port count %d\n", + smcibdev->ibdev->name, port_cnt); + for (i = 0; + i < min_t(size_t, port_cnt, SMC_MAX_PORTS); + i++) { + set_bit(i, &smcibdev->port_event_mask); + /* determine pnetids of the port */ + if (smc_pnetid_by_dev_port(ibdev->dev.parent, i, + smcibdev->pnetid[i])) + smc_pnetid_by_table_ib(smcibdev, i + 1); + smc_copy_netdev_ifindex(smcibdev, i); + pr_warn_ratelimited("smc: ib device %s port %d has pnetid " + "%.16s%s\n", + smcibdev->ibdev->name, i + 1, + smcibdev->pnetid[i], + smcibdev->pnetid_by_user[i] ? + " (user defined)" : + ""); + } + schedule_work(&smcibdev->port_event_work); + return 0; +} + +/* callback function for ib_unregister_client() */ +static void smc_ib_remove_dev(struct ib_device *ibdev, void *client_data) +{ + struct smc_ib_device *smcibdev = client_data; + + mutex_lock(&smc_ib_devices.mutex); + list_del_init(&smcibdev->list); /* remove from smc_ib_devices */ + mutex_unlock(&smc_ib_devices.mutex); + pr_warn_ratelimited("smc: removing ib device %s\n", + smcibdev->ibdev->name); + smc_smcr_terminate_all(smcibdev); + smc_ib_cleanup_per_ibdev(smcibdev); + ib_unregister_event_handler(&smcibdev->event_handler); + cancel_work_sync(&smcibdev->port_event_work); + kfree(smcibdev); +} + +static struct ib_client smc_ib_client = { + .name = "smc_ib", + .add = smc_ib_add_dev, + .remove = smc_ib_remove_dev, +}; + +int __init smc_ib_register_client(void) +{ + smc_ib_init_local_systemid(); + return ib_register_client(&smc_ib_client); +} + +void smc_ib_unregister_client(void) +{ + ib_unregister_client(&smc_ib_client); +} diff --git a/net/smc/smc_ib.h b/net/smc/smc_ib.h new file mode 100644 index 000000000..ebcb05ede --- /dev/null +++ b/net/smc/smc_ib.h @@ -0,0 +1,120 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Definitions for IB environment + * + * Copyright IBM Corp. 2016 + * + * Author(s): Ursula Braun + */ + +#ifndef _SMC_IB_H +#define _SMC_IB_H + +#include +#include +#include +#include +#include +#include + +#define SMC_MAX_PORTS 2 /* Max # of ports */ +#define SMC_GID_SIZE sizeof(union ib_gid) + +#define SMC_IB_MAX_SEND_SGE 2 + +struct smc_ib_devices { /* list of smc ib devices definition */ + struct list_head list; + struct mutex mutex; /* protects list of smc ib devices */ +}; + +extern struct smc_ib_devices smc_ib_devices; /* list of smc ib devices */ +extern struct smc_lgr_list smc_lgr_list; /* list of linkgroups */ + +struct smc_ib_device { /* ib-device infos for smc */ + struct list_head list; + struct ib_device *ibdev; + struct ib_port_attr pattr[SMC_MAX_PORTS]; /* ib dev. port attrs */ + struct ib_event_handler event_handler; /* global ib_event handler */ + struct ib_cq *roce_cq_send; /* send completion queue */ + struct ib_cq *roce_cq_recv; /* recv completion queue */ + struct tasklet_struct send_tasklet; /* called by send cq handler */ + struct tasklet_struct recv_tasklet; /* called by recv cq handler */ + char mac[SMC_MAX_PORTS][ETH_ALEN]; + /* mac address per port*/ + u8 pnetid[SMC_MAX_PORTS][SMC_MAX_PNETID_LEN]; + /* pnetid per port */ + bool pnetid_by_user[SMC_MAX_PORTS]; + /* pnetid defined by user? */ + u8 initialized : 1; /* ib dev CQ, evthdl done */ + struct work_struct port_event_work; + unsigned long port_event_mask; + DECLARE_BITMAP(ports_going_away, SMC_MAX_PORTS); + atomic_t lnk_cnt; /* number of links on ibdev */ + wait_queue_head_t lnks_deleted; /* wait 4 removal of all links*/ + struct mutex mutex; /* protect dev setup+cleanup */ + atomic_t lnk_cnt_by_port[SMC_MAX_PORTS]; + /* number of links per port */ + int ndev_ifidx[SMC_MAX_PORTS]; /* ndev if indexes */ +}; + +static inline __be32 smc_ib_gid_to_ipv4(u8 gid[SMC_GID_SIZE]) +{ + struct in6_addr *addr6 = (struct in6_addr *)gid; + + if (ipv6_addr_v4mapped(addr6) || + !(addr6->s6_addr32[0] | addr6->s6_addr32[1] | addr6->s6_addr32[2])) + return addr6->s6_addr32[3]; + return cpu_to_be32(INADDR_NONE); +} + +static inline struct net *smc_ib_net(struct smc_ib_device *smcibdev) +{ + if (smcibdev && smcibdev->ibdev) + return read_pnet(&smcibdev->ibdev->coredev.rdma_net); + return NULL; +} + +struct smc_init_info_smcrv2; +struct smc_buf_desc; +struct smc_link; + +void smc_ib_ndev_change(struct net_device *ndev, unsigned long event); +int smc_ib_register_client(void) __init; +void smc_ib_unregister_client(void); +bool smc_ib_port_active(struct smc_ib_device *smcibdev, u8 ibport); +int smc_ib_buf_map_sg(struct smc_link *lnk, + struct smc_buf_desc *buf_slot, + enum dma_data_direction data_direction); +void smc_ib_buf_unmap_sg(struct smc_link *lnk, + struct smc_buf_desc *buf_slot, + enum dma_data_direction data_direction); +void smc_ib_dealloc_protection_domain(struct smc_link *lnk); +int smc_ib_create_protection_domain(struct smc_link *lnk); +void smc_ib_destroy_queue_pair(struct smc_link *lnk); +int smc_ib_create_queue_pair(struct smc_link *lnk); +int smc_ib_ready_link(struct smc_link *lnk); +int smc_ib_modify_qp_rts(struct smc_link *lnk); +int smc_ib_modify_qp_reset(struct smc_link *lnk); +int smc_ib_modify_qp_error(struct smc_link *lnk); +long smc_ib_setup_per_ibdev(struct smc_ib_device *smcibdev); +int smc_ib_get_memory_region(struct ib_pd *pd, int access_flags, + struct smc_buf_desc *buf_slot, u8 link_idx); +void smc_ib_put_memory_region(struct ib_mr *mr); +bool smc_ib_is_sg_need_sync(struct smc_link *lnk, + struct smc_buf_desc *buf_slot); +void smc_ib_sync_sg_for_cpu(struct smc_link *lnk, + struct smc_buf_desc *buf_slot, + enum dma_data_direction data_direction); +void smc_ib_sync_sg_for_device(struct smc_link *lnk, + struct smc_buf_desc *buf_slot, + enum dma_data_direction data_direction); +int smc_ib_determine_gid(struct smc_ib_device *smcibdev, u8 ibport, + unsigned short vlan_id, u8 gid[], u8 *sgid_index, + struct smc_init_info_smcrv2 *smcrv2); +int smc_ib_find_route(struct net *net, __be32 saddr, __be32 daddr, + u8 nexthop_mac[], u8 *uses_gateway); +bool smc_ib_is_valid_local_systemid(void); +int smcr_nl_get_device(struct sk_buff *skb, struct netlink_callback *cb); +#endif diff --git a/net/smc/smc_ism.c b/net/smc/smc_ism.c new file mode 100644 index 000000000..911fe08bc --- /dev/null +++ b/net/smc/smc_ism.c @@ -0,0 +1,534 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Shared Memory Communications Direct over ISM devices (SMC-D) + * + * Functions for ISM device. + * + * Copyright IBM Corp. 2018 + */ + +#include +#include +#include +#include +#include + +#include "smc.h" +#include "smc_core.h" +#include "smc_ism.h" +#include "smc_pnet.h" +#include "smc_netlink.h" + +struct smcd_dev_list smcd_dev_list = { + .list = LIST_HEAD_INIT(smcd_dev_list.list), + .mutex = __MUTEX_INITIALIZER(smcd_dev_list.mutex) +}; + +static bool smc_ism_v2_capable; +static u8 smc_ism_v2_system_eid[SMC_MAX_EID_LEN]; + +/* Test if an ISM communication is possible - same CPC */ +int smc_ism_cantalk(u64 peer_gid, unsigned short vlan_id, struct smcd_dev *smcd) +{ + return smcd->ops->query_remote_gid(smcd, peer_gid, vlan_id ? 1 : 0, + vlan_id); +} + +void smc_ism_get_system_eid(u8 **eid) +{ + if (!smc_ism_v2_capable) + *eid = NULL; + else + *eid = smc_ism_v2_system_eid; +} + +u16 smc_ism_get_chid(struct smcd_dev *smcd) +{ + return smcd->ops->get_chid(smcd); +} + +/* HW supports ISM V2 and thus System EID is defined */ +bool smc_ism_is_v2_capable(void) +{ + return smc_ism_v2_capable; +} + +/* Set a connection using this DMBE. */ +void smc_ism_set_conn(struct smc_connection *conn) +{ + unsigned long flags; + + spin_lock_irqsave(&conn->lgr->smcd->lock, flags); + conn->lgr->smcd->conn[conn->rmb_desc->sba_idx] = conn; + spin_unlock_irqrestore(&conn->lgr->smcd->lock, flags); +} + +/* Unset a connection using this DMBE. */ +void smc_ism_unset_conn(struct smc_connection *conn) +{ + unsigned long flags; + + if (!conn->rmb_desc) + return; + + spin_lock_irqsave(&conn->lgr->smcd->lock, flags); + conn->lgr->smcd->conn[conn->rmb_desc->sba_idx] = NULL; + spin_unlock_irqrestore(&conn->lgr->smcd->lock, flags); +} + +/* Register a VLAN identifier with the ISM device. Use a reference count + * and add a VLAN identifier only when the first DMB using this VLAN is + * registered. + */ +int smc_ism_get_vlan(struct smcd_dev *smcd, unsigned short vlanid) +{ + struct smc_ism_vlanid *new_vlan, *vlan; + unsigned long flags; + int rc = 0; + + if (!vlanid) /* No valid vlan id */ + return -EINVAL; + + /* create new vlan entry, in case we need it */ + new_vlan = kzalloc(sizeof(*new_vlan), GFP_KERNEL); + if (!new_vlan) + return -ENOMEM; + new_vlan->vlanid = vlanid; + refcount_set(&new_vlan->refcnt, 1); + + /* if there is an existing entry, increase count and return */ + spin_lock_irqsave(&smcd->lock, flags); + list_for_each_entry(vlan, &smcd->vlan, list) { + if (vlan->vlanid == vlanid) { + refcount_inc(&vlan->refcnt); + kfree(new_vlan); + goto out; + } + } + + /* no existing entry found. + * add new entry to device; might fail, e.g., if HW limit reached + */ + if (smcd->ops->add_vlan_id(smcd, vlanid)) { + kfree(new_vlan); + rc = -EIO; + goto out; + } + list_add_tail(&new_vlan->list, &smcd->vlan); +out: + spin_unlock_irqrestore(&smcd->lock, flags); + return rc; +} + +/* Unregister a VLAN identifier with the ISM device. Use a reference count + * and remove a VLAN identifier only when the last DMB using this VLAN is + * unregistered. + */ +int smc_ism_put_vlan(struct smcd_dev *smcd, unsigned short vlanid) +{ + struct smc_ism_vlanid *vlan; + unsigned long flags; + bool found = false; + int rc = 0; + + if (!vlanid) /* No valid vlan id */ + return -EINVAL; + + spin_lock_irqsave(&smcd->lock, flags); + list_for_each_entry(vlan, &smcd->vlan, list) { + if (vlan->vlanid == vlanid) { + if (!refcount_dec_and_test(&vlan->refcnt)) + goto out; + found = true; + break; + } + } + if (!found) { + rc = -ENOENT; + goto out; /* VLAN id not in table */ + } + + /* Found and the last reference just gone */ + if (smcd->ops->del_vlan_id(smcd, vlanid)) + rc = -EIO; + list_del(&vlan->list); + kfree(vlan); +out: + spin_unlock_irqrestore(&smcd->lock, flags); + return rc; +} + +int smc_ism_unregister_dmb(struct smcd_dev *smcd, struct smc_buf_desc *dmb_desc) +{ + struct smcd_dmb dmb; + int rc = 0; + + if (!dmb_desc->dma_addr) + return rc; + + memset(&dmb, 0, sizeof(dmb)); + dmb.dmb_tok = dmb_desc->token; + dmb.sba_idx = dmb_desc->sba_idx; + dmb.cpu_addr = dmb_desc->cpu_addr; + dmb.dma_addr = dmb_desc->dma_addr; + dmb.dmb_len = dmb_desc->len; + rc = smcd->ops->unregister_dmb(smcd, &dmb); + if (!rc || rc == ISM_ERROR) { + dmb_desc->cpu_addr = NULL; + dmb_desc->dma_addr = 0; + } + + return rc; +} + +int smc_ism_register_dmb(struct smc_link_group *lgr, int dmb_len, + struct smc_buf_desc *dmb_desc) +{ + struct smcd_dmb dmb; + int rc; + + memset(&dmb, 0, sizeof(dmb)); + dmb.dmb_len = dmb_len; + dmb.sba_idx = dmb_desc->sba_idx; + dmb.vlan_id = lgr->vlan_id; + dmb.rgid = lgr->peer_gid; + rc = lgr->smcd->ops->register_dmb(lgr->smcd, &dmb); + if (!rc) { + dmb_desc->sba_idx = dmb.sba_idx; + dmb_desc->token = dmb.dmb_tok; + dmb_desc->cpu_addr = dmb.cpu_addr; + dmb_desc->dma_addr = dmb.dma_addr; + dmb_desc->len = dmb.dmb_len; + } + return rc; +} + +static int smc_nl_handle_smcd_dev(struct smcd_dev *smcd, + struct sk_buff *skb, + struct netlink_callback *cb) +{ + char smc_pnet[SMC_MAX_PNETID_LEN + 1]; + struct smc_pci_dev smc_pci_dev; + struct nlattr *port_attrs; + struct nlattr *attrs; + int use_cnt = 0; + void *nlh; + + nlh = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &smc_gen_nl_family, NLM_F_MULTI, + SMC_NETLINK_GET_DEV_SMCD); + if (!nlh) + goto errmsg; + attrs = nla_nest_start(skb, SMC_GEN_DEV_SMCD); + if (!attrs) + goto errout; + use_cnt = atomic_read(&smcd->lgr_cnt); + if (nla_put_u32(skb, SMC_NLA_DEV_USE_CNT, use_cnt)) + goto errattr; + if (nla_put_u8(skb, SMC_NLA_DEV_IS_CRIT, use_cnt > 0)) + goto errattr; + memset(&smc_pci_dev, 0, sizeof(smc_pci_dev)); + smc_set_pci_values(to_pci_dev(smcd->dev.parent), &smc_pci_dev); + if (nla_put_u32(skb, SMC_NLA_DEV_PCI_FID, smc_pci_dev.pci_fid)) + goto errattr; + if (nla_put_u16(skb, SMC_NLA_DEV_PCI_CHID, smc_pci_dev.pci_pchid)) + goto errattr; + if (nla_put_u16(skb, SMC_NLA_DEV_PCI_VENDOR, smc_pci_dev.pci_vendor)) + goto errattr; + if (nla_put_u16(skb, SMC_NLA_DEV_PCI_DEVICE, smc_pci_dev.pci_device)) + goto errattr; + if (nla_put_string(skb, SMC_NLA_DEV_PCI_ID, smc_pci_dev.pci_id)) + goto errattr; + + port_attrs = nla_nest_start(skb, SMC_NLA_DEV_PORT); + if (!port_attrs) + goto errattr; + if (nla_put_u8(skb, SMC_NLA_DEV_PORT_PNET_USR, smcd->pnetid_by_user)) + goto errportattr; + memcpy(smc_pnet, smcd->pnetid, SMC_MAX_PNETID_LEN); + smc_pnet[SMC_MAX_PNETID_LEN] = 0; + if (nla_put_string(skb, SMC_NLA_DEV_PORT_PNETID, smc_pnet)) + goto errportattr; + + nla_nest_end(skb, port_attrs); + nla_nest_end(skb, attrs); + genlmsg_end(skb, nlh); + return 0; + +errportattr: + nla_nest_cancel(skb, port_attrs); +errattr: + nla_nest_cancel(skb, attrs); +errout: + nlmsg_cancel(skb, nlh); +errmsg: + return -EMSGSIZE; +} + +static void smc_nl_prep_smcd_dev(struct smcd_dev_list *dev_list, + struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct smc_nl_dmp_ctx *cb_ctx = smc_nl_dmp_ctx(cb); + int snum = cb_ctx->pos[0]; + struct smcd_dev *smcd; + int num = 0; + + mutex_lock(&dev_list->mutex); + list_for_each_entry(smcd, &dev_list->list, list) { + if (num < snum) + goto next; + if (smc_nl_handle_smcd_dev(smcd, skb, cb)) + goto errout; +next: + num++; + } +errout: + mutex_unlock(&dev_list->mutex); + cb_ctx->pos[0] = num; +} + +int smcd_nl_get_device(struct sk_buff *skb, struct netlink_callback *cb) +{ + smc_nl_prep_smcd_dev(&smcd_dev_list, skb, cb); + return skb->len; +} + +struct smc_ism_event_work { + struct work_struct work; + struct smcd_dev *smcd; + struct smcd_event event; +}; + +#define ISM_EVENT_REQUEST 0x0001 +#define ISM_EVENT_RESPONSE 0x0002 +#define ISM_EVENT_REQUEST_IR 0x00000001 +#define ISM_EVENT_CODE_SHUTDOWN 0x80 +#define ISM_EVENT_CODE_TESTLINK 0x83 + +union smcd_sw_event_info { + u64 info; + struct { + u8 uid[SMC_LGR_ID_SIZE]; + unsigned short vlan_id; + u16 code; + }; +}; + +static void smcd_handle_sw_event(struct smc_ism_event_work *wrk) +{ + union smcd_sw_event_info ev_info; + + ev_info.info = wrk->event.info; + switch (wrk->event.code) { + case ISM_EVENT_CODE_SHUTDOWN: /* Peer shut down DMBs */ + smc_smcd_terminate(wrk->smcd, wrk->event.tok, ev_info.vlan_id); + break; + case ISM_EVENT_CODE_TESTLINK: /* Activity timer */ + if (ev_info.code == ISM_EVENT_REQUEST) { + ev_info.code = ISM_EVENT_RESPONSE; + wrk->smcd->ops->signal_event(wrk->smcd, + wrk->event.tok, + ISM_EVENT_REQUEST_IR, + ISM_EVENT_CODE_TESTLINK, + ev_info.info); + } + break; + } +} + +int smc_ism_signal_shutdown(struct smc_link_group *lgr) +{ + int rc; + union smcd_sw_event_info ev_info; + + if (lgr->peer_shutdown) + return 0; + + memcpy(ev_info.uid, lgr->id, SMC_LGR_ID_SIZE); + ev_info.vlan_id = lgr->vlan_id; + ev_info.code = ISM_EVENT_REQUEST; + rc = lgr->smcd->ops->signal_event(lgr->smcd, lgr->peer_gid, + ISM_EVENT_REQUEST_IR, + ISM_EVENT_CODE_SHUTDOWN, + ev_info.info); + return rc; +} + +/* worker for SMC-D events */ +static void smc_ism_event_work(struct work_struct *work) +{ + struct smc_ism_event_work *wrk = + container_of(work, struct smc_ism_event_work, work); + + switch (wrk->event.type) { + case ISM_EVENT_GID: /* GID event, token is peer GID */ + smc_smcd_terminate(wrk->smcd, wrk->event.tok, VLAN_VID_MASK); + break; + case ISM_EVENT_DMB: + break; + case ISM_EVENT_SWR: /* Software defined event */ + smcd_handle_sw_event(wrk); + break; + } + kfree(wrk); +} + +static void smcd_release(struct device *dev) +{ + struct smcd_dev *smcd = container_of(dev, struct smcd_dev, dev); + + kfree(smcd->conn); + kfree(smcd); +} + +struct smcd_dev *smcd_alloc_dev(struct device *parent, const char *name, + const struct smcd_ops *ops, int max_dmbs) +{ + struct smcd_dev *smcd; + + smcd = kzalloc(sizeof(*smcd), GFP_KERNEL); + if (!smcd) + return NULL; + smcd->conn = kcalloc(max_dmbs, sizeof(struct smc_connection *), + GFP_KERNEL); + if (!smcd->conn) { + kfree(smcd); + return NULL; + } + + smcd->event_wq = alloc_ordered_workqueue("ism_evt_wq-%s)", + WQ_MEM_RECLAIM, name); + if (!smcd->event_wq) { + kfree(smcd->conn); + kfree(smcd); + return NULL; + } + + smcd->dev.parent = parent; + smcd->dev.release = smcd_release; + device_initialize(&smcd->dev); + dev_set_name(&smcd->dev, name); + smcd->ops = ops; + if (smc_pnetid_by_dev_port(parent, 0, smcd->pnetid)) + smc_pnetid_by_table_smcd(smcd); + + spin_lock_init(&smcd->lock); + spin_lock_init(&smcd->lgr_lock); + INIT_LIST_HEAD(&smcd->vlan); + INIT_LIST_HEAD(&smcd->lgr_list); + init_waitqueue_head(&smcd->lgrs_deleted); + return smcd; +} +EXPORT_SYMBOL_GPL(smcd_alloc_dev); + +int smcd_register_dev(struct smcd_dev *smcd) +{ + int rc; + + mutex_lock(&smcd_dev_list.mutex); + if (list_empty(&smcd_dev_list.list)) { + u8 *system_eid = NULL; + + system_eid = smcd->ops->get_system_eid(); + if (system_eid[24] != '0' || system_eid[28] != '0') { + smc_ism_v2_capable = true; + memcpy(smc_ism_v2_system_eid, system_eid, + SMC_MAX_EID_LEN); + } + } + /* sort list: devices without pnetid before devices with pnetid */ + if (smcd->pnetid[0]) + list_add_tail(&smcd->list, &smcd_dev_list.list); + else + list_add(&smcd->list, &smcd_dev_list.list); + mutex_unlock(&smcd_dev_list.mutex); + + pr_warn_ratelimited("smc: adding smcd device %s with pnetid %.16s%s\n", + dev_name(&smcd->dev), smcd->pnetid, + smcd->pnetid_by_user ? " (user defined)" : ""); + + rc = device_add(&smcd->dev); + if (rc) { + mutex_lock(&smcd_dev_list.mutex); + list_del(&smcd->list); + mutex_unlock(&smcd_dev_list.mutex); + } + + return rc; +} +EXPORT_SYMBOL_GPL(smcd_register_dev); + +void smcd_unregister_dev(struct smcd_dev *smcd) +{ + pr_warn_ratelimited("smc: removing smcd device %s\n", + dev_name(&smcd->dev)); + mutex_lock(&smcd_dev_list.mutex); + list_del_init(&smcd->list); + mutex_unlock(&smcd_dev_list.mutex); + smcd->going_away = 1; + smc_smcd_terminate_all(smcd); + destroy_workqueue(smcd->event_wq); + + device_del(&smcd->dev); +} +EXPORT_SYMBOL_GPL(smcd_unregister_dev); + +void smcd_free_dev(struct smcd_dev *smcd) +{ + put_device(&smcd->dev); +} +EXPORT_SYMBOL_GPL(smcd_free_dev); + +/* SMCD Device event handler. Called from ISM device interrupt handler. + * Parameters are smcd device pointer, + * - event->type (0 --> DMB, 1 --> GID), + * - event->code (event code), + * - event->tok (either DMB token when event type 0, or GID when event type 1) + * - event->time (time of day) + * - event->info (debug info). + * + * Context: + * - Function called in IRQ context from ISM device driver event handler. + */ +void smcd_handle_event(struct smcd_dev *smcd, struct smcd_event *event) +{ + struct smc_ism_event_work *wrk; + + if (smcd->going_away) + return; + /* copy event to event work queue, and let it be handled there */ + wrk = kmalloc(sizeof(*wrk), GFP_ATOMIC); + if (!wrk) + return; + INIT_WORK(&wrk->work, smc_ism_event_work); + wrk->smcd = smcd; + wrk->event = *event; + queue_work(smcd->event_wq, &wrk->work); +} +EXPORT_SYMBOL_GPL(smcd_handle_event); + +/* SMCD Device interrupt handler. Called from ISM device interrupt handler. + * Parameters are smcd device pointer, DMB number, and the DMBE bitmask. + * Find the connection and schedule the tasklet for this connection. + * + * Context: + * - Function called in IRQ context from ISM device driver IRQ handler. + */ +void smcd_handle_irq(struct smcd_dev *smcd, unsigned int dmbno, u16 dmbemask) +{ + struct smc_connection *conn = NULL; + unsigned long flags; + + spin_lock_irqsave(&smcd->lock, flags); + conn = smcd->conn[dmbno]; + if (conn && !conn->killed) + tasklet_schedule(&conn->rx_tsklet); + spin_unlock_irqrestore(&smcd->lock, flags); +} +EXPORT_SYMBOL_GPL(smcd_handle_irq); + +void __init smc_ism_init(void) +{ + smc_ism_v2_capable = false; + memset(smc_ism_v2_system_eid, 0, SMC_MAX_EID_LEN); +} diff --git a/net/smc/smc_ism.h b/net/smc/smc_ism.h new file mode 100644 index 000000000..d6b2db604 --- /dev/null +++ b/net/smc/smc_ism.h @@ -0,0 +1,58 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Shared Memory Communications Direct over ISM devices (SMC-D) + * + * SMC-D ISM device structure definitions. + * + * Copyright IBM Corp. 2018 + */ + +#ifndef SMCD_ISM_H +#define SMCD_ISM_H + +#include +#include +#include + +#include "smc.h" + +struct smcd_dev_list { /* List of SMCD devices */ + struct list_head list; + struct mutex mutex; /* Protects list of devices */ +}; + +extern struct smcd_dev_list smcd_dev_list; /* list of smcd devices */ + +struct smc_ism_vlanid { /* VLAN id set on ISM device */ + struct list_head list; + unsigned short vlanid; /* Vlan id */ + refcount_t refcnt; /* Reference count */ +}; + +struct smcd_dev; + +int smc_ism_cantalk(u64 peer_gid, unsigned short vlan_id, struct smcd_dev *dev); +void smc_ism_set_conn(struct smc_connection *conn); +void smc_ism_unset_conn(struct smc_connection *conn); +int smc_ism_get_vlan(struct smcd_dev *dev, unsigned short vlan_id); +int smc_ism_put_vlan(struct smcd_dev *dev, unsigned short vlan_id); +int smc_ism_register_dmb(struct smc_link_group *lgr, int buf_size, + struct smc_buf_desc *dmb_desc); +int smc_ism_unregister_dmb(struct smcd_dev *dev, struct smc_buf_desc *dmb_desc); +int smc_ism_signal_shutdown(struct smc_link_group *lgr); +void smc_ism_get_system_eid(u8 **eid); +u16 smc_ism_get_chid(struct smcd_dev *dev); +bool smc_ism_is_v2_capable(void); +void smc_ism_init(void); +int smcd_nl_get_device(struct sk_buff *skb, struct netlink_callback *cb); + +static inline int smc_ism_write(struct smcd_dev *smcd, u64 dmb_tok, + unsigned int idx, bool sf, unsigned int offset, + void *data, size_t len) +{ + int rc; + + rc = smcd->ops->move_data(smcd, dmb_tok, idx, sf, offset, data, len); + return rc < 0 ? rc : 0; +} + +#endif diff --git a/net/smc/smc_llc.c b/net/smc/smc_llc.c new file mode 100644 index 000000000..fcb24a0cc --- /dev/null +++ b/net/smc/smc_llc.c @@ -0,0 +1,2348 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Link Layer Control (LLC) + * + * Copyright IBM Corp. 2016 + * + * Author(s): Klaus Wacker + * Ursula Braun + */ + +#include +#include + +#include "smc.h" +#include "smc_core.h" +#include "smc_clc.h" +#include "smc_llc.h" +#include "smc_pnet.h" + +#define SMC_LLC_DATA_LEN 40 + +struct smc_llc_hdr { + struct smc_wr_rx_hdr common; + union { + struct { + u8 length; /* 44 */ + #if defined(__BIG_ENDIAN_BITFIELD) + u8 reserved:4, + add_link_rej_rsn:4; +#elif defined(__LITTLE_ENDIAN_BITFIELD) + u8 add_link_rej_rsn:4, + reserved:4; +#endif + }; + u16 length_v2; /* 44 - 8192*/ + }; + u8 flags; +} __packed; /* format defined in + * IBM Shared Memory Communications Version 2 + * (https://www.ibm.com/support/pages/node/6326337) + */ + +#define SMC_LLC_FLAG_NO_RMBE_EYEC 0x03 + +struct smc_llc_msg_confirm_link { /* type 0x01 */ + struct smc_llc_hdr hd; + u8 sender_mac[ETH_ALEN]; + u8 sender_gid[SMC_GID_SIZE]; + u8 sender_qp_num[3]; + u8 link_num; + u8 link_uid[SMC_LGR_ID_SIZE]; + u8 max_links; + u8 reserved[9]; +}; + +#define SMC_LLC_FLAG_ADD_LNK_REJ 0x40 +#define SMC_LLC_REJ_RSN_NO_ALT_PATH 1 + +#define SMC_LLC_ADD_LNK_MAX_LINKS 2 + +struct smc_llc_msg_add_link { /* type 0x02 */ + struct smc_llc_hdr hd; + u8 sender_mac[ETH_ALEN]; + u8 reserved2[2]; + u8 sender_gid[SMC_GID_SIZE]; + u8 sender_qp_num[3]; + u8 link_num; +#if defined(__BIG_ENDIAN_BITFIELD) + u8 reserved3 : 4, + qp_mtu : 4; +#elif defined(__LITTLE_ENDIAN_BITFIELD) + u8 qp_mtu : 4, + reserved3 : 4; +#endif + u8 initial_psn[3]; + u8 reserved[8]; +}; + +struct smc_llc_msg_add_link_cont_rt { + __be32 rmb_key; + __be32 rmb_key_new; + __be64 rmb_vaddr_new; +}; + +struct smc_llc_msg_add_link_v2_ext { +#if defined(__BIG_ENDIAN_BITFIELD) + u8 v2_direct : 1, + reserved : 7; +#elif defined(__LITTLE_ENDIAN_BITFIELD) + u8 reserved : 7, + v2_direct : 1; +#endif + u8 reserved2; + u8 client_target_gid[SMC_GID_SIZE]; + u8 reserved3[8]; + u16 num_rkeys; + struct smc_llc_msg_add_link_cont_rt rt[]; +} __packed; /* format defined in + * IBM Shared Memory Communications Version 2 + * (https://www.ibm.com/support/pages/node/6326337) + */ + +struct smc_llc_msg_req_add_link_v2 { + struct smc_llc_hdr hd; + u8 reserved[20]; + u8 gid_cnt; + u8 reserved2[3]; + u8 gid[][SMC_GID_SIZE]; +}; + +#define SMC_LLC_RKEYS_PER_CONT_MSG 2 + +struct smc_llc_msg_add_link_cont { /* type 0x03 */ + struct smc_llc_hdr hd; + u8 link_num; + u8 num_rkeys; + u8 reserved2[2]; + struct smc_llc_msg_add_link_cont_rt rt[SMC_LLC_RKEYS_PER_CONT_MSG]; + u8 reserved[4]; +} __packed; /* format defined in RFC7609 */ + +#define SMC_LLC_FLAG_DEL_LINK_ALL 0x40 +#define SMC_LLC_FLAG_DEL_LINK_ORDERLY 0x20 + +struct smc_llc_msg_del_link { /* type 0x04 */ + struct smc_llc_hdr hd; + u8 link_num; + __be32 reason; + u8 reserved[35]; +} __packed; /* format defined in RFC7609 */ + +struct smc_llc_msg_test_link { /* type 0x07 */ + struct smc_llc_hdr hd; + u8 user_data[16]; + u8 reserved[24]; +}; + +struct smc_rmb_rtoken { + union { + u8 num_rkeys; /* first rtoken byte of CONFIRM LINK msg */ + /* is actually the num of rtokens, first */ + /* rtoken is always for the current link */ + u8 link_id; /* link id of the rtoken */ + }; + __be32 rmb_key; + __be64 rmb_vaddr; +} __packed; /* format defined in RFC7609 */ + +#define SMC_LLC_RKEYS_PER_MSG 3 +#define SMC_LLC_RKEYS_PER_MSG_V2 255 + +struct smc_llc_msg_confirm_rkey { /* type 0x06 */ + struct smc_llc_hdr hd; + struct smc_rmb_rtoken rtoken[SMC_LLC_RKEYS_PER_MSG]; + u8 reserved; +}; + +#define SMC_LLC_DEL_RKEY_MAX 8 +#define SMC_LLC_FLAG_RKEY_RETRY 0x10 +#define SMC_LLC_FLAG_RKEY_NEG 0x20 + +struct smc_llc_msg_delete_rkey { /* type 0x09 */ + struct smc_llc_hdr hd; + u8 num_rkeys; + u8 err_mask; + u8 reserved[2]; + __be32 rkey[8]; + u8 reserved2[4]; +}; + +struct smc_llc_msg_delete_rkey_v2 { /* type 0x29 */ + struct smc_llc_hdr hd; + u8 num_rkeys; + u8 num_inval_rkeys; + u8 reserved[2]; + __be32 rkey[]; +}; + +union smc_llc_msg { + struct smc_llc_msg_confirm_link confirm_link; + struct smc_llc_msg_add_link add_link; + struct smc_llc_msg_req_add_link_v2 req_add_link; + struct smc_llc_msg_add_link_cont add_link_cont; + struct smc_llc_msg_del_link delete_link; + + struct smc_llc_msg_confirm_rkey confirm_rkey; + struct smc_llc_msg_delete_rkey delete_rkey; + + struct smc_llc_msg_test_link test_link; + struct { + struct smc_llc_hdr hdr; + u8 data[SMC_LLC_DATA_LEN]; + } raw; +}; + +#define SMC_LLC_FLAG_RESP 0x80 + +struct smc_llc_qentry { + struct list_head list; + struct smc_link *link; + union smc_llc_msg msg; +}; + +static void smc_llc_enqueue(struct smc_link *link, union smc_llc_msg *llc); + +struct smc_llc_qentry *smc_llc_flow_qentry_clr(struct smc_llc_flow *flow) +{ + struct smc_llc_qentry *qentry = flow->qentry; + + flow->qentry = NULL; + return qentry; +} + +void smc_llc_flow_qentry_del(struct smc_llc_flow *flow) +{ + struct smc_llc_qentry *qentry; + + if (flow->qentry) { + qentry = flow->qentry; + flow->qentry = NULL; + kfree(qentry); + } +} + +static inline void smc_llc_flow_qentry_set(struct smc_llc_flow *flow, + struct smc_llc_qentry *qentry) +{ + flow->qentry = qentry; +} + +static void smc_llc_flow_parallel(struct smc_link_group *lgr, u8 flow_type, + struct smc_llc_qentry *qentry) +{ + u8 msg_type = qentry->msg.raw.hdr.common.llc_type; + + if ((msg_type == SMC_LLC_ADD_LINK || msg_type == SMC_LLC_DELETE_LINK) && + flow_type != msg_type && !lgr->delayed_event) { + lgr->delayed_event = qentry; + return; + } + /* drop parallel or already-in-progress llc requests */ + if (flow_type != msg_type) + pr_warn_once("smc: SMC-R lg %*phN net %llu dropped parallel " + "LLC msg: msg %d flow %d role %d\n", + SMC_LGR_ID_SIZE, &lgr->id, + lgr->net->net_cookie, + qentry->msg.raw.hdr.common.type, + flow_type, lgr->role); + kfree(qentry); +} + +/* try to start a new llc flow, initiated by an incoming llc msg */ +static bool smc_llc_flow_start(struct smc_llc_flow *flow, + struct smc_llc_qentry *qentry) +{ + struct smc_link_group *lgr = qentry->link->lgr; + + spin_lock_bh(&lgr->llc_flow_lock); + if (flow->type) { + /* a flow is already active */ + smc_llc_flow_parallel(lgr, flow->type, qentry); + spin_unlock_bh(&lgr->llc_flow_lock); + return false; + } + switch (qentry->msg.raw.hdr.common.llc_type) { + case SMC_LLC_ADD_LINK: + flow->type = SMC_LLC_FLOW_ADD_LINK; + break; + case SMC_LLC_DELETE_LINK: + flow->type = SMC_LLC_FLOW_DEL_LINK; + break; + case SMC_LLC_CONFIRM_RKEY: + case SMC_LLC_DELETE_RKEY: + flow->type = SMC_LLC_FLOW_RKEY; + break; + default: + flow->type = SMC_LLC_FLOW_NONE; + } + smc_llc_flow_qentry_set(flow, qentry); + spin_unlock_bh(&lgr->llc_flow_lock); + return true; +} + +/* start a new local llc flow, wait till current flow finished */ +int smc_llc_flow_initiate(struct smc_link_group *lgr, + enum smc_llc_flowtype type) +{ + enum smc_llc_flowtype allowed_remote = SMC_LLC_FLOW_NONE; + int rc; + + /* all flows except confirm_rkey and delete_rkey are exclusive, + * confirm/delete rkey flows can run concurrently (local and remote) + */ + if (type == SMC_LLC_FLOW_RKEY) + allowed_remote = SMC_LLC_FLOW_RKEY; +again: + if (list_empty(&lgr->list)) + return -ENODEV; + spin_lock_bh(&lgr->llc_flow_lock); + if (lgr->llc_flow_lcl.type == SMC_LLC_FLOW_NONE && + (lgr->llc_flow_rmt.type == SMC_LLC_FLOW_NONE || + lgr->llc_flow_rmt.type == allowed_remote)) { + lgr->llc_flow_lcl.type = type; + spin_unlock_bh(&lgr->llc_flow_lock); + return 0; + } + spin_unlock_bh(&lgr->llc_flow_lock); + rc = wait_event_timeout(lgr->llc_flow_waiter, (list_empty(&lgr->list) || + (lgr->llc_flow_lcl.type == SMC_LLC_FLOW_NONE && + (lgr->llc_flow_rmt.type == SMC_LLC_FLOW_NONE || + lgr->llc_flow_rmt.type == allowed_remote))), + SMC_LLC_WAIT_TIME * 10); + if (!rc) + return -ETIMEDOUT; + goto again; +} + +/* finish the current llc flow */ +void smc_llc_flow_stop(struct smc_link_group *lgr, struct smc_llc_flow *flow) +{ + spin_lock_bh(&lgr->llc_flow_lock); + memset(flow, 0, sizeof(*flow)); + flow->type = SMC_LLC_FLOW_NONE; + spin_unlock_bh(&lgr->llc_flow_lock); + if (!list_empty(&lgr->list) && lgr->delayed_event && + flow == &lgr->llc_flow_lcl) + schedule_work(&lgr->llc_event_work); + else + wake_up(&lgr->llc_flow_waiter); +} + +/* lnk is optional and used for early wakeup when link goes down, useful in + * cases where we wait for a response on the link after we sent a request + */ +struct smc_llc_qentry *smc_llc_wait(struct smc_link_group *lgr, + struct smc_link *lnk, + int time_out, u8 exp_msg) +{ + struct smc_llc_flow *flow = &lgr->llc_flow_lcl; + u8 rcv_msg; + + wait_event_timeout(lgr->llc_msg_waiter, + (flow->qentry || + (lnk && !smc_link_usable(lnk)) || + list_empty(&lgr->list)), + time_out); + if (!flow->qentry || + (lnk && !smc_link_usable(lnk)) || list_empty(&lgr->list)) { + smc_llc_flow_qentry_del(flow); + goto out; + } + rcv_msg = flow->qentry->msg.raw.hdr.common.llc_type; + if (exp_msg && rcv_msg != exp_msg) { + if (exp_msg == SMC_LLC_ADD_LINK && + rcv_msg == SMC_LLC_DELETE_LINK) { + /* flow_start will delay the unexpected msg */ + smc_llc_flow_start(&lgr->llc_flow_lcl, + smc_llc_flow_qentry_clr(flow)); + return NULL; + } + pr_warn_once("smc: SMC-R lg %*phN net %llu dropped unexpected LLC msg: " + "msg %d exp %d flow %d role %d flags %x\n", + SMC_LGR_ID_SIZE, &lgr->id, lgr->net->net_cookie, + rcv_msg, exp_msg, + flow->type, lgr->role, + flow->qentry->msg.raw.hdr.flags); + smc_llc_flow_qentry_del(flow); + } +out: + return flow->qentry; +} + +/********************************** send *************************************/ + +struct smc_llc_tx_pend { +}; + +/* handler for send/transmission completion of an LLC msg */ +static void smc_llc_tx_handler(struct smc_wr_tx_pend_priv *pend, + struct smc_link *link, + enum ib_wc_status wc_status) +{ + /* future work: handle wc_status error for recovery and failover */ +} + +/** + * smc_llc_add_pending_send() - add LLC control message to pending WQE transmits + * @link: Pointer to SMC link used for sending LLC control message. + * @wr_buf: Out variable returning pointer to work request payload buffer. + * @pend: Out variable returning pointer to private pending WR tracking. + * It's the context the transmit complete handler will get. + * + * Reserves and pre-fills an entry for a pending work request send/tx. + * Used by mid-level smc_llc_send_msg() to prepare for later actual send/tx. + * Can sleep due to smc_get_ctrl_buf (if not in softirq context). + * + * Return: 0 on success, otherwise an error value. + */ +static int smc_llc_add_pending_send(struct smc_link *link, + struct smc_wr_buf **wr_buf, + struct smc_wr_tx_pend_priv **pend) +{ + int rc; + + rc = smc_wr_tx_get_free_slot(link, smc_llc_tx_handler, wr_buf, NULL, + pend); + if (rc < 0) + return rc; + BUILD_BUG_ON_MSG( + sizeof(union smc_llc_msg) > SMC_WR_BUF_SIZE, + "must increase SMC_WR_BUF_SIZE to at least sizeof(struct smc_llc_msg)"); + BUILD_BUG_ON_MSG( + sizeof(union smc_llc_msg) != SMC_WR_TX_SIZE, + "must adapt SMC_WR_TX_SIZE to sizeof(struct smc_llc_msg); if not all smc_wr upper layer protocols use the same message size any more, must start to set link->wr_tx_sges[i].length on each individual smc_wr_tx_send()"); + BUILD_BUG_ON_MSG( + sizeof(struct smc_llc_tx_pend) > SMC_WR_TX_PEND_PRIV_SIZE, + "must increase SMC_WR_TX_PEND_PRIV_SIZE to at least sizeof(struct smc_llc_tx_pend)"); + return 0; +} + +static int smc_llc_add_pending_send_v2(struct smc_link *link, + struct smc_wr_v2_buf **wr_buf, + struct smc_wr_tx_pend_priv **pend) +{ + int rc; + + rc = smc_wr_tx_get_v2_slot(link, smc_llc_tx_handler, wr_buf, pend); + if (rc < 0) + return rc; + return 0; +} + +static void smc_llc_init_msg_hdr(struct smc_llc_hdr *hdr, + struct smc_link_group *lgr, size_t len) +{ + if (lgr->smc_version == SMC_V2) { + hdr->common.llc_version = SMC_V2; + hdr->length_v2 = len; + } else { + hdr->common.llc_version = 0; + hdr->length = len; + } +} + +/* high-level API to send LLC confirm link */ +int smc_llc_send_confirm_link(struct smc_link *link, + enum smc_llc_reqresp reqresp) +{ + struct smc_llc_msg_confirm_link *confllc; + struct smc_wr_tx_pend_priv *pend; + struct smc_wr_buf *wr_buf; + int rc; + + if (!smc_wr_tx_link_hold(link)) + return -ENOLINK; + rc = smc_llc_add_pending_send(link, &wr_buf, &pend); + if (rc) + goto put_out; + confllc = (struct smc_llc_msg_confirm_link *)wr_buf; + memset(confllc, 0, sizeof(*confllc)); + confllc->hd.common.llc_type = SMC_LLC_CONFIRM_LINK; + smc_llc_init_msg_hdr(&confllc->hd, link->lgr, sizeof(*confllc)); + confllc->hd.flags |= SMC_LLC_FLAG_NO_RMBE_EYEC; + if (reqresp == SMC_LLC_RESP) + confllc->hd.flags |= SMC_LLC_FLAG_RESP; + memcpy(confllc->sender_mac, link->smcibdev->mac[link->ibport - 1], + ETH_ALEN); + memcpy(confllc->sender_gid, link->gid, SMC_GID_SIZE); + hton24(confllc->sender_qp_num, link->roce_qp->qp_num); + confllc->link_num = link->link_id; + memcpy(confllc->link_uid, link->link_uid, SMC_LGR_ID_SIZE); + confllc->max_links = SMC_LLC_ADD_LNK_MAX_LINKS; + /* send llc message */ + rc = smc_wr_tx_send(link, pend); +put_out: + smc_wr_tx_link_put(link); + return rc; +} + +/* send LLC confirm rkey request */ +static int smc_llc_send_confirm_rkey(struct smc_link *send_link, + struct smc_buf_desc *rmb_desc) +{ + struct smc_llc_msg_confirm_rkey *rkeyllc; + struct smc_wr_tx_pend_priv *pend; + struct smc_wr_buf *wr_buf; + struct smc_link *link; + int i, rc, rtok_ix; + + if (!smc_wr_tx_link_hold(send_link)) + return -ENOLINK; + rc = smc_llc_add_pending_send(send_link, &wr_buf, &pend); + if (rc) + goto put_out; + rkeyllc = (struct smc_llc_msg_confirm_rkey *)wr_buf; + memset(rkeyllc, 0, sizeof(*rkeyllc)); + rkeyllc->hd.common.llc_type = SMC_LLC_CONFIRM_RKEY; + smc_llc_init_msg_hdr(&rkeyllc->hd, send_link->lgr, sizeof(*rkeyllc)); + + rtok_ix = 1; + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + link = &send_link->lgr->lnk[i]; + if (smc_link_active(link) && link != send_link) { + rkeyllc->rtoken[rtok_ix].link_id = link->link_id; + rkeyllc->rtoken[rtok_ix].rmb_key = + htonl(rmb_desc->mr[link->link_idx]->rkey); + rkeyllc->rtoken[rtok_ix].rmb_vaddr = rmb_desc->is_vm ? + cpu_to_be64((uintptr_t)rmb_desc->cpu_addr) : + cpu_to_be64((u64)sg_dma_address + (rmb_desc->sgt[link->link_idx].sgl)); + rtok_ix++; + } + } + /* rkey of send_link is in rtoken[0] */ + rkeyllc->rtoken[0].num_rkeys = rtok_ix - 1; + rkeyllc->rtoken[0].rmb_key = + htonl(rmb_desc->mr[send_link->link_idx]->rkey); + rkeyllc->rtoken[0].rmb_vaddr = rmb_desc->is_vm ? + cpu_to_be64((uintptr_t)rmb_desc->cpu_addr) : + cpu_to_be64((u64)sg_dma_address + (rmb_desc->sgt[send_link->link_idx].sgl)); + /* send llc message */ + rc = smc_wr_tx_send(send_link, pend); +put_out: + smc_wr_tx_link_put(send_link); + return rc; +} + +/* send LLC delete rkey request */ +static int smc_llc_send_delete_rkey(struct smc_link *link, + struct smc_buf_desc *rmb_desc) +{ + struct smc_llc_msg_delete_rkey *rkeyllc; + struct smc_wr_tx_pend_priv *pend; + struct smc_wr_buf *wr_buf; + int rc; + + if (!smc_wr_tx_link_hold(link)) + return -ENOLINK; + rc = smc_llc_add_pending_send(link, &wr_buf, &pend); + if (rc) + goto put_out; + rkeyllc = (struct smc_llc_msg_delete_rkey *)wr_buf; + memset(rkeyllc, 0, sizeof(*rkeyllc)); + rkeyllc->hd.common.llc_type = SMC_LLC_DELETE_RKEY; + smc_llc_init_msg_hdr(&rkeyllc->hd, link->lgr, sizeof(*rkeyllc)); + rkeyllc->num_rkeys = 1; + rkeyllc->rkey[0] = htonl(rmb_desc->mr[link->link_idx]->rkey); + /* send llc message */ + rc = smc_wr_tx_send(link, pend); +put_out: + smc_wr_tx_link_put(link); + return rc; +} + +/* return first buffer from any of the next buf lists */ +static struct smc_buf_desc *_smc_llc_get_next_rmb(struct smc_link_group *lgr, + int *buf_lst) +{ + struct smc_buf_desc *buf_pos; + + while (*buf_lst < SMC_RMBE_SIZES) { + buf_pos = list_first_entry_or_null(&lgr->rmbs[*buf_lst], + struct smc_buf_desc, list); + if (buf_pos) + return buf_pos; + (*buf_lst)++; + } + return NULL; +} + +/* return next rmb from buffer lists */ +static struct smc_buf_desc *smc_llc_get_next_rmb(struct smc_link_group *lgr, + int *buf_lst, + struct smc_buf_desc *buf_pos) +{ + struct smc_buf_desc *buf_next; + + if (!buf_pos) + return _smc_llc_get_next_rmb(lgr, buf_lst); + + if (list_is_last(&buf_pos->list, &lgr->rmbs[*buf_lst])) { + (*buf_lst)++; + return _smc_llc_get_next_rmb(lgr, buf_lst); + } + buf_next = list_next_entry(buf_pos, list); + return buf_next; +} + +static struct smc_buf_desc *smc_llc_get_first_rmb(struct smc_link_group *lgr, + int *buf_lst) +{ + *buf_lst = 0; + return smc_llc_get_next_rmb(lgr, buf_lst, NULL); +} + +static int smc_llc_fill_ext_v2(struct smc_llc_msg_add_link_v2_ext *ext, + struct smc_link *link, struct smc_link *link_new) +{ + struct smc_link_group *lgr = link->lgr; + struct smc_buf_desc *buf_pos; + int prim_lnk_idx, lnk_idx, i; + struct smc_buf_desc *rmb; + int len = sizeof(*ext); + int buf_lst; + + ext->v2_direct = !lgr->uses_gateway; + memcpy(ext->client_target_gid, link_new->gid, SMC_GID_SIZE); + + prim_lnk_idx = link->link_idx; + lnk_idx = link_new->link_idx; + down_write(&lgr->rmbs_lock); + ext->num_rkeys = lgr->conns_num; + if (!ext->num_rkeys) + goto out; + buf_pos = smc_llc_get_first_rmb(lgr, &buf_lst); + for (i = 0; i < ext->num_rkeys; i++) { + while (buf_pos && !(buf_pos)->used) + buf_pos = smc_llc_get_next_rmb(lgr, &buf_lst, buf_pos); + if (!buf_pos) + break; + rmb = buf_pos; + ext->rt[i].rmb_key = htonl(rmb->mr[prim_lnk_idx]->rkey); + ext->rt[i].rmb_key_new = htonl(rmb->mr[lnk_idx]->rkey); + ext->rt[i].rmb_vaddr_new = rmb->is_vm ? + cpu_to_be64((uintptr_t)rmb->cpu_addr) : + cpu_to_be64((u64)sg_dma_address(rmb->sgt[lnk_idx].sgl)); + buf_pos = smc_llc_get_next_rmb(lgr, &buf_lst, buf_pos); + } + len += i * sizeof(ext->rt[0]); +out: + up_write(&lgr->rmbs_lock); + return len; +} + +/* send ADD LINK request or response */ +int smc_llc_send_add_link(struct smc_link *link, u8 mac[], u8 gid[], + struct smc_link *link_new, + enum smc_llc_reqresp reqresp) +{ + struct smc_llc_msg_add_link_v2_ext *ext = NULL; + struct smc_llc_msg_add_link *addllc; + struct smc_wr_tx_pend_priv *pend; + int len = sizeof(*addllc); + int rc; + + if (!smc_wr_tx_link_hold(link)) + return -ENOLINK; + if (link->lgr->smc_version == SMC_V2) { + struct smc_wr_v2_buf *wr_buf; + + rc = smc_llc_add_pending_send_v2(link, &wr_buf, &pend); + if (rc) + goto put_out; + addllc = (struct smc_llc_msg_add_link *)wr_buf; + ext = (struct smc_llc_msg_add_link_v2_ext *) + &wr_buf->raw[sizeof(*addllc)]; + memset(ext, 0, SMC_WR_TX_SIZE); + } else { + struct smc_wr_buf *wr_buf; + + rc = smc_llc_add_pending_send(link, &wr_buf, &pend); + if (rc) + goto put_out; + addllc = (struct smc_llc_msg_add_link *)wr_buf; + } + + memset(addllc, 0, sizeof(*addllc)); + addllc->hd.common.llc_type = SMC_LLC_ADD_LINK; + if (reqresp == SMC_LLC_RESP) + addllc->hd.flags |= SMC_LLC_FLAG_RESP; + memcpy(addllc->sender_mac, mac, ETH_ALEN); + memcpy(addllc->sender_gid, gid, SMC_GID_SIZE); + if (link_new) { + addllc->link_num = link_new->link_id; + hton24(addllc->sender_qp_num, link_new->roce_qp->qp_num); + hton24(addllc->initial_psn, link_new->psn_initial); + if (reqresp == SMC_LLC_REQ) + addllc->qp_mtu = link_new->path_mtu; + else + addllc->qp_mtu = min(link_new->path_mtu, + link_new->peer_mtu); + } + if (ext && link_new) + len += smc_llc_fill_ext_v2(ext, link, link_new); + smc_llc_init_msg_hdr(&addllc->hd, link->lgr, len); + /* send llc message */ + if (link->lgr->smc_version == SMC_V2) + rc = smc_wr_tx_v2_send(link, pend, len); + else + rc = smc_wr_tx_send(link, pend); +put_out: + smc_wr_tx_link_put(link); + return rc; +} + +/* send DELETE LINK request or response */ +int smc_llc_send_delete_link(struct smc_link *link, u8 link_del_id, + enum smc_llc_reqresp reqresp, bool orderly, + u32 reason) +{ + struct smc_llc_msg_del_link *delllc; + struct smc_wr_tx_pend_priv *pend; + struct smc_wr_buf *wr_buf; + int rc; + + if (!smc_wr_tx_link_hold(link)) + return -ENOLINK; + rc = smc_llc_add_pending_send(link, &wr_buf, &pend); + if (rc) + goto put_out; + delllc = (struct smc_llc_msg_del_link *)wr_buf; + + memset(delllc, 0, sizeof(*delllc)); + delllc->hd.common.llc_type = SMC_LLC_DELETE_LINK; + smc_llc_init_msg_hdr(&delllc->hd, link->lgr, sizeof(*delllc)); + if (reqresp == SMC_LLC_RESP) + delllc->hd.flags |= SMC_LLC_FLAG_RESP; + if (orderly) + delllc->hd.flags |= SMC_LLC_FLAG_DEL_LINK_ORDERLY; + if (link_del_id) + delllc->link_num = link_del_id; + else + delllc->hd.flags |= SMC_LLC_FLAG_DEL_LINK_ALL; + delllc->reason = htonl(reason); + /* send llc message */ + rc = smc_wr_tx_send(link, pend); +put_out: + smc_wr_tx_link_put(link); + return rc; +} + +/* send LLC test link request */ +static int smc_llc_send_test_link(struct smc_link *link, u8 user_data[16]) +{ + struct smc_llc_msg_test_link *testllc; + struct smc_wr_tx_pend_priv *pend; + struct smc_wr_buf *wr_buf; + int rc; + + if (!smc_wr_tx_link_hold(link)) + return -ENOLINK; + rc = smc_llc_add_pending_send(link, &wr_buf, &pend); + if (rc) + goto put_out; + testllc = (struct smc_llc_msg_test_link *)wr_buf; + memset(testllc, 0, sizeof(*testllc)); + testllc->hd.common.llc_type = SMC_LLC_TEST_LINK; + smc_llc_init_msg_hdr(&testllc->hd, link->lgr, sizeof(*testllc)); + memcpy(testllc->user_data, user_data, sizeof(testllc->user_data)); + /* send llc message */ + rc = smc_wr_tx_send(link, pend); +put_out: + smc_wr_tx_link_put(link); + return rc; +} + +/* schedule an llc send on link, may wait for buffers */ +static int smc_llc_send_message(struct smc_link *link, void *llcbuf) +{ + struct smc_wr_tx_pend_priv *pend; + struct smc_wr_buf *wr_buf; + int rc; + + if (!smc_wr_tx_link_hold(link)) + return -ENOLINK; + rc = smc_llc_add_pending_send(link, &wr_buf, &pend); + if (rc) + goto put_out; + memcpy(wr_buf, llcbuf, sizeof(union smc_llc_msg)); + rc = smc_wr_tx_send(link, pend); +put_out: + smc_wr_tx_link_put(link); + return rc; +} + +/* schedule an llc send on link, may wait for buffers, + * and wait for send completion notification. + * @return 0 on success + */ +static int smc_llc_send_message_wait(struct smc_link *link, void *llcbuf) +{ + struct smc_wr_tx_pend_priv *pend; + struct smc_wr_buf *wr_buf; + int rc; + + if (!smc_wr_tx_link_hold(link)) + return -ENOLINK; + rc = smc_llc_add_pending_send(link, &wr_buf, &pend); + if (rc) + goto put_out; + memcpy(wr_buf, llcbuf, sizeof(union smc_llc_msg)); + rc = smc_wr_tx_send_wait(link, pend, SMC_LLC_WAIT_TIME); +put_out: + smc_wr_tx_link_put(link); + return rc; +} + +/********************************* receive ***********************************/ + +static int smc_llc_alloc_alt_link(struct smc_link_group *lgr, + enum smc_lgr_type lgr_new_t) +{ + int i; + + if (lgr->type == SMC_LGR_SYMMETRIC || + (lgr->type != SMC_LGR_SINGLE && + (lgr_new_t == SMC_LGR_ASYMMETRIC_LOCAL || + lgr_new_t == SMC_LGR_ASYMMETRIC_PEER))) + return -EMLINK; + + if (lgr_new_t == SMC_LGR_ASYMMETRIC_LOCAL || + lgr_new_t == SMC_LGR_ASYMMETRIC_PEER) { + for (i = SMC_LINKS_PER_LGR_MAX - 1; i >= 0; i--) + if (lgr->lnk[i].state == SMC_LNK_UNUSED) + return i; + } else { + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) + if (lgr->lnk[i].state == SMC_LNK_UNUSED) + return i; + } + return -EMLINK; +} + +/* send one add_link_continue msg */ +static int smc_llc_add_link_cont(struct smc_link *link, + struct smc_link *link_new, u8 *num_rkeys_todo, + int *buf_lst, struct smc_buf_desc **buf_pos) +{ + struct smc_llc_msg_add_link_cont *addc_llc; + struct smc_link_group *lgr = link->lgr; + int prim_lnk_idx, lnk_idx, i, rc; + struct smc_wr_tx_pend_priv *pend; + struct smc_wr_buf *wr_buf; + struct smc_buf_desc *rmb; + u8 n; + + if (!smc_wr_tx_link_hold(link)) + return -ENOLINK; + rc = smc_llc_add_pending_send(link, &wr_buf, &pend); + if (rc) + goto put_out; + addc_llc = (struct smc_llc_msg_add_link_cont *)wr_buf; + memset(addc_llc, 0, sizeof(*addc_llc)); + + prim_lnk_idx = link->link_idx; + lnk_idx = link_new->link_idx; + addc_llc->link_num = link_new->link_id; + addc_llc->num_rkeys = *num_rkeys_todo; + n = *num_rkeys_todo; + for (i = 0; i < min_t(u8, n, SMC_LLC_RKEYS_PER_CONT_MSG); i++) { + while (*buf_pos && !(*buf_pos)->used) + *buf_pos = smc_llc_get_next_rmb(lgr, buf_lst, *buf_pos); + if (!*buf_pos) { + addc_llc->num_rkeys = addc_llc->num_rkeys - + *num_rkeys_todo; + *num_rkeys_todo = 0; + break; + } + rmb = *buf_pos; + + addc_llc->rt[i].rmb_key = htonl(rmb->mr[prim_lnk_idx]->rkey); + addc_llc->rt[i].rmb_key_new = htonl(rmb->mr[lnk_idx]->rkey); + addc_llc->rt[i].rmb_vaddr_new = rmb->is_vm ? + cpu_to_be64((uintptr_t)rmb->cpu_addr) : + cpu_to_be64((u64)sg_dma_address(rmb->sgt[lnk_idx].sgl)); + + (*num_rkeys_todo)--; + *buf_pos = smc_llc_get_next_rmb(lgr, buf_lst, *buf_pos); + } + addc_llc->hd.common.llc_type = SMC_LLC_ADD_LINK_CONT; + addc_llc->hd.length = sizeof(struct smc_llc_msg_add_link_cont); + if (lgr->role == SMC_CLNT) + addc_llc->hd.flags |= SMC_LLC_FLAG_RESP; + rc = smc_wr_tx_send(link, pend); +put_out: + smc_wr_tx_link_put(link); + return rc; +} + +static int smc_llc_cli_rkey_exchange(struct smc_link *link, + struct smc_link *link_new) +{ + struct smc_llc_msg_add_link_cont *addc_llc; + struct smc_link_group *lgr = link->lgr; + u8 max, num_rkeys_send, num_rkeys_recv; + struct smc_llc_qentry *qentry; + struct smc_buf_desc *buf_pos; + int buf_lst; + int rc = 0; + int i; + + down_write(&lgr->rmbs_lock); + num_rkeys_send = lgr->conns_num; + buf_pos = smc_llc_get_first_rmb(lgr, &buf_lst); + do { + qentry = smc_llc_wait(lgr, NULL, SMC_LLC_WAIT_TIME, + SMC_LLC_ADD_LINK_CONT); + if (!qentry) { + rc = -ETIMEDOUT; + break; + } + addc_llc = &qentry->msg.add_link_cont; + num_rkeys_recv = addc_llc->num_rkeys; + max = min_t(u8, num_rkeys_recv, SMC_LLC_RKEYS_PER_CONT_MSG); + for (i = 0; i < max; i++) { + smc_rtoken_set(lgr, link->link_idx, link_new->link_idx, + addc_llc->rt[i].rmb_key, + addc_llc->rt[i].rmb_vaddr_new, + addc_llc->rt[i].rmb_key_new); + num_rkeys_recv--; + } + smc_llc_flow_qentry_del(&lgr->llc_flow_lcl); + rc = smc_llc_add_link_cont(link, link_new, &num_rkeys_send, + &buf_lst, &buf_pos); + if (rc) + break; + } while (num_rkeys_send || num_rkeys_recv); + + up_write(&lgr->rmbs_lock); + return rc; +} + +/* prepare and send an add link reject response */ +static int smc_llc_cli_add_link_reject(struct smc_llc_qentry *qentry) +{ + qentry->msg.raw.hdr.flags |= SMC_LLC_FLAG_RESP; + qentry->msg.raw.hdr.flags |= SMC_LLC_FLAG_ADD_LNK_REJ; + qentry->msg.raw.hdr.add_link_rej_rsn = SMC_LLC_REJ_RSN_NO_ALT_PATH; + smc_llc_init_msg_hdr(&qentry->msg.raw.hdr, qentry->link->lgr, + sizeof(qentry->msg)); + return smc_llc_send_message(qentry->link, &qentry->msg); +} + +static int smc_llc_cli_conf_link(struct smc_link *link, + struct smc_init_info *ini, + struct smc_link *link_new, + enum smc_lgr_type lgr_new_t) +{ + struct smc_link_group *lgr = link->lgr; + struct smc_llc_qentry *qentry = NULL; + int rc = 0; + + /* receive CONFIRM LINK request over RoCE fabric */ + qentry = smc_llc_wait(lgr, NULL, SMC_LLC_WAIT_FIRST_TIME, 0); + if (!qentry) { + rc = smc_llc_send_delete_link(link, link_new->link_id, + SMC_LLC_REQ, false, + SMC_LLC_DEL_LOST_PATH); + return -ENOLINK; + } + if (qentry->msg.raw.hdr.common.llc_type != SMC_LLC_CONFIRM_LINK) { + /* received DELETE_LINK instead */ + qentry->msg.raw.hdr.flags |= SMC_LLC_FLAG_RESP; + smc_llc_send_message(link, &qentry->msg); + smc_llc_flow_qentry_del(&lgr->llc_flow_lcl); + return -ENOLINK; + } + smc_llc_save_peer_uid(qentry); + smc_llc_flow_qentry_del(&lgr->llc_flow_lcl); + + rc = smc_ib_modify_qp_rts(link_new); + if (rc) { + smc_llc_send_delete_link(link, link_new->link_id, SMC_LLC_REQ, + false, SMC_LLC_DEL_LOST_PATH); + return -ENOLINK; + } + smc_wr_remember_qp_attr(link_new); + + rc = smcr_buf_reg_lgr(link_new); + if (rc) { + smc_llc_send_delete_link(link, link_new->link_id, SMC_LLC_REQ, + false, SMC_LLC_DEL_LOST_PATH); + return -ENOLINK; + } + + /* send CONFIRM LINK response over RoCE fabric */ + rc = smc_llc_send_confirm_link(link_new, SMC_LLC_RESP); + if (rc) { + smc_llc_send_delete_link(link, link_new->link_id, SMC_LLC_REQ, + false, SMC_LLC_DEL_LOST_PATH); + return -ENOLINK; + } + smc_llc_link_active(link_new); + if (lgr_new_t == SMC_LGR_ASYMMETRIC_LOCAL || + lgr_new_t == SMC_LGR_ASYMMETRIC_PEER) + smcr_lgr_set_type_asym(lgr, lgr_new_t, link_new->link_idx); + else + smcr_lgr_set_type(lgr, lgr_new_t); + return 0; +} + +static void smc_llc_save_add_link_rkeys(struct smc_link *link, + struct smc_link *link_new) +{ + struct smc_llc_msg_add_link_v2_ext *ext; + struct smc_link_group *lgr = link->lgr; + int max, i; + + ext = (struct smc_llc_msg_add_link_v2_ext *)((u8 *)lgr->wr_rx_buf_v2 + + SMC_WR_TX_SIZE); + max = min_t(u8, ext->num_rkeys, SMC_LLC_RKEYS_PER_MSG_V2); + down_write(&lgr->rmbs_lock); + for (i = 0; i < max; i++) { + smc_rtoken_set(lgr, link->link_idx, link_new->link_idx, + ext->rt[i].rmb_key, + ext->rt[i].rmb_vaddr_new, + ext->rt[i].rmb_key_new); + } + up_write(&lgr->rmbs_lock); +} + +static void smc_llc_save_add_link_info(struct smc_link *link, + struct smc_llc_msg_add_link *add_llc) +{ + link->peer_qpn = ntoh24(add_llc->sender_qp_num); + memcpy(link->peer_gid, add_llc->sender_gid, SMC_GID_SIZE); + memcpy(link->peer_mac, add_llc->sender_mac, ETH_ALEN); + link->peer_psn = ntoh24(add_llc->initial_psn); + link->peer_mtu = add_llc->qp_mtu; +} + +/* as an SMC client, process an add link request */ +int smc_llc_cli_add_link(struct smc_link *link, struct smc_llc_qentry *qentry) +{ + struct smc_llc_msg_add_link *llc = &qentry->msg.add_link; + enum smc_lgr_type lgr_new_t = SMC_LGR_SYMMETRIC; + struct smc_link_group *lgr = smc_get_lgr(link); + struct smc_init_info *ini = NULL; + struct smc_link *lnk_new = NULL; + int lnk_idx, rc = 0; + + if (!llc->qp_mtu) + goto out_reject; + + ini = kzalloc(sizeof(*ini), GFP_KERNEL); + if (!ini) { + rc = -ENOMEM; + goto out_reject; + } + + ini->vlan_id = lgr->vlan_id; + if (lgr->smc_version == SMC_V2) { + ini->check_smcrv2 = true; + ini->smcrv2.saddr = lgr->saddr; + ini->smcrv2.daddr = smc_ib_gid_to_ipv4(llc->sender_gid); + } + smc_pnet_find_alt_roce(lgr, ini, link->smcibdev); + if (!memcmp(llc->sender_gid, link->peer_gid, SMC_GID_SIZE) && + (lgr->smc_version == SMC_V2 || + !memcmp(llc->sender_mac, link->peer_mac, ETH_ALEN))) { + if (!ini->ib_dev && !ini->smcrv2.ib_dev_v2) + goto out_reject; + lgr_new_t = SMC_LGR_ASYMMETRIC_PEER; + } + if (lgr->smc_version == SMC_V2 && !ini->smcrv2.ib_dev_v2) { + lgr_new_t = SMC_LGR_ASYMMETRIC_LOCAL; + ini->smcrv2.ib_dev_v2 = link->smcibdev; + ini->smcrv2.ib_port_v2 = link->ibport; + } else if (lgr->smc_version < SMC_V2 && !ini->ib_dev) { + lgr_new_t = SMC_LGR_ASYMMETRIC_LOCAL; + ini->ib_dev = link->smcibdev; + ini->ib_port = link->ibport; + } + lnk_idx = smc_llc_alloc_alt_link(lgr, lgr_new_t); + if (lnk_idx < 0) + goto out_reject; + lnk_new = &lgr->lnk[lnk_idx]; + rc = smcr_link_init(lgr, lnk_new, lnk_idx, ini); + if (rc) + goto out_reject; + smc_llc_save_add_link_info(lnk_new, llc); + lnk_new->link_id = llc->link_num; /* SMC server assigns link id */ + smc_llc_link_set_uid(lnk_new); + + rc = smc_ib_ready_link(lnk_new); + if (rc) + goto out_clear_lnk; + + rc = smcr_buf_map_lgr(lnk_new); + if (rc) + goto out_clear_lnk; + + rc = smc_llc_send_add_link(link, + lnk_new->smcibdev->mac[lnk_new->ibport - 1], + lnk_new->gid, lnk_new, SMC_LLC_RESP); + if (rc) + goto out_clear_lnk; + if (lgr->smc_version == SMC_V2) { + smc_llc_save_add_link_rkeys(link, lnk_new); + } else { + rc = smc_llc_cli_rkey_exchange(link, lnk_new); + if (rc) { + rc = 0; + goto out_clear_lnk; + } + } + rc = smc_llc_cli_conf_link(link, ini, lnk_new, lgr_new_t); + if (!rc) + goto out; +out_clear_lnk: + lnk_new->state = SMC_LNK_INACTIVE; + smcr_link_clear(lnk_new, false); +out_reject: + smc_llc_cli_add_link_reject(qentry); +out: + kfree(ini); + kfree(qentry); + return rc; +} + +static void smc_llc_send_request_add_link(struct smc_link *link) +{ + struct smc_llc_msg_req_add_link_v2 *llc; + struct smc_wr_tx_pend_priv *pend; + struct smc_wr_v2_buf *wr_buf; + struct smc_gidlist gidlist; + int rc, len, i; + + if (!smc_wr_tx_link_hold(link)) + return; + if (link->lgr->type == SMC_LGR_SYMMETRIC || + link->lgr->type == SMC_LGR_ASYMMETRIC_PEER) + goto put_out; + + smc_fill_gid_list(link->lgr, &gidlist, link->smcibdev, link->gid); + if (gidlist.len <= 1) + goto put_out; + + rc = smc_llc_add_pending_send_v2(link, &wr_buf, &pend); + if (rc) + goto put_out; + llc = (struct smc_llc_msg_req_add_link_v2 *)wr_buf; + memset(llc, 0, SMC_WR_TX_SIZE); + + llc->hd.common.llc_type = SMC_LLC_REQ_ADD_LINK; + for (i = 0; i < gidlist.len; i++) + memcpy(llc->gid[i], gidlist.list[i], sizeof(gidlist.list[0])); + llc->gid_cnt = gidlist.len; + len = sizeof(*llc) + (gidlist.len * sizeof(gidlist.list[0])); + smc_llc_init_msg_hdr(&llc->hd, link->lgr, len); + rc = smc_wr_tx_v2_send(link, pend, len); + if (!rc) + /* set REQ_ADD_LINK flow and wait for response from peer */ + link->lgr->llc_flow_lcl.type = SMC_LLC_FLOW_REQ_ADD_LINK; +put_out: + smc_wr_tx_link_put(link); +} + +/* as an SMC client, invite server to start the add_link processing */ +static void smc_llc_cli_add_link_invite(struct smc_link *link, + struct smc_llc_qentry *qentry) +{ + struct smc_link_group *lgr = smc_get_lgr(link); + struct smc_init_info *ini = NULL; + + if (lgr->smc_version == SMC_V2) { + smc_llc_send_request_add_link(link); + goto out; + } + + if (lgr->type == SMC_LGR_SYMMETRIC || + lgr->type == SMC_LGR_ASYMMETRIC_PEER) + goto out; + + ini = kzalloc(sizeof(*ini), GFP_KERNEL); + if (!ini) + goto out; + + ini->vlan_id = lgr->vlan_id; + smc_pnet_find_alt_roce(lgr, ini, link->smcibdev); + if (!ini->ib_dev) + goto out; + + smc_llc_send_add_link(link, ini->ib_dev->mac[ini->ib_port - 1], + ini->ib_gid, NULL, SMC_LLC_REQ); +out: + kfree(ini); + kfree(qentry); +} + +static bool smc_llc_is_empty_llc_message(union smc_llc_msg *llc) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(llc->raw.data); i++) + if (llc->raw.data[i]) + return false; + return true; +} + +static bool smc_llc_is_local_add_link(union smc_llc_msg *llc) +{ + if (llc->raw.hdr.common.llc_type == SMC_LLC_ADD_LINK && + smc_llc_is_empty_llc_message(llc)) + return true; + return false; +} + +static void smc_llc_process_cli_add_link(struct smc_link_group *lgr) +{ + struct smc_llc_qentry *qentry; + + qentry = smc_llc_flow_qentry_clr(&lgr->llc_flow_lcl); + + mutex_lock(&lgr->llc_conf_mutex); + if (smc_llc_is_local_add_link(&qentry->msg)) + smc_llc_cli_add_link_invite(qentry->link, qentry); + else + smc_llc_cli_add_link(qentry->link, qentry); + mutex_unlock(&lgr->llc_conf_mutex); +} + +static int smc_llc_active_link_count(struct smc_link_group *lgr) +{ + int i, link_count = 0; + + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + if (!smc_link_active(&lgr->lnk[i])) + continue; + link_count++; + } + return link_count; +} + +/* find the asymmetric link when 3 links are established */ +static struct smc_link *smc_llc_find_asym_link(struct smc_link_group *lgr) +{ + int asym_idx = -ENOENT; + int i, j, k; + bool found; + + /* determine asymmetric link */ + found = false; + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + for (j = i + 1; j < SMC_LINKS_PER_LGR_MAX; j++) { + if (!smc_link_usable(&lgr->lnk[i]) || + !smc_link_usable(&lgr->lnk[j])) + continue; + if (!memcmp(lgr->lnk[i].gid, lgr->lnk[j].gid, + SMC_GID_SIZE)) { + found = true; /* asym_lnk is i or j */ + break; + } + } + if (found) + break; + } + if (!found) + goto out; /* no asymmetric link */ + for (k = 0; k < SMC_LINKS_PER_LGR_MAX; k++) { + if (!smc_link_usable(&lgr->lnk[k])) + continue; + if (k != i && + !memcmp(lgr->lnk[i].peer_gid, lgr->lnk[k].peer_gid, + SMC_GID_SIZE)) { + asym_idx = i; + break; + } + if (k != j && + !memcmp(lgr->lnk[j].peer_gid, lgr->lnk[k].peer_gid, + SMC_GID_SIZE)) { + asym_idx = j; + break; + } + } +out: + return (asym_idx < 0) ? NULL : &lgr->lnk[asym_idx]; +} + +static void smc_llc_delete_asym_link(struct smc_link_group *lgr) +{ + struct smc_link *lnk_new = NULL, *lnk_asym; + struct smc_llc_qentry *qentry; + int rc; + + lnk_asym = smc_llc_find_asym_link(lgr); + if (!lnk_asym) + return; /* no asymmetric link */ + if (!smc_link_downing(&lnk_asym->state)) + return; + lnk_new = smc_switch_conns(lgr, lnk_asym, false); + smc_wr_tx_wait_no_pending_sends(lnk_asym); + if (!lnk_new) + goto out_free; + /* change flow type from ADD_LINK into DEL_LINK */ + lgr->llc_flow_lcl.type = SMC_LLC_FLOW_DEL_LINK; + rc = smc_llc_send_delete_link(lnk_new, lnk_asym->link_id, SMC_LLC_REQ, + true, SMC_LLC_DEL_NO_ASYM_NEEDED); + if (rc) { + smcr_link_down_cond(lnk_new); + goto out_free; + } + qentry = smc_llc_wait(lgr, lnk_new, SMC_LLC_WAIT_TIME, + SMC_LLC_DELETE_LINK); + if (!qentry) { + smcr_link_down_cond(lnk_new); + goto out_free; + } + smc_llc_flow_qentry_del(&lgr->llc_flow_lcl); +out_free: + smcr_link_clear(lnk_asym, true); +} + +static int smc_llc_srv_rkey_exchange(struct smc_link *link, + struct smc_link *link_new) +{ + struct smc_llc_msg_add_link_cont *addc_llc; + struct smc_link_group *lgr = link->lgr; + u8 max, num_rkeys_send, num_rkeys_recv; + struct smc_llc_qentry *qentry = NULL; + struct smc_buf_desc *buf_pos; + int buf_lst; + int rc = 0; + int i; + + down_write(&lgr->rmbs_lock); + num_rkeys_send = lgr->conns_num; + buf_pos = smc_llc_get_first_rmb(lgr, &buf_lst); + do { + smc_llc_add_link_cont(link, link_new, &num_rkeys_send, + &buf_lst, &buf_pos); + qentry = smc_llc_wait(lgr, link, SMC_LLC_WAIT_TIME, + SMC_LLC_ADD_LINK_CONT); + if (!qentry) { + rc = -ETIMEDOUT; + goto out; + } + addc_llc = &qentry->msg.add_link_cont; + num_rkeys_recv = addc_llc->num_rkeys; + max = min_t(u8, num_rkeys_recv, SMC_LLC_RKEYS_PER_CONT_MSG); + for (i = 0; i < max; i++) { + smc_rtoken_set(lgr, link->link_idx, link_new->link_idx, + addc_llc->rt[i].rmb_key, + addc_llc->rt[i].rmb_vaddr_new, + addc_llc->rt[i].rmb_key_new); + num_rkeys_recv--; + } + smc_llc_flow_qentry_del(&lgr->llc_flow_lcl); + } while (num_rkeys_send || num_rkeys_recv); +out: + up_write(&lgr->rmbs_lock); + return rc; +} + +static int smc_llc_srv_conf_link(struct smc_link *link, + struct smc_link *link_new, + enum smc_lgr_type lgr_new_t) +{ + struct smc_link_group *lgr = link->lgr; + struct smc_llc_qentry *qentry = NULL; + int rc; + + /* send CONFIRM LINK request over the RoCE fabric */ + rc = smc_llc_send_confirm_link(link_new, SMC_LLC_REQ); + if (rc) + return -ENOLINK; + /* receive CONFIRM LINK response over the RoCE fabric */ + qentry = smc_llc_wait(lgr, link, SMC_LLC_WAIT_FIRST_TIME, 0); + if (!qentry || + qentry->msg.raw.hdr.common.llc_type != SMC_LLC_CONFIRM_LINK) { + /* send DELETE LINK */ + smc_llc_send_delete_link(link, link_new->link_id, SMC_LLC_REQ, + false, SMC_LLC_DEL_LOST_PATH); + if (qentry) + smc_llc_flow_qentry_del(&lgr->llc_flow_lcl); + return -ENOLINK; + } + smc_llc_save_peer_uid(qentry); + smc_llc_link_active(link_new); + if (lgr_new_t == SMC_LGR_ASYMMETRIC_LOCAL || + lgr_new_t == SMC_LGR_ASYMMETRIC_PEER) + smcr_lgr_set_type_asym(lgr, lgr_new_t, link_new->link_idx); + else + smcr_lgr_set_type(lgr, lgr_new_t); + smc_llc_flow_qentry_del(&lgr->llc_flow_lcl); + return 0; +} + +static void smc_llc_send_req_add_link_response(struct smc_llc_qentry *qentry) +{ + qentry->msg.raw.hdr.flags |= SMC_LLC_FLAG_RESP; + smc_llc_init_msg_hdr(&qentry->msg.raw.hdr, qentry->link->lgr, + sizeof(qentry->msg)); + memset(&qentry->msg.raw.data, 0, sizeof(qentry->msg.raw.data)); + smc_llc_send_message(qentry->link, &qentry->msg); +} + +int smc_llc_srv_add_link(struct smc_link *link, + struct smc_llc_qentry *req_qentry) +{ + enum smc_lgr_type lgr_new_t = SMC_LGR_SYMMETRIC; + struct smc_link_group *lgr = link->lgr; + struct smc_llc_msg_add_link *add_llc; + struct smc_llc_qentry *qentry = NULL; + bool send_req_add_link_resp = false; + struct smc_link *link_new = NULL; + struct smc_init_info *ini = NULL; + int lnk_idx, rc = 0; + + if (req_qentry && + req_qentry->msg.raw.hdr.common.llc_type == SMC_LLC_REQ_ADD_LINK) + send_req_add_link_resp = true; + + ini = kzalloc(sizeof(*ini), GFP_KERNEL); + if (!ini) { + rc = -ENOMEM; + goto out; + } + + /* ignore client add link recommendation, start new flow */ + ini->vlan_id = lgr->vlan_id; + if (lgr->smc_version == SMC_V2) { + ini->check_smcrv2 = true; + ini->smcrv2.saddr = lgr->saddr; + if (send_req_add_link_resp) { + struct smc_llc_msg_req_add_link_v2 *req_add = + &req_qentry->msg.req_add_link; + + ini->smcrv2.daddr = smc_ib_gid_to_ipv4(req_add->gid[0]); + } + } + smc_pnet_find_alt_roce(lgr, ini, link->smcibdev); + if (lgr->smc_version == SMC_V2 && !ini->smcrv2.ib_dev_v2) { + lgr_new_t = SMC_LGR_ASYMMETRIC_LOCAL; + ini->smcrv2.ib_dev_v2 = link->smcibdev; + ini->smcrv2.ib_port_v2 = link->ibport; + } else if (lgr->smc_version < SMC_V2 && !ini->ib_dev) { + lgr_new_t = SMC_LGR_ASYMMETRIC_LOCAL; + ini->ib_dev = link->smcibdev; + ini->ib_port = link->ibport; + } + lnk_idx = smc_llc_alloc_alt_link(lgr, lgr_new_t); + if (lnk_idx < 0) { + rc = 0; + goto out; + } + + rc = smcr_link_init(lgr, &lgr->lnk[lnk_idx], lnk_idx, ini); + if (rc) + goto out; + link_new = &lgr->lnk[lnk_idx]; + + rc = smcr_buf_map_lgr(link_new); + if (rc) + goto out_err; + + rc = smc_llc_send_add_link(link, + link_new->smcibdev->mac[link_new->ibport-1], + link_new->gid, link_new, SMC_LLC_REQ); + if (rc) + goto out_err; + send_req_add_link_resp = false; + /* receive ADD LINK response over the RoCE fabric */ + qentry = smc_llc_wait(lgr, link, SMC_LLC_WAIT_TIME, SMC_LLC_ADD_LINK); + if (!qentry) { + rc = -ETIMEDOUT; + goto out_err; + } + add_llc = &qentry->msg.add_link; + if (add_llc->hd.flags & SMC_LLC_FLAG_ADD_LNK_REJ) { + smc_llc_flow_qentry_del(&lgr->llc_flow_lcl); + rc = -ENOLINK; + goto out_err; + } + if (lgr->type == SMC_LGR_SINGLE && + (!memcmp(add_llc->sender_gid, link->peer_gid, SMC_GID_SIZE) && + (lgr->smc_version == SMC_V2 || + !memcmp(add_llc->sender_mac, link->peer_mac, ETH_ALEN)))) { + lgr_new_t = SMC_LGR_ASYMMETRIC_PEER; + } + smc_llc_save_add_link_info(link_new, add_llc); + smc_llc_flow_qentry_del(&lgr->llc_flow_lcl); + + rc = smc_ib_ready_link(link_new); + if (rc) + goto out_err; + rc = smcr_buf_reg_lgr(link_new); + if (rc) + goto out_err; + if (lgr->smc_version == SMC_V2) { + smc_llc_save_add_link_rkeys(link, link_new); + } else { + rc = smc_llc_srv_rkey_exchange(link, link_new); + if (rc) + goto out_err; + } + rc = smc_llc_srv_conf_link(link, link_new, lgr_new_t); + if (rc) + goto out_err; + kfree(ini); + return 0; +out_err: + if (link_new) { + link_new->state = SMC_LNK_INACTIVE; + smcr_link_clear(link_new, false); + } +out: + kfree(ini); + if (send_req_add_link_resp) + smc_llc_send_req_add_link_response(req_qentry); + return rc; +} + +static void smc_llc_process_srv_add_link(struct smc_link_group *lgr) +{ + struct smc_link *link = lgr->llc_flow_lcl.qentry->link; + struct smc_llc_qentry *qentry; + int rc; + + qentry = smc_llc_flow_qentry_clr(&lgr->llc_flow_lcl); + + mutex_lock(&lgr->llc_conf_mutex); + rc = smc_llc_srv_add_link(link, qentry); + if (!rc && lgr->type == SMC_LGR_SYMMETRIC) { + /* delete any asymmetric link */ + smc_llc_delete_asym_link(lgr); + } + mutex_unlock(&lgr->llc_conf_mutex); + kfree(qentry); +} + +/* enqueue a local add_link req to trigger a new add_link flow */ +void smc_llc_add_link_local(struct smc_link *link) +{ + struct smc_llc_msg_add_link add_llc = {}; + + add_llc.hd.common.llc_type = SMC_LLC_ADD_LINK; + smc_llc_init_msg_hdr(&add_llc.hd, link->lgr, sizeof(add_llc)); + /* no dev and port needed */ + smc_llc_enqueue(link, (union smc_llc_msg *)&add_llc); +} + +/* worker to process an add link message */ +static void smc_llc_add_link_work(struct work_struct *work) +{ + struct smc_link_group *lgr = container_of(work, struct smc_link_group, + llc_add_link_work); + + if (list_empty(&lgr->list)) { + /* link group is terminating */ + smc_llc_flow_qentry_del(&lgr->llc_flow_lcl); + goto out; + } + + if (lgr->role == SMC_CLNT) + smc_llc_process_cli_add_link(lgr); + else + smc_llc_process_srv_add_link(lgr); +out: + if (lgr->llc_flow_lcl.type != SMC_LLC_FLOW_REQ_ADD_LINK) + smc_llc_flow_stop(lgr, &lgr->llc_flow_lcl); +} + +/* enqueue a local del_link msg to trigger a new del_link flow, + * called only for role SMC_SERV + */ +void smc_llc_srv_delete_link_local(struct smc_link *link, u8 del_link_id) +{ + struct smc_llc_msg_del_link del_llc = {}; + + del_llc.hd.common.llc_type = SMC_LLC_DELETE_LINK; + smc_llc_init_msg_hdr(&del_llc.hd, link->lgr, sizeof(del_llc)); + del_llc.link_num = del_link_id; + del_llc.reason = htonl(SMC_LLC_DEL_LOST_PATH); + del_llc.hd.flags |= SMC_LLC_FLAG_DEL_LINK_ORDERLY; + smc_llc_enqueue(link, (union smc_llc_msg *)&del_llc); +} + +static void smc_llc_process_cli_delete_link(struct smc_link_group *lgr) +{ + struct smc_link *lnk_del = NULL, *lnk_asym, *lnk; + struct smc_llc_msg_del_link *del_llc; + struct smc_llc_qentry *qentry; + int active_links; + int lnk_idx; + + qentry = smc_llc_flow_qentry_clr(&lgr->llc_flow_lcl); + lnk = qentry->link; + del_llc = &qentry->msg.delete_link; + + if (del_llc->hd.flags & SMC_LLC_FLAG_DEL_LINK_ALL) { + smc_lgr_terminate_sched(lgr); + goto out; + } + mutex_lock(&lgr->llc_conf_mutex); + /* delete single link */ + for (lnk_idx = 0; lnk_idx < SMC_LINKS_PER_LGR_MAX; lnk_idx++) { + if (lgr->lnk[lnk_idx].link_id != del_llc->link_num) + continue; + lnk_del = &lgr->lnk[lnk_idx]; + break; + } + del_llc->hd.flags |= SMC_LLC_FLAG_RESP; + if (!lnk_del) { + /* link was not found */ + del_llc->reason = htonl(SMC_LLC_DEL_NOLNK); + smc_llc_send_message(lnk, &qentry->msg); + goto out_unlock; + } + lnk_asym = smc_llc_find_asym_link(lgr); + + del_llc->reason = 0; + smc_llc_send_message(lnk, &qentry->msg); /* response */ + + if (smc_link_downing(&lnk_del->state)) + smc_switch_conns(lgr, lnk_del, false); + smcr_link_clear(lnk_del, true); + + active_links = smc_llc_active_link_count(lgr); + if (lnk_del == lnk_asym) { + /* expected deletion of asym link, don't change lgr state */ + } else if (active_links == 1) { + smcr_lgr_set_type(lgr, SMC_LGR_SINGLE); + } else if (!active_links) { + smcr_lgr_set_type(lgr, SMC_LGR_NONE); + smc_lgr_terminate_sched(lgr); + } +out_unlock: + mutex_unlock(&lgr->llc_conf_mutex); +out: + kfree(qentry); +} + +/* try to send a DELETE LINK ALL request on any active link, + * waiting for send completion + */ +void smc_llc_send_link_delete_all(struct smc_link_group *lgr, bool ord, u32 rsn) +{ + struct smc_llc_msg_del_link delllc = {}; + int i; + + delllc.hd.common.llc_type = SMC_LLC_DELETE_LINK; + smc_llc_init_msg_hdr(&delllc.hd, lgr, sizeof(delllc)); + if (ord) + delllc.hd.flags |= SMC_LLC_FLAG_DEL_LINK_ORDERLY; + delllc.hd.flags |= SMC_LLC_FLAG_DEL_LINK_ALL; + delllc.reason = htonl(rsn); + + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + if (!smc_link_sendable(&lgr->lnk[i])) + continue; + if (!smc_llc_send_message_wait(&lgr->lnk[i], &delllc)) + break; + } +} + +static void smc_llc_process_srv_delete_link(struct smc_link_group *lgr) +{ + struct smc_llc_msg_del_link *del_llc; + struct smc_link *lnk, *lnk_del; + struct smc_llc_qentry *qentry; + int active_links; + int i; + + mutex_lock(&lgr->llc_conf_mutex); + qentry = smc_llc_flow_qentry_clr(&lgr->llc_flow_lcl); + lnk = qentry->link; + del_llc = &qentry->msg.delete_link; + + if (qentry->msg.delete_link.hd.flags & SMC_LLC_FLAG_DEL_LINK_ALL) { + /* delete entire lgr */ + smc_llc_send_link_delete_all(lgr, true, ntohl( + qentry->msg.delete_link.reason)); + smc_lgr_terminate_sched(lgr); + goto out; + } + /* delete single link */ + lnk_del = NULL; + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) { + if (lgr->lnk[i].link_id == del_llc->link_num) { + lnk_del = &lgr->lnk[i]; + break; + } + } + if (!lnk_del) + goto out; /* asymmetric link already deleted */ + + if (smc_link_downing(&lnk_del->state)) { + if (smc_switch_conns(lgr, lnk_del, false)) + smc_wr_tx_wait_no_pending_sends(lnk_del); + } + if (!list_empty(&lgr->list)) { + /* qentry is either a request from peer (send it back to + * initiate the DELETE_LINK processing), or a locally + * enqueued DELETE_LINK request (forward it) + */ + if (!smc_llc_send_message(lnk, &qentry->msg)) { + struct smc_llc_qentry *qentry2; + + qentry2 = smc_llc_wait(lgr, lnk, SMC_LLC_WAIT_TIME, + SMC_LLC_DELETE_LINK); + if (qentry2) + smc_llc_flow_qentry_del(&lgr->llc_flow_lcl); + } + } + smcr_link_clear(lnk_del, true); + + active_links = smc_llc_active_link_count(lgr); + if (active_links == 1) { + smcr_lgr_set_type(lgr, SMC_LGR_SINGLE); + } else if (!active_links) { + smcr_lgr_set_type(lgr, SMC_LGR_NONE); + smc_lgr_terminate_sched(lgr); + } + + if (lgr->type == SMC_LGR_SINGLE && !list_empty(&lgr->list)) { + /* trigger setup of asymm alt link */ + smc_llc_add_link_local(lnk); + } +out: + mutex_unlock(&lgr->llc_conf_mutex); + kfree(qentry); +} + +static void smc_llc_delete_link_work(struct work_struct *work) +{ + struct smc_link_group *lgr = container_of(work, struct smc_link_group, + llc_del_link_work); + + if (list_empty(&lgr->list)) { + /* link group is terminating */ + smc_llc_flow_qentry_del(&lgr->llc_flow_lcl); + goto out; + } + + if (lgr->role == SMC_CLNT) + smc_llc_process_cli_delete_link(lgr); + else + smc_llc_process_srv_delete_link(lgr); +out: + smc_llc_flow_stop(lgr, &lgr->llc_flow_lcl); +} + +/* process a confirm_rkey request from peer, remote flow */ +static void smc_llc_rmt_conf_rkey(struct smc_link_group *lgr) +{ + struct smc_llc_msg_confirm_rkey *llc; + struct smc_llc_qentry *qentry; + struct smc_link *link; + int num_entries; + int rk_idx; + int i; + + qentry = lgr->llc_flow_rmt.qentry; + llc = &qentry->msg.confirm_rkey; + link = qentry->link; + + num_entries = llc->rtoken[0].num_rkeys; + if (num_entries > SMC_LLC_RKEYS_PER_MSG) + goto out_err; + /* first rkey entry is for receiving link */ + rk_idx = smc_rtoken_add(link, + llc->rtoken[0].rmb_vaddr, + llc->rtoken[0].rmb_key); + if (rk_idx < 0) + goto out_err; + + for (i = 1; i <= min_t(u8, num_entries, SMC_LLC_RKEYS_PER_MSG - 1); i++) + smc_rtoken_set2(lgr, rk_idx, llc->rtoken[i].link_id, + llc->rtoken[i].rmb_vaddr, + llc->rtoken[i].rmb_key); + /* max links is 3 so there is no need to support conf_rkey_cont msgs */ + goto out; +out_err: + llc->hd.flags |= SMC_LLC_FLAG_RKEY_NEG; + llc->hd.flags |= SMC_LLC_FLAG_RKEY_RETRY; +out: + llc->hd.flags |= SMC_LLC_FLAG_RESP; + smc_llc_init_msg_hdr(&llc->hd, link->lgr, sizeof(*llc)); + smc_llc_send_message(link, &qentry->msg); + smc_llc_flow_qentry_del(&lgr->llc_flow_rmt); +} + +/* process a delete_rkey request from peer, remote flow */ +static void smc_llc_rmt_delete_rkey(struct smc_link_group *lgr) +{ + struct smc_llc_msg_delete_rkey *llc; + struct smc_llc_qentry *qentry; + struct smc_link *link; + u8 err_mask = 0; + int i, max; + + qentry = lgr->llc_flow_rmt.qentry; + llc = &qentry->msg.delete_rkey; + link = qentry->link; + + if (lgr->smc_version == SMC_V2) { + struct smc_llc_msg_delete_rkey_v2 *llcv2; + + memcpy(lgr->wr_rx_buf_v2, llc, sizeof(*llc)); + llcv2 = (struct smc_llc_msg_delete_rkey_v2 *)lgr->wr_rx_buf_v2; + llcv2->num_inval_rkeys = 0; + + max = min_t(u8, llcv2->num_rkeys, SMC_LLC_RKEYS_PER_MSG_V2); + for (i = 0; i < max; i++) { + if (smc_rtoken_delete(link, llcv2->rkey[i])) + llcv2->num_inval_rkeys++; + } + memset(&llc->rkey[0], 0, sizeof(llc->rkey)); + memset(&llc->reserved2, 0, sizeof(llc->reserved2)); + smc_llc_init_msg_hdr(&llc->hd, link->lgr, sizeof(*llc)); + if (llcv2->num_inval_rkeys) { + llc->hd.flags |= SMC_LLC_FLAG_RKEY_NEG; + llc->err_mask = llcv2->num_inval_rkeys; + } + goto finish; + } + + max = min_t(u8, llc->num_rkeys, SMC_LLC_DEL_RKEY_MAX); + for (i = 0; i < max; i++) { + if (smc_rtoken_delete(link, llc->rkey[i])) + err_mask |= 1 << (SMC_LLC_DEL_RKEY_MAX - 1 - i); + } + if (err_mask) { + llc->hd.flags |= SMC_LLC_FLAG_RKEY_NEG; + llc->err_mask = err_mask; + } +finish: + llc->hd.flags |= SMC_LLC_FLAG_RESP; + smc_llc_send_message(link, &qentry->msg); + smc_llc_flow_qentry_del(&lgr->llc_flow_rmt); +} + +static void smc_llc_protocol_violation(struct smc_link_group *lgr, u8 type) +{ + pr_warn_ratelimited("smc: SMC-R lg %*phN net %llu LLC protocol violation: " + "llc_type %d\n", SMC_LGR_ID_SIZE, &lgr->id, + lgr->net->net_cookie, type); + smc_llc_set_termination_rsn(lgr, SMC_LLC_DEL_PROT_VIOL); + smc_lgr_terminate_sched(lgr); +} + +/* flush the llc event queue */ +static void smc_llc_event_flush(struct smc_link_group *lgr) +{ + struct smc_llc_qentry *qentry, *q; + + spin_lock_bh(&lgr->llc_event_q_lock); + list_for_each_entry_safe(qentry, q, &lgr->llc_event_q, list) { + list_del_init(&qentry->list); + kfree(qentry); + } + spin_unlock_bh(&lgr->llc_event_q_lock); +} + +static void smc_llc_event_handler(struct smc_llc_qentry *qentry) +{ + union smc_llc_msg *llc = &qentry->msg; + struct smc_link *link = qentry->link; + struct smc_link_group *lgr = link->lgr; + + if (!smc_link_usable(link)) + goto out; + + switch (llc->raw.hdr.common.llc_type) { + case SMC_LLC_TEST_LINK: + llc->test_link.hd.flags |= SMC_LLC_FLAG_RESP; + smc_llc_send_message(link, llc); + break; + case SMC_LLC_ADD_LINK: + if (list_empty(&lgr->list)) + goto out; /* lgr is terminating */ + if (lgr->role == SMC_CLNT) { + if (smc_llc_is_local_add_link(llc)) { + if (lgr->llc_flow_lcl.type == + SMC_LLC_FLOW_ADD_LINK) + break; /* add_link in progress */ + if (smc_llc_flow_start(&lgr->llc_flow_lcl, + qentry)) { + schedule_work(&lgr->llc_add_link_work); + } + return; + } + if (lgr->llc_flow_lcl.type == SMC_LLC_FLOW_ADD_LINK && + !lgr->llc_flow_lcl.qentry) { + /* a flow is waiting for this message */ + smc_llc_flow_qentry_set(&lgr->llc_flow_lcl, + qentry); + wake_up(&lgr->llc_msg_waiter); + return; + } + if (lgr->llc_flow_lcl.type == + SMC_LLC_FLOW_REQ_ADD_LINK) { + /* server started add_link processing */ + lgr->llc_flow_lcl.type = SMC_LLC_FLOW_ADD_LINK; + smc_llc_flow_qentry_set(&lgr->llc_flow_lcl, + qentry); + schedule_work(&lgr->llc_add_link_work); + return; + } + if (smc_llc_flow_start(&lgr->llc_flow_lcl, qentry)) { + schedule_work(&lgr->llc_add_link_work); + } + } else if (smc_llc_flow_start(&lgr->llc_flow_lcl, qentry)) { + /* as smc server, handle client suggestion */ + schedule_work(&lgr->llc_add_link_work); + } + return; + case SMC_LLC_CONFIRM_LINK: + case SMC_LLC_ADD_LINK_CONT: + if (lgr->llc_flow_lcl.type != SMC_LLC_FLOW_NONE) { + /* a flow is waiting for this message */ + smc_llc_flow_qentry_set(&lgr->llc_flow_lcl, qentry); + wake_up(&lgr->llc_msg_waiter); + return; + } + break; + case SMC_LLC_DELETE_LINK: + if (lgr->llc_flow_lcl.type == SMC_LLC_FLOW_ADD_LINK && + !lgr->llc_flow_lcl.qentry) { + /* DEL LINK REQ during ADD LINK SEQ */ + smc_llc_flow_qentry_set(&lgr->llc_flow_lcl, qentry); + wake_up(&lgr->llc_msg_waiter); + } else if (smc_llc_flow_start(&lgr->llc_flow_lcl, qentry)) { + schedule_work(&lgr->llc_del_link_work); + } + return; + case SMC_LLC_CONFIRM_RKEY: + /* new request from remote, assign to remote flow */ + if (smc_llc_flow_start(&lgr->llc_flow_rmt, qentry)) { + /* process here, does not wait for more llc msgs */ + smc_llc_rmt_conf_rkey(lgr); + smc_llc_flow_stop(lgr, &lgr->llc_flow_rmt); + } + return; + case SMC_LLC_CONFIRM_RKEY_CONT: + /* not used because max links is 3, and 3 rkeys fit into + * one CONFIRM_RKEY message + */ + break; + case SMC_LLC_DELETE_RKEY: + /* new request from remote, assign to remote flow */ + if (smc_llc_flow_start(&lgr->llc_flow_rmt, qentry)) { + /* process here, does not wait for more llc msgs */ + smc_llc_rmt_delete_rkey(lgr); + smc_llc_flow_stop(lgr, &lgr->llc_flow_rmt); + } + return; + case SMC_LLC_REQ_ADD_LINK: + /* handle response here, smc_llc_flow_stop() cannot be called + * in tasklet context + */ + if (lgr->role == SMC_CLNT && + lgr->llc_flow_lcl.type == SMC_LLC_FLOW_REQ_ADD_LINK && + (llc->raw.hdr.flags & SMC_LLC_FLAG_RESP)) { + smc_llc_flow_stop(link->lgr, &lgr->llc_flow_lcl); + } else if (lgr->role == SMC_SERV) { + if (smc_llc_flow_start(&lgr->llc_flow_lcl, qentry)) { + /* as smc server, handle client suggestion */ + lgr->llc_flow_lcl.type = SMC_LLC_FLOW_ADD_LINK; + schedule_work(&lgr->llc_add_link_work); + } + return; + } + break; + default: + smc_llc_protocol_violation(lgr, llc->raw.hdr.common.type); + break; + } +out: + kfree(qentry); +} + +/* worker to process llc messages on the event queue */ +static void smc_llc_event_work(struct work_struct *work) +{ + struct smc_link_group *lgr = container_of(work, struct smc_link_group, + llc_event_work); + struct smc_llc_qentry *qentry; + + if (!lgr->llc_flow_lcl.type && lgr->delayed_event) { + qentry = lgr->delayed_event; + lgr->delayed_event = NULL; + if (smc_link_usable(qentry->link)) + smc_llc_event_handler(qentry); + else + kfree(qentry); + } + +again: + spin_lock_bh(&lgr->llc_event_q_lock); + if (!list_empty(&lgr->llc_event_q)) { + qentry = list_first_entry(&lgr->llc_event_q, + struct smc_llc_qentry, list); + list_del_init(&qentry->list); + spin_unlock_bh(&lgr->llc_event_q_lock); + smc_llc_event_handler(qentry); + goto again; + } + spin_unlock_bh(&lgr->llc_event_q_lock); +} + +/* process llc responses in tasklet context */ +static void smc_llc_rx_response(struct smc_link *link, + struct smc_llc_qentry *qentry) +{ + enum smc_llc_flowtype flowtype = link->lgr->llc_flow_lcl.type; + struct smc_llc_flow *flow = &link->lgr->llc_flow_lcl; + u8 llc_type = qentry->msg.raw.hdr.common.llc_type; + + switch (llc_type) { + case SMC_LLC_TEST_LINK: + if (smc_link_active(link)) + complete(&link->llc_testlink_resp); + break; + case SMC_LLC_ADD_LINK: + case SMC_LLC_ADD_LINK_CONT: + case SMC_LLC_CONFIRM_LINK: + if (flowtype != SMC_LLC_FLOW_ADD_LINK || flow->qentry) + break; /* drop out-of-flow response */ + goto assign; + case SMC_LLC_DELETE_LINK: + if (flowtype != SMC_LLC_FLOW_DEL_LINK || flow->qentry) + break; /* drop out-of-flow response */ + goto assign; + case SMC_LLC_CONFIRM_RKEY: + case SMC_LLC_DELETE_RKEY: + if (flowtype != SMC_LLC_FLOW_RKEY || flow->qentry) + break; /* drop out-of-flow response */ + goto assign; + case SMC_LLC_CONFIRM_RKEY_CONT: + /* not used because max links is 3 */ + break; + default: + smc_llc_protocol_violation(link->lgr, + qentry->msg.raw.hdr.common.type); + break; + } + kfree(qentry); + return; +assign: + /* assign responses to the local flow, we requested them */ + smc_llc_flow_qentry_set(&link->lgr->llc_flow_lcl, qentry); + wake_up(&link->lgr->llc_msg_waiter); +} + +static void smc_llc_enqueue(struct smc_link *link, union smc_llc_msg *llc) +{ + struct smc_link_group *lgr = link->lgr; + struct smc_llc_qentry *qentry; + unsigned long flags; + + qentry = kmalloc(sizeof(*qentry), GFP_ATOMIC); + if (!qentry) + return; + qentry->link = link; + INIT_LIST_HEAD(&qentry->list); + memcpy(&qentry->msg, llc, sizeof(union smc_llc_msg)); + + /* process responses immediately */ + if ((llc->raw.hdr.flags & SMC_LLC_FLAG_RESP) && + llc->raw.hdr.common.llc_type != SMC_LLC_REQ_ADD_LINK) { + smc_llc_rx_response(link, qentry); + return; + } + + /* add requests to event queue */ + spin_lock_irqsave(&lgr->llc_event_q_lock, flags); + list_add_tail(&qentry->list, &lgr->llc_event_q); + spin_unlock_irqrestore(&lgr->llc_event_q_lock, flags); + queue_work(system_highpri_wq, &lgr->llc_event_work); +} + +/* copy received msg and add it to the event queue */ +static void smc_llc_rx_handler(struct ib_wc *wc, void *buf) +{ + struct smc_link *link = (struct smc_link *)wc->qp->qp_context; + union smc_llc_msg *llc = buf; + + if (wc->byte_len < sizeof(*llc)) + return; /* short message */ + if (!llc->raw.hdr.common.llc_version) { + if (llc->raw.hdr.length != sizeof(*llc)) + return; /* invalid message */ + } else { + if (llc->raw.hdr.length_v2 < sizeof(*llc)) + return; /* invalid message */ + } + + smc_llc_enqueue(link, llc); +} + +/***************************** worker, utils *********************************/ + +static void smc_llc_testlink_work(struct work_struct *work) +{ + struct smc_link *link = container_of(to_delayed_work(work), + struct smc_link, llc_testlink_wrk); + unsigned long next_interval; + unsigned long expire_time; + u8 user_data[16] = { 0 }; + int rc; + + if (!smc_link_active(link)) + return; /* don't reschedule worker */ + expire_time = link->wr_rx_tstamp + link->llc_testlink_time; + if (time_is_after_jiffies(expire_time)) { + next_interval = expire_time - jiffies; + goto out; + } + reinit_completion(&link->llc_testlink_resp); + smc_llc_send_test_link(link, user_data); + /* receive TEST LINK response over RoCE fabric */ + rc = wait_for_completion_interruptible_timeout(&link->llc_testlink_resp, + SMC_LLC_WAIT_TIME); + if (!smc_link_active(link)) + return; /* link state changed */ + if (rc <= 0) { + smcr_link_down_cond_sched(link); + return; + } + next_interval = link->llc_testlink_time; +out: + schedule_delayed_work(&link->llc_testlink_wrk, next_interval); +} + +void smc_llc_lgr_init(struct smc_link_group *lgr, struct smc_sock *smc) +{ + struct net *net = sock_net(smc->clcsock->sk); + + INIT_WORK(&lgr->llc_event_work, smc_llc_event_work); + INIT_WORK(&lgr->llc_add_link_work, smc_llc_add_link_work); + INIT_WORK(&lgr->llc_del_link_work, smc_llc_delete_link_work); + INIT_LIST_HEAD(&lgr->llc_event_q); + spin_lock_init(&lgr->llc_event_q_lock); + spin_lock_init(&lgr->llc_flow_lock); + init_waitqueue_head(&lgr->llc_flow_waiter); + init_waitqueue_head(&lgr->llc_msg_waiter); + mutex_init(&lgr->llc_conf_mutex); + lgr->llc_testlink_time = READ_ONCE(net->smc.sysctl_smcr_testlink_time); +} + +/* called after lgr was removed from lgr_list */ +void smc_llc_lgr_clear(struct smc_link_group *lgr) +{ + smc_llc_event_flush(lgr); + wake_up_all(&lgr->llc_flow_waiter); + wake_up_all(&lgr->llc_msg_waiter); + cancel_work_sync(&lgr->llc_event_work); + cancel_work_sync(&lgr->llc_add_link_work); + cancel_work_sync(&lgr->llc_del_link_work); + if (lgr->delayed_event) { + kfree(lgr->delayed_event); + lgr->delayed_event = NULL; + } +} + +int smc_llc_link_init(struct smc_link *link) +{ + init_completion(&link->llc_testlink_resp); + INIT_DELAYED_WORK(&link->llc_testlink_wrk, smc_llc_testlink_work); + return 0; +} + +void smc_llc_link_active(struct smc_link *link) +{ + pr_warn_ratelimited("smc: SMC-R lg %*phN net %llu link added: id %*phN, " + "peerid %*phN, ibdev %s, ibport %d\n", + SMC_LGR_ID_SIZE, &link->lgr->id, + link->lgr->net->net_cookie, + SMC_LGR_ID_SIZE, &link->link_uid, + SMC_LGR_ID_SIZE, &link->peer_link_uid, + link->smcibdev->ibdev->name, link->ibport); + link->state = SMC_LNK_ACTIVE; + if (link->lgr->llc_testlink_time) { + link->llc_testlink_time = link->lgr->llc_testlink_time; + schedule_delayed_work(&link->llc_testlink_wrk, + link->llc_testlink_time); + } +} + +/* called in worker context */ +void smc_llc_link_clear(struct smc_link *link, bool log) +{ + if (log) + pr_warn_ratelimited("smc: SMC-R lg %*phN net %llu link removed: id %*phN" + ", peerid %*phN, ibdev %s, ibport %d\n", + SMC_LGR_ID_SIZE, &link->lgr->id, + link->lgr->net->net_cookie, + SMC_LGR_ID_SIZE, &link->link_uid, + SMC_LGR_ID_SIZE, &link->peer_link_uid, + link->smcibdev->ibdev->name, link->ibport); + complete(&link->llc_testlink_resp); + cancel_delayed_work_sync(&link->llc_testlink_wrk); +} + +/* register a new rtoken at the remote peer (for all links) */ +int smc_llc_do_confirm_rkey(struct smc_link *send_link, + struct smc_buf_desc *rmb_desc) +{ + struct smc_link_group *lgr = send_link->lgr; + struct smc_llc_qentry *qentry = NULL; + int rc = 0; + + rc = smc_llc_send_confirm_rkey(send_link, rmb_desc); + if (rc) + goto out; + /* receive CONFIRM RKEY response from server over RoCE fabric */ + qentry = smc_llc_wait(lgr, send_link, SMC_LLC_WAIT_TIME, + SMC_LLC_CONFIRM_RKEY); + if (!qentry || (qentry->msg.raw.hdr.flags & SMC_LLC_FLAG_RKEY_NEG)) + rc = -EFAULT; +out: + if (qentry) + smc_llc_flow_qentry_del(&lgr->llc_flow_lcl); + return rc; +} + +/* unregister an rtoken at the remote peer */ +int smc_llc_do_delete_rkey(struct smc_link_group *lgr, + struct smc_buf_desc *rmb_desc) +{ + struct smc_llc_qentry *qentry = NULL; + struct smc_link *send_link; + int rc = 0; + + send_link = smc_llc_usable_link(lgr); + if (!send_link) + return -ENOLINK; + + /* protected by llc_flow control */ + rc = smc_llc_send_delete_rkey(send_link, rmb_desc); + if (rc) + goto out; + /* receive DELETE RKEY response from server over RoCE fabric */ + qentry = smc_llc_wait(lgr, send_link, SMC_LLC_WAIT_TIME, + SMC_LLC_DELETE_RKEY); + if (!qentry || (qentry->msg.raw.hdr.flags & SMC_LLC_FLAG_RKEY_NEG)) + rc = -EFAULT; +out: + if (qentry) + smc_llc_flow_qentry_del(&lgr->llc_flow_lcl); + return rc; +} + +void smc_llc_link_set_uid(struct smc_link *link) +{ + __be32 link_uid; + + link_uid = htonl(*((u32 *)link->lgr->id) + link->link_id); + memcpy(link->link_uid, &link_uid, SMC_LGR_ID_SIZE); +} + +/* save peers link user id, used for debug purposes */ +void smc_llc_save_peer_uid(struct smc_llc_qentry *qentry) +{ + memcpy(qentry->link->peer_link_uid, qentry->msg.confirm_link.link_uid, + SMC_LGR_ID_SIZE); +} + +/* evaluate confirm link request or response */ +int smc_llc_eval_conf_link(struct smc_llc_qentry *qentry, + enum smc_llc_reqresp type) +{ + if (type == SMC_LLC_REQ) { /* SMC server assigns link_id */ + qentry->link->link_id = qentry->msg.confirm_link.link_num; + smc_llc_link_set_uid(qentry->link); + } + if (!(qentry->msg.raw.hdr.flags & SMC_LLC_FLAG_NO_RMBE_EYEC)) + return -ENOTSUPP; + return 0; +} + +/***************************** init, exit, misc ******************************/ + +static struct smc_wr_rx_handler smc_llc_rx_handlers[] = { + { + .handler = smc_llc_rx_handler, + .type = SMC_LLC_CONFIRM_LINK + }, + { + .handler = smc_llc_rx_handler, + .type = SMC_LLC_TEST_LINK + }, + { + .handler = smc_llc_rx_handler, + .type = SMC_LLC_ADD_LINK + }, + { + .handler = smc_llc_rx_handler, + .type = SMC_LLC_ADD_LINK_CONT + }, + { + .handler = smc_llc_rx_handler, + .type = SMC_LLC_DELETE_LINK + }, + { + .handler = smc_llc_rx_handler, + .type = SMC_LLC_CONFIRM_RKEY + }, + { + .handler = smc_llc_rx_handler, + .type = SMC_LLC_CONFIRM_RKEY_CONT + }, + { + .handler = smc_llc_rx_handler, + .type = SMC_LLC_DELETE_RKEY + }, + /* V2 types */ + { + .handler = smc_llc_rx_handler, + .type = SMC_LLC_CONFIRM_LINK_V2 + }, + { + .handler = smc_llc_rx_handler, + .type = SMC_LLC_TEST_LINK_V2 + }, + { + .handler = smc_llc_rx_handler, + .type = SMC_LLC_ADD_LINK_V2 + }, + { + .handler = smc_llc_rx_handler, + .type = SMC_LLC_DELETE_LINK_V2 + }, + { + .handler = smc_llc_rx_handler, + .type = SMC_LLC_REQ_ADD_LINK_V2 + }, + { + .handler = smc_llc_rx_handler, + .type = SMC_LLC_CONFIRM_RKEY_V2 + }, + { + .handler = smc_llc_rx_handler, + .type = SMC_LLC_DELETE_RKEY_V2 + }, + { + .handler = NULL, + } +}; + +int __init smc_llc_init(void) +{ + struct smc_wr_rx_handler *handler; + int rc = 0; + + for (handler = smc_llc_rx_handlers; handler->handler; handler++) { + INIT_HLIST_NODE(&handler->list); + rc = smc_wr_rx_register_handler(handler); + if (rc) + break; + } + return rc; +} diff --git a/net/smc/smc_llc.h b/net/smc/smc_llc.h new file mode 100644 index 000000000..7e7a3162c --- /dev/null +++ b/net/smc/smc_llc.h @@ -0,0 +1,120 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Definitions for LLC (link layer control) message handling + * + * Copyright IBM Corp. 2016 + * + * Author(s): Klaus Wacker + * Ursula Braun + */ + +#ifndef SMC_LLC_H +#define SMC_LLC_H + +#include "smc_wr.h" + +#define SMC_LLC_FLAG_RESP 0x80 + +#define SMC_LLC_WAIT_FIRST_TIME (5 * HZ) +#define SMC_LLC_WAIT_TIME (2 * HZ) +#define SMC_LLC_TESTLINK_DEFAULT_TIME (30 * HZ) + +enum smc_llc_reqresp { + SMC_LLC_REQ, + SMC_LLC_RESP +}; + +enum smc_llc_msg_type { + SMC_LLC_CONFIRM_LINK = 0x01, + SMC_LLC_ADD_LINK = 0x02, + SMC_LLC_ADD_LINK_CONT = 0x03, + SMC_LLC_DELETE_LINK = 0x04, + SMC_LLC_REQ_ADD_LINK = 0x05, + SMC_LLC_CONFIRM_RKEY = 0x06, + SMC_LLC_TEST_LINK = 0x07, + SMC_LLC_CONFIRM_RKEY_CONT = 0x08, + SMC_LLC_DELETE_RKEY = 0x09, + /* V2 types */ + SMC_LLC_CONFIRM_LINK_V2 = 0x21, + SMC_LLC_ADD_LINK_V2 = 0x22, + SMC_LLC_DELETE_LINK_V2 = 0x24, + SMC_LLC_REQ_ADD_LINK_V2 = 0x25, + SMC_LLC_CONFIRM_RKEY_V2 = 0x26, + SMC_LLC_TEST_LINK_V2 = 0x27, + SMC_LLC_DELETE_RKEY_V2 = 0x29, +}; + +#define smc_link_downing(state) \ + (cmpxchg(state, SMC_LNK_ACTIVE, SMC_LNK_INACTIVE) == SMC_LNK_ACTIVE) + +/* LLC DELETE LINK Request Reason Codes */ +#define SMC_LLC_DEL_LOST_PATH 0x00010000 +#define SMC_LLC_DEL_OP_INIT_TERM 0x00020000 +#define SMC_LLC_DEL_PROG_INIT_TERM 0x00030000 +#define SMC_LLC_DEL_PROT_VIOL 0x00040000 +#define SMC_LLC_DEL_NO_ASYM_NEEDED 0x00050000 +/* LLC DELETE LINK Response Reason Codes */ +#define SMC_LLC_DEL_NOLNK 0x00100000 /* Unknown Link ID (no link) */ +#define SMC_LLC_DEL_NOLGR 0x00200000 /* Unknown Link Group */ + +/* returns a usable link of the link group, or NULL */ +static inline struct smc_link *smc_llc_usable_link(struct smc_link_group *lgr) +{ + int i; + + for (i = 0; i < SMC_LINKS_PER_LGR_MAX; i++) + if (smc_link_usable(&lgr->lnk[i])) + return &lgr->lnk[i]; + return NULL; +} + +/* set the termination reason code for the link group */ +static inline void smc_llc_set_termination_rsn(struct smc_link_group *lgr, + u32 rsn) +{ + if (!lgr->llc_termination_rsn) + lgr->llc_termination_rsn = rsn; +} + +/* transmit */ +int smc_llc_send_confirm_link(struct smc_link *lnk, + enum smc_llc_reqresp reqresp); +int smc_llc_send_add_link(struct smc_link *link, u8 mac[], u8 gid[], + struct smc_link *link_new, + enum smc_llc_reqresp reqresp); +int smc_llc_send_delete_link(struct smc_link *link, u8 link_del_id, + enum smc_llc_reqresp reqresp, bool orderly, + u32 reason); +void smc_llc_srv_delete_link_local(struct smc_link *link, u8 del_link_id); +void smc_llc_lgr_init(struct smc_link_group *lgr, struct smc_sock *smc); +void smc_llc_lgr_clear(struct smc_link_group *lgr); +int smc_llc_link_init(struct smc_link *link); +void smc_llc_link_active(struct smc_link *link); +void smc_llc_link_clear(struct smc_link *link, bool log); +int smc_llc_do_confirm_rkey(struct smc_link *send_link, + struct smc_buf_desc *rmb_desc); +int smc_llc_do_delete_rkey(struct smc_link_group *lgr, + struct smc_buf_desc *rmb_desc); +int smc_llc_flow_initiate(struct smc_link_group *lgr, + enum smc_llc_flowtype type); +void smc_llc_flow_stop(struct smc_link_group *lgr, struct smc_llc_flow *flow); +int smc_llc_eval_conf_link(struct smc_llc_qentry *qentry, + enum smc_llc_reqresp type); +void smc_llc_link_set_uid(struct smc_link *link); +void smc_llc_save_peer_uid(struct smc_llc_qentry *qentry); +struct smc_llc_qentry *smc_llc_wait(struct smc_link_group *lgr, + struct smc_link *lnk, + int time_out, u8 exp_msg); +struct smc_llc_qentry *smc_llc_flow_qentry_clr(struct smc_llc_flow *flow); +void smc_llc_flow_qentry_del(struct smc_llc_flow *flow); +void smc_llc_send_link_delete_all(struct smc_link_group *lgr, bool ord, + u32 rsn); +int smc_llc_cli_add_link(struct smc_link *link, struct smc_llc_qentry *qentry); +int smc_llc_srv_add_link(struct smc_link *link, + struct smc_llc_qentry *req_qentry); +void smc_llc_add_link_local(struct smc_link *link); +int smc_llc_init(void) __init; + +#endif /* SMC_LLC_H */ diff --git a/net/smc/smc_netlink.c b/net/smc/smc_netlink.c new file mode 100644 index 000000000..621c46c70 --- /dev/null +++ b/net/smc/smc_netlink.c @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Generic netlink support functions to interact with SMC module + * + * Copyright IBM Corp. 2020 + * + * Author(s): Guvenc Gulce + */ + +#include +#include +#include +#include +#include +#include + +#include "smc_core.h" +#include "smc_ism.h" +#include "smc_ib.h" +#include "smc_clc.h" +#include "smc_stats.h" +#include "smc_netlink.h" + +const struct nla_policy +smc_gen_ueid_policy[SMC_NLA_EID_TABLE_MAX + 1] = { + [SMC_NLA_EID_TABLE_UNSPEC] = { .type = NLA_UNSPEC }, + [SMC_NLA_EID_TABLE_ENTRY] = { .type = NLA_STRING, + .len = SMC_MAX_EID_LEN, + }, +}; + +#define SMC_CMD_MAX_ATTR 1 +/* SMC_GENL generic netlink operation definition */ +static const struct genl_ops smc_gen_nl_ops[] = { + { + .cmd = SMC_NETLINK_GET_SYS_INFO, + /* can be retrieved by unprivileged users */ + .dumpit = smc_nl_get_sys_info, + }, + { + .cmd = SMC_NETLINK_GET_LGR_SMCR, + /* can be retrieved by unprivileged users */ + .dumpit = smcr_nl_get_lgr, + }, + { + .cmd = SMC_NETLINK_GET_LINK_SMCR, + /* can be retrieved by unprivileged users */ + .dumpit = smcr_nl_get_link, + }, + { + .cmd = SMC_NETLINK_GET_LGR_SMCD, + /* can be retrieved by unprivileged users */ + .dumpit = smcd_nl_get_lgr, + }, + { + .cmd = SMC_NETLINK_GET_DEV_SMCD, + /* can be retrieved by unprivileged users */ + .dumpit = smcd_nl_get_device, + }, + { + .cmd = SMC_NETLINK_GET_DEV_SMCR, + /* can be retrieved by unprivileged users */ + .dumpit = smcr_nl_get_device, + }, + { + .cmd = SMC_NETLINK_GET_STATS, + /* can be retrieved by unprivileged users */ + .dumpit = smc_nl_get_stats, + }, + { + .cmd = SMC_NETLINK_GET_FBACK_STATS, + /* can be retrieved by unprivileged users */ + .dumpit = smc_nl_get_fback_stats, + }, + { + .cmd = SMC_NETLINK_DUMP_UEID, + /* can be retrieved by unprivileged users */ + .dumpit = smc_nl_dump_ueid, + }, + { + .cmd = SMC_NETLINK_ADD_UEID, + .flags = GENL_ADMIN_PERM, + .doit = smc_nl_add_ueid, + .policy = smc_gen_ueid_policy, + }, + { + .cmd = SMC_NETLINK_REMOVE_UEID, + .flags = GENL_ADMIN_PERM, + .doit = smc_nl_remove_ueid, + .policy = smc_gen_ueid_policy, + }, + { + .cmd = SMC_NETLINK_FLUSH_UEID, + .flags = GENL_ADMIN_PERM, + .doit = smc_nl_flush_ueid, + }, + { + .cmd = SMC_NETLINK_DUMP_SEID, + /* can be retrieved by unprivileged users */ + .dumpit = smc_nl_dump_seid, + }, + { + .cmd = SMC_NETLINK_ENABLE_SEID, + .flags = GENL_ADMIN_PERM, + .doit = smc_nl_enable_seid, + }, + { + .cmd = SMC_NETLINK_DISABLE_SEID, + .flags = GENL_ADMIN_PERM, + .doit = smc_nl_disable_seid, + }, + { + .cmd = SMC_NETLINK_DUMP_HS_LIMITATION, + /* can be retrieved by unprivileged users */ + .dumpit = smc_nl_dump_hs_limitation, + }, + { + .cmd = SMC_NETLINK_ENABLE_HS_LIMITATION, + .flags = GENL_ADMIN_PERM, + .doit = smc_nl_enable_hs_limitation, + }, + { + .cmd = SMC_NETLINK_DISABLE_HS_LIMITATION, + .flags = GENL_ADMIN_PERM, + .doit = smc_nl_disable_hs_limitation, + }, +}; + +static const struct nla_policy smc_gen_nl_policy[2] = { + [SMC_CMD_MAX_ATTR] = { .type = NLA_REJECT, }, +}; + +/* SMC_GENL family definition */ +struct genl_family smc_gen_nl_family __ro_after_init = { + .hdrsize = 0, + .name = SMC_GENL_FAMILY_NAME, + .version = SMC_GENL_FAMILY_VERSION, + .maxattr = SMC_CMD_MAX_ATTR, + .policy = smc_gen_nl_policy, + .netnsok = true, + .module = THIS_MODULE, + .ops = smc_gen_nl_ops, + .n_ops = ARRAY_SIZE(smc_gen_nl_ops), + .resv_start_op = SMC_NETLINK_DISABLE_HS_LIMITATION + 1, +}; + +int __init smc_nl_init(void) +{ + return genl_register_family(&smc_gen_nl_family); +} + +void smc_nl_exit(void) +{ + genl_unregister_family(&smc_gen_nl_family); +} diff --git a/net/smc/smc_netlink.h b/net/smc/smc_netlink.h new file mode 100644 index 000000000..e8c6c3f0e --- /dev/null +++ b/net/smc/smc_netlink.h @@ -0,0 +1,34 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * SMC Generic netlink operations + * + * Copyright IBM Corp. 2020 + * + * Author(s): Guvenc Gulce + */ + +#ifndef _SMC_NETLINK_H +#define _SMC_NETLINK_H + +#include +#include + +extern struct genl_family smc_gen_nl_family; + +extern const struct nla_policy smc_gen_ueid_policy[]; + +struct smc_nl_dmp_ctx { + int pos[3]; +}; + +static inline struct smc_nl_dmp_ctx *smc_nl_dmp_ctx(struct netlink_callback *c) +{ + return (struct smc_nl_dmp_ctx *)c->ctx; +} + +int smc_nl_init(void) __init; +void smc_nl_exit(void); + +#endif diff --git a/net/smc/smc_netns.h b/net/smc/smc_netns.h new file mode 100644 index 000000000..0f4f35aa4 --- /dev/null +++ b/net/smc/smc_netns.h @@ -0,0 +1,21 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Shared Memory Communications + * + * Network namespace definitions. + * + * Copyright IBM Corp. 2018 + */ + +#ifndef SMC_NETNS_H +#define SMC_NETNS_H + +#include "smc_pnet.h" + +extern unsigned int smc_net_id; + +/* per-network namespace private data */ +struct smc_net { + struct smc_pnettable pnettable; + struct smc_pnetids_ndev pnetids_ndev; +}; +#endif diff --git a/net/smc/smc_pnet.c b/net/smc/smc_pnet.c new file mode 100644 index 000000000..25fb2fd18 --- /dev/null +++ b/net/smc/smc_pnet.c @@ -0,0 +1,1206 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Generic netlink support functions to configure an SMC-R PNET table + * + * Copyright IBM Corp. 2016 + * + * Author(s): Thomas Richter + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include +#include "smc_netns.h" + +#include "smc_pnet.h" +#include "smc_ib.h" +#include "smc_ism.h" +#include "smc_core.h" + +static struct net_device *__pnet_find_base_ndev(struct net_device *ndev); +static struct net_device *pnet_find_base_ndev(struct net_device *ndev); + +static const struct nla_policy smc_pnet_policy[SMC_PNETID_MAX + 1] = { + [SMC_PNETID_NAME] = { + .type = NLA_NUL_STRING, + .len = SMC_MAX_PNETID_LEN + }, + [SMC_PNETID_ETHNAME] = { + .type = NLA_NUL_STRING, + .len = IFNAMSIZ - 1 + }, + [SMC_PNETID_IBNAME] = { + .type = NLA_NUL_STRING, + .len = IB_DEVICE_NAME_MAX - 1 + }, + [SMC_PNETID_IBPORT] = { .type = NLA_U8 } +}; + +static struct genl_family smc_pnet_nl_family; + +enum smc_pnet_nametype { + SMC_PNET_ETH = 1, + SMC_PNET_IB = 2, +}; + +/* pnet entry stored in pnet table */ +struct smc_pnetentry { + struct list_head list; + char pnet_name[SMC_MAX_PNETID_LEN + 1]; + enum smc_pnet_nametype type; + union { + struct { + char eth_name[IFNAMSIZ + 1]; + struct net_device *ndev; + netdevice_tracker dev_tracker; + }; + struct { + char ib_name[IB_DEVICE_NAME_MAX + 1]; + u8 ib_port; + }; + }; +}; + +/* Check if the pnetid is set */ +bool smc_pnet_is_pnetid_set(u8 *pnetid) +{ + if (pnetid[0] == 0 || pnetid[0] == _S) + return false; + return true; +} + +/* Check if two given pnetids match */ +static bool smc_pnet_match(u8 *pnetid1, u8 *pnetid2) +{ + int i; + + for (i = 0; i < SMC_MAX_PNETID_LEN; i++) { + if ((pnetid1[i] == 0 || pnetid1[i] == _S) && + (pnetid2[i] == 0 || pnetid2[i] == _S)) + break; + if (pnetid1[i] != pnetid2[i]) + return false; + } + return true; +} + +/* Remove a pnetid from the pnet table. + */ +static int smc_pnet_remove_by_pnetid(struct net *net, char *pnet_name) +{ + struct smc_pnetentry *pnetelem, *tmp_pe; + struct smc_pnettable *pnettable; + struct smc_ib_device *ibdev; + struct smcd_dev *smcd_dev; + struct smc_net *sn; + int rc = -ENOENT; + int ibport; + + /* get pnettable for namespace */ + sn = net_generic(net, smc_net_id); + pnettable = &sn->pnettable; + + /* remove table entry */ + mutex_lock(&pnettable->lock); + list_for_each_entry_safe(pnetelem, tmp_pe, &pnettable->pnetlist, + list) { + if (!pnet_name || + smc_pnet_match(pnetelem->pnet_name, pnet_name)) { + list_del(&pnetelem->list); + if (pnetelem->type == SMC_PNET_ETH && pnetelem->ndev) { + netdev_put(pnetelem->ndev, + &pnetelem->dev_tracker); + pr_warn_ratelimited("smc: net device %s " + "erased user defined " + "pnetid %.16s\n", + pnetelem->eth_name, + pnetelem->pnet_name); + } + kfree(pnetelem); + rc = 0; + } + } + mutex_unlock(&pnettable->lock); + + /* if this is not the initial namespace, stop here */ + if (net != &init_net) + return rc; + + /* remove ib devices */ + mutex_lock(&smc_ib_devices.mutex); + list_for_each_entry(ibdev, &smc_ib_devices.list, list) { + for (ibport = 0; ibport < SMC_MAX_PORTS; ibport++) { + if (ibdev->pnetid_by_user[ibport] && + (!pnet_name || + smc_pnet_match(pnet_name, + ibdev->pnetid[ibport]))) { + pr_warn_ratelimited("smc: ib device %s ibport " + "%d erased user defined " + "pnetid %.16s\n", + ibdev->ibdev->name, + ibport + 1, + ibdev->pnetid[ibport]); + memset(ibdev->pnetid[ibport], 0, + SMC_MAX_PNETID_LEN); + ibdev->pnetid_by_user[ibport] = false; + rc = 0; + } + } + } + mutex_unlock(&smc_ib_devices.mutex); + /* remove smcd devices */ + mutex_lock(&smcd_dev_list.mutex); + list_for_each_entry(smcd_dev, &smcd_dev_list.list, list) { + if (smcd_dev->pnetid_by_user && + (!pnet_name || + smc_pnet_match(pnet_name, smcd_dev->pnetid))) { + pr_warn_ratelimited("smc: smcd device %s " + "erased user defined pnetid " + "%.16s\n", dev_name(&smcd_dev->dev), + smcd_dev->pnetid); + memset(smcd_dev->pnetid, 0, SMC_MAX_PNETID_LEN); + smcd_dev->pnetid_by_user = false; + rc = 0; + } + } + mutex_unlock(&smcd_dev_list.mutex); + return rc; +} + +/* Add the reference to a given network device to the pnet table. + */ +static int smc_pnet_add_by_ndev(struct net_device *ndev) +{ + struct smc_pnetentry *pnetelem, *tmp_pe; + struct smc_pnettable *pnettable; + struct net *net = dev_net(ndev); + struct smc_net *sn; + int rc = -ENOENT; + + /* get pnettable for namespace */ + sn = net_generic(net, smc_net_id); + pnettable = &sn->pnettable; + + mutex_lock(&pnettable->lock); + list_for_each_entry_safe(pnetelem, tmp_pe, &pnettable->pnetlist, list) { + if (pnetelem->type == SMC_PNET_ETH && !pnetelem->ndev && + !strncmp(pnetelem->eth_name, ndev->name, IFNAMSIZ)) { + netdev_hold(ndev, &pnetelem->dev_tracker, GFP_ATOMIC); + pnetelem->ndev = ndev; + rc = 0; + pr_warn_ratelimited("smc: adding net device %s with " + "user defined pnetid %.16s\n", + pnetelem->eth_name, + pnetelem->pnet_name); + break; + } + } + mutex_unlock(&pnettable->lock); + return rc; +} + +/* Remove the reference to a given network device from the pnet table. + */ +static int smc_pnet_remove_by_ndev(struct net_device *ndev) +{ + struct smc_pnetentry *pnetelem, *tmp_pe; + struct smc_pnettable *pnettable; + struct net *net = dev_net(ndev); + struct smc_net *sn; + int rc = -ENOENT; + + /* get pnettable for namespace */ + sn = net_generic(net, smc_net_id); + pnettable = &sn->pnettable; + + mutex_lock(&pnettable->lock); + list_for_each_entry_safe(pnetelem, tmp_pe, &pnettable->pnetlist, list) { + if (pnetelem->type == SMC_PNET_ETH && pnetelem->ndev == ndev) { + netdev_put(pnetelem->ndev, &pnetelem->dev_tracker); + pnetelem->ndev = NULL; + rc = 0; + pr_warn_ratelimited("smc: removing net device %s with " + "user defined pnetid %.16s\n", + pnetelem->eth_name, + pnetelem->pnet_name); + break; + } + } + mutex_unlock(&pnettable->lock); + return rc; +} + +/* Apply pnetid to ib device when no pnetid is set. + */ +static bool smc_pnet_apply_ib(struct smc_ib_device *ib_dev, u8 ib_port, + char *pnet_name) +{ + bool applied = false; + + mutex_lock(&smc_ib_devices.mutex); + if (!smc_pnet_is_pnetid_set(ib_dev->pnetid[ib_port - 1])) { + memcpy(ib_dev->pnetid[ib_port - 1], pnet_name, + SMC_MAX_PNETID_LEN); + ib_dev->pnetid_by_user[ib_port - 1] = true; + applied = true; + } + mutex_unlock(&smc_ib_devices.mutex); + return applied; +} + +/* Apply pnetid to smcd device when no pnetid is set. + */ +static bool smc_pnet_apply_smcd(struct smcd_dev *smcd_dev, char *pnet_name) +{ + bool applied = false; + + mutex_lock(&smcd_dev_list.mutex); + if (!smc_pnet_is_pnetid_set(smcd_dev->pnetid)) { + memcpy(smcd_dev->pnetid, pnet_name, SMC_MAX_PNETID_LEN); + smcd_dev->pnetid_by_user = true; + applied = true; + } + mutex_unlock(&smcd_dev_list.mutex); + return applied; +} + +/* The limit for pnetid is 16 characters. + * Valid characters should be (single-byte character set) a-z, A-Z, 0-9. + * Lower case letters are converted to upper case. + * Interior blanks should not be used. + */ +static bool smc_pnetid_valid(const char *pnet_name, char *pnetid) +{ + char *bf = skip_spaces(pnet_name); + size_t len = strlen(bf); + char *end = bf + len; + + if (!len) + return false; + while (--end >= bf && isspace(*end)) + ; + if (end - bf >= SMC_MAX_PNETID_LEN) + return false; + while (bf <= end) { + if (!isalnum(*bf)) + return false; + *pnetid++ = islower(*bf) ? toupper(*bf) : *bf; + bf++; + } + *pnetid = '\0'; + return true; +} + +/* Find an infiniband device by a given name. The device might not exist. */ +static struct smc_ib_device *smc_pnet_find_ib(char *ib_name) +{ + struct smc_ib_device *ibdev; + + mutex_lock(&smc_ib_devices.mutex); + list_for_each_entry(ibdev, &smc_ib_devices.list, list) { + if (!strncmp(ibdev->ibdev->name, ib_name, + sizeof(ibdev->ibdev->name)) || + (ibdev->ibdev->dev.parent && + !strncmp(dev_name(ibdev->ibdev->dev.parent), ib_name, + IB_DEVICE_NAME_MAX - 1))) { + goto out; + } + } + ibdev = NULL; +out: + mutex_unlock(&smc_ib_devices.mutex); + return ibdev; +} + +/* Find an smcd device by a given name. The device might not exist. */ +static struct smcd_dev *smc_pnet_find_smcd(char *smcd_name) +{ + struct smcd_dev *smcd_dev; + + mutex_lock(&smcd_dev_list.mutex); + list_for_each_entry(smcd_dev, &smcd_dev_list.list, list) { + if (!strncmp(dev_name(&smcd_dev->dev), smcd_name, + IB_DEVICE_NAME_MAX - 1)) + goto out; + } + smcd_dev = NULL; +out: + mutex_unlock(&smcd_dev_list.mutex); + return smcd_dev; +} + +static int smc_pnet_add_eth(struct smc_pnettable *pnettable, struct net *net, + char *eth_name, char *pnet_name) +{ + struct smc_pnetentry *tmp_pe, *new_pe; + struct net_device *ndev, *base_ndev; + u8 ndev_pnetid[SMC_MAX_PNETID_LEN]; + bool new_netdev; + int rc; + + /* check if (base) netdev already has a pnetid. If there is one, we do + * not want to add a pnet table entry + */ + rc = -EEXIST; + ndev = dev_get_by_name(net, eth_name); /* dev_hold() */ + if (ndev) { + base_ndev = pnet_find_base_ndev(ndev); + if (!smc_pnetid_by_dev_port(base_ndev->dev.parent, + base_ndev->dev_port, ndev_pnetid)) + goto out_put; + } + + /* add a new netdev entry to the pnet table if there isn't one */ + rc = -ENOMEM; + new_pe = kzalloc(sizeof(*new_pe), GFP_KERNEL); + if (!new_pe) + goto out_put; + new_pe->type = SMC_PNET_ETH; + memcpy(new_pe->pnet_name, pnet_name, SMC_MAX_PNETID_LEN); + strncpy(new_pe->eth_name, eth_name, IFNAMSIZ); + rc = -EEXIST; + new_netdev = true; + mutex_lock(&pnettable->lock); + list_for_each_entry(tmp_pe, &pnettable->pnetlist, list) { + if (tmp_pe->type == SMC_PNET_ETH && + !strncmp(tmp_pe->eth_name, eth_name, IFNAMSIZ)) { + new_netdev = false; + break; + } + } + if (new_netdev) { + if (ndev) { + new_pe->ndev = ndev; + netdev_tracker_alloc(ndev, &new_pe->dev_tracker, + GFP_ATOMIC); + } + list_add_tail(&new_pe->list, &pnettable->pnetlist); + mutex_unlock(&pnettable->lock); + } else { + mutex_unlock(&pnettable->lock); + kfree(new_pe); + goto out_put; + } + if (ndev) + pr_warn_ratelimited("smc: net device %s " + "applied user defined pnetid %.16s\n", + new_pe->eth_name, new_pe->pnet_name); + return 0; + +out_put: + dev_put(ndev); + return rc; +} + +static int smc_pnet_add_ib(struct smc_pnettable *pnettable, char *ib_name, + u8 ib_port, char *pnet_name) +{ + struct smc_pnetentry *tmp_pe, *new_pe; + struct smc_ib_device *ib_dev; + bool smcddev_applied = true; + bool ibdev_applied = true; + struct smcd_dev *smcd_dev; + bool new_ibdev; + + /* try to apply the pnetid to active devices */ + ib_dev = smc_pnet_find_ib(ib_name); + if (ib_dev) { + ibdev_applied = smc_pnet_apply_ib(ib_dev, ib_port, pnet_name); + if (ibdev_applied) + pr_warn_ratelimited("smc: ib device %s ibport %d " + "applied user defined pnetid " + "%.16s\n", ib_dev->ibdev->name, + ib_port, + ib_dev->pnetid[ib_port - 1]); + } + smcd_dev = smc_pnet_find_smcd(ib_name); + if (smcd_dev) { + smcddev_applied = smc_pnet_apply_smcd(smcd_dev, pnet_name); + if (smcddev_applied) + pr_warn_ratelimited("smc: smcd device %s " + "applied user defined pnetid " + "%.16s\n", dev_name(&smcd_dev->dev), + smcd_dev->pnetid); + } + /* Apply fails when a device has a hardware-defined pnetid set, do not + * add a pnet table entry in that case. + */ + if (!ibdev_applied || !smcddev_applied) + return -EEXIST; + + /* add a new ib entry to the pnet table if there isn't one */ + new_pe = kzalloc(sizeof(*new_pe), GFP_KERNEL); + if (!new_pe) + return -ENOMEM; + new_pe->type = SMC_PNET_IB; + memcpy(new_pe->pnet_name, pnet_name, SMC_MAX_PNETID_LEN); + strncpy(new_pe->ib_name, ib_name, IB_DEVICE_NAME_MAX); + new_pe->ib_port = ib_port; + + new_ibdev = true; + mutex_lock(&pnettable->lock); + list_for_each_entry(tmp_pe, &pnettable->pnetlist, list) { + if (tmp_pe->type == SMC_PNET_IB && + !strncmp(tmp_pe->ib_name, ib_name, IB_DEVICE_NAME_MAX)) { + new_ibdev = false; + break; + } + } + if (new_ibdev) { + list_add_tail(&new_pe->list, &pnettable->pnetlist); + mutex_unlock(&pnettable->lock); + } else { + mutex_unlock(&pnettable->lock); + kfree(new_pe); + } + return (new_ibdev) ? 0 : -EEXIST; +} + +/* Append a pnetid to the end of the pnet table if not already on this list. + */ +static int smc_pnet_enter(struct net *net, struct nlattr *tb[]) +{ + char pnet_name[SMC_MAX_PNETID_LEN + 1]; + struct smc_pnettable *pnettable; + bool new_netdev = false; + bool new_ibdev = false; + struct smc_net *sn; + u8 ibport = 1; + char *string; + int rc; + + /* get pnettable for namespace */ + sn = net_generic(net, smc_net_id); + pnettable = &sn->pnettable; + + rc = -EINVAL; + if (!tb[SMC_PNETID_NAME]) + goto error; + string = (char *)nla_data(tb[SMC_PNETID_NAME]); + if (!smc_pnetid_valid(string, pnet_name)) + goto error; + + if (tb[SMC_PNETID_ETHNAME]) { + string = (char *)nla_data(tb[SMC_PNETID_ETHNAME]); + rc = smc_pnet_add_eth(pnettable, net, string, pnet_name); + if (!rc) + new_netdev = true; + else if (rc != -EEXIST) + goto error; + } + + /* if this is not the initial namespace, stop here */ + if (net != &init_net) + return new_netdev ? 0 : -EEXIST; + + rc = -EINVAL; + if (tb[SMC_PNETID_IBNAME]) { + string = (char *)nla_data(tb[SMC_PNETID_IBNAME]); + string = strim(string); + if (tb[SMC_PNETID_IBPORT]) { + ibport = nla_get_u8(tb[SMC_PNETID_IBPORT]); + if (ibport < 1 || ibport > SMC_MAX_PORTS) + goto error; + } + rc = smc_pnet_add_ib(pnettable, string, ibport, pnet_name); + if (!rc) + new_ibdev = true; + else if (rc != -EEXIST) + goto error; + } + return (new_netdev || new_ibdev) ? 0 : -EEXIST; + +error: + return rc; +} + +/* Convert an smc_pnetentry to a netlink attribute sequence */ +static int smc_pnet_set_nla(struct sk_buff *msg, + struct smc_pnetentry *pnetelem) +{ + if (nla_put_string(msg, SMC_PNETID_NAME, pnetelem->pnet_name)) + return -1; + if (pnetelem->type == SMC_PNET_ETH) { + if (nla_put_string(msg, SMC_PNETID_ETHNAME, + pnetelem->eth_name)) + return -1; + } else { + if (nla_put_string(msg, SMC_PNETID_ETHNAME, "n/a")) + return -1; + } + if (pnetelem->type == SMC_PNET_IB) { + if (nla_put_string(msg, SMC_PNETID_IBNAME, pnetelem->ib_name) || + nla_put_u8(msg, SMC_PNETID_IBPORT, pnetelem->ib_port)) + return -1; + } else { + if (nla_put_string(msg, SMC_PNETID_IBNAME, "n/a") || + nla_put_u8(msg, SMC_PNETID_IBPORT, 0xff)) + return -1; + } + + return 0; +} + +static int smc_pnet_add(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = genl_info_net(info); + + return smc_pnet_enter(net, info->attrs); +} + +static int smc_pnet_del(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = genl_info_net(info); + + if (!info->attrs[SMC_PNETID_NAME]) + return -EINVAL; + return smc_pnet_remove_by_pnetid(net, + (char *)nla_data(info->attrs[SMC_PNETID_NAME])); +} + +static int smc_pnet_dump_start(struct netlink_callback *cb) +{ + cb->args[0] = 0; + return 0; +} + +static int smc_pnet_dumpinfo(struct sk_buff *skb, + u32 portid, u32 seq, u32 flags, + struct smc_pnetentry *pnetelem) +{ + void *hdr; + + hdr = genlmsg_put(skb, portid, seq, &smc_pnet_nl_family, + flags, SMC_PNETID_GET); + if (!hdr) + return -ENOMEM; + if (smc_pnet_set_nla(skb, pnetelem) < 0) { + genlmsg_cancel(skb, hdr); + return -EMSGSIZE; + } + genlmsg_end(skb, hdr); + return 0; +} + +static int _smc_pnet_dump(struct net *net, struct sk_buff *skb, u32 portid, + u32 seq, u8 *pnetid, int start_idx) +{ + struct smc_pnettable *pnettable; + struct smc_pnetentry *pnetelem; + struct smc_net *sn; + int idx = 0; + + /* get pnettable for namespace */ + sn = net_generic(net, smc_net_id); + pnettable = &sn->pnettable; + + /* dump pnettable entries */ + mutex_lock(&pnettable->lock); + list_for_each_entry(pnetelem, &pnettable->pnetlist, list) { + if (pnetid && !smc_pnet_match(pnetelem->pnet_name, pnetid)) + continue; + if (idx++ < start_idx) + continue; + /* if this is not the initial namespace, dump only netdev */ + if (net != &init_net && pnetelem->type != SMC_PNET_ETH) + continue; + if (smc_pnet_dumpinfo(skb, portid, seq, NLM_F_MULTI, + pnetelem)) { + --idx; + break; + } + } + mutex_unlock(&pnettable->lock); + return idx; +} + +static int smc_pnet_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + int idx; + + idx = _smc_pnet_dump(net, skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NULL, cb->args[0]); + + cb->args[0] = idx; + return skb->len; +} + +/* Retrieve one PNETID entry */ +static int smc_pnet_get(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = genl_info_net(info); + struct sk_buff *msg; + void *hdr; + + if (!info->attrs[SMC_PNETID_NAME]) + return -EINVAL; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + _smc_pnet_dump(net, msg, info->snd_portid, info->snd_seq, + nla_data(info->attrs[SMC_PNETID_NAME]), 0); + + /* finish multi part message and send it */ + hdr = nlmsg_put(msg, info->snd_portid, info->snd_seq, NLMSG_DONE, 0, + NLM_F_MULTI); + if (!hdr) { + nlmsg_free(msg); + return -EMSGSIZE; + } + return genlmsg_reply(msg, info); +} + +/* Remove and delete all pnetids from pnet table. + */ +static int smc_pnet_flush(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = genl_info_net(info); + + smc_pnet_remove_by_pnetid(net, NULL); + return 0; +} + +/* SMC_PNETID generic netlink operation definition */ +static const struct genl_ops smc_pnet_ops[] = { + { + .cmd = SMC_PNETID_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + /* can be retrieved by unprivileged users */ + .doit = smc_pnet_get, + .dumpit = smc_pnet_dump, + .start = smc_pnet_dump_start + }, + { + .cmd = SMC_PNETID_ADD, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = smc_pnet_add + }, + { + .cmd = SMC_PNETID_DEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = smc_pnet_del + }, + { + .cmd = SMC_PNETID_FLUSH, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_ADMIN_PERM, + .doit = smc_pnet_flush + } +}; + +/* SMC_PNETID family definition */ +static struct genl_family smc_pnet_nl_family __ro_after_init = { + .hdrsize = 0, + .name = SMCR_GENL_FAMILY_NAME, + .version = SMCR_GENL_FAMILY_VERSION, + .maxattr = SMC_PNETID_MAX, + .policy = smc_pnet_policy, + .netnsok = true, + .module = THIS_MODULE, + .ops = smc_pnet_ops, + .n_ops = ARRAY_SIZE(smc_pnet_ops), + .resv_start_op = SMC_PNETID_FLUSH + 1, +}; + +bool smc_pnet_is_ndev_pnetid(struct net *net, u8 *pnetid) +{ + struct smc_net *sn = net_generic(net, smc_net_id); + struct smc_pnetids_ndev_entry *pe; + bool rc = false; + + read_lock(&sn->pnetids_ndev.lock); + list_for_each_entry(pe, &sn->pnetids_ndev.list, list) { + if (smc_pnet_match(pnetid, pe->pnetid)) { + rc = true; + goto unlock; + } + } + +unlock: + read_unlock(&sn->pnetids_ndev.lock); + return rc; +} + +static int smc_pnet_add_pnetid(struct net *net, u8 *pnetid) +{ + struct smc_net *sn = net_generic(net, smc_net_id); + struct smc_pnetids_ndev_entry *pe, *pi; + + pe = kzalloc(sizeof(*pe), GFP_KERNEL); + if (!pe) + return -ENOMEM; + + write_lock(&sn->pnetids_ndev.lock); + list_for_each_entry(pi, &sn->pnetids_ndev.list, list) { + if (smc_pnet_match(pnetid, pe->pnetid)) { + refcount_inc(&pi->refcnt); + kfree(pe); + goto unlock; + } + } + refcount_set(&pe->refcnt, 1); + memcpy(pe->pnetid, pnetid, SMC_MAX_PNETID_LEN); + list_add_tail(&pe->list, &sn->pnetids_ndev.list); + +unlock: + write_unlock(&sn->pnetids_ndev.lock); + return 0; +} + +static void smc_pnet_remove_pnetid(struct net *net, u8 *pnetid) +{ + struct smc_net *sn = net_generic(net, smc_net_id); + struct smc_pnetids_ndev_entry *pe, *pe2; + + write_lock(&sn->pnetids_ndev.lock); + list_for_each_entry_safe(pe, pe2, &sn->pnetids_ndev.list, list) { + if (smc_pnet_match(pnetid, pe->pnetid)) { + if (refcount_dec_and_test(&pe->refcnt)) { + list_del(&pe->list); + kfree(pe); + } + break; + } + } + write_unlock(&sn->pnetids_ndev.lock); +} + +static void smc_pnet_add_base_pnetid(struct net *net, struct net_device *dev, + u8 *ndev_pnetid) +{ + struct net_device *base_dev; + + base_dev = __pnet_find_base_ndev(dev); + if (base_dev->flags & IFF_UP && + !smc_pnetid_by_dev_port(base_dev->dev.parent, base_dev->dev_port, + ndev_pnetid)) { + /* add to PNETIDs list */ + smc_pnet_add_pnetid(net, ndev_pnetid); + } +} + +/* create initial list of netdevice pnetids */ +static void smc_pnet_create_pnetids_list(struct net *net) +{ + u8 ndev_pnetid[SMC_MAX_PNETID_LEN]; + struct net_device *dev; + + rtnl_lock(); + for_each_netdev(net, dev) + smc_pnet_add_base_pnetid(net, dev, ndev_pnetid); + rtnl_unlock(); +} + +/* clean up list of netdevice pnetids */ +static void smc_pnet_destroy_pnetids_list(struct net *net) +{ + struct smc_net *sn = net_generic(net, smc_net_id); + struct smc_pnetids_ndev_entry *pe, *temp_pe; + + write_lock(&sn->pnetids_ndev.lock); + list_for_each_entry_safe(pe, temp_pe, &sn->pnetids_ndev.list, list) { + list_del(&pe->list); + kfree(pe); + } + write_unlock(&sn->pnetids_ndev.lock); +} + +static int smc_pnet_netdev_event(struct notifier_block *this, + unsigned long event, void *ptr) +{ + struct net_device *event_dev = netdev_notifier_info_to_dev(ptr); + struct net *net = dev_net(event_dev); + u8 ndev_pnetid[SMC_MAX_PNETID_LEN]; + + switch (event) { + case NETDEV_REBOOT: + case NETDEV_UNREGISTER: + smc_pnet_remove_by_ndev(event_dev); + smc_ib_ndev_change(event_dev, event); + return NOTIFY_OK; + case NETDEV_REGISTER: + smc_pnet_add_by_ndev(event_dev); + smc_ib_ndev_change(event_dev, event); + return NOTIFY_OK; + case NETDEV_UP: + smc_pnet_add_base_pnetid(net, event_dev, ndev_pnetid); + return NOTIFY_OK; + case NETDEV_DOWN: + event_dev = __pnet_find_base_ndev(event_dev); + if (!smc_pnetid_by_dev_port(event_dev->dev.parent, + event_dev->dev_port, ndev_pnetid)) { + /* remove from PNETIDs list */ + smc_pnet_remove_pnetid(net, ndev_pnetid); + } + return NOTIFY_OK; + default: + return NOTIFY_DONE; + } +} + +static struct notifier_block smc_netdev_notifier = { + .notifier_call = smc_pnet_netdev_event +}; + +/* init network namespace */ +int smc_pnet_net_init(struct net *net) +{ + struct smc_net *sn = net_generic(net, smc_net_id); + struct smc_pnettable *pnettable = &sn->pnettable; + struct smc_pnetids_ndev *pnetids_ndev = &sn->pnetids_ndev; + + INIT_LIST_HEAD(&pnettable->pnetlist); + mutex_init(&pnettable->lock); + INIT_LIST_HEAD(&pnetids_ndev->list); + rwlock_init(&pnetids_ndev->lock); + + smc_pnet_create_pnetids_list(net); + + /* disable handshake limitation by default */ + net->smc.limit_smc_hs = 0; + + return 0; +} + +int __init smc_pnet_init(void) +{ + int rc; + + rc = genl_register_family(&smc_pnet_nl_family); + if (rc) + return rc; + rc = register_netdevice_notifier(&smc_netdev_notifier); + if (rc) + genl_unregister_family(&smc_pnet_nl_family); + + return rc; +} + +/* exit network namespace */ +void smc_pnet_net_exit(struct net *net) +{ + /* flush pnet table */ + smc_pnet_remove_by_pnetid(net, NULL); + smc_pnet_destroy_pnetids_list(net); +} + +void smc_pnet_exit(void) +{ + unregister_netdevice_notifier(&smc_netdev_notifier); + genl_unregister_family(&smc_pnet_nl_family); +} + +static struct net_device *__pnet_find_base_ndev(struct net_device *ndev) +{ + int i, nest_lvl; + + ASSERT_RTNL(); + nest_lvl = ndev->lower_level; + for (i = 0; i < nest_lvl; i++) { + struct list_head *lower = &ndev->adj_list.lower; + + if (list_empty(lower)) + break; + lower = lower->next; + ndev = netdev_lower_get_next(ndev, &lower); + } + return ndev; +} + +/* Determine one base device for stacked net devices. + * If the lower device level contains more than one devices + * (for instance with bonding slaves), just the first device + * is used to reach a base device. + */ +static struct net_device *pnet_find_base_ndev(struct net_device *ndev) +{ + rtnl_lock(); + ndev = __pnet_find_base_ndev(ndev); + rtnl_unlock(); + return ndev; +} + +static int smc_pnet_find_ndev_pnetid_by_table(struct net_device *ndev, + u8 *pnetid) +{ + struct smc_pnettable *pnettable; + struct net *net = dev_net(ndev); + struct smc_pnetentry *pnetelem; + struct smc_net *sn; + int rc = -ENOENT; + + /* get pnettable for namespace */ + sn = net_generic(net, smc_net_id); + pnettable = &sn->pnettable; + + mutex_lock(&pnettable->lock); + list_for_each_entry(pnetelem, &pnettable->pnetlist, list) { + if (pnetelem->type == SMC_PNET_ETH && ndev == pnetelem->ndev) { + /* get pnetid of netdev device */ + memcpy(pnetid, pnetelem->pnet_name, SMC_MAX_PNETID_LEN); + rc = 0; + break; + } + } + mutex_unlock(&pnettable->lock); + return rc; +} + +static int smc_pnet_determine_gid(struct smc_ib_device *ibdev, int i, + struct smc_init_info *ini) +{ + if (!ini->check_smcrv2 && + !smc_ib_determine_gid(ibdev, i, ini->vlan_id, ini->ib_gid, NULL, + NULL)) { + ini->ib_dev = ibdev; + ini->ib_port = i; + return 0; + } + if (ini->check_smcrv2 && + !smc_ib_determine_gid(ibdev, i, ini->vlan_id, ini->smcrv2.ib_gid_v2, + NULL, &ini->smcrv2)) { + ini->smcrv2.ib_dev_v2 = ibdev; + ini->smcrv2.ib_port_v2 = i; + return 0; + } + return -ENODEV; +} + +/* find a roce device for the given pnetid */ +static void _smc_pnet_find_roce_by_pnetid(u8 *pnet_id, + struct smc_init_info *ini, + struct smc_ib_device *known_dev, + struct net *net) +{ + struct smc_ib_device *ibdev; + int i; + + mutex_lock(&smc_ib_devices.mutex); + list_for_each_entry(ibdev, &smc_ib_devices.list, list) { + if (ibdev == known_dev || + !rdma_dev_access_netns(ibdev->ibdev, net)) + continue; + for (i = 1; i <= SMC_MAX_PORTS; i++) { + if (!rdma_is_port_valid(ibdev->ibdev, i)) + continue; + if (smc_pnet_match(ibdev->pnetid[i - 1], pnet_id) && + smc_ib_port_active(ibdev, i) && + !test_bit(i - 1, ibdev->ports_going_away)) { + if (!smc_pnet_determine_gid(ibdev, i, ini)) + goto out; + } + } + } +out: + mutex_unlock(&smc_ib_devices.mutex); +} + +/* find alternate roce device with same pnet_id, vlan_id and net namespace */ +void smc_pnet_find_alt_roce(struct smc_link_group *lgr, + struct smc_init_info *ini, + struct smc_ib_device *known_dev) +{ + struct net *net = lgr->net; + + _smc_pnet_find_roce_by_pnetid(lgr->pnet_id, ini, known_dev, net); +} + +/* if handshake network device belongs to a roce device, return its + * IB device and port + */ +static void smc_pnet_find_rdma_dev(struct net_device *netdev, + struct smc_init_info *ini) +{ + struct net *net = dev_net(netdev); + struct smc_ib_device *ibdev; + + mutex_lock(&smc_ib_devices.mutex); + list_for_each_entry(ibdev, &smc_ib_devices.list, list) { + struct net_device *ndev; + int i; + + /* check rdma net namespace */ + if (!rdma_dev_access_netns(ibdev->ibdev, net)) + continue; + + for (i = 1; i <= SMC_MAX_PORTS; i++) { + if (!rdma_is_port_valid(ibdev->ibdev, i)) + continue; + if (!ibdev->ibdev->ops.get_netdev) + continue; + ndev = ibdev->ibdev->ops.get_netdev(ibdev->ibdev, i); + if (!ndev) + continue; + dev_put(ndev); + if (netdev == ndev && + smc_ib_port_active(ibdev, i) && + !test_bit(i - 1, ibdev->ports_going_away)) { + if (!smc_pnet_determine_gid(ibdev, i, ini)) + break; + } + } + } + mutex_unlock(&smc_ib_devices.mutex); +} + +/* Determine the corresponding IB device port based on the hardware PNETID. + * Searching stops at the first matching active IB device port with vlan_id + * configured. + * If nothing found, check pnetid table. + * If nothing found, try to use handshake device + */ +static void smc_pnet_find_roce_by_pnetid(struct net_device *ndev, + struct smc_init_info *ini) +{ + u8 ndev_pnetid[SMC_MAX_PNETID_LEN]; + struct net *net; + + ndev = pnet_find_base_ndev(ndev); + net = dev_net(ndev); + if (smc_pnetid_by_dev_port(ndev->dev.parent, ndev->dev_port, + ndev_pnetid) && + smc_pnet_find_ndev_pnetid_by_table(ndev, ndev_pnetid)) { + smc_pnet_find_rdma_dev(ndev, ini); + return; /* pnetid could not be determined */ + } + _smc_pnet_find_roce_by_pnetid(ndev_pnetid, ini, NULL, net); +} + +static void smc_pnet_find_ism_by_pnetid(struct net_device *ndev, + struct smc_init_info *ini) +{ + u8 ndev_pnetid[SMC_MAX_PNETID_LEN]; + struct smcd_dev *ismdev; + + ndev = pnet_find_base_ndev(ndev); + if (smc_pnetid_by_dev_port(ndev->dev.parent, ndev->dev_port, + ndev_pnetid) && + smc_pnet_find_ndev_pnetid_by_table(ndev, ndev_pnetid)) + return; /* pnetid could not be determined */ + + mutex_lock(&smcd_dev_list.mutex); + list_for_each_entry(ismdev, &smcd_dev_list.list, list) { + if (smc_pnet_match(ismdev->pnetid, ndev_pnetid) && + !ismdev->going_away && + (!ini->ism_peer_gid[0] || + !smc_ism_cantalk(ini->ism_peer_gid[0], ini->vlan_id, + ismdev))) { + ini->ism_dev[0] = ismdev; + break; + } + } + mutex_unlock(&smcd_dev_list.mutex); +} + +/* PNET table analysis for a given sock: + * determine ib_device and port belonging to used internal TCP socket + * ethernet interface. + */ +void smc_pnet_find_roce_resource(struct sock *sk, struct smc_init_info *ini) +{ + struct dst_entry *dst = sk_dst_get(sk); + + if (!dst) + goto out; + if (!dst->dev) + goto out_rel; + + smc_pnet_find_roce_by_pnetid(dst->dev, ini); + +out_rel: + dst_release(dst); +out: + return; +} + +void smc_pnet_find_ism_resource(struct sock *sk, struct smc_init_info *ini) +{ + struct dst_entry *dst = sk_dst_get(sk); + + ini->ism_dev[0] = NULL; + if (!dst) + goto out; + if (!dst->dev) + goto out_rel; + + smc_pnet_find_ism_by_pnetid(dst->dev, ini); + +out_rel: + dst_release(dst); +out: + return; +} + +/* Lookup and apply a pnet table entry to the given ib device. + */ +int smc_pnetid_by_table_ib(struct smc_ib_device *smcibdev, u8 ib_port) +{ + char *ib_name = smcibdev->ibdev->name; + struct smc_pnettable *pnettable; + struct smc_pnetentry *tmp_pe; + struct smc_net *sn; + int rc = -ENOENT; + + /* get pnettable for init namespace */ + sn = net_generic(&init_net, smc_net_id); + pnettable = &sn->pnettable; + + mutex_lock(&pnettable->lock); + list_for_each_entry(tmp_pe, &pnettable->pnetlist, list) { + if (tmp_pe->type == SMC_PNET_IB && + !strncmp(tmp_pe->ib_name, ib_name, IB_DEVICE_NAME_MAX) && + tmp_pe->ib_port == ib_port) { + smc_pnet_apply_ib(smcibdev, ib_port, tmp_pe->pnet_name); + rc = 0; + break; + } + } + mutex_unlock(&pnettable->lock); + + return rc; +} + +/* Lookup and apply a pnet table entry to the given smcd device. + */ +int smc_pnetid_by_table_smcd(struct smcd_dev *smcddev) +{ + const char *ib_name = dev_name(&smcddev->dev); + struct smc_pnettable *pnettable; + struct smc_pnetentry *tmp_pe; + struct smc_net *sn; + int rc = -ENOENT; + + /* get pnettable for init namespace */ + sn = net_generic(&init_net, smc_net_id); + pnettable = &sn->pnettable; + + mutex_lock(&pnettable->lock); + list_for_each_entry(tmp_pe, &pnettable->pnetlist, list) { + if (tmp_pe->type == SMC_PNET_IB && + !strncmp(tmp_pe->ib_name, ib_name, IB_DEVICE_NAME_MAX)) { + smc_pnet_apply_smcd(smcddev, tmp_pe->pnet_name); + rc = 0; + break; + } + } + mutex_unlock(&pnettable->lock); + + return rc; +} diff --git a/net/smc/smc_pnet.h b/net/smc/smc_pnet.h new file mode 100644 index 000000000..80a88eea4 --- /dev/null +++ b/net/smc/smc_pnet.h @@ -0,0 +1,70 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * PNET table queries + * + * Copyright IBM Corp. 2016 + * + * Author(s): Thomas Richter + */ + +#ifndef _SMC_PNET_H +#define _SMC_PNET_H + +#include + +#if IS_ENABLED(CONFIG_HAVE_PNETID) +#include +#endif + +struct smc_ib_device; +struct smcd_dev; +struct smc_init_info; +struct smc_link_group; + +/** + * struct smc_pnettable - SMC PNET table anchor + * @lock: Lock for list action + * @pnetlist: List of PNETIDs + */ +struct smc_pnettable { + struct mutex lock; + struct list_head pnetlist; +}; + +struct smc_pnetids_ndev { /* list of pnetids for net devices in UP state*/ + struct list_head list; + rwlock_t lock; +}; + +struct smc_pnetids_ndev_entry { + struct list_head list; + u8 pnetid[SMC_MAX_PNETID_LEN]; + refcount_t refcnt; +}; + +static inline int smc_pnetid_by_dev_port(struct device *dev, + unsigned short port, u8 *pnetid) +{ +#if IS_ENABLED(CONFIG_HAVE_PNETID) + return pnet_id_by_dev_port(dev, port, pnetid); +#else + return -ENOENT; +#endif +} + +int smc_pnet_init(void) __init; +int smc_pnet_net_init(struct net *net); +void smc_pnet_exit(void); +void smc_pnet_net_exit(struct net *net); +void smc_pnet_find_roce_resource(struct sock *sk, struct smc_init_info *ini); +void smc_pnet_find_ism_resource(struct sock *sk, struct smc_init_info *ini); +int smc_pnetid_by_table_ib(struct smc_ib_device *smcibdev, u8 ib_port); +int smc_pnetid_by_table_smcd(struct smcd_dev *smcd); +void smc_pnet_find_alt_roce(struct smc_link_group *lgr, + struct smc_init_info *ini, + struct smc_ib_device *known_dev); +bool smc_pnet_is_ndev_pnetid(struct net *net, u8 *pnetid); +bool smc_pnet_is_pnetid_set(u8 *pnetid); +#endif diff --git a/net/smc/smc_rx.c b/net/smc/smc_rx.c new file mode 100644 index 000000000..ffcc9996a --- /dev/null +++ b/net/smc/smc_rx.c @@ -0,0 +1,511 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Manage RMBE + * copy new RMBE data into user space + * + * Copyright IBM Corp. 2016 + * + * Author(s): Ursula Braun + */ + +#include +#include +#include + +#include + +#include "smc.h" +#include "smc_core.h" +#include "smc_cdc.h" +#include "smc_tx.h" /* smc_tx_consumer_update() */ +#include "smc_rx.h" +#include "smc_stats.h" +#include "smc_tracepoint.h" + +/* callback implementation to wakeup consumers blocked with smc_rx_wait(). + * indirectly called by smc_cdc_msg_recv_action(). + */ +static void smc_rx_wake_up(struct sock *sk) +{ + struct socket_wq *wq; + + /* derived from sock_def_readable() */ + /* called already in smc_listen_work() */ + rcu_read_lock(); + wq = rcu_dereference(sk->sk_wq); + if (skwq_has_sleeper(wq)) + wake_up_interruptible_sync_poll(&wq->wait, EPOLLIN | EPOLLPRI | + EPOLLRDNORM | EPOLLRDBAND); + sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN); + if ((sk->sk_shutdown == SHUTDOWN_MASK) || + (sk->sk_state == SMC_CLOSED)) + sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_HUP); + rcu_read_unlock(); +} + +/* Update consumer cursor + * @conn connection to update + * @cons consumer cursor + * @len number of Bytes consumed + * Returns: + * 1 if we should end our receive, 0 otherwise + */ +static int smc_rx_update_consumer(struct smc_sock *smc, + union smc_host_cursor cons, size_t len) +{ + struct smc_connection *conn = &smc->conn; + struct sock *sk = &smc->sk; + bool force = false; + int diff, rc = 0; + + smc_curs_add(conn->rmb_desc->len, &cons, len); + + /* did we process urgent data? */ + if (conn->urg_state == SMC_URG_VALID || conn->urg_rx_skip_pend) { + diff = smc_curs_comp(conn->rmb_desc->len, &cons, + &conn->urg_curs); + if (sock_flag(sk, SOCK_URGINLINE)) { + if (diff == 0) { + force = true; + rc = 1; + conn->urg_state = SMC_URG_READ; + } + } else { + if (diff == 1) { + /* skip urgent byte */ + force = true; + smc_curs_add(conn->rmb_desc->len, &cons, 1); + conn->urg_rx_skip_pend = false; + } else if (diff < -1) + /* we read past urgent byte */ + conn->urg_state = SMC_URG_READ; + } + } + + smc_curs_copy(&conn->local_tx_ctrl.cons, &cons, conn); + + /* send consumer cursor update if required */ + /* similar to advertising new TCP rcv_wnd if required */ + smc_tx_consumer_update(conn, force); + + return rc; +} + +static void smc_rx_update_cons(struct smc_sock *smc, size_t len) +{ + struct smc_connection *conn = &smc->conn; + union smc_host_cursor cons; + + smc_curs_copy(&cons, &conn->local_tx_ctrl.cons, conn); + smc_rx_update_consumer(smc, cons, len); +} + +struct smc_spd_priv { + struct smc_sock *smc; + size_t len; +}; + +static void smc_rx_pipe_buf_release(struct pipe_inode_info *pipe, + struct pipe_buffer *buf) +{ + struct smc_spd_priv *priv = (struct smc_spd_priv *)buf->private; + struct smc_sock *smc = priv->smc; + struct smc_connection *conn; + struct sock *sk = &smc->sk; + + if (sk->sk_state == SMC_CLOSED || + sk->sk_state == SMC_PEERFINCLOSEWAIT || + sk->sk_state == SMC_APPFINCLOSEWAIT) + goto out; + conn = &smc->conn; + lock_sock(sk); + smc_rx_update_cons(smc, priv->len); + release_sock(sk); + if (atomic_sub_and_test(priv->len, &conn->splice_pending)) + smc_rx_wake_up(sk); +out: + kfree(priv); + put_page(buf->page); + sock_put(sk); +} + +static const struct pipe_buf_operations smc_pipe_ops = { + .release = smc_rx_pipe_buf_release, + .get = generic_pipe_buf_get +}; + +static void smc_rx_spd_release(struct splice_pipe_desc *spd, + unsigned int i) +{ + put_page(spd->pages[i]); +} + +static int smc_rx_splice(struct pipe_inode_info *pipe, char *src, size_t len, + struct smc_sock *smc) +{ + struct smc_link_group *lgr = smc->conn.lgr; + int offset = offset_in_page(src); + struct partial_page *partial; + struct splice_pipe_desc spd; + struct smc_spd_priv **priv; + struct page **pages; + int bytes, nr_pages; + int i; + + nr_pages = !lgr->is_smcd && smc->conn.rmb_desc->is_vm ? + PAGE_ALIGN(len + offset) / PAGE_SIZE : 1; + + pages = kcalloc(nr_pages, sizeof(*pages), GFP_KERNEL); + if (!pages) + goto out; + partial = kcalloc(nr_pages, sizeof(*partial), GFP_KERNEL); + if (!partial) + goto out_page; + priv = kcalloc(nr_pages, sizeof(*priv), GFP_KERNEL); + if (!priv) + goto out_part; + for (i = 0; i < nr_pages; i++) { + priv[i] = kzalloc(sizeof(**priv), GFP_KERNEL); + if (!priv[i]) + goto out_priv; + } + + if (lgr->is_smcd || + (!lgr->is_smcd && !smc->conn.rmb_desc->is_vm)) { + /* smcd or smcr that uses physically contiguous RMBs */ + priv[0]->len = len; + priv[0]->smc = smc; + partial[0].offset = src - (char *)smc->conn.rmb_desc->cpu_addr; + partial[0].len = len; + partial[0].private = (unsigned long)priv[0]; + pages[0] = smc->conn.rmb_desc->pages; + } else { + int size, left = len; + void *buf = src; + /* smcr that uses virtually contiguous RMBs*/ + for (i = 0; i < nr_pages; i++) { + size = min_t(int, PAGE_SIZE - offset, left); + priv[i]->len = size; + priv[i]->smc = smc; + pages[i] = vmalloc_to_page(buf); + partial[i].offset = offset; + partial[i].len = size; + partial[i].private = (unsigned long)priv[i]; + buf += size / sizeof(*buf); + left -= size; + offset = 0; + } + } + spd.nr_pages_max = nr_pages; + spd.nr_pages = nr_pages; + spd.pages = pages; + spd.partial = partial; + spd.ops = &smc_pipe_ops; + spd.spd_release = smc_rx_spd_release; + + bytes = splice_to_pipe(pipe, &spd); + if (bytes > 0) { + sock_hold(&smc->sk); + if (!lgr->is_smcd && smc->conn.rmb_desc->is_vm) { + for (i = 0; i < PAGE_ALIGN(bytes + offset) / PAGE_SIZE; i++) + get_page(pages[i]); + } else { + get_page(smc->conn.rmb_desc->pages); + } + atomic_add(bytes, &smc->conn.splice_pending); + } + kfree(priv); + kfree(partial); + kfree(pages); + + return bytes; + +out_priv: + for (i = (i - 1); i >= 0; i--) + kfree(priv[i]); + kfree(priv); +out_part: + kfree(partial); +out_page: + kfree(pages); +out: + return -ENOMEM; +} + +static int smc_rx_data_available_and_no_splice_pend(struct smc_connection *conn) +{ + return atomic_read(&conn->bytes_to_rcv) && + !atomic_read(&conn->splice_pending); +} + +/* blocks rcvbuf consumer until >=len bytes available or timeout or interrupted + * @smc smc socket + * @timeo pointer to max seconds to wait, pointer to value 0 for no timeout + * @fcrit add'l criterion to evaluate as function pointer + * Returns: + * 1 if at least 1 byte available in rcvbuf or if socket error/shutdown. + * 0 otherwise (nothing in rcvbuf nor timeout, e.g. interrupted). + */ +int smc_rx_wait(struct smc_sock *smc, long *timeo, + int (*fcrit)(struct smc_connection *conn)) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + struct smc_connection *conn = &smc->conn; + struct smc_cdc_conn_state_flags *cflags = + &conn->local_tx_ctrl.conn_state_flags; + struct sock *sk = &smc->sk; + int rc; + + if (fcrit(conn)) + return 1; + sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); + add_wait_queue(sk_sleep(sk), &wait); + rc = sk_wait_event(sk, timeo, + READ_ONCE(sk->sk_err) || + cflags->peer_conn_abort || + READ_ONCE(sk->sk_shutdown) & RCV_SHUTDOWN || + conn->killed || + fcrit(conn), + &wait); + remove_wait_queue(sk_sleep(sk), &wait); + sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); + return rc; +} + +static int smc_rx_recv_urg(struct smc_sock *smc, struct msghdr *msg, int len, + int flags) +{ + struct smc_connection *conn = &smc->conn; + union smc_host_cursor cons; + struct sock *sk = &smc->sk; + int rc = 0; + + if (sock_flag(sk, SOCK_URGINLINE) || + !(conn->urg_state == SMC_URG_VALID) || + conn->urg_state == SMC_URG_READ) + return -EINVAL; + + SMC_STAT_INC(smc, urg_data_cnt); + if (conn->urg_state == SMC_URG_VALID) { + if (!(flags & MSG_PEEK)) + smc->conn.urg_state = SMC_URG_READ; + msg->msg_flags |= MSG_OOB; + if (len > 0) { + if (!(flags & MSG_TRUNC)) + rc = memcpy_to_msg(msg, &conn->urg_rx_byte, 1); + len = 1; + smc_curs_copy(&cons, &conn->local_tx_ctrl.cons, conn); + if (smc_curs_diff(conn->rmb_desc->len, &cons, + &conn->urg_curs) > 1) + conn->urg_rx_skip_pend = true; + /* Urgent Byte was already accounted for, but trigger + * skipping the urgent byte in non-inline case + */ + if (!(flags & MSG_PEEK)) + smc_rx_update_consumer(smc, cons, 0); + } else { + msg->msg_flags |= MSG_TRUNC; + } + + return rc ? -EFAULT : len; + } + + if (sk->sk_state == SMC_CLOSED || sk->sk_shutdown & RCV_SHUTDOWN) + return 0; + + return -EAGAIN; +} + +static bool smc_rx_recvmsg_data_available(struct smc_sock *smc) +{ + struct smc_connection *conn = &smc->conn; + + if (smc_rx_data_available(conn)) + return true; + else if (conn->urg_state == SMC_URG_VALID) + /* we received a single urgent Byte - skip */ + smc_rx_update_cons(smc, 0); + return false; +} + +/* smc_rx_recvmsg - receive data from RMBE + * @msg: copy data to receive buffer + * @pipe: copy data to pipe if set - indicates splice() call + * + * rcvbuf consumer: main API called by socket layer. + * Called under sk lock. + */ +int smc_rx_recvmsg(struct smc_sock *smc, struct msghdr *msg, + struct pipe_inode_info *pipe, size_t len, int flags) +{ + size_t copylen, read_done = 0, read_remaining = len; + size_t chunk_len, chunk_off, chunk_len_sum; + struct smc_connection *conn = &smc->conn; + int (*func)(struct smc_connection *conn); + union smc_host_cursor cons; + int readable, chunk; + char *rcvbuf_base; + struct sock *sk; + int splbytes; + long timeo; + int target; /* Read at least these many bytes */ + int rc; + + if (unlikely(flags & MSG_ERRQUEUE)) + return -EINVAL; /* future work for sk.sk_family == AF_SMC */ + + sk = &smc->sk; + if (sk->sk_state == SMC_LISTEN) + return -ENOTCONN; + if (flags & MSG_OOB) + return smc_rx_recv_urg(smc, msg, len, flags); + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); + + readable = atomic_read(&conn->bytes_to_rcv); + if (readable >= conn->rmb_desc->len) + SMC_STAT_RMB_RX_FULL(smc, !conn->lnk); + + if (len < readable) + SMC_STAT_RMB_RX_SIZE_SMALL(smc, !conn->lnk); + /* we currently use 1 RMBE per RMB, so RMBE == RMB base addr */ + rcvbuf_base = conn->rx_off + conn->rmb_desc->cpu_addr; + + do { /* while (read_remaining) */ + if (read_done >= target || (pipe && read_done)) + break; + + if (conn->killed) + break; + + if (smc_rx_recvmsg_data_available(smc)) + goto copy; + + if (sk->sk_shutdown & RCV_SHUTDOWN) { + /* smc_cdc_msg_recv_action() could have run after + * above smc_rx_recvmsg_data_available() + */ + if (smc_rx_recvmsg_data_available(smc)) + goto copy; + break; + } + + if (read_done) { + if (sk->sk_err || + sk->sk_state == SMC_CLOSED || + !timeo || + signal_pending(current)) + break; + } else { + if (sk->sk_err) { + read_done = sock_error(sk); + break; + } + if (sk->sk_state == SMC_CLOSED) { + if (!sock_flag(sk, SOCK_DONE)) { + /* This occurs when user tries to read + * from never connected socket. + */ + read_done = -ENOTCONN; + break; + } + break; + } + if (!timeo) + return -EAGAIN; + if (signal_pending(current)) { + read_done = sock_intr_errno(timeo); + break; + } + } + + if (!smc_rx_data_available(conn)) { + smc_rx_wait(smc, &timeo, smc_rx_data_available); + continue; + } + +copy: + /* initialize variables for 1st iteration of subsequent loop */ + /* could be just 1 byte, even after waiting on data above */ + readable = atomic_read(&conn->bytes_to_rcv); + splbytes = atomic_read(&conn->splice_pending); + if (!readable || (msg && splbytes)) { + if (splbytes) + func = smc_rx_data_available_and_no_splice_pend; + else + func = smc_rx_data_available; + smc_rx_wait(smc, &timeo, func); + continue; + } + + smc_curs_copy(&cons, &conn->local_tx_ctrl.cons, conn); + /* subsequent splice() calls pick up where previous left */ + if (splbytes) + smc_curs_add(conn->rmb_desc->len, &cons, splbytes); + if (conn->urg_state == SMC_URG_VALID && + sock_flag(&smc->sk, SOCK_URGINLINE) && + readable > 1) + readable--; /* always stop at urgent Byte */ + /* not more than what user space asked for */ + copylen = min_t(size_t, read_remaining, readable); + /* determine chunks where to read from rcvbuf */ + /* either unwrapped case, or 1st chunk of wrapped case */ + chunk_len = min_t(size_t, copylen, conn->rmb_desc->len - + cons.count); + chunk_len_sum = chunk_len; + chunk_off = cons.count; + smc_rmb_sync_sg_for_cpu(conn); + for (chunk = 0; chunk < 2; chunk++) { + if (!(flags & MSG_TRUNC)) { + if (msg) { + rc = memcpy_to_msg(msg, rcvbuf_base + + chunk_off, + chunk_len); + } else { + rc = smc_rx_splice(pipe, rcvbuf_base + + chunk_off, chunk_len, + smc); + } + if (rc < 0) { + if (!read_done) + read_done = -EFAULT; + goto out; + } + } + read_remaining -= chunk_len; + read_done += chunk_len; + + if (chunk_len_sum == copylen) + break; /* either on 1st or 2nd iteration */ + /* prepare next (== 2nd) iteration */ + chunk_len = copylen - chunk_len; /* remainder */ + chunk_len_sum += chunk_len; + chunk_off = 0; /* modulo offset in recv ring buffer */ + } + + /* update cursors */ + if (!(flags & MSG_PEEK)) { + /* increased in recv tasklet smc_cdc_msg_rcv() */ + smp_mb__before_atomic(); + atomic_sub(copylen, &conn->bytes_to_rcv); + /* guarantee 0 <= bytes_to_rcv <= rmb_desc->len */ + smp_mb__after_atomic(); + if (msg && smc_rx_update_consumer(smc, cons, copylen)) + goto out; + } + + trace_smc_rx_recvmsg(smc, copylen); + } while (read_remaining); +out: + return read_done; +} + +/* Initialize receive properties on connection establishment. NB: not __init! */ +void smc_rx_init(struct smc_sock *smc) +{ + smc->sk.sk_data_ready = smc_rx_wake_up; + atomic_set(&smc->conn.splice_pending, 0); + smc->conn.urg_state = SMC_URG_READ; +} diff --git a/net/smc/smc_rx.h b/net/smc/smc_rx.h new file mode 100644 index 000000000..db823c97d --- /dev/null +++ b/net/smc/smc_rx.h @@ -0,0 +1,31 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Manage RMBE + * + * Copyright IBM Corp. 2016 + * + * Author(s): Ursula Braun + */ + +#ifndef SMC_RX_H +#define SMC_RX_H + +#include +#include + +#include "smc.h" + +void smc_rx_init(struct smc_sock *smc); + +int smc_rx_recvmsg(struct smc_sock *smc, struct msghdr *msg, + struct pipe_inode_info *pipe, size_t len, int flags); +int smc_rx_wait(struct smc_sock *smc, long *timeo, + int (*fcrit)(struct smc_connection *conn)); +static inline int smc_rx_data_available(struct smc_connection *conn) +{ + return atomic_read(&conn->bytes_to_rcv); +} + +#endif /* SMC_RX_H */ diff --git a/net/smc/smc_stats.c b/net/smc/smc_stats.c new file mode 100644 index 000000000..e80e34f7a --- /dev/null +++ b/net/smc/smc_stats.c @@ -0,0 +1,413 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * SMC statistics netlink routines + * + * Copyright IBM Corp. 2021 + * + * Author(s): Guvenc Gulce + */ +#include +#include +#include +#include +#include +#include +#include +#include "smc_netlink.h" +#include "smc_stats.h" + +int smc_stats_init(struct net *net) +{ + net->smc.fback_rsn = kzalloc(sizeof(*net->smc.fback_rsn), GFP_KERNEL); + if (!net->smc.fback_rsn) + goto err_fback; + net->smc.smc_stats = alloc_percpu(struct smc_stats); + if (!net->smc.smc_stats) + goto err_stats; + mutex_init(&net->smc.mutex_fback_rsn); + return 0; + +err_stats: + kfree(net->smc.fback_rsn); +err_fback: + return -ENOMEM; +} + +void smc_stats_exit(struct net *net) +{ + kfree(net->smc.fback_rsn); + if (net->smc.smc_stats) + free_percpu(net->smc.smc_stats); +} + +static int smc_nl_fill_stats_rmb_data(struct sk_buff *skb, + struct smc_stats *stats, int tech, + int type) +{ + struct smc_stats_rmbcnt *stats_rmb_cnt; + struct nlattr *attrs; + + if (type == SMC_NLA_STATS_T_TX_RMB_STATS) + stats_rmb_cnt = &stats->smc[tech].rmb_tx; + else + stats_rmb_cnt = &stats->smc[tech].rmb_rx; + + attrs = nla_nest_start(skb, type); + if (!attrs) + goto errout; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_RMB_REUSE_CNT, + stats_rmb_cnt->reuse_cnt, + SMC_NLA_STATS_RMB_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_RMB_SIZE_SM_PEER_CNT, + stats_rmb_cnt->buf_size_small_peer_cnt, + SMC_NLA_STATS_RMB_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_RMB_SIZE_SM_CNT, + stats_rmb_cnt->buf_size_small_cnt, + SMC_NLA_STATS_RMB_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_RMB_FULL_PEER_CNT, + stats_rmb_cnt->buf_full_peer_cnt, + SMC_NLA_STATS_RMB_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_RMB_FULL_CNT, + stats_rmb_cnt->buf_full_cnt, + SMC_NLA_STATS_RMB_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_RMB_ALLOC_CNT, + stats_rmb_cnt->alloc_cnt, + SMC_NLA_STATS_RMB_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_RMB_DGRADE_CNT, + stats_rmb_cnt->dgrade_cnt, + SMC_NLA_STATS_RMB_PAD)) + goto errattr; + + nla_nest_end(skb, attrs); + return 0; + +errattr: + nla_nest_cancel(skb, attrs); +errout: + return -EMSGSIZE; +} + +static int smc_nl_fill_stats_bufsize_data(struct sk_buff *skb, + struct smc_stats *stats, int tech, + int type) +{ + struct smc_stats_memsize *stats_pload; + struct nlattr *attrs; + + if (type == SMC_NLA_STATS_T_TXPLOAD_SIZE) + stats_pload = &stats->smc[tech].tx_pd; + else if (type == SMC_NLA_STATS_T_RXPLOAD_SIZE) + stats_pload = &stats->smc[tech].rx_pd; + else if (type == SMC_NLA_STATS_T_TX_RMB_SIZE) + stats_pload = &stats->smc[tech].tx_rmbsize; + else if (type == SMC_NLA_STATS_T_RX_RMB_SIZE) + stats_pload = &stats->smc[tech].rx_rmbsize; + else + goto errout; + + attrs = nla_nest_start(skb, type); + if (!attrs) + goto errout; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_PLOAD_8K, + stats_pload->buf[SMC_BUF_8K], + SMC_NLA_STATS_PLOAD_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_PLOAD_16K, + stats_pload->buf[SMC_BUF_16K], + SMC_NLA_STATS_PLOAD_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_PLOAD_32K, + stats_pload->buf[SMC_BUF_32K], + SMC_NLA_STATS_PLOAD_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_PLOAD_64K, + stats_pload->buf[SMC_BUF_64K], + SMC_NLA_STATS_PLOAD_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_PLOAD_128K, + stats_pload->buf[SMC_BUF_128K], + SMC_NLA_STATS_PLOAD_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_PLOAD_256K, + stats_pload->buf[SMC_BUF_256K], + SMC_NLA_STATS_PLOAD_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_PLOAD_512K, + stats_pload->buf[SMC_BUF_512K], + SMC_NLA_STATS_PLOAD_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_PLOAD_1024K, + stats_pload->buf[SMC_BUF_1024K], + SMC_NLA_STATS_PLOAD_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_PLOAD_G_1024K, + stats_pload->buf[SMC_BUF_G_1024K], + SMC_NLA_STATS_PLOAD_PAD)) + goto errattr; + + nla_nest_end(skb, attrs); + return 0; + +errattr: + nla_nest_cancel(skb, attrs); +errout: + return -EMSGSIZE; +} + +static int smc_nl_fill_stats_tech_data(struct sk_buff *skb, + struct smc_stats *stats, int tech) +{ + struct smc_stats_tech *smc_tech; + struct nlattr *attrs; + + smc_tech = &stats->smc[tech]; + if (tech == SMC_TYPE_D) + attrs = nla_nest_start(skb, SMC_NLA_STATS_SMCD_TECH); + else + attrs = nla_nest_start(skb, SMC_NLA_STATS_SMCR_TECH); + + if (!attrs) + goto errout; + if (smc_nl_fill_stats_rmb_data(skb, stats, tech, + SMC_NLA_STATS_T_TX_RMB_STATS)) + goto errattr; + if (smc_nl_fill_stats_rmb_data(skb, stats, tech, + SMC_NLA_STATS_T_RX_RMB_STATS)) + goto errattr; + if (smc_nl_fill_stats_bufsize_data(skb, stats, tech, + SMC_NLA_STATS_T_TXPLOAD_SIZE)) + goto errattr; + if (smc_nl_fill_stats_bufsize_data(skb, stats, tech, + SMC_NLA_STATS_T_RXPLOAD_SIZE)) + goto errattr; + if (smc_nl_fill_stats_bufsize_data(skb, stats, tech, + SMC_NLA_STATS_T_TX_RMB_SIZE)) + goto errattr; + if (smc_nl_fill_stats_bufsize_data(skb, stats, tech, + SMC_NLA_STATS_T_RX_RMB_SIZE)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_T_CLNT_V1_SUCC, + smc_tech->clnt_v1_succ_cnt, + SMC_NLA_STATS_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_T_CLNT_V2_SUCC, + smc_tech->clnt_v2_succ_cnt, + SMC_NLA_STATS_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_T_SRV_V1_SUCC, + smc_tech->srv_v1_succ_cnt, + SMC_NLA_STATS_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_T_SRV_V2_SUCC, + smc_tech->srv_v2_succ_cnt, + SMC_NLA_STATS_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_T_RX_BYTES, + smc_tech->rx_bytes, + SMC_NLA_STATS_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_T_TX_BYTES, + smc_tech->tx_bytes, + SMC_NLA_STATS_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_T_RX_CNT, + smc_tech->rx_cnt, + SMC_NLA_STATS_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_T_TX_CNT, + smc_tech->tx_cnt, + SMC_NLA_STATS_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_T_SENDPAGE_CNT, + smc_tech->sendpage_cnt, + SMC_NLA_STATS_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_T_CORK_CNT, + smc_tech->cork_cnt, + SMC_NLA_STATS_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_T_NDLY_CNT, + smc_tech->ndly_cnt, + SMC_NLA_STATS_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_T_SPLICE_CNT, + smc_tech->splice_cnt, + SMC_NLA_STATS_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_T_URG_DATA_CNT, + smc_tech->urg_data_cnt, + SMC_NLA_STATS_PAD)) + goto errattr; + + nla_nest_end(skb, attrs); + return 0; + +errattr: + nla_nest_cancel(skb, attrs); +errout: + return -EMSGSIZE; +} + +int smc_nl_get_stats(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct smc_nl_dmp_ctx *cb_ctx = smc_nl_dmp_ctx(cb); + struct net *net = sock_net(skb->sk); + struct smc_stats *stats; + struct nlattr *attrs; + int cpu, i, size; + void *nlh; + u64 *src; + u64 *sum; + + if (cb_ctx->pos[0]) + goto errmsg; + nlh = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &smc_gen_nl_family, NLM_F_MULTI, + SMC_NETLINK_GET_STATS); + if (!nlh) + goto errmsg; + + attrs = nla_nest_start(skb, SMC_GEN_STATS); + if (!attrs) + goto errnest; + stats = kzalloc(sizeof(*stats), GFP_KERNEL); + if (!stats) + goto erralloc; + size = sizeof(*stats) / sizeof(u64); + for_each_possible_cpu(cpu) { + src = (u64 *)per_cpu_ptr(net->smc.smc_stats, cpu); + sum = (u64 *)stats; + for (i = 0; i < size; i++) + *(sum++) += *(src++); + } + if (smc_nl_fill_stats_tech_data(skb, stats, SMC_TYPE_D)) + goto errattr; + if (smc_nl_fill_stats_tech_data(skb, stats, SMC_TYPE_R)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_CLNT_HS_ERR_CNT, + stats->clnt_hshake_err_cnt, + SMC_NLA_STATS_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_STATS_SRV_HS_ERR_CNT, + stats->srv_hshake_err_cnt, + SMC_NLA_STATS_PAD)) + goto errattr; + + nla_nest_end(skb, attrs); + genlmsg_end(skb, nlh); + cb_ctx->pos[0] = 1; + kfree(stats); + return skb->len; + +errattr: + kfree(stats); +erralloc: + nla_nest_cancel(skb, attrs); +errnest: + genlmsg_cancel(skb, nlh); +errmsg: + return skb->len; +} + +static int smc_nl_get_fback_details(struct sk_buff *skb, + struct netlink_callback *cb, int pos, + bool is_srv) +{ + struct smc_nl_dmp_ctx *cb_ctx = smc_nl_dmp_ctx(cb); + struct net *net = sock_net(skb->sk); + int cnt_reported = cb_ctx->pos[2]; + struct smc_stats_fback *trgt_arr; + struct nlattr *attrs; + int rc = 0; + void *nlh; + + if (is_srv) + trgt_arr = &net->smc.fback_rsn->srv[0]; + else + trgt_arr = &net->smc.fback_rsn->clnt[0]; + if (!trgt_arr[pos].fback_code) + return -ENODATA; + nlh = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &smc_gen_nl_family, NLM_F_MULTI, + SMC_NETLINK_GET_FBACK_STATS); + if (!nlh) + goto errmsg; + attrs = nla_nest_start(skb, SMC_GEN_FBACK_STATS); + if (!attrs) + goto errout; + if (nla_put_u8(skb, SMC_NLA_FBACK_STATS_TYPE, is_srv)) + goto errattr; + if (!cnt_reported) { + if (nla_put_u64_64bit(skb, SMC_NLA_FBACK_STATS_SRV_CNT, + net->smc.fback_rsn->srv_fback_cnt, + SMC_NLA_FBACK_STATS_PAD)) + goto errattr; + if (nla_put_u64_64bit(skb, SMC_NLA_FBACK_STATS_CLNT_CNT, + net->smc.fback_rsn->clnt_fback_cnt, + SMC_NLA_FBACK_STATS_PAD)) + goto errattr; + cnt_reported = 1; + } + + if (nla_put_u32(skb, SMC_NLA_FBACK_STATS_RSN_CODE, + trgt_arr[pos].fback_code)) + goto errattr; + if (nla_put_u16(skb, SMC_NLA_FBACK_STATS_RSN_CNT, + trgt_arr[pos].count)) + goto errattr; + + cb_ctx->pos[2] = cnt_reported; + nla_nest_end(skb, attrs); + genlmsg_end(skb, nlh); + return rc; + +errattr: + nla_nest_cancel(skb, attrs); +errout: + genlmsg_cancel(skb, nlh); +errmsg: + return -EMSGSIZE; +} + +int smc_nl_get_fback_stats(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct smc_nl_dmp_ctx *cb_ctx = smc_nl_dmp_ctx(cb); + struct net *net = sock_net(skb->sk); + int rc_srv = 0, rc_clnt = 0, k; + int skip_serv = cb_ctx->pos[1]; + int snum = cb_ctx->pos[0]; + bool is_srv = true; + + mutex_lock(&net->smc.mutex_fback_rsn); + for (k = 0; k < SMC_MAX_FBACK_RSN_CNT; k++) { + if (k < snum) + continue; + if (!skip_serv) { + rc_srv = smc_nl_get_fback_details(skb, cb, k, is_srv); + if (rc_srv && rc_srv != -ENODATA) + break; + } else { + skip_serv = 0; + } + rc_clnt = smc_nl_get_fback_details(skb, cb, k, !is_srv); + if (rc_clnt && rc_clnt != -ENODATA) { + skip_serv = 1; + break; + } + if (rc_clnt == -ENODATA && rc_srv == -ENODATA) + break; + } + mutex_unlock(&net->smc.mutex_fback_rsn); + cb_ctx->pos[1] = skip_serv; + cb_ctx->pos[0] = k; + return skb->len; +} diff --git a/net/smc/smc_stats.h b/net/smc/smc_stats.h new file mode 100644 index 000000000..ee22d6f9a --- /dev/null +++ b/net/smc/smc_stats.h @@ -0,0 +1,271 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Macros for SMC statistics + * + * Copyright IBM Corp. 2021 + * + * Author(s): Guvenc Gulce + */ + +#ifndef NET_SMC_SMC_STATS_H_ +#define NET_SMC_SMC_STATS_H_ +#include +#include +#include +#include +#include + +#include "smc_clc.h" + +#define SMC_MAX_FBACK_RSN_CNT 30 + +enum { + SMC_BUF_8K, + SMC_BUF_16K, + SMC_BUF_32K, + SMC_BUF_64K, + SMC_BUF_128K, + SMC_BUF_256K, + SMC_BUF_512K, + SMC_BUF_1024K, + SMC_BUF_G_1024K, + SMC_BUF_MAX, +}; + +struct smc_stats_fback { + int fback_code; + u16 count; +}; + +struct smc_stats_rsn { + struct smc_stats_fback srv[SMC_MAX_FBACK_RSN_CNT]; + struct smc_stats_fback clnt[SMC_MAX_FBACK_RSN_CNT]; + u64 srv_fback_cnt; + u64 clnt_fback_cnt; +}; + +struct smc_stats_rmbcnt { + u64 buf_size_small_peer_cnt; + u64 buf_size_small_cnt; + u64 buf_full_peer_cnt; + u64 buf_full_cnt; + u64 reuse_cnt; + u64 alloc_cnt; + u64 dgrade_cnt; +}; + +struct smc_stats_memsize { + u64 buf[SMC_BUF_MAX]; +}; + +struct smc_stats_tech { + struct smc_stats_memsize tx_rmbsize; + struct smc_stats_memsize rx_rmbsize; + struct smc_stats_memsize tx_pd; + struct smc_stats_memsize rx_pd; + struct smc_stats_rmbcnt rmb_tx; + struct smc_stats_rmbcnt rmb_rx; + u64 clnt_v1_succ_cnt; + u64 clnt_v2_succ_cnt; + u64 srv_v1_succ_cnt; + u64 srv_v2_succ_cnt; + u64 sendpage_cnt; + u64 urg_data_cnt; + u64 splice_cnt; + u64 cork_cnt; + u64 ndly_cnt; + u64 rx_bytes; + u64 tx_bytes; + u64 rx_cnt; + u64 tx_cnt; +}; + +struct smc_stats { + struct smc_stats_tech smc[2]; + u64 clnt_hshake_err_cnt; + u64 srv_hshake_err_cnt; +}; + +#define SMC_STAT_PAYLOAD_SUB(_smc_stats, _tech, key, _len, _rc) \ +do { \ + typeof(_smc_stats) stats = (_smc_stats); \ + typeof(_tech) t = (_tech); \ + typeof(_len) l = (_len); \ + int _pos; \ + typeof(_rc) r = (_rc); \ + int m = SMC_BUF_MAX - 1; \ + this_cpu_inc((*stats).smc[t].key ## _cnt); \ + if (r <= 0 || l <= 0) \ + break; \ + _pos = fls64((l - 1) >> 13); \ + _pos = (_pos <= m) ? _pos : m; \ + this_cpu_inc((*stats).smc[t].key ## _pd.buf[_pos]); \ + this_cpu_add((*stats).smc[t].key ## _bytes, r); \ +} \ +while (0) + +#define SMC_STAT_TX_PAYLOAD(_smc, length, rcode) \ +do { \ + typeof(_smc) __smc = _smc; \ + struct net *_net = sock_net(&__smc->sk); \ + struct smc_stats __percpu *_smc_stats = _net->smc.smc_stats; \ + typeof(length) _len = (length); \ + typeof(rcode) _rc = (rcode); \ + bool is_smcd = !__smc->conn.lnk; \ + if (is_smcd) \ + SMC_STAT_PAYLOAD_SUB(_smc_stats, SMC_TYPE_D, tx, _len, _rc); \ + else \ + SMC_STAT_PAYLOAD_SUB(_smc_stats, SMC_TYPE_R, tx, _len, _rc); \ +} \ +while (0) + +#define SMC_STAT_RX_PAYLOAD(_smc, length, rcode) \ +do { \ + typeof(_smc) __smc = _smc; \ + struct net *_net = sock_net(&__smc->sk); \ + struct smc_stats __percpu *_smc_stats = _net->smc.smc_stats; \ + typeof(length) _len = (length); \ + typeof(rcode) _rc = (rcode); \ + bool is_smcd = !__smc->conn.lnk; \ + if (is_smcd) \ + SMC_STAT_PAYLOAD_SUB(_smc_stats, SMC_TYPE_D, rx, _len, _rc); \ + else \ + SMC_STAT_PAYLOAD_SUB(_smc_stats, SMC_TYPE_R, rx, _len, _rc); \ +} \ +while (0) + +#define SMC_STAT_RMB_SIZE_SUB(_smc_stats, _tech, k, _len) \ +do { \ + typeof(_len) _l = (_len); \ + typeof(_tech) t = (_tech); \ + int _pos; \ + int m = SMC_BUF_MAX - 1; \ + if (_l <= 0) \ + break; \ + _pos = fls((_l - 1) >> 13); \ + _pos = (_pos <= m) ? _pos : m; \ + this_cpu_inc((*(_smc_stats)).smc[t].k ## _rmbsize.buf[_pos]); \ +} \ +while (0) + +#define SMC_STAT_RMB_SUB(_smc_stats, type, t, key) \ + this_cpu_inc((*(_smc_stats)).smc[t].rmb ## _ ## key.type ## _cnt) + +#define SMC_STAT_RMB_SIZE(_smc, _is_smcd, _is_rx, _len) \ +do { \ + struct net *_net = sock_net(&(_smc)->sk); \ + struct smc_stats __percpu *_smc_stats = _net->smc.smc_stats; \ + typeof(_is_smcd) is_d = (_is_smcd); \ + typeof(_is_rx) is_r = (_is_rx); \ + typeof(_len) l = (_len); \ + if ((is_d) && (is_r)) \ + SMC_STAT_RMB_SIZE_SUB(_smc_stats, SMC_TYPE_D, rx, l); \ + if ((is_d) && !(is_r)) \ + SMC_STAT_RMB_SIZE_SUB(_smc_stats, SMC_TYPE_D, tx, l); \ + if (!(is_d) && (is_r)) \ + SMC_STAT_RMB_SIZE_SUB(_smc_stats, SMC_TYPE_R, rx, l); \ + if (!(is_d) && !(is_r)) \ + SMC_STAT_RMB_SIZE_SUB(_smc_stats, SMC_TYPE_R, tx, l); \ +} \ +while (0) + +#define SMC_STAT_RMB(_smc, type, _is_smcd, _is_rx) \ +do { \ + struct net *net = sock_net(&(_smc)->sk); \ + struct smc_stats __percpu *_smc_stats = net->smc.smc_stats; \ + typeof(_is_smcd) is_d = (_is_smcd); \ + typeof(_is_rx) is_r = (_is_rx); \ + if ((is_d) && (is_r)) \ + SMC_STAT_RMB_SUB(_smc_stats, type, SMC_TYPE_D, rx); \ + if ((is_d) && !(is_r)) \ + SMC_STAT_RMB_SUB(_smc_stats, type, SMC_TYPE_D, tx); \ + if (!(is_d) && (is_r)) \ + SMC_STAT_RMB_SUB(_smc_stats, type, SMC_TYPE_R, rx); \ + if (!(is_d) && !(is_r)) \ + SMC_STAT_RMB_SUB(_smc_stats, type, SMC_TYPE_R, tx); \ +} \ +while (0) + +#define SMC_STAT_BUF_REUSE(smc, is_smcd, is_rx) \ + SMC_STAT_RMB(smc, reuse, is_smcd, is_rx) + +#define SMC_STAT_RMB_ALLOC(smc, is_smcd, is_rx) \ + SMC_STAT_RMB(smc, alloc, is_smcd, is_rx) + +#define SMC_STAT_RMB_DOWNGRADED(smc, is_smcd, is_rx) \ + SMC_STAT_RMB(smc, dgrade, is_smcd, is_rx) + +#define SMC_STAT_RMB_TX_PEER_FULL(smc, is_smcd) \ + SMC_STAT_RMB(smc, buf_full_peer, is_smcd, false) + +#define SMC_STAT_RMB_TX_FULL(smc, is_smcd) \ + SMC_STAT_RMB(smc, buf_full, is_smcd, false) + +#define SMC_STAT_RMB_TX_PEER_SIZE_SMALL(smc, is_smcd) \ + SMC_STAT_RMB(smc, buf_size_small_peer, is_smcd, false) + +#define SMC_STAT_RMB_TX_SIZE_SMALL(smc, is_smcd) \ + SMC_STAT_RMB(smc, buf_size_small, is_smcd, false) + +#define SMC_STAT_RMB_RX_SIZE_SMALL(smc, is_smcd) \ + SMC_STAT_RMB(smc, buf_size_small, is_smcd, true) + +#define SMC_STAT_RMB_RX_FULL(smc, is_smcd) \ + SMC_STAT_RMB(smc, buf_full, is_smcd, true) + +#define SMC_STAT_INC(_smc, type) \ +do { \ + typeof(_smc) __smc = _smc; \ + bool is_smcd = !(__smc)->conn.lnk; \ + struct net *net = sock_net(&(__smc)->sk); \ + struct smc_stats __percpu *smc_stats = net->smc.smc_stats; \ + if ((is_smcd)) \ + this_cpu_inc(smc_stats->smc[SMC_TYPE_D].type); \ + else \ + this_cpu_inc(smc_stats->smc[SMC_TYPE_R].type); \ +} \ +while (0) + +#define SMC_STAT_CLNT_SUCC_INC(net, _aclc) \ +do { \ + typeof(_aclc) acl = (_aclc); \ + bool is_v2 = (acl->hdr.version == SMC_V2); \ + bool is_smcd = (acl->hdr.typev1 == SMC_TYPE_D); \ + struct smc_stats __percpu *smc_stats = (net)->smc.smc_stats; \ + if (is_v2 && is_smcd) \ + this_cpu_inc(smc_stats->smc[SMC_TYPE_D].clnt_v2_succ_cnt); \ + else if (is_v2 && !is_smcd) \ + this_cpu_inc(smc_stats->smc[SMC_TYPE_R].clnt_v2_succ_cnt); \ + else if (!is_v2 && is_smcd) \ + this_cpu_inc(smc_stats->smc[SMC_TYPE_D].clnt_v1_succ_cnt); \ + else if (!is_v2 && !is_smcd) \ + this_cpu_inc(smc_stats->smc[SMC_TYPE_R].clnt_v1_succ_cnt); \ +} \ +while (0) + +#define SMC_STAT_SERV_SUCC_INC(net, _ini) \ +do { \ + typeof(_ini) i = (_ini); \ + bool is_smcd = (i->is_smcd); \ + u8 version = is_smcd ? i->smcd_version : i->smcr_version; \ + bool is_v2 = (version & SMC_V2); \ + typeof(net->smc.smc_stats) smc_stats = (net)->smc.smc_stats; \ + if (is_v2 && is_smcd) \ + this_cpu_inc(smc_stats->smc[SMC_TYPE_D].srv_v2_succ_cnt); \ + else if (is_v2 && !is_smcd) \ + this_cpu_inc(smc_stats->smc[SMC_TYPE_R].srv_v2_succ_cnt); \ + else if (!is_v2 && is_smcd) \ + this_cpu_inc(smc_stats->smc[SMC_TYPE_D].srv_v1_succ_cnt); \ + else if (!is_v2 && !is_smcd) \ + this_cpu_inc(smc_stats->smc[SMC_TYPE_R].srv_v1_succ_cnt); \ +} \ +while (0) + +int smc_nl_get_stats(struct sk_buff *skb, struct netlink_callback *cb); +int smc_nl_get_fback_stats(struct sk_buff *skb, struct netlink_callback *cb); +int smc_stats_init(struct net *net); +void smc_stats_exit(struct net *net); + +#endif /* NET_SMC_SMC_STATS_H_ */ diff --git a/net/smc/smc_sysctl.c b/net/smc/smc_sysctl.c new file mode 100644 index 000000000..0b2a957ca --- /dev/null +++ b/net/smc/smc_sysctl.c @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * smc_sysctl.c: sysctl interface to SMC subsystem. + * + * Copyright (c) 2022, Alibaba Inc. + * + * Author: Tony Lu + * + */ + +#include +#include +#include + +#include "smc.h" +#include "smc_core.h" +#include "smc_llc.h" +#include "smc_sysctl.h" + +static int min_sndbuf = SMC_BUF_MIN_SIZE; +static int min_rcvbuf = SMC_BUF_MIN_SIZE; +static int max_sndbuf = INT_MAX / 2; +static int max_rcvbuf = INT_MAX / 2; +static const int net_smc_wmem_init = (64 * 1024); +static const int net_smc_rmem_init = (64 * 1024); + +static struct ctl_table smc_table[] = { + { + .procname = "autocorking_size", + .data = &init_net.smc.sysctl_autocorking_size, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_douintvec, + }, + { + .procname = "smcr_buf_type", + .data = &init_net.smc.sysctl_smcr_buf_type, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_douintvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_TWO, + }, + { + .procname = "smcr_testlink_time", + .data = &init_net.smc.sysctl_smcr_testlink_time, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { + .procname = "wmem", + .data = &init_net.smc.sysctl_wmem, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_sndbuf, + .extra2 = &max_sndbuf, + }, + { + .procname = "rmem", + .data = &init_net.smc.sysctl_rmem, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_rcvbuf, + .extra2 = &max_rcvbuf, + }, + { } +}; + +int __net_init smc_sysctl_net_init(struct net *net) +{ + struct ctl_table *table; + + table = smc_table; + if (!net_eq(net, &init_net)) { + int i; + + table = kmemdup(table, sizeof(smc_table), GFP_KERNEL); + if (!table) + goto err_alloc; + + for (i = 0; i < ARRAY_SIZE(smc_table) - 1; i++) + table[i].data += (void *)net - (void *)&init_net; + } + + net->smc.smc_hdr = register_net_sysctl(net, "net/smc", table); + if (!net->smc.smc_hdr) + goto err_reg; + + net->smc.sysctl_autocorking_size = SMC_AUTOCORKING_DEFAULT_SIZE; + net->smc.sysctl_smcr_buf_type = SMCR_PHYS_CONT_BUFS; + net->smc.sysctl_smcr_testlink_time = SMC_LLC_TESTLINK_DEFAULT_TIME; + WRITE_ONCE(net->smc.sysctl_wmem, net_smc_wmem_init); + WRITE_ONCE(net->smc.sysctl_rmem, net_smc_rmem_init); + + return 0; + +err_reg: + if (!net_eq(net, &init_net)) + kfree(table); +err_alloc: + return -ENOMEM; +} + +void __net_exit smc_sysctl_net_exit(struct net *net) +{ + struct ctl_table *table; + + table = net->smc.smc_hdr->ctl_table_arg; + unregister_net_sysctl_table(net->smc.smc_hdr); + if (!net_eq(net, &init_net)) + kfree(table); +} diff --git a/net/smc/smc_sysctl.h b/net/smc/smc_sysctl.h new file mode 100644 index 000000000..0becc11bd --- /dev/null +++ b/net/smc/smc_sysctl.h @@ -0,0 +1,33 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * smc_sysctl.c: sysctl interface to SMC subsystem. + * + * Copyright (c) 2022, Alibaba Inc. + * + * Author: Tony Lu + * + */ + +#ifndef _SMC_SYSCTL_H +#define _SMC_SYSCTL_H + +#ifdef CONFIG_SYSCTL + +int __net_init smc_sysctl_net_init(struct net *net); +void __net_exit smc_sysctl_net_exit(struct net *net); + +#else + +static inline int smc_sysctl_net_init(struct net *net) +{ + net->smc.sysctl_autocorking_size = SMC_AUTOCORKING_DEFAULT_SIZE; + return 0; +} + +static inline void smc_sysctl_net_exit(struct net *net) { } + +#endif /* CONFIG_SYSCTL */ + +#endif /* _SMC_SYSCTL_H */ diff --git a/net/smc/smc_tracepoint.c b/net/smc/smc_tracepoint.c new file mode 100644 index 000000000..8d47ced5a --- /dev/null +++ b/net/smc/smc_tracepoint.c @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: GPL-2.0-only + +#define CREATE_TRACE_POINTS +#include "smc_tracepoint.h" + +EXPORT_TRACEPOINT_SYMBOL(smc_switch_to_fallback); +EXPORT_TRACEPOINT_SYMBOL(smc_tx_sendmsg); +EXPORT_TRACEPOINT_SYMBOL(smc_rx_recvmsg); +EXPORT_TRACEPOINT_SYMBOL(smcr_link_down); diff --git a/net/smc/smc_tracepoint.h b/net/smc/smc_tracepoint.h new file mode 100644 index 000000000..9fc5e586d --- /dev/null +++ b/net/smc/smc_tracepoint.h @@ -0,0 +1,125 @@ +/* SPDX-License-Identifier: GPL-2.0 */ + +#undef TRACE_SYSTEM +#define TRACE_SYSTEM smc + +#if !defined(_TRACE_SMC_H) || defined(TRACE_HEADER_MULTI_READ) +#define _TRACE_SMC_H + +#include +#include +#include +#include +#include "smc.h" +#include "smc_core.h" + +TRACE_EVENT(smc_switch_to_fallback, + + TP_PROTO(const struct smc_sock *smc, int fallback_rsn), + + TP_ARGS(smc, fallback_rsn), + + TP_STRUCT__entry( + __field(const void *, sk) + __field(const void *, clcsk) + __field(u64, net_cookie) + __field(int, fallback_rsn) + ), + + TP_fast_assign( + const struct sock *sk = &smc->sk; + const struct sock *clcsk = smc->clcsock->sk; + + __entry->sk = sk; + __entry->clcsk = clcsk; + __entry->net_cookie = sock_net(sk)->net_cookie; + __entry->fallback_rsn = fallback_rsn; + ), + + TP_printk("sk=%p clcsk=%p net=%llu fallback_rsn=%d", + __entry->sk, __entry->clcsk, + __entry->net_cookie, __entry->fallback_rsn) +); + +DECLARE_EVENT_CLASS(smc_msg_event, + + TP_PROTO(const struct smc_sock *smc, size_t len), + + TP_ARGS(smc, len), + + TP_STRUCT__entry( + __field(const void *, smc) + __field(u64, net_cookie) + __field(size_t, len) + __string(name, smc->conn.lnk->ibname) + ), + + TP_fast_assign( + const struct sock *sk = &smc->sk; + + __entry->smc = smc; + __entry->net_cookie = sock_net(sk)->net_cookie; + __entry->len = len; + __assign_str(name, smc->conn.lnk->ibname); + ), + + TP_printk("smc=%p net=%llu len=%zu dev=%s", + __entry->smc, __entry->net_cookie, + __entry->len, __get_str(name)) +); + +DEFINE_EVENT(smc_msg_event, smc_tx_sendmsg, + + TP_PROTO(const struct smc_sock *smc, size_t len), + + TP_ARGS(smc, len) +); + +DEFINE_EVENT(smc_msg_event, smc_rx_recvmsg, + + TP_PROTO(const struct smc_sock *smc, size_t len), + + TP_ARGS(smc, len) +); + +TRACE_EVENT(smcr_link_down, + + TP_PROTO(const struct smc_link *lnk, void *location), + + TP_ARGS(lnk, location), + + TP_STRUCT__entry( + __field(const void *, lnk) + __field(const void *, lgr) + __field(u64, net_cookie) + __field(int, state) + __string(name, lnk->ibname) + __field(void *, location) + ), + + TP_fast_assign( + const struct smc_link_group *lgr = lnk->lgr; + + __entry->lnk = lnk; + __entry->lgr = lgr; + __entry->net_cookie = lgr->net->net_cookie; + __entry->state = lnk->state; + __assign_str(name, lnk->ibname); + __entry->location = location; + ), + + TP_printk("lnk=%p lgr=%p net=%llu state=%d dev=%s location=%pS", + __entry->lnk, __entry->lgr, __entry->net_cookie, + __entry->state, __get_str(name), + __entry->location) +); + +#endif /* _TRACE_SMC_H */ + +#undef TRACE_INCLUDE_PATH +#define TRACE_INCLUDE_PATH . + +#undef TRACE_INCLUDE_FILE +#define TRACE_INCLUDE_FILE smc_tracepoint + +#include diff --git a/net/smc/smc_tx.c b/net/smc/smc_tx.c new file mode 100644 index 000000000..45128443f --- /dev/null +++ b/net/smc/smc_tx.c @@ -0,0 +1,779 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Manage send buffer. + * Producer: + * Copy user space data into send buffer, if send buffer space available. + * Consumer: + * Trigger RDMA write into RMBE of peer and send CDC, if RMBE space available. + * + * Copyright IBM Corp. 2016 + * + * Author(s): Ursula Braun + */ + +#include +#include +#include +#include + +#include +#include + +#include "smc.h" +#include "smc_wr.h" +#include "smc_cdc.h" +#include "smc_close.h" +#include "smc_ism.h" +#include "smc_tx.h" +#include "smc_stats.h" +#include "smc_tracepoint.h" + +#define SMC_TX_WORK_DELAY 0 + +/***************************** sndbuf producer *******************************/ + +/* callback implementation for sk.sk_write_space() + * to wakeup sndbuf producers that blocked with smc_tx_wait(). + * called under sk_socket lock. + */ +static void smc_tx_write_space(struct sock *sk) +{ + struct socket *sock = sk->sk_socket; + struct smc_sock *smc = smc_sk(sk); + struct socket_wq *wq; + + /* similar to sk_stream_write_space */ + if (atomic_read(&smc->conn.sndbuf_space) && sock) { + if (test_bit(SOCK_NOSPACE, &sock->flags)) + SMC_STAT_RMB_TX_FULL(smc, !smc->conn.lnk); + clear_bit(SOCK_NOSPACE, &sock->flags); + rcu_read_lock(); + wq = rcu_dereference(sk->sk_wq); + if (skwq_has_sleeper(wq)) + wake_up_interruptible_poll(&wq->wait, + EPOLLOUT | EPOLLWRNORM | + EPOLLWRBAND); + if (wq && wq->fasync_list && !(sk->sk_shutdown & SEND_SHUTDOWN)) + sock_wake_async(wq, SOCK_WAKE_SPACE, POLL_OUT); + rcu_read_unlock(); + } +} + +/* Wakeup sndbuf producers that blocked with smc_tx_wait(). + * Cf. tcp_data_snd_check()=>tcp_check_space()=>tcp_new_space(). + */ +void smc_tx_sndbuf_nonfull(struct smc_sock *smc) +{ + if (smc->sk.sk_socket && + test_bit(SOCK_NOSPACE, &smc->sk.sk_socket->flags)) + smc->sk.sk_write_space(&smc->sk); +} + +/* blocks sndbuf producer until at least one byte of free space available + * or urgent Byte was consumed + */ +static int smc_tx_wait(struct smc_sock *smc, int flags) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + struct smc_connection *conn = &smc->conn; + struct sock *sk = &smc->sk; + long timeo; + int rc = 0; + + /* similar to sk_stream_wait_memory */ + timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); + add_wait_queue(sk_sleep(sk), &wait); + while (1) { + sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk); + if (sk->sk_err || + (sk->sk_shutdown & SEND_SHUTDOWN) || + conn->killed || + conn->local_tx_ctrl.conn_state_flags.peer_done_writing) { + rc = -EPIPE; + break; + } + if (smc_cdc_rxed_any_close(conn)) { + rc = -ECONNRESET; + break; + } + if (!timeo) { + /* ensure EPOLLOUT is subsequently generated */ + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + rc = -EAGAIN; + break; + } + if (signal_pending(current)) { + rc = sock_intr_errno(timeo); + break; + } + sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); + if (atomic_read(&conn->sndbuf_space) && !conn->urg_tx_pend) + break; /* at least 1 byte of free & no urgent data */ + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + sk_wait_event(sk, &timeo, + READ_ONCE(sk->sk_err) || + (READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN) || + smc_cdc_rxed_any_close(conn) || + (atomic_read(&conn->sndbuf_space) && + !conn->urg_tx_pend), + &wait); + } + remove_wait_queue(sk_sleep(sk), &wait); + return rc; +} + +static bool smc_tx_is_corked(struct smc_sock *smc) +{ + struct tcp_sock *tp = tcp_sk(smc->clcsock->sk); + + return (tp->nonagle & TCP_NAGLE_CORK) ? true : false; +} + +/* If we have pending CDC messages, do not send: + * Because CQE of this CDC message will happen shortly, it gives + * a chance to coalesce future sendmsg() payload in to one RDMA Write, + * without need for a timer, and with no latency trade off. + * Algorithm here: + * 1. First message should never cork + * 2. If we have pending Tx CDC messages, wait for the first CDC + * message's completion + * 3. Don't cork to much data in a single RDMA Write to prevent burst + * traffic, total corked message should not exceed sendbuf/2 + */ +static bool smc_should_autocork(struct smc_sock *smc) +{ + struct smc_connection *conn = &smc->conn; + int corking_size; + + corking_size = min_t(unsigned int, conn->sndbuf_desc->len >> 1, + sock_net(&smc->sk)->smc.sysctl_autocorking_size); + + if (atomic_read(&conn->cdc_pend_tx_wr) == 0 || + smc_tx_prepared_sends(conn) > corking_size) + return false; + return true; +} + +static bool smc_tx_should_cork(struct smc_sock *smc, struct msghdr *msg) +{ + struct smc_connection *conn = &smc->conn; + + if (smc_should_autocork(smc)) + return true; + + /* for a corked socket defer the RDMA writes if + * sndbuf_space is still available. The applications + * should known how/when to uncork it. + */ + if ((msg->msg_flags & MSG_MORE || + smc_tx_is_corked(smc) || + msg->msg_flags & MSG_SENDPAGE_NOTLAST) && + atomic_read(&conn->sndbuf_space)) + return true; + + return false; +} + +/* sndbuf producer: main API called by socket layer. + * called under sock lock. + */ +int smc_tx_sendmsg(struct smc_sock *smc, struct msghdr *msg, size_t len) +{ + size_t copylen, send_done = 0, send_remaining = len; + size_t chunk_len, chunk_off, chunk_len_sum; + struct smc_connection *conn = &smc->conn; + union smc_host_cursor prep; + struct sock *sk = &smc->sk; + char *sndbuf_base; + int tx_cnt_prep; + int writespace; + int rc, chunk; + + /* This should be in poll */ + sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); + + if (sk->sk_err || (sk->sk_shutdown & SEND_SHUTDOWN)) { + rc = -EPIPE; + goto out_err; + } + + if (sk->sk_state == SMC_INIT) + return -ENOTCONN; + + if (len > conn->sndbuf_desc->len) + SMC_STAT_RMB_TX_SIZE_SMALL(smc, !conn->lnk); + + if (len > conn->peer_rmbe_size) + SMC_STAT_RMB_TX_PEER_SIZE_SMALL(smc, !conn->lnk); + + if (msg->msg_flags & MSG_OOB) + SMC_STAT_INC(smc, urg_data_cnt); + + while (msg_data_left(msg)) { + if (smc->sk.sk_shutdown & SEND_SHUTDOWN || + (smc->sk.sk_err == ECONNABORTED) || + conn->killed) + return -EPIPE; + if (smc_cdc_rxed_any_close(conn)) + return send_done ?: -ECONNRESET; + + if (msg->msg_flags & MSG_OOB) + conn->local_tx_ctrl.prod_flags.urg_data_pending = 1; + + if (!atomic_read(&conn->sndbuf_space) || conn->urg_tx_pend) { + if (send_done) + return send_done; + rc = smc_tx_wait(smc, msg->msg_flags); + if (rc) + goto out_err; + continue; + } + + /* initialize variables for 1st iteration of subsequent loop */ + /* could be just 1 byte, even after smc_tx_wait above */ + writespace = atomic_read(&conn->sndbuf_space); + /* not more than what user space asked for */ + copylen = min_t(size_t, send_remaining, writespace); + /* determine start of sndbuf */ + sndbuf_base = conn->sndbuf_desc->cpu_addr; + smc_curs_copy(&prep, &conn->tx_curs_prep, conn); + tx_cnt_prep = prep.count; + /* determine chunks where to write into sndbuf */ + /* either unwrapped case, or 1st chunk of wrapped case */ + chunk_len = min_t(size_t, copylen, conn->sndbuf_desc->len - + tx_cnt_prep); + chunk_len_sum = chunk_len; + chunk_off = tx_cnt_prep; + for (chunk = 0; chunk < 2; chunk++) { + rc = memcpy_from_msg(sndbuf_base + chunk_off, + msg, chunk_len); + if (rc) { + smc_sndbuf_sync_sg_for_device(conn); + if (send_done) + return send_done; + goto out_err; + } + send_done += chunk_len; + send_remaining -= chunk_len; + + if (chunk_len_sum == copylen) + break; /* either on 1st or 2nd iteration */ + /* prepare next (== 2nd) iteration */ + chunk_len = copylen - chunk_len; /* remainder */ + chunk_len_sum += chunk_len; + chunk_off = 0; /* modulo offset in send ring buffer */ + } + smc_sndbuf_sync_sg_for_device(conn); + /* update cursors */ + smc_curs_add(conn->sndbuf_desc->len, &prep, copylen); + smc_curs_copy(&conn->tx_curs_prep, &prep, conn); + /* increased in send tasklet smc_cdc_tx_handler() */ + smp_mb__before_atomic(); + atomic_sub(copylen, &conn->sndbuf_space); + /* guarantee 0 <= sndbuf_space <= sndbuf_desc->len */ + smp_mb__after_atomic(); + /* since we just produced more new data into sndbuf, + * trigger sndbuf consumer: RDMA write into peer RMBE and CDC + */ + if ((msg->msg_flags & MSG_OOB) && !send_remaining) + conn->urg_tx_pend = true; + /* If we need to cork, do nothing and wait for the next + * sendmsg() call or push on tx completion + */ + if (!smc_tx_should_cork(smc, msg)) + smc_tx_sndbuf_nonempty(conn); + + trace_smc_tx_sendmsg(smc, copylen); + } /* while (msg_data_left(msg)) */ + + return send_done; + +out_err: + rc = sk_stream_error(sk, msg->msg_flags, rc); + /* make sure we wake any epoll edge trigger waiter */ + if (unlikely(rc == -EAGAIN)) + sk->sk_write_space(sk); + return rc; +} + +int smc_tx_sendpage(struct smc_sock *smc, struct page *page, int offset, + size_t size, int flags) +{ + struct msghdr msg = {.msg_flags = flags}; + char *kaddr = kmap(page); + struct kvec iov; + int rc; + + iov.iov_base = kaddr + offset; + iov.iov_len = size; + iov_iter_kvec(&msg.msg_iter, ITER_SOURCE, &iov, 1, size); + rc = smc_tx_sendmsg(smc, &msg, size); + kunmap(page); + return rc; +} + +/***************************** sndbuf consumer *******************************/ + +/* sndbuf consumer: actual data transfer of one target chunk with ISM write */ +int smcd_tx_ism_write(struct smc_connection *conn, void *data, size_t len, + u32 offset, int signal) +{ + int rc; + + rc = smc_ism_write(conn->lgr->smcd, conn->peer_token, + conn->peer_rmbe_idx, signal, conn->tx_off + offset, + data, len); + if (rc) + conn->local_tx_ctrl.conn_state_flags.peer_conn_abort = 1; + return rc; +} + +/* sndbuf consumer: actual data transfer of one target chunk with RDMA write */ +static int smc_tx_rdma_write(struct smc_connection *conn, int peer_rmbe_offset, + int num_sges, struct ib_rdma_wr *rdma_wr) +{ + struct smc_link_group *lgr = conn->lgr; + struct smc_link *link = conn->lnk; + int rc; + + rdma_wr->wr.wr_id = smc_wr_tx_get_next_wr_id(link); + rdma_wr->wr.num_sge = num_sges; + rdma_wr->remote_addr = + lgr->rtokens[conn->rtoken_idx][link->link_idx].dma_addr + + /* RMBE within RMB */ + conn->tx_off + + /* offset within RMBE */ + peer_rmbe_offset; + rdma_wr->rkey = lgr->rtokens[conn->rtoken_idx][link->link_idx].rkey; + rc = ib_post_send(link->roce_qp, &rdma_wr->wr, NULL); + if (rc) + smcr_link_down_cond_sched(link); + return rc; +} + +/* sndbuf consumer */ +static inline void smc_tx_advance_cursors(struct smc_connection *conn, + union smc_host_cursor *prod, + union smc_host_cursor *sent, + size_t len) +{ + smc_curs_add(conn->peer_rmbe_size, prod, len); + /* increased in recv tasklet smc_cdc_msg_rcv() */ + smp_mb__before_atomic(); + /* data in flight reduces usable snd_wnd */ + atomic_sub(len, &conn->peer_rmbe_space); + /* guarantee 0 <= peer_rmbe_space <= peer_rmbe_size */ + smp_mb__after_atomic(); + smc_curs_add(conn->sndbuf_desc->len, sent, len); +} + +/* SMC-R helper for smc_tx_rdma_writes() */ +static int smcr_tx_rdma_writes(struct smc_connection *conn, size_t len, + size_t src_off, size_t src_len, + size_t dst_off, size_t dst_len, + struct smc_rdma_wr *wr_rdma_buf) +{ + struct smc_link *link = conn->lnk; + + dma_addr_t dma_addr = + sg_dma_address(conn->sndbuf_desc->sgt[link->link_idx].sgl); + u64 virt_addr = (uintptr_t)conn->sndbuf_desc->cpu_addr; + int src_len_sum = src_len, dst_len_sum = dst_len; + int sent_count = src_off; + int srcchunk, dstchunk; + int num_sges; + int rc; + + for (dstchunk = 0; dstchunk < 2; dstchunk++) { + struct ib_rdma_wr *wr = &wr_rdma_buf->wr_tx_rdma[dstchunk]; + struct ib_sge *sge = wr->wr.sg_list; + u64 base_addr = dma_addr; + + if (dst_len < link->qp_attr.cap.max_inline_data) { + base_addr = virt_addr; + wr->wr.send_flags |= IB_SEND_INLINE; + } else { + wr->wr.send_flags &= ~IB_SEND_INLINE; + } + + num_sges = 0; + for (srcchunk = 0; srcchunk < 2; srcchunk++) { + sge[srcchunk].addr = conn->sndbuf_desc->is_vm ? + (virt_addr + src_off) : (base_addr + src_off); + sge[srcchunk].length = src_len; + if (conn->sndbuf_desc->is_vm) + sge[srcchunk].lkey = + conn->sndbuf_desc->mr[link->link_idx]->lkey; + num_sges++; + + src_off += src_len; + if (src_off >= conn->sndbuf_desc->len) + src_off -= conn->sndbuf_desc->len; + /* modulo in send ring */ + if (src_len_sum == dst_len) + break; /* either on 1st or 2nd iteration */ + /* prepare next (== 2nd) iteration */ + src_len = dst_len - src_len; /* remainder */ + src_len_sum += src_len; + } + rc = smc_tx_rdma_write(conn, dst_off, num_sges, wr); + if (rc) + return rc; + if (dst_len_sum == len) + break; /* either on 1st or 2nd iteration */ + /* prepare next (== 2nd) iteration */ + dst_off = 0; /* modulo offset in RMBE ring buffer */ + dst_len = len - dst_len; /* remainder */ + dst_len_sum += dst_len; + src_len = min_t(int, dst_len, conn->sndbuf_desc->len - + sent_count); + src_len_sum = src_len; + } + return 0; +} + +/* SMC-D helper for smc_tx_rdma_writes() */ +static int smcd_tx_rdma_writes(struct smc_connection *conn, size_t len, + size_t src_off, size_t src_len, + size_t dst_off, size_t dst_len) +{ + int src_len_sum = src_len, dst_len_sum = dst_len; + int srcchunk, dstchunk; + int rc; + + for (dstchunk = 0; dstchunk < 2; dstchunk++) { + for (srcchunk = 0; srcchunk < 2; srcchunk++) { + void *data = conn->sndbuf_desc->cpu_addr + src_off; + + rc = smcd_tx_ism_write(conn, data, src_len, dst_off + + sizeof(struct smcd_cdc_msg), 0); + if (rc) + return rc; + dst_off += src_len; + src_off += src_len; + if (src_off >= conn->sndbuf_desc->len) + src_off -= conn->sndbuf_desc->len; + /* modulo in send ring */ + if (src_len_sum == dst_len) + break; /* either on 1st or 2nd iteration */ + /* prepare next (== 2nd) iteration */ + src_len = dst_len - src_len; /* remainder */ + src_len_sum += src_len; + } + if (dst_len_sum == len) + break; /* either on 1st or 2nd iteration */ + /* prepare next (== 2nd) iteration */ + dst_off = 0; /* modulo offset in RMBE ring buffer */ + dst_len = len - dst_len; /* remainder */ + dst_len_sum += dst_len; + src_len = min_t(int, dst_len, conn->sndbuf_desc->len - src_off); + src_len_sum = src_len; + } + return 0; +} + +/* sndbuf consumer: prepare all necessary (src&dst) chunks of data transmit; + * usable snd_wnd as max transmit + */ +static int smc_tx_rdma_writes(struct smc_connection *conn, + struct smc_rdma_wr *wr_rdma_buf) +{ + size_t len, src_len, dst_off, dst_len; /* current chunk values */ + union smc_host_cursor sent, prep, prod, cons; + struct smc_cdc_producer_flags *pflags; + int to_send, rmbespace; + int rc; + + /* source: sndbuf */ + smc_curs_copy(&sent, &conn->tx_curs_sent, conn); + smc_curs_copy(&prep, &conn->tx_curs_prep, conn); + /* cf. wmem_alloc - (snd_max - snd_una) */ + to_send = smc_curs_diff(conn->sndbuf_desc->len, &sent, &prep); + if (to_send <= 0) + return 0; + + /* destination: RMBE */ + /* cf. snd_wnd */ + rmbespace = atomic_read(&conn->peer_rmbe_space); + if (rmbespace <= 0) { + struct smc_sock *smc = container_of(conn, struct smc_sock, + conn); + SMC_STAT_RMB_TX_PEER_FULL(smc, !conn->lnk); + return 0; + } + smc_curs_copy(&prod, &conn->local_tx_ctrl.prod, conn); + smc_curs_copy(&cons, &conn->local_rx_ctrl.cons, conn); + + /* if usable snd_wnd closes ask peer to advertise once it opens again */ + pflags = &conn->local_tx_ctrl.prod_flags; + pflags->write_blocked = (to_send >= rmbespace); + /* cf. usable snd_wnd */ + len = min(to_send, rmbespace); + + /* initialize variables for first iteration of subsequent nested loop */ + dst_off = prod.count; + if (prod.wrap == cons.wrap) { + /* the filled destination area is unwrapped, + * hence the available free destination space is wrapped + * and we need 2 destination chunks of sum len; start with 1st + * which is limited by what's available in sndbuf + */ + dst_len = min_t(size_t, + conn->peer_rmbe_size - prod.count, len); + } else { + /* the filled destination area is wrapped, + * hence the available free destination space is unwrapped + * and we need a single destination chunk of entire len + */ + dst_len = len; + } + /* dst_len determines the maximum src_len */ + if (sent.count + dst_len <= conn->sndbuf_desc->len) { + /* unwrapped src case: single chunk of entire dst_len */ + src_len = dst_len; + } else { + /* wrapped src case: 2 chunks of sum dst_len; start with 1st: */ + src_len = conn->sndbuf_desc->len - sent.count; + } + + if (conn->lgr->is_smcd) + rc = smcd_tx_rdma_writes(conn, len, sent.count, src_len, + dst_off, dst_len); + else + rc = smcr_tx_rdma_writes(conn, len, sent.count, src_len, + dst_off, dst_len, wr_rdma_buf); + if (rc) + return rc; + + if (conn->urg_tx_pend && len == to_send) + pflags->urg_data_present = 1; + smc_tx_advance_cursors(conn, &prod, &sent, len); + /* update connection's cursors with advanced local cursors */ + smc_curs_copy(&conn->local_tx_ctrl.prod, &prod, conn); + /* dst: peer RMBE */ + smc_curs_copy(&conn->tx_curs_sent, &sent, conn);/* src: local sndbuf */ + + return 0; +} + +/* Wakeup sndbuf consumers from any context (IRQ or process) + * since there is more data to transmit; usable snd_wnd as max transmit + */ +static int smcr_tx_sndbuf_nonempty(struct smc_connection *conn) +{ + struct smc_cdc_producer_flags *pflags = &conn->local_tx_ctrl.prod_flags; + struct smc_link *link = conn->lnk; + struct smc_rdma_wr *wr_rdma_buf; + struct smc_cdc_tx_pend *pend; + struct smc_wr_buf *wr_buf; + int rc; + + if (!link || !smc_wr_tx_link_hold(link)) + return -ENOLINK; + rc = smc_cdc_get_free_slot(conn, link, &wr_buf, &wr_rdma_buf, &pend); + if (rc < 0) { + smc_wr_tx_link_put(link); + if (rc == -EBUSY) { + struct smc_sock *smc = + container_of(conn, struct smc_sock, conn); + + if (smc->sk.sk_err == ECONNABORTED) + return sock_error(&smc->sk); + if (conn->killed) + return -EPIPE; + rc = 0; + mod_delayed_work(conn->lgr->tx_wq, &conn->tx_work, + SMC_TX_WORK_DELAY); + } + return rc; + } + + spin_lock_bh(&conn->send_lock); + if (link != conn->lnk) { + /* link of connection changed, tx_work will restart */ + smc_wr_tx_put_slot(link, + (struct smc_wr_tx_pend_priv *)pend); + rc = -ENOLINK; + goto out_unlock; + } + if (!pflags->urg_data_present) { + rc = smc_tx_rdma_writes(conn, wr_rdma_buf); + if (rc) { + smc_wr_tx_put_slot(link, + (struct smc_wr_tx_pend_priv *)pend); + goto out_unlock; + } + } + + rc = smc_cdc_msg_send(conn, wr_buf, pend); + if (!rc && pflags->urg_data_present) { + pflags->urg_data_pending = 0; + pflags->urg_data_present = 0; + } + +out_unlock: + spin_unlock_bh(&conn->send_lock); + smc_wr_tx_link_put(link); + return rc; +} + +static int smcd_tx_sndbuf_nonempty(struct smc_connection *conn) +{ + struct smc_cdc_producer_flags *pflags = &conn->local_tx_ctrl.prod_flags; + int rc = 0; + + spin_lock_bh(&conn->send_lock); + if (!pflags->urg_data_present) + rc = smc_tx_rdma_writes(conn, NULL); + if (!rc) + rc = smcd_cdc_msg_send(conn); + + if (!rc && pflags->urg_data_present) { + pflags->urg_data_pending = 0; + pflags->urg_data_present = 0; + } + spin_unlock_bh(&conn->send_lock); + return rc; +} + +static int __smc_tx_sndbuf_nonempty(struct smc_connection *conn) +{ + struct smc_sock *smc = container_of(conn, struct smc_sock, conn); + int rc = 0; + + /* No data in the send queue */ + if (unlikely(smc_tx_prepared_sends(conn) <= 0)) + goto out; + + /* Peer don't have RMBE space */ + if (unlikely(atomic_read(&conn->peer_rmbe_space) <= 0)) { + SMC_STAT_RMB_TX_PEER_FULL(smc, !conn->lnk); + goto out; + } + + if (conn->killed || + conn->local_rx_ctrl.conn_state_flags.peer_conn_abort) { + rc = -EPIPE; /* connection being aborted */ + goto out; + } + if (conn->lgr->is_smcd) + rc = smcd_tx_sndbuf_nonempty(conn); + else + rc = smcr_tx_sndbuf_nonempty(conn); + + if (!rc) { + /* trigger socket release if connection is closing */ + smc_close_wake_tx_prepared(smc); + } + +out: + return rc; +} + +int smc_tx_sndbuf_nonempty(struct smc_connection *conn) +{ + int rc; + + /* This make sure only one can send simultaneously to prevent wasting + * of CPU and CDC slot. + * Record whether someone has tried to push while we are pushing. + */ + if (atomic_inc_return(&conn->tx_pushing) > 1) + return 0; + +again: + atomic_set(&conn->tx_pushing, 1); + smp_wmb(); /* Make sure tx_pushing is 1 before real send */ + rc = __smc_tx_sndbuf_nonempty(conn); + + /* We need to check whether someone else have added some data into + * the send queue and tried to push but failed after the atomic_set() + * when we are pushing. + * If so, we need to push again to prevent those data hang in the send + * queue. + */ + if (unlikely(!atomic_dec_and_test(&conn->tx_pushing))) + goto again; + + return rc; +} + +/* Wakeup sndbuf consumers from process context + * since there is more data to transmit. The caller + * must hold sock lock. + */ +void smc_tx_pending(struct smc_connection *conn) +{ + struct smc_sock *smc = container_of(conn, struct smc_sock, conn); + int rc; + + if (smc->sk.sk_err) + return; + + rc = smc_tx_sndbuf_nonempty(conn); + if (!rc && conn->local_rx_ctrl.prod_flags.write_blocked && + !atomic_read(&conn->bytes_to_rcv)) + conn->local_rx_ctrl.prod_flags.write_blocked = 0; +} + +/* Wakeup sndbuf consumers from process context + * since there is more data to transmit in locked + * sock. + */ +void smc_tx_work(struct work_struct *work) +{ + struct smc_connection *conn = container_of(to_delayed_work(work), + struct smc_connection, + tx_work); + struct smc_sock *smc = container_of(conn, struct smc_sock, conn); + + lock_sock(&smc->sk); + smc_tx_pending(conn); + release_sock(&smc->sk); +} + +void smc_tx_consumer_update(struct smc_connection *conn, bool force) +{ + union smc_host_cursor cfed, cons, prod; + int sender_free = conn->rmb_desc->len; + int to_confirm; + + smc_curs_copy(&cons, &conn->local_tx_ctrl.cons, conn); + smc_curs_copy(&cfed, &conn->rx_curs_confirmed, conn); + to_confirm = smc_curs_diff(conn->rmb_desc->len, &cfed, &cons); + if (to_confirm > conn->rmbe_update_limit) { + smc_curs_copy(&prod, &conn->local_rx_ctrl.prod, conn); + sender_free = conn->rmb_desc->len - + smc_curs_diff_large(conn->rmb_desc->len, + &cfed, &prod); + } + + if (conn->local_rx_ctrl.prod_flags.cons_curs_upd_req || + force || + ((to_confirm > conn->rmbe_update_limit) && + ((sender_free <= (conn->rmb_desc->len / 2)) || + conn->local_rx_ctrl.prod_flags.write_blocked))) { + if (conn->killed || + conn->local_rx_ctrl.conn_state_flags.peer_conn_abort) + return; + if ((smc_cdc_get_slot_and_msg_send(conn) < 0) && + !conn->killed) { + queue_delayed_work(conn->lgr->tx_wq, &conn->tx_work, + SMC_TX_WORK_DELAY); + return; + } + } + if (conn->local_rx_ctrl.prod_flags.write_blocked && + !atomic_read(&conn->bytes_to_rcv)) + conn->local_rx_ctrl.prod_flags.write_blocked = 0; +} + +/***************************** send initialize *******************************/ + +/* Initialize send properties on connection establishment. NB: not __init! */ +void smc_tx_init(struct smc_sock *smc) +{ + smc->sk.sk_write_space = smc_tx_write_space; +} diff --git a/net/smc/smc_tx.h b/net/smc/smc_tx.h new file mode 100644 index 000000000..34b578498 --- /dev/null +++ b/net/smc/smc_tx.h @@ -0,0 +1,42 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Manage send buffer + * + * Copyright IBM Corp. 2016 + * + * Author(s): Ursula Braun + */ + +#ifndef SMC_TX_H +#define SMC_TX_H + +#include +#include + +#include "smc.h" +#include "smc_cdc.h" + +static inline int smc_tx_prepared_sends(struct smc_connection *conn) +{ + union smc_host_cursor sent, prep; + + smc_curs_copy(&sent, &conn->tx_curs_sent, conn); + smc_curs_copy(&prep, &conn->tx_curs_prep, conn); + return smc_curs_diff(conn->sndbuf_desc->len, &sent, &prep); +} + +void smc_tx_pending(struct smc_connection *conn); +void smc_tx_work(struct work_struct *work); +void smc_tx_init(struct smc_sock *smc); +int smc_tx_sendmsg(struct smc_sock *smc, struct msghdr *msg, size_t len); +int smc_tx_sendpage(struct smc_sock *smc, struct page *page, int offset, + size_t size, int flags); +int smc_tx_sndbuf_nonempty(struct smc_connection *conn); +void smc_tx_sndbuf_nonfull(struct smc_sock *smc); +void smc_tx_consumer_update(struct smc_connection *conn, bool force); +int smcd_tx_ism_write(struct smc_connection *conn, void *data, size_t len, + u32 offset, int signal); + +#endif /* SMC_TX_H */ diff --git a/net/smc/smc_wr.c b/net/smc/smc_wr.c new file mode 100644 index 000000000..b0678a417 --- /dev/null +++ b/net/smc/smc_wr.c @@ -0,0 +1,918 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Work Requests exploiting Infiniband API + * + * Work requests (WR) of type ib_post_send or ib_post_recv respectively + * are submitted to either RC SQ or RC RQ respectively + * (reliably connected send/receive queue) + * and become work queue entries (WQEs). + * While an SQ WR/WQE is pending, we track it until transmission completion. + * Through a send or receive completion queue (CQ) respectively, + * we get completion queue entries (CQEs) [aka work completions (WCs)]. + * Since the CQ callback is called from IRQ context, we split work by using + * bottom halves implemented by tasklets. + * + * SMC uses this to exchange LLC (link layer control) + * and CDC (connection data control) messages. + * + * Copyright IBM Corp. 2016 + * + * Author(s): Steffen Maier + */ + +#include +#include +#include +#include +#include + +#include "smc.h" +#include "smc_wr.h" + +#define SMC_WR_MAX_POLL_CQE 10 /* max. # of compl. queue elements in 1 poll */ + +#define SMC_WR_RX_HASH_BITS 4 +static DEFINE_HASHTABLE(smc_wr_rx_hash, SMC_WR_RX_HASH_BITS); +static DEFINE_SPINLOCK(smc_wr_rx_hash_lock); + +struct smc_wr_tx_pend { /* control data for a pending send request */ + u64 wr_id; /* work request id sent */ + smc_wr_tx_handler handler; + enum ib_wc_status wc_status; /* CQE status */ + struct smc_link *link; + u32 idx; + struct smc_wr_tx_pend_priv priv; + u8 compl_requested; +}; + +/******************************** send queue *********************************/ + +/*------------------------------- completion --------------------------------*/ + +/* returns true if at least one tx work request is pending on the given link */ +static inline bool smc_wr_is_tx_pend(struct smc_link *link) +{ + return !bitmap_empty(link->wr_tx_mask, link->wr_tx_cnt); +} + +/* wait till all pending tx work requests on the given link are completed */ +void smc_wr_tx_wait_no_pending_sends(struct smc_link *link) +{ + wait_event(link->wr_tx_wait, !smc_wr_is_tx_pend(link)); +} + +static inline int smc_wr_tx_find_pending_index(struct smc_link *link, u64 wr_id) +{ + u32 i; + + for (i = 0; i < link->wr_tx_cnt; i++) { + if (link->wr_tx_pends[i].wr_id == wr_id) + return i; + } + return link->wr_tx_cnt; +} + +static inline void smc_wr_tx_process_cqe(struct ib_wc *wc) +{ + struct smc_wr_tx_pend pnd_snd; + struct smc_link *link; + u32 pnd_snd_idx; + + link = wc->qp->qp_context; + + if (wc->opcode == IB_WC_REG_MR) { + if (wc->status) + link->wr_reg_state = FAILED; + else + link->wr_reg_state = CONFIRMED; + smc_wr_wakeup_reg_wait(link); + return; + } + + pnd_snd_idx = smc_wr_tx_find_pending_index(link, wc->wr_id); + if (pnd_snd_idx == link->wr_tx_cnt) { + if (link->lgr->smc_version != SMC_V2 || + link->wr_tx_v2_pend->wr_id != wc->wr_id) + return; + link->wr_tx_v2_pend->wc_status = wc->status; + memcpy(&pnd_snd, link->wr_tx_v2_pend, sizeof(pnd_snd)); + /* clear the full struct smc_wr_tx_pend including .priv */ + memset(link->wr_tx_v2_pend, 0, + sizeof(*link->wr_tx_v2_pend)); + memset(link->lgr->wr_tx_buf_v2, 0, + sizeof(*link->lgr->wr_tx_buf_v2)); + } else { + link->wr_tx_pends[pnd_snd_idx].wc_status = wc->status; + if (link->wr_tx_pends[pnd_snd_idx].compl_requested) + complete(&link->wr_tx_compl[pnd_snd_idx]); + memcpy(&pnd_snd, &link->wr_tx_pends[pnd_snd_idx], + sizeof(pnd_snd)); + /* clear the full struct smc_wr_tx_pend including .priv */ + memset(&link->wr_tx_pends[pnd_snd_idx], 0, + sizeof(link->wr_tx_pends[pnd_snd_idx])); + memset(&link->wr_tx_bufs[pnd_snd_idx], 0, + sizeof(link->wr_tx_bufs[pnd_snd_idx])); + if (!test_and_clear_bit(pnd_snd_idx, link->wr_tx_mask)) + return; + } + + if (wc->status) { + if (link->lgr->smc_version == SMC_V2) { + memset(link->wr_tx_v2_pend, 0, + sizeof(*link->wr_tx_v2_pend)); + memset(link->lgr->wr_tx_buf_v2, 0, + sizeof(*link->lgr->wr_tx_buf_v2)); + } + /* terminate link */ + smcr_link_down_cond_sched(link); + } + if (pnd_snd.handler) + pnd_snd.handler(&pnd_snd.priv, link, wc->status); + wake_up(&link->wr_tx_wait); +} + +static void smc_wr_tx_tasklet_fn(struct tasklet_struct *t) +{ + struct smc_ib_device *dev = from_tasklet(dev, t, send_tasklet); + struct ib_wc wc[SMC_WR_MAX_POLL_CQE]; + int i = 0, rc; + int polled = 0; + +again: + polled++; + do { + memset(&wc, 0, sizeof(wc)); + rc = ib_poll_cq(dev->roce_cq_send, SMC_WR_MAX_POLL_CQE, wc); + if (polled == 1) { + ib_req_notify_cq(dev->roce_cq_send, + IB_CQ_NEXT_COMP | + IB_CQ_REPORT_MISSED_EVENTS); + } + if (!rc) + break; + for (i = 0; i < rc; i++) + smc_wr_tx_process_cqe(&wc[i]); + } while (rc > 0); + if (polled == 1) + goto again; +} + +void smc_wr_tx_cq_handler(struct ib_cq *ib_cq, void *cq_context) +{ + struct smc_ib_device *dev = (struct smc_ib_device *)cq_context; + + tasklet_schedule(&dev->send_tasklet); +} + +/*---------------------------- request submission ---------------------------*/ + +static inline int smc_wr_tx_get_free_slot_index(struct smc_link *link, u32 *idx) +{ + *idx = link->wr_tx_cnt; + if (!smc_link_sendable(link)) + return -ENOLINK; + for_each_clear_bit(*idx, link->wr_tx_mask, link->wr_tx_cnt) { + if (!test_and_set_bit(*idx, link->wr_tx_mask)) + return 0; + } + *idx = link->wr_tx_cnt; + return -EBUSY; +} + +/** + * smc_wr_tx_get_free_slot() - returns buffer for message assembly, + * and sets info for pending transmit tracking + * @link: Pointer to smc_link used to later send the message. + * @handler: Send completion handler function pointer. + * @wr_buf: Out value returns pointer to message buffer. + * @wr_rdma_buf: Out value returns pointer to rdma work request. + * @wr_pend_priv: Out value returns pointer serving as handler context. + * + * Return: 0 on success, or -errno on error. + */ +int smc_wr_tx_get_free_slot(struct smc_link *link, + smc_wr_tx_handler handler, + struct smc_wr_buf **wr_buf, + struct smc_rdma_wr **wr_rdma_buf, + struct smc_wr_tx_pend_priv **wr_pend_priv) +{ + struct smc_link_group *lgr = smc_get_lgr(link); + struct smc_wr_tx_pend *wr_pend; + u32 idx = link->wr_tx_cnt; + struct ib_send_wr *wr_ib; + u64 wr_id; + int rc; + + *wr_buf = NULL; + *wr_pend_priv = NULL; + if (in_softirq() || lgr->terminating) { + rc = smc_wr_tx_get_free_slot_index(link, &idx); + if (rc) + return rc; + } else { + rc = wait_event_interruptible_timeout( + link->wr_tx_wait, + !smc_link_sendable(link) || + lgr->terminating || + (smc_wr_tx_get_free_slot_index(link, &idx) != -EBUSY), + SMC_WR_TX_WAIT_FREE_SLOT_TIME); + if (!rc) { + /* timeout - terminate link */ + smcr_link_down_cond_sched(link); + return -EPIPE; + } + if (idx == link->wr_tx_cnt) + return -EPIPE; + } + wr_id = smc_wr_tx_get_next_wr_id(link); + wr_pend = &link->wr_tx_pends[idx]; + wr_pend->wr_id = wr_id; + wr_pend->handler = handler; + wr_pend->link = link; + wr_pend->idx = idx; + wr_ib = &link->wr_tx_ibs[idx]; + wr_ib->wr_id = wr_id; + *wr_buf = &link->wr_tx_bufs[idx]; + if (wr_rdma_buf) + *wr_rdma_buf = &link->wr_tx_rdmas[idx]; + *wr_pend_priv = &wr_pend->priv; + return 0; +} + +int smc_wr_tx_get_v2_slot(struct smc_link *link, + smc_wr_tx_handler handler, + struct smc_wr_v2_buf **wr_buf, + struct smc_wr_tx_pend_priv **wr_pend_priv) +{ + struct smc_wr_tx_pend *wr_pend; + struct ib_send_wr *wr_ib; + u64 wr_id; + + if (link->wr_tx_v2_pend->idx == link->wr_tx_cnt) + return -EBUSY; + + *wr_buf = NULL; + *wr_pend_priv = NULL; + wr_id = smc_wr_tx_get_next_wr_id(link); + wr_pend = link->wr_tx_v2_pend; + wr_pend->wr_id = wr_id; + wr_pend->handler = handler; + wr_pend->link = link; + wr_pend->idx = link->wr_tx_cnt; + wr_ib = link->wr_tx_v2_ib; + wr_ib->wr_id = wr_id; + *wr_buf = link->lgr->wr_tx_buf_v2; + *wr_pend_priv = &wr_pend->priv; + return 0; +} + +int smc_wr_tx_put_slot(struct smc_link *link, + struct smc_wr_tx_pend_priv *wr_pend_priv) +{ + struct smc_wr_tx_pend *pend; + + pend = container_of(wr_pend_priv, struct smc_wr_tx_pend, priv); + if (pend->idx < link->wr_tx_cnt) { + u32 idx = pend->idx; + + /* clear the full struct smc_wr_tx_pend including .priv */ + memset(&link->wr_tx_pends[idx], 0, + sizeof(link->wr_tx_pends[idx])); + memset(&link->wr_tx_bufs[idx], 0, + sizeof(link->wr_tx_bufs[idx])); + test_and_clear_bit(idx, link->wr_tx_mask); + wake_up(&link->wr_tx_wait); + return 1; + } else if (link->lgr->smc_version == SMC_V2 && + pend->idx == link->wr_tx_cnt) { + /* Large v2 buffer */ + memset(&link->wr_tx_v2_pend, 0, + sizeof(link->wr_tx_v2_pend)); + memset(&link->lgr->wr_tx_buf_v2, 0, + sizeof(link->lgr->wr_tx_buf_v2)); + return 1; + } + + return 0; +} + +/* Send prepared WR slot via ib_post_send. + * @priv: pointer to smc_wr_tx_pend_priv identifying prepared message buffer + */ +int smc_wr_tx_send(struct smc_link *link, struct smc_wr_tx_pend_priv *priv) +{ + struct smc_wr_tx_pend *pend; + int rc; + + ib_req_notify_cq(link->smcibdev->roce_cq_send, + IB_CQ_NEXT_COMP | IB_CQ_REPORT_MISSED_EVENTS); + pend = container_of(priv, struct smc_wr_tx_pend, priv); + rc = ib_post_send(link->roce_qp, &link->wr_tx_ibs[pend->idx], NULL); + if (rc) { + smc_wr_tx_put_slot(link, priv); + smcr_link_down_cond_sched(link); + } + return rc; +} + +int smc_wr_tx_v2_send(struct smc_link *link, struct smc_wr_tx_pend_priv *priv, + int len) +{ + int rc; + + link->wr_tx_v2_ib->sg_list[0].length = len; + ib_req_notify_cq(link->smcibdev->roce_cq_send, + IB_CQ_NEXT_COMP | IB_CQ_REPORT_MISSED_EVENTS); + rc = ib_post_send(link->roce_qp, link->wr_tx_v2_ib, NULL); + if (rc) { + smc_wr_tx_put_slot(link, priv); + smcr_link_down_cond_sched(link); + } + return rc; +} + +/* Send prepared WR slot via ib_post_send and wait for send completion + * notification. + * @priv: pointer to smc_wr_tx_pend_priv identifying prepared message buffer + */ +int smc_wr_tx_send_wait(struct smc_link *link, struct smc_wr_tx_pend_priv *priv, + unsigned long timeout) +{ + struct smc_wr_tx_pend *pend; + u32 pnd_idx; + int rc; + + pend = container_of(priv, struct smc_wr_tx_pend, priv); + pend->compl_requested = 1; + pnd_idx = pend->idx; + init_completion(&link->wr_tx_compl[pnd_idx]); + + rc = smc_wr_tx_send(link, priv); + if (rc) + return rc; + /* wait for completion by smc_wr_tx_process_cqe() */ + rc = wait_for_completion_interruptible_timeout( + &link->wr_tx_compl[pnd_idx], timeout); + if (rc <= 0) + rc = -ENODATA; + if (rc > 0) + rc = 0; + return rc; +} + +/* Register a memory region and wait for result. */ +int smc_wr_reg_send(struct smc_link *link, struct ib_mr *mr) +{ + int rc; + + ib_req_notify_cq(link->smcibdev->roce_cq_send, + IB_CQ_NEXT_COMP | IB_CQ_REPORT_MISSED_EVENTS); + link->wr_reg_state = POSTED; + link->wr_reg.wr.wr_id = (u64)(uintptr_t)mr; + link->wr_reg.mr = mr; + link->wr_reg.key = mr->rkey; + rc = ib_post_send(link->roce_qp, &link->wr_reg.wr, NULL); + if (rc) + return rc; + + atomic_inc(&link->wr_reg_refcnt); + rc = wait_event_interruptible_timeout(link->wr_reg_wait, + (link->wr_reg_state != POSTED), + SMC_WR_REG_MR_WAIT_TIME); + if (atomic_dec_and_test(&link->wr_reg_refcnt)) + wake_up_all(&link->wr_reg_wait); + if (!rc) { + /* timeout - terminate link */ + smcr_link_down_cond_sched(link); + return -EPIPE; + } + if (rc == -ERESTARTSYS) + return -EINTR; + switch (link->wr_reg_state) { + case CONFIRMED: + rc = 0; + break; + case FAILED: + rc = -EIO; + break; + case POSTED: + rc = -EPIPE; + break; + } + return rc; +} + +/****************************** receive queue ********************************/ + +int smc_wr_rx_register_handler(struct smc_wr_rx_handler *handler) +{ + struct smc_wr_rx_handler *h_iter; + int rc = 0; + + spin_lock(&smc_wr_rx_hash_lock); + hash_for_each_possible(smc_wr_rx_hash, h_iter, list, handler->type) { + if (h_iter->type == handler->type) { + rc = -EEXIST; + goto out_unlock; + } + } + hash_add(smc_wr_rx_hash, &handler->list, handler->type); +out_unlock: + spin_unlock(&smc_wr_rx_hash_lock); + return rc; +} + +/* Demultiplex a received work request based on the message type to its handler. + * Relies on smc_wr_rx_hash having been completely filled before any IB WRs, + * and not being modified any more afterwards so we don't need to lock it. + */ +static inline void smc_wr_rx_demultiplex(struct ib_wc *wc) +{ + struct smc_link *link = (struct smc_link *)wc->qp->qp_context; + struct smc_wr_rx_handler *handler; + struct smc_wr_rx_hdr *wr_rx; + u64 temp_wr_id; + u32 index; + + if (wc->byte_len < sizeof(*wr_rx)) + return; /* short message */ + temp_wr_id = wc->wr_id; + index = do_div(temp_wr_id, link->wr_rx_cnt); + wr_rx = (struct smc_wr_rx_hdr *)&link->wr_rx_bufs[index]; + hash_for_each_possible(smc_wr_rx_hash, handler, list, wr_rx->type) { + if (handler->type == wr_rx->type) + handler->handler(wc, wr_rx); + } +} + +static inline void smc_wr_rx_process_cqes(struct ib_wc wc[], int num) +{ + struct smc_link *link; + int i; + + for (i = 0; i < num; i++) { + link = wc[i].qp->qp_context; + link->wr_rx_id_compl = wc[i].wr_id; + if (wc[i].status == IB_WC_SUCCESS) { + link->wr_rx_tstamp = jiffies; + smc_wr_rx_demultiplex(&wc[i]); + smc_wr_rx_post(link); /* refill WR RX */ + } else { + /* handle status errors */ + switch (wc[i].status) { + case IB_WC_RETRY_EXC_ERR: + case IB_WC_RNR_RETRY_EXC_ERR: + case IB_WC_WR_FLUSH_ERR: + smcr_link_down_cond_sched(link); + if (link->wr_rx_id_compl == link->wr_rx_id) + wake_up(&link->wr_rx_empty_wait); + break; + default: + smc_wr_rx_post(link); /* refill WR RX */ + break; + } + } + } +} + +static void smc_wr_rx_tasklet_fn(struct tasklet_struct *t) +{ + struct smc_ib_device *dev = from_tasklet(dev, t, recv_tasklet); + struct ib_wc wc[SMC_WR_MAX_POLL_CQE]; + int polled = 0; + int rc; + +again: + polled++; + do { + memset(&wc, 0, sizeof(wc)); + rc = ib_poll_cq(dev->roce_cq_recv, SMC_WR_MAX_POLL_CQE, wc); + if (polled == 1) { + ib_req_notify_cq(dev->roce_cq_recv, + IB_CQ_SOLICITED_MASK + | IB_CQ_REPORT_MISSED_EVENTS); + } + if (!rc) + break; + smc_wr_rx_process_cqes(&wc[0], rc); + } while (rc > 0); + if (polled == 1) + goto again; +} + +void smc_wr_rx_cq_handler(struct ib_cq *ib_cq, void *cq_context) +{ + struct smc_ib_device *dev = (struct smc_ib_device *)cq_context; + + tasklet_schedule(&dev->recv_tasklet); +} + +int smc_wr_rx_post_init(struct smc_link *link) +{ + u32 i; + int rc = 0; + + for (i = 0; i < link->wr_rx_cnt; i++) + rc = smc_wr_rx_post(link); + return rc; +} + +/***************************** init, exit, misc ******************************/ + +void smc_wr_remember_qp_attr(struct smc_link *lnk) +{ + struct ib_qp_attr *attr = &lnk->qp_attr; + struct ib_qp_init_attr init_attr; + + memset(attr, 0, sizeof(*attr)); + memset(&init_attr, 0, sizeof(init_attr)); + ib_query_qp(lnk->roce_qp, attr, + IB_QP_STATE | + IB_QP_CUR_STATE | + IB_QP_PKEY_INDEX | + IB_QP_PORT | + IB_QP_QKEY | + IB_QP_AV | + IB_QP_PATH_MTU | + IB_QP_TIMEOUT | + IB_QP_RETRY_CNT | + IB_QP_RNR_RETRY | + IB_QP_RQ_PSN | + IB_QP_ALT_PATH | + IB_QP_MIN_RNR_TIMER | + IB_QP_SQ_PSN | + IB_QP_PATH_MIG_STATE | + IB_QP_CAP | + IB_QP_DEST_QPN, + &init_attr); + + lnk->wr_tx_cnt = min_t(size_t, SMC_WR_BUF_CNT, + lnk->qp_attr.cap.max_send_wr); + lnk->wr_rx_cnt = min_t(size_t, SMC_WR_BUF_CNT * 3, + lnk->qp_attr.cap.max_recv_wr); +} + +static void smc_wr_init_sge(struct smc_link *lnk) +{ + int sges_per_buf = (lnk->lgr->smc_version == SMC_V2) ? 2 : 1; + bool send_inline = (lnk->qp_attr.cap.max_inline_data > SMC_WR_TX_SIZE); + u32 i; + + for (i = 0; i < lnk->wr_tx_cnt; i++) { + lnk->wr_tx_sges[i].addr = send_inline ? (uintptr_t)(&lnk->wr_tx_bufs[i]) : + lnk->wr_tx_dma_addr + i * SMC_WR_BUF_SIZE; + lnk->wr_tx_sges[i].length = SMC_WR_TX_SIZE; + lnk->wr_tx_sges[i].lkey = lnk->roce_pd->local_dma_lkey; + lnk->wr_tx_rdma_sges[i].tx_rdma_sge[0].wr_tx_rdma_sge[0].lkey = + lnk->roce_pd->local_dma_lkey; + lnk->wr_tx_rdma_sges[i].tx_rdma_sge[0].wr_tx_rdma_sge[1].lkey = + lnk->roce_pd->local_dma_lkey; + lnk->wr_tx_rdma_sges[i].tx_rdma_sge[1].wr_tx_rdma_sge[0].lkey = + lnk->roce_pd->local_dma_lkey; + lnk->wr_tx_rdma_sges[i].tx_rdma_sge[1].wr_tx_rdma_sge[1].lkey = + lnk->roce_pd->local_dma_lkey; + lnk->wr_tx_ibs[i].next = NULL; + lnk->wr_tx_ibs[i].sg_list = &lnk->wr_tx_sges[i]; + lnk->wr_tx_ibs[i].num_sge = 1; + lnk->wr_tx_ibs[i].opcode = IB_WR_SEND; + lnk->wr_tx_ibs[i].send_flags = + IB_SEND_SIGNALED | IB_SEND_SOLICITED; + if (send_inline) + lnk->wr_tx_ibs[i].send_flags |= IB_SEND_INLINE; + lnk->wr_tx_rdmas[i].wr_tx_rdma[0].wr.opcode = IB_WR_RDMA_WRITE; + lnk->wr_tx_rdmas[i].wr_tx_rdma[1].wr.opcode = IB_WR_RDMA_WRITE; + lnk->wr_tx_rdmas[i].wr_tx_rdma[0].wr.sg_list = + lnk->wr_tx_rdma_sges[i].tx_rdma_sge[0].wr_tx_rdma_sge; + lnk->wr_tx_rdmas[i].wr_tx_rdma[1].wr.sg_list = + lnk->wr_tx_rdma_sges[i].tx_rdma_sge[1].wr_tx_rdma_sge; + } + + if (lnk->lgr->smc_version == SMC_V2) { + lnk->wr_tx_v2_sge->addr = lnk->wr_tx_v2_dma_addr; + lnk->wr_tx_v2_sge->length = SMC_WR_BUF_V2_SIZE; + lnk->wr_tx_v2_sge->lkey = lnk->roce_pd->local_dma_lkey; + + lnk->wr_tx_v2_ib->next = NULL; + lnk->wr_tx_v2_ib->sg_list = lnk->wr_tx_v2_sge; + lnk->wr_tx_v2_ib->num_sge = 1; + lnk->wr_tx_v2_ib->opcode = IB_WR_SEND; + lnk->wr_tx_v2_ib->send_flags = + IB_SEND_SIGNALED | IB_SEND_SOLICITED; + } + + /* With SMC-Rv2 there can be messages larger than SMC_WR_TX_SIZE. + * Each ib_recv_wr gets 2 sges, the second one is a spillover buffer + * and the same buffer for all sges. When a larger message arrived then + * the content of the first small sge is copied to the beginning of + * the larger spillover buffer, allowing easy data mapping. + */ + for (i = 0; i < lnk->wr_rx_cnt; i++) { + int x = i * sges_per_buf; + + lnk->wr_rx_sges[x].addr = + lnk->wr_rx_dma_addr + i * SMC_WR_BUF_SIZE; + lnk->wr_rx_sges[x].length = SMC_WR_TX_SIZE; + lnk->wr_rx_sges[x].lkey = lnk->roce_pd->local_dma_lkey; + if (lnk->lgr->smc_version == SMC_V2) { + lnk->wr_rx_sges[x + 1].addr = + lnk->wr_rx_v2_dma_addr + SMC_WR_TX_SIZE; + lnk->wr_rx_sges[x + 1].length = + SMC_WR_BUF_V2_SIZE - SMC_WR_TX_SIZE; + lnk->wr_rx_sges[x + 1].lkey = + lnk->roce_pd->local_dma_lkey; + } + lnk->wr_rx_ibs[i].next = NULL; + lnk->wr_rx_ibs[i].sg_list = &lnk->wr_rx_sges[x]; + lnk->wr_rx_ibs[i].num_sge = sges_per_buf; + } + lnk->wr_reg.wr.next = NULL; + lnk->wr_reg.wr.num_sge = 0; + lnk->wr_reg.wr.send_flags = IB_SEND_SIGNALED; + lnk->wr_reg.wr.opcode = IB_WR_REG_MR; + lnk->wr_reg.access = IB_ACCESS_LOCAL_WRITE | IB_ACCESS_REMOTE_WRITE; +} + +void smc_wr_free_link(struct smc_link *lnk) +{ + struct ib_device *ibdev; + + if (!lnk->smcibdev) + return; + ibdev = lnk->smcibdev->ibdev; + + smc_wr_drain_cq(lnk); + smc_wr_wakeup_reg_wait(lnk); + smc_wr_wakeup_tx_wait(lnk); + + smc_wr_tx_wait_no_pending_sends(lnk); + wait_event(lnk->wr_reg_wait, (!atomic_read(&lnk->wr_reg_refcnt))); + wait_event(lnk->wr_tx_wait, (!atomic_read(&lnk->wr_tx_refcnt))); + + if (lnk->wr_rx_dma_addr) { + ib_dma_unmap_single(ibdev, lnk->wr_rx_dma_addr, + SMC_WR_BUF_SIZE * lnk->wr_rx_cnt, + DMA_FROM_DEVICE); + lnk->wr_rx_dma_addr = 0; + } + if (lnk->wr_rx_v2_dma_addr) { + ib_dma_unmap_single(ibdev, lnk->wr_rx_v2_dma_addr, + SMC_WR_BUF_V2_SIZE, + DMA_FROM_DEVICE); + lnk->wr_rx_v2_dma_addr = 0; + } + if (lnk->wr_tx_dma_addr) { + ib_dma_unmap_single(ibdev, lnk->wr_tx_dma_addr, + SMC_WR_BUF_SIZE * lnk->wr_tx_cnt, + DMA_TO_DEVICE); + lnk->wr_tx_dma_addr = 0; + } + if (lnk->wr_tx_v2_dma_addr) { + ib_dma_unmap_single(ibdev, lnk->wr_tx_v2_dma_addr, + SMC_WR_BUF_V2_SIZE, + DMA_TO_DEVICE); + lnk->wr_tx_v2_dma_addr = 0; + } +} + +void smc_wr_free_lgr_mem(struct smc_link_group *lgr) +{ + if (lgr->smc_version < SMC_V2) + return; + + kfree(lgr->wr_rx_buf_v2); + lgr->wr_rx_buf_v2 = NULL; + kfree(lgr->wr_tx_buf_v2); + lgr->wr_tx_buf_v2 = NULL; +} + +void smc_wr_free_link_mem(struct smc_link *lnk) +{ + kfree(lnk->wr_tx_v2_ib); + lnk->wr_tx_v2_ib = NULL; + kfree(lnk->wr_tx_v2_sge); + lnk->wr_tx_v2_sge = NULL; + kfree(lnk->wr_tx_v2_pend); + lnk->wr_tx_v2_pend = NULL; + kfree(lnk->wr_tx_compl); + lnk->wr_tx_compl = NULL; + kfree(lnk->wr_tx_pends); + lnk->wr_tx_pends = NULL; + bitmap_free(lnk->wr_tx_mask); + lnk->wr_tx_mask = NULL; + kfree(lnk->wr_tx_sges); + lnk->wr_tx_sges = NULL; + kfree(lnk->wr_tx_rdma_sges); + lnk->wr_tx_rdma_sges = NULL; + kfree(lnk->wr_rx_sges); + lnk->wr_rx_sges = NULL; + kfree(lnk->wr_tx_rdmas); + lnk->wr_tx_rdmas = NULL; + kfree(lnk->wr_rx_ibs); + lnk->wr_rx_ibs = NULL; + kfree(lnk->wr_tx_ibs); + lnk->wr_tx_ibs = NULL; + kfree(lnk->wr_tx_bufs); + lnk->wr_tx_bufs = NULL; + kfree(lnk->wr_rx_bufs); + lnk->wr_rx_bufs = NULL; +} + +int smc_wr_alloc_lgr_mem(struct smc_link_group *lgr) +{ + if (lgr->smc_version < SMC_V2) + return 0; + + lgr->wr_rx_buf_v2 = kzalloc(SMC_WR_BUF_V2_SIZE, GFP_KERNEL); + if (!lgr->wr_rx_buf_v2) + return -ENOMEM; + lgr->wr_tx_buf_v2 = kzalloc(SMC_WR_BUF_V2_SIZE, GFP_KERNEL); + if (!lgr->wr_tx_buf_v2) { + kfree(lgr->wr_rx_buf_v2); + return -ENOMEM; + } + return 0; +} + +int smc_wr_alloc_link_mem(struct smc_link *link) +{ + int sges_per_buf = link->lgr->smc_version == SMC_V2 ? 2 : 1; + + /* allocate link related memory */ + link->wr_tx_bufs = kcalloc(SMC_WR_BUF_CNT, SMC_WR_BUF_SIZE, GFP_KERNEL); + if (!link->wr_tx_bufs) + goto no_mem; + link->wr_rx_bufs = kcalloc(SMC_WR_BUF_CNT * 3, SMC_WR_BUF_SIZE, + GFP_KERNEL); + if (!link->wr_rx_bufs) + goto no_mem_wr_tx_bufs; + link->wr_tx_ibs = kcalloc(SMC_WR_BUF_CNT, sizeof(link->wr_tx_ibs[0]), + GFP_KERNEL); + if (!link->wr_tx_ibs) + goto no_mem_wr_rx_bufs; + link->wr_rx_ibs = kcalloc(SMC_WR_BUF_CNT * 3, + sizeof(link->wr_rx_ibs[0]), + GFP_KERNEL); + if (!link->wr_rx_ibs) + goto no_mem_wr_tx_ibs; + link->wr_tx_rdmas = kcalloc(SMC_WR_BUF_CNT, + sizeof(link->wr_tx_rdmas[0]), + GFP_KERNEL); + if (!link->wr_tx_rdmas) + goto no_mem_wr_rx_ibs; + link->wr_tx_rdma_sges = kcalloc(SMC_WR_BUF_CNT, + sizeof(link->wr_tx_rdma_sges[0]), + GFP_KERNEL); + if (!link->wr_tx_rdma_sges) + goto no_mem_wr_tx_rdmas; + link->wr_tx_sges = kcalloc(SMC_WR_BUF_CNT, sizeof(link->wr_tx_sges[0]), + GFP_KERNEL); + if (!link->wr_tx_sges) + goto no_mem_wr_tx_rdma_sges; + link->wr_rx_sges = kcalloc(SMC_WR_BUF_CNT * 3, + sizeof(link->wr_rx_sges[0]) * sges_per_buf, + GFP_KERNEL); + if (!link->wr_rx_sges) + goto no_mem_wr_tx_sges; + link->wr_tx_mask = bitmap_zalloc(SMC_WR_BUF_CNT, GFP_KERNEL); + if (!link->wr_tx_mask) + goto no_mem_wr_rx_sges; + link->wr_tx_pends = kcalloc(SMC_WR_BUF_CNT, + sizeof(link->wr_tx_pends[0]), + GFP_KERNEL); + if (!link->wr_tx_pends) + goto no_mem_wr_tx_mask; + link->wr_tx_compl = kcalloc(SMC_WR_BUF_CNT, + sizeof(link->wr_tx_compl[0]), + GFP_KERNEL); + if (!link->wr_tx_compl) + goto no_mem_wr_tx_pends; + + if (link->lgr->smc_version == SMC_V2) { + link->wr_tx_v2_ib = kzalloc(sizeof(*link->wr_tx_v2_ib), + GFP_KERNEL); + if (!link->wr_tx_v2_ib) + goto no_mem_tx_compl; + link->wr_tx_v2_sge = kzalloc(sizeof(*link->wr_tx_v2_sge), + GFP_KERNEL); + if (!link->wr_tx_v2_sge) + goto no_mem_v2_ib; + link->wr_tx_v2_pend = kzalloc(sizeof(*link->wr_tx_v2_pend), + GFP_KERNEL); + if (!link->wr_tx_v2_pend) + goto no_mem_v2_sge; + } + return 0; + +no_mem_v2_sge: + kfree(link->wr_tx_v2_sge); +no_mem_v2_ib: + kfree(link->wr_tx_v2_ib); +no_mem_tx_compl: + kfree(link->wr_tx_compl); +no_mem_wr_tx_pends: + kfree(link->wr_tx_pends); +no_mem_wr_tx_mask: + kfree(link->wr_tx_mask); +no_mem_wr_rx_sges: + kfree(link->wr_rx_sges); +no_mem_wr_tx_sges: + kfree(link->wr_tx_sges); +no_mem_wr_tx_rdma_sges: + kfree(link->wr_tx_rdma_sges); +no_mem_wr_tx_rdmas: + kfree(link->wr_tx_rdmas); +no_mem_wr_rx_ibs: + kfree(link->wr_rx_ibs); +no_mem_wr_tx_ibs: + kfree(link->wr_tx_ibs); +no_mem_wr_rx_bufs: + kfree(link->wr_rx_bufs); +no_mem_wr_tx_bufs: + kfree(link->wr_tx_bufs); +no_mem: + return -ENOMEM; +} + +void smc_wr_remove_dev(struct smc_ib_device *smcibdev) +{ + tasklet_kill(&smcibdev->recv_tasklet); + tasklet_kill(&smcibdev->send_tasklet); +} + +void smc_wr_add_dev(struct smc_ib_device *smcibdev) +{ + tasklet_setup(&smcibdev->recv_tasklet, smc_wr_rx_tasklet_fn); + tasklet_setup(&smcibdev->send_tasklet, smc_wr_tx_tasklet_fn); +} + +int smc_wr_create_link(struct smc_link *lnk) +{ + struct ib_device *ibdev = lnk->smcibdev->ibdev; + int rc = 0; + + smc_wr_tx_set_wr_id(&lnk->wr_tx_id, 0); + lnk->wr_rx_id = 0; + lnk->wr_rx_dma_addr = ib_dma_map_single( + ibdev, lnk->wr_rx_bufs, SMC_WR_BUF_SIZE * lnk->wr_rx_cnt, + DMA_FROM_DEVICE); + if (ib_dma_mapping_error(ibdev, lnk->wr_rx_dma_addr)) { + lnk->wr_rx_dma_addr = 0; + rc = -EIO; + goto out; + } + if (lnk->lgr->smc_version == SMC_V2) { + lnk->wr_rx_v2_dma_addr = ib_dma_map_single(ibdev, + lnk->lgr->wr_rx_buf_v2, SMC_WR_BUF_V2_SIZE, + DMA_FROM_DEVICE); + if (ib_dma_mapping_error(ibdev, lnk->wr_rx_v2_dma_addr)) { + lnk->wr_rx_v2_dma_addr = 0; + rc = -EIO; + goto dma_unmap; + } + lnk->wr_tx_v2_dma_addr = ib_dma_map_single(ibdev, + lnk->lgr->wr_tx_buf_v2, SMC_WR_BUF_V2_SIZE, + DMA_TO_DEVICE); + if (ib_dma_mapping_error(ibdev, lnk->wr_tx_v2_dma_addr)) { + lnk->wr_tx_v2_dma_addr = 0; + rc = -EIO; + goto dma_unmap; + } + } + lnk->wr_tx_dma_addr = ib_dma_map_single( + ibdev, lnk->wr_tx_bufs, SMC_WR_BUF_SIZE * lnk->wr_tx_cnt, + DMA_TO_DEVICE); + if (ib_dma_mapping_error(ibdev, lnk->wr_tx_dma_addr)) { + rc = -EIO; + goto dma_unmap; + } + smc_wr_init_sge(lnk); + bitmap_zero(lnk->wr_tx_mask, SMC_WR_BUF_CNT); + init_waitqueue_head(&lnk->wr_tx_wait); + atomic_set(&lnk->wr_tx_refcnt, 0); + init_waitqueue_head(&lnk->wr_reg_wait); + atomic_set(&lnk->wr_reg_refcnt, 0); + init_waitqueue_head(&lnk->wr_rx_empty_wait); + return rc; + +dma_unmap: + if (lnk->wr_rx_v2_dma_addr) { + ib_dma_unmap_single(ibdev, lnk->wr_rx_v2_dma_addr, + SMC_WR_BUF_V2_SIZE, + DMA_FROM_DEVICE); + lnk->wr_rx_v2_dma_addr = 0; + } + if (lnk->wr_tx_v2_dma_addr) { + ib_dma_unmap_single(ibdev, lnk->wr_tx_v2_dma_addr, + SMC_WR_BUF_V2_SIZE, + DMA_TO_DEVICE); + lnk->wr_tx_v2_dma_addr = 0; + } + ib_dma_unmap_single(ibdev, lnk->wr_rx_dma_addr, + SMC_WR_BUF_SIZE * lnk->wr_rx_cnt, + DMA_FROM_DEVICE); + lnk->wr_rx_dma_addr = 0; +out: + return rc; +} diff --git a/net/smc/smc_wr.h b/net/smc/smc_wr.h new file mode 100644 index 000000000..45e9b894d --- /dev/null +++ b/net/smc/smc_wr.h @@ -0,0 +1,140 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Shared Memory Communications over RDMA (SMC-R) and RoCE + * + * Work Requests exploiting Infiniband API + * + * Copyright IBM Corp. 2016 + * + * Author(s): Steffen Maier + */ + +#ifndef SMC_WR_H +#define SMC_WR_H + +#include +#include +#include + +#include "smc.h" +#include "smc_core.h" + +#define SMC_WR_BUF_CNT 16 /* # of ctrl buffers per link */ + +#define SMC_WR_TX_WAIT_FREE_SLOT_TIME (10 * HZ) + +#define SMC_WR_TX_SIZE 44 /* actual size of wr_send data (<=SMC_WR_BUF_SIZE) */ + +#define SMC_WR_TX_PEND_PRIV_SIZE 32 + +struct smc_wr_tx_pend_priv { + u8 priv[SMC_WR_TX_PEND_PRIV_SIZE]; +}; + +typedef void (*smc_wr_tx_handler)(struct smc_wr_tx_pend_priv *, + struct smc_link *, + enum ib_wc_status); + +typedef bool (*smc_wr_tx_filter)(struct smc_wr_tx_pend_priv *, + unsigned long); + +typedef void (*smc_wr_tx_dismisser)(struct smc_wr_tx_pend_priv *); + +struct smc_wr_rx_handler { + struct hlist_node list; /* hash table collision resolution */ + void (*handler)(struct ib_wc *, void *); + u8 type; +}; + +/* Only used by RDMA write WRs. + * All other WRs (CDC/LLC) use smc_wr_tx_send handling WR_ID implicitly + */ +static inline long smc_wr_tx_get_next_wr_id(struct smc_link *link) +{ + return atomic_long_inc_return(&link->wr_tx_id); +} + +static inline void smc_wr_tx_set_wr_id(atomic_long_t *wr_tx_id, long val) +{ + atomic_long_set(wr_tx_id, val); +} + +static inline bool smc_wr_tx_link_hold(struct smc_link *link) +{ + if (!smc_link_sendable(link)) + return false; + atomic_inc(&link->wr_tx_refcnt); + return true; +} + +static inline void smc_wr_tx_link_put(struct smc_link *link) +{ + if (atomic_dec_and_test(&link->wr_tx_refcnt)) + wake_up_all(&link->wr_tx_wait); +} + +static inline void smc_wr_drain_cq(struct smc_link *lnk) +{ + wait_event(lnk->wr_rx_empty_wait, lnk->wr_rx_id_compl == lnk->wr_rx_id); +} + +static inline void smc_wr_wakeup_tx_wait(struct smc_link *lnk) +{ + wake_up_all(&lnk->wr_tx_wait); +} + +static inline void smc_wr_wakeup_reg_wait(struct smc_link *lnk) +{ + wake_up(&lnk->wr_reg_wait); +} + +/* post a new receive work request to fill a completed old work request entry */ +static inline int smc_wr_rx_post(struct smc_link *link) +{ + int rc; + u64 wr_id, temp_wr_id; + u32 index; + + wr_id = ++link->wr_rx_id; /* tasklet context, thus not atomic */ + temp_wr_id = wr_id; + index = do_div(temp_wr_id, link->wr_rx_cnt); + link->wr_rx_ibs[index].wr_id = wr_id; + rc = ib_post_recv(link->roce_qp, &link->wr_rx_ibs[index], NULL); + return rc; +} + +int smc_wr_create_link(struct smc_link *lnk); +int smc_wr_alloc_link_mem(struct smc_link *lnk); +int smc_wr_alloc_lgr_mem(struct smc_link_group *lgr); +void smc_wr_free_link(struct smc_link *lnk); +void smc_wr_free_link_mem(struct smc_link *lnk); +void smc_wr_free_lgr_mem(struct smc_link_group *lgr); +void smc_wr_remember_qp_attr(struct smc_link *lnk); +void smc_wr_remove_dev(struct smc_ib_device *smcibdev); +void smc_wr_add_dev(struct smc_ib_device *smcibdev); + +int smc_wr_tx_get_free_slot(struct smc_link *link, smc_wr_tx_handler handler, + struct smc_wr_buf **wr_buf, + struct smc_rdma_wr **wrs, + struct smc_wr_tx_pend_priv **wr_pend_priv); +int smc_wr_tx_get_v2_slot(struct smc_link *link, + smc_wr_tx_handler handler, + struct smc_wr_v2_buf **wr_buf, + struct smc_wr_tx_pend_priv **wr_pend_priv); +int smc_wr_tx_put_slot(struct smc_link *link, + struct smc_wr_tx_pend_priv *wr_pend_priv); +int smc_wr_tx_send(struct smc_link *link, + struct smc_wr_tx_pend_priv *wr_pend_priv); +int smc_wr_tx_v2_send(struct smc_link *link, + struct smc_wr_tx_pend_priv *priv, int len); +int smc_wr_tx_send_wait(struct smc_link *link, struct smc_wr_tx_pend_priv *priv, + unsigned long timeout); +void smc_wr_tx_cq_handler(struct ib_cq *ib_cq, void *cq_context); +void smc_wr_tx_wait_no_pending_sends(struct smc_link *link); + +int smc_wr_rx_register_handler(struct smc_wr_rx_handler *handler); +int smc_wr_rx_post_init(struct smc_link *link); +void smc_wr_rx_cq_handler(struct ib_cq *ib_cq, void *cq_context); +int smc_wr_reg_send(struct smc_link *link, struct ib_mr *mr); + +#endif /* SMC_WR_H */ diff --git a/net/socket.c b/net/socket.c new file mode 100644 index 000000000..639d76f20 --- /dev/null +++ b/net/socket.c @@ -0,0 +1,3693 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NET An implementation of the SOCKET network access protocol. + * + * Version: @(#)socket.c 1.1.93 18/02/95 + * + * Authors: Orest Zborowski, + * Ross Biro + * Fred N. van Kempen, + * + * Fixes: + * Anonymous : NOTSOCK/BADF cleanup. Error fix in + * shutdown() + * Alan Cox : verify_area() fixes + * Alan Cox : Removed DDI + * Jonathan Kamens : SOCK_DGRAM reconnect bug + * Alan Cox : Moved a load of checks to the very + * top level. + * Alan Cox : Move address structures to/from user + * mode above the protocol layers. + * Rob Janssen : Allow 0 length sends. + * Alan Cox : Asynchronous I/O support (cribbed from the + * tty drivers). + * Niibe Yutaka : Asynchronous I/O for writes (4.4BSD style) + * Jeff Uphoff : Made max number of sockets command-line + * configurable. + * Matti Aarnio : Made the number of sockets dynamic, + * to be allocated when needed, and mr. + * Uphoff's max is used as max to be + * allowed to allocate. + * Linus : Argh. removed all the socket allocation + * altogether: it's in the inode now. + * Alan Cox : Made sock_alloc()/sock_release() public + * for NetROM and future kernel nfsd type + * stuff. + * Alan Cox : sendmsg/recvmsg basics. + * Tom Dyas : Export net symbols. + * Marcin Dalecki : Fixed problems with CONFIG_NET="n". + * Alan Cox : Added thread locking to sys_* calls + * for sockets. May have errors at the + * moment. + * Kevin Buhr : Fixed the dumb errors in the above. + * Andi Kleen : Some small cleanups, optimizations, + * and fixed a copy_from_user() bug. + * Tigran Aivazian : sys_send(args) calls sys_sendto(args, NULL, 0) + * Tigran Aivazian : Made listen(2) backlog sanity checks + * protocol-independent + * + * This module is effectively the top level interface to the BSD socket + * paradigm. + * + * Based upon Swansea University Computer Society NET3.039 + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CONFIG_NET_RX_BUSY_POLL +unsigned int sysctl_net_busy_read __read_mostly; +unsigned int sysctl_net_busy_poll __read_mostly; +#endif + +static ssize_t sock_read_iter(struct kiocb *iocb, struct iov_iter *to); +static ssize_t sock_write_iter(struct kiocb *iocb, struct iov_iter *from); +static int sock_mmap(struct file *file, struct vm_area_struct *vma); + +static int sock_close(struct inode *inode, struct file *file); +static __poll_t sock_poll(struct file *file, + struct poll_table_struct *wait); +static long sock_ioctl(struct file *file, unsigned int cmd, unsigned long arg); +#ifdef CONFIG_COMPAT +static long compat_sock_ioctl(struct file *file, + unsigned int cmd, unsigned long arg); +#endif +static int sock_fasync(int fd, struct file *filp, int on); +static ssize_t sock_sendpage(struct file *file, struct page *page, + int offset, size_t size, loff_t *ppos, int more); +static ssize_t sock_splice_read(struct file *file, loff_t *ppos, + struct pipe_inode_info *pipe, size_t len, + unsigned int flags); +static void sock_splice_eof(struct file *file); + +#ifdef CONFIG_PROC_FS +static void sock_show_fdinfo(struct seq_file *m, struct file *f) +{ + struct socket *sock = f->private_data; + + if (sock->ops->show_fdinfo) + sock->ops->show_fdinfo(m, sock); +} +#else +#define sock_show_fdinfo NULL +#endif + +/* + * Socket files have a set of 'special' operations as well as the generic file ones. These don't appear + * in the operation structures but are done directly via the socketcall() multiplexor. + */ + +static const struct file_operations socket_file_ops = { + .owner = THIS_MODULE, + .llseek = no_llseek, + .read_iter = sock_read_iter, + .write_iter = sock_write_iter, + .poll = sock_poll, + .unlocked_ioctl = sock_ioctl, +#ifdef CONFIG_COMPAT + .compat_ioctl = compat_sock_ioctl, +#endif + .mmap = sock_mmap, + .release = sock_close, + .fasync = sock_fasync, + .sendpage = sock_sendpage, + .splice_write = generic_splice_sendpage, + .splice_read = sock_splice_read, + .splice_eof = sock_splice_eof, + .show_fdinfo = sock_show_fdinfo, +}; + +static const char * const pf_family_names[] = { + [PF_UNSPEC] = "PF_UNSPEC", + [PF_UNIX] = "PF_UNIX/PF_LOCAL", + [PF_INET] = "PF_INET", + [PF_AX25] = "PF_AX25", + [PF_IPX] = "PF_IPX", + [PF_APPLETALK] = "PF_APPLETALK", + [PF_NETROM] = "PF_NETROM", + [PF_BRIDGE] = "PF_BRIDGE", + [PF_ATMPVC] = "PF_ATMPVC", + [PF_X25] = "PF_X25", + [PF_INET6] = "PF_INET6", + [PF_ROSE] = "PF_ROSE", + [PF_DECnet] = "PF_DECnet", + [PF_NETBEUI] = "PF_NETBEUI", + [PF_SECURITY] = "PF_SECURITY", + [PF_KEY] = "PF_KEY", + [PF_NETLINK] = "PF_NETLINK/PF_ROUTE", + [PF_PACKET] = "PF_PACKET", + [PF_ASH] = "PF_ASH", + [PF_ECONET] = "PF_ECONET", + [PF_ATMSVC] = "PF_ATMSVC", + [PF_RDS] = "PF_RDS", + [PF_SNA] = "PF_SNA", + [PF_IRDA] = "PF_IRDA", + [PF_PPPOX] = "PF_PPPOX", + [PF_WANPIPE] = "PF_WANPIPE", + [PF_LLC] = "PF_LLC", + [PF_IB] = "PF_IB", + [PF_MPLS] = "PF_MPLS", + [PF_CAN] = "PF_CAN", + [PF_TIPC] = "PF_TIPC", + [PF_BLUETOOTH] = "PF_BLUETOOTH", + [PF_IUCV] = "PF_IUCV", + [PF_RXRPC] = "PF_RXRPC", + [PF_ISDN] = "PF_ISDN", + [PF_PHONET] = "PF_PHONET", + [PF_IEEE802154] = "PF_IEEE802154", + [PF_CAIF] = "PF_CAIF", + [PF_ALG] = "PF_ALG", + [PF_NFC] = "PF_NFC", + [PF_VSOCK] = "PF_VSOCK", + [PF_KCM] = "PF_KCM", + [PF_QIPCRTR] = "PF_QIPCRTR", + [PF_SMC] = "PF_SMC", + [PF_XDP] = "PF_XDP", + [PF_MCTP] = "PF_MCTP", +}; + +/* + * The protocol list. Each protocol is registered in here. + */ + +static DEFINE_SPINLOCK(net_family_lock); +static const struct net_proto_family __rcu *net_families[NPROTO] __read_mostly; + +/* + * Support routines. + * Move socket addresses back and forth across the kernel/user + * divide and look after the messy bits. + */ + +/** + * move_addr_to_kernel - copy a socket address into kernel space + * @uaddr: Address in user space + * @kaddr: Address in kernel space + * @ulen: Length in user space + * + * The address is copied into kernel space. If the provided address is + * too long an error code of -EINVAL is returned. If the copy gives + * invalid addresses -EFAULT is returned. On a success 0 is returned. + */ + +int move_addr_to_kernel(void __user *uaddr, int ulen, struct sockaddr_storage *kaddr) +{ + if (ulen < 0 || ulen > sizeof(struct sockaddr_storage)) + return -EINVAL; + if (ulen == 0) + return 0; + if (copy_from_user(kaddr, uaddr, ulen)) + return -EFAULT; + return audit_sockaddr(ulen, kaddr); +} + +/** + * move_addr_to_user - copy an address to user space + * @kaddr: kernel space address + * @klen: length of address in kernel + * @uaddr: user space address + * @ulen: pointer to user length field + * + * The value pointed to by ulen on entry is the buffer length available. + * This is overwritten with the buffer space used. -EINVAL is returned + * if an overlong buffer is specified or a negative buffer size. -EFAULT + * is returned if either the buffer or the length field are not + * accessible. + * After copying the data up to the limit the user specifies, the true + * length of the data is written over the length limit the user + * specified. Zero is returned for a success. + */ + +static int move_addr_to_user(struct sockaddr_storage *kaddr, int klen, + void __user *uaddr, int __user *ulen) +{ + int err; + int len; + + BUG_ON(klen > sizeof(struct sockaddr_storage)); + err = get_user(len, ulen); + if (err) + return err; + if (len > klen) + len = klen; + if (len < 0) + return -EINVAL; + if (len) { + if (audit_sockaddr(klen, kaddr)) + return -ENOMEM; + if (copy_to_user(uaddr, kaddr, len)) + return -EFAULT; + } + /* + * "fromlen shall refer to the value before truncation.." + * 1003.1g + */ + return __put_user(klen, ulen); +} + +static struct kmem_cache *sock_inode_cachep __ro_after_init; + +static struct inode *sock_alloc_inode(struct super_block *sb) +{ + struct socket_alloc *ei; + + ei = alloc_inode_sb(sb, sock_inode_cachep, GFP_KERNEL); + if (!ei) + return NULL; + init_waitqueue_head(&ei->socket.wq.wait); + ei->socket.wq.fasync_list = NULL; + ei->socket.wq.flags = 0; + + ei->socket.state = SS_UNCONNECTED; + ei->socket.flags = 0; + ei->socket.ops = NULL; + ei->socket.sk = NULL; + ei->socket.file = NULL; + + return &ei->vfs_inode; +} + +static void sock_free_inode(struct inode *inode) +{ + struct socket_alloc *ei; + + ei = container_of(inode, struct socket_alloc, vfs_inode); + kmem_cache_free(sock_inode_cachep, ei); +} + +static void init_once(void *foo) +{ + struct socket_alloc *ei = (struct socket_alloc *)foo; + + inode_init_once(&ei->vfs_inode); +} + +static void init_inodecache(void) +{ + sock_inode_cachep = kmem_cache_create("sock_inode_cache", + sizeof(struct socket_alloc), + 0, + (SLAB_HWCACHE_ALIGN | + SLAB_RECLAIM_ACCOUNT | + SLAB_MEM_SPREAD | SLAB_ACCOUNT), + init_once); + BUG_ON(sock_inode_cachep == NULL); +} + +static const struct super_operations sockfs_ops = { + .alloc_inode = sock_alloc_inode, + .free_inode = sock_free_inode, + .statfs = simple_statfs, +}; + +/* + * sockfs_dname() is called from d_path(). + */ +static char *sockfs_dname(struct dentry *dentry, char *buffer, int buflen) +{ + return dynamic_dname(buffer, buflen, "socket:[%lu]", + d_inode(dentry)->i_ino); +} + +static const struct dentry_operations sockfs_dentry_operations = { + .d_dname = sockfs_dname, +}; + +static int sockfs_xattr_get(const struct xattr_handler *handler, + struct dentry *dentry, struct inode *inode, + const char *suffix, void *value, size_t size) +{ + if (value) { + if (dentry->d_name.len + 1 > size) + return -ERANGE; + memcpy(value, dentry->d_name.name, dentry->d_name.len + 1); + } + return dentry->d_name.len + 1; +} + +#define XATTR_SOCKPROTONAME_SUFFIX "sockprotoname" +#define XATTR_NAME_SOCKPROTONAME (XATTR_SYSTEM_PREFIX XATTR_SOCKPROTONAME_SUFFIX) +#define XATTR_NAME_SOCKPROTONAME_LEN (sizeof(XATTR_NAME_SOCKPROTONAME)-1) + +static const struct xattr_handler sockfs_xattr_handler = { + .name = XATTR_NAME_SOCKPROTONAME, + .get = sockfs_xattr_get, +}; + +static int sockfs_security_xattr_set(const struct xattr_handler *handler, + struct user_namespace *mnt_userns, + struct dentry *dentry, struct inode *inode, + const char *suffix, const void *value, + size_t size, int flags) +{ + /* Handled by LSM. */ + return -EAGAIN; +} + +static const struct xattr_handler sockfs_security_xattr_handler = { + .prefix = XATTR_SECURITY_PREFIX, + .set = sockfs_security_xattr_set, +}; + +static const struct xattr_handler *sockfs_xattr_handlers[] = { + &sockfs_xattr_handler, + &sockfs_security_xattr_handler, + NULL +}; + +static int sockfs_init_fs_context(struct fs_context *fc) +{ + struct pseudo_fs_context *ctx = init_pseudo(fc, SOCKFS_MAGIC); + if (!ctx) + return -ENOMEM; + ctx->ops = &sockfs_ops; + ctx->dops = &sockfs_dentry_operations; + ctx->xattr = sockfs_xattr_handlers; + return 0; +} + +static struct vfsmount *sock_mnt __read_mostly; + +static struct file_system_type sock_fs_type = { + .name = "sockfs", + .init_fs_context = sockfs_init_fs_context, + .kill_sb = kill_anon_super, +}; + +/* + * Obtains the first available file descriptor and sets it up for use. + * + * These functions create file structures and maps them to fd space + * of the current process. On success it returns file descriptor + * and file struct implicitly stored in sock->file. + * Note that another thread may close file descriptor before we return + * from this function. We use the fact that now we do not refer + * to socket after mapping. If one day we will need it, this + * function will increment ref. count on file by 1. + * + * In any case returned fd MAY BE not valid! + * This race condition is unavoidable + * with shared fd spaces, we cannot solve it inside kernel, + * but we take care of internal coherence yet. + */ + +/** + * sock_alloc_file - Bind a &socket to a &file + * @sock: socket + * @flags: file status flags + * @dname: protocol name + * + * Returns the &file bound with @sock, implicitly storing it + * in sock->file. If dname is %NULL, sets to "". + * + * On failure @sock is released, and an ERR pointer is returned. + * + * This function uses GFP_KERNEL internally. + */ + +struct file *sock_alloc_file(struct socket *sock, int flags, const char *dname) +{ + struct file *file; + + if (!dname) + dname = sock->sk ? sock->sk->sk_prot_creator->name : ""; + + file = alloc_file_pseudo(SOCK_INODE(sock), sock_mnt, dname, + O_RDWR | (flags & O_NONBLOCK), + &socket_file_ops); + if (IS_ERR(file)) { + sock_release(sock); + return file; + } + + sock->file = file; + file->private_data = sock; + stream_open(SOCK_INODE(sock), file); + return file; +} +EXPORT_SYMBOL(sock_alloc_file); + +static int sock_map_fd(struct socket *sock, int flags) +{ + struct file *newfile; + int fd = get_unused_fd_flags(flags); + if (unlikely(fd < 0)) { + sock_release(sock); + return fd; + } + + newfile = sock_alloc_file(sock, flags, NULL); + if (!IS_ERR(newfile)) { + fd_install(fd, newfile); + return fd; + } + + put_unused_fd(fd); + return PTR_ERR(newfile); +} + +/** + * sock_from_file - Return the &socket bounded to @file. + * @file: file + * + * On failure returns %NULL. + */ + +struct socket *sock_from_file(struct file *file) +{ + if (file->f_op == &socket_file_ops) + return file->private_data; /* set in sock_alloc_file */ + + return NULL; +} +EXPORT_SYMBOL(sock_from_file); + +/** + * sockfd_lookup - Go from a file number to its socket slot + * @fd: file handle + * @err: pointer to an error code return + * + * The file handle passed in is locked and the socket it is bound + * to is returned. If an error occurs the err pointer is overwritten + * with a negative errno code and NULL is returned. The function checks + * for both invalid handles and passing a handle which is not a socket. + * + * On a success the socket object pointer is returned. + */ + +struct socket *sockfd_lookup(int fd, int *err) +{ + struct file *file; + struct socket *sock; + + file = fget(fd); + if (!file) { + *err = -EBADF; + return NULL; + } + + sock = sock_from_file(file); + if (!sock) { + *err = -ENOTSOCK; + fput(file); + } + return sock; +} +EXPORT_SYMBOL(sockfd_lookup); + +static struct socket *sockfd_lookup_light(int fd, int *err, int *fput_needed) +{ + struct fd f = fdget(fd); + struct socket *sock; + + *err = -EBADF; + if (f.file) { + sock = sock_from_file(f.file); + if (likely(sock)) { + *fput_needed = f.flags & FDPUT_FPUT; + return sock; + } + *err = -ENOTSOCK; + fdput(f); + } + return NULL; +} + +static ssize_t sockfs_listxattr(struct dentry *dentry, char *buffer, + size_t size) +{ + ssize_t len; + ssize_t used = 0; + + len = security_inode_listsecurity(d_inode(dentry), buffer, size); + if (len < 0) + return len; + used += len; + if (buffer) { + if (size < used) + return -ERANGE; + buffer += len; + } + + len = (XATTR_NAME_SOCKPROTONAME_LEN + 1); + used += len; + if (buffer) { + if (size < used) + return -ERANGE; + memcpy(buffer, XATTR_NAME_SOCKPROTONAME, len); + buffer += len; + } + + return used; +} + +static int sockfs_setattr(struct user_namespace *mnt_userns, + struct dentry *dentry, struct iattr *iattr) +{ + int err = simple_setattr(&init_user_ns, dentry, iattr); + + if (!err && (iattr->ia_valid & ATTR_UID)) { + struct socket *sock = SOCKET_I(d_inode(dentry)); + + if (sock->sk) + sock->sk->sk_uid = iattr->ia_uid; + else + err = -ENOENT; + } + + return err; +} + +static const struct inode_operations sockfs_inode_ops = { + .listxattr = sockfs_listxattr, + .setattr = sockfs_setattr, +}; + +/** + * sock_alloc - allocate a socket + * + * Allocate a new inode and socket object. The two are bound together + * and initialised. The socket is then returned. If we are out of inodes + * NULL is returned. This functions uses GFP_KERNEL internally. + */ + +struct socket *sock_alloc(void) +{ + struct inode *inode; + struct socket *sock; + + inode = new_inode_pseudo(sock_mnt->mnt_sb); + if (!inode) + return NULL; + + sock = SOCKET_I(inode); + + inode->i_ino = get_next_ino(); + inode->i_mode = S_IFSOCK | S_IRWXUGO; + inode->i_uid = current_fsuid(); + inode->i_gid = current_fsgid(); + inode->i_op = &sockfs_inode_ops; + + return sock; +} +EXPORT_SYMBOL(sock_alloc); + +static void __sock_release(struct socket *sock, struct inode *inode) +{ + if (sock->ops) { + struct module *owner = sock->ops->owner; + + if (inode) + inode_lock(inode); + sock->ops->release(sock); + sock->sk = NULL; + if (inode) + inode_unlock(inode); + sock->ops = NULL; + module_put(owner); + } + + if (sock->wq.fasync_list) + pr_err("%s: fasync list not empty!\n", __func__); + + if (!sock->file) { + iput(SOCK_INODE(sock)); + return; + } + sock->file = NULL; +} + +/** + * sock_release - close a socket + * @sock: socket to close + * + * The socket is released from the protocol stack if it has a release + * callback, and the inode is then released if the socket is bound to + * an inode not a file. + */ +void sock_release(struct socket *sock) +{ + __sock_release(sock, NULL); +} +EXPORT_SYMBOL(sock_release); + +void __sock_tx_timestamp(__u16 tsflags, __u8 *tx_flags) +{ + u8 flags = *tx_flags; + + if (tsflags & SOF_TIMESTAMPING_TX_HARDWARE) { + flags |= SKBTX_HW_TSTAMP; + + /* PTP hardware clocks can provide a free running cycle counter + * as a time base for virtual clocks. Tell driver to use the + * free running cycle counter for timestamp if socket is bound + * to virtual clock. + */ + if (tsflags & SOF_TIMESTAMPING_BIND_PHC) + flags |= SKBTX_HW_TSTAMP_USE_CYCLES; + } + + if (tsflags & SOF_TIMESTAMPING_TX_SOFTWARE) + flags |= SKBTX_SW_TSTAMP; + + if (tsflags & SOF_TIMESTAMPING_TX_SCHED) + flags |= SKBTX_SCHED_TSTAMP; + + *tx_flags = flags; +} +EXPORT_SYMBOL(__sock_tx_timestamp); + +INDIRECT_CALLABLE_DECLARE(int inet_sendmsg(struct socket *, struct msghdr *, + size_t)); +INDIRECT_CALLABLE_DECLARE(int inet6_sendmsg(struct socket *, struct msghdr *, + size_t)); +static inline int sock_sendmsg_nosec(struct socket *sock, struct msghdr *msg) +{ + int ret = INDIRECT_CALL_INET(sock->ops->sendmsg, inet6_sendmsg, + inet_sendmsg, sock, msg, + msg_data_left(msg)); + BUG_ON(ret == -EIOCBQUEUED); + return ret; +} + +static int __sock_sendmsg(struct socket *sock, struct msghdr *msg) +{ + int err = security_socket_sendmsg(sock, msg, + msg_data_left(msg)); + + return err ?: sock_sendmsg_nosec(sock, msg); +} + +/** + * sock_sendmsg - send a message through @sock + * @sock: socket + * @msg: message to send + * + * Sends @msg through @sock, passing through LSM. + * Returns the number of bytes sent, or an error code. + */ +int sock_sendmsg(struct socket *sock, struct msghdr *msg) +{ + struct sockaddr_storage *save_addr = (struct sockaddr_storage *)msg->msg_name; + struct sockaddr_storage address; + int save_len = msg->msg_namelen; + int ret; + + if (msg->msg_name) { + memcpy(&address, msg->msg_name, msg->msg_namelen); + msg->msg_name = &address; + } + + ret = __sock_sendmsg(sock, msg); + msg->msg_name = save_addr; + msg->msg_namelen = save_len; + + return ret; +} +EXPORT_SYMBOL(sock_sendmsg); + +/** + * kernel_sendmsg - send a message through @sock (kernel-space) + * @sock: socket + * @msg: message header + * @vec: kernel vec + * @num: vec array length + * @size: total message data size + * + * Builds the message data with @vec and sends it through @sock. + * Returns the number of bytes sent, or an error code. + */ + +int kernel_sendmsg(struct socket *sock, struct msghdr *msg, + struct kvec *vec, size_t num, size_t size) +{ + iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, vec, num, size); + return sock_sendmsg(sock, msg); +} +EXPORT_SYMBOL(kernel_sendmsg); + +/** + * kernel_sendmsg_locked - send a message through @sock (kernel-space) + * @sk: sock + * @msg: message header + * @vec: output s/g array + * @num: output s/g array length + * @size: total message data size + * + * Builds the message data with @vec and sends it through @sock. + * Returns the number of bytes sent, or an error code. + * Caller must hold @sk. + */ + +int kernel_sendmsg_locked(struct sock *sk, struct msghdr *msg, + struct kvec *vec, size_t num, size_t size) +{ + struct socket *sock = sk->sk_socket; + + if (!sock->ops->sendmsg_locked) + return sock_no_sendmsg_locked(sk, msg, size); + + iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, vec, num, size); + + return sock->ops->sendmsg_locked(sk, msg, msg_data_left(msg)); +} +EXPORT_SYMBOL(kernel_sendmsg_locked); + +static bool skb_is_err_queue(const struct sk_buff *skb) +{ + /* pkt_type of skbs enqueued on the error queue are set to + * PACKET_OUTGOING in skb_set_err_queue(). This is only safe to do + * in recvmsg, since skbs received on a local socket will never + * have a pkt_type of PACKET_OUTGOING. + */ + return skb->pkt_type == PACKET_OUTGOING; +} + +/* On transmit, software and hardware timestamps are returned independently. + * As the two skb clones share the hardware timestamp, which may be updated + * before the software timestamp is received, a hardware TX timestamp may be + * returned only if there is no software TX timestamp. Ignore false software + * timestamps, which may be made in the __sock_recv_timestamp() call when the + * option SO_TIMESTAMP_OLD(NS) is enabled on the socket, even when the skb has a + * hardware timestamp. + */ +static bool skb_is_swtx_tstamp(const struct sk_buff *skb, int false_tstamp) +{ + return skb->tstamp && !false_tstamp && skb_is_err_queue(skb); +} + +static ktime_t get_timestamp(struct sock *sk, struct sk_buff *skb, int *if_index) +{ + bool cycles = READ_ONCE(sk->sk_tsflags) & SOF_TIMESTAMPING_BIND_PHC; + struct skb_shared_hwtstamps *shhwtstamps = skb_hwtstamps(skb); + struct net_device *orig_dev; + ktime_t hwtstamp; + + rcu_read_lock(); + orig_dev = dev_get_by_napi_id(skb_napi_id(skb)); + if (orig_dev) { + *if_index = orig_dev->ifindex; + hwtstamp = netdev_get_tstamp(orig_dev, shhwtstamps, cycles); + } else { + hwtstamp = shhwtstamps->hwtstamp; + } + rcu_read_unlock(); + + return hwtstamp; +} + +static void put_ts_pktinfo(struct msghdr *msg, struct sk_buff *skb, + int if_index) +{ + struct scm_ts_pktinfo ts_pktinfo; + struct net_device *orig_dev; + + if (!skb_mac_header_was_set(skb)) + return; + + memset(&ts_pktinfo, 0, sizeof(ts_pktinfo)); + + if (!if_index) { + rcu_read_lock(); + orig_dev = dev_get_by_napi_id(skb_napi_id(skb)); + if (orig_dev) + if_index = orig_dev->ifindex; + rcu_read_unlock(); + } + ts_pktinfo.if_index = if_index; + + ts_pktinfo.pkt_length = skb->len - skb_mac_offset(skb); + put_cmsg(msg, SOL_SOCKET, SCM_TIMESTAMPING_PKTINFO, + sizeof(ts_pktinfo), &ts_pktinfo); +} + +/* + * called from sock_recv_timestamp() if sock_flag(sk, SOCK_RCVTSTAMP) + */ +void __sock_recv_timestamp(struct msghdr *msg, struct sock *sk, + struct sk_buff *skb) +{ + int need_software_tstamp = sock_flag(sk, SOCK_RCVTSTAMP); + int new_tstamp = sock_flag(sk, SOCK_TSTAMP_NEW); + struct scm_timestamping_internal tss; + int empty = 1, false_tstamp = 0; + struct skb_shared_hwtstamps *shhwtstamps = + skb_hwtstamps(skb); + int if_index; + ktime_t hwtstamp; + u32 tsflags; + + /* Race occurred between timestamp enabling and packet + receiving. Fill in the current time for now. */ + if (need_software_tstamp && skb->tstamp == 0) { + __net_timestamp(skb); + false_tstamp = 1; + } + + if (need_software_tstamp) { + if (!sock_flag(sk, SOCK_RCVTSTAMPNS)) { + if (new_tstamp) { + struct __kernel_sock_timeval tv; + + skb_get_new_timestamp(skb, &tv); + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMP_NEW, + sizeof(tv), &tv); + } else { + struct __kernel_old_timeval tv; + + skb_get_timestamp(skb, &tv); + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMP_OLD, + sizeof(tv), &tv); + } + } else { + if (new_tstamp) { + struct __kernel_timespec ts; + + skb_get_new_timestampns(skb, &ts); + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMPNS_NEW, + sizeof(ts), &ts); + } else { + struct __kernel_old_timespec ts; + + skb_get_timestampns(skb, &ts); + put_cmsg(msg, SOL_SOCKET, SO_TIMESTAMPNS_OLD, + sizeof(ts), &ts); + } + } + } + + memset(&tss, 0, sizeof(tss)); + tsflags = READ_ONCE(sk->sk_tsflags); + if ((tsflags & SOF_TIMESTAMPING_SOFTWARE) && + ktime_to_timespec64_cond(skb->tstamp, tss.ts + 0)) + empty = 0; + if (shhwtstamps && + (tsflags & SOF_TIMESTAMPING_RAW_HARDWARE) && + !skb_is_swtx_tstamp(skb, false_tstamp)) { + if_index = 0; + if (skb_shinfo(skb)->tx_flags & SKBTX_HW_TSTAMP_NETDEV) + hwtstamp = get_timestamp(sk, skb, &if_index); + else + hwtstamp = shhwtstamps->hwtstamp; + + if (tsflags & SOF_TIMESTAMPING_BIND_PHC) + hwtstamp = ptp_convert_timestamp(&hwtstamp, + READ_ONCE(sk->sk_bind_phc)); + + if (ktime_to_timespec64_cond(hwtstamp, tss.ts + 2)) { + empty = 0; + + if ((tsflags & SOF_TIMESTAMPING_OPT_PKTINFO) && + !skb_is_err_queue(skb)) + put_ts_pktinfo(msg, skb, if_index); + } + } + if (!empty) { + if (sock_flag(sk, SOCK_TSTAMP_NEW)) + put_cmsg_scm_timestamping64(msg, &tss); + else + put_cmsg_scm_timestamping(msg, &tss); + + if (skb_is_err_queue(skb) && skb->len && + SKB_EXT_ERR(skb)->opt_stats) + put_cmsg(msg, SOL_SOCKET, SCM_TIMESTAMPING_OPT_STATS, + skb->len, skb->data); + } +} +EXPORT_SYMBOL_GPL(__sock_recv_timestamp); + +void __sock_recv_wifi_status(struct msghdr *msg, struct sock *sk, + struct sk_buff *skb) +{ + int ack; + + if (!sock_flag(sk, SOCK_WIFI_STATUS)) + return; + if (!skb->wifi_acked_valid) + return; + + ack = skb->wifi_acked; + + put_cmsg(msg, SOL_SOCKET, SCM_WIFI_STATUS, sizeof(ack), &ack); +} +EXPORT_SYMBOL_GPL(__sock_recv_wifi_status); + +static inline void sock_recv_drops(struct msghdr *msg, struct sock *sk, + struct sk_buff *skb) +{ + if (sock_flag(sk, SOCK_RXQ_OVFL) && skb && SOCK_SKB_CB(skb)->dropcount) + put_cmsg(msg, SOL_SOCKET, SO_RXQ_OVFL, + sizeof(__u32), &SOCK_SKB_CB(skb)->dropcount); +} + +static void sock_recv_mark(struct msghdr *msg, struct sock *sk, + struct sk_buff *skb) +{ + if (sock_flag(sk, SOCK_RCVMARK) && skb) { + /* We must use a bounce buffer for CONFIG_HARDENED_USERCOPY=y */ + __u32 mark = skb->mark; + + put_cmsg(msg, SOL_SOCKET, SO_MARK, sizeof(__u32), &mark); + } +} + +void __sock_recv_cmsgs(struct msghdr *msg, struct sock *sk, + struct sk_buff *skb) +{ + sock_recv_timestamp(msg, sk, skb); + sock_recv_drops(msg, sk, skb); + sock_recv_mark(msg, sk, skb); +} +EXPORT_SYMBOL_GPL(__sock_recv_cmsgs); + +INDIRECT_CALLABLE_DECLARE(int inet_recvmsg(struct socket *, struct msghdr *, + size_t, int)); +INDIRECT_CALLABLE_DECLARE(int inet6_recvmsg(struct socket *, struct msghdr *, + size_t, int)); +static inline int sock_recvmsg_nosec(struct socket *sock, struct msghdr *msg, + int flags) +{ + return INDIRECT_CALL_INET(sock->ops->recvmsg, inet6_recvmsg, + inet_recvmsg, sock, msg, msg_data_left(msg), + flags); +} + +/** + * sock_recvmsg - receive a message from @sock + * @sock: socket + * @msg: message to receive + * @flags: message flags + * + * Receives @msg from @sock, passing through LSM. Returns the total number + * of bytes received, or an error. + */ +int sock_recvmsg(struct socket *sock, struct msghdr *msg, int flags) +{ + int err = security_socket_recvmsg(sock, msg, msg_data_left(msg), flags); + + return err ?: sock_recvmsg_nosec(sock, msg, flags); +} +EXPORT_SYMBOL(sock_recvmsg); + +/** + * kernel_recvmsg - Receive a message from a socket (kernel space) + * @sock: The socket to receive the message from + * @msg: Received message + * @vec: Input s/g array for message data + * @num: Size of input s/g array + * @size: Number of bytes to read + * @flags: Message flags (MSG_DONTWAIT, etc...) + * + * On return the msg structure contains the scatter/gather array passed in the + * vec argument. The array is modified so that it consists of the unfilled + * portion of the original array. + * + * The returned value is the total number of bytes received, or an error. + */ + +int kernel_recvmsg(struct socket *sock, struct msghdr *msg, + struct kvec *vec, size_t num, size_t size, int flags) +{ + msg->msg_control_is_user = false; + iov_iter_kvec(&msg->msg_iter, ITER_DEST, vec, num, size); + return sock_recvmsg(sock, msg, flags); +} +EXPORT_SYMBOL(kernel_recvmsg); + +static ssize_t sock_sendpage(struct file *file, struct page *page, + int offset, size_t size, loff_t *ppos, int more) +{ + struct socket *sock; + int flags; + + sock = file->private_data; + + flags = (file->f_flags & O_NONBLOCK) ? MSG_DONTWAIT : 0; + /* more is a combination of MSG_MORE and MSG_SENDPAGE_NOTLAST */ + flags |= more; + + return kernel_sendpage(sock, page, offset, size, flags); +} + +static ssize_t sock_splice_read(struct file *file, loff_t *ppos, + struct pipe_inode_info *pipe, size_t len, + unsigned int flags) +{ + struct socket *sock = file->private_data; + + if (unlikely(!sock->ops->splice_read)) + return generic_file_splice_read(file, ppos, pipe, len, flags); + + return sock->ops->splice_read(sock, ppos, pipe, len, flags); +} + +static void sock_splice_eof(struct file *file) +{ + struct socket *sock = file->private_data; + + if (sock->ops->splice_eof) + sock->ops->splice_eof(sock); +} + +static ssize_t sock_read_iter(struct kiocb *iocb, struct iov_iter *to) +{ + struct file *file = iocb->ki_filp; + struct socket *sock = file->private_data; + struct msghdr msg = {.msg_iter = *to, + .msg_iocb = iocb}; + ssize_t res; + + if (file->f_flags & O_NONBLOCK || (iocb->ki_flags & IOCB_NOWAIT)) + msg.msg_flags = MSG_DONTWAIT; + + if (iocb->ki_pos != 0) + return -ESPIPE; + + if (!iov_iter_count(to)) /* Match SYS5 behaviour */ + return 0; + + res = sock_recvmsg(sock, &msg, msg.msg_flags); + *to = msg.msg_iter; + return res; +} + +static ssize_t sock_write_iter(struct kiocb *iocb, struct iov_iter *from) +{ + struct file *file = iocb->ki_filp; + struct socket *sock = file->private_data; + struct msghdr msg = {.msg_iter = *from, + .msg_iocb = iocb}; + ssize_t res; + + if (iocb->ki_pos != 0) + return -ESPIPE; + + if (file->f_flags & O_NONBLOCK || (iocb->ki_flags & IOCB_NOWAIT)) + msg.msg_flags = MSG_DONTWAIT; + + if (sock->type == SOCK_SEQPACKET) + msg.msg_flags |= MSG_EOR; + + res = __sock_sendmsg(sock, &msg); + *from = msg.msg_iter; + return res; +} + +/* + * Atomic setting of ioctl hooks to avoid race + * with module unload. + */ + +static DEFINE_MUTEX(br_ioctl_mutex); +static int (*br_ioctl_hook)(struct net *net, struct net_bridge *br, + unsigned int cmd, struct ifreq *ifr, + void __user *uarg); + +void brioctl_set(int (*hook)(struct net *net, struct net_bridge *br, + unsigned int cmd, struct ifreq *ifr, + void __user *uarg)) +{ + mutex_lock(&br_ioctl_mutex); + br_ioctl_hook = hook; + mutex_unlock(&br_ioctl_mutex); +} +EXPORT_SYMBOL(brioctl_set); + +int br_ioctl_call(struct net *net, struct net_bridge *br, unsigned int cmd, + struct ifreq *ifr, void __user *uarg) +{ + int err = -ENOPKG; + + if (!br_ioctl_hook) + request_module("bridge"); + + mutex_lock(&br_ioctl_mutex); + if (br_ioctl_hook) + err = br_ioctl_hook(net, br, cmd, ifr, uarg); + mutex_unlock(&br_ioctl_mutex); + + return err; +} + +static DEFINE_MUTEX(vlan_ioctl_mutex); +static int (*vlan_ioctl_hook) (struct net *, void __user *arg); + +void vlan_ioctl_set(int (*hook) (struct net *, void __user *)) +{ + mutex_lock(&vlan_ioctl_mutex); + vlan_ioctl_hook = hook; + mutex_unlock(&vlan_ioctl_mutex); +} +EXPORT_SYMBOL(vlan_ioctl_set); + +static long sock_do_ioctl(struct net *net, struct socket *sock, + unsigned int cmd, unsigned long arg) +{ + struct ifreq ifr; + bool need_copyout; + int err; + void __user *argp = (void __user *)arg; + void __user *data; + + err = sock->ops->ioctl(sock, cmd, arg); + + /* + * If this ioctl is unknown try to hand it down + * to the NIC driver. + */ + if (err != -ENOIOCTLCMD) + return err; + + if (!is_socket_ioctl_cmd(cmd)) + return -ENOTTY; + + if (get_user_ifreq(&ifr, &data, argp)) + return -EFAULT; + err = dev_ioctl(net, cmd, &ifr, data, &need_copyout); + if (!err && need_copyout) + if (put_user_ifreq(&ifr, argp)) + return -EFAULT; + + return err; +} + +/* + * With an ioctl, arg may well be a user mode pointer, but we don't know + * what to do with it - that's up to the protocol still. + */ + +static long sock_ioctl(struct file *file, unsigned cmd, unsigned long arg) +{ + struct socket *sock; + struct sock *sk; + void __user *argp = (void __user *)arg; + int pid, err; + struct net *net; + + sock = file->private_data; + sk = sock->sk; + net = sock_net(sk); + if (unlikely(cmd >= SIOCDEVPRIVATE && cmd <= (SIOCDEVPRIVATE + 15))) { + struct ifreq ifr; + void __user *data; + bool need_copyout; + if (get_user_ifreq(&ifr, &data, argp)) + return -EFAULT; + err = dev_ioctl(net, cmd, &ifr, data, &need_copyout); + if (!err && need_copyout) + if (put_user_ifreq(&ifr, argp)) + return -EFAULT; + } else +#ifdef CONFIG_WEXT_CORE + if (cmd >= SIOCIWFIRST && cmd <= SIOCIWLAST) { + err = wext_handle_ioctl(net, cmd, argp); + } else +#endif + switch (cmd) { + case FIOSETOWN: + case SIOCSPGRP: + err = -EFAULT; + if (get_user(pid, (int __user *)argp)) + break; + err = f_setown(sock->file, pid, 1); + break; + case FIOGETOWN: + case SIOCGPGRP: + err = put_user(f_getown(sock->file), + (int __user *)argp); + break; + case SIOCGIFBR: + case SIOCSIFBR: + case SIOCBRADDBR: + case SIOCBRDELBR: + err = br_ioctl_call(net, NULL, cmd, NULL, argp); + break; + case SIOCGIFVLAN: + case SIOCSIFVLAN: + err = -ENOPKG; + if (!vlan_ioctl_hook) + request_module("8021q"); + + mutex_lock(&vlan_ioctl_mutex); + if (vlan_ioctl_hook) + err = vlan_ioctl_hook(net, argp); + mutex_unlock(&vlan_ioctl_mutex); + break; + case SIOCGSKNS: + err = -EPERM; + if (!ns_capable(net->user_ns, CAP_NET_ADMIN)) + break; + + err = open_related_ns(&net->ns, get_net_ns); + break; + case SIOCGSTAMP_OLD: + case SIOCGSTAMPNS_OLD: + if (!sock->ops->gettstamp) { + err = -ENOIOCTLCMD; + break; + } + err = sock->ops->gettstamp(sock, argp, + cmd == SIOCGSTAMP_OLD, + !IS_ENABLED(CONFIG_64BIT)); + break; + case SIOCGSTAMP_NEW: + case SIOCGSTAMPNS_NEW: + if (!sock->ops->gettstamp) { + err = -ENOIOCTLCMD; + break; + } + err = sock->ops->gettstamp(sock, argp, + cmd == SIOCGSTAMP_NEW, + false); + break; + + case SIOCGIFCONF: + err = dev_ifconf(net, argp); + break; + + default: + err = sock_do_ioctl(net, sock, cmd, arg); + break; + } + return err; +} + +/** + * sock_create_lite - creates a socket + * @family: protocol family (AF_INET, ...) + * @type: communication type (SOCK_STREAM, ...) + * @protocol: protocol (0, ...) + * @res: new socket + * + * Creates a new socket and assigns it to @res, passing through LSM. + * The new socket initialization is not complete, see kernel_accept(). + * Returns 0 or an error. On failure @res is set to %NULL. + * This function internally uses GFP_KERNEL. + */ + +int sock_create_lite(int family, int type, int protocol, struct socket **res) +{ + int err; + struct socket *sock = NULL; + + err = security_socket_create(family, type, protocol, 1); + if (err) + goto out; + + sock = sock_alloc(); + if (!sock) { + err = -ENOMEM; + goto out; + } + + sock->type = type; + err = security_socket_post_create(sock, family, type, protocol, 1); + if (err) + goto out_release; + +out: + *res = sock; + return err; +out_release: + sock_release(sock); + sock = NULL; + goto out; +} +EXPORT_SYMBOL(sock_create_lite); + +/* No kernel lock held - perfect */ +static __poll_t sock_poll(struct file *file, poll_table *wait) +{ + struct socket *sock = file->private_data; + __poll_t events = poll_requested_events(wait), flag = 0; + + if (!sock->ops->poll) + return 0; + + if (sk_can_busy_loop(sock->sk)) { + /* poll once if requested by the syscall */ + if (events & POLL_BUSY_LOOP) + sk_busy_loop(sock->sk, 1); + + /* if this socket can poll_ll, tell the system call */ + flag = POLL_BUSY_LOOP; + } + + return sock->ops->poll(file, sock, wait) | flag; +} + +static int sock_mmap(struct file *file, struct vm_area_struct *vma) +{ + struct socket *sock = file->private_data; + + return sock->ops->mmap(file, sock, vma); +} + +static int sock_close(struct inode *inode, struct file *filp) +{ + __sock_release(SOCKET_I(inode), inode); + return 0; +} + +/* + * Update the socket async list + * + * Fasync_list locking strategy. + * + * 1. fasync_list is modified only under process context socket lock + * i.e. under semaphore. + * 2. fasync_list is used under read_lock(&sk->sk_callback_lock) + * or under socket lock + */ + +static int sock_fasync(int fd, struct file *filp, int on) +{ + struct socket *sock = filp->private_data; + struct sock *sk = sock->sk; + struct socket_wq *wq = &sock->wq; + + if (sk == NULL) + return -EINVAL; + + lock_sock(sk); + fasync_helper(fd, filp, on, &wq->fasync_list); + + if (!wq->fasync_list) + sock_reset_flag(sk, SOCK_FASYNC); + else + sock_set_flag(sk, SOCK_FASYNC); + + release_sock(sk); + return 0; +} + +/* This function may be called only under rcu_lock */ + +int sock_wake_async(struct socket_wq *wq, int how, int band) +{ + if (!wq || !wq->fasync_list) + return -1; + + switch (how) { + case SOCK_WAKE_WAITD: + if (test_bit(SOCKWQ_ASYNC_WAITDATA, &wq->flags)) + break; + goto call_kill; + case SOCK_WAKE_SPACE: + if (!test_and_clear_bit(SOCKWQ_ASYNC_NOSPACE, &wq->flags)) + break; + fallthrough; + case SOCK_WAKE_IO: +call_kill: + kill_fasync(&wq->fasync_list, SIGIO, band); + break; + case SOCK_WAKE_URG: + kill_fasync(&wq->fasync_list, SIGURG, band); + } + + return 0; +} +EXPORT_SYMBOL(sock_wake_async); + +/** + * __sock_create - creates a socket + * @net: net namespace + * @family: protocol family (AF_INET, ...) + * @type: communication type (SOCK_STREAM, ...) + * @protocol: protocol (0, ...) + * @res: new socket + * @kern: boolean for kernel space sockets + * + * Creates a new socket and assigns it to @res, passing through LSM. + * Returns 0 or an error. On failure @res is set to %NULL. @kern must + * be set to true if the socket resides in kernel space. + * This function internally uses GFP_KERNEL. + */ + +int __sock_create(struct net *net, int family, int type, int protocol, + struct socket **res, int kern) +{ + int err; + struct socket *sock; + const struct net_proto_family *pf; + + /* + * Check protocol is in range + */ + if (family < 0 || family >= NPROTO) + return -EAFNOSUPPORT; + if (type < 0 || type >= SOCK_MAX) + return -EINVAL; + + /* Compatibility. + + This uglymoron is moved from INET layer to here to avoid + deadlock in module load. + */ + if (family == PF_INET && type == SOCK_PACKET) { + pr_info_once("%s uses obsolete (PF_INET,SOCK_PACKET)\n", + current->comm); + family = PF_PACKET; + } + + err = security_socket_create(family, type, protocol, kern); + if (err) + return err; + + /* + * Allocate the socket and allow the family to set things up. if + * the protocol is 0, the family is instructed to select an appropriate + * default. + */ + sock = sock_alloc(); + if (!sock) { + net_warn_ratelimited("socket: no more sockets\n"); + return -ENFILE; /* Not exactly a match, but its the + closest posix thing */ + } + + sock->type = type; + +#ifdef CONFIG_MODULES + /* Attempt to load a protocol module if the find failed. + * + * 12/09/1996 Marcin: But! this makes REALLY only sense, if the user + * requested real, full-featured networking support upon configuration. + * Otherwise module support will break! + */ + if (rcu_access_pointer(net_families[family]) == NULL) + request_module("net-pf-%d", family); +#endif + + rcu_read_lock(); + pf = rcu_dereference(net_families[family]); + err = -EAFNOSUPPORT; + if (!pf) + goto out_release; + + /* + * We will call the ->create function, that possibly is in a loadable + * module, so we have to bump that loadable module refcnt first. + */ + if (!try_module_get(pf->owner)) + goto out_release; + + /* Now protected by module ref count */ + rcu_read_unlock(); + + err = pf->create(net, sock, protocol, kern); + if (err < 0) + goto out_module_put; + + /* + * Now to bump the refcnt of the [loadable] module that owns this + * socket at sock_release time we decrement its refcnt. + */ + if (!try_module_get(sock->ops->owner)) + goto out_module_busy; + + /* + * Now that we're done with the ->create function, the [loadable] + * module can have its refcnt decremented + */ + module_put(pf->owner); + err = security_socket_post_create(sock, family, type, protocol, kern); + if (err) + goto out_sock_release; + *res = sock; + + return 0; + +out_module_busy: + err = -EAFNOSUPPORT; +out_module_put: + sock->ops = NULL; + module_put(pf->owner); +out_sock_release: + sock_release(sock); + return err; + +out_release: + rcu_read_unlock(); + goto out_sock_release; +} +EXPORT_SYMBOL(__sock_create); + +/** + * sock_create - creates a socket + * @family: protocol family (AF_INET, ...) + * @type: communication type (SOCK_STREAM, ...) + * @protocol: protocol (0, ...) + * @res: new socket + * + * A wrapper around __sock_create(). + * Returns 0 or an error. This function internally uses GFP_KERNEL. + */ + +int sock_create(int family, int type, int protocol, struct socket **res) +{ + return __sock_create(current->nsproxy->net_ns, family, type, protocol, res, 0); +} +EXPORT_SYMBOL(sock_create); + +/** + * sock_create_kern - creates a socket (kernel space) + * @net: net namespace + * @family: protocol family (AF_INET, ...) + * @type: communication type (SOCK_STREAM, ...) + * @protocol: protocol (0, ...) + * @res: new socket + * + * A wrapper around __sock_create(). + * Returns 0 or an error. This function internally uses GFP_KERNEL. + */ + +int sock_create_kern(struct net *net, int family, int type, int protocol, struct socket **res) +{ + return __sock_create(net, family, type, protocol, res, 1); +} +EXPORT_SYMBOL(sock_create_kern); + +static struct socket *__sys_socket_create(int family, int type, int protocol) +{ + struct socket *sock; + int retval; + + /* Check the SOCK_* constants for consistency. */ + BUILD_BUG_ON(SOCK_CLOEXEC != O_CLOEXEC); + BUILD_BUG_ON((SOCK_MAX | SOCK_TYPE_MASK) != SOCK_TYPE_MASK); + BUILD_BUG_ON(SOCK_CLOEXEC & SOCK_TYPE_MASK); + BUILD_BUG_ON(SOCK_NONBLOCK & SOCK_TYPE_MASK); + + if ((type & ~SOCK_TYPE_MASK) & ~(SOCK_CLOEXEC | SOCK_NONBLOCK)) + return ERR_PTR(-EINVAL); + type &= SOCK_TYPE_MASK; + + retval = sock_create(family, type, protocol, &sock); + if (retval < 0) + return ERR_PTR(retval); + + return sock; +} + +struct file *__sys_socket_file(int family, int type, int protocol) +{ + struct socket *sock; + int flags; + + sock = __sys_socket_create(family, type, protocol); + if (IS_ERR(sock)) + return ERR_CAST(sock); + + flags = type & ~SOCK_TYPE_MASK; + if (SOCK_NONBLOCK != O_NONBLOCK && (flags & SOCK_NONBLOCK)) + flags = (flags & ~SOCK_NONBLOCK) | O_NONBLOCK; + + return sock_alloc_file(sock, flags, NULL); +} + +int __sys_socket(int family, int type, int protocol) +{ + struct socket *sock; + int flags; + + sock = __sys_socket_create(family, type, protocol); + if (IS_ERR(sock)) + return PTR_ERR(sock); + + flags = type & ~SOCK_TYPE_MASK; + if (SOCK_NONBLOCK != O_NONBLOCK && (flags & SOCK_NONBLOCK)) + flags = (flags & ~SOCK_NONBLOCK) | O_NONBLOCK; + + return sock_map_fd(sock, flags & (O_CLOEXEC | O_NONBLOCK)); +} + +SYSCALL_DEFINE3(socket, int, family, int, type, int, protocol) +{ + return __sys_socket(family, type, protocol); +} + +/* + * Create a pair of connected sockets. + */ + +int __sys_socketpair(int family, int type, int protocol, int __user *usockvec) +{ + struct socket *sock1, *sock2; + int fd1, fd2, err; + struct file *newfile1, *newfile2; + int flags; + + flags = type & ~SOCK_TYPE_MASK; + if (flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK)) + return -EINVAL; + type &= SOCK_TYPE_MASK; + + if (SOCK_NONBLOCK != O_NONBLOCK && (flags & SOCK_NONBLOCK)) + flags = (flags & ~SOCK_NONBLOCK) | O_NONBLOCK; + + /* + * reserve descriptors and make sure we won't fail + * to return them to userland. + */ + fd1 = get_unused_fd_flags(flags); + if (unlikely(fd1 < 0)) + return fd1; + + fd2 = get_unused_fd_flags(flags); + if (unlikely(fd2 < 0)) { + put_unused_fd(fd1); + return fd2; + } + + err = put_user(fd1, &usockvec[0]); + if (err) + goto out; + + err = put_user(fd2, &usockvec[1]); + if (err) + goto out; + + /* + * Obtain the first socket and check if the underlying protocol + * supports the socketpair call. + */ + + err = sock_create(family, type, protocol, &sock1); + if (unlikely(err < 0)) + goto out; + + err = sock_create(family, type, protocol, &sock2); + if (unlikely(err < 0)) { + sock_release(sock1); + goto out; + } + + err = security_socket_socketpair(sock1, sock2); + if (unlikely(err)) { + sock_release(sock2); + sock_release(sock1); + goto out; + } + + err = sock1->ops->socketpair(sock1, sock2); + if (unlikely(err < 0)) { + sock_release(sock2); + sock_release(sock1); + goto out; + } + + newfile1 = sock_alloc_file(sock1, flags, NULL); + if (IS_ERR(newfile1)) { + err = PTR_ERR(newfile1); + sock_release(sock2); + goto out; + } + + newfile2 = sock_alloc_file(sock2, flags, NULL); + if (IS_ERR(newfile2)) { + err = PTR_ERR(newfile2); + fput(newfile1); + goto out; + } + + audit_fd_pair(fd1, fd2); + + fd_install(fd1, newfile1); + fd_install(fd2, newfile2); + return 0; + +out: + put_unused_fd(fd2); + put_unused_fd(fd1); + return err; +} + +SYSCALL_DEFINE4(socketpair, int, family, int, type, int, protocol, + int __user *, usockvec) +{ + return __sys_socketpair(family, type, protocol, usockvec); +} + +/* + * Bind a name to a socket. Nothing much to do here since it's + * the protocol's responsibility to handle the local address. + * + * We move the socket address to kernel space before we call + * the protocol layer (having also checked the address is ok). + */ + +int __sys_bind(int fd, struct sockaddr __user *umyaddr, int addrlen) +{ + struct socket *sock; + struct sockaddr_storage address; + int err, fput_needed; + + sock = sockfd_lookup_light(fd, &err, &fput_needed); + if (sock) { + err = move_addr_to_kernel(umyaddr, addrlen, &address); + if (!err) { + err = security_socket_bind(sock, + (struct sockaddr *)&address, + addrlen); + if (!err) + err = sock->ops->bind(sock, + (struct sockaddr *) + &address, addrlen); + } + fput_light(sock->file, fput_needed); + } + return err; +} + +SYSCALL_DEFINE3(bind, int, fd, struct sockaddr __user *, umyaddr, int, addrlen) +{ + return __sys_bind(fd, umyaddr, addrlen); +} + +/* + * Perform a listen. Basically, we allow the protocol to do anything + * necessary for a listen, and if that works, we mark the socket as + * ready for listening. + */ + +int __sys_listen(int fd, int backlog) +{ + struct socket *sock; + int err, fput_needed; + int somaxconn; + + sock = sockfd_lookup_light(fd, &err, &fput_needed); + if (sock) { + somaxconn = READ_ONCE(sock_net(sock->sk)->core.sysctl_somaxconn); + if ((unsigned int)backlog > somaxconn) + backlog = somaxconn; + + err = security_socket_listen(sock, backlog); + if (!err) + err = sock->ops->listen(sock, backlog); + + fput_light(sock->file, fput_needed); + } + return err; +} + +SYSCALL_DEFINE2(listen, int, fd, int, backlog) +{ + return __sys_listen(fd, backlog); +} + +struct file *do_accept(struct file *file, unsigned file_flags, + struct sockaddr __user *upeer_sockaddr, + int __user *upeer_addrlen, int flags) +{ + struct socket *sock, *newsock; + struct file *newfile; + int err, len; + struct sockaddr_storage address; + + sock = sock_from_file(file); + if (!sock) + return ERR_PTR(-ENOTSOCK); + + newsock = sock_alloc(); + if (!newsock) + return ERR_PTR(-ENFILE); + + newsock->type = sock->type; + newsock->ops = sock->ops; + + /* + * We don't need try_module_get here, as the listening socket (sock) + * has the protocol module (sock->ops->owner) held. + */ + __module_get(newsock->ops->owner); + + newfile = sock_alloc_file(newsock, flags, sock->sk->sk_prot_creator->name); + if (IS_ERR(newfile)) + return newfile; + + err = security_socket_accept(sock, newsock); + if (err) + goto out_fd; + + err = sock->ops->accept(sock, newsock, sock->file->f_flags | file_flags, + false); + if (err < 0) + goto out_fd; + + if (upeer_sockaddr) { + len = newsock->ops->getname(newsock, + (struct sockaddr *)&address, 2); + if (len < 0) { + err = -ECONNABORTED; + goto out_fd; + } + err = move_addr_to_user(&address, + len, upeer_sockaddr, upeer_addrlen); + if (err < 0) + goto out_fd; + } + + /* File flags are not inherited via accept() unlike another OSes. */ + return newfile; +out_fd: + fput(newfile); + return ERR_PTR(err); +} + +static int __sys_accept4_file(struct file *file, struct sockaddr __user *upeer_sockaddr, + int __user *upeer_addrlen, int flags) +{ + struct file *newfile; + int newfd; + + if (flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK)) + return -EINVAL; + + if (SOCK_NONBLOCK != O_NONBLOCK && (flags & SOCK_NONBLOCK)) + flags = (flags & ~SOCK_NONBLOCK) | O_NONBLOCK; + + newfd = get_unused_fd_flags(flags); + if (unlikely(newfd < 0)) + return newfd; + + newfile = do_accept(file, 0, upeer_sockaddr, upeer_addrlen, + flags); + if (IS_ERR(newfile)) { + put_unused_fd(newfd); + return PTR_ERR(newfile); + } + fd_install(newfd, newfile); + return newfd; +} + +/* + * For accept, we attempt to create a new socket, set up the link + * with the client, wake up the client, then return the new + * connected fd. We collect the address of the connector in kernel + * space and move it to user at the very end. This is unclean because + * we open the socket then return an error. + * + * 1003.1g adds the ability to recvmsg() to query connection pending + * status to recvmsg. We need to add that support in a way thats + * clean when we restructure accept also. + */ + +int __sys_accept4(int fd, struct sockaddr __user *upeer_sockaddr, + int __user *upeer_addrlen, int flags) +{ + int ret = -EBADF; + struct fd f; + + f = fdget(fd); + if (f.file) { + ret = __sys_accept4_file(f.file, upeer_sockaddr, + upeer_addrlen, flags); + fdput(f); + } + + return ret; +} + +SYSCALL_DEFINE4(accept4, int, fd, struct sockaddr __user *, upeer_sockaddr, + int __user *, upeer_addrlen, int, flags) +{ + return __sys_accept4(fd, upeer_sockaddr, upeer_addrlen, flags); +} + +SYSCALL_DEFINE3(accept, int, fd, struct sockaddr __user *, upeer_sockaddr, + int __user *, upeer_addrlen) +{ + return __sys_accept4(fd, upeer_sockaddr, upeer_addrlen, 0); +} + +/* + * Attempt to connect to a socket with the server address. The address + * is in user space so we verify it is OK and move it to kernel space. + * + * For 1003.1g we need to add clean support for a bind to AF_UNSPEC to + * break bindings + * + * NOTE: 1003.1g draft 6.3 is broken with respect to AX.25/NetROM and + * other SEQPACKET protocols that take time to connect() as it doesn't + * include the -EINPROGRESS status for such sockets. + */ + +int __sys_connect_file(struct file *file, struct sockaddr_storage *address, + int addrlen, int file_flags) +{ + struct socket *sock; + int err; + + sock = sock_from_file(file); + if (!sock) { + err = -ENOTSOCK; + goto out; + } + + err = + security_socket_connect(sock, (struct sockaddr *)address, addrlen); + if (err) + goto out; + + err = sock->ops->connect(sock, (struct sockaddr *)address, addrlen, + sock->file->f_flags | file_flags); +out: + return err; +} + +int __sys_connect(int fd, struct sockaddr __user *uservaddr, int addrlen) +{ + int ret = -EBADF; + struct fd f; + + f = fdget(fd); + if (f.file) { + struct sockaddr_storage address; + + ret = move_addr_to_kernel(uservaddr, addrlen, &address); + if (!ret) + ret = __sys_connect_file(f.file, &address, addrlen, 0); + fdput(f); + } + + return ret; +} + +SYSCALL_DEFINE3(connect, int, fd, struct sockaddr __user *, uservaddr, + int, addrlen) +{ + return __sys_connect(fd, uservaddr, addrlen); +} + +/* + * Get the local address ('name') of a socket object. Move the obtained + * name to user space. + */ + +int __sys_getsockname(int fd, struct sockaddr __user *usockaddr, + int __user *usockaddr_len) +{ + struct socket *sock; + struct sockaddr_storage address; + int err, fput_needed; + + sock = sockfd_lookup_light(fd, &err, &fput_needed); + if (!sock) + goto out; + + err = security_socket_getsockname(sock); + if (err) + goto out_put; + + err = sock->ops->getname(sock, (struct sockaddr *)&address, 0); + if (err < 0) + goto out_put; + /* "err" is actually length in this case */ + err = move_addr_to_user(&address, err, usockaddr, usockaddr_len); + +out_put: + fput_light(sock->file, fput_needed); +out: + return err; +} + +SYSCALL_DEFINE3(getsockname, int, fd, struct sockaddr __user *, usockaddr, + int __user *, usockaddr_len) +{ + return __sys_getsockname(fd, usockaddr, usockaddr_len); +} + +/* + * Get the remote address ('name') of a socket object. Move the obtained + * name to user space. + */ + +int __sys_getpeername(int fd, struct sockaddr __user *usockaddr, + int __user *usockaddr_len) +{ + struct socket *sock; + struct sockaddr_storage address; + int err, fput_needed; + + sock = sockfd_lookup_light(fd, &err, &fput_needed); + if (sock != NULL) { + err = security_socket_getpeername(sock); + if (err) { + fput_light(sock->file, fput_needed); + return err; + } + + err = sock->ops->getname(sock, (struct sockaddr *)&address, 1); + if (err >= 0) + /* "err" is actually length in this case */ + err = move_addr_to_user(&address, err, usockaddr, + usockaddr_len); + fput_light(sock->file, fput_needed); + } + return err; +} + +SYSCALL_DEFINE3(getpeername, int, fd, struct sockaddr __user *, usockaddr, + int __user *, usockaddr_len) +{ + return __sys_getpeername(fd, usockaddr, usockaddr_len); +} + +/* + * Send a datagram to a given address. We move the address into kernel + * space and check the user space data area is readable before invoking + * the protocol. + */ +int __sys_sendto(int fd, void __user *buff, size_t len, unsigned int flags, + struct sockaddr __user *addr, int addr_len) +{ + struct socket *sock; + struct sockaddr_storage address; + int err; + struct msghdr msg; + struct iovec iov; + int fput_needed; + + err = import_single_range(ITER_SOURCE, buff, len, &iov, &msg.msg_iter); + if (unlikely(err)) + return err; + sock = sockfd_lookup_light(fd, &err, &fput_needed); + if (!sock) + goto out; + + msg.msg_name = NULL; + msg.msg_control = NULL; + msg.msg_controllen = 0; + msg.msg_namelen = 0; + msg.msg_ubuf = NULL; + if (addr) { + err = move_addr_to_kernel(addr, addr_len, &address); + if (err < 0) + goto out_put; + msg.msg_name = (struct sockaddr *)&address; + msg.msg_namelen = addr_len; + } + flags &= ~MSG_INTERNAL_SENDMSG_FLAGS; + if (sock->file->f_flags & O_NONBLOCK) + flags |= MSG_DONTWAIT; + msg.msg_flags = flags; + err = __sock_sendmsg(sock, &msg); + +out_put: + fput_light(sock->file, fput_needed); +out: + return err; +} + +SYSCALL_DEFINE6(sendto, int, fd, void __user *, buff, size_t, len, + unsigned int, flags, struct sockaddr __user *, addr, + int, addr_len) +{ + return __sys_sendto(fd, buff, len, flags, addr, addr_len); +} + +/* + * Send a datagram down a socket. + */ + +SYSCALL_DEFINE4(send, int, fd, void __user *, buff, size_t, len, + unsigned int, flags) +{ + return __sys_sendto(fd, buff, len, flags, NULL, 0); +} + +/* + * Receive a frame from the socket and optionally record the address of the + * sender. We verify the buffers are writable and if needed move the + * sender address from kernel to user space. + */ +int __sys_recvfrom(int fd, void __user *ubuf, size_t size, unsigned int flags, + struct sockaddr __user *addr, int __user *addr_len) +{ + struct sockaddr_storage address; + struct msghdr msg = { + /* Save some cycles and don't copy the address if not needed */ + .msg_name = addr ? (struct sockaddr *)&address : NULL, + }; + struct socket *sock; + struct iovec iov; + int err, err2; + int fput_needed; + + err = import_single_range(ITER_DEST, ubuf, size, &iov, &msg.msg_iter); + if (unlikely(err)) + return err; + sock = sockfd_lookup_light(fd, &err, &fput_needed); + if (!sock) + goto out; + + if (sock->file->f_flags & O_NONBLOCK) + flags |= MSG_DONTWAIT; + err = sock_recvmsg(sock, &msg, flags); + + if (err >= 0 && addr != NULL) { + err2 = move_addr_to_user(&address, + msg.msg_namelen, addr, addr_len); + if (err2 < 0) + err = err2; + } + + fput_light(sock->file, fput_needed); +out: + return err; +} + +SYSCALL_DEFINE6(recvfrom, int, fd, void __user *, ubuf, size_t, size, + unsigned int, flags, struct sockaddr __user *, addr, + int __user *, addr_len) +{ + return __sys_recvfrom(fd, ubuf, size, flags, addr, addr_len); +} + +/* + * Receive a datagram from a socket. + */ + +SYSCALL_DEFINE4(recv, int, fd, void __user *, ubuf, size_t, size, + unsigned int, flags) +{ + return __sys_recvfrom(fd, ubuf, size, flags, NULL, NULL); +} + +static bool sock_use_custom_sol_socket(const struct socket *sock) +{ + const struct sock *sk = sock->sk; + + /* Use sock->ops->setsockopt() for MPTCP */ + return IS_ENABLED(CONFIG_MPTCP) && + sk->sk_protocol == IPPROTO_MPTCP && + sk->sk_type == SOCK_STREAM && + (sk->sk_family == AF_INET || sk->sk_family == AF_INET6); +} + +/* + * Set a socket option. Because we don't know the option lengths we have + * to pass the user mode parameter for the protocols to sort out. + */ +int __sys_setsockopt(int fd, int level, int optname, char __user *user_optval, + int optlen) +{ + sockptr_t optval = USER_SOCKPTR(user_optval); + char *kernel_optval = NULL; + int err, fput_needed; + struct socket *sock; + + if (optlen < 0) + return -EINVAL; + + sock = sockfd_lookup_light(fd, &err, &fput_needed); + if (!sock) + return err; + + err = security_socket_setsockopt(sock, level, optname); + if (err) + goto out_put; + + if (!in_compat_syscall()) + err = BPF_CGROUP_RUN_PROG_SETSOCKOPT(sock->sk, &level, &optname, + user_optval, &optlen, + &kernel_optval); + if (err < 0) + goto out_put; + if (err > 0) { + err = 0; + goto out_put; + } + + if (kernel_optval) + optval = KERNEL_SOCKPTR(kernel_optval); + if (level == SOL_SOCKET && !sock_use_custom_sol_socket(sock)) + err = sock_setsockopt(sock, level, optname, optval, optlen); + else if (unlikely(!sock->ops->setsockopt)) + err = -EOPNOTSUPP; + else + err = sock->ops->setsockopt(sock, level, optname, optval, + optlen); + kfree(kernel_optval); +out_put: + fput_light(sock->file, fput_needed); + return err; +} + +SYSCALL_DEFINE5(setsockopt, int, fd, int, level, int, optname, + char __user *, optval, int, optlen) +{ + return __sys_setsockopt(fd, level, optname, optval, optlen); +} + +INDIRECT_CALLABLE_DECLARE(bool tcp_bpf_bypass_getsockopt(int level, + int optname)); + +/* + * Get a socket option. Because we don't know the option lengths we have + * to pass a user mode parameter for the protocols to sort out. + */ +int __sys_getsockopt(int fd, int level, int optname, char __user *optval, + int __user *optlen) +{ + int err, fput_needed; + struct socket *sock; + int max_optlen; + + sock = sockfd_lookup_light(fd, &err, &fput_needed); + if (!sock) + return err; + + err = security_socket_getsockopt(sock, level, optname); + if (err) + goto out_put; + + if (!in_compat_syscall()) + max_optlen = BPF_CGROUP_GETSOCKOPT_MAX_OPTLEN(optlen); + + if (level == SOL_SOCKET) + err = sock_getsockopt(sock, level, optname, optval, optlen); + else if (unlikely(!sock->ops->getsockopt)) + err = -EOPNOTSUPP; + else + err = sock->ops->getsockopt(sock, level, optname, optval, + optlen); + + if (!in_compat_syscall()) + err = BPF_CGROUP_RUN_PROG_GETSOCKOPT(sock->sk, level, optname, + optval, optlen, max_optlen, + err); +out_put: + fput_light(sock->file, fput_needed); + return err; +} + +SYSCALL_DEFINE5(getsockopt, int, fd, int, level, int, optname, + char __user *, optval, int __user *, optlen) +{ + return __sys_getsockopt(fd, level, optname, optval, optlen); +} + +/* + * Shutdown a socket. + */ + +int __sys_shutdown_sock(struct socket *sock, int how) +{ + int err; + + err = security_socket_shutdown(sock, how); + if (!err) + err = sock->ops->shutdown(sock, how); + + return err; +} + +int __sys_shutdown(int fd, int how) +{ + int err, fput_needed; + struct socket *sock; + + sock = sockfd_lookup_light(fd, &err, &fput_needed); + if (sock != NULL) { + err = __sys_shutdown_sock(sock, how); + fput_light(sock->file, fput_needed); + } + return err; +} + +SYSCALL_DEFINE2(shutdown, int, fd, int, how) +{ + return __sys_shutdown(fd, how); +} + +/* A couple of helpful macros for getting the address of the 32/64 bit + * fields which are the same type (int / unsigned) on our platforms. + */ +#define COMPAT_MSG(msg, member) ((MSG_CMSG_COMPAT & flags) ? &msg##_compat->member : &msg->member) +#define COMPAT_NAMELEN(msg) COMPAT_MSG(msg, msg_namelen) +#define COMPAT_FLAGS(msg) COMPAT_MSG(msg, msg_flags) + +struct used_address { + struct sockaddr_storage name; + unsigned int name_len; +}; + +int __copy_msghdr(struct msghdr *kmsg, + struct user_msghdr *msg, + struct sockaddr __user **save_addr) +{ + ssize_t err; + + kmsg->msg_control_is_user = true; + kmsg->msg_get_inq = 0; + kmsg->msg_control_user = msg->msg_control; + kmsg->msg_controllen = msg->msg_controllen; + kmsg->msg_flags = msg->msg_flags; + + kmsg->msg_namelen = msg->msg_namelen; + if (!msg->msg_name) + kmsg->msg_namelen = 0; + + if (kmsg->msg_namelen < 0) + return -EINVAL; + + if (kmsg->msg_namelen > sizeof(struct sockaddr_storage)) + kmsg->msg_namelen = sizeof(struct sockaddr_storage); + + if (save_addr) + *save_addr = msg->msg_name; + + if (msg->msg_name && kmsg->msg_namelen) { + if (!save_addr) { + err = move_addr_to_kernel(msg->msg_name, + kmsg->msg_namelen, + kmsg->msg_name); + if (err < 0) + return err; + } + } else { + kmsg->msg_name = NULL; + kmsg->msg_namelen = 0; + } + + if (msg->msg_iovlen > UIO_MAXIOV) + return -EMSGSIZE; + + kmsg->msg_iocb = NULL; + kmsg->msg_ubuf = NULL; + return 0; +} + +static int copy_msghdr_from_user(struct msghdr *kmsg, + struct user_msghdr __user *umsg, + struct sockaddr __user **save_addr, + struct iovec **iov) +{ + struct user_msghdr msg; + ssize_t err; + + if (copy_from_user(&msg, umsg, sizeof(*umsg))) + return -EFAULT; + + err = __copy_msghdr(kmsg, &msg, save_addr); + if (err) + return err; + + err = import_iovec(save_addr ? ITER_DEST : ITER_SOURCE, + msg.msg_iov, msg.msg_iovlen, + UIO_FASTIOV, iov, &kmsg->msg_iter); + return err < 0 ? err : 0; +} + +static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys, + unsigned int flags, struct used_address *used_address, + unsigned int allowed_msghdr_flags) +{ + unsigned char ctl[sizeof(struct cmsghdr) + 20] + __aligned(sizeof(__kernel_size_t)); + /* 20 is size of ipv6_pktinfo */ + unsigned char *ctl_buf = ctl; + int ctl_len; + ssize_t err; + + err = -ENOBUFS; + + if (msg_sys->msg_controllen > INT_MAX) + goto out; + flags |= (msg_sys->msg_flags & allowed_msghdr_flags); + ctl_len = msg_sys->msg_controllen; + if ((MSG_CMSG_COMPAT & flags) && ctl_len) { + err = + cmsghdr_from_user_compat_to_kern(msg_sys, sock->sk, ctl, + sizeof(ctl)); + if (err) + goto out; + ctl_buf = msg_sys->msg_control; + ctl_len = msg_sys->msg_controllen; + } else if (ctl_len) { + BUILD_BUG_ON(sizeof(struct cmsghdr) != + CMSG_ALIGN(sizeof(struct cmsghdr))); + if (ctl_len > sizeof(ctl)) { + ctl_buf = sock_kmalloc(sock->sk, ctl_len, GFP_KERNEL); + if (ctl_buf == NULL) + goto out; + } + err = -EFAULT; + if (copy_from_user(ctl_buf, msg_sys->msg_control_user, ctl_len)) + goto out_freectl; + msg_sys->msg_control = ctl_buf; + msg_sys->msg_control_is_user = false; + } + flags &= ~MSG_INTERNAL_SENDMSG_FLAGS; + msg_sys->msg_flags = flags; + + if (sock->file->f_flags & O_NONBLOCK) + msg_sys->msg_flags |= MSG_DONTWAIT; + /* + * If this is sendmmsg() and current destination address is same as + * previously succeeded address, omit asking LSM's decision. + * used_address->name_len is initialized to UINT_MAX so that the first + * destination address never matches. + */ + if (used_address && msg_sys->msg_name && + used_address->name_len == msg_sys->msg_namelen && + !memcmp(&used_address->name, msg_sys->msg_name, + used_address->name_len)) { + err = sock_sendmsg_nosec(sock, msg_sys); + goto out_freectl; + } + err = __sock_sendmsg(sock, msg_sys); + /* + * If this is sendmmsg() and sending to current destination address was + * successful, remember it. + */ + if (used_address && err >= 0) { + used_address->name_len = msg_sys->msg_namelen; + if (msg_sys->msg_name) + memcpy(&used_address->name, msg_sys->msg_name, + used_address->name_len); + } + +out_freectl: + if (ctl_buf != ctl) + sock_kfree_s(sock->sk, ctl_buf, ctl_len); +out: + return err; +} + +int sendmsg_copy_msghdr(struct msghdr *msg, + struct user_msghdr __user *umsg, unsigned flags, + struct iovec **iov) +{ + int err; + + if (flags & MSG_CMSG_COMPAT) { + struct compat_msghdr __user *msg_compat; + + msg_compat = (struct compat_msghdr __user *) umsg; + err = get_compat_msghdr(msg, msg_compat, NULL, iov); + } else { + err = copy_msghdr_from_user(msg, umsg, NULL, iov); + } + if (err < 0) + return err; + + return 0; +} + +static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg, + struct msghdr *msg_sys, unsigned int flags, + struct used_address *used_address, + unsigned int allowed_msghdr_flags) +{ + struct sockaddr_storage address; + struct iovec iovstack[UIO_FASTIOV], *iov = iovstack; + ssize_t err; + + msg_sys->msg_name = &address; + + err = sendmsg_copy_msghdr(msg_sys, msg, flags, &iov); + if (err < 0) + return err; + + err = ____sys_sendmsg(sock, msg_sys, flags, used_address, + allowed_msghdr_flags); + kfree(iov); + return err; +} + +/* + * BSD sendmsg interface + */ +long __sys_sendmsg_sock(struct socket *sock, struct msghdr *msg, + unsigned int flags) +{ + return ____sys_sendmsg(sock, msg, flags, NULL, 0); +} + +long __sys_sendmsg(int fd, struct user_msghdr __user *msg, unsigned int flags, + bool forbid_cmsg_compat) +{ + int fput_needed, err; + struct msghdr msg_sys; + struct socket *sock; + + if (forbid_cmsg_compat && (flags & MSG_CMSG_COMPAT)) + return -EINVAL; + + sock = sockfd_lookup_light(fd, &err, &fput_needed); + if (!sock) + goto out; + + err = ___sys_sendmsg(sock, msg, &msg_sys, flags, NULL, 0); + + fput_light(sock->file, fput_needed); +out: + return err; +} + +SYSCALL_DEFINE3(sendmsg, int, fd, struct user_msghdr __user *, msg, unsigned int, flags) +{ + return __sys_sendmsg(fd, msg, flags, true); +} + +/* + * Linux sendmmsg interface + */ + +int __sys_sendmmsg(int fd, struct mmsghdr __user *mmsg, unsigned int vlen, + unsigned int flags, bool forbid_cmsg_compat) +{ + int fput_needed, err, datagrams; + struct socket *sock; + struct mmsghdr __user *entry; + struct compat_mmsghdr __user *compat_entry; + struct msghdr msg_sys; + struct used_address used_address; + unsigned int oflags = flags; + + if (forbid_cmsg_compat && (flags & MSG_CMSG_COMPAT)) + return -EINVAL; + + if (vlen > UIO_MAXIOV) + vlen = UIO_MAXIOV; + + datagrams = 0; + + sock = sockfd_lookup_light(fd, &err, &fput_needed); + if (!sock) + return err; + + used_address.name_len = UINT_MAX; + entry = mmsg; + compat_entry = (struct compat_mmsghdr __user *)mmsg; + err = 0; + flags |= MSG_BATCH; + + while (datagrams < vlen) { + if (datagrams == vlen - 1) + flags = oflags; + + if (MSG_CMSG_COMPAT & flags) { + err = ___sys_sendmsg(sock, (struct user_msghdr __user *)compat_entry, + &msg_sys, flags, &used_address, MSG_EOR); + if (err < 0) + break; + err = __put_user(err, &compat_entry->msg_len); + ++compat_entry; + } else { + err = ___sys_sendmsg(sock, + (struct user_msghdr __user *)entry, + &msg_sys, flags, &used_address, MSG_EOR); + if (err < 0) + break; + err = put_user(err, &entry->msg_len); + ++entry; + } + + if (err) + break; + ++datagrams; + if (msg_data_left(&msg_sys)) + break; + cond_resched(); + } + + fput_light(sock->file, fput_needed); + + /* We only return an error if no datagrams were able to be sent */ + if (datagrams != 0) + return datagrams; + + return err; +} + +SYSCALL_DEFINE4(sendmmsg, int, fd, struct mmsghdr __user *, mmsg, + unsigned int, vlen, unsigned int, flags) +{ + return __sys_sendmmsg(fd, mmsg, vlen, flags, true); +} + +int recvmsg_copy_msghdr(struct msghdr *msg, + struct user_msghdr __user *umsg, unsigned flags, + struct sockaddr __user **uaddr, + struct iovec **iov) +{ + ssize_t err; + + if (MSG_CMSG_COMPAT & flags) { + struct compat_msghdr __user *msg_compat; + + msg_compat = (struct compat_msghdr __user *) umsg; + err = get_compat_msghdr(msg, msg_compat, uaddr, iov); + } else { + err = copy_msghdr_from_user(msg, umsg, uaddr, iov); + } + if (err < 0) + return err; + + return 0; +} + +static int ____sys_recvmsg(struct socket *sock, struct msghdr *msg_sys, + struct user_msghdr __user *msg, + struct sockaddr __user *uaddr, + unsigned int flags, int nosec) +{ + struct compat_msghdr __user *msg_compat = + (struct compat_msghdr __user *) msg; + int __user *uaddr_len = COMPAT_NAMELEN(msg); + struct sockaddr_storage addr; + unsigned long cmsg_ptr; + int len; + ssize_t err; + + msg_sys->msg_name = &addr; + cmsg_ptr = (unsigned long)msg_sys->msg_control; + msg_sys->msg_flags = flags & (MSG_CMSG_CLOEXEC|MSG_CMSG_COMPAT); + + /* We assume all kernel code knows the size of sockaddr_storage */ + msg_sys->msg_namelen = 0; + + if (sock->file->f_flags & O_NONBLOCK) + flags |= MSG_DONTWAIT; + + if (unlikely(nosec)) + err = sock_recvmsg_nosec(sock, msg_sys, flags); + else + err = sock_recvmsg(sock, msg_sys, flags); + + if (err < 0) + goto out; + len = err; + + if (uaddr != NULL) { + err = move_addr_to_user(&addr, + msg_sys->msg_namelen, uaddr, + uaddr_len); + if (err < 0) + goto out; + } + err = __put_user((msg_sys->msg_flags & ~MSG_CMSG_COMPAT), + COMPAT_FLAGS(msg)); + if (err) + goto out; + if (MSG_CMSG_COMPAT & flags) + err = __put_user((unsigned long)msg_sys->msg_control - cmsg_ptr, + &msg_compat->msg_controllen); + else + err = __put_user((unsigned long)msg_sys->msg_control - cmsg_ptr, + &msg->msg_controllen); + if (err) + goto out; + err = len; +out: + return err; +} + +static int ___sys_recvmsg(struct socket *sock, struct user_msghdr __user *msg, + struct msghdr *msg_sys, unsigned int flags, int nosec) +{ + struct iovec iovstack[UIO_FASTIOV], *iov = iovstack; + /* user mode address pointers */ + struct sockaddr __user *uaddr; + ssize_t err; + + err = recvmsg_copy_msghdr(msg_sys, msg, flags, &uaddr, &iov); + if (err < 0) + return err; + + err = ____sys_recvmsg(sock, msg_sys, msg, uaddr, flags, nosec); + kfree(iov); + return err; +} + +/* + * BSD recvmsg interface + */ + +long __sys_recvmsg_sock(struct socket *sock, struct msghdr *msg, + struct user_msghdr __user *umsg, + struct sockaddr __user *uaddr, unsigned int flags) +{ + return ____sys_recvmsg(sock, msg, umsg, uaddr, flags, 0); +} + +long __sys_recvmsg(int fd, struct user_msghdr __user *msg, unsigned int flags, + bool forbid_cmsg_compat) +{ + int fput_needed, err; + struct msghdr msg_sys; + struct socket *sock; + + if (forbid_cmsg_compat && (flags & MSG_CMSG_COMPAT)) + return -EINVAL; + + sock = sockfd_lookup_light(fd, &err, &fput_needed); + if (!sock) + goto out; + + err = ___sys_recvmsg(sock, msg, &msg_sys, flags, 0); + + fput_light(sock->file, fput_needed); +out: + return err; +} + +SYSCALL_DEFINE3(recvmsg, int, fd, struct user_msghdr __user *, msg, + unsigned int, flags) +{ + return __sys_recvmsg(fd, msg, flags, true); +} + +/* + * Linux recvmmsg interface + */ + +static int do_recvmmsg(int fd, struct mmsghdr __user *mmsg, + unsigned int vlen, unsigned int flags, + struct timespec64 *timeout) +{ + int fput_needed, err, datagrams; + struct socket *sock; + struct mmsghdr __user *entry; + struct compat_mmsghdr __user *compat_entry; + struct msghdr msg_sys; + struct timespec64 end_time; + struct timespec64 timeout64; + + if (timeout && + poll_select_set_timeout(&end_time, timeout->tv_sec, + timeout->tv_nsec)) + return -EINVAL; + + datagrams = 0; + + sock = sockfd_lookup_light(fd, &err, &fput_needed); + if (!sock) + return err; + + if (likely(!(flags & MSG_ERRQUEUE))) { + err = sock_error(sock->sk); + if (err) { + datagrams = err; + goto out_put; + } + } + + entry = mmsg; + compat_entry = (struct compat_mmsghdr __user *)mmsg; + + while (datagrams < vlen) { + /* + * No need to ask LSM for more than the first datagram. + */ + if (MSG_CMSG_COMPAT & flags) { + err = ___sys_recvmsg(sock, (struct user_msghdr __user *)compat_entry, + &msg_sys, flags & ~MSG_WAITFORONE, + datagrams); + if (err < 0) + break; + err = __put_user(err, &compat_entry->msg_len); + ++compat_entry; + } else { + err = ___sys_recvmsg(sock, + (struct user_msghdr __user *)entry, + &msg_sys, flags & ~MSG_WAITFORONE, + datagrams); + if (err < 0) + break; + err = put_user(err, &entry->msg_len); + ++entry; + } + + if (err) + break; + ++datagrams; + + /* MSG_WAITFORONE turns on MSG_DONTWAIT after one packet */ + if (flags & MSG_WAITFORONE) + flags |= MSG_DONTWAIT; + + if (timeout) { + ktime_get_ts64(&timeout64); + *timeout = timespec64_sub(end_time, timeout64); + if (timeout->tv_sec < 0) { + timeout->tv_sec = timeout->tv_nsec = 0; + break; + } + + /* Timeout, return less than vlen datagrams */ + if (timeout->tv_nsec == 0 && timeout->tv_sec == 0) + break; + } + + /* Out of band data, return right away */ + if (msg_sys.msg_flags & MSG_OOB) + break; + cond_resched(); + } + + if (err == 0) + goto out_put; + + if (datagrams == 0) { + datagrams = err; + goto out_put; + } + + /* + * We may return less entries than requested (vlen) if the + * sock is non block and there aren't enough datagrams... + */ + if (err != -EAGAIN) { + /* + * ... or if recvmsg returns an error after we + * received some datagrams, where we record the + * error to return on the next call or if the + * app asks about it using getsockopt(SO_ERROR). + */ + WRITE_ONCE(sock->sk->sk_err, -err); + } +out_put: + fput_light(sock->file, fput_needed); + + return datagrams; +} + +int __sys_recvmmsg(int fd, struct mmsghdr __user *mmsg, + unsigned int vlen, unsigned int flags, + struct __kernel_timespec __user *timeout, + struct old_timespec32 __user *timeout32) +{ + int datagrams; + struct timespec64 timeout_sys; + + if (timeout && get_timespec64(&timeout_sys, timeout)) + return -EFAULT; + + if (timeout32 && get_old_timespec32(&timeout_sys, timeout32)) + return -EFAULT; + + if (!timeout && !timeout32) + return do_recvmmsg(fd, mmsg, vlen, flags, NULL); + + datagrams = do_recvmmsg(fd, mmsg, vlen, flags, &timeout_sys); + + if (datagrams <= 0) + return datagrams; + + if (timeout && put_timespec64(&timeout_sys, timeout)) + datagrams = -EFAULT; + + if (timeout32 && put_old_timespec32(&timeout_sys, timeout32)) + datagrams = -EFAULT; + + return datagrams; +} + +SYSCALL_DEFINE5(recvmmsg, int, fd, struct mmsghdr __user *, mmsg, + unsigned int, vlen, unsigned int, flags, + struct __kernel_timespec __user *, timeout) +{ + if (flags & MSG_CMSG_COMPAT) + return -EINVAL; + + return __sys_recvmmsg(fd, mmsg, vlen, flags, timeout, NULL); +} + +#ifdef CONFIG_COMPAT_32BIT_TIME +SYSCALL_DEFINE5(recvmmsg_time32, int, fd, struct mmsghdr __user *, mmsg, + unsigned int, vlen, unsigned int, flags, + struct old_timespec32 __user *, timeout) +{ + if (flags & MSG_CMSG_COMPAT) + return -EINVAL; + + return __sys_recvmmsg(fd, mmsg, vlen, flags, NULL, timeout); +} +#endif + +#ifdef __ARCH_WANT_SYS_SOCKETCALL +/* Argument list sizes for sys_socketcall */ +#define AL(x) ((x) * sizeof(unsigned long)) +static const unsigned char nargs[21] = { + AL(0), AL(3), AL(3), AL(3), AL(2), AL(3), + AL(3), AL(3), AL(4), AL(4), AL(4), AL(6), + AL(6), AL(2), AL(5), AL(5), AL(3), AL(3), + AL(4), AL(5), AL(4) +}; + +#undef AL + +/* + * System call vectors. + * + * Argument checking cleaned up. Saved 20% in size. + * This function doesn't need to set the kernel lock because + * it is set by the callees. + */ + +SYSCALL_DEFINE2(socketcall, int, call, unsigned long __user *, args) +{ + unsigned long a[AUDITSC_ARGS]; + unsigned long a0, a1; + int err; + unsigned int len; + + if (call < 1 || call > SYS_SENDMMSG) + return -EINVAL; + call = array_index_nospec(call, SYS_SENDMMSG + 1); + + len = nargs[call]; + if (len > sizeof(a)) + return -EINVAL; + + /* copy_from_user should be SMP safe. */ + if (copy_from_user(a, args, len)) + return -EFAULT; + + err = audit_socketcall(nargs[call] / sizeof(unsigned long), a); + if (err) + return err; + + a0 = a[0]; + a1 = a[1]; + + switch (call) { + case SYS_SOCKET: + err = __sys_socket(a0, a1, a[2]); + break; + case SYS_BIND: + err = __sys_bind(a0, (struct sockaddr __user *)a1, a[2]); + break; + case SYS_CONNECT: + err = __sys_connect(a0, (struct sockaddr __user *)a1, a[2]); + break; + case SYS_LISTEN: + err = __sys_listen(a0, a1); + break; + case SYS_ACCEPT: + err = __sys_accept4(a0, (struct sockaddr __user *)a1, + (int __user *)a[2], 0); + break; + case SYS_GETSOCKNAME: + err = + __sys_getsockname(a0, (struct sockaddr __user *)a1, + (int __user *)a[2]); + break; + case SYS_GETPEERNAME: + err = + __sys_getpeername(a0, (struct sockaddr __user *)a1, + (int __user *)a[2]); + break; + case SYS_SOCKETPAIR: + err = __sys_socketpair(a0, a1, a[2], (int __user *)a[3]); + break; + case SYS_SEND: + err = __sys_sendto(a0, (void __user *)a1, a[2], a[3], + NULL, 0); + break; + case SYS_SENDTO: + err = __sys_sendto(a0, (void __user *)a1, a[2], a[3], + (struct sockaddr __user *)a[4], a[5]); + break; + case SYS_RECV: + err = __sys_recvfrom(a0, (void __user *)a1, a[2], a[3], + NULL, NULL); + break; + case SYS_RECVFROM: + err = __sys_recvfrom(a0, (void __user *)a1, a[2], a[3], + (struct sockaddr __user *)a[4], + (int __user *)a[5]); + break; + case SYS_SHUTDOWN: + err = __sys_shutdown(a0, a1); + break; + case SYS_SETSOCKOPT: + err = __sys_setsockopt(a0, a1, a[2], (char __user *)a[3], + a[4]); + break; + case SYS_GETSOCKOPT: + err = + __sys_getsockopt(a0, a1, a[2], (char __user *)a[3], + (int __user *)a[4]); + break; + case SYS_SENDMSG: + err = __sys_sendmsg(a0, (struct user_msghdr __user *)a1, + a[2], true); + break; + case SYS_SENDMMSG: + err = __sys_sendmmsg(a0, (struct mmsghdr __user *)a1, a[2], + a[3], true); + break; + case SYS_RECVMSG: + err = __sys_recvmsg(a0, (struct user_msghdr __user *)a1, + a[2], true); + break; + case SYS_RECVMMSG: + if (IS_ENABLED(CONFIG_64BIT)) + err = __sys_recvmmsg(a0, (struct mmsghdr __user *)a1, + a[2], a[3], + (struct __kernel_timespec __user *)a[4], + NULL); + else + err = __sys_recvmmsg(a0, (struct mmsghdr __user *)a1, + a[2], a[3], NULL, + (struct old_timespec32 __user *)a[4]); + break; + case SYS_ACCEPT4: + err = __sys_accept4(a0, (struct sockaddr __user *)a1, + (int __user *)a[2], a[3]); + break; + default: + err = -EINVAL; + break; + } + return err; +} + +#endif /* __ARCH_WANT_SYS_SOCKETCALL */ + +/** + * sock_register - add a socket protocol handler + * @ops: description of protocol + * + * This function is called by a protocol handler that wants to + * advertise its address family, and have it linked into the + * socket interface. The value ops->family corresponds to the + * socket system call protocol family. + */ +int sock_register(const struct net_proto_family *ops) +{ + int err; + + if (ops->family >= NPROTO) { + pr_crit("protocol %d >= NPROTO(%d)\n", ops->family, NPROTO); + return -ENOBUFS; + } + + spin_lock(&net_family_lock); + if (rcu_dereference_protected(net_families[ops->family], + lockdep_is_held(&net_family_lock))) + err = -EEXIST; + else { + rcu_assign_pointer(net_families[ops->family], ops); + err = 0; + } + spin_unlock(&net_family_lock); + + pr_info("NET: Registered %s protocol family\n", pf_family_names[ops->family]); + return err; +} +EXPORT_SYMBOL(sock_register); + +/** + * sock_unregister - remove a protocol handler + * @family: protocol family to remove + * + * This function is called by a protocol handler that wants to + * remove its address family, and have it unlinked from the + * new socket creation. + * + * If protocol handler is a module, then it can use module reference + * counts to protect against new references. If protocol handler is not + * a module then it needs to provide its own protection in + * the ops->create routine. + */ +void sock_unregister(int family) +{ + BUG_ON(family < 0 || family >= NPROTO); + + spin_lock(&net_family_lock); + RCU_INIT_POINTER(net_families[family], NULL); + spin_unlock(&net_family_lock); + + synchronize_rcu(); + + pr_info("NET: Unregistered %s protocol family\n", pf_family_names[family]); +} +EXPORT_SYMBOL(sock_unregister); + +bool sock_is_registered(int family) +{ + return family < NPROTO && rcu_access_pointer(net_families[family]); +} + +static int __init sock_init(void) +{ + int err; + /* + * Initialize the network sysctl infrastructure. + */ + err = net_sysctl_init(); + if (err) + goto out; + + /* + * Initialize skbuff SLAB cache + */ + skb_init(); + + /* + * Initialize the protocols module. + */ + + init_inodecache(); + + err = register_filesystem(&sock_fs_type); + if (err) + goto out; + sock_mnt = kern_mount(&sock_fs_type); + if (IS_ERR(sock_mnt)) { + err = PTR_ERR(sock_mnt); + goto out_mount; + } + + /* The real protocol initialization is performed in later initcalls. + */ + +#ifdef CONFIG_NETFILTER + err = netfilter_init(); + if (err) + goto out; +#endif + + ptp_classifier_init(); + +out: + return err; + +out_mount: + unregister_filesystem(&sock_fs_type); + goto out; +} + +core_initcall(sock_init); /* early initcall */ + +#ifdef CONFIG_PROC_FS +void socket_seq_show(struct seq_file *seq) +{ + seq_printf(seq, "sockets: used %d\n", + sock_inuse_get(seq->private)); +} +#endif /* CONFIG_PROC_FS */ + +/* Handle the fact that while struct ifreq has the same *layout* on + * 32/64 for everything but ifreq::ifru_ifmap and ifreq::ifru_data, + * which are handled elsewhere, it still has different *size* due to + * ifreq::ifru_ifmap (which is 16 bytes on 32 bit, 24 bytes on 64-bit, + * resulting in struct ifreq being 32 and 40 bytes respectively). + * As a result, if the struct happens to be at the end of a page and + * the next page isn't readable/writable, we get a fault. To prevent + * that, copy back and forth to the full size. + */ +int get_user_ifreq(struct ifreq *ifr, void __user **ifrdata, void __user *arg) +{ + if (in_compat_syscall()) { + struct compat_ifreq *ifr32 = (struct compat_ifreq *)ifr; + + memset(ifr, 0, sizeof(*ifr)); + if (copy_from_user(ifr32, arg, sizeof(*ifr32))) + return -EFAULT; + + if (ifrdata) + *ifrdata = compat_ptr(ifr32->ifr_data); + + return 0; + } + + if (copy_from_user(ifr, arg, sizeof(*ifr))) + return -EFAULT; + + if (ifrdata) + *ifrdata = ifr->ifr_data; + + return 0; +} +EXPORT_SYMBOL(get_user_ifreq); + +int put_user_ifreq(struct ifreq *ifr, void __user *arg) +{ + size_t size = sizeof(*ifr); + + if (in_compat_syscall()) + size = sizeof(struct compat_ifreq); + + if (copy_to_user(arg, ifr, size)) + return -EFAULT; + + return 0; +} +EXPORT_SYMBOL(put_user_ifreq); + +#ifdef CONFIG_COMPAT +static int compat_siocwandev(struct net *net, struct compat_ifreq __user *uifr32) +{ + compat_uptr_t uptr32; + struct ifreq ifr; + void __user *saved; + int err; + + if (get_user_ifreq(&ifr, NULL, uifr32)) + return -EFAULT; + + if (get_user(uptr32, &uifr32->ifr_settings.ifs_ifsu)) + return -EFAULT; + + saved = ifr.ifr_settings.ifs_ifsu.raw_hdlc; + ifr.ifr_settings.ifs_ifsu.raw_hdlc = compat_ptr(uptr32); + + err = dev_ioctl(net, SIOCWANDEV, &ifr, NULL, NULL); + if (!err) { + ifr.ifr_settings.ifs_ifsu.raw_hdlc = saved; + if (put_user_ifreq(&ifr, uifr32)) + err = -EFAULT; + } + return err; +} + +/* Handle ioctls that use ifreq::ifr_data and just need struct ifreq converted */ +static int compat_ifr_data_ioctl(struct net *net, unsigned int cmd, + struct compat_ifreq __user *u_ifreq32) +{ + struct ifreq ifreq; + void __user *data; + + if (!is_socket_ioctl_cmd(cmd)) + return -ENOTTY; + if (get_user_ifreq(&ifreq, &data, u_ifreq32)) + return -EFAULT; + ifreq.ifr_data = data; + + return dev_ioctl(net, cmd, &ifreq, data, NULL); +} + +static int compat_sock_ioctl_trans(struct file *file, struct socket *sock, + unsigned int cmd, unsigned long arg) +{ + void __user *argp = compat_ptr(arg); + struct sock *sk = sock->sk; + struct net *net = sock_net(sk); + + if (cmd >= SIOCDEVPRIVATE && cmd <= (SIOCDEVPRIVATE + 15)) + return sock_ioctl(file, cmd, (unsigned long)argp); + + switch (cmd) { + case SIOCWANDEV: + return compat_siocwandev(net, argp); + case SIOCGSTAMP_OLD: + case SIOCGSTAMPNS_OLD: + if (!sock->ops->gettstamp) + return -ENOIOCTLCMD; + return sock->ops->gettstamp(sock, argp, cmd == SIOCGSTAMP_OLD, + !COMPAT_USE_64BIT_TIME); + + case SIOCETHTOOL: + case SIOCBONDSLAVEINFOQUERY: + case SIOCBONDINFOQUERY: + case SIOCSHWTSTAMP: + case SIOCGHWTSTAMP: + return compat_ifr_data_ioctl(net, cmd, argp); + + case FIOSETOWN: + case SIOCSPGRP: + case FIOGETOWN: + case SIOCGPGRP: + case SIOCBRADDBR: + case SIOCBRDELBR: + case SIOCGIFVLAN: + case SIOCSIFVLAN: + case SIOCGSKNS: + case SIOCGSTAMP_NEW: + case SIOCGSTAMPNS_NEW: + case SIOCGIFCONF: + case SIOCSIFBR: + case SIOCGIFBR: + return sock_ioctl(file, cmd, arg); + + case SIOCGIFFLAGS: + case SIOCSIFFLAGS: + case SIOCGIFMAP: + case SIOCSIFMAP: + case SIOCGIFMETRIC: + case SIOCSIFMETRIC: + case SIOCGIFMTU: + case SIOCSIFMTU: + case SIOCGIFMEM: + case SIOCSIFMEM: + case SIOCGIFHWADDR: + case SIOCSIFHWADDR: + case SIOCADDMULTI: + case SIOCDELMULTI: + case SIOCGIFINDEX: + case SIOCGIFADDR: + case SIOCSIFADDR: + case SIOCSIFHWBROADCAST: + case SIOCDIFADDR: + case SIOCGIFBRDADDR: + case SIOCSIFBRDADDR: + case SIOCGIFDSTADDR: + case SIOCSIFDSTADDR: + case SIOCGIFNETMASK: + case SIOCSIFNETMASK: + case SIOCSIFPFLAGS: + case SIOCGIFPFLAGS: + case SIOCGIFTXQLEN: + case SIOCSIFTXQLEN: + case SIOCBRADDIF: + case SIOCBRDELIF: + case SIOCGIFNAME: + case SIOCSIFNAME: + case SIOCGMIIPHY: + case SIOCGMIIREG: + case SIOCSMIIREG: + case SIOCBONDENSLAVE: + case SIOCBONDRELEASE: + case SIOCBONDSETHWADDR: + case SIOCBONDCHANGEACTIVE: + case SIOCSARP: + case SIOCGARP: + case SIOCDARP: + case SIOCOUTQ: + case SIOCOUTQNSD: + case SIOCATMARK: + return sock_do_ioctl(net, sock, cmd, arg); + } + + return -ENOIOCTLCMD; +} + +static long compat_sock_ioctl(struct file *file, unsigned int cmd, + unsigned long arg) +{ + struct socket *sock = file->private_data; + int ret = -ENOIOCTLCMD; + struct sock *sk; + struct net *net; + + sk = sock->sk; + net = sock_net(sk); + + if (sock->ops->compat_ioctl) + ret = sock->ops->compat_ioctl(sock, cmd, arg); + + if (ret == -ENOIOCTLCMD && + (cmd >= SIOCIWFIRST && cmd <= SIOCIWLAST)) + ret = compat_wext_handle_ioctl(net, cmd, arg); + + if (ret == -ENOIOCTLCMD) + ret = compat_sock_ioctl_trans(file, sock, cmd, arg); + + return ret; +} +#endif + +/** + * kernel_bind - bind an address to a socket (kernel space) + * @sock: socket + * @addr: address + * @addrlen: length of address + * + * Returns 0 or an error. + */ + +int kernel_bind(struct socket *sock, struct sockaddr *addr, int addrlen) +{ + struct sockaddr_storage address; + + memcpy(&address, addr, addrlen); + + return sock->ops->bind(sock, (struct sockaddr *)&address, addrlen); +} +EXPORT_SYMBOL(kernel_bind); + +/** + * kernel_listen - move socket to listening state (kernel space) + * @sock: socket + * @backlog: pending connections queue size + * + * Returns 0 or an error. + */ + +int kernel_listen(struct socket *sock, int backlog) +{ + return sock->ops->listen(sock, backlog); +} +EXPORT_SYMBOL(kernel_listen); + +/** + * kernel_accept - accept a connection (kernel space) + * @sock: listening socket + * @newsock: new connected socket + * @flags: flags + * + * @flags must be SOCK_CLOEXEC, SOCK_NONBLOCK or 0. + * If it fails, @newsock is guaranteed to be %NULL. + * Returns 0 or an error. + */ + +int kernel_accept(struct socket *sock, struct socket **newsock, int flags) +{ + struct sock *sk = sock->sk; + int err; + + err = sock_create_lite(sk->sk_family, sk->sk_type, sk->sk_protocol, + newsock); + if (err < 0) + goto done; + + err = sock->ops->accept(sock, *newsock, flags, true); + if (err < 0) { + sock_release(*newsock); + *newsock = NULL; + goto done; + } + + (*newsock)->ops = sock->ops; + __module_get((*newsock)->ops->owner); + +done: + return err; +} +EXPORT_SYMBOL(kernel_accept); + +/** + * kernel_connect - connect a socket (kernel space) + * @sock: socket + * @addr: address + * @addrlen: address length + * @flags: flags (O_NONBLOCK, ...) + * + * For datagram sockets, @addr is the address to which datagrams are sent + * by default, and the only address from which datagrams are received. + * For stream sockets, attempts to connect to @addr. + * Returns 0 or an error code. + */ + +int kernel_connect(struct socket *sock, struct sockaddr *addr, int addrlen, + int flags) +{ + struct sockaddr_storage address; + + memcpy(&address, addr, addrlen); + + return sock->ops->connect(sock, (struct sockaddr *)&address, addrlen, flags); +} +EXPORT_SYMBOL(kernel_connect); + +/** + * kernel_getsockname - get the address which the socket is bound (kernel space) + * @sock: socket + * @addr: address holder + * + * Fills the @addr pointer with the address which the socket is bound. + * Returns the length of the address in bytes or an error code. + */ + +int kernel_getsockname(struct socket *sock, struct sockaddr *addr) +{ + return sock->ops->getname(sock, addr, 0); +} +EXPORT_SYMBOL(kernel_getsockname); + +/** + * kernel_getpeername - get the address which the socket is connected (kernel space) + * @sock: socket + * @addr: address holder + * + * Fills the @addr pointer with the address which the socket is connected. + * Returns the length of the address in bytes or an error code. + */ + +int kernel_getpeername(struct socket *sock, struct sockaddr *addr) +{ + return sock->ops->getname(sock, addr, 1); +} +EXPORT_SYMBOL(kernel_getpeername); + +/** + * kernel_sendpage - send a &page through a socket (kernel space) + * @sock: socket + * @page: page + * @offset: page offset + * @size: total size in bytes + * @flags: flags (MSG_DONTWAIT, ...) + * + * Returns the total amount sent in bytes or an error. + */ + +int kernel_sendpage(struct socket *sock, struct page *page, int offset, + size_t size, int flags) +{ + if (sock->ops->sendpage) { + /* Warn in case the improper page to zero-copy send */ + WARN_ONCE(!sendpage_ok(page), "improper page for zero-copy send"); + return sock->ops->sendpage(sock, page, offset, size, flags); + } + return sock_no_sendpage(sock, page, offset, size, flags); +} +EXPORT_SYMBOL(kernel_sendpage); + +/** + * kernel_sendpage_locked - send a &page through the locked sock (kernel space) + * @sk: sock + * @page: page + * @offset: page offset + * @size: total size in bytes + * @flags: flags (MSG_DONTWAIT, ...) + * + * Returns the total amount sent in bytes or an error. + * Caller must hold @sk. + */ + +int kernel_sendpage_locked(struct sock *sk, struct page *page, int offset, + size_t size, int flags) +{ + struct socket *sock = sk->sk_socket; + + if (sock->ops->sendpage_locked) + return sock->ops->sendpage_locked(sk, page, offset, size, + flags); + + return sock_no_sendpage_locked(sk, page, offset, size, flags); +} +EXPORT_SYMBOL(kernel_sendpage_locked); + +/** + * kernel_sock_shutdown - shut down part of a full-duplex connection (kernel space) + * @sock: socket + * @how: connection part + * + * Returns 0 or an error. + */ + +int kernel_sock_shutdown(struct socket *sock, enum sock_shutdown_cmd how) +{ + return sock->ops->shutdown(sock, how); +} +EXPORT_SYMBOL(kernel_sock_shutdown); + +/** + * kernel_sock_ip_overhead - returns the IP overhead imposed by a socket + * @sk: socket + * + * This routine returns the IP overhead imposed by a socket i.e. + * the length of the underlying IP header, depending on whether + * this is an IPv4 or IPv6 socket and the length from IP options turned + * on at the socket. Assumes that the caller has a lock on the socket. + */ + +u32 kernel_sock_ip_overhead(struct sock *sk) +{ + struct inet_sock *inet; + struct ip_options_rcu *opt; + u32 overhead = 0; +#if IS_ENABLED(CONFIG_IPV6) + struct ipv6_pinfo *np; + struct ipv6_txoptions *optv6 = NULL; +#endif /* IS_ENABLED(CONFIG_IPV6) */ + + if (!sk) + return overhead; + + switch (sk->sk_family) { + case AF_INET: + inet = inet_sk(sk); + overhead += sizeof(struct iphdr); + opt = rcu_dereference_protected(inet->inet_opt, + sock_owned_by_user(sk)); + if (opt) + overhead += opt->opt.optlen; + return overhead; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + np = inet6_sk(sk); + overhead += sizeof(struct ipv6hdr); + if (np) + optv6 = rcu_dereference_protected(np->opt, + sock_owned_by_user(sk)); + if (optv6) + overhead += (optv6->opt_flen + optv6->opt_nflen); + return overhead; +#endif /* IS_ENABLED(CONFIG_IPV6) */ + default: /* Returns 0 overhead if the socket is not ipv4 or ipv6 */ + return overhead; + } +} +EXPORT_SYMBOL(kernel_sock_ip_overhead); diff --git a/net/strparser/Kconfig b/net/strparser/Kconfig new file mode 100644 index 000000000..e6146c21a --- /dev/null +++ b/net/strparser/Kconfig @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: GPL-2.0-only +config STREAM_PARSER + def_bool n diff --git a/net/strparser/Makefile b/net/strparser/Makefile new file mode 100644 index 000000000..931319153 --- /dev/null +++ b/net/strparser/Makefile @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: GPL-2.0-only +obj-$(CONFIG_STREAM_PARSER) += strparser.o diff --git a/net/strparser/strparser.c b/net/strparser/strparser.c new file mode 100644 index 000000000..8299ceb3e --- /dev/null +++ b/net/strparser/strparser.c @@ -0,0 +1,545 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Stream Parser + * + * Copyright (c) 2016 Tom Herbert + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static struct workqueue_struct *strp_wq; + +static inline struct _strp_msg *_strp_msg(struct sk_buff *skb) +{ + return (struct _strp_msg *)((void *)skb->cb + + offsetof(struct sk_skb_cb, strp)); +} + +/* Lower lock held */ +static void strp_abort_strp(struct strparser *strp, int err) +{ + /* Unrecoverable error in receive */ + + cancel_delayed_work(&strp->msg_timer_work); + + if (strp->stopped) + return; + + strp->stopped = 1; + + if (strp->sk) { + struct sock *sk = strp->sk; + + /* Report an error on the lower socket */ + sk->sk_err = -err; + sk_error_report(sk); + } +} + +static void strp_start_timer(struct strparser *strp, long timeo) +{ + if (timeo && timeo != LONG_MAX) + mod_delayed_work(strp_wq, &strp->msg_timer_work, timeo); +} + +/* Lower lock held */ +static void strp_parser_err(struct strparser *strp, int err, + read_descriptor_t *desc) +{ + desc->error = err; + kfree_skb(strp->skb_head); + strp->skb_head = NULL; + strp->cb.abort_parser(strp, err); +} + +static inline int strp_peek_len(struct strparser *strp) +{ + if (strp->sk) { + struct socket *sock = strp->sk->sk_socket; + + return sock->ops->peek_len(sock); + } + + /* If we don't have an associated socket there's nothing to peek. + * Return int max to avoid stopping the strparser. + */ + + return INT_MAX; +} + +/* Lower socket lock held */ +static int __strp_recv(read_descriptor_t *desc, struct sk_buff *orig_skb, + unsigned int orig_offset, size_t orig_len, + size_t max_msg_size, long timeo) +{ + struct strparser *strp = (struct strparser *)desc->arg.data; + struct _strp_msg *stm; + struct sk_buff *head, *skb; + size_t eaten = 0, cand_len; + ssize_t extra; + int err; + bool cloned_orig = false; + + if (strp->paused) + return 0; + + head = strp->skb_head; + if (head) { + /* Message already in progress */ + if (unlikely(orig_offset)) { + /* Getting data with a non-zero offset when a message is + * in progress is not expected. If it does happen, we + * need to clone and pull since we can't deal with + * offsets in the skbs for a message expect in the head. + */ + orig_skb = skb_clone(orig_skb, GFP_ATOMIC); + if (!orig_skb) { + STRP_STATS_INCR(strp->stats.mem_fail); + desc->error = -ENOMEM; + return 0; + } + if (!pskb_pull(orig_skb, orig_offset)) { + STRP_STATS_INCR(strp->stats.mem_fail); + kfree_skb(orig_skb); + desc->error = -ENOMEM; + return 0; + } + cloned_orig = true; + orig_offset = 0; + } + + if (!strp->skb_nextp) { + /* We are going to append to the frags_list of head. + * Need to unshare the frag_list. + */ + err = skb_unclone(head, GFP_ATOMIC); + if (err) { + STRP_STATS_INCR(strp->stats.mem_fail); + desc->error = err; + return 0; + } + + if (unlikely(skb_shinfo(head)->frag_list)) { + /* We can't append to an sk_buff that already + * has a frag_list. We create a new head, point + * the frag_list of that to the old head, and + * then are able to use the old head->next for + * appending to the message. + */ + if (WARN_ON(head->next)) { + desc->error = -EINVAL; + return 0; + } + + skb = alloc_skb_for_msg(head); + if (!skb) { + STRP_STATS_INCR(strp->stats.mem_fail); + desc->error = -ENOMEM; + return 0; + } + + strp->skb_nextp = &head->next; + strp->skb_head = skb; + head = skb; + } else { + strp->skb_nextp = + &skb_shinfo(head)->frag_list; + } + } + } + + while (eaten < orig_len) { + /* Always clone since we will consume something */ + skb = skb_clone(orig_skb, GFP_ATOMIC); + if (!skb) { + STRP_STATS_INCR(strp->stats.mem_fail); + desc->error = -ENOMEM; + break; + } + + cand_len = orig_len - eaten; + + head = strp->skb_head; + if (!head) { + head = skb; + strp->skb_head = head; + /* Will set skb_nextp on next packet if needed */ + strp->skb_nextp = NULL; + stm = _strp_msg(head); + memset(stm, 0, sizeof(*stm)); + stm->strp.offset = orig_offset + eaten; + } else { + /* Unclone if we are appending to an skb that we + * already share a frag_list with. + */ + if (skb_has_frag_list(skb)) { + err = skb_unclone(skb, GFP_ATOMIC); + if (err) { + STRP_STATS_INCR(strp->stats.mem_fail); + desc->error = err; + break; + } + } + + stm = _strp_msg(head); + *strp->skb_nextp = skb; + strp->skb_nextp = &skb->next; + head->data_len += skb->len; + head->len += skb->len; + head->truesize += skb->truesize; + } + + if (!stm->strp.full_len) { + ssize_t len; + + len = (*strp->cb.parse_msg)(strp, head); + + if (!len) { + /* Need more header to determine length */ + if (!stm->accum_len) { + /* Start RX timer for new message */ + strp_start_timer(strp, timeo); + } + stm->accum_len += cand_len; + eaten += cand_len; + STRP_STATS_INCR(strp->stats.need_more_hdr); + WARN_ON(eaten != orig_len); + break; + } else if (len < 0) { + if (len == -ESTRPIPE && stm->accum_len) { + len = -ENODATA; + strp->unrecov_intr = 1; + } else { + strp->interrupted = 1; + } + strp_parser_err(strp, len, desc); + break; + } else if (len > max_msg_size) { + /* Message length exceeds maximum allowed */ + STRP_STATS_INCR(strp->stats.msg_too_big); + strp_parser_err(strp, -EMSGSIZE, desc); + break; + } else if (len <= (ssize_t)head->len - + skb->len - stm->strp.offset) { + /* Length must be into new skb (and also + * greater than zero) + */ + STRP_STATS_INCR(strp->stats.bad_hdr_len); + strp_parser_err(strp, -EPROTO, desc); + break; + } + + stm->strp.full_len = len; + } + + extra = (ssize_t)(stm->accum_len + cand_len) - + stm->strp.full_len; + + if (extra < 0) { + /* Message not complete yet. */ + if (stm->strp.full_len - stm->accum_len > + strp_peek_len(strp)) { + /* Don't have the whole message in the socket + * buffer. Set strp->need_bytes to wait for + * the rest of the message. Also, set "early + * eaten" since we've already buffered the skb + * but don't consume yet per strp_read_sock. + */ + + if (!stm->accum_len) { + /* Start RX timer for new message */ + strp_start_timer(strp, timeo); + } + + stm->accum_len += cand_len; + eaten += cand_len; + strp->need_bytes = stm->strp.full_len - + stm->accum_len; + STRP_STATS_ADD(strp->stats.bytes, cand_len); + desc->count = 0; /* Stop reading socket */ + break; + } + stm->accum_len += cand_len; + eaten += cand_len; + WARN_ON(eaten != orig_len); + break; + } + + /* Positive extra indicates more bytes than needed for the + * message + */ + + WARN_ON(extra > cand_len); + + eaten += (cand_len - extra); + + /* Hurray, we have a new message! */ + cancel_delayed_work(&strp->msg_timer_work); + strp->skb_head = NULL; + strp->need_bytes = 0; + STRP_STATS_INCR(strp->stats.msgs); + + /* Give skb to upper layer */ + strp->cb.rcv_msg(strp, head); + + if (unlikely(strp->paused)) { + /* Upper layer paused strp */ + break; + } + } + + if (cloned_orig) + kfree_skb(orig_skb); + + STRP_STATS_ADD(strp->stats.bytes, eaten); + + return eaten; +} + +int strp_process(struct strparser *strp, struct sk_buff *orig_skb, + unsigned int orig_offset, size_t orig_len, + size_t max_msg_size, long timeo) +{ + read_descriptor_t desc; /* Dummy arg to strp_recv */ + + desc.arg.data = strp; + + return __strp_recv(&desc, orig_skb, orig_offset, orig_len, + max_msg_size, timeo); +} +EXPORT_SYMBOL_GPL(strp_process); + +static int strp_recv(read_descriptor_t *desc, struct sk_buff *orig_skb, + unsigned int orig_offset, size_t orig_len) +{ + struct strparser *strp = (struct strparser *)desc->arg.data; + + return __strp_recv(desc, orig_skb, orig_offset, orig_len, + strp->sk->sk_rcvbuf, strp->sk->sk_rcvtimeo); +} + +static int default_read_sock_done(struct strparser *strp, int err) +{ + return err; +} + +/* Called with lock held on lower socket */ +static int strp_read_sock(struct strparser *strp) +{ + struct socket *sock = strp->sk->sk_socket; + read_descriptor_t desc; + + if (unlikely(!sock || !sock->ops || !sock->ops->read_sock)) + return -EBUSY; + + desc.arg.data = strp; + desc.error = 0; + desc.count = 1; /* give more than one skb per call */ + + /* sk should be locked here, so okay to do read_sock */ + sock->ops->read_sock(strp->sk, &desc, strp_recv); + + desc.error = strp->cb.read_sock_done(strp, desc.error); + + return desc.error; +} + +/* Lower sock lock held */ +void strp_data_ready(struct strparser *strp) +{ + if (unlikely(strp->stopped) || strp->paused) + return; + + /* This check is needed to synchronize with do_strp_work. + * do_strp_work acquires a process lock (lock_sock) whereas + * the lock held here is bh_lock_sock. The two locks can be + * held by different threads at the same time, but bh_lock_sock + * allows a thread in BH context to safely check if the process + * lock is held. In this case, if the lock is held, queue work. + */ + if (sock_owned_by_user_nocheck(strp->sk)) { + queue_work(strp_wq, &strp->work); + return; + } + + if (strp->need_bytes) { + if (strp_peek_len(strp) < strp->need_bytes) + return; + } + + if (strp_read_sock(strp) == -ENOMEM) + queue_work(strp_wq, &strp->work); +} +EXPORT_SYMBOL_GPL(strp_data_ready); + +static void do_strp_work(struct strparser *strp) +{ + /* We need the read lock to synchronize with strp_data_ready. We + * need the socket lock for calling strp_read_sock. + */ + strp->cb.lock(strp); + + if (unlikely(strp->stopped)) + goto out; + + if (strp->paused) + goto out; + + if (strp_read_sock(strp) == -ENOMEM) + queue_work(strp_wq, &strp->work); + +out: + strp->cb.unlock(strp); +} + +static void strp_work(struct work_struct *w) +{ + do_strp_work(container_of(w, struct strparser, work)); +} + +static void strp_msg_timeout(struct work_struct *w) +{ + struct strparser *strp = container_of(w, struct strparser, + msg_timer_work.work); + + /* Message assembly timed out */ + STRP_STATS_INCR(strp->stats.msg_timeouts); + strp->cb.lock(strp); + strp->cb.abort_parser(strp, -ETIMEDOUT); + strp->cb.unlock(strp); +} + +static void strp_sock_lock(struct strparser *strp) +{ + lock_sock(strp->sk); +} + +static void strp_sock_unlock(struct strparser *strp) +{ + release_sock(strp->sk); +} + +int strp_init(struct strparser *strp, struct sock *sk, + const struct strp_callbacks *cb) +{ + + if (!cb || !cb->rcv_msg || !cb->parse_msg) + return -EINVAL; + + /* The sk (sock) arg determines the mode of the stream parser. + * + * If the sock is set then the strparser is in receive callback mode. + * The upper layer calls strp_data_ready to kick receive processing + * and strparser calls the read_sock function on the socket to + * get packets. + * + * If the sock is not set then the strparser is in general mode. + * The upper layer calls strp_process for each skb to be parsed. + */ + + if (!sk) { + if (!cb->lock || !cb->unlock) + return -EINVAL; + } + + memset(strp, 0, sizeof(*strp)); + + strp->sk = sk; + + strp->cb.lock = cb->lock ? : strp_sock_lock; + strp->cb.unlock = cb->unlock ? : strp_sock_unlock; + strp->cb.rcv_msg = cb->rcv_msg; + strp->cb.parse_msg = cb->parse_msg; + strp->cb.read_sock_done = cb->read_sock_done ? : default_read_sock_done; + strp->cb.abort_parser = cb->abort_parser ? : strp_abort_strp; + + INIT_DELAYED_WORK(&strp->msg_timer_work, strp_msg_timeout); + INIT_WORK(&strp->work, strp_work); + + return 0; +} +EXPORT_SYMBOL_GPL(strp_init); + +/* Sock process lock held (lock_sock) */ +void __strp_unpause(struct strparser *strp) +{ + strp->paused = 0; + + if (strp->need_bytes) { + if (strp_peek_len(strp) < strp->need_bytes) + return; + } + strp_read_sock(strp); +} +EXPORT_SYMBOL_GPL(__strp_unpause); + +void strp_unpause(struct strparser *strp) +{ + strp->paused = 0; + + /* Sync setting paused with RX work */ + smp_mb(); + + queue_work(strp_wq, &strp->work); +} +EXPORT_SYMBOL_GPL(strp_unpause); + +/* strp must already be stopped so that strp_recv will no longer be called. + * Note that strp_done is not called with the lower socket held. + */ +void strp_done(struct strparser *strp) +{ + WARN_ON(!strp->stopped); + + cancel_delayed_work_sync(&strp->msg_timer_work); + cancel_work_sync(&strp->work); + + if (strp->skb_head) { + kfree_skb(strp->skb_head); + strp->skb_head = NULL; + } +} +EXPORT_SYMBOL_GPL(strp_done); + +void strp_stop(struct strparser *strp) +{ + strp->stopped = 1; +} +EXPORT_SYMBOL_GPL(strp_stop); + +void strp_check_rcv(struct strparser *strp) +{ + queue_work(strp_wq, &strp->work); +} +EXPORT_SYMBOL_GPL(strp_check_rcv); + +static int __init strp_dev_init(void) +{ + BUILD_BUG_ON(sizeof(struct sk_skb_cb) > + sizeof_field(struct sk_buff, cb)); + + strp_wq = create_singlethread_workqueue("kstrp"); + if (unlikely(!strp_wq)) + return -ENOMEM; + + return 0; +} +device_initcall(strp_dev_init); diff --git a/net/sunrpc/Kconfig b/net/sunrpc/Kconfig new file mode 100644 index 000000000..bbbb5af0a --- /dev/null +++ b/net/sunrpc/Kconfig @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: GPL-2.0-only +config SUNRPC + tristate + depends on MULTIUSER + +config SUNRPC_GSS + tristate + select OID_REGISTRY + depends on MULTIUSER + +config SUNRPC_BACKCHANNEL + bool + depends on SUNRPC + +config SUNRPC_SWAP + bool + depends on SUNRPC + +config RPCSEC_GSS_KRB5 + tristate "Secure RPC: Kerberos V mechanism" + depends on SUNRPC && CRYPTO + depends on CRYPTO_MD5 && CRYPTO_DES && CRYPTO_CBC && CRYPTO_CTS + depends on CRYPTO_ECB && CRYPTO_HMAC && CRYPTO_SHA1 && CRYPTO_AES + default y + select SUNRPC_GSS + help + Choose Y here to enable Secure RPC using the Kerberos version 5 + GSS-API mechanism (RFC 1964). + + Secure RPC calls with Kerberos require an auxiliary user-space + daemon which may be found in the Linux nfs-utils package + available from http://linux-nfs.org/. In addition, user-space + Kerberos support should be installed. + + If unsure, say Y. + +config SUNRPC_DISABLE_INSECURE_ENCTYPES + bool "Secure RPC: Disable insecure Kerberos encryption types" + depends on RPCSEC_GSS_KRB5 + default n + help + Choose Y here to disable the use of deprecated encryption types + with the Kerberos version 5 GSS-API mechanism (RFC 1964). The + deprecated encryption types include DES-CBC-MD5, DES-CBC-CRC, + and DES-CBC-MD4. These types were deprecated by RFC 6649 because + they were found to be insecure. + + N is the default because many sites have deployed KDCs and + keytabs that contain only these deprecated encryption types. + Choosing Y prevents the use of known-insecure encryption types + but might result in compatibility problems. + +config SUNRPC_DEBUG + bool "RPC: Enable dprintk debugging" + depends on SUNRPC && SYSCTL + select DEBUG_FS + help + This option enables a sysctl-based debugging interface + that is be used by the 'rpcdebug' utility to turn on or off + logging of different aspects of the kernel RPC activity. + + Disabling this option will make your kernel slightly smaller, + but makes troubleshooting NFS issues significantly harder. + + If unsure, say Y. + +config SUNRPC_XPRT_RDMA + tristate "RPC-over-RDMA transport" + depends on SUNRPC && INFINIBAND && INFINIBAND_ADDR_TRANS + default SUNRPC && INFINIBAND + select SG_POOL + help + This option allows the NFS client and server to use RDMA + transports (InfiniBand, iWARP, or RoCE). + + To compile this support as a module, choose M. The module + will be called rpcrdma.ko. + + If unsure, or you know there is no RDMA capability on your + hardware platform, say N. diff --git a/net/sunrpc/Makefile b/net/sunrpc/Makefile new file mode 100644 index 000000000..1c8de397d --- /dev/null +++ b/net/sunrpc/Makefile @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for Linux kernel SUN RPC +# + + +obj-$(CONFIG_SUNRPC) += sunrpc.o +obj-$(CONFIG_SUNRPC_GSS) += auth_gss/ +obj-$(CONFIG_SUNRPC_XPRT_RDMA) += xprtrdma/ + +sunrpc-y := clnt.o xprt.o socklib.o xprtsock.o sched.o \ + auth.o auth_null.o auth_unix.o \ + svc.o svcsock.o svcauth.o svcauth_unix.o \ + addr.o rpcb_clnt.o timer.o xdr.o \ + sunrpc_syms.o cache.o rpc_pipe.o sysfs.o \ + svc_xprt.o \ + xprtmultipath.o +sunrpc-$(CONFIG_SUNRPC_DEBUG) += debugfs.o +sunrpc-$(CONFIG_SUNRPC_BACKCHANNEL) += backchannel_rqst.o +sunrpc-$(CONFIG_PROC_FS) += stats.o +sunrpc-$(CONFIG_SYSCTL) += sysctl.o diff --git a/net/sunrpc/addr.c b/net/sunrpc/addr.c new file mode 100644 index 000000000..d435bffc6 --- /dev/null +++ b/net/sunrpc/addr.c @@ -0,0 +1,354 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Copyright 2009, Oracle. All rights reserved. + * + * Convert socket addresses to presentation addresses and universal + * addresses, and vice versa. + * + * Universal addresses are introduced by RFC 1833 and further refined by + * recent RFCs describing NFSv4. The universal address format is part + * of the external (network) interface provided by rpcbind version 3 + * and 4, and by NFSv4. Such an address is a string containing a + * presentation format IP address followed by a port number in + * "hibyte.lobyte" format. + * + * IPv6 addresses can also include a scope ID, typically denoted by + * a '%' followed by a device name or a non-negative integer. Refer to + * RFC 4291, Section 2.2 for details on IPv6 presentation formats. + */ + +#include +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_IPV6) + +static size_t rpc_ntop6_noscopeid(const struct sockaddr *sap, + char *buf, const int buflen) +{ + const struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sap; + const struct in6_addr *addr = &sin6->sin6_addr; + + /* + * RFC 4291, Section 2.2.2 + * + * Shorthanded ANY address + */ + if (ipv6_addr_any(addr)) + return snprintf(buf, buflen, "::"); + + /* + * RFC 4291, Section 2.2.2 + * + * Shorthanded loopback address + */ + if (ipv6_addr_loopback(addr)) + return snprintf(buf, buflen, "::1"); + + /* + * RFC 4291, Section 2.2.3 + * + * Special presentation address format for mapped v4 + * addresses. + */ + if (ipv6_addr_v4mapped(addr)) + return snprintf(buf, buflen, "::ffff:%pI4", + &addr->s6_addr32[3]); + + /* + * RFC 4291, Section 2.2.1 + */ + return snprintf(buf, buflen, "%pI6c", addr); +} + +static size_t rpc_ntop6(const struct sockaddr *sap, + char *buf, const size_t buflen) +{ + const struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sap; + char scopebuf[IPV6_SCOPE_ID_LEN]; + size_t len; + int rc; + + len = rpc_ntop6_noscopeid(sap, buf, buflen); + if (unlikely(len == 0)) + return len; + + if (!(ipv6_addr_type(&sin6->sin6_addr) & IPV6_ADDR_LINKLOCAL)) + return len; + if (sin6->sin6_scope_id == 0) + return len; + + rc = snprintf(scopebuf, sizeof(scopebuf), "%c%u", + IPV6_SCOPE_DELIMITER, sin6->sin6_scope_id); + if (unlikely((size_t)rc >= sizeof(scopebuf))) + return 0; + + len += rc; + if (unlikely(len >= buflen)) + return 0; + + strcat(buf, scopebuf); + return len; +} + +#else /* !IS_ENABLED(CONFIG_IPV6) */ + +static size_t rpc_ntop6_noscopeid(const struct sockaddr *sap, + char *buf, const int buflen) +{ + return 0; +} + +static size_t rpc_ntop6(const struct sockaddr *sap, + char *buf, const size_t buflen) +{ + return 0; +} + +#endif /* !IS_ENABLED(CONFIG_IPV6) */ + +static int rpc_ntop4(const struct sockaddr *sap, + char *buf, const size_t buflen) +{ + const struct sockaddr_in *sin = (struct sockaddr_in *)sap; + + return snprintf(buf, buflen, "%pI4", &sin->sin_addr); +} + +/** + * rpc_ntop - construct a presentation address in @buf + * @sap: socket address + * @buf: construction area + * @buflen: size of @buf, in bytes + * + * Plants a %NUL-terminated string in @buf and returns the length + * of the string, excluding the %NUL. Otherwise zero is returned. + */ +size_t rpc_ntop(const struct sockaddr *sap, char *buf, const size_t buflen) +{ + switch (sap->sa_family) { + case AF_INET: + return rpc_ntop4(sap, buf, buflen); + case AF_INET6: + return rpc_ntop6(sap, buf, buflen); + } + + return 0; +} +EXPORT_SYMBOL_GPL(rpc_ntop); + +static size_t rpc_pton4(const char *buf, const size_t buflen, + struct sockaddr *sap, const size_t salen) +{ + struct sockaddr_in *sin = (struct sockaddr_in *)sap; + u8 *addr = (u8 *)&sin->sin_addr.s_addr; + + if (buflen > INET_ADDRSTRLEN || salen < sizeof(struct sockaddr_in)) + return 0; + + memset(sap, 0, sizeof(struct sockaddr_in)); + + if (in4_pton(buf, buflen, addr, '\0', NULL) == 0) + return 0; + + sin->sin_family = AF_INET; + return sizeof(struct sockaddr_in); +} + +#if IS_ENABLED(CONFIG_IPV6) +static int rpc_parse_scope_id(struct net *net, const char *buf, + const size_t buflen, const char *delim, + struct sockaddr_in6 *sin6) +{ + char p[IPV6_SCOPE_ID_LEN + 1]; + size_t len; + u32 scope_id = 0; + struct net_device *dev; + + if ((buf + buflen) == delim) + return 1; + + if (*delim != IPV6_SCOPE_DELIMITER) + return 0; + + if (!(ipv6_addr_type(&sin6->sin6_addr) & IPV6_ADDR_LINKLOCAL)) + return 0; + + len = (buf + buflen) - delim - 1; + if (len > IPV6_SCOPE_ID_LEN) + return 0; + + memcpy(p, delim + 1, len); + p[len] = 0; + + dev = dev_get_by_name(net, p); + if (dev != NULL) { + scope_id = dev->ifindex; + dev_put(dev); + } else { + if (kstrtou32(p, 10, &scope_id) != 0) + return 0; + } + + sin6->sin6_scope_id = scope_id; + return 1; +} + +static size_t rpc_pton6(struct net *net, const char *buf, const size_t buflen, + struct sockaddr *sap, const size_t salen) +{ + struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sap; + u8 *addr = (u8 *)&sin6->sin6_addr.in6_u; + const char *delim; + + if (buflen > (INET6_ADDRSTRLEN + IPV6_SCOPE_ID_LEN) || + salen < sizeof(struct sockaddr_in6)) + return 0; + + memset(sap, 0, sizeof(struct sockaddr_in6)); + + if (in6_pton(buf, buflen, addr, IPV6_SCOPE_DELIMITER, &delim) == 0) + return 0; + + if (!rpc_parse_scope_id(net, buf, buflen, delim, sin6)) + return 0; + + sin6->sin6_family = AF_INET6; + return sizeof(struct sockaddr_in6); +} +#else +static size_t rpc_pton6(struct net *net, const char *buf, const size_t buflen, + struct sockaddr *sap, const size_t salen) +{ + return 0; +} +#endif + +/** + * rpc_pton - Construct a sockaddr in @sap + * @net: applicable network namespace + * @buf: C string containing presentation format IP address + * @buflen: length of presentation address in bytes + * @sap: buffer into which to plant socket address + * @salen: size of buffer in bytes + * + * Returns the size of the socket address if successful; otherwise + * zero is returned. + * + * Plants a socket address in @sap and returns the size of the + * socket address, if successful. Returns zero if an error + * occurred. + */ +size_t rpc_pton(struct net *net, const char *buf, const size_t buflen, + struct sockaddr *sap, const size_t salen) +{ + unsigned int i; + + for (i = 0; i < buflen; i++) + if (buf[i] == ':') + return rpc_pton6(net, buf, buflen, sap, salen); + return rpc_pton4(buf, buflen, sap, salen); +} +EXPORT_SYMBOL_GPL(rpc_pton); + +/** + * rpc_sockaddr2uaddr - Construct a universal address string from @sap. + * @sap: socket address + * @gfp_flags: allocation mode + * + * Returns a %NUL-terminated string in dynamically allocated memory; + * otherwise NULL is returned if an error occurred. Caller must + * free the returned string. + */ +char *rpc_sockaddr2uaddr(const struct sockaddr *sap, gfp_t gfp_flags) +{ + char portbuf[RPCBIND_MAXUADDRPLEN]; + char addrbuf[RPCBIND_MAXUADDRLEN]; + unsigned short port; + + switch (sap->sa_family) { + case AF_INET: + if (rpc_ntop4(sap, addrbuf, sizeof(addrbuf)) == 0) + return NULL; + port = ntohs(((struct sockaddr_in *)sap)->sin_port); + break; + case AF_INET6: + if (rpc_ntop6_noscopeid(sap, addrbuf, sizeof(addrbuf)) == 0) + return NULL; + port = ntohs(((struct sockaddr_in6 *)sap)->sin6_port); + break; + default: + return NULL; + } + + if (snprintf(portbuf, sizeof(portbuf), + ".%u.%u", port >> 8, port & 0xff) > (int)sizeof(portbuf)) + return NULL; + + if (strlcat(addrbuf, portbuf, sizeof(addrbuf)) > sizeof(addrbuf)) + return NULL; + + return kstrdup(addrbuf, gfp_flags); +} + +/** + * rpc_uaddr2sockaddr - convert a universal address to a socket address. + * @net: applicable network namespace + * @uaddr: C string containing universal address to convert + * @uaddr_len: length of universal address string + * @sap: buffer into which to plant socket address + * @salen: size of buffer + * + * @uaddr does not have to be '\0'-terminated, but kstrtou8() and + * rpc_pton() require proper string termination to be successful. + * + * Returns the size of the socket address if successful; otherwise + * zero is returned. + */ +size_t rpc_uaddr2sockaddr(struct net *net, const char *uaddr, + const size_t uaddr_len, struct sockaddr *sap, + const size_t salen) +{ + char *c, buf[RPCBIND_MAXUADDRLEN + sizeof('\0')]; + u8 portlo, porthi; + unsigned short port; + + if (uaddr_len > RPCBIND_MAXUADDRLEN) + return 0; + + memcpy(buf, uaddr, uaddr_len); + + buf[uaddr_len] = '\0'; + c = strrchr(buf, '.'); + if (unlikely(c == NULL)) + return 0; + if (unlikely(kstrtou8(c + 1, 10, &portlo) != 0)) + return 0; + + *c = '\0'; + c = strrchr(buf, '.'); + if (unlikely(c == NULL)) + return 0; + if (unlikely(kstrtou8(c + 1, 10, &porthi) != 0)) + return 0; + + port = (unsigned short)((porthi << 8) | portlo); + + *c = '\0'; + if (rpc_pton(net, buf, strlen(buf), sap, salen) == 0) + return 0; + + switch (sap->sa_family) { + case AF_INET: + ((struct sockaddr_in *)sap)->sin_port = htons(port); + return sizeof(struct sockaddr_in); + case AF_INET6: + ((struct sockaddr_in6 *)sap)->sin6_port = htons(port); + return sizeof(struct sockaddr_in6); + } + + return 0; +} +EXPORT_SYMBOL_GPL(rpc_uaddr2sockaddr); diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c new file mode 100644 index 000000000..fb75a8835 --- /dev/null +++ b/net/sunrpc/auth.c @@ -0,0 +1,891 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * linux/net/sunrpc/auth.c + * + * Generic RPC client authentication API. + * + * Copyright (C) 1996, Olaf Kirch + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#define RPC_CREDCACHE_DEFAULT_HASHBITS (4) +struct rpc_cred_cache { + struct hlist_head *hashtable; + unsigned int hashbits; + spinlock_t lock; +}; + +static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS; + +static const struct rpc_authops __rcu *auth_flavors[RPC_AUTH_MAXFLAVOR] = { + [RPC_AUTH_NULL] = (const struct rpc_authops __force __rcu *)&authnull_ops, + [RPC_AUTH_UNIX] = (const struct rpc_authops __force __rcu *)&authunix_ops, + NULL, /* others can be loadable modules */ +}; + +static LIST_HEAD(cred_unused); +static unsigned long number_cred_unused; + +static struct cred machine_cred = { + .usage = ATOMIC_INIT(1), +#ifdef CONFIG_DEBUG_CREDENTIALS + .magic = CRED_MAGIC, +#endif +}; + +/* + * Return the machine_cred pointer to be used whenever + * the a generic machine credential is needed. + */ +const struct cred *rpc_machine_cred(void) +{ + return &machine_cred; +} +EXPORT_SYMBOL_GPL(rpc_machine_cred); + +#define MAX_HASHTABLE_BITS (14) +static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp) +{ + unsigned long num; + unsigned int nbits; + int ret; + + if (!val) + goto out_inval; + ret = kstrtoul(val, 0, &num); + if (ret) + goto out_inval; + nbits = fls(num - 1); + if (nbits > MAX_HASHTABLE_BITS || nbits < 2) + goto out_inval; + *(unsigned int *)kp->arg = nbits; + return 0; +out_inval: + return -EINVAL; +} + +static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp) +{ + unsigned int nbits; + + nbits = *(unsigned int *)kp->arg; + return sprintf(buffer, "%u\n", 1U << nbits); +} + +#define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int); + +static const struct kernel_param_ops param_ops_hashtbl_sz = { + .set = param_set_hashtbl_sz, + .get = param_get_hashtbl_sz, +}; + +module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644); +MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size"); + +static unsigned long auth_max_cred_cachesize = ULONG_MAX; +module_param(auth_max_cred_cachesize, ulong, 0644); +MODULE_PARM_DESC(auth_max_cred_cachesize, "RPC credential maximum total cache size"); + +static u32 +pseudoflavor_to_flavor(u32 flavor) { + if (flavor > RPC_AUTH_MAXFLAVOR) + return RPC_AUTH_GSS; + return flavor; +} + +int +rpcauth_register(const struct rpc_authops *ops) +{ + const struct rpc_authops *old; + rpc_authflavor_t flavor; + + if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) + return -EINVAL; + old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], NULL, ops); + if (old == NULL || old == ops) + return 0; + return -EPERM; +} +EXPORT_SYMBOL_GPL(rpcauth_register); + +int +rpcauth_unregister(const struct rpc_authops *ops) +{ + const struct rpc_authops *old; + rpc_authflavor_t flavor; + + if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) + return -EINVAL; + + old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], ops, NULL); + if (old == ops || old == NULL) + return 0; + return -EPERM; +} +EXPORT_SYMBOL_GPL(rpcauth_unregister); + +static const struct rpc_authops * +rpcauth_get_authops(rpc_authflavor_t flavor) +{ + const struct rpc_authops *ops; + + if (flavor >= RPC_AUTH_MAXFLAVOR) + return NULL; + + rcu_read_lock(); + ops = rcu_dereference(auth_flavors[flavor]); + if (ops == NULL) { + rcu_read_unlock(); + request_module("rpc-auth-%u", flavor); + rcu_read_lock(); + ops = rcu_dereference(auth_flavors[flavor]); + if (ops == NULL) + goto out; + } + if (!try_module_get(ops->owner)) + ops = NULL; +out: + rcu_read_unlock(); + return ops; +} + +static void +rpcauth_put_authops(const struct rpc_authops *ops) +{ + module_put(ops->owner); +} + +/** + * rpcauth_get_pseudoflavor - check if security flavor is supported + * @flavor: a security flavor + * @info: a GSS mech OID, quality of protection, and service value + * + * Verifies that an appropriate kernel module is available or already loaded. + * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is + * not supported locally. + */ +rpc_authflavor_t +rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info) +{ + const struct rpc_authops *ops = rpcauth_get_authops(flavor); + rpc_authflavor_t pseudoflavor; + + if (!ops) + return RPC_AUTH_MAXFLAVOR; + pseudoflavor = flavor; + if (ops->info2flavor != NULL) + pseudoflavor = ops->info2flavor(info); + + rpcauth_put_authops(ops); + return pseudoflavor; +} +EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor); + +/** + * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor + * @pseudoflavor: GSS pseudoflavor to match + * @info: rpcsec_gss_info structure to fill in + * + * Returns zero and fills in "info" if pseudoflavor matches a + * supported mechanism. + */ +int +rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info) +{ + rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor); + const struct rpc_authops *ops; + int result; + + ops = rpcauth_get_authops(flavor); + if (ops == NULL) + return -ENOENT; + + result = -ENOENT; + if (ops->flavor2info != NULL) + result = ops->flavor2info(pseudoflavor, info); + + rpcauth_put_authops(ops); + return result; +} +EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo); + +struct rpc_auth * +rpcauth_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) +{ + struct rpc_auth *auth = ERR_PTR(-EINVAL); + const struct rpc_authops *ops; + u32 flavor = pseudoflavor_to_flavor(args->pseudoflavor); + + ops = rpcauth_get_authops(flavor); + if (ops == NULL) + goto out; + + auth = ops->create(args, clnt); + + rpcauth_put_authops(ops); + if (IS_ERR(auth)) + return auth; + if (clnt->cl_auth) + rpcauth_release(clnt->cl_auth); + clnt->cl_auth = auth; + +out: + return auth; +} +EXPORT_SYMBOL_GPL(rpcauth_create); + +void +rpcauth_release(struct rpc_auth *auth) +{ + if (!refcount_dec_and_test(&auth->au_count)) + return; + auth->au_ops->destroy(auth); +} + +static DEFINE_SPINLOCK(rpc_credcache_lock); + +/* + * On success, the caller is responsible for freeing the reference + * held by the hashtable + */ +static bool +rpcauth_unhash_cred_locked(struct rpc_cred *cred) +{ + if (!test_and_clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags)) + return false; + hlist_del_rcu(&cred->cr_hash); + return true; +} + +static bool +rpcauth_unhash_cred(struct rpc_cred *cred) +{ + spinlock_t *cache_lock; + bool ret; + + if (!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags)) + return false; + cache_lock = &cred->cr_auth->au_credcache->lock; + spin_lock(cache_lock); + ret = rpcauth_unhash_cred_locked(cred); + spin_unlock(cache_lock); + return ret; +} + +/* + * Initialize RPC credential cache + */ +int +rpcauth_init_credcache(struct rpc_auth *auth) +{ + struct rpc_cred_cache *new; + unsigned int hashsize; + + new = kmalloc(sizeof(*new), GFP_KERNEL); + if (!new) + goto out_nocache; + new->hashbits = auth_hashbits; + hashsize = 1U << new->hashbits; + new->hashtable = kcalloc(hashsize, sizeof(new->hashtable[0]), GFP_KERNEL); + if (!new->hashtable) + goto out_nohashtbl; + spin_lock_init(&new->lock); + auth->au_credcache = new; + return 0; +out_nohashtbl: + kfree(new); +out_nocache: + return -ENOMEM; +} +EXPORT_SYMBOL_GPL(rpcauth_init_credcache); + +char * +rpcauth_stringify_acceptor(struct rpc_cred *cred) +{ + if (!cred->cr_ops->crstringify_acceptor) + return NULL; + return cred->cr_ops->crstringify_acceptor(cred); +} +EXPORT_SYMBOL_GPL(rpcauth_stringify_acceptor); + +/* + * Destroy a list of credentials + */ +static inline +void rpcauth_destroy_credlist(struct list_head *head) +{ + struct rpc_cred *cred; + + while (!list_empty(head)) { + cred = list_entry(head->next, struct rpc_cred, cr_lru); + list_del_init(&cred->cr_lru); + put_rpccred(cred); + } +} + +static void +rpcauth_lru_add_locked(struct rpc_cred *cred) +{ + if (!list_empty(&cred->cr_lru)) + return; + number_cred_unused++; + list_add_tail(&cred->cr_lru, &cred_unused); +} + +static void +rpcauth_lru_add(struct rpc_cred *cred) +{ + if (!list_empty(&cred->cr_lru)) + return; + spin_lock(&rpc_credcache_lock); + rpcauth_lru_add_locked(cred); + spin_unlock(&rpc_credcache_lock); +} + +static void +rpcauth_lru_remove_locked(struct rpc_cred *cred) +{ + if (list_empty(&cred->cr_lru)) + return; + number_cred_unused--; + list_del_init(&cred->cr_lru); +} + +static void +rpcauth_lru_remove(struct rpc_cred *cred) +{ + if (list_empty(&cred->cr_lru)) + return; + spin_lock(&rpc_credcache_lock); + rpcauth_lru_remove_locked(cred); + spin_unlock(&rpc_credcache_lock); +} + +/* + * Clear the RPC credential cache, and delete those credentials + * that are not referenced. + */ +void +rpcauth_clear_credcache(struct rpc_cred_cache *cache) +{ + LIST_HEAD(free); + struct hlist_head *head; + struct rpc_cred *cred; + unsigned int hashsize = 1U << cache->hashbits; + int i; + + spin_lock(&rpc_credcache_lock); + spin_lock(&cache->lock); + for (i = 0; i < hashsize; i++) { + head = &cache->hashtable[i]; + while (!hlist_empty(head)) { + cred = hlist_entry(head->first, struct rpc_cred, cr_hash); + rpcauth_unhash_cred_locked(cred); + /* Note: We now hold a reference to cred */ + rpcauth_lru_remove_locked(cred); + list_add_tail(&cred->cr_lru, &free); + } + } + spin_unlock(&cache->lock); + spin_unlock(&rpc_credcache_lock); + rpcauth_destroy_credlist(&free); +} + +/* + * Destroy the RPC credential cache + */ +void +rpcauth_destroy_credcache(struct rpc_auth *auth) +{ + struct rpc_cred_cache *cache = auth->au_credcache; + + if (cache) { + auth->au_credcache = NULL; + rpcauth_clear_credcache(cache); + kfree(cache->hashtable); + kfree(cache); + } +} +EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache); + + +#define RPC_AUTH_EXPIRY_MORATORIUM (60 * HZ) + +/* + * Remove stale credentials. Avoid sleeping inside the loop. + */ +static long +rpcauth_prune_expired(struct list_head *free, int nr_to_scan) +{ + struct rpc_cred *cred, *next; + unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM; + long freed = 0; + + list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) { + + if (nr_to_scan-- == 0) + break; + if (refcount_read(&cred->cr_count) > 1) { + rpcauth_lru_remove_locked(cred); + continue; + } + /* + * Enforce a 60 second garbage collection moratorium + * Note that the cred_unused list must be time-ordered. + */ + if (time_in_range(cred->cr_expire, expired, jiffies)) + continue; + if (!rpcauth_unhash_cred(cred)) + continue; + + rpcauth_lru_remove_locked(cred); + freed++; + list_add_tail(&cred->cr_lru, free); + } + return freed ? freed : SHRINK_STOP; +} + +static unsigned long +rpcauth_cache_do_shrink(int nr_to_scan) +{ + LIST_HEAD(free); + unsigned long freed; + + spin_lock(&rpc_credcache_lock); + freed = rpcauth_prune_expired(&free, nr_to_scan); + spin_unlock(&rpc_credcache_lock); + rpcauth_destroy_credlist(&free); + + return freed; +} + +/* + * Run memory cache shrinker. + */ +static unsigned long +rpcauth_cache_shrink_scan(struct shrinker *shrink, struct shrink_control *sc) + +{ + if ((sc->gfp_mask & GFP_KERNEL) != GFP_KERNEL) + return SHRINK_STOP; + + /* nothing left, don't come back */ + if (list_empty(&cred_unused)) + return SHRINK_STOP; + + return rpcauth_cache_do_shrink(sc->nr_to_scan); +} + +static unsigned long +rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc) + +{ + return number_cred_unused * sysctl_vfs_cache_pressure / 100; +} + +static void +rpcauth_cache_enforce_limit(void) +{ + unsigned long diff; + unsigned int nr_to_scan; + + if (number_cred_unused <= auth_max_cred_cachesize) + return; + diff = number_cred_unused - auth_max_cred_cachesize; + nr_to_scan = 100; + if (diff < nr_to_scan) + nr_to_scan = diff; + rpcauth_cache_do_shrink(nr_to_scan); +} + +/* + * Look up a process' credentials in the authentication cache + */ +struct rpc_cred * +rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred, + int flags, gfp_t gfp) +{ + LIST_HEAD(free); + struct rpc_cred_cache *cache = auth->au_credcache; + struct rpc_cred *cred = NULL, + *entry, *new; + unsigned int nr; + + nr = auth->au_ops->hash_cred(acred, cache->hashbits); + + rcu_read_lock(); + hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) { + if (!entry->cr_ops->crmatch(acred, entry, flags)) + continue; + cred = get_rpccred(entry); + if (cred) + break; + } + rcu_read_unlock(); + + if (cred != NULL) + goto found; + + new = auth->au_ops->crcreate(auth, acred, flags, gfp); + if (IS_ERR(new)) { + cred = new; + goto out; + } + + spin_lock(&cache->lock); + hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) { + if (!entry->cr_ops->crmatch(acred, entry, flags)) + continue; + cred = get_rpccred(entry); + if (cred) + break; + } + if (cred == NULL) { + cred = new; + set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags); + refcount_inc(&cred->cr_count); + hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]); + } else + list_add_tail(&new->cr_lru, &free); + spin_unlock(&cache->lock); + rpcauth_cache_enforce_limit(); +found: + if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) && + cred->cr_ops->cr_init != NULL && + !(flags & RPCAUTH_LOOKUP_NEW)) { + int res = cred->cr_ops->cr_init(auth, cred); + if (res < 0) { + put_rpccred(cred); + cred = ERR_PTR(res); + } + } + rpcauth_destroy_credlist(&free); +out: + return cred; +} +EXPORT_SYMBOL_GPL(rpcauth_lookup_credcache); + +struct rpc_cred * +rpcauth_lookupcred(struct rpc_auth *auth, int flags) +{ + struct auth_cred acred; + struct rpc_cred *ret; + const struct cred *cred = current_cred(); + + memset(&acred, 0, sizeof(acred)); + acred.cred = cred; + ret = auth->au_ops->lookup_cred(auth, &acred, flags); + return ret; +} +EXPORT_SYMBOL_GPL(rpcauth_lookupcred); + +void +rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred, + struct rpc_auth *auth, const struct rpc_credops *ops) +{ + INIT_HLIST_NODE(&cred->cr_hash); + INIT_LIST_HEAD(&cred->cr_lru); + refcount_set(&cred->cr_count, 1); + cred->cr_auth = auth; + cred->cr_flags = 0; + cred->cr_ops = ops; + cred->cr_expire = jiffies; + cred->cr_cred = get_cred(acred->cred); +} +EXPORT_SYMBOL_GPL(rpcauth_init_cred); + +static struct rpc_cred * +rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags) +{ + struct rpc_auth *auth = task->tk_client->cl_auth; + struct auth_cred acred = { + .cred = get_task_cred(&init_task), + }; + struct rpc_cred *ret; + + if (RPC_IS_ASYNC(task)) + lookupflags |= RPCAUTH_LOOKUP_ASYNC; + ret = auth->au_ops->lookup_cred(auth, &acred, lookupflags); + put_cred(acred.cred); + return ret; +} + +static struct rpc_cred * +rpcauth_bind_machine_cred(struct rpc_task *task, int lookupflags) +{ + struct rpc_auth *auth = task->tk_client->cl_auth; + struct auth_cred acred = { + .principal = task->tk_client->cl_principal, + .cred = init_task.cred, + }; + + if (!acred.principal) + return NULL; + if (RPC_IS_ASYNC(task)) + lookupflags |= RPCAUTH_LOOKUP_ASYNC; + return auth->au_ops->lookup_cred(auth, &acred, lookupflags); +} + +static struct rpc_cred * +rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags) +{ + struct rpc_auth *auth = task->tk_client->cl_auth; + + return rpcauth_lookupcred(auth, lookupflags); +} + +static int +rpcauth_bindcred(struct rpc_task *task, const struct cred *cred, int flags) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_cred *new = NULL; + int lookupflags = 0; + struct rpc_auth *auth = task->tk_client->cl_auth; + struct auth_cred acred = { + .cred = cred, + }; + + if (flags & RPC_TASK_ASYNC) + lookupflags |= RPCAUTH_LOOKUP_NEW | RPCAUTH_LOOKUP_ASYNC; + if (task->tk_op_cred) + /* Task must use exactly this rpc_cred */ + new = get_rpccred(task->tk_op_cred); + else if (cred != NULL && cred != &machine_cred) + new = auth->au_ops->lookup_cred(auth, &acred, lookupflags); + else if (cred == &machine_cred) + new = rpcauth_bind_machine_cred(task, lookupflags); + + /* If machine cred couldn't be bound, try a root cred */ + if (new) + ; + else if (cred == &machine_cred) + new = rpcauth_bind_root_cred(task, lookupflags); + else if (flags & RPC_TASK_NULLCREDS) + new = authnull_ops.lookup_cred(NULL, NULL, 0); + else + new = rpcauth_bind_new_cred(task, lookupflags); + if (IS_ERR(new)) + return PTR_ERR(new); + put_rpccred(req->rq_cred); + req->rq_cred = new; + return 0; +} + +void +put_rpccred(struct rpc_cred *cred) +{ + if (cred == NULL) + return; + rcu_read_lock(); + if (refcount_dec_and_test(&cred->cr_count)) + goto destroy; + if (refcount_read(&cred->cr_count) != 1 || + !test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags)) + goto out; + if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) { + cred->cr_expire = jiffies; + rpcauth_lru_add(cred); + /* Race breaker */ + if (unlikely(!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))) + rpcauth_lru_remove(cred); + } else if (rpcauth_unhash_cred(cred)) { + rpcauth_lru_remove(cred); + if (refcount_dec_and_test(&cred->cr_count)) + goto destroy; + } +out: + rcu_read_unlock(); + return; +destroy: + rcu_read_unlock(); + cred->cr_ops->crdestroy(cred); +} +EXPORT_SYMBOL_GPL(put_rpccred); + +/** + * rpcauth_marshcred - Append RPC credential to end of @xdr + * @task: controlling RPC task + * @xdr: xdr_stream containing initial portion of RPC Call header + * + * On success, an appropriate verifier is added to @xdr, @xdr is + * updated to point past the verifier, and zero is returned. + * Otherwise, @xdr is in an undefined state and a negative errno + * is returned. + */ +int rpcauth_marshcred(struct rpc_task *task, struct xdr_stream *xdr) +{ + const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops; + + return ops->crmarshal(task, xdr); +} + +/** + * rpcauth_wrap_req_encode - XDR encode the RPC procedure + * @task: controlling RPC task + * @xdr: stream where on-the-wire bytes are to be marshalled + * + * On success, @xdr contains the encoded and wrapped message. + * Otherwise, @xdr is in an undefined state. + */ +int rpcauth_wrap_req_encode(struct rpc_task *task, struct xdr_stream *xdr) +{ + kxdreproc_t encode = task->tk_msg.rpc_proc->p_encode; + + encode(task->tk_rqstp, xdr, task->tk_msg.rpc_argp); + return 0; +} +EXPORT_SYMBOL_GPL(rpcauth_wrap_req_encode); + +/** + * rpcauth_wrap_req - XDR encode and wrap the RPC procedure + * @task: controlling RPC task + * @xdr: stream where on-the-wire bytes are to be marshalled + * + * On success, @xdr contains the encoded and wrapped message, + * and zero is returned. Otherwise, @xdr is in an undefined + * state and a negative errno is returned. + */ +int rpcauth_wrap_req(struct rpc_task *task, struct xdr_stream *xdr) +{ + const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops; + + return ops->crwrap_req(task, xdr); +} + +/** + * rpcauth_checkverf - Validate verifier in RPC Reply header + * @task: controlling RPC task + * @xdr: xdr_stream containing RPC Reply header + * + * On success, @xdr is updated to point past the verifier and + * zero is returned. Otherwise, @xdr is in an undefined state + * and a negative errno is returned. + */ +int +rpcauth_checkverf(struct rpc_task *task, struct xdr_stream *xdr) +{ + const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops; + + return ops->crvalidate(task, xdr); +} + +/** + * rpcauth_unwrap_resp_decode - Invoke XDR decode function + * @task: controlling RPC task + * @xdr: stream where the Reply message resides + * + * Returns zero on success; otherwise a negative errno is returned. + */ +int +rpcauth_unwrap_resp_decode(struct rpc_task *task, struct xdr_stream *xdr) +{ + kxdrdproc_t decode = task->tk_msg.rpc_proc->p_decode; + + return decode(task->tk_rqstp, xdr, task->tk_msg.rpc_resp); +} +EXPORT_SYMBOL_GPL(rpcauth_unwrap_resp_decode); + +/** + * rpcauth_unwrap_resp - Invoke unwrap and decode function for the cred + * @task: controlling RPC task + * @xdr: stream where the Reply message resides + * + * Returns zero on success; otherwise a negative errno is returned. + */ +int +rpcauth_unwrap_resp(struct rpc_task *task, struct xdr_stream *xdr) +{ + const struct rpc_credops *ops = task->tk_rqstp->rq_cred->cr_ops; + + return ops->crunwrap_resp(task, xdr); +} + +bool +rpcauth_xmit_need_reencode(struct rpc_task *task) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + + if (!cred || !cred->cr_ops->crneed_reencode) + return false; + return cred->cr_ops->crneed_reencode(task); +} + +int +rpcauth_refreshcred(struct rpc_task *task) +{ + struct rpc_cred *cred; + int err; + + cred = task->tk_rqstp->rq_cred; + if (cred == NULL) { + err = rpcauth_bindcred(task, task->tk_msg.rpc_cred, task->tk_flags); + if (err < 0) + goto out; + cred = task->tk_rqstp->rq_cred; + } + + err = cred->cr_ops->crrefresh(task); +out: + if (err < 0) + task->tk_status = err; + return err; +} + +void +rpcauth_invalcred(struct rpc_task *task) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + + if (cred) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); +} + +int +rpcauth_uptodatecred(struct rpc_task *task) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + + return cred == NULL || + test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0; +} + +static struct shrinker rpc_cred_shrinker = { + .count_objects = rpcauth_cache_shrink_count, + .scan_objects = rpcauth_cache_shrink_scan, + .seeks = DEFAULT_SEEKS, +}; + +int __init rpcauth_init_module(void) +{ + int err; + + err = rpc_init_authunix(); + if (err < 0) + goto out1; + err = register_shrinker(&rpc_cred_shrinker, "sunrpc_cred"); + if (err < 0) + goto out2; + return 0; +out2: + rpc_destroy_authunix(); +out1: + return err; +} + +void rpcauth_remove_module(void) +{ + rpc_destroy_authunix(); + unregister_shrinker(&rpc_cred_shrinker); +} diff --git a/net/sunrpc/auth_gss/Makefile b/net/sunrpc/auth_gss/Makefile new file mode 100644 index 000000000..4a29f4c5d --- /dev/null +++ b/net/sunrpc/auth_gss/Makefile @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for Linux kernel rpcsec_gss implementation +# + +obj-$(CONFIG_SUNRPC_GSS) += auth_rpcgss.o + +auth_rpcgss-y := auth_gss.o gss_generic_token.o \ + gss_mech_switch.o svcauth_gss.o \ + gss_rpc_upcall.o gss_rpc_xdr.o trace.o + +obj-$(CONFIG_RPCSEC_GSS_KRB5) += rpcsec_gss_krb5.o + +rpcsec_gss_krb5-y := gss_krb5_mech.o gss_krb5_seal.o gss_krb5_unseal.o \ + gss_krb5_seqnum.o gss_krb5_wrap.o gss_krb5_crypto.o gss_krb5_keys.o diff --git a/net/sunrpc/auth_gss/auth_gss.c b/net/sunrpc/auth_gss/auth_gss.c new file mode 100644 index 000000000..2d7b1e031 --- /dev/null +++ b/net/sunrpc/auth_gss/auth_gss.c @@ -0,0 +1,2281 @@ +// SPDX-License-Identifier: BSD-3-Clause +/* + * linux/net/sunrpc/auth_gss/auth_gss.c + * + * RPCSEC_GSS client authentication. + * + * Copyright (c) 2000 The Regents of the University of Michigan. + * All rights reserved. + * + * Dug Song + * Andy Adamson + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "auth_gss_internal.h" +#include "../netns.h" + +#include + +static const struct rpc_authops authgss_ops; + +static const struct rpc_credops gss_credops; +static const struct rpc_credops gss_nullops; + +#define GSS_RETRY_EXPIRED 5 +static unsigned int gss_expired_cred_retry_delay = GSS_RETRY_EXPIRED; + +#define GSS_KEY_EXPIRE_TIMEO 240 +static unsigned int gss_key_expire_timeo = GSS_KEY_EXPIRE_TIMEO; + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +#define GSS_CRED_SLACK (RPC_MAX_AUTH_SIZE * 2) +/* length of a krb5 verifier (48), plus data added before arguments when + * using integrity (two 4-byte integers): */ +#define GSS_VERF_SLACK 100 + +static DEFINE_HASHTABLE(gss_auth_hash_table, 4); +static DEFINE_SPINLOCK(gss_auth_hash_lock); + +struct gss_pipe { + struct rpc_pipe_dir_object pdo; + struct rpc_pipe *pipe; + struct rpc_clnt *clnt; + const char *name; + struct kref kref; +}; + +struct gss_auth { + struct kref kref; + struct hlist_node hash; + struct rpc_auth rpc_auth; + struct gss_api_mech *mech; + enum rpc_gss_svc service; + struct rpc_clnt *client; + struct net *net; + netns_tracker ns_tracker; + /* + * There are two upcall pipes; dentry[1], named "gssd", is used + * for the new text-based upcall; dentry[0] is named after the + * mechanism (for example, "krb5") and exists for + * backwards-compatibility with older gssd's. + */ + struct gss_pipe *gss_pipe[2]; + const char *target_name; +}; + +/* pipe_version >= 0 if and only if someone has a pipe open. */ +static DEFINE_SPINLOCK(pipe_version_lock); +static struct rpc_wait_queue pipe_version_rpc_waitqueue; +static DECLARE_WAIT_QUEUE_HEAD(pipe_version_waitqueue); +static void gss_put_auth(struct gss_auth *gss_auth); + +static void gss_free_ctx(struct gss_cl_ctx *); +static const struct rpc_pipe_ops gss_upcall_ops_v0; +static const struct rpc_pipe_ops gss_upcall_ops_v1; + +static inline struct gss_cl_ctx * +gss_get_ctx(struct gss_cl_ctx *ctx) +{ + refcount_inc(&ctx->count); + return ctx; +} + +static inline void +gss_put_ctx(struct gss_cl_ctx *ctx) +{ + if (refcount_dec_and_test(&ctx->count)) + gss_free_ctx(ctx); +} + +/* gss_cred_set_ctx: + * called by gss_upcall_callback and gss_create_upcall in order + * to set the gss context. The actual exchange of an old context + * and a new one is protected by the pipe->lock. + */ +static void +gss_cred_set_ctx(struct rpc_cred *cred, struct gss_cl_ctx *ctx) +{ + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); + + if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags)) + return; + gss_get_ctx(ctx); + rcu_assign_pointer(gss_cred->gc_ctx, ctx); + set_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + smp_mb__before_atomic(); + clear_bit(RPCAUTH_CRED_NEW, &cred->cr_flags); +} + +static struct gss_cl_ctx * +gss_cred_get_ctx(struct rpc_cred *cred) +{ + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); + struct gss_cl_ctx *ctx = NULL; + + rcu_read_lock(); + ctx = rcu_dereference(gss_cred->gc_ctx); + if (ctx) + gss_get_ctx(ctx); + rcu_read_unlock(); + return ctx; +} + +static struct gss_cl_ctx * +gss_alloc_context(void) +{ + struct gss_cl_ctx *ctx; + + ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); + if (ctx != NULL) { + ctx->gc_proc = RPC_GSS_PROC_DATA; + ctx->gc_seq = 1; /* NetApp 6.4R1 doesn't accept seq. no. 0 */ + spin_lock_init(&ctx->gc_seq_lock); + refcount_set(&ctx->count,1); + } + return ctx; +} + +#define GSSD_MIN_TIMEOUT (60 * 60) +static const void * +gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct gss_api_mech *gm) +{ + const void *q; + unsigned int seclen; + unsigned int timeout; + unsigned long now = jiffies; + u32 window_size; + int ret; + + /* First unsigned int gives the remaining lifetime in seconds of the + * credential - e.g. the remaining TGT lifetime for Kerberos or + * the -t value passed to GSSD. + */ + p = simple_get_bytes(p, end, &timeout, sizeof(timeout)); + if (IS_ERR(p)) + goto err; + if (timeout == 0) + timeout = GSSD_MIN_TIMEOUT; + ctx->gc_expiry = now + ((unsigned long)timeout * HZ); + /* Sequence number window. Determines the maximum number of + * simultaneous requests + */ + p = simple_get_bytes(p, end, &window_size, sizeof(window_size)); + if (IS_ERR(p)) + goto err; + ctx->gc_win = window_size; + /* gssd signals an error by passing ctx->gc_win = 0: */ + if (ctx->gc_win == 0) { + /* + * in which case, p points to an error code. Anything other + * than -EKEYEXPIRED gets converted to -EACCES. + */ + p = simple_get_bytes(p, end, &ret, sizeof(ret)); + if (!IS_ERR(p)) + p = (ret == -EKEYEXPIRED) ? ERR_PTR(-EKEYEXPIRED) : + ERR_PTR(-EACCES); + goto err; + } + /* copy the opaque wire context */ + p = simple_get_netobj(p, end, &ctx->gc_wire_ctx); + if (IS_ERR(p)) + goto err; + /* import the opaque security context */ + p = simple_get_bytes(p, end, &seclen, sizeof(seclen)); + if (IS_ERR(p)) + goto err; + q = (const void *)((const char *)p + seclen); + if (unlikely(q > end || q < p)) { + p = ERR_PTR(-EFAULT); + goto err; + } + ret = gss_import_sec_context(p, seclen, gm, &ctx->gc_gss_ctx, NULL, GFP_KERNEL); + if (ret < 0) { + trace_rpcgss_import_ctx(ret); + p = ERR_PTR(ret); + goto err; + } + + /* is there any trailing data? */ + if (q == end) { + p = q; + goto done; + } + + /* pull in acceptor name (if there is one) */ + p = simple_get_netobj(q, end, &ctx->gc_acceptor); + if (IS_ERR(p)) + goto err; +done: + trace_rpcgss_context(window_size, ctx->gc_expiry, now, timeout, + ctx->gc_acceptor.len, ctx->gc_acceptor.data); +err: + return p; +} + +/* XXX: Need some documentation about why UPCALL_BUF_LEN is so small. + * Is user space expecting no more than UPCALL_BUF_LEN bytes? + * Note that there are now _two_ NI_MAXHOST sized data items + * being passed in this string. + */ +#define UPCALL_BUF_LEN 256 + +struct gss_upcall_msg { + refcount_t count; + kuid_t uid; + const char *service_name; + struct rpc_pipe_msg msg; + struct list_head list; + struct gss_auth *auth; + struct rpc_pipe *pipe; + struct rpc_wait_queue rpc_waitqueue; + wait_queue_head_t waitqueue; + struct gss_cl_ctx *ctx; + char databuf[UPCALL_BUF_LEN]; +}; + +static int get_pipe_version(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + int ret; + + spin_lock(&pipe_version_lock); + if (sn->pipe_version >= 0) { + atomic_inc(&sn->pipe_users); + ret = sn->pipe_version; + } else + ret = -EAGAIN; + spin_unlock(&pipe_version_lock); + return ret; +} + +static void put_pipe_version(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + if (atomic_dec_and_lock(&sn->pipe_users, &pipe_version_lock)) { + sn->pipe_version = -1; + spin_unlock(&pipe_version_lock); + } +} + +static void +gss_release_msg(struct gss_upcall_msg *gss_msg) +{ + struct net *net = gss_msg->auth->net; + if (!refcount_dec_and_test(&gss_msg->count)) + return; + put_pipe_version(net); + BUG_ON(!list_empty(&gss_msg->list)); + if (gss_msg->ctx != NULL) + gss_put_ctx(gss_msg->ctx); + rpc_destroy_wait_queue(&gss_msg->rpc_waitqueue); + gss_put_auth(gss_msg->auth); + kfree_const(gss_msg->service_name); + kfree(gss_msg); +} + +static struct gss_upcall_msg * +__gss_find_upcall(struct rpc_pipe *pipe, kuid_t uid, const struct gss_auth *auth) +{ + struct gss_upcall_msg *pos; + list_for_each_entry(pos, &pipe->in_downcall, list) { + if (!uid_eq(pos->uid, uid)) + continue; + if (pos->auth->service != auth->service) + continue; + refcount_inc(&pos->count); + return pos; + } + return NULL; +} + +/* Try to add an upcall to the pipefs queue. + * If an upcall owned by our uid already exists, then we return a reference + * to that upcall instead of adding the new upcall. + */ +static inline struct gss_upcall_msg * +gss_add_msg(struct gss_upcall_msg *gss_msg) +{ + struct rpc_pipe *pipe = gss_msg->pipe; + struct gss_upcall_msg *old; + + spin_lock(&pipe->lock); + old = __gss_find_upcall(pipe, gss_msg->uid, gss_msg->auth); + if (old == NULL) { + refcount_inc(&gss_msg->count); + list_add(&gss_msg->list, &pipe->in_downcall); + } else + gss_msg = old; + spin_unlock(&pipe->lock); + return gss_msg; +} + +static void +__gss_unhash_msg(struct gss_upcall_msg *gss_msg) +{ + list_del_init(&gss_msg->list); + rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno); + wake_up_all(&gss_msg->waitqueue); + refcount_dec(&gss_msg->count); +} + +static void +gss_unhash_msg(struct gss_upcall_msg *gss_msg) +{ + struct rpc_pipe *pipe = gss_msg->pipe; + + if (list_empty(&gss_msg->list)) + return; + spin_lock(&pipe->lock); + if (!list_empty(&gss_msg->list)) + __gss_unhash_msg(gss_msg); + spin_unlock(&pipe->lock); +} + +static void +gss_handle_downcall_result(struct gss_cred *gss_cred, struct gss_upcall_msg *gss_msg) +{ + switch (gss_msg->msg.errno) { + case 0: + if (gss_msg->ctx == NULL) + break; + clear_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags); + gss_cred_set_ctx(&gss_cred->gc_base, gss_msg->ctx); + break; + case -EKEYEXPIRED: + set_bit(RPCAUTH_CRED_NEGATIVE, &gss_cred->gc_base.cr_flags); + } + gss_cred->gc_upcall_timestamp = jiffies; + gss_cred->gc_upcall = NULL; + rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno); +} + +static void +gss_upcall_callback(struct rpc_task *task) +{ + struct gss_cred *gss_cred = container_of(task->tk_rqstp->rq_cred, + struct gss_cred, gc_base); + struct gss_upcall_msg *gss_msg = gss_cred->gc_upcall; + struct rpc_pipe *pipe = gss_msg->pipe; + + spin_lock(&pipe->lock); + gss_handle_downcall_result(gss_cred, gss_msg); + spin_unlock(&pipe->lock); + task->tk_status = gss_msg->msg.errno; + gss_release_msg(gss_msg); +} + +static void gss_encode_v0_msg(struct gss_upcall_msg *gss_msg, + const struct cred *cred) +{ + struct user_namespace *userns = cred->user_ns; + + uid_t uid = from_kuid_munged(userns, gss_msg->uid); + memcpy(gss_msg->databuf, &uid, sizeof(uid)); + gss_msg->msg.data = gss_msg->databuf; + gss_msg->msg.len = sizeof(uid); + + BUILD_BUG_ON(sizeof(uid) > sizeof(gss_msg->databuf)); +} + +static ssize_t +gss_v0_upcall(struct file *file, struct rpc_pipe_msg *msg, + char __user *buf, size_t buflen) +{ + struct gss_upcall_msg *gss_msg = container_of(msg, + struct gss_upcall_msg, + msg); + if (msg->copied == 0) + gss_encode_v0_msg(gss_msg, file->f_cred); + return rpc_pipe_generic_upcall(file, msg, buf, buflen); +} + +static int gss_encode_v1_msg(struct gss_upcall_msg *gss_msg, + const char *service_name, + const char *target_name, + const struct cred *cred) +{ + struct user_namespace *userns = cred->user_ns; + struct gss_api_mech *mech = gss_msg->auth->mech; + char *p = gss_msg->databuf; + size_t buflen = sizeof(gss_msg->databuf); + int len; + + len = scnprintf(p, buflen, "mech=%s uid=%d", mech->gm_name, + from_kuid_munged(userns, gss_msg->uid)); + buflen -= len; + p += len; + gss_msg->msg.len = len; + + /* + * target= is a full service principal that names the remote + * identity that we are authenticating to. + */ + if (target_name) { + len = scnprintf(p, buflen, " target=%s", target_name); + buflen -= len; + p += len; + gss_msg->msg.len += len; + } + + /* + * gssd uses service= and srchost= to select a matching key from + * the system's keytab to use as the source principal. + * + * service= is the service name part of the source principal, + * or "*" (meaning choose any). + * + * srchost= is the hostname part of the source principal. When + * not provided, gssd uses the local hostname. + */ + if (service_name) { + char *c = strchr(service_name, '@'); + + if (!c) + len = scnprintf(p, buflen, " service=%s", + service_name); + else + len = scnprintf(p, buflen, + " service=%.*s srchost=%s", + (int)(c - service_name), + service_name, c + 1); + buflen -= len; + p += len; + gss_msg->msg.len += len; + } + + if (mech->gm_upcall_enctypes) { + len = scnprintf(p, buflen, " enctypes=%s", + mech->gm_upcall_enctypes); + buflen -= len; + p += len; + gss_msg->msg.len += len; + } + trace_rpcgss_upcall_msg(gss_msg->databuf); + len = scnprintf(p, buflen, "\n"); + if (len == 0) + goto out_overflow; + gss_msg->msg.len += len; + gss_msg->msg.data = gss_msg->databuf; + return 0; +out_overflow: + WARN_ON_ONCE(1); + return -ENOMEM; +} + +static ssize_t +gss_v1_upcall(struct file *file, struct rpc_pipe_msg *msg, + char __user *buf, size_t buflen) +{ + struct gss_upcall_msg *gss_msg = container_of(msg, + struct gss_upcall_msg, + msg); + int err; + if (msg->copied == 0) { + err = gss_encode_v1_msg(gss_msg, + gss_msg->service_name, + gss_msg->auth->target_name, + file->f_cred); + if (err) + return err; + } + return rpc_pipe_generic_upcall(file, msg, buf, buflen); +} + +static struct gss_upcall_msg * +gss_alloc_msg(struct gss_auth *gss_auth, + kuid_t uid, const char *service_name) +{ + struct gss_upcall_msg *gss_msg; + int vers; + int err = -ENOMEM; + + gss_msg = kzalloc(sizeof(*gss_msg), GFP_KERNEL); + if (gss_msg == NULL) + goto err; + vers = get_pipe_version(gss_auth->net); + err = vers; + if (err < 0) + goto err_free_msg; + gss_msg->pipe = gss_auth->gss_pipe[vers]->pipe; + INIT_LIST_HEAD(&gss_msg->list); + rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq"); + init_waitqueue_head(&gss_msg->waitqueue); + refcount_set(&gss_msg->count, 1); + gss_msg->uid = uid; + gss_msg->auth = gss_auth; + kref_get(&gss_auth->kref); + if (service_name) { + gss_msg->service_name = kstrdup_const(service_name, GFP_KERNEL); + if (!gss_msg->service_name) { + err = -ENOMEM; + goto err_put_pipe_version; + } + } + return gss_msg; +err_put_pipe_version: + put_pipe_version(gss_auth->net); +err_free_msg: + kfree(gss_msg); +err: + return ERR_PTR(err); +} + +static struct gss_upcall_msg * +gss_setup_upcall(struct gss_auth *gss_auth, struct rpc_cred *cred) +{ + struct gss_cred *gss_cred = container_of(cred, + struct gss_cred, gc_base); + struct gss_upcall_msg *gss_new, *gss_msg; + kuid_t uid = cred->cr_cred->fsuid; + + gss_new = gss_alloc_msg(gss_auth, uid, gss_cred->gc_principal); + if (IS_ERR(gss_new)) + return gss_new; + gss_msg = gss_add_msg(gss_new); + if (gss_msg == gss_new) { + int res; + refcount_inc(&gss_msg->count); + res = rpc_queue_upcall(gss_new->pipe, &gss_new->msg); + if (res) { + gss_unhash_msg(gss_new); + refcount_dec(&gss_msg->count); + gss_release_msg(gss_new); + gss_msg = ERR_PTR(res); + } + } else + gss_release_msg(gss_new); + return gss_msg; +} + +static void warn_gssd(void) +{ + dprintk("AUTH_GSS upcall failed. Please check user daemon is running.\n"); +} + +static inline int +gss_refresh_upcall(struct rpc_task *task) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + struct gss_auth *gss_auth = container_of(cred->cr_auth, + struct gss_auth, rpc_auth); + struct gss_cred *gss_cred = container_of(cred, + struct gss_cred, gc_base); + struct gss_upcall_msg *gss_msg; + struct rpc_pipe *pipe; + int err = 0; + + gss_msg = gss_setup_upcall(gss_auth, cred); + if (PTR_ERR(gss_msg) == -EAGAIN) { + /* XXX: warning on the first, under the assumption we + * shouldn't normally hit this case on a refresh. */ + warn_gssd(); + rpc_sleep_on_timeout(&pipe_version_rpc_waitqueue, + task, NULL, jiffies + (15 * HZ)); + err = -EAGAIN; + goto out; + } + if (IS_ERR(gss_msg)) { + err = PTR_ERR(gss_msg); + goto out; + } + pipe = gss_msg->pipe; + spin_lock(&pipe->lock); + if (gss_cred->gc_upcall != NULL) + rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL); + else if (gss_msg->ctx == NULL && gss_msg->msg.errno >= 0) { + gss_cred->gc_upcall = gss_msg; + /* gss_upcall_callback will release the reference to gss_upcall_msg */ + refcount_inc(&gss_msg->count); + rpc_sleep_on(&gss_msg->rpc_waitqueue, task, gss_upcall_callback); + } else { + gss_handle_downcall_result(gss_cred, gss_msg); + err = gss_msg->msg.errno; + } + spin_unlock(&pipe->lock); + gss_release_msg(gss_msg); +out: + trace_rpcgss_upcall_result(from_kuid(&init_user_ns, + cred->cr_cred->fsuid), err); + return err; +} + +static inline int +gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred) +{ + struct net *net = gss_auth->net; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_pipe *pipe; + struct rpc_cred *cred = &gss_cred->gc_base; + struct gss_upcall_msg *gss_msg; + DEFINE_WAIT(wait); + int err; + +retry: + err = 0; + /* if gssd is down, just skip upcalling altogether */ + if (!gssd_running(net)) { + warn_gssd(); + err = -EACCES; + goto out; + } + gss_msg = gss_setup_upcall(gss_auth, cred); + if (PTR_ERR(gss_msg) == -EAGAIN) { + err = wait_event_interruptible_timeout(pipe_version_waitqueue, + sn->pipe_version >= 0, 15 * HZ); + if (sn->pipe_version < 0) { + warn_gssd(); + err = -EACCES; + } + if (err < 0) + goto out; + goto retry; + } + if (IS_ERR(gss_msg)) { + err = PTR_ERR(gss_msg); + goto out; + } + pipe = gss_msg->pipe; + for (;;) { + prepare_to_wait(&gss_msg->waitqueue, &wait, TASK_KILLABLE); + spin_lock(&pipe->lock); + if (gss_msg->ctx != NULL || gss_msg->msg.errno < 0) { + break; + } + spin_unlock(&pipe->lock); + if (fatal_signal_pending(current)) { + err = -ERESTARTSYS; + goto out_intr; + } + schedule(); + } + if (gss_msg->ctx) { + trace_rpcgss_ctx_init(gss_cred); + gss_cred_set_ctx(cred, gss_msg->ctx); + } else { + err = gss_msg->msg.errno; + } + spin_unlock(&pipe->lock); +out_intr: + finish_wait(&gss_msg->waitqueue, &wait); + gss_release_msg(gss_msg); +out: + trace_rpcgss_upcall_result(from_kuid(&init_user_ns, + cred->cr_cred->fsuid), err); + return err; +} + +static struct gss_upcall_msg * +gss_find_downcall(struct rpc_pipe *pipe, kuid_t uid) +{ + struct gss_upcall_msg *pos; + list_for_each_entry(pos, &pipe->in_downcall, list) { + if (!uid_eq(pos->uid, uid)) + continue; + if (!rpc_msg_is_inflight(&pos->msg)) + continue; + refcount_inc(&pos->count); + return pos; + } + return NULL; +} + +#define MSG_BUF_MAXSIZE 1024 + +static ssize_t +gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen) +{ + const void *p, *end; + void *buf; + struct gss_upcall_msg *gss_msg; + struct rpc_pipe *pipe = RPC_I(file_inode(filp))->pipe; + struct gss_cl_ctx *ctx; + uid_t id; + kuid_t uid; + ssize_t err = -EFBIG; + + if (mlen > MSG_BUF_MAXSIZE) + goto out; + err = -ENOMEM; + buf = kmalloc(mlen, GFP_KERNEL); + if (!buf) + goto out; + + err = -EFAULT; + if (copy_from_user(buf, src, mlen)) + goto err; + + end = (const void *)((char *)buf + mlen); + p = simple_get_bytes(buf, end, &id, sizeof(id)); + if (IS_ERR(p)) { + err = PTR_ERR(p); + goto err; + } + + uid = make_kuid(current_user_ns(), id); + if (!uid_valid(uid)) { + err = -EINVAL; + goto err; + } + + err = -ENOMEM; + ctx = gss_alloc_context(); + if (ctx == NULL) + goto err; + + err = -ENOENT; + /* Find a matching upcall */ + spin_lock(&pipe->lock); + gss_msg = gss_find_downcall(pipe, uid); + if (gss_msg == NULL) { + spin_unlock(&pipe->lock); + goto err_put_ctx; + } + list_del_init(&gss_msg->list); + spin_unlock(&pipe->lock); + + p = gss_fill_context(p, end, ctx, gss_msg->auth->mech); + if (IS_ERR(p)) { + err = PTR_ERR(p); + switch (err) { + case -EACCES: + case -EKEYEXPIRED: + gss_msg->msg.errno = err; + err = mlen; + break; + case -EFAULT: + case -ENOMEM: + case -EINVAL: + case -ENOSYS: + gss_msg->msg.errno = -EAGAIN; + break; + default: + printk(KERN_CRIT "%s: bad return from " + "gss_fill_context: %zd\n", __func__, err); + gss_msg->msg.errno = -EIO; + } + goto err_release_msg; + } + gss_msg->ctx = gss_get_ctx(ctx); + err = mlen; + +err_release_msg: + spin_lock(&pipe->lock); + __gss_unhash_msg(gss_msg); + spin_unlock(&pipe->lock); + gss_release_msg(gss_msg); +err_put_ctx: + gss_put_ctx(ctx); +err: + kfree(buf); +out: + return err; +} + +static int gss_pipe_open(struct inode *inode, int new_version) +{ + struct net *net = inode->i_sb->s_fs_info; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + int ret = 0; + + spin_lock(&pipe_version_lock); + if (sn->pipe_version < 0) { + /* First open of any gss pipe determines the version: */ + sn->pipe_version = new_version; + rpc_wake_up(&pipe_version_rpc_waitqueue); + wake_up(&pipe_version_waitqueue); + } else if (sn->pipe_version != new_version) { + /* Trying to open a pipe of a different version */ + ret = -EBUSY; + goto out; + } + atomic_inc(&sn->pipe_users); +out: + spin_unlock(&pipe_version_lock); + return ret; + +} + +static int gss_pipe_open_v0(struct inode *inode) +{ + return gss_pipe_open(inode, 0); +} + +static int gss_pipe_open_v1(struct inode *inode) +{ + return gss_pipe_open(inode, 1); +} + +static void +gss_pipe_release(struct inode *inode) +{ + struct net *net = inode->i_sb->s_fs_info; + struct rpc_pipe *pipe = RPC_I(inode)->pipe; + struct gss_upcall_msg *gss_msg; + +restart: + spin_lock(&pipe->lock); + list_for_each_entry(gss_msg, &pipe->in_downcall, list) { + + if (!list_empty(&gss_msg->msg.list)) + continue; + gss_msg->msg.errno = -EPIPE; + refcount_inc(&gss_msg->count); + __gss_unhash_msg(gss_msg); + spin_unlock(&pipe->lock); + gss_release_msg(gss_msg); + goto restart; + } + spin_unlock(&pipe->lock); + + put_pipe_version(net); +} + +static void +gss_pipe_destroy_msg(struct rpc_pipe_msg *msg) +{ + struct gss_upcall_msg *gss_msg = container_of(msg, struct gss_upcall_msg, msg); + + if (msg->errno < 0) { + refcount_inc(&gss_msg->count); + gss_unhash_msg(gss_msg); + if (msg->errno == -ETIMEDOUT) + warn_gssd(); + gss_release_msg(gss_msg); + } + gss_release_msg(gss_msg); +} + +static void gss_pipe_dentry_destroy(struct dentry *dir, + struct rpc_pipe_dir_object *pdo) +{ + struct gss_pipe *gss_pipe = pdo->pdo_data; + struct rpc_pipe *pipe = gss_pipe->pipe; + + if (pipe->dentry != NULL) { + rpc_unlink(pipe->dentry); + pipe->dentry = NULL; + } +} + +static int gss_pipe_dentry_create(struct dentry *dir, + struct rpc_pipe_dir_object *pdo) +{ + struct gss_pipe *p = pdo->pdo_data; + struct dentry *dentry; + + dentry = rpc_mkpipe_dentry(dir, p->name, p->clnt, p->pipe); + if (IS_ERR(dentry)) + return PTR_ERR(dentry); + p->pipe->dentry = dentry; + return 0; +} + +static const struct rpc_pipe_dir_object_ops gss_pipe_dir_object_ops = { + .create = gss_pipe_dentry_create, + .destroy = gss_pipe_dentry_destroy, +}; + +static struct gss_pipe *gss_pipe_alloc(struct rpc_clnt *clnt, + const char *name, + const struct rpc_pipe_ops *upcall_ops) +{ + struct gss_pipe *p; + int err = -ENOMEM; + + p = kmalloc(sizeof(*p), GFP_KERNEL); + if (p == NULL) + goto err; + p->pipe = rpc_mkpipe_data(upcall_ops, RPC_PIPE_WAIT_FOR_OPEN); + if (IS_ERR(p->pipe)) { + err = PTR_ERR(p->pipe); + goto err_free_gss_pipe; + } + p->name = name; + p->clnt = clnt; + kref_init(&p->kref); + rpc_init_pipe_dir_object(&p->pdo, + &gss_pipe_dir_object_ops, + p); + return p; +err_free_gss_pipe: + kfree(p); +err: + return ERR_PTR(err); +} + +struct gss_alloc_pdo { + struct rpc_clnt *clnt; + const char *name; + const struct rpc_pipe_ops *upcall_ops; +}; + +static int gss_pipe_match_pdo(struct rpc_pipe_dir_object *pdo, void *data) +{ + struct gss_pipe *gss_pipe; + struct gss_alloc_pdo *args = data; + + if (pdo->pdo_ops != &gss_pipe_dir_object_ops) + return 0; + gss_pipe = container_of(pdo, struct gss_pipe, pdo); + if (strcmp(gss_pipe->name, args->name) != 0) + return 0; + if (!kref_get_unless_zero(&gss_pipe->kref)) + return 0; + return 1; +} + +static struct rpc_pipe_dir_object *gss_pipe_alloc_pdo(void *data) +{ + struct gss_pipe *gss_pipe; + struct gss_alloc_pdo *args = data; + + gss_pipe = gss_pipe_alloc(args->clnt, args->name, args->upcall_ops); + if (!IS_ERR(gss_pipe)) + return &gss_pipe->pdo; + return NULL; +} + +static struct gss_pipe *gss_pipe_get(struct rpc_clnt *clnt, + const char *name, + const struct rpc_pipe_ops *upcall_ops) +{ + struct net *net = rpc_net_ns(clnt); + struct rpc_pipe_dir_object *pdo; + struct gss_alloc_pdo args = { + .clnt = clnt, + .name = name, + .upcall_ops = upcall_ops, + }; + + pdo = rpc_find_or_alloc_pipe_dir_object(net, + &clnt->cl_pipedir_objects, + gss_pipe_match_pdo, + gss_pipe_alloc_pdo, + &args); + if (pdo != NULL) + return container_of(pdo, struct gss_pipe, pdo); + return ERR_PTR(-ENOMEM); +} + +static void __gss_pipe_free(struct gss_pipe *p) +{ + struct rpc_clnt *clnt = p->clnt; + struct net *net = rpc_net_ns(clnt); + + rpc_remove_pipe_dir_object(net, + &clnt->cl_pipedir_objects, + &p->pdo); + rpc_destroy_pipe_data(p->pipe); + kfree(p); +} + +static void __gss_pipe_release(struct kref *kref) +{ + struct gss_pipe *p = container_of(kref, struct gss_pipe, kref); + + __gss_pipe_free(p); +} + +static void gss_pipe_free(struct gss_pipe *p) +{ + if (p != NULL) + kref_put(&p->kref, __gss_pipe_release); +} + +/* + * NOTE: we have the opportunity to use different + * parameters based on the input flavor (which must be a pseudoflavor) + */ +static struct gss_auth * +gss_create_new(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) +{ + rpc_authflavor_t flavor = args->pseudoflavor; + struct gss_auth *gss_auth; + struct gss_pipe *gss_pipe; + struct rpc_auth * auth; + int err = -ENOMEM; /* XXX? */ + + if (!try_module_get(THIS_MODULE)) + return ERR_PTR(err); + if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL))) + goto out_dec; + INIT_HLIST_NODE(&gss_auth->hash); + gss_auth->target_name = NULL; + if (args->target_name) { + gss_auth->target_name = kstrdup(args->target_name, GFP_KERNEL); + if (gss_auth->target_name == NULL) + goto err_free; + } + gss_auth->client = clnt; + gss_auth->net = get_net_track(rpc_net_ns(clnt), &gss_auth->ns_tracker, + GFP_KERNEL); + err = -EINVAL; + gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor); + if (!gss_auth->mech) + goto err_put_net; + gss_auth->service = gss_pseudoflavor_to_service(gss_auth->mech, flavor); + if (gss_auth->service == 0) + goto err_put_mech; + if (!gssd_running(gss_auth->net)) + goto err_put_mech; + auth = &gss_auth->rpc_auth; + auth->au_cslack = GSS_CRED_SLACK >> 2; + auth->au_rslack = GSS_KRB5_MAX_SLACK_NEEDED >> 2; + auth->au_verfsize = GSS_VERF_SLACK >> 2; + auth->au_ralign = GSS_VERF_SLACK >> 2; + __set_bit(RPCAUTH_AUTH_UPDATE_SLACK, &auth->au_flags); + auth->au_ops = &authgss_ops; + auth->au_flavor = flavor; + if (gss_pseudoflavor_to_datatouch(gss_auth->mech, flavor)) + __set_bit(RPCAUTH_AUTH_DATATOUCH, &auth->au_flags); + refcount_set(&auth->au_count, 1); + kref_init(&gss_auth->kref); + + err = rpcauth_init_credcache(auth); + if (err) + goto err_put_mech; + /* + * Note: if we created the old pipe first, then someone who + * examined the directory at the right moment might conclude + * that we supported only the old pipe. So we instead create + * the new pipe first. + */ + gss_pipe = gss_pipe_get(clnt, "gssd", &gss_upcall_ops_v1); + if (IS_ERR(gss_pipe)) { + err = PTR_ERR(gss_pipe); + goto err_destroy_credcache; + } + gss_auth->gss_pipe[1] = gss_pipe; + + gss_pipe = gss_pipe_get(clnt, gss_auth->mech->gm_name, + &gss_upcall_ops_v0); + if (IS_ERR(gss_pipe)) { + err = PTR_ERR(gss_pipe); + goto err_destroy_pipe_1; + } + gss_auth->gss_pipe[0] = gss_pipe; + + return gss_auth; +err_destroy_pipe_1: + gss_pipe_free(gss_auth->gss_pipe[1]); +err_destroy_credcache: + rpcauth_destroy_credcache(auth); +err_put_mech: + gss_mech_put(gss_auth->mech); +err_put_net: + put_net_track(gss_auth->net, &gss_auth->ns_tracker); +err_free: + kfree(gss_auth->target_name); + kfree(gss_auth); +out_dec: + module_put(THIS_MODULE); + trace_rpcgss_createauth(flavor, err); + return ERR_PTR(err); +} + +static void +gss_free(struct gss_auth *gss_auth) +{ + gss_pipe_free(gss_auth->gss_pipe[0]); + gss_pipe_free(gss_auth->gss_pipe[1]); + gss_mech_put(gss_auth->mech); + put_net_track(gss_auth->net, &gss_auth->ns_tracker); + kfree(gss_auth->target_name); + + kfree(gss_auth); + module_put(THIS_MODULE); +} + +static void +gss_free_callback(struct kref *kref) +{ + struct gss_auth *gss_auth = container_of(kref, struct gss_auth, kref); + + gss_free(gss_auth); +} + +static void +gss_put_auth(struct gss_auth *gss_auth) +{ + kref_put(&gss_auth->kref, gss_free_callback); +} + +static void +gss_destroy(struct rpc_auth *auth) +{ + struct gss_auth *gss_auth = container_of(auth, + struct gss_auth, rpc_auth); + + if (hash_hashed(&gss_auth->hash)) { + spin_lock(&gss_auth_hash_lock); + hash_del(&gss_auth->hash); + spin_unlock(&gss_auth_hash_lock); + } + + gss_pipe_free(gss_auth->gss_pipe[0]); + gss_auth->gss_pipe[0] = NULL; + gss_pipe_free(gss_auth->gss_pipe[1]); + gss_auth->gss_pipe[1] = NULL; + rpcauth_destroy_credcache(auth); + + gss_put_auth(gss_auth); +} + +/* + * Auths may be shared between rpc clients that were cloned from a + * common client with the same xprt, if they also share the flavor and + * target_name. + * + * The auth is looked up from the oldest parent sharing the same + * cl_xprt, and the auth itself references only that common parent + * (which is guaranteed to last as long as any of its descendants). + */ +static struct gss_auth * +gss_auth_find_or_add_hashed(const struct rpc_auth_create_args *args, + struct rpc_clnt *clnt, + struct gss_auth *new) +{ + struct gss_auth *gss_auth; + unsigned long hashval = (unsigned long)clnt; + + spin_lock(&gss_auth_hash_lock); + hash_for_each_possible(gss_auth_hash_table, + gss_auth, + hash, + hashval) { + if (gss_auth->client != clnt) + continue; + if (gss_auth->rpc_auth.au_flavor != args->pseudoflavor) + continue; + if (gss_auth->target_name != args->target_name) { + if (gss_auth->target_name == NULL) + continue; + if (args->target_name == NULL) + continue; + if (strcmp(gss_auth->target_name, args->target_name)) + continue; + } + if (!refcount_inc_not_zero(&gss_auth->rpc_auth.au_count)) + continue; + goto out; + } + if (new) + hash_add(gss_auth_hash_table, &new->hash, hashval); + gss_auth = new; +out: + spin_unlock(&gss_auth_hash_lock); + return gss_auth; +} + +static struct gss_auth * +gss_create_hashed(const struct rpc_auth_create_args *args, + struct rpc_clnt *clnt) +{ + struct gss_auth *gss_auth; + struct gss_auth *new; + + gss_auth = gss_auth_find_or_add_hashed(args, clnt, NULL); + if (gss_auth != NULL) + goto out; + new = gss_create_new(args, clnt); + if (IS_ERR(new)) + return new; + gss_auth = gss_auth_find_or_add_hashed(args, clnt, new); + if (gss_auth != new) + gss_destroy(&new->rpc_auth); +out: + return gss_auth; +} + +static struct rpc_auth * +gss_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) +{ + struct gss_auth *gss_auth; + struct rpc_xprt_switch *xps = rcu_access_pointer(clnt->cl_xpi.xpi_xpswitch); + + while (clnt != clnt->cl_parent) { + struct rpc_clnt *parent = clnt->cl_parent; + /* Find the original parent for this transport */ + if (rcu_access_pointer(parent->cl_xpi.xpi_xpswitch) != xps) + break; + clnt = parent; + } + + gss_auth = gss_create_hashed(args, clnt); + if (IS_ERR(gss_auth)) + return ERR_CAST(gss_auth); + return &gss_auth->rpc_auth; +} + +static struct gss_cred * +gss_dup_cred(struct gss_auth *gss_auth, struct gss_cred *gss_cred) +{ + struct gss_cred *new; + + /* Make a copy of the cred so that we can reference count it */ + new = kzalloc(sizeof(*gss_cred), GFP_KERNEL); + if (new) { + struct auth_cred acred = { + .cred = gss_cred->gc_base.cr_cred, + }; + struct gss_cl_ctx *ctx = + rcu_dereference_protected(gss_cred->gc_ctx, 1); + + rpcauth_init_cred(&new->gc_base, &acred, + &gss_auth->rpc_auth, + &gss_nullops); + new->gc_base.cr_flags = 1UL << RPCAUTH_CRED_UPTODATE; + new->gc_service = gss_cred->gc_service; + new->gc_principal = gss_cred->gc_principal; + kref_get(&gss_auth->kref); + rcu_assign_pointer(new->gc_ctx, ctx); + gss_get_ctx(ctx); + } + return new; +} + +/* + * gss_send_destroy_context will cause the RPCSEC_GSS to send a NULL RPC call + * to the server with the GSS control procedure field set to + * RPC_GSS_PROC_DESTROY. This should normally cause the server to release + * all RPCSEC_GSS state associated with that context. + */ +static void +gss_send_destroy_context(struct rpc_cred *cred) +{ + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); + struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth); + struct gss_cl_ctx *ctx = rcu_dereference_protected(gss_cred->gc_ctx, 1); + struct gss_cred *new; + struct rpc_task *task; + + new = gss_dup_cred(gss_auth, gss_cred); + if (new) { + ctx->gc_proc = RPC_GSS_PROC_DESTROY; + + trace_rpcgss_ctx_destroy(gss_cred); + task = rpc_call_null(gss_auth->client, &new->gc_base, + RPC_TASK_ASYNC); + if (!IS_ERR(task)) + rpc_put_task(task); + + put_rpccred(&new->gc_base); + } +} + +/* gss_destroy_cred (and gss_free_ctx) are used to clean up after failure + * to create a new cred or context, so they check that things have been + * allocated before freeing them. */ +static void +gss_do_free_ctx(struct gss_cl_ctx *ctx) +{ + gss_delete_sec_context(&ctx->gc_gss_ctx); + kfree(ctx->gc_wire_ctx.data); + kfree(ctx->gc_acceptor.data); + kfree(ctx); +} + +static void +gss_free_ctx_callback(struct rcu_head *head) +{ + struct gss_cl_ctx *ctx = container_of(head, struct gss_cl_ctx, gc_rcu); + gss_do_free_ctx(ctx); +} + +static void +gss_free_ctx(struct gss_cl_ctx *ctx) +{ + call_rcu(&ctx->gc_rcu, gss_free_ctx_callback); +} + +static void +gss_free_cred(struct gss_cred *gss_cred) +{ + kfree(gss_cred); +} + +static void +gss_free_cred_callback(struct rcu_head *head) +{ + struct gss_cred *gss_cred = container_of(head, struct gss_cred, gc_base.cr_rcu); + gss_free_cred(gss_cred); +} + +static void +gss_destroy_nullcred(struct rpc_cred *cred) +{ + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); + struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth); + struct gss_cl_ctx *ctx = rcu_dereference_protected(gss_cred->gc_ctx, 1); + + RCU_INIT_POINTER(gss_cred->gc_ctx, NULL); + put_cred(cred->cr_cred); + call_rcu(&cred->cr_rcu, gss_free_cred_callback); + if (ctx) + gss_put_ctx(ctx); + gss_put_auth(gss_auth); +} + +static void +gss_destroy_cred(struct rpc_cred *cred) +{ + if (test_and_clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) + gss_send_destroy_context(cred); + gss_destroy_nullcred(cred); +} + +static int +gss_hash_cred(struct auth_cred *acred, unsigned int hashbits) +{ + return hash_64(from_kuid(&init_user_ns, acred->cred->fsuid), hashbits); +} + +/* + * Lookup RPCSEC_GSS cred for the current process + */ +static struct rpc_cred *gss_lookup_cred(struct rpc_auth *auth, + struct auth_cred *acred, int flags) +{ + return rpcauth_lookup_credcache(auth, acred, flags, + rpc_task_gfp_mask()); +} + +static struct rpc_cred * +gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags, gfp_t gfp) +{ + struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth); + struct gss_cred *cred = NULL; + int err = -ENOMEM; + + if (!(cred = kzalloc(sizeof(*cred), gfp))) + goto out_err; + + rpcauth_init_cred(&cred->gc_base, acred, auth, &gss_credops); + /* + * Note: in order to force a call to call_refresh(), we deliberately + * fail to flag the credential as RPCAUTH_CRED_UPTODATE. + */ + cred->gc_base.cr_flags = 1UL << RPCAUTH_CRED_NEW; + cred->gc_service = gss_auth->service; + cred->gc_principal = acred->principal; + kref_get(&gss_auth->kref); + return &cred->gc_base; + +out_err: + return ERR_PTR(err); +} + +static int +gss_cred_init(struct rpc_auth *auth, struct rpc_cred *cred) +{ + struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth); + struct gss_cred *gss_cred = container_of(cred,struct gss_cred, gc_base); + int err; + + do { + err = gss_create_upcall(gss_auth, gss_cred); + } while (err == -EAGAIN); + return err; +} + +static char * +gss_stringify_acceptor(struct rpc_cred *cred) +{ + char *string = NULL; + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); + struct gss_cl_ctx *ctx; + unsigned int len; + struct xdr_netobj *acceptor; + + rcu_read_lock(); + ctx = rcu_dereference(gss_cred->gc_ctx); + if (!ctx) + goto out; + + len = ctx->gc_acceptor.len; + rcu_read_unlock(); + + /* no point if there's no string */ + if (!len) + return NULL; +realloc: + string = kmalloc(len + 1, GFP_KERNEL); + if (!string) + return NULL; + + rcu_read_lock(); + ctx = rcu_dereference(gss_cred->gc_ctx); + + /* did the ctx disappear or was it replaced by one with no acceptor? */ + if (!ctx || !ctx->gc_acceptor.len) { + kfree(string); + string = NULL; + goto out; + } + + acceptor = &ctx->gc_acceptor; + + /* + * Did we find a new acceptor that's longer than the original? Allocate + * a longer buffer and try again. + */ + if (len < acceptor->len) { + len = acceptor->len; + rcu_read_unlock(); + kfree(string); + goto realloc; + } + + memcpy(string, acceptor->data, acceptor->len); + string[acceptor->len] = '\0'; +out: + rcu_read_unlock(); + return string; +} + +/* + * Returns -EACCES if GSS context is NULL or will expire within the + * timeout (miliseconds) + */ +static int +gss_key_timeout(struct rpc_cred *rc) +{ + struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base); + struct gss_cl_ctx *ctx; + unsigned long timeout = jiffies + (gss_key_expire_timeo * HZ); + int ret = 0; + + rcu_read_lock(); + ctx = rcu_dereference(gss_cred->gc_ctx); + if (!ctx || time_after(timeout, ctx->gc_expiry)) + ret = -EACCES; + rcu_read_unlock(); + + return ret; +} + +static int +gss_match(struct auth_cred *acred, struct rpc_cred *rc, int flags) +{ + struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base); + struct gss_cl_ctx *ctx; + int ret; + + if (test_bit(RPCAUTH_CRED_NEW, &rc->cr_flags)) + goto out; + /* Don't match with creds that have expired. */ + rcu_read_lock(); + ctx = rcu_dereference(gss_cred->gc_ctx); + if (!ctx || time_after(jiffies, ctx->gc_expiry)) { + rcu_read_unlock(); + return 0; + } + rcu_read_unlock(); + if (!test_bit(RPCAUTH_CRED_UPTODATE, &rc->cr_flags)) + return 0; +out: + if (acred->principal != NULL) { + if (gss_cred->gc_principal == NULL) + return 0; + ret = strcmp(acred->principal, gss_cred->gc_principal) == 0; + } else { + if (gss_cred->gc_principal != NULL) + return 0; + ret = uid_eq(rc->cr_cred->fsuid, acred->cred->fsuid); + } + return ret; +} + +/* + * Marshal credentials. + * + * The expensive part is computing the verifier. We can't cache a + * pre-computed version of the verifier because the seqno, which + * is different every time, is included in the MIC. + */ +static int gss_marshal(struct rpc_task *task, struct xdr_stream *xdr) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_cred *cred = req->rq_cred; + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, + gc_base); + struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); + __be32 *p, *cred_len; + u32 maj_stat = 0; + struct xdr_netobj mic; + struct kvec iov; + struct xdr_buf verf_buf; + int status; + + /* Credential */ + + p = xdr_reserve_space(xdr, 7 * sizeof(*p) + + ctx->gc_wire_ctx.len); + if (!p) + goto marshal_failed; + *p++ = rpc_auth_gss; + cred_len = p++; + + spin_lock(&ctx->gc_seq_lock); + req->rq_seqno = (ctx->gc_seq < MAXSEQ) ? ctx->gc_seq++ : MAXSEQ; + spin_unlock(&ctx->gc_seq_lock); + if (req->rq_seqno == MAXSEQ) + goto expired; + trace_rpcgss_seqno(task); + + *p++ = cpu_to_be32(RPC_GSS_VERSION); + *p++ = cpu_to_be32(ctx->gc_proc); + *p++ = cpu_to_be32(req->rq_seqno); + *p++ = cpu_to_be32(gss_cred->gc_service); + p = xdr_encode_netobj(p, &ctx->gc_wire_ctx); + *cred_len = cpu_to_be32((p - (cred_len + 1)) << 2); + + /* Verifier */ + + /* We compute the checksum for the verifier over the xdr-encoded bytes + * starting with the xid and ending at the end of the credential: */ + iov.iov_base = req->rq_snd_buf.head[0].iov_base; + iov.iov_len = (u8 *)p - (u8 *)iov.iov_base; + xdr_buf_from_iov(&iov, &verf_buf); + + p = xdr_reserve_space(xdr, sizeof(*p)); + if (!p) + goto marshal_failed; + *p++ = rpc_auth_gss; + mic.data = (u8 *)(p + 1); + maj_stat = gss_get_mic(ctx->gc_gss_ctx, &verf_buf, &mic); + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + goto expired; + else if (maj_stat != 0) + goto bad_mic; + if (xdr_stream_encode_opaque_inline(xdr, (void **)&p, mic.len) < 0) + goto marshal_failed; + status = 0; +out: + gss_put_ctx(ctx); + return status; +expired: + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + status = -EKEYEXPIRED; + goto out; +marshal_failed: + status = -EMSGSIZE; + goto out; +bad_mic: + trace_rpcgss_get_mic(task, maj_stat); + status = -EIO; + goto out; +} + +static int gss_renew_cred(struct rpc_task *task) +{ + struct rpc_cred *oldcred = task->tk_rqstp->rq_cred; + struct gss_cred *gss_cred = container_of(oldcred, + struct gss_cred, + gc_base); + struct rpc_auth *auth = oldcred->cr_auth; + struct auth_cred acred = { + .cred = oldcred->cr_cred, + .principal = gss_cred->gc_principal, + }; + struct rpc_cred *new; + + new = gss_lookup_cred(auth, &acred, RPCAUTH_LOOKUP_NEW); + if (IS_ERR(new)) + return PTR_ERR(new); + + task->tk_rqstp->rq_cred = new; + put_rpccred(oldcred); + return 0; +} + +static int gss_cred_is_negative_entry(struct rpc_cred *cred) +{ + if (test_bit(RPCAUTH_CRED_NEGATIVE, &cred->cr_flags)) { + unsigned long now = jiffies; + unsigned long begin, expire; + struct gss_cred *gss_cred; + + gss_cred = container_of(cred, struct gss_cred, gc_base); + begin = gss_cred->gc_upcall_timestamp; + expire = begin + gss_expired_cred_retry_delay * HZ; + + if (time_in_range_open(now, begin, expire)) + return 1; + } + return 0; +} + +/* +* Refresh credentials. XXX - finish +*/ +static int +gss_refresh(struct rpc_task *task) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + int ret = 0; + + if (gss_cred_is_negative_entry(cred)) + return -EKEYEXPIRED; + + if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) && + !test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags)) { + ret = gss_renew_cred(task); + if (ret < 0) + goto out; + cred = task->tk_rqstp->rq_cred; + } + + if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags)) + ret = gss_refresh_upcall(task); +out: + return ret; +} + +/* Dummy refresh routine: used only when destroying the context */ +static int +gss_refresh_null(struct rpc_task *task) +{ + return 0; +} + +static int +gss_validate(struct rpc_task *task, struct xdr_stream *xdr) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); + __be32 *p, *seq = NULL; + struct kvec iov; + struct xdr_buf verf_buf; + struct xdr_netobj mic; + u32 len, maj_stat; + int status; + + p = xdr_inline_decode(xdr, 2 * sizeof(*p)); + if (!p) + goto validate_failed; + if (*p++ != rpc_auth_gss) + goto validate_failed; + len = be32_to_cpup(p); + if (len > RPC_MAX_AUTH_SIZE) + goto validate_failed; + p = xdr_inline_decode(xdr, len); + if (!p) + goto validate_failed; + + seq = kmalloc(4, GFP_KERNEL); + if (!seq) + goto validate_failed; + *seq = cpu_to_be32(task->tk_rqstp->rq_seqno); + iov.iov_base = seq; + iov.iov_len = 4; + xdr_buf_from_iov(&iov, &verf_buf); + mic.data = (u8 *)p; + mic.len = len; + maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic); + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + if (maj_stat) + goto bad_mic; + + /* We leave it to unwrap to calculate au_rslack. For now we just + * calculate the length of the verifier: */ + if (test_bit(RPCAUTH_AUTH_UPDATE_SLACK, &cred->cr_auth->au_flags)) + cred->cr_auth->au_verfsize = XDR_QUADLEN(len) + 2; + status = 0; +out: + gss_put_ctx(ctx); + kfree(seq); + return status; + +validate_failed: + status = -EIO; + goto out; +bad_mic: + trace_rpcgss_verify_mic(task, maj_stat); + status = -EACCES; + goto out; +} + +static noinline_for_stack int +gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx, + struct rpc_task *task, struct xdr_stream *xdr) +{ + struct rpc_rqst *rqstp = task->tk_rqstp; + struct xdr_buf integ_buf, *snd_buf = &rqstp->rq_snd_buf; + struct xdr_netobj mic; + __be32 *p, *integ_len; + u32 offset, maj_stat; + + p = xdr_reserve_space(xdr, 2 * sizeof(*p)); + if (!p) + goto wrap_failed; + integ_len = p++; + *p = cpu_to_be32(rqstp->rq_seqno); + + if (rpcauth_wrap_req_encode(task, xdr)) + goto wrap_failed; + + offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base; + if (xdr_buf_subsegment(snd_buf, &integ_buf, + offset, snd_buf->len - offset)) + goto wrap_failed; + *integ_len = cpu_to_be32(integ_buf.len); + + p = xdr_reserve_space(xdr, 0); + if (!p) + goto wrap_failed; + mic.data = (u8 *)(p + 1); + maj_stat = gss_get_mic(ctx->gc_gss_ctx, &integ_buf, &mic); + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + else if (maj_stat) + goto bad_mic; + /* Check that the trailing MIC fit in the buffer, after the fact */ + if (xdr_stream_encode_opaque_inline(xdr, (void **)&p, mic.len) < 0) + goto wrap_failed; + return 0; +wrap_failed: + return -EMSGSIZE; +bad_mic: + trace_rpcgss_get_mic(task, maj_stat); + return -EIO; +} + +static void +priv_release_snd_buf(struct rpc_rqst *rqstp) +{ + int i; + + for (i=0; i < rqstp->rq_enc_pages_num; i++) + __free_page(rqstp->rq_enc_pages[i]); + kfree(rqstp->rq_enc_pages); + rqstp->rq_release_snd_buf = NULL; +} + +static int +alloc_enc_pages(struct rpc_rqst *rqstp) +{ + struct xdr_buf *snd_buf = &rqstp->rq_snd_buf; + int first, last, i; + + if (rqstp->rq_release_snd_buf) + rqstp->rq_release_snd_buf(rqstp); + + if (snd_buf->page_len == 0) { + rqstp->rq_enc_pages_num = 0; + return 0; + } + + first = snd_buf->page_base >> PAGE_SHIFT; + last = (snd_buf->page_base + snd_buf->page_len - 1) >> PAGE_SHIFT; + rqstp->rq_enc_pages_num = last - first + 1 + 1; + rqstp->rq_enc_pages + = kmalloc_array(rqstp->rq_enc_pages_num, + sizeof(struct page *), + GFP_KERNEL); + if (!rqstp->rq_enc_pages) + goto out; + for (i=0; i < rqstp->rq_enc_pages_num; i++) { + rqstp->rq_enc_pages[i] = alloc_page(GFP_KERNEL); + if (rqstp->rq_enc_pages[i] == NULL) + goto out_free; + } + rqstp->rq_release_snd_buf = priv_release_snd_buf; + return 0; +out_free: + rqstp->rq_enc_pages_num = i; + priv_release_snd_buf(rqstp); +out: + return -EAGAIN; +} + +static noinline_for_stack int +gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, + struct rpc_task *task, struct xdr_stream *xdr) +{ + struct rpc_rqst *rqstp = task->tk_rqstp; + struct xdr_buf *snd_buf = &rqstp->rq_snd_buf; + u32 pad, offset, maj_stat; + int status; + __be32 *p, *opaque_len; + struct page **inpages; + int first; + struct kvec *iov; + + status = -EIO; + p = xdr_reserve_space(xdr, 2 * sizeof(*p)); + if (!p) + goto wrap_failed; + opaque_len = p++; + *p = cpu_to_be32(rqstp->rq_seqno); + + if (rpcauth_wrap_req_encode(task, xdr)) + goto wrap_failed; + + status = alloc_enc_pages(rqstp); + if (unlikely(status)) + goto wrap_failed; + first = snd_buf->page_base >> PAGE_SHIFT; + inpages = snd_buf->pages + first; + snd_buf->pages = rqstp->rq_enc_pages; + snd_buf->page_base -= first << PAGE_SHIFT; + /* + * Move the tail into its own page, in case gss_wrap needs + * more space in the head when wrapping. + * + * Still... Why can't gss_wrap just slide the tail down? + */ + if (snd_buf->page_len || snd_buf->tail[0].iov_len) { + char *tmp; + + tmp = page_address(rqstp->rq_enc_pages[rqstp->rq_enc_pages_num - 1]); + memcpy(tmp, snd_buf->tail[0].iov_base, snd_buf->tail[0].iov_len); + snd_buf->tail[0].iov_base = tmp; + } + offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base; + maj_stat = gss_wrap(ctx->gc_gss_ctx, offset, snd_buf, inpages); + /* slack space should prevent this ever happening: */ + if (unlikely(snd_buf->len > snd_buf->buflen)) + goto wrap_failed; + /* We're assuming that when GSS_S_CONTEXT_EXPIRED, the encryption was + * done anyway, so it's safe to put the request on the wire: */ + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + else if (maj_stat) + goto bad_wrap; + + *opaque_len = cpu_to_be32(snd_buf->len - offset); + /* guess whether the pad goes into the head or the tail: */ + if (snd_buf->page_len || snd_buf->tail[0].iov_len) + iov = snd_buf->tail; + else + iov = snd_buf->head; + p = iov->iov_base + iov->iov_len; + pad = xdr_pad_size(snd_buf->len - offset); + memset(p, 0, pad); + iov->iov_len += pad; + snd_buf->len += pad; + + return 0; +wrap_failed: + return status; +bad_wrap: + trace_rpcgss_wrap(task, maj_stat); + return -EIO; +} + +static int gss_wrap_req(struct rpc_task *task, struct xdr_stream *xdr) +{ + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, + gc_base); + struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); + int status; + + status = -EIO; + if (ctx->gc_proc != RPC_GSS_PROC_DATA) { + /* The spec seems a little ambiguous here, but I think that not + * wrapping context destruction requests makes the most sense. + */ + status = rpcauth_wrap_req_encode(task, xdr); + goto out; + } + switch (gss_cred->gc_service) { + case RPC_GSS_SVC_NONE: + status = rpcauth_wrap_req_encode(task, xdr); + break; + case RPC_GSS_SVC_INTEGRITY: + status = gss_wrap_req_integ(cred, ctx, task, xdr); + break; + case RPC_GSS_SVC_PRIVACY: + status = gss_wrap_req_priv(cred, ctx, task, xdr); + break; + default: + status = -EIO; + } +out: + gss_put_ctx(ctx); + return status; +} + +/** + * gss_update_rslack - Possibly update RPC receive buffer size estimates + * @task: rpc_task for incoming RPC Reply being unwrapped + * @cred: controlling rpc_cred for @task + * @before: XDR words needed before each RPC Reply message + * @after: XDR words needed following each RPC Reply message + * + */ +static void gss_update_rslack(struct rpc_task *task, struct rpc_cred *cred, + unsigned int before, unsigned int after) +{ + struct rpc_auth *auth = cred->cr_auth; + + if (test_and_clear_bit(RPCAUTH_AUTH_UPDATE_SLACK, &auth->au_flags)) { + auth->au_ralign = auth->au_verfsize + before; + auth->au_rslack = auth->au_verfsize + after; + trace_rpcgss_update_slack(task, auth); + } +} + +static int +gss_unwrap_resp_auth(struct rpc_task *task, struct rpc_cred *cred) +{ + gss_update_rslack(task, cred, 0, 0); + return 0; +} + +/* + * RFC 2203, Section 5.3.2.2 + * + * struct rpc_gss_integ_data { + * opaque databody_integ<>; + * opaque checksum<>; + * }; + * + * struct rpc_gss_data_t { + * unsigned int seq_num; + * proc_req_arg_t arg; + * }; + */ +static noinline_for_stack int +gss_unwrap_resp_integ(struct rpc_task *task, struct rpc_cred *cred, + struct gss_cl_ctx *ctx, struct rpc_rqst *rqstp, + struct xdr_stream *xdr) +{ + struct xdr_buf gss_data, *rcv_buf = &rqstp->rq_rcv_buf; + u32 len, offset, seqno, maj_stat; + struct xdr_netobj mic; + int ret; + + ret = -EIO; + mic.data = NULL; + + /* opaque databody_integ<>; */ + if (xdr_stream_decode_u32(xdr, &len)) + goto unwrap_failed; + if (len & 3) + goto unwrap_failed; + offset = rcv_buf->len - xdr_stream_remaining(xdr); + if (xdr_stream_decode_u32(xdr, &seqno)) + goto unwrap_failed; + if (seqno != rqstp->rq_seqno) + goto bad_seqno; + if (xdr_buf_subsegment(rcv_buf, &gss_data, offset, len)) + goto unwrap_failed; + + /* + * The xdr_stream now points to the beginning of the + * upper layer payload, to be passed below to + * rpcauth_unwrap_resp_decode(). The checksum, which + * follows the upper layer payload in @rcv_buf, is + * located and parsed without updating the xdr_stream. + */ + + /* opaque checksum<>; */ + offset += len; + if (xdr_decode_word(rcv_buf, offset, &len)) + goto unwrap_failed; + offset += sizeof(__be32); + if (offset + len > rcv_buf->len) + goto unwrap_failed; + mic.len = len; + mic.data = kmalloc(len, GFP_KERNEL); + if (ZERO_OR_NULL_PTR(mic.data)) + goto unwrap_failed; + if (read_bytes_from_xdr_buf(rcv_buf, offset, mic.data, mic.len)) + goto unwrap_failed; + + maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &gss_data, &mic); + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + if (maj_stat != GSS_S_COMPLETE) + goto bad_mic; + + gss_update_rslack(task, cred, 2, 2 + 1 + XDR_QUADLEN(mic.len)); + ret = 0; + +out: + kfree(mic.data); + return ret; + +unwrap_failed: + trace_rpcgss_unwrap_failed(task); + goto out; +bad_seqno: + trace_rpcgss_bad_seqno(task, rqstp->rq_seqno, seqno); + goto out; +bad_mic: + trace_rpcgss_verify_mic(task, maj_stat); + goto out; +} + +static noinline_for_stack int +gss_unwrap_resp_priv(struct rpc_task *task, struct rpc_cred *cred, + struct gss_cl_ctx *ctx, struct rpc_rqst *rqstp, + struct xdr_stream *xdr) +{ + struct xdr_buf *rcv_buf = &rqstp->rq_rcv_buf; + struct kvec *head = rqstp->rq_rcv_buf.head; + u32 offset, opaque_len, maj_stat; + __be32 *p; + + p = xdr_inline_decode(xdr, 2 * sizeof(*p)); + if (unlikely(!p)) + goto unwrap_failed; + opaque_len = be32_to_cpup(p++); + offset = (u8 *)(p) - (u8 *)head->iov_base; + if (offset + opaque_len > rcv_buf->len) + goto unwrap_failed; + + maj_stat = gss_unwrap(ctx->gc_gss_ctx, offset, + offset + opaque_len, rcv_buf); + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + if (maj_stat != GSS_S_COMPLETE) + goto bad_unwrap; + /* gss_unwrap decrypted the sequence number */ + if (be32_to_cpup(p++) != rqstp->rq_seqno) + goto bad_seqno; + + /* gss_unwrap redacts the opaque blob from the head iovec. + * rcv_buf has changed, thus the stream needs to be reset. + */ + xdr_init_decode(xdr, rcv_buf, p, rqstp); + + gss_update_rslack(task, cred, 2 + ctx->gc_gss_ctx->align, + 2 + ctx->gc_gss_ctx->slack); + + return 0; +unwrap_failed: + trace_rpcgss_unwrap_failed(task); + return -EIO; +bad_seqno: + trace_rpcgss_bad_seqno(task, rqstp->rq_seqno, be32_to_cpup(--p)); + return -EIO; +bad_unwrap: + trace_rpcgss_unwrap(task, maj_stat); + return -EIO; +} + +static bool +gss_seq_is_newer(u32 new, u32 old) +{ + return (s32)(new - old) > 0; +} + +static bool +gss_xmit_need_reencode(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_cred *cred = req->rq_cred; + struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); + u32 win, seq_xmit = 0; + bool ret = true; + + if (!ctx) + goto out; + + if (gss_seq_is_newer(req->rq_seqno, READ_ONCE(ctx->gc_seq))) + goto out_ctx; + + seq_xmit = READ_ONCE(ctx->gc_seq_xmit); + while (gss_seq_is_newer(req->rq_seqno, seq_xmit)) { + u32 tmp = seq_xmit; + + seq_xmit = cmpxchg(&ctx->gc_seq_xmit, tmp, req->rq_seqno); + if (seq_xmit == tmp) { + ret = false; + goto out_ctx; + } + } + + win = ctx->gc_win; + if (win > 0) + ret = !gss_seq_is_newer(req->rq_seqno, seq_xmit - win); + +out_ctx: + gss_put_ctx(ctx); +out: + trace_rpcgss_need_reencode(task, seq_xmit, ret); + return ret; +} + +static int +gss_unwrap_resp(struct rpc_task *task, struct xdr_stream *xdr) +{ + struct rpc_rqst *rqstp = task->tk_rqstp; + struct rpc_cred *cred = rqstp->rq_cred; + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, + gc_base); + struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); + int status = -EIO; + + if (ctx->gc_proc != RPC_GSS_PROC_DATA) + goto out_decode; + switch (gss_cred->gc_service) { + case RPC_GSS_SVC_NONE: + status = gss_unwrap_resp_auth(task, cred); + break; + case RPC_GSS_SVC_INTEGRITY: + status = gss_unwrap_resp_integ(task, cred, ctx, rqstp, xdr); + break; + case RPC_GSS_SVC_PRIVACY: + status = gss_unwrap_resp_priv(task, cred, ctx, rqstp, xdr); + break; + } + if (status) + goto out; + +out_decode: + status = rpcauth_unwrap_resp_decode(task, xdr); +out: + gss_put_ctx(ctx); + return status; +} + +static const struct rpc_authops authgss_ops = { + .owner = THIS_MODULE, + .au_flavor = RPC_AUTH_GSS, + .au_name = "RPCSEC_GSS", + .create = gss_create, + .destroy = gss_destroy, + .hash_cred = gss_hash_cred, + .lookup_cred = gss_lookup_cred, + .crcreate = gss_create_cred, + .info2flavor = gss_mech_info2flavor, + .flavor2info = gss_mech_flavor2info, +}; + +static const struct rpc_credops gss_credops = { + .cr_name = "AUTH_GSS", + .crdestroy = gss_destroy_cred, + .cr_init = gss_cred_init, + .crmatch = gss_match, + .crmarshal = gss_marshal, + .crrefresh = gss_refresh, + .crvalidate = gss_validate, + .crwrap_req = gss_wrap_req, + .crunwrap_resp = gss_unwrap_resp, + .crkey_timeout = gss_key_timeout, + .crstringify_acceptor = gss_stringify_acceptor, + .crneed_reencode = gss_xmit_need_reencode, +}; + +static const struct rpc_credops gss_nullops = { + .cr_name = "AUTH_GSS", + .crdestroy = gss_destroy_nullcred, + .crmatch = gss_match, + .crmarshal = gss_marshal, + .crrefresh = gss_refresh_null, + .crvalidate = gss_validate, + .crwrap_req = gss_wrap_req, + .crunwrap_resp = gss_unwrap_resp, + .crstringify_acceptor = gss_stringify_acceptor, +}; + +static const struct rpc_pipe_ops gss_upcall_ops_v0 = { + .upcall = gss_v0_upcall, + .downcall = gss_pipe_downcall, + .destroy_msg = gss_pipe_destroy_msg, + .open_pipe = gss_pipe_open_v0, + .release_pipe = gss_pipe_release, +}; + +static const struct rpc_pipe_ops gss_upcall_ops_v1 = { + .upcall = gss_v1_upcall, + .downcall = gss_pipe_downcall, + .destroy_msg = gss_pipe_destroy_msg, + .open_pipe = gss_pipe_open_v1, + .release_pipe = gss_pipe_release, +}; + +static __net_init int rpcsec_gss_init_net(struct net *net) +{ + return gss_svc_init_net(net); +} + +static __net_exit void rpcsec_gss_exit_net(struct net *net) +{ + gss_svc_shutdown_net(net); +} + +static struct pernet_operations rpcsec_gss_net_ops = { + .init = rpcsec_gss_init_net, + .exit = rpcsec_gss_exit_net, +}; + +/* + * Initialize RPCSEC_GSS module + */ +static int __init init_rpcsec_gss(void) +{ + int err = 0; + + err = rpcauth_register(&authgss_ops); + if (err) + goto out; + err = gss_svc_init(); + if (err) + goto out_unregister; + err = register_pernet_subsys(&rpcsec_gss_net_ops); + if (err) + goto out_svc_exit; + rpc_init_wait_queue(&pipe_version_rpc_waitqueue, "gss pipe version"); + return 0; +out_svc_exit: + gss_svc_shutdown(); +out_unregister: + rpcauth_unregister(&authgss_ops); +out: + return err; +} + +static void __exit exit_rpcsec_gss(void) +{ + unregister_pernet_subsys(&rpcsec_gss_net_ops); + gss_svc_shutdown(); + rpcauth_unregister(&authgss_ops); + rcu_barrier(); /* Wait for completion of call_rcu()'s */ +} + +MODULE_ALIAS("rpc-auth-6"); +MODULE_LICENSE("GPL"); +module_param_named(expired_cred_retry_delay, + gss_expired_cred_retry_delay, + uint, 0644); +MODULE_PARM_DESC(expired_cred_retry_delay, "Timeout (in seconds) until " + "the RPC engine retries an expired credential"); + +module_param_named(key_expire_timeo, + gss_key_expire_timeo, + uint, 0644); +MODULE_PARM_DESC(key_expire_timeo, "Time (in seconds) at the end of a " + "credential keys lifetime where the NFS layer cleans up " + "prior to key expiration"); + +module_init(init_rpcsec_gss) +module_exit(exit_rpcsec_gss) diff --git a/net/sunrpc/auth_gss/auth_gss_internal.h b/net/sunrpc/auth_gss/auth_gss_internal.h new file mode 100644 index 000000000..c53b32909 --- /dev/null +++ b/net/sunrpc/auth_gss/auth_gss_internal.h @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: BSD-3-Clause +/* + * linux/net/sunrpc/auth_gss/auth_gss_internal.h + * + * Internal definitions for RPCSEC_GSS client authentication + * + * Copyright (c) 2000 The Regents of the University of Michigan. + * All rights reserved. + * + */ +#include +#include +#include + +static inline const void * +simple_get_bytes(const void *p, const void *end, void *res, size_t len) +{ + const void *q = (const void *)((const char *)p + len); + if (unlikely(q > end || q < p)) + return ERR_PTR(-EFAULT); + memcpy(res, p, len); + return q; +} + +static inline const void * +simple_get_netobj(const void *p, const void *end, struct xdr_netobj *dest) +{ + const void *q; + unsigned int len; + + p = simple_get_bytes(p, end, &len, sizeof(len)); + if (IS_ERR(p)) + return p; + q = (const void *)((const char *)p + len); + if (unlikely(q > end || q < p)) + return ERR_PTR(-EFAULT); + if (len) { + dest->data = kmemdup(p, len, GFP_KERNEL); + if (unlikely(dest->data == NULL)) + return ERR_PTR(-ENOMEM); + } else + dest->data = NULL; + dest->len = len; + return q; +} diff --git a/net/sunrpc/auth_gss/gss_generic_token.c b/net/sunrpc/auth_gss/gss_generic_token.c new file mode 100644 index 000000000..4a4082bb2 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_generic_token.c @@ -0,0 +1,231 @@ +/* + * linux/net/sunrpc/gss_generic_token.c + * + * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/generic/util_token.c + * + * Copyright (c) 2000 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson + */ + +/* + * Copyright 1993 by OpenVision Technologies, Inc. + * + * Permission to use, copy, modify, distribute, and sell this software + * and its documentation for any purpose is hereby granted without fee, + * provided that the above copyright notice appears in all copies and + * that both that copyright notice and this permission notice appear in + * supporting documentation, and that the name of OpenVision not be used + * in advertising or publicity pertaining to distribution of the software + * without specific, written prior permission. OpenVision makes no + * representations about the suitability of this software for any + * purpose. It is provided "as is" without express or implied warranty. + * + * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, + * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO + * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR + * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF + * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR + * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +#include +#include +#include +#include +#include + + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + + +/* TWRITE_STR from gssapiP_generic.h */ +#define TWRITE_STR(ptr, str, len) \ + memcpy((ptr), (char *) (str), (len)); \ + (ptr) += (len); + +/* XXXX this code currently makes the assumption that a mech oid will + never be longer than 127 bytes. This assumption is not inherent in + the interfaces, so the code can be fixed if the OSI namespace + balloons unexpectedly. */ + +/* Each token looks like this: + +0x60 tag for APPLICATION 0, SEQUENCE + (constructed, definite-length) + possible multiple bytes, need to parse/generate + 0x06 tag for OBJECT IDENTIFIER + compile-time constant string (assume 1 byte) + compile-time constant string + the ANY containing the application token + bytes 0,1 are the token type + bytes 2,n are the token data + +For the purposes of this abstraction, the token "header" consists of +the sequence tag and length octets, the mech OID DER encoding, and the +first two inner bytes, which indicate the token type. The token +"body" consists of everything else. + +*/ + +static int +der_length_size( int length) +{ + if (length < (1<<7)) + return 1; + else if (length < (1<<8)) + return 2; +#if (SIZEOF_INT == 2) + else + return 3; +#else + else if (length < (1<<16)) + return 3; + else if (length < (1<<24)) + return 4; + else + return 5; +#endif +} + +static void +der_write_length(unsigned char **buf, int length) +{ + if (length < (1<<7)) { + *(*buf)++ = (unsigned char) length; + } else { + *(*buf)++ = (unsigned char) (der_length_size(length)+127); +#if (SIZEOF_INT > 2) + if (length >= (1<<24)) + *(*buf)++ = (unsigned char) (length>>24); + if (length >= (1<<16)) + *(*buf)++ = (unsigned char) ((length>>16)&0xff); +#endif + if (length >= (1<<8)) + *(*buf)++ = (unsigned char) ((length>>8)&0xff); + *(*buf)++ = (unsigned char) (length&0xff); + } +} + +/* returns decoded length, or < 0 on failure. Advances buf and + decrements bufsize */ + +static int +der_read_length(unsigned char **buf, int *bufsize) +{ + unsigned char sf; + int ret; + + if (*bufsize < 1) + return -1; + sf = *(*buf)++; + (*bufsize)--; + if (sf & 0x80) { + if ((sf &= 0x7f) > ((*bufsize)-1)) + return -1; + if (sf > SIZEOF_INT) + return -1; + ret = 0; + for (; sf; sf--) { + ret = (ret<<8) + (*(*buf)++); + (*bufsize)--; + } + } else { + ret = sf; + } + + return ret; +} + +/* returns the length of a token, given the mech oid and the body size */ + +int +g_token_size(struct xdr_netobj *mech, unsigned int body_size) +{ + /* set body_size to sequence contents size */ + body_size += 2 + (int) mech->len; /* NEED overflow check */ + return 1 + der_length_size(body_size) + body_size; +} + +EXPORT_SYMBOL_GPL(g_token_size); + +/* fills in a buffer with the token header. The buffer is assumed to + be the right size. buf is advanced past the token header */ + +void +g_make_token_header(struct xdr_netobj *mech, int body_size, unsigned char **buf) +{ + *(*buf)++ = 0x60; + der_write_length(buf, 2 + mech->len + body_size); + *(*buf)++ = 0x06; + *(*buf)++ = (unsigned char) mech->len; + TWRITE_STR(*buf, mech->data, ((int) mech->len)); +} + +EXPORT_SYMBOL_GPL(g_make_token_header); + +/* + * Given a buffer containing a token, reads and verifies the token, + * leaving buf advanced past the token header, and setting body_size + * to the number of remaining bytes. Returns 0 on success, + * G_BAD_TOK_HEADER for a variety of errors, and G_WRONG_MECH if the + * mechanism in the token does not match the mech argument. buf and + * *body_size are left unmodified on error. + */ +u32 +g_verify_token_header(struct xdr_netobj *mech, int *body_size, + unsigned char **buf_in, int toksize) +{ + unsigned char *buf = *buf_in; + int seqsize; + struct xdr_netobj toid; + int ret = 0; + + if ((toksize-=1) < 0) + return G_BAD_TOK_HEADER; + if (*buf++ != 0x60) + return G_BAD_TOK_HEADER; + + if ((seqsize = der_read_length(&buf, &toksize)) < 0) + return G_BAD_TOK_HEADER; + + if (seqsize != toksize) + return G_BAD_TOK_HEADER; + + if ((toksize-=1) < 0) + return G_BAD_TOK_HEADER; + if (*buf++ != 0x06) + return G_BAD_TOK_HEADER; + + if ((toksize-=1) < 0) + return G_BAD_TOK_HEADER; + toid.len = *buf++; + + if ((toksize-=toid.len) < 0) + return G_BAD_TOK_HEADER; + toid.data = buf; + buf+=toid.len; + + if (! g_OID_equal(&toid, mech)) + ret = G_WRONG_MECH; + + /* G_WRONG_MECH is not returned immediately because it's more important + to return G_BAD_TOK_HEADER if the token header is in fact bad */ + + if ((toksize-=2) < 0) + return G_BAD_TOK_HEADER; + + if (ret) + return ret; + + *buf_in = buf; + *body_size = toksize; + + return ret; +} + +EXPORT_SYMBOL_GPL(g_verify_token_header); diff --git a/net/sunrpc/auth_gss/gss_krb5_crypto.c b/net/sunrpc/auth_gss/gss_krb5_crypto.c new file mode 100644 index 000000000..3ea58175e --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_crypto.c @@ -0,0 +1,810 @@ +/* + * linux/net/sunrpc/gss_krb5_crypto.c + * + * Copyright (c) 2000-2008 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson + * Bruce Fields + */ + +/* + * Copyright (C) 1998 by the FundsXpress, INC. + * + * All rights reserved. + * + * Export of this software from the United States of America may require + * a specific license from the United States Government. It is the + * responsibility of any person or organization contemplating export to + * obtain such a license before exporting. + * + * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and + * distribute this software and its documentation for any purpose and + * without fee is hereby granted, provided that the above copyright + * notice appear in all copies and that both that copyright notice and + * this permission notice appear in supporting documentation, and that + * the name of FundsXpress. not be used in advertising or publicity pertaining + * to distribution of the software without specific, written prior + * permission. FundsXpress makes no representations about the suitability of + * this software for any purpose. It is provided "as is" without express + * or implied warranty. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED + * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +u32 +krb5_encrypt( + struct crypto_sync_skcipher *tfm, + void * iv, + void * in, + void * out, + int length) +{ + u32 ret = -EINVAL; + struct scatterlist sg[1]; + u8 local_iv[GSS_KRB5_MAX_BLOCKSIZE] = {0}; + SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm); + + if (length % crypto_sync_skcipher_blocksize(tfm) != 0) + goto out; + + if (crypto_sync_skcipher_ivsize(tfm) > GSS_KRB5_MAX_BLOCKSIZE) { + dprintk("RPC: gss_k5encrypt: tfm iv size too large %d\n", + crypto_sync_skcipher_ivsize(tfm)); + goto out; + } + + if (iv) + memcpy(local_iv, iv, crypto_sync_skcipher_ivsize(tfm)); + + memcpy(out, in, length); + sg_init_one(sg, out, length); + + skcipher_request_set_sync_tfm(req, tfm); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, sg, sg, length, local_iv); + + ret = crypto_skcipher_encrypt(req); + skcipher_request_zero(req); +out: + dprintk("RPC: krb5_encrypt returns %d\n", ret); + return ret; +} + +u32 +krb5_decrypt( + struct crypto_sync_skcipher *tfm, + void * iv, + void * in, + void * out, + int length) +{ + u32 ret = -EINVAL; + struct scatterlist sg[1]; + u8 local_iv[GSS_KRB5_MAX_BLOCKSIZE] = {0}; + SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm); + + if (length % crypto_sync_skcipher_blocksize(tfm) != 0) + goto out; + + if (crypto_sync_skcipher_ivsize(tfm) > GSS_KRB5_MAX_BLOCKSIZE) { + dprintk("RPC: gss_k5decrypt: tfm iv size too large %d\n", + crypto_sync_skcipher_ivsize(tfm)); + goto out; + } + if (iv) + memcpy(local_iv, iv, crypto_sync_skcipher_ivsize(tfm)); + + memcpy(out, in, length); + sg_init_one(sg, out, length); + + skcipher_request_set_sync_tfm(req, tfm); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, sg, sg, length, local_iv); + + ret = crypto_skcipher_decrypt(req); + skcipher_request_zero(req); +out: + dprintk("RPC: gss_k5decrypt returns %d\n",ret); + return ret; +} + +static int +checksummer(struct scatterlist *sg, void *data) +{ + struct ahash_request *req = data; + + ahash_request_set_crypt(req, sg, NULL, sg->length); + + return crypto_ahash_update(req); +} + +/* + * checksum the plaintext data and hdrlen bytes of the token header + * The checksum is performed over the first 8 bytes of the + * gss token header and then over the data body + */ +u32 +make_checksum(struct krb5_ctx *kctx, char *header, int hdrlen, + struct xdr_buf *body, int body_offset, u8 *cksumkey, + unsigned int usage, struct xdr_netobj *cksumout) +{ + struct crypto_ahash *tfm; + struct ahash_request *req; + struct scatterlist sg[1]; + int err = -1; + u8 *checksumdata; + unsigned int checksumlen; + + if (cksumout->len < kctx->gk5e->cksumlength) { + dprintk("%s: checksum buffer length, %u, too small for %s\n", + __func__, cksumout->len, kctx->gk5e->name); + return GSS_S_FAILURE; + } + + checksumdata = kmalloc(GSS_KRB5_MAX_CKSUM_LEN, GFP_KERNEL); + if (checksumdata == NULL) + return GSS_S_FAILURE; + + tfm = crypto_alloc_ahash(kctx->gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(tfm)) + goto out_free_cksum; + + req = ahash_request_alloc(tfm, GFP_KERNEL); + if (!req) + goto out_free_ahash; + + ahash_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP, NULL, NULL); + + checksumlen = crypto_ahash_digestsize(tfm); + + if (cksumkey != NULL) { + err = crypto_ahash_setkey(tfm, cksumkey, + kctx->gk5e->keylength); + if (err) + goto out; + } + + err = crypto_ahash_init(req); + if (err) + goto out; + sg_init_one(sg, header, hdrlen); + ahash_request_set_crypt(req, sg, NULL, hdrlen); + err = crypto_ahash_update(req); + if (err) + goto out; + err = xdr_process_buf(body, body_offset, body->len - body_offset, + checksummer, req); + if (err) + goto out; + ahash_request_set_crypt(req, NULL, checksumdata, 0); + err = crypto_ahash_final(req); + if (err) + goto out; + + switch (kctx->gk5e->ctype) { + case CKSUMTYPE_RSA_MD5: + err = kctx->gk5e->encrypt(kctx->seq, NULL, checksumdata, + checksumdata, checksumlen); + if (err) + goto out; + memcpy(cksumout->data, + checksumdata + checksumlen - kctx->gk5e->cksumlength, + kctx->gk5e->cksumlength); + break; + case CKSUMTYPE_HMAC_SHA1_DES3: + memcpy(cksumout->data, checksumdata, kctx->gk5e->cksumlength); + break; + default: + BUG(); + break; + } + cksumout->len = kctx->gk5e->cksumlength; +out: + ahash_request_free(req); +out_free_ahash: + crypto_free_ahash(tfm); +out_free_cksum: + kfree(checksumdata); + return err ? GSS_S_FAILURE : 0; +} + +/* + * checksum the plaintext data and hdrlen bytes of the token header + * Per rfc4121, sec. 4.2.4, the checksum is performed over the data + * body then over the first 16 octets of the MIC token + * Inclusion of the header data in the calculation of the + * checksum is optional. + */ +u32 +make_checksum_v2(struct krb5_ctx *kctx, char *header, int hdrlen, + struct xdr_buf *body, int body_offset, u8 *cksumkey, + unsigned int usage, struct xdr_netobj *cksumout) +{ + struct crypto_ahash *tfm; + struct ahash_request *req; + struct scatterlist sg[1]; + int err = -1; + u8 *checksumdata; + + if (kctx->gk5e->keyed_cksum == 0) { + dprintk("%s: expected keyed hash for %s\n", + __func__, kctx->gk5e->name); + return GSS_S_FAILURE; + } + if (cksumkey == NULL) { + dprintk("%s: no key supplied for %s\n", + __func__, kctx->gk5e->name); + return GSS_S_FAILURE; + } + + checksumdata = kmalloc(GSS_KRB5_MAX_CKSUM_LEN, GFP_KERNEL); + if (!checksumdata) + return GSS_S_FAILURE; + + tfm = crypto_alloc_ahash(kctx->gk5e->cksum_name, 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(tfm)) + goto out_free_cksum; + + req = ahash_request_alloc(tfm, GFP_KERNEL); + if (!req) + goto out_free_ahash; + + ahash_request_set_callback(req, CRYPTO_TFM_REQ_MAY_SLEEP, NULL, NULL); + + err = crypto_ahash_setkey(tfm, cksumkey, kctx->gk5e->keylength); + if (err) + goto out; + + err = crypto_ahash_init(req); + if (err) + goto out; + err = xdr_process_buf(body, body_offset, body->len - body_offset, + checksummer, req); + if (err) + goto out; + if (header != NULL) { + sg_init_one(sg, header, hdrlen); + ahash_request_set_crypt(req, sg, NULL, hdrlen); + err = crypto_ahash_update(req); + if (err) + goto out; + } + ahash_request_set_crypt(req, NULL, checksumdata, 0); + err = crypto_ahash_final(req); + if (err) + goto out; + + cksumout->len = kctx->gk5e->cksumlength; + + switch (kctx->gk5e->ctype) { + case CKSUMTYPE_HMAC_SHA1_96_AES128: + case CKSUMTYPE_HMAC_SHA1_96_AES256: + /* note that this truncates the hash */ + memcpy(cksumout->data, checksumdata, kctx->gk5e->cksumlength); + break; + default: + BUG(); + break; + } +out: + ahash_request_free(req); +out_free_ahash: + crypto_free_ahash(tfm); +out_free_cksum: + kfree(checksumdata); + return err ? GSS_S_FAILURE : 0; +} + +struct encryptor_desc { + u8 iv[GSS_KRB5_MAX_BLOCKSIZE]; + struct skcipher_request *req; + int pos; + struct xdr_buf *outbuf; + struct page **pages; + struct scatterlist infrags[4]; + struct scatterlist outfrags[4]; + int fragno; + int fraglen; +}; + +static int +encryptor(struct scatterlist *sg, void *data) +{ + struct encryptor_desc *desc = data; + struct xdr_buf *outbuf = desc->outbuf; + struct crypto_sync_skcipher *tfm = + crypto_sync_skcipher_reqtfm(desc->req); + struct page *in_page; + int thislen = desc->fraglen + sg->length; + int fraglen, ret; + int page_pos; + + /* Worst case is 4 fragments: head, end of page 1, start + * of page 2, tail. Anything more is a bug. */ + BUG_ON(desc->fragno > 3); + + page_pos = desc->pos - outbuf->head[0].iov_len; + if (page_pos >= 0 && page_pos < outbuf->page_len) { + /* pages are not in place: */ + int i = (page_pos + outbuf->page_base) >> PAGE_SHIFT; + in_page = desc->pages[i]; + } else { + in_page = sg_page(sg); + } + sg_set_page(&desc->infrags[desc->fragno], in_page, sg->length, + sg->offset); + sg_set_page(&desc->outfrags[desc->fragno], sg_page(sg), sg->length, + sg->offset); + desc->fragno++; + desc->fraglen += sg->length; + desc->pos += sg->length; + + fraglen = thislen & (crypto_sync_skcipher_blocksize(tfm) - 1); + thislen -= fraglen; + + if (thislen == 0) + return 0; + + sg_mark_end(&desc->infrags[desc->fragno - 1]); + sg_mark_end(&desc->outfrags[desc->fragno - 1]); + + skcipher_request_set_crypt(desc->req, desc->infrags, desc->outfrags, + thislen, desc->iv); + + ret = crypto_skcipher_encrypt(desc->req); + if (ret) + return ret; + + sg_init_table(desc->infrags, 4); + sg_init_table(desc->outfrags, 4); + + if (fraglen) { + sg_set_page(&desc->outfrags[0], sg_page(sg), fraglen, + sg->offset + sg->length - fraglen); + desc->infrags[0] = desc->outfrags[0]; + sg_assign_page(&desc->infrags[0], in_page); + desc->fragno = 1; + desc->fraglen = fraglen; + } else { + desc->fragno = 0; + desc->fraglen = 0; + } + return 0; +} + +int +gss_encrypt_xdr_buf(struct crypto_sync_skcipher *tfm, struct xdr_buf *buf, + int offset, struct page **pages) +{ + int ret; + struct encryptor_desc desc; + SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm); + + BUG_ON((buf->len - offset) % crypto_sync_skcipher_blocksize(tfm) != 0); + + skcipher_request_set_sync_tfm(req, tfm); + skcipher_request_set_callback(req, 0, NULL, NULL); + + memset(desc.iv, 0, sizeof(desc.iv)); + desc.req = req; + desc.pos = offset; + desc.outbuf = buf; + desc.pages = pages; + desc.fragno = 0; + desc.fraglen = 0; + + sg_init_table(desc.infrags, 4); + sg_init_table(desc.outfrags, 4); + + ret = xdr_process_buf(buf, offset, buf->len - offset, encryptor, &desc); + skcipher_request_zero(req); + return ret; +} + +struct decryptor_desc { + u8 iv[GSS_KRB5_MAX_BLOCKSIZE]; + struct skcipher_request *req; + struct scatterlist frags[4]; + int fragno; + int fraglen; +}; + +static int +decryptor(struct scatterlist *sg, void *data) +{ + struct decryptor_desc *desc = data; + int thislen = desc->fraglen + sg->length; + struct crypto_sync_skcipher *tfm = + crypto_sync_skcipher_reqtfm(desc->req); + int fraglen, ret; + + /* Worst case is 4 fragments: head, end of page 1, start + * of page 2, tail. Anything more is a bug. */ + BUG_ON(desc->fragno > 3); + sg_set_page(&desc->frags[desc->fragno], sg_page(sg), sg->length, + sg->offset); + desc->fragno++; + desc->fraglen += sg->length; + + fraglen = thislen & (crypto_sync_skcipher_blocksize(tfm) - 1); + thislen -= fraglen; + + if (thislen == 0) + return 0; + + sg_mark_end(&desc->frags[desc->fragno - 1]); + + skcipher_request_set_crypt(desc->req, desc->frags, desc->frags, + thislen, desc->iv); + + ret = crypto_skcipher_decrypt(desc->req); + if (ret) + return ret; + + sg_init_table(desc->frags, 4); + + if (fraglen) { + sg_set_page(&desc->frags[0], sg_page(sg), fraglen, + sg->offset + sg->length - fraglen); + desc->fragno = 1; + desc->fraglen = fraglen; + } else { + desc->fragno = 0; + desc->fraglen = 0; + } + return 0; +} + +int +gss_decrypt_xdr_buf(struct crypto_sync_skcipher *tfm, struct xdr_buf *buf, + int offset) +{ + int ret; + struct decryptor_desc desc; + SYNC_SKCIPHER_REQUEST_ON_STACK(req, tfm); + + /* XXXJBF: */ + BUG_ON((buf->len - offset) % crypto_sync_skcipher_blocksize(tfm) != 0); + + skcipher_request_set_sync_tfm(req, tfm); + skcipher_request_set_callback(req, 0, NULL, NULL); + + memset(desc.iv, 0, sizeof(desc.iv)); + desc.req = req; + desc.fragno = 0; + desc.fraglen = 0; + + sg_init_table(desc.frags, 4); + + ret = xdr_process_buf(buf, offset, buf->len - offset, decryptor, &desc); + skcipher_request_zero(req); + return ret; +} + +/* + * This function makes the assumption that it was ultimately called + * from gss_wrap(). + * + * The client auth_gss code moves any existing tail data into a + * separate page before calling gss_wrap. + * The server svcauth_gss code ensures that both the head and the + * tail have slack space of RPC_MAX_AUTH_SIZE before calling gss_wrap. + * + * Even with that guarantee, this function may be called more than + * once in the processing of gss_wrap(). The best we can do is + * verify at compile-time (see GSS_KRB5_SLACK_CHECK) that the + * largest expected shift will fit within RPC_MAX_AUTH_SIZE. + * At run-time we can verify that a single invocation of this + * function doesn't attempt to use more the RPC_MAX_AUTH_SIZE. + */ + +int +xdr_extend_head(struct xdr_buf *buf, unsigned int base, unsigned int shiftlen) +{ + u8 *p; + + if (shiftlen == 0) + return 0; + + BUILD_BUG_ON(GSS_KRB5_MAX_SLACK_NEEDED > RPC_MAX_AUTH_SIZE); + BUG_ON(shiftlen > RPC_MAX_AUTH_SIZE); + + p = buf->head[0].iov_base + base; + + memmove(p + shiftlen, p, buf->head[0].iov_len - base); + + buf->head[0].iov_len += shiftlen; + buf->len += shiftlen; + + return 0; +} + +static u32 +gss_krb5_cts_crypt(struct crypto_sync_skcipher *cipher, struct xdr_buf *buf, + u32 offset, u8 *iv, struct page **pages, int encrypt) +{ + u32 ret; + struct scatterlist sg[1]; + SYNC_SKCIPHER_REQUEST_ON_STACK(req, cipher); + u8 *data; + struct page **save_pages; + u32 len = buf->len - offset; + + if (len > GSS_KRB5_MAX_BLOCKSIZE * 2) { + WARN_ON(0); + return -ENOMEM; + } + data = kmalloc(GSS_KRB5_MAX_BLOCKSIZE * 2, GFP_KERNEL); + if (!data) + return -ENOMEM; + + /* + * For encryption, we want to read from the cleartext + * page cache pages, and write the encrypted data to + * the supplied xdr_buf pages. + */ + save_pages = buf->pages; + if (encrypt) + buf->pages = pages; + + ret = read_bytes_from_xdr_buf(buf, offset, data, len); + buf->pages = save_pages; + if (ret) + goto out; + + sg_init_one(sg, data, len); + + skcipher_request_set_sync_tfm(req, cipher); + skcipher_request_set_callback(req, 0, NULL, NULL); + skcipher_request_set_crypt(req, sg, sg, len, iv); + + if (encrypt) + ret = crypto_skcipher_encrypt(req); + else + ret = crypto_skcipher_decrypt(req); + + skcipher_request_zero(req); + + if (ret) + goto out; + + ret = write_bytes_to_xdr_buf(buf, offset, data, len); + +out: + kfree(data); + return ret; +} + +u32 +gss_krb5_aes_encrypt(struct krb5_ctx *kctx, u32 offset, + struct xdr_buf *buf, struct page **pages) +{ + u32 err; + struct xdr_netobj hmac; + u8 *cksumkey; + u8 *ecptr; + struct crypto_sync_skcipher *cipher, *aux_cipher; + int blocksize; + struct page **save_pages; + int nblocks, nbytes; + struct encryptor_desc desc; + u32 cbcbytes; + unsigned int usage; + + if (kctx->initiate) { + cipher = kctx->initiator_enc; + aux_cipher = kctx->initiator_enc_aux; + cksumkey = kctx->initiator_integ; + usage = KG_USAGE_INITIATOR_SEAL; + } else { + cipher = kctx->acceptor_enc; + aux_cipher = kctx->acceptor_enc_aux; + cksumkey = kctx->acceptor_integ; + usage = KG_USAGE_ACCEPTOR_SEAL; + } + blocksize = crypto_sync_skcipher_blocksize(cipher); + + /* hide the gss token header and insert the confounder */ + offset += GSS_KRB5_TOK_HDR_LEN; + if (xdr_extend_head(buf, offset, kctx->gk5e->conflen)) + return GSS_S_FAILURE; + gss_krb5_make_confounder(buf->head[0].iov_base + offset, kctx->gk5e->conflen); + offset -= GSS_KRB5_TOK_HDR_LEN; + + if (buf->tail[0].iov_base != NULL) { + ecptr = buf->tail[0].iov_base + buf->tail[0].iov_len; + } else { + buf->tail[0].iov_base = buf->head[0].iov_base + + buf->head[0].iov_len; + buf->tail[0].iov_len = 0; + ecptr = buf->tail[0].iov_base; + } + + /* copy plaintext gss token header after filler (if any) */ + memcpy(ecptr, buf->head[0].iov_base + offset, GSS_KRB5_TOK_HDR_LEN); + buf->tail[0].iov_len += GSS_KRB5_TOK_HDR_LEN; + buf->len += GSS_KRB5_TOK_HDR_LEN; + + /* Do the HMAC */ + hmac.len = GSS_KRB5_MAX_CKSUM_LEN; + hmac.data = buf->tail[0].iov_base + buf->tail[0].iov_len; + + /* + * When we are called, pages points to the real page cache + * data -- which we can't go and encrypt! buf->pages points + * to scratch pages which we are going to send off to the + * client/server. Swap in the plaintext pages to calculate + * the hmac. + */ + save_pages = buf->pages; + buf->pages = pages; + + err = make_checksum_v2(kctx, NULL, 0, buf, + offset + GSS_KRB5_TOK_HDR_LEN, + cksumkey, usage, &hmac); + buf->pages = save_pages; + if (err) + return GSS_S_FAILURE; + + nbytes = buf->len - offset - GSS_KRB5_TOK_HDR_LEN; + nblocks = (nbytes + blocksize - 1) / blocksize; + cbcbytes = 0; + if (nblocks > 2) + cbcbytes = (nblocks - 2) * blocksize; + + memset(desc.iv, 0, sizeof(desc.iv)); + + if (cbcbytes) { + SYNC_SKCIPHER_REQUEST_ON_STACK(req, aux_cipher); + + desc.pos = offset + GSS_KRB5_TOK_HDR_LEN; + desc.fragno = 0; + desc.fraglen = 0; + desc.pages = pages; + desc.outbuf = buf; + desc.req = req; + + skcipher_request_set_sync_tfm(req, aux_cipher); + skcipher_request_set_callback(req, 0, NULL, NULL); + + sg_init_table(desc.infrags, 4); + sg_init_table(desc.outfrags, 4); + + err = xdr_process_buf(buf, offset + GSS_KRB5_TOK_HDR_LEN, + cbcbytes, encryptor, &desc); + skcipher_request_zero(req); + if (err) + goto out_err; + } + + /* Make sure IV carries forward from any CBC results. */ + err = gss_krb5_cts_crypt(cipher, buf, + offset + GSS_KRB5_TOK_HDR_LEN + cbcbytes, + desc.iv, pages, 1); + if (err) { + err = GSS_S_FAILURE; + goto out_err; + } + + /* Now update buf to account for HMAC */ + buf->tail[0].iov_len += kctx->gk5e->cksumlength; + buf->len += kctx->gk5e->cksumlength; + +out_err: + if (err) + err = GSS_S_FAILURE; + return err; +} + +u32 +gss_krb5_aes_decrypt(struct krb5_ctx *kctx, u32 offset, u32 len, + struct xdr_buf *buf, u32 *headskip, u32 *tailskip) +{ + struct xdr_buf subbuf; + u32 ret = 0; + u8 *cksum_key; + struct crypto_sync_skcipher *cipher, *aux_cipher; + struct xdr_netobj our_hmac_obj; + u8 our_hmac[GSS_KRB5_MAX_CKSUM_LEN]; + u8 pkt_hmac[GSS_KRB5_MAX_CKSUM_LEN]; + int nblocks, blocksize, cbcbytes; + struct decryptor_desc desc; + unsigned int usage; + + if (kctx->initiate) { + cipher = kctx->acceptor_enc; + aux_cipher = kctx->acceptor_enc_aux; + cksum_key = kctx->acceptor_integ; + usage = KG_USAGE_ACCEPTOR_SEAL; + } else { + cipher = kctx->initiator_enc; + aux_cipher = kctx->initiator_enc_aux; + cksum_key = kctx->initiator_integ; + usage = KG_USAGE_INITIATOR_SEAL; + } + blocksize = crypto_sync_skcipher_blocksize(cipher); + + + /* create a segment skipping the header and leaving out the checksum */ + xdr_buf_subsegment(buf, &subbuf, offset + GSS_KRB5_TOK_HDR_LEN, + (len - offset - GSS_KRB5_TOK_HDR_LEN - + kctx->gk5e->cksumlength)); + + nblocks = (subbuf.len + blocksize - 1) / blocksize; + + cbcbytes = 0; + if (nblocks > 2) + cbcbytes = (nblocks - 2) * blocksize; + + memset(desc.iv, 0, sizeof(desc.iv)); + + if (cbcbytes) { + SYNC_SKCIPHER_REQUEST_ON_STACK(req, aux_cipher); + + desc.fragno = 0; + desc.fraglen = 0; + desc.req = req; + + skcipher_request_set_sync_tfm(req, aux_cipher); + skcipher_request_set_callback(req, 0, NULL, NULL); + + sg_init_table(desc.frags, 4); + + ret = xdr_process_buf(&subbuf, 0, cbcbytes, decryptor, &desc); + skcipher_request_zero(req); + if (ret) + goto out_err; + } + + /* Make sure IV carries forward from any CBC results. */ + ret = gss_krb5_cts_crypt(cipher, &subbuf, cbcbytes, desc.iv, NULL, 0); + if (ret) + goto out_err; + + + /* Calculate our hmac over the plaintext data */ + our_hmac_obj.len = sizeof(our_hmac); + our_hmac_obj.data = our_hmac; + + ret = make_checksum_v2(kctx, NULL, 0, &subbuf, 0, + cksum_key, usage, &our_hmac_obj); + if (ret) + goto out_err; + + /* Get the packet's hmac value */ + ret = read_bytes_from_xdr_buf(buf, len - kctx->gk5e->cksumlength, + pkt_hmac, kctx->gk5e->cksumlength); + if (ret) + goto out_err; + + if (crypto_memneq(pkt_hmac, our_hmac, kctx->gk5e->cksumlength) != 0) { + ret = GSS_S_BAD_SIG; + goto out_err; + } + *headskip = kctx->gk5e->conflen; + *tailskip = kctx->gk5e->cksumlength; +out_err: + if (ret && ret != GSS_S_BAD_SIG) + ret = GSS_S_FAILURE; + return ret; +} diff --git a/net/sunrpc/auth_gss/gss_krb5_keys.c b/net/sunrpc/auth_gss/gss_krb5_keys.c new file mode 100644 index 000000000..726c07695 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_keys.c @@ -0,0 +1,322 @@ +/* + * COPYRIGHT (c) 2008 + * The Regents of the University of Michigan + * ALL RIGHTS RESERVED + * + * Permission is granted to use, copy, create derivative works + * and redistribute this software and such derivative works + * for any purpose, so long as the name of The University of + * Michigan is not used in any advertising or publicity + * pertaining to the use of distribution of this software + * without specific, written prior authorization. If the + * above copyright notice or any other identification of the + * University of Michigan is included in any copy of any + * portion of this software, then the disclaimer below must + * also be included. + * + * THIS SOFTWARE IS PROVIDED AS IS, WITHOUT REPRESENTATION + * FROM THE UNIVERSITY OF MICHIGAN AS TO ITS FITNESS FOR ANY + * PURPOSE, AND WITHOUT WARRANTY BY THE UNIVERSITY OF + * MICHIGAN OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * WITHOUT LIMITATION THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE + * REGENTS OF THE UNIVERSITY OF MICHIGAN SHALL NOT BE LIABLE + * FOR ANY DAMAGES, INCLUDING SPECIAL, INDIRECT, INCIDENTAL, OR + * CONSEQUENTIAL DAMAGES, WITH RESPECT TO ANY CLAIM ARISING + * OUT OF OR IN CONNECTION WITH THE USE OF THE SOFTWARE, EVEN + * IF IT HAS BEEN OR IS HEREAFTER ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGES. + */ + +/* + * Copyright (C) 1998 by the FundsXpress, INC. + * + * All rights reserved. + * + * Export of this software from the United States of America may require + * a specific license from the United States Government. It is the + * responsibility of any person or organization contemplating export to + * obtain such a license before exporting. + * + * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and + * distribute this software and its documentation for any purpose and + * without fee is hereby granted, provided that the above copyright + * notice appear in all copies and that both that copyright notice and + * this permission notice appear in supporting documentation, and that + * the name of FundsXpress. not be used in advertising or publicity pertaining + * to distribution of the software without specific, written prior + * permission. FundsXpress makes no representations about the suitability of + * this software for any purpose. It is provided "as is" without express + * or implied warranty. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED + * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE. + */ + +#include +#include +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +/* + * This is the n-fold function as described in rfc3961, sec 5.1 + * Taken from MIT Kerberos and modified. + */ + +static void krb5_nfold(u32 inbits, const u8 *in, + u32 outbits, u8 *out) +{ + unsigned long ulcm; + int byte, i, msbit; + + /* the code below is more readable if I make these bytes + instead of bits */ + + inbits >>= 3; + outbits >>= 3; + + /* first compute lcm(n,k) */ + ulcm = lcm(inbits, outbits); + + /* now do the real work */ + + memset(out, 0, outbits); + byte = 0; + + /* this will end up cycling through k lcm(k,n)/k times, which + is correct */ + for (i = ulcm-1; i >= 0; i--) { + /* compute the msbit in k which gets added into this byte */ + msbit = ( + /* first, start with the msbit in the first, + * unrotated byte */ + ((inbits << 3) - 1) + /* then, for each byte, shift to the right + * for each repetition */ + + (((inbits << 3) + 13) * (i/inbits)) + /* last, pick out the correct byte within + * that shifted repetition */ + + ((inbits - (i % inbits)) << 3) + ) % (inbits << 3); + + /* pull out the byte value itself */ + byte += (((in[((inbits - 1) - (msbit >> 3)) % inbits] << 8)| + (in[((inbits) - (msbit >> 3)) % inbits])) + >> ((msbit & 7) + 1)) & 0xff; + + /* do the addition */ + byte += out[i % outbits]; + out[i % outbits] = byte & 0xff; + + /* keep around the carry bit, if any */ + byte >>= 8; + + } + + /* if there's a carry bit left over, add it back in */ + if (byte) { + for (i = outbits - 1; i >= 0; i--) { + /* do the addition */ + byte += out[i]; + out[i] = byte & 0xff; + + /* keep around the carry bit, if any */ + byte >>= 8; + } + } +} + +/* + * This is the DK (derive_key) function as described in rfc3961, sec 5.1 + * Taken from MIT Kerberos and modified. + */ + +u32 krb5_derive_key(const struct gss_krb5_enctype *gk5e, + const struct xdr_netobj *inkey, + struct xdr_netobj *outkey, + const struct xdr_netobj *in_constant, + gfp_t gfp_mask) +{ + size_t blocksize, keybytes, keylength, n; + unsigned char *inblockdata, *outblockdata, *rawkey; + struct xdr_netobj inblock, outblock; + struct crypto_sync_skcipher *cipher; + u32 ret = EINVAL; + + blocksize = gk5e->blocksize; + keybytes = gk5e->keybytes; + keylength = gk5e->keylength; + + if ((inkey->len != keylength) || (outkey->len != keylength)) + goto err_return; + + cipher = crypto_alloc_sync_skcipher(gk5e->encrypt_name, 0, 0); + if (IS_ERR(cipher)) + goto err_return; + if (crypto_sync_skcipher_setkey(cipher, inkey->data, inkey->len)) + goto err_return; + + /* allocate and set up buffers */ + + ret = ENOMEM; + inblockdata = kmalloc(blocksize, gfp_mask); + if (inblockdata == NULL) + goto err_free_cipher; + + outblockdata = kmalloc(blocksize, gfp_mask); + if (outblockdata == NULL) + goto err_free_in; + + rawkey = kmalloc(keybytes, gfp_mask); + if (rawkey == NULL) + goto err_free_out; + + inblock.data = (char *) inblockdata; + inblock.len = blocksize; + + outblock.data = (char *) outblockdata; + outblock.len = blocksize; + + /* initialize the input block */ + + if (in_constant->len == inblock.len) { + memcpy(inblock.data, in_constant->data, inblock.len); + } else { + krb5_nfold(in_constant->len * 8, in_constant->data, + inblock.len * 8, inblock.data); + } + + /* loop encrypting the blocks until enough key bytes are generated */ + + n = 0; + while (n < keybytes) { + (*(gk5e->encrypt))(cipher, NULL, inblock.data, + outblock.data, inblock.len); + + if ((keybytes - n) <= outblock.len) { + memcpy(rawkey + n, outblock.data, (keybytes - n)); + break; + } + + memcpy(rawkey + n, outblock.data, outblock.len); + memcpy(inblock.data, outblock.data, outblock.len); + n += outblock.len; + } + + /* postprocess the key */ + + inblock.data = (char *) rawkey; + inblock.len = keybytes; + + BUG_ON(gk5e->mk_key == NULL); + ret = (*(gk5e->mk_key))(gk5e, &inblock, outkey); + if (ret) { + dprintk("%s: got %d from mk_key function for '%s'\n", + __func__, ret, gk5e->encrypt_name); + goto err_free_raw; + } + + /* clean memory, free resources and exit */ + + ret = 0; + +err_free_raw: + kfree_sensitive(rawkey); +err_free_out: + kfree_sensitive(outblockdata); +err_free_in: + kfree_sensitive(inblockdata); +err_free_cipher: + crypto_free_sync_skcipher(cipher); +err_return: + return ret; +} + +#define smask(step) ((1<>step)&smask(step))) +#define parity_char(x) pstep(pstep(pstep((x), 4), 2), 1) + +static void mit_des_fixup_key_parity(u8 key[8]) +{ + int i; + for (i = 0; i < 8; i++) { + key[i] &= 0xfe; + key[i] |= 1^parity_char(key[i]); + } +} + +/* + * This is the des3 key derivation postprocess function + */ +u32 gss_krb5_des3_make_key(const struct gss_krb5_enctype *gk5e, + struct xdr_netobj *randombits, + struct xdr_netobj *key) +{ + int i; + u32 ret = EINVAL; + + if (key->len != 24) { + dprintk("%s: key->len is %d\n", __func__, key->len); + goto err_out; + } + if (randombits->len != 21) { + dprintk("%s: randombits->len is %d\n", + __func__, randombits->len); + goto err_out; + } + + /* take the seven bytes, move them around into the top 7 bits of the + 8 key bytes, then compute the parity bits. Do this three times. */ + + for (i = 0; i < 3; i++) { + memcpy(key->data + i*8, randombits->data + i*7, 7); + key->data[i*8+7] = (((key->data[i*8]&1)<<1) | + ((key->data[i*8+1]&1)<<2) | + ((key->data[i*8+2]&1)<<3) | + ((key->data[i*8+3]&1)<<4) | + ((key->data[i*8+4]&1)<<5) | + ((key->data[i*8+5]&1)<<6) | + ((key->data[i*8+6]&1)<<7)); + + mit_des_fixup_key_parity(key->data + i*8); + } + ret = 0; +err_out: + return ret; +} + +/* + * This is the aes key derivation postprocess function + */ +u32 gss_krb5_aes_make_key(const struct gss_krb5_enctype *gk5e, + struct xdr_netobj *randombits, + struct xdr_netobj *key) +{ + u32 ret = EINVAL; + + if (key->len != 16 && key->len != 32) { + dprintk("%s: key->len is %d\n", __func__, key->len); + goto err_out; + } + if (randombits->len != 16 && randombits->len != 32) { + dprintk("%s: randombits->len is %d\n", + __func__, randombits->len); + goto err_out; + } + if (randombits->len != key->len) { + dprintk("%s: randombits->len is %d, key->len is %d\n", + __func__, randombits->len, key->len); + goto err_out; + } + memcpy(key->data, randombits->data, key->len); + ret = 0; +err_out: + return ret; +} diff --git a/net/sunrpc/auth_gss/gss_krb5_mech.c b/net/sunrpc/auth_gss/gss_krb5_mech.c new file mode 100644 index 000000000..1c092b05c --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_mech.c @@ -0,0 +1,654 @@ +// SPDX-License-Identifier: BSD-3-Clause +/* + * linux/net/sunrpc/gss_krb5_mech.c + * + * Copyright (c) 2001-2008 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson + * J. Bruce Fields + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "auth_gss_internal.h" + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static struct gss_api_mech gss_kerberos_mech; /* forward declaration */ + +static const struct gss_krb5_enctype supported_gss_krb5_enctypes[] = { +#ifndef CONFIG_SUNRPC_DISABLE_INSECURE_ENCTYPES + /* + * DES (All DES enctypes are mapped to the same gss functionality) + */ + { + .etype = ENCTYPE_DES_CBC_RAW, + .ctype = CKSUMTYPE_RSA_MD5, + .name = "des-cbc-crc", + .encrypt_name = "cbc(des)", + .cksum_name = "md5", + .encrypt = krb5_encrypt, + .decrypt = krb5_decrypt, + .mk_key = NULL, + .signalg = SGN_ALG_DES_MAC_MD5, + .sealalg = SEAL_ALG_DES, + .keybytes = 7, + .keylength = 8, + .blocksize = 8, + .conflen = 8, + .cksumlength = 8, + .keyed_cksum = 0, + }, +#endif /* CONFIG_SUNRPC_DISABLE_INSECURE_ENCTYPES */ + /* + * 3DES + */ + { + .etype = ENCTYPE_DES3_CBC_RAW, + .ctype = CKSUMTYPE_HMAC_SHA1_DES3, + .name = "des3-hmac-sha1", + .encrypt_name = "cbc(des3_ede)", + .cksum_name = "hmac(sha1)", + .encrypt = krb5_encrypt, + .decrypt = krb5_decrypt, + .mk_key = gss_krb5_des3_make_key, + .signalg = SGN_ALG_HMAC_SHA1_DES3_KD, + .sealalg = SEAL_ALG_DES3KD, + .keybytes = 21, + .keylength = 24, + .blocksize = 8, + .conflen = 8, + .cksumlength = 20, + .keyed_cksum = 1, + }, + /* + * AES128 + */ + { + .etype = ENCTYPE_AES128_CTS_HMAC_SHA1_96, + .ctype = CKSUMTYPE_HMAC_SHA1_96_AES128, + .name = "aes128-cts", + .encrypt_name = "cts(cbc(aes))", + .cksum_name = "hmac(sha1)", + .encrypt = krb5_encrypt, + .decrypt = krb5_decrypt, + .mk_key = gss_krb5_aes_make_key, + .encrypt_v2 = gss_krb5_aes_encrypt, + .decrypt_v2 = gss_krb5_aes_decrypt, + .signalg = -1, + .sealalg = -1, + .keybytes = 16, + .keylength = 16, + .blocksize = 16, + .conflen = 16, + .cksumlength = 12, + .keyed_cksum = 1, + }, + /* + * AES256 + */ + { + .etype = ENCTYPE_AES256_CTS_HMAC_SHA1_96, + .ctype = CKSUMTYPE_HMAC_SHA1_96_AES256, + .name = "aes256-cts", + .encrypt_name = "cts(cbc(aes))", + .cksum_name = "hmac(sha1)", + .encrypt = krb5_encrypt, + .decrypt = krb5_decrypt, + .mk_key = gss_krb5_aes_make_key, + .encrypt_v2 = gss_krb5_aes_encrypt, + .decrypt_v2 = gss_krb5_aes_decrypt, + .signalg = -1, + .sealalg = -1, + .keybytes = 32, + .keylength = 32, + .blocksize = 16, + .conflen = 16, + .cksumlength = 12, + .keyed_cksum = 1, + }, +}; + +static const int num_supported_enctypes = + ARRAY_SIZE(supported_gss_krb5_enctypes); + +static int +supported_gss_krb5_enctype(int etype) +{ + int i; + for (i = 0; i < num_supported_enctypes; i++) + if (supported_gss_krb5_enctypes[i].etype == etype) + return 1; + return 0; +} + +static const struct gss_krb5_enctype * +get_gss_krb5_enctype(int etype) +{ + int i; + for (i = 0; i < num_supported_enctypes; i++) + if (supported_gss_krb5_enctypes[i].etype == etype) + return &supported_gss_krb5_enctypes[i]; + return NULL; +} + +static inline const void * +get_key(const void *p, const void *end, + struct krb5_ctx *ctx, struct crypto_sync_skcipher **res) +{ + struct xdr_netobj key; + int alg; + + p = simple_get_bytes(p, end, &alg, sizeof(alg)); + if (IS_ERR(p)) + goto out_err; + + switch (alg) { + case ENCTYPE_DES_CBC_CRC: + case ENCTYPE_DES_CBC_MD4: + case ENCTYPE_DES_CBC_MD5: + /* Map all these key types to ENCTYPE_DES_CBC_RAW */ + alg = ENCTYPE_DES_CBC_RAW; + break; + } + + if (!supported_gss_krb5_enctype(alg)) { + printk(KERN_WARNING "gss_kerberos_mech: unsupported " + "encryption key algorithm %d\n", alg); + p = ERR_PTR(-EINVAL); + goto out_err; + } + p = simple_get_netobj(p, end, &key); + if (IS_ERR(p)) + goto out_err; + + *res = crypto_alloc_sync_skcipher(ctx->gk5e->encrypt_name, 0, 0); + if (IS_ERR(*res)) { + printk(KERN_WARNING "gss_kerberos_mech: unable to initialize " + "crypto algorithm %s\n", ctx->gk5e->encrypt_name); + *res = NULL; + goto out_err_free_key; + } + if (crypto_sync_skcipher_setkey(*res, key.data, key.len)) { + printk(KERN_WARNING "gss_kerberos_mech: error setting key for " + "crypto algorithm %s\n", ctx->gk5e->encrypt_name); + goto out_err_free_tfm; + } + + kfree(key.data); + return p; + +out_err_free_tfm: + crypto_free_sync_skcipher(*res); +out_err_free_key: + kfree(key.data); + p = ERR_PTR(-EINVAL); +out_err: + return p; +} + +static int +gss_import_v1_context(const void *p, const void *end, struct krb5_ctx *ctx) +{ + u32 seq_send; + int tmp; + u32 time32; + + p = simple_get_bytes(p, end, &ctx->initiate, sizeof(ctx->initiate)); + if (IS_ERR(p)) + goto out_err; + + /* Old format supports only DES! Any other enctype uses new format */ + ctx->enctype = ENCTYPE_DES_CBC_RAW; + + ctx->gk5e = get_gss_krb5_enctype(ctx->enctype); + if (ctx->gk5e == NULL) { + p = ERR_PTR(-EINVAL); + goto out_err; + } + + /* The downcall format was designed before we completely understood + * the uses of the context fields; so it includes some stuff we + * just give some minimal sanity-checking, and some we ignore + * completely (like the next twenty bytes): */ + if (unlikely(p + 20 > end || p + 20 < p)) { + p = ERR_PTR(-EFAULT); + goto out_err; + } + p += 20; + p = simple_get_bytes(p, end, &tmp, sizeof(tmp)); + if (IS_ERR(p)) + goto out_err; + if (tmp != SGN_ALG_DES_MAC_MD5) { + p = ERR_PTR(-ENOSYS); + goto out_err; + } + p = simple_get_bytes(p, end, &tmp, sizeof(tmp)); + if (IS_ERR(p)) + goto out_err; + if (tmp != SEAL_ALG_DES) { + p = ERR_PTR(-ENOSYS); + goto out_err; + } + p = simple_get_bytes(p, end, &time32, sizeof(time32)); + if (IS_ERR(p)) + goto out_err; + /* unsigned 32-bit time overflows in year 2106 */ + ctx->endtime = (time64_t)time32; + p = simple_get_bytes(p, end, &seq_send, sizeof(seq_send)); + if (IS_ERR(p)) + goto out_err; + atomic_set(&ctx->seq_send, seq_send); + p = simple_get_netobj(p, end, &ctx->mech_used); + if (IS_ERR(p)) + goto out_err; + p = get_key(p, end, ctx, &ctx->enc); + if (IS_ERR(p)) + goto out_err_free_mech; + p = get_key(p, end, ctx, &ctx->seq); + if (IS_ERR(p)) + goto out_err_free_key1; + if (p != end) { + p = ERR_PTR(-EFAULT); + goto out_err_free_key2; + } + + return 0; + +out_err_free_key2: + crypto_free_sync_skcipher(ctx->seq); +out_err_free_key1: + crypto_free_sync_skcipher(ctx->enc); +out_err_free_mech: + kfree(ctx->mech_used.data); +out_err: + return PTR_ERR(p); +} + +static struct crypto_sync_skcipher * +context_v2_alloc_cipher(struct krb5_ctx *ctx, const char *cname, u8 *key) +{ + struct crypto_sync_skcipher *cp; + + cp = crypto_alloc_sync_skcipher(cname, 0, 0); + if (IS_ERR(cp)) { + dprintk("gss_kerberos_mech: unable to initialize " + "crypto algorithm %s\n", cname); + return NULL; + } + if (crypto_sync_skcipher_setkey(cp, key, ctx->gk5e->keylength)) { + dprintk("gss_kerberos_mech: error setting key for " + "crypto algorithm %s\n", cname); + crypto_free_sync_skcipher(cp); + return NULL; + } + return cp; +} + +static inline void +set_cdata(u8 cdata[GSS_KRB5_K5CLENGTH], u32 usage, u8 seed) +{ + cdata[0] = (usage>>24)&0xff; + cdata[1] = (usage>>16)&0xff; + cdata[2] = (usage>>8)&0xff; + cdata[3] = usage&0xff; + cdata[4] = seed; +} + +static int +context_derive_keys_des3(struct krb5_ctx *ctx, gfp_t gfp_mask) +{ + struct xdr_netobj c, keyin, keyout; + u8 cdata[GSS_KRB5_K5CLENGTH]; + u32 err; + + c.len = GSS_KRB5_K5CLENGTH; + c.data = cdata; + + keyin.data = ctx->Ksess; + keyin.len = ctx->gk5e->keylength; + keyout.len = ctx->gk5e->keylength; + + /* seq uses the raw key */ + ctx->seq = context_v2_alloc_cipher(ctx, ctx->gk5e->encrypt_name, + ctx->Ksess); + if (ctx->seq == NULL) + goto out_err; + + ctx->enc = context_v2_alloc_cipher(ctx, ctx->gk5e->encrypt_name, + ctx->Ksess); + if (ctx->enc == NULL) + goto out_free_seq; + + /* derive cksum */ + set_cdata(cdata, KG_USAGE_SIGN, KEY_USAGE_SEED_CHECKSUM); + keyout.data = ctx->cksum; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving cksum key\n", + __func__, err); + goto out_free_enc; + } + + return 0; + +out_free_enc: + crypto_free_sync_skcipher(ctx->enc); +out_free_seq: + crypto_free_sync_skcipher(ctx->seq); +out_err: + return -EINVAL; +} + +static int +context_derive_keys_new(struct krb5_ctx *ctx, gfp_t gfp_mask) +{ + struct xdr_netobj c, keyin, keyout; + u8 cdata[GSS_KRB5_K5CLENGTH]; + u32 err; + + c.len = GSS_KRB5_K5CLENGTH; + c.data = cdata; + + keyin.data = ctx->Ksess; + keyin.len = ctx->gk5e->keylength; + keyout.len = ctx->gk5e->keylength; + + /* initiator seal encryption */ + set_cdata(cdata, KG_USAGE_INITIATOR_SEAL, KEY_USAGE_SEED_ENCRYPTION); + keyout.data = ctx->initiator_seal; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving initiator_seal key\n", + __func__, err); + goto out_err; + } + ctx->initiator_enc = context_v2_alloc_cipher(ctx, + ctx->gk5e->encrypt_name, + ctx->initiator_seal); + if (ctx->initiator_enc == NULL) + goto out_err; + + /* acceptor seal encryption */ + set_cdata(cdata, KG_USAGE_ACCEPTOR_SEAL, KEY_USAGE_SEED_ENCRYPTION); + keyout.data = ctx->acceptor_seal; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving acceptor_seal key\n", + __func__, err); + goto out_free_initiator_enc; + } + ctx->acceptor_enc = context_v2_alloc_cipher(ctx, + ctx->gk5e->encrypt_name, + ctx->acceptor_seal); + if (ctx->acceptor_enc == NULL) + goto out_free_initiator_enc; + + /* initiator sign checksum */ + set_cdata(cdata, KG_USAGE_INITIATOR_SIGN, KEY_USAGE_SEED_CHECKSUM); + keyout.data = ctx->initiator_sign; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving initiator_sign key\n", + __func__, err); + goto out_free_acceptor_enc; + } + + /* acceptor sign checksum */ + set_cdata(cdata, KG_USAGE_ACCEPTOR_SIGN, KEY_USAGE_SEED_CHECKSUM); + keyout.data = ctx->acceptor_sign; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving acceptor_sign key\n", + __func__, err); + goto out_free_acceptor_enc; + } + + /* initiator seal integrity */ + set_cdata(cdata, KG_USAGE_INITIATOR_SEAL, KEY_USAGE_SEED_INTEGRITY); + keyout.data = ctx->initiator_integ; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving initiator_integ key\n", + __func__, err); + goto out_free_acceptor_enc; + } + + /* acceptor seal integrity */ + set_cdata(cdata, KG_USAGE_ACCEPTOR_SEAL, KEY_USAGE_SEED_INTEGRITY); + keyout.data = ctx->acceptor_integ; + err = krb5_derive_key(ctx->gk5e, &keyin, &keyout, &c, gfp_mask); + if (err) { + dprintk("%s: Error %d deriving acceptor_integ key\n", + __func__, err); + goto out_free_acceptor_enc; + } + + switch (ctx->enctype) { + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + ctx->initiator_enc_aux = + context_v2_alloc_cipher(ctx, "cbc(aes)", + ctx->initiator_seal); + if (ctx->initiator_enc_aux == NULL) + goto out_free_acceptor_enc; + ctx->acceptor_enc_aux = + context_v2_alloc_cipher(ctx, "cbc(aes)", + ctx->acceptor_seal); + if (ctx->acceptor_enc_aux == NULL) { + crypto_free_sync_skcipher(ctx->initiator_enc_aux); + goto out_free_acceptor_enc; + } + } + + return 0; + +out_free_acceptor_enc: + crypto_free_sync_skcipher(ctx->acceptor_enc); +out_free_initiator_enc: + crypto_free_sync_skcipher(ctx->initiator_enc); +out_err: + return -EINVAL; +} + +static int +gss_import_v2_context(const void *p, const void *end, struct krb5_ctx *ctx, + gfp_t gfp_mask) +{ + u64 seq_send64; + int keylen; + u32 time32; + + p = simple_get_bytes(p, end, &ctx->flags, sizeof(ctx->flags)); + if (IS_ERR(p)) + goto out_err; + ctx->initiate = ctx->flags & KRB5_CTX_FLAG_INITIATOR; + + p = simple_get_bytes(p, end, &time32, sizeof(time32)); + if (IS_ERR(p)) + goto out_err; + /* unsigned 32-bit time overflows in year 2106 */ + ctx->endtime = (time64_t)time32; + p = simple_get_bytes(p, end, &seq_send64, sizeof(seq_send64)); + if (IS_ERR(p)) + goto out_err; + atomic64_set(&ctx->seq_send64, seq_send64); + /* set seq_send for use by "older" enctypes */ + atomic_set(&ctx->seq_send, seq_send64); + if (seq_send64 != atomic_read(&ctx->seq_send)) { + dprintk("%s: seq_send64 %llx, seq_send %x overflow?\n", __func__, + seq_send64, atomic_read(&ctx->seq_send)); + p = ERR_PTR(-EINVAL); + goto out_err; + } + p = simple_get_bytes(p, end, &ctx->enctype, sizeof(ctx->enctype)); + if (IS_ERR(p)) + goto out_err; + /* Map ENCTYPE_DES3_CBC_SHA1 to ENCTYPE_DES3_CBC_RAW */ + if (ctx->enctype == ENCTYPE_DES3_CBC_SHA1) + ctx->enctype = ENCTYPE_DES3_CBC_RAW; + ctx->gk5e = get_gss_krb5_enctype(ctx->enctype); + if (ctx->gk5e == NULL) { + dprintk("gss_kerberos_mech: unsupported krb5 enctype %u\n", + ctx->enctype); + p = ERR_PTR(-EINVAL); + goto out_err; + } + keylen = ctx->gk5e->keylength; + + p = simple_get_bytes(p, end, ctx->Ksess, keylen); + if (IS_ERR(p)) + goto out_err; + + if (p != end) { + p = ERR_PTR(-EINVAL); + goto out_err; + } + + ctx->mech_used.data = kmemdup(gss_kerberos_mech.gm_oid.data, + gss_kerberos_mech.gm_oid.len, gfp_mask); + if (unlikely(ctx->mech_used.data == NULL)) { + p = ERR_PTR(-ENOMEM); + goto out_err; + } + ctx->mech_used.len = gss_kerberos_mech.gm_oid.len; + + switch (ctx->enctype) { + case ENCTYPE_DES3_CBC_RAW: + return context_derive_keys_des3(ctx, gfp_mask); + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + return context_derive_keys_new(ctx, gfp_mask); + default: + return -EINVAL; + } + +out_err: + return PTR_ERR(p); +} + +static int +gss_import_sec_context_kerberos(const void *p, size_t len, + struct gss_ctx *ctx_id, + time64_t *endtime, + gfp_t gfp_mask) +{ + const void *end = (const void *)((const char *)p + len); + struct krb5_ctx *ctx; + int ret; + + ctx = kzalloc(sizeof(*ctx), gfp_mask); + if (ctx == NULL) + return -ENOMEM; + + if (len == 85) + ret = gss_import_v1_context(p, end, ctx); + else + ret = gss_import_v2_context(p, end, ctx, gfp_mask); + + if (ret == 0) { + ctx_id->internal_ctx_id = ctx; + if (endtime) + *endtime = ctx->endtime; + } else + kfree(ctx); + + dprintk("RPC: %s: returning %d\n", __func__, ret); + return ret; +} + +static void +gss_delete_sec_context_kerberos(void *internal_ctx) { + struct krb5_ctx *kctx = internal_ctx; + + crypto_free_sync_skcipher(kctx->seq); + crypto_free_sync_skcipher(kctx->enc); + crypto_free_sync_skcipher(kctx->acceptor_enc); + crypto_free_sync_skcipher(kctx->initiator_enc); + crypto_free_sync_skcipher(kctx->acceptor_enc_aux); + crypto_free_sync_skcipher(kctx->initiator_enc_aux); + kfree(kctx->mech_used.data); + kfree(kctx); +} + +static const struct gss_api_ops gss_kerberos_ops = { + .gss_import_sec_context = gss_import_sec_context_kerberos, + .gss_get_mic = gss_get_mic_kerberos, + .gss_verify_mic = gss_verify_mic_kerberos, + .gss_wrap = gss_wrap_kerberos, + .gss_unwrap = gss_unwrap_kerberos, + .gss_delete_sec_context = gss_delete_sec_context_kerberos, +}; + +static struct pf_desc gss_kerberos_pfs[] = { + [0] = { + .pseudoflavor = RPC_AUTH_GSS_KRB5, + .qop = GSS_C_QOP_DEFAULT, + .service = RPC_GSS_SVC_NONE, + .name = "krb5", + }, + [1] = { + .pseudoflavor = RPC_AUTH_GSS_KRB5I, + .qop = GSS_C_QOP_DEFAULT, + .service = RPC_GSS_SVC_INTEGRITY, + .name = "krb5i", + .datatouch = true, + }, + [2] = { + .pseudoflavor = RPC_AUTH_GSS_KRB5P, + .qop = GSS_C_QOP_DEFAULT, + .service = RPC_GSS_SVC_PRIVACY, + .name = "krb5p", + .datatouch = true, + }, +}; + +MODULE_ALIAS("rpc-auth-gss-krb5"); +MODULE_ALIAS("rpc-auth-gss-krb5i"); +MODULE_ALIAS("rpc-auth-gss-krb5p"); +MODULE_ALIAS("rpc-auth-gss-390003"); +MODULE_ALIAS("rpc-auth-gss-390004"); +MODULE_ALIAS("rpc-auth-gss-390005"); +MODULE_ALIAS("rpc-auth-gss-1.2.840.113554.1.2.2"); + +static struct gss_api_mech gss_kerberos_mech = { + .gm_name = "krb5", + .gm_owner = THIS_MODULE, + .gm_oid = { 9, "\x2a\x86\x48\x86\xf7\x12\x01\x02\x02" }, + .gm_ops = &gss_kerberos_ops, + .gm_pf_num = ARRAY_SIZE(gss_kerberos_pfs), + .gm_pfs = gss_kerberos_pfs, + .gm_upcall_enctypes = KRB5_SUPPORTED_ENCTYPES, +}; + +static int __init init_kerberos_module(void) +{ + int status; + + status = gss_mech_register(&gss_kerberos_mech); + if (status) + printk("Failed to register kerberos gss mechanism!\n"); + return status; +} + +static void __exit cleanup_kerberos_module(void) +{ + gss_mech_unregister(&gss_kerberos_mech); +} + +MODULE_LICENSE("GPL"); +module_init(init_kerberos_module); +module_exit(cleanup_kerberos_module); diff --git a/net/sunrpc/auth_gss/gss_krb5_seal.c b/net/sunrpc/auth_gss/gss_krb5_seal.c new file mode 100644 index 000000000..33061417e --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_seal.c @@ -0,0 +1,222 @@ +/* + * linux/net/sunrpc/gss_krb5_seal.c + * + * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/krb5/k5seal.c + * + * Copyright (c) 2000-2008 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson + * J. Bruce Fields + */ + +/* + * Copyright 1993 by OpenVision Technologies, Inc. + * + * Permission to use, copy, modify, distribute, and sell this software + * and its documentation for any purpose is hereby granted without fee, + * provided that the above copyright notice appears in all copies and + * that both that copyright notice and this permission notice appear in + * supporting documentation, and that the name of OpenVision not be used + * in advertising or publicity pertaining to distribution of the software + * without specific, written prior permission. OpenVision makes no + * representations about the suitability of this software for any + * purpose. It is provided "as is" without express or implied warranty. + * + * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, + * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO + * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR + * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF + * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR + * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +/* + * Copyright (C) 1998 by the FundsXpress, INC. + * + * All rights reserved. + * + * Export of this software from the United States of America may require + * a specific license from the United States Government. It is the + * responsibility of any person or organization contemplating export to + * obtain such a license before exporting. + * + * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and + * distribute this software and its documentation for any purpose and + * without fee is hereby granted, provided that the above copyright + * notice appear in all copies and that both that copyright notice and + * this permission notice appear in supporting documentation, and that + * the name of FundsXpress. not be used in advertising or publicity pertaining + * to distribution of the software without specific, written prior + * permission. FundsXpress makes no representations about the suitability of + * this software for any purpose. It is provided "as is" without express + * or implied warranty. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED + * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE. + */ + +#include +#include +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static void * +setup_token(struct krb5_ctx *ctx, struct xdr_netobj *token) +{ + u16 *ptr; + void *krb5_hdr; + int body_size = GSS_KRB5_TOK_HDR_LEN + ctx->gk5e->cksumlength; + + token->len = g_token_size(&ctx->mech_used, body_size); + + ptr = (u16 *)token->data; + g_make_token_header(&ctx->mech_used, body_size, (unsigned char **)&ptr); + + /* ptr now at start of header described in rfc 1964, section 1.2.1: */ + krb5_hdr = ptr; + *ptr++ = KG_TOK_MIC_MSG; + /* + * signalg is stored as if it were converted from LE to host endian, even + * though it's an opaque pair of bytes according to the RFC. + */ + *ptr++ = (__force u16)cpu_to_le16(ctx->gk5e->signalg); + *ptr++ = SEAL_ALG_NONE; + *ptr = 0xffff; + + return krb5_hdr; +} + +static void * +setup_token_v2(struct krb5_ctx *ctx, struct xdr_netobj *token) +{ + u16 *ptr; + void *krb5_hdr; + u8 *p, flags = 0x00; + + if ((ctx->flags & KRB5_CTX_FLAG_INITIATOR) == 0) + flags |= 0x01; + if (ctx->flags & KRB5_CTX_FLAG_ACCEPTOR_SUBKEY) + flags |= 0x04; + + /* Per rfc 4121, sec 4.2.6.1, there is no header, + * just start the token */ + krb5_hdr = ptr = (u16 *)token->data; + + *ptr++ = KG2_TOK_MIC; + p = (u8 *)ptr; + *p++ = flags; + *p++ = 0xff; + ptr = (u16 *)p; + *ptr++ = 0xffff; + *ptr = 0xffff; + + token->len = GSS_KRB5_TOK_HDR_LEN + ctx->gk5e->cksumlength; + return krb5_hdr; +} + +static u32 +gss_get_mic_v1(struct krb5_ctx *ctx, struct xdr_buf *text, + struct xdr_netobj *token) +{ + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj md5cksum = {.len = sizeof(cksumdata), + .data = cksumdata}; + void *ptr; + time64_t now; + u32 seq_send; + u8 *cksumkey; + + dprintk("RPC: %s\n", __func__); + BUG_ON(ctx == NULL); + + now = ktime_get_real_seconds(); + + ptr = setup_token(ctx, token); + + if (ctx->gk5e->keyed_cksum) + cksumkey = ctx->cksum; + else + cksumkey = NULL; + + if (make_checksum(ctx, ptr, 8, text, 0, cksumkey, + KG_USAGE_SIGN, &md5cksum)) + return GSS_S_FAILURE; + + memcpy(ptr + GSS_KRB5_TOK_HDR_LEN, md5cksum.data, md5cksum.len); + + seq_send = atomic_fetch_inc(&ctx->seq_send); + + if (krb5_make_seq_num(ctx, ctx->seq, ctx->initiate ? 0 : 0xff, + seq_send, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8)) + return GSS_S_FAILURE; + + return (ctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; +} + +static u32 +gss_get_mic_v2(struct krb5_ctx *ctx, struct xdr_buf *text, + struct xdr_netobj *token) +{ + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj cksumobj = { .len = sizeof(cksumdata), + .data = cksumdata}; + void *krb5_hdr; + time64_t now; + u8 *cksumkey; + unsigned int cksum_usage; + __be64 seq_send_be64; + + dprintk("RPC: %s\n", __func__); + + krb5_hdr = setup_token_v2(ctx, token); + + /* Set up the sequence number. Now 64-bits in clear + * text and w/o direction indicator */ + seq_send_be64 = cpu_to_be64(atomic64_fetch_inc(&ctx->seq_send64)); + memcpy(krb5_hdr + 8, (char *) &seq_send_be64, 8); + + if (ctx->initiate) { + cksumkey = ctx->initiator_sign; + cksum_usage = KG_USAGE_INITIATOR_SIGN; + } else { + cksumkey = ctx->acceptor_sign; + cksum_usage = KG_USAGE_ACCEPTOR_SIGN; + } + + if (make_checksum_v2(ctx, krb5_hdr, GSS_KRB5_TOK_HDR_LEN, + text, 0, cksumkey, cksum_usage, &cksumobj)) + return GSS_S_FAILURE; + + memcpy(krb5_hdr + GSS_KRB5_TOK_HDR_LEN, cksumobj.data, cksumobj.len); + + now = ktime_get_real_seconds(); + + return (ctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; +} + +u32 +gss_get_mic_kerberos(struct gss_ctx *gss_ctx, struct xdr_buf *text, + struct xdr_netobj *token) +{ + struct krb5_ctx *ctx = gss_ctx->internal_ctx_id; + + switch (ctx->enctype) { + default: + BUG(); + case ENCTYPE_DES_CBC_RAW: + case ENCTYPE_DES3_CBC_RAW: + return gss_get_mic_v1(ctx, text, token); + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + return gss_get_mic_v2(ctx, text, token); + } +} diff --git a/net/sunrpc/auth_gss/gss_krb5_seqnum.c b/net/sunrpc/auth_gss/gss_krb5_seqnum.c new file mode 100644 index 000000000..3200b971a --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_seqnum.c @@ -0,0 +1,104 @@ +/* + * linux/net/sunrpc/gss_krb5_seqnum.c + * + * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/krb5/util_seqnum.c + * + * Copyright (c) 2000 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson + */ + +/* + * Copyright 1993 by OpenVision Technologies, Inc. + * + * Permission to use, copy, modify, distribute, and sell this software + * and its documentation for any purpose is hereby granted without fee, + * provided that the above copyright notice appears in all copies and + * that both that copyright notice and this permission notice appear in + * supporting documentation, and that the name of OpenVision not be used + * in advertising or publicity pertaining to distribution of the software + * without specific, written prior permission. OpenVision makes no + * representations about the suitability of this software for any + * purpose. It is provided "as is" without express or implied warranty. + * + * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, + * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO + * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR + * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF + * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR + * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +#include +#include +#include + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +s32 +krb5_make_seq_num(struct krb5_ctx *kctx, + struct crypto_sync_skcipher *key, + int direction, + u32 seqnum, + unsigned char *cksum, unsigned char *buf) +{ + unsigned char *plain; + s32 code; + + plain = kmalloc(8, GFP_KERNEL); + if (!plain) + return -ENOMEM; + + plain[0] = (unsigned char) (seqnum & 0xff); + plain[1] = (unsigned char) ((seqnum >> 8) & 0xff); + plain[2] = (unsigned char) ((seqnum >> 16) & 0xff); + plain[3] = (unsigned char) ((seqnum >> 24) & 0xff); + + plain[4] = direction; + plain[5] = direction; + plain[6] = direction; + plain[7] = direction; + + code = krb5_encrypt(key, cksum, plain, buf, 8); + kfree(plain); + return code; +} + +s32 +krb5_get_seq_num(struct krb5_ctx *kctx, + unsigned char *cksum, + unsigned char *buf, + int *direction, u32 *seqnum) +{ + s32 code; + unsigned char *plain; + struct crypto_sync_skcipher *key = kctx->seq; + + dprintk("RPC: krb5_get_seq_num:\n"); + + plain = kmalloc(8, GFP_KERNEL); + if (!plain) + return -ENOMEM; + + if ((code = krb5_decrypt(key, cksum, buf, plain, 8))) + goto out; + + if ((plain[4] != plain[5]) || (plain[4] != plain[6]) || + (plain[4] != plain[7])) { + code = (s32)KG_BAD_SEQ; + goto out; + } + + *direction = plain[4]; + + *seqnum = ((plain[0]) | + (plain[1] << 8) | (plain[2] << 16) | (plain[3] << 24)); + +out: + kfree(plain); + return code; +} diff --git a/net/sunrpc/auth_gss/gss_krb5_unseal.c b/net/sunrpc/auth_gss/gss_krb5_unseal.c new file mode 100644 index 000000000..ba04e3ec9 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_unseal.c @@ -0,0 +1,226 @@ +/* + * linux/net/sunrpc/gss_krb5_unseal.c + * + * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/krb5/k5unseal.c + * + * Copyright (c) 2000-2008 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson + */ + +/* + * Copyright 1993 by OpenVision Technologies, Inc. + * + * Permission to use, copy, modify, distribute, and sell this software + * and its documentation for any purpose is hereby granted without fee, + * provided that the above copyright notice appears in all copies and + * that both that copyright notice and this permission notice appear in + * supporting documentation, and that the name of OpenVision not be used + * in advertising or publicity pertaining to distribution of the software + * without specific, written prior permission. OpenVision makes no + * representations about the suitability of this software for any + * purpose. It is provided "as is" without express or implied warranty. + * + * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, + * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO + * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR + * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF + * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR + * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +/* + * Copyright (C) 1998 by the FundsXpress, INC. + * + * All rights reserved. + * + * Export of this software from the United States of America may require + * a specific license from the United States Government. It is the + * responsibility of any person or organization contemplating export to + * obtain such a license before exporting. + * + * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and + * distribute this software and its documentation for any purpose and + * without fee is hereby granted, provided that the above copyright + * notice appear in all copies and that both that copyright notice and + * this permission notice appear in supporting documentation, and that + * the name of FundsXpress. not be used in advertising or publicity pertaining + * to distribution of the software without specific, written prior + * permission. FundsXpress makes no representations about the suitability of + * this software for any purpose. It is provided "as is" without express + * or implied warranty. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED + * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE. + */ + +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + + +/* read_token is a mic token, and message_buffer is the data that the mic was + * supposedly taken over. */ + +static u32 +gss_verify_mic_v1(struct krb5_ctx *ctx, + struct xdr_buf *message_buffer, struct xdr_netobj *read_token) +{ + int signalg; + int sealalg; + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj md5cksum = {.len = sizeof(cksumdata), + .data = cksumdata}; + s32 now; + int direction; + u32 seqnum; + unsigned char *ptr = (unsigned char *)read_token->data; + int bodysize; + u8 *cksumkey; + + dprintk("RPC: krb5_read_token\n"); + + if (g_verify_token_header(&ctx->mech_used, &bodysize, &ptr, + read_token->len)) + return GSS_S_DEFECTIVE_TOKEN; + + if ((ptr[0] != ((KG_TOK_MIC_MSG >> 8) & 0xff)) || + (ptr[1] != (KG_TOK_MIC_MSG & 0xff))) + return GSS_S_DEFECTIVE_TOKEN; + + /* XXX sanity-check bodysize?? */ + + signalg = ptr[2] + (ptr[3] << 8); + if (signalg != ctx->gk5e->signalg) + return GSS_S_DEFECTIVE_TOKEN; + + sealalg = ptr[4] + (ptr[5] << 8); + if (sealalg != SEAL_ALG_NONE) + return GSS_S_DEFECTIVE_TOKEN; + + if ((ptr[6] != 0xff) || (ptr[7] != 0xff)) + return GSS_S_DEFECTIVE_TOKEN; + + if (ctx->gk5e->keyed_cksum) + cksumkey = ctx->cksum; + else + cksumkey = NULL; + + if (make_checksum(ctx, ptr, 8, message_buffer, 0, + cksumkey, KG_USAGE_SIGN, &md5cksum)) + return GSS_S_FAILURE; + + if (memcmp(md5cksum.data, ptr + GSS_KRB5_TOK_HDR_LEN, + ctx->gk5e->cksumlength)) + return GSS_S_BAD_SIG; + + /* it got through unscathed. Make sure the context is unexpired */ + + now = ktime_get_real_seconds(); + + if (now > ctx->endtime) + return GSS_S_CONTEXT_EXPIRED; + + /* do sequencing checks */ + + if (krb5_get_seq_num(ctx, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8, + &direction, &seqnum)) + return GSS_S_FAILURE; + + if ((ctx->initiate && direction != 0xff) || + (!ctx->initiate && direction != 0)) + return GSS_S_BAD_SIG; + + return GSS_S_COMPLETE; +} + +static u32 +gss_verify_mic_v2(struct krb5_ctx *ctx, + struct xdr_buf *message_buffer, struct xdr_netobj *read_token) +{ + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj cksumobj = {.len = sizeof(cksumdata), + .data = cksumdata}; + time64_t now; + u8 *ptr = read_token->data; + u8 *cksumkey; + u8 flags; + int i; + unsigned int cksum_usage; + __be16 be16_ptr; + + dprintk("RPC: %s\n", __func__); + + memcpy(&be16_ptr, (char *) ptr, 2); + if (be16_to_cpu(be16_ptr) != KG2_TOK_MIC) + return GSS_S_DEFECTIVE_TOKEN; + + flags = ptr[2]; + if ((!ctx->initiate && (flags & KG2_TOKEN_FLAG_SENTBYACCEPTOR)) || + (ctx->initiate && !(flags & KG2_TOKEN_FLAG_SENTBYACCEPTOR))) + return GSS_S_BAD_SIG; + + if (flags & KG2_TOKEN_FLAG_SEALED) { + dprintk("%s: token has unexpected sealed flag\n", __func__); + return GSS_S_FAILURE; + } + + for (i = 3; i < 8; i++) + if (ptr[i] != 0xff) + return GSS_S_DEFECTIVE_TOKEN; + + if (ctx->initiate) { + cksumkey = ctx->acceptor_sign; + cksum_usage = KG_USAGE_ACCEPTOR_SIGN; + } else { + cksumkey = ctx->initiator_sign; + cksum_usage = KG_USAGE_INITIATOR_SIGN; + } + + if (make_checksum_v2(ctx, ptr, GSS_KRB5_TOK_HDR_LEN, message_buffer, 0, + cksumkey, cksum_usage, &cksumobj)) + return GSS_S_FAILURE; + + if (memcmp(cksumobj.data, ptr + GSS_KRB5_TOK_HDR_LEN, + ctx->gk5e->cksumlength)) + return GSS_S_BAD_SIG; + + /* it got through unscathed. Make sure the context is unexpired */ + now = ktime_get_real_seconds(); + if (now > ctx->endtime) + return GSS_S_CONTEXT_EXPIRED; + + /* + * NOTE: the sequence number at ptr + 8 is skipped, rpcsec_gss + * doesn't want it checked; see page 6 of rfc 2203. + */ + + return GSS_S_COMPLETE; +} + +u32 +gss_verify_mic_kerberos(struct gss_ctx *gss_ctx, + struct xdr_buf *message_buffer, + struct xdr_netobj *read_token) +{ + struct krb5_ctx *ctx = gss_ctx->internal_ctx_id; + + switch (ctx->enctype) { + default: + BUG(); + case ENCTYPE_DES_CBC_RAW: + case ENCTYPE_DES3_CBC_RAW: + return gss_verify_mic_v1(ctx, message_buffer, read_token); + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + return gss_verify_mic_v2(ctx, message_buffer, read_token); + } +} diff --git a/net/sunrpc/auth_gss/gss_krb5_wrap.c b/net/sunrpc/auth_gss/gss_krb5_wrap.c new file mode 100644 index 000000000..483376878 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_wrap.c @@ -0,0 +1,596 @@ +/* + * COPYRIGHT (c) 2008 + * The Regents of the University of Michigan + * ALL RIGHTS RESERVED + * + * Permission is granted to use, copy, create derivative works + * and redistribute this software and such derivative works + * for any purpose, so long as the name of The University of + * Michigan is not used in any advertising or publicity + * pertaining to the use of distribution of this software + * without specific, written prior authorization. If the + * above copyright notice or any other identification of the + * University of Michigan is included in any copy of any + * portion of this software, then the disclaimer below must + * also be included. + * + * THIS SOFTWARE IS PROVIDED AS IS, WITHOUT REPRESENTATION + * FROM THE UNIVERSITY OF MICHIGAN AS TO ITS FITNESS FOR ANY + * PURPOSE, AND WITHOUT WARRANTY BY THE UNIVERSITY OF + * MICHIGAN OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING + * WITHOUT LIMITATION THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE + * REGENTS OF THE UNIVERSITY OF MICHIGAN SHALL NOT BE LIABLE + * FOR ANY DAMAGES, INCLUDING SPECIAL, INDIRECT, INCIDENTAL, OR + * CONSEQUENTIAL DAMAGES, WITH RESPECT TO ANY CLAIM ARISING + * OUT OF OR IN CONNECTION WITH THE USE OF THE SOFTWARE, EVEN + * IF IT HAS BEEN OR IS HEREAFTER ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGES. + */ + +#include +#include +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static inline int +gss_krb5_padding(int blocksize, int length) +{ + return blocksize - (length % blocksize); +} + +static inline void +gss_krb5_add_padding(struct xdr_buf *buf, int offset, int blocksize) +{ + int padding = gss_krb5_padding(blocksize, buf->len - offset); + char *p; + struct kvec *iov; + + if (buf->page_len || buf->tail[0].iov_len) + iov = &buf->tail[0]; + else + iov = &buf->head[0]; + p = iov->iov_base + iov->iov_len; + iov->iov_len += padding; + buf->len += padding; + memset(p, padding, padding); +} + +static inline int +gss_krb5_remove_padding(struct xdr_buf *buf, int blocksize) +{ + u8 *ptr; + u8 pad; + size_t len = buf->len; + + if (len <= buf->head[0].iov_len) { + pad = *(u8 *)(buf->head[0].iov_base + len - 1); + if (pad > buf->head[0].iov_len) + return -EINVAL; + buf->head[0].iov_len -= pad; + goto out; + } else + len -= buf->head[0].iov_len; + if (len <= buf->page_len) { + unsigned int last = (buf->page_base + len - 1) + >>PAGE_SHIFT; + unsigned int offset = (buf->page_base + len - 1) + & (PAGE_SIZE - 1); + ptr = kmap_atomic(buf->pages[last]); + pad = *(ptr + offset); + kunmap_atomic(ptr); + goto out; + } else + len -= buf->page_len; + BUG_ON(len > buf->tail[0].iov_len); + pad = *(u8 *)(buf->tail[0].iov_base + len - 1); +out: + /* XXX: NOTE: we do not adjust the page lengths--they represent + * a range of data in the real filesystem page cache, and we need + * to know that range so the xdr code can properly place read data. + * However adjusting the head length, as we do above, is harmless. + * In the case of a request that fits into a single page, the server + * also uses length and head length together to determine the original + * start of the request to copy the request for deferal; so it's + * easier on the server if we adjust head and tail length in tandem. + * It's not really a problem that we don't fool with the page and + * tail lengths, though--at worst badly formed xdr might lead the + * server to attempt to parse the padding. + * XXX: Document all these weird requirements for gss mechanism + * wrap/unwrap functions. */ + if (pad > blocksize) + return -EINVAL; + if (buf->len > pad) + buf->len -= pad; + else + return -EINVAL; + return 0; +} + +void +gss_krb5_make_confounder(char *p, u32 conflen) +{ + static u64 i = 0; + u64 *q = (u64 *)p; + + /* rfc1964 claims this should be "random". But all that's really + * necessary is that it be unique. And not even that is necessary in + * our case since our "gssapi" implementation exists only to support + * rpcsec_gss, so we know that the only buffers we will ever encrypt + * already begin with a unique sequence number. Just to hedge my bets + * I'll make a half-hearted attempt at something unique, but ensuring + * uniqueness would mean worrying about atomicity and rollover, and I + * don't care enough. */ + + /* initialize to random value */ + if (i == 0) { + i = get_random_u32(); + i = (i << 32) | get_random_u32(); + } + + switch (conflen) { + case 16: + *q++ = i++; + fallthrough; + case 8: + *q++ = i++; + break; + default: + BUG(); + } +} + +/* Assumptions: the head and tail of inbuf are ours to play with. + * The pages, however, may be real pages in the page cache and we replace + * them with scratch pages from **pages before writing to them. */ +/* XXX: obviously the above should be documentation of wrap interface, + * and shouldn't be in this kerberos-specific file. */ + +/* XXX factor out common code with seal/unseal. */ + +static u32 +gss_wrap_kerberos_v1(struct krb5_ctx *kctx, int offset, + struct xdr_buf *buf, struct page **pages) +{ + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj md5cksum = {.len = sizeof(cksumdata), + .data = cksumdata}; + int blocksize = 0, plainlen; + unsigned char *ptr, *msg_start; + time64_t now; + int headlen; + struct page **tmp_pages; + u32 seq_send; + u8 *cksumkey; + u32 conflen = kctx->gk5e->conflen; + + dprintk("RPC: %s\n", __func__); + + now = ktime_get_real_seconds(); + + blocksize = crypto_sync_skcipher_blocksize(kctx->enc); + gss_krb5_add_padding(buf, offset, blocksize); + BUG_ON((buf->len - offset) % blocksize); + plainlen = conflen + buf->len - offset; + + headlen = g_token_size(&kctx->mech_used, + GSS_KRB5_TOK_HDR_LEN + kctx->gk5e->cksumlength + plainlen) - + (buf->len - offset); + + ptr = buf->head[0].iov_base + offset; + /* shift data to make room for header. */ + xdr_extend_head(buf, offset, headlen); + + /* XXX Would be cleverer to encrypt while copying. */ + BUG_ON((buf->len - offset - headlen) % blocksize); + + g_make_token_header(&kctx->mech_used, + GSS_KRB5_TOK_HDR_LEN + + kctx->gk5e->cksumlength + plainlen, &ptr); + + + /* ptr now at header described in rfc 1964, section 1.2.1: */ + ptr[0] = (unsigned char) ((KG_TOK_WRAP_MSG >> 8) & 0xff); + ptr[1] = (unsigned char) (KG_TOK_WRAP_MSG & 0xff); + + msg_start = ptr + GSS_KRB5_TOK_HDR_LEN + kctx->gk5e->cksumlength; + + /* + * signalg and sealalg are stored as if they were converted from LE + * to host endian, even though they're opaque pairs of bytes according + * to the RFC. + */ + *(__le16 *)(ptr + 2) = cpu_to_le16(kctx->gk5e->signalg); + *(__le16 *)(ptr + 4) = cpu_to_le16(kctx->gk5e->sealalg); + ptr[6] = 0xff; + ptr[7] = 0xff; + + gss_krb5_make_confounder(msg_start, conflen); + + if (kctx->gk5e->keyed_cksum) + cksumkey = kctx->cksum; + else + cksumkey = NULL; + + /* XXXJBF: UGH!: */ + tmp_pages = buf->pages; + buf->pages = pages; + if (make_checksum(kctx, ptr, 8, buf, offset + headlen - conflen, + cksumkey, KG_USAGE_SEAL, &md5cksum)) + return GSS_S_FAILURE; + buf->pages = tmp_pages; + + memcpy(ptr + GSS_KRB5_TOK_HDR_LEN, md5cksum.data, md5cksum.len); + + seq_send = atomic_fetch_inc(&kctx->seq_send); + + /* XXX would probably be more efficient to compute checksum + * and encrypt at the same time: */ + if ((krb5_make_seq_num(kctx, kctx->seq, kctx->initiate ? 0 : 0xff, + seq_send, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8))) + return GSS_S_FAILURE; + + if (gss_encrypt_xdr_buf(kctx->enc, buf, + offset + headlen - conflen, pages)) + return GSS_S_FAILURE; + + return (kctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; +} + +static u32 +gss_unwrap_kerberos_v1(struct krb5_ctx *kctx, int offset, int len, + struct xdr_buf *buf, unsigned int *slack, + unsigned int *align) +{ + int signalg; + int sealalg; + char cksumdata[GSS_KRB5_MAX_CKSUM_LEN]; + struct xdr_netobj md5cksum = {.len = sizeof(cksumdata), + .data = cksumdata}; + time64_t now; + int direction; + s32 seqnum; + unsigned char *ptr; + int bodysize; + void *data_start, *orig_start; + int data_len; + int blocksize; + u32 conflen = kctx->gk5e->conflen; + int crypt_offset; + u8 *cksumkey; + unsigned int saved_len = buf->len; + + dprintk("RPC: gss_unwrap_kerberos\n"); + + ptr = (u8 *)buf->head[0].iov_base + offset; + if (g_verify_token_header(&kctx->mech_used, &bodysize, &ptr, + len - offset)) + return GSS_S_DEFECTIVE_TOKEN; + + if ((ptr[0] != ((KG_TOK_WRAP_MSG >> 8) & 0xff)) || + (ptr[1] != (KG_TOK_WRAP_MSG & 0xff))) + return GSS_S_DEFECTIVE_TOKEN; + + /* XXX sanity-check bodysize?? */ + + /* get the sign and seal algorithms */ + + signalg = ptr[2] + (ptr[3] << 8); + if (signalg != kctx->gk5e->signalg) + return GSS_S_DEFECTIVE_TOKEN; + + sealalg = ptr[4] + (ptr[5] << 8); + if (sealalg != kctx->gk5e->sealalg) + return GSS_S_DEFECTIVE_TOKEN; + + if ((ptr[6] != 0xff) || (ptr[7] != 0xff)) + return GSS_S_DEFECTIVE_TOKEN; + + /* + * Data starts after token header and checksum. ptr points + * to the beginning of the token header + */ + crypt_offset = ptr + (GSS_KRB5_TOK_HDR_LEN + kctx->gk5e->cksumlength) - + (unsigned char *)buf->head[0].iov_base; + + buf->len = len; + if (gss_decrypt_xdr_buf(kctx->enc, buf, crypt_offset)) + return GSS_S_DEFECTIVE_TOKEN; + + if (kctx->gk5e->keyed_cksum) + cksumkey = kctx->cksum; + else + cksumkey = NULL; + + if (make_checksum(kctx, ptr, 8, buf, crypt_offset, + cksumkey, KG_USAGE_SEAL, &md5cksum)) + return GSS_S_FAILURE; + + if (memcmp(md5cksum.data, ptr + GSS_KRB5_TOK_HDR_LEN, + kctx->gk5e->cksumlength)) + return GSS_S_BAD_SIG; + + /* it got through unscathed. Make sure the context is unexpired */ + + now = ktime_get_real_seconds(); + + if (now > kctx->endtime) + return GSS_S_CONTEXT_EXPIRED; + + /* do sequencing checks */ + + if (krb5_get_seq_num(kctx, ptr + GSS_KRB5_TOK_HDR_LEN, + ptr + 8, &direction, &seqnum)) + return GSS_S_BAD_SIG; + + if ((kctx->initiate && direction != 0xff) || + (!kctx->initiate && direction != 0)) + return GSS_S_BAD_SIG; + + /* Copy the data back to the right position. XXX: Would probably be + * better to copy and encrypt at the same time. */ + + blocksize = crypto_sync_skcipher_blocksize(kctx->enc); + data_start = ptr + (GSS_KRB5_TOK_HDR_LEN + kctx->gk5e->cksumlength) + + conflen; + orig_start = buf->head[0].iov_base + offset; + data_len = (buf->head[0].iov_base + buf->head[0].iov_len) - data_start; + memmove(orig_start, data_start, data_len); + buf->head[0].iov_len -= (data_start - orig_start); + buf->len = len - (data_start - orig_start); + + if (gss_krb5_remove_padding(buf, blocksize)) + return GSS_S_DEFECTIVE_TOKEN; + + /* slack must include room for krb5 padding */ + *slack = XDR_QUADLEN(saved_len - buf->len); + /* The GSS blob always precedes the RPC message payload */ + *align = *slack; + return GSS_S_COMPLETE; +} + +/* + * We can shift data by up to LOCAL_BUF_LEN bytes in a pass. If we need + * to do more than that, we shift repeatedly. Kevin Coffman reports + * seeing 28 bytes as the value used by Microsoft clients and servers + * with AES, so this constant is chosen to allow handling 28 in one pass + * without using too much stack space. + * + * If that proves to a problem perhaps we could use a more clever + * algorithm. + */ +#define LOCAL_BUF_LEN 32u + +static void rotate_buf_a_little(struct xdr_buf *buf, unsigned int shift) +{ + char head[LOCAL_BUF_LEN]; + char tmp[LOCAL_BUF_LEN]; + unsigned int this_len, i; + + BUG_ON(shift > LOCAL_BUF_LEN); + + read_bytes_from_xdr_buf(buf, 0, head, shift); + for (i = 0; i + shift < buf->len; i += LOCAL_BUF_LEN) { + this_len = min(LOCAL_BUF_LEN, buf->len - (i + shift)); + read_bytes_from_xdr_buf(buf, i+shift, tmp, this_len); + write_bytes_to_xdr_buf(buf, i, tmp, this_len); + } + write_bytes_to_xdr_buf(buf, buf->len - shift, head, shift); +} + +static void _rotate_left(struct xdr_buf *buf, unsigned int shift) +{ + int shifted = 0; + int this_shift; + + shift %= buf->len; + while (shifted < shift) { + this_shift = min(shift - shifted, LOCAL_BUF_LEN); + rotate_buf_a_little(buf, this_shift); + shifted += this_shift; + } +} + +static void rotate_left(u32 base, struct xdr_buf *buf, unsigned int shift) +{ + struct xdr_buf subbuf; + + xdr_buf_subsegment(buf, &subbuf, base, buf->len - base); + _rotate_left(&subbuf, shift); +} + +static u32 +gss_wrap_kerberos_v2(struct krb5_ctx *kctx, u32 offset, + struct xdr_buf *buf, struct page **pages) +{ + u8 *ptr; + time64_t now; + u8 flags = 0x00; + __be16 *be16ptr; + __be64 *be64ptr; + u32 err; + + dprintk("RPC: %s\n", __func__); + + if (kctx->gk5e->encrypt_v2 == NULL) + return GSS_S_FAILURE; + + /* make room for gss token header */ + if (xdr_extend_head(buf, offset, GSS_KRB5_TOK_HDR_LEN)) + return GSS_S_FAILURE; + + /* construct gss token header */ + ptr = buf->head[0].iov_base + offset; + *ptr++ = (unsigned char) ((KG2_TOK_WRAP>>8) & 0xff); + *ptr++ = (unsigned char) (KG2_TOK_WRAP & 0xff); + + if ((kctx->flags & KRB5_CTX_FLAG_INITIATOR) == 0) + flags |= KG2_TOKEN_FLAG_SENTBYACCEPTOR; + if ((kctx->flags & KRB5_CTX_FLAG_ACCEPTOR_SUBKEY) != 0) + flags |= KG2_TOKEN_FLAG_ACCEPTORSUBKEY; + /* We always do confidentiality in wrap tokens */ + flags |= KG2_TOKEN_FLAG_SEALED; + + *ptr++ = flags; + *ptr++ = 0xff; + be16ptr = (__be16 *)ptr; + + *be16ptr++ = 0; + /* "inner" token header always uses 0 for RRC */ + *be16ptr++ = 0; + + be64ptr = (__be64 *)be16ptr; + *be64ptr = cpu_to_be64(atomic64_fetch_inc(&kctx->seq_send64)); + + err = (*kctx->gk5e->encrypt_v2)(kctx, offset, buf, pages); + if (err) + return err; + + now = ktime_get_real_seconds(); + return (kctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; +} + +static u32 +gss_unwrap_kerberos_v2(struct krb5_ctx *kctx, int offset, int len, + struct xdr_buf *buf, unsigned int *slack, + unsigned int *align) +{ + time64_t now; + u8 *ptr; + u8 flags = 0x00; + u16 ec, rrc; + int err; + u32 headskip, tailskip; + u8 decrypted_hdr[GSS_KRB5_TOK_HDR_LEN]; + unsigned int movelen; + + + dprintk("RPC: %s\n", __func__); + + if (kctx->gk5e->decrypt_v2 == NULL) + return GSS_S_FAILURE; + + ptr = buf->head[0].iov_base + offset; + + if (be16_to_cpu(*((__be16 *)ptr)) != KG2_TOK_WRAP) + return GSS_S_DEFECTIVE_TOKEN; + + flags = ptr[2]; + if ((!kctx->initiate && (flags & KG2_TOKEN_FLAG_SENTBYACCEPTOR)) || + (kctx->initiate && !(flags & KG2_TOKEN_FLAG_SENTBYACCEPTOR))) + return GSS_S_BAD_SIG; + + if ((flags & KG2_TOKEN_FLAG_SEALED) == 0) { + dprintk("%s: token missing expected sealed flag\n", __func__); + return GSS_S_DEFECTIVE_TOKEN; + } + + if (ptr[3] != 0xff) + return GSS_S_DEFECTIVE_TOKEN; + + ec = be16_to_cpup((__be16 *)(ptr + 4)); + rrc = be16_to_cpup((__be16 *)(ptr + 6)); + + /* + * NOTE: the sequence number at ptr + 8 is skipped, rpcsec_gss + * doesn't want it checked; see page 6 of rfc 2203. + */ + + if (rrc != 0) + rotate_left(offset + 16, buf, rrc); + + err = (*kctx->gk5e->decrypt_v2)(kctx, offset, len, buf, + &headskip, &tailskip); + if (err) + return GSS_S_FAILURE; + + /* + * Retrieve the decrypted gss token header and verify + * it against the original + */ + err = read_bytes_from_xdr_buf(buf, + len - GSS_KRB5_TOK_HDR_LEN - tailskip, + decrypted_hdr, GSS_KRB5_TOK_HDR_LEN); + if (err) { + dprintk("%s: error %u getting decrypted_hdr\n", __func__, err); + return GSS_S_FAILURE; + } + if (memcmp(ptr, decrypted_hdr, 6) + || memcmp(ptr + 8, decrypted_hdr + 8, 8)) { + dprintk("%s: token hdr, plaintext hdr mismatch!\n", __func__); + return GSS_S_FAILURE; + } + + /* do sequencing checks */ + + /* it got through unscathed. Make sure the context is unexpired */ + now = ktime_get_real_seconds(); + if (now > kctx->endtime) + return GSS_S_CONTEXT_EXPIRED; + + /* + * Move the head data back to the right position in xdr_buf. + * We ignore any "ec" data since it might be in the head or + * the tail, and we really don't need to deal with it. + * Note that buf->head[0].iov_len may indicate the available + * head buffer space rather than that actually occupied. + */ + movelen = min_t(unsigned int, buf->head[0].iov_len, len); + movelen -= offset + GSS_KRB5_TOK_HDR_LEN + headskip; + BUG_ON(offset + GSS_KRB5_TOK_HDR_LEN + headskip + movelen > + buf->head[0].iov_len); + memmove(ptr, ptr + GSS_KRB5_TOK_HDR_LEN + headskip, movelen); + buf->head[0].iov_len -= GSS_KRB5_TOK_HDR_LEN + headskip; + buf->len = len - (GSS_KRB5_TOK_HDR_LEN + headskip); + + /* Trim off the trailing "extra count" and checksum blob */ + xdr_buf_trim(buf, ec + GSS_KRB5_TOK_HDR_LEN + tailskip); + + *align = XDR_QUADLEN(GSS_KRB5_TOK_HDR_LEN + headskip); + *slack = *align + XDR_QUADLEN(ec + GSS_KRB5_TOK_HDR_LEN + tailskip); + return GSS_S_COMPLETE; +} + +u32 +gss_wrap_kerberos(struct gss_ctx *gctx, int offset, + struct xdr_buf *buf, struct page **pages) +{ + struct krb5_ctx *kctx = gctx->internal_ctx_id; + + switch (kctx->enctype) { + default: + BUG(); + case ENCTYPE_DES_CBC_RAW: + case ENCTYPE_DES3_CBC_RAW: + return gss_wrap_kerberos_v1(kctx, offset, buf, pages); + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + return gss_wrap_kerberos_v2(kctx, offset, buf, pages); + } +} + +u32 +gss_unwrap_kerberos(struct gss_ctx *gctx, int offset, + int len, struct xdr_buf *buf) +{ + struct krb5_ctx *kctx = gctx->internal_ctx_id; + + switch (kctx->enctype) { + default: + BUG(); + case ENCTYPE_DES_CBC_RAW: + case ENCTYPE_DES3_CBC_RAW: + return gss_unwrap_kerberos_v1(kctx, offset, len, buf, + &gctx->slack, &gctx->align); + case ENCTYPE_AES128_CTS_HMAC_SHA1_96: + case ENCTYPE_AES256_CTS_HMAC_SHA1_96: + return gss_unwrap_kerberos_v2(kctx, offset, len, buf, + &gctx->slack, &gctx->align); + } +} diff --git a/net/sunrpc/auth_gss/gss_mech_switch.c b/net/sunrpc/auth_gss/gss_mech_switch.c new file mode 100644 index 000000000..fae632da1 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_mech_switch.c @@ -0,0 +1,448 @@ +// SPDX-License-Identifier: BSD-3-Clause +/* + * linux/net/sunrpc/gss_mech_switch.c + * + * Copyright (c) 2001 The Regents of the University of Michigan. + * All rights reserved. + * + * J. Bruce Fields + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static LIST_HEAD(registered_mechs); +static DEFINE_SPINLOCK(registered_mechs_lock); + +static void +gss_mech_free(struct gss_api_mech *gm) +{ + struct pf_desc *pf; + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + pf = &gm->gm_pfs[i]; + if (pf->domain) + auth_domain_put(pf->domain); + kfree(pf->auth_domain_name); + pf->auth_domain_name = NULL; + } +} + +static inline char * +make_auth_domain_name(char *name) +{ + static char *prefix = "gss/"; + char *new; + + new = kmalloc(strlen(name) + strlen(prefix) + 1, GFP_KERNEL); + if (new) { + strcpy(new, prefix); + strcat(new, name); + } + return new; +} + +static int +gss_mech_svc_setup(struct gss_api_mech *gm) +{ + struct auth_domain *dom; + struct pf_desc *pf; + int i, status; + + for (i = 0; i < gm->gm_pf_num; i++) { + pf = &gm->gm_pfs[i]; + pf->auth_domain_name = make_auth_domain_name(pf->name); + status = -ENOMEM; + if (pf->auth_domain_name == NULL) + goto out; + dom = svcauth_gss_register_pseudoflavor( + pf->pseudoflavor, pf->auth_domain_name); + if (IS_ERR(dom)) { + status = PTR_ERR(dom); + goto out; + } + pf->domain = dom; + } + return 0; +out: + gss_mech_free(gm); + return status; +} + +/** + * gss_mech_register - register a GSS mechanism + * @gm: GSS mechanism handle + * + * Returns zero if successful, or a negative errno. + */ +int gss_mech_register(struct gss_api_mech *gm) +{ + int status; + + status = gss_mech_svc_setup(gm); + if (status) + return status; + spin_lock(®istered_mechs_lock); + list_add_rcu(&gm->gm_list, ®istered_mechs); + spin_unlock(®istered_mechs_lock); + dprintk("RPC: registered gss mechanism %s\n", gm->gm_name); + return 0; +} +EXPORT_SYMBOL_GPL(gss_mech_register); + +/** + * gss_mech_unregister - release a GSS mechanism + * @gm: GSS mechanism handle + * + */ +void gss_mech_unregister(struct gss_api_mech *gm) +{ + spin_lock(®istered_mechs_lock); + list_del_rcu(&gm->gm_list); + spin_unlock(®istered_mechs_lock); + dprintk("RPC: unregistered gss mechanism %s\n", gm->gm_name); + gss_mech_free(gm); +} +EXPORT_SYMBOL_GPL(gss_mech_unregister); + +struct gss_api_mech *gss_mech_get(struct gss_api_mech *gm) +{ + __module_get(gm->gm_owner); + return gm; +} +EXPORT_SYMBOL(gss_mech_get); + +static struct gss_api_mech * +_gss_mech_get_by_name(const char *name) +{ + struct gss_api_mech *pos, *gm = NULL; + + rcu_read_lock(); + list_for_each_entry_rcu(pos, ®istered_mechs, gm_list) { + if (0 == strcmp(name, pos->gm_name)) { + if (try_module_get(pos->gm_owner)) + gm = pos; + break; + } + } + rcu_read_unlock(); + return gm; + +} + +struct gss_api_mech * gss_mech_get_by_name(const char *name) +{ + struct gss_api_mech *gm = NULL; + + gm = _gss_mech_get_by_name(name); + if (!gm) { + request_module("rpc-auth-gss-%s", name); + gm = _gss_mech_get_by_name(name); + } + return gm; +} + +struct gss_api_mech *gss_mech_get_by_OID(struct rpcsec_gss_oid *obj) +{ + struct gss_api_mech *pos, *gm = NULL; + char buf[32]; + + if (sprint_oid(obj->data, obj->len, buf, sizeof(buf)) < 0) + return NULL; + request_module("rpc-auth-gss-%s", buf); + + rcu_read_lock(); + list_for_each_entry_rcu(pos, ®istered_mechs, gm_list) { + if (obj->len == pos->gm_oid.len) { + if (0 == memcmp(obj->data, pos->gm_oid.data, obj->len)) { + if (try_module_get(pos->gm_owner)) + gm = pos; + break; + } + } + } + rcu_read_unlock(); + if (!gm) + trace_rpcgss_oid_to_mech(buf); + return gm; +} + +static inline int +mech_supports_pseudoflavor(struct gss_api_mech *gm, u32 pseudoflavor) +{ + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].pseudoflavor == pseudoflavor) + return 1; + } + return 0; +} + +static struct gss_api_mech *_gss_mech_get_by_pseudoflavor(u32 pseudoflavor) +{ + struct gss_api_mech *gm = NULL, *pos; + + rcu_read_lock(); + list_for_each_entry_rcu(pos, ®istered_mechs, gm_list) { + if (!mech_supports_pseudoflavor(pos, pseudoflavor)) + continue; + if (try_module_get(pos->gm_owner)) + gm = pos; + break; + } + rcu_read_unlock(); + return gm; +} + +struct gss_api_mech * +gss_mech_get_by_pseudoflavor(u32 pseudoflavor) +{ + struct gss_api_mech *gm; + + gm = _gss_mech_get_by_pseudoflavor(pseudoflavor); + + if (!gm) { + request_module("rpc-auth-gss-%u", pseudoflavor); + gm = _gss_mech_get_by_pseudoflavor(pseudoflavor); + } + return gm; +} + +/** + * gss_svc_to_pseudoflavor - map a GSS service number to a pseudoflavor + * @gm: GSS mechanism handle + * @qop: GSS quality-of-protection value + * @service: GSS service value + * + * Returns a matching security flavor, or RPC_AUTH_MAXFLAVOR if none is found. + */ +rpc_authflavor_t gss_svc_to_pseudoflavor(struct gss_api_mech *gm, u32 qop, + u32 service) +{ + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].qop == qop && + gm->gm_pfs[i].service == service) { + return gm->gm_pfs[i].pseudoflavor; + } + } + return RPC_AUTH_MAXFLAVOR; +} + +/** + * gss_mech_info2flavor - look up a pseudoflavor given a GSS tuple + * @info: a GSS mech OID, quality of protection, and service value + * + * Returns a matching pseudoflavor, or RPC_AUTH_MAXFLAVOR if the tuple is + * not supported. + */ +rpc_authflavor_t gss_mech_info2flavor(struct rpcsec_gss_info *info) +{ + rpc_authflavor_t pseudoflavor; + struct gss_api_mech *gm; + + gm = gss_mech_get_by_OID(&info->oid); + if (gm == NULL) + return RPC_AUTH_MAXFLAVOR; + + pseudoflavor = gss_svc_to_pseudoflavor(gm, info->qop, info->service); + + gss_mech_put(gm); + return pseudoflavor; +} + +/** + * gss_mech_flavor2info - look up a GSS tuple for a given pseudoflavor + * @pseudoflavor: GSS pseudoflavor to match + * @info: rpcsec_gss_info structure to fill in + * + * Returns zero and fills in "info" if pseudoflavor matches a + * supported mechanism. Otherwise a negative errno is returned. + */ +int gss_mech_flavor2info(rpc_authflavor_t pseudoflavor, + struct rpcsec_gss_info *info) +{ + struct gss_api_mech *gm; + int i; + + gm = gss_mech_get_by_pseudoflavor(pseudoflavor); + if (gm == NULL) + return -ENOENT; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].pseudoflavor == pseudoflavor) { + memcpy(info->oid.data, gm->gm_oid.data, gm->gm_oid.len); + info->oid.len = gm->gm_oid.len; + info->qop = gm->gm_pfs[i].qop; + info->service = gm->gm_pfs[i].service; + gss_mech_put(gm); + return 0; + } + } + + gss_mech_put(gm); + return -ENOENT; +} + +u32 +gss_pseudoflavor_to_service(struct gss_api_mech *gm, u32 pseudoflavor) +{ + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].pseudoflavor == pseudoflavor) + return gm->gm_pfs[i].service; + } + return 0; +} +EXPORT_SYMBOL(gss_pseudoflavor_to_service); + +bool +gss_pseudoflavor_to_datatouch(struct gss_api_mech *gm, u32 pseudoflavor) +{ + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].pseudoflavor == pseudoflavor) + return gm->gm_pfs[i].datatouch; + } + return false; +} + +char * +gss_service_to_auth_domain_name(struct gss_api_mech *gm, u32 service) +{ + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].service == service) + return gm->gm_pfs[i].auth_domain_name; + } + return NULL; +} + +void +gss_mech_put(struct gss_api_mech * gm) +{ + if (gm) + module_put(gm->gm_owner); +} +EXPORT_SYMBOL(gss_mech_put); + +/* The mech could probably be determined from the token instead, but it's just + * as easy for now to pass it in. */ +int +gss_import_sec_context(const void *input_token, size_t bufsize, + struct gss_api_mech *mech, + struct gss_ctx **ctx_id, + time64_t *endtime, + gfp_t gfp_mask) +{ + if (!(*ctx_id = kzalloc(sizeof(**ctx_id), gfp_mask))) + return -ENOMEM; + (*ctx_id)->mech_type = gss_mech_get(mech); + + return mech->gm_ops->gss_import_sec_context(input_token, bufsize, + *ctx_id, endtime, gfp_mask); +} + +/* gss_get_mic: compute a mic over message and return mic_token. */ + +u32 +gss_get_mic(struct gss_ctx *context_handle, + struct xdr_buf *message, + struct xdr_netobj *mic_token) +{ + return context_handle->mech_type->gm_ops + ->gss_get_mic(context_handle, + message, + mic_token); +} + +/* gss_verify_mic: check whether the provided mic_token verifies message. */ + +u32 +gss_verify_mic(struct gss_ctx *context_handle, + struct xdr_buf *message, + struct xdr_netobj *mic_token) +{ + return context_handle->mech_type->gm_ops + ->gss_verify_mic(context_handle, + message, + mic_token); +} + +/* + * This function is called from both the client and server code. + * Each makes guarantees about how much "slack" space is available + * for the underlying function in "buf"'s head and tail while + * performing the wrap. + * + * The client and server code allocate RPC_MAX_AUTH_SIZE extra + * space in both the head and tail which is available for use by + * the wrap function. + * + * Underlying functions should verify they do not use more than + * RPC_MAX_AUTH_SIZE of extra space in either the head or tail + * when performing the wrap. + */ +u32 +gss_wrap(struct gss_ctx *ctx_id, + int offset, + struct xdr_buf *buf, + struct page **inpages) +{ + return ctx_id->mech_type->gm_ops + ->gss_wrap(ctx_id, offset, buf, inpages); +} + +u32 +gss_unwrap(struct gss_ctx *ctx_id, + int offset, + int len, + struct xdr_buf *buf) +{ + return ctx_id->mech_type->gm_ops + ->gss_unwrap(ctx_id, offset, len, buf); +} + + +/* gss_delete_sec_context: free all resources associated with context_handle. + * Note this differs from the RFC 2744-specified prototype in that we don't + * bother returning an output token, since it would never be used anyway. */ + +u32 +gss_delete_sec_context(struct gss_ctx **context_handle) +{ + dprintk("RPC: gss_delete_sec_context deleting %p\n", + *context_handle); + + if (!*context_handle) + return GSS_S_NO_CONTEXT; + if ((*context_handle)->internal_ctx_id) + (*context_handle)->mech_type->gm_ops + ->gss_delete_sec_context((*context_handle) + ->internal_ctx_id); + gss_mech_put((*context_handle)->mech_type); + kfree(*context_handle); + *context_handle=NULL; + return GSS_S_COMPLETE; +} diff --git a/net/sunrpc/auth_gss/gss_rpc_upcall.c b/net/sunrpc/auth_gss/gss_rpc_upcall.c new file mode 100644 index 000000000..f549e4c05 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_rpc_upcall.c @@ -0,0 +1,403 @@ +// SPDX-License-Identifier: GPL-2.0+ +/* + * linux/net/sunrpc/gss_rpc_upcall.c + * + * Copyright (C) 2012 Simo Sorce + */ + +#include +#include + +#include +#include "gss_rpc_upcall.h" + +#define GSSPROXY_SOCK_PATHNAME "/var/run/gssproxy.sock" + +#define GSSPROXY_PROGRAM (400112u) +#define GSSPROXY_VERS_1 (1u) + +/* + * Encoding/Decoding functions + */ + +enum { + GSSX_NULL = 0, /* Unused */ + GSSX_INDICATE_MECHS = 1, + GSSX_GET_CALL_CONTEXT = 2, + GSSX_IMPORT_AND_CANON_NAME = 3, + GSSX_EXPORT_CRED = 4, + GSSX_IMPORT_CRED = 5, + GSSX_ACQUIRE_CRED = 6, + GSSX_STORE_CRED = 7, + GSSX_INIT_SEC_CONTEXT = 8, + GSSX_ACCEPT_SEC_CONTEXT = 9, + GSSX_RELEASE_HANDLE = 10, + GSSX_GET_MIC = 11, + GSSX_VERIFY = 12, + GSSX_WRAP = 13, + GSSX_UNWRAP = 14, + GSSX_WRAP_SIZE_LIMIT = 15, +}; + +#define PROC(proc, name) \ +[GSSX_##proc] = { \ + .p_proc = GSSX_##proc, \ + .p_encode = gssx_enc_##name, \ + .p_decode = gssx_dec_##name, \ + .p_arglen = GSSX_ARG_##name##_sz, \ + .p_replen = GSSX_RES_##name##_sz, \ + .p_statidx = GSSX_##proc, \ + .p_name = #proc, \ +} + +static const struct rpc_procinfo gssp_procedures[] = { + PROC(INDICATE_MECHS, indicate_mechs), + PROC(GET_CALL_CONTEXT, get_call_context), + PROC(IMPORT_AND_CANON_NAME, import_and_canon_name), + PROC(EXPORT_CRED, export_cred), + PROC(IMPORT_CRED, import_cred), + PROC(ACQUIRE_CRED, acquire_cred), + PROC(STORE_CRED, store_cred), + PROC(INIT_SEC_CONTEXT, init_sec_context), + PROC(ACCEPT_SEC_CONTEXT, accept_sec_context), + PROC(RELEASE_HANDLE, release_handle), + PROC(GET_MIC, get_mic), + PROC(VERIFY, verify), + PROC(WRAP, wrap), + PROC(UNWRAP, unwrap), + PROC(WRAP_SIZE_LIMIT, wrap_size_limit), +}; + + + +/* + * Common transport functions + */ + +static const struct rpc_program gssp_program; + +static int gssp_rpc_create(struct net *net, struct rpc_clnt **_clnt) +{ + static const struct sockaddr_un gssp_localaddr = { + .sun_family = AF_LOCAL, + .sun_path = GSSPROXY_SOCK_PATHNAME, + }; + struct rpc_create_args args = { + .net = net, + .protocol = XPRT_TRANSPORT_LOCAL, + .address = (struct sockaddr *)&gssp_localaddr, + .addrsize = sizeof(gssp_localaddr), + .servername = "localhost", + .program = &gssp_program, + .version = GSSPROXY_VERS_1, + .authflavor = RPC_AUTH_NULL, + /* + * Note we want connection to be done in the caller's + * filesystem namespace. We therefore turn off the idle + * timeout, which would result in reconnections being + * done without the correct namespace: + */ + .flags = RPC_CLNT_CREATE_NOPING | + RPC_CLNT_CREATE_CONNECTED | + RPC_CLNT_CREATE_NO_IDLE_TIMEOUT + }; + struct rpc_clnt *clnt; + int result = 0; + + clnt = rpc_create(&args); + if (IS_ERR(clnt)) { + dprintk("RPC: failed to create AF_LOCAL gssproxy " + "client (errno %ld).\n", PTR_ERR(clnt)); + result = PTR_ERR(clnt); + *_clnt = NULL; + goto out; + } + + dprintk("RPC: created new gssp local client (gssp_local_clnt: " + "%p)\n", clnt); + *_clnt = clnt; + +out: + return result; +} + +void init_gssp_clnt(struct sunrpc_net *sn) +{ + mutex_init(&sn->gssp_lock); + sn->gssp_clnt = NULL; +} + +int set_gssp_clnt(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_clnt *clnt; + int ret; + + mutex_lock(&sn->gssp_lock); + ret = gssp_rpc_create(net, &clnt); + if (!ret) { + if (sn->gssp_clnt) + rpc_shutdown_client(sn->gssp_clnt); + sn->gssp_clnt = clnt; + } + mutex_unlock(&sn->gssp_lock); + return ret; +} + +void clear_gssp_clnt(struct sunrpc_net *sn) +{ + mutex_lock(&sn->gssp_lock); + if (sn->gssp_clnt) { + rpc_shutdown_client(sn->gssp_clnt); + sn->gssp_clnt = NULL; + } + mutex_unlock(&sn->gssp_lock); +} + +static struct rpc_clnt *get_gssp_clnt(struct sunrpc_net *sn) +{ + struct rpc_clnt *clnt; + + mutex_lock(&sn->gssp_lock); + clnt = sn->gssp_clnt; + if (clnt) + refcount_inc(&clnt->cl_count); + mutex_unlock(&sn->gssp_lock); + return clnt; +} + +static int gssp_call(struct net *net, struct rpc_message *msg) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_clnt *clnt; + int status; + + clnt = get_gssp_clnt(sn); + if (!clnt) + return -EIO; + status = rpc_call_sync(clnt, msg, 0); + if (status < 0) { + dprintk("gssp: rpc_call returned error %d\n", -status); + switch (status) { + case -EPROTONOSUPPORT: + status = -EINVAL; + break; + case -ECONNREFUSED: + case -ETIMEDOUT: + case -ENOTCONN: + status = -EAGAIN; + break; + case -ERESTARTSYS: + if (signalled ()) + status = -EINTR; + break; + default: + break; + } + } + rpc_release_client(clnt); + return status; +} + +static void gssp_free_receive_pages(struct gssx_arg_accept_sec_context *arg) +{ + unsigned int i; + + for (i = 0; i < arg->npages && arg->pages[i]; i++) + __free_page(arg->pages[i]); + + kfree(arg->pages); +} + +static int gssp_alloc_receive_pages(struct gssx_arg_accept_sec_context *arg) +{ + unsigned int i; + + arg->npages = DIV_ROUND_UP(NGROUPS_MAX * 4, PAGE_SIZE); + arg->pages = kcalloc(arg->npages, sizeof(struct page *), GFP_KERNEL); + if (!arg->pages) + return -ENOMEM; + for (i = 0; i < arg->npages; i++) { + arg->pages[i] = alloc_page(GFP_KERNEL); + if (!arg->pages[i]) { + gssp_free_receive_pages(arg); + return -ENOMEM; + } + } + return 0; +} + +static char *gssp_stringify(struct xdr_netobj *netobj) +{ + return kmemdup_nul(netobj->data, netobj->len, GFP_KERNEL); +} + +static void gssp_hostbased_service(char **principal) +{ + char *c; + + if (!*principal) + return; + + /* terminate and remove realm part */ + c = strchr(*principal, '@'); + if (c) { + *c = '\0'; + + /* change service-hostname delimiter */ + c = strchr(*principal, '/'); + if (c) + *c = '@'; + } + if (!c) { + /* not a service principal */ + kfree(*principal); + *principal = NULL; + } +} + +/* + * Public functions + */ + +/* numbers somewhat arbitrary but large enough for current needs */ +#define GSSX_MAX_OUT_HANDLE 128 +#define GSSX_MAX_SRC_PRINC 256 +#define GSSX_KMEMBUF (GSSX_max_output_handle_sz + \ + GSSX_max_oid_sz + \ + GSSX_max_princ_sz + \ + sizeof(struct svc_cred)) + +int gssp_accept_sec_context_upcall(struct net *net, + struct gssp_upcall_data *data) +{ + struct gssx_ctx ctxh = { + .state = data->in_handle + }; + struct gssx_arg_accept_sec_context arg = { + .input_token = data->in_token, + }; + struct gssx_ctx rctxh = { + /* + * pass in the max length we expect for each of these + * buffers but let the xdr code kmalloc them: + */ + .exported_context_token.len = GSSX_max_output_handle_sz, + .mech.len = GSS_OID_MAX_LEN, + .targ_name.display_name.len = GSSX_max_princ_sz, + .src_name.display_name.len = GSSX_max_princ_sz + }; + struct gssx_res_accept_sec_context res = { + .context_handle = &rctxh, + .output_token = &data->out_token + }; + struct rpc_message msg = { + .rpc_proc = &gssp_procedures[GSSX_ACCEPT_SEC_CONTEXT], + .rpc_argp = &arg, + .rpc_resp = &res, + .rpc_cred = NULL, /* FIXME ? */ + }; + struct xdr_netobj client_name = { 0 , NULL }; + struct xdr_netobj target_name = { 0, NULL }; + int ret; + + if (data->in_handle.len != 0) + arg.context_handle = &ctxh; + res.output_token->len = GSSX_max_output_token_sz; + + ret = gssp_alloc_receive_pages(&arg); + if (ret) + return ret; + + ret = gssp_call(net, &msg); + + gssp_free_receive_pages(&arg); + + /* we need to fetch all data even in case of error so + * that we can free special strctures is they have been allocated */ + data->major_status = res.status.major_status; + data->minor_status = res.status.minor_status; + if (res.context_handle) { + data->out_handle = rctxh.exported_context_token; + data->mech_oid.len = rctxh.mech.len; + if (rctxh.mech.data) { + memcpy(data->mech_oid.data, rctxh.mech.data, + data->mech_oid.len); + kfree(rctxh.mech.data); + } + client_name = rctxh.src_name.display_name; + target_name = rctxh.targ_name.display_name; + } + + if (res.options.count == 1) { + gssx_buffer *value = &res.options.data[0].value; + /* Currently we only decode CREDS_VALUE, if we add + * anything else we'll have to loop and match on the + * option name */ + if (value->len == 1) { + /* steal group info from struct svc_cred */ + data->creds = *(struct svc_cred *)value->data; + data->found_creds = 1; + } + /* whether we use it or not, free data */ + kfree(value->data); + } + + if (res.options.count != 0) { + kfree(res.options.data); + } + + /* convert to GSS_NT_HOSTBASED_SERVICE form and set into creds */ + if (data->found_creds) { + if (client_name.data) { + data->creds.cr_raw_principal = + gssp_stringify(&client_name); + data->creds.cr_principal = + gssp_stringify(&client_name); + gssp_hostbased_service(&data->creds.cr_principal); + } + if (target_name.data) { + data->creds.cr_targ_princ = + gssp_stringify(&target_name); + gssp_hostbased_service(&data->creds.cr_targ_princ); + } + } + kfree(client_name.data); + kfree(target_name.data); + + return ret; +} + +void gssp_free_upcall_data(struct gssp_upcall_data *data) +{ + kfree(data->in_handle.data); + kfree(data->out_handle.data); + kfree(data->out_token.data); + free_svc_cred(&data->creds); +} + +/* + * Initialization stuff + */ +static unsigned int gssp_version1_counts[ARRAY_SIZE(gssp_procedures)]; +static const struct rpc_version gssp_version1 = { + .number = GSSPROXY_VERS_1, + .nrprocs = ARRAY_SIZE(gssp_procedures), + .procs = gssp_procedures, + .counts = gssp_version1_counts, +}; + +static const struct rpc_version *gssp_version[] = { + NULL, + &gssp_version1, +}; + +static struct rpc_stat gssp_stats; + +static const struct rpc_program gssp_program = { + .name = "gssproxy", + .number = GSSPROXY_PROGRAM, + .nrvers = ARRAY_SIZE(gssp_version), + .version = gssp_version, + .stats = &gssp_stats, +}; diff --git a/net/sunrpc/auth_gss/gss_rpc_upcall.h b/net/sunrpc/auth_gss/gss_rpc_upcall.h new file mode 100644 index 000000000..31e963441 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_rpc_upcall.h @@ -0,0 +1,36 @@ +/* SPDX-License-Identifier: GPL-2.0+ */ +/* + * linux/net/sunrpc/gss_rpc_upcall.h + * + * Copyright (C) 2012 Simo Sorce + */ + +#ifndef _GSS_RPC_UPCALL_H +#define _GSS_RPC_UPCALL_H + +#include +#include +#include "gss_rpc_xdr.h" +#include "../netns.h" + +struct gssp_upcall_data { + struct xdr_netobj in_handle; + struct gssp_in_token in_token; + struct xdr_netobj out_handle; + struct xdr_netobj out_token; + struct rpcsec_gss_oid mech_oid; + struct svc_cred creds; + int found_creds; + int major_status; + int minor_status; +}; + +int gssp_accept_sec_context_upcall(struct net *net, + struct gssp_upcall_data *data); +void gssp_free_upcall_data(struct gssp_upcall_data *data); + +void init_gssp_clnt(struct sunrpc_net *); +int set_gssp_clnt(struct net *); +void clear_gssp_clnt(struct sunrpc_net *); + +#endif /* _GSS_RPC_UPCALL_H */ diff --git a/net/sunrpc/auth_gss/gss_rpc_xdr.c b/net/sunrpc/auth_gss/gss_rpc_xdr.c new file mode 100644 index 000000000..d79f12c25 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_rpc_xdr.c @@ -0,0 +1,838 @@ +// SPDX-License-Identifier: GPL-2.0+ +/* + * GSS Proxy upcall module + * + * Copyright (C) 2012 Simo Sorce + */ + +#include +#include "gss_rpc_xdr.h" + +static int gssx_enc_bool(struct xdr_stream *xdr, int v) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + *p = v ? xdr_one : xdr_zero; + return 0; +} + +static int gssx_dec_bool(struct xdr_stream *xdr, u32 *v) +{ + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + *v = be32_to_cpu(*p); + return 0; +} + +static int gssx_enc_buffer(struct xdr_stream *xdr, + const gssx_buffer *buf) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, sizeof(u32) + buf->len); + if (!p) + return -ENOSPC; + xdr_encode_opaque(p, buf->data, buf->len); + return 0; +} + +static int gssx_enc_in_token(struct xdr_stream *xdr, + const struct gssp_in_token *in) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, 4); + if (!p) + return -ENOSPC; + *p = cpu_to_be32(in->page_len); + + /* all we need to do is to write pages */ + xdr_write_pages(xdr, in->pages, in->page_base, in->page_len); + + return 0; +} + + +static int gssx_dec_buffer(struct xdr_stream *xdr, + gssx_buffer *buf) +{ + u32 length; + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + + length = be32_to_cpup(p); + p = xdr_inline_decode(xdr, length); + if (unlikely(p == NULL)) + return -ENOSPC; + + if (buf->len == 0) { + /* we intentionally are not interested in this buffer */ + return 0; + } + if (length > buf->len) + return -ENOSPC; + + if (!buf->data) { + buf->data = kmemdup(p, length, GFP_KERNEL); + if (!buf->data) + return -ENOMEM; + } else { + memcpy(buf->data, p, length); + } + buf->len = length; + return 0; +} + +static int gssx_enc_option(struct xdr_stream *xdr, + struct gssx_option *opt) +{ + int err; + + err = gssx_enc_buffer(xdr, &opt->option); + if (err) + return err; + err = gssx_enc_buffer(xdr, &opt->value); + return err; +} + +static int gssx_dec_option(struct xdr_stream *xdr, + struct gssx_option *opt) +{ + int err; + + err = gssx_dec_buffer(xdr, &opt->option); + if (err) + return err; + err = gssx_dec_buffer(xdr, &opt->value); + return err; +} + +static int dummy_enc_opt_array(struct xdr_stream *xdr, + const struct gssx_option_array *oa) +{ + __be32 *p; + + if (oa->count != 0) + return -EINVAL; + + p = xdr_reserve_space(xdr, 4); + if (!p) + return -ENOSPC; + *p = 0; + + return 0; +} + +static int dummy_dec_opt_array(struct xdr_stream *xdr, + struct gssx_option_array *oa) +{ + struct gssx_option dummy; + u32 count, i; + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + count = be32_to_cpup(p++); + memset(&dummy, 0, sizeof(dummy)); + for (i = 0; i < count; i++) { + gssx_dec_option(xdr, &dummy); + } + + oa->count = 0; + oa->data = NULL; + return 0; +} + +static int get_host_u32(struct xdr_stream *xdr, u32 *res) +{ + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (!p) + return -EINVAL; + /* Contents of linux creds are all host-endian: */ + memcpy(res, p, sizeof(u32)); + return 0; +} + +static int gssx_dec_linux_creds(struct xdr_stream *xdr, + struct svc_cred *creds) +{ + u32 length; + __be32 *p; + u32 tmp; + u32 N; + int i, err; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + + length = be32_to_cpup(p); + + if (length > (3 + NGROUPS_MAX) * sizeof(u32)) + return -ENOSPC; + + /* uid */ + err = get_host_u32(xdr, &tmp); + if (err) + return err; + creds->cr_uid = make_kuid(&init_user_ns, tmp); + + /* gid */ + err = get_host_u32(xdr, &tmp); + if (err) + return err; + creds->cr_gid = make_kgid(&init_user_ns, tmp); + + /* number of additional gid's */ + err = get_host_u32(xdr, &tmp); + if (err) + return err; + N = tmp; + if ((3 + N) * sizeof(u32) != length) + return -EINVAL; + creds->cr_group_info = groups_alloc(N); + if (creds->cr_group_info == NULL) + return -ENOMEM; + + /* gid's */ + for (i = 0; i < N; i++) { + kgid_t kgid; + err = get_host_u32(xdr, &tmp); + if (err) + goto out_free_groups; + err = -EINVAL; + kgid = make_kgid(&init_user_ns, tmp); + if (!gid_valid(kgid)) + goto out_free_groups; + creds->cr_group_info->gid[i] = kgid; + } + groups_sort(creds->cr_group_info); + + return 0; +out_free_groups: + groups_free(creds->cr_group_info); + return err; +} + +static int gssx_dec_option_array(struct xdr_stream *xdr, + struct gssx_option_array *oa) +{ + struct svc_cred *creds; + u32 count, i; + __be32 *p; + int err; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + count = be32_to_cpup(p++); + if (!count) + return 0; + + /* we recognize only 1 currently: CREDS_VALUE */ + oa->count = 1; + + oa->data = kmalloc(sizeof(struct gssx_option), GFP_KERNEL); + if (!oa->data) + return -ENOMEM; + + creds = kzalloc(sizeof(struct svc_cred), GFP_KERNEL); + if (!creds) { + kfree(oa->data); + return -ENOMEM; + } + + oa->data[0].option.data = CREDS_VALUE; + oa->data[0].option.len = sizeof(CREDS_VALUE); + oa->data[0].value.data = (void *)creds; + oa->data[0].value.len = 0; + + for (i = 0; i < count; i++) { + gssx_buffer dummy = { 0, NULL }; + u32 length; + + /* option buffer */ + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + + length = be32_to_cpup(p); + p = xdr_inline_decode(xdr, length); + if (unlikely(p == NULL)) + return -ENOSPC; + + if (length == sizeof(CREDS_VALUE) && + memcmp(p, CREDS_VALUE, sizeof(CREDS_VALUE)) == 0) { + /* We have creds here. parse them */ + err = gssx_dec_linux_creds(xdr, creds); + if (err) + return err; + oa->data[0].value.len = 1; /* presence */ + } else { + /* consume uninteresting buffer */ + err = gssx_dec_buffer(xdr, &dummy); + if (err) + return err; + } + } + return 0; +} + +static int gssx_dec_status(struct xdr_stream *xdr, + struct gssx_status *status) +{ + __be32 *p; + int err; + + /* status->major_status */ + p = xdr_inline_decode(xdr, 8); + if (unlikely(p == NULL)) + return -ENOSPC; + p = xdr_decode_hyper(p, &status->major_status); + + /* status->mech */ + err = gssx_dec_buffer(xdr, &status->mech); + if (err) + return err; + + /* status->minor_status */ + p = xdr_inline_decode(xdr, 8); + if (unlikely(p == NULL)) + return -ENOSPC; + p = xdr_decode_hyper(p, &status->minor_status); + + /* status->major_status_string */ + err = gssx_dec_buffer(xdr, &status->major_status_string); + if (err) + return err; + + /* status->minor_status_string */ + err = gssx_dec_buffer(xdr, &status->minor_status_string); + if (err) + return err; + + /* status->server_ctx */ + err = gssx_dec_buffer(xdr, &status->server_ctx); + if (err) + return err; + + /* we assume we have no options for now, so simply consume them */ + /* status->options */ + err = dummy_dec_opt_array(xdr, &status->options); + + return err; +} + +static int gssx_enc_call_ctx(struct xdr_stream *xdr, + const struct gssx_call_ctx *ctx) +{ + struct gssx_option opt; + __be32 *p; + int err; + + /* ctx->locale */ + err = gssx_enc_buffer(xdr, &ctx->locale); + if (err) + return err; + + /* ctx->server_ctx */ + err = gssx_enc_buffer(xdr, &ctx->server_ctx); + if (err) + return err; + + /* we always want to ask for lucid contexts */ + /* ctx->options */ + p = xdr_reserve_space(xdr, 4); + *p = cpu_to_be32(2); + + /* we want a lucid_v1 context */ + opt.option.data = LUCID_OPTION; + opt.option.len = sizeof(LUCID_OPTION); + opt.value.data = LUCID_VALUE; + opt.value.len = sizeof(LUCID_VALUE); + err = gssx_enc_option(xdr, &opt); + + /* ..and user creds */ + opt.option.data = CREDS_OPTION; + opt.option.len = sizeof(CREDS_OPTION); + opt.value.data = CREDS_VALUE; + opt.value.len = sizeof(CREDS_VALUE); + err = gssx_enc_option(xdr, &opt); + + return err; +} + +static int gssx_dec_name_attr(struct xdr_stream *xdr, + struct gssx_name_attr *attr) +{ + int err; + + /* attr->attr */ + err = gssx_dec_buffer(xdr, &attr->attr); + if (err) + return err; + + /* attr->value */ + err = gssx_dec_buffer(xdr, &attr->value); + if (err) + return err; + + /* attr->extensions */ + err = dummy_dec_opt_array(xdr, &attr->extensions); + + return err; +} + +static int dummy_enc_nameattr_array(struct xdr_stream *xdr, + struct gssx_name_attr_array *naa) +{ + __be32 *p; + + if (naa->count != 0) + return -EINVAL; + + p = xdr_reserve_space(xdr, 4); + if (!p) + return -ENOSPC; + *p = 0; + + return 0; +} + +static int dummy_dec_nameattr_array(struct xdr_stream *xdr, + struct gssx_name_attr_array *naa) +{ + struct gssx_name_attr dummy = { .attr = {.len = 0} }; + u32 count, i; + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -ENOSPC; + count = be32_to_cpup(p++); + for (i = 0; i < count; i++) { + gssx_dec_name_attr(xdr, &dummy); + } + + naa->count = 0; + naa->data = NULL; + return 0; +} + +static struct xdr_netobj zero_netobj = {}; + +static struct gssx_name_attr_array zero_name_attr_array = {}; + +static struct gssx_option_array zero_option_array = {}; + +static int gssx_enc_name(struct xdr_stream *xdr, + struct gssx_name *name) +{ + int err; + + /* name->display_name */ + err = gssx_enc_buffer(xdr, &name->display_name); + if (err) + return err; + + /* name->name_type */ + err = gssx_enc_buffer(xdr, &zero_netobj); + if (err) + return err; + + /* name->exported_name */ + err = gssx_enc_buffer(xdr, &zero_netobj); + if (err) + return err; + + /* name->exported_composite_name */ + err = gssx_enc_buffer(xdr, &zero_netobj); + if (err) + return err; + + /* leave name_attributes empty for now, will add once we have any + * to pass up at all */ + /* name->name_attributes */ + err = dummy_enc_nameattr_array(xdr, &zero_name_attr_array); + if (err) + return err; + + /* leave options empty for now, will add once we have any options + * to pass up at all */ + /* name->extensions */ + err = dummy_enc_opt_array(xdr, &zero_option_array); + + return err; +} + + +static int gssx_dec_name(struct xdr_stream *xdr, + struct gssx_name *name) +{ + struct xdr_netobj dummy_netobj = { .len = 0 }; + struct gssx_name_attr_array dummy_name_attr_array = { .count = 0 }; + struct gssx_option_array dummy_option_array = { .count = 0 }; + int err; + + /* name->display_name */ + err = gssx_dec_buffer(xdr, &name->display_name); + if (err) + return err; + + /* name->name_type */ + err = gssx_dec_buffer(xdr, &dummy_netobj); + if (err) + return err; + + /* name->exported_name */ + err = gssx_dec_buffer(xdr, &dummy_netobj); + if (err) + return err; + + /* name->exported_composite_name */ + err = gssx_dec_buffer(xdr, &dummy_netobj); + if (err) + return err; + + /* we assume we have no attributes for now, so simply consume them */ + /* name->name_attributes */ + err = dummy_dec_nameattr_array(xdr, &dummy_name_attr_array); + if (err) + return err; + + /* we assume we have no options for now, so simply consume them */ + /* name->extensions */ + err = dummy_dec_opt_array(xdr, &dummy_option_array); + + return err; +} + +static int dummy_enc_credel_array(struct xdr_stream *xdr, + struct gssx_cred_element_array *cea) +{ + __be32 *p; + + if (cea->count != 0) + return -EINVAL; + + p = xdr_reserve_space(xdr, 4); + if (!p) + return -ENOSPC; + *p = 0; + + return 0; +} + +static int gssx_enc_cred(struct xdr_stream *xdr, + struct gssx_cred *cred) +{ + int err; + + /* cred->desired_name */ + err = gssx_enc_name(xdr, &cred->desired_name); + if (err) + return err; + + /* cred->elements */ + err = dummy_enc_credel_array(xdr, &cred->elements); + if (err) + return err; + + /* cred->cred_handle_reference */ + err = gssx_enc_buffer(xdr, &cred->cred_handle_reference); + if (err) + return err; + + /* cred->needs_release */ + err = gssx_enc_bool(xdr, cred->needs_release); + + return err; +} + +static int gssx_enc_ctx(struct xdr_stream *xdr, + struct gssx_ctx *ctx) +{ + __be32 *p; + int err; + + /* ctx->exported_context_token */ + err = gssx_enc_buffer(xdr, &ctx->exported_context_token); + if (err) + return err; + + /* ctx->state */ + err = gssx_enc_buffer(xdr, &ctx->state); + if (err) + return err; + + /* ctx->need_release */ + err = gssx_enc_bool(xdr, ctx->need_release); + if (err) + return err; + + /* ctx->mech */ + err = gssx_enc_buffer(xdr, &ctx->mech); + if (err) + return err; + + /* ctx->src_name */ + err = gssx_enc_name(xdr, &ctx->src_name); + if (err) + return err; + + /* ctx->targ_name */ + err = gssx_enc_name(xdr, &ctx->targ_name); + if (err) + return err; + + /* ctx->lifetime */ + p = xdr_reserve_space(xdr, 8+8); + if (!p) + return -ENOSPC; + p = xdr_encode_hyper(p, ctx->lifetime); + + /* ctx->ctx_flags */ + p = xdr_encode_hyper(p, ctx->ctx_flags); + + /* ctx->locally_initiated */ + err = gssx_enc_bool(xdr, ctx->locally_initiated); + if (err) + return err; + + /* ctx->open */ + err = gssx_enc_bool(xdr, ctx->open); + if (err) + return err; + + /* leave options empty for now, will add once we have any options + * to pass up at all */ + /* ctx->options */ + err = dummy_enc_opt_array(xdr, &ctx->options); + + return err; +} + +static int gssx_dec_ctx(struct xdr_stream *xdr, + struct gssx_ctx *ctx) +{ + __be32 *p; + int err; + + /* ctx->exported_context_token */ + err = gssx_dec_buffer(xdr, &ctx->exported_context_token); + if (err) + return err; + + /* ctx->state */ + err = gssx_dec_buffer(xdr, &ctx->state); + if (err) + return err; + + /* ctx->need_release */ + err = gssx_dec_bool(xdr, &ctx->need_release); + if (err) + return err; + + /* ctx->mech */ + err = gssx_dec_buffer(xdr, &ctx->mech); + if (err) + return err; + + /* ctx->src_name */ + err = gssx_dec_name(xdr, &ctx->src_name); + if (err) + return err; + + /* ctx->targ_name */ + err = gssx_dec_name(xdr, &ctx->targ_name); + if (err) + return err; + + /* ctx->lifetime */ + p = xdr_inline_decode(xdr, 8+8); + if (unlikely(p == NULL)) + return -ENOSPC; + p = xdr_decode_hyper(p, &ctx->lifetime); + + /* ctx->ctx_flags */ + p = xdr_decode_hyper(p, &ctx->ctx_flags); + + /* ctx->locally_initiated */ + err = gssx_dec_bool(xdr, &ctx->locally_initiated); + if (err) + return err; + + /* ctx->open */ + err = gssx_dec_bool(xdr, &ctx->open); + if (err) + return err; + + /* we assume we have no options for now, so simply consume them */ + /* ctx->options */ + err = dummy_dec_opt_array(xdr, &ctx->options); + + return err; +} + +static int gssx_enc_cb(struct xdr_stream *xdr, struct gssx_cb *cb) +{ + __be32 *p; + int err; + + /* cb->initiator_addrtype */ + p = xdr_reserve_space(xdr, 8); + if (!p) + return -ENOSPC; + p = xdr_encode_hyper(p, cb->initiator_addrtype); + + /* cb->initiator_address */ + err = gssx_enc_buffer(xdr, &cb->initiator_address); + if (err) + return err; + + /* cb->acceptor_addrtype */ + p = xdr_reserve_space(xdr, 8); + if (!p) + return -ENOSPC; + p = xdr_encode_hyper(p, cb->acceptor_addrtype); + + /* cb->acceptor_address */ + err = gssx_enc_buffer(xdr, &cb->acceptor_address); + if (err) + return err; + + /* cb->application_data */ + err = gssx_enc_buffer(xdr, &cb->application_data); + + return err; +} + +void gssx_enc_accept_sec_context(struct rpc_rqst *req, + struct xdr_stream *xdr, + const void *data) +{ + const struct gssx_arg_accept_sec_context *arg = data; + int err; + + err = gssx_enc_call_ctx(xdr, &arg->call_ctx); + if (err) + goto done; + + /* arg->context_handle */ + if (arg->context_handle) + err = gssx_enc_ctx(xdr, arg->context_handle); + else + err = gssx_enc_bool(xdr, 0); + if (err) + goto done; + + /* arg->cred_handle */ + if (arg->cred_handle) + err = gssx_enc_cred(xdr, arg->cred_handle); + else + err = gssx_enc_bool(xdr, 0); + if (err) + goto done; + + /* arg->input_token */ + err = gssx_enc_in_token(xdr, &arg->input_token); + if (err) + goto done; + + /* arg->input_cb */ + if (arg->input_cb) + err = gssx_enc_cb(xdr, arg->input_cb); + else + err = gssx_enc_bool(xdr, 0); + if (err) + goto done; + + err = gssx_enc_bool(xdr, arg->ret_deleg_cred); + if (err) + goto done; + + /* leave options empty for now, will add once we have any options + * to pass up at all */ + /* arg->options */ + err = dummy_enc_opt_array(xdr, &arg->options); + + xdr_inline_pages(&req->rq_rcv_buf, + PAGE_SIZE/2 /* pretty arbitrary */, + arg->pages, 0 /* page base */, arg->npages * PAGE_SIZE); +done: + if (err) + dprintk("RPC: gssx_enc_accept_sec_context: %d\n", err); +} + +int gssx_dec_accept_sec_context(struct rpc_rqst *rqstp, + struct xdr_stream *xdr, + void *data) +{ + struct gssx_res_accept_sec_context *res = data; + u32 value_follows; + int err; + struct page *scratch; + + scratch = alloc_page(GFP_KERNEL); + if (!scratch) + return -ENOMEM; + xdr_set_scratch_page(xdr, scratch); + + /* res->status */ + err = gssx_dec_status(xdr, &res->status); + if (err) + goto out_free; + + /* res->context_handle */ + err = gssx_dec_bool(xdr, &value_follows); + if (err) + goto out_free; + if (value_follows) { + err = gssx_dec_ctx(xdr, res->context_handle); + if (err) + goto out_free; + } else { + res->context_handle = NULL; + } + + /* res->output_token */ + err = gssx_dec_bool(xdr, &value_follows); + if (err) + goto out_free; + if (value_follows) { + err = gssx_dec_buffer(xdr, res->output_token); + if (err) + goto out_free; + } else { + res->output_token = NULL; + } + + /* res->delegated_cred_handle */ + err = gssx_dec_bool(xdr, &value_follows); + if (err) + goto out_free; + if (value_follows) { + /* we do not support upcall servers sending this data. */ + err = -EINVAL; + goto out_free; + } + + /* res->options */ + err = gssx_dec_option_array(xdr, &res->options); + +out_free: + __free_page(scratch); + return err; +} diff --git a/net/sunrpc/auth_gss/gss_rpc_xdr.h b/net/sunrpc/auth_gss/gss_rpc_xdr.h new file mode 100644 index 000000000..3f17411b7 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_rpc_xdr.h @@ -0,0 +1,252 @@ +/* SPDX-License-Identifier: GPL-2.0+ */ +/* + * GSS Proxy upcall module + * + * Copyright (C) 2012 Simo Sorce + */ + +#ifndef _LINUX_GSS_RPC_XDR_H +#define _LINUX_GSS_RPC_XDR_H + +#include +#include +#include + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +#define LUCID_OPTION "exported_context_type" +#define LUCID_VALUE "linux_lucid_v1" +#define CREDS_OPTION "exported_creds_type" +#define CREDS_VALUE "linux_creds_v1" + +typedef struct xdr_netobj gssx_buffer; +typedef struct xdr_netobj utf8string; +typedef struct xdr_netobj gssx_OID; + +enum gssx_cred_usage { + GSSX_C_INITIATE = 1, + GSSX_C_ACCEPT = 2, + GSSX_C_BOTH = 3, +}; + +struct gssx_option { + gssx_buffer option; + gssx_buffer value; +}; + +struct gssx_option_array { + u32 count; + struct gssx_option *data; +}; + +struct gssx_status { + u64 major_status; + gssx_OID mech; + u64 minor_status; + utf8string major_status_string; + utf8string minor_status_string; + gssx_buffer server_ctx; + struct gssx_option_array options; +}; + +struct gssx_call_ctx { + utf8string locale; + gssx_buffer server_ctx; + struct gssx_option_array options; +}; + +struct gssx_name_attr { + gssx_buffer attr; + gssx_buffer value; + struct gssx_option_array extensions; +}; + +struct gssx_name_attr_array { + u32 count; + struct gssx_name_attr *data; +}; + +struct gssx_name { + gssx_buffer display_name; +}; +typedef struct gssx_name gssx_name; + +struct gssx_cred_element { + gssx_name MN; + gssx_OID mech; + u32 cred_usage; + u64 initiator_time_rec; + u64 acceptor_time_rec; + struct gssx_option_array options; +}; + +struct gssx_cred_element_array { + u32 count; + struct gssx_cred_element *data; +}; + +struct gssx_cred { + gssx_name desired_name; + struct gssx_cred_element_array elements; + gssx_buffer cred_handle_reference; + u32 needs_release; +}; + +struct gssx_ctx { + gssx_buffer exported_context_token; + gssx_buffer state; + u32 need_release; + gssx_OID mech; + gssx_name src_name; + gssx_name targ_name; + u64 lifetime; + u64 ctx_flags; + u32 locally_initiated; + u32 open; + struct gssx_option_array options; +}; + +struct gssx_cb { + u64 initiator_addrtype; + gssx_buffer initiator_address; + u64 acceptor_addrtype; + gssx_buffer acceptor_address; + gssx_buffer application_data; +}; + + +/* This structure is not defined in the protocol. + * It is used in the kernel to carry around a big buffer + * as a set of pages */ +struct gssp_in_token { + struct page **pages; /* Array of contiguous pages */ + unsigned int page_base; /* Start of page data */ + unsigned int page_len; /* Length of page data */ +}; + +struct gssx_arg_accept_sec_context { + struct gssx_call_ctx call_ctx; + struct gssx_ctx *context_handle; + struct gssx_cred *cred_handle; + struct gssp_in_token input_token; + struct gssx_cb *input_cb; + u32 ret_deleg_cred; + struct gssx_option_array options; + struct page **pages; + unsigned int npages; +}; + +struct gssx_res_accept_sec_context { + struct gssx_status status; + struct gssx_ctx *context_handle; + gssx_buffer *output_token; + /* struct gssx_cred *delegated_cred_handle; not used in kernel */ + struct gssx_option_array options; +}; + + + +#define gssx_enc_indicate_mechs NULL +#define gssx_dec_indicate_mechs NULL +#define gssx_enc_get_call_context NULL +#define gssx_dec_get_call_context NULL +#define gssx_enc_import_and_canon_name NULL +#define gssx_dec_import_and_canon_name NULL +#define gssx_enc_export_cred NULL +#define gssx_dec_export_cred NULL +#define gssx_enc_import_cred NULL +#define gssx_dec_import_cred NULL +#define gssx_enc_acquire_cred NULL +#define gssx_dec_acquire_cred NULL +#define gssx_enc_store_cred NULL +#define gssx_dec_store_cred NULL +#define gssx_enc_init_sec_context NULL +#define gssx_dec_init_sec_context NULL +void gssx_enc_accept_sec_context(struct rpc_rqst *req, + struct xdr_stream *xdr, + const void *data); +int gssx_dec_accept_sec_context(struct rpc_rqst *rqstp, + struct xdr_stream *xdr, + void *data); +#define gssx_enc_release_handle NULL +#define gssx_dec_release_handle NULL +#define gssx_enc_get_mic NULL +#define gssx_dec_get_mic NULL +#define gssx_enc_verify NULL +#define gssx_dec_verify NULL +#define gssx_enc_wrap NULL +#define gssx_dec_wrap NULL +#define gssx_enc_unwrap NULL +#define gssx_dec_unwrap NULL +#define gssx_enc_wrap_size_limit NULL +#define gssx_dec_wrap_size_limit NULL + +/* non implemented calls are set to 0 size */ +#define GSSX_ARG_indicate_mechs_sz 0 +#define GSSX_RES_indicate_mechs_sz 0 +#define GSSX_ARG_get_call_context_sz 0 +#define GSSX_RES_get_call_context_sz 0 +#define GSSX_ARG_import_and_canon_name_sz 0 +#define GSSX_RES_import_and_canon_name_sz 0 +#define GSSX_ARG_export_cred_sz 0 +#define GSSX_RES_export_cred_sz 0 +#define GSSX_ARG_import_cred_sz 0 +#define GSSX_RES_import_cred_sz 0 +#define GSSX_ARG_acquire_cred_sz 0 +#define GSSX_RES_acquire_cred_sz 0 +#define GSSX_ARG_store_cred_sz 0 +#define GSSX_RES_store_cred_sz 0 +#define GSSX_ARG_init_sec_context_sz 0 +#define GSSX_RES_init_sec_context_sz 0 + +#define GSSX_default_in_call_ctx_sz (4 + 4 + 4 + \ + 8 + sizeof(LUCID_OPTION) + sizeof(LUCID_VALUE) + \ + 8 + sizeof(CREDS_OPTION) + sizeof(CREDS_VALUE)) +#define GSSX_default_in_ctx_hndl_sz (4 + 4+8 + 4 + 4 + 6*4 + 6*4 + 8 + 8 + \ + 4 + 4 + 4) +#define GSSX_default_in_cred_sz 4 /* we send in no cred_handle */ +#define GSSX_default_in_token_sz 4 /* does *not* include token data */ +#define GSSX_default_in_cb_sz 4 /* we do not use channel bindings */ +#define GSSX_ARG_accept_sec_context_sz (GSSX_default_in_call_ctx_sz + \ + GSSX_default_in_ctx_hndl_sz + \ + GSSX_default_in_cred_sz + \ + GSSX_default_in_token_sz + \ + GSSX_default_in_cb_sz + \ + 4 /* no deleg creds boolean */ + \ + 4) /* empty options */ + +/* somewhat arbitrary numbers but large enough (we ignore some of the data + * sent down, but it is part of the protocol so we need enough space to take + * it in) */ +#define GSSX_default_status_sz 8 + 24 + 8 + 256 + 256 + 16 + 4 +#define GSSX_max_output_handle_sz 128 +#define GSSX_max_oid_sz 16 +#define GSSX_max_princ_sz 256 +#define GSSX_default_ctx_sz (GSSX_max_output_handle_sz + \ + 16 + 4 + GSSX_max_oid_sz + \ + 2 * GSSX_max_princ_sz + \ + 8 + 8 + 4 + 4 + 4) +#define GSSX_max_output_token_sz 1024 +/* grouplist not included; we allocate separate pages for that: */ +#define GSSX_max_creds_sz (4 + 4 + 4 /* + NGROUPS_MAX*4 */) +#define GSSX_RES_accept_sec_context_sz (GSSX_default_status_sz + \ + GSSX_default_ctx_sz + \ + GSSX_max_output_token_sz + \ + 4 + GSSX_max_creds_sz) + +#define GSSX_ARG_release_handle_sz 0 +#define GSSX_RES_release_handle_sz 0 +#define GSSX_ARG_get_mic_sz 0 +#define GSSX_RES_get_mic_sz 0 +#define GSSX_ARG_verify_sz 0 +#define GSSX_RES_verify_sz 0 +#define GSSX_ARG_wrap_sz 0 +#define GSSX_RES_wrap_sz 0 +#define GSSX_ARG_unwrap_sz 0 +#define GSSX_RES_unwrap_sz 0 +#define GSSX_ARG_wrap_size_limit_sz 0 +#define GSSX_RES_wrap_size_limit_sz 0 + +#endif /* _LINUX_GSS_RPC_XDR_H */ diff --git a/net/sunrpc/auth_gss/svcauth_gss.c b/net/sunrpc/auth_gss/svcauth_gss.c new file mode 100644 index 000000000..bdc34ea0d --- /dev/null +++ b/net/sunrpc/auth_gss/svcauth_gss.c @@ -0,0 +1,2017 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Neil Brown + * J. Bruce Fields + * Andy Adamson + * Dug Song + * + * RPCSEC_GSS server authentication. + * This implements RPCSEC_GSS as defined in rfc2203 (rpcsec_gss) and rfc2078 + * (gssapi) + * + * The RPCSEC_GSS involves three stages: + * 1/ context creation + * 2/ data exchange + * 3/ context destruction + * + * Context creation is handled largely by upcalls to user-space. + * In particular, GSS_Accept_sec_context is handled by an upcall + * Data exchange is handled entirely within the kernel + * In particular, GSS_GetMIC, GSS_VerifyMIC, GSS_Seal, GSS_Unseal are in-kernel. + * Context destruction is handled in-kernel + * GSS_Delete_sec_context is in-kernel + * + * Context creation is initiated by a RPCSEC_GSS_INIT request arriving. + * The context handle and gss_token are used as a key into the rpcsec_init cache. + * The content of this cache includes some of the outputs of GSS_Accept_sec_context, + * being major_status, minor_status, context_handle, reply_token. + * These are sent back to the client. + * Sequence window management is handled by the kernel. The window size if currently + * a compile time constant. + * + * When user-space is happy that a context is established, it places an entry + * in the rpcsec_context cache. The key for this cache is the context_handle. + * The content includes: + * uid/gidlist - for determining access rights + * mechanism type + * mechanism specific information, such as a key + * + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include "gss_rpc_upcall.h" + + +/* The rpcsec_init cache is used for mapping RPCSEC_GSS_{,CONT_}INIT requests + * into replies. + * + * Key is context handle (\x if empty) and gss_token. + * Content is major_status minor_status (integers) context_handle, reply_token. + * + */ + +static int netobj_equal(struct xdr_netobj *a, struct xdr_netobj *b) +{ + return a->len == b->len && 0 == memcmp(a->data, b->data, a->len); +} + +#define RSI_HASHBITS 6 +#define RSI_HASHMAX (1<in_handle.data); + kfree(rsii->in_token.data); + kfree(rsii->out_handle.data); + kfree(rsii->out_token.data); +} + +static void rsi_free_rcu(struct rcu_head *head) +{ + struct rsi *rsii = container_of(head, struct rsi, rcu_head); + + rsi_free(rsii); + kfree(rsii); +} + +static void rsi_put(struct kref *ref) +{ + struct rsi *rsii = container_of(ref, struct rsi, h.ref); + + call_rcu(&rsii->rcu_head, rsi_free_rcu); +} + +static inline int rsi_hash(struct rsi *item) +{ + return hash_mem(item->in_handle.data, item->in_handle.len, RSI_HASHBITS) + ^ hash_mem(item->in_token.data, item->in_token.len, RSI_HASHBITS); +} + +static int rsi_match(struct cache_head *a, struct cache_head *b) +{ + struct rsi *item = container_of(a, struct rsi, h); + struct rsi *tmp = container_of(b, struct rsi, h); + return netobj_equal(&item->in_handle, &tmp->in_handle) && + netobj_equal(&item->in_token, &tmp->in_token); +} + +static int dup_to_netobj(struct xdr_netobj *dst, char *src, int len) +{ + dst->len = len; + dst->data = (len ? kmemdup(src, len, GFP_KERNEL) : NULL); + if (len && !dst->data) + return -ENOMEM; + return 0; +} + +static inline int dup_netobj(struct xdr_netobj *dst, struct xdr_netobj *src) +{ + return dup_to_netobj(dst, src->data, src->len); +} + +static void rsi_init(struct cache_head *cnew, struct cache_head *citem) +{ + struct rsi *new = container_of(cnew, struct rsi, h); + struct rsi *item = container_of(citem, struct rsi, h); + + new->out_handle.data = NULL; + new->out_handle.len = 0; + new->out_token.data = NULL; + new->out_token.len = 0; + new->in_handle.len = item->in_handle.len; + item->in_handle.len = 0; + new->in_token.len = item->in_token.len; + item->in_token.len = 0; + new->in_handle.data = item->in_handle.data; + item->in_handle.data = NULL; + new->in_token.data = item->in_token.data; + item->in_token.data = NULL; +} + +static void update_rsi(struct cache_head *cnew, struct cache_head *citem) +{ + struct rsi *new = container_of(cnew, struct rsi, h); + struct rsi *item = container_of(citem, struct rsi, h); + + BUG_ON(new->out_handle.data || new->out_token.data); + new->out_handle.len = item->out_handle.len; + item->out_handle.len = 0; + new->out_token.len = item->out_token.len; + item->out_token.len = 0; + new->out_handle.data = item->out_handle.data; + item->out_handle.data = NULL; + new->out_token.data = item->out_token.data; + item->out_token.data = NULL; + + new->major_status = item->major_status; + new->minor_status = item->minor_status; +} + +static struct cache_head *rsi_alloc(void) +{ + struct rsi *rsii = kmalloc(sizeof(*rsii), GFP_KERNEL); + if (rsii) + return &rsii->h; + else + return NULL; +} + +static int rsi_upcall(struct cache_detail *cd, struct cache_head *h) +{ + return sunrpc_cache_pipe_upcall_timeout(cd, h); +} + +static void rsi_request(struct cache_detail *cd, + struct cache_head *h, + char **bpp, int *blen) +{ + struct rsi *rsii = container_of(h, struct rsi, h); + + qword_addhex(bpp, blen, rsii->in_handle.data, rsii->in_handle.len); + qword_addhex(bpp, blen, rsii->in_token.data, rsii->in_token.len); + (*bpp)[-1] = '\n'; + WARN_ONCE(*blen < 0, + "RPCSEC/GSS credential too large - please use gssproxy\n"); +} + +static int rsi_parse(struct cache_detail *cd, + char *mesg, int mlen) +{ + /* context token expiry major minor context token */ + char *buf = mesg; + char *ep; + int len; + struct rsi rsii, *rsip = NULL; + time64_t expiry; + int status = -EINVAL; + + memset(&rsii, 0, sizeof(rsii)); + /* handle */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) + goto out; + status = -ENOMEM; + if (dup_to_netobj(&rsii.in_handle, buf, len)) + goto out; + + /* token */ + len = qword_get(&mesg, buf, mlen); + status = -EINVAL; + if (len < 0) + goto out; + status = -ENOMEM; + if (dup_to_netobj(&rsii.in_token, buf, len)) + goto out; + + rsip = rsi_lookup(cd, &rsii); + if (!rsip) + goto out; + + rsii.h.flags = 0; + /* expiry */ + expiry = get_expiry(&mesg); + status = -EINVAL; + if (expiry == 0) + goto out; + + /* major/minor */ + len = qword_get(&mesg, buf, mlen); + if (len <= 0) + goto out; + rsii.major_status = simple_strtoul(buf, &ep, 10); + if (*ep) + goto out; + len = qword_get(&mesg, buf, mlen); + if (len <= 0) + goto out; + rsii.minor_status = simple_strtoul(buf, &ep, 10); + if (*ep) + goto out; + + /* out_handle */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) + goto out; + status = -ENOMEM; + if (dup_to_netobj(&rsii.out_handle, buf, len)) + goto out; + + /* out_token */ + len = qword_get(&mesg, buf, mlen); + status = -EINVAL; + if (len < 0) + goto out; + status = -ENOMEM; + if (dup_to_netobj(&rsii.out_token, buf, len)) + goto out; + rsii.h.expiry_time = expiry; + rsip = rsi_update(cd, &rsii, rsip); + status = 0; +out: + rsi_free(&rsii); + if (rsip) + cache_put(&rsip->h, cd); + else + status = -ENOMEM; + return status; +} + +static const struct cache_detail rsi_cache_template = { + .owner = THIS_MODULE, + .hash_size = RSI_HASHMAX, + .name = "auth.rpcsec.init", + .cache_put = rsi_put, + .cache_upcall = rsi_upcall, + .cache_request = rsi_request, + .cache_parse = rsi_parse, + .match = rsi_match, + .init = rsi_init, + .update = update_rsi, + .alloc = rsi_alloc, +}; + +static struct rsi *rsi_lookup(struct cache_detail *cd, struct rsi *item) +{ + struct cache_head *ch; + int hash = rsi_hash(item); + + ch = sunrpc_cache_lookup_rcu(cd, &item->h, hash); + if (ch) + return container_of(ch, struct rsi, h); + else + return NULL; +} + +static struct rsi *rsi_update(struct cache_detail *cd, struct rsi *new, struct rsi *old) +{ + struct cache_head *ch; + int hash = rsi_hash(new); + + ch = sunrpc_cache_update(cd, &new->h, + &old->h, hash); + if (ch) + return container_of(ch, struct rsi, h); + else + return NULL; +} + + +/* + * The rpcsec_context cache is used to store a context that is + * used in data exchange. + * The key is a context handle. The content is: + * uid, gidlist, mechanism, service-set, mech-specific-data + */ + +#define RSC_HASHBITS 10 +#define RSC_HASHMAX (1<handle.data); + if (rsci->mechctx) + gss_delete_sec_context(&rsci->mechctx); + free_svc_cred(&rsci->cred); +} + +static void rsc_free_rcu(struct rcu_head *head) +{ + struct rsc *rsci = container_of(head, struct rsc, rcu_head); + + kfree(rsci->handle.data); + kfree(rsci); +} + +static void rsc_put(struct kref *ref) +{ + struct rsc *rsci = container_of(ref, struct rsc, h.ref); + + if (rsci->mechctx) + gss_delete_sec_context(&rsci->mechctx); + free_svc_cred(&rsci->cred); + call_rcu(&rsci->rcu_head, rsc_free_rcu); +} + +static inline int +rsc_hash(struct rsc *rsci) +{ + return hash_mem(rsci->handle.data, rsci->handle.len, RSC_HASHBITS); +} + +static int +rsc_match(struct cache_head *a, struct cache_head *b) +{ + struct rsc *new = container_of(a, struct rsc, h); + struct rsc *tmp = container_of(b, struct rsc, h); + + return netobj_equal(&new->handle, &tmp->handle); +} + +static void +rsc_init(struct cache_head *cnew, struct cache_head *ctmp) +{ + struct rsc *new = container_of(cnew, struct rsc, h); + struct rsc *tmp = container_of(ctmp, struct rsc, h); + + new->handle.len = tmp->handle.len; + tmp->handle.len = 0; + new->handle.data = tmp->handle.data; + tmp->handle.data = NULL; + new->mechctx = NULL; + init_svc_cred(&new->cred); +} + +static void +update_rsc(struct cache_head *cnew, struct cache_head *ctmp) +{ + struct rsc *new = container_of(cnew, struct rsc, h); + struct rsc *tmp = container_of(ctmp, struct rsc, h); + + new->mechctx = tmp->mechctx; + tmp->mechctx = NULL; + memset(&new->seqdata, 0, sizeof(new->seqdata)); + spin_lock_init(&new->seqdata.sd_lock); + new->cred = tmp->cred; + init_svc_cred(&tmp->cred); +} + +static struct cache_head * +rsc_alloc(void) +{ + struct rsc *rsci = kmalloc(sizeof(*rsci), GFP_KERNEL); + if (rsci) + return &rsci->h; + else + return NULL; +} + +static int rsc_upcall(struct cache_detail *cd, struct cache_head *h) +{ + return -EINVAL; +} + +static int rsc_parse(struct cache_detail *cd, + char *mesg, int mlen) +{ + /* contexthandle expiry [ uid gid N mechname ...mechdata... ] */ + char *buf = mesg; + int id; + int len, rv; + struct rsc rsci, *rscp = NULL; + time64_t expiry; + int status = -EINVAL; + struct gss_api_mech *gm = NULL; + + memset(&rsci, 0, sizeof(rsci)); + /* context handle */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) goto out; + status = -ENOMEM; + if (dup_to_netobj(&rsci.handle, buf, len)) + goto out; + + rsci.h.flags = 0; + /* expiry */ + expiry = get_expiry(&mesg); + status = -EINVAL; + if (expiry == 0) + goto out; + + rscp = rsc_lookup(cd, &rsci); + if (!rscp) + goto out; + + /* uid, or NEGATIVE */ + rv = get_int(&mesg, &id); + if (rv == -EINVAL) + goto out; + if (rv == -ENOENT) + set_bit(CACHE_NEGATIVE, &rsci.h.flags); + else { + int N, i; + + /* + * NOTE: we skip uid_valid()/gid_valid() checks here: + * instead, * -1 id's are later mapped to the + * (export-specific) anonymous id by nfsd_setuser. + * + * (But supplementary gid's get no such special + * treatment so are checked for validity here.) + */ + /* uid */ + rsci.cred.cr_uid = make_kuid(current_user_ns(), id); + + /* gid */ + if (get_int(&mesg, &id)) + goto out; + rsci.cred.cr_gid = make_kgid(current_user_ns(), id); + + /* number of additional gid's */ + if (get_int(&mesg, &N)) + goto out; + if (N < 0 || N > NGROUPS_MAX) + goto out; + status = -ENOMEM; + rsci.cred.cr_group_info = groups_alloc(N); + if (rsci.cred.cr_group_info == NULL) + goto out; + + /* gid's */ + status = -EINVAL; + for (i=0; igid[i] = kgid; + } + groups_sort(rsci.cred.cr_group_info); + + /* mech name */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) + goto out; + gm = rsci.cred.cr_gss_mech = gss_mech_get_by_name(buf); + status = -EOPNOTSUPP; + if (!gm) + goto out; + + status = -EINVAL; + /* mech-specific data: */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) + goto out; + status = gss_import_sec_context(buf, len, gm, &rsci.mechctx, + NULL, GFP_KERNEL); + if (status) + goto out; + + /* get client name */ + len = qword_get(&mesg, buf, mlen); + if (len > 0) { + rsci.cred.cr_principal = kstrdup(buf, GFP_KERNEL); + if (!rsci.cred.cr_principal) { + status = -ENOMEM; + goto out; + } + } + + } + rsci.h.expiry_time = expiry; + rscp = rsc_update(cd, &rsci, rscp); + status = 0; +out: + rsc_free(&rsci); + if (rscp) + cache_put(&rscp->h, cd); + else + status = -ENOMEM; + return status; +} + +static const struct cache_detail rsc_cache_template = { + .owner = THIS_MODULE, + .hash_size = RSC_HASHMAX, + .name = "auth.rpcsec.context", + .cache_put = rsc_put, + .cache_upcall = rsc_upcall, + .cache_parse = rsc_parse, + .match = rsc_match, + .init = rsc_init, + .update = update_rsc, + .alloc = rsc_alloc, +}; + +static struct rsc *rsc_lookup(struct cache_detail *cd, struct rsc *item) +{ + struct cache_head *ch; + int hash = rsc_hash(item); + + ch = sunrpc_cache_lookup_rcu(cd, &item->h, hash); + if (ch) + return container_of(ch, struct rsc, h); + else + return NULL; +} + +static struct rsc *rsc_update(struct cache_detail *cd, struct rsc *new, struct rsc *old) +{ + struct cache_head *ch; + int hash = rsc_hash(new); + + ch = sunrpc_cache_update(cd, &new->h, + &old->h, hash); + if (ch) + return container_of(ch, struct rsc, h); + else + return NULL; +} + + +static struct rsc * +gss_svc_searchbyctx(struct cache_detail *cd, struct xdr_netobj *handle) +{ + struct rsc rsci; + struct rsc *found; + + memset(&rsci, 0, sizeof(rsci)); + if (dup_to_netobj(&rsci.handle, handle->data, handle->len)) + return NULL; + found = rsc_lookup(cd, &rsci); + rsc_free(&rsci); + if (!found) + return NULL; + if (cache_check(cd, &found->h, NULL)) + return NULL; + return found; +} + +/** + * gss_check_seq_num - GSS sequence number window check + * @rqstp: RPC Call to use when reporting errors + * @rsci: cached GSS context state (updated on return) + * @seq_num: sequence number to check + * + * Implements sequence number algorithm as specified in + * RFC 2203, Section 5.3.3.1. "Context Management". + * + * Return values: + * %true: @rqstp's GSS sequence number is inside the window + * %false: @rqstp's GSS sequence number is outside the window + */ +static bool gss_check_seq_num(const struct svc_rqst *rqstp, struct rsc *rsci, + u32 seq_num) +{ + struct gss_svc_seq_data *sd = &rsci->seqdata; + bool result = false; + + spin_lock(&sd->sd_lock); + if (seq_num > sd->sd_max) { + if (seq_num >= sd->sd_max + GSS_SEQ_WIN) { + memset(sd->sd_win, 0, sizeof(sd->sd_win)); + sd->sd_max = seq_num; + } else while (sd->sd_max < seq_num) { + sd->sd_max++; + __clear_bit(sd->sd_max % GSS_SEQ_WIN, sd->sd_win); + } + __set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win); + goto ok; + } else if (seq_num + GSS_SEQ_WIN <= sd->sd_max) { + goto toolow; + } + if (__test_and_set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win)) + goto alreadyseen; + +ok: + result = true; +out: + spin_unlock(&sd->sd_lock); + return result; + +toolow: + trace_rpcgss_svc_seqno_low(rqstp, seq_num, + sd->sd_max - GSS_SEQ_WIN, + sd->sd_max); + goto out; +alreadyseen: + trace_rpcgss_svc_seqno_seen(rqstp, seq_num); + goto out; +} + +static inline u32 round_up_to_quad(u32 i) +{ + return (i + 3 ) & ~3; +} + +static inline int +svc_safe_getnetobj(struct kvec *argv, struct xdr_netobj *o) +{ + int l; + + if (argv->iov_len < 4) + return -1; + o->len = svc_getnl(argv); + l = round_up_to_quad(o->len); + if (argv->iov_len < l) + return -1; + o->data = argv->iov_base; + argv->iov_base += l; + argv->iov_len -= l; + return 0; +} + +static inline int +svc_safe_putnetobj(struct kvec *resv, struct xdr_netobj *o) +{ + u8 *p; + + if (resv->iov_len + 4 > PAGE_SIZE) + return -1; + svc_putnl(resv, o->len); + p = resv->iov_base + resv->iov_len; + resv->iov_len += round_up_to_quad(o->len); + if (resv->iov_len > PAGE_SIZE) + return -1; + memcpy(p, o->data, o->len); + memset(p + o->len, 0, round_up_to_quad(o->len) - o->len); + return 0; +} + +/* + * Verify the checksum on the header and return SVC_OK on success. + * Otherwise, return SVC_DROP (in the case of a bad sequence number) + * or return SVC_DENIED and indicate error in rqstp->rq_auth_stat. + */ +static int +gss_verify_header(struct svc_rqst *rqstp, struct rsc *rsci, + __be32 *rpcstart, struct rpc_gss_wire_cred *gc) +{ + struct gss_ctx *ctx_id = rsci->mechctx; + struct xdr_buf rpchdr; + struct xdr_netobj checksum; + u32 flavor = 0; + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec iov; + + /* data to compute the checksum over: */ + iov.iov_base = rpcstart; + iov.iov_len = (u8 *)argv->iov_base - (u8 *)rpcstart; + xdr_buf_from_iov(&iov, &rpchdr); + + rqstp->rq_auth_stat = rpc_autherr_badverf; + if (argv->iov_len < 4) + return SVC_DENIED; + flavor = svc_getnl(argv); + if (flavor != RPC_AUTH_GSS) + return SVC_DENIED; + if (svc_safe_getnetobj(argv, &checksum)) + return SVC_DENIED; + + if (rqstp->rq_deferred) /* skip verification of revisited request */ + return SVC_OK; + if (gss_verify_mic(ctx_id, &rpchdr, &checksum) != GSS_S_COMPLETE) { + rqstp->rq_auth_stat = rpcsec_gsserr_credproblem; + return SVC_DENIED; + } + + if (gc->gc_seq > MAXSEQ) { + trace_rpcgss_svc_seqno_large(rqstp, gc->gc_seq); + rqstp->rq_auth_stat = rpcsec_gsserr_ctxproblem; + return SVC_DENIED; + } + if (!gss_check_seq_num(rqstp, rsci, gc->gc_seq)) + return SVC_DROP; + return SVC_OK; +} + +static int +gss_write_null_verf(struct svc_rqst *rqstp) +{ + __be32 *p; + + svc_putnl(rqstp->rq_res.head, RPC_AUTH_NULL); + p = rqstp->rq_res.head->iov_base + rqstp->rq_res.head->iov_len; + /* don't really need to check if head->iov_len > PAGE_SIZE ... */ + *p++ = 0; + if (!xdr_ressize_check(rqstp, p)) + return -1; + return 0; +} + +static int +gss_write_verf(struct svc_rqst *rqstp, struct gss_ctx *ctx_id, u32 seq) +{ + __be32 *xdr_seq; + u32 maj_stat; + struct xdr_buf verf_data; + struct xdr_netobj mic; + __be32 *p; + struct kvec iov; + int err = -1; + + svc_putnl(rqstp->rq_res.head, RPC_AUTH_GSS); + xdr_seq = kmalloc(4, GFP_KERNEL); + if (!xdr_seq) + return -ENOMEM; + *xdr_seq = htonl(seq); + + iov.iov_base = xdr_seq; + iov.iov_len = 4; + xdr_buf_from_iov(&iov, &verf_data); + p = rqstp->rq_res.head->iov_base + rqstp->rq_res.head->iov_len; + mic.data = (u8 *)(p + 1); + maj_stat = gss_get_mic(ctx_id, &verf_data, &mic); + if (maj_stat != GSS_S_COMPLETE) + goto out; + *p++ = htonl(mic.len); + memset((u8 *)p + mic.len, 0, round_up_to_quad(mic.len) - mic.len); + p += XDR_QUADLEN(mic.len); + if (!xdr_ressize_check(rqstp, p)) + goto out; + err = 0; +out: + kfree(xdr_seq); + return err; +} + +struct gss_domain { + struct auth_domain h; + u32 pseudoflavor; +}; + +static struct auth_domain * +find_gss_auth_domain(struct gss_ctx *ctx, u32 svc) +{ + char *name; + + name = gss_service_to_auth_domain_name(ctx->mech_type, svc); + if (!name) + return NULL; + return auth_domain_find(name); +} + +static struct auth_ops svcauthops_gss; + +u32 svcauth_gss_flavor(struct auth_domain *dom) +{ + struct gss_domain *gd = container_of(dom, struct gss_domain, h); + + return gd->pseudoflavor; +} + +EXPORT_SYMBOL_GPL(svcauth_gss_flavor); + +struct auth_domain * +svcauth_gss_register_pseudoflavor(u32 pseudoflavor, char * name) +{ + struct gss_domain *new; + struct auth_domain *test; + int stat = -ENOMEM; + + new = kmalloc(sizeof(*new), GFP_KERNEL); + if (!new) + goto out; + kref_init(&new->h.ref); + new->h.name = kstrdup(name, GFP_KERNEL); + if (!new->h.name) + goto out_free_dom; + new->h.flavour = &svcauthops_gss; + new->pseudoflavor = pseudoflavor; + + test = auth_domain_lookup(name, &new->h); + if (test != &new->h) { + pr_warn("svc: duplicate registration of gss pseudo flavour %s.\n", + name); + stat = -EADDRINUSE; + auth_domain_put(test); + goto out_free_name; + } + return test; + +out_free_name: + kfree(new->h.name); +out_free_dom: + kfree(new); +out: + return ERR_PTR(stat); +} +EXPORT_SYMBOL_GPL(svcauth_gss_register_pseudoflavor); + +static inline int +read_u32_from_xdr_buf(struct xdr_buf *buf, int base, u32 *obj) +{ + __be32 raw; + int status; + + status = read_bytes_from_xdr_buf(buf, base, &raw, sizeof(*obj)); + if (status) + return status; + *obj = ntohl(raw); + return 0; +} + +/* It would be nice if this bit of code could be shared with the client. + * Obstacles: + * The client shouldn't malloc(), would have to pass in own memory. + * The server uses base of head iovec as read pointer, while the + * client uses separate pointer. */ +static int +unwrap_integ_data(struct svc_rqst *rqstp, struct xdr_buf *buf, u32 seq, struct gss_ctx *ctx) +{ + u32 integ_len, rseqno, maj_stat; + int stat = -EINVAL; + struct xdr_netobj mic; + struct xdr_buf integ_buf; + + mic.data = NULL; + + /* NFS READ normally uses splice to send data in-place. However + * the data in cache can change after the reply's MIC is computed + * but before the RPC reply is sent. To prevent the client from + * rejecting the server-computed MIC in this somewhat rare case, + * do not use splice with the GSS integrity service. + */ + clear_bit(RQ_SPLICE_OK, &rqstp->rq_flags); + + /* Did we already verify the signature on the original pass through? */ + if (rqstp->rq_deferred) + return 0; + + integ_len = svc_getnl(&buf->head[0]); + if (integ_len & 3) + goto unwrap_failed; + if (integ_len > buf->len) + goto unwrap_failed; + if (xdr_buf_subsegment(buf, &integ_buf, 0, integ_len)) + goto unwrap_failed; + + /* copy out mic... */ + if (read_u32_from_xdr_buf(buf, integ_len, &mic.len)) + goto unwrap_failed; + if (mic.len > RPC_MAX_AUTH_SIZE) + goto unwrap_failed; + mic.data = kmalloc(mic.len, GFP_KERNEL); + if (!mic.data) + goto unwrap_failed; + if (read_bytes_from_xdr_buf(buf, integ_len + 4, mic.data, mic.len)) + goto unwrap_failed; + maj_stat = gss_verify_mic(ctx, &integ_buf, &mic); + if (maj_stat != GSS_S_COMPLETE) + goto bad_mic; + rseqno = svc_getnl(&buf->head[0]); + if (rseqno != seq) + goto bad_seqno; + /* trim off the mic and padding at the end before returning */ + xdr_buf_trim(buf, round_up_to_quad(mic.len) + 4); + stat = 0; +out: + kfree(mic.data); + return stat; + +unwrap_failed: + trace_rpcgss_svc_unwrap_failed(rqstp); + goto out; +bad_seqno: + trace_rpcgss_svc_seqno_bad(rqstp, seq, rseqno); + goto out; +bad_mic: + trace_rpcgss_svc_mic(rqstp, maj_stat); + goto out; +} + +static inline int +total_buf_len(struct xdr_buf *buf) +{ + return buf->head[0].iov_len + buf->page_len + buf->tail[0].iov_len; +} + +static void +fix_priv_head(struct xdr_buf *buf, int pad) +{ + if (buf->page_len == 0) { + /* We need to adjust head and buf->len in tandem in this + * case to make svc_defer() work--it finds the original + * buffer start using buf->len - buf->head[0].iov_len. */ + buf->head[0].iov_len -= pad; + } +} + +static int +unwrap_priv_data(struct svc_rqst *rqstp, struct xdr_buf *buf, u32 seq, struct gss_ctx *ctx) +{ + u32 priv_len, maj_stat; + int pad, remaining_len, offset; + u32 rseqno; + + clear_bit(RQ_SPLICE_OK, &rqstp->rq_flags); + + priv_len = svc_getnl(&buf->head[0]); + if (rqstp->rq_deferred) { + /* Already decrypted last time through! The sequence number + * check at out_seq is unnecessary but harmless: */ + goto out_seq; + } + /* buf->len is the number of bytes from the original start of the + * request to the end, where head[0].iov_len is just the bytes + * not yet read from the head, so these two values are different: */ + remaining_len = total_buf_len(buf); + if (priv_len > remaining_len) + goto unwrap_failed; + pad = remaining_len - priv_len; + buf->len -= pad; + fix_priv_head(buf, pad); + + maj_stat = gss_unwrap(ctx, 0, priv_len, buf); + pad = priv_len - buf->len; + /* The upper layers assume the buffer is aligned on 4-byte boundaries. + * In the krb5p case, at least, the data ends up offset, so we need to + * move it around. */ + /* XXX: This is very inefficient. It would be better to either do + * this while we encrypt, or maybe in the receive code, if we can peak + * ahead and work out the service and mechanism there. */ + offset = xdr_pad_size(buf->head[0].iov_len); + if (offset) { + buf->buflen = RPCSVC_MAXPAYLOAD; + xdr_shift_buf(buf, offset); + fix_priv_head(buf, pad); + } + if (maj_stat != GSS_S_COMPLETE) + goto bad_unwrap; +out_seq: + rseqno = svc_getnl(&buf->head[0]); + if (rseqno != seq) + goto bad_seqno; + return 0; + +unwrap_failed: + trace_rpcgss_svc_unwrap_failed(rqstp); + return -EINVAL; +bad_seqno: + trace_rpcgss_svc_seqno_bad(rqstp, seq, rseqno); + return -EINVAL; +bad_unwrap: + trace_rpcgss_svc_unwrap(rqstp, maj_stat); + return -EINVAL; +} + +struct gss_svc_data { + /* decoded gss client cred: */ + struct rpc_gss_wire_cred clcred; + /* save a pointer to the beginning of the encoded verifier, + * for use in encryption/checksumming in svcauth_gss_release: */ + __be32 *verf_start; + struct rsc *rsci; +}; + +static int +svcauth_gss_set_client(struct svc_rqst *rqstp) +{ + struct gss_svc_data *svcdata = rqstp->rq_auth_data; + struct rsc *rsci = svcdata->rsci; + struct rpc_gss_wire_cred *gc = &svcdata->clcred; + int stat; + + rqstp->rq_auth_stat = rpc_autherr_badcred; + + /* + * A gss export can be specified either by: + * export *(sec=krb5,rw) + * or by + * export gss/krb5(rw) + * The latter is deprecated; but for backwards compatibility reasons + * the nfsd code will still fall back on trying it if the former + * doesn't work; so we try to make both available to nfsd, below. + */ + rqstp->rq_gssclient = find_gss_auth_domain(rsci->mechctx, gc->gc_svc); + if (rqstp->rq_gssclient == NULL) + return SVC_DENIED; + stat = svcauth_unix_set_client(rqstp); + if (stat == SVC_DROP || stat == SVC_CLOSE) + return stat; + + rqstp->rq_auth_stat = rpc_auth_ok; + return SVC_OK; +} + +static inline int +gss_write_init_verf(struct cache_detail *cd, struct svc_rqst *rqstp, + struct xdr_netobj *out_handle, int *major_status) +{ + struct rsc *rsci; + int rc; + + if (*major_status != GSS_S_COMPLETE) + return gss_write_null_verf(rqstp); + rsci = gss_svc_searchbyctx(cd, out_handle); + if (rsci == NULL) { + *major_status = GSS_S_NO_CONTEXT; + return gss_write_null_verf(rqstp); + } + rc = gss_write_verf(rqstp, rsci->mechctx, GSS_SEQ_WIN); + cache_put(&rsci->h, cd); + return rc; +} + +static inline int +gss_read_common_verf(struct rpc_gss_wire_cred *gc, + struct kvec *argv, __be32 *authp, + struct xdr_netobj *in_handle) +{ + /* Read the verifier; should be NULL: */ + *authp = rpc_autherr_badverf; + if (argv->iov_len < 2 * 4) + return SVC_DENIED; + if (svc_getnl(argv) != RPC_AUTH_NULL) + return SVC_DENIED; + if (svc_getnl(argv) != 0) + return SVC_DENIED; + /* Martial context handle and token for upcall: */ + *authp = rpc_autherr_badcred; + if (gc->gc_proc == RPC_GSS_PROC_INIT && gc->gc_ctx.len != 0) + return SVC_DENIED; + if (dup_netobj(in_handle, &gc->gc_ctx)) + return SVC_CLOSE; + *authp = rpc_autherr_badverf; + + return 0; +} + +static inline int +gss_read_verf(struct rpc_gss_wire_cred *gc, + struct kvec *argv, __be32 *authp, + struct xdr_netobj *in_handle, + struct xdr_netobj *in_token) +{ + struct xdr_netobj tmpobj; + int res; + + res = gss_read_common_verf(gc, argv, authp, in_handle); + if (res) + return res; + + if (svc_safe_getnetobj(argv, &tmpobj)) { + kfree(in_handle->data); + return SVC_DENIED; + } + if (dup_netobj(in_token, &tmpobj)) { + kfree(in_handle->data); + return SVC_CLOSE; + } + + return 0; +} + +static void gss_free_in_token_pages(struct gssp_in_token *in_token) +{ + u32 inlen; + int i; + + i = 0; + inlen = in_token->page_len; + while (inlen) { + if (in_token->pages[i]) + put_page(in_token->pages[i]); + inlen -= inlen > PAGE_SIZE ? PAGE_SIZE : inlen; + } + + kfree(in_token->pages); + in_token->pages = NULL; +} + +static int gss_read_proxy_verf(struct svc_rqst *rqstp, + struct rpc_gss_wire_cred *gc, + struct xdr_netobj *in_handle, + struct gssp_in_token *in_token) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + unsigned int length, pgto_offs, pgfrom_offs; + int pages, i, res, pgto, pgfrom; + size_t inlen, to_offs, from_offs; + + res = gss_read_common_verf(gc, argv, &rqstp->rq_auth_stat, in_handle); + if (res) + return res; + + inlen = svc_getnl(argv); + if (inlen > (argv->iov_len + rqstp->rq_arg.page_len)) { + kfree(in_handle->data); + return SVC_DENIED; + } + + pages = DIV_ROUND_UP(inlen, PAGE_SIZE); + in_token->pages = kcalloc(pages, sizeof(struct page *), GFP_KERNEL); + if (!in_token->pages) { + kfree(in_handle->data); + return SVC_DENIED; + } + in_token->page_base = 0; + in_token->page_len = inlen; + for (i = 0; i < pages; i++) { + in_token->pages[i] = alloc_page(GFP_KERNEL); + if (!in_token->pages[i]) { + kfree(in_handle->data); + gss_free_in_token_pages(in_token); + return SVC_DENIED; + } + } + + length = min_t(unsigned int, inlen, argv->iov_len); + memcpy(page_address(in_token->pages[0]), argv->iov_base, length); + inlen -= length; + + to_offs = length; + from_offs = rqstp->rq_arg.page_base; + while (inlen) { + pgto = to_offs >> PAGE_SHIFT; + pgfrom = from_offs >> PAGE_SHIFT; + pgto_offs = to_offs & ~PAGE_MASK; + pgfrom_offs = from_offs & ~PAGE_MASK; + + length = min_t(unsigned int, inlen, + min_t(unsigned int, PAGE_SIZE - pgto_offs, + PAGE_SIZE - pgfrom_offs)); + memcpy(page_address(in_token->pages[pgto]) + pgto_offs, + page_address(rqstp->rq_arg.pages[pgfrom]) + pgfrom_offs, + length); + + to_offs += length; + from_offs += length; + inlen -= length; + } + return 0; +} + +static inline int +gss_write_resv(struct kvec *resv, size_t size_limit, + struct xdr_netobj *out_handle, struct xdr_netobj *out_token, + int major_status, int minor_status) +{ + if (resv->iov_len + 4 > size_limit) + return -1; + svc_putnl(resv, RPC_SUCCESS); + if (svc_safe_putnetobj(resv, out_handle)) + return -1; + if (resv->iov_len + 3 * 4 > size_limit) + return -1; + svc_putnl(resv, major_status); + svc_putnl(resv, minor_status); + svc_putnl(resv, GSS_SEQ_WIN); + if (svc_safe_putnetobj(resv, out_token)) + return -1; + return 0; +} + +/* + * Having read the cred already and found we're in the context + * initiation case, read the verifier and initiate (or check the results + * of) upcalls to userspace for help with context initiation. If + * the upcall results are available, write the verifier and result. + * Otherwise, drop the request pending an answer to the upcall. + */ +static int svcauth_gss_legacy_init(struct svc_rqst *rqstp, + struct rpc_gss_wire_cred *gc) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec *resv = &rqstp->rq_res.head[0]; + struct rsi *rsip, rsikey; + int ret; + struct sunrpc_net *sn = net_generic(SVC_NET(rqstp), sunrpc_net_id); + + memset(&rsikey, 0, sizeof(rsikey)); + ret = gss_read_verf(gc, argv, &rqstp->rq_auth_stat, + &rsikey.in_handle, &rsikey.in_token); + if (ret) + return ret; + + /* Perform upcall, or find upcall result: */ + rsip = rsi_lookup(sn->rsi_cache, &rsikey); + rsi_free(&rsikey); + if (!rsip) + return SVC_CLOSE; + if (cache_check(sn->rsi_cache, &rsip->h, &rqstp->rq_chandle) < 0) + /* No upcall result: */ + return SVC_CLOSE; + + ret = SVC_CLOSE; + /* Got an answer to the upcall; use it: */ + if (gss_write_init_verf(sn->rsc_cache, rqstp, + &rsip->out_handle, &rsip->major_status)) + goto out; + if (gss_write_resv(resv, PAGE_SIZE, + &rsip->out_handle, &rsip->out_token, + rsip->major_status, rsip->minor_status)) + goto out; + + ret = SVC_COMPLETE; +out: + cache_put(&rsip->h, sn->rsi_cache); + return ret; +} + +static int gss_proxy_save_rsc(struct cache_detail *cd, + struct gssp_upcall_data *ud, + uint64_t *handle) +{ + struct rsc rsci, *rscp = NULL; + static atomic64_t ctxhctr; + long long ctxh; + struct gss_api_mech *gm = NULL; + time64_t expiry; + int status; + + memset(&rsci, 0, sizeof(rsci)); + /* context handle */ + status = -ENOMEM; + /* the handle needs to be just a unique id, + * use a static counter */ + ctxh = atomic64_inc_return(&ctxhctr); + + /* make a copy for the caller */ + *handle = ctxh; + + /* make a copy for the rsc cache */ + if (dup_to_netobj(&rsci.handle, (char *)handle, sizeof(uint64_t))) + goto out; + rscp = rsc_lookup(cd, &rsci); + if (!rscp) + goto out; + + /* creds */ + if (!ud->found_creds) { + /* userspace seem buggy, we should always get at least a + * mapping to nobody */ + goto out; + } else { + struct timespec64 boot; + + /* steal creds */ + rsci.cred = ud->creds; + memset(&ud->creds, 0, sizeof(struct svc_cred)); + + status = -EOPNOTSUPP; + /* get mech handle from OID */ + gm = gss_mech_get_by_OID(&ud->mech_oid); + if (!gm) + goto out; + rsci.cred.cr_gss_mech = gm; + + status = -EINVAL; + /* mech-specific data: */ + status = gss_import_sec_context(ud->out_handle.data, + ud->out_handle.len, + gm, &rsci.mechctx, + &expiry, GFP_KERNEL); + if (status) + goto out; + + getboottime64(&boot); + expiry -= boot.tv_sec; + } + + rsci.h.expiry_time = expiry; + rscp = rsc_update(cd, &rsci, rscp); + status = 0; +out: + rsc_free(&rsci); + if (rscp) + cache_put(&rscp->h, cd); + else + status = -ENOMEM; + return status; +} + +static int svcauth_gss_proxy_init(struct svc_rqst *rqstp, + struct rpc_gss_wire_cred *gc) +{ + struct kvec *resv = &rqstp->rq_res.head[0]; + struct xdr_netobj cli_handle; + struct gssp_upcall_data ud; + uint64_t handle; + int status; + int ret; + struct net *net = SVC_NET(rqstp); + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + memset(&ud, 0, sizeof(ud)); + ret = gss_read_proxy_verf(rqstp, gc, &ud.in_handle, &ud.in_token); + if (ret) + return ret; + + ret = SVC_CLOSE; + + /* Perform synchronous upcall to gss-proxy */ + status = gssp_accept_sec_context_upcall(net, &ud); + if (status) + goto out; + + trace_rpcgss_svc_accept_upcall(rqstp, ud.major_status, ud.minor_status); + + switch (ud.major_status) { + case GSS_S_CONTINUE_NEEDED: + cli_handle = ud.out_handle; + break; + case GSS_S_COMPLETE: + status = gss_proxy_save_rsc(sn->rsc_cache, &ud, &handle); + if (status) + goto out; + cli_handle.data = (u8 *)&handle; + cli_handle.len = sizeof(handle); + break; + default: + goto out; + } + + /* Got an answer to the upcall; use it: */ + if (gss_write_init_verf(sn->rsc_cache, rqstp, + &cli_handle, &ud.major_status)) + goto out; + if (gss_write_resv(resv, PAGE_SIZE, + &cli_handle, &ud.out_token, + ud.major_status, ud.minor_status)) + goto out; + + ret = SVC_COMPLETE; +out: + gss_free_in_token_pages(&ud.in_token); + gssp_free_upcall_data(&ud); + return ret; +} + +/* + * Try to set the sn->use_gss_proxy variable to a new value. We only allow + * it to be changed if it's currently undefined (-1). If it's any other value + * then return -EBUSY unless the type wouldn't have changed anyway. + */ +static int set_gss_proxy(struct net *net, int type) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + int ret; + + WARN_ON_ONCE(type != 0 && type != 1); + ret = cmpxchg(&sn->use_gss_proxy, -1, type); + if (ret != -1 && ret != type) + return -EBUSY; + return 0; +} + +static bool use_gss_proxy(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + /* If use_gss_proxy is still undefined, then try to disable it */ + if (sn->use_gss_proxy == -1) + set_gss_proxy(net, 0); + return sn->use_gss_proxy; +} + +#ifdef CONFIG_PROC_FS + +static ssize_t write_gssp(struct file *file, const char __user *buf, + size_t count, loff_t *ppos) +{ + struct net *net = pde_data(file_inode(file)); + char tbuf[20]; + unsigned long i; + int res; + + if (*ppos || count > sizeof(tbuf)-1) + return -EINVAL; + if (copy_from_user(tbuf, buf, count)) + return -EFAULT; + + tbuf[count] = 0; + res = kstrtoul(tbuf, 0, &i); + if (res) + return res; + if (i != 1) + return -EINVAL; + res = set_gssp_clnt(net); + if (res) + return res; + res = set_gss_proxy(net, 1); + if (res) + return res; + return count; +} + +static ssize_t read_gssp(struct file *file, char __user *buf, + size_t count, loff_t *ppos) +{ + struct net *net = pde_data(file_inode(file)); + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + unsigned long p = *ppos; + char tbuf[10]; + size_t len; + + snprintf(tbuf, sizeof(tbuf), "%d\n", sn->use_gss_proxy); + len = strlen(tbuf); + if (p >= len) + return 0; + len -= p; + if (len > count) + len = count; + if (copy_to_user(buf, (void *)(tbuf+p), len)) + return -EFAULT; + *ppos += len; + return len; +} + +static const struct proc_ops use_gss_proxy_proc_ops = { + .proc_open = nonseekable_open, + .proc_write = write_gssp, + .proc_read = read_gssp, +}; + +static int create_use_gss_proxy_proc_entry(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct proc_dir_entry **p = &sn->use_gssp_proc; + + sn->use_gss_proxy = -1; + *p = proc_create_data("use-gss-proxy", S_IFREG | 0600, + sn->proc_net_rpc, + &use_gss_proxy_proc_ops, net); + if (!*p) + return -ENOMEM; + init_gssp_clnt(sn); + return 0; +} + +static void destroy_use_gss_proxy_proc_entry(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + if (sn->use_gssp_proc) { + remove_proc_entry("use-gss-proxy", sn->proc_net_rpc); + clear_gssp_clnt(sn); + } +} +#else /* CONFIG_PROC_FS */ + +static int create_use_gss_proxy_proc_entry(struct net *net) +{ + return 0; +} + +static void destroy_use_gss_proxy_proc_entry(struct net *net) {} + +#endif /* CONFIG_PROC_FS */ + +/* + * Accept an rpcsec packet. + * If context establishment, punt to user space + * If data exchange, verify/decrypt + * If context destruction, handle here + * In the context establishment and destruction case we encode + * response here and return SVC_COMPLETE. + */ +static int +svcauth_gss_accept(struct svc_rqst *rqstp) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec *resv = &rqstp->rq_res.head[0]; + u32 crlen; + struct gss_svc_data *svcdata = rqstp->rq_auth_data; + struct rpc_gss_wire_cred *gc; + struct rsc *rsci = NULL; + __be32 *rpcstart; + __be32 *reject_stat = resv->iov_base + resv->iov_len; + int ret; + struct sunrpc_net *sn = net_generic(SVC_NET(rqstp), sunrpc_net_id); + + rqstp->rq_auth_stat = rpc_autherr_badcred; + if (!svcdata) + svcdata = kmalloc(sizeof(*svcdata), GFP_KERNEL); + if (!svcdata) + goto auth_err; + rqstp->rq_auth_data = svcdata; + svcdata->verf_start = NULL; + svcdata->rsci = NULL; + gc = &svcdata->clcred; + + /* start of rpc packet is 7 u32's back from here: + * xid direction rpcversion prog vers proc flavour + */ + rpcstart = argv->iov_base; + rpcstart -= 7; + + /* credential is: + * version(==1), proc(0,1,2,3), seq, service (1,2,3), handle + * at least 5 u32s, and is preceded by length, so that makes 6. + */ + + if (argv->iov_len < 5 * 4) + goto auth_err; + crlen = svc_getnl(argv); + if (svc_getnl(argv) != RPC_GSS_VERSION) + goto auth_err; + gc->gc_proc = svc_getnl(argv); + gc->gc_seq = svc_getnl(argv); + gc->gc_svc = svc_getnl(argv); + if (svc_safe_getnetobj(argv, &gc->gc_ctx)) + goto auth_err; + if (crlen != round_up_to_quad(gc->gc_ctx.len) + 5 * 4) + goto auth_err; + + if ((gc->gc_proc != RPC_GSS_PROC_DATA) && (rqstp->rq_proc != 0)) + goto auth_err; + + rqstp->rq_auth_stat = rpc_autherr_badverf; + switch (gc->gc_proc) { + case RPC_GSS_PROC_INIT: + case RPC_GSS_PROC_CONTINUE_INIT: + if (use_gss_proxy(SVC_NET(rqstp))) + return svcauth_gss_proxy_init(rqstp, gc); + else + return svcauth_gss_legacy_init(rqstp, gc); + case RPC_GSS_PROC_DATA: + case RPC_GSS_PROC_DESTROY: + /* Look up the context, and check the verifier: */ + rqstp->rq_auth_stat = rpcsec_gsserr_credproblem; + rsci = gss_svc_searchbyctx(sn->rsc_cache, &gc->gc_ctx); + if (!rsci) + goto auth_err; + switch (gss_verify_header(rqstp, rsci, rpcstart, gc)) { + case SVC_OK: + break; + case SVC_DENIED: + goto auth_err; + case SVC_DROP: + goto drop; + } + break; + default: + rqstp->rq_auth_stat = rpc_autherr_rejectedcred; + goto auth_err; + } + + /* now act upon the command: */ + switch (gc->gc_proc) { + case RPC_GSS_PROC_DESTROY: + if (gss_write_verf(rqstp, rsci->mechctx, gc->gc_seq)) + goto auth_err; + /* Delete the entry from the cache_list and call cache_put */ + sunrpc_cache_unhash(sn->rsc_cache, &rsci->h); + if (resv->iov_len + 4 > PAGE_SIZE) + goto drop; + svc_putnl(resv, RPC_SUCCESS); + goto complete; + case RPC_GSS_PROC_DATA: + rqstp->rq_auth_stat = rpcsec_gsserr_ctxproblem; + svcdata->verf_start = resv->iov_base + resv->iov_len; + if (gss_write_verf(rqstp, rsci->mechctx, gc->gc_seq)) + goto auth_err; + rqstp->rq_cred = rsci->cred; + get_group_info(rsci->cred.cr_group_info); + rqstp->rq_auth_stat = rpc_autherr_badcred; + switch (gc->gc_svc) { + case RPC_GSS_SVC_NONE: + break; + case RPC_GSS_SVC_INTEGRITY: + /* placeholders for length and seq. number: */ + svc_putnl(resv, 0); + svc_putnl(resv, 0); + if (unwrap_integ_data(rqstp, &rqstp->rq_arg, + gc->gc_seq, rsci->mechctx)) + goto garbage_args; + rqstp->rq_auth_slack = RPC_MAX_AUTH_SIZE; + break; + case RPC_GSS_SVC_PRIVACY: + /* placeholders for length and seq. number: */ + svc_putnl(resv, 0); + svc_putnl(resv, 0); + if (unwrap_priv_data(rqstp, &rqstp->rq_arg, + gc->gc_seq, rsci->mechctx)) + goto garbage_args; + rqstp->rq_auth_slack = RPC_MAX_AUTH_SIZE * 2; + break; + default: + goto auth_err; + } + svcdata->rsci = rsci; + cache_get(&rsci->h); + rqstp->rq_cred.cr_flavor = gss_svc_to_pseudoflavor( + rsci->mechctx->mech_type, + GSS_C_QOP_DEFAULT, + gc->gc_svc); + ret = SVC_OK; + trace_rpcgss_svc_authenticate(rqstp, gc); + goto out; + } +garbage_args: + ret = SVC_GARBAGE; + goto out; +auth_err: + /* Restore write pointer to its original value: */ + xdr_ressize_check(rqstp, reject_stat); + ret = SVC_DENIED; + goto out; +complete: + ret = SVC_COMPLETE; + goto out; +drop: + ret = SVC_CLOSE; +out: + if (rsci) + cache_put(&rsci->h, sn->rsc_cache); + return ret; +} + +static __be32 * +svcauth_gss_prepare_to_wrap(struct xdr_buf *resbuf, struct gss_svc_data *gsd) +{ + __be32 *p; + u32 verf_len; + + p = gsd->verf_start; + gsd->verf_start = NULL; + + /* If the reply stat is nonzero, don't wrap: */ + if (*(p-1) != rpc_success) + return NULL; + /* Skip the verifier: */ + p += 1; + verf_len = ntohl(*p++); + p += XDR_QUADLEN(verf_len); + /* move accept_stat to right place: */ + memcpy(p, p + 2, 4); + /* Also don't wrap if the accept stat is nonzero: */ + if (*p != rpc_success) { + resbuf->head[0].iov_len -= 2 * 4; + return NULL; + } + p++; + return p; +} + +static inline int +svcauth_gss_wrap_resp_integ(struct svc_rqst *rqstp) +{ + struct gss_svc_data *gsd = (struct gss_svc_data *)rqstp->rq_auth_data; + struct rpc_gss_wire_cred *gc = &gsd->clcred; + struct xdr_buf *resbuf = &rqstp->rq_res; + struct xdr_buf integ_buf; + struct xdr_netobj mic; + struct kvec *resv; + __be32 *p; + int integ_offset, integ_len; + int stat = -EINVAL; + + p = svcauth_gss_prepare_to_wrap(resbuf, gsd); + if (p == NULL) + goto out; + integ_offset = (u8 *)(p + 1) - (u8 *)resbuf->head[0].iov_base; + integ_len = resbuf->len - integ_offset; + if (integ_len & 3) + goto out; + *p++ = htonl(integ_len); + *p++ = htonl(gc->gc_seq); + if (xdr_buf_subsegment(resbuf, &integ_buf, integ_offset, integ_len)) { + WARN_ON_ONCE(1); + goto out_err; + } + if (resbuf->tail[0].iov_base == NULL) { + if (resbuf->head[0].iov_len + RPC_MAX_AUTH_SIZE > PAGE_SIZE) + goto out_err; + resbuf->tail[0].iov_base = resbuf->head[0].iov_base + + resbuf->head[0].iov_len; + resbuf->tail[0].iov_len = 0; + } + resv = &resbuf->tail[0]; + mic.data = (u8 *)resv->iov_base + resv->iov_len + 4; + if (gss_get_mic(gsd->rsci->mechctx, &integ_buf, &mic)) + goto out_err; + svc_putnl(resv, mic.len); + memset(mic.data + mic.len, 0, + round_up_to_quad(mic.len) - mic.len); + resv->iov_len += XDR_QUADLEN(mic.len) << 2; + /* not strictly required: */ + resbuf->len += XDR_QUADLEN(mic.len) << 2; + if (resv->iov_len > PAGE_SIZE) + goto out_err; +out: + stat = 0; +out_err: + return stat; +} + +static inline int +svcauth_gss_wrap_resp_priv(struct svc_rqst *rqstp) +{ + struct gss_svc_data *gsd = (struct gss_svc_data *)rqstp->rq_auth_data; + struct rpc_gss_wire_cred *gc = &gsd->clcred; + struct xdr_buf *resbuf = &rqstp->rq_res; + struct page **inpages = NULL; + __be32 *p, *len; + int offset; + int pad; + + p = svcauth_gss_prepare_to_wrap(resbuf, gsd); + if (p == NULL) + return 0; + len = p++; + offset = (u8 *)p - (u8 *)resbuf->head[0].iov_base; + *p++ = htonl(gc->gc_seq); + inpages = resbuf->pages; + /* XXX: Would be better to write some xdr helper functions for + * nfs{2,3,4}xdr.c that place the data right, instead of copying: */ + + /* + * If there is currently tail data, make sure there is + * room for the head, tail, and 2 * RPC_MAX_AUTH_SIZE in + * the page, and move the current tail data such that + * there is RPC_MAX_AUTH_SIZE slack space available in + * both the head and tail. + */ + if (resbuf->tail[0].iov_base) { + if (resbuf->tail[0].iov_base >= + resbuf->head[0].iov_base + PAGE_SIZE) + return -EINVAL; + if (resbuf->tail[0].iov_base < resbuf->head[0].iov_base) + return -EINVAL; + if (resbuf->tail[0].iov_len + resbuf->head[0].iov_len + + 2 * RPC_MAX_AUTH_SIZE > PAGE_SIZE) + return -ENOMEM; + memmove(resbuf->tail[0].iov_base + RPC_MAX_AUTH_SIZE, + resbuf->tail[0].iov_base, + resbuf->tail[0].iov_len); + resbuf->tail[0].iov_base += RPC_MAX_AUTH_SIZE; + } + /* + * If there is no current tail data, make sure there is + * room for the head data, and 2 * RPC_MAX_AUTH_SIZE in the + * allotted page, and set up tail information such that there + * is RPC_MAX_AUTH_SIZE slack space available in both the + * head and tail. + */ + if (resbuf->tail[0].iov_base == NULL) { + if (resbuf->head[0].iov_len + 2*RPC_MAX_AUTH_SIZE > PAGE_SIZE) + return -ENOMEM; + resbuf->tail[0].iov_base = resbuf->head[0].iov_base + + resbuf->head[0].iov_len + RPC_MAX_AUTH_SIZE; + resbuf->tail[0].iov_len = 0; + } + if (gss_wrap(gsd->rsci->mechctx, offset, resbuf, inpages)) + return -ENOMEM; + *len = htonl(resbuf->len - offset); + pad = 3 - ((resbuf->len - offset - 1)&3); + p = (__be32 *)(resbuf->tail[0].iov_base + resbuf->tail[0].iov_len); + memset(p, 0, pad); + resbuf->tail[0].iov_len += pad; + resbuf->len += pad; + return 0; +} + +static int +svcauth_gss_release(struct svc_rqst *rqstp) +{ + struct gss_svc_data *gsd = (struct gss_svc_data *)rqstp->rq_auth_data; + struct rpc_gss_wire_cred *gc; + struct xdr_buf *resbuf = &rqstp->rq_res; + int stat = -EINVAL; + struct sunrpc_net *sn = net_generic(SVC_NET(rqstp), sunrpc_net_id); + + if (!gsd) + goto out; + gc = &gsd->clcred; + if (gc->gc_proc != RPC_GSS_PROC_DATA) + goto out; + /* Release can be called twice, but we only wrap once. */ + if (gsd->verf_start == NULL) + goto out; + /* normally not set till svc_send, but we need it here: */ + /* XXX: what for? Do we mess it up the moment we call svc_putu32 + * or whatever? */ + resbuf->len = total_buf_len(resbuf); + switch (gc->gc_svc) { + case RPC_GSS_SVC_NONE: + break; + case RPC_GSS_SVC_INTEGRITY: + stat = svcauth_gss_wrap_resp_integ(rqstp); + if (stat) + goto out_err; + break; + case RPC_GSS_SVC_PRIVACY: + stat = svcauth_gss_wrap_resp_priv(rqstp); + if (stat) + goto out_err; + break; + /* + * For any other gc_svc value, svcauth_gss_accept() already set + * the auth_error appropriately; just fall through: + */ + } + +out: + stat = 0; +out_err: + if (rqstp->rq_client) + auth_domain_put(rqstp->rq_client); + rqstp->rq_client = NULL; + if (rqstp->rq_gssclient) + auth_domain_put(rqstp->rq_gssclient); + rqstp->rq_gssclient = NULL; + if (rqstp->rq_cred.cr_group_info) + put_group_info(rqstp->rq_cred.cr_group_info); + rqstp->rq_cred.cr_group_info = NULL; + if (gsd && gsd->rsci) { + cache_put(&gsd->rsci->h, sn->rsc_cache); + gsd->rsci = NULL; + } + return stat; +} + +static void +svcauth_gss_domain_release_rcu(struct rcu_head *head) +{ + struct auth_domain *dom = container_of(head, struct auth_domain, rcu_head); + struct gss_domain *gd = container_of(dom, struct gss_domain, h); + + kfree(dom->name); + kfree(gd); +} + +static void +svcauth_gss_domain_release(struct auth_domain *dom) +{ + call_rcu(&dom->rcu_head, svcauth_gss_domain_release_rcu); +} + +static struct auth_ops svcauthops_gss = { + .name = "rpcsec_gss", + .owner = THIS_MODULE, + .flavour = RPC_AUTH_GSS, + .accept = svcauth_gss_accept, + .release = svcauth_gss_release, + .domain_release = svcauth_gss_domain_release, + .set_client = svcauth_gss_set_client, +}; + +static int rsi_cache_create_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd; + int err; + + cd = cache_create_net(&rsi_cache_template, net); + if (IS_ERR(cd)) + return PTR_ERR(cd); + err = cache_register_net(cd, net); + if (err) { + cache_destroy_net(cd, net); + return err; + } + sn->rsi_cache = cd; + return 0; +} + +static void rsi_cache_destroy_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd = sn->rsi_cache; + + sn->rsi_cache = NULL; + cache_purge(cd); + cache_unregister_net(cd, net); + cache_destroy_net(cd, net); +} + +static int rsc_cache_create_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd; + int err; + + cd = cache_create_net(&rsc_cache_template, net); + if (IS_ERR(cd)) + return PTR_ERR(cd); + err = cache_register_net(cd, net); + if (err) { + cache_destroy_net(cd, net); + return err; + } + sn->rsc_cache = cd; + return 0; +} + +static void rsc_cache_destroy_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd = sn->rsc_cache; + + sn->rsc_cache = NULL; + cache_purge(cd); + cache_unregister_net(cd, net); + cache_destroy_net(cd, net); +} + +int +gss_svc_init_net(struct net *net) +{ + int rv; + + rv = rsc_cache_create_net(net); + if (rv) + return rv; + rv = rsi_cache_create_net(net); + if (rv) + goto out1; + rv = create_use_gss_proxy_proc_entry(net); + if (rv) + goto out2; + return 0; +out2: + rsi_cache_destroy_net(net); +out1: + rsc_cache_destroy_net(net); + return rv; +} + +void +gss_svc_shutdown_net(struct net *net) +{ + destroy_use_gss_proxy_proc_entry(net); + rsi_cache_destroy_net(net); + rsc_cache_destroy_net(net); +} + +int +gss_svc_init(void) +{ + return svc_auth_register(RPC_AUTH_GSS, &svcauthops_gss); +} + +void +gss_svc_shutdown(void) +{ + svc_auth_unregister(RPC_AUTH_GSS); +} diff --git a/net/sunrpc/auth_gss/trace.c b/net/sunrpc/auth_gss/trace.c new file mode 100644 index 000000000..76685abba --- /dev/null +++ b/net/sunrpc/auth_gss/trace.c @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2018, 2019 Oracle. All rights reserved. + */ + +#include +#include +#include +#include +#include +#include + +#define CREATE_TRACE_POINTS +#include diff --git a/net/sunrpc/auth_null.c b/net/sunrpc/auth_null.c new file mode 100644 index 000000000..41a633a40 --- /dev/null +++ b/net/sunrpc/auth_null.c @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * linux/net/sunrpc/auth_null.c + * + * AUTH_NULL authentication. Really :-) + * + * Copyright (C) 1996, Olaf Kirch + */ + +#include +#include +#include + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static struct rpc_auth null_auth; +static struct rpc_cred null_cred; + +static struct rpc_auth * +nul_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) +{ + refcount_inc(&null_auth.au_count); + return &null_auth; +} + +static void +nul_destroy(struct rpc_auth *auth) +{ +} + +/* + * Lookup NULL creds for current process + */ +static struct rpc_cred * +nul_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) +{ + return get_rpccred(&null_cred); +} + +/* + * Destroy cred handle. + */ +static void +nul_destroy_cred(struct rpc_cred *cred) +{ +} + +/* + * Match cred handle against current process + */ +static int +nul_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags) +{ + return 1; +} + +/* + * Marshal credential. + */ +static int +nul_marshal(struct rpc_task *task, struct xdr_stream *xdr) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, 4 * sizeof(*p)); + if (!p) + return -EMSGSIZE; + /* Credential */ + *p++ = rpc_auth_null; + *p++ = xdr_zero; + /* Verifier */ + *p++ = rpc_auth_null; + *p = xdr_zero; + return 0; +} + +/* + * Refresh credential. This is a no-op for AUTH_NULL + */ +static int +nul_refresh(struct rpc_task *task) +{ + set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags); + return 0; +} + +static int +nul_validate(struct rpc_task *task, struct xdr_stream *xdr) +{ + __be32 *p; + + p = xdr_inline_decode(xdr, 2 * sizeof(*p)); + if (!p) + return -EIO; + if (*p++ != rpc_auth_null) + return -EIO; + if (*p != xdr_zero) + return -EIO; + return 0; +} + +const struct rpc_authops authnull_ops = { + .owner = THIS_MODULE, + .au_flavor = RPC_AUTH_NULL, + .au_name = "NULL", + .create = nul_create, + .destroy = nul_destroy, + .lookup_cred = nul_lookup_cred, +}; + +static +struct rpc_auth null_auth = { + .au_cslack = NUL_CALLSLACK, + .au_rslack = NUL_REPLYSLACK, + .au_verfsize = NUL_REPLYSLACK, + .au_ralign = NUL_REPLYSLACK, + .au_ops = &authnull_ops, + .au_flavor = RPC_AUTH_NULL, + .au_count = REFCOUNT_INIT(1), +}; + +static +const struct rpc_credops null_credops = { + .cr_name = "AUTH_NULL", + .crdestroy = nul_destroy_cred, + .crmatch = nul_match, + .crmarshal = nul_marshal, + .crwrap_req = rpcauth_wrap_req_encode, + .crrefresh = nul_refresh, + .crvalidate = nul_validate, + .crunwrap_resp = rpcauth_unwrap_resp_decode, +}; + +static +struct rpc_cred null_cred = { + .cr_lru = LIST_HEAD_INIT(null_cred.cr_lru), + .cr_auth = &null_auth, + .cr_ops = &null_credops, + .cr_count = REFCOUNT_INIT(2), + .cr_flags = 1UL << RPCAUTH_CRED_UPTODATE, +}; diff --git a/net/sunrpc/auth_unix.c b/net/sunrpc/auth_unix.c new file mode 100644 index 000000000..1e091d3fa --- /dev/null +++ b/net/sunrpc/auth_unix.c @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * linux/net/sunrpc/auth_unix.c + * + * UNIX-style authentication; no AUTH_SHORT support + * + * Copyright (C) 1996, Olaf Kirch + */ + +#include +#include +#include +#include +#include +#include +#include +#include + + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static struct rpc_auth unix_auth; +static const struct rpc_credops unix_credops; +static mempool_t *unix_pool; + +static struct rpc_auth * +unx_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt) +{ + refcount_inc(&unix_auth.au_count); + return &unix_auth; +} + +static void +unx_destroy(struct rpc_auth *auth) +{ +} + +/* + * Lookup AUTH_UNIX creds for current process + */ +static struct rpc_cred *unx_lookup_cred(struct rpc_auth *auth, + struct auth_cred *acred, int flags) +{ + struct rpc_cred *ret; + + ret = kmalloc(sizeof(*ret), rpc_task_gfp_mask()); + if (!ret) { + if (!(flags & RPCAUTH_LOOKUP_ASYNC)) + return ERR_PTR(-ENOMEM); + ret = mempool_alloc(unix_pool, GFP_NOWAIT); + if (!ret) + return ERR_PTR(-ENOMEM); + } + rpcauth_init_cred(ret, acred, auth, &unix_credops); + ret->cr_flags = 1UL << RPCAUTH_CRED_UPTODATE; + return ret; +} + +static void +unx_free_cred_callback(struct rcu_head *head) +{ + struct rpc_cred *rpc_cred = container_of(head, struct rpc_cred, cr_rcu); + + put_cred(rpc_cred->cr_cred); + mempool_free(rpc_cred, unix_pool); +} + +static void +unx_destroy_cred(struct rpc_cred *cred) +{ + call_rcu(&cred->cr_rcu, unx_free_cred_callback); +} + +/* + * Match credentials against current the auth_cred. + */ +static int +unx_match(struct auth_cred *acred, struct rpc_cred *cred, int flags) +{ + unsigned int groups = 0; + unsigned int i; + + if (cred->cr_cred == acred->cred) + return 1; + + if (!uid_eq(cred->cr_cred->fsuid, acred->cred->fsuid) || !gid_eq(cred->cr_cred->fsgid, acred->cred->fsgid)) + return 0; + + if (acred->cred->group_info != NULL) + groups = acred->cred->group_info->ngroups; + if (groups > UNX_NGROUPS) + groups = UNX_NGROUPS; + if (cred->cr_cred->group_info == NULL) + return groups == 0; + if (groups != cred->cr_cred->group_info->ngroups) + return 0; + + for (i = 0; i < groups ; i++) + if (!gid_eq(cred->cr_cred->group_info->gid[i], acred->cred->group_info->gid[i])) + return 0; + return 1; +} + +/* + * Marshal credentials. + * Maybe we should keep a cached credential for performance reasons. + */ +static int +unx_marshal(struct rpc_task *task, struct xdr_stream *xdr) +{ + struct rpc_clnt *clnt = task->tk_client; + struct rpc_cred *cred = task->tk_rqstp->rq_cred; + __be32 *p, *cred_len, *gidarr_len; + int i; + struct group_info *gi = cred->cr_cred->group_info; + struct user_namespace *userns = clnt->cl_cred ? + clnt->cl_cred->user_ns : &init_user_ns; + + /* Credential */ + + p = xdr_reserve_space(xdr, 3 * sizeof(*p)); + if (!p) + goto marshal_failed; + *p++ = rpc_auth_unix; + cred_len = p++; + *p++ = xdr_zero; /* stamp */ + if (xdr_stream_encode_opaque(xdr, clnt->cl_nodename, + clnt->cl_nodelen) < 0) + goto marshal_failed; + p = xdr_reserve_space(xdr, 3 * sizeof(*p)); + if (!p) + goto marshal_failed; + *p++ = cpu_to_be32(from_kuid_munged(userns, cred->cr_cred->fsuid)); + *p++ = cpu_to_be32(from_kgid_munged(userns, cred->cr_cred->fsgid)); + + gidarr_len = p++; + if (gi) + for (i = 0; i < UNX_NGROUPS && i < gi->ngroups; i++) + *p++ = cpu_to_be32(from_kgid_munged(userns, gi->gid[i])); + *gidarr_len = cpu_to_be32(p - gidarr_len - 1); + *cred_len = cpu_to_be32((p - cred_len - 1) << 2); + p = xdr_reserve_space(xdr, (p - gidarr_len - 1) << 2); + if (!p) + goto marshal_failed; + + /* Verifier */ + + p = xdr_reserve_space(xdr, 2 * sizeof(*p)); + if (!p) + goto marshal_failed; + *p++ = rpc_auth_null; + *p = xdr_zero; + + return 0; + +marshal_failed: + return -EMSGSIZE; +} + +/* + * Refresh credentials. This is a no-op for AUTH_UNIX + */ +static int +unx_refresh(struct rpc_task *task) +{ + set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags); + return 0; +} + +static int +unx_validate(struct rpc_task *task, struct xdr_stream *xdr) +{ + struct rpc_auth *auth = task->tk_rqstp->rq_cred->cr_auth; + __be32 *p; + u32 size; + + p = xdr_inline_decode(xdr, 2 * sizeof(*p)); + if (!p) + return -EIO; + switch (*p++) { + case rpc_auth_null: + case rpc_auth_unix: + case rpc_auth_short: + break; + default: + return -EIO; + } + size = be32_to_cpup(p); + if (size > RPC_MAX_AUTH_SIZE) + return -EIO; + p = xdr_inline_decode(xdr, size); + if (!p) + return -EIO; + + auth->au_verfsize = XDR_QUADLEN(size) + 2; + auth->au_rslack = XDR_QUADLEN(size) + 2; + auth->au_ralign = XDR_QUADLEN(size) + 2; + return 0; +} + +int __init rpc_init_authunix(void) +{ + unix_pool = mempool_create_kmalloc_pool(16, sizeof(struct rpc_cred)); + return unix_pool ? 0 : -ENOMEM; +} + +void rpc_destroy_authunix(void) +{ + mempool_destroy(unix_pool); +} + +const struct rpc_authops authunix_ops = { + .owner = THIS_MODULE, + .au_flavor = RPC_AUTH_UNIX, + .au_name = "UNIX", + .create = unx_create, + .destroy = unx_destroy, + .lookup_cred = unx_lookup_cred, +}; + +static +struct rpc_auth unix_auth = { + .au_cslack = UNX_CALLSLACK, + .au_rslack = NUL_REPLYSLACK, + .au_verfsize = NUL_REPLYSLACK, + .au_ops = &authunix_ops, + .au_flavor = RPC_AUTH_UNIX, + .au_count = REFCOUNT_INIT(1), +}; + +static +const struct rpc_credops unix_credops = { + .cr_name = "AUTH_UNIX", + .crdestroy = unx_destroy_cred, + .crmatch = unx_match, + .crmarshal = unx_marshal, + .crwrap_req = rpcauth_wrap_req_encode, + .crrefresh = unx_refresh, + .crvalidate = unx_validate, + .crunwrap_resp = rpcauth_unwrap_resp_decode, +}; diff --git a/net/sunrpc/backchannel_rqst.c b/net/sunrpc/backchannel_rqst.c new file mode 100644 index 000000000..65a6c6429 --- /dev/null +++ b/net/sunrpc/backchannel_rqst.c @@ -0,0 +1,376 @@ +// SPDX-License-Identifier: GPL-2.0-only +/****************************************************************************** + +(c) 2007 Network Appliance, Inc. All Rights Reserved. +(c) 2009 NetApp. All Rights Reserved. + + +******************************************************************************/ + +#include +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +#define RPCDBG_FACILITY RPCDBG_TRANS +#endif + +#define BC_MAX_SLOTS 64U + +unsigned int xprt_bc_max_slots(struct rpc_xprt *xprt) +{ + return BC_MAX_SLOTS; +} + +/* + * Helper routines that track the number of preallocation elements + * on the transport. + */ +static inline int xprt_need_to_requeue(struct rpc_xprt *xprt) +{ + return xprt->bc_alloc_count < xprt->bc_alloc_max; +} + +/* + * Free the preallocated rpc_rqst structure and the memory + * buffers hanging off of it. + */ +static void xprt_free_allocation(struct rpc_rqst *req) +{ + struct xdr_buf *xbufp; + + dprintk("RPC: free allocations for req= %p\n", req); + WARN_ON_ONCE(test_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state)); + xbufp = &req->rq_rcv_buf; + free_page((unsigned long)xbufp->head[0].iov_base); + xbufp = &req->rq_snd_buf; + free_page((unsigned long)xbufp->head[0].iov_base); + kfree(req); +} + +static void xprt_bc_reinit_xdr_buf(struct xdr_buf *buf) +{ + buf->head[0].iov_len = PAGE_SIZE; + buf->tail[0].iov_len = 0; + buf->pages = NULL; + buf->page_len = 0; + buf->flags = 0; + buf->len = 0; + buf->buflen = PAGE_SIZE; +} + +static int xprt_alloc_xdr_buf(struct xdr_buf *buf, gfp_t gfp_flags) +{ + struct page *page; + /* Preallocate one XDR receive buffer */ + page = alloc_page(gfp_flags); + if (page == NULL) + return -ENOMEM; + xdr_buf_init(buf, page_address(page), PAGE_SIZE); + return 0; +} + +static struct rpc_rqst *xprt_alloc_bc_req(struct rpc_xprt *xprt) +{ + gfp_t gfp_flags = GFP_KERNEL | __GFP_NORETRY | __GFP_NOWARN; + struct rpc_rqst *req; + + /* Pre-allocate one backchannel rpc_rqst */ + req = kzalloc(sizeof(*req), gfp_flags); + if (req == NULL) + return NULL; + + req->rq_xprt = xprt; + INIT_LIST_HEAD(&req->rq_bc_list); + + /* Preallocate one XDR receive buffer */ + if (xprt_alloc_xdr_buf(&req->rq_rcv_buf, gfp_flags) < 0) { + printk(KERN_ERR "Failed to create bc receive xbuf\n"); + goto out_free; + } + req->rq_rcv_buf.len = PAGE_SIZE; + + /* Preallocate one XDR send buffer */ + if (xprt_alloc_xdr_buf(&req->rq_snd_buf, gfp_flags) < 0) { + printk(KERN_ERR "Failed to create bc snd xbuf\n"); + goto out_free; + } + return req; +out_free: + xprt_free_allocation(req); + return NULL; +} + +/* + * Preallocate up to min_reqs structures and related buffers for use + * by the backchannel. This function can be called multiple times + * when creating new sessions that use the same rpc_xprt. The + * preallocated buffers are added to the pool of resources used by + * the rpc_xprt. Any one of these resources may be used by an + * incoming callback request. It's up to the higher levels in the + * stack to enforce that the maximum number of session slots is not + * being exceeded. + * + * Some callback arguments can be large. For example, a pNFS server + * using multiple deviceids. The list can be unbound, but the client + * has the ability to tell the server the maximum size of the callback + * requests. Each deviceID is 16 bytes, so allocate one page + * for the arguments to have enough room to receive a number of these + * deviceIDs. The NFS client indicates to the pNFS server that its + * callback requests can be up to 4096 bytes in size. + */ +int xprt_setup_backchannel(struct rpc_xprt *xprt, unsigned int min_reqs) +{ + if (!xprt->ops->bc_setup) + return 0; + return xprt->ops->bc_setup(xprt, min_reqs); +} +EXPORT_SYMBOL_GPL(xprt_setup_backchannel); + +int xprt_setup_bc(struct rpc_xprt *xprt, unsigned int min_reqs) +{ + struct rpc_rqst *req; + struct list_head tmp_list; + int i; + + dprintk("RPC: setup backchannel transport\n"); + + if (min_reqs > BC_MAX_SLOTS) + min_reqs = BC_MAX_SLOTS; + + /* + * We use a temporary list to keep track of the preallocated + * buffers. Once we're done building the list we splice it + * into the backchannel preallocation list off of the rpc_xprt + * struct. This helps minimize the amount of time the list + * lock is held on the rpc_xprt struct. It also makes cleanup + * easier in case of memory allocation errors. + */ + INIT_LIST_HEAD(&tmp_list); + for (i = 0; i < min_reqs; i++) { + /* Pre-allocate one backchannel rpc_rqst */ + req = xprt_alloc_bc_req(xprt); + if (req == NULL) { + printk(KERN_ERR "Failed to create bc rpc_rqst\n"); + goto out_free; + } + + /* Add the allocated buffer to the tmp list */ + dprintk("RPC: adding req= %p\n", req); + list_add(&req->rq_bc_pa_list, &tmp_list); + } + + /* + * Add the temporary list to the backchannel preallocation list + */ + spin_lock(&xprt->bc_pa_lock); + list_splice(&tmp_list, &xprt->bc_pa_list); + xprt->bc_alloc_count += min_reqs; + xprt->bc_alloc_max += min_reqs; + atomic_add(min_reqs, &xprt->bc_slot_count); + spin_unlock(&xprt->bc_pa_lock); + + dprintk("RPC: setup backchannel transport done\n"); + return 0; + +out_free: + /* + * Memory allocation failed, free the temporary list + */ + while (!list_empty(&tmp_list)) { + req = list_first_entry(&tmp_list, + struct rpc_rqst, + rq_bc_pa_list); + list_del(&req->rq_bc_pa_list); + xprt_free_allocation(req); + } + + dprintk("RPC: setup backchannel transport failed\n"); + return -ENOMEM; +} + +/** + * xprt_destroy_backchannel - Destroys the backchannel preallocated structures. + * @xprt: the transport holding the preallocated strucures + * @max_reqs: the maximum number of preallocated structures to destroy + * + * Since these structures may have been allocated by multiple calls + * to xprt_setup_backchannel, we only destroy up to the maximum number + * of reqs specified by the caller. + */ +void xprt_destroy_backchannel(struct rpc_xprt *xprt, unsigned int max_reqs) +{ + if (xprt->ops->bc_destroy) + xprt->ops->bc_destroy(xprt, max_reqs); +} +EXPORT_SYMBOL_GPL(xprt_destroy_backchannel); + +void xprt_destroy_bc(struct rpc_xprt *xprt, unsigned int max_reqs) +{ + struct rpc_rqst *req = NULL, *tmp = NULL; + + dprintk("RPC: destroy backchannel transport\n"); + + if (max_reqs == 0) + goto out; + + spin_lock_bh(&xprt->bc_pa_lock); + xprt->bc_alloc_max -= min(max_reqs, xprt->bc_alloc_max); + list_for_each_entry_safe(req, tmp, &xprt->bc_pa_list, rq_bc_pa_list) { + dprintk("RPC: req=%p\n", req); + list_del(&req->rq_bc_pa_list); + xprt_free_allocation(req); + xprt->bc_alloc_count--; + atomic_dec(&xprt->bc_slot_count); + if (--max_reqs == 0) + break; + } + spin_unlock_bh(&xprt->bc_pa_lock); + +out: + dprintk("RPC: backchannel list empty= %s\n", + list_empty(&xprt->bc_pa_list) ? "true" : "false"); +} + +static struct rpc_rqst *xprt_get_bc_request(struct rpc_xprt *xprt, __be32 xid, + struct rpc_rqst *new) +{ + struct rpc_rqst *req = NULL; + + dprintk("RPC: allocate a backchannel request\n"); + if (list_empty(&xprt->bc_pa_list)) { + if (!new) + goto not_found; + if (atomic_read(&xprt->bc_slot_count) >= BC_MAX_SLOTS) + goto not_found; + list_add_tail(&new->rq_bc_pa_list, &xprt->bc_pa_list); + xprt->bc_alloc_count++; + atomic_inc(&xprt->bc_slot_count); + } + req = list_first_entry(&xprt->bc_pa_list, struct rpc_rqst, + rq_bc_pa_list); + req->rq_reply_bytes_recvd = 0; + memcpy(&req->rq_private_buf, &req->rq_rcv_buf, + sizeof(req->rq_private_buf)); + req->rq_xid = xid; + req->rq_connect_cookie = xprt->connect_cookie; + dprintk("RPC: backchannel req=%p\n", req); +not_found: + return req; +} + +/* + * Return the preallocated rpc_rqst structure and XDR buffers + * associated with this rpc_task. + */ +void xprt_free_bc_request(struct rpc_rqst *req) +{ + struct rpc_xprt *xprt = req->rq_xprt; + + xprt->ops->bc_free_rqst(req); +} + +void xprt_free_bc_rqst(struct rpc_rqst *req) +{ + struct rpc_xprt *xprt = req->rq_xprt; + + dprintk("RPC: free backchannel req=%p\n", req); + + req->rq_connect_cookie = xprt->connect_cookie - 1; + smp_mb__before_atomic(); + clear_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state); + smp_mb__after_atomic(); + + /* + * Return it to the list of preallocations so that it + * may be reused by a new callback request. + */ + spin_lock_bh(&xprt->bc_pa_lock); + if (xprt_need_to_requeue(xprt)) { + xprt_bc_reinit_xdr_buf(&req->rq_snd_buf); + xprt_bc_reinit_xdr_buf(&req->rq_rcv_buf); + req->rq_rcv_buf.len = PAGE_SIZE; + list_add_tail(&req->rq_bc_pa_list, &xprt->bc_pa_list); + xprt->bc_alloc_count++; + atomic_inc(&xprt->bc_slot_count); + req = NULL; + } + spin_unlock_bh(&xprt->bc_pa_lock); + if (req != NULL) { + /* + * The last remaining session was destroyed while this + * entry was in use. Free the entry and don't attempt + * to add back to the list because there is no need to + * have anymore preallocated entries. + */ + dprintk("RPC: Last session removed req=%p\n", req); + xprt_free_allocation(req); + } + xprt_put(xprt); +} + +/* + * One or more rpc_rqst structure have been preallocated during the + * backchannel setup. Buffer space for the send and private XDR buffers + * has been preallocated as well. Use xprt_alloc_bc_request to allocate + * to this request. Use xprt_free_bc_request to return it. + * + * We know that we're called in soft interrupt context, grab the spin_lock + * since there is no need to grab the bottom half spin_lock. + * + * Return an available rpc_rqst, otherwise NULL if non are available. + */ +struct rpc_rqst *xprt_lookup_bc_request(struct rpc_xprt *xprt, __be32 xid) +{ + struct rpc_rqst *req, *new = NULL; + + do { + spin_lock(&xprt->bc_pa_lock); + list_for_each_entry(req, &xprt->bc_pa_list, rq_bc_pa_list) { + if (req->rq_connect_cookie != xprt->connect_cookie) + continue; + if (req->rq_xid == xid) + goto found; + } + req = xprt_get_bc_request(xprt, xid, new); +found: + spin_unlock(&xprt->bc_pa_lock); + if (new) { + if (req != new) + xprt_free_allocation(new); + break; + } else if (req) + break; + new = xprt_alloc_bc_req(xprt); + } while (new); + return req; +} + +/* + * Add callback request to callback list. The callback + * service sleeps on the sv_cb_waitq waiting for new + * requests. Wake it up after adding enqueing the + * request. + */ +void xprt_complete_bc_request(struct rpc_rqst *req, uint32_t copied) +{ + struct rpc_xprt *xprt = req->rq_xprt; + struct svc_serv *bc_serv = xprt->bc_serv; + + spin_lock(&xprt->bc_pa_lock); + list_del(&req->rq_bc_pa_list); + xprt->bc_alloc_count--; + spin_unlock(&xprt->bc_pa_lock); + + req->rq_private_buf.len = copied; + set_bit(RPC_BC_PA_IN_USE, &req->rq_bc_pa_state); + + dprintk("RPC: add callback request to list\n"); + xprt_get(xprt); + spin_lock(&bc_serv->sv_cb_lock); + list_add(&req->rq_bc_list, &bc_serv->sv_cb_list); + wake_up(&bc_serv->sv_cb_waitq); + spin_unlock(&bc_serv->sv_cb_lock); +} diff --git a/net/sunrpc/cache.c b/net/sunrpc/cache.c new file mode 100644 index 000000000..f075a9fb5 --- /dev/null +++ b/net/sunrpc/cache.c @@ -0,0 +1,1918 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/sunrpc/cache.c + * + * Generic code for various authentication-related caches + * used by sunrpc clients and servers. + * + * Copyright (C) 2002 Neil Brown + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "netns.h" +#include "fail.h" + +#define RPCDBG_FACILITY RPCDBG_CACHE + +static bool cache_defer_req(struct cache_req *req, struct cache_head *item); +static void cache_revisit_request(struct cache_head *item); + +static void cache_init(struct cache_head *h, struct cache_detail *detail) +{ + time64_t now = seconds_since_boot(); + INIT_HLIST_NODE(&h->cache_list); + h->flags = 0; + kref_init(&h->ref); + h->expiry_time = now + CACHE_NEW_EXPIRY; + if (now <= detail->flush_time) + /* ensure it isn't already expired */ + now = detail->flush_time + 1; + h->last_refresh = now; +} + +static void cache_fresh_unlocked(struct cache_head *head, + struct cache_detail *detail); + +static struct cache_head *sunrpc_cache_find_rcu(struct cache_detail *detail, + struct cache_head *key, + int hash) +{ + struct hlist_head *head = &detail->hash_table[hash]; + struct cache_head *tmp; + + rcu_read_lock(); + hlist_for_each_entry_rcu(tmp, head, cache_list) { + if (!detail->match(tmp, key)) + continue; + if (test_bit(CACHE_VALID, &tmp->flags) && + cache_is_expired(detail, tmp)) + continue; + tmp = cache_get_rcu(tmp); + rcu_read_unlock(); + return tmp; + } + rcu_read_unlock(); + return NULL; +} + +static void sunrpc_begin_cache_remove_entry(struct cache_head *ch, + struct cache_detail *cd) +{ + /* Must be called under cd->hash_lock */ + hlist_del_init_rcu(&ch->cache_list); + set_bit(CACHE_CLEANED, &ch->flags); + cd->entries --; +} + +static void sunrpc_end_cache_remove_entry(struct cache_head *ch, + struct cache_detail *cd) +{ + cache_fresh_unlocked(ch, cd); + cache_put(ch, cd); +} + +static struct cache_head *sunrpc_cache_add_entry(struct cache_detail *detail, + struct cache_head *key, + int hash) +{ + struct cache_head *new, *tmp, *freeme = NULL; + struct hlist_head *head = &detail->hash_table[hash]; + + new = detail->alloc(); + if (!new) + return NULL; + /* must fully initialise 'new', else + * we might get lose if we need to + * cache_put it soon. + */ + cache_init(new, detail); + detail->init(new, key); + + spin_lock(&detail->hash_lock); + + /* check if entry appeared while we slept */ + hlist_for_each_entry_rcu(tmp, head, cache_list, + lockdep_is_held(&detail->hash_lock)) { + if (!detail->match(tmp, key)) + continue; + if (test_bit(CACHE_VALID, &tmp->flags) && + cache_is_expired(detail, tmp)) { + sunrpc_begin_cache_remove_entry(tmp, detail); + trace_cache_entry_expired(detail, tmp); + freeme = tmp; + break; + } + cache_get(tmp); + spin_unlock(&detail->hash_lock); + cache_put(new, detail); + return tmp; + } + + hlist_add_head_rcu(&new->cache_list, head); + detail->entries++; + cache_get(new); + spin_unlock(&detail->hash_lock); + + if (freeme) + sunrpc_end_cache_remove_entry(freeme, detail); + return new; +} + +struct cache_head *sunrpc_cache_lookup_rcu(struct cache_detail *detail, + struct cache_head *key, int hash) +{ + struct cache_head *ret; + + ret = sunrpc_cache_find_rcu(detail, key, hash); + if (ret) + return ret; + /* Didn't find anything, insert an empty entry */ + return sunrpc_cache_add_entry(detail, key, hash); +} +EXPORT_SYMBOL_GPL(sunrpc_cache_lookup_rcu); + +static void cache_dequeue(struct cache_detail *detail, struct cache_head *ch); + +static void cache_fresh_locked(struct cache_head *head, time64_t expiry, + struct cache_detail *detail) +{ + time64_t now = seconds_since_boot(); + if (now <= detail->flush_time) + /* ensure it isn't immediately treated as expired */ + now = detail->flush_time + 1; + head->expiry_time = expiry; + head->last_refresh = now; + smp_wmb(); /* paired with smp_rmb() in cache_is_valid() */ + set_bit(CACHE_VALID, &head->flags); +} + +static void cache_fresh_unlocked(struct cache_head *head, + struct cache_detail *detail) +{ + if (test_and_clear_bit(CACHE_PENDING, &head->flags)) { + cache_revisit_request(head); + cache_dequeue(detail, head); + } +} + +static void cache_make_negative(struct cache_detail *detail, + struct cache_head *h) +{ + set_bit(CACHE_NEGATIVE, &h->flags); + trace_cache_entry_make_negative(detail, h); +} + +static void cache_entry_update(struct cache_detail *detail, + struct cache_head *h, + struct cache_head *new) +{ + if (!test_bit(CACHE_NEGATIVE, &new->flags)) { + detail->update(h, new); + trace_cache_entry_update(detail, h); + } else { + cache_make_negative(detail, h); + } +} + +struct cache_head *sunrpc_cache_update(struct cache_detail *detail, + struct cache_head *new, struct cache_head *old, int hash) +{ + /* The 'old' entry is to be replaced by 'new'. + * If 'old' is not VALID, we update it directly, + * otherwise we need to replace it + */ + struct cache_head *tmp; + + if (!test_bit(CACHE_VALID, &old->flags)) { + spin_lock(&detail->hash_lock); + if (!test_bit(CACHE_VALID, &old->flags)) { + cache_entry_update(detail, old, new); + cache_fresh_locked(old, new->expiry_time, detail); + spin_unlock(&detail->hash_lock); + cache_fresh_unlocked(old, detail); + return old; + } + spin_unlock(&detail->hash_lock); + } + /* We need to insert a new entry */ + tmp = detail->alloc(); + if (!tmp) { + cache_put(old, detail); + return NULL; + } + cache_init(tmp, detail); + detail->init(tmp, old); + + spin_lock(&detail->hash_lock); + cache_entry_update(detail, tmp, new); + hlist_add_head(&tmp->cache_list, &detail->hash_table[hash]); + detail->entries++; + cache_get(tmp); + cache_fresh_locked(tmp, new->expiry_time, detail); + cache_fresh_locked(old, 0, detail); + spin_unlock(&detail->hash_lock); + cache_fresh_unlocked(tmp, detail); + cache_fresh_unlocked(old, detail); + cache_put(old, detail); + return tmp; +} +EXPORT_SYMBOL_GPL(sunrpc_cache_update); + +static inline int cache_is_valid(struct cache_head *h) +{ + if (!test_bit(CACHE_VALID, &h->flags)) + return -EAGAIN; + else { + /* entry is valid */ + if (test_bit(CACHE_NEGATIVE, &h->flags)) + return -ENOENT; + else { + /* + * In combination with write barrier in + * sunrpc_cache_update, ensures that anyone + * using the cache entry after this sees the + * updated contents: + */ + smp_rmb(); + return 0; + } + } +} + +static int try_to_negate_entry(struct cache_detail *detail, struct cache_head *h) +{ + int rv; + + spin_lock(&detail->hash_lock); + rv = cache_is_valid(h); + if (rv == -EAGAIN) { + cache_make_negative(detail, h); + cache_fresh_locked(h, seconds_since_boot()+CACHE_NEW_EXPIRY, + detail); + rv = -ENOENT; + } + spin_unlock(&detail->hash_lock); + cache_fresh_unlocked(h, detail); + return rv; +} + +/* + * This is the generic cache management routine for all + * the authentication caches. + * It checks the currency of a cache item and will (later) + * initiate an upcall to fill it if needed. + * + * + * Returns 0 if the cache_head can be used, or cache_puts it and returns + * -EAGAIN if upcall is pending and request has been queued + * -ETIMEDOUT if upcall failed or request could not be queue or + * upcall completed but item is still invalid (implying that + * the cache item has been replaced with a newer one). + * -ENOENT if cache entry was negative + */ +int cache_check(struct cache_detail *detail, + struct cache_head *h, struct cache_req *rqstp) +{ + int rv; + time64_t refresh_age, age; + + /* First decide return status as best we can */ + rv = cache_is_valid(h); + + /* now see if we want to start an upcall */ + refresh_age = (h->expiry_time - h->last_refresh); + age = seconds_since_boot() - h->last_refresh; + + if (rqstp == NULL) { + if (rv == -EAGAIN) + rv = -ENOENT; + } else if (rv == -EAGAIN || + (h->expiry_time != 0 && age > refresh_age/2)) { + dprintk("RPC: Want update, refage=%lld, age=%lld\n", + refresh_age, age); + switch (detail->cache_upcall(detail, h)) { + case -EINVAL: + rv = try_to_negate_entry(detail, h); + break; + case -EAGAIN: + cache_fresh_unlocked(h, detail); + break; + } + } + + if (rv == -EAGAIN) { + if (!cache_defer_req(rqstp, h)) { + /* + * Request was not deferred; handle it as best + * we can ourselves: + */ + rv = cache_is_valid(h); + if (rv == -EAGAIN) + rv = -ETIMEDOUT; + } + } + if (rv) + cache_put(h, detail); + return rv; +} +EXPORT_SYMBOL_GPL(cache_check); + +/* + * caches need to be periodically cleaned. + * For this we maintain a list of cache_detail and + * a current pointer into that list and into the table + * for that entry. + * + * Each time cache_clean is called it finds the next non-empty entry + * in the current table and walks the list in that entry + * looking for entries that can be removed. + * + * An entry gets removed if: + * - The expiry is before current time + * - The last_refresh time is before the flush_time for that cache + * + * later we might drop old entries with non-NEVER expiry if that table + * is getting 'full' for some definition of 'full' + * + * The question of "how often to scan a table" is an interesting one + * and is answered in part by the use of the "nextcheck" field in the + * cache_detail. + * When a scan of a table begins, the nextcheck field is set to a time + * that is well into the future. + * While scanning, if an expiry time is found that is earlier than the + * current nextcheck time, nextcheck is set to that expiry time. + * If the flush_time is ever set to a time earlier than the nextcheck + * time, the nextcheck time is then set to that flush_time. + * + * A table is then only scanned if the current time is at least + * the nextcheck time. + * + */ + +static LIST_HEAD(cache_list); +static DEFINE_SPINLOCK(cache_list_lock); +static struct cache_detail *current_detail; +static int current_index; + +static void do_cache_clean(struct work_struct *work); +static struct delayed_work cache_cleaner; + +void sunrpc_init_cache_detail(struct cache_detail *cd) +{ + spin_lock_init(&cd->hash_lock); + INIT_LIST_HEAD(&cd->queue); + spin_lock(&cache_list_lock); + cd->nextcheck = 0; + cd->entries = 0; + atomic_set(&cd->writers, 0); + cd->last_close = 0; + cd->last_warn = -1; + list_add(&cd->others, &cache_list); + spin_unlock(&cache_list_lock); + + /* start the cleaning process */ + queue_delayed_work(system_power_efficient_wq, &cache_cleaner, 0); +} +EXPORT_SYMBOL_GPL(sunrpc_init_cache_detail); + +void sunrpc_destroy_cache_detail(struct cache_detail *cd) +{ + cache_purge(cd); + spin_lock(&cache_list_lock); + spin_lock(&cd->hash_lock); + if (current_detail == cd) + current_detail = NULL; + list_del_init(&cd->others); + spin_unlock(&cd->hash_lock); + spin_unlock(&cache_list_lock); + if (list_empty(&cache_list)) { + /* module must be being unloaded so its safe to kill the worker */ + cancel_delayed_work_sync(&cache_cleaner); + } +} +EXPORT_SYMBOL_GPL(sunrpc_destroy_cache_detail); + +/* clean cache tries to find something to clean + * and cleans it. + * It returns 1 if it cleaned something, + * 0 if it didn't find anything this time + * -1 if it fell off the end of the list. + */ +static int cache_clean(void) +{ + int rv = 0; + struct list_head *next; + + spin_lock(&cache_list_lock); + + /* find a suitable table if we don't already have one */ + while (current_detail == NULL || + current_index >= current_detail->hash_size) { + if (current_detail) + next = current_detail->others.next; + else + next = cache_list.next; + if (next == &cache_list) { + current_detail = NULL; + spin_unlock(&cache_list_lock); + return -1; + } + current_detail = list_entry(next, struct cache_detail, others); + if (current_detail->nextcheck > seconds_since_boot()) + current_index = current_detail->hash_size; + else { + current_index = 0; + current_detail->nextcheck = seconds_since_boot()+30*60; + } + } + + /* find a non-empty bucket in the table */ + while (current_detail && + current_index < current_detail->hash_size && + hlist_empty(¤t_detail->hash_table[current_index])) + current_index++; + + /* find a cleanable entry in the bucket and clean it, or set to next bucket */ + + if (current_detail && current_index < current_detail->hash_size) { + struct cache_head *ch = NULL; + struct cache_detail *d; + struct hlist_head *head; + struct hlist_node *tmp; + + spin_lock(¤t_detail->hash_lock); + + /* Ok, now to clean this strand */ + + head = ¤t_detail->hash_table[current_index]; + hlist_for_each_entry_safe(ch, tmp, head, cache_list) { + if (current_detail->nextcheck > ch->expiry_time) + current_detail->nextcheck = ch->expiry_time+1; + if (!cache_is_expired(current_detail, ch)) + continue; + + sunrpc_begin_cache_remove_entry(ch, current_detail); + trace_cache_entry_expired(current_detail, ch); + rv = 1; + break; + } + + spin_unlock(¤t_detail->hash_lock); + d = current_detail; + if (!ch) + current_index ++; + spin_unlock(&cache_list_lock); + if (ch) + sunrpc_end_cache_remove_entry(ch, d); + } else + spin_unlock(&cache_list_lock); + + return rv; +} + +/* + * We want to regularly clean the cache, so we need to schedule some work ... + */ +static void do_cache_clean(struct work_struct *work) +{ + int delay; + + if (list_empty(&cache_list)) + return; + + if (cache_clean() == -1) + delay = round_jiffies_relative(30*HZ); + else + delay = 5; + + queue_delayed_work(system_power_efficient_wq, &cache_cleaner, delay); +} + + +/* + * Clean all caches promptly. This just calls cache_clean + * repeatedly until we are sure that every cache has had a chance to + * be fully cleaned + */ +void cache_flush(void) +{ + while (cache_clean() != -1) + cond_resched(); + while (cache_clean() != -1) + cond_resched(); +} +EXPORT_SYMBOL_GPL(cache_flush); + +void cache_purge(struct cache_detail *detail) +{ + struct cache_head *ch = NULL; + struct hlist_head *head = NULL; + int i = 0; + + spin_lock(&detail->hash_lock); + if (!detail->entries) { + spin_unlock(&detail->hash_lock); + return; + } + + dprintk("RPC: %d entries in %s cache\n", detail->entries, detail->name); + for (i = 0; i < detail->hash_size; i++) { + head = &detail->hash_table[i]; + while (!hlist_empty(head)) { + ch = hlist_entry(head->first, struct cache_head, + cache_list); + sunrpc_begin_cache_remove_entry(ch, detail); + spin_unlock(&detail->hash_lock); + sunrpc_end_cache_remove_entry(ch, detail); + spin_lock(&detail->hash_lock); + } + } + spin_unlock(&detail->hash_lock); +} +EXPORT_SYMBOL_GPL(cache_purge); + + +/* + * Deferral and Revisiting of Requests. + * + * If a cache lookup finds a pending entry, we + * need to defer the request and revisit it later. + * All deferred requests are stored in a hash table, + * indexed by "struct cache_head *". + * As it may be wasteful to store a whole request + * structure, we allow the request to provide a + * deferred form, which must contain a + * 'struct cache_deferred_req' + * This cache_deferred_req contains a method to allow + * it to be revisited when cache info is available + */ + +#define DFR_HASHSIZE (PAGE_SIZE/sizeof(struct list_head)) +#define DFR_HASH(item) ((((long)item)>>4 ^ (((long)item)>>13)) % DFR_HASHSIZE) + +#define DFR_MAX 300 /* ??? */ + +static DEFINE_SPINLOCK(cache_defer_lock); +static LIST_HEAD(cache_defer_list); +static struct hlist_head cache_defer_hash[DFR_HASHSIZE]; +static int cache_defer_cnt; + +static void __unhash_deferred_req(struct cache_deferred_req *dreq) +{ + hlist_del_init(&dreq->hash); + if (!list_empty(&dreq->recent)) { + list_del_init(&dreq->recent); + cache_defer_cnt--; + } +} + +static void __hash_deferred_req(struct cache_deferred_req *dreq, struct cache_head *item) +{ + int hash = DFR_HASH(item); + + INIT_LIST_HEAD(&dreq->recent); + hlist_add_head(&dreq->hash, &cache_defer_hash[hash]); +} + +static void setup_deferral(struct cache_deferred_req *dreq, + struct cache_head *item, + int count_me) +{ + + dreq->item = item; + + spin_lock(&cache_defer_lock); + + __hash_deferred_req(dreq, item); + + if (count_me) { + cache_defer_cnt++; + list_add(&dreq->recent, &cache_defer_list); + } + + spin_unlock(&cache_defer_lock); + +} + +struct thread_deferred_req { + struct cache_deferred_req handle; + struct completion completion; +}; + +static void cache_restart_thread(struct cache_deferred_req *dreq, int too_many) +{ + struct thread_deferred_req *dr = + container_of(dreq, struct thread_deferred_req, handle); + complete(&dr->completion); +} + +static void cache_wait_req(struct cache_req *req, struct cache_head *item) +{ + struct thread_deferred_req sleeper; + struct cache_deferred_req *dreq = &sleeper.handle; + + sleeper.completion = COMPLETION_INITIALIZER_ONSTACK(sleeper.completion); + dreq->revisit = cache_restart_thread; + + setup_deferral(dreq, item, 0); + + if (!test_bit(CACHE_PENDING, &item->flags) || + wait_for_completion_interruptible_timeout( + &sleeper.completion, req->thread_wait) <= 0) { + /* The completion wasn't completed, so we need + * to clean up + */ + spin_lock(&cache_defer_lock); + if (!hlist_unhashed(&sleeper.handle.hash)) { + __unhash_deferred_req(&sleeper.handle); + spin_unlock(&cache_defer_lock); + } else { + /* cache_revisit_request already removed + * this from the hash table, but hasn't + * called ->revisit yet. It will very soon + * and we need to wait for it. + */ + spin_unlock(&cache_defer_lock); + wait_for_completion(&sleeper.completion); + } + } +} + +static void cache_limit_defers(void) +{ + /* Make sure we haven't exceed the limit of allowed deferred + * requests. + */ + struct cache_deferred_req *discard = NULL; + + if (cache_defer_cnt <= DFR_MAX) + return; + + spin_lock(&cache_defer_lock); + + /* Consider removing either the first or the last */ + if (cache_defer_cnt > DFR_MAX) { + if (prandom_u32_max(2)) + discard = list_entry(cache_defer_list.next, + struct cache_deferred_req, recent); + else + discard = list_entry(cache_defer_list.prev, + struct cache_deferred_req, recent); + __unhash_deferred_req(discard); + } + spin_unlock(&cache_defer_lock); + if (discard) + discard->revisit(discard, 1); +} + +#if IS_ENABLED(CONFIG_FAIL_SUNRPC) +static inline bool cache_defer_immediately(void) +{ + return !fail_sunrpc.ignore_cache_wait && + should_fail(&fail_sunrpc.attr, 1); +} +#else +static inline bool cache_defer_immediately(void) +{ + return false; +} +#endif + +/* Return true if and only if a deferred request is queued. */ +static bool cache_defer_req(struct cache_req *req, struct cache_head *item) +{ + struct cache_deferred_req *dreq; + + if (!cache_defer_immediately()) { + cache_wait_req(req, item); + if (!test_bit(CACHE_PENDING, &item->flags)) + return false; + } + + dreq = req->defer(req); + if (dreq == NULL) + return false; + setup_deferral(dreq, item, 1); + if (!test_bit(CACHE_PENDING, &item->flags)) + /* Bit could have been cleared before we managed to + * set up the deferral, so need to revisit just in case + */ + cache_revisit_request(item); + + cache_limit_defers(); + return true; +} + +static void cache_revisit_request(struct cache_head *item) +{ + struct cache_deferred_req *dreq; + struct list_head pending; + struct hlist_node *tmp; + int hash = DFR_HASH(item); + + INIT_LIST_HEAD(&pending); + spin_lock(&cache_defer_lock); + + hlist_for_each_entry_safe(dreq, tmp, &cache_defer_hash[hash], hash) + if (dreq->item == item) { + __unhash_deferred_req(dreq); + list_add(&dreq->recent, &pending); + } + + spin_unlock(&cache_defer_lock); + + while (!list_empty(&pending)) { + dreq = list_entry(pending.next, struct cache_deferred_req, recent); + list_del_init(&dreq->recent); + dreq->revisit(dreq, 0); + } +} + +void cache_clean_deferred(void *owner) +{ + struct cache_deferred_req *dreq, *tmp; + struct list_head pending; + + + INIT_LIST_HEAD(&pending); + spin_lock(&cache_defer_lock); + + list_for_each_entry_safe(dreq, tmp, &cache_defer_list, recent) { + if (dreq->owner == owner) { + __unhash_deferred_req(dreq); + list_add(&dreq->recent, &pending); + } + } + spin_unlock(&cache_defer_lock); + + while (!list_empty(&pending)) { + dreq = list_entry(pending.next, struct cache_deferred_req, recent); + list_del_init(&dreq->recent); + dreq->revisit(dreq, 1); + } +} + +/* + * communicate with user-space + * + * We have a magic /proc file - /proc/net/rpc//channel. + * On read, you get a full request, or block. + * On write, an update request is processed. + * Poll works if anything to read, and always allows write. + * + * Implemented by linked list of requests. Each open file has + * a ->private that also exists in this list. New requests are added + * to the end and may wakeup and preceding readers. + * New readers are added to the head. If, on read, an item is found with + * CACHE_UPCALLING clear, we free it from the list. + * + */ + +static DEFINE_SPINLOCK(queue_lock); + +struct cache_queue { + struct list_head list; + int reader; /* if 0, then request */ +}; +struct cache_request { + struct cache_queue q; + struct cache_head *item; + char * buf; + int len; + int readers; +}; +struct cache_reader { + struct cache_queue q; + int offset; /* if non-0, we have a refcnt on next request */ +}; + +static int cache_request(struct cache_detail *detail, + struct cache_request *crq) +{ + char *bp = crq->buf; + int len = PAGE_SIZE; + + detail->cache_request(detail, crq->item, &bp, &len); + if (len < 0) + return -E2BIG; + return PAGE_SIZE - len; +} + +static ssize_t cache_read(struct file *filp, char __user *buf, size_t count, + loff_t *ppos, struct cache_detail *cd) +{ + struct cache_reader *rp = filp->private_data; + struct cache_request *rq; + struct inode *inode = file_inode(filp); + int err; + + if (count == 0) + return 0; + + inode_lock(inode); /* protect against multiple concurrent + * readers on this file */ + again: + spin_lock(&queue_lock); + /* need to find next request */ + while (rp->q.list.next != &cd->queue && + list_entry(rp->q.list.next, struct cache_queue, list) + ->reader) { + struct list_head *next = rp->q.list.next; + list_move(&rp->q.list, next); + } + if (rp->q.list.next == &cd->queue) { + spin_unlock(&queue_lock); + inode_unlock(inode); + WARN_ON_ONCE(rp->offset); + return 0; + } + rq = container_of(rp->q.list.next, struct cache_request, q.list); + WARN_ON_ONCE(rq->q.reader); + if (rp->offset == 0) + rq->readers++; + spin_unlock(&queue_lock); + + if (rq->len == 0) { + err = cache_request(cd, rq); + if (err < 0) + goto out; + rq->len = err; + } + + if (rp->offset == 0 && !test_bit(CACHE_PENDING, &rq->item->flags)) { + err = -EAGAIN; + spin_lock(&queue_lock); + list_move(&rp->q.list, &rq->q.list); + spin_unlock(&queue_lock); + } else { + if (rp->offset + count > rq->len) + count = rq->len - rp->offset; + err = -EFAULT; + if (copy_to_user(buf, rq->buf + rp->offset, count)) + goto out; + rp->offset += count; + if (rp->offset >= rq->len) { + rp->offset = 0; + spin_lock(&queue_lock); + list_move(&rp->q.list, &rq->q.list); + spin_unlock(&queue_lock); + } + err = 0; + } + out: + if (rp->offset == 0) { + /* need to release rq */ + spin_lock(&queue_lock); + rq->readers--; + if (rq->readers == 0 && + !test_bit(CACHE_PENDING, &rq->item->flags)) { + list_del(&rq->q.list); + spin_unlock(&queue_lock); + cache_put(rq->item, cd); + kfree(rq->buf); + kfree(rq); + } else + spin_unlock(&queue_lock); + } + if (err == -EAGAIN) + goto again; + inode_unlock(inode); + return err ? err : count; +} + +static ssize_t cache_do_downcall(char *kaddr, const char __user *buf, + size_t count, struct cache_detail *cd) +{ + ssize_t ret; + + if (count == 0) + return -EINVAL; + if (copy_from_user(kaddr, buf, count)) + return -EFAULT; + kaddr[count] = '\0'; + ret = cd->cache_parse(cd, kaddr, count); + if (!ret) + ret = count; + return ret; +} + +static ssize_t cache_downcall(struct address_space *mapping, + const char __user *buf, + size_t count, struct cache_detail *cd) +{ + char *write_buf; + ssize_t ret = -ENOMEM; + + if (count >= 32768) { /* 32k is max userland buffer, lets check anyway */ + ret = -EINVAL; + goto out; + } + + write_buf = kvmalloc(count + 1, GFP_KERNEL); + if (!write_buf) + goto out; + + ret = cache_do_downcall(write_buf, buf, count, cd); + kvfree(write_buf); +out: + return ret; +} + +static ssize_t cache_write(struct file *filp, const char __user *buf, + size_t count, loff_t *ppos, + struct cache_detail *cd) +{ + struct address_space *mapping = filp->f_mapping; + struct inode *inode = file_inode(filp); + ssize_t ret = -EINVAL; + + if (!cd->cache_parse) + goto out; + + inode_lock(inode); + ret = cache_downcall(mapping, buf, count, cd); + inode_unlock(inode); +out: + return ret; +} + +static DECLARE_WAIT_QUEUE_HEAD(queue_wait); + +static __poll_t cache_poll(struct file *filp, poll_table *wait, + struct cache_detail *cd) +{ + __poll_t mask; + struct cache_reader *rp = filp->private_data; + struct cache_queue *cq; + + poll_wait(filp, &queue_wait, wait); + + /* alway allow write */ + mask = EPOLLOUT | EPOLLWRNORM; + + if (!rp) + return mask; + + spin_lock(&queue_lock); + + for (cq= &rp->q; &cq->list != &cd->queue; + cq = list_entry(cq->list.next, struct cache_queue, list)) + if (!cq->reader) { + mask |= EPOLLIN | EPOLLRDNORM; + break; + } + spin_unlock(&queue_lock); + return mask; +} + +static int cache_ioctl(struct inode *ino, struct file *filp, + unsigned int cmd, unsigned long arg, + struct cache_detail *cd) +{ + int len = 0; + struct cache_reader *rp = filp->private_data; + struct cache_queue *cq; + + if (cmd != FIONREAD || !rp) + return -EINVAL; + + spin_lock(&queue_lock); + + /* only find the length remaining in current request, + * or the length of the next request + */ + for (cq= &rp->q; &cq->list != &cd->queue; + cq = list_entry(cq->list.next, struct cache_queue, list)) + if (!cq->reader) { + struct cache_request *cr = + container_of(cq, struct cache_request, q); + len = cr->len - rp->offset; + break; + } + spin_unlock(&queue_lock); + + return put_user(len, (int __user *)arg); +} + +static int cache_open(struct inode *inode, struct file *filp, + struct cache_detail *cd) +{ + struct cache_reader *rp = NULL; + + if (!cd || !try_module_get(cd->owner)) + return -EACCES; + nonseekable_open(inode, filp); + if (filp->f_mode & FMODE_READ) { + rp = kmalloc(sizeof(*rp), GFP_KERNEL); + if (!rp) { + module_put(cd->owner); + return -ENOMEM; + } + rp->offset = 0; + rp->q.reader = 1; + + spin_lock(&queue_lock); + list_add(&rp->q.list, &cd->queue); + spin_unlock(&queue_lock); + } + if (filp->f_mode & FMODE_WRITE) + atomic_inc(&cd->writers); + filp->private_data = rp; + return 0; +} + +static int cache_release(struct inode *inode, struct file *filp, + struct cache_detail *cd) +{ + struct cache_reader *rp = filp->private_data; + + if (rp) { + spin_lock(&queue_lock); + if (rp->offset) { + struct cache_queue *cq; + for (cq= &rp->q; &cq->list != &cd->queue; + cq = list_entry(cq->list.next, struct cache_queue, list)) + if (!cq->reader) { + container_of(cq, struct cache_request, q) + ->readers--; + break; + } + rp->offset = 0; + } + list_del(&rp->q.list); + spin_unlock(&queue_lock); + + filp->private_data = NULL; + kfree(rp); + + } + if (filp->f_mode & FMODE_WRITE) { + atomic_dec(&cd->writers); + cd->last_close = seconds_since_boot(); + } + module_put(cd->owner); + return 0; +} + + + +static void cache_dequeue(struct cache_detail *detail, struct cache_head *ch) +{ + struct cache_queue *cq, *tmp; + struct cache_request *cr; + struct list_head dequeued; + + INIT_LIST_HEAD(&dequeued); + spin_lock(&queue_lock); + list_for_each_entry_safe(cq, tmp, &detail->queue, list) + if (!cq->reader) { + cr = container_of(cq, struct cache_request, q); + if (cr->item != ch) + continue; + if (test_bit(CACHE_PENDING, &ch->flags)) + /* Lost a race and it is pending again */ + break; + if (cr->readers != 0) + continue; + list_move(&cr->q.list, &dequeued); + } + spin_unlock(&queue_lock); + while (!list_empty(&dequeued)) { + cr = list_entry(dequeued.next, struct cache_request, q.list); + list_del(&cr->q.list); + cache_put(cr->item, detail); + kfree(cr->buf); + kfree(cr); + } +} + +/* + * Support routines for text-based upcalls. + * Fields are separated by spaces. + * Fields are either mangled to quote space tab newline slosh with slosh + * or a hexified with a leading \x + * Record is terminated with newline. + * + */ + +void qword_add(char **bpp, int *lp, char *str) +{ + char *bp = *bpp; + int len = *lp; + int ret; + + if (len < 0) return; + + ret = string_escape_str(str, bp, len, ESCAPE_OCTAL, "\\ \n\t"); + if (ret >= len) { + bp += len; + len = -1; + } else { + bp += ret; + len -= ret; + *bp++ = ' '; + len--; + } + *bpp = bp; + *lp = len; +} +EXPORT_SYMBOL_GPL(qword_add); + +void qword_addhex(char **bpp, int *lp, char *buf, int blen) +{ + char *bp = *bpp; + int len = *lp; + + if (len < 0) return; + + if (len > 2) { + *bp++ = '\\'; + *bp++ = 'x'; + len -= 2; + while (blen && len >= 2) { + bp = hex_byte_pack(bp, *buf++); + len -= 2; + blen--; + } + } + if (blen || len<1) len = -1; + else { + *bp++ = ' '; + len--; + } + *bpp = bp; + *lp = len; +} +EXPORT_SYMBOL_GPL(qword_addhex); + +static void warn_no_listener(struct cache_detail *detail) +{ + if (detail->last_warn != detail->last_close) { + detail->last_warn = detail->last_close; + if (detail->warn_no_listener) + detail->warn_no_listener(detail, detail->last_close != 0); + } +} + +static bool cache_listeners_exist(struct cache_detail *detail) +{ + if (atomic_read(&detail->writers)) + return true; + if (detail->last_close == 0) + /* This cache was never opened */ + return false; + if (detail->last_close < seconds_since_boot() - 30) + /* + * We allow for the possibility that someone might + * restart a userspace daemon without restarting the + * server; but after 30 seconds, we give up. + */ + return false; + return true; +} + +/* + * register an upcall request to user-space and queue it up for read() by the + * upcall daemon. + * + * Each request is at most one page long. + */ +static int cache_pipe_upcall(struct cache_detail *detail, struct cache_head *h) +{ + char *buf; + struct cache_request *crq; + int ret = 0; + + if (test_bit(CACHE_CLEANED, &h->flags)) + /* Too late to make an upcall */ + return -EAGAIN; + + buf = kmalloc(PAGE_SIZE, GFP_KERNEL); + if (!buf) + return -EAGAIN; + + crq = kmalloc(sizeof (*crq), GFP_KERNEL); + if (!crq) { + kfree(buf); + return -EAGAIN; + } + + crq->q.reader = 0; + crq->buf = buf; + crq->len = 0; + crq->readers = 0; + spin_lock(&queue_lock); + if (test_bit(CACHE_PENDING, &h->flags)) { + crq->item = cache_get(h); + list_add_tail(&crq->q.list, &detail->queue); + trace_cache_entry_upcall(detail, h); + } else + /* Lost a race, no longer PENDING, so don't enqueue */ + ret = -EAGAIN; + spin_unlock(&queue_lock); + wake_up(&queue_wait); + if (ret == -EAGAIN) { + kfree(buf); + kfree(crq); + } + return ret; +} + +int sunrpc_cache_pipe_upcall(struct cache_detail *detail, struct cache_head *h) +{ + if (test_and_set_bit(CACHE_PENDING, &h->flags)) + return 0; + return cache_pipe_upcall(detail, h); +} +EXPORT_SYMBOL_GPL(sunrpc_cache_pipe_upcall); + +int sunrpc_cache_pipe_upcall_timeout(struct cache_detail *detail, + struct cache_head *h) +{ + if (!cache_listeners_exist(detail)) { + warn_no_listener(detail); + trace_cache_entry_no_listener(detail, h); + return -EINVAL; + } + return sunrpc_cache_pipe_upcall(detail, h); +} +EXPORT_SYMBOL_GPL(sunrpc_cache_pipe_upcall_timeout); + +/* + * parse a message from user-space and pass it + * to an appropriate cache + * Messages are, like requests, separated into fields by + * spaces and dequotes as \xHEXSTRING or embedded \nnn octal + * + * Message is + * reply cachename expiry key ... content.... + * + * key and content are both parsed by cache + */ + +int qword_get(char **bpp, char *dest, int bufsize) +{ + /* return bytes copied, or -1 on error */ + char *bp = *bpp; + int len = 0; + + while (*bp == ' ') bp++; + + if (bp[0] == '\\' && bp[1] == 'x') { + /* HEX STRING */ + bp += 2; + while (len < bufsize - 1) { + int h, l; + + h = hex_to_bin(bp[0]); + if (h < 0) + break; + + l = hex_to_bin(bp[1]); + if (l < 0) + break; + + *dest++ = (h << 4) | l; + bp += 2; + len++; + } + } else { + /* text with \nnn octal quoting */ + while (*bp != ' ' && *bp != '\n' && *bp && len < bufsize-1) { + if (*bp == '\\' && + isodigit(bp[1]) && (bp[1] <= '3') && + isodigit(bp[2]) && + isodigit(bp[3])) { + int byte = (*++bp -'0'); + bp++; + byte = (byte << 3) | (*bp++ - '0'); + byte = (byte << 3) | (*bp++ - '0'); + *dest++ = byte; + len++; + } else { + *dest++ = *bp++; + len++; + } + } + } + + if (*bp != ' ' && *bp != '\n' && *bp != '\0') + return -1; + while (*bp == ' ') bp++; + *bpp = bp; + *dest = '\0'; + return len; +} +EXPORT_SYMBOL_GPL(qword_get); + + +/* + * support /proc/net/rpc/$CACHENAME/content + * as a seqfile. + * We call ->cache_show passing NULL for the item to + * get a header, then pass each real item in the cache + */ + +static void *__cache_seq_start(struct seq_file *m, loff_t *pos) +{ + loff_t n = *pos; + unsigned int hash, entry; + struct cache_head *ch; + struct cache_detail *cd = m->private; + + if (!n--) + return SEQ_START_TOKEN; + hash = n >> 32; + entry = n & ((1LL<<32) - 1); + + hlist_for_each_entry_rcu(ch, &cd->hash_table[hash], cache_list) + if (!entry--) + return ch; + n &= ~((1LL<<32) - 1); + do { + hash++; + n += 1LL<<32; + } while(hash < cd->hash_size && + hlist_empty(&cd->hash_table[hash])); + if (hash >= cd->hash_size) + return NULL; + *pos = n+1; + return hlist_entry_safe(rcu_dereference_raw( + hlist_first_rcu(&cd->hash_table[hash])), + struct cache_head, cache_list); +} + +static void *cache_seq_next(struct seq_file *m, void *p, loff_t *pos) +{ + struct cache_head *ch = p; + int hash = (*pos >> 32); + struct cache_detail *cd = m->private; + + if (p == SEQ_START_TOKEN) + hash = 0; + else if (ch->cache_list.next == NULL) { + hash++; + *pos += 1LL<<32; + } else { + ++*pos; + return hlist_entry_safe(rcu_dereference_raw( + hlist_next_rcu(&ch->cache_list)), + struct cache_head, cache_list); + } + *pos &= ~((1LL<<32) - 1); + while (hash < cd->hash_size && + hlist_empty(&cd->hash_table[hash])) { + hash++; + *pos += 1LL<<32; + } + if (hash >= cd->hash_size) + return NULL; + ++*pos; + return hlist_entry_safe(rcu_dereference_raw( + hlist_first_rcu(&cd->hash_table[hash])), + struct cache_head, cache_list); +} + +void *cache_seq_start_rcu(struct seq_file *m, loff_t *pos) + __acquires(RCU) +{ + rcu_read_lock(); + return __cache_seq_start(m, pos); +} +EXPORT_SYMBOL_GPL(cache_seq_start_rcu); + +void *cache_seq_next_rcu(struct seq_file *file, void *p, loff_t *pos) +{ + return cache_seq_next(file, p, pos); +} +EXPORT_SYMBOL_GPL(cache_seq_next_rcu); + +void cache_seq_stop_rcu(struct seq_file *m, void *p) + __releases(RCU) +{ + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(cache_seq_stop_rcu); + +static int c_show(struct seq_file *m, void *p) +{ + struct cache_head *cp = p; + struct cache_detail *cd = m->private; + + if (p == SEQ_START_TOKEN) + return cd->cache_show(m, cd, NULL); + + ifdebug(CACHE) + seq_printf(m, "# expiry=%lld refcnt=%d flags=%lx\n", + convert_to_wallclock(cp->expiry_time), + kref_read(&cp->ref), cp->flags); + cache_get(cp); + if (cache_check(cd, cp, NULL)) + /* cache_check does a cache_put on failure */ + seq_puts(m, "# "); + else { + if (cache_is_expired(cd, cp)) + seq_puts(m, "# "); + cache_put(cp, cd); + } + + return cd->cache_show(m, cd, cp); +} + +static const struct seq_operations cache_content_op = { + .start = cache_seq_start_rcu, + .next = cache_seq_next_rcu, + .stop = cache_seq_stop_rcu, + .show = c_show, +}; + +static int content_open(struct inode *inode, struct file *file, + struct cache_detail *cd) +{ + struct seq_file *seq; + int err; + + if (!cd || !try_module_get(cd->owner)) + return -EACCES; + + err = seq_open(file, &cache_content_op); + if (err) { + module_put(cd->owner); + return err; + } + + seq = file->private_data; + seq->private = cd; + return 0; +} + +static int content_release(struct inode *inode, struct file *file, + struct cache_detail *cd) +{ + int ret = seq_release(inode, file); + module_put(cd->owner); + return ret; +} + +static int open_flush(struct inode *inode, struct file *file, + struct cache_detail *cd) +{ + if (!cd || !try_module_get(cd->owner)) + return -EACCES; + return nonseekable_open(inode, file); +} + +static int release_flush(struct inode *inode, struct file *file, + struct cache_detail *cd) +{ + module_put(cd->owner); + return 0; +} + +static ssize_t read_flush(struct file *file, char __user *buf, + size_t count, loff_t *ppos, + struct cache_detail *cd) +{ + char tbuf[22]; + size_t len; + + len = snprintf(tbuf, sizeof(tbuf), "%llu\n", + convert_to_wallclock(cd->flush_time)); + return simple_read_from_buffer(buf, count, ppos, tbuf, len); +} + +static ssize_t write_flush(struct file *file, const char __user *buf, + size_t count, loff_t *ppos, + struct cache_detail *cd) +{ + char tbuf[20]; + char *ep; + time64_t now; + + if (*ppos || count > sizeof(tbuf)-1) + return -EINVAL; + if (copy_from_user(tbuf, buf, count)) + return -EFAULT; + tbuf[count] = 0; + simple_strtoul(tbuf, &ep, 0); + if (*ep && *ep != '\n') + return -EINVAL; + /* Note that while we check that 'buf' holds a valid number, + * we always ignore the value and just flush everything. + * Making use of the number leads to races. + */ + + now = seconds_since_boot(); + /* Always flush everything, so behave like cache_purge() + * Do this by advancing flush_time to the current time, + * or by one second if it has already reached the current time. + * Newly added cache entries will always have ->last_refresh greater + * that ->flush_time, so they don't get flushed prematurely. + */ + + if (cd->flush_time >= now) + now = cd->flush_time + 1; + + cd->flush_time = now; + cd->nextcheck = now; + cache_flush(); + + if (cd->flush) + cd->flush(); + + *ppos += count; + return count; +} + +static ssize_t cache_read_procfs(struct file *filp, char __user *buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = pde_data(file_inode(filp)); + + return cache_read(filp, buf, count, ppos, cd); +} + +static ssize_t cache_write_procfs(struct file *filp, const char __user *buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = pde_data(file_inode(filp)); + + return cache_write(filp, buf, count, ppos, cd); +} + +static __poll_t cache_poll_procfs(struct file *filp, poll_table *wait) +{ + struct cache_detail *cd = pde_data(file_inode(filp)); + + return cache_poll(filp, wait, cd); +} + +static long cache_ioctl_procfs(struct file *filp, + unsigned int cmd, unsigned long arg) +{ + struct inode *inode = file_inode(filp); + struct cache_detail *cd = pde_data(inode); + + return cache_ioctl(inode, filp, cmd, arg, cd); +} + +static int cache_open_procfs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = pde_data(inode); + + return cache_open(inode, filp, cd); +} + +static int cache_release_procfs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = pde_data(inode); + + return cache_release(inode, filp, cd); +} + +static const struct proc_ops cache_channel_proc_ops = { + .proc_lseek = no_llseek, + .proc_read = cache_read_procfs, + .proc_write = cache_write_procfs, + .proc_poll = cache_poll_procfs, + .proc_ioctl = cache_ioctl_procfs, /* for FIONREAD */ + .proc_open = cache_open_procfs, + .proc_release = cache_release_procfs, +}; + +static int content_open_procfs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = pde_data(inode); + + return content_open(inode, filp, cd); +} + +static int content_release_procfs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = pde_data(inode); + + return content_release(inode, filp, cd); +} + +static const struct proc_ops content_proc_ops = { + .proc_open = content_open_procfs, + .proc_read = seq_read, + .proc_lseek = seq_lseek, + .proc_release = content_release_procfs, +}; + +static int open_flush_procfs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = pde_data(inode); + + return open_flush(inode, filp, cd); +} + +static int release_flush_procfs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = pde_data(inode); + + return release_flush(inode, filp, cd); +} + +static ssize_t read_flush_procfs(struct file *filp, char __user *buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = pde_data(file_inode(filp)); + + return read_flush(filp, buf, count, ppos, cd); +} + +static ssize_t write_flush_procfs(struct file *filp, + const char __user *buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = pde_data(file_inode(filp)); + + return write_flush(filp, buf, count, ppos, cd); +} + +static const struct proc_ops cache_flush_proc_ops = { + .proc_open = open_flush_procfs, + .proc_read = read_flush_procfs, + .proc_write = write_flush_procfs, + .proc_release = release_flush_procfs, + .proc_lseek = no_llseek, +}; + +static void remove_cache_proc_entries(struct cache_detail *cd) +{ + if (cd->procfs) { + proc_remove(cd->procfs); + cd->procfs = NULL; + } +} + +#ifdef CONFIG_PROC_FS +static int create_cache_proc_entries(struct cache_detail *cd, struct net *net) +{ + struct proc_dir_entry *p; + struct sunrpc_net *sn; + + sn = net_generic(net, sunrpc_net_id); + cd->procfs = proc_mkdir(cd->name, sn->proc_net_rpc); + if (cd->procfs == NULL) + goto out_nomem; + + p = proc_create_data("flush", S_IFREG | 0600, + cd->procfs, &cache_flush_proc_ops, cd); + if (p == NULL) + goto out_nomem; + + if (cd->cache_request || cd->cache_parse) { + p = proc_create_data("channel", S_IFREG | 0600, cd->procfs, + &cache_channel_proc_ops, cd); + if (p == NULL) + goto out_nomem; + } + if (cd->cache_show) { + p = proc_create_data("content", S_IFREG | 0400, cd->procfs, + &content_proc_ops, cd); + if (p == NULL) + goto out_nomem; + } + return 0; +out_nomem: + remove_cache_proc_entries(cd); + return -ENOMEM; +} +#else /* CONFIG_PROC_FS */ +static int create_cache_proc_entries(struct cache_detail *cd, struct net *net) +{ + return 0; +} +#endif + +void __init cache_initialize(void) +{ + INIT_DEFERRABLE_WORK(&cache_cleaner, do_cache_clean); +} + +int cache_register_net(struct cache_detail *cd, struct net *net) +{ + int ret; + + sunrpc_init_cache_detail(cd); + ret = create_cache_proc_entries(cd, net); + if (ret) + sunrpc_destroy_cache_detail(cd); + return ret; +} +EXPORT_SYMBOL_GPL(cache_register_net); + +void cache_unregister_net(struct cache_detail *cd, struct net *net) +{ + remove_cache_proc_entries(cd); + sunrpc_destroy_cache_detail(cd); +} +EXPORT_SYMBOL_GPL(cache_unregister_net); + +struct cache_detail *cache_create_net(const struct cache_detail *tmpl, struct net *net) +{ + struct cache_detail *cd; + int i; + + cd = kmemdup(tmpl, sizeof(struct cache_detail), GFP_KERNEL); + if (cd == NULL) + return ERR_PTR(-ENOMEM); + + cd->hash_table = kcalloc(cd->hash_size, sizeof(struct hlist_head), + GFP_KERNEL); + if (cd->hash_table == NULL) { + kfree(cd); + return ERR_PTR(-ENOMEM); + } + + for (i = 0; i < cd->hash_size; i++) + INIT_HLIST_HEAD(&cd->hash_table[i]); + cd->net = net; + return cd; +} +EXPORT_SYMBOL_GPL(cache_create_net); + +void cache_destroy_net(struct cache_detail *cd, struct net *net) +{ + kfree(cd->hash_table); + kfree(cd); +} +EXPORT_SYMBOL_GPL(cache_destroy_net); + +static ssize_t cache_read_pipefs(struct file *filp, char __user *buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = RPC_I(file_inode(filp))->private; + + return cache_read(filp, buf, count, ppos, cd); +} + +static ssize_t cache_write_pipefs(struct file *filp, const char __user *buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = RPC_I(file_inode(filp))->private; + + return cache_write(filp, buf, count, ppos, cd); +} + +static __poll_t cache_poll_pipefs(struct file *filp, poll_table *wait) +{ + struct cache_detail *cd = RPC_I(file_inode(filp))->private; + + return cache_poll(filp, wait, cd); +} + +static long cache_ioctl_pipefs(struct file *filp, + unsigned int cmd, unsigned long arg) +{ + struct inode *inode = file_inode(filp); + struct cache_detail *cd = RPC_I(inode)->private; + + return cache_ioctl(inode, filp, cmd, arg, cd); +} + +static int cache_open_pipefs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = RPC_I(inode)->private; + + return cache_open(inode, filp, cd); +} + +static int cache_release_pipefs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = RPC_I(inode)->private; + + return cache_release(inode, filp, cd); +} + +const struct file_operations cache_file_operations_pipefs = { + .owner = THIS_MODULE, + .llseek = no_llseek, + .read = cache_read_pipefs, + .write = cache_write_pipefs, + .poll = cache_poll_pipefs, + .unlocked_ioctl = cache_ioctl_pipefs, /* for FIONREAD */ + .open = cache_open_pipefs, + .release = cache_release_pipefs, +}; + +static int content_open_pipefs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = RPC_I(inode)->private; + + return content_open(inode, filp, cd); +} + +static int content_release_pipefs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = RPC_I(inode)->private; + + return content_release(inode, filp, cd); +} + +const struct file_operations content_file_operations_pipefs = { + .open = content_open_pipefs, + .read = seq_read, + .llseek = seq_lseek, + .release = content_release_pipefs, +}; + +static int open_flush_pipefs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = RPC_I(inode)->private; + + return open_flush(inode, filp, cd); +} + +static int release_flush_pipefs(struct inode *inode, struct file *filp) +{ + struct cache_detail *cd = RPC_I(inode)->private; + + return release_flush(inode, filp, cd); +} + +static ssize_t read_flush_pipefs(struct file *filp, char __user *buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = RPC_I(file_inode(filp))->private; + + return read_flush(filp, buf, count, ppos, cd); +} + +static ssize_t write_flush_pipefs(struct file *filp, + const char __user *buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = RPC_I(file_inode(filp))->private; + + return write_flush(filp, buf, count, ppos, cd); +} + +const struct file_operations cache_flush_operations_pipefs = { + .open = open_flush_pipefs, + .read = read_flush_pipefs, + .write = write_flush_pipefs, + .release = release_flush_pipefs, + .llseek = no_llseek, +}; + +int sunrpc_cache_register_pipefs(struct dentry *parent, + const char *name, umode_t umode, + struct cache_detail *cd) +{ + struct dentry *dir = rpc_create_cache_dir(parent, name, umode, cd); + if (IS_ERR(dir)) + return PTR_ERR(dir); + cd->pipefs = dir; + return 0; +} +EXPORT_SYMBOL_GPL(sunrpc_cache_register_pipefs); + +void sunrpc_cache_unregister_pipefs(struct cache_detail *cd) +{ + if (cd->pipefs) { + rpc_remove_cache_dir(cd->pipefs); + cd->pipefs = NULL; + } +} +EXPORT_SYMBOL_GPL(sunrpc_cache_unregister_pipefs); + +void sunrpc_cache_unhash(struct cache_detail *cd, struct cache_head *h) +{ + spin_lock(&cd->hash_lock); + if (!hlist_unhashed(&h->cache_list)){ + sunrpc_begin_cache_remove_entry(h, cd); + spin_unlock(&cd->hash_lock); + sunrpc_end_cache_remove_entry(h, cd); + } else + spin_unlock(&cd->hash_lock); +} +EXPORT_SYMBOL_GPL(sunrpc_cache_unhash); diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c new file mode 100644 index 000000000..61e5c7746 --- /dev/null +++ b/net/sunrpc/clnt.c @@ -0,0 +1,3363 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * linux/net/sunrpc/clnt.c + * + * This file contains the high-level RPC interface. + * It is modeled as a finite state machine to support both synchronous + * and asynchronous requests. + * + * - RPC header generation and argument serialization. + * - Credential refresh. + * - TCP connect handling. + * - Retry of operation when it is suspected the operation failed because + * of uid squashing on the server, or when the credentials were stale + * and need to be refreshed, or when a packet was damaged in transit. + * This may be have to be moved to the VFS layer. + * + * Copyright (C) 1992,1993 Rick Sladkey + * Copyright (C) 1995,1996 Olaf Kirch + */ + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "sunrpc.h" +#include "sysfs.h" +#include "netns.h" + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_CALL +#endif + +/* + * All RPC clients are linked into this list + */ + +static DECLARE_WAIT_QUEUE_HEAD(destroy_wait); + + +static void call_start(struct rpc_task *task); +static void call_reserve(struct rpc_task *task); +static void call_reserveresult(struct rpc_task *task); +static void call_allocate(struct rpc_task *task); +static void call_encode(struct rpc_task *task); +static void call_decode(struct rpc_task *task); +static void call_bind(struct rpc_task *task); +static void call_bind_status(struct rpc_task *task); +static void call_transmit(struct rpc_task *task); +static void call_status(struct rpc_task *task); +static void call_transmit_status(struct rpc_task *task); +static void call_refresh(struct rpc_task *task); +static void call_refreshresult(struct rpc_task *task); +static void call_connect(struct rpc_task *task); +static void call_connect_status(struct rpc_task *task); + +static int rpc_encode_header(struct rpc_task *task, + struct xdr_stream *xdr); +static int rpc_decode_header(struct rpc_task *task, + struct xdr_stream *xdr); +static int rpc_ping(struct rpc_clnt *clnt); +static int rpc_ping_noreply(struct rpc_clnt *clnt); +static void rpc_check_timeout(struct rpc_task *task); + +static void rpc_register_client(struct rpc_clnt *clnt) +{ + struct net *net = rpc_net_ns(clnt); + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + spin_lock(&sn->rpc_client_lock); + list_add(&clnt->cl_clients, &sn->all_clients); + spin_unlock(&sn->rpc_client_lock); +} + +static void rpc_unregister_client(struct rpc_clnt *clnt) +{ + struct net *net = rpc_net_ns(clnt); + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + spin_lock(&sn->rpc_client_lock); + list_del(&clnt->cl_clients); + spin_unlock(&sn->rpc_client_lock); +} + +static void __rpc_clnt_remove_pipedir(struct rpc_clnt *clnt) +{ + rpc_remove_client_dir(clnt); +} + +static void rpc_clnt_remove_pipedir(struct rpc_clnt *clnt) +{ + struct net *net = rpc_net_ns(clnt); + struct super_block *pipefs_sb; + + pipefs_sb = rpc_get_sb_net(net); + if (pipefs_sb) { + if (pipefs_sb == clnt->pipefs_sb) + __rpc_clnt_remove_pipedir(clnt); + rpc_put_sb_net(net); + } +} + +static struct dentry *rpc_setup_pipedir_sb(struct super_block *sb, + struct rpc_clnt *clnt) +{ + static uint32_t clntid; + const char *dir_name = clnt->cl_program->pipe_dir_name; + char name[15]; + struct dentry *dir, *dentry; + + dir = rpc_d_lookup_sb(sb, dir_name); + if (dir == NULL) { + pr_info("RPC: pipefs directory doesn't exist: %s\n", dir_name); + return dir; + } + for (;;) { + snprintf(name, sizeof(name), "clnt%x", (unsigned int)clntid++); + name[sizeof(name) - 1] = '\0'; + dentry = rpc_create_client_dir(dir, name, clnt); + if (!IS_ERR(dentry)) + break; + if (dentry == ERR_PTR(-EEXIST)) + continue; + printk(KERN_INFO "RPC: Couldn't create pipefs entry" + " %s/%s, error %ld\n", + dir_name, name, PTR_ERR(dentry)); + break; + } + dput(dir); + return dentry; +} + +static int +rpc_setup_pipedir(struct super_block *pipefs_sb, struct rpc_clnt *clnt) +{ + struct dentry *dentry; + + clnt->pipefs_sb = pipefs_sb; + + if (clnt->cl_program->pipe_dir_name != NULL) { + dentry = rpc_setup_pipedir_sb(pipefs_sb, clnt); + if (IS_ERR(dentry)) + return PTR_ERR(dentry); + } + return 0; +} + +static int rpc_clnt_skip_event(struct rpc_clnt *clnt, unsigned long event) +{ + if (clnt->cl_program->pipe_dir_name == NULL) + return 1; + + switch (event) { + case RPC_PIPEFS_MOUNT: + if (clnt->cl_pipedir_objects.pdh_dentry != NULL) + return 1; + if (refcount_read(&clnt->cl_count) == 0) + return 1; + break; + case RPC_PIPEFS_UMOUNT: + if (clnt->cl_pipedir_objects.pdh_dentry == NULL) + return 1; + break; + } + return 0; +} + +static int __rpc_clnt_handle_event(struct rpc_clnt *clnt, unsigned long event, + struct super_block *sb) +{ + struct dentry *dentry; + + switch (event) { + case RPC_PIPEFS_MOUNT: + dentry = rpc_setup_pipedir_sb(sb, clnt); + if (!dentry) + return -ENOENT; + if (IS_ERR(dentry)) + return PTR_ERR(dentry); + break; + case RPC_PIPEFS_UMOUNT: + __rpc_clnt_remove_pipedir(clnt); + break; + default: + printk(KERN_ERR "%s: unknown event: %ld\n", __func__, event); + return -ENOTSUPP; + } + return 0; +} + +static int __rpc_pipefs_event(struct rpc_clnt *clnt, unsigned long event, + struct super_block *sb) +{ + int error = 0; + + for (;; clnt = clnt->cl_parent) { + if (!rpc_clnt_skip_event(clnt, event)) + error = __rpc_clnt_handle_event(clnt, event, sb); + if (error || clnt == clnt->cl_parent) + break; + } + return error; +} + +static struct rpc_clnt *rpc_get_client_for_event(struct net *net, int event) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_clnt *clnt; + + spin_lock(&sn->rpc_client_lock); + list_for_each_entry(clnt, &sn->all_clients, cl_clients) { + if (rpc_clnt_skip_event(clnt, event)) + continue; + spin_unlock(&sn->rpc_client_lock); + return clnt; + } + spin_unlock(&sn->rpc_client_lock); + return NULL; +} + +static int rpc_pipefs_event(struct notifier_block *nb, unsigned long event, + void *ptr) +{ + struct super_block *sb = ptr; + struct rpc_clnt *clnt; + int error = 0; + + while ((clnt = rpc_get_client_for_event(sb->s_fs_info, event))) { + error = __rpc_pipefs_event(clnt, event, sb); + if (error) + break; + } + return error; +} + +static struct notifier_block rpc_clients_block = { + .notifier_call = rpc_pipefs_event, + .priority = SUNRPC_PIPEFS_RPC_PRIO, +}; + +int rpc_clients_notifier_register(void) +{ + return rpc_pipefs_notifier_register(&rpc_clients_block); +} + +void rpc_clients_notifier_unregister(void) +{ + return rpc_pipefs_notifier_unregister(&rpc_clients_block); +} + +static struct rpc_xprt *rpc_clnt_set_transport(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + const struct rpc_timeout *timeout) +{ + struct rpc_xprt *old; + + spin_lock(&clnt->cl_lock); + old = rcu_dereference_protected(clnt->cl_xprt, + lockdep_is_held(&clnt->cl_lock)); + + if (!xprt_bound(xprt)) + clnt->cl_autobind = 1; + + clnt->cl_timeout = timeout; + rcu_assign_pointer(clnt->cl_xprt, xprt); + spin_unlock(&clnt->cl_lock); + + return old; +} + +static void rpc_clnt_set_nodename(struct rpc_clnt *clnt, const char *nodename) +{ + clnt->cl_nodelen = strlcpy(clnt->cl_nodename, + nodename, sizeof(clnt->cl_nodename)); +} + +static int rpc_client_register(struct rpc_clnt *clnt, + rpc_authflavor_t pseudoflavor, + const char *client_name) +{ + struct rpc_auth_create_args auth_args = { + .pseudoflavor = pseudoflavor, + .target_name = client_name, + }; + struct rpc_auth *auth; + struct net *net = rpc_net_ns(clnt); + struct super_block *pipefs_sb; + int err; + + rpc_clnt_debugfs_register(clnt); + + pipefs_sb = rpc_get_sb_net(net); + if (pipefs_sb) { + err = rpc_setup_pipedir(pipefs_sb, clnt); + if (err) + goto out; + } + + rpc_register_client(clnt); + if (pipefs_sb) + rpc_put_sb_net(net); + + auth = rpcauth_create(&auth_args, clnt); + if (IS_ERR(auth)) { + dprintk("RPC: Couldn't create auth handle (flavor %u)\n", + pseudoflavor); + err = PTR_ERR(auth); + goto err_auth; + } + return 0; +err_auth: + pipefs_sb = rpc_get_sb_net(net); + rpc_unregister_client(clnt); + __rpc_clnt_remove_pipedir(clnt); +out: + if (pipefs_sb) + rpc_put_sb_net(net); + rpc_sysfs_client_destroy(clnt); + rpc_clnt_debugfs_unregister(clnt); + return err; +} + +static DEFINE_IDA(rpc_clids); + +void rpc_cleanup_clids(void) +{ + ida_destroy(&rpc_clids); +} + +static int rpc_alloc_clid(struct rpc_clnt *clnt) +{ + int clid; + + clid = ida_alloc(&rpc_clids, GFP_KERNEL); + if (clid < 0) + return clid; + clnt->cl_clid = clid; + return 0; +} + +static void rpc_free_clid(struct rpc_clnt *clnt) +{ + ida_free(&rpc_clids, clnt->cl_clid); +} + +static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args, + struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt, + struct rpc_clnt *parent) +{ + const struct rpc_program *program = args->program; + const struct rpc_version *version; + struct rpc_clnt *clnt = NULL; + const struct rpc_timeout *timeout; + const char *nodename = args->nodename; + int err; + + err = rpciod_up(); + if (err) + goto out_no_rpciod; + + err = -EINVAL; + if (args->version >= program->nrvers) + goto out_err; + version = program->version[args->version]; + if (version == NULL) + goto out_err; + + err = -ENOMEM; + clnt = kzalloc(sizeof(*clnt), GFP_KERNEL); + if (!clnt) + goto out_err; + clnt->cl_parent = parent ? : clnt; + + err = rpc_alloc_clid(clnt); + if (err) + goto out_no_clid; + + clnt->cl_cred = get_cred(args->cred); + clnt->cl_procinfo = version->procs; + clnt->cl_maxproc = version->nrprocs; + clnt->cl_prog = args->prognumber ? : program->number; + clnt->cl_vers = version->number; + clnt->cl_stats = program->stats; + clnt->cl_metrics = rpc_alloc_iostats(clnt); + rpc_init_pipe_dir_head(&clnt->cl_pipedir_objects); + err = -ENOMEM; + if (clnt->cl_metrics == NULL) + goto out_no_stats; + clnt->cl_program = program; + INIT_LIST_HEAD(&clnt->cl_tasks); + spin_lock_init(&clnt->cl_lock); + + timeout = xprt->timeout; + if (args->timeout != NULL) { + memcpy(&clnt->cl_timeout_default, args->timeout, + sizeof(clnt->cl_timeout_default)); + timeout = &clnt->cl_timeout_default; + } + + rpc_clnt_set_transport(clnt, xprt, timeout); + xprt->main = true; + xprt_iter_init(&clnt->cl_xpi, xps); + xprt_switch_put(xps); + + clnt->cl_rtt = &clnt->cl_rtt_default; + rpc_init_rtt(&clnt->cl_rtt_default, clnt->cl_timeout->to_initval); + + refcount_set(&clnt->cl_count, 1); + + if (nodename == NULL) + nodename = utsname()->nodename; + /* save the nodename */ + rpc_clnt_set_nodename(clnt, nodename); + + rpc_sysfs_client_setup(clnt, xps, rpc_net_ns(clnt)); + err = rpc_client_register(clnt, args->authflavor, args->client_name); + if (err) + goto out_no_path; + if (parent) + refcount_inc(&parent->cl_count); + + trace_rpc_clnt_new(clnt, xprt, program->name, args->servername); + return clnt; + +out_no_path: + rpc_free_iostats(clnt->cl_metrics); +out_no_stats: + put_cred(clnt->cl_cred); + rpc_free_clid(clnt); +out_no_clid: + kfree(clnt); +out_err: + rpciod_down(); +out_no_rpciod: + xprt_switch_put(xps); + xprt_put(xprt); + trace_rpc_clnt_new_err(program->name, args->servername, err); + return ERR_PTR(err); +} + +static struct rpc_clnt *rpc_create_xprt(struct rpc_create_args *args, + struct rpc_xprt *xprt) +{ + struct rpc_clnt *clnt = NULL; + struct rpc_xprt_switch *xps; + + if (args->bc_xprt && args->bc_xprt->xpt_bc_xps) { + WARN_ON_ONCE(!(args->protocol & XPRT_TRANSPORT_BC)); + xps = args->bc_xprt->xpt_bc_xps; + xprt_switch_get(xps); + } else { + xps = xprt_switch_alloc(xprt, GFP_KERNEL); + if (xps == NULL) { + xprt_put(xprt); + return ERR_PTR(-ENOMEM); + } + if (xprt->bc_xprt) { + xprt_switch_get(xps); + xprt->bc_xprt->xpt_bc_xps = xps; + } + } + clnt = rpc_new_client(args, xps, xprt, NULL); + if (IS_ERR(clnt)) + return clnt; + + if (!(args->flags & RPC_CLNT_CREATE_NOPING)) { + int err = rpc_ping(clnt); + if (err != 0) { + rpc_shutdown_client(clnt); + return ERR_PTR(err); + } + } else if (args->flags & RPC_CLNT_CREATE_CONNECTED) { + int err = rpc_ping_noreply(clnt); + if (err != 0) { + rpc_shutdown_client(clnt); + return ERR_PTR(err); + } + } + + clnt->cl_softrtry = 1; + if (args->flags & (RPC_CLNT_CREATE_HARDRTRY|RPC_CLNT_CREATE_SOFTERR)) { + clnt->cl_softrtry = 0; + if (args->flags & RPC_CLNT_CREATE_SOFTERR) + clnt->cl_softerr = 1; + } + + if (args->flags & RPC_CLNT_CREATE_AUTOBIND) + clnt->cl_autobind = 1; + if (args->flags & RPC_CLNT_CREATE_NO_RETRANS_TIMEOUT) + clnt->cl_noretranstimeo = 1; + if (args->flags & RPC_CLNT_CREATE_DISCRTRY) + clnt->cl_discrtry = 1; + if (!(args->flags & RPC_CLNT_CREATE_QUIET)) + clnt->cl_chatty = 1; + + return clnt; +} + +/** + * rpc_create - create an RPC client and transport with one call + * @args: rpc_clnt create argument structure + * + * Creates and initializes an RPC transport and an RPC client. + * + * It can ping the server in order to determine if it is up, and to see if + * it supports this program and version. RPC_CLNT_CREATE_NOPING disables + * this behavior so asynchronous tasks can also use rpc_create. + */ +struct rpc_clnt *rpc_create(struct rpc_create_args *args) +{ + struct rpc_xprt *xprt; + struct xprt_create xprtargs = { + .net = args->net, + .ident = args->protocol, + .srcaddr = args->saddress, + .dstaddr = args->address, + .addrlen = args->addrsize, + .servername = args->servername, + .bc_xprt = args->bc_xprt, + }; + char servername[48]; + struct rpc_clnt *clnt; + int i; + + if (args->bc_xprt) { + WARN_ON_ONCE(!(args->protocol & XPRT_TRANSPORT_BC)); + xprt = args->bc_xprt->xpt_bc_xprt; + if (xprt) { + xprt_get(xprt); + return rpc_create_xprt(args, xprt); + } + } + + if (args->flags & RPC_CLNT_CREATE_INFINITE_SLOTS) + xprtargs.flags |= XPRT_CREATE_INFINITE_SLOTS; + if (args->flags & RPC_CLNT_CREATE_NO_IDLE_TIMEOUT) + xprtargs.flags |= XPRT_CREATE_NO_IDLE_TIMEOUT; + /* + * If the caller chooses not to specify a hostname, whip + * up a string representation of the passed-in address. + */ + if (xprtargs.servername == NULL) { + struct sockaddr_un *sun = + (struct sockaddr_un *)args->address; + struct sockaddr_in *sin = + (struct sockaddr_in *)args->address; + struct sockaddr_in6 *sin6 = + (struct sockaddr_in6 *)args->address; + + servername[0] = '\0'; + switch (args->address->sa_family) { + case AF_LOCAL: + snprintf(servername, sizeof(servername), "%s", + sun->sun_path); + break; + case AF_INET: + snprintf(servername, sizeof(servername), "%pI4", + &sin->sin_addr.s_addr); + break; + case AF_INET6: + snprintf(servername, sizeof(servername), "%pI6", + &sin6->sin6_addr); + break; + default: + /* caller wants default server name, but + * address family isn't recognized. */ + return ERR_PTR(-EINVAL); + } + xprtargs.servername = servername; + } + + xprt = xprt_create_transport(&xprtargs); + if (IS_ERR(xprt)) + return (struct rpc_clnt *)xprt; + + /* + * By default, kernel RPC client connects from a reserved port. + * CAP_NET_BIND_SERVICE will not be set for unprivileged requesters, + * but it is always enabled for rpciod, which handles the connect + * operation. + */ + xprt->resvport = 1; + if (args->flags & RPC_CLNT_CREATE_NONPRIVPORT) + xprt->resvport = 0; + xprt->reuseport = 0; + if (args->flags & RPC_CLNT_CREATE_REUSEPORT) + xprt->reuseport = 1; + + clnt = rpc_create_xprt(args, xprt); + if (IS_ERR(clnt) || args->nconnect <= 1) + return clnt; + + for (i = 0; i < args->nconnect - 1; i++) { + if (rpc_clnt_add_xprt(clnt, &xprtargs, NULL, NULL) < 0) + break; + } + return clnt; +} +EXPORT_SYMBOL_GPL(rpc_create); + +/* + * This function clones the RPC client structure. It allows us to share the + * same transport while varying parameters such as the authentication + * flavour. + */ +static struct rpc_clnt *__rpc_clone_client(struct rpc_create_args *args, + struct rpc_clnt *clnt) +{ + struct rpc_xprt_switch *xps; + struct rpc_xprt *xprt; + struct rpc_clnt *new; + int err; + + err = -ENOMEM; + rcu_read_lock(); + xprt = xprt_get(rcu_dereference(clnt->cl_xprt)); + xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch)); + rcu_read_unlock(); + if (xprt == NULL || xps == NULL) { + xprt_put(xprt); + xprt_switch_put(xps); + goto out_err; + } + args->servername = xprt->servername; + args->nodename = clnt->cl_nodename; + + new = rpc_new_client(args, xps, xprt, clnt); + if (IS_ERR(new)) + return new; + + /* Turn off autobind on clones */ + new->cl_autobind = 0; + new->cl_softrtry = clnt->cl_softrtry; + new->cl_softerr = clnt->cl_softerr; + new->cl_noretranstimeo = clnt->cl_noretranstimeo; + new->cl_discrtry = clnt->cl_discrtry; + new->cl_chatty = clnt->cl_chatty; + new->cl_principal = clnt->cl_principal; + new->cl_max_connect = clnt->cl_max_connect; + return new; + +out_err: + trace_rpc_clnt_clone_err(clnt, err); + return ERR_PTR(err); +} + +/** + * rpc_clone_client - Clone an RPC client structure + * + * @clnt: RPC client whose parameters are copied + * + * Returns a fresh RPC client or an ERR_PTR. + */ +struct rpc_clnt *rpc_clone_client(struct rpc_clnt *clnt) +{ + struct rpc_create_args args = { + .program = clnt->cl_program, + .prognumber = clnt->cl_prog, + .version = clnt->cl_vers, + .authflavor = clnt->cl_auth->au_flavor, + .cred = clnt->cl_cred, + }; + return __rpc_clone_client(&args, clnt); +} +EXPORT_SYMBOL_GPL(rpc_clone_client); + +/** + * rpc_clone_client_set_auth - Clone an RPC client structure and set its auth + * + * @clnt: RPC client whose parameters are copied + * @flavor: security flavor for new client + * + * Returns a fresh RPC client or an ERR_PTR. + */ +struct rpc_clnt * +rpc_clone_client_set_auth(struct rpc_clnt *clnt, rpc_authflavor_t flavor) +{ + struct rpc_create_args args = { + .program = clnt->cl_program, + .prognumber = clnt->cl_prog, + .version = clnt->cl_vers, + .authflavor = flavor, + .cred = clnt->cl_cred, + }; + return __rpc_clone_client(&args, clnt); +} +EXPORT_SYMBOL_GPL(rpc_clone_client_set_auth); + +/** + * rpc_switch_client_transport: switch the RPC transport on the fly + * @clnt: pointer to a struct rpc_clnt + * @args: pointer to the new transport arguments + * @timeout: pointer to the new timeout parameters + * + * This function allows the caller to switch the RPC transport for the + * rpc_clnt structure 'clnt' to allow it to connect to a mirrored NFS + * server, for instance. It assumes that the caller has ensured that + * there are no active RPC tasks by using some form of locking. + * + * Returns zero if "clnt" is now using the new xprt. Otherwise a + * negative errno is returned, and "clnt" continues to use the old + * xprt. + */ +int rpc_switch_client_transport(struct rpc_clnt *clnt, + struct xprt_create *args, + const struct rpc_timeout *timeout) +{ + const struct rpc_timeout *old_timeo; + rpc_authflavor_t pseudoflavor; + struct rpc_xprt_switch *xps, *oldxps; + struct rpc_xprt *xprt, *old; + struct rpc_clnt *parent; + int err; + + xprt = xprt_create_transport(args); + if (IS_ERR(xprt)) + return PTR_ERR(xprt); + + xps = xprt_switch_alloc(xprt, GFP_KERNEL); + if (xps == NULL) { + xprt_put(xprt); + return -ENOMEM; + } + + pseudoflavor = clnt->cl_auth->au_flavor; + + old_timeo = clnt->cl_timeout; + old = rpc_clnt_set_transport(clnt, xprt, timeout); + oldxps = xprt_iter_xchg_switch(&clnt->cl_xpi, xps); + + rpc_unregister_client(clnt); + __rpc_clnt_remove_pipedir(clnt); + rpc_sysfs_client_destroy(clnt); + rpc_clnt_debugfs_unregister(clnt); + + /* + * A new transport was created. "clnt" therefore + * becomes the root of a new cl_parent tree. clnt's + * children, if it has any, still point to the old xprt. + */ + parent = clnt->cl_parent; + clnt->cl_parent = clnt; + + /* + * The old rpc_auth cache cannot be re-used. GSS + * contexts in particular are between a single + * client and server. + */ + err = rpc_client_register(clnt, pseudoflavor, NULL); + if (err) + goto out_revert; + + synchronize_rcu(); + if (parent != clnt) + rpc_release_client(parent); + xprt_switch_put(oldxps); + xprt_put(old); + trace_rpc_clnt_replace_xprt(clnt); + return 0; + +out_revert: + xps = xprt_iter_xchg_switch(&clnt->cl_xpi, oldxps); + rpc_clnt_set_transport(clnt, old, old_timeo); + clnt->cl_parent = parent; + rpc_client_register(clnt, pseudoflavor, NULL); + xprt_switch_put(xps); + xprt_put(xprt); + trace_rpc_clnt_replace_xprt_err(clnt); + return err; +} +EXPORT_SYMBOL_GPL(rpc_switch_client_transport); + +static +int _rpc_clnt_xprt_iter_init(struct rpc_clnt *clnt, struct rpc_xprt_iter *xpi, + void func(struct rpc_xprt_iter *xpi, struct rpc_xprt_switch *xps)) +{ + struct rpc_xprt_switch *xps; + + rcu_read_lock(); + xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch)); + rcu_read_unlock(); + if (xps == NULL) + return -EAGAIN; + func(xpi, xps); + xprt_switch_put(xps); + return 0; +} + +static +int rpc_clnt_xprt_iter_init(struct rpc_clnt *clnt, struct rpc_xprt_iter *xpi) +{ + return _rpc_clnt_xprt_iter_init(clnt, xpi, xprt_iter_init_listall); +} + +static +int rpc_clnt_xprt_iter_offline_init(struct rpc_clnt *clnt, + struct rpc_xprt_iter *xpi) +{ + return _rpc_clnt_xprt_iter_init(clnt, xpi, xprt_iter_init_listoffline); +} + +/** + * rpc_clnt_iterate_for_each_xprt - Apply a function to all transports + * @clnt: pointer to client + * @fn: function to apply + * @data: void pointer to function data + * + * Iterates through the list of RPC transports currently attached to the + * client and applies the function fn(clnt, xprt, data). + * + * On error, the iteration stops, and the function returns the error value. + */ +int rpc_clnt_iterate_for_each_xprt(struct rpc_clnt *clnt, + int (*fn)(struct rpc_clnt *, struct rpc_xprt *, void *), + void *data) +{ + struct rpc_xprt_iter xpi; + int ret; + + ret = rpc_clnt_xprt_iter_init(clnt, &xpi); + if (ret) + return ret; + for (;;) { + struct rpc_xprt *xprt = xprt_iter_get_next(&xpi); + + if (!xprt) + break; + ret = fn(clnt, xprt, data); + xprt_put(xprt); + if (ret < 0) + break; + } + xprt_iter_destroy(&xpi); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_clnt_iterate_for_each_xprt); + +/* + * Kill all tasks for the given client. + * XXX: kill their descendants as well? + */ +void rpc_killall_tasks(struct rpc_clnt *clnt) +{ + struct rpc_task *rovr; + + + if (list_empty(&clnt->cl_tasks)) + return; + + /* + * Spin lock all_tasks to prevent changes... + */ + trace_rpc_clnt_killall(clnt); + spin_lock(&clnt->cl_lock); + list_for_each_entry(rovr, &clnt->cl_tasks, tk_task) + rpc_signal_task(rovr); + spin_unlock(&clnt->cl_lock); +} +EXPORT_SYMBOL_GPL(rpc_killall_tasks); + +/** + * rpc_cancel_tasks - try to cancel a set of RPC tasks + * @clnt: Pointer to RPC client + * @error: RPC task error value to set + * @fnmatch: Pointer to selector function + * @data: User data + * + * Uses @fnmatch to define a set of RPC tasks that are to be cancelled. + * The argument @error must be a negative error value. + */ +unsigned long rpc_cancel_tasks(struct rpc_clnt *clnt, int error, + bool (*fnmatch)(const struct rpc_task *, + const void *), + const void *data) +{ + struct rpc_task *task; + unsigned long count = 0; + + if (list_empty(&clnt->cl_tasks)) + return 0; + /* + * Spin lock all_tasks to prevent changes... + */ + spin_lock(&clnt->cl_lock); + list_for_each_entry(task, &clnt->cl_tasks, tk_task) { + if (!RPC_IS_ACTIVATED(task)) + continue; + if (!fnmatch(task, data)) + continue; + rpc_task_try_cancel(task, error); + count++; + } + spin_unlock(&clnt->cl_lock); + return count; +} +EXPORT_SYMBOL_GPL(rpc_cancel_tasks); + +static int rpc_clnt_disconnect_xprt(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, void *dummy) +{ + if (xprt_connected(xprt)) + xprt_force_disconnect(xprt); + return 0; +} + +void rpc_clnt_disconnect(struct rpc_clnt *clnt) +{ + rpc_clnt_iterate_for_each_xprt(clnt, rpc_clnt_disconnect_xprt, NULL); +} +EXPORT_SYMBOL_GPL(rpc_clnt_disconnect); + +/* + * Properly shut down an RPC client, terminating all outstanding + * requests. + */ +void rpc_shutdown_client(struct rpc_clnt *clnt) +{ + might_sleep(); + + trace_rpc_clnt_shutdown(clnt); + + while (!list_empty(&clnt->cl_tasks)) { + rpc_killall_tasks(clnt); + wait_event_timeout(destroy_wait, + list_empty(&clnt->cl_tasks), 1*HZ); + } + + rpc_release_client(clnt); +} +EXPORT_SYMBOL_GPL(rpc_shutdown_client); + +/* + * Free an RPC client + */ +static void rpc_free_client_work(struct work_struct *work) +{ + struct rpc_clnt *clnt = container_of(work, struct rpc_clnt, cl_work); + + trace_rpc_clnt_free(clnt); + + /* These might block on processes that might allocate memory, + * so they cannot be called in rpciod, so they are handled separately + * here. + */ + rpc_sysfs_client_destroy(clnt); + rpc_clnt_debugfs_unregister(clnt); + rpc_free_clid(clnt); + rpc_clnt_remove_pipedir(clnt); + xprt_put(rcu_dereference_raw(clnt->cl_xprt)); + + kfree(clnt); + rpciod_down(); +} +static struct rpc_clnt * +rpc_free_client(struct rpc_clnt *clnt) +{ + struct rpc_clnt *parent = NULL; + + trace_rpc_clnt_release(clnt); + if (clnt->cl_parent != clnt) + parent = clnt->cl_parent; + rpc_unregister_client(clnt); + rpc_free_iostats(clnt->cl_metrics); + clnt->cl_metrics = NULL; + xprt_iter_destroy(&clnt->cl_xpi); + put_cred(clnt->cl_cred); + + INIT_WORK(&clnt->cl_work, rpc_free_client_work); + schedule_work(&clnt->cl_work); + return parent; +} + +/* + * Free an RPC client + */ +static struct rpc_clnt * +rpc_free_auth(struct rpc_clnt *clnt) +{ + /* + * Note: RPCSEC_GSS may need to send NULL RPC calls in order to + * release remaining GSS contexts. This mechanism ensures + * that it can do so safely. + */ + if (clnt->cl_auth != NULL) { + rpcauth_release(clnt->cl_auth); + clnt->cl_auth = NULL; + } + if (refcount_dec_and_test(&clnt->cl_count)) + return rpc_free_client(clnt); + return NULL; +} + +/* + * Release reference to the RPC client + */ +void +rpc_release_client(struct rpc_clnt *clnt) +{ + do { + if (list_empty(&clnt->cl_tasks)) + wake_up(&destroy_wait); + if (refcount_dec_not_one(&clnt->cl_count)) + break; + clnt = rpc_free_auth(clnt); + } while (clnt != NULL); +} +EXPORT_SYMBOL_GPL(rpc_release_client); + +/** + * rpc_bind_new_program - bind a new RPC program to an existing client + * @old: old rpc_client + * @program: rpc program to set + * @vers: rpc program version + * + * Clones the rpc client and sets up a new RPC program. This is mainly + * of use for enabling different RPC programs to share the same transport. + * The Sun NFSv2/v3 ACL protocol can do this. + */ +struct rpc_clnt *rpc_bind_new_program(struct rpc_clnt *old, + const struct rpc_program *program, + u32 vers) +{ + struct rpc_create_args args = { + .program = program, + .prognumber = program->number, + .version = vers, + .authflavor = old->cl_auth->au_flavor, + .cred = old->cl_cred, + }; + struct rpc_clnt *clnt; + int err; + + clnt = __rpc_clone_client(&args, old); + if (IS_ERR(clnt)) + goto out; + err = rpc_ping(clnt); + if (err != 0) { + rpc_shutdown_client(clnt); + clnt = ERR_PTR(err); + } +out: + return clnt; +} +EXPORT_SYMBOL_GPL(rpc_bind_new_program); + +struct rpc_xprt * +rpc_task_get_xprt(struct rpc_clnt *clnt, struct rpc_xprt *xprt) +{ + struct rpc_xprt_switch *xps; + + if (!xprt) + return NULL; + rcu_read_lock(); + xps = rcu_dereference(clnt->cl_xpi.xpi_xpswitch); + atomic_long_inc(&xps->xps_queuelen); + rcu_read_unlock(); + atomic_long_inc(&xprt->queuelen); + + return xprt; +} + +static void +rpc_task_release_xprt(struct rpc_clnt *clnt, struct rpc_xprt *xprt) +{ + struct rpc_xprt_switch *xps; + + atomic_long_dec(&xprt->queuelen); + rcu_read_lock(); + xps = rcu_dereference(clnt->cl_xpi.xpi_xpswitch); + atomic_long_dec(&xps->xps_queuelen); + rcu_read_unlock(); + + xprt_put(xprt); +} + +void rpc_task_release_transport(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + + if (xprt) { + task->tk_xprt = NULL; + if (task->tk_client) + rpc_task_release_xprt(task->tk_client, xprt); + else + xprt_put(xprt); + } +} +EXPORT_SYMBOL_GPL(rpc_task_release_transport); + +void rpc_task_release_client(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + + rpc_task_release_transport(task); + if (clnt != NULL) { + /* Remove from client task list */ + spin_lock(&clnt->cl_lock); + list_del(&task->tk_task); + spin_unlock(&clnt->cl_lock); + task->tk_client = NULL; + + rpc_release_client(clnt); + } +} + +static struct rpc_xprt * +rpc_task_get_first_xprt(struct rpc_clnt *clnt) +{ + struct rpc_xprt *xprt; + + rcu_read_lock(); + xprt = xprt_get(rcu_dereference(clnt->cl_xprt)); + rcu_read_unlock(); + return rpc_task_get_xprt(clnt, xprt); +} + +static struct rpc_xprt * +rpc_task_get_next_xprt(struct rpc_clnt *clnt) +{ + return rpc_task_get_xprt(clnt, xprt_iter_get_next(&clnt->cl_xpi)); +} + +static +void rpc_task_set_transport(struct rpc_task *task, struct rpc_clnt *clnt) +{ + if (task->tk_xprt) { + if (!(test_bit(XPRT_OFFLINE, &task->tk_xprt->state) && + (task->tk_flags & RPC_TASK_MOVEABLE))) + return; + xprt_release(task); + xprt_put(task->tk_xprt); + } + if (task->tk_flags & RPC_TASK_NO_ROUND_ROBIN) + task->tk_xprt = rpc_task_get_first_xprt(clnt); + else + task->tk_xprt = rpc_task_get_next_xprt(clnt); +} + +static +void rpc_task_set_client(struct rpc_task *task, struct rpc_clnt *clnt) +{ + rpc_task_set_transport(task, clnt); + task->tk_client = clnt; + refcount_inc(&clnt->cl_count); + if (clnt->cl_softrtry) + task->tk_flags |= RPC_TASK_SOFT; + if (clnt->cl_softerr) + task->tk_flags |= RPC_TASK_TIMEOUT; + if (clnt->cl_noretranstimeo) + task->tk_flags |= RPC_TASK_NO_RETRANS_TIMEOUT; + /* Add to the client's list of all tasks */ + spin_lock(&clnt->cl_lock); + list_add_tail(&task->tk_task, &clnt->cl_tasks); + spin_unlock(&clnt->cl_lock); +} + +static void +rpc_task_set_rpc_message(struct rpc_task *task, const struct rpc_message *msg) +{ + if (msg != NULL) { + task->tk_msg.rpc_proc = msg->rpc_proc; + task->tk_msg.rpc_argp = msg->rpc_argp; + task->tk_msg.rpc_resp = msg->rpc_resp; + task->tk_msg.rpc_cred = msg->rpc_cred; + if (!(task->tk_flags & RPC_TASK_CRED_NOREF)) + get_cred(task->tk_msg.rpc_cred); + } +} + +/* + * Default callback for async RPC calls + */ +static void +rpc_default_callback(struct rpc_task *task, void *data) +{ +} + +static const struct rpc_call_ops rpc_default_ops = { + .rpc_call_done = rpc_default_callback, +}; + +/** + * rpc_run_task - Allocate a new RPC task, then run rpc_execute against it + * @task_setup_data: pointer to task initialisation data + */ +struct rpc_task *rpc_run_task(const struct rpc_task_setup *task_setup_data) +{ + struct rpc_task *task; + + task = rpc_new_task(task_setup_data); + if (IS_ERR(task)) + return task; + + if (!RPC_IS_ASYNC(task)) + task->tk_flags |= RPC_TASK_CRED_NOREF; + + rpc_task_set_client(task, task_setup_data->rpc_client); + rpc_task_set_rpc_message(task, task_setup_data->rpc_message); + + if (task->tk_action == NULL) + rpc_call_start(task); + + atomic_inc(&task->tk_count); + rpc_execute(task); + return task; +} +EXPORT_SYMBOL_GPL(rpc_run_task); + +/** + * rpc_call_sync - Perform a synchronous RPC call + * @clnt: pointer to RPC client + * @msg: RPC call parameters + * @flags: RPC call flags + */ +int rpc_call_sync(struct rpc_clnt *clnt, const struct rpc_message *msg, int flags) +{ + struct rpc_task *task; + struct rpc_task_setup task_setup_data = { + .rpc_client = clnt, + .rpc_message = msg, + .callback_ops = &rpc_default_ops, + .flags = flags, + }; + int status; + + WARN_ON_ONCE(flags & RPC_TASK_ASYNC); + if (flags & RPC_TASK_ASYNC) { + rpc_release_calldata(task_setup_data.callback_ops, + task_setup_data.callback_data); + return -EINVAL; + } + + task = rpc_run_task(&task_setup_data); + if (IS_ERR(task)) + return PTR_ERR(task); + status = task->tk_status; + rpc_put_task(task); + return status; +} +EXPORT_SYMBOL_GPL(rpc_call_sync); + +/** + * rpc_call_async - Perform an asynchronous RPC call + * @clnt: pointer to RPC client + * @msg: RPC call parameters + * @flags: RPC call flags + * @tk_ops: RPC call ops + * @data: user call data + */ +int +rpc_call_async(struct rpc_clnt *clnt, const struct rpc_message *msg, int flags, + const struct rpc_call_ops *tk_ops, void *data) +{ + struct rpc_task *task; + struct rpc_task_setup task_setup_data = { + .rpc_client = clnt, + .rpc_message = msg, + .callback_ops = tk_ops, + .callback_data = data, + .flags = flags|RPC_TASK_ASYNC, + }; + + task = rpc_run_task(&task_setup_data); + if (IS_ERR(task)) + return PTR_ERR(task); + rpc_put_task(task); + return 0; +} +EXPORT_SYMBOL_GPL(rpc_call_async); + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +static void call_bc_encode(struct rpc_task *task); + +/** + * rpc_run_bc_task - Allocate a new RPC task for backchannel use, then run + * rpc_execute against it + * @req: RPC request + */ +struct rpc_task *rpc_run_bc_task(struct rpc_rqst *req) +{ + struct rpc_task *task; + struct rpc_task_setup task_setup_data = { + .callback_ops = &rpc_default_ops, + .flags = RPC_TASK_SOFTCONN | + RPC_TASK_NO_RETRANS_TIMEOUT, + }; + + dprintk("RPC: rpc_run_bc_task req= %p\n", req); + /* + * Create an rpc_task to send the data + */ + task = rpc_new_task(&task_setup_data); + if (IS_ERR(task)) { + xprt_free_bc_request(req); + return task; + } + + xprt_init_bc_request(req, task); + + task->tk_action = call_bc_encode; + atomic_inc(&task->tk_count); + WARN_ON_ONCE(atomic_read(&task->tk_count) != 2); + rpc_execute(task); + + dprintk("RPC: rpc_run_bc_task: task= %p\n", task); + return task; +} +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +/** + * rpc_prepare_reply_pages - Prepare to receive a reply data payload into pages + * @req: RPC request to prepare + * @pages: vector of struct page pointers + * @base: offset in first page where receive should start, in bytes + * @len: expected size of the upper layer data payload, in bytes + * @hdrsize: expected size of upper layer reply header, in XDR words + * + */ +void rpc_prepare_reply_pages(struct rpc_rqst *req, struct page **pages, + unsigned int base, unsigned int len, + unsigned int hdrsize) +{ + hdrsize += RPC_REPHDRSIZE + req->rq_cred->cr_auth->au_ralign; + + xdr_inline_pages(&req->rq_rcv_buf, hdrsize << 2, pages, base, len); + trace_rpc_xdr_reply_pages(req->rq_task, &req->rq_rcv_buf); +} +EXPORT_SYMBOL_GPL(rpc_prepare_reply_pages); + +void +rpc_call_start(struct rpc_task *task) +{ + task->tk_action = call_start; +} +EXPORT_SYMBOL_GPL(rpc_call_start); + +/** + * rpc_peeraddr - extract remote peer address from clnt's xprt + * @clnt: RPC client structure + * @buf: target buffer + * @bufsize: length of target buffer + * + * Returns the number of bytes that are actually in the stored address. + */ +size_t rpc_peeraddr(struct rpc_clnt *clnt, struct sockaddr *buf, size_t bufsize) +{ + size_t bytes; + struct rpc_xprt *xprt; + + rcu_read_lock(); + xprt = rcu_dereference(clnt->cl_xprt); + + bytes = xprt->addrlen; + if (bytes > bufsize) + bytes = bufsize; + memcpy(buf, &xprt->addr, bytes); + rcu_read_unlock(); + + return bytes; +} +EXPORT_SYMBOL_GPL(rpc_peeraddr); + +/** + * rpc_peeraddr2str - return remote peer address in printable format + * @clnt: RPC client structure + * @format: address format + * + * NB: the lifetime of the memory referenced by the returned pointer is + * the same as the rpc_xprt itself. As long as the caller uses this + * pointer, it must hold the RCU read lock. + */ +const char *rpc_peeraddr2str(struct rpc_clnt *clnt, + enum rpc_display_format_t format) +{ + struct rpc_xprt *xprt; + + xprt = rcu_dereference(clnt->cl_xprt); + + if (xprt->address_strings[format] != NULL) + return xprt->address_strings[format]; + else + return "unprintable"; +} +EXPORT_SYMBOL_GPL(rpc_peeraddr2str); + +static const struct sockaddr_in rpc_inaddr_loopback = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_ANY), +}; + +static const struct sockaddr_in6 rpc_in6addr_loopback = { + .sin6_family = AF_INET6, + .sin6_addr = IN6ADDR_ANY_INIT, +}; + +/* + * Try a getsockname() on a connected datagram socket. Using a + * connected datagram socket prevents leaving a socket in TIME_WAIT. + * This conserves the ephemeral port number space. + * + * Returns zero and fills in "buf" if successful; otherwise, a + * negative errno is returned. + */ +static int rpc_sockname(struct net *net, struct sockaddr *sap, size_t salen, + struct sockaddr *buf) +{ + struct socket *sock; + int err; + + err = __sock_create(net, sap->sa_family, + SOCK_DGRAM, IPPROTO_UDP, &sock, 1); + if (err < 0) { + dprintk("RPC: can't create UDP socket (%d)\n", err); + goto out; + } + + switch (sap->sa_family) { + case AF_INET: + err = kernel_bind(sock, + (struct sockaddr *)&rpc_inaddr_loopback, + sizeof(rpc_inaddr_loopback)); + break; + case AF_INET6: + err = kernel_bind(sock, + (struct sockaddr *)&rpc_in6addr_loopback, + sizeof(rpc_in6addr_loopback)); + break; + default: + err = -EAFNOSUPPORT; + goto out_release; + } + if (err < 0) { + dprintk("RPC: can't bind UDP socket (%d)\n", err); + goto out_release; + } + + err = kernel_connect(sock, sap, salen, 0); + if (err < 0) { + dprintk("RPC: can't connect UDP socket (%d)\n", err); + goto out_release; + } + + err = kernel_getsockname(sock, buf); + if (err < 0) { + dprintk("RPC: getsockname failed (%d)\n", err); + goto out_release; + } + + err = 0; + if (buf->sa_family == AF_INET6) { + struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)buf; + sin6->sin6_scope_id = 0; + } + dprintk("RPC: %s succeeded\n", __func__); + +out_release: + sock_release(sock); +out: + return err; +} + +/* + * Scraping a connected socket failed, so we don't have a useable + * local address. Fallback: generate an address that will prevent + * the server from calling us back. + * + * Returns zero and fills in "buf" if successful; otherwise, a + * negative errno is returned. + */ +static int rpc_anyaddr(int family, struct sockaddr *buf, size_t buflen) +{ + switch (family) { + case AF_INET: + if (buflen < sizeof(rpc_inaddr_loopback)) + return -EINVAL; + memcpy(buf, &rpc_inaddr_loopback, + sizeof(rpc_inaddr_loopback)); + break; + case AF_INET6: + if (buflen < sizeof(rpc_in6addr_loopback)) + return -EINVAL; + memcpy(buf, &rpc_in6addr_loopback, + sizeof(rpc_in6addr_loopback)); + break; + default: + dprintk("RPC: %s: address family not supported\n", + __func__); + return -EAFNOSUPPORT; + } + dprintk("RPC: %s: succeeded\n", __func__); + return 0; +} + +/** + * rpc_localaddr - discover local endpoint address for an RPC client + * @clnt: RPC client structure + * @buf: target buffer + * @buflen: size of target buffer, in bytes + * + * Returns zero and fills in "buf" and "buflen" if successful; + * otherwise, a negative errno is returned. + * + * This works even if the underlying transport is not currently connected, + * or if the upper layer never previously provided a source address. + * + * The result of this function call is transient: multiple calls in + * succession may give different results, depending on how local + * networking configuration changes over time. + */ +int rpc_localaddr(struct rpc_clnt *clnt, struct sockaddr *buf, size_t buflen) +{ + struct sockaddr_storage address; + struct sockaddr *sap = (struct sockaddr *)&address; + struct rpc_xprt *xprt; + struct net *net; + size_t salen; + int err; + + rcu_read_lock(); + xprt = rcu_dereference(clnt->cl_xprt); + salen = xprt->addrlen; + memcpy(sap, &xprt->addr, salen); + net = get_net(xprt->xprt_net); + rcu_read_unlock(); + + rpc_set_port(sap, 0); + err = rpc_sockname(net, sap, salen, buf); + put_net(net); + if (err != 0) + /* Couldn't discover local address, return ANYADDR */ + return rpc_anyaddr(sap->sa_family, buf, buflen); + return 0; +} +EXPORT_SYMBOL_GPL(rpc_localaddr); + +void +rpc_setbufsize(struct rpc_clnt *clnt, unsigned int sndsize, unsigned int rcvsize) +{ + struct rpc_xprt *xprt; + + rcu_read_lock(); + xprt = rcu_dereference(clnt->cl_xprt); + if (xprt->ops->set_buffer_size) + xprt->ops->set_buffer_size(xprt, sndsize, rcvsize); + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(rpc_setbufsize); + +/** + * rpc_net_ns - Get the network namespace for this RPC client + * @clnt: RPC client to query + * + */ +struct net *rpc_net_ns(struct rpc_clnt *clnt) +{ + struct net *ret; + + rcu_read_lock(); + ret = rcu_dereference(clnt->cl_xprt)->xprt_net; + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_net_ns); + +/** + * rpc_max_payload - Get maximum payload size for a transport, in bytes + * @clnt: RPC client to query + * + * For stream transports, this is one RPC record fragment (see RFC + * 1831), as we don't support multi-record requests yet. For datagram + * transports, this is the size of an IP packet minus the IP, UDP, and + * RPC header sizes. + */ +size_t rpc_max_payload(struct rpc_clnt *clnt) +{ + size_t ret; + + rcu_read_lock(); + ret = rcu_dereference(clnt->cl_xprt)->max_payload; + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_max_payload); + +/** + * rpc_max_bc_payload - Get maximum backchannel payload size, in bytes + * @clnt: RPC client to query + */ +size_t rpc_max_bc_payload(struct rpc_clnt *clnt) +{ + struct rpc_xprt *xprt; + size_t ret; + + rcu_read_lock(); + xprt = rcu_dereference(clnt->cl_xprt); + ret = xprt->ops->bc_maxpayload(xprt); + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_max_bc_payload); + +unsigned int rpc_num_bc_slots(struct rpc_clnt *clnt) +{ + struct rpc_xprt *xprt; + unsigned int ret; + + rcu_read_lock(); + xprt = rcu_dereference(clnt->cl_xprt); + ret = xprt->ops->bc_num_slots(xprt); + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_num_bc_slots); + +/** + * rpc_force_rebind - force transport to check that remote port is unchanged + * @clnt: client to rebind + * + */ +void rpc_force_rebind(struct rpc_clnt *clnt) +{ + if (clnt->cl_autobind) { + rcu_read_lock(); + xprt_clear_bound(rcu_dereference(clnt->cl_xprt)); + rcu_read_unlock(); + } +} +EXPORT_SYMBOL_GPL(rpc_force_rebind); + +static int +__rpc_restart_call(struct rpc_task *task, void (*action)(struct rpc_task *)) +{ + task->tk_status = 0; + task->tk_rpc_status = 0; + task->tk_action = action; + return 1; +} + +/* + * Restart an (async) RPC call. Usually called from within the + * exit handler. + */ +int +rpc_restart_call(struct rpc_task *task) +{ + return __rpc_restart_call(task, call_start); +} +EXPORT_SYMBOL_GPL(rpc_restart_call); + +/* + * Restart an (async) RPC call from the call_prepare state. + * Usually called from within the exit handler. + */ +int +rpc_restart_call_prepare(struct rpc_task *task) +{ + if (task->tk_ops->rpc_call_prepare != NULL) + return __rpc_restart_call(task, rpc_prepare_task); + return rpc_restart_call(task); +} +EXPORT_SYMBOL_GPL(rpc_restart_call_prepare); + +const char +*rpc_proc_name(const struct rpc_task *task) +{ + const struct rpc_procinfo *proc = task->tk_msg.rpc_proc; + + if (proc) { + if (proc->p_name) + return proc->p_name; + else + return "NULL"; + } else + return "no proc"; +} + +static void +__rpc_call_rpcerror(struct rpc_task *task, int tk_status, int rpc_status) +{ + trace_rpc_call_rpcerror(task, tk_status, rpc_status); + rpc_task_set_rpc_status(task, rpc_status); + rpc_exit(task, tk_status); +} + +static void +rpc_call_rpcerror(struct rpc_task *task, int status) +{ + __rpc_call_rpcerror(task, status, status); +} + +/* + * 0. Initial state + * + * Other FSM states can be visited zero or more times, but + * this state is visited exactly once for each RPC. + */ +static void +call_start(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + int idx = task->tk_msg.rpc_proc->p_statidx; + + trace_rpc_request(task); + + /* Increment call count (version might not be valid for ping) */ + if (clnt->cl_program->version[clnt->cl_vers]) + clnt->cl_program->version[clnt->cl_vers]->counts[idx]++; + clnt->cl_stats->rpccnt++; + task->tk_action = call_reserve; + rpc_task_set_transport(task, clnt); +} + +/* + * 1. Reserve an RPC call slot + */ +static void +call_reserve(struct rpc_task *task) +{ + task->tk_status = 0; + task->tk_action = call_reserveresult; + xprt_reserve(task); +} + +static void call_retry_reserve(struct rpc_task *task); + +/* + * 1b. Grok the result of xprt_reserve() + */ +static void +call_reserveresult(struct rpc_task *task) +{ + int status = task->tk_status; + + /* + * After a call to xprt_reserve(), we must have either + * a request slot or else an error status. + */ + task->tk_status = 0; + if (status >= 0) { + if (task->tk_rqstp) { + task->tk_action = call_refresh; + return; + } + + rpc_call_rpcerror(task, -EIO); + return; + } + + switch (status) { + case -ENOMEM: + rpc_delay(task, HZ >> 2); + fallthrough; + case -EAGAIN: /* woken up; retry */ + task->tk_action = call_retry_reserve; + return; + default: + rpc_call_rpcerror(task, status); + } +} + +/* + * 1c. Retry reserving an RPC call slot + */ +static void +call_retry_reserve(struct rpc_task *task) +{ + task->tk_status = 0; + task->tk_action = call_reserveresult; + xprt_retry_reserve(task); +} + +/* + * 2. Bind and/or refresh the credentials + */ +static void +call_refresh(struct rpc_task *task) +{ + task->tk_action = call_refreshresult; + task->tk_status = 0; + task->tk_client->cl_stats->rpcauthrefresh++; + rpcauth_refreshcred(task); +} + +/* + * 2a. Process the results of a credential refresh + */ +static void +call_refreshresult(struct rpc_task *task) +{ + int status = task->tk_status; + + task->tk_status = 0; + task->tk_action = call_refresh; + switch (status) { + case 0: + if (rpcauth_uptodatecred(task)) { + task->tk_action = call_allocate; + return; + } + /* Use rate-limiting and a max number of retries if refresh + * had status 0 but failed to update the cred. + */ + fallthrough; + case -ETIMEDOUT: + rpc_delay(task, 3*HZ); + fallthrough; + case -EAGAIN: + status = -EACCES; + fallthrough; + case -EKEYEXPIRED: + if (!task->tk_cred_retry) + break; + task->tk_cred_retry--; + trace_rpc_retry_refresh_status(task); + return; + case -ENOMEM: + rpc_delay(task, HZ >> 4); + return; + } + trace_rpc_refresh_status(task); + rpc_call_rpcerror(task, status); +} + +/* + * 2b. Allocate the buffer. For details, see sched.c:rpc_malloc. + * (Note: buffer memory is freed in xprt_release). + */ +static void +call_allocate(struct rpc_task *task) +{ + const struct rpc_auth *auth = task->tk_rqstp->rq_cred->cr_auth; + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + const struct rpc_procinfo *proc = task->tk_msg.rpc_proc; + int status; + + task->tk_status = 0; + task->tk_action = call_encode; + + if (req->rq_buffer) + return; + + if (proc->p_proc != 0) { + BUG_ON(proc->p_arglen == 0); + if (proc->p_decode != NULL) + BUG_ON(proc->p_replen == 0); + } + + /* + * Calculate the size (in quads) of the RPC call + * and reply headers, and convert both values + * to byte sizes. + */ + req->rq_callsize = RPC_CALLHDRSIZE + (auth->au_cslack << 1) + + proc->p_arglen; + req->rq_callsize <<= 2; + /* + * Note: the reply buffer must at minimum allocate enough space + * for the 'struct accepted_reply' from RFC5531. + */ + req->rq_rcvsize = RPC_REPHDRSIZE + auth->au_rslack + \ + max_t(size_t, proc->p_replen, 2); + req->rq_rcvsize <<= 2; + + status = xprt->ops->buf_alloc(task); + trace_rpc_buf_alloc(task, status); + if (status == 0) + return; + if (status != -ENOMEM) { + rpc_call_rpcerror(task, status); + return; + } + + if (RPC_IS_ASYNC(task) || !fatal_signal_pending(current)) { + task->tk_action = call_allocate; + rpc_delay(task, HZ>>4); + return; + } + + rpc_call_rpcerror(task, -ERESTARTSYS); +} + +static int +rpc_task_need_encode(struct rpc_task *task) +{ + return test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate) == 0 && + (!(task->tk_flags & RPC_TASK_SENT) || + !(task->tk_flags & RPC_TASK_NO_RETRANS_TIMEOUT) || + xprt_request_need_retransmit(task)); +} + +static void +rpc_xdr_encode(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct xdr_stream xdr; + + xdr_buf_init(&req->rq_snd_buf, + req->rq_buffer, + req->rq_callsize); + xdr_buf_init(&req->rq_rcv_buf, + req->rq_rbuffer, + req->rq_rcvsize); + + req->rq_reply_bytes_recvd = 0; + req->rq_snd_buf.head[0].iov_len = 0; + xdr_init_encode(&xdr, &req->rq_snd_buf, + req->rq_snd_buf.head[0].iov_base, req); + if (rpc_encode_header(task, &xdr)) + return; + + task->tk_status = rpcauth_wrap_req(task, &xdr); +} + +/* + * 3. Encode arguments of an RPC call + */ +static void +call_encode(struct rpc_task *task) +{ + if (!rpc_task_need_encode(task)) + goto out; + + /* Dequeue task from the receive queue while we're encoding */ + xprt_request_dequeue_xprt(task); + /* Encode here so that rpcsec_gss can use correct sequence number. */ + rpc_xdr_encode(task); + /* Add task to reply queue before transmission to avoid races */ + if (task->tk_status == 0 && rpc_reply_expected(task)) + task->tk_status = xprt_request_enqueue_receive(task); + /* Did the encode result in an error condition? */ + if (task->tk_status != 0) { + /* Was the error nonfatal? */ + switch (task->tk_status) { + case -EAGAIN: + case -ENOMEM: + rpc_delay(task, HZ >> 4); + break; + case -EKEYEXPIRED: + if (!task->tk_cred_retry) { + rpc_call_rpcerror(task, task->tk_status); + } else { + task->tk_action = call_refresh; + task->tk_cred_retry--; + trace_rpc_retry_refresh_status(task); + } + break; + default: + rpc_call_rpcerror(task, task->tk_status); + } + return; + } + + xprt_request_enqueue_transmit(task); +out: + task->tk_action = call_transmit; + /* Check that the connection is OK */ + if (!xprt_bound(task->tk_xprt)) + task->tk_action = call_bind; + else if (!xprt_connected(task->tk_xprt)) + task->tk_action = call_connect; +} + +/* + * Helpers to check if the task was already transmitted, and + * to take action when that is the case. + */ +static bool +rpc_task_transmitted(struct rpc_task *task) +{ + return !test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate); +} + +static void +rpc_task_handle_transmitted(struct rpc_task *task) +{ + xprt_end_transmit(task); + task->tk_action = call_transmit_status; +} + +/* + * 4. Get the server port number if not yet set + */ +static void +call_bind(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; + + if (rpc_task_transmitted(task)) { + rpc_task_handle_transmitted(task); + return; + } + + if (xprt_bound(xprt)) { + task->tk_action = call_connect; + return; + } + + task->tk_action = call_bind_status; + if (!xprt_prepare_transmit(task)) + return; + + xprt->ops->rpcbind(task); +} + +/* + * 4a. Sort out bind result + */ +static void +call_bind_status(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; + int status = -EIO; + + if (rpc_task_transmitted(task)) { + rpc_task_handle_transmitted(task); + return; + } + + if (task->tk_status >= 0) + goto out_next; + if (xprt_bound(xprt)) { + task->tk_status = 0; + goto out_next; + } + + switch (task->tk_status) { + case -ENOMEM: + rpc_delay(task, HZ >> 2); + goto retry_timeout; + case -EACCES: + trace_rpcb_prog_unavail_err(task); + /* fail immediately if this is an RPC ping */ + if (task->tk_msg.rpc_proc->p_proc == 0) { + status = -EOPNOTSUPP; + break; + } + rpc_delay(task, 3*HZ); + goto retry_timeout; + case -ENOBUFS: + rpc_delay(task, HZ >> 2); + goto retry_timeout; + case -EAGAIN: + goto retry_timeout; + case -ETIMEDOUT: + trace_rpcb_timeout_err(task); + goto retry_timeout; + case -EPFNOSUPPORT: + /* server doesn't support any rpcbind version we know of */ + trace_rpcb_bind_version_err(task); + break; + case -EPROTONOSUPPORT: + trace_rpcb_bind_version_err(task); + goto retry_timeout; + case -ECONNREFUSED: /* connection problems */ + case -ECONNRESET: + case -ECONNABORTED: + case -ENOTCONN: + case -EHOSTDOWN: + case -ENETDOWN: + case -EHOSTUNREACH: + case -ENETUNREACH: + case -EPIPE: + trace_rpcb_unreachable_err(task); + if (!RPC_IS_SOFTCONN(task)) { + rpc_delay(task, 5*HZ); + goto retry_timeout; + } + status = task->tk_status; + break; + default: + trace_rpcb_unrecognized_err(task); + } + + rpc_call_rpcerror(task, status); + return; +out_next: + task->tk_action = call_connect; + return; +retry_timeout: + task->tk_status = 0; + task->tk_action = call_bind; + rpc_check_timeout(task); +} + +/* + * 4b. Connect to the RPC server + */ +static void +call_connect(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; + + if (rpc_task_transmitted(task)) { + rpc_task_handle_transmitted(task); + return; + } + + if (xprt_connected(xprt)) { + task->tk_action = call_transmit; + return; + } + + task->tk_action = call_connect_status; + if (task->tk_status < 0) + return; + if (task->tk_flags & RPC_TASK_NOCONNECT) { + rpc_call_rpcerror(task, -ENOTCONN); + return; + } + if (!xprt_prepare_transmit(task)) + return; + xprt_connect(task); +} + +/* + * 4c. Sort out connect result + */ +static void +call_connect_status(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; + struct rpc_clnt *clnt = task->tk_client; + int status = task->tk_status; + + if (rpc_task_transmitted(task)) { + rpc_task_handle_transmitted(task); + return; + } + + trace_rpc_connect_status(task); + + if (task->tk_status == 0) { + clnt->cl_stats->netreconn++; + goto out_next; + } + if (xprt_connected(xprt)) { + task->tk_status = 0; + goto out_next; + } + + task->tk_status = 0; + switch (status) { + case -ECONNREFUSED: + case -ECONNRESET: + /* A positive refusal suggests a rebind is needed. */ + if (RPC_IS_SOFTCONN(task)) + break; + if (clnt->cl_autobind) { + rpc_force_rebind(clnt); + goto out_retry; + } + fallthrough; + case -ECONNABORTED: + case -ENETDOWN: + case -ENETUNREACH: + case -EHOSTUNREACH: + case -EPIPE: + case -EPROTO: + xprt_conditional_disconnect(task->tk_rqstp->rq_xprt, + task->tk_rqstp->rq_connect_cookie); + if (RPC_IS_SOFTCONN(task)) + break; + /* retry with existing socket, after a delay */ + rpc_delay(task, 3*HZ); + fallthrough; + case -EADDRINUSE: + case -ENOTCONN: + case -EAGAIN: + case -ETIMEDOUT: + if (!(task->tk_flags & RPC_TASK_NO_ROUND_ROBIN) && + (task->tk_flags & RPC_TASK_MOVEABLE) && + test_bit(XPRT_REMOVE, &xprt->state)) { + struct rpc_xprt *saved = task->tk_xprt; + struct rpc_xprt_switch *xps; + + rcu_read_lock(); + xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch)); + rcu_read_unlock(); + if (xps->xps_nxprts > 1) { + long value; + + xprt_release(task); + value = atomic_long_dec_return(&xprt->queuelen); + if (value == 0) + rpc_xprt_switch_remove_xprt(xps, saved, + true); + xprt_put(saved); + task->tk_xprt = NULL; + task->tk_action = call_start; + } + xprt_switch_put(xps); + if (!task->tk_xprt) + return; + } + goto out_retry; + case -ENOBUFS: + rpc_delay(task, HZ >> 2); + goto out_retry; + } + rpc_call_rpcerror(task, status); + return; +out_next: + task->tk_action = call_transmit; + return; +out_retry: + /* Check for timeouts before looping back to call_bind */ + task->tk_action = call_bind; + rpc_check_timeout(task); +} + +/* + * 5. Transmit the RPC request, and wait for reply + */ +static void +call_transmit(struct rpc_task *task) +{ + if (rpc_task_transmitted(task)) { + rpc_task_handle_transmitted(task); + return; + } + + task->tk_action = call_transmit_status; + if (!xprt_prepare_transmit(task)) + return; + task->tk_status = 0; + if (test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) { + if (!xprt_connected(task->tk_xprt)) { + task->tk_status = -ENOTCONN; + return; + } + xprt_transmit(task); + } + xprt_end_transmit(task); +} + +/* + * 5a. Handle cleanup after a transmission + */ +static void +call_transmit_status(struct rpc_task *task) +{ + task->tk_action = call_status; + + /* + * Common case: success. Force the compiler to put this + * test first. + */ + if (rpc_task_transmitted(task)) { + task->tk_status = 0; + xprt_request_wait_receive(task); + return; + } + + switch (task->tk_status) { + default: + break; + case -EBADMSG: + task->tk_status = 0; + task->tk_action = call_encode; + break; + /* + * Special cases: if we've been waiting on the + * socket's write_space() callback, or if the + * socket just returned a connection error, + * then hold onto the transport lock. + */ + case -ENOMEM: + case -ENOBUFS: + rpc_delay(task, HZ>>2); + fallthrough; + case -EBADSLT: + case -EAGAIN: + task->tk_action = call_transmit; + task->tk_status = 0; + break; + case -ECONNREFUSED: + case -EHOSTDOWN: + case -ENETDOWN: + case -EHOSTUNREACH: + case -ENETUNREACH: + case -EPERM: + if (RPC_IS_SOFTCONN(task)) { + if (!task->tk_msg.rpc_proc->p_proc) + trace_xprt_ping(task->tk_xprt, + task->tk_status); + rpc_call_rpcerror(task, task->tk_status); + return; + } + fallthrough; + case -ECONNRESET: + case -ECONNABORTED: + case -EADDRINUSE: + case -ENOTCONN: + case -EPIPE: + task->tk_action = call_bind; + task->tk_status = 0; + break; + } + rpc_check_timeout(task); +} + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +static void call_bc_transmit(struct rpc_task *task); +static void call_bc_transmit_status(struct rpc_task *task); + +static void +call_bc_encode(struct rpc_task *task) +{ + xprt_request_enqueue_transmit(task); + task->tk_action = call_bc_transmit; +} + +/* + * 5b. Send the backchannel RPC reply. On error, drop the reply. In + * addition, disconnect on connectivity errors. + */ +static void +call_bc_transmit(struct rpc_task *task) +{ + task->tk_action = call_bc_transmit_status; + if (test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) { + if (!xprt_prepare_transmit(task)) + return; + task->tk_status = 0; + xprt_transmit(task); + } + xprt_end_transmit(task); +} + +static void +call_bc_transmit_status(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + + if (rpc_task_transmitted(task)) + task->tk_status = 0; + + switch (task->tk_status) { + case 0: + /* Success */ + case -ENETDOWN: + case -EHOSTDOWN: + case -EHOSTUNREACH: + case -ENETUNREACH: + case -ECONNRESET: + case -ECONNREFUSED: + case -EADDRINUSE: + case -ENOTCONN: + case -EPIPE: + break; + case -ENOMEM: + case -ENOBUFS: + rpc_delay(task, HZ>>2); + fallthrough; + case -EBADSLT: + case -EAGAIN: + task->tk_status = 0; + task->tk_action = call_bc_transmit; + return; + case -ETIMEDOUT: + /* + * Problem reaching the server. Disconnect and let the + * forechannel reestablish the connection. The server will + * have to retransmit the backchannel request and we'll + * reprocess it. Since these ops are idempotent, there's no + * need to cache our reply at this time. + */ + printk(KERN_NOTICE "RPC: Could not send backchannel reply " + "error: %d\n", task->tk_status); + xprt_conditional_disconnect(req->rq_xprt, + req->rq_connect_cookie); + break; + default: + /* + * We were unable to reply and will have to drop the + * request. The server should reconnect and retransmit. + */ + printk(KERN_NOTICE "RPC: Could not send backchannel reply " + "error: %d\n", task->tk_status); + break; + } + task->tk_action = rpc_exit_task; +} +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +/* + * 6. Sort out the RPC call status + */ +static void +call_status(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + int status; + + if (!task->tk_msg.rpc_proc->p_proc) + trace_xprt_ping(task->tk_xprt, task->tk_status); + + status = task->tk_status; + if (status >= 0) { + task->tk_action = call_decode; + return; + } + + trace_rpc_call_status(task); + task->tk_status = 0; + switch(status) { + case -EHOSTDOWN: + case -ENETDOWN: + case -EHOSTUNREACH: + case -ENETUNREACH: + case -EPERM: + if (RPC_IS_SOFTCONN(task)) + goto out_exit; + /* + * Delay any retries for 3 seconds, then handle as if it + * were a timeout. + */ + rpc_delay(task, 3*HZ); + fallthrough; + case -ETIMEDOUT: + break; + case -ECONNREFUSED: + case -ECONNRESET: + case -ECONNABORTED: + case -ENOTCONN: + rpc_force_rebind(clnt); + break; + case -EADDRINUSE: + rpc_delay(task, 3*HZ); + fallthrough; + case -EPIPE: + case -EAGAIN: + break; + case -ENFILE: + case -ENOBUFS: + case -ENOMEM: + rpc_delay(task, HZ>>2); + break; + case -EIO: + /* shutdown or soft timeout */ + goto out_exit; + default: + if (clnt->cl_chatty) + printk("%s: RPC call returned error %d\n", + clnt->cl_program->name, -status); + goto out_exit; + } + task->tk_action = call_encode; + rpc_check_timeout(task); + return; +out_exit: + rpc_call_rpcerror(task, status); +} + +static bool +rpc_check_connected(const struct rpc_rqst *req) +{ + /* No allocated request or transport? return true */ + if (!req || !req->rq_xprt) + return true; + return xprt_connected(req->rq_xprt); +} + +static void +rpc_check_timeout(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + + if (RPC_SIGNALLED(task)) + return; + + if (xprt_adjust_timeout(task->tk_rqstp) == 0) + return; + + trace_rpc_timeout_status(task); + task->tk_timeouts++; + + if (RPC_IS_SOFTCONN(task) && !rpc_check_connected(task->tk_rqstp)) { + rpc_call_rpcerror(task, -ETIMEDOUT); + return; + } + + if (RPC_IS_SOFT(task)) { + /* + * Once a "no retrans timeout" soft tasks (a.k.a NFSv4) has + * been sent, it should time out only if the transport + * connection gets terminally broken. + */ + if ((task->tk_flags & RPC_TASK_NO_RETRANS_TIMEOUT) && + rpc_check_connected(task->tk_rqstp)) + return; + + if (clnt->cl_chatty) { + pr_notice_ratelimited( + "%s: server %s not responding, timed out\n", + clnt->cl_program->name, + task->tk_xprt->servername); + } + if (task->tk_flags & RPC_TASK_TIMEOUT) + rpc_call_rpcerror(task, -ETIMEDOUT); + else + __rpc_call_rpcerror(task, -EIO, -ETIMEDOUT); + return; + } + + if (!(task->tk_flags & RPC_CALL_MAJORSEEN)) { + task->tk_flags |= RPC_CALL_MAJORSEEN; + if (clnt->cl_chatty) { + pr_notice_ratelimited( + "%s: server %s not responding, still trying\n", + clnt->cl_program->name, + task->tk_xprt->servername); + } + } + rpc_force_rebind(clnt); + /* + * Did our request time out due to an RPCSEC_GSS out-of-sequence + * event? RFC2203 requires the server to drop all such requests. + */ + rpcauth_invalcred(task); +} + +/* + * 7. Decode the RPC reply + */ +static void +call_decode(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + struct rpc_rqst *req = task->tk_rqstp; + struct xdr_stream xdr; + int err; + + if (!task->tk_msg.rpc_proc->p_decode) { + task->tk_action = rpc_exit_task; + return; + } + + if (task->tk_flags & RPC_CALL_MAJORSEEN) { + if (clnt->cl_chatty) { + pr_notice_ratelimited("%s: server %s OK\n", + clnt->cl_program->name, + task->tk_xprt->servername); + } + task->tk_flags &= ~RPC_CALL_MAJORSEEN; + } + + /* + * Did we ever call xprt_complete_rqst()? If not, we should assume + * the message is incomplete. + */ + err = -EAGAIN; + if (!req->rq_reply_bytes_recvd) + goto out; + + /* Ensure that we see all writes made by xprt_complete_rqst() + * before it changed req->rq_reply_bytes_recvd. + */ + smp_rmb(); + + req->rq_rcv_buf.len = req->rq_private_buf.len; + trace_rpc_xdr_recvfrom(task, &req->rq_rcv_buf); + + /* Check that the softirq receive buffer is valid */ + WARN_ON(memcmp(&req->rq_rcv_buf, &req->rq_private_buf, + sizeof(req->rq_rcv_buf)) != 0); + + xdr_init_decode(&xdr, &req->rq_rcv_buf, + req->rq_rcv_buf.head[0].iov_base, req); + err = rpc_decode_header(task, &xdr); +out: + switch (err) { + case 0: + task->tk_action = rpc_exit_task; + task->tk_status = rpcauth_unwrap_resp(task, &xdr); + return; + case -EAGAIN: + task->tk_status = 0; + if (task->tk_client->cl_discrtry) + xprt_conditional_disconnect(req->rq_xprt, + req->rq_connect_cookie); + task->tk_action = call_encode; + rpc_check_timeout(task); + break; + case -EKEYREJECTED: + task->tk_action = call_reserve; + rpc_check_timeout(task); + rpcauth_invalcred(task); + /* Ensure we obtain a new XID if we retry! */ + xprt_release(task); + } +} + +static int +rpc_encode_header(struct rpc_task *task, struct xdr_stream *xdr) +{ + struct rpc_clnt *clnt = task->tk_client; + struct rpc_rqst *req = task->tk_rqstp; + __be32 *p; + int error; + + error = -EMSGSIZE; + p = xdr_reserve_space(xdr, RPC_CALLHDRSIZE << 2); + if (!p) + goto out_fail; + *p++ = req->rq_xid; + *p++ = rpc_call; + *p++ = cpu_to_be32(RPC_VERSION); + *p++ = cpu_to_be32(clnt->cl_prog); + *p++ = cpu_to_be32(clnt->cl_vers); + *p = cpu_to_be32(task->tk_msg.rpc_proc->p_proc); + + error = rpcauth_marshcred(task, xdr); + if (error < 0) + goto out_fail; + return 0; +out_fail: + trace_rpc_bad_callhdr(task); + rpc_call_rpcerror(task, error); + return error; +} + +static noinline int +rpc_decode_header(struct rpc_task *task, struct xdr_stream *xdr) +{ + struct rpc_clnt *clnt = task->tk_client; + int error; + __be32 *p; + + /* RFC-1014 says that the representation of XDR data must be a + * multiple of four bytes + * - if it isn't pointer subtraction in the NFS client may give + * undefined results + */ + if (task->tk_rqstp->rq_rcv_buf.len & 3) + goto out_unparsable; + + p = xdr_inline_decode(xdr, 3 * sizeof(*p)); + if (!p) + goto out_unparsable; + p++; /* skip XID */ + if (*p++ != rpc_reply) + goto out_unparsable; + if (*p++ != rpc_msg_accepted) + goto out_msg_denied; + + error = rpcauth_checkverf(task, xdr); + if (error) + goto out_verifier; + + p = xdr_inline_decode(xdr, sizeof(*p)); + if (!p) + goto out_unparsable; + switch (*p) { + case rpc_success: + return 0; + case rpc_prog_unavail: + trace_rpc__prog_unavail(task); + error = -EPFNOSUPPORT; + goto out_err; + case rpc_prog_mismatch: + trace_rpc__prog_mismatch(task); + error = -EPROTONOSUPPORT; + goto out_err; + case rpc_proc_unavail: + trace_rpc__proc_unavail(task); + error = -EOPNOTSUPP; + goto out_err; + case rpc_garbage_args: + case rpc_system_err: + trace_rpc__garbage_args(task); + error = -EIO; + break; + default: + goto out_unparsable; + } + +out_garbage: + clnt->cl_stats->rpcgarbage++; + if (task->tk_garb_retry) { + task->tk_garb_retry--; + task->tk_action = call_encode; + return -EAGAIN; + } +out_err: + rpc_call_rpcerror(task, error); + return error; + +out_unparsable: + trace_rpc__unparsable(task); + error = -EIO; + goto out_garbage; + +out_verifier: + trace_rpc_bad_verifier(task); + goto out_garbage; + +out_msg_denied: + error = -EACCES; + p = xdr_inline_decode(xdr, sizeof(*p)); + if (!p) + goto out_unparsable; + switch (*p++) { + case rpc_auth_error: + break; + case rpc_mismatch: + trace_rpc__mismatch(task); + error = -EPROTONOSUPPORT; + goto out_err; + default: + goto out_unparsable; + } + + p = xdr_inline_decode(xdr, sizeof(*p)); + if (!p) + goto out_unparsable; + switch (*p++) { + case rpc_autherr_rejectedcred: + case rpc_autherr_rejectedverf: + case rpcsec_gsserr_credproblem: + case rpcsec_gsserr_ctxproblem: + rpcauth_invalcred(task); + if (!task->tk_cred_retry) + break; + task->tk_cred_retry--; + trace_rpc__stale_creds(task); + return -EKEYREJECTED; + case rpc_autherr_badcred: + case rpc_autherr_badverf: + /* possibly garbled cred/verf? */ + if (!task->tk_garb_retry) + break; + task->tk_garb_retry--; + trace_rpc__bad_creds(task); + task->tk_action = call_encode; + return -EAGAIN; + case rpc_autherr_tooweak: + trace_rpc__auth_tooweak(task); + pr_warn("RPC: server %s requires stronger authentication.\n", + task->tk_xprt->servername); + break; + default: + goto out_unparsable; + } + goto out_err; +} + +static void rpcproc_encode_null(struct rpc_rqst *rqstp, struct xdr_stream *xdr, + const void *obj) +{ +} + +static int rpcproc_decode_null(struct rpc_rqst *rqstp, struct xdr_stream *xdr, + void *obj) +{ + return 0; +} + +static const struct rpc_procinfo rpcproc_null = { + .p_encode = rpcproc_encode_null, + .p_decode = rpcproc_decode_null, +}; + +static const struct rpc_procinfo rpcproc_null_noreply = { + .p_encode = rpcproc_encode_null, +}; + +static void +rpc_null_call_prepare(struct rpc_task *task, void *data) +{ + task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT; + rpc_call_start(task); +} + +static const struct rpc_call_ops rpc_null_ops = { + .rpc_call_prepare = rpc_null_call_prepare, + .rpc_call_done = rpc_default_callback, +}; + +static +struct rpc_task *rpc_call_null_helper(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, struct rpc_cred *cred, int flags, + const struct rpc_call_ops *ops, void *data) +{ + struct rpc_message msg = { + .rpc_proc = &rpcproc_null, + }; + struct rpc_task_setup task_setup_data = { + .rpc_client = clnt, + .rpc_xprt = xprt, + .rpc_message = &msg, + .rpc_op_cred = cred, + .callback_ops = ops ?: &rpc_null_ops, + .callback_data = data, + .flags = flags | RPC_TASK_SOFT | RPC_TASK_SOFTCONN | + RPC_TASK_NULLCREDS, + }; + + return rpc_run_task(&task_setup_data); +} + +struct rpc_task *rpc_call_null(struct rpc_clnt *clnt, struct rpc_cred *cred, int flags) +{ + return rpc_call_null_helper(clnt, NULL, cred, flags, NULL, NULL); +} +EXPORT_SYMBOL_GPL(rpc_call_null); + +static int rpc_ping(struct rpc_clnt *clnt) +{ + struct rpc_task *task; + int status; + + task = rpc_call_null_helper(clnt, NULL, NULL, 0, NULL, NULL); + if (IS_ERR(task)) + return PTR_ERR(task); + status = task->tk_status; + rpc_put_task(task); + return status; +} + +static int rpc_ping_noreply(struct rpc_clnt *clnt) +{ + struct rpc_message msg = { + .rpc_proc = &rpcproc_null_noreply, + }; + struct rpc_task_setup task_setup_data = { + .rpc_client = clnt, + .rpc_message = &msg, + .callback_ops = &rpc_null_ops, + .flags = RPC_TASK_SOFT | RPC_TASK_SOFTCONN | RPC_TASK_NULLCREDS, + }; + struct rpc_task *task; + int status; + + task = rpc_run_task(&task_setup_data); + if (IS_ERR(task)) + return PTR_ERR(task); + status = task->tk_status; + rpc_put_task(task); + return status; +} + +struct rpc_cb_add_xprt_calldata { + struct rpc_xprt_switch *xps; + struct rpc_xprt *xprt; +}; + +static void rpc_cb_add_xprt_done(struct rpc_task *task, void *calldata) +{ + struct rpc_cb_add_xprt_calldata *data = calldata; + + if (task->tk_status == 0) + rpc_xprt_switch_add_xprt(data->xps, data->xprt); +} + +static void rpc_cb_add_xprt_release(void *calldata) +{ + struct rpc_cb_add_xprt_calldata *data = calldata; + + xprt_put(data->xprt); + xprt_switch_put(data->xps); + kfree(data); +} + +static const struct rpc_call_ops rpc_cb_add_xprt_call_ops = { + .rpc_call_prepare = rpc_null_call_prepare, + .rpc_call_done = rpc_cb_add_xprt_done, + .rpc_release = rpc_cb_add_xprt_release, +}; + +/** + * rpc_clnt_test_and_add_xprt - Test and add a new transport to a rpc_clnt + * @clnt: pointer to struct rpc_clnt + * @xps: pointer to struct rpc_xprt_switch, + * @xprt: pointer struct rpc_xprt + * @in_max_connect: pointer to the max_connect value for the passed in xprt transport + */ +int rpc_clnt_test_and_add_xprt(struct rpc_clnt *clnt, + struct rpc_xprt_switch *xps, struct rpc_xprt *xprt, + void *in_max_connect) +{ + struct rpc_cb_add_xprt_calldata *data; + struct rpc_task *task; + int max_connect = clnt->cl_max_connect; + + if (in_max_connect) + max_connect = *(int *)in_max_connect; + if (xps->xps_nunique_destaddr_xprts + 1 > max_connect) { + rcu_read_lock(); + pr_warn("SUNRPC: reached max allowed number (%d) did not add " + "transport to server: %s\n", max_connect, + rpc_peeraddr2str(clnt, RPC_DISPLAY_ADDR)); + rcu_read_unlock(); + return -EINVAL; + } + + data = kmalloc(sizeof(*data), GFP_KERNEL); + if (!data) + return -ENOMEM; + data->xps = xprt_switch_get(xps); + data->xprt = xprt_get(xprt); + if (rpc_xprt_switch_has_addr(data->xps, (struct sockaddr *)&xprt->addr)) { + rpc_cb_add_xprt_release(data); + goto success; + } + + task = rpc_call_null_helper(clnt, xprt, NULL, RPC_TASK_ASYNC, + &rpc_cb_add_xprt_call_ops, data); + if (IS_ERR(task)) + return PTR_ERR(task); + + data->xps->xps_nunique_destaddr_xprts++; + rpc_put_task(task); +success: + return 1; +} +EXPORT_SYMBOL_GPL(rpc_clnt_test_and_add_xprt); + +static int rpc_clnt_add_xprt_helper(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + struct rpc_add_xprt_test *data) +{ + struct rpc_task *task; + int status = -EADDRINUSE; + + /* Test the connection */ + task = rpc_call_null_helper(clnt, xprt, NULL, 0, NULL, NULL); + if (IS_ERR(task)) + return PTR_ERR(task); + + status = task->tk_status; + rpc_put_task(task); + + if (status < 0) + return status; + + /* rpc_xprt_switch and rpc_xprt are deferrenced by add_xprt_test() */ + data->add_xprt_test(clnt, xprt, data->data); + + return 0; +} + +/** + * rpc_clnt_setup_test_and_add_xprt() + * + * This is an rpc_clnt_add_xprt setup() function which returns 1 so: + * 1) caller of the test function must dereference the rpc_xprt_switch + * and the rpc_xprt. + * 2) test function must call rpc_xprt_switch_add_xprt, usually in + * the rpc_call_done routine. + * + * Upon success (return of 1), the test function adds the new + * transport to the rpc_clnt xprt switch + * + * @clnt: struct rpc_clnt to get the new transport + * @xps: the rpc_xprt_switch to hold the new transport + * @xprt: the rpc_xprt to test + * @data: a struct rpc_add_xprt_test pointer that holds the test function + * and test function call data + */ +int rpc_clnt_setup_test_and_add_xprt(struct rpc_clnt *clnt, + struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt, + void *data) +{ + int status = -EADDRINUSE; + + xprt = xprt_get(xprt); + xprt_switch_get(xps); + + if (rpc_xprt_switch_has_addr(xps, (struct sockaddr *)&xprt->addr)) + goto out_err; + + status = rpc_clnt_add_xprt_helper(clnt, xprt, data); + if (status < 0) + goto out_err; + + status = 1; +out_err: + xprt_put(xprt); + xprt_switch_put(xps); + if (status < 0) + pr_info("RPC: rpc_clnt_test_xprt failed: %d addr %s not " + "added\n", status, + xprt->address_strings[RPC_DISPLAY_ADDR]); + /* so that rpc_clnt_add_xprt does not call rpc_xprt_switch_add_xprt */ + return status; +} +EXPORT_SYMBOL_GPL(rpc_clnt_setup_test_and_add_xprt); + +/** + * rpc_clnt_add_xprt - Add a new transport to a rpc_clnt + * @clnt: pointer to struct rpc_clnt + * @xprtargs: pointer to struct xprt_create + * @setup: callback to test and/or set up the connection + * @data: pointer to setup function data + * + * Creates a new transport using the parameters set in args and + * adds it to clnt. + * If ping is set, then test that connectivity succeeds before + * adding the new transport. + * + */ +int rpc_clnt_add_xprt(struct rpc_clnt *clnt, + struct xprt_create *xprtargs, + int (*setup)(struct rpc_clnt *, + struct rpc_xprt_switch *, + struct rpc_xprt *, + void *), + void *data) +{ + struct rpc_xprt_switch *xps; + struct rpc_xprt *xprt; + unsigned long connect_timeout; + unsigned long reconnect_timeout; + unsigned char resvport, reuseport; + int ret = 0, ident; + + rcu_read_lock(); + xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch)); + xprt = xprt_iter_xprt(&clnt->cl_xpi); + if (xps == NULL || xprt == NULL) { + rcu_read_unlock(); + xprt_switch_put(xps); + return -EAGAIN; + } + resvport = xprt->resvport; + reuseport = xprt->reuseport; + connect_timeout = xprt->connect_timeout; + reconnect_timeout = xprt->max_reconnect_timeout; + ident = xprt->xprt_class->ident; + rcu_read_unlock(); + + if (!xprtargs->ident) + xprtargs->ident = ident; + xprt = xprt_create_transport(xprtargs); + if (IS_ERR(xprt)) { + ret = PTR_ERR(xprt); + goto out_put_switch; + } + xprt->resvport = resvport; + xprt->reuseport = reuseport; + if (xprt->ops->set_connect_timeout != NULL) + xprt->ops->set_connect_timeout(xprt, + connect_timeout, + reconnect_timeout); + + rpc_xprt_switch_set_roundrobin(xps); + if (setup) { + ret = setup(clnt, xps, xprt, data); + if (ret != 0) + goto out_put_xprt; + } + rpc_xprt_switch_add_xprt(xps, xprt); +out_put_xprt: + xprt_put(xprt); +out_put_switch: + xprt_switch_put(xps); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_clnt_add_xprt); + +static int rpc_xprt_probe_trunked(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + struct rpc_add_xprt_test *data) +{ + struct rpc_xprt_switch *xps; + struct rpc_xprt *main_xprt; + int status = 0; + + xprt_get(xprt); + + rcu_read_lock(); + main_xprt = xprt_get(rcu_dereference(clnt->cl_xprt)); + xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch)); + status = rpc_cmp_addr_port((struct sockaddr *)&xprt->addr, + (struct sockaddr *)&main_xprt->addr); + rcu_read_unlock(); + xprt_put(main_xprt); + if (status || !test_bit(XPRT_OFFLINE, &xprt->state)) + goto out; + + status = rpc_clnt_add_xprt_helper(clnt, xprt, data); +out: + xprt_put(xprt); + xprt_switch_put(xps); + return status; +} + +/* rpc_clnt_probe_trunked_xprt -- probe offlined transport for session trunking + * @clnt rpc_clnt structure + * + * For each offlined transport found in the rpc_clnt structure call + * the function rpc_xprt_probe_trunked() which will determine if this + * transport still belongs to the trunking group. + */ +void rpc_clnt_probe_trunked_xprts(struct rpc_clnt *clnt, + struct rpc_add_xprt_test *data) +{ + struct rpc_xprt_iter xpi; + int ret; + + ret = rpc_clnt_xprt_iter_offline_init(clnt, &xpi); + if (ret) + return; + for (;;) { + struct rpc_xprt *xprt = xprt_iter_get_next(&xpi); + + if (!xprt) + break; + ret = rpc_xprt_probe_trunked(clnt, xprt, data); + xprt_put(xprt); + if (ret < 0) + break; + xprt_iter_rewind(&xpi); + } + xprt_iter_destroy(&xpi); +} +EXPORT_SYMBOL_GPL(rpc_clnt_probe_trunked_xprts); + +static int rpc_xprt_offline(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + void *data) +{ + struct rpc_xprt *main_xprt; + struct rpc_xprt_switch *xps; + int err = 0; + + xprt_get(xprt); + + rcu_read_lock(); + main_xprt = xprt_get(rcu_dereference(clnt->cl_xprt)); + xps = xprt_switch_get(rcu_dereference(clnt->cl_xpi.xpi_xpswitch)); + err = rpc_cmp_addr_port((struct sockaddr *)&xprt->addr, + (struct sockaddr *)&main_xprt->addr); + rcu_read_unlock(); + xprt_put(main_xprt); + if (err) + goto out; + + if (wait_on_bit_lock(&xprt->state, XPRT_LOCKED, TASK_KILLABLE)) { + err = -EINTR; + goto out; + } + xprt_set_offline_locked(xprt, xps); + + xprt_release_write(xprt, NULL); +out: + xprt_put(xprt); + xprt_switch_put(xps); + return err; +} + +/* rpc_clnt_manage_trunked_xprts -- offline trunked transports + * @clnt rpc_clnt structure + * + * For each active transport found in the rpc_clnt structure call + * the function rpc_xprt_offline() which will identify trunked transports + * and will mark them offline. + */ +void rpc_clnt_manage_trunked_xprts(struct rpc_clnt *clnt) +{ + rpc_clnt_iterate_for_each_xprt(clnt, rpc_xprt_offline, NULL); +} +EXPORT_SYMBOL_GPL(rpc_clnt_manage_trunked_xprts); + +struct connect_timeout_data { + unsigned long connect_timeout; + unsigned long reconnect_timeout; +}; + +static int +rpc_xprt_set_connect_timeout(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + void *data) +{ + struct connect_timeout_data *timeo = data; + + if (xprt->ops->set_connect_timeout) + xprt->ops->set_connect_timeout(xprt, + timeo->connect_timeout, + timeo->reconnect_timeout); + return 0; +} + +void +rpc_set_connect_timeout(struct rpc_clnt *clnt, + unsigned long connect_timeout, + unsigned long reconnect_timeout) +{ + struct connect_timeout_data timeout = { + .connect_timeout = connect_timeout, + .reconnect_timeout = reconnect_timeout, + }; + rpc_clnt_iterate_for_each_xprt(clnt, + rpc_xprt_set_connect_timeout, + &timeout); +} +EXPORT_SYMBOL_GPL(rpc_set_connect_timeout); + +void rpc_clnt_xprt_switch_put(struct rpc_clnt *clnt) +{ + rcu_read_lock(); + xprt_switch_put(rcu_dereference(clnt->cl_xpi.xpi_xpswitch)); + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(rpc_clnt_xprt_switch_put); + +void rpc_clnt_xprt_set_online(struct rpc_clnt *clnt, struct rpc_xprt *xprt) +{ + struct rpc_xprt_switch *xps; + + rcu_read_lock(); + xps = rcu_dereference(clnt->cl_xpi.xpi_xpswitch); + rcu_read_unlock(); + xprt_set_online_locked(xprt, xps); +} + +void rpc_clnt_xprt_switch_add_xprt(struct rpc_clnt *clnt, struct rpc_xprt *xprt) +{ + if (rpc_clnt_xprt_switch_has_addr(clnt, + (const struct sockaddr *)&xprt->addr)) { + return rpc_clnt_xprt_set_online(clnt, xprt); + } + rcu_read_lock(); + rpc_xprt_switch_add_xprt(rcu_dereference(clnt->cl_xpi.xpi_xpswitch), + xprt); + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(rpc_clnt_xprt_switch_add_xprt); + +void rpc_clnt_xprt_switch_remove_xprt(struct rpc_clnt *clnt, struct rpc_xprt *xprt) +{ + struct rpc_xprt_switch *xps; + + rcu_read_lock(); + xps = rcu_dereference(clnt->cl_xpi.xpi_xpswitch); + rpc_xprt_switch_remove_xprt(rcu_dereference(clnt->cl_xpi.xpi_xpswitch), + xprt, 0); + xps->xps_nunique_destaddr_xprts--; + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(rpc_clnt_xprt_switch_remove_xprt); + +bool rpc_clnt_xprt_switch_has_addr(struct rpc_clnt *clnt, + const struct sockaddr *sap) +{ + struct rpc_xprt_switch *xps; + bool ret; + + rcu_read_lock(); + xps = rcu_dereference(clnt->cl_xpi.xpi_xpswitch); + ret = rpc_xprt_switch_has_addr(xps, sap); + rcu_read_unlock(); + return ret; +} +EXPORT_SYMBOL_GPL(rpc_clnt_xprt_switch_has_addr); + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +static void rpc_show_header(void) +{ + printk(KERN_INFO "-pid- flgs status -client- --rqstp- " + "-timeout ---ops--\n"); +} + +static void rpc_show_task(const struct rpc_clnt *clnt, + const struct rpc_task *task) +{ + const char *rpc_waitq = "none"; + + if (RPC_IS_QUEUED(task)) + rpc_waitq = rpc_qname(task->tk_waitqueue); + + printk(KERN_INFO "%5u %04x %6d %8p %8p %8ld %8p %sv%u %s a:%ps q:%s\n", + task->tk_pid, task->tk_flags, task->tk_status, + clnt, task->tk_rqstp, rpc_task_timeout(task), task->tk_ops, + clnt->cl_program->name, clnt->cl_vers, rpc_proc_name(task), + task->tk_action, rpc_waitq); +} + +void rpc_show_tasks(struct net *net) +{ + struct rpc_clnt *clnt; + struct rpc_task *task; + int header = 0; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + spin_lock(&sn->rpc_client_lock); + list_for_each_entry(clnt, &sn->all_clients, cl_clients) { + spin_lock(&clnt->cl_lock); + list_for_each_entry(task, &clnt->cl_tasks, tk_task) { + if (!header) { + rpc_show_header(); + header++; + } + rpc_show_task(clnt, task); + } + spin_unlock(&clnt->cl_lock); + } + spin_unlock(&sn->rpc_client_lock); +} +#endif + +#if IS_ENABLED(CONFIG_SUNRPC_SWAP) +static int +rpc_clnt_swap_activate_callback(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + void *dummy) +{ + return xprt_enable_swap(xprt); +} + +int +rpc_clnt_swap_activate(struct rpc_clnt *clnt) +{ + while (clnt != clnt->cl_parent) + clnt = clnt->cl_parent; + if (atomic_inc_return(&clnt->cl_swapper) == 1) + return rpc_clnt_iterate_for_each_xprt(clnt, + rpc_clnt_swap_activate_callback, NULL); + return 0; +} +EXPORT_SYMBOL_GPL(rpc_clnt_swap_activate); + +static int +rpc_clnt_swap_deactivate_callback(struct rpc_clnt *clnt, + struct rpc_xprt *xprt, + void *dummy) +{ + xprt_disable_swap(xprt); + return 0; +} + +void +rpc_clnt_swap_deactivate(struct rpc_clnt *clnt) +{ + while (clnt != clnt->cl_parent) + clnt = clnt->cl_parent; + if (atomic_dec_if_positive(&clnt->cl_swapper) == 0) + rpc_clnt_iterate_for_each_xprt(clnt, + rpc_clnt_swap_deactivate_callback, NULL); +} +EXPORT_SYMBOL_GPL(rpc_clnt_swap_deactivate); +#endif /* CONFIG_SUNRPC_SWAP */ diff --git a/net/sunrpc/debugfs.c b/net/sunrpc/debugfs.c new file mode 100644 index 000000000..a176d5a0b --- /dev/null +++ b/net/sunrpc/debugfs.c @@ -0,0 +1,294 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * debugfs interface for sunrpc + * + * (c) 2014 Jeff Layton + */ + +#include +#include +#include + +#include "netns.h" +#include "fail.h" + +static struct dentry *topdir; +static struct dentry *rpc_clnt_dir; +static struct dentry *rpc_xprt_dir; + +static int +tasks_show(struct seq_file *f, void *v) +{ + u32 xid = 0; + struct rpc_task *task = v; + struct rpc_clnt *clnt = task->tk_client; + const char *rpc_waitq = "none"; + + if (RPC_IS_QUEUED(task)) + rpc_waitq = rpc_qname(task->tk_waitqueue); + + if (task->tk_rqstp) + xid = be32_to_cpu(task->tk_rqstp->rq_xid); + + seq_printf(f, "%5u %04x %6d 0x%x 0x%x %8ld %ps %sv%u %s a:%ps q:%s\n", + task->tk_pid, task->tk_flags, task->tk_status, + clnt->cl_clid, xid, rpc_task_timeout(task), task->tk_ops, + clnt->cl_program->name, clnt->cl_vers, rpc_proc_name(task), + task->tk_action, rpc_waitq); + return 0; +} + +static void * +tasks_start(struct seq_file *f, loff_t *ppos) + __acquires(&clnt->cl_lock) +{ + struct rpc_clnt *clnt = f->private; + loff_t pos = *ppos; + struct rpc_task *task; + + spin_lock(&clnt->cl_lock); + list_for_each_entry(task, &clnt->cl_tasks, tk_task) + if (pos-- == 0) + return task; + return NULL; +} + +static void * +tasks_next(struct seq_file *f, void *v, loff_t *pos) +{ + struct rpc_clnt *clnt = f->private; + struct rpc_task *task = v; + struct list_head *next = task->tk_task.next; + + ++*pos; + + /* If there's another task on list, return it */ + if (next == &clnt->cl_tasks) + return NULL; + return list_entry(next, struct rpc_task, tk_task); +} + +static void +tasks_stop(struct seq_file *f, void *v) + __releases(&clnt->cl_lock) +{ + struct rpc_clnt *clnt = f->private; + spin_unlock(&clnt->cl_lock); +} + +static const struct seq_operations tasks_seq_operations = { + .start = tasks_start, + .next = tasks_next, + .stop = tasks_stop, + .show = tasks_show, +}; + +static int tasks_open(struct inode *inode, struct file *filp) +{ + int ret = seq_open(filp, &tasks_seq_operations); + if (!ret) { + struct seq_file *seq = filp->private_data; + struct rpc_clnt *clnt = seq->private = inode->i_private; + + if (!refcount_inc_not_zero(&clnt->cl_count)) { + seq_release(inode, filp); + ret = -EINVAL; + } + } + + return ret; +} + +static int +tasks_release(struct inode *inode, struct file *filp) +{ + struct seq_file *seq = filp->private_data; + struct rpc_clnt *clnt = seq->private; + + rpc_release_client(clnt); + return seq_release(inode, filp); +} + +static const struct file_operations tasks_fops = { + .owner = THIS_MODULE, + .open = tasks_open, + .read = seq_read, + .llseek = seq_lseek, + .release = tasks_release, +}; + +static int do_xprt_debugfs(struct rpc_clnt *clnt, struct rpc_xprt *xprt, void *numv) +{ + int len; + char name[24]; /* enough for "../../rpc_xprt/ + 8 hex digits + NULL */ + char link[9]; /* enough for 8 hex digits + NULL */ + int *nump = numv; + + if (IS_ERR_OR_NULL(xprt->debugfs)) + return 0; + len = snprintf(name, sizeof(name), "../../rpc_xprt/%s", + xprt->debugfs->d_name.name); + if (len >= sizeof(name)) + return -1; + if (*nump == 0) + strcpy(link, "xprt"); + else { + len = snprintf(link, sizeof(link), "xprt%d", *nump); + if (len >= sizeof(link)) + return -1; + } + debugfs_create_symlink(link, clnt->cl_debugfs, name); + (*nump)++; + return 0; +} + +void +rpc_clnt_debugfs_register(struct rpc_clnt *clnt) +{ + int len; + char name[9]; /* enough for 8 hex digits + NULL */ + int xprtnum = 0; + + len = snprintf(name, sizeof(name), "%x", clnt->cl_clid); + if (len >= sizeof(name)) + return; + + /* make the per-client dir */ + clnt->cl_debugfs = debugfs_create_dir(name, rpc_clnt_dir); + + /* make tasks file */ + debugfs_create_file("tasks", S_IFREG | 0400, clnt->cl_debugfs, clnt, + &tasks_fops); + + rpc_clnt_iterate_for_each_xprt(clnt, do_xprt_debugfs, &xprtnum); +} + +void +rpc_clnt_debugfs_unregister(struct rpc_clnt *clnt) +{ + debugfs_remove_recursive(clnt->cl_debugfs); + clnt->cl_debugfs = NULL; +} + +static int +xprt_info_show(struct seq_file *f, void *v) +{ + struct rpc_xprt *xprt = f->private; + + seq_printf(f, "netid: %s\n", xprt->address_strings[RPC_DISPLAY_NETID]); + seq_printf(f, "addr: %s\n", xprt->address_strings[RPC_DISPLAY_ADDR]); + seq_printf(f, "port: %s\n", xprt->address_strings[RPC_DISPLAY_PORT]); + seq_printf(f, "state: 0x%lx\n", xprt->state); + return 0; +} + +static int +xprt_info_open(struct inode *inode, struct file *filp) +{ + int ret; + struct rpc_xprt *xprt = inode->i_private; + + ret = single_open(filp, xprt_info_show, xprt); + + if (!ret) { + if (!xprt_get(xprt)) { + single_release(inode, filp); + ret = -EINVAL; + } + } + return ret; +} + +static int +xprt_info_release(struct inode *inode, struct file *filp) +{ + struct rpc_xprt *xprt = inode->i_private; + + xprt_put(xprt); + return single_release(inode, filp); +} + +static const struct file_operations xprt_info_fops = { + .owner = THIS_MODULE, + .open = xprt_info_open, + .read = seq_read, + .llseek = seq_lseek, + .release = xprt_info_release, +}; + +void +rpc_xprt_debugfs_register(struct rpc_xprt *xprt) +{ + int len, id; + static atomic_t cur_id; + char name[9]; /* 8 hex digits + NULL term */ + + id = (unsigned int)atomic_inc_return(&cur_id); + + len = snprintf(name, sizeof(name), "%x", id); + if (len >= sizeof(name)) + return; + + /* make the per-client dir */ + xprt->debugfs = debugfs_create_dir(name, rpc_xprt_dir); + + /* make tasks file */ + debugfs_create_file("info", S_IFREG | 0400, xprt->debugfs, xprt, + &xprt_info_fops); +} + +void +rpc_xprt_debugfs_unregister(struct rpc_xprt *xprt) +{ + debugfs_remove_recursive(xprt->debugfs); + xprt->debugfs = NULL; +} + +#if IS_ENABLED(CONFIG_FAIL_SUNRPC) +struct fail_sunrpc_attr fail_sunrpc = { + .attr = FAULT_ATTR_INITIALIZER, +}; +EXPORT_SYMBOL_GPL(fail_sunrpc); + +static void fail_sunrpc_init(void) +{ + struct dentry *dir; + + dir = fault_create_debugfs_attr("fail_sunrpc", NULL, + &fail_sunrpc.attr); + + debugfs_create_bool("ignore-client-disconnect", S_IFREG | 0600, dir, + &fail_sunrpc.ignore_client_disconnect); + + debugfs_create_bool("ignore-server-disconnect", S_IFREG | 0600, dir, + &fail_sunrpc.ignore_server_disconnect); + + debugfs_create_bool("ignore-cache-wait", S_IFREG | 0600, dir, + &fail_sunrpc.ignore_cache_wait); +} +#else +static void fail_sunrpc_init(void) +{ +} +#endif + +void __exit +sunrpc_debugfs_exit(void) +{ + debugfs_remove_recursive(topdir); + topdir = NULL; + rpc_clnt_dir = NULL; + rpc_xprt_dir = NULL; +} + +void __init +sunrpc_debugfs_init(void) +{ + topdir = debugfs_create_dir("sunrpc", NULL); + + rpc_clnt_dir = debugfs_create_dir("rpc_clnt", topdir); + + rpc_xprt_dir = debugfs_create_dir("rpc_xprt", topdir); + + fail_sunrpc_init(); +} diff --git a/net/sunrpc/fail.h b/net/sunrpc/fail.h new file mode 100644 index 000000000..4b4b500df --- /dev/null +++ b/net/sunrpc/fail.h @@ -0,0 +1,25 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Copyright (C) 2021, Oracle. All rights reserved. + */ + +#ifndef _NET_SUNRPC_FAIL_H_ +#define _NET_SUNRPC_FAIL_H_ + +#include + +#if IS_ENABLED(CONFIG_FAULT_INJECTION) + +struct fail_sunrpc_attr { + struct fault_attr attr; + + bool ignore_client_disconnect; + bool ignore_server_disconnect; + bool ignore_cache_wait; +}; + +extern struct fail_sunrpc_attr fail_sunrpc; + +#endif /* CONFIG_FAULT_INJECTION */ + +#endif /* _NET_SUNRPC_FAIL_H_ */ diff --git a/net/sunrpc/netns.h b/net/sunrpc/netns.h new file mode 100644 index 000000000..7ec10b92b --- /dev/null +++ b/net/sunrpc/netns.h @@ -0,0 +1,43 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __SUNRPC_NETNS_H__ +#define __SUNRPC_NETNS_H__ + +#include +#include + +struct cache_detail; + +struct sunrpc_net { + struct proc_dir_entry *proc_net_rpc; + struct cache_detail *ip_map_cache; + struct cache_detail *unix_gid_cache; + struct cache_detail *rsc_cache; + struct cache_detail *rsi_cache; + + struct super_block *pipefs_sb; + struct rpc_pipe *gssd_dummy; + struct mutex pipefs_sb_lock; + + struct list_head all_clients; + spinlock_t rpc_client_lock; + + struct rpc_clnt *rpcb_local_clnt; + struct rpc_clnt *rpcb_local_clnt4; + spinlock_t rpcb_clnt_lock; + unsigned int rpcb_users; + unsigned int rpcb_is_af_local : 1; + + struct mutex gssp_lock; + struct rpc_clnt *gssp_clnt; + int use_gss_proxy; + int pipe_version; + atomic_t pipe_users; + struct proc_dir_entry *use_gssp_proc; +}; + +extern unsigned int sunrpc_net_id; + +int ip_map_cache_create(struct net *); +void ip_map_cache_destroy(struct net *); + +#endif diff --git a/net/sunrpc/rpc_pipe.c b/net/sunrpc/rpc_pipe.c new file mode 100644 index 000000000..0b6034fab --- /dev/null +++ b/net/sunrpc/rpc_pipe.c @@ -0,0 +1,1517 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * net/sunrpc/rpc_pipe.c + * + * Userland/kernel interface for rpcauth_gss. + * Code shamelessly plagiarized from fs/nfsd/nfsctl.c + * and fs/sysfs/inode.c + * + * Copyright (c) 2002, Trond Myklebust + * + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "netns.h" +#include "sunrpc.h" + +#define RPCDBG_FACILITY RPCDBG_DEBUG + +#define NET_NAME(net) ((net == &init_net) ? " (init_net)" : "") + +static struct file_system_type rpc_pipe_fs_type; +static const struct rpc_pipe_ops gssd_dummy_pipe_ops; + +static struct kmem_cache *rpc_inode_cachep __read_mostly; + +#define RPC_UPCALL_TIMEOUT (30*HZ) + +static BLOCKING_NOTIFIER_HEAD(rpc_pipefs_notifier_list); + +int rpc_pipefs_notifier_register(struct notifier_block *nb) +{ + return blocking_notifier_chain_register(&rpc_pipefs_notifier_list, nb); +} +EXPORT_SYMBOL_GPL(rpc_pipefs_notifier_register); + +void rpc_pipefs_notifier_unregister(struct notifier_block *nb) +{ + blocking_notifier_chain_unregister(&rpc_pipefs_notifier_list, nb); +} +EXPORT_SYMBOL_GPL(rpc_pipefs_notifier_unregister); + +static void rpc_purge_list(wait_queue_head_t *waitq, struct list_head *head, + void (*destroy_msg)(struct rpc_pipe_msg *), int err) +{ + struct rpc_pipe_msg *msg; + + if (list_empty(head)) + return; + do { + msg = list_entry(head->next, struct rpc_pipe_msg, list); + list_del_init(&msg->list); + msg->errno = err; + destroy_msg(msg); + } while (!list_empty(head)); + + if (waitq) + wake_up(waitq); +} + +static void +rpc_timeout_upcall_queue(struct work_struct *work) +{ + LIST_HEAD(free_list); + struct rpc_pipe *pipe = + container_of(work, struct rpc_pipe, queue_timeout.work); + void (*destroy_msg)(struct rpc_pipe_msg *); + struct dentry *dentry; + + spin_lock(&pipe->lock); + destroy_msg = pipe->ops->destroy_msg; + if (pipe->nreaders == 0) { + list_splice_init(&pipe->pipe, &free_list); + pipe->pipelen = 0; + } + dentry = dget(pipe->dentry); + spin_unlock(&pipe->lock); + rpc_purge_list(dentry ? &RPC_I(d_inode(dentry))->waitq : NULL, + &free_list, destroy_msg, -ETIMEDOUT); + dput(dentry); +} + +ssize_t rpc_pipe_generic_upcall(struct file *filp, struct rpc_pipe_msg *msg, + char __user *dst, size_t buflen) +{ + char *data = (char *)msg->data + msg->copied; + size_t mlen = min(msg->len - msg->copied, buflen); + unsigned long left; + + left = copy_to_user(dst, data, mlen); + if (left == mlen) { + msg->errno = -EFAULT; + return -EFAULT; + } + + mlen -= left; + msg->copied += mlen; + msg->errno = 0; + return mlen; +} +EXPORT_SYMBOL_GPL(rpc_pipe_generic_upcall); + +/** + * rpc_queue_upcall - queue an upcall message to userspace + * @pipe: upcall pipe on which to queue given message + * @msg: message to queue + * + * Call with an @inode created by rpc_mkpipe() to queue an upcall. + * A userspace process may then later read the upcall by performing a + * read on an open file for this inode. It is up to the caller to + * initialize the fields of @msg (other than @msg->list) appropriately. + */ +int +rpc_queue_upcall(struct rpc_pipe *pipe, struct rpc_pipe_msg *msg) +{ + int res = -EPIPE; + struct dentry *dentry; + + spin_lock(&pipe->lock); + if (pipe->nreaders) { + list_add_tail(&msg->list, &pipe->pipe); + pipe->pipelen += msg->len; + res = 0; + } else if (pipe->flags & RPC_PIPE_WAIT_FOR_OPEN) { + if (list_empty(&pipe->pipe)) + queue_delayed_work(rpciod_workqueue, + &pipe->queue_timeout, + RPC_UPCALL_TIMEOUT); + list_add_tail(&msg->list, &pipe->pipe); + pipe->pipelen += msg->len; + res = 0; + } + dentry = dget(pipe->dentry); + spin_unlock(&pipe->lock); + if (dentry) { + wake_up(&RPC_I(d_inode(dentry))->waitq); + dput(dentry); + } + return res; +} +EXPORT_SYMBOL_GPL(rpc_queue_upcall); + +static inline void +rpc_inode_setowner(struct inode *inode, void *private) +{ + RPC_I(inode)->private = private; +} + +static void +rpc_close_pipes(struct inode *inode) +{ + struct rpc_pipe *pipe = RPC_I(inode)->pipe; + int need_release; + LIST_HEAD(free_list); + + inode_lock(inode); + spin_lock(&pipe->lock); + need_release = pipe->nreaders != 0 || pipe->nwriters != 0; + pipe->nreaders = 0; + list_splice_init(&pipe->in_upcall, &free_list); + list_splice_init(&pipe->pipe, &free_list); + pipe->pipelen = 0; + pipe->dentry = NULL; + spin_unlock(&pipe->lock); + rpc_purge_list(&RPC_I(inode)->waitq, &free_list, pipe->ops->destroy_msg, -EPIPE); + pipe->nwriters = 0; + if (need_release && pipe->ops->release_pipe) + pipe->ops->release_pipe(inode); + cancel_delayed_work_sync(&pipe->queue_timeout); + rpc_inode_setowner(inode, NULL); + RPC_I(inode)->pipe = NULL; + inode_unlock(inode); +} + +static struct inode * +rpc_alloc_inode(struct super_block *sb) +{ + struct rpc_inode *rpci; + rpci = alloc_inode_sb(sb, rpc_inode_cachep, GFP_KERNEL); + if (!rpci) + return NULL; + return &rpci->vfs_inode; +} + +static void +rpc_free_inode(struct inode *inode) +{ + kmem_cache_free(rpc_inode_cachep, RPC_I(inode)); +} + +static int +rpc_pipe_open(struct inode *inode, struct file *filp) +{ + struct rpc_pipe *pipe; + int first_open; + int res = -ENXIO; + + inode_lock(inode); + pipe = RPC_I(inode)->pipe; + if (pipe == NULL) + goto out; + first_open = pipe->nreaders == 0 && pipe->nwriters == 0; + if (first_open && pipe->ops->open_pipe) { + res = pipe->ops->open_pipe(inode); + if (res) + goto out; + } + if (filp->f_mode & FMODE_READ) + pipe->nreaders++; + if (filp->f_mode & FMODE_WRITE) + pipe->nwriters++; + res = 0; +out: + inode_unlock(inode); + return res; +} + +static int +rpc_pipe_release(struct inode *inode, struct file *filp) +{ + struct rpc_pipe *pipe; + struct rpc_pipe_msg *msg; + int last_close; + + inode_lock(inode); + pipe = RPC_I(inode)->pipe; + if (pipe == NULL) + goto out; + msg = filp->private_data; + if (msg != NULL) { + spin_lock(&pipe->lock); + msg->errno = -EAGAIN; + list_del_init(&msg->list); + spin_unlock(&pipe->lock); + pipe->ops->destroy_msg(msg); + } + if (filp->f_mode & FMODE_WRITE) + pipe->nwriters --; + if (filp->f_mode & FMODE_READ) { + pipe->nreaders --; + if (pipe->nreaders == 0) { + LIST_HEAD(free_list); + spin_lock(&pipe->lock); + list_splice_init(&pipe->pipe, &free_list); + pipe->pipelen = 0; + spin_unlock(&pipe->lock); + rpc_purge_list(&RPC_I(inode)->waitq, &free_list, + pipe->ops->destroy_msg, -EAGAIN); + } + } + last_close = pipe->nwriters == 0 && pipe->nreaders == 0; + if (last_close && pipe->ops->release_pipe) + pipe->ops->release_pipe(inode); +out: + inode_unlock(inode); + return 0; +} + +static ssize_t +rpc_pipe_read(struct file *filp, char __user *buf, size_t len, loff_t *offset) +{ + struct inode *inode = file_inode(filp); + struct rpc_pipe *pipe; + struct rpc_pipe_msg *msg; + int res = 0; + + inode_lock(inode); + pipe = RPC_I(inode)->pipe; + if (pipe == NULL) { + res = -EPIPE; + goto out_unlock; + } + msg = filp->private_data; + if (msg == NULL) { + spin_lock(&pipe->lock); + if (!list_empty(&pipe->pipe)) { + msg = list_entry(pipe->pipe.next, + struct rpc_pipe_msg, + list); + list_move(&msg->list, &pipe->in_upcall); + pipe->pipelen -= msg->len; + filp->private_data = msg; + msg->copied = 0; + } + spin_unlock(&pipe->lock); + if (msg == NULL) + goto out_unlock; + } + /* NOTE: it is up to the callback to update msg->copied */ + res = pipe->ops->upcall(filp, msg, buf, len); + if (res < 0 || msg->len == msg->copied) { + filp->private_data = NULL; + spin_lock(&pipe->lock); + list_del_init(&msg->list); + spin_unlock(&pipe->lock); + pipe->ops->destroy_msg(msg); + } +out_unlock: + inode_unlock(inode); + return res; +} + +static ssize_t +rpc_pipe_write(struct file *filp, const char __user *buf, size_t len, loff_t *offset) +{ + struct inode *inode = file_inode(filp); + int res; + + inode_lock(inode); + res = -EPIPE; + if (RPC_I(inode)->pipe != NULL) + res = RPC_I(inode)->pipe->ops->downcall(filp, buf, len); + inode_unlock(inode); + return res; +} + +static __poll_t +rpc_pipe_poll(struct file *filp, struct poll_table_struct *wait) +{ + struct inode *inode = file_inode(filp); + struct rpc_inode *rpci = RPC_I(inode); + __poll_t mask = EPOLLOUT | EPOLLWRNORM; + + poll_wait(filp, &rpci->waitq, wait); + + inode_lock(inode); + if (rpci->pipe == NULL) + mask |= EPOLLERR | EPOLLHUP; + else if (filp->private_data || !list_empty(&rpci->pipe->pipe)) + mask |= EPOLLIN | EPOLLRDNORM; + inode_unlock(inode); + return mask; +} + +static long +rpc_pipe_ioctl(struct file *filp, unsigned int cmd, unsigned long arg) +{ + struct inode *inode = file_inode(filp); + struct rpc_pipe *pipe; + int len; + + switch (cmd) { + case FIONREAD: + inode_lock(inode); + pipe = RPC_I(inode)->pipe; + if (pipe == NULL) { + inode_unlock(inode); + return -EPIPE; + } + spin_lock(&pipe->lock); + len = pipe->pipelen; + if (filp->private_data) { + struct rpc_pipe_msg *msg; + msg = filp->private_data; + len += msg->len - msg->copied; + } + spin_unlock(&pipe->lock); + inode_unlock(inode); + return put_user(len, (int __user *)arg); + default: + return -EINVAL; + } +} + +static const struct file_operations rpc_pipe_fops = { + .owner = THIS_MODULE, + .llseek = no_llseek, + .read = rpc_pipe_read, + .write = rpc_pipe_write, + .poll = rpc_pipe_poll, + .unlocked_ioctl = rpc_pipe_ioctl, + .open = rpc_pipe_open, + .release = rpc_pipe_release, +}; + +static int +rpc_show_info(struct seq_file *m, void *v) +{ + struct rpc_clnt *clnt = m->private; + + rcu_read_lock(); + seq_printf(m, "RPC server: %s\n", + rcu_dereference(clnt->cl_xprt)->servername); + seq_printf(m, "service: %s (%d) version %d\n", clnt->cl_program->name, + clnt->cl_prog, clnt->cl_vers); + seq_printf(m, "address: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_ADDR)); + seq_printf(m, "protocol: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_PROTO)); + seq_printf(m, "port: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_PORT)); + rcu_read_unlock(); + return 0; +} + +static int +rpc_info_open(struct inode *inode, struct file *file) +{ + struct rpc_clnt *clnt = NULL; + int ret = single_open(file, rpc_show_info, NULL); + + if (!ret) { + struct seq_file *m = file->private_data; + + spin_lock(&file->f_path.dentry->d_lock); + if (!d_unhashed(file->f_path.dentry)) + clnt = RPC_I(inode)->private; + if (clnt != NULL && refcount_inc_not_zero(&clnt->cl_count)) { + spin_unlock(&file->f_path.dentry->d_lock); + m->private = clnt; + } else { + spin_unlock(&file->f_path.dentry->d_lock); + single_release(inode, file); + ret = -EINVAL; + } + } + return ret; +} + +static int +rpc_info_release(struct inode *inode, struct file *file) +{ + struct seq_file *m = file->private_data; + struct rpc_clnt *clnt = (struct rpc_clnt *)m->private; + + if (clnt) + rpc_release_client(clnt); + return single_release(inode, file); +} + +static const struct file_operations rpc_info_operations = { + .owner = THIS_MODULE, + .open = rpc_info_open, + .read = seq_read, + .llseek = seq_lseek, + .release = rpc_info_release, +}; + + +/* + * Description of fs contents. + */ +struct rpc_filelist { + const char *name; + const struct file_operations *i_fop; + umode_t mode; +}; + +static struct inode * +rpc_get_inode(struct super_block *sb, umode_t mode) +{ + struct inode *inode = new_inode(sb); + if (!inode) + return NULL; + inode->i_ino = get_next_ino(); + inode->i_mode = mode; + inode->i_atime = inode->i_mtime = inode->i_ctime = current_time(inode); + switch (mode & S_IFMT) { + case S_IFDIR: + inode->i_fop = &simple_dir_operations; + inode->i_op = &simple_dir_inode_operations; + inc_nlink(inode); + break; + default: + break; + } + return inode; +} + +static int __rpc_create_common(struct inode *dir, struct dentry *dentry, + umode_t mode, + const struct file_operations *i_fop, + void *private) +{ + struct inode *inode; + + d_drop(dentry); + inode = rpc_get_inode(dir->i_sb, mode); + if (!inode) + goto out_err; + inode->i_ino = iunique(dir->i_sb, 100); + if (i_fop) + inode->i_fop = i_fop; + if (private) + rpc_inode_setowner(inode, private); + d_add(dentry, inode); + return 0; +out_err: + printk(KERN_WARNING "%s: %s failed to allocate inode for dentry %pd\n", + __FILE__, __func__, dentry); + dput(dentry); + return -ENOMEM; +} + +static int __rpc_create(struct inode *dir, struct dentry *dentry, + umode_t mode, + const struct file_operations *i_fop, + void *private) +{ + int err; + + err = __rpc_create_common(dir, dentry, S_IFREG | mode, i_fop, private); + if (err) + return err; + fsnotify_create(dir, dentry); + return 0; +} + +static int __rpc_mkdir(struct inode *dir, struct dentry *dentry, + umode_t mode, + const struct file_operations *i_fop, + void *private) +{ + int err; + + err = __rpc_create_common(dir, dentry, S_IFDIR | mode, i_fop, private); + if (err) + return err; + inc_nlink(dir); + fsnotify_mkdir(dir, dentry); + return 0; +} + +static void +init_pipe(struct rpc_pipe *pipe) +{ + pipe->nreaders = 0; + pipe->nwriters = 0; + INIT_LIST_HEAD(&pipe->in_upcall); + INIT_LIST_HEAD(&pipe->in_downcall); + INIT_LIST_HEAD(&pipe->pipe); + pipe->pipelen = 0; + INIT_DELAYED_WORK(&pipe->queue_timeout, + rpc_timeout_upcall_queue); + pipe->ops = NULL; + spin_lock_init(&pipe->lock); + pipe->dentry = NULL; +} + +void rpc_destroy_pipe_data(struct rpc_pipe *pipe) +{ + kfree(pipe); +} +EXPORT_SYMBOL_GPL(rpc_destroy_pipe_data); + +struct rpc_pipe *rpc_mkpipe_data(const struct rpc_pipe_ops *ops, int flags) +{ + struct rpc_pipe *pipe; + + pipe = kzalloc(sizeof(struct rpc_pipe), GFP_KERNEL); + if (!pipe) + return ERR_PTR(-ENOMEM); + init_pipe(pipe); + pipe->ops = ops; + pipe->flags = flags; + return pipe; +} +EXPORT_SYMBOL_GPL(rpc_mkpipe_data); + +static int __rpc_mkpipe_dentry(struct inode *dir, struct dentry *dentry, + umode_t mode, + const struct file_operations *i_fop, + void *private, + struct rpc_pipe *pipe) +{ + struct rpc_inode *rpci; + int err; + + err = __rpc_create_common(dir, dentry, S_IFIFO | mode, i_fop, private); + if (err) + return err; + rpci = RPC_I(d_inode(dentry)); + rpci->private = private; + rpci->pipe = pipe; + fsnotify_create(dir, dentry); + return 0; +} + +static int __rpc_rmdir(struct inode *dir, struct dentry *dentry) +{ + int ret; + + dget(dentry); + ret = simple_rmdir(dir, dentry); + d_drop(dentry); + if (!ret) + fsnotify_rmdir(dir, dentry); + dput(dentry); + return ret; +} + +static int __rpc_unlink(struct inode *dir, struct dentry *dentry) +{ + int ret; + + dget(dentry); + ret = simple_unlink(dir, dentry); + d_drop(dentry); + if (!ret) + fsnotify_unlink(dir, dentry); + dput(dentry); + return ret; +} + +static int __rpc_rmpipe(struct inode *dir, struct dentry *dentry) +{ + struct inode *inode = d_inode(dentry); + + rpc_close_pipes(inode); + return __rpc_unlink(dir, dentry); +} + +static struct dentry *__rpc_lookup_create_exclusive(struct dentry *parent, + const char *name) +{ + struct qstr q = QSTR_INIT(name, strlen(name)); + struct dentry *dentry = d_hash_and_lookup(parent, &q); + if (!dentry) { + dentry = d_alloc(parent, &q); + if (!dentry) + return ERR_PTR(-ENOMEM); + } + if (d_really_is_negative(dentry)) + return dentry; + dput(dentry); + return ERR_PTR(-EEXIST); +} + +/* + * FIXME: This probably has races. + */ +static void __rpc_depopulate(struct dentry *parent, + const struct rpc_filelist *files, + int start, int eof) +{ + struct inode *dir = d_inode(parent); + struct dentry *dentry; + struct qstr name; + int i; + + for (i = start; i < eof; i++) { + name.name = files[i].name; + name.len = strlen(files[i].name); + dentry = d_hash_and_lookup(parent, &name); + + if (dentry == NULL) + continue; + if (d_really_is_negative(dentry)) + goto next; + switch (d_inode(dentry)->i_mode & S_IFMT) { + default: + BUG(); + case S_IFREG: + __rpc_unlink(dir, dentry); + break; + case S_IFDIR: + __rpc_rmdir(dir, dentry); + } +next: + dput(dentry); + } +} + +static void rpc_depopulate(struct dentry *parent, + const struct rpc_filelist *files, + int start, int eof) +{ + struct inode *dir = d_inode(parent); + + inode_lock_nested(dir, I_MUTEX_CHILD); + __rpc_depopulate(parent, files, start, eof); + inode_unlock(dir); +} + +static int rpc_populate(struct dentry *parent, + const struct rpc_filelist *files, + int start, int eof, + void *private) +{ + struct inode *dir = d_inode(parent); + struct dentry *dentry; + int i, err; + + inode_lock(dir); + for (i = start; i < eof; i++) { + dentry = __rpc_lookup_create_exclusive(parent, files[i].name); + err = PTR_ERR(dentry); + if (IS_ERR(dentry)) + goto out_bad; + switch (files[i].mode & S_IFMT) { + default: + BUG(); + case S_IFREG: + err = __rpc_create(dir, dentry, + files[i].mode, + files[i].i_fop, + private); + break; + case S_IFDIR: + err = __rpc_mkdir(dir, dentry, + files[i].mode, + NULL, + private); + } + if (err != 0) + goto out_bad; + } + inode_unlock(dir); + return 0; +out_bad: + __rpc_depopulate(parent, files, start, eof); + inode_unlock(dir); + printk(KERN_WARNING "%s: %s failed to populate directory %pd\n", + __FILE__, __func__, parent); + return err; +} + +static struct dentry *rpc_mkdir_populate(struct dentry *parent, + const char *name, umode_t mode, void *private, + int (*populate)(struct dentry *, void *), void *args_populate) +{ + struct dentry *dentry; + struct inode *dir = d_inode(parent); + int error; + + inode_lock_nested(dir, I_MUTEX_PARENT); + dentry = __rpc_lookup_create_exclusive(parent, name); + if (IS_ERR(dentry)) + goto out; + error = __rpc_mkdir(dir, dentry, mode, NULL, private); + if (error != 0) + goto out_err; + if (populate != NULL) { + error = populate(dentry, args_populate); + if (error) + goto err_rmdir; + } +out: + inode_unlock(dir); + return dentry; +err_rmdir: + __rpc_rmdir(dir, dentry); +out_err: + dentry = ERR_PTR(error); + goto out; +} + +static int rpc_rmdir_depopulate(struct dentry *dentry, + void (*depopulate)(struct dentry *)) +{ + struct dentry *parent; + struct inode *dir; + int error; + + parent = dget_parent(dentry); + dir = d_inode(parent); + inode_lock_nested(dir, I_MUTEX_PARENT); + if (depopulate != NULL) + depopulate(dentry); + error = __rpc_rmdir(dir, dentry); + inode_unlock(dir); + dput(parent); + return error; +} + +/** + * rpc_mkpipe_dentry - make an rpc_pipefs file for kernel<->userspace + * communication + * @parent: dentry of directory to create new "pipe" in + * @name: name of pipe + * @private: private data to associate with the pipe, for the caller's use + * @pipe: &rpc_pipe containing input parameters + * + * Data is made available for userspace to read by calls to + * rpc_queue_upcall(). The actual reads will result in calls to + * @ops->upcall, which will be called with the file pointer, + * message, and userspace buffer to copy to. + * + * Writes can come at any time, and do not necessarily have to be + * responses to upcalls. They will result in calls to @msg->downcall. + * + * The @private argument passed here will be available to all these methods + * from the file pointer, via RPC_I(file_inode(file))->private. + */ +struct dentry *rpc_mkpipe_dentry(struct dentry *parent, const char *name, + void *private, struct rpc_pipe *pipe) +{ + struct dentry *dentry; + struct inode *dir = d_inode(parent); + umode_t umode = S_IFIFO | 0600; + int err; + + if (pipe->ops->upcall == NULL) + umode &= ~0444; + if (pipe->ops->downcall == NULL) + umode &= ~0222; + + inode_lock_nested(dir, I_MUTEX_PARENT); + dentry = __rpc_lookup_create_exclusive(parent, name); + if (IS_ERR(dentry)) + goto out; + err = __rpc_mkpipe_dentry(dir, dentry, umode, &rpc_pipe_fops, + private, pipe); + if (err) + goto out_err; +out: + inode_unlock(dir); + return dentry; +out_err: + dentry = ERR_PTR(err); + printk(KERN_WARNING "%s: %s() failed to create pipe %pd/%s (errno = %d)\n", + __FILE__, __func__, parent, name, + err); + goto out; +} +EXPORT_SYMBOL_GPL(rpc_mkpipe_dentry); + +/** + * rpc_unlink - remove a pipe + * @dentry: dentry for the pipe, as returned from rpc_mkpipe + * + * After this call, lookups will no longer find the pipe, and any + * attempts to read or write using preexisting opens of the pipe will + * return -EPIPE. + */ +int +rpc_unlink(struct dentry *dentry) +{ + struct dentry *parent; + struct inode *dir; + int error = 0; + + parent = dget_parent(dentry); + dir = d_inode(parent); + inode_lock_nested(dir, I_MUTEX_PARENT); + error = __rpc_rmpipe(dir, dentry); + inode_unlock(dir); + dput(parent); + return error; +} +EXPORT_SYMBOL_GPL(rpc_unlink); + +/** + * rpc_init_pipe_dir_head - initialise a struct rpc_pipe_dir_head + * @pdh: pointer to struct rpc_pipe_dir_head + */ +void rpc_init_pipe_dir_head(struct rpc_pipe_dir_head *pdh) +{ + INIT_LIST_HEAD(&pdh->pdh_entries); + pdh->pdh_dentry = NULL; +} +EXPORT_SYMBOL_GPL(rpc_init_pipe_dir_head); + +/** + * rpc_init_pipe_dir_object - initialise a struct rpc_pipe_dir_object + * @pdo: pointer to struct rpc_pipe_dir_object + * @pdo_ops: pointer to const struct rpc_pipe_dir_object_ops + * @pdo_data: pointer to caller-defined data + */ +void rpc_init_pipe_dir_object(struct rpc_pipe_dir_object *pdo, + const struct rpc_pipe_dir_object_ops *pdo_ops, + void *pdo_data) +{ + INIT_LIST_HEAD(&pdo->pdo_head); + pdo->pdo_ops = pdo_ops; + pdo->pdo_data = pdo_data; +} +EXPORT_SYMBOL_GPL(rpc_init_pipe_dir_object); + +static int +rpc_add_pipe_dir_object_locked(struct net *net, + struct rpc_pipe_dir_head *pdh, + struct rpc_pipe_dir_object *pdo) +{ + int ret = 0; + + if (pdh->pdh_dentry) + ret = pdo->pdo_ops->create(pdh->pdh_dentry, pdo); + if (ret == 0) + list_add_tail(&pdo->pdo_head, &pdh->pdh_entries); + return ret; +} + +static void +rpc_remove_pipe_dir_object_locked(struct net *net, + struct rpc_pipe_dir_head *pdh, + struct rpc_pipe_dir_object *pdo) +{ + if (pdh->pdh_dentry) + pdo->pdo_ops->destroy(pdh->pdh_dentry, pdo); + list_del_init(&pdo->pdo_head); +} + +/** + * rpc_add_pipe_dir_object - associate a rpc_pipe_dir_object to a directory + * @net: pointer to struct net + * @pdh: pointer to struct rpc_pipe_dir_head + * @pdo: pointer to struct rpc_pipe_dir_object + * + */ +int +rpc_add_pipe_dir_object(struct net *net, + struct rpc_pipe_dir_head *pdh, + struct rpc_pipe_dir_object *pdo) +{ + int ret = 0; + + if (list_empty(&pdo->pdo_head)) { + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + mutex_lock(&sn->pipefs_sb_lock); + ret = rpc_add_pipe_dir_object_locked(net, pdh, pdo); + mutex_unlock(&sn->pipefs_sb_lock); + } + return ret; +} +EXPORT_SYMBOL_GPL(rpc_add_pipe_dir_object); + +/** + * rpc_remove_pipe_dir_object - remove a rpc_pipe_dir_object from a directory + * @net: pointer to struct net + * @pdh: pointer to struct rpc_pipe_dir_head + * @pdo: pointer to struct rpc_pipe_dir_object + * + */ +void +rpc_remove_pipe_dir_object(struct net *net, + struct rpc_pipe_dir_head *pdh, + struct rpc_pipe_dir_object *pdo) +{ + if (!list_empty(&pdo->pdo_head)) { + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + mutex_lock(&sn->pipefs_sb_lock); + rpc_remove_pipe_dir_object_locked(net, pdh, pdo); + mutex_unlock(&sn->pipefs_sb_lock); + } +} +EXPORT_SYMBOL_GPL(rpc_remove_pipe_dir_object); + +/** + * rpc_find_or_alloc_pipe_dir_object + * @net: pointer to struct net + * @pdh: pointer to struct rpc_pipe_dir_head + * @match: match struct rpc_pipe_dir_object to data + * @alloc: allocate a new struct rpc_pipe_dir_object + * @data: user defined data for match() and alloc() + * + */ +struct rpc_pipe_dir_object * +rpc_find_or_alloc_pipe_dir_object(struct net *net, + struct rpc_pipe_dir_head *pdh, + int (*match)(struct rpc_pipe_dir_object *, void *), + struct rpc_pipe_dir_object *(*alloc)(void *), + void *data) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_pipe_dir_object *pdo; + + mutex_lock(&sn->pipefs_sb_lock); + list_for_each_entry(pdo, &pdh->pdh_entries, pdo_head) { + if (!match(pdo, data)) + continue; + goto out; + } + pdo = alloc(data); + if (!pdo) + goto out; + rpc_add_pipe_dir_object_locked(net, pdh, pdo); +out: + mutex_unlock(&sn->pipefs_sb_lock); + return pdo; +} +EXPORT_SYMBOL_GPL(rpc_find_or_alloc_pipe_dir_object); + +static void +rpc_create_pipe_dir_objects(struct rpc_pipe_dir_head *pdh) +{ + struct rpc_pipe_dir_object *pdo; + struct dentry *dir = pdh->pdh_dentry; + + list_for_each_entry(pdo, &pdh->pdh_entries, pdo_head) + pdo->pdo_ops->create(dir, pdo); +} + +static void +rpc_destroy_pipe_dir_objects(struct rpc_pipe_dir_head *pdh) +{ + struct rpc_pipe_dir_object *pdo; + struct dentry *dir = pdh->pdh_dentry; + + list_for_each_entry(pdo, &pdh->pdh_entries, pdo_head) + pdo->pdo_ops->destroy(dir, pdo); +} + +enum { + RPCAUTH_info, + RPCAUTH_EOF +}; + +static const struct rpc_filelist authfiles[] = { + [RPCAUTH_info] = { + .name = "info", + .i_fop = &rpc_info_operations, + .mode = S_IFREG | 0400, + }, +}; + +static int rpc_clntdir_populate(struct dentry *dentry, void *private) +{ + return rpc_populate(dentry, + authfiles, RPCAUTH_info, RPCAUTH_EOF, + private); +} + +static void rpc_clntdir_depopulate(struct dentry *dentry) +{ + rpc_depopulate(dentry, authfiles, RPCAUTH_info, RPCAUTH_EOF); +} + +/** + * rpc_create_client_dir - Create a new rpc_client directory in rpc_pipefs + * @dentry: the parent of new directory + * @name: the name of new directory + * @rpc_client: rpc client to associate with this directory + * + * This creates a directory at the given @path associated with + * @rpc_clnt, which will contain a file named "info" with some basic + * information about the client, together with any "pipes" that may + * later be created using rpc_mkpipe(). + */ +struct dentry *rpc_create_client_dir(struct dentry *dentry, + const char *name, + struct rpc_clnt *rpc_client) +{ + struct dentry *ret; + + ret = rpc_mkdir_populate(dentry, name, 0555, NULL, + rpc_clntdir_populate, rpc_client); + if (!IS_ERR(ret)) { + rpc_client->cl_pipedir_objects.pdh_dentry = ret; + rpc_create_pipe_dir_objects(&rpc_client->cl_pipedir_objects); + } + return ret; +} + +/** + * rpc_remove_client_dir - Remove a directory created with rpc_create_client_dir() + * @rpc_client: rpc_client for the pipe + */ +int rpc_remove_client_dir(struct rpc_clnt *rpc_client) +{ + struct dentry *dentry = rpc_client->cl_pipedir_objects.pdh_dentry; + + if (dentry == NULL) + return 0; + rpc_destroy_pipe_dir_objects(&rpc_client->cl_pipedir_objects); + rpc_client->cl_pipedir_objects.pdh_dentry = NULL; + return rpc_rmdir_depopulate(dentry, rpc_clntdir_depopulate); +} + +static const struct rpc_filelist cache_pipefs_files[3] = { + [0] = { + .name = "channel", + .i_fop = &cache_file_operations_pipefs, + .mode = S_IFREG | 0600, + }, + [1] = { + .name = "content", + .i_fop = &content_file_operations_pipefs, + .mode = S_IFREG | 0400, + }, + [2] = { + .name = "flush", + .i_fop = &cache_flush_operations_pipefs, + .mode = S_IFREG | 0600, + }, +}; + +static int rpc_cachedir_populate(struct dentry *dentry, void *private) +{ + return rpc_populate(dentry, + cache_pipefs_files, 0, 3, + private); +} + +static void rpc_cachedir_depopulate(struct dentry *dentry) +{ + rpc_depopulate(dentry, cache_pipefs_files, 0, 3); +} + +struct dentry *rpc_create_cache_dir(struct dentry *parent, const char *name, + umode_t umode, struct cache_detail *cd) +{ + return rpc_mkdir_populate(parent, name, umode, NULL, + rpc_cachedir_populate, cd); +} + +void rpc_remove_cache_dir(struct dentry *dentry) +{ + rpc_rmdir_depopulate(dentry, rpc_cachedir_depopulate); +} + +/* + * populate the filesystem + */ +static const struct super_operations s_ops = { + .alloc_inode = rpc_alloc_inode, + .free_inode = rpc_free_inode, + .statfs = simple_statfs, +}; + +#define RPCAUTH_GSSMAGIC 0x67596969 + +/* + * We have a single directory with 1 node in it. + */ +enum { + RPCAUTH_lockd, + RPCAUTH_mount, + RPCAUTH_nfs, + RPCAUTH_portmap, + RPCAUTH_statd, + RPCAUTH_nfsd4_cb, + RPCAUTH_cache, + RPCAUTH_nfsd, + RPCAUTH_gssd, + RPCAUTH_RootEOF +}; + +static const struct rpc_filelist files[] = { + [RPCAUTH_lockd] = { + .name = "lockd", + .mode = S_IFDIR | 0555, + }, + [RPCAUTH_mount] = { + .name = "mount", + .mode = S_IFDIR | 0555, + }, + [RPCAUTH_nfs] = { + .name = "nfs", + .mode = S_IFDIR | 0555, + }, + [RPCAUTH_portmap] = { + .name = "portmap", + .mode = S_IFDIR | 0555, + }, + [RPCAUTH_statd] = { + .name = "statd", + .mode = S_IFDIR | 0555, + }, + [RPCAUTH_nfsd4_cb] = { + .name = "nfsd4_cb", + .mode = S_IFDIR | 0555, + }, + [RPCAUTH_cache] = { + .name = "cache", + .mode = S_IFDIR | 0555, + }, + [RPCAUTH_nfsd] = { + .name = "nfsd", + .mode = S_IFDIR | 0555, + }, + [RPCAUTH_gssd] = { + .name = "gssd", + .mode = S_IFDIR | 0555, + }, +}; + +/* + * This call can be used only in RPC pipefs mount notification hooks. + */ +struct dentry *rpc_d_lookup_sb(const struct super_block *sb, + const unsigned char *dir_name) +{ + struct qstr dir = QSTR_INIT(dir_name, strlen(dir_name)); + return d_hash_and_lookup(sb->s_root, &dir); +} +EXPORT_SYMBOL_GPL(rpc_d_lookup_sb); + +int rpc_pipefs_init_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + sn->gssd_dummy = rpc_mkpipe_data(&gssd_dummy_pipe_ops, 0); + if (IS_ERR(sn->gssd_dummy)) + return PTR_ERR(sn->gssd_dummy); + + mutex_init(&sn->pipefs_sb_lock); + sn->pipe_version = -1; + return 0; +} + +void rpc_pipefs_exit_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + rpc_destroy_pipe_data(sn->gssd_dummy); +} + +/* + * This call will be used for per network namespace operations calls. + * Note: Function will be returned with pipefs_sb_lock taken if superblock was + * found. This lock have to be released by rpc_put_sb_net() when all operations + * will be completed. + */ +struct super_block *rpc_get_sb_net(const struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + mutex_lock(&sn->pipefs_sb_lock); + if (sn->pipefs_sb) + return sn->pipefs_sb; + mutex_unlock(&sn->pipefs_sb_lock); + return NULL; +} +EXPORT_SYMBOL_GPL(rpc_get_sb_net); + +void rpc_put_sb_net(const struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + WARN_ON(sn->pipefs_sb == NULL); + mutex_unlock(&sn->pipefs_sb_lock); +} +EXPORT_SYMBOL_GPL(rpc_put_sb_net); + +static const struct rpc_filelist gssd_dummy_clnt_dir[] = { + [0] = { + .name = "clntXX", + .mode = S_IFDIR | 0555, + }, +}; + +static ssize_t +dummy_downcall(struct file *filp, const char __user *src, size_t len) +{ + return -EINVAL; +} + +static const struct rpc_pipe_ops gssd_dummy_pipe_ops = { + .upcall = rpc_pipe_generic_upcall, + .downcall = dummy_downcall, +}; + +/* + * Here we present a bogus "info" file to keep rpc.gssd happy. We don't expect + * that it will ever use this info to handle an upcall, but rpc.gssd expects + * that this file will be there and have a certain format. + */ +static int +rpc_dummy_info_show(struct seq_file *m, void *v) +{ + seq_printf(m, "RPC server: %s\n", utsname()->nodename); + seq_printf(m, "service: foo (1) version 0\n"); + seq_printf(m, "address: 127.0.0.1\n"); + seq_printf(m, "protocol: tcp\n"); + seq_printf(m, "port: 0\n"); + return 0; +} +DEFINE_SHOW_ATTRIBUTE(rpc_dummy_info); + +static const struct rpc_filelist gssd_dummy_info_file[] = { + [0] = { + .name = "info", + .i_fop = &rpc_dummy_info_fops, + .mode = S_IFREG | 0400, + }, +}; + +/** + * rpc_gssd_dummy_populate - create a dummy gssd pipe + * @root: root of the rpc_pipefs filesystem + * @pipe_data: pipe data created when netns is initialized + * + * Create a dummy set of directories and a pipe that gssd can hold open to + * indicate that it is up and running. + */ +static struct dentry * +rpc_gssd_dummy_populate(struct dentry *root, struct rpc_pipe *pipe_data) +{ + int ret = 0; + struct dentry *gssd_dentry; + struct dentry *clnt_dentry = NULL; + struct dentry *pipe_dentry = NULL; + struct qstr q = QSTR_INIT(files[RPCAUTH_gssd].name, + strlen(files[RPCAUTH_gssd].name)); + + /* We should never get this far if "gssd" doesn't exist */ + gssd_dentry = d_hash_and_lookup(root, &q); + if (!gssd_dentry) + return ERR_PTR(-ENOENT); + + ret = rpc_populate(gssd_dentry, gssd_dummy_clnt_dir, 0, 1, NULL); + if (ret) { + pipe_dentry = ERR_PTR(ret); + goto out; + } + + q.name = gssd_dummy_clnt_dir[0].name; + q.len = strlen(gssd_dummy_clnt_dir[0].name); + clnt_dentry = d_hash_and_lookup(gssd_dentry, &q); + if (!clnt_dentry) { + __rpc_depopulate(gssd_dentry, gssd_dummy_clnt_dir, 0, 1); + pipe_dentry = ERR_PTR(-ENOENT); + goto out; + } + + ret = rpc_populate(clnt_dentry, gssd_dummy_info_file, 0, 1, NULL); + if (ret) { + __rpc_depopulate(gssd_dentry, gssd_dummy_clnt_dir, 0, 1); + pipe_dentry = ERR_PTR(ret); + goto out; + } + + pipe_dentry = rpc_mkpipe_dentry(clnt_dentry, "gssd", NULL, pipe_data); + if (IS_ERR(pipe_dentry)) { + __rpc_depopulate(clnt_dentry, gssd_dummy_info_file, 0, 1); + __rpc_depopulate(gssd_dentry, gssd_dummy_clnt_dir, 0, 1); + } +out: + dput(clnt_dentry); + dput(gssd_dentry); + return pipe_dentry; +} + +static void +rpc_gssd_dummy_depopulate(struct dentry *pipe_dentry) +{ + struct dentry *clnt_dir = pipe_dentry->d_parent; + struct dentry *gssd_dir = clnt_dir->d_parent; + + dget(pipe_dentry); + __rpc_rmpipe(d_inode(clnt_dir), pipe_dentry); + __rpc_depopulate(clnt_dir, gssd_dummy_info_file, 0, 1); + __rpc_depopulate(gssd_dir, gssd_dummy_clnt_dir, 0, 1); + dput(pipe_dentry); +} + +static int +rpc_fill_super(struct super_block *sb, struct fs_context *fc) +{ + struct inode *inode; + struct dentry *root, *gssd_dentry; + struct net *net = sb->s_fs_info; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + int err; + + sb->s_blocksize = PAGE_SIZE; + sb->s_blocksize_bits = PAGE_SHIFT; + sb->s_magic = RPCAUTH_GSSMAGIC; + sb->s_op = &s_ops; + sb->s_d_op = &simple_dentry_operations; + sb->s_time_gran = 1; + + inode = rpc_get_inode(sb, S_IFDIR | 0555); + sb->s_root = root = d_make_root(inode); + if (!root) + return -ENOMEM; + if (rpc_populate(root, files, RPCAUTH_lockd, RPCAUTH_RootEOF, NULL)) + return -ENOMEM; + + gssd_dentry = rpc_gssd_dummy_populate(root, sn->gssd_dummy); + if (IS_ERR(gssd_dentry)) { + __rpc_depopulate(root, files, RPCAUTH_lockd, RPCAUTH_RootEOF); + return PTR_ERR(gssd_dentry); + } + + dprintk("RPC: sending pipefs MOUNT notification for net %x%s\n", + net->ns.inum, NET_NAME(net)); + mutex_lock(&sn->pipefs_sb_lock); + sn->pipefs_sb = sb; + err = blocking_notifier_call_chain(&rpc_pipefs_notifier_list, + RPC_PIPEFS_MOUNT, + sb); + if (err) + goto err_depopulate; + mutex_unlock(&sn->pipefs_sb_lock); + return 0; + +err_depopulate: + rpc_gssd_dummy_depopulate(gssd_dentry); + blocking_notifier_call_chain(&rpc_pipefs_notifier_list, + RPC_PIPEFS_UMOUNT, + sb); + sn->pipefs_sb = NULL; + __rpc_depopulate(root, files, RPCAUTH_lockd, RPCAUTH_RootEOF); + mutex_unlock(&sn->pipefs_sb_lock); + return err; +} + +bool +gssd_running(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_pipe *pipe = sn->gssd_dummy; + + return pipe->nreaders || pipe->nwriters; +} +EXPORT_SYMBOL_GPL(gssd_running); + +static int rpc_fs_get_tree(struct fs_context *fc) +{ + return get_tree_keyed(fc, rpc_fill_super, get_net(fc->net_ns)); +} + +static void rpc_fs_free_fc(struct fs_context *fc) +{ + if (fc->s_fs_info) + put_net(fc->s_fs_info); +} + +static const struct fs_context_operations rpc_fs_context_ops = { + .free = rpc_fs_free_fc, + .get_tree = rpc_fs_get_tree, +}; + +static int rpc_init_fs_context(struct fs_context *fc) +{ + put_user_ns(fc->user_ns); + fc->user_ns = get_user_ns(fc->net_ns->user_ns); + fc->ops = &rpc_fs_context_ops; + return 0; +} + +static void rpc_kill_sb(struct super_block *sb) +{ + struct net *net = sb->s_fs_info; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + mutex_lock(&sn->pipefs_sb_lock); + if (sn->pipefs_sb != sb) { + mutex_unlock(&sn->pipefs_sb_lock); + goto out; + } + sn->pipefs_sb = NULL; + dprintk("RPC: sending pipefs UMOUNT notification for net %x%s\n", + net->ns.inum, NET_NAME(net)); + blocking_notifier_call_chain(&rpc_pipefs_notifier_list, + RPC_PIPEFS_UMOUNT, + sb); + mutex_unlock(&sn->pipefs_sb_lock); +out: + kill_litter_super(sb); + put_net(net); +} + +static struct file_system_type rpc_pipe_fs_type = { + .owner = THIS_MODULE, + .name = "rpc_pipefs", + .init_fs_context = rpc_init_fs_context, + .kill_sb = rpc_kill_sb, +}; +MODULE_ALIAS_FS("rpc_pipefs"); +MODULE_ALIAS("rpc_pipefs"); + +static void +init_once(void *foo) +{ + struct rpc_inode *rpci = (struct rpc_inode *) foo; + + inode_init_once(&rpci->vfs_inode); + rpci->private = NULL; + rpci->pipe = NULL; + init_waitqueue_head(&rpci->waitq); +} + +int register_rpc_pipefs(void) +{ + int err; + + rpc_inode_cachep = kmem_cache_create("rpc_inode_cache", + sizeof(struct rpc_inode), + 0, (SLAB_HWCACHE_ALIGN|SLAB_RECLAIM_ACCOUNT| + SLAB_MEM_SPREAD|SLAB_ACCOUNT), + init_once); + if (!rpc_inode_cachep) + return -ENOMEM; + err = rpc_clients_notifier_register(); + if (err) + goto err_notifier; + err = register_filesystem(&rpc_pipe_fs_type); + if (err) + goto err_register; + return 0; + +err_register: + rpc_clients_notifier_unregister(); +err_notifier: + kmem_cache_destroy(rpc_inode_cachep); + return err; +} + +void unregister_rpc_pipefs(void) +{ + rpc_clients_notifier_unregister(); + unregister_filesystem(&rpc_pipe_fs_type); + kmem_cache_destroy(rpc_inode_cachep); +} diff --git a/net/sunrpc/rpcb_clnt.c b/net/sunrpc/rpcb_clnt.c new file mode 100644 index 000000000..82afb5669 --- /dev/null +++ b/net/sunrpc/rpcb_clnt.c @@ -0,0 +1,1098 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * In-kernel rpcbind client supporting versions 2, 3, and 4 of the rpcbind + * protocol + * + * Based on RFC 1833: "Binding Protocols for ONC RPC Version 2" and + * RFC 3530: "Network File System (NFS) version 4 Protocol" + * + * Original: Gilles Quillard, Bull Open Source, 2005 + * Updated: Chuck Lever, Oracle Corporation, 2007 + * + * Descended from net/sunrpc/pmap_clnt.c, + * Copyright (C) 1996, Olaf Kirch + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include + +#include "netns.h" + +#define RPCBIND_SOCK_PATHNAME "/var/run/rpcbind.sock" + +#define RPCBIND_PROGRAM (100000u) +#define RPCBIND_PORT (111u) + +#define RPCBVERS_2 (2u) +#define RPCBVERS_3 (3u) +#define RPCBVERS_4 (4u) + +enum { + RPCBPROC_NULL, + RPCBPROC_SET, + RPCBPROC_UNSET, + RPCBPROC_GETPORT, + RPCBPROC_GETADDR = 3, /* alias for GETPORT */ + RPCBPROC_DUMP, + RPCBPROC_CALLIT, + RPCBPROC_BCAST = 5, /* alias for CALLIT */ + RPCBPROC_GETTIME, + RPCBPROC_UADDR2TADDR, + RPCBPROC_TADDR2UADDR, + RPCBPROC_GETVERSADDR, + RPCBPROC_INDIRECT, + RPCBPROC_GETADDRLIST, + RPCBPROC_GETSTAT, +}; + +/* + * r_owner + * + * The "owner" is allowed to unset a service in the rpcbind database. + * + * For AF_LOCAL SET/UNSET requests, rpcbind treats this string as a + * UID which it maps to a local user name via a password lookup. + * In all other cases it is ignored. + * + * For SET/UNSET requests, user space provides a value, even for + * network requests, and GETADDR uses an empty string. We follow + * those precedents here. + */ +#define RPCB_OWNER_STRING "0" +#define RPCB_MAXOWNERLEN sizeof(RPCB_OWNER_STRING) + +/* + * XDR data type sizes + */ +#define RPCB_program_sz (1) +#define RPCB_version_sz (1) +#define RPCB_protocol_sz (1) +#define RPCB_port_sz (1) +#define RPCB_boolean_sz (1) + +#define RPCB_netid_sz (1 + XDR_QUADLEN(RPCBIND_MAXNETIDLEN)) +#define RPCB_addr_sz (1 + XDR_QUADLEN(RPCBIND_MAXUADDRLEN)) +#define RPCB_ownerstring_sz (1 + XDR_QUADLEN(RPCB_MAXOWNERLEN)) + +/* + * XDR argument and result sizes + */ +#define RPCB_mappingargs_sz (RPCB_program_sz + RPCB_version_sz + \ + RPCB_protocol_sz + RPCB_port_sz) +#define RPCB_getaddrargs_sz (RPCB_program_sz + RPCB_version_sz + \ + RPCB_netid_sz + RPCB_addr_sz + \ + RPCB_ownerstring_sz) + +#define RPCB_getportres_sz RPCB_port_sz +#define RPCB_setres_sz RPCB_boolean_sz + +/* + * Note that RFC 1833 does not put any size restrictions on the + * address string returned by the remote rpcbind database. + */ +#define RPCB_getaddrres_sz RPCB_addr_sz + +static void rpcb_getport_done(struct rpc_task *, void *); +static void rpcb_map_release(void *data); +static const struct rpc_program rpcb_program; + +struct rpcbind_args { + struct rpc_xprt * r_xprt; + + u32 r_prog; + u32 r_vers; + u32 r_prot; + unsigned short r_port; + const char * r_netid; + const char * r_addr; + const char * r_owner; + + int r_status; +}; + +static const struct rpc_procinfo rpcb_procedures2[]; +static const struct rpc_procinfo rpcb_procedures3[]; +static const struct rpc_procinfo rpcb_procedures4[]; + +struct rpcb_info { + u32 rpc_vers; + const struct rpc_procinfo *rpc_proc; +}; + +static const struct rpcb_info rpcb_next_version[]; +static const struct rpcb_info rpcb_next_version6[]; + +static const struct rpc_call_ops rpcb_getport_ops = { + .rpc_call_done = rpcb_getport_done, + .rpc_release = rpcb_map_release, +}; + +static void rpcb_wake_rpcbind_waiters(struct rpc_xprt *xprt, int status) +{ + xprt_clear_binding(xprt); + rpc_wake_up_status(&xprt->binding, status); +} + +static void rpcb_map_release(void *data) +{ + struct rpcbind_args *map = data; + + rpcb_wake_rpcbind_waiters(map->r_xprt, map->r_status); + xprt_put(map->r_xprt); + kfree(map->r_addr); + kfree(map); +} + +static int rpcb_get_local(struct net *net) +{ + int cnt; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + spin_lock(&sn->rpcb_clnt_lock); + if (sn->rpcb_users) + sn->rpcb_users++; + cnt = sn->rpcb_users; + spin_unlock(&sn->rpcb_clnt_lock); + + return cnt; +} + +void rpcb_put_local(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct rpc_clnt *clnt = sn->rpcb_local_clnt; + struct rpc_clnt *clnt4 = sn->rpcb_local_clnt4; + int shutdown = 0; + + spin_lock(&sn->rpcb_clnt_lock); + if (sn->rpcb_users) { + if (--sn->rpcb_users == 0) { + sn->rpcb_local_clnt = NULL; + sn->rpcb_local_clnt4 = NULL; + } + shutdown = !sn->rpcb_users; + } + spin_unlock(&sn->rpcb_clnt_lock); + + if (shutdown) { + /* + * cleanup_rpcb_clnt - remove xprtsock's sysctls, unregister + */ + if (clnt4) + rpc_shutdown_client(clnt4); + if (clnt) + rpc_shutdown_client(clnt); + } +} + +static void rpcb_set_local(struct net *net, struct rpc_clnt *clnt, + struct rpc_clnt *clnt4, + bool is_af_local) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + /* Protected by rpcb_create_local_mutex */ + sn->rpcb_local_clnt = clnt; + sn->rpcb_local_clnt4 = clnt4; + sn->rpcb_is_af_local = is_af_local ? 1 : 0; + smp_wmb(); + sn->rpcb_users = 1; +} + +/* + * Returns zero on success, otherwise a negative errno value + * is returned. + */ +static int rpcb_create_local_unix(struct net *net) +{ + static const struct sockaddr_un rpcb_localaddr_rpcbind = { + .sun_family = AF_LOCAL, + .sun_path = RPCBIND_SOCK_PATHNAME, + }; + struct rpc_create_args args = { + .net = net, + .protocol = XPRT_TRANSPORT_LOCAL, + .address = (struct sockaddr *)&rpcb_localaddr_rpcbind, + .addrsize = sizeof(rpcb_localaddr_rpcbind), + .servername = "localhost", + .program = &rpcb_program, + .version = RPCBVERS_2, + .authflavor = RPC_AUTH_NULL, + .cred = current_cred(), + /* + * We turn off the idle timeout to prevent the kernel + * from automatically disconnecting the socket. + * Otherwise, we'd have to cache the mount namespace + * of the caller and somehow pass that to the socket + * reconnect code. + */ + .flags = RPC_CLNT_CREATE_NO_IDLE_TIMEOUT, + }; + struct rpc_clnt *clnt, *clnt4; + int result = 0; + + /* + * Because we requested an RPC PING at transport creation time, + * this works only if the user space portmapper is rpcbind, and + * it's listening on AF_LOCAL on the named socket. + */ + clnt = rpc_create(&args); + if (IS_ERR(clnt)) { + result = PTR_ERR(clnt); + goto out; + } + + clnt4 = rpc_bind_new_program(clnt, &rpcb_program, RPCBVERS_4); + if (IS_ERR(clnt4)) + clnt4 = NULL; + + rpcb_set_local(net, clnt, clnt4, true); + +out: + return result; +} + +/* + * Returns zero on success, otherwise a negative errno value + * is returned. + */ +static int rpcb_create_local_net(struct net *net) +{ + static const struct sockaddr_in rpcb_inaddr_loopback = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_LOOPBACK), + .sin_port = htons(RPCBIND_PORT), + }; + struct rpc_create_args args = { + .net = net, + .protocol = XPRT_TRANSPORT_TCP, + .address = (struct sockaddr *)&rpcb_inaddr_loopback, + .addrsize = sizeof(rpcb_inaddr_loopback), + .servername = "localhost", + .program = &rpcb_program, + .version = RPCBVERS_2, + .authflavor = RPC_AUTH_UNIX, + .cred = current_cred(), + .flags = RPC_CLNT_CREATE_NOPING, + }; + struct rpc_clnt *clnt, *clnt4; + int result = 0; + + clnt = rpc_create(&args); + if (IS_ERR(clnt)) { + result = PTR_ERR(clnt); + goto out; + } + + /* + * This results in an RPC ping. On systems running portmapper, + * the v4 ping will fail. Proceed anyway, but disallow rpcb + * v4 upcalls. + */ + clnt4 = rpc_bind_new_program(clnt, &rpcb_program, RPCBVERS_4); + if (IS_ERR(clnt4)) + clnt4 = NULL; + + rpcb_set_local(net, clnt, clnt4, false); + +out: + return result; +} + +/* + * Returns zero on success, otherwise a negative errno value + * is returned. + */ +int rpcb_create_local(struct net *net) +{ + static DEFINE_MUTEX(rpcb_create_local_mutex); + int result = 0; + + if (rpcb_get_local(net)) + return result; + + mutex_lock(&rpcb_create_local_mutex); + if (rpcb_get_local(net)) + goto out; + + if (rpcb_create_local_unix(net) != 0) + result = rpcb_create_local_net(net); + +out: + mutex_unlock(&rpcb_create_local_mutex); + return result; +} + +static struct rpc_clnt *rpcb_create(struct net *net, const char *nodename, + const char *hostname, + struct sockaddr *srvaddr, size_t salen, + int proto, u32 version, + const struct cred *cred, + const struct rpc_timeout *timeo) +{ + struct rpc_create_args args = { + .net = net, + .protocol = proto, + .address = srvaddr, + .addrsize = salen, + .timeout = timeo, + .servername = hostname, + .nodename = nodename, + .program = &rpcb_program, + .version = version, + .authflavor = RPC_AUTH_UNIX, + .cred = cred, + .flags = (RPC_CLNT_CREATE_NOPING | + RPC_CLNT_CREATE_NONPRIVPORT), + }; + + switch (srvaddr->sa_family) { + case AF_INET: + ((struct sockaddr_in *)srvaddr)->sin_port = htons(RPCBIND_PORT); + break; + case AF_INET6: + ((struct sockaddr_in6 *)srvaddr)->sin6_port = htons(RPCBIND_PORT); + break; + default: + return ERR_PTR(-EAFNOSUPPORT); + } + + return rpc_create(&args); +} + +static int rpcb_register_call(struct sunrpc_net *sn, struct rpc_clnt *clnt, struct rpc_message *msg, bool is_set) +{ + int flags = RPC_TASK_NOCONNECT; + int error, result = 0; + + if (is_set || !sn->rpcb_is_af_local) + flags = RPC_TASK_SOFTCONN; + msg->rpc_resp = &result; + + error = rpc_call_sync(clnt, msg, flags); + if (error < 0) + return error; + + if (!result) + return -EACCES; + return 0; +} + +/** + * rpcb_register - set or unset a port registration with the local rpcbind svc + * @net: target network namespace + * @prog: RPC program number to bind + * @vers: RPC version number to bind + * @prot: transport protocol to register + * @port: port value to register + * + * Returns zero if the registration request was dispatched successfully + * and the rpcbind daemon returned success. Otherwise, returns an errno + * value that reflects the nature of the error (request could not be + * dispatched, timed out, or rpcbind returned an error). + * + * RPC services invoke this function to advertise their contact + * information via the system's rpcbind daemon. RPC services + * invoke this function once for each [program, version, transport] + * tuple they wish to advertise. + * + * Callers may also unregister RPC services that are no longer + * available by setting the passed-in port to zero. This removes + * all registered transports for [program, version] from the local + * rpcbind database. + * + * This function uses rpcbind protocol version 2 to contact the + * local rpcbind daemon. + * + * Registration works over both AF_INET and AF_INET6, and services + * registered via this function are advertised as available for any + * address. If the local rpcbind daemon is listening on AF_INET6, + * services registered via this function will be advertised on + * IN6ADDR_ANY (ie available for all AF_INET and AF_INET6 + * addresses). + */ +int rpcb_register(struct net *net, u32 prog, u32 vers, int prot, unsigned short port) +{ + struct rpcbind_args map = { + .r_prog = prog, + .r_vers = vers, + .r_prot = prot, + .r_port = port, + }; + struct rpc_message msg = { + .rpc_argp = &map, + }; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + bool is_set = false; + + trace_pmap_register(prog, vers, prot, port); + + msg.rpc_proc = &rpcb_procedures2[RPCBPROC_UNSET]; + if (port != 0) { + msg.rpc_proc = &rpcb_procedures2[RPCBPROC_SET]; + is_set = true; + } + + return rpcb_register_call(sn, sn->rpcb_local_clnt, &msg, is_set); +} + +/* + * Fill in AF_INET family-specific arguments to register + */ +static int rpcb_register_inet4(struct sunrpc_net *sn, + const struct sockaddr *sap, + struct rpc_message *msg) +{ + const struct sockaddr_in *sin = (const struct sockaddr_in *)sap; + struct rpcbind_args *map = msg->rpc_argp; + unsigned short port = ntohs(sin->sin_port); + bool is_set = false; + int result; + + map->r_addr = rpc_sockaddr2uaddr(sap, GFP_KERNEL); + + msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET]; + if (port != 0) { + msg->rpc_proc = &rpcb_procedures4[RPCBPROC_SET]; + is_set = true; + } + + result = rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, is_set); + kfree(map->r_addr); + return result; +} + +/* + * Fill in AF_INET6 family-specific arguments to register + */ +static int rpcb_register_inet6(struct sunrpc_net *sn, + const struct sockaddr *sap, + struct rpc_message *msg) +{ + const struct sockaddr_in6 *sin6 = (const struct sockaddr_in6 *)sap; + struct rpcbind_args *map = msg->rpc_argp; + unsigned short port = ntohs(sin6->sin6_port); + bool is_set = false; + int result; + + map->r_addr = rpc_sockaddr2uaddr(sap, GFP_KERNEL); + + msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET]; + if (port != 0) { + msg->rpc_proc = &rpcb_procedures4[RPCBPROC_SET]; + is_set = true; + } + + result = rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, is_set); + kfree(map->r_addr); + return result; +} + +static int rpcb_unregister_all_protofamilies(struct sunrpc_net *sn, + struct rpc_message *msg) +{ + struct rpcbind_args *map = msg->rpc_argp; + + trace_rpcb_unregister(map->r_prog, map->r_vers, map->r_netid); + + map->r_addr = ""; + msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET]; + + return rpcb_register_call(sn, sn->rpcb_local_clnt4, msg, false); +} + +/** + * rpcb_v4_register - set or unset a port registration with the local rpcbind + * @net: target network namespace + * @program: RPC program number of service to (un)register + * @version: RPC version number of service to (un)register + * @address: address family, IP address, and port to (un)register + * @netid: netid of transport protocol to (un)register + * + * Returns zero if the registration request was dispatched successfully + * and the rpcbind daemon returned success. Otherwise, returns an errno + * value that reflects the nature of the error (request could not be + * dispatched, timed out, or rpcbind returned an error). + * + * RPC services invoke this function to advertise their contact + * information via the system's rpcbind daemon. RPC services + * invoke this function once for each [program, version, address, + * netid] tuple they wish to advertise. + * + * Callers may also unregister RPC services that are registered at a + * specific address by setting the port number in @address to zero. + * They may unregister all registered protocol families at once for + * a service by passing a NULL @address argument. If @netid is "" + * then all netids for [program, version, address] are unregistered. + * + * This function uses rpcbind protocol version 4 to contact the + * local rpcbind daemon. The local rpcbind daemon must support + * version 4 of the rpcbind protocol in order for these functions + * to register a service successfully. + * + * Supported netids include "udp" and "tcp" for UDP and TCP over + * IPv4, and "udp6" and "tcp6" for UDP and TCP over IPv6, + * respectively. + * + * The contents of @address determine the address family and the + * port to be registered. The usual practice is to pass INADDR_ANY + * as the raw address, but specifying a non-zero address is also + * supported by this API if the caller wishes to advertise an RPC + * service on a specific network interface. + * + * Note that passing in INADDR_ANY does not create the same service + * registration as IN6ADDR_ANY. The former advertises an RPC + * service on any IPv4 address, but not on IPv6. The latter + * advertises the service on all IPv4 and IPv6 addresses. + */ +int rpcb_v4_register(struct net *net, const u32 program, const u32 version, + const struct sockaddr *address, const char *netid) +{ + struct rpcbind_args map = { + .r_prog = program, + .r_vers = version, + .r_netid = netid, + .r_owner = RPCB_OWNER_STRING, + }; + struct rpc_message msg = { + .rpc_argp = &map, + }; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + if (sn->rpcb_local_clnt4 == NULL) + return -EPROTONOSUPPORT; + + if (address == NULL) + return rpcb_unregister_all_protofamilies(sn, &msg); + + trace_rpcb_register(map.r_prog, map.r_vers, map.r_addr, map.r_netid); + + switch (address->sa_family) { + case AF_INET: + return rpcb_register_inet4(sn, address, &msg); + case AF_INET6: + return rpcb_register_inet6(sn, address, &msg); + } + + return -EAFNOSUPPORT; +} + +static struct rpc_task *rpcb_call_async(struct rpc_clnt *rpcb_clnt, + struct rpcbind_args *map, const struct rpc_procinfo *proc) +{ + struct rpc_message msg = { + .rpc_proc = proc, + .rpc_argp = map, + .rpc_resp = map, + }; + struct rpc_task_setup task_setup_data = { + .rpc_client = rpcb_clnt, + .rpc_message = &msg, + .callback_ops = &rpcb_getport_ops, + .callback_data = map, + .flags = RPC_TASK_ASYNC | RPC_TASK_SOFTCONN, + }; + + return rpc_run_task(&task_setup_data); +} + +/* + * In the case where rpc clients have been cloned, we want to make + * sure that we use the program number/version etc of the actual + * owner of the xprt. To do so, we walk back up the tree of parents + * to find whoever created the transport and/or whoever has the + * autobind flag set. + */ +static struct rpc_clnt *rpcb_find_transport_owner(struct rpc_clnt *clnt) +{ + struct rpc_clnt *parent = clnt->cl_parent; + struct rpc_xprt_switch *xps = rcu_access_pointer(clnt->cl_xpi.xpi_xpswitch); + + while (parent != clnt) { + if (rcu_access_pointer(parent->cl_xpi.xpi_xpswitch) != xps) + break; + if (clnt->cl_autobind) + break; + clnt = parent; + parent = parent->cl_parent; + } + return clnt; +} + +/** + * rpcb_getport_async - obtain the port for a given RPC service on a given host + * @task: task that is waiting for portmapper request + * + * This one can be called for an ongoing RPC request, and can be used in + * an async (rpciod) context. + */ +void rpcb_getport_async(struct rpc_task *task) +{ + struct rpc_clnt *clnt; + const struct rpc_procinfo *proc; + u32 bind_version; + struct rpc_xprt *xprt; + struct rpc_clnt *rpcb_clnt; + struct rpcbind_args *map; + struct rpc_task *child; + struct sockaddr_storage addr; + struct sockaddr *sap = (struct sockaddr *)&addr; + size_t salen; + int status; + + rcu_read_lock(); + clnt = rpcb_find_transport_owner(task->tk_client); + rcu_read_unlock(); + xprt = xprt_get(task->tk_xprt); + + /* Put self on the wait queue to ensure we get notified if + * some other task is already attempting to bind the port */ + rpc_sleep_on_timeout(&xprt->binding, task, + NULL, jiffies + xprt->bind_timeout); + + if (xprt_test_and_set_binding(xprt)) { + xprt_put(xprt); + return; + } + + /* Someone else may have bound if we slept */ + if (xprt_bound(xprt)) { + status = 0; + goto bailout_nofree; + } + + /* Parent transport's destination address */ + salen = rpc_peeraddr(clnt, sap, sizeof(addr)); + + /* Don't ever use rpcbind v2 for AF_INET6 requests */ + switch (sap->sa_family) { + case AF_INET: + proc = rpcb_next_version[xprt->bind_index].rpc_proc; + bind_version = rpcb_next_version[xprt->bind_index].rpc_vers; + break; + case AF_INET6: + proc = rpcb_next_version6[xprt->bind_index].rpc_proc; + bind_version = rpcb_next_version6[xprt->bind_index].rpc_vers; + break; + default: + status = -EAFNOSUPPORT; + goto bailout_nofree; + } + if (proc == NULL) { + xprt->bind_index = 0; + status = -EPFNOSUPPORT; + goto bailout_nofree; + } + + trace_rpcb_getport(clnt, task, bind_version); + + rpcb_clnt = rpcb_create(xprt->xprt_net, + clnt->cl_nodename, + xprt->servername, sap, salen, + xprt->prot, bind_version, + clnt->cl_cred, + task->tk_client->cl_timeout); + if (IS_ERR(rpcb_clnt)) { + status = PTR_ERR(rpcb_clnt); + goto bailout_nofree; + } + + map = kzalloc(sizeof(struct rpcbind_args), rpc_task_gfp_mask()); + if (!map) { + status = -ENOMEM; + goto bailout_release_client; + } + map->r_prog = clnt->cl_prog; + map->r_vers = clnt->cl_vers; + map->r_prot = xprt->prot; + map->r_port = 0; + map->r_xprt = xprt; + map->r_status = -EIO; + + switch (bind_version) { + case RPCBVERS_4: + case RPCBVERS_3: + map->r_netid = xprt->address_strings[RPC_DISPLAY_NETID]; + map->r_addr = rpc_sockaddr2uaddr(sap, rpc_task_gfp_mask()); + if (!map->r_addr) { + status = -ENOMEM; + goto bailout_free_args; + } + map->r_owner = ""; + break; + case RPCBVERS_2: + map->r_addr = NULL; + break; + default: + BUG(); + } + + child = rpcb_call_async(rpcb_clnt, map, proc); + rpc_release_client(rpcb_clnt); + if (IS_ERR(child)) { + /* rpcb_map_release() has freed the arguments */ + return; + } + + xprt->stat.bind_count++; + rpc_put_task(child); + return; + +bailout_free_args: + kfree(map); +bailout_release_client: + rpc_release_client(rpcb_clnt); +bailout_nofree: + rpcb_wake_rpcbind_waiters(xprt, status); + task->tk_status = status; + xprt_put(xprt); +} +EXPORT_SYMBOL_GPL(rpcb_getport_async); + +/* + * Rpcbind child task calls this callback via tk_exit. + */ +static void rpcb_getport_done(struct rpc_task *child, void *data) +{ + struct rpcbind_args *map = data; + struct rpc_xprt *xprt = map->r_xprt; + + map->r_status = child->tk_status; + + /* Garbage reply: retry with a lesser rpcbind version */ + if (map->r_status == -EIO) + map->r_status = -EPROTONOSUPPORT; + + /* rpcbind server doesn't support this rpcbind protocol version */ + if (map->r_status == -EPROTONOSUPPORT) + xprt->bind_index++; + + if (map->r_status < 0) { + /* rpcbind server not available on remote host? */ + map->r_port = 0; + + } else if (map->r_port == 0) { + /* Requested RPC service wasn't registered on remote host */ + map->r_status = -EACCES; + } else { + /* Succeeded */ + map->r_status = 0; + } + + trace_rpcb_setport(child, map->r_status, map->r_port); + xprt->ops->set_port(xprt, map->r_port); + if (map->r_port) + xprt_set_bound(xprt); +} + +/* + * XDR functions for rpcbind + */ + +static void rpcb_enc_mapping(struct rpc_rqst *req, struct xdr_stream *xdr, + const void *data) +{ + const struct rpcbind_args *rpcb = data; + __be32 *p; + + p = xdr_reserve_space(xdr, RPCB_mappingargs_sz << 2); + *p++ = cpu_to_be32(rpcb->r_prog); + *p++ = cpu_to_be32(rpcb->r_vers); + *p++ = cpu_to_be32(rpcb->r_prot); + *p = cpu_to_be32(rpcb->r_port); +} + +static int rpcb_dec_getport(struct rpc_rqst *req, struct xdr_stream *xdr, + void *data) +{ + struct rpcbind_args *rpcb = data; + unsigned long port; + __be32 *p; + + rpcb->r_port = 0; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -EIO; + + port = be32_to_cpup(p); + if (unlikely(port > USHRT_MAX)) + return -EIO; + + rpcb->r_port = port; + return 0; +} + +static int rpcb_dec_set(struct rpc_rqst *req, struct xdr_stream *xdr, + void *data) +{ + unsigned int *boolp = data; + __be32 *p; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + return -EIO; + + *boolp = 0; + if (*p != xdr_zero) + *boolp = 1; + return 0; +} + +static void encode_rpcb_string(struct xdr_stream *xdr, const char *string, + const u32 maxstrlen) +{ + __be32 *p; + u32 len; + + len = strlen(string); + WARN_ON_ONCE(len > maxstrlen); + if (len > maxstrlen) + /* truncate and hope for the best */ + len = maxstrlen; + p = xdr_reserve_space(xdr, 4 + len); + xdr_encode_opaque(p, string, len); +} + +static void rpcb_enc_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr, + const void *data) +{ + const struct rpcbind_args *rpcb = data; + __be32 *p; + + p = xdr_reserve_space(xdr, (RPCB_program_sz + RPCB_version_sz) << 2); + *p++ = cpu_to_be32(rpcb->r_prog); + *p = cpu_to_be32(rpcb->r_vers); + + encode_rpcb_string(xdr, rpcb->r_netid, RPCBIND_MAXNETIDLEN); + encode_rpcb_string(xdr, rpcb->r_addr, RPCBIND_MAXUADDRLEN); + encode_rpcb_string(xdr, rpcb->r_owner, RPCB_MAXOWNERLEN); +} + +static int rpcb_dec_getaddr(struct rpc_rqst *req, struct xdr_stream *xdr, + void *data) +{ + struct rpcbind_args *rpcb = data; + struct sockaddr_storage address; + struct sockaddr *sap = (struct sockaddr *)&address; + __be32 *p; + u32 len; + + rpcb->r_port = 0; + + p = xdr_inline_decode(xdr, 4); + if (unlikely(p == NULL)) + goto out_fail; + len = be32_to_cpup(p); + + /* + * If the returned universal address is a null string, + * the requested RPC service was not registered. + */ + if (len == 0) + return 0; + + if (unlikely(len > RPCBIND_MAXUADDRLEN)) + goto out_fail; + + p = xdr_inline_decode(xdr, len); + if (unlikely(p == NULL)) + goto out_fail; + + if (rpc_uaddr2sockaddr(req->rq_xprt->xprt_net, (char *)p, len, + sap, sizeof(address)) == 0) + goto out_fail; + rpcb->r_port = rpc_get_port(sap); + + return 0; + +out_fail: + return -EIO; +} + +/* + * Not all rpcbind procedures described in RFC 1833 are implemented + * since the Linux kernel RPC code requires only these. + */ + +static const struct rpc_procinfo rpcb_procedures2[] = { + [RPCBPROC_SET] = { + .p_proc = RPCBPROC_SET, + .p_encode = rpcb_enc_mapping, + .p_decode = rpcb_dec_set, + .p_arglen = RPCB_mappingargs_sz, + .p_replen = RPCB_setres_sz, + .p_statidx = RPCBPROC_SET, + .p_timer = 0, + .p_name = "SET", + }, + [RPCBPROC_UNSET] = { + .p_proc = RPCBPROC_UNSET, + .p_encode = rpcb_enc_mapping, + .p_decode = rpcb_dec_set, + .p_arglen = RPCB_mappingargs_sz, + .p_replen = RPCB_setres_sz, + .p_statidx = RPCBPROC_UNSET, + .p_timer = 0, + .p_name = "UNSET", + }, + [RPCBPROC_GETPORT] = { + .p_proc = RPCBPROC_GETPORT, + .p_encode = rpcb_enc_mapping, + .p_decode = rpcb_dec_getport, + .p_arglen = RPCB_mappingargs_sz, + .p_replen = RPCB_getportres_sz, + .p_statidx = RPCBPROC_GETPORT, + .p_timer = 0, + .p_name = "GETPORT", + }, +}; + +static const struct rpc_procinfo rpcb_procedures3[] = { + [RPCBPROC_SET] = { + .p_proc = RPCBPROC_SET, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_set, + .p_arglen = RPCB_getaddrargs_sz, + .p_replen = RPCB_setres_sz, + .p_statidx = RPCBPROC_SET, + .p_timer = 0, + .p_name = "SET", + }, + [RPCBPROC_UNSET] = { + .p_proc = RPCBPROC_UNSET, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_set, + .p_arglen = RPCB_getaddrargs_sz, + .p_replen = RPCB_setres_sz, + .p_statidx = RPCBPROC_UNSET, + .p_timer = 0, + .p_name = "UNSET", + }, + [RPCBPROC_GETADDR] = { + .p_proc = RPCBPROC_GETADDR, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_getaddr, + .p_arglen = RPCB_getaddrargs_sz, + .p_replen = RPCB_getaddrres_sz, + .p_statidx = RPCBPROC_GETADDR, + .p_timer = 0, + .p_name = "GETADDR", + }, +}; + +static const struct rpc_procinfo rpcb_procedures4[] = { + [RPCBPROC_SET] = { + .p_proc = RPCBPROC_SET, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_set, + .p_arglen = RPCB_getaddrargs_sz, + .p_replen = RPCB_setres_sz, + .p_statidx = RPCBPROC_SET, + .p_timer = 0, + .p_name = "SET", + }, + [RPCBPROC_UNSET] = { + .p_proc = RPCBPROC_UNSET, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_set, + .p_arglen = RPCB_getaddrargs_sz, + .p_replen = RPCB_setres_sz, + .p_statidx = RPCBPROC_UNSET, + .p_timer = 0, + .p_name = "UNSET", + }, + [RPCBPROC_GETADDR] = { + .p_proc = RPCBPROC_GETADDR, + .p_encode = rpcb_enc_getaddr, + .p_decode = rpcb_dec_getaddr, + .p_arglen = RPCB_getaddrargs_sz, + .p_replen = RPCB_getaddrres_sz, + .p_statidx = RPCBPROC_GETADDR, + .p_timer = 0, + .p_name = "GETADDR", + }, +}; + +static const struct rpcb_info rpcb_next_version[] = { + { + .rpc_vers = RPCBVERS_2, + .rpc_proc = &rpcb_procedures2[RPCBPROC_GETPORT], + }, + { + .rpc_proc = NULL, + }, +}; + +static const struct rpcb_info rpcb_next_version6[] = { + { + .rpc_vers = RPCBVERS_4, + .rpc_proc = &rpcb_procedures4[RPCBPROC_GETADDR], + }, + { + .rpc_vers = RPCBVERS_3, + .rpc_proc = &rpcb_procedures3[RPCBPROC_GETADDR], + }, + { + .rpc_proc = NULL, + }, +}; + +static unsigned int rpcb_version2_counts[ARRAY_SIZE(rpcb_procedures2)]; +static const struct rpc_version rpcb_version2 = { + .number = RPCBVERS_2, + .nrprocs = ARRAY_SIZE(rpcb_procedures2), + .procs = rpcb_procedures2, + .counts = rpcb_version2_counts, +}; + +static unsigned int rpcb_version3_counts[ARRAY_SIZE(rpcb_procedures3)]; +static const struct rpc_version rpcb_version3 = { + .number = RPCBVERS_3, + .nrprocs = ARRAY_SIZE(rpcb_procedures3), + .procs = rpcb_procedures3, + .counts = rpcb_version3_counts, +}; + +static unsigned int rpcb_version4_counts[ARRAY_SIZE(rpcb_procedures4)]; +static const struct rpc_version rpcb_version4 = { + .number = RPCBVERS_4, + .nrprocs = ARRAY_SIZE(rpcb_procedures4), + .procs = rpcb_procedures4, + .counts = rpcb_version4_counts, +}; + +static const struct rpc_version *rpcb_version[] = { + NULL, + NULL, + &rpcb_version2, + &rpcb_version3, + &rpcb_version4 +}; + +static struct rpc_stat rpcb_stats; + +static const struct rpc_program rpcb_program = { + .name = "rpcbind", + .number = RPCBIND_PROGRAM, + .nrvers = ARRAY_SIZE(rpcb_version), + .version = rpcb_version, + .stats = &rpcb_stats, +}; diff --git a/net/sunrpc/sched.c b/net/sunrpc/sched.c new file mode 100644 index 000000000..6debf4fd4 --- /dev/null +++ b/net/sunrpc/sched.c @@ -0,0 +1,1361 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * linux/net/sunrpc/sched.c + * + * Scheduling for synchronous and asynchronous RPC requests. + * + * Copyright (C) 1996 Olaf Kirch, + * + * TCP NFS related read + write fixes + * (C) 1999 Dave Airlie, University of Limerick, Ireland + */ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "sunrpc.h" + +#define CREATE_TRACE_POINTS +#include + +/* + * RPC slabs and memory pools + */ +#define RPC_BUFFER_MAXSIZE (2048) +#define RPC_BUFFER_POOLSIZE (8) +#define RPC_TASK_POOLSIZE (8) +static struct kmem_cache *rpc_task_slabp __read_mostly; +static struct kmem_cache *rpc_buffer_slabp __read_mostly; +static mempool_t *rpc_task_mempool __read_mostly; +static mempool_t *rpc_buffer_mempool __read_mostly; + +static void rpc_async_schedule(struct work_struct *); +static void rpc_release_task(struct rpc_task *task); +static void __rpc_queue_timer_fn(struct work_struct *); + +/* + * RPC tasks sit here while waiting for conditions to improve. + */ +static struct rpc_wait_queue delay_queue; + +/* + * rpciod-related stuff + */ +struct workqueue_struct *rpciod_workqueue __read_mostly; +struct workqueue_struct *xprtiod_workqueue __read_mostly; +EXPORT_SYMBOL_GPL(xprtiod_workqueue); + +gfp_t rpc_task_gfp_mask(void) +{ + if (current->flags & PF_WQ_WORKER) + return GFP_KERNEL | __GFP_NORETRY | __GFP_NOWARN; + return GFP_KERNEL; +} +EXPORT_SYMBOL_GPL(rpc_task_gfp_mask); + +bool rpc_task_set_rpc_status(struct rpc_task *task, int rpc_status) +{ + if (cmpxchg(&task->tk_rpc_status, 0, rpc_status) == 0) + return true; + return false; +} + +unsigned long +rpc_task_timeout(const struct rpc_task *task) +{ + unsigned long timeout = READ_ONCE(task->tk_timeout); + + if (timeout != 0) { + unsigned long now = jiffies; + if (time_before(now, timeout)) + return timeout - now; + } + return 0; +} +EXPORT_SYMBOL_GPL(rpc_task_timeout); + +/* + * Disable the timer for a given RPC task. Should be called with + * queue->lock and bh_disabled in order to avoid races within + * rpc_run_timer(). + */ +static void +__rpc_disable_timer(struct rpc_wait_queue *queue, struct rpc_task *task) +{ + if (list_empty(&task->u.tk_wait.timer_list)) + return; + task->tk_timeout = 0; + list_del(&task->u.tk_wait.timer_list); + if (list_empty(&queue->timer_list.list)) + cancel_delayed_work(&queue->timer_list.dwork); +} + +static void +rpc_set_queue_timer(struct rpc_wait_queue *queue, unsigned long expires) +{ + unsigned long now = jiffies; + queue->timer_list.expires = expires; + if (time_before_eq(expires, now)) + expires = 0; + else + expires -= now; + mod_delayed_work(rpciod_workqueue, &queue->timer_list.dwork, expires); +} + +/* + * Set up a timer for the current task. + */ +static void +__rpc_add_timer(struct rpc_wait_queue *queue, struct rpc_task *task, + unsigned long timeout) +{ + task->tk_timeout = timeout; + if (list_empty(&queue->timer_list.list) || time_before(timeout, queue->timer_list.expires)) + rpc_set_queue_timer(queue, timeout); + list_add(&task->u.tk_wait.timer_list, &queue->timer_list.list); +} + +static void rpc_set_waitqueue_priority(struct rpc_wait_queue *queue, int priority) +{ + if (queue->priority != priority) { + queue->priority = priority; + queue->nr = 1U << priority; + } +} + +static void rpc_reset_waitqueue_priority(struct rpc_wait_queue *queue) +{ + rpc_set_waitqueue_priority(queue, queue->maxpriority); +} + +/* + * Add a request to a queue list + */ +static void +__rpc_list_enqueue_task(struct list_head *q, struct rpc_task *task) +{ + struct rpc_task *t; + + list_for_each_entry(t, q, u.tk_wait.list) { + if (t->tk_owner == task->tk_owner) { + list_add_tail(&task->u.tk_wait.links, + &t->u.tk_wait.links); + /* Cache the queue head in task->u.tk_wait.list */ + task->u.tk_wait.list.next = q; + task->u.tk_wait.list.prev = NULL; + return; + } + } + INIT_LIST_HEAD(&task->u.tk_wait.links); + list_add_tail(&task->u.tk_wait.list, q); +} + +/* + * Remove request from a queue list + */ +static void +__rpc_list_dequeue_task(struct rpc_task *task) +{ + struct list_head *q; + struct rpc_task *t; + + if (task->u.tk_wait.list.prev == NULL) { + list_del(&task->u.tk_wait.links); + return; + } + if (!list_empty(&task->u.tk_wait.links)) { + t = list_first_entry(&task->u.tk_wait.links, + struct rpc_task, + u.tk_wait.links); + /* Assume __rpc_list_enqueue_task() cached the queue head */ + q = t->u.tk_wait.list.next; + list_add_tail(&t->u.tk_wait.list, q); + list_del(&task->u.tk_wait.links); + } + list_del(&task->u.tk_wait.list); +} + +/* + * Add new request to a priority queue. + */ +static void __rpc_add_wait_queue_priority(struct rpc_wait_queue *queue, + struct rpc_task *task, + unsigned char queue_priority) +{ + if (unlikely(queue_priority > queue->maxpriority)) + queue_priority = queue->maxpriority; + __rpc_list_enqueue_task(&queue->tasks[queue_priority], task); +} + +/* + * Add new request to wait queue. + */ +static void __rpc_add_wait_queue(struct rpc_wait_queue *queue, + struct rpc_task *task, + unsigned char queue_priority) +{ + INIT_LIST_HEAD(&task->u.tk_wait.timer_list); + if (RPC_IS_PRIORITY(queue)) + __rpc_add_wait_queue_priority(queue, task, queue_priority); + else + list_add_tail(&task->u.tk_wait.list, &queue->tasks[0]); + task->tk_waitqueue = queue; + queue->qlen++; + /* barrier matches the read in rpc_wake_up_task_queue_locked() */ + smp_wmb(); + rpc_set_queued(task); +} + +/* + * Remove request from a priority queue. + */ +static void __rpc_remove_wait_queue_priority(struct rpc_task *task) +{ + __rpc_list_dequeue_task(task); +} + +/* + * Remove request from queue. + * Note: must be called with spin lock held. + */ +static void __rpc_remove_wait_queue(struct rpc_wait_queue *queue, struct rpc_task *task) +{ + __rpc_disable_timer(queue, task); + if (RPC_IS_PRIORITY(queue)) + __rpc_remove_wait_queue_priority(task); + else + list_del(&task->u.tk_wait.list); + queue->qlen--; +} + +static void __rpc_init_priority_wait_queue(struct rpc_wait_queue *queue, const char *qname, unsigned char nr_queues) +{ + int i; + + spin_lock_init(&queue->lock); + for (i = 0; i < ARRAY_SIZE(queue->tasks); i++) + INIT_LIST_HEAD(&queue->tasks[i]); + queue->maxpriority = nr_queues - 1; + rpc_reset_waitqueue_priority(queue); + queue->qlen = 0; + queue->timer_list.expires = 0; + INIT_DELAYED_WORK(&queue->timer_list.dwork, __rpc_queue_timer_fn); + INIT_LIST_HEAD(&queue->timer_list.list); + rpc_assign_waitqueue_name(queue, qname); +} + +void rpc_init_priority_wait_queue(struct rpc_wait_queue *queue, const char *qname) +{ + __rpc_init_priority_wait_queue(queue, qname, RPC_NR_PRIORITY); +} +EXPORT_SYMBOL_GPL(rpc_init_priority_wait_queue); + +void rpc_init_wait_queue(struct rpc_wait_queue *queue, const char *qname) +{ + __rpc_init_priority_wait_queue(queue, qname, 1); +} +EXPORT_SYMBOL_GPL(rpc_init_wait_queue); + +void rpc_destroy_wait_queue(struct rpc_wait_queue *queue) +{ + cancel_delayed_work_sync(&queue->timer_list.dwork); +} +EXPORT_SYMBOL_GPL(rpc_destroy_wait_queue); + +static int rpc_wait_bit_killable(struct wait_bit_key *key, int mode) +{ + schedule(); + if (signal_pending_state(mode, current)) + return -ERESTARTSYS; + return 0; +} + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) || IS_ENABLED(CONFIG_TRACEPOINTS) +static void rpc_task_set_debuginfo(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + + /* Might be a task carrying a reverse-direction operation */ + if (!clnt) { + static atomic_t rpc_pid; + + task->tk_pid = atomic_inc_return(&rpc_pid); + return; + } + + task->tk_pid = atomic_inc_return(&clnt->cl_pid); +} +#else +static inline void rpc_task_set_debuginfo(struct rpc_task *task) +{ +} +#endif + +static void rpc_set_active(struct rpc_task *task) +{ + rpc_task_set_debuginfo(task); + set_bit(RPC_TASK_ACTIVE, &task->tk_runstate); + trace_rpc_task_begin(task, NULL); +} + +/* + * Mark an RPC call as having completed by clearing the 'active' bit + * and then waking up all tasks that were sleeping. + */ +static int rpc_complete_task(struct rpc_task *task) +{ + void *m = &task->tk_runstate; + wait_queue_head_t *wq = bit_waitqueue(m, RPC_TASK_ACTIVE); + struct wait_bit_key k = __WAIT_BIT_KEY_INITIALIZER(m, RPC_TASK_ACTIVE); + unsigned long flags; + int ret; + + trace_rpc_task_complete(task, NULL); + + spin_lock_irqsave(&wq->lock, flags); + clear_bit(RPC_TASK_ACTIVE, &task->tk_runstate); + ret = atomic_dec_and_test(&task->tk_count); + if (waitqueue_active(wq)) + __wake_up_locked_key(wq, TASK_NORMAL, &k); + spin_unlock_irqrestore(&wq->lock, flags); + return ret; +} + +/* + * Allow callers to wait for completion of an RPC call + * + * Note the use of out_of_line_wait_on_bit() rather than wait_on_bit() + * to enforce taking of the wq->lock and hence avoid races with + * rpc_complete_task(). + */ +int rpc_wait_for_completion_task(struct rpc_task *task) +{ + return out_of_line_wait_on_bit(&task->tk_runstate, RPC_TASK_ACTIVE, + rpc_wait_bit_killable, TASK_KILLABLE|TASK_FREEZABLE_UNSAFE); +} +EXPORT_SYMBOL_GPL(rpc_wait_for_completion_task); + +/* + * Make an RPC task runnable. + * + * Note: If the task is ASYNC, and is being made runnable after sitting on an + * rpc_wait_queue, this must be called with the queue spinlock held to protect + * the wait queue operation. + * Note the ordering of rpc_test_and_set_running() and rpc_clear_queued(), + * which is needed to ensure that __rpc_execute() doesn't loop (due to the + * lockless RPC_IS_QUEUED() test) before we've had a chance to test + * the RPC_TASK_RUNNING flag. + */ +static void rpc_make_runnable(struct workqueue_struct *wq, + struct rpc_task *task) +{ + bool need_wakeup = !rpc_test_and_set_running(task); + + rpc_clear_queued(task); + if (!need_wakeup) + return; + if (RPC_IS_ASYNC(task)) { + INIT_WORK(&task->u.tk_work, rpc_async_schedule); + queue_work(wq, &task->u.tk_work); + } else + wake_up_bit(&task->tk_runstate, RPC_TASK_QUEUED); +} + +/* + * Prepare for sleeping on a wait queue. + * By always appending tasks to the list we ensure FIFO behavior. + * NB: An RPC task will only receive interrupt-driven events as long + * as it's on a wait queue. + */ +static void __rpc_do_sleep_on_priority(struct rpc_wait_queue *q, + struct rpc_task *task, + unsigned char queue_priority) +{ + trace_rpc_task_sleep(task, q); + + __rpc_add_wait_queue(q, task, queue_priority); +} + +static void __rpc_sleep_on_priority(struct rpc_wait_queue *q, + struct rpc_task *task, + unsigned char queue_priority) +{ + if (WARN_ON_ONCE(RPC_IS_QUEUED(task))) + return; + __rpc_do_sleep_on_priority(q, task, queue_priority); +} + +static void __rpc_sleep_on_priority_timeout(struct rpc_wait_queue *q, + struct rpc_task *task, unsigned long timeout, + unsigned char queue_priority) +{ + if (WARN_ON_ONCE(RPC_IS_QUEUED(task))) + return; + if (time_is_after_jiffies(timeout)) { + __rpc_do_sleep_on_priority(q, task, queue_priority); + __rpc_add_timer(q, task, timeout); + } else + task->tk_status = -ETIMEDOUT; +} + +static void rpc_set_tk_callback(struct rpc_task *task, rpc_action action) +{ + if (action && !WARN_ON_ONCE(task->tk_callback != NULL)) + task->tk_callback = action; +} + +static bool rpc_sleep_check_activated(struct rpc_task *task) +{ + /* We shouldn't ever put an inactive task to sleep */ + if (WARN_ON_ONCE(!RPC_IS_ACTIVATED(task))) { + task->tk_status = -EIO; + rpc_put_task_async(task); + return false; + } + return true; +} + +void rpc_sleep_on_timeout(struct rpc_wait_queue *q, struct rpc_task *task, + rpc_action action, unsigned long timeout) +{ + if (!rpc_sleep_check_activated(task)) + return; + + rpc_set_tk_callback(task, action); + + /* + * Protect the queue operations. + */ + spin_lock(&q->lock); + __rpc_sleep_on_priority_timeout(q, task, timeout, task->tk_priority); + spin_unlock(&q->lock); +} +EXPORT_SYMBOL_GPL(rpc_sleep_on_timeout); + +void rpc_sleep_on(struct rpc_wait_queue *q, struct rpc_task *task, + rpc_action action) +{ + if (!rpc_sleep_check_activated(task)) + return; + + rpc_set_tk_callback(task, action); + + WARN_ON_ONCE(task->tk_timeout != 0); + /* + * Protect the queue operations. + */ + spin_lock(&q->lock); + __rpc_sleep_on_priority(q, task, task->tk_priority); + spin_unlock(&q->lock); +} +EXPORT_SYMBOL_GPL(rpc_sleep_on); + +void rpc_sleep_on_priority_timeout(struct rpc_wait_queue *q, + struct rpc_task *task, unsigned long timeout, int priority) +{ + if (!rpc_sleep_check_activated(task)) + return; + + priority -= RPC_PRIORITY_LOW; + /* + * Protect the queue operations. + */ + spin_lock(&q->lock); + __rpc_sleep_on_priority_timeout(q, task, timeout, priority); + spin_unlock(&q->lock); +} +EXPORT_SYMBOL_GPL(rpc_sleep_on_priority_timeout); + +void rpc_sleep_on_priority(struct rpc_wait_queue *q, struct rpc_task *task, + int priority) +{ + if (!rpc_sleep_check_activated(task)) + return; + + WARN_ON_ONCE(task->tk_timeout != 0); + priority -= RPC_PRIORITY_LOW; + /* + * Protect the queue operations. + */ + spin_lock(&q->lock); + __rpc_sleep_on_priority(q, task, priority); + spin_unlock(&q->lock); +} +EXPORT_SYMBOL_GPL(rpc_sleep_on_priority); + +/** + * __rpc_do_wake_up_task_on_wq - wake up a single rpc_task + * @wq: workqueue on which to run task + * @queue: wait queue + * @task: task to be woken up + * + * Caller must hold queue->lock, and have cleared the task queued flag. + */ +static void __rpc_do_wake_up_task_on_wq(struct workqueue_struct *wq, + struct rpc_wait_queue *queue, + struct rpc_task *task) +{ + /* Has the task been executed yet? If not, we cannot wake it up! */ + if (!RPC_IS_ACTIVATED(task)) { + printk(KERN_ERR "RPC: Inactive task (%p) being woken up!\n", task); + return; + } + + trace_rpc_task_wakeup(task, queue); + + __rpc_remove_wait_queue(queue, task); + + rpc_make_runnable(wq, task); +} + +/* + * Wake up a queued task while the queue lock is being held + */ +static struct rpc_task * +rpc_wake_up_task_on_wq_queue_action_locked(struct workqueue_struct *wq, + struct rpc_wait_queue *queue, struct rpc_task *task, + bool (*action)(struct rpc_task *, void *), void *data) +{ + if (RPC_IS_QUEUED(task)) { + smp_rmb(); + if (task->tk_waitqueue == queue) { + if (action == NULL || action(task, data)) { + __rpc_do_wake_up_task_on_wq(wq, queue, task); + return task; + } + } + } + return NULL; +} + +/* + * Wake up a queued task while the queue lock is being held + */ +static void rpc_wake_up_task_queue_locked(struct rpc_wait_queue *queue, + struct rpc_task *task) +{ + rpc_wake_up_task_on_wq_queue_action_locked(rpciod_workqueue, queue, + task, NULL, NULL); +} + +/* + * Wake up a task on a specific queue + */ +void rpc_wake_up_queued_task(struct rpc_wait_queue *queue, struct rpc_task *task) +{ + if (!RPC_IS_QUEUED(task)) + return; + spin_lock(&queue->lock); + rpc_wake_up_task_queue_locked(queue, task); + spin_unlock(&queue->lock); +} +EXPORT_SYMBOL_GPL(rpc_wake_up_queued_task); + +static bool rpc_task_action_set_status(struct rpc_task *task, void *status) +{ + task->tk_status = *(int *)status; + return true; +} + +static void +rpc_wake_up_task_queue_set_status_locked(struct rpc_wait_queue *queue, + struct rpc_task *task, int status) +{ + rpc_wake_up_task_on_wq_queue_action_locked(rpciod_workqueue, queue, + task, rpc_task_action_set_status, &status); +} + +/** + * rpc_wake_up_queued_task_set_status - wake up a task and set task->tk_status + * @queue: pointer to rpc_wait_queue + * @task: pointer to rpc_task + * @status: integer error value + * + * If @task is queued on @queue, then it is woken up, and @task->tk_status is + * set to the value of @status. + */ +void +rpc_wake_up_queued_task_set_status(struct rpc_wait_queue *queue, + struct rpc_task *task, int status) +{ + if (!RPC_IS_QUEUED(task)) + return; + spin_lock(&queue->lock); + rpc_wake_up_task_queue_set_status_locked(queue, task, status); + spin_unlock(&queue->lock); +} + +/* + * Wake up the next task on a priority queue. + */ +static struct rpc_task *__rpc_find_next_queued_priority(struct rpc_wait_queue *queue) +{ + struct list_head *q; + struct rpc_task *task; + + /* + * Service the privileged queue. + */ + q = &queue->tasks[RPC_NR_PRIORITY - 1]; + if (queue->maxpriority > RPC_PRIORITY_PRIVILEGED && !list_empty(q)) { + task = list_first_entry(q, struct rpc_task, u.tk_wait.list); + goto out; + } + + /* + * Service a batch of tasks from a single owner. + */ + q = &queue->tasks[queue->priority]; + if (!list_empty(q) && queue->nr) { + queue->nr--; + task = list_first_entry(q, struct rpc_task, u.tk_wait.list); + goto out; + } + + /* + * Service the next queue. + */ + do { + if (q == &queue->tasks[0]) + q = &queue->tasks[queue->maxpriority]; + else + q = q - 1; + if (!list_empty(q)) { + task = list_first_entry(q, struct rpc_task, u.tk_wait.list); + goto new_queue; + } + } while (q != &queue->tasks[queue->priority]); + + rpc_reset_waitqueue_priority(queue); + return NULL; + +new_queue: + rpc_set_waitqueue_priority(queue, (unsigned int)(q - &queue->tasks[0])); +out: + return task; +} + +static struct rpc_task *__rpc_find_next_queued(struct rpc_wait_queue *queue) +{ + if (RPC_IS_PRIORITY(queue)) + return __rpc_find_next_queued_priority(queue); + if (!list_empty(&queue->tasks[0])) + return list_first_entry(&queue->tasks[0], struct rpc_task, u.tk_wait.list); + return NULL; +} + +/* + * Wake up the first task on the wait queue. + */ +struct rpc_task *rpc_wake_up_first_on_wq(struct workqueue_struct *wq, + struct rpc_wait_queue *queue, + bool (*func)(struct rpc_task *, void *), void *data) +{ + struct rpc_task *task = NULL; + + spin_lock(&queue->lock); + task = __rpc_find_next_queued(queue); + if (task != NULL) + task = rpc_wake_up_task_on_wq_queue_action_locked(wq, queue, + task, func, data); + spin_unlock(&queue->lock); + + return task; +} + +/* + * Wake up the first task on the wait queue. + */ +struct rpc_task *rpc_wake_up_first(struct rpc_wait_queue *queue, + bool (*func)(struct rpc_task *, void *), void *data) +{ + return rpc_wake_up_first_on_wq(rpciod_workqueue, queue, func, data); +} +EXPORT_SYMBOL_GPL(rpc_wake_up_first); + +static bool rpc_wake_up_next_func(struct rpc_task *task, void *data) +{ + return true; +} + +/* + * Wake up the next task on the wait queue. +*/ +struct rpc_task *rpc_wake_up_next(struct rpc_wait_queue *queue) +{ + return rpc_wake_up_first(queue, rpc_wake_up_next_func, NULL); +} +EXPORT_SYMBOL_GPL(rpc_wake_up_next); + +/** + * rpc_wake_up_locked - wake up all rpc_tasks + * @queue: rpc_wait_queue on which the tasks are sleeping + * + */ +static void rpc_wake_up_locked(struct rpc_wait_queue *queue) +{ + struct rpc_task *task; + + for (;;) { + task = __rpc_find_next_queued(queue); + if (task == NULL) + break; + rpc_wake_up_task_queue_locked(queue, task); + } +} + +/** + * rpc_wake_up - wake up all rpc_tasks + * @queue: rpc_wait_queue on which the tasks are sleeping + * + * Grabs queue->lock + */ +void rpc_wake_up(struct rpc_wait_queue *queue) +{ + spin_lock(&queue->lock); + rpc_wake_up_locked(queue); + spin_unlock(&queue->lock); +} +EXPORT_SYMBOL_GPL(rpc_wake_up); + +/** + * rpc_wake_up_status_locked - wake up all rpc_tasks and set their status value. + * @queue: rpc_wait_queue on which the tasks are sleeping + * @status: status value to set + */ +static void rpc_wake_up_status_locked(struct rpc_wait_queue *queue, int status) +{ + struct rpc_task *task; + + for (;;) { + task = __rpc_find_next_queued(queue); + if (task == NULL) + break; + rpc_wake_up_task_queue_set_status_locked(queue, task, status); + } +} + +/** + * rpc_wake_up_status - wake up all rpc_tasks and set their status value. + * @queue: rpc_wait_queue on which the tasks are sleeping + * @status: status value to set + * + * Grabs queue->lock + */ +void rpc_wake_up_status(struct rpc_wait_queue *queue, int status) +{ + spin_lock(&queue->lock); + rpc_wake_up_status_locked(queue, status); + spin_unlock(&queue->lock); +} +EXPORT_SYMBOL_GPL(rpc_wake_up_status); + +static void __rpc_queue_timer_fn(struct work_struct *work) +{ + struct rpc_wait_queue *queue = container_of(work, + struct rpc_wait_queue, + timer_list.dwork.work); + struct rpc_task *task, *n; + unsigned long expires, now, timeo; + + spin_lock(&queue->lock); + expires = now = jiffies; + list_for_each_entry_safe(task, n, &queue->timer_list.list, u.tk_wait.timer_list) { + timeo = task->tk_timeout; + if (time_after_eq(now, timeo)) { + trace_rpc_task_timeout(task, task->tk_action); + task->tk_status = -ETIMEDOUT; + rpc_wake_up_task_queue_locked(queue, task); + continue; + } + if (expires == now || time_after(expires, timeo)) + expires = timeo; + } + if (!list_empty(&queue->timer_list.list)) + rpc_set_queue_timer(queue, expires); + spin_unlock(&queue->lock); +} + +static void __rpc_atrun(struct rpc_task *task) +{ + if (task->tk_status == -ETIMEDOUT) + task->tk_status = 0; +} + +/* + * Run a task at a later time + */ +void rpc_delay(struct rpc_task *task, unsigned long delay) +{ + rpc_sleep_on_timeout(&delay_queue, task, __rpc_atrun, jiffies + delay); +} +EXPORT_SYMBOL_GPL(rpc_delay); + +/* + * Helper to call task->tk_ops->rpc_call_prepare + */ +void rpc_prepare_task(struct rpc_task *task) +{ + task->tk_ops->rpc_call_prepare(task, task->tk_calldata); +} + +static void +rpc_init_task_statistics(struct rpc_task *task) +{ + /* Initialize retry counters */ + task->tk_garb_retry = 2; + task->tk_cred_retry = 2; + + /* starting timestamp */ + task->tk_start = ktime_get(); +} + +static void +rpc_reset_task_statistics(struct rpc_task *task) +{ + task->tk_timeouts = 0; + task->tk_flags &= ~(RPC_CALL_MAJORSEEN|RPC_TASK_SENT); + rpc_init_task_statistics(task); +} + +/* + * Helper that calls task->tk_ops->rpc_call_done if it exists + */ +void rpc_exit_task(struct rpc_task *task) +{ + trace_rpc_task_end(task, task->tk_action); + task->tk_action = NULL; + if (task->tk_ops->rpc_count_stats) + task->tk_ops->rpc_count_stats(task, task->tk_calldata); + else if (task->tk_client) + rpc_count_iostats(task, task->tk_client->cl_metrics); + if (task->tk_ops->rpc_call_done != NULL) { + trace_rpc_task_call_done(task, task->tk_ops->rpc_call_done); + task->tk_ops->rpc_call_done(task, task->tk_calldata); + if (task->tk_action != NULL) { + /* Always release the RPC slot and buffer memory */ + xprt_release(task); + rpc_reset_task_statistics(task); + } + } +} + +void rpc_signal_task(struct rpc_task *task) +{ + struct rpc_wait_queue *queue; + + if (!RPC_IS_ACTIVATED(task)) + return; + + if (!rpc_task_set_rpc_status(task, -ERESTARTSYS)) + return; + trace_rpc_task_signalled(task, task->tk_action); + set_bit(RPC_TASK_SIGNALLED, &task->tk_runstate); + smp_mb__after_atomic(); + queue = READ_ONCE(task->tk_waitqueue); + if (queue) + rpc_wake_up_queued_task(queue, task); +} + +void rpc_task_try_cancel(struct rpc_task *task, int error) +{ + struct rpc_wait_queue *queue; + + if (!rpc_task_set_rpc_status(task, error)) + return; + queue = READ_ONCE(task->tk_waitqueue); + if (queue) + rpc_wake_up_queued_task(queue, task); +} + +void rpc_exit(struct rpc_task *task, int status) +{ + task->tk_status = status; + task->tk_action = rpc_exit_task; + rpc_wake_up_queued_task(task->tk_waitqueue, task); +} +EXPORT_SYMBOL_GPL(rpc_exit); + +void rpc_release_calldata(const struct rpc_call_ops *ops, void *calldata) +{ + if (ops->rpc_release != NULL) + ops->rpc_release(calldata); +} + +static bool xprt_needs_memalloc(struct rpc_xprt *xprt, struct rpc_task *tk) +{ + if (!xprt) + return false; + if (!atomic_read(&xprt->swapper)) + return false; + return test_bit(XPRT_LOCKED, &xprt->state) && xprt->snd_task == tk; +} + +/* + * This is the RPC `scheduler' (or rather, the finite state machine). + */ +static void __rpc_execute(struct rpc_task *task) +{ + struct rpc_wait_queue *queue; + int task_is_async = RPC_IS_ASYNC(task); + int status = 0; + unsigned long pflags = current->flags; + + WARN_ON_ONCE(RPC_IS_QUEUED(task)); + if (RPC_IS_QUEUED(task)) + return; + + for (;;) { + void (*do_action)(struct rpc_task *); + + /* + * Perform the next FSM step or a pending callback. + * + * tk_action may be NULL if the task has been killed. + */ + do_action = task->tk_action; + /* Tasks with an RPC error status should exit */ + if (do_action && do_action != rpc_exit_task && + (status = READ_ONCE(task->tk_rpc_status)) != 0) { + task->tk_status = status; + do_action = rpc_exit_task; + } + /* Callbacks override all actions */ + if (task->tk_callback) { + do_action = task->tk_callback; + task->tk_callback = NULL; + } + if (!do_action) + break; + if (RPC_IS_SWAPPER(task) || + xprt_needs_memalloc(task->tk_xprt, task)) + current->flags |= PF_MEMALLOC; + + trace_rpc_task_run_action(task, do_action); + do_action(task); + + /* + * Lockless check for whether task is sleeping or not. + */ + if (!RPC_IS_QUEUED(task)) { + cond_resched(); + continue; + } + + /* + * The queue->lock protects against races with + * rpc_make_runnable(). + * + * Note that once we clear RPC_TASK_RUNNING on an asynchronous + * rpc_task, rpc_make_runnable() can assign it to a + * different workqueue. We therefore cannot assume that the + * rpc_task pointer may still be dereferenced. + */ + queue = task->tk_waitqueue; + spin_lock(&queue->lock); + if (!RPC_IS_QUEUED(task)) { + spin_unlock(&queue->lock); + continue; + } + /* Wake up any task that has an exit status */ + if (READ_ONCE(task->tk_rpc_status) != 0) { + rpc_wake_up_task_queue_locked(queue, task); + spin_unlock(&queue->lock); + continue; + } + rpc_clear_running(task); + spin_unlock(&queue->lock); + if (task_is_async) + goto out; + + /* sync task: sleep here */ + trace_rpc_task_sync_sleep(task, task->tk_action); + status = out_of_line_wait_on_bit(&task->tk_runstate, + RPC_TASK_QUEUED, rpc_wait_bit_killable, + TASK_KILLABLE|TASK_FREEZABLE); + if (status < 0) { + /* + * When a sync task receives a signal, it exits with + * -ERESTARTSYS. In order to catch any callbacks that + * clean up after sleeping on some queue, we don't + * break the loop here, but go around once more. + */ + rpc_signal_task(task); + } + trace_rpc_task_sync_wake(task, task->tk_action); + } + + /* Release all resources associated with the task */ + rpc_release_task(task); +out: + current_restore_flags(pflags, PF_MEMALLOC); +} + +/* + * User-visible entry point to the scheduler. + * + * This may be called recursively if e.g. an async NFS task updates + * the attributes and finds that dirty pages must be flushed. + * NOTE: Upon exit of this function the task is guaranteed to be + * released. In particular note that tk_release() will have + * been called, so your task memory may have been freed. + */ +void rpc_execute(struct rpc_task *task) +{ + bool is_async = RPC_IS_ASYNC(task); + + rpc_set_active(task); + rpc_make_runnable(rpciod_workqueue, task); + if (!is_async) { + unsigned int pflags = memalloc_nofs_save(); + __rpc_execute(task); + memalloc_nofs_restore(pflags); + } +} + +static void rpc_async_schedule(struct work_struct *work) +{ + unsigned int pflags = memalloc_nofs_save(); + + __rpc_execute(container_of(work, struct rpc_task, u.tk_work)); + memalloc_nofs_restore(pflags); +} + +/** + * rpc_malloc - allocate RPC buffer resources + * @task: RPC task + * + * A single memory region is allocated, which is split between the + * RPC call and RPC reply that this task is being used for. When + * this RPC is retired, the memory is released by calling rpc_free. + * + * To prevent rpciod from hanging, this allocator never sleeps, + * returning -ENOMEM and suppressing warning if the request cannot + * be serviced immediately. The caller can arrange to sleep in a + * way that is safe for rpciod. + * + * Most requests are 'small' (under 2KiB) and can be serviced from a + * mempool, ensuring that NFS reads and writes can always proceed, + * and that there is good locality of reference for these buffers. + */ +int rpc_malloc(struct rpc_task *task) +{ + struct rpc_rqst *rqst = task->tk_rqstp; + size_t size = rqst->rq_callsize + rqst->rq_rcvsize; + struct rpc_buffer *buf; + gfp_t gfp = rpc_task_gfp_mask(); + + size += sizeof(struct rpc_buffer); + if (size <= RPC_BUFFER_MAXSIZE) { + buf = kmem_cache_alloc(rpc_buffer_slabp, gfp); + /* Reach for the mempool if dynamic allocation fails */ + if (!buf && RPC_IS_ASYNC(task)) + buf = mempool_alloc(rpc_buffer_mempool, GFP_NOWAIT); + } else + buf = kmalloc(size, gfp); + + if (!buf) + return -ENOMEM; + + buf->len = size; + rqst->rq_buffer = buf->data; + rqst->rq_rbuffer = (char *)rqst->rq_buffer + rqst->rq_callsize; + return 0; +} +EXPORT_SYMBOL_GPL(rpc_malloc); + +/** + * rpc_free - free RPC buffer resources allocated via rpc_malloc + * @task: RPC task + * + */ +void rpc_free(struct rpc_task *task) +{ + void *buffer = task->tk_rqstp->rq_buffer; + size_t size; + struct rpc_buffer *buf; + + buf = container_of(buffer, struct rpc_buffer, data); + size = buf->len; + + if (size <= RPC_BUFFER_MAXSIZE) + mempool_free(buf, rpc_buffer_mempool); + else + kfree(buf); +} +EXPORT_SYMBOL_GPL(rpc_free); + +/* + * Creation and deletion of RPC task structures + */ +static void rpc_init_task(struct rpc_task *task, const struct rpc_task_setup *task_setup_data) +{ + memset(task, 0, sizeof(*task)); + atomic_set(&task->tk_count, 1); + task->tk_flags = task_setup_data->flags; + task->tk_ops = task_setup_data->callback_ops; + task->tk_calldata = task_setup_data->callback_data; + INIT_LIST_HEAD(&task->tk_task); + + task->tk_priority = task_setup_data->priority - RPC_PRIORITY_LOW; + task->tk_owner = current->tgid; + + /* Initialize workqueue for async tasks */ + task->tk_workqueue = task_setup_data->workqueue; + + task->tk_xprt = rpc_task_get_xprt(task_setup_data->rpc_client, + xprt_get(task_setup_data->rpc_xprt)); + + task->tk_op_cred = get_rpccred(task_setup_data->rpc_op_cred); + + if (task->tk_ops->rpc_call_prepare != NULL) + task->tk_action = rpc_prepare_task; + + rpc_init_task_statistics(task); +} + +static struct rpc_task *rpc_alloc_task(void) +{ + struct rpc_task *task; + + task = kmem_cache_alloc(rpc_task_slabp, rpc_task_gfp_mask()); + if (task) + return task; + return mempool_alloc(rpc_task_mempool, GFP_NOWAIT); +} + +/* + * Create a new task for the specified client. + */ +struct rpc_task *rpc_new_task(const struct rpc_task_setup *setup_data) +{ + struct rpc_task *task = setup_data->task; + unsigned short flags = 0; + + if (task == NULL) { + task = rpc_alloc_task(); + if (task == NULL) { + rpc_release_calldata(setup_data->callback_ops, + setup_data->callback_data); + return ERR_PTR(-ENOMEM); + } + flags = RPC_TASK_DYNAMIC; + } + + rpc_init_task(task, setup_data); + task->tk_flags |= flags; + return task; +} + +/* + * rpc_free_task - release rpc task and perform cleanups + * + * Note that we free up the rpc_task _after_ rpc_release_calldata() + * in order to work around a workqueue dependency issue. + * + * Tejun Heo states: + * "Workqueue currently considers two work items to be the same if they're + * on the same address and won't execute them concurrently - ie. it + * makes a work item which is queued again while being executed wait + * for the previous execution to complete. + * + * If a work function frees the work item, and then waits for an event + * which should be performed by another work item and *that* work item + * recycles the freed work item, it can create a false dependency loop. + * There really is no reliable way to detect this short of verifying + * every memory free." + * + */ +static void rpc_free_task(struct rpc_task *task) +{ + unsigned short tk_flags = task->tk_flags; + + put_rpccred(task->tk_op_cred); + rpc_release_calldata(task->tk_ops, task->tk_calldata); + + if (tk_flags & RPC_TASK_DYNAMIC) + mempool_free(task, rpc_task_mempool); +} + +static void rpc_async_release(struct work_struct *work) +{ + unsigned int pflags = memalloc_nofs_save(); + + rpc_free_task(container_of(work, struct rpc_task, u.tk_work)); + memalloc_nofs_restore(pflags); +} + +static void rpc_release_resources_task(struct rpc_task *task) +{ + xprt_release(task); + if (task->tk_msg.rpc_cred) { + if (!(task->tk_flags & RPC_TASK_CRED_NOREF)) + put_cred(task->tk_msg.rpc_cred); + task->tk_msg.rpc_cred = NULL; + } + rpc_task_release_client(task); +} + +static void rpc_final_put_task(struct rpc_task *task, + struct workqueue_struct *q) +{ + if (q != NULL) { + INIT_WORK(&task->u.tk_work, rpc_async_release); + queue_work(q, &task->u.tk_work); + } else + rpc_free_task(task); +} + +static void rpc_do_put_task(struct rpc_task *task, struct workqueue_struct *q) +{ + if (atomic_dec_and_test(&task->tk_count)) { + rpc_release_resources_task(task); + rpc_final_put_task(task, q); + } +} + +void rpc_put_task(struct rpc_task *task) +{ + rpc_do_put_task(task, NULL); +} +EXPORT_SYMBOL_GPL(rpc_put_task); + +void rpc_put_task_async(struct rpc_task *task) +{ + rpc_do_put_task(task, task->tk_workqueue); +} +EXPORT_SYMBOL_GPL(rpc_put_task_async); + +static void rpc_release_task(struct rpc_task *task) +{ + WARN_ON_ONCE(RPC_IS_QUEUED(task)); + + rpc_release_resources_task(task); + + /* + * Note: at this point we have been removed from rpc_clnt->cl_tasks, + * so it should be safe to use task->tk_count as a test for whether + * or not any other processes still hold references to our rpc_task. + */ + if (atomic_read(&task->tk_count) != 1 + !RPC_IS_ASYNC(task)) { + /* Wake up anyone who may be waiting for task completion */ + if (!rpc_complete_task(task)) + return; + } else { + if (!atomic_dec_and_test(&task->tk_count)) + return; + } + rpc_final_put_task(task, task->tk_workqueue); +} + +int rpciod_up(void) +{ + return try_module_get(THIS_MODULE) ? 0 : -EINVAL; +} + +void rpciod_down(void) +{ + module_put(THIS_MODULE); +} + +/* + * Start up the rpciod workqueue. + */ +static int rpciod_start(void) +{ + struct workqueue_struct *wq; + + /* + * Create the rpciod thread and wait for it to start. + */ + wq = alloc_workqueue("rpciod", WQ_MEM_RECLAIM | WQ_UNBOUND, 0); + if (!wq) + goto out_failed; + rpciod_workqueue = wq; + wq = alloc_workqueue("xprtiod", WQ_UNBOUND | WQ_MEM_RECLAIM, 0); + if (!wq) + goto free_rpciod; + xprtiod_workqueue = wq; + return 1; +free_rpciod: + wq = rpciod_workqueue; + rpciod_workqueue = NULL; + destroy_workqueue(wq); +out_failed: + return 0; +} + +static void rpciod_stop(void) +{ + struct workqueue_struct *wq = NULL; + + if (rpciod_workqueue == NULL) + return; + + wq = rpciod_workqueue; + rpciod_workqueue = NULL; + destroy_workqueue(wq); + wq = xprtiod_workqueue; + xprtiod_workqueue = NULL; + destroy_workqueue(wq); +} + +void +rpc_destroy_mempool(void) +{ + rpciod_stop(); + mempool_destroy(rpc_buffer_mempool); + mempool_destroy(rpc_task_mempool); + kmem_cache_destroy(rpc_task_slabp); + kmem_cache_destroy(rpc_buffer_slabp); + rpc_destroy_wait_queue(&delay_queue); +} + +int +rpc_init_mempool(void) +{ + /* + * The following is not strictly a mempool initialisation, + * but there is no harm in doing it here + */ + rpc_init_wait_queue(&delay_queue, "delayq"); + if (!rpciod_start()) + goto err_nomem; + + rpc_task_slabp = kmem_cache_create("rpc_tasks", + sizeof(struct rpc_task), + 0, SLAB_HWCACHE_ALIGN, + NULL); + if (!rpc_task_slabp) + goto err_nomem; + rpc_buffer_slabp = kmem_cache_create("rpc_buffers", + RPC_BUFFER_MAXSIZE, + 0, SLAB_HWCACHE_ALIGN, + NULL); + if (!rpc_buffer_slabp) + goto err_nomem; + rpc_task_mempool = mempool_create_slab_pool(RPC_TASK_POOLSIZE, + rpc_task_slabp); + if (!rpc_task_mempool) + goto err_nomem; + rpc_buffer_mempool = mempool_create_slab_pool(RPC_BUFFER_POOLSIZE, + rpc_buffer_slabp); + if (!rpc_buffer_mempool) + goto err_nomem; + return 0; +err_nomem: + rpc_destroy_mempool(); + return -ENOMEM; +} diff --git a/net/sunrpc/socklib.c b/net/sunrpc/socklib.c new file mode 100644 index 000000000..1b2b84fee --- /dev/null +++ b/net/sunrpc/socklib.c @@ -0,0 +1,324 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * linux/net/sunrpc/socklib.c + * + * Common socket helper routines for RPC client and server + * + * Copyright (C) 1995, 1996 Olaf Kirch + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "socklib.h" + +/* + * Helper structure for copying from an sk_buff. + */ +struct xdr_skb_reader { + struct sk_buff *skb; + unsigned int offset; + size_t count; + __wsum csum; +}; + +typedef size_t (*xdr_skb_read_actor)(struct xdr_skb_reader *desc, void *to, + size_t len); + +/** + * xdr_skb_read_bits - copy some data bits from skb to internal buffer + * @desc: sk_buff copy helper + * @to: copy destination + * @len: number of bytes to copy + * + * Possibly called several times to iterate over an sk_buff and copy + * data out of it. + */ +static size_t +xdr_skb_read_bits(struct xdr_skb_reader *desc, void *to, size_t len) +{ + if (len > desc->count) + len = desc->count; + if (unlikely(skb_copy_bits(desc->skb, desc->offset, to, len))) + return 0; + desc->count -= len; + desc->offset += len; + return len; +} + +/** + * xdr_skb_read_and_csum_bits - copy and checksum from skb to buffer + * @desc: sk_buff copy helper + * @to: copy destination + * @len: number of bytes to copy + * + * Same as skb_read_bits, but calculate a checksum at the same time. + */ +static size_t xdr_skb_read_and_csum_bits(struct xdr_skb_reader *desc, void *to, size_t len) +{ + unsigned int pos; + __wsum csum2; + + if (len > desc->count) + len = desc->count; + pos = desc->offset; + csum2 = skb_copy_and_csum_bits(desc->skb, pos, to, len); + desc->csum = csum_block_add(desc->csum, csum2, pos); + desc->count -= len; + desc->offset += len; + return len; +} + +/** + * xdr_partial_copy_from_skb - copy data out of an skb + * @xdr: target XDR buffer + * @base: starting offset + * @desc: sk_buff copy helper + * @copy_actor: virtual method for copying data + * + */ +static ssize_t +xdr_partial_copy_from_skb(struct xdr_buf *xdr, unsigned int base, struct xdr_skb_reader *desc, xdr_skb_read_actor copy_actor) +{ + struct page **ppage = xdr->pages; + unsigned int len, pglen = xdr->page_len; + ssize_t copied = 0; + size_t ret; + + len = xdr->head[0].iov_len; + if (base < len) { + len -= base; + ret = copy_actor(desc, (char *)xdr->head[0].iov_base + base, len); + copied += ret; + if (ret != len || !desc->count) + goto out; + base = 0; + } else + base -= len; + + if (unlikely(pglen == 0)) + goto copy_tail; + if (unlikely(base >= pglen)) { + base -= pglen; + goto copy_tail; + } + if (base || xdr->page_base) { + pglen -= base; + base += xdr->page_base; + ppage += base >> PAGE_SHIFT; + base &= ~PAGE_MASK; + } + do { + char *kaddr; + + /* ACL likes to be lazy in allocating pages - ACLs + * are small by default but can get huge. */ + if ((xdr->flags & XDRBUF_SPARSE_PAGES) && *ppage == NULL) { + *ppage = alloc_page(GFP_NOWAIT | __GFP_NOWARN); + if (unlikely(*ppage == NULL)) { + if (copied == 0) + copied = -ENOMEM; + goto out; + } + } + + len = PAGE_SIZE; + kaddr = kmap_atomic(*ppage); + if (base) { + len -= base; + if (pglen < len) + len = pglen; + ret = copy_actor(desc, kaddr + base, len); + base = 0; + } else { + if (pglen < len) + len = pglen; + ret = copy_actor(desc, kaddr, len); + } + flush_dcache_page(*ppage); + kunmap_atomic(kaddr); + copied += ret; + if (ret != len || !desc->count) + goto out; + ppage++; + } while ((pglen -= len) != 0); +copy_tail: + len = xdr->tail[0].iov_len; + if (base < len) + copied += copy_actor(desc, (char *)xdr->tail[0].iov_base + base, len - base); +out: + return copied; +} + +/** + * csum_partial_copy_to_xdr - checksum and copy data + * @xdr: target XDR buffer + * @skb: source skb + * + * We have set things up such that we perform the checksum of the UDP + * packet in parallel with the copies into the RPC client iovec. -DaveM + */ +int csum_partial_copy_to_xdr(struct xdr_buf *xdr, struct sk_buff *skb) +{ + struct xdr_skb_reader desc; + + desc.skb = skb; + desc.offset = 0; + desc.count = skb->len - desc.offset; + + if (skb_csum_unnecessary(skb)) + goto no_checksum; + + desc.csum = csum_partial(skb->data, desc.offset, skb->csum); + if (xdr_partial_copy_from_skb(xdr, 0, &desc, xdr_skb_read_and_csum_bits) < 0) + return -1; + if (desc.offset != skb->len) { + __wsum csum2; + csum2 = skb_checksum(skb, desc.offset, skb->len - desc.offset, 0); + desc.csum = csum_block_add(desc.csum, csum2, desc.offset); + } + if (desc.count) + return -1; + if (csum_fold(desc.csum)) + return -1; + if (unlikely(skb->ip_summed == CHECKSUM_COMPLETE) && + !skb->csum_complete_sw) + netdev_rx_csum_fault(skb->dev, skb); + return 0; +no_checksum: + if (xdr_partial_copy_from_skb(xdr, 0, &desc, xdr_skb_read_bits) < 0) + return -1; + if (desc.count) + return -1; + return 0; +} +EXPORT_SYMBOL_GPL(csum_partial_copy_to_xdr); + +static inline int xprt_sendmsg(struct socket *sock, struct msghdr *msg, + size_t seek) +{ + if (seek) + iov_iter_advance(&msg->msg_iter, seek); + return sock_sendmsg(sock, msg); +} + +static int xprt_send_kvec(struct socket *sock, struct msghdr *msg, + struct kvec *vec, size_t seek) +{ + iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, vec, 1, vec->iov_len); + return xprt_sendmsg(sock, msg, seek); +} + +static int xprt_send_pagedata(struct socket *sock, struct msghdr *msg, + struct xdr_buf *xdr, size_t base) +{ + iov_iter_bvec(&msg->msg_iter, ITER_SOURCE, xdr->bvec, xdr_buf_pagecount(xdr), + xdr->page_len + xdr->page_base); + return xprt_sendmsg(sock, msg, base + xdr->page_base); +} + +/* Common case: + * - stream transport + * - sending from byte 0 of the message + * - the message is wholly contained in @xdr's head iovec + */ +static int xprt_send_rm_and_kvec(struct socket *sock, struct msghdr *msg, + rpc_fraghdr marker, struct kvec *vec, + size_t base) +{ + struct kvec iov[2] = { + [0] = { + .iov_base = &marker, + .iov_len = sizeof(marker) + }, + [1] = *vec, + }; + size_t len = iov[0].iov_len + iov[1].iov_len; + + iov_iter_kvec(&msg->msg_iter, ITER_SOURCE, iov, 2, len); + return xprt_sendmsg(sock, msg, base); +} + +/** + * xprt_sock_sendmsg - write an xdr_buf directly to a socket + * @sock: open socket to send on + * @msg: socket message metadata + * @xdr: xdr_buf containing this request + * @base: starting position in the buffer + * @marker: stream record marker field + * @sent_p: return the total number of bytes successfully queued for sending + * + * Return values: + * On success, returns zero and fills in @sent_p. + * %-ENOTSOCK if @sock is not a struct socket. + */ +int xprt_sock_sendmsg(struct socket *sock, struct msghdr *msg, + struct xdr_buf *xdr, unsigned int base, + rpc_fraghdr marker, unsigned int *sent_p) +{ + unsigned int rmsize = marker ? sizeof(marker) : 0; + unsigned int remainder = rmsize + xdr->len - base; + unsigned int want; + int err = 0; + + *sent_p = 0; + + if (unlikely(!sock)) + return -ENOTSOCK; + + msg->msg_flags |= MSG_MORE; + want = xdr->head[0].iov_len + rmsize; + if (base < want) { + unsigned int len = want - base; + + remainder -= len; + if (remainder == 0) + msg->msg_flags &= ~MSG_MORE; + if (rmsize) + err = xprt_send_rm_and_kvec(sock, msg, marker, + &xdr->head[0], base); + else + err = xprt_send_kvec(sock, msg, &xdr->head[0], base); + if (remainder == 0 || err != len) + goto out; + *sent_p += err; + base = 0; + } else { + base -= want; + } + + if (base < xdr->page_len) { + unsigned int len = xdr->page_len - base; + + remainder -= len; + if (remainder == 0) + msg->msg_flags &= ~MSG_MORE; + err = xprt_send_pagedata(sock, msg, xdr, base); + if (remainder == 0 || err != len) + goto out; + *sent_p += err; + base = 0; + } else { + base -= xdr->page_len; + } + + if (base >= xdr->tail[0].iov_len) + return 0; + msg->msg_flags &= ~MSG_MORE; + err = xprt_send_kvec(sock, msg, &xdr->tail[0], base); +out: + if (err > 0) { + *sent_p += err; + err = 0; + } + return err; +} diff --git a/net/sunrpc/socklib.h b/net/sunrpc/socklib.h new file mode 100644 index 000000000..c48114ad6 --- /dev/null +++ b/net/sunrpc/socklib.h @@ -0,0 +1,15 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Copyright (C) 1995-1997 Olaf Kirch + * Copyright (C) 2020, Oracle. + */ + +#ifndef _NET_SUNRPC_SOCKLIB_H_ +#define _NET_SUNRPC_SOCKLIB_H_ + +int csum_partial_copy_to_xdr(struct xdr_buf *xdr, struct sk_buff *skb); +int xprt_sock_sendmsg(struct socket *sock, struct msghdr *msg, + struct xdr_buf *xdr, unsigned int base, + rpc_fraghdr marker, unsigned int *sent_p); + +#endif /* _NET_SUNRPC_SOCKLIB_H_ */ diff --git a/net/sunrpc/stats.c b/net/sunrpc/stats.c new file mode 100644 index 000000000..52908f9e6 --- /dev/null +++ b/net/sunrpc/stats.c @@ -0,0 +1,343 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * linux/net/sunrpc/stats.c + * + * procfs-based user access to generic RPC statistics. The stats files + * reside in /proc/net/rpc. + * + * The read routines assume that the buffer passed in is just big enough. + * If you implement an RPC service that has its own stats routine which + * appends the generic RPC stats, make sure you don't exceed the PAGE_SIZE + * limit. + * + * Copyright (C) 1995, 1996, 1997 Olaf Kirch + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "netns.h" + +#define RPCDBG_FACILITY RPCDBG_MISC + +/* + * Get RPC client stats + */ +static int rpc_proc_show(struct seq_file *seq, void *v) { + const struct rpc_stat *statp = seq->private; + const struct rpc_program *prog = statp->program; + unsigned int i, j; + + seq_printf(seq, + "net %u %u %u %u\n", + statp->netcnt, + statp->netudpcnt, + statp->nettcpcnt, + statp->nettcpconn); + seq_printf(seq, + "rpc %u %u %u\n", + statp->rpccnt, + statp->rpcretrans, + statp->rpcauthrefresh); + + for (i = 0; i < prog->nrvers; i++) { + const struct rpc_version *vers = prog->version[i]; + if (!vers) + continue; + seq_printf(seq, "proc%u %u", + vers->number, vers->nrprocs); + for (j = 0; j < vers->nrprocs; j++) + seq_printf(seq, " %u", vers->counts[j]); + seq_putc(seq, '\n'); + } + return 0; +} + +static int rpc_proc_open(struct inode *inode, struct file *file) +{ + return single_open(file, rpc_proc_show, pde_data(inode)); +} + +static const struct proc_ops rpc_proc_ops = { + .proc_open = rpc_proc_open, + .proc_read = seq_read, + .proc_lseek = seq_lseek, + .proc_release = single_release, +}; + +/* + * Get RPC server stats + */ +void svc_seq_show(struct seq_file *seq, const struct svc_stat *statp) +{ + const struct svc_program *prog = statp->program; + const struct svc_version *vers; + unsigned int i, j; + + seq_printf(seq, + "net %u %u %u %u\n", + statp->netcnt, + statp->netudpcnt, + statp->nettcpcnt, + statp->nettcpconn); + seq_printf(seq, + "rpc %u %u %u %u %u\n", + statp->rpccnt, + statp->rpcbadfmt+statp->rpcbadauth+statp->rpcbadclnt, + statp->rpcbadfmt, + statp->rpcbadauth, + statp->rpcbadclnt); + + for (i = 0; i < prog->pg_nvers; i++) { + vers = prog->pg_vers[i]; + if (!vers) + continue; + seq_printf(seq, "proc%d %u", i, vers->vs_nproc); + for (j = 0; j < vers->vs_nproc; j++) + seq_printf(seq, " %u", vers->vs_count[j]); + seq_putc(seq, '\n'); + } +} +EXPORT_SYMBOL_GPL(svc_seq_show); + +/** + * rpc_alloc_iostats - allocate an rpc_iostats structure + * @clnt: RPC program, version, and xprt + * + */ +struct rpc_iostats *rpc_alloc_iostats(struct rpc_clnt *clnt) +{ + struct rpc_iostats *stats; + int i; + + stats = kcalloc(clnt->cl_maxproc, sizeof(*stats), GFP_KERNEL); + if (stats) { + for (i = 0; i < clnt->cl_maxproc; i++) + spin_lock_init(&stats[i].om_lock); + } + return stats; +} +EXPORT_SYMBOL_GPL(rpc_alloc_iostats); + +/** + * rpc_free_iostats - release an rpc_iostats structure + * @stats: doomed rpc_iostats structure + * + */ +void rpc_free_iostats(struct rpc_iostats *stats) +{ + kfree(stats); +} +EXPORT_SYMBOL_GPL(rpc_free_iostats); + +/** + * rpc_count_iostats_metrics - tally up per-task stats + * @task: completed rpc_task + * @op_metrics: stat structure for OP that will accumulate stats from @task + */ +void rpc_count_iostats_metrics(const struct rpc_task *task, + struct rpc_iostats *op_metrics) +{ + struct rpc_rqst *req = task->tk_rqstp; + ktime_t backlog, execute, now; + + if (!op_metrics || !req) + return; + + now = ktime_get(); + spin_lock(&op_metrics->om_lock); + + op_metrics->om_ops++; + /* kernel API: om_ops must never become larger than om_ntrans */ + op_metrics->om_ntrans += max(req->rq_ntrans, 1); + op_metrics->om_timeouts += task->tk_timeouts; + + op_metrics->om_bytes_sent += req->rq_xmit_bytes_sent; + op_metrics->om_bytes_recv += req->rq_reply_bytes_recvd; + + backlog = 0; + if (ktime_to_ns(req->rq_xtime)) { + backlog = ktime_sub(req->rq_xtime, task->tk_start); + op_metrics->om_queue = ktime_add(op_metrics->om_queue, backlog); + } + + op_metrics->om_rtt = ktime_add(op_metrics->om_rtt, req->rq_rtt); + + execute = ktime_sub(now, task->tk_start); + op_metrics->om_execute = ktime_add(op_metrics->om_execute, execute); + if (task->tk_status < 0) + op_metrics->om_error_status++; + + spin_unlock(&op_metrics->om_lock); + + trace_rpc_stats_latency(req->rq_task, backlog, req->rq_rtt, execute); +} +EXPORT_SYMBOL_GPL(rpc_count_iostats_metrics); + +/** + * rpc_count_iostats - tally up per-task stats + * @task: completed rpc_task + * @stats: array of stat structures + * + * Uses the statidx from @task + */ +void rpc_count_iostats(const struct rpc_task *task, struct rpc_iostats *stats) +{ + rpc_count_iostats_metrics(task, + &stats[task->tk_msg.rpc_proc->p_statidx]); +} +EXPORT_SYMBOL_GPL(rpc_count_iostats); + +static void _print_name(struct seq_file *seq, unsigned int op, + const struct rpc_procinfo *procs) +{ + if (procs[op].p_name) + seq_printf(seq, "\t%12s: ", procs[op].p_name); + else if (op == 0) + seq_printf(seq, "\t NULL: "); + else + seq_printf(seq, "\t%12u: ", op); +} + +static void _add_rpc_iostats(struct rpc_iostats *a, struct rpc_iostats *b) +{ + a->om_ops += b->om_ops; + a->om_ntrans += b->om_ntrans; + a->om_timeouts += b->om_timeouts; + a->om_bytes_sent += b->om_bytes_sent; + a->om_bytes_recv += b->om_bytes_recv; + a->om_queue = ktime_add(a->om_queue, b->om_queue); + a->om_rtt = ktime_add(a->om_rtt, b->om_rtt); + a->om_execute = ktime_add(a->om_execute, b->om_execute); + a->om_error_status += b->om_error_status; +} + +static void _print_rpc_iostats(struct seq_file *seq, struct rpc_iostats *stats, + int op, const struct rpc_procinfo *procs) +{ + _print_name(seq, op, procs); + seq_printf(seq, "%lu %lu %lu %llu %llu %llu %llu %llu %lu\n", + stats->om_ops, + stats->om_ntrans, + stats->om_timeouts, + stats->om_bytes_sent, + stats->om_bytes_recv, + ktime_to_ms(stats->om_queue), + ktime_to_ms(stats->om_rtt), + ktime_to_ms(stats->om_execute), + stats->om_error_status); +} + +static int do_print_stats(struct rpc_clnt *clnt, struct rpc_xprt *xprt, void *seqv) +{ + struct seq_file *seq = seqv; + + xprt->ops->print_stats(xprt, seq); + return 0; +} + +void rpc_clnt_show_stats(struct seq_file *seq, struct rpc_clnt *clnt) +{ + unsigned int op, maxproc = clnt->cl_maxproc; + + if (!clnt->cl_metrics) + return; + + seq_printf(seq, "\tRPC iostats version: %s ", RPC_IOSTATS_VERS); + seq_printf(seq, "p/v: %u/%u (%s)\n", + clnt->cl_prog, clnt->cl_vers, clnt->cl_program->name); + + rpc_clnt_iterate_for_each_xprt(clnt, do_print_stats, seq); + + seq_printf(seq, "\tper-op statistics\n"); + for (op = 0; op < maxproc; op++) { + struct rpc_iostats stats = {}; + struct rpc_clnt *next = clnt; + do { + _add_rpc_iostats(&stats, &next->cl_metrics[op]); + if (next == next->cl_parent) + break; + next = next->cl_parent; + } while (next); + _print_rpc_iostats(seq, &stats, op, clnt->cl_procinfo); + } +} +EXPORT_SYMBOL_GPL(rpc_clnt_show_stats); + +/* + * Register/unregister RPC proc files + */ +static inline struct proc_dir_entry * +do_register(struct net *net, const char *name, void *data, + const struct proc_ops *proc_ops) +{ + struct sunrpc_net *sn; + + dprintk("RPC: registering /proc/net/rpc/%s\n", name); + sn = net_generic(net, sunrpc_net_id); + return proc_create_data(name, 0, sn->proc_net_rpc, proc_ops, data); +} + +struct proc_dir_entry * +rpc_proc_register(struct net *net, struct rpc_stat *statp) +{ + return do_register(net, statp->program->name, statp, &rpc_proc_ops); +} +EXPORT_SYMBOL_GPL(rpc_proc_register); + +void +rpc_proc_unregister(struct net *net, const char *name) +{ + struct sunrpc_net *sn; + + sn = net_generic(net, sunrpc_net_id); + remove_proc_entry(name, sn->proc_net_rpc); +} +EXPORT_SYMBOL_GPL(rpc_proc_unregister); + +struct proc_dir_entry * +svc_proc_register(struct net *net, struct svc_stat *statp, const struct proc_ops *proc_ops) +{ + return do_register(net, statp->program->pg_name, statp, proc_ops); +} +EXPORT_SYMBOL_GPL(svc_proc_register); + +void +svc_proc_unregister(struct net *net, const char *name) +{ + struct sunrpc_net *sn; + + sn = net_generic(net, sunrpc_net_id); + remove_proc_entry(name, sn->proc_net_rpc); +} +EXPORT_SYMBOL_GPL(svc_proc_unregister); + +int rpc_proc_init(struct net *net) +{ + struct sunrpc_net *sn; + + dprintk("RPC: registering /proc/net/rpc\n"); + sn = net_generic(net, sunrpc_net_id); + sn->proc_net_rpc = proc_mkdir("rpc", net->proc_net); + if (sn->proc_net_rpc == NULL) + return -ENOMEM; + + return 0; +} + +void rpc_proc_exit(struct net *net) +{ + dprintk("RPC: unregistering /proc/net/rpc\n"); + remove_proc_entry("rpc", net->proc_net); +} diff --git a/net/sunrpc/sunrpc.h b/net/sunrpc/sunrpc.h new file mode 100644 index 000000000..d4a362c9e --- /dev/null +++ b/net/sunrpc/sunrpc.h @@ -0,0 +1,42 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/****************************************************************************** + +(c) 2008 NetApp. All Rights Reserved. + + +******************************************************************************/ + +/* + * Functions and macros used internally by RPC + */ + +#ifndef _NET_SUNRPC_SUNRPC_H +#define _NET_SUNRPC_SUNRPC_H + +#include + +/* + * Header for dynamically allocated rpc buffers. + */ +struct rpc_buffer { + size_t len; + char data[]; +}; + +static inline int sock_is_loopback(struct sock *sk) +{ + struct dst_entry *dst; + int loopback = 0; + rcu_read_lock(); + dst = rcu_dereference(sk->sk_dst_cache); + if (dst && dst->dev && + (dst->dev->features & NETIF_F_LOOPBACK)) + loopback = 1; + rcu_read_unlock(); + return loopback; +} + +int rpc_clients_notifier_register(void); +void rpc_clients_notifier_unregister(void); +void auth_domain_cleanup(void); +#endif /* _NET_SUNRPC_SUNRPC_H */ diff --git a/net/sunrpc/sunrpc_syms.c b/net/sunrpc/sunrpc_syms.c new file mode 100644 index 000000000..691c0000e --- /dev/null +++ b/net/sunrpc/sunrpc_syms.c @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * linux/net/sunrpc/sunrpc_syms.c + * + * Symbols exported by the sunrpc module. + * + * Copyright (C) 1997 Olaf Kirch + */ + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "sunrpc.h" +#include "sysfs.h" +#include "netns.h" + +unsigned int sunrpc_net_id; +EXPORT_SYMBOL_GPL(sunrpc_net_id); + +static __net_init int sunrpc_init_net(struct net *net) +{ + int err; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + err = rpc_proc_init(net); + if (err) + goto err_proc; + + err = ip_map_cache_create(net); + if (err) + goto err_ipmap; + + err = unix_gid_cache_create(net); + if (err) + goto err_unixgid; + + err = rpc_pipefs_init_net(net); + if (err) + goto err_pipefs; + + INIT_LIST_HEAD(&sn->all_clients); + spin_lock_init(&sn->rpc_client_lock); + spin_lock_init(&sn->rpcb_clnt_lock); + return 0; + +err_pipefs: + unix_gid_cache_destroy(net); +err_unixgid: + ip_map_cache_destroy(net); +err_ipmap: + rpc_proc_exit(net); +err_proc: + return err; +} + +static __net_exit void sunrpc_exit_net(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + rpc_pipefs_exit_net(net); + unix_gid_cache_destroy(net); + ip_map_cache_destroy(net); + rpc_proc_exit(net); + WARN_ON_ONCE(!list_empty(&sn->all_clients)); +} + +static struct pernet_operations sunrpc_net_ops = { + .init = sunrpc_init_net, + .exit = sunrpc_exit_net, + .id = &sunrpc_net_id, + .size = sizeof(struct sunrpc_net), +}; + +static int __init +init_sunrpc(void) +{ + int err = rpc_init_mempool(); + if (err) + goto out; + err = rpcauth_init_module(); + if (err) + goto out2; + + cache_initialize(); + + err = register_pernet_subsys(&sunrpc_net_ops); + if (err) + goto out3; + + err = register_rpc_pipefs(); + if (err) + goto out4; + + err = rpc_sysfs_init(); + if (err) + goto out5; + + sunrpc_debugfs_init(); +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + rpc_register_sysctl(); +#endif + svc_init_xprt_sock(); /* svc sock transport */ + init_socket_xprt(); /* clnt sock transport */ + return 0; + +out5: + unregister_rpc_pipefs(); +out4: + unregister_pernet_subsys(&sunrpc_net_ops); +out3: + rpcauth_remove_module(); +out2: + rpc_destroy_mempool(); +out: + return err; +} + +static void __exit +cleanup_sunrpc(void) +{ + rpc_sysfs_exit(); + rpc_cleanup_clids(); + xprt_cleanup_ids(); + xprt_multipath_cleanup_ids(); + rpcauth_remove_module(); + cleanup_socket_xprt(); + svc_cleanup_xprt_sock(); + sunrpc_debugfs_exit(); + unregister_rpc_pipefs(); + rpc_destroy_mempool(); + unregister_pernet_subsys(&sunrpc_net_ops); + auth_domain_cleanup(); +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + rpc_unregister_sysctl(); +#endif + rcu_barrier(); /* Wait for completion of call_rcu()'s */ +} +MODULE_LICENSE("GPL"); +fs_initcall(init_sunrpc); /* Ensure we're initialised before nfs */ +module_exit(cleanup_sunrpc); diff --git a/net/sunrpc/svc.c b/net/sunrpc/svc.c new file mode 100644 index 000000000..9b0b21ccc --- /dev/null +++ b/net/sunrpc/svc.c @@ -0,0 +1,1698 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * linux/net/sunrpc/svc.c + * + * High-level RPC service routines + * + * Copyright (C) 1995, 1996 Olaf Kirch + * + * Multiple threads pools and NUMAisation + * Copyright (c) 2006 Silicon Graphics, Inc. + * by Greg Banks + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include "fail.h" + +#define RPCDBG_FACILITY RPCDBG_SVCDSP + +static void svc_unregister(const struct svc_serv *serv, struct net *net); + +#define SVC_POOL_DEFAULT SVC_POOL_GLOBAL + +/* + * Mode for mapping cpus to pools. + */ +enum { + SVC_POOL_AUTO = -1, /* choose one of the others */ + SVC_POOL_GLOBAL, /* no mapping, just a single global pool + * (legacy & UP mode) */ + SVC_POOL_PERCPU, /* one pool per cpu */ + SVC_POOL_PERNODE /* one pool per numa node */ +}; + +/* + * Structure for mapping cpus to pools and vice versa. + * Setup once during sunrpc initialisation. + */ + +struct svc_pool_map { + int count; /* How many svc_servs use us */ + int mode; /* Note: int not enum to avoid + * warnings about "enumeration value + * not handled in switch" */ + unsigned int npools; + unsigned int *pool_to; /* maps pool id to cpu or node */ + unsigned int *to_pool; /* maps cpu or node to pool id */ +}; + +static struct svc_pool_map svc_pool_map = { + .mode = SVC_POOL_DEFAULT +}; + +static DEFINE_MUTEX(svc_pool_map_mutex);/* protects svc_pool_map.count only */ + +static int +param_set_pool_mode(const char *val, const struct kernel_param *kp) +{ + int *ip = (int *)kp->arg; + struct svc_pool_map *m = &svc_pool_map; + int err; + + mutex_lock(&svc_pool_map_mutex); + + err = -EBUSY; + if (m->count) + goto out; + + err = 0; + if (!strncmp(val, "auto", 4)) + *ip = SVC_POOL_AUTO; + else if (!strncmp(val, "global", 6)) + *ip = SVC_POOL_GLOBAL; + else if (!strncmp(val, "percpu", 6)) + *ip = SVC_POOL_PERCPU; + else if (!strncmp(val, "pernode", 7)) + *ip = SVC_POOL_PERNODE; + else + err = -EINVAL; + +out: + mutex_unlock(&svc_pool_map_mutex); + return err; +} + +static int +param_get_pool_mode(char *buf, const struct kernel_param *kp) +{ + int *ip = (int *)kp->arg; + + switch (*ip) + { + case SVC_POOL_AUTO: + return strlcpy(buf, "auto\n", 20); + case SVC_POOL_GLOBAL: + return strlcpy(buf, "global\n", 20); + case SVC_POOL_PERCPU: + return strlcpy(buf, "percpu\n", 20); + case SVC_POOL_PERNODE: + return strlcpy(buf, "pernode\n", 20); + default: + return sprintf(buf, "%d\n", *ip); + } +} + +module_param_call(pool_mode, param_set_pool_mode, param_get_pool_mode, + &svc_pool_map.mode, 0644); + +/* + * Detect best pool mapping mode heuristically, + * according to the machine's topology. + */ +static int +svc_pool_map_choose_mode(void) +{ + unsigned int node; + + if (nr_online_nodes > 1) { + /* + * Actually have multiple NUMA nodes, + * so split pools on NUMA node boundaries + */ + return SVC_POOL_PERNODE; + } + + node = first_online_node; + if (nr_cpus_node(node) > 2) { + /* + * Non-trivial SMP, or CONFIG_NUMA on + * non-NUMA hardware, e.g. with a generic + * x86_64 kernel on Xeons. In this case we + * want to divide the pools on cpu boundaries. + */ + return SVC_POOL_PERCPU; + } + + /* default: one global pool */ + return SVC_POOL_GLOBAL; +} + +/* + * Allocate the to_pool[] and pool_to[] arrays. + * Returns 0 on success or an errno. + */ +static int +svc_pool_map_alloc_arrays(struct svc_pool_map *m, unsigned int maxpools) +{ + m->to_pool = kcalloc(maxpools, sizeof(unsigned int), GFP_KERNEL); + if (!m->to_pool) + goto fail; + m->pool_to = kcalloc(maxpools, sizeof(unsigned int), GFP_KERNEL); + if (!m->pool_to) + goto fail_free; + + return 0; + +fail_free: + kfree(m->to_pool); + m->to_pool = NULL; +fail: + return -ENOMEM; +} + +/* + * Initialise the pool map for SVC_POOL_PERCPU mode. + * Returns number of pools or <0 on error. + */ +static int +svc_pool_map_init_percpu(struct svc_pool_map *m) +{ + unsigned int maxpools = nr_cpu_ids; + unsigned int pidx = 0; + unsigned int cpu; + int err; + + err = svc_pool_map_alloc_arrays(m, maxpools); + if (err) + return err; + + for_each_online_cpu(cpu) { + BUG_ON(pidx >= maxpools); + m->to_pool[cpu] = pidx; + m->pool_to[pidx] = cpu; + pidx++; + } + /* cpus brought online later all get mapped to pool0, sorry */ + + return pidx; +}; + + +/* + * Initialise the pool map for SVC_POOL_PERNODE mode. + * Returns number of pools or <0 on error. + */ +static int +svc_pool_map_init_pernode(struct svc_pool_map *m) +{ + unsigned int maxpools = nr_node_ids; + unsigned int pidx = 0; + unsigned int node; + int err; + + err = svc_pool_map_alloc_arrays(m, maxpools); + if (err) + return err; + + for_each_node_with_cpus(node) { + /* some architectures (e.g. SN2) have cpuless nodes */ + BUG_ON(pidx > maxpools); + m->to_pool[node] = pidx; + m->pool_to[pidx] = node; + pidx++; + } + /* nodes brought online later all get mapped to pool0, sorry */ + + return pidx; +} + + +/* + * Add a reference to the global map of cpus to pools (and + * vice versa) if pools are in use. + * Initialise the map if we're the first user. + * Returns the number of pools. If this is '1', no reference + * was taken. + */ +static unsigned int +svc_pool_map_get(void) +{ + struct svc_pool_map *m = &svc_pool_map; + int npools = -1; + + mutex_lock(&svc_pool_map_mutex); + + if (m->count++) { + mutex_unlock(&svc_pool_map_mutex); + WARN_ON_ONCE(m->npools <= 1); + return m->npools; + } + + if (m->mode == SVC_POOL_AUTO) + m->mode = svc_pool_map_choose_mode(); + + switch (m->mode) { + case SVC_POOL_PERCPU: + npools = svc_pool_map_init_percpu(m); + break; + case SVC_POOL_PERNODE: + npools = svc_pool_map_init_pernode(m); + break; + } + + if (npools <= 0) { + /* default, or memory allocation failure */ + npools = 1; + m->mode = SVC_POOL_GLOBAL; + } + m->npools = npools; + + if (npools == 1) + /* service is unpooled, so doesn't hold a reference */ + m->count--; + + mutex_unlock(&svc_pool_map_mutex); + return npools; +} + +/* + * Drop a reference to the global map of cpus to pools, if + * pools were in use, i.e. if npools > 1. + * When the last reference is dropped, the map data is + * freed; this allows the sysadmin to change the pool + * mode using the pool_mode module option without + * rebooting or re-loading sunrpc.ko. + */ +static void +svc_pool_map_put(int npools) +{ + struct svc_pool_map *m = &svc_pool_map; + + if (npools <= 1) + return; + mutex_lock(&svc_pool_map_mutex); + + if (!--m->count) { + kfree(m->to_pool); + m->to_pool = NULL; + kfree(m->pool_to); + m->pool_to = NULL; + m->npools = 0; + } + + mutex_unlock(&svc_pool_map_mutex); +} + +static int svc_pool_map_get_node(unsigned int pidx) +{ + const struct svc_pool_map *m = &svc_pool_map; + + if (m->count) { + if (m->mode == SVC_POOL_PERCPU) + return cpu_to_node(m->pool_to[pidx]); + if (m->mode == SVC_POOL_PERNODE) + return m->pool_to[pidx]; + } + return NUMA_NO_NODE; +} +/* + * Set the given thread's cpus_allowed mask so that it + * will only run on cpus in the given pool. + */ +static inline void +svc_pool_map_set_cpumask(struct task_struct *task, unsigned int pidx) +{ + struct svc_pool_map *m = &svc_pool_map; + unsigned int node = m->pool_to[pidx]; + + /* + * The caller checks for sv_nrpools > 1, which + * implies that we've been initialized. + */ + WARN_ON_ONCE(m->count == 0); + if (m->count == 0) + return; + + switch (m->mode) { + case SVC_POOL_PERCPU: + { + set_cpus_allowed_ptr(task, cpumask_of(node)); + break; + } + case SVC_POOL_PERNODE: + { + set_cpus_allowed_ptr(task, cpumask_of_node(node)); + break; + } + } +} + +/** + * svc_pool_for_cpu - Select pool to run a thread on this cpu + * @serv: An RPC service + * + * Use the active CPU and the svc_pool_map's mode setting to + * select the svc thread pool to use. Once initialized, the + * svc_pool_map does not change. + * + * Return value: + * A pointer to an svc_pool + */ +struct svc_pool *svc_pool_for_cpu(struct svc_serv *serv) +{ + struct svc_pool_map *m = &svc_pool_map; + int cpu = raw_smp_processor_id(); + unsigned int pidx = 0; + + if (serv->sv_nrpools <= 1) + return serv->sv_pools; + + switch (m->mode) { + case SVC_POOL_PERCPU: + pidx = m->to_pool[cpu]; + break; + case SVC_POOL_PERNODE: + pidx = m->to_pool[cpu_to_node(cpu)]; + break; + } + + return &serv->sv_pools[pidx % serv->sv_nrpools]; +} + +int svc_rpcb_setup(struct svc_serv *serv, struct net *net) +{ + int err; + + err = rpcb_create_local(net); + if (err) + return err; + + /* Remove any stale portmap registrations */ + svc_unregister(serv, net); + return 0; +} +EXPORT_SYMBOL_GPL(svc_rpcb_setup); + +void svc_rpcb_cleanup(struct svc_serv *serv, struct net *net) +{ + svc_unregister(serv, net); + rpcb_put_local(net); +} +EXPORT_SYMBOL_GPL(svc_rpcb_cleanup); + +static int svc_uses_rpcbind(struct svc_serv *serv) +{ + struct svc_program *progp; + unsigned int i; + + for (progp = serv->sv_program; progp; progp = progp->pg_next) { + for (i = 0; i < progp->pg_nvers; i++) { + if (progp->pg_vers[i] == NULL) + continue; + if (!progp->pg_vers[i]->vs_hidden) + return 1; + } + } + + return 0; +} + +int svc_bind(struct svc_serv *serv, struct net *net) +{ + if (!svc_uses_rpcbind(serv)) + return 0; + return svc_rpcb_setup(serv, net); +} +EXPORT_SYMBOL_GPL(svc_bind); + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +static void +__svc_init_bc(struct svc_serv *serv) +{ + INIT_LIST_HEAD(&serv->sv_cb_list); + spin_lock_init(&serv->sv_cb_lock); + init_waitqueue_head(&serv->sv_cb_waitq); +} +#else +static void +__svc_init_bc(struct svc_serv *serv) +{ +} +#endif + +/* + * Create an RPC service + */ +static struct svc_serv * +__svc_create(struct svc_program *prog, unsigned int bufsize, int npools, + int (*threadfn)(void *data)) +{ + struct svc_serv *serv; + unsigned int vers; + unsigned int xdrsize; + unsigned int i; + + if (!(serv = kzalloc(sizeof(*serv), GFP_KERNEL))) + return NULL; + serv->sv_name = prog->pg_name; + serv->sv_program = prog; + kref_init(&serv->sv_refcnt); + serv->sv_stats = prog->pg_stats; + if (bufsize > RPCSVC_MAXPAYLOAD) + bufsize = RPCSVC_MAXPAYLOAD; + serv->sv_max_payload = bufsize? bufsize : 4096; + serv->sv_max_mesg = roundup(serv->sv_max_payload + PAGE_SIZE, PAGE_SIZE); + serv->sv_threadfn = threadfn; + xdrsize = 0; + while (prog) { + prog->pg_lovers = prog->pg_nvers-1; + for (vers=0; verspg_nvers ; vers++) + if (prog->pg_vers[vers]) { + prog->pg_hivers = vers; + if (prog->pg_lovers > vers) + prog->pg_lovers = vers; + if (prog->pg_vers[vers]->vs_xdrsize > xdrsize) + xdrsize = prog->pg_vers[vers]->vs_xdrsize; + } + prog = prog->pg_next; + } + serv->sv_xdrsize = xdrsize; + INIT_LIST_HEAD(&serv->sv_tempsocks); + INIT_LIST_HEAD(&serv->sv_permsocks); + timer_setup(&serv->sv_temptimer, NULL, 0); + spin_lock_init(&serv->sv_lock); + + __svc_init_bc(serv); + + serv->sv_nrpools = npools; + serv->sv_pools = + kcalloc(serv->sv_nrpools, sizeof(struct svc_pool), + GFP_KERNEL); + if (!serv->sv_pools) { + kfree(serv); + return NULL; + } + + for (i = 0; i < serv->sv_nrpools; i++) { + struct svc_pool *pool = &serv->sv_pools[i]; + + dprintk("svc: initialising pool %u for %s\n", + i, serv->sv_name); + + pool->sp_id = i; + INIT_LIST_HEAD(&pool->sp_sockets); + INIT_LIST_HEAD(&pool->sp_all_threads); + spin_lock_init(&pool->sp_lock); + } + + return serv; +} + +/** + * svc_create - Create an RPC service + * @prog: the RPC program the new service will handle + * @bufsize: maximum message size for @prog + * @threadfn: a function to service RPC requests for @prog + * + * Returns an instantiated struct svc_serv object or NULL. + */ +struct svc_serv *svc_create(struct svc_program *prog, unsigned int bufsize, + int (*threadfn)(void *data)) +{ + return __svc_create(prog, bufsize, 1, threadfn); +} +EXPORT_SYMBOL_GPL(svc_create); + +/** + * svc_create_pooled - Create an RPC service with pooled threads + * @prog: the RPC program the new service will handle + * @bufsize: maximum message size for @prog + * @threadfn: a function to service RPC requests for @prog + * + * Returns an instantiated struct svc_serv object or NULL. + */ +struct svc_serv *svc_create_pooled(struct svc_program *prog, + unsigned int bufsize, + int (*threadfn)(void *data)) +{ + struct svc_serv *serv; + unsigned int npools = svc_pool_map_get(); + + serv = __svc_create(prog, bufsize, npools, threadfn); + if (!serv) + goto out_err; + return serv; +out_err: + svc_pool_map_put(npools); + return NULL; +} +EXPORT_SYMBOL_GPL(svc_create_pooled); + +/* + * Destroy an RPC service. Should be called with appropriate locking to + * protect sv_permsocks and sv_tempsocks. + */ +void +svc_destroy(struct kref *ref) +{ + struct svc_serv *serv = container_of(ref, struct svc_serv, sv_refcnt); + + dprintk("svc: svc_destroy(%s)\n", serv->sv_program->pg_name); + del_timer_sync(&serv->sv_temptimer); + + /* + * The last user is gone and thus all sockets have to be destroyed to + * the point. Check this. + */ + BUG_ON(!list_empty(&serv->sv_permsocks)); + BUG_ON(!list_empty(&serv->sv_tempsocks)); + + cache_clean_deferred(serv); + + svc_pool_map_put(serv->sv_nrpools); + + kfree(serv->sv_pools); + kfree(serv); +} +EXPORT_SYMBOL_GPL(svc_destroy); + +/* + * Allocate an RPC server's buffer space. + * We allocate pages and place them in rq_pages. + */ +static int +svc_init_buffer(struct svc_rqst *rqstp, unsigned int size, int node) +{ + unsigned int pages, arghi; + + /* bc_xprt uses fore channel allocated buffers */ + if (svc_is_backchannel(rqstp)) + return 1; + + pages = size / PAGE_SIZE + 1; /* extra page as we hold both request and reply. + * We assume one is at most one page + */ + arghi = 0; + WARN_ON_ONCE(pages > RPCSVC_MAXPAGES); + if (pages > RPCSVC_MAXPAGES) + pages = RPCSVC_MAXPAGES; + while (pages) { + struct page *p = alloc_pages_node(node, GFP_KERNEL, 0); + if (!p) + break; + rqstp->rq_pages[arghi++] = p; + pages--; + } + return pages == 0; +} + +/* + * Release an RPC server buffer + */ +static void +svc_release_buffer(struct svc_rqst *rqstp) +{ + unsigned int i; + + for (i = 0; i < ARRAY_SIZE(rqstp->rq_pages); i++) + if (rqstp->rq_pages[i]) + put_page(rqstp->rq_pages[i]); +} + +struct svc_rqst * +svc_rqst_alloc(struct svc_serv *serv, struct svc_pool *pool, int node) +{ + struct svc_rqst *rqstp; + + rqstp = kzalloc_node(sizeof(*rqstp), GFP_KERNEL, node); + if (!rqstp) + return rqstp; + + __set_bit(RQ_BUSY, &rqstp->rq_flags); + spin_lock_init(&rqstp->rq_lock); + rqstp->rq_server = serv; + rqstp->rq_pool = pool; + + rqstp->rq_scratch_page = alloc_pages_node(node, GFP_KERNEL, 0); + if (!rqstp->rq_scratch_page) + goto out_enomem; + + rqstp->rq_argp = kmalloc_node(serv->sv_xdrsize, GFP_KERNEL, node); + if (!rqstp->rq_argp) + goto out_enomem; + + rqstp->rq_resp = kmalloc_node(serv->sv_xdrsize, GFP_KERNEL, node); + if (!rqstp->rq_resp) + goto out_enomem; + + if (!svc_init_buffer(rqstp, serv->sv_max_mesg, node)) + goto out_enomem; + + return rqstp; +out_enomem: + svc_rqst_free(rqstp); + return NULL; +} +EXPORT_SYMBOL_GPL(svc_rqst_alloc); + +static struct svc_rqst * +svc_prepare_thread(struct svc_serv *serv, struct svc_pool *pool, int node) +{ + struct svc_rqst *rqstp; + + rqstp = svc_rqst_alloc(serv, pool, node); + if (!rqstp) + return ERR_PTR(-ENOMEM); + + svc_get(serv); + spin_lock_bh(&serv->sv_lock); + serv->sv_nrthreads += 1; + spin_unlock_bh(&serv->sv_lock); + + spin_lock_bh(&pool->sp_lock); + pool->sp_nrthreads++; + list_add_rcu(&rqstp->rq_all, &pool->sp_all_threads); + spin_unlock_bh(&pool->sp_lock); + return rqstp; +} + +/* + * Choose a pool in which to create a new thread, for svc_set_num_threads + */ +static inline struct svc_pool * +choose_pool(struct svc_serv *serv, struct svc_pool *pool, unsigned int *state) +{ + if (pool != NULL) + return pool; + + return &serv->sv_pools[(*state)++ % serv->sv_nrpools]; +} + +/* + * Choose a thread to kill, for svc_set_num_threads + */ +static inline struct task_struct * +choose_victim(struct svc_serv *serv, struct svc_pool *pool, unsigned int *state) +{ + unsigned int i; + struct task_struct *task = NULL; + + if (pool != NULL) { + spin_lock_bh(&pool->sp_lock); + } else { + /* choose a pool in round-robin fashion */ + for (i = 0; i < serv->sv_nrpools; i++) { + pool = &serv->sv_pools[--(*state) % serv->sv_nrpools]; + spin_lock_bh(&pool->sp_lock); + if (!list_empty(&pool->sp_all_threads)) + goto found_pool; + spin_unlock_bh(&pool->sp_lock); + } + return NULL; + } + +found_pool: + if (!list_empty(&pool->sp_all_threads)) { + struct svc_rqst *rqstp; + + /* + * Remove from the pool->sp_all_threads list + * so we don't try to kill it again. + */ + rqstp = list_entry(pool->sp_all_threads.next, struct svc_rqst, rq_all); + set_bit(RQ_VICTIM, &rqstp->rq_flags); + list_del_rcu(&rqstp->rq_all); + task = rqstp->rq_task; + } + spin_unlock_bh(&pool->sp_lock); + + return task; +} + +/* create new threads */ +static int +svc_start_kthreads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) +{ + struct svc_rqst *rqstp; + struct task_struct *task; + struct svc_pool *chosen_pool; + unsigned int state = serv->sv_nrthreads-1; + int node; + + do { + nrservs--; + chosen_pool = choose_pool(serv, pool, &state); + + node = svc_pool_map_get_node(chosen_pool->sp_id); + rqstp = svc_prepare_thread(serv, chosen_pool, node); + if (IS_ERR(rqstp)) + return PTR_ERR(rqstp); + + task = kthread_create_on_node(serv->sv_threadfn, rqstp, + node, "%s", serv->sv_name); + if (IS_ERR(task)) { + svc_exit_thread(rqstp); + return PTR_ERR(task); + } + + rqstp->rq_task = task; + if (serv->sv_nrpools > 1) + svc_pool_map_set_cpumask(task, chosen_pool->sp_id); + + svc_sock_update_bufs(serv); + wake_up_process(task); + } while (nrservs > 0); + + return 0; +} + +/* + * Create or destroy enough new threads to make the number + * of threads the given number. If `pool' is non-NULL, applies + * only to threads in that pool, otherwise round-robins between + * all pools. Caller must ensure that mutual exclusion between this and + * server startup or shutdown. + */ + +/* destroy old threads */ +static int +svc_stop_kthreads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) +{ + struct svc_rqst *rqstp; + struct task_struct *task; + unsigned int state = serv->sv_nrthreads-1; + + /* destroy old threads */ + do { + task = choose_victim(serv, pool, &state); + if (task == NULL) + break; + rqstp = kthread_data(task); + /* Did we lose a race to svo_function threadfn? */ + if (kthread_stop(task) == -EINTR) + svc_exit_thread(rqstp); + nrservs++; + } while (nrservs < 0); + return 0; +} + +int +svc_set_num_threads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) +{ + if (pool == NULL) { + nrservs -= serv->sv_nrthreads; + } else { + spin_lock_bh(&pool->sp_lock); + nrservs -= pool->sp_nrthreads; + spin_unlock_bh(&pool->sp_lock); + } + + if (nrservs > 0) + return svc_start_kthreads(serv, pool, nrservs); + if (nrservs < 0) + return svc_stop_kthreads(serv, pool, nrservs); + return 0; +} +EXPORT_SYMBOL_GPL(svc_set_num_threads); + +/** + * svc_rqst_replace_page - Replace one page in rq_pages[] + * @rqstp: svc_rqst with pages to replace + * @page: replacement page + * + * When replacing a page in rq_pages, batch the release of the + * replaced pages to avoid hammering the page allocator. + */ +void svc_rqst_replace_page(struct svc_rqst *rqstp, struct page *page) +{ + if (*rqstp->rq_next_page) { + if (!pagevec_space(&rqstp->rq_pvec)) + __pagevec_release(&rqstp->rq_pvec); + pagevec_add(&rqstp->rq_pvec, *rqstp->rq_next_page); + } + + get_page(page); + *(rqstp->rq_next_page++) = page; +} +EXPORT_SYMBOL_GPL(svc_rqst_replace_page); + +/* + * Called from a server thread as it's exiting. Caller must hold the "service + * mutex" for the service. + */ +void +svc_rqst_free(struct svc_rqst *rqstp) +{ + svc_release_buffer(rqstp); + if (rqstp->rq_scratch_page) + put_page(rqstp->rq_scratch_page); + kfree(rqstp->rq_resp); + kfree(rqstp->rq_argp); + kfree(rqstp->rq_auth_data); + kfree_rcu(rqstp, rq_rcu_head); +} +EXPORT_SYMBOL_GPL(svc_rqst_free); + +void +svc_exit_thread(struct svc_rqst *rqstp) +{ + struct svc_serv *serv = rqstp->rq_server; + struct svc_pool *pool = rqstp->rq_pool; + + spin_lock_bh(&pool->sp_lock); + pool->sp_nrthreads--; + if (!test_and_set_bit(RQ_VICTIM, &rqstp->rq_flags)) + list_del_rcu(&rqstp->rq_all); + spin_unlock_bh(&pool->sp_lock); + + spin_lock_bh(&serv->sv_lock); + serv->sv_nrthreads -= 1; + spin_unlock_bh(&serv->sv_lock); + svc_sock_update_bufs(serv); + + svc_rqst_free(rqstp); + + svc_put(serv); +} +EXPORT_SYMBOL_GPL(svc_exit_thread); + +/* + * Register an "inet" protocol family netid with the local + * rpcbind daemon via an rpcbind v4 SET request. + * + * No netconfig infrastructure is available in the kernel, so + * we map IP_ protocol numbers to netids by hand. + * + * Returns zero on success; a negative errno value is returned + * if any error occurs. + */ +static int __svc_rpcb_register4(struct net *net, const u32 program, + const u32 version, + const unsigned short protocol, + const unsigned short port) +{ + const struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_ANY), + .sin_port = htons(port), + }; + const char *netid; + int error; + + switch (protocol) { + case IPPROTO_UDP: + netid = RPCBIND_NETID_UDP; + break; + case IPPROTO_TCP: + netid = RPCBIND_NETID_TCP; + break; + default: + return -ENOPROTOOPT; + } + + error = rpcb_v4_register(net, program, version, + (const struct sockaddr *)&sin, netid); + + /* + * User space didn't support rpcbind v4, so retry this + * registration request with the legacy rpcbind v2 protocol. + */ + if (error == -EPROTONOSUPPORT) + error = rpcb_register(net, program, version, protocol, port); + + return error; +} + +#if IS_ENABLED(CONFIG_IPV6) +/* + * Register an "inet6" protocol family netid with the local + * rpcbind daemon via an rpcbind v4 SET request. + * + * No netconfig infrastructure is available in the kernel, so + * we map IP_ protocol numbers to netids by hand. + * + * Returns zero on success; a negative errno value is returned + * if any error occurs. + */ +static int __svc_rpcb_register6(struct net *net, const u32 program, + const u32 version, + const unsigned short protocol, + const unsigned short port) +{ + const struct sockaddr_in6 sin6 = { + .sin6_family = AF_INET6, + .sin6_addr = IN6ADDR_ANY_INIT, + .sin6_port = htons(port), + }; + const char *netid; + int error; + + switch (protocol) { + case IPPROTO_UDP: + netid = RPCBIND_NETID_UDP6; + break; + case IPPROTO_TCP: + netid = RPCBIND_NETID_TCP6; + break; + default: + return -ENOPROTOOPT; + } + + error = rpcb_v4_register(net, program, version, + (const struct sockaddr *)&sin6, netid); + + /* + * User space didn't support rpcbind version 4, so we won't + * use a PF_INET6 listener. + */ + if (error == -EPROTONOSUPPORT) + error = -EAFNOSUPPORT; + + return error; +} +#endif /* IS_ENABLED(CONFIG_IPV6) */ + +/* + * Register a kernel RPC service via rpcbind version 4. + * + * Returns zero on success; a negative errno value is returned + * if any error occurs. + */ +static int __svc_register(struct net *net, const char *progname, + const u32 program, const u32 version, + const int family, + const unsigned short protocol, + const unsigned short port) +{ + int error = -EAFNOSUPPORT; + + switch (family) { + case PF_INET: + error = __svc_rpcb_register4(net, program, version, + protocol, port); + break; +#if IS_ENABLED(CONFIG_IPV6) + case PF_INET6: + error = __svc_rpcb_register6(net, program, version, + protocol, port); +#endif + } + + trace_svc_register(progname, version, family, protocol, port, error); + return error; +} + +int svc_rpcbind_set_version(struct net *net, + const struct svc_program *progp, + u32 version, int family, + unsigned short proto, + unsigned short port) +{ + return __svc_register(net, progp->pg_name, progp->pg_prog, + version, family, proto, port); + +} +EXPORT_SYMBOL_GPL(svc_rpcbind_set_version); + +int svc_generic_rpcbind_set(struct net *net, + const struct svc_program *progp, + u32 version, int family, + unsigned short proto, + unsigned short port) +{ + const struct svc_version *vers = progp->pg_vers[version]; + int error; + + if (vers == NULL) + return 0; + + if (vers->vs_hidden) { + trace_svc_noregister(progp->pg_name, version, proto, + port, family, 0); + return 0; + } + + /* + * Don't register a UDP port if we need congestion + * control. + */ + if (vers->vs_need_cong_ctrl && proto == IPPROTO_UDP) + return 0; + + error = svc_rpcbind_set_version(net, progp, version, + family, proto, port); + + return (vers->vs_rpcb_optnl) ? 0 : error; +} +EXPORT_SYMBOL_GPL(svc_generic_rpcbind_set); + +/** + * svc_register - register an RPC service with the local portmapper + * @serv: svc_serv struct for the service to register + * @net: net namespace for the service to register + * @family: protocol family of service's listener socket + * @proto: transport protocol number to advertise + * @port: port to advertise + * + * Service is registered for any address in the passed-in protocol family + */ +int svc_register(const struct svc_serv *serv, struct net *net, + const int family, const unsigned short proto, + const unsigned short port) +{ + struct svc_program *progp; + unsigned int i; + int error = 0; + + WARN_ON_ONCE(proto == 0 && port == 0); + if (proto == 0 && port == 0) + return -EINVAL; + + for (progp = serv->sv_program; progp; progp = progp->pg_next) { + for (i = 0; i < progp->pg_nvers; i++) { + + error = progp->pg_rpcbind_set(net, progp, i, + family, proto, port); + if (error < 0) { + printk(KERN_WARNING "svc: failed to register " + "%sv%u RPC service (errno %d).\n", + progp->pg_name, i, -error); + break; + } + } + } + + return error; +} + +/* + * If user space is running rpcbind, it should take the v4 UNSET + * and clear everything for this [program, version]. If user space + * is running portmap, it will reject the v4 UNSET, but won't have + * any "inet6" entries anyway. So a PMAP_UNSET should be sufficient + * in this case to clear all existing entries for [program, version]. + */ +static void __svc_unregister(struct net *net, const u32 program, const u32 version, + const char *progname) +{ + int error; + + error = rpcb_v4_register(net, program, version, NULL, ""); + + /* + * User space didn't support rpcbind v4, so retry this + * request with the legacy rpcbind v2 protocol. + */ + if (error == -EPROTONOSUPPORT) + error = rpcb_register(net, program, version, 0, 0); + + trace_svc_unregister(progname, version, error); +} + +/* + * All netids, bind addresses and ports registered for [program, version] + * are removed from the local rpcbind database (if the service is not + * hidden) to make way for a new instance of the service. + * + * The result of unregistration is reported via dprintk for those who want + * verification of the result, but is otherwise not important. + */ +static void svc_unregister(const struct svc_serv *serv, struct net *net) +{ + struct svc_program *progp; + unsigned long flags; + unsigned int i; + + clear_thread_flag(TIF_SIGPENDING); + + for (progp = serv->sv_program; progp; progp = progp->pg_next) { + for (i = 0; i < progp->pg_nvers; i++) { + if (progp->pg_vers[i] == NULL) + continue; + if (progp->pg_vers[i]->vs_hidden) + continue; + __svc_unregister(net, progp->pg_prog, i, progp->pg_name); + } + } + + spin_lock_irqsave(¤t->sighand->siglock, flags); + recalc_sigpending(); + spin_unlock_irqrestore(¤t->sighand->siglock, flags); +} + +/* + * dprintk the given error with the address of the client that caused it. + */ +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +static __printf(2, 3) +void svc_printk(struct svc_rqst *rqstp, const char *fmt, ...) +{ + struct va_format vaf; + va_list args; + char buf[RPC_MAX_ADDRBUFLEN]; + + va_start(args, fmt); + + vaf.fmt = fmt; + vaf.va = &args; + + dprintk("svc: %s: %pV", svc_print_addr(rqstp, buf, sizeof(buf)), &vaf); + + va_end(args); +} +#else +static __printf(2,3) void svc_printk(struct svc_rqst *rqstp, const char *fmt, ...) {} +#endif + +__be32 +svc_generic_init_request(struct svc_rqst *rqstp, + const struct svc_program *progp, + struct svc_process_info *ret) +{ + const struct svc_version *versp = NULL; /* compiler food */ + const struct svc_procedure *procp = NULL; + + if (rqstp->rq_vers >= progp->pg_nvers ) + goto err_bad_vers; + versp = progp->pg_vers[rqstp->rq_vers]; + if (!versp) + goto err_bad_vers; + + /* + * Some protocol versions (namely NFSv4) require some form of + * congestion control. (See RFC 7530 section 3.1 paragraph 2) + * In other words, UDP is not allowed. We mark those when setting + * up the svc_xprt, and verify that here. + * + * The spec is not very clear about what error should be returned + * when someone tries to access a server that is listening on UDP + * for lower versions. RPC_PROG_MISMATCH seems to be the closest + * fit. + */ + if (versp->vs_need_cong_ctrl && rqstp->rq_xprt && + !test_bit(XPT_CONG_CTRL, &rqstp->rq_xprt->xpt_flags)) + goto err_bad_vers; + + if (rqstp->rq_proc >= versp->vs_nproc) + goto err_bad_proc; + rqstp->rq_procinfo = procp = &versp->vs_proc[rqstp->rq_proc]; + if (!procp) + goto err_bad_proc; + + /* Initialize storage for argp and resp */ + memset(rqstp->rq_argp, 0, procp->pc_argzero); + memset(rqstp->rq_resp, 0, procp->pc_ressize); + + /* Bump per-procedure stats counter */ + versp->vs_count[rqstp->rq_proc]++; + + ret->dispatch = versp->vs_dispatch; + return rpc_success; +err_bad_vers: + ret->mismatch.lovers = progp->pg_lovers; + ret->mismatch.hivers = progp->pg_hivers; + return rpc_prog_mismatch; +err_bad_proc: + return rpc_proc_unavail; +} +EXPORT_SYMBOL_GPL(svc_generic_init_request); + +/* + * Common routine for processing the RPC request. + */ +static int +svc_process_common(struct svc_rqst *rqstp, struct kvec *argv, struct kvec *resv) +{ + struct svc_program *progp; + const struct svc_procedure *procp = NULL; + struct svc_serv *serv = rqstp->rq_server; + struct svc_process_info process; + __be32 *statp; + u32 prog, vers; + __be32 rpc_stat; + int auth_res, rc; + __be32 *reply_statp; + + rpc_stat = rpc_success; + + if (argv->iov_len < 6*4) + goto err_short_len; + + /* Will be turned off by GSS integrity and privacy services */ + set_bit(RQ_SPLICE_OK, &rqstp->rq_flags); + /* Will be turned off only when NFSv4 Sessions are used */ + set_bit(RQ_USEDEFERRAL, &rqstp->rq_flags); + clear_bit(RQ_DROPME, &rqstp->rq_flags); + + svc_putu32(resv, rqstp->rq_xid); + + vers = svc_getnl(argv); + + /* First words of reply: */ + svc_putnl(resv, 1); /* REPLY */ + + if (vers != 2) /* RPC version number */ + goto err_bad_rpc; + + /* Save position in case we later decide to reject: */ + reply_statp = resv->iov_base + resv->iov_len; + + svc_putnl(resv, 0); /* ACCEPT */ + + rqstp->rq_prog = prog = svc_getnl(argv); /* program number */ + rqstp->rq_vers = svc_getnl(argv); /* version number */ + rqstp->rq_proc = svc_getnl(argv); /* procedure number */ + + for (progp = serv->sv_program; progp; progp = progp->pg_next) + if (prog == progp->pg_prog) + break; + + /* + * Decode auth data, and add verifier to reply buffer. + * We do this before anything else in order to get a decent + * auth verifier. + */ + auth_res = svc_authenticate(rqstp); + /* Also give the program a chance to reject this call: */ + if (auth_res == SVC_OK && progp) + auth_res = progp->pg_authenticate(rqstp); + if (auth_res != SVC_OK) + trace_svc_authenticate(rqstp, auth_res); + switch (auth_res) { + case SVC_OK: + break; + case SVC_GARBAGE: + goto err_garbage; + case SVC_SYSERR: + rpc_stat = rpc_system_err; + goto err_bad; + case SVC_DENIED: + goto err_bad_auth; + case SVC_CLOSE: + goto close; + case SVC_DROP: + goto dropit; + case SVC_COMPLETE: + goto sendit; + } + + if (progp == NULL) + goto err_bad_prog; + + rpc_stat = progp->pg_init_request(rqstp, progp, &process); + switch (rpc_stat) { + case rpc_success: + break; + case rpc_prog_unavail: + goto err_bad_prog; + case rpc_prog_mismatch: + goto err_bad_vers; + case rpc_proc_unavail: + goto err_bad_proc; + } + + procp = rqstp->rq_procinfo; + /* Should this check go into the dispatcher? */ + if (!procp || !procp->pc_func) + goto err_bad_proc; + + /* Syntactic check complete */ + serv->sv_stats->rpccnt++; + trace_svc_process(rqstp, progp->pg_name); + + /* Build the reply header. */ + statp = resv->iov_base +resv->iov_len; + svc_putnl(resv, RPC_SUCCESS); + + /* un-reserve some of the out-queue now that we have a + * better idea of reply size + */ + if (procp->pc_xdrressize) + svc_reserve_auth(rqstp, procp->pc_xdrressize<<2); + + /* Call the function that processes the request. */ + rc = process.dispatch(rqstp, statp); + if (procp->pc_release) + procp->pc_release(rqstp); + if (!rc) + goto dropit; + if (rqstp->rq_auth_stat != rpc_auth_ok) + goto err_bad_auth; + + /* Check RPC status result */ + if (*statp != rpc_success) + resv->iov_len = ((void*)statp) - resv->iov_base + 4; + + if (procp->pc_encode == NULL) + goto dropit; + + sendit: + if (svc_authorise(rqstp)) + goto close_xprt; + return 1; /* Caller can now send it */ + + dropit: + svc_authorise(rqstp); /* doesn't hurt to call this twice */ + dprintk("svc: svc_process dropit\n"); + return 0; + + close: + svc_authorise(rqstp); +close_xprt: + if (rqstp->rq_xprt && test_bit(XPT_TEMP, &rqstp->rq_xprt->xpt_flags)) + svc_xprt_close(rqstp->rq_xprt); + dprintk("svc: svc_process close\n"); + return 0; + +err_short_len: + svc_printk(rqstp, "short len %zd, dropping request\n", + argv->iov_len); + goto close_xprt; + +err_bad_rpc: + serv->sv_stats->rpcbadfmt++; + svc_putnl(resv, 1); /* REJECT */ + svc_putnl(resv, 0); /* RPC_MISMATCH */ + svc_putnl(resv, 2); /* Only RPCv2 supported */ + svc_putnl(resv, 2); + goto sendit; + +err_bad_auth: + dprintk("svc: authentication failed (%d)\n", + be32_to_cpu(rqstp->rq_auth_stat)); + serv->sv_stats->rpcbadauth++; + /* Restore write pointer to location of accept status: */ + xdr_ressize_check(rqstp, reply_statp); + svc_putnl(resv, 1); /* REJECT */ + svc_putnl(resv, 1); /* AUTH_ERROR */ + svc_putu32(resv, rqstp->rq_auth_stat); /* status */ + goto sendit; + +err_bad_prog: + dprintk("svc: unknown program %d\n", prog); + serv->sv_stats->rpcbadfmt++; + svc_putnl(resv, RPC_PROG_UNAVAIL); + goto sendit; + +err_bad_vers: + svc_printk(rqstp, "unknown version (%d for prog %d, %s)\n", + rqstp->rq_vers, rqstp->rq_prog, progp->pg_name); + + serv->sv_stats->rpcbadfmt++; + svc_putnl(resv, RPC_PROG_MISMATCH); + svc_putnl(resv, process.mismatch.lovers); + svc_putnl(resv, process.mismatch.hivers); + goto sendit; + +err_bad_proc: + svc_printk(rqstp, "unknown procedure (%d)\n", rqstp->rq_proc); + + serv->sv_stats->rpcbadfmt++; + svc_putnl(resv, RPC_PROC_UNAVAIL); + goto sendit; + +err_garbage: + svc_printk(rqstp, "failed to decode args\n"); + + rpc_stat = rpc_garbage_args; +err_bad: + serv->sv_stats->rpcbadfmt++; + svc_putnl(resv, ntohl(rpc_stat)); + goto sendit; +} + +/* + * Process the RPC request. + */ +int +svc_process(struct svc_rqst *rqstp) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec *resv = &rqstp->rq_res.head[0]; + __be32 dir; + +#if IS_ENABLED(CONFIG_FAIL_SUNRPC) + if (!fail_sunrpc.ignore_server_disconnect && + should_fail(&fail_sunrpc.attr, 1)) + svc_xprt_deferred_close(rqstp->rq_xprt); +#endif + + /* + * Setup response xdr_buf. + * Initially it has just one page + */ + rqstp->rq_next_page = &rqstp->rq_respages[1]; + resv->iov_base = page_address(rqstp->rq_respages[0]); + resv->iov_len = 0; + rqstp->rq_res.pages = rqstp->rq_next_page; + rqstp->rq_res.len = 0; + rqstp->rq_res.page_base = 0; + rqstp->rq_res.page_len = 0; + rqstp->rq_res.buflen = PAGE_SIZE; + rqstp->rq_res.tail[0].iov_base = NULL; + rqstp->rq_res.tail[0].iov_len = 0; + + dir = svc_getu32(argv); + if (dir != rpc_call) + goto out_baddir; + if (!svc_process_common(rqstp, argv, resv)) + goto out_drop; + return svc_send(rqstp); + +out_baddir: + svc_printk(rqstp, "bad direction 0x%08x, dropping request\n", + be32_to_cpu(dir)); + rqstp->rq_server->sv_stats->rpcbadfmt++; +out_drop: + svc_drop(rqstp); + return 0; +} +EXPORT_SYMBOL_GPL(svc_process); + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +/* + * Process a backchannel RPC request that arrived over an existing + * outbound connection + */ +int +bc_svc_process(struct svc_serv *serv, struct rpc_rqst *req, + struct svc_rqst *rqstp) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec *resv = &rqstp->rq_res.head[0]; + struct rpc_task *task; + int proc_error; + int error; + + dprintk("svc: %s(%p)\n", __func__, req); + + /* Build the svc_rqst used by the common processing routine */ + rqstp->rq_xid = req->rq_xid; + rqstp->rq_prot = req->rq_xprt->prot; + rqstp->rq_server = serv; + rqstp->rq_bc_net = req->rq_xprt->xprt_net; + + rqstp->rq_addrlen = sizeof(req->rq_xprt->addr); + memcpy(&rqstp->rq_addr, &req->rq_xprt->addr, rqstp->rq_addrlen); + memcpy(&rqstp->rq_arg, &req->rq_rcv_buf, sizeof(rqstp->rq_arg)); + memcpy(&rqstp->rq_res, &req->rq_snd_buf, sizeof(rqstp->rq_res)); + + /* Adjust the argument buffer length */ + rqstp->rq_arg.len = req->rq_private_buf.len; + if (rqstp->rq_arg.len <= rqstp->rq_arg.head[0].iov_len) { + rqstp->rq_arg.head[0].iov_len = rqstp->rq_arg.len; + rqstp->rq_arg.page_len = 0; + } else if (rqstp->rq_arg.len <= rqstp->rq_arg.head[0].iov_len + + rqstp->rq_arg.page_len) + rqstp->rq_arg.page_len = rqstp->rq_arg.len - + rqstp->rq_arg.head[0].iov_len; + else + rqstp->rq_arg.len = rqstp->rq_arg.head[0].iov_len + + rqstp->rq_arg.page_len; + + /* reset result send buffer "put" position */ + resv->iov_len = 0; + + /* + * Skip the next two words because they've already been + * processed in the transport + */ + svc_getu32(argv); /* XID */ + svc_getnl(argv); /* CALLDIR */ + + /* Parse and execute the bc call */ + proc_error = svc_process_common(rqstp, argv, resv); + + atomic_dec(&req->rq_xprt->bc_slot_count); + if (!proc_error) { + /* Processing error: drop the request */ + xprt_free_bc_request(req); + error = -EINVAL; + goto out; + } + /* Finally, send the reply synchronously */ + memcpy(&req->rq_snd_buf, &rqstp->rq_res, sizeof(req->rq_snd_buf)); + task = rpc_run_bc_task(req); + if (IS_ERR(task)) { + error = PTR_ERR(task); + goto out; + } + + WARN_ON_ONCE(atomic_read(&task->tk_count) != 1); + error = task->tk_status; + rpc_put_task(task); + +out: + dprintk("svc: %s(), error=%d\n", __func__, error); + return error; +} +EXPORT_SYMBOL_GPL(bc_svc_process); +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +/** + * svc_max_payload - Return transport-specific limit on the RPC payload + * @rqstp: RPC transaction context + * + * Returns the maximum number of payload bytes the current transport + * allows. + */ +u32 svc_max_payload(const struct svc_rqst *rqstp) +{ + u32 max = rqstp->rq_xprt->xpt_class->xcl_max_payload; + + if (rqstp->rq_server->sv_max_payload < max) + max = rqstp->rq_server->sv_max_payload; + return max; +} +EXPORT_SYMBOL_GPL(svc_max_payload); + +/** + * svc_proc_name - Return RPC procedure name in string form + * @rqstp: svc_rqst to operate on + * + * Return value: + * Pointer to a NUL-terminated string + */ +const char *svc_proc_name(const struct svc_rqst *rqstp) +{ + if (rqstp && rqstp->rq_procinfo) + return rqstp->rq_procinfo->pc_name; + return "unknown"; +} + + +/** + * svc_encode_result_payload - mark a range of bytes as a result payload + * @rqstp: svc_rqst to operate on + * @offset: payload's byte offset in rqstp->rq_res + * @length: size of payload, in bytes + * + * Returns zero on success, or a negative errno if a permanent + * error occurred. + */ +int svc_encode_result_payload(struct svc_rqst *rqstp, unsigned int offset, + unsigned int length) +{ + return rqstp->rq_xprt->xpt_ops->xpo_result_payload(rqstp, offset, + length); +} +EXPORT_SYMBOL_GPL(svc_encode_result_payload); + +/** + * svc_fill_write_vector - Construct data argument for VFS write call + * @rqstp: svc_rqst to operate on + * @payload: xdr_buf containing only the write data payload + * + * Fills in rqstp::rq_vec, and returns the number of elements. + */ +unsigned int svc_fill_write_vector(struct svc_rqst *rqstp, + struct xdr_buf *payload) +{ + struct page **pages = payload->pages; + struct kvec *first = payload->head; + struct kvec *vec = rqstp->rq_vec; + size_t total = payload->len; + unsigned int i; + + /* Some types of transport can present the write payload + * entirely in rq_arg.pages. In this case, @first is empty. + */ + i = 0; + if (first->iov_len) { + vec[i].iov_base = first->iov_base; + vec[i].iov_len = min_t(size_t, total, first->iov_len); + total -= vec[i].iov_len; + ++i; + } + + while (total) { + vec[i].iov_base = page_address(*pages); + vec[i].iov_len = min_t(size_t, total, PAGE_SIZE); + total -= vec[i].iov_len; + ++i; + ++pages; + } + + WARN_ON_ONCE(i > ARRAY_SIZE(rqstp->rq_vec)); + return i; +} +EXPORT_SYMBOL_GPL(svc_fill_write_vector); + +/** + * svc_fill_symlink_pathname - Construct pathname argument for VFS symlink call + * @rqstp: svc_rqst to operate on + * @first: buffer containing first section of pathname + * @p: buffer containing remaining section of pathname + * @total: total length of the pathname argument + * + * The VFS symlink API demands a NUL-terminated pathname in mapped memory. + * Returns pointer to a NUL-terminated string, or an ERR_PTR. Caller must free + * the returned string. + */ +char *svc_fill_symlink_pathname(struct svc_rqst *rqstp, struct kvec *first, + void *p, size_t total) +{ + size_t len, remaining; + char *result, *dst; + + result = kmalloc(total + 1, GFP_KERNEL); + if (!result) + return ERR_PTR(-ESERVERFAULT); + + dst = result; + remaining = total; + + len = min_t(size_t, total, first->iov_len); + if (len) { + memcpy(dst, first->iov_base, len); + dst += len; + remaining -= len; + } + + if (remaining) { + len = min_t(size_t, remaining, PAGE_SIZE); + memcpy(dst, p, len); + dst += len; + } + + *dst = '\0'; + + /* Sanity check: Linux doesn't allow the pathname argument to + * contain a NUL byte. + */ + if (strlen(result) != total) { + kfree(result); + return ERR_PTR(-EINVAL); + } + return result; +} +EXPORT_SYMBOL_GPL(svc_fill_symlink_pathname); diff --git a/net/sunrpc/svc_xprt.c b/net/sunrpc/svc_xprt.c new file mode 100644 index 000000000..8117d0e08 --- /dev/null +++ b/net/sunrpc/svc_xprt.c @@ -0,0 +1,1484 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * linux/net/sunrpc/svc_xprt.c + * + * Author: Tom Tucker + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + +static unsigned int svc_rpc_per_connection_limit __read_mostly; +module_param(svc_rpc_per_connection_limit, uint, 0644); + + +static struct svc_deferred_req *svc_deferred_dequeue(struct svc_xprt *xprt); +static int svc_deferred_recv(struct svc_rqst *rqstp); +static struct cache_deferred_req *svc_defer(struct cache_req *req); +static void svc_age_temp_xprts(struct timer_list *t); +static void svc_delete_xprt(struct svc_xprt *xprt); + +/* apparently the "standard" is that clients close + * idle connections after 5 minutes, servers after + * 6 minutes + * http://nfsv4bat.org/Documents/ConnectAThon/1996/nfstcp.pdf + */ +static int svc_conn_age_period = 6*60; + +/* List of registered transport classes */ +static DEFINE_SPINLOCK(svc_xprt_class_lock); +static LIST_HEAD(svc_xprt_class_list); + +/* SMP locking strategy: + * + * svc_pool->sp_lock protects most of the fields of that pool. + * svc_serv->sv_lock protects sv_tempsocks, sv_permsocks, sv_tmpcnt. + * when both need to be taken (rare), svc_serv->sv_lock is first. + * The "service mutex" protects svc_serv->sv_nrthread. + * svc_sock->sk_lock protects the svc_sock->sk_deferred list + * and the ->sk_info_authunix cache. + * + * The XPT_BUSY bit in xprt->xpt_flags prevents a transport being + * enqueued multiply. During normal transport processing this bit + * is set by svc_xprt_enqueue and cleared by svc_xprt_received. + * Providers should not manipulate this bit directly. + * + * Some flags can be set to certain values at any time + * providing that certain rules are followed: + * + * XPT_CONN, XPT_DATA: + * - Can be set or cleared at any time. + * - After a set, svc_xprt_enqueue must be called to enqueue + * the transport for processing. + * - After a clear, the transport must be read/accepted. + * If this succeeds, it must be set again. + * XPT_CLOSE: + * - Can set at any time. It is never cleared. + * XPT_DEAD: + * - Can only be set while XPT_BUSY is held which ensures + * that no other thread will be using the transport or will + * try to set XPT_DEAD. + */ +int svc_reg_xprt_class(struct svc_xprt_class *xcl) +{ + struct svc_xprt_class *cl; + int res = -EEXIST; + + dprintk("svc: Adding svc transport class '%s'\n", xcl->xcl_name); + + INIT_LIST_HEAD(&xcl->xcl_list); + spin_lock(&svc_xprt_class_lock); + /* Make sure there isn't already a class with the same name */ + list_for_each_entry(cl, &svc_xprt_class_list, xcl_list) { + if (strcmp(xcl->xcl_name, cl->xcl_name) == 0) + goto out; + } + list_add_tail(&xcl->xcl_list, &svc_xprt_class_list); + res = 0; +out: + spin_unlock(&svc_xprt_class_lock); + return res; +} +EXPORT_SYMBOL_GPL(svc_reg_xprt_class); + +void svc_unreg_xprt_class(struct svc_xprt_class *xcl) +{ + dprintk("svc: Removing svc transport class '%s'\n", xcl->xcl_name); + spin_lock(&svc_xprt_class_lock); + list_del_init(&xcl->xcl_list); + spin_unlock(&svc_xprt_class_lock); +} +EXPORT_SYMBOL_GPL(svc_unreg_xprt_class); + +/** + * svc_print_xprts - Format the transport list for printing + * @buf: target buffer for formatted address + * @maxlen: length of target buffer + * + * Fills in @buf with a string containing a list of transport names, each name + * terminated with '\n'. If the buffer is too small, some entries may be + * missing, but it is guaranteed that all lines in the output buffer are + * complete. + * + * Returns positive length of the filled-in string. + */ +int svc_print_xprts(char *buf, int maxlen) +{ + struct svc_xprt_class *xcl; + char tmpstr[80]; + int len = 0; + buf[0] = '\0'; + + spin_lock(&svc_xprt_class_lock); + list_for_each_entry(xcl, &svc_xprt_class_list, xcl_list) { + int slen; + + slen = snprintf(tmpstr, sizeof(tmpstr), "%s %d\n", + xcl->xcl_name, xcl->xcl_max_payload); + if (slen >= sizeof(tmpstr) || len + slen >= maxlen) + break; + len += slen; + strcat(buf, tmpstr); + } + spin_unlock(&svc_xprt_class_lock); + + return len; +} + +/** + * svc_xprt_deferred_close - Close a transport + * @xprt: transport instance + * + * Used in contexts that need to defer the work of shutting down + * the transport to an nfsd thread. + */ +void svc_xprt_deferred_close(struct svc_xprt *xprt) +{ + if (!test_and_set_bit(XPT_CLOSE, &xprt->xpt_flags)) + svc_xprt_enqueue(xprt); +} +EXPORT_SYMBOL_GPL(svc_xprt_deferred_close); + +static void svc_xprt_free(struct kref *kref) +{ + struct svc_xprt *xprt = + container_of(kref, struct svc_xprt, xpt_ref); + struct module *owner = xprt->xpt_class->xcl_owner; + if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags)) + svcauth_unix_info_release(xprt); + put_cred(xprt->xpt_cred); + put_net_track(xprt->xpt_net, &xprt->ns_tracker); + /* See comment on corresponding get in xs_setup_bc_tcp(): */ + if (xprt->xpt_bc_xprt) + xprt_put(xprt->xpt_bc_xprt); + if (xprt->xpt_bc_xps) + xprt_switch_put(xprt->xpt_bc_xps); + trace_svc_xprt_free(xprt); + xprt->xpt_ops->xpo_free(xprt); + module_put(owner); +} + +void svc_xprt_put(struct svc_xprt *xprt) +{ + kref_put(&xprt->xpt_ref, svc_xprt_free); +} +EXPORT_SYMBOL_GPL(svc_xprt_put); + +/* + * Called by transport drivers to initialize the transport independent + * portion of the transport instance. + */ +void svc_xprt_init(struct net *net, struct svc_xprt_class *xcl, + struct svc_xprt *xprt, struct svc_serv *serv) +{ + memset(xprt, 0, sizeof(*xprt)); + xprt->xpt_class = xcl; + xprt->xpt_ops = xcl->xcl_ops; + kref_init(&xprt->xpt_ref); + xprt->xpt_server = serv; + INIT_LIST_HEAD(&xprt->xpt_list); + INIT_LIST_HEAD(&xprt->xpt_ready); + INIT_LIST_HEAD(&xprt->xpt_deferred); + INIT_LIST_HEAD(&xprt->xpt_users); + mutex_init(&xprt->xpt_mutex); + spin_lock_init(&xprt->xpt_lock); + set_bit(XPT_BUSY, &xprt->xpt_flags); + xprt->xpt_net = get_net_track(net, &xprt->ns_tracker, GFP_ATOMIC); + strcpy(xprt->xpt_remotebuf, "uninitialized"); +} +EXPORT_SYMBOL_GPL(svc_xprt_init); + +static struct svc_xprt *__svc_xpo_create(struct svc_xprt_class *xcl, + struct svc_serv *serv, + struct net *net, + const int family, + const unsigned short port, + int flags) +{ + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_ANY), + .sin_port = htons(port), + }; +#if IS_ENABLED(CONFIG_IPV6) + struct sockaddr_in6 sin6 = { + .sin6_family = AF_INET6, + .sin6_addr = IN6ADDR_ANY_INIT, + .sin6_port = htons(port), + }; +#endif + struct svc_xprt *xprt; + struct sockaddr *sap; + size_t len; + + switch (family) { + case PF_INET: + sap = (struct sockaddr *)&sin; + len = sizeof(sin); + break; +#if IS_ENABLED(CONFIG_IPV6) + case PF_INET6: + sap = (struct sockaddr *)&sin6; + len = sizeof(sin6); + break; +#endif + default: + return ERR_PTR(-EAFNOSUPPORT); + } + + xprt = xcl->xcl_ops->xpo_create(serv, net, sap, len, flags); + if (IS_ERR(xprt)) + trace_svc_xprt_create_err(serv->sv_program->pg_name, + xcl->xcl_name, sap, len, xprt); + return xprt; +} + +/** + * svc_xprt_received - start next receiver thread + * @xprt: controlling transport + * + * The caller must hold the XPT_BUSY bit and must + * not thereafter touch transport data. + * + * Note: XPT_DATA only gets cleared when a read-attempt finds no (or + * insufficient) data. + */ +void svc_xprt_received(struct svc_xprt *xprt) +{ + if (!test_bit(XPT_BUSY, &xprt->xpt_flags)) { + WARN_ONCE(1, "xprt=0x%p already busy!", xprt); + return; + } + + /* As soon as we clear busy, the xprt could be closed and + * 'put', so we need a reference to call svc_xprt_enqueue with: + */ + svc_xprt_get(xprt); + smp_mb__before_atomic(); + clear_bit(XPT_BUSY, &xprt->xpt_flags); + svc_xprt_enqueue(xprt); + svc_xprt_put(xprt); +} +EXPORT_SYMBOL_GPL(svc_xprt_received); + +void svc_add_new_perm_xprt(struct svc_serv *serv, struct svc_xprt *new) +{ + clear_bit(XPT_TEMP, &new->xpt_flags); + spin_lock_bh(&serv->sv_lock); + list_add(&new->xpt_list, &serv->sv_permsocks); + spin_unlock_bh(&serv->sv_lock); + svc_xprt_received(new); +} + +static int _svc_xprt_create(struct svc_serv *serv, const char *xprt_name, + struct net *net, const int family, + const unsigned short port, int flags, + const struct cred *cred) +{ + struct svc_xprt_class *xcl; + + spin_lock(&svc_xprt_class_lock); + list_for_each_entry(xcl, &svc_xprt_class_list, xcl_list) { + struct svc_xprt *newxprt; + unsigned short newport; + + if (strcmp(xprt_name, xcl->xcl_name)) + continue; + + if (!try_module_get(xcl->xcl_owner)) + goto err; + + spin_unlock(&svc_xprt_class_lock); + newxprt = __svc_xpo_create(xcl, serv, net, family, port, flags); + if (IS_ERR(newxprt)) { + module_put(xcl->xcl_owner); + return PTR_ERR(newxprt); + } + newxprt->xpt_cred = get_cred(cred); + svc_add_new_perm_xprt(serv, newxprt); + newport = svc_xprt_local_port(newxprt); + return newport; + } + err: + spin_unlock(&svc_xprt_class_lock); + /* This errno is exposed to user space. Provide a reasonable + * perror msg for a bad transport. */ + return -EPROTONOSUPPORT; +} + +/** + * svc_xprt_create - Add a new listener to @serv + * @serv: target RPC service + * @xprt_name: transport class name + * @net: network namespace + * @family: network address family + * @port: listener port + * @flags: SVC_SOCK flags + * @cred: credential to bind to this transport + * + * Return values: + * %0: New listener added successfully + * %-EPROTONOSUPPORT: Requested transport type not supported + */ +int svc_xprt_create(struct svc_serv *serv, const char *xprt_name, + struct net *net, const int family, + const unsigned short port, int flags, + const struct cred *cred) +{ + int err; + + err = _svc_xprt_create(serv, xprt_name, net, family, port, flags, cred); + if (err == -EPROTONOSUPPORT) { + request_module("svc%s", xprt_name); + err = _svc_xprt_create(serv, xprt_name, net, family, port, flags, cred); + } + return err; +} +EXPORT_SYMBOL_GPL(svc_xprt_create); + +/* + * Copy the local and remote xprt addresses to the rqstp structure + */ +void svc_xprt_copy_addrs(struct svc_rqst *rqstp, struct svc_xprt *xprt) +{ + memcpy(&rqstp->rq_addr, &xprt->xpt_remote, xprt->xpt_remotelen); + rqstp->rq_addrlen = xprt->xpt_remotelen; + + /* + * Destination address in request is needed for binding the + * source address in RPC replies/callbacks later. + */ + memcpy(&rqstp->rq_daddr, &xprt->xpt_local, xprt->xpt_locallen); + rqstp->rq_daddrlen = xprt->xpt_locallen; +} +EXPORT_SYMBOL_GPL(svc_xprt_copy_addrs); + +/** + * svc_print_addr - Format rq_addr field for printing + * @rqstp: svc_rqst struct containing address to print + * @buf: target buffer for formatted address + * @len: length of target buffer + * + */ +char *svc_print_addr(struct svc_rqst *rqstp, char *buf, size_t len) +{ + return __svc_print_addr(svc_addr(rqstp), buf, len); +} +EXPORT_SYMBOL_GPL(svc_print_addr); + +static bool svc_xprt_slots_in_range(struct svc_xprt *xprt) +{ + unsigned int limit = svc_rpc_per_connection_limit; + int nrqsts = atomic_read(&xprt->xpt_nr_rqsts); + + return limit == 0 || (nrqsts >= 0 && nrqsts < limit); +} + +static bool svc_xprt_reserve_slot(struct svc_rqst *rqstp, struct svc_xprt *xprt) +{ + if (!test_bit(RQ_DATA, &rqstp->rq_flags)) { + if (!svc_xprt_slots_in_range(xprt)) + return false; + atomic_inc(&xprt->xpt_nr_rqsts); + set_bit(RQ_DATA, &rqstp->rq_flags); + } + return true; +} + +static void svc_xprt_release_slot(struct svc_rqst *rqstp) +{ + struct svc_xprt *xprt = rqstp->rq_xprt; + if (test_and_clear_bit(RQ_DATA, &rqstp->rq_flags)) { + atomic_dec(&xprt->xpt_nr_rqsts); + smp_wmb(); /* See smp_rmb() in svc_xprt_ready() */ + svc_xprt_enqueue(xprt); + } +} + +static bool svc_xprt_ready(struct svc_xprt *xprt) +{ + unsigned long xpt_flags; + + /* + * If another cpu has recently updated xpt_flags, + * sk_sock->flags, xpt_reserved, or xpt_nr_rqsts, we need to + * know about it; otherwise it's possible that both that cpu and + * this one could call svc_xprt_enqueue() without either + * svc_xprt_enqueue() recognizing that the conditions below + * are satisfied, and we could stall indefinitely: + */ + smp_rmb(); + xpt_flags = READ_ONCE(xprt->xpt_flags); + + if (xpt_flags & BIT(XPT_BUSY)) + return false; + if (xpt_flags & (BIT(XPT_CONN) | BIT(XPT_CLOSE))) + return true; + if (xpt_flags & (BIT(XPT_DATA) | BIT(XPT_DEFERRED))) { + if (xprt->xpt_ops->xpo_has_wspace(xprt) && + svc_xprt_slots_in_range(xprt)) + return true; + trace_svc_xprt_no_write_space(xprt); + return false; + } + return false; +} + +/** + * svc_xprt_enqueue - Queue a transport on an idle nfsd thread + * @xprt: transport with data pending + * + */ +void svc_xprt_enqueue(struct svc_xprt *xprt) +{ + struct svc_pool *pool; + struct svc_rqst *rqstp = NULL; + + if (!svc_xprt_ready(xprt)) + return; + + /* Mark transport as busy. It will remain in this state until + * the provider calls svc_xprt_received. We update XPT_BUSY + * atomically because it also guards against trying to enqueue + * the transport twice. + */ + if (test_and_set_bit(XPT_BUSY, &xprt->xpt_flags)) + return; + + pool = svc_pool_for_cpu(xprt->xpt_server); + + atomic_long_inc(&pool->sp_stats.packets); + + spin_lock_bh(&pool->sp_lock); + list_add_tail(&xprt->xpt_ready, &pool->sp_sockets); + pool->sp_stats.sockets_queued++; + spin_unlock_bh(&pool->sp_lock); + + /* find a thread for this xprt */ + rcu_read_lock(); + list_for_each_entry_rcu(rqstp, &pool->sp_all_threads, rq_all) { + if (test_and_set_bit(RQ_BUSY, &rqstp->rq_flags)) + continue; + atomic_long_inc(&pool->sp_stats.threads_woken); + rqstp->rq_qtime = ktime_get(); + wake_up_process(rqstp->rq_task); + goto out_unlock; + } + set_bit(SP_CONGESTED, &pool->sp_flags); + rqstp = NULL; +out_unlock: + rcu_read_unlock(); + trace_svc_xprt_enqueue(xprt, rqstp); +} +EXPORT_SYMBOL_GPL(svc_xprt_enqueue); + +/* + * Dequeue the first transport, if there is one. + */ +static struct svc_xprt *svc_xprt_dequeue(struct svc_pool *pool) +{ + struct svc_xprt *xprt = NULL; + + if (list_empty(&pool->sp_sockets)) + goto out; + + spin_lock_bh(&pool->sp_lock); + if (likely(!list_empty(&pool->sp_sockets))) { + xprt = list_first_entry(&pool->sp_sockets, + struct svc_xprt, xpt_ready); + list_del_init(&xprt->xpt_ready); + svc_xprt_get(xprt); + } + spin_unlock_bh(&pool->sp_lock); +out: + return xprt; +} + +/** + * svc_reserve - change the space reserved for the reply to a request. + * @rqstp: The request in question + * @space: new max space to reserve + * + * Each request reserves some space on the output queue of the transport + * to make sure the reply fits. This function reduces that reserved + * space to be the amount of space used already, plus @space. + * + */ +void svc_reserve(struct svc_rqst *rqstp, int space) +{ + struct svc_xprt *xprt = rqstp->rq_xprt; + + space += rqstp->rq_res.head[0].iov_len; + + if (xprt && space < rqstp->rq_reserved) { + atomic_sub((rqstp->rq_reserved - space), &xprt->xpt_reserved); + rqstp->rq_reserved = space; + smp_wmb(); /* See smp_rmb() in svc_xprt_ready() */ + svc_xprt_enqueue(xprt); + } +} +EXPORT_SYMBOL_GPL(svc_reserve); + +static void free_deferred(struct svc_xprt *xprt, struct svc_deferred_req *dr) +{ + if (!dr) + return; + + xprt->xpt_ops->xpo_release_ctxt(xprt, dr->xprt_ctxt); + kfree(dr); +} + +static void svc_xprt_release(struct svc_rqst *rqstp) +{ + struct svc_xprt *xprt = rqstp->rq_xprt; + + xprt->xpt_ops->xpo_release_ctxt(xprt, rqstp->rq_xprt_ctxt); + rqstp->rq_xprt_ctxt = NULL; + + free_deferred(xprt, rqstp->rq_deferred); + rqstp->rq_deferred = NULL; + + pagevec_release(&rqstp->rq_pvec); + svc_free_res_pages(rqstp); + rqstp->rq_res.page_len = 0; + rqstp->rq_res.page_base = 0; + + /* Reset response buffer and release + * the reservation. + * But first, check that enough space was reserved + * for the reply, otherwise we have a bug! + */ + if ((rqstp->rq_res.len) > rqstp->rq_reserved) + printk(KERN_ERR "RPC request reserved %d but used %d\n", + rqstp->rq_reserved, + rqstp->rq_res.len); + + rqstp->rq_res.head[0].iov_len = 0; + svc_reserve(rqstp, 0); + svc_xprt_release_slot(rqstp); + rqstp->rq_xprt = NULL; + svc_xprt_put(xprt); +} + +/* + * Some svc_serv's will have occasional work to do, even when a xprt is not + * waiting to be serviced. This function is there to "kick" a task in one of + * those services so that it can wake up and do that work. Note that we only + * bother with pool 0 as we don't need to wake up more than one thread for + * this purpose. + */ +void svc_wake_up(struct svc_serv *serv) +{ + struct svc_rqst *rqstp; + struct svc_pool *pool; + + pool = &serv->sv_pools[0]; + + rcu_read_lock(); + list_for_each_entry_rcu(rqstp, &pool->sp_all_threads, rq_all) { + /* skip any that aren't queued */ + if (test_bit(RQ_BUSY, &rqstp->rq_flags)) + continue; + rcu_read_unlock(); + wake_up_process(rqstp->rq_task); + trace_svc_wake_up(rqstp->rq_task->pid); + return; + } + rcu_read_unlock(); + + /* No free entries available */ + set_bit(SP_TASK_PENDING, &pool->sp_flags); + smp_wmb(); + trace_svc_wake_up(0); +} +EXPORT_SYMBOL_GPL(svc_wake_up); + +int svc_port_is_privileged(struct sockaddr *sin) +{ + switch (sin->sa_family) { + case AF_INET: + return ntohs(((struct sockaddr_in *)sin)->sin_port) + < PROT_SOCK; + case AF_INET6: + return ntohs(((struct sockaddr_in6 *)sin)->sin6_port) + < PROT_SOCK; + default: + return 0; + } +} + +/* + * Make sure that we don't have too many active connections. If we have, + * something must be dropped. It's not clear what will happen if we allow + * "too many" connections, but when dealing with network-facing software, + * we have to code defensively. Here we do that by imposing hard limits. + * + * There's no point in trying to do random drop here for DoS + * prevention. The NFS clients does 1 reconnect in 15 seconds. An + * attacker can easily beat that. + * + * The only somewhat efficient mechanism would be if drop old + * connections from the same IP first. But right now we don't even + * record the client IP in svc_sock. + * + * single-threaded services that expect a lot of clients will probably + * need to set sv_maxconn to override the default value which is based + * on the number of threads + */ +static void svc_check_conn_limits(struct svc_serv *serv) +{ + unsigned int limit = serv->sv_maxconn ? serv->sv_maxconn : + (serv->sv_nrthreads+3) * 20; + + if (serv->sv_tmpcnt > limit) { + struct svc_xprt *xprt = NULL; + spin_lock_bh(&serv->sv_lock); + if (!list_empty(&serv->sv_tempsocks)) { + /* Try to help the admin */ + net_notice_ratelimited("%s: too many open connections, consider increasing the %s\n", + serv->sv_name, serv->sv_maxconn ? + "max number of connections" : + "number of threads"); + /* + * Always select the oldest connection. It's not fair, + * but so is life + */ + xprt = list_entry(serv->sv_tempsocks.prev, + struct svc_xprt, + xpt_list); + set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_xprt_get(xprt); + } + spin_unlock_bh(&serv->sv_lock); + + if (xprt) { + svc_xprt_enqueue(xprt); + svc_xprt_put(xprt); + } + } +} + +static int svc_alloc_arg(struct svc_rqst *rqstp) +{ + struct svc_serv *serv = rqstp->rq_server; + struct xdr_buf *arg = &rqstp->rq_arg; + unsigned long pages, filled, ret; + + pagevec_init(&rqstp->rq_pvec); + + pages = (serv->sv_max_mesg + 2 * PAGE_SIZE) >> PAGE_SHIFT; + if (pages > RPCSVC_MAXPAGES) { + pr_warn_once("svc: warning: pages=%lu > RPCSVC_MAXPAGES=%lu\n", + pages, RPCSVC_MAXPAGES); + /* use as many pages as possible */ + pages = RPCSVC_MAXPAGES; + } + + for (filled = 0; filled < pages; filled = ret) { + ret = alloc_pages_bulk_array(GFP_KERNEL, pages, + rqstp->rq_pages); + if (ret > filled) + /* Made progress, don't sleep yet */ + continue; + + set_current_state(TASK_INTERRUPTIBLE); + if (signalled() || kthread_should_stop()) { + set_current_state(TASK_RUNNING); + return -EINTR; + } + trace_svc_alloc_arg_err(pages, ret); + memalloc_retry_wait(GFP_KERNEL); + } + rqstp->rq_page_end = &rqstp->rq_pages[pages]; + rqstp->rq_pages[pages] = NULL; /* this might be seen in nfsd_splice_actor() */ + + /* Make arg->head point to first page and arg->pages point to rest */ + arg->head[0].iov_base = page_address(rqstp->rq_pages[0]); + arg->head[0].iov_len = PAGE_SIZE; + arg->pages = rqstp->rq_pages + 1; + arg->page_base = 0; + /* save at least one page for response */ + arg->page_len = (pages-2)*PAGE_SIZE; + arg->len = (pages-1)*PAGE_SIZE; + arg->tail[0].iov_len = 0; + return 0; +} + +static bool +rqst_should_sleep(struct svc_rqst *rqstp) +{ + struct svc_pool *pool = rqstp->rq_pool; + + /* did someone call svc_wake_up? */ + if (test_and_clear_bit(SP_TASK_PENDING, &pool->sp_flags)) + return false; + + /* was a socket queued? */ + if (!list_empty(&pool->sp_sockets)) + return false; + + /* are we shutting down? */ + if (signalled() || kthread_should_stop()) + return false; + + /* are we freezing? */ + if (freezing(current)) + return false; + + return true; +} + +static struct svc_xprt *svc_get_next_xprt(struct svc_rqst *rqstp, long timeout) +{ + struct svc_pool *pool = rqstp->rq_pool; + long time_left = 0; + + /* rq_xprt should be clear on entry */ + WARN_ON_ONCE(rqstp->rq_xprt); + + rqstp->rq_xprt = svc_xprt_dequeue(pool); + if (rqstp->rq_xprt) + goto out_found; + + /* + * We have to be able to interrupt this wait + * to bring down the daemons ... + */ + set_current_state(TASK_INTERRUPTIBLE); + smp_mb__before_atomic(); + clear_bit(SP_CONGESTED, &pool->sp_flags); + clear_bit(RQ_BUSY, &rqstp->rq_flags); + smp_mb__after_atomic(); + + if (likely(rqst_should_sleep(rqstp))) + time_left = schedule_timeout(timeout); + else + __set_current_state(TASK_RUNNING); + + try_to_freeze(); + + set_bit(RQ_BUSY, &rqstp->rq_flags); + smp_mb__after_atomic(); + rqstp->rq_xprt = svc_xprt_dequeue(pool); + if (rqstp->rq_xprt) + goto out_found; + + if (!time_left) + atomic_long_inc(&pool->sp_stats.threads_timedout); + + if (signalled() || kthread_should_stop()) + return ERR_PTR(-EINTR); + return ERR_PTR(-EAGAIN); +out_found: + /* Normally we will wait up to 5 seconds for any required + * cache information to be provided. + */ + if (!test_bit(SP_CONGESTED, &pool->sp_flags)) + rqstp->rq_chandle.thread_wait = 5*HZ; + else + rqstp->rq_chandle.thread_wait = 1*HZ; + trace_svc_xprt_dequeue(rqstp); + return rqstp->rq_xprt; +} + +static void svc_add_new_temp_xprt(struct svc_serv *serv, struct svc_xprt *newxpt) +{ + spin_lock_bh(&serv->sv_lock); + set_bit(XPT_TEMP, &newxpt->xpt_flags); + list_add(&newxpt->xpt_list, &serv->sv_tempsocks); + serv->sv_tmpcnt++; + if (serv->sv_temptimer.function == NULL) { + /* setup timer to age temp transports */ + serv->sv_temptimer.function = svc_age_temp_xprts; + mod_timer(&serv->sv_temptimer, + jiffies + svc_conn_age_period * HZ); + } + spin_unlock_bh(&serv->sv_lock); + svc_xprt_received(newxpt); +} + +static int svc_handle_xprt(struct svc_rqst *rqstp, struct svc_xprt *xprt) +{ + struct svc_serv *serv = rqstp->rq_server; + int len = 0; + + if (test_bit(XPT_CLOSE, &xprt->xpt_flags)) { + if (test_and_clear_bit(XPT_KILL_TEMP, &xprt->xpt_flags)) + xprt->xpt_ops->xpo_kill_temp_xprt(xprt); + svc_delete_xprt(xprt); + /* Leave XPT_BUSY set on the dead xprt: */ + goto out; + } + if (test_bit(XPT_LISTENER, &xprt->xpt_flags)) { + struct svc_xprt *newxpt; + /* + * We know this module_get will succeed because the + * listener holds a reference too + */ + __module_get(xprt->xpt_class->xcl_owner); + svc_check_conn_limits(xprt->xpt_server); + newxpt = xprt->xpt_ops->xpo_accept(xprt); + if (newxpt) { + newxpt->xpt_cred = get_cred(xprt->xpt_cred); + svc_add_new_temp_xprt(serv, newxpt); + trace_svc_xprt_accept(newxpt, serv->sv_name); + } else { + module_put(xprt->xpt_class->xcl_owner); + } + svc_xprt_received(xprt); + } else if (svc_xprt_reserve_slot(rqstp, xprt)) { + /* XPT_DATA|XPT_DEFERRED case: */ + dprintk("svc: server %p, pool %u, transport %p, inuse=%d\n", + rqstp, rqstp->rq_pool->sp_id, xprt, + kref_read(&xprt->xpt_ref)); + rqstp->rq_deferred = svc_deferred_dequeue(xprt); + if (rqstp->rq_deferred) + len = svc_deferred_recv(rqstp); + else + len = xprt->xpt_ops->xpo_recvfrom(rqstp); + rqstp->rq_stime = ktime_get(); + rqstp->rq_reserved = serv->sv_max_mesg; + atomic_add(rqstp->rq_reserved, &xprt->xpt_reserved); + } else + svc_xprt_received(xprt); + +out: + return len; +} + +/* + * Receive the next request on any transport. This code is carefully + * organised not to touch any cachelines in the shared svc_serv + * structure, only cachelines in the local svc_pool. + */ +int svc_recv(struct svc_rqst *rqstp, long timeout) +{ + struct svc_xprt *xprt = NULL; + struct svc_serv *serv = rqstp->rq_server; + int len, err; + + err = svc_alloc_arg(rqstp); + if (err) + goto out; + + try_to_freeze(); + cond_resched(); + err = -EINTR; + if (signalled() || kthread_should_stop()) + goto out; + + xprt = svc_get_next_xprt(rqstp, timeout); + if (IS_ERR(xprt)) { + err = PTR_ERR(xprt); + goto out; + } + + len = svc_handle_xprt(rqstp, xprt); + + /* No data, incomplete (TCP) read, or accept() */ + err = -EAGAIN; + if (len <= 0) + goto out_release; + trace_svc_xdr_recvfrom(&rqstp->rq_arg); + + clear_bit(XPT_OLD, &xprt->xpt_flags); + + xprt->xpt_ops->xpo_secure_port(rqstp); + rqstp->rq_chandle.defer = svc_defer; + rqstp->rq_xid = svc_getu32(&rqstp->rq_arg.head[0]); + + if (serv->sv_stats) + serv->sv_stats->netcnt++; + return len; +out_release: + rqstp->rq_res.len = 0; + svc_xprt_release(rqstp); +out: + return err; +} +EXPORT_SYMBOL_GPL(svc_recv); + +/* + * Drop request + */ +void svc_drop(struct svc_rqst *rqstp) +{ + trace_svc_drop(rqstp); + svc_xprt_release(rqstp); +} +EXPORT_SYMBOL_GPL(svc_drop); + +/* + * Return reply to client. + */ +int svc_send(struct svc_rqst *rqstp) +{ + struct svc_xprt *xprt; + int len = -EFAULT; + struct xdr_buf *xb; + + xprt = rqstp->rq_xprt; + if (!xprt) + goto out; + + /* calculate over-all length */ + xb = &rqstp->rq_res; + xb->len = xb->head[0].iov_len + + xb->page_len + + xb->tail[0].iov_len; + trace_svc_xdr_sendto(rqstp->rq_xid, xb); + trace_svc_stats_latency(rqstp); + + len = xprt->xpt_ops->xpo_sendto(rqstp); + + trace_svc_send(rqstp, len); + svc_xprt_release(rqstp); + + if (len == -ECONNREFUSED || len == -ENOTCONN || len == -EAGAIN) + len = 0; +out: + return len; +} + +/* + * Timer function to close old temporary transports, using + * a mark-and-sweep algorithm. + */ +static void svc_age_temp_xprts(struct timer_list *t) +{ + struct svc_serv *serv = from_timer(serv, t, sv_temptimer); + struct svc_xprt *xprt; + struct list_head *le, *next; + + dprintk("svc_age_temp_xprts\n"); + + if (!spin_trylock_bh(&serv->sv_lock)) { + /* busy, try again 1 sec later */ + dprintk("svc_age_temp_xprts: busy\n"); + mod_timer(&serv->sv_temptimer, jiffies + HZ); + return; + } + + list_for_each_safe(le, next, &serv->sv_tempsocks) { + xprt = list_entry(le, struct svc_xprt, xpt_list); + + /* First time through, just mark it OLD. Second time + * through, close it. */ + if (!test_and_set_bit(XPT_OLD, &xprt->xpt_flags)) + continue; + if (kref_read(&xprt->xpt_ref) > 1 || + test_bit(XPT_BUSY, &xprt->xpt_flags)) + continue; + list_del_init(le); + set_bit(XPT_CLOSE, &xprt->xpt_flags); + dprintk("queuing xprt %p for closing\n", xprt); + + /* a thread will dequeue and close it soon */ + svc_xprt_enqueue(xprt); + } + spin_unlock_bh(&serv->sv_lock); + + mod_timer(&serv->sv_temptimer, jiffies + svc_conn_age_period * HZ); +} + +/* Close temporary transports whose xpt_local matches server_addr immediately + * instead of waiting for them to be picked up by the timer. + * + * This is meant to be called from a notifier_block that runs when an ip + * address is deleted. + */ +void svc_age_temp_xprts_now(struct svc_serv *serv, struct sockaddr *server_addr) +{ + struct svc_xprt *xprt; + struct list_head *le, *next; + LIST_HEAD(to_be_closed); + + spin_lock_bh(&serv->sv_lock); + list_for_each_safe(le, next, &serv->sv_tempsocks) { + xprt = list_entry(le, struct svc_xprt, xpt_list); + if (rpc_cmp_addr(server_addr, (struct sockaddr *) + &xprt->xpt_local)) { + dprintk("svc_age_temp_xprts_now: found %p\n", xprt); + list_move(le, &to_be_closed); + } + } + spin_unlock_bh(&serv->sv_lock); + + while (!list_empty(&to_be_closed)) { + le = to_be_closed.next; + list_del_init(le); + xprt = list_entry(le, struct svc_xprt, xpt_list); + set_bit(XPT_CLOSE, &xprt->xpt_flags); + set_bit(XPT_KILL_TEMP, &xprt->xpt_flags); + dprintk("svc_age_temp_xprts_now: queuing xprt %p for closing\n", + xprt); + svc_xprt_enqueue(xprt); + } +} +EXPORT_SYMBOL_GPL(svc_age_temp_xprts_now); + +static void call_xpt_users(struct svc_xprt *xprt) +{ + struct svc_xpt_user *u; + + spin_lock(&xprt->xpt_lock); + while (!list_empty(&xprt->xpt_users)) { + u = list_first_entry(&xprt->xpt_users, struct svc_xpt_user, list); + list_del_init(&u->list); + u->callback(u); + } + spin_unlock(&xprt->xpt_lock); +} + +/* + * Remove a dead transport + */ +static void svc_delete_xprt(struct svc_xprt *xprt) +{ + struct svc_serv *serv = xprt->xpt_server; + struct svc_deferred_req *dr; + + if (test_and_set_bit(XPT_DEAD, &xprt->xpt_flags)) + return; + + trace_svc_xprt_detach(xprt); + xprt->xpt_ops->xpo_detach(xprt); + if (xprt->xpt_bc_xprt) + xprt->xpt_bc_xprt->ops->close(xprt->xpt_bc_xprt); + + spin_lock_bh(&serv->sv_lock); + list_del_init(&xprt->xpt_list); + WARN_ON_ONCE(!list_empty(&xprt->xpt_ready)); + if (test_bit(XPT_TEMP, &xprt->xpt_flags)) + serv->sv_tmpcnt--; + spin_unlock_bh(&serv->sv_lock); + + while ((dr = svc_deferred_dequeue(xprt)) != NULL) + free_deferred(xprt, dr); + + call_xpt_users(xprt); + svc_xprt_put(xprt); +} + +/** + * svc_xprt_close - Close a client connection + * @xprt: transport to disconnect + * + */ +void svc_xprt_close(struct svc_xprt *xprt) +{ + trace_svc_xprt_close(xprt); + set_bit(XPT_CLOSE, &xprt->xpt_flags); + if (test_and_set_bit(XPT_BUSY, &xprt->xpt_flags)) + /* someone else will have to effect the close */ + return; + /* + * We expect svc_close_xprt() to work even when no threads are + * running (e.g., while configuring the server before starting + * any threads), so if the transport isn't busy, we delete + * it ourself: + */ + svc_delete_xprt(xprt); +} +EXPORT_SYMBOL_GPL(svc_xprt_close); + +static int svc_close_list(struct svc_serv *serv, struct list_head *xprt_list, struct net *net) +{ + struct svc_xprt *xprt; + int ret = 0; + + spin_lock_bh(&serv->sv_lock); + list_for_each_entry(xprt, xprt_list, xpt_list) { + if (xprt->xpt_net != net) + continue; + ret++; + set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_xprt_enqueue(xprt); + } + spin_unlock_bh(&serv->sv_lock); + return ret; +} + +static struct svc_xprt *svc_dequeue_net(struct svc_serv *serv, struct net *net) +{ + struct svc_pool *pool; + struct svc_xprt *xprt; + struct svc_xprt *tmp; + int i; + + for (i = 0; i < serv->sv_nrpools; i++) { + pool = &serv->sv_pools[i]; + + spin_lock_bh(&pool->sp_lock); + list_for_each_entry_safe(xprt, tmp, &pool->sp_sockets, xpt_ready) { + if (xprt->xpt_net != net) + continue; + list_del_init(&xprt->xpt_ready); + spin_unlock_bh(&pool->sp_lock); + return xprt; + } + spin_unlock_bh(&pool->sp_lock); + } + return NULL; +} + +static void svc_clean_up_xprts(struct svc_serv *serv, struct net *net) +{ + struct svc_xprt *xprt; + + while ((xprt = svc_dequeue_net(serv, net))) { + set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_delete_xprt(xprt); + } +} + +/** + * svc_xprt_destroy_all - Destroy transports associated with @serv + * @serv: RPC service to be shut down + * @net: target network namespace + * + * Server threads may still be running (especially in the case where the + * service is still running in other network namespaces). + * + * So we shut down sockets the same way we would on a running server, by + * setting XPT_CLOSE, enqueuing, and letting a thread pick it up to do + * the close. In the case there are no such other threads, + * threads running, svc_clean_up_xprts() does a simple version of a + * server's main event loop, and in the case where there are other + * threads, we may need to wait a little while and then check again to + * see if they're done. + */ +void svc_xprt_destroy_all(struct svc_serv *serv, struct net *net) +{ + int delay = 0; + + while (svc_close_list(serv, &serv->sv_permsocks, net) + + svc_close_list(serv, &serv->sv_tempsocks, net)) { + + svc_clean_up_xprts(serv, net); + msleep(delay++); + } +} +EXPORT_SYMBOL_GPL(svc_xprt_destroy_all); + +/* + * Handle defer and revisit of requests + */ + +static void svc_revisit(struct cache_deferred_req *dreq, int too_many) +{ + struct svc_deferred_req *dr = + container_of(dreq, struct svc_deferred_req, handle); + struct svc_xprt *xprt = dr->xprt; + + spin_lock(&xprt->xpt_lock); + set_bit(XPT_DEFERRED, &xprt->xpt_flags); + if (too_many || test_bit(XPT_DEAD, &xprt->xpt_flags)) { + spin_unlock(&xprt->xpt_lock); + trace_svc_defer_drop(dr); + free_deferred(xprt, dr); + svc_xprt_put(xprt); + return; + } + dr->xprt = NULL; + list_add(&dr->handle.recent, &xprt->xpt_deferred); + spin_unlock(&xprt->xpt_lock); + trace_svc_defer_queue(dr); + svc_xprt_enqueue(xprt); + svc_xprt_put(xprt); +} + +/* + * Save the request off for later processing. The request buffer looks + * like this: + * + * + * + * This code can only handle requests that consist of an xprt-header + * and rpc-header. + */ +static struct cache_deferred_req *svc_defer(struct cache_req *req) +{ + struct svc_rqst *rqstp = container_of(req, struct svc_rqst, rq_chandle); + struct svc_deferred_req *dr; + + if (rqstp->rq_arg.page_len || !test_bit(RQ_USEDEFERRAL, &rqstp->rq_flags)) + return NULL; /* if more than a page, give up FIXME */ + if (rqstp->rq_deferred) { + dr = rqstp->rq_deferred; + rqstp->rq_deferred = NULL; + } else { + size_t skip; + size_t size; + /* FIXME maybe discard if size too large */ + size = sizeof(struct svc_deferred_req) + rqstp->rq_arg.len; + dr = kmalloc(size, GFP_KERNEL); + if (dr == NULL) + return NULL; + + dr->handle.owner = rqstp->rq_server; + dr->prot = rqstp->rq_prot; + memcpy(&dr->addr, &rqstp->rq_addr, rqstp->rq_addrlen); + dr->addrlen = rqstp->rq_addrlen; + dr->daddr = rqstp->rq_daddr; + dr->argslen = rqstp->rq_arg.len >> 2; + + /* back up head to the start of the buffer and copy */ + skip = rqstp->rq_arg.len - rqstp->rq_arg.head[0].iov_len; + memcpy(dr->args, rqstp->rq_arg.head[0].iov_base - skip, + dr->argslen << 2); + } + dr->xprt_ctxt = rqstp->rq_xprt_ctxt; + rqstp->rq_xprt_ctxt = NULL; + trace_svc_defer(rqstp); + svc_xprt_get(rqstp->rq_xprt); + dr->xprt = rqstp->rq_xprt; + set_bit(RQ_DROPME, &rqstp->rq_flags); + + dr->handle.revisit = svc_revisit; + return &dr->handle; +} + +/* + * recv data from a deferred request into an active one + */ +static noinline int svc_deferred_recv(struct svc_rqst *rqstp) +{ + struct svc_deferred_req *dr = rqstp->rq_deferred; + + trace_svc_defer_recv(dr); + + /* setup iov_base past transport header */ + rqstp->rq_arg.head[0].iov_base = dr->args; + /* The iov_len does not include the transport header bytes */ + rqstp->rq_arg.head[0].iov_len = dr->argslen << 2; + rqstp->rq_arg.page_len = 0; + /* The rq_arg.len includes the transport header bytes */ + rqstp->rq_arg.len = dr->argslen << 2; + rqstp->rq_prot = dr->prot; + memcpy(&rqstp->rq_addr, &dr->addr, dr->addrlen); + rqstp->rq_addrlen = dr->addrlen; + /* Save off transport header len in case we get deferred again */ + rqstp->rq_daddr = dr->daddr; + rqstp->rq_respages = rqstp->rq_pages; + rqstp->rq_xprt_ctxt = dr->xprt_ctxt; + + dr->xprt_ctxt = NULL; + svc_xprt_received(rqstp->rq_xprt); + return dr->argslen << 2; +} + + +static struct svc_deferred_req *svc_deferred_dequeue(struct svc_xprt *xprt) +{ + struct svc_deferred_req *dr = NULL; + + if (!test_bit(XPT_DEFERRED, &xprt->xpt_flags)) + return NULL; + spin_lock(&xprt->xpt_lock); + if (!list_empty(&xprt->xpt_deferred)) { + dr = list_entry(xprt->xpt_deferred.next, + struct svc_deferred_req, + handle.recent); + list_del_init(&dr->handle.recent); + } else + clear_bit(XPT_DEFERRED, &xprt->xpt_flags); + spin_unlock(&xprt->xpt_lock); + return dr; +} + +/** + * svc_find_xprt - find an RPC transport instance + * @serv: pointer to svc_serv to search + * @xcl_name: C string containing transport's class name + * @net: owner net pointer + * @af: Address family of transport's local address + * @port: transport's IP port number + * + * Return the transport instance pointer for the endpoint accepting + * connections/peer traffic from the specified transport class, + * address family and port. + * + * Specifying 0 for the address family or port is effectively a + * wild-card, and will result in matching the first transport in the + * service's list that has a matching class name. + */ +struct svc_xprt *svc_find_xprt(struct svc_serv *serv, const char *xcl_name, + struct net *net, const sa_family_t af, + const unsigned short port) +{ + struct svc_xprt *xprt; + struct svc_xprt *found = NULL; + + /* Sanity check the args */ + if (serv == NULL || xcl_name == NULL) + return found; + + spin_lock_bh(&serv->sv_lock); + list_for_each_entry(xprt, &serv->sv_permsocks, xpt_list) { + if (xprt->xpt_net != net) + continue; + if (strcmp(xprt->xpt_class->xcl_name, xcl_name)) + continue; + if (af != AF_UNSPEC && af != xprt->xpt_local.ss_family) + continue; + if (port != 0 && port != svc_xprt_local_port(xprt)) + continue; + found = xprt; + svc_xprt_get(xprt); + break; + } + spin_unlock_bh(&serv->sv_lock); + return found; +} +EXPORT_SYMBOL_GPL(svc_find_xprt); + +static int svc_one_xprt_name(const struct svc_xprt *xprt, + char *pos, int remaining) +{ + int len; + + len = snprintf(pos, remaining, "%s %u\n", + xprt->xpt_class->xcl_name, + svc_xprt_local_port(xprt)); + if (len >= remaining) + return -ENAMETOOLONG; + return len; +} + +/** + * svc_xprt_names - format a buffer with a list of transport names + * @serv: pointer to an RPC service + * @buf: pointer to a buffer to be filled in + * @buflen: length of buffer to be filled in + * + * Fills in @buf with a string containing a list of transport names, + * each name terminated with '\n'. + * + * Returns positive length of the filled-in string on success; otherwise + * a negative errno value is returned if an error occurs. + */ +int svc_xprt_names(struct svc_serv *serv, char *buf, const int buflen) +{ + struct svc_xprt *xprt; + int len, totlen; + char *pos; + + /* Sanity check args */ + if (!serv) + return 0; + + spin_lock_bh(&serv->sv_lock); + + pos = buf; + totlen = 0; + list_for_each_entry(xprt, &serv->sv_permsocks, xpt_list) { + len = svc_one_xprt_name(xprt, pos, buflen - totlen); + if (len < 0) { + *buf = '\0'; + totlen = len; + } + if (len <= 0) + break; + + pos += len; + totlen += len; + } + + spin_unlock_bh(&serv->sv_lock); + return totlen; +} +EXPORT_SYMBOL_GPL(svc_xprt_names); + + +/*----------------------------------------------------------------------------*/ + +static void *svc_pool_stats_start(struct seq_file *m, loff_t *pos) +{ + unsigned int pidx = (unsigned int)*pos; + struct svc_serv *serv = m->private; + + dprintk("svc_pool_stats_start, *pidx=%u\n", pidx); + + if (!pidx) + return SEQ_START_TOKEN; + return (pidx > serv->sv_nrpools ? NULL : &serv->sv_pools[pidx-1]); +} + +static void *svc_pool_stats_next(struct seq_file *m, void *p, loff_t *pos) +{ + struct svc_pool *pool = p; + struct svc_serv *serv = m->private; + + dprintk("svc_pool_stats_next, *pos=%llu\n", *pos); + + if (p == SEQ_START_TOKEN) { + pool = &serv->sv_pools[0]; + } else { + unsigned int pidx = (pool - &serv->sv_pools[0]); + if (pidx < serv->sv_nrpools-1) + pool = &serv->sv_pools[pidx+1]; + else + pool = NULL; + } + ++*pos; + return pool; +} + +static void svc_pool_stats_stop(struct seq_file *m, void *p) +{ +} + +static int svc_pool_stats_show(struct seq_file *m, void *p) +{ + struct svc_pool *pool = p; + + if (p == SEQ_START_TOKEN) { + seq_puts(m, "# pool packets-arrived sockets-enqueued threads-woken threads-timedout\n"); + return 0; + } + + seq_printf(m, "%u %lu %lu %lu %lu\n", + pool->sp_id, + (unsigned long)atomic_long_read(&pool->sp_stats.packets), + pool->sp_stats.sockets_queued, + (unsigned long)atomic_long_read(&pool->sp_stats.threads_woken), + (unsigned long)atomic_long_read(&pool->sp_stats.threads_timedout)); + + return 0; +} + +static const struct seq_operations svc_pool_stats_seq_ops = { + .start = svc_pool_stats_start, + .next = svc_pool_stats_next, + .stop = svc_pool_stats_stop, + .show = svc_pool_stats_show, +}; + +int svc_pool_stats_open(struct svc_serv *serv, struct file *file) +{ + int err; + + err = seq_open(file, &svc_pool_stats_seq_ops); + if (!err) + ((struct seq_file *) file->private_data)->private = serv; + return err; +} +EXPORT_SYMBOL(svc_pool_stats_open); + +/*----------------------------------------------------------------------------*/ diff --git a/net/sunrpc/svcauth.c b/net/sunrpc/svcauth.c new file mode 100644 index 000000000..e72ba2f13 --- /dev/null +++ b/net/sunrpc/svcauth.c @@ -0,0 +1,234 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * linux/net/sunrpc/svcauth.c + * + * The generic interface for RPC authentication on the server side. + * + * Copyright (C) 1995, 1996 Olaf Kirch + * + * CHANGES + * 19-Apr-2000 Chris Evans - Security fix + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "sunrpc.h" + +#define RPCDBG_FACILITY RPCDBG_AUTH + + +/* + * Table of authenticators + */ +extern struct auth_ops svcauth_null; +extern struct auth_ops svcauth_unix; +extern struct auth_ops svcauth_tls; + +static struct auth_ops __rcu *authtab[RPC_AUTH_MAXFLAVOR] = { + [RPC_AUTH_NULL] = (struct auth_ops __force __rcu *)&svcauth_null, + [RPC_AUTH_UNIX] = (struct auth_ops __force __rcu *)&svcauth_unix, + [RPC_AUTH_TLS] = (struct auth_ops __force __rcu *)&svcauth_tls, +}; + +static struct auth_ops * +svc_get_auth_ops(rpc_authflavor_t flavor) +{ + struct auth_ops *aops; + + if (flavor >= RPC_AUTH_MAXFLAVOR) + return NULL; + rcu_read_lock(); + aops = rcu_dereference(authtab[flavor]); + if (aops != NULL && !try_module_get(aops->owner)) + aops = NULL; + rcu_read_unlock(); + return aops; +} + +static void +svc_put_auth_ops(struct auth_ops *aops) +{ + module_put(aops->owner); +} + +int +svc_authenticate(struct svc_rqst *rqstp) +{ + rpc_authflavor_t flavor; + struct auth_ops *aops; + + rqstp->rq_auth_stat = rpc_auth_ok; + + flavor = svc_getnl(&rqstp->rq_arg.head[0]); + + dprintk("svc: svc_authenticate (%d)\n", flavor); + + aops = svc_get_auth_ops(flavor); + if (aops == NULL) { + rqstp->rq_auth_stat = rpc_autherr_badcred; + return SVC_DENIED; + } + + rqstp->rq_auth_slack = 0; + init_svc_cred(&rqstp->rq_cred); + + rqstp->rq_authop = aops; + return aops->accept(rqstp); +} +EXPORT_SYMBOL_GPL(svc_authenticate); + +int svc_set_client(struct svc_rqst *rqstp) +{ + rqstp->rq_client = NULL; + return rqstp->rq_authop->set_client(rqstp); +} +EXPORT_SYMBOL_GPL(svc_set_client); + +/* A request, which was authenticated, has now executed. + * Time to finalise the credentials and verifier + * and release and resources + */ +int svc_authorise(struct svc_rqst *rqstp) +{ + struct auth_ops *aops = rqstp->rq_authop; + int rv = 0; + + rqstp->rq_authop = NULL; + + if (aops) { + rv = aops->release(rqstp); + svc_put_auth_ops(aops); + } + return rv; +} + +int +svc_auth_register(rpc_authflavor_t flavor, struct auth_ops *aops) +{ + struct auth_ops *old; + int rv = -EINVAL; + + if (flavor < RPC_AUTH_MAXFLAVOR) { + old = cmpxchg((struct auth_ops ** __force)&authtab[flavor], NULL, aops); + if (old == NULL || old == aops) + rv = 0; + } + return rv; +} +EXPORT_SYMBOL_GPL(svc_auth_register); + +void +svc_auth_unregister(rpc_authflavor_t flavor) +{ + if (flavor < RPC_AUTH_MAXFLAVOR) + rcu_assign_pointer(authtab[flavor], NULL); +} +EXPORT_SYMBOL_GPL(svc_auth_unregister); + +/************************************************** + * 'auth_domains' are stored in a hash table indexed by name. + * When the last reference to an 'auth_domain' is dropped, + * the object is unhashed and freed. + * If auth_domain_lookup fails to find an entry, it will return + * it's second argument 'new'. If this is non-null, it will + * have been atomically linked into the table. + */ + +#define DN_HASHBITS 6 +#define DN_HASHMAX (1<hash); + dom->flavour->domain_release(dom); + spin_unlock(&auth_domain_lock); +} + +void auth_domain_put(struct auth_domain *dom) +{ + kref_put_lock(&dom->ref, auth_domain_release, &auth_domain_lock); +} +EXPORT_SYMBOL_GPL(auth_domain_put); + +struct auth_domain * +auth_domain_lookup(char *name, struct auth_domain *new) +{ + struct auth_domain *hp; + struct hlist_head *head; + + head = &auth_domain_table[hash_str(name, DN_HASHBITS)]; + + spin_lock(&auth_domain_lock); + + hlist_for_each_entry(hp, head, hash) { + if (strcmp(hp->name, name)==0) { + kref_get(&hp->ref); + spin_unlock(&auth_domain_lock); + return hp; + } + } + if (new) + hlist_add_head_rcu(&new->hash, head); + spin_unlock(&auth_domain_lock); + return new; +} +EXPORT_SYMBOL_GPL(auth_domain_lookup); + +struct auth_domain *auth_domain_find(char *name) +{ + struct auth_domain *hp; + struct hlist_head *head; + + head = &auth_domain_table[hash_str(name, DN_HASHBITS)]; + + rcu_read_lock(); + hlist_for_each_entry_rcu(hp, head, hash) { + if (strcmp(hp->name, name)==0) { + if (!kref_get_unless_zero(&hp->ref)) + hp = NULL; + rcu_read_unlock(); + return hp; + } + } + rcu_read_unlock(); + return NULL; +} +EXPORT_SYMBOL_GPL(auth_domain_find); + +/** + * auth_domain_cleanup - check that the auth_domain table is empty + * + * On module unload the auth_domain_table must be empty. To make it + * easier to catch bugs which don't clean up domains properly, we + * warn if anything remains in the table at cleanup time. + * + * Note that we cannot proactively remove the domains at this stage. + * The ->release() function might be in a module that has already been + * unloaded. + */ + +void auth_domain_cleanup(void) +{ + int h; + struct auth_domain *hp; + + for (h = 0; h < DN_HASHMAX; h++) + hlist_for_each_entry(hp, &auth_domain_table[h], hash) + pr_warn("svc: domain %s still present at module unload.\n", + hp->name); +} diff --git a/net/sunrpc/svcauth_unix.c b/net/sunrpc/svcauth_unix.c new file mode 100644 index 000000000..609ade4fb --- /dev/null +++ b/net/sunrpc/svcauth_unix.c @@ -0,0 +1,987 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#define RPCDBG_FACILITY RPCDBG_AUTH + + +#include "netns.h" + +/* + * AUTHUNIX and AUTHNULL credentials are both handled here. + * AUTHNULL is treated just like AUTHUNIX except that the uid/gid + * are always nobody (-2). i.e. we do the same IP address checks for + * AUTHNULL as for AUTHUNIX, and that is done here. + */ + + +struct unix_domain { + struct auth_domain h; + /* other stuff later */ +}; + +extern struct auth_ops svcauth_null; +extern struct auth_ops svcauth_unix; +extern struct auth_ops svcauth_tls; + +static void svcauth_unix_domain_release_rcu(struct rcu_head *head) +{ + struct auth_domain *dom = container_of(head, struct auth_domain, rcu_head); + struct unix_domain *ud = container_of(dom, struct unix_domain, h); + + kfree(dom->name); + kfree(ud); +} + +static void svcauth_unix_domain_release(struct auth_domain *dom) +{ + call_rcu(&dom->rcu_head, svcauth_unix_domain_release_rcu); +} + +struct auth_domain *unix_domain_find(char *name) +{ + struct auth_domain *rv; + struct unix_domain *new = NULL; + + rv = auth_domain_find(name); + while(1) { + if (rv) { + if (new && rv != &new->h) + svcauth_unix_domain_release(&new->h); + + if (rv->flavour != &svcauth_unix) { + auth_domain_put(rv); + return NULL; + } + return rv; + } + + new = kmalloc(sizeof(*new), GFP_KERNEL); + if (new == NULL) + return NULL; + kref_init(&new->h.ref); + new->h.name = kstrdup(name, GFP_KERNEL); + if (new->h.name == NULL) { + kfree(new); + return NULL; + } + new->h.flavour = &svcauth_unix; + rv = auth_domain_lookup(name, &new->h); + } +} +EXPORT_SYMBOL_GPL(unix_domain_find); + + +/************************************************** + * cache for IP address to unix_domain + * as needed by AUTH_UNIX + */ +#define IP_HASHBITS 8 +#define IP_HASHMAX (1<flags) && + !test_bit(CACHE_NEGATIVE, &item->flags)) + auth_domain_put(&im->m_client->h); + kfree_rcu(im, m_rcu); +} + +static inline int hash_ip6(const struct in6_addr *ip) +{ + return hash_32(ipv6_addr_hash(ip), IP_HASHBITS); +} +static int ip_map_match(struct cache_head *corig, struct cache_head *cnew) +{ + struct ip_map *orig = container_of(corig, struct ip_map, h); + struct ip_map *new = container_of(cnew, struct ip_map, h); + return strcmp(orig->m_class, new->m_class) == 0 && + ipv6_addr_equal(&orig->m_addr, &new->m_addr); +} +static void ip_map_init(struct cache_head *cnew, struct cache_head *citem) +{ + struct ip_map *new = container_of(cnew, struct ip_map, h); + struct ip_map *item = container_of(citem, struct ip_map, h); + + strcpy(new->m_class, item->m_class); + new->m_addr = item->m_addr; +} +static void update(struct cache_head *cnew, struct cache_head *citem) +{ + struct ip_map *new = container_of(cnew, struct ip_map, h); + struct ip_map *item = container_of(citem, struct ip_map, h); + + kref_get(&item->m_client->h.ref); + new->m_client = item->m_client; +} +static struct cache_head *ip_map_alloc(void) +{ + struct ip_map *i = kmalloc(sizeof(*i), GFP_KERNEL); + if (i) + return &i->h; + else + return NULL; +} + +static int ip_map_upcall(struct cache_detail *cd, struct cache_head *h) +{ + return sunrpc_cache_pipe_upcall(cd, h); +} + +static void ip_map_request(struct cache_detail *cd, + struct cache_head *h, + char **bpp, int *blen) +{ + char text_addr[40]; + struct ip_map *im = container_of(h, struct ip_map, h); + + if (ipv6_addr_v4mapped(&(im->m_addr))) { + snprintf(text_addr, 20, "%pI4", &im->m_addr.s6_addr32[3]); + } else { + snprintf(text_addr, 40, "%pI6", &im->m_addr); + } + qword_add(bpp, blen, im->m_class); + qword_add(bpp, blen, text_addr); + (*bpp)[-1] = '\n'; +} + +static struct ip_map *__ip_map_lookup(struct cache_detail *cd, char *class, struct in6_addr *addr); +static int __ip_map_update(struct cache_detail *cd, struct ip_map *ipm, struct unix_domain *udom, time64_t expiry); + +static int ip_map_parse(struct cache_detail *cd, + char *mesg, int mlen) +{ + /* class ipaddress [domainname] */ + /* should be safe just to use the start of the input buffer + * for scratch: */ + char *buf = mesg; + int len; + char class[8]; + union { + struct sockaddr sa; + struct sockaddr_in s4; + struct sockaddr_in6 s6; + } address; + struct sockaddr_in6 sin6; + int err; + + struct ip_map *ipmp; + struct auth_domain *dom; + time64_t expiry; + + if (mesg[mlen-1] != '\n') + return -EINVAL; + mesg[mlen-1] = 0; + + /* class */ + len = qword_get(&mesg, class, sizeof(class)); + if (len <= 0) return -EINVAL; + + /* ip address */ + len = qword_get(&mesg, buf, mlen); + if (len <= 0) return -EINVAL; + + if (rpc_pton(cd->net, buf, len, &address.sa, sizeof(address)) == 0) + return -EINVAL; + switch (address.sa.sa_family) { + case AF_INET: + /* Form a mapped IPv4 address in sin6 */ + sin6.sin6_family = AF_INET6; + ipv6_addr_set_v4mapped(address.s4.sin_addr.s_addr, + &sin6.sin6_addr); + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + memcpy(&sin6, &address.s6, sizeof(sin6)); + break; +#endif + default: + return -EINVAL; + } + + expiry = get_expiry(&mesg); + if (expiry ==0) + return -EINVAL; + + /* domainname, or empty for NEGATIVE */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) return -EINVAL; + + if (len) { + dom = unix_domain_find(buf); + if (dom == NULL) + return -ENOENT; + } else + dom = NULL; + + /* IPv6 scope IDs are ignored for now */ + ipmp = __ip_map_lookup(cd, class, &sin6.sin6_addr); + if (ipmp) { + err = __ip_map_update(cd, ipmp, + container_of(dom, struct unix_domain, h), + expiry); + } else + err = -ENOMEM; + + if (dom) + auth_domain_put(dom); + + cache_flush(); + return err; +} + +static int ip_map_show(struct seq_file *m, + struct cache_detail *cd, + struct cache_head *h) +{ + struct ip_map *im; + struct in6_addr addr; + char *dom = "-no-domain-"; + + if (h == NULL) { + seq_puts(m, "#class IP domain\n"); + return 0; + } + im = container_of(h, struct ip_map, h); + /* class addr domain */ + addr = im->m_addr; + + if (test_bit(CACHE_VALID, &h->flags) && + !test_bit(CACHE_NEGATIVE, &h->flags)) + dom = im->m_client->h.name; + + if (ipv6_addr_v4mapped(&addr)) { + seq_printf(m, "%s %pI4 %s\n", + im->m_class, &addr.s6_addr32[3], dom); + } else { + seq_printf(m, "%s %pI6 %s\n", im->m_class, &addr, dom); + } + return 0; +} + + +static struct ip_map *__ip_map_lookup(struct cache_detail *cd, char *class, + struct in6_addr *addr) +{ + struct ip_map ip; + struct cache_head *ch; + + strcpy(ip.m_class, class); + ip.m_addr = *addr; + ch = sunrpc_cache_lookup_rcu(cd, &ip.h, + hash_str(class, IP_HASHBITS) ^ + hash_ip6(addr)); + + if (ch) + return container_of(ch, struct ip_map, h); + else + return NULL; +} + +static int __ip_map_update(struct cache_detail *cd, struct ip_map *ipm, + struct unix_domain *udom, time64_t expiry) +{ + struct ip_map ip; + struct cache_head *ch; + + ip.m_client = udom; + ip.h.flags = 0; + if (!udom) + set_bit(CACHE_NEGATIVE, &ip.h.flags); + ip.h.expiry_time = expiry; + ch = sunrpc_cache_update(cd, &ip.h, &ipm->h, + hash_str(ipm->m_class, IP_HASHBITS) ^ + hash_ip6(&ipm->m_addr)); + if (!ch) + return -ENOMEM; + cache_put(ch, cd); + return 0; +} + +void svcauth_unix_purge(struct net *net) +{ + struct sunrpc_net *sn; + + sn = net_generic(net, sunrpc_net_id); + cache_purge(sn->ip_map_cache); +} +EXPORT_SYMBOL_GPL(svcauth_unix_purge); + +static inline struct ip_map * +ip_map_cached_get(struct svc_xprt *xprt) +{ + struct ip_map *ipm = NULL; + struct sunrpc_net *sn; + + if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags)) { + spin_lock(&xprt->xpt_lock); + ipm = xprt->xpt_auth_cache; + if (ipm != NULL) { + sn = net_generic(xprt->xpt_net, sunrpc_net_id); + if (cache_is_expired(sn->ip_map_cache, &ipm->h)) { + /* + * The entry has been invalidated since it was + * remembered, e.g. by a second mount from the + * same IP address. + */ + xprt->xpt_auth_cache = NULL; + spin_unlock(&xprt->xpt_lock); + cache_put(&ipm->h, sn->ip_map_cache); + return NULL; + } + cache_get(&ipm->h); + } + spin_unlock(&xprt->xpt_lock); + } + return ipm; +} + +static inline void +ip_map_cached_put(struct svc_xprt *xprt, struct ip_map *ipm) +{ + if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags)) { + spin_lock(&xprt->xpt_lock); + if (xprt->xpt_auth_cache == NULL) { + /* newly cached, keep the reference */ + xprt->xpt_auth_cache = ipm; + ipm = NULL; + } + spin_unlock(&xprt->xpt_lock); + } + if (ipm) { + struct sunrpc_net *sn; + + sn = net_generic(xprt->xpt_net, sunrpc_net_id); + cache_put(&ipm->h, sn->ip_map_cache); + } +} + +void +svcauth_unix_info_release(struct svc_xprt *xpt) +{ + struct ip_map *ipm; + + ipm = xpt->xpt_auth_cache; + if (ipm != NULL) { + struct sunrpc_net *sn; + + sn = net_generic(xpt->xpt_net, sunrpc_net_id); + cache_put(&ipm->h, sn->ip_map_cache); + } +} + +/**************************************************************************** + * auth.unix.gid cache + * simple cache to map a UID to a list of GIDs + * because AUTH_UNIX aka AUTH_SYS has a max of UNX_NGROUPS + */ +#define GID_HASHBITS 8 +#define GID_HASHMAX (1<h; + + if (test_bit(CACHE_VALID, &item->flags) && + !test_bit(CACHE_NEGATIVE, &item->flags)) + put_group_info(ug->gi); + kfree(ug); +} + +static void unix_gid_put(struct kref *kref) +{ + struct cache_head *item = container_of(kref, struct cache_head, ref); + struct unix_gid *ug = container_of(item, struct unix_gid, h); + + call_rcu(&ug->rcu, unix_gid_free); +} + +static int unix_gid_match(struct cache_head *corig, struct cache_head *cnew) +{ + struct unix_gid *orig = container_of(corig, struct unix_gid, h); + struct unix_gid *new = container_of(cnew, struct unix_gid, h); + return uid_eq(orig->uid, new->uid); +} +static void unix_gid_init(struct cache_head *cnew, struct cache_head *citem) +{ + struct unix_gid *new = container_of(cnew, struct unix_gid, h); + struct unix_gid *item = container_of(citem, struct unix_gid, h); + new->uid = item->uid; +} +static void unix_gid_update(struct cache_head *cnew, struct cache_head *citem) +{ + struct unix_gid *new = container_of(cnew, struct unix_gid, h); + struct unix_gid *item = container_of(citem, struct unix_gid, h); + + get_group_info(item->gi); + new->gi = item->gi; +} +static struct cache_head *unix_gid_alloc(void) +{ + struct unix_gid *g = kmalloc(sizeof(*g), GFP_KERNEL); + if (g) + return &g->h; + else + return NULL; +} + +static int unix_gid_upcall(struct cache_detail *cd, struct cache_head *h) +{ + return sunrpc_cache_pipe_upcall_timeout(cd, h); +} + +static void unix_gid_request(struct cache_detail *cd, + struct cache_head *h, + char **bpp, int *blen) +{ + char tuid[20]; + struct unix_gid *ug = container_of(h, struct unix_gid, h); + + snprintf(tuid, 20, "%u", from_kuid(&init_user_ns, ug->uid)); + qword_add(bpp, blen, tuid); + (*bpp)[-1] = '\n'; +} + +static struct unix_gid *unix_gid_lookup(struct cache_detail *cd, kuid_t uid); + +static int unix_gid_parse(struct cache_detail *cd, + char *mesg, int mlen) +{ + /* uid expiry Ngid gid0 gid1 ... gidN-1 */ + int id; + kuid_t uid; + int gids; + int rv; + int i; + int err; + time64_t expiry; + struct unix_gid ug, *ugp; + + if (mesg[mlen - 1] != '\n') + return -EINVAL; + mesg[mlen-1] = 0; + + rv = get_int(&mesg, &id); + if (rv) + return -EINVAL; + uid = make_kuid(current_user_ns(), id); + ug.uid = uid; + + expiry = get_expiry(&mesg); + if (expiry == 0) + return -EINVAL; + + rv = get_int(&mesg, &gids); + if (rv || gids < 0 || gids > 8192) + return -EINVAL; + + ug.gi = groups_alloc(gids); + if (!ug.gi) + return -ENOMEM; + + for (i = 0 ; i < gids ; i++) { + int gid; + kgid_t kgid; + rv = get_int(&mesg, &gid); + err = -EINVAL; + if (rv) + goto out; + kgid = make_kgid(current_user_ns(), gid); + if (!gid_valid(kgid)) + goto out; + ug.gi->gid[i] = kgid; + } + + groups_sort(ug.gi); + ugp = unix_gid_lookup(cd, uid); + if (ugp) { + struct cache_head *ch; + ug.h.flags = 0; + ug.h.expiry_time = expiry; + ch = sunrpc_cache_update(cd, + &ug.h, &ugp->h, + unix_gid_hash(uid)); + if (!ch) + err = -ENOMEM; + else { + err = 0; + cache_put(ch, cd); + } + } else + err = -ENOMEM; + out: + if (ug.gi) + put_group_info(ug.gi); + return err; +} + +static int unix_gid_show(struct seq_file *m, + struct cache_detail *cd, + struct cache_head *h) +{ + struct user_namespace *user_ns = m->file->f_cred->user_ns; + struct unix_gid *ug; + int i; + int glen; + + if (h == NULL) { + seq_puts(m, "#uid cnt: gids...\n"); + return 0; + } + ug = container_of(h, struct unix_gid, h); + if (test_bit(CACHE_VALID, &h->flags) && + !test_bit(CACHE_NEGATIVE, &h->flags)) + glen = ug->gi->ngroups; + else + glen = 0; + + seq_printf(m, "%u %d:", from_kuid_munged(user_ns, ug->uid), glen); + for (i = 0; i < glen; i++) + seq_printf(m, " %d", from_kgid_munged(user_ns, ug->gi->gid[i])); + seq_printf(m, "\n"); + return 0; +} + +static const struct cache_detail unix_gid_cache_template = { + .owner = THIS_MODULE, + .hash_size = GID_HASHMAX, + .name = "auth.unix.gid", + .cache_put = unix_gid_put, + .cache_upcall = unix_gid_upcall, + .cache_request = unix_gid_request, + .cache_parse = unix_gid_parse, + .cache_show = unix_gid_show, + .match = unix_gid_match, + .init = unix_gid_init, + .update = unix_gid_update, + .alloc = unix_gid_alloc, +}; + +int unix_gid_cache_create(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd; + int err; + + cd = cache_create_net(&unix_gid_cache_template, net); + if (IS_ERR(cd)) + return PTR_ERR(cd); + err = cache_register_net(cd, net); + if (err) { + cache_destroy_net(cd, net); + return err; + } + sn->unix_gid_cache = cd; + return 0; +} + +void unix_gid_cache_destroy(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd = sn->unix_gid_cache; + + sn->unix_gid_cache = NULL; + cache_purge(cd); + cache_unregister_net(cd, net); + cache_destroy_net(cd, net); +} + +static struct unix_gid *unix_gid_lookup(struct cache_detail *cd, kuid_t uid) +{ + struct unix_gid ug; + struct cache_head *ch; + + ug.uid = uid; + ch = sunrpc_cache_lookup_rcu(cd, &ug.h, unix_gid_hash(uid)); + if (ch) + return container_of(ch, struct unix_gid, h); + else + return NULL; +} + +static struct group_info *unix_gid_find(kuid_t uid, struct svc_rqst *rqstp) +{ + struct unix_gid *ug; + struct group_info *gi; + int ret; + struct sunrpc_net *sn = net_generic(rqstp->rq_xprt->xpt_net, + sunrpc_net_id); + + ug = unix_gid_lookup(sn->unix_gid_cache, uid); + if (!ug) + return ERR_PTR(-EAGAIN); + ret = cache_check(sn->unix_gid_cache, &ug->h, &rqstp->rq_chandle); + switch (ret) { + case -ENOENT: + return ERR_PTR(-ENOENT); + case -ETIMEDOUT: + return ERR_PTR(-ESHUTDOWN); + case 0: + gi = get_group_info(ug->gi); + cache_put(&ug->h, sn->unix_gid_cache); + return gi; + default: + return ERR_PTR(-EAGAIN); + } +} + +int +svcauth_unix_set_client(struct svc_rqst *rqstp) +{ + struct sockaddr_in *sin; + struct sockaddr_in6 *sin6, sin6_storage; + struct ip_map *ipm; + struct group_info *gi; + struct svc_cred *cred = &rqstp->rq_cred; + struct svc_xprt *xprt = rqstp->rq_xprt; + struct net *net = xprt->xpt_net; + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + + switch (rqstp->rq_addr.ss_family) { + case AF_INET: + sin = svc_addr_in(rqstp); + sin6 = &sin6_storage; + ipv6_addr_set_v4mapped(sin->sin_addr.s_addr, &sin6->sin6_addr); + break; + case AF_INET6: + sin6 = svc_addr_in6(rqstp); + break; + default: + BUG(); + } + + rqstp->rq_client = NULL; + if (rqstp->rq_proc == 0) + goto out; + + rqstp->rq_auth_stat = rpc_autherr_badcred; + ipm = ip_map_cached_get(xprt); + if (ipm == NULL) + ipm = __ip_map_lookup(sn->ip_map_cache, rqstp->rq_server->sv_program->pg_class, + &sin6->sin6_addr); + + if (ipm == NULL) + return SVC_DENIED; + + switch (cache_check(sn->ip_map_cache, &ipm->h, &rqstp->rq_chandle)) { + default: + BUG(); + case -ETIMEDOUT: + return SVC_CLOSE; + case -EAGAIN: + return SVC_DROP; + case -ENOENT: + return SVC_DENIED; + case 0: + rqstp->rq_client = &ipm->m_client->h; + kref_get(&rqstp->rq_client->ref); + ip_map_cached_put(xprt, ipm); + break; + } + + gi = unix_gid_find(cred->cr_uid, rqstp); + switch (PTR_ERR(gi)) { + case -EAGAIN: + return SVC_DROP; + case -ESHUTDOWN: + return SVC_CLOSE; + case -ENOENT: + break; + default: + put_group_info(cred->cr_group_info); + cred->cr_group_info = gi; + } + +out: + rqstp->rq_auth_stat = rpc_auth_ok; + return SVC_OK; +} + +EXPORT_SYMBOL_GPL(svcauth_unix_set_client); + +static int +svcauth_null_accept(struct svc_rqst *rqstp) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec *resv = &rqstp->rq_res.head[0]; + struct svc_cred *cred = &rqstp->rq_cred; + + if (argv->iov_len < 3*4) + return SVC_GARBAGE; + + if (svc_getu32(argv) != 0) { + dprintk("svc: bad null cred\n"); + rqstp->rq_auth_stat = rpc_autherr_badcred; + return SVC_DENIED; + } + if (svc_getu32(argv) != htonl(RPC_AUTH_NULL) || svc_getu32(argv) != 0) { + dprintk("svc: bad null verf\n"); + rqstp->rq_auth_stat = rpc_autherr_badverf; + return SVC_DENIED; + } + + /* Signal that mapping to nobody uid/gid is required */ + cred->cr_uid = INVALID_UID; + cred->cr_gid = INVALID_GID; + cred->cr_group_info = groups_alloc(0); + if (cred->cr_group_info == NULL) + return SVC_CLOSE; /* kmalloc failure - client must retry */ + + /* Put NULL verifier */ + svc_putnl(resv, RPC_AUTH_NULL); + svc_putnl(resv, 0); + + rqstp->rq_cred.cr_flavor = RPC_AUTH_NULL; + return SVC_OK; +} + +static int +svcauth_null_release(struct svc_rqst *rqstp) +{ + if (rqstp->rq_client) + auth_domain_put(rqstp->rq_client); + rqstp->rq_client = NULL; + if (rqstp->rq_cred.cr_group_info) + put_group_info(rqstp->rq_cred.cr_group_info); + rqstp->rq_cred.cr_group_info = NULL; + + return 0; /* don't drop */ +} + + +struct auth_ops svcauth_null = { + .name = "null", + .owner = THIS_MODULE, + .flavour = RPC_AUTH_NULL, + .accept = svcauth_null_accept, + .release = svcauth_null_release, + .set_client = svcauth_unix_set_client, +}; + + +static int +svcauth_tls_accept(struct svc_rqst *rqstp) +{ + struct svc_cred *cred = &rqstp->rq_cred; + struct kvec *argv = rqstp->rq_arg.head; + struct kvec *resv = rqstp->rq_res.head; + + if (argv->iov_len < XDR_UNIT * 3) + return SVC_GARBAGE; + + /* Call's cred length */ + if (svc_getu32(argv) != xdr_zero) { + rqstp->rq_auth_stat = rpc_autherr_badcred; + return SVC_DENIED; + } + + /* Call's verifier flavor and its length */ + if (svc_getu32(argv) != rpc_auth_null || + svc_getu32(argv) != xdr_zero) { + rqstp->rq_auth_stat = rpc_autherr_badverf; + return SVC_DENIED; + } + + /* AUTH_TLS is not valid on non-NULL procedures */ + if (rqstp->rq_proc != 0) { + rqstp->rq_auth_stat = rpc_autherr_badcred; + return SVC_DENIED; + } + + /* Mapping to nobody uid/gid is required */ + cred->cr_uid = INVALID_UID; + cred->cr_gid = INVALID_GID; + cred->cr_group_info = groups_alloc(0); + if (cred->cr_group_info == NULL) + return SVC_CLOSE; /* kmalloc failure - client must retry */ + + /* Reply's verifier */ + svc_putnl(resv, RPC_AUTH_NULL); + if (rqstp->rq_xprt->xpt_ops->xpo_start_tls) { + svc_putnl(resv, 8); + memcpy(resv->iov_base + resv->iov_len, "STARTTLS", 8); + resv->iov_len += 8; + } else + svc_putnl(resv, 0); + + rqstp->rq_cred.cr_flavor = RPC_AUTH_TLS; + return SVC_OK; +} + +struct auth_ops svcauth_tls = { + .name = "tls", + .owner = THIS_MODULE, + .flavour = RPC_AUTH_TLS, + .accept = svcauth_tls_accept, + .release = svcauth_null_release, + .set_client = svcauth_unix_set_client, +}; + + +static int +svcauth_unix_accept(struct svc_rqst *rqstp) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec *resv = &rqstp->rq_res.head[0]; + struct svc_cred *cred = &rqstp->rq_cred; + struct user_namespace *userns; + u32 slen, i; + int len = argv->iov_len; + + if ((len -= 3*4) < 0) + return SVC_GARBAGE; + + svc_getu32(argv); /* length */ + svc_getu32(argv); /* time stamp */ + slen = XDR_QUADLEN(svc_getnl(argv)); /* machname length */ + if (slen > 64 || (len -= (slen + 3)*4) < 0) + goto badcred; + argv->iov_base = (void*)((__be32*)argv->iov_base + slen); /* skip machname */ + argv->iov_len -= slen*4; + /* + * Note: we skip uid_valid()/gid_valid() checks here for + * backwards compatibility with clients that use -1 id's. + * Instead, -1 uid or gid is later mapped to the + * (export-specific) anonymous id by nfsd_setuser. + * Supplementary gid's will be left alone. + */ + userns = (rqstp->rq_xprt && rqstp->rq_xprt->xpt_cred) ? + rqstp->rq_xprt->xpt_cred->user_ns : &init_user_ns; + cred->cr_uid = make_kuid(userns, svc_getnl(argv)); /* uid */ + cred->cr_gid = make_kgid(userns, svc_getnl(argv)); /* gid */ + slen = svc_getnl(argv); /* gids length */ + if (slen > UNX_NGROUPS || (len -= (slen + 2)*4) < 0) + goto badcred; + cred->cr_group_info = groups_alloc(slen); + if (cred->cr_group_info == NULL) + return SVC_CLOSE; + for (i = 0; i < slen; i++) { + kgid_t kgid = make_kgid(userns, svc_getnl(argv)); + cred->cr_group_info->gid[i] = kgid; + } + groups_sort(cred->cr_group_info); + if (svc_getu32(argv) != htonl(RPC_AUTH_NULL) || svc_getu32(argv) != 0) { + rqstp->rq_auth_stat = rpc_autherr_badverf; + return SVC_DENIED; + } + + /* Put NULL verifier */ + svc_putnl(resv, RPC_AUTH_NULL); + svc_putnl(resv, 0); + + rqstp->rq_cred.cr_flavor = RPC_AUTH_UNIX; + return SVC_OK; + +badcred: + rqstp->rq_auth_stat = rpc_autherr_badcred; + return SVC_DENIED; +} + +static int +svcauth_unix_release(struct svc_rqst *rqstp) +{ + /* Verifier (such as it is) is already in place. + */ + if (rqstp->rq_client) + auth_domain_put(rqstp->rq_client); + rqstp->rq_client = NULL; + if (rqstp->rq_cred.cr_group_info) + put_group_info(rqstp->rq_cred.cr_group_info); + rqstp->rq_cred.cr_group_info = NULL; + + return 0; +} + + +struct auth_ops svcauth_unix = { + .name = "unix", + .owner = THIS_MODULE, + .flavour = RPC_AUTH_UNIX, + .accept = svcauth_unix_accept, + .release = svcauth_unix_release, + .domain_release = svcauth_unix_domain_release, + .set_client = svcauth_unix_set_client, +}; + +static const struct cache_detail ip_map_cache_template = { + .owner = THIS_MODULE, + .hash_size = IP_HASHMAX, + .name = "auth.unix.ip", + .cache_put = ip_map_put, + .cache_upcall = ip_map_upcall, + .cache_request = ip_map_request, + .cache_parse = ip_map_parse, + .cache_show = ip_map_show, + .match = ip_map_match, + .init = ip_map_init, + .update = update, + .alloc = ip_map_alloc, +}; + +int ip_map_cache_create(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd; + int err; + + cd = cache_create_net(&ip_map_cache_template, net); + if (IS_ERR(cd)) + return PTR_ERR(cd); + err = cache_register_net(cd, net); + if (err) { + cache_destroy_net(cd, net); + return err; + } + sn->ip_map_cache = cd; + return 0; +} + +void ip_map_cache_destroy(struct net *net) +{ + struct sunrpc_net *sn = net_generic(net, sunrpc_net_id); + struct cache_detail *cd = sn->ip_map_cache; + + sn->ip_map_cache = NULL; + cache_purge(cd); + cache_unregister_net(cd, net); + cache_destroy_net(cd, net); +} diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c new file mode 100644 index 000000000..23b4c728d --- /dev/null +++ b/net/sunrpc/svcsock.c @@ -0,0 +1,1521 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * linux/net/sunrpc/svcsock.c + * + * These are the RPC server socket internals. + * + * The server scheduling algorithm does not always distribute the load + * evenly when servicing a single client. May need to modify the + * svc_xprt_enqueue procedure... + * + * TCP support is largely untested and may be a little slow. The problem + * is that we currently do two separate recvfrom's, one for the 4-byte + * record length, and the second for the actual record. This could possibly + * be improved by always reading a minimum size of around 100 bytes and + * tucking any superfluous bytes away in a temporary store. Still, that + * leaves write requests out in the rain. An alternative may be to peek at + * the first skb in the queue, and if it matches the next TCP sequence + * number, to extract the record marker. Yuck. + * + * Copyright (C) 1995, 1996 Olaf Kirch + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "socklib.h" +#include "sunrpc.h" + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + + +static struct svc_sock *svc_setup_socket(struct svc_serv *, struct socket *, + int flags); +static int svc_udp_recvfrom(struct svc_rqst *); +static int svc_udp_sendto(struct svc_rqst *); +static void svc_sock_detach(struct svc_xprt *); +static void svc_tcp_sock_detach(struct svc_xprt *); +static void svc_sock_free(struct svc_xprt *); + +static struct svc_xprt *svc_create_socket(struct svc_serv *, int, + struct net *, struct sockaddr *, + int, int); +#ifdef CONFIG_DEBUG_LOCK_ALLOC +static struct lock_class_key svc_key[2]; +static struct lock_class_key svc_slock_key[2]; + +static void svc_reclassify_socket(struct socket *sock) +{ + struct sock *sk = sock->sk; + + if (WARN_ON_ONCE(!sock_allow_reclassification(sk))) + return; + + switch (sk->sk_family) { + case AF_INET: + sock_lock_init_class_and_name(sk, "slock-AF_INET-NFSD", + &svc_slock_key[0], + "sk_xprt.xpt_lock-AF_INET-NFSD", + &svc_key[0]); + break; + + case AF_INET6: + sock_lock_init_class_and_name(sk, "slock-AF_INET6-NFSD", + &svc_slock_key[1], + "sk_xprt.xpt_lock-AF_INET6-NFSD", + &svc_key[1]); + break; + + default: + BUG(); + } +} +#else +static void svc_reclassify_socket(struct socket *sock) +{ +} +#endif + +/** + * svc_tcp_release_ctxt - Release transport-related resources + * @xprt: the transport which owned the context + * @ctxt: the context from rqstp->rq_xprt_ctxt or dr->xprt_ctxt + * + */ +static void svc_tcp_release_ctxt(struct svc_xprt *xprt, void *ctxt) +{ +} + +/** + * svc_udp_release_ctxt - Release transport-related resources + * @xprt: the transport which owned the context + * @ctxt: the context from rqstp->rq_xprt_ctxt or dr->xprt_ctxt + * + */ +static void svc_udp_release_ctxt(struct svc_xprt *xprt, void *ctxt) +{ + struct sk_buff *skb = ctxt; + + if (skb) + consume_skb(skb); +} + +union svc_pktinfo_u { + struct in_pktinfo pkti; + struct in6_pktinfo pkti6; +}; +#define SVC_PKTINFO_SPACE \ + CMSG_SPACE(sizeof(union svc_pktinfo_u)) + +static void svc_set_cmsg_data(struct svc_rqst *rqstp, struct cmsghdr *cmh) +{ + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + switch (svsk->sk_sk->sk_family) { + case AF_INET: { + struct in_pktinfo *pki = CMSG_DATA(cmh); + + cmh->cmsg_level = SOL_IP; + cmh->cmsg_type = IP_PKTINFO; + pki->ipi_ifindex = 0; + pki->ipi_spec_dst.s_addr = + svc_daddr_in(rqstp)->sin_addr.s_addr; + cmh->cmsg_len = CMSG_LEN(sizeof(*pki)); + } + break; + + case AF_INET6: { + struct in6_pktinfo *pki = CMSG_DATA(cmh); + struct sockaddr_in6 *daddr = svc_daddr_in6(rqstp); + + cmh->cmsg_level = SOL_IPV6; + cmh->cmsg_type = IPV6_PKTINFO; + pki->ipi6_ifindex = daddr->sin6_scope_id; + pki->ipi6_addr = daddr->sin6_addr; + cmh->cmsg_len = CMSG_LEN(sizeof(*pki)); + } + break; + } +} + +static int svc_sock_result_payload(struct svc_rqst *rqstp, unsigned int offset, + unsigned int length) +{ + return 0; +} + +/* + * Report socket names for nfsdfs + */ +static int svc_one_sock_name(struct svc_sock *svsk, char *buf, int remaining) +{ + const struct sock *sk = svsk->sk_sk; + const char *proto_name = sk->sk_protocol == IPPROTO_UDP ? + "udp" : "tcp"; + int len; + + switch (sk->sk_family) { + case PF_INET: + len = snprintf(buf, remaining, "ipv4 %s %pI4 %d\n", + proto_name, + &inet_sk(sk)->inet_rcv_saddr, + inet_sk(sk)->inet_num); + break; +#if IS_ENABLED(CONFIG_IPV6) + case PF_INET6: + len = snprintf(buf, remaining, "ipv6 %s %pI6 %d\n", + proto_name, + &sk->sk_v6_rcv_saddr, + inet_sk(sk)->inet_num); + break; +#endif + default: + len = snprintf(buf, remaining, "*unknown-%d*\n", + sk->sk_family); + } + + if (len >= remaining) { + *buf = '\0'; + return -ENAMETOOLONG; + } + return len; +} + +#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE +static void svc_flush_bvec(const struct bio_vec *bvec, size_t size, size_t seek) +{ + struct bvec_iter bi = { + .bi_size = size + seek, + }; + struct bio_vec bv; + + bvec_iter_advance(bvec, &bi, seek & PAGE_MASK); + for_each_bvec(bv, bvec, bi, bi) + flush_dcache_page(bv.bv_page); +} +#else +static inline void svc_flush_bvec(const struct bio_vec *bvec, size_t size, + size_t seek) +{ +} +#endif + +/* + * Read from @rqstp's transport socket. The incoming message fills whole + * pages in @rqstp's rq_pages array until the last page of the message + * has been received into a partial page. + */ +static ssize_t svc_tcp_read_msg(struct svc_rqst *rqstp, size_t buflen, + size_t seek) +{ + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + struct bio_vec *bvec = rqstp->rq_bvec; + struct msghdr msg = { NULL }; + unsigned int i; + ssize_t len; + size_t t; + + clear_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + + for (i = 0, t = 0; t < buflen; i++, t += PAGE_SIZE) { + bvec[i].bv_page = rqstp->rq_pages[i]; + bvec[i].bv_len = PAGE_SIZE; + bvec[i].bv_offset = 0; + } + rqstp->rq_respages = &rqstp->rq_pages[i]; + rqstp->rq_next_page = rqstp->rq_respages + 1; + + iov_iter_bvec(&msg.msg_iter, ITER_DEST, bvec, i, buflen); + if (seek) { + iov_iter_advance(&msg.msg_iter, seek); + buflen -= seek; + } + len = sock_recvmsg(svsk->sk_sock, &msg, MSG_DONTWAIT); + if (len > 0) + svc_flush_bvec(bvec, len, seek); + + /* If we read a full record, then assume there may be more + * data to read (stream based sockets only!) + */ + if (len == buflen) + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + + return len; +} + +/* + * Set socket snd and rcv buffer lengths + */ +static void svc_sock_setbufsize(struct svc_sock *svsk, unsigned int nreqs) +{ + unsigned int max_mesg = svsk->sk_xprt.xpt_server->sv_max_mesg; + struct socket *sock = svsk->sk_sock; + + nreqs = min(nreqs, INT_MAX / 2 / max_mesg); + + lock_sock(sock->sk); + sock->sk->sk_sndbuf = nreqs * max_mesg * 2; + sock->sk->sk_rcvbuf = nreqs * max_mesg * 2; + sock->sk->sk_write_space(sock->sk); + release_sock(sock->sk); +} + +static void svc_sock_secure_port(struct svc_rqst *rqstp) +{ + if (svc_port_is_privileged(svc_addr(rqstp))) + set_bit(RQ_SECURE, &rqstp->rq_flags); + else + clear_bit(RQ_SECURE, &rqstp->rq_flags); +} + +/* + * INET callback when data has been received on the socket. + */ +static void svc_data_ready(struct sock *sk) +{ + struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; + + if (svsk) { + /* Refer to svc_setup_socket() for details. */ + rmb(); + svsk->sk_odata(sk); + trace_svcsock_data_ready(&svsk->sk_xprt, 0); + if (!test_and_set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags)) + svc_xprt_enqueue(&svsk->sk_xprt); + } +} + +/* + * INET callback when space is newly available on the socket. + */ +static void svc_write_space(struct sock *sk) +{ + struct svc_sock *svsk = (struct svc_sock *)(sk->sk_user_data); + + if (svsk) { + /* Refer to svc_setup_socket() for details. */ + rmb(); + trace_svcsock_write_space(&svsk->sk_xprt, 0); + svsk->sk_owspace(sk); + svc_xprt_enqueue(&svsk->sk_xprt); + } +} + +static int svc_tcp_has_wspace(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + + if (test_bit(XPT_LISTENER, &xprt->xpt_flags)) + return 1; + return !test_bit(SOCK_NOSPACE, &svsk->sk_sock->flags); +} + +static void svc_tcp_kill_temp_xprt(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + + sock_no_linger(svsk->sk_sock->sk); +} + +/* + * See net/ipv6/ip_sockglue.c : ip_cmsg_recv_pktinfo + */ +static int svc_udp_get_dest_address4(struct svc_rqst *rqstp, + struct cmsghdr *cmh) +{ + struct in_pktinfo *pki = CMSG_DATA(cmh); + struct sockaddr_in *daddr = svc_daddr_in(rqstp); + + if (cmh->cmsg_type != IP_PKTINFO) + return 0; + + daddr->sin_family = AF_INET; + daddr->sin_addr.s_addr = pki->ipi_spec_dst.s_addr; + return 1; +} + +/* + * See net/ipv6/datagram.c : ip6_datagram_recv_ctl + */ +static int svc_udp_get_dest_address6(struct svc_rqst *rqstp, + struct cmsghdr *cmh) +{ + struct in6_pktinfo *pki = CMSG_DATA(cmh); + struct sockaddr_in6 *daddr = svc_daddr_in6(rqstp); + + if (cmh->cmsg_type != IPV6_PKTINFO) + return 0; + + daddr->sin6_family = AF_INET6; + daddr->sin6_addr = pki->ipi6_addr; + daddr->sin6_scope_id = pki->ipi6_ifindex; + return 1; +} + +/* + * Copy the UDP datagram's destination address to the rqstp structure. + * The 'destination' address in this case is the address to which the + * peer sent the datagram, i.e. our local address. For multihomed + * hosts, this can change from msg to msg. Note that only the IP + * address changes, the port number should remain the same. + */ +static int svc_udp_get_dest_address(struct svc_rqst *rqstp, + struct cmsghdr *cmh) +{ + switch (cmh->cmsg_level) { + case SOL_IP: + return svc_udp_get_dest_address4(rqstp, cmh); + case SOL_IPV6: + return svc_udp_get_dest_address6(rqstp, cmh); + } + + return 0; +} + +/** + * svc_udp_recvfrom - Receive a datagram from a UDP socket. + * @rqstp: request structure into which to receive an RPC Call + * + * Called in a loop when XPT_DATA has been set. + * + * Returns: + * On success, the number of bytes in a received RPC Call, or + * %0 if a complete RPC Call message was not ready to return + */ +static int svc_udp_recvfrom(struct svc_rqst *rqstp) +{ + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + struct svc_serv *serv = svsk->sk_xprt.xpt_server; + struct sk_buff *skb; + union { + struct cmsghdr hdr; + long all[SVC_PKTINFO_SPACE / sizeof(long)]; + } buffer; + struct cmsghdr *cmh = &buffer.hdr; + struct msghdr msg = { + .msg_name = svc_addr(rqstp), + .msg_control = cmh, + .msg_controllen = sizeof(buffer), + .msg_flags = MSG_DONTWAIT, + }; + size_t len; + int err; + + if (test_and_clear_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags)) + /* udp sockets need large rcvbuf as all pending + * requests are still in that buffer. sndbuf must + * also be large enough that there is enough space + * for one reply per thread. We count all threads + * rather than threads in a particular pool, which + * provides an upper bound on the number of threads + * which will access the socket. + */ + svc_sock_setbufsize(svsk, serv->sv_nrthreads + 3); + + clear_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + err = kernel_recvmsg(svsk->sk_sock, &msg, NULL, + 0, 0, MSG_PEEK | MSG_DONTWAIT); + if (err < 0) + goto out_recv_err; + skb = skb_recv_udp(svsk->sk_sk, MSG_DONTWAIT, &err); + if (!skb) + goto out_recv_err; + + len = svc_addr_len(svc_addr(rqstp)); + rqstp->rq_addrlen = len; + if (skb->tstamp == 0) { + skb->tstamp = ktime_get_real(); + /* Don't enable netstamp, sunrpc doesn't + need that much accuracy */ + } + sock_write_timestamp(svsk->sk_sk, skb->tstamp); + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); /* there may be more data... */ + + len = skb->len; + rqstp->rq_arg.len = len; + trace_svcsock_udp_recv(&svsk->sk_xprt, len); + + rqstp->rq_prot = IPPROTO_UDP; + + if (!svc_udp_get_dest_address(rqstp, cmh)) + goto out_cmsg_err; + rqstp->rq_daddrlen = svc_addr_len(svc_daddr(rqstp)); + + if (skb_is_nonlinear(skb)) { + /* we have to copy */ + local_bh_disable(); + if (csum_partial_copy_to_xdr(&rqstp->rq_arg, skb)) + goto out_bh_enable; + local_bh_enable(); + consume_skb(skb); + } else { + /* we can use it in-place */ + rqstp->rq_arg.head[0].iov_base = skb->data; + rqstp->rq_arg.head[0].iov_len = len; + if (skb_checksum_complete(skb)) + goto out_free; + rqstp->rq_xprt_ctxt = skb; + } + + rqstp->rq_arg.page_base = 0; + if (len <= rqstp->rq_arg.head[0].iov_len) { + rqstp->rq_arg.head[0].iov_len = len; + rqstp->rq_arg.page_len = 0; + rqstp->rq_respages = rqstp->rq_pages+1; + } else { + rqstp->rq_arg.page_len = len - rqstp->rq_arg.head[0].iov_len; + rqstp->rq_respages = rqstp->rq_pages + 1 + + DIV_ROUND_UP(rqstp->rq_arg.page_len, PAGE_SIZE); + } + rqstp->rq_next_page = rqstp->rq_respages+1; + + if (serv->sv_stats) + serv->sv_stats->netudpcnt++; + + svc_xprt_received(rqstp->rq_xprt); + return len; + +out_recv_err: + if (err != -EAGAIN) { + /* possibly an icmp error */ + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + } + trace_svcsock_udp_recv_err(&svsk->sk_xprt, err); + goto out_clear_busy; +out_cmsg_err: + net_warn_ratelimited("svc: received unknown control message %d/%d; dropping RPC reply datagram\n", + cmh->cmsg_level, cmh->cmsg_type); + goto out_free; +out_bh_enable: + local_bh_enable(); +out_free: + kfree_skb(skb); +out_clear_busy: + svc_xprt_received(rqstp->rq_xprt); + return 0; +} + +/** + * svc_udp_sendto - Send out a reply on a UDP socket + * @rqstp: completed svc_rqst + * + * xpt_mutex ensures @rqstp's whole message is written to the socket + * without interruption. + * + * Returns the number of bytes sent, or a negative errno. + */ +static int svc_udp_sendto(struct svc_rqst *rqstp) +{ + struct svc_xprt *xprt = rqstp->rq_xprt; + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + struct xdr_buf *xdr = &rqstp->rq_res; + union { + struct cmsghdr hdr; + long all[SVC_PKTINFO_SPACE / sizeof(long)]; + } buffer; + struct cmsghdr *cmh = &buffer.hdr; + struct msghdr msg = { + .msg_name = &rqstp->rq_addr, + .msg_namelen = rqstp->rq_addrlen, + .msg_control = cmh, + .msg_controllen = sizeof(buffer), + }; + unsigned int sent; + int err; + + svc_udp_release_ctxt(xprt, rqstp->rq_xprt_ctxt); + rqstp->rq_xprt_ctxt = NULL; + + svc_set_cmsg_data(rqstp, cmh); + + mutex_lock(&xprt->xpt_mutex); + + if (svc_xprt_is_dead(xprt)) + goto out_notconn; + + err = xdr_alloc_bvec(xdr, GFP_KERNEL); + if (err < 0) + goto out_unlock; + + err = xprt_sock_sendmsg(svsk->sk_sock, &msg, xdr, 0, 0, &sent); + if (err == -ECONNREFUSED) { + /* ICMP error on earlier request. */ + err = xprt_sock_sendmsg(svsk->sk_sock, &msg, xdr, 0, 0, &sent); + } + xdr_free_bvec(xdr); + trace_svcsock_udp_send(xprt, err); +out_unlock: + mutex_unlock(&xprt->xpt_mutex); + if (err < 0) + return err; + return sent; + +out_notconn: + mutex_unlock(&xprt->xpt_mutex); + return -ENOTCONN; +} + +static int svc_udp_has_wspace(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + struct svc_serv *serv = xprt->xpt_server; + unsigned long required; + + /* + * Set the SOCK_NOSPACE flag before checking the available + * sock space. + */ + set_bit(SOCK_NOSPACE, &svsk->sk_sock->flags); + required = atomic_read(&svsk->sk_xprt.xpt_reserved) + serv->sv_max_mesg; + if (required*2 > sock_wspace(svsk->sk_sk)) + return 0; + clear_bit(SOCK_NOSPACE, &svsk->sk_sock->flags); + return 1; +} + +static struct svc_xprt *svc_udp_accept(struct svc_xprt *xprt) +{ + BUG(); + return NULL; +} + +static void svc_udp_kill_temp_xprt(struct svc_xprt *xprt) +{ +} + +static struct svc_xprt *svc_udp_create(struct svc_serv *serv, + struct net *net, + struct sockaddr *sa, int salen, + int flags) +{ + return svc_create_socket(serv, IPPROTO_UDP, net, sa, salen, flags); +} + +static const struct svc_xprt_ops svc_udp_ops = { + .xpo_create = svc_udp_create, + .xpo_recvfrom = svc_udp_recvfrom, + .xpo_sendto = svc_udp_sendto, + .xpo_result_payload = svc_sock_result_payload, + .xpo_release_ctxt = svc_udp_release_ctxt, + .xpo_detach = svc_sock_detach, + .xpo_free = svc_sock_free, + .xpo_has_wspace = svc_udp_has_wspace, + .xpo_accept = svc_udp_accept, + .xpo_secure_port = svc_sock_secure_port, + .xpo_kill_temp_xprt = svc_udp_kill_temp_xprt, +}; + +static struct svc_xprt_class svc_udp_class = { + .xcl_name = "udp", + .xcl_owner = THIS_MODULE, + .xcl_ops = &svc_udp_ops, + .xcl_max_payload = RPCSVC_MAXPAYLOAD_UDP, + .xcl_ident = XPRT_TRANSPORT_UDP, +}; + +static void svc_udp_init(struct svc_sock *svsk, struct svc_serv *serv) +{ + svc_xprt_init(sock_net(svsk->sk_sock->sk), &svc_udp_class, + &svsk->sk_xprt, serv); + clear_bit(XPT_CACHE_AUTH, &svsk->sk_xprt.xpt_flags); + svsk->sk_sk->sk_data_ready = svc_data_ready; + svsk->sk_sk->sk_write_space = svc_write_space; + + /* initialise setting must have enough space to + * receive and respond to one request. + * svc_udp_recvfrom will re-adjust if necessary + */ + svc_sock_setbufsize(svsk, 3); + + /* data might have come in before data_ready set up */ + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + set_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags); + + /* make sure we get destination address info */ + switch (svsk->sk_sk->sk_family) { + case AF_INET: + ip_sock_set_pktinfo(svsk->sk_sock->sk); + break; + case AF_INET6: + ip6_sock_set_recvpktinfo(svsk->sk_sock->sk); + break; + default: + BUG(); + } +} + +/* + * A data_ready event on a listening socket means there's a connection + * pending. Do not use state_change as a substitute for it. + */ +static void svc_tcp_listen_data_ready(struct sock *sk) +{ + struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; + + /* + * This callback may called twice when a new connection + * is established as a child socket inherits everything + * from a parent LISTEN socket. + * 1) data_ready method of the parent socket will be called + * when one of child sockets become ESTABLISHED. + * 2) data_ready method of the child socket may be called + * when it receives data before the socket is accepted. + * In case of 2, we should ignore it silently and DO NOT + * dereference svsk. + */ + if (sk->sk_state != TCP_LISTEN) + return; + + if (svsk) { + /* Refer to svc_setup_socket() for details. */ + rmb(); + svsk->sk_odata(sk); + set_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); + svc_xprt_enqueue(&svsk->sk_xprt); + } +} + +/* + * A state change on a connected socket means it's dying or dead. + */ +static void svc_tcp_state_change(struct sock *sk) +{ + struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; + + if (svsk) { + /* Refer to svc_setup_socket() for details. */ + rmb(); + svsk->sk_ostate(sk); + trace_svcsock_tcp_state(&svsk->sk_xprt, svsk->sk_sock); + if (sk->sk_state != TCP_ESTABLISHED) + svc_xprt_deferred_close(&svsk->sk_xprt); + } +} + +/* + * Accept a TCP connection + */ +static struct svc_xprt *svc_tcp_accept(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + struct sockaddr_storage addr; + struct sockaddr *sin = (struct sockaddr *) &addr; + struct svc_serv *serv = svsk->sk_xprt.xpt_server; + struct socket *sock = svsk->sk_sock; + struct socket *newsock; + struct svc_sock *newsvsk; + int err, slen; + + if (!sock) + return NULL; + + clear_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); + err = kernel_accept(sock, &newsock, O_NONBLOCK); + if (err < 0) { + if (err == -ENOMEM) + printk(KERN_WARNING "%s: no more sockets!\n", + serv->sv_name); + else if (err != -EAGAIN) + net_warn_ratelimited("%s: accept failed (err %d)!\n", + serv->sv_name, -err); + trace_svcsock_accept_err(xprt, serv->sv_name, err); + return NULL; + } + set_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); + + err = kernel_getpeername(newsock, sin); + if (err < 0) { + trace_svcsock_getpeername_err(xprt, serv->sv_name, err); + goto failed; /* aborted connection or whatever */ + } + slen = err; + + /* Reset the inherited callbacks before calling svc_setup_socket */ + newsock->sk->sk_state_change = svsk->sk_ostate; + newsock->sk->sk_data_ready = svsk->sk_odata; + newsock->sk->sk_write_space = svsk->sk_owspace; + + /* make sure that a write doesn't block forever when + * low on memory + */ + newsock->sk->sk_sndtimeo = HZ*30; + + newsvsk = svc_setup_socket(serv, newsock, + (SVC_SOCK_ANONYMOUS | SVC_SOCK_TEMPORARY)); + if (IS_ERR(newsvsk)) + goto failed; + svc_xprt_set_remote(&newsvsk->sk_xprt, sin, slen); + err = kernel_getsockname(newsock, sin); + slen = err; + if (unlikely(err < 0)) + slen = offsetof(struct sockaddr, sa_data); + svc_xprt_set_local(&newsvsk->sk_xprt, sin, slen); + + if (sock_is_loopback(newsock->sk)) + set_bit(XPT_LOCAL, &newsvsk->sk_xprt.xpt_flags); + else + clear_bit(XPT_LOCAL, &newsvsk->sk_xprt.xpt_flags); + if (serv->sv_stats) + serv->sv_stats->nettcpconn++; + + return &newsvsk->sk_xprt; + +failed: + sock_release(newsock); + return NULL; +} + +static size_t svc_tcp_restore_pages(struct svc_sock *svsk, + struct svc_rqst *rqstp) +{ + size_t len = svsk->sk_datalen; + unsigned int i, npages; + + if (!len) + return 0; + npages = (len + PAGE_SIZE - 1) >> PAGE_SHIFT; + for (i = 0; i < npages; i++) { + if (rqstp->rq_pages[i] != NULL) + put_page(rqstp->rq_pages[i]); + BUG_ON(svsk->sk_pages[i] == NULL); + rqstp->rq_pages[i] = svsk->sk_pages[i]; + svsk->sk_pages[i] = NULL; + } + rqstp->rq_arg.head[0].iov_base = page_address(rqstp->rq_pages[0]); + return len; +} + +static void svc_tcp_save_pages(struct svc_sock *svsk, struct svc_rqst *rqstp) +{ + unsigned int i, len, npages; + + if (svsk->sk_datalen == 0) + return; + len = svsk->sk_datalen; + npages = (len + PAGE_SIZE - 1) >> PAGE_SHIFT; + for (i = 0; i < npages; i++) { + svsk->sk_pages[i] = rqstp->rq_pages[i]; + rqstp->rq_pages[i] = NULL; + } +} + +static void svc_tcp_clear_pages(struct svc_sock *svsk) +{ + unsigned int i, len, npages; + + if (svsk->sk_datalen == 0) + goto out; + len = svsk->sk_datalen; + npages = (len + PAGE_SIZE - 1) >> PAGE_SHIFT; + for (i = 0; i < npages; i++) { + if (svsk->sk_pages[i] == NULL) { + WARN_ON_ONCE(1); + continue; + } + put_page(svsk->sk_pages[i]); + svsk->sk_pages[i] = NULL; + } +out: + svsk->sk_tcplen = 0; + svsk->sk_datalen = 0; +} + +/* + * Receive fragment record header into sk_marker. + */ +static ssize_t svc_tcp_read_marker(struct svc_sock *svsk, + struct svc_rqst *rqstp) +{ + ssize_t want, len; + + /* If we haven't gotten the record length yet, + * get the next four bytes. + */ + if (svsk->sk_tcplen < sizeof(rpc_fraghdr)) { + struct msghdr msg = { NULL }; + struct kvec iov; + + want = sizeof(rpc_fraghdr) - svsk->sk_tcplen; + iov.iov_base = ((char *)&svsk->sk_marker) + svsk->sk_tcplen; + iov.iov_len = want; + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &iov, 1, want); + len = sock_recvmsg(svsk->sk_sock, &msg, MSG_DONTWAIT); + if (len < 0) + return len; + svsk->sk_tcplen += len; + if (len < want) { + /* call again to read the remaining bytes */ + goto err_short; + } + trace_svcsock_marker(&svsk->sk_xprt, svsk->sk_marker); + if (svc_sock_reclen(svsk) + svsk->sk_datalen > + svsk->sk_xprt.xpt_server->sv_max_mesg) + goto err_too_large; + } + return svc_sock_reclen(svsk); + +err_too_large: + net_notice_ratelimited("svc: %s %s RPC fragment too large: %d\n", + __func__, svsk->sk_xprt.xpt_server->sv_name, + svc_sock_reclen(svsk)); + svc_xprt_deferred_close(&svsk->sk_xprt); +err_short: + return -EAGAIN; +} + +static int receive_cb_reply(struct svc_sock *svsk, struct svc_rqst *rqstp) +{ + struct rpc_xprt *bc_xprt = svsk->sk_xprt.xpt_bc_xprt; + struct rpc_rqst *req = NULL; + struct kvec *src, *dst; + __be32 *p = (__be32 *)rqstp->rq_arg.head[0].iov_base; + __be32 xid; + __be32 calldir; + + xid = *p++; + calldir = *p; + + if (!bc_xprt) + return -EAGAIN; + spin_lock(&bc_xprt->queue_lock); + req = xprt_lookup_rqst(bc_xprt, xid); + if (!req) + goto unlock_notfound; + + memcpy(&req->rq_private_buf, &req->rq_rcv_buf, sizeof(struct xdr_buf)); + /* + * XXX!: cheating for now! Only copying HEAD. + * But we know this is good enough for now (in fact, for any + * callback reply in the forseeable future). + */ + dst = &req->rq_private_buf.head[0]; + src = &rqstp->rq_arg.head[0]; + if (dst->iov_len < src->iov_len) + goto unlock_eagain; /* whatever; just giving up. */ + memcpy(dst->iov_base, src->iov_base, src->iov_len); + xprt_complete_rqst(req->rq_task, rqstp->rq_arg.len); + rqstp->rq_arg.len = 0; + spin_unlock(&bc_xprt->queue_lock); + return 0; +unlock_notfound: + printk(KERN_NOTICE + "%s: Got unrecognized reply: " + "calldir 0x%x xpt_bc_xprt %p xid %08x\n", + __func__, ntohl(calldir), + bc_xprt, ntohl(xid)); +unlock_eagain: + spin_unlock(&bc_xprt->queue_lock); + return -EAGAIN; +} + +static void svc_tcp_fragment_received(struct svc_sock *svsk) +{ + /* If we have more data, signal svc_xprt_enqueue() to try again */ + svsk->sk_tcplen = 0; + svsk->sk_marker = xdr_zero; +} + +/** + * svc_tcp_recvfrom - Receive data from a TCP socket + * @rqstp: request structure into which to receive an RPC Call + * + * Called in a loop when XPT_DATA has been set. + * + * Read the 4-byte stream record marker, then use the record length + * in that marker to set up exactly the resources needed to receive + * the next RPC message into @rqstp. + * + * Returns: + * On success, the number of bytes in a received RPC Call, or + * %0 if a complete RPC Call message was not ready to return + * + * The zero return case handles partial receives and callback Replies. + * The state of a partial receive is preserved in the svc_sock for + * the next call to svc_tcp_recvfrom. + */ +static int svc_tcp_recvfrom(struct svc_rqst *rqstp) +{ + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + struct svc_serv *serv = svsk->sk_xprt.xpt_server; + size_t want, base; + ssize_t len; + __be32 *p; + __be32 calldir; + + clear_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + len = svc_tcp_read_marker(svsk, rqstp); + if (len < 0) + goto error; + + base = svc_tcp_restore_pages(svsk, rqstp); + want = len - (svsk->sk_tcplen - sizeof(rpc_fraghdr)); + len = svc_tcp_read_msg(rqstp, base + want, base); + if (len >= 0) { + trace_svcsock_tcp_recv(&svsk->sk_xprt, len); + svsk->sk_tcplen += len; + svsk->sk_datalen += len; + } + if (len != want || !svc_sock_final_rec(svsk)) + goto err_incomplete; + if (svsk->sk_datalen < 8) + goto err_nuts; + + rqstp->rq_arg.len = svsk->sk_datalen; + rqstp->rq_arg.page_base = 0; + if (rqstp->rq_arg.len <= rqstp->rq_arg.head[0].iov_len) { + rqstp->rq_arg.head[0].iov_len = rqstp->rq_arg.len; + rqstp->rq_arg.page_len = 0; + } else + rqstp->rq_arg.page_len = rqstp->rq_arg.len - rqstp->rq_arg.head[0].iov_len; + + rqstp->rq_xprt_ctxt = NULL; + rqstp->rq_prot = IPPROTO_TCP; + if (test_bit(XPT_LOCAL, &svsk->sk_xprt.xpt_flags)) + set_bit(RQ_LOCAL, &rqstp->rq_flags); + else + clear_bit(RQ_LOCAL, &rqstp->rq_flags); + + p = (__be32 *)rqstp->rq_arg.head[0].iov_base; + calldir = p[1]; + if (calldir) + len = receive_cb_reply(svsk, rqstp); + + /* Reset TCP read info */ + svsk->sk_datalen = 0; + svc_tcp_fragment_received(svsk); + + if (len < 0) + goto error; + + svc_xprt_copy_addrs(rqstp, &svsk->sk_xprt); + if (serv->sv_stats) + serv->sv_stats->nettcpcnt++; + + svc_xprt_received(rqstp->rq_xprt); + return rqstp->rq_arg.len; + +err_incomplete: + svc_tcp_save_pages(svsk, rqstp); + if (len < 0 && len != -EAGAIN) + goto err_delete; + if (len == want) + svc_tcp_fragment_received(svsk); + else + trace_svcsock_tcp_recv_short(&svsk->sk_xprt, + svc_sock_reclen(svsk), + svsk->sk_tcplen - sizeof(rpc_fraghdr)); + goto err_noclose; +error: + if (len != -EAGAIN) + goto err_delete; + trace_svcsock_tcp_recv_eagain(&svsk->sk_xprt, 0); + goto err_noclose; +err_nuts: + svsk->sk_datalen = 0; +err_delete: + trace_svcsock_tcp_recv_err(&svsk->sk_xprt, len); + svc_xprt_deferred_close(&svsk->sk_xprt); +err_noclose: + svc_xprt_received(rqstp->rq_xprt); + return 0; /* record not complete */ +} + +static int svc_tcp_send_kvec(struct socket *sock, const struct kvec *vec, + int flags) +{ + return kernel_sendpage(sock, virt_to_page(vec->iov_base), + offset_in_page(vec->iov_base), + vec->iov_len, flags); +} + +/* + * kernel_sendpage() is used exclusively to reduce the number of + * copy operations in this path. Therefore the caller must ensure + * that the pages backing @xdr are unchanging. + * + * In addition, the logic assumes that * .bv_len is never larger + * than PAGE_SIZE. + */ +static int svc_tcp_sendmsg(struct socket *sock, struct xdr_buf *xdr, + rpc_fraghdr marker, unsigned int *sentp) +{ + const struct kvec *head = xdr->head; + const struct kvec *tail = xdr->tail; + struct kvec rm = { + .iov_base = &marker, + .iov_len = sizeof(marker), + }; + struct msghdr msg = { + .msg_flags = 0, + }; + int ret; + + *sentp = 0; + ret = xdr_alloc_bvec(xdr, GFP_KERNEL); + if (ret < 0) + return ret; + + ret = kernel_sendmsg(sock, &msg, &rm, 1, rm.iov_len); + if (ret < 0) + return ret; + *sentp += ret; + if (ret != rm.iov_len) + return -EAGAIN; + + ret = svc_tcp_send_kvec(sock, head, 0); + if (ret < 0) + return ret; + *sentp += ret; + if (ret != head->iov_len) + goto out; + + if (xdr->page_len) { + unsigned int offset, len, remaining; + struct bio_vec *bvec; + + bvec = xdr->bvec + (xdr->page_base >> PAGE_SHIFT); + offset = offset_in_page(xdr->page_base); + remaining = xdr->page_len; + while (remaining > 0) { + len = min(remaining, bvec->bv_len - offset); + ret = kernel_sendpage(sock, bvec->bv_page, + bvec->bv_offset + offset, + len, 0); + if (ret < 0) + return ret; + *sentp += ret; + if (ret != len) + goto out; + remaining -= len; + offset = 0; + bvec++; + } + } + + if (tail->iov_len) { + ret = svc_tcp_send_kvec(sock, tail, 0); + if (ret < 0) + return ret; + *sentp += ret; + } + +out: + return 0; +} + +/** + * svc_tcp_sendto - Send out a reply on a TCP socket + * @rqstp: completed svc_rqst + * + * xpt_mutex ensures @rqstp's whole message is written to the socket + * without interruption. + * + * Returns the number of bytes sent, or a negative errno. + */ +static int svc_tcp_sendto(struct svc_rqst *rqstp) +{ + struct svc_xprt *xprt = rqstp->rq_xprt; + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + struct xdr_buf *xdr = &rqstp->rq_res; + rpc_fraghdr marker = cpu_to_be32(RPC_LAST_STREAM_FRAGMENT | + (u32)xdr->len); + unsigned int sent; + int err; + + svc_tcp_release_ctxt(xprt, rqstp->rq_xprt_ctxt); + rqstp->rq_xprt_ctxt = NULL; + + atomic_inc(&svsk->sk_sendqlen); + mutex_lock(&xprt->xpt_mutex); + if (svc_xprt_is_dead(xprt)) + goto out_notconn; + tcp_sock_set_cork(svsk->sk_sk, true); + err = svc_tcp_sendmsg(svsk->sk_sock, xdr, marker, &sent); + xdr_free_bvec(xdr); + trace_svcsock_tcp_send(xprt, err < 0 ? (long)err : sent); + if (err < 0 || sent != (xdr->len + sizeof(marker))) + goto out_close; + if (atomic_dec_and_test(&svsk->sk_sendqlen)) + tcp_sock_set_cork(svsk->sk_sk, false); + mutex_unlock(&xprt->xpt_mutex); + return sent; + +out_notconn: + atomic_dec(&svsk->sk_sendqlen); + mutex_unlock(&xprt->xpt_mutex); + return -ENOTCONN; +out_close: + pr_notice("rpc-srv/tcp: %s: %s %d when sending %d bytes - shutting down socket\n", + xprt->xpt_server->sv_name, + (err < 0) ? "got error" : "sent", + (err < 0) ? err : sent, xdr->len); + svc_xprt_deferred_close(xprt); + atomic_dec(&svsk->sk_sendqlen); + mutex_unlock(&xprt->xpt_mutex); + return -EAGAIN; +} + +static struct svc_xprt *svc_tcp_create(struct svc_serv *serv, + struct net *net, + struct sockaddr *sa, int salen, + int flags) +{ + return svc_create_socket(serv, IPPROTO_TCP, net, sa, salen, flags); +} + +static const struct svc_xprt_ops svc_tcp_ops = { + .xpo_create = svc_tcp_create, + .xpo_recvfrom = svc_tcp_recvfrom, + .xpo_sendto = svc_tcp_sendto, + .xpo_result_payload = svc_sock_result_payload, + .xpo_release_ctxt = svc_tcp_release_ctxt, + .xpo_detach = svc_tcp_sock_detach, + .xpo_free = svc_sock_free, + .xpo_has_wspace = svc_tcp_has_wspace, + .xpo_accept = svc_tcp_accept, + .xpo_secure_port = svc_sock_secure_port, + .xpo_kill_temp_xprt = svc_tcp_kill_temp_xprt, +}; + +static struct svc_xprt_class svc_tcp_class = { + .xcl_name = "tcp", + .xcl_owner = THIS_MODULE, + .xcl_ops = &svc_tcp_ops, + .xcl_max_payload = RPCSVC_MAXPAYLOAD_TCP, + .xcl_ident = XPRT_TRANSPORT_TCP, +}; + +void svc_init_xprt_sock(void) +{ + svc_reg_xprt_class(&svc_tcp_class); + svc_reg_xprt_class(&svc_udp_class); +} + +void svc_cleanup_xprt_sock(void) +{ + svc_unreg_xprt_class(&svc_tcp_class); + svc_unreg_xprt_class(&svc_udp_class); +} + +static void svc_tcp_init(struct svc_sock *svsk, struct svc_serv *serv) +{ + struct sock *sk = svsk->sk_sk; + + svc_xprt_init(sock_net(svsk->sk_sock->sk), &svc_tcp_class, + &svsk->sk_xprt, serv); + set_bit(XPT_CACHE_AUTH, &svsk->sk_xprt.xpt_flags); + set_bit(XPT_CONG_CTRL, &svsk->sk_xprt.xpt_flags); + if (sk->sk_state == TCP_LISTEN) { + strcpy(svsk->sk_xprt.xpt_remotebuf, "listener"); + set_bit(XPT_LISTENER, &svsk->sk_xprt.xpt_flags); + sk->sk_data_ready = svc_tcp_listen_data_ready; + set_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); + } else { + sk->sk_state_change = svc_tcp_state_change; + sk->sk_data_ready = svc_data_ready; + sk->sk_write_space = svc_write_space; + + svsk->sk_marker = xdr_zero; + svsk->sk_tcplen = 0; + svsk->sk_datalen = 0; + memset(&svsk->sk_pages[0], 0, sizeof(svsk->sk_pages)); + + tcp_sock_set_nodelay(sk); + + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + switch (sk->sk_state) { + case TCP_SYN_RECV: + case TCP_ESTABLISHED: + break; + default: + svc_xprt_deferred_close(&svsk->sk_xprt); + } + } +} + +void svc_sock_update_bufs(struct svc_serv *serv) +{ + /* + * The number of server threads has changed. Update + * rcvbuf and sndbuf accordingly on all sockets + */ + struct svc_sock *svsk; + + spin_lock_bh(&serv->sv_lock); + list_for_each_entry(svsk, &serv->sv_permsocks, sk_xprt.xpt_list) + set_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags); + spin_unlock_bh(&serv->sv_lock); +} +EXPORT_SYMBOL_GPL(svc_sock_update_bufs); + +/* + * Initialize socket for RPC use and create svc_sock struct + */ +static struct svc_sock *svc_setup_socket(struct svc_serv *serv, + struct socket *sock, + int flags) +{ + struct svc_sock *svsk; + struct sock *inet; + int pmap_register = !(flags & SVC_SOCK_ANONYMOUS); + int err = 0; + + svsk = kzalloc(sizeof(*svsk), GFP_KERNEL); + if (!svsk) + return ERR_PTR(-ENOMEM); + + inet = sock->sk; + + /* Register socket with portmapper */ + if (pmap_register) + err = svc_register(serv, sock_net(sock->sk), inet->sk_family, + inet->sk_protocol, + ntohs(inet_sk(inet)->inet_sport)); + + if (err < 0) { + kfree(svsk); + return ERR_PTR(err); + } + + svsk->sk_sock = sock; + svsk->sk_sk = inet; + svsk->sk_ostate = inet->sk_state_change; + svsk->sk_odata = inet->sk_data_ready; + svsk->sk_owspace = inet->sk_write_space; + /* + * This barrier is necessary in order to prevent race condition + * with svc_data_ready(), svc_listen_data_ready() and others + * when calling callbacks above. + */ + wmb(); + inet->sk_user_data = svsk; + + /* Initialize the socket */ + if (sock->type == SOCK_DGRAM) + svc_udp_init(svsk, serv); + else + svc_tcp_init(svsk, serv); + + trace_svcsock_new_socket(sock); + return svsk; +} + +/** + * svc_addsock - add a listener socket to an RPC service + * @serv: pointer to RPC service to which to add a new listener + * @net: caller's network namespace + * @fd: file descriptor of the new listener + * @name_return: pointer to buffer to fill in with name of listener + * @len: size of the buffer + * @cred: credential + * + * Fills in socket name and returns positive length of name if successful. + * Name is terminated with '\n'. On error, returns a negative errno + * value. + */ +int svc_addsock(struct svc_serv *serv, struct net *net, const int fd, + char *name_return, const size_t len, const struct cred *cred) +{ + int err = 0; + struct socket *so = sockfd_lookup(fd, &err); + struct svc_sock *svsk = NULL; + struct sockaddr_storage addr; + struct sockaddr *sin = (struct sockaddr *)&addr; + int salen; + + if (!so) + return err; + err = -EINVAL; + if (sock_net(so->sk) != net) + goto out; + err = -EAFNOSUPPORT; + if ((so->sk->sk_family != PF_INET) && (so->sk->sk_family != PF_INET6)) + goto out; + err = -EPROTONOSUPPORT; + if (so->sk->sk_protocol != IPPROTO_TCP && + so->sk->sk_protocol != IPPROTO_UDP) + goto out; + err = -EISCONN; + if (so->state > SS_UNCONNECTED) + goto out; + err = -ENOENT; + if (!try_module_get(THIS_MODULE)) + goto out; + svsk = svc_setup_socket(serv, so, SVC_SOCK_DEFAULTS); + if (IS_ERR(svsk)) { + module_put(THIS_MODULE); + err = PTR_ERR(svsk); + goto out; + } + salen = kernel_getsockname(svsk->sk_sock, sin); + if (salen >= 0) + svc_xprt_set_local(&svsk->sk_xprt, sin, salen); + svsk->sk_xprt.xpt_cred = get_cred(cred); + svc_add_new_perm_xprt(serv, &svsk->sk_xprt); + return svc_one_sock_name(svsk, name_return, len); +out: + sockfd_put(so); + return err; +} +EXPORT_SYMBOL_GPL(svc_addsock); + +/* + * Create socket for RPC service. + */ +static struct svc_xprt *svc_create_socket(struct svc_serv *serv, + int protocol, + struct net *net, + struct sockaddr *sin, int len, + int flags) +{ + struct svc_sock *svsk; + struct socket *sock; + int error; + int type; + struct sockaddr_storage addr; + struct sockaddr *newsin = (struct sockaddr *)&addr; + int newlen; + int family; + + if (protocol != IPPROTO_UDP && protocol != IPPROTO_TCP) { + printk(KERN_WARNING "svc: only UDP and TCP " + "sockets supported\n"); + return ERR_PTR(-EINVAL); + } + + type = (protocol == IPPROTO_UDP)? SOCK_DGRAM : SOCK_STREAM; + switch (sin->sa_family) { + case AF_INET6: + family = PF_INET6; + break; + case AF_INET: + family = PF_INET; + break; + default: + return ERR_PTR(-EINVAL); + } + + error = __sock_create(net, family, type, protocol, &sock, 1); + if (error < 0) + return ERR_PTR(error); + + svc_reclassify_socket(sock); + + /* + * If this is an PF_INET6 listener, we want to avoid + * getting requests from IPv4 remotes. Those should + * be shunted to a PF_INET listener via rpcbind. + */ + if (family == PF_INET6) + ip6_sock_set_v6only(sock->sk); + if (type == SOCK_STREAM) + sock->sk->sk_reuse = SK_CAN_REUSE; /* allow address reuse */ + error = kernel_bind(sock, sin, len); + if (error < 0) + goto bummer; + + error = kernel_getsockname(sock, newsin); + if (error < 0) + goto bummer; + newlen = error; + + if (protocol == IPPROTO_TCP) { + if ((error = kernel_listen(sock, 64)) < 0) + goto bummer; + } + + svsk = svc_setup_socket(serv, sock, flags); + if (IS_ERR(svsk)) { + error = PTR_ERR(svsk); + goto bummer; + } + svc_xprt_set_local(&svsk->sk_xprt, newsin, newlen); + return (struct svc_xprt *)svsk; +bummer: + sock_release(sock); + return ERR_PTR(error); +} + +/* + * Detach the svc_sock from the socket so that no + * more callbacks occur. + */ +static void svc_sock_detach(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + struct sock *sk = svsk->sk_sk; + + /* put back the old socket callbacks */ + lock_sock(sk); + sk->sk_state_change = svsk->sk_ostate; + sk->sk_data_ready = svsk->sk_odata; + sk->sk_write_space = svsk->sk_owspace; + sk->sk_user_data = NULL; + release_sock(sk); +} + +/* + * Disconnect the socket, and reset the callbacks + */ +static void svc_tcp_sock_detach(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + + svc_sock_detach(xprt); + + if (!test_bit(XPT_LISTENER, &xprt->xpt_flags)) { + svc_tcp_clear_pages(svsk); + kernel_sock_shutdown(svsk->sk_sock, SHUT_RDWR); + } +} + +/* + * Free the svc_sock's socket resources and the svc_sock itself. + */ +static void svc_sock_free(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + + if (svsk->sk_sock->file) + sockfd_put(svsk->sk_sock); + else + sock_release(svsk->sk_sock); + kfree(svsk); +} diff --git a/net/sunrpc/sysctl.c b/net/sunrpc/sysctl.c new file mode 100644 index 000000000..3aad6ef18 --- /dev/null +++ b/net/sunrpc/sysctl.c @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * linux/net/sunrpc/sysctl.c + * + * Sysctl interface to sunrpc module. + * + * I would prefer to register the sunrpc table below sys/net, but that's + * impossible at the moment. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "netns.h" + +/* + * Declare the debug flags here + */ +unsigned int rpc_debug; +EXPORT_SYMBOL_GPL(rpc_debug); + +unsigned int nfs_debug; +EXPORT_SYMBOL_GPL(nfs_debug); + +unsigned int nfsd_debug; +EXPORT_SYMBOL_GPL(nfsd_debug); + +unsigned int nlm_debug; +EXPORT_SYMBOL_GPL(nlm_debug); + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + +static struct ctl_table_header *sunrpc_table_header; +static struct ctl_table sunrpc_table[]; + +void +rpc_register_sysctl(void) +{ + if (!sunrpc_table_header) + sunrpc_table_header = register_sysctl_table(sunrpc_table); +} + +void +rpc_unregister_sysctl(void) +{ + if (sunrpc_table_header) { + unregister_sysctl_table(sunrpc_table_header); + sunrpc_table_header = NULL; + } +} + +static int proc_do_xprt(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + char tmpbuf[256]; + ssize_t len; + + if (write || *ppos) { + *lenp = 0; + return 0; + } + len = svc_print_xprts(tmpbuf, sizeof(tmpbuf)); + len = memory_read_from_buffer(buffer, *lenp, ppos, tmpbuf, len); + + if (len < 0) { + *lenp = 0; + return -EINVAL; + } + *lenp = len; + return 0; +} + +static int +proc_dodebug(struct ctl_table *table, int write, void *buffer, size_t *lenp, + loff_t *ppos) +{ + char tmpbuf[20], *s = NULL; + char *p; + unsigned int value; + size_t left, len; + + if ((*ppos && !write) || !*lenp) { + *lenp = 0; + return 0; + } + + left = *lenp; + + if (write) { + p = buffer; + while (left && isspace(*p)) { + left--; + p++; + } + if (!left) + goto done; + + if (left > sizeof(tmpbuf) - 1) + return -EINVAL; + memcpy(tmpbuf, p, left); + tmpbuf[left] = '\0'; + + value = simple_strtol(tmpbuf, &s, 0); + if (s) { + left -= (s - tmpbuf); + if (left && !isspace(*s)) + return -EINVAL; + while (left && isspace(*s)) { + left--; + s++; + } + } else + left = 0; + *(unsigned int *) table->data = value; + /* Display the RPC tasks on writing to rpc_debug */ + if (strcmp(table->procname, "rpc_debug") == 0) + rpc_show_tasks(&init_net); + } else { + len = sprintf(tmpbuf, "0x%04x", *(unsigned int *) table->data); + if (len > left) + len = left; + memcpy(buffer, tmpbuf, len); + if ((left -= len) > 0) { + *((char *)buffer + len) = '\n'; + left--; + } + } + +done: + *lenp -= left; + *ppos += *lenp; + return 0; +} + + +static struct ctl_table debug_table[] = { + { + .procname = "rpc_debug", + .data = &rpc_debug, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dodebug + }, + { + .procname = "nfs_debug", + .data = &nfs_debug, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dodebug + }, + { + .procname = "nfsd_debug", + .data = &nfsd_debug, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dodebug + }, + { + .procname = "nlm_debug", + .data = &nlm_debug, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dodebug + }, + { + .procname = "transports", + .maxlen = 256, + .mode = 0444, + .proc_handler = proc_do_xprt, + }, + { } +}; + +static struct ctl_table sunrpc_table[] = { + { + .procname = "sunrpc", + .mode = 0555, + .child = debug_table + }, + { } +}; + +#endif diff --git a/net/sunrpc/sysfs.c b/net/sunrpc/sysfs.c new file mode 100644 index 000000000..c1f559892 --- /dev/null +++ b/net/sunrpc/sysfs.c @@ -0,0 +1,626 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2020 Anna Schumaker + */ +#include +#include +#include +#include + +#include "sysfs.h" + +struct xprt_addr { + const char *addr; + struct rcu_head rcu; +}; + +static void free_xprt_addr(struct rcu_head *head) +{ + struct xprt_addr *addr = container_of(head, struct xprt_addr, rcu); + + kfree(addr->addr); + kfree(addr); +} + +static struct kset *rpc_sunrpc_kset; +static struct kobject *rpc_sunrpc_client_kobj, *rpc_sunrpc_xprt_switch_kobj; + +static void rpc_sysfs_object_release(struct kobject *kobj) +{ + kfree(kobj); +} + +static const struct kobj_ns_type_operations * +rpc_sysfs_object_child_ns_type(struct kobject *kobj) +{ + return &net_ns_type_operations; +} + +static struct kobj_type rpc_sysfs_object_type = { + .release = rpc_sysfs_object_release, + .sysfs_ops = &kobj_sysfs_ops, + .child_ns_type = rpc_sysfs_object_child_ns_type, +}; + +static struct kobject *rpc_sysfs_object_alloc(const char *name, + struct kset *kset, + struct kobject *parent) +{ + struct kobject *kobj; + + kobj = kzalloc(sizeof(*kobj), GFP_KERNEL); + if (kobj) { + kobj->kset = kset; + if (kobject_init_and_add(kobj, &rpc_sysfs_object_type, + parent, "%s", name) == 0) + return kobj; + kobject_put(kobj); + } + return NULL; +} + +static inline struct rpc_xprt * +rpc_sysfs_xprt_kobj_get_xprt(struct kobject *kobj) +{ + struct rpc_sysfs_xprt *x = container_of(kobj, + struct rpc_sysfs_xprt, kobject); + + return xprt_get(x->xprt); +} + +static inline struct rpc_xprt_switch * +rpc_sysfs_xprt_kobj_get_xprt_switch(struct kobject *kobj) +{ + struct rpc_sysfs_xprt *x = container_of(kobj, + struct rpc_sysfs_xprt, kobject); + + return xprt_switch_get(x->xprt_switch); +} + +static inline struct rpc_xprt_switch * +rpc_sysfs_xprt_switch_kobj_get_xprt(struct kobject *kobj) +{ + struct rpc_sysfs_xprt_switch *x = container_of(kobj, + struct rpc_sysfs_xprt_switch, kobject); + + return xprt_switch_get(x->xprt_switch); +} + +static ssize_t rpc_sysfs_xprt_dstaddr_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj); + ssize_t ret; + + if (!xprt) { + ret = sprintf(buf, "\n"); + goto out; + } + ret = sprintf(buf, "%s\n", xprt->address_strings[RPC_DISPLAY_ADDR]); + xprt_put(xprt); +out: + return ret; +} + +static ssize_t rpc_sysfs_xprt_srcaddr_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj); + size_t buflen = PAGE_SIZE; + ssize_t ret; + + if (!xprt || !xprt_connected(xprt)) { + ret = sprintf(buf, "\n"); + } else if (xprt->ops->get_srcaddr) { + ret = xprt->ops->get_srcaddr(xprt, buf, buflen); + if (ret > 0) { + if (ret < buflen - 1) { + buf[ret] = '\n'; + ret++; + buf[ret] = '\0'; + } + } else + ret = sprintf(buf, "\n"); + } else + ret = sprintf(buf, "\n"); + xprt_put(xprt); + return ret; +} + +static ssize_t rpc_sysfs_xprt_info_show(struct kobject *kobj, + struct kobj_attribute *attr, char *buf) +{ + struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj); + unsigned short srcport = 0; + size_t buflen = PAGE_SIZE; + ssize_t ret; + + if (!xprt || !xprt_connected(xprt)) { + ret = sprintf(buf, "\n"); + goto out; + } + + if (xprt->ops->get_srcport) + srcport = xprt->ops->get_srcport(xprt); + + ret = snprintf(buf, buflen, + "last_used=%lu\ncur_cong=%lu\ncong_win=%lu\n" + "max_num_slots=%u\nmin_num_slots=%u\nnum_reqs=%u\n" + "binding_q_len=%u\nsending_q_len=%u\npending_q_len=%u\n" + "backlog_q_len=%u\nmain_xprt=%d\nsrc_port=%u\n" + "tasks_queuelen=%ld\ndst_port=%s\n", + xprt->last_used, xprt->cong, xprt->cwnd, xprt->max_reqs, + xprt->min_reqs, xprt->num_reqs, xprt->binding.qlen, + xprt->sending.qlen, xprt->pending.qlen, + xprt->backlog.qlen, xprt->main, srcport, + atomic_long_read(&xprt->queuelen), + xprt->address_strings[RPC_DISPLAY_PORT]); +out: + xprt_put(xprt); + return ret; +} + +static ssize_t rpc_sysfs_xprt_state_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj); + ssize_t ret; + int locked, connected, connecting, close_wait, bound, binding, + closing, congested, cwnd_wait, write_space, offline, remove; + + if (!(xprt && xprt->state)) { + ret = sprintf(buf, "state=CLOSED\n"); + } else { + locked = test_bit(XPRT_LOCKED, &xprt->state); + connected = test_bit(XPRT_CONNECTED, &xprt->state); + connecting = test_bit(XPRT_CONNECTING, &xprt->state); + close_wait = test_bit(XPRT_CLOSE_WAIT, &xprt->state); + bound = test_bit(XPRT_BOUND, &xprt->state); + binding = test_bit(XPRT_BINDING, &xprt->state); + closing = test_bit(XPRT_CLOSING, &xprt->state); + congested = test_bit(XPRT_CONGESTED, &xprt->state); + cwnd_wait = test_bit(XPRT_CWND_WAIT, &xprt->state); + write_space = test_bit(XPRT_WRITE_SPACE, &xprt->state); + offline = test_bit(XPRT_OFFLINE, &xprt->state); + remove = test_bit(XPRT_REMOVE, &xprt->state); + + ret = sprintf(buf, "state=%s %s %s %s %s %s %s %s %s %s %s %s\n", + locked ? "LOCKED" : "", + connected ? "CONNECTED" : "", + connecting ? "CONNECTING" : "", + close_wait ? "CLOSE_WAIT" : "", + bound ? "BOUND" : "", + binding ? "BOUNDING" : "", + closing ? "CLOSING" : "", + congested ? "CONGESTED" : "", + cwnd_wait ? "CWND_WAIT" : "", + write_space ? "WRITE_SPACE" : "", + offline ? "OFFLINE" : "", + remove ? "REMOVE" : ""); + } + + xprt_put(xprt); + return ret; +} + +static ssize_t rpc_sysfs_xprt_switch_info_show(struct kobject *kobj, + struct kobj_attribute *attr, + char *buf) +{ + struct rpc_xprt_switch *xprt_switch = + rpc_sysfs_xprt_switch_kobj_get_xprt(kobj); + ssize_t ret; + + if (!xprt_switch) + return 0; + ret = sprintf(buf, "num_xprts=%u\nnum_active=%u\n" + "num_unique_destaddr=%u\nqueue_len=%ld\n", + xprt_switch->xps_nxprts, xprt_switch->xps_nactive, + xprt_switch->xps_nunique_destaddr_xprts, + atomic_long_read(&xprt_switch->xps_queuelen)); + xprt_switch_put(xprt_switch); + return ret; +} + +static ssize_t rpc_sysfs_xprt_dstaddr_store(struct kobject *kobj, + struct kobj_attribute *attr, + const char *buf, size_t count) +{ + struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj); + struct sockaddr *saddr; + char *dst_addr; + int port; + struct xprt_addr *saved_addr; + size_t buf_len; + + if (!xprt) + return 0; + if (!(xprt->xprt_class->ident == XPRT_TRANSPORT_TCP || + xprt->xprt_class->ident == XPRT_TRANSPORT_RDMA)) { + xprt_put(xprt); + return -EOPNOTSUPP; + } + + if (wait_on_bit_lock(&xprt->state, XPRT_LOCKED, TASK_KILLABLE)) { + count = -EINTR; + goto out_put; + } + saddr = (struct sockaddr *)&xprt->addr; + port = rpc_get_port(saddr); + + /* buf_len is the len until the first occurence of either + * '\n' or '\0' + */ + buf_len = strcspn(buf, "\n"); + + dst_addr = kstrndup(buf, buf_len, GFP_KERNEL); + if (!dst_addr) + goto out_err; + saved_addr = kzalloc(sizeof(*saved_addr), GFP_KERNEL); + if (!saved_addr) + goto out_err_free; + saved_addr->addr = + rcu_dereference_raw(xprt->address_strings[RPC_DISPLAY_ADDR]); + rcu_assign_pointer(xprt->address_strings[RPC_DISPLAY_ADDR], dst_addr); + call_rcu(&saved_addr->rcu, free_xprt_addr); + xprt->addrlen = rpc_pton(xprt->xprt_net, buf, buf_len, saddr, + sizeof(*saddr)); + rpc_set_port(saddr, port); + + xprt_force_disconnect(xprt); +out: + xprt_release_write(xprt, NULL); +out_put: + xprt_put(xprt); + return count; +out_err_free: + kfree(dst_addr); +out_err: + count = -ENOMEM; + goto out; +} + +static ssize_t rpc_sysfs_xprt_state_change(struct kobject *kobj, + struct kobj_attribute *attr, + const char *buf, size_t count) +{ + struct rpc_xprt *xprt = rpc_sysfs_xprt_kobj_get_xprt(kobj); + int offline = 0, online = 0, remove = 0; + struct rpc_xprt_switch *xps = rpc_sysfs_xprt_kobj_get_xprt_switch(kobj); + + if (!xprt || !xps) { + count = 0; + goto out_put; + } + + if (!strncmp(buf, "offline", 7)) + offline = 1; + else if (!strncmp(buf, "online", 6)) + online = 1; + else if (!strncmp(buf, "remove", 6)) + remove = 1; + else { + count = -EINVAL; + goto out_put; + } + + if (wait_on_bit_lock(&xprt->state, XPRT_LOCKED, TASK_KILLABLE)) { + count = -EINTR; + goto out_put; + } + if (xprt->main) { + count = -EINVAL; + goto release_tasks; + } + if (offline) { + xprt_set_offline_locked(xprt, xps); + } else if (online) { + xprt_set_online_locked(xprt, xps); + } else if (remove) { + if (test_bit(XPRT_OFFLINE, &xprt->state)) + xprt_delete_locked(xprt, xps); + else + count = -EINVAL; + } + +release_tasks: + xprt_release_write(xprt, NULL); +out_put: + xprt_put(xprt); + xprt_switch_put(xps); + return count; +} + +int rpc_sysfs_init(void) +{ + rpc_sunrpc_kset = kset_create_and_add("sunrpc", NULL, kernel_kobj); + if (!rpc_sunrpc_kset) + return -ENOMEM; + rpc_sunrpc_client_kobj = + rpc_sysfs_object_alloc("rpc-clients", rpc_sunrpc_kset, NULL); + if (!rpc_sunrpc_client_kobj) + goto err_client; + rpc_sunrpc_xprt_switch_kobj = + rpc_sysfs_object_alloc("xprt-switches", rpc_sunrpc_kset, NULL); + if (!rpc_sunrpc_xprt_switch_kobj) + goto err_switch; + return 0; +err_switch: + kobject_put(rpc_sunrpc_client_kobj); + rpc_sunrpc_client_kobj = NULL; +err_client: + kset_unregister(rpc_sunrpc_kset); + rpc_sunrpc_kset = NULL; + return -ENOMEM; +} + +static void rpc_sysfs_client_release(struct kobject *kobj) +{ + struct rpc_sysfs_client *c; + + c = container_of(kobj, struct rpc_sysfs_client, kobject); + kfree(c); +} + +static void rpc_sysfs_xprt_switch_release(struct kobject *kobj) +{ + struct rpc_sysfs_xprt_switch *xprt_switch; + + xprt_switch = container_of(kobj, struct rpc_sysfs_xprt_switch, kobject); + kfree(xprt_switch); +} + +static void rpc_sysfs_xprt_release(struct kobject *kobj) +{ + struct rpc_sysfs_xprt *xprt; + + xprt = container_of(kobj, struct rpc_sysfs_xprt, kobject); + kfree(xprt); +} + +static const void *rpc_sysfs_client_namespace(struct kobject *kobj) +{ + return container_of(kobj, struct rpc_sysfs_client, kobject)->net; +} + +static const void *rpc_sysfs_xprt_switch_namespace(struct kobject *kobj) +{ + return container_of(kobj, struct rpc_sysfs_xprt_switch, kobject)->net; +} + +static const void *rpc_sysfs_xprt_namespace(struct kobject *kobj) +{ + return container_of(kobj, struct rpc_sysfs_xprt, + kobject)->xprt->xprt_net; +} + +static struct kobj_attribute rpc_sysfs_xprt_dstaddr = __ATTR(dstaddr, + 0644, rpc_sysfs_xprt_dstaddr_show, rpc_sysfs_xprt_dstaddr_store); + +static struct kobj_attribute rpc_sysfs_xprt_srcaddr = __ATTR(srcaddr, + 0644, rpc_sysfs_xprt_srcaddr_show, NULL); + +static struct kobj_attribute rpc_sysfs_xprt_info = __ATTR(xprt_info, + 0444, rpc_sysfs_xprt_info_show, NULL); + +static struct kobj_attribute rpc_sysfs_xprt_change_state = __ATTR(xprt_state, + 0644, rpc_sysfs_xprt_state_show, rpc_sysfs_xprt_state_change); + +static struct attribute *rpc_sysfs_xprt_attrs[] = { + &rpc_sysfs_xprt_dstaddr.attr, + &rpc_sysfs_xprt_srcaddr.attr, + &rpc_sysfs_xprt_info.attr, + &rpc_sysfs_xprt_change_state.attr, + NULL, +}; +ATTRIBUTE_GROUPS(rpc_sysfs_xprt); + +static struct kobj_attribute rpc_sysfs_xprt_switch_info = + __ATTR(xprt_switch_info, 0444, rpc_sysfs_xprt_switch_info_show, NULL); + +static struct attribute *rpc_sysfs_xprt_switch_attrs[] = { + &rpc_sysfs_xprt_switch_info.attr, + NULL, +}; +ATTRIBUTE_GROUPS(rpc_sysfs_xprt_switch); + +static struct kobj_type rpc_sysfs_client_type = { + .release = rpc_sysfs_client_release, + .sysfs_ops = &kobj_sysfs_ops, + .namespace = rpc_sysfs_client_namespace, +}; + +static struct kobj_type rpc_sysfs_xprt_switch_type = { + .release = rpc_sysfs_xprt_switch_release, + .default_groups = rpc_sysfs_xprt_switch_groups, + .sysfs_ops = &kobj_sysfs_ops, + .namespace = rpc_sysfs_xprt_switch_namespace, +}; + +static struct kobj_type rpc_sysfs_xprt_type = { + .release = rpc_sysfs_xprt_release, + .default_groups = rpc_sysfs_xprt_groups, + .sysfs_ops = &kobj_sysfs_ops, + .namespace = rpc_sysfs_xprt_namespace, +}; + +void rpc_sysfs_exit(void) +{ + kobject_put(rpc_sunrpc_client_kobj); + kobject_put(rpc_sunrpc_xprt_switch_kobj); + kset_unregister(rpc_sunrpc_kset); +} + +static struct rpc_sysfs_client *rpc_sysfs_client_alloc(struct kobject *parent, + struct net *net, + int clid) +{ + struct rpc_sysfs_client *p; + + p = kzalloc(sizeof(*p), GFP_KERNEL); + if (p) { + p->net = net; + p->kobject.kset = rpc_sunrpc_kset; + if (kobject_init_and_add(&p->kobject, &rpc_sysfs_client_type, + parent, "clnt-%d", clid) == 0) + return p; + kobject_put(&p->kobject); + } + return NULL; +} + +static struct rpc_sysfs_xprt_switch * +rpc_sysfs_xprt_switch_alloc(struct kobject *parent, + struct rpc_xprt_switch *xprt_switch, + struct net *net, + gfp_t gfp_flags) +{ + struct rpc_sysfs_xprt_switch *p; + + p = kzalloc(sizeof(*p), gfp_flags); + if (p) { + p->net = net; + p->kobject.kset = rpc_sunrpc_kset; + if (kobject_init_and_add(&p->kobject, + &rpc_sysfs_xprt_switch_type, + parent, "switch-%d", + xprt_switch->xps_id) == 0) + return p; + kobject_put(&p->kobject); + } + return NULL; +} + +static struct rpc_sysfs_xprt *rpc_sysfs_xprt_alloc(struct kobject *parent, + struct rpc_xprt *xprt, + gfp_t gfp_flags) +{ + struct rpc_sysfs_xprt *p; + + p = kzalloc(sizeof(*p), gfp_flags); + if (!p) + goto out; + p->kobject.kset = rpc_sunrpc_kset; + if (kobject_init_and_add(&p->kobject, &rpc_sysfs_xprt_type, + parent, "xprt-%d-%s", xprt->id, + xprt->address_strings[RPC_DISPLAY_PROTO]) == 0) + return p; + kobject_put(&p->kobject); +out: + return NULL; +} + +void rpc_sysfs_client_setup(struct rpc_clnt *clnt, + struct rpc_xprt_switch *xprt_switch, + struct net *net) +{ + struct rpc_sysfs_client *rpc_client; + struct rpc_sysfs_xprt_switch *xswitch = + (struct rpc_sysfs_xprt_switch *)xprt_switch->xps_sysfs; + + if (!xswitch) + return; + + rpc_client = rpc_sysfs_client_alloc(rpc_sunrpc_client_kobj, + net, clnt->cl_clid); + if (rpc_client) { + char name[] = "switch"; + int ret; + + clnt->cl_sysfs = rpc_client; + rpc_client->clnt = clnt; + rpc_client->xprt_switch = xprt_switch; + kobject_uevent(&rpc_client->kobject, KOBJ_ADD); + ret = sysfs_create_link_nowarn(&rpc_client->kobject, + &xswitch->kobject, name); + if (ret) + pr_warn("can't create link to %s in sysfs (%d)\n", + name, ret); + } +} + +void rpc_sysfs_xprt_switch_setup(struct rpc_xprt_switch *xprt_switch, + struct rpc_xprt *xprt, + gfp_t gfp_flags) +{ + struct rpc_sysfs_xprt_switch *rpc_xprt_switch; + struct net *net; + + if (xprt_switch->xps_net) + net = xprt_switch->xps_net; + else + net = xprt->xprt_net; + rpc_xprt_switch = + rpc_sysfs_xprt_switch_alloc(rpc_sunrpc_xprt_switch_kobj, + xprt_switch, net, gfp_flags); + if (rpc_xprt_switch) { + xprt_switch->xps_sysfs = rpc_xprt_switch; + rpc_xprt_switch->xprt_switch = xprt_switch; + rpc_xprt_switch->xprt = xprt; + kobject_uevent(&rpc_xprt_switch->kobject, KOBJ_ADD); + } else { + xprt_switch->xps_sysfs = NULL; + } +} + +void rpc_sysfs_xprt_setup(struct rpc_xprt_switch *xprt_switch, + struct rpc_xprt *xprt, + gfp_t gfp_flags) +{ + struct rpc_sysfs_xprt *rpc_xprt; + struct rpc_sysfs_xprt_switch *switch_obj = + (struct rpc_sysfs_xprt_switch *)xprt_switch->xps_sysfs; + + if (!switch_obj) + return; + + rpc_xprt = rpc_sysfs_xprt_alloc(&switch_obj->kobject, xprt, gfp_flags); + if (rpc_xprt) { + xprt->xprt_sysfs = rpc_xprt; + rpc_xprt->xprt = xprt; + rpc_xprt->xprt_switch = xprt_switch; + kobject_uevent(&rpc_xprt->kobject, KOBJ_ADD); + } +} + +void rpc_sysfs_client_destroy(struct rpc_clnt *clnt) +{ + struct rpc_sysfs_client *rpc_client = clnt->cl_sysfs; + + if (rpc_client) { + char name[] = "switch"; + + sysfs_remove_link(&rpc_client->kobject, name); + kobject_uevent(&rpc_client->kobject, KOBJ_REMOVE); + kobject_del(&rpc_client->kobject); + kobject_put(&rpc_client->kobject); + clnt->cl_sysfs = NULL; + } +} + +void rpc_sysfs_xprt_switch_destroy(struct rpc_xprt_switch *xprt_switch) +{ + struct rpc_sysfs_xprt_switch *rpc_xprt_switch = xprt_switch->xps_sysfs; + + if (rpc_xprt_switch) { + kobject_uevent(&rpc_xprt_switch->kobject, KOBJ_REMOVE); + kobject_del(&rpc_xprt_switch->kobject); + kobject_put(&rpc_xprt_switch->kobject); + xprt_switch->xps_sysfs = NULL; + } +} + +void rpc_sysfs_xprt_destroy(struct rpc_xprt *xprt) +{ + struct rpc_sysfs_xprt *rpc_xprt = xprt->xprt_sysfs; + + if (rpc_xprt) { + kobject_uevent(&rpc_xprt->kobject, KOBJ_REMOVE); + kobject_del(&rpc_xprt->kobject); + kobject_put(&rpc_xprt->kobject); + xprt->xprt_sysfs = NULL; + } +} diff --git a/net/sunrpc/sysfs.h b/net/sunrpc/sysfs.h new file mode 100644 index 000000000..6620cebd1 --- /dev/null +++ b/net/sunrpc/sysfs.h @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2020 Anna Schumaker + */ +#ifndef __SUNRPC_SYSFS_H +#define __SUNRPC_SYSFS_H + +struct rpc_sysfs_client { + struct kobject kobject; + struct net *net; + struct rpc_clnt *clnt; + struct rpc_xprt_switch *xprt_switch; +}; + +struct rpc_sysfs_xprt_switch { + struct kobject kobject; + struct net *net; + struct rpc_xprt_switch *xprt_switch; + struct rpc_xprt *xprt; +}; + +struct rpc_sysfs_xprt { + struct kobject kobject; + struct rpc_xprt *xprt; + struct rpc_xprt_switch *xprt_switch; +}; + +int rpc_sysfs_init(void); +void rpc_sysfs_exit(void); + +void rpc_sysfs_client_setup(struct rpc_clnt *clnt, + struct rpc_xprt_switch *xprt_switch, + struct net *net); +void rpc_sysfs_client_destroy(struct rpc_clnt *clnt); +void rpc_sysfs_xprt_switch_setup(struct rpc_xprt_switch *xprt_switch, + struct rpc_xprt *xprt, gfp_t gfp_flags); +void rpc_sysfs_xprt_switch_destroy(struct rpc_xprt_switch *xprt); +void rpc_sysfs_xprt_setup(struct rpc_xprt_switch *xprt_switch, + struct rpc_xprt *xprt, gfp_t gfp_flags); +void rpc_sysfs_xprt_destroy(struct rpc_xprt *xprt); + +#endif diff --git a/net/sunrpc/timer.c b/net/sunrpc/timer.c new file mode 100644 index 000000000..81ae35b37 --- /dev/null +++ b/net/sunrpc/timer.c @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * linux/net/sunrpc/timer.c + * + * Estimate RPC request round trip time. + * + * Based on packet round-trip and variance estimator algorithms described + * in appendix A of "Congestion Avoidance and Control" by Van Jacobson + * and Michael J. Karels (ACM Computer Communication Review; Proceedings + * of the Sigcomm '88 Symposium in Stanford, CA, August, 1988). + * + * This RTT estimator is used only for RPC over datagram protocols. + * + * Copyright (C) 2002 Trond Myklebust + */ + +#include + +#include +#include +#include + +#include + +#define RPC_RTO_MAX (60*HZ) +#define RPC_RTO_INIT (HZ/5) +#define RPC_RTO_MIN (HZ/10) + +/** + * rpc_init_rtt - Initialize an RPC RTT estimator context + * @rt: context to initialize + * @timeo: initial timeout value, in jiffies + * + */ +void rpc_init_rtt(struct rpc_rtt *rt, unsigned long timeo) +{ + unsigned long init = 0; + unsigned int i; + + rt->timeo = timeo; + + if (timeo > RPC_RTO_INIT) + init = (timeo - RPC_RTO_INIT) << 3; + for (i = 0; i < 5; i++) { + rt->srtt[i] = init; + rt->sdrtt[i] = RPC_RTO_INIT; + rt->ntimeouts[i] = 0; + } +} +EXPORT_SYMBOL_GPL(rpc_init_rtt); + +/** + * rpc_update_rtt - Update an RPC RTT estimator context + * @rt: context to update + * @timer: timer array index (request type) + * @m: recent actual RTT, in jiffies + * + * NB: When computing the smoothed RTT and standard deviation, + * be careful not to produce negative intermediate results. + */ +void rpc_update_rtt(struct rpc_rtt *rt, unsigned int timer, long m) +{ + long *srtt, *sdrtt; + + if (timer-- == 0) + return; + + /* jiffies wrapped; ignore this one */ + if (m < 0) + return; + + if (m == 0) + m = 1L; + + srtt = (long *)&rt->srtt[timer]; + m -= *srtt >> 3; + *srtt += m; + + if (m < 0) + m = -m; + + sdrtt = (long *)&rt->sdrtt[timer]; + m -= *sdrtt >> 2; + *sdrtt += m; + + /* Set lower bound on the variance */ + if (*sdrtt < RPC_RTO_MIN) + *sdrtt = RPC_RTO_MIN; +} +EXPORT_SYMBOL_GPL(rpc_update_rtt); + +/** + * rpc_calc_rto - Provide an estimated timeout value + * @rt: context to use for calculation + * @timer: timer array index (request type) + * + * Estimate RTO for an NFS RPC sent via an unreliable datagram. Use + * the mean and mean deviation of RTT for the appropriate type of RPC + * for frequently issued RPCs, and a fixed default for the others. + * + * The justification for doing "other" this way is that these RPCs + * happen so infrequently that timer estimation would probably be + * stale. Also, since many of these RPCs are non-idempotent, a + * conservative timeout is desired. + * + * getattr, lookup, + * read, write, commit - A+4D + * other - timeo + */ +unsigned long rpc_calc_rto(struct rpc_rtt *rt, unsigned int timer) +{ + unsigned long res; + + if (timer-- == 0) + return rt->timeo; + + res = ((rt->srtt[timer] + 7) >> 3) + rt->sdrtt[timer]; + if (res > RPC_RTO_MAX) + res = RPC_RTO_MAX; + + return res; +} +EXPORT_SYMBOL_GPL(rpc_calc_rto); diff --git a/net/sunrpc/xdr.c b/net/sunrpc/xdr.c new file mode 100644 index 000000000..336a7c783 --- /dev/null +++ b/net/sunrpc/xdr.c @@ -0,0 +1,2272 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * linux/net/sunrpc/xdr.c + * + * Generic XDR support. + * + * Copyright (C) 1995, 1996 Olaf Kirch + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static void _copy_to_pages(struct page **, size_t, const char *, size_t); + + +/* + * XDR functions for basic NFS types + */ +__be32 * +xdr_encode_netobj(__be32 *p, const struct xdr_netobj *obj) +{ + unsigned int quadlen = XDR_QUADLEN(obj->len); + + p[quadlen] = 0; /* zero trailing bytes */ + *p++ = cpu_to_be32(obj->len); + memcpy(p, obj->data, obj->len); + return p + XDR_QUADLEN(obj->len); +} +EXPORT_SYMBOL_GPL(xdr_encode_netobj); + +__be32 * +xdr_decode_netobj(__be32 *p, struct xdr_netobj *obj) +{ + unsigned int len; + + if ((len = be32_to_cpu(*p++)) > XDR_MAX_NETOBJ) + return NULL; + obj->len = len; + obj->data = (u8 *) p; + return p + XDR_QUADLEN(len); +} +EXPORT_SYMBOL_GPL(xdr_decode_netobj); + +/** + * xdr_encode_opaque_fixed - Encode fixed length opaque data + * @p: pointer to current position in XDR buffer. + * @ptr: pointer to data to encode (or NULL) + * @nbytes: size of data. + * + * Copy the array of data of length nbytes at ptr to the XDR buffer + * at position p, then align to the next 32-bit boundary by padding + * with zero bytes (see RFC1832). + * Note: if ptr is NULL, only the padding is performed. + * + * Returns the updated current XDR buffer position + * + */ +__be32 *xdr_encode_opaque_fixed(__be32 *p, const void *ptr, unsigned int nbytes) +{ + if (likely(nbytes != 0)) { + unsigned int quadlen = XDR_QUADLEN(nbytes); + unsigned int padding = (quadlen << 2) - nbytes; + + if (ptr != NULL) + memcpy(p, ptr, nbytes); + if (padding != 0) + memset((char *)p + nbytes, 0, padding); + p += quadlen; + } + return p; +} +EXPORT_SYMBOL_GPL(xdr_encode_opaque_fixed); + +/** + * xdr_encode_opaque - Encode variable length opaque data + * @p: pointer to current position in XDR buffer. + * @ptr: pointer to data to encode (or NULL) + * @nbytes: size of data. + * + * Returns the updated current XDR buffer position + */ +__be32 *xdr_encode_opaque(__be32 *p, const void *ptr, unsigned int nbytes) +{ + *p++ = cpu_to_be32(nbytes); + return xdr_encode_opaque_fixed(p, ptr, nbytes); +} +EXPORT_SYMBOL_GPL(xdr_encode_opaque); + +__be32 * +xdr_encode_string(__be32 *p, const char *string) +{ + return xdr_encode_array(p, string, strlen(string)); +} +EXPORT_SYMBOL_GPL(xdr_encode_string); + +__be32 * +xdr_decode_string_inplace(__be32 *p, char **sp, + unsigned int *lenp, unsigned int maxlen) +{ + u32 len; + + len = be32_to_cpu(*p++); + if (len > maxlen) + return NULL; + *lenp = len; + *sp = (char *) p; + return p + XDR_QUADLEN(len); +} +EXPORT_SYMBOL_GPL(xdr_decode_string_inplace); + +/** + * xdr_terminate_string - '\0'-terminate a string residing in an xdr_buf + * @buf: XDR buffer where string resides + * @len: length of string, in bytes + * + */ +void xdr_terminate_string(const struct xdr_buf *buf, const u32 len) +{ + char *kaddr; + + kaddr = kmap_atomic(buf->pages[0]); + kaddr[buf->page_base + len] = '\0'; + kunmap_atomic(kaddr); +} +EXPORT_SYMBOL_GPL(xdr_terminate_string); + +size_t xdr_buf_pagecount(const struct xdr_buf *buf) +{ + if (!buf->page_len) + return 0; + return (buf->page_base + buf->page_len + PAGE_SIZE - 1) >> PAGE_SHIFT; +} + +int +xdr_alloc_bvec(struct xdr_buf *buf, gfp_t gfp) +{ + size_t i, n = xdr_buf_pagecount(buf); + + if (n != 0 && buf->bvec == NULL) { + buf->bvec = kmalloc_array(n, sizeof(buf->bvec[0]), gfp); + if (!buf->bvec) + return -ENOMEM; + for (i = 0; i < n; i++) { + buf->bvec[i].bv_page = buf->pages[i]; + buf->bvec[i].bv_len = PAGE_SIZE; + buf->bvec[i].bv_offset = 0; + } + } + return 0; +} + +void +xdr_free_bvec(struct xdr_buf *buf) +{ + kfree(buf->bvec); + buf->bvec = NULL; +} + +/** + * xdr_inline_pages - Prepare receive buffer for a large reply + * @xdr: xdr_buf into which reply will be placed + * @offset: expected offset where data payload will start, in bytes + * @pages: vector of struct page pointers + * @base: offset in first page where receive should start, in bytes + * @len: expected size of the upper layer data payload, in bytes + * + */ +void +xdr_inline_pages(struct xdr_buf *xdr, unsigned int offset, + struct page **pages, unsigned int base, unsigned int len) +{ + struct kvec *head = xdr->head; + struct kvec *tail = xdr->tail; + char *buf = (char *)head->iov_base; + unsigned int buflen = head->iov_len; + + head->iov_len = offset; + + xdr->pages = pages; + xdr->page_base = base; + xdr->page_len = len; + + tail->iov_base = buf + offset; + tail->iov_len = buflen - offset; + xdr->buflen += len; +} +EXPORT_SYMBOL_GPL(xdr_inline_pages); + +/* + * Helper routines for doing 'memmove' like operations on a struct xdr_buf + */ + +/** + * _shift_data_left_pages + * @pages: vector of pages containing both the source and dest memory area. + * @pgto_base: page vector address of destination + * @pgfrom_base: page vector address of source + * @len: number of bytes to copy + * + * Note: the addresses pgto_base and pgfrom_base are both calculated in + * the same way: + * if a memory area starts at byte 'base' in page 'pages[i]', + * then its address is given as (i << PAGE_CACHE_SHIFT) + base + * Alse note: pgto_base must be < pgfrom_base, but the memory areas + * they point to may overlap. + */ +static void +_shift_data_left_pages(struct page **pages, size_t pgto_base, + size_t pgfrom_base, size_t len) +{ + struct page **pgfrom, **pgto; + char *vfrom, *vto; + size_t copy; + + BUG_ON(pgfrom_base <= pgto_base); + + if (!len) + return; + + pgto = pages + (pgto_base >> PAGE_SHIFT); + pgfrom = pages + (pgfrom_base >> PAGE_SHIFT); + + pgto_base &= ~PAGE_MASK; + pgfrom_base &= ~PAGE_MASK; + + do { + if (pgto_base >= PAGE_SIZE) { + pgto_base = 0; + pgto++; + } + if (pgfrom_base >= PAGE_SIZE){ + pgfrom_base = 0; + pgfrom++; + } + + copy = len; + if (copy > (PAGE_SIZE - pgto_base)) + copy = PAGE_SIZE - pgto_base; + if (copy > (PAGE_SIZE - pgfrom_base)) + copy = PAGE_SIZE - pgfrom_base; + + vto = kmap_atomic(*pgto); + if (*pgto != *pgfrom) { + vfrom = kmap_atomic(*pgfrom); + memcpy(vto + pgto_base, vfrom + pgfrom_base, copy); + kunmap_atomic(vfrom); + } else + memmove(vto + pgto_base, vto + pgfrom_base, copy); + flush_dcache_page(*pgto); + kunmap_atomic(vto); + + pgto_base += copy; + pgfrom_base += copy; + + } while ((len -= copy) != 0); +} + +/** + * _shift_data_right_pages + * @pages: vector of pages containing both the source and dest memory area. + * @pgto_base: page vector address of destination + * @pgfrom_base: page vector address of source + * @len: number of bytes to copy + * + * Note: the addresses pgto_base and pgfrom_base are both calculated in + * the same way: + * if a memory area starts at byte 'base' in page 'pages[i]', + * then its address is given as (i << PAGE_SHIFT) + base + * Also note: pgfrom_base must be < pgto_base, but the memory areas + * they point to may overlap. + */ +static void +_shift_data_right_pages(struct page **pages, size_t pgto_base, + size_t pgfrom_base, size_t len) +{ + struct page **pgfrom, **pgto; + char *vfrom, *vto; + size_t copy; + + BUG_ON(pgto_base <= pgfrom_base); + + if (!len) + return; + + pgto_base += len; + pgfrom_base += len; + + pgto = pages + (pgto_base >> PAGE_SHIFT); + pgfrom = pages + (pgfrom_base >> PAGE_SHIFT); + + pgto_base &= ~PAGE_MASK; + pgfrom_base &= ~PAGE_MASK; + + do { + /* Are any pointers crossing a page boundary? */ + if (pgto_base == 0) { + pgto_base = PAGE_SIZE; + pgto--; + } + if (pgfrom_base == 0) { + pgfrom_base = PAGE_SIZE; + pgfrom--; + } + + copy = len; + if (copy > pgto_base) + copy = pgto_base; + if (copy > pgfrom_base) + copy = pgfrom_base; + pgto_base -= copy; + pgfrom_base -= copy; + + vto = kmap_atomic(*pgto); + if (*pgto != *pgfrom) { + vfrom = kmap_atomic(*pgfrom); + memcpy(vto + pgto_base, vfrom + pgfrom_base, copy); + kunmap_atomic(vfrom); + } else + memmove(vto + pgto_base, vto + pgfrom_base, copy); + flush_dcache_page(*pgto); + kunmap_atomic(vto); + + } while ((len -= copy) != 0); +} + +/** + * _copy_to_pages + * @pages: array of pages + * @pgbase: page vector address of destination + * @p: pointer to source data + * @len: length + * + * Copies data from an arbitrary memory location into an array of pages + * The copy is assumed to be non-overlapping. + */ +static void +_copy_to_pages(struct page **pages, size_t pgbase, const char *p, size_t len) +{ + struct page **pgto; + char *vto; + size_t copy; + + if (!len) + return; + + pgto = pages + (pgbase >> PAGE_SHIFT); + pgbase &= ~PAGE_MASK; + + for (;;) { + copy = PAGE_SIZE - pgbase; + if (copy > len) + copy = len; + + vto = kmap_atomic(*pgto); + memcpy(vto + pgbase, p, copy); + kunmap_atomic(vto); + + len -= copy; + if (len == 0) + break; + + pgbase += copy; + if (pgbase == PAGE_SIZE) { + flush_dcache_page(*pgto); + pgbase = 0; + pgto++; + } + p += copy; + } + flush_dcache_page(*pgto); +} + +/** + * _copy_from_pages + * @p: pointer to destination + * @pages: array of pages + * @pgbase: offset of source data + * @len: length + * + * Copies data into an arbitrary memory location from an array of pages + * The copy is assumed to be non-overlapping. + */ +void +_copy_from_pages(char *p, struct page **pages, size_t pgbase, size_t len) +{ + struct page **pgfrom; + char *vfrom; + size_t copy; + + if (!len) + return; + + pgfrom = pages + (pgbase >> PAGE_SHIFT); + pgbase &= ~PAGE_MASK; + + do { + copy = PAGE_SIZE - pgbase; + if (copy > len) + copy = len; + + vfrom = kmap_atomic(*pgfrom); + memcpy(p, vfrom + pgbase, copy); + kunmap_atomic(vfrom); + + pgbase += copy; + if (pgbase == PAGE_SIZE) { + pgbase = 0; + pgfrom++; + } + p += copy; + + } while ((len -= copy) != 0); +} +EXPORT_SYMBOL_GPL(_copy_from_pages); + +static void xdr_buf_iov_zero(const struct kvec *iov, unsigned int base, + unsigned int len) +{ + if (base >= iov->iov_len) + return; + if (len > iov->iov_len - base) + len = iov->iov_len - base; + memset(iov->iov_base + base, 0, len); +} + +/** + * xdr_buf_pages_zero + * @buf: xdr_buf + * @pgbase: beginning offset + * @len: length + */ +static void xdr_buf_pages_zero(const struct xdr_buf *buf, unsigned int pgbase, + unsigned int len) +{ + struct page **pages = buf->pages; + struct page **page; + char *vpage; + unsigned int zero; + + if (!len) + return; + if (pgbase >= buf->page_len) { + xdr_buf_iov_zero(buf->tail, pgbase - buf->page_len, len); + return; + } + if (pgbase + len > buf->page_len) { + xdr_buf_iov_zero(buf->tail, 0, pgbase + len - buf->page_len); + len = buf->page_len - pgbase; + } + + pgbase += buf->page_base; + + page = pages + (pgbase >> PAGE_SHIFT); + pgbase &= ~PAGE_MASK; + + do { + zero = PAGE_SIZE - pgbase; + if (zero > len) + zero = len; + + vpage = kmap_atomic(*page); + memset(vpage + pgbase, 0, zero); + kunmap_atomic(vpage); + + flush_dcache_page(*page); + pgbase = 0; + page++; + + } while ((len -= zero) != 0); +} + +static unsigned int xdr_buf_pages_fill_sparse(const struct xdr_buf *buf, + unsigned int buflen, gfp_t gfp) +{ + unsigned int i, npages, pagelen; + + if (!(buf->flags & XDRBUF_SPARSE_PAGES)) + return buflen; + if (buflen <= buf->head->iov_len) + return buflen; + pagelen = buflen - buf->head->iov_len; + if (pagelen > buf->page_len) + pagelen = buf->page_len; + npages = (pagelen + buf->page_base + PAGE_SIZE - 1) >> PAGE_SHIFT; + for (i = 0; i < npages; i++) { + if (!buf->pages[i]) + continue; + buf->pages[i] = alloc_page(gfp); + if (likely(buf->pages[i])) + continue; + buflen -= pagelen; + pagelen = i << PAGE_SHIFT; + if (pagelen > buf->page_base) + buflen += pagelen - buf->page_base; + break; + } + return buflen; +} + +static void xdr_buf_try_expand(struct xdr_buf *buf, unsigned int len) +{ + struct kvec *head = buf->head; + struct kvec *tail = buf->tail; + unsigned int sum = head->iov_len + buf->page_len + tail->iov_len; + unsigned int free_space, newlen; + + if (sum > buf->len) { + free_space = min_t(unsigned int, sum - buf->len, len); + newlen = xdr_buf_pages_fill_sparse(buf, buf->len + free_space, + GFP_KERNEL); + free_space = newlen - buf->len; + buf->len = newlen; + len -= free_space; + if (!len) + return; + } + + if (buf->buflen > sum) { + /* Expand the tail buffer */ + free_space = min_t(unsigned int, buf->buflen - sum, len); + tail->iov_len += free_space; + buf->len += free_space; + } +} + +static void xdr_buf_tail_copy_right(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + const struct kvec *tail = buf->tail; + unsigned int to = base + shift; + + if (to >= tail->iov_len) + return; + if (len + to > tail->iov_len) + len = tail->iov_len - to; + memmove(tail->iov_base + to, tail->iov_base + base, len); +} + +static void xdr_buf_pages_copy_right(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + const struct kvec *tail = buf->tail; + unsigned int to = base + shift; + unsigned int pglen = 0; + unsigned int talen = 0, tato = 0; + + if (base >= buf->page_len) + return; + if (len > buf->page_len - base) + len = buf->page_len - base; + if (to >= buf->page_len) { + tato = to - buf->page_len; + if (tail->iov_len >= len + tato) + talen = len; + else if (tail->iov_len > tato) + talen = tail->iov_len - tato; + } else if (len + to >= buf->page_len) { + pglen = buf->page_len - to; + talen = len - pglen; + if (talen > tail->iov_len) + talen = tail->iov_len; + } else + pglen = len; + + _copy_from_pages(tail->iov_base + tato, buf->pages, + buf->page_base + base + pglen, talen); + _shift_data_right_pages(buf->pages, buf->page_base + to, + buf->page_base + base, pglen); +} + +static void xdr_buf_head_copy_right(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + const struct kvec *head = buf->head; + const struct kvec *tail = buf->tail; + unsigned int to = base + shift; + unsigned int pglen = 0, pgto = 0; + unsigned int talen = 0, tato = 0; + + if (base >= head->iov_len) + return; + if (len > head->iov_len - base) + len = head->iov_len - base; + if (to >= buf->page_len + head->iov_len) { + tato = to - buf->page_len - head->iov_len; + talen = len; + } else if (to >= head->iov_len) { + pgto = to - head->iov_len; + pglen = len; + if (pgto + pglen > buf->page_len) { + talen = pgto + pglen - buf->page_len; + pglen -= talen; + } + } else { + pglen = len - to; + if (pglen > buf->page_len) { + talen = pglen - buf->page_len; + pglen = buf->page_len; + } + } + + len -= talen; + base += len; + if (talen + tato > tail->iov_len) + talen = tail->iov_len > tato ? tail->iov_len - tato : 0; + memcpy(tail->iov_base + tato, head->iov_base + base, talen); + + len -= pglen; + base -= pglen; + _copy_to_pages(buf->pages, buf->page_base + pgto, head->iov_base + base, + pglen); + + base -= len; + memmove(head->iov_base + to, head->iov_base + base, len); +} + +static void xdr_buf_tail_shift_right(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + const struct kvec *tail = buf->tail; + + if (base >= tail->iov_len || !shift || !len) + return; + xdr_buf_tail_copy_right(buf, base, len, shift); +} + +static void xdr_buf_pages_shift_right(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + if (!shift || !len) + return; + if (base >= buf->page_len) { + xdr_buf_tail_shift_right(buf, base - buf->page_len, len, shift); + return; + } + if (base + len > buf->page_len) + xdr_buf_tail_shift_right(buf, 0, base + len - buf->page_len, + shift); + xdr_buf_pages_copy_right(buf, base, len, shift); +} + +static void xdr_buf_head_shift_right(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + const struct kvec *head = buf->head; + + if (!shift) + return; + if (base >= head->iov_len) { + xdr_buf_pages_shift_right(buf, head->iov_len - base, len, + shift); + return; + } + if (base + len > head->iov_len) + xdr_buf_pages_shift_right(buf, 0, base + len - head->iov_len, + shift); + xdr_buf_head_copy_right(buf, base, len, shift); +} + +static void xdr_buf_tail_copy_left(const struct xdr_buf *buf, unsigned int base, + unsigned int len, unsigned int shift) +{ + const struct kvec *tail = buf->tail; + + if (base >= tail->iov_len) + return; + if (len > tail->iov_len - base) + len = tail->iov_len - base; + /* Shift data into head */ + if (shift > buf->page_len + base) { + const struct kvec *head = buf->head; + unsigned int hdto = + head->iov_len + buf->page_len + base - shift; + unsigned int hdlen = len; + + if (WARN_ONCE(shift > head->iov_len + buf->page_len + base, + "SUNRPC: Misaligned data.\n")) + return; + if (hdto + hdlen > head->iov_len) + hdlen = head->iov_len - hdto; + memcpy(head->iov_base + hdto, tail->iov_base + base, hdlen); + base += hdlen; + len -= hdlen; + if (!len) + return; + } + /* Shift data into pages */ + if (shift > base) { + unsigned int pgto = buf->page_len + base - shift; + unsigned int pglen = len; + + if (pgto + pglen > buf->page_len) + pglen = buf->page_len - pgto; + _copy_to_pages(buf->pages, buf->page_base + pgto, + tail->iov_base + base, pglen); + base += pglen; + len -= pglen; + if (!len) + return; + } + memmove(tail->iov_base + base - shift, tail->iov_base + base, len); +} + +static void xdr_buf_pages_copy_left(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + unsigned int pgto; + + if (base >= buf->page_len) + return; + if (len > buf->page_len - base) + len = buf->page_len - base; + /* Shift data into head */ + if (shift > base) { + const struct kvec *head = buf->head; + unsigned int hdto = head->iov_len + base - shift; + unsigned int hdlen = len; + + if (WARN_ONCE(shift > head->iov_len + base, + "SUNRPC: Misaligned data.\n")) + return; + if (hdto + hdlen > head->iov_len) + hdlen = head->iov_len - hdto; + _copy_from_pages(head->iov_base + hdto, buf->pages, + buf->page_base + base, hdlen); + base += hdlen; + len -= hdlen; + if (!len) + return; + } + pgto = base - shift; + _shift_data_left_pages(buf->pages, buf->page_base + pgto, + buf->page_base + base, len); +} + +static void xdr_buf_tail_shift_left(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + if (!shift || !len) + return; + xdr_buf_tail_copy_left(buf, base, len, shift); +} + +static void xdr_buf_pages_shift_left(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + if (!shift || !len) + return; + if (base >= buf->page_len) { + xdr_buf_tail_shift_left(buf, base - buf->page_len, len, shift); + return; + } + xdr_buf_pages_copy_left(buf, base, len, shift); + len += base; + if (len <= buf->page_len) + return; + xdr_buf_tail_copy_left(buf, 0, len - buf->page_len, shift); +} + +static void xdr_buf_head_shift_left(const struct xdr_buf *buf, + unsigned int base, unsigned int len, + unsigned int shift) +{ + const struct kvec *head = buf->head; + unsigned int bytes; + + if (!shift || !len) + return; + + if (shift > base) { + bytes = (shift - base); + if (bytes >= len) + return; + base += bytes; + len -= bytes; + } + + if (base < head->iov_len) { + bytes = min_t(unsigned int, len, head->iov_len - base); + memmove(head->iov_base + (base - shift), + head->iov_base + base, bytes); + base += bytes; + len -= bytes; + } + xdr_buf_pages_shift_left(buf, base - head->iov_len, len, shift); +} + +/** + * xdr_shrink_bufhead + * @buf: xdr_buf + * @len: new length of buf->head[0] + * + * Shrinks XDR buffer's header kvec buf->head[0], setting it to + * 'len' bytes. The extra data is not lost, but is instead + * moved into the inlined pages and/or the tail. + */ +static unsigned int xdr_shrink_bufhead(struct xdr_buf *buf, unsigned int len) +{ + struct kvec *head = buf->head; + unsigned int shift, buflen = max(buf->len, len); + + WARN_ON_ONCE(len > head->iov_len); + if (head->iov_len > buflen) { + buf->buflen -= head->iov_len - buflen; + head->iov_len = buflen; + } + if (len >= head->iov_len) + return 0; + shift = head->iov_len - len; + xdr_buf_try_expand(buf, shift); + xdr_buf_head_shift_right(buf, len, buflen - len, shift); + head->iov_len = len; + buf->buflen -= shift; + buf->len -= shift; + return shift; +} + +/** + * xdr_shrink_pagelen - shrinks buf->pages to @len bytes + * @buf: xdr_buf + * @len: new page buffer length + * + * The extra data is not lost, but is instead moved into buf->tail. + * Returns the actual number of bytes moved. + */ +static unsigned int xdr_shrink_pagelen(struct xdr_buf *buf, unsigned int len) +{ + unsigned int shift, buflen = buf->len - buf->head->iov_len; + + WARN_ON_ONCE(len > buf->page_len); + if (buf->head->iov_len >= buf->len || len > buflen) + buflen = len; + if (buf->page_len > buflen) { + buf->buflen -= buf->page_len - buflen; + buf->page_len = buflen; + } + if (len >= buf->page_len) + return 0; + shift = buf->page_len - len; + xdr_buf_try_expand(buf, shift); + xdr_buf_pages_shift_right(buf, len, buflen - len, shift); + buf->page_len = len; + buf->len -= shift; + buf->buflen -= shift; + return shift; +} + +void +xdr_shift_buf(struct xdr_buf *buf, size_t len) +{ + xdr_shrink_bufhead(buf, buf->head->iov_len - len); +} +EXPORT_SYMBOL_GPL(xdr_shift_buf); + +/** + * xdr_stream_pos - Return the current offset from the start of the xdr_stream + * @xdr: pointer to struct xdr_stream + */ +unsigned int xdr_stream_pos(const struct xdr_stream *xdr) +{ + return (unsigned int)(XDR_QUADLEN(xdr->buf->len) - xdr->nwords) << 2; +} +EXPORT_SYMBOL_GPL(xdr_stream_pos); + +static void xdr_stream_set_pos(struct xdr_stream *xdr, unsigned int pos) +{ + unsigned int blen = xdr->buf->len; + + xdr->nwords = blen > pos ? XDR_QUADLEN(blen) - XDR_QUADLEN(pos) : 0; +} + +static void xdr_stream_page_set_pos(struct xdr_stream *xdr, unsigned int pos) +{ + xdr_stream_set_pos(xdr, pos + xdr->buf->head[0].iov_len); +} + +/** + * xdr_page_pos - Return the current offset from the start of the xdr pages + * @xdr: pointer to struct xdr_stream + */ +unsigned int xdr_page_pos(const struct xdr_stream *xdr) +{ + unsigned int pos = xdr_stream_pos(xdr); + + WARN_ON(pos < xdr->buf->head[0].iov_len); + return pos - xdr->buf->head[0].iov_len; +} +EXPORT_SYMBOL_GPL(xdr_page_pos); + +/** + * xdr_init_encode - Initialize a struct xdr_stream for sending data. + * @xdr: pointer to xdr_stream struct + * @buf: pointer to XDR buffer in which to encode data + * @p: current pointer inside XDR buffer + * @rqst: pointer to controlling rpc_rqst, for debugging + * + * Note: at the moment the RPC client only passes the length of our + * scratch buffer in the xdr_buf's header kvec. Previously this + * meant we needed to call xdr_adjust_iovec() after encoding the + * data. With the new scheme, the xdr_stream manages the details + * of the buffer length, and takes care of adjusting the kvec + * length for us. + */ +void xdr_init_encode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p, + struct rpc_rqst *rqst) +{ + struct kvec *iov = buf->head; + int scratch_len = buf->buflen - buf->page_len - buf->tail[0].iov_len; + + xdr_reset_scratch_buffer(xdr); + BUG_ON(scratch_len < 0); + xdr->buf = buf; + xdr->iov = iov; + xdr->p = (__be32 *)((char *)iov->iov_base + iov->iov_len); + xdr->end = (__be32 *)((char *)iov->iov_base + scratch_len); + BUG_ON(iov->iov_len > scratch_len); + + if (p != xdr->p && p != NULL) { + size_t len; + + BUG_ON(p < xdr->p || p > xdr->end); + len = (char *)p - (char *)xdr->p; + xdr->p = p; + buf->len += len; + iov->iov_len += len; + } + xdr->rqst = rqst; +} +EXPORT_SYMBOL_GPL(xdr_init_encode); + +/** + * xdr_init_encode_pages - Initialize an xdr_stream for encoding into pages + * @xdr: pointer to xdr_stream struct + * @buf: pointer to XDR buffer into which to encode data + * @pages: list of pages to decode into + * @rqst: pointer to controlling rpc_rqst, for debugging + * + */ +void xdr_init_encode_pages(struct xdr_stream *xdr, struct xdr_buf *buf, + struct page **pages, struct rpc_rqst *rqst) +{ + xdr_reset_scratch_buffer(xdr); + + xdr->buf = buf; + xdr->page_ptr = pages; + xdr->iov = NULL; + xdr->p = page_address(*pages); + xdr->end = (void *)xdr->p + min_t(u32, buf->buflen, PAGE_SIZE); + xdr->rqst = rqst; +} +EXPORT_SYMBOL_GPL(xdr_init_encode_pages); + +/** + * __xdr_commit_encode - Ensure all data is written to buffer + * @xdr: pointer to xdr_stream + * + * We handle encoding across page boundaries by giving the caller a + * temporary location to write to, then later copying the data into + * place; xdr_commit_encode does that copying. + * + * Normally the caller doesn't need to call this directly, as the + * following xdr_reserve_space will do it. But an explicit call may be + * required at the end of encoding, or any other time when the xdr_buf + * data might be read. + */ +void __xdr_commit_encode(struct xdr_stream *xdr) +{ + size_t shift = xdr->scratch.iov_len; + void *page; + + page = page_address(*xdr->page_ptr); + memcpy(xdr->scratch.iov_base, page, shift); + memmove(page, page + shift, (void *)xdr->p - page); + xdr_reset_scratch_buffer(xdr); +} +EXPORT_SYMBOL_GPL(__xdr_commit_encode); + +/* + * The buffer space to be reserved crosses the boundary between + * xdr->buf->head and xdr->buf->pages, or between two pages + * in xdr->buf->pages. + */ +static noinline __be32 *xdr_get_next_encode_buffer(struct xdr_stream *xdr, + size_t nbytes) +{ + int space_left; + int frag1bytes, frag2bytes; + void *p; + + if (nbytes > PAGE_SIZE) + goto out_overflow; /* Bigger buffers require special handling */ + if (xdr->buf->len + nbytes > xdr->buf->buflen) + goto out_overflow; /* Sorry, we're totally out of space */ + frag1bytes = (xdr->end - xdr->p) << 2; + frag2bytes = nbytes - frag1bytes; + if (xdr->iov) + xdr->iov->iov_len += frag1bytes; + else + xdr->buf->page_len += frag1bytes; + xdr->page_ptr++; + xdr->iov = NULL; + + /* + * If the last encode didn't end exactly on a page boundary, the + * next one will straddle boundaries. Encode into the next + * page, then copy it back later in xdr_commit_encode. We use + * the "scratch" iov to track any temporarily unused fragment of + * space at the end of the previous buffer: + */ + xdr_set_scratch_buffer(xdr, xdr->p, frag1bytes); + + /* + * xdr->p is where the next encode will start after + * xdr_commit_encode() has shifted this one back: + */ + p = page_address(*xdr->page_ptr); + xdr->p = p + frag2bytes; + space_left = xdr->buf->buflen - xdr->buf->len; + if (space_left - frag1bytes >= PAGE_SIZE) + xdr->end = p + PAGE_SIZE; + else + xdr->end = p + space_left - frag1bytes; + + xdr->buf->page_len += frag2bytes; + xdr->buf->len += nbytes; + return p; +out_overflow: + trace_rpc_xdr_overflow(xdr, nbytes); + return NULL; +} + +/** + * xdr_reserve_space - Reserve buffer space for sending + * @xdr: pointer to xdr_stream + * @nbytes: number of bytes to reserve + * + * Checks that we have enough buffer space to encode 'nbytes' more + * bytes of data. If so, update the total xdr_buf length, and + * adjust the length of the current kvec. + */ +__be32 * xdr_reserve_space(struct xdr_stream *xdr, size_t nbytes) +{ + __be32 *p = xdr->p; + __be32 *q; + + xdr_commit_encode(xdr); + /* align nbytes on the next 32-bit boundary */ + nbytes += 3; + nbytes &= ~3; + q = p + (nbytes >> 2); + if (unlikely(q > xdr->end || q < p)) + return xdr_get_next_encode_buffer(xdr, nbytes); + xdr->p = q; + if (xdr->iov) + xdr->iov->iov_len += nbytes; + else + xdr->buf->page_len += nbytes; + xdr->buf->len += nbytes; + return p; +} +EXPORT_SYMBOL_GPL(xdr_reserve_space); + + +/** + * xdr_reserve_space_vec - Reserves a large amount of buffer space for sending + * @xdr: pointer to xdr_stream + * @vec: pointer to a kvec array + * @nbytes: number of bytes to reserve + * + * Reserves enough buffer space to encode 'nbytes' of data and stores the + * pointers in 'vec'. The size argument passed to xdr_reserve_space() is + * determined based on the number of bytes remaining in the current page to + * avoid invalidating iov_base pointers when xdr_commit_encode() is called. + */ +int xdr_reserve_space_vec(struct xdr_stream *xdr, struct kvec *vec, size_t nbytes) +{ + int thislen; + int v = 0; + __be32 *p; + + /* + * svcrdma requires every READ payload to start somewhere + * in xdr->pages. + */ + if (xdr->iov == xdr->buf->head) { + xdr->iov = NULL; + xdr->end = xdr->p; + } + + while (nbytes) { + thislen = xdr->buf->page_len % PAGE_SIZE; + thislen = min_t(size_t, nbytes, PAGE_SIZE - thislen); + + p = xdr_reserve_space(xdr, thislen); + if (!p) + return -EIO; + + vec[v].iov_base = p; + vec[v].iov_len = thislen; + v++; + nbytes -= thislen; + } + + return v; +} +EXPORT_SYMBOL_GPL(xdr_reserve_space_vec); + +/** + * xdr_truncate_encode - truncate an encode buffer + * @xdr: pointer to xdr_stream + * @len: new length of buffer + * + * Truncates the xdr stream, so that xdr->buf->len == len, + * and xdr->p points at offset len from the start of the buffer, and + * head, tail, and page lengths are adjusted to correspond. + * + * If this means moving xdr->p to a different buffer, we assume that + * the end pointer should be set to the end of the current page, + * except in the case of the head buffer when we assume the head + * buffer's current length represents the end of the available buffer. + * + * This is *not* safe to use on a buffer that already has inlined page + * cache pages (as in a zero-copy server read reply), except for the + * simple case of truncating from one position in the tail to another. + * + */ +void xdr_truncate_encode(struct xdr_stream *xdr, size_t len) +{ + struct xdr_buf *buf = xdr->buf; + struct kvec *head = buf->head; + struct kvec *tail = buf->tail; + int fraglen; + int new; + + if (len > buf->len) { + WARN_ON_ONCE(1); + return; + } + xdr_commit_encode(xdr); + + fraglen = min_t(int, buf->len - len, tail->iov_len); + tail->iov_len -= fraglen; + buf->len -= fraglen; + if (tail->iov_len) { + xdr->p = tail->iov_base + tail->iov_len; + WARN_ON_ONCE(!xdr->end); + WARN_ON_ONCE(!xdr->iov); + return; + } + WARN_ON_ONCE(fraglen); + fraglen = min_t(int, buf->len - len, buf->page_len); + buf->page_len -= fraglen; + buf->len -= fraglen; + + new = buf->page_base + buf->page_len; + + xdr->page_ptr = buf->pages + (new >> PAGE_SHIFT); + + if (buf->page_len) { + xdr->p = page_address(*xdr->page_ptr); + xdr->end = (void *)xdr->p + PAGE_SIZE; + xdr->p = (void *)xdr->p + (new % PAGE_SIZE); + WARN_ON_ONCE(xdr->iov); + return; + } + if (fraglen) + xdr->end = head->iov_base + head->iov_len; + /* (otherwise assume xdr->end is already set) */ + xdr->page_ptr--; + head->iov_len = len; + buf->len = len; + xdr->p = head->iov_base + head->iov_len; + xdr->iov = buf->head; +} +EXPORT_SYMBOL(xdr_truncate_encode); + +/** + * xdr_restrict_buflen - decrease available buffer space + * @xdr: pointer to xdr_stream + * @newbuflen: new maximum number of bytes available + * + * Adjust our idea of how much space is available in the buffer. + * If we've already used too much space in the buffer, returns -1. + * If the available space is already smaller than newbuflen, returns 0 + * and does nothing. Otherwise, adjusts xdr->buf->buflen to newbuflen + * and ensures xdr->end is set at most offset newbuflen from the start + * of the buffer. + */ +int xdr_restrict_buflen(struct xdr_stream *xdr, int newbuflen) +{ + struct xdr_buf *buf = xdr->buf; + int left_in_this_buf = (void *)xdr->end - (void *)xdr->p; + int end_offset = buf->len + left_in_this_buf; + + if (newbuflen < 0 || newbuflen < buf->len) + return -1; + if (newbuflen > buf->buflen) + return 0; + if (newbuflen < end_offset) + xdr->end = (void *)xdr->end + newbuflen - end_offset; + buf->buflen = newbuflen; + return 0; +} +EXPORT_SYMBOL(xdr_restrict_buflen); + +/** + * xdr_write_pages - Insert a list of pages into an XDR buffer for sending + * @xdr: pointer to xdr_stream + * @pages: list of pages + * @base: offset of first byte + * @len: length of data in bytes + * + */ +void xdr_write_pages(struct xdr_stream *xdr, struct page **pages, unsigned int base, + unsigned int len) +{ + struct xdr_buf *buf = xdr->buf; + struct kvec *iov = buf->tail; + buf->pages = pages; + buf->page_base = base; + buf->page_len = len; + + iov->iov_base = (char *)xdr->p; + iov->iov_len = 0; + xdr->iov = iov; + + if (len & 3) { + unsigned int pad = 4 - (len & 3); + + BUG_ON(xdr->p >= xdr->end); + iov->iov_base = (char *)xdr->p + (len & 3); + iov->iov_len += pad; + len += pad; + *xdr->p++ = 0; + } + buf->buflen += len; + buf->len += len; +} +EXPORT_SYMBOL_GPL(xdr_write_pages); + +static unsigned int xdr_set_iov(struct xdr_stream *xdr, struct kvec *iov, + unsigned int base, unsigned int len) +{ + if (len > iov->iov_len) + len = iov->iov_len; + if (unlikely(base > len)) + base = len; + xdr->p = (__be32*)(iov->iov_base + base); + xdr->end = (__be32*)(iov->iov_base + len); + xdr->iov = iov; + xdr->page_ptr = NULL; + return len - base; +} + +static unsigned int xdr_set_tail_base(struct xdr_stream *xdr, + unsigned int base, unsigned int len) +{ + struct xdr_buf *buf = xdr->buf; + + xdr_stream_set_pos(xdr, base + buf->page_len + buf->head->iov_len); + return xdr_set_iov(xdr, buf->tail, base, len); +} + +static unsigned int xdr_set_page_base(struct xdr_stream *xdr, + unsigned int base, unsigned int len) +{ + unsigned int pgnr; + unsigned int maxlen; + unsigned int pgoff; + unsigned int pgend; + void *kaddr; + + maxlen = xdr->buf->page_len; + if (base >= maxlen) + return 0; + else + maxlen -= base; + if (len > maxlen) + len = maxlen; + + xdr_stream_page_set_pos(xdr, base); + base += xdr->buf->page_base; + + pgnr = base >> PAGE_SHIFT; + xdr->page_ptr = &xdr->buf->pages[pgnr]; + kaddr = page_address(*xdr->page_ptr); + + pgoff = base & ~PAGE_MASK; + xdr->p = (__be32*)(kaddr + pgoff); + + pgend = pgoff + len; + if (pgend > PAGE_SIZE) + pgend = PAGE_SIZE; + xdr->end = (__be32*)(kaddr + pgend); + xdr->iov = NULL; + return len; +} + +static void xdr_set_page(struct xdr_stream *xdr, unsigned int base, + unsigned int len) +{ + if (xdr_set_page_base(xdr, base, len) == 0) { + base -= xdr->buf->page_len; + xdr_set_tail_base(xdr, base, len); + } +} + +static void xdr_set_next_page(struct xdr_stream *xdr) +{ + unsigned int newbase; + + newbase = (1 + xdr->page_ptr - xdr->buf->pages) << PAGE_SHIFT; + newbase -= xdr->buf->page_base; + if (newbase < xdr->buf->page_len) + xdr_set_page_base(xdr, newbase, xdr_stream_remaining(xdr)); + else + xdr_set_tail_base(xdr, 0, xdr_stream_remaining(xdr)); +} + +static bool xdr_set_next_buffer(struct xdr_stream *xdr) +{ + if (xdr->page_ptr != NULL) + xdr_set_next_page(xdr); + else if (xdr->iov == xdr->buf->head) + xdr_set_page(xdr, 0, xdr_stream_remaining(xdr)); + return xdr->p != xdr->end; +} + +/** + * xdr_init_decode - Initialize an xdr_stream for decoding data. + * @xdr: pointer to xdr_stream struct + * @buf: pointer to XDR buffer from which to decode data + * @p: current pointer inside XDR buffer + * @rqst: pointer to controlling rpc_rqst, for debugging + */ +void xdr_init_decode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p, + struct rpc_rqst *rqst) +{ + xdr->buf = buf; + xdr_reset_scratch_buffer(xdr); + xdr->nwords = XDR_QUADLEN(buf->len); + if (xdr_set_iov(xdr, buf->head, 0, buf->len) == 0 && + xdr_set_page_base(xdr, 0, buf->len) == 0) + xdr_set_iov(xdr, buf->tail, 0, buf->len); + if (p != NULL && p > xdr->p && xdr->end >= p) { + xdr->nwords -= p - xdr->p; + xdr->p = p; + } + xdr->rqst = rqst; +} +EXPORT_SYMBOL_GPL(xdr_init_decode); + +/** + * xdr_init_decode_pages - Initialize an xdr_stream for decoding into pages + * @xdr: pointer to xdr_stream struct + * @buf: pointer to XDR buffer from which to decode data + * @pages: list of pages to decode into + * @len: length in bytes of buffer in pages + */ +void xdr_init_decode_pages(struct xdr_stream *xdr, struct xdr_buf *buf, + struct page **pages, unsigned int len) +{ + memset(buf, 0, sizeof(*buf)); + buf->pages = pages; + buf->page_len = len; + buf->buflen = len; + buf->len = len; + xdr_init_decode(xdr, buf, NULL, NULL); +} +EXPORT_SYMBOL_GPL(xdr_init_decode_pages); + +static __be32 * __xdr_inline_decode(struct xdr_stream *xdr, size_t nbytes) +{ + unsigned int nwords = XDR_QUADLEN(nbytes); + __be32 *p = xdr->p; + __be32 *q = p + nwords; + + if (unlikely(nwords > xdr->nwords || q > xdr->end || q < p)) + return NULL; + xdr->p = q; + xdr->nwords -= nwords; + return p; +} + +static __be32 *xdr_copy_to_scratch(struct xdr_stream *xdr, size_t nbytes) +{ + __be32 *p; + char *cpdest = xdr->scratch.iov_base; + size_t cplen = (char *)xdr->end - (char *)xdr->p; + + if (nbytes > xdr->scratch.iov_len) + goto out_overflow; + p = __xdr_inline_decode(xdr, cplen); + if (p == NULL) + return NULL; + memcpy(cpdest, p, cplen); + if (!xdr_set_next_buffer(xdr)) + goto out_overflow; + cpdest += cplen; + nbytes -= cplen; + p = __xdr_inline_decode(xdr, nbytes); + if (p == NULL) + return NULL; + memcpy(cpdest, p, nbytes); + return xdr->scratch.iov_base; +out_overflow: + trace_rpc_xdr_overflow(xdr, nbytes); + return NULL; +} + +/** + * xdr_inline_decode - Retrieve XDR data to decode + * @xdr: pointer to xdr_stream struct + * @nbytes: number of bytes of data to decode + * + * Check if the input buffer is long enough to enable us to decode + * 'nbytes' more bytes of data starting at the current position. + * If so return the current pointer, then update the current + * pointer position. + */ +__be32 * xdr_inline_decode(struct xdr_stream *xdr, size_t nbytes) +{ + __be32 *p; + + if (unlikely(nbytes == 0)) + return xdr->p; + if (xdr->p == xdr->end && !xdr_set_next_buffer(xdr)) + goto out_overflow; + p = __xdr_inline_decode(xdr, nbytes); + if (p != NULL) + return p; + return xdr_copy_to_scratch(xdr, nbytes); +out_overflow: + trace_rpc_xdr_overflow(xdr, nbytes); + return NULL; +} +EXPORT_SYMBOL_GPL(xdr_inline_decode); + +static void xdr_realign_pages(struct xdr_stream *xdr) +{ + struct xdr_buf *buf = xdr->buf; + struct kvec *iov = buf->head; + unsigned int cur = xdr_stream_pos(xdr); + unsigned int copied; + + /* Realign pages to current pointer position */ + if (iov->iov_len > cur) { + copied = xdr_shrink_bufhead(buf, cur); + trace_rpc_xdr_alignment(xdr, cur, copied); + xdr_set_page(xdr, 0, buf->page_len); + } +} + +static unsigned int xdr_align_pages(struct xdr_stream *xdr, unsigned int len) +{ + struct xdr_buf *buf = xdr->buf; + unsigned int nwords = XDR_QUADLEN(len); + unsigned int copied; + + if (xdr->nwords == 0) + return 0; + + xdr_realign_pages(xdr); + if (nwords > xdr->nwords) { + nwords = xdr->nwords; + len = nwords << 2; + } + if (buf->page_len <= len) + len = buf->page_len; + else if (nwords < xdr->nwords) { + /* Truncate page data and move it into the tail */ + copied = xdr_shrink_pagelen(buf, len); + trace_rpc_xdr_alignment(xdr, len, copied); + } + return len; +} + +/** + * xdr_read_pages - align page-based XDR data to current pointer position + * @xdr: pointer to xdr_stream struct + * @len: number of bytes of page data + * + * Moves data beyond the current pointer position from the XDR head[] buffer + * into the page list. Any data that lies beyond current position + @len + * bytes is moved into the XDR tail[]. The xdr_stream current position is + * then advanced past that data to align to the next XDR object in the tail. + * + * Returns the number of XDR encoded bytes now contained in the pages + */ +unsigned int xdr_read_pages(struct xdr_stream *xdr, unsigned int len) +{ + unsigned int nwords = XDR_QUADLEN(len); + unsigned int base, end, pglen; + + pglen = xdr_align_pages(xdr, nwords << 2); + if (pglen == 0) + return 0; + + base = (nwords << 2) - pglen; + end = xdr_stream_remaining(xdr) - pglen; + + xdr_set_tail_base(xdr, base, end); + return len <= pglen ? len : pglen; +} +EXPORT_SYMBOL_GPL(xdr_read_pages); + +/** + * xdr_set_pagelen - Sets the length of the XDR pages + * @xdr: pointer to xdr_stream struct + * @len: new length of the XDR page data + * + * Either grows or shrinks the length of the xdr pages by setting pagelen to + * @len bytes. When shrinking, any extra data is moved into buf->tail, whereas + * when growing any data beyond the current pointer is moved into the tail. + * + * Returns True if the operation was successful, and False otherwise. + */ +void xdr_set_pagelen(struct xdr_stream *xdr, unsigned int len) +{ + struct xdr_buf *buf = xdr->buf; + size_t remaining = xdr_stream_remaining(xdr); + size_t base = 0; + + if (len < buf->page_len) { + base = buf->page_len - len; + xdr_shrink_pagelen(buf, len); + } else { + xdr_buf_head_shift_right(buf, xdr_stream_pos(xdr), + buf->page_len, remaining); + if (len > buf->page_len) + xdr_buf_try_expand(buf, len - buf->page_len); + } + xdr_set_tail_base(xdr, base, remaining); +} +EXPORT_SYMBOL_GPL(xdr_set_pagelen); + +/** + * xdr_enter_page - decode data from the XDR page + * @xdr: pointer to xdr_stream struct + * @len: number of bytes of page data + * + * Moves data beyond the current pointer position from the XDR head[] buffer + * into the page list. Any data that lies beyond current position + "len" + * bytes is moved into the XDR tail[]. The current pointer is then + * repositioned at the beginning of the first XDR page. + */ +void xdr_enter_page(struct xdr_stream *xdr, unsigned int len) +{ + len = xdr_align_pages(xdr, len); + /* + * Position current pointer at beginning of tail, and + * set remaining message length. + */ + if (len != 0) + xdr_set_page_base(xdr, 0, len); +} +EXPORT_SYMBOL_GPL(xdr_enter_page); + +static const struct kvec empty_iov = {.iov_base = NULL, .iov_len = 0}; + +void xdr_buf_from_iov(const struct kvec *iov, struct xdr_buf *buf) +{ + buf->head[0] = *iov; + buf->tail[0] = empty_iov; + buf->page_len = 0; + buf->buflen = buf->len = iov->iov_len; +} +EXPORT_SYMBOL_GPL(xdr_buf_from_iov); + +/** + * xdr_buf_subsegment - set subbuf to a portion of buf + * @buf: an xdr buffer + * @subbuf: the result buffer + * @base: beginning of range in bytes + * @len: length of range in bytes + * + * sets @subbuf to an xdr buffer representing the portion of @buf of + * length @len starting at offset @base. + * + * @buf and @subbuf may be pointers to the same struct xdr_buf. + * + * Returns -1 if base or length are out of bounds. + */ +int xdr_buf_subsegment(const struct xdr_buf *buf, struct xdr_buf *subbuf, + unsigned int base, unsigned int len) +{ + subbuf->buflen = subbuf->len = len; + if (base < buf->head[0].iov_len) { + subbuf->head[0].iov_base = buf->head[0].iov_base + base; + subbuf->head[0].iov_len = min_t(unsigned int, len, + buf->head[0].iov_len - base); + len -= subbuf->head[0].iov_len; + base = 0; + } else { + base -= buf->head[0].iov_len; + subbuf->head[0].iov_base = buf->head[0].iov_base; + subbuf->head[0].iov_len = 0; + } + + if (base < buf->page_len) { + subbuf->page_len = min(buf->page_len - base, len); + base += buf->page_base; + subbuf->page_base = base & ~PAGE_MASK; + subbuf->pages = &buf->pages[base >> PAGE_SHIFT]; + len -= subbuf->page_len; + base = 0; + } else { + base -= buf->page_len; + subbuf->pages = buf->pages; + subbuf->page_base = 0; + subbuf->page_len = 0; + } + + if (base < buf->tail[0].iov_len) { + subbuf->tail[0].iov_base = buf->tail[0].iov_base + base; + subbuf->tail[0].iov_len = min_t(unsigned int, len, + buf->tail[0].iov_len - base); + len -= subbuf->tail[0].iov_len; + base = 0; + } else { + base -= buf->tail[0].iov_len; + subbuf->tail[0].iov_base = buf->tail[0].iov_base; + subbuf->tail[0].iov_len = 0; + } + + if (base || len) + return -1; + return 0; +} +EXPORT_SYMBOL_GPL(xdr_buf_subsegment); + +/** + * xdr_stream_subsegment - set @subbuf to a portion of @xdr + * @xdr: an xdr_stream set up for decoding + * @subbuf: the result buffer + * @nbytes: length of @xdr to extract, in bytes + * + * Sets up @subbuf to represent a portion of @xdr. The portion + * starts at the current offset in @xdr, and extends for a length + * of @nbytes. If this is successful, @xdr is advanced to the next + * XDR data item following that portion. + * + * Return values: + * %true: @subbuf has been initialized, and @xdr has been advanced. + * %false: a bounds error has occurred + */ +bool xdr_stream_subsegment(struct xdr_stream *xdr, struct xdr_buf *subbuf, + unsigned int nbytes) +{ + unsigned int start = xdr_stream_pos(xdr); + unsigned int remaining, len; + + /* Extract @subbuf and bounds-check the fn arguments */ + if (xdr_buf_subsegment(xdr->buf, subbuf, start, nbytes)) + return false; + + /* Advance @xdr by @nbytes */ + for (remaining = nbytes; remaining;) { + if (xdr->p == xdr->end && !xdr_set_next_buffer(xdr)) + return false; + + len = (char *)xdr->end - (char *)xdr->p; + if (remaining <= len) { + xdr->p = (__be32 *)((char *)xdr->p + + (remaining + xdr_pad_size(nbytes))); + break; + } + + xdr->p = (__be32 *)((char *)xdr->p + len); + xdr->end = xdr->p; + remaining -= len; + } + + xdr_stream_set_pos(xdr, start + nbytes); + return true; +} +EXPORT_SYMBOL_GPL(xdr_stream_subsegment); + +/** + * xdr_stream_move_subsegment - Move part of a stream to another position + * @xdr: the source xdr_stream + * @offset: the source offset of the segment + * @target: the target offset of the segment + * @length: the number of bytes to move + * + * Moves @length bytes from @offset to @target in the xdr_stream, overwriting + * anything in its space. Returns the number of bytes in the segment. + */ +unsigned int xdr_stream_move_subsegment(struct xdr_stream *xdr, unsigned int offset, + unsigned int target, unsigned int length) +{ + struct xdr_buf buf; + unsigned int shift; + + if (offset < target) { + shift = target - offset; + if (xdr_buf_subsegment(xdr->buf, &buf, offset, shift + length) < 0) + return 0; + xdr_buf_head_shift_right(&buf, 0, length, shift); + } else if (offset > target) { + shift = offset - target; + if (xdr_buf_subsegment(xdr->buf, &buf, target, shift + length) < 0) + return 0; + xdr_buf_head_shift_left(&buf, shift, length, shift); + } + return length; +} +EXPORT_SYMBOL_GPL(xdr_stream_move_subsegment); + +/** + * xdr_stream_zero - zero out a portion of an xdr_stream + * @xdr: an xdr_stream to zero out + * @offset: the starting point in the stream + * @length: the number of bytes to zero + */ +unsigned int xdr_stream_zero(struct xdr_stream *xdr, unsigned int offset, + unsigned int length) +{ + struct xdr_buf buf; + + if (xdr_buf_subsegment(xdr->buf, &buf, offset, length) < 0) + return 0; + if (buf.head[0].iov_len) + xdr_buf_iov_zero(buf.head, 0, buf.head[0].iov_len); + if (buf.page_len > 0) + xdr_buf_pages_zero(&buf, 0, buf.page_len); + if (buf.tail[0].iov_len) + xdr_buf_iov_zero(buf.tail, 0, buf.tail[0].iov_len); + return length; +} +EXPORT_SYMBOL_GPL(xdr_stream_zero); + +/** + * xdr_buf_trim - lop at most "len" bytes off the end of "buf" + * @buf: buf to be trimmed + * @len: number of bytes to reduce "buf" by + * + * Trim an xdr_buf by the given number of bytes by fixing up the lengths. Note + * that it's possible that we'll trim less than that amount if the xdr_buf is + * too small, or if (for instance) it's all in the head and the parser has + * already read too far into it. + */ +void xdr_buf_trim(struct xdr_buf *buf, unsigned int len) +{ + size_t cur; + unsigned int trim = len; + + if (buf->tail[0].iov_len) { + cur = min_t(size_t, buf->tail[0].iov_len, trim); + buf->tail[0].iov_len -= cur; + trim -= cur; + if (!trim) + goto fix_len; + } + + if (buf->page_len) { + cur = min_t(unsigned int, buf->page_len, trim); + buf->page_len -= cur; + trim -= cur; + if (!trim) + goto fix_len; + } + + if (buf->head[0].iov_len) { + cur = min_t(size_t, buf->head[0].iov_len, trim); + buf->head[0].iov_len -= cur; + trim -= cur; + } +fix_len: + buf->len -= (len - trim); +} +EXPORT_SYMBOL_GPL(xdr_buf_trim); + +static void __read_bytes_from_xdr_buf(const struct xdr_buf *subbuf, + void *obj, unsigned int len) +{ + unsigned int this_len; + + this_len = min_t(unsigned int, len, subbuf->head[0].iov_len); + memcpy(obj, subbuf->head[0].iov_base, this_len); + len -= this_len; + obj += this_len; + this_len = min_t(unsigned int, len, subbuf->page_len); + _copy_from_pages(obj, subbuf->pages, subbuf->page_base, this_len); + len -= this_len; + obj += this_len; + this_len = min_t(unsigned int, len, subbuf->tail[0].iov_len); + memcpy(obj, subbuf->tail[0].iov_base, this_len); +} + +/* obj is assumed to point to allocated memory of size at least len: */ +int read_bytes_from_xdr_buf(const struct xdr_buf *buf, unsigned int base, + void *obj, unsigned int len) +{ + struct xdr_buf subbuf; + int status; + + status = xdr_buf_subsegment(buf, &subbuf, base, len); + if (status != 0) + return status; + __read_bytes_from_xdr_buf(&subbuf, obj, len); + return 0; +} +EXPORT_SYMBOL_GPL(read_bytes_from_xdr_buf); + +static void __write_bytes_to_xdr_buf(const struct xdr_buf *subbuf, + void *obj, unsigned int len) +{ + unsigned int this_len; + + this_len = min_t(unsigned int, len, subbuf->head[0].iov_len); + memcpy(subbuf->head[0].iov_base, obj, this_len); + len -= this_len; + obj += this_len; + this_len = min_t(unsigned int, len, subbuf->page_len); + _copy_to_pages(subbuf->pages, subbuf->page_base, obj, this_len); + len -= this_len; + obj += this_len; + this_len = min_t(unsigned int, len, subbuf->tail[0].iov_len); + memcpy(subbuf->tail[0].iov_base, obj, this_len); +} + +/* obj is assumed to point to allocated memory of size at least len: */ +int write_bytes_to_xdr_buf(const struct xdr_buf *buf, unsigned int base, + void *obj, unsigned int len) +{ + struct xdr_buf subbuf; + int status; + + status = xdr_buf_subsegment(buf, &subbuf, base, len); + if (status != 0) + return status; + __write_bytes_to_xdr_buf(&subbuf, obj, len); + return 0; +} +EXPORT_SYMBOL_GPL(write_bytes_to_xdr_buf); + +int xdr_decode_word(const struct xdr_buf *buf, unsigned int base, u32 *obj) +{ + __be32 raw; + int status; + + status = read_bytes_from_xdr_buf(buf, base, &raw, sizeof(*obj)); + if (status) + return status; + *obj = be32_to_cpu(raw); + return 0; +} +EXPORT_SYMBOL_GPL(xdr_decode_word); + +int xdr_encode_word(const struct xdr_buf *buf, unsigned int base, u32 obj) +{ + __be32 raw = cpu_to_be32(obj); + + return write_bytes_to_xdr_buf(buf, base, &raw, sizeof(obj)); +} +EXPORT_SYMBOL_GPL(xdr_encode_word); + +/* Returns 0 on success, or else a negative error code. */ +static int xdr_xcode_array2(const struct xdr_buf *buf, unsigned int base, + struct xdr_array2_desc *desc, int encode) +{ + char *elem = NULL, *c; + unsigned int copied = 0, todo, avail_here; + struct page **ppages = NULL; + int err; + + if (encode) { + if (xdr_encode_word(buf, base, desc->array_len) != 0) + return -EINVAL; + } else { + if (xdr_decode_word(buf, base, &desc->array_len) != 0 || + desc->array_len > desc->array_maxlen || + (unsigned long) base + 4 + desc->array_len * + desc->elem_size > buf->len) + return -EINVAL; + } + base += 4; + + if (!desc->xcode) + return 0; + + todo = desc->array_len * desc->elem_size; + + /* process head */ + if (todo && base < buf->head->iov_len) { + c = buf->head->iov_base + base; + avail_here = min_t(unsigned int, todo, + buf->head->iov_len - base); + todo -= avail_here; + + while (avail_here >= desc->elem_size) { + err = desc->xcode(desc, c); + if (err) + goto out; + c += desc->elem_size; + avail_here -= desc->elem_size; + } + if (avail_here) { + if (!elem) { + elem = kmalloc(desc->elem_size, GFP_KERNEL); + err = -ENOMEM; + if (!elem) + goto out; + } + if (encode) { + err = desc->xcode(desc, elem); + if (err) + goto out; + memcpy(c, elem, avail_here); + } else + memcpy(elem, c, avail_here); + copied = avail_here; + } + base = buf->head->iov_len; /* align to start of pages */ + } + + /* process pages array */ + base -= buf->head->iov_len; + if (todo && base < buf->page_len) { + unsigned int avail_page; + + avail_here = min(todo, buf->page_len - base); + todo -= avail_here; + + base += buf->page_base; + ppages = buf->pages + (base >> PAGE_SHIFT); + base &= ~PAGE_MASK; + avail_page = min_t(unsigned int, PAGE_SIZE - base, + avail_here); + c = kmap(*ppages) + base; + + while (avail_here) { + avail_here -= avail_page; + if (copied || avail_page < desc->elem_size) { + unsigned int l = min(avail_page, + desc->elem_size - copied); + if (!elem) { + elem = kmalloc(desc->elem_size, + GFP_KERNEL); + err = -ENOMEM; + if (!elem) + goto out; + } + if (encode) { + if (!copied) { + err = desc->xcode(desc, elem); + if (err) + goto out; + } + memcpy(c, elem + copied, l); + copied += l; + if (copied == desc->elem_size) + copied = 0; + } else { + memcpy(elem + copied, c, l); + copied += l; + if (copied == desc->elem_size) { + err = desc->xcode(desc, elem); + if (err) + goto out; + copied = 0; + } + } + avail_page -= l; + c += l; + } + while (avail_page >= desc->elem_size) { + err = desc->xcode(desc, c); + if (err) + goto out; + c += desc->elem_size; + avail_page -= desc->elem_size; + } + if (avail_page) { + unsigned int l = min(avail_page, + desc->elem_size - copied); + if (!elem) { + elem = kmalloc(desc->elem_size, + GFP_KERNEL); + err = -ENOMEM; + if (!elem) + goto out; + } + if (encode) { + if (!copied) { + err = desc->xcode(desc, elem); + if (err) + goto out; + } + memcpy(c, elem + copied, l); + copied += l; + if (copied == desc->elem_size) + copied = 0; + } else { + memcpy(elem + copied, c, l); + copied += l; + if (copied == desc->elem_size) { + err = desc->xcode(desc, elem); + if (err) + goto out; + copied = 0; + } + } + } + if (avail_here) { + kunmap(*ppages); + ppages++; + c = kmap(*ppages); + } + + avail_page = min(avail_here, + (unsigned int) PAGE_SIZE); + } + base = buf->page_len; /* align to start of tail */ + } + + /* process tail */ + base -= buf->page_len; + if (todo) { + c = buf->tail->iov_base + base; + if (copied) { + unsigned int l = desc->elem_size - copied; + + if (encode) + memcpy(c, elem + copied, l); + else { + memcpy(elem + copied, c, l); + err = desc->xcode(desc, elem); + if (err) + goto out; + } + todo -= l; + c += l; + } + while (todo) { + err = desc->xcode(desc, c); + if (err) + goto out; + c += desc->elem_size; + todo -= desc->elem_size; + } + } + err = 0; + +out: + kfree(elem); + if (ppages) + kunmap(*ppages); + return err; +} + +int xdr_decode_array2(const struct xdr_buf *buf, unsigned int base, + struct xdr_array2_desc *desc) +{ + if (base >= buf->len) + return -EINVAL; + + return xdr_xcode_array2(buf, base, desc, 0); +} +EXPORT_SYMBOL_GPL(xdr_decode_array2); + +int xdr_encode_array2(const struct xdr_buf *buf, unsigned int base, + struct xdr_array2_desc *desc) +{ + if ((unsigned long) base + 4 + desc->array_len * desc->elem_size > + buf->head->iov_len + buf->page_len + buf->tail->iov_len) + return -EINVAL; + + return xdr_xcode_array2(buf, base, desc, 1); +} +EXPORT_SYMBOL_GPL(xdr_encode_array2); + +int xdr_process_buf(const struct xdr_buf *buf, unsigned int offset, + unsigned int len, + int (*actor)(struct scatterlist *, void *), void *data) +{ + int i, ret = 0; + unsigned int page_len, thislen, page_offset; + struct scatterlist sg[1]; + + sg_init_table(sg, 1); + + if (offset >= buf->head[0].iov_len) { + offset -= buf->head[0].iov_len; + } else { + thislen = buf->head[0].iov_len - offset; + if (thislen > len) + thislen = len; + sg_set_buf(sg, buf->head[0].iov_base + offset, thislen); + ret = actor(sg, data); + if (ret) + goto out; + offset = 0; + len -= thislen; + } + if (len == 0) + goto out; + + if (offset >= buf->page_len) { + offset -= buf->page_len; + } else { + page_len = buf->page_len - offset; + if (page_len > len) + page_len = len; + len -= page_len; + page_offset = (offset + buf->page_base) & (PAGE_SIZE - 1); + i = (offset + buf->page_base) >> PAGE_SHIFT; + thislen = PAGE_SIZE - page_offset; + do { + if (thislen > page_len) + thislen = page_len; + sg_set_page(sg, buf->pages[i], thislen, page_offset); + ret = actor(sg, data); + if (ret) + goto out; + page_len -= thislen; + i++; + page_offset = 0; + thislen = PAGE_SIZE; + } while (page_len != 0); + offset = 0; + } + if (len == 0) + goto out; + if (offset < buf->tail[0].iov_len) { + thislen = buf->tail[0].iov_len - offset; + if (thislen > len) + thislen = len; + sg_set_buf(sg, buf->tail[0].iov_base + offset, thislen); + ret = actor(sg, data); + len -= thislen; + } + if (len != 0) + ret = -EINVAL; +out: + return ret; +} +EXPORT_SYMBOL_GPL(xdr_process_buf); + +/** + * xdr_stream_decode_opaque - Decode variable length opaque + * @xdr: pointer to xdr_stream + * @ptr: location to store opaque data + * @size: size of storage buffer @ptr + * + * Return values: + * On success, returns size of object stored in *@ptr + * %-EBADMSG on XDR buffer overflow + * %-EMSGSIZE on overflow of storage buffer @ptr + */ +ssize_t xdr_stream_decode_opaque(struct xdr_stream *xdr, void *ptr, size_t size) +{ + ssize_t ret; + void *p; + + ret = xdr_stream_decode_opaque_inline(xdr, &p, size); + if (ret <= 0) + return ret; + memcpy(ptr, p, ret); + return ret; +} +EXPORT_SYMBOL_GPL(xdr_stream_decode_opaque); + +/** + * xdr_stream_decode_opaque_dup - Decode and duplicate variable length opaque + * @xdr: pointer to xdr_stream + * @ptr: location to store pointer to opaque data + * @maxlen: maximum acceptable object size + * @gfp_flags: GFP mask to use + * + * Return values: + * On success, returns size of object stored in *@ptr + * %-EBADMSG on XDR buffer overflow + * %-EMSGSIZE if the size of the object would exceed @maxlen + * %-ENOMEM on memory allocation failure + */ +ssize_t xdr_stream_decode_opaque_dup(struct xdr_stream *xdr, void **ptr, + size_t maxlen, gfp_t gfp_flags) +{ + ssize_t ret; + void *p; + + ret = xdr_stream_decode_opaque_inline(xdr, &p, maxlen); + if (ret > 0) { + *ptr = kmemdup(p, ret, gfp_flags); + if (*ptr != NULL) + return ret; + ret = -ENOMEM; + } + *ptr = NULL; + return ret; +} +EXPORT_SYMBOL_GPL(xdr_stream_decode_opaque_dup); + +/** + * xdr_stream_decode_string - Decode variable length string + * @xdr: pointer to xdr_stream + * @str: location to store string + * @size: size of storage buffer @str + * + * Return values: + * On success, returns length of NUL-terminated string stored in *@str + * %-EBADMSG on XDR buffer overflow + * %-EMSGSIZE on overflow of storage buffer @str + */ +ssize_t xdr_stream_decode_string(struct xdr_stream *xdr, char *str, size_t size) +{ + ssize_t ret; + void *p; + + ret = xdr_stream_decode_opaque_inline(xdr, &p, size); + if (ret > 0) { + memcpy(str, p, ret); + str[ret] = '\0'; + return strlen(str); + } + *str = '\0'; + return ret; +} +EXPORT_SYMBOL_GPL(xdr_stream_decode_string); + +/** + * xdr_stream_decode_string_dup - Decode and duplicate variable length string + * @xdr: pointer to xdr_stream + * @str: location to store pointer to string + * @maxlen: maximum acceptable string length + * @gfp_flags: GFP mask to use + * + * Return values: + * On success, returns length of NUL-terminated string stored in *@ptr + * %-EBADMSG on XDR buffer overflow + * %-EMSGSIZE if the size of the string would exceed @maxlen + * %-ENOMEM on memory allocation failure + */ +ssize_t xdr_stream_decode_string_dup(struct xdr_stream *xdr, char **str, + size_t maxlen, gfp_t gfp_flags) +{ + void *p; + ssize_t ret; + + ret = xdr_stream_decode_opaque_inline(xdr, &p, maxlen); + if (ret > 0) { + char *s = kmemdup_nul(p, ret, gfp_flags); + if (s != NULL) { + *str = s; + return strlen(s); + } + ret = -ENOMEM; + } + *str = NULL; + return ret; +} +EXPORT_SYMBOL_GPL(xdr_stream_decode_string_dup); diff --git a/net/sunrpc/xprt.c b/net/sunrpc/xprt.c new file mode 100644 index 000000000..656cec208 --- /dev/null +++ b/net/sunrpc/xprt.c @@ -0,0 +1,2192 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * linux/net/sunrpc/xprt.c + * + * This is a generic RPC call interface supporting congestion avoidance, + * and asynchronous calls. + * + * The interface works like this: + * + * - When a process places a call, it allocates a request slot if + * one is available. Otherwise, it sleeps on the backlog queue + * (xprt_reserve). + * - Next, the caller puts together the RPC message, stuffs it into + * the request struct, and calls xprt_transmit(). + * - xprt_transmit sends the message and installs the caller on the + * transport's wait list. At the same time, if a reply is expected, + * it installs a timer that is run after the packet's timeout has + * expired. + * - When a packet arrives, the data_ready handler walks the list of + * pending requests for that transport. If a matching XID is found, the + * caller is woken up, and the timer removed. + * - When no reply arrives within the timeout interval, the timer is + * fired by the kernel and runs xprt_timer(). It either adjusts the + * timeout values (minor timeout) or wakes up the caller with a status + * of -ETIMEDOUT. + * - When the caller receives a notification from RPC that a reply arrived, + * it should release the RPC slot, and process the reply. + * If the call timed out, it may choose to retry the operation by + * adjusting the initial timeout value, and simply calling rpc_call + * again. + * + * Support for async RPC is done through a set of RPC-specific scheduling + * primitives that `transparently' work for processes as well as async + * tasks that rely on callbacks. + * + * Copyright (C) 1995-1997, Olaf Kirch + * + * Transport switch API copyright (C) 2005, Chuck Lever + */ + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include "sunrpc.h" +#include "sysfs.h" +#include "fail.h" + +/* + * Local variables + */ + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# define RPCDBG_FACILITY RPCDBG_XPRT +#endif + +/* + * Local functions + */ +static void xprt_init(struct rpc_xprt *xprt, struct net *net); +static __be32 xprt_alloc_xid(struct rpc_xprt *xprt); +static void xprt_destroy(struct rpc_xprt *xprt); +static void xprt_request_init(struct rpc_task *task); +static int xprt_request_prepare(struct rpc_rqst *req, struct xdr_buf *buf); + +static DEFINE_SPINLOCK(xprt_list_lock); +static LIST_HEAD(xprt_list); + +static unsigned long xprt_request_timeout(const struct rpc_rqst *req) +{ + unsigned long timeout = jiffies + req->rq_timeout; + + if (time_before(timeout, req->rq_majortimeo)) + return timeout; + return req->rq_majortimeo; +} + +/** + * xprt_register_transport - register a transport implementation + * @transport: transport to register + * + * If a transport implementation is loaded as a kernel module, it can + * call this interface to make itself known to the RPC client. + * + * Returns: + * 0: transport successfully registered + * -EEXIST: transport already registered + * -EINVAL: transport module being unloaded + */ +int xprt_register_transport(struct xprt_class *transport) +{ + struct xprt_class *t; + int result; + + result = -EEXIST; + spin_lock(&xprt_list_lock); + list_for_each_entry(t, &xprt_list, list) { + /* don't register the same transport class twice */ + if (t->ident == transport->ident) + goto out; + } + + list_add_tail(&transport->list, &xprt_list); + printk(KERN_INFO "RPC: Registered %s transport module.\n", + transport->name); + result = 0; + +out: + spin_unlock(&xprt_list_lock); + return result; +} +EXPORT_SYMBOL_GPL(xprt_register_transport); + +/** + * xprt_unregister_transport - unregister a transport implementation + * @transport: transport to unregister + * + * Returns: + * 0: transport successfully unregistered + * -ENOENT: transport never registered + */ +int xprt_unregister_transport(struct xprt_class *transport) +{ + struct xprt_class *t; + int result; + + result = 0; + spin_lock(&xprt_list_lock); + list_for_each_entry(t, &xprt_list, list) { + if (t == transport) { + printk(KERN_INFO + "RPC: Unregistered %s transport module.\n", + transport->name); + list_del_init(&transport->list); + goto out; + } + } + result = -ENOENT; + +out: + spin_unlock(&xprt_list_lock); + return result; +} +EXPORT_SYMBOL_GPL(xprt_unregister_transport); + +static void +xprt_class_release(const struct xprt_class *t) +{ + module_put(t->owner); +} + +static const struct xprt_class * +xprt_class_find_by_ident_locked(int ident) +{ + const struct xprt_class *t; + + list_for_each_entry(t, &xprt_list, list) { + if (t->ident != ident) + continue; + if (!try_module_get(t->owner)) + continue; + return t; + } + return NULL; +} + +static const struct xprt_class * +xprt_class_find_by_ident(int ident) +{ + const struct xprt_class *t; + + spin_lock(&xprt_list_lock); + t = xprt_class_find_by_ident_locked(ident); + spin_unlock(&xprt_list_lock); + return t; +} + +static const struct xprt_class * +xprt_class_find_by_netid_locked(const char *netid) +{ + const struct xprt_class *t; + unsigned int i; + + list_for_each_entry(t, &xprt_list, list) { + for (i = 0; t->netid[i][0] != '\0'; i++) { + if (strcmp(t->netid[i], netid) != 0) + continue; + if (!try_module_get(t->owner)) + continue; + return t; + } + } + return NULL; +} + +static const struct xprt_class * +xprt_class_find_by_netid(const char *netid) +{ + const struct xprt_class *t; + + spin_lock(&xprt_list_lock); + t = xprt_class_find_by_netid_locked(netid); + if (!t) { + spin_unlock(&xprt_list_lock); + request_module("rpc%s", netid); + spin_lock(&xprt_list_lock); + t = xprt_class_find_by_netid_locked(netid); + } + spin_unlock(&xprt_list_lock); + return t; +} + +/** + * xprt_find_transport_ident - convert a netid into a transport identifier + * @netid: transport to load + * + * Returns: + * > 0: transport identifier + * -ENOENT: transport module not available + */ +int xprt_find_transport_ident(const char *netid) +{ + const struct xprt_class *t; + int ret; + + t = xprt_class_find_by_netid(netid); + if (!t) + return -ENOENT; + ret = t->ident; + xprt_class_release(t); + return ret; +} +EXPORT_SYMBOL_GPL(xprt_find_transport_ident); + +static void xprt_clear_locked(struct rpc_xprt *xprt) +{ + xprt->snd_task = NULL; + if (!test_bit(XPRT_CLOSE_WAIT, &xprt->state)) + clear_bit_unlock(XPRT_LOCKED, &xprt->state); + else + queue_work(xprtiod_workqueue, &xprt->task_cleanup); +} + +/** + * xprt_reserve_xprt - serialize write access to transports + * @task: task that is requesting access to the transport + * @xprt: pointer to the target transport + * + * This prevents mixing the payload of separate requests, and prevents + * transport connects from colliding with writes. No congestion control + * is provided. + */ +int xprt_reserve_xprt(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + + if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) { + if (task == xprt->snd_task) + goto out_locked; + goto out_sleep; + } + if (test_bit(XPRT_WRITE_SPACE, &xprt->state)) + goto out_unlock; + xprt->snd_task = task; + +out_locked: + trace_xprt_reserve_xprt(xprt, task); + return 1; + +out_unlock: + xprt_clear_locked(xprt); +out_sleep: + task->tk_status = -EAGAIN; + if (RPC_IS_SOFT(task)) + rpc_sleep_on_timeout(&xprt->sending, task, NULL, + xprt_request_timeout(req)); + else + rpc_sleep_on(&xprt->sending, task, NULL); + return 0; +} +EXPORT_SYMBOL_GPL(xprt_reserve_xprt); + +static bool +xprt_need_congestion_window_wait(struct rpc_xprt *xprt) +{ + return test_bit(XPRT_CWND_WAIT, &xprt->state); +} + +static void +xprt_set_congestion_window_wait(struct rpc_xprt *xprt) +{ + if (!list_empty(&xprt->xmit_queue)) { + /* Peek at head of queue to see if it can make progress */ + if (list_first_entry(&xprt->xmit_queue, struct rpc_rqst, + rq_xmit)->rq_cong) + return; + } + set_bit(XPRT_CWND_WAIT, &xprt->state); +} + +static void +xprt_test_and_clear_congestion_window_wait(struct rpc_xprt *xprt) +{ + if (!RPCXPRT_CONGESTED(xprt)) + clear_bit(XPRT_CWND_WAIT, &xprt->state); +} + +/* + * xprt_reserve_xprt_cong - serialize write access to transports + * @task: task that is requesting access to the transport + * + * Same as xprt_reserve_xprt, but Van Jacobson congestion control is + * integrated into the decision of whether a request is allowed to be + * woken up and given access to the transport. + * Note that the lock is only granted if we know there are free slots. + */ +int xprt_reserve_xprt_cong(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + + if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) { + if (task == xprt->snd_task) + goto out_locked; + goto out_sleep; + } + if (req == NULL) { + xprt->snd_task = task; + goto out_locked; + } + if (test_bit(XPRT_WRITE_SPACE, &xprt->state)) + goto out_unlock; + if (!xprt_need_congestion_window_wait(xprt)) { + xprt->snd_task = task; + goto out_locked; + } +out_unlock: + xprt_clear_locked(xprt); +out_sleep: + task->tk_status = -EAGAIN; + if (RPC_IS_SOFT(task)) + rpc_sleep_on_timeout(&xprt->sending, task, NULL, + xprt_request_timeout(req)); + else + rpc_sleep_on(&xprt->sending, task, NULL); + return 0; +out_locked: + trace_xprt_reserve_cong(xprt, task); + return 1; +} +EXPORT_SYMBOL_GPL(xprt_reserve_xprt_cong); + +static inline int xprt_lock_write(struct rpc_xprt *xprt, struct rpc_task *task) +{ + int retval; + + if (test_bit(XPRT_LOCKED, &xprt->state) && xprt->snd_task == task) + return 1; + spin_lock(&xprt->transport_lock); + retval = xprt->ops->reserve_xprt(xprt, task); + spin_unlock(&xprt->transport_lock); + return retval; +} + +static bool __xprt_lock_write_func(struct rpc_task *task, void *data) +{ + struct rpc_xprt *xprt = data; + + xprt->snd_task = task; + return true; +} + +static void __xprt_lock_write_next(struct rpc_xprt *xprt) +{ + if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) + return; + if (test_bit(XPRT_WRITE_SPACE, &xprt->state)) + goto out_unlock; + if (rpc_wake_up_first_on_wq(xprtiod_workqueue, &xprt->sending, + __xprt_lock_write_func, xprt)) + return; +out_unlock: + xprt_clear_locked(xprt); +} + +static void __xprt_lock_write_next_cong(struct rpc_xprt *xprt) +{ + if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) + return; + if (test_bit(XPRT_WRITE_SPACE, &xprt->state)) + goto out_unlock; + if (xprt_need_congestion_window_wait(xprt)) + goto out_unlock; + if (rpc_wake_up_first_on_wq(xprtiod_workqueue, &xprt->sending, + __xprt_lock_write_func, xprt)) + return; +out_unlock: + xprt_clear_locked(xprt); +} + +/** + * xprt_release_xprt - allow other requests to use a transport + * @xprt: transport with other tasks potentially waiting + * @task: task that is releasing access to the transport + * + * Note that "task" can be NULL. No congestion control is provided. + */ +void xprt_release_xprt(struct rpc_xprt *xprt, struct rpc_task *task) +{ + if (xprt->snd_task == task) { + xprt_clear_locked(xprt); + __xprt_lock_write_next(xprt); + } + trace_xprt_release_xprt(xprt, task); +} +EXPORT_SYMBOL_GPL(xprt_release_xprt); + +/** + * xprt_release_xprt_cong - allow other requests to use a transport + * @xprt: transport with other tasks potentially waiting + * @task: task that is releasing access to the transport + * + * Note that "task" can be NULL. Another task is awoken to use the + * transport if the transport's congestion window allows it. + */ +void xprt_release_xprt_cong(struct rpc_xprt *xprt, struct rpc_task *task) +{ + if (xprt->snd_task == task) { + xprt_clear_locked(xprt); + __xprt_lock_write_next_cong(xprt); + } + trace_xprt_release_cong(xprt, task); +} +EXPORT_SYMBOL_GPL(xprt_release_xprt_cong); + +void xprt_release_write(struct rpc_xprt *xprt, struct rpc_task *task) +{ + if (xprt->snd_task != task) + return; + spin_lock(&xprt->transport_lock); + xprt->ops->release_xprt(xprt, task); + spin_unlock(&xprt->transport_lock); +} + +/* + * Van Jacobson congestion avoidance. Check if the congestion window + * overflowed. Put the task to sleep if this is the case. + */ +static int +__xprt_get_cong(struct rpc_xprt *xprt, struct rpc_rqst *req) +{ + if (req->rq_cong) + return 1; + trace_xprt_get_cong(xprt, req->rq_task); + if (RPCXPRT_CONGESTED(xprt)) { + xprt_set_congestion_window_wait(xprt); + return 0; + } + req->rq_cong = 1; + xprt->cong += RPC_CWNDSCALE; + return 1; +} + +/* + * Adjust the congestion window, and wake up the next task + * that has been sleeping due to congestion + */ +static void +__xprt_put_cong(struct rpc_xprt *xprt, struct rpc_rqst *req) +{ + if (!req->rq_cong) + return; + req->rq_cong = 0; + xprt->cong -= RPC_CWNDSCALE; + xprt_test_and_clear_congestion_window_wait(xprt); + trace_xprt_put_cong(xprt, req->rq_task); + __xprt_lock_write_next_cong(xprt); +} + +/** + * xprt_request_get_cong - Request congestion control credits + * @xprt: pointer to transport + * @req: pointer to RPC request + * + * Useful for transports that require congestion control. + */ +bool +xprt_request_get_cong(struct rpc_xprt *xprt, struct rpc_rqst *req) +{ + bool ret = false; + + if (req->rq_cong) + return true; + spin_lock(&xprt->transport_lock); + ret = __xprt_get_cong(xprt, req) != 0; + spin_unlock(&xprt->transport_lock); + return ret; +} +EXPORT_SYMBOL_GPL(xprt_request_get_cong); + +/** + * xprt_release_rqst_cong - housekeeping when request is complete + * @task: RPC request that recently completed + * + * Useful for transports that require congestion control. + */ +void xprt_release_rqst_cong(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + + __xprt_put_cong(req->rq_xprt, req); +} +EXPORT_SYMBOL_GPL(xprt_release_rqst_cong); + +static void xprt_clear_congestion_window_wait_locked(struct rpc_xprt *xprt) +{ + if (test_and_clear_bit(XPRT_CWND_WAIT, &xprt->state)) + __xprt_lock_write_next_cong(xprt); +} + +/* + * Clear the congestion window wait flag and wake up the next + * entry on xprt->sending + */ +static void +xprt_clear_congestion_window_wait(struct rpc_xprt *xprt) +{ + if (test_and_clear_bit(XPRT_CWND_WAIT, &xprt->state)) { + spin_lock(&xprt->transport_lock); + __xprt_lock_write_next_cong(xprt); + spin_unlock(&xprt->transport_lock); + } +} + +/** + * xprt_adjust_cwnd - adjust transport congestion window + * @xprt: pointer to xprt + * @task: recently completed RPC request used to adjust window + * @result: result code of completed RPC request + * + * The transport code maintains an estimate on the maximum number of out- + * standing RPC requests, using a smoothed version of the congestion + * avoidance implemented in 44BSD. This is basically the Van Jacobson + * congestion algorithm: If a retransmit occurs, the congestion window is + * halved; otherwise, it is incremented by 1/cwnd when + * + * - a reply is received and + * - a full number of requests are outstanding and + * - the congestion window hasn't been updated recently. + */ +void xprt_adjust_cwnd(struct rpc_xprt *xprt, struct rpc_task *task, int result) +{ + struct rpc_rqst *req = task->tk_rqstp; + unsigned long cwnd = xprt->cwnd; + + if (result >= 0 && cwnd <= xprt->cong) { + /* The (cwnd >> 1) term makes sure + * the result gets rounded properly. */ + cwnd += (RPC_CWNDSCALE * RPC_CWNDSCALE + (cwnd >> 1)) / cwnd; + if (cwnd > RPC_MAXCWND(xprt)) + cwnd = RPC_MAXCWND(xprt); + __xprt_lock_write_next_cong(xprt); + } else if (result == -ETIMEDOUT) { + cwnd >>= 1; + if (cwnd < RPC_CWNDSCALE) + cwnd = RPC_CWNDSCALE; + } + dprintk("RPC: cong %ld, cwnd was %ld, now %ld\n", + xprt->cong, xprt->cwnd, cwnd); + xprt->cwnd = cwnd; + __xprt_put_cong(xprt, req); +} +EXPORT_SYMBOL_GPL(xprt_adjust_cwnd); + +/** + * xprt_wake_pending_tasks - wake all tasks on a transport's pending queue + * @xprt: transport with waiting tasks + * @status: result code to plant in each task before waking it + * + */ +void xprt_wake_pending_tasks(struct rpc_xprt *xprt, int status) +{ + if (status < 0) + rpc_wake_up_status(&xprt->pending, status); + else + rpc_wake_up(&xprt->pending); +} +EXPORT_SYMBOL_GPL(xprt_wake_pending_tasks); + +/** + * xprt_wait_for_buffer_space - wait for transport output buffer to clear + * @xprt: transport + * + * Note that we only set the timer for the case of RPC_IS_SOFT(), since + * we don't in general want to force a socket disconnection due to + * an incomplete RPC call transmission. + */ +void xprt_wait_for_buffer_space(struct rpc_xprt *xprt) +{ + set_bit(XPRT_WRITE_SPACE, &xprt->state); +} +EXPORT_SYMBOL_GPL(xprt_wait_for_buffer_space); + +static bool +xprt_clear_write_space_locked(struct rpc_xprt *xprt) +{ + if (test_and_clear_bit(XPRT_WRITE_SPACE, &xprt->state)) { + __xprt_lock_write_next(xprt); + dprintk("RPC: write space: waking waiting task on " + "xprt %p\n", xprt); + return true; + } + return false; +} + +/** + * xprt_write_space - wake the task waiting for transport output buffer space + * @xprt: transport with waiting tasks + * + * Can be called in a soft IRQ context, so xprt_write_space never sleeps. + */ +bool xprt_write_space(struct rpc_xprt *xprt) +{ + bool ret; + + if (!test_bit(XPRT_WRITE_SPACE, &xprt->state)) + return false; + spin_lock(&xprt->transport_lock); + ret = xprt_clear_write_space_locked(xprt); + spin_unlock(&xprt->transport_lock); + return ret; +} +EXPORT_SYMBOL_GPL(xprt_write_space); + +static unsigned long xprt_abs_ktime_to_jiffies(ktime_t abstime) +{ + s64 delta = ktime_to_ns(ktime_get() - abstime); + return likely(delta >= 0) ? + jiffies - nsecs_to_jiffies(delta) : + jiffies + nsecs_to_jiffies(-delta); +} + +static unsigned long xprt_calc_majortimeo(struct rpc_rqst *req) +{ + const struct rpc_timeout *to = req->rq_task->tk_client->cl_timeout; + unsigned long majortimeo = req->rq_timeout; + + if (to->to_exponential) + majortimeo <<= to->to_retries; + else + majortimeo += to->to_increment * to->to_retries; + if (majortimeo > to->to_maxval || majortimeo == 0) + majortimeo = to->to_maxval; + return majortimeo; +} + +static void xprt_reset_majortimeo(struct rpc_rqst *req) +{ + req->rq_majortimeo += xprt_calc_majortimeo(req); +} + +static void xprt_reset_minortimeo(struct rpc_rqst *req) +{ + req->rq_minortimeo += req->rq_timeout; +} + +static void xprt_init_majortimeo(struct rpc_task *task, struct rpc_rqst *req) +{ + unsigned long time_init; + struct rpc_xprt *xprt = req->rq_xprt; + + if (likely(xprt && xprt_connected(xprt))) + time_init = jiffies; + else + time_init = xprt_abs_ktime_to_jiffies(task->tk_start); + req->rq_timeout = task->tk_client->cl_timeout->to_initval; + req->rq_majortimeo = time_init + xprt_calc_majortimeo(req); + req->rq_minortimeo = time_init + req->rq_timeout; +} + +/** + * xprt_adjust_timeout - adjust timeout values for next retransmit + * @req: RPC request containing parameters to use for the adjustment + * + */ +int xprt_adjust_timeout(struct rpc_rqst *req) +{ + struct rpc_xprt *xprt = req->rq_xprt; + const struct rpc_timeout *to = req->rq_task->tk_client->cl_timeout; + int status = 0; + + if (time_before(jiffies, req->rq_majortimeo)) { + if (time_before(jiffies, req->rq_minortimeo)) + return status; + if (to->to_exponential) + req->rq_timeout <<= 1; + else + req->rq_timeout += to->to_increment; + if (to->to_maxval && req->rq_timeout >= to->to_maxval) + req->rq_timeout = to->to_maxval; + req->rq_retries++; + } else { + req->rq_timeout = to->to_initval; + req->rq_retries = 0; + xprt_reset_majortimeo(req); + /* Reset the RTT counters == "slow start" */ + spin_lock(&xprt->transport_lock); + rpc_init_rtt(req->rq_task->tk_client->cl_rtt, to->to_initval); + spin_unlock(&xprt->transport_lock); + status = -ETIMEDOUT; + } + xprt_reset_minortimeo(req); + + if (req->rq_timeout == 0) { + printk(KERN_WARNING "xprt_adjust_timeout: rq_timeout = 0!\n"); + req->rq_timeout = 5 * HZ; + } + return status; +} + +static void xprt_autoclose(struct work_struct *work) +{ + struct rpc_xprt *xprt = + container_of(work, struct rpc_xprt, task_cleanup); + unsigned int pflags = memalloc_nofs_save(); + + trace_xprt_disconnect_auto(xprt); + xprt->connect_cookie++; + smp_mb__before_atomic(); + clear_bit(XPRT_CLOSE_WAIT, &xprt->state); + xprt->ops->close(xprt); + xprt_release_write(xprt, NULL); + wake_up_bit(&xprt->state, XPRT_LOCKED); + memalloc_nofs_restore(pflags); +} + +/** + * xprt_disconnect_done - mark a transport as disconnected + * @xprt: transport to flag for disconnect + * + */ +void xprt_disconnect_done(struct rpc_xprt *xprt) +{ + trace_xprt_disconnect_done(xprt); + spin_lock(&xprt->transport_lock); + xprt_clear_connected(xprt); + xprt_clear_write_space_locked(xprt); + xprt_clear_congestion_window_wait_locked(xprt); + xprt_wake_pending_tasks(xprt, -ENOTCONN); + spin_unlock(&xprt->transport_lock); +} +EXPORT_SYMBOL_GPL(xprt_disconnect_done); + +/** + * xprt_schedule_autoclose_locked - Try to schedule an autoclose RPC call + * @xprt: transport to disconnect + */ +static void xprt_schedule_autoclose_locked(struct rpc_xprt *xprt) +{ + if (test_and_set_bit(XPRT_CLOSE_WAIT, &xprt->state)) + return; + if (test_and_set_bit(XPRT_LOCKED, &xprt->state) == 0) + queue_work(xprtiod_workqueue, &xprt->task_cleanup); + else if (xprt->snd_task && !test_bit(XPRT_SND_IS_COOKIE, &xprt->state)) + rpc_wake_up_queued_task_set_status(&xprt->pending, + xprt->snd_task, -ENOTCONN); +} + +/** + * xprt_force_disconnect - force a transport to disconnect + * @xprt: transport to disconnect + * + */ +void xprt_force_disconnect(struct rpc_xprt *xprt) +{ + trace_xprt_disconnect_force(xprt); + + /* Don't race with the test_bit() in xprt_clear_locked() */ + spin_lock(&xprt->transport_lock); + xprt_schedule_autoclose_locked(xprt); + spin_unlock(&xprt->transport_lock); +} +EXPORT_SYMBOL_GPL(xprt_force_disconnect); + +static unsigned int +xprt_connect_cookie(struct rpc_xprt *xprt) +{ + return READ_ONCE(xprt->connect_cookie); +} + +static bool +xprt_request_retransmit_after_disconnect(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + return req->rq_connect_cookie != xprt_connect_cookie(xprt) || + !xprt_connected(xprt); +} + +/** + * xprt_conditional_disconnect - force a transport to disconnect + * @xprt: transport to disconnect + * @cookie: 'connection cookie' + * + * This attempts to break the connection if and only if 'cookie' matches + * the current transport 'connection cookie'. It ensures that we don't + * try to break the connection more than once when we need to retransmit + * a batch of RPC requests. + * + */ +void xprt_conditional_disconnect(struct rpc_xprt *xprt, unsigned int cookie) +{ + /* Don't race with the test_bit() in xprt_clear_locked() */ + spin_lock(&xprt->transport_lock); + if (cookie != xprt->connect_cookie) + goto out; + if (test_bit(XPRT_CLOSING, &xprt->state)) + goto out; + xprt_schedule_autoclose_locked(xprt); +out: + spin_unlock(&xprt->transport_lock); +} + +static bool +xprt_has_timer(const struct rpc_xprt *xprt) +{ + return xprt->idle_timeout != 0; +} + +static void +xprt_schedule_autodisconnect(struct rpc_xprt *xprt) + __must_hold(&xprt->transport_lock) +{ + xprt->last_used = jiffies; + if (RB_EMPTY_ROOT(&xprt->recv_queue) && xprt_has_timer(xprt)) + mod_timer(&xprt->timer, xprt->last_used + xprt->idle_timeout); +} + +static void +xprt_init_autodisconnect(struct timer_list *t) +{ + struct rpc_xprt *xprt = from_timer(xprt, t, timer); + + if (!RB_EMPTY_ROOT(&xprt->recv_queue)) + return; + /* Reset xprt->last_used to avoid connect/autodisconnect cycling */ + xprt->last_used = jiffies; + if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) + return; + queue_work(xprtiod_workqueue, &xprt->task_cleanup); +} + +#if IS_ENABLED(CONFIG_FAIL_SUNRPC) +static void xprt_inject_disconnect(struct rpc_xprt *xprt) +{ + if (!fail_sunrpc.ignore_client_disconnect && + should_fail(&fail_sunrpc.attr, 1)) + xprt->ops->inject_disconnect(xprt); +} +#else +static inline void xprt_inject_disconnect(struct rpc_xprt *xprt) +{ +} +#endif + +bool xprt_lock_connect(struct rpc_xprt *xprt, + struct rpc_task *task, + void *cookie) +{ + bool ret = false; + + spin_lock(&xprt->transport_lock); + if (!test_bit(XPRT_LOCKED, &xprt->state)) + goto out; + if (xprt->snd_task != task) + goto out; + set_bit(XPRT_SND_IS_COOKIE, &xprt->state); + xprt->snd_task = cookie; + ret = true; +out: + spin_unlock(&xprt->transport_lock); + return ret; +} +EXPORT_SYMBOL_GPL(xprt_lock_connect); + +void xprt_unlock_connect(struct rpc_xprt *xprt, void *cookie) +{ + spin_lock(&xprt->transport_lock); + if (xprt->snd_task != cookie) + goto out; + if (!test_bit(XPRT_LOCKED, &xprt->state)) + goto out; + xprt->snd_task =NULL; + clear_bit(XPRT_SND_IS_COOKIE, &xprt->state); + xprt->ops->release_xprt(xprt, NULL); + xprt_schedule_autodisconnect(xprt); +out: + spin_unlock(&xprt->transport_lock); + wake_up_bit(&xprt->state, XPRT_LOCKED); +} +EXPORT_SYMBOL_GPL(xprt_unlock_connect); + +/** + * xprt_connect - schedule a transport connect operation + * @task: RPC task that is requesting the connect + * + */ +void xprt_connect(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; + + trace_xprt_connect(xprt); + + if (!xprt_bound(xprt)) { + task->tk_status = -EAGAIN; + return; + } + if (!xprt_lock_write(xprt, task)) + return; + + if (!xprt_connected(xprt) && !test_bit(XPRT_CLOSE_WAIT, &xprt->state)) { + task->tk_rqstp->rq_connect_cookie = xprt->connect_cookie; + rpc_sleep_on_timeout(&xprt->pending, task, NULL, + xprt_request_timeout(task->tk_rqstp)); + + if (test_bit(XPRT_CLOSING, &xprt->state)) + return; + if (xprt_test_and_set_connecting(xprt)) + return; + /* Race breaker */ + if (!xprt_connected(xprt)) { + xprt->stat.connect_start = jiffies; + xprt->ops->connect(xprt, task); + } else { + xprt_clear_connecting(xprt); + task->tk_status = 0; + rpc_wake_up_queued_task(&xprt->pending, task); + } + } + xprt_release_write(xprt, task); +} + +/** + * xprt_reconnect_delay - compute the wait before scheduling a connect + * @xprt: transport instance + * + */ +unsigned long xprt_reconnect_delay(const struct rpc_xprt *xprt) +{ + unsigned long start, now = jiffies; + + start = xprt->stat.connect_start + xprt->reestablish_timeout; + if (time_after(start, now)) + return start - now; + return 0; +} +EXPORT_SYMBOL_GPL(xprt_reconnect_delay); + +/** + * xprt_reconnect_backoff - compute the new re-establish timeout + * @xprt: transport instance + * @init_to: initial reestablish timeout + * + */ +void xprt_reconnect_backoff(struct rpc_xprt *xprt, unsigned long init_to) +{ + xprt->reestablish_timeout <<= 1; + if (xprt->reestablish_timeout > xprt->max_reconnect_timeout) + xprt->reestablish_timeout = xprt->max_reconnect_timeout; + if (xprt->reestablish_timeout < init_to) + xprt->reestablish_timeout = init_to; +} +EXPORT_SYMBOL_GPL(xprt_reconnect_backoff); + +enum xprt_xid_rb_cmp { + XID_RB_EQUAL, + XID_RB_LEFT, + XID_RB_RIGHT, +}; +static enum xprt_xid_rb_cmp +xprt_xid_cmp(__be32 xid1, __be32 xid2) +{ + if (xid1 == xid2) + return XID_RB_EQUAL; + if ((__force u32)xid1 < (__force u32)xid2) + return XID_RB_LEFT; + return XID_RB_RIGHT; +} + +static struct rpc_rqst * +xprt_request_rb_find(struct rpc_xprt *xprt, __be32 xid) +{ + struct rb_node *n = xprt->recv_queue.rb_node; + struct rpc_rqst *req; + + while (n != NULL) { + req = rb_entry(n, struct rpc_rqst, rq_recv); + switch (xprt_xid_cmp(xid, req->rq_xid)) { + case XID_RB_LEFT: + n = n->rb_left; + break; + case XID_RB_RIGHT: + n = n->rb_right; + break; + case XID_RB_EQUAL: + return req; + } + } + return NULL; +} + +static void +xprt_request_rb_insert(struct rpc_xprt *xprt, struct rpc_rqst *new) +{ + struct rb_node **p = &xprt->recv_queue.rb_node; + struct rb_node *n = NULL; + struct rpc_rqst *req; + + while (*p != NULL) { + n = *p; + req = rb_entry(n, struct rpc_rqst, rq_recv); + switch(xprt_xid_cmp(new->rq_xid, req->rq_xid)) { + case XID_RB_LEFT: + p = &n->rb_left; + break; + case XID_RB_RIGHT: + p = &n->rb_right; + break; + case XID_RB_EQUAL: + WARN_ON_ONCE(new != req); + return; + } + } + rb_link_node(&new->rq_recv, n, p); + rb_insert_color(&new->rq_recv, &xprt->recv_queue); +} + +static void +xprt_request_rb_remove(struct rpc_xprt *xprt, struct rpc_rqst *req) +{ + rb_erase(&req->rq_recv, &xprt->recv_queue); +} + +/** + * xprt_lookup_rqst - find an RPC request corresponding to an XID + * @xprt: transport on which the original request was transmitted + * @xid: RPC XID of incoming reply + * + * Caller holds xprt->queue_lock. + */ +struct rpc_rqst *xprt_lookup_rqst(struct rpc_xprt *xprt, __be32 xid) +{ + struct rpc_rqst *entry; + + entry = xprt_request_rb_find(xprt, xid); + if (entry != NULL) { + trace_xprt_lookup_rqst(xprt, xid, 0); + entry->rq_rtt = ktime_sub(ktime_get(), entry->rq_xtime); + return entry; + } + + dprintk("RPC: xprt_lookup_rqst did not find xid %08x\n", + ntohl(xid)); + trace_xprt_lookup_rqst(xprt, xid, -ENOENT); + xprt->stat.bad_xids++; + return NULL; +} +EXPORT_SYMBOL_GPL(xprt_lookup_rqst); + +static bool +xprt_is_pinned_rqst(struct rpc_rqst *req) +{ + return atomic_read(&req->rq_pin) != 0; +} + +/** + * xprt_pin_rqst - Pin a request on the transport receive list + * @req: Request to pin + * + * Caller must ensure this is atomic with the call to xprt_lookup_rqst() + * so should be holding xprt->queue_lock. + */ +void xprt_pin_rqst(struct rpc_rqst *req) +{ + atomic_inc(&req->rq_pin); +} +EXPORT_SYMBOL_GPL(xprt_pin_rqst); + +/** + * xprt_unpin_rqst - Unpin a request on the transport receive list + * @req: Request to pin + * + * Caller should be holding xprt->queue_lock. + */ +void xprt_unpin_rqst(struct rpc_rqst *req) +{ + if (!test_bit(RPC_TASK_MSG_PIN_WAIT, &req->rq_task->tk_runstate)) { + atomic_dec(&req->rq_pin); + return; + } + if (atomic_dec_and_test(&req->rq_pin)) + wake_up_var(&req->rq_pin); +} +EXPORT_SYMBOL_GPL(xprt_unpin_rqst); + +static void xprt_wait_on_pinned_rqst(struct rpc_rqst *req) +{ + wait_var_event(&req->rq_pin, !xprt_is_pinned_rqst(req)); +} + +static bool +xprt_request_data_received(struct rpc_task *task) +{ + return !test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate) && + READ_ONCE(task->tk_rqstp->rq_reply_bytes_recvd) != 0; +} + +static bool +xprt_request_need_enqueue_receive(struct rpc_task *task, struct rpc_rqst *req) +{ + return !test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate) && + READ_ONCE(task->tk_rqstp->rq_reply_bytes_recvd) == 0; +} + +/** + * xprt_request_enqueue_receive - Add an request to the receive queue + * @task: RPC task + * + */ +int +xprt_request_enqueue_receive(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + int ret; + + if (!xprt_request_need_enqueue_receive(task, req)) + return 0; + + ret = xprt_request_prepare(task->tk_rqstp, &req->rq_rcv_buf); + if (ret) + return ret; + spin_lock(&xprt->queue_lock); + + /* Update the softirq receive buffer */ + memcpy(&req->rq_private_buf, &req->rq_rcv_buf, + sizeof(req->rq_private_buf)); + + /* Add request to the receive list */ + xprt_request_rb_insert(xprt, req); + set_bit(RPC_TASK_NEED_RECV, &task->tk_runstate); + spin_unlock(&xprt->queue_lock); + + /* Turn off autodisconnect */ + del_singleshot_timer_sync(&xprt->timer); + return 0; +} + +/** + * xprt_request_dequeue_receive_locked - Remove a request from the receive queue + * @task: RPC task + * + * Caller must hold xprt->queue_lock. + */ +static void +xprt_request_dequeue_receive_locked(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + + if (test_and_clear_bit(RPC_TASK_NEED_RECV, &task->tk_runstate)) + xprt_request_rb_remove(req->rq_xprt, req); +} + +/** + * xprt_update_rtt - Update RPC RTT statistics + * @task: RPC request that recently completed + * + * Caller holds xprt->queue_lock. + */ +void xprt_update_rtt(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_rtt *rtt = task->tk_client->cl_rtt; + unsigned int timer = task->tk_msg.rpc_proc->p_timer; + long m = usecs_to_jiffies(ktime_to_us(req->rq_rtt)); + + if (timer) { + if (req->rq_ntrans == 1) + rpc_update_rtt(rtt, timer, m); + rpc_set_timeo(rtt, timer, req->rq_ntrans - 1); + } +} +EXPORT_SYMBOL_GPL(xprt_update_rtt); + +/** + * xprt_complete_rqst - called when reply processing is complete + * @task: RPC request that recently completed + * @copied: actual number of bytes received from the transport + * + * Caller holds xprt->queue_lock. + */ +void xprt_complete_rqst(struct rpc_task *task, int copied) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + xprt->stat.recvs++; + + xdr_free_bvec(&req->rq_rcv_buf); + req->rq_private_buf.bvec = NULL; + req->rq_private_buf.len = copied; + /* Ensure all writes are done before we update */ + /* req->rq_reply_bytes_recvd */ + smp_wmb(); + req->rq_reply_bytes_recvd = copied; + xprt_request_dequeue_receive_locked(task); + rpc_wake_up_queued_task(&xprt->pending, task); +} +EXPORT_SYMBOL_GPL(xprt_complete_rqst); + +static void xprt_timer(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + if (task->tk_status != -ETIMEDOUT) + return; + + trace_xprt_timer(xprt, req->rq_xid, task->tk_status); + if (!req->rq_reply_bytes_recvd) { + if (xprt->ops->timer) + xprt->ops->timer(xprt, task); + } else + task->tk_status = 0; +} + +/** + * xprt_wait_for_reply_request_def - wait for reply + * @task: pointer to rpc_task + * + * Set a request's retransmit timeout based on the transport's + * default timeout parameters. Used by transports that don't adjust + * the retransmit timeout based on round-trip time estimation, + * and put the task to sleep on the pending queue. + */ +void xprt_wait_for_reply_request_def(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + + rpc_sleep_on_timeout(&req->rq_xprt->pending, task, xprt_timer, + xprt_request_timeout(req)); +} +EXPORT_SYMBOL_GPL(xprt_wait_for_reply_request_def); + +/** + * xprt_wait_for_reply_request_rtt - wait for reply using RTT estimator + * @task: pointer to rpc_task + * + * Set a request's retransmit timeout using the RTT estimator, + * and put the task to sleep on the pending queue. + */ +void xprt_wait_for_reply_request_rtt(struct rpc_task *task) +{ + int timer = task->tk_msg.rpc_proc->p_timer; + struct rpc_clnt *clnt = task->tk_client; + struct rpc_rtt *rtt = clnt->cl_rtt; + struct rpc_rqst *req = task->tk_rqstp; + unsigned long max_timeout = clnt->cl_timeout->to_maxval; + unsigned long timeout; + + timeout = rpc_calc_rto(rtt, timer); + timeout <<= rpc_ntimeo(rtt, timer) + req->rq_retries; + if (timeout > max_timeout || timeout == 0) + timeout = max_timeout; + rpc_sleep_on_timeout(&req->rq_xprt->pending, task, xprt_timer, + jiffies + timeout); +} +EXPORT_SYMBOL_GPL(xprt_wait_for_reply_request_rtt); + +/** + * xprt_request_wait_receive - wait for the reply to an RPC request + * @task: RPC task about to send a request + * + */ +void xprt_request_wait_receive(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + if (!test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate)) + return; + /* + * Sleep on the pending queue if we're expecting a reply. + * The spinlock ensures atomicity between the test of + * req->rq_reply_bytes_recvd, and the call to rpc_sleep_on(). + */ + spin_lock(&xprt->queue_lock); + if (test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate)) { + xprt->ops->wait_for_reply_request(task); + /* + * Send an extra queue wakeup call if the + * connection was dropped in case the call to + * rpc_sleep_on() raced. + */ + if (xprt_request_retransmit_after_disconnect(task)) + rpc_wake_up_queued_task_set_status(&xprt->pending, + task, -ENOTCONN); + } + spin_unlock(&xprt->queue_lock); +} + +static bool +xprt_request_need_enqueue_transmit(struct rpc_task *task, struct rpc_rqst *req) +{ + return !test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate); +} + +/** + * xprt_request_enqueue_transmit - queue a task for transmission + * @task: pointer to rpc_task + * + * Add a task to the transmission queue. + */ +void +xprt_request_enqueue_transmit(struct rpc_task *task) +{ + struct rpc_rqst *pos, *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + int ret; + + if (xprt_request_need_enqueue_transmit(task, req)) { + ret = xprt_request_prepare(task->tk_rqstp, &req->rq_snd_buf); + if (ret) { + task->tk_status = ret; + return; + } + req->rq_bytes_sent = 0; + spin_lock(&xprt->queue_lock); + /* + * Requests that carry congestion control credits are added + * to the head of the list to avoid starvation issues. + */ + if (req->rq_cong) { + xprt_clear_congestion_window_wait(xprt); + list_for_each_entry(pos, &xprt->xmit_queue, rq_xmit) { + if (pos->rq_cong) + continue; + /* Note: req is added _before_ pos */ + list_add_tail(&req->rq_xmit, &pos->rq_xmit); + INIT_LIST_HEAD(&req->rq_xmit2); + goto out; + } + } else if (!req->rq_seqno) { + list_for_each_entry(pos, &xprt->xmit_queue, rq_xmit) { + if (pos->rq_task->tk_owner != task->tk_owner) + continue; + list_add_tail(&req->rq_xmit2, &pos->rq_xmit2); + INIT_LIST_HEAD(&req->rq_xmit); + goto out; + } + } + list_add_tail(&req->rq_xmit, &xprt->xmit_queue); + INIT_LIST_HEAD(&req->rq_xmit2); +out: + atomic_long_inc(&xprt->xmit_queuelen); + set_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate); + spin_unlock(&xprt->queue_lock); + } +} + +/** + * xprt_request_dequeue_transmit_locked - remove a task from the transmission queue + * @task: pointer to rpc_task + * + * Remove a task from the transmission queue + * Caller must hold xprt->queue_lock + */ +static void +xprt_request_dequeue_transmit_locked(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + + if (!test_and_clear_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) + return; + if (!list_empty(&req->rq_xmit)) { + list_del(&req->rq_xmit); + if (!list_empty(&req->rq_xmit2)) { + struct rpc_rqst *next = list_first_entry(&req->rq_xmit2, + struct rpc_rqst, rq_xmit2); + list_del(&req->rq_xmit2); + list_add_tail(&next->rq_xmit, &next->rq_xprt->xmit_queue); + } + } else + list_del(&req->rq_xmit2); + atomic_long_dec(&req->rq_xprt->xmit_queuelen); + xdr_free_bvec(&req->rq_snd_buf); +} + +/** + * xprt_request_dequeue_transmit - remove a task from the transmission queue + * @task: pointer to rpc_task + * + * Remove a task from the transmission queue + */ +static void +xprt_request_dequeue_transmit(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + spin_lock(&xprt->queue_lock); + xprt_request_dequeue_transmit_locked(task); + spin_unlock(&xprt->queue_lock); +} + +/** + * xprt_request_dequeue_xprt - remove a task from the transmit+receive queue + * @task: pointer to rpc_task + * + * Remove a task from the transmit and receive queues, and ensure that + * it is not pinned by the receive work item. + */ +void +xprt_request_dequeue_xprt(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + if (test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate) || + test_bit(RPC_TASK_NEED_RECV, &task->tk_runstate) || + xprt_is_pinned_rqst(req)) { + spin_lock(&xprt->queue_lock); + while (xprt_is_pinned_rqst(req)) { + set_bit(RPC_TASK_MSG_PIN_WAIT, &task->tk_runstate); + spin_unlock(&xprt->queue_lock); + xprt_wait_on_pinned_rqst(req); + spin_lock(&xprt->queue_lock); + clear_bit(RPC_TASK_MSG_PIN_WAIT, &task->tk_runstate); + } + xprt_request_dequeue_transmit_locked(task); + xprt_request_dequeue_receive_locked(task); + spin_unlock(&xprt->queue_lock); + xdr_free_bvec(&req->rq_rcv_buf); + } +} + +/** + * xprt_request_prepare - prepare an encoded request for transport + * @req: pointer to rpc_rqst + * @buf: pointer to send/rcv xdr_buf + * + * Calls into the transport layer to do whatever is needed to prepare + * the request for transmission or receive. + * Returns error, or zero. + */ +static int +xprt_request_prepare(struct rpc_rqst *req, struct xdr_buf *buf) +{ + struct rpc_xprt *xprt = req->rq_xprt; + + if (xprt->ops->prepare_request) + return xprt->ops->prepare_request(req, buf); + return 0; +} + +/** + * xprt_request_need_retransmit - Test if a task needs retransmission + * @task: pointer to rpc_task + * + * Test for whether a connection breakage requires the task to retransmit + */ +bool +xprt_request_need_retransmit(struct rpc_task *task) +{ + return xprt_request_retransmit_after_disconnect(task); +} + +/** + * xprt_prepare_transmit - reserve the transport before sending a request + * @task: RPC task about to send a request + * + */ +bool xprt_prepare_transmit(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + if (!xprt_lock_write(xprt, task)) { + /* Race breaker: someone may have transmitted us */ + if (!test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) + rpc_wake_up_queued_task_set_status(&xprt->sending, + task, 0); + return false; + + } + if (atomic_read(&xprt->swapper)) + /* This will be clear in __rpc_execute */ + current->flags |= PF_MEMALLOC; + return true; +} + +void xprt_end_transmit(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_rqstp->rq_xprt; + + xprt_inject_disconnect(xprt); + xprt_release_write(xprt, task); +} + +/** + * xprt_request_transmit - send an RPC request on a transport + * @req: pointer to request to transmit + * @snd_task: RPC task that owns the transport lock + * + * This performs the transmission of a single request. + * Note that if the request is not the same as snd_task, then it + * does need to be pinned. + * Returns '0' on success. + */ +static int +xprt_request_transmit(struct rpc_rqst *req, struct rpc_task *snd_task) +{ + struct rpc_xprt *xprt = req->rq_xprt; + struct rpc_task *task = req->rq_task; + unsigned int connect_cookie; + int is_retrans = RPC_WAS_SENT(task); + int status; + + if (!req->rq_bytes_sent) { + if (xprt_request_data_received(task)) { + status = 0; + goto out_dequeue; + } + /* Verify that our message lies in the RPCSEC_GSS window */ + if (rpcauth_xmit_need_reencode(task)) { + status = -EBADMSG; + goto out_dequeue; + } + if (RPC_SIGNALLED(task)) { + status = -ERESTARTSYS; + goto out_dequeue; + } + } + + /* + * Update req->rq_ntrans before transmitting to avoid races with + * xprt_update_rtt(), which needs to know that it is recording a + * reply to the first transmission. + */ + req->rq_ntrans++; + + trace_rpc_xdr_sendto(task, &req->rq_snd_buf); + connect_cookie = xprt->connect_cookie; + status = xprt->ops->send_request(req); + if (status != 0) { + req->rq_ntrans--; + trace_xprt_transmit(req, status); + return status; + } + + if (is_retrans) { + task->tk_client->cl_stats->rpcretrans++; + trace_xprt_retransmit(req); + } + + xprt_inject_disconnect(xprt); + + task->tk_flags |= RPC_TASK_SENT; + spin_lock(&xprt->transport_lock); + + xprt->stat.sends++; + xprt->stat.req_u += xprt->stat.sends - xprt->stat.recvs; + xprt->stat.bklog_u += xprt->backlog.qlen; + xprt->stat.sending_u += xprt->sending.qlen; + xprt->stat.pending_u += xprt->pending.qlen; + spin_unlock(&xprt->transport_lock); + + req->rq_connect_cookie = connect_cookie; +out_dequeue: + trace_xprt_transmit(req, status); + xprt_request_dequeue_transmit(task); + rpc_wake_up_queued_task_set_status(&xprt->sending, task, status); + return status; +} + +/** + * xprt_transmit - send an RPC request on a transport + * @task: controlling RPC task + * + * Attempts to drain the transmit queue. On exit, either the transport + * signalled an error that needs to be handled before transmission can + * resume, or @task finished transmitting, and detected that it already + * received a reply. + */ +void +xprt_transmit(struct rpc_task *task) +{ + struct rpc_rqst *next, *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + int status; + + spin_lock(&xprt->queue_lock); + for (;;) { + next = list_first_entry_or_null(&xprt->xmit_queue, + struct rpc_rqst, rq_xmit); + if (!next) + break; + xprt_pin_rqst(next); + spin_unlock(&xprt->queue_lock); + status = xprt_request_transmit(next, task); + if (status == -EBADMSG && next != req) + status = 0; + spin_lock(&xprt->queue_lock); + xprt_unpin_rqst(next); + if (status < 0) { + if (test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) + task->tk_status = status; + break; + } + /* Was @task transmitted, and has it received a reply? */ + if (xprt_request_data_received(task) && + !test_bit(RPC_TASK_NEED_XMIT, &task->tk_runstate)) + break; + cond_resched_lock(&xprt->queue_lock); + } + spin_unlock(&xprt->queue_lock); +} + +static void xprt_complete_request_init(struct rpc_task *task) +{ + if (task->tk_rqstp) + xprt_request_init(task); +} + +void xprt_add_backlog(struct rpc_xprt *xprt, struct rpc_task *task) +{ + set_bit(XPRT_CONGESTED, &xprt->state); + rpc_sleep_on(&xprt->backlog, task, xprt_complete_request_init); +} +EXPORT_SYMBOL_GPL(xprt_add_backlog); + +static bool __xprt_set_rq(struct rpc_task *task, void *data) +{ + struct rpc_rqst *req = data; + + if (task->tk_rqstp == NULL) { + memset(req, 0, sizeof(*req)); /* mark unused */ + task->tk_rqstp = req; + return true; + } + return false; +} + +bool xprt_wake_up_backlog(struct rpc_xprt *xprt, struct rpc_rqst *req) +{ + if (rpc_wake_up_first(&xprt->backlog, __xprt_set_rq, req) == NULL) { + clear_bit(XPRT_CONGESTED, &xprt->state); + return false; + } + return true; +} +EXPORT_SYMBOL_GPL(xprt_wake_up_backlog); + +static bool xprt_throttle_congested(struct rpc_xprt *xprt, struct rpc_task *task) +{ + bool ret = false; + + if (!test_bit(XPRT_CONGESTED, &xprt->state)) + goto out; + spin_lock(&xprt->reserve_lock); + if (test_bit(XPRT_CONGESTED, &xprt->state)) { + xprt_add_backlog(xprt, task); + ret = true; + } + spin_unlock(&xprt->reserve_lock); +out: + return ret; +} + +static struct rpc_rqst *xprt_dynamic_alloc_slot(struct rpc_xprt *xprt) +{ + struct rpc_rqst *req = ERR_PTR(-EAGAIN); + + if (xprt->num_reqs >= xprt->max_reqs) + goto out; + ++xprt->num_reqs; + spin_unlock(&xprt->reserve_lock); + req = kzalloc(sizeof(*req), rpc_task_gfp_mask()); + spin_lock(&xprt->reserve_lock); + if (req != NULL) + goto out; + --xprt->num_reqs; + req = ERR_PTR(-ENOMEM); +out: + return req; +} + +static bool xprt_dynamic_free_slot(struct rpc_xprt *xprt, struct rpc_rqst *req) +{ + if (xprt->num_reqs > xprt->min_reqs) { + --xprt->num_reqs; + kfree(req); + return true; + } + return false; +} + +void xprt_alloc_slot(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct rpc_rqst *req; + + spin_lock(&xprt->reserve_lock); + if (!list_empty(&xprt->free)) { + req = list_entry(xprt->free.next, struct rpc_rqst, rq_list); + list_del(&req->rq_list); + goto out_init_req; + } + req = xprt_dynamic_alloc_slot(xprt); + if (!IS_ERR(req)) + goto out_init_req; + switch (PTR_ERR(req)) { + case -ENOMEM: + dprintk("RPC: dynamic allocation of request slot " + "failed! Retrying\n"); + task->tk_status = -ENOMEM; + break; + case -EAGAIN: + xprt_add_backlog(xprt, task); + dprintk("RPC: waiting for request slot\n"); + fallthrough; + default: + task->tk_status = -EAGAIN; + } + spin_unlock(&xprt->reserve_lock); + return; +out_init_req: + xprt->stat.max_slots = max_t(unsigned int, xprt->stat.max_slots, + xprt->num_reqs); + spin_unlock(&xprt->reserve_lock); + + task->tk_status = 0; + task->tk_rqstp = req; +} +EXPORT_SYMBOL_GPL(xprt_alloc_slot); + +void xprt_free_slot(struct rpc_xprt *xprt, struct rpc_rqst *req) +{ + spin_lock(&xprt->reserve_lock); + if (!xprt_wake_up_backlog(xprt, req) && + !xprt_dynamic_free_slot(xprt, req)) { + memset(req, 0, sizeof(*req)); /* mark unused */ + list_add(&req->rq_list, &xprt->free); + } + spin_unlock(&xprt->reserve_lock); +} +EXPORT_SYMBOL_GPL(xprt_free_slot); + +static void xprt_free_all_slots(struct rpc_xprt *xprt) +{ + struct rpc_rqst *req; + while (!list_empty(&xprt->free)) { + req = list_first_entry(&xprt->free, struct rpc_rqst, rq_list); + list_del(&req->rq_list); + kfree(req); + } +} + +static DEFINE_IDA(rpc_xprt_ids); + +void xprt_cleanup_ids(void) +{ + ida_destroy(&rpc_xprt_ids); +} + +static int xprt_alloc_id(struct rpc_xprt *xprt) +{ + int id; + + id = ida_alloc(&rpc_xprt_ids, GFP_KERNEL); + if (id < 0) + return id; + + xprt->id = id; + return 0; +} + +static void xprt_free_id(struct rpc_xprt *xprt) +{ + ida_free(&rpc_xprt_ids, xprt->id); +} + +struct rpc_xprt *xprt_alloc(struct net *net, size_t size, + unsigned int num_prealloc, + unsigned int max_alloc) +{ + struct rpc_xprt *xprt; + struct rpc_rqst *req; + int i; + + xprt = kzalloc(size, GFP_KERNEL); + if (xprt == NULL) + goto out; + + xprt_alloc_id(xprt); + xprt_init(xprt, net); + + for (i = 0; i < num_prealloc; i++) { + req = kzalloc(sizeof(struct rpc_rqst), GFP_KERNEL); + if (!req) + goto out_free; + list_add(&req->rq_list, &xprt->free); + } + xprt->max_reqs = max_t(unsigned int, max_alloc, num_prealloc); + xprt->min_reqs = num_prealloc; + xprt->num_reqs = num_prealloc; + + return xprt; + +out_free: + xprt_free(xprt); +out: + return NULL; +} +EXPORT_SYMBOL_GPL(xprt_alloc); + +void xprt_free(struct rpc_xprt *xprt) +{ + put_net_track(xprt->xprt_net, &xprt->ns_tracker); + xprt_free_all_slots(xprt); + xprt_free_id(xprt); + rpc_sysfs_xprt_destroy(xprt); + kfree_rcu(xprt, rcu); +} +EXPORT_SYMBOL_GPL(xprt_free); + +static void +xprt_init_connect_cookie(struct rpc_rqst *req, struct rpc_xprt *xprt) +{ + req->rq_connect_cookie = xprt_connect_cookie(xprt) - 1; +} + +static __be32 +xprt_alloc_xid(struct rpc_xprt *xprt) +{ + __be32 xid; + + spin_lock(&xprt->reserve_lock); + xid = (__force __be32)xprt->xid++; + spin_unlock(&xprt->reserve_lock); + return xid; +} + +static void +xprt_init_xid(struct rpc_xprt *xprt) +{ + xprt->xid = get_random_u32(); +} + +static void +xprt_request_init(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + struct rpc_rqst *req = task->tk_rqstp; + + req->rq_task = task; + req->rq_xprt = xprt; + req->rq_buffer = NULL; + req->rq_xid = xprt_alloc_xid(xprt); + xprt_init_connect_cookie(req, xprt); + req->rq_snd_buf.len = 0; + req->rq_snd_buf.buflen = 0; + req->rq_rcv_buf.len = 0; + req->rq_rcv_buf.buflen = 0; + req->rq_snd_buf.bvec = NULL; + req->rq_rcv_buf.bvec = NULL; + req->rq_release_snd_buf = NULL; + xprt_init_majortimeo(task, req); + + trace_xprt_reserve(req); +} + +static void +xprt_do_reserve(struct rpc_xprt *xprt, struct rpc_task *task) +{ + xprt->ops->alloc_slot(xprt, task); + if (task->tk_rqstp != NULL) + xprt_request_init(task); +} + +/** + * xprt_reserve - allocate an RPC request slot + * @task: RPC task requesting a slot allocation + * + * If the transport is marked as being congested, or if no more + * slots are available, place the task on the transport's + * backlog queue. + */ +void xprt_reserve(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + + task->tk_status = 0; + if (task->tk_rqstp != NULL) + return; + + task->tk_status = -EAGAIN; + if (!xprt_throttle_congested(xprt, task)) + xprt_do_reserve(xprt, task); +} + +/** + * xprt_retry_reserve - allocate an RPC request slot + * @task: RPC task requesting a slot allocation + * + * If no more slots are available, place the task on the transport's + * backlog queue. + * Note that the only difference with xprt_reserve is that we now + * ignore the value of the XPRT_CONGESTED flag. + */ +void xprt_retry_reserve(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + + task->tk_status = 0; + if (task->tk_rqstp != NULL) + return; + + task->tk_status = -EAGAIN; + xprt_do_reserve(xprt, task); +} + +/** + * xprt_release - release an RPC request slot + * @task: task which is finished with the slot + * + */ +void xprt_release(struct rpc_task *task) +{ + struct rpc_xprt *xprt; + struct rpc_rqst *req = task->tk_rqstp; + + if (req == NULL) { + if (task->tk_client) { + xprt = task->tk_xprt; + xprt_release_write(xprt, task); + } + return; + } + + xprt = req->rq_xprt; + xprt_request_dequeue_xprt(task); + spin_lock(&xprt->transport_lock); + xprt->ops->release_xprt(xprt, task); + if (xprt->ops->release_request) + xprt->ops->release_request(task); + xprt_schedule_autodisconnect(xprt); + spin_unlock(&xprt->transport_lock); + if (req->rq_buffer) + xprt->ops->buf_free(task); + if (req->rq_cred != NULL) + put_rpccred(req->rq_cred); + if (req->rq_release_snd_buf) + req->rq_release_snd_buf(req); + + task->tk_rqstp = NULL; + if (likely(!bc_prealloc(req))) + xprt->ops->free_slot(xprt, req); + else + xprt_free_bc_request(req); +} + +#ifdef CONFIG_SUNRPC_BACKCHANNEL +void +xprt_init_bc_request(struct rpc_rqst *req, struct rpc_task *task) +{ + struct xdr_buf *xbufp = &req->rq_snd_buf; + + task->tk_rqstp = req; + req->rq_task = task; + xprt_init_connect_cookie(req, req->rq_xprt); + /* + * Set up the xdr_buf length. + * This also indicates that the buffer is XDR encoded already. + */ + xbufp->len = xbufp->head[0].iov_len + xbufp->page_len + + xbufp->tail[0].iov_len; +} +#endif + +static void xprt_init(struct rpc_xprt *xprt, struct net *net) +{ + kref_init(&xprt->kref); + + spin_lock_init(&xprt->transport_lock); + spin_lock_init(&xprt->reserve_lock); + spin_lock_init(&xprt->queue_lock); + + INIT_LIST_HEAD(&xprt->free); + xprt->recv_queue = RB_ROOT; + INIT_LIST_HEAD(&xprt->xmit_queue); +#if defined(CONFIG_SUNRPC_BACKCHANNEL) + spin_lock_init(&xprt->bc_pa_lock); + INIT_LIST_HEAD(&xprt->bc_pa_list); +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + INIT_LIST_HEAD(&xprt->xprt_switch); + + xprt->last_used = jiffies; + xprt->cwnd = RPC_INITCWND; + xprt->bind_index = 0; + + rpc_init_wait_queue(&xprt->binding, "xprt_binding"); + rpc_init_wait_queue(&xprt->pending, "xprt_pending"); + rpc_init_wait_queue(&xprt->sending, "xprt_sending"); + rpc_init_priority_wait_queue(&xprt->backlog, "xprt_backlog"); + + xprt_init_xid(xprt); + + xprt->xprt_net = get_net_track(net, &xprt->ns_tracker, GFP_KERNEL); +} + +/** + * xprt_create_transport - create an RPC transport + * @args: rpc transport creation arguments + * + */ +struct rpc_xprt *xprt_create_transport(struct xprt_create *args) +{ + struct rpc_xprt *xprt; + const struct xprt_class *t; + + t = xprt_class_find_by_ident(args->ident); + if (!t) { + dprintk("RPC: transport (%d) not supported\n", args->ident); + return ERR_PTR(-EIO); + } + + xprt = t->setup(args); + xprt_class_release(t); + + if (IS_ERR(xprt)) + goto out; + if (args->flags & XPRT_CREATE_NO_IDLE_TIMEOUT) + xprt->idle_timeout = 0; + INIT_WORK(&xprt->task_cleanup, xprt_autoclose); + if (xprt_has_timer(xprt)) + timer_setup(&xprt->timer, xprt_init_autodisconnect, 0); + else + timer_setup(&xprt->timer, NULL, 0); + + if (strlen(args->servername) > RPC_MAXNETNAMELEN) { + xprt_destroy(xprt); + return ERR_PTR(-EINVAL); + } + xprt->servername = kstrdup(args->servername, GFP_KERNEL); + if (xprt->servername == NULL) { + xprt_destroy(xprt); + return ERR_PTR(-ENOMEM); + } + + rpc_xprt_debugfs_register(xprt); + + trace_xprt_create(xprt); +out: + return xprt; +} + +static void xprt_destroy_cb(struct work_struct *work) +{ + struct rpc_xprt *xprt = + container_of(work, struct rpc_xprt, task_cleanup); + + trace_xprt_destroy(xprt); + + rpc_xprt_debugfs_unregister(xprt); + rpc_destroy_wait_queue(&xprt->binding); + rpc_destroy_wait_queue(&xprt->pending); + rpc_destroy_wait_queue(&xprt->sending); + rpc_destroy_wait_queue(&xprt->backlog); + kfree(xprt->servername); + /* + * Destroy any existing back channel + */ + xprt_destroy_backchannel(xprt, UINT_MAX); + + /* + * Tear down transport state and free the rpc_xprt + */ + xprt->ops->destroy(xprt); +} + +/** + * xprt_destroy - destroy an RPC transport, killing off all requests. + * @xprt: transport to destroy + * + */ +static void xprt_destroy(struct rpc_xprt *xprt) +{ + /* + * Exclude transport connect/disconnect handlers and autoclose + */ + wait_on_bit_lock(&xprt->state, XPRT_LOCKED, TASK_UNINTERRUPTIBLE); + + /* + * xprt_schedule_autodisconnect() can run after XPRT_LOCKED + * is cleared. We use ->transport_lock to ensure the mod_timer() + * can only run *before* del_time_sync(), never after. + */ + spin_lock(&xprt->transport_lock); + del_timer_sync(&xprt->timer); + spin_unlock(&xprt->transport_lock); + + /* + * Destroy sockets etc from the system workqueue so they can + * safely flush receive work running on rpciod. + */ + INIT_WORK(&xprt->task_cleanup, xprt_destroy_cb); + schedule_work(&xprt->task_cleanup); +} + +static void xprt_destroy_kref(struct kref *kref) +{ + xprt_destroy(container_of(kref, struct rpc_xprt, kref)); +} + +/** + * xprt_get - return a reference to an RPC transport. + * @xprt: pointer to the transport + * + */ +struct rpc_xprt *xprt_get(struct rpc_xprt *xprt) +{ + if (xprt != NULL && kref_get_unless_zero(&xprt->kref)) + return xprt; + return NULL; +} +EXPORT_SYMBOL_GPL(xprt_get); + +/** + * xprt_put - release a reference to an RPC transport. + * @xprt: pointer to the transport + * + */ +void xprt_put(struct rpc_xprt *xprt) +{ + if (xprt != NULL) + kref_put(&xprt->kref, xprt_destroy_kref); +} +EXPORT_SYMBOL_GPL(xprt_put); + +void xprt_set_offline_locked(struct rpc_xprt *xprt, struct rpc_xprt_switch *xps) +{ + if (!test_and_set_bit(XPRT_OFFLINE, &xprt->state)) { + spin_lock(&xps->xps_lock); + xps->xps_nactive--; + spin_unlock(&xps->xps_lock); + } +} + +void xprt_set_online_locked(struct rpc_xprt *xprt, struct rpc_xprt_switch *xps) +{ + if (test_and_clear_bit(XPRT_OFFLINE, &xprt->state)) { + spin_lock(&xps->xps_lock); + xps->xps_nactive++; + spin_unlock(&xps->xps_lock); + } +} + +void xprt_delete_locked(struct rpc_xprt *xprt, struct rpc_xprt_switch *xps) +{ + if (test_and_set_bit(XPRT_REMOVE, &xprt->state)) + return; + + xprt_force_disconnect(xprt); + if (!test_bit(XPRT_CONNECTED, &xprt->state)) + return; + + if (!xprt->sending.qlen && !xprt->pending.qlen && + !xprt->backlog.qlen && !atomic_long_read(&xprt->queuelen)) + rpc_xprt_switch_remove_xprt(xps, xprt, true); +} diff --git a/net/sunrpc/xprtmultipath.c b/net/sunrpc/xprtmultipath.c new file mode 100644 index 000000000..74ee22712 --- /dev/null +++ b/net/sunrpc/xprtmultipath.c @@ -0,0 +1,655 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Multipath support for RPC + * + * Copyright (c) 2015, 2016, Primary Data, Inc. All rights reserved. + * + * Trond Myklebust + * + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "sysfs.h" + +typedef struct rpc_xprt *(*xprt_switch_find_xprt_t)(struct rpc_xprt_switch *xps, + const struct rpc_xprt *cur); + +static const struct rpc_xprt_iter_ops rpc_xprt_iter_singular; +static const struct rpc_xprt_iter_ops rpc_xprt_iter_roundrobin; +static const struct rpc_xprt_iter_ops rpc_xprt_iter_listall; +static const struct rpc_xprt_iter_ops rpc_xprt_iter_listoffline; + +static void xprt_switch_add_xprt_locked(struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt) +{ + if (unlikely(xprt_get(xprt) == NULL)) + return; + list_add_tail_rcu(&xprt->xprt_switch, &xps->xps_xprt_list); + smp_wmb(); + if (xps->xps_nxprts == 0) + xps->xps_net = xprt->xprt_net; + xps->xps_nxprts++; + xps->xps_nactive++; +} + +/** + * rpc_xprt_switch_add_xprt - Add a new rpc_xprt to an rpc_xprt_switch + * @xps: pointer to struct rpc_xprt_switch + * @xprt: pointer to struct rpc_xprt + * + * Adds xprt to the end of the list of struct rpc_xprt in xps. + */ +void rpc_xprt_switch_add_xprt(struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt) +{ + if (xprt == NULL) + return; + spin_lock(&xps->xps_lock); + if (xps->xps_net == xprt->xprt_net || xps->xps_net == NULL) + xprt_switch_add_xprt_locked(xps, xprt); + spin_unlock(&xps->xps_lock); + rpc_sysfs_xprt_setup(xps, xprt, GFP_KERNEL); +} + +static void xprt_switch_remove_xprt_locked(struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt, bool offline) +{ + if (unlikely(xprt == NULL)) + return; + if (!test_bit(XPRT_OFFLINE, &xprt->state) && offline) + xps->xps_nactive--; + xps->xps_nxprts--; + if (xps->xps_nxprts == 0) + xps->xps_net = NULL; + smp_wmb(); + list_del_rcu(&xprt->xprt_switch); +} + +/** + * rpc_xprt_switch_remove_xprt - Removes an rpc_xprt from a rpc_xprt_switch + * @xps: pointer to struct rpc_xprt_switch + * @xprt: pointer to struct rpc_xprt + * @offline: indicates if the xprt that's being removed is in an offline state + * + * Removes xprt from the list of struct rpc_xprt in xps. + */ +void rpc_xprt_switch_remove_xprt(struct rpc_xprt_switch *xps, + struct rpc_xprt *xprt, bool offline) +{ + spin_lock(&xps->xps_lock); + xprt_switch_remove_xprt_locked(xps, xprt, offline); + spin_unlock(&xps->xps_lock); + xprt_put(xprt); +} + +static DEFINE_IDA(rpc_xprtswitch_ids); + +void xprt_multipath_cleanup_ids(void) +{ + ida_destroy(&rpc_xprtswitch_ids); +} + +static int xprt_switch_alloc_id(struct rpc_xprt_switch *xps, gfp_t gfp_flags) +{ + int id; + + id = ida_alloc(&rpc_xprtswitch_ids, gfp_flags); + if (id < 0) + return id; + + xps->xps_id = id; + return 0; +} + +static void xprt_switch_free_id(struct rpc_xprt_switch *xps) +{ + ida_free(&rpc_xprtswitch_ids, xps->xps_id); +} + +/** + * xprt_switch_alloc - Allocate a new struct rpc_xprt_switch + * @xprt: pointer to struct rpc_xprt + * @gfp_flags: allocation flags + * + * On success, returns an initialised struct rpc_xprt_switch, containing + * the entry xprt. Returns NULL on failure. + */ +struct rpc_xprt_switch *xprt_switch_alloc(struct rpc_xprt *xprt, + gfp_t gfp_flags) +{ + struct rpc_xprt_switch *xps; + + xps = kmalloc(sizeof(*xps), gfp_flags); + if (xps != NULL) { + spin_lock_init(&xps->xps_lock); + kref_init(&xps->xps_kref); + xprt_switch_alloc_id(xps, gfp_flags); + xps->xps_nxprts = xps->xps_nactive = 0; + atomic_long_set(&xps->xps_queuelen, 0); + xps->xps_net = NULL; + INIT_LIST_HEAD(&xps->xps_xprt_list); + xps->xps_iter_ops = &rpc_xprt_iter_singular; + rpc_sysfs_xprt_switch_setup(xps, xprt, gfp_flags); + xprt_switch_add_xprt_locked(xps, xprt); + xps->xps_nunique_destaddr_xprts = 1; + rpc_sysfs_xprt_setup(xps, xprt, gfp_flags); + } + + return xps; +} + +static void xprt_switch_free_entries(struct rpc_xprt_switch *xps) +{ + spin_lock(&xps->xps_lock); + while (!list_empty(&xps->xps_xprt_list)) { + struct rpc_xprt *xprt; + + xprt = list_first_entry(&xps->xps_xprt_list, + struct rpc_xprt, xprt_switch); + xprt_switch_remove_xprt_locked(xps, xprt, true); + spin_unlock(&xps->xps_lock); + xprt_put(xprt); + spin_lock(&xps->xps_lock); + } + spin_unlock(&xps->xps_lock); +} + +static void xprt_switch_free(struct kref *kref) +{ + struct rpc_xprt_switch *xps = container_of(kref, + struct rpc_xprt_switch, xps_kref); + + xprt_switch_free_entries(xps); + rpc_sysfs_xprt_switch_destroy(xps); + xprt_switch_free_id(xps); + kfree_rcu(xps, xps_rcu); +} + +/** + * xprt_switch_get - Return a reference to a rpc_xprt_switch + * @xps: pointer to struct rpc_xprt_switch + * + * Returns a reference to xps unless the refcount is already zero. + */ +struct rpc_xprt_switch *xprt_switch_get(struct rpc_xprt_switch *xps) +{ + if (xps != NULL && kref_get_unless_zero(&xps->xps_kref)) + return xps; + return NULL; +} + +/** + * xprt_switch_put - Release a reference to a rpc_xprt_switch + * @xps: pointer to struct rpc_xprt_switch + * + * Release the reference to xps, and free it once the refcount is zero. + */ +void xprt_switch_put(struct rpc_xprt_switch *xps) +{ + if (xps != NULL) + kref_put(&xps->xps_kref, xprt_switch_free); +} + +/** + * rpc_xprt_switch_set_roundrobin - Set a round-robin policy on rpc_xprt_switch + * @xps: pointer to struct rpc_xprt_switch + * + * Sets a round-robin default policy for iterators acting on xps. + */ +void rpc_xprt_switch_set_roundrobin(struct rpc_xprt_switch *xps) +{ + if (READ_ONCE(xps->xps_iter_ops) != &rpc_xprt_iter_roundrobin) + WRITE_ONCE(xps->xps_iter_ops, &rpc_xprt_iter_roundrobin); +} + +static +const struct rpc_xprt_iter_ops *xprt_iter_ops(const struct rpc_xprt_iter *xpi) +{ + if (xpi->xpi_ops != NULL) + return xpi->xpi_ops; + return rcu_dereference(xpi->xpi_xpswitch)->xps_iter_ops; +} + +static +void xprt_iter_no_rewind(struct rpc_xprt_iter *xpi) +{ +} + +static +void xprt_iter_default_rewind(struct rpc_xprt_iter *xpi) +{ + WRITE_ONCE(xpi->xpi_cursor, NULL); +} + +static +bool xprt_is_active(const struct rpc_xprt *xprt) +{ + return (kref_read(&xprt->kref) != 0 && + !test_bit(XPRT_OFFLINE, &xprt->state)); +} + +static +struct rpc_xprt *xprt_switch_find_first_entry(struct list_head *head) +{ + struct rpc_xprt *pos; + + list_for_each_entry_rcu(pos, head, xprt_switch) { + if (xprt_is_active(pos)) + return pos; + } + return NULL; +} + +static +struct rpc_xprt *xprt_switch_find_first_entry_offline(struct list_head *head) +{ + struct rpc_xprt *pos; + + list_for_each_entry_rcu(pos, head, xprt_switch) { + if (!xprt_is_active(pos)) + return pos; + } + return NULL; +} + +static +struct rpc_xprt *xprt_iter_first_entry(struct rpc_xprt_iter *xpi) +{ + struct rpc_xprt_switch *xps = rcu_dereference(xpi->xpi_xpswitch); + + if (xps == NULL) + return NULL; + return xprt_switch_find_first_entry(&xps->xps_xprt_list); +} + +static +struct rpc_xprt *_xprt_switch_find_current_entry(struct list_head *head, + const struct rpc_xprt *cur, + bool find_active) +{ + struct rpc_xprt *pos; + bool found = false; + + list_for_each_entry_rcu(pos, head, xprt_switch) { + if (cur == pos) + found = true; + if (found && ((find_active && xprt_is_active(pos)) || + (!find_active && !xprt_is_active(pos)))) + return pos; + } + return NULL; +} + +static +struct rpc_xprt *xprt_switch_find_current_entry(struct list_head *head, + const struct rpc_xprt *cur) +{ + return _xprt_switch_find_current_entry(head, cur, true); +} + +static +struct rpc_xprt * _xprt_iter_current_entry(struct rpc_xprt_iter *xpi, + struct rpc_xprt *first_entry(struct list_head *head), + struct rpc_xprt *current_entry(struct list_head *head, + const struct rpc_xprt *cur)) +{ + struct rpc_xprt_switch *xps = rcu_dereference(xpi->xpi_xpswitch); + struct list_head *head; + + if (xps == NULL) + return NULL; + head = &xps->xps_xprt_list; + if (xpi->xpi_cursor == NULL || xps->xps_nxprts < 2) + return first_entry(head); + return current_entry(head, xpi->xpi_cursor); +} + +static +struct rpc_xprt *xprt_iter_current_entry(struct rpc_xprt_iter *xpi) +{ + return _xprt_iter_current_entry(xpi, xprt_switch_find_first_entry, + xprt_switch_find_current_entry); +} + +static +struct rpc_xprt *xprt_switch_find_current_entry_offline(struct list_head *head, + const struct rpc_xprt *cur) +{ + return _xprt_switch_find_current_entry(head, cur, false); +} + +static +struct rpc_xprt *xprt_iter_current_entry_offline(struct rpc_xprt_iter *xpi) +{ + return _xprt_iter_current_entry(xpi, + xprt_switch_find_first_entry_offline, + xprt_switch_find_current_entry_offline); +} + +bool rpc_xprt_switch_has_addr(struct rpc_xprt_switch *xps, + const struct sockaddr *sap) +{ + struct list_head *head; + struct rpc_xprt *pos; + + if (xps == NULL || sap == NULL) + return false; + + head = &xps->xps_xprt_list; + list_for_each_entry_rcu(pos, head, xprt_switch) { + if (rpc_cmp_addr_port(sap, (struct sockaddr *)&pos->addr)) { + pr_info("RPC: addr %s already in xprt switch\n", + pos->address_strings[RPC_DISPLAY_ADDR]); + return true; + } + } + return false; +} + +static +struct rpc_xprt *xprt_switch_find_next_entry(struct list_head *head, + const struct rpc_xprt *cur, bool check_active) +{ + struct rpc_xprt *pos, *prev = NULL; + bool found = false; + + list_for_each_entry_rcu(pos, head, xprt_switch) { + if (cur == prev) + found = true; + /* for request to return active transports return only + * active, for request to return offline transports + * return only offline + */ + if (found && ((check_active && xprt_is_active(pos)) || + (!check_active && !xprt_is_active(pos)))) + return pos; + prev = pos; + } + return NULL; +} + +static +struct rpc_xprt *xprt_switch_set_next_cursor(struct rpc_xprt_switch *xps, + struct rpc_xprt **cursor, + xprt_switch_find_xprt_t find_next) +{ + struct rpc_xprt *pos, *old; + + old = smp_load_acquire(cursor); + pos = find_next(xps, old); + smp_store_release(cursor, pos); + return pos; +} + +static +struct rpc_xprt *xprt_iter_next_entry_multiple(struct rpc_xprt_iter *xpi, + xprt_switch_find_xprt_t find_next) +{ + struct rpc_xprt_switch *xps = rcu_dereference(xpi->xpi_xpswitch); + + if (xps == NULL) + return NULL; + return xprt_switch_set_next_cursor(xps, &xpi->xpi_cursor, find_next); +} + +static +struct rpc_xprt *__xprt_switch_find_next_entry_roundrobin(struct list_head *head, + const struct rpc_xprt *cur) +{ + struct rpc_xprt *ret; + + ret = xprt_switch_find_next_entry(head, cur, true); + if (ret != NULL) + return ret; + return xprt_switch_find_first_entry(head); +} + +static +struct rpc_xprt *xprt_switch_find_next_entry_roundrobin(struct rpc_xprt_switch *xps, + const struct rpc_xprt *cur) +{ + struct list_head *head = &xps->xps_xprt_list; + struct rpc_xprt *xprt; + unsigned int nactive; + + for (;;) { + unsigned long xprt_queuelen, xps_queuelen; + + xprt = __xprt_switch_find_next_entry_roundrobin(head, cur); + if (!xprt) + break; + xprt_queuelen = atomic_long_read(&xprt->queuelen); + xps_queuelen = atomic_long_read(&xps->xps_queuelen); + nactive = READ_ONCE(xps->xps_nactive); + /* Exit loop if xprt_queuelen <= average queue length */ + if (xprt_queuelen * nactive <= xps_queuelen) + break; + cur = xprt; + } + return xprt; +} + +static +struct rpc_xprt *xprt_iter_next_entry_roundrobin(struct rpc_xprt_iter *xpi) +{ + return xprt_iter_next_entry_multiple(xpi, + xprt_switch_find_next_entry_roundrobin); +} + +static +struct rpc_xprt *xprt_switch_find_next_entry_all(struct rpc_xprt_switch *xps, + const struct rpc_xprt *cur) +{ + return xprt_switch_find_next_entry(&xps->xps_xprt_list, cur, true); +} + +static +struct rpc_xprt *xprt_switch_find_next_entry_offline(struct rpc_xprt_switch *xps, + const struct rpc_xprt *cur) +{ + return xprt_switch_find_next_entry(&xps->xps_xprt_list, cur, false); +} + +static +struct rpc_xprt *xprt_iter_next_entry_all(struct rpc_xprt_iter *xpi) +{ + return xprt_iter_next_entry_multiple(xpi, + xprt_switch_find_next_entry_all); +} + +static +struct rpc_xprt *xprt_iter_next_entry_offline(struct rpc_xprt_iter *xpi) +{ + return xprt_iter_next_entry_multiple(xpi, + xprt_switch_find_next_entry_offline); +} + +/* + * xprt_iter_rewind - Resets the xprt iterator + * @xpi: pointer to rpc_xprt_iter + * + * Resets xpi to ensure that it points to the first entry in the list + * of transports. + */ +void xprt_iter_rewind(struct rpc_xprt_iter *xpi) +{ + rcu_read_lock(); + xprt_iter_ops(xpi)->xpi_rewind(xpi); + rcu_read_unlock(); +} + +static void __xprt_iter_init(struct rpc_xprt_iter *xpi, + struct rpc_xprt_switch *xps, + const struct rpc_xprt_iter_ops *ops) +{ + rcu_assign_pointer(xpi->xpi_xpswitch, xprt_switch_get(xps)); + xpi->xpi_cursor = NULL; + xpi->xpi_ops = ops; +} + +/** + * xprt_iter_init - Initialise an xprt iterator + * @xpi: pointer to rpc_xprt_iter + * @xps: pointer to rpc_xprt_switch + * + * Initialises the iterator to use the default iterator ops + * as set in xps. This function is mainly intended for internal + * use in the rpc_client. + */ +void xprt_iter_init(struct rpc_xprt_iter *xpi, + struct rpc_xprt_switch *xps) +{ + __xprt_iter_init(xpi, xps, NULL); +} + +/** + * xprt_iter_init_listall - Initialise an xprt iterator + * @xpi: pointer to rpc_xprt_iter + * @xps: pointer to rpc_xprt_switch + * + * Initialises the iterator to iterate once through the entire list + * of entries in xps. + */ +void xprt_iter_init_listall(struct rpc_xprt_iter *xpi, + struct rpc_xprt_switch *xps) +{ + __xprt_iter_init(xpi, xps, &rpc_xprt_iter_listall); +} + +void xprt_iter_init_listoffline(struct rpc_xprt_iter *xpi, + struct rpc_xprt_switch *xps) +{ + __xprt_iter_init(xpi, xps, &rpc_xprt_iter_listoffline); +} + +/** + * xprt_iter_xchg_switch - Atomically swap out the rpc_xprt_switch + * @xpi: pointer to rpc_xprt_iter + * @newswitch: pointer to a new rpc_xprt_switch or NULL + * + * Swaps out the existing xpi->xpi_xpswitch with a new value. + */ +struct rpc_xprt_switch *xprt_iter_xchg_switch(struct rpc_xprt_iter *xpi, + struct rpc_xprt_switch *newswitch) +{ + struct rpc_xprt_switch __rcu *oldswitch; + + /* Atomically swap out the old xpswitch */ + oldswitch = xchg(&xpi->xpi_xpswitch, RCU_INITIALIZER(newswitch)); + if (newswitch != NULL) + xprt_iter_rewind(xpi); + return rcu_dereference_protected(oldswitch, true); +} + +/** + * xprt_iter_destroy - Destroys the xprt iterator + * @xpi: pointer to rpc_xprt_iter + */ +void xprt_iter_destroy(struct rpc_xprt_iter *xpi) +{ + xprt_switch_put(xprt_iter_xchg_switch(xpi, NULL)); +} + +/** + * xprt_iter_xprt - Returns the rpc_xprt pointed to by the cursor + * @xpi: pointer to rpc_xprt_iter + * + * Returns a pointer to the struct rpc_xprt that is currently + * pointed to by the cursor. + * Caller must be holding rcu_read_lock(). + */ +struct rpc_xprt *xprt_iter_xprt(struct rpc_xprt_iter *xpi) +{ + WARN_ON_ONCE(!rcu_read_lock_held()); + return xprt_iter_ops(xpi)->xpi_xprt(xpi); +} + +static +struct rpc_xprt *xprt_iter_get_helper(struct rpc_xprt_iter *xpi, + struct rpc_xprt *(*fn)(struct rpc_xprt_iter *)) +{ + struct rpc_xprt *ret; + + do { + ret = fn(xpi); + if (ret == NULL) + break; + ret = xprt_get(ret); + } while (ret == NULL); + return ret; +} + +/** + * xprt_iter_get_xprt - Returns the rpc_xprt pointed to by the cursor + * @xpi: pointer to rpc_xprt_iter + * + * Returns a reference to the struct rpc_xprt that is currently + * pointed to by the cursor. + */ +struct rpc_xprt *xprt_iter_get_xprt(struct rpc_xprt_iter *xpi) +{ + struct rpc_xprt *xprt; + + rcu_read_lock(); + xprt = xprt_iter_get_helper(xpi, xprt_iter_ops(xpi)->xpi_xprt); + rcu_read_unlock(); + return xprt; +} + +/** + * xprt_iter_get_next - Returns the next rpc_xprt following the cursor + * @xpi: pointer to rpc_xprt_iter + * + * Returns a reference to the struct rpc_xprt that immediately follows the + * entry pointed to by the cursor. + */ +struct rpc_xprt *xprt_iter_get_next(struct rpc_xprt_iter *xpi) +{ + struct rpc_xprt *xprt; + + rcu_read_lock(); + xprt = xprt_iter_get_helper(xpi, xprt_iter_ops(xpi)->xpi_next); + rcu_read_unlock(); + return xprt; +} + +/* Policy for always returning the first entry in the rpc_xprt_switch */ +static +const struct rpc_xprt_iter_ops rpc_xprt_iter_singular = { + .xpi_rewind = xprt_iter_no_rewind, + .xpi_xprt = xprt_iter_first_entry, + .xpi_next = xprt_iter_first_entry, +}; + +/* Policy for round-robin iteration of entries in the rpc_xprt_switch */ +static +const struct rpc_xprt_iter_ops rpc_xprt_iter_roundrobin = { + .xpi_rewind = xprt_iter_default_rewind, + .xpi_xprt = xprt_iter_current_entry, + .xpi_next = xprt_iter_next_entry_roundrobin, +}; + +/* Policy for once-through iteration of entries in the rpc_xprt_switch */ +static +const struct rpc_xprt_iter_ops rpc_xprt_iter_listall = { + .xpi_rewind = xprt_iter_default_rewind, + .xpi_xprt = xprt_iter_current_entry, + .xpi_next = xprt_iter_next_entry_all, +}; + +static +const struct rpc_xprt_iter_ops rpc_xprt_iter_listoffline = { + .xpi_rewind = xprt_iter_default_rewind, + .xpi_xprt = xprt_iter_current_entry_offline, + .xpi_next = xprt_iter_next_entry_offline, +}; diff --git a/net/sunrpc/xprtrdma/Makefile b/net/sunrpc/xprtrdma/Makefile new file mode 100644 index 000000000..55b21bae8 --- /dev/null +++ b/net/sunrpc/xprtrdma/Makefile @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: GPL-2.0 +obj-$(CONFIG_SUNRPC_XPRT_RDMA) += rpcrdma.o + +rpcrdma-y := transport.o rpc_rdma.o verbs.o frwr_ops.o \ + svc_rdma.o svc_rdma_backchannel.o svc_rdma_transport.o \ + svc_rdma_sendto.o svc_rdma_recvfrom.o svc_rdma_rw.o \ + svc_rdma_pcl.o module.o +rpcrdma-$(CONFIG_SUNRPC_BACKCHANNEL) += backchannel.o diff --git a/net/sunrpc/xprtrdma/backchannel.c b/net/sunrpc/xprtrdma/backchannel.c new file mode 100644 index 000000000..e4d84a13c --- /dev/null +++ b/net/sunrpc/xprtrdma/backchannel.c @@ -0,0 +1,282 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2015-2020, Oracle and/or its affiliates. + * + * Support for reverse-direction RPCs on RPC/RDMA. + */ + +#include +#include +#include +#include + +#include "xprt_rdma.h" +#include + +#undef RPCRDMA_BACKCHANNEL_DEBUG + +/** + * xprt_rdma_bc_setup - Pre-allocate resources for handling backchannel requests + * @xprt: transport associated with these backchannel resources + * @reqs: number of concurrent incoming requests to expect + * + * Returns 0 on success; otherwise a negative errno + */ +int xprt_rdma_bc_setup(struct rpc_xprt *xprt, unsigned int reqs) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + + r_xprt->rx_buf.rb_bc_srv_max_requests = RPCRDMA_BACKWARD_WRS >> 1; + trace_xprtrdma_cb_setup(r_xprt, reqs); + return 0; +} + +/** + * xprt_rdma_bc_maxpayload - Return maximum backchannel message size + * @xprt: transport + * + * Returns maximum size, in bytes, of a backchannel message + */ +size_t xprt_rdma_bc_maxpayload(struct rpc_xprt *xprt) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct rpcrdma_ep *ep = r_xprt->rx_ep; + size_t maxmsg; + + maxmsg = min_t(unsigned int, ep->re_inline_send, ep->re_inline_recv); + maxmsg = min_t(unsigned int, maxmsg, PAGE_SIZE); + return maxmsg - RPCRDMA_HDRLEN_MIN; +} + +unsigned int xprt_rdma_bc_max_slots(struct rpc_xprt *xprt) +{ + return RPCRDMA_BACKWARD_WRS >> 1; +} + +static int rpcrdma_bc_marshal_reply(struct rpc_rqst *rqst) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(rqst->rq_xprt); + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + __be32 *p; + + rpcrdma_set_xdrlen(&req->rl_hdrbuf, 0); + xdr_init_encode(&req->rl_stream, &req->rl_hdrbuf, + rdmab_data(req->rl_rdmabuf), rqst); + + p = xdr_reserve_space(&req->rl_stream, 28); + if (unlikely(!p)) + return -EIO; + *p++ = rqst->rq_xid; + *p++ = rpcrdma_version; + *p++ = cpu_to_be32(r_xprt->rx_buf.rb_bc_srv_max_requests); + *p++ = rdma_msg; + *p++ = xdr_zero; + *p++ = xdr_zero; + *p = xdr_zero; + + if (rpcrdma_prepare_send_sges(r_xprt, req, RPCRDMA_HDRLEN_MIN, + &rqst->rq_snd_buf, rpcrdma_noch_pullup)) + return -EIO; + + trace_xprtrdma_cb_reply(r_xprt, rqst); + return 0; +} + +/** + * xprt_rdma_bc_send_reply - marshal and send a backchannel reply + * @rqst: RPC rqst with a backchannel RPC reply in rq_snd_buf + * + * Caller holds the transport's write lock. + * + * Returns: + * %0 if the RPC message has been sent + * %-ENOTCONN if the caller should reconnect and call again + * %-EIO if a permanent error occurred and the request was not + * sent. Do not try to send this message again. + */ +int xprt_rdma_bc_send_reply(struct rpc_rqst *rqst) +{ + struct rpc_xprt *xprt = rqst->rq_xprt; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + int rc; + + if (!xprt_connected(xprt)) + return -ENOTCONN; + + if (!xprt_request_get_cong(xprt, rqst)) + return -EBADSLT; + + rc = rpcrdma_bc_marshal_reply(rqst); + if (rc < 0) + goto failed_marshal; + + if (frwr_send(r_xprt, req)) + goto drop_connection; + return 0; + +failed_marshal: + if (rc != -ENOTCONN) + return rc; +drop_connection: + xprt_rdma_close(xprt); + return -ENOTCONN; +} + +/** + * xprt_rdma_bc_destroy - Release resources for handling backchannel requests + * @xprt: transport associated with these backchannel resources + * @reqs: number of incoming requests to destroy; ignored + */ +void xprt_rdma_bc_destroy(struct rpc_xprt *xprt, unsigned int reqs) +{ + struct rpc_rqst *rqst, *tmp; + + spin_lock(&xprt->bc_pa_lock); + list_for_each_entry_safe(rqst, tmp, &xprt->bc_pa_list, rq_bc_pa_list) { + list_del(&rqst->rq_bc_pa_list); + spin_unlock(&xprt->bc_pa_lock); + + rpcrdma_req_destroy(rpcr_to_rdmar(rqst)); + + spin_lock(&xprt->bc_pa_lock); + } + spin_unlock(&xprt->bc_pa_lock); +} + +/** + * xprt_rdma_bc_free_rqst - Release a backchannel rqst + * @rqst: request to release + */ +void xprt_rdma_bc_free_rqst(struct rpc_rqst *rqst) +{ + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + struct rpcrdma_rep *rep = req->rl_reply; + struct rpc_xprt *xprt = rqst->rq_xprt; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + + rpcrdma_rep_put(&r_xprt->rx_buf, rep); + req->rl_reply = NULL; + + spin_lock(&xprt->bc_pa_lock); + list_add_tail(&rqst->rq_bc_pa_list, &xprt->bc_pa_list); + spin_unlock(&xprt->bc_pa_lock); + xprt_put(xprt); +} + +static struct rpc_rqst *rpcrdma_bc_rqst_get(struct rpcrdma_xprt *r_xprt) +{ + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct rpcrdma_req *req; + struct rpc_rqst *rqst; + size_t size; + + spin_lock(&xprt->bc_pa_lock); + rqst = list_first_entry_or_null(&xprt->bc_pa_list, struct rpc_rqst, + rq_bc_pa_list); + if (!rqst) + goto create_req; + list_del(&rqst->rq_bc_pa_list); + spin_unlock(&xprt->bc_pa_lock); + return rqst; + +create_req: + spin_unlock(&xprt->bc_pa_lock); + + /* Set a limit to prevent a remote from overrunning our resources. + */ + if (xprt->bc_alloc_count >= RPCRDMA_BACKWARD_WRS) + return NULL; + + size = min_t(size_t, r_xprt->rx_ep->re_inline_recv, PAGE_SIZE); + req = rpcrdma_req_create(r_xprt, size); + if (!req) + return NULL; + if (rpcrdma_req_setup(r_xprt, req)) { + rpcrdma_req_destroy(req); + return NULL; + } + + xprt->bc_alloc_count++; + rqst = &req->rl_slot; + rqst->rq_xprt = xprt; + __set_bit(RPC_BC_PA_IN_USE, &rqst->rq_bc_pa_state); + xdr_buf_init(&rqst->rq_snd_buf, rdmab_data(req->rl_sendbuf), size); + return rqst; +} + +/** + * rpcrdma_bc_receive_call - Handle a reverse-direction Call + * @r_xprt: transport receiving the call + * @rep: receive buffer containing the call + * + * Operational assumptions: + * o Backchannel credits are ignored, just as the NFS server + * forechannel currently does + * o The ULP manages a replay cache (eg, NFSv4.1 sessions). + * No replay detection is done at the transport level + */ +void rpcrdma_bc_receive_call(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_rep *rep) +{ + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct svc_serv *bc_serv; + struct rpcrdma_req *req; + struct rpc_rqst *rqst; + struct xdr_buf *buf; + size_t size; + __be32 *p; + + p = xdr_inline_decode(&rep->rr_stream, 0); + size = xdr_stream_remaining(&rep->rr_stream); + +#ifdef RPCRDMA_BACKCHANNEL_DEBUG + pr_info("RPC: %s: callback XID %08x, length=%u\n", + __func__, be32_to_cpup(p), size); + pr_info("RPC: %s: %*ph\n", __func__, size, p); +#endif + + rqst = rpcrdma_bc_rqst_get(r_xprt); + if (!rqst) + goto out_overflow; + + rqst->rq_reply_bytes_recvd = 0; + rqst->rq_xid = *p; + + rqst->rq_private_buf.len = size; + + buf = &rqst->rq_rcv_buf; + memset(buf, 0, sizeof(*buf)); + buf->head[0].iov_base = p; + buf->head[0].iov_len = size; + buf->len = size; + + /* The receive buffer has to be hooked to the rpcrdma_req + * so that it is not released while the req is pointing + * to its buffer, and so that it can be reposted after + * the Upper Layer is done decoding it. + */ + req = rpcr_to_rdmar(rqst); + req->rl_reply = rep; + trace_xprtrdma_cb_call(r_xprt, rqst); + + /* Queue rqst for ULP's callback service */ + bc_serv = xprt->bc_serv; + xprt_get(xprt); + spin_lock(&bc_serv->sv_cb_lock); + list_add(&rqst->rq_bc_list, &bc_serv->sv_cb_list); + spin_unlock(&bc_serv->sv_cb_lock); + + wake_up(&bc_serv->sv_cb_waitq); + + r_xprt->rx_stats.bcall_count++; + return; + +out_overflow: + pr_warn("RPC/RDMA backchannel overflow\n"); + xprt_force_disconnect(xprt); + /* This receive buffer gets reposted automatically + * when the connection is re-established. + */ + return; +} diff --git a/net/sunrpc/xprtrdma/frwr_ops.c b/net/sunrpc/xprtrdma/frwr_ops.c new file mode 100644 index 000000000..ffbf99894 --- /dev/null +++ b/net/sunrpc/xprtrdma/frwr_ops.c @@ -0,0 +1,696 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2015, 2017 Oracle. All rights reserved. + * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. + */ + +/* Lightweight memory registration using Fast Registration Work + * Requests (FRWR). + * + * FRWR features ordered asynchronous registration and invalidation + * of arbitrarily-sized memory regions. This is the fastest and safest + * but most complex memory registration mode. + */ + +/* Normal operation + * + * A Memory Region is prepared for RDMA Read or Write using a FAST_REG + * Work Request (frwr_map). When the RDMA operation is finished, this + * Memory Region is invalidated using a LOCAL_INV Work Request + * (frwr_unmap_async and frwr_unmap_sync). + * + * Typically FAST_REG Work Requests are not signaled, and neither are + * RDMA Send Work Requests (with the exception of signaling occasionally + * to prevent provider work queue overflows). This greatly reduces HCA + * interrupt workload. + */ + +/* Transport recovery + * + * frwr_map and frwr_unmap_* cannot run at the same time the transport + * connect worker is running. The connect worker holds the transport + * send lock, just as ->send_request does. This prevents frwr_map and + * the connect worker from running concurrently. When a connection is + * closed, the Receive completion queue is drained before the allowing + * the connect worker to get control. This prevents frwr_unmap and the + * connect worker from running concurrently. + * + * When the underlying transport disconnects, MRs that are in flight + * are flushed and are likely unusable. Thus all MRs are destroyed. + * New MRs are created on demand. + */ + +#include + +#include "xprt_rdma.h" +#include + +static void frwr_cid_init(struct rpcrdma_ep *ep, + struct rpcrdma_mr *mr) +{ + struct rpc_rdma_cid *cid = &mr->mr_cid; + + cid->ci_queue_id = ep->re_attr.send_cq->res.id; + cid->ci_completion_id = mr->mr_ibmr->res.id; +} + +static void frwr_mr_unmap(struct rpcrdma_xprt *r_xprt, struct rpcrdma_mr *mr) +{ + if (mr->mr_device) { + trace_xprtrdma_mr_unmap(mr); + ib_dma_unmap_sg(mr->mr_device, mr->mr_sg, mr->mr_nents, + mr->mr_dir); + mr->mr_device = NULL; + } +} + +/** + * frwr_mr_release - Destroy one MR + * @mr: MR allocated by frwr_mr_init + * + */ +void frwr_mr_release(struct rpcrdma_mr *mr) +{ + int rc; + + frwr_mr_unmap(mr->mr_xprt, mr); + + rc = ib_dereg_mr(mr->mr_ibmr); + if (rc) + trace_xprtrdma_frwr_dereg(mr, rc); + kfree(mr->mr_sg); + kfree(mr); +} + +static void frwr_mr_put(struct rpcrdma_mr *mr) +{ + frwr_mr_unmap(mr->mr_xprt, mr); + + /* The MR is returned to the req's MR free list instead + * of to the xprt's MR free list. No spinlock is needed. + */ + rpcrdma_mr_push(mr, &mr->mr_req->rl_free_mrs); +} + +/* frwr_reset - Place MRs back on the free list + * @req: request to reset + * + * Used after a failed marshal. For FRWR, this means the MRs + * don't have to be fully released and recreated. + * + * NB: This is safe only as long as none of @req's MRs are + * involved with an ongoing asynchronous FAST_REG or LOCAL_INV + * Work Request. + */ +void frwr_reset(struct rpcrdma_req *req) +{ + struct rpcrdma_mr *mr; + + while ((mr = rpcrdma_mr_pop(&req->rl_registered))) + frwr_mr_put(mr); +} + +/** + * frwr_mr_init - Initialize one MR + * @r_xprt: controlling transport instance + * @mr: generic MR to prepare for FRWR + * + * Returns zero if successful. Otherwise a negative errno + * is returned. + */ +int frwr_mr_init(struct rpcrdma_xprt *r_xprt, struct rpcrdma_mr *mr) +{ + struct rpcrdma_ep *ep = r_xprt->rx_ep; + unsigned int depth = ep->re_max_fr_depth; + struct scatterlist *sg; + struct ib_mr *frmr; + + sg = kcalloc_node(depth, sizeof(*sg), XPRTRDMA_GFP_FLAGS, + ibdev_to_node(ep->re_id->device)); + if (!sg) + return -ENOMEM; + + frmr = ib_alloc_mr(ep->re_pd, ep->re_mrtype, depth); + if (IS_ERR(frmr)) + goto out_mr_err; + + mr->mr_xprt = r_xprt; + mr->mr_ibmr = frmr; + mr->mr_device = NULL; + INIT_LIST_HEAD(&mr->mr_list); + init_completion(&mr->mr_linv_done); + frwr_cid_init(ep, mr); + + sg_init_table(sg, depth); + mr->mr_sg = sg; + return 0; + +out_mr_err: + kfree(sg); + trace_xprtrdma_frwr_alloc(mr, PTR_ERR(frmr)); + return PTR_ERR(frmr); +} + +/** + * frwr_query_device - Prepare a transport for use with FRWR + * @ep: endpoint to fill in + * @device: RDMA device to query + * + * On success, sets: + * ep->re_attr + * ep->re_max_requests + * ep->re_max_rdma_segs + * ep->re_max_fr_depth + * ep->re_mrtype + * + * Return values: + * On success, returns zero. + * %-EINVAL - the device does not support FRWR memory registration + * %-ENOMEM - the device is not sufficiently capable for NFS/RDMA + */ +int frwr_query_device(struct rpcrdma_ep *ep, const struct ib_device *device) +{ + const struct ib_device_attr *attrs = &device->attrs; + int max_qp_wr, depth, delta; + unsigned int max_sge; + + if (!(attrs->device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS) || + attrs->max_fast_reg_page_list_len == 0) { + pr_err("rpcrdma: 'frwr' mode is not supported by device %s\n", + device->name); + return -EINVAL; + } + + max_sge = min_t(unsigned int, attrs->max_send_sge, + RPCRDMA_MAX_SEND_SGES); + if (max_sge < RPCRDMA_MIN_SEND_SGES) { + pr_err("rpcrdma: HCA provides only %u send SGEs\n", max_sge); + return -ENOMEM; + } + ep->re_attr.cap.max_send_sge = max_sge; + ep->re_attr.cap.max_recv_sge = 1; + + ep->re_mrtype = IB_MR_TYPE_MEM_REG; + if (attrs->kernel_cap_flags & IBK_SG_GAPS_REG) + ep->re_mrtype = IB_MR_TYPE_SG_GAPS; + + /* Quirk: Some devices advertise a large max_fast_reg_page_list_len + * capability, but perform optimally when the MRs are not larger + * than a page. + */ + if (attrs->max_sge_rd > RPCRDMA_MAX_HDR_SEGS) + ep->re_max_fr_depth = attrs->max_sge_rd; + else + ep->re_max_fr_depth = attrs->max_fast_reg_page_list_len; + if (ep->re_max_fr_depth > RPCRDMA_MAX_DATA_SEGS) + ep->re_max_fr_depth = RPCRDMA_MAX_DATA_SEGS; + + /* Add room for frwr register and invalidate WRs. + * 1. FRWR reg WR for head + * 2. FRWR invalidate WR for head + * 3. N FRWR reg WRs for pagelist + * 4. N FRWR invalidate WRs for pagelist + * 5. FRWR reg WR for tail + * 6. FRWR invalidate WR for tail + * 7. The RDMA_SEND WR + */ + depth = 7; + + /* Calculate N if the device max FRWR depth is smaller than + * RPCRDMA_MAX_DATA_SEGS. + */ + if (ep->re_max_fr_depth < RPCRDMA_MAX_DATA_SEGS) { + delta = RPCRDMA_MAX_DATA_SEGS - ep->re_max_fr_depth; + do { + depth += 2; /* FRWR reg + invalidate */ + delta -= ep->re_max_fr_depth; + } while (delta > 0); + } + + max_qp_wr = attrs->max_qp_wr; + max_qp_wr -= RPCRDMA_BACKWARD_WRS; + max_qp_wr -= 1; + if (max_qp_wr < RPCRDMA_MIN_SLOT_TABLE) + return -ENOMEM; + if (ep->re_max_requests > max_qp_wr) + ep->re_max_requests = max_qp_wr; + ep->re_attr.cap.max_send_wr = ep->re_max_requests * depth; + if (ep->re_attr.cap.max_send_wr > max_qp_wr) { + ep->re_max_requests = max_qp_wr / depth; + if (!ep->re_max_requests) + return -ENOMEM; + ep->re_attr.cap.max_send_wr = ep->re_max_requests * depth; + } + ep->re_attr.cap.max_send_wr += RPCRDMA_BACKWARD_WRS; + ep->re_attr.cap.max_send_wr += 1; /* for ib_drain_sq */ + ep->re_attr.cap.max_recv_wr = ep->re_max_requests; + ep->re_attr.cap.max_recv_wr += RPCRDMA_BACKWARD_WRS; + ep->re_attr.cap.max_recv_wr += RPCRDMA_MAX_RECV_BATCH; + ep->re_attr.cap.max_recv_wr += 1; /* for ib_drain_rq */ + + ep->re_max_rdma_segs = + DIV_ROUND_UP(RPCRDMA_MAX_DATA_SEGS, ep->re_max_fr_depth); + /* Reply chunks require segments for head and tail buffers */ + ep->re_max_rdma_segs += 2; + if (ep->re_max_rdma_segs > RPCRDMA_MAX_HDR_SEGS) + ep->re_max_rdma_segs = RPCRDMA_MAX_HDR_SEGS; + + /* Ensure the underlying device is capable of conveying the + * largest r/wsize NFS will ask for. This guarantees that + * failing over from one RDMA device to another will not + * break NFS I/O. + */ + if ((ep->re_max_rdma_segs * ep->re_max_fr_depth) < RPCRDMA_MAX_SEGS) + return -ENOMEM; + + return 0; +} + +/** + * frwr_map - Register a memory region + * @r_xprt: controlling transport + * @seg: memory region co-ordinates + * @nsegs: number of segments remaining + * @writing: true when RDMA Write will be used + * @xid: XID of RPC using the registered memory + * @mr: MR to fill in + * + * Prepare a REG_MR Work Request to register a memory region + * for remote access via RDMA READ or RDMA WRITE. + * + * Returns the next segment or a negative errno pointer. + * On success, @mr is filled in. + */ +struct rpcrdma_mr_seg *frwr_map(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_mr_seg *seg, + int nsegs, bool writing, __be32 xid, + struct rpcrdma_mr *mr) +{ + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct ib_reg_wr *reg_wr; + int i, n, dma_nents; + struct ib_mr *ibmr; + u8 key; + + if (nsegs > ep->re_max_fr_depth) + nsegs = ep->re_max_fr_depth; + for (i = 0; i < nsegs;) { + sg_set_page(&mr->mr_sg[i], seg->mr_page, + seg->mr_len, seg->mr_offset); + + ++seg; + ++i; + if (ep->re_mrtype == IB_MR_TYPE_SG_GAPS) + continue; + if ((i < nsegs && seg->mr_offset) || + offset_in_page((seg-1)->mr_offset + (seg-1)->mr_len)) + break; + } + mr->mr_dir = rpcrdma_data_dir(writing); + mr->mr_nents = i; + + dma_nents = ib_dma_map_sg(ep->re_id->device, mr->mr_sg, mr->mr_nents, + mr->mr_dir); + if (!dma_nents) + goto out_dmamap_err; + mr->mr_device = ep->re_id->device; + + ibmr = mr->mr_ibmr; + n = ib_map_mr_sg(ibmr, mr->mr_sg, dma_nents, NULL, PAGE_SIZE); + if (n != dma_nents) + goto out_mapmr_err; + + ibmr->iova &= 0x00000000ffffffff; + ibmr->iova |= ((u64)be32_to_cpu(xid)) << 32; + key = (u8)(ibmr->rkey & 0x000000FF); + ib_update_fast_reg_key(ibmr, ++key); + + reg_wr = &mr->mr_regwr; + reg_wr->mr = ibmr; + reg_wr->key = ibmr->rkey; + reg_wr->access = writing ? + IB_ACCESS_REMOTE_WRITE | IB_ACCESS_LOCAL_WRITE : + IB_ACCESS_REMOTE_READ; + + mr->mr_handle = ibmr->rkey; + mr->mr_length = ibmr->length; + mr->mr_offset = ibmr->iova; + trace_xprtrdma_mr_map(mr); + + return seg; + +out_dmamap_err: + trace_xprtrdma_frwr_sgerr(mr, i); + return ERR_PTR(-EIO); + +out_mapmr_err: + trace_xprtrdma_frwr_maperr(mr, n); + return ERR_PTR(-EIO); +} + +/** + * frwr_wc_fastreg - Invoked by RDMA provider for a flushed FastReg WC + * @cq: completion queue + * @wc: WCE for a completed FastReg WR + * + * Each flushed MR gets destroyed after the QP has drained. + */ +static void frwr_wc_fastreg(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct rpcrdma_mr *mr = container_of(cqe, struct rpcrdma_mr, mr_cqe); + + /* WARNING: Only wr_cqe and status are reliable at this point */ + trace_xprtrdma_wc_fastreg(wc, &mr->mr_cid); + + rpcrdma_flush_disconnect(cq->cq_context, wc); +} + +/** + * frwr_send - post Send WRs containing the RPC Call message + * @r_xprt: controlling transport instance + * @req: prepared RPC Call + * + * For FRWR, chain any FastReg WRs to the Send WR. Only a + * single ib_post_send call is needed to register memory + * and then post the Send WR. + * + * Returns the return code from ib_post_send. + * + * Caller must hold the transport send lock to ensure that the + * pointers to the transport's rdma_cm_id and QP are stable. + */ +int frwr_send(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req) +{ + struct ib_send_wr *post_wr, *send_wr = &req->rl_wr; + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct rpcrdma_mr *mr; + unsigned int num_wrs; + int ret; + + num_wrs = 1; + post_wr = send_wr; + list_for_each_entry(mr, &req->rl_registered, mr_list) { + trace_xprtrdma_mr_fastreg(mr); + + mr->mr_cqe.done = frwr_wc_fastreg; + mr->mr_regwr.wr.next = post_wr; + mr->mr_regwr.wr.wr_cqe = &mr->mr_cqe; + mr->mr_regwr.wr.num_sge = 0; + mr->mr_regwr.wr.opcode = IB_WR_REG_MR; + mr->mr_regwr.wr.send_flags = 0; + post_wr = &mr->mr_regwr.wr; + ++num_wrs; + } + + if ((kref_read(&req->rl_kref) > 1) || num_wrs > ep->re_send_count) { + send_wr->send_flags |= IB_SEND_SIGNALED; + ep->re_send_count = min_t(unsigned int, ep->re_send_batch, + num_wrs - ep->re_send_count); + } else { + send_wr->send_flags &= ~IB_SEND_SIGNALED; + ep->re_send_count -= num_wrs; + } + + trace_xprtrdma_post_send(req); + ret = ib_post_send(ep->re_id->qp, post_wr, NULL); + if (ret) + trace_xprtrdma_post_send_err(r_xprt, req, ret); + return ret; +} + +/** + * frwr_reminv - handle a remotely invalidated mr on the @mrs list + * @rep: Received reply + * @mrs: list of MRs to check + * + */ +void frwr_reminv(struct rpcrdma_rep *rep, struct list_head *mrs) +{ + struct rpcrdma_mr *mr; + + list_for_each_entry(mr, mrs, mr_list) + if (mr->mr_handle == rep->rr_inv_rkey) { + list_del_init(&mr->mr_list); + trace_xprtrdma_mr_reminv(mr); + frwr_mr_put(mr); + break; /* only one invalidated MR per RPC */ + } +} + +static void frwr_mr_done(struct ib_wc *wc, struct rpcrdma_mr *mr) +{ + if (likely(wc->status == IB_WC_SUCCESS)) + frwr_mr_put(mr); +} + +/** + * frwr_wc_localinv - Invoked by RDMA provider for a LOCAL_INV WC + * @cq: completion queue + * @wc: WCE for a completed LocalInv WR + * + */ +static void frwr_wc_localinv(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct rpcrdma_mr *mr = container_of(cqe, struct rpcrdma_mr, mr_cqe); + + /* WARNING: Only wr_cqe and status are reliable at this point */ + trace_xprtrdma_wc_li(wc, &mr->mr_cid); + frwr_mr_done(wc, mr); + + rpcrdma_flush_disconnect(cq->cq_context, wc); +} + +/** + * frwr_wc_localinv_wake - Invoked by RDMA provider for a LOCAL_INV WC + * @cq: completion queue + * @wc: WCE for a completed LocalInv WR + * + * Awaken anyone waiting for an MR to finish being fenced. + */ +static void frwr_wc_localinv_wake(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct rpcrdma_mr *mr = container_of(cqe, struct rpcrdma_mr, mr_cqe); + + /* WARNING: Only wr_cqe and status are reliable at this point */ + trace_xprtrdma_wc_li_wake(wc, &mr->mr_cid); + frwr_mr_done(wc, mr); + complete(&mr->mr_linv_done); + + rpcrdma_flush_disconnect(cq->cq_context, wc); +} + +/** + * frwr_unmap_sync - invalidate memory regions that were registered for @req + * @r_xprt: controlling transport instance + * @req: rpcrdma_req with a non-empty list of MRs to process + * + * Sleeps until it is safe for the host CPU to access the previously mapped + * memory regions. This guarantees that registered MRs are properly fenced + * from the server before the RPC consumer accesses the data in them. It + * also ensures proper Send flow control: waking the next RPC waits until + * this RPC has relinquished all its Send Queue entries. + */ +void frwr_unmap_sync(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req) +{ + struct ib_send_wr *first, **prev, *last; + struct rpcrdma_ep *ep = r_xprt->rx_ep; + const struct ib_send_wr *bad_wr; + struct rpcrdma_mr *mr; + int rc; + + /* ORDER: Invalidate all of the MRs first + * + * Chain the LOCAL_INV Work Requests and post them with + * a single ib_post_send() call. + */ + prev = &first; + mr = rpcrdma_mr_pop(&req->rl_registered); + do { + trace_xprtrdma_mr_localinv(mr); + r_xprt->rx_stats.local_inv_needed++; + + last = &mr->mr_invwr; + last->next = NULL; + last->wr_cqe = &mr->mr_cqe; + last->sg_list = NULL; + last->num_sge = 0; + last->opcode = IB_WR_LOCAL_INV; + last->send_flags = IB_SEND_SIGNALED; + last->ex.invalidate_rkey = mr->mr_handle; + + last->wr_cqe->done = frwr_wc_localinv; + + *prev = last; + prev = &last->next; + } while ((mr = rpcrdma_mr_pop(&req->rl_registered))); + + mr = container_of(last, struct rpcrdma_mr, mr_invwr); + + /* Strong send queue ordering guarantees that when the + * last WR in the chain completes, all WRs in the chain + * are complete. + */ + last->wr_cqe->done = frwr_wc_localinv_wake; + reinit_completion(&mr->mr_linv_done); + + /* Transport disconnect drains the receive CQ before it + * replaces the QP. The RPC reply handler won't call us + * unless re_id->qp is a valid pointer. + */ + bad_wr = NULL; + rc = ib_post_send(ep->re_id->qp, first, &bad_wr); + + /* The final LOCAL_INV WR in the chain is supposed to + * do the wake. If it was never posted, the wake will + * not happen, so don't wait in that case. + */ + if (bad_wr != first) + wait_for_completion(&mr->mr_linv_done); + if (!rc) + return; + + /* On error, the MRs get destroyed once the QP has drained. */ + trace_xprtrdma_post_linv_err(req, rc); + + /* Force a connection loss to ensure complete recovery. + */ + rpcrdma_force_disconnect(ep); +} + +/** + * frwr_wc_localinv_done - Invoked by RDMA provider for a signaled LOCAL_INV WC + * @cq: completion queue + * @wc: WCE for a completed LocalInv WR + * + */ +static void frwr_wc_localinv_done(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct rpcrdma_mr *mr = container_of(cqe, struct rpcrdma_mr, mr_cqe); + struct rpcrdma_rep *rep; + + /* WARNING: Only wr_cqe and status are reliable at this point */ + trace_xprtrdma_wc_li_done(wc, &mr->mr_cid); + + /* Ensure that @rep is generated before the MR is released */ + rep = mr->mr_req->rl_reply; + smp_rmb(); + + if (wc->status != IB_WC_SUCCESS) { + if (rep) + rpcrdma_unpin_rqst(rep); + rpcrdma_flush_disconnect(cq->cq_context, wc); + return; + } + frwr_mr_put(mr); + rpcrdma_complete_rqst(rep); +} + +/** + * frwr_unmap_async - invalidate memory regions that were registered for @req + * @r_xprt: controlling transport instance + * @req: rpcrdma_req with a non-empty list of MRs to process + * + * This guarantees that registered MRs are properly fenced from the + * server before the RPC consumer accesses the data in them. It also + * ensures proper Send flow control: waking the next RPC waits until + * this RPC has relinquished all its Send Queue entries. + */ +void frwr_unmap_async(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req) +{ + struct ib_send_wr *first, *last, **prev; + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct rpcrdma_mr *mr; + int rc; + + /* Chain the LOCAL_INV Work Requests and post them with + * a single ib_post_send() call. + */ + prev = &first; + mr = rpcrdma_mr_pop(&req->rl_registered); + do { + trace_xprtrdma_mr_localinv(mr); + r_xprt->rx_stats.local_inv_needed++; + + last = &mr->mr_invwr; + last->next = NULL; + last->wr_cqe = &mr->mr_cqe; + last->sg_list = NULL; + last->num_sge = 0; + last->opcode = IB_WR_LOCAL_INV; + last->send_flags = IB_SEND_SIGNALED; + last->ex.invalidate_rkey = mr->mr_handle; + + last->wr_cqe->done = frwr_wc_localinv; + + *prev = last; + prev = &last->next; + } while ((mr = rpcrdma_mr_pop(&req->rl_registered))); + + /* Strong send queue ordering guarantees that when the + * last WR in the chain completes, all WRs in the chain + * are complete. The last completion will wake up the + * RPC waiter. + */ + last->wr_cqe->done = frwr_wc_localinv_done; + + /* Transport disconnect drains the receive CQ before it + * replaces the QP. The RPC reply handler won't call us + * unless re_id->qp is a valid pointer. + */ + rc = ib_post_send(ep->re_id->qp, first, NULL); + if (!rc) + return; + + /* On error, the MRs get destroyed once the QP has drained. */ + trace_xprtrdma_post_linv_err(req, rc); + + /* The final LOCAL_INV WR in the chain is supposed to + * do the wake. If it was never posted, the wake does + * not happen. Unpin the rqst in preparation for its + * retransmission. + */ + rpcrdma_unpin_rqst(req->rl_reply); + + /* Force a connection loss to ensure complete recovery. + */ + rpcrdma_force_disconnect(ep); +} + +/** + * frwr_wp_create - Create an MR for padding Write chunks + * @r_xprt: transport resources to use + * + * Return 0 on success, negative errno on failure. + */ +int frwr_wp_create(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct rpcrdma_mr_seg seg; + struct rpcrdma_mr *mr; + + mr = rpcrdma_mr_get(r_xprt); + if (!mr) + return -EAGAIN; + mr->mr_req = NULL; + ep->re_write_pad_mr = mr; + + seg.mr_len = XDR_UNIT; + seg.mr_page = virt_to_page(ep->re_write_pad); + seg.mr_offset = offset_in_page(ep->re_write_pad); + if (IS_ERR(frwr_map(r_xprt, &seg, 1, true, xdr_zero, mr))) + return -EIO; + trace_xprtrdma_mr_fastreg(mr); + + mr->mr_cqe.done = frwr_wc_fastreg; + mr->mr_regwr.wr.next = NULL; + mr->mr_regwr.wr.wr_cqe = &mr->mr_cqe; + mr->mr_regwr.wr.num_sge = 0; + mr->mr_regwr.wr.opcode = IB_WR_REG_MR; + mr->mr_regwr.wr.send_flags = 0; + + return ib_post_send(ep->re_id->qp, &mr->mr_regwr.wr, NULL); +} diff --git a/net/sunrpc/xprtrdma/module.c b/net/sunrpc/xprtrdma/module.c new file mode 100644 index 000000000..45c5b41ac --- /dev/null +++ b/net/sunrpc/xprtrdma/module.c @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2015, 2017 Oracle. All rights reserved. + */ + +/* rpcrdma.ko module initialization + */ + +#include +#include +#include +#include +#include + +#include + +#include "xprt_rdma.h" + +#define CREATE_TRACE_POINTS +#include + +MODULE_AUTHOR("Open Grid Computing and Network Appliance, Inc."); +MODULE_DESCRIPTION("RPC/RDMA Transport"); +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_ALIAS("svcrdma"); +MODULE_ALIAS("xprtrdma"); +MODULE_ALIAS("rpcrdma6"); + +static void __exit rpc_rdma_cleanup(void) +{ + xprt_rdma_cleanup(); + svc_rdma_cleanup(); +} + +static int __init rpc_rdma_init(void) +{ + int rc; + + rc = svc_rdma_init(); + if (rc) + goto out; + + rc = xprt_rdma_init(); + if (rc) + svc_rdma_cleanup(); + +out: + return rc; +} + +module_init(rpc_rdma_init); +module_exit(rpc_rdma_cleanup); diff --git a/net/sunrpc/xprtrdma/rpc_rdma.c b/net/sunrpc/xprtrdma/rpc_rdma.c new file mode 100644 index 000000000..190a4de23 --- /dev/null +++ b/net/sunrpc/xprtrdma/rpc_rdma.c @@ -0,0 +1,1510 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2014-2020, Oracle and/or its affiliates. + * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* + * rpc_rdma.c + * + * This file contains the guts of the RPC RDMA protocol, and + * does marshaling/unmarshaling, etc. It is also where interfacing + * to the Linux RPC framework lives. + */ + +#include + +#include + +#include "xprt_rdma.h" +#include + +/* Returns size of largest RPC-over-RDMA header in a Call message + * + * The largest Call header contains a full-size Read list and a + * minimal Reply chunk. + */ +static unsigned int rpcrdma_max_call_header_size(unsigned int maxsegs) +{ + unsigned int size; + + /* Fixed header fields and list discriminators */ + size = RPCRDMA_HDRLEN_MIN; + + /* Maximum Read list size */ + size += maxsegs * rpcrdma_readchunk_maxsz * sizeof(__be32); + + /* Minimal Read chunk size */ + size += sizeof(__be32); /* segment count */ + size += rpcrdma_segment_maxsz * sizeof(__be32); + size += sizeof(__be32); /* list discriminator */ + + return size; +} + +/* Returns size of largest RPC-over-RDMA header in a Reply message + * + * There is only one Write list or one Reply chunk per Reply + * message. The larger list is the Write list. + */ +static unsigned int rpcrdma_max_reply_header_size(unsigned int maxsegs) +{ + unsigned int size; + + /* Fixed header fields and list discriminators */ + size = RPCRDMA_HDRLEN_MIN; + + /* Maximum Write list size */ + size += sizeof(__be32); /* segment count */ + size += maxsegs * rpcrdma_segment_maxsz * sizeof(__be32); + size += sizeof(__be32); /* list discriminator */ + + return size; +} + +/** + * rpcrdma_set_max_header_sizes - Initialize inline payload sizes + * @ep: endpoint to initialize + * + * The max_inline fields contain the maximum size of an RPC message + * so the marshaling code doesn't have to repeat this calculation + * for every RPC. + */ +void rpcrdma_set_max_header_sizes(struct rpcrdma_ep *ep) +{ + unsigned int maxsegs = ep->re_max_rdma_segs; + + ep->re_max_inline_send = + ep->re_inline_send - rpcrdma_max_call_header_size(maxsegs); + ep->re_max_inline_recv = + ep->re_inline_recv - rpcrdma_max_reply_header_size(maxsegs); +} + +/* The client can send a request inline as long as the RPCRDMA header + * plus the RPC call fit under the transport's inline limit. If the + * combined call message size exceeds that limit, the client must use + * a Read chunk for this operation. + * + * A Read chunk is also required if sending the RPC call inline would + * exceed this device's max_sge limit. + */ +static bool rpcrdma_args_inline(struct rpcrdma_xprt *r_xprt, + struct rpc_rqst *rqst) +{ + struct xdr_buf *xdr = &rqst->rq_snd_buf; + struct rpcrdma_ep *ep = r_xprt->rx_ep; + unsigned int count, remaining, offset; + + if (xdr->len > ep->re_max_inline_send) + return false; + + if (xdr->page_len) { + remaining = xdr->page_len; + offset = offset_in_page(xdr->page_base); + count = RPCRDMA_MIN_SEND_SGES; + while (remaining) { + remaining -= min_t(unsigned int, + PAGE_SIZE - offset, remaining); + offset = 0; + if (++count > ep->re_attr.cap.max_send_sge) + return false; + } + } + + return true; +} + +/* The client can't know how large the actual reply will be. Thus it + * plans for the largest possible reply for that particular ULP + * operation. If the maximum combined reply message size exceeds that + * limit, the client must provide a write list or a reply chunk for + * this request. + */ +static bool rpcrdma_results_inline(struct rpcrdma_xprt *r_xprt, + struct rpc_rqst *rqst) +{ + return rqst->rq_rcv_buf.buflen <= r_xprt->rx_ep->re_max_inline_recv; +} + +/* The client is required to provide a Reply chunk if the maximum + * size of the non-payload part of the RPC Reply is larger than + * the inline threshold. + */ +static bool +rpcrdma_nonpayload_inline(const struct rpcrdma_xprt *r_xprt, + const struct rpc_rqst *rqst) +{ + const struct xdr_buf *buf = &rqst->rq_rcv_buf; + + return (buf->head[0].iov_len + buf->tail[0].iov_len) < + r_xprt->rx_ep->re_max_inline_recv; +} + +/* ACL likes to be lazy in allocating pages. For TCP, these + * pages can be allocated during receive processing. Not true + * for RDMA, which must always provision receive buffers + * up front. + */ +static noinline int +rpcrdma_alloc_sparse_pages(struct xdr_buf *buf) +{ + struct page **ppages; + int len; + + len = buf->page_len; + ppages = buf->pages + (buf->page_base >> PAGE_SHIFT); + while (len > 0) { + if (!*ppages) + *ppages = alloc_page(GFP_NOWAIT | __GFP_NOWARN); + if (!*ppages) + return -ENOBUFS; + ppages++; + len -= PAGE_SIZE; + } + + return 0; +} + +/* Convert @vec to a single SGL element. + * + * Returns pointer to next available SGE, and bumps the total number + * of SGEs consumed. + */ +static struct rpcrdma_mr_seg * +rpcrdma_convert_kvec(struct kvec *vec, struct rpcrdma_mr_seg *seg, + unsigned int *n) +{ + seg->mr_page = virt_to_page(vec->iov_base); + seg->mr_offset = offset_in_page(vec->iov_base); + seg->mr_len = vec->iov_len; + ++seg; + ++(*n); + return seg; +} + +/* Convert @xdrbuf into SGEs no larger than a page each. As they + * are registered, these SGEs are then coalesced into RDMA segments + * when the selected memreg mode supports it. + * + * Returns positive number of SGEs consumed, or a negative errno. + */ + +static int +rpcrdma_convert_iovs(struct rpcrdma_xprt *r_xprt, struct xdr_buf *xdrbuf, + unsigned int pos, enum rpcrdma_chunktype type, + struct rpcrdma_mr_seg *seg) +{ + unsigned long page_base; + unsigned int len, n; + struct page **ppages; + + n = 0; + if (pos == 0) + seg = rpcrdma_convert_kvec(&xdrbuf->head[0], seg, &n); + + len = xdrbuf->page_len; + ppages = xdrbuf->pages + (xdrbuf->page_base >> PAGE_SHIFT); + page_base = offset_in_page(xdrbuf->page_base); + while (len) { + seg->mr_page = *ppages; + seg->mr_offset = page_base; + seg->mr_len = min_t(u32, PAGE_SIZE - page_base, len); + len -= seg->mr_len; + ++ppages; + ++seg; + ++n; + page_base = 0; + } + + if (type == rpcrdma_readch || type == rpcrdma_writech) + goto out; + + if (xdrbuf->tail[0].iov_len) + rpcrdma_convert_kvec(&xdrbuf->tail[0], seg, &n); + +out: + if (unlikely(n > RPCRDMA_MAX_SEGS)) + return -EIO; + return n; +} + +static int +encode_rdma_segment(struct xdr_stream *xdr, struct rpcrdma_mr *mr) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, 4 * sizeof(*p)); + if (unlikely(!p)) + return -EMSGSIZE; + + xdr_encode_rdma_segment(p, mr->mr_handle, mr->mr_length, mr->mr_offset); + return 0; +} + +static int +encode_read_segment(struct xdr_stream *xdr, struct rpcrdma_mr *mr, + u32 position) +{ + __be32 *p; + + p = xdr_reserve_space(xdr, 6 * sizeof(*p)); + if (unlikely(!p)) + return -EMSGSIZE; + + *p++ = xdr_one; /* Item present */ + xdr_encode_read_segment(p, position, mr->mr_handle, mr->mr_length, + mr->mr_offset); + return 0; +} + +static struct rpcrdma_mr_seg *rpcrdma_mr_prepare(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct rpcrdma_mr_seg *seg, + int nsegs, bool writing, + struct rpcrdma_mr **mr) +{ + *mr = rpcrdma_mr_pop(&req->rl_free_mrs); + if (!*mr) { + *mr = rpcrdma_mr_get(r_xprt); + if (!*mr) + goto out_getmr_err; + (*mr)->mr_req = req; + } + + rpcrdma_mr_push(*mr, &req->rl_registered); + return frwr_map(r_xprt, seg, nsegs, writing, req->rl_slot.rq_xid, *mr); + +out_getmr_err: + trace_xprtrdma_nomrs_err(r_xprt, req); + xprt_wait_for_buffer_space(&r_xprt->rx_xprt); + rpcrdma_mrs_refresh(r_xprt); + return ERR_PTR(-EAGAIN); +} + +/* Register and XDR encode the Read list. Supports encoding a list of read + * segments that belong to a single read chunk. + * + * Encoding key for single-list chunks (HLOO = Handle32 Length32 Offset64): + * + * Read chunklist (a linked list): + * N elements, position P (same P for all chunks of same arg!): + * 1 - PHLOO - 1 - PHLOO - ... - 1 - PHLOO - 0 + * + * Returns zero on success, or a negative errno if a failure occurred. + * @xdr is advanced to the next position in the stream. + * + * Only a single @pos value is currently supported. + */ +static int rpcrdma_encode_read_list(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct rpc_rqst *rqst, + enum rpcrdma_chunktype rtype) +{ + struct xdr_stream *xdr = &req->rl_stream; + struct rpcrdma_mr_seg *seg; + struct rpcrdma_mr *mr; + unsigned int pos; + int nsegs; + + if (rtype == rpcrdma_noch_pullup || rtype == rpcrdma_noch_mapped) + goto done; + + pos = rqst->rq_snd_buf.head[0].iov_len; + if (rtype == rpcrdma_areadch) + pos = 0; + seg = req->rl_segments; + nsegs = rpcrdma_convert_iovs(r_xprt, &rqst->rq_snd_buf, pos, + rtype, seg); + if (nsegs < 0) + return nsegs; + + do { + seg = rpcrdma_mr_prepare(r_xprt, req, seg, nsegs, false, &mr); + if (IS_ERR(seg)) + return PTR_ERR(seg); + + if (encode_read_segment(xdr, mr, pos) < 0) + return -EMSGSIZE; + + trace_xprtrdma_chunk_read(rqst->rq_task, pos, mr, nsegs); + r_xprt->rx_stats.read_chunk_count++; + nsegs -= mr->mr_nents; + } while (nsegs); + +done: + if (xdr_stream_encode_item_absent(xdr) < 0) + return -EMSGSIZE; + return 0; +} + +/* Register and XDR encode the Write list. Supports encoding a list + * containing one array of plain segments that belong to a single + * write chunk. + * + * Encoding key for single-list chunks (HLOO = Handle32 Length32 Offset64): + * + * Write chunklist (a list of (one) counted array): + * N elements: + * 1 - N - HLOO - HLOO - ... - HLOO - 0 + * + * Returns zero on success, or a negative errno if a failure occurred. + * @xdr is advanced to the next position in the stream. + * + * Only a single Write chunk is currently supported. + */ +static int rpcrdma_encode_write_list(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct rpc_rqst *rqst, + enum rpcrdma_chunktype wtype) +{ + struct xdr_stream *xdr = &req->rl_stream; + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct rpcrdma_mr_seg *seg; + struct rpcrdma_mr *mr; + int nsegs, nchunks; + __be32 *segcount; + + if (wtype != rpcrdma_writech) + goto done; + + seg = req->rl_segments; + nsegs = rpcrdma_convert_iovs(r_xprt, &rqst->rq_rcv_buf, + rqst->rq_rcv_buf.head[0].iov_len, + wtype, seg); + if (nsegs < 0) + return nsegs; + + if (xdr_stream_encode_item_present(xdr) < 0) + return -EMSGSIZE; + segcount = xdr_reserve_space(xdr, sizeof(*segcount)); + if (unlikely(!segcount)) + return -EMSGSIZE; + /* Actual value encoded below */ + + nchunks = 0; + do { + seg = rpcrdma_mr_prepare(r_xprt, req, seg, nsegs, true, &mr); + if (IS_ERR(seg)) + return PTR_ERR(seg); + + if (encode_rdma_segment(xdr, mr) < 0) + return -EMSGSIZE; + + trace_xprtrdma_chunk_write(rqst->rq_task, mr, nsegs); + r_xprt->rx_stats.write_chunk_count++; + r_xprt->rx_stats.total_rdma_request += mr->mr_length; + nchunks++; + nsegs -= mr->mr_nents; + } while (nsegs); + + if (xdr_pad_size(rqst->rq_rcv_buf.page_len)) { + if (encode_rdma_segment(xdr, ep->re_write_pad_mr) < 0) + return -EMSGSIZE; + + trace_xprtrdma_chunk_wp(rqst->rq_task, ep->re_write_pad_mr, + nsegs); + r_xprt->rx_stats.write_chunk_count++; + r_xprt->rx_stats.total_rdma_request += mr->mr_length; + nchunks++; + nsegs -= mr->mr_nents; + } + + /* Update count of segments in this Write chunk */ + *segcount = cpu_to_be32(nchunks); + +done: + if (xdr_stream_encode_item_absent(xdr) < 0) + return -EMSGSIZE; + return 0; +} + +/* Register and XDR encode the Reply chunk. Supports encoding an array + * of plain segments that belong to a single write (reply) chunk. + * + * Encoding key for single-list chunks (HLOO = Handle32 Length32 Offset64): + * + * Reply chunk (a counted array): + * N elements: + * 1 - N - HLOO - HLOO - ... - HLOO + * + * Returns zero on success, or a negative errno if a failure occurred. + * @xdr is advanced to the next position in the stream. + */ +static int rpcrdma_encode_reply_chunk(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct rpc_rqst *rqst, + enum rpcrdma_chunktype wtype) +{ + struct xdr_stream *xdr = &req->rl_stream; + struct rpcrdma_mr_seg *seg; + struct rpcrdma_mr *mr; + int nsegs, nchunks; + __be32 *segcount; + + if (wtype != rpcrdma_replych) { + if (xdr_stream_encode_item_absent(xdr) < 0) + return -EMSGSIZE; + return 0; + } + + seg = req->rl_segments; + nsegs = rpcrdma_convert_iovs(r_xprt, &rqst->rq_rcv_buf, 0, wtype, seg); + if (nsegs < 0) + return nsegs; + + if (xdr_stream_encode_item_present(xdr) < 0) + return -EMSGSIZE; + segcount = xdr_reserve_space(xdr, sizeof(*segcount)); + if (unlikely(!segcount)) + return -EMSGSIZE; + /* Actual value encoded below */ + + nchunks = 0; + do { + seg = rpcrdma_mr_prepare(r_xprt, req, seg, nsegs, true, &mr); + if (IS_ERR(seg)) + return PTR_ERR(seg); + + if (encode_rdma_segment(xdr, mr) < 0) + return -EMSGSIZE; + + trace_xprtrdma_chunk_reply(rqst->rq_task, mr, nsegs); + r_xprt->rx_stats.reply_chunk_count++; + r_xprt->rx_stats.total_rdma_request += mr->mr_length; + nchunks++; + nsegs -= mr->mr_nents; + } while (nsegs); + + /* Update count of segments in the Reply chunk */ + *segcount = cpu_to_be32(nchunks); + + return 0; +} + +static void rpcrdma_sendctx_done(struct kref *kref) +{ + struct rpcrdma_req *req = + container_of(kref, struct rpcrdma_req, rl_kref); + struct rpcrdma_rep *rep = req->rl_reply; + + rpcrdma_complete_rqst(rep); + rep->rr_rxprt->rx_stats.reply_waits_for_send++; +} + +/** + * rpcrdma_sendctx_unmap - DMA-unmap Send buffer + * @sc: sendctx containing SGEs to unmap + * + */ +void rpcrdma_sendctx_unmap(struct rpcrdma_sendctx *sc) +{ + struct rpcrdma_regbuf *rb = sc->sc_req->rl_sendbuf; + struct ib_sge *sge; + + if (!sc->sc_unmap_count) + return; + + /* The first two SGEs contain the transport header and + * the inline buffer. These are always left mapped so + * they can be cheaply re-used. + */ + for (sge = &sc->sc_sges[2]; sc->sc_unmap_count; + ++sge, --sc->sc_unmap_count) + ib_dma_unmap_page(rdmab_device(rb), sge->addr, sge->length, + DMA_TO_DEVICE); + + kref_put(&sc->sc_req->rl_kref, rpcrdma_sendctx_done); +} + +/* Prepare an SGE for the RPC-over-RDMA transport header. + */ +static void rpcrdma_prepare_hdr_sge(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, u32 len) +{ + struct rpcrdma_sendctx *sc = req->rl_sendctx; + struct rpcrdma_regbuf *rb = req->rl_rdmabuf; + struct ib_sge *sge = &sc->sc_sges[req->rl_wr.num_sge++]; + + sge->addr = rdmab_addr(rb); + sge->length = len; + sge->lkey = rdmab_lkey(rb); + + ib_dma_sync_single_for_device(rdmab_device(rb), sge->addr, sge->length, + DMA_TO_DEVICE); +} + +/* The head iovec is straightforward, as it is usually already + * DMA-mapped. Sync the content that has changed. + */ +static bool rpcrdma_prepare_head_iov(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, unsigned int len) +{ + struct rpcrdma_sendctx *sc = req->rl_sendctx; + struct ib_sge *sge = &sc->sc_sges[req->rl_wr.num_sge++]; + struct rpcrdma_regbuf *rb = req->rl_sendbuf; + + if (!rpcrdma_regbuf_dma_map(r_xprt, rb)) + return false; + + sge->addr = rdmab_addr(rb); + sge->length = len; + sge->lkey = rdmab_lkey(rb); + + ib_dma_sync_single_for_device(rdmab_device(rb), sge->addr, sge->length, + DMA_TO_DEVICE); + return true; +} + +/* If there is a page list present, DMA map and prepare an + * SGE for each page to be sent. + */ +static bool rpcrdma_prepare_pagelist(struct rpcrdma_req *req, + struct xdr_buf *xdr) +{ + struct rpcrdma_sendctx *sc = req->rl_sendctx; + struct rpcrdma_regbuf *rb = req->rl_sendbuf; + unsigned int page_base, len, remaining; + struct page **ppages; + struct ib_sge *sge; + + ppages = xdr->pages + (xdr->page_base >> PAGE_SHIFT); + page_base = offset_in_page(xdr->page_base); + remaining = xdr->page_len; + while (remaining) { + sge = &sc->sc_sges[req->rl_wr.num_sge++]; + len = min_t(unsigned int, PAGE_SIZE - page_base, remaining); + sge->addr = ib_dma_map_page(rdmab_device(rb), *ppages, + page_base, len, DMA_TO_DEVICE); + if (ib_dma_mapping_error(rdmab_device(rb), sge->addr)) + goto out_mapping_err; + + sge->length = len; + sge->lkey = rdmab_lkey(rb); + + sc->sc_unmap_count++; + ppages++; + remaining -= len; + page_base = 0; + } + + return true; + +out_mapping_err: + trace_xprtrdma_dma_maperr(sge->addr); + return false; +} + +/* The tail iovec may include an XDR pad for the page list, + * as well as additional content, and may not reside in the + * same page as the head iovec. + */ +static bool rpcrdma_prepare_tail_iov(struct rpcrdma_req *req, + struct xdr_buf *xdr, + unsigned int page_base, unsigned int len) +{ + struct rpcrdma_sendctx *sc = req->rl_sendctx; + struct ib_sge *sge = &sc->sc_sges[req->rl_wr.num_sge++]; + struct rpcrdma_regbuf *rb = req->rl_sendbuf; + struct page *page = virt_to_page(xdr->tail[0].iov_base); + + sge->addr = ib_dma_map_page(rdmab_device(rb), page, page_base, len, + DMA_TO_DEVICE); + if (ib_dma_mapping_error(rdmab_device(rb), sge->addr)) + goto out_mapping_err; + + sge->length = len; + sge->lkey = rdmab_lkey(rb); + ++sc->sc_unmap_count; + return true; + +out_mapping_err: + trace_xprtrdma_dma_maperr(sge->addr); + return false; +} + +/* Copy the tail to the end of the head buffer. + */ +static void rpcrdma_pullup_tail_iov(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct xdr_buf *xdr) +{ + unsigned char *dst; + + dst = (unsigned char *)xdr->head[0].iov_base; + dst += xdr->head[0].iov_len + xdr->page_len; + memmove(dst, xdr->tail[0].iov_base, xdr->tail[0].iov_len); + r_xprt->rx_stats.pullup_copy_count += xdr->tail[0].iov_len; +} + +/* Copy pagelist content into the head buffer. + */ +static void rpcrdma_pullup_pagelist(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct xdr_buf *xdr) +{ + unsigned int len, page_base, remaining; + struct page **ppages; + unsigned char *src, *dst; + + dst = (unsigned char *)xdr->head[0].iov_base; + dst += xdr->head[0].iov_len; + ppages = xdr->pages + (xdr->page_base >> PAGE_SHIFT); + page_base = offset_in_page(xdr->page_base); + remaining = xdr->page_len; + while (remaining) { + src = page_address(*ppages); + src += page_base; + len = min_t(unsigned int, PAGE_SIZE - page_base, remaining); + memcpy(dst, src, len); + r_xprt->rx_stats.pullup_copy_count += len; + + ppages++; + dst += len; + remaining -= len; + page_base = 0; + } +} + +/* Copy the contents of @xdr into @rl_sendbuf and DMA sync it. + * When the head, pagelist, and tail are small, a pull-up copy + * is considerably less costly than DMA mapping the components + * of @xdr. + * + * Assumptions: + * - the caller has already verified that the total length + * of the RPC Call body will fit into @rl_sendbuf. + */ +static bool rpcrdma_prepare_noch_pullup(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct xdr_buf *xdr) +{ + if (unlikely(xdr->tail[0].iov_len)) + rpcrdma_pullup_tail_iov(r_xprt, req, xdr); + + if (unlikely(xdr->page_len)) + rpcrdma_pullup_pagelist(r_xprt, req, xdr); + + /* The whole RPC message resides in the head iovec now */ + return rpcrdma_prepare_head_iov(r_xprt, req, xdr->len); +} + +static bool rpcrdma_prepare_noch_mapped(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct xdr_buf *xdr) +{ + struct kvec *tail = &xdr->tail[0]; + + if (!rpcrdma_prepare_head_iov(r_xprt, req, xdr->head[0].iov_len)) + return false; + if (xdr->page_len) + if (!rpcrdma_prepare_pagelist(req, xdr)) + return false; + if (tail->iov_len) + if (!rpcrdma_prepare_tail_iov(req, xdr, + offset_in_page(tail->iov_base), + tail->iov_len)) + return false; + + if (req->rl_sendctx->sc_unmap_count) + kref_get(&req->rl_kref); + return true; +} + +static bool rpcrdma_prepare_readch(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, + struct xdr_buf *xdr) +{ + if (!rpcrdma_prepare_head_iov(r_xprt, req, xdr->head[0].iov_len)) + return false; + + /* If there is a Read chunk, the page list is being handled + * via explicit RDMA, and thus is skipped here. + */ + + /* Do not include the tail if it is only an XDR pad */ + if (xdr->tail[0].iov_len > 3) { + unsigned int page_base, len; + + /* If the content in the page list is an odd length, + * xdr_write_pages() adds a pad at the beginning of + * the tail iovec. Force the tail's non-pad content to + * land at the next XDR position in the Send message. + */ + page_base = offset_in_page(xdr->tail[0].iov_base); + len = xdr->tail[0].iov_len; + page_base += len & 3; + len -= len & 3; + if (!rpcrdma_prepare_tail_iov(req, xdr, page_base, len)) + return false; + kref_get(&req->rl_kref); + } + + return true; +} + +/** + * rpcrdma_prepare_send_sges - Construct SGEs for a Send WR + * @r_xprt: controlling transport + * @req: context of RPC Call being marshalled + * @hdrlen: size of transport header, in bytes + * @xdr: xdr_buf containing RPC Call + * @rtype: chunk type being encoded + * + * Returns 0 on success; otherwise a negative errno is returned. + */ +inline int rpcrdma_prepare_send_sges(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, u32 hdrlen, + struct xdr_buf *xdr, + enum rpcrdma_chunktype rtype) +{ + int ret; + + ret = -EAGAIN; + req->rl_sendctx = rpcrdma_sendctx_get_locked(r_xprt); + if (!req->rl_sendctx) + goto out_nosc; + req->rl_sendctx->sc_unmap_count = 0; + req->rl_sendctx->sc_req = req; + kref_init(&req->rl_kref); + req->rl_wr.wr_cqe = &req->rl_sendctx->sc_cqe; + req->rl_wr.sg_list = req->rl_sendctx->sc_sges; + req->rl_wr.num_sge = 0; + req->rl_wr.opcode = IB_WR_SEND; + + rpcrdma_prepare_hdr_sge(r_xprt, req, hdrlen); + + ret = -EIO; + switch (rtype) { + case rpcrdma_noch_pullup: + if (!rpcrdma_prepare_noch_pullup(r_xprt, req, xdr)) + goto out_unmap; + break; + case rpcrdma_noch_mapped: + if (!rpcrdma_prepare_noch_mapped(r_xprt, req, xdr)) + goto out_unmap; + break; + case rpcrdma_readch: + if (!rpcrdma_prepare_readch(r_xprt, req, xdr)) + goto out_unmap; + break; + case rpcrdma_areadch: + break; + default: + goto out_unmap; + } + + return 0; + +out_unmap: + rpcrdma_sendctx_unmap(req->rl_sendctx); +out_nosc: + trace_xprtrdma_prepsend_failed(&req->rl_slot, ret); + return ret; +} + +/** + * rpcrdma_marshal_req - Marshal and send one RPC request + * @r_xprt: controlling transport + * @rqst: RPC request to be marshaled + * + * For the RPC in "rqst", this function: + * - Chooses the transfer mode (eg., RDMA_MSG or RDMA_NOMSG) + * - Registers Read, Write, and Reply chunks + * - Constructs the transport header + * - Posts a Send WR to send the transport header and request + * + * Returns: + * %0 if the RPC was sent successfully, + * %-ENOTCONN if the connection was lost, + * %-EAGAIN if the caller should call again with the same arguments, + * %-ENOBUFS if the caller should call again after a delay, + * %-EMSGSIZE if the transport header is too small, + * %-EIO if a permanent problem occurred while marshaling. + */ +int +rpcrdma_marshal_req(struct rpcrdma_xprt *r_xprt, struct rpc_rqst *rqst) +{ + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + struct xdr_stream *xdr = &req->rl_stream; + enum rpcrdma_chunktype rtype, wtype; + struct xdr_buf *buf = &rqst->rq_snd_buf; + bool ddp_allowed; + __be32 *p; + int ret; + + if (unlikely(rqst->rq_rcv_buf.flags & XDRBUF_SPARSE_PAGES)) { + ret = rpcrdma_alloc_sparse_pages(&rqst->rq_rcv_buf); + if (ret) + return ret; + } + + rpcrdma_set_xdrlen(&req->rl_hdrbuf, 0); + xdr_init_encode(xdr, &req->rl_hdrbuf, rdmab_data(req->rl_rdmabuf), + rqst); + + /* Fixed header fields */ + ret = -EMSGSIZE; + p = xdr_reserve_space(xdr, 4 * sizeof(*p)); + if (!p) + goto out_err; + *p++ = rqst->rq_xid; + *p++ = rpcrdma_version; + *p++ = r_xprt->rx_buf.rb_max_requests; + + /* When the ULP employs a GSS flavor that guarantees integrity + * or privacy, direct data placement of individual data items + * is not allowed. + */ + ddp_allowed = !test_bit(RPCAUTH_AUTH_DATATOUCH, + &rqst->rq_cred->cr_auth->au_flags); + + /* + * Chunks needed for results? + * + * o If the expected result is under the inline threshold, all ops + * return as inline. + * o Large read ops return data as write chunk(s), header as + * inline. + * o Large non-read ops return as a single reply chunk. + */ + if (rpcrdma_results_inline(r_xprt, rqst)) + wtype = rpcrdma_noch; + else if ((ddp_allowed && rqst->rq_rcv_buf.flags & XDRBUF_READ) && + rpcrdma_nonpayload_inline(r_xprt, rqst)) + wtype = rpcrdma_writech; + else + wtype = rpcrdma_replych; + + /* + * Chunks needed for arguments? + * + * o If the total request is under the inline threshold, all ops + * are sent as inline. + * o Large write ops transmit data as read chunk(s), header as + * inline. + * o Large non-write ops are sent with the entire message as a + * single read chunk (protocol 0-position special case). + * + * This assumes that the upper layer does not present a request + * that both has a data payload, and whose non-data arguments + * by themselves are larger than the inline threshold. + */ + if (rpcrdma_args_inline(r_xprt, rqst)) { + *p++ = rdma_msg; + rtype = buf->len < rdmab_length(req->rl_sendbuf) ? + rpcrdma_noch_pullup : rpcrdma_noch_mapped; + } else if (ddp_allowed && buf->flags & XDRBUF_WRITE) { + *p++ = rdma_msg; + rtype = rpcrdma_readch; + } else { + r_xprt->rx_stats.nomsg_call_count++; + *p++ = rdma_nomsg; + rtype = rpcrdma_areadch; + } + + /* This implementation supports the following combinations + * of chunk lists in one RPC-over-RDMA Call message: + * + * - Read list + * - Write list + * - Reply chunk + * - Read list + Reply chunk + * + * It might not yet support the following combinations: + * + * - Read list + Write list + * + * It does not support the following combinations: + * + * - Write list + Reply chunk + * - Read list + Write list + Reply chunk + * + * This implementation supports only a single chunk in each + * Read or Write list. Thus for example the client cannot + * send a Call message with a Position Zero Read chunk and a + * regular Read chunk at the same time. + */ + ret = rpcrdma_encode_read_list(r_xprt, req, rqst, rtype); + if (ret) + goto out_err; + ret = rpcrdma_encode_write_list(r_xprt, req, rqst, wtype); + if (ret) + goto out_err; + ret = rpcrdma_encode_reply_chunk(r_xprt, req, rqst, wtype); + if (ret) + goto out_err; + + ret = rpcrdma_prepare_send_sges(r_xprt, req, req->rl_hdrbuf.len, + buf, rtype); + if (ret) + goto out_err; + + trace_xprtrdma_marshal(req, rtype, wtype); + return 0; + +out_err: + trace_xprtrdma_marshal_failed(rqst, ret); + r_xprt->rx_stats.failed_marshal_count++; + frwr_reset(req); + return ret; +} + +static void __rpcrdma_update_cwnd_locked(struct rpc_xprt *xprt, + struct rpcrdma_buffer *buf, + u32 grant) +{ + buf->rb_credits = grant; + xprt->cwnd = grant << RPC_CWNDSHIFT; +} + +static void rpcrdma_update_cwnd(struct rpcrdma_xprt *r_xprt, u32 grant) +{ + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + + spin_lock(&xprt->transport_lock); + __rpcrdma_update_cwnd_locked(xprt, &r_xprt->rx_buf, grant); + spin_unlock(&xprt->transport_lock); +} + +/** + * rpcrdma_reset_cwnd - Reset the xprt's congestion window + * @r_xprt: controlling transport instance + * + * Prepare @r_xprt for the next connection by reinitializing + * its credit grant to one (see RFC 8166, Section 3.3.3). + */ +void rpcrdma_reset_cwnd(struct rpcrdma_xprt *r_xprt) +{ + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + + spin_lock(&xprt->transport_lock); + xprt->cong = 0; + __rpcrdma_update_cwnd_locked(xprt, &r_xprt->rx_buf, 1); + spin_unlock(&xprt->transport_lock); +} + +/** + * rpcrdma_inline_fixup - Scatter inline received data into rqst's iovecs + * @rqst: controlling RPC request + * @srcp: points to RPC message payload in receive buffer + * @copy_len: remaining length of receive buffer content + * @pad: Write chunk pad bytes needed (zero for pure inline) + * + * The upper layer has set the maximum number of bytes it can + * receive in each component of rq_rcv_buf. These values are set in + * the head.iov_len, page_len, tail.iov_len, and buflen fields. + * + * Unlike the TCP equivalent (xdr_partial_copy_from_skb), in + * many cases this function simply updates iov_base pointers in + * rq_rcv_buf to point directly to the received reply data, to + * avoid copying reply data. + * + * Returns the count of bytes which had to be memcopied. + */ +static unsigned long +rpcrdma_inline_fixup(struct rpc_rqst *rqst, char *srcp, int copy_len, int pad) +{ + unsigned long fixup_copy_count; + int i, npages, curlen; + char *destp; + struct page **ppages; + int page_base; + + /* The head iovec is redirected to the RPC reply message + * in the receive buffer, to avoid a memcopy. + */ + rqst->rq_rcv_buf.head[0].iov_base = srcp; + rqst->rq_private_buf.head[0].iov_base = srcp; + + /* The contents of the receive buffer that follow + * head.iov_len bytes are copied into the page list. + */ + curlen = rqst->rq_rcv_buf.head[0].iov_len; + if (curlen > copy_len) + curlen = copy_len; + srcp += curlen; + copy_len -= curlen; + + ppages = rqst->rq_rcv_buf.pages + + (rqst->rq_rcv_buf.page_base >> PAGE_SHIFT); + page_base = offset_in_page(rqst->rq_rcv_buf.page_base); + fixup_copy_count = 0; + if (copy_len && rqst->rq_rcv_buf.page_len) { + int pagelist_len; + + pagelist_len = rqst->rq_rcv_buf.page_len; + if (pagelist_len > copy_len) + pagelist_len = copy_len; + npages = PAGE_ALIGN(page_base + pagelist_len) >> PAGE_SHIFT; + for (i = 0; i < npages; i++) { + curlen = PAGE_SIZE - page_base; + if (curlen > pagelist_len) + curlen = pagelist_len; + + destp = kmap_atomic(ppages[i]); + memcpy(destp + page_base, srcp, curlen); + flush_dcache_page(ppages[i]); + kunmap_atomic(destp); + srcp += curlen; + copy_len -= curlen; + fixup_copy_count += curlen; + pagelist_len -= curlen; + if (!pagelist_len) + break; + page_base = 0; + } + + /* Implicit padding for the last segment in a Write + * chunk is inserted inline at the front of the tail + * iovec. The upper layer ignores the content of + * the pad. Simply ensure inline content in the tail + * that follows the Write chunk is properly aligned. + */ + if (pad) + srcp -= pad; + } + + /* The tail iovec is redirected to the remaining data + * in the receive buffer, to avoid a memcopy. + */ + if (copy_len || pad) { + rqst->rq_rcv_buf.tail[0].iov_base = srcp; + rqst->rq_private_buf.tail[0].iov_base = srcp; + } + + if (fixup_copy_count) + trace_xprtrdma_fixup(rqst, fixup_copy_count); + return fixup_copy_count; +} + +/* By convention, backchannel calls arrive via rdma_msg type + * messages, and never populate the chunk lists. This makes + * the RPC/RDMA header small and fixed in size, so it is + * straightforward to check the RPC header's direction field. + */ +static bool +rpcrdma_is_bcall(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep) +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +{ + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct xdr_stream *xdr = &rep->rr_stream; + __be32 *p; + + if (rep->rr_proc != rdma_msg) + return false; + + /* Peek at stream contents without advancing. */ + p = xdr_inline_decode(xdr, 0); + + /* Chunk lists */ + if (xdr_item_is_present(p++)) + return false; + if (xdr_item_is_present(p++)) + return false; + if (xdr_item_is_present(p++)) + return false; + + /* RPC header */ + if (*p++ != rep->rr_xid) + return false; + if (*p != cpu_to_be32(RPC_CALL)) + return false; + + /* No bc service. */ + if (xprt->bc_serv == NULL) + return false; + + /* Now that we are sure this is a backchannel call, + * advance to the RPC header. + */ + p = xdr_inline_decode(xdr, 3 * sizeof(*p)); + if (unlikely(!p)) + return true; + + rpcrdma_bc_receive_call(r_xprt, rep); + return true; +} +#else /* CONFIG_SUNRPC_BACKCHANNEL */ +{ + return false; +} +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +static int decode_rdma_segment(struct xdr_stream *xdr, u32 *length) +{ + u32 handle; + u64 offset; + __be32 *p; + + p = xdr_inline_decode(xdr, 4 * sizeof(*p)); + if (unlikely(!p)) + return -EIO; + + xdr_decode_rdma_segment(p, &handle, length, &offset); + trace_xprtrdma_decode_seg(handle, *length, offset); + return 0; +} + +static int decode_write_chunk(struct xdr_stream *xdr, u32 *length) +{ + u32 segcount, seglength; + __be32 *p; + + p = xdr_inline_decode(xdr, sizeof(*p)); + if (unlikely(!p)) + return -EIO; + + *length = 0; + segcount = be32_to_cpup(p); + while (segcount--) { + if (decode_rdma_segment(xdr, &seglength)) + return -EIO; + *length += seglength; + } + + return 0; +} + +/* In RPC-over-RDMA Version One replies, a Read list is never + * expected. This decoder is a stub that returns an error if + * a Read list is present. + */ +static int decode_read_list(struct xdr_stream *xdr) +{ + __be32 *p; + + p = xdr_inline_decode(xdr, sizeof(*p)); + if (unlikely(!p)) + return -EIO; + if (unlikely(xdr_item_is_present(p))) + return -EIO; + return 0; +} + +/* Supports only one Write chunk in the Write list + */ +static int decode_write_list(struct xdr_stream *xdr, u32 *length) +{ + u32 chunklen; + bool first; + __be32 *p; + + *length = 0; + first = true; + do { + p = xdr_inline_decode(xdr, sizeof(*p)); + if (unlikely(!p)) + return -EIO; + if (xdr_item_is_absent(p)) + break; + if (!first) + return -EIO; + + if (decode_write_chunk(xdr, &chunklen)) + return -EIO; + *length += chunklen; + first = false; + } while (true); + return 0; +} + +static int decode_reply_chunk(struct xdr_stream *xdr, u32 *length) +{ + __be32 *p; + + p = xdr_inline_decode(xdr, sizeof(*p)); + if (unlikely(!p)) + return -EIO; + + *length = 0; + if (xdr_item_is_present(p)) + if (decode_write_chunk(xdr, length)) + return -EIO; + return 0; +} + +static int +rpcrdma_decode_msg(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep, + struct rpc_rqst *rqst) +{ + struct xdr_stream *xdr = &rep->rr_stream; + u32 writelist, replychunk, rpclen; + char *base; + + /* Decode the chunk lists */ + if (decode_read_list(xdr)) + return -EIO; + if (decode_write_list(xdr, &writelist)) + return -EIO; + if (decode_reply_chunk(xdr, &replychunk)) + return -EIO; + + /* RDMA_MSG sanity checks */ + if (unlikely(replychunk)) + return -EIO; + + /* Build the RPC reply's Payload stream in rqst->rq_rcv_buf */ + base = (char *)xdr_inline_decode(xdr, 0); + rpclen = xdr_stream_remaining(xdr); + r_xprt->rx_stats.fixup_copy_count += + rpcrdma_inline_fixup(rqst, base, rpclen, writelist & 3); + + r_xprt->rx_stats.total_rdma_reply += writelist; + return rpclen + xdr_align_size(writelist); +} + +static noinline int +rpcrdma_decode_nomsg(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep) +{ + struct xdr_stream *xdr = &rep->rr_stream; + u32 writelist, replychunk; + + /* Decode the chunk lists */ + if (decode_read_list(xdr)) + return -EIO; + if (decode_write_list(xdr, &writelist)) + return -EIO; + if (decode_reply_chunk(xdr, &replychunk)) + return -EIO; + + /* RDMA_NOMSG sanity checks */ + if (unlikely(writelist)) + return -EIO; + if (unlikely(!replychunk)) + return -EIO; + + /* Reply chunk buffer already is the reply vector */ + r_xprt->rx_stats.total_rdma_reply += replychunk; + return replychunk; +} + +static noinline int +rpcrdma_decode_error(struct rpcrdma_xprt *r_xprt, struct rpcrdma_rep *rep, + struct rpc_rqst *rqst) +{ + struct xdr_stream *xdr = &rep->rr_stream; + __be32 *p; + + p = xdr_inline_decode(xdr, sizeof(*p)); + if (unlikely(!p)) + return -EIO; + + switch (*p) { + case err_vers: + p = xdr_inline_decode(xdr, 2 * sizeof(*p)); + if (!p) + break; + trace_xprtrdma_err_vers(rqst, p, p + 1); + break; + case err_chunk: + trace_xprtrdma_err_chunk(rqst); + break; + default: + trace_xprtrdma_err_unrecognized(rqst, p); + } + + return -EIO; +} + +/** + * rpcrdma_unpin_rqst - Release rqst without completing it + * @rep: RPC/RDMA Receive context + * + * This is done when a connection is lost so that a Reply + * can be dropped and its matching Call can be subsequently + * retransmitted on a new connection. + */ +void rpcrdma_unpin_rqst(struct rpcrdma_rep *rep) +{ + struct rpc_xprt *xprt = &rep->rr_rxprt->rx_xprt; + struct rpc_rqst *rqst = rep->rr_rqst; + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + + req->rl_reply = NULL; + rep->rr_rqst = NULL; + + spin_lock(&xprt->queue_lock); + xprt_unpin_rqst(rqst); + spin_unlock(&xprt->queue_lock); +} + +/** + * rpcrdma_complete_rqst - Pass completed rqst back to RPC + * @rep: RPC/RDMA Receive context + * + * Reconstruct the RPC reply and complete the transaction + * while @rqst is still pinned to ensure the rep, rqst, and + * rq_task pointers remain stable. + */ +void rpcrdma_complete_rqst(struct rpcrdma_rep *rep) +{ + struct rpcrdma_xprt *r_xprt = rep->rr_rxprt; + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct rpc_rqst *rqst = rep->rr_rqst; + int status; + + switch (rep->rr_proc) { + case rdma_msg: + status = rpcrdma_decode_msg(r_xprt, rep, rqst); + break; + case rdma_nomsg: + status = rpcrdma_decode_nomsg(r_xprt, rep); + break; + case rdma_error: + status = rpcrdma_decode_error(r_xprt, rep, rqst); + break; + default: + status = -EIO; + } + if (status < 0) + goto out_badheader; + +out: + spin_lock(&xprt->queue_lock); + xprt_complete_rqst(rqst->rq_task, status); + xprt_unpin_rqst(rqst); + spin_unlock(&xprt->queue_lock); + return; + +out_badheader: + trace_xprtrdma_reply_hdr_err(rep); + r_xprt->rx_stats.bad_reply_count++; + rqst->rq_task->tk_status = status; + status = 0; + goto out; +} + +static void rpcrdma_reply_done(struct kref *kref) +{ + struct rpcrdma_req *req = + container_of(kref, struct rpcrdma_req, rl_kref); + + rpcrdma_complete_rqst(req->rl_reply); +} + +/** + * rpcrdma_reply_handler - Process received RPC/RDMA messages + * @rep: Incoming rpcrdma_rep object to process + * + * Errors must result in the RPC task either being awakened, or + * allowed to timeout, to discover the errors at that time. + */ +void rpcrdma_reply_handler(struct rpcrdma_rep *rep) +{ + struct rpcrdma_xprt *r_xprt = rep->rr_rxprt; + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_req *req; + struct rpc_rqst *rqst; + u32 credits; + __be32 *p; + + /* Any data means we had a useful conversation, so + * then we don't need to delay the next reconnect. + */ + if (xprt->reestablish_timeout) + xprt->reestablish_timeout = 0; + + /* Fixed transport header fields */ + xdr_init_decode(&rep->rr_stream, &rep->rr_hdrbuf, + rep->rr_hdrbuf.head[0].iov_base, NULL); + p = xdr_inline_decode(&rep->rr_stream, 4 * sizeof(*p)); + if (unlikely(!p)) + goto out_shortreply; + rep->rr_xid = *p++; + rep->rr_vers = *p++; + credits = be32_to_cpu(*p++); + rep->rr_proc = *p++; + + if (rep->rr_vers != rpcrdma_version) + goto out_badversion; + + if (rpcrdma_is_bcall(r_xprt, rep)) + return; + + /* Match incoming rpcrdma_rep to an rpcrdma_req to + * get context for handling any incoming chunks. + */ + spin_lock(&xprt->queue_lock); + rqst = xprt_lookup_rqst(xprt, rep->rr_xid); + if (!rqst) + goto out_norqst; + xprt_pin_rqst(rqst); + spin_unlock(&xprt->queue_lock); + + if (credits == 0) + credits = 1; /* don't deadlock */ + else if (credits > r_xprt->rx_ep->re_max_requests) + credits = r_xprt->rx_ep->re_max_requests; + rpcrdma_post_recvs(r_xprt, credits + (buf->rb_bc_srv_max_requests << 1), + false); + if (buf->rb_credits != credits) + rpcrdma_update_cwnd(r_xprt, credits); + + req = rpcr_to_rdmar(rqst); + if (unlikely(req->rl_reply)) + rpcrdma_rep_put(buf, req->rl_reply); + req->rl_reply = rep; + rep->rr_rqst = rqst; + + trace_xprtrdma_reply(rqst->rq_task, rep, credits); + + if (rep->rr_wc_flags & IB_WC_WITH_INVALIDATE) + frwr_reminv(rep, &req->rl_registered); + if (!list_empty(&req->rl_registered)) + frwr_unmap_async(r_xprt, req); + /* LocalInv completion will complete the RPC */ + else + kref_put(&req->rl_kref, rpcrdma_reply_done); + return; + +out_badversion: + trace_xprtrdma_reply_vers_err(rep); + goto out; + +out_norqst: + spin_unlock(&xprt->queue_lock); + trace_xprtrdma_reply_rqst_err(rep); + goto out; + +out_shortreply: + trace_xprtrdma_reply_short_err(rep); + +out: + rpcrdma_rep_put(buf, rep); +} diff --git a/net/sunrpc/xprtrdma/svc_rdma.c b/net/sunrpc/xprtrdma/svc_rdma.c new file mode 100644 index 000000000..5bc20e9d0 --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma.c @@ -0,0 +1,300 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2015-2018 Oracle. All rights reserved. + * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Author: Tom Tucker + */ + +#include +#include +#include +#include +#include +#include +#include + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + +/* RPC/RDMA parameters */ +unsigned int svcrdma_ord = 16; /* historical default */ +static unsigned int min_ord = 1; +static unsigned int max_ord = 255; +unsigned int svcrdma_max_requests = RPCRDMA_MAX_REQUESTS; +unsigned int svcrdma_max_bc_requests = RPCRDMA_MAX_BC_REQUESTS; +static unsigned int min_max_requests = 4; +static unsigned int max_max_requests = 16384; +unsigned int svcrdma_max_req_size = RPCRDMA_DEF_INLINE_THRESH; +static unsigned int min_max_inline = RPCRDMA_DEF_INLINE_THRESH; +static unsigned int max_max_inline = RPCRDMA_MAX_INLINE_THRESH; +static unsigned int svcrdma_stat_unused; +static unsigned int zero; + +struct percpu_counter svcrdma_stat_read; +struct percpu_counter svcrdma_stat_recv; +struct percpu_counter svcrdma_stat_sq_starve; +struct percpu_counter svcrdma_stat_write; + +enum { + SVCRDMA_COUNTER_BUFSIZ = sizeof(unsigned long long), +}; + +static int svcrdma_counter_handler(struct ctl_table *table, int write, + void *buffer, size_t *lenp, loff_t *ppos) +{ + struct percpu_counter *stat = (struct percpu_counter *)table->data; + char tmp[SVCRDMA_COUNTER_BUFSIZ + 1]; + int len; + + if (write) { + percpu_counter_set(stat, 0); + return 0; + } + + len = snprintf(tmp, SVCRDMA_COUNTER_BUFSIZ, "%lld\n", + percpu_counter_sum_positive(stat)); + if (len >= SVCRDMA_COUNTER_BUFSIZ) + return -EFAULT; + len = strlen(tmp); + if (*ppos > len) { + *lenp = 0; + return 0; + } + len -= *ppos; + if (len > *lenp) + len = *lenp; + if (len) + memcpy(buffer, tmp, len); + *lenp = len; + *ppos += len; + + return 0; +} + +static struct ctl_table_header *svcrdma_table_header; +static struct ctl_table svcrdma_parm_table[] = { + { + .procname = "max_requests", + .data = &svcrdma_max_requests, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_max_requests, + .extra2 = &max_max_requests + }, + { + .procname = "max_req_size", + .data = &svcrdma_max_req_size, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_max_inline, + .extra2 = &max_max_inline + }, + { + .procname = "max_outbound_read_requests", + .data = &svcrdma_ord, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_ord, + .extra2 = &max_ord, + }, + + { + .procname = "rdma_stat_read", + .data = &svcrdma_stat_read, + .maxlen = SVCRDMA_COUNTER_BUFSIZ, + .mode = 0644, + .proc_handler = svcrdma_counter_handler, + }, + { + .procname = "rdma_stat_recv", + .data = &svcrdma_stat_recv, + .maxlen = SVCRDMA_COUNTER_BUFSIZ, + .mode = 0644, + .proc_handler = svcrdma_counter_handler, + }, + { + .procname = "rdma_stat_write", + .data = &svcrdma_stat_write, + .maxlen = SVCRDMA_COUNTER_BUFSIZ, + .mode = 0644, + .proc_handler = svcrdma_counter_handler, + }, + { + .procname = "rdma_stat_sq_starve", + .data = &svcrdma_stat_sq_starve, + .maxlen = SVCRDMA_COUNTER_BUFSIZ, + .mode = 0644, + .proc_handler = svcrdma_counter_handler, + }, + { + .procname = "rdma_stat_rq_starve", + .data = &svcrdma_stat_unused, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &zero, + .extra2 = &zero, + }, + { + .procname = "rdma_stat_rq_poll", + .data = &svcrdma_stat_unused, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &zero, + .extra2 = &zero, + }, + { + .procname = "rdma_stat_rq_prod", + .data = &svcrdma_stat_unused, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &zero, + .extra2 = &zero, + }, + { + .procname = "rdma_stat_sq_poll", + .data = &svcrdma_stat_unused, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &zero, + .extra2 = &zero, + }, + { + .procname = "rdma_stat_sq_prod", + .data = &svcrdma_stat_unused, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &zero, + .extra2 = &zero, + }, + { }, +}; + +static struct ctl_table svcrdma_table[] = { + { + .procname = "svc_rdma", + .mode = 0555, + .child = svcrdma_parm_table + }, + { }, +}; + +static struct ctl_table svcrdma_root_table[] = { + { + .procname = "sunrpc", + .mode = 0555, + .child = svcrdma_table + }, + { }, +}; + +static void svc_rdma_proc_cleanup(void) +{ + if (!svcrdma_table_header) + return; + unregister_sysctl_table(svcrdma_table_header); + svcrdma_table_header = NULL; + + percpu_counter_destroy(&svcrdma_stat_write); + percpu_counter_destroy(&svcrdma_stat_sq_starve); + percpu_counter_destroy(&svcrdma_stat_recv); + percpu_counter_destroy(&svcrdma_stat_read); +} + +static int svc_rdma_proc_init(void) +{ + int rc; + + if (svcrdma_table_header) + return 0; + + rc = percpu_counter_init(&svcrdma_stat_read, 0, GFP_KERNEL); + if (rc) + goto out_err; + rc = percpu_counter_init(&svcrdma_stat_recv, 0, GFP_KERNEL); + if (rc) + goto out_err; + rc = percpu_counter_init(&svcrdma_stat_sq_starve, 0, GFP_KERNEL); + if (rc) + goto out_err; + rc = percpu_counter_init(&svcrdma_stat_write, 0, GFP_KERNEL); + if (rc) + goto out_err; + + svcrdma_table_header = register_sysctl_table(svcrdma_root_table); + return 0; + +out_err: + percpu_counter_destroy(&svcrdma_stat_sq_starve); + percpu_counter_destroy(&svcrdma_stat_recv); + percpu_counter_destroy(&svcrdma_stat_read); + return rc; +} + +void svc_rdma_cleanup(void) +{ + dprintk("SVCRDMA Module Removed, deregister RPC RDMA transport\n"); + svc_unreg_xprt_class(&svc_rdma_class); + svc_rdma_proc_cleanup(); +} + +int svc_rdma_init(void) +{ + int rc; + + dprintk("SVCRDMA Module Init, register RPC RDMA transport\n"); + dprintk("\tsvcrdma_ord : %d\n", svcrdma_ord); + dprintk("\tmax_requests : %u\n", svcrdma_max_requests); + dprintk("\tmax_bc_requests : %u\n", svcrdma_max_bc_requests); + dprintk("\tmax_inline : %d\n", svcrdma_max_req_size); + + rc = svc_rdma_proc_init(); + if (rc) + return rc; + + /* Register RDMA with the SVC transport switch */ + svc_reg_xprt_class(&svc_rdma_class); + return 0; +} diff --git a/net/sunrpc/xprtrdma/svc_rdma_backchannel.c b/net/sunrpc/xprtrdma/svc_rdma_backchannel.c new file mode 100644 index 000000000..aa2227a7e --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_backchannel.c @@ -0,0 +1,293 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2015-2018 Oracle. All rights reserved. + * + * Support for reverse-direction RPCs on RPC/RDMA (server-side). + */ + +#include + +#include "xprt_rdma.h" +#include + +/** + * svc_rdma_handle_bc_reply - Process incoming backchannel Reply + * @rqstp: resources for handling the Reply + * @rctxt: Received message + * + */ +void svc_rdma_handle_bc_reply(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *rctxt) +{ + struct svc_xprt *sxprt = rqstp->rq_xprt; + struct rpc_xprt *xprt = sxprt->xpt_bc_xprt; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct xdr_buf *rcvbuf = &rqstp->rq_arg; + struct kvec *dst, *src = &rcvbuf->head[0]; + __be32 *rdma_resp = rctxt->rc_recv_buf; + struct rpc_rqst *req; + u32 credits; + + spin_lock(&xprt->queue_lock); + req = xprt_lookup_rqst(xprt, *rdma_resp); + if (!req) + goto out_unlock; + + dst = &req->rq_private_buf.head[0]; + memcpy(&req->rq_private_buf, &req->rq_rcv_buf, sizeof(struct xdr_buf)); + if (dst->iov_len < src->iov_len) + goto out_unlock; + memcpy(dst->iov_base, src->iov_base, src->iov_len); + xprt_pin_rqst(req); + spin_unlock(&xprt->queue_lock); + + credits = be32_to_cpup(rdma_resp + 2); + if (credits == 0) + credits = 1; /* don't deadlock */ + else if (credits > r_xprt->rx_buf.rb_bc_max_requests) + credits = r_xprt->rx_buf.rb_bc_max_requests; + spin_lock(&xprt->transport_lock); + xprt->cwnd = credits << RPC_CWNDSHIFT; + spin_unlock(&xprt->transport_lock); + + spin_lock(&xprt->queue_lock); + xprt_complete_rqst(req->rq_task, rcvbuf->len); + xprt_unpin_rqst(req); + rcvbuf->len = 0; + +out_unlock: + spin_unlock(&xprt->queue_lock); +} + +/* Send a reverse-direction RPC Call. + * + * Caller holds the connection's mutex and has already marshaled + * the RPC/RDMA request. + * + * This is similar to svc_rdma_send_reply_msg, but takes a struct + * rpc_rqst instead, does not support chunks, and avoids blocking + * memory allocation. + * + * XXX: There is still an opportunity to block in svc_rdma_send() + * if there are no SQ entries to post the Send. This may occur if + * the adapter has a small maximum SQ depth. + */ +static int svc_rdma_bc_sendto(struct svcxprt_rdma *rdma, + struct rpc_rqst *rqst, + struct svc_rdma_send_ctxt *sctxt) +{ + struct svc_rdma_recv_ctxt *rctxt; + int ret; + + rctxt = svc_rdma_recv_ctxt_get(rdma); + if (!rctxt) + return -EIO; + + ret = svc_rdma_map_reply_msg(rdma, sctxt, rctxt, &rqst->rq_snd_buf); + svc_rdma_recv_ctxt_put(rdma, rctxt); + if (ret < 0) + return -EIO; + + /* Bump page refcnt so Send completion doesn't release + * the rq_buffer before all retransmits are complete. + */ + get_page(virt_to_page(rqst->rq_buffer)); + sctxt->sc_send_wr.opcode = IB_WR_SEND; + ret = svc_rdma_send(rdma, sctxt); + if (ret < 0) + return ret; + + ret = wait_for_completion_killable(&sctxt->sc_done); + svc_rdma_send_ctxt_put(rdma, sctxt); + return ret; +} + +/* Server-side transport endpoint wants a whole page for its send + * buffer. The client RPC code constructs the RPC header in this + * buffer before it invokes ->send_request. + */ +static int +xprt_rdma_bc_allocate(struct rpc_task *task) +{ + struct rpc_rqst *rqst = task->tk_rqstp; + size_t size = rqst->rq_callsize; + struct page *page; + + if (size > PAGE_SIZE) { + WARN_ONCE(1, "svcrdma: large bc buffer request (size %zu)\n", + size); + return -EINVAL; + } + + page = alloc_page(GFP_NOIO | __GFP_NOWARN); + if (!page) + return -ENOMEM; + rqst->rq_buffer = page_address(page); + + rqst->rq_rbuffer = kmalloc(rqst->rq_rcvsize, GFP_NOIO | __GFP_NOWARN); + if (!rqst->rq_rbuffer) { + put_page(page); + return -ENOMEM; + } + return 0; +} + +static void +xprt_rdma_bc_free(struct rpc_task *task) +{ + struct rpc_rqst *rqst = task->tk_rqstp; + + put_page(virt_to_page(rqst->rq_buffer)); + kfree(rqst->rq_rbuffer); +} + +static int +rpcrdma_bc_send_request(struct svcxprt_rdma *rdma, struct rpc_rqst *rqst) +{ + struct rpc_xprt *xprt = rqst->rq_xprt; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct svc_rdma_send_ctxt *ctxt; + __be32 *p; + int rc; + + ctxt = svc_rdma_send_ctxt_get(rdma); + if (!ctxt) + goto drop_connection; + + p = xdr_reserve_space(&ctxt->sc_stream, RPCRDMA_HDRLEN_MIN); + if (!p) + goto put_ctxt; + *p++ = rqst->rq_xid; + *p++ = rpcrdma_version; + *p++ = cpu_to_be32(r_xprt->rx_buf.rb_bc_max_requests); + *p++ = rdma_msg; + *p++ = xdr_zero; + *p++ = xdr_zero; + *p = xdr_zero; + + rqst->rq_xtime = ktime_get(); + rc = svc_rdma_bc_sendto(rdma, rqst, ctxt); + if (rc) + goto put_ctxt; + return 0; + +put_ctxt: + svc_rdma_send_ctxt_put(rdma, ctxt); + +drop_connection: + return -ENOTCONN; +} + +/** + * xprt_rdma_bc_send_request - Send a reverse-direction Call + * @rqst: rpc_rqst containing Call message to be sent + * + * Return values: + * %0 if the message was sent successfully + * %ENOTCONN if the message was not sent + */ +static int xprt_rdma_bc_send_request(struct rpc_rqst *rqst) +{ + struct svc_xprt *sxprt = rqst->rq_xprt->bc_xprt; + struct svcxprt_rdma *rdma = + container_of(sxprt, struct svcxprt_rdma, sc_xprt); + int ret; + + if (test_bit(XPT_DEAD, &sxprt->xpt_flags)) + return -ENOTCONN; + + ret = rpcrdma_bc_send_request(rdma, rqst); + if (ret == -ENOTCONN) + svc_xprt_close(sxprt); + return ret; +} + +static void +xprt_rdma_bc_close(struct rpc_xprt *xprt) +{ + xprt_disconnect_done(xprt); + xprt->cwnd = RPC_CWNDSHIFT; +} + +static void +xprt_rdma_bc_put(struct rpc_xprt *xprt) +{ + xprt_rdma_free_addresses(xprt); + xprt_free(xprt); +} + +static const struct rpc_xprt_ops xprt_rdma_bc_procs = { + .reserve_xprt = xprt_reserve_xprt_cong, + .release_xprt = xprt_release_xprt_cong, + .alloc_slot = xprt_alloc_slot, + .free_slot = xprt_free_slot, + .release_request = xprt_release_rqst_cong, + .buf_alloc = xprt_rdma_bc_allocate, + .buf_free = xprt_rdma_bc_free, + .send_request = xprt_rdma_bc_send_request, + .wait_for_reply_request = xprt_wait_for_reply_request_def, + .close = xprt_rdma_bc_close, + .destroy = xprt_rdma_bc_put, + .print_stats = xprt_rdma_print_stats +}; + +static const struct rpc_timeout xprt_rdma_bc_timeout = { + .to_initval = 60 * HZ, + .to_maxval = 60 * HZ, +}; + +/* It shouldn't matter if the number of backchannel session slots + * doesn't match the number of RPC/RDMA credits. That just means + * one or the other will have extra slots that aren't used. + */ +static struct rpc_xprt * +xprt_setup_rdma_bc(struct xprt_create *args) +{ + struct rpc_xprt *xprt; + struct rpcrdma_xprt *new_xprt; + + if (args->addrlen > sizeof(xprt->addr)) + return ERR_PTR(-EBADF); + + xprt = xprt_alloc(args->net, sizeof(*new_xprt), + RPCRDMA_MAX_BC_REQUESTS, + RPCRDMA_MAX_BC_REQUESTS); + if (!xprt) + return ERR_PTR(-ENOMEM); + + xprt->timeout = &xprt_rdma_bc_timeout; + xprt_set_bound(xprt); + xprt_set_connected(xprt); + xprt->bind_timeout = 0; + xprt->reestablish_timeout = 0; + xprt->idle_timeout = 0; + + xprt->prot = XPRT_TRANSPORT_BC_RDMA; + xprt->ops = &xprt_rdma_bc_procs; + + memcpy(&xprt->addr, args->dstaddr, args->addrlen); + xprt->addrlen = args->addrlen; + xprt_rdma_format_addresses(xprt, (struct sockaddr *)&xprt->addr); + xprt->resvport = 0; + + xprt->max_payload = xprt_rdma_max_inline_read; + + new_xprt = rpcx_to_rdmax(xprt); + new_xprt->rx_buf.rb_bc_max_requests = xprt->max_reqs; + + xprt_get(xprt); + args->bc_xprt->xpt_bc_xprt = xprt; + xprt->bc_xprt = args->bc_xprt; + + /* Final put for backchannel xprt is in __svc_rdma_free */ + xprt_get(xprt); + return xprt; +} + +struct xprt_class xprt_rdma_bc = { + .list = LIST_HEAD_INIT(xprt_rdma_bc.list), + .name = "rdma backchannel", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_BC_RDMA, + .setup = xprt_setup_rdma_bc, +}; diff --git a/net/sunrpc/xprtrdma/svc_rdma_pcl.c b/net/sunrpc/xprtrdma/svc_rdma_pcl.c new file mode 100644 index 000000000..b63cfeaa2 --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_pcl.c @@ -0,0 +1,306 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2020 Oracle. All rights reserved. + */ + +#include +#include + +#include "xprt_rdma.h" +#include + +/** + * pcl_free - Release all memory associated with a parsed chunk list + * @pcl: parsed chunk list + * + */ +void pcl_free(struct svc_rdma_pcl *pcl) +{ + while (!list_empty(&pcl->cl_chunks)) { + struct svc_rdma_chunk *chunk; + + chunk = pcl_first_chunk(pcl); + list_del(&chunk->ch_list); + kfree(chunk); + } +} + +static struct svc_rdma_chunk *pcl_alloc_chunk(u32 segcount, u32 position) +{ + struct svc_rdma_chunk *chunk; + + chunk = kmalloc(struct_size(chunk, ch_segments, segcount), GFP_KERNEL); + if (!chunk) + return NULL; + + chunk->ch_position = position; + chunk->ch_length = 0; + chunk->ch_payload_length = 0; + chunk->ch_segcount = 0; + return chunk; +} + +static struct svc_rdma_chunk * +pcl_lookup_position(struct svc_rdma_pcl *pcl, u32 position) +{ + struct svc_rdma_chunk *pos; + + pcl_for_each_chunk(pos, pcl) { + if (pos->ch_position == position) + return pos; + } + return NULL; +} + +static void pcl_insert_position(struct svc_rdma_pcl *pcl, + struct svc_rdma_chunk *chunk) +{ + struct svc_rdma_chunk *pos; + + pcl_for_each_chunk(pos, pcl) { + if (pos->ch_position > chunk->ch_position) + break; + } + __list_add(&chunk->ch_list, pos->ch_list.prev, &pos->ch_list); + pcl->cl_count++; +} + +static void pcl_set_read_segment(const struct svc_rdma_recv_ctxt *rctxt, + struct svc_rdma_chunk *chunk, + u32 handle, u32 length, u64 offset) +{ + struct svc_rdma_segment *segment; + + segment = &chunk->ch_segments[chunk->ch_segcount]; + segment->rs_handle = handle; + segment->rs_length = length; + segment->rs_offset = offset; + + trace_svcrdma_decode_rseg(&rctxt->rc_cid, chunk, segment); + + chunk->ch_length += length; + chunk->ch_segcount++; +} + +/** + * pcl_alloc_call - Construct a parsed chunk list for the Call body + * @rctxt: Ingress receive context + * @p: Start of an un-decoded Read list + * + * Assumptions: + * - The incoming Read list has already been sanity checked. + * - cl_count is already set to the number of segments in + * the un-decoded list. + * - The list might not be in order by position. + * + * Return values: + * %true: Parsed chunk list was successfully constructed, and + * cl_count is updated to be the number of chunks (ie. + * unique positions) in the Read list. + * %false: Memory allocation failed. + */ +bool pcl_alloc_call(struct svc_rdma_recv_ctxt *rctxt, __be32 *p) +{ + struct svc_rdma_pcl *pcl = &rctxt->rc_call_pcl; + unsigned int i, segcount = pcl->cl_count; + + pcl->cl_count = 0; + for (i = 0; i < segcount; i++) { + struct svc_rdma_chunk *chunk; + u32 position, handle, length; + u64 offset; + + p++; /* skip the list discriminator */ + p = xdr_decode_read_segment(p, &position, &handle, + &length, &offset); + if (position != 0) + continue; + + if (pcl_is_empty(pcl)) { + chunk = pcl_alloc_chunk(segcount, position); + if (!chunk) + return false; + pcl_insert_position(pcl, chunk); + } else { + chunk = list_first_entry(&pcl->cl_chunks, + struct svc_rdma_chunk, + ch_list); + } + + pcl_set_read_segment(rctxt, chunk, handle, length, offset); + } + + return true; +} + +/** + * pcl_alloc_read - Construct a parsed chunk list for normal Read chunks + * @rctxt: Ingress receive context + * @p: Start of an un-decoded Read list + * + * Assumptions: + * - The incoming Read list has already been sanity checked. + * - cl_count is already set to the number of segments in + * the un-decoded list. + * - The list might not be in order by position. + * + * Return values: + * %true: Parsed chunk list was successfully constructed, and + * cl_count is updated to be the number of chunks (ie. + * unique position values) in the Read list. + * %false: Memory allocation failed. + * + * TODO: + * - Check for chunk range overlaps + */ +bool pcl_alloc_read(struct svc_rdma_recv_ctxt *rctxt, __be32 *p) +{ + struct svc_rdma_pcl *pcl = &rctxt->rc_read_pcl; + unsigned int i, segcount = pcl->cl_count; + + pcl->cl_count = 0; + for (i = 0; i < segcount; i++) { + struct svc_rdma_chunk *chunk; + u32 position, handle, length; + u64 offset; + + p++; /* skip the list discriminator */ + p = xdr_decode_read_segment(p, &position, &handle, + &length, &offset); + if (position == 0) + continue; + + chunk = pcl_lookup_position(pcl, position); + if (!chunk) { + chunk = pcl_alloc_chunk(segcount, position); + if (!chunk) + return false; + pcl_insert_position(pcl, chunk); + } + + pcl_set_read_segment(rctxt, chunk, handle, length, offset); + } + + return true; +} + +/** + * pcl_alloc_write - Construct a parsed chunk list from a Write list + * @rctxt: Ingress receive context + * @pcl: Parsed chunk list to populate + * @p: Start of an un-decoded Write list + * + * Assumptions: + * - The incoming Write list has already been sanity checked, and + * - cl_count is set to the number of chunks in the un-decoded list. + * + * Return values: + * %true: Parsed chunk list was successfully constructed. + * %false: Memory allocation failed. + */ +bool pcl_alloc_write(struct svc_rdma_recv_ctxt *rctxt, + struct svc_rdma_pcl *pcl, __be32 *p) +{ + struct svc_rdma_segment *segment; + struct svc_rdma_chunk *chunk; + unsigned int i, j; + u32 segcount; + + for (i = 0; i < pcl->cl_count; i++) { + p++; /* skip the list discriminator */ + segcount = be32_to_cpup(p++); + + chunk = pcl_alloc_chunk(segcount, 0); + if (!chunk) + return false; + list_add_tail(&chunk->ch_list, &pcl->cl_chunks); + + for (j = 0; j < segcount; j++) { + segment = &chunk->ch_segments[j]; + p = xdr_decode_rdma_segment(p, &segment->rs_handle, + &segment->rs_length, + &segment->rs_offset); + trace_svcrdma_decode_wseg(&rctxt->rc_cid, chunk, j); + + chunk->ch_length += segment->rs_length; + chunk->ch_segcount++; + } + } + return true; +} + +static int pcl_process_region(const struct xdr_buf *xdr, + unsigned int offset, unsigned int length, + int (*actor)(const struct xdr_buf *, void *), + void *data) +{ + struct xdr_buf subbuf; + + if (!length) + return 0; + if (xdr_buf_subsegment(xdr, &subbuf, offset, length)) + return -EMSGSIZE; + return actor(&subbuf, data); +} + +/** + * pcl_process_nonpayloads - Process non-payload regions inside @xdr + * @pcl: Chunk list to process + * @xdr: xdr_buf to process + * @actor: Function to invoke on each non-payload region + * @data: Arguments for @actor + * + * This mechanism must ignore not only result payloads that were already + * sent via RDMA Write, but also XDR padding for those payloads that + * the upper layer has added. + * + * Assumptions: + * The xdr->len and ch_position fields are aligned to 4-byte multiples. + * + * Returns: + * On success, zero, + * %-EMSGSIZE on XDR buffer overflow, or + * The return value of @actor + */ +int pcl_process_nonpayloads(const struct svc_rdma_pcl *pcl, + const struct xdr_buf *xdr, + int (*actor)(const struct xdr_buf *, void *), + void *data) +{ + struct svc_rdma_chunk *chunk, *next; + unsigned int start; + int ret; + + chunk = pcl_first_chunk(pcl); + + /* No result payloads were generated */ + if (!chunk || !chunk->ch_payload_length) + return actor(xdr, data); + + /* Process the region before the first result payload */ + ret = pcl_process_region(xdr, 0, chunk->ch_position, actor, data); + if (ret < 0) + return ret; + + /* Process the regions between each middle result payload */ + while ((next = pcl_next_chunk(pcl, chunk))) { + if (!next->ch_payload_length) + break; + + start = pcl_chunk_end_offset(chunk); + ret = pcl_process_region(xdr, start, next->ch_position - start, + actor, data); + if (ret < 0) + return ret; + + chunk = next; + } + + /* Process the region after the last result payload */ + start = pcl_chunk_end_offset(chunk); + ret = pcl_process_region(xdr, start, xdr->len - start, actor, data); + if (ret < 0) + return ret; + + return 0; +} diff --git a/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c b/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c new file mode 100644 index 000000000..b2dd01e52 --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c @@ -0,0 +1,868 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2016-2018 Oracle. All rights reserved. + * Copyright (c) 2014 Open Grid Computing, Inc. All rights reserved. + * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Author: Tom Tucker + */ + +/* Operation + * + * The main entry point is svc_rdma_recvfrom. This is called from + * svc_recv when the transport indicates there is incoming data to + * be read. "Data Ready" is signaled when an RDMA Receive completes, + * or when a set of RDMA Reads complete. + * + * An svc_rqst is passed in. This structure contains an array of + * free pages (rq_pages) that will contain the incoming RPC message. + * + * Short messages are moved directly into svc_rqst::rq_arg, and + * the RPC Call is ready to be processed by the Upper Layer. + * svc_rdma_recvfrom returns the length of the RPC Call message, + * completing the reception of the RPC Call. + * + * However, when an incoming message has Read chunks, + * svc_rdma_recvfrom must post RDMA Reads to pull the RPC Call's + * data payload from the client. svc_rdma_recvfrom sets up the + * RDMA Reads using pages in svc_rqst::rq_pages, which are + * transferred to an svc_rdma_recv_ctxt for the duration of the + * I/O. svc_rdma_recvfrom then returns zero, since the RPC message + * is still not yet ready. + * + * When the Read chunk payloads have become available on the + * server, "Data Ready" is raised again, and svc_recv calls + * svc_rdma_recvfrom again. This second call may use a different + * svc_rqst than the first one, thus any information that needs + * to be preserved across these two calls is kept in an + * svc_rdma_recv_ctxt. + * + * The second call to svc_rdma_recvfrom performs final assembly + * of the RPC Call message, using the RDMA Read sink pages kept in + * the svc_rdma_recv_ctxt. The xdr_buf is copied from the + * svc_rdma_recv_ctxt to the second svc_rqst. The second call returns + * the length of the completed RPC Call message. + * + * Page Management + * + * Pages under I/O must be transferred from the first svc_rqst to an + * svc_rdma_recv_ctxt before the first svc_rdma_recvfrom call returns. + * + * The first svc_rqst supplies pages for RDMA Reads. These are moved + * from rqstp::rq_pages into ctxt::pages. The consumed elements of + * the rq_pages array are set to NULL and refilled with the first + * svc_rdma_recvfrom call returns. + * + * During the second svc_rdma_recvfrom call, RDMA Read sink pages + * are transferred from the svc_rdma_recv_ctxt to the second svc_rqst. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "xprt_rdma.h" +#include + +static void svc_rdma_wc_receive(struct ib_cq *cq, struct ib_wc *wc); + +static inline struct svc_rdma_recv_ctxt * +svc_rdma_next_recv_ctxt(struct list_head *list) +{ + return list_first_entry_or_null(list, struct svc_rdma_recv_ctxt, + rc_list); +} + +static void svc_rdma_recv_cid_init(struct svcxprt_rdma *rdma, + struct rpc_rdma_cid *cid) +{ + cid->ci_queue_id = rdma->sc_rq_cq->res.id; + cid->ci_completion_id = atomic_inc_return(&rdma->sc_completion_ids); +} + +static struct svc_rdma_recv_ctxt * +svc_rdma_recv_ctxt_alloc(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_recv_ctxt *ctxt; + dma_addr_t addr; + void *buffer; + + ctxt = kmalloc(sizeof(*ctxt), GFP_KERNEL); + if (!ctxt) + goto fail0; + buffer = kmalloc(rdma->sc_max_req_size, GFP_KERNEL); + if (!buffer) + goto fail1; + addr = ib_dma_map_single(rdma->sc_pd->device, buffer, + rdma->sc_max_req_size, DMA_FROM_DEVICE); + if (ib_dma_mapping_error(rdma->sc_pd->device, addr)) + goto fail2; + + svc_rdma_recv_cid_init(rdma, &ctxt->rc_cid); + pcl_init(&ctxt->rc_call_pcl); + pcl_init(&ctxt->rc_read_pcl); + pcl_init(&ctxt->rc_write_pcl); + pcl_init(&ctxt->rc_reply_pcl); + + ctxt->rc_recv_wr.next = NULL; + ctxt->rc_recv_wr.wr_cqe = &ctxt->rc_cqe; + ctxt->rc_recv_wr.sg_list = &ctxt->rc_recv_sge; + ctxt->rc_recv_wr.num_sge = 1; + ctxt->rc_cqe.done = svc_rdma_wc_receive; + ctxt->rc_recv_sge.addr = addr; + ctxt->rc_recv_sge.length = rdma->sc_max_req_size; + ctxt->rc_recv_sge.lkey = rdma->sc_pd->local_dma_lkey; + ctxt->rc_recv_buf = buffer; + ctxt->rc_temp = false; + return ctxt; + +fail2: + kfree(buffer); +fail1: + kfree(ctxt); +fail0: + return NULL; +} + +static void svc_rdma_recv_ctxt_destroy(struct svcxprt_rdma *rdma, + struct svc_rdma_recv_ctxt *ctxt) +{ + ib_dma_unmap_single(rdma->sc_pd->device, ctxt->rc_recv_sge.addr, + ctxt->rc_recv_sge.length, DMA_FROM_DEVICE); + kfree(ctxt->rc_recv_buf); + kfree(ctxt); +} + +/** + * svc_rdma_recv_ctxts_destroy - Release all recv_ctxt's for an xprt + * @rdma: svcxprt_rdma being torn down + * + */ +void svc_rdma_recv_ctxts_destroy(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_recv_ctxt *ctxt; + struct llist_node *node; + + while ((node = llist_del_first(&rdma->sc_recv_ctxts))) { + ctxt = llist_entry(node, struct svc_rdma_recv_ctxt, rc_node); + svc_rdma_recv_ctxt_destroy(rdma, ctxt); + } +} + +/** + * svc_rdma_recv_ctxt_get - Allocate a recv_ctxt + * @rdma: controlling svcxprt_rdma + * + * Returns a recv_ctxt or (rarely) NULL if none are available. + */ +struct svc_rdma_recv_ctxt *svc_rdma_recv_ctxt_get(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_recv_ctxt *ctxt; + struct llist_node *node; + + node = llist_del_first(&rdma->sc_recv_ctxts); + if (!node) + goto out_empty; + ctxt = llist_entry(node, struct svc_rdma_recv_ctxt, rc_node); + +out: + ctxt->rc_page_count = 0; + return ctxt; + +out_empty: + ctxt = svc_rdma_recv_ctxt_alloc(rdma); + if (!ctxt) + return NULL; + goto out; +} + +/** + * svc_rdma_recv_ctxt_put - Return recv_ctxt to free list + * @rdma: controlling svcxprt_rdma + * @ctxt: object to return to the free list + * + */ +void svc_rdma_recv_ctxt_put(struct svcxprt_rdma *rdma, + struct svc_rdma_recv_ctxt *ctxt) +{ + pcl_free(&ctxt->rc_call_pcl); + pcl_free(&ctxt->rc_read_pcl); + pcl_free(&ctxt->rc_write_pcl); + pcl_free(&ctxt->rc_reply_pcl); + + if (!ctxt->rc_temp) + llist_add(&ctxt->rc_node, &rdma->sc_recv_ctxts); + else + svc_rdma_recv_ctxt_destroy(rdma, ctxt); +} + +/** + * svc_rdma_release_ctxt - Release transport-specific per-rqst resources + * @xprt: the transport which owned the context + * @vctxt: the context from rqstp->rq_xprt_ctxt or dr->xprt_ctxt + * + * Ensure that the recv_ctxt is released whether or not a Reply + * was sent. For example, the client could close the connection, + * or svc_process could drop an RPC, before the Reply is sent. + */ +void svc_rdma_release_ctxt(struct svc_xprt *xprt, void *vctxt) +{ + struct svc_rdma_recv_ctxt *ctxt = vctxt; + struct svcxprt_rdma *rdma = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + + if (ctxt) + svc_rdma_recv_ctxt_put(rdma, ctxt); +} + +static bool svc_rdma_refresh_recvs(struct svcxprt_rdma *rdma, + unsigned int wanted, bool temp) +{ + const struct ib_recv_wr *bad_wr = NULL; + struct svc_rdma_recv_ctxt *ctxt; + struct ib_recv_wr *recv_chain; + int ret; + + if (test_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags)) + return false; + + recv_chain = NULL; + while (wanted--) { + ctxt = svc_rdma_recv_ctxt_get(rdma); + if (!ctxt) + break; + + trace_svcrdma_post_recv(ctxt); + ctxt->rc_temp = temp; + ctxt->rc_recv_wr.next = recv_chain; + recv_chain = &ctxt->rc_recv_wr; + rdma->sc_pending_recvs++; + } + if (!recv_chain) + return false; + + ret = ib_post_recv(rdma->sc_qp, recv_chain, &bad_wr); + if (ret) + goto err_free; + return true; + +err_free: + trace_svcrdma_rq_post_err(rdma, ret); + while (bad_wr) { + ctxt = container_of(bad_wr, struct svc_rdma_recv_ctxt, + rc_recv_wr); + bad_wr = bad_wr->next; + svc_rdma_recv_ctxt_put(rdma, ctxt); + } + /* Since we're destroying the xprt, no need to reset + * sc_pending_recvs. */ + return false; +} + +/** + * svc_rdma_post_recvs - Post initial set of Recv WRs + * @rdma: fresh svcxprt_rdma + * + * Returns true if successful, otherwise false. + */ +bool svc_rdma_post_recvs(struct svcxprt_rdma *rdma) +{ + return svc_rdma_refresh_recvs(rdma, rdma->sc_max_requests, true); +} + +/** + * svc_rdma_wc_receive - Invoked by RDMA provider for each polled Receive WC + * @cq: Completion Queue context + * @wc: Work Completion object + * + */ +static void svc_rdma_wc_receive(struct ib_cq *cq, struct ib_wc *wc) +{ + struct svcxprt_rdma *rdma = cq->cq_context; + struct ib_cqe *cqe = wc->wr_cqe; + struct svc_rdma_recv_ctxt *ctxt; + + rdma->sc_pending_recvs--; + + /* WARNING: Only wc->wr_cqe and wc->status are reliable */ + ctxt = container_of(cqe, struct svc_rdma_recv_ctxt, rc_cqe); + + if (wc->status != IB_WC_SUCCESS) + goto flushed; + trace_svcrdma_wc_recv(wc, &ctxt->rc_cid); + + /* If receive posting fails, the connection is about to be + * lost anyway. The server will not be able to send a reply + * for this RPC, and the client will retransmit this RPC + * anyway when it reconnects. + * + * Therefore we drop the Receive, even if status was SUCCESS + * to reduce the likelihood of replayed requests once the + * client reconnects. + */ + if (rdma->sc_pending_recvs < rdma->sc_max_requests) + if (!svc_rdma_refresh_recvs(rdma, rdma->sc_recv_batch, false)) + goto dropped; + + /* All wc fields are now known to be valid */ + ctxt->rc_byte_len = wc->byte_len; + + spin_lock(&rdma->sc_rq_dto_lock); + list_add_tail(&ctxt->rc_list, &rdma->sc_rq_dto_q); + /* Note the unlock pairs with the smp_rmb in svc_xprt_ready: */ + set_bit(XPT_DATA, &rdma->sc_xprt.xpt_flags); + spin_unlock(&rdma->sc_rq_dto_lock); + if (!test_bit(RDMAXPRT_CONN_PENDING, &rdma->sc_flags)) + svc_xprt_enqueue(&rdma->sc_xprt); + return; + +flushed: + if (wc->status == IB_WC_WR_FLUSH_ERR) + trace_svcrdma_wc_recv_flush(wc, &ctxt->rc_cid); + else + trace_svcrdma_wc_recv_err(wc, &ctxt->rc_cid); +dropped: + svc_rdma_recv_ctxt_put(rdma, ctxt); + svc_xprt_deferred_close(&rdma->sc_xprt); +} + +/** + * svc_rdma_flush_recv_queues - Drain pending Receive work + * @rdma: svcxprt_rdma being shut down + * + */ +void svc_rdma_flush_recv_queues(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_recv_ctxt *ctxt; + + while ((ctxt = svc_rdma_next_recv_ctxt(&rdma->sc_rq_dto_q))) { + list_del(&ctxt->rc_list); + svc_rdma_recv_ctxt_put(rdma, ctxt); + } +} + +static void svc_rdma_build_arg_xdr(struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *ctxt) +{ + struct xdr_buf *arg = &rqstp->rq_arg; + + arg->head[0].iov_base = ctxt->rc_recv_buf; + arg->head[0].iov_len = ctxt->rc_byte_len; + arg->tail[0].iov_base = NULL; + arg->tail[0].iov_len = 0; + arg->page_len = 0; + arg->page_base = 0; + arg->buflen = ctxt->rc_byte_len; + arg->len = ctxt->rc_byte_len; +} + +/** + * xdr_count_read_segments - Count number of Read segments in Read list + * @rctxt: Ingress receive context + * @p: Start of an un-decoded Read list + * + * Before allocating anything, ensure the ingress Read list is safe + * to use. + * + * The segment count is limited to how many segments can fit in the + * transport header without overflowing the buffer. That's about 40 + * Read segments for a 1KB inline threshold. + * + * Return values: + * %true: Read list is valid. @rctxt's xdr_stream is updated to point + * to the first byte past the Read list. rc_read_pcl and + * rc_call_pcl cl_count fields are set to the number of + * Read segments in the list. + * %false: Read list is corrupt. @rctxt's xdr_stream is left in an + * unknown state. + */ +static bool xdr_count_read_segments(struct svc_rdma_recv_ctxt *rctxt, __be32 *p) +{ + rctxt->rc_call_pcl.cl_count = 0; + rctxt->rc_read_pcl.cl_count = 0; + while (xdr_item_is_present(p)) { + u32 position, handle, length; + u64 offset; + + p = xdr_inline_decode(&rctxt->rc_stream, + rpcrdma_readseg_maxsz * sizeof(*p)); + if (!p) + return false; + + xdr_decode_read_segment(p, &position, &handle, + &length, &offset); + if (position) { + if (position & 3) + return false; + ++rctxt->rc_read_pcl.cl_count; + } else { + ++rctxt->rc_call_pcl.cl_count; + } + + p = xdr_inline_decode(&rctxt->rc_stream, sizeof(*p)); + if (!p) + return false; + } + return true; +} + +/* Sanity check the Read list. + * + * Sanity checks: + * - Read list does not overflow Receive buffer. + * - Chunk size limited by largest NFS data payload. + * + * Return values: + * %true: Read list is valid. @rctxt's xdr_stream is updated + * to point to the first byte past the Read list. + * %false: Read list is corrupt. @rctxt's xdr_stream is left + * in an unknown state. + */ +static bool xdr_check_read_list(struct svc_rdma_recv_ctxt *rctxt) +{ + __be32 *p; + + p = xdr_inline_decode(&rctxt->rc_stream, sizeof(*p)); + if (!p) + return false; + if (!xdr_count_read_segments(rctxt, p)) + return false; + if (!pcl_alloc_call(rctxt, p)) + return false; + return pcl_alloc_read(rctxt, p); +} + +static bool xdr_check_write_chunk(struct svc_rdma_recv_ctxt *rctxt) +{ + u32 segcount; + __be32 *p; + + if (xdr_stream_decode_u32(&rctxt->rc_stream, &segcount)) + return false; + + /* A bogus segcount causes this buffer overflow check to fail. */ + p = xdr_inline_decode(&rctxt->rc_stream, + segcount * rpcrdma_segment_maxsz * sizeof(*p)); + return p != NULL; +} + +/** + * xdr_count_write_chunks - Count number of Write chunks in Write list + * @rctxt: Received header and decoding state + * @p: start of an un-decoded Write list + * + * Before allocating anything, ensure the ingress Write list is + * safe to use. + * + * Return values: + * %true: Write list is valid. @rctxt's xdr_stream is updated + * to point to the first byte past the Write list, and + * the number of Write chunks is in rc_write_pcl.cl_count. + * %false: Write list is corrupt. @rctxt's xdr_stream is left + * in an indeterminate state. + */ +static bool xdr_count_write_chunks(struct svc_rdma_recv_ctxt *rctxt, __be32 *p) +{ + rctxt->rc_write_pcl.cl_count = 0; + while (xdr_item_is_present(p)) { + if (!xdr_check_write_chunk(rctxt)) + return false; + ++rctxt->rc_write_pcl.cl_count; + p = xdr_inline_decode(&rctxt->rc_stream, sizeof(*p)); + if (!p) + return false; + } + return true; +} + +/* Sanity check the Write list. + * + * Implementation limits: + * - This implementation currently supports only one Write chunk. + * + * Sanity checks: + * - Write list does not overflow Receive buffer. + * - Chunk size limited by largest NFS data payload. + * + * Return values: + * %true: Write list is valid. @rctxt's xdr_stream is updated + * to point to the first byte past the Write list. + * %false: Write list is corrupt. @rctxt's xdr_stream is left + * in an unknown state. + */ +static bool xdr_check_write_list(struct svc_rdma_recv_ctxt *rctxt) +{ + __be32 *p; + + p = xdr_inline_decode(&rctxt->rc_stream, sizeof(*p)); + if (!p) + return false; + if (!xdr_count_write_chunks(rctxt, p)) + return false; + if (!pcl_alloc_write(rctxt, &rctxt->rc_write_pcl, p)) + return false; + + rctxt->rc_cur_result_payload = pcl_first_chunk(&rctxt->rc_write_pcl); + return true; +} + +/* Sanity check the Reply chunk. + * + * Sanity checks: + * - Reply chunk does not overflow Receive buffer. + * - Chunk size limited by largest NFS data payload. + * + * Return values: + * %true: Reply chunk is valid. @rctxt's xdr_stream is updated + * to point to the first byte past the Reply chunk. + * %false: Reply chunk is corrupt. @rctxt's xdr_stream is left + * in an unknown state. + */ +static bool xdr_check_reply_chunk(struct svc_rdma_recv_ctxt *rctxt) +{ + __be32 *p; + + p = xdr_inline_decode(&rctxt->rc_stream, sizeof(*p)); + if (!p) + return false; + + if (!xdr_item_is_present(p)) + return true; + if (!xdr_check_write_chunk(rctxt)) + return false; + + rctxt->rc_reply_pcl.cl_count = 1; + return pcl_alloc_write(rctxt, &rctxt->rc_reply_pcl, p); +} + +/* RPC-over-RDMA Version One private extension: Remote Invalidation. + * Responder's choice: requester signals it can handle Send With + * Invalidate, and responder chooses one R_key to invalidate. + * + * If there is exactly one distinct R_key in the received transport + * header, set rc_inv_rkey to that R_key. Otherwise, set it to zero. + */ +static void svc_rdma_get_inv_rkey(struct svcxprt_rdma *rdma, + struct svc_rdma_recv_ctxt *ctxt) +{ + struct svc_rdma_segment *segment; + struct svc_rdma_chunk *chunk; + u32 inv_rkey; + + ctxt->rc_inv_rkey = 0; + + if (!rdma->sc_snd_w_inv) + return; + + inv_rkey = 0; + pcl_for_each_chunk(chunk, &ctxt->rc_call_pcl) { + pcl_for_each_segment(segment, chunk) { + if (inv_rkey == 0) + inv_rkey = segment->rs_handle; + else if (inv_rkey != segment->rs_handle) + return; + } + } + pcl_for_each_chunk(chunk, &ctxt->rc_read_pcl) { + pcl_for_each_segment(segment, chunk) { + if (inv_rkey == 0) + inv_rkey = segment->rs_handle; + else if (inv_rkey != segment->rs_handle) + return; + } + } + pcl_for_each_chunk(chunk, &ctxt->rc_write_pcl) { + pcl_for_each_segment(segment, chunk) { + if (inv_rkey == 0) + inv_rkey = segment->rs_handle; + else if (inv_rkey != segment->rs_handle) + return; + } + } + pcl_for_each_chunk(chunk, &ctxt->rc_reply_pcl) { + pcl_for_each_segment(segment, chunk) { + if (inv_rkey == 0) + inv_rkey = segment->rs_handle; + else if (inv_rkey != segment->rs_handle) + return; + } + } + ctxt->rc_inv_rkey = inv_rkey; +} + +/** + * svc_rdma_xdr_decode_req - Decode the transport header + * @rq_arg: xdr_buf containing ingress RPC/RDMA message + * @rctxt: state of decoding + * + * On entry, xdr->head[0].iov_base points to first byte of the + * RPC-over-RDMA transport header. + * + * On successful exit, head[0] points to first byte past the + * RPC-over-RDMA header. For RDMA_MSG, this is the RPC message. + * + * The length of the RPC-over-RDMA header is returned. + * + * Assumptions: + * - The transport header is entirely contained in the head iovec. + */ +static int svc_rdma_xdr_decode_req(struct xdr_buf *rq_arg, + struct svc_rdma_recv_ctxt *rctxt) +{ + __be32 *p, *rdma_argp; + unsigned int hdr_len; + + rdma_argp = rq_arg->head[0].iov_base; + xdr_init_decode(&rctxt->rc_stream, rq_arg, rdma_argp, NULL); + + p = xdr_inline_decode(&rctxt->rc_stream, + rpcrdma_fixed_maxsz * sizeof(*p)); + if (unlikely(!p)) + goto out_short; + p++; + if (*p != rpcrdma_version) + goto out_version; + p += 2; + rctxt->rc_msgtype = *p; + switch (rctxt->rc_msgtype) { + case rdma_msg: + break; + case rdma_nomsg: + break; + case rdma_done: + goto out_drop; + case rdma_error: + goto out_drop; + default: + goto out_proc; + } + + if (!xdr_check_read_list(rctxt)) + goto out_inval; + if (!xdr_check_write_list(rctxt)) + goto out_inval; + if (!xdr_check_reply_chunk(rctxt)) + goto out_inval; + + rq_arg->head[0].iov_base = rctxt->rc_stream.p; + hdr_len = xdr_stream_pos(&rctxt->rc_stream); + rq_arg->head[0].iov_len -= hdr_len; + rq_arg->len -= hdr_len; + trace_svcrdma_decode_rqst(rctxt, rdma_argp, hdr_len); + return hdr_len; + +out_short: + trace_svcrdma_decode_short_err(rctxt, rq_arg->len); + return -EINVAL; + +out_version: + trace_svcrdma_decode_badvers_err(rctxt, rdma_argp); + return -EPROTONOSUPPORT; + +out_drop: + trace_svcrdma_decode_drop_err(rctxt, rdma_argp); + return 0; + +out_proc: + trace_svcrdma_decode_badproc_err(rctxt, rdma_argp); + return -EINVAL; + +out_inval: + trace_svcrdma_decode_parse_err(rctxt, rdma_argp); + return -EINVAL; +} + +static void svc_rdma_send_error(struct svcxprt_rdma *rdma, + struct svc_rdma_recv_ctxt *rctxt, + int status) +{ + struct svc_rdma_send_ctxt *sctxt; + + sctxt = svc_rdma_send_ctxt_get(rdma); + if (!sctxt) + return; + svc_rdma_send_error_msg(rdma, sctxt, rctxt, status); +} + +/* By convention, backchannel calls arrive via rdma_msg type + * messages, and never populate the chunk lists. This makes + * the RPC/RDMA header small and fixed in size, so it is + * straightforward to check the RPC header's direction field. + */ +static bool svc_rdma_is_reverse_direction_reply(struct svc_xprt *xprt, + struct svc_rdma_recv_ctxt *rctxt) +{ + __be32 *p = rctxt->rc_recv_buf; + + if (!xprt->xpt_bc_xprt) + return false; + + if (rctxt->rc_msgtype != rdma_msg) + return false; + + if (!pcl_is_empty(&rctxt->rc_call_pcl)) + return false; + if (!pcl_is_empty(&rctxt->rc_read_pcl)) + return false; + if (!pcl_is_empty(&rctxt->rc_write_pcl)) + return false; + if (!pcl_is_empty(&rctxt->rc_reply_pcl)) + return false; + + /* RPC call direction */ + if (*(p + 8) == cpu_to_be32(RPC_CALL)) + return false; + + return true; +} + +/** + * svc_rdma_recvfrom - Receive an RPC call + * @rqstp: request structure into which to receive an RPC Call + * + * Returns: + * The positive number of bytes in the RPC Call message, + * %0 if there were no Calls ready to return, + * %-EINVAL if the Read chunk data is too large, + * %-ENOMEM if rdma_rw context pool was exhausted, + * %-ENOTCONN if posting failed (connection is lost), + * %-EIO if rdma_rw initialization failed (DMA mapping, etc). + * + * Called in a loop when XPT_DATA is set. XPT_DATA is cleared only + * when there are no remaining ctxt's to process. + * + * The next ctxt is removed from the "receive" lists. + * + * - If the ctxt completes a Read, then finish assembling the Call + * message and return the number of bytes in the message. + * + * - If the ctxt completes a Receive, then construct the Call + * message from the contents of the Receive buffer. + * + * - If there are no Read chunks in this message, then finish + * assembling the Call message and return the number of bytes + * in the message. + * + * - If there are Read chunks in this message, post Read WRs to + * pull that payload and return 0. + */ +int svc_rdma_recvfrom(struct svc_rqst *rqstp) +{ + struct svc_xprt *xprt = rqstp->rq_xprt; + struct svcxprt_rdma *rdma_xprt = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + struct svc_rdma_recv_ctxt *ctxt; + int ret; + + /* Prevent svc_xprt_release() from releasing pages in rq_pages + * when returning 0 or an error. + */ + rqstp->rq_respages = rqstp->rq_pages; + rqstp->rq_next_page = rqstp->rq_respages; + + rqstp->rq_xprt_ctxt = NULL; + + ctxt = NULL; + spin_lock(&rdma_xprt->sc_rq_dto_lock); + ctxt = svc_rdma_next_recv_ctxt(&rdma_xprt->sc_rq_dto_q); + if (ctxt) + list_del(&ctxt->rc_list); + else + /* No new incoming requests, terminate the loop */ + clear_bit(XPT_DATA, &xprt->xpt_flags); + spin_unlock(&rdma_xprt->sc_rq_dto_lock); + + /* Unblock the transport for the next receive */ + svc_xprt_received(xprt); + if (!ctxt) + return 0; + + percpu_counter_inc(&svcrdma_stat_recv); + ib_dma_sync_single_for_cpu(rdma_xprt->sc_pd->device, + ctxt->rc_recv_sge.addr, ctxt->rc_byte_len, + DMA_FROM_DEVICE); + svc_rdma_build_arg_xdr(rqstp, ctxt); + + ret = svc_rdma_xdr_decode_req(&rqstp->rq_arg, ctxt); + if (ret < 0) + goto out_err; + if (ret == 0) + goto out_drop; + + if (svc_rdma_is_reverse_direction_reply(xprt, ctxt)) + goto out_backchannel; + + svc_rdma_get_inv_rkey(rdma_xprt, ctxt); + + if (!pcl_is_empty(&ctxt->rc_read_pcl) || + !pcl_is_empty(&ctxt->rc_call_pcl)) { + ret = svc_rdma_process_read_list(rdma_xprt, rqstp, ctxt); + if (ret < 0) + goto out_readfail; + } + + rqstp->rq_xprt_ctxt = ctxt; + rqstp->rq_prot = IPPROTO_MAX; + svc_xprt_copy_addrs(rqstp, xprt); + return rqstp->rq_arg.len; + +out_err: + svc_rdma_send_error(rdma_xprt, ctxt, ret); + svc_rdma_recv_ctxt_put(rdma_xprt, ctxt); + return 0; + +out_readfail: + if (ret == -EINVAL) + svc_rdma_send_error(rdma_xprt, ctxt, ret); + svc_rdma_recv_ctxt_put(rdma_xprt, ctxt); + svc_xprt_deferred_close(xprt); + return -ENOTCONN; + +out_backchannel: + svc_rdma_handle_bc_reply(rqstp, ctxt); +out_drop: + svc_rdma_recv_ctxt_put(rdma_xprt, ctxt); + return 0; +} diff --git a/net/sunrpc/xprtrdma/svc_rdma_rw.c b/net/sunrpc/xprtrdma/svc_rdma_rw.c new file mode 100644 index 000000000..11cf7c646 --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_rw.c @@ -0,0 +1,1165 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Copyright (c) 2016-2018 Oracle. All rights reserved. + * + * Use the core R/W API to move RPC-over-RDMA Read and Write chunks. + */ + +#include + +#include +#include +#include + +#include "xprt_rdma.h" +#include + +static void svc_rdma_write_done(struct ib_cq *cq, struct ib_wc *wc); +static void svc_rdma_wc_read_done(struct ib_cq *cq, struct ib_wc *wc); + +/* Each R/W context contains state for one chain of RDMA Read or + * Write Work Requests. + * + * Each WR chain handles a single contiguous server-side buffer, + * because scatterlist entries after the first have to start on + * page alignment. xdr_buf iovecs cannot guarantee alignment. + * + * Each WR chain handles only one R_key. Each RPC-over-RDMA segment + * from a client may contain a unique R_key, so each WR chain moves + * up to one segment at a time. + * + * The scatterlist makes this data structure over 4KB in size. To + * make it less likely to fail, and to handle the allocation for + * smaller I/O requests without disabling bottom-halves, these + * contexts are created on demand, but cached and reused until the + * controlling svcxprt_rdma is destroyed. + */ +struct svc_rdma_rw_ctxt { + struct llist_node rw_node; + struct list_head rw_list; + struct rdma_rw_ctx rw_ctx; + unsigned int rw_nents; + struct sg_table rw_sg_table; + struct scatterlist rw_first_sgl[]; +}; + +static inline struct svc_rdma_rw_ctxt * +svc_rdma_next_ctxt(struct list_head *list) +{ + return list_first_entry_or_null(list, struct svc_rdma_rw_ctxt, + rw_list); +} + +static struct svc_rdma_rw_ctxt * +svc_rdma_get_rw_ctxt(struct svcxprt_rdma *rdma, unsigned int sges) +{ + struct svc_rdma_rw_ctxt *ctxt; + struct llist_node *node; + + spin_lock(&rdma->sc_rw_ctxt_lock); + node = llist_del_first(&rdma->sc_rw_ctxts); + spin_unlock(&rdma->sc_rw_ctxt_lock); + if (node) { + ctxt = llist_entry(node, struct svc_rdma_rw_ctxt, rw_node); + } else { + ctxt = kmalloc(struct_size(ctxt, rw_first_sgl, SG_CHUNK_SIZE), + GFP_KERNEL); + if (!ctxt) + goto out_noctx; + + INIT_LIST_HEAD(&ctxt->rw_list); + } + + ctxt->rw_sg_table.sgl = ctxt->rw_first_sgl; + if (sg_alloc_table_chained(&ctxt->rw_sg_table, sges, + ctxt->rw_sg_table.sgl, + SG_CHUNK_SIZE)) + goto out_free; + return ctxt; + +out_free: + kfree(ctxt); +out_noctx: + trace_svcrdma_no_rwctx_err(rdma, sges); + return NULL; +} + +static void __svc_rdma_put_rw_ctxt(struct svcxprt_rdma *rdma, + struct svc_rdma_rw_ctxt *ctxt, + struct llist_head *list) +{ + sg_free_table_chained(&ctxt->rw_sg_table, SG_CHUNK_SIZE); + llist_add(&ctxt->rw_node, list); +} + +static void svc_rdma_put_rw_ctxt(struct svcxprt_rdma *rdma, + struct svc_rdma_rw_ctxt *ctxt) +{ + __svc_rdma_put_rw_ctxt(rdma, ctxt, &rdma->sc_rw_ctxts); +} + +/** + * svc_rdma_destroy_rw_ctxts - Free accumulated R/W contexts + * @rdma: transport about to be destroyed + * + */ +void svc_rdma_destroy_rw_ctxts(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_rw_ctxt *ctxt; + struct llist_node *node; + + while ((node = llist_del_first(&rdma->sc_rw_ctxts)) != NULL) { + ctxt = llist_entry(node, struct svc_rdma_rw_ctxt, rw_node); + kfree(ctxt); + } +} + +/** + * svc_rdma_rw_ctx_init - Prepare a R/W context for I/O + * @rdma: controlling transport instance + * @ctxt: R/W context to prepare + * @offset: RDMA offset + * @handle: RDMA tag/handle + * @direction: I/O direction + * + * Returns on success, the number of WQEs that will be needed + * on the workqueue, or a negative errno. + */ +static int svc_rdma_rw_ctx_init(struct svcxprt_rdma *rdma, + struct svc_rdma_rw_ctxt *ctxt, + u64 offset, u32 handle, + enum dma_data_direction direction) +{ + int ret; + + ret = rdma_rw_ctx_init(&ctxt->rw_ctx, rdma->sc_qp, rdma->sc_port_num, + ctxt->rw_sg_table.sgl, ctxt->rw_nents, + 0, offset, handle, direction); + if (unlikely(ret < 0)) { + svc_rdma_put_rw_ctxt(rdma, ctxt); + trace_svcrdma_dma_map_rw_err(rdma, ctxt->rw_nents, ret); + } + return ret; +} + +/* A chunk context tracks all I/O for moving one Read or Write + * chunk. This is a set of rdma_rw's that handle data movement + * for all segments of one chunk. + * + * These are small, acquired with a single allocator call, and + * no more than one is needed per chunk. They are allocated on + * demand, and not cached. + */ +struct svc_rdma_chunk_ctxt { + struct rpc_rdma_cid cc_cid; + struct ib_cqe cc_cqe; + struct svcxprt_rdma *cc_rdma; + struct list_head cc_rwctxts; + ktime_t cc_posttime; + int cc_sqecount; + enum ib_wc_status cc_status; + struct completion cc_done; +}; + +static void svc_rdma_cc_cid_init(struct svcxprt_rdma *rdma, + struct rpc_rdma_cid *cid) +{ + cid->ci_queue_id = rdma->sc_sq_cq->res.id; + cid->ci_completion_id = atomic_inc_return(&rdma->sc_completion_ids); +} + +static void svc_rdma_cc_init(struct svcxprt_rdma *rdma, + struct svc_rdma_chunk_ctxt *cc) +{ + svc_rdma_cc_cid_init(rdma, &cc->cc_cid); + cc->cc_rdma = rdma; + + INIT_LIST_HEAD(&cc->cc_rwctxts); + cc->cc_sqecount = 0; +} + +/* + * The consumed rw_ctx's are cleaned and placed on a local llist so + * that only one atomic llist operation is needed to put them all + * back on the free list. + */ +static void svc_rdma_cc_release(struct svc_rdma_chunk_ctxt *cc, + enum dma_data_direction dir) +{ + struct svcxprt_rdma *rdma = cc->cc_rdma; + struct llist_node *first, *last; + struct svc_rdma_rw_ctxt *ctxt; + LLIST_HEAD(free); + + first = last = NULL; + while ((ctxt = svc_rdma_next_ctxt(&cc->cc_rwctxts)) != NULL) { + list_del(&ctxt->rw_list); + + rdma_rw_ctx_destroy(&ctxt->rw_ctx, rdma->sc_qp, + rdma->sc_port_num, ctxt->rw_sg_table.sgl, + ctxt->rw_nents, dir); + __svc_rdma_put_rw_ctxt(rdma, ctxt, &free); + + ctxt->rw_node.next = first; + first = &ctxt->rw_node; + if (!last) + last = first; + } + if (first) + llist_add_batch(first, last, &rdma->sc_rw_ctxts); +} + +/* State for sending a Write or Reply chunk. + * - Tracks progress of writing one chunk over all its segments + * - Stores arguments for the SGL constructor functions + */ +struct svc_rdma_write_info { + const struct svc_rdma_chunk *wi_chunk; + + /* write state of this chunk */ + unsigned int wi_seg_off; + unsigned int wi_seg_no; + + /* SGL constructor arguments */ + const struct xdr_buf *wi_xdr; + unsigned char *wi_base; + unsigned int wi_next_off; + + struct svc_rdma_chunk_ctxt wi_cc; +}; + +static struct svc_rdma_write_info * +svc_rdma_write_info_alloc(struct svcxprt_rdma *rdma, + const struct svc_rdma_chunk *chunk) +{ + struct svc_rdma_write_info *info; + + info = kmalloc(sizeof(*info), GFP_KERNEL); + if (!info) + return info; + + info->wi_chunk = chunk; + info->wi_seg_off = 0; + info->wi_seg_no = 0; + svc_rdma_cc_init(rdma, &info->wi_cc); + info->wi_cc.cc_cqe.done = svc_rdma_write_done; + return info; +} + +static void svc_rdma_write_info_free(struct svc_rdma_write_info *info) +{ + svc_rdma_cc_release(&info->wi_cc, DMA_TO_DEVICE); + kfree(info); +} + +/** + * svc_rdma_write_done - Write chunk completion + * @cq: controlling Completion Queue + * @wc: Work Completion + * + * Pages under I/O are freed by a subsequent Send completion. + */ +static void svc_rdma_write_done(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct svc_rdma_chunk_ctxt *cc = + container_of(cqe, struct svc_rdma_chunk_ctxt, cc_cqe); + struct svcxprt_rdma *rdma = cc->cc_rdma; + struct svc_rdma_write_info *info = + container_of(cc, struct svc_rdma_write_info, wi_cc); + + switch (wc->status) { + case IB_WC_SUCCESS: + trace_svcrdma_wc_write(wc, &cc->cc_cid); + break; + case IB_WC_WR_FLUSH_ERR: + trace_svcrdma_wc_write_flush(wc, &cc->cc_cid); + break; + default: + trace_svcrdma_wc_write_err(wc, &cc->cc_cid); + } + + svc_rdma_wake_send_waiters(rdma, cc->cc_sqecount); + + if (unlikely(wc->status != IB_WC_SUCCESS)) + svc_xprt_deferred_close(&rdma->sc_xprt); + + svc_rdma_write_info_free(info); +} + +/* State for pulling a Read chunk. + */ +struct svc_rdma_read_info { + struct svc_rqst *ri_rqst; + struct svc_rdma_recv_ctxt *ri_readctxt; + unsigned int ri_pageno; + unsigned int ri_pageoff; + unsigned int ri_totalbytes; + + struct svc_rdma_chunk_ctxt ri_cc; +}; + +static struct svc_rdma_read_info * +svc_rdma_read_info_alloc(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_read_info *info; + + info = kmalloc(sizeof(*info), GFP_KERNEL); + if (!info) + return info; + + svc_rdma_cc_init(rdma, &info->ri_cc); + info->ri_cc.cc_cqe.done = svc_rdma_wc_read_done; + return info; +} + +static void svc_rdma_read_info_free(struct svc_rdma_read_info *info) +{ + svc_rdma_cc_release(&info->ri_cc, DMA_FROM_DEVICE); + kfree(info); +} + +/** + * svc_rdma_wc_read_done - Handle completion of an RDMA Read ctx + * @cq: controlling Completion Queue + * @wc: Work Completion + * + */ +static void svc_rdma_wc_read_done(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct svc_rdma_chunk_ctxt *cc = + container_of(cqe, struct svc_rdma_chunk_ctxt, cc_cqe); + struct svc_rdma_read_info *info; + + switch (wc->status) { + case IB_WC_SUCCESS: + info = container_of(cc, struct svc_rdma_read_info, ri_cc); + trace_svcrdma_wc_read(wc, &cc->cc_cid, info->ri_totalbytes, + cc->cc_posttime); + break; + case IB_WC_WR_FLUSH_ERR: + trace_svcrdma_wc_read_flush(wc, &cc->cc_cid); + break; + default: + trace_svcrdma_wc_read_err(wc, &cc->cc_cid); + } + + svc_rdma_wake_send_waiters(cc->cc_rdma, cc->cc_sqecount); + cc->cc_status = wc->status; + complete(&cc->cc_done); + return; +} + +/* This function sleeps when the transport's Send Queue is congested. + * + * Assumptions: + * - If ib_post_send() succeeds, only one completion is expected, + * even if one or more WRs are flushed. This is true when posting + * an rdma_rw_ctx or when posting a single signaled WR. + */ +static int svc_rdma_post_chunk_ctxt(struct svc_rdma_chunk_ctxt *cc) +{ + struct svcxprt_rdma *rdma = cc->cc_rdma; + struct ib_send_wr *first_wr; + const struct ib_send_wr *bad_wr; + struct list_head *tmp; + struct ib_cqe *cqe; + int ret; + + if (cc->cc_sqecount > rdma->sc_sq_depth) + return -EINVAL; + + first_wr = NULL; + cqe = &cc->cc_cqe; + list_for_each(tmp, &cc->cc_rwctxts) { + struct svc_rdma_rw_ctxt *ctxt; + + ctxt = list_entry(tmp, struct svc_rdma_rw_ctxt, rw_list); + first_wr = rdma_rw_ctx_wrs(&ctxt->rw_ctx, rdma->sc_qp, + rdma->sc_port_num, cqe, first_wr); + cqe = NULL; + } + + do { + if (atomic_sub_return(cc->cc_sqecount, + &rdma->sc_sq_avail) > 0) { + cc->cc_posttime = ktime_get(); + ret = ib_post_send(rdma->sc_qp, first_wr, &bad_wr); + if (ret) + break; + return 0; + } + + percpu_counter_inc(&svcrdma_stat_sq_starve); + trace_svcrdma_sq_full(rdma); + atomic_add(cc->cc_sqecount, &rdma->sc_sq_avail); + wait_event(rdma->sc_send_wait, + atomic_read(&rdma->sc_sq_avail) > cc->cc_sqecount); + trace_svcrdma_sq_retry(rdma); + } while (1); + + trace_svcrdma_sq_post_err(rdma, ret); + svc_xprt_deferred_close(&rdma->sc_xprt); + + /* If even one was posted, there will be a completion. */ + if (bad_wr != first_wr) + return 0; + + atomic_add(cc->cc_sqecount, &rdma->sc_sq_avail); + wake_up(&rdma->sc_send_wait); + return -ENOTCONN; +} + +/* Build and DMA-map an SGL that covers one kvec in an xdr_buf + */ +static void svc_rdma_vec_to_sg(struct svc_rdma_write_info *info, + unsigned int len, + struct svc_rdma_rw_ctxt *ctxt) +{ + struct scatterlist *sg = ctxt->rw_sg_table.sgl; + + sg_set_buf(&sg[0], info->wi_base, len); + info->wi_base += len; + + ctxt->rw_nents = 1; +} + +/* Build and DMA-map an SGL that covers part of an xdr_buf's pagelist. + */ +static void svc_rdma_pagelist_to_sg(struct svc_rdma_write_info *info, + unsigned int remaining, + struct svc_rdma_rw_ctxt *ctxt) +{ + unsigned int sge_no, sge_bytes, page_off, page_no; + const struct xdr_buf *xdr = info->wi_xdr; + struct scatterlist *sg; + struct page **page; + + page_off = info->wi_next_off + xdr->page_base; + page_no = page_off >> PAGE_SHIFT; + page_off = offset_in_page(page_off); + page = xdr->pages + page_no; + info->wi_next_off += remaining; + sg = ctxt->rw_sg_table.sgl; + sge_no = 0; + do { + sge_bytes = min_t(unsigned int, remaining, + PAGE_SIZE - page_off); + sg_set_page(sg, *page, sge_bytes, page_off); + + remaining -= sge_bytes; + sg = sg_next(sg); + page_off = 0; + sge_no++; + page++; + } while (remaining); + + ctxt->rw_nents = sge_no; +} + +/* Construct RDMA Write WRs to send a portion of an xdr_buf containing + * an RPC Reply. + */ +static int +svc_rdma_build_writes(struct svc_rdma_write_info *info, + void (*constructor)(struct svc_rdma_write_info *info, + unsigned int len, + struct svc_rdma_rw_ctxt *ctxt), + unsigned int remaining) +{ + struct svc_rdma_chunk_ctxt *cc = &info->wi_cc; + struct svcxprt_rdma *rdma = cc->cc_rdma; + const struct svc_rdma_segment *seg; + struct svc_rdma_rw_ctxt *ctxt; + int ret; + + do { + unsigned int write_len; + u64 offset; + + if (info->wi_seg_no >= info->wi_chunk->ch_segcount) + goto out_overflow; + + seg = &info->wi_chunk->ch_segments[info->wi_seg_no]; + write_len = min(remaining, seg->rs_length - info->wi_seg_off); + if (!write_len) + goto out_overflow; + ctxt = svc_rdma_get_rw_ctxt(rdma, + (write_len >> PAGE_SHIFT) + 2); + if (!ctxt) + return -ENOMEM; + + constructor(info, write_len, ctxt); + offset = seg->rs_offset + info->wi_seg_off; + ret = svc_rdma_rw_ctx_init(rdma, ctxt, offset, seg->rs_handle, + DMA_TO_DEVICE); + if (ret < 0) + return -EIO; + percpu_counter_inc(&svcrdma_stat_write); + + list_add(&ctxt->rw_list, &cc->cc_rwctxts); + cc->cc_sqecount += ret; + if (write_len == seg->rs_length - info->wi_seg_off) { + info->wi_seg_no++; + info->wi_seg_off = 0; + } else { + info->wi_seg_off += write_len; + } + remaining -= write_len; + } while (remaining); + + return 0; + +out_overflow: + trace_svcrdma_small_wrch_err(rdma, remaining, info->wi_seg_no, + info->wi_chunk->ch_segcount); + return -E2BIG; +} + +/** + * svc_rdma_iov_write - Construct RDMA Writes from an iov + * @info: pointer to write arguments + * @iov: kvec to write + * + * Returns: + * On success, returns zero + * %-E2BIG if the client-provided Write chunk is too small + * %-ENOMEM if a resource has been exhausted + * %-EIO if an rdma-rw error occurred + */ +static int svc_rdma_iov_write(struct svc_rdma_write_info *info, + const struct kvec *iov) +{ + info->wi_base = iov->iov_base; + return svc_rdma_build_writes(info, svc_rdma_vec_to_sg, + iov->iov_len); +} + +/** + * svc_rdma_pages_write - Construct RDMA Writes from pages + * @info: pointer to write arguments + * @xdr: xdr_buf with pages to write + * @offset: offset into the content of @xdr + * @length: number of bytes to write + * + * Returns: + * On success, returns zero + * %-E2BIG if the client-provided Write chunk is too small + * %-ENOMEM if a resource has been exhausted + * %-EIO if an rdma-rw error occurred + */ +static int svc_rdma_pages_write(struct svc_rdma_write_info *info, + const struct xdr_buf *xdr, + unsigned int offset, + unsigned long length) +{ + info->wi_xdr = xdr; + info->wi_next_off = offset - xdr->head[0].iov_len; + return svc_rdma_build_writes(info, svc_rdma_pagelist_to_sg, + length); +} + +/** + * svc_rdma_xb_write - Construct RDMA Writes to write an xdr_buf + * @xdr: xdr_buf to write + * @data: pointer to write arguments + * + * Returns: + * On success, returns zero + * %-E2BIG if the client-provided Write chunk is too small + * %-ENOMEM if a resource has been exhausted + * %-EIO if an rdma-rw error occurred + */ +static int svc_rdma_xb_write(const struct xdr_buf *xdr, void *data) +{ + struct svc_rdma_write_info *info = data; + int ret; + + if (xdr->head[0].iov_len) { + ret = svc_rdma_iov_write(info, &xdr->head[0]); + if (ret < 0) + return ret; + } + + if (xdr->page_len) { + ret = svc_rdma_pages_write(info, xdr, xdr->head[0].iov_len, + xdr->page_len); + if (ret < 0) + return ret; + } + + if (xdr->tail[0].iov_len) { + ret = svc_rdma_iov_write(info, &xdr->tail[0]); + if (ret < 0) + return ret; + } + + return xdr->len; +} + +/** + * svc_rdma_send_write_chunk - Write all segments in a Write chunk + * @rdma: controlling RDMA transport + * @chunk: Write chunk provided by the client + * @xdr: xdr_buf containing the data payload + * + * Returns a non-negative number of bytes the chunk consumed, or + * %-E2BIG if the payload was larger than the Write chunk, + * %-EINVAL if client provided too many segments, + * %-ENOMEM if rdma_rw context pool was exhausted, + * %-ENOTCONN if posting failed (connection is lost), + * %-EIO if rdma_rw initialization failed (DMA mapping, etc). + */ +int svc_rdma_send_write_chunk(struct svcxprt_rdma *rdma, + const struct svc_rdma_chunk *chunk, + const struct xdr_buf *xdr) +{ + struct svc_rdma_write_info *info; + struct svc_rdma_chunk_ctxt *cc; + int ret; + + info = svc_rdma_write_info_alloc(rdma, chunk); + if (!info) + return -ENOMEM; + cc = &info->wi_cc; + + ret = svc_rdma_xb_write(xdr, info); + if (ret != xdr->len) + goto out_err; + + trace_svcrdma_post_write_chunk(&cc->cc_cid, cc->cc_sqecount); + ret = svc_rdma_post_chunk_ctxt(cc); + if (ret < 0) + goto out_err; + return xdr->len; + +out_err: + svc_rdma_write_info_free(info); + return ret; +} + +/** + * svc_rdma_send_reply_chunk - Write all segments in the Reply chunk + * @rdma: controlling RDMA transport + * @rctxt: Write and Reply chunks from client + * @xdr: xdr_buf containing an RPC Reply + * + * Returns a non-negative number of bytes the chunk consumed, or + * %-E2BIG if the payload was larger than the Reply chunk, + * %-EINVAL if client provided too many segments, + * %-ENOMEM if rdma_rw context pool was exhausted, + * %-ENOTCONN if posting failed (connection is lost), + * %-EIO if rdma_rw initialization failed (DMA mapping, etc). + */ +int svc_rdma_send_reply_chunk(struct svcxprt_rdma *rdma, + const struct svc_rdma_recv_ctxt *rctxt, + const struct xdr_buf *xdr) +{ + struct svc_rdma_write_info *info; + struct svc_rdma_chunk_ctxt *cc; + struct svc_rdma_chunk *chunk; + int ret; + + if (pcl_is_empty(&rctxt->rc_reply_pcl)) + return 0; + + chunk = pcl_first_chunk(&rctxt->rc_reply_pcl); + info = svc_rdma_write_info_alloc(rdma, chunk); + if (!info) + return -ENOMEM; + cc = &info->wi_cc; + + ret = pcl_process_nonpayloads(&rctxt->rc_write_pcl, xdr, + svc_rdma_xb_write, info); + if (ret < 0) + goto out_err; + + trace_svcrdma_post_reply_chunk(&cc->cc_cid, cc->cc_sqecount); + ret = svc_rdma_post_chunk_ctxt(cc); + if (ret < 0) + goto out_err; + + return xdr->len; + +out_err: + svc_rdma_write_info_free(info); + return ret; +} + +/** + * svc_rdma_build_read_segment - Build RDMA Read WQEs to pull one RDMA segment + * @info: context for ongoing I/O + * @segment: co-ordinates of remote memory to be read + * + * Returns: + * %0: the Read WR chain was constructed successfully + * %-EINVAL: there were not enough rq_pages to finish + * %-ENOMEM: allocating a local resources failed + * %-EIO: a DMA mapping error occurred + */ +static int svc_rdma_build_read_segment(struct svc_rdma_read_info *info, + const struct svc_rdma_segment *segment) +{ + struct svc_rdma_recv_ctxt *head = info->ri_readctxt; + struct svc_rdma_chunk_ctxt *cc = &info->ri_cc; + struct svc_rqst *rqstp = info->ri_rqst; + unsigned int sge_no, seg_len, len; + struct svc_rdma_rw_ctxt *ctxt; + struct scatterlist *sg; + int ret; + + len = segment->rs_length; + sge_no = PAGE_ALIGN(info->ri_pageoff + len) >> PAGE_SHIFT; + ctxt = svc_rdma_get_rw_ctxt(cc->cc_rdma, sge_no); + if (!ctxt) + return -ENOMEM; + ctxt->rw_nents = sge_no; + + sg = ctxt->rw_sg_table.sgl; + for (sge_no = 0; sge_no < ctxt->rw_nents; sge_no++) { + seg_len = min_t(unsigned int, len, + PAGE_SIZE - info->ri_pageoff); + + if (!info->ri_pageoff) + head->rc_page_count++; + + sg_set_page(sg, rqstp->rq_pages[info->ri_pageno], + seg_len, info->ri_pageoff); + sg = sg_next(sg); + + info->ri_pageoff += seg_len; + if (info->ri_pageoff == PAGE_SIZE) { + info->ri_pageno++; + info->ri_pageoff = 0; + } + len -= seg_len; + + /* Safety check */ + if (len && + &rqstp->rq_pages[info->ri_pageno + 1] > rqstp->rq_page_end) + goto out_overrun; + } + + ret = svc_rdma_rw_ctx_init(cc->cc_rdma, ctxt, segment->rs_offset, + segment->rs_handle, DMA_FROM_DEVICE); + if (ret < 0) + return -EIO; + percpu_counter_inc(&svcrdma_stat_read); + + list_add(&ctxt->rw_list, &cc->cc_rwctxts); + cc->cc_sqecount += ret; + return 0; + +out_overrun: + trace_svcrdma_page_overrun_err(cc->cc_rdma, rqstp, info->ri_pageno); + return -EINVAL; +} + +/** + * svc_rdma_build_read_chunk - Build RDMA Read WQEs to pull one RDMA chunk + * @info: context for ongoing I/O + * @chunk: Read chunk to pull + * + * Return values: + * %0: the Read WR chain was constructed successfully + * %-EINVAL: there were not enough resources to finish + * %-ENOMEM: allocating a local resources failed + * %-EIO: a DMA mapping error occurred + */ +static int svc_rdma_build_read_chunk(struct svc_rdma_read_info *info, + const struct svc_rdma_chunk *chunk) +{ + const struct svc_rdma_segment *segment; + int ret; + + ret = -EINVAL; + pcl_for_each_segment(segment, chunk) { + ret = svc_rdma_build_read_segment(info, segment); + if (ret < 0) + break; + info->ri_totalbytes += segment->rs_length; + } + return ret; +} + +/** + * svc_rdma_copy_inline_range - Copy part of the inline content into pages + * @info: context for RDMA Reads + * @offset: offset into the Receive buffer of region to copy + * @remaining: length of region to copy + * + * Take a page at a time from rqstp->rq_pages and copy the inline + * content from the Receive buffer into that page. Update + * info->ri_pageno and info->ri_pageoff so that the next RDMA Read + * result will land contiguously with the copied content. + * + * Return values: + * %0: Inline content was successfully copied + * %-EINVAL: offset or length was incorrect + */ +static int svc_rdma_copy_inline_range(struct svc_rdma_read_info *info, + unsigned int offset, + unsigned int remaining) +{ + struct svc_rdma_recv_ctxt *head = info->ri_readctxt; + unsigned char *dst, *src = head->rc_recv_buf; + struct svc_rqst *rqstp = info->ri_rqst; + unsigned int page_no, numpages; + + numpages = PAGE_ALIGN(info->ri_pageoff + remaining) >> PAGE_SHIFT; + for (page_no = 0; page_no < numpages; page_no++) { + unsigned int page_len; + + page_len = min_t(unsigned int, remaining, + PAGE_SIZE - info->ri_pageoff); + + if (!info->ri_pageoff) + head->rc_page_count++; + + dst = page_address(rqstp->rq_pages[info->ri_pageno]); + memcpy(dst + info->ri_pageno, src + offset, page_len); + + info->ri_totalbytes += page_len; + info->ri_pageoff += page_len; + if (info->ri_pageoff == PAGE_SIZE) { + info->ri_pageno++; + info->ri_pageoff = 0; + } + remaining -= page_len; + offset += page_len; + } + + return -EINVAL; +} + +/** + * svc_rdma_read_multiple_chunks - Construct RDMA Reads to pull data item Read chunks + * @info: context for RDMA Reads + * + * The chunk data lands in rqstp->rq_arg as a series of contiguous pages, + * like an incoming TCP call. + * + * Return values: + * %0: RDMA Read WQEs were successfully built + * %-EINVAL: client provided too many chunks or segments, + * %-ENOMEM: rdma_rw context pool was exhausted, + * %-ENOTCONN: posting failed (connection is lost), + * %-EIO: rdma_rw initialization failed (DMA mapping, etc). + */ +static noinline int svc_rdma_read_multiple_chunks(struct svc_rdma_read_info *info) +{ + struct svc_rdma_recv_ctxt *head = info->ri_readctxt; + const struct svc_rdma_pcl *pcl = &head->rc_read_pcl; + struct xdr_buf *buf = &info->ri_rqst->rq_arg; + struct svc_rdma_chunk *chunk, *next; + unsigned int start, length; + int ret; + + start = 0; + chunk = pcl_first_chunk(pcl); + length = chunk->ch_position; + ret = svc_rdma_copy_inline_range(info, start, length); + if (ret < 0) + return ret; + + pcl_for_each_chunk(chunk, pcl) { + ret = svc_rdma_build_read_chunk(info, chunk); + if (ret < 0) + return ret; + + next = pcl_next_chunk(pcl, chunk); + if (!next) + break; + + start += length; + length = next->ch_position - info->ri_totalbytes; + ret = svc_rdma_copy_inline_range(info, start, length); + if (ret < 0) + return ret; + } + + start += length; + length = head->rc_byte_len - start; + ret = svc_rdma_copy_inline_range(info, start, length); + if (ret < 0) + return ret; + + buf->len += info->ri_totalbytes; + buf->buflen += info->ri_totalbytes; + + buf->head[0].iov_base = page_address(info->ri_rqst->rq_pages[0]); + buf->head[0].iov_len = min_t(size_t, PAGE_SIZE, info->ri_totalbytes); + buf->pages = &info->ri_rqst->rq_pages[1]; + buf->page_len = info->ri_totalbytes - buf->head[0].iov_len; + return 0; +} + +/** + * svc_rdma_read_data_item - Construct RDMA Reads to pull data item Read chunks + * @info: context for RDMA Reads + * + * The chunk data lands in the page list of rqstp->rq_arg.pages. + * + * Currently NFSD does not look at the rqstp->rq_arg.tail[0] kvec. + * Therefore, XDR round-up of the Read chunk and trailing + * inline content must both be added at the end of the pagelist. + * + * Return values: + * %0: RDMA Read WQEs were successfully built + * %-EINVAL: client provided too many chunks or segments, + * %-ENOMEM: rdma_rw context pool was exhausted, + * %-ENOTCONN: posting failed (connection is lost), + * %-EIO: rdma_rw initialization failed (DMA mapping, etc). + */ +static int svc_rdma_read_data_item(struct svc_rdma_read_info *info) +{ + struct svc_rdma_recv_ctxt *head = info->ri_readctxt; + struct xdr_buf *buf = &info->ri_rqst->rq_arg; + struct svc_rdma_chunk *chunk; + unsigned int length; + int ret; + + chunk = pcl_first_chunk(&head->rc_read_pcl); + ret = svc_rdma_build_read_chunk(info, chunk); + if (ret < 0) + goto out; + + /* Split the Receive buffer between the head and tail + * buffers at Read chunk's position. XDR roundup of the + * chunk is not included in either the pagelist or in + * the tail. + */ + buf->tail[0].iov_base = buf->head[0].iov_base + chunk->ch_position; + buf->tail[0].iov_len = buf->head[0].iov_len - chunk->ch_position; + buf->head[0].iov_len = chunk->ch_position; + + /* Read chunk may need XDR roundup (see RFC 8166, s. 3.4.5.2). + * + * If the client already rounded up the chunk length, the + * length does not change. Otherwise, the length of the page + * list is increased to include XDR round-up. + * + * Currently these chunks always start at page offset 0, + * thus the rounded-up length never crosses a page boundary. + */ + buf->pages = &info->ri_rqst->rq_pages[0]; + length = xdr_align_size(chunk->ch_length); + buf->page_len = length; + buf->len += length; + buf->buflen += length; + +out: + return ret; +} + +/** + * svc_rdma_read_chunk_range - Build RDMA Read WQEs for portion of a chunk + * @info: context for RDMA Reads + * @chunk: parsed Call chunk to pull + * @offset: offset of region to pull + * @length: length of region to pull + * + * Return values: + * %0: RDMA Read WQEs were successfully built + * %-EINVAL: there were not enough resources to finish + * %-ENOMEM: rdma_rw context pool was exhausted, + * %-ENOTCONN: posting failed (connection is lost), + * %-EIO: rdma_rw initialization failed (DMA mapping, etc). + */ +static int svc_rdma_read_chunk_range(struct svc_rdma_read_info *info, + const struct svc_rdma_chunk *chunk, + unsigned int offset, unsigned int length) +{ + const struct svc_rdma_segment *segment; + int ret; + + ret = -EINVAL; + pcl_for_each_segment(segment, chunk) { + struct svc_rdma_segment dummy; + + if (offset > segment->rs_length) { + offset -= segment->rs_length; + continue; + } + + dummy.rs_handle = segment->rs_handle; + dummy.rs_length = min_t(u32, length, segment->rs_length) - offset; + dummy.rs_offset = segment->rs_offset + offset; + + ret = svc_rdma_build_read_segment(info, &dummy); + if (ret < 0) + break; + + info->ri_totalbytes += dummy.rs_length; + length -= dummy.rs_length; + offset = 0; + } + return ret; +} + +/** + * svc_rdma_read_call_chunk - Build RDMA Read WQEs to pull a Long Message + * @info: context for RDMA Reads + * + * Return values: + * %0: RDMA Read WQEs were successfully built + * %-EINVAL: there were not enough resources to finish + * %-ENOMEM: rdma_rw context pool was exhausted, + * %-ENOTCONN: posting failed (connection is lost), + * %-EIO: rdma_rw initialization failed (DMA mapping, etc). + */ +static int svc_rdma_read_call_chunk(struct svc_rdma_read_info *info) +{ + struct svc_rdma_recv_ctxt *head = info->ri_readctxt; + const struct svc_rdma_chunk *call_chunk = + pcl_first_chunk(&head->rc_call_pcl); + const struct svc_rdma_pcl *pcl = &head->rc_read_pcl; + struct svc_rdma_chunk *chunk, *next; + unsigned int start, length; + int ret; + + if (pcl_is_empty(pcl)) + return svc_rdma_build_read_chunk(info, call_chunk); + + start = 0; + chunk = pcl_first_chunk(pcl); + length = chunk->ch_position; + ret = svc_rdma_read_chunk_range(info, call_chunk, start, length); + if (ret < 0) + return ret; + + pcl_for_each_chunk(chunk, pcl) { + ret = svc_rdma_build_read_chunk(info, chunk); + if (ret < 0) + return ret; + + next = pcl_next_chunk(pcl, chunk); + if (!next) + break; + + start += length; + length = next->ch_position - info->ri_totalbytes; + ret = svc_rdma_read_chunk_range(info, call_chunk, + start, length); + if (ret < 0) + return ret; + } + + start += length; + length = call_chunk->ch_length - start; + return svc_rdma_read_chunk_range(info, call_chunk, start, length); +} + +/** + * svc_rdma_read_special - Build RDMA Read WQEs to pull a Long Message + * @info: context for RDMA Reads + * + * The start of the data lands in the first page just after the + * Transport header, and the rest lands in rqstp->rq_arg.pages. + * + * Assumptions: + * - A PZRC is never sent in an RDMA_MSG message, though it's + * allowed by spec. + * + * Return values: + * %0: RDMA Read WQEs were successfully built + * %-EINVAL: client provided too many chunks or segments, + * %-ENOMEM: rdma_rw context pool was exhausted, + * %-ENOTCONN: posting failed (connection is lost), + * %-EIO: rdma_rw initialization failed (DMA mapping, etc). + */ +static noinline int svc_rdma_read_special(struct svc_rdma_read_info *info) +{ + struct xdr_buf *buf = &info->ri_rqst->rq_arg; + int ret; + + ret = svc_rdma_read_call_chunk(info); + if (ret < 0) + goto out; + + buf->len += info->ri_totalbytes; + buf->buflen += info->ri_totalbytes; + + buf->head[0].iov_base = page_address(info->ri_rqst->rq_pages[0]); + buf->head[0].iov_len = min_t(size_t, PAGE_SIZE, info->ri_totalbytes); + buf->pages = &info->ri_rqst->rq_pages[1]; + buf->page_len = info->ri_totalbytes - buf->head[0].iov_len; + +out: + return ret; +} + +/** + * svc_rdma_process_read_list - Pull list of Read chunks from the client + * @rdma: controlling RDMA transport + * @rqstp: set of pages to use as Read sink buffers + * @head: pages under I/O collect here + * + * The RPC/RDMA protocol assumes that the upper layer's XDR decoders + * pull each Read chunk as they decode an incoming RPC message. + * + * On Linux, however, the server needs to have a fully-constructed RPC + * message in rqstp->rq_arg when there is a positive return code from + * ->xpo_recvfrom. So the Read list is safety-checked immediately when + * it is received, then here the whole Read list is pulled all at once. + * The ingress RPC message is fully reconstructed once all associated + * RDMA Reads have completed. + * + * Return values: + * %1: all needed RDMA Reads were posted successfully, + * %-EINVAL: client provided too many chunks or segments, + * %-ENOMEM: rdma_rw context pool was exhausted, + * %-ENOTCONN: posting failed (connection is lost), + * %-EIO: rdma_rw initialization failed (DMA mapping, etc). + */ +int svc_rdma_process_read_list(struct svcxprt_rdma *rdma, + struct svc_rqst *rqstp, + struct svc_rdma_recv_ctxt *head) +{ + struct svc_rdma_read_info *info; + struct svc_rdma_chunk_ctxt *cc; + int ret; + + info = svc_rdma_read_info_alloc(rdma); + if (!info) + return -ENOMEM; + cc = &info->ri_cc; + info->ri_rqst = rqstp; + info->ri_readctxt = head; + info->ri_pageno = 0; + info->ri_pageoff = 0; + info->ri_totalbytes = 0; + + if (pcl_is_empty(&head->rc_call_pcl)) { + if (head->rc_read_pcl.cl_count == 1) + ret = svc_rdma_read_data_item(info); + else + ret = svc_rdma_read_multiple_chunks(info); + } else + ret = svc_rdma_read_special(info); + if (ret < 0) + goto out_err; + + trace_svcrdma_post_read_chunk(&cc->cc_cid, cc->cc_sqecount); + init_completion(&cc->cc_done); + ret = svc_rdma_post_chunk_ctxt(cc); + if (ret < 0) + goto out_err; + + ret = 1; + wait_for_completion(&cc->cc_done); + if (cc->cc_status != IB_WC_SUCCESS) + ret = -EIO; + + /* rq_respages starts after the last arg page */ + rqstp->rq_respages = &rqstp->rq_pages[head->rc_page_count]; + rqstp->rq_next_page = rqstp->rq_respages + 1; + + /* Ensure svc_rdma_recv_ctxt_put() does not try to release pages */ + head->rc_page_count = 0; + +out_err: + svc_rdma_read_info_free(info); + return ret; +} diff --git a/net/sunrpc/xprtrdma/svc_rdma_sendto.c b/net/sunrpc/xprtrdma/svc_rdma_sendto.c new file mode 100644 index 000000000..22a871e6f --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_sendto.c @@ -0,0 +1,1044 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2016-2018 Oracle. All rights reserved. + * Copyright (c) 2014 Open Grid Computing, Inc. All rights reserved. + * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Author: Tom Tucker + */ + +/* Operation + * + * The main entry point is svc_rdma_sendto. This is called by the + * RPC server when an RPC Reply is ready to be transmitted to a client. + * + * The passed-in svc_rqst contains a struct xdr_buf which holds an + * XDR-encoded RPC Reply message. sendto must construct the RPC-over-RDMA + * transport header, post all Write WRs needed for this Reply, then post + * a Send WR conveying the transport header and the RPC message itself to + * the client. + * + * svc_rdma_sendto must fully transmit the Reply before returning, as + * the svc_rqst will be recycled as soon as sendto returns. Remaining + * resources referred to by the svc_rqst are also recycled at that time. + * Therefore any resources that must remain longer must be detached + * from the svc_rqst and released later. + * + * Page Management + * + * The I/O that performs Reply transmission is asynchronous, and may + * complete well after sendto returns. Thus pages under I/O must be + * removed from the svc_rqst before sendto returns. + * + * The logic here depends on Send Queue and completion ordering. Since + * the Send WR is always posted last, it will always complete last. Thus + * when it completes, it is guaranteed that all previous Write WRs have + * also completed. + * + * Write WRs are constructed and posted. Each Write segment gets its own + * svc_rdma_rw_ctxt, allowing the Write completion handler to find and + * DMA-unmap the pages under I/O for that Write segment. The Write + * completion handler does not release any pages. + * + * When the Send WR is constructed, it also gets its own svc_rdma_send_ctxt. + * The ownership of all of the Reply's pages are transferred into that + * ctxt, the Send WR is posted, and sendto returns. + * + * The svc_rdma_send_ctxt is presented when the Send WR completes. The + * Send completion handler finally releases the Reply's pages. + * + * This mechanism also assumes that completions on the transport's Send + * Completion Queue do not run in parallel. Otherwise a Write completion + * and Send completion running at the same time could release pages that + * are still DMA-mapped. + * + * Error Handling + * + * - If the Send WR is posted successfully, it will either complete + * successfully, or get flushed. Either way, the Send completion + * handler releases the Reply's pages. + * - If the Send WR cannot be not posted, the forward path releases + * the Reply's pages. + * + * This handles the case, without the use of page reference counting, + * where two different Write segments send portions of the same page. + */ + +#include +#include + +#include +#include + +#include +#include + +#include "xprt_rdma.h" +#include + +static void svc_rdma_wc_send(struct ib_cq *cq, struct ib_wc *wc); + +static void svc_rdma_send_cid_init(struct svcxprt_rdma *rdma, + struct rpc_rdma_cid *cid) +{ + cid->ci_queue_id = rdma->sc_sq_cq->res.id; + cid->ci_completion_id = atomic_inc_return(&rdma->sc_completion_ids); +} + +static struct svc_rdma_send_ctxt * +svc_rdma_send_ctxt_alloc(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_send_ctxt *ctxt; + dma_addr_t addr; + void *buffer; + size_t size; + int i; + + size = sizeof(*ctxt); + size += rdma->sc_max_send_sges * sizeof(struct ib_sge); + ctxt = kmalloc(size, GFP_KERNEL); + if (!ctxt) + goto fail0; + buffer = kmalloc(rdma->sc_max_req_size, GFP_KERNEL); + if (!buffer) + goto fail1; + addr = ib_dma_map_single(rdma->sc_pd->device, buffer, + rdma->sc_max_req_size, DMA_TO_DEVICE); + if (ib_dma_mapping_error(rdma->sc_pd->device, addr)) + goto fail2; + + svc_rdma_send_cid_init(rdma, &ctxt->sc_cid); + + ctxt->sc_send_wr.next = NULL; + ctxt->sc_send_wr.wr_cqe = &ctxt->sc_cqe; + ctxt->sc_send_wr.sg_list = ctxt->sc_sges; + ctxt->sc_send_wr.send_flags = IB_SEND_SIGNALED; + init_completion(&ctxt->sc_done); + ctxt->sc_cqe.done = svc_rdma_wc_send; + ctxt->sc_xprt_buf = buffer; + xdr_buf_init(&ctxt->sc_hdrbuf, ctxt->sc_xprt_buf, + rdma->sc_max_req_size); + ctxt->sc_sges[0].addr = addr; + + for (i = 0; i < rdma->sc_max_send_sges; i++) + ctxt->sc_sges[i].lkey = rdma->sc_pd->local_dma_lkey; + return ctxt; + +fail2: + kfree(buffer); +fail1: + kfree(ctxt); +fail0: + return NULL; +} + +/** + * svc_rdma_send_ctxts_destroy - Release all send_ctxt's for an xprt + * @rdma: svcxprt_rdma being torn down + * + */ +void svc_rdma_send_ctxts_destroy(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_send_ctxt *ctxt; + struct llist_node *node; + + while ((node = llist_del_first(&rdma->sc_send_ctxts)) != NULL) { + ctxt = llist_entry(node, struct svc_rdma_send_ctxt, sc_node); + ib_dma_unmap_single(rdma->sc_pd->device, + ctxt->sc_sges[0].addr, + rdma->sc_max_req_size, + DMA_TO_DEVICE); + kfree(ctxt->sc_xprt_buf); + kfree(ctxt); + } +} + +/** + * svc_rdma_send_ctxt_get - Get a free send_ctxt + * @rdma: controlling svcxprt_rdma + * + * Returns a ready-to-use send_ctxt, or NULL if none are + * available and a fresh one cannot be allocated. + */ +struct svc_rdma_send_ctxt *svc_rdma_send_ctxt_get(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_send_ctxt *ctxt; + struct llist_node *node; + + spin_lock(&rdma->sc_send_lock); + node = llist_del_first(&rdma->sc_send_ctxts); + if (!node) + goto out_empty; + ctxt = llist_entry(node, struct svc_rdma_send_ctxt, sc_node); + spin_unlock(&rdma->sc_send_lock); + +out: + rpcrdma_set_xdrlen(&ctxt->sc_hdrbuf, 0); + xdr_init_encode(&ctxt->sc_stream, &ctxt->sc_hdrbuf, + ctxt->sc_xprt_buf, NULL); + + ctxt->sc_send_wr.num_sge = 0; + ctxt->sc_cur_sge_no = 0; + return ctxt; + +out_empty: + spin_unlock(&rdma->sc_send_lock); + ctxt = svc_rdma_send_ctxt_alloc(rdma); + if (!ctxt) + return NULL; + goto out; +} + +/** + * svc_rdma_send_ctxt_put - Return send_ctxt to free list + * @rdma: controlling svcxprt_rdma + * @ctxt: object to return to the free list + */ +void svc_rdma_send_ctxt_put(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *ctxt) +{ + struct ib_device *device = rdma->sc_cm_id->device; + unsigned int i; + + /* The first SGE contains the transport header, which + * remains mapped until @ctxt is destroyed. + */ + for (i = 1; i < ctxt->sc_send_wr.num_sge; i++) { + ib_dma_unmap_page(device, + ctxt->sc_sges[i].addr, + ctxt->sc_sges[i].length, + DMA_TO_DEVICE); + trace_svcrdma_dma_unmap_page(rdma, + ctxt->sc_sges[i].addr, + ctxt->sc_sges[i].length); + } + + llist_add(&ctxt->sc_node, &rdma->sc_send_ctxts); +} + +/** + * svc_rdma_wake_send_waiters - manage Send Queue accounting + * @rdma: controlling transport + * @avail: Number of additional SQEs that are now available + * + */ +void svc_rdma_wake_send_waiters(struct svcxprt_rdma *rdma, int avail) +{ + atomic_add(avail, &rdma->sc_sq_avail); + smp_mb__after_atomic(); + if (unlikely(waitqueue_active(&rdma->sc_send_wait))) + wake_up(&rdma->sc_send_wait); +} + +/** + * svc_rdma_wc_send - Invoked by RDMA provider for each polled Send WC + * @cq: Completion Queue context + * @wc: Work Completion object + * + * NB: The svc_xprt/svcxprt_rdma is pinned whenever it's possible that + * the Send completion handler could be running. + */ +static void svc_rdma_wc_send(struct ib_cq *cq, struct ib_wc *wc) +{ + struct svcxprt_rdma *rdma = cq->cq_context; + struct ib_cqe *cqe = wc->wr_cqe; + struct svc_rdma_send_ctxt *ctxt = + container_of(cqe, struct svc_rdma_send_ctxt, sc_cqe); + + svc_rdma_wake_send_waiters(rdma, 1); + complete(&ctxt->sc_done); + + if (unlikely(wc->status != IB_WC_SUCCESS)) + goto flushed; + + trace_svcrdma_wc_send(wc, &ctxt->sc_cid); + return; + +flushed: + if (wc->status != IB_WC_WR_FLUSH_ERR) + trace_svcrdma_wc_send_err(wc, &ctxt->sc_cid); + else + trace_svcrdma_wc_send_flush(wc, &ctxt->sc_cid); + svc_xprt_deferred_close(&rdma->sc_xprt); +} + +/** + * svc_rdma_send - Post a single Send WR + * @rdma: transport on which to post the WR + * @ctxt: send ctxt with a Send WR ready to post + * + * Returns zero if the Send WR was posted successfully. Otherwise, a + * negative errno is returned. + */ +int svc_rdma_send(struct svcxprt_rdma *rdma, struct svc_rdma_send_ctxt *ctxt) +{ + struct ib_send_wr *wr = &ctxt->sc_send_wr; + int ret; + + reinit_completion(&ctxt->sc_done); + + /* Sync the transport header buffer */ + ib_dma_sync_single_for_device(rdma->sc_pd->device, + wr->sg_list[0].addr, + wr->sg_list[0].length, + DMA_TO_DEVICE); + + /* If the SQ is full, wait until an SQ entry is available */ + while (1) { + if ((atomic_dec_return(&rdma->sc_sq_avail) < 0)) { + percpu_counter_inc(&svcrdma_stat_sq_starve); + trace_svcrdma_sq_full(rdma); + atomic_inc(&rdma->sc_sq_avail); + wait_event(rdma->sc_send_wait, + atomic_read(&rdma->sc_sq_avail) > 1); + if (test_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags)) + return -ENOTCONN; + trace_svcrdma_sq_retry(rdma); + continue; + } + + trace_svcrdma_post_send(ctxt); + ret = ib_post_send(rdma->sc_qp, wr, NULL); + if (ret) + break; + return 0; + } + + trace_svcrdma_sq_post_err(rdma, ret); + svc_xprt_deferred_close(&rdma->sc_xprt); + wake_up(&rdma->sc_send_wait); + return ret; +} + +/** + * svc_rdma_encode_read_list - Encode RPC Reply's Read chunk list + * @sctxt: Send context for the RPC Reply + * + * Return values: + * On success, returns length in bytes of the Reply XDR buffer + * that was consumed by the Reply Read list + * %-EMSGSIZE on XDR buffer overflow + */ +static ssize_t svc_rdma_encode_read_list(struct svc_rdma_send_ctxt *sctxt) +{ + /* RPC-over-RDMA version 1 replies never have a Read list. */ + return xdr_stream_encode_item_absent(&sctxt->sc_stream); +} + +/** + * svc_rdma_encode_write_segment - Encode one Write segment + * @sctxt: Send context for the RPC Reply + * @chunk: Write chunk to push + * @remaining: remaining bytes of the payload left in the Write chunk + * @segno: which segment in the chunk + * + * Return values: + * On success, returns length in bytes of the Reply XDR buffer + * that was consumed by the Write segment, and updates @remaining + * %-EMSGSIZE on XDR buffer overflow + */ +static ssize_t svc_rdma_encode_write_segment(struct svc_rdma_send_ctxt *sctxt, + const struct svc_rdma_chunk *chunk, + u32 *remaining, unsigned int segno) +{ + const struct svc_rdma_segment *segment = &chunk->ch_segments[segno]; + const size_t len = rpcrdma_segment_maxsz * sizeof(__be32); + u32 length; + __be32 *p; + + p = xdr_reserve_space(&sctxt->sc_stream, len); + if (!p) + return -EMSGSIZE; + + length = min_t(u32, *remaining, segment->rs_length); + *remaining -= length; + xdr_encode_rdma_segment(p, segment->rs_handle, length, + segment->rs_offset); + trace_svcrdma_encode_wseg(sctxt, segno, segment->rs_handle, length, + segment->rs_offset); + return len; +} + +/** + * svc_rdma_encode_write_chunk - Encode one Write chunk + * @sctxt: Send context for the RPC Reply + * @chunk: Write chunk to push + * + * Copy a Write chunk from the Call transport header to the + * Reply transport header. Update each segment's length field + * to reflect the number of bytes written in that segment. + * + * Return values: + * On success, returns length in bytes of the Reply XDR buffer + * that was consumed by the Write chunk + * %-EMSGSIZE on XDR buffer overflow + */ +static ssize_t svc_rdma_encode_write_chunk(struct svc_rdma_send_ctxt *sctxt, + const struct svc_rdma_chunk *chunk) +{ + u32 remaining = chunk->ch_payload_length; + unsigned int segno; + ssize_t len, ret; + + len = 0; + ret = xdr_stream_encode_item_present(&sctxt->sc_stream); + if (ret < 0) + return ret; + len += ret; + + ret = xdr_stream_encode_u32(&sctxt->sc_stream, chunk->ch_segcount); + if (ret < 0) + return ret; + len += ret; + + for (segno = 0; segno < chunk->ch_segcount; segno++) { + ret = svc_rdma_encode_write_segment(sctxt, chunk, &remaining, segno); + if (ret < 0) + return ret; + len += ret; + } + + return len; +} + +/** + * svc_rdma_encode_write_list - Encode RPC Reply's Write chunk list + * @rctxt: Reply context with information about the RPC Call + * @sctxt: Send context for the RPC Reply + * + * Return values: + * On success, returns length in bytes of the Reply XDR buffer + * that was consumed by the Reply's Write list + * %-EMSGSIZE on XDR buffer overflow + */ +static ssize_t svc_rdma_encode_write_list(struct svc_rdma_recv_ctxt *rctxt, + struct svc_rdma_send_ctxt *sctxt) +{ + struct svc_rdma_chunk *chunk; + ssize_t len, ret; + + len = 0; + pcl_for_each_chunk(chunk, &rctxt->rc_write_pcl) { + ret = svc_rdma_encode_write_chunk(sctxt, chunk); + if (ret < 0) + return ret; + len += ret; + } + + /* Terminate the Write list */ + ret = xdr_stream_encode_item_absent(&sctxt->sc_stream); + if (ret < 0) + return ret; + + return len + ret; +} + +/** + * svc_rdma_encode_reply_chunk - Encode RPC Reply's Reply chunk + * @rctxt: Reply context with information about the RPC Call + * @sctxt: Send context for the RPC Reply + * @length: size in bytes of the payload in the Reply chunk + * + * Return values: + * On success, returns length in bytes of the Reply XDR buffer + * that was consumed by the Reply's Reply chunk + * %-EMSGSIZE on XDR buffer overflow + * %-E2BIG if the RPC message is larger than the Reply chunk + */ +static ssize_t +svc_rdma_encode_reply_chunk(struct svc_rdma_recv_ctxt *rctxt, + struct svc_rdma_send_ctxt *sctxt, + unsigned int length) +{ + struct svc_rdma_chunk *chunk; + + if (pcl_is_empty(&rctxt->rc_reply_pcl)) + return xdr_stream_encode_item_absent(&sctxt->sc_stream); + + chunk = pcl_first_chunk(&rctxt->rc_reply_pcl); + if (length > chunk->ch_length) + return -E2BIG; + + chunk->ch_payload_length = length; + return svc_rdma_encode_write_chunk(sctxt, chunk); +} + +struct svc_rdma_map_data { + struct svcxprt_rdma *md_rdma; + struct svc_rdma_send_ctxt *md_ctxt; +}; + +/** + * svc_rdma_page_dma_map - DMA map one page + * @data: pointer to arguments + * @page: struct page to DMA map + * @offset: offset into the page + * @len: number of bytes to map + * + * Returns: + * %0 if DMA mapping was successful + * %-EIO if the page cannot be DMA mapped + */ +static int svc_rdma_page_dma_map(void *data, struct page *page, + unsigned long offset, unsigned int len) +{ + struct svc_rdma_map_data *args = data; + struct svcxprt_rdma *rdma = args->md_rdma; + struct svc_rdma_send_ctxt *ctxt = args->md_ctxt; + struct ib_device *dev = rdma->sc_cm_id->device; + dma_addr_t dma_addr; + + ++ctxt->sc_cur_sge_no; + + dma_addr = ib_dma_map_page(dev, page, offset, len, DMA_TO_DEVICE); + if (ib_dma_mapping_error(dev, dma_addr)) + goto out_maperr; + + trace_svcrdma_dma_map_page(rdma, dma_addr, len); + ctxt->sc_sges[ctxt->sc_cur_sge_no].addr = dma_addr; + ctxt->sc_sges[ctxt->sc_cur_sge_no].length = len; + ctxt->sc_send_wr.num_sge++; + return 0; + +out_maperr: + trace_svcrdma_dma_map_err(rdma, dma_addr, len); + return -EIO; +} + +/** + * svc_rdma_iov_dma_map - DMA map an iovec + * @data: pointer to arguments + * @iov: kvec to DMA map + * + * ib_dma_map_page() is used here because svc_rdma_dma_unmap() + * handles DMA-unmap and it uses ib_dma_unmap_page() exclusively. + * + * Returns: + * %0 if DMA mapping was successful + * %-EIO if the iovec cannot be DMA mapped + */ +static int svc_rdma_iov_dma_map(void *data, const struct kvec *iov) +{ + if (!iov->iov_len) + return 0; + return svc_rdma_page_dma_map(data, virt_to_page(iov->iov_base), + offset_in_page(iov->iov_base), + iov->iov_len); +} + +/** + * svc_rdma_xb_dma_map - DMA map all segments of an xdr_buf + * @xdr: xdr_buf containing portion of an RPC message to transmit + * @data: pointer to arguments + * + * Returns: + * %0 if DMA mapping was successful + * %-EIO if DMA mapping failed + * + * On failure, any DMA mappings that have been already done must be + * unmapped by the caller. + */ +static int svc_rdma_xb_dma_map(const struct xdr_buf *xdr, void *data) +{ + unsigned int len, remaining; + unsigned long pageoff; + struct page **ppages; + int ret; + + ret = svc_rdma_iov_dma_map(data, &xdr->head[0]); + if (ret < 0) + return ret; + + ppages = xdr->pages + (xdr->page_base >> PAGE_SHIFT); + pageoff = offset_in_page(xdr->page_base); + remaining = xdr->page_len; + while (remaining) { + len = min_t(u32, PAGE_SIZE - pageoff, remaining); + + ret = svc_rdma_page_dma_map(data, *ppages++, pageoff, len); + if (ret < 0) + return ret; + + remaining -= len; + pageoff = 0; + } + + ret = svc_rdma_iov_dma_map(data, &xdr->tail[0]); + if (ret < 0) + return ret; + + return xdr->len; +} + +struct svc_rdma_pullup_data { + u8 *pd_dest; + unsigned int pd_length; + unsigned int pd_num_sges; +}; + +/** + * svc_rdma_xb_count_sges - Count how many SGEs will be needed + * @xdr: xdr_buf containing portion of an RPC message to transmit + * @data: pointer to arguments + * + * Returns: + * Number of SGEs needed to Send the contents of @xdr inline + */ +static int svc_rdma_xb_count_sges(const struct xdr_buf *xdr, + void *data) +{ + struct svc_rdma_pullup_data *args = data; + unsigned int remaining; + unsigned long offset; + + if (xdr->head[0].iov_len) + ++args->pd_num_sges; + + offset = offset_in_page(xdr->page_base); + remaining = xdr->page_len; + while (remaining) { + ++args->pd_num_sges; + remaining -= min_t(u32, PAGE_SIZE - offset, remaining); + offset = 0; + } + + if (xdr->tail[0].iov_len) + ++args->pd_num_sges; + + args->pd_length += xdr->len; + return 0; +} + +/** + * svc_rdma_pull_up_needed - Determine whether to use pull-up + * @rdma: controlling transport + * @sctxt: send_ctxt for the Send WR + * @rctxt: Write and Reply chunks provided by client + * @xdr: xdr_buf containing RPC message to transmit + * + * Returns: + * %true if pull-up must be used + * %false otherwise + */ +static bool svc_rdma_pull_up_needed(const struct svcxprt_rdma *rdma, + const struct svc_rdma_send_ctxt *sctxt, + const struct svc_rdma_recv_ctxt *rctxt, + const struct xdr_buf *xdr) +{ + /* Resources needed for the transport header */ + struct svc_rdma_pullup_data args = { + .pd_length = sctxt->sc_hdrbuf.len, + .pd_num_sges = 1, + }; + int ret; + + ret = pcl_process_nonpayloads(&rctxt->rc_write_pcl, xdr, + svc_rdma_xb_count_sges, &args); + if (ret < 0) + return false; + + if (args.pd_length < RPCRDMA_PULLUP_THRESH) + return true; + return args.pd_num_sges >= rdma->sc_max_send_sges; +} + +/** + * svc_rdma_xb_linearize - Copy region of xdr_buf to flat buffer + * @xdr: xdr_buf containing portion of an RPC message to copy + * @data: pointer to arguments + * + * Returns: + * Always zero. + */ +static int svc_rdma_xb_linearize(const struct xdr_buf *xdr, + void *data) +{ + struct svc_rdma_pullup_data *args = data; + unsigned int len, remaining; + unsigned long pageoff; + struct page **ppages; + + if (xdr->head[0].iov_len) { + memcpy(args->pd_dest, xdr->head[0].iov_base, xdr->head[0].iov_len); + args->pd_dest += xdr->head[0].iov_len; + } + + ppages = xdr->pages + (xdr->page_base >> PAGE_SHIFT); + pageoff = offset_in_page(xdr->page_base); + remaining = xdr->page_len; + while (remaining) { + len = min_t(u32, PAGE_SIZE - pageoff, remaining); + memcpy(args->pd_dest, page_address(*ppages) + pageoff, len); + remaining -= len; + args->pd_dest += len; + pageoff = 0; + ppages++; + } + + if (xdr->tail[0].iov_len) { + memcpy(args->pd_dest, xdr->tail[0].iov_base, xdr->tail[0].iov_len); + args->pd_dest += xdr->tail[0].iov_len; + } + + args->pd_length += xdr->len; + return 0; +} + +/** + * svc_rdma_pull_up_reply_msg - Copy Reply into a single buffer + * @rdma: controlling transport + * @sctxt: send_ctxt for the Send WR; xprt hdr is already prepared + * @rctxt: Write and Reply chunks provided by client + * @xdr: prepared xdr_buf containing RPC message + * + * The device is not capable of sending the reply directly. + * Assemble the elements of @xdr into the transport header buffer. + * + * Assumptions: + * pull_up_needed has determined that @xdr will fit in the buffer. + * + * Returns: + * %0 if pull-up was successful + * %-EMSGSIZE if a buffer manipulation problem occurred + */ +static int svc_rdma_pull_up_reply_msg(const struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *sctxt, + const struct svc_rdma_recv_ctxt *rctxt, + const struct xdr_buf *xdr) +{ + struct svc_rdma_pullup_data args = { + .pd_dest = sctxt->sc_xprt_buf + sctxt->sc_hdrbuf.len, + }; + int ret; + + ret = pcl_process_nonpayloads(&rctxt->rc_write_pcl, xdr, + svc_rdma_xb_linearize, &args); + if (ret < 0) + return ret; + + sctxt->sc_sges[0].length = sctxt->sc_hdrbuf.len + args.pd_length; + trace_svcrdma_send_pullup(sctxt, args.pd_length); + return 0; +} + +/* svc_rdma_map_reply_msg - DMA map the buffer holding RPC message + * @rdma: controlling transport + * @sctxt: send_ctxt for the Send WR + * @rctxt: Write and Reply chunks provided by client + * @xdr: prepared xdr_buf containing RPC message + * + * Returns: + * %0 if DMA mapping was successful. + * %-EMSGSIZE if a buffer manipulation problem occurred + * %-EIO if DMA mapping failed + * + * The Send WR's num_sge field is set in all cases. + */ +int svc_rdma_map_reply_msg(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *sctxt, + const struct svc_rdma_recv_ctxt *rctxt, + const struct xdr_buf *xdr) +{ + struct svc_rdma_map_data args = { + .md_rdma = rdma, + .md_ctxt = sctxt, + }; + + /* Set up the (persistently-mapped) transport header SGE. */ + sctxt->sc_send_wr.num_sge = 1; + sctxt->sc_sges[0].length = sctxt->sc_hdrbuf.len; + + /* If there is a Reply chunk, nothing follows the transport + * header, and we're done here. + */ + if (!pcl_is_empty(&rctxt->rc_reply_pcl)) + return 0; + + /* For pull-up, svc_rdma_send() will sync the transport header. + * No additional DMA mapping is necessary. + */ + if (svc_rdma_pull_up_needed(rdma, sctxt, rctxt, xdr)) + return svc_rdma_pull_up_reply_msg(rdma, sctxt, rctxt, xdr); + + return pcl_process_nonpayloads(&rctxt->rc_write_pcl, xdr, + svc_rdma_xb_dma_map, &args); +} + +/* Prepare the portion of the RPC Reply that will be transmitted + * via RDMA Send. The RPC-over-RDMA transport header is prepared + * in sc_sges[0], and the RPC xdr_buf is prepared in following sges. + * + * Depending on whether a Write list or Reply chunk is present, + * the server may send all, a portion of, or none of the xdr_buf. + * In the latter case, only the transport header (sc_sges[0]) is + * transmitted. + * + * RDMA Send is the last step of transmitting an RPC reply. Pages + * involved in the earlier RDMA Writes are here transferred out + * of the rqstp and into the sctxt's page array. These pages are + * DMA unmapped by each Write completion, but the subsequent Send + * completion finally releases these pages. + * + * Assumptions: + * - The Reply's transport header will never be larger than a page. + */ +static int svc_rdma_send_reply_msg(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *sctxt, + const struct svc_rdma_recv_ctxt *rctxt, + struct svc_rqst *rqstp) +{ + int ret; + + ret = svc_rdma_map_reply_msg(rdma, sctxt, rctxt, &rqstp->rq_res); + if (ret < 0) + return ret; + + if (rctxt->rc_inv_rkey) { + sctxt->sc_send_wr.opcode = IB_WR_SEND_WITH_INV; + sctxt->sc_send_wr.ex.invalidate_rkey = rctxt->rc_inv_rkey; + } else { + sctxt->sc_send_wr.opcode = IB_WR_SEND; + } + + ret = svc_rdma_send(rdma, sctxt); + if (ret < 0) + return ret; + + ret = wait_for_completion_killable(&sctxt->sc_done); + svc_rdma_send_ctxt_put(rdma, sctxt); + return ret; +} + +/** + * svc_rdma_send_error_msg - Send an RPC/RDMA v1 error response + * @rdma: controlling transport context + * @sctxt: Send context for the response + * @rctxt: Receive context for incoming bad message + * @status: negative errno indicating error that occurred + * + * Given the client-provided Read, Write, and Reply chunks, the + * server was not able to parse the Call or form a complete Reply. + * Return an RDMA_ERROR message so the client can retire the RPC + * transaction. + * + * The caller does not have to release @sctxt. It is released by + * Send completion, or by this function on error. + */ +void svc_rdma_send_error_msg(struct svcxprt_rdma *rdma, + struct svc_rdma_send_ctxt *sctxt, + struct svc_rdma_recv_ctxt *rctxt, + int status) +{ + __be32 *rdma_argp = rctxt->rc_recv_buf; + __be32 *p; + + rpcrdma_set_xdrlen(&sctxt->sc_hdrbuf, 0); + xdr_init_encode(&sctxt->sc_stream, &sctxt->sc_hdrbuf, + sctxt->sc_xprt_buf, NULL); + + p = xdr_reserve_space(&sctxt->sc_stream, + rpcrdma_fixed_maxsz * sizeof(*p)); + if (!p) + goto put_ctxt; + + *p++ = *rdma_argp; + *p++ = *(rdma_argp + 1); + *p++ = rdma->sc_fc_credits; + *p = rdma_error; + + switch (status) { + case -EPROTONOSUPPORT: + p = xdr_reserve_space(&sctxt->sc_stream, 3 * sizeof(*p)); + if (!p) + goto put_ctxt; + + *p++ = err_vers; + *p++ = rpcrdma_version; + *p = rpcrdma_version; + trace_svcrdma_err_vers(*rdma_argp); + break; + default: + p = xdr_reserve_space(&sctxt->sc_stream, sizeof(*p)); + if (!p) + goto put_ctxt; + + *p = err_chunk; + trace_svcrdma_err_chunk(*rdma_argp); + } + + /* Remote Invalidation is skipped for simplicity. */ + sctxt->sc_send_wr.num_sge = 1; + sctxt->sc_send_wr.opcode = IB_WR_SEND; + sctxt->sc_sges[0].length = sctxt->sc_hdrbuf.len; + if (svc_rdma_send(rdma, sctxt)) + goto put_ctxt; + + wait_for_completion_killable(&sctxt->sc_done); + +put_ctxt: + svc_rdma_send_ctxt_put(rdma, sctxt); +} + +/** + * svc_rdma_sendto - Transmit an RPC reply + * @rqstp: processed RPC request, reply XDR already in ::rq_res + * + * Any resources still associated with @rqstp are released upon return. + * If no reply message was possible, the connection is closed. + * + * Returns: + * %0 if an RPC reply has been successfully posted, + * %-ENOMEM if a resource shortage occurred (connection is lost), + * %-ENOTCONN if posting failed (connection is lost). + */ +int svc_rdma_sendto(struct svc_rqst *rqstp) +{ + struct svc_xprt *xprt = rqstp->rq_xprt; + struct svcxprt_rdma *rdma = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + struct svc_rdma_recv_ctxt *rctxt = rqstp->rq_xprt_ctxt; + __be32 *rdma_argp = rctxt->rc_recv_buf; + struct svc_rdma_send_ctxt *sctxt; + unsigned int rc_size; + __be32 *p; + int ret; + + ret = -ENOTCONN; + if (svc_xprt_is_dead(xprt)) + goto drop_connection; + + ret = -ENOMEM; + sctxt = svc_rdma_send_ctxt_get(rdma); + if (!sctxt) + goto drop_connection; + + ret = -EMSGSIZE; + p = xdr_reserve_space(&sctxt->sc_stream, + rpcrdma_fixed_maxsz * sizeof(*p)); + if (!p) + goto put_ctxt; + + ret = svc_rdma_send_reply_chunk(rdma, rctxt, &rqstp->rq_res); + if (ret < 0) + goto reply_chunk; + rc_size = ret; + + *p++ = *rdma_argp; + *p++ = *(rdma_argp + 1); + *p++ = rdma->sc_fc_credits; + *p = pcl_is_empty(&rctxt->rc_reply_pcl) ? rdma_msg : rdma_nomsg; + + ret = svc_rdma_encode_read_list(sctxt); + if (ret < 0) + goto put_ctxt; + ret = svc_rdma_encode_write_list(rctxt, sctxt); + if (ret < 0) + goto put_ctxt; + ret = svc_rdma_encode_reply_chunk(rctxt, sctxt, rc_size); + if (ret < 0) + goto put_ctxt; + + ret = svc_rdma_send_reply_msg(rdma, sctxt, rctxt, rqstp); + if (ret < 0) + goto put_ctxt; + + /* Prevent svc_xprt_release() from releasing the page backing + * rq_res.head[0].iov_base. It's no longer being accessed by + * the I/O device. */ + rqstp->rq_respages++; + return 0; + +reply_chunk: + if (ret != -E2BIG && ret != -EINVAL) + goto put_ctxt; + + svc_rdma_send_error_msg(rdma, sctxt, rctxt, ret); + return 0; + +put_ctxt: + svc_rdma_send_ctxt_put(rdma, sctxt); +drop_connection: + trace_svcrdma_send_err(rqstp, ret); + svc_xprt_deferred_close(&rdma->sc_xprt); + return -ENOTCONN; +} + +/** + * svc_rdma_result_payload - special processing for a result payload + * @rqstp: svc_rqst to operate on + * @offset: payload's byte offset in @xdr + * @length: size of payload, in bytes + * + * Return values: + * %0 if successful or nothing needed to be done + * %-EMSGSIZE on XDR buffer overflow + * %-E2BIG if the payload was larger than the Write chunk + * %-EINVAL if client provided too many segments + * %-ENOMEM if rdma_rw context pool was exhausted + * %-ENOTCONN if posting failed (connection is lost) + * %-EIO if rdma_rw initialization failed (DMA mapping, etc) + */ +int svc_rdma_result_payload(struct svc_rqst *rqstp, unsigned int offset, + unsigned int length) +{ + struct svc_rdma_recv_ctxt *rctxt = rqstp->rq_xprt_ctxt; + struct svc_rdma_chunk *chunk; + struct svcxprt_rdma *rdma; + struct xdr_buf subbuf; + int ret; + + chunk = rctxt->rc_cur_result_payload; + if (!length || !chunk) + return 0; + rctxt->rc_cur_result_payload = + pcl_next_chunk(&rctxt->rc_write_pcl, chunk); + if (length > chunk->ch_length) + return -E2BIG; + + chunk->ch_position = offset; + chunk->ch_payload_length = length; + + if (xdr_buf_subsegment(&rqstp->rq_res, &subbuf, offset, length)) + return -EMSGSIZE; + + rdma = container_of(rqstp->rq_xprt, struct svcxprt_rdma, sc_xprt); + ret = svc_rdma_send_write_chunk(rdma, chunk, &subbuf); + if (ret < 0) + return ret; + return 0; +} diff --git a/net/sunrpc/xprtrdma/svc_rdma_transport.c b/net/sunrpc/xprtrdma/svc_rdma_transport.c new file mode 100644 index 000000000..f776f0cb4 --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_transport.c @@ -0,0 +1,610 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2015-2018 Oracle. All rights reserved. + * Copyright (c) 2014 Open Grid Computing, Inc. All rights reserved. + * Copyright (c) 2005-2007 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Author: Tom Tucker + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +#include "xprt_rdma.h" +#include + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + +static struct svcxprt_rdma *svc_rdma_create_xprt(struct svc_serv *serv, + struct net *net); +static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, + struct net *net, + struct sockaddr *sa, int salen, + int flags); +static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt); +static void svc_rdma_detach(struct svc_xprt *xprt); +static void svc_rdma_free(struct svc_xprt *xprt); +static int svc_rdma_has_wspace(struct svc_xprt *xprt); +static void svc_rdma_secure_port(struct svc_rqst *); +static void svc_rdma_kill_temp_xprt(struct svc_xprt *); + +static const struct svc_xprt_ops svc_rdma_ops = { + .xpo_create = svc_rdma_create, + .xpo_recvfrom = svc_rdma_recvfrom, + .xpo_sendto = svc_rdma_sendto, + .xpo_result_payload = svc_rdma_result_payload, + .xpo_release_ctxt = svc_rdma_release_ctxt, + .xpo_detach = svc_rdma_detach, + .xpo_free = svc_rdma_free, + .xpo_has_wspace = svc_rdma_has_wspace, + .xpo_accept = svc_rdma_accept, + .xpo_secure_port = svc_rdma_secure_port, + .xpo_kill_temp_xprt = svc_rdma_kill_temp_xprt, +}; + +struct svc_xprt_class svc_rdma_class = { + .xcl_name = "rdma", + .xcl_owner = THIS_MODULE, + .xcl_ops = &svc_rdma_ops, + .xcl_max_payload = RPCSVC_MAXPAYLOAD_RDMA, + .xcl_ident = XPRT_TRANSPORT_RDMA, +}; + +/* QP event handler */ +static void qp_event_handler(struct ib_event *event, void *context) +{ + struct svc_xprt *xprt = context; + + trace_svcrdma_qp_error(event, (struct sockaddr *)&xprt->xpt_remote); + switch (event->event) { + /* These are considered benign events */ + case IB_EVENT_PATH_MIG: + case IB_EVENT_COMM_EST: + case IB_EVENT_SQ_DRAINED: + case IB_EVENT_QP_LAST_WQE_REACHED: + break; + + /* These are considered fatal events */ + case IB_EVENT_PATH_MIG_ERR: + case IB_EVENT_QP_FATAL: + case IB_EVENT_QP_REQ_ERR: + case IB_EVENT_QP_ACCESS_ERR: + case IB_EVENT_DEVICE_FATAL: + default: + svc_xprt_deferred_close(xprt); + break; + } +} + +static struct svcxprt_rdma *svc_rdma_create_xprt(struct svc_serv *serv, + struct net *net) +{ + struct svcxprt_rdma *cma_xprt = kzalloc(sizeof *cma_xprt, GFP_KERNEL); + + if (!cma_xprt) { + dprintk("svcrdma: failed to create new transport\n"); + return NULL; + } + svc_xprt_init(net, &svc_rdma_class, &cma_xprt->sc_xprt, serv); + INIT_LIST_HEAD(&cma_xprt->sc_accept_q); + INIT_LIST_HEAD(&cma_xprt->sc_rq_dto_q); + init_llist_head(&cma_xprt->sc_send_ctxts); + init_llist_head(&cma_xprt->sc_recv_ctxts); + init_llist_head(&cma_xprt->sc_rw_ctxts); + init_waitqueue_head(&cma_xprt->sc_send_wait); + + spin_lock_init(&cma_xprt->sc_lock); + spin_lock_init(&cma_xprt->sc_rq_dto_lock); + spin_lock_init(&cma_xprt->sc_send_lock); + spin_lock_init(&cma_xprt->sc_rw_ctxt_lock); + + /* + * Note that this implies that the underlying transport support + * has some form of congestion control (see RFC 7530 section 3.1 + * paragraph 2). For now, we assume that all supported RDMA + * transports are suitable here. + */ + set_bit(XPT_CONG_CTRL, &cma_xprt->sc_xprt.xpt_flags); + + return cma_xprt; +} + +static void +svc_rdma_parse_connect_private(struct svcxprt_rdma *newxprt, + struct rdma_conn_param *param) +{ + const struct rpcrdma_connect_private *pmsg = param->private_data; + + if (pmsg && + pmsg->cp_magic == rpcrdma_cmp_magic && + pmsg->cp_version == RPCRDMA_CMP_VERSION) { + newxprt->sc_snd_w_inv = pmsg->cp_flags & + RPCRDMA_CMP_F_SND_W_INV_OK; + + dprintk("svcrdma: client send_size %u, recv_size %u " + "remote inv %ssupported\n", + rpcrdma_decode_buffer_size(pmsg->cp_send_size), + rpcrdma_decode_buffer_size(pmsg->cp_recv_size), + newxprt->sc_snd_w_inv ? "" : "un"); + } +} + +/* + * This function handles the CONNECT_REQUEST event on a listening + * endpoint. It is passed the cma_id for the _new_ connection. The context in + * this cma_id is inherited from the listening cma_id and is the svc_xprt + * structure for the listening endpoint. + * + * This function creates a new xprt for the new connection and enqueues it on + * the accept queue for the listent xprt. When the listen thread is kicked, it + * will call the recvfrom method on the listen xprt which will accept the new + * connection. + */ +static void handle_connect_req(struct rdma_cm_id *new_cma_id, + struct rdma_conn_param *param) +{ + struct svcxprt_rdma *listen_xprt = new_cma_id->context; + struct svcxprt_rdma *newxprt; + struct sockaddr *sa; + + /* Create a new transport */ + newxprt = svc_rdma_create_xprt(listen_xprt->sc_xprt.xpt_server, + listen_xprt->sc_xprt.xpt_net); + if (!newxprt) + return; + newxprt->sc_cm_id = new_cma_id; + new_cma_id->context = newxprt; + svc_rdma_parse_connect_private(newxprt, param); + + /* Save client advertised inbound read limit for use later in accept. */ + newxprt->sc_ord = param->initiator_depth; + + sa = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.dst_addr; + newxprt->sc_xprt.xpt_remotelen = svc_addr_len(sa); + memcpy(&newxprt->sc_xprt.xpt_remote, sa, + newxprt->sc_xprt.xpt_remotelen); + snprintf(newxprt->sc_xprt.xpt_remotebuf, + sizeof(newxprt->sc_xprt.xpt_remotebuf) - 1, "%pISc", sa); + + /* The remote port is arbitrary and not under the control of the + * client ULP. Set it to a fixed value so that the DRC continues + * to be effective after a reconnect. + */ + rpc_set_port((struct sockaddr *)&newxprt->sc_xprt.xpt_remote, 0); + + sa = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.src_addr; + svc_xprt_set_local(&newxprt->sc_xprt, sa, svc_addr_len(sa)); + + /* + * Enqueue the new transport on the accept queue of the listening + * transport + */ + spin_lock(&listen_xprt->sc_lock); + list_add_tail(&newxprt->sc_accept_q, &listen_xprt->sc_accept_q); + spin_unlock(&listen_xprt->sc_lock); + + set_bit(XPT_CONN, &listen_xprt->sc_xprt.xpt_flags); + svc_xprt_enqueue(&listen_xprt->sc_xprt); +} + +/** + * svc_rdma_listen_handler - Handle CM events generated on a listening endpoint + * @cma_id: the server's listener rdma_cm_id + * @event: details of the event + * + * Return values: + * %0: Do not destroy @cma_id + * %1: Destroy @cma_id (never returned here) + * + * NB: There is never a DEVICE_REMOVAL event for INADDR_ANY listeners. + */ +static int svc_rdma_listen_handler(struct rdma_cm_id *cma_id, + struct rdma_cm_event *event) +{ + switch (event->event) { + case RDMA_CM_EVENT_CONNECT_REQUEST: + handle_connect_req(cma_id, &event->param.conn); + break; + default: + break; + } + return 0; +} + +/** + * svc_rdma_cma_handler - Handle CM events on client connections + * @cma_id: the server's listener rdma_cm_id + * @event: details of the event + * + * Return values: + * %0: Do not destroy @cma_id + * %1: Destroy @cma_id (never returned here) + */ +static int svc_rdma_cma_handler(struct rdma_cm_id *cma_id, + struct rdma_cm_event *event) +{ + struct svcxprt_rdma *rdma = cma_id->context; + struct svc_xprt *xprt = &rdma->sc_xprt; + + switch (event->event) { + case RDMA_CM_EVENT_ESTABLISHED: + clear_bit(RDMAXPRT_CONN_PENDING, &rdma->sc_flags); + + /* Handle any requests that were received while + * CONN_PENDING was set. */ + svc_xprt_enqueue(xprt); + break; + case RDMA_CM_EVENT_DISCONNECTED: + case RDMA_CM_EVENT_DEVICE_REMOVAL: + svc_xprt_deferred_close(xprt); + break; + default: + break; + } + return 0; +} + +/* + * Create a listening RDMA service endpoint. + */ +static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, + struct net *net, + struct sockaddr *sa, int salen, + int flags) +{ + struct rdma_cm_id *listen_id; + struct svcxprt_rdma *cma_xprt; + int ret; + + if (sa->sa_family != AF_INET && sa->sa_family != AF_INET6) + return ERR_PTR(-EAFNOSUPPORT); + cma_xprt = svc_rdma_create_xprt(serv, net); + if (!cma_xprt) + return ERR_PTR(-ENOMEM); + set_bit(XPT_LISTENER, &cma_xprt->sc_xprt.xpt_flags); + strcpy(cma_xprt->sc_xprt.xpt_remotebuf, "listener"); + + listen_id = rdma_create_id(net, svc_rdma_listen_handler, cma_xprt, + RDMA_PS_TCP, IB_QPT_RC); + if (IS_ERR(listen_id)) { + ret = PTR_ERR(listen_id); + goto err0; + } + + /* Allow both IPv4 and IPv6 sockets to bind a single port + * at the same time. + */ +#if IS_ENABLED(CONFIG_IPV6) + ret = rdma_set_afonly(listen_id, 1); + if (ret) + goto err1; +#endif + ret = rdma_bind_addr(listen_id, sa); + if (ret) + goto err1; + cma_xprt->sc_cm_id = listen_id; + + ret = rdma_listen(listen_id, RPCRDMA_LISTEN_BACKLOG); + if (ret) + goto err1; + + /* + * We need to use the address from the cm_id in case the + * caller specified 0 for the port number. + */ + sa = (struct sockaddr *)&cma_xprt->sc_cm_id->route.addr.src_addr; + svc_xprt_set_local(&cma_xprt->sc_xprt, sa, salen); + + return &cma_xprt->sc_xprt; + + err1: + rdma_destroy_id(listen_id); + err0: + kfree(cma_xprt); + return ERR_PTR(ret); +} + +/* + * This is the xpo_recvfrom function for listening endpoints. Its + * purpose is to accept incoming connections. The CMA callback handler + * has already created a new transport and attached it to the new CMA + * ID. + * + * There is a queue of pending connections hung on the listening + * transport. This queue contains the new svc_xprt structure. This + * function takes svc_xprt structures off the accept_q and completes + * the connection. + */ +static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt) +{ + struct svcxprt_rdma *listen_rdma; + struct svcxprt_rdma *newxprt = NULL; + struct rdma_conn_param conn_param; + struct rpcrdma_connect_private pmsg; + struct ib_qp_init_attr qp_attr; + unsigned int ctxts, rq_depth; + struct ib_device *dev; + int ret = 0; + RPC_IFDEBUG(struct sockaddr *sap); + + listen_rdma = container_of(xprt, struct svcxprt_rdma, sc_xprt); + clear_bit(XPT_CONN, &xprt->xpt_flags); + /* Get the next entry off the accept list */ + spin_lock(&listen_rdma->sc_lock); + if (!list_empty(&listen_rdma->sc_accept_q)) { + newxprt = list_entry(listen_rdma->sc_accept_q.next, + struct svcxprt_rdma, sc_accept_q); + list_del_init(&newxprt->sc_accept_q); + } + if (!list_empty(&listen_rdma->sc_accept_q)) + set_bit(XPT_CONN, &listen_rdma->sc_xprt.xpt_flags); + spin_unlock(&listen_rdma->sc_lock); + if (!newxprt) + return NULL; + + dev = newxprt->sc_cm_id->device; + newxprt->sc_port_num = newxprt->sc_cm_id->port_num; + + /* Qualify the transport resource defaults with the + * capabilities of this particular device */ + /* Transport header, head iovec, tail iovec */ + newxprt->sc_max_send_sges = 3; + /* Add one SGE per page list entry */ + newxprt->sc_max_send_sges += (svcrdma_max_req_size / PAGE_SIZE) + 1; + if (newxprt->sc_max_send_sges > dev->attrs.max_send_sge) + newxprt->sc_max_send_sges = dev->attrs.max_send_sge; + newxprt->sc_max_req_size = svcrdma_max_req_size; + newxprt->sc_max_requests = svcrdma_max_requests; + newxprt->sc_max_bc_requests = svcrdma_max_bc_requests; + newxprt->sc_recv_batch = RPCRDMA_MAX_RECV_BATCH; + rq_depth = newxprt->sc_max_requests + newxprt->sc_max_bc_requests + + newxprt->sc_recv_batch; + if (rq_depth > dev->attrs.max_qp_wr) { + pr_warn("svcrdma: reducing receive depth to %d\n", + dev->attrs.max_qp_wr); + rq_depth = dev->attrs.max_qp_wr; + newxprt->sc_recv_batch = 1; + newxprt->sc_max_requests = rq_depth - 2; + newxprt->sc_max_bc_requests = 2; + } + newxprt->sc_fc_credits = cpu_to_be32(newxprt->sc_max_requests); + ctxts = rdma_rw_mr_factor(dev, newxprt->sc_port_num, RPCSVC_MAXPAGES); + ctxts *= newxprt->sc_max_requests; + newxprt->sc_sq_depth = rq_depth + ctxts; + if (newxprt->sc_sq_depth > dev->attrs.max_qp_wr) { + pr_warn("svcrdma: reducing send depth to %d\n", + dev->attrs.max_qp_wr); + newxprt->sc_sq_depth = dev->attrs.max_qp_wr; + } + atomic_set(&newxprt->sc_sq_avail, newxprt->sc_sq_depth); + + newxprt->sc_pd = ib_alloc_pd(dev, 0); + if (IS_ERR(newxprt->sc_pd)) { + trace_svcrdma_pd_err(newxprt, PTR_ERR(newxprt->sc_pd)); + goto errout; + } + newxprt->sc_sq_cq = ib_alloc_cq_any(dev, newxprt, newxprt->sc_sq_depth, + IB_POLL_WORKQUEUE); + if (IS_ERR(newxprt->sc_sq_cq)) + goto errout; + newxprt->sc_rq_cq = + ib_alloc_cq_any(dev, newxprt, rq_depth, IB_POLL_WORKQUEUE); + if (IS_ERR(newxprt->sc_rq_cq)) + goto errout; + + memset(&qp_attr, 0, sizeof qp_attr); + qp_attr.event_handler = qp_event_handler; + qp_attr.qp_context = &newxprt->sc_xprt; + qp_attr.port_num = newxprt->sc_port_num; + qp_attr.cap.max_rdma_ctxs = ctxts; + qp_attr.cap.max_send_wr = newxprt->sc_sq_depth - ctxts; + qp_attr.cap.max_recv_wr = rq_depth; + qp_attr.cap.max_send_sge = newxprt->sc_max_send_sges; + qp_attr.cap.max_recv_sge = 1; + qp_attr.sq_sig_type = IB_SIGNAL_REQ_WR; + qp_attr.qp_type = IB_QPT_RC; + qp_attr.send_cq = newxprt->sc_sq_cq; + qp_attr.recv_cq = newxprt->sc_rq_cq; + dprintk("svcrdma: newxprt->sc_cm_id=%p, newxprt->sc_pd=%p\n", + newxprt->sc_cm_id, newxprt->sc_pd); + dprintk(" cap.max_send_wr = %d, cap.max_recv_wr = %d\n", + qp_attr.cap.max_send_wr, qp_attr.cap.max_recv_wr); + dprintk(" cap.max_send_sge = %d, cap.max_recv_sge = %d\n", + qp_attr.cap.max_send_sge, qp_attr.cap.max_recv_sge); + + ret = rdma_create_qp(newxprt->sc_cm_id, newxprt->sc_pd, &qp_attr); + if (ret) { + trace_svcrdma_qp_err(newxprt, ret); + goto errout; + } + newxprt->sc_qp = newxprt->sc_cm_id->qp; + + if (!(dev->attrs.device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS)) + newxprt->sc_snd_w_inv = false; + if (!rdma_protocol_iwarp(dev, newxprt->sc_port_num) && + !rdma_ib_or_roce(dev, newxprt->sc_port_num)) { + trace_svcrdma_fabric_err(newxprt, -EINVAL); + goto errout; + } + + if (!svc_rdma_post_recvs(newxprt)) + goto errout; + + /* Construct RDMA-CM private message */ + pmsg.cp_magic = rpcrdma_cmp_magic; + pmsg.cp_version = RPCRDMA_CMP_VERSION; + pmsg.cp_flags = 0; + pmsg.cp_send_size = pmsg.cp_recv_size = + rpcrdma_encode_buffer_size(newxprt->sc_max_req_size); + + /* Accept Connection */ + set_bit(RDMAXPRT_CONN_PENDING, &newxprt->sc_flags); + memset(&conn_param, 0, sizeof conn_param); + conn_param.responder_resources = 0; + conn_param.initiator_depth = min_t(int, newxprt->sc_ord, + dev->attrs.max_qp_init_rd_atom); + if (!conn_param.initiator_depth) { + ret = -EINVAL; + trace_svcrdma_initdepth_err(newxprt, ret); + goto errout; + } + conn_param.private_data = &pmsg; + conn_param.private_data_len = sizeof(pmsg); + rdma_lock_handler(newxprt->sc_cm_id); + newxprt->sc_cm_id->event_handler = svc_rdma_cma_handler; + ret = rdma_accept(newxprt->sc_cm_id, &conn_param); + rdma_unlock_handler(newxprt->sc_cm_id); + if (ret) { + trace_svcrdma_accept_err(newxprt, ret); + goto errout; + } + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + dprintk("svcrdma: new connection %p accepted:\n", newxprt); + sap = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.src_addr; + dprintk(" local address : %pIS:%u\n", sap, rpc_get_port(sap)); + sap = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.dst_addr; + dprintk(" remote address : %pIS:%u\n", sap, rpc_get_port(sap)); + dprintk(" max_sge : %d\n", newxprt->sc_max_send_sges); + dprintk(" sq_depth : %d\n", newxprt->sc_sq_depth); + dprintk(" rdma_rw_ctxs : %d\n", ctxts); + dprintk(" max_requests : %d\n", newxprt->sc_max_requests); + dprintk(" ord : %d\n", conn_param.initiator_depth); +#endif + + return &newxprt->sc_xprt; + + errout: + /* Take a reference in case the DTO handler runs */ + svc_xprt_get(&newxprt->sc_xprt); + if (newxprt->sc_qp && !IS_ERR(newxprt->sc_qp)) + ib_destroy_qp(newxprt->sc_qp); + rdma_destroy_id(newxprt->sc_cm_id); + /* This call to put will destroy the transport */ + svc_xprt_put(&newxprt->sc_xprt); + return NULL; +} + +static void svc_rdma_detach(struct svc_xprt *xprt) +{ + struct svcxprt_rdma *rdma = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + + rdma_disconnect(rdma->sc_cm_id); +} + +static void __svc_rdma_free(struct work_struct *work) +{ + struct svcxprt_rdma *rdma = + container_of(work, struct svcxprt_rdma, sc_work); + + /* This blocks until the Completion Queues are empty */ + if (rdma->sc_qp && !IS_ERR(rdma->sc_qp)) + ib_drain_qp(rdma->sc_qp); + + svc_rdma_flush_recv_queues(rdma); + + svc_rdma_destroy_rw_ctxts(rdma); + svc_rdma_send_ctxts_destroy(rdma); + svc_rdma_recv_ctxts_destroy(rdma); + + /* Destroy the QP if present (not a listener) */ + if (rdma->sc_qp && !IS_ERR(rdma->sc_qp)) + ib_destroy_qp(rdma->sc_qp); + + if (rdma->sc_sq_cq && !IS_ERR(rdma->sc_sq_cq)) + ib_free_cq(rdma->sc_sq_cq); + + if (rdma->sc_rq_cq && !IS_ERR(rdma->sc_rq_cq)) + ib_free_cq(rdma->sc_rq_cq); + + if (rdma->sc_pd && !IS_ERR(rdma->sc_pd)) + ib_dealloc_pd(rdma->sc_pd); + + /* Destroy the CM ID */ + rdma_destroy_id(rdma->sc_cm_id); + + kfree(rdma); +} + +static void svc_rdma_free(struct svc_xprt *xprt) +{ + struct svcxprt_rdma *rdma = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + + INIT_WORK(&rdma->sc_work, __svc_rdma_free); + schedule_work(&rdma->sc_work); +} + +static int svc_rdma_has_wspace(struct svc_xprt *xprt) +{ + struct svcxprt_rdma *rdma = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + + /* + * If there are already waiters on the SQ, + * return false. + */ + if (waitqueue_active(&rdma->sc_send_wait)) + return 0; + + /* Otherwise return true. */ + return 1; +} + +static void svc_rdma_secure_port(struct svc_rqst *rqstp) +{ + set_bit(RQ_SECURE, &rqstp->rq_flags); +} + +static void svc_rdma_kill_temp_xprt(struct svc_xprt *xprt) +{ +} diff --git a/net/sunrpc/xprtrdma/transport.c b/net/sunrpc/xprtrdma/transport.c new file mode 100644 index 000000000..10bb2b929 --- /dev/null +++ b/net/sunrpc/xprtrdma/transport.c @@ -0,0 +1,805 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2014-2017 Oracle. All rights reserved. + * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* + * transport.c + * + * This file contains the top-level implementation of an RPC RDMA + * transport. + * + * Naming convention: functions beginning with xprt_ are part of the + * transport switch. All others are RPC RDMA internal. + */ + +#include +#include +#include +#include + +#include +#include + +#include "xprt_rdma.h" +#include + +/* + * tunables + */ + +static unsigned int xprt_rdma_slot_table_entries = RPCRDMA_DEF_SLOT_TABLE; +unsigned int xprt_rdma_max_inline_read = RPCRDMA_DEF_INLINE; +unsigned int xprt_rdma_max_inline_write = RPCRDMA_DEF_INLINE; +unsigned int xprt_rdma_memreg_strategy = RPCRDMA_FRWR; +int xprt_rdma_pad_optimize; +static struct xprt_class xprt_rdma; + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + +static unsigned int min_slot_table_size = RPCRDMA_MIN_SLOT_TABLE; +static unsigned int max_slot_table_size = RPCRDMA_MAX_SLOT_TABLE; +static unsigned int min_inline_size = RPCRDMA_MIN_INLINE; +static unsigned int max_inline_size = RPCRDMA_MAX_INLINE; +static unsigned int max_padding = PAGE_SIZE; +static unsigned int min_memreg = RPCRDMA_BOUNCEBUFFERS; +static unsigned int max_memreg = RPCRDMA_LAST - 1; +static unsigned int dummy; + +static struct ctl_table_header *sunrpc_table_header; + +static struct ctl_table xr_tunables_table[] = { + { + .procname = "rdma_slot_table_entries", + .data = &xprt_rdma_slot_table_entries, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_slot_table_size, + .extra2 = &max_slot_table_size + }, + { + .procname = "rdma_max_inline_read", + .data = &xprt_rdma_max_inline_read, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_inline_size, + .extra2 = &max_inline_size, + }, + { + .procname = "rdma_max_inline_write", + .data = &xprt_rdma_max_inline_write, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_inline_size, + .extra2 = &max_inline_size, + }, + { + .procname = "rdma_inline_write_padding", + .data = &dummy, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = &max_padding, + }, + { + .procname = "rdma_memreg_strategy", + .data = &xprt_rdma_memreg_strategy, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_memreg, + .extra2 = &max_memreg, + }, + { + .procname = "rdma_pad_optimize", + .data = &xprt_rdma_pad_optimize, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { }, +}; + +static struct ctl_table sunrpc_table[] = { + { + .procname = "sunrpc", + .mode = 0555, + .child = xr_tunables_table + }, + { }, +}; + +#endif + +static const struct rpc_xprt_ops xprt_rdma_procs; + +static void +xprt_rdma_format_addresses4(struct rpc_xprt *xprt, struct sockaddr *sap) +{ + struct sockaddr_in *sin = (struct sockaddr_in *)sap; + char buf[20]; + + snprintf(buf, sizeof(buf), "%08x", ntohl(sin->sin_addr.s_addr)); + xprt->address_strings[RPC_DISPLAY_HEX_ADDR] = kstrdup(buf, GFP_KERNEL); + + xprt->address_strings[RPC_DISPLAY_NETID] = RPCBIND_NETID_RDMA; +} + +static void +xprt_rdma_format_addresses6(struct rpc_xprt *xprt, struct sockaddr *sap) +{ + struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)sap; + char buf[40]; + + snprintf(buf, sizeof(buf), "%pi6", &sin6->sin6_addr); + xprt->address_strings[RPC_DISPLAY_HEX_ADDR] = kstrdup(buf, GFP_KERNEL); + + xprt->address_strings[RPC_DISPLAY_NETID] = RPCBIND_NETID_RDMA6; +} + +void +xprt_rdma_format_addresses(struct rpc_xprt *xprt, struct sockaddr *sap) +{ + char buf[128]; + + switch (sap->sa_family) { + case AF_INET: + xprt_rdma_format_addresses4(xprt, sap); + break; + case AF_INET6: + xprt_rdma_format_addresses6(xprt, sap); + break; + default: + pr_err("rpcrdma: Unrecognized address family\n"); + return; + } + + (void)rpc_ntop(sap, buf, sizeof(buf)); + xprt->address_strings[RPC_DISPLAY_ADDR] = kstrdup(buf, GFP_KERNEL); + + snprintf(buf, sizeof(buf), "%u", rpc_get_port(sap)); + xprt->address_strings[RPC_DISPLAY_PORT] = kstrdup(buf, GFP_KERNEL); + + snprintf(buf, sizeof(buf), "%4hx", rpc_get_port(sap)); + xprt->address_strings[RPC_DISPLAY_HEX_PORT] = kstrdup(buf, GFP_KERNEL); + + xprt->address_strings[RPC_DISPLAY_PROTO] = "rdma"; +} + +void +xprt_rdma_free_addresses(struct rpc_xprt *xprt) +{ + unsigned int i; + + for (i = 0; i < RPC_DISPLAY_MAX; i++) + switch (i) { + case RPC_DISPLAY_PROTO: + case RPC_DISPLAY_NETID: + continue; + default: + kfree(xprt->address_strings[i]); + } +} + +/** + * xprt_rdma_connect_worker - establish connection in the background + * @work: worker thread context + * + * Requester holds the xprt's send lock to prevent activity on this + * transport while a fresh connection is being established. RPC tasks + * sleep on the xprt's pending queue waiting for connect to complete. + */ +static void +xprt_rdma_connect_worker(struct work_struct *work) +{ + struct rpcrdma_xprt *r_xprt = container_of(work, struct rpcrdma_xprt, + rx_connect_worker.work); + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + unsigned int pflags = current->flags; + int rc; + + if (atomic_read(&xprt->swapper)) + current->flags |= PF_MEMALLOC; + rc = rpcrdma_xprt_connect(r_xprt); + xprt_clear_connecting(xprt); + if (!rc) { + xprt->connect_cookie++; + xprt->stat.connect_count++; + xprt->stat.connect_time += (long)jiffies - + xprt->stat.connect_start; + xprt_set_connected(xprt); + rc = -EAGAIN; + } else + rpcrdma_xprt_disconnect(r_xprt); + xprt_unlock_connect(xprt, r_xprt); + xprt_wake_pending_tasks(xprt, rc); + current_restore_flags(pflags, PF_MEMALLOC); +} + +/** + * xprt_rdma_inject_disconnect - inject a connection fault + * @xprt: transport context + * + * If @xprt is connected, disconnect it to simulate spurious + * connection loss. Caller must hold @xprt's send lock to + * ensure that data structures and hardware resources are + * stable during the rdma_disconnect() call. + */ +static void +xprt_rdma_inject_disconnect(struct rpc_xprt *xprt) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + + trace_xprtrdma_op_inject_dsc(r_xprt); + rdma_disconnect(r_xprt->rx_ep->re_id); +} + +/** + * xprt_rdma_destroy - Full tear down of transport + * @xprt: doomed transport context + * + * Caller guarantees there will be no more calls to us with + * this @xprt. + */ +static void +xprt_rdma_destroy(struct rpc_xprt *xprt) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + + cancel_delayed_work_sync(&r_xprt->rx_connect_worker); + + rpcrdma_xprt_disconnect(r_xprt); + rpcrdma_buffer_destroy(&r_xprt->rx_buf); + + xprt_rdma_free_addresses(xprt); + xprt_free(xprt); + + module_put(THIS_MODULE); +} + +/* 60 second timeout, no retries */ +static const struct rpc_timeout xprt_rdma_default_timeout = { + .to_initval = 60 * HZ, + .to_maxval = 60 * HZ, +}; + +/** + * xprt_setup_rdma - Set up transport to use RDMA + * + * @args: rpc transport arguments + */ +static struct rpc_xprt * +xprt_setup_rdma(struct xprt_create *args) +{ + struct rpc_xprt *xprt; + struct rpcrdma_xprt *new_xprt; + struct sockaddr *sap; + int rc; + + if (args->addrlen > sizeof(xprt->addr)) + return ERR_PTR(-EBADF); + + if (!try_module_get(THIS_MODULE)) + return ERR_PTR(-EIO); + + xprt = xprt_alloc(args->net, sizeof(struct rpcrdma_xprt), 0, + xprt_rdma_slot_table_entries); + if (!xprt) { + module_put(THIS_MODULE); + return ERR_PTR(-ENOMEM); + } + + xprt->timeout = &xprt_rdma_default_timeout; + xprt->connect_timeout = xprt->timeout->to_initval; + xprt->max_reconnect_timeout = xprt->timeout->to_maxval; + xprt->bind_timeout = RPCRDMA_BIND_TO; + xprt->reestablish_timeout = RPCRDMA_INIT_REEST_TO; + xprt->idle_timeout = RPCRDMA_IDLE_DISC_TO; + + xprt->resvport = 0; /* privileged port not needed */ + xprt->ops = &xprt_rdma_procs; + + /* + * Set up RDMA-specific connect data. + */ + sap = args->dstaddr; + + /* Ensure xprt->addr holds valid server TCP (not RDMA) + * address, for any side protocols which peek at it */ + xprt->prot = IPPROTO_TCP; + xprt->xprt_class = &xprt_rdma; + xprt->addrlen = args->addrlen; + memcpy(&xprt->addr, sap, xprt->addrlen); + + if (rpc_get_port(sap)) + xprt_set_bound(xprt); + xprt_rdma_format_addresses(xprt, sap); + + new_xprt = rpcx_to_rdmax(xprt); + rc = rpcrdma_buffer_create(new_xprt); + if (rc) { + xprt_rdma_free_addresses(xprt); + xprt_free(xprt); + module_put(THIS_MODULE); + return ERR_PTR(rc); + } + + INIT_DELAYED_WORK(&new_xprt->rx_connect_worker, + xprt_rdma_connect_worker); + + xprt->max_payload = RPCRDMA_MAX_DATA_SEGS << PAGE_SHIFT; + + return xprt; +} + +/** + * xprt_rdma_close - close a transport connection + * @xprt: transport context + * + * Called during autoclose or device removal. + * + * Caller holds @xprt's send lock to prevent activity on this + * transport while the connection is torn down. + */ +void xprt_rdma_close(struct rpc_xprt *xprt) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + + rpcrdma_xprt_disconnect(r_xprt); + + xprt->reestablish_timeout = 0; + ++xprt->connect_cookie; + xprt_disconnect_done(xprt); +} + +/** + * xprt_rdma_set_port - update server port with rpcbind result + * @xprt: controlling RPC transport + * @port: new port value + * + * Transport connect status is unchanged. + */ +static void +xprt_rdma_set_port(struct rpc_xprt *xprt, u16 port) +{ + struct sockaddr *sap = (struct sockaddr *)&xprt->addr; + char buf[8]; + + rpc_set_port(sap, port); + + kfree(xprt->address_strings[RPC_DISPLAY_PORT]); + snprintf(buf, sizeof(buf), "%u", port); + xprt->address_strings[RPC_DISPLAY_PORT] = kstrdup(buf, GFP_KERNEL); + + kfree(xprt->address_strings[RPC_DISPLAY_HEX_PORT]); + snprintf(buf, sizeof(buf), "%4hx", port); + xprt->address_strings[RPC_DISPLAY_HEX_PORT] = kstrdup(buf, GFP_KERNEL); +} + +/** + * xprt_rdma_timer - invoked when an RPC times out + * @xprt: controlling RPC transport + * @task: RPC task that timed out + * + * Invoked when the transport is still connected, but an RPC + * retransmit timeout occurs. + * + * Since RDMA connections don't have a keep-alive, forcibly + * disconnect and retry to connect. This drives full + * detection of the network path, and retransmissions of + * all pending RPCs. + */ +static void +xprt_rdma_timer(struct rpc_xprt *xprt, struct rpc_task *task) +{ + xprt_force_disconnect(xprt); +} + +/** + * xprt_rdma_set_connect_timeout - set timeouts for establishing a connection + * @xprt: controlling transport instance + * @connect_timeout: reconnect timeout after client disconnects + * @reconnect_timeout: reconnect timeout after server disconnects + * + */ +static void xprt_rdma_set_connect_timeout(struct rpc_xprt *xprt, + unsigned long connect_timeout, + unsigned long reconnect_timeout) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + + trace_xprtrdma_op_set_cto(r_xprt, connect_timeout, reconnect_timeout); + + spin_lock(&xprt->transport_lock); + + if (connect_timeout < xprt->connect_timeout) { + struct rpc_timeout to; + unsigned long initval; + + to = *xprt->timeout; + initval = connect_timeout; + if (initval < RPCRDMA_INIT_REEST_TO << 1) + initval = RPCRDMA_INIT_REEST_TO << 1; + to.to_initval = initval; + to.to_maxval = initval; + r_xprt->rx_timeout = to; + xprt->timeout = &r_xprt->rx_timeout; + xprt->connect_timeout = connect_timeout; + } + + if (reconnect_timeout < xprt->max_reconnect_timeout) + xprt->max_reconnect_timeout = reconnect_timeout; + + spin_unlock(&xprt->transport_lock); +} + +/** + * xprt_rdma_connect - schedule an attempt to reconnect + * @xprt: transport state + * @task: RPC scheduler context (unused) + * + */ +static void +xprt_rdma_connect(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct rpcrdma_ep *ep = r_xprt->rx_ep; + unsigned long delay; + + WARN_ON_ONCE(!xprt_lock_connect(xprt, task, r_xprt)); + + delay = 0; + if (ep && ep->re_connect_status != 0) { + delay = xprt_reconnect_delay(xprt); + xprt_reconnect_backoff(xprt, RPCRDMA_INIT_REEST_TO); + } + trace_xprtrdma_op_connect(r_xprt, delay); + queue_delayed_work(system_long_wq, &r_xprt->rx_connect_worker, delay); +} + +/** + * xprt_rdma_alloc_slot - allocate an rpc_rqst + * @xprt: controlling RPC transport + * @task: RPC task requesting a fresh rpc_rqst + * + * tk_status values: + * %0 if task->tk_rqstp points to a fresh rpc_rqst + * %-EAGAIN if no rpc_rqst is available; queued on backlog + */ +static void +xprt_rdma_alloc_slot(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct rpcrdma_req *req; + + req = rpcrdma_buffer_get(&r_xprt->rx_buf); + if (!req) + goto out_sleep; + task->tk_rqstp = &req->rl_slot; + task->tk_status = 0; + return; + +out_sleep: + task->tk_status = -ENOMEM; + xprt_add_backlog(xprt, task); +} + +/** + * xprt_rdma_free_slot - release an rpc_rqst + * @xprt: controlling RPC transport + * @rqst: rpc_rqst to release + * + */ +static void +xprt_rdma_free_slot(struct rpc_xprt *xprt, struct rpc_rqst *rqst) +{ + struct rpcrdma_xprt *r_xprt = + container_of(xprt, struct rpcrdma_xprt, rx_xprt); + + rpcrdma_reply_put(&r_xprt->rx_buf, rpcr_to_rdmar(rqst)); + if (!xprt_wake_up_backlog(xprt, rqst)) { + memset(rqst, 0, sizeof(*rqst)); + rpcrdma_buffer_put(&r_xprt->rx_buf, rpcr_to_rdmar(rqst)); + } +} + +static bool rpcrdma_check_regbuf(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_regbuf *rb, size_t size, + gfp_t flags) +{ + if (unlikely(rdmab_length(rb) < size)) { + if (!rpcrdma_regbuf_realloc(rb, size, flags)) + return false; + r_xprt->rx_stats.hardway_register_count += size; + } + return true; +} + +/** + * xprt_rdma_allocate - allocate transport resources for an RPC + * @task: RPC task + * + * Return values: + * 0: Success; rq_buffer points to RPC buffer to use + * ENOMEM: Out of memory, call again later + * EIO: A permanent error occurred, do not retry + */ +static int +xprt_rdma_allocate(struct rpc_task *task) +{ + struct rpc_rqst *rqst = task->tk_rqstp; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(rqst->rq_xprt); + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + gfp_t flags = rpc_task_gfp_mask(); + + if (!rpcrdma_check_regbuf(r_xprt, req->rl_sendbuf, rqst->rq_callsize, + flags)) + goto out_fail; + if (!rpcrdma_check_regbuf(r_xprt, req->rl_recvbuf, rqst->rq_rcvsize, + flags)) + goto out_fail; + + rqst->rq_buffer = rdmab_data(req->rl_sendbuf); + rqst->rq_rbuffer = rdmab_data(req->rl_recvbuf); + return 0; + +out_fail: + return -ENOMEM; +} + +/** + * xprt_rdma_free - release resources allocated by xprt_rdma_allocate + * @task: RPC task + * + * Caller guarantees rqst->rq_buffer is non-NULL. + */ +static void +xprt_rdma_free(struct rpc_task *task) +{ + struct rpc_rqst *rqst = task->tk_rqstp; + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + + if (unlikely(!list_empty(&req->rl_registered))) { + trace_xprtrdma_mrs_zap(task); + frwr_unmap_sync(rpcx_to_rdmax(rqst->rq_xprt), req); + } + + /* XXX: If the RPC is completing because of a signal and + * not because a reply was received, we ought to ensure + * that the Send completion has fired, so that memory + * involved with the Send is not still visible to the NIC. + */ +} + +/** + * xprt_rdma_send_request - marshal and send an RPC request + * @rqst: RPC message in rq_snd_buf + * + * Caller holds the transport's write lock. + * + * Returns: + * %0 if the RPC message has been sent + * %-ENOTCONN if the caller should reconnect and call again + * %-EAGAIN if the caller should call again + * %-ENOBUFS if the caller should call again after a delay + * %-EMSGSIZE if encoding ran out of buffer space. The request + * was not sent. Do not try to send this message again. + * %-EIO if an I/O error occurred. The request was not sent. + * Do not try to send this message again. + */ +static int +xprt_rdma_send_request(struct rpc_rqst *rqst) +{ + struct rpc_xprt *xprt = rqst->rq_xprt; + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + int rc = 0; + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) + if (unlikely(!rqst->rq_buffer)) + return xprt_rdma_bc_send_reply(rqst); +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + + if (!xprt_connected(xprt)) + return -ENOTCONN; + + if (!xprt_request_get_cong(xprt, rqst)) + return -EBADSLT; + + rc = rpcrdma_marshal_req(r_xprt, rqst); + if (rc < 0) + goto failed_marshal; + + /* Must suppress retransmit to maintain credits */ + if (rqst->rq_connect_cookie == xprt->connect_cookie) + goto drop_connection; + rqst->rq_xtime = ktime_get(); + + if (frwr_send(r_xprt, req)) + goto drop_connection; + + rqst->rq_xmit_bytes_sent += rqst->rq_snd_buf.len; + + /* An RPC with no reply will throw off credit accounting, + * so drop the connection to reset the credit grant. + */ + if (!rpc_reply_expected(rqst->rq_task)) + goto drop_connection; + return 0; + +failed_marshal: + if (rc != -ENOTCONN) + return rc; +drop_connection: + xprt_rdma_close(xprt); + return -ENOTCONN; +} + +void xprt_rdma_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + long idle_time = 0; + + if (xprt_connected(xprt)) + idle_time = (long)(jiffies - xprt->last_used) / HZ; + + seq_puts(seq, "\txprt:\trdma "); + seq_printf(seq, "%u %lu %lu %lu %ld %lu %lu %lu %llu %llu ", + 0, /* need a local port? */ + xprt->stat.bind_count, + xprt->stat.connect_count, + xprt->stat.connect_time / HZ, + idle_time, + xprt->stat.sends, + xprt->stat.recvs, + xprt->stat.bad_xids, + xprt->stat.req_u, + xprt->stat.bklog_u); + seq_printf(seq, "%lu %lu %lu %llu %llu %llu %llu %lu %lu %lu %lu ", + r_xprt->rx_stats.read_chunk_count, + r_xprt->rx_stats.write_chunk_count, + r_xprt->rx_stats.reply_chunk_count, + r_xprt->rx_stats.total_rdma_request, + r_xprt->rx_stats.total_rdma_reply, + r_xprt->rx_stats.pullup_copy_count, + r_xprt->rx_stats.fixup_copy_count, + r_xprt->rx_stats.hardway_register_count, + r_xprt->rx_stats.failed_marshal_count, + r_xprt->rx_stats.bad_reply_count, + r_xprt->rx_stats.nomsg_call_count); + seq_printf(seq, "%lu %lu %lu %lu %lu %lu\n", + r_xprt->rx_stats.mrs_recycled, + r_xprt->rx_stats.mrs_orphaned, + r_xprt->rx_stats.mrs_allocated, + r_xprt->rx_stats.local_inv_needed, + r_xprt->rx_stats.empty_sendctx_q, + r_xprt->rx_stats.reply_waits_for_send); +} + +static int +xprt_rdma_enable_swap(struct rpc_xprt *xprt) +{ + return 0; +} + +static void +xprt_rdma_disable_swap(struct rpc_xprt *xprt) +{ +} + +/* + * Plumbing for rpc transport switch and kernel module + */ + +static const struct rpc_xprt_ops xprt_rdma_procs = { + .reserve_xprt = xprt_reserve_xprt_cong, + .release_xprt = xprt_release_xprt_cong, /* sunrpc/xprt.c */ + .alloc_slot = xprt_rdma_alloc_slot, + .free_slot = xprt_rdma_free_slot, + .release_request = xprt_release_rqst_cong, /* ditto */ + .wait_for_reply_request = xprt_wait_for_reply_request_def, /* ditto */ + .timer = xprt_rdma_timer, + .rpcbind = rpcb_getport_async, /* sunrpc/rpcb_clnt.c */ + .set_port = xprt_rdma_set_port, + .connect = xprt_rdma_connect, + .buf_alloc = xprt_rdma_allocate, + .buf_free = xprt_rdma_free, + .send_request = xprt_rdma_send_request, + .close = xprt_rdma_close, + .destroy = xprt_rdma_destroy, + .set_connect_timeout = xprt_rdma_set_connect_timeout, + .print_stats = xprt_rdma_print_stats, + .enable_swap = xprt_rdma_enable_swap, + .disable_swap = xprt_rdma_disable_swap, + .inject_disconnect = xprt_rdma_inject_disconnect, +#if defined(CONFIG_SUNRPC_BACKCHANNEL) + .bc_setup = xprt_rdma_bc_setup, + .bc_maxpayload = xprt_rdma_bc_maxpayload, + .bc_num_slots = xprt_rdma_bc_max_slots, + .bc_free_rqst = xprt_rdma_bc_free_rqst, + .bc_destroy = xprt_rdma_bc_destroy, +#endif +}; + +static struct xprt_class xprt_rdma = { + .list = LIST_HEAD_INIT(xprt_rdma.list), + .name = "rdma", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_RDMA, + .setup = xprt_setup_rdma, + .netid = { "rdma", "rdma6", "" }, +}; + +void xprt_rdma_cleanup(void) +{ +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + if (sunrpc_table_header) { + unregister_sysctl_table(sunrpc_table_header); + sunrpc_table_header = NULL; + } +#endif + + xprt_unregister_transport(&xprt_rdma); + xprt_unregister_transport(&xprt_rdma_bc); +} + +int xprt_rdma_init(void) +{ + int rc; + + rc = xprt_register_transport(&xprt_rdma); + if (rc) + return rc; + + rc = xprt_register_transport(&xprt_rdma_bc); + if (rc) { + xprt_unregister_transport(&xprt_rdma); + return rc; + } + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) + if (!sunrpc_table_header) + sunrpc_table_header = register_sysctl_table(sunrpc_table); +#endif + return 0; +} diff --git a/net/sunrpc/xprtrdma/verbs.c b/net/sunrpc/xprtrdma/verbs.c new file mode 100644 index 000000000..28c0771c4 --- /dev/null +++ b/net/sunrpc/xprtrdma/verbs.c @@ -0,0 +1,1396 @@ +// SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause +/* + * Copyright (c) 2014-2017 Oracle. All rights reserved. + * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* + * verbs.c + * + * Encapsulates the major functions managing: + * o adapters + * o endpoints + * o connections + * o buffer memory + */ + +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include "xprt_rdma.h" +#include + +static int rpcrdma_sendctxs_create(struct rpcrdma_xprt *r_xprt); +static void rpcrdma_sendctxs_destroy(struct rpcrdma_xprt *r_xprt); +static void rpcrdma_sendctx_put_locked(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_sendctx *sc); +static int rpcrdma_reqs_setup(struct rpcrdma_xprt *r_xprt); +static void rpcrdma_reqs_reset(struct rpcrdma_xprt *r_xprt); +static void rpcrdma_rep_destroy(struct rpcrdma_rep *rep); +static void rpcrdma_reps_unmap(struct rpcrdma_xprt *r_xprt); +static void rpcrdma_mrs_create(struct rpcrdma_xprt *r_xprt); +static void rpcrdma_mrs_destroy(struct rpcrdma_xprt *r_xprt); +static void rpcrdma_ep_get(struct rpcrdma_ep *ep); +static int rpcrdma_ep_put(struct rpcrdma_ep *ep); +static struct rpcrdma_regbuf * +rpcrdma_regbuf_alloc(size_t size, enum dma_data_direction direction); +static void rpcrdma_regbuf_dma_unmap(struct rpcrdma_regbuf *rb); +static void rpcrdma_regbuf_free(struct rpcrdma_regbuf *rb); + +/* Wait for outstanding transport work to finish. ib_drain_qp + * handles the drains in the wrong order for us, so open code + * them here. + */ +static void rpcrdma_xprt_drain(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct rdma_cm_id *id = ep->re_id; + + /* Wait for rpcrdma_post_recvs() to leave its critical + * section. + */ + if (atomic_inc_return(&ep->re_receiving) > 1) + wait_for_completion(&ep->re_done); + + /* Flush Receives, then wait for deferred Reply work + * to complete. + */ + ib_drain_rq(id->qp); + + /* Deferred Reply processing might have scheduled + * local invalidations. + */ + ib_drain_sq(id->qp); + + rpcrdma_ep_put(ep); +} + +/* Ensure xprt_force_disconnect() is invoked exactly once when a + * connection is closed or lost. (The important thing is it needs + * to be invoked "at least" once). + */ +void rpcrdma_force_disconnect(struct rpcrdma_ep *ep) +{ + if (atomic_add_unless(&ep->re_force_disconnect, 1, 1)) + xprt_force_disconnect(ep->re_xprt); +} + +/** + * rpcrdma_flush_disconnect - Disconnect on flushed completion + * @r_xprt: transport to disconnect + * @wc: work completion entry + * + * Must be called in process context. + */ +void rpcrdma_flush_disconnect(struct rpcrdma_xprt *r_xprt, struct ib_wc *wc) +{ + if (wc->status != IB_WC_SUCCESS) + rpcrdma_force_disconnect(r_xprt->rx_ep); +} + +/** + * rpcrdma_wc_send - Invoked by RDMA provider for each polled Send WC + * @cq: completion queue + * @wc: WCE for a completed Send WR + * + */ +static void rpcrdma_wc_send(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct rpcrdma_sendctx *sc = + container_of(cqe, struct rpcrdma_sendctx, sc_cqe); + struct rpcrdma_xprt *r_xprt = cq->cq_context; + + /* WARNING: Only wr_cqe and status are reliable at this point */ + trace_xprtrdma_wc_send(wc, &sc->sc_cid); + rpcrdma_sendctx_put_locked(r_xprt, sc); + rpcrdma_flush_disconnect(r_xprt, wc); +} + +/** + * rpcrdma_wc_receive - Invoked by RDMA provider for each polled Receive WC + * @cq: completion queue + * @wc: WCE for a completed Receive WR + * + */ +static void rpcrdma_wc_receive(struct ib_cq *cq, struct ib_wc *wc) +{ + struct ib_cqe *cqe = wc->wr_cqe; + struct rpcrdma_rep *rep = container_of(cqe, struct rpcrdma_rep, + rr_cqe); + struct rpcrdma_xprt *r_xprt = cq->cq_context; + + /* WARNING: Only wr_cqe and status are reliable at this point */ + trace_xprtrdma_wc_receive(wc, &rep->rr_cid); + --r_xprt->rx_ep->re_receive_count; + if (wc->status != IB_WC_SUCCESS) + goto out_flushed; + + /* status == SUCCESS means all fields in wc are trustworthy */ + rpcrdma_set_xdrlen(&rep->rr_hdrbuf, wc->byte_len); + rep->rr_wc_flags = wc->wc_flags; + rep->rr_inv_rkey = wc->ex.invalidate_rkey; + + ib_dma_sync_single_for_cpu(rdmab_device(rep->rr_rdmabuf), + rdmab_addr(rep->rr_rdmabuf), + wc->byte_len, DMA_FROM_DEVICE); + + rpcrdma_reply_handler(rep); + return; + +out_flushed: + rpcrdma_flush_disconnect(r_xprt, wc); + rpcrdma_rep_put(&r_xprt->rx_buf, rep); +} + +static void rpcrdma_update_cm_private(struct rpcrdma_ep *ep, + struct rdma_conn_param *param) +{ + const struct rpcrdma_connect_private *pmsg = param->private_data; + unsigned int rsize, wsize; + + /* Default settings for RPC-over-RDMA Version One */ + rsize = RPCRDMA_V1_DEF_INLINE_SIZE; + wsize = RPCRDMA_V1_DEF_INLINE_SIZE; + + if (pmsg && + pmsg->cp_magic == rpcrdma_cmp_magic && + pmsg->cp_version == RPCRDMA_CMP_VERSION) { + rsize = rpcrdma_decode_buffer_size(pmsg->cp_send_size); + wsize = rpcrdma_decode_buffer_size(pmsg->cp_recv_size); + } + + if (rsize < ep->re_inline_recv) + ep->re_inline_recv = rsize; + if (wsize < ep->re_inline_send) + ep->re_inline_send = wsize; + + rpcrdma_set_max_header_sizes(ep); +} + +/** + * rpcrdma_cm_event_handler - Handle RDMA CM events + * @id: rdma_cm_id on which an event has occurred + * @event: details of the event + * + * Called with @id's mutex held. Returns 1 if caller should + * destroy @id, otherwise 0. + */ +static int +rpcrdma_cm_event_handler(struct rdma_cm_id *id, struct rdma_cm_event *event) +{ + struct sockaddr *sap = (struct sockaddr *)&id->route.addr.dst_addr; + struct rpcrdma_ep *ep = id->context; + + might_sleep(); + + switch (event->event) { + case RDMA_CM_EVENT_ADDR_RESOLVED: + case RDMA_CM_EVENT_ROUTE_RESOLVED: + ep->re_async_rc = 0; + complete(&ep->re_done); + return 0; + case RDMA_CM_EVENT_ADDR_ERROR: + ep->re_async_rc = -EPROTO; + complete(&ep->re_done); + return 0; + case RDMA_CM_EVENT_ROUTE_ERROR: + ep->re_async_rc = -ENETUNREACH; + complete(&ep->re_done); + return 0; + case RDMA_CM_EVENT_DEVICE_REMOVAL: + pr_info("rpcrdma: removing device %s for %pISpc\n", + ep->re_id->device->name, sap); + fallthrough; + case RDMA_CM_EVENT_ADDR_CHANGE: + ep->re_connect_status = -ENODEV; + goto disconnected; + case RDMA_CM_EVENT_ESTABLISHED: + rpcrdma_ep_get(ep); + ep->re_connect_status = 1; + rpcrdma_update_cm_private(ep, &event->param.conn); + trace_xprtrdma_inline_thresh(ep); + wake_up_all(&ep->re_connect_wait); + break; + case RDMA_CM_EVENT_CONNECT_ERROR: + ep->re_connect_status = -ENOTCONN; + goto wake_connect_worker; + case RDMA_CM_EVENT_UNREACHABLE: + ep->re_connect_status = -ENETUNREACH; + goto wake_connect_worker; + case RDMA_CM_EVENT_REJECTED: + ep->re_connect_status = -ECONNREFUSED; + if (event->status == IB_CM_REJ_STALE_CONN) + ep->re_connect_status = -ENOTCONN; +wake_connect_worker: + wake_up_all(&ep->re_connect_wait); + return 0; + case RDMA_CM_EVENT_DISCONNECTED: + ep->re_connect_status = -ECONNABORTED; +disconnected: + rpcrdma_force_disconnect(ep); + return rpcrdma_ep_put(ep); + default: + break; + } + + return 0; +} + +static struct rdma_cm_id *rpcrdma_create_id(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_ep *ep) +{ + unsigned long wtimeout = msecs_to_jiffies(RDMA_RESOLVE_TIMEOUT) + 1; + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct rdma_cm_id *id; + int rc; + + init_completion(&ep->re_done); + + id = rdma_create_id(xprt->xprt_net, rpcrdma_cm_event_handler, ep, + RDMA_PS_TCP, IB_QPT_RC); + if (IS_ERR(id)) + return id; + + ep->re_async_rc = -ETIMEDOUT; + rc = rdma_resolve_addr(id, NULL, (struct sockaddr *)&xprt->addr, + RDMA_RESOLVE_TIMEOUT); + if (rc) + goto out; + rc = wait_for_completion_interruptible_timeout(&ep->re_done, wtimeout); + if (rc < 0) + goto out; + + rc = ep->re_async_rc; + if (rc) + goto out; + + ep->re_async_rc = -ETIMEDOUT; + rc = rdma_resolve_route(id, RDMA_RESOLVE_TIMEOUT); + if (rc) + goto out; + rc = wait_for_completion_interruptible_timeout(&ep->re_done, wtimeout); + if (rc < 0) + goto out; + rc = ep->re_async_rc; + if (rc) + goto out; + + return id; + +out: + rdma_destroy_id(id); + return ERR_PTR(rc); +} + +static void rpcrdma_ep_destroy(struct kref *kref) +{ + struct rpcrdma_ep *ep = container_of(kref, struct rpcrdma_ep, re_kref); + + if (ep->re_id->qp) { + rdma_destroy_qp(ep->re_id); + ep->re_id->qp = NULL; + } + + if (ep->re_attr.recv_cq) + ib_free_cq(ep->re_attr.recv_cq); + ep->re_attr.recv_cq = NULL; + if (ep->re_attr.send_cq) + ib_free_cq(ep->re_attr.send_cq); + ep->re_attr.send_cq = NULL; + + if (ep->re_pd) + ib_dealloc_pd(ep->re_pd); + ep->re_pd = NULL; + + kfree(ep); + module_put(THIS_MODULE); +} + +static noinline void rpcrdma_ep_get(struct rpcrdma_ep *ep) +{ + kref_get(&ep->re_kref); +} + +/* Returns: + * %0 if @ep still has a positive kref count, or + * %1 if @ep was destroyed successfully. + */ +static noinline int rpcrdma_ep_put(struct rpcrdma_ep *ep) +{ + return kref_put(&ep->re_kref, rpcrdma_ep_destroy); +} + +static int rpcrdma_ep_create(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_connect_private *pmsg; + struct ib_device *device; + struct rdma_cm_id *id; + struct rpcrdma_ep *ep; + int rc; + + ep = kzalloc(sizeof(*ep), XPRTRDMA_GFP_FLAGS); + if (!ep) + return -ENOTCONN; + ep->re_xprt = &r_xprt->rx_xprt; + kref_init(&ep->re_kref); + + id = rpcrdma_create_id(r_xprt, ep); + if (IS_ERR(id)) { + kfree(ep); + return PTR_ERR(id); + } + __module_get(THIS_MODULE); + device = id->device; + ep->re_id = id; + reinit_completion(&ep->re_done); + + ep->re_max_requests = r_xprt->rx_xprt.max_reqs; + ep->re_inline_send = xprt_rdma_max_inline_write; + ep->re_inline_recv = xprt_rdma_max_inline_read; + rc = frwr_query_device(ep, device); + if (rc) + goto out_destroy; + + r_xprt->rx_buf.rb_max_requests = cpu_to_be32(ep->re_max_requests); + + ep->re_attr.srq = NULL; + ep->re_attr.cap.max_inline_data = 0; + ep->re_attr.sq_sig_type = IB_SIGNAL_REQ_WR; + ep->re_attr.qp_type = IB_QPT_RC; + ep->re_attr.port_num = ~0; + + ep->re_send_batch = ep->re_max_requests >> 3; + ep->re_send_count = ep->re_send_batch; + init_waitqueue_head(&ep->re_connect_wait); + + ep->re_attr.send_cq = ib_alloc_cq_any(device, r_xprt, + ep->re_attr.cap.max_send_wr, + IB_POLL_WORKQUEUE); + if (IS_ERR(ep->re_attr.send_cq)) { + rc = PTR_ERR(ep->re_attr.send_cq); + ep->re_attr.send_cq = NULL; + goto out_destroy; + } + + ep->re_attr.recv_cq = ib_alloc_cq_any(device, r_xprt, + ep->re_attr.cap.max_recv_wr, + IB_POLL_WORKQUEUE); + if (IS_ERR(ep->re_attr.recv_cq)) { + rc = PTR_ERR(ep->re_attr.recv_cq); + ep->re_attr.recv_cq = NULL; + goto out_destroy; + } + ep->re_receive_count = 0; + + /* Initialize cma parameters */ + memset(&ep->re_remote_cma, 0, sizeof(ep->re_remote_cma)); + + /* Prepare RDMA-CM private message */ + pmsg = &ep->re_cm_private; + pmsg->cp_magic = rpcrdma_cmp_magic; + pmsg->cp_version = RPCRDMA_CMP_VERSION; + pmsg->cp_flags |= RPCRDMA_CMP_F_SND_W_INV_OK; + pmsg->cp_send_size = rpcrdma_encode_buffer_size(ep->re_inline_send); + pmsg->cp_recv_size = rpcrdma_encode_buffer_size(ep->re_inline_recv); + ep->re_remote_cma.private_data = pmsg; + ep->re_remote_cma.private_data_len = sizeof(*pmsg); + + /* Client offers RDMA Read but does not initiate */ + ep->re_remote_cma.initiator_depth = 0; + ep->re_remote_cma.responder_resources = + min_t(int, U8_MAX, device->attrs.max_qp_rd_atom); + + /* Limit transport retries so client can detect server + * GID changes quickly. RPC layer handles re-establishing + * transport connection and retransmission. + */ + ep->re_remote_cma.retry_count = 6; + + /* RPC-over-RDMA handles its own flow control. In addition, + * make all RNR NAKs visible so we know that RPC-over-RDMA + * flow control is working correctly (no NAKs should be seen). + */ + ep->re_remote_cma.flow_control = 0; + ep->re_remote_cma.rnr_retry_count = 0; + + ep->re_pd = ib_alloc_pd(device, 0); + if (IS_ERR(ep->re_pd)) { + rc = PTR_ERR(ep->re_pd); + ep->re_pd = NULL; + goto out_destroy; + } + + rc = rdma_create_qp(id, ep->re_pd, &ep->re_attr); + if (rc) + goto out_destroy; + + r_xprt->rx_ep = ep; + return 0; + +out_destroy: + rpcrdma_ep_put(ep); + rdma_destroy_id(id); + return rc; +} + +/** + * rpcrdma_xprt_connect - Connect an unconnected transport + * @r_xprt: controlling transport instance + * + * Returns 0 on success or a negative errno. + */ +int rpcrdma_xprt_connect(struct rpcrdma_xprt *r_xprt) +{ + struct rpc_xprt *xprt = &r_xprt->rx_xprt; + struct rpcrdma_ep *ep; + int rc; + + rc = rpcrdma_ep_create(r_xprt); + if (rc) + return rc; + ep = r_xprt->rx_ep; + + xprt_clear_connected(xprt); + rpcrdma_reset_cwnd(r_xprt); + + /* Bump the ep's reference count while there are + * outstanding Receives. + */ + rpcrdma_ep_get(ep); + rpcrdma_post_recvs(r_xprt, 1, true); + + rc = rdma_connect(ep->re_id, &ep->re_remote_cma); + if (rc) + goto out; + + if (xprt->reestablish_timeout < RPCRDMA_INIT_REEST_TO) + xprt->reestablish_timeout = RPCRDMA_INIT_REEST_TO; + wait_event_interruptible(ep->re_connect_wait, + ep->re_connect_status != 0); + if (ep->re_connect_status <= 0) { + rc = ep->re_connect_status; + goto out; + } + + rc = rpcrdma_sendctxs_create(r_xprt); + if (rc) { + rc = -ENOTCONN; + goto out; + } + + rc = rpcrdma_reqs_setup(r_xprt); + if (rc) { + rc = -ENOTCONN; + goto out; + } + rpcrdma_mrs_create(r_xprt); + frwr_wp_create(r_xprt); + +out: + trace_xprtrdma_connect(r_xprt, rc); + return rc; +} + +/** + * rpcrdma_xprt_disconnect - Disconnect underlying transport + * @r_xprt: controlling transport instance + * + * Caller serializes. Either the transport send lock is held, + * or we're being called to destroy the transport. + * + * On return, @r_xprt is completely divested of all hardware + * resources and prepared for the next ->connect operation. + */ +void rpcrdma_xprt_disconnect(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct rdma_cm_id *id; + int rc; + + if (!ep) + return; + + id = ep->re_id; + rc = rdma_disconnect(id); + trace_xprtrdma_disconnect(r_xprt, rc); + + rpcrdma_xprt_drain(r_xprt); + rpcrdma_reps_unmap(r_xprt); + rpcrdma_reqs_reset(r_xprt); + rpcrdma_mrs_destroy(r_xprt); + rpcrdma_sendctxs_destroy(r_xprt); + + if (rpcrdma_ep_put(ep)) + rdma_destroy_id(id); + + r_xprt->rx_ep = NULL; +} + +/* Fixed-size circular FIFO queue. This implementation is wait-free and + * lock-free. + * + * Consumer is the code path that posts Sends. This path dequeues a + * sendctx for use by a Send operation. Multiple consumer threads + * are serialized by the RPC transport lock, which allows only one + * ->send_request call at a time. + * + * Producer is the code path that handles Send completions. This path + * enqueues a sendctx that has been completed. Multiple producer + * threads are serialized by the ib_poll_cq() function. + */ + +/* rpcrdma_sendctxs_destroy() assumes caller has already quiesced + * queue activity, and rpcrdma_xprt_drain has flushed all remaining + * Send requests. + */ +static void rpcrdma_sendctxs_destroy(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + unsigned long i; + + if (!buf->rb_sc_ctxs) + return; + for (i = 0; i <= buf->rb_sc_last; i++) + kfree(buf->rb_sc_ctxs[i]); + kfree(buf->rb_sc_ctxs); + buf->rb_sc_ctxs = NULL; +} + +static struct rpcrdma_sendctx *rpcrdma_sendctx_create(struct rpcrdma_ep *ep) +{ + struct rpcrdma_sendctx *sc; + + sc = kzalloc(struct_size(sc, sc_sges, ep->re_attr.cap.max_send_sge), + XPRTRDMA_GFP_FLAGS); + if (!sc) + return NULL; + + sc->sc_cqe.done = rpcrdma_wc_send; + sc->sc_cid.ci_queue_id = ep->re_attr.send_cq->res.id; + sc->sc_cid.ci_completion_id = + atomic_inc_return(&ep->re_completion_ids); + return sc; +} + +static int rpcrdma_sendctxs_create(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_sendctx *sc; + unsigned long i; + + /* Maximum number of concurrent outstanding Send WRs. Capping + * the circular queue size stops Send Queue overflow by causing + * the ->send_request call to fail temporarily before too many + * Sends are posted. + */ + i = r_xprt->rx_ep->re_max_requests + RPCRDMA_MAX_BC_REQUESTS; + buf->rb_sc_ctxs = kcalloc(i, sizeof(sc), XPRTRDMA_GFP_FLAGS); + if (!buf->rb_sc_ctxs) + return -ENOMEM; + + buf->rb_sc_last = i - 1; + for (i = 0; i <= buf->rb_sc_last; i++) { + sc = rpcrdma_sendctx_create(r_xprt->rx_ep); + if (!sc) + return -ENOMEM; + + buf->rb_sc_ctxs[i] = sc; + } + + buf->rb_sc_head = 0; + buf->rb_sc_tail = 0; + return 0; +} + +/* The sendctx queue is not guaranteed to have a size that is a + * power of two, thus the helpers in circ_buf.h cannot be used. + * The other option is to use modulus (%), which can be expensive. + */ +static unsigned long rpcrdma_sendctx_next(struct rpcrdma_buffer *buf, + unsigned long item) +{ + return likely(item < buf->rb_sc_last) ? item + 1 : 0; +} + +/** + * rpcrdma_sendctx_get_locked - Acquire a send context + * @r_xprt: controlling transport instance + * + * Returns pointer to a free send completion context; or NULL if + * the queue is empty. + * + * Usage: Called to acquire an SGE array before preparing a Send WR. + * + * The caller serializes calls to this function (per transport), and + * provides an effective memory barrier that flushes the new value + * of rb_sc_head. + */ +struct rpcrdma_sendctx *rpcrdma_sendctx_get_locked(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_sendctx *sc; + unsigned long next_head; + + next_head = rpcrdma_sendctx_next(buf, buf->rb_sc_head); + + if (next_head == READ_ONCE(buf->rb_sc_tail)) + goto out_emptyq; + + /* ORDER: item must be accessed _before_ head is updated */ + sc = buf->rb_sc_ctxs[next_head]; + + /* Releasing the lock in the caller acts as a memory + * barrier that flushes rb_sc_head. + */ + buf->rb_sc_head = next_head; + + return sc; + +out_emptyq: + /* The queue is "empty" if there have not been enough Send + * completions recently. This is a sign the Send Queue is + * backing up. Cause the caller to pause and try again. + */ + xprt_wait_for_buffer_space(&r_xprt->rx_xprt); + r_xprt->rx_stats.empty_sendctx_q++; + return NULL; +} + +/** + * rpcrdma_sendctx_put_locked - Release a send context + * @r_xprt: controlling transport instance + * @sc: send context to release + * + * Usage: Called from Send completion to return a sendctxt + * to the queue. + * + * The caller serializes calls to this function (per transport). + */ +static void rpcrdma_sendctx_put_locked(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_sendctx *sc) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + unsigned long next_tail; + + /* Unmap SGEs of previously completed but unsignaled + * Sends by walking up the queue until @sc is found. + */ + next_tail = buf->rb_sc_tail; + do { + next_tail = rpcrdma_sendctx_next(buf, next_tail); + + /* ORDER: item must be accessed _before_ tail is updated */ + rpcrdma_sendctx_unmap(buf->rb_sc_ctxs[next_tail]); + + } while (buf->rb_sc_ctxs[next_tail] != sc); + + /* Paired with READ_ONCE */ + smp_store_release(&buf->rb_sc_tail, next_tail); + + xprt_write_space(&r_xprt->rx_xprt); +} + +static void +rpcrdma_mrs_create(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct ib_device *device = ep->re_id->device; + unsigned int count; + + /* Try to allocate enough to perform one full-sized I/O */ + for (count = 0; count < ep->re_max_rdma_segs; count++) { + struct rpcrdma_mr *mr; + int rc; + + mr = kzalloc_node(sizeof(*mr), XPRTRDMA_GFP_FLAGS, + ibdev_to_node(device)); + if (!mr) + break; + + rc = frwr_mr_init(r_xprt, mr); + if (rc) { + kfree(mr); + break; + } + + spin_lock(&buf->rb_lock); + rpcrdma_mr_push(mr, &buf->rb_mrs); + list_add(&mr->mr_all, &buf->rb_all_mrs); + spin_unlock(&buf->rb_lock); + } + + r_xprt->rx_stats.mrs_allocated += count; + trace_xprtrdma_createmrs(r_xprt, count); +} + +static void +rpcrdma_mr_refresh_worker(struct work_struct *work) +{ + struct rpcrdma_buffer *buf = container_of(work, struct rpcrdma_buffer, + rb_refresh_worker); + struct rpcrdma_xprt *r_xprt = container_of(buf, struct rpcrdma_xprt, + rx_buf); + + rpcrdma_mrs_create(r_xprt); + xprt_write_space(&r_xprt->rx_xprt); +} + +/** + * rpcrdma_mrs_refresh - Wake the MR refresh worker + * @r_xprt: controlling transport instance + * + */ +void rpcrdma_mrs_refresh(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_ep *ep = r_xprt->rx_ep; + + /* If there is no underlying connection, it's no use + * to wake the refresh worker. + */ + if (ep->re_connect_status != 1) + return; + queue_work(system_highpri_wq, &buf->rb_refresh_worker); +} + +/** + * rpcrdma_req_create - Allocate an rpcrdma_req object + * @r_xprt: controlling r_xprt + * @size: initial size, in bytes, of send and receive buffers + * + * Returns an allocated and fully initialized rpcrdma_req or NULL. + */ +struct rpcrdma_req *rpcrdma_req_create(struct rpcrdma_xprt *r_xprt, + size_t size) +{ + struct rpcrdma_buffer *buffer = &r_xprt->rx_buf; + struct rpcrdma_req *req; + + req = kzalloc(sizeof(*req), XPRTRDMA_GFP_FLAGS); + if (req == NULL) + goto out1; + + req->rl_sendbuf = rpcrdma_regbuf_alloc(size, DMA_TO_DEVICE); + if (!req->rl_sendbuf) + goto out2; + + req->rl_recvbuf = rpcrdma_regbuf_alloc(size, DMA_NONE); + if (!req->rl_recvbuf) + goto out3; + + INIT_LIST_HEAD(&req->rl_free_mrs); + INIT_LIST_HEAD(&req->rl_registered); + spin_lock(&buffer->rb_lock); + list_add(&req->rl_all, &buffer->rb_allreqs); + spin_unlock(&buffer->rb_lock); + return req; + +out3: + rpcrdma_regbuf_free(req->rl_sendbuf); +out2: + kfree(req); +out1: + return NULL; +} + +/** + * rpcrdma_req_setup - Per-connection instance setup of an rpcrdma_req object + * @r_xprt: controlling transport instance + * @req: rpcrdma_req object to set up + * + * Returns zero on success, and a negative errno on failure. + */ +int rpcrdma_req_setup(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req) +{ + struct rpcrdma_regbuf *rb; + size_t maxhdrsize; + + /* Compute maximum header buffer size in bytes */ + maxhdrsize = rpcrdma_fixed_maxsz + 3 + + r_xprt->rx_ep->re_max_rdma_segs * rpcrdma_readchunk_maxsz; + maxhdrsize *= sizeof(__be32); + rb = rpcrdma_regbuf_alloc(__roundup_pow_of_two(maxhdrsize), + DMA_TO_DEVICE); + if (!rb) + goto out; + + if (!__rpcrdma_regbuf_dma_map(r_xprt, rb)) + goto out_free; + + req->rl_rdmabuf = rb; + xdr_buf_init(&req->rl_hdrbuf, rdmab_data(rb), rdmab_length(rb)); + return 0; + +out_free: + rpcrdma_regbuf_free(rb); +out: + return -ENOMEM; +} + +/* ASSUMPTION: the rb_allreqs list is stable for the duration, + * and thus can be walked without holding rb_lock. Eg. the + * caller is holding the transport send lock to exclude + * device removal or disconnection. + */ +static int rpcrdma_reqs_setup(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_req *req; + int rc; + + list_for_each_entry(req, &buf->rb_allreqs, rl_all) { + rc = rpcrdma_req_setup(r_xprt, req); + if (rc) + return rc; + } + return 0; +} + +static void rpcrdma_req_reset(struct rpcrdma_req *req) +{ + /* Credits are valid for only one connection */ + req->rl_slot.rq_cong = 0; + + rpcrdma_regbuf_free(req->rl_rdmabuf); + req->rl_rdmabuf = NULL; + + rpcrdma_regbuf_dma_unmap(req->rl_sendbuf); + rpcrdma_regbuf_dma_unmap(req->rl_recvbuf); + + frwr_reset(req); +} + +/* ASSUMPTION: the rb_allreqs list is stable for the duration, + * and thus can be walked without holding rb_lock. Eg. the + * caller is holding the transport send lock to exclude + * device removal or disconnection. + */ +static void rpcrdma_reqs_reset(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_req *req; + + list_for_each_entry(req, &buf->rb_allreqs, rl_all) + rpcrdma_req_reset(req); +} + +static noinline +struct rpcrdma_rep *rpcrdma_rep_create(struct rpcrdma_xprt *r_xprt, + bool temp) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_rep *rep; + + rep = kzalloc(sizeof(*rep), XPRTRDMA_GFP_FLAGS); + if (rep == NULL) + goto out; + + rep->rr_rdmabuf = rpcrdma_regbuf_alloc(r_xprt->rx_ep->re_inline_recv, + DMA_FROM_DEVICE); + if (!rep->rr_rdmabuf) + goto out_free; + + rep->rr_cid.ci_completion_id = + atomic_inc_return(&r_xprt->rx_ep->re_completion_ids); + + xdr_buf_init(&rep->rr_hdrbuf, rdmab_data(rep->rr_rdmabuf), + rdmab_length(rep->rr_rdmabuf)); + rep->rr_cqe.done = rpcrdma_wc_receive; + rep->rr_rxprt = r_xprt; + rep->rr_recv_wr.next = NULL; + rep->rr_recv_wr.wr_cqe = &rep->rr_cqe; + rep->rr_recv_wr.sg_list = &rep->rr_rdmabuf->rg_iov; + rep->rr_recv_wr.num_sge = 1; + rep->rr_temp = temp; + + spin_lock(&buf->rb_lock); + list_add(&rep->rr_all, &buf->rb_all_reps); + spin_unlock(&buf->rb_lock); + return rep; + +out_free: + kfree(rep); +out: + return NULL; +} + +static void rpcrdma_rep_free(struct rpcrdma_rep *rep) +{ + rpcrdma_regbuf_free(rep->rr_rdmabuf); + kfree(rep); +} + +static void rpcrdma_rep_destroy(struct rpcrdma_rep *rep) +{ + struct rpcrdma_buffer *buf = &rep->rr_rxprt->rx_buf; + + spin_lock(&buf->rb_lock); + list_del(&rep->rr_all); + spin_unlock(&buf->rb_lock); + + rpcrdma_rep_free(rep); +} + +static struct rpcrdma_rep *rpcrdma_rep_get_locked(struct rpcrdma_buffer *buf) +{ + struct llist_node *node; + + /* Calls to llist_del_first are required to be serialized */ + node = llist_del_first(&buf->rb_free_reps); + if (!node) + return NULL; + return llist_entry(node, struct rpcrdma_rep, rr_node); +} + +/** + * rpcrdma_rep_put - Release rpcrdma_rep back to free list + * @buf: buffer pool + * @rep: rep to release + * + */ +void rpcrdma_rep_put(struct rpcrdma_buffer *buf, struct rpcrdma_rep *rep) +{ + llist_add(&rep->rr_node, &buf->rb_free_reps); +} + +/* Caller must ensure the QP is quiescent (RQ is drained) before + * invoking this function, to guarantee rb_all_reps is not + * changing. + */ +static void rpcrdma_reps_unmap(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_rep *rep; + + list_for_each_entry(rep, &buf->rb_all_reps, rr_all) { + rpcrdma_regbuf_dma_unmap(rep->rr_rdmabuf); + rep->rr_temp = true; /* Mark this rep for destruction */ + } +} + +static void rpcrdma_reps_destroy(struct rpcrdma_buffer *buf) +{ + struct rpcrdma_rep *rep; + + spin_lock(&buf->rb_lock); + while ((rep = list_first_entry_or_null(&buf->rb_all_reps, + struct rpcrdma_rep, + rr_all)) != NULL) { + list_del(&rep->rr_all); + spin_unlock(&buf->rb_lock); + + rpcrdma_rep_free(rep); + + spin_lock(&buf->rb_lock); + } + spin_unlock(&buf->rb_lock); +} + +/** + * rpcrdma_buffer_create - Create initial set of req/rep objects + * @r_xprt: transport instance to (re)initialize + * + * Returns zero on success, otherwise a negative errno. + */ +int rpcrdma_buffer_create(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + int i, rc; + + buf->rb_bc_srv_max_requests = 0; + spin_lock_init(&buf->rb_lock); + INIT_LIST_HEAD(&buf->rb_mrs); + INIT_LIST_HEAD(&buf->rb_all_mrs); + INIT_WORK(&buf->rb_refresh_worker, rpcrdma_mr_refresh_worker); + + INIT_LIST_HEAD(&buf->rb_send_bufs); + INIT_LIST_HEAD(&buf->rb_allreqs); + INIT_LIST_HEAD(&buf->rb_all_reps); + + rc = -ENOMEM; + for (i = 0; i < r_xprt->rx_xprt.max_reqs; i++) { + struct rpcrdma_req *req; + + req = rpcrdma_req_create(r_xprt, + RPCRDMA_V1_DEF_INLINE_SIZE * 2); + if (!req) + goto out; + list_add(&req->rl_list, &buf->rb_send_bufs); + } + + init_llist_head(&buf->rb_free_reps); + + return 0; +out: + rpcrdma_buffer_destroy(buf); + return rc; +} + +/** + * rpcrdma_req_destroy - Destroy an rpcrdma_req object + * @req: unused object to be destroyed + * + * Relies on caller holding the transport send lock to protect + * removing req->rl_all from buf->rb_all_reqs safely. + */ +void rpcrdma_req_destroy(struct rpcrdma_req *req) +{ + struct rpcrdma_mr *mr; + + list_del(&req->rl_all); + + while ((mr = rpcrdma_mr_pop(&req->rl_free_mrs))) { + struct rpcrdma_buffer *buf = &mr->mr_xprt->rx_buf; + + spin_lock(&buf->rb_lock); + list_del(&mr->mr_all); + spin_unlock(&buf->rb_lock); + + frwr_mr_release(mr); + } + + rpcrdma_regbuf_free(req->rl_recvbuf); + rpcrdma_regbuf_free(req->rl_sendbuf); + rpcrdma_regbuf_free(req->rl_rdmabuf); + kfree(req); +} + +/** + * rpcrdma_mrs_destroy - Release all of a transport's MRs + * @r_xprt: controlling transport instance + * + * Relies on caller holding the transport send lock to protect + * removing mr->mr_list from req->rl_free_mrs safely. + */ +static void rpcrdma_mrs_destroy(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_mr *mr; + + cancel_work_sync(&buf->rb_refresh_worker); + + spin_lock(&buf->rb_lock); + while ((mr = list_first_entry_or_null(&buf->rb_all_mrs, + struct rpcrdma_mr, + mr_all)) != NULL) { + list_del(&mr->mr_list); + list_del(&mr->mr_all); + spin_unlock(&buf->rb_lock); + + frwr_mr_release(mr); + + spin_lock(&buf->rb_lock); + } + spin_unlock(&buf->rb_lock); +} + +/** + * rpcrdma_buffer_destroy - Release all hw resources + * @buf: root control block for resources + * + * ORDERING: relies on a prior rpcrdma_xprt_drain : + * - No more Send or Receive completions can occur + * - All MRs, reps, and reqs are returned to their free lists + */ +void +rpcrdma_buffer_destroy(struct rpcrdma_buffer *buf) +{ + rpcrdma_reps_destroy(buf); + + while (!list_empty(&buf->rb_send_bufs)) { + struct rpcrdma_req *req; + + req = list_first_entry(&buf->rb_send_bufs, + struct rpcrdma_req, rl_list); + list_del(&req->rl_list); + rpcrdma_req_destroy(req); + } +} + +/** + * rpcrdma_mr_get - Allocate an rpcrdma_mr object + * @r_xprt: controlling transport + * + * Returns an initialized rpcrdma_mr or NULL if no free + * rpcrdma_mr objects are available. + */ +struct rpcrdma_mr * +rpcrdma_mr_get(struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_mr *mr; + + spin_lock(&buf->rb_lock); + mr = rpcrdma_mr_pop(&buf->rb_mrs); + spin_unlock(&buf->rb_lock); + return mr; +} + +/** + * rpcrdma_reply_put - Put reply buffers back into pool + * @buffers: buffer pool + * @req: object to return + * + */ +void rpcrdma_reply_put(struct rpcrdma_buffer *buffers, struct rpcrdma_req *req) +{ + if (req->rl_reply) { + rpcrdma_rep_put(buffers, req->rl_reply); + req->rl_reply = NULL; + } +} + +/** + * rpcrdma_buffer_get - Get a request buffer + * @buffers: Buffer pool from which to obtain a buffer + * + * Returns a fresh rpcrdma_req, or NULL if none are available. + */ +struct rpcrdma_req * +rpcrdma_buffer_get(struct rpcrdma_buffer *buffers) +{ + struct rpcrdma_req *req; + + spin_lock(&buffers->rb_lock); + req = list_first_entry_or_null(&buffers->rb_send_bufs, + struct rpcrdma_req, rl_list); + if (req) + list_del_init(&req->rl_list); + spin_unlock(&buffers->rb_lock); + return req; +} + +/** + * rpcrdma_buffer_put - Put request/reply buffers back into pool + * @buffers: buffer pool + * @req: object to return + * + */ +void rpcrdma_buffer_put(struct rpcrdma_buffer *buffers, struct rpcrdma_req *req) +{ + rpcrdma_reply_put(buffers, req); + + spin_lock(&buffers->rb_lock); + list_add(&req->rl_list, &buffers->rb_send_bufs); + spin_unlock(&buffers->rb_lock); +} + +/* Returns a pointer to a rpcrdma_regbuf object, or NULL. + * + * xprtrdma uses a regbuf for posting an outgoing RDMA SEND, or for + * receiving the payload of RDMA RECV operations. During Long Calls + * or Replies they may be registered externally via frwr_map. + */ +static struct rpcrdma_regbuf * +rpcrdma_regbuf_alloc(size_t size, enum dma_data_direction direction) +{ + struct rpcrdma_regbuf *rb; + + rb = kmalloc(sizeof(*rb), XPRTRDMA_GFP_FLAGS); + if (!rb) + return NULL; + rb->rg_data = kmalloc(size, XPRTRDMA_GFP_FLAGS); + if (!rb->rg_data) { + kfree(rb); + return NULL; + } + + rb->rg_device = NULL; + rb->rg_direction = direction; + rb->rg_iov.length = size; + return rb; +} + +/** + * rpcrdma_regbuf_realloc - re-allocate a SEND/RECV buffer + * @rb: regbuf to reallocate + * @size: size of buffer to be allocated, in bytes + * @flags: GFP flags + * + * Returns true if reallocation was successful. If false is + * returned, @rb is left untouched. + */ +bool rpcrdma_regbuf_realloc(struct rpcrdma_regbuf *rb, size_t size, gfp_t flags) +{ + void *buf; + + buf = kmalloc(size, flags); + if (!buf) + return false; + + rpcrdma_regbuf_dma_unmap(rb); + kfree(rb->rg_data); + + rb->rg_data = buf; + rb->rg_iov.length = size; + return true; +} + +/** + * __rpcrdma_regbuf_dma_map - DMA-map a regbuf + * @r_xprt: controlling transport instance + * @rb: regbuf to be mapped + * + * Returns true if the buffer is now DMA mapped to @r_xprt's device + */ +bool __rpcrdma_regbuf_dma_map(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_regbuf *rb) +{ + struct ib_device *device = r_xprt->rx_ep->re_id->device; + + if (rb->rg_direction == DMA_NONE) + return false; + + rb->rg_iov.addr = ib_dma_map_single(device, rdmab_data(rb), + rdmab_length(rb), rb->rg_direction); + if (ib_dma_mapping_error(device, rdmab_addr(rb))) { + trace_xprtrdma_dma_maperr(rdmab_addr(rb)); + return false; + } + + rb->rg_device = device; + rb->rg_iov.lkey = r_xprt->rx_ep->re_pd->local_dma_lkey; + return true; +} + +static void rpcrdma_regbuf_dma_unmap(struct rpcrdma_regbuf *rb) +{ + if (!rb) + return; + + if (!rpcrdma_regbuf_is_mapped(rb)) + return; + + ib_dma_unmap_single(rb->rg_device, rdmab_addr(rb), rdmab_length(rb), + rb->rg_direction); + rb->rg_device = NULL; +} + +static void rpcrdma_regbuf_free(struct rpcrdma_regbuf *rb) +{ + rpcrdma_regbuf_dma_unmap(rb); + if (rb) + kfree(rb->rg_data); + kfree(rb); +} + +/** + * rpcrdma_post_recvs - Refill the Receive Queue + * @r_xprt: controlling transport instance + * @needed: current credit grant + * @temp: mark Receive buffers to be deleted after one use + * + */ +void rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, int needed, bool temp) +{ + struct rpcrdma_buffer *buf = &r_xprt->rx_buf; + struct rpcrdma_ep *ep = r_xprt->rx_ep; + struct ib_recv_wr *wr, *bad_wr; + struct rpcrdma_rep *rep; + int count, rc; + + rc = 0; + count = 0; + + if (likely(ep->re_receive_count > needed)) + goto out; + needed -= ep->re_receive_count; + if (!temp) + needed += RPCRDMA_MAX_RECV_BATCH; + + if (atomic_inc_return(&ep->re_receiving) > 1) + goto out; + + /* fast path: all needed reps can be found on the free list */ + wr = NULL; + while (needed) { + rep = rpcrdma_rep_get_locked(buf); + if (rep && rep->rr_temp) { + rpcrdma_rep_destroy(rep); + continue; + } + if (!rep) + rep = rpcrdma_rep_create(r_xprt, temp); + if (!rep) + break; + if (!rpcrdma_regbuf_dma_map(r_xprt, rep->rr_rdmabuf)) { + rpcrdma_rep_put(buf, rep); + break; + } + + rep->rr_cid.ci_queue_id = ep->re_attr.recv_cq->res.id; + trace_xprtrdma_post_recv(rep); + rep->rr_recv_wr.next = wr; + wr = &rep->rr_recv_wr; + --needed; + ++count; + } + if (!wr) + goto out; + + rc = ib_post_recv(ep->re_id->qp, wr, + (const struct ib_recv_wr **)&bad_wr); + if (rc) { + trace_xprtrdma_post_recvs_err(r_xprt, rc); + for (wr = bad_wr; wr;) { + struct rpcrdma_rep *rep; + + rep = container_of(wr, struct rpcrdma_rep, rr_recv_wr); + wr = wr->next; + rpcrdma_rep_put(buf, rep); + --count; + } + } + if (atomic_dec_return(&ep->re_receiving) > 0) + complete(&ep->re_done); + +out: + trace_xprtrdma_post_recvs(r_xprt, count); + ep->re_receive_count += count; + return; +} diff --git a/net/sunrpc/xprtrdma/xprt_rdma.h b/net/sunrpc/xprtrdma/xprt_rdma.h new file mode 100644 index 000000000..5e5ff6784 --- /dev/null +++ b/net/sunrpc/xprtrdma/xprt_rdma.h @@ -0,0 +1,605 @@ +/* SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause */ +/* + * Copyright (c) 2014-2017 Oracle. All rights reserved. + * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _LINUX_SUNRPC_XPRT_RDMA_H +#define _LINUX_SUNRPC_XPRT_RDMA_H + +#include /* wait_queue_head_t, etc */ +#include /* spinlock_t, etc */ +#include /* atomic_t, etc */ +#include /* struct kref */ +#include /* struct work_struct */ +#include + +#include /* RDMA connection api */ +#include /* RDMA verbs api */ + +#include /* rpc_xprt */ +#include /* completion IDs */ +#include /* RPC/RDMA protocol */ +#include /* xprt parameters */ + +#define RDMA_RESOLVE_TIMEOUT (5000) /* 5 seconds */ +#define RDMA_CONNECT_RETRY_MAX (2) /* retries if no listener backlog */ + +#define RPCRDMA_BIND_TO (60U * HZ) +#define RPCRDMA_INIT_REEST_TO (5U * HZ) +#define RPCRDMA_MAX_REEST_TO (30U * HZ) +#define RPCRDMA_IDLE_DISC_TO (5U * 60 * HZ) + +/* + * RDMA Endpoint -- connection endpoint details + */ +struct rpcrdma_mr; +struct rpcrdma_ep { + struct kref re_kref; + struct rdma_cm_id *re_id; + struct ib_pd *re_pd; + unsigned int re_max_rdma_segs; + unsigned int re_max_fr_depth; + struct rpcrdma_mr *re_write_pad_mr; + enum ib_mr_type re_mrtype; + struct completion re_done; + unsigned int re_send_count; + unsigned int re_send_batch; + unsigned int re_max_inline_send; + unsigned int re_max_inline_recv; + int re_async_rc; + int re_connect_status; + atomic_t re_receiving; + atomic_t re_force_disconnect; + struct ib_qp_init_attr re_attr; + wait_queue_head_t re_connect_wait; + struct rpc_xprt *re_xprt; + struct rpcrdma_connect_private + re_cm_private; + struct rdma_conn_param re_remote_cma; + int re_receive_count; + unsigned int re_max_requests; /* depends on device */ + unsigned int re_inline_send; /* negotiated */ + unsigned int re_inline_recv; /* negotiated */ + + atomic_t re_completion_ids; + + char re_write_pad[XDR_UNIT]; +}; + +/* Pre-allocate extra Work Requests for handling reverse-direction + * Receives and Sends. This is a fixed value because the Work Queues + * are allocated when the forward channel is set up, long before the + * backchannel is provisioned. This value is two times + * NFS4_DEF_CB_SLOT_TABLE_SIZE. + */ +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +#define RPCRDMA_BACKWARD_WRS (32) +#else +#define RPCRDMA_BACKWARD_WRS (0) +#endif + +/* Registered buffer -- registered kmalloc'd memory for RDMA SEND/RECV + */ + +struct rpcrdma_regbuf { + struct ib_sge rg_iov; + struct ib_device *rg_device; + enum dma_data_direction rg_direction; + void *rg_data; +}; + +static inline u64 rdmab_addr(struct rpcrdma_regbuf *rb) +{ + return rb->rg_iov.addr; +} + +static inline u32 rdmab_length(struct rpcrdma_regbuf *rb) +{ + return rb->rg_iov.length; +} + +static inline u32 rdmab_lkey(struct rpcrdma_regbuf *rb) +{ + return rb->rg_iov.lkey; +} + +static inline struct ib_device *rdmab_device(struct rpcrdma_regbuf *rb) +{ + return rb->rg_device; +} + +static inline void *rdmab_data(const struct rpcrdma_regbuf *rb) +{ + return rb->rg_data; +} + +/* Do not use emergency memory reserves, and fail quickly if memory + * cannot be allocated easily. These flags may be used wherever there + * is robust logic to handle a failure to allocate. + */ +#define XPRTRDMA_GFP_FLAGS (__GFP_NOMEMALLOC | __GFP_NORETRY | __GFP_NOWARN) + +/* To ensure a transport can always make forward progress, + * the number of RDMA segments allowed in header chunk lists + * is capped at 16. This prevents less-capable devices from + * overrunning the Send buffer while building chunk lists. + * + * Elements of the Read list take up more room than the + * Write list or Reply chunk. 16 read segments means the + * chunk lists cannot consume more than + * + * ((16 + 2) * read segment size) + 1 XDR words, + * + * or about 400 bytes. The fixed part of the header is + * another 24 bytes. Thus when the inline threshold is + * 1024 bytes, at least 600 bytes are available for RPC + * message bodies. + */ +enum { + RPCRDMA_MAX_HDR_SEGS = 16, +}; + +/* + * struct rpcrdma_rep -- this structure encapsulates state required + * to receive and complete an RPC Reply, asychronously. It needs + * several pieces of state: + * + * o receive buffer and ib_sge (donated to provider) + * o status of receive (success or not, length, inv rkey) + * o bookkeeping state to get run by reply handler (XDR stream) + * + * These structures are allocated during transport initialization. + * N of these are associated with a transport instance, managed by + * struct rpcrdma_buffer. N is the max number of outstanding RPCs. + */ + +struct rpcrdma_rep { + struct ib_cqe rr_cqe; + struct rpc_rdma_cid rr_cid; + + __be32 rr_xid; + __be32 rr_vers; + __be32 rr_proc; + int rr_wc_flags; + u32 rr_inv_rkey; + bool rr_temp; + struct rpcrdma_regbuf *rr_rdmabuf; + struct rpcrdma_xprt *rr_rxprt; + struct rpc_rqst *rr_rqst; + struct xdr_buf rr_hdrbuf; + struct xdr_stream rr_stream; + struct llist_node rr_node; + struct ib_recv_wr rr_recv_wr; + struct list_head rr_all; +}; + +/* To reduce the rate at which a transport invokes ib_post_recv + * (and thus the hardware doorbell rate), xprtrdma posts Receive + * WRs in batches. + * + * Setting this to zero disables Receive post batching. + */ +enum { + RPCRDMA_MAX_RECV_BATCH = 7, +}; + +/* struct rpcrdma_sendctx - DMA mapped SGEs to unmap after Send completes + */ +struct rpcrdma_req; +struct rpcrdma_sendctx { + struct ib_cqe sc_cqe; + struct rpc_rdma_cid sc_cid; + struct rpcrdma_req *sc_req; + unsigned int sc_unmap_count; + struct ib_sge sc_sges[]; +}; + +/* + * struct rpcrdma_mr - external memory region metadata + * + * An external memory region is any buffer or page that is registered + * on the fly (ie, not pre-registered). + */ +struct rpcrdma_req; +struct rpcrdma_mr { + struct list_head mr_list; + struct rpcrdma_req *mr_req; + + struct ib_mr *mr_ibmr; + struct ib_device *mr_device; + struct scatterlist *mr_sg; + int mr_nents; + enum dma_data_direction mr_dir; + struct ib_cqe mr_cqe; + struct completion mr_linv_done; + union { + struct ib_reg_wr mr_regwr; + struct ib_send_wr mr_invwr; + }; + struct rpcrdma_xprt *mr_xprt; + u32 mr_handle; + u32 mr_length; + u64 mr_offset; + struct list_head mr_all; + struct rpc_rdma_cid mr_cid; +}; + +/* + * struct rpcrdma_req -- structure central to the request/reply sequence. + * + * N of these are associated with a transport instance, and stored in + * struct rpcrdma_buffer. N is the max number of outstanding requests. + * + * It includes pre-registered buffer memory for send AND recv. + * The recv buffer, however, is not owned by this structure, and + * is "donated" to the hardware when a recv is posted. When a + * reply is handled, the recv buffer used is given back to the + * struct rpcrdma_req associated with the request. + * + * In addition to the basic memory, this structure includes an array + * of iovs for send operations. The reason is that the iovs passed to + * ib_post_{send,recv} must not be modified until the work request + * completes. + */ + +/* Maximum number of page-sized "segments" per chunk list to be + * registered or invalidated. Must handle a Reply chunk: + */ +enum { + RPCRDMA_MAX_IOV_SEGS = 3, + RPCRDMA_MAX_DATA_SEGS = ((1 * 1024 * 1024) / PAGE_SIZE) + 1, + RPCRDMA_MAX_SEGS = RPCRDMA_MAX_DATA_SEGS + + RPCRDMA_MAX_IOV_SEGS, +}; + +/* Arguments for DMA mapping and registration */ +struct rpcrdma_mr_seg { + u32 mr_len; /* length of segment */ + struct page *mr_page; /* underlying struct page */ + u64 mr_offset; /* IN: page offset, OUT: iova */ +}; + +/* The Send SGE array is provisioned to send a maximum size + * inline request: + * - RPC-over-RDMA header + * - xdr_buf head iovec + * - RPCRDMA_MAX_INLINE bytes, in pages + * - xdr_buf tail iovec + * + * The actual number of array elements consumed by each RPC + * depends on the device's max_sge limit. + */ +enum { + RPCRDMA_MIN_SEND_SGES = 3, + RPCRDMA_MAX_PAGE_SGES = RPCRDMA_MAX_INLINE >> PAGE_SHIFT, + RPCRDMA_MAX_SEND_SGES = 1 + 1 + RPCRDMA_MAX_PAGE_SGES + 1, +}; + +struct rpcrdma_buffer; +struct rpcrdma_req { + struct list_head rl_list; + struct rpc_rqst rl_slot; + struct rpcrdma_rep *rl_reply; + struct xdr_stream rl_stream; + struct xdr_buf rl_hdrbuf; + struct ib_send_wr rl_wr; + struct rpcrdma_sendctx *rl_sendctx; + struct rpcrdma_regbuf *rl_rdmabuf; /* xprt header */ + struct rpcrdma_regbuf *rl_sendbuf; /* rq_snd_buf */ + struct rpcrdma_regbuf *rl_recvbuf; /* rq_rcv_buf */ + + struct list_head rl_all; + struct kref rl_kref; + + struct list_head rl_free_mrs; + struct list_head rl_registered; + struct rpcrdma_mr_seg rl_segments[RPCRDMA_MAX_SEGS]; +}; + +static inline struct rpcrdma_req * +rpcr_to_rdmar(const struct rpc_rqst *rqst) +{ + return container_of(rqst, struct rpcrdma_req, rl_slot); +} + +static inline void +rpcrdma_mr_push(struct rpcrdma_mr *mr, struct list_head *list) +{ + list_add(&mr->mr_list, list); +} + +static inline struct rpcrdma_mr * +rpcrdma_mr_pop(struct list_head *list) +{ + struct rpcrdma_mr *mr; + + mr = list_first_entry_or_null(list, struct rpcrdma_mr, mr_list); + if (mr) + list_del_init(&mr->mr_list); + return mr; +} + +/* + * struct rpcrdma_buffer -- holds list/queue of pre-registered memory for + * inline requests/replies, and client/server credits. + * + * One of these is associated with a transport instance + */ +struct rpcrdma_buffer { + spinlock_t rb_lock; + struct list_head rb_send_bufs; + struct list_head rb_mrs; + + unsigned long rb_sc_head; + unsigned long rb_sc_tail; + unsigned long rb_sc_last; + struct rpcrdma_sendctx **rb_sc_ctxs; + + struct list_head rb_allreqs; + struct list_head rb_all_mrs; + struct list_head rb_all_reps; + + struct llist_head rb_free_reps; + + __be32 rb_max_requests; + u32 rb_credits; /* most recent credit grant */ + + u32 rb_bc_srv_max_requests; + u32 rb_bc_max_requests; + + struct work_struct rb_refresh_worker; +}; + +/* + * Statistics for RPCRDMA + */ +struct rpcrdma_stats { + /* accessed when sending a call */ + unsigned long read_chunk_count; + unsigned long write_chunk_count; + unsigned long reply_chunk_count; + unsigned long long total_rdma_request; + + /* rarely accessed error counters */ + unsigned long long pullup_copy_count; + unsigned long hardway_register_count; + unsigned long failed_marshal_count; + unsigned long bad_reply_count; + unsigned long mrs_recycled; + unsigned long mrs_orphaned; + unsigned long mrs_allocated; + unsigned long empty_sendctx_q; + + /* accessed when receiving a reply */ + unsigned long long total_rdma_reply; + unsigned long long fixup_copy_count; + unsigned long reply_waits_for_send; + unsigned long local_inv_needed; + unsigned long nomsg_call_count; + unsigned long bcall_count; +}; + +/* + * RPCRDMA transport -- encapsulates the structures above for + * integration with RPC. + * + * The contained structures are embedded, not pointers, + * for convenience. This structure need not be visible externally. + * + * It is allocated and initialized during mount, and released + * during unmount. + */ +struct rpcrdma_xprt { + struct rpc_xprt rx_xprt; + struct rpcrdma_ep *rx_ep; + struct rpcrdma_buffer rx_buf; + struct delayed_work rx_connect_worker; + struct rpc_timeout rx_timeout; + struct rpcrdma_stats rx_stats; +}; + +#define rpcx_to_rdmax(x) container_of(x, struct rpcrdma_xprt, rx_xprt) + +static inline const char * +rpcrdma_addrstr(const struct rpcrdma_xprt *r_xprt) +{ + return r_xprt->rx_xprt.address_strings[RPC_DISPLAY_ADDR]; +} + +static inline const char * +rpcrdma_portstr(const struct rpcrdma_xprt *r_xprt) +{ + return r_xprt->rx_xprt.address_strings[RPC_DISPLAY_PORT]; +} + +/* Setting this to 0 ensures interoperability with early servers. + * Setting this to 1 enhances certain unaligned read/write performance. + * Default is 0, see sysctl entry and rpc_rdma.c rpcrdma_convert_iovs() */ +extern int xprt_rdma_pad_optimize; + +/* This setting controls the hunt for a supported memory + * registration strategy. + */ +extern unsigned int xprt_rdma_memreg_strategy; + +/* + * Endpoint calls - xprtrdma/verbs.c + */ +void rpcrdma_force_disconnect(struct rpcrdma_ep *ep); +void rpcrdma_flush_disconnect(struct rpcrdma_xprt *r_xprt, struct ib_wc *wc); +int rpcrdma_xprt_connect(struct rpcrdma_xprt *r_xprt); +void rpcrdma_xprt_disconnect(struct rpcrdma_xprt *r_xprt); + +void rpcrdma_post_recvs(struct rpcrdma_xprt *r_xprt, int needed, bool temp); + +/* + * Buffer calls - xprtrdma/verbs.c + */ +struct rpcrdma_req *rpcrdma_req_create(struct rpcrdma_xprt *r_xprt, + size_t size); +int rpcrdma_req_setup(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req); +void rpcrdma_req_destroy(struct rpcrdma_req *req); +int rpcrdma_buffer_create(struct rpcrdma_xprt *); +void rpcrdma_buffer_destroy(struct rpcrdma_buffer *); +struct rpcrdma_sendctx *rpcrdma_sendctx_get_locked(struct rpcrdma_xprt *r_xprt); + +struct rpcrdma_mr *rpcrdma_mr_get(struct rpcrdma_xprt *r_xprt); +void rpcrdma_mrs_refresh(struct rpcrdma_xprt *r_xprt); + +struct rpcrdma_req *rpcrdma_buffer_get(struct rpcrdma_buffer *); +void rpcrdma_buffer_put(struct rpcrdma_buffer *buffers, + struct rpcrdma_req *req); +void rpcrdma_rep_put(struct rpcrdma_buffer *buf, struct rpcrdma_rep *rep); +void rpcrdma_reply_put(struct rpcrdma_buffer *buffers, struct rpcrdma_req *req); + +bool rpcrdma_regbuf_realloc(struct rpcrdma_regbuf *rb, size_t size, + gfp_t flags); +bool __rpcrdma_regbuf_dma_map(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_regbuf *rb); + +/** + * rpcrdma_regbuf_is_mapped - check if buffer is DMA mapped + * + * Returns true if the buffer is now mapped to rb->rg_device. + */ +static inline bool rpcrdma_regbuf_is_mapped(struct rpcrdma_regbuf *rb) +{ + return rb->rg_device != NULL; +} + +/** + * rpcrdma_regbuf_dma_map - DMA-map a regbuf + * @r_xprt: controlling transport instance + * @rb: regbuf to be mapped + * + * Returns true if the buffer is currently DMA mapped. + */ +static inline bool rpcrdma_regbuf_dma_map(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_regbuf *rb) +{ + if (likely(rpcrdma_regbuf_is_mapped(rb))) + return true; + return __rpcrdma_regbuf_dma_map(r_xprt, rb); +} + +/* + * Wrappers for chunk registration, shared by read/write chunk code. + */ + +static inline enum dma_data_direction +rpcrdma_data_dir(bool writing) +{ + return writing ? DMA_FROM_DEVICE : DMA_TO_DEVICE; +} + +/* Memory registration calls xprtrdma/frwr_ops.c + */ +void frwr_reset(struct rpcrdma_req *req); +int frwr_query_device(struct rpcrdma_ep *ep, const struct ib_device *device); +int frwr_mr_init(struct rpcrdma_xprt *r_xprt, struct rpcrdma_mr *mr); +void frwr_mr_release(struct rpcrdma_mr *mr); +struct rpcrdma_mr_seg *frwr_map(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_mr_seg *seg, + int nsegs, bool writing, __be32 xid, + struct rpcrdma_mr *mr); +int frwr_send(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req); +void frwr_reminv(struct rpcrdma_rep *rep, struct list_head *mrs); +void frwr_unmap_sync(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req); +void frwr_unmap_async(struct rpcrdma_xprt *r_xprt, struct rpcrdma_req *req); +int frwr_wp_create(struct rpcrdma_xprt *r_xprt); + +/* + * RPC/RDMA protocol calls - xprtrdma/rpc_rdma.c + */ + +enum rpcrdma_chunktype { + rpcrdma_noch = 0, + rpcrdma_noch_pullup, + rpcrdma_noch_mapped, + rpcrdma_readch, + rpcrdma_areadch, + rpcrdma_writech, + rpcrdma_replych +}; + +int rpcrdma_prepare_send_sges(struct rpcrdma_xprt *r_xprt, + struct rpcrdma_req *req, u32 hdrlen, + struct xdr_buf *xdr, + enum rpcrdma_chunktype rtype); +void rpcrdma_sendctx_unmap(struct rpcrdma_sendctx *sc); +int rpcrdma_marshal_req(struct rpcrdma_xprt *r_xprt, struct rpc_rqst *rqst); +void rpcrdma_set_max_header_sizes(struct rpcrdma_ep *ep); +void rpcrdma_reset_cwnd(struct rpcrdma_xprt *r_xprt); +void rpcrdma_complete_rqst(struct rpcrdma_rep *rep); +void rpcrdma_unpin_rqst(struct rpcrdma_rep *rep); +void rpcrdma_reply_handler(struct rpcrdma_rep *rep); + +static inline void rpcrdma_set_xdrlen(struct xdr_buf *xdr, size_t len) +{ + xdr->head[0].iov_len = len; + xdr->len = len; +} + +/* RPC/RDMA module init - xprtrdma/transport.c + */ +extern unsigned int xprt_rdma_max_inline_read; +extern unsigned int xprt_rdma_max_inline_write; +void xprt_rdma_format_addresses(struct rpc_xprt *xprt, struct sockaddr *sap); +void xprt_rdma_free_addresses(struct rpc_xprt *xprt); +void xprt_rdma_close(struct rpc_xprt *xprt); +void xprt_rdma_print_stats(struct rpc_xprt *xprt, struct seq_file *seq); +int xprt_rdma_init(void); +void xprt_rdma_cleanup(void); + +/* Backchannel calls - xprtrdma/backchannel.c + */ +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +int xprt_rdma_bc_setup(struct rpc_xprt *, unsigned int); +size_t xprt_rdma_bc_maxpayload(struct rpc_xprt *); +unsigned int xprt_rdma_bc_max_slots(struct rpc_xprt *); +int rpcrdma_bc_post_recv(struct rpcrdma_xprt *, unsigned int); +void rpcrdma_bc_receive_call(struct rpcrdma_xprt *, struct rpcrdma_rep *); +int xprt_rdma_bc_send_reply(struct rpc_rqst *rqst); +void xprt_rdma_bc_free_rqst(struct rpc_rqst *); +void xprt_rdma_bc_destroy(struct rpc_xprt *, unsigned int); +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +extern struct xprt_class xprt_rdma_bc; + +#endif /* _LINUX_SUNRPC_XPRT_RDMA_H */ diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c new file mode 100644 index 000000000..05aa32696 --- /dev/null +++ b/net/sunrpc/xprtsock.c @@ -0,0 +1,3257 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * linux/net/sunrpc/xprtsock.c + * + * Client-side transport implementation for sockets. + * + * TCP callback races fixes (C) 1998 Red Hat + * TCP send fixes (C) 1998 Red Hat + * TCP NFS related read + write fixes + * (C) 1999 Dave Airlie, University of Limerick, Ireland + * + * Rewrite of larges part of the code in order to stabilize TCP stuff. + * Fix behaviour when socket buffer is full. + * (C) 1999 Trond Myklebust + * + * IP socket transport implementation, (C) 2005 Chuck Lever + * + * IPv6 support contributed by Gilles Quillard, Bull Open Source, 2005. + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef CONFIG_SUNRPC_BACKCHANNEL +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "socklib.h" +#include "sunrpc.h" + +static void xs_close(struct rpc_xprt *xprt); +static void xs_set_srcport(struct sock_xprt *transport, struct socket *sock); +static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt, + struct socket *sock); + +/* + * xprtsock tunables + */ +static unsigned int xprt_udp_slot_table_entries = RPC_DEF_SLOT_TABLE; +static unsigned int xprt_tcp_slot_table_entries = RPC_MIN_SLOT_TABLE; +static unsigned int xprt_max_tcp_slot_table_entries = RPC_MAX_SLOT_TABLE; + +static unsigned int xprt_min_resvport = RPC_DEF_MIN_RESVPORT; +static unsigned int xprt_max_resvport = RPC_DEF_MAX_RESVPORT; + +#define XS_TCP_LINGER_TO (15U * HZ) +static unsigned int xs_tcp_fin_timeout __read_mostly = XS_TCP_LINGER_TO; + +/* + * We can register our own files under /proc/sys/sunrpc by + * calling register_sysctl_table() again. The files in that + * directory become the union of all files registered there. + * + * We simply need to make sure that we don't collide with + * someone else's file names! + */ + +static unsigned int min_slot_table_size = RPC_MIN_SLOT_TABLE; +static unsigned int max_slot_table_size = RPC_MAX_SLOT_TABLE; +static unsigned int max_tcp_slot_table_limit = RPC_MAX_SLOT_TABLE_LIMIT; +static unsigned int xprt_min_resvport_limit = RPC_MIN_RESVPORT; +static unsigned int xprt_max_resvport_limit = RPC_MAX_RESVPORT; + +static struct ctl_table_header *sunrpc_table_header; + +static struct xprt_class xs_local_transport; +static struct xprt_class xs_udp_transport; +static struct xprt_class xs_tcp_transport; +static struct xprt_class xs_bc_tcp_transport; + +/* + * FIXME: changing the UDP slot table size should also resize the UDP + * socket buffers for existing UDP transports + */ +static struct ctl_table xs_tunables_table[] = { + { + .procname = "udp_slot_table_entries", + .data = &xprt_udp_slot_table_entries, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_slot_table_size, + .extra2 = &max_slot_table_size + }, + { + .procname = "tcp_slot_table_entries", + .data = &xprt_tcp_slot_table_entries, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_slot_table_size, + .extra2 = &max_slot_table_size + }, + { + .procname = "tcp_max_slot_table_entries", + .data = &xprt_max_tcp_slot_table_entries, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_slot_table_size, + .extra2 = &max_tcp_slot_table_limit + }, + { + .procname = "min_resvport", + .data = &xprt_min_resvport, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &xprt_min_resvport_limit, + .extra2 = &xprt_max_resvport_limit + }, + { + .procname = "max_resvport", + .data = &xprt_max_resvport, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &xprt_min_resvport_limit, + .extra2 = &xprt_max_resvport_limit + }, + { + .procname = "tcp_fin_timeout", + .data = &xs_tcp_fin_timeout, + .maxlen = sizeof(xs_tcp_fin_timeout), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { }, +}; + +static struct ctl_table sunrpc_table[] = { + { + .procname = "sunrpc", + .mode = 0555, + .child = xs_tunables_table + }, + { }, +}; + +/* + * Wait duration for a reply from the RPC portmapper. + */ +#define XS_BIND_TO (60U * HZ) + +/* + * Delay if a UDP socket connect error occurs. This is most likely some + * kind of resource problem on the local host. + */ +#define XS_UDP_REEST_TO (2U * HZ) + +/* + * The reestablish timeout allows clients to delay for a bit before attempting + * to reconnect to a server that just dropped our connection. + * + * We implement an exponential backoff when trying to reestablish a TCP + * transport connection with the server. Some servers like to drop a TCP + * connection when they are overworked, so we start with a short timeout and + * increase over time if the server is down or not responding. + */ +#define XS_TCP_INIT_REEST_TO (3U * HZ) + +/* + * TCP idle timeout; client drops the transport socket if it is idle + * for this long. Note that we also timeout UDP sockets to prevent + * holding port numbers when there is no RPC traffic. + */ +#define XS_IDLE_DISC_TO (5U * 60 * HZ) + +#if IS_ENABLED(CONFIG_SUNRPC_DEBUG) +# undef RPC_DEBUG_DATA +# define RPCDBG_FACILITY RPCDBG_TRANS +#endif + +#ifdef RPC_DEBUG_DATA +static void xs_pktdump(char *msg, u32 *packet, unsigned int count) +{ + u8 *buf = (u8 *) packet; + int j; + + dprintk("RPC: %s\n", msg); + for (j = 0; j < count && j < 128; j += 4) { + if (!(j & 31)) { + if (j) + dprintk("\n"); + dprintk("0x%04x ", j); + } + dprintk("%02x%02x%02x%02x ", + buf[j], buf[j+1], buf[j+2], buf[j+3]); + } + dprintk("\n"); +} +#else +static inline void xs_pktdump(char *msg, u32 *packet, unsigned int count) +{ + /* NOP */ +} +#endif + +static inline struct rpc_xprt *xprt_from_sock(struct sock *sk) +{ + return (struct rpc_xprt *) sk->sk_user_data; +} + +static inline struct sockaddr *xs_addr(struct rpc_xprt *xprt) +{ + return (struct sockaddr *) &xprt->addr; +} + +static inline struct sockaddr_un *xs_addr_un(struct rpc_xprt *xprt) +{ + return (struct sockaddr_un *) &xprt->addr; +} + +static inline struct sockaddr_in *xs_addr_in(struct rpc_xprt *xprt) +{ + return (struct sockaddr_in *) &xprt->addr; +} + +static inline struct sockaddr_in6 *xs_addr_in6(struct rpc_xprt *xprt) +{ + return (struct sockaddr_in6 *) &xprt->addr; +} + +static void xs_format_common_peer_addresses(struct rpc_xprt *xprt) +{ + struct sockaddr *sap = xs_addr(xprt); + struct sockaddr_in6 *sin6; + struct sockaddr_in *sin; + struct sockaddr_un *sun; + char buf[128]; + + switch (sap->sa_family) { + case AF_LOCAL: + sun = xs_addr_un(xprt); + strscpy(buf, sun->sun_path, sizeof(buf)); + xprt->address_strings[RPC_DISPLAY_ADDR] = + kstrdup(buf, GFP_KERNEL); + break; + case AF_INET: + (void)rpc_ntop(sap, buf, sizeof(buf)); + xprt->address_strings[RPC_DISPLAY_ADDR] = + kstrdup(buf, GFP_KERNEL); + sin = xs_addr_in(xprt); + snprintf(buf, sizeof(buf), "%08x", ntohl(sin->sin_addr.s_addr)); + break; + case AF_INET6: + (void)rpc_ntop(sap, buf, sizeof(buf)); + xprt->address_strings[RPC_DISPLAY_ADDR] = + kstrdup(buf, GFP_KERNEL); + sin6 = xs_addr_in6(xprt); + snprintf(buf, sizeof(buf), "%pi6", &sin6->sin6_addr); + break; + default: + BUG(); + } + + xprt->address_strings[RPC_DISPLAY_HEX_ADDR] = kstrdup(buf, GFP_KERNEL); +} + +static void xs_format_common_peer_ports(struct rpc_xprt *xprt) +{ + struct sockaddr *sap = xs_addr(xprt); + char buf[128]; + + snprintf(buf, sizeof(buf), "%u", rpc_get_port(sap)); + xprt->address_strings[RPC_DISPLAY_PORT] = kstrdup(buf, GFP_KERNEL); + + snprintf(buf, sizeof(buf), "%4hx", rpc_get_port(sap)); + xprt->address_strings[RPC_DISPLAY_HEX_PORT] = kstrdup(buf, GFP_KERNEL); +} + +static void xs_format_peer_addresses(struct rpc_xprt *xprt, + const char *protocol, + const char *netid) +{ + xprt->address_strings[RPC_DISPLAY_PROTO] = protocol; + xprt->address_strings[RPC_DISPLAY_NETID] = netid; + xs_format_common_peer_addresses(xprt); + xs_format_common_peer_ports(xprt); +} + +static void xs_update_peer_port(struct rpc_xprt *xprt) +{ + kfree(xprt->address_strings[RPC_DISPLAY_HEX_PORT]); + kfree(xprt->address_strings[RPC_DISPLAY_PORT]); + + xs_format_common_peer_ports(xprt); +} + +static void xs_free_peer_addresses(struct rpc_xprt *xprt) +{ + unsigned int i; + + for (i = 0; i < RPC_DISPLAY_MAX; i++) + switch (i) { + case RPC_DISPLAY_PROTO: + case RPC_DISPLAY_NETID: + continue; + default: + kfree(xprt->address_strings[i]); + } +} + +static size_t +xs_alloc_sparse_pages(struct xdr_buf *buf, size_t want, gfp_t gfp) +{ + size_t i,n; + + if (!want || !(buf->flags & XDRBUF_SPARSE_PAGES)) + return want; + n = (buf->page_base + want + PAGE_SIZE - 1) >> PAGE_SHIFT; + for (i = 0; i < n; i++) { + if (buf->pages[i]) + continue; + buf->bvec[i].bv_page = buf->pages[i] = alloc_page(gfp); + if (!buf->pages[i]) { + i *= PAGE_SIZE; + return i > buf->page_base ? i - buf->page_base : 0; + } + } + return want; +} + +static ssize_t +xs_sock_recvmsg(struct socket *sock, struct msghdr *msg, int flags, size_t seek) +{ + ssize_t ret; + if (seek != 0) + iov_iter_advance(&msg->msg_iter, seek); + ret = sock_recvmsg(sock, msg, flags); + return ret > 0 ? ret + seek : ret; +} + +static ssize_t +xs_read_kvec(struct socket *sock, struct msghdr *msg, int flags, + struct kvec *kvec, size_t count, size_t seek) +{ + iov_iter_kvec(&msg->msg_iter, ITER_DEST, kvec, 1, count); + return xs_sock_recvmsg(sock, msg, flags, seek); +} + +static ssize_t +xs_read_bvec(struct socket *sock, struct msghdr *msg, int flags, + struct bio_vec *bvec, unsigned long nr, size_t count, + size_t seek) +{ + iov_iter_bvec(&msg->msg_iter, ITER_DEST, bvec, nr, count); + return xs_sock_recvmsg(sock, msg, flags, seek); +} + +static ssize_t +xs_read_discard(struct socket *sock, struct msghdr *msg, int flags, + size_t count) +{ + iov_iter_discard(&msg->msg_iter, ITER_DEST, count); + return sock_recvmsg(sock, msg, flags); +} + +#if ARCH_IMPLEMENTS_FLUSH_DCACHE_PAGE +static void +xs_flush_bvec(const struct bio_vec *bvec, size_t count, size_t seek) +{ + struct bvec_iter bi = { + .bi_size = count, + }; + struct bio_vec bv; + + bvec_iter_advance(bvec, &bi, seek & PAGE_MASK); + for_each_bvec(bv, bvec, bi, bi) + flush_dcache_page(bv.bv_page); +} +#else +static inline void +xs_flush_bvec(const struct bio_vec *bvec, size_t count, size_t seek) +{ +} +#endif + +static ssize_t +xs_read_xdr_buf(struct socket *sock, struct msghdr *msg, int flags, + struct xdr_buf *buf, size_t count, size_t seek, size_t *read) +{ + size_t want, seek_init = seek, offset = 0; + ssize_t ret; + + want = min_t(size_t, count, buf->head[0].iov_len); + if (seek < want) { + ret = xs_read_kvec(sock, msg, flags, &buf->head[0], want, seek); + if (ret <= 0) + goto sock_err; + offset += ret; + if (offset == count || msg->msg_flags & (MSG_EOR|MSG_TRUNC)) + goto out; + if (ret != want) + goto out; + seek = 0; + } else { + seek -= want; + offset += want; + } + + want = xs_alloc_sparse_pages( + buf, min_t(size_t, count - offset, buf->page_len), + GFP_KERNEL | __GFP_NORETRY | __GFP_NOWARN); + if (seek < want) { + ret = xs_read_bvec(sock, msg, flags, buf->bvec, + xdr_buf_pagecount(buf), + want + buf->page_base, + seek + buf->page_base); + if (ret <= 0) + goto sock_err; + xs_flush_bvec(buf->bvec, ret, seek + buf->page_base); + ret -= buf->page_base; + offset += ret; + if (offset == count || msg->msg_flags & (MSG_EOR|MSG_TRUNC)) + goto out; + if (ret != want) + goto out; + seek = 0; + } else { + seek -= want; + offset += want; + } + + want = min_t(size_t, count - offset, buf->tail[0].iov_len); + if (seek < want) { + ret = xs_read_kvec(sock, msg, flags, &buf->tail[0], want, seek); + if (ret <= 0) + goto sock_err; + offset += ret; + if (offset == count || msg->msg_flags & (MSG_EOR|MSG_TRUNC)) + goto out; + if (ret != want) + goto out; + } else if (offset < seek_init) + offset = seek_init; + ret = -EMSGSIZE; +out: + *read = offset - seek_init; + return ret; +sock_err: + offset += seek; + goto out; +} + +static void +xs_read_header(struct sock_xprt *transport, struct xdr_buf *buf) +{ + if (!transport->recv.copied) { + if (buf->head[0].iov_len >= transport->recv.offset) + memcpy(buf->head[0].iov_base, + &transport->recv.xid, + transport->recv.offset); + transport->recv.copied = transport->recv.offset; + } +} + +static bool +xs_read_stream_request_done(struct sock_xprt *transport) +{ + return transport->recv.fraghdr & cpu_to_be32(RPC_LAST_STREAM_FRAGMENT); +} + +static void +xs_read_stream_check_eor(struct sock_xprt *transport, + struct msghdr *msg) +{ + if (xs_read_stream_request_done(transport)) + msg->msg_flags |= MSG_EOR; +} + +static ssize_t +xs_read_stream_request(struct sock_xprt *transport, struct msghdr *msg, + int flags, struct rpc_rqst *req) +{ + struct xdr_buf *buf = &req->rq_private_buf; + size_t want, read; + ssize_t ret; + + xs_read_header(transport, buf); + + want = transport->recv.len - transport->recv.offset; + if (want != 0) { + ret = xs_read_xdr_buf(transport->sock, msg, flags, buf, + transport->recv.copied + want, + transport->recv.copied, + &read); + transport->recv.offset += read; + transport->recv.copied += read; + } + + if (transport->recv.offset == transport->recv.len) + xs_read_stream_check_eor(transport, msg); + + if (want == 0) + return 0; + + switch (ret) { + default: + break; + case -EFAULT: + case -EMSGSIZE: + msg->msg_flags |= MSG_TRUNC; + return read; + case 0: + return -ESHUTDOWN; + } + return ret < 0 ? ret : read; +} + +static size_t +xs_read_stream_headersize(bool isfrag) +{ + if (isfrag) + return sizeof(__be32); + return 3 * sizeof(__be32); +} + +static ssize_t +xs_read_stream_header(struct sock_xprt *transport, struct msghdr *msg, + int flags, size_t want, size_t seek) +{ + struct kvec kvec = { + .iov_base = &transport->recv.fraghdr, + .iov_len = want, + }; + return xs_read_kvec(transport->sock, msg, flags, &kvec, want, seek); +} + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +static ssize_t +xs_read_stream_call(struct sock_xprt *transport, struct msghdr *msg, int flags) +{ + struct rpc_xprt *xprt = &transport->xprt; + struct rpc_rqst *req; + ssize_t ret; + + /* Is this transport associated with the backchannel? */ + if (!xprt->bc_serv) + return -ESHUTDOWN; + + /* Look up and lock the request corresponding to the given XID */ + req = xprt_lookup_bc_request(xprt, transport->recv.xid); + if (!req) { + printk(KERN_WARNING "Callback slot table overflowed\n"); + return -ESHUTDOWN; + } + if (transport->recv.copied && !req->rq_private_buf.len) + return -ESHUTDOWN; + + ret = xs_read_stream_request(transport, msg, flags, req); + if (msg->msg_flags & (MSG_EOR|MSG_TRUNC)) + xprt_complete_bc_request(req, transport->recv.copied); + else + req->rq_private_buf.len = transport->recv.copied; + + return ret; +} +#else /* CONFIG_SUNRPC_BACKCHANNEL */ +static ssize_t +xs_read_stream_call(struct sock_xprt *transport, struct msghdr *msg, int flags) +{ + return -ESHUTDOWN; +} +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +static ssize_t +xs_read_stream_reply(struct sock_xprt *transport, struct msghdr *msg, int flags) +{ + struct rpc_xprt *xprt = &transport->xprt; + struct rpc_rqst *req; + ssize_t ret = 0; + + /* Look up and lock the request corresponding to the given XID */ + spin_lock(&xprt->queue_lock); + req = xprt_lookup_rqst(xprt, transport->recv.xid); + if (!req || (transport->recv.copied && !req->rq_private_buf.len)) { + msg->msg_flags |= MSG_TRUNC; + goto out; + } + xprt_pin_rqst(req); + spin_unlock(&xprt->queue_lock); + + ret = xs_read_stream_request(transport, msg, flags, req); + + spin_lock(&xprt->queue_lock); + if (msg->msg_flags & (MSG_EOR|MSG_TRUNC)) + xprt_complete_rqst(req->rq_task, transport->recv.copied); + else + req->rq_private_buf.len = transport->recv.copied; + xprt_unpin_rqst(req); +out: + spin_unlock(&xprt->queue_lock); + return ret; +} + +static ssize_t +xs_read_stream(struct sock_xprt *transport, int flags) +{ + struct msghdr msg = { 0 }; + size_t want, read = 0; + ssize_t ret = 0; + + if (transport->recv.len == 0) { + want = xs_read_stream_headersize(transport->recv.copied != 0); + ret = xs_read_stream_header(transport, &msg, flags, want, + transport->recv.offset); + if (ret <= 0) + goto out_err; + transport->recv.offset = ret; + if (transport->recv.offset != want) + return transport->recv.offset; + transport->recv.len = be32_to_cpu(transport->recv.fraghdr) & + RPC_FRAGMENT_SIZE_MASK; + transport->recv.offset -= sizeof(transport->recv.fraghdr); + read = ret; + } + + switch (be32_to_cpu(transport->recv.calldir)) { + default: + msg.msg_flags |= MSG_TRUNC; + break; + case RPC_CALL: + ret = xs_read_stream_call(transport, &msg, flags); + break; + case RPC_REPLY: + ret = xs_read_stream_reply(transport, &msg, flags); + } + if (msg.msg_flags & MSG_TRUNC) { + transport->recv.calldir = cpu_to_be32(-1); + transport->recv.copied = -1; + } + if (ret < 0) + goto out_err; + read += ret; + if (transport->recv.offset < transport->recv.len) { + if (!(msg.msg_flags & MSG_TRUNC)) + return read; + msg.msg_flags = 0; + ret = xs_read_discard(transport->sock, &msg, flags, + transport->recv.len - transport->recv.offset); + if (ret <= 0) + goto out_err; + transport->recv.offset += ret; + read += ret; + if (transport->recv.offset != transport->recv.len) + return read; + } + if (xs_read_stream_request_done(transport)) { + trace_xs_stream_read_request(transport); + transport->recv.copied = 0; + } + transport->recv.offset = 0; + transport->recv.len = 0; + return read; +out_err: + return ret != 0 ? ret : -ESHUTDOWN; +} + +static __poll_t xs_poll_socket(struct sock_xprt *transport) +{ + return transport->sock->ops->poll(transport->file, transport->sock, + NULL); +} + +static bool xs_poll_socket_readable(struct sock_xprt *transport) +{ + __poll_t events = xs_poll_socket(transport); + + return (events & (EPOLLIN | EPOLLRDNORM)) && !(events & EPOLLRDHUP); +} + +static void xs_poll_check_readable(struct sock_xprt *transport) +{ + + clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state); + if (!xs_poll_socket_readable(transport)) + return; + if (!test_and_set_bit(XPRT_SOCK_DATA_READY, &transport->sock_state)) + queue_work(xprtiod_workqueue, &transport->recv_worker); +} + +static void xs_stream_data_receive(struct sock_xprt *transport) +{ + size_t read = 0; + ssize_t ret = 0; + + mutex_lock(&transport->recv_mutex); + if (transport->sock == NULL) + goto out; + for (;;) { + ret = xs_read_stream(transport, MSG_DONTWAIT); + if (ret < 0) + break; + read += ret; + cond_resched(); + } + if (ret == -ESHUTDOWN) + kernel_sock_shutdown(transport->sock, SHUT_RDWR); + else + xs_poll_check_readable(transport); +out: + mutex_unlock(&transport->recv_mutex); + trace_xs_stream_read_data(&transport->xprt, ret, read); +} + +static void xs_stream_data_receive_workfn(struct work_struct *work) +{ + struct sock_xprt *transport = + container_of(work, struct sock_xprt, recv_worker); + unsigned int pflags = memalloc_nofs_save(); + + xs_stream_data_receive(transport); + memalloc_nofs_restore(pflags); +} + +static void +xs_stream_reset_connect(struct sock_xprt *transport) +{ + transport->recv.offset = 0; + transport->recv.len = 0; + transport->recv.copied = 0; + transport->xmit.offset = 0; +} + +static void +xs_stream_start_connect(struct sock_xprt *transport) +{ + transport->xprt.stat.connect_count++; + transport->xprt.stat.connect_start = jiffies; +} + +#define XS_SENDMSG_FLAGS (MSG_DONTWAIT | MSG_NOSIGNAL) + +/** + * xs_nospace - handle transmit was incomplete + * @req: pointer to RPC request + * @transport: pointer to struct sock_xprt + * + */ +static int xs_nospace(struct rpc_rqst *req, struct sock_xprt *transport) +{ + struct rpc_xprt *xprt = &transport->xprt; + struct sock *sk = transport->inet; + int ret = -EAGAIN; + + trace_rpc_socket_nospace(req, transport); + + /* Protect against races with write_space */ + spin_lock(&xprt->transport_lock); + + /* Don't race with disconnect */ + if (xprt_connected(xprt)) { + /* wait for more buffer space */ + set_bit(XPRT_SOCK_NOSPACE, &transport->sock_state); + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + sk->sk_write_pending++; + xprt_wait_for_buffer_space(xprt); + } else + ret = -ENOTCONN; + + spin_unlock(&xprt->transport_lock); + return ret; +} + +static int xs_sock_nospace(struct rpc_rqst *req) +{ + struct sock_xprt *transport = + container_of(req->rq_xprt, struct sock_xprt, xprt); + struct sock *sk = transport->inet; + int ret = -EAGAIN; + + lock_sock(sk); + if (!sock_writeable(sk)) + ret = xs_nospace(req, transport); + release_sock(sk); + return ret; +} + +static int xs_stream_nospace(struct rpc_rqst *req, bool vm_wait) +{ + struct sock_xprt *transport = + container_of(req->rq_xprt, struct sock_xprt, xprt); + struct sock *sk = transport->inet; + int ret = -EAGAIN; + + if (vm_wait) + return -ENOBUFS; + lock_sock(sk); + if (!sk_stream_memory_free(sk)) + ret = xs_nospace(req, transport); + release_sock(sk); + return ret; +} + +static int xs_stream_prepare_request(struct rpc_rqst *req, struct xdr_buf *buf) +{ + return xdr_alloc_bvec(buf, rpc_task_gfp_mask()); +} + +/* + * Determine if the previous message in the stream was aborted before it + * could complete transmission. + */ +static bool +xs_send_request_was_aborted(struct sock_xprt *transport, struct rpc_rqst *req) +{ + return transport->xmit.offset != 0 && req->rq_bytes_sent == 0; +} + +/* + * Return the stream record marker field for a record of length < 2^31-1 + */ +static rpc_fraghdr +xs_stream_record_marker(struct xdr_buf *xdr) +{ + if (!xdr->len) + return 0; + return cpu_to_be32(RPC_LAST_STREAM_FRAGMENT | (u32)xdr->len); +} + +/** + * xs_local_send_request - write an RPC request to an AF_LOCAL socket + * @req: pointer to RPC request + * + * Return values: + * 0: The request has been sent + * EAGAIN: The socket was blocked, please call again later to + * complete the request + * ENOTCONN: Caller needs to invoke connect logic then call again + * other: Some other error occurred, the request was not sent + */ +static int xs_local_send_request(struct rpc_rqst *req) +{ + struct rpc_xprt *xprt = req->rq_xprt; + struct sock_xprt *transport = + container_of(xprt, struct sock_xprt, xprt); + struct xdr_buf *xdr = &req->rq_snd_buf; + rpc_fraghdr rm = xs_stream_record_marker(xdr); + unsigned int msglen = rm ? req->rq_slen + sizeof(rm) : req->rq_slen; + struct msghdr msg = { + .msg_flags = XS_SENDMSG_FLAGS, + }; + bool vm_wait; + unsigned int sent; + int status; + + /* Close the stream if the previous transmission was incomplete */ + if (xs_send_request_was_aborted(transport, req)) { + xprt_force_disconnect(xprt); + return -ENOTCONN; + } + + xs_pktdump("packet data:", + req->rq_svec->iov_base, req->rq_svec->iov_len); + + vm_wait = sk_stream_is_writeable(transport->inet) ? true : false; + + req->rq_xtime = ktime_get(); + status = xprt_sock_sendmsg(transport->sock, &msg, xdr, + transport->xmit.offset, rm, &sent); + dprintk("RPC: %s(%u) = %d\n", + __func__, xdr->len - transport->xmit.offset, status); + + if (likely(sent > 0) || status == 0) { + transport->xmit.offset += sent; + req->rq_bytes_sent = transport->xmit.offset; + if (likely(req->rq_bytes_sent >= msglen)) { + req->rq_xmit_bytes_sent += transport->xmit.offset; + transport->xmit.offset = 0; + return 0; + } + status = -EAGAIN; + vm_wait = false; + } + + switch (status) { + case -EAGAIN: + status = xs_stream_nospace(req, vm_wait); + break; + default: + dprintk("RPC: sendmsg returned unrecognized error %d\n", + -status); + fallthrough; + case -EPIPE: + xprt_force_disconnect(xprt); + status = -ENOTCONN; + } + + return status; +} + +/** + * xs_udp_send_request - write an RPC request to a UDP socket + * @req: pointer to RPC request + * + * Return values: + * 0: The request has been sent + * EAGAIN: The socket was blocked, please call again later to + * complete the request + * ENOTCONN: Caller needs to invoke connect logic then call again + * other: Some other error occurred, the request was not sent + */ +static int xs_udp_send_request(struct rpc_rqst *req) +{ + struct rpc_xprt *xprt = req->rq_xprt; + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct xdr_buf *xdr = &req->rq_snd_buf; + struct msghdr msg = { + .msg_name = xs_addr(xprt), + .msg_namelen = xprt->addrlen, + .msg_flags = XS_SENDMSG_FLAGS, + }; + unsigned int sent; + int status; + + xs_pktdump("packet data:", + req->rq_svec->iov_base, + req->rq_svec->iov_len); + + if (!xprt_bound(xprt)) + return -ENOTCONN; + + if (!xprt_request_get_cong(xprt, req)) + return -EBADSLT; + + status = xdr_alloc_bvec(xdr, rpc_task_gfp_mask()); + if (status < 0) + return status; + req->rq_xtime = ktime_get(); + status = xprt_sock_sendmsg(transport->sock, &msg, xdr, 0, 0, &sent); + + dprintk("RPC: xs_udp_send_request(%u) = %d\n", + xdr->len, status); + + /* firewall is blocking us, don't return -EAGAIN or we end up looping */ + if (status == -EPERM) + goto process_status; + + if (status == -EAGAIN && sock_writeable(transport->inet)) + status = -ENOBUFS; + + if (sent > 0 || status == 0) { + req->rq_xmit_bytes_sent += sent; + if (sent >= req->rq_slen) + return 0; + /* Still some bytes left; set up for a retry later. */ + status = -EAGAIN; + } + +process_status: + switch (status) { + case -ENOTSOCK: + status = -ENOTCONN; + /* Should we call xs_close() here? */ + break; + case -EAGAIN: + status = xs_sock_nospace(req); + break; + case -ENETUNREACH: + case -ENOBUFS: + case -EPIPE: + case -ECONNREFUSED: + case -EPERM: + /* When the server has died, an ICMP port unreachable message + * prompts ECONNREFUSED. */ + break; + default: + dprintk("RPC: sendmsg returned unrecognized error %d\n", + -status); + } + + return status; +} + +/** + * xs_tcp_send_request - write an RPC request to a TCP socket + * @req: pointer to RPC request + * + * Return values: + * 0: The request has been sent + * EAGAIN: The socket was blocked, please call again later to + * complete the request + * ENOTCONN: Caller needs to invoke connect logic then call again + * other: Some other error occurred, the request was not sent + * + * XXX: In the case of soft timeouts, should we eventually give up + * if sendmsg is not able to make progress? + */ +static int xs_tcp_send_request(struct rpc_rqst *req) +{ + struct rpc_xprt *xprt = req->rq_xprt; + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct xdr_buf *xdr = &req->rq_snd_buf; + rpc_fraghdr rm = xs_stream_record_marker(xdr); + unsigned int msglen = rm ? req->rq_slen + sizeof(rm) : req->rq_slen; + struct msghdr msg = { + .msg_flags = XS_SENDMSG_FLAGS, + }; + bool vm_wait; + unsigned int sent; + int status; + + /* Close the stream if the previous transmission was incomplete */ + if (xs_send_request_was_aborted(transport, req)) { + if (transport->sock != NULL) + kernel_sock_shutdown(transport->sock, SHUT_RDWR); + return -ENOTCONN; + } + if (!transport->inet) + return -ENOTCONN; + + xs_pktdump("packet data:", + req->rq_svec->iov_base, + req->rq_svec->iov_len); + + if (test_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state)) + xs_tcp_set_socket_timeouts(xprt, transport->sock); + + xs_set_srcport(transport, transport->sock); + + /* Continue transmitting the packet/record. We must be careful + * to cope with writespace callbacks arriving _after_ we have + * called sendmsg(). */ + req->rq_xtime = ktime_get(); + tcp_sock_set_cork(transport->inet, true); + + vm_wait = sk_stream_is_writeable(transport->inet) ? true : false; + + do { + status = xprt_sock_sendmsg(transport->sock, &msg, xdr, + transport->xmit.offset, rm, &sent); + + dprintk("RPC: xs_tcp_send_request(%u) = %d\n", + xdr->len - transport->xmit.offset, status); + + /* If we've sent the entire packet, immediately + * reset the count of bytes sent. */ + transport->xmit.offset += sent; + req->rq_bytes_sent = transport->xmit.offset; + if (likely(req->rq_bytes_sent >= msglen)) { + req->rq_xmit_bytes_sent += transport->xmit.offset; + transport->xmit.offset = 0; + if (atomic_long_read(&xprt->xmit_queuelen) == 1) + tcp_sock_set_cork(transport->inet, false); + return 0; + } + + WARN_ON_ONCE(sent == 0 && status == 0); + + if (sent > 0) + vm_wait = false; + + } while (status == 0); + + switch (status) { + case -ENOTSOCK: + status = -ENOTCONN; + /* Should we call xs_close() here? */ + break; + case -EAGAIN: + status = xs_stream_nospace(req, vm_wait); + break; + case -ECONNRESET: + case -ECONNREFUSED: + case -ENOTCONN: + case -EADDRINUSE: + case -ENOBUFS: + case -EPIPE: + break; + default: + dprintk("RPC: sendmsg returned unrecognized error %d\n", + -status); + } + + return status; +} + +static void xs_save_old_callbacks(struct sock_xprt *transport, struct sock *sk) +{ + transport->old_data_ready = sk->sk_data_ready; + transport->old_state_change = sk->sk_state_change; + transport->old_write_space = sk->sk_write_space; + transport->old_error_report = sk->sk_error_report; +} + +static void xs_restore_old_callbacks(struct sock_xprt *transport, struct sock *sk) +{ + sk->sk_data_ready = transport->old_data_ready; + sk->sk_state_change = transport->old_state_change; + sk->sk_write_space = transport->old_write_space; + sk->sk_error_report = transport->old_error_report; +} + +static void xs_sock_reset_state_flags(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + clear_bit(XPRT_SOCK_DATA_READY, &transport->sock_state); + clear_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state); + clear_bit(XPRT_SOCK_WAKE_WRITE, &transport->sock_state); + clear_bit(XPRT_SOCK_WAKE_DISCONNECT, &transport->sock_state); + clear_bit(XPRT_SOCK_NOSPACE, &transport->sock_state); +} + +static void xs_run_error_worker(struct sock_xprt *transport, unsigned int nr) +{ + set_bit(nr, &transport->sock_state); + queue_work(xprtiod_workqueue, &transport->error_worker); +} + +static void xs_sock_reset_connection_flags(struct rpc_xprt *xprt) +{ + xprt->connect_cookie++; + smp_mb__before_atomic(); + clear_bit(XPRT_CLOSE_WAIT, &xprt->state); + clear_bit(XPRT_CLOSING, &xprt->state); + xs_sock_reset_state_flags(xprt); + smp_mb__after_atomic(); +} + +/** + * xs_error_report - callback to handle TCP socket state errors + * @sk: socket + * + * Note: we don't call sock_error() since there may be a rpc_task + * using the socket, and so we don't want to clear sk->sk_err. + */ +static void xs_error_report(struct sock *sk) +{ + struct sock_xprt *transport; + struct rpc_xprt *xprt; + + if (!(xprt = xprt_from_sock(sk))) + return; + + transport = container_of(xprt, struct sock_xprt, xprt); + transport->xprt_err = -sk->sk_err; + if (transport->xprt_err == 0) + return; + dprintk("RPC: xs_error_report client %p, error=%d...\n", + xprt, -transport->xprt_err); + trace_rpc_socket_error(xprt, sk->sk_socket, transport->xprt_err); + + /* barrier ensures xprt_err is set before XPRT_SOCK_WAKE_ERROR */ + smp_mb__before_atomic(); + xs_run_error_worker(transport, XPRT_SOCK_WAKE_ERROR); +} + +static void xs_reset_transport(struct sock_xprt *transport) +{ + struct socket *sock = transport->sock; + struct sock *sk = transport->inet; + struct rpc_xprt *xprt = &transport->xprt; + struct file *filp = transport->file; + + if (sk == NULL) + return; + /* + * Make sure we're calling this in a context from which it is safe + * to call __fput_sync(). In practice that means rpciod and the + * system workqueue. + */ + if (!(current->flags & PF_WQ_WORKER)) { + WARN_ON_ONCE(1); + set_bit(XPRT_CLOSE_WAIT, &xprt->state); + return; + } + + if (atomic_read(&transport->xprt.swapper)) + sk_clear_memalloc(sk); + + kernel_sock_shutdown(sock, SHUT_RDWR); + + mutex_lock(&transport->recv_mutex); + lock_sock(sk); + transport->inet = NULL; + transport->sock = NULL; + transport->file = NULL; + + sk->sk_user_data = NULL; + + xs_restore_old_callbacks(transport, sk); + xprt_clear_connected(xprt); + xs_sock_reset_connection_flags(xprt); + /* Reset stream record info */ + xs_stream_reset_connect(transport); + release_sock(sk); + mutex_unlock(&transport->recv_mutex); + + trace_rpc_socket_close(xprt, sock); + __fput_sync(filp); + + xprt_disconnect_done(xprt); +} + +/** + * xs_close - close a socket + * @xprt: transport + * + * This is used when all requests are complete; ie, no DRC state remains + * on the server we want to save. + * + * The caller _must_ be holding XPRT_LOCKED in order to avoid issues with + * xs_reset_transport() zeroing the socket from underneath a writer. + */ +static void xs_close(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + dprintk("RPC: xs_close xprt %p\n", xprt); + + xs_reset_transport(transport); + xprt->reestablish_timeout = 0; +} + +static void xs_inject_disconnect(struct rpc_xprt *xprt) +{ + dprintk("RPC: injecting transport disconnect on xprt=%p\n", + xprt); + xprt_disconnect_done(xprt); +} + +static void xs_xprt_free(struct rpc_xprt *xprt) +{ + xs_free_peer_addresses(xprt); + xprt_free(xprt); +} + +/** + * xs_destroy - prepare to shutdown a transport + * @xprt: doomed transport + * + */ +static void xs_destroy(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, + struct sock_xprt, xprt); + dprintk("RPC: xs_destroy xprt %p\n", xprt); + + cancel_delayed_work_sync(&transport->connect_worker); + xs_close(xprt); + cancel_work_sync(&transport->recv_worker); + cancel_work_sync(&transport->error_worker); + xs_xprt_free(xprt); + module_put(THIS_MODULE); +} + +/** + * xs_udp_data_read_skb - receive callback for UDP sockets + * @xprt: transport + * @sk: socket + * @skb: skbuff + * + */ +static void xs_udp_data_read_skb(struct rpc_xprt *xprt, + struct sock *sk, + struct sk_buff *skb) +{ + struct rpc_task *task; + struct rpc_rqst *rovr; + int repsize, copied; + u32 _xid; + __be32 *xp; + + repsize = skb->len; + if (repsize < 4) { + dprintk("RPC: impossible RPC reply size %d!\n", repsize); + return; + } + + /* Copy the XID from the skb... */ + xp = skb_header_pointer(skb, 0, sizeof(_xid), &_xid); + if (xp == NULL) + return; + + /* Look up and lock the request corresponding to the given XID */ + spin_lock(&xprt->queue_lock); + rovr = xprt_lookup_rqst(xprt, *xp); + if (!rovr) + goto out_unlock; + xprt_pin_rqst(rovr); + xprt_update_rtt(rovr->rq_task); + spin_unlock(&xprt->queue_lock); + task = rovr->rq_task; + + if ((copied = rovr->rq_private_buf.buflen) > repsize) + copied = repsize; + + /* Suck it into the iovec, verify checksum if not done by hw. */ + if (csum_partial_copy_to_xdr(&rovr->rq_private_buf, skb)) { + spin_lock(&xprt->queue_lock); + __UDPX_INC_STATS(sk, UDP_MIB_INERRORS); + goto out_unpin; + } + + + spin_lock(&xprt->transport_lock); + xprt_adjust_cwnd(xprt, task, copied); + spin_unlock(&xprt->transport_lock); + spin_lock(&xprt->queue_lock); + xprt_complete_rqst(task, copied); + __UDPX_INC_STATS(sk, UDP_MIB_INDATAGRAMS); +out_unpin: + xprt_unpin_rqst(rovr); + out_unlock: + spin_unlock(&xprt->queue_lock); +} + +static void xs_udp_data_receive(struct sock_xprt *transport) +{ + struct sk_buff *skb; + struct sock *sk; + int err; + + mutex_lock(&transport->recv_mutex); + sk = transport->inet; + if (sk == NULL) + goto out; + for (;;) { + skb = skb_recv_udp(sk, MSG_DONTWAIT, &err); + if (skb == NULL) + break; + xs_udp_data_read_skb(&transport->xprt, sk, skb); + consume_skb(skb); + cond_resched(); + } + xs_poll_check_readable(transport); +out: + mutex_unlock(&transport->recv_mutex); +} + +static void xs_udp_data_receive_workfn(struct work_struct *work) +{ + struct sock_xprt *transport = + container_of(work, struct sock_xprt, recv_worker); + unsigned int pflags = memalloc_nofs_save(); + + xs_udp_data_receive(transport); + memalloc_nofs_restore(pflags); +} + +/** + * xs_data_ready - "data ready" callback for sockets + * @sk: socket with data to read + * + */ +static void xs_data_ready(struct sock *sk) +{ + struct rpc_xprt *xprt; + + xprt = xprt_from_sock(sk); + if (xprt != NULL) { + struct sock_xprt *transport = container_of(xprt, + struct sock_xprt, xprt); + + trace_xs_data_ready(xprt); + + transport->old_data_ready(sk); + /* Any data means we had a useful conversation, so + * then we don't need to delay the next reconnect + */ + if (xprt->reestablish_timeout) + xprt->reestablish_timeout = 0; + if (!test_and_set_bit(XPRT_SOCK_DATA_READY, &transport->sock_state)) + queue_work(xprtiod_workqueue, &transport->recv_worker); + } +} + +/* + * Helper function to force a TCP close if the server is sending + * junk and/or it has put us in CLOSE_WAIT + */ +static void xs_tcp_force_close(struct rpc_xprt *xprt) +{ + xprt_force_disconnect(xprt); +} + +#if defined(CONFIG_SUNRPC_BACKCHANNEL) +static size_t xs_tcp_bc_maxpayload(struct rpc_xprt *xprt) +{ + return PAGE_SIZE; +} +#endif /* CONFIG_SUNRPC_BACKCHANNEL */ + +/** + * xs_local_state_change - callback to handle AF_LOCAL socket state changes + * @sk: socket whose state has changed + * + */ +static void xs_local_state_change(struct sock *sk) +{ + struct rpc_xprt *xprt; + struct sock_xprt *transport; + + if (!(xprt = xprt_from_sock(sk))) + return; + transport = container_of(xprt, struct sock_xprt, xprt); + if (sk->sk_shutdown & SHUTDOWN_MASK) { + clear_bit(XPRT_CONNECTED, &xprt->state); + /* Trigger the socket release */ + xs_run_error_worker(transport, XPRT_SOCK_WAKE_DISCONNECT); + } +} + +/** + * xs_tcp_state_change - callback to handle TCP socket state changes + * @sk: socket whose state has changed + * + */ +static void xs_tcp_state_change(struct sock *sk) +{ + struct rpc_xprt *xprt; + struct sock_xprt *transport; + + if (!(xprt = xprt_from_sock(sk))) + return; + dprintk("RPC: xs_tcp_state_change client %p...\n", xprt); + dprintk("RPC: state %x conn %d dead %d zapped %d sk_shutdown %d\n", + sk->sk_state, xprt_connected(xprt), + sock_flag(sk, SOCK_DEAD), + sock_flag(sk, SOCK_ZAPPED), + sk->sk_shutdown); + + transport = container_of(xprt, struct sock_xprt, xprt); + trace_rpc_socket_state_change(xprt, sk->sk_socket); + switch (sk->sk_state) { + case TCP_ESTABLISHED: + if (!xprt_test_and_set_connected(xprt)) { + xprt->connect_cookie++; + clear_bit(XPRT_SOCK_CONNECTING, &transport->sock_state); + xprt_clear_connecting(xprt); + + xprt->stat.connect_count++; + xprt->stat.connect_time += (long)jiffies - + xprt->stat.connect_start; + xs_run_error_worker(transport, XPRT_SOCK_WAKE_PENDING); + } + break; + case TCP_FIN_WAIT1: + /* The client initiated a shutdown of the socket */ + xprt->connect_cookie++; + xprt->reestablish_timeout = 0; + set_bit(XPRT_CLOSING, &xprt->state); + smp_mb__before_atomic(); + clear_bit(XPRT_CONNECTED, &xprt->state); + clear_bit(XPRT_CLOSE_WAIT, &xprt->state); + smp_mb__after_atomic(); + break; + case TCP_CLOSE_WAIT: + /* The server initiated a shutdown of the socket */ + xprt->connect_cookie++; + clear_bit(XPRT_CONNECTED, &xprt->state); + xs_run_error_worker(transport, XPRT_SOCK_WAKE_DISCONNECT); + fallthrough; + case TCP_CLOSING: + /* + * If the server closed down the connection, make sure that + * we back off before reconnecting + */ + if (xprt->reestablish_timeout < XS_TCP_INIT_REEST_TO) + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; + break; + case TCP_LAST_ACK: + set_bit(XPRT_CLOSING, &xprt->state); + smp_mb__before_atomic(); + clear_bit(XPRT_CONNECTED, &xprt->state); + smp_mb__after_atomic(); + break; + case TCP_CLOSE: + if (test_and_clear_bit(XPRT_SOCK_CONNECTING, + &transport->sock_state)) + xprt_clear_connecting(xprt); + clear_bit(XPRT_CLOSING, &xprt->state); + /* Trigger the socket release */ + xs_run_error_worker(transport, XPRT_SOCK_WAKE_DISCONNECT); + } +} + +static void xs_write_space(struct sock *sk) +{ + struct sock_xprt *transport; + struct rpc_xprt *xprt; + + if (!sk->sk_socket) + return; + clear_bit(SOCK_NOSPACE, &sk->sk_socket->flags); + + if (unlikely(!(xprt = xprt_from_sock(sk)))) + return; + transport = container_of(xprt, struct sock_xprt, xprt); + if (!test_and_clear_bit(XPRT_SOCK_NOSPACE, &transport->sock_state)) + return; + xs_run_error_worker(transport, XPRT_SOCK_WAKE_WRITE); + sk->sk_write_pending--; +} + +/** + * xs_udp_write_space - callback invoked when socket buffer space + * becomes available + * @sk: socket whose state has changed + * + * Called when more output buffer space is available for this socket. + * We try not to wake our writers until they can make "significant" + * progress, otherwise we'll waste resources thrashing kernel_sendmsg + * with a bunch of small requests. + */ +static void xs_udp_write_space(struct sock *sk) +{ + /* from net/core/sock.c:sock_def_write_space */ + if (sock_writeable(sk)) + xs_write_space(sk); +} + +/** + * xs_tcp_write_space - callback invoked when socket buffer space + * becomes available + * @sk: socket whose state has changed + * + * Called when more output buffer space is available for this socket. + * We try not to wake our writers until they can make "significant" + * progress, otherwise we'll waste resources thrashing kernel_sendmsg + * with a bunch of small requests. + */ +static void xs_tcp_write_space(struct sock *sk) +{ + /* from net/core/stream.c:sk_stream_write_space */ + if (sk_stream_is_writeable(sk)) + xs_write_space(sk); +} + +static void xs_udp_do_set_buffer_size(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct sock *sk = transport->inet; + + if (transport->rcvsize) { + sk->sk_userlocks |= SOCK_RCVBUF_LOCK; + sk->sk_rcvbuf = transport->rcvsize * xprt->max_reqs * 2; + } + if (transport->sndsize) { + sk->sk_userlocks |= SOCK_SNDBUF_LOCK; + sk->sk_sndbuf = transport->sndsize * xprt->max_reqs * 2; + sk->sk_write_space(sk); + } +} + +/** + * xs_udp_set_buffer_size - set send and receive limits + * @xprt: generic transport + * @sndsize: requested size of send buffer, in bytes + * @rcvsize: requested size of receive buffer, in bytes + * + * Set socket send and receive buffer size limits. + */ +static void xs_udp_set_buffer_size(struct rpc_xprt *xprt, size_t sndsize, size_t rcvsize) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + transport->sndsize = 0; + if (sndsize) + transport->sndsize = sndsize + 1024; + transport->rcvsize = 0; + if (rcvsize) + transport->rcvsize = rcvsize + 1024; + + xs_udp_do_set_buffer_size(xprt); +} + +/** + * xs_udp_timer - called when a retransmit timeout occurs on a UDP transport + * @xprt: controlling transport + * @task: task that timed out + * + * Adjust the congestion window after a retransmit timeout has occurred. + */ +static void xs_udp_timer(struct rpc_xprt *xprt, struct rpc_task *task) +{ + spin_lock(&xprt->transport_lock); + xprt_adjust_cwnd(xprt, task, -ETIMEDOUT); + spin_unlock(&xprt->transport_lock); +} + +static int xs_get_random_port(void) +{ + unsigned short min = xprt_min_resvport, max = xprt_max_resvport; + unsigned short range; + unsigned short rand; + + if (max < min) + return -EADDRINUSE; + range = max - min + 1; + rand = prandom_u32_max(range); + return rand + min; +} + +static unsigned short xs_sock_getport(struct socket *sock) +{ + struct sockaddr_storage buf; + unsigned short port = 0; + + if (kernel_getsockname(sock, (struct sockaddr *)&buf) < 0) + goto out; + switch (buf.ss_family) { + case AF_INET6: + port = ntohs(((struct sockaddr_in6 *)&buf)->sin6_port); + break; + case AF_INET: + port = ntohs(((struct sockaddr_in *)&buf)->sin_port); + } +out: + return port; +} + +/** + * xs_set_port - reset the port number in the remote endpoint address + * @xprt: generic transport + * @port: new port number + * + */ +static void xs_set_port(struct rpc_xprt *xprt, unsigned short port) +{ + dprintk("RPC: setting port for xprt %p to %u\n", xprt, port); + + rpc_set_port(xs_addr(xprt), port); + xs_update_peer_port(xprt); +} + +static void xs_set_srcport(struct sock_xprt *transport, struct socket *sock) +{ + if (transport->srcport == 0 && transport->xprt.reuseport) + transport->srcport = xs_sock_getport(sock); +} + +static int xs_get_srcport(struct sock_xprt *transport) +{ + int port = transport->srcport; + + if (port == 0 && transport->xprt.resvport) + port = xs_get_random_port(); + return port; +} + +static unsigned short xs_sock_srcport(struct rpc_xprt *xprt) +{ + struct sock_xprt *sock = container_of(xprt, struct sock_xprt, xprt); + unsigned short ret = 0; + mutex_lock(&sock->recv_mutex); + if (sock->sock) + ret = xs_sock_getport(sock->sock); + mutex_unlock(&sock->recv_mutex); + return ret; +} + +static int xs_sock_srcaddr(struct rpc_xprt *xprt, char *buf, size_t buflen) +{ + struct sock_xprt *sock = container_of(xprt, struct sock_xprt, xprt); + union { + struct sockaddr sa; + struct sockaddr_storage st; + } saddr; + int ret = -ENOTCONN; + + mutex_lock(&sock->recv_mutex); + if (sock->sock) { + ret = kernel_getsockname(sock->sock, &saddr.sa); + if (ret >= 0) + ret = snprintf(buf, buflen, "%pISc", &saddr.sa); + } + mutex_unlock(&sock->recv_mutex); + return ret; +} + +static unsigned short xs_next_srcport(struct sock_xprt *transport, unsigned short port) +{ + if (transport->srcport != 0) + transport->srcport = 0; + if (!transport->xprt.resvport) + return 0; + if (port <= xprt_min_resvport || port > xprt_max_resvport) + return xprt_max_resvport; + return --port; +} +static int xs_bind(struct sock_xprt *transport, struct socket *sock) +{ + struct sockaddr_storage myaddr; + int err, nloop = 0; + int port = xs_get_srcport(transport); + unsigned short last; + + /* + * If we are asking for any ephemeral port (i.e. port == 0 && + * transport->xprt.resvport == 0), don't bind. Let the local + * port selection happen implicitly when the socket is used + * (for example at connect time). + * + * This ensures that we can continue to establish TCP + * connections even when all local ephemeral ports are already + * a part of some TCP connection. This makes no difference + * for UDP sockets, but also doesn't harm them. + * + * If we're asking for any reserved port (i.e. port == 0 && + * transport->xprt.resvport == 1) xs_get_srcport above will + * ensure that port is non-zero and we will bind as needed. + */ + if (port <= 0) + return port; + + memcpy(&myaddr, &transport->srcaddr, transport->xprt.addrlen); + do { + rpc_set_port((struct sockaddr *)&myaddr, port); + err = kernel_bind(sock, (struct sockaddr *)&myaddr, + transport->xprt.addrlen); + if (err == 0) { + if (transport->xprt.reuseport) + transport->srcport = port; + break; + } + last = port; + port = xs_next_srcport(transport, port); + if (port > last) + nloop++; + } while (err == -EADDRINUSE && nloop != 2); + + if (myaddr.ss_family == AF_INET) + dprintk("RPC: %s %pI4:%u: %s (%d)\n", __func__, + &((struct sockaddr_in *)&myaddr)->sin_addr, + port, err ? "failed" : "ok", err); + else + dprintk("RPC: %s %pI6:%u: %s (%d)\n", __func__, + &((struct sockaddr_in6 *)&myaddr)->sin6_addr, + port, err ? "failed" : "ok", err); + return err; +} + +/* + * We don't support autobind on AF_LOCAL sockets + */ +static void xs_local_rpcbind(struct rpc_task *task) +{ + xprt_set_bound(task->tk_xprt); +} + +static void xs_local_set_port(struct rpc_xprt *xprt, unsigned short port) +{ +} + +#ifdef CONFIG_DEBUG_LOCK_ALLOC +static struct lock_class_key xs_key[3]; +static struct lock_class_key xs_slock_key[3]; + +static inline void xs_reclassify_socketu(struct socket *sock) +{ + struct sock *sk = sock->sk; + + sock_lock_init_class_and_name(sk, "slock-AF_LOCAL-RPC", + &xs_slock_key[0], "sk_lock-AF_LOCAL-RPC", &xs_key[0]); +} + +static inline void xs_reclassify_socket4(struct socket *sock) +{ + struct sock *sk = sock->sk; + + sock_lock_init_class_and_name(sk, "slock-AF_INET-RPC", + &xs_slock_key[1], "sk_lock-AF_INET-RPC", &xs_key[1]); +} + +static inline void xs_reclassify_socket6(struct socket *sock) +{ + struct sock *sk = sock->sk; + + sock_lock_init_class_and_name(sk, "slock-AF_INET6-RPC", + &xs_slock_key[2], "sk_lock-AF_INET6-RPC", &xs_key[2]); +} + +static inline void xs_reclassify_socket(int family, struct socket *sock) +{ + if (WARN_ON_ONCE(!sock_allow_reclassification(sock->sk))) + return; + + switch (family) { + case AF_LOCAL: + xs_reclassify_socketu(sock); + break; + case AF_INET: + xs_reclassify_socket4(sock); + break; + case AF_INET6: + xs_reclassify_socket6(sock); + break; + } +} +#else +static inline void xs_reclassify_socket(int family, struct socket *sock) +{ +} +#endif + +static void xs_dummy_setup_socket(struct work_struct *work) +{ +} + +static struct socket *xs_create_sock(struct rpc_xprt *xprt, + struct sock_xprt *transport, int family, int type, + int protocol, bool reuseport) +{ + struct file *filp; + struct socket *sock; + int err; + + err = __sock_create(xprt->xprt_net, family, type, protocol, &sock, 1); + if (err < 0) { + dprintk("RPC: can't create %d transport socket (%d).\n", + protocol, -err); + goto out; + } + xs_reclassify_socket(family, sock); + + if (reuseport) + sock_set_reuseport(sock->sk); + + err = xs_bind(transport, sock); + if (err) { + sock_release(sock); + goto out; + } + + filp = sock_alloc_file(sock, O_NONBLOCK, NULL); + if (IS_ERR(filp)) + return ERR_CAST(filp); + transport->file = filp; + + return sock; +out: + return ERR_PTR(err); +} + +static int xs_local_finish_connecting(struct rpc_xprt *xprt, + struct socket *sock) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, + xprt); + + if (!transport->inet) { + struct sock *sk = sock->sk; + + lock_sock(sk); + + xs_save_old_callbacks(transport, sk); + + sk->sk_user_data = xprt; + sk->sk_data_ready = xs_data_ready; + sk->sk_write_space = xs_udp_write_space; + sk->sk_state_change = xs_local_state_change; + sk->sk_error_report = xs_error_report; + + xprt_clear_connected(xprt); + + /* Reset to new socket */ + transport->sock = sock; + transport->inet = sk; + + release_sock(sk); + } + + xs_stream_start_connect(transport); + + return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, 0); +} + +/** + * xs_local_setup_socket - create AF_LOCAL socket, connect to a local endpoint + * @transport: socket transport to connect + */ +static int xs_local_setup_socket(struct sock_xprt *transport) +{ + struct rpc_xprt *xprt = &transport->xprt; + struct file *filp; + struct socket *sock; + int status; + + status = __sock_create(xprt->xprt_net, AF_LOCAL, + SOCK_STREAM, 0, &sock, 1); + if (status < 0) { + dprintk("RPC: can't create AF_LOCAL " + "transport socket (%d).\n", -status); + goto out; + } + xs_reclassify_socket(AF_LOCAL, sock); + + filp = sock_alloc_file(sock, O_NONBLOCK, NULL); + if (IS_ERR(filp)) { + status = PTR_ERR(filp); + goto out; + } + transport->file = filp; + + dprintk("RPC: worker connecting xprt %p via AF_LOCAL to %s\n", + xprt, xprt->address_strings[RPC_DISPLAY_ADDR]); + + status = xs_local_finish_connecting(xprt, sock); + trace_rpc_socket_connect(xprt, sock, status); + switch (status) { + case 0: + dprintk("RPC: xprt %p connected to %s\n", + xprt, xprt->address_strings[RPC_DISPLAY_ADDR]); + xprt->stat.connect_count++; + xprt->stat.connect_time += (long)jiffies - + xprt->stat.connect_start; + xprt_set_connected(xprt); + break; + case -ENOBUFS: + break; + case -ENOENT: + dprintk("RPC: xprt %p: socket %s does not exist\n", + xprt, xprt->address_strings[RPC_DISPLAY_ADDR]); + break; + case -ECONNREFUSED: + dprintk("RPC: xprt %p: connection refused for %s\n", + xprt, xprt->address_strings[RPC_DISPLAY_ADDR]); + break; + default: + printk(KERN_ERR "%s: unhandled error (%d) connecting to %s\n", + __func__, -status, + xprt->address_strings[RPC_DISPLAY_ADDR]); + } + +out: + xprt_clear_connecting(xprt); + xprt_wake_pending_tasks(xprt, status); + return status; +} + +static void xs_local_connect(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + int ret; + + if (transport->file) + goto force_disconnect; + + if (RPC_IS_ASYNC(task)) { + /* + * We want the AF_LOCAL connect to be resolved in the + * filesystem namespace of the process making the rpc + * call. Thus we connect synchronously. + * + * If we want to support asynchronous AF_LOCAL calls, + * we'll need to figure out how to pass a namespace to + * connect. + */ + rpc_task_set_rpc_status(task, -ENOTCONN); + goto out_wake; + } + ret = xs_local_setup_socket(transport); + if (ret && !RPC_IS_SOFTCONN(task)) + msleep_interruptible(15000); + return; +force_disconnect: + xprt_force_disconnect(xprt); +out_wake: + xprt_clear_connecting(xprt); + xprt_wake_pending_tasks(xprt, -ENOTCONN); +} + +#if IS_ENABLED(CONFIG_SUNRPC_SWAP) +/* + * Note that this should be called with XPRT_LOCKED held, or recv_mutex + * held, or when we otherwise know that we have exclusive access to the + * socket, to guard against races with xs_reset_transport. + */ +static void xs_set_memalloc(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, + xprt); + + /* + * If there's no sock, then we have nothing to set. The + * reconnecting process will get it for us. + */ + if (!transport->inet) + return; + if (atomic_read(&xprt->swapper)) + sk_set_memalloc(transport->inet); +} + +/** + * xs_enable_swap - Tag this transport as being used for swap. + * @xprt: transport to tag + * + * Take a reference to this transport on behalf of the rpc_clnt, and + * optionally mark it for swapping if it wasn't already. + */ +static int +xs_enable_swap(struct rpc_xprt *xprt) +{ + struct sock_xprt *xs = container_of(xprt, struct sock_xprt, xprt); + + mutex_lock(&xs->recv_mutex); + if (atomic_inc_return(&xprt->swapper) == 1 && + xs->inet) + sk_set_memalloc(xs->inet); + mutex_unlock(&xs->recv_mutex); + return 0; +} + +/** + * xs_disable_swap - Untag this transport as being used for swap. + * @xprt: transport to tag + * + * Drop a "swapper" reference to this xprt on behalf of the rpc_clnt. If the + * swapper refcount goes to 0, untag the socket as a memalloc socket. + */ +static void +xs_disable_swap(struct rpc_xprt *xprt) +{ + struct sock_xprt *xs = container_of(xprt, struct sock_xprt, xprt); + + mutex_lock(&xs->recv_mutex); + if (atomic_dec_and_test(&xprt->swapper) && + xs->inet) + sk_clear_memalloc(xs->inet); + mutex_unlock(&xs->recv_mutex); +} +#else +static void xs_set_memalloc(struct rpc_xprt *xprt) +{ +} + +static int +xs_enable_swap(struct rpc_xprt *xprt) +{ + return -EINVAL; +} + +static void +xs_disable_swap(struct rpc_xprt *xprt) +{ +} +#endif + +static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + if (!transport->inet) { + struct sock *sk = sock->sk; + + lock_sock(sk); + + xs_save_old_callbacks(transport, sk); + + sk->sk_user_data = xprt; + sk->sk_data_ready = xs_data_ready; + sk->sk_write_space = xs_udp_write_space; + + xprt_set_connected(xprt); + + /* Reset to new socket */ + transport->sock = sock; + transport->inet = sk; + + xs_set_memalloc(xprt); + + release_sock(sk); + } + xs_udp_do_set_buffer_size(xprt); + + xprt->stat.connect_start = jiffies; +} + +static void xs_udp_setup_socket(struct work_struct *work) +{ + struct sock_xprt *transport = + container_of(work, struct sock_xprt, connect_worker.work); + struct rpc_xprt *xprt = &transport->xprt; + struct socket *sock; + int status = -EIO; + unsigned int pflags = current->flags; + + if (atomic_read(&xprt->swapper)) + current->flags |= PF_MEMALLOC; + sock = xs_create_sock(xprt, transport, + xs_addr(xprt)->sa_family, SOCK_DGRAM, + IPPROTO_UDP, false); + if (IS_ERR(sock)) + goto out; + + dprintk("RPC: worker connecting xprt %p via %s to " + "%s (port %s)\n", xprt, + xprt->address_strings[RPC_DISPLAY_PROTO], + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PORT]); + + xs_udp_finish_connecting(xprt, sock); + trace_rpc_socket_connect(xprt, sock, 0); + status = 0; +out: + xprt_clear_connecting(xprt); + xprt_unlock_connect(xprt, transport); + xprt_wake_pending_tasks(xprt, status); + current_restore_flags(pflags, PF_MEMALLOC); +} + +/** + * xs_tcp_shutdown - gracefully shut down a TCP socket + * @xprt: transport + * + * Initiates a graceful shutdown of the TCP socket by calling the + * equivalent of shutdown(SHUT_RDWR); + */ +static void xs_tcp_shutdown(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct socket *sock = transport->sock; + int skst = transport->inet ? transport->inet->sk_state : TCP_CLOSE; + + if (sock == NULL) + return; + if (!xprt->reuseport) { + xs_close(xprt); + return; + } + switch (skst) { + case TCP_FIN_WAIT1: + case TCP_FIN_WAIT2: + case TCP_LAST_ACK: + break; + case TCP_ESTABLISHED: + case TCP_CLOSE_WAIT: + kernel_sock_shutdown(sock, SHUT_RDWR); + trace_rpc_socket_shutdown(xprt, sock); + break; + default: + xs_reset_transport(transport); + } +} + +static void xs_tcp_set_socket_timeouts(struct rpc_xprt *xprt, + struct socket *sock) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + unsigned int keepidle; + unsigned int keepcnt; + unsigned int timeo; + + spin_lock(&xprt->transport_lock); + keepidle = DIV_ROUND_UP(xprt->timeout->to_initval, HZ); + keepcnt = xprt->timeout->to_retries + 1; + timeo = jiffies_to_msecs(xprt->timeout->to_initval) * + (xprt->timeout->to_retries + 1); + clear_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state); + spin_unlock(&xprt->transport_lock); + + /* TCP Keepalive options */ + sock_set_keepalive(sock->sk); + tcp_sock_set_keepidle(sock->sk, keepidle); + tcp_sock_set_keepintvl(sock->sk, keepidle); + tcp_sock_set_keepcnt(sock->sk, keepcnt); + + /* TCP user timeout (see RFC5482) */ + tcp_sock_set_user_timeout(sock->sk, timeo); +} + +static void xs_tcp_set_connect_timeout(struct rpc_xprt *xprt, + unsigned long connect_timeout, + unsigned long reconnect_timeout) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct rpc_timeout to; + unsigned long initval; + + spin_lock(&xprt->transport_lock); + if (reconnect_timeout < xprt->max_reconnect_timeout) + xprt->max_reconnect_timeout = reconnect_timeout; + if (connect_timeout < xprt->connect_timeout) { + memcpy(&to, xprt->timeout, sizeof(to)); + initval = DIV_ROUND_UP(connect_timeout, to.to_retries + 1); + /* Arbitrary lower limit */ + if (initval < XS_TCP_INIT_REEST_TO << 1) + initval = XS_TCP_INIT_REEST_TO << 1; + to.to_initval = initval; + to.to_maxval = initval; + memcpy(&transport->tcp_timeout, &to, + sizeof(transport->tcp_timeout)); + xprt->timeout = &transport->tcp_timeout; + xprt->connect_timeout = connect_timeout; + } + set_bit(XPRT_SOCK_UPD_TIMEOUT, &transport->sock_state); + spin_unlock(&xprt->transport_lock); +} + +static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + if (!transport->inet) { + struct sock *sk = sock->sk; + + /* Avoid temporary address, they are bad for long-lived + * connections such as NFS mounts. + * RFC4941, section 3.6 suggests that: + * Individual applications, which have specific + * knowledge about the normal duration of connections, + * MAY override this as appropriate. + */ + if (xs_addr(xprt)->sa_family == PF_INET6) { + ip6_sock_set_addr_preferences(sk, + IPV6_PREFER_SRC_PUBLIC); + } + + xs_tcp_set_socket_timeouts(xprt, sock); + tcp_sock_set_nodelay(sk); + + lock_sock(sk); + + xs_save_old_callbacks(transport, sk); + + sk->sk_user_data = xprt; + sk->sk_data_ready = xs_data_ready; + sk->sk_state_change = xs_tcp_state_change; + sk->sk_write_space = xs_tcp_write_space; + sk->sk_error_report = xs_error_report; + + /* socket options */ + sock_reset_flag(sk, SOCK_LINGER); + + xprt_clear_connected(xprt); + + /* Reset to new socket */ + transport->sock = sock; + transport->inet = sk; + + release_sock(sk); + } + + if (!xprt_bound(xprt)) + return -ENOTCONN; + + xs_set_memalloc(xprt); + + xs_stream_start_connect(transport); + + /* Tell the socket layer to start connecting... */ + set_bit(XPRT_SOCK_CONNECTING, &transport->sock_state); + return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, O_NONBLOCK); +} + +/** + * xs_tcp_setup_socket - create a TCP socket and connect to a remote endpoint + * @work: queued work item + * + * Invoked by a work queue tasklet. + */ +static void xs_tcp_setup_socket(struct work_struct *work) +{ + struct sock_xprt *transport = + container_of(work, struct sock_xprt, connect_worker.work); + struct socket *sock = transport->sock; + struct rpc_xprt *xprt = &transport->xprt; + int status; + unsigned int pflags = current->flags; + + if (atomic_read(&xprt->swapper)) + current->flags |= PF_MEMALLOC; + + if (xprt_connected(xprt)) + goto out; + if (test_and_clear_bit(XPRT_SOCK_CONNECT_SENT, + &transport->sock_state) || + !sock) { + xs_reset_transport(transport); + sock = xs_create_sock(xprt, transport, xs_addr(xprt)->sa_family, + SOCK_STREAM, IPPROTO_TCP, true); + if (IS_ERR(sock)) { + xprt_wake_pending_tasks(xprt, PTR_ERR(sock)); + goto out; + } + } + + dprintk("RPC: worker connecting xprt %p via %s to " + "%s (port %s)\n", xprt, + xprt->address_strings[RPC_DISPLAY_PROTO], + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PORT]); + + status = xs_tcp_finish_connecting(xprt, sock); + trace_rpc_socket_connect(xprt, sock, status); + dprintk("RPC: %p connect status %d connected %d sock state %d\n", + xprt, -status, xprt_connected(xprt), + sock->sk->sk_state); + switch (status) { + case 0: + case -EINPROGRESS: + /* SYN_SENT! */ + set_bit(XPRT_SOCK_CONNECT_SENT, &transport->sock_state); + if (xprt->reestablish_timeout < XS_TCP_INIT_REEST_TO) + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; + fallthrough; + case -EALREADY: + goto out_unlock; + case -EADDRNOTAVAIL: + /* Source port number is unavailable. Try a new one! */ + transport->srcport = 0; + status = -EAGAIN; + break; + case -EINVAL: + /* Happens, for instance, if the user specified a link + * local IPv6 address without a scope-id. + */ + case -ECONNREFUSED: + case -ECONNRESET: + case -ENETDOWN: + case -ENETUNREACH: + case -EHOSTUNREACH: + case -EADDRINUSE: + case -ENOBUFS: + break; + default: + printk("%s: connect returned unhandled error %d\n", + __func__, status); + status = -EAGAIN; + } + + /* xs_tcp_force_close() wakes tasks with a fixed error code. + * We need to wake them first to ensure the correct error code. + */ + xprt_wake_pending_tasks(xprt, status); + xs_tcp_force_close(xprt); +out: + xprt_clear_connecting(xprt); +out_unlock: + xprt_unlock_connect(xprt, transport); + current_restore_flags(pflags, PF_MEMALLOC); +} + +/** + * xs_connect - connect a socket to a remote endpoint + * @xprt: pointer to transport structure + * @task: address of RPC task that manages state of connect request + * + * TCP: If the remote end dropped the connection, delay reconnecting. + * + * UDP socket connects are synchronous, but we use a work queue anyway + * to guarantee that even unprivileged user processes can set up a + * socket on a privileged port. + * + * If a UDP socket connect fails, the delay behavior here prevents + * retry floods (hard mounts). + */ +static void xs_connect(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + unsigned long delay = 0; + + WARN_ON_ONCE(!xprt_lock_connect(xprt, task, transport)); + + if (transport->sock != NULL) { + dprintk("RPC: xs_connect delayed xprt %p for %lu " + "seconds\n", xprt, xprt->reestablish_timeout / HZ); + + delay = xprt_reconnect_delay(xprt); + xprt_reconnect_backoff(xprt, XS_TCP_INIT_REEST_TO); + + } else + dprintk("RPC: xs_connect scheduled xprt %p\n", xprt); + + queue_delayed_work(xprtiod_workqueue, + &transport->connect_worker, + delay); +} + +static void xs_wake_disconnect(struct sock_xprt *transport) +{ + if (test_and_clear_bit(XPRT_SOCK_WAKE_DISCONNECT, &transport->sock_state)) + xs_tcp_force_close(&transport->xprt); +} + +static void xs_wake_write(struct sock_xprt *transport) +{ + if (test_and_clear_bit(XPRT_SOCK_WAKE_WRITE, &transport->sock_state)) + xprt_write_space(&transport->xprt); +} + +static void xs_wake_error(struct sock_xprt *transport) +{ + int sockerr; + + if (!test_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state)) + return; + mutex_lock(&transport->recv_mutex); + if (transport->sock == NULL) + goto out; + if (!test_and_clear_bit(XPRT_SOCK_WAKE_ERROR, &transport->sock_state)) + goto out; + sockerr = xchg(&transport->xprt_err, 0); + if (sockerr < 0) + xprt_wake_pending_tasks(&transport->xprt, sockerr); +out: + mutex_unlock(&transport->recv_mutex); +} + +static void xs_wake_pending(struct sock_xprt *transport) +{ + if (test_and_clear_bit(XPRT_SOCK_WAKE_PENDING, &transport->sock_state)) + xprt_wake_pending_tasks(&transport->xprt, -EAGAIN); +} + +static void xs_error_handle(struct work_struct *work) +{ + struct sock_xprt *transport = container_of(work, + struct sock_xprt, error_worker); + + xs_wake_disconnect(transport); + xs_wake_write(transport); + xs_wake_error(transport); + xs_wake_pending(transport); +} + +/** + * xs_local_print_stats - display AF_LOCAL socket-specific stats + * @xprt: rpc_xprt struct containing statistics + * @seq: output file + * + */ +static void xs_local_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) +{ + long idle_time = 0; + + if (xprt_connected(xprt)) + idle_time = (long)(jiffies - xprt->last_used) / HZ; + + seq_printf(seq, "\txprt:\tlocal %lu %lu %lu %ld %lu %lu %lu " + "%llu %llu %lu %llu %llu\n", + xprt->stat.bind_count, + xprt->stat.connect_count, + xprt->stat.connect_time / HZ, + idle_time, + xprt->stat.sends, + xprt->stat.recvs, + xprt->stat.bad_xids, + xprt->stat.req_u, + xprt->stat.bklog_u, + xprt->stat.max_slots, + xprt->stat.sending_u, + xprt->stat.pending_u); +} + +/** + * xs_udp_print_stats - display UDP socket-specific stats + * @xprt: rpc_xprt struct containing statistics + * @seq: output file + * + */ +static void xs_udp_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + seq_printf(seq, "\txprt:\tudp %u %lu %lu %lu %lu %llu %llu " + "%lu %llu %llu\n", + transport->srcport, + xprt->stat.bind_count, + xprt->stat.sends, + xprt->stat.recvs, + xprt->stat.bad_xids, + xprt->stat.req_u, + xprt->stat.bklog_u, + xprt->stat.max_slots, + xprt->stat.sending_u, + xprt->stat.pending_u); +} + +/** + * xs_tcp_print_stats - display TCP socket-specific stats + * @xprt: rpc_xprt struct containing statistics + * @seq: output file + * + */ +static void xs_tcp_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + long idle_time = 0; + + if (xprt_connected(xprt)) + idle_time = (long)(jiffies - xprt->last_used) / HZ; + + seq_printf(seq, "\txprt:\ttcp %u %lu %lu %lu %ld %lu %lu %lu " + "%llu %llu %lu %llu %llu\n", + transport->srcport, + xprt->stat.bind_count, + xprt->stat.connect_count, + xprt->stat.connect_time / HZ, + idle_time, + xprt->stat.sends, + xprt->stat.recvs, + xprt->stat.bad_xids, + xprt->stat.req_u, + xprt->stat.bklog_u, + xprt->stat.max_slots, + xprt->stat.sending_u, + xprt->stat.pending_u); +} + +/* + * Allocate a bunch of pages for a scratch buffer for the rpc code. The reason + * we allocate pages instead doing a kmalloc like rpc_malloc is because we want + * to use the server side send routines. + */ +static int bc_malloc(struct rpc_task *task) +{ + struct rpc_rqst *rqst = task->tk_rqstp; + size_t size = rqst->rq_callsize; + struct page *page; + struct rpc_buffer *buf; + + if (size > PAGE_SIZE - sizeof(struct rpc_buffer)) { + WARN_ONCE(1, "xprtsock: large bc buffer request (size %zu)\n", + size); + return -EINVAL; + } + + page = alloc_page(GFP_KERNEL | __GFP_NORETRY | __GFP_NOWARN); + if (!page) + return -ENOMEM; + + buf = page_address(page); + buf->len = PAGE_SIZE; + + rqst->rq_buffer = buf->data; + rqst->rq_rbuffer = (char *)rqst->rq_buffer + rqst->rq_callsize; + return 0; +} + +/* + * Free the space allocated in the bc_alloc routine + */ +static void bc_free(struct rpc_task *task) +{ + void *buffer = task->tk_rqstp->rq_buffer; + struct rpc_buffer *buf; + + buf = container_of(buffer, struct rpc_buffer, data); + free_page((unsigned long)buf); +} + +static int bc_sendto(struct rpc_rqst *req) +{ + struct xdr_buf *xdr = &req->rq_snd_buf; + struct sock_xprt *transport = + container_of(req->rq_xprt, struct sock_xprt, xprt); + struct msghdr msg = { + .msg_flags = 0, + }; + rpc_fraghdr marker = cpu_to_be32(RPC_LAST_STREAM_FRAGMENT | + (u32)xdr->len); + unsigned int sent = 0; + int err; + + req->rq_xtime = ktime_get(); + err = xdr_alloc_bvec(xdr, rpc_task_gfp_mask()); + if (err < 0) + return err; + err = xprt_sock_sendmsg(transport->sock, &msg, xdr, 0, marker, &sent); + xdr_free_bvec(xdr); + if (err < 0 || sent != (xdr->len + sizeof(marker))) + return -EAGAIN; + return sent; +} + +/** + * bc_send_request - Send a backchannel Call on a TCP socket + * @req: rpc_rqst containing Call message to be sent + * + * xpt_mutex ensures @rqstp's whole message is written to the socket + * without interruption. + * + * Return values: + * %0 if the message was sent successfully + * %ENOTCONN if the message was not sent + */ +static int bc_send_request(struct rpc_rqst *req) +{ + struct svc_xprt *xprt; + int len; + + /* + * Get the server socket associated with this callback xprt + */ + xprt = req->rq_xprt->bc_xprt; + + /* + * Grab the mutex to serialize data as the connection is shared + * with the fore channel + */ + mutex_lock(&xprt->xpt_mutex); + if (test_bit(XPT_DEAD, &xprt->xpt_flags)) + len = -ENOTCONN; + else + len = bc_sendto(req); + mutex_unlock(&xprt->xpt_mutex); + + if (len > 0) + len = 0; + + return len; +} + +/* + * The close routine. Since this is client initiated, we do nothing + */ + +static void bc_close(struct rpc_xprt *xprt) +{ + xprt_disconnect_done(xprt); +} + +/* + * The xprt destroy routine. Again, because this connection is client + * initiated, we do nothing + */ + +static void bc_destroy(struct rpc_xprt *xprt) +{ + dprintk("RPC: bc_destroy xprt %p\n", xprt); + + xs_xprt_free(xprt); + module_put(THIS_MODULE); +} + +static const struct rpc_xprt_ops xs_local_ops = { + .reserve_xprt = xprt_reserve_xprt, + .release_xprt = xprt_release_xprt, + .alloc_slot = xprt_alloc_slot, + .free_slot = xprt_free_slot, + .rpcbind = xs_local_rpcbind, + .set_port = xs_local_set_port, + .connect = xs_local_connect, + .buf_alloc = rpc_malloc, + .buf_free = rpc_free, + .prepare_request = xs_stream_prepare_request, + .send_request = xs_local_send_request, + .wait_for_reply_request = xprt_wait_for_reply_request_def, + .close = xs_close, + .destroy = xs_destroy, + .print_stats = xs_local_print_stats, + .enable_swap = xs_enable_swap, + .disable_swap = xs_disable_swap, +}; + +static const struct rpc_xprt_ops xs_udp_ops = { + .set_buffer_size = xs_udp_set_buffer_size, + .reserve_xprt = xprt_reserve_xprt_cong, + .release_xprt = xprt_release_xprt_cong, + .alloc_slot = xprt_alloc_slot, + .free_slot = xprt_free_slot, + .rpcbind = rpcb_getport_async, + .set_port = xs_set_port, + .connect = xs_connect, + .get_srcaddr = xs_sock_srcaddr, + .get_srcport = xs_sock_srcport, + .buf_alloc = rpc_malloc, + .buf_free = rpc_free, + .send_request = xs_udp_send_request, + .wait_for_reply_request = xprt_wait_for_reply_request_rtt, + .timer = xs_udp_timer, + .release_request = xprt_release_rqst_cong, + .close = xs_close, + .destroy = xs_destroy, + .print_stats = xs_udp_print_stats, + .enable_swap = xs_enable_swap, + .disable_swap = xs_disable_swap, + .inject_disconnect = xs_inject_disconnect, +}; + +static const struct rpc_xprt_ops xs_tcp_ops = { + .reserve_xprt = xprt_reserve_xprt, + .release_xprt = xprt_release_xprt, + .alloc_slot = xprt_alloc_slot, + .free_slot = xprt_free_slot, + .rpcbind = rpcb_getport_async, + .set_port = xs_set_port, + .connect = xs_connect, + .get_srcaddr = xs_sock_srcaddr, + .get_srcport = xs_sock_srcport, + .buf_alloc = rpc_malloc, + .buf_free = rpc_free, + .prepare_request = xs_stream_prepare_request, + .send_request = xs_tcp_send_request, + .wait_for_reply_request = xprt_wait_for_reply_request_def, + .close = xs_tcp_shutdown, + .destroy = xs_destroy, + .set_connect_timeout = xs_tcp_set_connect_timeout, + .print_stats = xs_tcp_print_stats, + .enable_swap = xs_enable_swap, + .disable_swap = xs_disable_swap, + .inject_disconnect = xs_inject_disconnect, +#ifdef CONFIG_SUNRPC_BACKCHANNEL + .bc_setup = xprt_setup_bc, + .bc_maxpayload = xs_tcp_bc_maxpayload, + .bc_num_slots = xprt_bc_max_slots, + .bc_free_rqst = xprt_free_bc_rqst, + .bc_destroy = xprt_destroy_bc, +#endif +}; + +/* + * The rpc_xprt_ops for the server backchannel + */ + +static const struct rpc_xprt_ops bc_tcp_ops = { + .reserve_xprt = xprt_reserve_xprt, + .release_xprt = xprt_release_xprt, + .alloc_slot = xprt_alloc_slot, + .free_slot = xprt_free_slot, + .buf_alloc = bc_malloc, + .buf_free = bc_free, + .send_request = bc_send_request, + .wait_for_reply_request = xprt_wait_for_reply_request_def, + .close = bc_close, + .destroy = bc_destroy, + .print_stats = xs_tcp_print_stats, + .enable_swap = xs_enable_swap, + .disable_swap = xs_disable_swap, + .inject_disconnect = xs_inject_disconnect, +}; + +static int xs_init_anyaddr(const int family, struct sockaddr *sap) +{ + static const struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_ANY), + }; + static const struct sockaddr_in6 sin6 = { + .sin6_family = AF_INET6, + .sin6_addr = IN6ADDR_ANY_INIT, + }; + + switch (family) { + case AF_LOCAL: + break; + case AF_INET: + memcpy(sap, &sin, sizeof(sin)); + break; + case AF_INET6: + memcpy(sap, &sin6, sizeof(sin6)); + break; + default: + dprintk("RPC: %s: Bad address family\n", __func__); + return -EAFNOSUPPORT; + } + return 0; +} + +static struct rpc_xprt *xs_setup_xprt(struct xprt_create *args, + unsigned int slot_table_size, + unsigned int max_slot_table_size) +{ + struct rpc_xprt *xprt; + struct sock_xprt *new; + + if (args->addrlen > sizeof(xprt->addr)) { + dprintk("RPC: xs_setup_xprt: address too large\n"); + return ERR_PTR(-EBADF); + } + + xprt = xprt_alloc(args->net, sizeof(*new), slot_table_size, + max_slot_table_size); + if (xprt == NULL) { + dprintk("RPC: xs_setup_xprt: couldn't allocate " + "rpc_xprt\n"); + return ERR_PTR(-ENOMEM); + } + + new = container_of(xprt, struct sock_xprt, xprt); + mutex_init(&new->recv_mutex); + memcpy(&xprt->addr, args->dstaddr, args->addrlen); + xprt->addrlen = args->addrlen; + if (args->srcaddr) + memcpy(&new->srcaddr, args->srcaddr, args->addrlen); + else { + int err; + err = xs_init_anyaddr(args->dstaddr->sa_family, + (struct sockaddr *)&new->srcaddr); + if (err != 0) { + xprt_free(xprt); + return ERR_PTR(err); + } + } + + return xprt; +} + +static const struct rpc_timeout xs_local_default_timeout = { + .to_initval = 10 * HZ, + .to_maxval = 10 * HZ, + .to_retries = 2, +}; + +/** + * xs_setup_local - Set up transport to use an AF_LOCAL socket + * @args: rpc transport creation arguments + * + * AF_LOCAL is a "tpi_cots_ord" transport, just like TCP + */ +static struct rpc_xprt *xs_setup_local(struct xprt_create *args) +{ + struct sockaddr_un *sun = (struct sockaddr_un *)args->dstaddr; + struct sock_xprt *transport; + struct rpc_xprt *xprt; + struct rpc_xprt *ret; + + xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries, + xprt_max_tcp_slot_table_entries); + if (IS_ERR(xprt)) + return xprt; + transport = container_of(xprt, struct sock_xprt, xprt); + + xprt->prot = 0; + xprt->xprt_class = &xs_local_transport; + xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; + + xprt->bind_timeout = XS_BIND_TO; + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; + xprt->idle_timeout = XS_IDLE_DISC_TO; + + xprt->ops = &xs_local_ops; + xprt->timeout = &xs_local_default_timeout; + + INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn); + INIT_WORK(&transport->error_worker, xs_error_handle); + INIT_DELAYED_WORK(&transport->connect_worker, xs_dummy_setup_socket); + + switch (sun->sun_family) { + case AF_LOCAL: + if (sun->sun_path[0] != '/') { + dprintk("RPC: bad AF_LOCAL address: %s\n", + sun->sun_path); + ret = ERR_PTR(-EINVAL); + goto out_err; + } + xprt_set_bound(xprt); + xs_format_peer_addresses(xprt, "local", RPCBIND_NETID_LOCAL); + break; + default: + ret = ERR_PTR(-EAFNOSUPPORT); + goto out_err; + } + + dprintk("RPC: set up xprt to %s via AF_LOCAL\n", + xprt->address_strings[RPC_DISPLAY_ADDR]); + + if (try_module_get(THIS_MODULE)) + return xprt; + ret = ERR_PTR(-EINVAL); +out_err: + xs_xprt_free(xprt); + return ret; +} + +static const struct rpc_timeout xs_udp_default_timeout = { + .to_initval = 5 * HZ, + .to_maxval = 30 * HZ, + .to_increment = 5 * HZ, + .to_retries = 5, +}; + +/** + * xs_setup_udp - Set up transport to use a UDP socket + * @args: rpc transport creation arguments + * + */ +static struct rpc_xprt *xs_setup_udp(struct xprt_create *args) +{ + struct sockaddr *addr = args->dstaddr; + struct rpc_xprt *xprt; + struct sock_xprt *transport; + struct rpc_xprt *ret; + + xprt = xs_setup_xprt(args, xprt_udp_slot_table_entries, + xprt_udp_slot_table_entries); + if (IS_ERR(xprt)) + return xprt; + transport = container_of(xprt, struct sock_xprt, xprt); + + xprt->prot = IPPROTO_UDP; + xprt->xprt_class = &xs_udp_transport; + /* XXX: header size can vary due to auth type, IPv6, etc. */ + xprt->max_payload = (1U << 16) - (MAX_HEADER << 3); + + xprt->bind_timeout = XS_BIND_TO; + xprt->reestablish_timeout = XS_UDP_REEST_TO; + xprt->idle_timeout = XS_IDLE_DISC_TO; + + xprt->ops = &xs_udp_ops; + + xprt->timeout = &xs_udp_default_timeout; + + INIT_WORK(&transport->recv_worker, xs_udp_data_receive_workfn); + INIT_WORK(&transport->error_worker, xs_error_handle); + INIT_DELAYED_WORK(&transport->connect_worker, xs_udp_setup_socket); + + switch (addr->sa_family) { + case AF_INET: + if (((struct sockaddr_in *)addr)->sin_port != htons(0)) + xprt_set_bound(xprt); + + xs_format_peer_addresses(xprt, "udp", RPCBIND_NETID_UDP); + break; + case AF_INET6: + if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0)) + xprt_set_bound(xprt); + + xs_format_peer_addresses(xprt, "udp", RPCBIND_NETID_UDP6); + break; + default: + ret = ERR_PTR(-EAFNOSUPPORT); + goto out_err; + } + + if (xprt_bound(xprt)) + dprintk("RPC: set up xprt to %s (port %s) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PORT], + xprt->address_strings[RPC_DISPLAY_PROTO]); + else + dprintk("RPC: set up xprt to %s (autobind) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PROTO]); + + if (try_module_get(THIS_MODULE)) + return xprt; + ret = ERR_PTR(-EINVAL); +out_err: + xs_xprt_free(xprt); + return ret; +} + +static const struct rpc_timeout xs_tcp_default_timeout = { + .to_initval = 60 * HZ, + .to_maxval = 60 * HZ, + .to_retries = 2, +}; + +/** + * xs_setup_tcp - Set up transport to use a TCP socket + * @args: rpc transport creation arguments + * + */ +static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args) +{ + struct sockaddr *addr = args->dstaddr; + struct rpc_xprt *xprt; + struct sock_xprt *transport; + struct rpc_xprt *ret; + unsigned int max_slot_table_size = xprt_max_tcp_slot_table_entries; + + if (args->flags & XPRT_CREATE_INFINITE_SLOTS) + max_slot_table_size = RPC_MAX_SLOT_TABLE_LIMIT; + + xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries, + max_slot_table_size); + if (IS_ERR(xprt)) + return xprt; + transport = container_of(xprt, struct sock_xprt, xprt); + + xprt->prot = IPPROTO_TCP; + xprt->xprt_class = &xs_tcp_transport; + xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; + + xprt->bind_timeout = XS_BIND_TO; + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; + xprt->idle_timeout = XS_IDLE_DISC_TO; + + xprt->ops = &xs_tcp_ops; + xprt->timeout = &xs_tcp_default_timeout; + + xprt->max_reconnect_timeout = xprt->timeout->to_maxval; + xprt->connect_timeout = xprt->timeout->to_initval * + (xprt->timeout->to_retries + 1); + + INIT_WORK(&transport->recv_worker, xs_stream_data_receive_workfn); + INIT_WORK(&transport->error_worker, xs_error_handle); + INIT_DELAYED_WORK(&transport->connect_worker, xs_tcp_setup_socket); + + switch (addr->sa_family) { + case AF_INET: + if (((struct sockaddr_in *)addr)->sin_port != htons(0)) + xprt_set_bound(xprt); + + xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP); + break; + case AF_INET6: + if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0)) + xprt_set_bound(xprt); + + xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP6); + break; + default: + ret = ERR_PTR(-EAFNOSUPPORT); + goto out_err; + } + + if (xprt_bound(xprt)) + dprintk("RPC: set up xprt to %s (port %s) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PORT], + xprt->address_strings[RPC_DISPLAY_PROTO]); + else + dprintk("RPC: set up xprt to %s (autobind) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PROTO]); + + if (try_module_get(THIS_MODULE)) + return xprt; + ret = ERR_PTR(-EINVAL); +out_err: + xs_xprt_free(xprt); + return ret; +} + +/** + * xs_setup_bc_tcp - Set up transport to use a TCP backchannel socket + * @args: rpc transport creation arguments + * + */ +static struct rpc_xprt *xs_setup_bc_tcp(struct xprt_create *args) +{ + struct sockaddr *addr = args->dstaddr; + struct rpc_xprt *xprt; + struct sock_xprt *transport; + struct svc_sock *bc_sock; + struct rpc_xprt *ret; + + xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries, + xprt_tcp_slot_table_entries); + if (IS_ERR(xprt)) + return xprt; + transport = container_of(xprt, struct sock_xprt, xprt); + + xprt->prot = IPPROTO_TCP; + xprt->xprt_class = &xs_bc_tcp_transport; + xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; + xprt->timeout = &xs_tcp_default_timeout; + + /* backchannel */ + xprt_set_bound(xprt); + xprt->bind_timeout = 0; + xprt->reestablish_timeout = 0; + xprt->idle_timeout = 0; + + xprt->ops = &bc_tcp_ops; + + switch (addr->sa_family) { + case AF_INET: + xs_format_peer_addresses(xprt, "tcp", + RPCBIND_NETID_TCP); + break; + case AF_INET6: + xs_format_peer_addresses(xprt, "tcp", + RPCBIND_NETID_TCP6); + break; + default: + ret = ERR_PTR(-EAFNOSUPPORT); + goto out_err; + } + + dprintk("RPC: set up xprt to %s (port %s) via %s\n", + xprt->address_strings[RPC_DISPLAY_ADDR], + xprt->address_strings[RPC_DISPLAY_PORT], + xprt->address_strings[RPC_DISPLAY_PROTO]); + + /* + * Once we've associated a backchannel xprt with a connection, + * we want to keep it around as long as the connection lasts, + * in case we need to start using it for a backchannel again; + * this reference won't be dropped until bc_xprt is destroyed. + */ + xprt_get(xprt); + args->bc_xprt->xpt_bc_xprt = xprt; + xprt->bc_xprt = args->bc_xprt; + bc_sock = container_of(args->bc_xprt, struct svc_sock, sk_xprt); + transport->sock = bc_sock->sk_sock; + transport->inet = bc_sock->sk_sk; + + /* + * Since we don't want connections for the backchannel, we set + * the xprt status to connected + */ + xprt_set_connected(xprt); + + if (try_module_get(THIS_MODULE)) + return xprt; + + args->bc_xprt->xpt_bc_xprt = NULL; + args->bc_xprt->xpt_bc_xps = NULL; + xprt_put(xprt); + ret = ERR_PTR(-EINVAL); +out_err: + xs_xprt_free(xprt); + return ret; +} + +static struct xprt_class xs_local_transport = { + .list = LIST_HEAD_INIT(xs_local_transport.list), + .name = "named UNIX socket", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_LOCAL, + .setup = xs_setup_local, + .netid = { "" }, +}; + +static struct xprt_class xs_udp_transport = { + .list = LIST_HEAD_INIT(xs_udp_transport.list), + .name = "udp", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_UDP, + .setup = xs_setup_udp, + .netid = { "udp", "udp6", "" }, +}; + +static struct xprt_class xs_tcp_transport = { + .list = LIST_HEAD_INIT(xs_tcp_transport.list), + .name = "tcp", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_TCP, + .setup = xs_setup_tcp, + .netid = { "tcp", "tcp6", "" }, +}; + +static struct xprt_class xs_bc_tcp_transport = { + .list = LIST_HEAD_INIT(xs_bc_tcp_transport.list), + .name = "tcp NFSv4.1 backchannel", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_BC_TCP, + .setup = xs_setup_bc_tcp, + .netid = { "" }, +}; + +/** + * init_socket_xprt - set up xprtsock's sysctls, register with RPC client + * + */ +int init_socket_xprt(void) +{ + if (!sunrpc_table_header) + sunrpc_table_header = register_sysctl_table(sunrpc_table); + + xprt_register_transport(&xs_local_transport); + xprt_register_transport(&xs_udp_transport); + xprt_register_transport(&xs_tcp_transport); + xprt_register_transport(&xs_bc_tcp_transport); + + return 0; +} + +/** + * cleanup_socket_xprt - remove xprtsock's sysctls, unregister + * + */ +void cleanup_socket_xprt(void) +{ + if (sunrpc_table_header) { + unregister_sysctl_table(sunrpc_table_header); + sunrpc_table_header = NULL; + } + + xprt_unregister_transport(&xs_local_transport); + xprt_unregister_transport(&xs_udp_transport); + xprt_unregister_transport(&xs_tcp_transport); + xprt_unregister_transport(&xs_bc_tcp_transport); +} + +static int param_set_portnr(const char *val, const struct kernel_param *kp) +{ + return param_set_uint_minmax(val, kp, + RPC_MIN_RESVPORT, + RPC_MAX_RESVPORT); +} + +static const struct kernel_param_ops param_ops_portnr = { + .set = param_set_portnr, + .get = param_get_uint, +}; + +#define param_check_portnr(name, p) \ + __param_check(name, p, unsigned int); + +module_param_named(min_resvport, xprt_min_resvport, portnr, 0644); +module_param_named(max_resvport, xprt_max_resvport, portnr, 0644); + +static int param_set_slot_table_size(const char *val, + const struct kernel_param *kp) +{ + return param_set_uint_minmax(val, kp, + RPC_MIN_SLOT_TABLE, + RPC_MAX_SLOT_TABLE); +} + +static const struct kernel_param_ops param_ops_slot_table_size = { + .set = param_set_slot_table_size, + .get = param_get_uint, +}; + +#define param_check_slot_table_size(name, p) \ + __param_check(name, p, unsigned int); + +static int param_set_max_slot_table_size(const char *val, + const struct kernel_param *kp) +{ + return param_set_uint_minmax(val, kp, + RPC_MIN_SLOT_TABLE, + RPC_MAX_SLOT_TABLE_LIMIT); +} + +static const struct kernel_param_ops param_ops_max_slot_table_size = { + .set = param_set_max_slot_table_size, + .get = param_get_uint, +}; + +#define param_check_max_slot_table_size(name, p) \ + __param_check(name, p, unsigned int); + +module_param_named(tcp_slot_table_entries, xprt_tcp_slot_table_entries, + slot_table_size, 0644); +module_param_named(tcp_max_slot_table_entries, xprt_max_tcp_slot_table_entries, + max_slot_table_size, 0644); +module_param_named(udp_slot_table_entries, xprt_udp_slot_table_entries, + slot_table_size, 0644); diff --git a/net/switchdev/Kconfig b/net/switchdev/Kconfig new file mode 100644 index 000000000..18a2d980e --- /dev/null +++ b/net/switchdev/Kconfig @@ -0,0 +1,14 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Configuration for Switch device support +# + +config NET_SWITCHDEV + bool "Switch (and switch-ish) device support" + depends on INET + help + This module provides glue between core networking code and device + drivers in order to support hardware switch chips in very generic + meaning of the word "switch". This include devices supporting L2/L3 but + also various flow offloading chips, including switches embedded into + SR-IOV NICs. diff --git a/net/switchdev/Makefile b/net/switchdev/Makefile new file mode 100644 index 000000000..c5561d7f3 --- /dev/null +++ b/net/switchdev/Makefile @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the Switch device API +# + +obj-y += switchdev.o diff --git a/net/switchdev/switchdev.c b/net/switchdev/switchdev.c new file mode 100644 index 000000000..8cc42aea1 --- /dev/null +++ b/net/switchdev/switchdev.c @@ -0,0 +1,864 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * net/switchdev/switchdev.c - Switch device API + * Copyright (c) 2014-2015 Jiri Pirko + * Copyright (c) 2014-2015 Scott Feldman + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static LIST_HEAD(deferred); +static DEFINE_SPINLOCK(deferred_lock); + +typedef void switchdev_deferred_func_t(struct net_device *dev, + const void *data); + +struct switchdev_deferred_item { + struct list_head list; + struct net_device *dev; + netdevice_tracker dev_tracker; + switchdev_deferred_func_t *func; + unsigned long data[]; +}; + +static struct switchdev_deferred_item *switchdev_deferred_dequeue(void) +{ + struct switchdev_deferred_item *dfitem; + + spin_lock_bh(&deferred_lock); + if (list_empty(&deferred)) { + dfitem = NULL; + goto unlock; + } + dfitem = list_first_entry(&deferred, + struct switchdev_deferred_item, list); + list_del(&dfitem->list); +unlock: + spin_unlock_bh(&deferred_lock); + return dfitem; +} + +/** + * switchdev_deferred_process - Process ops in deferred queue + * + * Called to flush the ops currently queued in deferred ops queue. + * rtnl_lock must be held. + */ +void switchdev_deferred_process(void) +{ + struct switchdev_deferred_item *dfitem; + + ASSERT_RTNL(); + + while ((dfitem = switchdev_deferred_dequeue())) { + dfitem->func(dfitem->dev, dfitem->data); + netdev_put(dfitem->dev, &dfitem->dev_tracker); + kfree(dfitem); + } +} +EXPORT_SYMBOL_GPL(switchdev_deferred_process); + +static void switchdev_deferred_process_work(struct work_struct *work) +{ + rtnl_lock(); + switchdev_deferred_process(); + rtnl_unlock(); +} + +static DECLARE_WORK(deferred_process_work, switchdev_deferred_process_work); + +static int switchdev_deferred_enqueue(struct net_device *dev, + const void *data, size_t data_len, + switchdev_deferred_func_t *func) +{ + struct switchdev_deferred_item *dfitem; + + dfitem = kmalloc(struct_size(dfitem, data, data_len), GFP_ATOMIC); + if (!dfitem) + return -ENOMEM; + dfitem->dev = dev; + dfitem->func = func; + memcpy(dfitem->data, data, data_len); + netdev_hold(dev, &dfitem->dev_tracker, GFP_ATOMIC); + spin_lock_bh(&deferred_lock); + list_add_tail(&dfitem->list, &deferred); + spin_unlock_bh(&deferred_lock); + schedule_work(&deferred_process_work); + return 0; +} + +static int switchdev_port_attr_notify(enum switchdev_notifier_type nt, + struct net_device *dev, + const struct switchdev_attr *attr, + struct netlink_ext_ack *extack) +{ + int err; + int rc; + + struct switchdev_notifier_port_attr_info attr_info = { + .attr = attr, + .handled = false, + }; + + rc = call_switchdev_blocking_notifiers(nt, dev, + &attr_info.info, extack); + err = notifier_to_errno(rc); + if (err) { + WARN_ON(!attr_info.handled); + return err; + } + + if (!attr_info.handled) + return -EOPNOTSUPP; + + return 0; +} + +static int switchdev_port_attr_set_now(struct net_device *dev, + const struct switchdev_attr *attr, + struct netlink_ext_ack *extack) +{ + return switchdev_port_attr_notify(SWITCHDEV_PORT_ATTR_SET, dev, attr, + extack); +} + +static void switchdev_port_attr_set_deferred(struct net_device *dev, + const void *data) +{ + const struct switchdev_attr *attr = data; + int err; + + err = switchdev_port_attr_set_now(dev, attr, NULL); + if (err && err != -EOPNOTSUPP) + netdev_err(dev, "failed (err=%d) to set attribute (id=%d)\n", + err, attr->id); + if (attr->complete) + attr->complete(dev, err, attr->complete_priv); +} + +static int switchdev_port_attr_set_defer(struct net_device *dev, + const struct switchdev_attr *attr) +{ + return switchdev_deferred_enqueue(dev, attr, sizeof(*attr), + switchdev_port_attr_set_deferred); +} + +/** + * switchdev_port_attr_set - Set port attribute + * + * @dev: port device + * @attr: attribute to set + * @extack: netlink extended ack, for error message propagation + * + * rtnl_lock must be held and must not be in atomic section, + * in case SWITCHDEV_F_DEFER flag is not set. + */ +int switchdev_port_attr_set(struct net_device *dev, + const struct switchdev_attr *attr, + struct netlink_ext_ack *extack) +{ + if (attr->flags & SWITCHDEV_F_DEFER) + return switchdev_port_attr_set_defer(dev, attr); + ASSERT_RTNL(); + return switchdev_port_attr_set_now(dev, attr, extack); +} +EXPORT_SYMBOL_GPL(switchdev_port_attr_set); + +static size_t switchdev_obj_size(const struct switchdev_obj *obj) +{ + switch (obj->id) { + case SWITCHDEV_OBJ_ID_PORT_VLAN: + return sizeof(struct switchdev_obj_port_vlan); + case SWITCHDEV_OBJ_ID_PORT_MDB: + return sizeof(struct switchdev_obj_port_mdb); + case SWITCHDEV_OBJ_ID_HOST_MDB: + return sizeof(struct switchdev_obj_port_mdb); + default: + BUG(); + } + return 0; +} + +static int switchdev_port_obj_notify(enum switchdev_notifier_type nt, + struct net_device *dev, + const struct switchdev_obj *obj, + struct netlink_ext_ack *extack) +{ + int rc; + int err; + + struct switchdev_notifier_port_obj_info obj_info = { + .obj = obj, + .handled = false, + }; + + rc = call_switchdev_blocking_notifiers(nt, dev, &obj_info.info, extack); + err = notifier_to_errno(rc); + if (err) { + WARN_ON(!obj_info.handled); + return err; + } + if (!obj_info.handled) + return -EOPNOTSUPP; + return 0; +} + +static void switchdev_port_obj_add_deferred(struct net_device *dev, + const void *data) +{ + const struct switchdev_obj *obj = data; + int err; + + ASSERT_RTNL(); + err = switchdev_port_obj_notify(SWITCHDEV_PORT_OBJ_ADD, + dev, obj, NULL); + if (err && err != -EOPNOTSUPP) + netdev_err(dev, "failed (err=%d) to add object (id=%d)\n", + err, obj->id); + if (obj->complete) + obj->complete(dev, err, obj->complete_priv); +} + +static int switchdev_port_obj_add_defer(struct net_device *dev, + const struct switchdev_obj *obj) +{ + return switchdev_deferred_enqueue(dev, obj, switchdev_obj_size(obj), + switchdev_port_obj_add_deferred); +} + +/** + * switchdev_port_obj_add - Add port object + * + * @dev: port device + * @obj: object to add + * @extack: netlink extended ack + * + * rtnl_lock must be held and must not be in atomic section, + * in case SWITCHDEV_F_DEFER flag is not set. + */ +int switchdev_port_obj_add(struct net_device *dev, + const struct switchdev_obj *obj, + struct netlink_ext_ack *extack) +{ + if (obj->flags & SWITCHDEV_F_DEFER) + return switchdev_port_obj_add_defer(dev, obj); + ASSERT_RTNL(); + return switchdev_port_obj_notify(SWITCHDEV_PORT_OBJ_ADD, + dev, obj, extack); +} +EXPORT_SYMBOL_GPL(switchdev_port_obj_add); + +static int switchdev_port_obj_del_now(struct net_device *dev, + const struct switchdev_obj *obj) +{ + return switchdev_port_obj_notify(SWITCHDEV_PORT_OBJ_DEL, + dev, obj, NULL); +} + +static void switchdev_port_obj_del_deferred(struct net_device *dev, + const void *data) +{ + const struct switchdev_obj *obj = data; + int err; + + err = switchdev_port_obj_del_now(dev, obj); + if (err && err != -EOPNOTSUPP) + netdev_err(dev, "failed (err=%d) to del object (id=%d)\n", + err, obj->id); + if (obj->complete) + obj->complete(dev, err, obj->complete_priv); +} + +static int switchdev_port_obj_del_defer(struct net_device *dev, + const struct switchdev_obj *obj) +{ + return switchdev_deferred_enqueue(dev, obj, switchdev_obj_size(obj), + switchdev_port_obj_del_deferred); +} + +/** + * switchdev_port_obj_del - Delete port object + * + * @dev: port device + * @obj: object to delete + * + * rtnl_lock must be held and must not be in atomic section, + * in case SWITCHDEV_F_DEFER flag is not set. + */ +int switchdev_port_obj_del(struct net_device *dev, + const struct switchdev_obj *obj) +{ + if (obj->flags & SWITCHDEV_F_DEFER) + return switchdev_port_obj_del_defer(dev, obj); + ASSERT_RTNL(); + return switchdev_port_obj_del_now(dev, obj); +} +EXPORT_SYMBOL_GPL(switchdev_port_obj_del); + +static ATOMIC_NOTIFIER_HEAD(switchdev_notif_chain); +static BLOCKING_NOTIFIER_HEAD(switchdev_blocking_notif_chain); + +/** + * register_switchdev_notifier - Register notifier + * @nb: notifier_block + * + * Register switch device notifier. + */ +int register_switchdev_notifier(struct notifier_block *nb) +{ + return atomic_notifier_chain_register(&switchdev_notif_chain, nb); +} +EXPORT_SYMBOL_GPL(register_switchdev_notifier); + +/** + * unregister_switchdev_notifier - Unregister notifier + * @nb: notifier_block + * + * Unregister switch device notifier. + */ +int unregister_switchdev_notifier(struct notifier_block *nb) +{ + return atomic_notifier_chain_unregister(&switchdev_notif_chain, nb); +} +EXPORT_SYMBOL_GPL(unregister_switchdev_notifier); + +/** + * call_switchdev_notifiers - Call notifiers + * @val: value passed unmodified to notifier function + * @dev: port device + * @info: notifier information data + * @extack: netlink extended ack + * Call all network notifier blocks. + */ +int call_switchdev_notifiers(unsigned long val, struct net_device *dev, + struct switchdev_notifier_info *info, + struct netlink_ext_ack *extack) +{ + info->dev = dev; + info->extack = extack; + return atomic_notifier_call_chain(&switchdev_notif_chain, val, info); +} +EXPORT_SYMBOL_GPL(call_switchdev_notifiers); + +int register_switchdev_blocking_notifier(struct notifier_block *nb) +{ + struct blocking_notifier_head *chain = &switchdev_blocking_notif_chain; + + return blocking_notifier_chain_register(chain, nb); +} +EXPORT_SYMBOL_GPL(register_switchdev_blocking_notifier); + +int unregister_switchdev_blocking_notifier(struct notifier_block *nb) +{ + struct blocking_notifier_head *chain = &switchdev_blocking_notif_chain; + + return blocking_notifier_chain_unregister(chain, nb); +} +EXPORT_SYMBOL_GPL(unregister_switchdev_blocking_notifier); + +int call_switchdev_blocking_notifiers(unsigned long val, struct net_device *dev, + struct switchdev_notifier_info *info, + struct netlink_ext_ack *extack) +{ + info->dev = dev; + info->extack = extack; + return blocking_notifier_call_chain(&switchdev_blocking_notif_chain, + val, info); +} +EXPORT_SYMBOL_GPL(call_switchdev_blocking_notifiers); + +struct switchdev_nested_priv { + bool (*check_cb)(const struct net_device *dev); + bool (*foreign_dev_check_cb)(const struct net_device *dev, + const struct net_device *foreign_dev); + const struct net_device *dev; + struct net_device *lower_dev; +}; + +static int switchdev_lower_dev_walk(struct net_device *lower_dev, + struct netdev_nested_priv *priv) +{ + struct switchdev_nested_priv *switchdev_priv = priv->data; + bool (*foreign_dev_check_cb)(const struct net_device *dev, + const struct net_device *foreign_dev); + bool (*check_cb)(const struct net_device *dev); + const struct net_device *dev; + + check_cb = switchdev_priv->check_cb; + foreign_dev_check_cb = switchdev_priv->foreign_dev_check_cb; + dev = switchdev_priv->dev; + + if (check_cb(lower_dev) && !foreign_dev_check_cb(lower_dev, dev)) { + switchdev_priv->lower_dev = lower_dev; + return 1; + } + + return 0; +} + +static struct net_device * +switchdev_lower_dev_find_rcu(struct net_device *dev, + bool (*check_cb)(const struct net_device *dev), + bool (*foreign_dev_check_cb)(const struct net_device *dev, + const struct net_device *foreign_dev)) +{ + struct switchdev_nested_priv switchdev_priv = { + .check_cb = check_cb, + .foreign_dev_check_cb = foreign_dev_check_cb, + .dev = dev, + .lower_dev = NULL, + }; + struct netdev_nested_priv priv = { + .data = &switchdev_priv, + }; + + netdev_walk_all_lower_dev_rcu(dev, switchdev_lower_dev_walk, &priv); + + return switchdev_priv.lower_dev; +} + +static struct net_device * +switchdev_lower_dev_find(struct net_device *dev, + bool (*check_cb)(const struct net_device *dev), + bool (*foreign_dev_check_cb)(const struct net_device *dev, + const struct net_device *foreign_dev)) +{ + struct switchdev_nested_priv switchdev_priv = { + .check_cb = check_cb, + .foreign_dev_check_cb = foreign_dev_check_cb, + .dev = dev, + .lower_dev = NULL, + }; + struct netdev_nested_priv priv = { + .data = &switchdev_priv, + }; + + netdev_walk_all_lower_dev(dev, switchdev_lower_dev_walk, &priv); + + return switchdev_priv.lower_dev; +} + +static int __switchdev_handle_fdb_event_to_device(struct net_device *dev, + struct net_device *orig_dev, unsigned long event, + const struct switchdev_notifier_fdb_info *fdb_info, + bool (*check_cb)(const struct net_device *dev), + bool (*foreign_dev_check_cb)(const struct net_device *dev, + const struct net_device *foreign_dev), + int (*mod_cb)(struct net_device *dev, struct net_device *orig_dev, + unsigned long event, const void *ctx, + const struct switchdev_notifier_fdb_info *fdb_info)) +{ + const struct switchdev_notifier_info *info = &fdb_info->info; + struct net_device *br, *lower_dev, *switchdev; + struct list_head *iter; + int err = -EOPNOTSUPP; + + if (check_cb(dev)) + return mod_cb(dev, orig_dev, event, info->ctx, fdb_info); + + /* Recurse through lower interfaces in case the FDB entry is pointing + * towards a bridge or a LAG device. + */ + netdev_for_each_lower_dev(dev, lower_dev, iter) { + /* Do not propagate FDB entries across bridges */ + if (netif_is_bridge_master(lower_dev)) + continue; + + /* Bridge ports might be either us, or LAG interfaces + * that we offload. + */ + if (!check_cb(lower_dev) && + !switchdev_lower_dev_find_rcu(lower_dev, check_cb, + foreign_dev_check_cb)) + continue; + + err = __switchdev_handle_fdb_event_to_device(lower_dev, orig_dev, + event, fdb_info, check_cb, + foreign_dev_check_cb, + mod_cb); + if (err && err != -EOPNOTSUPP) + return err; + } + + /* Event is neither on a bridge nor a LAG. Check whether it is on an + * interface that is in a bridge with us. + */ + br = netdev_master_upper_dev_get_rcu(dev); + if (!br || !netif_is_bridge_master(br)) + return 0; + + switchdev = switchdev_lower_dev_find_rcu(br, check_cb, foreign_dev_check_cb); + if (!switchdev) + return 0; + + if (!foreign_dev_check_cb(switchdev, dev)) + return err; + + return __switchdev_handle_fdb_event_to_device(br, orig_dev, event, fdb_info, + check_cb, foreign_dev_check_cb, + mod_cb); +} + +int switchdev_handle_fdb_event_to_device(struct net_device *dev, unsigned long event, + const struct switchdev_notifier_fdb_info *fdb_info, + bool (*check_cb)(const struct net_device *dev), + bool (*foreign_dev_check_cb)(const struct net_device *dev, + const struct net_device *foreign_dev), + int (*mod_cb)(struct net_device *dev, struct net_device *orig_dev, + unsigned long event, const void *ctx, + const struct switchdev_notifier_fdb_info *fdb_info)) +{ + int err; + + err = __switchdev_handle_fdb_event_to_device(dev, dev, event, fdb_info, + check_cb, foreign_dev_check_cb, + mod_cb); + if (err == -EOPNOTSUPP) + err = 0; + + return err; +} +EXPORT_SYMBOL_GPL(switchdev_handle_fdb_event_to_device); + +static int __switchdev_handle_port_obj_add(struct net_device *dev, + struct switchdev_notifier_port_obj_info *port_obj_info, + bool (*check_cb)(const struct net_device *dev), + bool (*foreign_dev_check_cb)(const struct net_device *dev, + const struct net_device *foreign_dev), + int (*add_cb)(struct net_device *dev, const void *ctx, + const struct switchdev_obj *obj, + struct netlink_ext_ack *extack)) +{ + struct switchdev_notifier_info *info = &port_obj_info->info; + struct net_device *br, *lower_dev, *switchdev; + struct netlink_ext_ack *extack; + struct list_head *iter; + int err = -EOPNOTSUPP; + + extack = switchdev_notifier_info_to_extack(info); + + if (check_cb(dev)) { + err = add_cb(dev, info->ctx, port_obj_info->obj, extack); + if (err != -EOPNOTSUPP) + port_obj_info->handled = true; + return err; + } + + /* Switch ports might be stacked under e.g. a LAG. Ignore the + * unsupported devices, another driver might be able to handle them. But + * propagate to the callers any hard errors. + * + * If the driver does its own bookkeeping of stacked ports, it's not + * necessary to go through this helper. + */ + netdev_for_each_lower_dev(dev, lower_dev, iter) { + if (netif_is_bridge_master(lower_dev)) + continue; + + /* When searching for switchdev interfaces that are neighbors + * of foreign ones, and @dev is a bridge, do not recurse on the + * foreign interface again, it was already visited. + */ + if (foreign_dev_check_cb && !check_cb(lower_dev) && + !switchdev_lower_dev_find(lower_dev, check_cb, foreign_dev_check_cb)) + continue; + + err = __switchdev_handle_port_obj_add(lower_dev, port_obj_info, + check_cb, foreign_dev_check_cb, + add_cb); + if (err && err != -EOPNOTSUPP) + return err; + } + + /* Event is neither on a bridge nor a LAG. Check whether it is on an + * interface that is in a bridge with us. + */ + if (!foreign_dev_check_cb) + return err; + + br = netdev_master_upper_dev_get(dev); + if (!br || !netif_is_bridge_master(br)) + return err; + + switchdev = switchdev_lower_dev_find(br, check_cb, foreign_dev_check_cb); + if (!switchdev) + return err; + + if (!foreign_dev_check_cb(switchdev, dev)) + return err; + + return __switchdev_handle_port_obj_add(br, port_obj_info, check_cb, + foreign_dev_check_cb, add_cb); +} + +/* Pass through a port object addition, if @dev passes @check_cb, or replicate + * it towards all lower interfaces of @dev that pass @check_cb, if @dev is a + * bridge or a LAG. + */ +int switchdev_handle_port_obj_add(struct net_device *dev, + struct switchdev_notifier_port_obj_info *port_obj_info, + bool (*check_cb)(const struct net_device *dev), + int (*add_cb)(struct net_device *dev, const void *ctx, + const struct switchdev_obj *obj, + struct netlink_ext_ack *extack)) +{ + int err; + + err = __switchdev_handle_port_obj_add(dev, port_obj_info, check_cb, + NULL, add_cb); + if (err == -EOPNOTSUPP) + err = 0; + return err; +} +EXPORT_SYMBOL_GPL(switchdev_handle_port_obj_add); + +/* Same as switchdev_handle_port_obj_add(), except if object is notified on a + * @dev that passes @foreign_dev_check_cb, it is replicated towards all devices + * that pass @check_cb and are in the same bridge as @dev. + */ +int switchdev_handle_port_obj_add_foreign(struct net_device *dev, + struct switchdev_notifier_port_obj_info *port_obj_info, + bool (*check_cb)(const struct net_device *dev), + bool (*foreign_dev_check_cb)(const struct net_device *dev, + const struct net_device *foreign_dev), + int (*add_cb)(struct net_device *dev, const void *ctx, + const struct switchdev_obj *obj, + struct netlink_ext_ack *extack)) +{ + int err; + + err = __switchdev_handle_port_obj_add(dev, port_obj_info, check_cb, + foreign_dev_check_cb, add_cb); + if (err == -EOPNOTSUPP) + err = 0; + return err; +} +EXPORT_SYMBOL_GPL(switchdev_handle_port_obj_add_foreign); + +static int __switchdev_handle_port_obj_del(struct net_device *dev, + struct switchdev_notifier_port_obj_info *port_obj_info, + bool (*check_cb)(const struct net_device *dev), + bool (*foreign_dev_check_cb)(const struct net_device *dev, + const struct net_device *foreign_dev), + int (*del_cb)(struct net_device *dev, const void *ctx, + const struct switchdev_obj *obj)) +{ + struct switchdev_notifier_info *info = &port_obj_info->info; + struct net_device *br, *lower_dev, *switchdev; + struct list_head *iter; + int err = -EOPNOTSUPP; + + if (check_cb(dev)) { + err = del_cb(dev, info->ctx, port_obj_info->obj); + if (err != -EOPNOTSUPP) + port_obj_info->handled = true; + return err; + } + + /* Switch ports might be stacked under e.g. a LAG. Ignore the + * unsupported devices, another driver might be able to handle them. But + * propagate to the callers any hard errors. + * + * If the driver does its own bookkeeping of stacked ports, it's not + * necessary to go through this helper. + */ + netdev_for_each_lower_dev(dev, lower_dev, iter) { + if (netif_is_bridge_master(lower_dev)) + continue; + + /* When searching for switchdev interfaces that are neighbors + * of foreign ones, and @dev is a bridge, do not recurse on the + * foreign interface again, it was already visited. + */ + if (foreign_dev_check_cb && !check_cb(lower_dev) && + !switchdev_lower_dev_find(lower_dev, check_cb, foreign_dev_check_cb)) + continue; + + err = __switchdev_handle_port_obj_del(lower_dev, port_obj_info, + check_cb, foreign_dev_check_cb, + del_cb); + if (err && err != -EOPNOTSUPP) + return err; + } + + /* Event is neither on a bridge nor a LAG. Check whether it is on an + * interface that is in a bridge with us. + */ + if (!foreign_dev_check_cb) + return err; + + br = netdev_master_upper_dev_get(dev); + if (!br || !netif_is_bridge_master(br)) + return err; + + switchdev = switchdev_lower_dev_find(br, check_cb, foreign_dev_check_cb); + if (!switchdev) + return err; + + if (!foreign_dev_check_cb(switchdev, dev)) + return err; + + return __switchdev_handle_port_obj_del(br, port_obj_info, check_cb, + foreign_dev_check_cb, del_cb); +} + +/* Pass through a port object deletion, if @dev passes @check_cb, or replicate + * it towards all lower interfaces of @dev that pass @check_cb, if @dev is a + * bridge or a LAG. + */ +int switchdev_handle_port_obj_del(struct net_device *dev, + struct switchdev_notifier_port_obj_info *port_obj_info, + bool (*check_cb)(const struct net_device *dev), + int (*del_cb)(struct net_device *dev, const void *ctx, + const struct switchdev_obj *obj)) +{ + int err; + + err = __switchdev_handle_port_obj_del(dev, port_obj_info, check_cb, + NULL, del_cb); + if (err == -EOPNOTSUPP) + err = 0; + return err; +} +EXPORT_SYMBOL_GPL(switchdev_handle_port_obj_del); + +/* Same as switchdev_handle_port_obj_del(), except if object is notified on a + * @dev that passes @foreign_dev_check_cb, it is replicated towards all devices + * that pass @check_cb and are in the same bridge as @dev. + */ +int switchdev_handle_port_obj_del_foreign(struct net_device *dev, + struct switchdev_notifier_port_obj_info *port_obj_info, + bool (*check_cb)(const struct net_device *dev), + bool (*foreign_dev_check_cb)(const struct net_device *dev, + const struct net_device *foreign_dev), + int (*del_cb)(struct net_device *dev, const void *ctx, + const struct switchdev_obj *obj)) +{ + int err; + + err = __switchdev_handle_port_obj_del(dev, port_obj_info, check_cb, + foreign_dev_check_cb, del_cb); + if (err == -EOPNOTSUPP) + err = 0; + return err; +} +EXPORT_SYMBOL_GPL(switchdev_handle_port_obj_del_foreign); + +static int __switchdev_handle_port_attr_set(struct net_device *dev, + struct switchdev_notifier_port_attr_info *port_attr_info, + bool (*check_cb)(const struct net_device *dev), + int (*set_cb)(struct net_device *dev, const void *ctx, + const struct switchdev_attr *attr, + struct netlink_ext_ack *extack)) +{ + struct switchdev_notifier_info *info = &port_attr_info->info; + struct netlink_ext_ack *extack; + struct net_device *lower_dev; + struct list_head *iter; + int err = -EOPNOTSUPP; + + extack = switchdev_notifier_info_to_extack(info); + + if (check_cb(dev)) { + err = set_cb(dev, info->ctx, port_attr_info->attr, extack); + if (err != -EOPNOTSUPP) + port_attr_info->handled = true; + return err; + } + + /* Switch ports might be stacked under e.g. a LAG. Ignore the + * unsupported devices, another driver might be able to handle them. But + * propagate to the callers any hard errors. + * + * If the driver does its own bookkeeping of stacked ports, it's not + * necessary to go through this helper. + */ + netdev_for_each_lower_dev(dev, lower_dev, iter) { + if (netif_is_bridge_master(lower_dev)) + continue; + + err = __switchdev_handle_port_attr_set(lower_dev, port_attr_info, + check_cb, set_cb); + if (err && err != -EOPNOTSUPP) + return err; + } + + return err; +} + +int switchdev_handle_port_attr_set(struct net_device *dev, + struct switchdev_notifier_port_attr_info *port_attr_info, + bool (*check_cb)(const struct net_device *dev), + int (*set_cb)(struct net_device *dev, const void *ctx, + const struct switchdev_attr *attr, + struct netlink_ext_ack *extack)) +{ + int err; + + err = __switchdev_handle_port_attr_set(dev, port_attr_info, check_cb, + set_cb); + if (err == -EOPNOTSUPP) + err = 0; + return err; +} +EXPORT_SYMBOL_GPL(switchdev_handle_port_attr_set); + +int switchdev_bridge_port_offload(struct net_device *brport_dev, + struct net_device *dev, const void *ctx, + struct notifier_block *atomic_nb, + struct notifier_block *blocking_nb, + bool tx_fwd_offload, + struct netlink_ext_ack *extack) +{ + struct switchdev_notifier_brport_info brport_info = { + .brport = { + .dev = dev, + .ctx = ctx, + .atomic_nb = atomic_nb, + .blocking_nb = blocking_nb, + .tx_fwd_offload = tx_fwd_offload, + }, + }; + int err; + + ASSERT_RTNL(); + + err = call_switchdev_blocking_notifiers(SWITCHDEV_BRPORT_OFFLOADED, + brport_dev, &brport_info.info, + extack); + return notifier_to_errno(err); +} +EXPORT_SYMBOL_GPL(switchdev_bridge_port_offload); + +void switchdev_bridge_port_unoffload(struct net_device *brport_dev, + const void *ctx, + struct notifier_block *atomic_nb, + struct notifier_block *blocking_nb) +{ + struct switchdev_notifier_brport_info brport_info = { + .brport = { + .ctx = ctx, + .atomic_nb = atomic_nb, + .blocking_nb = blocking_nb, + }, + }; + + ASSERT_RTNL(); + + call_switchdev_blocking_notifiers(SWITCHDEV_BRPORT_UNOFFLOADED, + brport_dev, &brport_info.info, + NULL); +} +EXPORT_SYMBOL_GPL(switchdev_bridge_port_unoffload); diff --git a/net/sysctl_net.c b/net/sysctl_net.c new file mode 100644 index 000000000..4b45ed631 --- /dev/null +++ b/net/sysctl_net.c @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* -*- linux-c -*- + * sysctl_net.c: sysctl interface to net subsystem. + * + * Begun April 1, 1996, Mike Shaver. + * Added /proc/sys/net directories for each protocol family. [MS] + * + * Revision 1.2 1996/05/08 20:24:40 shaver + * Added bits for NET_BRIDGE and the NET_IPV4_ARP stuff and + * NET_IPV4_IP_FORWARD. + * + * + */ + +#include +#include +#include +#include + +#include + +#ifdef CONFIG_INET +#include +#endif + +#ifdef CONFIG_NET +#include +#endif + +static struct ctl_table_set * +net_ctl_header_lookup(struct ctl_table_root *root) +{ + return ¤t->nsproxy->net_ns->sysctls; +} + +static int is_seen(struct ctl_table_set *set) +{ + return ¤t->nsproxy->net_ns->sysctls == set; +} + +/* Return standard mode bits for table entry. */ +static int net_ctl_permissions(struct ctl_table_header *head, + struct ctl_table *table) +{ + struct net *net = container_of(head->set, struct net, sysctls); + + /* Allow network administrator to have same access as root. */ + if (ns_capable_noaudit(net->user_ns, CAP_NET_ADMIN)) { + int mode = (table->mode >> 6) & 7; + return (mode << 6) | (mode << 3) | mode; + } + + return table->mode; +} + +static void net_ctl_set_ownership(struct ctl_table_header *head, + struct ctl_table *table, + kuid_t *uid, kgid_t *gid) +{ + struct net *net = container_of(head->set, struct net, sysctls); + kuid_t ns_root_uid; + kgid_t ns_root_gid; + + ns_root_uid = make_kuid(net->user_ns, 0); + if (uid_valid(ns_root_uid)) + *uid = ns_root_uid; + + ns_root_gid = make_kgid(net->user_ns, 0); + if (gid_valid(ns_root_gid)) + *gid = ns_root_gid; +} + +static struct ctl_table_root net_sysctl_root = { + .lookup = net_ctl_header_lookup, + .permissions = net_ctl_permissions, + .set_ownership = net_ctl_set_ownership, +}; + +static int __net_init sysctl_net_init(struct net *net) +{ + setup_sysctl_set(&net->sysctls, &net_sysctl_root, is_seen); + return 0; +} + +static void __net_exit sysctl_net_exit(struct net *net) +{ + retire_sysctl_set(&net->sysctls); +} + +static struct pernet_operations sysctl_pernet_ops = { + .init = sysctl_net_init, + .exit = sysctl_net_exit, +}; + +static struct ctl_table_header *net_header; +__init int net_sysctl_init(void) +{ + static struct ctl_table empty[1]; + int ret = -ENOMEM; + /* Avoid limitations in the sysctl implementation by + * registering "/proc/sys/net" as an empty directory not in a + * network namespace. + */ + net_header = register_sysctl("net", empty); + if (!net_header) + goto out; + ret = register_pernet_subsys(&sysctl_pernet_ops); + if (ret) + goto out1; +out: + return ret; +out1: + unregister_sysctl_table(net_header); + net_header = NULL; + goto out; +} + +/* Verify that sysctls for non-init netns are safe by either: + * 1) being read-only, or + * 2) having a data pointer which points outside of the global kernel/module + * data segment, and rather into the heap where a per-net object was + * allocated. + */ +static void ensure_safe_net_sysctl(struct net *net, const char *path, + struct ctl_table *table) +{ + struct ctl_table *ent; + + pr_debug("Registering net sysctl (net %p): %s\n", net, path); + for (ent = table; ent->procname; ent++) { + unsigned long addr; + const char *where; + + pr_debug(" procname=%s mode=%o proc_handler=%ps data=%p\n", + ent->procname, ent->mode, ent->proc_handler, ent->data); + + /* If it's not writable inside the netns, then it can't hurt. */ + if ((ent->mode & 0222) == 0) { + pr_debug(" Not writable by anyone\n"); + continue; + } + + /* Where does data point? */ + addr = (unsigned long)ent->data; + if (is_module_address(addr)) + where = "module"; + else if (is_kernel_core_data(addr)) + where = "kernel"; + else + continue; + + /* If it is writable and points to kernel/module global + * data, then it's probably a netns leak. + */ + WARN(1, "sysctl %s/%s: data points to %s global data: %ps\n", + path, ent->procname, where, ent->data); + + /* Make it "safe" by dropping writable perms */ + ent->mode &= ~0222; + } +} + +struct ctl_table_header *register_net_sysctl(struct net *net, + const char *path, struct ctl_table *table) +{ + if (!net_eq(net, &init_net)) + ensure_safe_net_sysctl(net, path, table); + + return __register_sysctl_table(&net->sysctls, path, table); +} +EXPORT_SYMBOL_GPL(register_net_sysctl); + +void unregister_net_sysctl_table(struct ctl_table_header *header) +{ + unregister_sysctl_table(header); +} +EXPORT_SYMBOL_GPL(unregister_net_sysctl_table); diff --git a/net/tipc/Kconfig b/net/tipc/Kconfig new file mode 100644 index 000000000..be1c4003d --- /dev/null +++ b/net/tipc/Kconfig @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# TIPC configuration +# + +menuconfig TIPC + tristate "The TIPC Protocol" + depends on INET + depends on IPV6 || IPV6=n + help + The Transparent Inter Process Communication (TIPC) protocol is + specially designed for intra cluster communication. This protocol + originates from Ericsson where it has been used in carrier grade + cluster applications for many years. + + For more information about TIPC, see http://tipc.sourceforge.net. + + This protocol support is also available as a module ( = code which + can be inserted in and removed from the running kernel whenever you + want). The module will be called tipc. If you want to compile it + as a module, say M here and read . + + If in doubt, say N. + +config TIPC_MEDIA_IB + bool "InfiniBand media type support" + depends on TIPC && INFINIBAND_IPOIB + help + Saying Y here will enable support for running TIPC on + IP-over-InfiniBand devices. +config TIPC_MEDIA_UDP + bool "IP/UDP media type support" + depends on TIPC + select NET_UDP_TUNNEL + help + Saying Y here will enable support for running TIPC over IP/UDP + bool + default y +config TIPC_CRYPTO + bool "TIPC encryption support" + depends on TIPC + select CRYPTO + select CRYPTO_AES + select CRYPTO_GCM + help + Saying Y here will enable support for TIPC encryption. + All TIPC messages will be encrypted/decrypted by using the currently most + advanced algorithm: AEAD AES-GCM (like IPSec or TLS) before leaving/ + entering the TIPC stack. + Key setting from user-space is performed via netlink by a user program + (e.g. the iproute2 'tipc' tool). + bool + default y + +config TIPC_DIAG + tristate "TIPC: socket monitoring interface" + depends on TIPC + default y + help + Support for TIPC socket monitoring interface used by ss tool. + If unsure, say Y. diff --git a/net/tipc/Makefile b/net/tipc/Makefile new file mode 100644 index 000000000..ee49a9f1d --- /dev/null +++ b/net/tipc/Makefile @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the Linux TIPC layer +# + +obj-$(CONFIG_TIPC) := tipc.o + +tipc-y += addr.o bcast.o bearer.o \ + core.o link.o discover.o msg.o \ + name_distr.o subscr.o monitor.o name_table.o net.o \ + netlink.o netlink_compat.o node.o socket.o eth_media.o \ + topsrv.o group.o trace.o + +CFLAGS_trace.o += -I$(src) + +tipc-$(CONFIG_TIPC_MEDIA_UDP) += udp_media.o +tipc-$(CONFIG_TIPC_MEDIA_IB) += ib_media.o +tipc-$(CONFIG_SYSCTL) += sysctl.o +tipc-$(CONFIG_TIPC_CRYPTO) += crypto.o + + +obj-$(CONFIG_TIPC_DIAG) += diag.o diff --git a/net/tipc/addr.c b/net/tipc/addr.c new file mode 100644 index 000000000..fd0796269 --- /dev/null +++ b/net/tipc/addr.c @@ -0,0 +1,124 @@ +/* + * net/tipc/addr.c: TIPC address utility routines + * + * Copyright (c) 2000-2006, 2018, Ericsson AB + * Copyright (c) 2004-2005, 2010-2011, Wind River Systems + * Copyright (c) 2020-2021, Red Hat Inc + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "addr.h" +#include "core.h" + +bool tipc_in_scope(bool legacy_format, u32 domain, u32 addr) +{ + if (!domain || (domain == addr)) + return true; + if (!legacy_format) + return false; + if (domain == tipc_cluster_mask(addr)) /* domain */ + return true; + if (domain == (addr & TIPC_ZONE_CLUSTER_MASK)) /* domain */ + return true; + if (domain == (addr & TIPC_ZONE_MASK)) /* domain */ + return true; + return false; +} + +void tipc_set_node_id(struct net *net, u8 *id) +{ + struct tipc_net *tn = tipc_net(net); + + memcpy(tn->node_id, id, NODE_ID_LEN); + tipc_nodeid2string(tn->node_id_string, id); + tn->trial_addr = hash128to32(id); + pr_info("Node identity %s, cluster identity %u\n", + tipc_own_id_string(net), tn->net_id); +} + +void tipc_set_node_addr(struct net *net, u32 addr) +{ + struct tipc_net *tn = tipc_net(net); + u8 node_id[NODE_ID_LEN] = {0,}; + + tn->node_addr = addr; + if (!tipc_own_id(net)) { + sprintf(node_id, "%x", addr); + tipc_set_node_id(net, node_id); + } + tn->trial_addr = addr; + tn->addr_trial_end = jiffies; + pr_info("Node number set to %u\n", addr); +} + +char *tipc_nodeid2string(char *str, u8 *id) +{ + int i; + u8 c; + + /* Already a string ? */ + for (i = 0; i < NODE_ID_LEN; i++) { + c = id[i]; + if (c >= '0' && c <= '9') + continue; + if (c >= 'A' && c <= 'Z') + continue; + if (c >= 'a' && c <= 'z') + continue; + if (c == '.') + continue; + if (c == ':') + continue; + if (c == '_') + continue; + if (c == '-') + continue; + if (c == '@') + continue; + if (c != 0) + break; + } + if (i == NODE_ID_LEN) { + memcpy(str, id, NODE_ID_LEN); + str[NODE_ID_LEN] = 0; + return str; + } + + /* Translate to hex string */ + for (i = 0; i < NODE_ID_LEN; i++) + sprintf(&str[2 * i], "%02x", id[i]); + + /* Strip off trailing zeroes */ + for (i = NODE_ID_STR_LEN - 2; str[i] == '0'; i--) + str[i] = 0; + + return str; +} diff --git a/net/tipc/addr.h b/net/tipc/addr.h new file mode 100644 index 000000000..0772cfada --- /dev/null +++ b/net/tipc/addr.h @@ -0,0 +1,136 @@ +/* + * net/tipc/addr.h: Include file for TIPC address utility routines + * + * Copyright (c) 2000-2006, 2018, Ericsson AB + * Copyright (c) 2004-2005, Wind River Systems + * Copyright (c) 2020-2021, Red Hat Inc + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _TIPC_ADDR_H +#define _TIPC_ADDR_H + +#include +#include +#include +#include +#include "core.h" + +/* Struct tipc_uaddr: internal version of struct sockaddr_tipc. + * Must be kept aligned both regarding field positions and size. + */ +struct tipc_uaddr { + unsigned short family; + unsigned char addrtype; + signed char scope; + union { + struct { + struct tipc_service_addr sa; + u32 lookup_node; + }; + struct tipc_service_range sr; + struct tipc_socket_addr sk; + }; +}; + +static inline void tipc_uaddr(struct tipc_uaddr *ua, u32 atype, u32 scope, + u32 type, u32 lower, u32 upper) +{ + ua->family = AF_TIPC; + ua->addrtype = atype; + ua->scope = scope; + ua->sr.type = type; + ua->sr.lower = lower; + ua->sr.upper = upper; +} + +static inline bool tipc_uaddr_valid(struct tipc_uaddr *ua, int len) +{ + u32 atype; + + if (len < sizeof(struct sockaddr_tipc)) + return false; + atype = ua->addrtype; + if (ua->family != AF_TIPC) + return false; + if (atype == TIPC_SERVICE_ADDR || atype == TIPC_SOCKET_ADDR) + return true; + if (atype == TIPC_SERVICE_RANGE) + return ua->sr.upper >= ua->sr.lower; + return false; +} + +static inline u32 tipc_own_addr(struct net *net) +{ + return tipc_net(net)->node_addr; +} + +static inline u8 *tipc_own_id(struct net *net) +{ + struct tipc_net *tn = tipc_net(net); + + if (!strlen(tn->node_id_string)) + return NULL; + return tn->node_id; +} + +static inline char *tipc_own_id_string(struct net *net) +{ + return tipc_net(net)->node_id_string; +} + +static inline u32 tipc_cluster_mask(u32 addr) +{ + return addr & TIPC_ZONE_CLUSTER_MASK; +} + +static inline int tipc_node2scope(u32 node) +{ + return node ? TIPC_NODE_SCOPE : TIPC_CLUSTER_SCOPE; +} + +static inline int tipc_scope2node(struct net *net, int sc) +{ + return sc != TIPC_NODE_SCOPE ? 0 : tipc_own_addr(net); +} + +static inline int in_own_node(struct net *net, u32 addr) +{ + return addr == tipc_own_addr(net) || !addr; +} + +bool tipc_in_scope(bool legacy_format, u32 domain, u32 addr); +void tipc_set_node_id(struct net *net, u8 *id); +void tipc_set_node_addr(struct net *net, u32 addr); +char *tipc_nodeid2string(char *str, u8 *id); +u32 tipc_node_id2hash(u8 *id128); + +#endif diff --git a/net/tipc/bcast.c b/net/tipc/bcast.c new file mode 100644 index 000000000..593846d25 --- /dev/null +++ b/net/tipc/bcast.c @@ -0,0 +1,864 @@ +/* + * net/tipc/bcast.c: TIPC broadcast code + * + * Copyright (c) 2004-2006, 2014-2017, Ericsson AB + * Copyright (c) 2004, Intel Corporation. + * Copyright (c) 2005, 2010-2011, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include +#include "socket.h" +#include "msg.h" +#include "bcast.h" +#include "link.h" +#include "name_table.h" + +#define BCLINK_WIN_DEFAULT 50 /* bcast link window size (default) */ +#define BCLINK_WIN_MIN 32 /* bcast minimum link window size */ + +const char tipc_bclink_name[] = "broadcast-link"; +unsigned long sysctl_tipc_bc_retruni __read_mostly; + +/** + * struct tipc_bc_base - base structure for keeping broadcast send state + * @link: broadcast send link structure + * @inputq: data input queue; will only carry SOCK_WAKEUP messages + * @dests: array keeping number of reachable destinations per bearer + * @primary_bearer: a bearer having links to all broadcast destinations, if any + * @bcast_support: indicates if primary bearer, if any, supports broadcast + * @force_bcast: forces broadcast for multicast traffic + * @rcast_support: indicates if all peer nodes support replicast + * @force_rcast: forces replicast for multicast traffic + * @rc_ratio: dest count as percentage of cluster size where send method changes + * @bc_threshold: calculated from rc_ratio; if dests > threshold use broadcast + */ +struct tipc_bc_base { + struct tipc_link *link; + struct sk_buff_head inputq; + int dests[MAX_BEARERS]; + int primary_bearer; + bool bcast_support; + bool force_bcast; + bool rcast_support; + bool force_rcast; + int rc_ratio; + int bc_threshold; +}; + +static struct tipc_bc_base *tipc_bc_base(struct net *net) +{ + return tipc_net(net)->bcbase; +} + +/* tipc_bcast_get_mtu(): -get the MTU currently used by broadcast link + * Note: the MTU is decremented to give room for a tunnel header, in + * case the message needs to be sent as replicast + */ +int tipc_bcast_get_mtu(struct net *net) +{ + return tipc_link_mss(tipc_bc_sndlink(net)); +} + +void tipc_bcast_toggle_rcast(struct net *net, bool supp) +{ + tipc_bc_base(net)->rcast_support = supp; +} + +static void tipc_bcbase_calc_bc_threshold(struct net *net) +{ + struct tipc_bc_base *bb = tipc_bc_base(net); + int cluster_size = tipc_link_bc_peers(tipc_bc_sndlink(net)); + + bb->bc_threshold = 1 + (cluster_size * bb->rc_ratio / 100); +} + +/* tipc_bcbase_select_primary(): find a bearer with links to all destinations, + * if any, and make it primary bearer + */ +static void tipc_bcbase_select_primary(struct net *net) +{ + struct tipc_bc_base *bb = tipc_bc_base(net); + int all_dests = tipc_link_bc_peers(bb->link); + int max_win = tipc_link_max_win(bb->link); + int min_win = tipc_link_min_win(bb->link); + int i, mtu, prim; + + bb->primary_bearer = INVALID_BEARER_ID; + bb->bcast_support = true; + + if (!all_dests) + return; + + for (i = 0; i < MAX_BEARERS; i++) { + if (!bb->dests[i]) + continue; + + mtu = tipc_bearer_mtu(net, i); + if (mtu < tipc_link_mtu(bb->link)) { + tipc_link_set_mtu(bb->link, mtu); + tipc_link_set_queue_limits(bb->link, + min_win, + max_win); + } + bb->bcast_support &= tipc_bearer_bcast_support(net, i); + if (bb->dests[i] < all_dests) + continue; + + bb->primary_bearer = i; + + /* Reduce risk that all nodes select same primary */ + if ((i ^ tipc_own_addr(net)) & 1) + break; + } + prim = bb->primary_bearer; + if (prim != INVALID_BEARER_ID) + bb->bcast_support = tipc_bearer_bcast_support(net, prim); +} + +void tipc_bcast_inc_bearer_dst_cnt(struct net *net, int bearer_id) +{ + struct tipc_bc_base *bb = tipc_bc_base(net); + + tipc_bcast_lock(net); + bb->dests[bearer_id]++; + tipc_bcbase_select_primary(net); + tipc_bcast_unlock(net); +} + +void tipc_bcast_dec_bearer_dst_cnt(struct net *net, int bearer_id) +{ + struct tipc_bc_base *bb = tipc_bc_base(net); + + tipc_bcast_lock(net); + bb->dests[bearer_id]--; + tipc_bcbase_select_primary(net); + tipc_bcast_unlock(net); +} + +/* tipc_bcbase_xmit - broadcast a packet queue across one or more bearers + * + * Note that number of reachable destinations, as indicated in the dests[] + * array, may transitionally differ from the number of destinations indicated + * in each sent buffer. We can sustain this. Excess destination nodes will + * drop and never acknowledge the unexpected packets, and missing destinations + * will either require retransmission (if they are just about to be added to + * the bearer), or be removed from the buffer's 'ackers' counter (if they + * just went down) + */ +static void tipc_bcbase_xmit(struct net *net, struct sk_buff_head *xmitq) +{ + int bearer_id; + struct tipc_bc_base *bb = tipc_bc_base(net); + struct sk_buff *skb, *_skb; + struct sk_buff_head _xmitq; + + if (skb_queue_empty(xmitq)) + return; + + /* The typical case: at least one bearer has links to all nodes */ + bearer_id = bb->primary_bearer; + if (bearer_id >= 0) { + tipc_bearer_bc_xmit(net, bearer_id, xmitq); + return; + } + + /* We have to transmit across all bearers */ + __skb_queue_head_init(&_xmitq); + for (bearer_id = 0; bearer_id < MAX_BEARERS; bearer_id++) { + if (!bb->dests[bearer_id]) + continue; + + skb_queue_walk(xmitq, skb) { + _skb = pskb_copy_for_clone(skb, GFP_ATOMIC); + if (!_skb) + break; + __skb_queue_tail(&_xmitq, _skb); + } + tipc_bearer_bc_xmit(net, bearer_id, &_xmitq); + } + __skb_queue_purge(xmitq); + __skb_queue_purge(&_xmitq); +} + +static void tipc_bcast_select_xmit_method(struct net *net, int dests, + struct tipc_mc_method *method) +{ + struct tipc_bc_base *bb = tipc_bc_base(net); + unsigned long exp = method->expires; + + /* Broadcast supported by used bearer/bearers? */ + if (!bb->bcast_support) { + method->rcast = true; + return; + } + /* Any destinations which don't support replicast ? */ + if (!bb->rcast_support) { + method->rcast = false; + return; + } + /* Can current method be changed ? */ + method->expires = jiffies + TIPC_METHOD_EXPIRE; + if (method->mandatory) + return; + + if (!(tipc_net(net)->capabilities & TIPC_MCAST_RBCTL) && + time_before(jiffies, exp)) + return; + + /* Configuration as force 'broadcast' method */ + if (bb->force_bcast) { + method->rcast = false; + return; + } + /* Configuration as force 'replicast' method */ + if (bb->force_rcast) { + method->rcast = true; + return; + } + /* Configuration as 'autoselect' or default method */ + /* Determine method to use now */ + method->rcast = dests <= bb->bc_threshold; +} + +/* tipc_bcast_xmit - broadcast the buffer chain to all external nodes + * @net: the applicable net namespace + * @pkts: chain of buffers containing message + * @cong_link_cnt: set to 1 if broadcast link is congested, otherwise 0 + * Consumes the buffer chain. + * Returns 0 if success, otherwise errno: -EHOSTUNREACH,-EMSGSIZE + */ +int tipc_bcast_xmit(struct net *net, struct sk_buff_head *pkts, + u16 *cong_link_cnt) +{ + struct tipc_link *l = tipc_bc_sndlink(net); + struct sk_buff_head xmitq; + int rc = 0; + + __skb_queue_head_init(&xmitq); + tipc_bcast_lock(net); + if (tipc_link_bc_peers(l)) + rc = tipc_link_xmit(l, pkts, &xmitq); + tipc_bcast_unlock(net); + tipc_bcbase_xmit(net, &xmitq); + __skb_queue_purge(pkts); + if (rc == -ELINKCONG) { + *cong_link_cnt = 1; + rc = 0; + } + return rc; +} + +/* tipc_rcast_xmit - replicate and send a message to given destination nodes + * @net: the applicable net namespace + * @pkts: chain of buffers containing message + * @dests: list of destination nodes + * @cong_link_cnt: returns number of congested links + * @cong_links: returns identities of congested links + * Returns 0 if success, otherwise errno + */ +static int tipc_rcast_xmit(struct net *net, struct sk_buff_head *pkts, + struct tipc_nlist *dests, u16 *cong_link_cnt) +{ + struct tipc_dest *dst, *tmp; + struct sk_buff_head _pkts; + u32 dnode, selector; + + selector = msg_link_selector(buf_msg(skb_peek(pkts))); + __skb_queue_head_init(&_pkts); + + list_for_each_entry_safe(dst, tmp, &dests->list, list) { + dnode = dst->node; + if (!tipc_msg_pskb_copy(dnode, pkts, &_pkts)) + return -ENOMEM; + + /* Any other return value than -ELINKCONG is ignored */ + if (tipc_node_xmit(net, &_pkts, dnode, selector) == -ELINKCONG) + (*cong_link_cnt)++; + } + return 0; +} + +/* tipc_mcast_send_sync - deliver a dummy message with SYN bit + * @net: the applicable net namespace + * @skb: socket buffer to copy + * @method: send method to be used + * @dests: destination nodes for message. + * Returns 0 if success, otherwise errno + */ +static int tipc_mcast_send_sync(struct net *net, struct sk_buff *skb, + struct tipc_mc_method *method, + struct tipc_nlist *dests) +{ + struct tipc_msg *hdr, *_hdr; + struct sk_buff_head tmpq; + struct sk_buff *_skb; + u16 cong_link_cnt; + int rc = 0; + + /* Is a cluster supporting with new capabilities ? */ + if (!(tipc_net(net)->capabilities & TIPC_MCAST_RBCTL)) + return 0; + + hdr = buf_msg(skb); + if (msg_user(hdr) == MSG_FRAGMENTER) + hdr = msg_inner_hdr(hdr); + if (msg_type(hdr) != TIPC_MCAST_MSG) + return 0; + + /* Allocate dummy message */ + _skb = tipc_buf_acquire(MCAST_H_SIZE, GFP_KERNEL); + if (!_skb) + return -ENOMEM; + + /* Preparing for 'synching' header */ + msg_set_syn(hdr, 1); + + /* Copy skb's header into a dummy header */ + skb_copy_to_linear_data(_skb, hdr, MCAST_H_SIZE); + skb_orphan(_skb); + + /* Reverse method for dummy message */ + _hdr = buf_msg(_skb); + msg_set_size(_hdr, MCAST_H_SIZE); + msg_set_is_rcast(_hdr, !msg_is_rcast(hdr)); + msg_set_errcode(_hdr, TIPC_ERR_NO_PORT); + + __skb_queue_head_init(&tmpq); + __skb_queue_tail(&tmpq, _skb); + if (method->rcast) + rc = tipc_bcast_xmit(net, &tmpq, &cong_link_cnt); + else + rc = tipc_rcast_xmit(net, &tmpq, dests, &cong_link_cnt); + + /* This queue should normally be empty by now */ + __skb_queue_purge(&tmpq); + + return rc; +} + +/* tipc_mcast_xmit - deliver message to indicated destination nodes + * and to identified node local sockets + * @net: the applicable net namespace + * @pkts: chain of buffers containing message + * @method: send method to be used + * @dests: destination nodes for message. + * @cong_link_cnt: returns number of encountered congested destination links + * Consumes buffer chain. + * Returns 0 if success, otherwise errno + */ +int tipc_mcast_xmit(struct net *net, struct sk_buff_head *pkts, + struct tipc_mc_method *method, struct tipc_nlist *dests, + u16 *cong_link_cnt) +{ + struct sk_buff_head inputq, localq; + bool rcast = method->rcast; + struct tipc_msg *hdr; + struct sk_buff *skb; + int rc = 0; + + skb_queue_head_init(&inputq); + __skb_queue_head_init(&localq); + + /* Clone packets before they are consumed by next call */ + if (dests->local && !tipc_msg_reassemble(pkts, &localq)) { + rc = -ENOMEM; + goto exit; + } + /* Send according to determined transmit method */ + if (dests->remote) { + tipc_bcast_select_xmit_method(net, dests->remote, method); + + skb = skb_peek(pkts); + hdr = buf_msg(skb); + if (msg_user(hdr) == MSG_FRAGMENTER) + hdr = msg_inner_hdr(hdr); + msg_set_is_rcast(hdr, method->rcast); + + /* Switch method ? */ + if (rcast != method->rcast) { + rc = tipc_mcast_send_sync(net, skb, method, dests); + if (unlikely(rc)) { + pr_err("Unable to send SYN: method %d, rc %d\n", + rcast, rc); + goto exit; + } + } + + if (method->rcast) + rc = tipc_rcast_xmit(net, pkts, dests, cong_link_cnt); + else + rc = tipc_bcast_xmit(net, pkts, cong_link_cnt); + } + + if (dests->local) { + tipc_loopback_trace(net, &localq); + tipc_sk_mcast_rcv(net, &localq, &inputq); + } +exit: + /* This queue should normally be empty by now */ + __skb_queue_purge(pkts); + return rc; +} + +/* tipc_bcast_rcv - receive a broadcast packet, and deliver to rcv link + * + * RCU is locked, no other locks set + */ +int tipc_bcast_rcv(struct net *net, struct tipc_link *l, struct sk_buff *skb) +{ + struct tipc_msg *hdr = buf_msg(skb); + struct sk_buff_head *inputq = &tipc_bc_base(net)->inputq; + struct sk_buff_head xmitq; + int rc; + + __skb_queue_head_init(&xmitq); + + if (msg_mc_netid(hdr) != tipc_netid(net) || !tipc_link_is_up(l)) { + kfree_skb(skb); + return 0; + } + + tipc_bcast_lock(net); + if (msg_user(hdr) == BCAST_PROTOCOL) + rc = tipc_link_bc_nack_rcv(l, skb, &xmitq); + else + rc = tipc_link_rcv(l, skb, NULL); + tipc_bcast_unlock(net); + + tipc_bcbase_xmit(net, &xmitq); + + /* Any socket wakeup messages ? */ + if (!skb_queue_empty(inputq)) + tipc_sk_rcv(net, inputq); + + return rc; +} + +/* tipc_bcast_ack_rcv - receive and handle a broadcast acknowledge + * + * RCU is locked, no other locks set + */ +void tipc_bcast_ack_rcv(struct net *net, struct tipc_link *l, + struct tipc_msg *hdr) +{ + struct sk_buff_head *inputq = &tipc_bc_base(net)->inputq; + u16 acked = msg_bcast_ack(hdr); + struct sk_buff_head xmitq; + + /* Ignore bc acks sent by peer before bcast synch point was received */ + if (msg_bc_ack_invalid(hdr)) + return; + + __skb_queue_head_init(&xmitq); + + tipc_bcast_lock(net); + tipc_link_bc_ack_rcv(l, acked, 0, NULL, &xmitq, NULL); + tipc_bcast_unlock(net); + + tipc_bcbase_xmit(net, &xmitq); + + /* Any socket wakeup messages ? */ + if (!skb_queue_empty(inputq)) + tipc_sk_rcv(net, inputq); +} + +/* tipc_bcast_synch_rcv - check and update rcv link with peer's send state + * + * RCU is locked, no other locks set + */ +int tipc_bcast_sync_rcv(struct net *net, struct tipc_link *l, + struct tipc_msg *hdr, + struct sk_buff_head *retrq) +{ + struct sk_buff_head *inputq = &tipc_bc_base(net)->inputq; + struct tipc_gap_ack_blks *ga; + struct sk_buff_head xmitq; + int rc = 0; + + __skb_queue_head_init(&xmitq); + + tipc_bcast_lock(net); + if (msg_type(hdr) != STATE_MSG) { + tipc_link_bc_init_rcv(l, hdr); + } else if (!msg_bc_ack_invalid(hdr)) { + tipc_get_gap_ack_blks(&ga, l, hdr, false); + if (!sysctl_tipc_bc_retruni) + retrq = &xmitq; + rc = tipc_link_bc_ack_rcv(l, msg_bcast_ack(hdr), + msg_bc_gap(hdr), ga, &xmitq, + retrq); + rc |= tipc_link_bc_sync_rcv(l, hdr, &xmitq); + } + tipc_bcast_unlock(net); + + tipc_bcbase_xmit(net, &xmitq); + + /* Any socket wakeup messages ? */ + if (!skb_queue_empty(inputq)) + tipc_sk_rcv(net, inputq); + return rc; +} + +/* tipc_bcast_add_peer - add a peer node to broadcast link and bearer + * + * RCU is locked, node lock is set + */ +void tipc_bcast_add_peer(struct net *net, struct tipc_link *uc_l, + struct sk_buff_head *xmitq) +{ + struct tipc_link *snd_l = tipc_bc_sndlink(net); + + tipc_bcast_lock(net); + tipc_link_add_bc_peer(snd_l, uc_l, xmitq); + tipc_bcbase_select_primary(net); + tipc_bcbase_calc_bc_threshold(net); + tipc_bcast_unlock(net); +} + +/* tipc_bcast_remove_peer - remove a peer node from broadcast link and bearer + * + * RCU is locked, node lock is set + */ +void tipc_bcast_remove_peer(struct net *net, struct tipc_link *rcv_l) +{ + struct tipc_link *snd_l = tipc_bc_sndlink(net); + struct sk_buff_head *inputq = &tipc_bc_base(net)->inputq; + struct sk_buff_head xmitq; + + __skb_queue_head_init(&xmitq); + + tipc_bcast_lock(net); + tipc_link_remove_bc_peer(snd_l, rcv_l, &xmitq); + tipc_bcbase_select_primary(net); + tipc_bcbase_calc_bc_threshold(net); + tipc_bcast_unlock(net); + + tipc_bcbase_xmit(net, &xmitq); + + /* Any socket wakeup messages ? */ + if (!skb_queue_empty(inputq)) + tipc_sk_rcv(net, inputq); +} + +int tipc_bclink_reset_stats(struct net *net, struct tipc_link *l) +{ + if (!l) + return -ENOPROTOOPT; + + tipc_bcast_lock(net); + tipc_link_reset_stats(l); + tipc_bcast_unlock(net); + return 0; +} + +static int tipc_bc_link_set_queue_limits(struct net *net, u32 max_win) +{ + struct tipc_link *l = tipc_bc_sndlink(net); + + if (!l) + return -ENOPROTOOPT; + if (max_win < BCLINK_WIN_MIN) + max_win = BCLINK_WIN_MIN; + if (max_win > TIPC_MAX_LINK_WIN) + return -EINVAL; + tipc_bcast_lock(net); + tipc_link_set_queue_limits(l, tipc_link_min_win(l), max_win); + tipc_bcast_unlock(net); + return 0; +} + +static int tipc_bc_link_set_broadcast_mode(struct net *net, u32 bc_mode) +{ + struct tipc_bc_base *bb = tipc_bc_base(net); + + switch (bc_mode) { + case BCLINK_MODE_BCAST: + if (!bb->bcast_support) + return -ENOPROTOOPT; + + bb->force_bcast = true; + bb->force_rcast = false; + break; + case BCLINK_MODE_RCAST: + if (!bb->rcast_support) + return -ENOPROTOOPT; + + bb->force_bcast = false; + bb->force_rcast = true; + break; + case BCLINK_MODE_SEL: + if (!bb->bcast_support || !bb->rcast_support) + return -ENOPROTOOPT; + + bb->force_bcast = false; + bb->force_rcast = false; + break; + default: + return -EINVAL; + } + + return 0; +} + +static int tipc_bc_link_set_broadcast_ratio(struct net *net, u32 bc_ratio) +{ + struct tipc_bc_base *bb = tipc_bc_base(net); + + if (!bb->bcast_support || !bb->rcast_support) + return -ENOPROTOOPT; + + if (bc_ratio > 100 || bc_ratio <= 0) + return -EINVAL; + + bb->rc_ratio = bc_ratio; + tipc_bcast_lock(net); + tipc_bcbase_calc_bc_threshold(net); + tipc_bcast_unlock(net); + + return 0; +} + +int tipc_nl_bc_link_set(struct net *net, struct nlattr *attrs[]) +{ + int err; + u32 win; + u32 bc_mode; + u32 bc_ratio; + struct nlattr *props[TIPC_NLA_PROP_MAX + 1]; + + if (!attrs[TIPC_NLA_LINK_PROP]) + return -EINVAL; + + err = tipc_nl_parse_link_prop(attrs[TIPC_NLA_LINK_PROP], props); + if (err) + return err; + + if (!props[TIPC_NLA_PROP_WIN] && + !props[TIPC_NLA_PROP_BROADCAST] && + !props[TIPC_NLA_PROP_BROADCAST_RATIO]) { + return -EOPNOTSUPP; + } + + if (props[TIPC_NLA_PROP_BROADCAST]) { + bc_mode = nla_get_u32(props[TIPC_NLA_PROP_BROADCAST]); + err = tipc_bc_link_set_broadcast_mode(net, bc_mode); + } + + if (!err && props[TIPC_NLA_PROP_BROADCAST_RATIO]) { + bc_ratio = nla_get_u32(props[TIPC_NLA_PROP_BROADCAST_RATIO]); + err = tipc_bc_link_set_broadcast_ratio(net, bc_ratio); + } + + if (!err && props[TIPC_NLA_PROP_WIN]) { + win = nla_get_u32(props[TIPC_NLA_PROP_WIN]); + err = tipc_bc_link_set_queue_limits(net, win); + } + + return err; +} + +int tipc_bcast_init(struct net *net) +{ + struct tipc_net *tn = tipc_net(net); + struct tipc_bc_base *bb = NULL; + struct tipc_link *l = NULL; + + bb = kzalloc(sizeof(*bb), GFP_KERNEL); + if (!bb) + goto enomem; + tn->bcbase = bb; + spin_lock_init(&tipc_net(net)->bclock); + + if (!tipc_link_bc_create(net, 0, 0, NULL, + one_page_mtu, + BCLINK_WIN_DEFAULT, + BCLINK_WIN_DEFAULT, + 0, + &bb->inputq, + NULL, + NULL, + &l)) + goto enomem; + bb->link = l; + tn->bcl = l; + bb->rc_ratio = 10; + bb->rcast_support = true; + return 0; +enomem: + kfree(bb); + kfree(l); + return -ENOMEM; +} + +void tipc_bcast_stop(struct net *net) +{ + struct tipc_net *tn = net_generic(net, tipc_net_id); + + synchronize_net(); + kfree(tn->bcbase); + kfree(tn->bcl); +} + +void tipc_nlist_init(struct tipc_nlist *nl, u32 self) +{ + memset(nl, 0, sizeof(*nl)); + INIT_LIST_HEAD(&nl->list); + nl->self = self; +} + +void tipc_nlist_add(struct tipc_nlist *nl, u32 node) +{ + if (node == nl->self) + nl->local = true; + else if (tipc_dest_push(&nl->list, node, 0)) + nl->remote++; +} + +void tipc_nlist_del(struct tipc_nlist *nl, u32 node) +{ + if (node == nl->self) + nl->local = false; + else if (tipc_dest_del(&nl->list, node, 0)) + nl->remote--; +} + +void tipc_nlist_purge(struct tipc_nlist *nl) +{ + tipc_dest_list_purge(&nl->list); + nl->remote = 0; + nl->local = false; +} + +u32 tipc_bcast_get_mode(struct net *net) +{ + struct tipc_bc_base *bb = tipc_bc_base(net); + + if (bb->force_bcast) + return BCLINK_MODE_BCAST; + + if (bb->force_rcast) + return BCLINK_MODE_RCAST; + + if (bb->bcast_support && bb->rcast_support) + return BCLINK_MODE_SEL; + + return 0; +} + +u32 tipc_bcast_get_broadcast_ratio(struct net *net) +{ + struct tipc_bc_base *bb = tipc_bc_base(net); + + return bb->rc_ratio; +} + +void tipc_mcast_filter_msg(struct net *net, struct sk_buff_head *defq, + struct sk_buff_head *inputq) +{ + struct sk_buff *skb, *_skb, *tmp; + struct tipc_msg *hdr, *_hdr; + bool match = false; + u32 node, port; + + skb = skb_peek(inputq); + if (!skb) + return; + + hdr = buf_msg(skb); + + if (likely(!msg_is_syn(hdr) && skb_queue_empty(defq))) + return; + + node = msg_orignode(hdr); + if (node == tipc_own_addr(net)) + return; + + port = msg_origport(hdr); + + /* Has the twin SYN message already arrived ? */ + skb_queue_walk(defq, _skb) { + _hdr = buf_msg(_skb); + if (msg_orignode(_hdr) != node) + continue; + if (msg_origport(_hdr) != port) + continue; + match = true; + break; + } + + if (!match) { + if (!msg_is_syn(hdr)) + return; + __skb_dequeue(inputq); + __skb_queue_tail(defq, skb); + return; + } + + /* Deliver non-SYN message from other link, otherwise queue it */ + if (!msg_is_syn(hdr)) { + if (msg_is_rcast(hdr) != msg_is_rcast(_hdr)) + return; + __skb_dequeue(inputq); + __skb_queue_tail(defq, skb); + return; + } + + /* Queue non-SYN/SYN message from same link */ + if (msg_is_rcast(hdr) == msg_is_rcast(_hdr)) { + __skb_dequeue(inputq); + __skb_queue_tail(defq, skb); + return; + } + + /* Matching SYN messages => return the one with data, if any */ + __skb_unlink(_skb, defq); + if (msg_data_sz(hdr)) { + kfree_skb(_skb); + } else { + __skb_dequeue(inputq); + kfree_skb(skb); + __skb_queue_tail(inputq, _skb); + } + + /* Deliver subsequent non-SYN messages from same peer */ + skb_queue_walk_safe(defq, _skb, tmp) { + _hdr = buf_msg(_skb); + if (msg_orignode(_hdr) != node) + continue; + if (msg_origport(_hdr) != port) + continue; + if (msg_is_syn(_hdr)) + break; + __skb_unlink(_skb, defq); + __skb_queue_tail(inputq, _skb); + } +} diff --git a/net/tipc/bcast.h b/net/tipc/bcast.h new file mode 100644 index 000000000..2d9352dc7 --- /dev/null +++ b/net/tipc/bcast.h @@ -0,0 +1,127 @@ +/* + * net/tipc/bcast.h: Include file for TIPC broadcast code + * + * Copyright (c) 2003-2006, 2014-2015, Ericsson AB + * Copyright (c) 2005, 2010-2011, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _TIPC_BCAST_H +#define _TIPC_BCAST_H + +#include "core.h" + +struct tipc_node; +struct tipc_msg; +struct tipc_nl_msg; +struct tipc_nlist; +struct tipc_nitem; +extern const char tipc_bclink_name[]; +extern unsigned long sysctl_tipc_bc_retruni; + +#define TIPC_METHOD_EXPIRE msecs_to_jiffies(5000) + +#define BCLINK_MODE_BCAST 0x1 +#define BCLINK_MODE_RCAST 0x2 +#define BCLINK_MODE_SEL 0x4 + +struct tipc_nlist { + struct list_head list; + u32 self; + u16 remote; + bool local; +}; + +void tipc_nlist_init(struct tipc_nlist *nl, u32 self); +void tipc_nlist_purge(struct tipc_nlist *nl); +void tipc_nlist_add(struct tipc_nlist *nl, u32 node); +void tipc_nlist_del(struct tipc_nlist *nl, u32 node); + +/* Cookie to be used between socket and broadcast layer + * @rcast: replicast (instead of broadcast) was used at previous xmit + * @mandatory: broadcast/replicast indication was set by user + * @deferredq: defer queue to make message in order + * @expires: re-evaluate non-mandatory transmit method if we are past this + */ +struct tipc_mc_method { + bool rcast; + bool mandatory; + struct sk_buff_head deferredq; + unsigned long expires; +}; + +int tipc_bcast_init(struct net *net); +void tipc_bcast_stop(struct net *net); +void tipc_bcast_add_peer(struct net *net, struct tipc_link *l, + struct sk_buff_head *xmitq); +void tipc_bcast_remove_peer(struct net *net, struct tipc_link *rcv_bcl); +void tipc_bcast_inc_bearer_dst_cnt(struct net *net, int bearer_id); +void tipc_bcast_dec_bearer_dst_cnt(struct net *net, int bearer_id); +int tipc_bcast_get_mtu(struct net *net); +void tipc_bcast_toggle_rcast(struct net *net, bool supp); +int tipc_mcast_xmit(struct net *net, struct sk_buff_head *pkts, + struct tipc_mc_method *method, struct tipc_nlist *dests, + u16 *cong_link_cnt); +int tipc_bcast_xmit(struct net *net, struct sk_buff_head *pkts, + u16 *cong_link_cnt); +int tipc_bcast_rcv(struct net *net, struct tipc_link *l, struct sk_buff *skb); +void tipc_bcast_ack_rcv(struct net *net, struct tipc_link *l, + struct tipc_msg *hdr); +int tipc_bcast_sync_rcv(struct net *net, struct tipc_link *l, + struct tipc_msg *hdr, + struct sk_buff_head *retrq); +int tipc_nl_add_bc_link(struct net *net, struct tipc_nl_msg *msg, + struct tipc_link *bcl); +int tipc_nl_bc_link_set(struct net *net, struct nlattr *attrs[]); +int tipc_bclink_reset_stats(struct net *net, struct tipc_link *l); + +u32 tipc_bcast_get_mode(struct net *net); +u32 tipc_bcast_get_broadcast_ratio(struct net *net); + +void tipc_mcast_filter_msg(struct net *net, struct sk_buff_head *defq, + struct sk_buff_head *inputq); + +static inline void tipc_bcast_lock(struct net *net) +{ + spin_lock_bh(&tipc_net(net)->bclock); +} + +static inline void tipc_bcast_unlock(struct net *net) +{ + spin_unlock_bh(&tipc_net(net)->bclock); +} + +static inline struct tipc_link *tipc_bc_sndlink(struct net *net) +{ + return tipc_net(net)->bcl; +} + +#endif diff --git a/net/tipc/bearer.c b/net/tipc/bearer.c new file mode 100644 index 000000000..cdcd27318 --- /dev/null +++ b/net/tipc/bearer.c @@ -0,0 +1,1372 @@ +/* + * net/tipc/bearer.c: TIPC bearer code + * + * Copyright (c) 1996-2006, 2013-2016, Ericsson AB + * Copyright (c) 2004-2006, 2010-2013, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include +#include "core.h" +#include "bearer.h" +#include "link.h" +#include "discover.h" +#include "monitor.h" +#include "bcast.h" +#include "netlink.h" +#include "udp_media.h" +#include "trace.h" +#include "crypto.h" + +#define MAX_ADDR_STR 60 + +static struct tipc_media * const media_info_array[] = { + ð_media_info, +#ifdef CONFIG_TIPC_MEDIA_IB + &ib_media_info, +#endif +#ifdef CONFIG_TIPC_MEDIA_UDP + &udp_media_info, +#endif + NULL +}; + +static struct tipc_bearer *bearer_get(struct net *net, int bearer_id) +{ + struct tipc_net *tn = tipc_net(net); + + return rcu_dereference(tn->bearer_list[bearer_id]); +} + +static void bearer_disable(struct net *net, struct tipc_bearer *b); +static int tipc_l2_rcv_msg(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev); + +/** + * tipc_media_find - locates specified media object by name + * @name: name to locate + */ +struct tipc_media *tipc_media_find(const char *name) +{ + u32 i; + + for (i = 0; media_info_array[i] != NULL; i++) { + if (!strcmp(media_info_array[i]->name, name)) + break; + } + return media_info_array[i]; +} + +/** + * media_find_id - locates specified media object by type identifier + * @type: type identifier to locate + */ +static struct tipc_media *media_find_id(u8 type) +{ + u32 i; + + for (i = 0; media_info_array[i] != NULL; i++) { + if (media_info_array[i]->type_id == type) + break; + } + return media_info_array[i]; +} + +/** + * tipc_media_addr_printf - record media address in print buffer + * @buf: output buffer + * @len: output buffer size remaining + * @a: input media address + */ +int tipc_media_addr_printf(char *buf, int len, struct tipc_media_addr *a) +{ + char addr_str[MAX_ADDR_STR]; + struct tipc_media *m; + int ret; + + m = media_find_id(a->media_id); + + if (m && !m->addr2str(a, addr_str, sizeof(addr_str))) + ret = scnprintf(buf, len, "%s(%s)", m->name, addr_str); + else { + u32 i; + + ret = scnprintf(buf, len, "UNKNOWN(%u)", a->media_id); + for (i = 0; i < sizeof(a->value); i++) + ret += scnprintf(buf + ret, len - ret, + "-%x", a->value[i]); + } + return ret; +} + +/** + * bearer_name_validate - validate & (optionally) deconstruct bearer name + * @name: ptr to bearer name string + * @name_parts: ptr to area for bearer name components (or NULL if not needed) + * + * Return: 1 if bearer name is valid, otherwise 0. + */ +static int bearer_name_validate(const char *name, + struct tipc_bearer_names *name_parts) +{ + char name_copy[TIPC_MAX_BEARER_NAME]; + char *media_name; + char *if_name; + u32 media_len; + u32 if_len; + + /* copy bearer name & ensure length is OK */ + if (strscpy(name_copy, name, TIPC_MAX_BEARER_NAME) < 0) + return 0; + + /* ensure all component parts of bearer name are present */ + media_name = name_copy; + if_name = strchr(media_name, ':'); + if (if_name == NULL) + return 0; + *(if_name++) = 0; + media_len = if_name - media_name; + if_len = strlen(if_name) + 1; + + /* validate component parts of bearer name */ + if ((media_len <= 1) || (media_len > TIPC_MAX_MEDIA_NAME) || + (if_len <= 1) || (if_len > TIPC_MAX_IF_NAME)) + return 0; + + /* return bearer name components, if necessary */ + if (name_parts) { + strcpy(name_parts->media_name, media_name); + strcpy(name_parts->if_name, if_name); + } + return 1; +} + +/** + * tipc_bearer_find - locates bearer object with matching bearer name + * @net: the applicable net namespace + * @name: bearer name to locate + */ +struct tipc_bearer *tipc_bearer_find(struct net *net, const char *name) +{ + struct tipc_net *tn = net_generic(net, tipc_net_id); + struct tipc_bearer *b; + u32 i; + + for (i = 0; i < MAX_BEARERS; i++) { + b = rtnl_dereference(tn->bearer_list[i]); + if (b && (!strcmp(b->name, name))) + return b; + } + return NULL; +} + +/* tipc_bearer_get_name - get the bearer name from its id. + * @net: network namespace + * @name: a pointer to the buffer where the name will be stored. + * @bearer_id: the id to get the name from. + */ +int tipc_bearer_get_name(struct net *net, char *name, u32 bearer_id) +{ + struct tipc_net *tn = tipc_net(net); + struct tipc_bearer *b; + + if (bearer_id >= MAX_BEARERS) + return -EINVAL; + + b = rtnl_dereference(tn->bearer_list[bearer_id]); + if (!b) + return -EINVAL; + + strcpy(name, b->name); + return 0; +} + +void tipc_bearer_add_dest(struct net *net, u32 bearer_id, u32 dest) +{ + struct tipc_net *tn = net_generic(net, tipc_net_id); + struct tipc_bearer *b; + + rcu_read_lock(); + b = rcu_dereference(tn->bearer_list[bearer_id]); + if (b) + tipc_disc_add_dest(b->disc); + rcu_read_unlock(); +} + +void tipc_bearer_remove_dest(struct net *net, u32 bearer_id, u32 dest) +{ + struct tipc_net *tn = net_generic(net, tipc_net_id); + struct tipc_bearer *b; + + rcu_read_lock(); + b = rcu_dereference(tn->bearer_list[bearer_id]); + if (b) + tipc_disc_remove_dest(b->disc); + rcu_read_unlock(); +} + +/** + * tipc_enable_bearer - enable bearer with the given name + * @net: the applicable net namespace + * @name: bearer name to enable + * @disc_domain: bearer domain + * @prio: bearer priority + * @attr: nlattr array + * @extack: netlink extended ack + */ +static int tipc_enable_bearer(struct net *net, const char *name, + u32 disc_domain, u32 prio, + struct nlattr *attr[], + struct netlink_ext_ack *extack) +{ + struct tipc_net *tn = tipc_net(net); + struct tipc_bearer_names b_names; + int with_this_prio = 1; + struct tipc_bearer *b; + struct tipc_media *m; + struct sk_buff *skb; + int bearer_id = 0; + int res = -EINVAL; + char *errstr = ""; + u32 i; + + if (!bearer_name_validate(name, &b_names)) { + NL_SET_ERR_MSG(extack, "Illegal name"); + return res; + } + + if (prio > TIPC_MAX_LINK_PRI && prio != TIPC_MEDIA_LINK_PRI) { + errstr = "illegal priority"; + NL_SET_ERR_MSG(extack, "Illegal priority"); + goto rejected; + } + + m = tipc_media_find(b_names.media_name); + if (!m) { + errstr = "media not registered"; + NL_SET_ERR_MSG(extack, "Media not registered"); + goto rejected; + } + + if (prio == TIPC_MEDIA_LINK_PRI) + prio = m->priority; + + /* Check new bearer vs existing ones and find free bearer id if any */ + bearer_id = MAX_BEARERS; + i = MAX_BEARERS; + while (i-- != 0) { + b = rtnl_dereference(tn->bearer_list[i]); + if (!b) { + bearer_id = i; + continue; + } + if (!strcmp(name, b->name)) { + errstr = "already enabled"; + NL_SET_ERR_MSG(extack, "Already enabled"); + goto rejected; + } + + if (b->priority == prio && + (++with_this_prio > 2)) { + pr_warn("Bearer <%s>: already 2 bearers with priority %u\n", + name, prio); + + if (prio == TIPC_MIN_LINK_PRI) { + errstr = "cannot adjust to lower"; + NL_SET_ERR_MSG(extack, "Cannot adjust to lower"); + goto rejected; + } + + pr_warn("Bearer <%s>: trying with adjusted priority\n", + name); + prio--; + bearer_id = MAX_BEARERS; + i = MAX_BEARERS; + with_this_prio = 1; + } + } + + if (bearer_id >= MAX_BEARERS) { + errstr = "max 3 bearers permitted"; + NL_SET_ERR_MSG(extack, "Max 3 bearers permitted"); + goto rejected; + } + + b = kzalloc(sizeof(*b), GFP_ATOMIC); + if (!b) + return -ENOMEM; + + strcpy(b->name, name); + b->media = m; + res = m->enable_media(net, b, attr); + if (res) { + kfree(b); + errstr = "failed to enable media"; + NL_SET_ERR_MSG(extack, "Failed to enable media"); + goto rejected; + } + + b->identity = bearer_id; + b->tolerance = m->tolerance; + b->min_win = m->min_win; + b->max_win = m->max_win; + b->domain = disc_domain; + b->net_plane = bearer_id + 'A'; + b->priority = prio; + refcount_set(&b->refcnt, 1); + + res = tipc_disc_create(net, b, &b->bcast_addr, &skb); + if (res) { + bearer_disable(net, b); + errstr = "failed to create discoverer"; + NL_SET_ERR_MSG(extack, "Failed to create discoverer"); + goto rejected; + } + + /* Create monitoring data before accepting activate messages */ + if (tipc_mon_create(net, bearer_id)) { + bearer_disable(net, b); + kfree_skb(skb); + return -ENOMEM; + } + + test_and_set_bit_lock(0, &b->up); + rcu_assign_pointer(tn->bearer_list[bearer_id], b); + if (skb) + tipc_bearer_xmit_skb(net, bearer_id, skb, &b->bcast_addr); + + pr_info("Enabled bearer <%s>, priority %u\n", name, prio); + + return res; +rejected: + pr_warn("Enabling of bearer <%s> rejected, %s\n", name, errstr); + return res; +} + +/** + * tipc_reset_bearer - Reset all links established over this bearer + * @net: the applicable net namespace + * @b: the target bearer + */ +static int tipc_reset_bearer(struct net *net, struct tipc_bearer *b) +{ + pr_info("Resetting bearer <%s>\n", b->name); + tipc_node_delete_links(net, b->identity); + tipc_disc_reset(net, b); + return 0; +} + +bool tipc_bearer_hold(struct tipc_bearer *b) +{ + return (b && refcount_inc_not_zero(&b->refcnt)); +} + +void tipc_bearer_put(struct tipc_bearer *b) +{ + if (b && refcount_dec_and_test(&b->refcnt)) + kfree_rcu(b, rcu); +} + +/** + * bearer_disable - disable this bearer + * @net: the applicable net namespace + * @b: the bearer to disable + * + * Note: This routine assumes caller holds RTNL lock. + */ +static void bearer_disable(struct net *net, struct tipc_bearer *b) +{ + struct tipc_net *tn = tipc_net(net); + int bearer_id = b->identity; + + pr_info("Disabling bearer <%s>\n", b->name); + clear_bit_unlock(0, &b->up); + tipc_node_delete_links(net, bearer_id); + b->media->disable_media(b); + RCU_INIT_POINTER(b->media_ptr, NULL); + if (b->disc) + tipc_disc_delete(b->disc); + RCU_INIT_POINTER(tn->bearer_list[bearer_id], NULL); + tipc_bearer_put(b); + tipc_mon_delete(net, bearer_id); +} + +int tipc_enable_l2_media(struct net *net, struct tipc_bearer *b, + struct nlattr *attr[]) +{ + char *dev_name = strchr((const char *)b->name, ':') + 1; + int hwaddr_len = b->media->hwaddr_len; + u8 node_id[NODE_ID_LEN] = {0,}; + struct net_device *dev; + + /* Find device with specified name */ + dev = dev_get_by_name(net, dev_name); + if (!dev) + return -ENODEV; + if (tipc_mtu_bad(dev, 0)) { + dev_put(dev); + return -EINVAL; + } + if (dev == net->loopback_dev) { + dev_put(dev); + pr_info("Enabling <%s> not permitted\n", b->name); + return -EINVAL; + } + + /* Autoconfigure own node identity if needed */ + if (!tipc_own_id(net) && hwaddr_len <= NODE_ID_LEN) { + memcpy(node_id, dev->dev_addr, hwaddr_len); + tipc_net_init(net, node_id, 0); + } + if (!tipc_own_id(net)) { + dev_put(dev); + pr_warn("Failed to obtain node identity\n"); + return -EINVAL; + } + + /* Associate TIPC bearer with L2 bearer */ + rcu_assign_pointer(b->media_ptr, dev); + b->pt.dev = dev; + b->pt.type = htons(ETH_P_TIPC); + b->pt.func = tipc_l2_rcv_msg; + dev_add_pack(&b->pt); + memset(&b->bcast_addr, 0, sizeof(b->bcast_addr)); + memcpy(b->bcast_addr.value, dev->broadcast, hwaddr_len); + b->bcast_addr.media_id = b->media->type_id; + b->bcast_addr.broadcast = TIPC_BROADCAST_SUPPORT; + b->mtu = dev->mtu; + b->media->raw2addr(b, &b->addr, (const char *)dev->dev_addr); + rcu_assign_pointer(dev->tipc_ptr, b); + return 0; +} + +/* tipc_disable_l2_media - detach TIPC bearer from an L2 interface + * @b: the target bearer + * + * Mark L2 bearer as inactive so that incoming buffers are thrown away + */ +void tipc_disable_l2_media(struct tipc_bearer *b) +{ + struct net_device *dev; + + dev = (struct net_device *)rtnl_dereference(b->media_ptr); + dev_remove_pack(&b->pt); + RCU_INIT_POINTER(dev->tipc_ptr, NULL); + synchronize_net(); + dev_put(dev); +} + +/** + * tipc_l2_send_msg - send a TIPC packet out over an L2 interface + * @net: the associated network namespace + * @skb: the packet to be sent + * @b: the bearer through which the packet is to be sent + * @dest: peer destination address + */ +int tipc_l2_send_msg(struct net *net, struct sk_buff *skb, + struct tipc_bearer *b, struct tipc_media_addr *dest) +{ + struct net_device *dev; + int delta; + + dev = (struct net_device *)rcu_dereference(b->media_ptr); + if (!dev) + return 0; + + delta = SKB_DATA_ALIGN(dev->hard_header_len - skb_headroom(skb)); + if ((delta > 0) && pskb_expand_head(skb, delta, 0, GFP_ATOMIC)) { + kfree_skb(skb); + return 0; + } + skb_reset_network_header(skb); + skb->dev = dev; + skb->protocol = htons(ETH_P_TIPC); + dev_hard_header(skb, dev, ETH_P_TIPC, dest->value, + dev->dev_addr, skb->len); + dev_queue_xmit(skb); + return 0; +} + +bool tipc_bearer_bcast_support(struct net *net, u32 bearer_id) +{ + bool supp = false; + struct tipc_bearer *b; + + rcu_read_lock(); + b = bearer_get(net, bearer_id); + if (b) + supp = (b->bcast_addr.broadcast == TIPC_BROADCAST_SUPPORT); + rcu_read_unlock(); + return supp; +} + +int tipc_bearer_mtu(struct net *net, u32 bearer_id) +{ + int mtu = 0; + struct tipc_bearer *b; + + rcu_read_lock(); + b = rcu_dereference(tipc_net(net)->bearer_list[bearer_id]); + if (b) + mtu = b->mtu; + rcu_read_unlock(); + return mtu; +} + +int tipc_bearer_min_mtu(struct net *net, u32 bearer_id) +{ + int mtu = TIPC_MIN_BEARER_MTU; + struct tipc_bearer *b; + + rcu_read_lock(); + b = bearer_get(net, bearer_id); + if (b) + mtu += b->encap_hlen; + rcu_read_unlock(); + return mtu; +} + +/* tipc_bearer_xmit_skb - sends buffer to destination over bearer + */ +void tipc_bearer_xmit_skb(struct net *net, u32 bearer_id, + struct sk_buff *skb, + struct tipc_media_addr *dest) +{ + struct tipc_msg *hdr = buf_msg(skb); + struct tipc_bearer *b; + + rcu_read_lock(); + b = bearer_get(net, bearer_id); + if (likely(b && (test_bit(0, &b->up) || msg_is_reset(hdr)))) { +#ifdef CONFIG_TIPC_CRYPTO + tipc_crypto_xmit(net, &skb, b, dest, NULL); + if (skb) +#endif + b->media->send_msg(net, skb, b, dest); + } else { + kfree_skb(skb); + } + rcu_read_unlock(); +} + +/* tipc_bearer_xmit() -send buffer to destination over bearer + */ +void tipc_bearer_xmit(struct net *net, u32 bearer_id, + struct sk_buff_head *xmitq, + struct tipc_media_addr *dst, + struct tipc_node *__dnode) +{ + struct tipc_bearer *b; + struct sk_buff *skb, *tmp; + + if (skb_queue_empty(xmitq)) + return; + + rcu_read_lock(); + b = bearer_get(net, bearer_id); + if (unlikely(!b)) + __skb_queue_purge(xmitq); + skb_queue_walk_safe(xmitq, skb, tmp) { + __skb_dequeue(xmitq); + if (likely(test_bit(0, &b->up) || msg_is_reset(buf_msg(skb)))) { +#ifdef CONFIG_TIPC_CRYPTO + tipc_crypto_xmit(net, &skb, b, dst, __dnode); + if (skb) +#endif + b->media->send_msg(net, skb, b, dst); + } else { + kfree_skb(skb); + } + } + rcu_read_unlock(); +} + +/* tipc_bearer_bc_xmit() - broadcast buffers to all destinations + */ +void tipc_bearer_bc_xmit(struct net *net, u32 bearer_id, + struct sk_buff_head *xmitq) +{ + struct tipc_net *tn = tipc_net(net); + struct tipc_media_addr *dst; + int net_id = tn->net_id; + struct tipc_bearer *b; + struct sk_buff *skb, *tmp; + struct tipc_msg *hdr; + + rcu_read_lock(); + b = bearer_get(net, bearer_id); + if (unlikely(!b || !test_bit(0, &b->up))) + __skb_queue_purge(xmitq); + skb_queue_walk_safe(xmitq, skb, tmp) { + hdr = buf_msg(skb); + msg_set_non_seq(hdr, 1); + msg_set_mc_netid(hdr, net_id); + __skb_dequeue(xmitq); + dst = &b->bcast_addr; +#ifdef CONFIG_TIPC_CRYPTO + tipc_crypto_xmit(net, &skb, b, dst, NULL); + if (skb) +#endif + b->media->send_msg(net, skb, b, dst); + } + rcu_read_unlock(); +} + +/** + * tipc_l2_rcv_msg - handle incoming TIPC message from an interface + * @skb: the received message + * @dev: the net device that the packet was received on + * @pt: the packet_type structure which was used to register this handler + * @orig_dev: the original receive net device in case the device is a bond + * + * Accept only packets explicitly sent to this node, or broadcast packets; + * ignores packets sent using interface multicast, and traffic sent to other + * nodes (which can happen if interface is running in promiscuous mode). + */ +static int tipc_l2_rcv_msg(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *orig_dev) +{ + struct tipc_bearer *b; + + rcu_read_lock(); + b = rcu_dereference(dev->tipc_ptr) ?: + rcu_dereference(orig_dev->tipc_ptr); + if (likely(b && test_bit(0, &b->up) && + (skb->pkt_type <= PACKET_MULTICAST))) { + skb_mark_not_on_list(skb); + TIPC_SKB_CB(skb)->flags = 0; + tipc_rcv(dev_net(b->pt.dev), skb, b); + rcu_read_unlock(); + return NET_RX_SUCCESS; + } + rcu_read_unlock(); + kfree_skb(skb); + return NET_RX_DROP; +} + +/** + * tipc_l2_device_event - handle device events from network device + * @nb: the context of the notification + * @evt: the type of event + * @ptr: the net device that the event was on + * + * This function is called by the Ethernet driver in case of link + * change event. + */ +static int tipc_l2_device_event(struct notifier_block *nb, unsigned long evt, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct net *net = dev_net(dev); + struct tipc_bearer *b; + + b = rtnl_dereference(dev->tipc_ptr); + if (!b) + return NOTIFY_DONE; + + trace_tipc_l2_device_event(dev, b, evt); + switch (evt) { + case NETDEV_CHANGE: + if (netif_carrier_ok(dev) && netif_oper_up(dev)) { + test_and_set_bit_lock(0, &b->up); + break; + } + fallthrough; + case NETDEV_GOING_DOWN: + clear_bit_unlock(0, &b->up); + tipc_reset_bearer(net, b); + break; + case NETDEV_UP: + test_and_set_bit_lock(0, &b->up); + break; + case NETDEV_CHANGEMTU: + if (tipc_mtu_bad(dev, 0)) { + bearer_disable(net, b); + break; + } + b->mtu = dev->mtu; + tipc_reset_bearer(net, b); + break; + case NETDEV_CHANGEADDR: + b->media->raw2addr(b, &b->addr, + (const char *)dev->dev_addr); + tipc_reset_bearer(net, b); + break; + case NETDEV_UNREGISTER: + case NETDEV_CHANGENAME: + bearer_disable(net, b); + break; + } + return NOTIFY_OK; +} + +static struct notifier_block notifier = { + .notifier_call = tipc_l2_device_event, + .priority = 0, +}; + +int tipc_bearer_setup(void) +{ + return register_netdevice_notifier(¬ifier); +} + +void tipc_bearer_cleanup(void) +{ + unregister_netdevice_notifier(¬ifier); +} + +void tipc_bearer_stop(struct net *net) +{ + struct tipc_net *tn = net_generic(net, tipc_net_id); + struct tipc_bearer *b; + u32 i; + + for (i = 0; i < MAX_BEARERS; i++) { + b = rtnl_dereference(tn->bearer_list[i]); + if (b) { + bearer_disable(net, b); + tn->bearer_list[i] = NULL; + } + } +} + +void tipc_clone_to_loopback(struct net *net, struct sk_buff_head *pkts) +{ + struct net_device *dev = net->loopback_dev; + struct sk_buff *skb, *_skb; + int exp; + + skb_queue_walk(pkts, _skb) { + skb = pskb_copy(_skb, GFP_ATOMIC); + if (!skb) + continue; + + exp = SKB_DATA_ALIGN(dev->hard_header_len - skb_headroom(skb)); + if (exp > 0 && pskb_expand_head(skb, exp, 0, GFP_ATOMIC)) { + kfree_skb(skb); + continue; + } + + skb_reset_network_header(skb); + dev_hard_header(skb, dev, ETH_P_TIPC, dev->dev_addr, + dev->dev_addr, skb->len); + skb->dev = dev; + skb->pkt_type = PACKET_HOST; + skb->ip_summed = CHECKSUM_UNNECESSARY; + skb->protocol = eth_type_trans(skb, dev); + netif_rx(skb); + } +} + +static int tipc_loopback_rcv_pkt(struct sk_buff *skb, struct net_device *dev, + struct packet_type *pt, struct net_device *od) +{ + consume_skb(skb); + return NET_RX_SUCCESS; +} + +int tipc_attach_loopback(struct net *net) +{ + struct net_device *dev = net->loopback_dev; + struct tipc_net *tn = tipc_net(net); + + if (!dev) + return -ENODEV; + + netdev_hold(dev, &tn->loopback_pt.dev_tracker, GFP_KERNEL); + tn->loopback_pt.dev = dev; + tn->loopback_pt.type = htons(ETH_P_TIPC); + tn->loopback_pt.func = tipc_loopback_rcv_pkt; + dev_add_pack(&tn->loopback_pt); + return 0; +} + +void tipc_detach_loopback(struct net *net) +{ + struct tipc_net *tn = tipc_net(net); + + dev_remove_pack(&tn->loopback_pt); + netdev_put(net->loopback_dev, &tn->loopback_pt.dev_tracker); +} + +/* Caller should hold rtnl_lock to protect the bearer */ +static int __tipc_nl_add_bearer(struct tipc_nl_msg *msg, + struct tipc_bearer *bearer, int nlflags) +{ + void *hdr; + struct nlattr *attrs; + struct nlattr *prop; + + hdr = genlmsg_put(msg->skb, msg->portid, msg->seq, &tipc_genl_family, + nlflags, TIPC_NL_BEARER_GET); + if (!hdr) + return -EMSGSIZE; + + attrs = nla_nest_start_noflag(msg->skb, TIPC_NLA_BEARER); + if (!attrs) + goto msg_full; + + if (nla_put_string(msg->skb, TIPC_NLA_BEARER_NAME, bearer->name)) + goto attr_msg_full; + + prop = nla_nest_start_noflag(msg->skb, TIPC_NLA_BEARER_PROP); + if (!prop) + goto prop_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PROP_PRIO, bearer->priority)) + goto prop_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PROP_TOL, bearer->tolerance)) + goto prop_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PROP_WIN, bearer->max_win)) + goto prop_msg_full; + if (bearer->media->type_id == TIPC_MEDIA_TYPE_UDP) + if (nla_put_u32(msg->skb, TIPC_NLA_PROP_MTU, bearer->mtu)) + goto prop_msg_full; + + nla_nest_end(msg->skb, prop); + +#ifdef CONFIG_TIPC_MEDIA_UDP + if (bearer->media->type_id == TIPC_MEDIA_TYPE_UDP) { + if (tipc_udp_nl_add_bearer_data(msg, bearer)) + goto attr_msg_full; + } +#endif + + nla_nest_end(msg->skb, attrs); + genlmsg_end(msg->skb, hdr); + + return 0; + +prop_msg_full: + nla_nest_cancel(msg->skb, prop); +attr_msg_full: + nla_nest_cancel(msg->skb, attrs); +msg_full: + genlmsg_cancel(msg->skb, hdr); + + return -EMSGSIZE; +} + +int tipc_nl_bearer_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + int err; + int i = cb->args[0]; + struct tipc_bearer *bearer; + struct tipc_nl_msg msg; + struct net *net = sock_net(skb->sk); + struct tipc_net *tn = net_generic(net, tipc_net_id); + + if (i == MAX_BEARERS) + return 0; + + msg.skb = skb; + msg.portid = NETLINK_CB(cb->skb).portid; + msg.seq = cb->nlh->nlmsg_seq; + + rtnl_lock(); + for (i = 0; i < MAX_BEARERS; i++) { + bearer = rtnl_dereference(tn->bearer_list[i]); + if (!bearer) + continue; + + err = __tipc_nl_add_bearer(&msg, bearer, NLM_F_MULTI); + if (err) + break; + } + rtnl_unlock(); + + cb->args[0] = i; + return skb->len; +} + +int tipc_nl_bearer_get(struct sk_buff *skb, struct genl_info *info) +{ + int err; + char *name; + struct sk_buff *rep; + struct tipc_bearer *bearer; + struct tipc_nl_msg msg; + struct nlattr *attrs[TIPC_NLA_BEARER_MAX + 1]; + struct net *net = genl_info_net(info); + + if (!info->attrs[TIPC_NLA_BEARER]) + return -EINVAL; + + err = nla_parse_nested_deprecated(attrs, TIPC_NLA_BEARER_MAX, + info->attrs[TIPC_NLA_BEARER], + tipc_nl_bearer_policy, info->extack); + if (err) + return err; + + if (!attrs[TIPC_NLA_BEARER_NAME]) + return -EINVAL; + name = nla_data(attrs[TIPC_NLA_BEARER_NAME]); + + rep = nlmsg_new(NLMSG_GOODSIZE, GFP_KERNEL); + if (!rep) + return -ENOMEM; + + msg.skb = rep; + msg.portid = info->snd_portid; + msg.seq = info->snd_seq; + + rtnl_lock(); + bearer = tipc_bearer_find(net, name); + if (!bearer) { + err = -EINVAL; + NL_SET_ERR_MSG(info->extack, "Bearer not found"); + goto err_out; + } + + err = __tipc_nl_add_bearer(&msg, bearer, 0); + if (err) + goto err_out; + rtnl_unlock(); + + return genlmsg_reply(rep, info); +err_out: + rtnl_unlock(); + nlmsg_free(rep); + + return err; +} + +int __tipc_nl_bearer_disable(struct sk_buff *skb, struct genl_info *info) +{ + int err; + char *name; + struct tipc_bearer *bearer; + struct nlattr *attrs[TIPC_NLA_BEARER_MAX + 1]; + struct net *net = sock_net(skb->sk); + + if (!info->attrs[TIPC_NLA_BEARER]) + return -EINVAL; + + err = nla_parse_nested_deprecated(attrs, TIPC_NLA_BEARER_MAX, + info->attrs[TIPC_NLA_BEARER], + tipc_nl_bearer_policy, info->extack); + if (err) + return err; + + if (!attrs[TIPC_NLA_BEARER_NAME]) + return -EINVAL; + + name = nla_data(attrs[TIPC_NLA_BEARER_NAME]); + + bearer = tipc_bearer_find(net, name); + if (!bearer) { + NL_SET_ERR_MSG(info->extack, "Bearer not found"); + return -EINVAL; + } + + bearer_disable(net, bearer); + + return 0; +} + +int tipc_nl_bearer_disable(struct sk_buff *skb, struct genl_info *info) +{ + int err; + + rtnl_lock(); + err = __tipc_nl_bearer_disable(skb, info); + rtnl_unlock(); + + return err; +} + +int __tipc_nl_bearer_enable(struct sk_buff *skb, struct genl_info *info) +{ + int err; + char *bearer; + struct nlattr *attrs[TIPC_NLA_BEARER_MAX + 1]; + struct net *net = sock_net(skb->sk); + u32 domain = 0; + u32 prio; + + prio = TIPC_MEDIA_LINK_PRI; + + if (!info->attrs[TIPC_NLA_BEARER]) + return -EINVAL; + + err = nla_parse_nested_deprecated(attrs, TIPC_NLA_BEARER_MAX, + info->attrs[TIPC_NLA_BEARER], + tipc_nl_bearer_policy, info->extack); + if (err) + return err; + + if (!attrs[TIPC_NLA_BEARER_NAME]) + return -EINVAL; + + bearer = nla_data(attrs[TIPC_NLA_BEARER_NAME]); + + if (attrs[TIPC_NLA_BEARER_DOMAIN]) + domain = nla_get_u32(attrs[TIPC_NLA_BEARER_DOMAIN]); + + if (attrs[TIPC_NLA_BEARER_PROP]) { + struct nlattr *props[TIPC_NLA_PROP_MAX + 1]; + + err = tipc_nl_parse_link_prop(attrs[TIPC_NLA_BEARER_PROP], + props); + if (err) + return err; + + if (props[TIPC_NLA_PROP_PRIO]) + prio = nla_get_u32(props[TIPC_NLA_PROP_PRIO]); + } + + return tipc_enable_bearer(net, bearer, domain, prio, attrs, + info->extack); +} + +int tipc_nl_bearer_enable(struct sk_buff *skb, struct genl_info *info) +{ + int err; + + rtnl_lock(); + err = __tipc_nl_bearer_enable(skb, info); + rtnl_unlock(); + + return err; +} + +int tipc_nl_bearer_add(struct sk_buff *skb, struct genl_info *info) +{ + int err; + char *name; + struct tipc_bearer *b; + struct nlattr *attrs[TIPC_NLA_BEARER_MAX + 1]; + struct net *net = sock_net(skb->sk); + + if (!info->attrs[TIPC_NLA_BEARER]) + return -EINVAL; + + err = nla_parse_nested_deprecated(attrs, TIPC_NLA_BEARER_MAX, + info->attrs[TIPC_NLA_BEARER], + tipc_nl_bearer_policy, info->extack); + if (err) + return err; + + if (!attrs[TIPC_NLA_BEARER_NAME]) + return -EINVAL; + name = nla_data(attrs[TIPC_NLA_BEARER_NAME]); + + rtnl_lock(); + b = tipc_bearer_find(net, name); + if (!b) { + rtnl_unlock(); + NL_SET_ERR_MSG(info->extack, "Bearer not found"); + return -EINVAL; + } + +#ifdef CONFIG_TIPC_MEDIA_UDP + if (attrs[TIPC_NLA_BEARER_UDP_OPTS]) { + err = tipc_udp_nl_bearer_add(b, + attrs[TIPC_NLA_BEARER_UDP_OPTS]); + if (err) { + rtnl_unlock(); + return err; + } + } +#endif + rtnl_unlock(); + + return 0; +} + +int __tipc_nl_bearer_set(struct sk_buff *skb, struct genl_info *info) +{ + struct tipc_bearer *b; + struct nlattr *attrs[TIPC_NLA_BEARER_MAX + 1]; + struct net *net = sock_net(skb->sk); + char *name; + int err; + + if (!info->attrs[TIPC_NLA_BEARER]) + return -EINVAL; + + err = nla_parse_nested_deprecated(attrs, TIPC_NLA_BEARER_MAX, + info->attrs[TIPC_NLA_BEARER], + tipc_nl_bearer_policy, info->extack); + if (err) + return err; + + if (!attrs[TIPC_NLA_BEARER_NAME]) + return -EINVAL; + name = nla_data(attrs[TIPC_NLA_BEARER_NAME]); + + b = tipc_bearer_find(net, name); + if (!b) { + NL_SET_ERR_MSG(info->extack, "Bearer not found"); + return -EINVAL; + } + + if (attrs[TIPC_NLA_BEARER_PROP]) { + struct nlattr *props[TIPC_NLA_PROP_MAX + 1]; + + err = tipc_nl_parse_link_prop(attrs[TIPC_NLA_BEARER_PROP], + props); + if (err) + return err; + + if (props[TIPC_NLA_PROP_TOL]) { + b->tolerance = nla_get_u32(props[TIPC_NLA_PROP_TOL]); + tipc_node_apply_property(net, b, TIPC_NLA_PROP_TOL); + } + if (props[TIPC_NLA_PROP_PRIO]) + b->priority = nla_get_u32(props[TIPC_NLA_PROP_PRIO]); + if (props[TIPC_NLA_PROP_WIN]) + b->max_win = nla_get_u32(props[TIPC_NLA_PROP_WIN]); + if (props[TIPC_NLA_PROP_MTU]) { + if (b->media->type_id != TIPC_MEDIA_TYPE_UDP) { + NL_SET_ERR_MSG(info->extack, + "MTU property is unsupported"); + return -EINVAL; + } +#ifdef CONFIG_TIPC_MEDIA_UDP + if (nla_get_u32(props[TIPC_NLA_PROP_MTU]) < + b->encap_hlen + TIPC_MIN_BEARER_MTU) { + NL_SET_ERR_MSG(info->extack, + "MTU value is out-of-range"); + return -EINVAL; + } + b->mtu = nla_get_u32(props[TIPC_NLA_PROP_MTU]); + tipc_node_apply_property(net, b, TIPC_NLA_PROP_MTU); +#endif + } + } + + return 0; +} + +int tipc_nl_bearer_set(struct sk_buff *skb, struct genl_info *info) +{ + int err; + + rtnl_lock(); + err = __tipc_nl_bearer_set(skb, info); + rtnl_unlock(); + + return err; +} + +static int __tipc_nl_add_media(struct tipc_nl_msg *msg, + struct tipc_media *media, int nlflags) +{ + void *hdr; + struct nlattr *attrs; + struct nlattr *prop; + + hdr = genlmsg_put(msg->skb, msg->portid, msg->seq, &tipc_genl_family, + nlflags, TIPC_NL_MEDIA_GET); + if (!hdr) + return -EMSGSIZE; + + attrs = nla_nest_start_noflag(msg->skb, TIPC_NLA_MEDIA); + if (!attrs) + goto msg_full; + + if (nla_put_string(msg->skb, TIPC_NLA_MEDIA_NAME, media->name)) + goto attr_msg_full; + + prop = nla_nest_start_noflag(msg->skb, TIPC_NLA_MEDIA_PROP); + if (!prop) + goto prop_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PROP_PRIO, media->priority)) + goto prop_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PROP_TOL, media->tolerance)) + goto prop_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PROP_WIN, media->max_win)) + goto prop_msg_full; + if (media->type_id == TIPC_MEDIA_TYPE_UDP) + if (nla_put_u32(msg->skb, TIPC_NLA_PROP_MTU, media->mtu)) + goto prop_msg_full; + + nla_nest_end(msg->skb, prop); + nla_nest_end(msg->skb, attrs); + genlmsg_end(msg->skb, hdr); + + return 0; + +prop_msg_full: + nla_nest_cancel(msg->skb, prop); +attr_msg_full: + nla_nest_cancel(msg->skb, attrs); +msg_full: + genlmsg_cancel(msg->skb, hdr); + + return -EMSGSIZE; +} + +int tipc_nl_media_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + int err; + int i = cb->args[0]; + struct tipc_nl_msg msg; + + if (i == MAX_MEDIA) + return 0; + + msg.skb = skb; + msg.portid = NETLINK_CB(cb->skb).portid; + msg.seq = cb->nlh->nlmsg_seq; + + rtnl_lock(); + for (; media_info_array[i] != NULL; i++) { + err = __tipc_nl_add_media(&msg, media_info_array[i], + NLM_F_MULTI); + if (err) + break; + } + rtnl_unlock(); + + cb->args[0] = i; + return skb->len; +} + +int tipc_nl_media_get(struct sk_buff *skb, struct genl_info *info) +{ + int err; + char *name; + struct tipc_nl_msg msg; + struct tipc_media *media; + struct sk_buff *rep; + struct nlattr *attrs[TIPC_NLA_MEDIA_MAX + 1]; + + if (!info->attrs[TIPC_NLA_MEDIA]) + return -EINVAL; + + err = nla_parse_nested_deprecated(attrs, TIPC_NLA_MEDIA_MAX, + info->attrs[TIPC_NLA_MEDIA], + tipc_nl_media_policy, info->extack); + if (err) + return err; + + if (!attrs[TIPC_NLA_MEDIA_NAME]) + return -EINVAL; + name = nla_data(attrs[TIPC_NLA_MEDIA_NAME]); + + rep = nlmsg_new(NLMSG_GOODSIZE, GFP_KERNEL); + if (!rep) + return -ENOMEM; + + msg.skb = rep; + msg.portid = info->snd_portid; + msg.seq = info->snd_seq; + + rtnl_lock(); + media = tipc_media_find(name); + if (!media) { + NL_SET_ERR_MSG(info->extack, "Media not found"); + err = -EINVAL; + goto err_out; + } + + err = __tipc_nl_add_media(&msg, media, 0); + if (err) + goto err_out; + rtnl_unlock(); + + return genlmsg_reply(rep, info); +err_out: + rtnl_unlock(); + nlmsg_free(rep); + + return err; +} + +int __tipc_nl_media_set(struct sk_buff *skb, struct genl_info *info) +{ + int err; + char *name; + struct tipc_media *m; + struct nlattr *attrs[TIPC_NLA_MEDIA_MAX + 1]; + + if (!info->attrs[TIPC_NLA_MEDIA]) + return -EINVAL; + + err = nla_parse_nested_deprecated(attrs, TIPC_NLA_MEDIA_MAX, + info->attrs[TIPC_NLA_MEDIA], + tipc_nl_media_policy, info->extack); + + if (!attrs[TIPC_NLA_MEDIA_NAME]) + return -EINVAL; + name = nla_data(attrs[TIPC_NLA_MEDIA_NAME]); + + m = tipc_media_find(name); + if (!m) { + NL_SET_ERR_MSG(info->extack, "Media not found"); + return -EINVAL; + } + if (attrs[TIPC_NLA_MEDIA_PROP]) { + struct nlattr *props[TIPC_NLA_PROP_MAX + 1]; + + err = tipc_nl_parse_link_prop(attrs[TIPC_NLA_MEDIA_PROP], + props); + if (err) + return err; + + if (props[TIPC_NLA_PROP_TOL]) + m->tolerance = nla_get_u32(props[TIPC_NLA_PROP_TOL]); + if (props[TIPC_NLA_PROP_PRIO]) + m->priority = nla_get_u32(props[TIPC_NLA_PROP_PRIO]); + if (props[TIPC_NLA_PROP_WIN]) + m->max_win = nla_get_u32(props[TIPC_NLA_PROP_WIN]); + if (props[TIPC_NLA_PROP_MTU]) { + if (m->type_id != TIPC_MEDIA_TYPE_UDP) { + NL_SET_ERR_MSG(info->extack, + "MTU property is unsupported"); + return -EINVAL; + } +#ifdef CONFIG_TIPC_MEDIA_UDP + if (tipc_udp_mtu_bad(nla_get_u32 + (props[TIPC_NLA_PROP_MTU]))) { + NL_SET_ERR_MSG(info->extack, + "MTU value is out-of-range"); + return -EINVAL; + } + m->mtu = nla_get_u32(props[TIPC_NLA_PROP_MTU]); +#endif + } + } + + return 0; +} + +int tipc_nl_media_set(struct sk_buff *skb, struct genl_info *info) +{ + int err; + + rtnl_lock(); + err = __tipc_nl_media_set(skb, info); + rtnl_unlock(); + + return err; +} diff --git a/net/tipc/bearer.h b/net/tipc/bearer.h new file mode 100644 index 000000000..bd0cc5c28 --- /dev/null +++ b/net/tipc/bearer.h @@ -0,0 +1,268 @@ +/* + * net/tipc/bearer.h: Include file for TIPC bearer code + * + * Copyright (c) 1996-2006, 2013-2016, Ericsson AB + * Copyright (c) 2005, 2010-2011, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _TIPC_BEARER_H +#define _TIPC_BEARER_H + +#include "netlink.h" +#include "core.h" +#include "msg.h" +#include + +#define MAX_MEDIA 3 + +/* Identifiers associated with TIPC message header media address info + * - address info field is 32 bytes long + * - the field's actual content and length is defined per media + * - remaining unused bytes in the field are set to zero + */ +#define TIPC_MEDIA_INFO_SIZE 32 +#define TIPC_MEDIA_TYPE_OFFSET 3 +#define TIPC_MEDIA_ADDR_OFFSET 4 + +/* + * Identifiers of supported TIPC media types + */ +#define TIPC_MEDIA_TYPE_ETH 1 +#define TIPC_MEDIA_TYPE_IB 2 +#define TIPC_MEDIA_TYPE_UDP 3 + +/* Minimum bearer MTU */ +#define TIPC_MIN_BEARER_MTU (MAX_H_SIZE + INT_H_SIZE) + +/* Identifiers for distinguishing between broadcast/multicast and replicast + */ +#define TIPC_BROADCAST_SUPPORT 1 +#define TIPC_REPLICAST_SUPPORT 2 + +/** + * struct tipc_media_addr - destination address used by TIPC bearers + * @value: address info (format defined by media) + * @media_id: TIPC media type identifier + * @broadcast: non-zero if address is a broadcast address + */ +struct tipc_media_addr { + u8 value[TIPC_MEDIA_INFO_SIZE]; + u8 media_id; + u8 broadcast; +}; + +struct tipc_bearer; + +/** + * struct tipc_media - Media specific info exposed to generic bearer layer + * @send_msg: routine which handles buffer transmission + * @enable_media: routine which enables a media + * @disable_media: routine which disables a media + * @addr2str: convert media address format to string + * @addr2msg: convert from media addr format to discovery msg addr format + * @msg2addr: convert from discovery msg addr format to media addr format + * @raw2addr: convert from raw addr format to media addr format + * @priority: default link (and bearer) priority + * @tolerance: default time (in ms) before declaring link failure + * @min_win: minimum window (in packets) before declaring link congestion + * @max_win: maximum window (in packets) before declaring link congestion + * @mtu: max packet size bearer can support for media type not dependent on + * underlying device MTU + * @type_id: TIPC media identifier + * @hwaddr_len: TIPC media address len + * @name: media name + */ +struct tipc_media { + int (*send_msg)(struct net *net, struct sk_buff *buf, + struct tipc_bearer *b, + struct tipc_media_addr *dest); + int (*enable_media)(struct net *net, struct tipc_bearer *b, + struct nlattr *attr[]); + void (*disable_media)(struct tipc_bearer *b); + int (*addr2str)(struct tipc_media_addr *addr, + char *strbuf, + int bufsz); + int (*addr2msg)(char *msg, struct tipc_media_addr *addr); + int (*msg2addr)(struct tipc_bearer *b, + struct tipc_media_addr *addr, + char *msg); + int (*raw2addr)(struct tipc_bearer *b, + struct tipc_media_addr *addr, + const char *raw); + u32 priority; + u32 tolerance; + u32 min_win; + u32 max_win; + u32 mtu; + u32 type_id; + u32 hwaddr_len; + char name[TIPC_MAX_MEDIA_NAME]; +}; + +/** + * struct tipc_bearer - Generic TIPC bearer structure + * @media_ptr: pointer to additional media-specific information about bearer + * @mtu: max packet size bearer can support + * @addr: media-specific address associated with bearer + * @name: bearer name (format = media:interface) + * @media: ptr to media structure associated with bearer + * @bcast_addr: media address used in broadcasting + * @pt: packet type for bearer + * @rcu: rcu struct for tipc_bearer + * @priority: default link priority for bearer + * @min_win: minimum window (in packets) before declaring link congestion + * @max_win: maximum window (in packets) before declaring link congestion + * @tolerance: default link tolerance for bearer + * @domain: network domain to which links can be established + * @identity: array index of this bearer within TIPC bearer array + * @disc: ptr to link setup request + * @net_plane: network plane ('A' through 'H') currently associated with bearer + * @encap_hlen: encap headers length + * @up: bearer up flag (bit 0) + * @refcnt: tipc_bearer reference counter + * + * Note: media-specific code is responsible for initialization of the fields + * indicated below when a bearer is enabled; TIPC's generic bearer code takes + * care of initializing all other fields. + */ +struct tipc_bearer { + void __rcu *media_ptr; /* initialized by media */ + u32 mtu; /* initialized by media */ + struct tipc_media_addr addr; /* initialized by media */ + char name[TIPC_MAX_BEARER_NAME]; + struct tipc_media *media; + struct tipc_media_addr bcast_addr; + struct packet_type pt; + struct rcu_head rcu; + u32 priority; + u32 min_win; + u32 max_win; + u32 tolerance; + u32 domain; + u32 identity; + struct tipc_discoverer *disc; + char net_plane; + u16 encap_hlen; + unsigned long up; + refcount_t refcnt; +}; + +struct tipc_bearer_names { + char media_name[TIPC_MAX_MEDIA_NAME]; + char if_name[TIPC_MAX_IF_NAME]; +}; + +/* + * TIPC routines available to supported media types + */ + +void tipc_rcv(struct net *net, struct sk_buff *skb, struct tipc_bearer *b); + +/* + * Routines made available to TIPC by supported media types + */ +extern struct tipc_media eth_media_info; + +#ifdef CONFIG_TIPC_MEDIA_IB +extern struct tipc_media ib_media_info; +#endif +#ifdef CONFIG_TIPC_MEDIA_UDP +extern struct tipc_media udp_media_info; +#endif + +int tipc_nl_bearer_disable(struct sk_buff *skb, struct genl_info *info); +int __tipc_nl_bearer_disable(struct sk_buff *skb, struct genl_info *info); +int tipc_nl_bearer_enable(struct sk_buff *skb, struct genl_info *info); +int __tipc_nl_bearer_enable(struct sk_buff *skb, struct genl_info *info); +int tipc_nl_bearer_dump(struct sk_buff *skb, struct netlink_callback *cb); +int tipc_nl_bearer_get(struct sk_buff *skb, struct genl_info *info); +int tipc_nl_bearer_set(struct sk_buff *skb, struct genl_info *info); +int __tipc_nl_bearer_set(struct sk_buff *skb, struct genl_info *info); +int tipc_nl_bearer_add(struct sk_buff *skb, struct genl_info *info); + +int tipc_nl_media_dump(struct sk_buff *skb, struct netlink_callback *cb); +int tipc_nl_media_get(struct sk_buff *skb, struct genl_info *info); +int tipc_nl_media_set(struct sk_buff *skb, struct genl_info *info); +int __tipc_nl_media_set(struct sk_buff *skb, struct genl_info *info); + +int tipc_media_set_priority(const char *name, u32 new_value); +int tipc_media_set_window(const char *name, u32 new_value); +int tipc_media_addr_printf(char *buf, int len, struct tipc_media_addr *a); +int tipc_enable_l2_media(struct net *net, struct tipc_bearer *b, + struct nlattr *attrs[]); +bool tipc_bearer_hold(struct tipc_bearer *b); +void tipc_bearer_put(struct tipc_bearer *b); +void tipc_disable_l2_media(struct tipc_bearer *b); +int tipc_l2_send_msg(struct net *net, struct sk_buff *buf, + struct tipc_bearer *b, struct tipc_media_addr *dest); + +void tipc_bearer_add_dest(struct net *net, u32 bearer_id, u32 dest); +void tipc_bearer_remove_dest(struct net *net, u32 bearer_id, u32 dest); +struct tipc_bearer *tipc_bearer_find(struct net *net, const char *name); +int tipc_bearer_get_name(struct net *net, char *name, u32 bearer_id); +struct tipc_media *tipc_media_find(const char *name); +int tipc_bearer_setup(void); +void tipc_bearer_cleanup(void); +void tipc_bearer_stop(struct net *net); +int tipc_bearer_mtu(struct net *net, u32 bearer_id); +int tipc_bearer_min_mtu(struct net *net, u32 bearer_id); +bool tipc_bearer_bcast_support(struct net *net, u32 bearer_id); +void tipc_bearer_xmit_skb(struct net *net, u32 bearer_id, + struct sk_buff *skb, + struct tipc_media_addr *dest); +void tipc_bearer_xmit(struct net *net, u32 bearer_id, + struct sk_buff_head *xmitq, + struct tipc_media_addr *dst, + struct tipc_node *__dnode); +void tipc_bearer_bc_xmit(struct net *net, u32 bearer_id, + struct sk_buff_head *xmitq); +void tipc_clone_to_loopback(struct net *net, struct sk_buff_head *pkts); +int tipc_attach_loopback(struct net *net); +void tipc_detach_loopback(struct net *net); + +static inline void tipc_loopback_trace(struct net *net, + struct sk_buff_head *pkts) +{ + if (unlikely(dev_nit_active(net->loopback_dev))) + tipc_clone_to_loopback(net, pkts); +} + +/* check if device MTU is too low for tipc headers */ +static inline bool tipc_mtu_bad(struct net_device *dev, unsigned int reserve) +{ + if (dev->mtu >= TIPC_MIN_BEARER_MTU + reserve) + return false; + netdev_warn(dev, "MTU too low for tipc bearer\n"); + return true; +} + +#endif /* _TIPC_BEARER_H */ diff --git a/net/tipc/core.c b/net/tipc/core.c new file mode 100644 index 000000000..434e70eab --- /dev/null +++ b/net/tipc/core.c @@ -0,0 +1,229 @@ +/* + * net/tipc/core.c: TIPC module code + * + * Copyright (c) 2003-2006, 2013, Ericsson AB + * Copyright (c) 2005-2006, 2010-2013, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "core.h" +#include "name_table.h" +#include "subscr.h" +#include "bearer.h" +#include "net.h" +#include "socket.h" +#include "bcast.h" +#include "node.h" +#include "crypto.h" + +#include + +/* configurable TIPC parameters */ +unsigned int tipc_net_id __read_mostly; +int sysctl_tipc_rmem[3] __read_mostly; /* min/default/max */ + +static int __net_init tipc_init_net(struct net *net) +{ + struct tipc_net *tn = net_generic(net, tipc_net_id); + int err; + + tn->net_id = 4711; + tn->node_addr = 0; + tn->trial_addr = 0; + tn->addr_trial_end = 0; + tn->capabilities = TIPC_NODE_CAPABILITIES; + INIT_WORK(&tn->work, tipc_net_finalize_work); + memset(tn->node_id, 0, sizeof(tn->node_id)); + memset(tn->node_id_string, 0, sizeof(tn->node_id_string)); + tn->mon_threshold = TIPC_DEF_MON_THRESHOLD; + get_random_bytes(&tn->random, sizeof(int)); + INIT_LIST_HEAD(&tn->node_list); + spin_lock_init(&tn->node_list_lock); + +#ifdef CONFIG_TIPC_CRYPTO + err = tipc_crypto_start(&tn->crypto_tx, net, NULL); + if (err) + goto out_crypto; +#endif + err = tipc_sk_rht_init(net); + if (err) + goto out_sk_rht; + + err = tipc_nametbl_init(net); + if (err) + goto out_nametbl; + + err = tipc_bcast_init(net); + if (err) + goto out_bclink; + + err = tipc_attach_loopback(net); + if (err) + goto out_bclink; + + return 0; + +out_bclink: + tipc_nametbl_stop(net); +out_nametbl: + tipc_sk_rht_destroy(net); +out_sk_rht: + +#ifdef CONFIG_TIPC_CRYPTO + tipc_crypto_stop(&tn->crypto_tx); +out_crypto: +#endif + return err; +} + +static void __net_exit tipc_exit_net(struct net *net) +{ + struct tipc_net *tn = tipc_net(net); + + tipc_detach_loopback(net); + tipc_net_stop(net); + /* Make sure the tipc_net_finalize_work() finished */ + cancel_work_sync(&tn->work); + tipc_bcast_stop(net); + tipc_nametbl_stop(net); + tipc_sk_rht_destroy(net); +#ifdef CONFIG_TIPC_CRYPTO + tipc_crypto_stop(&tipc_net(net)->crypto_tx); +#endif + while (atomic_read(&tn->wq_count)) + cond_resched(); +} + +static void __net_exit tipc_pernet_pre_exit(struct net *net) +{ + tipc_node_pre_cleanup_net(net); +} + +static struct pernet_operations tipc_pernet_pre_exit_ops = { + .pre_exit = tipc_pernet_pre_exit, +}; + +static struct pernet_operations tipc_net_ops = { + .init = tipc_init_net, + .exit = tipc_exit_net, + .id = &tipc_net_id, + .size = sizeof(struct tipc_net), +}; + +static struct pernet_operations tipc_topsrv_net_ops = { + .init = tipc_topsrv_init_net, + .exit = tipc_topsrv_exit_net, +}; + +static int __init tipc_init(void) +{ + int err; + + pr_info("Activated (version " TIPC_MOD_VER ")\n"); + + sysctl_tipc_rmem[0] = RCVBUF_MIN; + sysctl_tipc_rmem[1] = RCVBUF_DEF; + sysctl_tipc_rmem[2] = RCVBUF_MAX; + + err = tipc_register_sysctl(); + if (err) + goto out_sysctl; + + err = register_pernet_device(&tipc_net_ops); + if (err) + goto out_pernet; + + err = tipc_socket_init(); + if (err) + goto out_socket; + + err = register_pernet_device(&tipc_topsrv_net_ops); + if (err) + goto out_pernet_topsrv; + + err = register_pernet_subsys(&tipc_pernet_pre_exit_ops); + if (err) + goto out_register_pernet_subsys; + + err = tipc_bearer_setup(); + if (err) + goto out_bearer; + + err = tipc_netlink_start(); + if (err) + goto out_netlink; + + err = tipc_netlink_compat_start(); + if (err) + goto out_netlink_compat; + + pr_info("Started in single node mode\n"); + return 0; + +out_netlink_compat: + tipc_netlink_stop(); +out_netlink: + tipc_bearer_cleanup(); +out_bearer: + unregister_pernet_subsys(&tipc_pernet_pre_exit_ops); +out_register_pernet_subsys: + unregister_pernet_device(&tipc_topsrv_net_ops); +out_pernet_topsrv: + tipc_socket_stop(); +out_socket: + unregister_pernet_device(&tipc_net_ops); +out_pernet: + tipc_unregister_sysctl(); +out_sysctl: + pr_err("Unable to start in single node mode\n"); + return err; +} + +static void __exit tipc_exit(void) +{ + tipc_netlink_compat_stop(); + tipc_netlink_stop(); + tipc_bearer_cleanup(); + unregister_pernet_subsys(&tipc_pernet_pre_exit_ops); + unregister_pernet_device(&tipc_topsrv_net_ops); + tipc_socket_stop(); + unregister_pernet_device(&tipc_net_ops); + tipc_unregister_sysctl(); + + pr_info("Deactivated\n"); +} + +module_init(tipc_init); +module_exit(tipc_exit); + +MODULE_DESCRIPTION("TIPC: Transparent Inter Process Communication"); +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_VERSION(TIPC_MOD_VER); diff --git a/net/tipc/core.h b/net/tipc/core.h new file mode 100644 index 000000000..0a3f7a70a --- /dev/null +++ b/net/tipc/core.h @@ -0,0 +1,228 @@ +/* + * net/tipc/core.h: Include file for TIPC global declarations + * + * Copyright (c) 2005-2006, 2013-2018 Ericsson AB + * Copyright (c) 2005-2007, 2010-2013, Wind River Systems + * Copyright (c) 2020, Red Hat Inc + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _TIPC_CORE_H +#define _TIPC_CORE_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef pr_fmt +#undef pr_fmt +#endif + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +struct tipc_node; +struct tipc_bearer; +struct tipc_bc_base; +struct tipc_link; +struct tipc_name_table; +struct tipc_topsrv; +struct tipc_monitor; +#ifdef CONFIG_TIPC_CRYPTO +struct tipc_crypto; +#endif + +#define TIPC_MOD_VER "2.0.0" + +#define NODE_HTABLE_SIZE 512 +#define MAX_BEARERS 3 +#define TIPC_DEF_MON_THRESHOLD 32 +#define NODE_ID_LEN 16 +#define NODE_ID_STR_LEN (NODE_ID_LEN * 2 + 1) + +extern unsigned int tipc_net_id __read_mostly; +extern int sysctl_tipc_rmem[3] __read_mostly; +extern int sysctl_tipc_named_timeout __read_mostly; + +struct tipc_net { + u8 node_id[NODE_ID_LEN]; + u32 node_addr; + u32 trial_addr; + unsigned long addr_trial_end; + char node_id_string[NODE_ID_STR_LEN]; + int net_id; + int random; + bool legacy_addr_format; + + /* Node table and node list */ + spinlock_t node_list_lock; + struct hlist_head node_htable[NODE_HTABLE_SIZE]; + struct list_head node_list; + u32 num_nodes; + u32 num_links; + + /* Neighbor monitoring list */ + struct tipc_monitor *monitors[MAX_BEARERS]; + int mon_threshold; + + /* Bearer list */ + struct tipc_bearer __rcu *bearer_list[MAX_BEARERS + 1]; + + /* Broadcast link */ + spinlock_t bclock; + struct tipc_bc_base *bcbase; + struct tipc_link *bcl; + + /* Socket hash table */ + struct rhashtable sk_rht; + + /* Name table */ + spinlock_t nametbl_lock; + struct name_table *nametbl; + + /* Topology subscription server */ + struct tipc_topsrv *topsrv; + atomic_t subscription_count; + + /* Cluster capabilities */ + u16 capabilities; + + /* Tracing of node internal messages */ + struct packet_type loopback_pt; + +#ifdef CONFIG_TIPC_CRYPTO + /* TX crypto handler */ + struct tipc_crypto *crypto_tx; +#endif + /* Work item for net finalize */ + struct work_struct work; + /* The numbers of work queues in schedule */ + atomic_t wq_count; +}; + +static inline struct tipc_net *tipc_net(struct net *net) +{ + return net_generic(net, tipc_net_id); +} + +static inline int tipc_netid(struct net *net) +{ + return tipc_net(net)->net_id; +} + +static inline struct list_head *tipc_nodes(struct net *net) +{ + return &tipc_net(net)->node_list; +} + +static inline struct name_table *tipc_name_table(struct net *net) +{ + return tipc_net(net)->nametbl; +} + +static inline struct tipc_topsrv *tipc_topsrv(struct net *net) +{ + return tipc_net(net)->topsrv; +} + +static inline unsigned int tipc_hashfn(u32 addr) +{ + return addr & (NODE_HTABLE_SIZE - 1); +} + +static inline u16 mod(u16 x) +{ + return x & 0xffffu; +} + +static inline int less_eq(u16 left, u16 right) +{ + return mod(right - left) < 32768u; +} + +static inline int more(u16 left, u16 right) +{ + return !less_eq(left, right); +} + +static inline int less(u16 left, u16 right) +{ + return less_eq(left, right) && (mod(right) != mod(left)); +} + +static inline int in_range(u16 val, u16 min, u16 max) +{ + return !less(val, min) && !more(val, max); +} + +static inline u32 tipc_net_hash_mixes(struct net *net, int tn_rand) +{ + return net_hash_mix(&init_net) ^ net_hash_mix(net) ^ tn_rand; +} + +static inline u32 hash128to32(char *bytes) +{ + __be32 *tmp = (__be32 *)bytes; + u32 res; + + res = ntohl(tmp[0] ^ tmp[1] ^ tmp[2] ^ tmp[3]); + if (likely(res)) + return res; + return ntohl(tmp[0] | tmp[1] | tmp[2] | tmp[3]); +} + +#ifdef CONFIG_SYSCTL +int tipc_register_sysctl(void); +void tipc_unregister_sysctl(void); +#else +#define tipc_register_sysctl() 0 +#define tipc_unregister_sysctl() +#endif +#endif diff --git a/net/tipc/crypto.c b/net/tipc/crypto.c new file mode 100644 index 000000000..65f59739a --- /dev/null +++ b/net/tipc/crypto.c @@ -0,0 +1,2475 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * net/tipc/crypto.c: TIPC crypto for key handling & packet en/decryption + * + * Copyright (c) 2019, Ericsson AB + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include +#include +#include +#include "crypto.h" +#include "msg.h" +#include "bcast.h" + +#define TIPC_TX_GRACE_PERIOD msecs_to_jiffies(5000) /* 5s */ +#define TIPC_TX_LASTING_TIME msecs_to_jiffies(10000) /* 10s */ +#define TIPC_RX_ACTIVE_LIM msecs_to_jiffies(3000) /* 3s */ +#define TIPC_RX_PASSIVE_LIM msecs_to_jiffies(15000) /* 15s */ + +#define TIPC_MAX_TFMS_DEF 10 +#define TIPC_MAX_TFMS_LIM 1000 + +#define TIPC_REKEYING_INTV_DEF (60 * 24) /* default: 1 day */ + +/* + * TIPC Key ids + */ +enum { + KEY_MASTER = 0, + KEY_MIN = KEY_MASTER, + KEY_1 = 1, + KEY_2, + KEY_3, + KEY_MAX = KEY_3, +}; + +/* + * TIPC Crypto statistics + */ +enum { + STAT_OK, + STAT_NOK, + STAT_ASYNC, + STAT_ASYNC_OK, + STAT_ASYNC_NOK, + STAT_BADKEYS, /* tx only */ + STAT_BADMSGS = STAT_BADKEYS, /* rx only */ + STAT_NOKEYS, + STAT_SWITCHES, + + MAX_STATS, +}; + +/* TIPC crypto statistics' header */ +static const char *hstats[MAX_STATS] = {"ok", "nok", "async", "async_ok", + "async_nok", "badmsgs", "nokeys", + "switches"}; + +/* Max TFMs number per key */ +int sysctl_tipc_max_tfms __read_mostly = TIPC_MAX_TFMS_DEF; +/* Key exchange switch, default: on */ +int sysctl_tipc_key_exchange_enabled __read_mostly = 1; + +/* + * struct tipc_key - TIPC keys' status indicator + * + * 7 6 5 4 3 2 1 0 + * +-----+-----+-----+-----+-----+-----+-----+-----+ + * key: | (reserved)|passive idx| active idx|pending idx| + * +-----+-----+-----+-----+-----+-----+-----+-----+ + */ +struct tipc_key { +#define KEY_BITS (2) +#define KEY_MASK ((1 << KEY_BITS) - 1) + union { + struct { +#if defined(__LITTLE_ENDIAN_BITFIELD) + u8 pending:2, + active:2, + passive:2, /* rx only */ + reserved:2; +#elif defined(__BIG_ENDIAN_BITFIELD) + u8 reserved:2, + passive:2, /* rx only */ + active:2, + pending:2; +#else +#error "Please fix " +#endif + } __packed; + u8 keys; + }; +}; + +/** + * struct tipc_tfm - TIPC TFM structure to form a list of TFMs + * @tfm: cipher handle/key + * @list: linked list of TFMs + */ +struct tipc_tfm { + struct crypto_aead *tfm; + struct list_head list; +}; + +/** + * struct tipc_aead - TIPC AEAD key structure + * @tfm_entry: per-cpu pointer to one entry in TFM list + * @crypto: TIPC crypto owns this key + * @cloned: reference to the source key in case cloning + * @users: the number of the key users (TX/RX) + * @salt: the key's SALT value + * @authsize: authentication tag size (max = 16) + * @mode: crypto mode is applied to the key + * @hint: a hint for user key + * @rcu: struct rcu_head + * @key: the aead key + * @gen: the key's generation + * @seqno: the key seqno (cluster scope) + * @refcnt: the key reference counter + */ +struct tipc_aead { +#define TIPC_AEAD_HINT_LEN (5) + struct tipc_tfm * __percpu *tfm_entry; + struct tipc_crypto *crypto; + struct tipc_aead *cloned; + atomic_t users; + u32 salt; + u8 authsize; + u8 mode; + char hint[2 * TIPC_AEAD_HINT_LEN + 1]; + struct rcu_head rcu; + struct tipc_aead_key *key; + u16 gen; + + atomic64_t seqno ____cacheline_aligned; + refcount_t refcnt ____cacheline_aligned; + +} ____cacheline_aligned; + +/** + * struct tipc_crypto_stats - TIPC Crypto statistics + * @stat: array of crypto statistics + */ +struct tipc_crypto_stats { + unsigned int stat[MAX_STATS]; +}; + +/** + * struct tipc_crypto - TIPC TX/RX crypto structure + * @net: struct net + * @node: TIPC node (RX) + * @aead: array of pointers to AEAD keys for encryption/decryption + * @peer_rx_active: replicated peer RX active key index + * @key_gen: TX/RX key generation + * @key: the key states + * @skey_mode: session key's mode + * @skey: received session key + * @wq: common workqueue on TX crypto + * @work: delayed work sched for TX/RX + * @key_distr: key distributing state + * @rekeying_intv: rekeying interval (in minutes) + * @stats: the crypto statistics + * @name: the crypto name + * @sndnxt: the per-peer sndnxt (TX) + * @timer1: general timer 1 (jiffies) + * @timer2: general timer 2 (jiffies) + * @working: the crypto is working or not + * @key_master: flag indicates if master key exists + * @legacy_user: flag indicates if a peer joins w/o master key (for bwd comp.) + * @nokey: no key indication + * @flags: combined flags field + * @lock: tipc_key lock + */ +struct tipc_crypto { + struct net *net; + struct tipc_node *node; + struct tipc_aead __rcu *aead[KEY_MAX + 1]; + atomic_t peer_rx_active; + u16 key_gen; + struct tipc_key key; + u8 skey_mode; + struct tipc_aead_key *skey; + struct workqueue_struct *wq; + struct delayed_work work; +#define KEY_DISTR_SCHED 1 +#define KEY_DISTR_COMPL 2 + atomic_t key_distr; + u32 rekeying_intv; + + struct tipc_crypto_stats __percpu *stats; + char name[48]; + + atomic64_t sndnxt ____cacheline_aligned; + unsigned long timer1; + unsigned long timer2; + union { + struct { + u8 working:1; + u8 key_master:1; + u8 legacy_user:1; + u8 nokey: 1; + }; + u8 flags; + }; + spinlock_t lock; /* crypto lock */ + +} ____cacheline_aligned; + +/* struct tipc_crypto_tx_ctx - TX context for callbacks */ +struct tipc_crypto_tx_ctx { + struct tipc_aead *aead; + struct tipc_bearer *bearer; + struct tipc_media_addr dst; +}; + +/* struct tipc_crypto_rx_ctx - RX context for callbacks */ +struct tipc_crypto_rx_ctx { + struct tipc_aead *aead; + struct tipc_bearer *bearer; +}; + +static struct tipc_aead *tipc_aead_get(struct tipc_aead __rcu *aead); +static inline void tipc_aead_put(struct tipc_aead *aead); +static void tipc_aead_free(struct rcu_head *rp); +static int tipc_aead_users(struct tipc_aead __rcu *aead); +static void tipc_aead_users_inc(struct tipc_aead __rcu *aead, int lim); +static void tipc_aead_users_dec(struct tipc_aead __rcu *aead, int lim); +static void tipc_aead_users_set(struct tipc_aead __rcu *aead, int val); +static struct crypto_aead *tipc_aead_tfm_next(struct tipc_aead *aead); +static int tipc_aead_init(struct tipc_aead **aead, struct tipc_aead_key *ukey, + u8 mode); +static int tipc_aead_clone(struct tipc_aead **dst, struct tipc_aead *src); +static void *tipc_aead_mem_alloc(struct crypto_aead *tfm, + unsigned int crypto_ctx_size, + u8 **iv, struct aead_request **req, + struct scatterlist **sg, int nsg); +static int tipc_aead_encrypt(struct tipc_aead *aead, struct sk_buff *skb, + struct tipc_bearer *b, + struct tipc_media_addr *dst, + struct tipc_node *__dnode); +static void tipc_aead_encrypt_done(struct crypto_async_request *base, int err); +static int tipc_aead_decrypt(struct net *net, struct tipc_aead *aead, + struct sk_buff *skb, struct tipc_bearer *b); +static void tipc_aead_decrypt_done(struct crypto_async_request *base, int err); +static inline int tipc_ehdr_size(struct tipc_ehdr *ehdr); +static int tipc_ehdr_build(struct net *net, struct tipc_aead *aead, + u8 tx_key, struct sk_buff *skb, + struct tipc_crypto *__rx); +static inline void tipc_crypto_key_set_state(struct tipc_crypto *c, + u8 new_passive, + u8 new_active, + u8 new_pending); +static int tipc_crypto_key_attach(struct tipc_crypto *c, + struct tipc_aead *aead, u8 pos, + bool master_key); +static bool tipc_crypto_key_try_align(struct tipc_crypto *rx, u8 new_pending); +static struct tipc_aead *tipc_crypto_key_pick_tx(struct tipc_crypto *tx, + struct tipc_crypto *rx, + struct sk_buff *skb, + u8 tx_key); +static void tipc_crypto_key_synch(struct tipc_crypto *rx, struct sk_buff *skb); +static int tipc_crypto_key_revoke(struct net *net, u8 tx_key); +static inline void tipc_crypto_clone_msg(struct net *net, struct sk_buff *_skb, + struct tipc_bearer *b, + struct tipc_media_addr *dst, + struct tipc_node *__dnode, u8 type); +static void tipc_crypto_rcv_complete(struct net *net, struct tipc_aead *aead, + struct tipc_bearer *b, + struct sk_buff **skb, int err); +static void tipc_crypto_do_cmd(struct net *net, int cmd); +static char *tipc_crypto_key_dump(struct tipc_crypto *c, char *buf); +static char *tipc_key_change_dump(struct tipc_key old, struct tipc_key new, + char *buf); +static int tipc_crypto_key_xmit(struct net *net, struct tipc_aead_key *skey, + u16 gen, u8 mode, u32 dnode); +static bool tipc_crypto_key_rcv(struct tipc_crypto *rx, struct tipc_msg *hdr); +static void tipc_crypto_work_tx(struct work_struct *work); +static void tipc_crypto_work_rx(struct work_struct *work); +static int tipc_aead_key_generate(struct tipc_aead_key *skey); + +#define is_tx(crypto) (!(crypto)->node) +#define is_rx(crypto) (!is_tx(crypto)) + +#define key_next(cur) ((cur) % KEY_MAX + 1) + +#define tipc_aead_rcu_ptr(rcu_ptr, lock) \ + rcu_dereference_protected((rcu_ptr), lockdep_is_held(lock)) + +#define tipc_aead_rcu_replace(rcu_ptr, ptr, lock) \ +do { \ + struct tipc_aead *__tmp = rcu_dereference_protected((rcu_ptr), \ + lockdep_is_held(lock)); \ + rcu_assign_pointer((rcu_ptr), (ptr)); \ + tipc_aead_put(__tmp); \ +} while (0) + +#define tipc_crypto_key_detach(rcu_ptr, lock) \ + tipc_aead_rcu_replace((rcu_ptr), NULL, lock) + +/** + * tipc_aead_key_validate - Validate a AEAD user key + * @ukey: pointer to user key data + * @info: netlink info pointer + */ +int tipc_aead_key_validate(struct tipc_aead_key *ukey, struct genl_info *info) +{ + int keylen; + + /* Check if algorithm exists */ + if (unlikely(!crypto_has_alg(ukey->alg_name, 0, 0))) { + GENL_SET_ERR_MSG(info, "unable to load the algorithm (module existed?)"); + return -ENODEV; + } + + /* Currently, we only support the "gcm(aes)" cipher algorithm */ + if (strcmp(ukey->alg_name, "gcm(aes)")) { + GENL_SET_ERR_MSG(info, "not supported yet the algorithm"); + return -ENOTSUPP; + } + + /* Check if key size is correct */ + keylen = ukey->keylen - TIPC_AES_GCM_SALT_SIZE; + if (unlikely(keylen != TIPC_AES_GCM_KEY_SIZE_128 && + keylen != TIPC_AES_GCM_KEY_SIZE_192 && + keylen != TIPC_AES_GCM_KEY_SIZE_256)) { + GENL_SET_ERR_MSG(info, "incorrect key length (20, 28 or 36 octets?)"); + return -EKEYREJECTED; + } + + return 0; +} + +/** + * tipc_aead_key_generate - Generate new session key + * @skey: input/output key with new content + * + * Return: 0 in case of success, otherwise < 0 + */ +static int tipc_aead_key_generate(struct tipc_aead_key *skey) +{ + int rc = 0; + + /* Fill the key's content with a random value via RNG cipher */ + rc = crypto_get_default_rng(); + if (likely(!rc)) { + rc = crypto_rng_get_bytes(crypto_default_rng, skey->key, + skey->keylen); + crypto_put_default_rng(); + } + + return rc; +} + +static struct tipc_aead *tipc_aead_get(struct tipc_aead __rcu *aead) +{ + struct tipc_aead *tmp; + + rcu_read_lock(); + tmp = rcu_dereference(aead); + if (unlikely(!tmp || !refcount_inc_not_zero(&tmp->refcnt))) + tmp = NULL; + rcu_read_unlock(); + + return tmp; +} + +static inline void tipc_aead_put(struct tipc_aead *aead) +{ + if (aead && refcount_dec_and_test(&aead->refcnt)) + call_rcu(&aead->rcu, tipc_aead_free); +} + +/** + * tipc_aead_free - Release AEAD key incl. all the TFMs in the list + * @rp: rcu head pointer + */ +static void tipc_aead_free(struct rcu_head *rp) +{ + struct tipc_aead *aead = container_of(rp, struct tipc_aead, rcu); + struct tipc_tfm *tfm_entry, *head, *tmp; + + if (aead->cloned) { + tipc_aead_put(aead->cloned); + } else { + head = *get_cpu_ptr(aead->tfm_entry); + put_cpu_ptr(aead->tfm_entry); + list_for_each_entry_safe(tfm_entry, tmp, &head->list, list) { + crypto_free_aead(tfm_entry->tfm); + list_del(&tfm_entry->list); + kfree(tfm_entry); + } + /* Free the head */ + crypto_free_aead(head->tfm); + list_del(&head->list); + kfree(head); + } + free_percpu(aead->tfm_entry); + kfree_sensitive(aead->key); + kfree(aead); +} + +static int tipc_aead_users(struct tipc_aead __rcu *aead) +{ + struct tipc_aead *tmp; + int users = 0; + + rcu_read_lock(); + tmp = rcu_dereference(aead); + if (tmp) + users = atomic_read(&tmp->users); + rcu_read_unlock(); + + return users; +} + +static void tipc_aead_users_inc(struct tipc_aead __rcu *aead, int lim) +{ + struct tipc_aead *tmp; + + rcu_read_lock(); + tmp = rcu_dereference(aead); + if (tmp) + atomic_add_unless(&tmp->users, 1, lim); + rcu_read_unlock(); +} + +static void tipc_aead_users_dec(struct tipc_aead __rcu *aead, int lim) +{ + struct tipc_aead *tmp; + + rcu_read_lock(); + tmp = rcu_dereference(aead); + if (tmp) + atomic_add_unless(&rcu_dereference(aead)->users, -1, lim); + rcu_read_unlock(); +} + +static void tipc_aead_users_set(struct tipc_aead __rcu *aead, int val) +{ + struct tipc_aead *tmp; + int cur; + + rcu_read_lock(); + tmp = rcu_dereference(aead); + if (tmp) { + do { + cur = atomic_read(&tmp->users); + if (cur == val) + break; + } while (atomic_cmpxchg(&tmp->users, cur, val) != cur); + } + rcu_read_unlock(); +} + +/** + * tipc_aead_tfm_next - Move TFM entry to the next one in list and return it + * @aead: the AEAD key pointer + */ +static struct crypto_aead *tipc_aead_tfm_next(struct tipc_aead *aead) +{ + struct tipc_tfm **tfm_entry; + struct crypto_aead *tfm; + + tfm_entry = get_cpu_ptr(aead->tfm_entry); + *tfm_entry = list_next_entry(*tfm_entry, list); + tfm = (*tfm_entry)->tfm; + put_cpu_ptr(tfm_entry); + + return tfm; +} + +/** + * tipc_aead_init - Initiate TIPC AEAD + * @aead: returned new TIPC AEAD key handle pointer + * @ukey: pointer to user key data + * @mode: the key mode + * + * Allocate a (list of) new cipher transformation (TFM) with the specific user + * key data if valid. The number of the allocated TFMs can be set via the sysfs + * "net/tipc/max_tfms" first. + * Also, all the other AEAD data are also initialized. + * + * Return: 0 if the initiation is successful, otherwise: < 0 + */ +static int tipc_aead_init(struct tipc_aead **aead, struct tipc_aead_key *ukey, + u8 mode) +{ + struct tipc_tfm *tfm_entry, *head; + struct crypto_aead *tfm; + struct tipc_aead *tmp; + int keylen, err, cpu; + int tfm_cnt = 0; + + if (unlikely(*aead)) + return -EEXIST; + + /* Allocate a new AEAD */ + tmp = kzalloc(sizeof(*tmp), GFP_ATOMIC); + if (unlikely(!tmp)) + return -ENOMEM; + + /* The key consists of two parts: [AES-KEY][SALT] */ + keylen = ukey->keylen - TIPC_AES_GCM_SALT_SIZE; + + /* Allocate per-cpu TFM entry pointer */ + tmp->tfm_entry = alloc_percpu(struct tipc_tfm *); + if (!tmp->tfm_entry) { + kfree_sensitive(tmp); + return -ENOMEM; + } + + /* Make a list of TFMs with the user key data */ + do { + tfm = crypto_alloc_aead(ukey->alg_name, 0, 0); + if (IS_ERR(tfm)) { + err = PTR_ERR(tfm); + break; + } + + if (unlikely(!tfm_cnt && + crypto_aead_ivsize(tfm) != TIPC_AES_GCM_IV_SIZE)) { + crypto_free_aead(tfm); + err = -ENOTSUPP; + break; + } + + err = crypto_aead_setauthsize(tfm, TIPC_AES_GCM_TAG_SIZE); + err |= crypto_aead_setkey(tfm, ukey->key, keylen); + if (unlikely(err)) { + crypto_free_aead(tfm); + break; + } + + tfm_entry = kmalloc(sizeof(*tfm_entry), GFP_KERNEL); + if (unlikely(!tfm_entry)) { + crypto_free_aead(tfm); + err = -ENOMEM; + break; + } + INIT_LIST_HEAD(&tfm_entry->list); + tfm_entry->tfm = tfm; + + /* First entry? */ + if (!tfm_cnt) { + head = tfm_entry; + for_each_possible_cpu(cpu) { + *per_cpu_ptr(tmp->tfm_entry, cpu) = head; + } + } else { + list_add_tail(&tfm_entry->list, &head->list); + } + + } while (++tfm_cnt < sysctl_tipc_max_tfms); + + /* Not any TFM is allocated? */ + if (!tfm_cnt) { + free_percpu(tmp->tfm_entry); + kfree_sensitive(tmp); + return err; + } + + /* Form a hex string of some last bytes as the key's hint */ + bin2hex(tmp->hint, ukey->key + keylen - TIPC_AEAD_HINT_LEN, + TIPC_AEAD_HINT_LEN); + + /* Initialize the other data */ + tmp->mode = mode; + tmp->cloned = NULL; + tmp->authsize = TIPC_AES_GCM_TAG_SIZE; + tmp->key = kmemdup(ukey, tipc_aead_key_size(ukey), GFP_KERNEL); + if (!tmp->key) { + tipc_aead_free(&tmp->rcu); + return -ENOMEM; + } + memcpy(&tmp->salt, ukey->key + keylen, TIPC_AES_GCM_SALT_SIZE); + atomic_set(&tmp->users, 0); + atomic64_set(&tmp->seqno, 0); + refcount_set(&tmp->refcnt, 1); + + *aead = tmp; + return 0; +} + +/** + * tipc_aead_clone - Clone a TIPC AEAD key + * @dst: dest key for the cloning + * @src: source key to clone from + * + * Make a "copy" of the source AEAD key data to the dest, the TFMs list is + * common for the keys. + * A reference to the source is hold in the "cloned" pointer for the later + * freeing purposes. + * + * Note: this must be done in cluster-key mode only! + * Return: 0 in case of success, otherwise < 0 + */ +static int tipc_aead_clone(struct tipc_aead **dst, struct tipc_aead *src) +{ + struct tipc_aead *aead; + int cpu; + + if (!src) + return -ENOKEY; + + if (src->mode != CLUSTER_KEY) + return -EINVAL; + + if (unlikely(*dst)) + return -EEXIST; + + aead = kzalloc(sizeof(*aead), GFP_ATOMIC); + if (unlikely(!aead)) + return -ENOMEM; + + aead->tfm_entry = alloc_percpu_gfp(struct tipc_tfm *, GFP_ATOMIC); + if (unlikely(!aead->tfm_entry)) { + kfree_sensitive(aead); + return -ENOMEM; + } + + for_each_possible_cpu(cpu) { + *per_cpu_ptr(aead->tfm_entry, cpu) = + *per_cpu_ptr(src->tfm_entry, cpu); + } + + memcpy(aead->hint, src->hint, sizeof(src->hint)); + aead->mode = src->mode; + aead->salt = src->salt; + aead->authsize = src->authsize; + atomic_set(&aead->users, 0); + atomic64_set(&aead->seqno, 0); + refcount_set(&aead->refcnt, 1); + + WARN_ON(!refcount_inc_not_zero(&src->refcnt)); + aead->cloned = src; + + *dst = aead; + return 0; +} + +/** + * tipc_aead_mem_alloc - Allocate memory for AEAD request operations + * @tfm: cipher handle to be registered with the request + * @crypto_ctx_size: size of crypto context for callback + * @iv: returned pointer to IV data + * @req: returned pointer to AEAD request data + * @sg: returned pointer to SG lists + * @nsg: number of SG lists to be allocated + * + * Allocate memory to store the crypto context data, AEAD request, IV and SG + * lists, the memory layout is as follows: + * crypto_ctx || iv || aead_req || sg[] + * + * Return: the pointer to the memory areas in case of success, otherwise NULL + */ +static void *tipc_aead_mem_alloc(struct crypto_aead *tfm, + unsigned int crypto_ctx_size, + u8 **iv, struct aead_request **req, + struct scatterlist **sg, int nsg) +{ + unsigned int iv_size, req_size; + unsigned int len; + u8 *mem; + + iv_size = crypto_aead_ivsize(tfm); + req_size = sizeof(**req) + crypto_aead_reqsize(tfm); + + len = crypto_ctx_size; + len += iv_size; + len += crypto_aead_alignmask(tfm) & ~(crypto_tfm_ctx_alignment() - 1); + len = ALIGN(len, crypto_tfm_ctx_alignment()); + len += req_size; + len = ALIGN(len, __alignof__(struct scatterlist)); + len += nsg * sizeof(**sg); + + mem = kmalloc(len, GFP_ATOMIC); + if (!mem) + return NULL; + + *iv = (u8 *)PTR_ALIGN(mem + crypto_ctx_size, + crypto_aead_alignmask(tfm) + 1); + *req = (struct aead_request *)PTR_ALIGN(*iv + iv_size, + crypto_tfm_ctx_alignment()); + *sg = (struct scatterlist *)PTR_ALIGN((u8 *)*req + req_size, + __alignof__(struct scatterlist)); + + return (void *)mem; +} + +/** + * tipc_aead_encrypt - Encrypt a message + * @aead: TIPC AEAD key for the message encryption + * @skb: the input/output skb + * @b: TIPC bearer where the message will be delivered after the encryption + * @dst: the destination media address + * @__dnode: TIPC dest node if "known" + * + * Return: + * * 0 : if the encryption has completed + * * -EINPROGRESS/-EBUSY : if a callback will be performed + * * < 0 : the encryption has failed + */ +static int tipc_aead_encrypt(struct tipc_aead *aead, struct sk_buff *skb, + struct tipc_bearer *b, + struct tipc_media_addr *dst, + struct tipc_node *__dnode) +{ + struct crypto_aead *tfm = tipc_aead_tfm_next(aead); + struct tipc_crypto_tx_ctx *tx_ctx; + struct aead_request *req; + struct sk_buff *trailer; + struct scatterlist *sg; + struct tipc_ehdr *ehdr; + int ehsz, len, tailen, nsg, rc; + void *ctx; + u32 salt; + u8 *iv; + + /* Make sure message len at least 4-byte aligned */ + len = ALIGN(skb->len, 4); + tailen = len - skb->len + aead->authsize; + + /* Expand skb tail for authentication tag: + * As for simplicity, we'd have made sure skb having enough tailroom + * for authentication tag @skb allocation. Even when skb is nonlinear + * but there is no frag_list, it should be still fine! + * Otherwise, we must cow it to be a writable buffer with the tailroom. + */ + SKB_LINEAR_ASSERT(skb); + if (tailen > skb_tailroom(skb)) { + pr_debug("TX(): skb tailroom is not enough: %d, requires: %d\n", + skb_tailroom(skb), tailen); + } + + nsg = skb_cow_data(skb, tailen, &trailer); + if (unlikely(nsg < 0)) { + pr_err("TX: skb_cow_data() returned %d\n", nsg); + return nsg; + } + + pskb_put(skb, trailer, tailen); + + /* Allocate memory for the AEAD operation */ + ctx = tipc_aead_mem_alloc(tfm, sizeof(*tx_ctx), &iv, &req, &sg, nsg); + if (unlikely(!ctx)) + return -ENOMEM; + TIPC_SKB_CB(skb)->crypto_ctx = ctx; + + /* Map skb to the sg lists */ + sg_init_table(sg, nsg); + rc = skb_to_sgvec(skb, sg, 0, skb->len); + if (unlikely(rc < 0)) { + pr_err("TX: skb_to_sgvec() returned %d, nsg %d!\n", rc, nsg); + goto exit; + } + + /* Prepare IV: [SALT (4 octets)][SEQNO (8 octets)] + * In case we're in cluster-key mode, SALT is varied by xor-ing with + * the source address (or w0 of id), otherwise with the dest address + * if dest is known. + */ + ehdr = (struct tipc_ehdr *)skb->data; + salt = aead->salt; + if (aead->mode == CLUSTER_KEY) + salt ^= __be32_to_cpu(ehdr->addr); + else if (__dnode) + salt ^= tipc_node_get_addr(__dnode); + memcpy(iv, &salt, 4); + memcpy(iv + 4, (u8 *)&ehdr->seqno, 8); + + /* Prepare request */ + ehsz = tipc_ehdr_size(ehdr); + aead_request_set_tfm(req, tfm); + aead_request_set_ad(req, ehsz); + aead_request_set_crypt(req, sg, sg, len - ehsz, iv); + + /* Set callback function & data */ + aead_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG, + tipc_aead_encrypt_done, skb); + tx_ctx = (struct tipc_crypto_tx_ctx *)ctx; + tx_ctx->aead = aead; + tx_ctx->bearer = b; + memcpy(&tx_ctx->dst, dst, sizeof(*dst)); + + /* Hold bearer */ + if (unlikely(!tipc_bearer_hold(b))) { + rc = -ENODEV; + goto exit; + } + + /* Now, do encrypt */ + rc = crypto_aead_encrypt(req); + if (rc == -EINPROGRESS || rc == -EBUSY) + return rc; + + tipc_bearer_put(b); + +exit: + kfree(ctx); + TIPC_SKB_CB(skb)->crypto_ctx = NULL; + return rc; +} + +static void tipc_aead_encrypt_done(struct crypto_async_request *base, int err) +{ + struct sk_buff *skb = base->data; + struct tipc_crypto_tx_ctx *tx_ctx = TIPC_SKB_CB(skb)->crypto_ctx; + struct tipc_bearer *b = tx_ctx->bearer; + struct tipc_aead *aead = tx_ctx->aead; + struct tipc_crypto *tx = aead->crypto; + struct net *net = tx->net; + + switch (err) { + case 0: + this_cpu_inc(tx->stats->stat[STAT_ASYNC_OK]); + rcu_read_lock(); + if (likely(test_bit(0, &b->up))) + b->media->send_msg(net, skb, b, &tx_ctx->dst); + else + kfree_skb(skb); + rcu_read_unlock(); + break; + case -EINPROGRESS: + return; + default: + this_cpu_inc(tx->stats->stat[STAT_ASYNC_NOK]); + kfree_skb(skb); + break; + } + + kfree(tx_ctx); + tipc_bearer_put(b); + tipc_aead_put(aead); +} + +/** + * tipc_aead_decrypt - Decrypt an encrypted message + * @net: struct net + * @aead: TIPC AEAD for the message decryption + * @skb: the input/output skb + * @b: TIPC bearer where the message has been received + * + * Return: + * * 0 : if the decryption has completed + * * -EINPROGRESS/-EBUSY : if a callback will be performed + * * < 0 : the decryption has failed + */ +static int tipc_aead_decrypt(struct net *net, struct tipc_aead *aead, + struct sk_buff *skb, struct tipc_bearer *b) +{ + struct tipc_crypto_rx_ctx *rx_ctx; + struct aead_request *req; + struct crypto_aead *tfm; + struct sk_buff *unused; + struct scatterlist *sg; + struct tipc_ehdr *ehdr; + int ehsz, nsg, rc; + void *ctx; + u32 salt; + u8 *iv; + + if (unlikely(!aead)) + return -ENOKEY; + + nsg = skb_cow_data(skb, 0, &unused); + if (unlikely(nsg < 0)) { + pr_err("RX: skb_cow_data() returned %d\n", nsg); + return nsg; + } + + /* Allocate memory for the AEAD operation */ + tfm = tipc_aead_tfm_next(aead); + ctx = tipc_aead_mem_alloc(tfm, sizeof(*rx_ctx), &iv, &req, &sg, nsg); + if (unlikely(!ctx)) + return -ENOMEM; + TIPC_SKB_CB(skb)->crypto_ctx = ctx; + + /* Map skb to the sg lists */ + sg_init_table(sg, nsg); + rc = skb_to_sgvec(skb, sg, 0, skb->len); + if (unlikely(rc < 0)) { + pr_err("RX: skb_to_sgvec() returned %d, nsg %d\n", rc, nsg); + goto exit; + } + + /* Reconstruct IV: */ + ehdr = (struct tipc_ehdr *)skb->data; + salt = aead->salt; + if (aead->mode == CLUSTER_KEY) + salt ^= __be32_to_cpu(ehdr->addr); + else if (ehdr->destined) + salt ^= tipc_own_addr(net); + memcpy(iv, &salt, 4); + memcpy(iv + 4, (u8 *)&ehdr->seqno, 8); + + /* Prepare request */ + ehsz = tipc_ehdr_size(ehdr); + aead_request_set_tfm(req, tfm); + aead_request_set_ad(req, ehsz); + aead_request_set_crypt(req, sg, sg, skb->len - ehsz, iv); + + /* Set callback function & data */ + aead_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG, + tipc_aead_decrypt_done, skb); + rx_ctx = (struct tipc_crypto_rx_ctx *)ctx; + rx_ctx->aead = aead; + rx_ctx->bearer = b; + + /* Hold bearer */ + if (unlikely(!tipc_bearer_hold(b))) { + rc = -ENODEV; + goto exit; + } + + /* Now, do decrypt */ + rc = crypto_aead_decrypt(req); + if (rc == -EINPROGRESS || rc == -EBUSY) + return rc; + + tipc_bearer_put(b); + +exit: + kfree(ctx); + TIPC_SKB_CB(skb)->crypto_ctx = NULL; + return rc; +} + +static void tipc_aead_decrypt_done(struct crypto_async_request *base, int err) +{ + struct sk_buff *skb = base->data; + struct tipc_crypto_rx_ctx *rx_ctx = TIPC_SKB_CB(skb)->crypto_ctx; + struct tipc_bearer *b = rx_ctx->bearer; + struct tipc_aead *aead = rx_ctx->aead; + struct tipc_crypto_stats __percpu *stats = aead->crypto->stats; + struct net *net = aead->crypto->net; + + switch (err) { + case 0: + this_cpu_inc(stats->stat[STAT_ASYNC_OK]); + break; + case -EINPROGRESS: + return; + default: + this_cpu_inc(stats->stat[STAT_ASYNC_NOK]); + break; + } + + kfree(rx_ctx); + tipc_crypto_rcv_complete(net, aead, b, &skb, err); + if (likely(skb)) { + if (likely(test_bit(0, &b->up))) + tipc_rcv(net, skb, b); + else + kfree_skb(skb); + } + + tipc_bearer_put(b); +} + +static inline int tipc_ehdr_size(struct tipc_ehdr *ehdr) +{ + return (ehdr->user != LINK_CONFIG) ? EHDR_SIZE : EHDR_CFG_SIZE; +} + +/** + * tipc_ehdr_validate - Validate an encryption message + * @skb: the message buffer + * + * Return: "true" if this is a valid encryption message, otherwise "false" + */ +bool tipc_ehdr_validate(struct sk_buff *skb) +{ + struct tipc_ehdr *ehdr; + int ehsz; + + if (unlikely(!pskb_may_pull(skb, EHDR_MIN_SIZE))) + return false; + + ehdr = (struct tipc_ehdr *)skb->data; + if (unlikely(ehdr->version != TIPC_EVERSION)) + return false; + ehsz = tipc_ehdr_size(ehdr); + if (unlikely(!pskb_may_pull(skb, ehsz))) + return false; + if (unlikely(skb->len <= ehsz + TIPC_AES_GCM_TAG_SIZE)) + return false; + + return true; +} + +/** + * tipc_ehdr_build - Build TIPC encryption message header + * @net: struct net + * @aead: TX AEAD key to be used for the message encryption + * @tx_key: key id used for the message encryption + * @skb: input/output message skb + * @__rx: RX crypto handle if dest is "known" + * + * Return: the header size if the building is successful, otherwise < 0 + */ +static int tipc_ehdr_build(struct net *net, struct tipc_aead *aead, + u8 tx_key, struct sk_buff *skb, + struct tipc_crypto *__rx) +{ + struct tipc_msg *hdr = buf_msg(skb); + struct tipc_ehdr *ehdr; + u32 user = msg_user(hdr); + u64 seqno; + int ehsz; + + /* Make room for encryption header */ + ehsz = (user != LINK_CONFIG) ? EHDR_SIZE : EHDR_CFG_SIZE; + WARN_ON(skb_headroom(skb) < ehsz); + ehdr = (struct tipc_ehdr *)skb_push(skb, ehsz); + + /* Obtain a seqno first: + * Use the key seqno (= cluster wise) if dest is unknown or we're in + * cluster key mode, otherwise it's better for a per-peer seqno! + */ + if (!__rx || aead->mode == CLUSTER_KEY) + seqno = atomic64_inc_return(&aead->seqno); + else + seqno = atomic64_inc_return(&__rx->sndnxt); + + /* Revoke the key if seqno is wrapped around */ + if (unlikely(!seqno)) + return tipc_crypto_key_revoke(net, tx_key); + + /* Word 1-2 */ + ehdr->seqno = cpu_to_be64(seqno); + + /* Words 0, 3- */ + ehdr->version = TIPC_EVERSION; + ehdr->user = 0; + ehdr->keepalive = 0; + ehdr->tx_key = tx_key; + ehdr->destined = (__rx) ? 1 : 0; + ehdr->rx_key_active = (__rx) ? __rx->key.active : 0; + ehdr->rx_nokey = (__rx) ? __rx->nokey : 0; + ehdr->master_key = aead->crypto->key_master; + ehdr->reserved_1 = 0; + ehdr->reserved_2 = 0; + + switch (user) { + case LINK_CONFIG: + ehdr->user = LINK_CONFIG; + memcpy(ehdr->id, tipc_own_id(net), NODE_ID_LEN); + break; + default: + if (user == LINK_PROTOCOL && msg_type(hdr) == STATE_MSG) { + ehdr->user = LINK_PROTOCOL; + ehdr->keepalive = msg_is_keepalive(hdr); + } + ehdr->addr = hdr->hdr[3]; + break; + } + + return ehsz; +} + +static inline void tipc_crypto_key_set_state(struct tipc_crypto *c, + u8 new_passive, + u8 new_active, + u8 new_pending) +{ + struct tipc_key old = c->key; + char buf[32]; + + c->key.keys = ((new_passive & KEY_MASK) << (KEY_BITS * 2)) | + ((new_active & KEY_MASK) << (KEY_BITS)) | + ((new_pending & KEY_MASK)); + + pr_debug("%s: key changing %s ::%pS\n", c->name, + tipc_key_change_dump(old, c->key, buf), + __builtin_return_address(0)); +} + +/** + * tipc_crypto_key_init - Initiate a new user / AEAD key + * @c: TIPC crypto to which new key is attached + * @ukey: the user key + * @mode: the key mode (CLUSTER_KEY or PER_NODE_KEY) + * @master_key: specify this is a cluster master key + * + * A new TIPC AEAD key will be allocated and initiated with the specified user + * key, then attached to the TIPC crypto. + * + * Return: new key id in case of success, otherwise: < 0 + */ +int tipc_crypto_key_init(struct tipc_crypto *c, struct tipc_aead_key *ukey, + u8 mode, bool master_key) +{ + struct tipc_aead *aead = NULL; + int rc = 0; + + /* Initiate with the new user key */ + rc = tipc_aead_init(&aead, ukey, mode); + + /* Attach it to the crypto */ + if (likely(!rc)) { + rc = tipc_crypto_key_attach(c, aead, 0, master_key); + if (rc < 0) + tipc_aead_free(&aead->rcu); + } + + return rc; +} + +/** + * tipc_crypto_key_attach - Attach a new AEAD key to TIPC crypto + * @c: TIPC crypto to which the new AEAD key is attached + * @aead: the new AEAD key pointer + * @pos: desired slot in the crypto key array, = 0 if any! + * @master_key: specify this is a cluster master key + * + * Return: new key id in case of success, otherwise: -EBUSY + */ +static int tipc_crypto_key_attach(struct tipc_crypto *c, + struct tipc_aead *aead, u8 pos, + bool master_key) +{ + struct tipc_key key; + int rc = -EBUSY; + u8 new_key; + + spin_lock_bh(&c->lock); + key = c->key; + if (master_key) { + new_key = KEY_MASTER; + goto attach; + } + if (key.active && key.passive) + goto exit; + if (key.pending) { + if (tipc_aead_users(c->aead[key.pending]) > 0) + goto exit; + /* if (pos): ok with replacing, will be aligned when needed */ + /* Replace it */ + new_key = key.pending; + } else { + if (pos) { + if (key.active && pos != key_next(key.active)) { + key.passive = pos; + new_key = pos; + goto attach; + } else if (!key.active && !key.passive) { + key.pending = pos; + new_key = pos; + goto attach; + } + } + key.pending = key_next(key.active ?: key.passive); + new_key = key.pending; + } + +attach: + aead->crypto = c; + aead->gen = (is_tx(c)) ? ++c->key_gen : c->key_gen; + tipc_aead_rcu_replace(c->aead[new_key], aead, &c->lock); + if (likely(c->key.keys != key.keys)) + tipc_crypto_key_set_state(c, key.passive, key.active, + key.pending); + c->working = 1; + c->nokey = 0; + c->key_master |= master_key; + rc = new_key; + +exit: + spin_unlock_bh(&c->lock); + return rc; +} + +void tipc_crypto_key_flush(struct tipc_crypto *c) +{ + struct tipc_crypto *tx, *rx; + int k; + + spin_lock_bh(&c->lock); + if (is_rx(c)) { + /* Try to cancel pending work */ + rx = c; + tx = tipc_net(rx->net)->crypto_tx; + if (cancel_delayed_work(&rx->work)) { + kfree(rx->skey); + rx->skey = NULL; + atomic_xchg(&rx->key_distr, 0); + tipc_node_put(rx->node); + } + /* RX stopping => decrease TX key users if any */ + k = atomic_xchg(&rx->peer_rx_active, 0); + if (k) { + tipc_aead_users_dec(tx->aead[k], 0); + /* Mark the point TX key users changed */ + tx->timer1 = jiffies; + } + } + + c->flags = 0; + tipc_crypto_key_set_state(c, 0, 0, 0); + for (k = KEY_MIN; k <= KEY_MAX; k++) + tipc_crypto_key_detach(c->aead[k], &c->lock); + atomic64_set(&c->sndnxt, 0); + spin_unlock_bh(&c->lock); +} + +/** + * tipc_crypto_key_try_align - Align RX keys if possible + * @rx: RX crypto handle + * @new_pending: new pending slot if aligned (= TX key from peer) + * + * Peer has used an unknown key slot, this only happens when peer has left and + * rejoned, or we are newcomer. + * That means, there must be no active key but a pending key at unaligned slot. + * If so, we try to move the pending key to the new slot. + * Note: A potential passive key can exist, it will be shifted correspondingly! + * + * Return: "true" if key is successfully aligned, otherwise "false" + */ +static bool tipc_crypto_key_try_align(struct tipc_crypto *rx, u8 new_pending) +{ + struct tipc_aead *tmp1, *tmp2 = NULL; + struct tipc_key key; + bool aligned = false; + u8 new_passive = 0; + int x; + + spin_lock(&rx->lock); + key = rx->key; + if (key.pending == new_pending) { + aligned = true; + goto exit; + } + if (key.active) + goto exit; + if (!key.pending) + goto exit; + if (tipc_aead_users(rx->aead[key.pending]) > 0) + goto exit; + + /* Try to "isolate" this pending key first */ + tmp1 = tipc_aead_rcu_ptr(rx->aead[key.pending], &rx->lock); + if (!refcount_dec_if_one(&tmp1->refcnt)) + goto exit; + rcu_assign_pointer(rx->aead[key.pending], NULL); + + /* Move passive key if any */ + if (key.passive) { + tmp2 = rcu_replace_pointer(rx->aead[key.passive], tmp2, lockdep_is_held(&rx->lock)); + x = (key.passive - key.pending + new_pending) % KEY_MAX; + new_passive = (x <= 0) ? x + KEY_MAX : x; + } + + /* Re-allocate the key(s) */ + tipc_crypto_key_set_state(rx, new_passive, 0, new_pending); + rcu_assign_pointer(rx->aead[new_pending], tmp1); + if (new_passive) + rcu_assign_pointer(rx->aead[new_passive], tmp2); + refcount_set(&tmp1->refcnt, 1); + aligned = true; + pr_info_ratelimited("%s: key[%d] -> key[%d]\n", rx->name, key.pending, + new_pending); + +exit: + spin_unlock(&rx->lock); + return aligned; +} + +/** + * tipc_crypto_key_pick_tx - Pick one TX key for message decryption + * @tx: TX crypto handle + * @rx: RX crypto handle (can be NULL) + * @skb: the message skb which will be decrypted later + * @tx_key: peer TX key id + * + * This function looks up the existing TX keys and pick one which is suitable + * for the message decryption, that must be a cluster key and not used before + * on the same message (i.e. recursive). + * + * Return: the TX AEAD key handle in case of success, otherwise NULL + */ +static struct tipc_aead *tipc_crypto_key_pick_tx(struct tipc_crypto *tx, + struct tipc_crypto *rx, + struct sk_buff *skb, + u8 tx_key) +{ + struct tipc_skb_cb *skb_cb = TIPC_SKB_CB(skb); + struct tipc_aead *aead = NULL; + struct tipc_key key = tx->key; + u8 k, i = 0; + + /* Initialize data if not yet */ + if (!skb_cb->tx_clone_deferred) { + skb_cb->tx_clone_deferred = 1; + memset(&skb_cb->tx_clone_ctx, 0, sizeof(skb_cb->tx_clone_ctx)); + } + + skb_cb->tx_clone_ctx.rx = rx; + if (++skb_cb->tx_clone_ctx.recurs > 2) + return NULL; + + /* Pick one TX key */ + spin_lock(&tx->lock); + if (tx_key == KEY_MASTER) { + aead = tipc_aead_rcu_ptr(tx->aead[KEY_MASTER], &tx->lock); + goto done; + } + do { + k = (i == 0) ? key.pending : + ((i == 1) ? key.active : key.passive); + if (!k) + continue; + aead = tipc_aead_rcu_ptr(tx->aead[k], &tx->lock); + if (!aead) + continue; + if (aead->mode != CLUSTER_KEY || + aead == skb_cb->tx_clone_ctx.last) { + aead = NULL; + continue; + } + /* Ok, found one cluster key */ + skb_cb->tx_clone_ctx.last = aead; + WARN_ON(skb->next); + skb->next = skb_clone(skb, GFP_ATOMIC); + if (unlikely(!skb->next)) + pr_warn("Failed to clone skb for next round if any\n"); + break; + } while (++i < 3); + +done: + if (likely(aead)) + WARN_ON(!refcount_inc_not_zero(&aead->refcnt)); + spin_unlock(&tx->lock); + + return aead; +} + +/** + * tipc_crypto_key_synch: Synch own key data according to peer key status + * @rx: RX crypto handle + * @skb: TIPCv2 message buffer (incl. the ehdr from peer) + * + * This function updates the peer node related data as the peer RX active key + * has changed, so the number of TX keys' users on this node are increased and + * decreased correspondingly. + * + * It also considers if peer has no key, then we need to make own master key + * (if any) taking over i.e. starting grace period and also trigger key + * distributing process. + * + * The "per-peer" sndnxt is also reset when the peer key has switched. + */ +static void tipc_crypto_key_synch(struct tipc_crypto *rx, struct sk_buff *skb) +{ + struct tipc_ehdr *ehdr = (struct tipc_ehdr *)skb_network_header(skb); + struct tipc_crypto *tx = tipc_net(rx->net)->crypto_tx; + struct tipc_msg *hdr = buf_msg(skb); + u32 self = tipc_own_addr(rx->net); + u8 cur, new; + unsigned long delay; + + /* Update RX 'key_master' flag according to peer, also mark "legacy" if + * a peer has no master key. + */ + rx->key_master = ehdr->master_key; + if (!rx->key_master) + tx->legacy_user = 1; + + /* For later cases, apply only if message is destined to this node */ + if (!ehdr->destined || msg_short(hdr) || msg_destnode(hdr) != self) + return; + + /* Case 1: Peer has no keys, let's make master key take over */ + if (ehdr->rx_nokey) { + /* Set or extend grace period */ + tx->timer2 = jiffies; + /* Schedule key distributing for the peer if not yet */ + if (tx->key.keys && + !atomic_cmpxchg(&rx->key_distr, 0, KEY_DISTR_SCHED)) { + get_random_bytes(&delay, 2); + delay %= 5; + delay = msecs_to_jiffies(500 * ++delay); + if (queue_delayed_work(tx->wq, &rx->work, delay)) + tipc_node_get(rx->node); + } + } else { + /* Cancel a pending key distributing if any */ + atomic_xchg(&rx->key_distr, 0); + } + + /* Case 2: Peer RX active key has changed, let's update own TX users */ + cur = atomic_read(&rx->peer_rx_active); + new = ehdr->rx_key_active; + if (tx->key.keys && + cur != new && + atomic_cmpxchg(&rx->peer_rx_active, cur, new) == cur) { + if (new) + tipc_aead_users_inc(tx->aead[new], INT_MAX); + if (cur) + tipc_aead_users_dec(tx->aead[cur], 0); + + atomic64_set(&rx->sndnxt, 0); + /* Mark the point TX key users changed */ + tx->timer1 = jiffies; + + pr_debug("%s: key users changed %d-- %d++, peer %s\n", + tx->name, cur, new, rx->name); + } +} + +static int tipc_crypto_key_revoke(struct net *net, u8 tx_key) +{ + struct tipc_crypto *tx = tipc_net(net)->crypto_tx; + struct tipc_key key; + + spin_lock_bh(&tx->lock); + key = tx->key; + WARN_ON(!key.active || tx_key != key.active); + + /* Free the active key */ + tipc_crypto_key_set_state(tx, key.passive, 0, key.pending); + tipc_crypto_key_detach(tx->aead[key.active], &tx->lock); + spin_unlock_bh(&tx->lock); + + pr_warn("%s: key is revoked\n", tx->name); + return -EKEYREVOKED; +} + +int tipc_crypto_start(struct tipc_crypto **crypto, struct net *net, + struct tipc_node *node) +{ + struct tipc_crypto *c; + + if (*crypto) + return -EEXIST; + + /* Allocate crypto */ + c = kzalloc(sizeof(*c), GFP_ATOMIC); + if (!c) + return -ENOMEM; + + /* Allocate workqueue on TX */ + if (!node) { + c->wq = alloc_ordered_workqueue("tipc_crypto", 0); + if (!c->wq) { + kfree(c); + return -ENOMEM; + } + } + + /* Allocate statistic structure */ + c->stats = alloc_percpu_gfp(struct tipc_crypto_stats, GFP_ATOMIC); + if (!c->stats) { + if (c->wq) + destroy_workqueue(c->wq); + kfree_sensitive(c); + return -ENOMEM; + } + + c->flags = 0; + c->net = net; + c->node = node; + get_random_bytes(&c->key_gen, 2); + tipc_crypto_key_set_state(c, 0, 0, 0); + atomic_set(&c->key_distr, 0); + atomic_set(&c->peer_rx_active, 0); + atomic64_set(&c->sndnxt, 0); + c->timer1 = jiffies; + c->timer2 = jiffies; + c->rekeying_intv = TIPC_REKEYING_INTV_DEF; + spin_lock_init(&c->lock); + scnprintf(c->name, 48, "%s(%s)", (is_rx(c)) ? "RX" : "TX", + (is_rx(c)) ? tipc_node_get_id_str(c->node) : + tipc_own_id_string(c->net)); + + if (is_rx(c)) + INIT_DELAYED_WORK(&c->work, tipc_crypto_work_rx); + else + INIT_DELAYED_WORK(&c->work, tipc_crypto_work_tx); + + *crypto = c; + return 0; +} + +void tipc_crypto_stop(struct tipc_crypto **crypto) +{ + struct tipc_crypto *c = *crypto; + u8 k; + + if (!c) + return; + + /* Flush any queued works & destroy wq */ + if (is_tx(c)) { + c->rekeying_intv = 0; + cancel_delayed_work_sync(&c->work); + destroy_workqueue(c->wq); + } + + /* Release AEAD keys */ + rcu_read_lock(); + for (k = KEY_MIN; k <= KEY_MAX; k++) + tipc_aead_put(rcu_dereference(c->aead[k])); + rcu_read_unlock(); + pr_debug("%s: has been stopped\n", c->name); + + /* Free this crypto statistics */ + free_percpu(c->stats); + + *crypto = NULL; + kfree_sensitive(c); +} + +void tipc_crypto_timeout(struct tipc_crypto *rx) +{ + struct tipc_net *tn = tipc_net(rx->net); + struct tipc_crypto *tx = tn->crypto_tx; + struct tipc_key key; + int cmd; + + /* TX pending: taking all users & stable -> active */ + spin_lock(&tx->lock); + key = tx->key; + if (key.active && tipc_aead_users(tx->aead[key.active]) > 0) + goto s1; + if (!key.pending || tipc_aead_users(tx->aead[key.pending]) <= 0) + goto s1; + if (time_before(jiffies, tx->timer1 + TIPC_TX_LASTING_TIME)) + goto s1; + + tipc_crypto_key_set_state(tx, key.passive, key.pending, 0); + if (key.active) + tipc_crypto_key_detach(tx->aead[key.active], &tx->lock); + this_cpu_inc(tx->stats->stat[STAT_SWITCHES]); + pr_info("%s: key[%d] is activated\n", tx->name, key.pending); + +s1: + spin_unlock(&tx->lock); + + /* RX pending: having user -> active */ + spin_lock(&rx->lock); + key = rx->key; + if (!key.pending || tipc_aead_users(rx->aead[key.pending]) <= 0) + goto s2; + + if (key.active) + key.passive = key.active; + key.active = key.pending; + rx->timer2 = jiffies; + tipc_crypto_key_set_state(rx, key.passive, key.active, 0); + this_cpu_inc(rx->stats->stat[STAT_SWITCHES]); + pr_info("%s: key[%d] is activated\n", rx->name, key.pending); + goto s5; + +s2: + /* RX pending: not working -> remove */ + if (!key.pending || tipc_aead_users(rx->aead[key.pending]) > -10) + goto s3; + + tipc_crypto_key_set_state(rx, key.passive, key.active, 0); + tipc_crypto_key_detach(rx->aead[key.pending], &rx->lock); + pr_debug("%s: key[%d] is removed\n", rx->name, key.pending); + goto s5; + +s3: + /* RX active: timed out or no user -> pending */ + if (!key.active) + goto s4; + if (time_before(jiffies, rx->timer1 + TIPC_RX_ACTIVE_LIM) && + tipc_aead_users(rx->aead[key.active]) > 0) + goto s4; + + if (key.pending) + key.passive = key.active; + else + key.pending = key.active; + rx->timer2 = jiffies; + tipc_crypto_key_set_state(rx, key.passive, 0, key.pending); + tipc_aead_users_set(rx->aead[key.pending], 0); + pr_debug("%s: key[%d] is deactivated\n", rx->name, key.active); + goto s5; + +s4: + /* RX passive: outdated or not working -> free */ + if (!key.passive) + goto s5; + if (time_before(jiffies, rx->timer2 + TIPC_RX_PASSIVE_LIM) && + tipc_aead_users(rx->aead[key.passive]) > -10) + goto s5; + + tipc_crypto_key_set_state(rx, 0, key.active, key.pending); + tipc_crypto_key_detach(rx->aead[key.passive], &rx->lock); + pr_debug("%s: key[%d] is freed\n", rx->name, key.passive); + +s5: + spin_unlock(&rx->lock); + + /* Relax it here, the flag will be set again if it really is, but only + * when we are not in grace period for safety! + */ + if (time_after(jiffies, tx->timer2 + TIPC_TX_GRACE_PERIOD)) + tx->legacy_user = 0; + + /* Limit max_tfms & do debug commands if needed */ + if (likely(sysctl_tipc_max_tfms <= TIPC_MAX_TFMS_LIM)) + return; + + cmd = sysctl_tipc_max_tfms; + sysctl_tipc_max_tfms = TIPC_MAX_TFMS_DEF; + tipc_crypto_do_cmd(rx->net, cmd); +} + +static inline void tipc_crypto_clone_msg(struct net *net, struct sk_buff *_skb, + struct tipc_bearer *b, + struct tipc_media_addr *dst, + struct tipc_node *__dnode, u8 type) +{ + struct sk_buff *skb; + + skb = skb_clone(_skb, GFP_ATOMIC); + if (skb) { + TIPC_SKB_CB(skb)->xmit_type = type; + tipc_crypto_xmit(net, &skb, b, dst, __dnode); + if (skb) + b->media->send_msg(net, skb, b, dst); + } +} + +/** + * tipc_crypto_xmit - Build & encrypt TIPC message for xmit + * @net: struct net + * @skb: input/output message skb pointer + * @b: bearer used for xmit later + * @dst: destination media address + * @__dnode: destination node for reference if any + * + * First, build an encryption message header on the top of the message, then + * encrypt the original TIPC message by using the pending, master or active + * key with this preference order. + * If the encryption is successful, the encrypted skb is returned directly or + * via the callback. + * Otherwise, the skb is freed! + * + * Return: + * * 0 : the encryption has succeeded (or no encryption) + * * -EINPROGRESS/-EBUSY : the encryption is ongoing, a callback will be made + * * -ENOKEK : the encryption has failed due to no key + * * -EKEYREVOKED : the encryption has failed due to key revoked + * * -ENOMEM : the encryption has failed due to no memory + * * < 0 : the encryption has failed due to other reasons + */ +int tipc_crypto_xmit(struct net *net, struct sk_buff **skb, + struct tipc_bearer *b, struct tipc_media_addr *dst, + struct tipc_node *__dnode) +{ + struct tipc_crypto *__rx = tipc_node_crypto_rx(__dnode); + struct tipc_crypto *tx = tipc_net(net)->crypto_tx; + struct tipc_crypto_stats __percpu *stats = tx->stats; + struct tipc_msg *hdr = buf_msg(*skb); + struct tipc_key key = tx->key; + struct tipc_aead *aead = NULL; + u32 user = msg_user(hdr); + u32 type = msg_type(hdr); + int rc = -ENOKEY; + u8 tx_key = 0; + + /* No encryption? */ + if (!tx->working) + return 0; + + /* Pending key if peer has active on it or probing time */ + if (unlikely(key.pending)) { + tx_key = key.pending; + if (!tx->key_master && !key.active) + goto encrypt; + if (__rx && atomic_read(&__rx->peer_rx_active) == tx_key) + goto encrypt; + if (TIPC_SKB_CB(*skb)->xmit_type == SKB_PROBING) { + pr_debug("%s: probing for key[%d]\n", tx->name, + key.pending); + goto encrypt; + } + if (user == LINK_CONFIG || user == LINK_PROTOCOL) + tipc_crypto_clone_msg(net, *skb, b, dst, __dnode, + SKB_PROBING); + } + + /* Master key if this is a *vital* message or in grace period */ + if (tx->key_master) { + tx_key = KEY_MASTER; + if (!key.active) + goto encrypt; + if (TIPC_SKB_CB(*skb)->xmit_type == SKB_GRACING) { + pr_debug("%s: gracing for msg (%d %d)\n", tx->name, + user, type); + goto encrypt; + } + if (user == LINK_CONFIG || + (user == LINK_PROTOCOL && type == RESET_MSG) || + (user == MSG_CRYPTO && type == KEY_DISTR_MSG) || + time_before(jiffies, tx->timer2 + TIPC_TX_GRACE_PERIOD)) { + if (__rx && __rx->key_master && + !atomic_read(&__rx->peer_rx_active)) + goto encrypt; + if (!__rx) { + if (likely(!tx->legacy_user)) + goto encrypt; + tipc_crypto_clone_msg(net, *skb, b, dst, + __dnode, SKB_GRACING); + } + } + } + + /* Else, use the active key if any */ + if (likely(key.active)) { + tx_key = key.active; + goto encrypt; + } + + goto exit; + +encrypt: + aead = tipc_aead_get(tx->aead[tx_key]); + if (unlikely(!aead)) + goto exit; + rc = tipc_ehdr_build(net, aead, tx_key, *skb, __rx); + if (likely(rc > 0)) + rc = tipc_aead_encrypt(aead, *skb, b, dst, __dnode); + +exit: + switch (rc) { + case 0: + this_cpu_inc(stats->stat[STAT_OK]); + break; + case -EINPROGRESS: + case -EBUSY: + this_cpu_inc(stats->stat[STAT_ASYNC]); + *skb = NULL; + return rc; + default: + this_cpu_inc(stats->stat[STAT_NOK]); + if (rc == -ENOKEY) + this_cpu_inc(stats->stat[STAT_NOKEYS]); + else if (rc == -EKEYREVOKED) + this_cpu_inc(stats->stat[STAT_BADKEYS]); + kfree_skb(*skb); + *skb = NULL; + break; + } + + tipc_aead_put(aead); + return rc; +} + +/** + * tipc_crypto_rcv - Decrypt an encrypted TIPC message from peer + * @net: struct net + * @rx: RX crypto handle + * @skb: input/output message skb pointer + * @b: bearer where the message has been received + * + * If the decryption is successful, the decrypted skb is returned directly or + * as the callback, the encryption header and auth tag will be trimed out + * before forwarding to tipc_rcv() via the tipc_crypto_rcv_complete(). + * Otherwise, the skb will be freed! + * Note: RX key(s) can be re-aligned, or in case of no key suitable, TX + * cluster key(s) can be taken for decryption (- recursive). + * + * Return: + * * 0 : the decryption has successfully completed + * * -EINPROGRESS/-EBUSY : the decryption is ongoing, a callback will be made + * * -ENOKEY : the decryption has failed due to no key + * * -EBADMSG : the decryption has failed due to bad message + * * -ENOMEM : the decryption has failed due to no memory + * * < 0 : the decryption has failed due to other reasons + */ +int tipc_crypto_rcv(struct net *net, struct tipc_crypto *rx, + struct sk_buff **skb, struct tipc_bearer *b) +{ + struct tipc_crypto *tx = tipc_net(net)->crypto_tx; + struct tipc_crypto_stats __percpu *stats; + struct tipc_aead *aead = NULL; + struct tipc_key key; + int rc = -ENOKEY; + u8 tx_key, n; + + tx_key = ((struct tipc_ehdr *)(*skb)->data)->tx_key; + + /* New peer? + * Let's try with TX key (i.e. cluster mode) & verify the skb first! + */ + if (unlikely(!rx || tx_key == KEY_MASTER)) + goto pick_tx; + + /* Pick RX key according to TX key if any */ + key = rx->key; + if (tx_key == key.active || tx_key == key.pending || + tx_key == key.passive) + goto decrypt; + + /* Unknown key, let's try to align RX key(s) */ + if (tipc_crypto_key_try_align(rx, tx_key)) + goto decrypt; + +pick_tx: + /* No key suitable? Try to pick one from TX... */ + aead = tipc_crypto_key_pick_tx(tx, rx, *skb, tx_key); + if (aead) + goto decrypt; + goto exit; + +decrypt: + rcu_read_lock(); + if (!aead) + aead = tipc_aead_get(rx->aead[tx_key]); + rc = tipc_aead_decrypt(net, aead, *skb, b); + rcu_read_unlock(); + +exit: + stats = ((rx) ?: tx)->stats; + switch (rc) { + case 0: + this_cpu_inc(stats->stat[STAT_OK]); + break; + case -EINPROGRESS: + case -EBUSY: + this_cpu_inc(stats->stat[STAT_ASYNC]); + *skb = NULL; + return rc; + default: + this_cpu_inc(stats->stat[STAT_NOK]); + if (rc == -ENOKEY) { + kfree_skb(*skb); + *skb = NULL; + if (rx) { + /* Mark rx->nokey only if we dont have a + * pending received session key, nor a newer + * one i.e. in the next slot. + */ + n = key_next(tx_key); + rx->nokey = !(rx->skey || + rcu_access_pointer(rx->aead[n])); + pr_debug_ratelimited("%s: nokey %d, key %d/%x\n", + rx->name, rx->nokey, + tx_key, rx->key.keys); + tipc_node_put(rx->node); + } + this_cpu_inc(stats->stat[STAT_NOKEYS]); + return rc; + } else if (rc == -EBADMSG) { + this_cpu_inc(stats->stat[STAT_BADMSGS]); + } + break; + } + + tipc_crypto_rcv_complete(net, aead, b, skb, rc); + return rc; +} + +static void tipc_crypto_rcv_complete(struct net *net, struct tipc_aead *aead, + struct tipc_bearer *b, + struct sk_buff **skb, int err) +{ + struct tipc_skb_cb *skb_cb = TIPC_SKB_CB(*skb); + struct tipc_crypto *rx = aead->crypto; + struct tipc_aead *tmp = NULL; + struct tipc_ehdr *ehdr; + struct tipc_node *n; + + /* Is this completed by TX? */ + if (unlikely(is_tx(aead->crypto))) { + rx = skb_cb->tx_clone_ctx.rx; + pr_debug("TX->RX(%s): err %d, aead %p, skb->next %p, flags %x\n", + (rx) ? tipc_node_get_id_str(rx->node) : "-", err, aead, + (*skb)->next, skb_cb->flags); + pr_debug("skb_cb [recurs %d, last %p], tx->aead [%p %p %p]\n", + skb_cb->tx_clone_ctx.recurs, skb_cb->tx_clone_ctx.last, + aead->crypto->aead[1], aead->crypto->aead[2], + aead->crypto->aead[3]); + if (unlikely(err)) { + if (err == -EBADMSG && (*skb)->next) + tipc_rcv(net, (*skb)->next, b); + goto free_skb; + } + + if (likely((*skb)->next)) { + kfree_skb((*skb)->next); + (*skb)->next = NULL; + } + ehdr = (struct tipc_ehdr *)(*skb)->data; + if (!rx) { + WARN_ON(ehdr->user != LINK_CONFIG); + n = tipc_node_create(net, 0, ehdr->id, 0xffffu, 0, + true); + rx = tipc_node_crypto_rx(n); + if (unlikely(!rx)) + goto free_skb; + } + + /* Ignore cloning if it was TX master key */ + if (ehdr->tx_key == KEY_MASTER) + goto rcv; + if (tipc_aead_clone(&tmp, aead) < 0) + goto rcv; + WARN_ON(!refcount_inc_not_zero(&tmp->refcnt)); + if (tipc_crypto_key_attach(rx, tmp, ehdr->tx_key, false) < 0) { + tipc_aead_free(&tmp->rcu); + goto rcv; + } + tipc_aead_put(aead); + aead = tmp; + } + + if (unlikely(err)) { + tipc_aead_users_dec((struct tipc_aead __force __rcu *)aead, INT_MIN); + goto free_skb; + } + + /* Set the RX key's user */ + tipc_aead_users_set((struct tipc_aead __force __rcu *)aead, 1); + + /* Mark this point, RX works */ + rx->timer1 = jiffies; + +rcv: + /* Remove ehdr & auth. tag prior to tipc_rcv() */ + ehdr = (struct tipc_ehdr *)(*skb)->data; + + /* Mark this point, RX passive still works */ + if (rx->key.passive && ehdr->tx_key == rx->key.passive) + rx->timer2 = jiffies; + + skb_reset_network_header(*skb); + skb_pull(*skb, tipc_ehdr_size(ehdr)); + if (pskb_trim(*skb, (*skb)->len - aead->authsize)) + goto free_skb; + + /* Validate TIPCv2 message */ + if (unlikely(!tipc_msg_validate(skb))) { + pr_err_ratelimited("Packet dropped after decryption!\n"); + goto free_skb; + } + + /* Ok, everything's fine, try to synch own keys according to peers' */ + tipc_crypto_key_synch(rx, *skb); + + /* Re-fetch skb cb as skb might be changed in tipc_msg_validate */ + skb_cb = TIPC_SKB_CB(*skb); + + /* Mark skb decrypted */ + skb_cb->decrypted = 1; + + /* Clear clone cxt if any */ + if (likely(!skb_cb->tx_clone_deferred)) + goto exit; + skb_cb->tx_clone_deferred = 0; + memset(&skb_cb->tx_clone_ctx, 0, sizeof(skb_cb->tx_clone_ctx)); + goto exit; + +free_skb: + kfree_skb(*skb); + *skb = NULL; + +exit: + tipc_aead_put(aead); + if (rx) + tipc_node_put(rx->node); +} + +static void tipc_crypto_do_cmd(struct net *net, int cmd) +{ + struct tipc_net *tn = tipc_net(net); + struct tipc_crypto *tx = tn->crypto_tx, *rx; + struct list_head *p; + unsigned int stat; + int i, j, cpu; + char buf[200]; + + /* Currently only one command is supported */ + switch (cmd) { + case 0xfff1: + goto print_stats; + default: + return; + } + +print_stats: + /* Print a header */ + pr_info("\n=============== TIPC Crypto Statistics ===============\n\n"); + + /* Print key status */ + pr_info("Key status:\n"); + pr_info("TX(%7.7s)\n%s", tipc_own_id_string(net), + tipc_crypto_key_dump(tx, buf)); + + rcu_read_lock(); + for (p = tn->node_list.next; p != &tn->node_list; p = p->next) { + rx = tipc_node_crypto_rx_by_list(p); + pr_info("RX(%7.7s)\n%s", tipc_node_get_id_str(rx->node), + tipc_crypto_key_dump(rx, buf)); + } + rcu_read_unlock(); + + /* Print crypto statistics */ + for (i = 0, j = 0; i < MAX_STATS; i++) + j += scnprintf(buf + j, 200 - j, "|%11s ", hstats[i]); + pr_info("Counter %s", buf); + + memset(buf, '-', 115); + buf[115] = '\0'; + pr_info("%s\n", buf); + + j = scnprintf(buf, 200, "TX(%7.7s) ", tipc_own_id_string(net)); + for_each_possible_cpu(cpu) { + for (i = 0; i < MAX_STATS; i++) { + stat = per_cpu_ptr(tx->stats, cpu)->stat[i]; + j += scnprintf(buf + j, 200 - j, "|%11d ", stat); + } + pr_info("%s", buf); + j = scnprintf(buf, 200, "%12s", " "); + } + + rcu_read_lock(); + for (p = tn->node_list.next; p != &tn->node_list; p = p->next) { + rx = tipc_node_crypto_rx_by_list(p); + j = scnprintf(buf, 200, "RX(%7.7s) ", + tipc_node_get_id_str(rx->node)); + for_each_possible_cpu(cpu) { + for (i = 0; i < MAX_STATS; i++) { + stat = per_cpu_ptr(rx->stats, cpu)->stat[i]; + j += scnprintf(buf + j, 200 - j, "|%11d ", + stat); + } + pr_info("%s", buf); + j = scnprintf(buf, 200, "%12s", " "); + } + } + rcu_read_unlock(); + + pr_info("\n======================== Done ========================\n"); +} + +static char *tipc_crypto_key_dump(struct tipc_crypto *c, char *buf) +{ + struct tipc_key key = c->key; + struct tipc_aead *aead; + int k, i = 0; + char *s; + + for (k = KEY_MIN; k <= KEY_MAX; k++) { + if (k == KEY_MASTER) { + if (is_rx(c)) + continue; + if (time_before(jiffies, + c->timer2 + TIPC_TX_GRACE_PERIOD)) + s = "ACT"; + else + s = "PAS"; + } else { + if (k == key.passive) + s = "PAS"; + else if (k == key.active) + s = "ACT"; + else if (k == key.pending) + s = "PEN"; + else + s = "-"; + } + i += scnprintf(buf + i, 200 - i, "\tKey%d: %s", k, s); + + rcu_read_lock(); + aead = rcu_dereference(c->aead[k]); + if (aead) + i += scnprintf(buf + i, 200 - i, + "{\"0x...%s\", \"%s\"}/%d:%d", + aead->hint, + (aead->mode == CLUSTER_KEY) ? "c" : "p", + atomic_read(&aead->users), + refcount_read(&aead->refcnt)); + rcu_read_unlock(); + i += scnprintf(buf + i, 200 - i, "\n"); + } + + if (is_rx(c)) + i += scnprintf(buf + i, 200 - i, "\tPeer RX active: %d\n", + atomic_read(&c->peer_rx_active)); + + return buf; +} + +static char *tipc_key_change_dump(struct tipc_key old, struct tipc_key new, + char *buf) +{ + struct tipc_key *key = &old; + int k, i = 0; + char *s; + + /* Output format: "[%s %s %s] -> [%s %s %s]", max len = 32 */ +again: + i += scnprintf(buf + i, 32 - i, "["); + for (k = KEY_1; k <= KEY_3; k++) { + if (k == key->passive) + s = "pas"; + else if (k == key->active) + s = "act"; + else if (k == key->pending) + s = "pen"; + else + s = "-"; + i += scnprintf(buf + i, 32 - i, + (k != KEY_3) ? "%s " : "%s", s); + } + if (key != &new) { + i += scnprintf(buf + i, 32 - i, "] -> "); + key = &new; + goto again; + } + i += scnprintf(buf + i, 32 - i, "]"); + return buf; +} + +/** + * tipc_crypto_msg_rcv - Common 'MSG_CRYPTO' processing point + * @net: the struct net + * @skb: the receiving message buffer + */ +void tipc_crypto_msg_rcv(struct net *net, struct sk_buff *skb) +{ + struct tipc_crypto *rx; + struct tipc_msg *hdr; + + if (unlikely(skb_linearize(skb))) + goto exit; + + hdr = buf_msg(skb); + rx = tipc_node_crypto_rx_by_addr(net, msg_prevnode(hdr)); + if (unlikely(!rx)) + goto exit; + + switch (msg_type(hdr)) { + case KEY_DISTR_MSG: + if (tipc_crypto_key_rcv(rx, hdr)) + goto exit; + break; + default: + break; + } + + tipc_node_put(rx->node); + +exit: + kfree_skb(skb); +} + +/** + * tipc_crypto_key_distr - Distribute a TX key + * @tx: the TX crypto + * @key: the key's index + * @dest: the destination tipc node, = NULL if distributing to all nodes + * + * Return: 0 in case of success, otherwise < 0 + */ +int tipc_crypto_key_distr(struct tipc_crypto *tx, u8 key, + struct tipc_node *dest) +{ + struct tipc_aead *aead; + u32 dnode = tipc_node_get_addr(dest); + int rc = -ENOKEY; + + if (!sysctl_tipc_key_exchange_enabled) + return 0; + + if (key) { + rcu_read_lock(); + aead = tipc_aead_get(tx->aead[key]); + if (likely(aead)) { + rc = tipc_crypto_key_xmit(tx->net, aead->key, + aead->gen, aead->mode, + dnode); + tipc_aead_put(aead); + } + rcu_read_unlock(); + } + + return rc; +} + +/** + * tipc_crypto_key_xmit - Send a session key + * @net: the struct net + * @skey: the session key to be sent + * @gen: the key's generation + * @mode: the key's mode + * @dnode: the destination node address, = 0 if broadcasting to all nodes + * + * The session key 'skey' is packed in a TIPC v2 'MSG_CRYPTO/KEY_DISTR_MSG' + * as its data section, then xmit-ed through the uc/bc link. + * + * Return: 0 in case of success, otherwise < 0 + */ +static int tipc_crypto_key_xmit(struct net *net, struct tipc_aead_key *skey, + u16 gen, u8 mode, u32 dnode) +{ + struct sk_buff_head pkts; + struct tipc_msg *hdr; + struct sk_buff *skb; + u16 size, cong_link_cnt; + u8 *data; + int rc; + + size = tipc_aead_key_size(skey); + skb = tipc_buf_acquire(INT_H_SIZE + size, GFP_ATOMIC); + if (!skb) + return -ENOMEM; + + hdr = buf_msg(skb); + tipc_msg_init(tipc_own_addr(net), hdr, MSG_CRYPTO, KEY_DISTR_MSG, + INT_H_SIZE, dnode); + msg_set_size(hdr, INT_H_SIZE + size); + msg_set_key_gen(hdr, gen); + msg_set_key_mode(hdr, mode); + + data = msg_data(hdr); + *((__be32 *)(data + TIPC_AEAD_ALG_NAME)) = htonl(skey->keylen); + memcpy(data, skey->alg_name, TIPC_AEAD_ALG_NAME); + memcpy(data + TIPC_AEAD_ALG_NAME + sizeof(__be32), skey->key, + skey->keylen); + + __skb_queue_head_init(&pkts); + __skb_queue_tail(&pkts, skb); + if (dnode) + rc = tipc_node_xmit(net, &pkts, dnode, 0); + else + rc = tipc_bcast_xmit(net, &pkts, &cong_link_cnt); + + return rc; +} + +/** + * tipc_crypto_key_rcv - Receive a session key + * @rx: the RX crypto + * @hdr: the TIPC v2 message incl. the receiving session key in its data + * + * This function retrieves the session key in the message from peer, then + * schedules a RX work to attach the key to the corresponding RX crypto. + * + * Return: "true" if the key has been scheduled for attaching, otherwise + * "false". + */ +static bool tipc_crypto_key_rcv(struct tipc_crypto *rx, struct tipc_msg *hdr) +{ + struct tipc_crypto *tx = tipc_net(rx->net)->crypto_tx; + struct tipc_aead_key *skey = NULL; + u16 key_gen = msg_key_gen(hdr); + u32 size = msg_data_sz(hdr); + u8 *data = msg_data(hdr); + unsigned int keylen; + + /* Verify whether the size can exist in the packet */ + if (unlikely(size < sizeof(struct tipc_aead_key) + TIPC_AEAD_KEYLEN_MIN)) { + pr_debug("%s: message data size is too small\n", rx->name); + goto exit; + } + + keylen = ntohl(*((__be32 *)(data + TIPC_AEAD_ALG_NAME))); + + /* Verify the supplied size values */ + if (unlikely(size != keylen + sizeof(struct tipc_aead_key) || + keylen > TIPC_AEAD_KEY_SIZE_MAX)) { + pr_debug("%s: invalid MSG_CRYPTO key size\n", rx->name); + goto exit; + } + + spin_lock(&rx->lock); + if (unlikely(rx->skey || (key_gen == rx->key_gen && rx->key.keys))) { + pr_err("%s: key existed <%p>, gen %d vs %d\n", rx->name, + rx->skey, key_gen, rx->key_gen); + goto exit_unlock; + } + + /* Allocate memory for the key */ + skey = kmalloc(size, GFP_ATOMIC); + if (unlikely(!skey)) { + pr_err("%s: unable to allocate memory for skey\n", rx->name); + goto exit_unlock; + } + + /* Copy key from msg data */ + skey->keylen = keylen; + memcpy(skey->alg_name, data, TIPC_AEAD_ALG_NAME); + memcpy(skey->key, data + TIPC_AEAD_ALG_NAME + sizeof(__be32), + skey->keylen); + + rx->key_gen = key_gen; + rx->skey_mode = msg_key_mode(hdr); + rx->skey = skey; + rx->nokey = 0; + mb(); /* for nokey flag */ + +exit_unlock: + spin_unlock(&rx->lock); + +exit: + /* Schedule the key attaching on this crypto */ + if (likely(skey && queue_delayed_work(tx->wq, &rx->work, 0))) + return true; + + return false; +} + +/** + * tipc_crypto_work_rx - Scheduled RX works handler + * @work: the struct RX work + * + * The function processes the previous scheduled works i.e. distributing TX key + * or attaching a received session key on RX crypto. + */ +static void tipc_crypto_work_rx(struct work_struct *work) +{ + struct delayed_work *dwork = to_delayed_work(work); + struct tipc_crypto *rx = container_of(dwork, struct tipc_crypto, work); + struct tipc_crypto *tx = tipc_net(rx->net)->crypto_tx; + unsigned long delay = msecs_to_jiffies(5000); + bool resched = false; + u8 key; + int rc; + + /* Case 1: Distribute TX key to peer if scheduled */ + if (atomic_cmpxchg(&rx->key_distr, + KEY_DISTR_SCHED, + KEY_DISTR_COMPL) == KEY_DISTR_SCHED) { + /* Always pick the newest one for distributing */ + key = tx->key.pending ?: tx->key.active; + rc = tipc_crypto_key_distr(tx, key, rx->node); + if (unlikely(rc)) + pr_warn("%s: unable to distr key[%d] to %s, err %d\n", + tx->name, key, tipc_node_get_id_str(rx->node), + rc); + + /* Sched for key_distr releasing */ + resched = true; + } else { + atomic_cmpxchg(&rx->key_distr, KEY_DISTR_COMPL, 0); + } + + /* Case 2: Attach a pending received session key from peer if any */ + if (rx->skey) { + rc = tipc_crypto_key_init(rx, rx->skey, rx->skey_mode, false); + if (unlikely(rc < 0)) + pr_warn("%s: unable to attach received skey, err %d\n", + rx->name, rc); + switch (rc) { + case -EBUSY: + case -ENOMEM: + /* Resched the key attaching */ + resched = true; + break; + default: + synchronize_rcu(); + kfree(rx->skey); + rx->skey = NULL; + break; + } + } + + if (resched && queue_delayed_work(tx->wq, &rx->work, delay)) + return; + + tipc_node_put(rx->node); +} + +/** + * tipc_crypto_rekeying_sched - (Re)schedule rekeying w/o new interval + * @tx: TX crypto + * @changed: if the rekeying needs to be rescheduled with new interval + * @new_intv: new rekeying interval (when "changed" = true) + */ +void tipc_crypto_rekeying_sched(struct tipc_crypto *tx, bool changed, + u32 new_intv) +{ + unsigned long delay; + bool now = false; + + if (changed) { + if (new_intv == TIPC_REKEYING_NOW) + now = true; + else + tx->rekeying_intv = new_intv; + cancel_delayed_work_sync(&tx->work); + } + + if (tx->rekeying_intv || now) { + delay = (now) ? 0 : tx->rekeying_intv * 60 * 1000; + queue_delayed_work(tx->wq, &tx->work, msecs_to_jiffies(delay)); + } +} + +/** + * tipc_crypto_work_tx - Scheduled TX works handler + * @work: the struct TX work + * + * The function processes the previous scheduled work, i.e. key rekeying, by + * generating a new session key based on current one, then attaching it to the + * TX crypto and finally distributing it to peers. It also re-schedules the + * rekeying if needed. + */ +static void tipc_crypto_work_tx(struct work_struct *work) +{ + struct delayed_work *dwork = to_delayed_work(work); + struct tipc_crypto *tx = container_of(dwork, struct tipc_crypto, work); + struct tipc_aead_key *skey = NULL; + struct tipc_key key = tx->key; + struct tipc_aead *aead; + int rc = -ENOMEM; + + if (unlikely(key.pending)) + goto resched; + + /* Take current key as a template */ + rcu_read_lock(); + aead = rcu_dereference(tx->aead[key.active ?: KEY_MASTER]); + if (unlikely(!aead)) { + rcu_read_unlock(); + /* At least one key should exist for securing */ + return; + } + + /* Lets duplicate it first */ + skey = kmemdup(aead->key, tipc_aead_key_size(aead->key), GFP_ATOMIC); + rcu_read_unlock(); + + /* Now, generate new key, initiate & distribute it */ + if (likely(skey)) { + rc = tipc_aead_key_generate(skey) ?: + tipc_crypto_key_init(tx, skey, PER_NODE_KEY, false); + if (likely(rc > 0)) + rc = tipc_crypto_key_distr(tx, rc, NULL); + kfree_sensitive(skey); + } + + if (unlikely(rc)) + pr_warn_ratelimited("%s: rekeying returns %d\n", tx->name, rc); + +resched: + /* Re-schedule rekeying if any */ + tipc_crypto_rekeying_sched(tx, false, 0); +} diff --git a/net/tipc/crypto.h b/net/tipc/crypto.h new file mode 100644 index 000000000..ce7d4cc8a --- /dev/null +++ b/net/tipc/crypto.h @@ -0,0 +1,200 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * net/tipc/crypto.h: Include file for TIPC crypto + * + * Copyright (c) 2019, Ericsson AB + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +#ifdef CONFIG_TIPC_CRYPTO +#ifndef _TIPC_CRYPTO_H +#define _TIPC_CRYPTO_H + +#include "core.h" +#include "node.h" +#include "msg.h" +#include "bearer.h" + +#define TIPC_EVERSION 7 + +/* AEAD aes(gcm) */ +#define TIPC_AES_GCM_KEY_SIZE_128 16 +#define TIPC_AES_GCM_KEY_SIZE_192 24 +#define TIPC_AES_GCM_KEY_SIZE_256 32 + +#define TIPC_AES_GCM_SALT_SIZE 4 +#define TIPC_AES_GCM_IV_SIZE 12 +#define TIPC_AES_GCM_TAG_SIZE 16 + +/* + * TIPC crypto modes: + * - CLUSTER_KEY: + * One single key is used for both TX & RX in all nodes in the cluster. + * - PER_NODE_KEY: + * Each nodes in the cluster has one TX key, for RX a node needs to know + * its peers' TX key for the decryption of messages from those nodes. + */ +enum { + CLUSTER_KEY = 1, + PER_NODE_KEY = (1 << 1), +}; + +extern int sysctl_tipc_max_tfms __read_mostly; +extern int sysctl_tipc_key_exchange_enabled __read_mostly; + +/* + * TIPC encryption message format: + * + * 3 3 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 + * 1 0 9 8 7 6 5 4|3 2 1 0 9 8 7 6|5 4 3 2 1 0 9 8|7 6 5 4 3 2 1 0 + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * w0:|Ver=7| User |D|TX |RX |K|M|N| Rsvd | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * w1:| Seqno | + * w2:| (8 octets) | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * w3:\ Prevnode \ + * / (4 or 16 octets) / + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * \ \ + * / Encrypted complete TIPC V2 header and user data / + * \ \ + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * | | + * | AuthTag | + * | (16 octets) | + * | | + * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + * + * Word0: + * Ver : = 7 i.e. TIPC encryption message version + * User : = 7 (for LINK_PROTOCOL); = 13 (for LINK_CONFIG) or = 0 + * D : The destined bit i.e. the message's destination node is + * "known" or not at the message encryption + * TX : TX key used for the message encryption + * RX : Currently RX active key corresponding to the destination + * node's TX key (when the "D" bit is set) + * K : Keep-alive bit (for RPS, LINK_PROTOCOL/STATE_MSG only) + * M : Bit indicates if sender has master key + * N : Bit indicates if sender has no RX keys corresponding to the + * receiver's TX (when the "D" bit is set) + * Rsvd : Reserved bit, field + * Word1-2: + * Seqno : The 64-bit sequence number of the encrypted message, also + * part of the nonce used for the message encryption/decryption + * Word3-: + * Prevnode: The source node address, or ID in case LINK_CONFIG only + * AuthTag : The authentication tag for the message integrity checking + * generated by the message encryption + */ +struct tipc_ehdr { + union { + struct { +#if defined(__LITTLE_ENDIAN_BITFIELD) + __u8 destined:1, + user:4, + version:3; + __u8 reserved_1:1, + rx_nokey:1, + master_key:1, + keepalive:1, + rx_key_active:2, + tx_key:2; +#elif defined(__BIG_ENDIAN_BITFIELD) + __u8 version:3, + user:4, + destined:1; + __u8 tx_key:2, + rx_key_active:2, + keepalive:1, + master_key:1, + rx_nokey:1, + reserved_1:1; +#else +#error "Please fix " +#endif + __be16 reserved_2; + } __packed; + __be32 w0; + }; + __be64 seqno; + union { + __be32 addr; + __u8 id[NODE_ID_LEN]; /* For a LINK_CONFIG message only! */ + }; +#define EHDR_SIZE (offsetof(struct tipc_ehdr, addr) + sizeof(__be32)) +#define EHDR_CFG_SIZE (sizeof(struct tipc_ehdr)) +#define EHDR_MIN_SIZE (EHDR_SIZE) +#define EHDR_MAX_SIZE (EHDR_CFG_SIZE) +#define EMSG_OVERHEAD (EHDR_SIZE + TIPC_AES_GCM_TAG_SIZE) +} __packed; + +int tipc_crypto_start(struct tipc_crypto **crypto, struct net *net, + struct tipc_node *node); +void tipc_crypto_stop(struct tipc_crypto **crypto); +void tipc_crypto_timeout(struct tipc_crypto *rx); +int tipc_crypto_xmit(struct net *net, struct sk_buff **skb, + struct tipc_bearer *b, struct tipc_media_addr *dst, + struct tipc_node *__dnode); +int tipc_crypto_rcv(struct net *net, struct tipc_crypto *rx, + struct sk_buff **skb, struct tipc_bearer *b); +int tipc_crypto_key_init(struct tipc_crypto *c, struct tipc_aead_key *ukey, + u8 mode, bool master_key); +void tipc_crypto_key_flush(struct tipc_crypto *c); +int tipc_crypto_key_distr(struct tipc_crypto *tx, u8 key, + struct tipc_node *dest); +void tipc_crypto_msg_rcv(struct net *net, struct sk_buff *skb); +void tipc_crypto_rekeying_sched(struct tipc_crypto *tx, bool changed, + u32 new_intv); +int tipc_aead_key_validate(struct tipc_aead_key *ukey, struct genl_info *info); +bool tipc_ehdr_validate(struct sk_buff *skb); + +static inline u32 msg_key_gen(struct tipc_msg *m) +{ + return msg_bits(m, 4, 16, 0xffff); +} + +static inline void msg_set_key_gen(struct tipc_msg *m, u32 gen) +{ + msg_set_bits(m, 4, 16, 0xffff, gen); +} + +static inline u32 msg_key_mode(struct tipc_msg *m) +{ + return msg_bits(m, 4, 0, 0xf); +} + +static inline void msg_set_key_mode(struct tipc_msg *m, u32 mode) +{ + msg_set_bits(m, 4, 0, 0xf, mode); +} + +#endif /* _TIPC_CRYPTO_H */ +#endif diff --git a/net/tipc/diag.c b/net/tipc/diag.c new file mode 100644 index 000000000..73137f4ae --- /dev/null +++ b/net/tipc/diag.c @@ -0,0 +1,116 @@ +/* + * net/tipc/diag.c: TIPC socket diag + * + * Copyright (c) 2018, Ericsson AB + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "ASIS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "core.h" +#include "socket.h" +#include +#include + +static u64 __tipc_diag_gen_cookie(struct sock *sk) +{ + u32 res[2]; + + sock_diag_save_cookie(sk, res); + return *((u64 *)res); +} + +static int __tipc_add_sock_diag(struct sk_buff *skb, + struct netlink_callback *cb, + struct tipc_sock *tsk) +{ + struct tipc_sock_diag_req *req = nlmsg_data(cb->nlh); + struct nlmsghdr *nlh; + int err; + + nlh = nlmsg_put_answer(skb, cb, SOCK_DIAG_BY_FAMILY, 0, + NLM_F_MULTI); + if (!nlh) + return -EMSGSIZE; + + err = tipc_sk_fill_sock_diag(skb, cb, tsk, req->tidiag_states, + __tipc_diag_gen_cookie); + if (err) + return err; + + nlmsg_end(skb, nlh); + return 0; +} + +static int tipc_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + return tipc_nl_sk_walk(skb, cb, __tipc_add_sock_diag); +} + +static int tipc_sock_diag_handler_dump(struct sk_buff *skb, + struct nlmsghdr *h) +{ + int hdrlen = sizeof(struct tipc_sock_diag_req); + struct net *net = sock_net(skb->sk); + + if (nlmsg_len(h) < hdrlen) + return -EINVAL; + + if (h->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .start = tipc_dump_start, + .dump = tipc_diag_dump, + .done = tipc_dump_done, + }; + netlink_dump_start(net->diag_nlsk, skb, h, &c); + return 0; + } + return -EOPNOTSUPP; +} + +static const struct sock_diag_handler tipc_sock_diag_handler = { + .family = AF_TIPC, + .dump = tipc_sock_diag_handler_dump, +}; + +static int __init tipc_diag_init(void) +{ + return sock_diag_register(&tipc_sock_diag_handler); +} + +static void __exit tipc_diag_exit(void) +{ + sock_diag_unregister(&tipc_sock_diag_handler); +} + +module_init(tipc_diag_init); +module_exit(tipc_diag_exit); + +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, AF_TIPC); diff --git a/net/tipc/discover.c b/net/tipc/discover.c new file mode 100644 index 000000000..e8dcdf267 --- /dev/null +++ b/net/tipc/discover.c @@ -0,0 +1,420 @@ +/* + * net/tipc/discover.c + * + * Copyright (c) 2003-2006, 2014-2018, Ericsson AB + * Copyright (c) 2005-2006, 2010-2011, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "core.h" +#include "node.h" +#include "discover.h" + +/* min delay during bearer start up */ +#define TIPC_DISC_INIT msecs_to_jiffies(125) +/* max delay if bearer has no links */ +#define TIPC_DISC_FAST msecs_to_jiffies(1000) +/* max delay if bearer has links */ +#define TIPC_DISC_SLOW msecs_to_jiffies(60000) +/* indicates no timer in use */ +#define TIPC_DISC_INACTIVE 0xffffffff + +/** + * struct tipc_discoverer - information about an ongoing link setup request + * @bearer_id: identity of bearer issuing requests + * @net: network namespace instance + * @dest: destination address for request messages + * @domain: network domain to which links can be established + * @num_nodes: number of nodes currently discovered (i.e. with an active link) + * @lock: spinlock for controlling access to requests + * @skb: request message to be (repeatedly) sent + * @timer: timer governing period between requests + * @timer_intv: current interval between requests (in ms) + */ +struct tipc_discoverer { + u32 bearer_id; + struct tipc_media_addr dest; + struct net *net; + u32 domain; + int num_nodes; + spinlock_t lock; + struct sk_buff *skb; + struct timer_list timer; + unsigned long timer_intv; +}; + +/** + * tipc_disc_init_msg - initialize a link setup message + * @net: the applicable net namespace + * @skb: buffer containing message + * @mtyp: message type (request or response) + * @b: ptr to bearer issuing message + */ +static void tipc_disc_init_msg(struct net *net, struct sk_buff *skb, + u32 mtyp, struct tipc_bearer *b) +{ + struct tipc_net *tn = tipc_net(net); + u32 dest_domain = b->domain; + struct tipc_msg *hdr; + + hdr = buf_msg(skb); + tipc_msg_init(tn->trial_addr, hdr, LINK_CONFIG, mtyp, + MAX_H_SIZE, dest_domain); + msg_set_size(hdr, MAX_H_SIZE + NODE_ID_LEN); + msg_set_non_seq(hdr, 1); + msg_set_node_sig(hdr, tn->random); + msg_set_node_capabilities(hdr, TIPC_NODE_CAPABILITIES); + msg_set_dest_domain(hdr, dest_domain); + msg_set_bc_netid(hdr, tn->net_id); + b->media->addr2msg(msg_media_addr(hdr), &b->addr); + msg_set_peer_net_hash(hdr, tipc_net_hash_mixes(net, tn->random)); + msg_set_node_id(hdr, tipc_own_id(net)); +} + +static void tipc_disc_msg_xmit(struct net *net, u32 mtyp, u32 dst, + u32 src, u32 sugg_addr, + struct tipc_media_addr *maddr, + struct tipc_bearer *b) +{ + struct tipc_msg *hdr; + struct sk_buff *skb; + + skb = tipc_buf_acquire(MAX_H_SIZE + NODE_ID_LEN, GFP_ATOMIC); + if (!skb) + return; + hdr = buf_msg(skb); + tipc_disc_init_msg(net, skb, mtyp, b); + msg_set_sugg_node_addr(hdr, sugg_addr); + msg_set_dest_domain(hdr, dst); + tipc_bearer_xmit_skb(net, b->identity, skb, maddr); +} + +/** + * disc_dupl_alert - issue node address duplication alert + * @b: pointer to bearer detecting duplication + * @node_addr: duplicated node address + * @media_addr: media address advertised by duplicated node + */ +static void disc_dupl_alert(struct tipc_bearer *b, u32 node_addr, + struct tipc_media_addr *media_addr) +{ + char media_addr_str[64]; + + tipc_media_addr_printf(media_addr_str, sizeof(media_addr_str), + media_addr); + pr_warn("Duplicate %x using %s seen on <%s>\n", node_addr, + media_addr_str, b->name); +} + +/* tipc_disc_addr_trial(): - handle an address uniqueness trial from peer + * Returns true if message should be dropped by caller, i.e., if it is a + * trial message or we are inside trial period. Otherwise false. + */ +static bool tipc_disc_addr_trial_msg(struct tipc_discoverer *d, + struct tipc_media_addr *maddr, + struct tipc_bearer *b, + u32 dst, u32 src, + u32 sugg_addr, + u8 *peer_id, + int mtyp) +{ + struct net *net = d->net; + struct tipc_net *tn = tipc_net(net); + u32 self = tipc_own_addr(net); + bool trial = time_before(jiffies, tn->addr_trial_end) && !self; + + if (mtyp == DSC_TRIAL_FAIL_MSG) { + if (!trial) + return true; + + /* Ignore if somebody else already gave new suggestion */ + if (dst != tn->trial_addr) + return true; + + /* Otherwise update trial address and restart trial period */ + tn->trial_addr = sugg_addr; + msg_set_prevnode(buf_msg(d->skb), sugg_addr); + tn->addr_trial_end = jiffies + msecs_to_jiffies(1000); + return true; + } + + /* Apply trial address if we just left trial period */ + if (!trial && !self) { + schedule_work(&tn->work); + msg_set_prevnode(buf_msg(d->skb), tn->trial_addr); + msg_set_type(buf_msg(d->skb), DSC_REQ_MSG); + } + + /* Accept regular link requests/responses only after trial period */ + if (mtyp != DSC_TRIAL_MSG) + return trial; + + sugg_addr = tipc_node_try_addr(net, peer_id, src); + if (sugg_addr) + tipc_disc_msg_xmit(net, DSC_TRIAL_FAIL_MSG, src, + self, sugg_addr, maddr, b); + return true; +} + +/** + * tipc_disc_rcv - handle incoming discovery message (request or response) + * @net: applicable net namespace + * @skb: buffer containing message + * @b: bearer that message arrived on + */ +void tipc_disc_rcv(struct net *net, struct sk_buff *skb, + struct tipc_bearer *b) +{ + struct tipc_net *tn = tipc_net(net); + struct tipc_msg *hdr = buf_msg(skb); + u32 pnet_hash = msg_peer_net_hash(hdr); + u16 caps = msg_node_capabilities(hdr); + bool legacy = tn->legacy_addr_format; + u32 sugg = msg_sugg_node_addr(hdr); + u32 signature = msg_node_sig(hdr); + u8 peer_id[NODE_ID_LEN] = {0,}; + u32 dst = msg_dest_domain(hdr); + u32 net_id = msg_bc_netid(hdr); + struct tipc_media_addr maddr; + u32 src = msg_prevnode(hdr); + u32 mtyp = msg_type(hdr); + bool dupl_addr = false; + bool respond = false; + u32 self; + int err; + + if (skb_linearize(skb)) { + kfree_skb(skb); + return; + } + hdr = buf_msg(skb); + + if (caps & TIPC_NODE_ID128) + memcpy(peer_id, msg_node_id(hdr), NODE_ID_LEN); + else + sprintf(peer_id, "%x", src); + + err = b->media->msg2addr(b, &maddr, msg_media_addr(hdr)); + kfree_skb(skb); + if (err || maddr.broadcast) { + pr_warn_ratelimited("Rcv corrupt discovery message\n"); + return; + } + /* Ignore discovery messages from own node */ + if (!memcmp(&maddr, &b->addr, sizeof(maddr))) + return; + if (net_id != tn->net_id) + return; + if (tipc_disc_addr_trial_msg(b->disc, &maddr, b, dst, + src, sugg, peer_id, mtyp)) + return; + self = tipc_own_addr(net); + + /* Message from somebody using this node's address */ + if (in_own_node(net, src)) { + disc_dupl_alert(b, self, &maddr); + return; + } + if (!tipc_in_scope(legacy, dst, self)) + return; + if (!tipc_in_scope(legacy, b->domain, src)) + return; + tipc_node_check_dest(net, src, peer_id, b, caps, signature, pnet_hash, + &maddr, &respond, &dupl_addr); + if (dupl_addr) + disc_dupl_alert(b, src, &maddr); + if (!respond) + return; + if (mtyp != DSC_REQ_MSG) + return; + tipc_disc_msg_xmit(net, DSC_RESP_MSG, src, self, 0, &maddr, b); +} + +/* tipc_disc_add_dest - increment set of discovered nodes + */ +void tipc_disc_add_dest(struct tipc_discoverer *d) +{ + spin_lock_bh(&d->lock); + d->num_nodes++; + spin_unlock_bh(&d->lock); +} + +/* tipc_disc_remove_dest - decrement set of discovered nodes + */ +void tipc_disc_remove_dest(struct tipc_discoverer *d) +{ + int intv, num; + + spin_lock_bh(&d->lock); + d->num_nodes--; + num = d->num_nodes; + intv = d->timer_intv; + if (!num && (intv == TIPC_DISC_INACTIVE || intv > TIPC_DISC_FAST)) { + d->timer_intv = TIPC_DISC_INIT; + mod_timer(&d->timer, jiffies + d->timer_intv); + } + spin_unlock_bh(&d->lock); +} + +/* tipc_disc_timeout - send a periodic link setup request + * Called whenever a link setup request timer associated with a bearer expires. + * - Keep doubling time between sent request until limit is reached; + * - Hold at fast polling rate if we don't have any associated nodes + * - Otherwise hold at slow polling rate + */ +static void tipc_disc_timeout(struct timer_list *t) +{ + struct tipc_discoverer *d = from_timer(d, t, timer); + struct tipc_net *tn = tipc_net(d->net); + struct tipc_media_addr maddr; + struct sk_buff *skb = NULL; + struct net *net = d->net; + u32 bearer_id; + + spin_lock_bh(&d->lock); + + /* Stop searching if only desired node has been found */ + if (tipc_node(d->domain) && d->num_nodes) { + d->timer_intv = TIPC_DISC_INACTIVE; + goto exit; + } + + /* Did we just leave trial period ? */ + if (!time_before(jiffies, tn->addr_trial_end) && !tipc_own_addr(net)) { + mod_timer(&d->timer, jiffies + TIPC_DISC_INIT); + spin_unlock_bh(&d->lock); + schedule_work(&tn->work); + return; + } + + /* Adjust timeout interval according to discovery phase */ + if (time_before(jiffies, tn->addr_trial_end)) { + d->timer_intv = TIPC_DISC_INIT; + } else { + d->timer_intv *= 2; + if (d->num_nodes && d->timer_intv > TIPC_DISC_SLOW) + d->timer_intv = TIPC_DISC_SLOW; + else if (!d->num_nodes && d->timer_intv > TIPC_DISC_FAST) + d->timer_intv = TIPC_DISC_FAST; + msg_set_type(buf_msg(d->skb), DSC_REQ_MSG); + msg_set_prevnode(buf_msg(d->skb), tn->trial_addr); + } + + mod_timer(&d->timer, jiffies + d->timer_intv); + memcpy(&maddr, &d->dest, sizeof(maddr)); + skb = skb_clone(d->skb, GFP_ATOMIC); + bearer_id = d->bearer_id; +exit: + spin_unlock_bh(&d->lock); + if (skb) + tipc_bearer_xmit_skb(net, bearer_id, skb, &maddr); +} + +/** + * tipc_disc_create - create object to send periodic link setup requests + * @net: the applicable net namespace + * @b: ptr to bearer issuing requests + * @dest: destination address for request messages + * @skb: pointer to created frame + * + * Return: 0 if successful, otherwise -errno. + */ +int tipc_disc_create(struct net *net, struct tipc_bearer *b, + struct tipc_media_addr *dest, struct sk_buff **skb) +{ + struct tipc_net *tn = tipc_net(net); + struct tipc_discoverer *d; + + d = kmalloc(sizeof(*d), GFP_ATOMIC); + if (!d) + return -ENOMEM; + d->skb = tipc_buf_acquire(MAX_H_SIZE + NODE_ID_LEN, GFP_ATOMIC); + if (!d->skb) { + kfree(d); + return -ENOMEM; + } + tipc_disc_init_msg(net, d->skb, DSC_REQ_MSG, b); + + /* Do we need an address trial period first ? */ + if (!tipc_own_addr(net)) { + tn->addr_trial_end = jiffies + msecs_to_jiffies(1000); + msg_set_type(buf_msg(d->skb), DSC_TRIAL_MSG); + } + memcpy(&d->dest, dest, sizeof(*dest)); + d->net = net; + d->bearer_id = b->identity; + d->domain = b->domain; + d->num_nodes = 0; + d->timer_intv = TIPC_DISC_INIT; + spin_lock_init(&d->lock); + timer_setup(&d->timer, tipc_disc_timeout, 0); + mod_timer(&d->timer, jiffies + d->timer_intv); + b->disc = d; + *skb = skb_clone(d->skb, GFP_ATOMIC); + return 0; +} + +/** + * tipc_disc_delete - destroy object sending periodic link setup requests + * @d: ptr to link dest structure + */ +void tipc_disc_delete(struct tipc_discoverer *d) +{ + del_timer_sync(&d->timer); + kfree_skb(d->skb); + kfree(d); +} + +/** + * tipc_disc_reset - reset object to send periodic link setup requests + * @net: the applicable net namespace + * @b: ptr to bearer issuing requests + */ +void tipc_disc_reset(struct net *net, struct tipc_bearer *b) +{ + struct tipc_discoverer *d = b->disc; + struct tipc_media_addr maddr; + struct sk_buff *skb; + + spin_lock_bh(&d->lock); + tipc_disc_init_msg(net, d->skb, DSC_REQ_MSG, b); + d->net = net; + d->bearer_id = b->identity; + d->domain = b->domain; + d->num_nodes = 0; + d->timer_intv = TIPC_DISC_INIT; + memcpy(&maddr, &d->dest, sizeof(maddr)); + mod_timer(&d->timer, jiffies + d->timer_intv); + skb = skb_clone(d->skb, GFP_ATOMIC); + spin_unlock_bh(&d->lock); + if (skb) + tipc_bearer_xmit_skb(net, b->identity, skb, &maddr); +} diff --git a/net/tipc/discover.h b/net/tipc/discover.h new file mode 100644 index 000000000..521d96c41 --- /dev/null +++ b/net/tipc/discover.h @@ -0,0 +1,51 @@ +/* + * net/tipc/discover.h + * + * Copyright (c) 2003-2006, Ericsson AB + * Copyright (c) 2005, 2010-2011, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _TIPC_DISCOVER_H +#define _TIPC_DISCOVER_H + +struct tipc_discoverer; + +int tipc_disc_create(struct net *net, struct tipc_bearer *b_ptr, + struct tipc_media_addr *dest, struct sk_buff **skb); +void tipc_disc_delete(struct tipc_discoverer *req); +void tipc_disc_reset(struct net *net, struct tipc_bearer *b_ptr); +void tipc_disc_add_dest(struct tipc_discoverer *req); +void tipc_disc_remove_dest(struct tipc_discoverer *req); +void tipc_disc_rcv(struct net *net, struct sk_buff *buf, + struct tipc_bearer *b_ptr); + +#endif diff --git a/net/tipc/eth_media.c b/net/tipc/eth_media.c new file mode 100644 index 000000000..cb0d185e0 --- /dev/null +++ b/net/tipc/eth_media.c @@ -0,0 +1,98 @@ +/* + * net/tipc/eth_media.c: Ethernet bearer support for TIPC + * + * Copyright (c) 2001-2007, 2013-2014, Ericsson AB + * Copyright (c) 2005-2008, 2011-2013, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "core.h" +#include "bearer.h" + +/* Convert Ethernet address (media address format) to string */ +static int tipc_eth_addr2str(struct tipc_media_addr *addr, + char *strbuf, int bufsz) +{ + if (bufsz < 18) /* 18 = strlen("aa:bb:cc:dd:ee:ff\0") */ + return 1; + + sprintf(strbuf, "%pM", addr->value); + return 0; +} + +/* Convert from media address format to discovery message addr format */ +static int tipc_eth_addr2msg(char *msg, struct tipc_media_addr *addr) +{ + memset(msg, 0, TIPC_MEDIA_INFO_SIZE); + msg[TIPC_MEDIA_TYPE_OFFSET] = TIPC_MEDIA_TYPE_ETH; + memcpy(msg + TIPC_MEDIA_ADDR_OFFSET, addr->value, ETH_ALEN); + return 0; +} + +/* Convert raw mac address format to media addr format */ +static int tipc_eth_raw2addr(struct tipc_bearer *b, + struct tipc_media_addr *addr, + const char *msg) +{ + memset(addr, 0, sizeof(*addr)); + ether_addr_copy(addr->value, msg); + addr->media_id = TIPC_MEDIA_TYPE_ETH; + addr->broadcast = is_broadcast_ether_addr(addr->value); + return 0; +} + +/* Convert discovery msg addr format to Ethernet media addr format */ +static int tipc_eth_msg2addr(struct tipc_bearer *b, + struct tipc_media_addr *addr, + char *msg) +{ + /* Skip past preamble: */ + msg += TIPC_MEDIA_ADDR_OFFSET; + return tipc_eth_raw2addr(b, addr, msg); +} + +/* Ethernet media registration info */ +struct tipc_media eth_media_info = { + .send_msg = tipc_l2_send_msg, + .enable_media = tipc_enable_l2_media, + .disable_media = tipc_disable_l2_media, + .addr2str = tipc_eth_addr2str, + .addr2msg = tipc_eth_addr2msg, + .msg2addr = tipc_eth_msg2addr, + .raw2addr = tipc_eth_raw2addr, + .priority = TIPC_DEF_LINK_PRI, + .tolerance = TIPC_DEF_LINK_TOL, + .min_win = TIPC_DEF_LINK_WIN, + .max_win = TIPC_MAX_LINK_WIN, + .type_id = TIPC_MEDIA_TYPE_ETH, + .hwaddr_len = ETH_ALEN, + .name = "eth" +}; diff --git a/net/tipc/group.c b/net/tipc/group.c new file mode 100644 index 000000000..3e137d8c9 --- /dev/null +++ b/net/tipc/group.c @@ -0,0 +1,959 @@ +/* + * net/tipc/group.c: TIPC group messaging code + * + * Copyright (c) 2017, Ericsson AB + * Copyright (c) 2020, Red Hat Inc + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "core.h" +#include "addr.h" +#include "group.h" +#include "bcast.h" +#include "topsrv.h" +#include "msg.h" +#include "socket.h" +#include "node.h" +#include "name_table.h" +#include "subscr.h" + +#define ADV_UNIT (((MAX_MSG_SIZE + MAX_H_SIZE) / FLOWCTL_BLK_SZ) + 1) +#define ADV_IDLE ADV_UNIT +#define ADV_ACTIVE (ADV_UNIT * 12) + +enum mbr_state { + MBR_JOINING, + MBR_PUBLISHED, + MBR_JOINED, + MBR_PENDING, + MBR_ACTIVE, + MBR_RECLAIMING, + MBR_REMITTED, + MBR_LEAVING +}; + +struct tipc_member { + struct rb_node tree_node; + struct list_head list; + struct list_head small_win; + struct sk_buff_head deferredq; + struct tipc_group *group; + u32 node; + u32 port; + u32 instance; + enum mbr_state state; + u16 advertised; + u16 window; + u16 bc_rcv_nxt; + u16 bc_syncpt; + u16 bc_acked; +}; + +struct tipc_group { + struct rb_root members; + struct list_head small_win; + struct list_head pending; + struct list_head active; + struct tipc_nlist dests; + struct net *net; + int subid; + u32 type; + u32 instance; + u32 scope; + u32 portid; + u16 member_cnt; + u16 active_cnt; + u16 max_active; + u16 bc_snd_nxt; + u16 bc_ackers; + bool *open; + bool loopback; + bool events; +}; + +static void tipc_group_proto_xmit(struct tipc_group *grp, struct tipc_member *m, + int mtyp, struct sk_buff_head *xmitq); + +static void tipc_group_open(struct tipc_member *m, bool *wakeup) +{ + *wakeup = false; + if (list_empty(&m->small_win)) + return; + list_del_init(&m->small_win); + *m->group->open = true; + *wakeup = true; +} + +static void tipc_group_decr_active(struct tipc_group *grp, + struct tipc_member *m) +{ + if (m->state == MBR_ACTIVE || m->state == MBR_RECLAIMING || + m->state == MBR_REMITTED) + grp->active_cnt--; +} + +static int tipc_group_rcvbuf_limit(struct tipc_group *grp) +{ + int max_active, active_pool, idle_pool; + int mcnt = grp->member_cnt + 1; + + /* Limit simultaneous reception from other members */ + max_active = min(mcnt / 8, 64); + max_active = max(max_active, 16); + grp->max_active = max_active; + + /* Reserve blocks for active and idle members */ + active_pool = max_active * ADV_ACTIVE; + idle_pool = (mcnt - max_active) * ADV_IDLE; + + /* Scale to bytes, considering worst-case truesize/msgsize ratio */ + return (active_pool + idle_pool) * FLOWCTL_BLK_SZ * 4; +} + +u16 tipc_group_bc_snd_nxt(struct tipc_group *grp) +{ + return grp->bc_snd_nxt; +} + +static bool tipc_group_is_receiver(struct tipc_member *m) +{ + return m && m->state != MBR_JOINING && m->state != MBR_LEAVING; +} + +static bool tipc_group_is_sender(struct tipc_member *m) +{ + return m && m->state != MBR_JOINING && m->state != MBR_PUBLISHED; +} + +u32 tipc_group_exclude(struct tipc_group *grp) +{ + if (!grp->loopback) + return grp->portid; + return 0; +} + +struct tipc_group *tipc_group_create(struct net *net, u32 portid, + struct tipc_group_req *mreq, + bool *group_is_open) +{ + u32 filter = TIPC_SUB_PORTS | TIPC_SUB_NO_STATUS; + bool global = mreq->scope != TIPC_NODE_SCOPE; + struct tipc_group *grp; + u32 type = mreq->type; + + grp = kzalloc(sizeof(*grp), GFP_ATOMIC); + if (!grp) + return NULL; + tipc_nlist_init(&grp->dests, tipc_own_addr(net)); + INIT_LIST_HEAD(&grp->small_win); + INIT_LIST_HEAD(&grp->active); + INIT_LIST_HEAD(&grp->pending); + grp->members = RB_ROOT; + grp->net = net; + grp->portid = portid; + grp->type = type; + grp->instance = mreq->instance; + grp->scope = mreq->scope; + grp->loopback = mreq->flags & TIPC_GROUP_LOOPBACK; + grp->events = mreq->flags & TIPC_GROUP_MEMBER_EVTS; + grp->open = group_is_open; + *grp->open = false; + filter |= global ? TIPC_SUB_CLUSTER_SCOPE : TIPC_SUB_NODE_SCOPE; + if (tipc_topsrv_kern_subscr(net, portid, type, 0, ~0, + filter, &grp->subid)) + return grp; + kfree(grp); + return NULL; +} + +void tipc_group_join(struct net *net, struct tipc_group *grp, int *sk_rcvbuf) +{ + struct rb_root *tree = &grp->members; + struct tipc_member *m, *tmp; + struct sk_buff_head xmitq; + + __skb_queue_head_init(&xmitq); + rbtree_postorder_for_each_entry_safe(m, tmp, tree, tree_node) { + tipc_group_proto_xmit(grp, m, GRP_JOIN_MSG, &xmitq); + tipc_group_update_member(m, 0); + } + tipc_node_distr_xmit(net, &xmitq); + *sk_rcvbuf = tipc_group_rcvbuf_limit(grp); +} + +void tipc_group_delete(struct net *net, struct tipc_group *grp) +{ + struct rb_root *tree = &grp->members; + struct tipc_member *m, *tmp; + struct sk_buff_head xmitq; + + __skb_queue_head_init(&xmitq); + + rbtree_postorder_for_each_entry_safe(m, tmp, tree, tree_node) { + tipc_group_proto_xmit(grp, m, GRP_LEAVE_MSG, &xmitq); + __skb_queue_purge(&m->deferredq); + list_del(&m->list); + kfree(m); + } + tipc_node_distr_xmit(net, &xmitq); + tipc_nlist_purge(&grp->dests); + tipc_topsrv_kern_unsubscr(net, grp->subid); + kfree(grp); +} + +static struct tipc_member *tipc_group_find_member(struct tipc_group *grp, + u32 node, u32 port) +{ + struct rb_node *n = grp->members.rb_node; + u64 nkey, key = (u64)node << 32 | port; + struct tipc_member *m; + + while (n) { + m = container_of(n, struct tipc_member, tree_node); + nkey = (u64)m->node << 32 | m->port; + if (key < nkey) + n = n->rb_left; + else if (key > nkey) + n = n->rb_right; + else + return m; + } + return NULL; +} + +static struct tipc_member *tipc_group_find_dest(struct tipc_group *grp, + u32 node, u32 port) +{ + struct tipc_member *m; + + m = tipc_group_find_member(grp, node, port); + if (m && tipc_group_is_receiver(m)) + return m; + return NULL; +} + +static struct tipc_member *tipc_group_find_node(struct tipc_group *grp, + u32 node) +{ + struct tipc_member *m; + struct rb_node *n; + + for (n = rb_first(&grp->members); n; n = rb_next(n)) { + m = container_of(n, struct tipc_member, tree_node); + if (m->node == node) + return m; + } + return NULL; +} + +static int tipc_group_add_to_tree(struct tipc_group *grp, + struct tipc_member *m) +{ + u64 nkey, key = (u64)m->node << 32 | m->port; + struct rb_node **n, *parent = NULL; + struct tipc_member *tmp; + + n = &grp->members.rb_node; + while (*n) { + tmp = container_of(*n, struct tipc_member, tree_node); + parent = *n; + tmp = container_of(parent, struct tipc_member, tree_node); + nkey = (u64)tmp->node << 32 | tmp->port; + if (key < nkey) + n = &(*n)->rb_left; + else if (key > nkey) + n = &(*n)->rb_right; + else + return -EEXIST; + } + rb_link_node(&m->tree_node, parent, n); + rb_insert_color(&m->tree_node, &grp->members); + return 0; +} + +static struct tipc_member *tipc_group_create_member(struct tipc_group *grp, + u32 node, u32 port, + u32 instance, int state) +{ + struct tipc_member *m; + int ret; + + m = kzalloc(sizeof(*m), GFP_ATOMIC); + if (!m) + return NULL; + INIT_LIST_HEAD(&m->list); + INIT_LIST_HEAD(&m->small_win); + __skb_queue_head_init(&m->deferredq); + m->group = grp; + m->node = node; + m->port = port; + m->instance = instance; + m->bc_acked = grp->bc_snd_nxt - 1; + ret = tipc_group_add_to_tree(grp, m); + if (ret < 0) { + kfree(m); + return NULL; + } + grp->member_cnt++; + tipc_nlist_add(&grp->dests, m->node); + m->state = state; + return m; +} + +void tipc_group_add_member(struct tipc_group *grp, u32 node, + u32 port, u32 instance) +{ + tipc_group_create_member(grp, node, port, instance, MBR_PUBLISHED); +} + +static void tipc_group_delete_member(struct tipc_group *grp, + struct tipc_member *m) +{ + rb_erase(&m->tree_node, &grp->members); + grp->member_cnt--; + + /* Check if we were waiting for replicast ack from this member */ + if (grp->bc_ackers && less(m->bc_acked, grp->bc_snd_nxt - 1)) + grp->bc_ackers--; + + list_del_init(&m->list); + list_del_init(&m->small_win); + tipc_group_decr_active(grp, m); + + /* If last member on a node, remove node from dest list */ + if (!tipc_group_find_node(grp, m->node)) + tipc_nlist_del(&grp->dests, m->node); + + kfree(m); +} + +struct tipc_nlist *tipc_group_dests(struct tipc_group *grp) +{ + return &grp->dests; +} + +void tipc_group_self(struct tipc_group *grp, struct tipc_service_range *seq, + int *scope) +{ + seq->type = grp->type; + seq->lower = grp->instance; + seq->upper = grp->instance; + *scope = grp->scope; +} + +void tipc_group_update_member(struct tipc_member *m, int len) +{ + struct tipc_group *grp = m->group; + struct tipc_member *_m, *tmp; + + if (!tipc_group_is_receiver(m)) + return; + + m->window -= len; + + if (m->window >= ADV_IDLE) + return; + + list_del_init(&m->small_win); + + /* Sort member into small_window members' list */ + list_for_each_entry_safe(_m, tmp, &grp->small_win, small_win) { + if (_m->window > m->window) + break; + } + list_add_tail(&m->small_win, &_m->small_win); +} + +void tipc_group_update_bc_members(struct tipc_group *grp, int len, bool ack) +{ + u16 prev = grp->bc_snd_nxt - 1; + struct tipc_member *m; + struct rb_node *n; + u16 ackers = 0; + + for (n = rb_first(&grp->members); n; n = rb_next(n)) { + m = container_of(n, struct tipc_member, tree_node); + if (tipc_group_is_receiver(m)) { + tipc_group_update_member(m, len); + m->bc_acked = prev; + ackers++; + } + } + + /* Mark number of acknowledges to expect, if any */ + if (ack) + grp->bc_ackers = ackers; + grp->bc_snd_nxt++; +} + +bool tipc_group_cong(struct tipc_group *grp, u32 dnode, u32 dport, + int len, struct tipc_member **mbr) +{ + struct sk_buff_head xmitq; + struct tipc_member *m; + int adv, state; + + m = tipc_group_find_dest(grp, dnode, dport); + if (!tipc_group_is_receiver(m)) { + *mbr = NULL; + return false; + } + *mbr = m; + + if (m->window >= len) + return false; + + *grp->open = false; + + /* If not fully advertised, do it now to prevent mutual blocking */ + adv = m->advertised; + state = m->state; + if (state == MBR_JOINED && adv == ADV_IDLE) + return true; + if (state == MBR_ACTIVE && adv == ADV_ACTIVE) + return true; + if (state == MBR_PENDING && adv == ADV_IDLE) + return true; + __skb_queue_head_init(&xmitq); + tipc_group_proto_xmit(grp, m, GRP_ADV_MSG, &xmitq); + tipc_node_distr_xmit(grp->net, &xmitq); + return true; +} + +bool tipc_group_bc_cong(struct tipc_group *grp, int len) +{ + struct tipc_member *m = NULL; + + /* If prev bcast was replicast, reject until all receivers have acked */ + if (grp->bc_ackers) { + *grp->open = false; + return true; + } + if (list_empty(&grp->small_win)) + return false; + + m = list_first_entry(&grp->small_win, struct tipc_member, small_win); + if (m->window >= len) + return false; + + return tipc_group_cong(grp, m->node, m->port, len, &m); +} + +/* tipc_group_sort_msg() - sort msg into queue by bcast sequence number + */ +static void tipc_group_sort_msg(struct sk_buff *skb, struct sk_buff_head *defq) +{ + struct tipc_msg *_hdr, *hdr = buf_msg(skb); + u16 bc_seqno = msg_grp_bc_seqno(hdr); + struct sk_buff *_skb, *tmp; + int mtyp = msg_type(hdr); + + /* Bcast/mcast may be bypassed by ucast or other bcast, - sort it in */ + if (mtyp == TIPC_GRP_BCAST_MSG || mtyp == TIPC_GRP_MCAST_MSG) { + skb_queue_walk_safe(defq, _skb, tmp) { + _hdr = buf_msg(_skb); + if (!less(bc_seqno, msg_grp_bc_seqno(_hdr))) + continue; + __skb_queue_before(defq, _skb, skb); + return; + } + /* Bcast was not bypassed, - add to tail */ + } + /* Unicasts are never bypassed, - always add to tail */ + __skb_queue_tail(defq, skb); +} + +/* tipc_group_filter_msg() - determine if we should accept arriving message + */ +void tipc_group_filter_msg(struct tipc_group *grp, struct sk_buff_head *inputq, + struct sk_buff_head *xmitq) +{ + struct sk_buff *skb = __skb_dequeue(inputq); + bool ack, deliver, update, leave = false; + struct sk_buff_head *defq; + struct tipc_member *m; + struct tipc_msg *hdr; + u32 node, port; + int mtyp, blks; + + if (!skb) + return; + + hdr = buf_msg(skb); + node = msg_orignode(hdr); + port = msg_origport(hdr); + + if (!msg_in_group(hdr)) + goto drop; + + m = tipc_group_find_member(grp, node, port); + if (!tipc_group_is_sender(m)) + goto drop; + + if (less(msg_grp_bc_seqno(hdr), m->bc_rcv_nxt)) + goto drop; + + TIPC_SKB_CB(skb)->orig_member = m->instance; + defq = &m->deferredq; + tipc_group_sort_msg(skb, defq); + + while ((skb = skb_peek(defq))) { + hdr = buf_msg(skb); + mtyp = msg_type(hdr); + blks = msg_blocks(hdr); + deliver = true; + ack = false; + update = false; + + if (more(msg_grp_bc_seqno(hdr), m->bc_rcv_nxt)) + break; + + /* Decide what to do with message */ + switch (mtyp) { + case TIPC_GRP_MCAST_MSG: + if (msg_nameinst(hdr) != grp->instance) { + update = true; + deliver = false; + } + fallthrough; + case TIPC_GRP_BCAST_MSG: + m->bc_rcv_nxt++; + ack = msg_grp_bc_ack_req(hdr); + break; + case TIPC_GRP_UCAST_MSG: + break; + case TIPC_GRP_MEMBER_EVT: + if (m->state == MBR_LEAVING) + leave = true; + if (!grp->events) + deliver = false; + break; + default: + break; + } + + /* Execute decisions */ + __skb_dequeue(defq); + if (deliver) + __skb_queue_tail(inputq, skb); + else + kfree_skb(skb); + + if (ack) + tipc_group_proto_xmit(grp, m, GRP_ACK_MSG, xmitq); + + if (leave) { + __skb_queue_purge(defq); + tipc_group_delete_member(grp, m); + break; + } + if (!update) + continue; + + tipc_group_update_rcv_win(grp, blks, node, port, xmitq); + } + return; +drop: + kfree_skb(skb); +} + +void tipc_group_update_rcv_win(struct tipc_group *grp, int blks, u32 node, + u32 port, struct sk_buff_head *xmitq) +{ + struct list_head *active = &grp->active; + int max_active = grp->max_active; + int reclaim_limit = max_active * 3 / 4; + int active_cnt = grp->active_cnt; + struct tipc_member *m, *rm, *pm; + + m = tipc_group_find_member(grp, node, port); + if (!m) + return; + + m->advertised -= blks; + + switch (m->state) { + case MBR_JOINED: + /* First, decide if member can go active */ + if (active_cnt <= max_active) { + m->state = MBR_ACTIVE; + list_add_tail(&m->list, active); + grp->active_cnt++; + tipc_group_proto_xmit(grp, m, GRP_ADV_MSG, xmitq); + } else { + m->state = MBR_PENDING; + list_add_tail(&m->list, &grp->pending); + } + + if (active_cnt < reclaim_limit) + break; + + /* Reclaim from oldest active member, if possible */ + if (!list_empty(active)) { + rm = list_first_entry(active, struct tipc_member, list); + rm->state = MBR_RECLAIMING; + list_del_init(&rm->list); + tipc_group_proto_xmit(grp, rm, GRP_RECLAIM_MSG, xmitq); + break; + } + /* Nobody to reclaim from; - revert oldest pending to JOINED */ + pm = list_first_entry(&grp->pending, struct tipc_member, list); + list_del_init(&pm->list); + pm->state = MBR_JOINED; + tipc_group_proto_xmit(grp, pm, GRP_ADV_MSG, xmitq); + break; + case MBR_ACTIVE: + if (!list_is_last(&m->list, &grp->active)) + list_move_tail(&m->list, &grp->active); + if (m->advertised > (ADV_ACTIVE * 3 / 4)) + break; + tipc_group_proto_xmit(grp, m, GRP_ADV_MSG, xmitq); + break; + case MBR_REMITTED: + if (m->advertised > ADV_IDLE) + break; + m->state = MBR_JOINED; + grp->active_cnt--; + if (m->advertised < ADV_IDLE) { + pr_warn_ratelimited("Rcv unexpected msg after REMIT\n"); + tipc_group_proto_xmit(grp, m, GRP_ADV_MSG, xmitq); + } + + if (list_empty(&grp->pending)) + return; + + /* Set oldest pending member to active and advertise */ + pm = list_first_entry(&grp->pending, struct tipc_member, list); + pm->state = MBR_ACTIVE; + list_move_tail(&pm->list, &grp->active); + grp->active_cnt++; + tipc_group_proto_xmit(grp, pm, GRP_ADV_MSG, xmitq); + break; + case MBR_RECLAIMING: + case MBR_JOINING: + case MBR_LEAVING: + default: + break; + } +} + +static void tipc_group_create_event(struct tipc_group *grp, + struct tipc_member *m, + u32 event, u16 seqno, + struct sk_buff_head *inputq) +{ u32 dnode = tipc_own_addr(grp->net); + struct tipc_event evt; + struct sk_buff *skb; + struct tipc_msg *hdr; + + memset(&evt, 0, sizeof(evt)); + evt.event = event; + evt.found_lower = m->instance; + evt.found_upper = m->instance; + evt.port.ref = m->port; + evt.port.node = m->node; + evt.s.seq.type = grp->type; + evt.s.seq.lower = m->instance; + evt.s.seq.upper = m->instance; + + skb = tipc_msg_create(TIPC_CRITICAL_IMPORTANCE, TIPC_GRP_MEMBER_EVT, + GROUP_H_SIZE, sizeof(evt), dnode, m->node, + grp->portid, m->port, 0); + if (!skb) + return; + + hdr = buf_msg(skb); + msg_set_nametype(hdr, grp->type); + msg_set_grp_evt(hdr, event); + msg_set_dest_droppable(hdr, true); + msg_set_grp_bc_seqno(hdr, seqno); + memcpy(msg_data(hdr), &evt, sizeof(evt)); + TIPC_SKB_CB(skb)->orig_member = m->instance; + __skb_queue_tail(inputq, skb); +} + +static void tipc_group_proto_xmit(struct tipc_group *grp, struct tipc_member *m, + int mtyp, struct sk_buff_head *xmitq) +{ + struct tipc_msg *hdr; + struct sk_buff *skb; + int adv = 0; + + skb = tipc_msg_create(GROUP_PROTOCOL, mtyp, INT_H_SIZE, 0, + m->node, tipc_own_addr(grp->net), + m->port, grp->portid, 0); + if (!skb) + return; + + if (m->state == MBR_ACTIVE) + adv = ADV_ACTIVE - m->advertised; + else if (m->state == MBR_JOINED || m->state == MBR_PENDING) + adv = ADV_IDLE - m->advertised; + + hdr = buf_msg(skb); + + if (mtyp == GRP_JOIN_MSG) { + msg_set_grp_bc_syncpt(hdr, grp->bc_snd_nxt); + msg_set_adv_win(hdr, adv); + m->advertised += adv; + } else if (mtyp == GRP_LEAVE_MSG) { + msg_set_grp_bc_syncpt(hdr, grp->bc_snd_nxt); + } else if (mtyp == GRP_ADV_MSG) { + msg_set_adv_win(hdr, adv); + m->advertised += adv; + } else if (mtyp == GRP_ACK_MSG) { + msg_set_grp_bc_acked(hdr, m->bc_rcv_nxt); + } else if (mtyp == GRP_REMIT_MSG) { + msg_set_grp_remitted(hdr, m->window); + } + msg_set_dest_droppable(hdr, true); + __skb_queue_tail(xmitq, skb); +} + +void tipc_group_proto_rcv(struct tipc_group *grp, bool *usr_wakeup, + struct tipc_msg *hdr, struct sk_buff_head *inputq, + struct sk_buff_head *xmitq) +{ + u32 node = msg_orignode(hdr); + u32 port = msg_origport(hdr); + struct tipc_member *m, *pm; + u16 remitted, in_flight; + + if (!grp) + return; + + if (grp->scope == TIPC_NODE_SCOPE && node != tipc_own_addr(grp->net)) + return; + + m = tipc_group_find_member(grp, node, port); + + switch (msg_type(hdr)) { + case GRP_JOIN_MSG: + if (!m) + m = tipc_group_create_member(grp, node, port, + 0, MBR_JOINING); + if (!m) + return; + m->bc_syncpt = msg_grp_bc_syncpt(hdr); + m->bc_rcv_nxt = m->bc_syncpt; + m->window += msg_adv_win(hdr); + + /* Wait until PUBLISH event is received if necessary */ + if (m->state != MBR_PUBLISHED) + return; + + /* Member can be taken into service */ + m->state = MBR_JOINED; + tipc_group_open(m, usr_wakeup); + tipc_group_update_member(m, 0); + tipc_group_proto_xmit(grp, m, GRP_ADV_MSG, xmitq); + tipc_group_create_event(grp, m, TIPC_PUBLISHED, + m->bc_syncpt, inputq); + return; + case GRP_LEAVE_MSG: + if (!m) + return; + m->bc_syncpt = msg_grp_bc_syncpt(hdr); + list_del_init(&m->list); + tipc_group_open(m, usr_wakeup); + tipc_group_decr_active(grp, m); + m->state = MBR_LEAVING; + tipc_group_create_event(grp, m, TIPC_WITHDRAWN, + m->bc_syncpt, inputq); + return; + case GRP_ADV_MSG: + if (!m) + return; + m->window += msg_adv_win(hdr); + tipc_group_open(m, usr_wakeup); + return; + case GRP_ACK_MSG: + if (!m) + return; + m->bc_acked = msg_grp_bc_acked(hdr); + if (--grp->bc_ackers) + return; + list_del_init(&m->small_win); + *m->group->open = true; + *usr_wakeup = true; + tipc_group_update_member(m, 0); + return; + case GRP_RECLAIM_MSG: + if (!m) + return; + tipc_group_proto_xmit(grp, m, GRP_REMIT_MSG, xmitq); + m->window = ADV_IDLE; + tipc_group_open(m, usr_wakeup); + return; + case GRP_REMIT_MSG: + if (!m || m->state != MBR_RECLAIMING) + return; + + remitted = msg_grp_remitted(hdr); + + /* Messages preceding the REMIT still in receive queue */ + if (m->advertised > remitted) { + m->state = MBR_REMITTED; + in_flight = m->advertised - remitted; + m->advertised = ADV_IDLE + in_flight; + return; + } + /* This should never happen */ + if (m->advertised < remitted) + pr_warn_ratelimited("Unexpected REMIT msg\n"); + + /* All messages preceding the REMIT have been read */ + m->state = MBR_JOINED; + grp->active_cnt--; + m->advertised = ADV_IDLE; + + /* Set oldest pending member to active and advertise */ + if (list_empty(&grp->pending)) + return; + pm = list_first_entry(&grp->pending, struct tipc_member, list); + pm->state = MBR_ACTIVE; + list_move_tail(&pm->list, &grp->active); + grp->active_cnt++; + if (pm->advertised <= (ADV_ACTIVE * 3 / 4)) + tipc_group_proto_xmit(grp, pm, GRP_ADV_MSG, xmitq); + return; + default: + pr_warn("Received unknown GROUP_PROTO message\n"); + } +} + +/* tipc_group_member_evt() - receive and handle a member up/down event + */ +void tipc_group_member_evt(struct tipc_group *grp, + bool *usr_wakeup, + int *sk_rcvbuf, + struct tipc_msg *hdr, + struct sk_buff_head *inputq, + struct sk_buff_head *xmitq) +{ + struct tipc_event *evt = (void *)msg_data(hdr); + u32 instance = evt->found_lower; + u32 node = evt->port.node; + u32 port = evt->port.ref; + int event = evt->event; + struct tipc_member *m; + struct net *net; + u32 self; + + if (!grp) + return; + + net = grp->net; + self = tipc_own_addr(net); + if (!grp->loopback && node == self && port == grp->portid) + return; + + m = tipc_group_find_member(grp, node, port); + + switch (event) { + case TIPC_PUBLISHED: + /* Send and wait for arrival of JOIN message if necessary */ + if (!m) { + m = tipc_group_create_member(grp, node, port, instance, + MBR_PUBLISHED); + if (!m) + break; + tipc_group_update_member(m, 0); + tipc_group_proto_xmit(grp, m, GRP_JOIN_MSG, xmitq); + break; + } + + if (m->state != MBR_JOINING) + break; + + /* Member can be taken into service */ + m->instance = instance; + m->state = MBR_JOINED; + tipc_group_open(m, usr_wakeup); + tipc_group_update_member(m, 0); + tipc_group_proto_xmit(grp, m, GRP_JOIN_MSG, xmitq); + tipc_group_create_event(grp, m, TIPC_PUBLISHED, + m->bc_syncpt, inputq); + break; + case TIPC_WITHDRAWN: + if (!m) + break; + + tipc_group_decr_active(grp, m); + m->state = MBR_LEAVING; + list_del_init(&m->list); + tipc_group_open(m, usr_wakeup); + + /* Only send event if no LEAVE message can be expected */ + if (!tipc_node_is_up(net, node)) + tipc_group_create_event(grp, m, TIPC_WITHDRAWN, + m->bc_rcv_nxt, inputq); + break; + default: + break; + } + *sk_rcvbuf = tipc_group_rcvbuf_limit(grp); +} + +int tipc_group_fill_sock_diag(struct tipc_group *grp, struct sk_buff *skb) +{ + struct nlattr *group = nla_nest_start_noflag(skb, TIPC_NLA_SOCK_GROUP); + + if (!group) + return -EMSGSIZE; + + if (nla_put_u32(skb, TIPC_NLA_SOCK_GROUP_ID, + grp->type) || + nla_put_u32(skb, TIPC_NLA_SOCK_GROUP_INSTANCE, + grp->instance) || + nla_put_u32(skb, TIPC_NLA_SOCK_GROUP_BC_SEND_NEXT, + grp->bc_snd_nxt)) + goto group_msg_cancel; + + if (grp->scope == TIPC_NODE_SCOPE) + if (nla_put_flag(skb, TIPC_NLA_SOCK_GROUP_NODE_SCOPE)) + goto group_msg_cancel; + + if (grp->scope == TIPC_CLUSTER_SCOPE) + if (nla_put_flag(skb, TIPC_NLA_SOCK_GROUP_CLUSTER_SCOPE)) + goto group_msg_cancel; + + if (*grp->open) + if (nla_put_flag(skb, TIPC_NLA_SOCK_GROUP_OPEN)) + goto group_msg_cancel; + + nla_nest_end(skb, group); + return 0; + +group_msg_cancel: + nla_nest_cancel(skb, group); + return -1; +} diff --git a/net/tipc/group.h b/net/tipc/group.h new file mode 100644 index 000000000..ea4c3be64 --- /dev/null +++ b/net/tipc/group.h @@ -0,0 +1,77 @@ +/* + * net/tipc/group.h: Include file for TIPC group unicast/multicast functions + * + * Copyright (c) 2017, Ericsson AB + * Copyright (c) 2020, Red Hat Inc + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _TIPC_GROUP_H +#define _TIPC_GROUP_H + +#include "core.h" + +struct tipc_group; +struct tipc_member; +struct tipc_msg; + +struct tipc_group *tipc_group_create(struct net *net, u32 portid, + struct tipc_group_req *mreq, + bool *group_is_open); +void tipc_group_join(struct net *net, struct tipc_group *grp, int *sk_rcv_buf); +void tipc_group_delete(struct net *net, struct tipc_group *grp); +void tipc_group_add_member(struct tipc_group *grp, u32 node, + u32 port, u32 instance); +struct tipc_nlist *tipc_group_dests(struct tipc_group *grp); +void tipc_group_self(struct tipc_group *grp, struct tipc_service_range *seq, + int *scope); +u32 tipc_group_exclude(struct tipc_group *grp); +void tipc_group_filter_msg(struct tipc_group *grp, + struct sk_buff_head *inputq, + struct sk_buff_head *xmitq); +void tipc_group_member_evt(struct tipc_group *grp, bool *wakeup, + int *sk_rcvbuf, struct tipc_msg *hdr, + struct sk_buff_head *inputq, + struct sk_buff_head *xmitq); +void tipc_group_proto_rcv(struct tipc_group *grp, bool *wakeup, + struct tipc_msg *hdr, + struct sk_buff_head *inputq, + struct sk_buff_head *xmitq); +void tipc_group_update_bc_members(struct tipc_group *grp, int len, bool ack); +bool tipc_group_cong(struct tipc_group *grp, u32 dnode, u32 dport, + int len, struct tipc_member **m); +bool tipc_group_bc_cong(struct tipc_group *grp, int len); +void tipc_group_update_rcv_win(struct tipc_group *grp, int blks, u32 node, + u32 port, struct sk_buff_head *xmitq); +u16 tipc_group_bc_snd_nxt(struct tipc_group *grp); +void tipc_group_update_member(struct tipc_member *m, int len); +int tipc_group_fill_sock_diag(struct tipc_group *grp, struct sk_buff *skb); +#endif diff --git a/net/tipc/ib_media.c b/net/tipc/ib_media.c new file mode 100644 index 000000000..b9ad0434c --- /dev/null +++ b/net/tipc/ib_media.c @@ -0,0 +1,104 @@ +/* + * net/tipc/ib_media.c: Infiniband bearer support for TIPC + * + * Copyright (c) 2013 Patrick McHardy + * + * Based on eth_media.c, which carries the following copyright notice: + * + * Copyright (c) 2001-2007, Ericsson AB + * Copyright (c) 2005-2008, 2011, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include +#include "core.h" +#include "bearer.h" + +#define TIPC_MAX_IB_LINK_WIN 500 + +/* convert InfiniBand address (media address format) media address to string */ +static int tipc_ib_addr2str(struct tipc_media_addr *a, char *str_buf, + int str_size) +{ + if (str_size < 60) /* 60 = 19 * strlen("xx:") + strlen("xx\0") */ + return 1; + + sprintf(str_buf, "%20phC", a->value); + + return 0; +} + +/* Convert from media address format to discovery message addr format */ +static int tipc_ib_addr2msg(char *msg, struct tipc_media_addr *addr) +{ + memset(msg, 0, TIPC_MEDIA_INFO_SIZE); + memcpy(msg, addr->value, INFINIBAND_ALEN); + return 0; +} + +/* Convert raw InfiniBand address format to media addr format */ +static int tipc_ib_raw2addr(struct tipc_bearer *b, + struct tipc_media_addr *addr, + const char *msg) +{ + memset(addr, 0, sizeof(*addr)); + memcpy(addr->value, msg, INFINIBAND_ALEN); + addr->media_id = TIPC_MEDIA_TYPE_IB; + addr->broadcast = !memcmp(msg, b->bcast_addr.value, + INFINIBAND_ALEN); + return 0; +} + +/* Convert discovery msg addr format to InfiniBand media addr format */ +static int tipc_ib_msg2addr(struct tipc_bearer *b, + struct tipc_media_addr *addr, + char *msg) +{ + return tipc_ib_raw2addr(b, addr, msg); +} + +/* InfiniBand media registration info */ +struct tipc_media ib_media_info = { + .send_msg = tipc_l2_send_msg, + .enable_media = tipc_enable_l2_media, + .disable_media = tipc_disable_l2_media, + .addr2str = tipc_ib_addr2str, + .addr2msg = tipc_ib_addr2msg, + .msg2addr = tipc_ib_msg2addr, + .raw2addr = tipc_ib_raw2addr, + .priority = TIPC_DEF_LINK_PRI, + .tolerance = TIPC_DEF_LINK_TOL, + .min_win = TIPC_DEF_LINK_WIN, + .max_win = TIPC_MAX_IB_LINK_WIN, + .type_id = TIPC_MEDIA_TYPE_IB, + .hwaddr_len = INFINIBAND_ALEN, + .name = "ib" +}; diff --git a/net/tipc/link.c b/net/tipc/link.c new file mode 100644 index 000000000..8715c9b05 --- /dev/null +++ b/net/tipc/link.c @@ -0,0 +1,3009 @@ +/* + * net/tipc/link.c: TIPC link code + * + * Copyright (c) 1996-2007, 2012-2016, Ericsson AB + * Copyright (c) 2004-2007, 2010-2013, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "core.h" +#include "subscr.h" +#include "link.h" +#include "bcast.h" +#include "socket.h" +#include "name_distr.h" +#include "discover.h" +#include "netlink.h" +#include "monitor.h" +#include "trace.h" +#include "crypto.h" + +#include + +struct tipc_stats { + u32 sent_pkts; + u32 recv_pkts; + u32 sent_states; + u32 recv_states; + u32 sent_probes; + u32 recv_probes; + u32 sent_nacks; + u32 recv_nacks; + u32 sent_acks; + u32 sent_bundled; + u32 sent_bundles; + u32 recv_bundled; + u32 recv_bundles; + u32 retransmitted; + u32 sent_fragmented; + u32 sent_fragments; + u32 recv_fragmented; + u32 recv_fragments; + u32 link_congs; /* # port sends blocked by congestion */ + u32 deferred_recv; + u32 duplicates; + u32 max_queue_sz; /* send queue size high water mark */ + u32 accu_queue_sz; /* used for send queue size profiling */ + u32 queue_sz_counts; /* used for send queue size profiling */ + u32 msg_length_counts; /* used for message length profiling */ + u32 msg_lengths_total; /* used for message length profiling */ + u32 msg_length_profile[7]; /* used for msg. length profiling */ +}; + +/** + * struct tipc_link - TIPC link data structure + * @addr: network address of link's peer node + * @name: link name character string + * @media_addr: media address to use when sending messages over link + * @timer: link timer + * @net: pointer to namespace struct + * @refcnt: reference counter for permanent references (owner node & timer) + * @peer_session: link session # being used by peer end of link + * @peer_bearer_id: bearer id used by link's peer endpoint + * @bearer_id: local bearer id used by link + * @tolerance: minimum link continuity loss needed to reset link [in ms] + * @abort_limit: # of unacknowledged continuity probes needed to reset link + * @state: current state of link FSM + * @peer_caps: bitmap describing capabilities of peer node + * @silent_intv_cnt: # of timer intervals without any reception from peer + * @proto_msg: template for control messages generated by link + * @pmsg: convenience pointer to "proto_msg" field + * @priority: current link priority + * @net_plane: current link network plane ('A' through 'H') + * @mon_state: cookie with information needed by link monitor + * @backlog_limit: backlog queue congestion thresholds (indexed by importance) + * @exp_msg_count: # of tunnelled messages expected during link changeover + * @reset_rcv_checkpt: seq # of last acknowledged message at time of link reset + * @mtu: current maximum packet size for this link + * @advertised_mtu: advertised own mtu when link is being established + * @transmitq: queue for sent, non-acked messages + * @backlogq: queue for messages waiting to be sent + * @snt_nxt: next sequence number to use for outbound messages + * @ackers: # of peers that needs to ack each packet before it can be released + * @acked: # last packet acked by a certain peer. Used for broadcast. + * @rcv_nxt: next sequence number to expect for inbound messages + * @deferred_queue: deferred queue saved OOS b'cast message received from node + * @unacked_window: # of inbound messages rx'd without ack'ing back to peer + * @inputq: buffer queue for messages to be delivered upwards + * @namedq: buffer queue for name table messages to be delivered upwards + * @next_out: ptr to first unsent outbound message in queue + * @wakeupq: linked list of wakeup msgs waiting for link congestion to abate + * @long_msg_seq_no: next identifier to use for outbound fragmented messages + * @reasm_buf: head of partially reassembled inbound message fragments + * @bc_rcvr: marks that this is a broadcast receiver link + * @stats: collects statistics regarding link activity + * @session: session to be used by link + * @snd_nxt_state: next send seq number + * @rcv_nxt_state: next rcv seq number + * @in_session: have received ACTIVATE_MSG from peer + * @active: link is active + * @if_name: associated interface name + * @rst_cnt: link reset counter + * @drop_point: seq number for failover handling (FIXME) + * @failover_reasm_skb: saved failover msg ptr (FIXME) + * @failover_deferdq: deferred message queue for failover processing (FIXME) + * @transmq: the link's transmit queue + * @backlog: link's backlog by priority (importance) + * @snd_nxt: next sequence number to be used + * @rcv_unacked: # messages read by user, but not yet acked back to peer + * @deferdq: deferred receive queue + * @window: sliding window size for congestion handling + * @min_win: minimal send window to be used by link + * @ssthresh: slow start threshold for congestion handling + * @max_win: maximal send window to be used by link + * @cong_acks: congestion acks for congestion avoidance (FIXME) + * @checkpoint: seq number for congestion window size handling + * @reasm_tnlmsg: fragmentation/reassembly area for tunnel protocol message + * @last_gap: last gap ack blocks for bcast (FIXME) + * @last_ga: ptr to gap ack blocks + * @bc_rcvlink: the peer specific link used for broadcast reception + * @bc_sndlink: the namespace global link used for broadcast sending + * @nack_state: bcast nack state + * @bc_peer_is_up: peer has acked the bcast init msg + */ +struct tipc_link { + u32 addr; + char name[TIPC_MAX_LINK_NAME]; + struct net *net; + + /* Management and link supervision data */ + u16 peer_session; + u16 session; + u16 snd_nxt_state; + u16 rcv_nxt_state; + u32 peer_bearer_id; + u32 bearer_id; + u32 tolerance; + u32 abort_limit; + u32 state; + u16 peer_caps; + bool in_session; + bool active; + u32 silent_intv_cnt; + char if_name[TIPC_MAX_IF_NAME]; + u32 priority; + char net_plane; + struct tipc_mon_state mon_state; + u16 rst_cnt; + + /* Failover/synch */ + u16 drop_point; + struct sk_buff *failover_reasm_skb; + struct sk_buff_head failover_deferdq; + + /* Max packet negotiation */ + u16 mtu; + u16 advertised_mtu; + + /* Sending */ + struct sk_buff_head transmq; + struct sk_buff_head backlogq; + struct { + u16 len; + u16 limit; + struct sk_buff *target_bskb; + } backlog[5]; + u16 snd_nxt; + + /* Reception */ + u16 rcv_nxt; + u32 rcv_unacked; + struct sk_buff_head deferdq; + struct sk_buff_head *inputq; + struct sk_buff_head *namedq; + + /* Congestion handling */ + struct sk_buff_head wakeupq; + u16 window; + u16 min_win; + u16 ssthresh; + u16 max_win; + u16 cong_acks; + u16 checkpoint; + + /* Fragmentation/reassembly */ + struct sk_buff *reasm_buf; + struct sk_buff *reasm_tnlmsg; + + /* Broadcast */ + u16 ackers; + u16 acked; + u16 last_gap; + struct tipc_gap_ack_blks *last_ga; + struct tipc_link *bc_rcvlink; + struct tipc_link *bc_sndlink; + u8 nack_state; + bool bc_peer_is_up; + + /* Statistics */ + struct tipc_stats stats; +}; + +/* + * Error message prefixes + */ +static const char *link_co_err = "Link tunneling error, "; +static const char *link_rst_msg = "Resetting link "; + +/* Send states for broadcast NACKs + */ +enum { + BC_NACK_SND_CONDITIONAL, + BC_NACK_SND_UNCONDITIONAL, + BC_NACK_SND_SUPPRESS, +}; + +#define TIPC_BC_RETR_LIM (jiffies + msecs_to_jiffies(10)) +#define TIPC_UC_RETR_TIME (jiffies + msecs_to_jiffies(1)) + +/* Link FSM states: + */ +enum { + LINK_ESTABLISHED = 0xe, + LINK_ESTABLISHING = 0xe << 4, + LINK_RESET = 0x1 << 8, + LINK_RESETTING = 0x2 << 12, + LINK_PEER_RESET = 0xd << 16, + LINK_FAILINGOVER = 0xf << 20, + LINK_SYNCHING = 0xc << 24 +}; + +/* Link FSM state checking routines + */ +static int link_is_up(struct tipc_link *l) +{ + return l->state & (LINK_ESTABLISHED | LINK_SYNCHING); +} + +static int tipc_link_proto_rcv(struct tipc_link *l, struct sk_buff *skb, + struct sk_buff_head *xmitq); +static void tipc_link_build_proto_msg(struct tipc_link *l, int mtyp, bool probe, + bool probe_reply, u16 rcvgap, + int tolerance, int priority, + struct sk_buff_head *xmitq); +static void link_print(struct tipc_link *l, const char *str); +static int tipc_link_build_nack_msg(struct tipc_link *l, + struct sk_buff_head *xmitq); +static void tipc_link_build_bc_init_msg(struct tipc_link *l, + struct sk_buff_head *xmitq); +static u8 __tipc_build_gap_ack_blks(struct tipc_gap_ack_blks *ga, + struct tipc_link *l, u8 start_index); +static u16 tipc_build_gap_ack_blks(struct tipc_link *l, struct tipc_msg *hdr); +static int tipc_link_advance_transmq(struct tipc_link *l, struct tipc_link *r, + u16 acked, u16 gap, + struct tipc_gap_ack_blks *ga, + struct sk_buff_head *xmitq, + bool *retransmitted, int *rc); +static void tipc_link_update_cwin(struct tipc_link *l, int released, + bool retransmitted); +/* + * Simple non-static link routines (i.e. referenced outside this file) + */ +bool tipc_link_is_up(struct tipc_link *l) +{ + return link_is_up(l); +} + +bool tipc_link_peer_is_down(struct tipc_link *l) +{ + return l->state == LINK_PEER_RESET; +} + +bool tipc_link_is_reset(struct tipc_link *l) +{ + return l->state & (LINK_RESET | LINK_FAILINGOVER | LINK_ESTABLISHING); +} + +bool tipc_link_is_establishing(struct tipc_link *l) +{ + return l->state == LINK_ESTABLISHING; +} + +bool tipc_link_is_synching(struct tipc_link *l) +{ + return l->state == LINK_SYNCHING; +} + +bool tipc_link_is_failingover(struct tipc_link *l) +{ + return l->state == LINK_FAILINGOVER; +} + +bool tipc_link_is_blocked(struct tipc_link *l) +{ + return l->state & (LINK_RESETTING | LINK_PEER_RESET | LINK_FAILINGOVER); +} + +static bool link_is_bc_sndlink(struct tipc_link *l) +{ + return !l->bc_sndlink; +} + +static bool link_is_bc_rcvlink(struct tipc_link *l) +{ + return ((l->bc_rcvlink == l) && !link_is_bc_sndlink(l)); +} + +void tipc_link_set_active(struct tipc_link *l, bool active) +{ + l->active = active; +} + +u32 tipc_link_id(struct tipc_link *l) +{ + return l->peer_bearer_id << 16 | l->bearer_id; +} + +int tipc_link_min_win(struct tipc_link *l) +{ + return l->min_win; +} + +int tipc_link_max_win(struct tipc_link *l) +{ + return l->max_win; +} + +int tipc_link_prio(struct tipc_link *l) +{ + return l->priority; +} + +unsigned long tipc_link_tolerance(struct tipc_link *l) +{ + return l->tolerance; +} + +struct sk_buff_head *tipc_link_inputq(struct tipc_link *l) +{ + return l->inputq; +} + +char tipc_link_plane(struct tipc_link *l) +{ + return l->net_plane; +} + +struct net *tipc_link_net(struct tipc_link *l) +{ + return l->net; +} + +void tipc_link_update_caps(struct tipc_link *l, u16 capabilities) +{ + l->peer_caps = capabilities; +} + +void tipc_link_add_bc_peer(struct tipc_link *snd_l, + struct tipc_link *uc_l, + struct sk_buff_head *xmitq) +{ + struct tipc_link *rcv_l = uc_l->bc_rcvlink; + + snd_l->ackers++; + rcv_l->acked = snd_l->snd_nxt - 1; + snd_l->state = LINK_ESTABLISHED; + tipc_link_build_bc_init_msg(uc_l, xmitq); +} + +void tipc_link_remove_bc_peer(struct tipc_link *snd_l, + struct tipc_link *rcv_l, + struct sk_buff_head *xmitq) +{ + u16 ack = snd_l->snd_nxt - 1; + + snd_l->ackers--; + rcv_l->bc_peer_is_up = true; + rcv_l->state = LINK_ESTABLISHED; + tipc_link_bc_ack_rcv(rcv_l, ack, 0, NULL, xmitq, NULL); + trace_tipc_link_reset(rcv_l, TIPC_DUMP_ALL, "bclink removed!"); + tipc_link_reset(rcv_l); + rcv_l->state = LINK_RESET; + if (!snd_l->ackers) { + trace_tipc_link_reset(snd_l, TIPC_DUMP_ALL, "zero ackers!"); + tipc_link_reset(snd_l); + snd_l->state = LINK_RESET; + __skb_queue_purge(xmitq); + } +} + +int tipc_link_bc_peers(struct tipc_link *l) +{ + return l->ackers; +} + +static u16 link_bc_rcv_gap(struct tipc_link *l) +{ + struct sk_buff *skb = skb_peek(&l->deferdq); + u16 gap = 0; + + if (more(l->snd_nxt, l->rcv_nxt)) + gap = l->snd_nxt - l->rcv_nxt; + if (skb) + gap = buf_seqno(skb) - l->rcv_nxt; + return gap; +} + +void tipc_link_set_mtu(struct tipc_link *l, int mtu) +{ + l->mtu = mtu; +} + +int tipc_link_mtu(struct tipc_link *l) +{ + return l->mtu; +} + +int tipc_link_mss(struct tipc_link *l) +{ +#ifdef CONFIG_TIPC_CRYPTO + return l->mtu - INT_H_SIZE - EMSG_OVERHEAD; +#else + return l->mtu - INT_H_SIZE; +#endif +} + +u16 tipc_link_rcv_nxt(struct tipc_link *l) +{ + return l->rcv_nxt; +} + +u16 tipc_link_acked(struct tipc_link *l) +{ + return l->acked; +} + +char *tipc_link_name(struct tipc_link *l) +{ + return l->name; +} + +u32 tipc_link_state(struct tipc_link *l) +{ + return l->state; +} + +/** + * tipc_link_create - create a new link + * @net: pointer to associated network namespace + * @if_name: associated interface name + * @bearer_id: id (index) of associated bearer + * @tolerance: link tolerance to be used by link + * @net_plane: network plane (A,B,c..) this link belongs to + * @mtu: mtu to be advertised by link + * @priority: priority to be used by link + * @min_win: minimal send window to be used by link + * @max_win: maximal send window to be used by link + * @session: session to be used by link + * @peer: node id of peer node + * @peer_caps: bitmap describing peer node capabilities + * @bc_sndlink: the namespace global link used for broadcast sending + * @bc_rcvlink: the peer specific link used for broadcast reception + * @inputq: queue to put messages ready for delivery + * @namedq: queue to put binding table update messages ready for delivery + * @link: return value, pointer to put the created link + * @self: local unicast link id + * @peer_id: 128-bit ID of peer + * + * Return: true if link was created, otherwise false + */ +bool tipc_link_create(struct net *net, char *if_name, int bearer_id, + int tolerance, char net_plane, u32 mtu, int priority, + u32 min_win, u32 max_win, u32 session, u32 self, + u32 peer, u8 *peer_id, u16 peer_caps, + struct tipc_link *bc_sndlink, + struct tipc_link *bc_rcvlink, + struct sk_buff_head *inputq, + struct sk_buff_head *namedq, + struct tipc_link **link) +{ + char peer_str[NODE_ID_STR_LEN] = {0,}; + char self_str[NODE_ID_STR_LEN] = {0,}; + struct tipc_link *l; + + l = kzalloc(sizeof(*l), GFP_ATOMIC); + if (!l) + return false; + *link = l; + l->session = session; + + /* Set link name for unicast links only */ + if (peer_id) { + tipc_nodeid2string(self_str, tipc_own_id(net)); + if (strlen(self_str) > 16) + sprintf(self_str, "%x", self); + tipc_nodeid2string(peer_str, peer_id); + if (strlen(peer_str) > 16) + sprintf(peer_str, "%x", peer); + } + /* Peer i/f name will be completed by reset/activate message */ + snprintf(l->name, sizeof(l->name), "%s:%s-%s:unknown", + self_str, if_name, peer_str); + + strcpy(l->if_name, if_name); + l->addr = peer; + l->peer_caps = peer_caps; + l->net = net; + l->in_session = false; + l->bearer_id = bearer_id; + l->tolerance = tolerance; + if (bc_rcvlink) + bc_rcvlink->tolerance = tolerance; + l->net_plane = net_plane; + l->advertised_mtu = mtu; + l->mtu = mtu; + l->priority = priority; + tipc_link_set_queue_limits(l, min_win, max_win); + l->ackers = 1; + l->bc_sndlink = bc_sndlink; + l->bc_rcvlink = bc_rcvlink; + l->inputq = inputq; + l->namedq = namedq; + l->state = LINK_RESETTING; + __skb_queue_head_init(&l->transmq); + __skb_queue_head_init(&l->backlogq); + __skb_queue_head_init(&l->deferdq); + __skb_queue_head_init(&l->failover_deferdq); + skb_queue_head_init(&l->wakeupq); + skb_queue_head_init(l->inputq); + return true; +} + +/** + * tipc_link_bc_create - create new link to be used for broadcast + * @net: pointer to associated network namespace + * @mtu: mtu to be used initially if no peers + * @min_win: minimal send window to be used by link + * @max_win: maximal send window to be used by link + * @inputq: queue to put messages ready for delivery + * @namedq: queue to put binding table update messages ready for delivery + * @link: return value, pointer to put the created link + * @ownnode: identity of own node + * @peer: node id of peer node + * @peer_id: 128-bit ID of peer + * @peer_caps: bitmap describing peer node capabilities + * @bc_sndlink: the namespace global link used for broadcast sending + * + * Return: true if link was created, otherwise false + */ +bool tipc_link_bc_create(struct net *net, u32 ownnode, u32 peer, u8 *peer_id, + int mtu, u32 min_win, u32 max_win, u16 peer_caps, + struct sk_buff_head *inputq, + struct sk_buff_head *namedq, + struct tipc_link *bc_sndlink, + struct tipc_link **link) +{ + struct tipc_link *l; + + if (!tipc_link_create(net, "", MAX_BEARERS, 0, 'Z', mtu, 0, min_win, + max_win, 0, ownnode, peer, NULL, peer_caps, + bc_sndlink, NULL, inputq, namedq, link)) + return false; + + l = *link; + if (peer_id) { + char peer_str[NODE_ID_STR_LEN] = {0,}; + + tipc_nodeid2string(peer_str, peer_id); + if (strlen(peer_str) > 16) + sprintf(peer_str, "%x", peer); + /* Broadcast receiver link name: "broadcast-link:" */ + snprintf(l->name, sizeof(l->name), "%s:%s", tipc_bclink_name, + peer_str); + } else { + strcpy(l->name, tipc_bclink_name); + } + trace_tipc_link_reset(l, TIPC_DUMP_ALL, "bclink created!"); + tipc_link_reset(l); + l->state = LINK_RESET; + l->ackers = 0; + l->bc_rcvlink = l; + + /* Broadcast send link is always up */ + if (link_is_bc_sndlink(l)) + l->state = LINK_ESTABLISHED; + + /* Disable replicast if even a single peer doesn't support it */ + if (link_is_bc_rcvlink(l) && !(peer_caps & TIPC_BCAST_RCAST)) + tipc_bcast_toggle_rcast(net, false); + + return true; +} + +/** + * tipc_link_fsm_evt - link finite state machine + * @l: pointer to link + * @evt: state machine event to be processed + */ +int tipc_link_fsm_evt(struct tipc_link *l, int evt) +{ + int rc = 0; + int old_state = l->state; + + switch (l->state) { + case LINK_RESETTING: + switch (evt) { + case LINK_PEER_RESET_EVT: + l->state = LINK_PEER_RESET; + break; + case LINK_RESET_EVT: + l->state = LINK_RESET; + break; + case LINK_FAILURE_EVT: + case LINK_FAILOVER_BEGIN_EVT: + case LINK_ESTABLISH_EVT: + case LINK_FAILOVER_END_EVT: + case LINK_SYNCH_BEGIN_EVT: + case LINK_SYNCH_END_EVT: + default: + goto illegal_evt; + } + break; + case LINK_RESET: + switch (evt) { + case LINK_PEER_RESET_EVT: + l->state = LINK_ESTABLISHING; + break; + case LINK_FAILOVER_BEGIN_EVT: + l->state = LINK_FAILINGOVER; + break; + case LINK_FAILURE_EVT: + case LINK_RESET_EVT: + case LINK_ESTABLISH_EVT: + case LINK_FAILOVER_END_EVT: + break; + case LINK_SYNCH_BEGIN_EVT: + case LINK_SYNCH_END_EVT: + default: + goto illegal_evt; + } + break; + case LINK_PEER_RESET: + switch (evt) { + case LINK_RESET_EVT: + l->state = LINK_ESTABLISHING; + break; + case LINK_PEER_RESET_EVT: + case LINK_ESTABLISH_EVT: + case LINK_FAILURE_EVT: + break; + case LINK_SYNCH_BEGIN_EVT: + case LINK_SYNCH_END_EVT: + case LINK_FAILOVER_BEGIN_EVT: + case LINK_FAILOVER_END_EVT: + default: + goto illegal_evt; + } + break; + case LINK_FAILINGOVER: + switch (evt) { + case LINK_FAILOVER_END_EVT: + l->state = LINK_RESET; + break; + case LINK_PEER_RESET_EVT: + case LINK_RESET_EVT: + case LINK_ESTABLISH_EVT: + case LINK_FAILURE_EVT: + break; + case LINK_FAILOVER_BEGIN_EVT: + case LINK_SYNCH_BEGIN_EVT: + case LINK_SYNCH_END_EVT: + default: + goto illegal_evt; + } + break; + case LINK_ESTABLISHING: + switch (evt) { + case LINK_ESTABLISH_EVT: + l->state = LINK_ESTABLISHED; + break; + case LINK_FAILOVER_BEGIN_EVT: + l->state = LINK_FAILINGOVER; + break; + case LINK_RESET_EVT: + l->state = LINK_RESET; + break; + case LINK_FAILURE_EVT: + case LINK_PEER_RESET_EVT: + case LINK_SYNCH_BEGIN_EVT: + case LINK_FAILOVER_END_EVT: + break; + case LINK_SYNCH_END_EVT: + default: + goto illegal_evt; + } + break; + case LINK_ESTABLISHED: + switch (evt) { + case LINK_PEER_RESET_EVT: + l->state = LINK_PEER_RESET; + rc |= TIPC_LINK_DOWN_EVT; + break; + case LINK_FAILURE_EVT: + l->state = LINK_RESETTING; + rc |= TIPC_LINK_DOWN_EVT; + break; + case LINK_RESET_EVT: + l->state = LINK_RESET; + break; + case LINK_ESTABLISH_EVT: + case LINK_SYNCH_END_EVT: + break; + case LINK_SYNCH_BEGIN_EVT: + l->state = LINK_SYNCHING; + break; + case LINK_FAILOVER_BEGIN_EVT: + case LINK_FAILOVER_END_EVT: + default: + goto illegal_evt; + } + break; + case LINK_SYNCHING: + switch (evt) { + case LINK_PEER_RESET_EVT: + l->state = LINK_PEER_RESET; + rc |= TIPC_LINK_DOWN_EVT; + break; + case LINK_FAILURE_EVT: + l->state = LINK_RESETTING; + rc |= TIPC_LINK_DOWN_EVT; + break; + case LINK_RESET_EVT: + l->state = LINK_RESET; + break; + case LINK_ESTABLISH_EVT: + case LINK_SYNCH_BEGIN_EVT: + break; + case LINK_SYNCH_END_EVT: + l->state = LINK_ESTABLISHED; + break; + case LINK_FAILOVER_BEGIN_EVT: + case LINK_FAILOVER_END_EVT: + default: + goto illegal_evt; + } + break; + default: + pr_err("Unknown FSM state %x in %s\n", l->state, l->name); + } + trace_tipc_link_fsm(l->name, old_state, l->state, evt); + return rc; +illegal_evt: + pr_err("Illegal FSM event %x in state %x on link %s\n", + evt, l->state, l->name); + trace_tipc_link_fsm(l->name, old_state, l->state, evt); + return rc; +} + +/* link_profile_stats - update statistical profiling of traffic + */ +static void link_profile_stats(struct tipc_link *l) +{ + struct sk_buff *skb; + struct tipc_msg *msg; + int length; + + /* Update counters used in statistical profiling of send traffic */ + l->stats.accu_queue_sz += skb_queue_len(&l->transmq); + l->stats.queue_sz_counts++; + + skb = skb_peek(&l->transmq); + if (!skb) + return; + msg = buf_msg(skb); + length = msg_size(msg); + + if (msg_user(msg) == MSG_FRAGMENTER) { + if (msg_type(msg) != FIRST_FRAGMENT) + return; + length = msg_size(msg_inner_hdr(msg)); + } + l->stats.msg_lengths_total += length; + l->stats.msg_length_counts++; + if (length <= 64) + l->stats.msg_length_profile[0]++; + else if (length <= 256) + l->stats.msg_length_profile[1]++; + else if (length <= 1024) + l->stats.msg_length_profile[2]++; + else if (length <= 4096) + l->stats.msg_length_profile[3]++; + else if (length <= 16384) + l->stats.msg_length_profile[4]++; + else if (length <= 32768) + l->stats.msg_length_profile[5]++; + else + l->stats.msg_length_profile[6]++; +} + +/** + * tipc_link_too_silent - check if link is "too silent" + * @l: tipc link to be checked + * + * Return: true if the link 'silent_intv_cnt' is about to reach the + * 'abort_limit' value, otherwise false + */ +bool tipc_link_too_silent(struct tipc_link *l) +{ + return (l->silent_intv_cnt + 2 > l->abort_limit); +} + +/* tipc_link_timeout - perform periodic task as instructed from node timeout + */ +int tipc_link_timeout(struct tipc_link *l, struct sk_buff_head *xmitq) +{ + int mtyp = 0; + int rc = 0; + bool state = false; + bool probe = false; + bool setup = false; + u16 bc_snt = l->bc_sndlink->snd_nxt - 1; + u16 bc_acked = l->bc_rcvlink->acked; + struct tipc_mon_state *mstate = &l->mon_state; + + trace_tipc_link_timeout(l, TIPC_DUMP_NONE, " "); + trace_tipc_link_too_silent(l, TIPC_DUMP_ALL, " "); + switch (l->state) { + case LINK_ESTABLISHED: + case LINK_SYNCHING: + mtyp = STATE_MSG; + link_profile_stats(l); + tipc_mon_get_state(l->net, l->addr, mstate, l->bearer_id); + if (mstate->reset || (l->silent_intv_cnt > l->abort_limit)) + return tipc_link_fsm_evt(l, LINK_FAILURE_EVT); + state = bc_acked != bc_snt; + state |= l->bc_rcvlink->rcv_unacked; + state |= l->rcv_unacked; + state |= !skb_queue_empty(&l->transmq); + probe = mstate->probing; + probe |= l->silent_intv_cnt; + if (probe || mstate->monitoring) + l->silent_intv_cnt++; + probe |= !skb_queue_empty(&l->deferdq); + if (l->snd_nxt == l->checkpoint) { + tipc_link_update_cwin(l, 0, 0); + probe = true; + } + l->checkpoint = l->snd_nxt; + break; + case LINK_RESET: + setup = l->rst_cnt++ <= 4; + setup |= !(l->rst_cnt % 16); + mtyp = RESET_MSG; + break; + case LINK_ESTABLISHING: + setup = true; + mtyp = ACTIVATE_MSG; + break; + case LINK_PEER_RESET: + case LINK_RESETTING: + case LINK_FAILINGOVER: + break; + default: + break; + } + + if (state || probe || setup) + tipc_link_build_proto_msg(l, mtyp, probe, 0, 0, 0, 0, xmitq); + + return rc; +} + +/** + * link_schedule_user - schedule a message sender for wakeup after congestion + * @l: congested link + * @hdr: header of message that is being sent + * Create pseudo msg to send back to user when congestion abates + */ +static int link_schedule_user(struct tipc_link *l, struct tipc_msg *hdr) +{ + u32 dnode = tipc_own_addr(l->net); + u32 dport = msg_origport(hdr); + struct sk_buff *skb; + + /* Create and schedule wakeup pseudo message */ + skb = tipc_msg_create(SOCK_WAKEUP, 0, INT_H_SIZE, 0, + dnode, l->addr, dport, 0, 0); + if (!skb) + return -ENOBUFS; + msg_set_dest_droppable(buf_msg(skb), true); + TIPC_SKB_CB(skb)->chain_imp = msg_importance(hdr); + skb_queue_tail(&l->wakeupq, skb); + l->stats.link_congs++; + trace_tipc_link_conges(l, TIPC_DUMP_ALL, "wakeup scheduled!"); + return -ELINKCONG; +} + +/** + * link_prepare_wakeup - prepare users for wakeup after congestion + * @l: congested link + * Wake up a number of waiting users, as permitted by available space + * in the send queue + */ +static void link_prepare_wakeup(struct tipc_link *l) +{ + struct sk_buff_head *wakeupq = &l->wakeupq; + struct sk_buff_head *inputq = l->inputq; + struct sk_buff *skb, *tmp; + struct sk_buff_head tmpq; + int avail[5] = {0,}; + int imp = 0; + + __skb_queue_head_init(&tmpq); + + for (; imp <= TIPC_SYSTEM_IMPORTANCE; imp++) + avail[imp] = l->backlog[imp].limit - l->backlog[imp].len; + + skb_queue_walk_safe(wakeupq, skb, tmp) { + imp = TIPC_SKB_CB(skb)->chain_imp; + if (avail[imp] <= 0) + continue; + avail[imp]--; + __skb_unlink(skb, wakeupq); + __skb_queue_tail(&tmpq, skb); + } + + spin_lock_bh(&inputq->lock); + skb_queue_splice_tail(&tmpq, inputq); + spin_unlock_bh(&inputq->lock); + +} + +/** + * tipc_link_set_skb_retransmit_time - set the time at which retransmission of + * the given skb should be next attempted + * @skb: skb to set a future retransmission time for + * @l: link the skb will be transmitted on + */ +static void tipc_link_set_skb_retransmit_time(struct sk_buff *skb, + struct tipc_link *l) +{ + if (link_is_bc_sndlink(l)) + TIPC_SKB_CB(skb)->nxt_retr = TIPC_BC_RETR_LIM; + else + TIPC_SKB_CB(skb)->nxt_retr = TIPC_UC_RETR_TIME; +} + +void tipc_link_reset(struct tipc_link *l) +{ + struct sk_buff_head list; + u32 imp; + + __skb_queue_head_init(&list); + + l->in_session = false; + /* Force re-synch of peer session number before establishing */ + l->peer_session--; + l->session++; + l->mtu = l->advertised_mtu; + + spin_lock_bh(&l->wakeupq.lock); + skb_queue_splice_init(&l->wakeupq, &list); + spin_unlock_bh(&l->wakeupq.lock); + + spin_lock_bh(&l->inputq->lock); + skb_queue_splice_init(&list, l->inputq); + spin_unlock_bh(&l->inputq->lock); + + __skb_queue_purge(&l->transmq); + __skb_queue_purge(&l->deferdq); + __skb_queue_purge(&l->backlogq); + __skb_queue_purge(&l->failover_deferdq); + for (imp = 0; imp <= TIPC_SYSTEM_IMPORTANCE; imp++) { + l->backlog[imp].len = 0; + l->backlog[imp].target_bskb = NULL; + } + kfree_skb(l->reasm_buf); + kfree_skb(l->reasm_tnlmsg); + kfree_skb(l->failover_reasm_skb); + l->reasm_buf = NULL; + l->reasm_tnlmsg = NULL; + l->failover_reasm_skb = NULL; + l->rcv_unacked = 0; + l->snd_nxt = 1; + l->rcv_nxt = 1; + l->snd_nxt_state = 1; + l->rcv_nxt_state = 1; + l->acked = 0; + l->last_gap = 0; + kfree(l->last_ga); + l->last_ga = NULL; + l->silent_intv_cnt = 0; + l->rst_cnt = 0; + l->bc_peer_is_up = false; + memset(&l->mon_state, 0, sizeof(l->mon_state)); + tipc_link_reset_stats(l); +} + +/** + * tipc_link_xmit(): enqueue buffer list according to queue situation + * @l: link to use + * @list: chain of buffers containing message + * @xmitq: returned list of packets to be sent by caller + * + * Consumes the buffer chain. + * Messages at TIPC_SYSTEM_IMPORTANCE are always accepted + * Return: 0 if success, or errno: -ELINKCONG, -EMSGSIZE or -ENOBUFS + */ +int tipc_link_xmit(struct tipc_link *l, struct sk_buff_head *list, + struct sk_buff_head *xmitq) +{ + struct sk_buff_head *backlogq = &l->backlogq; + struct sk_buff_head *transmq = &l->transmq; + struct sk_buff *skb, *_skb; + u16 bc_ack = l->bc_rcvlink->rcv_nxt - 1; + u16 ack = l->rcv_nxt - 1; + u16 seqno = l->snd_nxt; + int pkt_cnt = skb_queue_len(list); + unsigned int mss = tipc_link_mss(l); + unsigned int cwin = l->window; + unsigned int mtu = l->mtu; + struct tipc_msg *hdr; + bool new_bundle; + int rc = 0; + int imp; + + if (pkt_cnt <= 0) + return 0; + + hdr = buf_msg(skb_peek(list)); + if (unlikely(msg_size(hdr) > mtu)) { + pr_warn("Too large msg, purging xmit list %d %d %d %d %d!\n", + skb_queue_len(list), msg_user(hdr), + msg_type(hdr), msg_size(hdr), mtu); + __skb_queue_purge(list); + return -EMSGSIZE; + } + + imp = msg_importance(hdr); + /* Allow oversubscription of one data msg per source at congestion */ + if (unlikely(l->backlog[imp].len >= l->backlog[imp].limit)) { + if (imp == TIPC_SYSTEM_IMPORTANCE) { + pr_warn("%s<%s>, link overflow", link_rst_msg, l->name); + return -ENOBUFS; + } + rc = link_schedule_user(l, hdr); + } + + if (pkt_cnt > 1) { + l->stats.sent_fragmented++; + l->stats.sent_fragments += pkt_cnt; + } + + /* Prepare each packet for sending, and add to relevant queue: */ + while ((skb = __skb_dequeue(list))) { + if (likely(skb_queue_len(transmq) < cwin)) { + hdr = buf_msg(skb); + msg_set_seqno(hdr, seqno); + msg_set_ack(hdr, ack); + msg_set_bcast_ack(hdr, bc_ack); + _skb = skb_clone(skb, GFP_ATOMIC); + if (!_skb) { + kfree_skb(skb); + __skb_queue_purge(list); + return -ENOBUFS; + } + __skb_queue_tail(transmq, skb); + tipc_link_set_skb_retransmit_time(skb, l); + __skb_queue_tail(xmitq, _skb); + TIPC_SKB_CB(skb)->ackers = l->ackers; + l->rcv_unacked = 0; + l->stats.sent_pkts++; + seqno++; + continue; + } + if (tipc_msg_try_bundle(l->backlog[imp].target_bskb, &skb, + mss, l->addr, &new_bundle)) { + if (skb) { + /* Keep a ref. to the skb for next try */ + l->backlog[imp].target_bskb = skb; + l->backlog[imp].len++; + __skb_queue_tail(backlogq, skb); + } else { + if (new_bundle) { + l->stats.sent_bundles++; + l->stats.sent_bundled++; + } + l->stats.sent_bundled++; + } + continue; + } + l->backlog[imp].target_bskb = NULL; + l->backlog[imp].len += (1 + skb_queue_len(list)); + __skb_queue_tail(backlogq, skb); + skb_queue_splice_tail_init(list, backlogq); + } + l->snd_nxt = seqno; + return rc; +} + +static void tipc_link_update_cwin(struct tipc_link *l, int released, + bool retransmitted) +{ + int bklog_len = skb_queue_len(&l->backlogq); + struct sk_buff_head *txq = &l->transmq; + int txq_len = skb_queue_len(txq); + u16 cwin = l->window; + + /* Enter fast recovery */ + if (unlikely(retransmitted)) { + l->ssthresh = max_t(u16, l->window / 2, 300); + l->window = min_t(u16, l->ssthresh, l->window); + return; + } + /* Enter slow start */ + if (unlikely(!released)) { + l->ssthresh = max_t(u16, l->window / 2, 300); + l->window = l->min_win; + return; + } + /* Don't increase window if no pressure on the transmit queue */ + if (txq_len + bklog_len < cwin) + return; + + /* Don't increase window if there are holes the transmit queue */ + if (txq_len && l->snd_nxt - buf_seqno(skb_peek(txq)) != txq_len) + return; + + l->cong_acks += released; + + /* Slow start */ + if (cwin <= l->ssthresh) { + l->window = min_t(u16, cwin + released, l->max_win); + return; + } + /* Congestion avoidance */ + if (l->cong_acks < cwin) + return; + l->window = min_t(u16, ++cwin, l->max_win); + l->cong_acks = 0; +} + +static void tipc_link_advance_backlog(struct tipc_link *l, + struct sk_buff_head *xmitq) +{ + u16 bc_ack = l->bc_rcvlink->rcv_nxt - 1; + struct sk_buff_head *txq = &l->transmq; + struct sk_buff *skb, *_skb; + u16 ack = l->rcv_nxt - 1; + u16 seqno = l->snd_nxt; + struct tipc_msg *hdr; + u16 cwin = l->window; + u32 imp; + + while (skb_queue_len(txq) < cwin) { + skb = skb_peek(&l->backlogq); + if (!skb) + break; + _skb = skb_clone(skb, GFP_ATOMIC); + if (!_skb) + break; + __skb_dequeue(&l->backlogq); + hdr = buf_msg(skb); + imp = msg_importance(hdr); + l->backlog[imp].len--; + if (unlikely(skb == l->backlog[imp].target_bskb)) + l->backlog[imp].target_bskb = NULL; + __skb_queue_tail(&l->transmq, skb); + tipc_link_set_skb_retransmit_time(skb, l); + + __skb_queue_tail(xmitq, _skb); + TIPC_SKB_CB(skb)->ackers = l->ackers; + msg_set_seqno(hdr, seqno); + msg_set_ack(hdr, ack); + msg_set_bcast_ack(hdr, bc_ack); + l->rcv_unacked = 0; + l->stats.sent_pkts++; + seqno++; + } + l->snd_nxt = seqno; +} + +/** + * link_retransmit_failure() - Detect repeated retransmit failures + * @l: tipc link sender + * @r: tipc link receiver (= l in case of unicast) + * @rc: returned code + * + * Return: true if the repeated retransmit failures happens, otherwise + * false + */ +static bool link_retransmit_failure(struct tipc_link *l, struct tipc_link *r, + int *rc) +{ + struct sk_buff *skb = skb_peek(&l->transmq); + struct tipc_msg *hdr; + + if (!skb) + return false; + + if (!TIPC_SKB_CB(skb)->retr_cnt) + return false; + + if (!time_after(jiffies, TIPC_SKB_CB(skb)->retr_stamp + + msecs_to_jiffies(r->tolerance * 10))) + return false; + + hdr = buf_msg(skb); + if (link_is_bc_sndlink(l) && !less(r->acked, msg_seqno(hdr))) + return false; + + pr_warn("Retransmission failure on link <%s>\n", l->name); + link_print(l, "State of link "); + pr_info("Failed msg: usr %u, typ %u, len %u, err %u\n", + msg_user(hdr), msg_type(hdr), msg_size(hdr), msg_errcode(hdr)); + pr_info("sqno %u, prev: %x, dest: %x\n", + msg_seqno(hdr), msg_prevnode(hdr), msg_destnode(hdr)); + pr_info("retr_stamp %d, retr_cnt %d\n", + jiffies_to_msecs(TIPC_SKB_CB(skb)->retr_stamp), + TIPC_SKB_CB(skb)->retr_cnt); + + trace_tipc_list_dump(&l->transmq, true, "retrans failure!"); + trace_tipc_link_dump(l, TIPC_DUMP_NONE, "retrans failure!"); + trace_tipc_link_dump(r, TIPC_DUMP_NONE, "retrans failure!"); + + if (link_is_bc_sndlink(l)) { + r->state = LINK_RESET; + *rc |= TIPC_LINK_DOWN_EVT; + } else { + *rc |= tipc_link_fsm_evt(l, LINK_FAILURE_EVT); + } + + return true; +} + +/* tipc_data_input - deliver data and name distr msgs to upper layer + * + * Consumes buffer if message is of right type + * Node lock must be held + */ +static bool tipc_data_input(struct tipc_link *l, struct sk_buff *skb, + struct sk_buff_head *inputq) +{ + struct sk_buff_head *mc_inputq = l->bc_rcvlink->inputq; + struct tipc_msg *hdr = buf_msg(skb); + + switch (msg_user(hdr)) { + case TIPC_LOW_IMPORTANCE: + case TIPC_MEDIUM_IMPORTANCE: + case TIPC_HIGH_IMPORTANCE: + case TIPC_CRITICAL_IMPORTANCE: + if (unlikely(msg_in_group(hdr) || msg_mcast(hdr))) { + skb_queue_tail(mc_inputq, skb); + return true; + } + fallthrough; + case CONN_MANAGER: + skb_queue_tail(inputq, skb); + return true; + case GROUP_PROTOCOL: + skb_queue_tail(mc_inputq, skb); + return true; + case NAME_DISTRIBUTOR: + l->bc_rcvlink->state = LINK_ESTABLISHED; + skb_queue_tail(l->namedq, skb); + return true; + case MSG_BUNDLER: + case TUNNEL_PROTOCOL: + case MSG_FRAGMENTER: + case BCAST_PROTOCOL: + return false; +#ifdef CONFIG_TIPC_CRYPTO + case MSG_CRYPTO: + if (sysctl_tipc_key_exchange_enabled && + TIPC_SKB_CB(skb)->decrypted) { + tipc_crypto_msg_rcv(l->net, skb); + return true; + } + fallthrough; +#endif + default: + pr_warn("Dropping received illegal msg type\n"); + kfree_skb(skb); + return true; + } +} + +/* tipc_link_input - process packet that has passed link protocol check + * + * Consumes buffer + */ +static int tipc_link_input(struct tipc_link *l, struct sk_buff *skb, + struct sk_buff_head *inputq, + struct sk_buff **reasm_skb) +{ + struct tipc_msg *hdr = buf_msg(skb); + struct sk_buff *iskb; + struct sk_buff_head tmpq; + int usr = msg_user(hdr); + int pos = 0; + + if (usr == MSG_BUNDLER) { + skb_queue_head_init(&tmpq); + l->stats.recv_bundles++; + l->stats.recv_bundled += msg_msgcnt(hdr); + while (tipc_msg_extract(skb, &iskb, &pos)) + tipc_data_input(l, iskb, &tmpq); + tipc_skb_queue_splice_tail(&tmpq, inputq); + return 0; + } else if (usr == MSG_FRAGMENTER) { + l->stats.recv_fragments++; + if (tipc_buf_append(reasm_skb, &skb)) { + l->stats.recv_fragmented++; + tipc_data_input(l, skb, inputq); + } else if (!*reasm_skb && !link_is_bc_rcvlink(l)) { + pr_warn_ratelimited("Unable to build fragment list\n"); + return tipc_link_fsm_evt(l, LINK_FAILURE_EVT); + } + return 0; + } else if (usr == BCAST_PROTOCOL) { + tipc_bcast_lock(l->net); + tipc_link_bc_init_rcv(l->bc_rcvlink, hdr); + tipc_bcast_unlock(l->net); + } + + kfree_skb(skb); + return 0; +} + +/* tipc_link_tnl_rcv() - receive TUNNEL_PROTOCOL message, drop or process the + * inner message along with the ones in the old link's + * deferdq + * @l: tunnel link + * @skb: TUNNEL_PROTOCOL message + * @inputq: queue to put messages ready for delivery + */ +static int tipc_link_tnl_rcv(struct tipc_link *l, struct sk_buff *skb, + struct sk_buff_head *inputq) +{ + struct sk_buff **reasm_skb = &l->failover_reasm_skb; + struct sk_buff **reasm_tnlmsg = &l->reasm_tnlmsg; + struct sk_buff_head *fdefq = &l->failover_deferdq; + struct tipc_msg *hdr = buf_msg(skb); + struct sk_buff *iskb; + int ipos = 0; + int rc = 0; + u16 seqno; + + if (msg_type(hdr) == SYNCH_MSG) { + kfree_skb(skb); + return 0; + } + + /* Not a fragment? */ + if (likely(!msg_nof_fragms(hdr))) { + if (unlikely(!tipc_msg_extract(skb, &iskb, &ipos))) { + pr_warn_ratelimited("Unable to extract msg, defq: %d\n", + skb_queue_len(fdefq)); + return 0; + } + kfree_skb(skb); + } else { + /* Set fragment type for buf_append */ + if (msg_fragm_no(hdr) == 1) + msg_set_type(hdr, FIRST_FRAGMENT); + else if (msg_fragm_no(hdr) < msg_nof_fragms(hdr)) + msg_set_type(hdr, FRAGMENT); + else + msg_set_type(hdr, LAST_FRAGMENT); + + if (!tipc_buf_append(reasm_tnlmsg, &skb)) { + /* Successful but non-complete reassembly? */ + if (*reasm_tnlmsg || link_is_bc_rcvlink(l)) + return 0; + pr_warn_ratelimited("Unable to reassemble tunnel msg\n"); + return tipc_link_fsm_evt(l, LINK_FAILURE_EVT); + } + iskb = skb; + } + + do { + seqno = buf_seqno(iskb); + if (unlikely(less(seqno, l->drop_point))) { + kfree_skb(iskb); + continue; + } + if (unlikely(seqno != l->drop_point)) { + __tipc_skb_queue_sorted(fdefq, seqno, iskb); + continue; + } + + l->drop_point++; + if (!tipc_data_input(l, iskb, inputq)) + rc |= tipc_link_input(l, iskb, inputq, reasm_skb); + if (unlikely(rc)) + break; + } while ((iskb = __tipc_skb_dequeue(fdefq, l->drop_point))); + + return rc; +} + +/** + * tipc_get_gap_ack_blks - get Gap ACK blocks from PROTOCOL/STATE_MSG + * @ga: returned pointer to the Gap ACK blocks if any + * @l: the tipc link + * @hdr: the PROTOCOL/STATE_MSG header + * @uc: desired Gap ACK blocks type, i.e. unicast (= 1) or broadcast (= 0) + * + * Return: the total Gap ACK blocks size + */ +u16 tipc_get_gap_ack_blks(struct tipc_gap_ack_blks **ga, struct tipc_link *l, + struct tipc_msg *hdr, bool uc) +{ + struct tipc_gap_ack_blks *p; + u16 sz = 0; + + /* Does peer support the Gap ACK blocks feature? */ + if (l->peer_caps & TIPC_GAP_ACK_BLOCK) { + p = (struct tipc_gap_ack_blks *)msg_data(hdr); + sz = ntohs(p->len); + /* Sanity check */ + if (sz == struct_size(p, gacks, size_add(p->ugack_cnt, p->bgack_cnt))) { + /* Good, check if the desired type exists */ + if ((uc && p->ugack_cnt) || (!uc && p->bgack_cnt)) + goto ok; + /* Backward compatible: peer might not support bc, but uc? */ + } else if (uc && sz == struct_size(p, gacks, p->ugack_cnt)) { + if (p->ugack_cnt) { + p->bgack_cnt = 0; + goto ok; + } + } + } + /* Other cases: ignore! */ + p = NULL; + +ok: + *ga = p; + return sz; +} + +static u8 __tipc_build_gap_ack_blks(struct tipc_gap_ack_blks *ga, + struct tipc_link *l, u8 start_index) +{ + struct tipc_gap_ack *gacks = &ga->gacks[start_index]; + struct sk_buff *skb = skb_peek(&l->deferdq); + u16 expect, seqno = 0; + u8 n = 0; + + if (!skb) + return 0; + + expect = buf_seqno(skb); + skb_queue_walk(&l->deferdq, skb) { + seqno = buf_seqno(skb); + if (unlikely(more(seqno, expect))) { + gacks[n].ack = htons(expect - 1); + gacks[n].gap = htons(seqno - expect); + if (++n >= MAX_GAP_ACK_BLKS / 2) { + pr_info_ratelimited("Gacks on %s: %d, ql: %d!\n", + l->name, n, + skb_queue_len(&l->deferdq)); + return n; + } + } else if (unlikely(less(seqno, expect))) { + pr_warn("Unexpected skb in deferdq!\n"); + continue; + } + expect = seqno + 1; + } + + /* last block */ + gacks[n].ack = htons(seqno); + gacks[n].gap = 0; + n++; + return n; +} + +/* tipc_build_gap_ack_blks - build Gap ACK blocks + * @l: tipc unicast link + * @hdr: the tipc message buffer to store the Gap ACK blocks after built + * + * The function builds Gap ACK blocks for both the unicast & broadcast receiver + * links of a certain peer, the buffer after built has the network data format + * as found at the struct tipc_gap_ack_blks definition. + * + * returns the actual allocated memory size + */ +static u16 tipc_build_gap_ack_blks(struct tipc_link *l, struct tipc_msg *hdr) +{ + struct tipc_link *bcl = l->bc_rcvlink; + struct tipc_gap_ack_blks *ga; + u16 len; + + ga = (struct tipc_gap_ack_blks *)msg_data(hdr); + + /* Start with broadcast link first */ + tipc_bcast_lock(bcl->net); + msg_set_bcast_ack(hdr, bcl->rcv_nxt - 1); + msg_set_bc_gap(hdr, link_bc_rcv_gap(bcl)); + ga->bgack_cnt = __tipc_build_gap_ack_blks(ga, bcl, 0); + tipc_bcast_unlock(bcl->net); + + /* Now for unicast link, but an explicit NACK only (???) */ + ga->ugack_cnt = (msg_seq_gap(hdr)) ? + __tipc_build_gap_ack_blks(ga, l, ga->bgack_cnt) : 0; + + /* Total len */ + len = struct_size(ga, gacks, size_add(ga->bgack_cnt, ga->ugack_cnt)); + ga->len = htons(len); + return len; +} + +/* tipc_link_advance_transmq - advance TIPC link transmq queue by releasing + * acked packets, also doing retransmissions if + * gaps found + * @l: tipc link with transmq queue to be advanced + * @r: tipc link "receiver" i.e. in case of broadcast (= "l" if unicast) + * @acked: seqno of last packet acked by peer without any gaps before + * @gap: # of gap packets + * @ga: buffer pointer to Gap ACK blocks from peer + * @xmitq: queue for accumulating the retransmitted packets if any + * @retransmitted: returned boolean value if a retransmission is really issued + * @rc: returned code e.g. TIPC_LINK_DOWN_EVT if a repeated retransmit failures + * happens (- unlikely case) + * + * Return: the number of packets released from the link transmq + */ +static int tipc_link_advance_transmq(struct tipc_link *l, struct tipc_link *r, + u16 acked, u16 gap, + struct tipc_gap_ack_blks *ga, + struct sk_buff_head *xmitq, + bool *retransmitted, int *rc) +{ + struct tipc_gap_ack_blks *last_ga = r->last_ga, *this_ga = NULL; + struct tipc_gap_ack *gacks = NULL; + struct sk_buff *skb, *_skb, *tmp; + struct tipc_msg *hdr; + u32 qlen = skb_queue_len(&l->transmq); + u16 nacked = acked, ngap = gap, gack_cnt = 0; + u16 bc_ack = l->bc_rcvlink->rcv_nxt - 1; + u16 ack = l->rcv_nxt - 1; + u16 seqno, n = 0; + u16 end = r->acked, start = end, offset = r->last_gap; + u16 si = (last_ga) ? last_ga->start_index : 0; + bool is_uc = !link_is_bc_sndlink(l); + bool bc_has_acked = false; + + trace_tipc_link_retrans(r, acked + 1, acked + gap, &l->transmq); + + /* Determine Gap ACK blocks if any for the particular link */ + if (ga && is_uc) { + /* Get the Gap ACKs, uc part */ + gack_cnt = ga->ugack_cnt; + gacks = &ga->gacks[ga->bgack_cnt]; + } else if (ga) { + /* Copy the Gap ACKs, bc part, for later renewal if needed */ + this_ga = kmemdup(ga, struct_size(ga, gacks, ga->bgack_cnt), + GFP_ATOMIC); + if (likely(this_ga)) { + this_ga->start_index = 0; + /* Start with the bc Gap ACKs */ + gack_cnt = this_ga->bgack_cnt; + gacks = &this_ga->gacks[0]; + } else { + /* Hmm, we can get in trouble..., simply ignore it */ + pr_warn_ratelimited("Ignoring bc Gap ACKs, no memory\n"); + } + } + + /* Advance the link transmq */ + skb_queue_walk_safe(&l->transmq, skb, tmp) { + seqno = buf_seqno(skb); + +next_gap_ack: + if (less_eq(seqno, nacked)) { + if (is_uc) + goto release; + /* Skip packets peer has already acked */ + if (!more(seqno, r->acked)) + continue; + /* Get the next of last Gap ACK blocks */ + while (more(seqno, end)) { + if (!last_ga || si >= last_ga->bgack_cnt) + break; + start = end + offset + 1; + end = ntohs(last_ga->gacks[si].ack); + offset = ntohs(last_ga->gacks[si].gap); + si++; + WARN_ONCE(more(start, end) || + (!offset && + si < last_ga->bgack_cnt) || + si > MAX_GAP_ACK_BLKS, + "Corrupted Gap ACK: %d %d %d %d %d\n", + start, end, offset, si, + last_ga->bgack_cnt); + } + /* Check against the last Gap ACK block */ + if (in_range(seqno, start, end)) + continue; + /* Update/release the packet peer is acking */ + bc_has_acked = true; + if (--TIPC_SKB_CB(skb)->ackers) + continue; +release: + /* release skb */ + __skb_unlink(skb, &l->transmq); + kfree_skb(skb); + } else if (less_eq(seqno, nacked + ngap)) { + /* First gap: check if repeated retrans failures? */ + if (unlikely(seqno == acked + 1 && + link_retransmit_failure(l, r, rc))) { + /* Ignore this bc Gap ACKs if any */ + kfree(this_ga); + this_ga = NULL; + break; + } + /* retransmit skb if unrestricted*/ + if (time_before(jiffies, TIPC_SKB_CB(skb)->nxt_retr)) + continue; + tipc_link_set_skb_retransmit_time(skb, l); + _skb = pskb_copy(skb, GFP_ATOMIC); + if (!_skb) + continue; + hdr = buf_msg(_skb); + msg_set_ack(hdr, ack); + msg_set_bcast_ack(hdr, bc_ack); + _skb->priority = TC_PRIO_CONTROL; + __skb_queue_tail(xmitq, _skb); + l->stats.retransmitted++; + if (!is_uc) + r->stats.retransmitted++; + *retransmitted = true; + /* Increase actual retrans counter & mark first time */ + if (!TIPC_SKB_CB(skb)->retr_cnt++) + TIPC_SKB_CB(skb)->retr_stamp = jiffies; + } else { + /* retry with Gap ACK blocks if any */ + if (n >= gack_cnt) + break; + nacked = ntohs(gacks[n].ack); + ngap = ntohs(gacks[n].gap); + n++; + goto next_gap_ack; + } + } + + /* Renew last Gap ACK blocks for bc if needed */ + if (bc_has_acked) { + if (this_ga) { + kfree(last_ga); + r->last_ga = this_ga; + r->last_gap = gap; + } else if (last_ga) { + if (less(acked, start)) { + si--; + offset = start - acked - 1; + } else if (less(acked, end)) { + acked = end; + } + if (si < last_ga->bgack_cnt) { + last_ga->start_index = si; + r->last_gap = offset; + } else { + kfree(last_ga); + r->last_ga = NULL; + r->last_gap = 0; + } + } else { + r->last_gap = 0; + } + r->acked = acked; + } else { + kfree(this_ga); + } + + return qlen - skb_queue_len(&l->transmq); +} + +/* tipc_link_build_state_msg: prepare link state message for transmission + * + * Note that sending of broadcast ack is coordinated among nodes, to reduce + * risk of ack storms towards the sender + */ +int tipc_link_build_state_msg(struct tipc_link *l, struct sk_buff_head *xmitq) +{ + if (!l) + return 0; + + /* Broadcast ACK must be sent via a unicast link => defer to caller */ + if (link_is_bc_rcvlink(l)) { + if (((l->rcv_nxt ^ tipc_own_addr(l->net)) & 0xf) != 0xf) + return 0; + l->rcv_unacked = 0; + + /* Use snd_nxt to store peer's snd_nxt in broadcast rcv link */ + l->snd_nxt = l->rcv_nxt; + return TIPC_LINK_SND_STATE; + } + /* Unicast ACK */ + l->rcv_unacked = 0; + l->stats.sent_acks++; + tipc_link_build_proto_msg(l, STATE_MSG, 0, 0, 0, 0, 0, xmitq); + return 0; +} + +/* tipc_link_build_reset_msg: prepare link RESET or ACTIVATE message + */ +void tipc_link_build_reset_msg(struct tipc_link *l, struct sk_buff_head *xmitq) +{ + int mtyp = RESET_MSG; + struct sk_buff *skb; + + if (l->state == LINK_ESTABLISHING) + mtyp = ACTIVATE_MSG; + + tipc_link_build_proto_msg(l, mtyp, 0, 0, 0, 0, 0, xmitq); + + /* Inform peer that this endpoint is going down if applicable */ + skb = skb_peek_tail(xmitq); + if (skb && (l->state == LINK_RESET)) + msg_set_peer_stopping(buf_msg(skb), 1); +} + +/* tipc_link_build_nack_msg: prepare link nack message for transmission + * Note that sending of broadcast NACK is coordinated among nodes, to + * reduce the risk of NACK storms towards the sender + */ +static int tipc_link_build_nack_msg(struct tipc_link *l, + struct sk_buff_head *xmitq) +{ + u32 def_cnt = ++l->stats.deferred_recv; + struct sk_buff_head *dfq = &l->deferdq; + u32 defq_len = skb_queue_len(dfq); + int match1, match2; + + if (link_is_bc_rcvlink(l)) { + match1 = def_cnt & 0xf; + match2 = tipc_own_addr(l->net) & 0xf; + if (match1 == match2) + return TIPC_LINK_SND_STATE; + return 0; + } + + if (defq_len >= 3 && !((defq_len - 3) % 16)) { + u16 rcvgap = buf_seqno(skb_peek(dfq)) - l->rcv_nxt; + + tipc_link_build_proto_msg(l, STATE_MSG, 0, 0, + rcvgap, 0, 0, xmitq); + } + return 0; +} + +/* tipc_link_rcv - process TIPC packets/messages arriving from off-node + * @l: the link that should handle the message + * @skb: TIPC packet + * @xmitq: queue to place packets to be sent after this call + */ +int tipc_link_rcv(struct tipc_link *l, struct sk_buff *skb, + struct sk_buff_head *xmitq) +{ + struct sk_buff_head *defq = &l->deferdq; + struct tipc_msg *hdr = buf_msg(skb); + u16 seqno, rcv_nxt, win_lim; + int released = 0; + int rc = 0; + + /* Verify and update link state */ + if (unlikely(msg_user(hdr) == LINK_PROTOCOL)) + return tipc_link_proto_rcv(l, skb, xmitq); + + /* Don't send probe at next timeout expiration */ + l->silent_intv_cnt = 0; + + do { + hdr = buf_msg(skb); + seqno = msg_seqno(hdr); + rcv_nxt = l->rcv_nxt; + win_lim = rcv_nxt + TIPC_MAX_LINK_WIN; + + if (unlikely(!link_is_up(l))) { + if (l->state == LINK_ESTABLISHING) + rc = TIPC_LINK_UP_EVT; + kfree_skb(skb); + break; + } + + /* Drop if outside receive window */ + if (unlikely(less(seqno, rcv_nxt) || more(seqno, win_lim))) { + l->stats.duplicates++; + kfree_skb(skb); + break; + } + released += tipc_link_advance_transmq(l, l, msg_ack(hdr), 0, + NULL, NULL, NULL, NULL); + + /* Defer delivery if sequence gap */ + if (unlikely(seqno != rcv_nxt)) { + if (!__tipc_skb_queue_sorted(defq, seqno, skb)) + l->stats.duplicates++; + rc |= tipc_link_build_nack_msg(l, xmitq); + break; + } + + /* Deliver packet */ + l->rcv_nxt++; + l->stats.recv_pkts++; + + if (unlikely(msg_user(hdr) == TUNNEL_PROTOCOL)) + rc |= tipc_link_tnl_rcv(l, skb, l->inputq); + else if (!tipc_data_input(l, skb, l->inputq)) + rc |= tipc_link_input(l, skb, l->inputq, &l->reasm_buf); + if (unlikely(++l->rcv_unacked >= TIPC_MIN_LINK_WIN)) + rc |= tipc_link_build_state_msg(l, xmitq); + if (unlikely(rc & ~TIPC_LINK_SND_STATE)) + break; + } while ((skb = __tipc_skb_dequeue(defq, l->rcv_nxt))); + + /* Forward queues and wake up waiting users */ + if (released) { + tipc_link_update_cwin(l, released, 0); + tipc_link_advance_backlog(l, xmitq); + if (unlikely(!skb_queue_empty(&l->wakeupq))) + link_prepare_wakeup(l); + } + return rc; +} + +static void tipc_link_build_proto_msg(struct tipc_link *l, int mtyp, bool probe, + bool probe_reply, u16 rcvgap, + int tolerance, int priority, + struct sk_buff_head *xmitq) +{ + struct tipc_mon_state *mstate = &l->mon_state; + struct sk_buff_head *dfq = &l->deferdq; + struct tipc_link *bcl = l->bc_rcvlink; + struct tipc_msg *hdr; + struct sk_buff *skb; + bool node_up = link_is_up(bcl); + u16 glen = 0, bc_rcvgap = 0; + int dlen = 0; + void *data; + + /* Don't send protocol message during reset or link failover */ + if (tipc_link_is_blocked(l)) + return; + + if (!tipc_link_is_up(l) && (mtyp == STATE_MSG)) + return; + + if ((probe || probe_reply) && !skb_queue_empty(dfq)) + rcvgap = buf_seqno(skb_peek(dfq)) - l->rcv_nxt; + + skb = tipc_msg_create(LINK_PROTOCOL, mtyp, INT_H_SIZE, + tipc_max_domain_size + MAX_GAP_ACK_BLKS_SZ, + l->addr, tipc_own_addr(l->net), 0, 0, 0); + if (!skb) + return; + + hdr = buf_msg(skb); + data = msg_data(hdr); + msg_set_session(hdr, l->session); + msg_set_bearer_id(hdr, l->bearer_id); + msg_set_net_plane(hdr, l->net_plane); + msg_set_next_sent(hdr, l->snd_nxt); + msg_set_ack(hdr, l->rcv_nxt - 1); + msg_set_bcast_ack(hdr, bcl->rcv_nxt - 1); + msg_set_bc_ack_invalid(hdr, !node_up); + msg_set_last_bcast(hdr, l->bc_sndlink->snd_nxt - 1); + msg_set_link_tolerance(hdr, tolerance); + msg_set_linkprio(hdr, priority); + msg_set_redundant_link(hdr, node_up); + msg_set_seq_gap(hdr, 0); + msg_set_seqno(hdr, l->snd_nxt + U16_MAX / 2); + + if (mtyp == STATE_MSG) { + if (l->peer_caps & TIPC_LINK_PROTO_SEQNO) + msg_set_seqno(hdr, l->snd_nxt_state++); + msg_set_seq_gap(hdr, rcvgap); + bc_rcvgap = link_bc_rcv_gap(bcl); + msg_set_bc_gap(hdr, bc_rcvgap); + msg_set_probe(hdr, probe); + msg_set_is_keepalive(hdr, probe || probe_reply); + if (l->peer_caps & TIPC_GAP_ACK_BLOCK) + glen = tipc_build_gap_ack_blks(l, hdr); + tipc_mon_prep(l->net, data + glen, &dlen, mstate, l->bearer_id); + msg_set_size(hdr, INT_H_SIZE + glen + dlen); + skb_trim(skb, INT_H_SIZE + glen + dlen); + l->stats.sent_states++; + l->rcv_unacked = 0; + } else { + /* RESET_MSG or ACTIVATE_MSG */ + if (mtyp == ACTIVATE_MSG) { + msg_set_dest_session_valid(hdr, 1); + msg_set_dest_session(hdr, l->peer_session); + } + msg_set_max_pkt(hdr, l->advertised_mtu); + strcpy(data, l->if_name); + msg_set_size(hdr, INT_H_SIZE + TIPC_MAX_IF_NAME); + skb_trim(skb, INT_H_SIZE + TIPC_MAX_IF_NAME); + } + if (probe) + l->stats.sent_probes++; + if (rcvgap) + l->stats.sent_nacks++; + if (bc_rcvgap) + bcl->stats.sent_nacks++; + skb->priority = TC_PRIO_CONTROL; + __skb_queue_tail(xmitq, skb); + trace_tipc_proto_build(skb, false, l->name); +} + +void tipc_link_create_dummy_tnl_msg(struct tipc_link *l, + struct sk_buff_head *xmitq) +{ + u32 onode = tipc_own_addr(l->net); + struct tipc_msg *hdr, *ihdr; + struct sk_buff_head tnlq; + struct sk_buff *skb; + u32 dnode = l->addr; + + __skb_queue_head_init(&tnlq); + skb = tipc_msg_create(TUNNEL_PROTOCOL, FAILOVER_MSG, + INT_H_SIZE, BASIC_H_SIZE, + dnode, onode, 0, 0, 0); + if (!skb) { + pr_warn("%sunable to create tunnel packet\n", link_co_err); + return; + } + + hdr = buf_msg(skb); + msg_set_msgcnt(hdr, 1); + msg_set_bearer_id(hdr, l->peer_bearer_id); + + ihdr = (struct tipc_msg *)msg_data(hdr); + tipc_msg_init(onode, ihdr, TIPC_LOW_IMPORTANCE, TIPC_DIRECT_MSG, + BASIC_H_SIZE, dnode); + msg_set_errcode(ihdr, TIPC_ERR_NO_PORT); + __skb_queue_tail(&tnlq, skb); + tipc_link_xmit(l, &tnlq, xmitq); +} + +/* tipc_link_tnl_prepare(): prepare and return a list of tunnel packets + * with contents of the link's transmit and backlog queues. + */ +void tipc_link_tnl_prepare(struct tipc_link *l, struct tipc_link *tnl, + int mtyp, struct sk_buff_head *xmitq) +{ + struct sk_buff_head *fdefq = &tnl->failover_deferdq; + struct sk_buff *skb, *tnlskb; + struct tipc_msg *hdr, tnlhdr; + struct sk_buff_head *queue = &l->transmq; + struct sk_buff_head tmpxq, tnlq, frags; + u16 pktlen, pktcnt, seqno = l->snd_nxt; + bool pktcnt_need_update = false; + u16 syncpt; + int rc; + + if (!tnl) + return; + + __skb_queue_head_init(&tnlq); + /* Link Synching: + * From now on, send only one single ("dummy") SYNCH message + * to peer. The SYNCH message does not contain any data, just + * a header conveying the synch point to the peer. + */ + if (mtyp == SYNCH_MSG && (tnl->peer_caps & TIPC_TUNNEL_ENHANCED)) { + tnlskb = tipc_msg_create(TUNNEL_PROTOCOL, SYNCH_MSG, + INT_H_SIZE, 0, l->addr, + tipc_own_addr(l->net), + 0, 0, 0); + if (!tnlskb) { + pr_warn("%sunable to create dummy SYNCH_MSG\n", + link_co_err); + return; + } + + hdr = buf_msg(tnlskb); + syncpt = l->snd_nxt + skb_queue_len(&l->backlogq) - 1; + msg_set_syncpt(hdr, syncpt); + msg_set_bearer_id(hdr, l->peer_bearer_id); + __skb_queue_tail(&tnlq, tnlskb); + tipc_link_xmit(tnl, &tnlq, xmitq); + return; + } + + __skb_queue_head_init(&tmpxq); + __skb_queue_head_init(&frags); + /* At least one packet required for safe algorithm => add dummy */ + skb = tipc_msg_create(TIPC_LOW_IMPORTANCE, TIPC_DIRECT_MSG, + BASIC_H_SIZE, 0, l->addr, tipc_own_addr(l->net), + 0, 0, TIPC_ERR_NO_PORT); + if (!skb) { + pr_warn("%sunable to create tunnel packet\n", link_co_err); + return; + } + __skb_queue_tail(&tnlq, skb); + tipc_link_xmit(l, &tnlq, &tmpxq); + __skb_queue_purge(&tmpxq); + + /* Initialize reusable tunnel packet header */ + tipc_msg_init(tipc_own_addr(l->net), &tnlhdr, TUNNEL_PROTOCOL, + mtyp, INT_H_SIZE, l->addr); + if (mtyp == SYNCH_MSG) + pktcnt = l->snd_nxt - buf_seqno(skb_peek(&l->transmq)); + else + pktcnt = skb_queue_len(&l->transmq); + pktcnt += skb_queue_len(&l->backlogq); + msg_set_msgcnt(&tnlhdr, pktcnt); + msg_set_bearer_id(&tnlhdr, l->peer_bearer_id); +tnl: + /* Wrap each packet into a tunnel packet */ + skb_queue_walk(queue, skb) { + hdr = buf_msg(skb); + if (queue == &l->backlogq) + msg_set_seqno(hdr, seqno++); + pktlen = msg_size(hdr); + + /* Tunnel link MTU is not large enough? This could be + * due to: + * 1) Link MTU has just changed or set differently; + * 2) Or FAILOVER on the top of a SYNCH message + * + * The 2nd case should not happen if peer supports + * TIPC_TUNNEL_ENHANCED + */ + if (pktlen > tnl->mtu - INT_H_SIZE) { + if (mtyp == FAILOVER_MSG && + (tnl->peer_caps & TIPC_TUNNEL_ENHANCED)) { + rc = tipc_msg_fragment(skb, &tnlhdr, tnl->mtu, + &frags); + if (rc) { + pr_warn("%sunable to frag msg: rc %d\n", + link_co_err, rc); + return; + } + pktcnt += skb_queue_len(&frags) - 1; + pktcnt_need_update = true; + skb_queue_splice_tail_init(&frags, &tnlq); + continue; + } + /* Unluckily, peer doesn't have TIPC_TUNNEL_ENHANCED + * => Just warn it and return! + */ + pr_warn_ratelimited("%stoo large msg <%d, %d>: %d!\n", + link_co_err, msg_user(hdr), + msg_type(hdr), msg_size(hdr)); + return; + } + + msg_set_size(&tnlhdr, pktlen + INT_H_SIZE); + tnlskb = tipc_buf_acquire(pktlen + INT_H_SIZE, GFP_ATOMIC); + if (!tnlskb) { + pr_warn("%sunable to send packet\n", link_co_err); + return; + } + skb_copy_to_linear_data(tnlskb, &tnlhdr, INT_H_SIZE); + skb_copy_to_linear_data_offset(tnlskb, INT_H_SIZE, hdr, pktlen); + __skb_queue_tail(&tnlq, tnlskb); + } + if (queue != &l->backlogq) { + queue = &l->backlogq; + goto tnl; + } + + if (pktcnt_need_update) + skb_queue_walk(&tnlq, skb) { + hdr = buf_msg(skb); + msg_set_msgcnt(hdr, pktcnt); + } + + tipc_link_xmit(tnl, &tnlq, xmitq); + + if (mtyp == FAILOVER_MSG) { + tnl->drop_point = l->rcv_nxt; + tnl->failover_reasm_skb = l->reasm_buf; + l->reasm_buf = NULL; + + /* Failover the link's deferdq */ + if (unlikely(!skb_queue_empty(fdefq))) { + pr_warn("Link failover deferdq not empty: %d!\n", + skb_queue_len(fdefq)); + __skb_queue_purge(fdefq); + } + skb_queue_splice_init(&l->deferdq, fdefq); + } +} + +/** + * tipc_link_failover_prepare() - prepare tnl for link failover + * + * This is a special version of the precursor - tipc_link_tnl_prepare(), + * see the tipc_node_link_failover() for details + * + * @l: failover link + * @tnl: tunnel link + * @xmitq: queue for messages to be xmited + */ +void tipc_link_failover_prepare(struct tipc_link *l, struct tipc_link *tnl, + struct sk_buff_head *xmitq) +{ + struct sk_buff_head *fdefq = &tnl->failover_deferdq; + + tipc_link_create_dummy_tnl_msg(tnl, xmitq); + + /* This failover link endpoint was never established before, + * so it has not received anything from peer. + * Otherwise, it must be a normal failover situation or the + * node has entered SELF_DOWN_PEER_LEAVING and both peer nodes + * would have to start over from scratch instead. + */ + tnl->drop_point = 1; + tnl->failover_reasm_skb = NULL; + + /* Initiate the link's failover deferdq */ + if (unlikely(!skb_queue_empty(fdefq))) { + pr_warn("Link failover deferdq not empty: %d!\n", + skb_queue_len(fdefq)); + __skb_queue_purge(fdefq); + } +} + +/* tipc_link_validate_msg(): validate message against current link state + * Returns true if message should be accepted, otherwise false + */ +bool tipc_link_validate_msg(struct tipc_link *l, struct tipc_msg *hdr) +{ + u16 curr_session = l->peer_session; + u16 session = msg_session(hdr); + int mtyp = msg_type(hdr); + + if (msg_user(hdr) != LINK_PROTOCOL) + return true; + + switch (mtyp) { + case RESET_MSG: + if (!l->in_session) + return true; + /* Accept only RESET with new session number */ + return more(session, curr_session); + case ACTIVATE_MSG: + if (!l->in_session) + return true; + /* Accept only ACTIVATE with new or current session number */ + return !less(session, curr_session); + case STATE_MSG: + /* Accept only STATE with current session number */ + if (!l->in_session) + return false; + if (session != curr_session) + return false; + /* Extra sanity check */ + if (!link_is_up(l) && msg_ack(hdr)) + return false; + if (!(l->peer_caps & TIPC_LINK_PROTO_SEQNO)) + return true; + /* Accept only STATE with new sequence number */ + return !less(msg_seqno(hdr), l->rcv_nxt_state); + default: + return false; + } +} + +/* tipc_link_proto_rcv(): receive link level protocol message : + * Note that network plane id propagates through the network, and may + * change at any time. The node with lowest numerical id determines + * network plane + */ +static int tipc_link_proto_rcv(struct tipc_link *l, struct sk_buff *skb, + struct sk_buff_head *xmitq) +{ + struct tipc_msg *hdr = buf_msg(skb); + struct tipc_gap_ack_blks *ga = NULL; + bool reply = msg_probe(hdr), retransmitted = false; + u32 dlen = msg_data_sz(hdr), glen = 0, msg_max; + u16 peers_snd_nxt = msg_next_sent(hdr); + u16 peers_tol = msg_link_tolerance(hdr); + u16 peers_prio = msg_linkprio(hdr); + u16 gap = msg_seq_gap(hdr); + u16 ack = msg_ack(hdr); + u16 rcv_nxt = l->rcv_nxt; + u16 rcvgap = 0; + int mtyp = msg_type(hdr); + int rc = 0, released; + char *if_name; + void *data; + + trace_tipc_proto_rcv(skb, false, l->name); + + if (dlen > U16_MAX) + goto exit; + + if (tipc_link_is_blocked(l) || !xmitq) + goto exit; + + if (tipc_own_addr(l->net) > msg_prevnode(hdr)) + l->net_plane = msg_net_plane(hdr); + + if (skb_linearize(skb)) + goto exit; + + hdr = buf_msg(skb); + data = msg_data(hdr); + + if (!tipc_link_validate_msg(l, hdr)) { + trace_tipc_skb_dump(skb, false, "PROTO invalid (1)!"); + trace_tipc_link_dump(l, TIPC_DUMP_NONE, "PROTO invalid (1)!"); + goto exit; + } + + switch (mtyp) { + case RESET_MSG: + case ACTIVATE_MSG: + msg_max = msg_max_pkt(hdr); + if (msg_max < tipc_bearer_min_mtu(l->net, l->bearer_id)) + break; + /* Complete own link name with peer's interface name */ + if_name = strrchr(l->name, ':') + 1; + if (sizeof(l->name) - (if_name - l->name) <= TIPC_MAX_IF_NAME) + break; + if (msg_data_sz(hdr) < TIPC_MAX_IF_NAME) + break; + strncpy(if_name, data, TIPC_MAX_IF_NAME); + + /* Update own tolerance if peer indicates a non-zero value */ + if (in_range(peers_tol, TIPC_MIN_LINK_TOL, TIPC_MAX_LINK_TOL)) { + l->tolerance = peers_tol; + l->bc_rcvlink->tolerance = peers_tol; + } + /* Update own priority if peer's priority is higher */ + if (in_range(peers_prio, l->priority + 1, TIPC_MAX_LINK_PRI)) + l->priority = peers_prio; + + /* If peer is going down we want full re-establish cycle */ + if (msg_peer_stopping(hdr)) { + rc = tipc_link_fsm_evt(l, LINK_FAILURE_EVT); + break; + } + + /* If this endpoint was re-created while peer was ESTABLISHING + * it doesn't know current session number. Force re-synch. + */ + if (mtyp == ACTIVATE_MSG && msg_dest_session_valid(hdr) && + l->session != msg_dest_session(hdr)) { + if (less(l->session, msg_dest_session(hdr))) + l->session = msg_dest_session(hdr) + 1; + break; + } + + /* ACTIVATE_MSG serves as PEER_RESET if link is already down */ + if (mtyp == RESET_MSG || !link_is_up(l)) + rc = tipc_link_fsm_evt(l, LINK_PEER_RESET_EVT); + + /* ACTIVATE_MSG takes up link if it was already locally reset */ + if (mtyp == ACTIVATE_MSG && l->state == LINK_ESTABLISHING) + rc = TIPC_LINK_UP_EVT; + + l->peer_session = msg_session(hdr); + l->in_session = true; + l->peer_bearer_id = msg_bearer_id(hdr); + if (l->mtu > msg_max) + l->mtu = msg_max; + break; + + case STATE_MSG: + /* Validate Gap ACK blocks, drop if invalid */ + glen = tipc_get_gap_ack_blks(&ga, l, hdr, true); + if (glen > dlen) + break; + + l->rcv_nxt_state = msg_seqno(hdr) + 1; + + /* Update own tolerance if peer indicates a non-zero value */ + if (in_range(peers_tol, TIPC_MIN_LINK_TOL, TIPC_MAX_LINK_TOL)) { + l->tolerance = peers_tol; + l->bc_rcvlink->tolerance = peers_tol; + } + /* Update own prio if peer indicates a different value */ + if ((peers_prio != l->priority) && + in_range(peers_prio, 1, TIPC_MAX_LINK_PRI)) { + l->priority = peers_prio; + rc = tipc_link_fsm_evt(l, LINK_FAILURE_EVT); + } + + l->silent_intv_cnt = 0; + l->stats.recv_states++; + if (msg_probe(hdr)) + l->stats.recv_probes++; + + if (!link_is_up(l)) { + if (l->state == LINK_ESTABLISHING) + rc = TIPC_LINK_UP_EVT; + break; + } + + tipc_mon_rcv(l->net, data + glen, dlen - glen, l->addr, + &l->mon_state, l->bearer_id); + + /* Send NACK if peer has sent pkts we haven't received yet */ + if ((reply || msg_is_keepalive(hdr)) && + more(peers_snd_nxt, rcv_nxt) && + !tipc_link_is_synching(l) && + skb_queue_empty(&l->deferdq)) + rcvgap = peers_snd_nxt - l->rcv_nxt; + if (rcvgap || reply) + tipc_link_build_proto_msg(l, STATE_MSG, 0, reply, + rcvgap, 0, 0, xmitq); + + released = tipc_link_advance_transmq(l, l, ack, gap, ga, xmitq, + &retransmitted, &rc); + if (gap) + l->stats.recv_nacks++; + if (released || retransmitted) + tipc_link_update_cwin(l, released, retransmitted); + if (released) + tipc_link_advance_backlog(l, xmitq); + if (unlikely(!skb_queue_empty(&l->wakeupq))) + link_prepare_wakeup(l); + } +exit: + kfree_skb(skb); + return rc; +} + +/* tipc_link_build_bc_proto_msg() - create broadcast protocol message + */ +static bool tipc_link_build_bc_proto_msg(struct tipc_link *l, bool bcast, + u16 peers_snd_nxt, + struct sk_buff_head *xmitq) +{ + struct sk_buff *skb; + struct tipc_msg *hdr; + struct sk_buff *dfrd_skb = skb_peek(&l->deferdq); + u16 ack = l->rcv_nxt - 1; + u16 gap_to = peers_snd_nxt - 1; + + skb = tipc_msg_create(BCAST_PROTOCOL, STATE_MSG, INT_H_SIZE, + 0, l->addr, tipc_own_addr(l->net), 0, 0, 0); + if (!skb) + return false; + hdr = buf_msg(skb); + msg_set_last_bcast(hdr, l->bc_sndlink->snd_nxt - 1); + msg_set_bcast_ack(hdr, ack); + msg_set_bcgap_after(hdr, ack); + if (dfrd_skb) + gap_to = buf_seqno(dfrd_skb) - 1; + msg_set_bcgap_to(hdr, gap_to); + msg_set_non_seq(hdr, bcast); + __skb_queue_tail(xmitq, skb); + return true; +} + +/* tipc_link_build_bc_init_msg() - synchronize broadcast link endpoints. + * + * Give a newly added peer node the sequence number where it should + * start receiving and acking broadcast packets. + */ +static void tipc_link_build_bc_init_msg(struct tipc_link *l, + struct sk_buff_head *xmitq) +{ + struct sk_buff_head list; + + __skb_queue_head_init(&list); + if (!tipc_link_build_bc_proto_msg(l->bc_rcvlink, false, 0, &list)) + return; + msg_set_bc_ack_invalid(buf_msg(skb_peek(&list)), true); + tipc_link_xmit(l, &list, xmitq); +} + +/* tipc_link_bc_init_rcv - receive initial broadcast synch data from peer + */ +void tipc_link_bc_init_rcv(struct tipc_link *l, struct tipc_msg *hdr) +{ + int mtyp = msg_type(hdr); + u16 peers_snd_nxt = msg_bc_snd_nxt(hdr); + + if (link_is_up(l)) + return; + + if (msg_user(hdr) == BCAST_PROTOCOL) { + l->rcv_nxt = peers_snd_nxt; + l->state = LINK_ESTABLISHED; + return; + } + + if (l->peer_caps & TIPC_BCAST_SYNCH) + return; + + if (msg_peer_node_is_up(hdr)) + return; + + /* Compatibility: accept older, less safe initial synch data */ + if ((mtyp == RESET_MSG) || (mtyp == ACTIVATE_MSG)) + l->rcv_nxt = peers_snd_nxt; +} + +/* tipc_link_bc_sync_rcv - update rcv link according to peer's send state + */ +int tipc_link_bc_sync_rcv(struct tipc_link *l, struct tipc_msg *hdr, + struct sk_buff_head *xmitq) +{ + u16 peers_snd_nxt = msg_bc_snd_nxt(hdr); + int rc = 0; + + if (!link_is_up(l)) + return rc; + + if (!msg_peer_node_is_up(hdr)) + return rc; + + /* Open when peer acknowledges our bcast init msg (pkt #1) */ + if (msg_ack(hdr)) + l->bc_peer_is_up = true; + + if (!l->bc_peer_is_up) + return rc; + + /* Ignore if peers_snd_nxt goes beyond receive window */ + if (more(peers_snd_nxt, l->rcv_nxt + l->window)) + return rc; + + l->snd_nxt = peers_snd_nxt; + if (link_bc_rcv_gap(l)) + rc |= TIPC_LINK_SND_STATE; + + /* Return now if sender supports nack via STATE messages */ + if (l->peer_caps & TIPC_BCAST_STATE_NACK) + return rc; + + /* Otherwise, be backwards compatible */ + + if (!more(peers_snd_nxt, l->rcv_nxt)) { + l->nack_state = BC_NACK_SND_CONDITIONAL; + return 0; + } + + /* Don't NACK if one was recently sent or peeked */ + if (l->nack_state == BC_NACK_SND_SUPPRESS) { + l->nack_state = BC_NACK_SND_UNCONDITIONAL; + return 0; + } + + /* Conditionally delay NACK sending until next synch rcv */ + if (l->nack_state == BC_NACK_SND_CONDITIONAL) { + l->nack_state = BC_NACK_SND_UNCONDITIONAL; + if ((peers_snd_nxt - l->rcv_nxt) < TIPC_MIN_LINK_WIN) + return 0; + } + + /* Send NACK now but suppress next one */ + tipc_link_build_bc_proto_msg(l, true, peers_snd_nxt, xmitq); + l->nack_state = BC_NACK_SND_SUPPRESS; + return 0; +} + +int tipc_link_bc_ack_rcv(struct tipc_link *r, u16 acked, u16 gap, + struct tipc_gap_ack_blks *ga, + struct sk_buff_head *xmitq, + struct sk_buff_head *retrq) +{ + struct tipc_link *l = r->bc_sndlink; + bool unused = false; + int rc = 0; + + if (!link_is_up(r) || !r->bc_peer_is_up) + return 0; + + if (gap) { + l->stats.recv_nacks++; + r->stats.recv_nacks++; + } + + if (less(acked, r->acked) || (acked == r->acked && !gap && !ga)) + return 0; + + trace_tipc_link_bc_ack(r, acked, gap, &l->transmq); + tipc_link_advance_transmq(l, r, acked, gap, ga, retrq, &unused, &rc); + + tipc_link_advance_backlog(l, xmitq); + if (unlikely(!skb_queue_empty(&l->wakeupq))) + link_prepare_wakeup(l); + + return rc; +} + +/* tipc_link_bc_nack_rcv(): receive broadcast nack message + * This function is here for backwards compatibility, since + * no BCAST_PROTOCOL/STATE messages occur from TIPC v2.5. + */ +int tipc_link_bc_nack_rcv(struct tipc_link *l, struct sk_buff *skb, + struct sk_buff_head *xmitq) +{ + struct tipc_msg *hdr = buf_msg(skb); + u32 dnode = msg_destnode(hdr); + int mtyp = msg_type(hdr); + u16 acked = msg_bcast_ack(hdr); + u16 from = acked + 1; + u16 to = msg_bcgap_to(hdr); + u16 peers_snd_nxt = to + 1; + int rc = 0; + + kfree_skb(skb); + + if (!tipc_link_is_up(l) || !l->bc_peer_is_up) + return 0; + + if (mtyp != STATE_MSG) + return 0; + + if (dnode == tipc_own_addr(l->net)) { + rc = tipc_link_bc_ack_rcv(l, acked, to - acked, NULL, xmitq, + xmitq); + l->stats.recv_nacks++; + return rc; + } + + /* Msg for other node => suppress own NACK at next sync if applicable */ + if (more(peers_snd_nxt, l->rcv_nxt) && !less(l->rcv_nxt, from)) + l->nack_state = BC_NACK_SND_SUPPRESS; + + return 0; +} + +void tipc_link_set_queue_limits(struct tipc_link *l, u32 min_win, u32 max_win) +{ + int max_bulk = TIPC_MAX_PUBL / (l->mtu / ITEM_SIZE); + + l->min_win = min_win; + l->ssthresh = max_win; + l->max_win = max_win; + l->window = min_win; + l->backlog[TIPC_LOW_IMPORTANCE].limit = min_win * 2; + l->backlog[TIPC_MEDIUM_IMPORTANCE].limit = min_win * 4; + l->backlog[TIPC_HIGH_IMPORTANCE].limit = min_win * 6; + l->backlog[TIPC_CRITICAL_IMPORTANCE].limit = min_win * 8; + l->backlog[TIPC_SYSTEM_IMPORTANCE].limit = max_bulk; +} + +/** + * tipc_link_reset_stats - reset link statistics + * @l: pointer to link + */ +void tipc_link_reset_stats(struct tipc_link *l) +{ + memset(&l->stats, 0, sizeof(l->stats)); +} + +static void link_print(struct tipc_link *l, const char *str) +{ + struct sk_buff *hskb = skb_peek(&l->transmq); + u16 head = hskb ? msg_seqno(buf_msg(hskb)) : l->snd_nxt - 1; + u16 tail = l->snd_nxt - 1; + + pr_info("%s Link <%s> state %x\n", str, l->name, l->state); + pr_info("XMTQ: %u [%u-%u], BKLGQ: %u, SNDNX: %u, RCVNX: %u\n", + skb_queue_len(&l->transmq), head, tail, + skb_queue_len(&l->backlogq), l->snd_nxt, l->rcv_nxt); +} + +/* Parse and validate nested (link) properties valid for media, bearer and link + */ +int tipc_nl_parse_link_prop(struct nlattr *prop, struct nlattr *props[]) +{ + int err; + + err = nla_parse_nested_deprecated(props, TIPC_NLA_PROP_MAX, prop, + tipc_nl_prop_policy, NULL); + if (err) + return err; + + if (props[TIPC_NLA_PROP_PRIO]) { + u32 prio; + + prio = nla_get_u32(props[TIPC_NLA_PROP_PRIO]); + if (prio > TIPC_MAX_LINK_PRI) + return -EINVAL; + } + + if (props[TIPC_NLA_PROP_TOL]) { + u32 tol; + + tol = nla_get_u32(props[TIPC_NLA_PROP_TOL]); + if ((tol < TIPC_MIN_LINK_TOL) || (tol > TIPC_MAX_LINK_TOL)) + return -EINVAL; + } + + if (props[TIPC_NLA_PROP_WIN]) { + u32 max_win; + + max_win = nla_get_u32(props[TIPC_NLA_PROP_WIN]); + if (max_win < TIPC_DEF_LINK_WIN || max_win > TIPC_MAX_LINK_WIN) + return -EINVAL; + } + + return 0; +} + +static int __tipc_nl_add_stats(struct sk_buff *skb, struct tipc_stats *s) +{ + int i; + struct nlattr *stats; + + struct nla_map { + u32 key; + u32 val; + }; + + struct nla_map map[] = { + {TIPC_NLA_STATS_RX_INFO, 0}, + {TIPC_NLA_STATS_RX_FRAGMENTS, s->recv_fragments}, + {TIPC_NLA_STATS_RX_FRAGMENTED, s->recv_fragmented}, + {TIPC_NLA_STATS_RX_BUNDLES, s->recv_bundles}, + {TIPC_NLA_STATS_RX_BUNDLED, s->recv_bundled}, + {TIPC_NLA_STATS_TX_INFO, 0}, + {TIPC_NLA_STATS_TX_FRAGMENTS, s->sent_fragments}, + {TIPC_NLA_STATS_TX_FRAGMENTED, s->sent_fragmented}, + {TIPC_NLA_STATS_TX_BUNDLES, s->sent_bundles}, + {TIPC_NLA_STATS_TX_BUNDLED, s->sent_bundled}, + {TIPC_NLA_STATS_MSG_PROF_TOT, (s->msg_length_counts) ? + s->msg_length_counts : 1}, + {TIPC_NLA_STATS_MSG_LEN_CNT, s->msg_length_counts}, + {TIPC_NLA_STATS_MSG_LEN_TOT, s->msg_lengths_total}, + {TIPC_NLA_STATS_MSG_LEN_P0, s->msg_length_profile[0]}, + {TIPC_NLA_STATS_MSG_LEN_P1, s->msg_length_profile[1]}, + {TIPC_NLA_STATS_MSG_LEN_P2, s->msg_length_profile[2]}, + {TIPC_NLA_STATS_MSG_LEN_P3, s->msg_length_profile[3]}, + {TIPC_NLA_STATS_MSG_LEN_P4, s->msg_length_profile[4]}, + {TIPC_NLA_STATS_MSG_LEN_P5, s->msg_length_profile[5]}, + {TIPC_NLA_STATS_MSG_LEN_P6, s->msg_length_profile[6]}, + {TIPC_NLA_STATS_RX_STATES, s->recv_states}, + {TIPC_NLA_STATS_RX_PROBES, s->recv_probes}, + {TIPC_NLA_STATS_RX_NACKS, s->recv_nacks}, + {TIPC_NLA_STATS_RX_DEFERRED, s->deferred_recv}, + {TIPC_NLA_STATS_TX_STATES, s->sent_states}, + {TIPC_NLA_STATS_TX_PROBES, s->sent_probes}, + {TIPC_NLA_STATS_TX_NACKS, s->sent_nacks}, + {TIPC_NLA_STATS_TX_ACKS, s->sent_acks}, + {TIPC_NLA_STATS_RETRANSMITTED, s->retransmitted}, + {TIPC_NLA_STATS_DUPLICATES, s->duplicates}, + {TIPC_NLA_STATS_LINK_CONGS, s->link_congs}, + {TIPC_NLA_STATS_MAX_QUEUE, s->max_queue_sz}, + {TIPC_NLA_STATS_AVG_QUEUE, s->queue_sz_counts ? + (s->accu_queue_sz / s->queue_sz_counts) : 0} + }; + + stats = nla_nest_start_noflag(skb, TIPC_NLA_LINK_STATS); + if (!stats) + return -EMSGSIZE; + + for (i = 0; i < ARRAY_SIZE(map); i++) + if (nla_put_u32(skb, map[i].key, map[i].val)) + goto msg_full; + + nla_nest_end(skb, stats); + + return 0; +msg_full: + nla_nest_cancel(skb, stats); + + return -EMSGSIZE; +} + +/* Caller should hold appropriate locks to protect the link */ +int __tipc_nl_add_link(struct net *net, struct tipc_nl_msg *msg, + struct tipc_link *link, int nlflags) +{ + u32 self = tipc_own_addr(net); + struct nlattr *attrs; + struct nlattr *prop; + void *hdr; + int err; + + hdr = genlmsg_put(msg->skb, msg->portid, msg->seq, &tipc_genl_family, + nlflags, TIPC_NL_LINK_GET); + if (!hdr) + return -EMSGSIZE; + + attrs = nla_nest_start_noflag(msg->skb, TIPC_NLA_LINK); + if (!attrs) + goto msg_full; + + if (nla_put_string(msg->skb, TIPC_NLA_LINK_NAME, link->name)) + goto attr_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_LINK_DEST, tipc_cluster_mask(self))) + goto attr_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_LINK_MTU, link->mtu)) + goto attr_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_LINK_RX, link->stats.recv_pkts)) + goto attr_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_LINK_TX, link->stats.sent_pkts)) + goto attr_msg_full; + + if (tipc_link_is_up(link)) + if (nla_put_flag(msg->skb, TIPC_NLA_LINK_UP)) + goto attr_msg_full; + if (link->active) + if (nla_put_flag(msg->skb, TIPC_NLA_LINK_ACTIVE)) + goto attr_msg_full; + + prop = nla_nest_start_noflag(msg->skb, TIPC_NLA_LINK_PROP); + if (!prop) + goto attr_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PROP_PRIO, link->priority)) + goto prop_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PROP_TOL, link->tolerance)) + goto prop_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PROP_WIN, + link->window)) + goto prop_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PROP_PRIO, link->priority)) + goto prop_msg_full; + nla_nest_end(msg->skb, prop); + + err = __tipc_nl_add_stats(msg->skb, &link->stats); + if (err) + goto attr_msg_full; + + nla_nest_end(msg->skb, attrs); + genlmsg_end(msg->skb, hdr); + + return 0; + +prop_msg_full: + nla_nest_cancel(msg->skb, prop); +attr_msg_full: + nla_nest_cancel(msg->skb, attrs); +msg_full: + genlmsg_cancel(msg->skb, hdr); + + return -EMSGSIZE; +} + +static int __tipc_nl_add_bc_link_stat(struct sk_buff *skb, + struct tipc_stats *stats) +{ + int i; + struct nlattr *nest; + + struct nla_map { + __u32 key; + __u32 val; + }; + + struct nla_map map[] = { + {TIPC_NLA_STATS_RX_INFO, stats->recv_pkts}, + {TIPC_NLA_STATS_RX_FRAGMENTS, stats->recv_fragments}, + {TIPC_NLA_STATS_RX_FRAGMENTED, stats->recv_fragmented}, + {TIPC_NLA_STATS_RX_BUNDLES, stats->recv_bundles}, + {TIPC_NLA_STATS_RX_BUNDLED, stats->recv_bundled}, + {TIPC_NLA_STATS_TX_INFO, stats->sent_pkts}, + {TIPC_NLA_STATS_TX_FRAGMENTS, stats->sent_fragments}, + {TIPC_NLA_STATS_TX_FRAGMENTED, stats->sent_fragmented}, + {TIPC_NLA_STATS_TX_BUNDLES, stats->sent_bundles}, + {TIPC_NLA_STATS_TX_BUNDLED, stats->sent_bundled}, + {TIPC_NLA_STATS_RX_NACKS, stats->recv_nacks}, + {TIPC_NLA_STATS_RX_DEFERRED, stats->deferred_recv}, + {TIPC_NLA_STATS_TX_NACKS, stats->sent_nacks}, + {TIPC_NLA_STATS_TX_ACKS, stats->sent_acks}, + {TIPC_NLA_STATS_RETRANSMITTED, stats->retransmitted}, + {TIPC_NLA_STATS_DUPLICATES, stats->duplicates}, + {TIPC_NLA_STATS_LINK_CONGS, stats->link_congs}, + {TIPC_NLA_STATS_MAX_QUEUE, stats->max_queue_sz}, + {TIPC_NLA_STATS_AVG_QUEUE, stats->queue_sz_counts ? + (stats->accu_queue_sz / stats->queue_sz_counts) : 0} + }; + + nest = nla_nest_start_noflag(skb, TIPC_NLA_LINK_STATS); + if (!nest) + return -EMSGSIZE; + + for (i = 0; i < ARRAY_SIZE(map); i++) + if (nla_put_u32(skb, map[i].key, map[i].val)) + goto msg_full; + + nla_nest_end(skb, nest); + + return 0; +msg_full: + nla_nest_cancel(skb, nest); + + return -EMSGSIZE; +} + +int tipc_nl_add_bc_link(struct net *net, struct tipc_nl_msg *msg, + struct tipc_link *bcl) +{ + int err; + void *hdr; + struct nlattr *attrs; + struct nlattr *prop; + u32 bc_mode = tipc_bcast_get_mode(net); + u32 bc_ratio = tipc_bcast_get_broadcast_ratio(net); + + if (!bcl) + return 0; + + tipc_bcast_lock(net); + + hdr = genlmsg_put(msg->skb, msg->portid, msg->seq, &tipc_genl_family, + NLM_F_MULTI, TIPC_NL_LINK_GET); + if (!hdr) { + tipc_bcast_unlock(net); + return -EMSGSIZE; + } + + attrs = nla_nest_start_noflag(msg->skb, TIPC_NLA_LINK); + if (!attrs) + goto msg_full; + + /* The broadcast link is always up */ + if (nla_put_flag(msg->skb, TIPC_NLA_LINK_UP)) + goto attr_msg_full; + + if (nla_put_flag(msg->skb, TIPC_NLA_LINK_BROADCAST)) + goto attr_msg_full; + if (nla_put_string(msg->skb, TIPC_NLA_LINK_NAME, bcl->name)) + goto attr_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_LINK_RX, 0)) + goto attr_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_LINK_TX, 0)) + goto attr_msg_full; + + prop = nla_nest_start_noflag(msg->skb, TIPC_NLA_LINK_PROP); + if (!prop) + goto attr_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PROP_WIN, bcl->max_win)) + goto prop_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PROP_BROADCAST, bc_mode)) + goto prop_msg_full; + if (bc_mode & BCLINK_MODE_SEL) + if (nla_put_u32(msg->skb, TIPC_NLA_PROP_BROADCAST_RATIO, + bc_ratio)) + goto prop_msg_full; + nla_nest_end(msg->skb, prop); + + err = __tipc_nl_add_bc_link_stat(msg->skb, &bcl->stats); + if (err) + goto attr_msg_full; + + tipc_bcast_unlock(net); + nla_nest_end(msg->skb, attrs); + genlmsg_end(msg->skb, hdr); + + return 0; + +prop_msg_full: + nla_nest_cancel(msg->skb, prop); +attr_msg_full: + nla_nest_cancel(msg->skb, attrs); +msg_full: + tipc_bcast_unlock(net); + genlmsg_cancel(msg->skb, hdr); + + return -EMSGSIZE; +} + +void tipc_link_set_tolerance(struct tipc_link *l, u32 tol, + struct sk_buff_head *xmitq) +{ + l->tolerance = tol; + if (l->bc_rcvlink) + l->bc_rcvlink->tolerance = tol; + if (link_is_up(l)) + tipc_link_build_proto_msg(l, STATE_MSG, 0, 0, 0, tol, 0, xmitq); +} + +void tipc_link_set_prio(struct tipc_link *l, u32 prio, + struct sk_buff_head *xmitq) +{ + l->priority = prio; + tipc_link_build_proto_msg(l, STATE_MSG, 0, 0, 0, 0, prio, xmitq); +} + +void tipc_link_set_abort_limit(struct tipc_link *l, u32 limit) +{ + l->abort_limit = limit; +} + +/** + * tipc_link_dump - dump TIPC link data + * @l: tipc link to be dumped + * @dqueues: bitmask to decide if any link queue to be dumped? + * - TIPC_DUMP_NONE: don't dump link queues + * - TIPC_DUMP_TRANSMQ: dump link transmq queue + * - TIPC_DUMP_BACKLOGQ: dump link backlog queue + * - TIPC_DUMP_DEFERDQ: dump link deferd queue + * - TIPC_DUMP_INPUTQ: dump link input queue + * - TIPC_DUMP_WAKEUP: dump link wakeup queue + * - TIPC_DUMP_ALL: dump all the link queues above + * @buf: returned buffer of dump data in format + */ +int tipc_link_dump(struct tipc_link *l, u16 dqueues, char *buf) +{ + int i = 0; + size_t sz = (dqueues) ? LINK_LMAX : LINK_LMIN; + struct sk_buff_head *list; + struct sk_buff *hskb, *tskb; + u32 len; + + if (!l) { + i += scnprintf(buf, sz, "link data: (null)\n"); + return i; + } + + i += scnprintf(buf, sz, "link data: %x", l->addr); + i += scnprintf(buf + i, sz - i, " %x", l->state); + i += scnprintf(buf + i, sz - i, " %u", l->in_session); + i += scnprintf(buf + i, sz - i, " %u", l->session); + i += scnprintf(buf + i, sz - i, " %u", l->peer_session); + i += scnprintf(buf + i, sz - i, " %u", l->snd_nxt); + i += scnprintf(buf + i, sz - i, " %u", l->rcv_nxt); + i += scnprintf(buf + i, sz - i, " %u", l->snd_nxt_state); + i += scnprintf(buf + i, sz - i, " %u", l->rcv_nxt_state); + i += scnprintf(buf + i, sz - i, " %x", l->peer_caps); + i += scnprintf(buf + i, sz - i, " %u", l->silent_intv_cnt); + i += scnprintf(buf + i, sz - i, " %u", l->rst_cnt); + i += scnprintf(buf + i, sz - i, " %u", 0); + i += scnprintf(buf + i, sz - i, " %u", 0); + i += scnprintf(buf + i, sz - i, " %u", l->acked); + + list = &l->transmq; + len = skb_queue_len(list); + hskb = skb_peek(list); + tskb = skb_peek_tail(list); + i += scnprintf(buf + i, sz - i, " | %u %u %u", len, + (hskb) ? msg_seqno(buf_msg(hskb)) : 0, + (tskb) ? msg_seqno(buf_msg(tskb)) : 0); + + list = &l->deferdq; + len = skb_queue_len(list); + hskb = skb_peek(list); + tskb = skb_peek_tail(list); + i += scnprintf(buf + i, sz - i, " | %u %u %u", len, + (hskb) ? msg_seqno(buf_msg(hskb)) : 0, + (tskb) ? msg_seqno(buf_msg(tskb)) : 0); + + list = &l->backlogq; + len = skb_queue_len(list); + hskb = skb_peek(list); + tskb = skb_peek_tail(list); + i += scnprintf(buf + i, sz - i, " | %u %u %u", len, + (hskb) ? msg_seqno(buf_msg(hskb)) : 0, + (tskb) ? msg_seqno(buf_msg(tskb)) : 0); + + list = l->inputq; + len = skb_queue_len(list); + hskb = skb_peek(list); + tskb = skb_peek_tail(list); + i += scnprintf(buf + i, sz - i, " | %u %u %u\n", len, + (hskb) ? msg_seqno(buf_msg(hskb)) : 0, + (tskb) ? msg_seqno(buf_msg(tskb)) : 0); + + if (dqueues & TIPC_DUMP_TRANSMQ) { + i += scnprintf(buf + i, sz - i, "transmq: "); + i += tipc_list_dump(&l->transmq, false, buf + i); + } + if (dqueues & TIPC_DUMP_BACKLOGQ) { + i += scnprintf(buf + i, sz - i, + "backlogq: <%u %u %u %u %u>, ", + l->backlog[TIPC_LOW_IMPORTANCE].len, + l->backlog[TIPC_MEDIUM_IMPORTANCE].len, + l->backlog[TIPC_HIGH_IMPORTANCE].len, + l->backlog[TIPC_CRITICAL_IMPORTANCE].len, + l->backlog[TIPC_SYSTEM_IMPORTANCE].len); + i += tipc_list_dump(&l->backlogq, false, buf + i); + } + if (dqueues & TIPC_DUMP_DEFERDQ) { + i += scnprintf(buf + i, sz - i, "deferdq: "); + i += tipc_list_dump(&l->deferdq, false, buf + i); + } + if (dqueues & TIPC_DUMP_INPUTQ) { + i += scnprintf(buf + i, sz - i, "inputq: "); + i += tipc_list_dump(l->inputq, false, buf + i); + } + if (dqueues & TIPC_DUMP_WAKEUP) { + i += scnprintf(buf + i, sz - i, "wakeup: "); + i += tipc_list_dump(&l->wakeupq, false, buf + i); + } + + return i; +} diff --git a/net/tipc/link.h b/net/tipc/link.h new file mode 100644 index 000000000..a16f401fd --- /dev/null +++ b/net/tipc/link.h @@ -0,0 +1,160 @@ +/* + * net/tipc/link.h: Include file for TIPC link code + * + * Copyright (c) 1995-2006, 2013-2014, Ericsson AB + * Copyright (c) 2004-2005, 2010-2011, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _TIPC_LINK_H +#define _TIPC_LINK_H + +#include +#include "msg.h" +#include "node.h" + +/* TIPC-specific error codes +*/ +#define ELINKCONG EAGAIN /* link congestion <=> resource unavailable */ + +/* Link FSM events: + */ +enum { + LINK_ESTABLISH_EVT = 0xec1ab1e, + LINK_PEER_RESET_EVT = 0x9eed0e, + LINK_FAILURE_EVT = 0xfa110e, + LINK_RESET_EVT = 0x10ca1d0e, + LINK_FAILOVER_BEGIN_EVT = 0xfa110bee, + LINK_FAILOVER_END_EVT = 0xfa110ede, + LINK_SYNCH_BEGIN_EVT = 0xc1ccbee, + LINK_SYNCH_END_EVT = 0xc1ccede +}; + +/* Events returned from link at packet reception or at timeout + */ +enum { + TIPC_LINK_UP_EVT = 1, + TIPC_LINK_DOWN_EVT = (1 << 1), + TIPC_LINK_SND_STATE = (1 << 2) +}; + +/* Starting value for maximum packet size negotiation on unicast links + * (unless bearer MTU is less) + */ +#define MAX_PKT_DEFAULT 1500 + +bool tipc_link_create(struct net *net, char *if_name, int bearer_id, + int tolerance, char net_plane, u32 mtu, int priority, + u32 min_win, u32 max_win, u32 session, u32 ownnode, + u32 peer, u8 *peer_id, u16 peer_caps, + struct tipc_link *bc_sndlink, + struct tipc_link *bc_rcvlink, + struct sk_buff_head *inputq, + struct sk_buff_head *namedq, + struct tipc_link **link); +bool tipc_link_bc_create(struct net *net, u32 ownnode, u32 peer, u8 *peer_id, + int mtu, u32 min_win, u32 max_win, u16 peer_caps, + struct sk_buff_head *inputq, + struct sk_buff_head *namedq, + struct tipc_link *bc_sndlink, + struct tipc_link **link); +void tipc_link_tnl_prepare(struct tipc_link *l, struct tipc_link *tnl, + int mtyp, struct sk_buff_head *xmitq); +void tipc_link_create_dummy_tnl_msg(struct tipc_link *tnl, + struct sk_buff_head *xmitq); +void tipc_link_failover_prepare(struct tipc_link *l, struct tipc_link *tnl, + struct sk_buff_head *xmitq); +void tipc_link_build_reset_msg(struct tipc_link *l, struct sk_buff_head *xmitq); +int tipc_link_fsm_evt(struct tipc_link *l, int evt); +bool tipc_link_is_up(struct tipc_link *l); +bool tipc_link_peer_is_down(struct tipc_link *l); +bool tipc_link_is_reset(struct tipc_link *l); +bool tipc_link_is_establishing(struct tipc_link *l); +bool tipc_link_is_synching(struct tipc_link *l); +bool tipc_link_is_failingover(struct tipc_link *l); +bool tipc_link_is_blocked(struct tipc_link *l); +void tipc_link_set_active(struct tipc_link *l, bool active); +void tipc_link_reset(struct tipc_link *l); +void tipc_link_reset_stats(struct tipc_link *l); +int tipc_link_xmit(struct tipc_link *link, struct sk_buff_head *list, + struct sk_buff_head *xmitq); +struct sk_buff_head *tipc_link_inputq(struct tipc_link *l); +u16 tipc_link_rcv_nxt(struct tipc_link *l); +u16 tipc_link_acked(struct tipc_link *l); +u32 tipc_link_id(struct tipc_link *l); +char *tipc_link_name(struct tipc_link *l); +u32 tipc_link_state(struct tipc_link *l); +char tipc_link_plane(struct tipc_link *l); +int tipc_link_prio(struct tipc_link *l); +int tipc_link_min_win(struct tipc_link *l); +int tipc_link_max_win(struct tipc_link *l); +void tipc_link_update_caps(struct tipc_link *l, u16 capabilities); +bool tipc_link_validate_msg(struct tipc_link *l, struct tipc_msg *hdr); +unsigned long tipc_link_tolerance(struct tipc_link *l); +void tipc_link_set_tolerance(struct tipc_link *l, u32 tol, + struct sk_buff_head *xmitq); +void tipc_link_set_prio(struct tipc_link *l, u32 prio, + struct sk_buff_head *xmitq); +void tipc_link_set_abort_limit(struct tipc_link *l, u32 limit); +void tipc_link_set_queue_limits(struct tipc_link *l, u32 min_win, u32 max_win); +int __tipc_nl_add_link(struct net *net, struct tipc_nl_msg *msg, + struct tipc_link *link, int nlflags); +int tipc_nl_parse_link_prop(struct nlattr *prop, struct nlattr *props[]); +int tipc_link_timeout(struct tipc_link *l, struct sk_buff_head *xmitq); +int tipc_link_rcv(struct tipc_link *l, struct sk_buff *skb, + struct sk_buff_head *xmitq); +int tipc_link_build_state_msg(struct tipc_link *l, struct sk_buff_head *xmitq); +void tipc_link_add_bc_peer(struct tipc_link *snd_l, + struct tipc_link *uc_l, + struct sk_buff_head *xmitq); +void tipc_link_remove_bc_peer(struct tipc_link *snd_l, + struct tipc_link *rcv_l, + struct sk_buff_head *xmitq); +int tipc_link_bc_peers(struct tipc_link *l); +void tipc_link_set_mtu(struct tipc_link *l, int mtu); +int tipc_link_mtu(struct tipc_link *l); +int tipc_link_mss(struct tipc_link *l); +u16 tipc_get_gap_ack_blks(struct tipc_gap_ack_blks **ga, struct tipc_link *l, + struct tipc_msg *hdr, bool uc); +int tipc_link_bc_ack_rcv(struct tipc_link *l, u16 acked, u16 gap, + struct tipc_gap_ack_blks *ga, + struct sk_buff_head *xmitq, + struct sk_buff_head *retrq); +void tipc_link_build_bc_sync_msg(struct tipc_link *l, + struct sk_buff_head *xmitq); +void tipc_link_bc_init_rcv(struct tipc_link *l, struct tipc_msg *hdr); +int tipc_link_bc_sync_rcv(struct tipc_link *l, struct tipc_msg *hdr, + struct sk_buff_head *xmitq); +int tipc_link_bc_nack_rcv(struct tipc_link *l, struct sk_buff *skb, + struct sk_buff_head *xmitq); +bool tipc_link_too_silent(struct tipc_link *l); +struct net *tipc_link_net(struct tipc_link *l); +#endif diff --git a/net/tipc/monitor.c b/net/tipc/monitor.c new file mode 100644 index 000000000..9618e4429 --- /dev/null +++ b/net/tipc/monitor.c @@ -0,0 +1,874 @@ +/* + * net/tipc/monitor.c + * + * Copyright (c) 2016, Ericsson AB + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include +#include "core.h" +#include "addr.h" +#include "monitor.h" +#include "bearer.h" + +#define MAX_MON_DOMAIN 64 +#define MON_TIMEOUT 120000 +#define MAX_PEER_DOWN_EVENTS 4 + +/* struct tipc_mon_domain: domain record to be transferred between peers + * @len: actual size of domain record + * @gen: current generation of sender's domain + * @ack_gen: most recent generation of self's domain acked by peer + * @member_cnt: number of domain member nodes described in this record + * @up_map: bit map indicating which of the members the sender considers up + * @members: identity of the domain members + */ +struct tipc_mon_domain { + u16 len; + u16 gen; + u16 ack_gen; + u16 member_cnt; + u64 up_map; + u32 members[MAX_MON_DOMAIN]; +}; + +/* struct tipc_peer: state of a peer node and its domain + * @addr: tipc node identity of peer + * @head_map: shows which other nodes currently consider peer 'up' + * @domain: most recent domain record from peer + * @hash: position in hashed lookup list + * @list: position in linked list, in circular ascending order by 'addr' + * @applied: number of reported domain members applied on this monitor list + * @is_up: peer is up as seen from this node + * @is_head: peer is assigned domain head as seen from this node + * @is_local: peer is in local domain and should be continuously monitored + * @down_cnt: - numbers of other peers which have reported this on lost + */ +struct tipc_peer { + u32 addr; + struct tipc_mon_domain *domain; + struct hlist_node hash; + struct list_head list; + u8 applied; + u8 down_cnt; + bool is_up; + bool is_head; + bool is_local; +}; + +struct tipc_monitor { + struct hlist_head peers[NODE_HTABLE_SIZE]; + int peer_cnt; + struct tipc_peer *self; + rwlock_t lock; + struct tipc_mon_domain cache; + u16 list_gen; + u16 dom_gen; + struct net *net; + struct timer_list timer; + unsigned long timer_intv; +}; + +static struct tipc_monitor *tipc_monitor(struct net *net, int bearer_id) +{ + return tipc_net(net)->monitors[bearer_id]; +} + +const int tipc_max_domain_size = sizeof(struct tipc_mon_domain); + +static inline u16 mon_cpu_to_le16(u16 val) +{ + return (__force __u16)htons(val); +} + +static inline u32 mon_cpu_to_le32(u32 val) +{ + return (__force __u32)htonl(val); +} + +static inline u64 mon_cpu_to_le64(u64 val) +{ + return (__force __u64)cpu_to_be64(val); +} + +static inline u16 mon_le16_to_cpu(u16 val) +{ + return ntohs((__force __be16)val); +} + +static inline u32 mon_le32_to_cpu(u32 val) +{ + return ntohl((__force __be32)val); +} + +static inline u64 mon_le64_to_cpu(u64 val) +{ + return be64_to_cpu((__force __be64)val); +} + +/* dom_rec_len(): actual length of domain record for transport + */ +static int dom_rec_len(struct tipc_mon_domain *dom, u16 mcnt) +{ + return (offsetof(struct tipc_mon_domain, members)) + (mcnt * sizeof(u32)); +} + +/* dom_size() : calculate size of own domain based on number of peers + */ +static int dom_size(int peers) +{ + int i = 0; + + while ((i * i) < peers) + i++; + return i < MAX_MON_DOMAIN ? i : MAX_MON_DOMAIN; +} + +static void map_set(u64 *up_map, int i, unsigned int v) +{ + *up_map &= ~(1ULL << i); + *up_map |= ((u64)v << i); +} + +static int map_get(u64 up_map, int i) +{ + return (up_map & (1ULL << i)) >> i; +} + +static struct tipc_peer *peer_prev(struct tipc_peer *peer) +{ + return list_last_entry(&peer->list, struct tipc_peer, list); +} + +static struct tipc_peer *peer_nxt(struct tipc_peer *peer) +{ + return list_first_entry(&peer->list, struct tipc_peer, list); +} + +static struct tipc_peer *peer_head(struct tipc_peer *peer) +{ + while (!peer->is_head) + peer = peer_prev(peer); + return peer; +} + +static struct tipc_peer *get_peer(struct tipc_monitor *mon, u32 addr) +{ + struct tipc_peer *peer; + unsigned int thash = tipc_hashfn(addr); + + hlist_for_each_entry(peer, &mon->peers[thash], hash) { + if (peer->addr == addr) + return peer; + } + return NULL; +} + +static struct tipc_peer *get_self(struct net *net, int bearer_id) +{ + struct tipc_monitor *mon = tipc_monitor(net, bearer_id); + + return mon->self; +} + +static inline bool tipc_mon_is_active(struct net *net, struct tipc_monitor *mon) +{ + struct tipc_net *tn = tipc_net(net); + + return mon->peer_cnt > tn->mon_threshold; +} + +/* mon_identify_lost_members() : - identify amd mark potentially lost members + */ +static void mon_identify_lost_members(struct tipc_peer *peer, + struct tipc_mon_domain *dom_bef, + int applied_bef) +{ + struct tipc_peer *member = peer; + struct tipc_mon_domain *dom_aft = peer->domain; + int applied_aft = peer->applied; + int i; + + for (i = 0; i < applied_bef; i++) { + member = peer_nxt(member); + + /* Do nothing if self or peer already see member as down */ + if (!member->is_up || !map_get(dom_bef->up_map, i)) + continue; + + /* Loss of local node must be detected by active probing */ + if (member->is_local) + continue; + + /* Start probing if member was removed from applied domain */ + if (!applied_aft || (applied_aft < i)) { + member->down_cnt = 1; + continue; + } + + /* Member loss is confirmed if it is still in applied domain */ + if (!map_get(dom_aft->up_map, i)) + member->down_cnt++; + } +} + +/* mon_apply_domain() : match a peer's domain record against monitor list + */ +static void mon_apply_domain(struct tipc_monitor *mon, + struct tipc_peer *peer) +{ + struct tipc_mon_domain *dom = peer->domain; + struct tipc_peer *member; + u32 addr; + int i; + + if (!dom || !peer->is_up) + return; + + /* Scan across domain members and match against monitor list */ + peer->applied = 0; + member = peer_nxt(peer); + for (i = 0; i < dom->member_cnt; i++) { + addr = dom->members[i]; + if (addr != member->addr) + return; + peer->applied++; + member = peer_nxt(member); + } +} + +/* mon_update_local_domain() : update after peer addition/removal/up/down + */ +static void mon_update_local_domain(struct tipc_monitor *mon) +{ + struct tipc_peer *self = mon->self; + struct tipc_mon_domain *cache = &mon->cache; + struct tipc_mon_domain *dom = self->domain; + struct tipc_peer *peer = self; + u64 prev_up_map = dom->up_map; + u16 member_cnt, i; + bool diff; + + /* Update local domain size based on current size of cluster */ + member_cnt = dom_size(mon->peer_cnt) - 1; + self->applied = member_cnt; + + /* Update native and cached outgoing local domain records */ + dom->len = dom_rec_len(dom, member_cnt); + diff = dom->member_cnt != member_cnt; + dom->member_cnt = member_cnt; + for (i = 0; i < member_cnt; i++) { + peer = peer_nxt(peer); + diff |= dom->members[i] != peer->addr; + dom->members[i] = peer->addr; + map_set(&dom->up_map, i, peer->is_up); + cache->members[i] = mon_cpu_to_le32(peer->addr); + } + diff |= dom->up_map != prev_up_map; + if (!diff) + return; + dom->gen = ++mon->dom_gen; + cache->len = mon_cpu_to_le16(dom->len); + cache->gen = mon_cpu_to_le16(dom->gen); + cache->member_cnt = mon_cpu_to_le16(member_cnt); + cache->up_map = mon_cpu_to_le64(dom->up_map); + mon_apply_domain(mon, self); +} + +/* mon_update_neighbors() : update preceding neighbors of added/removed peer + */ +static void mon_update_neighbors(struct tipc_monitor *mon, + struct tipc_peer *peer) +{ + int dz, i; + + dz = dom_size(mon->peer_cnt); + for (i = 0; i < dz; i++) { + mon_apply_domain(mon, peer); + peer = peer_prev(peer); + } +} + +/* mon_assign_roles() : reassign peer roles after a network change + * The monitor list is consistent at this stage; i.e., each peer is monitoring + * a set of domain members as matched between domain record and the monitor list + */ +static void mon_assign_roles(struct tipc_monitor *mon, struct tipc_peer *head) +{ + struct tipc_peer *peer = peer_nxt(head); + struct tipc_peer *self = mon->self; + int i = 0; + + for (; peer != self; peer = peer_nxt(peer)) { + peer->is_local = false; + + /* Update domain member */ + if (i++ < head->applied) { + peer->is_head = false; + if (head == self) + peer->is_local = true; + continue; + } + /* Assign next domain head */ + if (!peer->is_up) + continue; + if (peer->is_head) + break; + head = peer; + head->is_head = true; + i = 0; + } + mon->list_gen++; +} + +void tipc_mon_remove_peer(struct net *net, u32 addr, int bearer_id) +{ + struct tipc_monitor *mon = tipc_monitor(net, bearer_id); + struct tipc_peer *self; + struct tipc_peer *peer, *prev, *head; + + if (!mon) + return; + + self = get_self(net, bearer_id); + write_lock_bh(&mon->lock); + peer = get_peer(mon, addr); + if (!peer) + goto exit; + prev = peer_prev(peer); + list_del(&peer->list); + hlist_del(&peer->hash); + kfree(peer->domain); + kfree(peer); + mon->peer_cnt--; + head = peer_head(prev); + if (head == self) + mon_update_local_domain(mon); + mon_update_neighbors(mon, prev); + + /* Revert to full-mesh monitoring if we reach threshold */ + if (!tipc_mon_is_active(net, mon)) { + list_for_each_entry(peer, &self->list, list) { + kfree(peer->domain); + peer->domain = NULL; + peer->applied = 0; + } + } + mon_assign_roles(mon, head); +exit: + write_unlock_bh(&mon->lock); +} + +static bool tipc_mon_add_peer(struct tipc_monitor *mon, u32 addr, + struct tipc_peer **peer) +{ + struct tipc_peer *self = mon->self; + struct tipc_peer *cur, *prev, *p; + + p = kzalloc(sizeof(*p), GFP_ATOMIC); + *peer = p; + if (!p) + return false; + p->addr = addr; + + /* Add new peer to lookup list */ + INIT_LIST_HEAD(&p->list); + hlist_add_head(&p->hash, &mon->peers[tipc_hashfn(addr)]); + + /* Sort new peer into iterator list, in ascending circular order */ + prev = self; + list_for_each_entry(cur, &self->list, list) { + if ((addr > prev->addr) && (addr < cur->addr)) + break; + if (((addr < cur->addr) || (addr > prev->addr)) && + (prev->addr > cur->addr)) + break; + prev = cur; + } + list_add_tail(&p->list, &cur->list); + mon->peer_cnt++; + mon_update_neighbors(mon, p); + return true; +} + +void tipc_mon_peer_up(struct net *net, u32 addr, int bearer_id) +{ + struct tipc_monitor *mon = tipc_monitor(net, bearer_id); + struct tipc_peer *self = get_self(net, bearer_id); + struct tipc_peer *peer, *head; + + write_lock_bh(&mon->lock); + peer = get_peer(mon, addr); + if (!peer && !tipc_mon_add_peer(mon, addr, &peer)) + goto exit; + peer->is_up = true; + head = peer_head(peer); + if (head == self) + mon_update_local_domain(mon); + mon_assign_roles(mon, head); +exit: + write_unlock_bh(&mon->lock); +} + +void tipc_mon_peer_down(struct net *net, u32 addr, int bearer_id) +{ + struct tipc_monitor *mon = tipc_monitor(net, bearer_id); + struct tipc_peer *self; + struct tipc_peer *peer, *head; + struct tipc_mon_domain *dom; + int applied; + + if (!mon) + return; + + self = get_self(net, bearer_id); + write_lock_bh(&mon->lock); + peer = get_peer(mon, addr); + if (!peer) { + pr_warn("Mon: unknown link %x/%u DOWN\n", addr, bearer_id); + goto exit; + } + applied = peer->applied; + peer->applied = 0; + dom = peer->domain; + peer->domain = NULL; + if (peer->is_head) + mon_identify_lost_members(peer, dom, applied); + kfree(dom); + peer->is_up = false; + peer->is_head = false; + peer->is_local = false; + peer->down_cnt = 0; + head = peer_head(peer); + if (head == self) + mon_update_local_domain(mon); + mon_assign_roles(mon, head); +exit: + write_unlock_bh(&mon->lock); +} + +/* tipc_mon_rcv - process monitor domain event message + */ +void tipc_mon_rcv(struct net *net, void *data, u16 dlen, u32 addr, + struct tipc_mon_state *state, int bearer_id) +{ + struct tipc_monitor *mon = tipc_monitor(net, bearer_id); + struct tipc_mon_domain *arrv_dom = data; + struct tipc_mon_domain dom_bef; + struct tipc_mon_domain *dom; + struct tipc_peer *peer; + u16 new_member_cnt = mon_le16_to_cpu(arrv_dom->member_cnt); + int new_dlen = dom_rec_len(arrv_dom, new_member_cnt); + u16 new_gen = mon_le16_to_cpu(arrv_dom->gen); + u16 acked_gen = mon_le16_to_cpu(arrv_dom->ack_gen); + u16 arrv_dlen = mon_le16_to_cpu(arrv_dom->len); + bool probing = state->probing; + int i, applied_bef; + + state->probing = false; + + /* Sanity check received domain record */ + if (new_member_cnt > MAX_MON_DOMAIN) + return; + if (dlen < dom_rec_len(arrv_dom, 0)) + return; + if (dlen != dom_rec_len(arrv_dom, new_member_cnt)) + return; + if (dlen < new_dlen || arrv_dlen != new_dlen) + return; + + /* Synch generation numbers with peer if link just came up */ + if (!state->synched) { + state->peer_gen = new_gen - 1; + state->acked_gen = acked_gen; + state->synched = true; + } + + if (more(acked_gen, state->acked_gen)) + state->acked_gen = acked_gen; + + /* Drop duplicate unless we are waiting for a probe response */ + if (!more(new_gen, state->peer_gen) && !probing) + return; + + write_lock_bh(&mon->lock); + peer = get_peer(mon, addr); + if (!peer || !peer->is_up) + goto exit; + + /* Peer is confirmed, stop any ongoing probing */ + peer->down_cnt = 0; + + /* Task is done for duplicate record */ + if (!more(new_gen, state->peer_gen)) + goto exit; + + state->peer_gen = new_gen; + + /* Cache current domain record for later use */ + dom_bef.member_cnt = 0; + dom = peer->domain; + if (dom) + memcpy(&dom_bef, dom, dom->len); + + /* Transform and store received domain record */ + if (!dom || (dom->len < new_dlen)) { + kfree(dom); + dom = kmalloc(new_dlen, GFP_ATOMIC); + peer->domain = dom; + if (!dom) + goto exit; + } + dom->len = new_dlen; + dom->gen = new_gen; + dom->member_cnt = new_member_cnt; + dom->up_map = mon_le64_to_cpu(arrv_dom->up_map); + for (i = 0; i < new_member_cnt; i++) + dom->members[i] = mon_le32_to_cpu(arrv_dom->members[i]); + + /* Update peers affected by this domain record */ + applied_bef = peer->applied; + mon_apply_domain(mon, peer); + mon_identify_lost_members(peer, &dom_bef, applied_bef); + mon_assign_roles(mon, peer_head(peer)); +exit: + write_unlock_bh(&mon->lock); +} + +void tipc_mon_prep(struct net *net, void *data, int *dlen, + struct tipc_mon_state *state, int bearer_id) +{ + struct tipc_monitor *mon = tipc_monitor(net, bearer_id); + struct tipc_mon_domain *dom = data; + u16 gen = mon->dom_gen; + u16 len; + + /* Send invalid record if not active */ + if (!tipc_mon_is_active(net, mon)) { + dom->len = 0; + return; + } + + /* Send only a dummy record with ack if peer has acked our last sent */ + if (likely(state->acked_gen == gen)) { + len = dom_rec_len(dom, 0); + *dlen = len; + dom->len = mon_cpu_to_le16(len); + dom->gen = mon_cpu_to_le16(gen); + dom->ack_gen = mon_cpu_to_le16(state->peer_gen); + dom->member_cnt = 0; + return; + } + /* Send the full record */ + read_lock_bh(&mon->lock); + len = mon_le16_to_cpu(mon->cache.len); + *dlen = len; + memcpy(data, &mon->cache, len); + read_unlock_bh(&mon->lock); + dom->ack_gen = mon_cpu_to_le16(state->peer_gen); +} + +void tipc_mon_get_state(struct net *net, u32 addr, + struct tipc_mon_state *state, + int bearer_id) +{ + struct tipc_monitor *mon = tipc_monitor(net, bearer_id); + struct tipc_peer *peer; + + if (!tipc_mon_is_active(net, mon)) { + state->probing = false; + state->monitoring = true; + return; + } + + /* Used cached state if table has not changed */ + if (!state->probing && + (state->list_gen == mon->list_gen) && + (state->acked_gen == mon->dom_gen)) + return; + + read_lock_bh(&mon->lock); + peer = get_peer(mon, addr); + if (peer) { + state->probing = state->acked_gen != mon->dom_gen; + state->probing |= peer->down_cnt; + state->reset |= peer->down_cnt >= MAX_PEER_DOWN_EVENTS; + state->monitoring = peer->is_local; + state->monitoring |= peer->is_head; + state->list_gen = mon->list_gen; + } + read_unlock_bh(&mon->lock); +} + +static void mon_timeout(struct timer_list *t) +{ + struct tipc_monitor *mon = from_timer(mon, t, timer); + struct tipc_peer *self; + int best_member_cnt = dom_size(mon->peer_cnt) - 1; + + write_lock_bh(&mon->lock); + self = mon->self; + if (self && (best_member_cnt != self->applied)) { + mon_update_local_domain(mon); + mon_assign_roles(mon, self); + } + write_unlock_bh(&mon->lock); + mod_timer(&mon->timer, jiffies + mon->timer_intv); +} + +int tipc_mon_create(struct net *net, int bearer_id) +{ + struct tipc_net *tn = tipc_net(net); + struct tipc_monitor *mon; + struct tipc_peer *self; + struct tipc_mon_domain *dom; + + if (tn->monitors[bearer_id]) + return 0; + + mon = kzalloc(sizeof(*mon), GFP_ATOMIC); + self = kzalloc(sizeof(*self), GFP_ATOMIC); + dom = kzalloc(sizeof(*dom), GFP_ATOMIC); + if (!mon || !self || !dom) { + kfree(mon); + kfree(self); + kfree(dom); + return -ENOMEM; + } + tn->monitors[bearer_id] = mon; + rwlock_init(&mon->lock); + mon->net = net; + mon->peer_cnt = 1; + mon->self = self; + self->domain = dom; + self->addr = tipc_own_addr(net); + self->is_up = true; + self->is_head = true; + INIT_LIST_HEAD(&self->list); + timer_setup(&mon->timer, mon_timeout, 0); + mon->timer_intv = msecs_to_jiffies(MON_TIMEOUT + (tn->random & 0xffff)); + mod_timer(&mon->timer, jiffies + mon->timer_intv); + return 0; +} + +void tipc_mon_delete(struct net *net, int bearer_id) +{ + struct tipc_net *tn = tipc_net(net); + struct tipc_monitor *mon = tipc_monitor(net, bearer_id); + struct tipc_peer *self; + struct tipc_peer *peer, *tmp; + + if (!mon) + return; + + self = get_self(net, bearer_id); + write_lock_bh(&mon->lock); + tn->monitors[bearer_id] = NULL; + list_for_each_entry_safe(peer, tmp, &self->list, list) { + list_del(&peer->list); + hlist_del(&peer->hash); + kfree(peer->domain); + kfree(peer); + } + mon->self = NULL; + write_unlock_bh(&mon->lock); + del_timer_sync(&mon->timer); + kfree(self->domain); + kfree(self); + kfree(mon); +} + +void tipc_mon_reinit_self(struct net *net) +{ + struct tipc_monitor *mon; + int bearer_id; + + for (bearer_id = 0; bearer_id < MAX_BEARERS; bearer_id++) { + mon = tipc_monitor(net, bearer_id); + if (!mon) + continue; + write_lock_bh(&mon->lock); + mon->self->addr = tipc_own_addr(net); + write_unlock_bh(&mon->lock); + } +} + +int tipc_nl_monitor_set_threshold(struct net *net, u32 cluster_size) +{ + struct tipc_net *tn = tipc_net(net); + + if (cluster_size > TIPC_CLUSTER_SIZE) + return -EINVAL; + + tn->mon_threshold = cluster_size; + + return 0; +} + +int tipc_nl_monitor_get_threshold(struct net *net) +{ + struct tipc_net *tn = tipc_net(net); + + return tn->mon_threshold; +} + +static int __tipc_nl_add_monitor_peer(struct tipc_peer *peer, + struct tipc_nl_msg *msg) +{ + struct tipc_mon_domain *dom = peer->domain; + struct nlattr *attrs; + void *hdr; + + hdr = genlmsg_put(msg->skb, msg->portid, msg->seq, &tipc_genl_family, + NLM_F_MULTI, TIPC_NL_MON_PEER_GET); + if (!hdr) + return -EMSGSIZE; + + attrs = nla_nest_start_noflag(msg->skb, TIPC_NLA_MON_PEER); + if (!attrs) + goto msg_full; + + if (nla_put_u32(msg->skb, TIPC_NLA_MON_PEER_ADDR, peer->addr)) + goto attr_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_MON_PEER_APPLIED, peer->applied)) + goto attr_msg_full; + + if (peer->is_up) + if (nla_put_flag(msg->skb, TIPC_NLA_MON_PEER_UP)) + goto attr_msg_full; + if (peer->is_local) + if (nla_put_flag(msg->skb, TIPC_NLA_MON_PEER_LOCAL)) + goto attr_msg_full; + if (peer->is_head) + if (nla_put_flag(msg->skb, TIPC_NLA_MON_PEER_HEAD)) + goto attr_msg_full; + + if (dom) { + if (nla_put_u32(msg->skb, TIPC_NLA_MON_PEER_DOMGEN, dom->gen)) + goto attr_msg_full; + if (nla_put_u64_64bit(msg->skb, TIPC_NLA_MON_PEER_UPMAP, + dom->up_map, TIPC_NLA_MON_PEER_PAD)) + goto attr_msg_full; + if (nla_put(msg->skb, TIPC_NLA_MON_PEER_MEMBERS, + dom->member_cnt * sizeof(u32), &dom->members)) + goto attr_msg_full; + } + + nla_nest_end(msg->skb, attrs); + genlmsg_end(msg->skb, hdr); + return 0; + +attr_msg_full: + nla_nest_cancel(msg->skb, attrs); +msg_full: + genlmsg_cancel(msg->skb, hdr); + + return -EMSGSIZE; +} + +int tipc_nl_add_monitor_peer(struct net *net, struct tipc_nl_msg *msg, + u32 bearer_id, u32 *prev_node) +{ + struct tipc_monitor *mon = tipc_monitor(net, bearer_id); + struct tipc_peer *peer; + + if (!mon) + return -EINVAL; + + read_lock_bh(&mon->lock); + peer = mon->self; + do { + if (*prev_node) { + if (peer->addr == *prev_node) + *prev_node = 0; + else + continue; + } + if (__tipc_nl_add_monitor_peer(peer, msg)) { + *prev_node = peer->addr; + read_unlock_bh(&mon->lock); + return -EMSGSIZE; + } + } while ((peer = peer_nxt(peer)) != mon->self); + read_unlock_bh(&mon->lock); + + return 0; +} + +int __tipc_nl_add_monitor(struct net *net, struct tipc_nl_msg *msg, + u32 bearer_id) +{ + struct tipc_monitor *mon = tipc_monitor(net, bearer_id); + char bearer_name[TIPC_MAX_BEARER_NAME]; + struct nlattr *attrs; + void *hdr; + int ret; + + ret = tipc_bearer_get_name(net, bearer_name, bearer_id); + if (ret || !mon) + return 0; + + hdr = genlmsg_put(msg->skb, msg->portid, msg->seq, &tipc_genl_family, + NLM_F_MULTI, TIPC_NL_MON_GET); + if (!hdr) + return -EMSGSIZE; + + attrs = nla_nest_start_noflag(msg->skb, TIPC_NLA_MON); + if (!attrs) + goto msg_full; + + read_lock_bh(&mon->lock); + if (nla_put_u32(msg->skb, TIPC_NLA_MON_REF, bearer_id)) + goto attr_msg_full; + if (tipc_mon_is_active(net, mon)) + if (nla_put_flag(msg->skb, TIPC_NLA_MON_ACTIVE)) + goto attr_msg_full; + if (nla_put_string(msg->skb, TIPC_NLA_MON_BEARER_NAME, bearer_name)) + goto attr_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_MON_PEERCNT, mon->peer_cnt)) + goto attr_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_MON_LISTGEN, mon->list_gen)) + goto attr_msg_full; + + read_unlock_bh(&mon->lock); + nla_nest_end(msg->skb, attrs); + genlmsg_end(msg->skb, hdr); + + return 0; + +attr_msg_full: + read_unlock_bh(&mon->lock); + nla_nest_cancel(msg->skb, attrs); +msg_full: + genlmsg_cancel(msg->skb, hdr); + + return -EMSGSIZE; +} diff --git a/net/tipc/monitor.h b/net/tipc/monitor.h new file mode 100644 index 000000000..ed63d2e65 --- /dev/null +++ b/net/tipc/monitor.h @@ -0,0 +1,83 @@ +/* + * net/tipc/monitor.h + * + * Copyright (c) 2015, Ericsson AB + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _TIPC_MONITOR_H +#define _TIPC_MONITOR_H + +#include "netlink.h" + +/* struct tipc_mon_state: link instance's cache of monitor list and domain state + * @list_gen: current generation of this node's monitor list + * @gen: current generation of this node's local domain + * @peer_gen: most recent domain generation received from peer + * @acked_gen: most recent generation of self's domain acked by peer + * @monitoring: this peer endpoint should continuously monitored + * @probing: peer endpoint should be temporarily probed for potential loss + * @synched: domain record's generation has been synched with peer after reset + */ +struct tipc_mon_state { + u16 list_gen; + u16 peer_gen; + u16 acked_gen; + bool monitoring :1; + bool probing :1; + bool reset :1; + bool synched :1; +}; + +int tipc_mon_create(struct net *net, int bearer_id); +void tipc_mon_delete(struct net *net, int bearer_id); + +void tipc_mon_peer_up(struct net *net, u32 addr, int bearer_id); +void tipc_mon_peer_down(struct net *net, u32 addr, int bearer_id); +void tipc_mon_prep(struct net *net, void *data, int *dlen, + struct tipc_mon_state *state, int bearer_id); +void tipc_mon_rcv(struct net *net, void *data, u16 dlen, u32 addr, + struct tipc_mon_state *state, int bearer_id); +void tipc_mon_get_state(struct net *net, u32 addr, + struct tipc_mon_state *state, + int bearer_id); +void tipc_mon_remove_peer(struct net *net, u32 addr, int bearer_id); + +int tipc_nl_monitor_set_threshold(struct net *net, u32 cluster_size); +int tipc_nl_monitor_get_threshold(struct net *net); +int __tipc_nl_add_monitor(struct net *net, struct tipc_nl_msg *msg, + u32 bearer_id); +int tipc_nl_add_monitor_peer(struct net *net, struct tipc_nl_msg *msg, + u32 bearer_id, u32 *prev_node); +void tipc_mon_reinit_self(struct net *net); + +extern const int tipc_max_domain_size; +#endif diff --git a/net/tipc/msg.c b/net/tipc/msg.c new file mode 100644 index 000000000..5c9fd4791 --- /dev/null +++ b/net/tipc/msg.c @@ -0,0 +1,851 @@ +/* + * net/tipc/msg.c: TIPC message header routines + * + * Copyright (c) 2000-2006, 2014-2015, Ericsson AB + * Copyright (c) 2005, 2010-2011, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include +#include "core.h" +#include "msg.h" +#include "addr.h" +#include "name_table.h" +#include "crypto.h" + +#define BUF_ALIGN(x) ALIGN(x, 4) +#define MAX_FORWARD_SIZE 1024 +#ifdef CONFIG_TIPC_CRYPTO +#define BUF_HEADROOM ALIGN(((LL_MAX_HEADER + 48) + EHDR_MAX_SIZE), 16) +#define BUF_OVERHEAD (BUF_HEADROOM + TIPC_AES_GCM_TAG_SIZE) +#else +#define BUF_HEADROOM (LL_MAX_HEADER + 48) +#define BUF_OVERHEAD BUF_HEADROOM +#endif + +const int one_page_mtu = PAGE_SIZE - SKB_DATA_ALIGN(BUF_OVERHEAD) - + SKB_DATA_ALIGN(sizeof(struct skb_shared_info)); + +/** + * tipc_buf_acquire - creates a TIPC message buffer + * @size: message size (including TIPC header) + * @gfp: memory allocation flags + * + * Return: a new buffer with data pointers set to the specified size. + * + * NOTE: + * Headroom is reserved to allow prepending of a data link header. + * There may also be unrequested tailroom present at the buffer's end. + */ +struct sk_buff *tipc_buf_acquire(u32 size, gfp_t gfp) +{ + struct sk_buff *skb; + + skb = alloc_skb_fclone(BUF_OVERHEAD + size, gfp); + if (skb) { + skb_reserve(skb, BUF_HEADROOM); + skb_put(skb, size); + skb->next = NULL; + } + return skb; +} + +void tipc_msg_init(u32 own_node, struct tipc_msg *m, u32 user, u32 type, + u32 hsize, u32 dnode) +{ + memset(m, 0, hsize); + msg_set_version(m); + msg_set_user(m, user); + msg_set_hdr_sz(m, hsize); + msg_set_size(m, hsize); + msg_set_prevnode(m, own_node); + msg_set_type(m, type); + if (hsize > SHORT_H_SIZE) { + msg_set_orignode(m, own_node); + msg_set_destnode(m, dnode); + } +} + +struct sk_buff *tipc_msg_create(uint user, uint type, + uint hdr_sz, uint data_sz, u32 dnode, + u32 onode, u32 dport, u32 oport, int errcode) +{ + struct tipc_msg *msg; + struct sk_buff *buf; + + buf = tipc_buf_acquire(hdr_sz + data_sz, GFP_ATOMIC); + if (unlikely(!buf)) + return NULL; + + msg = buf_msg(buf); + tipc_msg_init(onode, msg, user, type, hdr_sz, dnode); + msg_set_size(msg, hdr_sz + data_sz); + msg_set_origport(msg, oport); + msg_set_destport(msg, dport); + msg_set_errcode(msg, errcode); + return buf; +} + +/* tipc_buf_append(): Append a buffer to the fragment list of another buffer + * @*headbuf: in: NULL for first frag, otherwise value returned from prev call + * out: set when successful non-complete reassembly, otherwise NULL + * @*buf: in: the buffer to append. Always defined + * out: head buf after successful complete reassembly, otherwise NULL + * Returns 1 when reassembly complete, otherwise 0 + */ +int tipc_buf_append(struct sk_buff **headbuf, struct sk_buff **buf) +{ + struct sk_buff *head = *headbuf; + struct sk_buff *frag = *buf; + struct sk_buff *tail = NULL; + struct tipc_msg *msg; + u32 fragid; + int delta; + bool headstolen; + + if (!frag) + goto err; + + msg = buf_msg(frag); + fragid = msg_type(msg); + frag->next = NULL; + skb_pull(frag, msg_hdr_sz(msg)); + + if (fragid == FIRST_FRAGMENT) { + if (unlikely(head)) + goto err; + *buf = NULL; + if (skb_has_frag_list(frag) && __skb_linearize(frag)) + goto err; + frag = skb_unshare(frag, GFP_ATOMIC); + if (unlikely(!frag)) + goto err; + head = *headbuf = frag; + TIPC_SKB_CB(head)->tail = NULL; + return 0; + } + + if (!head) + goto err; + + if (skb_try_coalesce(head, frag, &headstolen, &delta)) { + kfree_skb_partial(frag, headstolen); + } else { + tail = TIPC_SKB_CB(head)->tail; + if (!skb_has_frag_list(head)) + skb_shinfo(head)->frag_list = frag; + else + tail->next = frag; + head->truesize += frag->truesize; + head->data_len += frag->len; + head->len += frag->len; + TIPC_SKB_CB(head)->tail = frag; + } + + if (fragid == LAST_FRAGMENT) { + TIPC_SKB_CB(head)->validated = 0; + if (unlikely(!tipc_msg_validate(&head))) + goto err; + *buf = head; + TIPC_SKB_CB(head)->tail = NULL; + *headbuf = NULL; + return 1; + } + *buf = NULL; + return 0; +err: + kfree_skb(*buf); + kfree_skb(*headbuf); + *buf = *headbuf = NULL; + return 0; +} + +/** + * tipc_msg_append(): Append data to tail of an existing buffer queue + * @_hdr: header to be used + * @m: the data to be appended + * @mss: max allowable size of buffer + * @dlen: size of data to be appended + * @txq: queue to append to + * + * Return: the number of 1k blocks appended or errno value + */ +int tipc_msg_append(struct tipc_msg *_hdr, struct msghdr *m, int dlen, + int mss, struct sk_buff_head *txq) +{ + struct sk_buff *skb; + int accounted, total, curr; + int mlen, cpy, rem = dlen; + struct tipc_msg *hdr; + + skb = skb_peek_tail(txq); + accounted = skb ? msg_blocks(buf_msg(skb)) : 0; + total = accounted; + + do { + if (!skb || skb->len >= mss) { + skb = tipc_buf_acquire(mss, GFP_KERNEL); + if (unlikely(!skb)) + return -ENOMEM; + skb_orphan(skb); + skb_trim(skb, MIN_H_SIZE); + hdr = buf_msg(skb); + skb_copy_to_linear_data(skb, _hdr, MIN_H_SIZE); + msg_set_hdr_sz(hdr, MIN_H_SIZE); + msg_set_size(hdr, MIN_H_SIZE); + __skb_queue_tail(txq, skb); + total += 1; + } + hdr = buf_msg(skb); + curr = msg_blocks(hdr); + mlen = msg_size(hdr); + cpy = min_t(size_t, rem, mss - mlen); + if (cpy != copy_from_iter(skb->data + mlen, cpy, &m->msg_iter)) + return -EFAULT; + msg_set_size(hdr, mlen + cpy); + skb_put(skb, cpy); + rem -= cpy; + total += msg_blocks(hdr) - curr; + } while (rem > 0); + return total - accounted; +} + +/* tipc_msg_validate - validate basic format of received message + * + * This routine ensures a TIPC message has an acceptable header, and at least + * as much data as the header indicates it should. The routine also ensures + * that the entire message header is stored in the main fragment of the message + * buffer, to simplify future access to message header fields. + * + * Note: Having extra info present in the message header or data areas is OK. + * TIPC will ignore the excess, under the assumption that it is optional info + * introduced by a later release of the protocol. + */ +bool tipc_msg_validate(struct sk_buff **_skb) +{ + struct sk_buff *skb = *_skb; + struct tipc_msg *hdr; + int msz, hsz; + + /* Ensure that flow control ratio condition is satisfied */ + if (unlikely(skb->truesize / buf_roundup_len(skb) >= 4)) { + skb = skb_copy_expand(skb, BUF_HEADROOM, 0, GFP_ATOMIC); + if (!skb) + return false; + kfree_skb(*_skb); + *_skb = skb; + } + + if (unlikely(TIPC_SKB_CB(skb)->validated)) + return true; + + if (unlikely(!pskb_may_pull(skb, MIN_H_SIZE))) + return false; + + hsz = msg_hdr_sz(buf_msg(skb)); + if (unlikely(hsz < MIN_H_SIZE) || (hsz > MAX_H_SIZE)) + return false; + if (unlikely(!pskb_may_pull(skb, hsz))) + return false; + + hdr = buf_msg(skb); + if (unlikely(msg_version(hdr) != TIPC_VERSION)) + return false; + + msz = msg_size(hdr); + if (unlikely(msz < hsz)) + return false; + if (unlikely((msz - hsz) > TIPC_MAX_USER_MSG_SIZE)) + return false; + if (unlikely(skb->len < msz)) + return false; + + TIPC_SKB_CB(skb)->validated = 1; + return true; +} + +/** + * tipc_msg_fragment - build a fragment skb list for TIPC message + * + * @skb: TIPC message skb + * @hdr: internal msg header to be put on the top of the fragments + * @pktmax: max size of a fragment incl. the header + * @frags: returned fragment skb list + * + * Return: 0 if the fragmentation is successful, otherwise: -EINVAL + * or -ENOMEM + */ +int tipc_msg_fragment(struct sk_buff *skb, const struct tipc_msg *hdr, + int pktmax, struct sk_buff_head *frags) +{ + int pktno, nof_fragms, dsz, dmax, eat; + struct tipc_msg *_hdr; + struct sk_buff *_skb; + u8 *data; + + /* Non-linear buffer? */ + if (skb_linearize(skb)) + return -ENOMEM; + + data = (u8 *)skb->data; + dsz = msg_size(buf_msg(skb)); + dmax = pktmax - INT_H_SIZE; + if (dsz <= dmax || !dmax) + return -EINVAL; + + nof_fragms = dsz / dmax + 1; + for (pktno = 1; pktno <= nof_fragms; pktno++) { + if (pktno < nof_fragms) + eat = dmax; + else + eat = dsz % dmax; + /* Allocate a new fragment */ + _skb = tipc_buf_acquire(INT_H_SIZE + eat, GFP_ATOMIC); + if (!_skb) + goto error; + skb_orphan(_skb); + __skb_queue_tail(frags, _skb); + /* Copy header & data to the fragment */ + skb_copy_to_linear_data(_skb, hdr, INT_H_SIZE); + skb_copy_to_linear_data_offset(_skb, INT_H_SIZE, data, eat); + data += eat; + /* Update the fragment's header */ + _hdr = buf_msg(_skb); + msg_set_fragm_no(_hdr, pktno); + msg_set_nof_fragms(_hdr, nof_fragms); + msg_set_size(_hdr, INT_H_SIZE + eat); + } + return 0; + +error: + __skb_queue_purge(frags); + __skb_queue_head_init(frags); + return -ENOMEM; +} + +/** + * tipc_msg_build - create buffer chain containing specified header and data + * @mhdr: Message header, to be prepended to data + * @m: User message + * @offset: buffer offset for fragmented messages (FIXME) + * @dsz: Total length of user data + * @pktmax: Max packet size that can be used + * @list: Buffer or chain of buffers to be returned to caller + * + * Note that the recursive call we are making here is safe, since it can + * logically go only one further level down. + * + * Return: message data size or errno: -ENOMEM, -EFAULT + */ +int tipc_msg_build(struct tipc_msg *mhdr, struct msghdr *m, int offset, + int dsz, int pktmax, struct sk_buff_head *list) +{ + int mhsz = msg_hdr_sz(mhdr); + struct tipc_msg pkthdr; + int msz = mhsz + dsz; + int pktrem = pktmax; + struct sk_buff *skb; + int drem = dsz; + int pktno = 1; + char *pktpos; + int pktsz; + int rc; + + msg_set_size(mhdr, msz); + + /* No fragmentation needed? */ + if (likely(msz <= pktmax)) { + skb = tipc_buf_acquire(msz, GFP_KERNEL); + + /* Fall back to smaller MTU if node local message */ + if (unlikely(!skb)) { + if (pktmax != MAX_MSG_SIZE) + return -ENOMEM; + rc = tipc_msg_build(mhdr, m, offset, dsz, + one_page_mtu, list); + if (rc != dsz) + return rc; + if (tipc_msg_assemble(list)) + return dsz; + return -ENOMEM; + } + skb_orphan(skb); + __skb_queue_tail(list, skb); + skb_copy_to_linear_data(skb, mhdr, mhsz); + pktpos = skb->data + mhsz; + if (copy_from_iter_full(pktpos, dsz, &m->msg_iter)) + return dsz; + rc = -EFAULT; + goto error; + } + + /* Prepare reusable fragment header */ + tipc_msg_init(msg_prevnode(mhdr), &pkthdr, MSG_FRAGMENTER, + FIRST_FRAGMENT, INT_H_SIZE, msg_destnode(mhdr)); + msg_set_size(&pkthdr, pktmax); + msg_set_fragm_no(&pkthdr, pktno); + msg_set_importance(&pkthdr, msg_importance(mhdr)); + + /* Prepare first fragment */ + skb = tipc_buf_acquire(pktmax, GFP_KERNEL); + if (!skb) + return -ENOMEM; + skb_orphan(skb); + __skb_queue_tail(list, skb); + pktpos = skb->data; + skb_copy_to_linear_data(skb, &pkthdr, INT_H_SIZE); + pktpos += INT_H_SIZE; + pktrem -= INT_H_SIZE; + skb_copy_to_linear_data_offset(skb, INT_H_SIZE, mhdr, mhsz); + pktpos += mhsz; + pktrem -= mhsz; + + do { + if (drem < pktrem) + pktrem = drem; + + if (!copy_from_iter_full(pktpos, pktrem, &m->msg_iter)) { + rc = -EFAULT; + goto error; + } + drem -= pktrem; + + if (!drem) + break; + + /* Prepare new fragment: */ + if (drem < (pktmax - INT_H_SIZE)) + pktsz = drem + INT_H_SIZE; + else + pktsz = pktmax; + skb = tipc_buf_acquire(pktsz, GFP_KERNEL); + if (!skb) { + rc = -ENOMEM; + goto error; + } + skb_orphan(skb); + __skb_queue_tail(list, skb); + msg_set_type(&pkthdr, FRAGMENT); + msg_set_size(&pkthdr, pktsz); + msg_set_fragm_no(&pkthdr, ++pktno); + skb_copy_to_linear_data(skb, &pkthdr, INT_H_SIZE); + pktpos = skb->data + INT_H_SIZE; + pktrem = pktsz - INT_H_SIZE; + + } while (1); + msg_set_type(buf_msg(skb), LAST_FRAGMENT); + return dsz; +error: + __skb_queue_purge(list); + __skb_queue_head_init(list); + return rc; +} + +/** + * tipc_msg_bundle - Append contents of a buffer to tail of an existing one + * @bskb: the bundle buffer to append to + * @msg: message to be appended + * @max: max allowable size for the bundle buffer + * + * Return: "true" if bundling has been performed, otherwise "false" + */ +static bool tipc_msg_bundle(struct sk_buff *bskb, struct tipc_msg *msg, + u32 max) +{ + struct tipc_msg *bmsg = buf_msg(bskb); + u32 msz, bsz, offset, pad; + + msz = msg_size(msg); + bsz = msg_size(bmsg); + offset = BUF_ALIGN(bsz); + pad = offset - bsz; + + if (unlikely(skb_tailroom(bskb) < (pad + msz))) + return false; + if (unlikely(max < (offset + msz))) + return false; + + skb_put(bskb, pad + msz); + skb_copy_to_linear_data_offset(bskb, offset, msg, msz); + msg_set_size(bmsg, offset + msz); + msg_set_msgcnt(bmsg, msg_msgcnt(bmsg) + 1); + return true; +} + +/** + * tipc_msg_try_bundle - Try to bundle a new message to the last one + * @tskb: the last/target message to which the new one will be appended + * @skb: the new message skb pointer + * @mss: max message size (header inclusive) + * @dnode: destination node for the message + * @new_bundle: if this call made a new bundle or not + * + * Return: "true" if the new message skb is potential for bundling this time or + * later, in the case a bundling has been done this time, the skb is consumed + * (the skb pointer = NULL). + * Otherwise, "false" if the skb cannot be bundled at all. + */ +bool tipc_msg_try_bundle(struct sk_buff *tskb, struct sk_buff **skb, u32 mss, + u32 dnode, bool *new_bundle) +{ + struct tipc_msg *msg, *inner, *outer; + u32 tsz; + + /* First, check if the new buffer is suitable for bundling */ + msg = buf_msg(*skb); + if (msg_user(msg) == MSG_FRAGMENTER) + return false; + if (msg_user(msg) == TUNNEL_PROTOCOL) + return false; + if (msg_user(msg) == BCAST_PROTOCOL) + return false; + if (mss <= INT_H_SIZE + msg_size(msg)) + return false; + + /* Ok, but the last/target buffer can be empty? */ + if (unlikely(!tskb)) + return true; + + /* Is it a bundle already? Try to bundle the new message to it */ + if (msg_user(buf_msg(tskb)) == MSG_BUNDLER) { + *new_bundle = false; + goto bundle; + } + + /* Make a new bundle of the two messages if possible */ + tsz = msg_size(buf_msg(tskb)); + if (unlikely(mss < BUF_ALIGN(INT_H_SIZE + tsz) + msg_size(msg))) + return true; + if (unlikely(pskb_expand_head(tskb, INT_H_SIZE, mss - tsz - INT_H_SIZE, + GFP_ATOMIC))) + return true; + inner = buf_msg(tskb); + skb_push(tskb, INT_H_SIZE); + outer = buf_msg(tskb); + tipc_msg_init(msg_prevnode(inner), outer, MSG_BUNDLER, 0, INT_H_SIZE, + dnode); + msg_set_importance(outer, msg_importance(inner)); + msg_set_size(outer, INT_H_SIZE + tsz); + msg_set_msgcnt(outer, 1); + *new_bundle = true; + +bundle: + if (likely(tipc_msg_bundle(tskb, msg, mss))) { + consume_skb(*skb); + *skb = NULL; + } + return true; +} + +/** + * tipc_msg_extract(): extract bundled inner packet from buffer + * @skb: buffer to be extracted from. + * @iskb: extracted inner buffer, to be returned + * @pos: position in outer message of msg to be extracted. + * Returns position of next msg. + * Consumes outer buffer when last packet extracted + * Return: true when there is an extracted buffer, otherwise false + */ +bool tipc_msg_extract(struct sk_buff *skb, struct sk_buff **iskb, int *pos) +{ + struct tipc_msg *hdr, *ihdr; + int imsz; + + *iskb = NULL; + if (unlikely(skb_linearize(skb))) + goto none; + + hdr = buf_msg(skb); + if (unlikely(*pos > (msg_data_sz(hdr) - MIN_H_SIZE))) + goto none; + + ihdr = (struct tipc_msg *)(msg_data(hdr) + *pos); + imsz = msg_size(ihdr); + + if ((*pos + imsz) > msg_data_sz(hdr)) + goto none; + + *iskb = tipc_buf_acquire(imsz, GFP_ATOMIC); + if (!*iskb) + goto none; + + skb_copy_to_linear_data(*iskb, ihdr, imsz); + if (unlikely(!tipc_msg_validate(iskb))) + goto none; + + *pos += BUF_ALIGN(imsz); + return true; +none: + kfree_skb(skb); + kfree_skb(*iskb); + *iskb = NULL; + return false; +} + +/** + * tipc_msg_reverse(): swap source and destination addresses and add error code + * @own_node: originating node id for reversed message + * @skb: buffer containing message to be reversed; will be consumed + * @err: error code to be set in message, if any + * Replaces consumed buffer with new one when successful + * Return: true if success, otherwise false + */ +bool tipc_msg_reverse(u32 own_node, struct sk_buff **skb, int err) +{ + struct sk_buff *_skb = *skb; + struct tipc_msg *_hdr, *hdr; + int hlen, dlen; + + if (skb_linearize(_skb)) + goto exit; + _hdr = buf_msg(_skb); + dlen = min_t(uint, msg_data_sz(_hdr), MAX_FORWARD_SIZE); + hlen = msg_hdr_sz(_hdr); + + if (msg_dest_droppable(_hdr)) + goto exit; + if (msg_errcode(_hdr)) + goto exit; + + /* Never return SHORT header */ + if (hlen == SHORT_H_SIZE) + hlen = BASIC_H_SIZE; + + /* Don't return data along with SYN+, - sender has a clone */ + if (msg_is_syn(_hdr) && err == TIPC_ERR_OVERLOAD) + dlen = 0; + + /* Allocate new buffer to return */ + *skb = tipc_buf_acquire(hlen + dlen, GFP_ATOMIC); + if (!*skb) + goto exit; + memcpy((*skb)->data, _skb->data, msg_hdr_sz(_hdr)); + memcpy((*skb)->data + hlen, msg_data(_hdr), dlen); + + /* Build reverse header in new buffer */ + hdr = buf_msg(*skb); + msg_set_hdr_sz(hdr, hlen); + msg_set_errcode(hdr, err); + msg_set_non_seq(hdr, 0); + msg_set_origport(hdr, msg_destport(_hdr)); + msg_set_destport(hdr, msg_origport(_hdr)); + msg_set_destnode(hdr, msg_prevnode(_hdr)); + msg_set_prevnode(hdr, own_node); + msg_set_orignode(hdr, own_node); + msg_set_size(hdr, hlen + dlen); + skb_orphan(_skb); + kfree_skb(_skb); + return true; +exit: + kfree_skb(_skb); + *skb = NULL; + return false; +} + +bool tipc_msg_skb_clone(struct sk_buff_head *msg, struct sk_buff_head *cpy) +{ + struct sk_buff *skb, *_skb; + + skb_queue_walk(msg, skb) { + _skb = skb_clone(skb, GFP_ATOMIC); + if (!_skb) { + __skb_queue_purge(cpy); + pr_err_ratelimited("Failed to clone buffer chain\n"); + return false; + } + __skb_queue_tail(cpy, _skb); + } + return true; +} + +/** + * tipc_msg_lookup_dest(): try to find new destination for named message + * @net: pointer to associated network namespace + * @skb: the buffer containing the message. + * @err: error code to be used by caller if lookup fails + * Does not consume buffer + * Return: true if a destination is found, false otherwise + */ +bool tipc_msg_lookup_dest(struct net *net, struct sk_buff *skb, int *err) +{ + struct tipc_msg *msg = buf_msg(skb); + u32 scope = msg_lookup_scope(msg); + u32 self = tipc_own_addr(net); + u32 inst = msg_nameinst(msg); + struct tipc_socket_addr sk; + struct tipc_uaddr ua; + + if (!msg_isdata(msg)) + return false; + if (!msg_named(msg)) + return false; + if (msg_errcode(msg)) + return false; + *err = TIPC_ERR_NO_NAME; + if (skb_linearize(skb)) + return false; + msg = buf_msg(skb); + if (msg_reroute_cnt(msg)) + return false; + tipc_uaddr(&ua, TIPC_SERVICE_RANGE, scope, + msg_nametype(msg), inst, inst); + sk.node = tipc_scope2node(net, scope); + if (!tipc_nametbl_lookup_anycast(net, &ua, &sk)) + return false; + msg_incr_reroute_cnt(msg); + if (sk.node != self) + msg_set_prevnode(msg, self); + msg_set_destnode(msg, sk.node); + msg_set_destport(msg, sk.ref); + *err = TIPC_OK; + + return true; +} + +/* tipc_msg_assemble() - assemble chain of fragments into one message + */ +bool tipc_msg_assemble(struct sk_buff_head *list) +{ + struct sk_buff *skb, *tmp = NULL; + + if (skb_queue_len(list) == 1) + return true; + + while ((skb = __skb_dequeue(list))) { + skb->next = NULL; + if (tipc_buf_append(&tmp, &skb)) { + __skb_queue_tail(list, skb); + return true; + } + if (!tmp) + break; + } + __skb_queue_purge(list); + __skb_queue_head_init(list); + pr_warn("Failed do assemble buffer\n"); + return false; +} + +/* tipc_msg_reassemble() - clone a buffer chain of fragments and + * reassemble the clones into one message + */ +bool tipc_msg_reassemble(struct sk_buff_head *list, struct sk_buff_head *rcvq) +{ + struct sk_buff *skb, *_skb; + struct sk_buff *frag = NULL; + struct sk_buff *head = NULL; + int hdr_len; + + /* Copy header if single buffer */ + if (skb_queue_len(list) == 1) { + skb = skb_peek(list); + hdr_len = skb_headroom(skb) + msg_hdr_sz(buf_msg(skb)); + _skb = __pskb_copy(skb, hdr_len, GFP_ATOMIC); + if (!_skb) + return false; + __skb_queue_tail(rcvq, _skb); + return true; + } + + /* Clone all fragments and reassemble */ + skb_queue_walk(list, skb) { + frag = skb_clone(skb, GFP_ATOMIC); + if (!frag) + goto error; + frag->next = NULL; + if (tipc_buf_append(&head, &frag)) + break; + if (!head) + goto error; + } + __skb_queue_tail(rcvq, frag); + return true; +error: + pr_warn("Failed do clone local mcast rcv buffer\n"); + kfree_skb(head); + return false; +} + +bool tipc_msg_pskb_copy(u32 dst, struct sk_buff_head *msg, + struct sk_buff_head *cpy) +{ + struct sk_buff *skb, *_skb; + + skb_queue_walk(msg, skb) { + _skb = pskb_copy(skb, GFP_ATOMIC); + if (!_skb) { + __skb_queue_purge(cpy); + return false; + } + msg_set_destnode(buf_msg(_skb), dst); + __skb_queue_tail(cpy, _skb); + } + return true; +} + +/* tipc_skb_queue_sorted(); sort pkt into list according to sequence number + * @list: list to be appended to + * @seqno: sequence number of buffer to add + * @skb: buffer to add + */ +bool __tipc_skb_queue_sorted(struct sk_buff_head *list, u16 seqno, + struct sk_buff *skb) +{ + struct sk_buff *_skb, *tmp; + + if (skb_queue_empty(list) || less(seqno, buf_seqno(skb_peek(list)))) { + __skb_queue_head(list, skb); + return true; + } + + if (more(seqno, buf_seqno(skb_peek_tail(list)))) { + __skb_queue_tail(list, skb); + return true; + } + + skb_queue_walk_safe(list, _skb, tmp) { + if (more(seqno, buf_seqno(_skb))) + continue; + if (seqno == buf_seqno(_skb)) + break; + __skb_queue_before(list, _skb, skb); + return true; + } + kfree_skb(skb); + return false; +} + +void tipc_skb_reject(struct net *net, int err, struct sk_buff *skb, + struct sk_buff_head *xmitq) +{ + if (tipc_msg_reverse(tipc_own_addr(net), &skb, err)) + __skb_queue_tail(xmitq, skb); +} diff --git a/net/tipc/msg.h b/net/tipc/msg.h new file mode 100644 index 000000000..c5eec1621 --- /dev/null +++ b/net/tipc/msg.h @@ -0,0 +1,1310 @@ +/* + * net/tipc/msg.h: Include file for TIPC message header routines + * + * Copyright (c) 2000-2007, 2014-2017 Ericsson AB + * Copyright (c) 2005-2008, 2010-2011, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _TIPC_MSG_H +#define _TIPC_MSG_H + +#include +#include "core.h" + +/* + * Constants and routines used to read and write TIPC payload message headers + * + * Note: Some items are also used with TIPC internal message headers + */ +#define TIPC_VERSION 2 +struct plist; + +/* + * Payload message users are defined in TIPC's public API: + * - TIPC_LOW_IMPORTANCE + * - TIPC_MEDIUM_IMPORTANCE + * - TIPC_HIGH_IMPORTANCE + * - TIPC_CRITICAL_IMPORTANCE + */ +#define TIPC_SYSTEM_IMPORTANCE 4 + + +/* + * Payload message types + */ +#define TIPC_CONN_MSG 0 +#define TIPC_MCAST_MSG 1 +#define TIPC_NAMED_MSG 2 +#define TIPC_DIRECT_MSG 3 +#define TIPC_GRP_MEMBER_EVT 4 +#define TIPC_GRP_BCAST_MSG 5 +#define TIPC_GRP_MCAST_MSG 6 +#define TIPC_GRP_UCAST_MSG 7 + +/* + * Internal message users + */ +#define BCAST_PROTOCOL 5 +#define MSG_BUNDLER 6 +#define LINK_PROTOCOL 7 +#define CONN_MANAGER 8 +#define GROUP_PROTOCOL 9 +#define TUNNEL_PROTOCOL 10 +#define NAME_DISTRIBUTOR 11 +#define MSG_FRAGMENTER 12 +#define LINK_CONFIG 13 +#define MSG_CRYPTO 14 +#define SOCK_WAKEUP 14 /* pseudo user */ +#define TOP_SRV 15 /* pseudo user */ + +/* + * Message header sizes + */ +#define SHORT_H_SIZE 24 /* In-cluster basic payload message */ +#define BASIC_H_SIZE 32 /* Basic payload message */ +#define NAMED_H_SIZE 40 /* Named payload message */ +#define MCAST_H_SIZE 44 /* Multicast payload message */ +#define GROUP_H_SIZE 44 /* Group payload message */ +#define INT_H_SIZE 40 /* Internal messages */ +#define MIN_H_SIZE 24 /* Smallest legal TIPC header size */ +#define MAX_H_SIZE 60 /* Largest possible TIPC header size */ + +#define MAX_MSG_SIZE (MAX_H_SIZE + TIPC_MAX_USER_MSG_SIZE) +#define TIPC_MEDIA_INFO_OFFSET 5 + +extern const int one_page_mtu; + +struct tipc_skb_cb { + union { + struct { + struct sk_buff *tail; + unsigned long nxt_retr; + unsigned long retr_stamp; + u32 bytes_read; + u32 orig_member; + u16 chain_imp; + u16 ackers; + u16 retr_cnt; + } __packed; +#ifdef CONFIG_TIPC_CRYPTO + struct { + struct tipc_crypto *rx; + struct tipc_aead *last; + u8 recurs; + } tx_clone_ctx __packed; +#endif + } __packed; + union { + struct { + u8 validated:1; +#ifdef CONFIG_TIPC_CRYPTO + u8 encrypted:1; + u8 decrypted:1; +#define SKB_PROBING 1 +#define SKB_GRACING 2 + u8 xmit_type:2; + u8 tx_clone_deferred:1; +#endif + }; + u8 flags; + }; + u8 reserved; +#ifdef CONFIG_TIPC_CRYPTO + void *crypto_ctx; +#endif +} __packed; + +#define TIPC_SKB_CB(__skb) ((struct tipc_skb_cb *)&((__skb)->cb[0])) + +struct tipc_msg { + __be32 hdr[15]; +}; + +/* struct tipc_gap_ack - TIPC Gap ACK block + * @ack: seqno of the last consecutive packet in link deferdq + * @gap: number of gap packets since the last ack + * + * E.g: + * link deferdq: 1 2 3 4 10 11 13 14 15 20 + * --> Gap ACK blocks: <4, 5>, <11, 1>, <15, 4>, <20, 0> + */ +struct tipc_gap_ack { + __be16 ack; + __be16 gap; +}; + +/* struct tipc_gap_ack_blks + * @len: actual length of the record + * @ugack_cnt: number of Gap ACK blocks for unicast (following the broadcast + * ones) + * @start_index: starting index for "valid" broadcast Gap ACK blocks + * @bgack_cnt: number of Gap ACK blocks for broadcast in the record + * @gacks: array of Gap ACK blocks + * + * 31 16 15 0 + * +-------------+-------------+-------------+-------------+ + * | bgack_cnt | ugack_cnt | len | + * +-------------+-------------+-------------+-------------+ - + * | gap | ack | | + * +-------------+-------------+-------------+-------------+ > bc gacks + * : : : | + * +-------------+-------------+-------------+-------------+ - + * | gap | ack | | + * +-------------+-------------+-------------+-------------+ > uc gacks + * : : : | + * +-------------+-------------+-------------+-------------+ - + */ +struct tipc_gap_ack_blks { + __be16 len; + union { + u8 ugack_cnt; + u8 start_index; + }; + u8 bgack_cnt; + struct tipc_gap_ack gacks[]; +}; + +#define MAX_GAP_ACK_BLKS 128 +#define MAX_GAP_ACK_BLKS_SZ (sizeof(struct tipc_gap_ack_blks) + \ + sizeof(struct tipc_gap_ack) * MAX_GAP_ACK_BLKS) + +static inline struct tipc_msg *buf_msg(struct sk_buff *skb) +{ + return (struct tipc_msg *)skb->data; +} + +static inline u32 msg_word(struct tipc_msg *m, u32 pos) +{ + return ntohl(m->hdr[pos]); +} + +static inline void msg_set_word(struct tipc_msg *m, u32 w, u32 val) +{ + m->hdr[w] = htonl(val); +} + +static inline u32 msg_bits(struct tipc_msg *m, u32 w, u32 pos, u32 mask) +{ + return (msg_word(m, w) >> pos) & mask; +} + +static inline void msg_set_bits(struct tipc_msg *m, u32 w, + u32 pos, u32 mask, u32 val) +{ + val = (val & mask) << pos; + mask = mask << pos; + m->hdr[w] &= ~htonl(mask); + m->hdr[w] |= htonl(val); +} + +/* + * Word 0 + */ +static inline u32 msg_version(struct tipc_msg *m) +{ + return msg_bits(m, 0, 29, 7); +} + +static inline void msg_set_version(struct tipc_msg *m) +{ + msg_set_bits(m, 0, 29, 7, TIPC_VERSION); +} + +static inline u32 msg_user(struct tipc_msg *m) +{ + return msg_bits(m, 0, 25, 0xf); +} + +static inline u32 msg_isdata(struct tipc_msg *m) +{ + return msg_user(m) <= TIPC_CRITICAL_IMPORTANCE; +} + +static inline void msg_set_user(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 0, 25, 0xf, n); +} + +static inline u32 msg_hdr_sz(struct tipc_msg *m) +{ + return msg_bits(m, 0, 21, 0xf) << 2; +} + +static inline void msg_set_hdr_sz(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 0, 21, 0xf, n>>2); +} + +static inline u32 msg_size(struct tipc_msg *m) +{ + return msg_bits(m, 0, 0, 0x1ffff); +} + +static inline u32 msg_blocks(struct tipc_msg *m) +{ + return (msg_size(m) / 1024) + 1; +} + +static inline u32 msg_data_sz(struct tipc_msg *m) +{ + return msg_size(m) - msg_hdr_sz(m); +} + +static inline int msg_non_seq(struct tipc_msg *m) +{ + return msg_bits(m, 0, 20, 1); +} + +static inline void msg_set_non_seq(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 0, 20, 1, n); +} + +static inline int msg_is_syn(struct tipc_msg *m) +{ + return msg_bits(m, 0, 17, 1); +} + +static inline void msg_set_syn(struct tipc_msg *m, u32 d) +{ + msg_set_bits(m, 0, 17, 1, d); +} + +static inline int msg_dest_droppable(struct tipc_msg *m) +{ + return msg_bits(m, 0, 19, 1); +} + +static inline void msg_set_dest_droppable(struct tipc_msg *m, u32 d) +{ + msg_set_bits(m, 0, 19, 1, d); +} + +static inline int msg_is_keepalive(struct tipc_msg *m) +{ + return msg_bits(m, 0, 19, 1); +} + +static inline void msg_set_is_keepalive(struct tipc_msg *m, u32 d) +{ + msg_set_bits(m, 0, 19, 1, d); +} + +static inline int msg_src_droppable(struct tipc_msg *m) +{ + return msg_bits(m, 0, 18, 1); +} + +static inline void msg_set_src_droppable(struct tipc_msg *m, u32 d) +{ + msg_set_bits(m, 0, 18, 1, d); +} + +static inline int msg_ack_required(struct tipc_msg *m) +{ + return msg_bits(m, 0, 18, 1); +} + +static inline void msg_set_ack_required(struct tipc_msg *m) +{ + msg_set_bits(m, 0, 18, 1, 1); +} + +static inline int msg_nagle_ack(struct tipc_msg *m) +{ + return msg_bits(m, 0, 18, 1); +} + +static inline void msg_set_nagle_ack(struct tipc_msg *m) +{ + msg_set_bits(m, 0, 18, 1, 1); +} + +static inline bool msg_is_rcast(struct tipc_msg *m) +{ + return msg_bits(m, 0, 18, 0x1); +} + +static inline void msg_set_is_rcast(struct tipc_msg *m, bool d) +{ + msg_set_bits(m, 0, 18, 0x1, d); +} + +static inline void msg_set_size(struct tipc_msg *m, u32 sz) +{ + m->hdr[0] = htonl((msg_word(m, 0) & ~0x1ffff) | sz); +} + +static inline unchar *msg_data(struct tipc_msg *m) +{ + return ((unchar *)m) + msg_hdr_sz(m); +} + +static inline struct tipc_msg *msg_inner_hdr(struct tipc_msg *m) +{ + return (struct tipc_msg *)msg_data(m); +} + +/* + * Word 1 + */ +static inline u32 msg_type(struct tipc_msg *m) +{ + return msg_bits(m, 1, 29, 0x7); +} + +static inline void msg_set_type(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 1, 29, 0x7, n); +} + +static inline int msg_in_group(struct tipc_msg *m) +{ + int mtyp = msg_type(m); + + return mtyp >= TIPC_GRP_MEMBER_EVT && mtyp <= TIPC_GRP_UCAST_MSG; +} + +static inline bool msg_is_grp_evt(struct tipc_msg *m) +{ + return msg_type(m) == TIPC_GRP_MEMBER_EVT; +} + +static inline u32 msg_named(struct tipc_msg *m) +{ + return msg_type(m) == TIPC_NAMED_MSG; +} + +static inline u32 msg_mcast(struct tipc_msg *m) +{ + int mtyp = msg_type(m); + + return ((mtyp == TIPC_MCAST_MSG) || (mtyp == TIPC_GRP_BCAST_MSG) || + (mtyp == TIPC_GRP_MCAST_MSG)); +} + +static inline u32 msg_connected(struct tipc_msg *m) +{ + return msg_type(m) == TIPC_CONN_MSG; +} + +static inline u32 msg_direct(struct tipc_msg *m) +{ + return msg_type(m) == TIPC_DIRECT_MSG; +} + +static inline u32 msg_errcode(struct tipc_msg *m) +{ + return msg_bits(m, 1, 25, 0xf); +} + +static inline void msg_set_errcode(struct tipc_msg *m, u32 err) +{ + msg_set_bits(m, 1, 25, 0xf, err); +} + +static inline void msg_set_bulk(struct tipc_msg *m) +{ + msg_set_bits(m, 1, 28, 0x1, 1); +} + +static inline u32 msg_is_bulk(struct tipc_msg *m) +{ + return msg_bits(m, 1, 28, 0x1); +} + +static inline void msg_set_last_bulk(struct tipc_msg *m) +{ + msg_set_bits(m, 1, 27, 0x1, 1); +} + +static inline u32 msg_is_last_bulk(struct tipc_msg *m) +{ + return msg_bits(m, 1, 27, 0x1); +} + +static inline void msg_set_non_legacy(struct tipc_msg *m) +{ + msg_set_bits(m, 1, 26, 0x1, 1); +} + +static inline u32 msg_is_legacy(struct tipc_msg *m) +{ + return !msg_bits(m, 1, 26, 0x1); +} + +static inline u32 msg_reroute_cnt(struct tipc_msg *m) +{ + return msg_bits(m, 1, 21, 0xf); +} + +static inline void msg_incr_reroute_cnt(struct tipc_msg *m) +{ + msg_set_bits(m, 1, 21, 0xf, msg_reroute_cnt(m) + 1); +} + +static inline u32 msg_lookup_scope(struct tipc_msg *m) +{ + return msg_bits(m, 1, 19, 0x3); +} + +static inline void msg_set_lookup_scope(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 1, 19, 0x3, n); +} + +static inline u16 msg_bcast_ack(struct tipc_msg *m) +{ + return msg_bits(m, 1, 0, 0xffff); +} + +static inline void msg_set_bcast_ack(struct tipc_msg *m, u16 n) +{ + msg_set_bits(m, 1, 0, 0xffff, n); +} + +/* Note: reusing bits in word 1 for ACTIVATE_MSG only, to re-synch + * link peer session number + */ +static inline bool msg_dest_session_valid(struct tipc_msg *m) +{ + return msg_bits(m, 1, 16, 0x1); +} + +static inline void msg_set_dest_session_valid(struct tipc_msg *m, bool valid) +{ + msg_set_bits(m, 1, 16, 0x1, valid); +} + +static inline u16 msg_dest_session(struct tipc_msg *m) +{ + return msg_bits(m, 1, 0, 0xffff); +} + +static inline void msg_set_dest_session(struct tipc_msg *m, u16 n) +{ + msg_set_bits(m, 1, 0, 0xffff, n); +} + +/* + * Word 2 + */ +static inline u16 msg_ack(struct tipc_msg *m) +{ + return msg_bits(m, 2, 16, 0xffff); +} + +static inline void msg_set_ack(struct tipc_msg *m, u16 n) +{ + msg_set_bits(m, 2, 16, 0xffff, n); +} + +static inline u16 msg_seqno(struct tipc_msg *m) +{ + return msg_bits(m, 2, 0, 0xffff); +} + +static inline void msg_set_seqno(struct tipc_msg *m, u16 n) +{ + msg_set_bits(m, 2, 0, 0xffff, n); +} + +/* + * Words 3-10 + */ +static inline u32 msg_importance(struct tipc_msg *m) +{ + int usr = msg_user(m); + + if (likely((usr <= TIPC_CRITICAL_IMPORTANCE) && !msg_errcode(m))) + return usr; + if ((usr == MSG_FRAGMENTER) || (usr == MSG_BUNDLER)) + return msg_bits(m, 9, 0, 0x7); + return TIPC_SYSTEM_IMPORTANCE; +} + +static inline void msg_set_importance(struct tipc_msg *m, u32 i) +{ + int usr = msg_user(m); + + if (likely((usr == MSG_FRAGMENTER) || (usr == MSG_BUNDLER))) + msg_set_bits(m, 9, 0, 0x7, i); + else if (i < TIPC_SYSTEM_IMPORTANCE) + msg_set_user(m, i); + else + pr_warn("Trying to set illegal importance in message\n"); +} + +static inline u32 msg_prevnode(struct tipc_msg *m) +{ + return msg_word(m, 3); +} + +static inline void msg_set_prevnode(struct tipc_msg *m, u32 a) +{ + msg_set_word(m, 3, a); +} + +static inline u32 msg_origport(struct tipc_msg *m) +{ + if (msg_user(m) == MSG_FRAGMENTER) + m = msg_inner_hdr(m); + return msg_word(m, 4); +} + +static inline void msg_set_origport(struct tipc_msg *m, u32 p) +{ + msg_set_word(m, 4, p); +} + +static inline u16 msg_named_seqno(struct tipc_msg *m) +{ + return msg_bits(m, 4, 0, 0xffff); +} + +static inline void msg_set_named_seqno(struct tipc_msg *m, u16 n) +{ + msg_set_bits(m, 4, 0, 0xffff, n); +} + +static inline u32 msg_destport(struct tipc_msg *m) +{ + return msg_word(m, 5); +} + +static inline void msg_set_destport(struct tipc_msg *m, u32 p) +{ + msg_set_word(m, 5, p); +} + +static inline u32 msg_mc_netid(struct tipc_msg *m) +{ + return msg_word(m, 5); +} + +static inline void msg_set_mc_netid(struct tipc_msg *m, u32 p) +{ + msg_set_word(m, 5, p); +} + +static inline int msg_short(struct tipc_msg *m) +{ + return msg_hdr_sz(m) == SHORT_H_SIZE; +} + +static inline u32 msg_orignode(struct tipc_msg *m) +{ + if (likely(msg_short(m))) + return msg_prevnode(m); + return msg_word(m, 6); +} + +static inline void msg_set_orignode(struct tipc_msg *m, u32 a) +{ + msg_set_word(m, 6, a); +} + +static inline u32 msg_destnode(struct tipc_msg *m) +{ + return msg_word(m, 7); +} + +static inline void msg_set_destnode(struct tipc_msg *m, u32 a) +{ + msg_set_word(m, 7, a); +} + +static inline u32 msg_nametype(struct tipc_msg *m) +{ + return msg_word(m, 8); +} + +static inline void msg_set_nametype(struct tipc_msg *m, u32 n) +{ + msg_set_word(m, 8, n); +} + +static inline u32 msg_nameinst(struct tipc_msg *m) +{ + return msg_word(m, 9); +} + +static inline u32 msg_namelower(struct tipc_msg *m) +{ + return msg_nameinst(m); +} + +static inline void msg_set_namelower(struct tipc_msg *m, u32 n) +{ + msg_set_word(m, 9, n); +} + +static inline void msg_set_nameinst(struct tipc_msg *m, u32 n) +{ + msg_set_namelower(m, n); +} + +static inline u32 msg_nameupper(struct tipc_msg *m) +{ + return msg_word(m, 10); +} + +static inline void msg_set_nameupper(struct tipc_msg *m, u32 n) +{ + msg_set_word(m, 10, n); +} + +/* + * Constants and routines used to read and write TIPC internal message headers + */ + +/* + * Connection management protocol message types + */ +#define CONN_PROBE 0 +#define CONN_PROBE_REPLY 1 +#define CONN_ACK 2 + +/* + * Name distributor message types + */ +#define PUBLICATION 0 +#define WITHDRAWAL 1 + +/* + * Segmentation message types + */ +#define FIRST_FRAGMENT 0 +#define FRAGMENT 1 +#define LAST_FRAGMENT 2 + +/* + * Link management protocol message types + */ +#define STATE_MSG 0 +#define RESET_MSG 1 +#define ACTIVATE_MSG 2 + +/* + * Changeover tunnel message types + */ +#define SYNCH_MSG 0 +#define FAILOVER_MSG 1 + +/* + * Config protocol message types + */ +#define DSC_REQ_MSG 0 +#define DSC_RESP_MSG 1 +#define DSC_TRIAL_MSG 2 +#define DSC_TRIAL_FAIL_MSG 3 + +/* + * Group protocol message types + */ +#define GRP_JOIN_MSG 0 +#define GRP_LEAVE_MSG 1 +#define GRP_ADV_MSG 2 +#define GRP_ACK_MSG 3 +#define GRP_RECLAIM_MSG 4 +#define GRP_REMIT_MSG 5 + +/* Crypto message types */ +#define KEY_DISTR_MSG 0 + +/* + * Word 1 + */ +static inline u32 msg_seq_gap(struct tipc_msg *m) +{ + return msg_bits(m, 1, 16, 0x1fff); +} + +static inline void msg_set_seq_gap(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 1, 16, 0x1fff, n); +} + +static inline u32 msg_node_sig(struct tipc_msg *m) +{ + return msg_bits(m, 1, 0, 0xffff); +} + +static inline void msg_set_node_sig(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 1, 0, 0xffff, n); +} + +static inline u32 msg_node_capabilities(struct tipc_msg *m) +{ + return msg_bits(m, 1, 15, 0x1fff); +} + +static inline void msg_set_node_capabilities(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 1, 15, 0x1fff, n); +} + +/* + * Word 2 + */ +static inline u32 msg_dest_domain(struct tipc_msg *m) +{ + return msg_word(m, 2); +} + +static inline void msg_set_dest_domain(struct tipc_msg *m, u32 n) +{ + msg_set_word(m, 2, n); +} + +static inline void msg_set_bcgap_after(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 2, 16, 0xffff, n); +} + +static inline u32 msg_bcgap_to(struct tipc_msg *m) +{ + return msg_bits(m, 2, 0, 0xffff); +} + +static inline void msg_set_bcgap_to(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 2, 0, 0xffff, n); +} + +/* + * Word 4 + */ +static inline u32 msg_last_bcast(struct tipc_msg *m) +{ + return msg_bits(m, 4, 16, 0xffff); +} + +static inline u32 msg_bc_snd_nxt(struct tipc_msg *m) +{ + return msg_last_bcast(m) + 1; +} + +static inline void msg_set_last_bcast(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 4, 16, 0xffff, n); +} + +static inline u32 msg_nof_fragms(struct tipc_msg *m) +{ + return msg_bits(m, 4, 0, 0xffff); +} + +static inline void msg_set_nof_fragms(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 4, 0, 0xffff, n); +} + +static inline u32 msg_fragm_no(struct tipc_msg *m) +{ + return msg_bits(m, 4, 16, 0xffff); +} + +static inline void msg_set_fragm_no(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 4, 16, 0xffff, n); +} + +static inline u16 msg_next_sent(struct tipc_msg *m) +{ + return msg_bits(m, 4, 0, 0xffff); +} + +static inline void msg_set_next_sent(struct tipc_msg *m, u16 n) +{ + msg_set_bits(m, 4, 0, 0xffff, n); +} + +static inline u32 msg_bc_netid(struct tipc_msg *m) +{ + return msg_word(m, 4); +} + +static inline void msg_set_bc_netid(struct tipc_msg *m, u32 id) +{ + msg_set_word(m, 4, id); +} + +static inline u32 msg_link_selector(struct tipc_msg *m) +{ + if (msg_user(m) == MSG_FRAGMENTER) + m = (void *)msg_data(m); + return msg_bits(m, 4, 0, 1); +} + +/* + * Word 5 + */ +static inline u16 msg_session(struct tipc_msg *m) +{ + return msg_bits(m, 5, 16, 0xffff); +} + +static inline void msg_set_session(struct tipc_msg *m, u16 n) +{ + msg_set_bits(m, 5, 16, 0xffff, n); +} + +static inline u32 msg_probe(struct tipc_msg *m) +{ + return msg_bits(m, 5, 0, 1); +} + +static inline void msg_set_probe(struct tipc_msg *m, u32 val) +{ + msg_set_bits(m, 5, 0, 1, val); +} + +static inline char msg_net_plane(struct tipc_msg *m) +{ + return msg_bits(m, 5, 1, 7) + 'A'; +} + +static inline void msg_set_net_plane(struct tipc_msg *m, char n) +{ + msg_set_bits(m, 5, 1, 7, (n - 'A')); +} + +static inline u32 msg_linkprio(struct tipc_msg *m) +{ + return msg_bits(m, 5, 4, 0x1f); +} + +static inline void msg_set_linkprio(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 5, 4, 0x1f, n); +} + +static inline u32 msg_bearer_id(struct tipc_msg *m) +{ + return msg_bits(m, 5, 9, 0x7); +} + +static inline void msg_set_bearer_id(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 5, 9, 0x7, n); +} + +static inline u32 msg_redundant_link(struct tipc_msg *m) +{ + return msg_bits(m, 5, 12, 0x1); +} + +static inline void msg_set_redundant_link(struct tipc_msg *m, u32 r) +{ + msg_set_bits(m, 5, 12, 0x1, r); +} + +static inline u32 msg_peer_stopping(struct tipc_msg *m) +{ + return msg_bits(m, 5, 13, 0x1); +} + +static inline void msg_set_peer_stopping(struct tipc_msg *m, u32 s) +{ + msg_set_bits(m, 5, 13, 0x1, s); +} + +static inline bool msg_bc_ack_invalid(struct tipc_msg *m) +{ + switch (msg_user(m)) { + case BCAST_PROTOCOL: + case NAME_DISTRIBUTOR: + case LINK_PROTOCOL: + return msg_bits(m, 5, 14, 0x1); + default: + return false; + } +} + +static inline void msg_set_bc_ack_invalid(struct tipc_msg *m, bool invalid) +{ + msg_set_bits(m, 5, 14, 0x1, invalid); +} + +static inline char *msg_media_addr(struct tipc_msg *m) +{ + return (char *)&m->hdr[TIPC_MEDIA_INFO_OFFSET]; +} + +static inline u32 msg_bc_gap(struct tipc_msg *m) +{ + return msg_bits(m, 8, 0, 0x3ff); +} + +static inline void msg_set_bc_gap(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 8, 0, 0x3ff, n); +} + +/* + * Word 9 + */ +static inline u16 msg_msgcnt(struct tipc_msg *m) +{ + return msg_bits(m, 9, 16, 0xffff); +} + +static inline void msg_set_msgcnt(struct tipc_msg *m, u16 n) +{ + msg_set_bits(m, 9, 16, 0xffff, n); +} + +static inline u16 msg_syncpt(struct tipc_msg *m) +{ + return msg_bits(m, 9, 16, 0xffff); +} + +static inline void msg_set_syncpt(struct tipc_msg *m, u16 n) +{ + msg_set_bits(m, 9, 16, 0xffff, n); +} + +static inline u32 msg_conn_ack(struct tipc_msg *m) +{ + return msg_bits(m, 9, 16, 0xffff); +} + +static inline void msg_set_conn_ack(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 9, 16, 0xffff, n); +} + +static inline u16 msg_adv_win(struct tipc_msg *m) +{ + return msg_bits(m, 9, 0, 0xffff); +} + +static inline void msg_set_adv_win(struct tipc_msg *m, u16 n) +{ + msg_set_bits(m, 9, 0, 0xffff, n); +} + +static inline u32 msg_max_pkt(struct tipc_msg *m) +{ + return msg_bits(m, 9, 16, 0xffff) * 4; +} + +static inline void msg_set_max_pkt(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 9, 16, 0xffff, (n / 4)); +} + +static inline u32 msg_link_tolerance(struct tipc_msg *m) +{ + return msg_bits(m, 9, 0, 0xffff); +} + +static inline void msg_set_link_tolerance(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 9, 0, 0xffff, n); +} + +static inline u16 msg_grp_bc_syncpt(struct tipc_msg *m) +{ + return msg_bits(m, 9, 16, 0xffff); +} + +static inline void msg_set_grp_bc_syncpt(struct tipc_msg *m, u16 n) +{ + msg_set_bits(m, 9, 16, 0xffff, n); +} + +static inline u16 msg_grp_bc_acked(struct tipc_msg *m) +{ + return msg_bits(m, 9, 16, 0xffff); +} + +static inline void msg_set_grp_bc_acked(struct tipc_msg *m, u16 n) +{ + msg_set_bits(m, 9, 16, 0xffff, n); +} + +static inline u16 msg_grp_remitted(struct tipc_msg *m) +{ + return msg_bits(m, 9, 16, 0xffff); +} + +static inline void msg_set_grp_remitted(struct tipc_msg *m, u16 n) +{ + msg_set_bits(m, 9, 16, 0xffff, n); +} + +/* Word 10 + */ +static inline u16 msg_grp_evt(struct tipc_msg *m) +{ + return msg_bits(m, 10, 0, 0x3); +} + +static inline void msg_set_grp_evt(struct tipc_msg *m, int n) +{ + msg_set_bits(m, 10, 0, 0x3, n); +} + +static inline u16 msg_grp_bc_ack_req(struct tipc_msg *m) +{ + return msg_bits(m, 10, 0, 0x1); +} + +static inline void msg_set_grp_bc_ack_req(struct tipc_msg *m, bool n) +{ + msg_set_bits(m, 10, 0, 0x1, n); +} + +static inline u16 msg_grp_bc_seqno(struct tipc_msg *m) +{ + return msg_bits(m, 10, 16, 0xffff); +} + +static inline void msg_set_grp_bc_seqno(struct tipc_msg *m, u32 n) +{ + msg_set_bits(m, 10, 16, 0xffff, n); +} + +static inline bool msg_peer_link_is_up(struct tipc_msg *m) +{ + if (likely(msg_user(m) != LINK_PROTOCOL)) + return true; + if (msg_type(m) == STATE_MSG) + return true; + return false; +} + +static inline bool msg_peer_node_is_up(struct tipc_msg *m) +{ + if (msg_peer_link_is_up(m)) + return true; + return msg_redundant_link(m); +} + +static inline bool msg_is_reset(struct tipc_msg *hdr) +{ + return (msg_user(hdr) == LINK_PROTOCOL) && (msg_type(hdr) == RESET_MSG); +} + +/* Word 13 + */ +static inline void msg_set_peer_net_hash(struct tipc_msg *m, u32 n) +{ + msg_set_word(m, 13, n); +} + +static inline u32 msg_peer_net_hash(struct tipc_msg *m) +{ + return msg_word(m, 13); +} + +/* Word 14 + */ +static inline u32 msg_sugg_node_addr(struct tipc_msg *m) +{ + return msg_word(m, 14); +} + +static inline void msg_set_sugg_node_addr(struct tipc_msg *m, u32 n) +{ + msg_set_word(m, 14, n); +} + +static inline void msg_set_node_id(struct tipc_msg *hdr, u8 *id) +{ + memcpy(msg_data(hdr), id, 16); +} + +static inline u8 *msg_node_id(struct tipc_msg *hdr) +{ + return (u8 *)msg_data(hdr); +} + +struct sk_buff *tipc_buf_acquire(u32 size, gfp_t gfp); +bool tipc_msg_validate(struct sk_buff **_skb); +bool tipc_msg_reverse(u32 own_addr, struct sk_buff **skb, int err); +void tipc_skb_reject(struct net *net, int err, struct sk_buff *skb, + struct sk_buff_head *xmitq); +void tipc_msg_init(u32 own_addr, struct tipc_msg *m, u32 user, u32 type, + u32 hsize, u32 destnode); +struct sk_buff *tipc_msg_create(uint user, uint type, uint hdr_sz, + uint data_sz, u32 dnode, u32 onode, + u32 dport, u32 oport, int errcode); +int tipc_buf_append(struct sk_buff **headbuf, struct sk_buff **buf); +bool tipc_msg_try_bundle(struct sk_buff *tskb, struct sk_buff **skb, u32 mss, + u32 dnode, bool *new_bundle); +bool tipc_msg_extract(struct sk_buff *skb, struct sk_buff **iskb, int *pos); +int tipc_msg_fragment(struct sk_buff *skb, const struct tipc_msg *hdr, + int pktmax, struct sk_buff_head *frags); +int tipc_msg_build(struct tipc_msg *mhdr, struct msghdr *m, + int offset, int dsz, int mtu, struct sk_buff_head *list); +int tipc_msg_append(struct tipc_msg *hdr, struct msghdr *m, int dlen, + int mss, struct sk_buff_head *txq); +bool tipc_msg_lookup_dest(struct net *net, struct sk_buff *skb, int *err); +bool tipc_msg_assemble(struct sk_buff_head *list); +bool tipc_msg_reassemble(struct sk_buff_head *list, struct sk_buff_head *rcvq); +bool tipc_msg_pskb_copy(u32 dst, struct sk_buff_head *msg, + struct sk_buff_head *cpy); +bool __tipc_skb_queue_sorted(struct sk_buff_head *list, u16 seqno, + struct sk_buff *skb); +bool tipc_msg_skb_clone(struct sk_buff_head *msg, struct sk_buff_head *cpy); + +static inline u16 buf_seqno(struct sk_buff *skb) +{ + return msg_seqno(buf_msg(skb)); +} + +static inline int buf_roundup_len(struct sk_buff *skb) +{ + return (skb->len / 1024 + 1) * 1024; +} + +/* tipc_skb_peek(): peek and reserve first buffer in list + * @list: list to be peeked in + * Returns pointer to first buffer in list, if any + */ +static inline struct sk_buff *tipc_skb_peek(struct sk_buff_head *list, + spinlock_t *lock) +{ + struct sk_buff *skb; + + spin_lock_bh(lock); + skb = skb_peek(list); + if (skb) + skb_get(skb); + spin_unlock_bh(lock); + return skb; +} + +/* tipc_skb_peek_port(): find a destination port, ignoring all destinations + * up to and including 'filter'. + * Note: ignoring previously tried destinations minimizes the risk of + * contention on the socket lock + * @list: list to be peeked in + * @filter: last destination to be ignored from search + * Returns a destination port number, of applicable. + */ +static inline u32 tipc_skb_peek_port(struct sk_buff_head *list, u32 filter) +{ + struct sk_buff *skb; + u32 dport = 0; + bool ignore = true; + + spin_lock_bh(&list->lock); + skb_queue_walk(list, skb) { + dport = msg_destport(buf_msg(skb)); + if (!filter || skb_queue_is_last(list, skb)) + break; + if (dport == filter) + ignore = false; + else if (!ignore) + break; + } + spin_unlock_bh(&list->lock); + return dport; +} + +/* tipc_skb_dequeue(): unlink first buffer with dest 'dport' from list + * @list: list to be unlinked from + * @dport: selection criteria for buffer to unlink + */ +static inline struct sk_buff *tipc_skb_dequeue(struct sk_buff_head *list, + u32 dport) +{ + struct sk_buff *_skb, *tmp, *skb = NULL; + + spin_lock_bh(&list->lock); + skb_queue_walk_safe(list, _skb, tmp) { + if (msg_destport(buf_msg(_skb)) == dport) { + __skb_unlink(_skb, list); + skb = _skb; + break; + } + } + spin_unlock_bh(&list->lock); + return skb; +} + +/* tipc_skb_queue_splice_tail - append an skb list to lock protected list + * @list: the new list to append. Not lock protected + * @head: target list. Lock protected. + */ +static inline void tipc_skb_queue_splice_tail(struct sk_buff_head *list, + struct sk_buff_head *head) +{ + spin_lock_bh(&head->lock); + skb_queue_splice_tail(list, head); + spin_unlock_bh(&head->lock); +} + +/* tipc_skb_queue_splice_tail_init - merge two lock protected skb lists + * @list: the new list to add. Lock protected. Will be reinitialized + * @head: target list. Lock protected. + */ +static inline void tipc_skb_queue_splice_tail_init(struct sk_buff_head *list, + struct sk_buff_head *head) +{ + struct sk_buff_head tmp; + + __skb_queue_head_init(&tmp); + + spin_lock_bh(&list->lock); + skb_queue_splice_tail_init(list, &tmp); + spin_unlock_bh(&list->lock); + tipc_skb_queue_splice_tail(&tmp, head); +} + +/* __tipc_skb_dequeue() - dequeue the head skb according to expected seqno + * @list: list to be dequeued from + * @seqno: seqno of the expected msg + * + * returns skb dequeued from the list if its seqno is less than or equal to + * the expected one, otherwise the skb is still hold + * + * Note: must be used with appropriate locks held only + */ +static inline struct sk_buff *__tipc_skb_dequeue(struct sk_buff_head *list, + u16 seqno) +{ + struct sk_buff *skb = skb_peek(list); + + if (skb && less_eq(buf_seqno(skb), seqno)) { + __skb_unlink(skb, list); + return skb; + } + return NULL; +} + +#endif diff --git a/net/tipc/name_distr.c b/net/tipc/name_distr.c new file mode 100644 index 000000000..190b49c5c --- /dev/null +++ b/net/tipc/name_distr.c @@ -0,0 +1,411 @@ +/* + * net/tipc/name_distr.c: TIPC name distribution code + * + * Copyright (c) 2000-2006, 2014-2019, Ericsson AB + * Copyright (c) 2005, 2010-2011, Wind River Systems + * Copyright (c) 2020-2021, Red Hat Inc + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "core.h" +#include "link.h" +#include "name_distr.h" + +int sysctl_tipc_named_timeout __read_mostly = 2000; + +/** + * publ_to_item - add publication info to a publication message + * @p: publication info + * @i: location of item in the message + */ +static void publ_to_item(struct distr_item *i, struct publication *p) +{ + i->type = htonl(p->sr.type); + i->lower = htonl(p->sr.lower); + i->upper = htonl(p->sr.upper); + i->port = htonl(p->sk.ref); + i->key = htonl(p->key); +} + +/** + * named_prepare_buf - allocate & initialize a publication message + * @net: the associated network namespace + * @type: message type + * @size: payload size + * @dest: destination node + * + * The buffer returned is of size INT_H_SIZE + payload size + */ +static struct sk_buff *named_prepare_buf(struct net *net, u32 type, u32 size, + u32 dest) +{ + struct sk_buff *buf = tipc_buf_acquire(INT_H_SIZE + size, GFP_ATOMIC); + u32 self = tipc_own_addr(net); + struct tipc_msg *msg; + + if (buf != NULL) { + msg = buf_msg(buf); + tipc_msg_init(self, msg, NAME_DISTRIBUTOR, + type, INT_H_SIZE, dest); + msg_set_size(msg, INT_H_SIZE + size); + } + return buf; +} + +/** + * tipc_named_publish - tell other nodes about a new publication by this node + * @net: the associated network namespace + * @p: the new publication + */ +struct sk_buff *tipc_named_publish(struct net *net, struct publication *p) +{ + struct name_table *nt = tipc_name_table(net); + struct distr_item *item; + struct sk_buff *skb; + + if (p->scope == TIPC_NODE_SCOPE) { + list_add_tail_rcu(&p->binding_node, &nt->node_scope); + return NULL; + } + write_lock_bh(&nt->cluster_scope_lock); + list_add_tail(&p->binding_node, &nt->cluster_scope); + write_unlock_bh(&nt->cluster_scope_lock); + skb = named_prepare_buf(net, PUBLICATION, ITEM_SIZE, 0); + if (!skb) { + pr_warn("Publication distribution failure\n"); + return NULL; + } + msg_set_named_seqno(buf_msg(skb), nt->snd_nxt++); + msg_set_non_legacy(buf_msg(skb)); + item = (struct distr_item *)msg_data(buf_msg(skb)); + publ_to_item(item, p); + return skb; +} + +/** + * tipc_named_withdraw - tell other nodes about a withdrawn publication by this node + * @net: the associated network namespace + * @p: the withdrawn publication + */ +struct sk_buff *tipc_named_withdraw(struct net *net, struct publication *p) +{ + struct name_table *nt = tipc_name_table(net); + struct distr_item *item; + struct sk_buff *skb; + + write_lock_bh(&nt->cluster_scope_lock); + list_del(&p->binding_node); + write_unlock_bh(&nt->cluster_scope_lock); + if (p->scope == TIPC_NODE_SCOPE) + return NULL; + + skb = named_prepare_buf(net, WITHDRAWAL, ITEM_SIZE, 0); + if (!skb) { + pr_warn("Withdrawal distribution failure\n"); + return NULL; + } + msg_set_named_seqno(buf_msg(skb), nt->snd_nxt++); + msg_set_non_legacy(buf_msg(skb)); + item = (struct distr_item *)msg_data(buf_msg(skb)); + publ_to_item(item, p); + return skb; +} + +/** + * named_distribute - prepare name info for bulk distribution to another node + * @net: the associated network namespace + * @list: list of messages (buffers) to be returned from this function + * @dnode: node to be updated + * @pls: linked list of publication items to be packed into buffer chain + * @seqno: sequence number for this message + */ +static void named_distribute(struct net *net, struct sk_buff_head *list, + u32 dnode, struct list_head *pls, u16 seqno) +{ + struct publication *publ; + struct sk_buff *skb = NULL; + struct distr_item *item = NULL; + u32 msg_dsz = ((tipc_node_get_mtu(net, dnode, 0, false) - INT_H_SIZE) / + ITEM_SIZE) * ITEM_SIZE; + u32 msg_rem = msg_dsz; + struct tipc_msg *hdr; + + list_for_each_entry(publ, pls, binding_node) { + /* Prepare next buffer: */ + if (!skb) { + skb = named_prepare_buf(net, PUBLICATION, msg_rem, + dnode); + if (!skb) { + pr_warn("Bulk publication failure\n"); + return; + } + hdr = buf_msg(skb); + msg_set_bc_ack_invalid(hdr, true); + msg_set_bulk(hdr); + msg_set_non_legacy(hdr); + item = (struct distr_item *)msg_data(hdr); + } + + /* Pack publication into message: */ + publ_to_item(item, publ); + item++; + msg_rem -= ITEM_SIZE; + + /* Append full buffer to list: */ + if (!msg_rem) { + __skb_queue_tail(list, skb); + skb = NULL; + msg_rem = msg_dsz; + } + } + if (skb) { + hdr = buf_msg(skb); + msg_set_size(hdr, INT_H_SIZE + (msg_dsz - msg_rem)); + skb_trim(skb, INT_H_SIZE + (msg_dsz - msg_rem)); + __skb_queue_tail(list, skb); + } + hdr = buf_msg(skb_peek_tail(list)); + msg_set_last_bulk(hdr); + msg_set_named_seqno(hdr, seqno); +} + +/** + * tipc_named_node_up - tell specified node about all publications by this node + * @net: the associated network namespace + * @dnode: destination node + * @capabilities: peer node's capabilities + */ +void tipc_named_node_up(struct net *net, u32 dnode, u16 capabilities) +{ + struct name_table *nt = tipc_name_table(net); + struct tipc_net *tn = tipc_net(net); + struct sk_buff_head head; + u16 seqno; + + __skb_queue_head_init(&head); + spin_lock_bh(&tn->nametbl_lock); + if (!(capabilities & TIPC_NAMED_BCAST)) + nt->rc_dests++; + seqno = nt->snd_nxt; + spin_unlock_bh(&tn->nametbl_lock); + + read_lock_bh(&nt->cluster_scope_lock); + named_distribute(net, &head, dnode, &nt->cluster_scope, seqno); + tipc_node_xmit(net, &head, dnode, 0); + read_unlock_bh(&nt->cluster_scope_lock); +} + +/** + * tipc_publ_purge - remove publication associated with a failed node + * @net: the associated network namespace + * @p: the publication to remove + * @addr: failed node's address + * + * Invoked for each publication issued by a newly failed node. + * Removes publication structure from name table & deletes it. + */ +static void tipc_publ_purge(struct net *net, struct publication *p, u32 addr) +{ + struct tipc_net *tn = tipc_net(net); + struct publication *_p; + struct tipc_uaddr ua; + + tipc_uaddr(&ua, TIPC_SERVICE_RANGE, p->scope, p->sr.type, + p->sr.lower, p->sr.upper); + spin_lock_bh(&tn->nametbl_lock); + _p = tipc_nametbl_remove_publ(net, &ua, &p->sk, p->key); + if (_p) + tipc_node_unsubscribe(net, &_p->binding_node, addr); + spin_unlock_bh(&tn->nametbl_lock); + if (_p) + kfree_rcu(_p, rcu); +} + +void tipc_publ_notify(struct net *net, struct list_head *nsub_list, + u32 addr, u16 capabilities) +{ + struct name_table *nt = tipc_name_table(net); + struct tipc_net *tn = tipc_net(net); + + struct publication *publ, *tmp; + + list_for_each_entry_safe(publ, tmp, nsub_list, binding_node) + tipc_publ_purge(net, publ, addr); + spin_lock_bh(&tn->nametbl_lock); + if (!(capabilities & TIPC_NAMED_BCAST)) + nt->rc_dests--; + spin_unlock_bh(&tn->nametbl_lock); +} + +/** + * tipc_update_nametbl - try to process a nametable update and notify + * subscribers + * @net: the associated network namespace + * @i: location of item in the message + * @node: node address + * @dtype: name distributor message type + * + * tipc_nametbl_lock must be held. + * Return: the publication item if successful, otherwise NULL. + */ +static bool tipc_update_nametbl(struct net *net, struct distr_item *i, + u32 node, u32 dtype) +{ + struct publication *p = NULL; + struct tipc_socket_addr sk; + struct tipc_uaddr ua; + u32 key = ntohl(i->key); + + tipc_uaddr(&ua, TIPC_SERVICE_RANGE, TIPC_CLUSTER_SCOPE, + ntohl(i->type), ntohl(i->lower), ntohl(i->upper)); + sk.ref = ntohl(i->port); + sk.node = node; + + if (dtype == PUBLICATION) { + p = tipc_nametbl_insert_publ(net, &ua, &sk, key); + if (p) { + tipc_node_subscribe(net, &p->binding_node, node); + return true; + } + } else if (dtype == WITHDRAWAL) { + p = tipc_nametbl_remove_publ(net, &ua, &sk, key); + if (p) { + tipc_node_unsubscribe(net, &p->binding_node, node); + kfree_rcu(p, rcu); + return true; + } + pr_warn_ratelimited("Failed to remove binding %u,%u from %u\n", + ua.sr.type, ua.sr.lower, node); + } else { + pr_warn_ratelimited("Unknown name table message received\n"); + } + return false; +} + +static struct sk_buff *tipc_named_dequeue(struct sk_buff_head *namedq, + u16 *rcv_nxt, bool *open) +{ + struct sk_buff *skb, *tmp; + struct tipc_msg *hdr; + u16 seqno; + + spin_lock_bh(&namedq->lock); + skb_queue_walk_safe(namedq, skb, tmp) { + if (unlikely(skb_linearize(skb))) { + __skb_unlink(skb, namedq); + kfree_skb(skb); + continue; + } + hdr = buf_msg(skb); + seqno = msg_named_seqno(hdr); + if (msg_is_last_bulk(hdr)) { + *rcv_nxt = seqno; + *open = true; + } + + if (msg_is_bulk(hdr) || msg_is_legacy(hdr)) { + __skb_unlink(skb, namedq); + spin_unlock_bh(&namedq->lock); + return skb; + } + + if (*open && (*rcv_nxt == seqno)) { + (*rcv_nxt)++; + __skb_unlink(skb, namedq); + spin_unlock_bh(&namedq->lock); + return skb; + } + + if (less(seqno, *rcv_nxt)) { + __skb_unlink(skb, namedq); + kfree_skb(skb); + continue; + } + } + spin_unlock_bh(&namedq->lock); + return NULL; +} + +/** + * tipc_named_rcv - process name table update messages sent by another node + * @net: the associated network namespace + * @namedq: queue to receive from + * @rcv_nxt: store last received seqno here + * @open: last bulk msg was received (FIXME) + */ +void tipc_named_rcv(struct net *net, struct sk_buff_head *namedq, + u16 *rcv_nxt, bool *open) +{ + struct tipc_net *tn = tipc_net(net); + struct distr_item *item; + struct tipc_msg *hdr; + struct sk_buff *skb; + u32 count, node; + + spin_lock_bh(&tn->nametbl_lock); + while ((skb = tipc_named_dequeue(namedq, rcv_nxt, open))) { + hdr = buf_msg(skb); + node = msg_orignode(hdr); + item = (struct distr_item *)msg_data(hdr); + count = msg_data_sz(hdr) / ITEM_SIZE; + while (count--) { + tipc_update_nametbl(net, item, node, msg_type(hdr)); + item++; + } + kfree_skb(skb); + } + spin_unlock_bh(&tn->nametbl_lock); +} + +/** + * tipc_named_reinit - re-initialize local publications + * @net: the associated network namespace + * + * This routine is called whenever TIPC networking is enabled. + * All name table entries published by this node are updated to reflect + * the node's new network address. + */ +void tipc_named_reinit(struct net *net) +{ + struct name_table *nt = tipc_name_table(net); + struct tipc_net *tn = tipc_net(net); + struct publication *p; + u32 self = tipc_own_addr(net); + + spin_lock_bh(&tn->nametbl_lock); + + list_for_each_entry_rcu(p, &nt->node_scope, binding_node) + p->sk.node = self; + list_for_each_entry_rcu(p, &nt->cluster_scope, binding_node) + p->sk.node = self; + nt->rc_dests = 0; + spin_unlock_bh(&tn->nametbl_lock); +} diff --git a/net/tipc/name_distr.h b/net/tipc/name_distr.h new file mode 100644 index 000000000..e231e6964 --- /dev/null +++ b/net/tipc/name_distr.h @@ -0,0 +1,80 @@ +/* + * net/tipc/name_distr.h: Include file for TIPC name distribution code + * + * Copyright (c) 2000-2006, Ericsson AB + * Copyright (c) 2005, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _TIPC_NAME_DISTR_H +#define _TIPC_NAME_DISTR_H + +#include "name_table.h" + +#define ITEM_SIZE sizeof(struct distr_item) + +/** + * struct distr_item - publication info distributed to other nodes + * @type: name sequence type + * @lower: name sequence lower bound + * @upper: name sequence upper bound + * @port: publishing port reference + * @key: publication key + * + * ===> All fields are stored in network byte order. <=== + * + * First 3 fields identify (name or) name sequence being published. + * Reference field uniquely identifies port that published name sequence. + * Key field uniquely identifies publication, in the event a port has + * multiple publications of the same name sequence. + * + * Note: There is no field that identifies the publishing node because it is + * the same for all items contained within a publication message. + */ +struct distr_item { + __be32 type; + __be32 lower; + __be32 upper; + __be32 port; + __be32 key; +}; + +void tipc_named_bcast(struct net *net, struct sk_buff *skb); +struct sk_buff *tipc_named_publish(struct net *net, struct publication *publ); +struct sk_buff *tipc_named_withdraw(struct net *net, struct publication *publ); +void tipc_named_node_up(struct net *net, u32 dnode, u16 capabilities); +void tipc_named_rcv(struct net *net, struct sk_buff_head *namedq, + u16 *rcv_nxt, bool *open); +void tipc_named_reinit(struct net *net); +void tipc_publ_notify(struct net *net, struct list_head *nsub_list, + u32 addr, u16 capabilities); + +#endif diff --git a/net/tipc/name_table.c b/net/tipc/name_table.c new file mode 100644 index 000000000..d1180370f --- /dev/null +++ b/net/tipc/name_table.c @@ -0,0 +1,1204 @@ +/* + * net/tipc/name_table.c: TIPC name table code + * + * Copyright (c) 2000-2006, 2014-2018, Ericsson AB + * Copyright (c) 2004-2008, 2010-2014, Wind River Systems + * Copyright (c) 2020-2021, Red Hat Inc + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include +#include +#include +#include "core.h" +#include "netlink.h" +#include "name_table.h" +#include "name_distr.h" +#include "subscr.h" +#include "bcast.h" +#include "addr.h" +#include "node.h" +#include "group.h" + +/** + * struct service_range - container for all bindings of a service range + * @lower: service range lower bound + * @upper: service range upper bound + * @tree_node: member of service range RB tree + * @max: largest 'upper' in this node subtree + * @local_publ: list of identical publications made from this node + * Used by closest_first lookup and multicast lookup algorithm + * @all_publ: all publications identical to this one, whatever node and scope + * Used by round-robin lookup algorithm + */ +struct service_range { + u32 lower; + u32 upper; + struct rb_node tree_node; + u32 max; + struct list_head local_publ; + struct list_head all_publ; +}; + +/** + * struct tipc_service - container for all published instances of a service type + * @type: 32 bit 'type' value for service + * @publ_cnt: increasing counter for publications in this service + * @ranges: rb tree containing all service ranges for this service + * @service_list: links to adjacent name ranges in hash chain + * @subscriptions: list of subscriptions for this service type + * @lock: spinlock controlling access to pertaining service ranges/publications + * @rcu: RCU callback head used for deferred freeing + */ +struct tipc_service { + u32 type; + u32 publ_cnt; + struct rb_root ranges; + struct hlist_node service_list; + struct list_head subscriptions; + spinlock_t lock; /* Covers service range list */ + struct rcu_head rcu; +}; + +#define service_range_upper(sr) ((sr)->upper) +RB_DECLARE_CALLBACKS_MAX(static, sr_callbacks, + struct service_range, tree_node, u32, max, + service_range_upper) + +#define service_range_entry(rbtree_node) \ + (container_of(rbtree_node, struct service_range, tree_node)) + +#define service_range_overlap(sr, start, end) \ + ((sr)->lower <= (end) && (sr)->upper >= (start)) + +/** + * service_range_foreach_match - iterate over tipc service rbtree for each + * range match + * @sr: the service range pointer as a loop cursor + * @sc: the pointer to tipc service which holds the service range rbtree + * @start: beginning of the search range (end >= start) for matching + * @end: end of the search range (end >= start) for matching + */ +#define service_range_foreach_match(sr, sc, start, end) \ + for (sr = service_range_match_first((sc)->ranges.rb_node, \ + start, \ + end); \ + sr; \ + sr = service_range_match_next(&(sr)->tree_node, \ + start, \ + end)) + +/** + * service_range_match_first - find first service range matching a range + * @n: the root node of service range rbtree for searching + * @start: beginning of the search range (end >= start) for matching + * @end: end of the search range (end >= start) for matching + * + * Return: the leftmost service range node in the rbtree that overlaps the + * specific range if any. Otherwise, returns NULL. + */ +static struct service_range *service_range_match_first(struct rb_node *n, + u32 start, u32 end) +{ + struct service_range *sr; + struct rb_node *l, *r; + + /* Non overlaps in tree at all? */ + if (!n || service_range_entry(n)->max < start) + return NULL; + + while (n) { + l = n->rb_left; + if (l && service_range_entry(l)->max >= start) { + /* A leftmost overlap range node must be one in the left + * subtree. If not, it has lower > end, then nodes on + * the right side cannot satisfy the condition either. + */ + n = l; + continue; + } + + /* No one in the left subtree can match, return if this node is + * an overlap i.e. leftmost. + */ + sr = service_range_entry(n); + if (service_range_overlap(sr, start, end)) + return sr; + + /* Ok, try to lookup on the right side */ + r = n->rb_right; + if (sr->lower <= end && + r && service_range_entry(r)->max >= start) { + n = r; + continue; + } + break; + } + + return NULL; +} + +/** + * service_range_match_next - find next service range matching a range + * @n: a node in service range rbtree from which the searching starts + * @start: beginning of the search range (end >= start) for matching + * @end: end of the search range (end >= start) for matching + * + * Return: the next service range node to the given node in the rbtree that + * overlaps the specific range if any. Otherwise, returns NULL. + */ +static struct service_range *service_range_match_next(struct rb_node *n, + u32 start, u32 end) +{ + struct service_range *sr; + struct rb_node *p, *r; + + while (n) { + r = n->rb_right; + if (r && service_range_entry(r)->max >= start) + /* A next overlap range node must be one in the right + * subtree. If not, it has lower > end, then any next + * successor (- an ancestor) of this node cannot + * satisfy the condition either. + */ + return service_range_match_first(r, start, end); + + /* No one in the right subtree can match, go up to find an + * ancestor of this node which is parent of a left-hand child. + */ + while ((p = rb_parent(n)) && n == p->rb_right) + n = p; + if (!p) + break; + + /* Return if this ancestor is an overlap */ + sr = service_range_entry(p); + if (service_range_overlap(sr, start, end)) + return sr; + + /* Ok, try to lookup more from this ancestor */ + if (sr->lower <= end) { + n = p; + continue; + } + break; + } + + return NULL; +} + +static int hash(int x) +{ + return x & (TIPC_NAMETBL_SIZE - 1); +} + +/** + * tipc_publ_create - create a publication structure + * @ua: the service range the user is binding to + * @sk: the address of the socket that is bound + * @key: publication key + */ +static struct publication *tipc_publ_create(struct tipc_uaddr *ua, + struct tipc_socket_addr *sk, + u32 key) +{ + struct publication *p = kzalloc(sizeof(*p), GFP_ATOMIC); + + if (!p) + return NULL; + + p->sr = ua->sr; + p->sk = *sk; + p->scope = ua->scope; + p->key = key; + INIT_LIST_HEAD(&p->binding_sock); + INIT_LIST_HEAD(&p->binding_node); + INIT_LIST_HEAD(&p->local_publ); + INIT_LIST_HEAD(&p->all_publ); + INIT_LIST_HEAD(&p->list); + return p; +} + +/** + * tipc_service_create - create a service structure for the specified 'type' + * @net: network namespace + * @ua: address representing the service to be bound + * + * Allocates a single range structure and sets it to all 0's. + */ +static struct tipc_service *tipc_service_create(struct net *net, + struct tipc_uaddr *ua) +{ + struct name_table *nt = tipc_name_table(net); + struct tipc_service *service; + struct hlist_head *hd; + + service = kzalloc(sizeof(*service), GFP_ATOMIC); + if (!service) { + pr_warn("Service creation failed, no memory\n"); + return NULL; + } + + spin_lock_init(&service->lock); + service->type = ua->sr.type; + service->ranges = RB_ROOT; + INIT_HLIST_NODE(&service->service_list); + INIT_LIST_HEAD(&service->subscriptions); + hd = &nt->services[hash(ua->sr.type)]; + hlist_add_head_rcu(&service->service_list, hd); + return service; +} + +/* tipc_service_find_range - find service range matching publication parameters + */ +static struct service_range *tipc_service_find_range(struct tipc_service *sc, + struct tipc_uaddr *ua) +{ + struct service_range *sr; + + service_range_foreach_match(sr, sc, ua->sr.lower, ua->sr.upper) { + /* Look for exact match */ + if (sr->lower == ua->sr.lower && sr->upper == ua->sr.upper) + return sr; + } + + return NULL; +} + +static struct service_range *tipc_service_create_range(struct tipc_service *sc, + struct publication *p) +{ + struct rb_node **n, *parent = NULL; + struct service_range *sr; + u32 lower = p->sr.lower; + u32 upper = p->sr.upper; + + n = &sc->ranges.rb_node; + while (*n) { + parent = *n; + sr = service_range_entry(parent); + if (lower == sr->lower && upper == sr->upper) + return sr; + if (sr->max < upper) + sr->max = upper; + if (lower <= sr->lower) + n = &parent->rb_left; + else + n = &parent->rb_right; + } + sr = kzalloc(sizeof(*sr), GFP_ATOMIC); + if (!sr) + return NULL; + sr->lower = lower; + sr->upper = upper; + sr->max = upper; + INIT_LIST_HEAD(&sr->local_publ); + INIT_LIST_HEAD(&sr->all_publ); + rb_link_node(&sr->tree_node, parent, n); + rb_insert_augmented(&sr->tree_node, &sc->ranges, &sr_callbacks); + return sr; +} + +static bool tipc_service_insert_publ(struct net *net, + struct tipc_service *sc, + struct publication *p) +{ + struct tipc_subscription *sub, *tmp; + struct service_range *sr; + struct publication *_p; + u32 node = p->sk.node; + bool first = false; + bool res = false; + u32 key = p->key; + + spin_lock_bh(&sc->lock); + sr = tipc_service_create_range(sc, p); + if (!sr) + goto exit; + + first = list_empty(&sr->all_publ); + + /* Return if the publication already exists */ + list_for_each_entry(_p, &sr->all_publ, all_publ) { + if (_p->key == key && (!_p->sk.node || _p->sk.node == node)) { + pr_debug("Failed to bind duplicate %u,%u,%u/%u:%u/%u\n", + p->sr.type, p->sr.lower, p->sr.upper, + node, p->sk.ref, key); + goto exit; + } + } + + if (in_own_node(net, p->sk.node)) + list_add(&p->local_publ, &sr->local_publ); + list_add(&p->all_publ, &sr->all_publ); + p->id = sc->publ_cnt++; + + /* Any subscriptions waiting for notification? */ + list_for_each_entry_safe(sub, tmp, &sc->subscriptions, service_list) { + tipc_sub_report_overlap(sub, p, TIPC_PUBLISHED, first); + } + res = true; +exit: + if (!res) + pr_warn("Failed to bind to %u,%u,%u\n", + p->sr.type, p->sr.lower, p->sr.upper); + spin_unlock_bh(&sc->lock); + return res; +} + +/** + * tipc_service_remove_publ - remove a publication from a service + * @r: service_range to remove publication from + * @sk: address publishing socket + * @key: target publication key + */ +static struct publication *tipc_service_remove_publ(struct service_range *r, + struct tipc_socket_addr *sk, + u32 key) +{ + struct publication *p; + u32 node = sk->node; + + list_for_each_entry(p, &r->all_publ, all_publ) { + if (p->key != key || (node && node != p->sk.node)) + continue; + list_del(&p->all_publ); + list_del(&p->local_publ); + return p; + } + return NULL; +} + +/* + * Code reused: time_after32() for the same purpose + */ +#define publication_after(pa, pb) time_after32((pa)->id, (pb)->id) +static int tipc_publ_sort(void *priv, const struct list_head *a, + const struct list_head *b) +{ + struct publication *pa, *pb; + + pa = container_of(a, struct publication, list); + pb = container_of(b, struct publication, list); + return publication_after(pa, pb); +} + +/** + * tipc_service_subscribe - attach a subscription, and optionally + * issue the prescribed number of events if there is any service + * range overlapping with the requested range + * @service: the tipc_service to attach the @sub to + * @sub: the subscription to attach + */ +static void tipc_service_subscribe(struct tipc_service *service, + struct tipc_subscription *sub) +{ + struct publication *p, *first, *tmp; + struct list_head publ_list; + struct service_range *sr; + u32 filter, lower, upper; + + filter = sub->s.filter; + lower = sub->s.seq.lower; + upper = sub->s.seq.upper; + + tipc_sub_get(sub); + list_add(&sub->service_list, &service->subscriptions); + + if (filter & TIPC_SUB_NO_STATUS) + return; + + INIT_LIST_HEAD(&publ_list); + service_range_foreach_match(sr, service, lower, upper) { + first = NULL; + list_for_each_entry(p, &sr->all_publ, all_publ) { + if (filter & TIPC_SUB_PORTS) + list_add_tail(&p->list, &publ_list); + else if (!first || publication_after(first, p)) + /* Pick this range's *first* publication */ + first = p; + } + if (first) + list_add_tail(&first->list, &publ_list); + } + + /* Sort the publications before reporting */ + list_sort(NULL, &publ_list, tipc_publ_sort); + list_for_each_entry_safe(p, tmp, &publ_list, list) { + tipc_sub_report_overlap(sub, p, TIPC_PUBLISHED, true); + list_del_init(&p->list); + } +} + +static struct tipc_service *tipc_service_find(struct net *net, + struct tipc_uaddr *ua) +{ + struct name_table *nt = tipc_name_table(net); + struct hlist_head *service_head; + struct tipc_service *service; + + service_head = &nt->services[hash(ua->sr.type)]; + hlist_for_each_entry_rcu(service, service_head, service_list) { + if (service->type == ua->sr.type) + return service; + } + return NULL; +}; + +struct publication *tipc_nametbl_insert_publ(struct net *net, + struct tipc_uaddr *ua, + struct tipc_socket_addr *sk, + u32 key) +{ + struct tipc_service *sc; + struct publication *p; + + p = tipc_publ_create(ua, sk, key); + if (!p) + return NULL; + + sc = tipc_service_find(net, ua); + if (!sc) + sc = tipc_service_create(net, ua); + if (sc && tipc_service_insert_publ(net, sc, p)) + return p; + kfree(p); + return NULL; +} + +struct publication *tipc_nametbl_remove_publ(struct net *net, + struct tipc_uaddr *ua, + struct tipc_socket_addr *sk, + u32 key) +{ + struct tipc_subscription *sub, *tmp; + struct publication *p = NULL; + struct service_range *sr; + struct tipc_service *sc; + bool last; + + sc = tipc_service_find(net, ua); + if (!sc) + goto exit; + + spin_lock_bh(&sc->lock); + sr = tipc_service_find_range(sc, ua); + if (!sr) + goto unlock; + p = tipc_service_remove_publ(sr, sk, key); + if (!p) + goto unlock; + + /* Notify any waiting subscriptions */ + last = list_empty(&sr->all_publ); + list_for_each_entry_safe(sub, tmp, &sc->subscriptions, service_list) { + tipc_sub_report_overlap(sub, p, TIPC_WITHDRAWN, last); + } + + /* Remove service range item if this was its last publication */ + if (list_empty(&sr->all_publ)) { + rb_erase_augmented(&sr->tree_node, &sc->ranges, &sr_callbacks); + kfree(sr); + } + + /* Delete service item if no more publications and subscriptions */ + if (RB_EMPTY_ROOT(&sc->ranges) && list_empty(&sc->subscriptions)) { + hlist_del_init_rcu(&sc->service_list); + kfree_rcu(sc, rcu); + } +unlock: + spin_unlock_bh(&sc->lock); +exit: + if (!p) { + pr_err("Failed to remove unknown binding: %u,%u,%u/%u:%u/%u\n", + ua->sr.type, ua->sr.lower, ua->sr.upper, + sk->node, sk->ref, key); + } + return p; +} + +/** + * tipc_nametbl_lookup_anycast - perform service instance to socket translation + * @net: network namespace + * @ua: service address to look up + * @sk: address to socket we want to find + * + * On entry, a non-zero 'sk->node' indicates the node where we want lookup to be + * performed, which may not be this one. + * + * On exit: + * + * - If lookup is deferred to another node, leave 'sk->node' unchanged and + * return 'true'. + * - If lookup is successful, set the 'sk->node' and 'sk->ref' (== portid) which + * represent the bound socket and return 'true'. + * - If lookup fails, return 'false' + * + * Note that for legacy users (node configured with Z.C.N address format) the + * 'closest-first' lookup algorithm must be maintained, i.e., if sk.node is 0 + * we must look in the local binding list first + */ +bool tipc_nametbl_lookup_anycast(struct net *net, + struct tipc_uaddr *ua, + struct tipc_socket_addr *sk) +{ + struct tipc_net *tn = tipc_net(net); + bool legacy = tn->legacy_addr_format; + u32 self = tipc_own_addr(net); + u32 inst = ua->sa.instance; + struct service_range *r; + struct tipc_service *sc; + struct publication *p; + struct list_head *l; + bool res = false; + + if (!tipc_in_scope(legacy, sk->node, self)) + return true; + + rcu_read_lock(); + sc = tipc_service_find(net, ua); + if (unlikely(!sc)) + goto exit; + + spin_lock_bh(&sc->lock); + service_range_foreach_match(r, sc, inst, inst) { + /* Select lookup algo: local, closest-first or round-robin */ + if (sk->node == self) { + l = &r->local_publ; + if (list_empty(l)) + continue; + p = list_first_entry(l, struct publication, local_publ); + list_move_tail(&p->local_publ, &r->local_publ); + } else if (legacy && !sk->node && !list_empty(&r->local_publ)) { + l = &r->local_publ; + p = list_first_entry(l, struct publication, local_publ); + list_move_tail(&p->local_publ, &r->local_publ); + } else { + l = &r->all_publ; + p = list_first_entry(l, struct publication, all_publ); + list_move_tail(&p->all_publ, &r->all_publ); + } + *sk = p->sk; + res = true; + /* Todo: as for legacy, pick the first matching range only, a + * "true" round-robin will be performed as needed. + */ + break; + } + spin_unlock_bh(&sc->lock); + +exit: + rcu_read_unlock(); + return res; +} + +/* tipc_nametbl_lookup_group(): lookup destinaton(s) in a communication group + * Returns a list of one (== group anycast) or more (== group multicast) + * destination socket/node pairs matching the given address. + * The requester may or may not want to exclude himself from the list. + */ +bool tipc_nametbl_lookup_group(struct net *net, struct tipc_uaddr *ua, + struct list_head *dsts, int *dstcnt, + u32 exclude, bool mcast) +{ + u32 self = tipc_own_addr(net); + u32 inst = ua->sa.instance; + struct service_range *sr; + struct tipc_service *sc; + struct publication *p; + + *dstcnt = 0; + rcu_read_lock(); + sc = tipc_service_find(net, ua); + if (unlikely(!sc)) + goto exit; + + spin_lock_bh(&sc->lock); + + /* Todo: a full search i.e. service_range_foreach_match() instead? */ + sr = service_range_match_first(sc->ranges.rb_node, inst, inst); + if (!sr) + goto no_match; + + list_for_each_entry(p, &sr->all_publ, all_publ) { + if (p->scope != ua->scope) + continue; + if (p->sk.ref == exclude && p->sk.node == self) + continue; + tipc_dest_push(dsts, p->sk.node, p->sk.ref); + (*dstcnt)++; + if (mcast) + continue; + list_move_tail(&p->all_publ, &sr->all_publ); + break; + } +no_match: + spin_unlock_bh(&sc->lock); +exit: + rcu_read_unlock(); + return !list_empty(dsts); +} + +/* tipc_nametbl_lookup_mcast_sockets(): look up node local destinaton sockets + * matching the given address + * Used on nodes which have received a multicast/broadcast message + * Returns a list of local sockets + */ +void tipc_nametbl_lookup_mcast_sockets(struct net *net, struct tipc_uaddr *ua, + struct list_head *dports) +{ + struct service_range *sr; + struct tipc_service *sc; + struct publication *p; + u8 scope = ua->scope; + + rcu_read_lock(); + sc = tipc_service_find(net, ua); + if (!sc) + goto exit; + + spin_lock_bh(&sc->lock); + service_range_foreach_match(sr, sc, ua->sr.lower, ua->sr.upper) { + list_for_each_entry(p, &sr->local_publ, local_publ) { + if (scope == p->scope || scope == TIPC_ANY_SCOPE) + tipc_dest_push(dports, 0, p->sk.ref); + } + } + spin_unlock_bh(&sc->lock); +exit: + rcu_read_unlock(); +} + +/* tipc_nametbl_lookup_mcast_nodes(): look up all destination nodes matching + * the given address. Used in sending node. + * Used on nodes which are sending out a multicast/broadcast message + * Returns a list of nodes, including own node if applicable + */ +void tipc_nametbl_lookup_mcast_nodes(struct net *net, struct tipc_uaddr *ua, + struct tipc_nlist *nodes) +{ + struct service_range *sr; + struct tipc_service *sc; + struct publication *p; + + rcu_read_lock(); + sc = tipc_service_find(net, ua); + if (!sc) + goto exit; + + spin_lock_bh(&sc->lock); + service_range_foreach_match(sr, sc, ua->sr.lower, ua->sr.upper) { + list_for_each_entry(p, &sr->all_publ, all_publ) { + tipc_nlist_add(nodes, p->sk.node); + } + } + spin_unlock_bh(&sc->lock); +exit: + rcu_read_unlock(); +} + +/* tipc_nametbl_build_group - build list of communication group members + */ +void tipc_nametbl_build_group(struct net *net, struct tipc_group *grp, + struct tipc_uaddr *ua) +{ + struct service_range *sr; + struct tipc_service *sc; + struct publication *p; + struct rb_node *n; + + rcu_read_lock(); + sc = tipc_service_find(net, ua); + if (!sc) + goto exit; + + spin_lock_bh(&sc->lock); + for (n = rb_first(&sc->ranges); n; n = rb_next(n)) { + sr = container_of(n, struct service_range, tree_node); + list_for_each_entry(p, &sr->all_publ, all_publ) { + if (p->scope != ua->scope) + continue; + tipc_group_add_member(grp, p->sk.node, p->sk.ref, + p->sr.lower); + } + } + spin_unlock_bh(&sc->lock); +exit: + rcu_read_unlock(); +} + +/* tipc_nametbl_publish - add service binding to name table + */ +struct publication *tipc_nametbl_publish(struct net *net, struct tipc_uaddr *ua, + struct tipc_socket_addr *sk, u32 key) +{ + struct name_table *nt = tipc_name_table(net); + struct tipc_net *tn = tipc_net(net); + struct publication *p = NULL; + struct sk_buff *skb = NULL; + u32 rc_dests; + + spin_lock_bh(&tn->nametbl_lock); + + if (nt->local_publ_count >= TIPC_MAX_PUBL) { + pr_warn("Bind failed, max limit %u reached\n", TIPC_MAX_PUBL); + goto exit; + } + + p = tipc_nametbl_insert_publ(net, ua, sk, key); + if (p) { + nt->local_publ_count++; + skb = tipc_named_publish(net, p); + } + rc_dests = nt->rc_dests; +exit: + spin_unlock_bh(&tn->nametbl_lock); + + if (skb) + tipc_node_broadcast(net, skb, rc_dests); + return p; + +} + +/** + * tipc_nametbl_withdraw - withdraw a service binding + * @net: network namespace + * @ua: service address/range being unbound + * @sk: address of the socket being unbound from + * @key: target publication key + */ +void tipc_nametbl_withdraw(struct net *net, struct tipc_uaddr *ua, + struct tipc_socket_addr *sk, u32 key) +{ + struct name_table *nt = tipc_name_table(net); + struct tipc_net *tn = tipc_net(net); + struct sk_buff *skb = NULL; + struct publication *p; + u32 rc_dests; + + spin_lock_bh(&tn->nametbl_lock); + + p = tipc_nametbl_remove_publ(net, ua, sk, key); + if (p) { + nt->local_publ_count--; + skb = tipc_named_withdraw(net, p); + list_del_init(&p->binding_sock); + kfree_rcu(p, rcu); + } + rc_dests = nt->rc_dests; + spin_unlock_bh(&tn->nametbl_lock); + + if (skb) + tipc_node_broadcast(net, skb, rc_dests); +} + +/** + * tipc_nametbl_subscribe - add a subscription object to the name table + * @sub: subscription to add + */ +bool tipc_nametbl_subscribe(struct tipc_subscription *sub) +{ + struct tipc_net *tn = tipc_net(sub->net); + u32 type = sub->s.seq.type; + struct tipc_service *sc; + struct tipc_uaddr ua; + bool res = true; + + tipc_uaddr(&ua, TIPC_SERVICE_RANGE, TIPC_NODE_SCOPE, type, + sub->s.seq.lower, sub->s.seq.upper); + spin_lock_bh(&tn->nametbl_lock); + sc = tipc_service_find(sub->net, &ua); + if (!sc) + sc = tipc_service_create(sub->net, &ua); + if (sc) { + spin_lock_bh(&sc->lock); + tipc_service_subscribe(sc, sub); + spin_unlock_bh(&sc->lock); + } else { + pr_warn("Failed to subscribe for {%u,%u,%u}\n", + type, sub->s.seq.lower, sub->s.seq.upper); + res = false; + } + spin_unlock_bh(&tn->nametbl_lock); + return res; +} + +/** + * tipc_nametbl_unsubscribe - remove a subscription object from name table + * @sub: subscription to remove + */ +void tipc_nametbl_unsubscribe(struct tipc_subscription *sub) +{ + struct tipc_net *tn = tipc_net(sub->net); + struct tipc_service *sc; + struct tipc_uaddr ua; + + tipc_uaddr(&ua, TIPC_SERVICE_RANGE, TIPC_NODE_SCOPE, + sub->s.seq.type, sub->s.seq.lower, sub->s.seq.upper); + spin_lock_bh(&tn->nametbl_lock); + sc = tipc_service_find(sub->net, &ua); + if (!sc) + goto exit; + + spin_lock_bh(&sc->lock); + list_del_init(&sub->service_list); + tipc_sub_put(sub); + + /* Delete service item if no more publications and subscriptions */ + if (RB_EMPTY_ROOT(&sc->ranges) && list_empty(&sc->subscriptions)) { + hlist_del_init_rcu(&sc->service_list); + kfree_rcu(sc, rcu); + } + spin_unlock_bh(&sc->lock); +exit: + spin_unlock_bh(&tn->nametbl_lock); +} + +int tipc_nametbl_init(struct net *net) +{ + struct tipc_net *tn = tipc_net(net); + struct name_table *nt; + int i; + + nt = kzalloc(sizeof(*nt), GFP_KERNEL); + if (!nt) + return -ENOMEM; + + for (i = 0; i < TIPC_NAMETBL_SIZE; i++) + INIT_HLIST_HEAD(&nt->services[i]); + + INIT_LIST_HEAD(&nt->node_scope); + INIT_LIST_HEAD(&nt->cluster_scope); + rwlock_init(&nt->cluster_scope_lock); + tn->nametbl = nt; + spin_lock_init(&tn->nametbl_lock); + return 0; +} + +/** + * tipc_service_delete - purge all publications for a service and delete it + * @net: the associated network namespace + * @sc: tipc_service to delete + */ +static void tipc_service_delete(struct net *net, struct tipc_service *sc) +{ + struct service_range *sr, *tmpr; + struct publication *p, *tmp; + + spin_lock_bh(&sc->lock); + rbtree_postorder_for_each_entry_safe(sr, tmpr, &sc->ranges, tree_node) { + list_for_each_entry_safe(p, tmp, &sr->all_publ, all_publ) { + tipc_service_remove_publ(sr, &p->sk, p->key); + kfree_rcu(p, rcu); + } + rb_erase_augmented(&sr->tree_node, &sc->ranges, &sr_callbacks); + kfree(sr); + } + hlist_del_init_rcu(&sc->service_list); + spin_unlock_bh(&sc->lock); + kfree_rcu(sc, rcu); +} + +void tipc_nametbl_stop(struct net *net) +{ + struct name_table *nt = tipc_name_table(net); + struct tipc_net *tn = tipc_net(net); + struct hlist_head *service_head; + struct tipc_service *service; + u32 i; + + /* Verify name table is empty and purge any lingering + * publications, then release the name table + */ + spin_lock_bh(&tn->nametbl_lock); + for (i = 0; i < TIPC_NAMETBL_SIZE; i++) { + if (hlist_empty(&nt->services[i])) + continue; + service_head = &nt->services[i]; + hlist_for_each_entry_rcu(service, service_head, service_list) { + tipc_service_delete(net, service); + } + } + spin_unlock_bh(&tn->nametbl_lock); + + synchronize_net(); + kfree(nt); +} + +static int __tipc_nl_add_nametable_publ(struct tipc_nl_msg *msg, + struct tipc_service *service, + struct service_range *sr, + u32 *last_key) +{ + struct publication *p; + struct nlattr *attrs; + struct nlattr *b; + void *hdr; + + if (*last_key) { + list_for_each_entry(p, &sr->all_publ, all_publ) + if (p->key == *last_key) + break; + if (list_entry_is_head(p, &sr->all_publ, all_publ)) + return -EPIPE; + } else { + p = list_first_entry(&sr->all_publ, + struct publication, + all_publ); + } + + list_for_each_entry_from(p, &sr->all_publ, all_publ) { + *last_key = p->key; + + hdr = genlmsg_put(msg->skb, msg->portid, msg->seq, + &tipc_genl_family, NLM_F_MULTI, + TIPC_NL_NAME_TABLE_GET); + if (!hdr) + return -EMSGSIZE; + + attrs = nla_nest_start_noflag(msg->skb, TIPC_NLA_NAME_TABLE); + if (!attrs) + goto msg_full; + + b = nla_nest_start_noflag(msg->skb, TIPC_NLA_NAME_TABLE_PUBL); + if (!b) + goto attr_msg_full; + + if (nla_put_u32(msg->skb, TIPC_NLA_PUBL_TYPE, service->type)) + goto publ_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PUBL_LOWER, sr->lower)) + goto publ_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PUBL_UPPER, sr->upper)) + goto publ_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PUBL_SCOPE, p->scope)) + goto publ_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PUBL_NODE, p->sk.node)) + goto publ_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PUBL_REF, p->sk.ref)) + goto publ_msg_full; + if (nla_put_u32(msg->skb, TIPC_NLA_PUBL_KEY, p->key)) + goto publ_msg_full; + + nla_nest_end(msg->skb, b); + nla_nest_end(msg->skb, attrs); + genlmsg_end(msg->skb, hdr); + } + *last_key = 0; + + return 0; + +publ_msg_full: + nla_nest_cancel(msg->skb, b); +attr_msg_full: + nla_nest_cancel(msg->skb, attrs); +msg_full: + genlmsg_cancel(msg->skb, hdr); + + return -EMSGSIZE; +} + +static int __tipc_nl_service_range_list(struct tipc_nl_msg *msg, + struct tipc_service *sc, + u32 *last_lower, u32 *last_key) +{ + struct service_range *sr; + struct rb_node *n; + int err; + + for (n = rb_first(&sc->ranges); n; n = rb_next(n)) { + sr = container_of(n, struct service_range, tree_node); + if (sr->lower < *last_lower) + continue; + err = __tipc_nl_add_nametable_publ(msg, sc, sr, last_key); + if (err) { + *last_lower = sr->lower; + return err; + } + } + *last_lower = 0; + return 0; +} + +static int tipc_nl_service_list(struct net *net, struct tipc_nl_msg *msg, + u32 *last_type, u32 *last_lower, u32 *last_key) +{ + struct tipc_net *tn = tipc_net(net); + struct tipc_service *service = NULL; + struct hlist_head *head; + struct tipc_uaddr ua; + int err; + int i; + + if (*last_type) + i = hash(*last_type); + else + i = 0; + + for (; i < TIPC_NAMETBL_SIZE; i++) { + head = &tn->nametbl->services[i]; + + if (*last_type || + (!i && *last_key && (*last_lower == *last_key))) { + tipc_uaddr(&ua, TIPC_SERVICE_RANGE, TIPC_NODE_SCOPE, + *last_type, *last_lower, *last_lower); + service = tipc_service_find(net, &ua); + if (!service) + return -EPIPE; + } else { + hlist_for_each_entry_rcu(service, head, service_list) + break; + if (!service) + continue; + } + + hlist_for_each_entry_from_rcu(service, service_list) { + spin_lock_bh(&service->lock); + err = __tipc_nl_service_range_list(msg, service, + last_lower, + last_key); + + if (err) { + *last_type = service->type; + spin_unlock_bh(&service->lock); + return err; + } + spin_unlock_bh(&service->lock); + } + *last_type = 0; + } + return 0; +} + +int tipc_nl_name_table_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + u32 last_type = cb->args[0]; + u32 last_lower = cb->args[1]; + u32 last_key = cb->args[2]; + int done = cb->args[3]; + struct tipc_nl_msg msg; + int err; + + if (done) + return 0; + + msg.skb = skb; + msg.portid = NETLINK_CB(cb->skb).portid; + msg.seq = cb->nlh->nlmsg_seq; + + rcu_read_lock(); + err = tipc_nl_service_list(net, &msg, &last_type, + &last_lower, &last_key); + if (!err) { + done = 1; + } else if (err != -EMSGSIZE) { + /* We never set seq or call nl_dump_check_consistent() this + * means that setting prev_seq here will cause the consistence + * check to fail in the netlink callback handler. Resulting in + * the NLMSG_DONE message having the NLM_F_DUMP_INTR flag set if + * we got an error. + */ + cb->prev_seq = 1; + } + rcu_read_unlock(); + + cb->args[0] = last_type; + cb->args[1] = last_lower; + cb->args[2] = last_key; + cb->args[3] = done; + + return skb->len; +} + +struct tipc_dest *tipc_dest_find(struct list_head *l, u32 node, u32 port) +{ + struct tipc_dest *dst; + + list_for_each_entry(dst, l, list) { + if (dst->node == node && dst->port == port) + return dst; + } + return NULL; +} + +bool tipc_dest_push(struct list_head *l, u32 node, u32 port) +{ + struct tipc_dest *dst; + + if (tipc_dest_find(l, node, port)) + return false; + + dst = kmalloc(sizeof(*dst), GFP_ATOMIC); + if (unlikely(!dst)) + return false; + dst->node = node; + dst->port = port; + list_add(&dst->list, l); + return true; +} + +bool tipc_dest_pop(struct list_head *l, u32 *node, u32 *port) +{ + struct tipc_dest *dst; + + if (list_empty(l)) + return false; + dst = list_first_entry(l, typeof(*dst), list); + if (port) + *port = dst->port; + if (node) + *node = dst->node; + list_del(&dst->list); + kfree(dst); + return true; +} + +bool tipc_dest_del(struct list_head *l, u32 node, u32 port) +{ + struct tipc_dest *dst; + + dst = tipc_dest_find(l, node, port); + if (!dst) + return false; + list_del(&dst->list); + kfree(dst); + return true; +} + +void tipc_dest_list_purge(struct list_head *l) +{ + struct tipc_dest *dst, *tmp; + + list_for_each_entry_safe(dst, tmp, l, list) { + list_del(&dst->list); + kfree(dst); + } +} diff --git a/net/tipc/name_table.h b/net/tipc/name_table.h new file mode 100644 index 000000000..3bcd9ef8c --- /dev/null +++ b/net/tipc/name_table.h @@ -0,0 +1,155 @@ +/* + * net/tipc/name_table.h: Include file for TIPC name table code + * + * Copyright (c) 2000-2006, 2014-2018, Ericsson AB + * Copyright (c) 2004-2005, 2010-2011, Wind River Systems + * Copyright (c) 2020-2021, Red Hat Inc + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _TIPC_NAME_TABLE_H +#define _TIPC_NAME_TABLE_H + +struct tipc_subscription; +struct tipc_plist; +struct tipc_nlist; +struct tipc_group; +struct tipc_uaddr; + +/* + * TIPC name types reserved for internal TIPC use (both current and planned) + */ +#define TIPC_ZM_SRV 3 /* zone master service name type */ +#define TIPC_PUBL_SCOPE_NUM (TIPC_NODE_SCOPE + 1) +#define TIPC_NAMETBL_SIZE 1024 /* must be a power of 2 */ + +#define TIPC_ANY_SCOPE 10 /* Both node and cluster scope will match */ + +/** + * struct publication - info about a published service address or range + * @sr: service range represented by this publication + * @sk: address of socket bound to this publication + * @scope: scope of publication, TIPC_NODE_SCOPE or TIPC_CLUSTER_SCOPE + * @key: publication key, unique across the cluster + * @id: publication id + * @binding_node: all publications from the same node which bound this one + * - Remote publications: in node->publ_list; + * Used by node/name distr to withdraw publications when node is lost + * - Local/node scope publications: in name_table->node_scope list + * - Local/cluster scope publications: in name_table->cluster_scope list + * @binding_sock: all publications from the same socket which bound this one + * Used by socket to withdraw publications when socket is unbound/released + * @local_publ: list of identical publications made from this node + * Used by closest_first and multicast receive lookup algorithms + * @all_publ: all publications identical to this one, whatever node and scope + * Used by round-robin lookup algorithm + * @list: to form a list of publications in temporal order + * @rcu: RCU callback head used for deferred freeing + */ +struct publication { + struct tipc_service_range sr; + struct tipc_socket_addr sk; + u16 scope; + u32 key; + u32 id; + struct list_head binding_node; + struct list_head binding_sock; + struct list_head local_publ; + struct list_head all_publ; + struct list_head list; + struct rcu_head rcu; +}; + +/** + * struct name_table - table containing all existing port name publications + * @services: name sequence hash lists + * @node_scope: all local publications with node scope + * - used by name_distr during re-init of name table + * @cluster_scope: all local publications with cluster scope + * - used by name_distr to send bulk updates to new nodes + * - used by name_distr during re-init of name table + * @cluster_scope_lock: lock for accessing @cluster_scope + * @local_publ_count: number of publications issued by this node + * @rc_dests: destination node counter + * @snd_nxt: next sequence number to be used + */ +struct name_table { + struct hlist_head services[TIPC_NAMETBL_SIZE]; + struct list_head node_scope; + struct list_head cluster_scope; + rwlock_t cluster_scope_lock; + u32 local_publ_count; + u32 rc_dests; + u32 snd_nxt; +}; + +int tipc_nl_name_table_dump(struct sk_buff *skb, struct netlink_callback *cb); +bool tipc_nametbl_lookup_anycast(struct net *net, struct tipc_uaddr *ua, + struct tipc_socket_addr *sk); +void tipc_nametbl_lookup_mcast_sockets(struct net *net, struct tipc_uaddr *ua, + struct list_head *dports); +void tipc_nametbl_lookup_mcast_nodes(struct net *net, struct tipc_uaddr *ua, + struct tipc_nlist *nodes); +bool tipc_nametbl_lookup_group(struct net *net, struct tipc_uaddr *ua, + struct list_head *dsts, int *dstcnt, + u32 exclude, bool mcast); +void tipc_nametbl_build_group(struct net *net, struct tipc_group *grp, + struct tipc_uaddr *ua); +struct publication *tipc_nametbl_publish(struct net *net, struct tipc_uaddr *ua, + struct tipc_socket_addr *sk, u32 key); +void tipc_nametbl_withdraw(struct net *net, struct tipc_uaddr *ua, + struct tipc_socket_addr *sk, u32 key); +struct publication *tipc_nametbl_insert_publ(struct net *net, + struct tipc_uaddr *ua, + struct tipc_socket_addr *sk, + u32 key); +struct publication *tipc_nametbl_remove_publ(struct net *net, + struct tipc_uaddr *ua, + struct tipc_socket_addr *sk, + u32 key); +bool tipc_nametbl_subscribe(struct tipc_subscription *s); +void tipc_nametbl_unsubscribe(struct tipc_subscription *s); +int tipc_nametbl_init(struct net *net); +void tipc_nametbl_stop(struct net *net); + +struct tipc_dest { + struct list_head list; + u32 port; + u32 node; +}; + +struct tipc_dest *tipc_dest_find(struct list_head *l, u32 node, u32 port); +bool tipc_dest_push(struct list_head *l, u32 node, u32 port); +bool tipc_dest_pop(struct list_head *l, u32 *node, u32 *port); +bool tipc_dest_del(struct list_head *l, u32 node, u32 port); +void tipc_dest_list_purge(struct list_head *l); + +#endif diff --git a/net/tipc/net.c b/net/tipc/net.c new file mode 100644 index 000000000..0e95572e5 --- /dev/null +++ b/net/tipc/net.c @@ -0,0 +1,345 @@ +/* + * net/tipc/net.c: TIPC network routing code + * + * Copyright (c) 1995-2006, 2014, Ericsson AB + * Copyright (c) 2005, 2010-2011, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "core.h" +#include "net.h" +#include "name_distr.h" +#include "subscr.h" +#include "socket.h" +#include "node.h" +#include "bcast.h" +#include "link.h" +#include "netlink.h" +#include "monitor.h" + +/* + * The TIPC locking policy is designed to ensure a very fine locking + * granularity, permitting complete parallel access to individual + * port and node/link instances. The code consists of four major + * locking domains, each protected with their own disjunct set of locks. + * + * 1: The bearer level. + * RTNL lock is used to serialize the process of configuring bearer + * on update side, and RCU lock is applied on read side to make + * bearer instance valid on both paths of message transmission and + * reception. + * + * 2: The node and link level. + * All node instances are saved into two tipc_node_list and node_htable + * lists. The two lists are protected by node_list_lock on write side, + * and they are guarded with RCU lock on read side. Especially node + * instance is destroyed only when TIPC module is removed, and we can + * confirm that there has no any user who is accessing the node at the + * moment. Therefore, Except for iterating the two lists within RCU + * protection, it's no needed to hold RCU that we access node instance + * in other places. + * + * In addition, all members in node structure including link instances + * are protected by node spin lock. + * + * 3: The transport level of the protocol. + * This consists of the structures port, (and its user level + * representations, such as user_port and tipc_sock), reference and + * tipc_user (port.c, reg.c, socket.c). + * + * This layer has four different locks: + * - The tipc_port spin_lock. This is protecting each port instance + * from parallel data access and removal. Since we can not place + * this lock in the port itself, it has been placed in the + * corresponding reference table entry, which has the same life + * cycle as the module. This entry is difficult to access from + * outside the TIPC core, however, so a pointer to the lock has + * been added in the port instance, -to be used for unlocking + * only. + * - A read/write lock to protect the reference table itself (teg.c). + * (Nobody is using read-only access to this, so it can just as + * well be changed to a spin_lock) + * - A spin lock to protect the registry of kernel/driver users (reg.c) + * - A global spin_lock (tipc_port_lock), which only task is to ensure + * consistency where more than one port is involved in an operation, + * i.e., when a port is part of a linked list of ports. + * There are two such lists; 'port_list', which is used for management, + * and 'wait_list', which is used to queue ports during congestion. + * + * 4: The name table (name_table.c, name_distr.c, subscription.c) + * - There is one big read/write-lock (tipc_nametbl_lock) protecting the + * overall name table structure. Nothing must be added/removed to + * this structure without holding write access to it. + * - There is one local spin_lock per sub_sequence, which can be seen + * as a sub-domain to the tipc_nametbl_lock domain. It is used only + * for translation operations, and is needed because a translation + * steps the root of the 'publication' linked list between each lookup. + * This is always used within the scope of a tipc_nametbl_lock(read). + * - A local spin_lock protecting the queue of subscriber events. +*/ + +static void tipc_net_finalize(struct net *net, u32 addr); + +int tipc_net_init(struct net *net, u8 *node_id, u32 addr) +{ + if (tipc_own_id(net)) { + pr_info("Cannot configure node identity twice\n"); + return -1; + } + pr_info("Started in network mode\n"); + + if (node_id) + tipc_set_node_id(net, node_id); + if (addr) + tipc_net_finalize(net, addr); + return 0; +} + +static void tipc_net_finalize(struct net *net, u32 addr) +{ + struct tipc_net *tn = tipc_net(net); + struct tipc_socket_addr sk = {0, addr}; + struct tipc_uaddr ua; + + tipc_uaddr(&ua, TIPC_SERVICE_RANGE, TIPC_CLUSTER_SCOPE, + TIPC_NODE_STATE, addr, addr); + + if (cmpxchg(&tn->node_addr, 0, addr)) + return; + tipc_set_node_addr(net, addr); + tipc_named_reinit(net); + tipc_sk_reinit(net); + tipc_mon_reinit_self(net); + tipc_nametbl_publish(net, &ua, &sk, addr); +} + +void tipc_net_finalize_work(struct work_struct *work) +{ + struct tipc_net *tn = container_of(work, struct tipc_net, work); + + tipc_net_finalize(tipc_link_net(tn->bcl), tn->trial_addr); +} + +void tipc_net_stop(struct net *net) +{ + if (!tipc_own_id(net)) + return; + + rtnl_lock(); + tipc_bearer_stop(net); + tipc_node_stop(net); + rtnl_unlock(); + + pr_info("Left network mode\n"); +} + +static int __tipc_nl_add_net(struct net *net, struct tipc_nl_msg *msg) +{ + struct tipc_net *tn = net_generic(net, tipc_net_id); + u64 *w0 = (u64 *)&tn->node_id[0]; + u64 *w1 = (u64 *)&tn->node_id[8]; + struct nlattr *attrs; + void *hdr; + + hdr = genlmsg_put(msg->skb, msg->portid, msg->seq, &tipc_genl_family, + NLM_F_MULTI, TIPC_NL_NET_GET); + if (!hdr) + return -EMSGSIZE; + + attrs = nla_nest_start_noflag(msg->skb, TIPC_NLA_NET); + if (!attrs) + goto msg_full; + + if (nla_put_u32(msg->skb, TIPC_NLA_NET_ID, tn->net_id)) + goto attr_msg_full; + if (nla_put_u64_64bit(msg->skb, TIPC_NLA_NET_NODEID, *w0, 0)) + goto attr_msg_full; + if (nla_put_u64_64bit(msg->skb, TIPC_NLA_NET_NODEID_W1, *w1, 0)) + goto attr_msg_full; + nla_nest_end(msg->skb, attrs); + genlmsg_end(msg->skb, hdr); + + return 0; + +attr_msg_full: + nla_nest_cancel(msg->skb, attrs); +msg_full: + genlmsg_cancel(msg->skb, hdr); + + return -EMSGSIZE; +} + +int tipc_nl_net_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + int err; + int done = cb->args[0]; + struct tipc_nl_msg msg; + + if (done) + return 0; + + msg.skb = skb; + msg.portid = NETLINK_CB(cb->skb).portid; + msg.seq = cb->nlh->nlmsg_seq; + + err = __tipc_nl_add_net(net, &msg); + if (err) + goto out; + + done = 1; +out: + cb->args[0] = done; + + return skb->len; +} + +int __tipc_nl_net_set(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[TIPC_NLA_NET_MAX + 1]; + struct net *net = sock_net(skb->sk); + struct tipc_net *tn = tipc_net(net); + int err; + + if (!info->attrs[TIPC_NLA_NET]) + return -EINVAL; + + err = nla_parse_nested_deprecated(attrs, TIPC_NLA_NET_MAX, + info->attrs[TIPC_NLA_NET], + tipc_nl_net_policy, info->extack); + + if (err) + return err; + + /* Can't change net id once TIPC has joined a network */ + if (tipc_own_addr(net)) + return -EPERM; + + if (attrs[TIPC_NLA_NET_ID]) { + u32 val; + + val = nla_get_u32(attrs[TIPC_NLA_NET_ID]); + if (val < 1 || val > 9999) + return -EINVAL; + + tn->net_id = val; + } + + if (attrs[TIPC_NLA_NET_ADDR]) { + u32 addr; + + addr = nla_get_u32(attrs[TIPC_NLA_NET_ADDR]); + if (!addr) + return -EINVAL; + tn->legacy_addr_format = true; + tipc_net_init(net, NULL, addr); + } + + if (attrs[TIPC_NLA_NET_NODEID]) { + u8 node_id[NODE_ID_LEN]; + u64 *w0 = (u64 *)&node_id[0]; + u64 *w1 = (u64 *)&node_id[8]; + + if (!attrs[TIPC_NLA_NET_NODEID_W1]) + return -EINVAL; + *w0 = nla_get_u64(attrs[TIPC_NLA_NET_NODEID]); + *w1 = nla_get_u64(attrs[TIPC_NLA_NET_NODEID_W1]); + tipc_net_init(net, node_id, 0); + } + return 0; +} + +int tipc_nl_net_set(struct sk_buff *skb, struct genl_info *info) +{ + int err; + + rtnl_lock(); + err = __tipc_nl_net_set(skb, info); + rtnl_unlock(); + + return err; +} + +static int __tipc_nl_addr_legacy_get(struct net *net, struct tipc_nl_msg *msg) +{ + struct tipc_net *tn = tipc_net(net); + struct nlattr *attrs; + void *hdr; + + hdr = genlmsg_put(msg->skb, msg->portid, msg->seq, &tipc_genl_family, + 0, TIPC_NL_ADDR_LEGACY_GET); + if (!hdr) + return -EMSGSIZE; + + attrs = nla_nest_start(msg->skb, TIPC_NLA_NET); + if (!attrs) + goto msg_full; + + if (tn->legacy_addr_format) + if (nla_put_flag(msg->skb, TIPC_NLA_NET_ADDR_LEGACY)) + goto attr_msg_full; + + nla_nest_end(msg->skb, attrs); + genlmsg_end(msg->skb, hdr); + + return 0; + +attr_msg_full: + nla_nest_cancel(msg->skb, attrs); +msg_full: + genlmsg_cancel(msg->skb, hdr); + + return -EMSGSIZE; +} + +int tipc_nl_net_addr_legacy_get(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = sock_net(skb->sk); + struct tipc_nl_msg msg; + struct sk_buff *rep; + int err; + + rep = nlmsg_new(NLMSG_GOODSIZE, GFP_KERNEL); + if (!rep) + return -ENOMEM; + + msg.skb = rep; + msg.portid = info->snd_portid; + msg.seq = info->snd_seq; + + err = __tipc_nl_addr_legacy_get(net, &msg); + if (err) { + nlmsg_free(msg.skb); + return err; + } + + return genlmsg_reply(msg.skb, info); +} diff --git a/net/tipc/net.h b/net/tipc/net.h new file mode 100644 index 000000000..d0c91d2df --- /dev/null +++ b/net/tipc/net.h @@ -0,0 +1,53 @@ +/* + * net/tipc/net.h: Include file for TIPC network routing code + * + * Copyright (c) 1995-2006, 2014, Ericsson AB + * Copyright (c) 2005, 2010-2011, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _TIPC_NET_H +#define _TIPC_NET_H + +#include + +extern const struct nla_policy tipc_nl_net_policy[]; + +int tipc_net_init(struct net *net, u8 *node_id, u32 addr); +void tipc_net_finalize_work(struct work_struct *work); +void tipc_sched_net_finalize(struct net *net, u32 addr); +void tipc_net_stop(struct net *net); +int tipc_nl_net_dump(struct sk_buff *skb, struct netlink_callback *cb); +int tipc_nl_net_set(struct sk_buff *skb, struct genl_info *info); +int __tipc_nl_net_set(struct sk_buff *skb, struct genl_info *info); +int tipc_nl_net_addr_legacy_get(struct sk_buff *skb, struct genl_info *info); + +#endif diff --git a/net/tipc/netlink.c b/net/tipc/netlink.c new file mode 100644 index 000000000..1a9a5bdac --- /dev/null +++ b/net/tipc/netlink.c @@ -0,0 +1,315 @@ +/* + * net/tipc/netlink.c: TIPC configuration handling + * + * Copyright (c) 2005-2006, 2014, Ericsson AB + * Copyright (c) 2005-2007, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "core.h" +#include "socket.h" +#include "name_table.h" +#include "bearer.h" +#include "link.h" +#include "node.h" +#include "net.h" +#include "udp_media.h" +#include + +static const struct nla_policy tipc_nl_policy[TIPC_NLA_MAX + 1] = { + [TIPC_NLA_UNSPEC] = { .type = NLA_UNSPEC, }, + [TIPC_NLA_BEARER] = { .type = NLA_NESTED, }, + [TIPC_NLA_SOCK] = { .type = NLA_NESTED, }, + [TIPC_NLA_PUBL] = { .type = NLA_NESTED, }, + [TIPC_NLA_LINK] = { .type = NLA_NESTED, }, + [TIPC_NLA_MEDIA] = { .type = NLA_NESTED, }, + [TIPC_NLA_NODE] = { .type = NLA_NESTED, }, + [TIPC_NLA_NET] = { .type = NLA_NESTED, }, + [TIPC_NLA_NAME_TABLE] = { .type = NLA_NESTED, }, + [TIPC_NLA_MON] = { .type = NLA_NESTED, }, +}; + +const struct nla_policy +tipc_nl_name_table_policy[TIPC_NLA_NAME_TABLE_MAX + 1] = { + [TIPC_NLA_NAME_TABLE_UNSPEC] = { .type = NLA_UNSPEC }, + [TIPC_NLA_NAME_TABLE_PUBL] = { .type = NLA_NESTED } +}; + +const struct nla_policy tipc_nl_monitor_policy[TIPC_NLA_MON_MAX + 1] = { + [TIPC_NLA_MON_UNSPEC] = { .type = NLA_UNSPEC }, + [TIPC_NLA_MON_REF] = { .type = NLA_U32 }, + [TIPC_NLA_MON_ACTIVATION_THRESHOLD] = { .type = NLA_U32 }, +}; + +const struct nla_policy tipc_nl_sock_policy[TIPC_NLA_SOCK_MAX + 1] = { + [TIPC_NLA_SOCK_UNSPEC] = { .type = NLA_UNSPEC }, + [TIPC_NLA_SOCK_ADDR] = { .type = NLA_U32 }, + [TIPC_NLA_SOCK_REF] = { .type = NLA_U32 }, + [TIPC_NLA_SOCK_CON] = { .type = NLA_NESTED }, + [TIPC_NLA_SOCK_HAS_PUBL] = { .type = NLA_FLAG } +}; + +const struct nla_policy tipc_nl_net_policy[TIPC_NLA_NET_MAX + 1] = { + [TIPC_NLA_NET_UNSPEC] = { .type = NLA_UNSPEC }, + [TIPC_NLA_NET_ID] = { .type = NLA_U32 }, + [TIPC_NLA_NET_ADDR] = { .type = NLA_U32 }, + [TIPC_NLA_NET_NODEID] = { .type = NLA_U64 }, + [TIPC_NLA_NET_NODEID_W1] = { .type = NLA_U64 }, + [TIPC_NLA_NET_ADDR_LEGACY] = { .type = NLA_FLAG } +}; + +const struct nla_policy tipc_nl_link_policy[TIPC_NLA_LINK_MAX + 1] = { + [TIPC_NLA_LINK_UNSPEC] = { .type = NLA_UNSPEC }, + [TIPC_NLA_LINK_NAME] = { .type = NLA_NUL_STRING, + .len = TIPC_MAX_LINK_NAME }, + [TIPC_NLA_LINK_MTU] = { .type = NLA_U32 }, + [TIPC_NLA_LINK_BROADCAST] = { .type = NLA_FLAG }, + [TIPC_NLA_LINK_UP] = { .type = NLA_FLAG }, + [TIPC_NLA_LINK_ACTIVE] = { .type = NLA_FLAG }, + [TIPC_NLA_LINK_PROP] = { .type = NLA_NESTED }, + [TIPC_NLA_LINK_STATS] = { .type = NLA_NESTED }, + [TIPC_NLA_LINK_RX] = { .type = NLA_U32 }, + [TIPC_NLA_LINK_TX] = { .type = NLA_U32 } +}; + +const struct nla_policy tipc_nl_node_policy[TIPC_NLA_NODE_MAX + 1] = { + [TIPC_NLA_NODE_UNSPEC] = { .type = NLA_UNSPEC }, + [TIPC_NLA_NODE_ADDR] = { .type = NLA_U32 }, + [TIPC_NLA_NODE_UP] = { .type = NLA_FLAG }, + [TIPC_NLA_NODE_ID] = { .type = NLA_BINARY, + .len = TIPC_NODEID_LEN}, + [TIPC_NLA_NODE_KEY] = { .type = NLA_BINARY, + .len = TIPC_AEAD_KEY_SIZE_MAX}, + [TIPC_NLA_NODE_KEY_MASTER] = { .type = NLA_FLAG }, + [TIPC_NLA_NODE_REKEYING] = { .type = NLA_U32 }, +}; + +/* Properties valid for media, bearer and link */ +const struct nla_policy tipc_nl_prop_policy[TIPC_NLA_PROP_MAX + 1] = { + [TIPC_NLA_PROP_UNSPEC] = { .type = NLA_UNSPEC }, + [TIPC_NLA_PROP_PRIO] = { .type = NLA_U32 }, + [TIPC_NLA_PROP_TOL] = { .type = NLA_U32 }, + [TIPC_NLA_PROP_WIN] = { .type = NLA_U32 }, + [TIPC_NLA_PROP_MTU] = { .type = NLA_U32 }, + [TIPC_NLA_PROP_BROADCAST] = { .type = NLA_U32 }, + [TIPC_NLA_PROP_BROADCAST_RATIO] = { .type = NLA_U32 } +}; + +const struct nla_policy tipc_nl_bearer_policy[TIPC_NLA_BEARER_MAX + 1] = { + [TIPC_NLA_BEARER_UNSPEC] = { .type = NLA_UNSPEC }, + [TIPC_NLA_BEARER_NAME] = { .type = NLA_NUL_STRING, + .len = TIPC_MAX_BEARER_NAME }, + [TIPC_NLA_BEARER_PROP] = { .type = NLA_NESTED }, + [TIPC_NLA_BEARER_DOMAIN] = { .type = NLA_U32 } +}; + +const struct nla_policy tipc_nl_media_policy[TIPC_NLA_MEDIA_MAX + 1] = { + [TIPC_NLA_MEDIA_UNSPEC] = { .type = NLA_UNSPEC }, + [TIPC_NLA_MEDIA_NAME] = { .type = NLA_STRING }, + [TIPC_NLA_MEDIA_PROP] = { .type = NLA_NESTED } +}; + +const struct nla_policy tipc_nl_udp_policy[TIPC_NLA_UDP_MAX + 1] = { + [TIPC_NLA_UDP_UNSPEC] = {.type = NLA_UNSPEC}, + [TIPC_NLA_UDP_LOCAL] = {.type = NLA_BINARY, + .len = sizeof(struct sockaddr_storage)}, + [TIPC_NLA_UDP_REMOTE] = {.type = NLA_BINARY, + .len = sizeof(struct sockaddr_storage)}, +}; + +/* Users of the legacy API (tipc-config) can't handle that we add operations, + * so we have a separate genl handling for the new API. + */ +static const struct genl_ops tipc_genl_v2_ops[] = { + { + .cmd = TIPC_NL_BEARER_DISABLE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tipc_nl_bearer_disable, + }, + { + .cmd = TIPC_NL_BEARER_ENABLE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tipc_nl_bearer_enable, + }, + { + .cmd = TIPC_NL_BEARER_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tipc_nl_bearer_get, + .dumpit = tipc_nl_bearer_dump, + }, + { + .cmd = TIPC_NL_BEARER_ADD, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tipc_nl_bearer_add, + }, + { + .cmd = TIPC_NL_BEARER_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tipc_nl_bearer_set, + }, + { + .cmd = TIPC_NL_SOCK_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .start = tipc_dump_start, + .dumpit = tipc_nl_sk_dump, + .done = tipc_dump_done, + }, + { + .cmd = TIPC_NL_PUBL_GET, + .validate = GENL_DONT_VALIDATE_STRICT | + GENL_DONT_VALIDATE_DUMP_STRICT, + .dumpit = tipc_nl_publ_dump, + }, + { + .cmd = TIPC_NL_LINK_GET, + .validate = GENL_DONT_VALIDATE_STRICT, + .doit = tipc_nl_node_get_link, + .dumpit = tipc_nl_node_dump_link, + }, + { + .cmd = TIPC_NL_LINK_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tipc_nl_node_set_link, + }, + { + .cmd = TIPC_NL_LINK_RESET_STATS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tipc_nl_node_reset_link_stats, + }, + { + .cmd = TIPC_NL_MEDIA_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tipc_nl_media_get, + .dumpit = tipc_nl_media_dump, + }, + { + .cmd = TIPC_NL_MEDIA_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tipc_nl_media_set, + }, + { + .cmd = TIPC_NL_NODE_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .dumpit = tipc_nl_node_dump, + }, + { + .cmd = TIPC_NL_NET_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .dumpit = tipc_nl_net_dump, + }, + { + .cmd = TIPC_NL_NET_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tipc_nl_net_set, + }, + { + .cmd = TIPC_NL_NAME_TABLE_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .dumpit = tipc_nl_name_table_dump, + }, + { + .cmd = TIPC_NL_MON_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tipc_nl_node_set_monitor, + }, + { + .cmd = TIPC_NL_MON_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tipc_nl_node_get_monitor, + .dumpit = tipc_nl_node_dump_monitor, + }, + { + .cmd = TIPC_NL_MON_PEER_GET, + .validate = GENL_DONT_VALIDATE_STRICT | + GENL_DONT_VALIDATE_DUMP_STRICT, + .dumpit = tipc_nl_node_dump_monitor_peer, + }, + { + .cmd = TIPC_NL_PEER_REMOVE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tipc_nl_peer_rm, + }, +#ifdef CONFIG_TIPC_MEDIA_UDP + { + .cmd = TIPC_NL_UDP_GET_REMOTEIP, + .validate = GENL_DONT_VALIDATE_STRICT | + GENL_DONT_VALIDATE_DUMP_STRICT, + .dumpit = tipc_udp_nl_dump_remoteip, + }, +#endif +#ifdef CONFIG_TIPC_CRYPTO + { + .cmd = TIPC_NL_KEY_SET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tipc_nl_node_set_key, + }, + { + .cmd = TIPC_NL_KEY_FLUSH, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tipc_nl_node_flush_key, + }, +#endif + { + .cmd = TIPC_NL_ADDR_LEGACY_GET, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tipc_nl_net_addr_legacy_get, + }, +}; + +struct genl_family tipc_genl_family __ro_after_init = { + .name = TIPC_GENL_V2_NAME, + .version = TIPC_GENL_V2_VERSION, + .hdrsize = 0, + .maxattr = TIPC_NLA_MAX, + .policy = tipc_nl_policy, + .netnsok = true, + .module = THIS_MODULE, + .ops = tipc_genl_v2_ops, + .n_ops = ARRAY_SIZE(tipc_genl_v2_ops), + .resv_start_op = TIPC_NL_ADDR_LEGACY_GET + 1, +}; + +int __init tipc_netlink_start(void) +{ + int res; + + res = genl_register_family(&tipc_genl_family); + if (res) { + pr_err("Failed to register netlink interface\n"); + return res; + } + return 0; +} + +void tipc_netlink_stop(void) +{ + genl_unregister_family(&tipc_genl_family); +} diff --git a/net/tipc/netlink.h b/net/tipc/netlink.h new file mode 100644 index 000000000..7cf777723 --- /dev/null +++ b/net/tipc/netlink.h @@ -0,0 +1,64 @@ +/* + * net/tipc/netlink.h: Include file for TIPC netlink code + * + * Copyright (c) 2014, Ericsson AB + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _TIPC_NETLINK_H +#define _TIPC_NETLINK_H +#include + +extern struct genl_family tipc_genl_family; + +struct tipc_nl_msg { + struct sk_buff *skb; + u32 portid; + u32 seq; +}; + +extern const struct nla_policy tipc_nl_name_table_policy[]; +extern const struct nla_policy tipc_nl_sock_policy[]; +extern const struct nla_policy tipc_nl_net_policy[]; +extern const struct nla_policy tipc_nl_link_policy[]; +extern const struct nla_policy tipc_nl_node_policy[]; +extern const struct nla_policy tipc_nl_prop_policy[]; +extern const struct nla_policy tipc_nl_bearer_policy[]; +extern const struct nla_policy tipc_nl_media_policy[]; +extern const struct nla_policy tipc_nl_udp_policy[]; +extern const struct nla_policy tipc_nl_monitor_policy[]; + +int tipc_netlink_start(void); +int tipc_netlink_compat_start(void); +void tipc_netlink_stop(void); +void tipc_netlink_compat_stop(void); + +#endif diff --git a/net/tipc/netlink_compat.c b/net/tipc/netlink_compat.c new file mode 100644 index 000000000..9eb7cab6b --- /dev/null +++ b/net/tipc/netlink_compat.c @@ -0,0 +1,1380 @@ +/* + * Copyright (c) 2014, Ericsson AB + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "core.h" +#include "bearer.h" +#include "link.h" +#include "name_table.h" +#include "socket.h" +#include "node.h" +#include "net.h" +#include +#include + +/* The legacy API had an artificial message length limit called + * ULTRA_STRING_MAX_LEN. + */ +#define ULTRA_STRING_MAX_LEN 32768 + +#define TIPC_SKB_MAX TLV_SPACE(ULTRA_STRING_MAX_LEN) + +#define REPLY_TRUNCATED "\n" + +struct tipc_nl_compat_msg { + u16 cmd; + int rep_type; + int rep_size; + int req_type; + int req_size; + struct net *net; + struct sk_buff *rep; + struct tlv_desc *req; + struct sock *dst_sk; +}; + +struct tipc_nl_compat_cmd_dump { + int (*header)(struct tipc_nl_compat_msg *); + int (*dumpit)(struct sk_buff *, struct netlink_callback *); + int (*format)(struct tipc_nl_compat_msg *msg, struct nlattr **attrs); +}; + +struct tipc_nl_compat_cmd_doit { + int (*doit)(struct sk_buff *skb, struct genl_info *info); + int (*transcode)(struct tipc_nl_compat_cmd_doit *cmd, + struct sk_buff *skb, struct tipc_nl_compat_msg *msg); +}; + +static int tipc_skb_tailroom(struct sk_buff *skb) +{ + int tailroom; + int limit; + + tailroom = skb_tailroom(skb); + limit = TIPC_SKB_MAX - skb->len; + + if (tailroom < limit) + return tailroom; + + return limit; +} + +static inline int TLV_GET_DATA_LEN(struct tlv_desc *tlv) +{ + return TLV_GET_LEN(tlv) - TLV_SPACE(0); +} + +static int tipc_add_tlv(struct sk_buff *skb, u16 type, void *data, u16 len) +{ + struct tlv_desc *tlv = (struct tlv_desc *)skb_tail_pointer(skb); + + if (tipc_skb_tailroom(skb) < TLV_SPACE(len)) + return -EMSGSIZE; + + skb_put(skb, TLV_SPACE(len)); + memset(tlv, 0, TLV_SPACE(len)); + tlv->tlv_type = htons(type); + tlv->tlv_len = htons(TLV_LENGTH(len)); + if (len && data) + memcpy(TLV_DATA(tlv), data, len); + + return 0; +} + +static void tipc_tlv_init(struct sk_buff *skb, u16 type) +{ + struct tlv_desc *tlv = (struct tlv_desc *)skb->data; + + TLV_SET_LEN(tlv, 0); + TLV_SET_TYPE(tlv, type); + skb_put(skb, sizeof(struct tlv_desc)); +} + +static __printf(2, 3) int tipc_tlv_sprintf(struct sk_buff *skb, + const char *fmt, ...) +{ + int n; + u16 len; + u32 rem; + char *buf; + struct tlv_desc *tlv; + va_list args; + + rem = tipc_skb_tailroom(skb); + + tlv = (struct tlv_desc *)skb->data; + len = TLV_GET_LEN(tlv); + buf = TLV_DATA(tlv) + len; + + va_start(args, fmt); + n = vscnprintf(buf, rem, fmt, args); + va_end(args); + + TLV_SET_LEN(tlv, n + len); + skb_put(skb, n); + + return n; +} + +static struct sk_buff *tipc_tlv_alloc(int size) +{ + int hdr_len; + struct sk_buff *buf; + + size = TLV_SPACE(size); + hdr_len = nlmsg_total_size(GENL_HDRLEN + TIPC_GENL_HDRLEN); + + buf = alloc_skb(hdr_len + size, GFP_KERNEL); + if (!buf) + return NULL; + + skb_reserve(buf, hdr_len); + + return buf; +} + +static struct sk_buff *tipc_get_err_tlv(char *str) +{ + int str_len = strlen(str) + 1; + struct sk_buff *buf; + + buf = tipc_tlv_alloc(TLV_SPACE(str_len)); + if (buf) + tipc_add_tlv(buf, TIPC_TLV_ERROR_STRING, str, str_len); + + return buf; +} + +static inline bool string_is_valid(char *s, int len) +{ + return memchr(s, '\0', len) ? true : false; +} + +static int __tipc_nl_compat_dumpit(struct tipc_nl_compat_cmd_dump *cmd, + struct tipc_nl_compat_msg *msg, + struct sk_buff *arg) +{ + struct genl_dumpit_info info; + int len = 0; + int err; + struct sk_buff *buf; + struct nlmsghdr *nlmsg; + struct netlink_callback cb; + struct nlattr **attrbuf; + + memset(&cb, 0, sizeof(cb)); + cb.nlh = (struct nlmsghdr *)arg->data; + cb.skb = arg; + cb.data = &info; + + buf = nlmsg_new(NLMSG_GOODSIZE, GFP_KERNEL); + if (!buf) + return -ENOMEM; + + buf->sk = msg->dst_sk; + if (__tipc_dump_start(&cb, msg->net)) { + kfree_skb(buf); + return -ENOMEM; + } + + attrbuf = kcalloc(tipc_genl_family.maxattr + 1, + sizeof(struct nlattr *), GFP_KERNEL); + if (!attrbuf) { + err = -ENOMEM; + goto err_out; + } + + info.attrs = attrbuf; + + if (nlmsg_len(cb.nlh) > 0) { + err = nlmsg_parse_deprecated(cb.nlh, GENL_HDRLEN, attrbuf, + tipc_genl_family.maxattr, + tipc_genl_family.policy, NULL); + if (err) + goto err_out; + } + do { + int rem; + + len = (*cmd->dumpit)(buf, &cb); + + nlmsg_for_each_msg(nlmsg, nlmsg_hdr(buf), len, rem) { + err = nlmsg_parse_deprecated(nlmsg, GENL_HDRLEN, + attrbuf, + tipc_genl_family.maxattr, + tipc_genl_family.policy, + NULL); + if (err) + goto err_out; + + err = (*cmd->format)(msg, attrbuf); + if (err) + goto err_out; + + if (tipc_skb_tailroom(msg->rep) <= 1) { + err = -EMSGSIZE; + goto err_out; + } + } + + skb_reset_tail_pointer(buf); + buf->len = 0; + + } while (len); + + err = 0; + +err_out: + kfree(attrbuf); + tipc_dump_done(&cb); + kfree_skb(buf); + + if (err == -EMSGSIZE) { + /* The legacy API only considered messages filling + * "ULTRA_STRING_MAX_LEN" to be truncated. + */ + if ((TIPC_SKB_MAX - msg->rep->len) <= 1) { + char *tail = skb_tail_pointer(msg->rep); + + if (*tail != '\0') + sprintf(tail - sizeof(REPLY_TRUNCATED) - 1, + REPLY_TRUNCATED); + } + + return 0; + } + + return err; +} + +static int tipc_nl_compat_dumpit(struct tipc_nl_compat_cmd_dump *cmd, + struct tipc_nl_compat_msg *msg) +{ + struct nlmsghdr *nlh; + struct sk_buff *arg; + int err; + + if (msg->req_type && (!msg->req_size || + !TLV_CHECK_TYPE(msg->req, msg->req_type))) + return -EINVAL; + + msg->rep = tipc_tlv_alloc(msg->rep_size); + if (!msg->rep) + return -ENOMEM; + + if (msg->rep_type) + tipc_tlv_init(msg->rep, msg->rep_type); + + if (cmd->header) { + err = (*cmd->header)(msg); + if (err) { + kfree_skb(msg->rep); + msg->rep = NULL; + return err; + } + } + + arg = nlmsg_new(0, GFP_KERNEL); + if (!arg) { + kfree_skb(msg->rep); + msg->rep = NULL; + return -ENOMEM; + } + + nlh = nlmsg_put(arg, 0, 0, tipc_genl_family.id, 0, NLM_F_MULTI); + if (!nlh) { + kfree_skb(arg); + kfree_skb(msg->rep); + msg->rep = NULL; + return -EMSGSIZE; + } + nlmsg_end(arg, nlh); + + err = __tipc_nl_compat_dumpit(cmd, msg, arg); + if (err) { + kfree_skb(msg->rep); + msg->rep = NULL; + } + kfree_skb(arg); + + return err; +} + +static int __tipc_nl_compat_doit(struct tipc_nl_compat_cmd_doit *cmd, + struct tipc_nl_compat_msg *msg) +{ + int err; + struct sk_buff *doit_buf; + struct sk_buff *trans_buf; + struct nlattr **attrbuf; + struct genl_info info; + + trans_buf = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!trans_buf) + return -ENOMEM; + + attrbuf = kmalloc_array(tipc_genl_family.maxattr + 1, + sizeof(struct nlattr *), + GFP_KERNEL); + if (!attrbuf) { + err = -ENOMEM; + goto trans_out; + } + + doit_buf = alloc_skb(NLMSG_GOODSIZE, GFP_KERNEL); + if (!doit_buf) { + err = -ENOMEM; + goto attrbuf_out; + } + + memset(&info, 0, sizeof(info)); + info.attrs = attrbuf; + + rtnl_lock(); + err = (*cmd->transcode)(cmd, trans_buf, msg); + if (err) + goto doit_out; + + err = nla_parse_deprecated(attrbuf, tipc_genl_family.maxattr, + (const struct nlattr *)trans_buf->data, + trans_buf->len, NULL, NULL); + if (err) + goto doit_out; + + doit_buf->sk = msg->dst_sk; + + err = (*cmd->doit)(doit_buf, &info); +doit_out: + rtnl_unlock(); + + kfree_skb(doit_buf); +attrbuf_out: + kfree(attrbuf); +trans_out: + kfree_skb(trans_buf); + + return err; +} + +static int tipc_nl_compat_doit(struct tipc_nl_compat_cmd_doit *cmd, + struct tipc_nl_compat_msg *msg) +{ + int err; + + if (msg->req_type && (!msg->req_size || + !TLV_CHECK_TYPE(msg->req, msg->req_type))) + return -EINVAL; + + err = __tipc_nl_compat_doit(cmd, msg); + if (err) + return err; + + /* The legacy API considered an empty message a success message */ + msg->rep = tipc_tlv_alloc(0); + if (!msg->rep) + return -ENOMEM; + + return 0; +} + +static int tipc_nl_compat_bearer_dump(struct tipc_nl_compat_msg *msg, + struct nlattr **attrs) +{ + struct nlattr *bearer[TIPC_NLA_BEARER_MAX + 1]; + int err; + + if (!attrs[TIPC_NLA_BEARER]) + return -EINVAL; + + err = nla_parse_nested_deprecated(bearer, TIPC_NLA_BEARER_MAX, + attrs[TIPC_NLA_BEARER], NULL, NULL); + if (err) + return err; + + return tipc_add_tlv(msg->rep, TIPC_TLV_BEARER_NAME, + nla_data(bearer[TIPC_NLA_BEARER_NAME]), + nla_len(bearer[TIPC_NLA_BEARER_NAME])); +} + +static int tipc_nl_compat_bearer_enable(struct tipc_nl_compat_cmd_doit *cmd, + struct sk_buff *skb, + struct tipc_nl_compat_msg *msg) +{ + struct nlattr *prop; + struct nlattr *bearer; + struct tipc_bearer_config *b; + int len; + + b = (struct tipc_bearer_config *)TLV_DATA(msg->req); + + bearer = nla_nest_start_noflag(skb, TIPC_NLA_BEARER); + if (!bearer) + return -EMSGSIZE; + + len = TLV_GET_DATA_LEN(msg->req); + len -= offsetof(struct tipc_bearer_config, name); + if (len <= 0) + return -EINVAL; + + len = min_t(int, len, TIPC_MAX_BEARER_NAME); + if (!string_is_valid(b->name, len)) + return -EINVAL; + + if (nla_put_string(skb, TIPC_NLA_BEARER_NAME, b->name)) + return -EMSGSIZE; + + if (nla_put_u32(skb, TIPC_NLA_BEARER_DOMAIN, ntohl(b->disc_domain))) + return -EMSGSIZE; + + if (ntohl(b->priority) <= TIPC_MAX_LINK_PRI) { + prop = nla_nest_start_noflag(skb, TIPC_NLA_BEARER_PROP); + if (!prop) + return -EMSGSIZE; + if (nla_put_u32(skb, TIPC_NLA_PROP_PRIO, ntohl(b->priority))) + return -EMSGSIZE; + nla_nest_end(skb, prop); + } + nla_nest_end(skb, bearer); + + return 0; +} + +static int tipc_nl_compat_bearer_disable(struct tipc_nl_compat_cmd_doit *cmd, + struct sk_buff *skb, + struct tipc_nl_compat_msg *msg) +{ + char *name; + struct nlattr *bearer; + int len; + + name = (char *)TLV_DATA(msg->req); + + bearer = nla_nest_start_noflag(skb, TIPC_NLA_BEARER); + if (!bearer) + return -EMSGSIZE; + + len = TLV_GET_DATA_LEN(msg->req); + if (len <= 0) + return -EINVAL; + + len = min_t(int, len, TIPC_MAX_BEARER_NAME); + if (!string_is_valid(name, len)) + return -EINVAL; + + if (nla_put_string(skb, TIPC_NLA_BEARER_NAME, name)) + return -EMSGSIZE; + + nla_nest_end(skb, bearer); + + return 0; +} + +static inline u32 perc(u32 count, u32 total) +{ + return (count * 100 + (total / 2)) / total; +} + +static void __fill_bc_link_stat(struct tipc_nl_compat_msg *msg, + struct nlattr *prop[], struct nlattr *stats[]) +{ + tipc_tlv_sprintf(msg->rep, " Window:%u packets\n", + nla_get_u32(prop[TIPC_NLA_PROP_WIN])); + + tipc_tlv_sprintf(msg->rep, + " RX packets:%u fragments:%u/%u bundles:%u/%u\n", + nla_get_u32(stats[TIPC_NLA_STATS_RX_INFO]), + nla_get_u32(stats[TIPC_NLA_STATS_RX_FRAGMENTS]), + nla_get_u32(stats[TIPC_NLA_STATS_RX_FRAGMENTED]), + nla_get_u32(stats[TIPC_NLA_STATS_RX_BUNDLES]), + nla_get_u32(stats[TIPC_NLA_STATS_RX_BUNDLED])); + + tipc_tlv_sprintf(msg->rep, + " TX packets:%u fragments:%u/%u bundles:%u/%u\n", + nla_get_u32(stats[TIPC_NLA_STATS_TX_INFO]), + nla_get_u32(stats[TIPC_NLA_STATS_TX_FRAGMENTS]), + nla_get_u32(stats[TIPC_NLA_STATS_TX_FRAGMENTED]), + nla_get_u32(stats[TIPC_NLA_STATS_TX_BUNDLES]), + nla_get_u32(stats[TIPC_NLA_STATS_TX_BUNDLED])); + + tipc_tlv_sprintf(msg->rep, " RX naks:%u defs:%u dups:%u\n", + nla_get_u32(stats[TIPC_NLA_STATS_RX_NACKS]), + nla_get_u32(stats[TIPC_NLA_STATS_RX_DEFERRED]), + nla_get_u32(stats[TIPC_NLA_STATS_DUPLICATES])); + + tipc_tlv_sprintf(msg->rep, " TX naks:%u acks:%u dups:%u\n", + nla_get_u32(stats[TIPC_NLA_STATS_TX_NACKS]), + nla_get_u32(stats[TIPC_NLA_STATS_TX_ACKS]), + nla_get_u32(stats[TIPC_NLA_STATS_RETRANSMITTED])); + + tipc_tlv_sprintf(msg->rep, + " Congestion link:%u Send queue max:%u avg:%u", + nla_get_u32(stats[TIPC_NLA_STATS_LINK_CONGS]), + nla_get_u32(stats[TIPC_NLA_STATS_MAX_QUEUE]), + nla_get_u32(stats[TIPC_NLA_STATS_AVG_QUEUE])); +} + +static int tipc_nl_compat_link_stat_dump(struct tipc_nl_compat_msg *msg, + struct nlattr **attrs) +{ + char *name; + struct nlattr *link[TIPC_NLA_LINK_MAX + 1]; + struct nlattr *prop[TIPC_NLA_PROP_MAX + 1]; + struct nlattr *stats[TIPC_NLA_STATS_MAX + 1]; + int err; + int len; + + if (!attrs[TIPC_NLA_LINK]) + return -EINVAL; + + err = nla_parse_nested_deprecated(link, TIPC_NLA_LINK_MAX, + attrs[TIPC_NLA_LINK], NULL, NULL); + if (err) + return err; + + if (!link[TIPC_NLA_LINK_PROP]) + return -EINVAL; + + err = nla_parse_nested_deprecated(prop, TIPC_NLA_PROP_MAX, + link[TIPC_NLA_LINK_PROP], NULL, + NULL); + if (err) + return err; + + if (!link[TIPC_NLA_LINK_STATS]) + return -EINVAL; + + err = nla_parse_nested_deprecated(stats, TIPC_NLA_STATS_MAX, + link[TIPC_NLA_LINK_STATS], NULL, + NULL); + if (err) + return err; + + name = (char *)TLV_DATA(msg->req); + + len = TLV_GET_DATA_LEN(msg->req); + if (len <= 0) + return -EINVAL; + + len = min_t(int, len, TIPC_MAX_LINK_NAME); + if (!string_is_valid(name, len)) + return -EINVAL; + + if (strcmp(name, nla_data(link[TIPC_NLA_LINK_NAME])) != 0) + return 0; + + tipc_tlv_sprintf(msg->rep, "\nLink <%s>\n", + (char *)nla_data(link[TIPC_NLA_LINK_NAME])); + + if (link[TIPC_NLA_LINK_BROADCAST]) { + __fill_bc_link_stat(msg, prop, stats); + return 0; + } + + if (link[TIPC_NLA_LINK_ACTIVE]) + tipc_tlv_sprintf(msg->rep, " ACTIVE"); + else if (link[TIPC_NLA_LINK_UP]) + tipc_tlv_sprintf(msg->rep, " STANDBY"); + else + tipc_tlv_sprintf(msg->rep, " DEFUNCT"); + + tipc_tlv_sprintf(msg->rep, " MTU:%u Priority:%u", + nla_get_u32(link[TIPC_NLA_LINK_MTU]), + nla_get_u32(prop[TIPC_NLA_PROP_PRIO])); + + tipc_tlv_sprintf(msg->rep, " Tolerance:%u ms Window:%u packets\n", + nla_get_u32(prop[TIPC_NLA_PROP_TOL]), + nla_get_u32(prop[TIPC_NLA_PROP_WIN])); + + tipc_tlv_sprintf(msg->rep, + " RX packets:%u fragments:%u/%u bundles:%u/%u\n", + nla_get_u32(link[TIPC_NLA_LINK_RX]) - + nla_get_u32(stats[TIPC_NLA_STATS_RX_INFO]), + nla_get_u32(stats[TIPC_NLA_STATS_RX_FRAGMENTS]), + nla_get_u32(stats[TIPC_NLA_STATS_RX_FRAGMENTED]), + nla_get_u32(stats[TIPC_NLA_STATS_RX_BUNDLES]), + nla_get_u32(stats[TIPC_NLA_STATS_RX_BUNDLED])); + + tipc_tlv_sprintf(msg->rep, + " TX packets:%u fragments:%u/%u bundles:%u/%u\n", + nla_get_u32(link[TIPC_NLA_LINK_TX]) - + nla_get_u32(stats[TIPC_NLA_STATS_TX_INFO]), + nla_get_u32(stats[TIPC_NLA_STATS_TX_FRAGMENTS]), + nla_get_u32(stats[TIPC_NLA_STATS_TX_FRAGMENTED]), + nla_get_u32(stats[TIPC_NLA_STATS_TX_BUNDLES]), + nla_get_u32(stats[TIPC_NLA_STATS_TX_BUNDLED])); + + tipc_tlv_sprintf(msg->rep, + " TX profile sample:%u packets average:%u octets\n", + nla_get_u32(stats[TIPC_NLA_STATS_MSG_LEN_CNT]), + nla_get_u32(stats[TIPC_NLA_STATS_MSG_LEN_TOT]) / + nla_get_u32(stats[TIPC_NLA_STATS_MSG_PROF_TOT])); + + tipc_tlv_sprintf(msg->rep, + " 0-64:%u%% -256:%u%% -1024:%u%% -4096:%u%% ", + perc(nla_get_u32(stats[TIPC_NLA_STATS_MSG_LEN_P0]), + nla_get_u32(stats[TIPC_NLA_STATS_MSG_PROF_TOT])), + perc(nla_get_u32(stats[TIPC_NLA_STATS_MSG_LEN_P1]), + nla_get_u32(stats[TIPC_NLA_STATS_MSG_PROF_TOT])), + perc(nla_get_u32(stats[TIPC_NLA_STATS_MSG_LEN_P2]), + nla_get_u32(stats[TIPC_NLA_STATS_MSG_PROF_TOT])), + perc(nla_get_u32(stats[TIPC_NLA_STATS_MSG_LEN_P3]), + nla_get_u32(stats[TIPC_NLA_STATS_MSG_PROF_TOT]))); + + tipc_tlv_sprintf(msg->rep, "-16384:%u%% -32768:%u%% -66000:%u%%\n", + perc(nla_get_u32(stats[TIPC_NLA_STATS_MSG_LEN_P4]), + nla_get_u32(stats[TIPC_NLA_STATS_MSG_PROF_TOT])), + perc(nla_get_u32(stats[TIPC_NLA_STATS_MSG_LEN_P5]), + nla_get_u32(stats[TIPC_NLA_STATS_MSG_PROF_TOT])), + perc(nla_get_u32(stats[TIPC_NLA_STATS_MSG_LEN_P6]), + nla_get_u32(stats[TIPC_NLA_STATS_MSG_PROF_TOT]))); + + tipc_tlv_sprintf(msg->rep, + " RX states:%u probes:%u naks:%u defs:%u dups:%u\n", + nla_get_u32(stats[TIPC_NLA_STATS_RX_STATES]), + nla_get_u32(stats[TIPC_NLA_STATS_RX_PROBES]), + nla_get_u32(stats[TIPC_NLA_STATS_RX_NACKS]), + nla_get_u32(stats[TIPC_NLA_STATS_RX_DEFERRED]), + nla_get_u32(stats[TIPC_NLA_STATS_DUPLICATES])); + + tipc_tlv_sprintf(msg->rep, + " TX states:%u probes:%u naks:%u acks:%u dups:%u\n", + nla_get_u32(stats[TIPC_NLA_STATS_TX_STATES]), + nla_get_u32(stats[TIPC_NLA_STATS_TX_PROBES]), + nla_get_u32(stats[TIPC_NLA_STATS_TX_NACKS]), + nla_get_u32(stats[TIPC_NLA_STATS_TX_ACKS]), + nla_get_u32(stats[TIPC_NLA_STATS_RETRANSMITTED])); + + tipc_tlv_sprintf(msg->rep, + " Congestion link:%u Send queue max:%u avg:%u", + nla_get_u32(stats[TIPC_NLA_STATS_LINK_CONGS]), + nla_get_u32(stats[TIPC_NLA_STATS_MAX_QUEUE]), + nla_get_u32(stats[TIPC_NLA_STATS_AVG_QUEUE])); + + return 0; +} + +static int tipc_nl_compat_link_dump(struct tipc_nl_compat_msg *msg, + struct nlattr **attrs) +{ + struct nlattr *link[TIPC_NLA_LINK_MAX + 1]; + struct tipc_link_info link_info; + int err; + + if (!attrs[TIPC_NLA_LINK]) + return -EINVAL; + + err = nla_parse_nested_deprecated(link, TIPC_NLA_LINK_MAX, + attrs[TIPC_NLA_LINK], NULL, NULL); + if (err) + return err; + + link_info.dest = htonl(nla_get_flag(link[TIPC_NLA_LINK_DEST])); + link_info.up = htonl(nla_get_flag(link[TIPC_NLA_LINK_UP])); + nla_strscpy(link_info.str, link[TIPC_NLA_LINK_NAME], + TIPC_MAX_LINK_NAME); + + return tipc_add_tlv(msg->rep, TIPC_TLV_LINK_INFO, + &link_info, sizeof(link_info)); +} + +static int __tipc_add_link_prop(struct sk_buff *skb, + struct tipc_nl_compat_msg *msg, + struct tipc_link_config *lc) +{ + switch (msg->cmd) { + case TIPC_CMD_SET_LINK_PRI: + return nla_put_u32(skb, TIPC_NLA_PROP_PRIO, ntohl(lc->value)); + case TIPC_CMD_SET_LINK_TOL: + return nla_put_u32(skb, TIPC_NLA_PROP_TOL, ntohl(lc->value)); + case TIPC_CMD_SET_LINK_WINDOW: + return nla_put_u32(skb, TIPC_NLA_PROP_WIN, ntohl(lc->value)); + } + + return -EINVAL; +} + +static int tipc_nl_compat_media_set(struct sk_buff *skb, + struct tipc_nl_compat_msg *msg) +{ + struct nlattr *prop; + struct nlattr *media; + struct tipc_link_config *lc; + + lc = (struct tipc_link_config *)TLV_DATA(msg->req); + + media = nla_nest_start_noflag(skb, TIPC_NLA_MEDIA); + if (!media) + return -EMSGSIZE; + + if (nla_put_string(skb, TIPC_NLA_MEDIA_NAME, lc->name)) + return -EMSGSIZE; + + prop = nla_nest_start_noflag(skb, TIPC_NLA_MEDIA_PROP); + if (!prop) + return -EMSGSIZE; + + __tipc_add_link_prop(skb, msg, lc); + nla_nest_end(skb, prop); + nla_nest_end(skb, media); + + return 0; +} + +static int tipc_nl_compat_bearer_set(struct sk_buff *skb, + struct tipc_nl_compat_msg *msg) +{ + struct nlattr *prop; + struct nlattr *bearer; + struct tipc_link_config *lc; + + lc = (struct tipc_link_config *)TLV_DATA(msg->req); + + bearer = nla_nest_start_noflag(skb, TIPC_NLA_BEARER); + if (!bearer) + return -EMSGSIZE; + + if (nla_put_string(skb, TIPC_NLA_BEARER_NAME, lc->name)) + return -EMSGSIZE; + + prop = nla_nest_start_noflag(skb, TIPC_NLA_BEARER_PROP); + if (!prop) + return -EMSGSIZE; + + __tipc_add_link_prop(skb, msg, lc); + nla_nest_end(skb, prop); + nla_nest_end(skb, bearer); + + return 0; +} + +static int __tipc_nl_compat_link_set(struct sk_buff *skb, + struct tipc_nl_compat_msg *msg) +{ + struct nlattr *prop; + struct nlattr *link; + struct tipc_link_config *lc; + + lc = (struct tipc_link_config *)TLV_DATA(msg->req); + + link = nla_nest_start_noflag(skb, TIPC_NLA_LINK); + if (!link) + return -EMSGSIZE; + + if (nla_put_string(skb, TIPC_NLA_LINK_NAME, lc->name)) + return -EMSGSIZE; + + prop = nla_nest_start_noflag(skb, TIPC_NLA_LINK_PROP); + if (!prop) + return -EMSGSIZE; + + __tipc_add_link_prop(skb, msg, lc); + nla_nest_end(skb, prop); + nla_nest_end(skb, link); + + return 0; +} + +static int tipc_nl_compat_link_set(struct tipc_nl_compat_cmd_doit *cmd, + struct sk_buff *skb, + struct tipc_nl_compat_msg *msg) +{ + struct tipc_link_config *lc; + struct tipc_bearer *bearer; + struct tipc_media *media; + int len; + + lc = (struct tipc_link_config *)TLV_DATA(msg->req); + + len = TLV_GET_DATA_LEN(msg->req); + len -= offsetof(struct tipc_link_config, name); + if (len <= 0) + return -EINVAL; + + len = min_t(int, len, TIPC_MAX_LINK_NAME); + if (!string_is_valid(lc->name, len)) + return -EINVAL; + + media = tipc_media_find(lc->name); + if (media) { + cmd->doit = &__tipc_nl_media_set; + return tipc_nl_compat_media_set(skb, msg); + } + + bearer = tipc_bearer_find(msg->net, lc->name); + if (bearer) { + cmd->doit = &__tipc_nl_bearer_set; + return tipc_nl_compat_bearer_set(skb, msg); + } + + return __tipc_nl_compat_link_set(skb, msg); +} + +static int tipc_nl_compat_link_reset_stats(struct tipc_nl_compat_cmd_doit *cmd, + struct sk_buff *skb, + struct tipc_nl_compat_msg *msg) +{ + char *name; + struct nlattr *link; + int len; + + name = (char *)TLV_DATA(msg->req); + + link = nla_nest_start_noflag(skb, TIPC_NLA_LINK); + if (!link) + return -EMSGSIZE; + + len = TLV_GET_DATA_LEN(msg->req); + if (len <= 0) + return -EINVAL; + + len = min_t(int, len, TIPC_MAX_LINK_NAME); + if (!string_is_valid(name, len)) + return -EINVAL; + + if (nla_put_string(skb, TIPC_NLA_LINK_NAME, name)) + return -EMSGSIZE; + + nla_nest_end(skb, link); + + return 0; +} + +static int tipc_nl_compat_name_table_dump_header(struct tipc_nl_compat_msg *msg) +{ + int i; + u32 depth; + struct tipc_name_table_query *ntq; + static const char * const header[] = { + "Type ", + "Lower Upper ", + "Port Identity ", + "Publication Scope" + }; + + ntq = (struct tipc_name_table_query *)TLV_DATA(msg->req); + if (TLV_GET_DATA_LEN(msg->req) < (int)sizeof(struct tipc_name_table_query)) + return -EINVAL; + + depth = ntohl(ntq->depth); + + if (depth > 4) + depth = 4; + for (i = 0; i < depth; i++) + tipc_tlv_sprintf(msg->rep, header[i]); + tipc_tlv_sprintf(msg->rep, "\n"); + + return 0; +} + +static int tipc_nl_compat_name_table_dump(struct tipc_nl_compat_msg *msg, + struct nlattr **attrs) +{ + char port_str[27]; + struct tipc_name_table_query *ntq; + struct nlattr *nt[TIPC_NLA_NAME_TABLE_MAX + 1]; + struct nlattr *publ[TIPC_NLA_PUBL_MAX + 1]; + u32 node, depth, type, lowbound, upbound; + static const char * const scope_str[] = {"", " zone", " cluster", + " node"}; + int err; + + if (!attrs[TIPC_NLA_NAME_TABLE]) + return -EINVAL; + + err = nla_parse_nested_deprecated(nt, TIPC_NLA_NAME_TABLE_MAX, + attrs[TIPC_NLA_NAME_TABLE], NULL, + NULL); + if (err) + return err; + + if (!nt[TIPC_NLA_NAME_TABLE_PUBL]) + return -EINVAL; + + err = nla_parse_nested_deprecated(publ, TIPC_NLA_PUBL_MAX, + nt[TIPC_NLA_NAME_TABLE_PUBL], NULL, + NULL); + if (err) + return err; + + ntq = (struct tipc_name_table_query *)TLV_DATA(msg->req); + + depth = ntohl(ntq->depth); + type = ntohl(ntq->type); + lowbound = ntohl(ntq->lowbound); + upbound = ntohl(ntq->upbound); + + if (!(depth & TIPC_NTQ_ALLTYPES) && + (type != nla_get_u32(publ[TIPC_NLA_PUBL_TYPE]))) + return 0; + if (lowbound && (lowbound > nla_get_u32(publ[TIPC_NLA_PUBL_UPPER]))) + return 0; + if (upbound && (upbound < nla_get_u32(publ[TIPC_NLA_PUBL_LOWER]))) + return 0; + + tipc_tlv_sprintf(msg->rep, "%-10u ", + nla_get_u32(publ[TIPC_NLA_PUBL_TYPE])); + + if (depth == 1) + goto out; + + tipc_tlv_sprintf(msg->rep, "%-10u %-10u ", + nla_get_u32(publ[TIPC_NLA_PUBL_LOWER]), + nla_get_u32(publ[TIPC_NLA_PUBL_UPPER])); + + if (depth == 2) + goto out; + + node = nla_get_u32(publ[TIPC_NLA_PUBL_NODE]); + sprintf(port_str, "<%u.%u.%u:%u>", tipc_zone(node), tipc_cluster(node), + tipc_node(node), nla_get_u32(publ[TIPC_NLA_PUBL_REF])); + tipc_tlv_sprintf(msg->rep, "%-26s ", port_str); + + if (depth == 3) + goto out; + + tipc_tlv_sprintf(msg->rep, "%-10u %s", + nla_get_u32(publ[TIPC_NLA_PUBL_KEY]), + scope_str[nla_get_u32(publ[TIPC_NLA_PUBL_SCOPE])]); +out: + tipc_tlv_sprintf(msg->rep, "\n"); + + return 0; +} + +static int __tipc_nl_compat_publ_dump(struct tipc_nl_compat_msg *msg, + struct nlattr **attrs) +{ + u32 type, lower, upper; + struct nlattr *publ[TIPC_NLA_PUBL_MAX + 1]; + int err; + + if (!attrs[TIPC_NLA_PUBL]) + return -EINVAL; + + err = nla_parse_nested_deprecated(publ, TIPC_NLA_PUBL_MAX, + attrs[TIPC_NLA_PUBL], NULL, NULL); + if (err) + return err; + + type = nla_get_u32(publ[TIPC_NLA_PUBL_TYPE]); + lower = nla_get_u32(publ[TIPC_NLA_PUBL_LOWER]); + upper = nla_get_u32(publ[TIPC_NLA_PUBL_UPPER]); + + if (lower == upper) + tipc_tlv_sprintf(msg->rep, " {%u,%u}", type, lower); + else + tipc_tlv_sprintf(msg->rep, " {%u,%u,%u}", type, lower, upper); + + return 0; +} + +static int tipc_nl_compat_publ_dump(struct tipc_nl_compat_msg *msg, u32 sock) +{ + int err; + void *hdr; + struct nlattr *nest; + struct sk_buff *args; + struct tipc_nl_compat_cmd_dump dump; + + args = nlmsg_new(NLMSG_GOODSIZE, GFP_KERNEL); + if (!args) + return -ENOMEM; + + hdr = genlmsg_put(args, 0, 0, &tipc_genl_family, NLM_F_MULTI, + TIPC_NL_PUBL_GET); + if (!hdr) { + kfree_skb(args); + return -EMSGSIZE; + } + + nest = nla_nest_start_noflag(args, TIPC_NLA_SOCK); + if (!nest) { + kfree_skb(args); + return -EMSGSIZE; + } + + if (nla_put_u32(args, TIPC_NLA_SOCK_REF, sock)) { + kfree_skb(args); + return -EMSGSIZE; + } + + nla_nest_end(args, nest); + genlmsg_end(args, hdr); + + dump.dumpit = tipc_nl_publ_dump; + dump.format = __tipc_nl_compat_publ_dump; + + err = __tipc_nl_compat_dumpit(&dump, msg, args); + + kfree_skb(args); + + return err; +} + +static int tipc_nl_compat_sk_dump(struct tipc_nl_compat_msg *msg, + struct nlattr **attrs) +{ + int err; + u32 sock_ref; + struct nlattr *sock[TIPC_NLA_SOCK_MAX + 1]; + + if (!attrs[TIPC_NLA_SOCK]) + return -EINVAL; + + err = nla_parse_nested_deprecated(sock, TIPC_NLA_SOCK_MAX, + attrs[TIPC_NLA_SOCK], NULL, NULL); + if (err) + return err; + + sock_ref = nla_get_u32(sock[TIPC_NLA_SOCK_REF]); + tipc_tlv_sprintf(msg->rep, "%u:", sock_ref); + + if (sock[TIPC_NLA_SOCK_CON]) { + u32 node; + struct nlattr *con[TIPC_NLA_CON_MAX + 1]; + + err = nla_parse_nested_deprecated(con, TIPC_NLA_CON_MAX, + sock[TIPC_NLA_SOCK_CON], + NULL, NULL); + + if (err) + return err; + + node = nla_get_u32(con[TIPC_NLA_CON_NODE]); + tipc_tlv_sprintf(msg->rep, " connected to <%u.%u.%u:%u>", + tipc_zone(node), + tipc_cluster(node), + tipc_node(node), + nla_get_u32(con[TIPC_NLA_CON_SOCK])); + + if (con[TIPC_NLA_CON_FLAG]) + tipc_tlv_sprintf(msg->rep, " via {%u,%u}\n", + nla_get_u32(con[TIPC_NLA_CON_TYPE]), + nla_get_u32(con[TIPC_NLA_CON_INST])); + else + tipc_tlv_sprintf(msg->rep, "\n"); + } else if (sock[TIPC_NLA_SOCK_HAS_PUBL]) { + tipc_tlv_sprintf(msg->rep, " bound to"); + + err = tipc_nl_compat_publ_dump(msg, sock_ref); + if (err) + return err; + } + tipc_tlv_sprintf(msg->rep, "\n"); + + return 0; +} + +static int tipc_nl_compat_media_dump(struct tipc_nl_compat_msg *msg, + struct nlattr **attrs) +{ + struct nlattr *media[TIPC_NLA_MEDIA_MAX + 1]; + int err; + + if (!attrs[TIPC_NLA_MEDIA]) + return -EINVAL; + + err = nla_parse_nested_deprecated(media, TIPC_NLA_MEDIA_MAX, + attrs[TIPC_NLA_MEDIA], NULL, NULL); + if (err) + return err; + + return tipc_add_tlv(msg->rep, TIPC_TLV_MEDIA_NAME, + nla_data(media[TIPC_NLA_MEDIA_NAME]), + nla_len(media[TIPC_NLA_MEDIA_NAME])); +} + +static int tipc_nl_compat_node_dump(struct tipc_nl_compat_msg *msg, + struct nlattr **attrs) +{ + struct tipc_node_info node_info; + struct nlattr *node[TIPC_NLA_NODE_MAX + 1]; + int err; + + if (!attrs[TIPC_NLA_NODE]) + return -EINVAL; + + err = nla_parse_nested_deprecated(node, TIPC_NLA_NODE_MAX, + attrs[TIPC_NLA_NODE], NULL, NULL); + if (err) + return err; + + node_info.addr = htonl(nla_get_u32(node[TIPC_NLA_NODE_ADDR])); + node_info.up = htonl(nla_get_flag(node[TIPC_NLA_NODE_UP])); + + return tipc_add_tlv(msg->rep, TIPC_TLV_NODE_INFO, &node_info, + sizeof(node_info)); +} + +static int tipc_nl_compat_net_set(struct tipc_nl_compat_cmd_doit *cmd, + struct sk_buff *skb, + struct tipc_nl_compat_msg *msg) +{ + u32 val; + struct nlattr *net; + + val = ntohl(*(__be32 *)TLV_DATA(msg->req)); + + net = nla_nest_start_noflag(skb, TIPC_NLA_NET); + if (!net) + return -EMSGSIZE; + + if (msg->cmd == TIPC_CMD_SET_NODE_ADDR) { + if (nla_put_u32(skb, TIPC_NLA_NET_ADDR, val)) + return -EMSGSIZE; + } else if (msg->cmd == TIPC_CMD_SET_NETID) { + if (nla_put_u32(skb, TIPC_NLA_NET_ID, val)) + return -EMSGSIZE; + } + nla_nest_end(skb, net); + + return 0; +} + +static int tipc_nl_compat_net_dump(struct tipc_nl_compat_msg *msg, + struct nlattr **attrs) +{ + __be32 id; + struct nlattr *net[TIPC_NLA_NET_MAX + 1]; + int err; + + if (!attrs[TIPC_NLA_NET]) + return -EINVAL; + + err = nla_parse_nested_deprecated(net, TIPC_NLA_NET_MAX, + attrs[TIPC_NLA_NET], NULL, NULL); + if (err) + return err; + + id = htonl(nla_get_u32(net[TIPC_NLA_NET_ID])); + + return tipc_add_tlv(msg->rep, TIPC_TLV_UNSIGNED, &id, sizeof(id)); +} + +static int tipc_cmd_show_stats_compat(struct tipc_nl_compat_msg *msg) +{ + msg->rep = tipc_tlv_alloc(ULTRA_STRING_MAX_LEN); + if (!msg->rep) + return -ENOMEM; + + tipc_tlv_init(msg->rep, TIPC_TLV_ULTRA_STRING); + tipc_tlv_sprintf(msg->rep, "TIPC version " TIPC_MOD_VER "\n"); + + return 0; +} + +static int tipc_nl_compat_handle(struct tipc_nl_compat_msg *msg) +{ + struct tipc_nl_compat_cmd_dump dump; + struct tipc_nl_compat_cmd_doit doit; + + memset(&dump, 0, sizeof(dump)); + memset(&doit, 0, sizeof(doit)); + + switch (msg->cmd) { + case TIPC_CMD_NOOP: + msg->rep = tipc_tlv_alloc(0); + if (!msg->rep) + return -ENOMEM; + return 0; + case TIPC_CMD_GET_BEARER_NAMES: + msg->rep_size = MAX_BEARERS * TLV_SPACE(TIPC_MAX_BEARER_NAME); + dump.dumpit = tipc_nl_bearer_dump; + dump.format = tipc_nl_compat_bearer_dump; + return tipc_nl_compat_dumpit(&dump, msg); + case TIPC_CMD_ENABLE_BEARER: + msg->req_type = TIPC_TLV_BEARER_CONFIG; + doit.doit = __tipc_nl_bearer_enable; + doit.transcode = tipc_nl_compat_bearer_enable; + return tipc_nl_compat_doit(&doit, msg); + case TIPC_CMD_DISABLE_BEARER: + msg->req_type = TIPC_TLV_BEARER_NAME; + doit.doit = __tipc_nl_bearer_disable; + doit.transcode = tipc_nl_compat_bearer_disable; + return tipc_nl_compat_doit(&doit, msg); + case TIPC_CMD_SHOW_LINK_STATS: + msg->req_type = TIPC_TLV_LINK_NAME; + msg->rep_size = ULTRA_STRING_MAX_LEN; + msg->rep_type = TIPC_TLV_ULTRA_STRING; + dump.dumpit = tipc_nl_node_dump_link; + dump.format = tipc_nl_compat_link_stat_dump; + return tipc_nl_compat_dumpit(&dump, msg); + case TIPC_CMD_GET_LINKS: + msg->req_type = TIPC_TLV_NET_ADDR; + msg->rep_size = ULTRA_STRING_MAX_LEN; + dump.dumpit = tipc_nl_node_dump_link; + dump.format = tipc_nl_compat_link_dump; + return tipc_nl_compat_dumpit(&dump, msg); + case TIPC_CMD_SET_LINK_TOL: + case TIPC_CMD_SET_LINK_PRI: + case TIPC_CMD_SET_LINK_WINDOW: + msg->req_type = TIPC_TLV_LINK_CONFIG; + doit.doit = tipc_nl_node_set_link; + doit.transcode = tipc_nl_compat_link_set; + return tipc_nl_compat_doit(&doit, msg); + case TIPC_CMD_RESET_LINK_STATS: + msg->req_type = TIPC_TLV_LINK_NAME; + doit.doit = tipc_nl_node_reset_link_stats; + doit.transcode = tipc_nl_compat_link_reset_stats; + return tipc_nl_compat_doit(&doit, msg); + case TIPC_CMD_SHOW_NAME_TABLE: + msg->req_type = TIPC_TLV_NAME_TBL_QUERY; + msg->rep_size = ULTRA_STRING_MAX_LEN; + msg->rep_type = TIPC_TLV_ULTRA_STRING; + dump.header = tipc_nl_compat_name_table_dump_header; + dump.dumpit = tipc_nl_name_table_dump; + dump.format = tipc_nl_compat_name_table_dump; + return tipc_nl_compat_dumpit(&dump, msg); + case TIPC_CMD_SHOW_PORTS: + msg->rep_size = ULTRA_STRING_MAX_LEN; + msg->rep_type = TIPC_TLV_ULTRA_STRING; + dump.dumpit = tipc_nl_sk_dump; + dump.format = tipc_nl_compat_sk_dump; + return tipc_nl_compat_dumpit(&dump, msg); + case TIPC_CMD_GET_MEDIA_NAMES: + msg->rep_size = MAX_MEDIA * TLV_SPACE(TIPC_MAX_MEDIA_NAME); + dump.dumpit = tipc_nl_media_dump; + dump.format = tipc_nl_compat_media_dump; + return tipc_nl_compat_dumpit(&dump, msg); + case TIPC_CMD_GET_NODES: + msg->rep_size = ULTRA_STRING_MAX_LEN; + dump.dumpit = tipc_nl_node_dump; + dump.format = tipc_nl_compat_node_dump; + return tipc_nl_compat_dumpit(&dump, msg); + case TIPC_CMD_SET_NODE_ADDR: + msg->req_type = TIPC_TLV_NET_ADDR; + doit.doit = __tipc_nl_net_set; + doit.transcode = tipc_nl_compat_net_set; + return tipc_nl_compat_doit(&doit, msg); + case TIPC_CMD_SET_NETID: + msg->req_type = TIPC_TLV_UNSIGNED; + doit.doit = __tipc_nl_net_set; + doit.transcode = tipc_nl_compat_net_set; + return tipc_nl_compat_doit(&doit, msg); + case TIPC_CMD_GET_NETID: + msg->rep_size = sizeof(u32); + dump.dumpit = tipc_nl_net_dump; + dump.format = tipc_nl_compat_net_dump; + return tipc_nl_compat_dumpit(&dump, msg); + case TIPC_CMD_SHOW_STATS: + return tipc_cmd_show_stats_compat(msg); + } + + return -EOPNOTSUPP; +} + +static int tipc_nl_compat_recv(struct sk_buff *skb, struct genl_info *info) +{ + int err; + int len; + struct tipc_nl_compat_msg msg; + struct nlmsghdr *req_nlh; + struct nlmsghdr *rep_nlh; + struct tipc_genlmsghdr *req_userhdr = info->userhdr; + + memset(&msg, 0, sizeof(msg)); + + req_nlh = (struct nlmsghdr *)skb->data; + msg.req = nlmsg_data(req_nlh) + GENL_HDRLEN + TIPC_GENL_HDRLEN; + msg.cmd = req_userhdr->cmd; + msg.net = genl_info_net(info); + msg.dst_sk = skb->sk; + + if ((msg.cmd & 0xC000) && (!netlink_net_capable(skb, CAP_NET_ADMIN))) { + msg.rep = tipc_get_err_tlv(TIPC_CFG_NOT_NET_ADMIN); + err = -EACCES; + goto send; + } + + msg.req_size = nlmsg_attrlen(req_nlh, GENL_HDRLEN + TIPC_GENL_HDRLEN); + if (msg.req_size && !TLV_OK(msg.req, msg.req_size)) { + msg.rep = tipc_get_err_tlv(TIPC_CFG_NOT_SUPPORTED); + err = -EOPNOTSUPP; + goto send; + } + + err = tipc_nl_compat_handle(&msg); + if ((err == -EOPNOTSUPP) || (err == -EPERM)) + msg.rep = tipc_get_err_tlv(TIPC_CFG_NOT_SUPPORTED); + else if (err == -EINVAL) + msg.rep = tipc_get_err_tlv(TIPC_CFG_TLV_ERROR); +send: + if (!msg.rep) + return err; + + len = nlmsg_total_size(GENL_HDRLEN + TIPC_GENL_HDRLEN); + skb_push(msg.rep, len); + rep_nlh = nlmsg_hdr(msg.rep); + memcpy(rep_nlh, info->nlhdr, len); + rep_nlh->nlmsg_len = msg.rep->len; + genlmsg_unicast(msg.net, msg.rep, NETLINK_CB(skb).portid); + + return err; +} + +static const struct genl_small_ops tipc_genl_compat_ops[] = { + { + .cmd = TIPC_GENL_CMD, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = tipc_nl_compat_recv, + }, +}; + +static struct genl_family tipc_genl_compat_family __ro_after_init = { + .name = TIPC_GENL_NAME, + .version = TIPC_GENL_VERSION, + .hdrsize = TIPC_GENL_HDRLEN, + .maxattr = 0, + .netnsok = true, + .module = THIS_MODULE, + .small_ops = tipc_genl_compat_ops, + .n_small_ops = ARRAY_SIZE(tipc_genl_compat_ops), + .resv_start_op = TIPC_GENL_CMD + 1, +}; + +int __init tipc_netlink_compat_start(void) +{ + int res; + + res = genl_register_family(&tipc_genl_compat_family); + if (res) { + pr_err("Failed to register legacy compat interface\n"); + return res; + } + + return 0; +} + +void tipc_netlink_compat_stop(void) +{ + genl_unregister_family(&tipc_genl_compat_family); +} diff --git a/net/tipc/node.c b/net/tipc/node.c new file mode 100644 index 000000000..a9c5b6594 --- /dev/null +++ b/net/tipc/node.c @@ -0,0 +1,3166 @@ +/* + * net/tipc/node.c: TIPC node management routines + * + * Copyright (c) 2000-2006, 2012-2016, Ericsson AB + * Copyright (c) 2005-2006, 2010-2014, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "core.h" +#include "link.h" +#include "node.h" +#include "name_distr.h" +#include "socket.h" +#include "bcast.h" +#include "monitor.h" +#include "discover.h" +#include "netlink.h" +#include "trace.h" +#include "crypto.h" + +#define INVALID_NODE_SIG 0x10000 +#define NODE_CLEANUP_AFTER 300000 + +/* Flags used to take different actions according to flag type + * TIPC_NOTIFY_NODE_DOWN: notify node is down + * TIPC_NOTIFY_NODE_UP: notify node is up + * TIPC_DISTRIBUTE_NAME: publish or withdraw link state name type + */ +enum { + TIPC_NOTIFY_NODE_DOWN = (1 << 3), + TIPC_NOTIFY_NODE_UP = (1 << 4), + TIPC_NOTIFY_LINK_UP = (1 << 6), + TIPC_NOTIFY_LINK_DOWN = (1 << 7) +}; + +struct tipc_link_entry { + struct tipc_link *link; + spinlock_t lock; /* per link */ + u32 mtu; + struct sk_buff_head inputq; + struct tipc_media_addr maddr; +}; + +struct tipc_bclink_entry { + struct tipc_link *link; + struct sk_buff_head inputq1; + struct sk_buff_head arrvq; + struct sk_buff_head inputq2; + struct sk_buff_head namedq; + u16 named_rcv_nxt; + bool named_open; +}; + +/** + * struct tipc_node - TIPC node structure + * @addr: network address of node + * @kref: reference counter to node object + * @lock: rwlock governing access to structure + * @net: the applicable net namespace + * @hash: links to adjacent nodes in unsorted hash chain + * @inputq: pointer to input queue containing messages for msg event + * @namedq: pointer to name table input queue with name table messages + * @active_links: bearer ids of active links, used as index into links[] array + * @links: array containing references to all links to node + * @bc_entry: broadcast link entry + * @action_flags: bit mask of different types of node actions + * @state: connectivity state vs peer node + * @preliminary: a preliminary node or not + * @failover_sent: failover sent or not + * @sync_point: sequence number where synch/failover is finished + * @list: links to adjacent nodes in sorted list of cluster's nodes + * @working_links: number of working links to node (both active and standby) + * @link_cnt: number of links to node + * @capabilities: bitmap, indicating peer node's functional capabilities + * @signature: node instance identifier + * @link_id: local and remote bearer ids of changing link, if any + * @peer_id: 128-bit ID of peer + * @peer_id_string: ID string of peer + * @publ_list: list of publications + * @conn_sks: list of connections (FIXME) + * @timer: node's keepalive timer + * @keepalive_intv: keepalive interval in milliseconds + * @rcu: rcu struct for tipc_node + * @delete_at: indicates the time for deleting a down node + * @peer_net: peer's net namespace + * @peer_hash_mix: hash for this peer (FIXME) + * @crypto_rx: RX crypto handler + */ +struct tipc_node { + u32 addr; + struct kref kref; + rwlock_t lock; + struct net *net; + struct hlist_node hash; + int active_links[2]; + struct tipc_link_entry links[MAX_BEARERS]; + struct tipc_bclink_entry bc_entry; + int action_flags; + struct list_head list; + int state; + bool preliminary; + bool failover_sent; + u16 sync_point; + int link_cnt; + u16 working_links; + u16 capabilities; + u32 signature; + u32 link_id; + u8 peer_id[16]; + char peer_id_string[NODE_ID_STR_LEN]; + struct list_head publ_list; + struct list_head conn_sks; + unsigned long keepalive_intv; + struct timer_list timer; + struct rcu_head rcu; + unsigned long delete_at; + struct net *peer_net; + u32 peer_hash_mix; +#ifdef CONFIG_TIPC_CRYPTO + struct tipc_crypto *crypto_rx; +#endif +}; + +/* Node FSM states and events: + */ +enum { + SELF_DOWN_PEER_DOWN = 0xdd, + SELF_UP_PEER_UP = 0xaa, + SELF_DOWN_PEER_LEAVING = 0xd1, + SELF_UP_PEER_COMING = 0xac, + SELF_COMING_PEER_UP = 0xca, + SELF_LEAVING_PEER_DOWN = 0x1d, + NODE_FAILINGOVER = 0xf0, + NODE_SYNCHING = 0xcc +}; + +enum { + SELF_ESTABL_CONTACT_EVT = 0xece, + SELF_LOST_CONTACT_EVT = 0x1ce, + PEER_ESTABL_CONTACT_EVT = 0x9ece, + PEER_LOST_CONTACT_EVT = 0x91ce, + NODE_FAILOVER_BEGIN_EVT = 0xfbe, + NODE_FAILOVER_END_EVT = 0xfee, + NODE_SYNCH_BEGIN_EVT = 0xcbe, + NODE_SYNCH_END_EVT = 0xcee +}; + +static void __tipc_node_link_down(struct tipc_node *n, int *bearer_id, + struct sk_buff_head *xmitq, + struct tipc_media_addr **maddr); +static void tipc_node_link_down(struct tipc_node *n, int bearer_id, + bool delete); +static void node_lost_contact(struct tipc_node *n, struct sk_buff_head *inputq); +static void tipc_node_delete(struct tipc_node *node); +static void tipc_node_timeout(struct timer_list *t); +static void tipc_node_fsm_evt(struct tipc_node *n, int evt); +static struct tipc_node *tipc_node_find(struct net *net, u32 addr); +static struct tipc_node *tipc_node_find_by_id(struct net *net, u8 *id); +static bool node_is_up(struct tipc_node *n); +static void tipc_node_delete_from_list(struct tipc_node *node); + +struct tipc_sock_conn { + u32 port; + u32 peer_port; + u32 peer_node; + struct list_head list; +}; + +static struct tipc_link *node_active_link(struct tipc_node *n, int sel) +{ + int bearer_id = n->active_links[sel & 1]; + + if (unlikely(bearer_id == INVALID_BEARER_ID)) + return NULL; + + return n->links[bearer_id].link; +} + +int tipc_node_get_mtu(struct net *net, u32 addr, u32 sel, bool connected) +{ + struct tipc_node *n; + int bearer_id; + unsigned int mtu = MAX_MSG_SIZE; + + n = tipc_node_find(net, addr); + if (unlikely(!n)) + return mtu; + + /* Allow MAX_MSG_SIZE when building connection oriented message + * if they are in the same core network + */ + if (n->peer_net && connected) { + tipc_node_put(n); + return mtu; + } + + bearer_id = n->active_links[sel & 1]; + if (likely(bearer_id != INVALID_BEARER_ID)) + mtu = n->links[bearer_id].mtu; + tipc_node_put(n); + return mtu; +} + +bool tipc_node_get_id(struct net *net, u32 addr, u8 *id) +{ + u8 *own_id = tipc_own_id(net); + struct tipc_node *n; + + if (!own_id) + return true; + + if (addr == tipc_own_addr(net)) { + memcpy(id, own_id, TIPC_NODEID_LEN); + return true; + } + n = tipc_node_find(net, addr); + if (!n) + return false; + + memcpy(id, &n->peer_id, TIPC_NODEID_LEN); + tipc_node_put(n); + return true; +} + +u16 tipc_node_get_capabilities(struct net *net, u32 addr) +{ + struct tipc_node *n; + u16 caps; + + n = tipc_node_find(net, addr); + if (unlikely(!n)) + return TIPC_NODE_CAPABILITIES; + caps = n->capabilities; + tipc_node_put(n); + return caps; +} + +u32 tipc_node_get_addr(struct tipc_node *node) +{ + return (node) ? node->addr : 0; +} + +char *tipc_node_get_id_str(struct tipc_node *node) +{ + return node->peer_id_string; +} + +#ifdef CONFIG_TIPC_CRYPTO +/** + * tipc_node_crypto_rx - Retrieve crypto RX handle from node + * @__n: target tipc_node + * Note: node ref counter must be held first! + */ +struct tipc_crypto *tipc_node_crypto_rx(struct tipc_node *__n) +{ + return (__n) ? __n->crypto_rx : NULL; +} + +struct tipc_crypto *tipc_node_crypto_rx_by_list(struct list_head *pos) +{ + return container_of(pos, struct tipc_node, list)->crypto_rx; +} + +struct tipc_crypto *tipc_node_crypto_rx_by_addr(struct net *net, u32 addr) +{ + struct tipc_node *n; + + n = tipc_node_find(net, addr); + return (n) ? n->crypto_rx : NULL; +} +#endif + +static void tipc_node_free(struct rcu_head *rp) +{ + struct tipc_node *n = container_of(rp, struct tipc_node, rcu); + +#ifdef CONFIG_TIPC_CRYPTO + tipc_crypto_stop(&n->crypto_rx); +#endif + kfree(n); +} + +static void tipc_node_kref_release(struct kref *kref) +{ + struct tipc_node *n = container_of(kref, struct tipc_node, kref); + + kfree(n->bc_entry.link); + call_rcu(&n->rcu, tipc_node_free); +} + +void tipc_node_put(struct tipc_node *node) +{ + kref_put(&node->kref, tipc_node_kref_release); +} + +void tipc_node_get(struct tipc_node *node) +{ + kref_get(&node->kref); +} + +/* + * tipc_node_find - locate specified node object, if it exists + */ +static struct tipc_node *tipc_node_find(struct net *net, u32 addr) +{ + struct tipc_net *tn = tipc_net(net); + struct tipc_node *node; + unsigned int thash = tipc_hashfn(addr); + + rcu_read_lock(); + hlist_for_each_entry_rcu(node, &tn->node_htable[thash], hash) { + if (node->addr != addr || node->preliminary) + continue; + if (!kref_get_unless_zero(&node->kref)) + node = NULL; + break; + } + rcu_read_unlock(); + return node; +} + +/* tipc_node_find_by_id - locate specified node object by its 128-bit id + * Note: this function is called only when a discovery request failed + * to find the node by its 32-bit id, and is not time critical + */ +static struct tipc_node *tipc_node_find_by_id(struct net *net, u8 *id) +{ + struct tipc_net *tn = tipc_net(net); + struct tipc_node *n; + bool found = false; + + rcu_read_lock(); + list_for_each_entry_rcu(n, &tn->node_list, list) { + read_lock_bh(&n->lock); + if (!memcmp(id, n->peer_id, 16) && + kref_get_unless_zero(&n->kref)) + found = true; + read_unlock_bh(&n->lock); + if (found) + break; + } + rcu_read_unlock(); + return found ? n : NULL; +} + +static void tipc_node_read_lock(struct tipc_node *n) + __acquires(n->lock) +{ + read_lock_bh(&n->lock); +} + +static void tipc_node_read_unlock(struct tipc_node *n) + __releases(n->lock) +{ + read_unlock_bh(&n->lock); +} + +static void tipc_node_write_lock(struct tipc_node *n) + __acquires(n->lock) +{ + write_lock_bh(&n->lock); +} + +static void tipc_node_write_unlock_fast(struct tipc_node *n) + __releases(n->lock) +{ + write_unlock_bh(&n->lock); +} + +static void tipc_node_write_unlock(struct tipc_node *n) + __releases(n->lock) +{ + struct tipc_socket_addr sk; + struct net *net = n->net; + u32 flags = n->action_flags; + struct list_head *publ_list; + struct tipc_uaddr ua; + u32 bearer_id, node; + + if (likely(!flags)) { + write_unlock_bh(&n->lock); + return; + } + + tipc_uaddr(&ua, TIPC_SERVICE_RANGE, TIPC_NODE_SCOPE, + TIPC_LINK_STATE, n->addr, n->addr); + sk.ref = n->link_id; + sk.node = tipc_own_addr(net); + node = n->addr; + bearer_id = n->link_id & 0xffff; + publ_list = &n->publ_list; + + n->action_flags &= ~(TIPC_NOTIFY_NODE_DOWN | TIPC_NOTIFY_NODE_UP | + TIPC_NOTIFY_LINK_DOWN | TIPC_NOTIFY_LINK_UP); + + write_unlock_bh(&n->lock); + + if (flags & TIPC_NOTIFY_NODE_DOWN) + tipc_publ_notify(net, publ_list, node, n->capabilities); + + if (flags & TIPC_NOTIFY_NODE_UP) + tipc_named_node_up(net, node, n->capabilities); + + if (flags & TIPC_NOTIFY_LINK_UP) { + tipc_mon_peer_up(net, node, bearer_id); + tipc_nametbl_publish(net, &ua, &sk, sk.ref); + } + if (flags & TIPC_NOTIFY_LINK_DOWN) { + tipc_mon_peer_down(net, node, bearer_id); + tipc_nametbl_withdraw(net, &ua, &sk, sk.ref); + } +} + +static void tipc_node_assign_peer_net(struct tipc_node *n, u32 hash_mixes) +{ + int net_id = tipc_netid(n->net); + struct tipc_net *tn_peer; + struct net *tmp; + u32 hash_chk; + + if (n->peer_net) + return; + + for_each_net_rcu(tmp) { + tn_peer = tipc_net(tmp); + if (!tn_peer) + continue; + /* Integrity checking whether node exists in namespace or not */ + if (tn_peer->net_id != net_id) + continue; + if (memcmp(n->peer_id, tn_peer->node_id, NODE_ID_LEN)) + continue; + hash_chk = tipc_net_hash_mixes(tmp, tn_peer->random); + if (hash_mixes ^ hash_chk) + continue; + n->peer_net = tmp; + n->peer_hash_mix = hash_mixes; + break; + } +} + +struct tipc_node *tipc_node_create(struct net *net, u32 addr, u8 *peer_id, + u16 capabilities, u32 hash_mixes, + bool preliminary) +{ + struct tipc_net *tn = net_generic(net, tipc_net_id); + struct tipc_link *l, *snd_l = tipc_bc_sndlink(net); + struct tipc_node *n, *temp_node; + unsigned long intv; + int bearer_id; + int i; + + spin_lock_bh(&tn->node_list_lock); + n = tipc_node_find(net, addr) ?: + tipc_node_find_by_id(net, peer_id); + if (n) { + if (!n->preliminary) + goto update; + if (preliminary) + goto exit; + /* A preliminary node becomes "real" now, refresh its data */ + tipc_node_write_lock(n); + if (!tipc_link_bc_create(net, tipc_own_addr(net), addr, peer_id, U16_MAX, + tipc_link_min_win(snd_l), tipc_link_max_win(snd_l), + n->capabilities, &n->bc_entry.inputq1, + &n->bc_entry.namedq, snd_l, &n->bc_entry.link)) { + pr_warn("Broadcast rcv link refresh failed, no memory\n"); + tipc_node_write_unlock_fast(n); + tipc_node_put(n); + n = NULL; + goto exit; + } + n->preliminary = false; + n->addr = addr; + hlist_del_rcu(&n->hash); + hlist_add_head_rcu(&n->hash, + &tn->node_htable[tipc_hashfn(addr)]); + list_del_rcu(&n->list); + list_for_each_entry_rcu(temp_node, &tn->node_list, list) { + if (n->addr < temp_node->addr) + break; + } + list_add_tail_rcu(&n->list, &temp_node->list); + tipc_node_write_unlock_fast(n); + +update: + if (n->peer_hash_mix ^ hash_mixes) + tipc_node_assign_peer_net(n, hash_mixes); + if (n->capabilities == capabilities) + goto exit; + /* Same node may come back with new capabilities */ + tipc_node_write_lock(n); + n->capabilities = capabilities; + for (bearer_id = 0; bearer_id < MAX_BEARERS; bearer_id++) { + l = n->links[bearer_id].link; + if (l) + tipc_link_update_caps(l, capabilities); + } + tipc_node_write_unlock_fast(n); + + /* Calculate cluster capabilities */ + tn->capabilities = TIPC_NODE_CAPABILITIES; + list_for_each_entry_rcu(temp_node, &tn->node_list, list) { + tn->capabilities &= temp_node->capabilities; + } + + tipc_bcast_toggle_rcast(net, + (tn->capabilities & TIPC_BCAST_RCAST)); + + goto exit; + } + n = kzalloc(sizeof(*n), GFP_ATOMIC); + if (!n) { + pr_warn("Node creation failed, no memory\n"); + goto exit; + } + tipc_nodeid2string(n->peer_id_string, peer_id); +#ifdef CONFIG_TIPC_CRYPTO + if (unlikely(tipc_crypto_start(&n->crypto_rx, net, n))) { + pr_warn("Failed to start crypto RX(%s)!\n", n->peer_id_string); + kfree(n); + n = NULL; + goto exit; + } +#endif + n->addr = addr; + n->preliminary = preliminary; + memcpy(&n->peer_id, peer_id, 16); + n->net = net; + n->peer_net = NULL; + n->peer_hash_mix = 0; + /* Assign kernel local namespace if exists */ + tipc_node_assign_peer_net(n, hash_mixes); + n->capabilities = capabilities; + kref_init(&n->kref); + rwlock_init(&n->lock); + INIT_HLIST_NODE(&n->hash); + INIT_LIST_HEAD(&n->list); + INIT_LIST_HEAD(&n->publ_list); + INIT_LIST_HEAD(&n->conn_sks); + skb_queue_head_init(&n->bc_entry.namedq); + skb_queue_head_init(&n->bc_entry.inputq1); + __skb_queue_head_init(&n->bc_entry.arrvq); + skb_queue_head_init(&n->bc_entry.inputq2); + for (i = 0; i < MAX_BEARERS; i++) + spin_lock_init(&n->links[i].lock); + n->state = SELF_DOWN_PEER_LEAVING; + n->delete_at = jiffies + msecs_to_jiffies(NODE_CLEANUP_AFTER); + n->signature = INVALID_NODE_SIG; + n->active_links[0] = INVALID_BEARER_ID; + n->active_links[1] = INVALID_BEARER_ID; + if (!preliminary && + !tipc_link_bc_create(net, tipc_own_addr(net), addr, peer_id, U16_MAX, + tipc_link_min_win(snd_l), tipc_link_max_win(snd_l), + n->capabilities, &n->bc_entry.inputq1, + &n->bc_entry.namedq, snd_l, &n->bc_entry.link)) { + pr_warn("Broadcast rcv link creation failed, no memory\n"); + tipc_node_put(n); + n = NULL; + goto exit; + } + tipc_node_get(n); + timer_setup(&n->timer, tipc_node_timeout, 0); + /* Start a slow timer anyway, crypto needs it */ + n->keepalive_intv = 10000; + intv = jiffies + msecs_to_jiffies(n->keepalive_intv); + if (!mod_timer(&n->timer, intv)) + tipc_node_get(n); + hlist_add_head_rcu(&n->hash, &tn->node_htable[tipc_hashfn(addr)]); + list_for_each_entry_rcu(temp_node, &tn->node_list, list) { + if (n->addr < temp_node->addr) + break; + } + list_add_tail_rcu(&n->list, &temp_node->list); + /* Calculate cluster capabilities */ + tn->capabilities = TIPC_NODE_CAPABILITIES; + list_for_each_entry_rcu(temp_node, &tn->node_list, list) { + tn->capabilities &= temp_node->capabilities; + } + tipc_bcast_toggle_rcast(net, (tn->capabilities & TIPC_BCAST_RCAST)); + trace_tipc_node_create(n, true, " "); +exit: + spin_unlock_bh(&tn->node_list_lock); + return n; +} + +static void tipc_node_calculate_timer(struct tipc_node *n, struct tipc_link *l) +{ + unsigned long tol = tipc_link_tolerance(l); + unsigned long intv = ((tol / 4) > 500) ? 500 : tol / 4; + + /* Link with lowest tolerance determines timer interval */ + if (intv < n->keepalive_intv) + n->keepalive_intv = intv; + + /* Ensure link's abort limit corresponds to current tolerance */ + tipc_link_set_abort_limit(l, tol / n->keepalive_intv); +} + +static void tipc_node_delete_from_list(struct tipc_node *node) +{ +#ifdef CONFIG_TIPC_CRYPTO + tipc_crypto_key_flush(node->crypto_rx); +#endif + list_del_rcu(&node->list); + hlist_del_rcu(&node->hash); + tipc_node_put(node); +} + +static void tipc_node_delete(struct tipc_node *node) +{ + trace_tipc_node_delete(node, true, " "); + tipc_node_delete_from_list(node); + + del_timer_sync(&node->timer); + tipc_node_put(node); +} + +void tipc_node_stop(struct net *net) +{ + struct tipc_net *tn = tipc_net(net); + struct tipc_node *node, *t_node; + + spin_lock_bh(&tn->node_list_lock); + list_for_each_entry_safe(node, t_node, &tn->node_list, list) + tipc_node_delete(node); + spin_unlock_bh(&tn->node_list_lock); +} + +void tipc_node_subscribe(struct net *net, struct list_head *subscr, u32 addr) +{ + struct tipc_node *n; + + if (in_own_node(net, addr)) + return; + + n = tipc_node_find(net, addr); + if (!n) { + pr_warn("Node subscribe rejected, unknown node 0x%x\n", addr); + return; + } + tipc_node_write_lock(n); + list_add_tail(subscr, &n->publ_list); + tipc_node_write_unlock_fast(n); + tipc_node_put(n); +} + +void tipc_node_unsubscribe(struct net *net, struct list_head *subscr, u32 addr) +{ + struct tipc_node *n; + + if (in_own_node(net, addr)) + return; + + n = tipc_node_find(net, addr); + if (!n) { + pr_warn("Node unsubscribe rejected, unknown node 0x%x\n", addr); + return; + } + tipc_node_write_lock(n); + list_del_init(subscr); + tipc_node_write_unlock_fast(n); + tipc_node_put(n); +} + +int tipc_node_add_conn(struct net *net, u32 dnode, u32 port, u32 peer_port) +{ + struct tipc_node *node; + struct tipc_sock_conn *conn; + int err = 0; + + if (in_own_node(net, dnode)) + return 0; + + node = tipc_node_find(net, dnode); + if (!node) { + pr_warn("Connecting sock to node 0x%x failed\n", dnode); + return -EHOSTUNREACH; + } + conn = kmalloc(sizeof(*conn), GFP_ATOMIC); + if (!conn) { + err = -EHOSTUNREACH; + goto exit; + } + conn->peer_node = dnode; + conn->port = port; + conn->peer_port = peer_port; + + tipc_node_write_lock(node); + list_add_tail(&conn->list, &node->conn_sks); + tipc_node_write_unlock(node); +exit: + tipc_node_put(node); + return err; +} + +void tipc_node_remove_conn(struct net *net, u32 dnode, u32 port) +{ + struct tipc_node *node; + struct tipc_sock_conn *conn, *safe; + + if (in_own_node(net, dnode)) + return; + + node = tipc_node_find(net, dnode); + if (!node) + return; + + tipc_node_write_lock(node); + list_for_each_entry_safe(conn, safe, &node->conn_sks, list) { + if (port != conn->port) + continue; + list_del(&conn->list); + kfree(conn); + } + tipc_node_write_unlock(node); + tipc_node_put(node); +} + +static void tipc_node_clear_links(struct tipc_node *node) +{ + int i; + + for (i = 0; i < MAX_BEARERS; i++) { + struct tipc_link_entry *le = &node->links[i]; + + if (le->link) { + kfree(le->link); + le->link = NULL; + node->link_cnt--; + } + } +} + +/* tipc_node_cleanup - delete nodes that does not + * have active links for NODE_CLEANUP_AFTER time + */ +static bool tipc_node_cleanup(struct tipc_node *peer) +{ + struct tipc_node *temp_node; + struct tipc_net *tn = tipc_net(peer->net); + bool deleted = false; + + /* If lock held by tipc_node_stop() the node will be deleted anyway */ + if (!spin_trylock_bh(&tn->node_list_lock)) + return false; + + tipc_node_write_lock(peer); + + if (!node_is_up(peer) && time_after(jiffies, peer->delete_at)) { + tipc_node_clear_links(peer); + tipc_node_delete_from_list(peer); + deleted = true; + } + tipc_node_write_unlock(peer); + + if (!deleted) { + spin_unlock_bh(&tn->node_list_lock); + return deleted; + } + + /* Calculate cluster capabilities */ + tn->capabilities = TIPC_NODE_CAPABILITIES; + list_for_each_entry_rcu(temp_node, &tn->node_list, list) { + tn->capabilities &= temp_node->capabilities; + } + tipc_bcast_toggle_rcast(peer->net, + (tn->capabilities & TIPC_BCAST_RCAST)); + spin_unlock_bh(&tn->node_list_lock); + return deleted; +} + +/* tipc_node_timeout - handle expiration of node timer + */ +static void tipc_node_timeout(struct timer_list *t) +{ + struct tipc_node *n = from_timer(n, t, timer); + struct tipc_link_entry *le; + struct sk_buff_head xmitq; + int remains = n->link_cnt; + int bearer_id; + int rc = 0; + + trace_tipc_node_timeout(n, false, " "); + if (!node_is_up(n) && tipc_node_cleanup(n)) { + /*Removing the reference of Timer*/ + tipc_node_put(n); + return; + } + +#ifdef CONFIG_TIPC_CRYPTO + /* Take any crypto key related actions first */ + tipc_crypto_timeout(n->crypto_rx); +#endif + __skb_queue_head_init(&xmitq); + + /* Initial node interval to value larger (10 seconds), then it will be + * recalculated with link lowest tolerance + */ + tipc_node_read_lock(n); + n->keepalive_intv = 10000; + tipc_node_read_unlock(n); + for (bearer_id = 0; remains && (bearer_id < MAX_BEARERS); bearer_id++) { + tipc_node_read_lock(n); + le = &n->links[bearer_id]; + if (le->link) { + spin_lock_bh(&le->lock); + /* Link tolerance may change asynchronously: */ + tipc_node_calculate_timer(n, le->link); + rc = tipc_link_timeout(le->link, &xmitq); + spin_unlock_bh(&le->lock); + remains--; + } + tipc_node_read_unlock(n); + tipc_bearer_xmit(n->net, bearer_id, &xmitq, &le->maddr, n); + if (rc & TIPC_LINK_DOWN_EVT) + tipc_node_link_down(n, bearer_id, false); + } + mod_timer(&n->timer, jiffies + msecs_to_jiffies(n->keepalive_intv)); +} + +/** + * __tipc_node_link_up - handle addition of link + * @n: target tipc_node + * @bearer_id: id of the bearer + * @xmitq: queue for messages to be xmited on + * Node lock must be held by caller + * Link becomes active (alone or shared) or standby, depending on its priority. + */ +static void __tipc_node_link_up(struct tipc_node *n, int bearer_id, + struct sk_buff_head *xmitq) +{ + int *slot0 = &n->active_links[0]; + int *slot1 = &n->active_links[1]; + struct tipc_link *ol = node_active_link(n, 0); + struct tipc_link *nl = n->links[bearer_id].link; + + if (!nl || tipc_link_is_up(nl)) + return; + + tipc_link_fsm_evt(nl, LINK_ESTABLISH_EVT); + if (!tipc_link_is_up(nl)) + return; + + n->working_links++; + n->action_flags |= TIPC_NOTIFY_LINK_UP; + n->link_id = tipc_link_id(nl); + + /* Leave room for tunnel header when returning 'mtu' to users: */ + n->links[bearer_id].mtu = tipc_link_mss(nl); + + tipc_bearer_add_dest(n->net, bearer_id, n->addr); + tipc_bcast_inc_bearer_dst_cnt(n->net, bearer_id); + + pr_debug("Established link <%s> on network plane %c\n", + tipc_link_name(nl), tipc_link_plane(nl)); + trace_tipc_node_link_up(n, true, " "); + + /* Ensure that a STATE message goes first */ + tipc_link_build_state_msg(nl, xmitq); + + /* First link? => give it both slots */ + if (!ol) { + *slot0 = bearer_id; + *slot1 = bearer_id; + tipc_node_fsm_evt(n, SELF_ESTABL_CONTACT_EVT); + n->action_flags |= TIPC_NOTIFY_NODE_UP; + tipc_link_set_active(nl, true); + tipc_bcast_add_peer(n->net, nl, xmitq); + return; + } + + /* Second link => redistribute slots */ + if (tipc_link_prio(nl) > tipc_link_prio(ol)) { + pr_debug("Old link <%s> becomes standby\n", tipc_link_name(ol)); + *slot0 = bearer_id; + *slot1 = bearer_id; + tipc_link_set_active(nl, true); + tipc_link_set_active(ol, false); + } else if (tipc_link_prio(nl) == tipc_link_prio(ol)) { + tipc_link_set_active(nl, true); + *slot1 = bearer_id; + } else { + pr_debug("New link <%s> is standby\n", tipc_link_name(nl)); + } + + /* Prepare synchronization with first link */ + tipc_link_tnl_prepare(ol, nl, SYNCH_MSG, xmitq); +} + +/** + * tipc_node_link_up - handle addition of link + * @n: target tipc_node + * @bearer_id: id of the bearer + * @xmitq: queue for messages to be xmited on + * + * Link becomes active (alone or shared) or standby, depending on its priority. + */ +static void tipc_node_link_up(struct tipc_node *n, int bearer_id, + struct sk_buff_head *xmitq) +{ + struct tipc_media_addr *maddr; + + tipc_node_write_lock(n); + __tipc_node_link_up(n, bearer_id, xmitq); + maddr = &n->links[bearer_id].maddr; + tipc_bearer_xmit(n->net, bearer_id, xmitq, maddr, n); + tipc_node_write_unlock(n); +} + +/** + * tipc_node_link_failover() - start failover in case "half-failover" + * + * This function is only called in a very special situation where link + * failover can be already started on peer node but not on this node. + * This can happen when e.g.:: + * + * 1. Both links <1A-2A>, <1B-2B> down + * 2. Link endpoint 2A up, but 1A still down (e.g. due to network + * disturbance, wrong session, etc.) + * 3. Link <1B-2B> up + * 4. Link endpoint 2A down (e.g. due to link tolerance timeout) + * 5. Node 2 starts failover onto link <1B-2B> + * + * ==> Node 1 does never start link/node failover! + * + * @n: tipc node structure + * @l: link peer endpoint failingover (- can be NULL) + * @tnl: tunnel link + * @xmitq: queue for messages to be xmited on tnl link later + */ +static void tipc_node_link_failover(struct tipc_node *n, struct tipc_link *l, + struct tipc_link *tnl, + struct sk_buff_head *xmitq) +{ + /* Avoid to be "self-failover" that can never end */ + if (!tipc_link_is_up(tnl)) + return; + + /* Don't rush, failure link may be in the process of resetting */ + if (l && !tipc_link_is_reset(l)) + return; + + tipc_link_fsm_evt(tnl, LINK_SYNCH_END_EVT); + tipc_node_fsm_evt(n, NODE_SYNCH_END_EVT); + + n->sync_point = tipc_link_rcv_nxt(tnl) + (U16_MAX / 2 - 1); + tipc_link_failover_prepare(l, tnl, xmitq); + + if (l) + tipc_link_fsm_evt(l, LINK_FAILOVER_BEGIN_EVT); + tipc_node_fsm_evt(n, NODE_FAILOVER_BEGIN_EVT); +} + +/** + * __tipc_node_link_down - handle loss of link + * @n: target tipc_node + * @bearer_id: id of the bearer + * @xmitq: queue for messages to be xmited on + * @maddr: output media address of the bearer + */ +static void __tipc_node_link_down(struct tipc_node *n, int *bearer_id, + struct sk_buff_head *xmitq, + struct tipc_media_addr **maddr) +{ + struct tipc_link_entry *le = &n->links[*bearer_id]; + int *slot0 = &n->active_links[0]; + int *slot1 = &n->active_links[1]; + int i, highest = 0, prio; + struct tipc_link *l, *_l, *tnl; + + l = n->links[*bearer_id].link; + if (!l || tipc_link_is_reset(l)) + return; + + n->working_links--; + n->action_flags |= TIPC_NOTIFY_LINK_DOWN; + n->link_id = tipc_link_id(l); + + tipc_bearer_remove_dest(n->net, *bearer_id, n->addr); + + pr_debug("Lost link <%s> on network plane %c\n", + tipc_link_name(l), tipc_link_plane(l)); + + /* Select new active link if any available */ + *slot0 = INVALID_BEARER_ID; + *slot1 = INVALID_BEARER_ID; + for (i = 0; i < MAX_BEARERS; i++) { + _l = n->links[i].link; + if (!_l || !tipc_link_is_up(_l)) + continue; + if (_l == l) + continue; + prio = tipc_link_prio(_l); + if (prio < highest) + continue; + if (prio > highest) { + highest = prio; + *slot0 = i; + *slot1 = i; + continue; + } + *slot1 = i; + } + + if (!node_is_up(n)) { + if (tipc_link_peer_is_down(l)) + tipc_node_fsm_evt(n, PEER_LOST_CONTACT_EVT); + tipc_node_fsm_evt(n, SELF_LOST_CONTACT_EVT); + trace_tipc_link_reset(l, TIPC_DUMP_ALL, "link down!"); + tipc_link_fsm_evt(l, LINK_RESET_EVT); + tipc_link_reset(l); + tipc_link_build_reset_msg(l, xmitq); + *maddr = &n->links[*bearer_id].maddr; + node_lost_contact(n, &le->inputq); + tipc_bcast_dec_bearer_dst_cnt(n->net, *bearer_id); + return; + } + tipc_bcast_dec_bearer_dst_cnt(n->net, *bearer_id); + + /* There is still a working link => initiate failover */ + *bearer_id = n->active_links[0]; + tnl = n->links[*bearer_id].link; + tipc_link_fsm_evt(tnl, LINK_SYNCH_END_EVT); + tipc_node_fsm_evt(n, NODE_SYNCH_END_EVT); + n->sync_point = tipc_link_rcv_nxt(tnl) + (U16_MAX / 2 - 1); + tipc_link_tnl_prepare(l, tnl, FAILOVER_MSG, xmitq); + trace_tipc_link_reset(l, TIPC_DUMP_ALL, "link down -> failover!"); + tipc_link_reset(l); + tipc_link_fsm_evt(l, LINK_RESET_EVT); + tipc_link_fsm_evt(l, LINK_FAILOVER_BEGIN_EVT); + tipc_node_fsm_evt(n, NODE_FAILOVER_BEGIN_EVT); + *maddr = &n->links[*bearer_id].maddr; +} + +static void tipc_node_link_down(struct tipc_node *n, int bearer_id, bool delete) +{ + struct tipc_link_entry *le = &n->links[bearer_id]; + struct tipc_media_addr *maddr = NULL; + struct tipc_link *l = le->link; + int old_bearer_id = bearer_id; + struct sk_buff_head xmitq; + + if (!l) + return; + + __skb_queue_head_init(&xmitq); + + tipc_node_write_lock(n); + if (!tipc_link_is_establishing(l)) { + __tipc_node_link_down(n, &bearer_id, &xmitq, &maddr); + } else { + /* Defuse pending tipc_node_link_up() */ + tipc_link_reset(l); + tipc_link_fsm_evt(l, LINK_RESET_EVT); + } + if (delete) { + kfree(l); + le->link = NULL; + n->link_cnt--; + } + trace_tipc_node_link_down(n, true, "node link down or deleted!"); + tipc_node_write_unlock(n); + if (delete) + tipc_mon_remove_peer(n->net, n->addr, old_bearer_id); + if (!skb_queue_empty(&xmitq)) + tipc_bearer_xmit(n->net, bearer_id, &xmitq, maddr, n); + tipc_sk_rcv(n->net, &le->inputq); +} + +static bool node_is_up(struct tipc_node *n) +{ + return n->active_links[0] != INVALID_BEARER_ID; +} + +bool tipc_node_is_up(struct net *net, u32 addr) +{ + struct tipc_node *n; + bool retval = false; + + if (in_own_node(net, addr)) + return true; + + n = tipc_node_find(net, addr); + if (!n) + return false; + retval = node_is_up(n); + tipc_node_put(n); + return retval; +} + +static u32 tipc_node_suggest_addr(struct net *net, u32 addr) +{ + struct tipc_node *n; + + addr ^= tipc_net(net)->random; + while ((n = tipc_node_find(net, addr))) { + tipc_node_put(n); + addr++; + } + return addr; +} + +/* tipc_node_try_addr(): Check if addr can be used by peer, suggest other if not + * Returns suggested address if any, otherwise 0 + */ +u32 tipc_node_try_addr(struct net *net, u8 *id, u32 addr) +{ + struct tipc_net *tn = tipc_net(net); + struct tipc_node *n; + bool preliminary; + u32 sugg_addr; + + /* Suggest new address if some other peer is using this one */ + n = tipc_node_find(net, addr); + if (n) { + if (!memcmp(n->peer_id, id, NODE_ID_LEN)) + addr = 0; + tipc_node_put(n); + if (!addr) + return 0; + return tipc_node_suggest_addr(net, addr); + } + + /* Suggest previously used address if peer is known */ + n = tipc_node_find_by_id(net, id); + if (n) { + sugg_addr = n->addr; + preliminary = n->preliminary; + tipc_node_put(n); + if (!preliminary) + return sugg_addr; + } + + /* Even this node may be in conflict */ + if (tn->trial_addr == addr) + return tipc_node_suggest_addr(net, addr); + + return 0; +} + +void tipc_node_check_dest(struct net *net, u32 addr, + u8 *peer_id, struct tipc_bearer *b, + u16 capabilities, u32 signature, u32 hash_mixes, + struct tipc_media_addr *maddr, + bool *respond, bool *dupl_addr) +{ + struct tipc_node *n; + struct tipc_link *l; + struct tipc_link_entry *le; + bool addr_match = false; + bool sign_match = false; + bool link_up = false; + bool link_is_reset = false; + bool accept_addr = false; + bool reset = false; + char *if_name; + unsigned long intv; + u16 session; + + *dupl_addr = false; + *respond = false; + + n = tipc_node_create(net, addr, peer_id, capabilities, hash_mixes, + false); + if (!n) + return; + + tipc_node_write_lock(n); + + le = &n->links[b->identity]; + + /* Prepare to validate requesting node's signature and media address */ + l = le->link; + link_up = l && tipc_link_is_up(l); + link_is_reset = l && tipc_link_is_reset(l); + addr_match = l && !memcmp(&le->maddr, maddr, sizeof(*maddr)); + sign_match = (signature == n->signature); + + /* These three flags give us eight permutations: */ + + if (sign_match && addr_match && link_up) { + /* All is fine. Ignore requests. */ + /* Peer node is not a container/local namespace */ + if (!n->peer_hash_mix) + n->peer_hash_mix = hash_mixes; + } else if (sign_match && addr_match && !link_up) { + /* Respond. The link will come up in due time */ + *respond = true; + } else if (sign_match && !addr_match && link_up) { + /* Peer has changed i/f address without rebooting. + * If so, the link will reset soon, and the next + * discovery will be accepted. So we can ignore it. + * It may also be a cloned or malicious peer having + * chosen the same node address and signature as an + * existing one. + * Ignore requests until the link goes down, if ever. + */ + *dupl_addr = true; + } else if (sign_match && !addr_match && !link_up) { + /* Peer link has changed i/f address without rebooting. + * It may also be a cloned or malicious peer; we can't + * distinguish between the two. + * The signature is correct, so we must accept. + */ + accept_addr = true; + *respond = true; + reset = true; + } else if (!sign_match && addr_match && link_up) { + /* Peer node rebooted. Two possibilities: + * - Delayed re-discovery; this link endpoint has already + * reset and re-established contact with the peer, before + * receiving a discovery message from that node. + * (The peer happened to receive one from this node first). + * - The peer came back so fast that our side has not + * discovered it yet. Probing from this side will soon + * reset the link, since there can be no working link + * endpoint at the peer end, and the link will re-establish. + * Accept the signature, since it comes from a known peer. + */ + n->signature = signature; + } else if (!sign_match && addr_match && !link_up) { + /* The peer node has rebooted. + * Accept signature, since it is a known peer. + */ + n->signature = signature; + *respond = true; + } else if (!sign_match && !addr_match && link_up) { + /* Peer rebooted with new address, or a new/duplicate peer. + * Ignore until the link goes down, if ever. + */ + *dupl_addr = true; + } else if (!sign_match && !addr_match && !link_up) { + /* Peer rebooted with new address, or it is a new peer. + * Accept signature and address. + */ + n->signature = signature; + accept_addr = true; + *respond = true; + reset = true; + } + + if (!accept_addr) + goto exit; + + /* Now create new link if not already existing */ + if (!l) { + if (n->link_cnt == 2) + goto exit; + + if_name = strchr(b->name, ':') + 1; + get_random_bytes(&session, sizeof(u16)); + if (!tipc_link_create(net, if_name, b->identity, b->tolerance, + b->net_plane, b->mtu, b->priority, + b->min_win, b->max_win, session, + tipc_own_addr(net), addr, peer_id, + n->capabilities, + tipc_bc_sndlink(n->net), n->bc_entry.link, + &le->inputq, + &n->bc_entry.namedq, &l)) { + *respond = false; + goto exit; + } + trace_tipc_link_reset(l, TIPC_DUMP_ALL, "link created!"); + tipc_link_reset(l); + tipc_link_fsm_evt(l, LINK_RESET_EVT); + if (n->state == NODE_FAILINGOVER) + tipc_link_fsm_evt(l, LINK_FAILOVER_BEGIN_EVT); + link_is_reset = tipc_link_is_reset(l); + le->link = l; + n->link_cnt++; + tipc_node_calculate_timer(n, l); + if (n->link_cnt == 1) { + intv = jiffies + msecs_to_jiffies(n->keepalive_intv); + if (!mod_timer(&n->timer, intv)) + tipc_node_get(n); + } + } + memcpy(&le->maddr, maddr, sizeof(*maddr)); +exit: + tipc_node_write_unlock(n); + if (reset && !link_is_reset) + tipc_node_link_down(n, b->identity, false); + tipc_node_put(n); +} + +void tipc_node_delete_links(struct net *net, int bearer_id) +{ + struct tipc_net *tn = net_generic(net, tipc_net_id); + struct tipc_node *n; + + rcu_read_lock(); + list_for_each_entry_rcu(n, &tn->node_list, list) { + tipc_node_link_down(n, bearer_id, true); + } + rcu_read_unlock(); +} + +static void tipc_node_reset_links(struct tipc_node *n) +{ + int i; + + pr_warn("Resetting all links to %x\n", n->addr); + + trace_tipc_node_reset_links(n, true, " "); + for (i = 0; i < MAX_BEARERS; i++) { + tipc_node_link_down(n, i, false); + } +} + +/* tipc_node_fsm_evt - node finite state machine + * Determines when contact is allowed with peer node + */ +static void tipc_node_fsm_evt(struct tipc_node *n, int evt) +{ + int state = n->state; + + switch (state) { + case SELF_DOWN_PEER_DOWN: + switch (evt) { + case SELF_ESTABL_CONTACT_EVT: + state = SELF_UP_PEER_COMING; + break; + case PEER_ESTABL_CONTACT_EVT: + state = SELF_COMING_PEER_UP; + break; + case SELF_LOST_CONTACT_EVT: + case PEER_LOST_CONTACT_EVT: + break; + case NODE_SYNCH_END_EVT: + case NODE_SYNCH_BEGIN_EVT: + case NODE_FAILOVER_BEGIN_EVT: + case NODE_FAILOVER_END_EVT: + default: + goto illegal_evt; + } + break; + case SELF_UP_PEER_UP: + switch (evt) { + case SELF_LOST_CONTACT_EVT: + state = SELF_DOWN_PEER_LEAVING; + break; + case PEER_LOST_CONTACT_EVT: + state = SELF_LEAVING_PEER_DOWN; + break; + case NODE_SYNCH_BEGIN_EVT: + state = NODE_SYNCHING; + break; + case NODE_FAILOVER_BEGIN_EVT: + state = NODE_FAILINGOVER; + break; + case SELF_ESTABL_CONTACT_EVT: + case PEER_ESTABL_CONTACT_EVT: + case NODE_SYNCH_END_EVT: + case NODE_FAILOVER_END_EVT: + break; + default: + goto illegal_evt; + } + break; + case SELF_DOWN_PEER_LEAVING: + switch (evt) { + case PEER_LOST_CONTACT_EVT: + state = SELF_DOWN_PEER_DOWN; + break; + case SELF_ESTABL_CONTACT_EVT: + case PEER_ESTABL_CONTACT_EVT: + case SELF_LOST_CONTACT_EVT: + break; + case NODE_SYNCH_END_EVT: + case NODE_SYNCH_BEGIN_EVT: + case NODE_FAILOVER_BEGIN_EVT: + case NODE_FAILOVER_END_EVT: + default: + goto illegal_evt; + } + break; + case SELF_UP_PEER_COMING: + switch (evt) { + case PEER_ESTABL_CONTACT_EVT: + state = SELF_UP_PEER_UP; + break; + case SELF_LOST_CONTACT_EVT: + state = SELF_DOWN_PEER_DOWN; + break; + case SELF_ESTABL_CONTACT_EVT: + case PEER_LOST_CONTACT_EVT: + case NODE_SYNCH_END_EVT: + case NODE_FAILOVER_BEGIN_EVT: + break; + case NODE_SYNCH_BEGIN_EVT: + case NODE_FAILOVER_END_EVT: + default: + goto illegal_evt; + } + break; + case SELF_COMING_PEER_UP: + switch (evt) { + case SELF_ESTABL_CONTACT_EVT: + state = SELF_UP_PEER_UP; + break; + case PEER_LOST_CONTACT_EVT: + state = SELF_DOWN_PEER_DOWN; + break; + case SELF_LOST_CONTACT_EVT: + case PEER_ESTABL_CONTACT_EVT: + break; + case NODE_SYNCH_END_EVT: + case NODE_SYNCH_BEGIN_EVT: + case NODE_FAILOVER_BEGIN_EVT: + case NODE_FAILOVER_END_EVT: + default: + goto illegal_evt; + } + break; + case SELF_LEAVING_PEER_DOWN: + switch (evt) { + case SELF_LOST_CONTACT_EVT: + state = SELF_DOWN_PEER_DOWN; + break; + case SELF_ESTABL_CONTACT_EVT: + case PEER_ESTABL_CONTACT_EVT: + case PEER_LOST_CONTACT_EVT: + break; + case NODE_SYNCH_END_EVT: + case NODE_SYNCH_BEGIN_EVT: + case NODE_FAILOVER_BEGIN_EVT: + case NODE_FAILOVER_END_EVT: + default: + goto illegal_evt; + } + break; + case NODE_FAILINGOVER: + switch (evt) { + case SELF_LOST_CONTACT_EVT: + state = SELF_DOWN_PEER_LEAVING; + break; + case PEER_LOST_CONTACT_EVT: + state = SELF_LEAVING_PEER_DOWN; + break; + case NODE_FAILOVER_END_EVT: + state = SELF_UP_PEER_UP; + break; + case NODE_FAILOVER_BEGIN_EVT: + case SELF_ESTABL_CONTACT_EVT: + case PEER_ESTABL_CONTACT_EVT: + break; + case NODE_SYNCH_BEGIN_EVT: + case NODE_SYNCH_END_EVT: + default: + goto illegal_evt; + } + break; + case NODE_SYNCHING: + switch (evt) { + case SELF_LOST_CONTACT_EVT: + state = SELF_DOWN_PEER_LEAVING; + break; + case PEER_LOST_CONTACT_EVT: + state = SELF_LEAVING_PEER_DOWN; + break; + case NODE_SYNCH_END_EVT: + state = SELF_UP_PEER_UP; + break; + case NODE_FAILOVER_BEGIN_EVT: + state = NODE_FAILINGOVER; + break; + case NODE_SYNCH_BEGIN_EVT: + case SELF_ESTABL_CONTACT_EVT: + case PEER_ESTABL_CONTACT_EVT: + break; + case NODE_FAILOVER_END_EVT: + default: + goto illegal_evt; + } + break; + default: + pr_err("Unknown node fsm state %x\n", state); + break; + } + trace_tipc_node_fsm(n->peer_id, n->state, state, evt); + n->state = state; + return; + +illegal_evt: + pr_err("Illegal node fsm evt %x in state %x\n", evt, state); + trace_tipc_node_fsm(n->peer_id, n->state, state, evt); +} + +static void node_lost_contact(struct tipc_node *n, + struct sk_buff_head *inputq) +{ + struct tipc_sock_conn *conn, *safe; + struct tipc_link *l; + struct list_head *conns = &n->conn_sks; + struct sk_buff *skb; + uint i; + + pr_debug("Lost contact with %x\n", n->addr); + n->delete_at = jiffies + msecs_to_jiffies(NODE_CLEANUP_AFTER); + trace_tipc_node_lost_contact(n, true, " "); + + /* Clean up broadcast state */ + tipc_bcast_remove_peer(n->net, n->bc_entry.link); + skb_queue_purge(&n->bc_entry.namedq); + + /* Abort any ongoing link failover */ + for (i = 0; i < MAX_BEARERS; i++) { + l = n->links[i].link; + if (l) + tipc_link_fsm_evt(l, LINK_FAILOVER_END_EVT); + } + + /* Notify publications from this node */ + n->action_flags |= TIPC_NOTIFY_NODE_DOWN; + n->peer_net = NULL; + n->peer_hash_mix = 0; + /* Notify sockets connected to node */ + list_for_each_entry_safe(conn, safe, conns, list) { + skb = tipc_msg_create(TIPC_CRITICAL_IMPORTANCE, TIPC_CONN_MSG, + SHORT_H_SIZE, 0, tipc_own_addr(n->net), + conn->peer_node, conn->port, + conn->peer_port, TIPC_ERR_NO_NODE); + if (likely(skb)) + skb_queue_tail(inputq, skb); + list_del(&conn->list); + kfree(conn); + } +} + +/** + * tipc_node_get_linkname - get the name of a link + * + * @net: the applicable net namespace + * @bearer_id: id of the bearer + * @addr: peer node address + * @linkname: link name output buffer + * @len: size of @linkname output buffer + * + * Return: 0 on success + */ +int tipc_node_get_linkname(struct net *net, u32 bearer_id, u32 addr, + char *linkname, size_t len) +{ + struct tipc_link *link; + int err = -EINVAL; + struct tipc_node *node = tipc_node_find(net, addr); + + if (!node) + return err; + + if (bearer_id >= MAX_BEARERS) + goto exit; + + tipc_node_read_lock(node); + link = node->links[bearer_id].link; + if (link) { + strncpy(linkname, tipc_link_name(link), len); + err = 0; + } + tipc_node_read_unlock(node); +exit: + tipc_node_put(node); + return err; +} + +/* Caller should hold node lock for the passed node */ +static int __tipc_nl_add_node(struct tipc_nl_msg *msg, struct tipc_node *node) +{ + void *hdr; + struct nlattr *attrs; + + hdr = genlmsg_put(msg->skb, msg->portid, msg->seq, &tipc_genl_family, + NLM_F_MULTI, TIPC_NL_NODE_GET); + if (!hdr) + return -EMSGSIZE; + + attrs = nla_nest_start_noflag(msg->skb, TIPC_NLA_NODE); + if (!attrs) + goto msg_full; + + if (nla_put_u32(msg->skb, TIPC_NLA_NODE_ADDR, node->addr)) + goto attr_msg_full; + if (node_is_up(node)) + if (nla_put_flag(msg->skb, TIPC_NLA_NODE_UP)) + goto attr_msg_full; + + nla_nest_end(msg->skb, attrs); + genlmsg_end(msg->skb, hdr); + + return 0; + +attr_msg_full: + nla_nest_cancel(msg->skb, attrs); +msg_full: + genlmsg_cancel(msg->skb, hdr); + + return -EMSGSIZE; +} + +static void tipc_lxc_xmit(struct net *peer_net, struct sk_buff_head *list) +{ + struct tipc_msg *hdr = buf_msg(skb_peek(list)); + struct sk_buff_head inputq; + + switch (msg_user(hdr)) { + case TIPC_LOW_IMPORTANCE: + case TIPC_MEDIUM_IMPORTANCE: + case TIPC_HIGH_IMPORTANCE: + case TIPC_CRITICAL_IMPORTANCE: + if (msg_connected(hdr) || msg_named(hdr) || + msg_direct(hdr)) { + tipc_loopback_trace(peer_net, list); + spin_lock_init(&list->lock); + tipc_sk_rcv(peer_net, list); + return; + } + if (msg_mcast(hdr)) { + tipc_loopback_trace(peer_net, list); + skb_queue_head_init(&inputq); + tipc_sk_mcast_rcv(peer_net, list, &inputq); + __skb_queue_purge(list); + skb_queue_purge(&inputq); + return; + } + return; + case MSG_FRAGMENTER: + if (tipc_msg_assemble(list)) { + tipc_loopback_trace(peer_net, list); + skb_queue_head_init(&inputq); + tipc_sk_mcast_rcv(peer_net, list, &inputq); + __skb_queue_purge(list); + skb_queue_purge(&inputq); + } + return; + case GROUP_PROTOCOL: + case CONN_MANAGER: + tipc_loopback_trace(peer_net, list); + spin_lock_init(&list->lock); + tipc_sk_rcv(peer_net, list); + return; + case LINK_PROTOCOL: + case NAME_DISTRIBUTOR: + case TUNNEL_PROTOCOL: + case BCAST_PROTOCOL: + return; + default: + return; + } +} + +/** + * tipc_node_xmit() - general link level function for message sending + * @net: the applicable net namespace + * @list: chain of buffers containing message + * @dnode: address of destination node + * @selector: a number used for deterministic link selection + * Consumes the buffer chain. + * Return: 0 if success, otherwise: -ELINKCONG,-EHOSTUNREACH,-EMSGSIZE,-ENOBUF + */ +int tipc_node_xmit(struct net *net, struct sk_buff_head *list, + u32 dnode, int selector) +{ + struct tipc_link_entry *le = NULL; + struct tipc_node *n; + struct sk_buff_head xmitq; + bool node_up = false; + struct net *peer_net; + int bearer_id; + int rc; + + if (in_own_node(net, dnode)) { + tipc_loopback_trace(net, list); + spin_lock_init(&list->lock); + tipc_sk_rcv(net, list); + return 0; + } + + n = tipc_node_find(net, dnode); + if (unlikely(!n)) { + __skb_queue_purge(list); + return -EHOSTUNREACH; + } + + rcu_read_lock(); + tipc_node_read_lock(n); + node_up = node_is_up(n); + peer_net = n->peer_net; + tipc_node_read_unlock(n); + if (node_up && peer_net && check_net(peer_net)) { + /* xmit inner linux container */ + tipc_lxc_xmit(peer_net, list); + if (likely(skb_queue_empty(list))) { + rcu_read_unlock(); + tipc_node_put(n); + return 0; + } + } + rcu_read_unlock(); + + tipc_node_read_lock(n); + bearer_id = n->active_links[selector & 1]; + if (unlikely(bearer_id == INVALID_BEARER_ID)) { + tipc_node_read_unlock(n); + tipc_node_put(n); + __skb_queue_purge(list); + return -EHOSTUNREACH; + } + + __skb_queue_head_init(&xmitq); + le = &n->links[bearer_id]; + spin_lock_bh(&le->lock); + rc = tipc_link_xmit(le->link, list, &xmitq); + spin_unlock_bh(&le->lock); + tipc_node_read_unlock(n); + + if (unlikely(rc == -ENOBUFS)) + tipc_node_link_down(n, bearer_id, false); + else + tipc_bearer_xmit(net, bearer_id, &xmitq, &le->maddr, n); + + tipc_node_put(n); + + return rc; +} + +/* tipc_node_xmit_skb(): send single buffer to destination + * Buffers sent via this function are generally TIPC_SYSTEM_IMPORTANCE + * messages, which will not be rejected + * The only exception is datagram messages rerouted after secondary + * lookup, which are rare and safe to dispose of anyway. + */ +int tipc_node_xmit_skb(struct net *net, struct sk_buff *skb, u32 dnode, + u32 selector) +{ + struct sk_buff_head head; + + __skb_queue_head_init(&head); + __skb_queue_tail(&head, skb); + tipc_node_xmit(net, &head, dnode, selector); + return 0; +} + +/* tipc_node_distr_xmit(): send single buffer msgs to individual destinations + * Note: this is only for SYSTEM_IMPORTANCE messages, which cannot be rejected + */ +int tipc_node_distr_xmit(struct net *net, struct sk_buff_head *xmitq) +{ + struct sk_buff *skb; + u32 selector, dnode; + + while ((skb = __skb_dequeue(xmitq))) { + selector = msg_origport(buf_msg(skb)); + dnode = msg_destnode(buf_msg(skb)); + tipc_node_xmit_skb(net, skb, dnode, selector); + } + return 0; +} + +void tipc_node_broadcast(struct net *net, struct sk_buff *skb, int rc_dests) +{ + struct sk_buff_head xmitq; + struct sk_buff *txskb; + struct tipc_node *n; + u16 dummy; + u32 dst; + + /* Use broadcast if all nodes support it */ + if (!rc_dests && tipc_bcast_get_mode(net) != BCLINK_MODE_RCAST) { + __skb_queue_head_init(&xmitq); + __skb_queue_tail(&xmitq, skb); + tipc_bcast_xmit(net, &xmitq, &dummy); + return; + } + + /* Otherwise use legacy replicast method */ + rcu_read_lock(); + list_for_each_entry_rcu(n, tipc_nodes(net), list) { + dst = n->addr; + if (in_own_node(net, dst)) + continue; + if (!node_is_up(n)) + continue; + txskb = pskb_copy(skb, GFP_ATOMIC); + if (!txskb) + break; + msg_set_destnode(buf_msg(txskb), dst); + tipc_node_xmit_skb(net, txskb, dst, 0); + } + rcu_read_unlock(); + kfree_skb(skb); +} + +static void tipc_node_mcast_rcv(struct tipc_node *n) +{ + struct tipc_bclink_entry *be = &n->bc_entry; + + /* 'arrvq' is under inputq2's lock protection */ + spin_lock_bh(&be->inputq2.lock); + spin_lock_bh(&be->inputq1.lock); + skb_queue_splice_tail_init(&be->inputq1, &be->arrvq); + spin_unlock_bh(&be->inputq1.lock); + spin_unlock_bh(&be->inputq2.lock); + tipc_sk_mcast_rcv(n->net, &be->arrvq, &be->inputq2); +} + +static void tipc_node_bc_sync_rcv(struct tipc_node *n, struct tipc_msg *hdr, + int bearer_id, struct sk_buff_head *xmitq) +{ + struct tipc_link *ucl; + int rc; + + rc = tipc_bcast_sync_rcv(n->net, n->bc_entry.link, hdr, xmitq); + + if (rc & TIPC_LINK_DOWN_EVT) { + tipc_node_reset_links(n); + return; + } + + if (!(rc & TIPC_LINK_SND_STATE)) + return; + + /* If probe message, a STATE response will be sent anyway */ + if (msg_probe(hdr)) + return; + + /* Produce a STATE message carrying broadcast NACK */ + tipc_node_read_lock(n); + ucl = n->links[bearer_id].link; + if (ucl) + tipc_link_build_state_msg(ucl, xmitq); + tipc_node_read_unlock(n); +} + +/** + * tipc_node_bc_rcv - process TIPC broadcast packet arriving from off-node + * @net: the applicable net namespace + * @skb: TIPC packet + * @bearer_id: id of bearer message arrived on + * + * Invoked with no locks held. + */ +static void tipc_node_bc_rcv(struct net *net, struct sk_buff *skb, int bearer_id) +{ + int rc; + struct sk_buff_head xmitq; + struct tipc_bclink_entry *be; + struct tipc_link_entry *le; + struct tipc_msg *hdr = buf_msg(skb); + int usr = msg_user(hdr); + u32 dnode = msg_destnode(hdr); + struct tipc_node *n; + + __skb_queue_head_init(&xmitq); + + /* If NACK for other node, let rcv link for that node peek into it */ + if ((usr == BCAST_PROTOCOL) && (dnode != tipc_own_addr(net))) + n = tipc_node_find(net, dnode); + else + n = tipc_node_find(net, msg_prevnode(hdr)); + if (!n) { + kfree_skb(skb); + return; + } + be = &n->bc_entry; + le = &n->links[bearer_id]; + + rc = tipc_bcast_rcv(net, be->link, skb); + + /* Broadcast ACKs are sent on a unicast link */ + if (rc & TIPC_LINK_SND_STATE) { + tipc_node_read_lock(n); + tipc_link_build_state_msg(le->link, &xmitq); + tipc_node_read_unlock(n); + } + + if (!skb_queue_empty(&xmitq)) + tipc_bearer_xmit(net, bearer_id, &xmitq, &le->maddr, n); + + if (!skb_queue_empty(&be->inputq1)) + tipc_node_mcast_rcv(n); + + /* Handle NAME_DISTRIBUTOR messages sent from 1.7 nodes */ + if (!skb_queue_empty(&n->bc_entry.namedq)) + tipc_named_rcv(net, &n->bc_entry.namedq, + &n->bc_entry.named_rcv_nxt, + &n->bc_entry.named_open); + + /* If reassembly or retransmission failure => reset all links to peer */ + if (rc & TIPC_LINK_DOWN_EVT) + tipc_node_reset_links(n); + + tipc_node_put(n); +} + +/** + * tipc_node_check_state - check and if necessary update node state + * @n: target tipc_node + * @skb: TIPC packet + * @bearer_id: identity of bearer delivering the packet + * @xmitq: queue for messages to be xmited on + * Return: true if state and msg are ok, otherwise false + */ +static bool tipc_node_check_state(struct tipc_node *n, struct sk_buff *skb, + int bearer_id, struct sk_buff_head *xmitq) +{ + struct tipc_msg *hdr = buf_msg(skb); + int usr = msg_user(hdr); + int mtyp = msg_type(hdr); + u16 oseqno = msg_seqno(hdr); + u16 exp_pkts = msg_msgcnt(hdr); + u16 rcv_nxt, syncpt, dlv_nxt, inputq_len; + int state = n->state; + struct tipc_link *l, *tnl, *pl = NULL; + struct tipc_media_addr *maddr; + int pb_id; + + if (trace_tipc_node_check_state_enabled()) { + trace_tipc_skb_dump(skb, false, "skb for node state check"); + trace_tipc_node_check_state(n, true, " "); + } + l = n->links[bearer_id].link; + if (!l) + return false; + rcv_nxt = tipc_link_rcv_nxt(l); + + + if (likely((state == SELF_UP_PEER_UP) && (usr != TUNNEL_PROTOCOL))) + return true; + + /* Find parallel link, if any */ + for (pb_id = 0; pb_id < MAX_BEARERS; pb_id++) { + if ((pb_id != bearer_id) && n->links[pb_id].link) { + pl = n->links[pb_id].link; + break; + } + } + + if (!tipc_link_validate_msg(l, hdr)) { + trace_tipc_skb_dump(skb, false, "PROTO invalid (2)!"); + trace_tipc_link_dump(l, TIPC_DUMP_NONE, "PROTO invalid (2)!"); + return false; + } + + /* Check and update node accesibility if applicable */ + if (state == SELF_UP_PEER_COMING) { + if (!tipc_link_is_up(l)) + return true; + if (!msg_peer_link_is_up(hdr)) + return true; + tipc_node_fsm_evt(n, PEER_ESTABL_CONTACT_EVT); + } + + if (state == SELF_DOWN_PEER_LEAVING) { + if (msg_peer_node_is_up(hdr)) + return false; + tipc_node_fsm_evt(n, PEER_LOST_CONTACT_EVT); + return true; + } + + if (state == SELF_LEAVING_PEER_DOWN) + return false; + + /* Ignore duplicate packets */ + if ((usr != LINK_PROTOCOL) && less(oseqno, rcv_nxt)) + return true; + + /* Initiate or update failover mode if applicable */ + if ((usr == TUNNEL_PROTOCOL) && (mtyp == FAILOVER_MSG)) { + syncpt = oseqno + exp_pkts - 1; + if (pl && !tipc_link_is_reset(pl)) { + __tipc_node_link_down(n, &pb_id, xmitq, &maddr); + trace_tipc_node_link_down(n, true, + "node link down <- failover!"); + tipc_skb_queue_splice_tail_init(tipc_link_inputq(pl), + tipc_link_inputq(l)); + } + + /* If parallel link was already down, and this happened before + * the tunnel link came up, node failover was never started. + * Ensure that a FAILOVER_MSG is sent to get peer out of + * NODE_FAILINGOVER state, also this node must accept + * TUNNEL_MSGs from peer. + */ + if (n->state != NODE_FAILINGOVER) + tipc_node_link_failover(n, pl, l, xmitq); + + /* If pkts arrive out of order, use lowest calculated syncpt */ + if (less(syncpt, n->sync_point)) + n->sync_point = syncpt; + } + + /* Open parallel link when tunnel link reaches synch point */ + if ((n->state == NODE_FAILINGOVER) && tipc_link_is_up(l)) { + if (!more(rcv_nxt, n->sync_point)) + return true; + tipc_node_fsm_evt(n, NODE_FAILOVER_END_EVT); + if (pl) + tipc_link_fsm_evt(pl, LINK_FAILOVER_END_EVT); + return true; + } + + /* No syncing needed if only one link */ + if (!pl || !tipc_link_is_up(pl)) + return true; + + /* Initiate synch mode if applicable */ + if ((usr == TUNNEL_PROTOCOL) && (mtyp == SYNCH_MSG) && (oseqno == 1)) { + if (n->capabilities & TIPC_TUNNEL_ENHANCED) + syncpt = msg_syncpt(hdr); + else + syncpt = msg_seqno(msg_inner_hdr(hdr)) + exp_pkts - 1; + if (!tipc_link_is_up(l)) + __tipc_node_link_up(n, bearer_id, xmitq); + if (n->state == SELF_UP_PEER_UP) { + n->sync_point = syncpt; + tipc_link_fsm_evt(l, LINK_SYNCH_BEGIN_EVT); + tipc_node_fsm_evt(n, NODE_SYNCH_BEGIN_EVT); + } + } + + /* Open tunnel link when parallel link reaches synch point */ + if (n->state == NODE_SYNCHING) { + if (tipc_link_is_synching(l)) { + tnl = l; + } else { + tnl = pl; + pl = l; + } + inputq_len = skb_queue_len(tipc_link_inputq(pl)); + dlv_nxt = tipc_link_rcv_nxt(pl) - inputq_len; + if (more(dlv_nxt, n->sync_point)) { + tipc_link_fsm_evt(tnl, LINK_SYNCH_END_EVT); + tipc_node_fsm_evt(n, NODE_SYNCH_END_EVT); + return true; + } + if (l == pl) + return true; + if ((usr == TUNNEL_PROTOCOL) && (mtyp == SYNCH_MSG)) + return true; + if (usr == LINK_PROTOCOL) + return true; + return false; + } + return true; +} + +/** + * tipc_rcv - process TIPC packets/messages arriving from off-node + * @net: the applicable net namespace + * @skb: TIPC packet + * @b: pointer to bearer message arrived on + * + * Invoked with no locks held. Bearer pointer must point to a valid bearer + * structure (i.e. cannot be NULL), but bearer can be inactive. + */ +void tipc_rcv(struct net *net, struct sk_buff *skb, struct tipc_bearer *b) +{ + struct sk_buff_head xmitq; + struct tipc_link_entry *le; + struct tipc_msg *hdr; + struct tipc_node *n; + int bearer_id = b->identity; + u32 self = tipc_own_addr(net); + int usr, rc = 0; + u16 bc_ack; +#ifdef CONFIG_TIPC_CRYPTO + struct tipc_ehdr *ehdr; + + /* Check if message must be decrypted first */ + if (TIPC_SKB_CB(skb)->decrypted || !tipc_ehdr_validate(skb)) + goto rcv; + + ehdr = (struct tipc_ehdr *)skb->data; + if (likely(ehdr->user != LINK_CONFIG)) { + n = tipc_node_find(net, ntohl(ehdr->addr)); + if (unlikely(!n)) + goto discard; + } else { + n = tipc_node_find_by_id(net, ehdr->id); + } + tipc_crypto_rcv(net, (n) ? n->crypto_rx : NULL, &skb, b); + if (!skb) + return; + +rcv: +#endif + /* Ensure message is well-formed before touching the header */ + if (unlikely(!tipc_msg_validate(&skb))) + goto discard; + __skb_queue_head_init(&xmitq); + hdr = buf_msg(skb); + usr = msg_user(hdr); + bc_ack = msg_bcast_ack(hdr); + + /* Handle arrival of discovery or broadcast packet */ + if (unlikely(msg_non_seq(hdr))) { + if (unlikely(usr == LINK_CONFIG)) + return tipc_disc_rcv(net, skb, b); + else + return tipc_node_bc_rcv(net, skb, bearer_id); + } + + /* Discard unicast link messages destined for another node */ + if (unlikely(!msg_short(hdr) && (msg_destnode(hdr) != self))) + goto discard; + + /* Locate neighboring node that sent packet */ + n = tipc_node_find(net, msg_prevnode(hdr)); + if (unlikely(!n)) + goto discard; + le = &n->links[bearer_id]; + + /* Ensure broadcast reception is in synch with peer's send state */ + if (unlikely(usr == LINK_PROTOCOL)) { + if (unlikely(skb_linearize(skb))) { + tipc_node_put(n); + goto discard; + } + hdr = buf_msg(skb); + tipc_node_bc_sync_rcv(n, hdr, bearer_id, &xmitq); + } else if (unlikely(tipc_link_acked(n->bc_entry.link) != bc_ack)) { + tipc_bcast_ack_rcv(net, n->bc_entry.link, hdr); + } + + /* Receive packet directly if conditions permit */ + tipc_node_read_lock(n); + if (likely((n->state == SELF_UP_PEER_UP) && (usr != TUNNEL_PROTOCOL))) { + spin_lock_bh(&le->lock); + if (le->link) { + rc = tipc_link_rcv(le->link, skb, &xmitq); + skb = NULL; + } + spin_unlock_bh(&le->lock); + } + tipc_node_read_unlock(n); + + /* Check/update node state before receiving */ + if (unlikely(skb)) { + if (unlikely(skb_linearize(skb))) + goto out_node_put; + tipc_node_write_lock(n); + if (tipc_node_check_state(n, skb, bearer_id, &xmitq)) { + if (le->link) { + rc = tipc_link_rcv(le->link, skb, &xmitq); + skb = NULL; + } + } + tipc_node_write_unlock(n); + } + + if (unlikely(rc & TIPC_LINK_UP_EVT)) + tipc_node_link_up(n, bearer_id, &xmitq); + + if (unlikely(rc & TIPC_LINK_DOWN_EVT)) + tipc_node_link_down(n, bearer_id, false); + + if (unlikely(!skb_queue_empty(&n->bc_entry.namedq))) + tipc_named_rcv(net, &n->bc_entry.namedq, + &n->bc_entry.named_rcv_nxt, + &n->bc_entry.named_open); + + if (unlikely(!skb_queue_empty(&n->bc_entry.inputq1))) + tipc_node_mcast_rcv(n); + + if (!skb_queue_empty(&le->inputq)) + tipc_sk_rcv(net, &le->inputq); + + if (!skb_queue_empty(&xmitq)) + tipc_bearer_xmit(net, bearer_id, &xmitq, &le->maddr, n); + +out_node_put: + tipc_node_put(n); +discard: + kfree_skb(skb); +} + +void tipc_node_apply_property(struct net *net, struct tipc_bearer *b, + int prop) +{ + struct tipc_net *tn = tipc_net(net); + int bearer_id = b->identity; + struct sk_buff_head xmitq; + struct tipc_link_entry *e; + struct tipc_node *n; + + __skb_queue_head_init(&xmitq); + + rcu_read_lock(); + + list_for_each_entry_rcu(n, &tn->node_list, list) { + tipc_node_write_lock(n); + e = &n->links[bearer_id]; + if (e->link) { + if (prop == TIPC_NLA_PROP_TOL) + tipc_link_set_tolerance(e->link, b->tolerance, + &xmitq); + else if (prop == TIPC_NLA_PROP_MTU) + tipc_link_set_mtu(e->link, b->mtu); + + /* Update MTU for node link entry */ + e->mtu = tipc_link_mss(e->link); + } + + tipc_node_write_unlock(n); + tipc_bearer_xmit(net, bearer_id, &xmitq, &e->maddr, NULL); + } + + rcu_read_unlock(); +} + +int tipc_nl_peer_rm(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = sock_net(skb->sk); + struct tipc_net *tn = net_generic(net, tipc_net_id); + struct nlattr *attrs[TIPC_NLA_NET_MAX + 1]; + struct tipc_node *peer, *temp_node; + u8 node_id[NODE_ID_LEN]; + u64 *w0 = (u64 *)&node_id[0]; + u64 *w1 = (u64 *)&node_id[8]; + u32 addr; + int err; + + /* We identify the peer by its net */ + if (!info->attrs[TIPC_NLA_NET]) + return -EINVAL; + + err = nla_parse_nested_deprecated(attrs, TIPC_NLA_NET_MAX, + info->attrs[TIPC_NLA_NET], + tipc_nl_net_policy, info->extack); + if (err) + return err; + + /* attrs[TIPC_NLA_NET_NODEID] and attrs[TIPC_NLA_NET_ADDR] are + * mutually exclusive cases + */ + if (attrs[TIPC_NLA_NET_ADDR]) { + addr = nla_get_u32(attrs[TIPC_NLA_NET_ADDR]); + if (!addr) + return -EINVAL; + } + + if (attrs[TIPC_NLA_NET_NODEID]) { + if (!attrs[TIPC_NLA_NET_NODEID_W1]) + return -EINVAL; + *w0 = nla_get_u64(attrs[TIPC_NLA_NET_NODEID]); + *w1 = nla_get_u64(attrs[TIPC_NLA_NET_NODEID_W1]); + addr = hash128to32(node_id); + } + + if (in_own_node(net, addr)) + return -ENOTSUPP; + + spin_lock_bh(&tn->node_list_lock); + peer = tipc_node_find(net, addr); + if (!peer) { + spin_unlock_bh(&tn->node_list_lock); + return -ENXIO; + } + + tipc_node_write_lock(peer); + if (peer->state != SELF_DOWN_PEER_DOWN && + peer->state != SELF_DOWN_PEER_LEAVING) { + tipc_node_write_unlock(peer); + err = -EBUSY; + goto err_out; + } + + tipc_node_clear_links(peer); + tipc_node_write_unlock(peer); + tipc_node_delete(peer); + + /* Calculate cluster capabilities */ + tn->capabilities = TIPC_NODE_CAPABILITIES; + list_for_each_entry_rcu(temp_node, &tn->node_list, list) { + tn->capabilities &= temp_node->capabilities; + } + tipc_bcast_toggle_rcast(net, (tn->capabilities & TIPC_BCAST_RCAST)); + err = 0; +err_out: + tipc_node_put(peer); + spin_unlock_bh(&tn->node_list_lock); + + return err; +} + +int tipc_nl_node_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + int err; + struct net *net = sock_net(skb->sk); + struct tipc_net *tn = net_generic(net, tipc_net_id); + int done = cb->args[0]; + int last_addr = cb->args[1]; + struct tipc_node *node; + struct tipc_nl_msg msg; + + if (done) + return 0; + + msg.skb = skb; + msg.portid = NETLINK_CB(cb->skb).portid; + msg.seq = cb->nlh->nlmsg_seq; + + rcu_read_lock(); + if (last_addr) { + node = tipc_node_find(net, last_addr); + if (!node) { + rcu_read_unlock(); + /* We never set seq or call nl_dump_check_consistent() + * this means that setting prev_seq here will cause the + * consistence check to fail in the netlink callback + * handler. Resulting in the NLMSG_DONE message having + * the NLM_F_DUMP_INTR flag set if the node state + * changed while we released the lock. + */ + cb->prev_seq = 1; + return -EPIPE; + } + tipc_node_put(node); + } + + list_for_each_entry_rcu(node, &tn->node_list, list) { + if (node->preliminary) + continue; + if (last_addr) { + if (node->addr == last_addr) + last_addr = 0; + else + continue; + } + + tipc_node_read_lock(node); + err = __tipc_nl_add_node(&msg, node); + if (err) { + last_addr = node->addr; + tipc_node_read_unlock(node); + goto out; + } + + tipc_node_read_unlock(node); + } + done = 1; +out: + cb->args[0] = done; + cb->args[1] = last_addr; + rcu_read_unlock(); + + return skb->len; +} + +/* tipc_node_find_by_name - locate owner node of link by link's name + * @net: the applicable net namespace + * @name: pointer to link name string + * @bearer_id: pointer to index in 'node->links' array where the link was found. + * + * Returns pointer to node owning the link, or 0 if no matching link is found. + */ +static struct tipc_node *tipc_node_find_by_name(struct net *net, + const char *link_name, + unsigned int *bearer_id) +{ + struct tipc_net *tn = net_generic(net, tipc_net_id); + struct tipc_link *l; + struct tipc_node *n; + struct tipc_node *found_node = NULL; + int i; + + *bearer_id = 0; + rcu_read_lock(); + list_for_each_entry_rcu(n, &tn->node_list, list) { + tipc_node_read_lock(n); + for (i = 0; i < MAX_BEARERS; i++) { + l = n->links[i].link; + if (l && !strcmp(tipc_link_name(l), link_name)) { + *bearer_id = i; + found_node = n; + break; + } + } + tipc_node_read_unlock(n); + if (found_node) + break; + } + rcu_read_unlock(); + + return found_node; +} + +int tipc_nl_node_set_link(struct sk_buff *skb, struct genl_info *info) +{ + int err; + int res = 0; + int bearer_id; + char *name; + struct tipc_link *link; + struct tipc_node *node; + struct sk_buff_head xmitq; + struct nlattr *attrs[TIPC_NLA_LINK_MAX + 1]; + struct net *net = sock_net(skb->sk); + + __skb_queue_head_init(&xmitq); + + if (!info->attrs[TIPC_NLA_LINK]) + return -EINVAL; + + err = nla_parse_nested_deprecated(attrs, TIPC_NLA_LINK_MAX, + info->attrs[TIPC_NLA_LINK], + tipc_nl_link_policy, info->extack); + if (err) + return err; + + if (!attrs[TIPC_NLA_LINK_NAME]) + return -EINVAL; + + name = nla_data(attrs[TIPC_NLA_LINK_NAME]); + + if (strcmp(name, tipc_bclink_name) == 0) + return tipc_nl_bc_link_set(net, attrs); + + node = tipc_node_find_by_name(net, name, &bearer_id); + if (!node) + return -EINVAL; + + tipc_node_read_lock(node); + + link = node->links[bearer_id].link; + if (!link) { + res = -EINVAL; + goto out; + } + + if (attrs[TIPC_NLA_LINK_PROP]) { + struct nlattr *props[TIPC_NLA_PROP_MAX + 1]; + + err = tipc_nl_parse_link_prop(attrs[TIPC_NLA_LINK_PROP], props); + if (err) { + res = err; + goto out; + } + + if (props[TIPC_NLA_PROP_TOL]) { + u32 tol; + + tol = nla_get_u32(props[TIPC_NLA_PROP_TOL]); + tipc_link_set_tolerance(link, tol, &xmitq); + } + if (props[TIPC_NLA_PROP_PRIO]) { + u32 prio; + + prio = nla_get_u32(props[TIPC_NLA_PROP_PRIO]); + tipc_link_set_prio(link, prio, &xmitq); + } + if (props[TIPC_NLA_PROP_WIN]) { + u32 max_win; + + max_win = nla_get_u32(props[TIPC_NLA_PROP_WIN]); + tipc_link_set_queue_limits(link, + tipc_link_min_win(link), + max_win); + } + } + +out: + tipc_node_read_unlock(node); + tipc_bearer_xmit(net, bearer_id, &xmitq, &node->links[bearer_id].maddr, + NULL); + return res; +} + +int tipc_nl_node_get_link(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = genl_info_net(info); + struct nlattr *attrs[TIPC_NLA_LINK_MAX + 1]; + struct tipc_nl_msg msg; + char *name; + int err; + + msg.portid = info->snd_portid; + msg.seq = info->snd_seq; + + if (!info->attrs[TIPC_NLA_LINK]) + return -EINVAL; + + err = nla_parse_nested_deprecated(attrs, TIPC_NLA_LINK_MAX, + info->attrs[TIPC_NLA_LINK], + tipc_nl_link_policy, info->extack); + if (err) + return err; + + if (!attrs[TIPC_NLA_LINK_NAME]) + return -EINVAL; + + name = nla_data(attrs[TIPC_NLA_LINK_NAME]); + + msg.skb = nlmsg_new(NLMSG_GOODSIZE, GFP_KERNEL); + if (!msg.skb) + return -ENOMEM; + + if (strcmp(name, tipc_bclink_name) == 0) { + err = tipc_nl_add_bc_link(net, &msg, tipc_net(net)->bcl); + if (err) + goto err_free; + } else { + int bearer_id; + struct tipc_node *node; + struct tipc_link *link; + + node = tipc_node_find_by_name(net, name, &bearer_id); + if (!node) { + err = -EINVAL; + goto err_free; + } + + tipc_node_read_lock(node); + link = node->links[bearer_id].link; + if (!link) { + tipc_node_read_unlock(node); + err = -EINVAL; + goto err_free; + } + + err = __tipc_nl_add_link(net, &msg, link, 0); + tipc_node_read_unlock(node); + if (err) + goto err_free; + } + + return genlmsg_reply(msg.skb, info); + +err_free: + nlmsg_free(msg.skb); + return err; +} + +int tipc_nl_node_reset_link_stats(struct sk_buff *skb, struct genl_info *info) +{ + int err; + char *link_name; + unsigned int bearer_id; + struct tipc_link *link; + struct tipc_node *node; + struct nlattr *attrs[TIPC_NLA_LINK_MAX + 1]; + struct net *net = sock_net(skb->sk); + struct tipc_net *tn = tipc_net(net); + struct tipc_link_entry *le; + + if (!info->attrs[TIPC_NLA_LINK]) + return -EINVAL; + + err = nla_parse_nested_deprecated(attrs, TIPC_NLA_LINK_MAX, + info->attrs[TIPC_NLA_LINK], + tipc_nl_link_policy, info->extack); + if (err) + return err; + + if (!attrs[TIPC_NLA_LINK_NAME]) + return -EINVAL; + + link_name = nla_data(attrs[TIPC_NLA_LINK_NAME]); + + err = -EINVAL; + if (!strcmp(link_name, tipc_bclink_name)) { + err = tipc_bclink_reset_stats(net, tipc_bc_sndlink(net)); + if (err) + return err; + return 0; + } else if (strstr(link_name, tipc_bclink_name)) { + rcu_read_lock(); + list_for_each_entry_rcu(node, &tn->node_list, list) { + tipc_node_read_lock(node); + link = node->bc_entry.link; + if (link && !strcmp(link_name, tipc_link_name(link))) { + err = tipc_bclink_reset_stats(net, link); + tipc_node_read_unlock(node); + break; + } + tipc_node_read_unlock(node); + } + rcu_read_unlock(); + return err; + } + + node = tipc_node_find_by_name(net, link_name, &bearer_id); + if (!node) + return -EINVAL; + + le = &node->links[bearer_id]; + tipc_node_read_lock(node); + spin_lock_bh(&le->lock); + link = node->links[bearer_id].link; + if (!link) { + spin_unlock_bh(&le->lock); + tipc_node_read_unlock(node); + return -EINVAL; + } + tipc_link_reset_stats(link); + spin_unlock_bh(&le->lock); + tipc_node_read_unlock(node); + return 0; +} + +/* Caller should hold node lock */ +static int __tipc_nl_add_node_links(struct net *net, struct tipc_nl_msg *msg, + struct tipc_node *node, u32 *prev_link, + bool bc_link) +{ + u32 i; + int err; + + for (i = *prev_link; i < MAX_BEARERS; i++) { + *prev_link = i; + + if (!node->links[i].link) + continue; + + err = __tipc_nl_add_link(net, msg, + node->links[i].link, NLM_F_MULTI); + if (err) + return err; + } + + if (bc_link) { + *prev_link = i; + err = tipc_nl_add_bc_link(net, msg, node->bc_entry.link); + if (err) + return err; + } + + *prev_link = 0; + + return 0; +} + +int tipc_nl_node_dump_link(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + struct nlattr **attrs = genl_dumpit_info(cb)->attrs; + struct nlattr *link[TIPC_NLA_LINK_MAX + 1]; + struct tipc_net *tn = net_generic(net, tipc_net_id); + struct tipc_node *node; + struct tipc_nl_msg msg; + u32 prev_node = cb->args[0]; + u32 prev_link = cb->args[1]; + int done = cb->args[2]; + bool bc_link = cb->args[3]; + int err; + + if (done) + return 0; + + if (!prev_node) { + /* Check if broadcast-receiver links dumping is needed */ + if (attrs && attrs[TIPC_NLA_LINK]) { + err = nla_parse_nested_deprecated(link, + TIPC_NLA_LINK_MAX, + attrs[TIPC_NLA_LINK], + tipc_nl_link_policy, + NULL); + if (unlikely(err)) + return err; + if (unlikely(!link[TIPC_NLA_LINK_BROADCAST])) + return -EINVAL; + bc_link = true; + } + } + + msg.skb = skb; + msg.portid = NETLINK_CB(cb->skb).portid; + msg.seq = cb->nlh->nlmsg_seq; + + rcu_read_lock(); + if (prev_node) { + node = tipc_node_find(net, prev_node); + if (!node) { + /* We never set seq or call nl_dump_check_consistent() + * this means that setting prev_seq here will cause the + * consistence check to fail in the netlink callback + * handler. Resulting in the last NLMSG_DONE message + * having the NLM_F_DUMP_INTR flag set. + */ + cb->prev_seq = 1; + goto out; + } + tipc_node_put(node); + + list_for_each_entry_continue_rcu(node, &tn->node_list, + list) { + tipc_node_read_lock(node); + err = __tipc_nl_add_node_links(net, &msg, node, + &prev_link, bc_link); + tipc_node_read_unlock(node); + if (err) + goto out; + + prev_node = node->addr; + } + } else { + err = tipc_nl_add_bc_link(net, &msg, tn->bcl); + if (err) + goto out; + + list_for_each_entry_rcu(node, &tn->node_list, list) { + tipc_node_read_lock(node); + err = __tipc_nl_add_node_links(net, &msg, node, + &prev_link, bc_link); + tipc_node_read_unlock(node); + if (err) + goto out; + + prev_node = node->addr; + } + } + done = 1; +out: + rcu_read_unlock(); + + cb->args[0] = prev_node; + cb->args[1] = prev_link; + cb->args[2] = done; + cb->args[3] = bc_link; + + return skb->len; +} + +int tipc_nl_node_set_monitor(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[TIPC_NLA_MON_MAX + 1]; + struct net *net = sock_net(skb->sk); + int err; + + if (!info->attrs[TIPC_NLA_MON]) + return -EINVAL; + + err = nla_parse_nested_deprecated(attrs, TIPC_NLA_MON_MAX, + info->attrs[TIPC_NLA_MON], + tipc_nl_monitor_policy, + info->extack); + if (err) + return err; + + if (attrs[TIPC_NLA_MON_ACTIVATION_THRESHOLD]) { + u32 val; + + val = nla_get_u32(attrs[TIPC_NLA_MON_ACTIVATION_THRESHOLD]); + err = tipc_nl_monitor_set_threshold(net, val); + if (err) + return err; + } + + return 0; +} + +static int __tipc_nl_add_monitor_prop(struct net *net, struct tipc_nl_msg *msg) +{ + struct nlattr *attrs; + void *hdr; + u32 val; + + hdr = genlmsg_put(msg->skb, msg->portid, msg->seq, &tipc_genl_family, + 0, TIPC_NL_MON_GET); + if (!hdr) + return -EMSGSIZE; + + attrs = nla_nest_start_noflag(msg->skb, TIPC_NLA_MON); + if (!attrs) + goto msg_full; + + val = tipc_nl_monitor_get_threshold(net); + + if (nla_put_u32(msg->skb, TIPC_NLA_MON_ACTIVATION_THRESHOLD, val)) + goto attr_msg_full; + + nla_nest_end(msg->skb, attrs); + genlmsg_end(msg->skb, hdr); + + return 0; + +attr_msg_full: + nla_nest_cancel(msg->skb, attrs); +msg_full: + genlmsg_cancel(msg->skb, hdr); + + return -EMSGSIZE; +} + +int tipc_nl_node_get_monitor(struct sk_buff *skb, struct genl_info *info) +{ + struct net *net = sock_net(skb->sk); + struct tipc_nl_msg msg; + int err; + + msg.skb = nlmsg_new(NLMSG_GOODSIZE, GFP_KERNEL); + if (!msg.skb) + return -ENOMEM; + msg.portid = info->snd_portid; + msg.seq = info->snd_seq; + + err = __tipc_nl_add_monitor_prop(net, &msg); + if (err) { + nlmsg_free(msg.skb); + return err; + } + + return genlmsg_reply(msg.skb, info); +} + +int tipc_nl_node_dump_monitor(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + u32 prev_bearer = cb->args[0]; + struct tipc_nl_msg msg; + int bearer_id; + int err; + + if (prev_bearer == MAX_BEARERS) + return 0; + + msg.skb = skb; + msg.portid = NETLINK_CB(cb->skb).portid; + msg.seq = cb->nlh->nlmsg_seq; + + rtnl_lock(); + for (bearer_id = prev_bearer; bearer_id < MAX_BEARERS; bearer_id++) { + err = __tipc_nl_add_monitor(net, &msg, bearer_id); + if (err) + break; + } + rtnl_unlock(); + cb->args[0] = bearer_id; + + return skb->len; +} + +int tipc_nl_node_dump_monitor_peer(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + u32 prev_node = cb->args[1]; + u32 bearer_id = cb->args[2]; + int done = cb->args[0]; + struct tipc_nl_msg msg; + int err; + + if (!prev_node) { + struct nlattr **attrs = genl_dumpit_info(cb)->attrs; + struct nlattr *mon[TIPC_NLA_MON_MAX + 1]; + + if (!attrs[TIPC_NLA_MON]) + return -EINVAL; + + err = nla_parse_nested_deprecated(mon, TIPC_NLA_MON_MAX, + attrs[TIPC_NLA_MON], + tipc_nl_monitor_policy, + NULL); + if (err) + return err; + + if (!mon[TIPC_NLA_MON_REF]) + return -EINVAL; + + bearer_id = nla_get_u32(mon[TIPC_NLA_MON_REF]); + + if (bearer_id >= MAX_BEARERS) + return -EINVAL; + } + + if (done) + return 0; + + msg.skb = skb; + msg.portid = NETLINK_CB(cb->skb).portid; + msg.seq = cb->nlh->nlmsg_seq; + + rtnl_lock(); + err = tipc_nl_add_monitor_peer(net, &msg, bearer_id, &prev_node); + if (!err) + done = 1; + + rtnl_unlock(); + cb->args[0] = done; + cb->args[1] = prev_node; + cb->args[2] = bearer_id; + + return skb->len; +} + +#ifdef CONFIG_TIPC_CRYPTO +static int tipc_nl_retrieve_key(struct nlattr **attrs, + struct tipc_aead_key **pkey) +{ + struct nlattr *attr = attrs[TIPC_NLA_NODE_KEY]; + struct tipc_aead_key *key; + + if (!attr) + return -ENODATA; + + if (nla_len(attr) < sizeof(*key)) + return -EINVAL; + key = (struct tipc_aead_key *)nla_data(attr); + if (key->keylen > TIPC_AEAD_KEYLEN_MAX || + nla_len(attr) < tipc_aead_key_size(key)) + return -EINVAL; + + *pkey = key; + return 0; +} + +static int tipc_nl_retrieve_nodeid(struct nlattr **attrs, u8 **node_id) +{ + struct nlattr *attr = attrs[TIPC_NLA_NODE_ID]; + + if (!attr) + return -ENODATA; + + if (nla_len(attr) < TIPC_NODEID_LEN) + return -EINVAL; + + *node_id = (u8 *)nla_data(attr); + return 0; +} + +static int tipc_nl_retrieve_rekeying(struct nlattr **attrs, u32 *intv) +{ + struct nlattr *attr = attrs[TIPC_NLA_NODE_REKEYING]; + + if (!attr) + return -ENODATA; + + *intv = nla_get_u32(attr); + return 0; +} + +static int __tipc_nl_node_set_key(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[TIPC_NLA_NODE_MAX + 1]; + struct net *net = sock_net(skb->sk); + struct tipc_crypto *tx = tipc_net(net)->crypto_tx, *c = tx; + struct tipc_node *n = NULL; + struct tipc_aead_key *ukey; + bool rekeying = true, master_key = false; + u8 *id, *own_id, mode; + u32 intv = 0; + int rc = 0; + + if (!info->attrs[TIPC_NLA_NODE]) + return -EINVAL; + + rc = nla_parse_nested(attrs, TIPC_NLA_NODE_MAX, + info->attrs[TIPC_NLA_NODE], + tipc_nl_node_policy, info->extack); + if (rc) + return rc; + + own_id = tipc_own_id(net); + if (!own_id) { + GENL_SET_ERR_MSG(info, "not found own node identity (set id?)"); + return -EPERM; + } + + rc = tipc_nl_retrieve_rekeying(attrs, &intv); + if (rc == -ENODATA) + rekeying = false; + + rc = tipc_nl_retrieve_key(attrs, &ukey); + if (rc == -ENODATA && rekeying) + goto rekeying; + else if (rc) + return rc; + + rc = tipc_aead_key_validate(ukey, info); + if (rc) + return rc; + + rc = tipc_nl_retrieve_nodeid(attrs, &id); + switch (rc) { + case -ENODATA: + mode = CLUSTER_KEY; + master_key = !!(attrs[TIPC_NLA_NODE_KEY_MASTER]); + break; + case 0: + mode = PER_NODE_KEY; + if (memcmp(id, own_id, NODE_ID_LEN)) { + n = tipc_node_find_by_id(net, id) ?: + tipc_node_create(net, 0, id, 0xffffu, 0, true); + if (unlikely(!n)) + return -ENOMEM; + c = n->crypto_rx; + } + break; + default: + return rc; + } + + /* Initiate the TX/RX key */ + rc = tipc_crypto_key_init(c, ukey, mode, master_key); + if (n) + tipc_node_put(n); + + if (unlikely(rc < 0)) { + GENL_SET_ERR_MSG(info, "unable to initiate or attach new key"); + return rc; + } else if (c == tx) { + /* Distribute TX key but not master one */ + if (!master_key && tipc_crypto_key_distr(tx, rc, NULL)) + GENL_SET_ERR_MSG(info, "failed to replicate new key"); +rekeying: + /* Schedule TX rekeying if needed */ + tipc_crypto_rekeying_sched(tx, rekeying, intv); + } + + return 0; +} + +int tipc_nl_node_set_key(struct sk_buff *skb, struct genl_info *info) +{ + int err; + + rtnl_lock(); + err = __tipc_nl_node_set_key(skb, info); + rtnl_unlock(); + + return err; +} + +static int __tipc_nl_node_flush_key(struct sk_buff *skb, + struct genl_info *info) +{ + struct net *net = sock_net(skb->sk); + struct tipc_net *tn = tipc_net(net); + struct tipc_node *n; + + tipc_crypto_key_flush(tn->crypto_tx); + rcu_read_lock(); + list_for_each_entry_rcu(n, &tn->node_list, list) + tipc_crypto_key_flush(n->crypto_rx); + rcu_read_unlock(); + + return 0; +} + +int tipc_nl_node_flush_key(struct sk_buff *skb, struct genl_info *info) +{ + int err; + + rtnl_lock(); + err = __tipc_nl_node_flush_key(skb, info); + rtnl_unlock(); + + return err; +} +#endif + +/** + * tipc_node_dump - dump TIPC node data + * @n: tipc node to be dumped + * @more: dump more? + * - false: dump only tipc node data + * - true: dump node link data as well + * @buf: returned buffer of dump data in format + */ +int tipc_node_dump(struct tipc_node *n, bool more, char *buf) +{ + int i = 0; + size_t sz = (more) ? NODE_LMAX : NODE_LMIN; + + if (!n) { + i += scnprintf(buf, sz, "node data: (null)\n"); + return i; + } + + i += scnprintf(buf, sz, "node data: %x", n->addr); + i += scnprintf(buf + i, sz - i, " %x", n->state); + i += scnprintf(buf + i, sz - i, " %d", n->active_links[0]); + i += scnprintf(buf + i, sz - i, " %d", n->active_links[1]); + i += scnprintf(buf + i, sz - i, " %x", n->action_flags); + i += scnprintf(buf + i, sz - i, " %u", n->failover_sent); + i += scnprintf(buf + i, sz - i, " %u", n->sync_point); + i += scnprintf(buf + i, sz - i, " %d", n->link_cnt); + i += scnprintf(buf + i, sz - i, " %u", n->working_links); + i += scnprintf(buf + i, sz - i, " %x", n->capabilities); + i += scnprintf(buf + i, sz - i, " %lu\n", n->keepalive_intv); + + if (!more) + return i; + + i += scnprintf(buf + i, sz - i, "link_entry[0]:\n"); + i += scnprintf(buf + i, sz - i, " mtu: %u\n", n->links[0].mtu); + i += scnprintf(buf + i, sz - i, " media: "); + i += tipc_media_addr_printf(buf + i, sz - i, &n->links[0].maddr); + i += scnprintf(buf + i, sz - i, "\n"); + i += tipc_link_dump(n->links[0].link, TIPC_DUMP_NONE, buf + i); + i += scnprintf(buf + i, sz - i, " inputq: "); + i += tipc_list_dump(&n->links[0].inputq, false, buf + i); + + i += scnprintf(buf + i, sz - i, "link_entry[1]:\n"); + i += scnprintf(buf + i, sz - i, " mtu: %u\n", n->links[1].mtu); + i += scnprintf(buf + i, sz - i, " media: "); + i += tipc_media_addr_printf(buf + i, sz - i, &n->links[1].maddr); + i += scnprintf(buf + i, sz - i, "\n"); + i += tipc_link_dump(n->links[1].link, TIPC_DUMP_NONE, buf + i); + i += scnprintf(buf + i, sz - i, " inputq: "); + i += tipc_list_dump(&n->links[1].inputq, false, buf + i); + + i += scnprintf(buf + i, sz - i, "bclink:\n "); + i += tipc_link_dump(n->bc_entry.link, TIPC_DUMP_NONE, buf + i); + + return i; +} + +void tipc_node_pre_cleanup_net(struct net *exit_net) +{ + struct tipc_node *n; + struct tipc_net *tn; + struct net *tmp; + + rcu_read_lock(); + for_each_net_rcu(tmp) { + if (tmp == exit_net) + continue; + tn = tipc_net(tmp); + if (!tn) + continue; + spin_lock_bh(&tn->node_list_lock); + list_for_each_entry_rcu(n, &tn->node_list, list) { + if (!n->peer_net) + continue; + if (n->peer_net != exit_net) + continue; + tipc_node_write_lock(n); + n->peer_net = NULL; + n->peer_hash_mix = 0; + tipc_node_write_unlock_fast(n); + break; + } + spin_unlock_bh(&tn->node_list_lock); + } + rcu_read_unlock(); +} diff --git a/net/tipc/node.h b/net/tipc/node.h new file mode 100644 index 000000000..154a5bbb0 --- /dev/null +++ b/net/tipc/node.h @@ -0,0 +1,131 @@ +/* + * net/tipc/node.h: Include file for TIPC node management routines + * + * Copyright (c) 2000-2006, 2014-2016, Ericsson AB + * Copyright (c) 2005, 2010-2014, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _TIPC_NODE_H +#define _TIPC_NODE_H + +#include "addr.h" +#include "net.h" +#include "bearer.h" +#include "msg.h" + +/* Optional capabilities supported by this code version + */ +enum { + TIPC_SYN_BIT = (1), + TIPC_BCAST_SYNCH = (1 << 1), + TIPC_BCAST_STATE_NACK = (1 << 2), + TIPC_BLOCK_FLOWCTL = (1 << 3), + TIPC_BCAST_RCAST = (1 << 4), + TIPC_NODE_ID128 = (1 << 5), + TIPC_LINK_PROTO_SEQNO = (1 << 6), + TIPC_MCAST_RBCTL = (1 << 7), + TIPC_GAP_ACK_BLOCK = (1 << 8), + TIPC_TUNNEL_ENHANCED = (1 << 9), + TIPC_NAGLE = (1 << 10), + TIPC_NAMED_BCAST = (1 << 11) +}; + +#define TIPC_NODE_CAPABILITIES (TIPC_SYN_BIT | \ + TIPC_BCAST_SYNCH | \ + TIPC_BCAST_STATE_NACK | \ + TIPC_BCAST_RCAST | \ + TIPC_BLOCK_FLOWCTL | \ + TIPC_NODE_ID128 | \ + TIPC_LINK_PROTO_SEQNO | \ + TIPC_MCAST_RBCTL | \ + TIPC_GAP_ACK_BLOCK | \ + TIPC_TUNNEL_ENHANCED | \ + TIPC_NAGLE | \ + TIPC_NAMED_BCAST) + +#define INVALID_BEARER_ID -1 + +void tipc_node_stop(struct net *net); +bool tipc_node_get_id(struct net *net, u32 addr, u8 *id); +u32 tipc_node_get_addr(struct tipc_node *node); +char *tipc_node_get_id_str(struct tipc_node *node); +void tipc_node_put(struct tipc_node *node); +void tipc_node_get(struct tipc_node *node); +struct tipc_node *tipc_node_create(struct net *net, u32 addr, u8 *peer_id, + u16 capabilities, u32 hash_mixes, + bool preliminary); +#ifdef CONFIG_TIPC_CRYPTO +struct tipc_crypto *tipc_node_crypto_rx(struct tipc_node *__n); +struct tipc_crypto *tipc_node_crypto_rx_by_list(struct list_head *pos); +struct tipc_crypto *tipc_node_crypto_rx_by_addr(struct net *net, u32 addr); +#endif +u32 tipc_node_try_addr(struct net *net, u8 *id, u32 addr); +void tipc_node_check_dest(struct net *net, u32 onode, u8 *peer_id128, + struct tipc_bearer *bearer, + u16 capabilities, u32 signature, u32 hash_mixes, + struct tipc_media_addr *maddr, + bool *respond, bool *dupl_addr); +void tipc_node_delete_links(struct net *net, int bearer_id); +void tipc_node_apply_property(struct net *net, struct tipc_bearer *b, int prop); +int tipc_node_get_linkname(struct net *net, u32 bearer_id, u32 node, + char *linkname, size_t len); +int tipc_node_xmit(struct net *net, struct sk_buff_head *list, u32 dnode, + int selector); +int tipc_node_distr_xmit(struct net *net, struct sk_buff_head *list); +int tipc_node_xmit_skb(struct net *net, struct sk_buff *skb, u32 dest, + u32 selector); +void tipc_node_subscribe(struct net *net, struct list_head *subscr, u32 addr); +void tipc_node_unsubscribe(struct net *net, struct list_head *subscr, u32 addr); +void tipc_node_broadcast(struct net *net, struct sk_buff *skb, int rc_dests); +int tipc_node_add_conn(struct net *net, u32 dnode, u32 port, u32 peer_port); +void tipc_node_remove_conn(struct net *net, u32 dnode, u32 port); +int tipc_node_get_mtu(struct net *net, u32 addr, u32 sel, bool connected); +bool tipc_node_is_up(struct net *net, u32 addr); +u16 tipc_node_get_capabilities(struct net *net, u32 addr); +int tipc_nl_node_dump(struct sk_buff *skb, struct netlink_callback *cb); +int tipc_nl_node_dump_link(struct sk_buff *skb, struct netlink_callback *cb); +int tipc_nl_node_reset_link_stats(struct sk_buff *skb, struct genl_info *info); +int tipc_nl_node_get_link(struct sk_buff *skb, struct genl_info *info); +int tipc_nl_node_set_link(struct sk_buff *skb, struct genl_info *info); +int tipc_nl_peer_rm(struct sk_buff *skb, struct genl_info *info); + +int tipc_nl_node_set_monitor(struct sk_buff *skb, struct genl_info *info); +int tipc_nl_node_get_monitor(struct sk_buff *skb, struct genl_info *info); +int tipc_nl_node_dump_monitor(struct sk_buff *skb, struct netlink_callback *cb); +int tipc_nl_node_dump_monitor_peer(struct sk_buff *skb, + struct netlink_callback *cb); +#ifdef CONFIG_TIPC_CRYPTO +int tipc_nl_node_set_key(struct sk_buff *skb, struct genl_info *info); +int tipc_nl_node_flush_key(struct sk_buff *skb, struct genl_info *info); +#endif +void tipc_node_pre_cleanup_net(struct net *exit_net); +#endif diff --git a/net/tipc/socket.c b/net/tipc/socket.c new file mode 100644 index 000000000..14027a7a7 --- /dev/null +++ b/net/tipc/socket.c @@ -0,0 +1,4019 @@ +/* + * net/tipc/socket.c: TIPC socket API + * + * Copyright (c) 2001-2007, 2012-2019, Ericsson AB + * Copyright (c) 2004-2008, 2010-2013, Wind River Systems + * Copyright (c) 2020-2021, Red Hat Inc + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include +#include + +#include "core.h" +#include "name_table.h" +#include "node.h" +#include "link.h" +#include "name_distr.h" +#include "socket.h" +#include "bcast.h" +#include "netlink.h" +#include "group.h" +#include "trace.h" + +#define NAGLE_START_INIT 4 +#define NAGLE_START_MAX 1024 +#define CONN_TIMEOUT_DEFAULT 8000 /* default connect timeout = 8s */ +#define CONN_PROBING_INTV msecs_to_jiffies(3600000) /* [ms] => 1 h */ +#define TIPC_MAX_PORT 0xffffffff +#define TIPC_MIN_PORT 1 +#define TIPC_ACK_RATE 4 /* ACK at 1/4 of rcv window size */ + +enum { + TIPC_LISTEN = TCP_LISTEN, + TIPC_ESTABLISHED = TCP_ESTABLISHED, + TIPC_OPEN = TCP_CLOSE, + TIPC_DISCONNECTING = TCP_CLOSE_WAIT, + TIPC_CONNECTING = TCP_SYN_SENT, +}; + +struct sockaddr_pair { + struct sockaddr_tipc sock; + struct sockaddr_tipc member; +}; + +/** + * struct tipc_sock - TIPC socket structure + * @sk: socket - interacts with 'port' and with user via the socket API + * @max_pkt: maximum packet size "hint" used when building messages sent by port + * @maxnagle: maximum size of msg which can be subject to nagle + * @portid: unique port identity in TIPC socket hash table + * @phdr: preformatted message header used when sending messages + * @cong_links: list of congested links + * @publications: list of publications for port + * @blocking_link: address of the congested link we are currently sleeping on + * @pub_count: total # of publications port has made during its lifetime + * @conn_timeout: the time we can wait for an unresponded setup request + * @probe_unacked: probe has not received ack yet + * @dupl_rcvcnt: number of bytes counted twice, in both backlog and rcv queue + * @cong_link_cnt: number of congested links + * @snt_unacked: # messages sent by socket, and not yet acked by peer + * @snd_win: send window size + * @peer_caps: peer capabilities mask + * @rcv_unacked: # messages read by user, but not yet acked back to peer + * @rcv_win: receive window size + * @peer: 'connected' peer for dgram/rdm + * @node: hash table node + * @mc_method: cookie for use between socket and broadcast layer + * @rcu: rcu struct for tipc_sock + * @group: TIPC communications group + * @oneway: message count in one direction (FIXME) + * @nagle_start: current nagle value + * @snd_backlog: send backlog count + * @msg_acc: messages accepted; used in managing backlog and nagle + * @pkt_cnt: TIPC socket packet count + * @expect_ack: whether this TIPC socket is expecting an ack + * @nodelay: setsockopt() TIPC_NODELAY setting + * @group_is_open: TIPC socket group is fully open (FIXME) + * @published: true if port has one or more associated names + * @conn_addrtype: address type used when establishing connection + */ +struct tipc_sock { + struct sock sk; + u32 max_pkt; + u32 maxnagle; + u32 portid; + struct tipc_msg phdr; + struct list_head cong_links; + struct list_head publications; + u32 pub_count; + atomic_t dupl_rcvcnt; + u16 conn_timeout; + bool probe_unacked; + u16 cong_link_cnt; + u16 snt_unacked; + u16 snd_win; + u16 peer_caps; + u16 rcv_unacked; + u16 rcv_win; + struct sockaddr_tipc peer; + struct rhash_head node; + struct tipc_mc_method mc_method; + struct rcu_head rcu; + struct tipc_group *group; + u32 oneway; + u32 nagle_start; + u16 snd_backlog; + u16 msg_acc; + u16 pkt_cnt; + bool expect_ack; + bool nodelay; + bool group_is_open; + bool published; + u8 conn_addrtype; +}; + +static int tipc_sk_backlog_rcv(struct sock *sk, struct sk_buff *skb); +static void tipc_data_ready(struct sock *sk); +static void tipc_write_space(struct sock *sk); +static void tipc_sock_destruct(struct sock *sk); +static int tipc_release(struct socket *sock); +static int tipc_accept(struct socket *sock, struct socket *new_sock, int flags, + bool kern); +static void tipc_sk_timeout(struct timer_list *t); +static int tipc_sk_publish(struct tipc_sock *tsk, struct tipc_uaddr *ua); +static int tipc_sk_withdraw(struct tipc_sock *tsk, struct tipc_uaddr *ua); +static int tipc_sk_leave(struct tipc_sock *tsk); +static struct tipc_sock *tipc_sk_lookup(struct net *net, u32 portid); +static int tipc_sk_insert(struct tipc_sock *tsk); +static void tipc_sk_remove(struct tipc_sock *tsk); +static int __tipc_sendstream(struct socket *sock, struct msghdr *m, size_t dsz); +static int __tipc_sendmsg(struct socket *sock, struct msghdr *m, size_t dsz); +static void tipc_sk_push_backlog(struct tipc_sock *tsk, bool nagle_ack); +static int tipc_wait_for_connect(struct socket *sock, long *timeo_p); + +static const struct proto_ops packet_ops; +static const struct proto_ops stream_ops; +static const struct proto_ops msg_ops; +static struct proto tipc_proto; +static const struct rhashtable_params tsk_rht_params; + +static u32 tsk_own_node(struct tipc_sock *tsk) +{ + return msg_prevnode(&tsk->phdr); +} + +static u32 tsk_peer_node(struct tipc_sock *tsk) +{ + return msg_destnode(&tsk->phdr); +} + +static u32 tsk_peer_port(struct tipc_sock *tsk) +{ + return msg_destport(&tsk->phdr); +} + +static bool tsk_unreliable(struct tipc_sock *tsk) +{ + return msg_src_droppable(&tsk->phdr) != 0; +} + +static void tsk_set_unreliable(struct tipc_sock *tsk, bool unreliable) +{ + msg_set_src_droppable(&tsk->phdr, unreliable ? 1 : 0); +} + +static bool tsk_unreturnable(struct tipc_sock *tsk) +{ + return msg_dest_droppable(&tsk->phdr) != 0; +} + +static void tsk_set_unreturnable(struct tipc_sock *tsk, bool unreturnable) +{ + msg_set_dest_droppable(&tsk->phdr, unreturnable ? 1 : 0); +} + +static int tsk_importance(struct tipc_sock *tsk) +{ + return msg_importance(&tsk->phdr); +} + +static struct tipc_sock *tipc_sk(const struct sock *sk) +{ + return container_of(sk, struct tipc_sock, sk); +} + +int tsk_set_importance(struct sock *sk, int imp) +{ + if (imp > TIPC_CRITICAL_IMPORTANCE) + return -EINVAL; + msg_set_importance(&tipc_sk(sk)->phdr, (u32)imp); + return 0; +} + +static bool tsk_conn_cong(struct tipc_sock *tsk) +{ + return tsk->snt_unacked > tsk->snd_win; +} + +static u16 tsk_blocks(int len) +{ + return ((len / FLOWCTL_BLK_SZ) + 1); +} + +/* tsk_blocks(): translate a buffer size in bytes to number of + * advertisable blocks, taking into account the ratio truesize(len)/len + * We can trust that this ratio is always < 4 for len >= FLOWCTL_BLK_SZ + */ +static u16 tsk_adv_blocks(int len) +{ + return len / FLOWCTL_BLK_SZ / 4; +} + +/* tsk_inc(): increment counter for sent or received data + * - If block based flow control is not supported by peer we + * fall back to message based ditto, incrementing the counter + */ +static u16 tsk_inc(struct tipc_sock *tsk, int msglen) +{ + if (likely(tsk->peer_caps & TIPC_BLOCK_FLOWCTL)) + return ((msglen / FLOWCTL_BLK_SZ) + 1); + return 1; +} + +/* tsk_set_nagle - enable/disable nagle property by manipulating maxnagle + */ +static void tsk_set_nagle(struct tipc_sock *tsk) +{ + struct sock *sk = &tsk->sk; + + tsk->maxnagle = 0; + if (sk->sk_type != SOCK_STREAM) + return; + if (tsk->nodelay) + return; + if (!(tsk->peer_caps & TIPC_NAGLE)) + return; + /* Limit node local buffer size to avoid receive queue overflow */ + if (tsk->max_pkt == MAX_MSG_SIZE) + tsk->maxnagle = 1500; + else + tsk->maxnagle = tsk->max_pkt; +} + +/** + * tsk_advance_rx_queue - discard first buffer in socket receive queue + * @sk: network socket + * + * Caller must hold socket lock + */ +static void tsk_advance_rx_queue(struct sock *sk) +{ + trace_tipc_sk_advance_rx(sk, NULL, TIPC_DUMP_SK_RCVQ, " "); + kfree_skb(__skb_dequeue(&sk->sk_receive_queue)); +} + +/* tipc_sk_respond() : send response message back to sender + */ +static void tipc_sk_respond(struct sock *sk, struct sk_buff *skb, int err) +{ + u32 selector; + u32 dnode; + u32 onode = tipc_own_addr(sock_net(sk)); + + if (!tipc_msg_reverse(onode, &skb, err)) + return; + + trace_tipc_sk_rej_msg(sk, skb, TIPC_DUMP_NONE, "@sk_respond!"); + dnode = msg_destnode(buf_msg(skb)); + selector = msg_origport(buf_msg(skb)); + tipc_node_xmit_skb(sock_net(sk), skb, dnode, selector); +} + +/** + * tsk_rej_rx_queue - reject all buffers in socket receive queue + * @sk: network socket + * @error: response error code + * + * Caller must hold socket lock + */ +static void tsk_rej_rx_queue(struct sock *sk, int error) +{ + struct sk_buff *skb; + + while ((skb = __skb_dequeue(&sk->sk_receive_queue))) + tipc_sk_respond(sk, skb, error); +} + +static bool tipc_sk_connected(const struct sock *sk) +{ + return READ_ONCE(sk->sk_state) == TIPC_ESTABLISHED; +} + +/* tipc_sk_type_connectionless - check if the socket is datagram socket + * @sk: socket + * + * Returns true if connection less, false otherwise + */ +static bool tipc_sk_type_connectionless(struct sock *sk) +{ + return sk->sk_type == SOCK_RDM || sk->sk_type == SOCK_DGRAM; +} + +/* tsk_peer_msg - verify if message was sent by connected port's peer + * + * Handles cases where the node's network address has changed from + * the default of <0.0.0> to its configured setting. + */ +static bool tsk_peer_msg(struct tipc_sock *tsk, struct tipc_msg *msg) +{ + struct sock *sk = &tsk->sk; + u32 self = tipc_own_addr(sock_net(sk)); + u32 peer_port = tsk_peer_port(tsk); + u32 orig_node, peer_node; + + if (unlikely(!tipc_sk_connected(sk))) + return false; + + if (unlikely(msg_origport(msg) != peer_port)) + return false; + + orig_node = msg_orignode(msg); + peer_node = tsk_peer_node(tsk); + + if (likely(orig_node == peer_node)) + return true; + + if (!orig_node && peer_node == self) + return true; + + if (!peer_node && orig_node == self) + return true; + + return false; +} + +/* tipc_set_sk_state - set the sk_state of the socket + * @sk: socket + * + * Caller must hold socket lock + * + * Returns 0 on success, errno otherwise + */ +static int tipc_set_sk_state(struct sock *sk, int state) +{ + int oldsk_state = sk->sk_state; + int res = -EINVAL; + + switch (state) { + case TIPC_OPEN: + res = 0; + break; + case TIPC_LISTEN: + case TIPC_CONNECTING: + if (oldsk_state == TIPC_OPEN) + res = 0; + break; + case TIPC_ESTABLISHED: + if (oldsk_state == TIPC_CONNECTING || + oldsk_state == TIPC_OPEN) + res = 0; + break; + case TIPC_DISCONNECTING: + if (oldsk_state == TIPC_CONNECTING || + oldsk_state == TIPC_ESTABLISHED) + res = 0; + break; + } + + if (!res) + sk->sk_state = state; + + return res; +} + +static int tipc_sk_sock_err(struct socket *sock, long *timeout) +{ + struct sock *sk = sock->sk; + int err = sock_error(sk); + int typ = sock->type; + + if (err) + return err; + if (typ == SOCK_STREAM || typ == SOCK_SEQPACKET) { + if (sk->sk_state == TIPC_DISCONNECTING) + return -EPIPE; + else if (!tipc_sk_connected(sk)) + return -ENOTCONN; + } + if (!*timeout) + return -EAGAIN; + if (signal_pending(current)) + return sock_intr_errno(*timeout); + + return 0; +} + +#define tipc_wait_for_cond(sock_, timeo_, condition_) \ +({ \ + DEFINE_WAIT_FUNC(wait_, woken_wake_function); \ + struct sock *sk_; \ + int rc_; \ + \ + while ((rc_ = !(condition_))) { \ + /* coupled with smp_wmb() in tipc_sk_proto_rcv() */ \ + smp_rmb(); \ + sk_ = (sock_)->sk; \ + rc_ = tipc_sk_sock_err((sock_), timeo_); \ + if (rc_) \ + break; \ + add_wait_queue(sk_sleep(sk_), &wait_); \ + release_sock(sk_); \ + *(timeo_) = wait_woken(&wait_, TASK_INTERRUPTIBLE, *(timeo_)); \ + sched_annotate_sleep(); \ + lock_sock(sk_); \ + remove_wait_queue(sk_sleep(sk_), &wait_); \ + } \ + rc_; \ +}) + +/** + * tipc_sk_create - create a TIPC socket + * @net: network namespace (must be default network) + * @sock: pre-allocated socket structure + * @protocol: protocol indicator (must be 0) + * @kern: caused by kernel or by userspace? + * + * This routine creates additional data structures used by the TIPC socket, + * initializes them, and links them together. + * + * Return: 0 on success, errno otherwise + */ +static int tipc_sk_create(struct net *net, struct socket *sock, + int protocol, int kern) +{ + const struct proto_ops *ops; + struct sock *sk; + struct tipc_sock *tsk; + struct tipc_msg *msg; + + /* Validate arguments */ + if (unlikely(protocol != 0)) + return -EPROTONOSUPPORT; + + switch (sock->type) { + case SOCK_STREAM: + ops = &stream_ops; + break; + case SOCK_SEQPACKET: + ops = &packet_ops; + break; + case SOCK_DGRAM: + case SOCK_RDM: + ops = &msg_ops; + break; + default: + return -EPROTOTYPE; + } + + /* Allocate socket's protocol area */ + sk = sk_alloc(net, AF_TIPC, GFP_KERNEL, &tipc_proto, kern); + if (sk == NULL) + return -ENOMEM; + + tsk = tipc_sk(sk); + tsk->max_pkt = MAX_PKT_DEFAULT; + tsk->maxnagle = 0; + tsk->nagle_start = NAGLE_START_INIT; + INIT_LIST_HEAD(&tsk->publications); + INIT_LIST_HEAD(&tsk->cong_links); + msg = &tsk->phdr; + + /* Finish initializing socket data structures */ + sock->ops = ops; + sock_init_data(sock, sk); + tipc_set_sk_state(sk, TIPC_OPEN); + if (tipc_sk_insert(tsk)) { + sk_free(sk); + pr_warn("Socket create failed; port number exhausted\n"); + return -EINVAL; + } + + /* Ensure tsk is visible before we read own_addr. */ + smp_mb(); + + tipc_msg_init(tipc_own_addr(net), msg, TIPC_LOW_IMPORTANCE, + TIPC_NAMED_MSG, NAMED_H_SIZE, 0); + + msg_set_origport(msg, tsk->portid); + timer_setup(&sk->sk_timer, tipc_sk_timeout, 0); + sk->sk_shutdown = 0; + sk->sk_backlog_rcv = tipc_sk_backlog_rcv; + sk->sk_rcvbuf = READ_ONCE(sysctl_tipc_rmem[1]); + sk->sk_data_ready = tipc_data_ready; + sk->sk_write_space = tipc_write_space; + sk->sk_destruct = tipc_sock_destruct; + tsk->conn_timeout = CONN_TIMEOUT_DEFAULT; + tsk->group_is_open = true; + atomic_set(&tsk->dupl_rcvcnt, 0); + + /* Start out with safe limits until we receive an advertised window */ + tsk->snd_win = tsk_adv_blocks(RCVBUF_MIN); + tsk->rcv_win = tsk->snd_win; + + if (tipc_sk_type_connectionless(sk)) { + tsk_set_unreturnable(tsk, true); + if (sock->type == SOCK_DGRAM) + tsk_set_unreliable(tsk, true); + } + __skb_queue_head_init(&tsk->mc_method.deferredq); + trace_tipc_sk_create(sk, NULL, TIPC_DUMP_NONE, " "); + return 0; +} + +static void tipc_sk_callback(struct rcu_head *head) +{ + struct tipc_sock *tsk = container_of(head, struct tipc_sock, rcu); + + sock_put(&tsk->sk); +} + +/* Caller should hold socket lock for the socket. */ +static void __tipc_shutdown(struct socket *sock, int error) +{ + struct sock *sk = sock->sk; + struct tipc_sock *tsk = tipc_sk(sk); + struct net *net = sock_net(sk); + long timeout = msecs_to_jiffies(CONN_TIMEOUT_DEFAULT); + u32 dnode = tsk_peer_node(tsk); + struct sk_buff *skb; + + /* Avoid that hi-prio shutdown msgs bypass msgs in link wakeup queue */ + tipc_wait_for_cond(sock, &timeout, (!tsk->cong_link_cnt && + !tsk_conn_cong(tsk))); + + /* Push out delayed messages if in Nagle mode */ + tipc_sk_push_backlog(tsk, false); + /* Remove pending SYN */ + __skb_queue_purge(&sk->sk_write_queue); + + /* Remove partially received buffer if any */ + skb = skb_peek(&sk->sk_receive_queue); + if (skb && TIPC_SKB_CB(skb)->bytes_read) { + __skb_unlink(skb, &sk->sk_receive_queue); + kfree_skb(skb); + } + + /* Reject all unreceived messages if connectionless */ + if (tipc_sk_type_connectionless(sk)) { + tsk_rej_rx_queue(sk, error); + return; + } + + switch (sk->sk_state) { + case TIPC_CONNECTING: + case TIPC_ESTABLISHED: + tipc_set_sk_state(sk, TIPC_DISCONNECTING); + tipc_node_remove_conn(net, dnode, tsk->portid); + /* Send a FIN+/- to its peer */ + skb = __skb_dequeue(&sk->sk_receive_queue); + if (skb) { + __skb_queue_purge(&sk->sk_receive_queue); + tipc_sk_respond(sk, skb, error); + break; + } + skb = tipc_msg_create(TIPC_CRITICAL_IMPORTANCE, + TIPC_CONN_MSG, SHORT_H_SIZE, 0, dnode, + tsk_own_node(tsk), tsk_peer_port(tsk), + tsk->portid, error); + if (skb) + tipc_node_xmit_skb(net, skb, dnode, tsk->portid); + break; + case TIPC_LISTEN: + /* Reject all SYN messages */ + tsk_rej_rx_queue(sk, error); + break; + default: + __skb_queue_purge(&sk->sk_receive_queue); + break; + } +} + +/** + * tipc_release - destroy a TIPC socket + * @sock: socket to destroy + * + * This routine cleans up any messages that are still queued on the socket. + * For DGRAM and RDM socket types, all queued messages are rejected. + * For SEQPACKET and STREAM socket types, the first message is rejected + * and any others are discarded. (If the first message on a STREAM socket + * is partially-read, it is discarded and the next one is rejected instead.) + * + * NOTE: Rejected messages are not necessarily returned to the sender! They + * are returned or discarded according to the "destination droppable" setting + * specified for the message by the sender. + * + * Return: 0 on success, errno otherwise + */ +static int tipc_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct tipc_sock *tsk; + + /* + * Exit if socket isn't fully initialized (occurs when a failed accept() + * releases a pre-allocated child socket that was never used) + */ + if (sk == NULL) + return 0; + + tsk = tipc_sk(sk); + lock_sock(sk); + + trace_tipc_sk_release(sk, NULL, TIPC_DUMP_ALL, " "); + __tipc_shutdown(sock, TIPC_ERR_NO_PORT); + sk->sk_shutdown = SHUTDOWN_MASK; + tipc_sk_leave(tsk); + tipc_sk_withdraw(tsk, NULL); + __skb_queue_purge(&tsk->mc_method.deferredq); + sk_stop_timer(sk, &sk->sk_timer); + tipc_sk_remove(tsk); + + sock_orphan(sk); + /* Reject any messages that accumulated in backlog queue */ + release_sock(sk); + tipc_dest_list_purge(&tsk->cong_links); + tsk->cong_link_cnt = 0; + call_rcu(&tsk->rcu, tipc_sk_callback); + sock->sk = NULL; + + return 0; +} + +/** + * __tipc_bind - associate or disassocate TIPC name(s) with a socket + * @sock: socket structure + * @skaddr: socket address describing name(s) and desired operation + * @alen: size of socket address data structure + * + * Name and name sequence binding are indicated using a positive scope value; + * a negative scope value unbinds the specified name. Specifying no name + * (i.e. a socket address length of 0) unbinds all names from the socket. + * + * Return: 0 on success, errno otherwise + * + * NOTE: This routine doesn't need to take the socket lock since it doesn't + * access any non-constant socket information. + */ +static int __tipc_bind(struct socket *sock, struct sockaddr *skaddr, int alen) +{ + struct tipc_uaddr *ua = (struct tipc_uaddr *)skaddr; + struct tipc_sock *tsk = tipc_sk(sock->sk); + bool unbind = false; + + if (unlikely(!alen)) + return tipc_sk_withdraw(tsk, NULL); + + if (ua->addrtype == TIPC_SERVICE_ADDR) { + ua->addrtype = TIPC_SERVICE_RANGE; + ua->sr.upper = ua->sr.lower; + } + if (ua->scope < 0) { + unbind = true; + ua->scope = -ua->scope; + } + /* Users may still use deprecated TIPC_ZONE_SCOPE */ + if (ua->scope != TIPC_NODE_SCOPE) + ua->scope = TIPC_CLUSTER_SCOPE; + + if (tsk->group) + return -EACCES; + + if (unbind) + return tipc_sk_withdraw(tsk, ua); + return tipc_sk_publish(tsk, ua); +} + +int tipc_sk_bind(struct socket *sock, struct sockaddr *skaddr, int alen) +{ + int res; + + lock_sock(sock->sk); + res = __tipc_bind(sock, skaddr, alen); + release_sock(sock->sk); + return res; +} + +static int tipc_bind(struct socket *sock, struct sockaddr *skaddr, int alen) +{ + struct tipc_uaddr *ua = (struct tipc_uaddr *)skaddr; + u32 atype = ua->addrtype; + + if (alen) { + if (!tipc_uaddr_valid(ua, alen)) + return -EINVAL; + if (atype == TIPC_SOCKET_ADDR) + return -EAFNOSUPPORT; + if (ua->sr.type < TIPC_RESERVED_TYPES) { + pr_warn_once("Can't bind to reserved service type %u\n", + ua->sr.type); + return -EACCES; + } + } + return tipc_sk_bind(sock, skaddr, alen); +} + +/** + * tipc_getname - get port ID of socket or peer socket + * @sock: socket structure + * @uaddr: area for returned socket address + * @peer: 0 = own ID, 1 = current peer ID, 2 = current/former peer ID + * + * Return: 0 on success, errno otherwise + * + * NOTE: This routine doesn't need to take the socket lock since it only + * accesses socket information that is unchanging (or which changes in + * a completely predictable manner). + */ +static int tipc_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + struct sockaddr_tipc *addr = (struct sockaddr_tipc *)uaddr; + struct sock *sk = sock->sk; + struct tipc_sock *tsk = tipc_sk(sk); + + memset(addr, 0, sizeof(*addr)); + if (peer) { + if ((!tipc_sk_connected(sk)) && + ((peer != 2) || (sk->sk_state != TIPC_DISCONNECTING))) + return -ENOTCONN; + addr->addr.id.ref = tsk_peer_port(tsk); + addr->addr.id.node = tsk_peer_node(tsk); + } else { + addr->addr.id.ref = tsk->portid; + addr->addr.id.node = tipc_own_addr(sock_net(sk)); + } + + addr->addrtype = TIPC_SOCKET_ADDR; + addr->family = AF_TIPC; + addr->scope = 0; + addr->addr.name.domain = 0; + + return sizeof(*addr); +} + +/** + * tipc_poll - read and possibly block on pollmask + * @file: file structure associated with the socket + * @sock: socket for which to calculate the poll bits + * @wait: ??? + * + * Return: pollmask value + * + * COMMENTARY: + * It appears that the usual socket locking mechanisms are not useful here + * since the pollmask info is potentially out-of-date the moment this routine + * exits. TCP and other protocols seem to rely on higher level poll routines + * to handle any preventable race conditions, so TIPC will do the same ... + * + * IMPORTANT: The fact that a read or write operation is indicated does NOT + * imply that the operation will succeed, merely that it should be performed + * and will not block. + */ +static __poll_t tipc_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + struct sock *sk = sock->sk; + struct tipc_sock *tsk = tipc_sk(sk); + __poll_t revents = 0; + + sock_poll_wait(file, sock, wait); + trace_tipc_sk_poll(sk, NULL, TIPC_DUMP_ALL, " "); + + if (sk->sk_shutdown & RCV_SHUTDOWN) + revents |= EPOLLRDHUP | EPOLLIN | EPOLLRDNORM; + if (sk->sk_shutdown == SHUTDOWN_MASK) + revents |= EPOLLHUP; + + switch (sk->sk_state) { + case TIPC_ESTABLISHED: + if (!tsk->cong_link_cnt && !tsk_conn_cong(tsk)) + revents |= EPOLLOUT; + fallthrough; + case TIPC_LISTEN: + case TIPC_CONNECTING: + if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) + revents |= EPOLLIN | EPOLLRDNORM; + break; + case TIPC_OPEN: + if (tsk->group_is_open && !tsk->cong_link_cnt) + revents |= EPOLLOUT; + if (!tipc_sk_type_connectionless(sk)) + break; + if (skb_queue_empty_lockless(&sk->sk_receive_queue)) + break; + revents |= EPOLLIN | EPOLLRDNORM; + break; + case TIPC_DISCONNECTING: + revents = EPOLLIN | EPOLLRDNORM | EPOLLHUP; + break; + } + return revents; +} + +/** + * tipc_sendmcast - send multicast message + * @sock: socket structure + * @ua: destination address struct + * @msg: message to send + * @dlen: length of data to send + * @timeout: timeout to wait for wakeup + * + * Called from function tipc_sendmsg(), which has done all sanity checks + * Return: the number of bytes sent on success, or errno + */ +static int tipc_sendmcast(struct socket *sock, struct tipc_uaddr *ua, + struct msghdr *msg, size_t dlen, long timeout) +{ + struct sock *sk = sock->sk; + struct tipc_sock *tsk = tipc_sk(sk); + struct tipc_msg *hdr = &tsk->phdr; + struct net *net = sock_net(sk); + int mtu = tipc_bcast_get_mtu(net); + struct sk_buff_head pkts; + struct tipc_nlist dsts; + int rc; + + if (tsk->group) + return -EACCES; + + /* Block or return if any destination link is congested */ + rc = tipc_wait_for_cond(sock, &timeout, !tsk->cong_link_cnt); + if (unlikely(rc)) + return rc; + + /* Lookup destination nodes */ + tipc_nlist_init(&dsts, tipc_own_addr(net)); + tipc_nametbl_lookup_mcast_nodes(net, ua, &dsts); + if (!dsts.local && !dsts.remote) + return -EHOSTUNREACH; + + /* Build message header */ + msg_set_type(hdr, TIPC_MCAST_MSG); + msg_set_hdr_sz(hdr, MCAST_H_SIZE); + msg_set_lookup_scope(hdr, TIPC_CLUSTER_SCOPE); + msg_set_destport(hdr, 0); + msg_set_destnode(hdr, 0); + msg_set_nametype(hdr, ua->sr.type); + msg_set_namelower(hdr, ua->sr.lower); + msg_set_nameupper(hdr, ua->sr.upper); + + /* Build message as chain of buffers */ + __skb_queue_head_init(&pkts); + rc = tipc_msg_build(hdr, msg, 0, dlen, mtu, &pkts); + + /* Send message if build was successful */ + if (unlikely(rc == dlen)) { + trace_tipc_sk_sendmcast(sk, skb_peek(&pkts), + TIPC_DUMP_SK_SNDQ, " "); + rc = tipc_mcast_xmit(net, &pkts, &tsk->mc_method, &dsts, + &tsk->cong_link_cnt); + } + + tipc_nlist_purge(&dsts); + + return rc ? rc : dlen; +} + +/** + * tipc_send_group_msg - send a message to a member in the group + * @net: network namespace + * @tsk: tipc socket + * @m: message to send + * @mb: group member + * @dnode: destination node + * @dport: destination port + * @dlen: total length of message data + */ +static int tipc_send_group_msg(struct net *net, struct tipc_sock *tsk, + struct msghdr *m, struct tipc_member *mb, + u32 dnode, u32 dport, int dlen) +{ + u16 bc_snd_nxt = tipc_group_bc_snd_nxt(tsk->group); + struct tipc_mc_method *method = &tsk->mc_method; + int blks = tsk_blocks(GROUP_H_SIZE + dlen); + struct tipc_msg *hdr = &tsk->phdr; + struct sk_buff_head pkts; + int mtu, rc; + + /* Complete message header */ + msg_set_type(hdr, TIPC_GRP_UCAST_MSG); + msg_set_hdr_sz(hdr, GROUP_H_SIZE); + msg_set_destport(hdr, dport); + msg_set_destnode(hdr, dnode); + msg_set_grp_bc_seqno(hdr, bc_snd_nxt); + + /* Build message as chain of buffers */ + __skb_queue_head_init(&pkts); + mtu = tipc_node_get_mtu(net, dnode, tsk->portid, false); + rc = tipc_msg_build(hdr, m, 0, dlen, mtu, &pkts); + if (unlikely(rc != dlen)) + return rc; + + /* Send message */ + rc = tipc_node_xmit(net, &pkts, dnode, tsk->portid); + if (unlikely(rc == -ELINKCONG)) { + tipc_dest_push(&tsk->cong_links, dnode, 0); + tsk->cong_link_cnt++; + } + + /* Update send window */ + tipc_group_update_member(mb, blks); + + /* A broadcast sent within next EXPIRE period must follow same path */ + method->rcast = true; + method->mandatory = true; + return dlen; +} + +/** + * tipc_send_group_unicast - send message to a member in the group + * @sock: socket structure + * @m: message to send + * @dlen: total length of message data + * @timeout: timeout to wait for wakeup + * + * Called from function tipc_sendmsg(), which has done all sanity checks + * Return: the number of bytes sent on success, or errno + */ +static int tipc_send_group_unicast(struct socket *sock, struct msghdr *m, + int dlen, long timeout) +{ + struct sock *sk = sock->sk; + struct tipc_uaddr *ua = (struct tipc_uaddr *)m->msg_name; + int blks = tsk_blocks(GROUP_H_SIZE + dlen); + struct tipc_sock *tsk = tipc_sk(sk); + struct net *net = sock_net(sk); + struct tipc_member *mb = NULL; + u32 node, port; + int rc; + + node = ua->sk.node; + port = ua->sk.ref; + if (!port && !node) + return -EHOSTUNREACH; + + /* Block or return if destination link or member is congested */ + rc = tipc_wait_for_cond(sock, &timeout, + !tipc_dest_find(&tsk->cong_links, node, 0) && + tsk->group && + !tipc_group_cong(tsk->group, node, port, blks, + &mb)); + if (unlikely(rc)) + return rc; + + if (unlikely(!mb)) + return -EHOSTUNREACH; + + rc = tipc_send_group_msg(net, tsk, m, mb, node, port, dlen); + + return rc ? rc : dlen; +} + +/** + * tipc_send_group_anycast - send message to any member with given identity + * @sock: socket structure + * @m: message to send + * @dlen: total length of message data + * @timeout: timeout to wait for wakeup + * + * Called from function tipc_sendmsg(), which has done all sanity checks + * Return: the number of bytes sent on success, or errno + */ +static int tipc_send_group_anycast(struct socket *sock, struct msghdr *m, + int dlen, long timeout) +{ + struct tipc_uaddr *ua = (struct tipc_uaddr *)m->msg_name; + struct sock *sk = sock->sk; + struct tipc_sock *tsk = tipc_sk(sk); + struct list_head *cong_links = &tsk->cong_links; + int blks = tsk_blocks(GROUP_H_SIZE + dlen); + struct tipc_msg *hdr = &tsk->phdr; + struct tipc_member *first = NULL; + struct tipc_member *mbr = NULL; + struct net *net = sock_net(sk); + u32 node, port, exclude; + struct list_head dsts; + int lookups = 0; + int dstcnt, rc; + bool cong; + + INIT_LIST_HEAD(&dsts); + ua->sa.type = msg_nametype(hdr); + ua->scope = msg_lookup_scope(hdr); + + while (++lookups < 4) { + exclude = tipc_group_exclude(tsk->group); + + first = NULL; + + /* Look for a non-congested destination member, if any */ + while (1) { + if (!tipc_nametbl_lookup_group(net, ua, &dsts, &dstcnt, + exclude, false)) + return -EHOSTUNREACH; + tipc_dest_pop(&dsts, &node, &port); + cong = tipc_group_cong(tsk->group, node, port, blks, + &mbr); + if (!cong) + break; + if (mbr == first) + break; + if (!first) + first = mbr; + } + + /* Start over if destination was not in member list */ + if (unlikely(!mbr)) + continue; + + if (likely(!cong && !tipc_dest_find(cong_links, node, 0))) + break; + + /* Block or return if destination link or member is congested */ + rc = tipc_wait_for_cond(sock, &timeout, + !tipc_dest_find(cong_links, node, 0) && + tsk->group && + !tipc_group_cong(tsk->group, node, port, + blks, &mbr)); + if (unlikely(rc)) + return rc; + + /* Send, unless destination disappeared while waiting */ + if (likely(mbr)) + break; + } + + if (unlikely(lookups >= 4)) + return -EHOSTUNREACH; + + rc = tipc_send_group_msg(net, tsk, m, mbr, node, port, dlen); + + return rc ? rc : dlen; +} + +/** + * tipc_send_group_bcast - send message to all members in communication group + * @sock: socket structure + * @m: message to send + * @dlen: total length of message data + * @timeout: timeout to wait for wakeup + * + * Called from function tipc_sendmsg(), which has done all sanity checks + * Return: the number of bytes sent on success, or errno + */ +static int tipc_send_group_bcast(struct socket *sock, struct msghdr *m, + int dlen, long timeout) +{ + struct tipc_uaddr *ua = (struct tipc_uaddr *)m->msg_name; + struct sock *sk = sock->sk; + struct net *net = sock_net(sk); + struct tipc_sock *tsk = tipc_sk(sk); + struct tipc_nlist *dsts; + struct tipc_mc_method *method = &tsk->mc_method; + bool ack = method->mandatory && method->rcast; + int blks = tsk_blocks(MCAST_H_SIZE + dlen); + struct tipc_msg *hdr = &tsk->phdr; + int mtu = tipc_bcast_get_mtu(net); + struct sk_buff_head pkts; + int rc = -EHOSTUNREACH; + + /* Block or return if any destination link or member is congested */ + rc = tipc_wait_for_cond(sock, &timeout, + !tsk->cong_link_cnt && tsk->group && + !tipc_group_bc_cong(tsk->group, blks)); + if (unlikely(rc)) + return rc; + + dsts = tipc_group_dests(tsk->group); + if (!dsts->local && !dsts->remote) + return -EHOSTUNREACH; + + /* Complete message header */ + if (ua) { + msg_set_type(hdr, TIPC_GRP_MCAST_MSG); + msg_set_nameinst(hdr, ua->sa.instance); + } else { + msg_set_type(hdr, TIPC_GRP_BCAST_MSG); + msg_set_nameinst(hdr, 0); + } + msg_set_hdr_sz(hdr, GROUP_H_SIZE); + msg_set_destport(hdr, 0); + msg_set_destnode(hdr, 0); + msg_set_grp_bc_seqno(hdr, tipc_group_bc_snd_nxt(tsk->group)); + + /* Avoid getting stuck with repeated forced replicasts */ + msg_set_grp_bc_ack_req(hdr, ack); + + /* Build message as chain of buffers */ + __skb_queue_head_init(&pkts); + rc = tipc_msg_build(hdr, m, 0, dlen, mtu, &pkts); + if (unlikely(rc != dlen)) + return rc; + + /* Send message */ + rc = tipc_mcast_xmit(net, &pkts, method, dsts, &tsk->cong_link_cnt); + if (unlikely(rc)) + return rc; + + /* Update broadcast sequence number and send windows */ + tipc_group_update_bc_members(tsk->group, blks, ack); + + /* Broadcast link is now free to choose method for next broadcast */ + method->mandatory = false; + method->expires = jiffies; + + return dlen; +} + +/** + * tipc_send_group_mcast - send message to all members with given identity + * @sock: socket structure + * @m: message to send + * @dlen: total length of message data + * @timeout: timeout to wait for wakeup + * + * Called from function tipc_sendmsg(), which has done all sanity checks + * Return: the number of bytes sent on success, or errno + */ +static int tipc_send_group_mcast(struct socket *sock, struct msghdr *m, + int dlen, long timeout) +{ + struct tipc_uaddr *ua = (struct tipc_uaddr *)m->msg_name; + struct sock *sk = sock->sk; + struct tipc_sock *tsk = tipc_sk(sk); + struct tipc_group *grp = tsk->group; + struct tipc_msg *hdr = &tsk->phdr; + struct net *net = sock_net(sk); + struct list_head dsts; + u32 dstcnt, exclude; + + INIT_LIST_HEAD(&dsts); + ua->sa.type = msg_nametype(hdr); + ua->scope = msg_lookup_scope(hdr); + exclude = tipc_group_exclude(grp); + + if (!tipc_nametbl_lookup_group(net, ua, &dsts, &dstcnt, exclude, true)) + return -EHOSTUNREACH; + + if (dstcnt == 1) { + tipc_dest_pop(&dsts, &ua->sk.node, &ua->sk.ref); + return tipc_send_group_unicast(sock, m, dlen, timeout); + } + + tipc_dest_list_purge(&dsts); + return tipc_send_group_bcast(sock, m, dlen, timeout); +} + +/** + * tipc_sk_mcast_rcv - Deliver multicast messages to all destination sockets + * @net: the associated network namespace + * @arrvq: queue with arriving messages, to be cloned after destination lookup + * @inputq: queue with cloned messages, delivered to socket after dest lookup + * + * Multi-threaded: parallel calls with reference to same queues may occur + */ +void tipc_sk_mcast_rcv(struct net *net, struct sk_buff_head *arrvq, + struct sk_buff_head *inputq) +{ + u32 self = tipc_own_addr(net); + struct sk_buff *skb, *_skb; + u32 portid, onode; + struct sk_buff_head tmpq; + struct list_head dports; + struct tipc_msg *hdr; + struct tipc_uaddr ua; + int user, mtyp, hlen; + + __skb_queue_head_init(&tmpq); + INIT_LIST_HEAD(&dports); + ua.addrtype = TIPC_SERVICE_RANGE; + + /* tipc_skb_peek() increments the head skb's reference counter */ + skb = tipc_skb_peek(arrvq, &inputq->lock); + for (; skb; skb = tipc_skb_peek(arrvq, &inputq->lock)) { + hdr = buf_msg(skb); + user = msg_user(hdr); + mtyp = msg_type(hdr); + hlen = skb_headroom(skb) + msg_hdr_sz(hdr); + onode = msg_orignode(hdr); + ua.sr.type = msg_nametype(hdr); + ua.sr.lower = msg_namelower(hdr); + ua.sr.upper = msg_nameupper(hdr); + if (onode == self) + ua.scope = TIPC_ANY_SCOPE; + else + ua.scope = TIPC_CLUSTER_SCOPE; + + if (mtyp == TIPC_GRP_UCAST_MSG || user == GROUP_PROTOCOL) { + spin_lock_bh(&inputq->lock); + if (skb_peek(arrvq) == skb) { + __skb_dequeue(arrvq); + __skb_queue_tail(inputq, skb); + } + kfree_skb(skb); + spin_unlock_bh(&inputq->lock); + continue; + } + + /* Group messages require exact scope match */ + if (msg_in_group(hdr)) { + ua.sr.lower = 0; + ua.sr.upper = ~0; + ua.scope = msg_lookup_scope(hdr); + } + + /* Create destination port list: */ + tipc_nametbl_lookup_mcast_sockets(net, &ua, &dports); + + /* Clone message per destination */ + while (tipc_dest_pop(&dports, NULL, &portid)) { + _skb = __pskb_copy(skb, hlen, GFP_ATOMIC); + if (_skb) { + msg_set_destport(buf_msg(_skb), portid); + __skb_queue_tail(&tmpq, _skb); + continue; + } + pr_warn("Failed to clone mcast rcv buffer\n"); + } + /* Append clones to inputq only if skb is still head of arrvq */ + spin_lock_bh(&inputq->lock); + if (skb_peek(arrvq) == skb) { + skb_queue_splice_tail_init(&tmpq, inputq); + /* Decrement the skb's refcnt */ + kfree_skb(__skb_dequeue(arrvq)); + } + spin_unlock_bh(&inputq->lock); + __skb_queue_purge(&tmpq); + kfree_skb(skb); + } + tipc_sk_rcv(net, inputq); +} + +/* tipc_sk_push_backlog(): send accumulated buffers in socket write queue + * when socket is in Nagle mode + */ +static void tipc_sk_push_backlog(struct tipc_sock *tsk, bool nagle_ack) +{ + struct sk_buff_head *txq = &tsk->sk.sk_write_queue; + struct sk_buff *skb = skb_peek_tail(txq); + struct net *net = sock_net(&tsk->sk); + u32 dnode = tsk_peer_node(tsk); + int rc; + + if (nagle_ack) { + tsk->pkt_cnt += skb_queue_len(txq); + if (!tsk->pkt_cnt || tsk->msg_acc / tsk->pkt_cnt < 2) { + tsk->oneway = 0; + if (tsk->nagle_start < NAGLE_START_MAX) + tsk->nagle_start *= 2; + tsk->expect_ack = false; + pr_debug("tsk %10u: bad nagle %u -> %u, next start %u!\n", + tsk->portid, tsk->msg_acc, tsk->pkt_cnt, + tsk->nagle_start); + } else { + tsk->nagle_start = NAGLE_START_INIT; + if (skb) { + msg_set_ack_required(buf_msg(skb)); + tsk->expect_ack = true; + } else { + tsk->expect_ack = false; + } + } + tsk->msg_acc = 0; + tsk->pkt_cnt = 0; + } + + if (!skb || tsk->cong_link_cnt) + return; + + /* Do not send SYN again after congestion */ + if (msg_is_syn(buf_msg(skb))) + return; + + if (tsk->msg_acc) + tsk->pkt_cnt += skb_queue_len(txq); + tsk->snt_unacked += tsk->snd_backlog; + tsk->snd_backlog = 0; + rc = tipc_node_xmit(net, txq, dnode, tsk->portid); + if (rc == -ELINKCONG) + tsk->cong_link_cnt = 1; +} + +/** + * tipc_sk_conn_proto_rcv - receive a connection mng protocol message + * @tsk: receiving socket + * @skb: pointer to message buffer. + * @inputq: buffer list containing the buffers + * @xmitq: output message area + */ +static void tipc_sk_conn_proto_rcv(struct tipc_sock *tsk, struct sk_buff *skb, + struct sk_buff_head *inputq, + struct sk_buff_head *xmitq) +{ + struct tipc_msg *hdr = buf_msg(skb); + u32 onode = tsk_own_node(tsk); + struct sock *sk = &tsk->sk; + int mtyp = msg_type(hdr); + bool was_cong; + + /* Ignore if connection cannot be validated: */ + if (!tsk_peer_msg(tsk, hdr)) { + trace_tipc_sk_drop_msg(sk, skb, TIPC_DUMP_NONE, "@proto_rcv!"); + goto exit; + } + + if (unlikely(msg_errcode(hdr))) { + tipc_set_sk_state(sk, TIPC_DISCONNECTING); + tipc_node_remove_conn(sock_net(sk), tsk_peer_node(tsk), + tsk_peer_port(tsk)); + sk->sk_state_change(sk); + + /* State change is ignored if socket already awake, + * - convert msg to abort msg and add to inqueue + */ + msg_set_user(hdr, TIPC_CRITICAL_IMPORTANCE); + msg_set_type(hdr, TIPC_CONN_MSG); + msg_set_size(hdr, BASIC_H_SIZE); + msg_set_hdr_sz(hdr, BASIC_H_SIZE); + __skb_queue_tail(inputq, skb); + return; + } + + tsk->probe_unacked = false; + + if (mtyp == CONN_PROBE) { + msg_set_type(hdr, CONN_PROBE_REPLY); + if (tipc_msg_reverse(onode, &skb, TIPC_OK)) + __skb_queue_tail(xmitq, skb); + return; + } else if (mtyp == CONN_ACK) { + was_cong = tsk_conn_cong(tsk); + tipc_sk_push_backlog(tsk, msg_nagle_ack(hdr)); + tsk->snt_unacked -= msg_conn_ack(hdr); + if (tsk->peer_caps & TIPC_BLOCK_FLOWCTL) + tsk->snd_win = msg_adv_win(hdr); + if (was_cong && !tsk_conn_cong(tsk)) + sk->sk_write_space(sk); + } else if (mtyp != CONN_PROBE_REPLY) { + pr_warn("Received unknown CONN_PROTO msg\n"); + } +exit: + kfree_skb(skb); +} + +/** + * tipc_sendmsg - send message in connectionless manner + * @sock: socket structure + * @m: message to send + * @dsz: amount of user data to be sent + * + * Message must have an destination specified explicitly. + * Used for SOCK_RDM and SOCK_DGRAM messages, + * and for 'SYN' messages on SOCK_SEQPACKET and SOCK_STREAM connections. + * (Note: 'SYN+' is prohibited on SOCK_STREAM.) + * + * Return: the number of bytes sent on success, or errno otherwise + */ +static int tipc_sendmsg(struct socket *sock, + struct msghdr *m, size_t dsz) +{ + struct sock *sk = sock->sk; + int ret; + + lock_sock(sk); + ret = __tipc_sendmsg(sock, m, dsz); + release_sock(sk); + + return ret; +} + +static int __tipc_sendmsg(struct socket *sock, struct msghdr *m, size_t dlen) +{ + struct sock *sk = sock->sk; + struct net *net = sock_net(sk); + struct tipc_sock *tsk = tipc_sk(sk); + struct tipc_uaddr *ua = (struct tipc_uaddr *)m->msg_name; + long timeout = sock_sndtimeo(sk, m->msg_flags & MSG_DONTWAIT); + struct list_head *clinks = &tsk->cong_links; + bool syn = !tipc_sk_type_connectionless(sk); + struct tipc_group *grp = tsk->group; + struct tipc_msg *hdr = &tsk->phdr; + struct tipc_socket_addr skaddr; + struct sk_buff_head pkts; + int atype, mtu, rc; + + if (unlikely(dlen > TIPC_MAX_USER_MSG_SIZE)) + return -EMSGSIZE; + + if (ua) { + if (!tipc_uaddr_valid(ua, m->msg_namelen)) + return -EINVAL; + atype = ua->addrtype; + } + + /* If socket belongs to a communication group follow other paths */ + if (grp) { + if (!ua) + return tipc_send_group_bcast(sock, m, dlen, timeout); + if (atype == TIPC_SERVICE_ADDR) + return tipc_send_group_anycast(sock, m, dlen, timeout); + if (atype == TIPC_SOCKET_ADDR) + return tipc_send_group_unicast(sock, m, dlen, timeout); + if (atype == TIPC_SERVICE_RANGE) + return tipc_send_group_mcast(sock, m, dlen, timeout); + return -EINVAL; + } + + if (!ua) { + ua = (struct tipc_uaddr *)&tsk->peer; + if (!syn && ua->family != AF_TIPC) + return -EDESTADDRREQ; + atype = ua->addrtype; + } + + if (unlikely(syn)) { + if (sk->sk_state == TIPC_LISTEN) + return -EPIPE; + if (sk->sk_state != TIPC_OPEN) + return -EISCONN; + if (tsk->published) + return -EOPNOTSUPP; + if (atype == TIPC_SERVICE_ADDR) + tsk->conn_addrtype = atype; + msg_set_syn(hdr, 1); + } + + memset(&skaddr, 0, sizeof(skaddr)); + + /* Determine destination */ + if (atype == TIPC_SERVICE_RANGE) { + return tipc_sendmcast(sock, ua, m, dlen, timeout); + } else if (atype == TIPC_SERVICE_ADDR) { + skaddr.node = ua->lookup_node; + ua->scope = tipc_node2scope(skaddr.node); + if (!tipc_nametbl_lookup_anycast(net, ua, &skaddr)) + return -EHOSTUNREACH; + } else if (atype == TIPC_SOCKET_ADDR) { + skaddr = ua->sk; + } else { + return -EINVAL; + } + + /* Block or return if destination link is congested */ + rc = tipc_wait_for_cond(sock, &timeout, + !tipc_dest_find(clinks, skaddr.node, 0)); + if (unlikely(rc)) + return rc; + + /* Finally build message header */ + msg_set_destnode(hdr, skaddr.node); + msg_set_destport(hdr, skaddr.ref); + if (atype == TIPC_SERVICE_ADDR) { + msg_set_type(hdr, TIPC_NAMED_MSG); + msg_set_hdr_sz(hdr, NAMED_H_SIZE); + msg_set_nametype(hdr, ua->sa.type); + msg_set_nameinst(hdr, ua->sa.instance); + msg_set_lookup_scope(hdr, ua->scope); + } else { /* TIPC_SOCKET_ADDR */ + msg_set_type(hdr, TIPC_DIRECT_MSG); + msg_set_lookup_scope(hdr, 0); + msg_set_hdr_sz(hdr, BASIC_H_SIZE); + } + + /* Add message body */ + __skb_queue_head_init(&pkts); + mtu = tipc_node_get_mtu(net, skaddr.node, tsk->portid, true); + rc = tipc_msg_build(hdr, m, 0, dlen, mtu, &pkts); + if (unlikely(rc != dlen)) + return rc; + if (unlikely(syn && !tipc_msg_skb_clone(&pkts, &sk->sk_write_queue))) { + __skb_queue_purge(&pkts); + return -ENOMEM; + } + + /* Send message */ + trace_tipc_sk_sendmsg(sk, skb_peek(&pkts), TIPC_DUMP_SK_SNDQ, " "); + rc = tipc_node_xmit(net, &pkts, skaddr.node, tsk->portid); + if (unlikely(rc == -ELINKCONG)) { + tipc_dest_push(clinks, skaddr.node, 0); + tsk->cong_link_cnt++; + rc = 0; + } + + if (unlikely(syn && !rc)) { + tipc_set_sk_state(sk, TIPC_CONNECTING); + if (dlen && timeout) { + timeout = msecs_to_jiffies(timeout); + tipc_wait_for_connect(sock, &timeout); + } + } + + return rc ? rc : dlen; +} + +/** + * tipc_sendstream - send stream-oriented data + * @sock: socket structure + * @m: data to send + * @dsz: total length of data to be transmitted + * + * Used for SOCK_STREAM data. + * + * Return: the number of bytes sent on success (or partial success), + * or errno if no data sent + */ +static int tipc_sendstream(struct socket *sock, struct msghdr *m, size_t dsz) +{ + struct sock *sk = sock->sk; + int ret; + + lock_sock(sk); + ret = __tipc_sendstream(sock, m, dsz); + release_sock(sk); + + return ret; +} + +static int __tipc_sendstream(struct socket *sock, struct msghdr *m, size_t dlen) +{ + struct sock *sk = sock->sk; + DECLARE_SOCKADDR(struct sockaddr_tipc *, dest, m->msg_name); + long timeout = sock_sndtimeo(sk, m->msg_flags & MSG_DONTWAIT); + struct sk_buff_head *txq = &sk->sk_write_queue; + struct tipc_sock *tsk = tipc_sk(sk); + struct tipc_msg *hdr = &tsk->phdr; + struct net *net = sock_net(sk); + struct sk_buff *skb; + u32 dnode = tsk_peer_node(tsk); + int maxnagle = tsk->maxnagle; + int maxpkt = tsk->max_pkt; + int send, sent = 0; + int blocks, rc = 0; + + if (unlikely(dlen > INT_MAX)) + return -EMSGSIZE; + + /* Handle implicit connection setup */ + if (unlikely(dest && sk->sk_state == TIPC_OPEN)) { + rc = __tipc_sendmsg(sock, m, dlen); + if (dlen && dlen == rc) { + tsk->peer_caps = tipc_node_get_capabilities(net, dnode); + tsk->snt_unacked = tsk_inc(tsk, dlen + msg_hdr_sz(hdr)); + } + return rc; + } + + do { + rc = tipc_wait_for_cond(sock, &timeout, + (!tsk->cong_link_cnt && + !tsk_conn_cong(tsk) && + tipc_sk_connected(sk))); + if (unlikely(rc)) + break; + send = min_t(size_t, dlen - sent, TIPC_MAX_USER_MSG_SIZE); + blocks = tsk->snd_backlog; + if (tsk->oneway++ >= tsk->nagle_start && maxnagle && + send <= maxnagle) { + rc = tipc_msg_append(hdr, m, send, maxnagle, txq); + if (unlikely(rc < 0)) + break; + blocks += rc; + tsk->msg_acc++; + if (blocks <= 64 && tsk->expect_ack) { + tsk->snd_backlog = blocks; + sent += send; + break; + } else if (blocks > 64) { + tsk->pkt_cnt += skb_queue_len(txq); + } else { + skb = skb_peek_tail(txq); + if (skb) { + msg_set_ack_required(buf_msg(skb)); + tsk->expect_ack = true; + } else { + tsk->expect_ack = false; + } + tsk->msg_acc = 0; + tsk->pkt_cnt = 0; + } + } else { + rc = tipc_msg_build(hdr, m, sent, send, maxpkt, txq); + if (unlikely(rc != send)) + break; + blocks += tsk_inc(tsk, send + MIN_H_SIZE); + } + trace_tipc_sk_sendstream(sk, skb_peek(txq), + TIPC_DUMP_SK_SNDQ, " "); + rc = tipc_node_xmit(net, txq, dnode, tsk->portid); + if (unlikely(rc == -ELINKCONG)) { + tsk->cong_link_cnt = 1; + rc = 0; + } + if (likely(!rc)) { + tsk->snt_unacked += blocks; + tsk->snd_backlog = 0; + sent += send; + } + } while (sent < dlen && !rc); + + return sent ? sent : rc; +} + +/** + * tipc_send_packet - send a connection-oriented message + * @sock: socket structure + * @m: message to send + * @dsz: length of data to be transmitted + * + * Used for SOCK_SEQPACKET messages. + * + * Return: the number of bytes sent on success, or errno otherwise + */ +static int tipc_send_packet(struct socket *sock, struct msghdr *m, size_t dsz) +{ + if (dsz > TIPC_MAX_USER_MSG_SIZE) + return -EMSGSIZE; + + return tipc_sendstream(sock, m, dsz); +} + +/* tipc_sk_finish_conn - complete the setup of a connection + */ +static void tipc_sk_finish_conn(struct tipc_sock *tsk, u32 peer_port, + u32 peer_node) +{ + struct sock *sk = &tsk->sk; + struct net *net = sock_net(sk); + struct tipc_msg *msg = &tsk->phdr; + + msg_set_syn(msg, 0); + msg_set_destnode(msg, peer_node); + msg_set_destport(msg, peer_port); + msg_set_type(msg, TIPC_CONN_MSG); + msg_set_lookup_scope(msg, 0); + msg_set_hdr_sz(msg, SHORT_H_SIZE); + + sk_reset_timer(sk, &sk->sk_timer, jiffies + CONN_PROBING_INTV); + tipc_set_sk_state(sk, TIPC_ESTABLISHED); + tipc_node_add_conn(net, peer_node, tsk->portid, peer_port); + tsk->max_pkt = tipc_node_get_mtu(net, peer_node, tsk->portid, true); + tsk->peer_caps = tipc_node_get_capabilities(net, peer_node); + tsk_set_nagle(tsk); + __skb_queue_purge(&sk->sk_write_queue); + if (tsk->peer_caps & TIPC_BLOCK_FLOWCTL) + return; + + /* Fall back to message based flow control */ + tsk->rcv_win = FLOWCTL_MSG_WIN; + tsk->snd_win = FLOWCTL_MSG_WIN; +} + +/** + * tipc_sk_set_orig_addr - capture sender's address for received message + * @m: descriptor for message info + * @skb: received message + * + * Note: Address is not captured if not requested by receiver. + */ +static void tipc_sk_set_orig_addr(struct msghdr *m, struct sk_buff *skb) +{ + DECLARE_SOCKADDR(struct sockaddr_pair *, srcaddr, m->msg_name); + struct tipc_msg *hdr = buf_msg(skb); + + if (!srcaddr) + return; + + srcaddr->sock.family = AF_TIPC; + srcaddr->sock.addrtype = TIPC_SOCKET_ADDR; + srcaddr->sock.scope = 0; + srcaddr->sock.addr.id.ref = msg_origport(hdr); + srcaddr->sock.addr.id.node = msg_orignode(hdr); + srcaddr->sock.addr.name.domain = 0; + m->msg_namelen = sizeof(struct sockaddr_tipc); + + if (!msg_in_group(hdr)) + return; + + /* Group message users may also want to know sending member's id */ + srcaddr->member.family = AF_TIPC; + srcaddr->member.addrtype = TIPC_SERVICE_ADDR; + srcaddr->member.scope = 0; + srcaddr->member.addr.name.name.type = msg_nametype(hdr); + srcaddr->member.addr.name.name.instance = TIPC_SKB_CB(skb)->orig_member; + srcaddr->member.addr.name.domain = 0; + m->msg_namelen = sizeof(*srcaddr); +} + +/** + * tipc_sk_anc_data_recv - optionally capture ancillary data for received message + * @m: descriptor for message info + * @skb: received message buffer + * @tsk: TIPC port associated with message + * + * Note: Ancillary data is not captured if not requested by receiver. + * + * Return: 0 if successful, otherwise errno + */ +static int tipc_sk_anc_data_recv(struct msghdr *m, struct sk_buff *skb, + struct tipc_sock *tsk) +{ + struct tipc_msg *hdr; + u32 data[3] = {0,}; + bool has_addr; + int dlen, rc; + + if (likely(m->msg_controllen == 0)) + return 0; + + hdr = buf_msg(skb); + dlen = msg_data_sz(hdr); + + /* Capture errored message object, if any */ + if (msg_errcode(hdr)) { + if (skb_linearize(skb)) + return -ENOMEM; + hdr = buf_msg(skb); + data[0] = msg_errcode(hdr); + data[1] = dlen; + rc = put_cmsg(m, SOL_TIPC, TIPC_ERRINFO, 8, data); + if (rc || !dlen) + return rc; + rc = put_cmsg(m, SOL_TIPC, TIPC_RETDATA, dlen, msg_data(hdr)); + if (rc) + return rc; + } + + /* Capture TIPC_SERVICE_ADDR/RANGE destination address, if any */ + switch (msg_type(hdr)) { + case TIPC_NAMED_MSG: + has_addr = true; + data[0] = msg_nametype(hdr); + data[1] = msg_namelower(hdr); + data[2] = data[1]; + break; + case TIPC_MCAST_MSG: + has_addr = true; + data[0] = msg_nametype(hdr); + data[1] = msg_namelower(hdr); + data[2] = msg_nameupper(hdr); + break; + case TIPC_CONN_MSG: + has_addr = !!tsk->conn_addrtype; + data[0] = msg_nametype(&tsk->phdr); + data[1] = msg_nameinst(&tsk->phdr); + data[2] = data[1]; + break; + default: + has_addr = false; + } + if (!has_addr) + return 0; + return put_cmsg(m, SOL_TIPC, TIPC_DESTNAME, 12, data); +} + +static struct sk_buff *tipc_sk_build_ack(struct tipc_sock *tsk) +{ + struct sock *sk = &tsk->sk; + struct sk_buff *skb = NULL; + struct tipc_msg *msg; + u32 peer_port = tsk_peer_port(tsk); + u32 dnode = tsk_peer_node(tsk); + + if (!tipc_sk_connected(sk)) + return NULL; + skb = tipc_msg_create(CONN_MANAGER, CONN_ACK, INT_H_SIZE, 0, + dnode, tsk_own_node(tsk), peer_port, + tsk->portid, TIPC_OK); + if (!skb) + return NULL; + msg = buf_msg(skb); + msg_set_conn_ack(msg, tsk->rcv_unacked); + tsk->rcv_unacked = 0; + + /* Adjust to and advertize the correct window limit */ + if (tsk->peer_caps & TIPC_BLOCK_FLOWCTL) { + tsk->rcv_win = tsk_adv_blocks(tsk->sk.sk_rcvbuf); + msg_set_adv_win(msg, tsk->rcv_win); + } + return skb; +} + +static void tipc_sk_send_ack(struct tipc_sock *tsk) +{ + struct sk_buff *skb; + + skb = tipc_sk_build_ack(tsk); + if (!skb) + return; + + tipc_node_xmit_skb(sock_net(&tsk->sk), skb, tsk_peer_node(tsk), + msg_link_selector(buf_msg(skb))); +} + +static int tipc_wait_for_rcvmsg(struct socket *sock, long *timeop) +{ + struct sock *sk = sock->sk; + DEFINE_WAIT_FUNC(wait, woken_wake_function); + long timeo = *timeop; + int err = sock_error(sk); + + if (err) + return err; + + for (;;) { + if (timeo && skb_queue_empty(&sk->sk_receive_queue)) { + if (sk->sk_shutdown & RCV_SHUTDOWN) { + err = -ENOTCONN; + break; + } + add_wait_queue(sk_sleep(sk), &wait); + release_sock(sk); + timeo = wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); + sched_annotate_sleep(); + lock_sock(sk); + remove_wait_queue(sk_sleep(sk), &wait); + } + err = 0; + if (!skb_queue_empty(&sk->sk_receive_queue)) + break; + err = -EAGAIN; + if (!timeo) + break; + err = sock_intr_errno(timeo); + if (signal_pending(current)) + break; + + err = sock_error(sk); + if (err) + break; + } + *timeop = timeo; + return err; +} + +/** + * tipc_recvmsg - receive packet-oriented message + * @sock: network socket + * @m: descriptor for message info + * @buflen: length of user buffer area + * @flags: receive flags + * + * Used for SOCK_DGRAM, SOCK_RDM, and SOCK_SEQPACKET messages. + * If the complete message doesn't fit in user area, truncate it. + * + * Return: size of returned message data, errno otherwise + */ +static int tipc_recvmsg(struct socket *sock, struct msghdr *m, + size_t buflen, int flags) +{ + struct sock *sk = sock->sk; + bool connected = !tipc_sk_type_connectionless(sk); + struct tipc_sock *tsk = tipc_sk(sk); + int rc, err, hlen, dlen, copy; + struct tipc_skb_cb *skb_cb; + struct sk_buff_head xmitq; + struct tipc_msg *hdr; + struct sk_buff *skb; + bool grp_evt; + long timeout; + + /* Catch invalid receive requests */ + if (unlikely(!buflen)) + return -EINVAL; + + lock_sock(sk); + if (unlikely(connected && sk->sk_state == TIPC_OPEN)) { + rc = -ENOTCONN; + goto exit; + } + timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + + /* Step rcv queue to first msg with data or error; wait if necessary */ + do { + rc = tipc_wait_for_rcvmsg(sock, &timeout); + if (unlikely(rc)) + goto exit; + skb = skb_peek(&sk->sk_receive_queue); + skb_cb = TIPC_SKB_CB(skb); + hdr = buf_msg(skb); + dlen = msg_data_sz(hdr); + hlen = msg_hdr_sz(hdr); + err = msg_errcode(hdr); + grp_evt = msg_is_grp_evt(hdr); + if (likely(dlen || err)) + break; + tsk_advance_rx_queue(sk); + } while (1); + + /* Collect msg meta data, including error code and rejected data */ + tipc_sk_set_orig_addr(m, skb); + rc = tipc_sk_anc_data_recv(m, skb, tsk); + if (unlikely(rc)) + goto exit; + hdr = buf_msg(skb); + + /* Capture data if non-error msg, otherwise just set return value */ + if (likely(!err)) { + int offset = skb_cb->bytes_read; + + copy = min_t(int, dlen - offset, buflen); + rc = skb_copy_datagram_msg(skb, hlen + offset, m, copy); + if (unlikely(rc)) + goto exit; + if (unlikely(offset + copy < dlen)) { + if (flags & MSG_EOR) { + if (!(flags & MSG_PEEK)) + skb_cb->bytes_read = offset + copy; + } else { + m->msg_flags |= MSG_TRUNC; + skb_cb->bytes_read = 0; + } + } else { + if (flags & MSG_EOR) + m->msg_flags |= MSG_EOR; + skb_cb->bytes_read = 0; + } + } else { + copy = 0; + rc = 0; + if (err != TIPC_CONN_SHUTDOWN && connected && !m->msg_control) { + rc = -ECONNRESET; + goto exit; + } + } + + /* Mark message as group event if applicable */ + if (unlikely(grp_evt)) { + if (msg_grp_evt(hdr) == TIPC_WITHDRAWN) + m->msg_flags |= MSG_EOR; + m->msg_flags |= MSG_OOB; + copy = 0; + } + + /* Caption of data or error code/rejected data was successful */ + if (unlikely(flags & MSG_PEEK)) + goto exit; + + /* Send group flow control advertisement when applicable */ + if (tsk->group && msg_in_group(hdr) && !grp_evt) { + __skb_queue_head_init(&xmitq); + tipc_group_update_rcv_win(tsk->group, tsk_blocks(hlen + dlen), + msg_orignode(hdr), msg_origport(hdr), + &xmitq); + tipc_node_distr_xmit(sock_net(sk), &xmitq); + } + + if (skb_cb->bytes_read) + goto exit; + + tsk_advance_rx_queue(sk); + + if (likely(!connected)) + goto exit; + + /* Send connection flow control advertisement when applicable */ + tsk->rcv_unacked += tsk_inc(tsk, hlen + dlen); + if (tsk->rcv_unacked >= tsk->rcv_win / TIPC_ACK_RATE) + tipc_sk_send_ack(tsk); +exit: + release_sock(sk); + return rc ? rc : copy; +} + +/** + * tipc_recvstream - receive stream-oriented data + * @sock: network socket + * @m: descriptor for message info + * @buflen: total size of user buffer area + * @flags: receive flags + * + * Used for SOCK_STREAM messages only. If not enough data is available + * will optionally wait for more; never truncates data. + * + * Return: size of returned message data, errno otherwise + */ +static int tipc_recvstream(struct socket *sock, struct msghdr *m, + size_t buflen, int flags) +{ + struct sock *sk = sock->sk; + struct tipc_sock *tsk = tipc_sk(sk); + struct sk_buff *skb; + struct tipc_msg *hdr; + struct tipc_skb_cb *skb_cb; + bool peek = flags & MSG_PEEK; + int offset, required, copy, copied = 0; + int hlen, dlen, err, rc; + long timeout; + + /* Catch invalid receive attempts */ + if (unlikely(!buflen)) + return -EINVAL; + + lock_sock(sk); + + if (unlikely(sk->sk_state == TIPC_OPEN)) { + rc = -ENOTCONN; + goto exit; + } + required = sock_rcvlowat(sk, flags & MSG_WAITALL, buflen); + timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + + do { + /* Look at first msg in receive queue; wait if necessary */ + rc = tipc_wait_for_rcvmsg(sock, &timeout); + if (unlikely(rc)) + break; + skb = skb_peek(&sk->sk_receive_queue); + skb_cb = TIPC_SKB_CB(skb); + hdr = buf_msg(skb); + dlen = msg_data_sz(hdr); + hlen = msg_hdr_sz(hdr); + err = msg_errcode(hdr); + + /* Discard any empty non-errored (SYN-) message */ + if (unlikely(!dlen && !err)) { + tsk_advance_rx_queue(sk); + continue; + } + + /* Collect msg meta data, incl. error code and rejected data */ + if (!copied) { + tipc_sk_set_orig_addr(m, skb); + rc = tipc_sk_anc_data_recv(m, skb, tsk); + if (rc) + break; + hdr = buf_msg(skb); + } + + /* Copy data if msg ok, otherwise return error/partial data */ + if (likely(!err)) { + offset = skb_cb->bytes_read; + copy = min_t(int, dlen - offset, buflen - copied); + rc = skb_copy_datagram_msg(skb, hlen + offset, m, copy); + if (unlikely(rc)) + break; + copied += copy; + offset += copy; + if (unlikely(offset < dlen)) { + if (!peek) + skb_cb->bytes_read = offset; + break; + } + } else { + rc = 0; + if ((err != TIPC_CONN_SHUTDOWN) && !m->msg_control) + rc = -ECONNRESET; + if (copied || rc) + break; + } + + if (unlikely(peek)) + break; + + tsk_advance_rx_queue(sk); + + /* Send connection flow control advertisement when applicable */ + tsk->rcv_unacked += tsk_inc(tsk, hlen + dlen); + if (tsk->rcv_unacked >= tsk->rcv_win / TIPC_ACK_RATE) + tipc_sk_send_ack(tsk); + + /* Exit if all requested data or FIN/error received */ + if (copied == buflen || err) + break; + + } while (!skb_queue_empty(&sk->sk_receive_queue) || copied < required); +exit: + release_sock(sk); + return copied ? copied : rc; +} + +/** + * tipc_write_space - wake up thread if port congestion is released + * @sk: socket + */ +static void tipc_write_space(struct sock *sk) +{ + struct socket_wq *wq; + + rcu_read_lock(); + wq = rcu_dereference(sk->sk_wq); + if (skwq_has_sleeper(wq)) + wake_up_interruptible_sync_poll(&wq->wait, EPOLLOUT | + EPOLLWRNORM | EPOLLWRBAND); + rcu_read_unlock(); +} + +/** + * tipc_data_ready - wake up threads to indicate messages have been received + * @sk: socket + */ +static void tipc_data_ready(struct sock *sk) +{ + struct socket_wq *wq; + + rcu_read_lock(); + wq = rcu_dereference(sk->sk_wq); + if (skwq_has_sleeper(wq)) + wake_up_interruptible_sync_poll(&wq->wait, EPOLLIN | + EPOLLRDNORM | EPOLLRDBAND); + rcu_read_unlock(); +} + +static void tipc_sock_destruct(struct sock *sk) +{ + __skb_queue_purge(&sk->sk_receive_queue); +} + +static void tipc_sk_proto_rcv(struct sock *sk, + struct sk_buff_head *inputq, + struct sk_buff_head *xmitq) +{ + struct sk_buff *skb = __skb_dequeue(inputq); + struct tipc_sock *tsk = tipc_sk(sk); + struct tipc_msg *hdr = buf_msg(skb); + struct tipc_group *grp = tsk->group; + bool wakeup = false; + + switch (msg_user(hdr)) { + case CONN_MANAGER: + tipc_sk_conn_proto_rcv(tsk, skb, inputq, xmitq); + return; + case SOCK_WAKEUP: + tipc_dest_del(&tsk->cong_links, msg_orignode(hdr), 0); + /* coupled with smp_rmb() in tipc_wait_for_cond() */ + smp_wmb(); + tsk->cong_link_cnt--; + wakeup = true; + tipc_sk_push_backlog(tsk, false); + break; + case GROUP_PROTOCOL: + tipc_group_proto_rcv(grp, &wakeup, hdr, inputq, xmitq); + break; + case TOP_SRV: + tipc_group_member_evt(tsk->group, &wakeup, &sk->sk_rcvbuf, + hdr, inputq, xmitq); + break; + default: + break; + } + + if (wakeup) + sk->sk_write_space(sk); + + kfree_skb(skb); +} + +/** + * tipc_sk_filter_connect - check incoming message for a connection-based socket + * @tsk: TIPC socket + * @skb: pointer to message buffer. + * @xmitq: for Nagle ACK if any + * Return: true if message should be added to receive queue, false otherwise + */ +static bool tipc_sk_filter_connect(struct tipc_sock *tsk, struct sk_buff *skb, + struct sk_buff_head *xmitq) +{ + struct sock *sk = &tsk->sk; + struct net *net = sock_net(sk); + struct tipc_msg *hdr = buf_msg(skb); + bool con_msg = msg_connected(hdr); + u32 pport = tsk_peer_port(tsk); + u32 pnode = tsk_peer_node(tsk); + u32 oport = msg_origport(hdr); + u32 onode = msg_orignode(hdr); + int err = msg_errcode(hdr); + unsigned long delay; + + if (unlikely(msg_mcast(hdr))) + return false; + tsk->oneway = 0; + + switch (sk->sk_state) { + case TIPC_CONNECTING: + /* Setup ACK */ + if (likely(con_msg)) { + if (err) + break; + tipc_sk_finish_conn(tsk, oport, onode); + msg_set_importance(&tsk->phdr, msg_importance(hdr)); + /* ACK+ message with data is added to receive queue */ + if (msg_data_sz(hdr)) + return true; + /* Empty ACK-, - wake up sleeping connect() and drop */ + sk->sk_state_change(sk); + msg_set_dest_droppable(hdr, 1); + return false; + } + /* Ignore connectionless message if not from listening socket */ + if (oport != pport || onode != pnode) + return false; + + /* Rejected SYN */ + if (err != TIPC_ERR_OVERLOAD) + break; + + /* Prepare for new setup attempt if we have a SYN clone */ + if (skb_queue_empty(&sk->sk_write_queue)) + break; + get_random_bytes(&delay, 2); + delay %= (tsk->conn_timeout / 4); + delay = msecs_to_jiffies(delay + 100); + sk_reset_timer(sk, &sk->sk_timer, jiffies + delay); + return false; + case TIPC_OPEN: + case TIPC_DISCONNECTING: + return false; + case TIPC_LISTEN: + /* Accept only SYN message */ + if (!msg_is_syn(hdr) && + tipc_node_get_capabilities(net, onode) & TIPC_SYN_BIT) + return false; + if (!con_msg && !err) + return true; + return false; + case TIPC_ESTABLISHED: + if (!skb_queue_empty(&sk->sk_write_queue)) + tipc_sk_push_backlog(tsk, false); + /* Accept only connection-based messages sent by peer */ + if (likely(con_msg && !err && pport == oport && + pnode == onode)) { + if (msg_ack_required(hdr)) { + struct sk_buff *skb; + + skb = tipc_sk_build_ack(tsk); + if (skb) { + msg_set_nagle_ack(buf_msg(skb)); + __skb_queue_tail(xmitq, skb); + } + } + return true; + } + if (!tsk_peer_msg(tsk, hdr)) + return false; + if (!err) + return true; + tipc_set_sk_state(sk, TIPC_DISCONNECTING); + tipc_node_remove_conn(net, pnode, tsk->portid); + sk->sk_state_change(sk); + return true; + default: + pr_err("Unknown sk_state %u\n", sk->sk_state); + } + /* Abort connection setup attempt */ + tipc_set_sk_state(sk, TIPC_DISCONNECTING); + sk->sk_err = ECONNREFUSED; + sk->sk_state_change(sk); + return true; +} + +/** + * rcvbuf_limit - get proper overload limit of socket receive queue + * @sk: socket + * @skb: message + * + * For connection oriented messages, irrespective of importance, + * default queue limit is 2 MB. + * + * For connectionless messages, queue limits are based on message + * importance as follows: + * + * TIPC_LOW_IMPORTANCE (2 MB) + * TIPC_MEDIUM_IMPORTANCE (4 MB) + * TIPC_HIGH_IMPORTANCE (8 MB) + * TIPC_CRITICAL_IMPORTANCE (16 MB) + * + * Return: overload limit according to corresponding message importance + */ +static unsigned int rcvbuf_limit(struct sock *sk, struct sk_buff *skb) +{ + struct tipc_sock *tsk = tipc_sk(sk); + struct tipc_msg *hdr = buf_msg(skb); + + if (unlikely(msg_in_group(hdr))) + return READ_ONCE(sk->sk_rcvbuf); + + if (unlikely(!msg_connected(hdr))) + return READ_ONCE(sk->sk_rcvbuf) << msg_importance(hdr); + + if (likely(tsk->peer_caps & TIPC_BLOCK_FLOWCTL)) + return READ_ONCE(sk->sk_rcvbuf); + + return FLOWCTL_MSG_LIM; +} + +/** + * tipc_sk_filter_rcv - validate incoming message + * @sk: socket + * @skb: pointer to message. + * @xmitq: output message area (FIXME) + * + * Enqueues message on receive queue if acceptable; optionally handles + * disconnect indication for a connected socket. + * + * Called with socket lock already taken + */ +static void tipc_sk_filter_rcv(struct sock *sk, struct sk_buff *skb, + struct sk_buff_head *xmitq) +{ + bool sk_conn = !tipc_sk_type_connectionless(sk); + struct tipc_sock *tsk = tipc_sk(sk); + struct tipc_group *grp = tsk->group; + struct tipc_msg *hdr = buf_msg(skb); + struct net *net = sock_net(sk); + struct sk_buff_head inputq; + int mtyp = msg_type(hdr); + int limit, err = TIPC_OK; + + trace_tipc_sk_filter_rcv(sk, skb, TIPC_DUMP_ALL, " "); + TIPC_SKB_CB(skb)->bytes_read = 0; + __skb_queue_head_init(&inputq); + __skb_queue_tail(&inputq, skb); + + if (unlikely(!msg_isdata(hdr))) + tipc_sk_proto_rcv(sk, &inputq, xmitq); + + if (unlikely(grp)) + tipc_group_filter_msg(grp, &inputq, xmitq); + + if (unlikely(!grp) && mtyp == TIPC_MCAST_MSG) + tipc_mcast_filter_msg(net, &tsk->mc_method.deferredq, &inputq); + + /* Validate and add to receive buffer if there is space */ + while ((skb = __skb_dequeue(&inputq))) { + hdr = buf_msg(skb); + limit = rcvbuf_limit(sk, skb); + if ((sk_conn && !tipc_sk_filter_connect(tsk, skb, xmitq)) || + (!sk_conn && msg_connected(hdr)) || + (!grp && msg_in_group(hdr))) + err = TIPC_ERR_NO_PORT; + else if (sk_rmem_alloc_get(sk) + skb->truesize >= limit) { + trace_tipc_sk_dump(sk, skb, TIPC_DUMP_ALL, + "err_overload2!"); + atomic_inc(&sk->sk_drops); + err = TIPC_ERR_OVERLOAD; + } + + if (unlikely(err)) { + if (tipc_msg_reverse(tipc_own_addr(net), &skb, err)) { + trace_tipc_sk_rej_msg(sk, skb, TIPC_DUMP_NONE, + "@filter_rcv!"); + __skb_queue_tail(xmitq, skb); + } + err = TIPC_OK; + continue; + } + __skb_queue_tail(&sk->sk_receive_queue, skb); + skb_set_owner_r(skb, sk); + trace_tipc_sk_overlimit2(sk, skb, TIPC_DUMP_ALL, + "rcvq >90% allocated!"); + sk->sk_data_ready(sk); + } +} + +/** + * tipc_sk_backlog_rcv - handle incoming message from backlog queue + * @sk: socket + * @skb: message + * + * Caller must hold socket lock + */ +static int tipc_sk_backlog_rcv(struct sock *sk, struct sk_buff *skb) +{ + unsigned int before = sk_rmem_alloc_get(sk); + struct sk_buff_head xmitq; + unsigned int added; + + __skb_queue_head_init(&xmitq); + + tipc_sk_filter_rcv(sk, skb, &xmitq); + added = sk_rmem_alloc_get(sk) - before; + atomic_add(added, &tipc_sk(sk)->dupl_rcvcnt); + + /* Send pending response/rejected messages, if any */ + tipc_node_distr_xmit(sock_net(sk), &xmitq); + return 0; +} + +/** + * tipc_sk_enqueue - extract all buffers with destination 'dport' from + * inputq and try adding them to socket or backlog queue + * @inputq: list of incoming buffers with potentially different destinations + * @sk: socket where the buffers should be enqueued + * @dport: port number for the socket + * @xmitq: output queue + * + * Caller must hold socket lock + */ +static void tipc_sk_enqueue(struct sk_buff_head *inputq, struct sock *sk, + u32 dport, struct sk_buff_head *xmitq) +{ + unsigned long time_limit = jiffies + usecs_to_jiffies(20000); + struct sk_buff *skb; + unsigned int lim; + atomic_t *dcnt; + u32 onode; + + while (skb_queue_len(inputq)) { + if (unlikely(time_after_eq(jiffies, time_limit))) + return; + + skb = tipc_skb_dequeue(inputq, dport); + if (unlikely(!skb)) + return; + + /* Add message directly to receive queue if possible */ + if (!sock_owned_by_user(sk)) { + tipc_sk_filter_rcv(sk, skb, xmitq); + continue; + } + + /* Try backlog, compensating for double-counted bytes */ + dcnt = &tipc_sk(sk)->dupl_rcvcnt; + if (!sk->sk_backlog.len) + atomic_set(dcnt, 0); + lim = rcvbuf_limit(sk, skb) + atomic_read(dcnt); + if (likely(!sk_add_backlog(sk, skb, lim))) { + trace_tipc_sk_overlimit1(sk, skb, TIPC_DUMP_ALL, + "bklg & rcvq >90% allocated!"); + continue; + } + + trace_tipc_sk_dump(sk, skb, TIPC_DUMP_ALL, "err_overload!"); + /* Overload => reject message back to sender */ + onode = tipc_own_addr(sock_net(sk)); + atomic_inc(&sk->sk_drops); + if (tipc_msg_reverse(onode, &skb, TIPC_ERR_OVERLOAD)) { + trace_tipc_sk_rej_msg(sk, skb, TIPC_DUMP_ALL, + "@sk_enqueue!"); + __skb_queue_tail(xmitq, skb); + } + break; + } +} + +/** + * tipc_sk_rcv - handle a chain of incoming buffers + * @net: the associated network namespace + * @inputq: buffer list containing the buffers + * Consumes all buffers in list until inputq is empty + * Note: may be called in multiple threads referring to the same queue + */ +void tipc_sk_rcv(struct net *net, struct sk_buff_head *inputq) +{ + struct sk_buff_head xmitq; + u32 dnode, dport = 0; + int err; + struct tipc_sock *tsk; + struct sock *sk; + struct sk_buff *skb; + + __skb_queue_head_init(&xmitq); + while (skb_queue_len(inputq)) { + dport = tipc_skb_peek_port(inputq, dport); + tsk = tipc_sk_lookup(net, dport); + + if (likely(tsk)) { + sk = &tsk->sk; + if (likely(spin_trylock_bh(&sk->sk_lock.slock))) { + tipc_sk_enqueue(inputq, sk, dport, &xmitq); + spin_unlock_bh(&sk->sk_lock.slock); + } + /* Send pending response/rejected messages, if any */ + tipc_node_distr_xmit(sock_net(sk), &xmitq); + sock_put(sk); + continue; + } + /* No destination socket => dequeue skb if still there */ + skb = tipc_skb_dequeue(inputq, dport); + if (!skb) + return; + + /* Try secondary lookup if unresolved named message */ + err = TIPC_ERR_NO_PORT; + if (tipc_msg_lookup_dest(net, skb, &err)) + goto xmit; + + /* Prepare for message rejection */ + if (!tipc_msg_reverse(tipc_own_addr(net), &skb, err)) + continue; + + trace_tipc_sk_rej_msg(NULL, skb, TIPC_DUMP_NONE, "@sk_rcv!"); +xmit: + dnode = msg_destnode(buf_msg(skb)); + tipc_node_xmit_skb(net, skb, dnode, dport); + } +} + +static int tipc_wait_for_connect(struct socket *sock, long *timeo_p) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + struct sock *sk = sock->sk; + int done; + + do { + int err = sock_error(sk); + if (err) + return err; + if (!*timeo_p) + return -ETIMEDOUT; + if (signal_pending(current)) + return sock_intr_errno(*timeo_p); + if (sk->sk_state == TIPC_DISCONNECTING) + break; + + add_wait_queue(sk_sleep(sk), &wait); + done = sk_wait_event(sk, timeo_p, tipc_sk_connected(sk), + &wait); + remove_wait_queue(sk_sleep(sk), &wait); + } while (!done); + return 0; +} + +static bool tipc_sockaddr_is_sane(struct sockaddr_tipc *addr) +{ + if (addr->family != AF_TIPC) + return false; + if (addr->addrtype == TIPC_SERVICE_RANGE) + return (addr->addr.nameseq.lower <= addr->addr.nameseq.upper); + return (addr->addrtype == TIPC_SERVICE_ADDR || + addr->addrtype == TIPC_SOCKET_ADDR); +} + +/** + * tipc_connect - establish a connection to another TIPC port + * @sock: socket structure + * @dest: socket address for destination port + * @destlen: size of socket address data structure + * @flags: file-related flags associated with socket + * + * Return: 0 on success, errno otherwise + */ +static int tipc_connect(struct socket *sock, struct sockaddr *dest, + int destlen, int flags) +{ + struct sock *sk = sock->sk; + struct tipc_sock *tsk = tipc_sk(sk); + struct sockaddr_tipc *dst = (struct sockaddr_tipc *)dest; + struct msghdr m = {NULL,}; + long timeout = (flags & O_NONBLOCK) ? 0 : tsk->conn_timeout; + int previous; + int res = 0; + + if (destlen != sizeof(struct sockaddr_tipc)) + return -EINVAL; + + lock_sock(sk); + + if (tsk->group) { + res = -EINVAL; + goto exit; + } + + if (dst->family == AF_UNSPEC) { + memset(&tsk->peer, 0, sizeof(struct sockaddr_tipc)); + if (!tipc_sk_type_connectionless(sk)) + res = -EINVAL; + goto exit; + } + if (!tipc_sockaddr_is_sane(dst)) { + res = -EINVAL; + goto exit; + } + /* DGRAM/RDM connect(), just save the destaddr */ + if (tipc_sk_type_connectionless(sk)) { + memcpy(&tsk->peer, dest, destlen); + goto exit; + } else if (dst->addrtype == TIPC_SERVICE_RANGE) { + res = -EINVAL; + goto exit; + } + + previous = sk->sk_state; + + switch (sk->sk_state) { + case TIPC_OPEN: + /* Send a 'SYN-' to destination */ + m.msg_name = dest; + m.msg_namelen = destlen; + iov_iter_kvec(&m.msg_iter, ITER_SOURCE, NULL, 0, 0); + + /* If connect is in non-blocking case, set MSG_DONTWAIT to + * indicate send_msg() is never blocked. + */ + if (!timeout) + m.msg_flags = MSG_DONTWAIT; + + res = __tipc_sendmsg(sock, &m, 0); + if ((res < 0) && (res != -EWOULDBLOCK)) + goto exit; + + /* Just entered TIPC_CONNECTING state; the only + * difference is that return value in non-blocking + * case is EINPROGRESS, rather than EALREADY. + */ + res = -EINPROGRESS; + fallthrough; + case TIPC_CONNECTING: + if (!timeout) { + if (previous == TIPC_CONNECTING) + res = -EALREADY; + goto exit; + } + timeout = msecs_to_jiffies(timeout); + /* Wait until an 'ACK' or 'RST' arrives, or a timeout occurs */ + res = tipc_wait_for_connect(sock, &timeout); + break; + case TIPC_ESTABLISHED: + res = -EISCONN; + break; + default: + res = -EINVAL; + } + +exit: + release_sock(sk); + return res; +} + +/** + * tipc_listen - allow socket to listen for incoming connections + * @sock: socket structure + * @len: (unused) + * + * Return: 0 on success, errno otherwise + */ +static int tipc_listen(struct socket *sock, int len) +{ + struct sock *sk = sock->sk; + int res; + + lock_sock(sk); + res = tipc_set_sk_state(sk, TIPC_LISTEN); + release_sock(sk); + + return res; +} + +static int tipc_wait_for_accept(struct socket *sock, long timeo) +{ + struct sock *sk = sock->sk; + DEFINE_WAIT_FUNC(wait, woken_wake_function); + int err; + + /* True wake-one mechanism for incoming connections: only + * one process gets woken up, not the 'whole herd'. + * Since we do not 'race & poll' for established sockets + * anymore, the common case will execute the loop only once. + */ + for (;;) { + if (timeo && skb_queue_empty(&sk->sk_receive_queue)) { + add_wait_queue(sk_sleep(sk), &wait); + release_sock(sk); + timeo = wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); + lock_sock(sk); + remove_wait_queue(sk_sleep(sk), &wait); + } + err = 0; + if (!skb_queue_empty(&sk->sk_receive_queue)) + break; + err = -EAGAIN; + if (!timeo) + break; + err = sock_intr_errno(timeo); + if (signal_pending(current)) + break; + } + return err; +} + +/** + * tipc_accept - wait for connection request + * @sock: listening socket + * @new_sock: new socket that is to be connected + * @flags: file-related flags associated with socket + * @kern: caused by kernel or by userspace? + * + * Return: 0 on success, errno otherwise + */ +static int tipc_accept(struct socket *sock, struct socket *new_sock, int flags, + bool kern) +{ + struct sock *new_sk, *sk = sock->sk; + struct tipc_sock *new_tsock; + struct msghdr m = {NULL,}; + struct tipc_msg *msg; + struct sk_buff *buf; + long timeo; + int res; + + lock_sock(sk); + + if (sk->sk_state != TIPC_LISTEN) { + res = -EINVAL; + goto exit; + } + timeo = sock_rcvtimeo(sk, flags & O_NONBLOCK); + res = tipc_wait_for_accept(sock, timeo); + if (res) + goto exit; + + buf = skb_peek(&sk->sk_receive_queue); + + res = tipc_sk_create(sock_net(sock->sk), new_sock, 0, kern); + if (res) + goto exit; + security_sk_clone(sock->sk, new_sock->sk); + + new_sk = new_sock->sk; + new_tsock = tipc_sk(new_sk); + msg = buf_msg(buf); + + /* we lock on new_sk; but lockdep sees the lock on sk */ + lock_sock_nested(new_sk, SINGLE_DEPTH_NESTING); + + /* + * Reject any stray messages received by new socket + * before the socket lock was taken (very, very unlikely) + */ + tsk_rej_rx_queue(new_sk, TIPC_ERR_NO_PORT); + + /* Connect new socket to it's peer */ + tipc_sk_finish_conn(new_tsock, msg_origport(msg), msg_orignode(msg)); + + tsk_set_importance(new_sk, msg_importance(msg)); + if (msg_named(msg)) { + new_tsock->conn_addrtype = TIPC_SERVICE_ADDR; + msg_set_nametype(&new_tsock->phdr, msg_nametype(msg)); + msg_set_nameinst(&new_tsock->phdr, msg_nameinst(msg)); + } + + /* + * Respond to 'SYN-' by discarding it & returning 'ACK'. + * Respond to 'SYN+' by queuing it on new socket & returning 'ACK'. + */ + if (!msg_data_sz(msg)) { + tsk_advance_rx_queue(sk); + } else { + __skb_dequeue(&sk->sk_receive_queue); + __skb_queue_head(&new_sk->sk_receive_queue, buf); + skb_set_owner_r(buf, new_sk); + } + iov_iter_kvec(&m.msg_iter, ITER_SOURCE, NULL, 0, 0); + __tipc_sendstream(new_sock, &m, 0); + release_sock(new_sk); +exit: + release_sock(sk); + return res; +} + +/** + * tipc_shutdown - shutdown socket connection + * @sock: socket structure + * @how: direction to close (must be SHUT_RDWR) + * + * Terminates connection (if necessary), then purges socket's receive queue. + * + * Return: 0 on success, errno otherwise + */ +static int tipc_shutdown(struct socket *sock, int how) +{ + struct sock *sk = sock->sk; + int res; + + if (how != SHUT_RDWR) + return -EINVAL; + + lock_sock(sk); + + trace_tipc_sk_shutdown(sk, NULL, TIPC_DUMP_ALL, " "); + __tipc_shutdown(sock, TIPC_CONN_SHUTDOWN); + sk->sk_shutdown = SHUTDOWN_MASK; + + if (sk->sk_state == TIPC_DISCONNECTING) { + /* Discard any unreceived messages */ + __skb_queue_purge(&sk->sk_receive_queue); + + res = 0; + } else { + res = -ENOTCONN; + } + /* Wake up anyone sleeping in poll. */ + sk->sk_state_change(sk); + + release_sock(sk); + return res; +} + +static void tipc_sk_check_probing_state(struct sock *sk, + struct sk_buff_head *list) +{ + struct tipc_sock *tsk = tipc_sk(sk); + u32 pnode = tsk_peer_node(tsk); + u32 pport = tsk_peer_port(tsk); + u32 self = tsk_own_node(tsk); + u32 oport = tsk->portid; + struct sk_buff *skb; + + if (tsk->probe_unacked) { + tipc_set_sk_state(sk, TIPC_DISCONNECTING); + sk->sk_err = ECONNABORTED; + tipc_node_remove_conn(sock_net(sk), pnode, pport); + sk->sk_state_change(sk); + return; + } + /* Prepare new probe */ + skb = tipc_msg_create(CONN_MANAGER, CONN_PROBE, INT_H_SIZE, 0, + pnode, self, pport, oport, TIPC_OK); + if (skb) + __skb_queue_tail(list, skb); + tsk->probe_unacked = true; + sk_reset_timer(sk, &sk->sk_timer, jiffies + CONN_PROBING_INTV); +} + +static void tipc_sk_retry_connect(struct sock *sk, struct sk_buff_head *list) +{ + struct tipc_sock *tsk = tipc_sk(sk); + + /* Try again later if dest link is congested */ + if (tsk->cong_link_cnt) { + sk_reset_timer(sk, &sk->sk_timer, + jiffies + msecs_to_jiffies(100)); + return; + } + /* Prepare SYN for retransmit */ + tipc_msg_skb_clone(&sk->sk_write_queue, list); +} + +static void tipc_sk_timeout(struct timer_list *t) +{ + struct sock *sk = from_timer(sk, t, sk_timer); + struct tipc_sock *tsk = tipc_sk(sk); + u32 pnode = tsk_peer_node(tsk); + struct sk_buff_head list; + int rc = 0; + + __skb_queue_head_init(&list); + bh_lock_sock(sk); + + /* Try again later if socket is busy */ + if (sock_owned_by_user(sk)) { + sk_reset_timer(sk, &sk->sk_timer, jiffies + HZ / 20); + bh_unlock_sock(sk); + sock_put(sk); + return; + } + + if (sk->sk_state == TIPC_ESTABLISHED) + tipc_sk_check_probing_state(sk, &list); + else if (sk->sk_state == TIPC_CONNECTING) + tipc_sk_retry_connect(sk, &list); + + bh_unlock_sock(sk); + + if (!skb_queue_empty(&list)) + rc = tipc_node_xmit(sock_net(sk), &list, pnode, tsk->portid); + + /* SYN messages may cause link congestion */ + if (rc == -ELINKCONG) { + tipc_dest_push(&tsk->cong_links, pnode, 0); + tsk->cong_link_cnt = 1; + } + sock_put(sk); +} + +static int tipc_sk_publish(struct tipc_sock *tsk, struct tipc_uaddr *ua) +{ + struct sock *sk = &tsk->sk; + struct net *net = sock_net(sk); + struct tipc_socket_addr skaddr; + struct publication *p; + u32 key; + + if (tipc_sk_connected(sk)) + return -EINVAL; + key = tsk->portid + tsk->pub_count + 1; + if (key == tsk->portid) + return -EADDRINUSE; + skaddr.ref = tsk->portid; + skaddr.node = tipc_own_addr(net); + p = tipc_nametbl_publish(net, ua, &skaddr, key); + if (unlikely(!p)) + return -EINVAL; + + list_add(&p->binding_sock, &tsk->publications); + tsk->pub_count++; + tsk->published = true; + return 0; +} + +static int tipc_sk_withdraw(struct tipc_sock *tsk, struct tipc_uaddr *ua) +{ + struct net *net = sock_net(&tsk->sk); + struct publication *safe, *p; + struct tipc_uaddr _ua; + int rc = -EINVAL; + + list_for_each_entry_safe(p, safe, &tsk->publications, binding_sock) { + if (!ua) { + tipc_uaddr(&_ua, TIPC_SERVICE_RANGE, p->scope, + p->sr.type, p->sr.lower, p->sr.upper); + tipc_nametbl_withdraw(net, &_ua, &p->sk, p->key); + continue; + } + /* Unbind specific publication */ + if (p->scope != ua->scope) + continue; + if (p->sr.type != ua->sr.type) + continue; + if (p->sr.lower != ua->sr.lower) + continue; + if (p->sr.upper != ua->sr.upper) + break; + tipc_nametbl_withdraw(net, ua, &p->sk, p->key); + rc = 0; + break; + } + if (list_empty(&tsk->publications)) { + tsk->published = 0; + rc = 0; + } + return rc; +} + +/* tipc_sk_reinit: set non-zero address in all existing sockets + * when we go from standalone to network mode. + */ +void tipc_sk_reinit(struct net *net) +{ + struct tipc_net *tn = net_generic(net, tipc_net_id); + struct rhashtable_iter iter; + struct tipc_sock *tsk; + struct tipc_msg *msg; + + rhashtable_walk_enter(&tn->sk_rht, &iter); + + do { + rhashtable_walk_start(&iter); + + while ((tsk = rhashtable_walk_next(&iter)) && !IS_ERR(tsk)) { + sock_hold(&tsk->sk); + rhashtable_walk_stop(&iter); + lock_sock(&tsk->sk); + msg = &tsk->phdr; + msg_set_prevnode(msg, tipc_own_addr(net)); + msg_set_orignode(msg, tipc_own_addr(net)); + release_sock(&tsk->sk); + rhashtable_walk_start(&iter); + sock_put(&tsk->sk); + } + + rhashtable_walk_stop(&iter); + } while (tsk == ERR_PTR(-EAGAIN)); + + rhashtable_walk_exit(&iter); +} + +static struct tipc_sock *tipc_sk_lookup(struct net *net, u32 portid) +{ + struct tipc_net *tn = net_generic(net, tipc_net_id); + struct tipc_sock *tsk; + + rcu_read_lock(); + tsk = rhashtable_lookup(&tn->sk_rht, &portid, tsk_rht_params); + if (tsk) + sock_hold(&tsk->sk); + rcu_read_unlock(); + + return tsk; +} + +static int tipc_sk_insert(struct tipc_sock *tsk) +{ + struct sock *sk = &tsk->sk; + struct net *net = sock_net(sk); + struct tipc_net *tn = net_generic(net, tipc_net_id); + u32 remaining = (TIPC_MAX_PORT - TIPC_MIN_PORT) + 1; + u32 portid = prandom_u32_max(remaining) + TIPC_MIN_PORT; + + while (remaining--) { + portid++; + if ((portid < TIPC_MIN_PORT) || (portid > TIPC_MAX_PORT)) + portid = TIPC_MIN_PORT; + tsk->portid = portid; + sock_hold(&tsk->sk); + if (!rhashtable_lookup_insert_fast(&tn->sk_rht, &tsk->node, + tsk_rht_params)) + return 0; + sock_put(&tsk->sk); + } + + return -1; +} + +static void tipc_sk_remove(struct tipc_sock *tsk) +{ + struct sock *sk = &tsk->sk; + struct tipc_net *tn = net_generic(sock_net(sk), tipc_net_id); + + if (!rhashtable_remove_fast(&tn->sk_rht, &tsk->node, tsk_rht_params)) { + WARN_ON(refcount_read(&sk->sk_refcnt) == 1); + __sock_put(sk); + } +} + +static const struct rhashtable_params tsk_rht_params = { + .nelem_hint = 192, + .head_offset = offsetof(struct tipc_sock, node), + .key_offset = offsetof(struct tipc_sock, portid), + .key_len = sizeof(u32), /* portid */ + .max_size = 1048576, + .min_size = 256, + .automatic_shrinking = true, +}; + +int tipc_sk_rht_init(struct net *net) +{ + struct tipc_net *tn = net_generic(net, tipc_net_id); + + return rhashtable_init(&tn->sk_rht, &tsk_rht_params); +} + +void tipc_sk_rht_destroy(struct net *net) +{ + struct tipc_net *tn = net_generic(net, tipc_net_id); + + /* Wait for socket readers to complete */ + synchronize_net(); + + rhashtable_destroy(&tn->sk_rht); +} + +static int tipc_sk_join(struct tipc_sock *tsk, struct tipc_group_req *mreq) +{ + struct net *net = sock_net(&tsk->sk); + struct tipc_group *grp = tsk->group; + struct tipc_msg *hdr = &tsk->phdr; + struct tipc_uaddr ua; + int rc; + + if (mreq->type < TIPC_RESERVED_TYPES) + return -EACCES; + if (mreq->scope > TIPC_NODE_SCOPE) + return -EINVAL; + if (mreq->scope != TIPC_NODE_SCOPE) + mreq->scope = TIPC_CLUSTER_SCOPE; + if (grp) + return -EACCES; + grp = tipc_group_create(net, tsk->portid, mreq, &tsk->group_is_open); + if (!grp) + return -ENOMEM; + tsk->group = grp; + msg_set_lookup_scope(hdr, mreq->scope); + msg_set_nametype(hdr, mreq->type); + msg_set_dest_droppable(hdr, true); + tipc_uaddr(&ua, TIPC_SERVICE_RANGE, mreq->scope, + mreq->type, mreq->instance, mreq->instance); + tipc_nametbl_build_group(net, grp, &ua); + rc = tipc_sk_publish(tsk, &ua); + if (rc) { + tipc_group_delete(net, grp); + tsk->group = NULL; + return rc; + } + /* Eliminate any risk that a broadcast overtakes sent JOINs */ + tsk->mc_method.rcast = true; + tsk->mc_method.mandatory = true; + tipc_group_join(net, grp, &tsk->sk.sk_rcvbuf); + return rc; +} + +static int tipc_sk_leave(struct tipc_sock *tsk) +{ + struct net *net = sock_net(&tsk->sk); + struct tipc_group *grp = tsk->group; + struct tipc_uaddr ua; + int scope; + + if (!grp) + return -EINVAL; + ua.addrtype = TIPC_SERVICE_RANGE; + tipc_group_self(grp, &ua.sr, &scope); + ua.scope = scope; + tipc_group_delete(net, grp); + tsk->group = NULL; + tipc_sk_withdraw(tsk, &ua); + return 0; +} + +/** + * tipc_setsockopt - set socket option + * @sock: socket structure + * @lvl: option level + * @opt: option identifier + * @ov: pointer to new option value + * @ol: length of option value + * + * For stream sockets only, accepts and ignores all IPPROTO_TCP options + * (to ease compatibility). + * + * Return: 0 on success, errno otherwise + */ +static int tipc_setsockopt(struct socket *sock, int lvl, int opt, + sockptr_t ov, unsigned int ol) +{ + struct sock *sk = sock->sk; + struct tipc_sock *tsk = tipc_sk(sk); + struct tipc_group_req mreq; + u32 value = 0; + int res = 0; + + if ((lvl == IPPROTO_TCP) && (sock->type == SOCK_STREAM)) + return 0; + if (lvl != SOL_TIPC) + return -ENOPROTOOPT; + + switch (opt) { + case TIPC_IMPORTANCE: + case TIPC_SRC_DROPPABLE: + case TIPC_DEST_DROPPABLE: + case TIPC_CONN_TIMEOUT: + case TIPC_NODELAY: + if (ol < sizeof(value)) + return -EINVAL; + if (copy_from_sockptr(&value, ov, sizeof(u32))) + return -EFAULT; + break; + case TIPC_GROUP_JOIN: + if (ol < sizeof(mreq)) + return -EINVAL; + if (copy_from_sockptr(&mreq, ov, sizeof(mreq))) + return -EFAULT; + break; + default: + if (!sockptr_is_null(ov) || ol) + return -EINVAL; + } + + lock_sock(sk); + + switch (opt) { + case TIPC_IMPORTANCE: + res = tsk_set_importance(sk, value); + break; + case TIPC_SRC_DROPPABLE: + if (sock->type != SOCK_STREAM) + tsk_set_unreliable(tsk, value); + else + res = -ENOPROTOOPT; + break; + case TIPC_DEST_DROPPABLE: + tsk_set_unreturnable(tsk, value); + break; + case TIPC_CONN_TIMEOUT: + tipc_sk(sk)->conn_timeout = value; + break; + case TIPC_MCAST_BROADCAST: + tsk->mc_method.rcast = false; + tsk->mc_method.mandatory = true; + break; + case TIPC_MCAST_REPLICAST: + tsk->mc_method.rcast = true; + tsk->mc_method.mandatory = true; + break; + case TIPC_GROUP_JOIN: + res = tipc_sk_join(tsk, &mreq); + break; + case TIPC_GROUP_LEAVE: + res = tipc_sk_leave(tsk); + break; + case TIPC_NODELAY: + tsk->nodelay = !!value; + tsk_set_nagle(tsk); + break; + default: + res = -EINVAL; + } + + release_sock(sk); + + return res; +} + +/** + * tipc_getsockopt - get socket option + * @sock: socket structure + * @lvl: option level + * @opt: option identifier + * @ov: receptacle for option value + * @ol: receptacle for length of option value + * + * For stream sockets only, returns 0 length result for all IPPROTO_TCP options + * (to ease compatibility). + * + * Return: 0 on success, errno otherwise + */ +static int tipc_getsockopt(struct socket *sock, int lvl, int opt, + char __user *ov, int __user *ol) +{ + struct sock *sk = sock->sk; + struct tipc_sock *tsk = tipc_sk(sk); + struct tipc_service_range seq; + int len, scope; + u32 value; + int res; + + if ((lvl == IPPROTO_TCP) && (sock->type == SOCK_STREAM)) + return put_user(0, ol); + if (lvl != SOL_TIPC) + return -ENOPROTOOPT; + res = get_user(len, ol); + if (res) + return res; + + lock_sock(sk); + + switch (opt) { + case TIPC_IMPORTANCE: + value = tsk_importance(tsk); + break; + case TIPC_SRC_DROPPABLE: + value = tsk_unreliable(tsk); + break; + case TIPC_DEST_DROPPABLE: + value = tsk_unreturnable(tsk); + break; + case TIPC_CONN_TIMEOUT: + value = tsk->conn_timeout; + /* no need to set "res", since already 0 at this point */ + break; + case TIPC_NODE_RECVQ_DEPTH: + value = 0; /* was tipc_queue_size, now obsolete */ + break; + case TIPC_SOCK_RECVQ_DEPTH: + value = skb_queue_len(&sk->sk_receive_queue); + break; + case TIPC_SOCK_RECVQ_USED: + value = sk_rmem_alloc_get(sk); + break; + case TIPC_GROUP_JOIN: + seq.type = 0; + if (tsk->group) + tipc_group_self(tsk->group, &seq, &scope); + value = seq.type; + break; + default: + res = -EINVAL; + } + + release_sock(sk); + + if (res) + return res; /* "get" failed */ + + if (len < sizeof(value)) + return -EINVAL; + + if (copy_to_user(ov, &value, sizeof(value))) + return -EFAULT; + + return put_user(sizeof(value), ol); +} + +static int tipc_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + struct net *net = sock_net(sock->sk); + struct tipc_sioc_nodeid_req nr = {0}; + struct tipc_sioc_ln_req lnr; + void __user *argp = (void __user *)arg; + + switch (cmd) { + case SIOCGETLINKNAME: + if (copy_from_user(&lnr, argp, sizeof(lnr))) + return -EFAULT; + if (!tipc_node_get_linkname(net, + lnr.bearer_id & 0xffff, lnr.peer, + lnr.linkname, TIPC_MAX_LINK_NAME)) { + if (copy_to_user(argp, &lnr, sizeof(lnr))) + return -EFAULT; + return 0; + } + return -EADDRNOTAVAIL; + case SIOCGETNODEID: + if (copy_from_user(&nr, argp, sizeof(nr))) + return -EFAULT; + if (!tipc_node_get_id(net, nr.peer, nr.node_id)) + return -EADDRNOTAVAIL; + if (copy_to_user(argp, &nr, sizeof(nr))) + return -EFAULT; + return 0; + default: + return -ENOIOCTLCMD; + } +} + +static int tipc_socketpair(struct socket *sock1, struct socket *sock2) +{ + struct tipc_sock *tsk2 = tipc_sk(sock2->sk); + struct tipc_sock *tsk1 = tipc_sk(sock1->sk); + u32 onode = tipc_own_addr(sock_net(sock1->sk)); + + tsk1->peer.family = AF_TIPC; + tsk1->peer.addrtype = TIPC_SOCKET_ADDR; + tsk1->peer.scope = TIPC_NODE_SCOPE; + tsk1->peer.addr.id.ref = tsk2->portid; + tsk1->peer.addr.id.node = onode; + tsk2->peer.family = AF_TIPC; + tsk2->peer.addrtype = TIPC_SOCKET_ADDR; + tsk2->peer.scope = TIPC_NODE_SCOPE; + tsk2->peer.addr.id.ref = tsk1->portid; + tsk2->peer.addr.id.node = onode; + + tipc_sk_finish_conn(tsk1, tsk2->portid, onode); + tipc_sk_finish_conn(tsk2, tsk1->portid, onode); + return 0; +} + +/* Protocol switches for the various types of TIPC sockets */ + +static const struct proto_ops msg_ops = { + .owner = THIS_MODULE, + .family = AF_TIPC, + .release = tipc_release, + .bind = tipc_bind, + .connect = tipc_connect, + .socketpair = tipc_socketpair, + .accept = sock_no_accept, + .getname = tipc_getname, + .poll = tipc_poll, + .ioctl = tipc_ioctl, + .listen = sock_no_listen, + .shutdown = tipc_shutdown, + .setsockopt = tipc_setsockopt, + .getsockopt = tipc_getsockopt, + .sendmsg = tipc_sendmsg, + .recvmsg = tipc_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage +}; + +static const struct proto_ops packet_ops = { + .owner = THIS_MODULE, + .family = AF_TIPC, + .release = tipc_release, + .bind = tipc_bind, + .connect = tipc_connect, + .socketpair = tipc_socketpair, + .accept = tipc_accept, + .getname = tipc_getname, + .poll = tipc_poll, + .ioctl = tipc_ioctl, + .listen = tipc_listen, + .shutdown = tipc_shutdown, + .setsockopt = tipc_setsockopt, + .getsockopt = tipc_getsockopt, + .sendmsg = tipc_send_packet, + .recvmsg = tipc_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage +}; + +static const struct proto_ops stream_ops = { + .owner = THIS_MODULE, + .family = AF_TIPC, + .release = tipc_release, + .bind = tipc_bind, + .connect = tipc_connect, + .socketpair = tipc_socketpair, + .accept = tipc_accept, + .getname = tipc_getname, + .poll = tipc_poll, + .ioctl = tipc_ioctl, + .listen = tipc_listen, + .shutdown = tipc_shutdown, + .setsockopt = tipc_setsockopt, + .getsockopt = tipc_getsockopt, + .sendmsg = tipc_sendstream, + .recvmsg = tipc_recvstream, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage +}; + +static const struct net_proto_family tipc_family_ops = { + .owner = THIS_MODULE, + .family = AF_TIPC, + .create = tipc_sk_create +}; + +static struct proto tipc_proto = { + .name = "TIPC", + .owner = THIS_MODULE, + .obj_size = sizeof(struct tipc_sock), + .sysctl_rmem = sysctl_tipc_rmem +}; + +/** + * tipc_socket_init - initialize TIPC socket interface + * + * Return: 0 on success, errno otherwise + */ +int tipc_socket_init(void) +{ + int res; + + res = proto_register(&tipc_proto, 1); + if (res) { + pr_err("Failed to register TIPC protocol type\n"); + goto out; + } + + res = sock_register(&tipc_family_ops); + if (res) { + pr_err("Failed to register TIPC socket type\n"); + proto_unregister(&tipc_proto); + goto out; + } + out: + return res; +} + +/** + * tipc_socket_stop - stop TIPC socket interface + */ +void tipc_socket_stop(void) +{ + sock_unregister(tipc_family_ops.family); + proto_unregister(&tipc_proto); +} + +/* Caller should hold socket lock for the passed tipc socket. */ +static int __tipc_nl_add_sk_con(struct sk_buff *skb, struct tipc_sock *tsk) +{ + u32 peer_node, peer_port; + u32 conn_type, conn_instance; + struct nlattr *nest; + + peer_node = tsk_peer_node(tsk); + peer_port = tsk_peer_port(tsk); + conn_type = msg_nametype(&tsk->phdr); + conn_instance = msg_nameinst(&tsk->phdr); + nest = nla_nest_start_noflag(skb, TIPC_NLA_SOCK_CON); + if (!nest) + return -EMSGSIZE; + + if (nla_put_u32(skb, TIPC_NLA_CON_NODE, peer_node)) + goto msg_full; + if (nla_put_u32(skb, TIPC_NLA_CON_SOCK, peer_port)) + goto msg_full; + + if (tsk->conn_addrtype != 0) { + if (nla_put_flag(skb, TIPC_NLA_CON_FLAG)) + goto msg_full; + if (nla_put_u32(skb, TIPC_NLA_CON_TYPE, conn_type)) + goto msg_full; + if (nla_put_u32(skb, TIPC_NLA_CON_INST, conn_instance)) + goto msg_full; + } + nla_nest_end(skb, nest); + + return 0; + +msg_full: + nla_nest_cancel(skb, nest); + + return -EMSGSIZE; +} + +static int __tipc_nl_add_sk_info(struct sk_buff *skb, struct tipc_sock + *tsk) +{ + struct net *net = sock_net(skb->sk); + struct sock *sk = &tsk->sk; + + if (nla_put_u32(skb, TIPC_NLA_SOCK_REF, tsk->portid) || + nla_put_u32(skb, TIPC_NLA_SOCK_ADDR, tipc_own_addr(net))) + return -EMSGSIZE; + + if (tipc_sk_connected(sk)) { + if (__tipc_nl_add_sk_con(skb, tsk)) + return -EMSGSIZE; + } else if (!list_empty(&tsk->publications)) { + if (nla_put_flag(skb, TIPC_NLA_SOCK_HAS_PUBL)) + return -EMSGSIZE; + } + return 0; +} + +/* Caller should hold socket lock for the passed tipc socket. */ +static int __tipc_nl_add_sk(struct sk_buff *skb, struct netlink_callback *cb, + struct tipc_sock *tsk) +{ + struct nlattr *attrs; + void *hdr; + + hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &tipc_genl_family, NLM_F_MULTI, TIPC_NL_SOCK_GET); + if (!hdr) + goto msg_cancel; + + attrs = nla_nest_start_noflag(skb, TIPC_NLA_SOCK); + if (!attrs) + goto genlmsg_cancel; + + if (__tipc_nl_add_sk_info(skb, tsk)) + goto attr_msg_cancel; + + nla_nest_end(skb, attrs); + genlmsg_end(skb, hdr); + + return 0; + +attr_msg_cancel: + nla_nest_cancel(skb, attrs); +genlmsg_cancel: + genlmsg_cancel(skb, hdr); +msg_cancel: + return -EMSGSIZE; +} + +int tipc_nl_sk_walk(struct sk_buff *skb, struct netlink_callback *cb, + int (*skb_handler)(struct sk_buff *skb, + struct netlink_callback *cb, + struct tipc_sock *tsk)) +{ + struct rhashtable_iter *iter = (void *)cb->args[4]; + struct tipc_sock *tsk; + int err; + + rhashtable_walk_start(iter); + while ((tsk = rhashtable_walk_next(iter)) != NULL) { + if (IS_ERR(tsk)) { + err = PTR_ERR(tsk); + if (err == -EAGAIN) { + err = 0; + continue; + } + break; + } + + sock_hold(&tsk->sk); + rhashtable_walk_stop(iter); + lock_sock(&tsk->sk); + err = skb_handler(skb, cb, tsk); + if (err) { + release_sock(&tsk->sk); + sock_put(&tsk->sk); + goto out; + } + release_sock(&tsk->sk); + rhashtable_walk_start(iter); + sock_put(&tsk->sk); + } + rhashtable_walk_stop(iter); +out: + return skb->len; +} +EXPORT_SYMBOL(tipc_nl_sk_walk); + +int tipc_dump_start(struct netlink_callback *cb) +{ + return __tipc_dump_start(cb, sock_net(cb->skb->sk)); +} +EXPORT_SYMBOL(tipc_dump_start); + +int __tipc_dump_start(struct netlink_callback *cb, struct net *net) +{ + /* tipc_nl_name_table_dump() uses cb->args[0...3]. */ + struct rhashtable_iter *iter = (void *)cb->args[4]; + struct tipc_net *tn = tipc_net(net); + + if (!iter) { + iter = kmalloc(sizeof(*iter), GFP_KERNEL); + if (!iter) + return -ENOMEM; + + cb->args[4] = (long)iter; + } + + rhashtable_walk_enter(&tn->sk_rht, iter); + return 0; +} + +int tipc_dump_done(struct netlink_callback *cb) +{ + struct rhashtable_iter *hti = (void *)cb->args[4]; + + rhashtable_walk_exit(hti); + kfree(hti); + return 0; +} +EXPORT_SYMBOL(tipc_dump_done); + +int tipc_sk_fill_sock_diag(struct sk_buff *skb, struct netlink_callback *cb, + struct tipc_sock *tsk, u32 sk_filter_state, + u64 (*tipc_diag_gen_cookie)(struct sock *sk)) +{ + struct sock *sk = &tsk->sk; + struct nlattr *attrs; + struct nlattr *stat; + + /*filter response w.r.t sk_state*/ + if (!(sk_filter_state & (1 << sk->sk_state))) + return 0; + + attrs = nla_nest_start_noflag(skb, TIPC_NLA_SOCK); + if (!attrs) + goto msg_cancel; + + if (__tipc_nl_add_sk_info(skb, tsk)) + goto attr_msg_cancel; + + if (nla_put_u32(skb, TIPC_NLA_SOCK_TYPE, (u32)sk->sk_type) || + nla_put_u32(skb, TIPC_NLA_SOCK_TIPC_STATE, (u32)sk->sk_state) || + nla_put_u32(skb, TIPC_NLA_SOCK_INO, sock_i_ino(sk)) || + nla_put_u32(skb, TIPC_NLA_SOCK_UID, + from_kuid_munged(sk_user_ns(NETLINK_CB(cb->skb).sk), + sock_i_uid(sk))) || + nla_put_u64_64bit(skb, TIPC_NLA_SOCK_COOKIE, + tipc_diag_gen_cookie(sk), + TIPC_NLA_SOCK_PAD)) + goto attr_msg_cancel; + + stat = nla_nest_start_noflag(skb, TIPC_NLA_SOCK_STAT); + if (!stat) + goto attr_msg_cancel; + + if (nla_put_u32(skb, TIPC_NLA_SOCK_STAT_RCVQ, + skb_queue_len(&sk->sk_receive_queue)) || + nla_put_u32(skb, TIPC_NLA_SOCK_STAT_SENDQ, + skb_queue_len(&sk->sk_write_queue)) || + nla_put_u32(skb, TIPC_NLA_SOCK_STAT_DROP, + atomic_read(&sk->sk_drops))) + goto stat_msg_cancel; + + if (tsk->cong_link_cnt && + nla_put_flag(skb, TIPC_NLA_SOCK_STAT_LINK_CONG)) + goto stat_msg_cancel; + + if (tsk_conn_cong(tsk) && + nla_put_flag(skb, TIPC_NLA_SOCK_STAT_CONN_CONG)) + goto stat_msg_cancel; + + nla_nest_end(skb, stat); + + if (tsk->group) + if (tipc_group_fill_sock_diag(tsk->group, skb)) + goto stat_msg_cancel; + + nla_nest_end(skb, attrs); + + return 0; + +stat_msg_cancel: + nla_nest_cancel(skb, stat); +attr_msg_cancel: + nla_nest_cancel(skb, attrs); +msg_cancel: + return -EMSGSIZE; +} +EXPORT_SYMBOL(tipc_sk_fill_sock_diag); + +int tipc_nl_sk_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + return tipc_nl_sk_walk(skb, cb, __tipc_nl_add_sk); +} + +/* Caller should hold socket lock for the passed tipc socket. */ +static int __tipc_nl_add_sk_publ(struct sk_buff *skb, + struct netlink_callback *cb, + struct publication *publ) +{ + void *hdr; + struct nlattr *attrs; + + hdr = genlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, + &tipc_genl_family, NLM_F_MULTI, TIPC_NL_PUBL_GET); + if (!hdr) + goto msg_cancel; + + attrs = nla_nest_start_noflag(skb, TIPC_NLA_PUBL); + if (!attrs) + goto genlmsg_cancel; + + if (nla_put_u32(skb, TIPC_NLA_PUBL_KEY, publ->key)) + goto attr_msg_cancel; + if (nla_put_u32(skb, TIPC_NLA_PUBL_TYPE, publ->sr.type)) + goto attr_msg_cancel; + if (nla_put_u32(skb, TIPC_NLA_PUBL_LOWER, publ->sr.lower)) + goto attr_msg_cancel; + if (nla_put_u32(skb, TIPC_NLA_PUBL_UPPER, publ->sr.upper)) + goto attr_msg_cancel; + + nla_nest_end(skb, attrs); + genlmsg_end(skb, hdr); + + return 0; + +attr_msg_cancel: + nla_nest_cancel(skb, attrs); +genlmsg_cancel: + genlmsg_cancel(skb, hdr); +msg_cancel: + return -EMSGSIZE; +} + +/* Caller should hold socket lock for the passed tipc socket. */ +static int __tipc_nl_list_sk_publ(struct sk_buff *skb, + struct netlink_callback *cb, + struct tipc_sock *tsk, u32 *last_publ) +{ + int err; + struct publication *p; + + if (*last_publ) { + list_for_each_entry(p, &tsk->publications, binding_sock) { + if (p->key == *last_publ) + break; + } + if (list_entry_is_head(p, &tsk->publications, binding_sock)) { + /* We never set seq or call nl_dump_check_consistent() + * this means that setting prev_seq here will cause the + * consistence check to fail in the netlink callback + * handler. Resulting in the last NLMSG_DONE message + * having the NLM_F_DUMP_INTR flag set. + */ + cb->prev_seq = 1; + *last_publ = 0; + return -EPIPE; + } + } else { + p = list_first_entry(&tsk->publications, struct publication, + binding_sock); + } + + list_for_each_entry_from(p, &tsk->publications, binding_sock) { + err = __tipc_nl_add_sk_publ(skb, cb, p); + if (err) { + *last_publ = p->key; + return err; + } + } + *last_publ = 0; + + return 0; +} + +int tipc_nl_publ_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + int err; + u32 tsk_portid = cb->args[0]; + u32 last_publ = cb->args[1]; + u32 done = cb->args[2]; + struct net *net = sock_net(skb->sk); + struct tipc_sock *tsk; + + if (!tsk_portid) { + struct nlattr **attrs = genl_dumpit_info(cb)->attrs; + struct nlattr *sock[TIPC_NLA_SOCK_MAX + 1]; + + if (!attrs[TIPC_NLA_SOCK]) + return -EINVAL; + + err = nla_parse_nested_deprecated(sock, TIPC_NLA_SOCK_MAX, + attrs[TIPC_NLA_SOCK], + tipc_nl_sock_policy, NULL); + if (err) + return err; + + if (!sock[TIPC_NLA_SOCK_REF]) + return -EINVAL; + + tsk_portid = nla_get_u32(sock[TIPC_NLA_SOCK_REF]); + } + + if (done) + return 0; + + tsk = tipc_sk_lookup(net, tsk_portid); + if (!tsk) + return -EINVAL; + + lock_sock(&tsk->sk); + err = __tipc_nl_list_sk_publ(skb, cb, tsk, &last_publ); + if (!err) + done = 1; + release_sock(&tsk->sk); + sock_put(&tsk->sk); + + cb->args[0] = tsk_portid; + cb->args[1] = last_publ; + cb->args[2] = done; + + return skb->len; +} + +/** + * tipc_sk_filtering - check if a socket should be traced + * @sk: the socket to be examined + * + * @sysctl_tipc_sk_filter is used as the socket tuple for filtering: + * (portid, sock type, name type, name lower, name upper) + * + * Return: true if the socket meets the socket tuple data + * (value 0 = 'any') or when there is no tuple set (all = 0), + * otherwise false + */ +bool tipc_sk_filtering(struct sock *sk) +{ + struct tipc_sock *tsk; + struct publication *p; + u32 _port, _sktype, _type, _lower, _upper; + u32 type = 0, lower = 0, upper = 0; + + if (!sk) + return true; + + tsk = tipc_sk(sk); + + _port = sysctl_tipc_sk_filter[0]; + _sktype = sysctl_tipc_sk_filter[1]; + _type = sysctl_tipc_sk_filter[2]; + _lower = sysctl_tipc_sk_filter[3]; + _upper = sysctl_tipc_sk_filter[4]; + + if (!_port && !_sktype && !_type && !_lower && !_upper) + return true; + + if (_port) + return (_port == tsk->portid); + + if (_sktype && _sktype != sk->sk_type) + return false; + + if (tsk->published) { + p = list_first_entry_or_null(&tsk->publications, + struct publication, binding_sock); + if (p) { + type = p->sr.type; + lower = p->sr.lower; + upper = p->sr.upper; + } + } + + if (!tipc_sk_type_connectionless(sk)) { + type = msg_nametype(&tsk->phdr); + lower = msg_nameinst(&tsk->phdr); + upper = lower; + } + + if ((_type && _type != type) || (_lower && _lower != lower) || + (_upper && _upper != upper)) + return false; + + return true; +} + +u32 tipc_sock_get_portid(struct sock *sk) +{ + return (sk) ? (tipc_sk(sk))->portid : 0; +} + +/** + * tipc_sk_overlimit1 - check if socket rx queue is about to be overloaded, + * both the rcv and backlog queues are considered + * @sk: tipc sk to be checked + * @skb: tipc msg to be checked + * + * Return: true if the socket rx queue allocation is > 90%, otherwise false + */ + +bool tipc_sk_overlimit1(struct sock *sk, struct sk_buff *skb) +{ + atomic_t *dcnt = &tipc_sk(sk)->dupl_rcvcnt; + unsigned int lim = rcvbuf_limit(sk, skb) + atomic_read(dcnt); + unsigned int qsize = sk->sk_backlog.len + sk_rmem_alloc_get(sk); + + return (qsize > lim * 90 / 100); +} + +/** + * tipc_sk_overlimit2 - check if socket rx queue is about to be overloaded, + * only the rcv queue is considered + * @sk: tipc sk to be checked + * @skb: tipc msg to be checked + * + * Return: true if the socket rx queue allocation is > 90%, otherwise false + */ + +bool tipc_sk_overlimit2(struct sock *sk, struct sk_buff *skb) +{ + unsigned int lim = rcvbuf_limit(sk, skb); + unsigned int qsize = sk_rmem_alloc_get(sk); + + return (qsize > lim * 90 / 100); +} + +/** + * tipc_sk_dump - dump TIPC socket + * @sk: tipc sk to be dumped + * @dqueues: bitmask to decide if any socket queue to be dumped? + * - TIPC_DUMP_NONE: don't dump socket queues + * - TIPC_DUMP_SK_SNDQ: dump socket send queue + * - TIPC_DUMP_SK_RCVQ: dump socket rcv queue + * - TIPC_DUMP_SK_BKLGQ: dump socket backlog queue + * - TIPC_DUMP_ALL: dump all the socket queues above + * @buf: returned buffer of dump data in format + */ +int tipc_sk_dump(struct sock *sk, u16 dqueues, char *buf) +{ + int i = 0; + size_t sz = (dqueues) ? SK_LMAX : SK_LMIN; + u32 conn_type, conn_instance; + struct tipc_sock *tsk; + struct publication *p; + bool tsk_connected; + + if (!sk) { + i += scnprintf(buf, sz, "sk data: (null)\n"); + return i; + } + + tsk = tipc_sk(sk); + tsk_connected = !tipc_sk_type_connectionless(sk); + + i += scnprintf(buf, sz, "sk data: %u", sk->sk_type); + i += scnprintf(buf + i, sz - i, " %d", sk->sk_state); + i += scnprintf(buf + i, sz - i, " %x", tsk_own_node(tsk)); + i += scnprintf(buf + i, sz - i, " %u", tsk->portid); + i += scnprintf(buf + i, sz - i, " | %u", tsk_connected); + if (tsk_connected) { + i += scnprintf(buf + i, sz - i, " %x", tsk_peer_node(tsk)); + i += scnprintf(buf + i, sz - i, " %u", tsk_peer_port(tsk)); + conn_type = msg_nametype(&tsk->phdr); + conn_instance = msg_nameinst(&tsk->phdr); + i += scnprintf(buf + i, sz - i, " %u", conn_type); + i += scnprintf(buf + i, sz - i, " %u", conn_instance); + } + i += scnprintf(buf + i, sz - i, " | %u", tsk->published); + if (tsk->published) { + p = list_first_entry_or_null(&tsk->publications, + struct publication, binding_sock); + i += scnprintf(buf + i, sz - i, " %u", (p) ? p->sr.type : 0); + i += scnprintf(buf + i, sz - i, " %u", (p) ? p->sr.lower : 0); + i += scnprintf(buf + i, sz - i, " %u", (p) ? p->sr.upper : 0); + } + i += scnprintf(buf + i, sz - i, " | %u", tsk->snd_win); + i += scnprintf(buf + i, sz - i, " %u", tsk->rcv_win); + i += scnprintf(buf + i, sz - i, " %u", tsk->max_pkt); + i += scnprintf(buf + i, sz - i, " %x", tsk->peer_caps); + i += scnprintf(buf + i, sz - i, " %u", tsk->cong_link_cnt); + i += scnprintf(buf + i, sz - i, " %u", tsk->snt_unacked); + i += scnprintf(buf + i, sz - i, " %u", tsk->rcv_unacked); + i += scnprintf(buf + i, sz - i, " %u", atomic_read(&tsk->dupl_rcvcnt)); + i += scnprintf(buf + i, sz - i, " %u", sk->sk_shutdown); + i += scnprintf(buf + i, sz - i, " | %d", sk_wmem_alloc_get(sk)); + i += scnprintf(buf + i, sz - i, " %d", sk->sk_sndbuf); + i += scnprintf(buf + i, sz - i, " | %d", sk_rmem_alloc_get(sk)); + i += scnprintf(buf + i, sz - i, " %d", sk->sk_rcvbuf); + i += scnprintf(buf + i, sz - i, " | %d\n", READ_ONCE(sk->sk_backlog.len)); + + if (dqueues & TIPC_DUMP_SK_SNDQ) { + i += scnprintf(buf + i, sz - i, "sk_write_queue: "); + i += tipc_list_dump(&sk->sk_write_queue, false, buf + i); + } + + if (dqueues & TIPC_DUMP_SK_RCVQ) { + i += scnprintf(buf + i, sz - i, "sk_receive_queue: "); + i += tipc_list_dump(&sk->sk_receive_queue, false, buf + i); + } + + if (dqueues & TIPC_DUMP_SK_BKLGQ) { + i += scnprintf(buf + i, sz - i, "sk_backlog:\n head "); + i += tipc_skb_dump(sk->sk_backlog.head, false, buf + i); + if (sk->sk_backlog.tail != sk->sk_backlog.head) { + i += scnprintf(buf + i, sz - i, " tail "); + i += tipc_skb_dump(sk->sk_backlog.tail, false, + buf + i); + } + } + + return i; +} diff --git a/net/tipc/socket.h b/net/tipc/socket.h new file mode 100644 index 000000000..02cdf1668 --- /dev/null +++ b/net/tipc/socket.h @@ -0,0 +1,80 @@ +/* net/tipc/socket.h: Include file for TIPC socket code + * + * Copyright (c) 2014-2016, Ericsson AB + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _TIPC_SOCK_H +#define _TIPC_SOCK_H + +#include +#include + +/* Compatibility values for deprecated message based flow control */ +#define FLOWCTL_MSG_WIN 512 +#define FLOWCTL_MSG_LIM ((FLOWCTL_MSG_WIN * 2 + 1) * SKB_TRUESIZE(MAX_MSG_SIZE)) + +#define FLOWCTL_BLK_SZ 1024 + +/* Socket receive buffer sizes */ +#define RCVBUF_MIN (FLOWCTL_BLK_SZ * 512) +#define RCVBUF_DEF (FLOWCTL_BLK_SZ * 1024 * 2) +#define RCVBUF_MAX (FLOWCTL_BLK_SZ * 1024 * 16) + +struct tipc_sock; + +int tipc_socket_init(void); +void tipc_socket_stop(void); +void tipc_sk_rcv(struct net *net, struct sk_buff_head *inputq); +void tipc_sk_mcast_rcv(struct net *net, struct sk_buff_head *arrvq, + struct sk_buff_head *inputq); +void tipc_sk_reinit(struct net *net); +int tipc_sk_rht_init(struct net *net); +void tipc_sk_rht_destroy(struct net *net); +int tipc_nl_sk_dump(struct sk_buff *skb, struct netlink_callback *cb); +int tipc_nl_publ_dump(struct sk_buff *skb, struct netlink_callback *cb); +int tipc_sk_fill_sock_diag(struct sk_buff *skb, struct netlink_callback *cb, + struct tipc_sock *tsk, u32 sk_filter_state, + u64 (*tipc_diag_gen_cookie)(struct sock *sk)); +int tipc_nl_sk_walk(struct sk_buff *skb, struct netlink_callback *cb, + int (*skb_handler)(struct sk_buff *skb, + struct netlink_callback *cb, + struct tipc_sock *tsk)); +int tipc_dump_start(struct netlink_callback *cb); +int __tipc_dump_start(struct netlink_callback *cb, struct net *net); +int tipc_dump_done(struct netlink_callback *cb); +u32 tipc_sock_get_portid(struct sock *sk); +bool tipc_sk_overlimit1(struct sock *sk, struct sk_buff *skb); +bool tipc_sk_overlimit2(struct sock *sk, struct sk_buff *skb); +int tipc_sk_bind(struct socket *sock, struct sockaddr *skaddr, int alen); +int tsk_set_importance(struct sock *sk, int imp); + +#endif diff --git a/net/tipc/subscr.c b/net/tipc/subscr.c new file mode 100644 index 000000000..05d49ad81 --- /dev/null +++ b/net/tipc/subscr.c @@ -0,0 +1,183 @@ +/* + * net/tipc/subscr.c: TIPC network topology service + * + * Copyright (c) 2000-2017, Ericsson AB + * Copyright (c) 2005-2007, 2010-2013, Wind River Systems + * Copyright (c) 2020-2021, Red Hat Inc + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "core.h" +#include "name_table.h" +#include "subscr.h" + +static void tipc_sub_send_event(struct tipc_subscription *sub, + struct publication *p, + u32 event) +{ + struct tipc_subscr *s = &sub->evt.s; + struct tipc_event *evt = &sub->evt; + + if (sub->inactive) + return; + tipc_evt_write(evt, event, event); + if (p) { + tipc_evt_write(evt, found_lower, p->sr.lower); + tipc_evt_write(evt, found_upper, p->sr.upper); + tipc_evt_write(evt, port.ref, p->sk.ref); + tipc_evt_write(evt, port.node, p->sk.node); + } else { + tipc_evt_write(evt, found_lower, s->seq.lower); + tipc_evt_write(evt, found_upper, s->seq.upper); + tipc_evt_write(evt, port.ref, 0); + tipc_evt_write(evt, port.node, 0); + } + tipc_topsrv_queue_evt(sub->net, sub->conid, event, evt); +} + +/** + * tipc_sub_check_overlap - test for subscription overlap with the given values + * @subscribed: the service range subscribed for + * @found: the service range we are checking for match + * + * Returns true if there is overlap, otherwise false. + */ +static bool tipc_sub_check_overlap(struct tipc_service_range *subscribed, + struct tipc_service_range *found) +{ + u32 found_lower = found->lower; + u32 found_upper = found->upper; + + if (found_lower < subscribed->lower) + found_lower = subscribed->lower; + if (found_upper > subscribed->upper) + found_upper = subscribed->upper; + return found_lower <= found_upper; +} + +void tipc_sub_report_overlap(struct tipc_subscription *sub, + struct publication *p, + u32 event, bool must) +{ + struct tipc_service_range *sr = &sub->s.seq; + u32 filter = sub->s.filter; + + if (!tipc_sub_check_overlap(sr, &p->sr)) + return; + if (!must && !(filter & TIPC_SUB_PORTS)) + return; + if (filter & TIPC_SUB_CLUSTER_SCOPE && p->scope == TIPC_NODE_SCOPE) + return; + if (filter & TIPC_SUB_NODE_SCOPE && p->scope != TIPC_NODE_SCOPE) + return; + spin_lock(&sub->lock); + tipc_sub_send_event(sub, p, event); + spin_unlock(&sub->lock); +} + +static void tipc_sub_timeout(struct timer_list *t) +{ + struct tipc_subscription *sub = from_timer(sub, t, timer); + + spin_lock(&sub->lock); + tipc_sub_send_event(sub, NULL, TIPC_SUBSCR_TIMEOUT); + sub->inactive = true; + spin_unlock(&sub->lock); +} + +static void tipc_sub_kref_release(struct kref *kref) +{ + kfree(container_of(kref, struct tipc_subscription, kref)); +} + +void tipc_sub_put(struct tipc_subscription *subscription) +{ + kref_put(&subscription->kref, tipc_sub_kref_release); +} + +void tipc_sub_get(struct tipc_subscription *subscription) +{ + kref_get(&subscription->kref); +} + +struct tipc_subscription *tipc_sub_subscribe(struct net *net, + struct tipc_subscr *s, + int conid) +{ + u32 lower = tipc_sub_read(s, seq.lower); + u32 upper = tipc_sub_read(s, seq.upper); + u32 filter = tipc_sub_read(s, filter); + struct tipc_subscription *sub; + u32 timeout; + + if ((filter & TIPC_SUB_PORTS && filter & TIPC_SUB_SERVICE) || + lower > upper) { + pr_warn("Subscription rejected, illegal request\n"); + return NULL; + } + sub = kmalloc(sizeof(*sub), GFP_ATOMIC); + if (!sub) { + pr_warn("Subscription rejected, no memory\n"); + return NULL; + } + INIT_LIST_HEAD(&sub->service_list); + INIT_LIST_HEAD(&sub->sub_list); + sub->net = net; + sub->conid = conid; + sub->inactive = false; + memcpy(&sub->evt.s, s, sizeof(*s)); + sub->s.seq.type = tipc_sub_read(s, seq.type); + sub->s.seq.lower = lower; + sub->s.seq.upper = upper; + sub->s.filter = filter; + sub->s.timeout = tipc_sub_read(s, timeout); + memcpy(sub->s.usr_handle, s->usr_handle, 8); + spin_lock_init(&sub->lock); + kref_init(&sub->kref); + if (!tipc_nametbl_subscribe(sub)) { + kfree(sub); + return NULL; + } + timer_setup(&sub->timer, tipc_sub_timeout, 0); + timeout = tipc_sub_read(&sub->evt.s, timeout); + if (timeout != TIPC_WAIT_FOREVER) + mod_timer(&sub->timer, jiffies + msecs_to_jiffies(timeout)); + return sub; +} + +void tipc_sub_unsubscribe(struct tipc_subscription *sub) +{ + tipc_nametbl_unsubscribe(sub); + if (sub->evt.s.timeout != TIPC_WAIT_FOREVER) + del_timer_sync(&sub->timer); + list_del(&sub->sub_list); + tipc_sub_put(sub); +} diff --git a/net/tipc/subscr.h b/net/tipc/subscr.h new file mode 100644 index 000000000..60b877531 --- /dev/null +++ b/net/tipc/subscr.h @@ -0,0 +1,122 @@ +/* + * net/tipc/subscr.h: Include file for TIPC network topology service + * + * Copyright (c) 2003-2017, Ericsson AB + * Copyright (c) 2005-2007, 2012-2013, Wind River Systems + * Copyright (c) 2020-2021, Red Hat Inc + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _TIPC_SUBSCR_H +#define _TIPC_SUBSCR_H + +#include "topsrv.h" + +#define TIPC_MAX_SUBSCR 65535 +#define TIPC_MAX_PUBL 65535 + +struct publication; +struct tipc_subscription; +struct tipc_conn; + +/** + * struct tipc_subscription - TIPC network topology subscription object + * @s: host-endian copy of the user subscription + * @evt: template for events generated by subscription + * @kref: reference count for this subscription + * @net: network namespace associated with subscription + * @timer: timer governing subscription duration (optional) + * @service_list: adjacent subscriptions in name sequence's subscription list + * @sub_list: adjacent subscriptions in subscriber's subscription list + * @conid: connection identifier of topology server + * @inactive: true if this subscription is inactive + * @lock: serialize up/down and timer events + */ +struct tipc_subscription { + struct tipc_subscr s; + struct tipc_event evt; + struct kref kref; + struct net *net; + struct timer_list timer; + struct list_head service_list; + struct list_head sub_list; + int conid; + bool inactive; + spinlock_t lock; +}; + +struct tipc_subscription *tipc_sub_subscribe(struct net *net, + struct tipc_subscr *s, + int conid); +void tipc_sub_unsubscribe(struct tipc_subscription *sub); +void tipc_sub_report_overlap(struct tipc_subscription *sub, + struct publication *p, + u32 event, bool must); + +int __net_init tipc_topsrv_init_net(struct net *net); +void __net_exit tipc_topsrv_exit_net(struct net *net); + +void tipc_sub_put(struct tipc_subscription *subscription); +void tipc_sub_get(struct tipc_subscription *subscription); + +#define TIPC_FILTER_MASK (TIPC_SUB_PORTS | TIPC_SUB_SERVICE | TIPC_SUB_CANCEL) + +/* tipc_sub_read - return field_ of struct sub_ in host endian format + */ +#define tipc_sub_read(sub_, field_) \ + ({ \ + struct tipc_subscr *sub__ = sub_; \ + u32 val__ = (sub__)->field_; \ + int swap_ = !((sub__)->filter & TIPC_FILTER_MASK); \ + (swap_ ? swab32(val__) : val__); \ + }) + +/* tipc_sub_write - write val_ to field_ of struct sub_ in user endian format + */ +#define tipc_sub_write(sub_, field_, val_) \ + ({ \ + struct tipc_subscr *sub__ = sub_; \ + u32 val__ = val_; \ + int swap_ = !((sub__)->filter & TIPC_FILTER_MASK); \ + (sub__)->field_ = swap_ ? swab32(val__) : val__; \ + }) + +/* tipc_evt_write - write val_ to field_ of struct evt_ in user endian format + */ +#define tipc_evt_write(evt_, field_, val_) \ + ({ \ + struct tipc_event *evt__ = evt_; \ + u32 val__ = val_; \ + int swap_ = !((evt__)->s.filter & (TIPC_FILTER_MASK)); \ + (evt__)->field_ = swap_ ? swab32(val__) : val__; \ + }) + +#endif diff --git a/net/tipc/sysctl.c b/net/tipc/sysctl.c new file mode 100644 index 000000000..9fb65c988 --- /dev/null +++ b/net/tipc/sysctl.c @@ -0,0 +1,108 @@ +/* + * net/tipc/sysctl.c: sysctl interface to TIPC subsystem + * + * Copyright (c) 2013, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "core.h" +#include "trace.h" +#include "crypto.h" +#include "bcast.h" +#include + +static struct ctl_table_header *tipc_ctl_hdr; + +static struct ctl_table tipc_table[] = { + { + .procname = "tipc_rmem", + .data = &sysctl_tipc_rmem, + .maxlen = sizeof(sysctl_tipc_rmem), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + }, + { + .procname = "named_timeout", + .data = &sysctl_tipc_named_timeout, + .maxlen = sizeof(sysctl_tipc_named_timeout), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + }, + { + .procname = "sk_filter", + .data = &sysctl_tipc_sk_filter, + .maxlen = sizeof(sysctl_tipc_sk_filter), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + }, +#ifdef CONFIG_TIPC_CRYPTO + { + .procname = "max_tfms", + .data = &sysctl_tipc_max_tfms, + .maxlen = sizeof(sysctl_tipc_max_tfms), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ONE, + }, + { + .procname = "key_exchange_enabled", + .data = &sysctl_tipc_key_exchange_enabled, + .maxlen = sizeof(sysctl_tipc_key_exchange_enabled), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = SYSCTL_ZERO, + .extra2 = SYSCTL_ONE, + }, +#endif + { + .procname = "bc_retruni", + .data = &sysctl_tipc_bc_retruni, + .maxlen = sizeof(sysctl_tipc_bc_retruni), + .mode = 0644, + .proc_handler = proc_doulongvec_minmax, + }, + {} +}; + +int tipc_register_sysctl(void) +{ + tipc_ctl_hdr = register_net_sysctl(&init_net, "net/tipc", tipc_table); + if (tipc_ctl_hdr == NULL) + return -ENOMEM; + return 0; +} + +void tipc_unregister_sysctl(void) +{ + unregister_net_sysctl_table(tipc_ctl_hdr); +} diff --git a/net/tipc/topsrv.c b/net/tipc/topsrv.c new file mode 100644 index 000000000..69c88cc03 --- /dev/null +++ b/net/tipc/topsrv.c @@ -0,0 +1,726 @@ +/* + * net/tipc/server.c: TIPC server infrastructure + * + * Copyright (c) 2012-2013, Wind River Systems + * Copyright (c) 2017-2018, Ericsson AB + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include "subscr.h" +#include "topsrv.h" +#include "core.h" +#include "socket.h" +#include "addr.h" +#include "msg.h" +#include "bearer.h" +#include +#include + +/* Number of messages to send before rescheduling */ +#define MAX_SEND_MSG_COUNT 25 +#define MAX_RECV_MSG_COUNT 25 +#define CF_CONNECTED 1 + +#define TIPC_SERVER_NAME_LEN 32 + +/** + * struct tipc_topsrv - TIPC server structure + * @conn_idr: identifier set of connection + * @idr_lock: protect the connection identifier set + * @idr_in_use: amount of allocated identifier entry + * @net: network namspace instance + * @awork: accept work item + * @rcv_wq: receive workqueue + * @send_wq: send workqueue + * @listener: topsrv listener socket + * @name: server name + */ +struct tipc_topsrv { + struct idr conn_idr; + spinlock_t idr_lock; /* for idr list */ + int idr_in_use; + struct net *net; + struct work_struct awork; + struct workqueue_struct *rcv_wq; + struct workqueue_struct *send_wq; + struct socket *listener; + char name[TIPC_SERVER_NAME_LEN]; +}; + +/** + * struct tipc_conn - TIPC connection structure + * @kref: reference counter to connection object + * @conid: connection identifier + * @sock: socket handler associated with connection + * @flags: indicates connection state + * @server: pointer to connected server + * @sub_list: lsit to all pertaing subscriptions + * @sub_lock: lock protecting the subscription list + * @rwork: receive work item + * @outqueue: pointer to first outbound message in queue + * @outqueue_lock: control access to the outqueue + * @swork: send work item + */ +struct tipc_conn { + struct kref kref; + int conid; + struct socket *sock; + unsigned long flags; + struct tipc_topsrv *server; + struct list_head sub_list; + spinlock_t sub_lock; /* for subscription list */ + struct work_struct rwork; + struct list_head outqueue; + spinlock_t outqueue_lock; /* for outqueue */ + struct work_struct swork; +}; + +/* An entry waiting to be sent */ +struct outqueue_entry { + bool inactive; + struct tipc_event evt; + struct list_head list; +}; + +static void tipc_conn_recv_work(struct work_struct *work); +static void tipc_conn_send_work(struct work_struct *work); +static void tipc_topsrv_kern_evt(struct net *net, struct tipc_event *evt); +static void tipc_conn_delete_sub(struct tipc_conn *con, struct tipc_subscr *s); + +static bool connected(struct tipc_conn *con) +{ + return con && test_bit(CF_CONNECTED, &con->flags); +} + +static void tipc_conn_kref_release(struct kref *kref) +{ + struct tipc_conn *con = container_of(kref, struct tipc_conn, kref); + struct tipc_topsrv *s = con->server; + struct outqueue_entry *e, *safe; + + spin_lock_bh(&s->idr_lock); + idr_remove(&s->conn_idr, con->conid); + s->idr_in_use--; + spin_unlock_bh(&s->idr_lock); + if (con->sock) + sock_release(con->sock); + + spin_lock_bh(&con->outqueue_lock); + list_for_each_entry_safe(e, safe, &con->outqueue, list) { + list_del(&e->list); + kfree(e); + } + spin_unlock_bh(&con->outqueue_lock); + kfree(con); +} + +static void conn_put(struct tipc_conn *con) +{ + kref_put(&con->kref, tipc_conn_kref_release); +} + +static void conn_get(struct tipc_conn *con) +{ + kref_get(&con->kref); +} + +static void tipc_conn_close(struct tipc_conn *con) +{ + struct sock *sk = con->sock->sk; + bool disconnect = false; + + write_lock_bh(&sk->sk_callback_lock); + disconnect = test_and_clear_bit(CF_CONNECTED, &con->flags); + + if (disconnect) { + sk->sk_user_data = NULL; + tipc_conn_delete_sub(con, NULL); + } + write_unlock_bh(&sk->sk_callback_lock); + + /* Handle concurrent calls from sending and receiving threads */ + if (!disconnect) + return; + + /* Don't flush pending works, -just let them expire */ + kernel_sock_shutdown(con->sock, SHUT_RDWR); + + conn_put(con); +} + +static struct tipc_conn *tipc_conn_alloc(struct tipc_topsrv *s, struct socket *sock) +{ + struct tipc_conn *con; + int ret; + + con = kzalloc(sizeof(*con), GFP_ATOMIC); + if (!con) + return ERR_PTR(-ENOMEM); + + kref_init(&con->kref); + INIT_LIST_HEAD(&con->outqueue); + INIT_LIST_HEAD(&con->sub_list); + spin_lock_init(&con->outqueue_lock); + spin_lock_init(&con->sub_lock); + INIT_WORK(&con->swork, tipc_conn_send_work); + INIT_WORK(&con->rwork, tipc_conn_recv_work); + + spin_lock_bh(&s->idr_lock); + ret = idr_alloc(&s->conn_idr, con, 0, 0, GFP_ATOMIC); + if (ret < 0) { + kfree(con); + spin_unlock_bh(&s->idr_lock); + return ERR_PTR(-ENOMEM); + } + con->conid = ret; + s->idr_in_use++; + + set_bit(CF_CONNECTED, &con->flags); + con->server = s; + con->sock = sock; + conn_get(con); + spin_unlock_bh(&s->idr_lock); + + return con; +} + +static struct tipc_conn *tipc_conn_lookup(struct tipc_topsrv *s, int conid) +{ + struct tipc_conn *con; + + spin_lock_bh(&s->idr_lock); + con = idr_find(&s->conn_idr, conid); + if (!connected(con) || !kref_get_unless_zero(&con->kref)) + con = NULL; + spin_unlock_bh(&s->idr_lock); + return con; +} + +/* tipc_conn_delete_sub - delete a specific or all subscriptions + * for a given subscriber + */ +static void tipc_conn_delete_sub(struct tipc_conn *con, struct tipc_subscr *s) +{ + struct tipc_net *tn = tipc_net(con->server->net); + struct list_head *sub_list = &con->sub_list; + struct tipc_subscription *sub, *tmp; + + spin_lock_bh(&con->sub_lock); + list_for_each_entry_safe(sub, tmp, sub_list, sub_list) { + if (!s || !memcmp(s, &sub->evt.s, sizeof(*s))) { + tipc_sub_unsubscribe(sub); + atomic_dec(&tn->subscription_count); + if (s) + break; + } + } + spin_unlock_bh(&con->sub_lock); +} + +static void tipc_conn_send_to_sock(struct tipc_conn *con) +{ + struct list_head *queue = &con->outqueue; + struct tipc_topsrv *srv = con->server; + struct outqueue_entry *e; + struct tipc_event *evt; + struct msghdr msg; + struct kvec iov; + int count = 0; + int ret; + + spin_lock_bh(&con->outqueue_lock); + + while (!list_empty(queue)) { + e = list_first_entry(queue, struct outqueue_entry, list); + evt = &e->evt; + spin_unlock_bh(&con->outqueue_lock); + + if (e->inactive) + tipc_conn_delete_sub(con, &evt->s); + + memset(&msg, 0, sizeof(msg)); + msg.msg_flags = MSG_DONTWAIT; + iov.iov_base = evt; + iov.iov_len = sizeof(*evt); + msg.msg_name = NULL; + + if (con->sock) { + ret = kernel_sendmsg(con->sock, &msg, &iov, + 1, sizeof(*evt)); + if (ret == -EWOULDBLOCK || ret == 0) { + cond_resched(); + return; + } else if (ret < 0) { + return tipc_conn_close(con); + } + } else { + tipc_topsrv_kern_evt(srv->net, evt); + } + + /* Don't starve users filling buffers */ + if (++count >= MAX_SEND_MSG_COUNT) { + cond_resched(); + count = 0; + } + spin_lock_bh(&con->outqueue_lock); + list_del(&e->list); + kfree(e); + } + spin_unlock_bh(&con->outqueue_lock); +} + +static void tipc_conn_send_work(struct work_struct *work) +{ + struct tipc_conn *con = container_of(work, struct tipc_conn, swork); + + if (connected(con)) + tipc_conn_send_to_sock(con); + + conn_put(con); +} + +/* tipc_topsrv_queue_evt() - interrupt level call from a subscription instance + * The queued work is launched into tipc_conn_send_work()->tipc_conn_send_to_sock() + */ +void tipc_topsrv_queue_evt(struct net *net, int conid, + u32 event, struct tipc_event *evt) +{ + struct tipc_topsrv *srv = tipc_topsrv(net); + struct outqueue_entry *e; + struct tipc_conn *con; + + con = tipc_conn_lookup(srv, conid); + if (!con) + return; + + if (!connected(con)) + goto err; + + e = kmalloc(sizeof(*e), GFP_ATOMIC); + if (!e) + goto err; + e->inactive = (event == TIPC_SUBSCR_TIMEOUT); + memcpy(&e->evt, evt, sizeof(*evt)); + spin_lock_bh(&con->outqueue_lock); + list_add_tail(&e->list, &con->outqueue); + spin_unlock_bh(&con->outqueue_lock); + + if (queue_work(srv->send_wq, &con->swork)) + return; +err: + conn_put(con); +} + +/* tipc_conn_write_space - interrupt callback after a sendmsg EAGAIN + * Indicates that there now is more space in the send buffer + * The queued work is launched into tipc_send_work()->tipc_conn_send_to_sock() + */ +static void tipc_conn_write_space(struct sock *sk) +{ + struct tipc_conn *con; + + read_lock_bh(&sk->sk_callback_lock); + con = sk->sk_user_data; + if (connected(con)) { + conn_get(con); + if (!queue_work(con->server->send_wq, &con->swork)) + conn_put(con); + } + read_unlock_bh(&sk->sk_callback_lock); +} + +static int tipc_conn_rcv_sub(struct tipc_topsrv *srv, + struct tipc_conn *con, + struct tipc_subscr *s) +{ + struct tipc_net *tn = tipc_net(srv->net); + struct tipc_subscription *sub; + u32 s_filter = tipc_sub_read(s, filter); + + if (s_filter & TIPC_SUB_CANCEL) { + tipc_sub_write(s, filter, s_filter & ~TIPC_SUB_CANCEL); + tipc_conn_delete_sub(con, s); + return 0; + } + if (atomic_read(&tn->subscription_count) >= TIPC_MAX_SUBSCR) { + pr_warn("Subscription rejected, max (%u)\n", TIPC_MAX_SUBSCR); + return -1; + } + sub = tipc_sub_subscribe(srv->net, s, con->conid); + if (!sub) + return -1; + atomic_inc(&tn->subscription_count); + spin_lock_bh(&con->sub_lock); + list_add(&sub->sub_list, &con->sub_list); + spin_unlock_bh(&con->sub_lock); + return 0; +} + +static int tipc_conn_rcv_from_sock(struct tipc_conn *con) +{ + struct tipc_topsrv *srv = con->server; + struct sock *sk = con->sock->sk; + struct msghdr msg = {}; + struct tipc_subscr s; + struct kvec iov; + int ret; + + iov.iov_base = &s; + iov.iov_len = sizeof(s); + msg.msg_name = NULL; + iov_iter_kvec(&msg.msg_iter, ITER_DEST, &iov, 1, iov.iov_len); + ret = sock_recvmsg(con->sock, &msg, MSG_DONTWAIT); + if (ret == -EWOULDBLOCK) + return -EWOULDBLOCK; + if (ret == sizeof(s)) { + read_lock_bh(&sk->sk_callback_lock); + /* RACE: the connection can be closed in the meantime */ + if (likely(connected(con))) + ret = tipc_conn_rcv_sub(srv, con, &s); + read_unlock_bh(&sk->sk_callback_lock); + if (!ret) + return 0; + } + + tipc_conn_close(con); + return ret; +} + +static void tipc_conn_recv_work(struct work_struct *work) +{ + struct tipc_conn *con = container_of(work, struct tipc_conn, rwork); + int count = 0; + + while (connected(con)) { + if (tipc_conn_rcv_from_sock(con)) + break; + + /* Don't flood Rx machine */ + if (++count >= MAX_RECV_MSG_COUNT) { + cond_resched(); + count = 0; + } + } + conn_put(con); +} + +/* tipc_conn_data_ready - interrupt callback indicating the socket has data + * The queued work is launched into tipc_recv_work()->tipc_conn_rcv_from_sock() + */ +static void tipc_conn_data_ready(struct sock *sk) +{ + struct tipc_conn *con; + + read_lock_bh(&sk->sk_callback_lock); + con = sk->sk_user_data; + if (connected(con)) { + conn_get(con); + if (!queue_work(con->server->rcv_wq, &con->rwork)) + conn_put(con); + } + read_unlock_bh(&sk->sk_callback_lock); +} + +static void tipc_topsrv_accept(struct work_struct *work) +{ + struct tipc_topsrv *srv = container_of(work, struct tipc_topsrv, awork); + struct socket *newsock, *lsock; + struct tipc_conn *con; + struct sock *newsk; + int ret; + + spin_lock_bh(&srv->idr_lock); + if (!srv->listener) { + spin_unlock_bh(&srv->idr_lock); + return; + } + lsock = srv->listener; + spin_unlock_bh(&srv->idr_lock); + + while (1) { + ret = kernel_accept(lsock, &newsock, O_NONBLOCK); + if (ret < 0) + return; + con = tipc_conn_alloc(srv, newsock); + if (IS_ERR(con)) { + ret = PTR_ERR(con); + sock_release(newsock); + return; + } + /* Register callbacks */ + newsk = newsock->sk; + write_lock_bh(&newsk->sk_callback_lock); + newsk->sk_data_ready = tipc_conn_data_ready; + newsk->sk_write_space = tipc_conn_write_space; + newsk->sk_user_data = con; + write_unlock_bh(&newsk->sk_callback_lock); + + /* Wake up receive process in case of 'SYN+' message */ + newsk->sk_data_ready(newsk); + conn_put(con); + } +} + +/* tipc_topsrv_listener_data_ready - interrupt callback with connection request + * The queued job is launched into tipc_topsrv_accept() + */ +static void tipc_topsrv_listener_data_ready(struct sock *sk) +{ + struct tipc_topsrv *srv; + + read_lock_bh(&sk->sk_callback_lock); + srv = sk->sk_user_data; + if (srv) + queue_work(srv->rcv_wq, &srv->awork); + read_unlock_bh(&sk->sk_callback_lock); +} + +static int tipc_topsrv_create_listener(struct tipc_topsrv *srv) +{ + struct socket *lsock = NULL; + struct sockaddr_tipc saddr; + struct sock *sk; + int rc; + + rc = sock_create_kern(srv->net, AF_TIPC, SOCK_SEQPACKET, 0, &lsock); + if (rc < 0) + return rc; + + srv->listener = lsock; + sk = lsock->sk; + write_lock_bh(&sk->sk_callback_lock); + sk->sk_data_ready = tipc_topsrv_listener_data_ready; + sk->sk_user_data = srv; + write_unlock_bh(&sk->sk_callback_lock); + + lock_sock(sk); + rc = tsk_set_importance(sk, TIPC_CRITICAL_IMPORTANCE); + release_sock(sk); + if (rc < 0) + goto err; + + saddr.family = AF_TIPC; + saddr.addrtype = TIPC_SERVICE_RANGE; + saddr.addr.nameseq.type = TIPC_TOP_SRV; + saddr.addr.nameseq.lower = TIPC_TOP_SRV; + saddr.addr.nameseq.upper = TIPC_TOP_SRV; + saddr.scope = TIPC_NODE_SCOPE; + + rc = tipc_sk_bind(lsock, (struct sockaddr *)&saddr, sizeof(saddr)); + if (rc < 0) + goto err; + rc = kernel_listen(lsock, 0); + if (rc < 0) + goto err; + + /* As server's listening socket owner and creator is the same module, + * we have to decrease TIPC module reference count to guarantee that + * it remains zero after the server socket is created, otherwise, + * executing "rmmod" command is unable to make TIPC module deleted + * after TIPC module is inserted successfully. + * + * However, the reference count is ever increased twice in + * sock_create_kern(): one is to increase the reference count of owner + * of TIPC socket's proto_ops struct; another is to increment the + * reference count of owner of TIPC proto struct. Therefore, we must + * decrement the module reference count twice to ensure that it keeps + * zero after server's listening socket is created. Of course, we + * must bump the module reference count twice as well before the socket + * is closed. + */ + module_put(lsock->ops->owner); + module_put(sk->sk_prot_creator->owner); + + return 0; +err: + sock_release(lsock); + return -EINVAL; +} + +bool tipc_topsrv_kern_subscr(struct net *net, u32 port, u32 type, u32 lower, + u32 upper, u32 filter, int *conid) +{ + struct tipc_subscr sub; + struct tipc_conn *con; + int rc; + + sub.seq.type = type; + sub.seq.lower = lower; + sub.seq.upper = upper; + sub.timeout = TIPC_WAIT_FOREVER; + sub.filter = filter; + *(u64 *)&sub.usr_handle = (u64)port; + + con = tipc_conn_alloc(tipc_topsrv(net), NULL); + if (IS_ERR(con)) + return false; + + *conid = con->conid; + rc = tipc_conn_rcv_sub(tipc_topsrv(net), con, &sub); + if (rc) + conn_put(con); + + conn_put(con); + return !rc; +} + +void tipc_topsrv_kern_unsubscr(struct net *net, int conid) +{ + struct tipc_conn *con; + + con = tipc_conn_lookup(tipc_topsrv(net), conid); + if (!con) + return; + + test_and_clear_bit(CF_CONNECTED, &con->flags); + tipc_conn_delete_sub(con, NULL); + conn_put(con); + conn_put(con); +} + +static void tipc_topsrv_kern_evt(struct net *net, struct tipc_event *evt) +{ + u32 port = *(u32 *)&evt->s.usr_handle; + u32 self = tipc_own_addr(net); + struct sk_buff_head evtq; + struct sk_buff *skb; + + skb = tipc_msg_create(TOP_SRV, 0, INT_H_SIZE, sizeof(*evt), + self, self, port, port, 0); + if (!skb) + return; + msg_set_dest_droppable(buf_msg(skb), true); + memcpy(msg_data(buf_msg(skb)), evt, sizeof(*evt)); + skb_queue_head_init(&evtq); + __skb_queue_tail(&evtq, skb); + tipc_loopback_trace(net, &evtq); + tipc_sk_rcv(net, &evtq); +} + +static int tipc_topsrv_work_start(struct tipc_topsrv *s) +{ + s->rcv_wq = alloc_ordered_workqueue("tipc_rcv", 0); + if (!s->rcv_wq) { + pr_err("can't start tipc receive workqueue\n"); + return -ENOMEM; + } + + s->send_wq = alloc_ordered_workqueue("tipc_send", 0); + if (!s->send_wq) { + pr_err("can't start tipc send workqueue\n"); + destroy_workqueue(s->rcv_wq); + return -ENOMEM; + } + + return 0; +} + +static void tipc_topsrv_work_stop(struct tipc_topsrv *s) +{ + destroy_workqueue(s->rcv_wq); + destroy_workqueue(s->send_wq); +} + +static int tipc_topsrv_start(struct net *net) +{ + struct tipc_net *tn = tipc_net(net); + const char name[] = "topology_server"; + struct tipc_topsrv *srv; + int ret; + + srv = kzalloc(sizeof(*srv), GFP_ATOMIC); + if (!srv) + return -ENOMEM; + + srv->net = net; + INIT_WORK(&srv->awork, tipc_topsrv_accept); + + strscpy(srv->name, name, sizeof(srv->name)); + tn->topsrv = srv; + atomic_set(&tn->subscription_count, 0); + + spin_lock_init(&srv->idr_lock); + idr_init(&srv->conn_idr); + srv->idr_in_use = 0; + + ret = tipc_topsrv_work_start(srv); + if (ret < 0) + goto err_start; + + ret = tipc_topsrv_create_listener(srv); + if (ret < 0) + goto err_create; + + return 0; + +err_create: + tipc_topsrv_work_stop(srv); +err_start: + kfree(srv); + return ret; +} + +static void tipc_topsrv_stop(struct net *net) +{ + struct tipc_topsrv *srv = tipc_topsrv(net); + struct socket *lsock = srv->listener; + struct tipc_conn *con; + int id; + + spin_lock_bh(&srv->idr_lock); + for (id = 0; srv->idr_in_use; id++) { + con = idr_find(&srv->conn_idr, id); + if (con) { + spin_unlock_bh(&srv->idr_lock); + tipc_conn_close(con); + spin_lock_bh(&srv->idr_lock); + } + } + __module_get(lsock->ops->owner); + __module_get(lsock->sk->sk_prot_creator->owner); + srv->listener = NULL; + spin_unlock_bh(&srv->idr_lock); + + tipc_topsrv_work_stop(srv); + sock_release(lsock); + idr_destroy(&srv->conn_idr); + kfree(srv); +} + +int __net_init tipc_topsrv_init_net(struct net *net) +{ + return tipc_topsrv_start(net); +} + +void __net_exit tipc_topsrv_exit_net(struct net *net) +{ + tipc_topsrv_stop(net); +} diff --git a/net/tipc/topsrv.h b/net/tipc/topsrv.h new file mode 100644 index 000000000..c7ea71293 --- /dev/null +++ b/net/tipc/topsrv.h @@ -0,0 +1,54 @@ +/* + * net/tipc/server.h: Include file for TIPC server code + * + * Copyright (c) 2012-2013, Wind River Systems + * Copyright (c) 2017, Ericsson AB + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _TIPC_SERVER_H +#define _TIPC_SERVER_H + +#include "core.h" + +#define TIPC_SERVER_NAME_LEN 32 +#define TIPC_SUB_CLUSTER_SCOPE 0x20 +#define TIPC_SUB_NODE_SCOPE 0x40 +#define TIPC_SUB_NO_STATUS 0x80 + +void tipc_topsrv_queue_evt(struct net *net, int conid, + u32 event, struct tipc_event *evt); + +bool tipc_topsrv_kern_subscr(struct net *net, u32 port, u32 type, u32 lower, + u32 upper, u32 filter, int *conid); +void tipc_topsrv_kern_unsubscr(struct net *net, int conid); + +#endif diff --git a/net/tipc/trace.c b/net/tipc/trace.c new file mode 100644 index 000000000..7d2931521 --- /dev/null +++ b/net/tipc/trace.c @@ -0,0 +1,206 @@ +/* + * net/tipc/trace.c: TIPC tracepoints code + * + * Copyright (c) 2018, Ericsson AB + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "ASIS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#define CREATE_TRACE_POINTS +#include "trace.h" + +/* + * socket tuples for filtering in socket traces: + * (portid, sock type, name type, name lower, name upper) + */ +unsigned long sysctl_tipc_sk_filter[5] __read_mostly = {0, }; + +/** + * tipc_skb_dump - dump TIPC skb data + * @skb: skb to be dumped + * @more: dump more? + * - false: dump only tipc msg data + * - true: dump kernel-related skb data and tipc cb[] array as well + * @buf: returned buffer of dump data in format + */ +int tipc_skb_dump(struct sk_buff *skb, bool more, char *buf) +{ + int i = 0; + size_t sz = (more) ? SKB_LMAX : SKB_LMIN; + struct tipc_msg *hdr; + struct tipc_skb_cb *skbcb; + + if (!skb) { + i += scnprintf(buf, sz, "msg: (null)\n"); + return i; + } + + hdr = buf_msg(skb); + skbcb = TIPC_SKB_CB(skb); + + /* tipc msg data section */ + i += scnprintf(buf, sz, "msg: %u", msg_user(hdr)); + i += scnprintf(buf + i, sz - i, " %u", msg_type(hdr)); + i += scnprintf(buf + i, sz - i, " %u", msg_hdr_sz(hdr)); + i += scnprintf(buf + i, sz - i, " %u", msg_data_sz(hdr)); + i += scnprintf(buf + i, sz - i, " %x", msg_orignode(hdr)); + i += scnprintf(buf + i, sz - i, " %x", msg_destnode(hdr)); + i += scnprintf(buf + i, sz - i, " %u", msg_seqno(hdr)); + i += scnprintf(buf + i, sz - i, " %u", msg_ack(hdr)); + i += scnprintf(buf + i, sz - i, " %u", msg_bcast_ack(hdr)); + switch (msg_user(hdr)) { + case LINK_PROTOCOL: + i += scnprintf(buf + i, sz - i, " %c", msg_net_plane(hdr)); + i += scnprintf(buf + i, sz - i, " %u", msg_probe(hdr)); + i += scnprintf(buf + i, sz - i, " %u", msg_peer_stopping(hdr)); + i += scnprintf(buf + i, sz - i, " %u", msg_session(hdr)); + i += scnprintf(buf + i, sz - i, " %u", msg_next_sent(hdr)); + i += scnprintf(buf + i, sz - i, " %u", msg_seq_gap(hdr)); + i += scnprintf(buf + i, sz - i, " %u", msg_bc_snd_nxt(hdr)); + i += scnprintf(buf + i, sz - i, " %u", msg_bc_gap(hdr)); + break; + case TIPC_LOW_IMPORTANCE: + case TIPC_MEDIUM_IMPORTANCE: + case TIPC_HIGH_IMPORTANCE: + case TIPC_CRITICAL_IMPORTANCE: + case CONN_MANAGER: + case SOCK_WAKEUP: + i += scnprintf(buf + i, sz - i, " | %u", msg_origport(hdr)); + i += scnprintf(buf + i, sz - i, " %u", msg_destport(hdr)); + switch (msg_type(hdr)) { + case TIPC_NAMED_MSG: + i += scnprintf(buf + i, sz - i, " %u", + msg_nametype(hdr)); + i += scnprintf(buf + i, sz - i, " %u", + msg_nameinst(hdr)); + break; + case TIPC_MCAST_MSG: + i += scnprintf(buf + i, sz - i, " %u", + msg_nametype(hdr)); + i += scnprintf(buf + i, sz - i, " %u", + msg_namelower(hdr)); + i += scnprintf(buf + i, sz - i, " %u", + msg_nameupper(hdr)); + break; + default: + break; + } + i += scnprintf(buf + i, sz - i, " | %u", + msg_src_droppable(hdr)); + i += scnprintf(buf + i, sz - i, " %u", + msg_dest_droppable(hdr)); + i += scnprintf(buf + i, sz - i, " %u", msg_errcode(hdr)); + i += scnprintf(buf + i, sz - i, " %u", msg_reroute_cnt(hdr)); + break; + default: + /* need more? */ + break; + } + + i += scnprintf(buf + i, sz - i, "\n"); + if (!more) + return i; + + /* kernel-related skb data section */ + i += scnprintf(buf + i, sz - i, "skb: %s", + (skb->dev) ? skb->dev->name : "n/a"); + i += scnprintf(buf + i, sz - i, " %u", skb->len); + i += scnprintf(buf + i, sz - i, " %u", skb->data_len); + i += scnprintf(buf + i, sz - i, " %u", skb->hdr_len); + i += scnprintf(buf + i, sz - i, " %u", skb->truesize); + i += scnprintf(buf + i, sz - i, " %u", skb_cloned(skb)); + i += scnprintf(buf + i, sz - i, " %p", skb->sk); + i += scnprintf(buf + i, sz - i, " %u", skb_shinfo(skb)->nr_frags); + i += scnprintf(buf + i, sz - i, " %llx", + ktime_to_ms(skb_get_ktime(skb))); + i += scnprintf(buf + i, sz - i, " %llx\n", + ktime_to_ms(skb_hwtstamps(skb)->hwtstamp)); + + /* tipc skb cb[] data section */ + i += scnprintf(buf + i, sz - i, "cb[]: %u", skbcb->bytes_read); + i += scnprintf(buf + i, sz - i, " %u", skbcb->orig_member); + i += scnprintf(buf + i, sz - i, " %u", + jiffies_to_msecs(skbcb->nxt_retr)); + i += scnprintf(buf + i, sz - i, " %u", skbcb->validated); + i += scnprintf(buf + i, sz - i, " %u", skbcb->chain_imp); + i += scnprintf(buf + i, sz - i, " %u\n", skbcb->ackers); + + return i; +} + +/** + * tipc_list_dump - dump TIPC skb list/queue + * @list: list of skbs to be dumped + * @more: dump more? + * - false: dump only the head & tail skbs + * - true: dump the first & last 5 skbs + * @buf: returned buffer of dump data in format + */ +int tipc_list_dump(struct sk_buff_head *list, bool more, char *buf) +{ + int i = 0; + size_t sz = (more) ? LIST_LMAX : LIST_LMIN; + u32 count, len; + struct sk_buff *hskb, *tskb, *skb, *tmp; + + if (!list) { + i += scnprintf(buf, sz, "(null)\n"); + return i; + } + + len = skb_queue_len(list); + i += scnprintf(buf, sz, "len = %d\n", len); + + if (!len) + return i; + + if (!more) { + hskb = skb_peek(list); + i += scnprintf(buf + i, sz - i, " head "); + i += tipc_skb_dump(hskb, false, buf + i); + if (len > 1) { + tskb = skb_peek_tail(list); + i += scnprintf(buf + i, sz - i, " tail "); + i += tipc_skb_dump(tskb, false, buf + i); + } + } else { + count = 0; + skb_queue_walk_safe(list, skb, tmp) { + count++; + if (count == 6) + i += scnprintf(buf + i, sz - i, " .\n .\n"); + if (count > 5 && count <= len - 5) + continue; + i += scnprintf(buf + i, sz - i, " #%d ", count); + i += tipc_skb_dump(skb, false, buf + i); + } + } + return i; +} diff --git a/net/tipc/trace.h b/net/tipc/trace.h new file mode 100644 index 000000000..04af83f05 --- /dev/null +++ b/net/tipc/trace.h @@ -0,0 +1,434 @@ +/* + * net/tipc/trace.h: TIPC tracepoints + * + * Copyright (c) 2018, Ericsson AB + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "ASIS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#undef TRACE_SYSTEM +#define TRACE_SYSTEM tipc + +#if !defined(_TIPC_TRACE_H) || defined(TRACE_HEADER_MULTI_READ) +#define _TIPC_TRACE_H + +#include +#include "core.h" +#include "link.h" +#include "socket.h" +#include "node.h" + +#define SKB_LMIN (100) +#define SKB_LMAX (SKB_LMIN * 2) +#define LIST_LMIN (SKB_LMIN * 3) +#define LIST_LMAX (SKB_LMIN * 11) +#define SK_LMIN (SKB_LMIN * 2) +#define SK_LMAX (SKB_LMIN * 11) +#define LINK_LMIN (SKB_LMIN) +#define LINK_LMAX (SKB_LMIN * 16) +#define NODE_LMIN (SKB_LMIN) +#define NODE_LMAX (SKB_LMIN * 11) + +#ifndef __TIPC_TRACE_ENUM +#define __TIPC_TRACE_ENUM +enum { + TIPC_DUMP_NONE = 0, + + TIPC_DUMP_TRANSMQ = 1, + TIPC_DUMP_BACKLOGQ = (1 << 1), + TIPC_DUMP_DEFERDQ = (1 << 2), + TIPC_DUMP_INPUTQ = (1 << 3), + TIPC_DUMP_WAKEUP = (1 << 4), + + TIPC_DUMP_SK_SNDQ = (1 << 8), + TIPC_DUMP_SK_RCVQ = (1 << 9), + TIPC_DUMP_SK_BKLGQ = (1 << 10), + TIPC_DUMP_ALL = 0xffffu +}; +#endif + +/* Link & Node FSM states: */ +#define state_sym(val) \ + __print_symbolic(val, \ + {(0xe), "ESTABLISHED" },\ + {(0xe << 4), "ESTABLISHING" },\ + {(0x1 << 8), "RESET" },\ + {(0x2 << 12), "RESETTING" },\ + {(0xd << 16), "PEER_RESET" },\ + {(0xf << 20), "FAILINGOVER" },\ + {(0xc << 24), "SYNCHING" },\ + {(0xdd), "SELF_DOWN_PEER_DOWN" },\ + {(0xaa), "SELF_UP_PEER_UP" },\ + {(0xd1), "SELF_DOWN_PEER_LEAVING" },\ + {(0xac), "SELF_UP_PEER_COMING" },\ + {(0xca), "SELF_COMING_PEER_UP" },\ + {(0x1d), "SELF_LEAVING_PEER_DOWN" },\ + {(0xf0), "FAILINGOVER" },\ + {(0xcc), "SYNCHING" }) + +/* Link & Node FSM events: */ +#define evt_sym(val) \ + __print_symbolic(val, \ + {(0xec1ab1e), "ESTABLISH_EVT" },\ + {(0x9eed0e), "PEER_RESET_EVT" },\ + {(0xfa110e), "FAILURE_EVT" },\ + {(0x10ca1d0e), "RESET_EVT" },\ + {(0xfa110bee), "FAILOVER_BEGIN_EVT" },\ + {(0xfa110ede), "FAILOVER_END_EVT" },\ + {(0xc1ccbee), "SYNCH_BEGIN_EVT" },\ + {(0xc1ccede), "SYNCH_END_EVT" },\ + {(0xece), "SELF_ESTABL_CONTACT_EVT" },\ + {(0x1ce), "SELF_LOST_CONTACT_EVT" },\ + {(0x9ece), "PEER_ESTABL_CONTACT_EVT" },\ + {(0x91ce), "PEER_LOST_CONTACT_EVT" },\ + {(0xfbe), "FAILOVER_BEGIN_EVT" },\ + {(0xfee), "FAILOVER_END_EVT" },\ + {(0xcbe), "SYNCH_BEGIN_EVT" },\ + {(0xcee), "SYNCH_END_EVT" }) + +/* Bearer, net device events: */ +#define dev_evt_sym(val) \ + __print_symbolic(val, \ + {(NETDEV_CHANGE), "NETDEV_CHANGE" },\ + {(NETDEV_GOING_DOWN), "NETDEV_GOING_DOWN" },\ + {(NETDEV_UP), "NETDEV_UP" },\ + {(NETDEV_CHANGEMTU), "NETDEV_CHANGEMTU" },\ + {(NETDEV_CHANGEADDR), "NETDEV_CHANGEADDR" },\ + {(NETDEV_UNREGISTER), "NETDEV_UNREGISTER" },\ + {(NETDEV_CHANGENAME), "NETDEV_CHANGENAME" }) + +extern unsigned long sysctl_tipc_sk_filter[5] __read_mostly; + +int tipc_skb_dump(struct sk_buff *skb, bool more, char *buf); +int tipc_list_dump(struct sk_buff_head *list, bool more, char *buf); +int tipc_sk_dump(struct sock *sk, u16 dqueues, char *buf); +int tipc_link_dump(struct tipc_link *l, u16 dqueues, char *buf); +int tipc_node_dump(struct tipc_node *n, bool more, char *buf); +bool tipc_sk_filtering(struct sock *sk); + +DECLARE_EVENT_CLASS(tipc_skb_class, + + TP_PROTO(struct sk_buff *skb, bool more, const char *header), + + TP_ARGS(skb, more, header), + + TP_STRUCT__entry( + __string(header, header) + __dynamic_array(char, buf, (more) ? SKB_LMAX : SKB_LMIN) + ), + + TP_fast_assign( + __assign_str(header, header); + tipc_skb_dump(skb, more, __get_str(buf)); + ), + + TP_printk("%s\n%s", __get_str(header), __get_str(buf)) +) + +#define DEFINE_SKB_EVENT(name) \ +DEFINE_EVENT(tipc_skb_class, name, \ + TP_PROTO(struct sk_buff *skb, bool more, const char *header), \ + TP_ARGS(skb, more, header)) +DEFINE_SKB_EVENT(tipc_skb_dump); +DEFINE_SKB_EVENT(tipc_proto_build); +DEFINE_SKB_EVENT(tipc_proto_rcv); + +DECLARE_EVENT_CLASS(tipc_list_class, + + TP_PROTO(struct sk_buff_head *list, bool more, const char *header), + + TP_ARGS(list, more, header), + + TP_STRUCT__entry( + __string(header, header) + __dynamic_array(char, buf, (more) ? LIST_LMAX : LIST_LMIN) + ), + + TP_fast_assign( + __assign_str(header, header); + tipc_list_dump(list, more, __get_str(buf)); + ), + + TP_printk("%s\n%s", __get_str(header), __get_str(buf)) +); + +#define DEFINE_LIST_EVENT(name) \ +DEFINE_EVENT(tipc_list_class, name, \ + TP_PROTO(struct sk_buff_head *list, bool more, const char *header), \ + TP_ARGS(list, more, header)) +DEFINE_LIST_EVENT(tipc_list_dump); + +DECLARE_EVENT_CLASS(tipc_sk_class, + + TP_PROTO(struct sock *sk, struct sk_buff *skb, u16 dqueues, + const char *header), + + TP_ARGS(sk, skb, dqueues, header), + + TP_STRUCT__entry( + __string(header, header) + __field(u32, portid) + __dynamic_array(char, buf, (dqueues) ? SK_LMAX : SK_LMIN) + __dynamic_array(char, skb_buf, (skb) ? SKB_LMIN : 1) + ), + + TP_fast_assign( + __assign_str(header, header); + __entry->portid = tipc_sock_get_portid(sk); + tipc_sk_dump(sk, dqueues, __get_str(buf)); + if (skb) + tipc_skb_dump(skb, false, __get_str(skb_buf)); + else + *(__get_str(skb_buf)) = '\0'; + ), + + TP_printk("<%u> %s\n%s%s", __entry->portid, __get_str(header), + __get_str(skb_buf), __get_str(buf)) +); + +#define DEFINE_SK_EVENT_FILTER(name) \ +DEFINE_EVENT_CONDITION(tipc_sk_class, name, \ + TP_PROTO(struct sock *sk, struct sk_buff *skb, u16 dqueues, \ + const char *header), \ + TP_ARGS(sk, skb, dqueues, header), \ + TP_CONDITION(tipc_sk_filtering(sk))) +DEFINE_SK_EVENT_FILTER(tipc_sk_dump); +DEFINE_SK_EVENT_FILTER(tipc_sk_create); +DEFINE_SK_EVENT_FILTER(tipc_sk_sendmcast); +DEFINE_SK_EVENT_FILTER(tipc_sk_sendmsg); +DEFINE_SK_EVENT_FILTER(tipc_sk_sendstream); +DEFINE_SK_EVENT_FILTER(tipc_sk_poll); +DEFINE_SK_EVENT_FILTER(tipc_sk_filter_rcv); +DEFINE_SK_EVENT_FILTER(tipc_sk_advance_rx); +DEFINE_SK_EVENT_FILTER(tipc_sk_rej_msg); +DEFINE_SK_EVENT_FILTER(tipc_sk_drop_msg); +DEFINE_SK_EVENT_FILTER(tipc_sk_release); +DEFINE_SK_EVENT_FILTER(tipc_sk_shutdown); + +#define DEFINE_SK_EVENT_FILTER_COND(name, cond) \ +DEFINE_EVENT_CONDITION(tipc_sk_class, name, \ + TP_PROTO(struct sock *sk, struct sk_buff *skb, u16 dqueues, \ + const char *header), \ + TP_ARGS(sk, skb, dqueues, header), \ + TP_CONDITION(tipc_sk_filtering(sk) && (cond))) +DEFINE_SK_EVENT_FILTER_COND(tipc_sk_overlimit1, tipc_sk_overlimit1(sk, skb)); +DEFINE_SK_EVENT_FILTER_COND(tipc_sk_overlimit2, tipc_sk_overlimit2(sk, skb)); + +DECLARE_EVENT_CLASS(tipc_link_class, + + TP_PROTO(struct tipc_link *l, u16 dqueues, const char *header), + + TP_ARGS(l, dqueues, header), + + TP_STRUCT__entry( + __string(header, header) + __array(char, name, TIPC_MAX_LINK_NAME) + __dynamic_array(char, buf, (dqueues) ? LINK_LMAX : LINK_LMIN) + ), + + TP_fast_assign( + __assign_str(header, header); + memcpy(__entry->name, tipc_link_name(l), TIPC_MAX_LINK_NAME); + tipc_link_dump(l, dqueues, __get_str(buf)); + ), + + TP_printk("<%s> %s\n%s", __entry->name, __get_str(header), + __get_str(buf)) +); + +#define DEFINE_LINK_EVENT(name) \ +DEFINE_EVENT(tipc_link_class, name, \ + TP_PROTO(struct tipc_link *l, u16 dqueues, const char *header), \ + TP_ARGS(l, dqueues, header)) +DEFINE_LINK_EVENT(tipc_link_dump); +DEFINE_LINK_EVENT(tipc_link_conges); +DEFINE_LINK_EVENT(tipc_link_timeout); +DEFINE_LINK_EVENT(tipc_link_reset); + +#define DEFINE_LINK_EVENT_COND(name, cond) \ +DEFINE_EVENT_CONDITION(tipc_link_class, name, \ + TP_PROTO(struct tipc_link *l, u16 dqueues, const char *header), \ + TP_ARGS(l, dqueues, header), \ + TP_CONDITION(cond)) +DEFINE_LINK_EVENT_COND(tipc_link_too_silent, tipc_link_too_silent(l)); + +DECLARE_EVENT_CLASS(tipc_link_transmq_class, + + TP_PROTO(struct tipc_link *r, u16 f, u16 t, struct sk_buff_head *tq), + + TP_ARGS(r, f, t, tq), + + TP_STRUCT__entry( + __array(char, name, TIPC_MAX_LINK_NAME) + __field(u16, from) + __field(u16, to) + __field(u32, len) + __field(u16, fseqno) + __field(u16, lseqno) + ), + + TP_fast_assign( + memcpy(__entry->name, tipc_link_name(r), TIPC_MAX_LINK_NAME); + __entry->from = f; + __entry->to = t; + __entry->len = skb_queue_len(tq); + __entry->fseqno = __entry->len ? + msg_seqno(buf_msg(skb_peek(tq))) : 0; + __entry->lseqno = __entry->len ? + msg_seqno(buf_msg(skb_peek_tail(tq))) : 0; + ), + + TP_printk("<%s> retrans req: [%u-%u] transmq: %u [%u-%u]\n", + __entry->name, __entry->from, __entry->to, + __entry->len, __entry->fseqno, __entry->lseqno) +); + +DEFINE_EVENT_CONDITION(tipc_link_transmq_class, tipc_link_retrans, + TP_PROTO(struct tipc_link *r, u16 f, u16 t, struct sk_buff_head *tq), + TP_ARGS(r, f, t, tq), + TP_CONDITION(less_eq(f, t)) +); + +DEFINE_EVENT_PRINT(tipc_link_transmq_class, tipc_link_bc_ack, + TP_PROTO(struct tipc_link *r, u16 f, u16 t, struct sk_buff_head *tq), + TP_ARGS(r, f, t, tq), + TP_printk("<%s> acked: %u gap: %u transmq: %u [%u-%u]\n", + __entry->name, __entry->from, __entry->to, + __entry->len, __entry->fseqno, __entry->lseqno) +); + +DECLARE_EVENT_CLASS(tipc_node_class, + + TP_PROTO(struct tipc_node *n, bool more, const char *header), + + TP_ARGS(n, more, header), + + TP_STRUCT__entry( + __string(header, header) + __field(u32, addr) + __dynamic_array(char, buf, (more) ? NODE_LMAX : NODE_LMIN) + ), + + TP_fast_assign( + __assign_str(header, header); + __entry->addr = tipc_node_get_addr(n); + tipc_node_dump(n, more, __get_str(buf)); + ), + + TP_printk("<%x> %s\n%s", __entry->addr, __get_str(header), + __get_str(buf)) +); + +#define DEFINE_NODE_EVENT(name) \ +DEFINE_EVENT(tipc_node_class, name, \ + TP_PROTO(struct tipc_node *n, bool more, const char *header), \ + TP_ARGS(n, more, header)) +DEFINE_NODE_EVENT(tipc_node_dump); +DEFINE_NODE_EVENT(tipc_node_create); +DEFINE_NODE_EVENT(tipc_node_delete); +DEFINE_NODE_EVENT(tipc_node_lost_contact); +DEFINE_NODE_EVENT(tipc_node_timeout); +DEFINE_NODE_EVENT(tipc_node_link_up); +DEFINE_NODE_EVENT(tipc_node_link_down); +DEFINE_NODE_EVENT(tipc_node_reset_links); +DEFINE_NODE_EVENT(tipc_node_check_state); + +DECLARE_EVENT_CLASS(tipc_fsm_class, + + TP_PROTO(const char *name, u32 os, u32 ns, int evt), + + TP_ARGS(name, os, ns, evt), + + TP_STRUCT__entry( + __string(name, name) + __field(u32, os) + __field(u32, ns) + __field(u32, evt) + ), + + TP_fast_assign( + __assign_str(name, name); + __entry->os = os; + __entry->ns = ns; + __entry->evt = evt; + ), + + TP_printk("<%s> %s--(%s)->%s\n", __get_str(name), + state_sym(__entry->os), evt_sym(__entry->evt), + state_sym(__entry->ns)) +); + +#define DEFINE_FSM_EVENT(fsm_name) \ +DEFINE_EVENT(tipc_fsm_class, fsm_name, \ + TP_PROTO(const char *name, u32 os, u32 ns, int evt), \ + TP_ARGS(name, os, ns, evt)) +DEFINE_FSM_EVENT(tipc_link_fsm); +DEFINE_FSM_EVENT(tipc_node_fsm); + +TRACE_EVENT(tipc_l2_device_event, + + TP_PROTO(struct net_device *dev, struct tipc_bearer *b, + unsigned long evt), + + TP_ARGS(dev, b, evt), + + TP_STRUCT__entry( + __string(dev_name, dev->name) + __string(b_name, b->name) + __field(unsigned long, evt) + __field(u8, b_up) + __field(u8, carrier) + __field(u8, oper) + ), + + TP_fast_assign( + __assign_str(dev_name, dev->name); + __assign_str(b_name, b->name); + __entry->evt = evt; + __entry->b_up = test_bit(0, &b->up); + __entry->carrier = netif_carrier_ok(dev); + __entry->oper = netif_oper_up(dev); + ), + + TP_printk("%s on: <%s>/<%s> oper: %s carrier: %s bearer: %s\n", + dev_evt_sym(__entry->evt), __get_str(dev_name), + __get_str(b_name), (__entry->oper) ? "up" : "down", + (__entry->carrier) ? "ok" : "notok", + (__entry->b_up) ? "up" : "down") +); + +#endif /* _TIPC_TRACE_H */ + +/* This part must be outside protection */ +#undef TRACE_INCLUDE_PATH +#define TRACE_INCLUDE_PATH . +#undef TRACE_INCLUDE_FILE +#define TRACE_INCLUDE_FILE trace +#include diff --git a/net/tipc/udp_media.c b/net/tipc/udp_media.c new file mode 100644 index 000000000..0a85244fd --- /dev/null +++ b/net/tipc/udp_media.c @@ -0,0 +1,859 @@ +/* net/tipc/udp_media.c: IP bearer support for TIPC + * + * Copyright (c) 2015, Ericsson AB + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "core.h" +#include "addr.h" +#include "net.h" +#include "bearer.h" +#include "netlink.h" +#include "msg.h" +#include "udp_media.h" + +/* IANA assigned UDP port */ +#define UDP_PORT_DEFAULT 6118 + +#define UDP_MIN_HEADROOM 48 + +/** + * struct udp_media_addr - IP/UDP addressing information + * + * This is the bearer level originating address used in neighbor discovery + * messages, and all fields should be in network byte order + * + * @proto: Ethernet protocol in use + * @port: port being used + * @ipv4: IPv4 address of neighbor + * @ipv6: IPv6 address of neighbor + */ +struct udp_media_addr { + __be16 proto; + __be16 port; + union { + struct in_addr ipv4; + struct in6_addr ipv6; + }; +}; + +/* struct udp_replicast - container for UDP remote addresses */ +struct udp_replicast { + struct udp_media_addr addr; + struct dst_cache dst_cache; + struct rcu_head rcu; + struct list_head list; +}; + +/** + * struct udp_bearer - ip/udp bearer data structure + * @bearer: associated generic tipc bearer + * @ubsock: bearer associated socket + * @ifindex: local address scope + * @work: used to schedule deferred work on a bearer + * @rcast: associated udp_replicast container + */ +struct udp_bearer { + struct tipc_bearer __rcu *bearer; + struct socket *ubsock; + u32 ifindex; + struct work_struct work; + struct udp_replicast rcast; +}; + +static int tipc_udp_is_mcast_addr(struct udp_media_addr *addr) +{ + if (ntohs(addr->proto) == ETH_P_IP) + return ipv4_is_multicast(addr->ipv4.s_addr); +#if IS_ENABLED(CONFIG_IPV6) + else + return ipv6_addr_is_multicast(&addr->ipv6); +#endif + return 0; +} + +/* udp_media_addr_set - convert a ip/udp address to a TIPC media address */ +static void tipc_udp_media_addr_set(struct tipc_media_addr *addr, + struct udp_media_addr *ua) +{ + memset(addr, 0, sizeof(struct tipc_media_addr)); + addr->media_id = TIPC_MEDIA_TYPE_UDP; + memcpy(addr->value, ua, sizeof(struct udp_media_addr)); + + if (tipc_udp_is_mcast_addr(ua)) + addr->broadcast = TIPC_BROADCAST_SUPPORT; +} + +/* tipc_udp_addr2str - convert ip/udp address to string */ +static int tipc_udp_addr2str(struct tipc_media_addr *a, char *buf, int size) +{ + struct udp_media_addr *ua = (struct udp_media_addr *)&a->value; + + if (ntohs(ua->proto) == ETH_P_IP) + snprintf(buf, size, "%pI4:%u", &ua->ipv4, ntohs(ua->port)); + else if (ntohs(ua->proto) == ETH_P_IPV6) + snprintf(buf, size, "%pI6:%u", &ua->ipv6, ntohs(ua->port)); + else + pr_err("Invalid UDP media address\n"); + return 0; +} + +/* tipc_udp_msg2addr - extract an ip/udp address from a TIPC ndisc message */ +static int tipc_udp_msg2addr(struct tipc_bearer *b, struct tipc_media_addr *a, + char *msg) +{ + struct udp_media_addr *ua; + + ua = (struct udp_media_addr *) (msg + TIPC_MEDIA_ADDR_OFFSET); + if (msg[TIPC_MEDIA_TYPE_OFFSET] != TIPC_MEDIA_TYPE_UDP) + return -EINVAL; + tipc_udp_media_addr_set(a, ua); + return 0; +} + +/* tipc_udp_addr2msg - write an ip/udp address to a TIPC ndisc message */ +static int tipc_udp_addr2msg(char *msg, struct tipc_media_addr *a) +{ + memset(msg, 0, TIPC_MEDIA_INFO_SIZE); + msg[TIPC_MEDIA_TYPE_OFFSET] = TIPC_MEDIA_TYPE_UDP; + memcpy(msg + TIPC_MEDIA_ADDR_OFFSET, a->value, + sizeof(struct udp_media_addr)); + return 0; +} + +/* tipc_send_msg - enqueue a send request */ +static int tipc_udp_xmit(struct net *net, struct sk_buff *skb, + struct udp_bearer *ub, struct udp_media_addr *src, + struct udp_media_addr *dst, struct dst_cache *cache) +{ + struct dst_entry *ndst; + int ttl, err = 0; + + local_bh_disable(); + ndst = dst_cache_get(cache); + if (dst->proto == htons(ETH_P_IP)) { + struct rtable *rt = (struct rtable *)ndst; + + if (!rt) { + struct flowi4 fl = { + .daddr = dst->ipv4.s_addr, + .saddr = src->ipv4.s_addr, + .flowi4_mark = skb->mark, + .flowi4_proto = IPPROTO_UDP + }; + rt = ip_route_output_key(net, &fl); + if (IS_ERR(rt)) { + err = PTR_ERR(rt); + goto tx_error; + } + dst_cache_set_ip4(cache, &rt->dst, fl.saddr); + } + + ttl = ip4_dst_hoplimit(&rt->dst); + udp_tunnel_xmit_skb(rt, ub->ubsock->sk, skb, src->ipv4.s_addr, + dst->ipv4.s_addr, 0, ttl, 0, src->port, + dst->port, false, true); +#if IS_ENABLED(CONFIG_IPV6) + } else { + if (!ndst) { + struct flowi6 fl6 = { + .flowi6_oif = ub->ifindex, + .daddr = dst->ipv6, + .saddr = src->ipv6, + .flowi6_proto = IPPROTO_UDP + }; + ndst = ipv6_stub->ipv6_dst_lookup_flow(net, + ub->ubsock->sk, + &fl6, NULL); + if (IS_ERR(ndst)) { + err = PTR_ERR(ndst); + goto tx_error; + } + dst_cache_set_ip6(cache, ndst, &fl6.saddr); + } + ttl = ip6_dst_hoplimit(ndst); + err = udp_tunnel6_xmit_skb(ndst, ub->ubsock->sk, skb, NULL, + &src->ipv6, &dst->ipv6, 0, ttl, 0, + src->port, dst->port, false); +#endif + } + local_bh_enable(); + return err; + +tx_error: + local_bh_enable(); + kfree_skb(skb); + return err; +} + +static int tipc_udp_send_msg(struct net *net, struct sk_buff *skb, + struct tipc_bearer *b, + struct tipc_media_addr *addr) +{ + struct udp_media_addr *src = (struct udp_media_addr *)&b->addr.value; + struct udp_media_addr *dst = (struct udp_media_addr *)&addr->value; + struct udp_replicast *rcast; + struct udp_bearer *ub; + int err = 0; + + if (skb_headroom(skb) < UDP_MIN_HEADROOM) { + err = pskb_expand_head(skb, UDP_MIN_HEADROOM, 0, GFP_ATOMIC); + if (err) + goto out; + } + + skb_set_inner_protocol(skb, htons(ETH_P_TIPC)); + ub = rcu_dereference(b->media_ptr); + if (!ub) { + err = -ENODEV; + goto out; + } + + if (addr->broadcast != TIPC_REPLICAST_SUPPORT) + return tipc_udp_xmit(net, skb, ub, src, dst, + &ub->rcast.dst_cache); + + /* Replicast, send an skb to each configured IP address */ + list_for_each_entry_rcu(rcast, &ub->rcast.list, list) { + struct sk_buff *_skb; + + _skb = pskb_copy(skb, GFP_ATOMIC); + if (!_skb) { + err = -ENOMEM; + goto out; + } + + err = tipc_udp_xmit(net, _skb, ub, src, &rcast->addr, + &rcast->dst_cache); + if (err) + goto out; + } + err = 0; +out: + kfree_skb(skb); + return err; +} + +static bool tipc_udp_is_known_peer(struct tipc_bearer *b, + struct udp_media_addr *addr) +{ + struct udp_replicast *rcast, *tmp; + struct udp_bearer *ub; + + ub = rcu_dereference_rtnl(b->media_ptr); + if (!ub) { + pr_err_ratelimited("UDP bearer instance not found\n"); + return false; + } + + list_for_each_entry_safe(rcast, tmp, &ub->rcast.list, list) { + if (!memcmp(&rcast->addr, addr, sizeof(struct udp_media_addr))) + return true; + } + + return false; +} + +static int tipc_udp_rcast_add(struct tipc_bearer *b, + struct udp_media_addr *addr) +{ + struct udp_replicast *rcast; + struct udp_bearer *ub; + + ub = rcu_dereference_rtnl(b->media_ptr); + if (!ub) + return -ENODEV; + + rcast = kmalloc(sizeof(*rcast), GFP_ATOMIC); + if (!rcast) + return -ENOMEM; + + if (dst_cache_init(&rcast->dst_cache, GFP_ATOMIC)) { + kfree(rcast); + return -ENOMEM; + } + + memcpy(&rcast->addr, addr, sizeof(struct udp_media_addr)); + + if (ntohs(addr->proto) == ETH_P_IP) + pr_info("New replicast peer: %pI4\n", &rcast->addr.ipv4); +#if IS_ENABLED(CONFIG_IPV6) + else if (ntohs(addr->proto) == ETH_P_IPV6) + pr_info("New replicast peer: %pI6\n", &rcast->addr.ipv6); +#endif + b->bcast_addr.broadcast = TIPC_REPLICAST_SUPPORT; + list_add_rcu(&rcast->list, &ub->rcast.list); + return 0; +} + +static int tipc_udp_rcast_disc(struct tipc_bearer *b, struct sk_buff *skb) +{ + struct udp_media_addr src = {0}; + struct udp_media_addr *dst; + + dst = (struct udp_media_addr *)&b->bcast_addr.value; + if (tipc_udp_is_mcast_addr(dst)) + return 0; + + src.port = udp_hdr(skb)->source; + + if (ip_hdr(skb)->version == 4) { + struct iphdr *iphdr = ip_hdr(skb); + + src.proto = htons(ETH_P_IP); + src.ipv4.s_addr = iphdr->saddr; + if (ipv4_is_multicast(iphdr->daddr)) + return 0; +#if IS_ENABLED(CONFIG_IPV6) + } else if (ip_hdr(skb)->version == 6) { + struct ipv6hdr *iphdr = ipv6_hdr(skb); + + src.proto = htons(ETH_P_IPV6); + src.ipv6 = iphdr->saddr; + if (ipv6_addr_is_multicast(&iphdr->daddr)) + return 0; +#endif + } else { + return 0; + } + + if (likely(tipc_udp_is_known_peer(b, &src))) + return 0; + + return tipc_udp_rcast_add(b, &src); +} + +/* tipc_udp_recv - read data from bearer socket */ +static int tipc_udp_recv(struct sock *sk, struct sk_buff *skb) +{ + struct udp_bearer *ub; + struct tipc_bearer *b; + struct tipc_msg *hdr; + int err; + + ub = rcu_dereference_sk_user_data(sk); + if (!ub) { + pr_err_ratelimited("Failed to get UDP bearer reference"); + goto out; + } + skb_pull(skb, sizeof(struct udphdr)); + hdr = buf_msg(skb); + + b = rcu_dereference(ub->bearer); + if (!b) + goto out; + + if (b && test_bit(0, &b->up)) { + TIPC_SKB_CB(skb)->flags = 0; + tipc_rcv(sock_net(sk), skb, b); + return 0; + } + + if (unlikely(msg_user(hdr) == LINK_CONFIG)) { + err = tipc_udp_rcast_disc(b, skb); + if (err) + goto out; + } + +out: + kfree_skb(skb); + return 0; +} + +static int enable_mcast(struct udp_bearer *ub, struct udp_media_addr *remote) +{ + int err = 0; + struct ip_mreqn mreqn; + struct sock *sk = ub->ubsock->sk; + + if (ntohs(remote->proto) == ETH_P_IP) { + mreqn.imr_multiaddr = remote->ipv4; + mreqn.imr_ifindex = ub->ifindex; + err = ip_mc_join_group(sk, &mreqn); +#if IS_ENABLED(CONFIG_IPV6) + } else { + lock_sock(sk); + err = ipv6_stub->ipv6_sock_mc_join(sk, ub->ifindex, + &remote->ipv6); + release_sock(sk); +#endif + } + return err; +} + +static int __tipc_nl_add_udp_addr(struct sk_buff *skb, + struct udp_media_addr *addr, int nla_t) +{ + if (ntohs(addr->proto) == ETH_P_IP) { + struct sockaddr_in ip4; + + memset(&ip4, 0, sizeof(ip4)); + ip4.sin_family = AF_INET; + ip4.sin_port = addr->port; + ip4.sin_addr.s_addr = addr->ipv4.s_addr; + if (nla_put(skb, nla_t, sizeof(ip4), &ip4)) + return -EMSGSIZE; + +#if IS_ENABLED(CONFIG_IPV6) + } else if (ntohs(addr->proto) == ETH_P_IPV6) { + struct sockaddr_in6 ip6; + + memset(&ip6, 0, sizeof(ip6)); + ip6.sin6_family = AF_INET6; + ip6.sin6_port = addr->port; + memcpy(&ip6.sin6_addr, &addr->ipv6, sizeof(struct in6_addr)); + if (nla_put(skb, nla_t, sizeof(ip6), &ip6)) + return -EMSGSIZE; +#endif + } + + return 0; +} + +int tipc_udp_nl_dump_remoteip(struct sk_buff *skb, struct netlink_callback *cb) +{ + u32 bid = cb->args[0]; + u32 skip_cnt = cb->args[1]; + u32 portid = NETLINK_CB(cb->skb).portid; + struct udp_replicast *rcast, *tmp; + struct tipc_bearer *b; + struct udp_bearer *ub; + void *hdr; + int err; + int i; + + if (!bid && !skip_cnt) { + struct nlattr **attrs = genl_dumpit_info(cb)->attrs; + struct net *net = sock_net(skb->sk); + struct nlattr *battrs[TIPC_NLA_BEARER_MAX + 1]; + char *bname; + + if (!attrs[TIPC_NLA_BEARER]) + return -EINVAL; + + err = nla_parse_nested_deprecated(battrs, TIPC_NLA_BEARER_MAX, + attrs[TIPC_NLA_BEARER], + tipc_nl_bearer_policy, NULL); + if (err) + return err; + + if (!battrs[TIPC_NLA_BEARER_NAME]) + return -EINVAL; + + bname = nla_data(battrs[TIPC_NLA_BEARER_NAME]); + + rtnl_lock(); + b = tipc_bearer_find(net, bname); + if (!b) { + rtnl_unlock(); + return -EINVAL; + } + bid = b->identity; + } else { + struct net *net = sock_net(skb->sk); + struct tipc_net *tn = net_generic(net, tipc_net_id); + + rtnl_lock(); + b = rtnl_dereference(tn->bearer_list[bid]); + if (!b) { + rtnl_unlock(); + return -EINVAL; + } + } + + ub = rtnl_dereference(b->media_ptr); + if (!ub) { + rtnl_unlock(); + return -EINVAL; + } + + i = 0; + list_for_each_entry_safe(rcast, tmp, &ub->rcast.list, list) { + if (i < skip_cnt) + goto count; + + hdr = genlmsg_put(skb, portid, cb->nlh->nlmsg_seq, + &tipc_genl_family, NLM_F_MULTI, + TIPC_NL_BEARER_GET); + if (!hdr) + goto done; + + err = __tipc_nl_add_udp_addr(skb, &rcast->addr, + TIPC_NLA_UDP_REMOTE); + if (err) { + genlmsg_cancel(skb, hdr); + goto done; + } + genlmsg_end(skb, hdr); +count: + i++; + } +done: + rtnl_unlock(); + cb->args[0] = bid; + cb->args[1] = i; + + return skb->len; +} + +int tipc_udp_nl_add_bearer_data(struct tipc_nl_msg *msg, struct tipc_bearer *b) +{ + struct udp_media_addr *src = (struct udp_media_addr *)&b->addr.value; + struct udp_media_addr *dst; + struct udp_bearer *ub; + struct nlattr *nest; + + ub = rtnl_dereference(b->media_ptr); + if (!ub) + return -ENODEV; + + nest = nla_nest_start_noflag(msg->skb, TIPC_NLA_BEARER_UDP_OPTS); + if (!nest) + goto msg_full; + + if (__tipc_nl_add_udp_addr(msg->skb, src, TIPC_NLA_UDP_LOCAL)) + goto msg_full; + + dst = (struct udp_media_addr *)&b->bcast_addr.value; + if (__tipc_nl_add_udp_addr(msg->skb, dst, TIPC_NLA_UDP_REMOTE)) + goto msg_full; + + if (!list_empty(&ub->rcast.list)) { + if (nla_put_flag(msg->skb, TIPC_NLA_UDP_MULTI_REMOTEIP)) + goto msg_full; + } + + nla_nest_end(msg->skb, nest); + return 0; +msg_full: + nla_nest_cancel(msg->skb, nest); + return -EMSGSIZE; +} + +/** + * tipc_parse_udp_addr - build udp media address from netlink data + * @nla: netlink attribute containing sockaddr storage aligned address + * @addr: tipc media address to fill with address, port and protocol type + * @scope_id: IPv6 scope id pointer, not NULL indicates it's required + */ + +static int tipc_parse_udp_addr(struct nlattr *nla, struct udp_media_addr *addr, + u32 *scope_id) +{ + struct sockaddr_storage sa; + + nla_memcpy(&sa, nla, sizeof(sa)); + if (sa.ss_family == AF_INET) { + struct sockaddr_in *ip4 = (struct sockaddr_in *)&sa; + + addr->proto = htons(ETH_P_IP); + addr->port = ip4->sin_port; + addr->ipv4.s_addr = ip4->sin_addr.s_addr; + return 0; + +#if IS_ENABLED(CONFIG_IPV6) + } else if (sa.ss_family == AF_INET6) { + struct sockaddr_in6 *ip6 = (struct sockaddr_in6 *)&sa; + + addr->proto = htons(ETH_P_IPV6); + addr->port = ip6->sin6_port; + memcpy(&addr->ipv6, &ip6->sin6_addr, sizeof(struct in6_addr)); + + /* Scope ID is only interesting for local addresses */ + if (scope_id) { + int atype; + + atype = ipv6_addr_type(&ip6->sin6_addr); + if (__ipv6_addr_needs_scope_id(atype) && + !ip6->sin6_scope_id) { + return -EINVAL; + } + + *scope_id = ip6->sin6_scope_id ? : 0; + } + + return 0; +#endif + } + return -EADDRNOTAVAIL; +} + +int tipc_udp_nl_bearer_add(struct tipc_bearer *b, struct nlattr *attr) +{ + int err; + struct udp_media_addr addr = {0}; + struct nlattr *opts[TIPC_NLA_UDP_MAX + 1]; + struct udp_media_addr *dst; + + if (nla_parse_nested_deprecated(opts, TIPC_NLA_UDP_MAX, attr, tipc_nl_udp_policy, NULL)) + return -EINVAL; + + if (!opts[TIPC_NLA_UDP_REMOTE]) + return -EINVAL; + + err = tipc_parse_udp_addr(opts[TIPC_NLA_UDP_REMOTE], &addr, NULL); + if (err) + return err; + + dst = (struct udp_media_addr *)&b->bcast_addr.value; + if (tipc_udp_is_mcast_addr(dst)) { + pr_err("Can't add remote ip to TIPC UDP multicast bearer\n"); + return -EINVAL; + } + + if (tipc_udp_is_known_peer(b, &addr)) + return 0; + + return tipc_udp_rcast_add(b, &addr); +} + +/** + * tipc_udp_enable - callback to create a new udp bearer instance + * @net: network namespace + * @b: pointer to generic tipc_bearer + * @attrs: netlink bearer configuration + * + * validate the bearer parameters and initialize the udp bearer + * rtnl_lock should be held + */ +static int tipc_udp_enable(struct net *net, struct tipc_bearer *b, + struct nlattr *attrs[]) +{ + int err = -EINVAL; + struct udp_bearer *ub; + struct udp_media_addr remote = {0}; + struct udp_media_addr local = {0}; + struct udp_port_cfg udp_conf = {0}; + struct udp_tunnel_sock_cfg tuncfg = {NULL}; + struct nlattr *opts[TIPC_NLA_UDP_MAX + 1]; + u8 node_id[NODE_ID_LEN] = {0,}; + struct net_device *dev; + int rmcast = 0; + + ub = kzalloc(sizeof(*ub), GFP_ATOMIC); + if (!ub) + return -ENOMEM; + + INIT_LIST_HEAD(&ub->rcast.list); + + if (!attrs[TIPC_NLA_BEARER_UDP_OPTS]) + goto err; + + if (nla_parse_nested_deprecated(opts, TIPC_NLA_UDP_MAX, attrs[TIPC_NLA_BEARER_UDP_OPTS], tipc_nl_udp_policy, NULL)) + goto err; + + if (!opts[TIPC_NLA_UDP_LOCAL] || !opts[TIPC_NLA_UDP_REMOTE]) { + pr_err("Invalid UDP bearer configuration"); + err = -EINVAL; + goto err; + } + + err = tipc_parse_udp_addr(opts[TIPC_NLA_UDP_LOCAL], &local, + &ub->ifindex); + if (err) + goto err; + + err = tipc_parse_udp_addr(opts[TIPC_NLA_UDP_REMOTE], &remote, NULL); + if (err) + goto err; + + if (remote.proto != local.proto) { + err = -EINVAL; + goto err; + } + + /* Checking remote ip address */ + rmcast = tipc_udp_is_mcast_addr(&remote); + + /* Autoconfigure own node identity if needed */ + if (!tipc_own_id(net)) { + memcpy(node_id, local.ipv6.in6_u.u6_addr8, 16); + tipc_net_init(net, node_id, 0); + } + if (!tipc_own_id(net)) { + pr_warn("Failed to set node id, please configure manually\n"); + err = -EINVAL; + goto err; + } + + b->bcast_addr.media_id = TIPC_MEDIA_TYPE_UDP; + b->bcast_addr.broadcast = TIPC_BROADCAST_SUPPORT; + rcu_assign_pointer(b->media_ptr, ub); + rcu_assign_pointer(ub->bearer, b); + tipc_udp_media_addr_set(&b->addr, &local); + if (local.proto == htons(ETH_P_IP)) { + dev = __ip_dev_find(net, local.ipv4.s_addr, false); + if (!dev) { + err = -ENODEV; + goto err; + } + udp_conf.family = AF_INET; + + /* Switch to use ANY to receive packets from group */ + if (rmcast) + udp_conf.local_ip.s_addr = htonl(INADDR_ANY); + else + udp_conf.local_ip.s_addr = local.ipv4.s_addr; + udp_conf.use_udp_checksums = false; + ub->ifindex = dev->ifindex; + b->encap_hlen = sizeof(struct iphdr) + sizeof(struct udphdr); + if (tipc_mtu_bad(dev, b->encap_hlen)) { + err = -EINVAL; + goto err; + } + b->mtu = b->media->mtu; +#if IS_ENABLED(CONFIG_IPV6) + } else if (local.proto == htons(ETH_P_IPV6)) { + dev = ub->ifindex ? __dev_get_by_index(net, ub->ifindex) : NULL; + dev = ipv6_dev_find(net, &local.ipv6, dev); + if (!dev) { + err = -ENODEV; + goto err; + } + udp_conf.family = AF_INET6; + udp_conf.use_udp6_tx_checksums = true; + udp_conf.use_udp6_rx_checksums = true; + if (rmcast) + udp_conf.local_ip6 = in6addr_any; + else + udp_conf.local_ip6 = local.ipv6; + ub->ifindex = dev->ifindex; + b->encap_hlen = sizeof(struct ipv6hdr) + sizeof(struct udphdr); + b->mtu = 1280; +#endif + } else { + err = -EAFNOSUPPORT; + goto err; + } + udp_conf.local_udp_port = local.port; + err = udp_sock_create(net, &udp_conf, &ub->ubsock); + if (err) + goto err; + tuncfg.sk_user_data = ub; + tuncfg.encap_type = 1; + tuncfg.encap_rcv = tipc_udp_recv; + tuncfg.encap_destroy = NULL; + setup_udp_tunnel_sock(net, ub->ubsock, &tuncfg); + + err = dst_cache_init(&ub->rcast.dst_cache, GFP_ATOMIC); + if (err) + goto free; + + /* + * The bcast media address port is used for all peers and the ip + * is used if it's a multicast address. + */ + memcpy(&b->bcast_addr.value, &remote, sizeof(remote)); + if (rmcast) + err = enable_mcast(ub, &remote); + else + err = tipc_udp_rcast_add(b, &remote); + if (err) + goto free; + + return 0; + +free: + dst_cache_destroy(&ub->rcast.dst_cache); + udp_tunnel_sock_release(ub->ubsock); +err: + kfree(ub); + return err; +} + +/* cleanup_bearer - break the socket/bearer association */ +static void cleanup_bearer(struct work_struct *work) +{ + struct udp_bearer *ub = container_of(work, struct udp_bearer, work); + struct udp_replicast *rcast, *tmp; + + list_for_each_entry_safe(rcast, tmp, &ub->rcast.list, list) { + dst_cache_destroy(&rcast->dst_cache); + list_del_rcu(&rcast->list); + kfree_rcu(rcast, rcu); + } + + atomic_dec(&tipc_net(sock_net(ub->ubsock->sk))->wq_count); + dst_cache_destroy(&ub->rcast.dst_cache); + udp_tunnel_sock_release(ub->ubsock); + synchronize_net(); + kfree(ub); +} + +/* tipc_udp_disable - detach bearer from socket */ +static void tipc_udp_disable(struct tipc_bearer *b) +{ + struct udp_bearer *ub; + + ub = rtnl_dereference(b->media_ptr); + if (!ub) { + pr_err("UDP bearer instance not found\n"); + return; + } + sock_set_flag(ub->ubsock->sk, SOCK_DEAD); + RCU_INIT_POINTER(ub->bearer, NULL); + + /* sock_release need to be done outside of rtnl lock */ + atomic_inc(&tipc_net(sock_net(ub->ubsock->sk))->wq_count); + INIT_WORK(&ub->work, cleanup_bearer); + schedule_work(&ub->work); +} + +struct tipc_media udp_media_info = { + .send_msg = tipc_udp_send_msg, + .enable_media = tipc_udp_enable, + .disable_media = tipc_udp_disable, + .addr2str = tipc_udp_addr2str, + .addr2msg = tipc_udp_addr2msg, + .msg2addr = tipc_udp_msg2addr, + .priority = TIPC_DEF_LINK_PRI, + .tolerance = TIPC_DEF_LINK_TOL, + .min_win = TIPC_DEF_LINK_WIN, + .max_win = TIPC_DEF_LINK_WIN, + .mtu = TIPC_DEF_LINK_UDP_MTU, + .type_id = TIPC_MEDIA_TYPE_UDP, + .hwaddr_len = 0, + .name = "udp" +}; diff --git a/net/tipc/udp_media.h b/net/tipc/udp_media.h new file mode 100644 index 000000000..e7455cc73 --- /dev/null +++ b/net/tipc/udp_media.h @@ -0,0 +1,60 @@ +/* + * net/tipc/udp_media.h: Include file for UDP bearer media + * + * Copyright (c) 1996-2006, 2013-2016, Ericsson AB + * Copyright (c) 2005, 2010-2011, Wind River Systems + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the names of the copyright holders nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * Alternatively, this software may be distributed under the terms of the + * GNU General Public License ("GPL") version 2 as published by the Free + * Software Foundation. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +#ifdef CONFIG_TIPC_MEDIA_UDP +#ifndef _TIPC_UDP_MEDIA_H +#define _TIPC_UDP_MEDIA_H + +#include +#include + +int tipc_udp_nl_bearer_add(struct tipc_bearer *b, struct nlattr *attr); +int tipc_udp_nl_add_bearer_data(struct tipc_nl_msg *msg, struct tipc_bearer *b); +int tipc_udp_nl_dump_remoteip(struct sk_buff *skb, struct netlink_callback *cb); + +/* check if configured MTU is too low for tipc headers */ +static inline bool tipc_udp_mtu_bad(u32 mtu) +{ + if (mtu >= (TIPC_MIN_BEARER_MTU + sizeof(struct iphdr) + + sizeof(struct udphdr))) + return false; + + pr_warn("MTU too low for tipc bearer\n"); + return true; +} + +#endif +#endif diff --git a/net/tls/Kconfig b/net/tls/Kconfig new file mode 100644 index 000000000..0cdc1f7b6 --- /dev/null +++ b/net/tls/Kconfig @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# TLS configuration +# +config TLS + tristate "Transport Layer Security support" + depends on INET + select CRYPTO + select CRYPTO_AES + select CRYPTO_GCM + select STREAM_PARSER + select NET_SOCK_MSG + default n + help + Enable kernel support for TLS protocol. This allows symmetric + encryption handling of the TLS protocol to be done in-kernel. + + If unsure, say N. + +config TLS_DEVICE + bool "Transport Layer Security HW offload" + depends on TLS + select SOCK_VALIDATE_XMIT + select SOCK_RX_QUEUE_MAPPING + default n + help + Enable kernel support for HW offload of the TLS protocol. + + If unsure, say N. + +config TLS_TOE + bool "Transport Layer Security TCP stack bypass" + depends on TLS + default n + help + Enable kernel support for legacy HW offload of the TLS protocol, + which is incompatible with the Linux networking stack semantics. + + If unsure, say N. diff --git a/net/tls/Makefile b/net/tls/Makefile new file mode 100644 index 000000000..e41c80048 --- /dev/null +++ b/net/tls/Makefile @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Makefile for the TLS subsystem. +# + +CFLAGS_trace.o := -I$(src) + +obj-$(CONFIG_TLS) += tls.o + +tls-y := tls_main.o tls_sw.o tls_proc.o trace.o tls_strp.o + +tls-$(CONFIG_TLS_TOE) += tls_toe.o +tls-$(CONFIG_TLS_DEVICE) += tls_device.o tls_device_fallback.o diff --git a/net/tls/tls.h b/net/tls/tls.h new file mode 100644 index 000000000..0672acab2 --- /dev/null +++ b/net/tls/tls.h @@ -0,0 +1,328 @@ +/* + * Copyright (c) 2016 Tom Herbert + * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. + * Copyright (c) 2016-2017, Dave Watson . All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#ifndef _TLS_INT_H +#define _TLS_INT_H + +#include +#include +#include +#include + +#define TLS_PAGE_ORDER (min_t(unsigned int, PAGE_ALLOC_COSTLY_ORDER, \ + TLS_MAX_PAYLOAD_SIZE >> PAGE_SHIFT)) + +#define __TLS_INC_STATS(net, field) \ + __SNMP_INC_STATS((net)->mib.tls_statistics, field) +#define TLS_INC_STATS(net, field) \ + SNMP_INC_STATS((net)->mib.tls_statistics, field) +#define TLS_DEC_STATS(net, field) \ + SNMP_DEC_STATS((net)->mib.tls_statistics, field) + +/* TLS records are maintained in 'struct tls_rec'. It stores the memory pages + * allocated or mapped for each TLS record. After encryption, the records are + * stores in a linked list. + */ +struct tls_rec { + struct list_head list; + int tx_ready; + int tx_flags; + + struct sk_msg msg_plaintext; + struct sk_msg msg_encrypted; + + /* AAD | msg_plaintext.sg.data | sg_tag */ + struct scatterlist sg_aead_in[2]; + /* AAD | msg_encrypted.sg.data (data contains overhead for hdr & iv & tag) */ + struct scatterlist sg_aead_out[2]; + + char content_type; + struct scatterlist sg_content_type; + + struct sock *sk; + + char aad_space[TLS_AAD_SPACE_SIZE]; + u8 iv_data[MAX_IV_SIZE]; + struct aead_request aead_req; + u8 aead_req_ctx[]; +}; + +int __net_init tls_proc_init(struct net *net); +void __net_exit tls_proc_fini(struct net *net); + +struct tls_context *tls_ctx_create(struct sock *sk); +void tls_ctx_free(struct sock *sk, struct tls_context *ctx); +void update_sk_prot(struct sock *sk, struct tls_context *ctx); + +int wait_on_pending_writer(struct sock *sk, long *timeo); +int tls_sk_query(struct sock *sk, int optname, char __user *optval, + int __user *optlen); +int tls_sk_attach(struct sock *sk, int optname, char __user *optval, + unsigned int optlen); +void tls_err_abort(struct sock *sk, int err); + +int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx); +void tls_update_rx_zc_capable(struct tls_context *tls_ctx); +void tls_sw_strparser_arm(struct sock *sk, struct tls_context *ctx); +void tls_sw_strparser_done(struct tls_context *tls_ctx); +int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); +int tls_sw_sendpage_locked(struct sock *sk, struct page *page, + int offset, size_t size, int flags); +int tls_sw_sendpage(struct sock *sk, struct page *page, + int offset, size_t size, int flags); +void tls_sw_cancel_work_tx(struct tls_context *tls_ctx); +void tls_sw_release_resources_tx(struct sock *sk); +void tls_sw_free_ctx_tx(struct tls_context *tls_ctx); +void tls_sw_free_resources_rx(struct sock *sk); +void tls_sw_release_resources_rx(struct sock *sk); +void tls_sw_free_ctx_rx(struct tls_context *tls_ctx); +int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int flags, int *addr_len); +bool tls_sw_sock_is_readable(struct sock *sk); +ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, + struct pipe_inode_info *pipe, + size_t len, unsigned int flags); + +int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); +int tls_device_sendpage(struct sock *sk, struct page *page, + int offset, size_t size, int flags); +int tls_tx_records(struct sock *sk, int flags); + +void tls_sw_write_space(struct sock *sk, struct tls_context *ctx); +void tls_device_write_space(struct sock *sk, struct tls_context *ctx); + +int tls_process_cmsg(struct sock *sk, struct msghdr *msg, + unsigned char *record_type); +int decrypt_skb(struct sock *sk, struct scatterlist *sgout); + +int tls_sw_fallback_init(struct sock *sk, + struct tls_offload_context_tx *offload_ctx, + struct tls_crypto_info *crypto_info); + +int tls_strp_dev_init(void); +void tls_strp_dev_exit(void); + +void tls_strp_done(struct tls_strparser *strp); +void tls_strp_stop(struct tls_strparser *strp); +int tls_strp_init(struct tls_strparser *strp, struct sock *sk); +void tls_strp_data_ready(struct tls_strparser *strp); + +void tls_strp_check_rcv(struct tls_strparser *strp); +void tls_strp_msg_done(struct tls_strparser *strp); + +int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb); +void tls_rx_msg_ready(struct tls_strparser *strp); + +void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh); +int tls_strp_msg_cow(struct tls_sw_context_rx *ctx); +struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx); +int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst); + +static inline struct tls_msg *tls_msg(struct sk_buff *skb) +{ + struct sk_skb_cb *scb = (struct sk_skb_cb *)skb->cb; + + return &scb->tls; +} + +static inline struct sk_buff *tls_strp_msg(struct tls_sw_context_rx *ctx) +{ + DEBUG_NET_WARN_ON_ONCE(!ctx->strp.msg_ready || !ctx->strp.anchor->len); + return ctx->strp.anchor; +} + +static inline bool tls_strp_msg_ready(struct tls_sw_context_rx *ctx) +{ + return ctx->strp.msg_ready; +} + +static inline bool tls_strp_msg_mixed_decrypted(struct tls_sw_context_rx *ctx) +{ + return ctx->strp.mixed_decrypted; +} + +#ifdef CONFIG_TLS_DEVICE +int tls_device_init(void); +void tls_device_cleanup(void); +int tls_set_device_offload(struct sock *sk, struct tls_context *ctx); +void tls_device_free_resources_tx(struct sock *sk); +int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx); +void tls_device_offload_cleanup_rx(struct sock *sk); +void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq); +int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx); +#else +static inline int tls_device_init(void) { return 0; } +static inline void tls_device_cleanup(void) {} + +static inline int +tls_set_device_offload(struct sock *sk, struct tls_context *ctx) +{ + return -EOPNOTSUPP; +} + +static inline void tls_device_free_resources_tx(struct sock *sk) {} + +static inline int +tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) +{ + return -EOPNOTSUPP; +} + +static inline void tls_device_offload_cleanup_rx(struct sock *sk) {} +static inline void +tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq) {} + +static inline int +tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx) +{ + return 0; +} +#endif + +int tls_push_sg(struct sock *sk, struct tls_context *ctx, + struct scatterlist *sg, u16 first_offset, + int flags); +int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, + int flags); +void tls_free_partial_record(struct sock *sk, struct tls_context *ctx); + +static inline bool tls_is_partially_sent_record(struct tls_context *ctx) +{ + return !!ctx->partially_sent_record; +} + +static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx) +{ + return tls_ctx->pending_open_record_frags; +} + +static inline bool tls_bigint_increment(unsigned char *seq, int len) +{ + int i; + + for (i = len - 1; i >= 0; i--) { + ++seq[i]; + if (seq[i] != 0) + break; + } + + return (i == -1); +} + +static inline void tls_bigint_subtract(unsigned char *seq, int n) +{ + u64 rcd_sn; + __be64 *p; + + BUILD_BUG_ON(TLS_MAX_REC_SEQ_SIZE != 8); + + p = (__be64 *)seq; + rcd_sn = be64_to_cpu(*p); + *p = cpu_to_be64(rcd_sn - n); +} + +static inline void +tls_advance_record_sn(struct sock *sk, struct tls_prot_info *prot, + struct cipher_context *ctx) +{ + if (tls_bigint_increment(ctx->rec_seq, prot->rec_seq_size)) + tls_err_abort(sk, -EBADMSG); + + if (prot->version != TLS_1_3_VERSION && + prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305) + tls_bigint_increment(ctx->iv + prot->salt_size, + prot->iv_size); +} + +static inline void +tls_xor_iv_with_seq(struct tls_prot_info *prot, char *iv, char *seq) +{ + int i; + + if (prot->version == TLS_1_3_VERSION || + prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) { + for (i = 0; i < 8; i++) + iv[i + 4] ^= seq[i]; + } +} + +static inline void +tls_fill_prepend(struct tls_context *ctx, char *buf, size_t plaintext_len, + unsigned char record_type) +{ + struct tls_prot_info *prot = &ctx->prot_info; + size_t pkt_len, iv_size = prot->iv_size; + + pkt_len = plaintext_len + prot->tag_size; + if (prot->version != TLS_1_3_VERSION && + prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305) { + pkt_len += iv_size; + + memcpy(buf + TLS_NONCE_OFFSET, + ctx->tx.iv + prot->salt_size, iv_size); + } + + /* we cover nonce explicit here as well, so buf should be of + * size KTLS_DTLS_HEADER_SIZE + KTLS_DTLS_NONCE_EXPLICIT_SIZE + */ + buf[0] = prot->version == TLS_1_3_VERSION ? + TLS_RECORD_TYPE_DATA : record_type; + /* Note that VERSION must be TLS_1_2 for both TLS1.2 and TLS1.3 */ + buf[1] = TLS_1_2_VERSION_MINOR; + buf[2] = TLS_1_2_VERSION_MAJOR; + /* we can use IV for nonce explicit according to spec */ + buf[3] = pkt_len >> 8; + buf[4] = pkt_len & 0xFF; +} + +static inline +void tls_make_aad(char *buf, size_t size, char *record_sequence, + unsigned char record_type, struct tls_prot_info *prot) +{ + if (prot->version != TLS_1_3_VERSION) { + memcpy(buf, record_sequence, prot->rec_seq_size); + buf += 8; + } else { + size += prot->tag_size; + } + + buf[0] = prot->version == TLS_1_3_VERSION ? + TLS_RECORD_TYPE_DATA : record_type; + buf[1] = TLS_1_2_VERSION_MAJOR; + buf[2] = TLS_1_2_VERSION_MINOR; + buf[3] = size >> 8; + buf[4] = size & 0xFF; +} + +#endif diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c new file mode 100644 index 000000000..184982788 --- /dev/null +++ b/net/tls/tls_device.c @@ -0,0 +1,1487 @@ +/* Copyright (c) 2018, Mellanox Technologies All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tls.h" +#include "trace.h" + +/* device_offload_lock is used to synchronize tls_dev_add + * against NETDEV_DOWN notifications. + */ +static DECLARE_RWSEM(device_offload_lock); + +static struct workqueue_struct *destruct_wq __read_mostly; + +static LIST_HEAD(tls_device_list); +static LIST_HEAD(tls_device_down_list); +static DEFINE_SPINLOCK(tls_device_lock); + +static struct page *dummy_page; + +static void tls_device_free_ctx(struct tls_context *ctx) +{ + if (ctx->tx_conf == TLS_HW) { + kfree(tls_offload_ctx_tx(ctx)); + kfree(ctx->tx.rec_seq); + kfree(ctx->tx.iv); + } + + if (ctx->rx_conf == TLS_HW) + kfree(tls_offload_ctx_rx(ctx)); + + tls_ctx_free(NULL, ctx); +} + +static void tls_device_tx_del_task(struct work_struct *work) +{ + struct tls_offload_context_tx *offload_ctx = + container_of(work, struct tls_offload_context_tx, destruct_work); + struct tls_context *ctx = offload_ctx->ctx; + struct net_device *netdev; + + /* Safe, because this is the destroy flow, refcount is 0, so + * tls_device_down can't store this field in parallel. + */ + netdev = rcu_dereference_protected(ctx->netdev, + !refcount_read(&ctx->refcount)); + + netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_TX); + dev_put(netdev); + ctx->netdev = NULL; + tls_device_free_ctx(ctx); +} + +static void tls_device_queue_ctx_destruction(struct tls_context *ctx) +{ + struct net_device *netdev; + unsigned long flags; + bool async_cleanup; + + spin_lock_irqsave(&tls_device_lock, flags); + if (unlikely(!refcount_dec_and_test(&ctx->refcount))) { + spin_unlock_irqrestore(&tls_device_lock, flags); + return; + } + + list_del(&ctx->list); /* Remove from tls_device_list / tls_device_down_list */ + + /* Safe, because this is the destroy flow, refcount is 0, so + * tls_device_down can't store this field in parallel. + */ + netdev = rcu_dereference_protected(ctx->netdev, + !refcount_read(&ctx->refcount)); + + async_cleanup = netdev && ctx->tx_conf == TLS_HW; + if (async_cleanup) { + struct tls_offload_context_tx *offload_ctx = tls_offload_ctx_tx(ctx); + + /* queue_work inside the spinlock + * to make sure tls_device_down waits for that work. + */ + queue_work(destruct_wq, &offload_ctx->destruct_work); + } + spin_unlock_irqrestore(&tls_device_lock, flags); + + if (!async_cleanup) + tls_device_free_ctx(ctx); +} + +/* We assume that the socket is already connected */ +static struct net_device *get_netdev_for_sock(struct sock *sk) +{ + struct dst_entry *dst = sk_dst_get(sk); + struct net_device *netdev = NULL; + + if (likely(dst)) { + netdev = netdev_sk_get_lowest_dev(dst->dev, sk); + dev_hold(netdev); + } + + dst_release(dst); + + return netdev; +} + +static void destroy_record(struct tls_record_info *record) +{ + int i; + + for (i = 0; i < record->num_frags; i++) + __skb_frag_unref(&record->frags[i], false); + kfree(record); +} + +static void delete_all_records(struct tls_offload_context_tx *offload_ctx) +{ + struct tls_record_info *info, *temp; + + list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) { + list_del(&info->list); + destroy_record(info); + } + + offload_ctx->retransmit_hint = NULL; +} + +static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_record_info *info, *temp; + struct tls_offload_context_tx *ctx; + u64 deleted_records = 0; + unsigned long flags; + + if (!tls_ctx) + return; + + ctx = tls_offload_ctx_tx(tls_ctx); + + spin_lock_irqsave(&ctx->lock, flags); + info = ctx->retransmit_hint; + if (info && !before(acked_seq, info->end_seq)) + ctx->retransmit_hint = NULL; + + list_for_each_entry_safe(info, temp, &ctx->records_list, list) { + if (before(acked_seq, info->end_seq)) + break; + list_del(&info->list); + + destroy_record(info); + deleted_records++; + } + + ctx->unacked_record_sn += deleted_records; + spin_unlock_irqrestore(&ctx->lock, flags); +} + +/* At this point, there should be no references on this + * socket and no in-flight SKBs associated with this + * socket, so it is safe to free all the resources. + */ +void tls_device_sk_destruct(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); + + tls_ctx->sk_destruct(sk); + + if (tls_ctx->tx_conf == TLS_HW) { + if (ctx->open_record) + destroy_record(ctx->open_record); + delete_all_records(ctx); + crypto_free_aead(ctx->aead_send); + clean_acked_data_disable(inet_csk(sk)); + } + + tls_device_queue_ctx_destruction(tls_ctx); +} +EXPORT_SYMBOL_GPL(tls_device_sk_destruct); + +void tls_device_free_resources_tx(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + + tls_free_partial_record(sk, tls_ctx); +} + +void tls_offload_tx_resync_request(struct sock *sk, u32 got_seq, u32 exp_seq) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + + trace_tls_device_tx_resync_req(sk, got_seq, exp_seq); + WARN_ON(test_and_set_bit(TLS_TX_SYNC_SCHED, &tls_ctx->flags)); +} +EXPORT_SYMBOL_GPL(tls_offload_tx_resync_request); + +static void tls_device_resync_tx(struct sock *sk, struct tls_context *tls_ctx, + u32 seq) +{ + struct net_device *netdev; + struct sk_buff *skb; + int err = 0; + u8 *rcd_sn; + + skb = tcp_write_queue_tail(sk); + if (skb) + TCP_SKB_CB(skb)->eor = 1; + + rcd_sn = tls_ctx->tx.rec_seq; + + trace_tls_device_tx_resync_send(sk, seq, rcd_sn); + down_read(&device_offload_lock); + netdev = rcu_dereference_protected(tls_ctx->netdev, + lockdep_is_held(&device_offload_lock)); + if (netdev) + err = netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq, + rcd_sn, + TLS_OFFLOAD_CTX_DIR_TX); + up_read(&device_offload_lock); + if (err) + return; + + clear_bit_unlock(TLS_TX_SYNC_SCHED, &tls_ctx->flags); +} + +static void tls_append_frag(struct tls_record_info *record, + struct page_frag *pfrag, + int size) +{ + skb_frag_t *frag; + + frag = &record->frags[record->num_frags - 1]; + if (skb_frag_page(frag) == pfrag->page && + skb_frag_off(frag) + skb_frag_size(frag) == pfrag->offset) { + skb_frag_size_add(frag, size); + } else { + ++frag; + __skb_frag_set_page(frag, pfrag->page); + skb_frag_off_set(frag, pfrag->offset); + skb_frag_size_set(frag, size); + ++record->num_frags; + get_page(pfrag->page); + } + + pfrag->offset += size; + record->len += size; +} + +static int tls_push_record(struct sock *sk, + struct tls_context *ctx, + struct tls_offload_context_tx *offload_ctx, + struct tls_record_info *record, + int flags) +{ + struct tls_prot_info *prot = &ctx->prot_info; + struct tcp_sock *tp = tcp_sk(sk); + skb_frag_t *frag; + int i; + + record->end_seq = tp->write_seq + record->len; + list_add_tail_rcu(&record->list, &offload_ctx->records_list); + offload_ctx->open_record = NULL; + + if (test_bit(TLS_TX_SYNC_SCHED, &ctx->flags)) + tls_device_resync_tx(sk, ctx, tp->write_seq); + + tls_advance_record_sn(sk, prot, &ctx->tx); + + for (i = 0; i < record->num_frags; i++) { + frag = &record->frags[i]; + sg_unmark_end(&offload_ctx->sg_tx_data[i]); + sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag), + skb_frag_size(frag), skb_frag_off(frag)); + sk_mem_charge(sk, skb_frag_size(frag)); + get_page(skb_frag_page(frag)); + } + sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]); + + /* all ready, send */ + return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags); +} + +static void tls_device_record_close(struct sock *sk, + struct tls_context *ctx, + struct tls_record_info *record, + struct page_frag *pfrag, + unsigned char record_type) +{ + struct tls_prot_info *prot = &ctx->prot_info; + struct page_frag dummy_tag_frag; + + /* append tag + * device will fill in the tag, we just need to append a placeholder + * use socket memory to improve coalescing (re-using a single buffer + * increases frag count) + * if we can't allocate memory now use the dummy page + */ + if (unlikely(pfrag->size - pfrag->offset < prot->tag_size) && + !skb_page_frag_refill(prot->tag_size, pfrag, sk->sk_allocation)) { + dummy_tag_frag.page = dummy_page; + dummy_tag_frag.offset = 0; + pfrag = &dummy_tag_frag; + } + tls_append_frag(record, pfrag, prot->tag_size); + + /* fill prepend */ + tls_fill_prepend(ctx, skb_frag_address(&record->frags[0]), + record->len - prot->overhead_size, + record_type); +} + +static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx, + struct page_frag *pfrag, + size_t prepend_size) +{ + struct tls_record_info *record; + skb_frag_t *frag; + + record = kmalloc(sizeof(*record), GFP_KERNEL); + if (!record) + return -ENOMEM; + + frag = &record->frags[0]; + __skb_frag_set_page(frag, pfrag->page); + skb_frag_off_set(frag, pfrag->offset); + skb_frag_size_set(frag, prepend_size); + + get_page(pfrag->page); + pfrag->offset += prepend_size; + + record->num_frags = 1; + record->len = prepend_size; + offload_ctx->open_record = record; + return 0; +} + +static int tls_do_allocation(struct sock *sk, + struct tls_offload_context_tx *offload_ctx, + struct page_frag *pfrag, + size_t prepend_size) +{ + int ret; + + if (!offload_ctx->open_record) { + if (unlikely(!skb_page_frag_refill(prepend_size, pfrag, + sk->sk_allocation))) { + READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk); + sk_stream_moderate_sndbuf(sk); + return -ENOMEM; + } + + ret = tls_create_new_record(offload_ctx, pfrag, prepend_size); + if (ret) + return ret; + + if (pfrag->size > pfrag->offset) + return 0; + } + + if (!sk_page_frag_refill(sk, pfrag)) + return -ENOMEM; + + return 0; +} + +static int tls_device_copy_data(void *addr, size_t bytes, struct iov_iter *i) +{ + size_t pre_copy, nocache; + + pre_copy = ~((unsigned long)addr - 1) & (SMP_CACHE_BYTES - 1); + if (pre_copy) { + pre_copy = min(pre_copy, bytes); + if (copy_from_iter(addr, pre_copy, i) != pre_copy) + return -EFAULT; + bytes -= pre_copy; + addr += pre_copy; + } + + nocache = round_down(bytes, SMP_CACHE_BYTES); + if (copy_from_iter_nocache(addr, nocache, i) != nocache) + return -EFAULT; + bytes -= nocache; + addr += nocache; + + if (bytes && copy_from_iter(addr, bytes, i) != bytes) + return -EFAULT; + + return 0; +} + +union tls_iter_offset { + struct iov_iter *msg_iter; + int offset; +}; + +static int tls_push_data(struct sock *sk, + union tls_iter_offset iter_offset, + size_t size, int flags, + unsigned char record_type, + struct page *zc_page) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); + struct tls_record_info *record; + int tls_push_record_flags; + struct page_frag *pfrag; + size_t orig_size = size; + u32 max_open_record_len; + bool more = false; + bool done = false; + int copy, rc = 0; + long timeo; + + if (flags & + ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST)) + return -EOPNOTSUPP; + + if (unlikely(sk->sk_err)) + return -sk->sk_err; + + flags |= MSG_SENDPAGE_DECRYPTED; + tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST; + + timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); + if (tls_is_partially_sent_record(tls_ctx)) { + rc = tls_push_partial_record(sk, tls_ctx, flags); + if (rc < 0) + return rc; + } + + pfrag = sk_page_frag(sk); + + /* TLS_HEADER_SIZE is not counted as part of the TLS record, and + * we need to leave room for an authentication tag. + */ + max_open_record_len = TLS_MAX_PAYLOAD_SIZE + + prot->prepend_size; + do { + rc = tls_do_allocation(sk, ctx, pfrag, prot->prepend_size); + if (unlikely(rc)) { + rc = sk_stream_wait_memory(sk, &timeo); + if (!rc) + continue; + + record = ctx->open_record; + if (!record) + break; +handle_error: + if (record_type != TLS_RECORD_TYPE_DATA) { + /* avoid sending partial + * record with type != + * application_data + */ + size = orig_size; + destroy_record(record); + ctx->open_record = NULL; + } else if (record->len > prot->prepend_size) { + goto last_record; + } + + break; + } + + record = ctx->open_record; + + copy = min_t(size_t, size, max_open_record_len - record->len); + if (copy && zc_page) { + struct page_frag zc_pfrag; + + zc_pfrag.page = zc_page; + zc_pfrag.offset = iter_offset.offset; + zc_pfrag.size = copy; + tls_append_frag(record, &zc_pfrag, copy); + + iter_offset.offset += copy; + } else if (copy) { + copy = min_t(size_t, copy, pfrag->size - pfrag->offset); + + rc = tls_device_copy_data(page_address(pfrag->page) + + pfrag->offset, copy, + iter_offset.msg_iter); + if (rc) + goto handle_error; + tls_append_frag(record, pfrag, copy); + } + + size -= copy; + if (!size) { +last_record: + tls_push_record_flags = flags; + if (flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE)) { + more = true; + break; + } + + done = true; + } + + if (done || record->len >= max_open_record_len || + (record->num_frags >= MAX_SKB_FRAGS - 1)) { + tls_device_record_close(sk, tls_ctx, record, + pfrag, record_type); + + rc = tls_push_record(sk, + tls_ctx, + ctx, + record, + tls_push_record_flags); + if (rc < 0) + break; + } + } while (!done); + + tls_ctx->pending_open_record_frags = more; + + if (orig_size - size > 0) + rc = orig_size - size; + + return rc; +} + +int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +{ + unsigned char record_type = TLS_RECORD_TYPE_DATA; + struct tls_context *tls_ctx = tls_get_ctx(sk); + union tls_iter_offset iter; + int rc; + + mutex_lock(&tls_ctx->tx_lock); + lock_sock(sk); + + if (unlikely(msg->msg_controllen)) { + rc = tls_process_cmsg(sk, msg, &record_type); + if (rc) + goto out; + } + + iter.msg_iter = &msg->msg_iter; + rc = tls_push_data(sk, iter, size, msg->msg_flags, record_type, NULL); + +out: + release_sock(sk); + mutex_unlock(&tls_ctx->tx_lock); + return rc; +} + +int tls_device_sendpage(struct sock *sk, struct page *page, + int offset, size_t size, int flags) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + union tls_iter_offset iter_offset; + struct iov_iter msg_iter; + char *kaddr; + struct kvec iov; + int rc; + + if (flags & MSG_SENDPAGE_NOTLAST) + flags |= MSG_MORE; + + mutex_lock(&tls_ctx->tx_lock); + lock_sock(sk); + + if (flags & MSG_OOB) { + rc = -EOPNOTSUPP; + goto out; + } + + if (tls_ctx->zerocopy_sendfile) { + iter_offset.offset = offset; + rc = tls_push_data(sk, iter_offset, size, + flags, TLS_RECORD_TYPE_DATA, page); + goto out; + } + + kaddr = kmap(page); + iov.iov_base = kaddr + offset; + iov.iov_len = size; + iov_iter_kvec(&msg_iter, ITER_SOURCE, &iov, 1, size); + iter_offset.msg_iter = &msg_iter; + rc = tls_push_data(sk, iter_offset, size, flags, TLS_RECORD_TYPE_DATA, + NULL); + kunmap(page); + +out: + release_sock(sk); + mutex_unlock(&tls_ctx->tx_lock); + return rc; +} + +struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context, + u32 seq, u64 *p_record_sn) +{ + u64 record_sn = context->hint_record_sn; + struct tls_record_info *info, *last; + + info = context->retransmit_hint; + if (!info || + before(seq, info->end_seq - info->len)) { + /* if retransmit_hint is irrelevant start + * from the beginning of the list + */ + info = list_first_entry_or_null(&context->records_list, + struct tls_record_info, list); + if (!info) + return NULL; + /* send the start_marker record if seq number is before the + * tls offload start marker sequence number. This record is + * required to handle TCP packets which are before TLS offload + * started. + * And if it's not start marker, look if this seq number + * belongs to the list. + */ + if (likely(!tls_record_is_start_marker(info))) { + /* we have the first record, get the last record to see + * if this seq number belongs to the list. + */ + last = list_last_entry(&context->records_list, + struct tls_record_info, list); + + if (!between(seq, tls_record_start_seq(info), + last->end_seq)) + return NULL; + } + record_sn = context->unacked_record_sn; + } + + /* We just need the _rcu for the READ_ONCE() */ + rcu_read_lock(); + list_for_each_entry_from_rcu(info, &context->records_list, list) { + if (before(seq, info->end_seq)) { + if (!context->retransmit_hint || + after(info->end_seq, + context->retransmit_hint->end_seq)) { + context->hint_record_sn = record_sn; + context->retransmit_hint = info; + } + *p_record_sn = record_sn; + goto exit_rcu_unlock; + } + record_sn++; + } + info = NULL; + +exit_rcu_unlock: + rcu_read_unlock(); + return info; +} +EXPORT_SYMBOL(tls_get_record); + +static int tls_device_push_pending_record(struct sock *sk, int flags) +{ + union tls_iter_offset iter; + struct iov_iter msg_iter; + + iov_iter_kvec(&msg_iter, ITER_SOURCE, NULL, 0, 0); + iter.msg_iter = &msg_iter; + return tls_push_data(sk, iter, 0, flags, TLS_RECORD_TYPE_DATA, NULL); +} + +void tls_device_write_space(struct sock *sk, struct tls_context *ctx) +{ + if (tls_is_partially_sent_record(ctx)) { + gfp_t sk_allocation = sk->sk_allocation; + + WARN_ON_ONCE(sk->sk_write_pending); + + sk->sk_allocation = GFP_ATOMIC; + tls_push_partial_record(sk, ctx, + MSG_DONTWAIT | MSG_NOSIGNAL | + MSG_SENDPAGE_DECRYPTED); + sk->sk_allocation = sk_allocation; + } +} + +static void tls_device_resync_rx(struct tls_context *tls_ctx, + struct sock *sk, u32 seq, u8 *rcd_sn) +{ + struct tls_offload_context_rx *rx_ctx = tls_offload_ctx_rx(tls_ctx); + struct net_device *netdev; + + trace_tls_device_rx_resync_send(sk, seq, rcd_sn, rx_ctx->resync_type); + rcu_read_lock(); + netdev = rcu_dereference(tls_ctx->netdev); + if (netdev) + netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq, rcd_sn, + TLS_OFFLOAD_CTX_DIR_RX); + rcu_read_unlock(); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICERESYNC); +} + +static bool +tls_device_rx_resync_async(struct tls_offload_resync_async *resync_async, + s64 resync_req, u32 *seq, u16 *rcd_delta) +{ + u32 is_async = resync_req & RESYNC_REQ_ASYNC; + u32 req_seq = resync_req >> 32; + u32 req_end = req_seq + ((resync_req >> 16) & 0xffff); + u16 i; + + *rcd_delta = 0; + + if (is_async) { + /* shouldn't get to wraparound: + * too long in async stage, something bad happened + */ + if (WARN_ON_ONCE(resync_async->rcd_delta == USHRT_MAX)) + return false; + + /* asynchronous stage: log all headers seq such that + * req_seq <= seq <= end_seq, and wait for real resync request + */ + if (before(*seq, req_seq)) + return false; + if (!after(*seq, req_end) && + resync_async->loglen < TLS_DEVICE_RESYNC_ASYNC_LOGMAX) + resync_async->log[resync_async->loglen++] = *seq; + + resync_async->rcd_delta++; + + return false; + } + + /* synchronous stage: check against the logged entries and + * proceed to check the next entries if no match was found + */ + for (i = 0; i < resync_async->loglen; i++) + if (req_seq == resync_async->log[i] && + atomic64_try_cmpxchg(&resync_async->req, &resync_req, 0)) { + *rcd_delta = resync_async->rcd_delta - i; + *seq = req_seq; + resync_async->loglen = 0; + resync_async->rcd_delta = 0; + return true; + } + + resync_async->loglen = 0; + resync_async->rcd_delta = 0; + + if (req_seq == *seq && + atomic64_try_cmpxchg(&resync_async->req, + &resync_req, 0)) + return true; + + return false; +} + +void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_offload_context_rx *rx_ctx; + u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE]; + u32 sock_data, is_req_pending; + struct tls_prot_info *prot; + s64 resync_req; + u16 rcd_delta; + u32 req_seq; + + if (tls_ctx->rx_conf != TLS_HW) + return; + if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags))) + return; + + prot = &tls_ctx->prot_info; + rx_ctx = tls_offload_ctx_rx(tls_ctx); + memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size); + + switch (rx_ctx->resync_type) { + case TLS_OFFLOAD_SYNC_TYPE_DRIVER_REQ: + resync_req = atomic64_read(&rx_ctx->resync_req); + req_seq = resync_req >> 32; + seq += TLS_HEADER_SIZE - 1; + is_req_pending = resync_req; + + if (likely(!is_req_pending) || req_seq != seq || + !atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0)) + return; + break; + case TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT: + if (likely(!rx_ctx->resync_nh_do_now)) + return; + + /* head of next rec is already in, note that the sock_inq will + * include the currently parsed message when called from parser + */ + sock_data = tcp_inq(sk); + if (sock_data > rcd_len) { + trace_tls_device_rx_resync_nh_delay(sk, sock_data, + rcd_len); + return; + } + + rx_ctx->resync_nh_do_now = 0; + seq += rcd_len; + tls_bigint_increment(rcd_sn, prot->rec_seq_size); + break; + case TLS_OFFLOAD_SYNC_TYPE_DRIVER_REQ_ASYNC: + resync_req = atomic64_read(&rx_ctx->resync_async->req); + is_req_pending = resync_req; + if (likely(!is_req_pending)) + return; + + if (!tls_device_rx_resync_async(rx_ctx->resync_async, + resync_req, &seq, &rcd_delta)) + return; + tls_bigint_subtract(rcd_sn, rcd_delta); + break; + } + + tls_device_resync_rx(tls_ctx, sk, seq, rcd_sn); +} + +static void tls_device_core_ctrl_rx_resync(struct tls_context *tls_ctx, + struct tls_offload_context_rx *ctx, + struct sock *sk, struct sk_buff *skb) +{ + struct strp_msg *rxm; + + /* device will request resyncs by itself based on stream scan */ + if (ctx->resync_type != TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT) + return; + /* already scheduled */ + if (ctx->resync_nh_do_now) + return; + /* seen decrypted fragments since last fully-failed record */ + if (ctx->resync_nh_reset) { + ctx->resync_nh_reset = 0; + ctx->resync_nh.decrypted_failed = 1; + ctx->resync_nh.decrypted_tgt = TLS_DEVICE_RESYNC_NH_START_IVAL; + return; + } + + if (++ctx->resync_nh.decrypted_failed <= ctx->resync_nh.decrypted_tgt) + return; + + /* doing resync, bump the next target in case it fails */ + if (ctx->resync_nh.decrypted_tgt < TLS_DEVICE_RESYNC_NH_MAX_IVAL) + ctx->resync_nh.decrypted_tgt *= 2; + else + ctx->resync_nh.decrypted_tgt += TLS_DEVICE_RESYNC_NH_MAX_IVAL; + + rxm = strp_msg(skb); + + /* head of next rec is already in, parser will sync for us */ + if (tcp_inq(sk) > rxm->full_len) { + trace_tls_device_rx_resync_nh_schedule(sk); + ctx->resync_nh_do_now = 1; + } else { + struct tls_prot_info *prot = &tls_ctx->prot_info; + u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE]; + + memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size); + tls_bigint_increment(rcd_sn, prot->rec_seq_size); + + tls_device_resync_rx(tls_ctx, sk, tcp_sk(sk)->copied_seq, + rcd_sn); + } +} + +static int +tls_device_reencrypt(struct sock *sk, struct tls_context *tls_ctx) +{ + struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx); + const struct tls_cipher_size_desc *cipher_sz; + int err, offset, copy, data_len, pos; + struct sk_buff *skb, *skb_iter; + struct scatterlist sg[1]; + struct strp_msg *rxm; + char *orig_buf, *buf; + + switch (tls_ctx->crypto_recv.info.cipher_type) { + case TLS_CIPHER_AES_GCM_128: + case TLS_CIPHER_AES_GCM_256: + break; + default: + return -EINVAL; + } + cipher_sz = &tls_cipher_size_desc[tls_ctx->crypto_recv.info.cipher_type]; + + rxm = strp_msg(tls_strp_msg(sw_ctx)); + orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE + cipher_sz->iv, + sk->sk_allocation); + if (!orig_buf) + return -ENOMEM; + buf = orig_buf; + + err = tls_strp_msg_cow(sw_ctx); + if (unlikely(err)) + goto free_buf; + + skb = tls_strp_msg(sw_ctx); + rxm = strp_msg(skb); + offset = rxm->offset; + + sg_init_table(sg, 1); + sg_set_buf(&sg[0], buf, + rxm->full_len + TLS_HEADER_SIZE + cipher_sz->iv); + err = skb_copy_bits(skb, offset, buf, TLS_HEADER_SIZE + cipher_sz->iv); + if (err) + goto free_buf; + + /* We are interested only in the decrypted data not the auth */ + err = decrypt_skb(sk, sg); + if (err != -EBADMSG) + goto free_buf; + else + err = 0; + + data_len = rxm->full_len - cipher_sz->tag; + + if (skb_pagelen(skb) > offset) { + copy = min_t(int, skb_pagelen(skb) - offset, data_len); + + if (skb->decrypted) { + err = skb_store_bits(skb, offset, buf, copy); + if (err) + goto free_buf; + } + + offset += copy; + buf += copy; + } + + pos = skb_pagelen(skb); + skb_walk_frags(skb, skb_iter) { + int frag_pos; + + /* Practically all frags must belong to msg if reencrypt + * is needed with current strparser and coalescing logic, + * but strparser may "get optimized", so let's be safe. + */ + if (pos + skb_iter->len <= offset) + goto done_with_frag; + if (pos >= data_len + rxm->offset) + break; + + frag_pos = offset - pos; + copy = min_t(int, skb_iter->len - frag_pos, + data_len + rxm->offset - offset); + + if (skb_iter->decrypted) { + err = skb_store_bits(skb_iter, frag_pos, buf, copy); + if (err) + goto free_buf; + } + + offset += copy; + buf += copy; +done_with_frag: + pos += skb_iter->len; + } + +free_buf: + kfree(orig_buf); + return err; +} + +int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx) +{ + struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx); + struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx); + struct sk_buff *skb = tls_strp_msg(sw_ctx); + struct strp_msg *rxm = strp_msg(skb); + int is_decrypted, is_encrypted; + + if (!tls_strp_msg_mixed_decrypted(sw_ctx)) { + is_decrypted = skb->decrypted; + is_encrypted = !is_decrypted; + } else { + is_decrypted = 0; + is_encrypted = 0; + } + + trace_tls_device_decrypted(sk, tcp_sk(sk)->copied_seq - rxm->full_len, + tls_ctx->rx.rec_seq, rxm->full_len, + is_encrypted, is_decrypted); + + if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags))) { + if (likely(is_encrypted || is_decrypted)) + return is_decrypted; + + /* After tls_device_down disables the offload, the next SKB will + * likely have initial fragments decrypted, and final ones not + * decrypted. We need to reencrypt that single SKB. + */ + return tls_device_reencrypt(sk, tls_ctx); + } + + /* Return immediately if the record is either entirely plaintext or + * entirely ciphertext. Otherwise handle reencrypt partially decrypted + * record. + */ + if (is_decrypted) { + ctx->resync_nh_reset = 1; + return is_decrypted; + } + if (is_encrypted) { + tls_device_core_ctrl_rx_resync(tls_ctx, ctx, sk, skb); + return 0; + } + + ctx->resync_nh_reset = 1; + return tls_device_reencrypt(sk, tls_ctx); +} + +static void tls_device_attach(struct tls_context *ctx, struct sock *sk, + struct net_device *netdev) +{ + if (sk->sk_destruct != tls_device_sk_destruct) { + refcount_set(&ctx->refcount, 1); + dev_hold(netdev); + RCU_INIT_POINTER(ctx->netdev, netdev); + spin_lock_irq(&tls_device_lock); + list_add_tail(&ctx->list, &tls_device_list); + spin_unlock_irq(&tls_device_lock); + + ctx->sk_destruct = sk->sk_destruct; + smp_store_release(&sk->sk_destruct, tls_device_sk_destruct); + } +} + +int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; + const struct tls_cipher_size_desc *cipher_sz; + struct tls_record_info *start_marker_record; + struct tls_offload_context_tx *offload_ctx; + struct tls_crypto_info *crypto_info; + struct net_device *netdev; + char *iv, *rec_seq; + struct sk_buff *skb; + __be64 rcd_sn; + int rc; + + if (!ctx) + return -EINVAL; + + if (ctx->priv_ctx_tx) + return -EEXIST; + + netdev = get_netdev_for_sock(sk); + if (!netdev) { + pr_err_ratelimited("%s: netdev not found\n", __func__); + return -EINVAL; + } + + if (!(netdev->features & NETIF_F_HW_TLS_TX)) { + rc = -EOPNOTSUPP; + goto release_netdev; + } + + crypto_info = &ctx->crypto_send.info; + if (crypto_info->version != TLS_1_2_VERSION) { + rc = -EOPNOTSUPP; + goto release_netdev; + } + + switch (crypto_info->cipher_type) { + case TLS_CIPHER_AES_GCM_128: + iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv; + rec_seq = + ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq; + break; + case TLS_CIPHER_AES_GCM_256: + iv = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->iv; + rec_seq = + ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->rec_seq; + break; + default: + rc = -EINVAL; + goto release_netdev; + } + cipher_sz = &tls_cipher_size_desc[crypto_info->cipher_type]; + + /* Sanity-check the rec_seq_size for stack allocations */ + if (cipher_sz->rec_seq > TLS_MAX_REC_SEQ_SIZE) { + rc = -EINVAL; + goto release_netdev; + } + + prot->version = crypto_info->version; + prot->cipher_type = crypto_info->cipher_type; + prot->prepend_size = TLS_HEADER_SIZE + cipher_sz->iv; + prot->tag_size = cipher_sz->tag; + prot->overhead_size = prot->prepend_size + prot->tag_size; + prot->iv_size = cipher_sz->iv; + prot->salt_size = cipher_sz->salt; + ctx->tx.iv = kmalloc(cipher_sz->iv + cipher_sz->salt, GFP_KERNEL); + if (!ctx->tx.iv) { + rc = -ENOMEM; + goto release_netdev; + } + + memcpy(ctx->tx.iv + cipher_sz->salt, iv, cipher_sz->iv); + + prot->rec_seq_size = cipher_sz->rec_seq; + ctx->tx.rec_seq = kmemdup(rec_seq, cipher_sz->rec_seq, GFP_KERNEL); + if (!ctx->tx.rec_seq) { + rc = -ENOMEM; + goto free_iv; + } + + start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL); + if (!start_marker_record) { + rc = -ENOMEM; + goto free_rec_seq; + } + + offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL); + if (!offload_ctx) { + rc = -ENOMEM; + goto free_marker_record; + } + + rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info); + if (rc) + goto free_offload_ctx; + + /* start at rec_seq - 1 to account for the start marker record */ + memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn)); + offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1; + + start_marker_record->end_seq = tcp_sk(sk)->write_seq; + start_marker_record->len = 0; + start_marker_record->num_frags = 0; + + INIT_WORK(&offload_ctx->destruct_work, tls_device_tx_del_task); + offload_ctx->ctx = ctx; + + INIT_LIST_HEAD(&offload_ctx->records_list); + list_add_tail(&start_marker_record->list, &offload_ctx->records_list); + spin_lock_init(&offload_ctx->lock); + sg_init_table(offload_ctx->sg_tx_data, + ARRAY_SIZE(offload_ctx->sg_tx_data)); + + clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked); + ctx->push_pending_record = tls_device_push_pending_record; + + /* TLS offload is greatly simplified if we don't send + * SKBs where only part of the payload needs to be encrypted. + * So mark the last skb in the write queue as end of record. + */ + skb = tcp_write_queue_tail(sk); + if (skb) + TCP_SKB_CB(skb)->eor = 1; + + /* Avoid offloading if the device is down + * We don't want to offload new flows after + * the NETDEV_DOWN event + * + * device_offload_lock is taken in tls_devices's NETDEV_DOWN + * handler thus protecting from the device going down before + * ctx was added to tls_device_list. + */ + down_read(&device_offload_lock); + if (!(netdev->flags & IFF_UP)) { + rc = -EINVAL; + goto release_lock; + } + + ctx->priv_ctx_tx = offload_ctx; + rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX, + &ctx->crypto_send.info, + tcp_sk(sk)->write_seq); + trace_tls_device_offload_set(sk, TLS_OFFLOAD_CTX_DIR_TX, + tcp_sk(sk)->write_seq, rec_seq, rc); + if (rc) + goto release_lock; + + tls_device_attach(ctx, sk, netdev); + up_read(&device_offload_lock); + + /* following this assignment tls_is_sk_tx_device_offloaded + * will return true and the context might be accessed + * by the netdev's xmit function. + */ + smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb); + dev_put(netdev); + + return 0; + +release_lock: + up_read(&device_offload_lock); + clean_acked_data_disable(inet_csk(sk)); + crypto_free_aead(offload_ctx->aead_send); +free_offload_ctx: + kfree(offload_ctx); + ctx->priv_ctx_tx = NULL; +free_marker_record: + kfree(start_marker_record); +free_rec_seq: + kfree(ctx->tx.rec_seq); +free_iv: + kfree(ctx->tx.iv); +release_netdev: + dev_put(netdev); + return rc; +} + +int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) +{ + struct tls12_crypto_info_aes_gcm_128 *info; + struct tls_offload_context_rx *context; + struct net_device *netdev; + int rc = 0; + + if (ctx->crypto_recv.info.version != TLS_1_2_VERSION) + return -EOPNOTSUPP; + + netdev = get_netdev_for_sock(sk); + if (!netdev) { + pr_err_ratelimited("%s: netdev not found\n", __func__); + return -EINVAL; + } + + if (!(netdev->features & NETIF_F_HW_TLS_RX)) { + rc = -EOPNOTSUPP; + goto release_netdev; + } + + /* Avoid offloading if the device is down + * We don't want to offload new flows after + * the NETDEV_DOWN event + * + * device_offload_lock is taken in tls_devices's NETDEV_DOWN + * handler thus protecting from the device going down before + * ctx was added to tls_device_list. + */ + down_read(&device_offload_lock); + if (!(netdev->flags & IFF_UP)) { + rc = -EINVAL; + goto release_lock; + } + + context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL); + if (!context) { + rc = -ENOMEM; + goto release_lock; + } + context->resync_nh_reset = 1; + + ctx->priv_ctx_rx = context; + rc = tls_set_sw_offload(sk, ctx, 0); + if (rc) + goto release_ctx; + + rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX, + &ctx->crypto_recv.info, + tcp_sk(sk)->copied_seq); + info = (void *)&ctx->crypto_recv.info; + trace_tls_device_offload_set(sk, TLS_OFFLOAD_CTX_DIR_RX, + tcp_sk(sk)->copied_seq, info->rec_seq, rc); + if (rc) + goto free_sw_resources; + + tls_device_attach(ctx, sk, netdev); + up_read(&device_offload_lock); + + dev_put(netdev); + + return 0; + +free_sw_resources: + up_read(&device_offload_lock); + tls_sw_free_resources_rx(sk); + down_read(&device_offload_lock); +release_ctx: + ctx->priv_ctx_rx = NULL; +release_lock: + up_read(&device_offload_lock); +release_netdev: + dev_put(netdev); + return rc; +} + +void tls_device_offload_cleanup_rx(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct net_device *netdev; + + down_read(&device_offload_lock); + netdev = rcu_dereference_protected(tls_ctx->netdev, + lockdep_is_held(&device_offload_lock)); + if (!netdev) + goto out; + + netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx, + TLS_OFFLOAD_CTX_DIR_RX); + + if (tls_ctx->tx_conf != TLS_HW) { + dev_put(netdev); + rcu_assign_pointer(tls_ctx->netdev, NULL); + } else { + set_bit(TLS_RX_DEV_CLOSED, &tls_ctx->flags); + } +out: + up_read(&device_offload_lock); + tls_sw_release_resources_rx(sk); +} + +static int tls_device_down(struct net_device *netdev) +{ + struct tls_context *ctx, *tmp; + unsigned long flags; + LIST_HEAD(list); + + /* Request a write lock to block new offload attempts */ + down_write(&device_offload_lock); + + spin_lock_irqsave(&tls_device_lock, flags); + list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) { + struct net_device *ctx_netdev = + rcu_dereference_protected(ctx->netdev, + lockdep_is_held(&device_offload_lock)); + + if (ctx_netdev != netdev || + !refcount_inc_not_zero(&ctx->refcount)) + continue; + + list_move(&ctx->list, &list); + } + spin_unlock_irqrestore(&tls_device_lock, flags); + + list_for_each_entry_safe(ctx, tmp, &list, list) { + /* Stop offloaded TX and switch to the fallback. + * tls_is_sk_tx_device_offloaded will return false. + */ + WRITE_ONCE(ctx->sk->sk_validate_xmit_skb, tls_validate_xmit_skb_sw); + + /* Stop the RX and TX resync. + * tls_dev_resync must not be called after tls_dev_del. + */ + rcu_assign_pointer(ctx->netdev, NULL); + + /* Start skipping the RX resync logic completely. */ + set_bit(TLS_RX_DEV_DEGRADED, &ctx->flags); + + /* Sync with inflight packets. After this point: + * TX: no non-encrypted packets will be passed to the driver. + * RX: resync requests from the driver will be ignored. + */ + synchronize_net(); + + /* Release the offload context on the driver side. */ + if (ctx->tx_conf == TLS_HW) + netdev->tlsdev_ops->tls_dev_del(netdev, ctx, + TLS_OFFLOAD_CTX_DIR_TX); + if (ctx->rx_conf == TLS_HW && + !test_bit(TLS_RX_DEV_CLOSED, &ctx->flags)) + netdev->tlsdev_ops->tls_dev_del(netdev, ctx, + TLS_OFFLOAD_CTX_DIR_RX); + + dev_put(netdev); + + /* Move the context to a separate list for two reasons: + * 1. When the context is deallocated, list_del is called. + * 2. It's no longer an offloaded context, so we don't want to + * run offload-specific code on this context. + */ + spin_lock_irqsave(&tls_device_lock, flags); + list_move_tail(&ctx->list, &tls_device_down_list); + spin_unlock_irqrestore(&tls_device_lock, flags); + + /* Device contexts for RX and TX will be freed in on sk_destruct + * by tls_device_free_ctx. rx_conf and tx_conf stay in TLS_HW. + * Now release the ref taken above. + */ + if (refcount_dec_and_test(&ctx->refcount)) { + /* sk_destruct ran after tls_device_down took a ref, and + * it returned early. Complete the destruction here. + */ + list_del(&ctx->list); + tls_device_free_ctx(ctx); + } + } + + up_write(&device_offload_lock); + + flush_workqueue(destruct_wq); + + return NOTIFY_DONE; +} + +static int tls_dev_event(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + if (!dev->tlsdev_ops && + !(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX))) + return NOTIFY_DONE; + + switch (event) { + case NETDEV_REGISTER: + case NETDEV_FEAT_CHANGE: + if (netif_is_bond_master(dev)) + return NOTIFY_DONE; + if ((dev->features & NETIF_F_HW_TLS_RX) && + !dev->tlsdev_ops->tls_dev_resync) + return NOTIFY_BAD; + + if (dev->tlsdev_ops && + dev->tlsdev_ops->tls_dev_add && + dev->tlsdev_ops->tls_dev_del) + return NOTIFY_DONE; + else + return NOTIFY_BAD; + case NETDEV_DOWN: + return tls_device_down(dev); + } + return NOTIFY_DONE; +} + +static struct notifier_block tls_dev_notifier = { + .notifier_call = tls_dev_event, +}; + +int __init tls_device_init(void) +{ + int err; + + dummy_page = alloc_page(GFP_KERNEL); + if (!dummy_page) + return -ENOMEM; + + destruct_wq = alloc_workqueue("ktls_device_destruct", 0, 0); + if (!destruct_wq) { + err = -ENOMEM; + goto err_free_dummy; + } + + err = register_netdevice_notifier(&tls_dev_notifier); + if (err) + goto err_destroy_wq; + + return 0; + +err_destroy_wq: + destroy_workqueue(destruct_wq); +err_free_dummy: + put_page(dummy_page); + return err; +} + +void __exit tls_device_cleanup(void) +{ + unregister_netdevice_notifier(&tls_dev_notifier); + destroy_workqueue(destruct_wq); + clean_acked_data_flush(); + put_page(dummy_page); +} diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c new file mode 100644 index 000000000..7fbb1d0b6 --- /dev/null +++ b/net/tls/tls_device_fallback.c @@ -0,0 +1,513 @@ +/* Copyright (c) 2018, Mellanox Technologies All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include +#include +#include +#include + +#include "tls.h" + +static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk) +{ + struct scatterlist *src = walk->sg; + int diff = walk->offset - src->offset; + + sg_set_page(sg, sg_page(src), + src->length - diff, walk->offset); + + scatterwalk_crypto_chain(sg, sg_next(src), 2); +} + +static int tls_enc_record(struct aead_request *aead_req, + struct crypto_aead *aead, char *aad, + char *iv, __be64 rcd_sn, + struct scatter_walk *in, + struct scatter_walk *out, int *in_len, + struct tls_prot_info *prot) +{ + unsigned char buf[TLS_HEADER_SIZE + MAX_IV_SIZE]; + const struct tls_cipher_size_desc *cipher_sz; + struct scatterlist sg_in[3]; + struct scatterlist sg_out[3]; + unsigned int buf_size; + u16 len; + int rc; + + switch (prot->cipher_type) { + case TLS_CIPHER_AES_GCM_128: + case TLS_CIPHER_AES_GCM_256: + break; + default: + return -EINVAL; + } + cipher_sz = &tls_cipher_size_desc[prot->cipher_type]; + + buf_size = TLS_HEADER_SIZE + cipher_sz->iv; + len = min_t(int, *in_len, buf_size); + + scatterwalk_copychunks(buf, in, len, 0); + scatterwalk_copychunks(buf, out, len, 1); + + *in_len -= len; + if (!*in_len) + return 0; + + scatterwalk_pagedone(in, 0, 1); + scatterwalk_pagedone(out, 1, 1); + + len = buf[4] | (buf[3] << 8); + len -= cipher_sz->iv; + + tls_make_aad(aad, len - cipher_sz->tag, (char *)&rcd_sn, buf[0], prot); + + memcpy(iv + cipher_sz->salt, buf + TLS_HEADER_SIZE, cipher_sz->iv); + + sg_init_table(sg_in, ARRAY_SIZE(sg_in)); + sg_init_table(sg_out, ARRAY_SIZE(sg_out)); + sg_set_buf(sg_in, aad, TLS_AAD_SPACE_SIZE); + sg_set_buf(sg_out, aad, TLS_AAD_SPACE_SIZE); + chain_to_walk(sg_in + 1, in); + chain_to_walk(sg_out + 1, out); + + *in_len -= len; + if (*in_len < 0) { + *in_len += cipher_sz->tag; + /* the input buffer doesn't contain the entire record. + * trim len accordingly. The resulting authentication tag + * will contain garbage, but we don't care, so we won't + * include any of it in the output skb + * Note that we assume the output buffer length + * is larger then input buffer length + tag size + */ + if (*in_len < 0) + len += *in_len; + + *in_len = 0; + } + + if (*in_len) { + scatterwalk_copychunks(NULL, in, len, 2); + scatterwalk_pagedone(in, 0, 1); + scatterwalk_copychunks(NULL, out, len, 2); + scatterwalk_pagedone(out, 1, 1); + } + + len -= cipher_sz->tag; + aead_request_set_crypt(aead_req, sg_in, sg_out, len, iv); + + rc = crypto_aead_encrypt(aead_req); + + return rc; +} + +static void tls_init_aead_request(struct aead_request *aead_req, + struct crypto_aead *aead) +{ + aead_request_set_tfm(aead_req, aead); + aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); +} + +static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead, + gfp_t flags) +{ + unsigned int req_size = sizeof(struct aead_request) + + crypto_aead_reqsize(aead); + struct aead_request *aead_req; + + aead_req = kzalloc(req_size, flags); + if (aead_req) + tls_init_aead_request(aead_req, aead); + return aead_req; +} + +static int tls_enc_records(struct aead_request *aead_req, + struct crypto_aead *aead, struct scatterlist *sg_in, + struct scatterlist *sg_out, char *aad, char *iv, + u64 rcd_sn, int len, struct tls_prot_info *prot) +{ + struct scatter_walk out, in; + int rc; + + scatterwalk_start(&in, sg_in); + scatterwalk_start(&out, sg_out); + + do { + rc = tls_enc_record(aead_req, aead, aad, iv, + cpu_to_be64(rcd_sn), &in, &out, &len, prot); + rcd_sn++; + + } while (rc == 0 && len); + + scatterwalk_done(&in, 0, 0); + scatterwalk_done(&out, 1, 0); + + return rc; +} + +/* Can't use icsk->icsk_af_ops->send_check here because the ip addresses + * might have been changed by NAT. + */ +static void update_chksum(struct sk_buff *skb, int headln) +{ + struct tcphdr *th = tcp_hdr(skb); + int datalen = skb->len - headln; + const struct ipv6hdr *ipv6h; + const struct iphdr *iph; + + /* We only changed the payload so if we are using partial we don't + * need to update anything. + */ + if (likely(skb->ip_summed == CHECKSUM_PARTIAL)) + return; + + skb->ip_summed = CHECKSUM_PARTIAL; + skb->csum_start = skb_transport_header(skb) - skb->head; + skb->csum_offset = offsetof(struct tcphdr, check); + + if (skb->sk->sk_family == AF_INET6) { + ipv6h = ipv6_hdr(skb); + th->check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr, + datalen, IPPROTO_TCP, 0); + } else { + iph = ip_hdr(skb); + th->check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, datalen, + IPPROTO_TCP, 0); + } +} + +static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln) +{ + struct sock *sk = skb->sk; + int delta; + + skb_copy_header(nskb, skb); + + skb_put(nskb, skb->len); + memcpy(nskb->data, skb->data, headln); + + nskb->destructor = skb->destructor; + nskb->sk = sk; + skb->destructor = NULL; + skb->sk = NULL; + + update_chksum(nskb, headln); + + /* sock_efree means skb must gone through skb_orphan_partial() */ + if (nskb->destructor == sock_efree) + return; + + delta = nskb->truesize - skb->truesize; + if (likely(delta < 0)) + WARN_ON_ONCE(refcount_sub_and_test(-delta, &sk->sk_wmem_alloc)); + else if (delta) + refcount_add(delta, &sk->sk_wmem_alloc); +} + +/* This function may be called after the user socket is already + * closed so make sure we don't use anything freed during + * tls_sk_proto_close here + */ + +static int fill_sg_in(struct scatterlist *sg_in, + struct sk_buff *skb, + struct tls_offload_context_tx *ctx, + u64 *rcd_sn, + s32 *sync_size, + int *resync_sgs) +{ + int tcp_payload_offset = skb_tcp_all_headers(skb); + int payload_len = skb->len - tcp_payload_offset; + u32 tcp_seq = ntohl(tcp_hdr(skb)->seq); + struct tls_record_info *record; + unsigned long flags; + int remaining; + int i; + + spin_lock_irqsave(&ctx->lock, flags); + record = tls_get_record(ctx, tcp_seq, rcd_sn); + if (!record) { + spin_unlock_irqrestore(&ctx->lock, flags); + return -EINVAL; + } + + *sync_size = tcp_seq - tls_record_start_seq(record); + if (*sync_size < 0) { + int is_start_marker = tls_record_is_start_marker(record); + + spin_unlock_irqrestore(&ctx->lock, flags); + /* This should only occur if the relevant record was + * already acked. In that case it should be ok + * to drop the packet and avoid retransmission. + * + * There is a corner case where the packet contains + * both an acked and a non-acked record. + * We currently don't handle that case and rely + * on TCP to retranmit a packet that doesn't contain + * already acked payload. + */ + if (!is_start_marker) + *sync_size = 0; + return -EINVAL; + } + + remaining = *sync_size; + for (i = 0; remaining > 0; i++) { + skb_frag_t *frag = &record->frags[i]; + + __skb_frag_ref(frag); + sg_set_page(sg_in + i, skb_frag_page(frag), + skb_frag_size(frag), skb_frag_off(frag)); + + remaining -= skb_frag_size(frag); + + if (remaining < 0) + sg_in[i].length += remaining; + } + *resync_sgs = i; + + spin_unlock_irqrestore(&ctx->lock, flags); + if (skb_to_sgvec(skb, &sg_in[i], tcp_payload_offset, payload_len) < 0) + return -EINVAL; + + return 0; +} + +static void fill_sg_out(struct scatterlist sg_out[3], void *buf, + struct tls_context *tls_ctx, + struct sk_buff *nskb, + int tcp_payload_offset, + int payload_len, + int sync_size, + void *dummy_buf) +{ + const struct tls_cipher_size_desc *cipher_sz = + &tls_cipher_size_desc[tls_ctx->crypto_send.info.cipher_type]; + + sg_set_buf(&sg_out[0], dummy_buf, sync_size); + sg_set_buf(&sg_out[1], nskb->data + tcp_payload_offset, payload_len); + /* Add room for authentication tag produced by crypto */ + dummy_buf += sync_size; + sg_set_buf(&sg_out[2], dummy_buf, cipher_sz->tag); +} + +static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx, + struct scatterlist sg_out[3], + struct scatterlist *sg_in, + struct sk_buff *skb, + s32 sync_size, u64 rcd_sn) +{ + struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); + int tcp_payload_offset = skb_tcp_all_headers(skb); + int payload_len = skb->len - tcp_payload_offset; + const struct tls_cipher_size_desc *cipher_sz; + void *buf, *iv, *aad, *dummy_buf, *salt; + struct aead_request *aead_req; + struct sk_buff *nskb = NULL; + int buf_len; + + aead_req = tls_alloc_aead_request(ctx->aead_send, GFP_ATOMIC); + if (!aead_req) + return NULL; + + switch (tls_ctx->crypto_send.info.cipher_type) { + case TLS_CIPHER_AES_GCM_128: + salt = tls_ctx->crypto_send.aes_gcm_128.salt; + break; + case TLS_CIPHER_AES_GCM_256: + salt = tls_ctx->crypto_send.aes_gcm_256.salt; + break; + default: + goto free_req; + } + cipher_sz = &tls_cipher_size_desc[tls_ctx->crypto_send.info.cipher_type]; + buf_len = cipher_sz->salt + cipher_sz->iv + TLS_AAD_SPACE_SIZE + + sync_size + cipher_sz->tag; + buf = kmalloc(buf_len, GFP_ATOMIC); + if (!buf) + goto free_req; + + iv = buf; + memcpy(iv, salt, cipher_sz->salt); + aad = buf + cipher_sz->salt + cipher_sz->iv; + dummy_buf = aad + TLS_AAD_SPACE_SIZE; + + nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC); + if (!nskb) + goto free_buf; + + skb_reserve(nskb, skb_headroom(skb)); + + fill_sg_out(sg_out, buf, tls_ctx, nskb, tcp_payload_offset, + payload_len, sync_size, dummy_buf); + + if (tls_enc_records(aead_req, ctx->aead_send, sg_in, sg_out, aad, iv, + rcd_sn, sync_size + payload_len, + &tls_ctx->prot_info) < 0) + goto free_nskb; + + complete_skb(nskb, skb, tcp_payload_offset); + + /* validate_xmit_skb_list assumes that if the skb wasn't segmented + * nskb->prev will point to the skb itself + */ + nskb->prev = nskb; + +free_buf: + kfree(buf); +free_req: + kfree(aead_req); + return nskb; +free_nskb: + kfree_skb(nskb); + nskb = NULL; + goto free_buf; +} + +static struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb) +{ + int tcp_payload_offset = skb_tcp_all_headers(skb); + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); + int payload_len = skb->len - tcp_payload_offset; + struct scatterlist *sg_in, sg_out[3]; + struct sk_buff *nskb = NULL; + int sg_in_max_elements; + int resync_sgs = 0; + s32 sync_size = 0; + u64 rcd_sn; + + /* worst case is: + * MAX_SKB_FRAGS in tls_record_info + * MAX_SKB_FRAGS + 1 in SKB head and frags. + */ + sg_in_max_elements = 2 * MAX_SKB_FRAGS + 1; + + if (!payload_len) + return skb; + + sg_in = kmalloc_array(sg_in_max_elements, sizeof(*sg_in), GFP_ATOMIC); + if (!sg_in) + goto free_orig; + + sg_init_table(sg_in, sg_in_max_elements); + sg_init_table(sg_out, ARRAY_SIZE(sg_out)); + + if (fill_sg_in(sg_in, skb, ctx, &rcd_sn, &sync_size, &resync_sgs)) { + /* bypass packets before kernel TLS socket option was set */ + if (sync_size < 0 && payload_len <= -sync_size) + nskb = skb_get(skb); + goto put_sg; + } + + nskb = tls_enc_skb(tls_ctx, sg_out, sg_in, skb, sync_size, rcd_sn); + +put_sg: + while (resync_sgs) + put_page(sg_page(&sg_in[--resync_sgs])); + kfree(sg_in); +free_orig: + if (nskb) + consume_skb(skb); + else + kfree_skb(skb); + return nskb; +} + +struct sk_buff *tls_validate_xmit_skb(struct sock *sk, + struct net_device *dev, + struct sk_buff *skb) +{ + if (dev == rcu_dereference_bh(tls_get_ctx(sk)->netdev) || + netif_is_bond_master(dev)) + return skb; + + return tls_sw_fallback(sk, skb); +} +EXPORT_SYMBOL_GPL(tls_validate_xmit_skb); + +struct sk_buff *tls_validate_xmit_skb_sw(struct sock *sk, + struct net_device *dev, + struct sk_buff *skb) +{ + return tls_sw_fallback(sk, skb); +} + +struct sk_buff *tls_encrypt_skb(struct sk_buff *skb) +{ + return tls_sw_fallback(skb->sk, skb); +} +EXPORT_SYMBOL_GPL(tls_encrypt_skb); + +int tls_sw_fallback_init(struct sock *sk, + struct tls_offload_context_tx *offload_ctx, + struct tls_crypto_info *crypto_info) +{ + const struct tls_cipher_size_desc *cipher_sz; + const u8 *key; + int rc; + + offload_ctx->aead_send = + crypto_alloc_aead("gcm(aes)", 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(offload_ctx->aead_send)) { + rc = PTR_ERR(offload_ctx->aead_send); + pr_err_ratelimited("crypto_alloc_aead failed rc=%d\n", rc); + offload_ctx->aead_send = NULL; + goto err_out; + } + + switch (crypto_info->cipher_type) { + case TLS_CIPHER_AES_GCM_128: + key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key; + break; + case TLS_CIPHER_AES_GCM_256: + key = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->key; + break; + default: + rc = -EINVAL; + goto free_aead; + } + cipher_sz = &tls_cipher_size_desc[crypto_info->cipher_type]; + + rc = crypto_aead_setkey(offload_ctx->aead_send, key, cipher_sz->key); + if (rc) + goto free_aead; + + rc = crypto_aead_setauthsize(offload_ctx->aead_send, cipher_sz->tag); + if (rc) + goto free_aead; + + return 0; +free_aead: + crypto_free_aead(offload_ctx->aead_send); +err_out: + return rc; +} diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c new file mode 100644 index 000000000..338a443fa --- /dev/null +++ b/net/tls/tls_main.c @@ -0,0 +1,1246 @@ +/* + * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. + * Copyright (c) 2016-2017, Dave Watson . All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "tls.h" + +MODULE_AUTHOR("Mellanox Technologies"); +MODULE_DESCRIPTION("Transport Layer Security Support"); +MODULE_LICENSE("Dual BSD/GPL"); +MODULE_ALIAS_TCP_ULP("tls"); + +enum { + TLSV4, + TLSV6, + TLS_NUM_PROTS, +}; + +#define CIPHER_SIZE_DESC(cipher) [cipher] = { \ + .iv = cipher ## _IV_SIZE, \ + .key = cipher ## _KEY_SIZE, \ + .salt = cipher ## _SALT_SIZE, \ + .tag = cipher ## _TAG_SIZE, \ + .rec_seq = cipher ## _REC_SEQ_SIZE, \ +} + +const struct tls_cipher_size_desc tls_cipher_size_desc[] = { + CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_128), + CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_256), + CIPHER_SIZE_DESC(TLS_CIPHER_AES_CCM_128), + CIPHER_SIZE_DESC(TLS_CIPHER_CHACHA20_POLY1305), + CIPHER_SIZE_DESC(TLS_CIPHER_SM4_GCM), + CIPHER_SIZE_DESC(TLS_CIPHER_SM4_CCM), +}; + +static const struct proto *saved_tcpv6_prot; +static DEFINE_MUTEX(tcpv6_prot_mutex); +static const struct proto *saved_tcpv4_prot; +static DEFINE_MUTEX(tcpv4_prot_mutex); +static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; +static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; +static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], + const struct proto *base); + +void update_sk_prot(struct sock *sk, struct tls_context *ctx) +{ + int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; + + WRITE_ONCE(sk->sk_prot, + &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]); + WRITE_ONCE(sk->sk_socket->ops, + &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]); +} + +int wait_on_pending_writer(struct sock *sk, long *timeo) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + int ret, rc = 0; + + add_wait_queue(sk_sleep(sk), &wait); + while (1) { + if (!*timeo) { + rc = -EAGAIN; + break; + } + + if (signal_pending(current)) { + rc = sock_intr_errno(*timeo); + break; + } + + ret = sk_wait_event(sk, timeo, + !READ_ONCE(sk->sk_write_pending), &wait); + if (ret) { + if (ret < 0) + rc = ret; + break; + } + } + remove_wait_queue(sk_sleep(sk), &wait); + return rc; +} + +int tls_push_sg(struct sock *sk, + struct tls_context *ctx, + struct scatterlist *sg, + u16 first_offset, + int flags) +{ + int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST; + int ret = 0; + struct page *p; + size_t size; + int offset = first_offset; + + size = sg->length - offset; + offset += sg->offset; + + ctx->in_tcp_sendpages = true; + while (1) { + if (sg_is_last(sg)) + sendpage_flags = flags; + + /* is sending application-limited? */ + tcp_rate_check_app_limited(sk); + p = sg_page(sg); +retry: + ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags); + + if (ret != size) { + if (ret > 0) { + offset += ret; + size -= ret; + goto retry; + } + + offset -= sg->offset; + ctx->partially_sent_offset = offset; + ctx->partially_sent_record = (void *)sg; + ctx->in_tcp_sendpages = false; + return ret; + } + + put_page(p); + sk_mem_uncharge(sk, sg->length); + sg = sg_next(sg); + if (!sg) + break; + + offset = sg->offset; + size = sg->length; + } + + ctx->in_tcp_sendpages = false; + + return 0; +} + +static int tls_handle_open_record(struct sock *sk, int flags) +{ + struct tls_context *ctx = tls_get_ctx(sk); + + if (tls_is_pending_open_record(ctx)) + return ctx->push_pending_record(sk, flags); + + return 0; +} + +int tls_process_cmsg(struct sock *sk, struct msghdr *msg, + unsigned char *record_type) +{ + struct cmsghdr *cmsg; + int rc = -EINVAL; + + for_each_cmsghdr(cmsg, msg) { + if (!CMSG_OK(msg, cmsg)) + return -EINVAL; + if (cmsg->cmsg_level != SOL_TLS) + continue; + + switch (cmsg->cmsg_type) { + case TLS_SET_RECORD_TYPE: + if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type))) + return -EINVAL; + + if (msg->msg_flags & MSG_MORE) + return -EINVAL; + + rc = tls_handle_open_record(sk, msg->msg_flags); + if (rc) + return rc; + + *record_type = *(unsigned char *)CMSG_DATA(cmsg); + rc = 0; + break; + default: + return -EINVAL; + } + } + + return rc; +} + +int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, + int flags) +{ + struct scatterlist *sg; + u16 offset; + + sg = ctx->partially_sent_record; + offset = ctx->partially_sent_offset; + + ctx->partially_sent_record = NULL; + return tls_push_sg(sk, ctx, sg, offset, flags); +} + +void tls_free_partial_record(struct sock *sk, struct tls_context *ctx) +{ + struct scatterlist *sg; + + for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) { + put_page(sg_page(sg)); + sk_mem_uncharge(sk, sg->length); + } + ctx->partially_sent_record = NULL; +} + +static void tls_write_space(struct sock *sk) +{ + struct tls_context *ctx = tls_get_ctx(sk); + + /* If in_tcp_sendpages call lower protocol write space handler + * to ensure we wake up any waiting operations there. For example + * if do_tcp_sendpages where to call sk_wait_event. + */ + if (ctx->in_tcp_sendpages) { + ctx->sk_write_space(sk); + return; + } + +#ifdef CONFIG_TLS_DEVICE + if (ctx->tx_conf == TLS_HW) + tls_device_write_space(sk, ctx); + else +#endif + tls_sw_write_space(sk, ctx); + + ctx->sk_write_space(sk); +} + +/** + * tls_ctx_free() - free TLS ULP context + * @sk: socket to with @ctx is attached + * @ctx: TLS context structure + * + * Free TLS context. If @sk is %NULL caller guarantees that the socket + * to which @ctx was attached has no outstanding references. + */ +void tls_ctx_free(struct sock *sk, struct tls_context *ctx) +{ + if (!ctx) + return; + + memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); + memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); + mutex_destroy(&ctx->tx_lock); + + if (sk) + kfree_rcu(ctx, rcu); + else + kfree(ctx); +} + +static void tls_sk_proto_cleanup(struct sock *sk, + struct tls_context *ctx, long timeo) +{ + if (unlikely(sk->sk_write_pending) && + !wait_on_pending_writer(sk, &timeo)) + tls_handle_open_record(sk, 0); + + /* We need these for tls_sw_fallback handling of other packets */ + if (ctx->tx_conf == TLS_SW) { + kfree(ctx->tx.rec_seq); + kfree(ctx->tx.iv); + tls_sw_release_resources_tx(sk); + TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); + } else if (ctx->tx_conf == TLS_HW) { + tls_device_free_resources_tx(sk); + TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); + } + + if (ctx->rx_conf == TLS_SW) { + tls_sw_release_resources_rx(sk); + TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); + } else if (ctx->rx_conf == TLS_HW) { + tls_device_offload_cleanup_rx(sk); + TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); + } +} + +static void tls_sk_proto_close(struct sock *sk, long timeout) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tls_context *ctx = tls_get_ctx(sk); + long timeo = sock_sndtimeo(sk, 0); + bool free_ctx; + + if (ctx->tx_conf == TLS_SW) + tls_sw_cancel_work_tx(ctx); + + lock_sock(sk); + free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW; + + if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE) + tls_sk_proto_cleanup(sk, ctx, timeo); + + write_lock_bh(&sk->sk_callback_lock); + if (free_ctx) + rcu_assign_pointer(icsk->icsk_ulp_data, NULL); + WRITE_ONCE(sk->sk_prot, ctx->sk_proto); + if (sk->sk_write_space == tls_write_space) + sk->sk_write_space = ctx->sk_write_space; + write_unlock_bh(&sk->sk_callback_lock); + release_sock(sk); + if (ctx->tx_conf == TLS_SW) + tls_sw_free_ctx_tx(ctx); + if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) + tls_sw_strparser_done(ctx); + if (ctx->rx_conf == TLS_SW) + tls_sw_free_ctx_rx(ctx); + ctx->sk_proto->close(sk, timeout); + + if (free_ctx) + tls_ctx_free(sk, ctx); +} + +static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval, + int __user *optlen, int tx) +{ + int rc = 0; + struct tls_context *ctx = tls_get_ctx(sk); + struct tls_crypto_info *crypto_info; + struct cipher_context *cctx; + int len; + + if (get_user(len, optlen)) + return -EFAULT; + + if (!optval || (len < sizeof(*crypto_info))) { + rc = -EINVAL; + goto out; + } + + if (!ctx) { + rc = -EBUSY; + goto out; + } + + /* get user crypto info */ + if (tx) { + crypto_info = &ctx->crypto_send.info; + cctx = &ctx->tx; + } else { + crypto_info = &ctx->crypto_recv.info; + cctx = &ctx->rx; + } + + if (!TLS_CRYPTO_INFO_READY(crypto_info)) { + rc = -EBUSY; + goto out; + } + + if (len == sizeof(*crypto_info)) { + if (copy_to_user(optval, crypto_info, sizeof(*crypto_info))) + rc = -EFAULT; + goto out; + } + + switch (crypto_info->cipher_type) { + case TLS_CIPHER_AES_GCM_128: { + struct tls12_crypto_info_aes_gcm_128 * + crypto_info_aes_gcm_128 = + container_of(crypto_info, + struct tls12_crypto_info_aes_gcm_128, + info); + + if (len != sizeof(*crypto_info_aes_gcm_128)) { + rc = -EINVAL; + goto out; + } + memcpy(crypto_info_aes_gcm_128->iv, + cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, + TLS_CIPHER_AES_GCM_128_IV_SIZE); + memcpy(crypto_info_aes_gcm_128->rec_seq, cctx->rec_seq, + TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE); + if (copy_to_user(optval, + crypto_info_aes_gcm_128, + sizeof(*crypto_info_aes_gcm_128))) + rc = -EFAULT; + break; + } + case TLS_CIPHER_AES_GCM_256: { + struct tls12_crypto_info_aes_gcm_256 * + crypto_info_aes_gcm_256 = + container_of(crypto_info, + struct tls12_crypto_info_aes_gcm_256, + info); + + if (len != sizeof(*crypto_info_aes_gcm_256)) { + rc = -EINVAL; + goto out; + } + memcpy(crypto_info_aes_gcm_256->iv, + cctx->iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE, + TLS_CIPHER_AES_GCM_256_IV_SIZE); + memcpy(crypto_info_aes_gcm_256->rec_seq, cctx->rec_seq, + TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE); + if (copy_to_user(optval, + crypto_info_aes_gcm_256, + sizeof(*crypto_info_aes_gcm_256))) + rc = -EFAULT; + break; + } + case TLS_CIPHER_AES_CCM_128: { + struct tls12_crypto_info_aes_ccm_128 *aes_ccm_128 = + container_of(crypto_info, + struct tls12_crypto_info_aes_ccm_128, info); + + if (len != sizeof(*aes_ccm_128)) { + rc = -EINVAL; + goto out; + } + memcpy(aes_ccm_128->iv, + cctx->iv + TLS_CIPHER_AES_CCM_128_SALT_SIZE, + TLS_CIPHER_AES_CCM_128_IV_SIZE); + memcpy(aes_ccm_128->rec_seq, cctx->rec_seq, + TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE); + if (copy_to_user(optval, aes_ccm_128, sizeof(*aes_ccm_128))) + rc = -EFAULT; + break; + } + case TLS_CIPHER_CHACHA20_POLY1305: { + struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305 = + container_of(crypto_info, + struct tls12_crypto_info_chacha20_poly1305, + info); + + if (len != sizeof(*chacha20_poly1305)) { + rc = -EINVAL; + goto out; + } + memcpy(chacha20_poly1305->iv, + cctx->iv + TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE, + TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE); + memcpy(chacha20_poly1305->rec_seq, cctx->rec_seq, + TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE); + if (copy_to_user(optval, chacha20_poly1305, + sizeof(*chacha20_poly1305))) + rc = -EFAULT; + break; + } + case TLS_CIPHER_SM4_GCM: { + struct tls12_crypto_info_sm4_gcm *sm4_gcm_info = + container_of(crypto_info, + struct tls12_crypto_info_sm4_gcm, info); + + if (len != sizeof(*sm4_gcm_info)) { + rc = -EINVAL; + goto out; + } + memcpy(sm4_gcm_info->iv, + cctx->iv + TLS_CIPHER_SM4_GCM_SALT_SIZE, + TLS_CIPHER_SM4_GCM_IV_SIZE); + memcpy(sm4_gcm_info->rec_seq, cctx->rec_seq, + TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE); + if (copy_to_user(optval, sm4_gcm_info, sizeof(*sm4_gcm_info))) + rc = -EFAULT; + break; + } + case TLS_CIPHER_SM4_CCM: { + struct tls12_crypto_info_sm4_ccm *sm4_ccm_info = + container_of(crypto_info, + struct tls12_crypto_info_sm4_ccm, info); + + if (len != sizeof(*sm4_ccm_info)) { + rc = -EINVAL; + goto out; + } + memcpy(sm4_ccm_info->iv, + cctx->iv + TLS_CIPHER_SM4_CCM_SALT_SIZE, + TLS_CIPHER_SM4_CCM_IV_SIZE); + memcpy(sm4_ccm_info->rec_seq, cctx->rec_seq, + TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE); + if (copy_to_user(optval, sm4_ccm_info, sizeof(*sm4_ccm_info))) + rc = -EFAULT; + break; + } + case TLS_CIPHER_ARIA_GCM_128: { + struct tls12_crypto_info_aria_gcm_128 * + crypto_info_aria_gcm_128 = + container_of(crypto_info, + struct tls12_crypto_info_aria_gcm_128, + info); + + if (len != sizeof(*crypto_info_aria_gcm_128)) { + rc = -EINVAL; + goto out; + } + memcpy(crypto_info_aria_gcm_128->iv, + cctx->iv + TLS_CIPHER_ARIA_GCM_128_SALT_SIZE, + TLS_CIPHER_ARIA_GCM_128_IV_SIZE); + memcpy(crypto_info_aria_gcm_128->rec_seq, cctx->rec_seq, + TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE); + if (copy_to_user(optval, + crypto_info_aria_gcm_128, + sizeof(*crypto_info_aria_gcm_128))) + rc = -EFAULT; + break; + } + case TLS_CIPHER_ARIA_GCM_256: { + struct tls12_crypto_info_aria_gcm_256 * + crypto_info_aria_gcm_256 = + container_of(crypto_info, + struct tls12_crypto_info_aria_gcm_256, + info); + + if (len != sizeof(*crypto_info_aria_gcm_256)) { + rc = -EINVAL; + goto out; + } + memcpy(crypto_info_aria_gcm_256->iv, + cctx->iv + TLS_CIPHER_ARIA_GCM_256_SALT_SIZE, + TLS_CIPHER_ARIA_GCM_256_IV_SIZE); + memcpy(crypto_info_aria_gcm_256->rec_seq, cctx->rec_seq, + TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE); + if (copy_to_user(optval, + crypto_info_aria_gcm_256, + sizeof(*crypto_info_aria_gcm_256))) + rc = -EFAULT; + break; + } + default: + rc = -EINVAL; + } + +out: + return rc; +} + +static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval, + int __user *optlen) +{ + struct tls_context *ctx = tls_get_ctx(sk); + unsigned int value; + int len; + + if (get_user(len, optlen)) + return -EFAULT; + + if (len != sizeof(value)) + return -EINVAL; + + value = ctx->zerocopy_sendfile; + if (copy_to_user(optval, &value, sizeof(value))) + return -EFAULT; + + return 0; +} + +static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval, + int __user *optlen) +{ + struct tls_context *ctx = tls_get_ctx(sk); + int value, len; + + if (ctx->prot_info.version != TLS_1_3_VERSION) + return -EINVAL; + + if (get_user(len, optlen)) + return -EFAULT; + if (len < sizeof(value)) + return -EINVAL; + + value = -EINVAL; + if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) + value = ctx->rx_no_pad; + if (value < 0) + return value; + + if (put_user(sizeof(value), optlen)) + return -EFAULT; + if (copy_to_user(optval, &value, sizeof(value))) + return -EFAULT; + + return 0; +} + +static int do_tls_getsockopt(struct sock *sk, int optname, + char __user *optval, int __user *optlen) +{ + int rc = 0; + + lock_sock(sk); + + switch (optname) { + case TLS_TX: + case TLS_RX: + rc = do_tls_getsockopt_conf(sk, optval, optlen, + optname == TLS_TX); + break; + case TLS_TX_ZEROCOPY_RO: + rc = do_tls_getsockopt_tx_zc(sk, optval, optlen); + break; + case TLS_RX_EXPECT_NO_PAD: + rc = do_tls_getsockopt_no_pad(sk, optval, optlen); + break; + default: + rc = -ENOPROTOOPT; + break; + } + + release_sock(sk); + + return rc; +} + +static int tls_getsockopt(struct sock *sk, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct tls_context *ctx = tls_get_ctx(sk); + + if (level != SOL_TLS) + return ctx->sk_proto->getsockopt(sk, level, + optname, optval, optlen); + + return do_tls_getsockopt(sk, optname, optval, optlen); +} + +static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, + unsigned int optlen, int tx) +{ + struct tls_crypto_info *crypto_info; + struct tls_crypto_info *alt_crypto_info; + struct tls_context *ctx = tls_get_ctx(sk); + size_t optsize; + int rc = 0; + int conf; + + if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info))) + return -EINVAL; + + if (tx) { + crypto_info = &ctx->crypto_send.info; + alt_crypto_info = &ctx->crypto_recv.info; + } else { + crypto_info = &ctx->crypto_recv.info; + alt_crypto_info = &ctx->crypto_send.info; + } + + /* Currently we don't support set crypto info more than one time */ + if (TLS_CRYPTO_INFO_READY(crypto_info)) + return -EBUSY; + + rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info)); + if (rc) { + rc = -EFAULT; + goto err_crypto_info; + } + + /* check version */ + if (crypto_info->version != TLS_1_2_VERSION && + crypto_info->version != TLS_1_3_VERSION) { + rc = -EINVAL; + goto err_crypto_info; + } + + /* Ensure that TLS version and ciphers are same in both directions */ + if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { + if (alt_crypto_info->version != crypto_info->version || + alt_crypto_info->cipher_type != crypto_info->cipher_type) { + rc = -EINVAL; + goto err_crypto_info; + } + } + + switch (crypto_info->cipher_type) { + case TLS_CIPHER_AES_GCM_128: + optsize = sizeof(struct tls12_crypto_info_aes_gcm_128); + break; + case TLS_CIPHER_AES_GCM_256: { + optsize = sizeof(struct tls12_crypto_info_aes_gcm_256); + break; + } + case TLS_CIPHER_AES_CCM_128: + optsize = sizeof(struct tls12_crypto_info_aes_ccm_128); + break; + case TLS_CIPHER_CHACHA20_POLY1305: + optsize = sizeof(struct tls12_crypto_info_chacha20_poly1305); + break; + case TLS_CIPHER_SM4_GCM: + optsize = sizeof(struct tls12_crypto_info_sm4_gcm); + break; + case TLS_CIPHER_SM4_CCM: + optsize = sizeof(struct tls12_crypto_info_sm4_ccm); + break; + case TLS_CIPHER_ARIA_GCM_128: + if (crypto_info->version != TLS_1_2_VERSION) { + rc = -EINVAL; + goto err_crypto_info; + } + optsize = sizeof(struct tls12_crypto_info_aria_gcm_128); + break; + case TLS_CIPHER_ARIA_GCM_256: + if (crypto_info->version != TLS_1_2_VERSION) { + rc = -EINVAL; + goto err_crypto_info; + } + optsize = sizeof(struct tls12_crypto_info_aria_gcm_256); + break; + default: + rc = -EINVAL; + goto err_crypto_info; + } + + if (optlen != optsize) { + rc = -EINVAL; + goto err_crypto_info; + } + + rc = copy_from_sockptr_offset(crypto_info + 1, optval, + sizeof(*crypto_info), + optlen - sizeof(*crypto_info)); + if (rc) { + rc = -EFAULT; + goto err_crypto_info; + } + + if (tx) { + rc = tls_set_device_offload(sk, ctx); + conf = TLS_HW; + if (!rc) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); + } else { + rc = tls_set_sw_offload(sk, ctx, 1); + if (rc) + goto err_crypto_info; + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); + conf = TLS_SW; + } + } else { + rc = tls_set_device_offload_rx(sk, ctx); + conf = TLS_HW; + if (!rc) { + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); + } else { + rc = tls_set_sw_offload(sk, ctx, 0); + if (rc) + goto err_crypto_info; + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); + conf = TLS_SW; + } + tls_sw_strparser_arm(sk, ctx); + } + + if (tx) + ctx->tx_conf = conf; + else + ctx->rx_conf = conf; + update_sk_prot(sk, ctx); + if (tx) { + ctx->sk_write_space = sk->sk_write_space; + sk->sk_write_space = tls_write_space; + } else { + struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx); + + tls_strp_check_rcv(&rx_ctx->strp); + } + return 0; + +err_crypto_info: + memzero_explicit(crypto_info, sizeof(union tls_crypto_context)); + return rc; +} + +static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval, + unsigned int optlen) +{ + struct tls_context *ctx = tls_get_ctx(sk); + unsigned int value; + + if (sockptr_is_null(optval) || optlen != sizeof(value)) + return -EINVAL; + + if (copy_from_sockptr(&value, optval, sizeof(value))) + return -EFAULT; + + if (value > 1) + return -EINVAL; + + ctx->zerocopy_sendfile = value; + + return 0; +} + +static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval, + unsigned int optlen) +{ + struct tls_context *ctx = tls_get_ctx(sk); + u32 val; + int rc; + + if (ctx->prot_info.version != TLS_1_3_VERSION || + sockptr_is_null(optval) || optlen < sizeof(val)) + return -EINVAL; + + rc = copy_from_sockptr(&val, optval, sizeof(val)); + if (rc) + return -EFAULT; + if (val > 1) + return -EINVAL; + rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val)); + if (rc < 1) + return rc == 0 ? -EINVAL : rc; + + lock_sock(sk); + rc = -EINVAL; + if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) { + ctx->rx_no_pad = val; + tls_update_rx_zc_capable(ctx); + rc = 0; + } + release_sock(sk); + + return rc; +} + +static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval, + unsigned int optlen) +{ + int rc = 0; + + switch (optname) { + case TLS_TX: + case TLS_RX: + lock_sock(sk); + rc = do_tls_setsockopt_conf(sk, optval, optlen, + optname == TLS_TX); + release_sock(sk); + break; + case TLS_TX_ZEROCOPY_RO: + lock_sock(sk); + rc = do_tls_setsockopt_tx_zc(sk, optval, optlen); + release_sock(sk); + break; + case TLS_RX_EXPECT_NO_PAD: + rc = do_tls_setsockopt_no_pad(sk, optval, optlen); + break; + default: + rc = -ENOPROTOOPT; + break; + } + return rc; +} + +static int tls_setsockopt(struct sock *sk, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct tls_context *ctx = tls_get_ctx(sk); + + if (level != SOL_TLS) + return ctx->sk_proto->setsockopt(sk, level, optname, optval, + optlen); + + return do_tls_setsockopt(sk, optname, optval, optlen); +} + +struct tls_context *tls_ctx_create(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tls_context *ctx; + + ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC); + if (!ctx) + return NULL; + + mutex_init(&ctx->tx_lock); + rcu_assign_pointer(icsk->icsk_ulp_data, ctx); + ctx->sk_proto = READ_ONCE(sk->sk_prot); + ctx->sk = sk; + return ctx; +} + +static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG], + const struct proto_ops *base) +{ + ops[TLS_BASE][TLS_BASE] = *base; + + ops[TLS_SW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; + ops[TLS_SW ][TLS_BASE].sendpage_locked = tls_sw_sendpage_locked; + + ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE]; + ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read; + + ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE]; + ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read; + +#ifdef CONFIG_TLS_DEVICE + ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; + ops[TLS_HW ][TLS_BASE].sendpage_locked = NULL; + + ops[TLS_HW ][TLS_SW ] = ops[TLS_BASE][TLS_SW ]; + ops[TLS_HW ][TLS_SW ].sendpage_locked = NULL; + + ops[TLS_BASE][TLS_HW ] = ops[TLS_BASE][TLS_SW ]; + + ops[TLS_SW ][TLS_HW ] = ops[TLS_SW ][TLS_SW ]; + + ops[TLS_HW ][TLS_HW ] = ops[TLS_HW ][TLS_SW ]; + ops[TLS_HW ][TLS_HW ].sendpage_locked = NULL; +#endif +#ifdef CONFIG_TLS_TOE + ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base; +#endif +} + +static void tls_build_proto(struct sock *sk) +{ + int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; + struct proto *prot = READ_ONCE(sk->sk_prot); + + /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ + if (ip_ver == TLSV6 && + unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) { + mutex_lock(&tcpv6_prot_mutex); + if (likely(prot != saved_tcpv6_prot)) { + build_protos(tls_prots[TLSV6], prot); + build_proto_ops(tls_proto_ops[TLSV6], + sk->sk_socket->ops); + smp_store_release(&saved_tcpv6_prot, prot); + } + mutex_unlock(&tcpv6_prot_mutex); + } + + if (ip_ver == TLSV4 && + unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) { + mutex_lock(&tcpv4_prot_mutex); + if (likely(prot != saved_tcpv4_prot)) { + build_protos(tls_prots[TLSV4], prot); + build_proto_ops(tls_proto_ops[TLSV4], + sk->sk_socket->ops); + smp_store_release(&saved_tcpv4_prot, prot); + } + mutex_unlock(&tcpv4_prot_mutex); + } +} + +static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], + const struct proto *base) +{ + prot[TLS_BASE][TLS_BASE] = *base; + prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt; + prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt; + prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close; + + prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; + prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg; + prot[TLS_SW][TLS_BASE].sendpage = tls_sw_sendpage; + + prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; + prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; + prot[TLS_BASE][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; + prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close; + + prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE]; + prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; + prot[TLS_SW][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; + prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; + +#ifdef CONFIG_TLS_DEVICE + prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; + prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg; + prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage; + + prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW]; + prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg; + prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage; + + prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW]; + + prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW]; + + prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW]; +#endif +#ifdef CONFIG_TLS_TOE + prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; + prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_toe_hash; + prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_toe_unhash; +#endif +} + +static int tls_init(struct sock *sk) +{ + struct tls_context *ctx; + int rc = 0; + + tls_build_proto(sk); + +#ifdef CONFIG_TLS_TOE + if (tls_toe_bypass(sk)) + return 0; +#endif + + /* The TLS ulp is currently supported only for TCP sockets + * in ESTABLISHED state. + * Supporting sockets in LISTEN state will require us + * to modify the accept implementation to clone rather then + * share the ulp context. + */ + if (sk->sk_state != TCP_ESTABLISHED) + return -ENOTCONN; + + /* allocate tls context */ + write_lock_bh(&sk->sk_callback_lock); + ctx = tls_ctx_create(sk); + if (!ctx) { + rc = -ENOMEM; + goto out; + } + + ctx->tx_conf = TLS_BASE; + ctx->rx_conf = TLS_BASE; + update_sk_prot(sk, ctx); +out: + write_unlock_bh(&sk->sk_callback_lock); + return rc; +} + +static void tls_update(struct sock *sk, struct proto *p, + void (*write_space)(struct sock *sk)) +{ + struct tls_context *ctx; + + WARN_ON_ONCE(sk->sk_prot == p); + + ctx = tls_get_ctx(sk); + if (likely(ctx)) { + ctx->sk_write_space = write_space; + ctx->sk_proto = p; + } else { + /* Pairs with lockless read in sk_clone_lock(). */ + WRITE_ONCE(sk->sk_prot, p); + sk->sk_write_space = write_space; + } +} + +static u16 tls_user_config(struct tls_context *ctx, bool tx) +{ + u16 config = tx ? ctx->tx_conf : ctx->rx_conf; + + switch (config) { + case TLS_BASE: + return TLS_CONF_BASE; + case TLS_SW: + return TLS_CONF_SW; + case TLS_HW: + return TLS_CONF_HW; + case TLS_HW_RECORD: + return TLS_CONF_HW_RECORD; + } + return 0; +} + +static int tls_get_info(const struct sock *sk, struct sk_buff *skb) +{ + u16 version, cipher_type; + struct tls_context *ctx; + struct nlattr *start; + int err; + + start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS); + if (!start) + return -EMSGSIZE; + + rcu_read_lock(); + ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data); + if (!ctx) { + err = 0; + goto nla_failure; + } + version = ctx->prot_info.version; + if (version) { + err = nla_put_u16(skb, TLS_INFO_VERSION, version); + if (err) + goto nla_failure; + } + cipher_type = ctx->prot_info.cipher_type; + if (cipher_type) { + err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type); + if (err) + goto nla_failure; + } + err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true)); + if (err) + goto nla_failure; + + err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false)); + if (err) + goto nla_failure; + + if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) { + err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX); + if (err) + goto nla_failure; + } + if (ctx->rx_no_pad) { + err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD); + if (err) + goto nla_failure; + } + + rcu_read_unlock(); + nla_nest_end(skb, start); + return 0; + +nla_failure: + rcu_read_unlock(); + nla_nest_cancel(skb, start); + return err; +} + +static size_t tls_get_info_size(const struct sock *sk) +{ + size_t size = 0; + + size += nla_total_size(0) + /* INET_ULP_INFO_TLS */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_VERSION */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */ + nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */ + nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */ + nla_total_size(0) + /* TLS_INFO_RX_NO_PAD */ + 0; + + return size; +} + +static int __net_init tls_init_net(struct net *net) +{ + int err; + + net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib); + if (!net->mib.tls_statistics) + return -ENOMEM; + + err = tls_proc_init(net); + if (err) + goto err_free_stats; + + return 0; +err_free_stats: + free_percpu(net->mib.tls_statistics); + return err; +} + +static void __net_exit tls_exit_net(struct net *net) +{ + tls_proc_fini(net); + free_percpu(net->mib.tls_statistics); +} + +static struct pernet_operations tls_proc_ops = { + .init = tls_init_net, + .exit = tls_exit_net, +}; + +static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { + .name = "tls", + .owner = THIS_MODULE, + .init = tls_init, + .update = tls_update, + .get_info = tls_get_info, + .get_info_size = tls_get_info_size, +}; + +static int __init tls_register(void) +{ + int err; + + err = register_pernet_subsys(&tls_proc_ops); + if (err) + return err; + + err = tls_strp_dev_init(); + if (err) + goto err_pernet; + + err = tls_device_init(); + if (err) + goto err_strp; + + tcp_register_ulp(&tcp_tls_ulp_ops); + + return 0; +err_strp: + tls_strp_dev_exit(); +err_pernet: + unregister_pernet_subsys(&tls_proc_ops); + return err; +} + +static void __exit tls_unregister(void) +{ + tcp_unregister_ulp(&tcp_tls_ulp_ops); + tls_strp_dev_exit(); + tls_device_cleanup(); + unregister_pernet_subsys(&tls_proc_ops); +} + +module_init(tls_register); +module_exit(tls_unregister); diff --git a/net/tls/tls_proc.c b/net/tls/tls_proc.c new file mode 100644 index 000000000..68982728f --- /dev/null +++ b/net/tls/tls_proc.c @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: (GPL-2.0-only OR BSD-2-Clause) +/* Copyright (C) 2019 Netronome Systems, Inc. */ + +#include +#include +#include +#include + +#include "tls.h" + +#ifdef CONFIG_PROC_FS +static const struct snmp_mib tls_mib_list[] = { + SNMP_MIB_ITEM("TlsCurrTxSw", LINUX_MIB_TLSCURRTXSW), + SNMP_MIB_ITEM("TlsCurrRxSw", LINUX_MIB_TLSCURRRXSW), + SNMP_MIB_ITEM("TlsCurrTxDevice", LINUX_MIB_TLSCURRTXDEVICE), + SNMP_MIB_ITEM("TlsCurrRxDevice", LINUX_MIB_TLSCURRRXDEVICE), + SNMP_MIB_ITEM("TlsTxSw", LINUX_MIB_TLSTXSW), + SNMP_MIB_ITEM("TlsRxSw", LINUX_MIB_TLSRXSW), + SNMP_MIB_ITEM("TlsTxDevice", LINUX_MIB_TLSTXDEVICE), + SNMP_MIB_ITEM("TlsRxDevice", LINUX_MIB_TLSRXDEVICE), + SNMP_MIB_ITEM("TlsDecryptError", LINUX_MIB_TLSDECRYPTERROR), + SNMP_MIB_ITEM("TlsRxDeviceResync", LINUX_MIB_TLSRXDEVICERESYNC), + SNMP_MIB_ITEM("TlsDecryptRetry", LINUX_MIB_TLSDECRYPTRETRY), + SNMP_MIB_ITEM("TlsRxNoPadViolation", LINUX_MIB_TLSRXNOPADVIOL), + SNMP_MIB_SENTINEL +}; + +static int tls_statistics_seq_show(struct seq_file *seq, void *v) +{ + unsigned long buf[LINUX_MIB_TLSMAX] = {}; + struct net *net = seq->private; + int i; + + snmp_get_cpu_field_batch(buf, tls_mib_list, net->mib.tls_statistics); + for (i = 0; tls_mib_list[i].name; i++) + seq_printf(seq, "%-32s\t%lu\n", tls_mib_list[i].name, buf[i]); + + return 0; +} +#endif + +int __net_init tls_proc_init(struct net *net) +{ +#ifdef CONFIG_PROC_FS + if (!proc_create_net_single("tls_stat", 0444, net->proc_net, + tls_statistics_seq_show, NULL)) + return -ENOMEM; +#endif /* CONFIG_PROC_FS */ + + return 0; +} + +void __net_exit tls_proc_fini(struct net *net) +{ + remove_proc_entry("tls_stat", net->proc_net); +} diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c new file mode 100644 index 000000000..f37f4a0fc --- /dev/null +++ b/net/tls/tls_strp.c @@ -0,0 +1,633 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* Copyright (c) 2016 Tom Herbert */ + +#include +#include +#include +#include +#include +#include + +#include "tls.h" + +static struct workqueue_struct *tls_strp_wq; + +static void tls_strp_abort_strp(struct tls_strparser *strp, int err) +{ + if (strp->stopped) + return; + + strp->stopped = 1; + + /* Report an error on the lower socket */ + WRITE_ONCE(strp->sk->sk_err, -err); + /* Paired with smp_rmb() in tcp_poll() */ + smp_wmb(); + sk_error_report(strp->sk); +} + +static void tls_strp_anchor_free(struct tls_strparser *strp) +{ + struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); + + DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1); + if (!strp->copy_mode) + shinfo->frag_list = NULL; + consume_skb(strp->anchor); + strp->anchor = NULL; +} + +static struct sk_buff * +tls_strp_skb_copy(struct tls_strparser *strp, struct sk_buff *in_skb, + int offset, int len) +{ + struct sk_buff *skb; + int i, err; + + skb = alloc_skb_with_frags(0, len, TLS_PAGE_ORDER, + &err, strp->sk->sk_allocation); + if (!skb) + return NULL; + + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + skb_frag_t *frag = &skb_shinfo(skb)->frags[i]; + + WARN_ON_ONCE(skb_copy_bits(in_skb, offset, + skb_frag_address(frag), + skb_frag_size(frag))); + offset += skb_frag_size(frag); + } + + skb->len = len; + skb->data_len = len; + skb_copy_header(skb, in_skb); + return skb; +} + +/* Create a new skb with the contents of input copied to its page frags */ +static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp) +{ + struct strp_msg *rxm; + struct sk_buff *skb; + + skb = tls_strp_skb_copy(strp, strp->anchor, strp->stm.offset, + strp->stm.full_len); + if (!skb) + return NULL; + + rxm = strp_msg(skb); + rxm->offset = 0; + return skb; +} + +/* Steal the input skb, input msg is invalid after calling this function */ +struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx) +{ + struct tls_strparser *strp = &ctx->strp; + +#ifdef CONFIG_TLS_DEVICE + DEBUG_NET_WARN_ON_ONCE(!strp->anchor->decrypted); +#else + /* This function turns an input into an output, + * that can only happen if we have offload. + */ + WARN_ON(1); +#endif + + if (strp->copy_mode) { + struct sk_buff *skb; + + /* Replace anchor with an empty skb, this is a little + * dangerous but __tls_cur_msg() warns on empty skbs + * so hopefully we'll catch abuses. + */ + skb = alloc_skb(0, strp->sk->sk_allocation); + if (!skb) + return NULL; + + swap(strp->anchor, skb); + return skb; + } + + return tls_strp_msg_make_copy(strp); +} + +/* Force the input skb to be in copy mode. The data ownership remains + * with the input skb itself (meaning unpause will wipe it) but it can + * be modified. + */ +int tls_strp_msg_cow(struct tls_sw_context_rx *ctx) +{ + struct tls_strparser *strp = &ctx->strp; + struct sk_buff *skb; + + if (strp->copy_mode) + return 0; + + skb = tls_strp_msg_make_copy(strp); + if (!skb) + return -ENOMEM; + + tls_strp_anchor_free(strp); + strp->anchor = skb; + + tcp_read_done(strp->sk, strp->stm.full_len); + strp->copy_mode = 1; + + return 0; +} + +/* Make a clone (in the skb sense) of the input msg to keep a reference + * to the underlying data. The reference-holding skbs get placed on + * @dst. + */ +int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst) +{ + struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); + + if (strp->copy_mode) { + struct sk_buff *skb; + + WARN_ON_ONCE(!shinfo->nr_frags); + + /* We can't skb_clone() the anchor, it gets wiped by unpause */ + skb = alloc_skb(0, strp->sk->sk_allocation); + if (!skb) + return -ENOMEM; + + __skb_queue_tail(dst, strp->anchor); + strp->anchor = skb; + } else { + struct sk_buff *iter, *clone; + int chunk, len, offset; + + offset = strp->stm.offset; + len = strp->stm.full_len; + iter = shinfo->frag_list; + + while (len > 0) { + if (iter->len <= offset) { + offset -= iter->len; + goto next; + } + + chunk = iter->len - offset; + offset = 0; + + clone = skb_clone(iter, strp->sk->sk_allocation); + if (!clone) + return -ENOMEM; + __skb_queue_tail(dst, clone); + + len -= chunk; +next: + iter = iter->next; + } + } + + return 0; +} + +static void tls_strp_flush_anchor_copy(struct tls_strparser *strp) +{ + struct skb_shared_info *shinfo = skb_shinfo(strp->anchor); + int i; + + DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1); + + for (i = 0; i < shinfo->nr_frags; i++) + __skb_frag_unref(&shinfo->frags[i], false); + shinfo->nr_frags = 0; + if (strp->copy_mode) { + kfree_skb_list(shinfo->frag_list); + shinfo->frag_list = NULL; + } + strp->copy_mode = 0; + strp->mixed_decrypted = 0; +} + +static int tls_strp_copyin_frag(struct tls_strparser *strp, struct sk_buff *skb, + struct sk_buff *in_skb, unsigned int offset, + size_t in_len) +{ + size_t len, chunk; + skb_frag_t *frag; + int sz; + + frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE]; + + len = in_len; + /* First make sure we got the header */ + if (!strp->stm.full_len) { + /* Assume one page is more than enough for headers */ + chunk = min_t(size_t, len, PAGE_SIZE - skb_frag_size(frag)); + WARN_ON_ONCE(skb_copy_bits(in_skb, offset, + skb_frag_address(frag) + + skb_frag_size(frag), + chunk)); + + skb->len += chunk; + skb->data_len += chunk; + skb_frag_size_add(frag, chunk); + + sz = tls_rx_msg_size(strp, skb); + if (sz < 0) + return sz; + + /* We may have over-read, sz == 0 is guaranteed under-read */ + if (unlikely(sz && sz < skb->len)) { + int over = skb->len - sz; + + WARN_ON_ONCE(over > chunk); + skb->len -= over; + skb->data_len -= over; + skb_frag_size_add(frag, -over); + + chunk -= over; + } + + frag++; + len -= chunk; + offset += chunk; + + strp->stm.full_len = sz; + if (!strp->stm.full_len) + goto read_done; + } + + /* Load up more data */ + while (len && strp->stm.full_len > skb->len) { + chunk = min_t(size_t, len, strp->stm.full_len - skb->len); + chunk = min_t(size_t, chunk, PAGE_SIZE - skb_frag_size(frag)); + WARN_ON_ONCE(skb_copy_bits(in_skb, offset, + skb_frag_address(frag) + + skb_frag_size(frag), + chunk)); + + skb->len += chunk; + skb->data_len += chunk; + skb_frag_size_add(frag, chunk); + frag++; + len -= chunk; + offset += chunk; + } + +read_done: + return in_len - len; +} + +static int tls_strp_copyin_skb(struct tls_strparser *strp, struct sk_buff *skb, + struct sk_buff *in_skb, unsigned int offset, + size_t in_len) +{ + struct sk_buff *nskb, *first, *last; + struct skb_shared_info *shinfo; + size_t chunk; + int sz; + + if (strp->stm.full_len) + chunk = strp->stm.full_len - skb->len; + else + chunk = TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE; + chunk = min(chunk, in_len); + + nskb = tls_strp_skb_copy(strp, in_skb, offset, chunk); + if (!nskb) + return -ENOMEM; + + shinfo = skb_shinfo(skb); + if (!shinfo->frag_list) { + shinfo->frag_list = nskb; + nskb->prev = nskb; + } else { + first = shinfo->frag_list; + last = first->prev; + last->next = nskb; + first->prev = nskb; + } + + skb->len += chunk; + skb->data_len += chunk; + + if (!strp->stm.full_len) { + sz = tls_rx_msg_size(strp, skb); + if (sz < 0) + return sz; + + /* We may have over-read, sz == 0 is guaranteed under-read */ + if (unlikely(sz && sz < skb->len)) { + int over = skb->len - sz; + + WARN_ON_ONCE(over > chunk); + skb->len -= over; + skb->data_len -= over; + __pskb_trim(nskb, nskb->len - over); + + chunk -= over; + } + + strp->stm.full_len = sz; + } + + return chunk; +} + +static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb, + unsigned int offset, size_t in_len) +{ + struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data; + struct sk_buff *skb; + int ret; + + if (strp->msg_ready) + return 0; + + skb = strp->anchor; + if (!skb->len) + skb_copy_decrypted(skb, in_skb); + else + strp->mixed_decrypted |= !!skb_cmp_decrypted(skb, in_skb); + + if (IS_ENABLED(CONFIG_TLS_DEVICE) && strp->mixed_decrypted) + ret = tls_strp_copyin_skb(strp, skb, in_skb, offset, in_len); + else + ret = tls_strp_copyin_frag(strp, skb, in_skb, offset, in_len); + if (ret < 0) { + desc->error = ret; + ret = 0; + } + + if (strp->stm.full_len && strp->stm.full_len == skb->len) { + desc->count = 0; + + strp->msg_ready = 1; + tls_rx_msg_ready(strp); + } + + return ret; +} + +static int tls_strp_read_copyin(struct tls_strparser *strp) +{ + struct socket *sock = strp->sk->sk_socket; + read_descriptor_t desc; + + desc.arg.data = strp; + desc.error = 0; + desc.count = 1; /* give more than one skb per call */ + + /* sk should be locked here, so okay to do read_sock */ + sock->ops->read_sock(strp->sk, &desc, tls_strp_copyin); + + return desc.error; +} + +static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort) +{ + struct skb_shared_info *shinfo; + struct page *page; + int need_spc, len; + + /* If the rbuf is small or rcv window has collapsed to 0 we need + * to read the data out. Otherwise the connection will stall. + * Without pressure threshold of INT_MAX will never be ready. + */ + if (likely(qshort && !tcp_epollin_ready(strp->sk, INT_MAX))) + return 0; + + shinfo = skb_shinfo(strp->anchor); + shinfo->frag_list = NULL; + + /* If we don't know the length go max plus page for cipher overhead */ + need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE; + + for (len = need_spc; len > 0; len -= PAGE_SIZE) { + page = alloc_page(strp->sk->sk_allocation); + if (!page) { + tls_strp_flush_anchor_copy(strp); + return -ENOMEM; + } + + skb_fill_page_desc(strp->anchor, shinfo->nr_frags++, + page, 0, 0); + } + + strp->copy_mode = 1; + strp->stm.offset = 0; + + strp->anchor->len = 0; + strp->anchor->data_len = 0; + strp->anchor->truesize = round_up(need_spc, PAGE_SIZE); + + tls_strp_read_copyin(strp); + + return 0; +} + +static bool tls_strp_check_queue_ok(struct tls_strparser *strp) +{ + unsigned int len = strp->stm.offset + strp->stm.full_len; + struct sk_buff *first, *skb; + u32 seq; + + first = skb_shinfo(strp->anchor)->frag_list; + skb = first; + seq = TCP_SKB_CB(first)->seq; + + /* Make sure there's no duplicate data in the queue, + * and the decrypted status matches. + */ + while (skb->len < len) { + seq += skb->len; + len -= skb->len; + skb = skb->next; + + if (TCP_SKB_CB(skb)->seq != seq) + return false; + if (skb_cmp_decrypted(first, skb)) + return false; + } + + return true; +} + +static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len) +{ + struct tcp_sock *tp = tcp_sk(strp->sk); + struct sk_buff *first; + u32 offset; + + first = tcp_recv_skb(strp->sk, tp->copied_seq, &offset); + if (WARN_ON_ONCE(!first)) + return; + + /* Bestow the state onto the anchor */ + strp->anchor->len = offset + len; + strp->anchor->data_len = offset + len; + strp->anchor->truesize = offset + len; + + skb_shinfo(strp->anchor)->frag_list = first; + + skb_copy_header(strp->anchor, first); + strp->anchor->destructor = NULL; + + strp->stm.offset = offset; +} + +void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh) +{ + struct strp_msg *rxm; + struct tls_msg *tlm; + + DEBUG_NET_WARN_ON_ONCE(!strp->msg_ready); + DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len); + + if (!strp->copy_mode && force_refresh) { + if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len)) + return; + + tls_strp_load_anchor_with_queue(strp, strp->stm.full_len); + } + + rxm = strp_msg(strp->anchor); + rxm->full_len = strp->stm.full_len; + rxm->offset = strp->stm.offset; + tlm = tls_msg(strp->anchor); + tlm->control = strp->mark; +} + +/* Called with lock held on lower socket */ +static int tls_strp_read_sock(struct tls_strparser *strp) +{ + int sz, inq; + + inq = tcp_inq(strp->sk); + if (inq < 1) + return 0; + + if (unlikely(strp->copy_mode)) + return tls_strp_read_copyin(strp); + + if (inq < strp->stm.full_len) + return tls_strp_read_copy(strp, true); + + if (!strp->stm.full_len) { + tls_strp_load_anchor_with_queue(strp, inq); + + sz = tls_rx_msg_size(strp, strp->anchor); + if (sz < 0) { + tls_strp_abort_strp(strp, sz); + return sz; + } + + strp->stm.full_len = sz; + + if (!strp->stm.full_len || inq < strp->stm.full_len) + return tls_strp_read_copy(strp, true); + } + + if (!tls_strp_check_queue_ok(strp)) + return tls_strp_read_copy(strp, false); + + strp->msg_ready = 1; + tls_rx_msg_ready(strp); + + return 0; +} + +void tls_strp_check_rcv(struct tls_strparser *strp) +{ + if (unlikely(strp->stopped) || strp->msg_ready) + return; + + if (tls_strp_read_sock(strp) == -ENOMEM) + queue_work(tls_strp_wq, &strp->work); +} + +/* Lower sock lock held */ +void tls_strp_data_ready(struct tls_strparser *strp) +{ + /* This check is needed to synchronize with do_tls_strp_work. + * do_tls_strp_work acquires a process lock (lock_sock) whereas + * the lock held here is bh_lock_sock. The two locks can be + * held by different threads at the same time, but bh_lock_sock + * allows a thread in BH context to safely check if the process + * lock is held. In this case, if the lock is held, queue work. + */ + if (sock_owned_by_user_nocheck(strp->sk)) { + queue_work(tls_strp_wq, &strp->work); + return; + } + + tls_strp_check_rcv(strp); +} + +static void tls_strp_work(struct work_struct *w) +{ + struct tls_strparser *strp = + container_of(w, struct tls_strparser, work); + + lock_sock(strp->sk); + tls_strp_check_rcv(strp); + release_sock(strp->sk); +} + +void tls_strp_msg_done(struct tls_strparser *strp) +{ + WARN_ON(!strp->stm.full_len); + + if (likely(!strp->copy_mode)) + tcp_read_done(strp->sk, strp->stm.full_len); + else + tls_strp_flush_anchor_copy(strp); + + strp->msg_ready = 0; + memset(&strp->stm, 0, sizeof(strp->stm)); + + tls_strp_check_rcv(strp); +} + +void tls_strp_stop(struct tls_strparser *strp) +{ + strp->stopped = 1; +} + +int tls_strp_init(struct tls_strparser *strp, struct sock *sk) +{ + memset(strp, 0, sizeof(*strp)); + + strp->sk = sk; + + strp->anchor = alloc_skb(0, GFP_KERNEL); + if (!strp->anchor) + return -ENOMEM; + + INIT_WORK(&strp->work, tls_strp_work); + + return 0; +} + +/* strp must already be stopped so that tls_strp_recv will no longer be called. + * Note that tls_strp_done is not called with the lower socket held. + */ +void tls_strp_done(struct tls_strparser *strp) +{ + WARN_ON(!strp->stopped); + + cancel_work_sync(&strp->work); + tls_strp_anchor_free(strp); +} + +int __init tls_strp_dev_init(void) +{ + tls_strp_wq = create_workqueue("tls-strp"); + if (unlikely(!tls_strp_wq)) + return -ENOMEM; + + return 0; +} + +void tls_strp_dev_exit(void) +{ + destroy_workqueue(tls_strp_wq); +} diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c new file mode 100644 index 000000000..0323040d3 --- /dev/null +++ b/net/tls/tls_sw.c @@ -0,0 +1,2818 @@ +/* + * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. + * Copyright (c) 2016-2017, Dave Watson . All rights reserved. + * Copyright (c) 2016-2017, Lance Chao . All rights reserved. + * Copyright (c) 2016, Fridolin Pokorny . All rights reserved. + * Copyright (c) 2016, Nikos Mavrogiannopoulos . All rights reserved. + * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "tls.h" + +struct tls_decrypt_arg { + struct_group(inargs, + bool zc; + bool async; + u8 tail; + ); + + struct sk_buff *skb; +}; + +struct tls_decrypt_ctx { + struct sock *sk; + u8 iv[MAX_IV_SIZE]; + u8 aad[TLS_MAX_AAD_SIZE]; + u8 tail; + struct scatterlist sg[]; +}; + +noinline void tls_err_abort(struct sock *sk, int err) +{ + WARN_ON_ONCE(err >= 0); + /* sk->sk_err should contain a positive error code. */ + WRITE_ONCE(sk->sk_err, -err); + /* Paired with smp_rmb() in tcp_poll() */ + smp_wmb(); + sk_error_report(sk); +} + +static int __skb_nsg(struct sk_buff *skb, int offset, int len, + unsigned int recursion_level) +{ + int start = skb_headlen(skb); + int i, chunk = start - offset; + struct sk_buff *frag_iter; + int elt = 0; + + if (unlikely(recursion_level >= 24)) + return -EMSGSIZE; + + if (chunk > 0) { + if (chunk > len) + chunk = len; + elt++; + len -= chunk; + if (len == 0) + return elt; + offset += chunk; + } + + for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { + int end; + + WARN_ON(start > offset + len); + + end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]); + chunk = end - offset; + if (chunk > 0) { + if (chunk > len) + chunk = len; + elt++; + len -= chunk; + if (len == 0) + return elt; + offset += chunk; + } + start = end; + } + + if (unlikely(skb_has_frag_list(skb))) { + skb_walk_frags(skb, frag_iter) { + int end, ret; + + WARN_ON(start > offset + len); + + end = start + frag_iter->len; + chunk = end - offset; + if (chunk > 0) { + if (chunk > len) + chunk = len; + ret = __skb_nsg(frag_iter, offset - start, chunk, + recursion_level + 1); + if (unlikely(ret < 0)) + return ret; + elt += ret; + len -= chunk; + if (len == 0) + return elt; + offset += chunk; + } + start = end; + } + } + BUG_ON(len); + return elt; +} + +/* Return the number of scatterlist elements required to completely map the + * skb, or -EMSGSIZE if the recursion depth is exceeded. + */ +static int skb_nsg(struct sk_buff *skb, int offset, int len) +{ + return __skb_nsg(skb, offset, len, 0); +} + +static int tls_padding_length(struct tls_prot_info *prot, struct sk_buff *skb, + struct tls_decrypt_arg *darg) +{ + struct strp_msg *rxm = strp_msg(skb); + struct tls_msg *tlm = tls_msg(skb); + int sub = 0; + + /* Determine zero-padding length */ + if (prot->version == TLS_1_3_VERSION) { + int offset = rxm->full_len - TLS_TAG_SIZE - 1; + char content_type = darg->zc ? darg->tail : 0; + int err; + + while (content_type == 0) { + if (offset < prot->prepend_size) + return -EBADMSG; + err = skb_copy_bits(skb, rxm->offset + offset, + &content_type, 1); + if (err) + return err; + if (content_type) + break; + sub++; + offset--; + } + tlm->control = content_type; + } + return sub; +} + +static void tls_decrypt_done(crypto_completion_data_t *data, int err) +{ + struct aead_request *aead_req = crypto_get_completion_data(data); + struct crypto_aead *aead = crypto_aead_reqtfm(aead_req); + struct scatterlist *sgout = aead_req->dst; + struct scatterlist *sgin = aead_req->src; + struct tls_sw_context_rx *ctx; + struct tls_decrypt_ctx *dctx; + struct tls_context *tls_ctx; + struct scatterlist *sg; + unsigned int pages; + struct sock *sk; + int aead_size; + + aead_size = sizeof(*aead_req) + crypto_aead_reqsize(aead); + aead_size = ALIGN(aead_size, __alignof__(*dctx)); + dctx = (void *)((u8 *)aead_req + aead_size); + + sk = dctx->sk; + tls_ctx = tls_get_ctx(sk); + ctx = tls_sw_ctx_rx(tls_ctx); + + /* Propagate if there was an err */ + if (err) { + if (err == -EBADMSG) + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); + ctx->async_wait.err = err; + tls_err_abort(sk, err); + } + + /* Free the destination pages if skb was not decrypted inplace */ + if (sgout != sgin) { + /* Skip the first S/G entry as it points to AAD */ + for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) { + if (!sg) + break; + put_page(sg_page(sg)); + } + } + + kfree(aead_req); + + spin_lock_bh(&ctx->decrypt_compl_lock); + if (!atomic_dec_return(&ctx->decrypt_pending)) + complete(&ctx->async_wait.completion); + spin_unlock_bh(&ctx->decrypt_compl_lock); +} + +static int tls_do_decryption(struct sock *sk, + struct scatterlist *sgin, + struct scatterlist *sgout, + char *iv_recv, + size_t data_len, + struct aead_request *aead_req, + struct tls_decrypt_arg *darg) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + int ret; + + aead_request_set_tfm(aead_req, ctx->aead_recv); + aead_request_set_ad(aead_req, prot->aad_size); + aead_request_set_crypt(aead_req, sgin, sgout, + data_len + prot->tag_size, + (u8 *)iv_recv); + + if (darg->async) { + aead_request_set_callback(aead_req, + CRYPTO_TFM_REQ_MAY_BACKLOG, + tls_decrypt_done, aead_req); + atomic_inc(&ctx->decrypt_pending); + } else { + aead_request_set_callback(aead_req, + CRYPTO_TFM_REQ_MAY_BACKLOG, + crypto_req_done, &ctx->async_wait); + } + + ret = crypto_aead_decrypt(aead_req); + if (ret == -EINPROGRESS) { + if (darg->async) + return 0; + + ret = crypto_wait_req(ret, &ctx->async_wait); + } + darg->async = false; + + return ret; +} + +static void tls_trim_both_msgs(struct sock *sk, int target_size) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec = ctx->open_rec; + + sk_msg_trim(sk, &rec->msg_plaintext, target_size); + if (target_size > 0) + target_size += prot->overhead_size; + sk_msg_trim(sk, &rec->msg_encrypted, target_size); +} + +static int tls_alloc_encrypted_msg(struct sock *sk, int len) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec = ctx->open_rec; + struct sk_msg *msg_en = &rec->msg_encrypted; + + return sk_msg_alloc(sk, msg_en, len, 0); +} + +static int tls_clone_plaintext_msg(struct sock *sk, int required) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec = ctx->open_rec; + struct sk_msg *msg_pl = &rec->msg_plaintext; + struct sk_msg *msg_en = &rec->msg_encrypted; + int skip, len; + + /* We add page references worth len bytes from encrypted sg + * at the end of plaintext sg. It is guaranteed that msg_en + * has enough required room (ensured by caller). + */ + len = required - msg_pl->sg.size; + + /* Skip initial bytes in msg_en's data to be able to use + * same offset of both plain and encrypted data. + */ + skip = prot->prepend_size + msg_pl->sg.size; + + return sk_msg_clone(sk, msg_pl, msg_en, skip, len); +} + +static struct tls_rec *tls_get_rec(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct sk_msg *msg_pl, *msg_en; + struct tls_rec *rec; + int mem_size; + + mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send); + + rec = kzalloc(mem_size, sk->sk_allocation); + if (!rec) + return NULL; + + msg_pl = &rec->msg_plaintext; + msg_en = &rec->msg_encrypted; + + sk_msg_init(msg_pl); + sk_msg_init(msg_en); + + sg_init_table(rec->sg_aead_in, 2); + sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size); + sg_unmark_end(&rec->sg_aead_in[1]); + + sg_init_table(rec->sg_aead_out, 2); + sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size); + sg_unmark_end(&rec->sg_aead_out[1]); + + rec->sk = sk; + + return rec; +} + +static void tls_free_rec(struct sock *sk, struct tls_rec *rec) +{ + sk_msg_free(sk, &rec->msg_encrypted); + sk_msg_free(sk, &rec->msg_plaintext); + kfree(rec); +} + +static void tls_free_open_rec(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec = ctx->open_rec; + + if (rec) { + tls_free_rec(sk, rec); + ctx->open_rec = NULL; + } +} + +int tls_tx_records(struct sock *sk, int flags) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec, *tmp; + struct sk_msg *msg_en; + int tx_flags, rc = 0; + + if (tls_is_partially_sent_record(tls_ctx)) { + rec = list_first_entry(&ctx->tx_list, + struct tls_rec, list); + + if (flags == -1) + tx_flags = rec->tx_flags; + else + tx_flags = flags; + + rc = tls_push_partial_record(sk, tls_ctx, tx_flags); + if (rc) + goto tx_err; + + /* Full record has been transmitted. + * Remove the head of tx_list + */ + list_del(&rec->list); + sk_msg_free(sk, &rec->msg_plaintext); + kfree(rec); + } + + /* Tx all ready records */ + list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) { + if (READ_ONCE(rec->tx_ready)) { + if (flags == -1) + tx_flags = rec->tx_flags; + else + tx_flags = flags; + + msg_en = &rec->msg_encrypted; + rc = tls_push_sg(sk, tls_ctx, + &msg_en->sg.data[msg_en->sg.curr], + 0, tx_flags); + if (rc) + goto tx_err; + + list_del(&rec->list); + sk_msg_free(sk, &rec->msg_plaintext); + kfree(rec); + } else { + break; + } + } + +tx_err: + if (rc < 0 && rc != -EAGAIN) + tls_err_abort(sk, -EBADMSG); + + return rc; +} + +static void tls_encrypt_done(crypto_completion_data_t *data, int err) +{ + struct aead_request *aead_req = crypto_get_completion_data(data); + struct tls_sw_context_tx *ctx; + struct tls_context *tls_ctx; + struct tls_prot_info *prot; + struct scatterlist *sge; + struct sk_msg *msg_en; + struct tls_rec *rec; + bool ready = false; + struct sock *sk; + int pending; + + rec = container_of(aead_req, struct tls_rec, aead_req); + msg_en = &rec->msg_encrypted; + + sk = rec->sk; + tls_ctx = tls_get_ctx(sk); + prot = &tls_ctx->prot_info; + ctx = tls_sw_ctx_tx(tls_ctx); + + sge = sk_msg_elem(msg_en, msg_en->sg.curr); + sge->offset -= prot->prepend_size; + sge->length += prot->prepend_size; + + /* Check if error is previously set on socket */ + if (err || sk->sk_err) { + rec = NULL; + + /* If err is already set on socket, return the same code */ + if (sk->sk_err) { + ctx->async_wait.err = -sk->sk_err; + } else { + ctx->async_wait.err = err; + tls_err_abort(sk, err); + } + } + + if (rec) { + struct tls_rec *first_rec; + + /* Mark the record as ready for transmission */ + smp_store_mb(rec->tx_ready, true); + + /* If received record is at head of tx_list, schedule tx */ + first_rec = list_first_entry(&ctx->tx_list, + struct tls_rec, list); + if (rec == first_rec) + ready = true; + } + + spin_lock_bh(&ctx->encrypt_compl_lock); + pending = atomic_dec_return(&ctx->encrypt_pending); + + if (!pending && ctx->async_notify) + complete(&ctx->async_wait.completion); + spin_unlock_bh(&ctx->encrypt_compl_lock); + + if (!ready) + return; + + /* Schedule the transmission */ + if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) + schedule_delayed_work(&ctx->tx_work.work, 1); +} + +static int tls_do_encryption(struct sock *sk, + struct tls_context *tls_ctx, + struct tls_sw_context_tx *ctx, + struct aead_request *aead_req, + size_t data_len, u32 start) +{ + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct tls_rec *rec = ctx->open_rec; + struct sk_msg *msg_en = &rec->msg_encrypted; + struct scatterlist *sge = sk_msg_elem(msg_en, start); + int rc, iv_offset = 0; + + /* For CCM based ciphers, first byte of IV is a constant */ + switch (prot->cipher_type) { + case TLS_CIPHER_AES_CCM_128: + rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE; + iv_offset = 1; + break; + case TLS_CIPHER_SM4_CCM: + rec->iv_data[0] = TLS_SM4_CCM_IV_B0_BYTE; + iv_offset = 1; + break; + } + + memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv, + prot->iv_size + prot->salt_size); + + tls_xor_iv_with_seq(prot, rec->iv_data + iv_offset, + tls_ctx->tx.rec_seq); + + sge->offset += prot->prepend_size; + sge->length -= prot->prepend_size; + + msg_en->sg.curr = start; + + aead_request_set_tfm(aead_req, ctx->aead_send); + aead_request_set_ad(aead_req, prot->aad_size); + aead_request_set_crypt(aead_req, rec->sg_aead_in, + rec->sg_aead_out, + data_len, rec->iv_data); + + aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, + tls_encrypt_done, aead_req); + + /* Add the record in tx_list */ + list_add_tail((struct list_head *)&rec->list, &ctx->tx_list); + atomic_inc(&ctx->encrypt_pending); + + rc = crypto_aead_encrypt(aead_req); + if (!rc || rc != -EINPROGRESS) { + atomic_dec(&ctx->encrypt_pending); + sge->offset -= prot->prepend_size; + sge->length += prot->prepend_size; + } + + if (!rc) { + WRITE_ONCE(rec->tx_ready, true); + } else if (rc != -EINPROGRESS) { + list_del(&rec->list); + return rc; + } + + /* Unhook the record from context if encryption is not failure */ + ctx->open_rec = NULL; + tls_advance_record_sn(sk, prot, &tls_ctx->tx); + return rc; +} + +static int tls_split_open_record(struct sock *sk, struct tls_rec *from, + struct tls_rec **to, struct sk_msg *msg_opl, + struct sk_msg *msg_oen, u32 split_point, + u32 tx_overhead_size, u32 *orig_end) +{ + u32 i, j, bytes = 0, apply = msg_opl->apply_bytes; + struct scatterlist *sge, *osge, *nsge; + u32 orig_size = msg_opl->sg.size; + struct scatterlist tmp = { }; + struct sk_msg *msg_npl; + struct tls_rec *new; + int ret; + + new = tls_get_rec(sk); + if (!new) + return -ENOMEM; + ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size + + tx_overhead_size, 0); + if (ret < 0) { + tls_free_rec(sk, new); + return ret; + } + + *orig_end = msg_opl->sg.end; + i = msg_opl->sg.start; + sge = sk_msg_elem(msg_opl, i); + while (apply && sge->length) { + if (sge->length > apply) { + u32 len = sge->length - apply; + + get_page(sg_page(sge)); + sg_set_page(&tmp, sg_page(sge), len, + sge->offset + apply); + sge->length = apply; + bytes += apply; + apply = 0; + } else { + apply -= sge->length; + bytes += sge->length; + } + + sk_msg_iter_var_next(i); + if (i == msg_opl->sg.end) + break; + sge = sk_msg_elem(msg_opl, i); + } + + msg_opl->sg.end = i; + msg_opl->sg.curr = i; + msg_opl->sg.copybreak = 0; + msg_opl->apply_bytes = 0; + msg_opl->sg.size = bytes; + + msg_npl = &new->msg_plaintext; + msg_npl->apply_bytes = apply; + msg_npl->sg.size = orig_size - bytes; + + j = msg_npl->sg.start; + nsge = sk_msg_elem(msg_npl, j); + if (tmp.length) { + memcpy(nsge, &tmp, sizeof(*nsge)); + sk_msg_iter_var_next(j); + nsge = sk_msg_elem(msg_npl, j); + } + + osge = sk_msg_elem(msg_opl, i); + while (osge->length) { + memcpy(nsge, osge, sizeof(*nsge)); + sg_unmark_end(nsge); + sk_msg_iter_var_next(i); + sk_msg_iter_var_next(j); + if (i == *orig_end) + break; + osge = sk_msg_elem(msg_opl, i); + nsge = sk_msg_elem(msg_npl, j); + } + + msg_npl->sg.end = j; + msg_npl->sg.curr = j; + msg_npl->sg.copybreak = 0; + + *to = new; + return 0; +} + +static void tls_merge_open_record(struct sock *sk, struct tls_rec *to, + struct tls_rec *from, u32 orig_end) +{ + struct sk_msg *msg_npl = &from->msg_plaintext; + struct sk_msg *msg_opl = &to->msg_plaintext; + struct scatterlist *osge, *nsge; + u32 i, j; + + i = msg_opl->sg.end; + sk_msg_iter_var_prev(i); + j = msg_npl->sg.start; + + osge = sk_msg_elem(msg_opl, i); + nsge = sk_msg_elem(msg_npl, j); + + if (sg_page(osge) == sg_page(nsge) && + osge->offset + osge->length == nsge->offset) { + osge->length += nsge->length; + put_page(sg_page(nsge)); + } + + msg_opl->sg.end = orig_end; + msg_opl->sg.curr = orig_end; + msg_opl->sg.copybreak = 0; + msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size; + msg_opl->sg.size += msg_npl->sg.size; + + sk_msg_free(sk, &to->msg_encrypted); + sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted); + + kfree(from); +} + +static int tls_push_record(struct sock *sk, int flags, + unsigned char record_type) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec = ctx->open_rec, *tmp = NULL; + u32 i, split_point, orig_end; + struct sk_msg *msg_pl, *msg_en; + struct aead_request *req; + bool split; + int rc; + + if (!rec) + return 0; + + msg_pl = &rec->msg_plaintext; + msg_en = &rec->msg_encrypted; + + split_point = msg_pl->apply_bytes; + split = split_point && split_point < msg_pl->sg.size; + if (unlikely((!split && + msg_pl->sg.size + + prot->overhead_size > msg_en->sg.size) || + (split && + split_point + + prot->overhead_size > msg_en->sg.size))) { + split = true; + split_point = msg_en->sg.size; + } + if (split) { + rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en, + split_point, prot->overhead_size, + &orig_end); + if (rc < 0) + return rc; + /* This can happen if above tls_split_open_record allocates + * a single large encryption buffer instead of two smaller + * ones. In this case adjust pointers and continue without + * split. + */ + if (!msg_pl->sg.size) { + tls_merge_open_record(sk, rec, tmp, orig_end); + msg_pl = &rec->msg_plaintext; + msg_en = &rec->msg_encrypted; + split = false; + } + sk_msg_trim(sk, msg_en, msg_pl->sg.size + + prot->overhead_size); + } + + rec->tx_flags = flags; + req = &rec->aead_req; + + i = msg_pl->sg.end; + sk_msg_iter_var_prev(i); + + rec->content_type = record_type; + if (prot->version == TLS_1_3_VERSION) { + /* Add content type to end of message. No padding added */ + sg_set_buf(&rec->sg_content_type, &rec->content_type, 1); + sg_mark_end(&rec->sg_content_type); + sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1, + &rec->sg_content_type); + } else { + sg_mark_end(sk_msg_elem(msg_pl, i)); + } + + if (msg_pl->sg.end < msg_pl->sg.start) { + sg_chain(&msg_pl->sg.data[msg_pl->sg.start], + MAX_SKB_FRAGS - msg_pl->sg.start + 1, + msg_pl->sg.data); + } + + i = msg_pl->sg.start; + sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]); + + i = msg_en->sg.end; + sk_msg_iter_var_prev(i); + sg_mark_end(sk_msg_elem(msg_en, i)); + + i = msg_en->sg.start; + sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]); + + tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size, + tls_ctx->tx.rec_seq, record_type, prot); + + tls_fill_prepend(tls_ctx, + page_address(sg_page(&msg_en->sg.data[i])) + + msg_en->sg.data[i].offset, + msg_pl->sg.size + prot->tail_size, + record_type); + + tls_ctx->pending_open_record_frags = false; + + rc = tls_do_encryption(sk, tls_ctx, ctx, req, + msg_pl->sg.size + prot->tail_size, i); + if (rc < 0) { + if (rc != -EINPROGRESS) { + tls_err_abort(sk, -EBADMSG); + if (split) { + tls_ctx->pending_open_record_frags = true; + tls_merge_open_record(sk, rec, tmp, orig_end); + } + } + ctx->async_capable = 1; + return rc; + } else if (split) { + msg_pl = &tmp->msg_plaintext; + msg_en = &tmp->msg_encrypted; + sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size); + tls_ctx->pending_open_record_frags = true; + ctx->open_rec = tmp; + } + + return tls_tx_records(sk, flags); +} + +static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk, + bool full_record, u8 record_type, + ssize_t *copied, int flags) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct sk_msg msg_redir = { }; + struct sk_psock *psock; + struct sock *sk_redir; + struct tls_rec *rec; + bool enospc, policy, redir_ingress; + int err = 0, send; + u32 delta = 0; + + policy = !(flags & MSG_SENDPAGE_NOPOLICY); + psock = sk_psock_get(sk); + if (!psock || !policy) { + err = tls_push_record(sk, flags, record_type); + if (err && err != -EINPROGRESS && sk->sk_err == EBADMSG) { + *copied -= sk_msg_free(sk, msg); + tls_free_open_rec(sk); + err = -sk->sk_err; + } + if (psock) + sk_psock_put(sk, psock); + return err; + } +more_data: + enospc = sk_msg_full(msg); + if (psock->eval == __SK_NONE) { + delta = msg->sg.size; + psock->eval = sk_psock_msg_verdict(sk, psock, msg); + delta -= msg->sg.size; + } + if (msg->cork_bytes && msg->cork_bytes > msg->sg.size && + !enospc && !full_record) { + err = -ENOSPC; + goto out_err; + } + msg->cork_bytes = 0; + send = msg->sg.size; + if (msg->apply_bytes && msg->apply_bytes < send) + send = msg->apply_bytes; + + switch (psock->eval) { + case __SK_PASS: + err = tls_push_record(sk, flags, record_type); + if (err && err != -EINPROGRESS && sk->sk_err == EBADMSG) { + *copied -= sk_msg_free(sk, msg); + tls_free_open_rec(sk); + err = -sk->sk_err; + goto out_err; + } + break; + case __SK_REDIRECT: + redir_ingress = psock->redir_ingress; + sk_redir = psock->sk_redir; + memcpy(&msg_redir, msg, sizeof(*msg)); + if (msg->apply_bytes < send) + msg->apply_bytes = 0; + else + msg->apply_bytes -= send; + sk_msg_return_zero(sk, msg, send); + msg->sg.size -= send; + release_sock(sk); + err = tcp_bpf_sendmsg_redir(sk_redir, redir_ingress, + &msg_redir, send, flags); + lock_sock(sk); + if (err < 0) { + *copied -= sk_msg_free_nocharge(sk, &msg_redir); + msg->sg.size = 0; + } + if (msg->sg.size == 0) + tls_free_open_rec(sk); + break; + case __SK_DROP: + default: + sk_msg_free_partial(sk, msg, send); + if (msg->apply_bytes < send) + msg->apply_bytes = 0; + else + msg->apply_bytes -= send; + if (msg->sg.size == 0) + tls_free_open_rec(sk); + *copied -= (send + delta); + err = -EACCES; + } + + if (likely(!err)) { + bool reset_eval = !ctx->open_rec; + + rec = ctx->open_rec; + if (rec) { + msg = &rec->msg_plaintext; + if (!msg->apply_bytes) + reset_eval = true; + } + if (reset_eval) { + psock->eval = __SK_NONE; + if (psock->sk_redir) { + sock_put(psock->sk_redir); + psock->sk_redir = NULL; + } + } + if (rec) + goto more_data; + } + out_err: + sk_psock_put(sk, psock); + return err; +} + +static int tls_sw_push_pending_record(struct sock *sk, int flags) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec = ctx->open_rec; + struct sk_msg *msg_pl; + size_t copied; + + if (!rec) + return 0; + + msg_pl = &rec->msg_plaintext; + copied = msg_pl->sg.size; + if (!copied) + return 0; + + return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA, + &copied, flags); +} + +int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +{ + long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + bool async_capable = ctx->async_capable; + unsigned char record_type = TLS_RECORD_TYPE_DATA; + bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); + bool eor = !(msg->msg_flags & MSG_MORE); + size_t try_to_copy; + ssize_t copied = 0; + struct sk_msg *msg_pl, *msg_en; + struct tls_rec *rec; + int required_size; + int num_async = 0; + bool full_record; + int record_room; + int num_zc = 0; + int orig_size; + int ret = 0; + int pending; + + if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | + MSG_CMSG_COMPAT)) + return -EOPNOTSUPP; + + ret = mutex_lock_interruptible(&tls_ctx->tx_lock); + if (ret) + return ret; + lock_sock(sk); + + if (unlikely(msg->msg_controllen)) { + ret = tls_process_cmsg(sk, msg, &record_type); + if (ret) { + if (ret == -EINPROGRESS) + num_async++; + else if (ret != -EAGAIN) + goto send_end; + } + } + + while (msg_data_left(msg)) { + if (sk->sk_err) { + ret = -sk->sk_err; + goto send_end; + } + + if (ctx->open_rec) + rec = ctx->open_rec; + else + rec = ctx->open_rec = tls_get_rec(sk); + if (!rec) { + ret = -ENOMEM; + goto send_end; + } + + msg_pl = &rec->msg_plaintext; + msg_en = &rec->msg_encrypted; + + orig_size = msg_pl->sg.size; + full_record = false; + try_to_copy = msg_data_left(msg); + record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size; + if (try_to_copy >= record_room) { + try_to_copy = record_room; + full_record = true; + } + + required_size = msg_pl->sg.size + try_to_copy + + prot->overhead_size; + + if (!sk_stream_memory_free(sk)) + goto wait_for_sndbuf; + +alloc_encrypted: + ret = tls_alloc_encrypted_msg(sk, required_size); + if (ret) { + if (ret != -ENOSPC) + goto wait_for_memory; + + /* Adjust try_to_copy according to the amount that was + * actually allocated. The difference is due + * to max sg elements limit + */ + try_to_copy -= required_size - msg_en->sg.size; + full_record = true; + } + + if (!is_kvec && (full_record || eor) && !async_capable) { + u32 first = msg_pl->sg.end; + + ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter, + msg_pl, try_to_copy); + if (ret) + goto fallback_to_reg_send; + + num_zc++; + copied += try_to_copy; + + sk_msg_sg_copy_set(msg_pl, first); + ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, + record_type, &copied, + msg->msg_flags); + if (ret) { + if (ret == -EINPROGRESS) + num_async++; + else if (ret == -ENOMEM) + goto wait_for_memory; + else if (ctx->open_rec && ret == -ENOSPC) + goto rollback_iter; + else if (ret != -EAGAIN) + goto send_end; + } + continue; +rollback_iter: + copied -= try_to_copy; + sk_msg_sg_copy_clear(msg_pl, first); + iov_iter_revert(&msg->msg_iter, + msg_pl->sg.size - orig_size); +fallback_to_reg_send: + sk_msg_trim(sk, msg_pl, orig_size); + } + + required_size = msg_pl->sg.size + try_to_copy; + + ret = tls_clone_plaintext_msg(sk, required_size); + if (ret) { + if (ret != -ENOSPC) + goto send_end; + + /* Adjust try_to_copy according to the amount that was + * actually allocated. The difference is due + * to max sg elements limit + */ + try_to_copy -= required_size - msg_pl->sg.size; + full_record = true; + sk_msg_trim(sk, msg_en, + msg_pl->sg.size + prot->overhead_size); + } + + if (try_to_copy) { + ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, + msg_pl, try_to_copy); + if (ret < 0) + goto trim_sgl; + } + + /* Open records defined only if successfully copied, otherwise + * we would trim the sg but not reset the open record frags. + */ + tls_ctx->pending_open_record_frags = true; + copied += try_to_copy; + if (full_record || eor) { + ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, + record_type, &copied, + msg->msg_flags); + if (ret) { + if (ret == -EINPROGRESS) + num_async++; + else if (ret == -ENOMEM) + goto wait_for_memory; + else if (ret != -EAGAIN) { + if (ret == -ENOSPC) + ret = 0; + goto send_end; + } + } + } + + continue; + +wait_for_sndbuf: + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); +wait_for_memory: + ret = sk_stream_wait_memory(sk, &timeo); + if (ret) { +trim_sgl: + if (ctx->open_rec) + tls_trim_both_msgs(sk, orig_size); + goto send_end; + } + + if (ctx->open_rec && msg_en->sg.size < required_size) + goto alloc_encrypted; + } + + if (!num_async) { + goto send_end; + } else if (num_zc) { + /* Wait for pending encryptions to get completed */ + spin_lock_bh(&ctx->encrypt_compl_lock); + ctx->async_notify = true; + + pending = atomic_read(&ctx->encrypt_pending); + spin_unlock_bh(&ctx->encrypt_compl_lock); + if (pending) + crypto_wait_req(-EINPROGRESS, &ctx->async_wait); + else + reinit_completion(&ctx->async_wait.completion); + + /* There can be no concurrent accesses, since we have no + * pending encrypt operations + */ + WRITE_ONCE(ctx->async_notify, false); + + if (ctx->async_wait.err) { + ret = ctx->async_wait.err; + copied = 0; + } + } + + /* Transmit if any encryptions have completed */ + if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) { + cancel_delayed_work(&ctx->tx_work.work); + tls_tx_records(sk, msg->msg_flags); + } + +send_end: + ret = sk_stream_error(sk, msg->msg_flags, ret); + + release_sock(sk); + mutex_unlock(&tls_ctx->tx_lock); + return copied > 0 ? copied : ret; +} + +static int tls_sw_do_sendpage(struct sock *sk, struct page *page, + int offset, size_t size, int flags) +{ + long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; + unsigned char record_type = TLS_RECORD_TYPE_DATA; + struct sk_msg *msg_pl; + struct tls_rec *rec; + int num_async = 0; + ssize_t copied = 0; + bool full_record; + int record_room; + int ret = 0; + bool eor; + + eor = !(flags & MSG_SENDPAGE_NOTLAST); + sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); + + /* Call the sk_stream functions to manage the sndbuf mem. */ + while (size > 0) { + size_t copy, required_size; + + if (sk->sk_err) { + ret = -sk->sk_err; + goto sendpage_end; + } + + if (ctx->open_rec) + rec = ctx->open_rec; + else + rec = ctx->open_rec = tls_get_rec(sk); + if (!rec) { + ret = -ENOMEM; + goto sendpage_end; + } + + msg_pl = &rec->msg_plaintext; + + full_record = false; + record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size; + copy = size; + if (copy >= record_room) { + copy = record_room; + full_record = true; + } + + required_size = msg_pl->sg.size + copy + prot->overhead_size; + + if (!sk_stream_memory_free(sk)) + goto wait_for_sndbuf; +alloc_payload: + ret = tls_alloc_encrypted_msg(sk, required_size); + if (ret) { + if (ret != -ENOSPC) + goto wait_for_memory; + + /* Adjust copy according to the amount that was + * actually allocated. The difference is due + * to max sg elements limit + */ + copy -= required_size - msg_pl->sg.size; + full_record = true; + } + + sk_msg_page_add(msg_pl, page, copy, offset); + msg_pl->sg.copybreak = 0; + msg_pl->sg.curr = msg_pl->sg.end; + sk_mem_charge(sk, copy); + + offset += copy; + size -= copy; + copied += copy; + + tls_ctx->pending_open_record_frags = true; + if (full_record || eor || sk_msg_full(msg_pl)) { + ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, + record_type, &copied, flags); + if (ret) { + if (ret == -EINPROGRESS) + num_async++; + else if (ret == -ENOMEM) + goto wait_for_memory; + else if (ret != -EAGAIN) { + if (ret == -ENOSPC) + ret = 0; + goto sendpage_end; + } + } + } + continue; +wait_for_sndbuf: + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); +wait_for_memory: + ret = sk_stream_wait_memory(sk, &timeo); + if (ret) { + if (ctx->open_rec) + tls_trim_both_msgs(sk, msg_pl->sg.size); + goto sendpage_end; + } + + if (ctx->open_rec) + goto alloc_payload; + } + + if (num_async) { + /* Transmit if any encryptions have completed */ + if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) { + cancel_delayed_work(&ctx->tx_work.work); + tls_tx_records(sk, flags); + } + } +sendpage_end: + ret = sk_stream_error(sk, flags, ret); + return copied > 0 ? copied : ret; +} + +int tls_sw_sendpage_locked(struct sock *sk, struct page *page, + int offset, size_t size, int flags) +{ + if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | + MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY | + MSG_NO_SHARED_FRAGS)) + return -EOPNOTSUPP; + + return tls_sw_do_sendpage(sk, page, offset, size, flags); +} + +int tls_sw_sendpage(struct sock *sk, struct page *page, + int offset, size_t size, int flags) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + int ret; + + if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | + MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY)) + return -EOPNOTSUPP; + + ret = mutex_lock_interruptible(&tls_ctx->tx_lock); + if (ret) + return ret; + lock_sock(sk); + ret = tls_sw_do_sendpage(sk, page, offset, size, flags); + release_sock(sk); + mutex_unlock(&tls_ctx->tx_lock); + return ret; +} + +static int +tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, + bool released) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + DEFINE_WAIT_FUNC(wait, woken_wake_function); + int ret = 0; + long timeo; + + timeo = sock_rcvtimeo(sk, nonblock); + + while (!tls_strp_msg_ready(ctx)) { + if (!sk_psock_queue_empty(psock)) + return 0; + + if (sk->sk_err) + return sock_error(sk); + + if (ret < 0) + return ret; + + if (!skb_queue_empty(&sk->sk_receive_queue)) { + tls_strp_check_rcv(&ctx->strp); + if (tls_strp_msg_ready(ctx)) + break; + } + + if (sk->sk_shutdown & RCV_SHUTDOWN) + return 0; + + if (sock_flag(sk, SOCK_DONE)) + return 0; + + if (!timeo) + return -EAGAIN; + + released = true; + add_wait_queue(sk_sleep(sk), &wait); + sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); + ret = sk_wait_event(sk, &timeo, + tls_strp_msg_ready(ctx) || + !sk_psock_queue_empty(psock), + &wait); + sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); + remove_wait_queue(sk_sleep(sk), &wait); + + /* Handle signals */ + if (signal_pending(current)) + return sock_intr_errno(timeo); + } + + tls_strp_msg_load(&ctx->strp, released); + + return 1; +} + +static int tls_setup_from_iter(struct iov_iter *from, + int length, int *pages_used, + struct scatterlist *to, + int to_max_pages) +{ + int rc = 0, i = 0, num_elem = *pages_used, maxpages; + struct page *pages[MAX_SKB_FRAGS]; + unsigned int size = 0; + ssize_t copied, use; + size_t offset; + + while (length > 0) { + i = 0; + maxpages = to_max_pages - num_elem; + if (maxpages == 0) { + rc = -EFAULT; + goto out; + } + copied = iov_iter_get_pages2(from, pages, + length, + maxpages, &offset); + if (copied <= 0) { + rc = -EFAULT; + goto out; + } + + length -= copied; + size += copied; + while (copied) { + use = min_t(int, copied, PAGE_SIZE - offset); + + sg_set_page(&to[num_elem], + pages[i], use, offset); + sg_unmark_end(&to[num_elem]); + /* We do not uncharge memory from this API */ + + offset = 0; + copied -= use; + + i++; + num_elem++; + } + } + /* Mark the end in the last sg entry if newly added */ + if (num_elem > *pages_used) + sg_mark_end(&to[num_elem - 1]); +out: + if (rc) + iov_iter_revert(from, size); + *pages_used = num_elem; + + return rc; +} + +static struct sk_buff * +tls_alloc_clrtxt_skb(struct sock *sk, struct sk_buff *skb, + unsigned int full_len) +{ + struct strp_msg *clr_rxm; + struct sk_buff *clr_skb; + int err; + + clr_skb = alloc_skb_with_frags(0, full_len, TLS_PAGE_ORDER, + &err, sk->sk_allocation); + if (!clr_skb) + return NULL; + + skb_copy_header(clr_skb, skb); + clr_skb->len = full_len; + clr_skb->data_len = full_len; + + clr_rxm = strp_msg(clr_skb); + clr_rxm->offset = 0; + + return clr_skb; +} + +/* Decrypt handlers + * + * tls_decrypt_sw() and tls_decrypt_device() are decrypt handlers. + * They must transform the darg in/out argument are as follows: + * | Input | Output + * ------------------------------------------------------------------- + * zc | Zero-copy decrypt allowed | Zero-copy performed + * async | Async decrypt allowed | Async crypto used / in progress + * skb | * | Output skb + * + * If ZC decryption was performed darg.skb will point to the input skb. + */ + +/* This function decrypts the input skb into either out_iov or in out_sg + * or in skb buffers itself. The input parameter 'darg->zc' indicates if + * zero-copy mode needs to be tried or not. With zero-copy mode, either + * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are + * NULL, then the decryption happens inside skb buffers itself, i.e. + * zero-copy gets disabled and 'darg->zc' is updated. + */ +static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov, + struct scatterlist *out_sg, + struct tls_decrypt_arg *darg) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; + int n_sgin, n_sgout, aead_size, err, pages = 0; + struct sk_buff *skb = tls_strp_msg(ctx); + const struct strp_msg *rxm = strp_msg(skb); + const struct tls_msg *tlm = tls_msg(skb); + struct aead_request *aead_req; + struct scatterlist *sgin = NULL; + struct scatterlist *sgout = NULL; + const int data_len = rxm->full_len - prot->overhead_size; + int tail_pages = !!prot->tail_size; + struct tls_decrypt_ctx *dctx; + struct sk_buff *clear_skb; + int iv_offset = 0; + u8 *mem; + + n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size, + rxm->full_len - prot->prepend_size); + if (n_sgin < 1) + return n_sgin ?: -EBADMSG; + + if (darg->zc && (out_iov || out_sg)) { + clear_skb = NULL; + + if (out_iov) + n_sgout = 1 + tail_pages + + iov_iter_npages_cap(out_iov, INT_MAX, data_len); + else + n_sgout = sg_nents(out_sg); + } else { + darg->zc = false; + + clear_skb = tls_alloc_clrtxt_skb(sk, skb, rxm->full_len); + if (!clear_skb) + return -ENOMEM; + + n_sgout = 1 + skb_shinfo(clear_skb)->nr_frags; + } + + /* Increment to accommodate AAD */ + n_sgin = n_sgin + 1; + + /* Allocate a single block of memory which contains + * aead_req || tls_decrypt_ctx. + * Both structs are variable length. + */ + aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); + aead_size = ALIGN(aead_size, __alignof__(*dctx)); + mem = kmalloc(aead_size + struct_size(dctx, sg, size_add(n_sgin, n_sgout)), + sk->sk_allocation); + if (!mem) { + err = -ENOMEM; + goto exit_free_skb; + } + + /* Segment the allocated memory */ + aead_req = (struct aead_request *)mem; + dctx = (struct tls_decrypt_ctx *)(mem + aead_size); + dctx->sk = sk; + sgin = &dctx->sg[0]; + sgout = &dctx->sg[n_sgin]; + + /* For CCM based ciphers, first byte of nonce+iv is a constant */ + switch (prot->cipher_type) { + case TLS_CIPHER_AES_CCM_128: + dctx->iv[0] = TLS_AES_CCM_IV_B0_BYTE; + iv_offset = 1; + break; + case TLS_CIPHER_SM4_CCM: + dctx->iv[0] = TLS_SM4_CCM_IV_B0_BYTE; + iv_offset = 1; + break; + } + + /* Prepare IV */ + if (prot->version == TLS_1_3_VERSION || + prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) { + memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, + prot->iv_size + prot->salt_size); + } else { + err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, + &dctx->iv[iv_offset] + prot->salt_size, + prot->iv_size); + if (err < 0) + goto exit_free; + memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, prot->salt_size); + } + tls_xor_iv_with_seq(prot, &dctx->iv[iv_offset], tls_ctx->rx.rec_seq); + + /* Prepare AAD */ + tls_make_aad(dctx->aad, rxm->full_len - prot->overhead_size + + prot->tail_size, + tls_ctx->rx.rec_seq, tlm->control, prot); + + /* Prepare sgin */ + sg_init_table(sgin, n_sgin); + sg_set_buf(&sgin[0], dctx->aad, prot->aad_size); + err = skb_to_sgvec(skb, &sgin[1], + rxm->offset + prot->prepend_size, + rxm->full_len - prot->prepend_size); + if (err < 0) + goto exit_free; + + if (clear_skb) { + sg_init_table(sgout, n_sgout); + sg_set_buf(&sgout[0], dctx->aad, prot->aad_size); + + err = skb_to_sgvec(clear_skb, &sgout[1], prot->prepend_size, + data_len + prot->tail_size); + if (err < 0) + goto exit_free; + } else if (out_iov) { + sg_init_table(sgout, n_sgout); + sg_set_buf(&sgout[0], dctx->aad, prot->aad_size); + + err = tls_setup_from_iter(out_iov, data_len, &pages, &sgout[1], + (n_sgout - 1 - tail_pages)); + if (err < 0) + goto exit_free_pages; + + if (prot->tail_size) { + sg_unmark_end(&sgout[pages]); + sg_set_buf(&sgout[pages + 1], &dctx->tail, + prot->tail_size); + sg_mark_end(&sgout[pages + 1]); + } + } else if (out_sg) { + memcpy(sgout, out_sg, n_sgout * sizeof(*sgout)); + } + + /* Prepare and submit AEAD request */ + err = tls_do_decryption(sk, sgin, sgout, dctx->iv, + data_len + prot->tail_size, aead_req, darg); + if (err) + goto exit_free_pages; + + darg->skb = clear_skb ?: tls_strp_msg(ctx); + clear_skb = NULL; + + if (unlikely(darg->async)) { + err = tls_strp_msg_hold(&ctx->strp, &ctx->async_hold); + if (err) + __skb_queue_tail(&ctx->async_hold, darg->skb); + return err; + } + + if (prot->tail_size) + darg->tail = dctx->tail; + +exit_free_pages: + /* Release the pages in case iov was mapped to pages */ + for (; pages > 0; pages--) + put_page(sg_page(&sgout[pages])); +exit_free: + kfree(mem); +exit_free_skb: + consume_skb(clear_skb); + return err; +} + +static int +tls_decrypt_sw(struct sock *sk, struct tls_context *tls_ctx, + struct msghdr *msg, struct tls_decrypt_arg *darg) +{ + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct strp_msg *rxm; + int pad, err; + + err = tls_decrypt_sg(sk, &msg->msg_iter, NULL, darg); + if (err < 0) { + if (err == -EBADMSG) + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); + return err; + } + /* keep going even for ->async, the code below is TLS 1.3 */ + + /* If opportunistic TLS 1.3 ZC failed retry without ZC */ + if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION && + darg->tail != TLS_RECORD_TYPE_DATA)) { + darg->zc = false; + if (!darg->tail) + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL); + TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY); + return tls_decrypt_sw(sk, tls_ctx, msg, darg); + } + + pad = tls_padding_length(prot, darg->skb, darg); + if (pad < 0) { + if (darg->skb != tls_strp_msg(ctx)) + consume_skb(darg->skb); + return pad; + } + + rxm = strp_msg(darg->skb); + rxm->full_len -= pad; + + return 0; +} + +static int +tls_decrypt_device(struct sock *sk, struct msghdr *msg, + struct tls_context *tls_ctx, struct tls_decrypt_arg *darg) +{ + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct strp_msg *rxm; + int pad, err; + + if (tls_ctx->rx_conf != TLS_HW) + return 0; + + err = tls_device_decrypted(sk, tls_ctx); + if (err <= 0) + return err; + + pad = tls_padding_length(prot, tls_strp_msg(ctx), darg); + if (pad < 0) + return pad; + + darg->async = false; + darg->skb = tls_strp_msg(ctx); + /* ->zc downgrade check, in case TLS 1.3 gets here */ + darg->zc &= !(prot->version == TLS_1_3_VERSION && + tls_msg(darg->skb)->control != TLS_RECORD_TYPE_DATA); + + rxm = strp_msg(darg->skb); + rxm->full_len -= pad; + + if (!darg->zc) { + /* Non-ZC case needs a real skb */ + darg->skb = tls_strp_msg_detach(ctx); + if (!darg->skb) + return -ENOMEM; + } else { + unsigned int off, len; + + /* In ZC case nobody cares about the output skb. + * Just copy the data here. Note the skb is not fully trimmed. + */ + off = rxm->offset + prot->prepend_size; + len = rxm->full_len - prot->overhead_size; + + err = skb_copy_datagram_msg(darg->skb, off, msg, len); + if (err) + return err; + } + return 1; +} + +static int tls_rx_one_record(struct sock *sk, struct msghdr *msg, + struct tls_decrypt_arg *darg) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct strp_msg *rxm; + int err; + + err = tls_decrypt_device(sk, msg, tls_ctx, darg); + if (!err) + err = tls_decrypt_sw(sk, tls_ctx, msg, darg); + if (err < 0) + return err; + + rxm = strp_msg(darg->skb); + rxm->offset += prot->prepend_size; + rxm->full_len -= prot->overhead_size; + tls_advance_record_sn(sk, prot, &tls_ctx->rx); + + return 0; +} + +int decrypt_skb(struct sock *sk, struct scatterlist *sgout) +{ + struct tls_decrypt_arg darg = { .zc = true, }; + + return tls_decrypt_sg(sk, NULL, sgout, &darg); +} + +static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm, + u8 *control) +{ + int err; + + if (!*control) { + *control = tlm->control; + if (!*control) + return -EBADMSG; + + err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, + sizeof(*control), control); + if (*control != TLS_RECORD_TYPE_DATA) { + if (err || msg->msg_flags & MSG_CTRUNC) + return -EIO; + } + } else if (*control != tlm->control) { + return 0; + } + + return 1; +} + +static void tls_rx_rec_done(struct tls_sw_context_rx *ctx) +{ + tls_strp_msg_done(&ctx->strp); +} + +/* This function traverses the rx_list in tls receive context to copies the + * decrypted records into the buffer provided by caller zero copy is not + * true. Further, the records are removed from the rx_list if it is not a peek + * case and the record has been consumed completely. + */ +static int process_rx_list(struct tls_sw_context_rx *ctx, + struct msghdr *msg, + u8 *control, + size_t skip, + size_t len, + bool is_peek) +{ + struct sk_buff *skb = skb_peek(&ctx->rx_list); + struct tls_msg *tlm; + ssize_t copied = 0; + int err; + + while (skip && skb) { + struct strp_msg *rxm = strp_msg(skb); + tlm = tls_msg(skb); + + err = tls_record_content_type(msg, tlm, control); + if (err <= 0) + goto out; + + if (skip < rxm->full_len) + break; + + skip = skip - rxm->full_len; + skb = skb_peek_next(skb, &ctx->rx_list); + } + + while (len && skb) { + struct sk_buff *next_skb; + struct strp_msg *rxm = strp_msg(skb); + int chunk = min_t(unsigned int, rxm->full_len - skip, len); + + tlm = tls_msg(skb); + + err = tls_record_content_type(msg, tlm, control); + if (err <= 0) + goto out; + + err = skb_copy_datagram_msg(skb, rxm->offset + skip, + msg, chunk); + if (err < 0) + goto out; + + len = len - chunk; + copied = copied + chunk; + + /* Consume the data from record if it is non-peek case*/ + if (!is_peek) { + rxm->offset = rxm->offset + chunk; + rxm->full_len = rxm->full_len - chunk; + + /* Return if there is unconsumed data in the record */ + if (rxm->full_len - skip) + break; + } + + /* The remaining skip-bytes must lie in 1st record in rx_list. + * So from the 2nd record, 'skip' should be 0. + */ + skip = 0; + + if (msg) + msg->msg_flags |= MSG_EOR; + + next_skb = skb_peek_next(skb, &ctx->rx_list); + + if (!is_peek) { + __skb_unlink(skb, &ctx->rx_list); + consume_skb(skb); + } + + skb = next_skb; + } + err = 0; + +out: + return copied ? : err; +} + +static bool +tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot, + size_t len_left, size_t decrypted, ssize_t done, + size_t *flushed_at) +{ + size_t max_rec; + + if (len_left <= decrypted) + return false; + + max_rec = prot->overhead_size - prot->tail_size + TLS_MAX_PAYLOAD_SIZE; + if (done - *flushed_at < SZ_128K && tcp_inq(sk) > max_rec) + return false; + + *flushed_at = done; + return sk_flush_backlog(sk); +} + +static int tls_rx_reader_acquire(struct sock *sk, struct tls_sw_context_rx *ctx, + bool nonblock) +{ + long timeo; + int ret; + + timeo = sock_rcvtimeo(sk, nonblock); + + while (unlikely(ctx->reader_present)) { + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + ctx->reader_contended = 1; + + add_wait_queue(&ctx->wq, &wait); + ret = sk_wait_event(sk, &timeo, + !READ_ONCE(ctx->reader_present), &wait); + remove_wait_queue(&ctx->wq, &wait); + + if (timeo <= 0) + return -EAGAIN; + if (signal_pending(current)) + return sock_intr_errno(timeo); + if (ret < 0) + return ret; + } + + WRITE_ONCE(ctx->reader_present, 1); + + return 0; +} + +static int tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx, + bool nonblock) +{ + int err; + + lock_sock(sk); + err = tls_rx_reader_acquire(sk, ctx, nonblock); + if (err) + release_sock(sk); + return err; +} + +static void tls_rx_reader_release(struct sock *sk, struct tls_sw_context_rx *ctx) +{ + if (unlikely(ctx->reader_contended)) { + if (wq_has_sleeper(&ctx->wq)) + wake_up(&ctx->wq); + else + ctx->reader_contended = 0; + + WARN_ON_ONCE(!ctx->reader_present); + } + + WRITE_ONCE(ctx->reader_present, 0); +} + +static void tls_rx_reader_unlock(struct sock *sk, struct tls_sw_context_rx *ctx) +{ + tls_rx_reader_release(sk, ctx); + release_sock(sk); +} + +int tls_sw_recvmsg(struct sock *sk, + struct msghdr *msg, + size_t len, + int flags, + int *addr_len) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct tls_prot_info *prot = &tls_ctx->prot_info; + ssize_t decrypted = 0, async_copy_bytes = 0; + struct sk_psock *psock; + unsigned char control = 0; + size_t flushed_at = 0; + struct strp_msg *rxm; + struct tls_msg *tlm; + ssize_t copied = 0; + bool async = false; + int target, err; + bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); + bool is_peek = flags & MSG_PEEK; + bool released = true; + bool bpf_strp_enabled; + bool zc_capable; + + if (unlikely(flags & MSG_ERRQUEUE)) + return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); + + psock = sk_psock_get(sk); + err = tls_rx_reader_lock(sk, ctx, flags & MSG_DONTWAIT); + if (err < 0) + return err; + bpf_strp_enabled = sk_psock_strp_enabled(psock); + + /* If crypto failed the connection is broken */ + err = ctx->async_wait.err; + if (err) + goto end; + + /* Process pending decrypted records. It must be non-zero-copy */ + err = process_rx_list(ctx, msg, &control, 0, len, is_peek); + if (err < 0) + goto end; + + copied = err; + if (len <= copied) + goto end; + + target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); + len = len - copied; + + zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek && + ctx->zc_capable; + decrypted = 0; + while (len && (decrypted + copied < target || tls_strp_msg_ready(ctx))) { + struct tls_decrypt_arg darg; + int to_decrypt, chunk; + + err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, + released); + if (err <= 0) { + if (psock) { + chunk = sk_msg_recvmsg(sk, psock, msg, len, + flags); + if (chunk > 0) { + decrypted += chunk; + len -= chunk; + continue; + } + } + goto recv_end; + } + + memset(&darg.inargs, 0, sizeof(darg.inargs)); + + rxm = strp_msg(tls_strp_msg(ctx)); + tlm = tls_msg(tls_strp_msg(ctx)); + + to_decrypt = rxm->full_len - prot->overhead_size; + + if (zc_capable && to_decrypt <= len && + tlm->control == TLS_RECORD_TYPE_DATA) + darg.zc = true; + + /* Do not use async mode if record is non-data */ + if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled) + darg.async = ctx->async_capable; + else + darg.async = false; + + err = tls_rx_one_record(sk, msg, &darg); + if (err < 0) { + tls_err_abort(sk, -EBADMSG); + goto recv_end; + } + + async |= darg.async; + + /* If the type of records being processed is not known yet, + * set it to record type just dequeued. If it is already known, + * but does not match the record type just dequeued, go to end. + * We always get record type here since for tls1.2, record type + * is known just after record is dequeued from stream parser. + * For tls1.3, we disable async. + */ + err = tls_record_content_type(msg, tls_msg(darg.skb), &control); + if (err <= 0) { + DEBUG_NET_WARN_ON_ONCE(darg.zc); + tls_rx_rec_done(ctx); +put_on_rx_list_err: + __skb_queue_tail(&ctx->rx_list, darg.skb); + goto recv_end; + } + + /* periodically flush backlog, and feed strparser */ + released = tls_read_flush_backlog(sk, prot, len, to_decrypt, + decrypted + copied, + &flushed_at); + + /* TLS 1.3 may have updated the length by more than overhead */ + rxm = strp_msg(darg.skb); + chunk = rxm->full_len; + tls_rx_rec_done(ctx); + + if (!darg.zc) { + bool partially_consumed = chunk > len; + struct sk_buff *skb = darg.skb; + + DEBUG_NET_WARN_ON_ONCE(darg.skb == ctx->strp.anchor); + + if (async) { + /* TLS 1.2-only, to_decrypt must be text len */ + chunk = min_t(int, to_decrypt, len); + async_copy_bytes += chunk; +put_on_rx_list: + decrypted += chunk; + len -= chunk; + __skb_queue_tail(&ctx->rx_list, skb); + continue; + } + + if (bpf_strp_enabled) { + released = true; + err = sk_psock_tls_strp_read(psock, skb); + if (err != __SK_PASS) { + rxm->offset = rxm->offset + rxm->full_len; + rxm->full_len = 0; + if (err == __SK_DROP) + consume_skb(skb); + continue; + } + } + + if (partially_consumed) + chunk = len; + + err = skb_copy_datagram_msg(skb, rxm->offset, + msg, chunk); + if (err < 0) + goto put_on_rx_list_err; + + if (is_peek) + goto put_on_rx_list; + + if (partially_consumed) { + rxm->offset += chunk; + rxm->full_len -= chunk; + goto put_on_rx_list; + } + + consume_skb(skb); + } + + decrypted += chunk; + len -= chunk; + + /* Return full control message to userspace before trying + * to parse another message type + */ + msg->msg_flags |= MSG_EOR; + if (control != TLS_RECORD_TYPE_DATA) + break; + } + +recv_end: + if (async) { + int ret, pending; + + /* Wait for all previously submitted records to be decrypted */ + spin_lock_bh(&ctx->decrypt_compl_lock); + reinit_completion(&ctx->async_wait.completion); + pending = atomic_read(&ctx->decrypt_pending); + spin_unlock_bh(&ctx->decrypt_compl_lock); + ret = 0; + if (pending) + ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait); + __skb_queue_purge(&ctx->async_hold); + + if (ret) { + if (err >= 0 || err == -EINPROGRESS) + err = ret; + decrypted = 0; + goto end; + } + + /* Drain records from the rx_list & copy if required */ + if (is_peek || is_kvec) + err = process_rx_list(ctx, msg, &control, copied, + decrypted, is_peek); + else + err = process_rx_list(ctx, msg, &control, 0, + async_copy_bytes, is_peek); + decrypted += max(err, 0); + } + + copied += decrypted; + +end: + tls_rx_reader_unlock(sk, ctx); + if (psock) + sk_psock_put(sk, psock); + return copied ? : err; +} + +ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, + struct pipe_inode_info *pipe, + size_t len, unsigned int flags) +{ + struct tls_context *tls_ctx = tls_get_ctx(sock->sk); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct strp_msg *rxm = NULL; + struct sock *sk = sock->sk; + struct tls_msg *tlm; + struct sk_buff *skb; + ssize_t copied = 0; + int chunk; + int err; + + err = tls_rx_reader_lock(sk, ctx, flags & SPLICE_F_NONBLOCK); + if (err < 0) + return err; + + if (!skb_queue_empty(&ctx->rx_list)) { + skb = __skb_dequeue(&ctx->rx_list); + } else { + struct tls_decrypt_arg darg; + + err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK, + true); + if (err <= 0) + goto splice_read_end; + + memset(&darg.inargs, 0, sizeof(darg.inargs)); + + err = tls_rx_one_record(sk, NULL, &darg); + if (err < 0) { + tls_err_abort(sk, -EBADMSG); + goto splice_read_end; + } + + tls_rx_rec_done(ctx); + skb = darg.skb; + } + + rxm = strp_msg(skb); + tlm = tls_msg(skb); + + /* splice does not support reading control messages */ + if (tlm->control != TLS_RECORD_TYPE_DATA) { + err = -EINVAL; + goto splice_requeue; + } + + chunk = min_t(unsigned int, rxm->full_len, len); + copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags); + if (copied < 0) + goto splice_requeue; + + if (chunk < rxm->full_len) { + rxm->offset += len; + rxm->full_len -= len; + goto splice_requeue; + } + + consume_skb(skb); + +splice_read_end: + tls_rx_reader_unlock(sk, ctx); + return copied ? : err; + +splice_requeue: + __skb_queue_head(&ctx->rx_list, skb); + goto splice_read_end; +} + +bool tls_sw_sock_is_readable(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + bool ingress_empty = true; + struct sk_psock *psock; + + rcu_read_lock(); + psock = sk_psock(sk); + if (psock) + ingress_empty = list_empty(&psock->ingress_msg); + rcu_read_unlock(); + + return !ingress_empty || tls_strp_msg_ready(ctx) || + !skb_queue_empty(&ctx->rx_list); +} + +int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb) +{ + struct tls_context *tls_ctx = tls_get_ctx(strp->sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; + char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; + size_t cipher_overhead; + size_t data_len = 0; + int ret; + + /* Verify that we have a full TLS header, or wait for more data */ + if (strp->stm.offset + prot->prepend_size > skb->len) + return 0; + + /* Sanity-check size of on-stack buffer. */ + if (WARN_ON(prot->prepend_size > sizeof(header))) { + ret = -EINVAL; + goto read_failure; + } + + /* Linearize header to local buffer */ + ret = skb_copy_bits(skb, strp->stm.offset, header, prot->prepend_size); + if (ret < 0) + goto read_failure; + + strp->mark = header[0]; + + data_len = ((header[4] & 0xFF) | (header[3] << 8)); + + cipher_overhead = prot->tag_size; + if (prot->version != TLS_1_3_VERSION && + prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305) + cipher_overhead += prot->iv_size; + + if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead + + prot->tail_size) { + ret = -EMSGSIZE; + goto read_failure; + } + if (data_len < cipher_overhead) { + ret = -EBADMSG; + goto read_failure; + } + + /* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */ + if (header[1] != TLS_1_2_VERSION_MINOR || + header[2] != TLS_1_2_VERSION_MAJOR) { + ret = -EINVAL; + goto read_failure; + } + + tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE, + TCP_SKB_CB(skb)->seq + strp->stm.offset); + return data_len + TLS_HEADER_SIZE; + +read_failure: + tls_err_abort(strp->sk, ret); + + return ret; +} + +void tls_rx_msg_ready(struct tls_strparser *strp) +{ + struct tls_sw_context_rx *ctx; + + ctx = container_of(strp, struct tls_sw_context_rx, strp); + ctx->saved_data_ready(strp->sk); +} + +static void tls_data_ready(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct sk_psock *psock; + gfp_t alloc_save; + + alloc_save = sk->sk_allocation; + sk->sk_allocation = GFP_ATOMIC; + tls_strp_data_ready(&ctx->strp); + sk->sk_allocation = alloc_save; + + psock = sk_psock_get(sk); + if (psock) { + if (!list_empty(&psock->ingress_msg)) + ctx->saved_data_ready(sk); + sk_psock_put(sk, psock); + } +} + +void tls_sw_cancel_work_tx(struct tls_context *tls_ctx) +{ + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + + set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask); + set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask); + cancel_delayed_work_sync(&ctx->tx_work.work); +} + +void tls_sw_release_resources_tx(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + struct tls_rec *rec, *tmp; + int pending; + + /* Wait for any pending async encryptions to complete */ + spin_lock_bh(&ctx->encrypt_compl_lock); + ctx->async_notify = true; + pending = atomic_read(&ctx->encrypt_pending); + spin_unlock_bh(&ctx->encrypt_compl_lock); + + if (pending) + crypto_wait_req(-EINPROGRESS, &ctx->async_wait); + + tls_tx_records(sk, -1); + + /* Free up un-sent records in tx_list. First, free + * the partially sent record if any at head of tx_list. + */ + if (tls_ctx->partially_sent_record) { + tls_free_partial_record(sk, tls_ctx); + rec = list_first_entry(&ctx->tx_list, + struct tls_rec, list); + list_del(&rec->list); + sk_msg_free(sk, &rec->msg_plaintext); + kfree(rec); + } + + list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) { + list_del(&rec->list); + sk_msg_free(sk, &rec->msg_encrypted); + sk_msg_free(sk, &rec->msg_plaintext); + kfree(rec); + } + + crypto_free_aead(ctx->aead_send); + tls_free_open_rec(sk); +} + +void tls_sw_free_ctx_tx(struct tls_context *tls_ctx) +{ + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); + + kfree(ctx); +} + +void tls_sw_release_resources_rx(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + + kfree(tls_ctx->rx.rec_seq); + kfree(tls_ctx->rx.iv); + + if (ctx->aead_recv) { + __skb_queue_purge(&ctx->rx_list); + crypto_free_aead(ctx->aead_recv); + tls_strp_stop(&ctx->strp); + /* If tls_sw_strparser_arm() was not called (cleanup paths) + * we still want to tls_strp_stop(), but sk->sk_data_ready was + * never swapped. + */ + if (ctx->saved_data_ready) { + write_lock_bh(&sk->sk_callback_lock); + sk->sk_data_ready = ctx->saved_data_ready; + write_unlock_bh(&sk->sk_callback_lock); + } + } +} + +void tls_sw_strparser_done(struct tls_context *tls_ctx) +{ + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + + tls_strp_done(&ctx->strp); +} + +void tls_sw_free_ctx_rx(struct tls_context *tls_ctx) +{ + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + + kfree(ctx); +} + +void tls_sw_free_resources_rx(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + + tls_sw_release_resources_rx(sk); + tls_sw_free_ctx_rx(tls_ctx); +} + +/* The work handler to transmitt the encrypted records in tx_list */ +static void tx_work_handler(struct work_struct *work) +{ + struct delayed_work *delayed_work = to_delayed_work(work); + struct tx_work *tx_work = container_of(delayed_work, + struct tx_work, work); + struct sock *sk = tx_work->sk; + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_tx *ctx; + + if (unlikely(!tls_ctx)) + return; + + ctx = tls_sw_ctx_tx(tls_ctx); + if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask)) + return; + + if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) + return; + + if (mutex_trylock(&tls_ctx->tx_lock)) { + lock_sock(sk); + tls_tx_records(sk, -1); + release_sock(sk); + mutex_unlock(&tls_ctx->tx_lock); + } else if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) { + /* Someone is holding the tx_lock, they will likely run Tx + * and cancel the work on their way out of the lock section. + * Schedule a long delay just in case. + */ + schedule_delayed_work(&ctx->tx_work.work, msecs_to_jiffies(10)); + } +} + +static bool tls_is_tx_ready(struct tls_sw_context_tx *ctx) +{ + struct tls_rec *rec; + + rec = list_first_entry_or_null(&ctx->tx_list, struct tls_rec, list); + if (!rec) + return false; + + return READ_ONCE(rec->tx_ready); +} + +void tls_sw_write_space(struct sock *sk, struct tls_context *ctx) +{ + struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx); + + /* Schedule the transmission if tx list is ready */ + if (tls_is_tx_ready(tx_ctx) && + !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask)) + schedule_delayed_work(&tx_ctx->tx_work.work, 0); +} + +void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx) +{ + struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx); + + write_lock_bh(&sk->sk_callback_lock); + rx_ctx->saved_data_ready = sk->sk_data_ready; + sk->sk_data_ready = tls_data_ready; + write_unlock_bh(&sk->sk_callback_lock); +} + +void tls_update_rx_zc_capable(struct tls_context *tls_ctx) +{ + struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx); + + rx_ctx->zc_capable = tls_ctx->rx_no_pad || + tls_ctx->prot_info.version != TLS_1_3_VERSION; +} + +int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_prot_info *prot = &tls_ctx->prot_info; + struct tls_crypto_info *crypto_info; + struct tls_sw_context_tx *sw_ctx_tx = NULL; + struct tls_sw_context_rx *sw_ctx_rx = NULL; + struct cipher_context *cctx; + struct crypto_aead **aead; + u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size; + struct crypto_tfm *tfm; + char *iv, *rec_seq, *key, *salt, *cipher_name; + size_t keysize; + int rc = 0; + + if (!ctx) { + rc = -EINVAL; + goto out; + } + + if (tx) { + if (!ctx->priv_ctx_tx) { + sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); + if (!sw_ctx_tx) { + rc = -ENOMEM; + goto out; + } + ctx->priv_ctx_tx = sw_ctx_tx; + } else { + sw_ctx_tx = + (struct tls_sw_context_tx *)ctx->priv_ctx_tx; + } + } else { + if (!ctx->priv_ctx_rx) { + sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); + if (!sw_ctx_rx) { + rc = -ENOMEM; + goto out; + } + ctx->priv_ctx_rx = sw_ctx_rx; + } else { + sw_ctx_rx = + (struct tls_sw_context_rx *)ctx->priv_ctx_rx; + } + } + + if (tx) { + crypto_init_wait(&sw_ctx_tx->async_wait); + spin_lock_init(&sw_ctx_tx->encrypt_compl_lock); + crypto_info = &ctx->crypto_send.info; + cctx = &ctx->tx; + aead = &sw_ctx_tx->aead_send; + INIT_LIST_HEAD(&sw_ctx_tx->tx_list); + INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler); + sw_ctx_tx->tx_work.sk = sk; + } else { + crypto_init_wait(&sw_ctx_rx->async_wait); + spin_lock_init(&sw_ctx_rx->decrypt_compl_lock); + init_waitqueue_head(&sw_ctx_rx->wq); + crypto_info = &ctx->crypto_recv.info; + cctx = &ctx->rx; + skb_queue_head_init(&sw_ctx_rx->rx_list); + skb_queue_head_init(&sw_ctx_rx->async_hold); + aead = &sw_ctx_rx->aead_recv; + } + + switch (crypto_info->cipher_type) { + case TLS_CIPHER_AES_GCM_128: { + struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; + + gcm_128_info = (void *)crypto_info; + nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; + tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE; + iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; + iv = gcm_128_info->iv; + rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE; + rec_seq = gcm_128_info->rec_seq; + keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE; + key = gcm_128_info->key; + salt = gcm_128_info->salt; + salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE; + cipher_name = "gcm(aes)"; + break; + } + case TLS_CIPHER_AES_GCM_256: { + struct tls12_crypto_info_aes_gcm_256 *gcm_256_info; + + gcm_256_info = (void *)crypto_info; + nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE; + tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE; + iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE; + iv = gcm_256_info->iv; + rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE; + rec_seq = gcm_256_info->rec_seq; + keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE; + key = gcm_256_info->key; + salt = gcm_256_info->salt; + salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE; + cipher_name = "gcm(aes)"; + break; + } + case TLS_CIPHER_AES_CCM_128: { + struct tls12_crypto_info_aes_ccm_128 *ccm_128_info; + + ccm_128_info = (void *)crypto_info; + nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE; + tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE; + iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE; + iv = ccm_128_info->iv; + rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE; + rec_seq = ccm_128_info->rec_seq; + keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE; + key = ccm_128_info->key; + salt = ccm_128_info->salt; + salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE; + cipher_name = "ccm(aes)"; + break; + } + case TLS_CIPHER_CHACHA20_POLY1305: { + struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305_info; + + chacha20_poly1305_info = (void *)crypto_info; + nonce_size = 0; + tag_size = TLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE; + iv_size = TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE; + iv = chacha20_poly1305_info->iv; + rec_seq_size = TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE; + rec_seq = chacha20_poly1305_info->rec_seq; + keysize = TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE; + key = chacha20_poly1305_info->key; + salt = chacha20_poly1305_info->salt; + salt_size = TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE; + cipher_name = "rfc7539(chacha20,poly1305)"; + break; + } + case TLS_CIPHER_SM4_GCM: { + struct tls12_crypto_info_sm4_gcm *sm4_gcm_info; + + sm4_gcm_info = (void *)crypto_info; + nonce_size = TLS_CIPHER_SM4_GCM_IV_SIZE; + tag_size = TLS_CIPHER_SM4_GCM_TAG_SIZE; + iv_size = TLS_CIPHER_SM4_GCM_IV_SIZE; + iv = sm4_gcm_info->iv; + rec_seq_size = TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE; + rec_seq = sm4_gcm_info->rec_seq; + keysize = TLS_CIPHER_SM4_GCM_KEY_SIZE; + key = sm4_gcm_info->key; + salt = sm4_gcm_info->salt; + salt_size = TLS_CIPHER_SM4_GCM_SALT_SIZE; + cipher_name = "gcm(sm4)"; + break; + } + case TLS_CIPHER_SM4_CCM: { + struct tls12_crypto_info_sm4_ccm *sm4_ccm_info; + + sm4_ccm_info = (void *)crypto_info; + nonce_size = TLS_CIPHER_SM4_CCM_IV_SIZE; + tag_size = TLS_CIPHER_SM4_CCM_TAG_SIZE; + iv_size = TLS_CIPHER_SM4_CCM_IV_SIZE; + iv = sm4_ccm_info->iv; + rec_seq_size = TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE; + rec_seq = sm4_ccm_info->rec_seq; + keysize = TLS_CIPHER_SM4_CCM_KEY_SIZE; + key = sm4_ccm_info->key; + salt = sm4_ccm_info->salt; + salt_size = TLS_CIPHER_SM4_CCM_SALT_SIZE; + cipher_name = "ccm(sm4)"; + break; + } + case TLS_CIPHER_ARIA_GCM_128: { + struct tls12_crypto_info_aria_gcm_128 *aria_gcm_128_info; + + aria_gcm_128_info = (void *)crypto_info; + nonce_size = TLS_CIPHER_ARIA_GCM_128_IV_SIZE; + tag_size = TLS_CIPHER_ARIA_GCM_128_TAG_SIZE; + iv_size = TLS_CIPHER_ARIA_GCM_128_IV_SIZE; + iv = aria_gcm_128_info->iv; + rec_seq_size = TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE; + rec_seq = aria_gcm_128_info->rec_seq; + keysize = TLS_CIPHER_ARIA_GCM_128_KEY_SIZE; + key = aria_gcm_128_info->key; + salt = aria_gcm_128_info->salt; + salt_size = TLS_CIPHER_ARIA_GCM_128_SALT_SIZE; + cipher_name = "gcm(aria)"; + break; + } + case TLS_CIPHER_ARIA_GCM_256: { + struct tls12_crypto_info_aria_gcm_256 *gcm_256_info; + + gcm_256_info = (void *)crypto_info; + nonce_size = TLS_CIPHER_ARIA_GCM_256_IV_SIZE; + tag_size = TLS_CIPHER_ARIA_GCM_256_TAG_SIZE; + iv_size = TLS_CIPHER_ARIA_GCM_256_IV_SIZE; + iv = gcm_256_info->iv; + rec_seq_size = TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE; + rec_seq = gcm_256_info->rec_seq; + keysize = TLS_CIPHER_ARIA_GCM_256_KEY_SIZE; + key = gcm_256_info->key; + salt = gcm_256_info->salt; + salt_size = TLS_CIPHER_ARIA_GCM_256_SALT_SIZE; + cipher_name = "gcm(aria)"; + break; + } + default: + rc = -EINVAL; + goto free_priv; + } + + if (crypto_info->version == TLS_1_3_VERSION) { + nonce_size = 0; + prot->aad_size = TLS_HEADER_SIZE; + prot->tail_size = 1; + } else { + prot->aad_size = TLS_AAD_SPACE_SIZE; + prot->tail_size = 0; + } + + /* Sanity-check the sizes for stack allocations. */ + if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE || + rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE || + prot->aad_size > TLS_MAX_AAD_SIZE) { + rc = -EINVAL; + goto free_priv; + } + + prot->version = crypto_info->version; + prot->cipher_type = crypto_info->cipher_type; + prot->prepend_size = TLS_HEADER_SIZE + nonce_size; + prot->tag_size = tag_size; + prot->overhead_size = prot->prepend_size + + prot->tag_size + prot->tail_size; + prot->iv_size = iv_size; + prot->salt_size = salt_size; + cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL); + if (!cctx->iv) { + rc = -ENOMEM; + goto free_priv; + } + /* Note: 128 & 256 bit salt are the same size */ + prot->rec_seq_size = rec_seq_size; + memcpy(cctx->iv, salt, salt_size); + memcpy(cctx->iv + salt_size, iv, iv_size); + cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); + if (!cctx->rec_seq) { + rc = -ENOMEM; + goto free_iv; + } + + if (!*aead) { + *aead = crypto_alloc_aead(cipher_name, 0, 0); + if (IS_ERR(*aead)) { + rc = PTR_ERR(*aead); + *aead = NULL; + goto free_rec_seq; + } + } + + ctx->push_pending_record = tls_sw_push_pending_record; + + rc = crypto_aead_setkey(*aead, key, keysize); + + if (rc) + goto free_aead; + + rc = crypto_aead_setauthsize(*aead, prot->tag_size); + if (rc) + goto free_aead; + + if (sw_ctx_rx) { + tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv); + + tls_update_rx_zc_capable(ctx); + sw_ctx_rx->async_capable = + crypto_info->version != TLS_1_3_VERSION && + !!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC); + + rc = tls_strp_init(&sw_ctx_rx->strp, sk); + if (rc) + goto free_aead; + } + + goto out; + +free_aead: + crypto_free_aead(*aead); + *aead = NULL; +free_rec_seq: + kfree(cctx->rec_seq); + cctx->rec_seq = NULL; +free_iv: + kfree(cctx->iv); + cctx->iv = NULL; +free_priv: + if (tx) { + kfree(ctx->priv_ctx_tx); + ctx->priv_ctx_tx = NULL; + } else { + kfree(ctx->priv_ctx_rx); + ctx->priv_ctx_rx = NULL; + } +out: + return rc; +} diff --git a/net/tls/tls_toe.c b/net/tls/tls_toe.c new file mode 100644 index 000000000..825669e1a --- /dev/null +++ b/net/tls/tls_toe.c @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. + * Copyright (c) 2016-2017, Dave Watson . All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the + * OpenIB.org BSD license below: + * + * Redistribution and use in source and binary forms, with or + * without modification, are permitted provided that the following + * conditions are met: + * + * - Redistributions of source code must retain the above + * copyright notice, this list of conditions and the following + * disclaimer. + * + * - Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials + * provided with the distribution. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include + +#include "tls.h" + +static LIST_HEAD(device_list); +static DEFINE_SPINLOCK(device_spinlock); + +static void tls_toe_sk_destruct(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct tls_context *ctx = tls_get_ctx(sk); + + ctx->sk_destruct(sk); + /* Free ctx */ + rcu_assign_pointer(icsk->icsk_ulp_data, NULL); + tls_ctx_free(sk, ctx); +} + +int tls_toe_bypass(struct sock *sk) +{ + struct tls_toe_device *dev; + struct tls_context *ctx; + int rc = 0; + + spin_lock_bh(&device_spinlock); + list_for_each_entry(dev, &device_list, dev_list) { + if (dev->feature && dev->feature(dev)) { + ctx = tls_ctx_create(sk); + if (!ctx) + goto out; + + ctx->sk_destruct = sk->sk_destruct; + sk->sk_destruct = tls_toe_sk_destruct; + ctx->rx_conf = TLS_HW_RECORD; + ctx->tx_conf = TLS_HW_RECORD; + update_sk_prot(sk, ctx); + rc = 1; + break; + } + } +out: + spin_unlock_bh(&device_spinlock); + return rc; +} + +void tls_toe_unhash(struct sock *sk) +{ + struct tls_context *ctx = tls_get_ctx(sk); + struct tls_toe_device *dev; + + spin_lock_bh(&device_spinlock); + list_for_each_entry(dev, &device_list, dev_list) { + if (dev->unhash) { + kref_get(&dev->kref); + spin_unlock_bh(&device_spinlock); + dev->unhash(dev, sk); + kref_put(&dev->kref, dev->release); + spin_lock_bh(&device_spinlock); + } + } + spin_unlock_bh(&device_spinlock); + ctx->sk_proto->unhash(sk); +} + +int tls_toe_hash(struct sock *sk) +{ + struct tls_context *ctx = tls_get_ctx(sk); + struct tls_toe_device *dev; + int err; + + err = ctx->sk_proto->hash(sk); + spin_lock_bh(&device_spinlock); + list_for_each_entry(dev, &device_list, dev_list) { + if (dev->hash) { + kref_get(&dev->kref); + spin_unlock_bh(&device_spinlock); + err |= dev->hash(dev, sk); + kref_put(&dev->kref, dev->release); + spin_lock_bh(&device_spinlock); + } + } + spin_unlock_bh(&device_spinlock); + + if (err) + tls_toe_unhash(sk); + return err; +} + +void tls_toe_register_device(struct tls_toe_device *device) +{ + spin_lock_bh(&device_spinlock); + list_add_tail(&device->dev_list, &device_list); + spin_unlock_bh(&device_spinlock); +} +EXPORT_SYMBOL(tls_toe_register_device); + +void tls_toe_unregister_device(struct tls_toe_device *device) +{ + spin_lock_bh(&device_spinlock); + list_del(&device->dev_list); + spin_unlock_bh(&device_spinlock); +} +EXPORT_SYMBOL(tls_toe_unregister_device); diff --git a/net/tls/trace.c b/net/tls/trace.c new file mode 100644 index 000000000..e374913cf --- /dev/null +++ b/net/tls/trace.c @@ -0,0 +1,10 @@ +// SPDX-License-Identifier: (GPL-2.0-only OR BSD-2-Clause) +/* Copyright (C) 2019 Netronome Systems, Inc. */ + +#include + +#ifndef __CHECKER__ +#define CREATE_TRACE_POINTS +#include "trace.h" + +#endif diff --git a/net/tls/trace.h b/net/tls/trace.h new file mode 100644 index 000000000..9ba5f600e --- /dev/null +++ b/net/tls/trace.h @@ -0,0 +1,202 @@ +/* SPDX-License-Identifier: (GPL-2.0-only OR BSD-2-Clause) */ +/* Copyright (C) 2019 Netronome Systems, Inc. */ + +#undef TRACE_SYSTEM +#define TRACE_SYSTEM tls + +#if !defined(_TLS_TRACE_H_) || defined(TRACE_HEADER_MULTI_READ) +#define _TLS_TRACE_H_ + +#include +#include + +struct sock; + +TRACE_EVENT(tls_device_offload_set, + + TP_PROTO(struct sock *sk, int dir, u32 tcp_seq, u8 *rec_no, int ret), + + TP_ARGS(sk, dir, tcp_seq, rec_no, ret), + + TP_STRUCT__entry( + __field( struct sock *, sk ) + __field( u64, rec_no ) + __field( int, dir ) + __field( u32, tcp_seq ) + __field( int, ret ) + ), + + TP_fast_assign( + __entry->sk = sk; + __entry->rec_no = get_unaligned_be64(rec_no); + __entry->dir = dir; + __entry->tcp_seq = tcp_seq; + __entry->ret = ret; + ), + + TP_printk( + "sk=%p direction=%d tcp_seq=%u rec_no=%llu ret=%d", + __entry->sk, __entry->dir, __entry->tcp_seq, __entry->rec_no, + __entry->ret + ) +); + +TRACE_EVENT(tls_device_decrypted, + + TP_PROTO(struct sock *sk, u32 tcp_seq, u8 *rec_no, u32 rec_len, + bool encrypted, bool decrypted), + + TP_ARGS(sk, tcp_seq, rec_no, rec_len, encrypted, decrypted), + + TP_STRUCT__entry( + __field( struct sock *, sk ) + __field( u64, rec_no ) + __field( u32, tcp_seq ) + __field( u32, rec_len ) + __field( bool, encrypted ) + __field( bool, decrypted ) + ), + + TP_fast_assign( + __entry->sk = sk; + __entry->rec_no = get_unaligned_be64(rec_no); + __entry->tcp_seq = tcp_seq; + __entry->rec_len = rec_len; + __entry->encrypted = encrypted; + __entry->decrypted = decrypted; + ), + + TP_printk( + "sk=%p tcp_seq=%u rec_no=%llu len=%u encrypted=%d decrypted=%d", + __entry->sk, __entry->tcp_seq, + __entry->rec_no, __entry->rec_len, + __entry->encrypted, __entry->decrypted + ) +); + +TRACE_EVENT(tls_device_rx_resync_send, + + TP_PROTO(struct sock *sk, u32 tcp_seq, u8 *rec_no, int sync_type), + + TP_ARGS(sk, tcp_seq, rec_no, sync_type), + + TP_STRUCT__entry( + __field( struct sock *, sk ) + __field( u64, rec_no ) + __field( u32, tcp_seq ) + __field( int, sync_type ) + ), + + TP_fast_assign( + __entry->sk = sk; + __entry->rec_no = get_unaligned_be64(rec_no); + __entry->tcp_seq = tcp_seq; + __entry->sync_type = sync_type; + ), + + TP_printk( + "sk=%p tcp_seq=%u rec_no=%llu sync_type=%d", + __entry->sk, __entry->tcp_seq, __entry->rec_no, + __entry->sync_type + ) +); + +TRACE_EVENT(tls_device_rx_resync_nh_schedule, + + TP_PROTO(struct sock *sk), + + TP_ARGS(sk), + + TP_STRUCT__entry( + __field( struct sock *, sk ) + ), + + TP_fast_assign( + __entry->sk = sk; + ), + + TP_printk( + "sk=%p", __entry->sk + ) +); + +TRACE_EVENT(tls_device_rx_resync_nh_delay, + + TP_PROTO(struct sock *sk, u32 sock_data, u32 rec_len), + + TP_ARGS(sk, sock_data, rec_len), + + TP_STRUCT__entry( + __field( struct sock *, sk ) + __field( u32, sock_data ) + __field( u32, rec_len ) + ), + + TP_fast_assign( + __entry->sk = sk; + __entry->sock_data = sock_data; + __entry->rec_len = rec_len; + ), + + TP_printk( + "sk=%p sock_data=%u rec_len=%u", + __entry->sk, __entry->sock_data, __entry->rec_len + ) +); + +TRACE_EVENT(tls_device_tx_resync_req, + + TP_PROTO(struct sock *sk, u32 tcp_seq, u32 exp_tcp_seq), + + TP_ARGS(sk, tcp_seq, exp_tcp_seq), + + TP_STRUCT__entry( + __field( struct sock *, sk ) + __field( u32, tcp_seq ) + __field( u32, exp_tcp_seq ) + ), + + TP_fast_assign( + __entry->sk = sk; + __entry->tcp_seq = tcp_seq; + __entry->exp_tcp_seq = exp_tcp_seq; + ), + + TP_printk( + "sk=%p tcp_seq=%u exp_tcp_seq=%u", + __entry->sk, __entry->tcp_seq, __entry->exp_tcp_seq + ) +); + +TRACE_EVENT(tls_device_tx_resync_send, + + TP_PROTO(struct sock *sk, u32 tcp_seq, u8 *rec_no), + + TP_ARGS(sk, tcp_seq, rec_no), + + TP_STRUCT__entry( + __field( struct sock *, sk ) + __field( u64, rec_no ) + __field( u32, tcp_seq ) + ), + + TP_fast_assign( + __entry->sk = sk; + __entry->rec_no = get_unaligned_be64(rec_no); + __entry->tcp_seq = tcp_seq; + ), + + TP_printk( + "sk=%p tcp_seq=%u rec_no=%llu", + __entry->sk, __entry->tcp_seq, __entry->rec_no + ) +); + +#endif /* _TLS_TRACE_H_ */ + +#undef TRACE_INCLUDE_PATH +#define TRACE_INCLUDE_PATH . +#undef TRACE_INCLUDE_FILE +#define TRACE_INCLUDE_FILE trace + +#include diff --git a/net/unix/Kconfig b/net/unix/Kconfig new file mode 100644 index 000000000..b7f811216 --- /dev/null +++ b/net/unix/Kconfig @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Unix Domain Sockets +# + +config UNIX + tristate "Unix domain sockets" + help + If you say Y here, you will include support for Unix domain sockets; + sockets are the standard Unix mechanism for establishing and + accessing network connections. Many commonly used programs such as + the X Window system and syslog use these sockets even if your + machine is not connected to any network. Unless you are working on + an embedded system or something similar, you therefore definitely + want to say Y here. + + To compile this driver as a module, choose M here: the module will be + called unix. Note that several important services won't work + correctly if you say M here and then neglect to load the module. + + Say Y unless you know what you are doing. + +config UNIX_SCM + bool + depends on UNIX + default y + +config AF_UNIX_OOB + bool + depends on UNIX + default y + +config UNIX_DIAG + tristate "UNIX: socket monitoring interface" + depends on UNIX + default n + help + Support for UNIX socket monitoring interface used by the ss tool. + If unsure, say Y. diff --git a/net/unix/Makefile b/net/unix/Makefile new file mode 100644 index 000000000..20491825b --- /dev/null +++ b/net/unix/Makefile @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the Linux unix domain socket layer. +# + +obj-$(CONFIG_UNIX) += unix.o + +unix-y := af_unix.o garbage.o +unix-$(CONFIG_SYSCTL) += sysctl_net_unix.o +unix-$(CONFIG_BPF_SYSCALL) += unix_bpf.o + +obj-$(CONFIG_UNIX_DIAG) += unix_diag.o +unix_diag-y := diag.o + +obj-$(CONFIG_UNIX_SCM) += scm.o diff --git a/net/unix/af_unix.c b/net/unix/af_unix.c new file mode 100644 index 000000000..be2ed7b0f --- /dev/null +++ b/net/unix/af_unix.c @@ -0,0 +1,3785 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NET4: Implementation of BSD Unix domain sockets. + * + * Authors: Alan Cox, + * + * Fixes: + * Linus Torvalds : Assorted bug cures. + * Niibe Yutaka : async I/O support. + * Carsten Paeth : PF_UNIX check, address fixes. + * Alan Cox : Limit size of allocated blocks. + * Alan Cox : Fixed the stupid socketpair bug. + * Alan Cox : BSD compatibility fine tuning. + * Alan Cox : Fixed a bug in connect when interrupted. + * Alan Cox : Sorted out a proper draft version of + * file descriptor passing hacked up from + * Mike Shaver's work. + * Marty Leisner : Fixes to fd passing + * Nick Nevin : recvmsg bugfix. + * Alan Cox : Started proper garbage collector + * Heiko EiBfeldt : Missing verify_area check + * Alan Cox : Started POSIXisms + * Andreas Schwab : Replace inode by dentry for proper + * reference counting + * Kirk Petersen : Made this a module + * Christoph Rohland : Elegant non-blocking accept/connect algorithm. + * Lots of bug fixes. + * Alexey Kuznetosv : Repaired (I hope) bugs introduces + * by above two patches. + * Andrea Arcangeli : If possible we block in connect(2) + * if the max backlog of the listen socket + * is been reached. This won't break + * old apps and it will avoid huge amount + * of socks hashed (this for unix_gc() + * performances reasons). + * Security fix that limits the max + * number of socks to 2*max_files and + * the number of skb queueable in the + * dgram receiver. + * Artur Skawina : Hash function optimizations + * Alexey Kuznetsov : Full scale SMP. Lot of bugs are introduced 8) + * Malcolm Beattie : Set peercred for socketpair + * Michal Ostrowski : Module initialization cleanup. + * Arnaldo C. Melo : Remove MOD_{INC,DEC}_USE_COUNT, + * the core infrastructure is doing that + * for all net proto families now (2.5.69+) + * + * Known differences from reference BSD that was tested: + * + * [TO FIX] + * ECONNREFUSED is not returned from one end of a connected() socket to the + * other the moment one end closes. + * fstat() doesn't return st_dev=0, and give the blksize as high water mark + * and a fake inode identifier (nor the BSD first socket fstat twice bug). + * [NOT TO FIX] + * accept() returns a path name even if the connecting socket has closed + * in the meantime (BSD loses the path and gives up). + * accept() returns 0 length path for an unbound connector. BSD returns 16 + * and a null first byte in the path (but not for gethost/peername - BSD bug ??) + * socketpair(...SOCK_RAW..) doesn't panic the kernel. + * BSD af_unix apparently has connect forgetting to block properly. + * (need to check this with the POSIX spec in detail) + * + * Differences from 2.0.0-11-... (ANK) + * Bug fixes and improvements. + * - client shutdown killed server socket. + * - removed all useless cli/sti pairs. + * + * Semantic changes/extensions. + * - generic control message passing. + * - SCM_CREDENTIALS control message. + * - "Abstract" (not FS based) socket bindings. + * Abstract names are sequences of bytes (not zero terminated) + * started by 0, so that this name space does not intersect + * with BSD names. + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "scm.h" + +static atomic_long_t unix_nr_socks; +static struct hlist_head bsd_socket_buckets[UNIX_HASH_SIZE / 2]; +static spinlock_t bsd_socket_locks[UNIX_HASH_SIZE / 2]; + +/* SMP locking strategy: + * hash table is protected with spinlock. + * each socket state is protected by separate spinlock. + */ + +static unsigned int unix_unbound_hash(struct sock *sk) +{ + unsigned long hash = (unsigned long)sk; + + hash ^= hash >> 16; + hash ^= hash >> 8; + hash ^= sk->sk_type; + + return hash & UNIX_HASH_MOD; +} + +static unsigned int unix_bsd_hash(struct inode *i) +{ + return i->i_ino & UNIX_HASH_MOD; +} + +static unsigned int unix_abstract_hash(struct sockaddr_un *sunaddr, + int addr_len, int type) +{ + __wsum csum = csum_partial(sunaddr, addr_len, 0); + unsigned int hash; + + hash = (__force unsigned int)csum_fold(csum); + hash ^= hash >> 8; + hash ^= type; + + return UNIX_HASH_MOD + 1 + (hash & UNIX_HASH_MOD); +} + +static void unix_table_double_lock(struct net *net, + unsigned int hash1, unsigned int hash2) +{ + if (hash1 == hash2) { + spin_lock(&net->unx.table.locks[hash1]); + return; + } + + if (hash1 > hash2) + swap(hash1, hash2); + + spin_lock(&net->unx.table.locks[hash1]); + spin_lock_nested(&net->unx.table.locks[hash2], SINGLE_DEPTH_NESTING); +} + +static void unix_table_double_unlock(struct net *net, + unsigned int hash1, unsigned int hash2) +{ + if (hash1 == hash2) { + spin_unlock(&net->unx.table.locks[hash1]); + return; + } + + spin_unlock(&net->unx.table.locks[hash1]); + spin_unlock(&net->unx.table.locks[hash2]); +} + +#ifdef CONFIG_SECURITY_NETWORK +static void unix_get_secdata(struct scm_cookie *scm, struct sk_buff *skb) +{ + UNIXCB(skb).secid = scm->secid; +} + +static inline void unix_set_secdata(struct scm_cookie *scm, struct sk_buff *skb) +{ + scm->secid = UNIXCB(skb).secid; +} + +static inline bool unix_secdata_eq(struct scm_cookie *scm, struct sk_buff *skb) +{ + return (scm->secid == UNIXCB(skb).secid); +} +#else +static inline void unix_get_secdata(struct scm_cookie *scm, struct sk_buff *skb) +{ } + +static inline void unix_set_secdata(struct scm_cookie *scm, struct sk_buff *skb) +{ } + +static inline bool unix_secdata_eq(struct scm_cookie *scm, struct sk_buff *skb) +{ + return true; +} +#endif /* CONFIG_SECURITY_NETWORK */ + +static inline int unix_our_peer(struct sock *sk, struct sock *osk) +{ + return unix_peer(osk) == sk; +} + +static inline int unix_may_send(struct sock *sk, struct sock *osk) +{ + return unix_peer(osk) == NULL || unix_our_peer(sk, osk); +} + +static inline int unix_recvq_full(const struct sock *sk) +{ + return skb_queue_len(&sk->sk_receive_queue) > sk->sk_max_ack_backlog; +} + +static inline int unix_recvq_full_lockless(const struct sock *sk) +{ + return skb_queue_len_lockless(&sk->sk_receive_queue) > + READ_ONCE(sk->sk_max_ack_backlog); +} + +struct sock *unix_peer_get(struct sock *s) +{ + struct sock *peer; + + unix_state_lock(s); + peer = unix_peer(s); + if (peer) + sock_hold(peer); + unix_state_unlock(s); + return peer; +} +EXPORT_SYMBOL_GPL(unix_peer_get); + +static struct unix_address *unix_create_addr(struct sockaddr_un *sunaddr, + int addr_len) +{ + struct unix_address *addr; + + addr = kmalloc(sizeof(*addr) + addr_len, GFP_KERNEL); + if (!addr) + return NULL; + + refcount_set(&addr->refcnt, 1); + addr->len = addr_len; + memcpy(addr->name, sunaddr, addr_len); + + return addr; +} + +static inline void unix_release_addr(struct unix_address *addr) +{ + if (refcount_dec_and_test(&addr->refcnt)) + kfree(addr); +} + +/* + * Check unix socket name: + * - should be not zero length. + * - if started by not zero, should be NULL terminated (FS object) + * - if started by zero, it is abstract name. + */ + +static int unix_validate_addr(struct sockaddr_un *sunaddr, int addr_len) +{ + if (addr_len <= offsetof(struct sockaddr_un, sun_path) || + addr_len > sizeof(*sunaddr)) + return -EINVAL; + + if (sunaddr->sun_family != AF_UNIX) + return -EINVAL; + + return 0; +} + +static void unix_mkname_bsd(struct sockaddr_un *sunaddr, int addr_len) +{ + /* This may look like an off by one error but it is a bit more + * subtle. 108 is the longest valid AF_UNIX path for a binding. + * sun_path[108] doesn't as such exist. However in kernel space + * we are guaranteed that it is a valid memory location in our + * kernel address buffer because syscall functions always pass + * a pointer of struct sockaddr_storage which has a bigger buffer + * than 108. + */ + ((char *)sunaddr)[addr_len] = 0; +} + +static void __unix_remove_socket(struct sock *sk) +{ + sk_del_node_init(sk); +} + +static void __unix_insert_socket(struct net *net, struct sock *sk) +{ + DEBUG_NET_WARN_ON_ONCE(!sk_unhashed(sk)); + sk_add_node(sk, &net->unx.table.buckets[sk->sk_hash]); +} + +static void __unix_set_addr_hash(struct net *net, struct sock *sk, + struct unix_address *addr, unsigned int hash) +{ + __unix_remove_socket(sk); + smp_store_release(&unix_sk(sk)->addr, addr); + + sk->sk_hash = hash; + __unix_insert_socket(net, sk); +} + +static void unix_remove_socket(struct net *net, struct sock *sk) +{ + spin_lock(&net->unx.table.locks[sk->sk_hash]); + __unix_remove_socket(sk); + spin_unlock(&net->unx.table.locks[sk->sk_hash]); +} + +static void unix_insert_unbound_socket(struct net *net, struct sock *sk) +{ + spin_lock(&net->unx.table.locks[sk->sk_hash]); + __unix_insert_socket(net, sk); + spin_unlock(&net->unx.table.locks[sk->sk_hash]); +} + +static void unix_insert_bsd_socket(struct sock *sk) +{ + spin_lock(&bsd_socket_locks[sk->sk_hash]); + sk_add_bind_node(sk, &bsd_socket_buckets[sk->sk_hash]); + spin_unlock(&bsd_socket_locks[sk->sk_hash]); +} + +static void unix_remove_bsd_socket(struct sock *sk) +{ + if (!hlist_unhashed(&sk->sk_bind_node)) { + spin_lock(&bsd_socket_locks[sk->sk_hash]); + __sk_del_bind_node(sk); + spin_unlock(&bsd_socket_locks[sk->sk_hash]); + + sk_node_init(&sk->sk_bind_node); + } +} + +static struct sock *__unix_find_socket_byname(struct net *net, + struct sockaddr_un *sunname, + int len, unsigned int hash) +{ + struct sock *s; + + sk_for_each(s, &net->unx.table.buckets[hash]) { + struct unix_sock *u = unix_sk(s); + + if (u->addr->len == len && + !memcmp(u->addr->name, sunname, len)) + return s; + } + return NULL; +} + +static inline struct sock *unix_find_socket_byname(struct net *net, + struct sockaddr_un *sunname, + int len, unsigned int hash) +{ + struct sock *s; + + spin_lock(&net->unx.table.locks[hash]); + s = __unix_find_socket_byname(net, sunname, len, hash); + if (s) + sock_hold(s); + spin_unlock(&net->unx.table.locks[hash]); + return s; +} + +static struct sock *unix_find_socket_byinode(struct inode *i) +{ + unsigned int hash = unix_bsd_hash(i); + struct sock *s; + + spin_lock(&bsd_socket_locks[hash]); + sk_for_each_bound(s, &bsd_socket_buckets[hash]) { + struct dentry *dentry = unix_sk(s)->path.dentry; + + if (dentry && d_backing_inode(dentry) == i) { + sock_hold(s); + spin_unlock(&bsd_socket_locks[hash]); + return s; + } + } + spin_unlock(&bsd_socket_locks[hash]); + return NULL; +} + +/* Support code for asymmetrically connected dgram sockets + * + * If a datagram socket is connected to a socket not itself connected + * to the first socket (eg, /dev/log), clients may only enqueue more + * messages if the present receive queue of the server socket is not + * "too large". This means there's a second writeability condition + * poll and sendmsg need to test. The dgram recv code will do a wake + * up on the peer_wait wait queue of a socket upon reception of a + * datagram which needs to be propagated to sleeping would-be writers + * since these might not have sent anything so far. This can't be + * accomplished via poll_wait because the lifetime of the server + * socket might be less than that of its clients if these break their + * association with it or if the server socket is closed while clients + * are still connected to it and there's no way to inform "a polling + * implementation" that it should let go of a certain wait queue + * + * In order to propagate a wake up, a wait_queue_entry_t of the client + * socket is enqueued on the peer_wait queue of the server socket + * whose wake function does a wake_up on the ordinary client socket + * wait queue. This connection is established whenever a write (or + * poll for write) hit the flow control condition and broken when the + * association to the server socket is dissolved or after a wake up + * was relayed. + */ + +static int unix_dgram_peer_wake_relay(wait_queue_entry_t *q, unsigned mode, int flags, + void *key) +{ + struct unix_sock *u; + wait_queue_head_t *u_sleep; + + u = container_of(q, struct unix_sock, peer_wake); + + __remove_wait_queue(&unix_sk(u->peer_wake.private)->peer_wait, + q); + u->peer_wake.private = NULL; + + /* relaying can only happen while the wq still exists */ + u_sleep = sk_sleep(&u->sk); + if (u_sleep) + wake_up_interruptible_poll(u_sleep, key_to_poll(key)); + + return 0; +} + +static int unix_dgram_peer_wake_connect(struct sock *sk, struct sock *other) +{ + struct unix_sock *u, *u_other; + int rc; + + u = unix_sk(sk); + u_other = unix_sk(other); + rc = 0; + spin_lock(&u_other->peer_wait.lock); + + if (!u->peer_wake.private) { + u->peer_wake.private = other; + __add_wait_queue(&u_other->peer_wait, &u->peer_wake); + + rc = 1; + } + + spin_unlock(&u_other->peer_wait.lock); + return rc; +} + +static void unix_dgram_peer_wake_disconnect(struct sock *sk, + struct sock *other) +{ + struct unix_sock *u, *u_other; + + u = unix_sk(sk); + u_other = unix_sk(other); + spin_lock(&u_other->peer_wait.lock); + + if (u->peer_wake.private == other) { + __remove_wait_queue(&u_other->peer_wait, &u->peer_wake); + u->peer_wake.private = NULL; + } + + spin_unlock(&u_other->peer_wait.lock); +} + +static void unix_dgram_peer_wake_disconnect_wakeup(struct sock *sk, + struct sock *other) +{ + unix_dgram_peer_wake_disconnect(sk, other); + wake_up_interruptible_poll(sk_sleep(sk), + EPOLLOUT | + EPOLLWRNORM | + EPOLLWRBAND); +} + +/* preconditions: + * - unix_peer(sk) == other + * - association is stable + */ +static int unix_dgram_peer_wake_me(struct sock *sk, struct sock *other) +{ + int connected; + + connected = unix_dgram_peer_wake_connect(sk, other); + + /* If other is SOCK_DEAD, we want to make sure we signal + * POLLOUT, such that a subsequent write() can get a + * -ECONNREFUSED. Otherwise, if we haven't queued any skbs + * to other and its full, we will hang waiting for POLLOUT. + */ + if (unix_recvq_full_lockless(other) && !sock_flag(other, SOCK_DEAD)) + return 1; + + if (connected) + unix_dgram_peer_wake_disconnect(sk, other); + + return 0; +} + +static int unix_writable(const struct sock *sk) +{ + return sk->sk_state != TCP_LISTEN && + (refcount_read(&sk->sk_wmem_alloc) << 2) <= sk->sk_sndbuf; +} + +static void unix_write_space(struct sock *sk) +{ + struct socket_wq *wq; + + rcu_read_lock(); + if (unix_writable(sk)) { + wq = rcu_dereference(sk->sk_wq); + if (skwq_has_sleeper(wq)) + wake_up_interruptible_sync_poll(&wq->wait, + EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND); + sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT); + } + rcu_read_unlock(); +} + +/* When dgram socket disconnects (or changes its peer), we clear its receive + * queue of packets arrived from previous peer. First, it allows to do + * flow control based only on wmem_alloc; second, sk connected to peer + * may receive messages only from that peer. */ +static void unix_dgram_disconnected(struct sock *sk, struct sock *other) +{ + if (!skb_queue_empty(&sk->sk_receive_queue)) { + skb_queue_purge(&sk->sk_receive_queue); + wake_up_interruptible_all(&unix_sk(sk)->peer_wait); + + /* If one link of bidirectional dgram pipe is disconnected, + * we signal error. Messages are lost. Do not make this, + * when peer was not connected to us. + */ + if (!sock_flag(other, SOCK_DEAD) && unix_peer(other) == sk) { + other->sk_err = ECONNRESET; + sk_error_report(other); + } + } + other->sk_state = TCP_CLOSE; +} + +static void unix_sock_destructor(struct sock *sk) +{ + struct unix_sock *u = unix_sk(sk); + + skb_queue_purge(&sk->sk_receive_queue); + + DEBUG_NET_WARN_ON_ONCE(refcount_read(&sk->sk_wmem_alloc)); + DEBUG_NET_WARN_ON_ONCE(!sk_unhashed(sk)); + DEBUG_NET_WARN_ON_ONCE(sk->sk_socket); + if (!sock_flag(sk, SOCK_DEAD)) { + pr_info("Attempt to release alive unix socket: %p\n", sk); + return; + } + + if (u->addr) + unix_release_addr(u->addr); + + atomic_long_dec(&unix_nr_socks); + sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); +#ifdef UNIX_REFCNT_DEBUG + pr_debug("UNIX %p is destroyed, %ld are still alive.\n", sk, + atomic_long_read(&unix_nr_socks)); +#endif +} + +static void unix_release_sock(struct sock *sk, int embrion) +{ + struct unix_sock *u = unix_sk(sk); + struct sock *skpair; + struct sk_buff *skb; + struct path path; + int state; + + unix_remove_socket(sock_net(sk), sk); + unix_remove_bsd_socket(sk); + + /* Clear state */ + unix_state_lock(sk); + sock_orphan(sk); + WRITE_ONCE(sk->sk_shutdown, SHUTDOWN_MASK); + path = u->path; + u->path.dentry = NULL; + u->path.mnt = NULL; + state = sk->sk_state; + sk->sk_state = TCP_CLOSE; + + skpair = unix_peer(sk); + unix_peer(sk) = NULL; + + unix_state_unlock(sk); + +#if IS_ENABLED(CONFIG_AF_UNIX_OOB) + if (u->oob_skb) { + kfree_skb(u->oob_skb); + u->oob_skb = NULL; + } +#endif + + wake_up_interruptible_all(&u->peer_wait); + + if (skpair != NULL) { + if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) { + unix_state_lock(skpair); + /* No more writes */ + WRITE_ONCE(skpair->sk_shutdown, SHUTDOWN_MASK); + if (!skb_queue_empty(&sk->sk_receive_queue) || embrion) + skpair->sk_err = ECONNRESET; + unix_state_unlock(skpair); + skpair->sk_state_change(skpair); + sk_wake_async(skpair, SOCK_WAKE_WAITD, POLL_HUP); + } + + unix_dgram_peer_wake_disconnect(sk, skpair); + sock_put(skpair); /* It may now die */ + } + + /* Try to flush out this socket. Throw out buffers at least */ + + while ((skb = skb_dequeue(&sk->sk_receive_queue)) != NULL) { + if (state == TCP_LISTEN) + unix_release_sock(skb->sk, 1); + /* passed fds are erased in the kfree_skb hook */ + UNIXCB(skb).consumed = skb->len; + kfree_skb(skb); + } + + if (path.dentry) + path_put(&path); + + sock_put(sk); + + /* ---- Socket is dead now and most probably destroyed ---- */ + + /* + * Fixme: BSD difference: In BSD all sockets connected to us get + * ECONNRESET and we die on the spot. In Linux we behave + * like files and pipes do and wait for the last + * dereference. + * + * Can't we simply set sock->err? + * + * What the above comment does talk about? --ANK(980817) + */ + + if (READ_ONCE(unix_tot_inflight)) + unix_gc(); /* Garbage collect fds */ +} + +static void init_peercred(struct sock *sk) +{ + const struct cred *old_cred; + struct pid *old_pid; + + spin_lock(&sk->sk_peer_lock); + old_pid = sk->sk_peer_pid; + old_cred = sk->sk_peer_cred; + sk->sk_peer_pid = get_pid(task_tgid(current)); + sk->sk_peer_cred = get_current_cred(); + spin_unlock(&sk->sk_peer_lock); + + put_pid(old_pid); + put_cred(old_cred); +} + +static void copy_peercred(struct sock *sk, struct sock *peersk) +{ + const struct cred *old_cred; + struct pid *old_pid; + + if (sk < peersk) { + spin_lock(&sk->sk_peer_lock); + spin_lock_nested(&peersk->sk_peer_lock, SINGLE_DEPTH_NESTING); + } else { + spin_lock(&peersk->sk_peer_lock); + spin_lock_nested(&sk->sk_peer_lock, SINGLE_DEPTH_NESTING); + } + old_pid = sk->sk_peer_pid; + old_cred = sk->sk_peer_cred; + sk->sk_peer_pid = get_pid(peersk->sk_peer_pid); + sk->sk_peer_cred = get_cred(peersk->sk_peer_cred); + + spin_unlock(&sk->sk_peer_lock); + spin_unlock(&peersk->sk_peer_lock); + + put_pid(old_pid); + put_cred(old_cred); +} + +static int unix_listen(struct socket *sock, int backlog) +{ + int err; + struct sock *sk = sock->sk; + struct unix_sock *u = unix_sk(sk); + + err = -EOPNOTSUPP; + if (sock->type != SOCK_STREAM && sock->type != SOCK_SEQPACKET) + goto out; /* Only stream/seqpacket sockets accept */ + err = -EINVAL; + if (!u->addr) + goto out; /* No listens on an unbound socket */ + unix_state_lock(sk); + if (sk->sk_state != TCP_CLOSE && sk->sk_state != TCP_LISTEN) + goto out_unlock; + if (backlog > sk->sk_max_ack_backlog) + wake_up_interruptible_all(&u->peer_wait); + sk->sk_max_ack_backlog = backlog; + sk->sk_state = TCP_LISTEN; + /* set credentials so connect can copy them */ + init_peercred(sk); + err = 0; + +out_unlock: + unix_state_unlock(sk); +out: + return err; +} + +static int unix_release(struct socket *); +static int unix_bind(struct socket *, struct sockaddr *, int); +static int unix_stream_connect(struct socket *, struct sockaddr *, + int addr_len, int flags); +static int unix_socketpair(struct socket *, struct socket *); +static int unix_accept(struct socket *, struct socket *, int, bool); +static int unix_getname(struct socket *, struct sockaddr *, int); +static __poll_t unix_poll(struct file *, struct socket *, poll_table *); +static __poll_t unix_dgram_poll(struct file *, struct socket *, + poll_table *); +static int unix_ioctl(struct socket *, unsigned int, unsigned long); +#ifdef CONFIG_COMPAT +static int unix_compat_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg); +#endif +static int unix_shutdown(struct socket *, int); +static int unix_stream_sendmsg(struct socket *, struct msghdr *, size_t); +static int unix_stream_recvmsg(struct socket *, struct msghdr *, size_t, int); +static ssize_t unix_stream_sendpage(struct socket *, struct page *, int offset, + size_t size, int flags); +static ssize_t unix_stream_splice_read(struct socket *, loff_t *ppos, + struct pipe_inode_info *, size_t size, + unsigned int flags); +static int unix_dgram_sendmsg(struct socket *, struct msghdr *, size_t); +static int unix_dgram_recvmsg(struct socket *, struct msghdr *, size_t, int); +static int unix_read_skb(struct sock *sk, skb_read_actor_t recv_actor); +static int unix_stream_read_skb(struct sock *sk, skb_read_actor_t recv_actor); +static int unix_dgram_connect(struct socket *, struct sockaddr *, + int, int); +static int unix_seqpacket_sendmsg(struct socket *, struct msghdr *, size_t); +static int unix_seqpacket_recvmsg(struct socket *, struct msghdr *, size_t, + int); + +static int unix_set_peek_off(struct sock *sk, int val) +{ + struct unix_sock *u = unix_sk(sk); + + if (mutex_lock_interruptible(&u->iolock)) + return -EINTR; + + WRITE_ONCE(sk->sk_peek_off, val); + mutex_unlock(&u->iolock); + + return 0; +} + +#ifdef CONFIG_PROC_FS +static int unix_count_nr_fds(struct sock *sk) +{ + struct sk_buff *skb; + struct unix_sock *u; + int nr_fds = 0; + + spin_lock(&sk->sk_receive_queue.lock); + skb = skb_peek(&sk->sk_receive_queue); + while (skb) { + u = unix_sk(skb->sk); + nr_fds += atomic_read(&u->scm_stat.nr_fds); + skb = skb_peek_next(skb, &sk->sk_receive_queue); + } + spin_unlock(&sk->sk_receive_queue.lock); + + return nr_fds; +} + +static void unix_show_fdinfo(struct seq_file *m, struct socket *sock) +{ + struct sock *sk = sock->sk; + struct unix_sock *u; + int nr_fds; + + if (sk) { + u = unix_sk(sk); + if (sock->type == SOCK_DGRAM) { + nr_fds = atomic_read(&u->scm_stat.nr_fds); + goto out_print; + } + + unix_state_lock(sk); + if (sk->sk_state != TCP_LISTEN) + nr_fds = atomic_read(&u->scm_stat.nr_fds); + else + nr_fds = unix_count_nr_fds(sk); + unix_state_unlock(sk); +out_print: + seq_printf(m, "scm_fds: %u\n", nr_fds); + } +} +#else +#define unix_show_fdinfo NULL +#endif + +static const struct proto_ops unix_stream_ops = { + .family = PF_UNIX, + .owner = THIS_MODULE, + .release = unix_release, + .bind = unix_bind, + .connect = unix_stream_connect, + .socketpair = unix_socketpair, + .accept = unix_accept, + .getname = unix_getname, + .poll = unix_poll, + .ioctl = unix_ioctl, +#ifdef CONFIG_COMPAT + .compat_ioctl = unix_compat_ioctl, +#endif + .listen = unix_listen, + .shutdown = unix_shutdown, + .sendmsg = unix_stream_sendmsg, + .recvmsg = unix_stream_recvmsg, + .read_skb = unix_stream_read_skb, + .mmap = sock_no_mmap, + .sendpage = unix_stream_sendpage, + .splice_read = unix_stream_splice_read, + .set_peek_off = unix_set_peek_off, + .show_fdinfo = unix_show_fdinfo, +}; + +static const struct proto_ops unix_dgram_ops = { + .family = PF_UNIX, + .owner = THIS_MODULE, + .release = unix_release, + .bind = unix_bind, + .connect = unix_dgram_connect, + .socketpair = unix_socketpair, + .accept = sock_no_accept, + .getname = unix_getname, + .poll = unix_dgram_poll, + .ioctl = unix_ioctl, +#ifdef CONFIG_COMPAT + .compat_ioctl = unix_compat_ioctl, +#endif + .listen = sock_no_listen, + .shutdown = unix_shutdown, + .sendmsg = unix_dgram_sendmsg, + .read_skb = unix_read_skb, + .recvmsg = unix_dgram_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, + .set_peek_off = unix_set_peek_off, + .show_fdinfo = unix_show_fdinfo, +}; + +static const struct proto_ops unix_seqpacket_ops = { + .family = PF_UNIX, + .owner = THIS_MODULE, + .release = unix_release, + .bind = unix_bind, + .connect = unix_stream_connect, + .socketpair = unix_socketpair, + .accept = unix_accept, + .getname = unix_getname, + .poll = unix_dgram_poll, + .ioctl = unix_ioctl, +#ifdef CONFIG_COMPAT + .compat_ioctl = unix_compat_ioctl, +#endif + .listen = unix_listen, + .shutdown = unix_shutdown, + .sendmsg = unix_seqpacket_sendmsg, + .recvmsg = unix_seqpacket_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, + .set_peek_off = unix_set_peek_off, + .show_fdinfo = unix_show_fdinfo, +}; + +static void unix_close(struct sock *sk, long timeout) +{ + /* Nothing to do here, unix socket does not need a ->close(). + * This is merely for sockmap. + */ +} + +static void unix_unhash(struct sock *sk) +{ + /* Nothing to do here, unix socket does not need a ->unhash(). + * This is merely for sockmap. + */ +} + +struct proto unix_dgram_proto = { + .name = "UNIX", + .owner = THIS_MODULE, + .obj_size = sizeof(struct unix_sock), + .close = unix_close, +#ifdef CONFIG_BPF_SYSCALL + .psock_update_sk_prot = unix_dgram_bpf_update_proto, +#endif +}; + +struct proto unix_stream_proto = { + .name = "UNIX-STREAM", + .owner = THIS_MODULE, + .obj_size = sizeof(struct unix_sock), + .close = unix_close, + .unhash = unix_unhash, +#ifdef CONFIG_BPF_SYSCALL + .psock_update_sk_prot = unix_stream_bpf_update_proto, +#endif +}; + +static struct sock *unix_create1(struct net *net, struct socket *sock, int kern, int type) +{ + struct unix_sock *u; + struct sock *sk; + int err; + + atomic_long_inc(&unix_nr_socks); + if (atomic_long_read(&unix_nr_socks) > 2 * get_max_files()) { + err = -ENFILE; + goto err; + } + + if (type == SOCK_STREAM) + sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_stream_proto, kern); + else /*dgram and seqpacket */ + sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_dgram_proto, kern); + + if (!sk) { + err = -ENOMEM; + goto err; + } + + sock_init_data(sock, sk); + + sk->sk_hash = unix_unbound_hash(sk); + sk->sk_allocation = GFP_KERNEL_ACCOUNT; + sk->sk_write_space = unix_write_space; + sk->sk_max_ack_backlog = net->unx.sysctl_max_dgram_qlen; + sk->sk_destruct = unix_sock_destructor; + u = unix_sk(sk); + u->path.dentry = NULL; + u->path.mnt = NULL; + spin_lock_init(&u->lock); + atomic_long_set(&u->inflight, 0); + INIT_LIST_HEAD(&u->link); + mutex_init(&u->iolock); /* single task reading lock */ + mutex_init(&u->bindlock); /* single task binding lock */ + init_waitqueue_head(&u->peer_wait); + init_waitqueue_func_entry(&u->peer_wake, unix_dgram_peer_wake_relay); + memset(&u->scm_stat, 0, sizeof(struct scm_stat)); + unix_insert_unbound_socket(net, sk); + + sock_prot_inuse_add(net, sk->sk_prot, 1); + + return sk; + +err: + atomic_long_dec(&unix_nr_socks); + return ERR_PTR(err); +} + +static int unix_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + + if (protocol && protocol != PF_UNIX) + return -EPROTONOSUPPORT; + + sock->state = SS_UNCONNECTED; + + switch (sock->type) { + case SOCK_STREAM: + sock->ops = &unix_stream_ops; + break; + /* + * Believe it or not BSD has AF_UNIX, SOCK_RAW though + * nothing uses it. + */ + case SOCK_RAW: + sock->type = SOCK_DGRAM; + fallthrough; + case SOCK_DGRAM: + sock->ops = &unix_dgram_ops; + break; + case SOCK_SEQPACKET: + sock->ops = &unix_seqpacket_ops; + break; + default: + return -ESOCKTNOSUPPORT; + } + + sk = unix_create1(net, sock, kern, sock->type); + if (IS_ERR(sk)) + return PTR_ERR(sk); + + return 0; +} + +static int unix_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + + if (!sk) + return 0; + + sk->sk_prot->close(sk, 0); + unix_release_sock(sk, 0); + sock->sk = NULL; + + return 0; +} + +static struct sock *unix_find_bsd(struct sockaddr_un *sunaddr, int addr_len, + int type) +{ + struct inode *inode; + struct path path; + struct sock *sk; + int err; + + unix_mkname_bsd(sunaddr, addr_len); + err = kern_path(sunaddr->sun_path, LOOKUP_FOLLOW, &path); + if (err) + goto fail; + + err = path_permission(&path, MAY_WRITE); + if (err) + goto path_put; + + err = -ECONNREFUSED; + inode = d_backing_inode(path.dentry); + if (!S_ISSOCK(inode->i_mode)) + goto path_put; + + sk = unix_find_socket_byinode(inode); + if (!sk) + goto path_put; + + err = -EPROTOTYPE; + if (sk->sk_type == type) + touch_atime(&path); + else + goto sock_put; + + path_put(&path); + + return sk; + +sock_put: + sock_put(sk); +path_put: + path_put(&path); +fail: + return ERR_PTR(err); +} + +static struct sock *unix_find_abstract(struct net *net, + struct sockaddr_un *sunaddr, + int addr_len, int type) +{ + unsigned int hash = unix_abstract_hash(sunaddr, addr_len, type); + struct dentry *dentry; + struct sock *sk; + + sk = unix_find_socket_byname(net, sunaddr, addr_len, hash); + if (!sk) + return ERR_PTR(-ECONNREFUSED); + + dentry = unix_sk(sk)->path.dentry; + if (dentry) + touch_atime(&unix_sk(sk)->path); + + return sk; +} + +static struct sock *unix_find_other(struct net *net, + struct sockaddr_un *sunaddr, + int addr_len, int type) +{ + struct sock *sk; + + if (sunaddr->sun_path[0]) + sk = unix_find_bsd(sunaddr, addr_len, type); + else + sk = unix_find_abstract(net, sunaddr, addr_len, type); + + return sk; +} + +static int unix_autobind(struct sock *sk) +{ + unsigned int new_hash, old_hash = sk->sk_hash; + struct unix_sock *u = unix_sk(sk); + struct net *net = sock_net(sk); + struct unix_address *addr; + u32 lastnum, ordernum; + int err; + + err = mutex_lock_interruptible(&u->bindlock); + if (err) + return err; + + if (u->addr) + goto out; + + err = -ENOMEM; + addr = kzalloc(sizeof(*addr) + + offsetof(struct sockaddr_un, sun_path) + 16, GFP_KERNEL); + if (!addr) + goto out; + + addr->len = offsetof(struct sockaddr_un, sun_path) + 6; + addr->name->sun_family = AF_UNIX; + refcount_set(&addr->refcnt, 1); + + ordernum = get_random_u32(); + lastnum = ordernum & 0xFFFFF; +retry: + ordernum = (ordernum + 1) & 0xFFFFF; + sprintf(addr->name->sun_path + 1, "%05x", ordernum); + + new_hash = unix_abstract_hash(addr->name, addr->len, sk->sk_type); + unix_table_double_lock(net, old_hash, new_hash); + + if (__unix_find_socket_byname(net, addr->name, addr->len, new_hash)) { + unix_table_double_unlock(net, old_hash, new_hash); + + /* __unix_find_socket_byname() may take long time if many names + * are already in use. + */ + cond_resched(); + + if (ordernum == lastnum) { + /* Give up if all names seems to be in use. */ + err = -ENOSPC; + unix_release_addr(addr); + goto out; + } + + goto retry; + } + + __unix_set_addr_hash(net, sk, addr, new_hash); + unix_table_double_unlock(net, old_hash, new_hash); + err = 0; + +out: mutex_unlock(&u->bindlock); + return err; +} + +static int unix_bind_bsd(struct sock *sk, struct sockaddr_un *sunaddr, + int addr_len) +{ + umode_t mode = S_IFSOCK | + (SOCK_INODE(sk->sk_socket)->i_mode & ~current_umask()); + unsigned int new_hash, old_hash = sk->sk_hash; + struct unix_sock *u = unix_sk(sk); + struct net *net = sock_net(sk); + struct user_namespace *ns; // barf... + struct unix_address *addr; + struct dentry *dentry; + struct path parent; + int err; + + unix_mkname_bsd(sunaddr, addr_len); + addr_len = strlen(sunaddr->sun_path) + + offsetof(struct sockaddr_un, sun_path) + 1; + + addr = unix_create_addr(sunaddr, addr_len); + if (!addr) + return -ENOMEM; + + /* + * Get the parent directory, calculate the hash for last + * component. + */ + dentry = kern_path_create(AT_FDCWD, addr->name->sun_path, &parent, 0); + if (IS_ERR(dentry)) { + err = PTR_ERR(dentry); + goto out; + } + + /* + * All right, let's create it. + */ + ns = mnt_user_ns(parent.mnt); + err = security_path_mknod(&parent, dentry, mode, 0); + if (!err) + err = vfs_mknod(ns, d_inode(parent.dentry), dentry, mode, 0); + if (err) + goto out_path; + err = mutex_lock_interruptible(&u->bindlock); + if (err) + goto out_unlink; + if (u->addr) + goto out_unlock; + + new_hash = unix_bsd_hash(d_backing_inode(dentry)); + unix_table_double_lock(net, old_hash, new_hash); + u->path.mnt = mntget(parent.mnt); + u->path.dentry = dget(dentry); + __unix_set_addr_hash(net, sk, addr, new_hash); + unix_table_double_unlock(net, old_hash, new_hash); + unix_insert_bsd_socket(sk); + mutex_unlock(&u->bindlock); + done_path_create(&parent, dentry); + return 0; + +out_unlock: + mutex_unlock(&u->bindlock); + err = -EINVAL; +out_unlink: + /* failed after successful mknod? unlink what we'd created... */ + vfs_unlink(ns, d_inode(parent.dentry), dentry, NULL); +out_path: + done_path_create(&parent, dentry); +out: + unix_release_addr(addr); + return err == -EEXIST ? -EADDRINUSE : err; +} + +static int unix_bind_abstract(struct sock *sk, struct sockaddr_un *sunaddr, + int addr_len) +{ + unsigned int new_hash, old_hash = sk->sk_hash; + struct unix_sock *u = unix_sk(sk); + struct net *net = sock_net(sk); + struct unix_address *addr; + int err; + + addr = unix_create_addr(sunaddr, addr_len); + if (!addr) + return -ENOMEM; + + err = mutex_lock_interruptible(&u->bindlock); + if (err) + goto out; + + if (u->addr) { + err = -EINVAL; + goto out_mutex; + } + + new_hash = unix_abstract_hash(addr->name, addr->len, sk->sk_type); + unix_table_double_lock(net, old_hash, new_hash); + + if (__unix_find_socket_byname(net, addr->name, addr->len, new_hash)) + goto out_spin; + + __unix_set_addr_hash(net, sk, addr, new_hash); + unix_table_double_unlock(net, old_hash, new_hash); + mutex_unlock(&u->bindlock); + return 0; + +out_spin: + unix_table_double_unlock(net, old_hash, new_hash); + err = -EADDRINUSE; +out_mutex: + mutex_unlock(&u->bindlock); +out: + unix_release_addr(addr); + return err; +} + +static int unix_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +{ + struct sockaddr_un *sunaddr = (struct sockaddr_un *)uaddr; + struct sock *sk = sock->sk; + int err; + + if (addr_len == offsetof(struct sockaddr_un, sun_path) && + sunaddr->sun_family == AF_UNIX) + return unix_autobind(sk); + + err = unix_validate_addr(sunaddr, addr_len); + if (err) + return err; + + if (sunaddr->sun_path[0]) + err = unix_bind_bsd(sk, sunaddr, addr_len); + else + err = unix_bind_abstract(sk, sunaddr, addr_len); + + return err; +} + +static void unix_state_double_lock(struct sock *sk1, struct sock *sk2) +{ + if (unlikely(sk1 == sk2) || !sk2) { + unix_state_lock(sk1); + return; + } + if (sk1 < sk2) { + unix_state_lock(sk1); + unix_state_lock_nested(sk2); + } else { + unix_state_lock(sk2); + unix_state_lock_nested(sk1); + } +} + +static void unix_state_double_unlock(struct sock *sk1, struct sock *sk2) +{ + if (unlikely(sk1 == sk2) || !sk2) { + unix_state_unlock(sk1); + return; + } + unix_state_unlock(sk1); + unix_state_unlock(sk2); +} + +static int unix_dgram_connect(struct socket *sock, struct sockaddr *addr, + int alen, int flags) +{ + struct sockaddr_un *sunaddr = (struct sockaddr_un *)addr; + struct sock *sk = sock->sk; + struct sock *other; + int err; + + err = -EINVAL; + if (alen < offsetofend(struct sockaddr, sa_family)) + goto out; + + if (addr->sa_family != AF_UNSPEC) { + err = unix_validate_addr(sunaddr, alen); + if (err) + goto out; + + if (test_bit(SOCK_PASSCRED, &sock->flags) && + !unix_sk(sk)->addr) { + err = unix_autobind(sk); + if (err) + goto out; + } + +restart: + other = unix_find_other(sock_net(sk), sunaddr, alen, sock->type); + if (IS_ERR(other)) { + err = PTR_ERR(other); + goto out; + } + + unix_state_double_lock(sk, other); + + /* Apparently VFS overslept socket death. Retry. */ + if (sock_flag(other, SOCK_DEAD)) { + unix_state_double_unlock(sk, other); + sock_put(other); + goto restart; + } + + err = -EPERM; + if (!unix_may_send(sk, other)) + goto out_unlock; + + err = security_unix_may_send(sk->sk_socket, other->sk_socket); + if (err) + goto out_unlock; + + sk->sk_state = other->sk_state = TCP_ESTABLISHED; + } else { + /* + * 1003.1g breaking connected state with AF_UNSPEC + */ + other = NULL; + unix_state_double_lock(sk, other); + } + + /* + * If it was connected, reconnect. + */ + if (unix_peer(sk)) { + struct sock *old_peer = unix_peer(sk); + + unix_peer(sk) = other; + if (!other) + sk->sk_state = TCP_CLOSE; + unix_dgram_peer_wake_disconnect_wakeup(sk, old_peer); + + unix_state_double_unlock(sk, other); + + if (other != old_peer) + unix_dgram_disconnected(sk, old_peer); + sock_put(old_peer); + } else { + unix_peer(sk) = other; + unix_state_double_unlock(sk, other); + } + + return 0; + +out_unlock: + unix_state_double_unlock(sk, other); + sock_put(other); +out: + return err; +} + +static long unix_wait_for_peer(struct sock *other, long timeo) + __releases(&unix_sk(other)->lock) +{ + struct unix_sock *u = unix_sk(other); + int sched; + DEFINE_WAIT(wait); + + prepare_to_wait_exclusive(&u->peer_wait, &wait, TASK_INTERRUPTIBLE); + + sched = !sock_flag(other, SOCK_DEAD) && + !(other->sk_shutdown & RCV_SHUTDOWN) && + unix_recvq_full_lockless(other); + + unix_state_unlock(other); + + if (sched) + timeo = schedule_timeout(timeo); + + finish_wait(&u->peer_wait, &wait); + return timeo; +} + +static int unix_stream_connect(struct socket *sock, struct sockaddr *uaddr, + int addr_len, int flags) +{ + struct sockaddr_un *sunaddr = (struct sockaddr_un *)uaddr; + struct sock *sk = sock->sk, *newsk = NULL, *other = NULL; + struct unix_sock *u = unix_sk(sk), *newu, *otheru; + struct net *net = sock_net(sk); + struct sk_buff *skb = NULL; + long timeo; + int err; + int st; + + err = unix_validate_addr(sunaddr, addr_len); + if (err) + goto out; + + if (test_bit(SOCK_PASSCRED, &sock->flags) && !u->addr) { + err = unix_autobind(sk); + if (err) + goto out; + } + + timeo = sock_sndtimeo(sk, flags & O_NONBLOCK); + + /* First of all allocate resources. + If we will make it after state is locked, + we will have to recheck all again in any case. + */ + + /* create new sock for complete connection */ + newsk = unix_create1(net, NULL, 0, sock->type); + if (IS_ERR(newsk)) { + err = PTR_ERR(newsk); + newsk = NULL; + goto out; + } + + err = -ENOMEM; + + /* Allocate skb for sending to listening sock */ + skb = sock_wmalloc(newsk, 1, 0, GFP_KERNEL); + if (skb == NULL) + goto out; + +restart: + /* Find listening sock. */ + other = unix_find_other(net, sunaddr, addr_len, sk->sk_type); + if (IS_ERR(other)) { + err = PTR_ERR(other); + other = NULL; + goto out; + } + + /* Latch state of peer */ + unix_state_lock(other); + + /* Apparently VFS overslept socket death. Retry. */ + if (sock_flag(other, SOCK_DEAD)) { + unix_state_unlock(other); + sock_put(other); + goto restart; + } + + err = -ECONNREFUSED; + if (other->sk_state != TCP_LISTEN) + goto out_unlock; + if (other->sk_shutdown & RCV_SHUTDOWN) + goto out_unlock; + + if (unix_recvq_full(other)) { + err = -EAGAIN; + if (!timeo) + goto out_unlock; + + timeo = unix_wait_for_peer(other, timeo); + + err = sock_intr_errno(timeo); + if (signal_pending(current)) + goto out; + sock_put(other); + goto restart; + } + + /* Latch our state. + + It is tricky place. We need to grab our state lock and cannot + drop lock on peer. It is dangerous because deadlock is + possible. Connect to self case and simultaneous + attempt to connect are eliminated by checking socket + state. other is TCP_LISTEN, if sk is TCP_LISTEN we + check this before attempt to grab lock. + + Well, and we have to recheck the state after socket locked. + */ + st = sk->sk_state; + + switch (st) { + case TCP_CLOSE: + /* This is ok... continue with connect */ + break; + case TCP_ESTABLISHED: + /* Socket is already connected */ + err = -EISCONN; + goto out_unlock; + default: + err = -EINVAL; + goto out_unlock; + } + + unix_state_lock_nested(sk); + + if (sk->sk_state != st) { + unix_state_unlock(sk); + unix_state_unlock(other); + sock_put(other); + goto restart; + } + + err = security_unix_stream_connect(sk, other, newsk); + if (err) { + unix_state_unlock(sk); + goto out_unlock; + } + + /* The way is open! Fastly set all the necessary fields... */ + + sock_hold(sk); + unix_peer(newsk) = sk; + newsk->sk_state = TCP_ESTABLISHED; + newsk->sk_type = sk->sk_type; + init_peercred(newsk); + newu = unix_sk(newsk); + RCU_INIT_POINTER(newsk->sk_wq, &newu->peer_wq); + otheru = unix_sk(other); + + /* copy address information from listening to new sock + * + * The contents of *(otheru->addr) and otheru->path + * are seen fully set up here, since we have found + * otheru in hash under its lock. Insertion into the + * hash chain we'd found it in had been done in an + * earlier critical area protected by the chain's lock, + * the same one where we'd set *(otheru->addr) contents, + * as well as otheru->path and otheru->addr itself. + * + * Using smp_store_release() here to set newu->addr + * is enough to make those stores, as well as stores + * to newu->path visible to anyone who gets newu->addr + * by smp_load_acquire(). IOW, the same warranties + * as for unix_sock instances bound in unix_bind() or + * in unix_autobind(). + */ + if (otheru->path.dentry) { + path_get(&otheru->path); + newu->path = otheru->path; + } + refcount_inc(&otheru->addr->refcnt); + smp_store_release(&newu->addr, otheru->addr); + + /* Set credentials */ + copy_peercred(sk, other); + + sock->state = SS_CONNECTED; + sk->sk_state = TCP_ESTABLISHED; + sock_hold(newsk); + + smp_mb__after_atomic(); /* sock_hold() does an atomic_inc() */ + unix_peer(sk) = newsk; + + unix_state_unlock(sk); + + /* take ten and send info to listening sock */ + spin_lock(&other->sk_receive_queue.lock); + __skb_queue_tail(&other->sk_receive_queue, skb); + spin_unlock(&other->sk_receive_queue.lock); + unix_state_unlock(other); + other->sk_data_ready(other); + sock_put(other); + return 0; + +out_unlock: + if (other) + unix_state_unlock(other); + +out: + kfree_skb(skb); + if (newsk) + unix_release_sock(newsk, 0); + if (other) + sock_put(other); + return err; +} + +static int unix_socketpair(struct socket *socka, struct socket *sockb) +{ + struct sock *ska = socka->sk, *skb = sockb->sk; + + /* Join our sockets back to back */ + sock_hold(ska); + sock_hold(skb); + unix_peer(ska) = skb; + unix_peer(skb) = ska; + init_peercred(ska); + init_peercred(skb); + + ska->sk_state = TCP_ESTABLISHED; + skb->sk_state = TCP_ESTABLISHED; + socka->state = SS_CONNECTED; + sockb->state = SS_CONNECTED; + return 0; +} + +static void unix_sock_inherit_flags(const struct socket *old, + struct socket *new) +{ + if (test_bit(SOCK_PASSCRED, &old->flags)) + set_bit(SOCK_PASSCRED, &new->flags); + if (test_bit(SOCK_PASSSEC, &old->flags)) + set_bit(SOCK_PASSSEC, &new->flags); +} + +static int unix_accept(struct socket *sock, struct socket *newsock, int flags, + bool kern) +{ + struct sock *sk = sock->sk; + struct sock *tsk; + struct sk_buff *skb; + int err; + + err = -EOPNOTSUPP; + if (sock->type != SOCK_STREAM && sock->type != SOCK_SEQPACKET) + goto out; + + err = -EINVAL; + if (sk->sk_state != TCP_LISTEN) + goto out; + + /* If socket state is TCP_LISTEN it cannot change (for now...), + * so that no locks are necessary. + */ + + skb = skb_recv_datagram(sk, (flags & O_NONBLOCK) ? MSG_DONTWAIT : 0, + &err); + if (!skb) { + /* This means receive shutdown. */ + if (err == 0) + err = -EINVAL; + goto out; + } + + tsk = skb->sk; + skb_free_datagram(sk, skb); + wake_up_interruptible(&unix_sk(sk)->peer_wait); + + /* attach accepted sock to socket */ + unix_state_lock(tsk); + newsock->state = SS_CONNECTED; + unix_sock_inherit_flags(sock, newsock); + sock_graft(tsk, newsock); + unix_state_unlock(tsk); + return 0; + +out: + return err; +} + + +static int unix_getname(struct socket *sock, struct sockaddr *uaddr, int peer) +{ + struct sock *sk = sock->sk; + struct unix_address *addr; + DECLARE_SOCKADDR(struct sockaddr_un *, sunaddr, uaddr); + int err = 0; + + if (peer) { + sk = unix_peer_get(sk); + + err = -ENOTCONN; + if (!sk) + goto out; + err = 0; + } else { + sock_hold(sk); + } + + addr = smp_load_acquire(&unix_sk(sk)->addr); + if (!addr) { + sunaddr->sun_family = AF_UNIX; + sunaddr->sun_path[0] = 0; + err = offsetof(struct sockaddr_un, sun_path); + } else { + err = addr->len; + memcpy(sunaddr, addr->name, addr->len); + } + sock_put(sk); +out: + return err; +} + +static void unix_peek_fds(struct scm_cookie *scm, struct sk_buff *skb) +{ + scm->fp = scm_fp_dup(UNIXCB(skb).fp); + + /* + * Garbage collection of unix sockets starts by selecting a set of + * candidate sockets which have reference only from being in flight + * (total_refs == inflight_refs). This condition is checked once during + * the candidate collection phase, and candidates are marked as such, so + * that non-candidates can later be ignored. While inflight_refs is + * protected by unix_gc_lock, total_refs (file count) is not, hence this + * is an instantaneous decision. + * + * Once a candidate, however, the socket must not be reinstalled into a + * file descriptor while the garbage collection is in progress. + * + * If the above conditions are met, then the directed graph of + * candidates (*) does not change while unix_gc_lock is held. + * + * Any operations that changes the file count through file descriptors + * (dup, close, sendmsg) does not change the graph since candidates are + * not installed in fds. + * + * Dequeing a candidate via recvmsg would install it into an fd, but + * that takes unix_gc_lock to decrement the inflight count, so it's + * serialized with garbage collection. + * + * MSG_PEEK is special in that it does not change the inflight count, + * yet does install the socket into an fd. The following lock/unlock + * pair is to ensure serialization with garbage collection. It must be + * done between incrementing the file count and installing the file into + * an fd. + * + * If garbage collection starts after the barrier provided by the + * lock/unlock, then it will see the elevated refcount and not mark this + * as a candidate. If a garbage collection is already in progress + * before the file count was incremented, then the lock/unlock pair will + * ensure that garbage collection is finished before progressing to + * installing the fd. + * + * (*) A -> B where B is on the queue of A or B is on the queue of C + * which is on the queue of listening socket A. + */ + spin_lock(&unix_gc_lock); + spin_unlock(&unix_gc_lock); +} + +static int unix_scm_to_skb(struct scm_cookie *scm, struct sk_buff *skb, bool send_fds) +{ + int err = 0; + + UNIXCB(skb).pid = get_pid(scm->pid); + UNIXCB(skb).uid = scm->creds.uid; + UNIXCB(skb).gid = scm->creds.gid; + UNIXCB(skb).fp = NULL; + unix_get_secdata(scm, skb); + if (scm->fp && send_fds) + err = unix_attach_fds(scm, skb); + + skb->destructor = unix_destruct_scm; + return err; +} + +static bool unix_passcred_enabled(const struct socket *sock, + const struct sock *other) +{ + return test_bit(SOCK_PASSCRED, &sock->flags) || + !other->sk_socket || + test_bit(SOCK_PASSCRED, &other->sk_socket->flags); +} + +/* + * Some apps rely on write() giving SCM_CREDENTIALS + * We include credentials if source or destination socket + * asserted SOCK_PASSCRED. + */ +static void maybe_add_creds(struct sk_buff *skb, const struct socket *sock, + const struct sock *other) +{ + if (UNIXCB(skb).pid) + return; + if (unix_passcred_enabled(sock, other)) { + UNIXCB(skb).pid = get_pid(task_tgid(current)); + current_uid_gid(&UNIXCB(skb).uid, &UNIXCB(skb).gid); + } +} + +static int maybe_init_creds(struct scm_cookie *scm, + struct socket *socket, + const struct sock *other) +{ + int err; + struct msghdr msg = { .msg_controllen = 0 }; + + err = scm_send(socket, &msg, scm, false); + if (err) + return err; + + if (unix_passcred_enabled(socket, other)) { + scm->pid = get_pid(task_tgid(current)); + current_uid_gid(&scm->creds.uid, &scm->creds.gid); + } + return err; +} + +static bool unix_skb_scm_eq(struct sk_buff *skb, + struct scm_cookie *scm) +{ + return UNIXCB(skb).pid == scm->pid && + uid_eq(UNIXCB(skb).uid, scm->creds.uid) && + gid_eq(UNIXCB(skb).gid, scm->creds.gid) && + unix_secdata_eq(scm, skb); +} + +static void scm_stat_add(struct sock *sk, struct sk_buff *skb) +{ + struct scm_fp_list *fp = UNIXCB(skb).fp; + struct unix_sock *u = unix_sk(sk); + + if (unlikely(fp && fp->count)) + atomic_add(fp->count, &u->scm_stat.nr_fds); +} + +static void scm_stat_del(struct sock *sk, struct sk_buff *skb) +{ + struct scm_fp_list *fp = UNIXCB(skb).fp; + struct unix_sock *u = unix_sk(sk); + + if (unlikely(fp && fp->count)) + atomic_sub(fp->count, &u->scm_stat.nr_fds); +} + +/* + * Send AF_UNIX data. + */ + +static int unix_dgram_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + DECLARE_SOCKADDR(struct sockaddr_un *, sunaddr, msg->msg_name); + struct sock *sk = sock->sk, *other = NULL; + struct unix_sock *u = unix_sk(sk); + struct scm_cookie scm; + struct sk_buff *skb; + int data_len = 0; + int sk_locked; + long timeo; + int err; + + wait_for_unix_gc(); + err = scm_send(sock, msg, &scm, false); + if (err < 0) + return err; + + err = -EOPNOTSUPP; + if (msg->msg_flags&MSG_OOB) + goto out; + + if (msg->msg_namelen) { + err = unix_validate_addr(sunaddr, msg->msg_namelen); + if (err) + goto out; + } else { + sunaddr = NULL; + err = -ENOTCONN; + other = unix_peer_get(sk); + if (!other) + goto out; + } + + if (test_bit(SOCK_PASSCRED, &sock->flags) && !u->addr) { + err = unix_autobind(sk); + if (err) + goto out; + } + + err = -EMSGSIZE; + if (len > sk->sk_sndbuf - 32) + goto out; + + if (len > SKB_MAX_ALLOC) { + data_len = min_t(size_t, + len - SKB_MAX_ALLOC, + MAX_SKB_FRAGS * PAGE_SIZE); + data_len = PAGE_ALIGN(data_len); + + BUILD_BUG_ON(SKB_MAX_ALLOC < PAGE_SIZE); + } + + skb = sock_alloc_send_pskb(sk, len - data_len, data_len, + msg->msg_flags & MSG_DONTWAIT, &err, + PAGE_ALLOC_COSTLY_ORDER); + if (skb == NULL) + goto out; + + err = unix_scm_to_skb(&scm, skb, true); + if (err < 0) + goto out_free; + + skb_put(skb, len - data_len); + skb->data_len = data_len; + skb->len = len; + err = skb_copy_datagram_from_iter(skb, 0, &msg->msg_iter, len); + if (err) + goto out_free; + + timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + +restart: + if (!other) { + err = -ECONNRESET; + if (sunaddr == NULL) + goto out_free; + + other = unix_find_other(sock_net(sk), sunaddr, msg->msg_namelen, + sk->sk_type); + if (IS_ERR(other)) { + err = PTR_ERR(other); + other = NULL; + goto out_free; + } + } + + if (sk_filter(other, skb) < 0) { + /* Toss the packet but do not return any error to the sender */ + err = len; + goto out_free; + } + + sk_locked = 0; + unix_state_lock(other); +restart_locked: + err = -EPERM; + if (!unix_may_send(sk, other)) + goto out_unlock; + + if (unlikely(sock_flag(other, SOCK_DEAD))) { + /* + * Check with 1003.1g - what should + * datagram error + */ + unix_state_unlock(other); + sock_put(other); + + if (!sk_locked) + unix_state_lock(sk); + + err = 0; + if (sk->sk_type == SOCK_SEQPACKET) { + /* We are here only when racing with unix_release_sock() + * is clearing @other. Never change state to TCP_CLOSE + * unlike SOCK_DGRAM wants. + */ + unix_state_unlock(sk); + err = -EPIPE; + } else if (unix_peer(sk) == other) { + unix_peer(sk) = NULL; + unix_dgram_peer_wake_disconnect_wakeup(sk, other); + + sk->sk_state = TCP_CLOSE; + unix_state_unlock(sk); + + unix_dgram_disconnected(sk, other); + sock_put(other); + err = -ECONNREFUSED; + } else { + unix_state_unlock(sk); + } + + other = NULL; + if (err) + goto out_free; + goto restart; + } + + err = -EPIPE; + if (other->sk_shutdown & RCV_SHUTDOWN) + goto out_unlock; + + if (sk->sk_type != SOCK_SEQPACKET) { + err = security_unix_may_send(sk->sk_socket, other->sk_socket); + if (err) + goto out_unlock; + } + + /* other == sk && unix_peer(other) != sk if + * - unix_peer(sk) == NULL, destination address bound to sk + * - unix_peer(sk) == sk by time of get but disconnected before lock + */ + if (other != sk && + unlikely(unix_peer(other) != sk && + unix_recvq_full_lockless(other))) { + if (timeo) { + timeo = unix_wait_for_peer(other, timeo); + + err = sock_intr_errno(timeo); + if (signal_pending(current)) + goto out_free; + + goto restart; + } + + if (!sk_locked) { + unix_state_unlock(other); + unix_state_double_lock(sk, other); + } + + if (unix_peer(sk) != other || + unix_dgram_peer_wake_me(sk, other)) { + err = -EAGAIN; + sk_locked = 1; + goto out_unlock; + } + + if (!sk_locked) { + sk_locked = 1; + goto restart_locked; + } + } + + if (unlikely(sk_locked)) + unix_state_unlock(sk); + + if (sock_flag(other, SOCK_RCVTSTAMP)) + __net_timestamp(skb); + maybe_add_creds(skb, sock, other); + scm_stat_add(other, skb); + skb_queue_tail(&other->sk_receive_queue, skb); + unix_state_unlock(other); + other->sk_data_ready(other); + sock_put(other); + scm_destroy(&scm); + return len; + +out_unlock: + if (sk_locked) + unix_state_unlock(sk); + unix_state_unlock(other); +out_free: + kfree_skb(skb); +out: + if (other) + sock_put(other); + scm_destroy(&scm); + return err; +} + +/* We use paged skbs for stream sockets, and limit occupancy to 32768 + * bytes, and a minimum of a full page. + */ +#define UNIX_SKB_FRAGS_SZ (PAGE_SIZE << get_order(32768)) + +#if IS_ENABLED(CONFIG_AF_UNIX_OOB) +static int queue_oob(struct socket *sock, struct msghdr *msg, struct sock *other, + struct scm_cookie *scm, bool fds_sent) +{ + struct unix_sock *ousk = unix_sk(other); + struct sk_buff *skb; + int err = 0; + + skb = sock_alloc_send_skb(sock->sk, 1, msg->msg_flags & MSG_DONTWAIT, &err); + + if (!skb) + return err; + + err = unix_scm_to_skb(scm, skb, !fds_sent); + if (err < 0) { + kfree_skb(skb); + return err; + } + skb_put(skb, 1); + err = skb_copy_datagram_from_iter(skb, 0, &msg->msg_iter, 1); + + if (err) { + kfree_skb(skb); + return err; + } + + unix_state_lock(other); + + if (sock_flag(other, SOCK_DEAD) || + (other->sk_shutdown & RCV_SHUTDOWN)) { + unix_state_unlock(other); + kfree_skb(skb); + return -EPIPE; + } + + maybe_add_creds(skb, sock, other); + skb_get(skb); + + if (ousk->oob_skb) + consume_skb(ousk->oob_skb); + + WRITE_ONCE(ousk->oob_skb, skb); + + scm_stat_add(other, skb); + skb_queue_tail(&other->sk_receive_queue, skb); + sk_send_sigurg(other); + unix_state_unlock(other); + other->sk_data_ready(other); + + return err; +} +#endif + +static int unix_stream_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + struct sock *sk = sock->sk; + struct sock *other = NULL; + int err, size; + struct sk_buff *skb; + int sent = 0; + struct scm_cookie scm; + bool fds_sent = false; + int data_len; + + wait_for_unix_gc(); + err = scm_send(sock, msg, &scm, false); + if (err < 0) + return err; + + err = -EOPNOTSUPP; + if (msg->msg_flags & MSG_OOB) { +#if IS_ENABLED(CONFIG_AF_UNIX_OOB) + if (len) + len--; + else +#endif + goto out_err; + } + + if (msg->msg_namelen) { + err = sk->sk_state == TCP_ESTABLISHED ? -EISCONN : -EOPNOTSUPP; + goto out_err; + } else { + err = -ENOTCONN; + other = unix_peer(sk); + if (!other) + goto out_err; + } + + if (sk->sk_shutdown & SEND_SHUTDOWN) + goto pipe_err; + + while (sent < len) { + size = len - sent; + + /* Keep two messages in the pipe so it schedules better */ + size = min_t(int, size, (sk->sk_sndbuf >> 1) - 64); + + /* allow fallback to order-0 allocations */ + size = min_t(int, size, SKB_MAX_HEAD(0) + UNIX_SKB_FRAGS_SZ); + + data_len = max_t(int, 0, size - SKB_MAX_HEAD(0)); + + data_len = min_t(size_t, size, PAGE_ALIGN(data_len)); + + skb = sock_alloc_send_pskb(sk, size - data_len, data_len, + msg->msg_flags & MSG_DONTWAIT, &err, + get_order(UNIX_SKB_FRAGS_SZ)); + if (!skb) + goto out_err; + + /* Only send the fds in the first buffer */ + err = unix_scm_to_skb(&scm, skb, !fds_sent); + if (err < 0) { + kfree_skb(skb); + goto out_err; + } + fds_sent = true; + + skb_put(skb, size - data_len); + skb->data_len = data_len; + skb->len = size; + err = skb_copy_datagram_from_iter(skb, 0, &msg->msg_iter, size); + if (err) { + kfree_skb(skb); + goto out_err; + } + + unix_state_lock(other); + + if (sock_flag(other, SOCK_DEAD) || + (other->sk_shutdown & RCV_SHUTDOWN)) + goto pipe_err_free; + + maybe_add_creds(skb, sock, other); + scm_stat_add(other, skb); + skb_queue_tail(&other->sk_receive_queue, skb); + unix_state_unlock(other); + other->sk_data_ready(other); + sent += size; + } + +#if IS_ENABLED(CONFIG_AF_UNIX_OOB) + if (msg->msg_flags & MSG_OOB) { + err = queue_oob(sock, msg, other, &scm, fds_sent); + if (err) + goto out_err; + sent++; + } +#endif + + scm_destroy(&scm); + + return sent; + +pipe_err_free: + unix_state_unlock(other); + kfree_skb(skb); +pipe_err: + if (sent == 0 && !(msg->msg_flags&MSG_NOSIGNAL)) + send_sig(SIGPIPE, current, 0); + err = -EPIPE; +out_err: + scm_destroy(&scm); + return sent ? : err; +} + +static ssize_t unix_stream_sendpage(struct socket *socket, struct page *page, + int offset, size_t size, int flags) +{ + int err; + bool send_sigpipe = false; + bool init_scm = true; + struct scm_cookie scm; + struct sock *other, *sk = socket->sk; + struct sk_buff *skb, *newskb = NULL, *tail = NULL; + + if (flags & MSG_OOB) + return -EOPNOTSUPP; + + other = unix_peer(sk); + if (!other || sk->sk_state != TCP_ESTABLISHED) + return -ENOTCONN; + + if (false) { +alloc_skb: + spin_unlock(&other->sk_receive_queue.lock); + unix_state_unlock(other); + mutex_unlock(&unix_sk(other)->iolock); + newskb = sock_alloc_send_pskb(sk, 0, 0, flags & MSG_DONTWAIT, + &err, 0); + if (!newskb) + goto err; + } + + /* we must acquire iolock as we modify already present + * skbs in the sk_receive_queue and mess with skb->len + */ + err = mutex_lock_interruptible(&unix_sk(other)->iolock); + if (err) { + err = flags & MSG_DONTWAIT ? -EAGAIN : -ERESTARTSYS; + goto err; + } + + if (sk->sk_shutdown & SEND_SHUTDOWN) { + err = -EPIPE; + send_sigpipe = true; + goto err_unlock; + } + + unix_state_lock(other); + + if (sock_flag(other, SOCK_DEAD) || + other->sk_shutdown & RCV_SHUTDOWN) { + err = -EPIPE; + send_sigpipe = true; + goto err_state_unlock; + } + + if (init_scm) { + err = maybe_init_creds(&scm, socket, other); + if (err) + goto err_state_unlock; + init_scm = false; + } + + spin_lock(&other->sk_receive_queue.lock); + skb = skb_peek_tail(&other->sk_receive_queue); + if (tail && tail == skb) { + skb = newskb; + } else if (!skb || !unix_skb_scm_eq(skb, &scm)) { + if (newskb) { + skb = newskb; + } else { + tail = skb; + goto alloc_skb; + } + } else if (newskb) { + /* this is fast path, we don't necessarily need to + * call to kfree_skb even though with newskb == NULL + * this - does no harm + */ + consume_skb(newskb); + newskb = NULL; + } + + if (skb_append_pagefrags(skb, page, offset, size)) { + tail = skb; + goto alloc_skb; + } + + skb->len += size; + skb->data_len += size; + skb->truesize += size; + refcount_add(size, &sk->sk_wmem_alloc); + + if (newskb) { + unix_scm_to_skb(&scm, skb, false); + __skb_queue_tail(&other->sk_receive_queue, newskb); + } + + spin_unlock(&other->sk_receive_queue.lock); + unix_state_unlock(other); + mutex_unlock(&unix_sk(other)->iolock); + + other->sk_data_ready(other); + scm_destroy(&scm); + return size; + +err_state_unlock: + unix_state_unlock(other); +err_unlock: + mutex_unlock(&unix_sk(other)->iolock); +err: + kfree_skb(newskb); + if (send_sigpipe && !(flags & MSG_NOSIGNAL)) + send_sig(SIGPIPE, current, 0); + if (!init_scm) + scm_destroy(&scm); + return err; +} + +static int unix_seqpacket_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + int err; + struct sock *sk = sock->sk; + + err = sock_error(sk); + if (err) + return err; + + if (sk->sk_state != TCP_ESTABLISHED) + return -ENOTCONN; + + if (msg->msg_namelen) + msg->msg_namelen = 0; + + return unix_dgram_sendmsg(sock, msg, len); +} + +static int unix_seqpacket_recvmsg(struct socket *sock, struct msghdr *msg, + size_t size, int flags) +{ + struct sock *sk = sock->sk; + + if (sk->sk_state != TCP_ESTABLISHED) + return -ENOTCONN; + + return unix_dgram_recvmsg(sock, msg, size, flags); +} + +static void unix_copy_addr(struct msghdr *msg, struct sock *sk) +{ + struct unix_address *addr = smp_load_acquire(&unix_sk(sk)->addr); + + if (addr) { + msg->msg_namelen = addr->len; + memcpy(msg->msg_name, addr->name, addr->len); + } +} + +int __unix_dgram_recvmsg(struct sock *sk, struct msghdr *msg, size_t size, + int flags) +{ + struct scm_cookie scm; + struct socket *sock = sk->sk_socket; + struct unix_sock *u = unix_sk(sk); + struct sk_buff *skb, *last; + long timeo; + int skip; + int err; + + err = -EOPNOTSUPP; + if (flags&MSG_OOB) + goto out; + + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + + do { + mutex_lock(&u->iolock); + + skip = sk_peek_offset(sk, flags); + skb = __skb_try_recv_datagram(sk, &sk->sk_receive_queue, flags, + &skip, &err, &last); + if (skb) { + if (!(flags & MSG_PEEK)) + scm_stat_del(sk, skb); + break; + } + + mutex_unlock(&u->iolock); + + if (err != -EAGAIN) + break; + } while (timeo && + !__skb_wait_for_more_packets(sk, &sk->sk_receive_queue, + &err, &timeo, last)); + + if (!skb) { /* implies iolock unlocked */ + unix_state_lock(sk); + /* Signal EOF on disconnected non-blocking SEQPACKET socket. */ + if (sk->sk_type == SOCK_SEQPACKET && err == -EAGAIN && + (sk->sk_shutdown & RCV_SHUTDOWN)) + err = 0; + unix_state_unlock(sk); + goto out; + } + + if (wq_has_sleeper(&u->peer_wait)) + wake_up_interruptible_sync_poll(&u->peer_wait, + EPOLLOUT | EPOLLWRNORM | + EPOLLWRBAND); + + if (msg->msg_name) + unix_copy_addr(msg, skb->sk); + + if (size > skb->len - skip) + size = skb->len - skip; + else if (size < skb->len - skip) + msg->msg_flags |= MSG_TRUNC; + + err = skb_copy_datagram_msg(skb, skip, msg, size); + if (err) + goto out_free; + + if (sock_flag(sk, SOCK_RCVTSTAMP)) + __sock_recv_timestamp(msg, sk, skb); + + memset(&scm, 0, sizeof(scm)); + + scm_set_cred(&scm, UNIXCB(skb).pid, UNIXCB(skb).uid, UNIXCB(skb).gid); + unix_set_secdata(&scm, skb); + + if (!(flags & MSG_PEEK)) { + if (UNIXCB(skb).fp) + unix_detach_fds(&scm, skb); + + sk_peek_offset_bwd(sk, skb->len); + } else { + /* It is questionable: on PEEK we could: + - do not return fds - good, but too simple 8) + - return fds, and do not return them on read (old strategy, + apparently wrong) + - clone fds (I chose it for now, it is the most universal + solution) + + POSIX 1003.1g does not actually define this clearly + at all. POSIX 1003.1g doesn't define a lot of things + clearly however! + + */ + + sk_peek_offset_fwd(sk, size); + + if (UNIXCB(skb).fp) + unix_peek_fds(&scm, skb); + } + err = (flags & MSG_TRUNC) ? skb->len - skip : size; + + scm_recv(sock, msg, &scm, flags); + +out_free: + skb_free_datagram(sk, skb); + mutex_unlock(&u->iolock); +out: + return err; +} + +static int unix_dgram_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, + int flags) +{ + struct sock *sk = sock->sk; + +#ifdef CONFIG_BPF_SYSCALL + const struct proto *prot = READ_ONCE(sk->sk_prot); + + if (prot != &unix_dgram_proto) + return prot->recvmsg(sk, msg, size, flags, NULL); +#endif + return __unix_dgram_recvmsg(sk, msg, size, flags); +} + +static int unix_read_skb(struct sock *sk, skb_read_actor_t recv_actor) +{ + struct unix_sock *u = unix_sk(sk); + struct sk_buff *skb; + int err; + + mutex_lock(&u->iolock); + skb = skb_recv_datagram(sk, MSG_DONTWAIT, &err); + mutex_unlock(&u->iolock); + if (!skb) + return err; + + return recv_actor(sk, skb); +} + +/* + * Sleep until more data has arrived. But check for races.. + */ +static long unix_stream_data_wait(struct sock *sk, long timeo, + struct sk_buff *last, unsigned int last_len, + bool freezable) +{ + unsigned int state = TASK_INTERRUPTIBLE | freezable * TASK_FREEZABLE; + struct sk_buff *tail; + DEFINE_WAIT(wait); + + unix_state_lock(sk); + + for (;;) { + prepare_to_wait(sk_sleep(sk), &wait, state); + + tail = skb_peek_tail(&sk->sk_receive_queue); + if (tail != last || + (tail && tail->len != last_len) || + sk->sk_err || + (sk->sk_shutdown & RCV_SHUTDOWN) || + signal_pending(current) || + !timeo) + break; + + sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); + unix_state_unlock(sk); + timeo = schedule_timeout(timeo); + unix_state_lock(sk); + + if (sock_flag(sk, SOCK_DEAD)) + break; + + sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); + } + + finish_wait(sk_sleep(sk), &wait); + unix_state_unlock(sk); + return timeo; +} + +static unsigned int unix_skb_len(const struct sk_buff *skb) +{ + return skb->len - UNIXCB(skb).consumed; +} + +struct unix_stream_read_state { + int (*recv_actor)(struct sk_buff *, int, int, + struct unix_stream_read_state *); + struct socket *socket; + struct msghdr *msg; + struct pipe_inode_info *pipe; + size_t size; + int flags; + unsigned int splice_flags; +}; + +#if IS_ENABLED(CONFIG_AF_UNIX_OOB) +static int unix_stream_recv_urg(struct unix_stream_read_state *state) +{ + struct socket *sock = state->socket; + struct sock *sk = sock->sk; + struct unix_sock *u = unix_sk(sk); + int chunk = 1; + struct sk_buff *oob_skb; + + mutex_lock(&u->iolock); + unix_state_lock(sk); + + if (sock_flag(sk, SOCK_URGINLINE) || !u->oob_skb) { + unix_state_unlock(sk); + mutex_unlock(&u->iolock); + return -EINVAL; + } + + oob_skb = u->oob_skb; + + if (!(state->flags & MSG_PEEK)) + WRITE_ONCE(u->oob_skb, NULL); + else + skb_get(oob_skb); + unix_state_unlock(sk); + + chunk = state->recv_actor(oob_skb, 0, chunk, state); + + if (!(state->flags & MSG_PEEK)) + UNIXCB(oob_skb).consumed += 1; + + consume_skb(oob_skb); + + mutex_unlock(&u->iolock); + + if (chunk < 0) + return -EFAULT; + + state->msg->msg_flags |= MSG_OOB; + return 1; +} + +static struct sk_buff *manage_oob(struct sk_buff *skb, struct sock *sk, + int flags, int copied) +{ + struct unix_sock *u = unix_sk(sk); + + if (!unix_skb_len(skb) && !(flags & MSG_PEEK)) { + skb_unlink(skb, &sk->sk_receive_queue); + consume_skb(skb); + skb = NULL; + } else { + if (skb == u->oob_skb) { + if (copied) { + skb = NULL; + } else if (sock_flag(sk, SOCK_URGINLINE)) { + if (!(flags & MSG_PEEK)) { + WRITE_ONCE(u->oob_skb, NULL); + consume_skb(skb); + } + } else if (!(flags & MSG_PEEK)) { + skb_unlink(skb, &sk->sk_receive_queue); + consume_skb(skb); + skb = skb_peek(&sk->sk_receive_queue); + } + } + } + return skb; +} +#endif + +static int unix_stream_read_skb(struct sock *sk, skb_read_actor_t recv_actor) +{ + if (unlikely(sk->sk_state != TCP_ESTABLISHED)) + return -ENOTCONN; + + return unix_read_skb(sk, recv_actor); +} + +static int unix_stream_read_generic(struct unix_stream_read_state *state, + bool freezable) +{ + struct scm_cookie scm; + struct socket *sock = state->socket; + struct sock *sk = sock->sk; + struct unix_sock *u = unix_sk(sk); + int copied = 0; + int flags = state->flags; + int noblock = flags & MSG_DONTWAIT; + bool check_creds = false; + int target; + int err = 0; + long timeo; + int skip; + size_t size = state->size; + unsigned int last_len; + + if (unlikely(sk->sk_state != TCP_ESTABLISHED)) { + err = -EINVAL; + goto out; + } + + if (unlikely(flags & MSG_OOB)) { + err = -EOPNOTSUPP; +#if IS_ENABLED(CONFIG_AF_UNIX_OOB) + err = unix_stream_recv_urg(state); +#endif + goto out; + } + + target = sock_rcvlowat(sk, flags & MSG_WAITALL, size); + timeo = sock_rcvtimeo(sk, noblock); + + memset(&scm, 0, sizeof(scm)); + + /* Lock the socket to prevent queue disordering + * while sleeps in memcpy_tomsg + */ + mutex_lock(&u->iolock); + + skip = max(sk_peek_offset(sk, flags), 0); + + do { + int chunk; + bool drop_skb; + struct sk_buff *skb, *last; + +redo: + unix_state_lock(sk); + if (sock_flag(sk, SOCK_DEAD)) { + err = -ECONNRESET; + goto unlock; + } + last = skb = skb_peek(&sk->sk_receive_queue); + last_len = last ? last->len : 0; + +#if IS_ENABLED(CONFIG_AF_UNIX_OOB) + if (skb) { + skb = manage_oob(skb, sk, flags, copied); + if (!skb) { + unix_state_unlock(sk); + if (copied) + break; + goto redo; + } + } +#endif +again: + if (skb == NULL) { + if (copied >= target) + goto unlock; + + /* + * POSIX 1003.1g mandates this order. + */ + + err = sock_error(sk); + if (err) + goto unlock; + if (sk->sk_shutdown & RCV_SHUTDOWN) + goto unlock; + + unix_state_unlock(sk); + if (!timeo) { + err = -EAGAIN; + break; + } + + mutex_unlock(&u->iolock); + + timeo = unix_stream_data_wait(sk, timeo, last, + last_len, freezable); + + if (signal_pending(current)) { + err = sock_intr_errno(timeo); + scm_destroy(&scm); + goto out; + } + + mutex_lock(&u->iolock); + goto redo; +unlock: + unix_state_unlock(sk); + break; + } + + while (skip >= unix_skb_len(skb)) { + skip -= unix_skb_len(skb); + last = skb; + last_len = skb->len; + skb = skb_peek_next(skb, &sk->sk_receive_queue); + if (!skb) + goto again; + } + + unix_state_unlock(sk); + + if (check_creds) { + /* Never glue messages from different writers */ + if (!unix_skb_scm_eq(skb, &scm)) + break; + } else if (test_bit(SOCK_PASSCRED, &sock->flags)) { + /* Copy credentials */ + scm_set_cred(&scm, UNIXCB(skb).pid, UNIXCB(skb).uid, UNIXCB(skb).gid); + unix_set_secdata(&scm, skb); + check_creds = true; + } + + /* Copy address just once */ + if (state->msg && state->msg->msg_name) { + DECLARE_SOCKADDR(struct sockaddr_un *, sunaddr, + state->msg->msg_name); + unix_copy_addr(state->msg, skb->sk); + sunaddr = NULL; + } + + chunk = min_t(unsigned int, unix_skb_len(skb) - skip, size); + skb_get(skb); + chunk = state->recv_actor(skb, skip, chunk, state); + drop_skb = !unix_skb_len(skb); + /* skb is only safe to use if !drop_skb */ + consume_skb(skb); + if (chunk < 0) { + if (copied == 0) + copied = -EFAULT; + break; + } + copied += chunk; + size -= chunk; + + if (drop_skb) { + /* the skb was touched by a concurrent reader; + * we should not expect anything from this skb + * anymore and assume it invalid - we can be + * sure it was dropped from the socket queue + * + * let's report a short read + */ + err = 0; + break; + } + + /* Mark read part of skb as used */ + if (!(flags & MSG_PEEK)) { + UNIXCB(skb).consumed += chunk; + + sk_peek_offset_bwd(sk, chunk); + + if (UNIXCB(skb).fp) { + scm_stat_del(sk, skb); + unix_detach_fds(&scm, skb); + } + + if (unix_skb_len(skb)) + break; + + skb_unlink(skb, &sk->sk_receive_queue); + consume_skb(skb); + + if (scm.fp) + break; + } else { + /* It is questionable, see note in unix_dgram_recvmsg. + */ + if (UNIXCB(skb).fp) + unix_peek_fds(&scm, skb); + + sk_peek_offset_fwd(sk, chunk); + + if (UNIXCB(skb).fp) + break; + + skip = 0; + last = skb; + last_len = skb->len; + unix_state_lock(sk); + skb = skb_peek_next(skb, &sk->sk_receive_queue); + if (skb) + goto again; + unix_state_unlock(sk); + break; + } + } while (size); + + mutex_unlock(&u->iolock); + if (state->msg) + scm_recv(sock, state->msg, &scm, flags); + else + scm_destroy(&scm); +out: + return copied ? : err; +} + +static int unix_stream_read_actor(struct sk_buff *skb, + int skip, int chunk, + struct unix_stream_read_state *state) +{ + int ret; + + ret = skb_copy_datagram_msg(skb, UNIXCB(skb).consumed + skip, + state->msg, chunk); + return ret ?: chunk; +} + +int __unix_stream_recvmsg(struct sock *sk, struct msghdr *msg, + size_t size, int flags) +{ + struct unix_stream_read_state state = { + .recv_actor = unix_stream_read_actor, + .socket = sk->sk_socket, + .msg = msg, + .size = size, + .flags = flags + }; + + return unix_stream_read_generic(&state, true); +} + +static int unix_stream_recvmsg(struct socket *sock, struct msghdr *msg, + size_t size, int flags) +{ + struct unix_stream_read_state state = { + .recv_actor = unix_stream_read_actor, + .socket = sock, + .msg = msg, + .size = size, + .flags = flags + }; + +#ifdef CONFIG_BPF_SYSCALL + struct sock *sk = sock->sk; + const struct proto *prot = READ_ONCE(sk->sk_prot); + + if (prot != &unix_stream_proto) + return prot->recvmsg(sk, msg, size, flags, NULL); +#endif + return unix_stream_read_generic(&state, true); +} + +static int unix_stream_splice_actor(struct sk_buff *skb, + int skip, int chunk, + struct unix_stream_read_state *state) +{ + return skb_splice_bits(skb, state->socket->sk, + UNIXCB(skb).consumed + skip, + state->pipe, chunk, state->splice_flags); +} + +static ssize_t unix_stream_splice_read(struct socket *sock, loff_t *ppos, + struct pipe_inode_info *pipe, + size_t size, unsigned int flags) +{ + struct unix_stream_read_state state = { + .recv_actor = unix_stream_splice_actor, + .socket = sock, + .pipe = pipe, + .size = size, + .splice_flags = flags, + }; + + if (unlikely(*ppos)) + return -ESPIPE; + + if (sock->file->f_flags & O_NONBLOCK || + flags & SPLICE_F_NONBLOCK) + state.flags = MSG_DONTWAIT; + + return unix_stream_read_generic(&state, false); +} + +static int unix_shutdown(struct socket *sock, int mode) +{ + struct sock *sk = sock->sk; + struct sock *other; + + if (mode < SHUT_RD || mode > SHUT_RDWR) + return -EINVAL; + /* This maps: + * SHUT_RD (0) -> RCV_SHUTDOWN (1) + * SHUT_WR (1) -> SEND_SHUTDOWN (2) + * SHUT_RDWR (2) -> SHUTDOWN_MASK (3) + */ + ++mode; + + unix_state_lock(sk); + WRITE_ONCE(sk->sk_shutdown, sk->sk_shutdown | mode); + other = unix_peer(sk); + if (other) + sock_hold(other); + unix_state_unlock(sk); + sk->sk_state_change(sk); + + if (other && + (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)) { + + int peer_mode = 0; + const struct proto *prot = READ_ONCE(other->sk_prot); + + if (prot->unhash) + prot->unhash(other); + if (mode&RCV_SHUTDOWN) + peer_mode |= SEND_SHUTDOWN; + if (mode&SEND_SHUTDOWN) + peer_mode |= RCV_SHUTDOWN; + unix_state_lock(other); + WRITE_ONCE(other->sk_shutdown, other->sk_shutdown | peer_mode); + unix_state_unlock(other); + other->sk_state_change(other); + if (peer_mode == SHUTDOWN_MASK) + sk_wake_async(other, SOCK_WAKE_WAITD, POLL_HUP); + else if (peer_mode & RCV_SHUTDOWN) + sk_wake_async(other, SOCK_WAKE_WAITD, POLL_IN); + } + if (other) + sock_put(other); + + return 0; +} + +long unix_inq_len(struct sock *sk) +{ + struct sk_buff *skb; + long amount = 0; + + if (sk->sk_state == TCP_LISTEN) + return -EINVAL; + + spin_lock(&sk->sk_receive_queue.lock); + if (sk->sk_type == SOCK_STREAM || + sk->sk_type == SOCK_SEQPACKET) { + skb_queue_walk(&sk->sk_receive_queue, skb) + amount += unix_skb_len(skb); + } else { + skb = skb_peek(&sk->sk_receive_queue); + if (skb) + amount = skb->len; + } + spin_unlock(&sk->sk_receive_queue.lock); + + return amount; +} +EXPORT_SYMBOL_GPL(unix_inq_len); + +long unix_outq_len(struct sock *sk) +{ + return sk_wmem_alloc_get(sk); +} +EXPORT_SYMBOL_GPL(unix_outq_len); + +static int unix_open_file(struct sock *sk) +{ + struct path path; + struct file *f; + int fd; + + if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + if (!smp_load_acquire(&unix_sk(sk)->addr)) + return -ENOENT; + + path = unix_sk(sk)->path; + if (!path.dentry) + return -ENOENT; + + path_get(&path); + + fd = get_unused_fd_flags(O_CLOEXEC); + if (fd < 0) + goto out; + + f = dentry_open(&path, O_PATH, current_cred()); + if (IS_ERR(f)) { + put_unused_fd(fd); + fd = PTR_ERR(f); + goto out; + } + + fd_install(fd, f); +out: + path_put(&path); + + return fd; +} + +static int unix_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + struct sock *sk = sock->sk; + long amount = 0; + int err; + + switch (cmd) { + case SIOCOUTQ: + amount = unix_outq_len(sk); + err = put_user(amount, (int __user *)arg); + break; + case SIOCINQ: + amount = unix_inq_len(sk); + if (amount < 0) + err = amount; + else + err = put_user(amount, (int __user *)arg); + break; + case SIOCUNIXFILE: + err = unix_open_file(sk); + break; +#if IS_ENABLED(CONFIG_AF_UNIX_OOB) + case SIOCATMARK: + { + struct sk_buff *skb; + int answ = 0; + + skb = skb_peek(&sk->sk_receive_queue); + if (skb && skb == READ_ONCE(unix_sk(sk)->oob_skb)) + answ = 1; + err = put_user(answ, (int __user *)arg); + } + break; +#endif + default: + err = -ENOIOCTLCMD; + break; + } + return err; +} + +#ifdef CONFIG_COMPAT +static int unix_compat_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + return unix_ioctl(sock, cmd, (unsigned long)compat_ptr(arg)); +} +#endif + +static __poll_t unix_poll(struct file *file, struct socket *sock, poll_table *wait) +{ + struct sock *sk = sock->sk; + __poll_t mask; + u8 shutdown; + + sock_poll_wait(file, sock, wait); + mask = 0; + shutdown = READ_ONCE(sk->sk_shutdown); + + /* exceptional events? */ + if (sk->sk_err) + mask |= EPOLLERR; + if (shutdown == SHUTDOWN_MASK) + mask |= EPOLLHUP; + if (shutdown & RCV_SHUTDOWN) + mask |= EPOLLRDHUP | EPOLLIN | EPOLLRDNORM; + + /* readable? */ + if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) + mask |= EPOLLIN | EPOLLRDNORM; + if (sk_is_readable(sk)) + mask |= EPOLLIN | EPOLLRDNORM; +#if IS_ENABLED(CONFIG_AF_UNIX_OOB) + if (READ_ONCE(unix_sk(sk)->oob_skb)) + mask |= EPOLLPRI; +#endif + + /* Connection-based need to check for termination and startup */ + if ((sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) && + sk->sk_state == TCP_CLOSE) + mask |= EPOLLHUP; + + /* + * we set writable also when the other side has shut down the + * connection. This prevents stuck sockets. + */ + if (unix_writable(sk)) + mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND; + + return mask; +} + +static __poll_t unix_dgram_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + struct sock *sk = sock->sk, *other; + unsigned int writable; + __poll_t mask; + u8 shutdown; + + sock_poll_wait(file, sock, wait); + mask = 0; + shutdown = READ_ONCE(sk->sk_shutdown); + + /* exceptional events? */ + if (sk->sk_err || !skb_queue_empty_lockless(&sk->sk_error_queue)) + mask |= EPOLLERR | + (sock_flag(sk, SOCK_SELECT_ERR_QUEUE) ? EPOLLPRI : 0); + + if (shutdown & RCV_SHUTDOWN) + mask |= EPOLLRDHUP | EPOLLIN | EPOLLRDNORM; + if (shutdown == SHUTDOWN_MASK) + mask |= EPOLLHUP; + + /* readable? */ + if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) + mask |= EPOLLIN | EPOLLRDNORM; + if (sk_is_readable(sk)) + mask |= EPOLLIN | EPOLLRDNORM; + + /* Connection-based need to check for termination and startup */ + if (sk->sk_type == SOCK_SEQPACKET) { + if (sk->sk_state == TCP_CLOSE) + mask |= EPOLLHUP; + /* connection hasn't started yet? */ + if (sk->sk_state == TCP_SYN_SENT) + return mask; + } + + /* No write status requested, avoid expensive OUT tests. */ + if (!(poll_requested_events(wait) & (EPOLLWRBAND|EPOLLWRNORM|EPOLLOUT))) + return mask; + + writable = unix_writable(sk); + if (writable) { + unix_state_lock(sk); + + other = unix_peer(sk); + if (other && unix_peer(other) != sk && + unix_recvq_full_lockless(other) && + unix_dgram_peer_wake_me(sk, other)) + writable = 0; + + unix_state_unlock(sk); + } + + if (writable) + mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND; + else + sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk); + + return mask; +} + +#ifdef CONFIG_PROC_FS + +#define BUCKET_SPACE (BITS_PER_LONG - (UNIX_HASH_BITS + 1) - 1) + +#define get_bucket(x) ((x) >> BUCKET_SPACE) +#define get_offset(x) ((x) & ((1UL << BUCKET_SPACE) - 1)) +#define set_bucket_offset(b, o) ((b) << BUCKET_SPACE | (o)) + +static struct sock *unix_from_bucket(struct seq_file *seq, loff_t *pos) +{ + unsigned long offset = get_offset(*pos); + unsigned long bucket = get_bucket(*pos); + unsigned long count = 0; + struct sock *sk; + + for (sk = sk_head(&seq_file_net(seq)->unx.table.buckets[bucket]); + sk; sk = sk_next(sk)) { + if (++count == offset) + break; + } + + return sk; +} + +static struct sock *unix_get_first(struct seq_file *seq, loff_t *pos) +{ + unsigned long bucket = get_bucket(*pos); + struct net *net = seq_file_net(seq); + struct sock *sk; + + while (bucket < UNIX_HASH_SIZE) { + spin_lock(&net->unx.table.locks[bucket]); + + sk = unix_from_bucket(seq, pos); + if (sk) + return sk; + + spin_unlock(&net->unx.table.locks[bucket]); + + *pos = set_bucket_offset(++bucket, 1); + } + + return NULL; +} + +static struct sock *unix_get_next(struct seq_file *seq, struct sock *sk, + loff_t *pos) +{ + unsigned long bucket = get_bucket(*pos); + + sk = sk_next(sk); + if (sk) + return sk; + + + spin_unlock(&seq_file_net(seq)->unx.table.locks[bucket]); + + *pos = set_bucket_offset(++bucket, 1); + + return unix_get_first(seq, pos); +} + +static void *unix_seq_start(struct seq_file *seq, loff_t *pos) +{ + if (!*pos) + return SEQ_START_TOKEN; + + return unix_get_first(seq, pos); +} + +static void *unix_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + ++*pos; + + if (v == SEQ_START_TOKEN) + return unix_get_first(seq, pos); + + return unix_get_next(seq, v, pos); +} + +static void unix_seq_stop(struct seq_file *seq, void *v) +{ + struct sock *sk = v; + + if (sk) + spin_unlock(&seq_file_net(seq)->unx.table.locks[sk->sk_hash]); +} + +static int unix_seq_show(struct seq_file *seq, void *v) +{ + + if (v == SEQ_START_TOKEN) + seq_puts(seq, "Num RefCount Protocol Flags Type St " + "Inode Path\n"); + else { + struct sock *s = v; + struct unix_sock *u = unix_sk(s); + unix_state_lock(s); + + seq_printf(seq, "%pK: %08X %08X %08X %04X %02X %5lu", + s, + refcount_read(&s->sk_refcnt), + 0, + s->sk_state == TCP_LISTEN ? __SO_ACCEPTCON : 0, + s->sk_type, + s->sk_socket ? + (s->sk_state == TCP_ESTABLISHED ? SS_CONNECTED : SS_UNCONNECTED) : + (s->sk_state == TCP_ESTABLISHED ? SS_CONNECTING : SS_DISCONNECTING), + sock_i_ino(s)); + + if (u->addr) { // under a hash table lock here + int i, len; + seq_putc(seq, ' '); + + i = 0; + len = u->addr->len - + offsetof(struct sockaddr_un, sun_path); + if (u->addr->name->sun_path[0]) { + len--; + } else { + seq_putc(seq, '@'); + i++; + } + for ( ; i < len; i++) + seq_putc(seq, u->addr->name->sun_path[i] ?: + '@'); + } + unix_state_unlock(s); + seq_putc(seq, '\n'); + } + + return 0; +} + +static const struct seq_operations unix_seq_ops = { + .start = unix_seq_start, + .next = unix_seq_next, + .stop = unix_seq_stop, + .show = unix_seq_show, +}; + +#if IS_BUILTIN(CONFIG_UNIX) && defined(CONFIG_BPF_SYSCALL) +struct bpf_unix_iter_state { + struct seq_net_private p; + unsigned int cur_sk; + unsigned int end_sk; + unsigned int max_sk; + struct sock **batch; + bool st_bucket_done; +}; + +struct bpf_iter__unix { + __bpf_md_ptr(struct bpf_iter_meta *, meta); + __bpf_md_ptr(struct unix_sock *, unix_sk); + uid_t uid __aligned(8); +}; + +static int unix_prog_seq_show(struct bpf_prog *prog, struct bpf_iter_meta *meta, + struct unix_sock *unix_sk, uid_t uid) +{ + struct bpf_iter__unix ctx; + + meta->seq_num--; /* skip SEQ_START_TOKEN */ + ctx.meta = meta; + ctx.unix_sk = unix_sk; + ctx.uid = uid; + return bpf_iter_run_prog(prog, &ctx); +} + +static int bpf_iter_unix_hold_batch(struct seq_file *seq, struct sock *start_sk) + +{ + struct bpf_unix_iter_state *iter = seq->private; + unsigned int expected = 1; + struct sock *sk; + + sock_hold(start_sk); + iter->batch[iter->end_sk++] = start_sk; + + for (sk = sk_next(start_sk); sk; sk = sk_next(sk)) { + if (iter->end_sk < iter->max_sk) { + sock_hold(sk); + iter->batch[iter->end_sk++] = sk; + } + + expected++; + } + + spin_unlock(&seq_file_net(seq)->unx.table.locks[start_sk->sk_hash]); + + return expected; +} + +static void bpf_iter_unix_put_batch(struct bpf_unix_iter_state *iter) +{ + while (iter->cur_sk < iter->end_sk) + sock_put(iter->batch[iter->cur_sk++]); +} + +static int bpf_iter_unix_realloc_batch(struct bpf_unix_iter_state *iter, + unsigned int new_batch_sz) +{ + struct sock **new_batch; + + new_batch = kvmalloc(sizeof(*new_batch) * new_batch_sz, + GFP_USER | __GFP_NOWARN); + if (!new_batch) + return -ENOMEM; + + bpf_iter_unix_put_batch(iter); + kvfree(iter->batch); + iter->batch = new_batch; + iter->max_sk = new_batch_sz; + + return 0; +} + +static struct sock *bpf_iter_unix_batch(struct seq_file *seq, + loff_t *pos) +{ + struct bpf_unix_iter_state *iter = seq->private; + unsigned int expected; + bool resized = false; + struct sock *sk; + + if (iter->st_bucket_done) + *pos = set_bucket_offset(get_bucket(*pos) + 1, 1); + +again: + /* Get a new batch */ + iter->cur_sk = 0; + iter->end_sk = 0; + + sk = unix_get_first(seq, pos); + if (!sk) + return NULL; /* Done */ + + expected = bpf_iter_unix_hold_batch(seq, sk); + + if (iter->end_sk == expected) { + iter->st_bucket_done = true; + return sk; + } + + if (!resized && !bpf_iter_unix_realloc_batch(iter, expected * 3 / 2)) { + resized = true; + goto again; + } + + return sk; +} + +static void *bpf_iter_unix_seq_start(struct seq_file *seq, loff_t *pos) +{ + if (!*pos) + return SEQ_START_TOKEN; + + /* bpf iter does not support lseek, so it always + * continue from where it was stop()-ped. + */ + return bpf_iter_unix_batch(seq, pos); +} + +static void *bpf_iter_unix_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct bpf_unix_iter_state *iter = seq->private; + struct sock *sk; + + /* Whenever seq_next() is called, the iter->cur_sk is + * done with seq_show(), so advance to the next sk in + * the batch. + */ + if (iter->cur_sk < iter->end_sk) + sock_put(iter->batch[iter->cur_sk++]); + + ++*pos; + + if (iter->cur_sk < iter->end_sk) + sk = iter->batch[iter->cur_sk]; + else + sk = bpf_iter_unix_batch(seq, pos); + + return sk; +} + +static int bpf_iter_unix_seq_show(struct seq_file *seq, void *v) +{ + struct bpf_iter_meta meta; + struct bpf_prog *prog; + struct sock *sk = v; + uid_t uid; + bool slow; + int ret; + + if (v == SEQ_START_TOKEN) + return 0; + + slow = lock_sock_fast(sk); + + if (unlikely(sk_unhashed(sk))) { + ret = SEQ_SKIP; + goto unlock; + } + + uid = from_kuid_munged(seq_user_ns(seq), sock_i_uid(sk)); + meta.seq = seq; + prog = bpf_iter_get_info(&meta, false); + ret = unix_prog_seq_show(prog, &meta, v, uid); +unlock: + unlock_sock_fast(sk, slow); + return ret; +} + +static void bpf_iter_unix_seq_stop(struct seq_file *seq, void *v) +{ + struct bpf_unix_iter_state *iter = seq->private; + struct bpf_iter_meta meta; + struct bpf_prog *prog; + + if (!v) { + meta.seq = seq; + prog = bpf_iter_get_info(&meta, true); + if (prog) + (void)unix_prog_seq_show(prog, &meta, v, 0); + } + + if (iter->cur_sk < iter->end_sk) + bpf_iter_unix_put_batch(iter); +} + +static const struct seq_operations bpf_iter_unix_seq_ops = { + .start = bpf_iter_unix_seq_start, + .next = bpf_iter_unix_seq_next, + .stop = bpf_iter_unix_seq_stop, + .show = bpf_iter_unix_seq_show, +}; +#endif +#endif + +static const struct net_proto_family unix_family_ops = { + .family = PF_UNIX, + .create = unix_create, + .owner = THIS_MODULE, +}; + + +static int __net_init unix_net_init(struct net *net) +{ + int i; + + net->unx.sysctl_max_dgram_qlen = 10; + if (unix_sysctl_register(net)) + goto out; + +#ifdef CONFIG_PROC_FS + if (!proc_create_net("unix", 0, net->proc_net, &unix_seq_ops, + sizeof(struct seq_net_private))) + goto err_sysctl; +#endif + + net->unx.table.locks = kvmalloc_array(UNIX_HASH_SIZE, + sizeof(spinlock_t), GFP_KERNEL); + if (!net->unx.table.locks) + goto err_proc; + + net->unx.table.buckets = kvmalloc_array(UNIX_HASH_SIZE, + sizeof(struct hlist_head), + GFP_KERNEL); + if (!net->unx.table.buckets) + goto free_locks; + + for (i = 0; i < UNIX_HASH_SIZE; i++) { + spin_lock_init(&net->unx.table.locks[i]); + INIT_HLIST_HEAD(&net->unx.table.buckets[i]); + } + + return 0; + +free_locks: + kvfree(net->unx.table.locks); +err_proc: +#ifdef CONFIG_PROC_FS + remove_proc_entry("unix", net->proc_net); +err_sysctl: +#endif + unix_sysctl_unregister(net); +out: + return -ENOMEM; +} + +static void __net_exit unix_net_exit(struct net *net) +{ + kvfree(net->unx.table.buckets); + kvfree(net->unx.table.locks); + unix_sysctl_unregister(net); + remove_proc_entry("unix", net->proc_net); +} + +static struct pernet_operations unix_net_ops = { + .init = unix_net_init, + .exit = unix_net_exit, +}; + +#if IS_BUILTIN(CONFIG_UNIX) && defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS) +DEFINE_BPF_ITER_FUNC(unix, struct bpf_iter_meta *meta, + struct unix_sock *unix_sk, uid_t uid) + +#define INIT_BATCH_SZ 16 + +static int bpf_iter_init_unix(void *priv_data, struct bpf_iter_aux_info *aux) +{ + struct bpf_unix_iter_state *iter = priv_data; + int err; + + err = bpf_iter_init_seq_net(priv_data, aux); + if (err) + return err; + + err = bpf_iter_unix_realloc_batch(iter, INIT_BATCH_SZ); + if (err) { + bpf_iter_fini_seq_net(priv_data); + return err; + } + + return 0; +} + +static void bpf_iter_fini_unix(void *priv_data) +{ + struct bpf_unix_iter_state *iter = priv_data; + + bpf_iter_fini_seq_net(priv_data); + kvfree(iter->batch); +} + +static const struct bpf_iter_seq_info unix_seq_info = { + .seq_ops = &bpf_iter_unix_seq_ops, + .init_seq_private = bpf_iter_init_unix, + .fini_seq_private = bpf_iter_fini_unix, + .seq_priv_size = sizeof(struct bpf_unix_iter_state), +}; + +static const struct bpf_func_proto * +bpf_iter_unix_get_func_proto(enum bpf_func_id func_id, + const struct bpf_prog *prog) +{ + switch (func_id) { + case BPF_FUNC_setsockopt: + return &bpf_sk_setsockopt_proto; + case BPF_FUNC_getsockopt: + return &bpf_sk_getsockopt_proto; + default: + return NULL; + } +} + +static struct bpf_iter_reg unix_reg_info = { + .target = "unix", + .ctx_arg_info_size = 1, + .ctx_arg_info = { + { offsetof(struct bpf_iter__unix, unix_sk), + PTR_TO_BTF_ID_OR_NULL }, + }, + .get_func_proto = bpf_iter_unix_get_func_proto, + .seq_info = &unix_seq_info, +}; + +static void __init bpf_iter_register(void) +{ + unix_reg_info.ctx_arg_info[0].btf_id = btf_sock_ids[BTF_SOCK_TYPE_UNIX]; + if (bpf_iter_reg_target(&unix_reg_info)) + pr_warn("Warning: could not register bpf iterator unix\n"); +} +#endif + +static int __init af_unix_init(void) +{ + int i, rc = -1; + + BUILD_BUG_ON(sizeof(struct unix_skb_parms) > sizeof_field(struct sk_buff, cb)); + + for (i = 0; i < UNIX_HASH_SIZE / 2; i++) { + spin_lock_init(&bsd_socket_locks[i]); + INIT_HLIST_HEAD(&bsd_socket_buckets[i]); + } + + rc = proto_register(&unix_dgram_proto, 1); + if (rc != 0) { + pr_crit("%s: Cannot create unix_sock SLAB cache!\n", __func__); + goto out; + } + + rc = proto_register(&unix_stream_proto, 1); + if (rc != 0) { + pr_crit("%s: Cannot create unix_sock SLAB cache!\n", __func__); + proto_unregister(&unix_dgram_proto); + goto out; + } + + sock_register(&unix_family_ops); + register_pernet_subsys(&unix_net_ops); + unix_bpf_build_proto(); + +#if IS_BUILTIN(CONFIG_UNIX) && defined(CONFIG_BPF_SYSCALL) && defined(CONFIG_PROC_FS) + bpf_iter_register(); +#endif + +out: + return rc; +} + +static void __exit af_unix_exit(void) +{ + sock_unregister(PF_UNIX); + proto_unregister(&unix_dgram_proto); + proto_unregister(&unix_stream_proto); + unregister_pernet_subsys(&unix_net_ops); +} + +/* Earlier than device_initcall() so that other drivers invoking + request_module() don't end up in a loop when modprobe tries + to use a UNIX socket. But later than subsys_initcall() because + we depend on stuff initialised there */ +fs_initcall(af_unix_init); +module_exit(af_unix_exit); + +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_UNIX); diff --git a/net/unix/diag.c b/net/unix/diag.c new file mode 100644 index 000000000..616b55c5b --- /dev/null +++ b/net/unix/diag.c @@ -0,0 +1,342 @@ +// SPDX-License-Identifier: GPL-2.0-only +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int sk_diag_dump_name(struct sock *sk, struct sk_buff *nlskb) +{ + /* might or might not have a hash table lock */ + struct unix_address *addr = smp_load_acquire(&unix_sk(sk)->addr); + + if (!addr) + return 0; + + return nla_put(nlskb, UNIX_DIAG_NAME, + addr->len - offsetof(struct sockaddr_un, sun_path), + addr->name->sun_path); +} + +static int sk_diag_dump_vfs(struct sock *sk, struct sk_buff *nlskb) +{ + struct dentry *dentry = unix_sk(sk)->path.dentry; + + if (dentry) { + struct unix_diag_vfs uv = { + .udiag_vfs_ino = d_backing_inode(dentry)->i_ino, + .udiag_vfs_dev = dentry->d_sb->s_dev, + }; + + return nla_put(nlskb, UNIX_DIAG_VFS, sizeof(uv), &uv); + } + + return 0; +} + +static int sk_diag_dump_peer(struct sock *sk, struct sk_buff *nlskb) +{ + struct sock *peer; + int ino; + + peer = unix_peer_get(sk); + if (peer) { + unix_state_lock(peer); + ino = sock_i_ino(peer); + unix_state_unlock(peer); + sock_put(peer); + + return nla_put_u32(nlskb, UNIX_DIAG_PEER, ino); + } + + return 0; +} + +static int sk_diag_dump_icons(struct sock *sk, struct sk_buff *nlskb) +{ + struct sk_buff *skb; + struct nlattr *attr; + u32 *buf; + int i; + + if (sk->sk_state == TCP_LISTEN) { + spin_lock(&sk->sk_receive_queue.lock); + + attr = nla_reserve(nlskb, UNIX_DIAG_ICONS, + sk->sk_receive_queue.qlen * sizeof(u32)); + if (!attr) + goto errout; + + buf = nla_data(attr); + i = 0; + skb_queue_walk(&sk->sk_receive_queue, skb) { + struct sock *req, *peer; + + req = skb->sk; + /* + * The state lock is outer for the same sk's + * queue lock. With the other's queue locked it's + * OK to lock the state. + */ + unix_state_lock_nested(req); + peer = unix_sk(req)->peer; + buf[i++] = (peer ? sock_i_ino(peer) : 0); + unix_state_unlock(req); + } + spin_unlock(&sk->sk_receive_queue.lock); + } + + return 0; + +errout: + spin_unlock(&sk->sk_receive_queue.lock); + return -EMSGSIZE; +} + +static int sk_diag_show_rqlen(struct sock *sk, struct sk_buff *nlskb) +{ + struct unix_diag_rqlen rql; + + if (sk->sk_state == TCP_LISTEN) { + rql.udiag_rqueue = sk->sk_receive_queue.qlen; + rql.udiag_wqueue = sk->sk_max_ack_backlog; + } else { + rql.udiag_rqueue = (u32) unix_inq_len(sk); + rql.udiag_wqueue = (u32) unix_outq_len(sk); + } + + return nla_put(nlskb, UNIX_DIAG_RQLEN, sizeof(rql), &rql); +} + +static int sk_diag_dump_uid(struct sock *sk, struct sk_buff *nlskb, + struct user_namespace *user_ns) +{ + uid_t uid = from_kuid_munged(user_ns, sock_i_uid(sk)); + return nla_put(nlskb, UNIX_DIAG_UID, sizeof(uid_t), &uid); +} + +static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, struct unix_diag_req *req, + struct user_namespace *user_ns, + u32 portid, u32 seq, u32 flags, int sk_ino) +{ + struct nlmsghdr *nlh; + struct unix_diag_msg *rep; + + nlh = nlmsg_put(skb, portid, seq, SOCK_DIAG_BY_FAMILY, sizeof(*rep), + flags); + if (!nlh) + return -EMSGSIZE; + + rep = nlmsg_data(nlh); + rep->udiag_family = AF_UNIX; + rep->udiag_type = sk->sk_type; + rep->udiag_state = sk->sk_state; + rep->pad = 0; + rep->udiag_ino = sk_ino; + sock_diag_save_cookie(sk, rep->udiag_cookie); + + if ((req->udiag_show & UDIAG_SHOW_NAME) && + sk_diag_dump_name(sk, skb)) + goto out_nlmsg_trim; + + if ((req->udiag_show & UDIAG_SHOW_VFS) && + sk_diag_dump_vfs(sk, skb)) + goto out_nlmsg_trim; + + if ((req->udiag_show & UDIAG_SHOW_PEER) && + sk_diag_dump_peer(sk, skb)) + goto out_nlmsg_trim; + + if ((req->udiag_show & UDIAG_SHOW_ICONS) && + sk_diag_dump_icons(sk, skb)) + goto out_nlmsg_trim; + + if ((req->udiag_show & UDIAG_SHOW_RQLEN) && + sk_diag_show_rqlen(sk, skb)) + goto out_nlmsg_trim; + + if ((req->udiag_show & UDIAG_SHOW_MEMINFO) && + sock_diag_put_meminfo(sk, skb, UNIX_DIAG_MEMINFO)) + goto out_nlmsg_trim; + + if (nla_put_u8(skb, UNIX_DIAG_SHUTDOWN, sk->sk_shutdown)) + goto out_nlmsg_trim; + + if ((req->udiag_show & UDIAG_SHOW_UID) && + sk_diag_dump_uid(sk, skb, user_ns)) + goto out_nlmsg_trim; + + nlmsg_end(skb, nlh); + return 0; + +out_nlmsg_trim: + nlmsg_cancel(skb, nlh); + return -EMSGSIZE; +} + +static int sk_diag_dump(struct sock *sk, struct sk_buff *skb, struct unix_diag_req *req, + struct user_namespace *user_ns, + u32 portid, u32 seq, u32 flags) +{ + int sk_ino; + + unix_state_lock(sk); + sk_ino = sock_i_ino(sk); + unix_state_unlock(sk); + + if (!sk_ino) + return 0; + + return sk_diag_fill(sk, skb, req, user_ns, portid, seq, flags, sk_ino); +} + +static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + int num, s_num, slot, s_slot; + struct unix_diag_req *req; + + req = nlmsg_data(cb->nlh); + + s_slot = cb->args[0]; + num = s_num = cb->args[1]; + + for (slot = s_slot; slot < UNIX_HASH_SIZE; s_num = 0, slot++) { + struct sock *sk; + + num = 0; + spin_lock(&net->unx.table.locks[slot]); + sk_for_each(sk, &net->unx.table.buckets[slot]) { + if (num < s_num) + goto next; + if (!(req->udiag_states & (1 << sk->sk_state))) + goto next; + if (sk_diag_dump(sk, skb, req, sk_user_ns(skb->sk), + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI) < 0) { + spin_unlock(&net->unx.table.locks[slot]); + goto done; + } +next: + num++; + } + spin_unlock(&net->unx.table.locks[slot]); + } +done: + cb->args[0] = slot; + cb->args[1] = num; + + return skb->len; +} + +static struct sock *unix_lookup_by_ino(struct net *net, unsigned int ino) +{ + struct sock *sk; + int i; + + for (i = 0; i < UNIX_HASH_SIZE; i++) { + spin_lock(&net->unx.table.locks[i]); + sk_for_each(sk, &net->unx.table.buckets[i]) { + if (ino == sock_i_ino(sk)) { + sock_hold(sk); + spin_unlock(&net->unx.table.locks[i]); + return sk; + } + } + spin_unlock(&net->unx.table.locks[i]); + } + return NULL; +} + +static int unix_diag_get_exact(struct sk_buff *in_skb, + const struct nlmsghdr *nlh, + struct unix_diag_req *req) +{ + struct net *net = sock_net(in_skb->sk); + unsigned int extra_len; + struct sk_buff *rep; + struct sock *sk; + int err; + + err = -EINVAL; + if (req->udiag_ino == 0) + goto out_nosk; + + sk = unix_lookup_by_ino(net, req->udiag_ino); + err = -ENOENT; + if (sk == NULL) + goto out_nosk; + + err = sock_diag_check_cookie(sk, req->udiag_cookie); + if (err) + goto out; + + extra_len = 256; +again: + err = -ENOMEM; + rep = nlmsg_new(sizeof(struct unix_diag_msg) + extra_len, GFP_KERNEL); + if (!rep) + goto out; + + err = sk_diag_fill(sk, rep, req, sk_user_ns(NETLINK_CB(in_skb).sk), + NETLINK_CB(in_skb).portid, + nlh->nlmsg_seq, 0, req->udiag_ino); + if (err < 0) { + nlmsg_free(rep); + extra_len += 256; + if (extra_len >= PAGE_SIZE) + goto out; + + goto again; + } + err = nlmsg_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).portid); + +out: + if (sk) + sock_put(sk); +out_nosk: + return err; +} + +static int unix_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h) +{ + int hdrlen = sizeof(struct unix_diag_req); + + if (nlmsg_len(h) < hdrlen) + return -EINVAL; + + if (h->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .dump = unix_diag_dump, + }; + return netlink_dump_start(sock_net(skb->sk)->diag_nlsk, skb, h, &c); + } else + return unix_diag_get_exact(skb, h, nlmsg_data(h)); +} + +static const struct sock_diag_handler unix_diag_handler = { + .family = AF_UNIX, + .dump = unix_diag_handler_dump, +}; + +static int __init unix_diag_init(void) +{ + return sock_diag_register(&unix_diag_handler); +} + +static void __exit unix_diag_exit(void) +{ + sock_diag_unregister(&unix_diag_handler); +} + +module_init(unix_diag_init); +module_exit(unix_diag_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, 1 /* AF_LOCAL */); diff --git a/net/unix/garbage.c b/net/unix/garbage.c new file mode 100644 index 000000000..dc2763540 --- /dev/null +++ b/net/unix/garbage.c @@ -0,0 +1,335 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NET3: Garbage Collector For AF_UNIX sockets + * + * Garbage Collector: + * Copyright (C) Barak A. Pearlmutter. + * + * Chopped about by Alan Cox 22/3/96 to make it fit the AF_UNIX socket problem. + * If it doesn't work blame me, it worked when Barak sent it. + * + * Assumptions: + * + * - object w/ a bit + * - free list + * + * Current optimizations: + * + * - explicit stack instead of recursion + * - tail recurse on first born instead of immediate push/pop + * - we gather the stuff that should not be killed into tree + * and stack is just a path from root to the current pointer. + * + * Future optimizations: + * + * - don't just push entire root set; process in place + * + * Fixes: + * Alan Cox 07 Sept 1997 Vmalloc internal stack as needed. + * Cope with changing max_files. + * Al Viro 11 Oct 1998 + * Graph may have cycles. That is, we can send the descriptor + * of foo to bar and vice versa. Current code chokes on that. + * Fix: move SCM_RIGHTS ones into the separate list and then + * skb_free() them all instead of doing explicit fput's. + * Another problem: since fput() may block somebody may + * create a new unix_socket when we are in the middle of sweep + * phase. Fix: revert the logic wrt MARKED. Mark everything + * upon the beginning and unmark non-junk ones. + * + * [12 Oct 1998] AAARGH! New code purges all SCM_RIGHTS + * sent to connect()'ed but still not accept()'ed sockets. + * Fixed. Old code had slightly different problem here: + * extra fput() in situation when we passed the descriptor via + * such socket and closed it (descriptor). That would happen on + * each unix_gc() until the accept(). Since the struct file in + * question would go to the free list and might be reused... + * That might be the reason of random oopses on filp_close() + * in unrelated processes. + * + * AV 28 Feb 1999 + * Kill the explicit allocation of stack. Now we keep the tree + * with root in dummy + pointer (gc_current) to one of the nodes. + * Stack is represented as path from gc_current to dummy. Unmark + * now means "add to tree". Push == "make it a son of gc_current". + * Pop == "move gc_current to parent". We keep only pointers to + * parents (->gc_tree). + * AV 1 Mar 1999 + * Damn. Added missing check for ->dead in listen queues scanning. + * + * Miklos Szeredi 25 Jun 2007 + * Reimplement with a cycle collecting algorithm. This should + * solve several problems with the previous code, like being racy + * wrt receive and holding up unrelated socket operations. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "scm.h" + +/* Internal data structures and random procedures: */ + +static LIST_HEAD(gc_candidates); +static DECLARE_WAIT_QUEUE_HEAD(unix_gc_wait); + +static void scan_inflight(struct sock *x, void (*func)(struct unix_sock *), + struct sk_buff_head *hitlist) +{ + struct sk_buff *skb; + struct sk_buff *next; + + spin_lock(&x->sk_receive_queue.lock); + skb_queue_walk_safe(&x->sk_receive_queue, skb, next) { + /* Do we have file descriptors ? */ + if (UNIXCB(skb).fp) { + bool hit = false; + /* Process the descriptors of this socket */ + int nfd = UNIXCB(skb).fp->count; + struct file **fp = UNIXCB(skb).fp->fp; + + while (nfd--) { + /* Get the socket the fd matches if it indeed does so */ + struct sock *sk = unix_get_socket(*fp++); + + if (sk) { + struct unix_sock *u = unix_sk(sk); + + /* Ignore non-candidates, they could + * have been added to the queues after + * starting the garbage collection + */ + if (test_bit(UNIX_GC_CANDIDATE, &u->gc_flags)) { + hit = true; + + func(u); + } + } + } + if (hit && hitlist != NULL) { + __skb_unlink(skb, &x->sk_receive_queue); + __skb_queue_tail(hitlist, skb); + } + } + } + spin_unlock(&x->sk_receive_queue.lock); +} + +static void scan_children(struct sock *x, void (*func)(struct unix_sock *), + struct sk_buff_head *hitlist) +{ + if (x->sk_state != TCP_LISTEN) { + scan_inflight(x, func, hitlist); + } else { + struct sk_buff *skb; + struct sk_buff *next; + struct unix_sock *u; + LIST_HEAD(embryos); + + /* For a listening socket collect the queued embryos + * and perform a scan on them as well. + */ + spin_lock(&x->sk_receive_queue.lock); + skb_queue_walk_safe(&x->sk_receive_queue, skb, next) { + u = unix_sk(skb->sk); + + /* An embryo cannot be in-flight, so it's safe + * to use the list link. + */ + BUG_ON(!list_empty(&u->link)); + list_add_tail(&u->link, &embryos); + } + spin_unlock(&x->sk_receive_queue.lock); + + while (!list_empty(&embryos)) { + u = list_entry(embryos.next, struct unix_sock, link); + scan_inflight(&u->sk, func, hitlist); + list_del_init(&u->link); + } + } +} + +static void dec_inflight(struct unix_sock *usk) +{ + atomic_long_dec(&usk->inflight); +} + +static void inc_inflight(struct unix_sock *usk) +{ + atomic_long_inc(&usk->inflight); +} + +static void inc_inflight_move_tail(struct unix_sock *u) +{ + atomic_long_inc(&u->inflight); + /* If this still might be part of a cycle, move it to the end + * of the list, so that it's checked even if it was already + * passed over + */ + if (test_bit(UNIX_GC_MAYBE_CYCLE, &u->gc_flags)) + list_move_tail(&u->link, &gc_candidates); +} + +static bool gc_in_progress; +#define UNIX_INFLIGHT_TRIGGER_GC 16000 + +void wait_for_unix_gc(void) +{ + /* If number of inflight sockets is insane, + * force a garbage collect right now. + * Paired with the WRITE_ONCE() in unix_inflight(), + * unix_notinflight() and gc_in_progress(). + */ + if (READ_ONCE(unix_tot_inflight) > UNIX_INFLIGHT_TRIGGER_GC && + !READ_ONCE(gc_in_progress)) + unix_gc(); + wait_event(unix_gc_wait, gc_in_progress == false); +} + +/* The external entry point: unix_gc() */ +void unix_gc(void) +{ + struct sk_buff *next_skb, *skb; + struct unix_sock *u; + struct unix_sock *next; + struct sk_buff_head hitlist; + struct list_head cursor; + LIST_HEAD(not_cycle_list); + + spin_lock(&unix_gc_lock); + + /* Avoid a recursive GC. */ + if (gc_in_progress) + goto out; + + /* Paired with READ_ONCE() in wait_for_unix_gc(). */ + WRITE_ONCE(gc_in_progress, true); + + /* First, select candidates for garbage collection. Only + * in-flight sockets are considered, and from those only ones + * which don't have any external reference. + * + * Holding unix_gc_lock will protect these candidates from + * being detached, and hence from gaining an external + * reference. Since there are no possible receivers, all + * buffers currently on the candidates' queues stay there + * during the garbage collection. + * + * We also know that no new candidate can be added onto the + * receive queues. Other, non candidate sockets _can_ be + * added to queue, so we must make sure only to touch + * candidates. + */ + list_for_each_entry_safe(u, next, &gc_inflight_list, link) { + long total_refs; + long inflight_refs; + + total_refs = file_count(u->sk.sk_socket->file); + inflight_refs = atomic_long_read(&u->inflight); + + BUG_ON(inflight_refs < 1); + BUG_ON(total_refs < inflight_refs); + if (total_refs == inflight_refs) { + list_move_tail(&u->link, &gc_candidates); + __set_bit(UNIX_GC_CANDIDATE, &u->gc_flags); + __set_bit(UNIX_GC_MAYBE_CYCLE, &u->gc_flags); + } + } + + /* Now remove all internal in-flight reference to children of + * the candidates. + */ + list_for_each_entry(u, &gc_candidates, link) + scan_children(&u->sk, dec_inflight, NULL); + + /* Restore the references for children of all candidates, + * which have remaining references. Do this recursively, so + * only those remain, which form cyclic references. + * + * Use a "cursor" link, to make the list traversal safe, even + * though elements might be moved about. + */ + list_add(&cursor, &gc_candidates); + while (cursor.next != &gc_candidates) { + u = list_entry(cursor.next, struct unix_sock, link); + + /* Move cursor to after the current position. */ + list_move(&cursor, &u->link); + + if (atomic_long_read(&u->inflight) > 0) { + list_move_tail(&u->link, ¬_cycle_list); + __clear_bit(UNIX_GC_MAYBE_CYCLE, &u->gc_flags); + scan_children(&u->sk, inc_inflight_move_tail, NULL); + } + } + list_del(&cursor); + + /* Now gc_candidates contains only garbage. Restore original + * inflight counters for these as well, and remove the skbuffs + * which are creating the cycle(s). + */ + skb_queue_head_init(&hitlist); + list_for_each_entry(u, &gc_candidates, link) + scan_children(&u->sk, inc_inflight, &hitlist); + + /* not_cycle_list contains those sockets which do not make up a + * cycle. Restore these to the inflight list. + */ + while (!list_empty(¬_cycle_list)) { + u = list_entry(not_cycle_list.next, struct unix_sock, link); + __clear_bit(UNIX_GC_CANDIDATE, &u->gc_flags); + list_move_tail(&u->link, &gc_inflight_list); + } + + spin_unlock(&unix_gc_lock); + + /* We need io_uring to clean its registered files, ignore all io_uring + * originated skbs. It's fine as io_uring doesn't keep references to + * other io_uring instances and so killing all other files in the cycle + * will put all io_uring references forcing it to go through normal + * release.path eventually putting registered files. + */ + skb_queue_walk_safe(&hitlist, skb, next_skb) { + if (skb->scm_io_uring) { + __skb_unlink(skb, &hitlist); + skb_queue_tail(&skb->sk->sk_receive_queue, skb); + } + } + + /* Here we are. Hitlist is filled. Die. */ + __skb_queue_purge(&hitlist); + + spin_lock(&unix_gc_lock); + + /* There could be io_uring registered files, just push them back to + * the inflight list + */ + list_for_each_entry_safe(u, next, &gc_candidates, link) + list_move_tail(&u->link, &gc_inflight_list); + + /* All candidates should have been detached by now. */ + BUG_ON(!list_empty(&gc_candidates)); + + /* Paired with READ_ONCE() in wait_for_unix_gc(). */ + WRITE_ONCE(gc_in_progress, false); + + wake_up(&unix_gc_wait); + + out: + spin_unlock(&unix_gc_lock); +} diff --git a/net/unix/scm.c b/net/unix/scm.c new file mode 100644 index 000000000..e8e2a00bb --- /dev/null +++ b/net/unix/scm.c @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "scm.h" + +unsigned int unix_tot_inflight; +EXPORT_SYMBOL(unix_tot_inflight); + +LIST_HEAD(gc_inflight_list); +EXPORT_SYMBOL(gc_inflight_list); + +DEFINE_SPINLOCK(unix_gc_lock); +EXPORT_SYMBOL(unix_gc_lock); + +struct sock *unix_get_socket(struct file *filp) +{ + struct sock *u_sock = NULL; + struct inode *inode = file_inode(filp); + + /* Socket ? */ + if (S_ISSOCK(inode->i_mode) && !(filp->f_mode & FMODE_PATH)) { + struct socket *sock = SOCKET_I(inode); + struct sock *s = sock->sk; + + /* PF_UNIX ? */ + if (s && sock->ops && sock->ops->family == PF_UNIX) + u_sock = s; + } else { + /* Could be an io_uring instance */ + u_sock = io_uring_get_socket(filp); + } + return u_sock; +} +EXPORT_SYMBOL(unix_get_socket); + +/* Keep the number of times in flight count for the file + * descriptor if it is for an AF_UNIX socket. + */ +void unix_inflight(struct user_struct *user, struct file *fp) +{ + struct sock *s = unix_get_socket(fp); + + spin_lock(&unix_gc_lock); + + if (s) { + struct unix_sock *u = unix_sk(s); + + if (atomic_long_inc_return(&u->inflight) == 1) { + BUG_ON(!list_empty(&u->link)); + list_add_tail(&u->link, &gc_inflight_list); + } else { + BUG_ON(list_empty(&u->link)); + } + /* Paired with READ_ONCE() in wait_for_unix_gc() */ + WRITE_ONCE(unix_tot_inflight, unix_tot_inflight + 1); + } + WRITE_ONCE(user->unix_inflight, user->unix_inflight + 1); + spin_unlock(&unix_gc_lock); +} + +void unix_notinflight(struct user_struct *user, struct file *fp) +{ + struct sock *s = unix_get_socket(fp); + + spin_lock(&unix_gc_lock); + + if (s) { + struct unix_sock *u = unix_sk(s); + + BUG_ON(!atomic_long_read(&u->inflight)); + BUG_ON(list_empty(&u->link)); + + if (atomic_long_dec_and_test(&u->inflight)) + list_del_init(&u->link); + /* Paired with READ_ONCE() in wait_for_unix_gc() */ + WRITE_ONCE(unix_tot_inflight, unix_tot_inflight - 1); + } + WRITE_ONCE(user->unix_inflight, user->unix_inflight - 1); + spin_unlock(&unix_gc_lock); +} + +/* + * The "user->unix_inflight" variable is protected by the garbage + * collection lock, and we just read it locklessly here. If you go + * over the limit, there might be a tiny race in actually noticing + * it across threads. Tough. + */ +static inline bool too_many_unix_fds(struct task_struct *p) +{ + struct user_struct *user = current_user(); + + if (unlikely(READ_ONCE(user->unix_inflight) > task_rlimit(p, RLIMIT_NOFILE))) + return !capable(CAP_SYS_RESOURCE) && !capable(CAP_SYS_ADMIN); + return false; +} + +int unix_attach_fds(struct scm_cookie *scm, struct sk_buff *skb) +{ + int i; + + if (too_many_unix_fds(current)) + return -ETOOMANYREFS; + + /* + * Need to duplicate file references for the sake of garbage + * collection. Otherwise a socket in the fps might become a + * candidate for GC while the skb is not yet queued. + */ + UNIXCB(skb).fp = scm_fp_dup(scm->fp); + if (!UNIXCB(skb).fp) + return -ENOMEM; + + for (i = scm->fp->count - 1; i >= 0; i--) + unix_inflight(scm->fp->user, scm->fp->fp[i]); + return 0; +} +EXPORT_SYMBOL(unix_attach_fds); + +void unix_detach_fds(struct scm_cookie *scm, struct sk_buff *skb) +{ + int i; + + scm->fp = UNIXCB(skb).fp; + UNIXCB(skb).fp = NULL; + + for (i = scm->fp->count-1; i >= 0; i--) + unix_notinflight(scm->fp->user, scm->fp->fp[i]); +} +EXPORT_SYMBOL(unix_detach_fds); + +void unix_destruct_scm(struct sk_buff *skb) +{ + struct scm_cookie scm; + + memset(&scm, 0, sizeof(scm)); + scm.pid = UNIXCB(skb).pid; + if (UNIXCB(skb).fp) + unix_detach_fds(&scm, skb); + + /* Alas, it calls VFS */ + /* So fscking what? fput() had been SMP-safe since the last Summer */ + scm_destroy(&scm); + sock_wfree(skb); +} +EXPORT_SYMBOL(unix_destruct_scm); diff --git a/net/unix/scm.h b/net/unix/scm.h new file mode 100644 index 000000000..5a255a477 --- /dev/null +++ b/net/unix/scm.h @@ -0,0 +1,10 @@ +#ifndef NET_UNIX_SCM_H +#define NET_UNIX_SCM_H + +extern struct list_head gc_inflight_list; +extern spinlock_t unix_gc_lock; + +int unix_attach_fds(struct scm_cookie *scm, struct sk_buff *skb); +void unix_detach_fds(struct scm_cookie *scm, struct sk_buff *skb); + +#endif diff --git a/net/unix/sysctl_net_unix.c b/net/unix/sysctl_net_unix.c new file mode 100644 index 000000000..500129aa7 --- /dev/null +++ b/net/unix/sysctl_net_unix.c @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * NET4: Sysctl interface to net af_unix subsystem. + * + * Authors: Mike Shaver. + */ + +#include +#include +#include + +#include + +static struct ctl_table unix_table[] = { + { + .procname = "max_dgram_qlen", + .data = &init_net.unx.sysctl_max_dgram_qlen, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { } +}; + +int __net_init unix_sysctl_register(struct net *net) +{ + struct ctl_table *table; + + if (net_eq(net, &init_net)) { + table = unix_table; + } else { + table = kmemdup(unix_table, sizeof(unix_table), GFP_KERNEL); + if (!table) + goto err_alloc; + + table[0].data = &net->unx.sysctl_max_dgram_qlen; + } + + net->unx.ctl = register_net_sysctl(net, "net/unix", table); + if (net->unx.ctl == NULL) + goto err_reg; + + return 0; + +err_reg: + if (!net_eq(net, &init_net)) + kfree(table); +err_alloc: + return -ENOMEM; +} + +void unix_sysctl_unregister(struct net *net) +{ + struct ctl_table *table; + + table = net->unx.ctl->ctl_table_arg; + unregister_net_sysctl_table(net->unx.ctl); + if (!net_eq(net, &init_net)) + kfree(table); +} diff --git a/net/unix/unix_bpf.c b/net/unix/unix_bpf.c new file mode 100644 index 000000000..bd84785bf --- /dev/null +++ b/net/unix/unix_bpf.c @@ -0,0 +1,198 @@ +// SPDX-License-Identifier: GPL-2.0 +/* Copyright (c) 2021 Cong Wang */ + +#include +#include +#include +#include + +#define unix_sk_has_data(__sk, __psock) \ + ({ !skb_queue_empty(&__sk->sk_receive_queue) || \ + !skb_queue_empty(&__psock->ingress_skb) || \ + !list_empty(&__psock->ingress_msg); \ + }) + +static int unix_msg_wait_data(struct sock *sk, struct sk_psock *psock, + long timeo) +{ + DEFINE_WAIT_FUNC(wait, woken_wake_function); + struct unix_sock *u = unix_sk(sk); + int ret = 0; + + if (sk->sk_shutdown & RCV_SHUTDOWN) + return 1; + + if (!timeo) + return ret; + + add_wait_queue(sk_sleep(sk), &wait); + sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); + if (!unix_sk_has_data(sk, psock)) { + mutex_unlock(&u->iolock); + wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); + mutex_lock(&u->iolock); + ret = unix_sk_has_data(sk, psock); + } + sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); + remove_wait_queue(sk_sleep(sk), &wait); + return ret; +} + +static int __unix_recvmsg(struct sock *sk, struct msghdr *msg, + size_t len, int flags) +{ + if (sk->sk_type == SOCK_DGRAM) + return __unix_dgram_recvmsg(sk, msg, len, flags); + else + return __unix_stream_recvmsg(sk, msg, len, flags); +} + +static int unix_bpf_recvmsg(struct sock *sk, struct msghdr *msg, + size_t len, int flags, int *addr_len) +{ + struct unix_sock *u = unix_sk(sk); + struct sk_psock *psock; + int copied; + + if (!len) + return 0; + + psock = sk_psock_get(sk); + if (unlikely(!psock)) + return __unix_recvmsg(sk, msg, len, flags); + + mutex_lock(&u->iolock); + if (!skb_queue_empty(&sk->sk_receive_queue) && + sk_psock_queue_empty(psock)) { + mutex_unlock(&u->iolock); + sk_psock_put(sk, psock); + return __unix_recvmsg(sk, msg, len, flags); + } + +msg_bytes_ready: + copied = sk_msg_recvmsg(sk, psock, msg, len, flags); + if (!copied) { + long timeo; + int data; + + timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + data = unix_msg_wait_data(sk, psock, timeo); + if (data) { + if (!sk_psock_queue_empty(psock)) + goto msg_bytes_ready; + mutex_unlock(&u->iolock); + sk_psock_put(sk, psock); + return __unix_recvmsg(sk, msg, len, flags); + } + copied = -EAGAIN; + } + mutex_unlock(&u->iolock); + sk_psock_put(sk, psock); + return copied; +} + +static struct proto *unix_dgram_prot_saved __read_mostly; +static DEFINE_SPINLOCK(unix_dgram_prot_lock); +static struct proto unix_dgram_bpf_prot; + +static struct proto *unix_stream_prot_saved __read_mostly; +static DEFINE_SPINLOCK(unix_stream_prot_lock); +static struct proto unix_stream_bpf_prot; + +static void unix_dgram_bpf_rebuild_protos(struct proto *prot, const struct proto *base) +{ + *prot = *base; + prot->close = sock_map_close; + prot->recvmsg = unix_bpf_recvmsg; + prot->sock_is_readable = sk_msg_is_readable; +} + +static void unix_stream_bpf_rebuild_protos(struct proto *prot, + const struct proto *base) +{ + *prot = *base; + prot->close = sock_map_close; + prot->recvmsg = unix_bpf_recvmsg; + prot->sock_is_readable = sk_msg_is_readable; + prot->unhash = sock_map_unhash; +} + +static void unix_dgram_bpf_check_needs_rebuild(struct proto *ops) +{ + if (unlikely(ops != smp_load_acquire(&unix_dgram_prot_saved))) { + spin_lock_bh(&unix_dgram_prot_lock); + if (likely(ops != unix_dgram_prot_saved)) { + unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, ops); + smp_store_release(&unix_dgram_prot_saved, ops); + } + spin_unlock_bh(&unix_dgram_prot_lock); + } +} + +static void unix_stream_bpf_check_needs_rebuild(struct proto *ops) +{ + if (unlikely(ops != smp_load_acquire(&unix_stream_prot_saved))) { + spin_lock_bh(&unix_stream_prot_lock); + if (likely(ops != unix_stream_prot_saved)) { + unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, ops); + smp_store_release(&unix_stream_prot_saved, ops); + } + spin_unlock_bh(&unix_stream_prot_lock); + } +} + +int unix_dgram_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) +{ + if (sk->sk_type != SOCK_DGRAM) + return -EOPNOTSUPP; + + if (restore) { + sk->sk_write_space = psock->saved_write_space; + sock_replace_proto(sk, psock->sk_proto); + return 0; + } + + unix_dgram_bpf_check_needs_rebuild(psock->sk_proto); + sock_replace_proto(sk, &unix_dgram_bpf_prot); + return 0; +} + +int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) +{ + struct sock *sk_pair; + + /* Restore does not decrement the sk_pair reference yet because we must + * keep the a reference to the socket until after an RCU grace period + * and any pending sends have completed. + */ + if (restore) { + sk->sk_write_space = psock->saved_write_space; + sock_replace_proto(sk, psock->sk_proto); + return 0; + } + + /* psock_update_sk_prot can be called multiple times if psock is + * added to multiple maps and/or slots in the same map. There is + * also an edge case where replacing a psock with itself can trigger + * an extra psock_update_sk_prot during the insert process. So it + * must be safe to do multiple calls. Here we need to ensure we don't + * increment the refcnt through sock_hold many times. There will only + * be a single matching destroy operation. + */ + if (!psock->sk_pair) { + sk_pair = unix_peer(sk); + sock_hold(sk_pair); + psock->sk_pair = sk_pair; + } + + unix_stream_bpf_check_needs_rebuild(psock->sk_proto); + sock_replace_proto(sk, &unix_stream_bpf_prot); + return 0; +} + +void __init unix_bpf_build_proto(void) +{ + unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, &unix_dgram_proto); + unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, &unix_stream_proto); + +} diff --git a/net/vmw_vsock/Kconfig b/net/vmw_vsock/Kconfig new file mode 100644 index 000000000..56356d298 --- /dev/null +++ b/net/vmw_vsock/Kconfig @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# Vsock protocol +# + +config VSOCKETS + tristate "Virtual Socket protocol" + help + Virtual Socket Protocol is a socket protocol similar to TCP/IP + allowing communication between Virtual Machines and hypervisor + or host. + + You should also select one or more hypervisor-specific transports + below. + + To compile this driver as a module, choose M here: the module + will be called vsock. If unsure, say N. + +config VSOCKETS_DIAG + tristate "Virtual Sockets monitoring interface" + depends on VSOCKETS + default y + help + Support for PF_VSOCK sockets monitoring interface used by the ss tool. + If unsure, say Y. + + Enable this module so userspace applications can query open sockets. + +config VSOCKETS_LOOPBACK + tristate "Virtual Sockets loopback transport" + depends on VSOCKETS + default y + select VIRTIO_VSOCKETS_COMMON + help + This module implements a loopback transport for Virtual Sockets, + using vmw_vsock_virtio_transport_common. + + To compile this driver as a module, choose M here: the module + will be called vsock_loopback. If unsure, say N. + +config VMWARE_VMCI_VSOCKETS + tristate "VMware VMCI transport for Virtual Sockets" + depends on VSOCKETS && VMWARE_VMCI + help + This module implements a VMCI transport for Virtual Sockets. + + Enable this transport if your Virtual Machine runs on a VMware + hypervisor. + + To compile this driver as a module, choose M here: the module + will be called vmw_vsock_vmci_transport. If unsure, say N. + +config VIRTIO_VSOCKETS + tristate "virtio transport for Virtual Sockets" + depends on VSOCKETS && VIRTIO + select VIRTIO_VSOCKETS_COMMON + help + This module implements a virtio transport for Virtual Sockets. + + Enable this transport if your Virtual Machine host supports Virtual + Sockets over virtio. + + To compile this driver as a module, choose M here: the module will be + called vmw_vsock_virtio_transport. If unsure, say N. + +config VIRTIO_VSOCKETS_COMMON + tristate + help + This option is selected by any driver which needs to access + the virtio_vsock. The module will be called + vmw_vsock_virtio_transport_common. + +config HYPERV_VSOCKETS + tristate "Hyper-V transport for Virtual Sockets" + depends on VSOCKETS && HYPERV + help + This module implements a Hyper-V transport for Virtual Sockets. + + Enable this transport if your Virtual Machine host supports Virtual + Sockets over Hyper-V VMBus. + + To compile this driver as a module, choose M here: the module will be + called hv_sock. If unsure, say N. diff --git a/net/vmw_vsock/Makefile b/net/vmw_vsock/Makefile new file mode 100644 index 000000000..6a943ec95 --- /dev/null +++ b/net/vmw_vsock/Makefile @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: GPL-2.0 +obj-$(CONFIG_VSOCKETS) += vsock.o +obj-$(CONFIG_VSOCKETS_DIAG) += vsock_diag.o +obj-$(CONFIG_VMWARE_VMCI_VSOCKETS) += vmw_vsock_vmci_transport.o +obj-$(CONFIG_VIRTIO_VSOCKETS) += vmw_vsock_virtio_transport.o +obj-$(CONFIG_VIRTIO_VSOCKETS_COMMON) += vmw_vsock_virtio_transport_common.o +obj-$(CONFIG_HYPERV_VSOCKETS) += hv_sock.o +obj-$(CONFIG_VSOCKETS_LOOPBACK) += vsock_loopback.o + +vsock-y += af_vsock.o af_vsock_tap.o vsock_addr.o + +vsock_diag-y += diag.o + +vmw_vsock_vmci_transport-y += vmci_transport.o vmci_transport_notify.o \ + vmci_transport_notify_qstate.o + +vmw_vsock_virtio_transport-y += virtio_transport.o + +vmw_vsock_virtio_transport_common-y += virtio_transport_common.o + +hv_sock-y += hyperv_transport.o diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c new file mode 100644 index 000000000..84471745c --- /dev/null +++ b/net/vmw_vsock/af_vsock.c @@ -0,0 +1,2462 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * VMware vSockets Driver + * + * Copyright (C) 2007-2013 VMware, Inc. All rights reserved. + */ + +/* Implementation notes: + * + * - There are two kinds of sockets: those created by user action (such as + * calling socket(2)) and those created by incoming connection request packets. + * + * - There are two "global" tables, one for bound sockets (sockets that have + * specified an address that they are responsible for) and one for connected + * sockets (sockets that have established a connection with another socket). + * These tables are "global" in that all sockets on the system are placed + * within them. - Note, though, that the bound table contains an extra entry + * for a list of unbound sockets and SOCK_DGRAM sockets will always remain in + * that list. The bound table is used solely for lookup of sockets when packets + * are received and that's not necessary for SOCK_DGRAM sockets since we create + * a datagram handle for each and need not perform a lookup. Keeping SOCK_DGRAM + * sockets out of the bound hash buckets will reduce the chance of collisions + * when looking for SOCK_STREAM sockets and prevents us from having to check the + * socket type in the hash table lookups. + * + * - Sockets created by user action will either be "client" sockets that + * initiate a connection or "server" sockets that listen for connections; we do + * not support simultaneous connects (two "client" sockets connecting). + * + * - "Server" sockets are referred to as listener sockets throughout this + * implementation because they are in the TCP_LISTEN state. When a + * connection request is received (the second kind of socket mentioned above), + * we create a new socket and refer to it as a pending socket. These pending + * sockets are placed on the pending connection list of the listener socket. + * When future packets are received for the address the listener socket is + * bound to, we check if the source of the packet is from one that has an + * existing pending connection. If it does, we process the packet for the + * pending socket. When that socket reaches the connected state, it is removed + * from the listener socket's pending list and enqueued in the listener + * socket's accept queue. Callers of accept(2) will accept connected sockets + * from the listener socket's accept queue. If the socket cannot be accepted + * for some reason then it is marked rejected. Once the connection is + * accepted, it is owned by the user process and the responsibility for cleanup + * falls with that user process. + * + * - It is possible that these pending sockets will never reach the connected + * state; in fact, we may never receive another packet after the connection + * request. Because of this, we must schedule a cleanup function to run in the + * future, after some amount of time passes where a connection should have been + * established. This function ensures that the socket is off all lists so it + * cannot be retrieved, then drops all references to the socket so it is cleaned + * up (sock_put() -> sk_free() -> our sk_destruct implementation). Note this + * function will also cleanup rejected sockets, those that reach the connected + * state but leave it before they have been accepted. + * + * - Lock ordering for pending or accept queue sockets is: + * + * lock_sock(listener); + * lock_sock_nested(pending, SINGLE_DEPTH_NESTING); + * + * Using explicit nested locking keeps lockdep happy since normally only one + * lock of a given class may be taken at a time. + * + * - Sockets created by user action will be cleaned up when the user process + * calls close(2), causing our release implementation to be called. Our release + * implementation will perform some cleanup then drop the last reference so our + * sk_destruct implementation is invoked. Our sk_destruct implementation will + * perform additional cleanup that's common for both types of sockets. + * + * - A socket's reference count is what ensures that the structure won't be + * freed. Each entry in a list (such as the "global" bound and connected tables + * and the listener socket's pending list and connected queue) ensures a + * reference. When we defer work until process context and pass a socket as our + * argument, we must ensure the reference count is increased to ensure the + * socket isn't freed before the function is run; the deferred function will + * then drop the reference. + * + * - sk->sk_state uses the TCP state constants because they are widely used by + * other address families and exposed to userspace tools like ss(8): + * + * TCP_CLOSE - unconnected + * TCP_SYN_SENT - connecting + * TCP_ESTABLISHED - connected + * TCP_CLOSING - disconnecting + * TCP_LISTEN - listening + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr); +static void vsock_sk_destruct(struct sock *sk); +static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb); + +/* Protocol family. */ +static struct proto vsock_proto = { + .name = "AF_VSOCK", + .owner = THIS_MODULE, + .obj_size = sizeof(struct vsock_sock), +}; + +/* The default peer timeout indicates how long we will wait for a peer response + * to a control message. + */ +#define VSOCK_DEFAULT_CONNECT_TIMEOUT (2 * HZ) + +#define VSOCK_DEFAULT_BUFFER_SIZE (1024 * 256) +#define VSOCK_DEFAULT_BUFFER_MAX_SIZE (1024 * 256) +#define VSOCK_DEFAULT_BUFFER_MIN_SIZE 128 + +/* Transport used for host->guest communication */ +static const struct vsock_transport *transport_h2g; +/* Transport used for guest->host communication */ +static const struct vsock_transport *transport_g2h; +/* Transport used for DGRAM communication */ +static const struct vsock_transport *transport_dgram; +/* Transport used for local communication */ +static const struct vsock_transport *transport_local; +static DEFINE_MUTEX(vsock_register_mutex); + +/**** UTILS ****/ + +/* Each bound VSocket is stored in the bind hash table and each connected + * VSocket is stored in the connected hash table. + * + * Unbound sockets are all put on the same list attached to the end of the hash + * table (vsock_unbound_sockets). Bound sockets are added to the hash table in + * the bucket that their local address hashes to (vsock_bound_sockets(addr) + * represents the list that addr hashes to). + * + * Specifically, we initialize the vsock_bind_table array to a size of + * VSOCK_HASH_SIZE + 1 so that vsock_bind_table[0] through + * vsock_bind_table[VSOCK_HASH_SIZE - 1] are for bound sockets and + * vsock_bind_table[VSOCK_HASH_SIZE] is for unbound sockets. The hash function + * mods with VSOCK_HASH_SIZE to ensure this. + */ +#define MAX_PORT_RETRIES 24 + +#define VSOCK_HASH(addr) ((addr)->svm_port % VSOCK_HASH_SIZE) +#define vsock_bound_sockets(addr) (&vsock_bind_table[VSOCK_HASH(addr)]) +#define vsock_unbound_sockets (&vsock_bind_table[VSOCK_HASH_SIZE]) + +/* XXX This can probably be implemented in a better way. */ +#define VSOCK_CONN_HASH(src, dst) \ + (((src)->svm_cid ^ (dst)->svm_port) % VSOCK_HASH_SIZE) +#define vsock_connected_sockets(src, dst) \ + (&vsock_connected_table[VSOCK_CONN_HASH(src, dst)]) +#define vsock_connected_sockets_vsk(vsk) \ + vsock_connected_sockets(&(vsk)->remote_addr, &(vsk)->local_addr) + +struct list_head vsock_bind_table[VSOCK_HASH_SIZE + 1]; +EXPORT_SYMBOL_GPL(vsock_bind_table); +struct list_head vsock_connected_table[VSOCK_HASH_SIZE]; +EXPORT_SYMBOL_GPL(vsock_connected_table); +DEFINE_SPINLOCK(vsock_table_lock); +EXPORT_SYMBOL_GPL(vsock_table_lock); + +/* Autobind this socket to the local address if necessary. */ +static int vsock_auto_bind(struct vsock_sock *vsk) +{ + struct sock *sk = sk_vsock(vsk); + struct sockaddr_vm local_addr; + + if (vsock_addr_bound(&vsk->local_addr)) + return 0; + vsock_addr_init(&local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); + return __vsock_bind(sk, &local_addr); +} + +static void vsock_init_tables(void) +{ + int i; + + for (i = 0; i < ARRAY_SIZE(vsock_bind_table); i++) + INIT_LIST_HEAD(&vsock_bind_table[i]); + + for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++) + INIT_LIST_HEAD(&vsock_connected_table[i]); +} + +static void __vsock_insert_bound(struct list_head *list, + struct vsock_sock *vsk) +{ + sock_hold(&vsk->sk); + list_add(&vsk->bound_table, list); +} + +static void __vsock_insert_connected(struct list_head *list, + struct vsock_sock *vsk) +{ + sock_hold(&vsk->sk); + list_add(&vsk->connected_table, list); +} + +static void __vsock_remove_bound(struct vsock_sock *vsk) +{ + list_del_init(&vsk->bound_table); + sock_put(&vsk->sk); +} + +static void __vsock_remove_connected(struct vsock_sock *vsk) +{ + list_del_init(&vsk->connected_table); + sock_put(&vsk->sk); +} + +static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr) +{ + struct vsock_sock *vsk; + + list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) { + if (vsock_addr_equals_addr(addr, &vsk->local_addr)) + return sk_vsock(vsk); + + if (addr->svm_port == vsk->local_addr.svm_port && + (vsk->local_addr.svm_cid == VMADDR_CID_ANY || + addr->svm_cid == VMADDR_CID_ANY)) + return sk_vsock(vsk); + } + + return NULL; +} + +static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src, + struct sockaddr_vm *dst) +{ + struct vsock_sock *vsk; + + list_for_each_entry(vsk, vsock_connected_sockets(src, dst), + connected_table) { + if (vsock_addr_equals_addr(src, &vsk->remote_addr) && + dst->svm_port == vsk->local_addr.svm_port) { + return sk_vsock(vsk); + } + } + + return NULL; +} + +static void vsock_insert_unbound(struct vsock_sock *vsk) +{ + spin_lock_bh(&vsock_table_lock); + __vsock_insert_bound(vsock_unbound_sockets, vsk); + spin_unlock_bh(&vsock_table_lock); +} + +void vsock_insert_connected(struct vsock_sock *vsk) +{ + struct list_head *list = vsock_connected_sockets( + &vsk->remote_addr, &vsk->local_addr); + + spin_lock_bh(&vsock_table_lock); + __vsock_insert_connected(list, vsk); + spin_unlock_bh(&vsock_table_lock); +} +EXPORT_SYMBOL_GPL(vsock_insert_connected); + +void vsock_remove_bound(struct vsock_sock *vsk) +{ + spin_lock_bh(&vsock_table_lock); + if (__vsock_in_bound_table(vsk)) + __vsock_remove_bound(vsk); + spin_unlock_bh(&vsock_table_lock); +} +EXPORT_SYMBOL_GPL(vsock_remove_bound); + +void vsock_remove_connected(struct vsock_sock *vsk) +{ + spin_lock_bh(&vsock_table_lock); + if (__vsock_in_connected_table(vsk)) + __vsock_remove_connected(vsk); + spin_unlock_bh(&vsock_table_lock); +} +EXPORT_SYMBOL_GPL(vsock_remove_connected); + +struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr) +{ + struct sock *sk; + + spin_lock_bh(&vsock_table_lock); + sk = __vsock_find_bound_socket(addr); + if (sk) + sock_hold(sk); + + spin_unlock_bh(&vsock_table_lock); + + return sk; +} +EXPORT_SYMBOL_GPL(vsock_find_bound_socket); + +struct sock *vsock_find_connected_socket(struct sockaddr_vm *src, + struct sockaddr_vm *dst) +{ + struct sock *sk; + + spin_lock_bh(&vsock_table_lock); + sk = __vsock_find_connected_socket(src, dst); + if (sk) + sock_hold(sk); + + spin_unlock_bh(&vsock_table_lock); + + return sk; +} +EXPORT_SYMBOL_GPL(vsock_find_connected_socket); + +void vsock_remove_sock(struct vsock_sock *vsk) +{ + vsock_remove_bound(vsk); + vsock_remove_connected(vsk); +} +EXPORT_SYMBOL_GPL(vsock_remove_sock); + +void vsock_for_each_connected_socket(struct vsock_transport *transport, + void (*fn)(struct sock *sk)) +{ + int i; + + spin_lock_bh(&vsock_table_lock); + + for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++) { + struct vsock_sock *vsk; + list_for_each_entry(vsk, &vsock_connected_table[i], + connected_table) { + if (vsk->transport != transport) + continue; + + fn(sk_vsock(vsk)); + } + } + + spin_unlock_bh(&vsock_table_lock); +} +EXPORT_SYMBOL_GPL(vsock_for_each_connected_socket); + +void vsock_add_pending(struct sock *listener, struct sock *pending) +{ + struct vsock_sock *vlistener; + struct vsock_sock *vpending; + + vlistener = vsock_sk(listener); + vpending = vsock_sk(pending); + + sock_hold(pending); + sock_hold(listener); + list_add_tail(&vpending->pending_links, &vlistener->pending_links); +} +EXPORT_SYMBOL_GPL(vsock_add_pending); + +void vsock_remove_pending(struct sock *listener, struct sock *pending) +{ + struct vsock_sock *vpending = vsock_sk(pending); + + list_del_init(&vpending->pending_links); + sock_put(listener); + sock_put(pending); +} +EXPORT_SYMBOL_GPL(vsock_remove_pending); + +void vsock_enqueue_accept(struct sock *listener, struct sock *connected) +{ + struct vsock_sock *vlistener; + struct vsock_sock *vconnected; + + vlistener = vsock_sk(listener); + vconnected = vsock_sk(connected); + + sock_hold(connected); + sock_hold(listener); + list_add_tail(&vconnected->accept_queue, &vlistener->accept_queue); +} +EXPORT_SYMBOL_GPL(vsock_enqueue_accept); + +static bool vsock_use_local_transport(unsigned int remote_cid) +{ + if (!transport_local) + return false; + + if (remote_cid == VMADDR_CID_LOCAL) + return true; + + if (transport_g2h) { + return remote_cid == transport_g2h->get_local_cid(); + } else { + return remote_cid == VMADDR_CID_HOST; + } +} + +static void vsock_deassign_transport(struct vsock_sock *vsk) +{ + if (!vsk->transport) + return; + + vsk->transport->destruct(vsk); + module_put(vsk->transport->module); + vsk->transport = NULL; +} + +/* Assign a transport to a socket and call the .init transport callback. + * + * Note: for connection oriented socket this must be called when vsk->remote_addr + * is set (e.g. during the connect() or when a connection request on a listener + * socket is received). + * The vsk->remote_addr is used to decide which transport to use: + * - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if + * g2h is not loaded, will use local transport; + * - remote CID <= VMADDR_CID_HOST or h2g is not loaded or remote flags field + * includes VMADDR_FLAG_TO_HOST flag value, will use guest->host transport; + * - remote CID > VMADDR_CID_HOST will use host->guest transport; + */ +int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) +{ + const struct vsock_transport *new_transport; + struct sock *sk = sk_vsock(vsk); + unsigned int remote_cid = vsk->remote_addr.svm_cid; + __u8 remote_flags; + int ret; + + /* If the packet is coming with the source and destination CIDs higher + * than VMADDR_CID_HOST, then a vsock channel where all the packets are + * forwarded to the host should be established. Then the host will + * need to forward the packets to the guest. + * + * The flag is set on the (listen) receive path (psk is not NULL). On + * the connect path the flag can be set by the user space application. + */ + if (psk && vsk->local_addr.svm_cid > VMADDR_CID_HOST && + vsk->remote_addr.svm_cid > VMADDR_CID_HOST) + vsk->remote_addr.svm_flags |= VMADDR_FLAG_TO_HOST; + + remote_flags = vsk->remote_addr.svm_flags; + + switch (sk->sk_type) { + case SOCK_DGRAM: + new_transport = transport_dgram; + break; + case SOCK_STREAM: + case SOCK_SEQPACKET: + if (vsock_use_local_transport(remote_cid)) + new_transport = transport_local; + else if (remote_cid <= VMADDR_CID_HOST || !transport_h2g || + (remote_flags & VMADDR_FLAG_TO_HOST)) + new_transport = transport_g2h; + else + new_transport = transport_h2g; + break; + default: + return -ESOCKTNOSUPPORT; + } + + if (vsk->transport) { + if (vsk->transport == new_transport) + return 0; + + /* transport->release() must be called with sock lock acquired. + * This path can only be taken during vsock_connect(), where we + * have already held the sock lock. In the other cases, this + * function is called on a new socket which is not assigned to + * any transport. + */ + vsk->transport->release(vsk); + vsock_deassign_transport(vsk); + } + + /* We increase the module refcnt to prevent the transport unloading + * while there are open sockets assigned to it. + */ + if (!new_transport || !try_module_get(new_transport->module)) + return -ENODEV; + + if (sk->sk_type == SOCK_SEQPACKET) { + if (!new_transport->seqpacket_allow || + !new_transport->seqpacket_allow(remote_cid)) { + module_put(new_transport->module); + return -ESOCKTNOSUPPORT; + } + } + + ret = new_transport->init(vsk, psk); + if (ret) { + module_put(new_transport->module); + return ret; + } + + vsk->transport = new_transport; + + return 0; +} +EXPORT_SYMBOL_GPL(vsock_assign_transport); + +bool vsock_find_cid(unsigned int cid) +{ + if (transport_g2h && cid == transport_g2h->get_local_cid()) + return true; + + if (transport_h2g && cid == VMADDR_CID_HOST) + return true; + + if (transport_local && cid == VMADDR_CID_LOCAL) + return true; + + return false; +} +EXPORT_SYMBOL_GPL(vsock_find_cid); + +static struct sock *vsock_dequeue_accept(struct sock *listener) +{ + struct vsock_sock *vlistener; + struct vsock_sock *vconnected; + + vlistener = vsock_sk(listener); + + if (list_empty(&vlistener->accept_queue)) + return NULL; + + vconnected = list_entry(vlistener->accept_queue.next, + struct vsock_sock, accept_queue); + + list_del_init(&vconnected->accept_queue); + sock_put(listener); + /* The caller will need a reference on the connected socket so we let + * it call sock_put(). + */ + + return sk_vsock(vconnected); +} + +static bool vsock_is_accept_queue_empty(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + return list_empty(&vsk->accept_queue); +} + +static bool vsock_is_pending(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + return !list_empty(&vsk->pending_links); +} + +static int vsock_send_shutdown(struct sock *sk, int mode) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + if (!vsk->transport) + return -ENODEV; + + return vsk->transport->shutdown(vsk, mode); +} + +static void vsock_pending_work(struct work_struct *work) +{ + struct sock *sk; + struct sock *listener; + struct vsock_sock *vsk; + bool cleanup; + + vsk = container_of(work, struct vsock_sock, pending_work.work); + sk = sk_vsock(vsk); + listener = vsk->listener; + cleanup = true; + + lock_sock(listener); + lock_sock_nested(sk, SINGLE_DEPTH_NESTING); + + if (vsock_is_pending(sk)) { + vsock_remove_pending(listener, sk); + + sk_acceptq_removed(listener); + } else if (!vsk->rejected) { + /* We are not on the pending list and accept() did not reject + * us, so we must have been accepted by our user process. We + * just need to drop our references to the sockets and be on + * our way. + */ + cleanup = false; + goto out; + } + + /* We need to remove ourself from the global connected sockets list so + * incoming packets can't find this socket, and to reduce the reference + * count. + */ + vsock_remove_connected(vsk); + + sk->sk_state = TCP_CLOSE; + +out: + release_sock(sk); + release_sock(listener); + if (cleanup) + sock_put(sk); + + sock_put(sk); + sock_put(listener); +} + +/**** SOCKET OPERATIONS ****/ + +static int __vsock_bind_connectible(struct vsock_sock *vsk, + struct sockaddr_vm *addr) +{ + static u32 port; + struct sockaddr_vm new_addr; + + if (!port) + port = LAST_RESERVED_PORT + 1 + + prandom_u32_max(U32_MAX - LAST_RESERVED_PORT); + + vsock_addr_init(&new_addr, addr->svm_cid, addr->svm_port); + + if (addr->svm_port == VMADDR_PORT_ANY) { + bool found = false; + unsigned int i; + + for (i = 0; i < MAX_PORT_RETRIES; i++) { + if (port <= LAST_RESERVED_PORT) + port = LAST_RESERVED_PORT + 1; + + new_addr.svm_port = port++; + + if (!__vsock_find_bound_socket(&new_addr)) { + found = true; + break; + } + } + + if (!found) + return -EADDRNOTAVAIL; + } else { + /* If port is in reserved range, ensure caller + * has necessary privileges. + */ + if (addr->svm_port <= LAST_RESERVED_PORT && + !capable(CAP_NET_BIND_SERVICE)) { + return -EACCES; + } + + if (__vsock_find_bound_socket(&new_addr)) + return -EADDRINUSE; + } + + vsock_addr_init(&vsk->local_addr, new_addr.svm_cid, new_addr.svm_port); + + /* Remove connection oriented sockets from the unbound list and add them + * to the hash table for easy lookup by its address. The unbound list + * is simply an extra entry at the end of the hash table, a trick used + * by AF_UNIX. + */ + __vsock_remove_bound(vsk); + __vsock_insert_bound(vsock_bound_sockets(&vsk->local_addr), vsk); + + return 0; +} + +static int __vsock_bind_dgram(struct vsock_sock *vsk, + struct sockaddr_vm *addr) +{ + return vsk->transport->dgram_bind(vsk, addr); +} + +static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr) +{ + struct vsock_sock *vsk = vsock_sk(sk); + int retval; + + /* First ensure this socket isn't already bound. */ + if (vsock_addr_bound(&vsk->local_addr)) + return -EINVAL; + + /* Now bind to the provided address or select appropriate values if + * none are provided (VMADDR_CID_ANY and VMADDR_PORT_ANY). Note that + * like AF_INET prevents binding to a non-local IP address (in most + * cases), we only allow binding to a local CID. + */ + if (addr->svm_cid != VMADDR_CID_ANY && !vsock_find_cid(addr->svm_cid)) + return -EADDRNOTAVAIL; + + switch (sk->sk_socket->type) { + case SOCK_STREAM: + case SOCK_SEQPACKET: + spin_lock_bh(&vsock_table_lock); + retval = __vsock_bind_connectible(vsk, addr); + spin_unlock_bh(&vsock_table_lock); + break; + + case SOCK_DGRAM: + retval = __vsock_bind_dgram(vsk, addr); + break; + + default: + retval = -EINVAL; + break; + } + + return retval; +} + +static void vsock_connect_timeout(struct work_struct *work); + +static struct sock *__vsock_create(struct net *net, + struct socket *sock, + struct sock *parent, + gfp_t priority, + unsigned short type, + int kern) +{ + struct sock *sk; + struct vsock_sock *psk; + struct vsock_sock *vsk; + + sk = sk_alloc(net, AF_VSOCK, priority, &vsock_proto, kern); + if (!sk) + return NULL; + + sock_init_data(sock, sk); + + /* sk->sk_type is normally set in sock_init_data, but only if sock is + * non-NULL. We make sure that our sockets always have a type by + * setting it here if needed. + */ + if (!sock) + sk->sk_type = type; + + vsk = vsock_sk(sk); + vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); + vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); + + sk->sk_destruct = vsock_sk_destruct; + sk->sk_backlog_rcv = vsock_queue_rcv_skb; + sock_reset_flag(sk, SOCK_DONE); + + INIT_LIST_HEAD(&vsk->bound_table); + INIT_LIST_HEAD(&vsk->connected_table); + vsk->listener = NULL; + INIT_LIST_HEAD(&vsk->pending_links); + INIT_LIST_HEAD(&vsk->accept_queue); + vsk->rejected = false; + vsk->sent_request = false; + vsk->ignore_connecting_rst = false; + vsk->peer_shutdown = 0; + INIT_DELAYED_WORK(&vsk->connect_work, vsock_connect_timeout); + INIT_DELAYED_WORK(&vsk->pending_work, vsock_pending_work); + + psk = parent ? vsock_sk(parent) : NULL; + if (parent) { + vsk->trusted = psk->trusted; + vsk->owner = get_cred(psk->owner); + vsk->connect_timeout = psk->connect_timeout; + vsk->buffer_size = psk->buffer_size; + vsk->buffer_min_size = psk->buffer_min_size; + vsk->buffer_max_size = psk->buffer_max_size; + security_sk_clone(parent, sk); + } else { + vsk->trusted = ns_capable_noaudit(&init_user_ns, CAP_NET_ADMIN); + vsk->owner = get_current_cred(); + vsk->connect_timeout = VSOCK_DEFAULT_CONNECT_TIMEOUT; + vsk->buffer_size = VSOCK_DEFAULT_BUFFER_SIZE; + vsk->buffer_min_size = VSOCK_DEFAULT_BUFFER_MIN_SIZE; + vsk->buffer_max_size = VSOCK_DEFAULT_BUFFER_MAX_SIZE; + } + + return sk; +} + +static bool sock_type_connectible(u16 type) +{ + return (type == SOCK_STREAM) || (type == SOCK_SEQPACKET); +} + +static void __vsock_release(struct sock *sk, int level) +{ + if (sk) { + struct sock *pending; + struct vsock_sock *vsk; + + vsk = vsock_sk(sk); + pending = NULL; /* Compiler warning. */ + + /* When "level" is SINGLE_DEPTH_NESTING, use the nested + * version to avoid the warning "possible recursive locking + * detected". When "level" is 0, lock_sock_nested(sk, level) + * is the same as lock_sock(sk). + */ + lock_sock_nested(sk, level); + + if (vsk->transport) + vsk->transport->release(vsk); + else if (sock_type_connectible(sk->sk_type)) + vsock_remove_sock(vsk); + + sock_orphan(sk); + sk->sk_shutdown = SHUTDOWN_MASK; + + skb_queue_purge(&sk->sk_receive_queue); + + /* Clean up any sockets that never were accepted. */ + while ((pending = vsock_dequeue_accept(sk)) != NULL) { + __vsock_release(pending, SINGLE_DEPTH_NESTING); + sock_put(pending); + } + + release_sock(sk); + sock_put(sk); + } +} + +static void vsock_sk_destruct(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + vsock_deassign_transport(vsk); + + /* When clearing these addresses, there's no need to set the family and + * possibly register the address family with the kernel. + */ + vsock_addr_init(&vsk->local_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); + vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); + + put_cred(vsk->owner); +} + +static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb) +{ + int err; + + err = sock_queue_rcv_skb(sk, skb); + if (err) + kfree_skb(skb); + + return err; +} + +struct sock *vsock_create_connected(struct sock *parent) +{ + return __vsock_create(sock_net(parent), NULL, parent, GFP_KERNEL, + parent->sk_type, 0); +} +EXPORT_SYMBOL_GPL(vsock_create_connected); + +s64 vsock_stream_has_data(struct vsock_sock *vsk) +{ + return vsk->transport->stream_has_data(vsk); +} +EXPORT_SYMBOL_GPL(vsock_stream_has_data); + +static s64 vsock_connectible_has_data(struct vsock_sock *vsk) +{ + struct sock *sk = sk_vsock(vsk); + + if (sk->sk_type == SOCK_SEQPACKET) + return vsk->transport->seqpacket_has_data(vsk); + else + return vsock_stream_has_data(vsk); +} + +s64 vsock_stream_has_space(struct vsock_sock *vsk) +{ + return vsk->transport->stream_has_space(vsk); +} +EXPORT_SYMBOL_GPL(vsock_stream_has_space); + +void vsock_data_ready(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + if (vsock_stream_has_data(vsk) >= sk->sk_rcvlowat || + sock_flag(sk, SOCK_DONE)) + sk->sk_data_ready(sk); +} +EXPORT_SYMBOL_GPL(vsock_data_ready); + +static int vsock_release(struct socket *sock) +{ + __vsock_release(sock->sk, 0); + sock->sk = NULL; + sock->state = SS_FREE; + + return 0; +} + +static int +vsock_bind(struct socket *sock, struct sockaddr *addr, int addr_len) +{ + int err; + struct sock *sk; + struct sockaddr_vm *vm_addr; + + sk = sock->sk; + + if (vsock_addr_cast(addr, addr_len, &vm_addr) != 0) + return -EINVAL; + + lock_sock(sk); + err = __vsock_bind(sk, vm_addr); + release_sock(sk); + + return err; +} + +static int vsock_getname(struct socket *sock, + struct sockaddr *addr, int peer) +{ + int err; + struct sock *sk; + struct vsock_sock *vsk; + struct sockaddr_vm *vm_addr; + + sk = sock->sk; + vsk = vsock_sk(sk); + err = 0; + + lock_sock(sk); + + if (peer) { + if (sock->state != SS_CONNECTED) { + err = -ENOTCONN; + goto out; + } + vm_addr = &vsk->remote_addr; + } else { + vm_addr = &vsk->local_addr; + } + + if (!vm_addr) { + err = -EINVAL; + goto out; + } + + /* sys_getsockname() and sys_getpeername() pass us a + * MAX_SOCK_ADDR-sized buffer and don't set addr_len. Unfortunately + * that macro is defined in socket.c instead of .h, so we hardcode its + * value here. + */ + BUILD_BUG_ON(sizeof(*vm_addr) > 128); + memcpy(addr, vm_addr, sizeof(*vm_addr)); + err = sizeof(*vm_addr); + +out: + release_sock(sk); + return err; +} + +static int vsock_shutdown(struct socket *sock, int mode) +{ + int err; + struct sock *sk; + + /* User level uses SHUT_RD (0) and SHUT_WR (1), but the kernel uses + * RCV_SHUTDOWN (1) and SEND_SHUTDOWN (2), so we must increment mode + * here like the other address families do. Note also that the + * increment makes SHUT_RDWR (2) into RCV_SHUTDOWN | SEND_SHUTDOWN (3), + * which is what we want. + */ + mode++; + + if ((mode & ~SHUTDOWN_MASK) || !mode) + return -EINVAL; + + /* If this is a connection oriented socket and it is not connected then + * bail out immediately. If it is a DGRAM socket then we must first + * kick the socket so that it wakes up from any sleeping calls, for + * example recv(), and then afterwards return the error. + */ + + sk = sock->sk; + + lock_sock(sk); + if (sock->state == SS_UNCONNECTED) { + err = -ENOTCONN; + if (sock_type_connectible(sk->sk_type)) + goto out; + } else { + sock->state = SS_DISCONNECTING; + err = 0; + } + + /* Receive and send shutdowns are treated alike. */ + mode = mode & (RCV_SHUTDOWN | SEND_SHUTDOWN); + if (mode) { + sk->sk_shutdown |= mode; + sk->sk_state_change(sk); + + if (sock_type_connectible(sk->sk_type)) { + sock_reset_flag(sk, SOCK_DONE); + vsock_send_shutdown(sk, mode); + } + } + +out: + release_sock(sk); + return err; +} + +static __poll_t vsock_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + struct sock *sk; + __poll_t mask; + struct vsock_sock *vsk; + + sk = sock->sk; + vsk = vsock_sk(sk); + + poll_wait(file, sk_sleep(sk), wait); + mask = 0; + + if (sk->sk_err) + /* Signify that there has been an error on this socket. */ + mask |= EPOLLERR; + + /* INET sockets treat local write shutdown and peer write shutdown as a + * case of EPOLLHUP set. + */ + if ((sk->sk_shutdown == SHUTDOWN_MASK) || + ((sk->sk_shutdown & SEND_SHUTDOWN) && + (vsk->peer_shutdown & SEND_SHUTDOWN))) { + mask |= EPOLLHUP; + } + + if (sk->sk_shutdown & RCV_SHUTDOWN || + vsk->peer_shutdown & SEND_SHUTDOWN) { + mask |= EPOLLRDHUP; + } + + if (sock->type == SOCK_DGRAM) { + /* For datagram sockets we can read if there is something in + * the queue and write as long as the socket isn't shutdown for + * sending. + */ + if (!skb_queue_empty_lockless(&sk->sk_receive_queue) || + (sk->sk_shutdown & RCV_SHUTDOWN)) { + mask |= EPOLLIN | EPOLLRDNORM; + } + + if (!(sk->sk_shutdown & SEND_SHUTDOWN)) + mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND; + + } else if (sock_type_connectible(sk->sk_type)) { + const struct vsock_transport *transport; + + lock_sock(sk); + + transport = vsk->transport; + + /* Listening sockets that have connections in their accept + * queue can be read. + */ + if (sk->sk_state == TCP_LISTEN + && !vsock_is_accept_queue_empty(sk)) + mask |= EPOLLIN | EPOLLRDNORM; + + /* If there is something in the queue then we can read. */ + if (transport && transport->stream_is_active(vsk) && + !(sk->sk_shutdown & RCV_SHUTDOWN)) { + bool data_ready_now = false; + int target = sock_rcvlowat(sk, 0, INT_MAX); + int ret = transport->notify_poll_in( + vsk, target, &data_ready_now); + if (ret < 0) { + mask |= EPOLLERR; + } else { + if (data_ready_now) + mask |= EPOLLIN | EPOLLRDNORM; + + } + } + + /* Sockets whose connections have been closed, reset, or + * terminated should also be considered read, and we check the + * shutdown flag for that. + */ + if (sk->sk_shutdown & RCV_SHUTDOWN || + vsk->peer_shutdown & SEND_SHUTDOWN) { + mask |= EPOLLIN | EPOLLRDNORM; + } + + /* Connected sockets that can produce data can be written. */ + if (transport && sk->sk_state == TCP_ESTABLISHED) { + if (!(sk->sk_shutdown & SEND_SHUTDOWN)) { + bool space_avail_now = false; + int ret = transport->notify_poll_out( + vsk, 1, &space_avail_now); + if (ret < 0) { + mask |= EPOLLERR; + } else { + if (space_avail_now) + /* Remove EPOLLWRBAND since INET + * sockets are not setting it. + */ + mask |= EPOLLOUT | EPOLLWRNORM; + + } + } + } + + /* Simulate INET socket poll behaviors, which sets + * EPOLLOUT|EPOLLWRNORM when peer is closed and nothing to read, + * but local send is not shutdown. + */ + if (sk->sk_state == TCP_CLOSE || sk->sk_state == TCP_CLOSING) { + if (!(sk->sk_shutdown & SEND_SHUTDOWN)) + mask |= EPOLLOUT | EPOLLWRNORM; + + } + + release_sock(sk); + } + + return mask; +} + +static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + int err; + struct sock *sk; + struct vsock_sock *vsk; + struct sockaddr_vm *remote_addr; + const struct vsock_transport *transport; + + if (msg->msg_flags & MSG_OOB) + return -EOPNOTSUPP; + + /* For now, MSG_DONTWAIT is always assumed... */ + err = 0; + sk = sock->sk; + vsk = vsock_sk(sk); + + lock_sock(sk); + + transport = vsk->transport; + + err = vsock_auto_bind(vsk); + if (err) + goto out; + + + /* If the provided message contains an address, use that. Otherwise + * fall back on the socket's remote handle (if it has been connected). + */ + if (msg->msg_name && + vsock_addr_cast(msg->msg_name, msg->msg_namelen, + &remote_addr) == 0) { + /* Ensure this address is of the right type and is a valid + * destination. + */ + + if (remote_addr->svm_cid == VMADDR_CID_ANY) + remote_addr->svm_cid = transport->get_local_cid(); + + if (!vsock_addr_bound(remote_addr)) { + err = -EINVAL; + goto out; + } + } else if (sock->state == SS_CONNECTED) { + remote_addr = &vsk->remote_addr; + + if (remote_addr->svm_cid == VMADDR_CID_ANY) + remote_addr->svm_cid = transport->get_local_cid(); + + /* XXX Should connect() or this function ensure remote_addr is + * bound? + */ + if (!vsock_addr_bound(&vsk->remote_addr)) { + err = -EINVAL; + goto out; + } + } else { + err = -EINVAL; + goto out; + } + + if (!transport->dgram_allow(remote_addr->svm_cid, + remote_addr->svm_port)) { + err = -EINVAL; + goto out; + } + + err = transport->dgram_enqueue(vsk, remote_addr, msg, len); + +out: + release_sock(sk); + return err; +} + +static int vsock_dgram_connect(struct socket *sock, + struct sockaddr *addr, int addr_len, int flags) +{ + int err; + struct sock *sk; + struct vsock_sock *vsk; + struct sockaddr_vm *remote_addr; + + sk = sock->sk; + vsk = vsock_sk(sk); + + err = vsock_addr_cast(addr, addr_len, &remote_addr); + if (err == -EAFNOSUPPORT && remote_addr->svm_family == AF_UNSPEC) { + lock_sock(sk); + vsock_addr_init(&vsk->remote_addr, VMADDR_CID_ANY, + VMADDR_PORT_ANY); + sock->state = SS_UNCONNECTED; + release_sock(sk); + return 0; + } else if (err != 0) + return -EINVAL; + + lock_sock(sk); + + err = vsock_auto_bind(vsk); + if (err) + goto out; + + if (!vsk->transport->dgram_allow(remote_addr->svm_cid, + remote_addr->svm_port)) { + err = -EINVAL; + goto out; + } + + memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr)); + sock->state = SS_CONNECTED; + +out: + release_sock(sk); + return err; +} + +static int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, + size_t len, int flags) +{ + struct vsock_sock *vsk = vsock_sk(sock->sk); + + return vsk->transport->dgram_dequeue(vsk, msg, len, flags); +} + +static const struct proto_ops vsock_dgram_ops = { + .family = PF_VSOCK, + .owner = THIS_MODULE, + .release = vsock_release, + .bind = vsock_bind, + .connect = vsock_dgram_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = vsock_getname, + .poll = vsock_poll, + .ioctl = sock_no_ioctl, + .listen = sock_no_listen, + .shutdown = vsock_shutdown, + .sendmsg = vsock_dgram_sendmsg, + .recvmsg = vsock_dgram_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static int vsock_transport_cancel_pkt(struct vsock_sock *vsk) +{ + const struct vsock_transport *transport = vsk->transport; + + if (!transport || !transport->cancel_pkt) + return -EOPNOTSUPP; + + return transport->cancel_pkt(vsk); +} + +static void vsock_connect_timeout(struct work_struct *work) +{ + struct sock *sk; + struct vsock_sock *vsk; + + vsk = container_of(work, struct vsock_sock, connect_work.work); + sk = sk_vsock(vsk); + + lock_sock(sk); + if (sk->sk_state == TCP_SYN_SENT && + (sk->sk_shutdown != SHUTDOWN_MASK)) { + sk->sk_state = TCP_CLOSE; + sk->sk_socket->state = SS_UNCONNECTED; + sk->sk_err = ETIMEDOUT; + sk_error_report(sk); + vsock_transport_cancel_pkt(vsk); + } + release_sock(sk); + + sock_put(sk); +} + +static int vsock_connect(struct socket *sock, struct sockaddr *addr, + int addr_len, int flags) +{ + int err; + struct sock *sk; + struct vsock_sock *vsk; + const struct vsock_transport *transport; + struct sockaddr_vm *remote_addr; + long timeout; + DEFINE_WAIT(wait); + + err = 0; + sk = sock->sk; + vsk = vsock_sk(sk); + + lock_sock(sk); + + /* XXX AF_UNSPEC should make us disconnect like AF_INET. */ + switch (sock->state) { + case SS_CONNECTED: + err = -EISCONN; + goto out; + case SS_DISCONNECTING: + err = -EINVAL; + goto out; + case SS_CONNECTING: + /* This continues on so we can move sock into the SS_CONNECTED + * state once the connection has completed (at which point err + * will be set to zero also). Otherwise, we will either wait + * for the connection or return -EALREADY should this be a + * non-blocking call. + */ + err = -EALREADY; + if (flags & O_NONBLOCK) + goto out; + break; + default: + if ((sk->sk_state == TCP_LISTEN) || + vsock_addr_cast(addr, addr_len, &remote_addr) != 0) { + err = -EINVAL; + goto out; + } + + /* Set the remote address that we are connecting to. */ + memcpy(&vsk->remote_addr, remote_addr, + sizeof(vsk->remote_addr)); + + err = vsock_assign_transport(vsk, NULL); + if (err) + goto out; + + transport = vsk->transport; + + /* The hypervisor and well-known contexts do not have socket + * endpoints. + */ + if (!transport || + !transport->stream_allow(remote_addr->svm_cid, + remote_addr->svm_port)) { + err = -ENETUNREACH; + goto out; + } + + err = vsock_auto_bind(vsk); + if (err) + goto out; + + sk->sk_state = TCP_SYN_SENT; + + err = transport->connect(vsk); + if (err < 0) + goto out; + + /* Mark sock as connecting and set the error code to in + * progress in case this is a non-blocking connect. + */ + sock->state = SS_CONNECTING; + err = -EINPROGRESS; + } + + /* The receive path will handle all communication until we are able to + * enter the connected state. Here we wait for the connection to be + * completed or a notification of an error. + */ + timeout = vsk->connect_timeout; + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + + while (sk->sk_state != TCP_ESTABLISHED && sk->sk_err == 0) { + if (flags & O_NONBLOCK) { + /* If we're not going to block, we schedule a timeout + * function to generate a timeout on the connection + * attempt, in case the peer doesn't respond in a + * timely manner. We hold on to the socket until the + * timeout fires. + */ + sock_hold(sk); + + /* If the timeout function is already scheduled, + * reschedule it, then ungrab the socket refcount to + * keep it balanced. + */ + if (mod_delayed_work(system_wq, &vsk->connect_work, + timeout)) + sock_put(sk); + + /* Skip ahead to preserve error code set above. */ + goto out_wait; + } + + release_sock(sk); + timeout = schedule_timeout(timeout); + lock_sock(sk); + + if (signal_pending(current)) { + err = sock_intr_errno(timeout); + sk->sk_state = sk->sk_state == TCP_ESTABLISHED ? TCP_CLOSING : TCP_CLOSE; + sock->state = SS_UNCONNECTED; + vsock_transport_cancel_pkt(vsk); + vsock_remove_connected(vsk); + goto out_wait; + } else if ((sk->sk_state != TCP_ESTABLISHED) && (timeout == 0)) { + err = -ETIMEDOUT; + sk->sk_state = TCP_CLOSE; + sock->state = SS_UNCONNECTED; + vsock_transport_cancel_pkt(vsk); + goto out_wait; + } + + prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE); + } + + if (sk->sk_err) { + err = -sk->sk_err; + sk->sk_state = TCP_CLOSE; + sock->state = SS_UNCONNECTED; + } else { + err = 0; + } + +out_wait: + finish_wait(sk_sleep(sk), &wait); +out: + release_sock(sk); + return err; +} + +static int vsock_accept(struct socket *sock, struct socket *newsock, int flags, + bool kern) +{ + struct sock *listener; + int err; + struct sock *connected; + struct vsock_sock *vconnected; + long timeout; + DEFINE_WAIT(wait); + + err = 0; + listener = sock->sk; + + lock_sock(listener); + + if (!sock_type_connectible(sock->type)) { + err = -EOPNOTSUPP; + goto out; + } + + if (listener->sk_state != TCP_LISTEN) { + err = -EINVAL; + goto out; + } + + /* Wait for children sockets to appear; these are the new sockets + * created upon connection establishment. + */ + timeout = sock_rcvtimeo(listener, flags & O_NONBLOCK); + prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE); + + while ((connected = vsock_dequeue_accept(listener)) == NULL && + listener->sk_err == 0) { + release_sock(listener); + timeout = schedule_timeout(timeout); + finish_wait(sk_sleep(listener), &wait); + lock_sock(listener); + + if (signal_pending(current)) { + err = sock_intr_errno(timeout); + goto out; + } else if (timeout == 0) { + err = -EAGAIN; + goto out; + } + + prepare_to_wait(sk_sleep(listener), &wait, TASK_INTERRUPTIBLE); + } + finish_wait(sk_sleep(listener), &wait); + + if (listener->sk_err) + err = -listener->sk_err; + + if (connected) { + sk_acceptq_removed(listener); + + lock_sock_nested(connected, SINGLE_DEPTH_NESTING); + vconnected = vsock_sk(connected); + + /* If the listener socket has received an error, then we should + * reject this socket and return. Note that we simply mark the + * socket rejected, drop our reference, and let the cleanup + * function handle the cleanup; the fact that we found it in + * the listener's accept queue guarantees that the cleanup + * function hasn't run yet. + */ + if (err) { + vconnected->rejected = true; + } else { + newsock->state = SS_CONNECTED; + sock_graft(connected, newsock); + } + + release_sock(connected); + sock_put(connected); + } + +out: + release_sock(listener); + return err; +} + +static int vsock_listen(struct socket *sock, int backlog) +{ + int err; + struct sock *sk; + struct vsock_sock *vsk; + + sk = sock->sk; + + lock_sock(sk); + + if (!sock_type_connectible(sk->sk_type)) { + err = -EOPNOTSUPP; + goto out; + } + + if (sock->state != SS_UNCONNECTED) { + err = -EINVAL; + goto out; + } + + vsk = vsock_sk(sk); + + if (!vsock_addr_bound(&vsk->local_addr)) { + err = -EINVAL; + goto out; + } + + sk->sk_max_ack_backlog = backlog; + sk->sk_state = TCP_LISTEN; + + err = 0; + +out: + release_sock(sk); + return err; +} + +static void vsock_update_buffer_size(struct vsock_sock *vsk, + const struct vsock_transport *transport, + u64 val) +{ + if (val > vsk->buffer_max_size) + val = vsk->buffer_max_size; + + if (val < vsk->buffer_min_size) + val = vsk->buffer_min_size; + + if (val != vsk->buffer_size && + transport && transport->notify_buffer_size) + transport->notify_buffer_size(vsk, &val); + + vsk->buffer_size = val; +} + +static int vsock_connectible_setsockopt(struct socket *sock, + int level, + int optname, + sockptr_t optval, + unsigned int optlen) +{ + int err; + struct sock *sk; + struct vsock_sock *vsk; + const struct vsock_transport *transport; + u64 val; + + if (level != AF_VSOCK) + return -ENOPROTOOPT; + +#define COPY_IN(_v) \ + do { \ + if (optlen < sizeof(_v)) { \ + err = -EINVAL; \ + goto exit; \ + } \ + if (copy_from_sockptr(&_v, optval, sizeof(_v)) != 0) { \ + err = -EFAULT; \ + goto exit; \ + } \ + } while (0) + + err = 0; + sk = sock->sk; + vsk = vsock_sk(sk); + + lock_sock(sk); + + transport = vsk->transport; + + switch (optname) { + case SO_VM_SOCKETS_BUFFER_SIZE: + COPY_IN(val); + vsock_update_buffer_size(vsk, transport, val); + break; + + case SO_VM_SOCKETS_BUFFER_MAX_SIZE: + COPY_IN(val); + vsk->buffer_max_size = val; + vsock_update_buffer_size(vsk, transport, vsk->buffer_size); + break; + + case SO_VM_SOCKETS_BUFFER_MIN_SIZE: + COPY_IN(val); + vsk->buffer_min_size = val; + vsock_update_buffer_size(vsk, transport, vsk->buffer_size); + break; + + case SO_VM_SOCKETS_CONNECT_TIMEOUT_NEW: + case SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD: { + struct __kernel_sock_timeval tv; + + err = sock_copy_user_timeval(&tv, optval, optlen, + optname == SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD); + if (err) + break; + if (tv.tv_sec >= 0 && tv.tv_usec < USEC_PER_SEC && + tv.tv_sec < (MAX_SCHEDULE_TIMEOUT / HZ - 1)) { + vsk->connect_timeout = tv.tv_sec * HZ + + DIV_ROUND_UP((unsigned long)tv.tv_usec, (USEC_PER_SEC / HZ)); + if (vsk->connect_timeout == 0) + vsk->connect_timeout = + VSOCK_DEFAULT_CONNECT_TIMEOUT; + + } else { + err = -ERANGE; + } + break; + } + + default: + err = -ENOPROTOOPT; + break; + } + +#undef COPY_IN + +exit: + release_sock(sk); + return err; +} + +static int vsock_connectible_getsockopt(struct socket *sock, + int level, int optname, + char __user *optval, + int __user *optlen) +{ + struct sock *sk = sock->sk; + struct vsock_sock *vsk = vsock_sk(sk); + + union { + u64 val64; + struct old_timeval32 tm32; + struct __kernel_old_timeval tm; + struct __kernel_sock_timeval stm; + } v; + + int lv = sizeof(v.val64); + int len; + + if (level != AF_VSOCK) + return -ENOPROTOOPT; + + if (get_user(len, optlen)) + return -EFAULT; + + memset(&v, 0, sizeof(v)); + + switch (optname) { + case SO_VM_SOCKETS_BUFFER_SIZE: + v.val64 = vsk->buffer_size; + break; + + case SO_VM_SOCKETS_BUFFER_MAX_SIZE: + v.val64 = vsk->buffer_max_size; + break; + + case SO_VM_SOCKETS_BUFFER_MIN_SIZE: + v.val64 = vsk->buffer_min_size; + break; + + case SO_VM_SOCKETS_CONNECT_TIMEOUT_NEW: + case SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD: + lv = sock_get_timeout(vsk->connect_timeout, &v, + optname == SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD); + break; + + default: + return -ENOPROTOOPT; + } + + if (len < lv) + return -EINVAL; + if (len > lv) + len = lv; + if (copy_to_user(optval, &v, len)) + return -EFAULT; + + if (put_user(len, optlen)) + return -EFAULT; + + return 0; +} + +static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg, + size_t len) +{ + struct sock *sk; + struct vsock_sock *vsk; + const struct vsock_transport *transport; + ssize_t total_written; + long timeout; + int err; + struct vsock_transport_send_notify_data send_data; + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + sk = sock->sk; + vsk = vsock_sk(sk); + total_written = 0; + err = 0; + + if (msg->msg_flags & MSG_OOB) + return -EOPNOTSUPP; + + lock_sock(sk); + + transport = vsk->transport; + + /* Callers should not provide a destination with connection oriented + * sockets. + */ + if (msg->msg_namelen) { + err = sk->sk_state == TCP_ESTABLISHED ? -EISCONN : -EOPNOTSUPP; + goto out; + } + + /* Send data only if both sides are not shutdown in the direction. */ + if (sk->sk_shutdown & SEND_SHUTDOWN || + vsk->peer_shutdown & RCV_SHUTDOWN) { + err = -EPIPE; + goto out; + } + + if (!transport || sk->sk_state != TCP_ESTABLISHED || + !vsock_addr_bound(&vsk->local_addr)) { + err = -ENOTCONN; + goto out; + } + + if (!vsock_addr_bound(&vsk->remote_addr)) { + err = -EDESTADDRREQ; + goto out; + } + + /* Wait for room in the produce queue to enqueue our user's data. */ + timeout = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + + err = transport->notify_send_init(vsk, &send_data); + if (err < 0) + goto out; + + while (total_written < len) { + ssize_t written; + + add_wait_queue(sk_sleep(sk), &wait); + while (vsock_stream_has_space(vsk) == 0 && + sk->sk_err == 0 && + !(sk->sk_shutdown & SEND_SHUTDOWN) && + !(vsk->peer_shutdown & RCV_SHUTDOWN)) { + + /* Don't wait for non-blocking sockets. */ + if (timeout == 0) { + err = -EAGAIN; + remove_wait_queue(sk_sleep(sk), &wait); + goto out_err; + } + + err = transport->notify_send_pre_block(vsk, &send_data); + if (err < 0) { + remove_wait_queue(sk_sleep(sk), &wait); + goto out_err; + } + + release_sock(sk); + timeout = wait_woken(&wait, TASK_INTERRUPTIBLE, timeout); + lock_sock(sk); + if (signal_pending(current)) { + err = sock_intr_errno(timeout); + remove_wait_queue(sk_sleep(sk), &wait); + goto out_err; + } else if (timeout == 0) { + err = -EAGAIN; + remove_wait_queue(sk_sleep(sk), &wait); + goto out_err; + } + } + remove_wait_queue(sk_sleep(sk), &wait); + + /* These checks occur both as part of and after the loop + * conditional since we need to check before and after + * sleeping. + */ + if (sk->sk_err) { + err = -sk->sk_err; + goto out_err; + } else if ((sk->sk_shutdown & SEND_SHUTDOWN) || + (vsk->peer_shutdown & RCV_SHUTDOWN)) { + err = -EPIPE; + goto out_err; + } + + err = transport->notify_send_pre_enqueue(vsk, &send_data); + if (err < 0) + goto out_err; + + /* Note that enqueue will only write as many bytes as are free + * in the produce queue, so we don't need to ensure len is + * smaller than the queue size. It is the caller's + * responsibility to check how many bytes we were able to send. + */ + + if (sk->sk_type == SOCK_SEQPACKET) { + written = transport->seqpacket_enqueue(vsk, + msg, len - total_written); + } else { + written = transport->stream_enqueue(vsk, + msg, len - total_written); + } + if (written < 0) { + err = -ENOMEM; + goto out_err; + } + + total_written += written; + + err = transport->notify_send_post_enqueue( + vsk, written, &send_data); + if (err < 0) + goto out_err; + + } + +out_err: + if (total_written > 0) { + /* Return number of written bytes only if: + * 1) SOCK_STREAM socket. + * 2) SOCK_SEQPACKET socket when whole buffer is sent. + */ + if (sk->sk_type == SOCK_STREAM || total_written == len) + err = total_written; + } +out: + release_sock(sk); + return err; +} + +static int vsock_connectible_wait_data(struct sock *sk, + struct wait_queue_entry *wait, + long timeout, + struct vsock_transport_recv_notify_data *recv_data, + size_t target) +{ + const struct vsock_transport *transport; + struct vsock_sock *vsk; + s64 data; + int err; + + vsk = vsock_sk(sk); + err = 0; + transport = vsk->transport; + + while (1) { + prepare_to_wait(sk_sleep(sk), wait, TASK_INTERRUPTIBLE); + data = vsock_connectible_has_data(vsk); + if (data != 0) + break; + + if (sk->sk_err != 0 || + (sk->sk_shutdown & RCV_SHUTDOWN) || + (vsk->peer_shutdown & SEND_SHUTDOWN)) { + break; + } + + /* Don't wait for non-blocking sockets. */ + if (timeout == 0) { + err = -EAGAIN; + break; + } + + if (recv_data) { + err = transport->notify_recv_pre_block(vsk, target, recv_data); + if (err < 0) + break; + } + + release_sock(sk); + timeout = schedule_timeout(timeout); + lock_sock(sk); + + if (signal_pending(current)) { + err = sock_intr_errno(timeout); + break; + } else if (timeout == 0) { + err = -EAGAIN; + break; + } + } + + finish_wait(sk_sleep(sk), wait); + + if (err) + return err; + + /* Internal transport error when checking for available + * data. XXX This should be changed to a connection + * reset in a later change. + */ + if (data < 0) + return -ENOMEM; + + return data; +} + +static int __vsock_stream_recvmsg(struct sock *sk, struct msghdr *msg, + size_t len, int flags) +{ + struct vsock_transport_recv_notify_data recv_data; + const struct vsock_transport *transport; + struct vsock_sock *vsk; + ssize_t copied; + size_t target; + long timeout; + int err; + + DEFINE_WAIT(wait); + + vsk = vsock_sk(sk); + transport = vsk->transport; + + /* We must not copy less than target bytes into the user's buffer + * before returning successfully, so we wait for the consume queue to + * have that much data to consume before dequeueing. Note that this + * makes it impossible to handle cases where target is greater than the + * queue size. + */ + target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); + if (target >= transport->stream_rcvhiwat(vsk)) { + err = -ENOMEM; + goto out; + } + timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + copied = 0; + + err = transport->notify_recv_init(vsk, target, &recv_data); + if (err < 0) + goto out; + + + while (1) { + ssize_t read; + + err = vsock_connectible_wait_data(sk, &wait, timeout, + &recv_data, target); + if (err <= 0) + break; + + err = transport->notify_recv_pre_dequeue(vsk, target, + &recv_data); + if (err < 0) + break; + + read = transport->stream_dequeue(vsk, msg, len - copied, flags); + if (read < 0) { + err = -ENOMEM; + break; + } + + copied += read; + + err = transport->notify_recv_post_dequeue(vsk, target, read, + !(flags & MSG_PEEK), &recv_data); + if (err < 0) + goto out; + + if (read >= target || flags & MSG_PEEK) + break; + + target -= read; + } + + if (sk->sk_err) + err = -sk->sk_err; + else if (sk->sk_shutdown & RCV_SHUTDOWN) + err = 0; + + if (copied > 0) + err = copied; + +out: + return err; +} + +static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg, + size_t len, int flags) +{ + const struct vsock_transport *transport; + struct vsock_sock *vsk; + ssize_t msg_len; + long timeout; + int err = 0; + DEFINE_WAIT(wait); + + vsk = vsock_sk(sk); + transport = vsk->transport; + + timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); + + err = vsock_connectible_wait_data(sk, &wait, timeout, NULL, 0); + if (err <= 0) + goto out; + + msg_len = transport->seqpacket_dequeue(vsk, msg, flags); + + if (msg_len < 0) { + err = -ENOMEM; + goto out; + } + + if (sk->sk_err) { + err = -sk->sk_err; + } else if (sk->sk_shutdown & RCV_SHUTDOWN) { + err = 0; + } else { + /* User sets MSG_TRUNC, so return real length of + * packet. + */ + if (flags & MSG_TRUNC) + err = msg_len; + else + err = len - msg_data_left(msg); + + /* Always set MSG_TRUNC if real length of packet is + * bigger than user's buffer. + */ + if (msg_len > len) + msg->msg_flags |= MSG_TRUNC; + } + +out: + return err; +} + +static int +vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, + int flags) +{ + struct sock *sk; + struct vsock_sock *vsk; + const struct vsock_transport *transport; + int err; + + sk = sock->sk; + + if (unlikely(flags & MSG_ERRQUEUE)) + return sock_recv_errqueue(sk, msg, len, SOL_VSOCK, VSOCK_RECVERR); + + vsk = vsock_sk(sk); + err = 0; + + lock_sock(sk); + + transport = vsk->transport; + + if (!transport || sk->sk_state != TCP_ESTABLISHED) { + /* Recvmsg is supposed to return 0 if a peer performs an + * orderly shutdown. Differentiate between that case and when a + * peer has not connected or a local shutdown occurred with the + * SOCK_DONE flag. + */ + if (sock_flag(sk, SOCK_DONE)) + err = 0; + else + err = -ENOTCONN; + + goto out; + } + + if (flags & MSG_OOB) { + err = -EOPNOTSUPP; + goto out; + } + + /* We don't check peer_shutdown flag here since peer may actually shut + * down, but there can be data in the queue that a local socket can + * receive. + */ + if (sk->sk_shutdown & RCV_SHUTDOWN) { + err = 0; + goto out; + } + + /* It is valid on Linux to pass in a zero-length receive buffer. This + * is not an error. We may as well bail out now. + */ + if (!len) { + err = 0; + goto out; + } + + if (sk->sk_type == SOCK_STREAM) + err = __vsock_stream_recvmsg(sk, msg, len, flags); + else + err = __vsock_seqpacket_recvmsg(sk, msg, len, flags); + +out: + release_sock(sk); + return err; +} + +static int vsock_set_rcvlowat(struct sock *sk, int val) +{ + const struct vsock_transport *transport; + struct vsock_sock *vsk; + + vsk = vsock_sk(sk); + + if (val > vsk->buffer_size) + return -EINVAL; + + transport = vsk->transport; + + if (transport && transport->set_rcvlowat) + return transport->set_rcvlowat(vsk, val); + + WRITE_ONCE(sk->sk_rcvlowat, val ? : 1); + return 0; +} + +static const struct proto_ops vsock_stream_ops = { + .family = PF_VSOCK, + .owner = THIS_MODULE, + .release = vsock_release, + .bind = vsock_bind, + .connect = vsock_connect, + .socketpair = sock_no_socketpair, + .accept = vsock_accept, + .getname = vsock_getname, + .poll = vsock_poll, + .ioctl = sock_no_ioctl, + .listen = vsock_listen, + .shutdown = vsock_shutdown, + .setsockopt = vsock_connectible_setsockopt, + .getsockopt = vsock_connectible_getsockopt, + .sendmsg = vsock_connectible_sendmsg, + .recvmsg = vsock_connectible_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, + .set_rcvlowat = vsock_set_rcvlowat, +}; + +static const struct proto_ops vsock_seqpacket_ops = { + .family = PF_VSOCK, + .owner = THIS_MODULE, + .release = vsock_release, + .bind = vsock_bind, + .connect = vsock_connect, + .socketpair = sock_no_socketpair, + .accept = vsock_accept, + .getname = vsock_getname, + .poll = vsock_poll, + .ioctl = sock_no_ioctl, + .listen = vsock_listen, + .shutdown = vsock_shutdown, + .setsockopt = vsock_connectible_setsockopt, + .getsockopt = vsock_connectible_getsockopt, + .sendmsg = vsock_connectible_sendmsg, + .recvmsg = vsock_connectible_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static int vsock_create(struct net *net, struct socket *sock, + int protocol, int kern) +{ + struct vsock_sock *vsk; + struct sock *sk; + int ret; + + if (!sock) + return -EINVAL; + + if (protocol && protocol != PF_VSOCK) + return -EPROTONOSUPPORT; + + switch (sock->type) { + case SOCK_DGRAM: + sock->ops = &vsock_dgram_ops; + break; + case SOCK_STREAM: + sock->ops = &vsock_stream_ops; + break; + case SOCK_SEQPACKET: + sock->ops = &vsock_seqpacket_ops; + break; + default: + return -ESOCKTNOSUPPORT; + } + + sock->state = SS_UNCONNECTED; + + sk = __vsock_create(net, sock, NULL, GFP_KERNEL, 0, kern); + if (!sk) + return -ENOMEM; + + vsk = vsock_sk(sk); + + if (sock->type == SOCK_DGRAM) { + ret = vsock_assign_transport(vsk, NULL); + if (ret < 0) { + sock_put(sk); + return ret; + } + } + + vsock_insert_unbound(vsk); + + return 0; +} + +static const struct net_proto_family vsock_family_ops = { + .family = AF_VSOCK, + .create = vsock_create, + .owner = THIS_MODULE, +}; + +static long vsock_dev_do_ioctl(struct file *filp, + unsigned int cmd, void __user *ptr) +{ + u32 __user *p = ptr; + u32 cid = VMADDR_CID_ANY; + int retval = 0; + + switch (cmd) { + case IOCTL_VM_SOCKETS_GET_LOCAL_CID: + /* To be compatible with the VMCI behavior, we prioritize the + * guest CID instead of well-know host CID (VMADDR_CID_HOST). + */ + if (transport_g2h) + cid = transport_g2h->get_local_cid(); + else if (transport_h2g) + cid = transport_h2g->get_local_cid(); + + if (put_user(cid, p) != 0) + retval = -EFAULT; + break; + + default: + retval = -ENOIOCTLCMD; + } + + return retval; +} + +static long vsock_dev_ioctl(struct file *filp, + unsigned int cmd, unsigned long arg) +{ + return vsock_dev_do_ioctl(filp, cmd, (void __user *)arg); +} + +#ifdef CONFIG_COMPAT +static long vsock_dev_compat_ioctl(struct file *filp, + unsigned int cmd, unsigned long arg) +{ + return vsock_dev_do_ioctl(filp, cmd, compat_ptr(arg)); +} +#endif + +static const struct file_operations vsock_device_ops = { + .owner = THIS_MODULE, + .unlocked_ioctl = vsock_dev_ioctl, +#ifdef CONFIG_COMPAT + .compat_ioctl = vsock_dev_compat_ioctl, +#endif + .open = nonseekable_open, +}; + +static struct miscdevice vsock_device = { + .name = "vsock", + .fops = &vsock_device_ops, +}; + +static int __init vsock_init(void) +{ + int err = 0; + + vsock_init_tables(); + + vsock_proto.owner = THIS_MODULE; + vsock_device.minor = MISC_DYNAMIC_MINOR; + err = misc_register(&vsock_device); + if (err) { + pr_err("Failed to register misc device\n"); + goto err_reset_transport; + } + + err = proto_register(&vsock_proto, 1); /* we want our slab */ + if (err) { + pr_err("Cannot register vsock protocol\n"); + goto err_deregister_misc; + } + + err = sock_register(&vsock_family_ops); + if (err) { + pr_err("could not register af_vsock (%d) address family: %d\n", + AF_VSOCK, err); + goto err_unregister_proto; + } + + return 0; + +err_unregister_proto: + proto_unregister(&vsock_proto); +err_deregister_misc: + misc_deregister(&vsock_device); +err_reset_transport: + return err; +} + +static void __exit vsock_exit(void) +{ + misc_deregister(&vsock_device); + sock_unregister(AF_VSOCK); + proto_unregister(&vsock_proto); +} + +const struct vsock_transport *vsock_core_get_transport(struct vsock_sock *vsk) +{ + return vsk->transport; +} +EXPORT_SYMBOL_GPL(vsock_core_get_transport); + +int vsock_core_register(const struct vsock_transport *t, int features) +{ + const struct vsock_transport *t_h2g, *t_g2h, *t_dgram, *t_local; + int err = mutex_lock_interruptible(&vsock_register_mutex); + + if (err) + return err; + + t_h2g = transport_h2g; + t_g2h = transport_g2h; + t_dgram = transport_dgram; + t_local = transport_local; + + if (features & VSOCK_TRANSPORT_F_H2G) { + if (t_h2g) { + err = -EBUSY; + goto err_busy; + } + t_h2g = t; + } + + if (features & VSOCK_TRANSPORT_F_G2H) { + if (t_g2h) { + err = -EBUSY; + goto err_busy; + } + t_g2h = t; + } + + if (features & VSOCK_TRANSPORT_F_DGRAM) { + if (t_dgram) { + err = -EBUSY; + goto err_busy; + } + t_dgram = t; + } + + if (features & VSOCK_TRANSPORT_F_LOCAL) { + if (t_local) { + err = -EBUSY; + goto err_busy; + } + t_local = t; + } + + transport_h2g = t_h2g; + transport_g2h = t_g2h; + transport_dgram = t_dgram; + transport_local = t_local; + +err_busy: + mutex_unlock(&vsock_register_mutex); + return err; +} +EXPORT_SYMBOL_GPL(vsock_core_register); + +void vsock_core_unregister(const struct vsock_transport *t) +{ + mutex_lock(&vsock_register_mutex); + + if (transport_h2g == t) + transport_h2g = NULL; + + if (transport_g2h == t) + transport_g2h = NULL; + + if (transport_dgram == t) + transport_dgram = NULL; + + if (transport_local == t) + transport_local = NULL; + + mutex_unlock(&vsock_register_mutex); +} +EXPORT_SYMBOL_GPL(vsock_core_unregister); + +module_init(vsock_init); +module_exit(vsock_exit); + +MODULE_AUTHOR("VMware, Inc."); +MODULE_DESCRIPTION("VMware Virtual Socket Family"); +MODULE_VERSION("1.0.2.0-k"); +MODULE_LICENSE("GPL v2"); diff --git a/net/vmw_vsock/af_vsock_tap.c b/net/vmw_vsock/af_vsock_tap.c new file mode 100644 index 000000000..30ee7e4fe --- /dev/null +++ b/net/vmw_vsock/af_vsock_tap.c @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * Tap functions for AF_VSOCK sockets. + * + * Code based on net/netlink/af_netlink.c tap functions. + */ + +#include +#include +#include +#include + +static DEFINE_SPINLOCK(vsock_tap_lock); +static struct list_head vsock_tap_all __read_mostly = + LIST_HEAD_INIT(vsock_tap_all); + +int vsock_add_tap(struct vsock_tap *vt) +{ + if (unlikely(vt->dev->type != ARPHRD_VSOCKMON)) + return -EINVAL; + + __module_get(vt->module); + + spin_lock(&vsock_tap_lock); + list_add_rcu(&vt->list, &vsock_tap_all); + spin_unlock(&vsock_tap_lock); + + return 0; +} +EXPORT_SYMBOL_GPL(vsock_add_tap); + +int vsock_remove_tap(struct vsock_tap *vt) +{ + struct vsock_tap *tmp; + bool found = false; + + spin_lock(&vsock_tap_lock); + + list_for_each_entry(tmp, &vsock_tap_all, list) { + if (vt == tmp) { + list_del_rcu(&vt->list); + found = true; + goto out; + } + } + + pr_warn("vsock_remove_tap: %p not found\n", vt); +out: + spin_unlock(&vsock_tap_lock); + + synchronize_net(); + + if (found) + module_put(vt->module); + + return found ? 0 : -ENODEV; +} +EXPORT_SYMBOL_GPL(vsock_remove_tap); + +static int __vsock_deliver_tap_skb(struct sk_buff *skb, + struct net_device *dev) +{ + int ret = 0; + struct sk_buff *nskb = skb_clone(skb, GFP_ATOMIC); + + if (nskb) { + dev_hold(dev); + + nskb->dev = dev; + ret = dev_queue_xmit(nskb); + if (unlikely(ret > 0)) + ret = net_xmit_errno(ret); + + dev_put(dev); + } + + return ret; +} + +static void __vsock_deliver_tap(struct sk_buff *skb) +{ + int ret; + struct vsock_tap *tmp; + + list_for_each_entry_rcu(tmp, &vsock_tap_all, list) { + ret = __vsock_deliver_tap_skb(skb, tmp->dev); + if (unlikely(ret)) + break; + } +} + +void vsock_deliver_tap(struct sk_buff *build_skb(void *opaque), void *opaque) +{ + struct sk_buff *skb; + + rcu_read_lock(); + + if (likely(list_empty(&vsock_tap_all))) + goto out; + + skb = build_skb(opaque); + if (skb) { + __vsock_deliver_tap(skb); + consume_skb(skb); + } + +out: + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(vsock_deliver_tap); diff --git a/net/vmw_vsock/diag.c b/net/vmw_vsock/diag.c new file mode 100644 index 000000000..a2823b1c5 --- /dev/null +++ b/net/vmw_vsock/diag.c @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * vsock sock_diag(7) module + * + * Copyright (C) 2017 Red Hat, Inc. + * Author: Stefan Hajnoczi + */ + +#include +#include +#include +#include + +static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, + u32 portid, u32 seq, u32 flags) +{ + struct vsock_sock *vsk = vsock_sk(sk); + struct vsock_diag_msg *rep; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(skb, portid, seq, SOCK_DIAG_BY_FAMILY, sizeof(*rep), + flags); + if (!nlh) + return -EMSGSIZE; + + rep = nlmsg_data(nlh); + rep->vdiag_family = AF_VSOCK; + + /* Lock order dictates that sk_lock is acquired before + * vsock_table_lock, so we cannot lock here. Simply don't take + * sk_lock; sk is guaranteed to stay alive since vsock_table_lock is + * held. + */ + rep->vdiag_type = sk->sk_type; + rep->vdiag_state = sk->sk_state; + rep->vdiag_shutdown = sk->sk_shutdown; + rep->vdiag_src_cid = vsk->local_addr.svm_cid; + rep->vdiag_src_port = vsk->local_addr.svm_port; + rep->vdiag_dst_cid = vsk->remote_addr.svm_cid; + rep->vdiag_dst_port = vsk->remote_addr.svm_port; + rep->vdiag_ino = sock_i_ino(sk); + + sock_diag_save_cookie(sk, rep->vdiag_cookie); + + return 0; +} + +static int vsock_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct vsock_diag_req *req; + struct vsock_sock *vsk; + unsigned int bucket; + unsigned int last_i; + unsigned int table; + struct net *net; + unsigned int i; + + req = nlmsg_data(cb->nlh); + net = sock_net(skb->sk); + + /* State saved between calls: */ + table = cb->args[0]; + bucket = cb->args[1]; + i = last_i = cb->args[2]; + + /* TODO VMCI pending sockets? */ + + spin_lock_bh(&vsock_table_lock); + + /* Bind table (locally created sockets) */ + if (table == 0) { + while (bucket < ARRAY_SIZE(vsock_bind_table)) { + struct list_head *head = &vsock_bind_table[bucket]; + + i = 0; + list_for_each_entry(vsk, head, bound_table) { + struct sock *sk = sk_vsock(vsk); + + if (!net_eq(sock_net(sk), net)) + continue; + if (i < last_i) + goto next_bind; + if (!(req->vdiag_states & (1 << sk->sk_state))) + goto next_bind; + if (sk_diag_fill(sk, skb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI) < 0) + goto done; +next_bind: + i++; + } + last_i = 0; + bucket++; + } + + table++; + bucket = 0; + } + + /* Connected table (accepted connections) */ + while (bucket < ARRAY_SIZE(vsock_connected_table)) { + struct list_head *head = &vsock_connected_table[bucket]; + + i = 0; + list_for_each_entry(vsk, head, connected_table) { + struct sock *sk = sk_vsock(vsk); + + /* Skip sockets we've already seen above */ + if (__vsock_in_bound_table(vsk)) + continue; + + if (!net_eq(sock_net(sk), net)) + continue; + if (i < last_i) + goto next_connected; + if (!(req->vdiag_states & (1 << sk->sk_state))) + goto next_connected; + if (sk_diag_fill(sk, skb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI) < 0) + goto done; +next_connected: + i++; + } + last_i = 0; + bucket++; + } + +done: + spin_unlock_bh(&vsock_table_lock); + + cb->args[0] = table; + cb->args[1] = bucket; + cb->args[2] = i; + + return skb->len; +} + +static int vsock_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h) +{ + int hdrlen = sizeof(struct vsock_diag_req); + struct net *net = sock_net(skb->sk); + + if (nlmsg_len(h) < hdrlen) + return -EINVAL; + + if (h->nlmsg_flags & NLM_F_DUMP) { + struct netlink_dump_control c = { + .dump = vsock_diag_dump, + }; + return netlink_dump_start(net->diag_nlsk, skb, h, &c); + } + + return -EOPNOTSUPP; +} + +static const struct sock_diag_handler vsock_diag_handler = { + .family = AF_VSOCK, + .dump = vsock_diag_handler_dump, +}; + +static int __init vsock_diag_init(void) +{ + return sock_diag_register(&vsock_diag_handler); +} + +static void __exit vsock_diag_exit(void) +{ + sock_diag_unregister(&vsock_diag_handler); +} + +module_init(vsock_diag_init); +module_exit(vsock_diag_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, + 40 /* AF_VSOCK */); diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c new file mode 100644 index 000000000..59c3e2697 --- /dev/null +++ b/net/vmw_vsock/hyperv_transport.c @@ -0,0 +1,956 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * Hyper-V transport for vsock + * + * Hyper-V Sockets supplies a byte-stream based communication mechanism + * between the host and the VM. This driver implements the necessary + * support in the VM by introducing the new vsock transport. + * + * Copyright (c) 2017, Microsoft Corporation. + */ +#include +#include +#include +#include +#include +#include + +/* Older (VMBUS version 'VERSION_WIN10' or before) Windows hosts have some + * stricter requirements on the hv_sock ring buffer size of six 4K pages. + * hyperv-tlfs defines HV_HYP_PAGE_SIZE as 4K. Newer hosts don't have this + * limitation; but, keep the defaults the same for compat. + */ +#define RINGBUFFER_HVS_RCV_SIZE (HV_HYP_PAGE_SIZE * 6) +#define RINGBUFFER_HVS_SND_SIZE (HV_HYP_PAGE_SIZE * 6) +#define RINGBUFFER_HVS_MAX_SIZE (HV_HYP_PAGE_SIZE * 64) + +/* The MTU is 16KB per the host side's design */ +#define HVS_MTU_SIZE (1024 * 16) + +/* How long to wait for graceful shutdown of a connection */ +#define HVS_CLOSE_TIMEOUT (8 * HZ) + +struct vmpipe_proto_header { + u32 pkt_type; + u32 data_size; +}; + +/* For recv, we use the VMBus in-place packet iterator APIs to directly copy + * data from the ringbuffer into the userspace buffer. + */ +struct hvs_recv_buf { + /* The header before the payload data */ + struct vmpipe_proto_header hdr; + + /* The payload */ + u8 data[HVS_MTU_SIZE]; +}; + +/* We can send up to HVS_MTU_SIZE bytes of payload to the host, but let's use + * a smaller size, i.e. HVS_SEND_BUF_SIZE, to maximize concurrency between the + * guest and the host processing as one VMBUS packet is the smallest processing + * unit. + * + * Note: the buffer can be eliminated in the future when we add new VMBus + * ringbuffer APIs that allow us to directly copy data from userspace buffer + * to VMBus ringbuffer. + */ +#define HVS_SEND_BUF_SIZE \ + (HV_HYP_PAGE_SIZE - sizeof(struct vmpipe_proto_header)) + +struct hvs_send_buf { + /* The header before the payload data */ + struct vmpipe_proto_header hdr; + + /* The payload */ + u8 data[HVS_SEND_BUF_SIZE]; +}; + +#define HVS_HEADER_LEN (sizeof(struct vmpacket_descriptor) + \ + sizeof(struct vmpipe_proto_header)) + +/* See 'prev_indices' in hv_ringbuffer_read(), hv_ringbuffer_write(), and + * __hv_pkt_iter_next(). + */ +#define VMBUS_PKT_TRAILER_SIZE (sizeof(u64)) + +#define HVS_PKT_LEN(payload_len) (HVS_HEADER_LEN + \ + ALIGN((payload_len), 8) + \ + VMBUS_PKT_TRAILER_SIZE) + +/* Upper bound on the size of a VMbus packet for hv_sock */ +#define HVS_MAX_PKT_SIZE HVS_PKT_LEN(HVS_MTU_SIZE) + +union hvs_service_id { + guid_t srv_id; + + struct { + unsigned int svm_port; + unsigned char b[sizeof(guid_t) - sizeof(unsigned int)]; + }; +}; + +/* Per-socket state (accessed via vsk->trans) */ +struct hvsock { + struct vsock_sock *vsk; + + guid_t vm_srv_id; + guid_t host_srv_id; + + struct vmbus_channel *chan; + struct vmpacket_descriptor *recv_desc; + + /* The length of the payload not delivered to userland yet */ + u32 recv_data_len; + /* The offset of the payload */ + u32 recv_data_off; + + /* Have we sent the zero-length packet (FIN)? */ + bool fin_sent; +}; + +/* In the VM, we support Hyper-V Sockets with AF_VSOCK, and the endpoint is + * (see struct sockaddr_vm). Note: cid is not really used here: + * when we write apps to connect to the host, we can only use VMADDR_CID_ANY + * or VMADDR_CID_HOST (both are equivalent) as the remote cid, and when we + * write apps to bind() & listen() in the VM, we can only use VMADDR_CID_ANY + * as the local cid. + * + * On the host, Hyper-V Sockets are supported by Winsock AF_HYPERV: + * https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user- + * guide/make-integration-service, and the endpoint is with + * the below sockaddr: + * + * struct SOCKADDR_HV + * { + * ADDRESS_FAMILY Family; + * USHORT Reserved; + * GUID VmId; + * GUID ServiceId; + * }; + * Note: VmID is not used by Linux VM and actually it isn't transmitted via + * VMBus, because here it's obvious the host and the VM can easily identify + * each other. Though the VmID is useful on the host, especially in the case + * of Windows container, Linux VM doesn't need it at all. + * + * To make use of the AF_VSOCK infrastructure in Linux VM, we have to limit + * the available GUID space of SOCKADDR_HV so that we can create a mapping + * between AF_VSOCK port and SOCKADDR_HV Service GUID. The rule of writing + * Hyper-V Sockets apps on the host and in Linux VM is: + * + **************************************************************************** + * The only valid Service GUIDs, from the perspectives of both the host and * + * Linux VM, that can be connected by the other end, must conform to this * + * format: -facb-11e6-bd58-64006a7986d3. * + **************************************************************************** + * + * When we write apps on the host to connect(), the GUID ServiceID is used. + * When we write apps in Linux VM to connect(), we only need to specify the + * port and the driver will form the GUID and use that to request the host. + * + */ + +/* 00000000-facb-11e6-bd58-64006a7986d3 */ +static const guid_t srv_id_template = + GUID_INIT(0x00000000, 0xfacb, 0x11e6, 0xbd, 0x58, + 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3); + +static bool hvs_check_transport(struct vsock_sock *vsk); + +static bool is_valid_srv_id(const guid_t *id) +{ + return !memcmp(&id->b[4], &srv_id_template.b[4], sizeof(guid_t) - 4); +} + +static unsigned int get_port_by_srv_id(const guid_t *svr_id) +{ + return *((unsigned int *)svr_id); +} + +static void hvs_addr_init(struct sockaddr_vm *addr, const guid_t *svr_id) +{ + unsigned int port = get_port_by_srv_id(svr_id); + + vsock_addr_init(addr, VMADDR_CID_ANY, port); +} + +static void hvs_set_channel_pending_send_size(struct vmbus_channel *chan) +{ + set_channel_pending_send_size(chan, + HVS_PKT_LEN(HVS_SEND_BUF_SIZE)); + + virt_mb(); +} + +static bool hvs_channel_readable(struct vmbus_channel *chan) +{ + u32 readable = hv_get_bytes_to_read(&chan->inbound); + + /* 0-size payload means FIN */ + return readable >= HVS_PKT_LEN(0); +} + +static int hvs_channel_readable_payload(struct vmbus_channel *chan) +{ + u32 readable = hv_get_bytes_to_read(&chan->inbound); + + if (readable > HVS_PKT_LEN(0)) { + /* At least we have 1 byte to read. We don't need to return + * the exact readable bytes: see vsock_stream_recvmsg() -> + * vsock_stream_has_data(). + */ + return 1; + } + + if (readable == HVS_PKT_LEN(0)) { + /* 0-size payload means FIN */ + return 0; + } + + /* No payload or FIN */ + return -1; +} + +static size_t hvs_channel_writable_bytes(struct vmbus_channel *chan) +{ + u32 writeable = hv_get_bytes_to_write(&chan->outbound); + size_t ret; + + /* The ringbuffer mustn't be 100% full, and we should reserve a + * zero-length-payload packet for the FIN: see hv_ringbuffer_write() + * and hvs_shutdown(). + */ + if (writeable <= HVS_PKT_LEN(1) + HVS_PKT_LEN(0)) + return 0; + + ret = writeable - HVS_PKT_LEN(1) - HVS_PKT_LEN(0); + + return round_down(ret, 8); +} + +static int __hvs_send_data(struct vmbus_channel *chan, + struct vmpipe_proto_header *hdr, + size_t to_write) +{ + hdr->pkt_type = 1; + hdr->data_size = to_write; + return vmbus_sendpacket(chan, hdr, sizeof(*hdr) + to_write, + 0, VM_PKT_DATA_INBAND, 0); +} + +static int hvs_send_data(struct vmbus_channel *chan, + struct hvs_send_buf *send_buf, size_t to_write) +{ + return __hvs_send_data(chan, &send_buf->hdr, to_write); +} + +static void hvs_channel_cb(void *ctx) +{ + struct sock *sk = (struct sock *)ctx; + struct vsock_sock *vsk = vsock_sk(sk); + struct hvsock *hvs = vsk->trans; + struct vmbus_channel *chan = hvs->chan; + + if (hvs_channel_readable(chan)) + sk->sk_data_ready(sk); + + if (hv_get_bytes_to_write(&chan->outbound) > 0) + sk->sk_write_space(sk); +} + +static void hvs_do_close_lock_held(struct vsock_sock *vsk, + bool cancel_timeout) +{ + struct sock *sk = sk_vsock(vsk); + + sock_set_flag(sk, SOCK_DONE); + vsk->peer_shutdown = SHUTDOWN_MASK; + if (vsock_stream_has_data(vsk) <= 0) + sk->sk_state = TCP_CLOSING; + sk->sk_state_change(sk); + if (vsk->close_work_scheduled && + (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { + vsk->close_work_scheduled = false; + vsock_remove_sock(vsk); + + /* Release the reference taken while scheduling the timeout */ + sock_put(sk); + } +} + +static void hvs_close_connection(struct vmbus_channel *chan) +{ + struct sock *sk = get_per_channel_state(chan); + + lock_sock(sk); + hvs_do_close_lock_held(vsock_sk(sk), true); + release_sock(sk); + + /* Release the refcnt for the channel that's opened in + * hvs_open_connection(). + */ + sock_put(sk); +} + +static void hvs_open_connection(struct vmbus_channel *chan) +{ + guid_t *if_instance, *if_type; + unsigned char conn_from_host; + + struct sockaddr_vm addr; + struct sock *sk, *new = NULL; + struct vsock_sock *vnew = NULL; + struct hvsock *hvs = NULL; + struct hvsock *hvs_new = NULL; + int rcvbuf; + int ret; + int sndbuf; + + if_type = &chan->offermsg.offer.if_type; + if_instance = &chan->offermsg.offer.if_instance; + conn_from_host = chan->offermsg.offer.u.pipe.user_def[0]; + if (!is_valid_srv_id(if_type)) + return; + + hvs_addr_init(&addr, conn_from_host ? if_type : if_instance); + sk = vsock_find_bound_socket(&addr); + if (!sk) + return; + + lock_sock(sk); + if ((conn_from_host && sk->sk_state != TCP_LISTEN) || + (!conn_from_host && sk->sk_state != TCP_SYN_SENT)) + goto out; + + if (conn_from_host) { + if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog) + goto out; + + new = vsock_create_connected(sk); + if (!new) + goto out; + + new->sk_state = TCP_SYN_SENT; + vnew = vsock_sk(new); + + hvs_addr_init(&vnew->local_addr, if_type); + + /* Remote peer is always the host */ + vsock_addr_init(&vnew->remote_addr, + VMADDR_CID_HOST, VMADDR_PORT_ANY); + vnew->remote_addr.svm_port = get_port_by_srv_id(if_instance); + ret = vsock_assign_transport(vnew, vsock_sk(sk)); + /* Transport assigned (looking at remote_addr) must be the + * same where we received the request. + */ + if (ret || !hvs_check_transport(vnew)) { + sock_put(new); + goto out; + } + hvs_new = vnew->trans; + hvs_new->chan = chan; + } else { + hvs = vsock_sk(sk)->trans; + hvs->chan = chan; + } + + set_channel_read_mode(chan, HV_CALL_DIRECT); + + /* Use the socket buffer sizes as hints for the VMBUS ring size. For + * server side sockets, 'sk' is the parent socket and thus, this will + * allow the child sockets to inherit the size from the parent. Keep + * the mins to the default value and align to page size as per VMBUS + * requirements. + * For the max, the socket core library will limit the socket buffer + * size that can be set by the user, but, since currently, the hv_sock + * VMBUS ring buffer is physically contiguous allocation, restrict it + * further. + * Older versions of hv_sock host side code cannot handle bigger VMBUS + * ring buffer size. Use the version number to limit the change to newer + * versions. + */ + if (vmbus_proto_version < VERSION_WIN10_V5) { + sndbuf = RINGBUFFER_HVS_SND_SIZE; + rcvbuf = RINGBUFFER_HVS_RCV_SIZE; + } else { + sndbuf = max_t(int, sk->sk_sndbuf, RINGBUFFER_HVS_SND_SIZE); + sndbuf = min_t(int, sndbuf, RINGBUFFER_HVS_MAX_SIZE); + sndbuf = ALIGN(sndbuf, HV_HYP_PAGE_SIZE); + rcvbuf = max_t(int, sk->sk_rcvbuf, RINGBUFFER_HVS_RCV_SIZE); + rcvbuf = min_t(int, rcvbuf, RINGBUFFER_HVS_MAX_SIZE); + rcvbuf = ALIGN(rcvbuf, HV_HYP_PAGE_SIZE); + } + + chan->max_pkt_size = HVS_MAX_PKT_SIZE; + + ret = vmbus_open(chan, sndbuf, rcvbuf, NULL, 0, hvs_channel_cb, + conn_from_host ? new : sk); + if (ret != 0) { + if (conn_from_host) { + hvs_new->chan = NULL; + sock_put(new); + } else { + hvs->chan = NULL; + } + goto out; + } + + set_per_channel_state(chan, conn_from_host ? new : sk); + + /* This reference will be dropped by hvs_close_connection(). */ + sock_hold(conn_from_host ? new : sk); + vmbus_set_chn_rescind_callback(chan, hvs_close_connection); + + /* Set the pending send size to max packet size to always get + * notifications from the host when there is enough writable space. + * The host is optimized to send notifications only when the pending + * size boundary is crossed, and not always. + */ + hvs_set_channel_pending_send_size(chan); + + if (conn_from_host) { + new->sk_state = TCP_ESTABLISHED; + sk_acceptq_added(sk); + + hvs_new->vm_srv_id = *if_type; + hvs_new->host_srv_id = *if_instance; + + vsock_insert_connected(vnew); + + vsock_enqueue_accept(sk, new); + } else { + sk->sk_state = TCP_ESTABLISHED; + sk->sk_socket->state = SS_CONNECTED; + + vsock_insert_connected(vsock_sk(sk)); + } + + sk->sk_state_change(sk); + +out: + /* Release refcnt obtained when we called vsock_find_bound_socket() */ + sock_put(sk); + + release_sock(sk); +} + +static u32 hvs_get_local_cid(void) +{ + return VMADDR_CID_ANY; +} + +static int hvs_sock_init(struct vsock_sock *vsk, struct vsock_sock *psk) +{ + struct hvsock *hvs; + struct sock *sk = sk_vsock(vsk); + + hvs = kzalloc(sizeof(*hvs), GFP_KERNEL); + if (!hvs) + return -ENOMEM; + + vsk->trans = hvs; + hvs->vsk = vsk; + sk->sk_sndbuf = RINGBUFFER_HVS_SND_SIZE; + sk->sk_rcvbuf = RINGBUFFER_HVS_RCV_SIZE; + return 0; +} + +static int hvs_connect(struct vsock_sock *vsk) +{ + union hvs_service_id vm, host; + struct hvsock *h = vsk->trans; + + vm.srv_id = srv_id_template; + vm.svm_port = vsk->local_addr.svm_port; + h->vm_srv_id = vm.srv_id; + + host.srv_id = srv_id_template; + host.svm_port = vsk->remote_addr.svm_port; + h->host_srv_id = host.srv_id; + + return vmbus_send_tl_connect_request(&h->vm_srv_id, &h->host_srv_id); +} + +static void hvs_shutdown_lock_held(struct hvsock *hvs, int mode) +{ + struct vmpipe_proto_header hdr; + + if (hvs->fin_sent || !hvs->chan) + return; + + /* It can't fail: see hvs_channel_writable_bytes(). */ + (void)__hvs_send_data(hvs->chan, &hdr, 0); + hvs->fin_sent = true; +} + +static int hvs_shutdown(struct vsock_sock *vsk, int mode) +{ + if (!(mode & SEND_SHUTDOWN)) + return 0; + + hvs_shutdown_lock_held(vsk->trans, mode); + return 0; +} + +static void hvs_close_timeout(struct work_struct *work) +{ + struct vsock_sock *vsk = + container_of(work, struct vsock_sock, close_work.work); + struct sock *sk = sk_vsock(vsk); + + sock_hold(sk); + lock_sock(sk); + if (!sock_flag(sk, SOCK_DONE)) + hvs_do_close_lock_held(vsk, false); + + vsk->close_work_scheduled = false; + release_sock(sk); + sock_put(sk); +} + +/* Returns true, if it is safe to remove socket; false otherwise */ +static bool hvs_close_lock_held(struct vsock_sock *vsk) +{ + struct sock *sk = sk_vsock(vsk); + + if (!(sk->sk_state == TCP_ESTABLISHED || + sk->sk_state == TCP_CLOSING)) + return true; + + if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK) + hvs_shutdown_lock_held(vsk->trans, SHUTDOWN_MASK); + + if (sock_flag(sk, SOCK_DONE)) + return true; + + /* This reference will be dropped by the delayed close routine */ + sock_hold(sk); + INIT_DELAYED_WORK(&vsk->close_work, hvs_close_timeout); + vsk->close_work_scheduled = true; + schedule_delayed_work(&vsk->close_work, HVS_CLOSE_TIMEOUT); + return false; +} + +static void hvs_release(struct vsock_sock *vsk) +{ + bool remove_sock; + + remove_sock = hvs_close_lock_held(vsk); + if (remove_sock) + vsock_remove_sock(vsk); +} + +static void hvs_destruct(struct vsock_sock *vsk) +{ + struct hvsock *hvs = vsk->trans; + struct vmbus_channel *chan = hvs->chan; + + if (chan) + vmbus_hvsock_device_unregister(chan); + + kfree(hvs); +} + +static int hvs_dgram_bind(struct vsock_sock *vsk, struct sockaddr_vm *addr) +{ + return -EOPNOTSUPP; +} + +static int hvs_dgram_dequeue(struct vsock_sock *vsk, struct msghdr *msg, + size_t len, int flags) +{ + return -EOPNOTSUPP; +} + +static int hvs_dgram_enqueue(struct vsock_sock *vsk, + struct sockaddr_vm *remote, struct msghdr *msg, + size_t dgram_len) +{ + return -EOPNOTSUPP; +} + +static bool hvs_dgram_allow(u32 cid, u32 port) +{ + return false; +} + +static int hvs_update_recv_data(struct hvsock *hvs) +{ + struct hvs_recv_buf *recv_buf; + u32 pkt_len, payload_len; + + pkt_len = hv_pkt_len(hvs->recv_desc); + + if (pkt_len < HVS_HEADER_LEN) + return -EIO; + + recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1); + payload_len = recv_buf->hdr.data_size; + + if (payload_len > pkt_len - HVS_HEADER_LEN || + payload_len > HVS_MTU_SIZE) + return -EIO; + + if (payload_len == 0) + hvs->vsk->peer_shutdown |= SEND_SHUTDOWN; + + hvs->recv_data_len = payload_len; + hvs->recv_data_off = 0; + + return 0; +} + +static ssize_t hvs_stream_dequeue(struct vsock_sock *vsk, struct msghdr *msg, + size_t len, int flags) +{ + struct hvsock *hvs = vsk->trans; + bool need_refill = !hvs->recv_desc; + struct hvs_recv_buf *recv_buf; + u32 to_read; + int ret; + + if (flags & MSG_PEEK) + return -EOPNOTSUPP; + + if (need_refill) { + hvs->recv_desc = hv_pkt_iter_first(hvs->chan); + if (!hvs->recv_desc) + return -ENOBUFS; + ret = hvs_update_recv_data(hvs); + if (ret) + return ret; + } + + recv_buf = (struct hvs_recv_buf *)(hvs->recv_desc + 1); + to_read = min_t(u32, len, hvs->recv_data_len); + ret = memcpy_to_msg(msg, recv_buf->data + hvs->recv_data_off, to_read); + if (ret != 0) + return ret; + + hvs->recv_data_len -= to_read; + if (hvs->recv_data_len == 0) { + hvs->recv_desc = hv_pkt_iter_next(hvs->chan, hvs->recv_desc); + if (hvs->recv_desc) { + ret = hvs_update_recv_data(hvs); + if (ret) + return ret; + } + } else { + hvs->recv_data_off += to_read; + } + + return to_read; +} + +static ssize_t hvs_stream_enqueue(struct vsock_sock *vsk, struct msghdr *msg, + size_t len) +{ + struct hvsock *hvs = vsk->trans; + struct vmbus_channel *chan = hvs->chan; + struct hvs_send_buf *send_buf; + ssize_t to_write, max_writable; + ssize_t ret = 0; + ssize_t bytes_written = 0; + + BUILD_BUG_ON(sizeof(*send_buf) != HV_HYP_PAGE_SIZE); + + send_buf = kmalloc(sizeof(*send_buf), GFP_KERNEL); + if (!send_buf) + return -ENOMEM; + + /* Reader(s) could be draining data from the channel as we write. + * Maximize bandwidth, by iterating until the channel is found to be + * full. + */ + while (len) { + max_writable = hvs_channel_writable_bytes(chan); + if (!max_writable) + break; + to_write = min_t(ssize_t, len, max_writable); + to_write = min_t(ssize_t, to_write, HVS_SEND_BUF_SIZE); + /* memcpy_from_msg is safe for loop as it advances the offsets + * within the message iterator. + */ + ret = memcpy_from_msg(send_buf->data, msg, to_write); + if (ret < 0) + goto out; + + ret = hvs_send_data(hvs->chan, send_buf, to_write); + if (ret < 0) + goto out; + + bytes_written += to_write; + len -= to_write; + } +out: + /* If any data has been sent, return that */ + if (bytes_written) + ret = bytes_written; + kfree(send_buf); + return ret; +} + +static s64 hvs_stream_has_data(struct vsock_sock *vsk) +{ + struct hvsock *hvs = vsk->trans; + s64 ret; + + if (hvs->recv_data_len > 0) + return 1; + + switch (hvs_channel_readable_payload(hvs->chan)) { + case 1: + ret = 1; + break; + case 0: + vsk->peer_shutdown |= SEND_SHUTDOWN; + ret = 0; + break; + default: /* -1 */ + ret = 0; + break; + } + + return ret; +} + +static s64 hvs_stream_has_space(struct vsock_sock *vsk) +{ + struct hvsock *hvs = vsk->trans; + + return hvs_channel_writable_bytes(hvs->chan); +} + +static u64 hvs_stream_rcvhiwat(struct vsock_sock *vsk) +{ + return HVS_MTU_SIZE + 1; +} + +static bool hvs_stream_is_active(struct vsock_sock *vsk) +{ + struct hvsock *hvs = vsk->trans; + + return hvs->chan != NULL; +} + +static bool hvs_stream_allow(u32 cid, u32 port) +{ + if (cid == VMADDR_CID_HOST) + return true; + + return false; +} + +static +int hvs_notify_poll_in(struct vsock_sock *vsk, size_t target, bool *readable) +{ + struct hvsock *hvs = vsk->trans; + + *readable = hvs_channel_readable(hvs->chan); + return 0; +} + +static +int hvs_notify_poll_out(struct vsock_sock *vsk, size_t target, bool *writable) +{ + *writable = hvs_stream_has_space(vsk) > 0; + + return 0; +} + +static +int hvs_notify_recv_init(struct vsock_sock *vsk, size_t target, + struct vsock_transport_recv_notify_data *d) +{ + return 0; +} + +static +int hvs_notify_recv_pre_block(struct vsock_sock *vsk, size_t target, + struct vsock_transport_recv_notify_data *d) +{ + return 0; +} + +static +int hvs_notify_recv_pre_dequeue(struct vsock_sock *vsk, size_t target, + struct vsock_transport_recv_notify_data *d) +{ + return 0; +} + +static +int hvs_notify_recv_post_dequeue(struct vsock_sock *vsk, size_t target, + ssize_t copied, bool data_read, + struct vsock_transport_recv_notify_data *d) +{ + return 0; +} + +static +int hvs_notify_send_init(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *d) +{ + return 0; +} + +static +int hvs_notify_send_pre_block(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *d) +{ + return 0; +} + +static +int hvs_notify_send_pre_enqueue(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *d) +{ + return 0; +} + +static +int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written, + struct vsock_transport_send_notify_data *d) +{ + return 0; +} + +static +int hvs_set_rcvlowat(struct vsock_sock *vsk, int val) +{ + return -EOPNOTSUPP; +} + +static struct vsock_transport hvs_transport = { + .module = THIS_MODULE, + + .get_local_cid = hvs_get_local_cid, + + .init = hvs_sock_init, + .destruct = hvs_destruct, + .release = hvs_release, + .connect = hvs_connect, + .shutdown = hvs_shutdown, + + .dgram_bind = hvs_dgram_bind, + .dgram_dequeue = hvs_dgram_dequeue, + .dgram_enqueue = hvs_dgram_enqueue, + .dgram_allow = hvs_dgram_allow, + + .stream_dequeue = hvs_stream_dequeue, + .stream_enqueue = hvs_stream_enqueue, + .stream_has_data = hvs_stream_has_data, + .stream_has_space = hvs_stream_has_space, + .stream_rcvhiwat = hvs_stream_rcvhiwat, + .stream_is_active = hvs_stream_is_active, + .stream_allow = hvs_stream_allow, + + .notify_poll_in = hvs_notify_poll_in, + .notify_poll_out = hvs_notify_poll_out, + .notify_recv_init = hvs_notify_recv_init, + .notify_recv_pre_block = hvs_notify_recv_pre_block, + .notify_recv_pre_dequeue = hvs_notify_recv_pre_dequeue, + .notify_recv_post_dequeue = hvs_notify_recv_post_dequeue, + .notify_send_init = hvs_notify_send_init, + .notify_send_pre_block = hvs_notify_send_pre_block, + .notify_send_pre_enqueue = hvs_notify_send_pre_enqueue, + .notify_send_post_enqueue = hvs_notify_send_post_enqueue, + + .set_rcvlowat = hvs_set_rcvlowat +}; + +static bool hvs_check_transport(struct vsock_sock *vsk) +{ + return vsk->transport == &hvs_transport; +} + +static int hvs_probe(struct hv_device *hdev, + const struct hv_vmbus_device_id *dev_id) +{ + struct vmbus_channel *chan = hdev->channel; + + hvs_open_connection(chan); + + /* Always return success to suppress the unnecessary error message + * in vmbus_probe(): on error the host will rescind the device in + * 30 seconds and we can do cleanup at that time in + * vmbus_onoffer_rescind(). + */ + return 0; +} + +static int hvs_remove(struct hv_device *hdev) +{ + struct vmbus_channel *chan = hdev->channel; + + vmbus_close(chan); + + return 0; +} + +/* hv_sock connections can not persist across hibernation, and all the hv_sock + * channels are forced to be rescinded before hibernation: see + * vmbus_bus_suspend(). Here the dummy hvs_suspend() and hvs_resume() + * are only needed because hibernation requires that every vmbus device's + * driver should have a .suspend and .resume callback: see vmbus_suspend(). + */ +static int hvs_suspend(struct hv_device *hv_dev) +{ + /* Dummy */ + return 0; +} + +static int hvs_resume(struct hv_device *dev) +{ + /* Dummy */ + return 0; +} + +/* This isn't really used. See vmbus_match() and vmbus_probe() */ +static const struct hv_vmbus_device_id id_table[] = { + {}, +}; + +static struct hv_driver hvs_drv = { + .name = "hv_sock", + .hvsock = true, + .id_table = id_table, + .probe = hvs_probe, + .remove = hvs_remove, + .suspend = hvs_suspend, + .resume = hvs_resume, +}; + +static int __init hvs_init(void) +{ + int ret; + + if (vmbus_proto_version < VERSION_WIN10) + return -ENODEV; + + ret = vmbus_driver_register(&hvs_drv); + if (ret != 0) + return ret; + + ret = vsock_core_register(&hvs_transport, VSOCK_TRANSPORT_F_G2H); + if (ret) { + vmbus_driver_unregister(&hvs_drv); + return ret; + } + + return 0; +} + +static void __exit hvs_exit(void) +{ + vsock_core_unregister(&hvs_transport); + vmbus_driver_unregister(&hvs_drv); +} + +module_init(hvs_init); +module_exit(hvs_exit); + +MODULE_DESCRIPTION("Hyper-V Sockets"); +MODULE_VERSION("1.0.0"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_VSOCK); diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c new file mode 100644 index 000000000..16575ea83 --- /dev/null +++ b/net/vmw_vsock/virtio_transport.c @@ -0,0 +1,821 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * virtio transport for vsock + * + * Copyright (C) 2013-2015 Red Hat, Inc. + * Author: Asias He + * Stefan Hajnoczi + * + * Some of the code is take from Gerd Hoffmann 's + * early virtio-vsock proof-of-concept bits. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static struct workqueue_struct *virtio_vsock_workqueue; +static struct virtio_vsock __rcu *the_virtio_vsock; +static DEFINE_MUTEX(the_virtio_vsock_mutex); /* protects the_virtio_vsock */ +static struct virtio_transport virtio_transport; /* forward declaration */ + +struct virtio_vsock { + struct virtio_device *vdev; + struct virtqueue *vqs[VSOCK_VQ_MAX]; + + /* Virtqueue processing is deferred to a workqueue */ + struct work_struct tx_work; + struct work_struct rx_work; + struct work_struct event_work; + + /* The following fields are protected by tx_lock. vqs[VSOCK_VQ_TX] + * must be accessed with tx_lock held. + */ + struct mutex tx_lock; + bool tx_run; + + struct work_struct send_pkt_work; + struct sk_buff_head send_pkt_queue; + + atomic_t queued_replies; + + /* The following fields are protected by rx_lock. vqs[VSOCK_VQ_RX] + * must be accessed with rx_lock held. + */ + struct mutex rx_lock; + bool rx_run; + int rx_buf_nr; + int rx_buf_max_nr; + + /* The following fields are protected by event_lock. + * vqs[VSOCK_VQ_EVENT] must be accessed with event_lock held. + */ + struct mutex event_lock; + bool event_run; + struct virtio_vsock_event event_list[8]; + + u32 guest_cid; + bool seqpacket_allow; +}; + +static u32 virtio_transport_get_local_cid(void) +{ + struct virtio_vsock *vsock; + u32 ret; + + rcu_read_lock(); + vsock = rcu_dereference(the_virtio_vsock); + if (!vsock) { + ret = VMADDR_CID_ANY; + goto out_rcu; + } + + ret = vsock->guest_cid; +out_rcu: + rcu_read_unlock(); + return ret; +} + +static void +virtio_transport_send_pkt_work(struct work_struct *work) +{ + struct virtio_vsock *vsock = + container_of(work, struct virtio_vsock, send_pkt_work); + struct virtqueue *vq; + bool added = false; + bool restart_rx = false; + + mutex_lock(&vsock->tx_lock); + + if (!vsock->tx_run) + goto out; + + vq = vsock->vqs[VSOCK_VQ_TX]; + + for (;;) { + struct scatterlist hdr, buf, *sgs[2]; + int ret, in_sg = 0, out_sg = 0; + struct sk_buff *skb; + bool reply; + + skb = virtio_vsock_skb_dequeue(&vsock->send_pkt_queue); + if (!skb) + break; + + virtio_transport_deliver_tap_pkt(skb); + reply = virtio_vsock_skb_reply(skb); + + sg_init_one(&hdr, virtio_vsock_hdr(skb), sizeof(*virtio_vsock_hdr(skb))); + sgs[out_sg++] = &hdr; + if (skb->len > 0) { + sg_init_one(&buf, skb->data, skb->len); + sgs[out_sg++] = &buf; + } + + ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, skb, GFP_KERNEL); + /* Usually this means that there is no more space available in + * the vq + */ + if (ret < 0) { + virtio_vsock_skb_queue_head(&vsock->send_pkt_queue, skb); + break; + } + + if (reply) { + struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX]; + int val; + + val = atomic_dec_return(&vsock->queued_replies); + + /* Do we now have resources to resume rx processing? */ + if (val + 1 == virtqueue_get_vring_size(rx_vq)) + restart_rx = true; + } + + added = true; + } + + if (added) + virtqueue_kick(vq); + +out: + mutex_unlock(&vsock->tx_lock); + + if (restart_rx) + queue_work(virtio_vsock_workqueue, &vsock->rx_work); +} + +static int +virtio_transport_send_pkt(struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr; + struct virtio_vsock *vsock; + int len = skb->len; + + hdr = virtio_vsock_hdr(skb); + + rcu_read_lock(); + vsock = rcu_dereference(the_virtio_vsock); + if (!vsock) { + kfree_skb(skb); + len = -ENODEV; + goto out_rcu; + } + + if (le64_to_cpu(hdr->dst_cid) == vsock->guest_cid) { + kfree_skb(skb); + len = -ENODEV; + goto out_rcu; + } + + if (virtio_vsock_skb_reply(skb)) + atomic_inc(&vsock->queued_replies); + + virtio_vsock_skb_queue_tail(&vsock->send_pkt_queue, skb); + queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work); + +out_rcu: + rcu_read_unlock(); + return len; +} + +static int +virtio_transport_cancel_pkt(struct vsock_sock *vsk) +{ + struct virtio_vsock *vsock; + int cnt = 0, ret; + + rcu_read_lock(); + vsock = rcu_dereference(the_virtio_vsock); + if (!vsock) { + ret = -ENODEV; + goto out_rcu; + } + + cnt = virtio_transport_purge_skbs(vsk, &vsock->send_pkt_queue); + + if (cnt) { + struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX]; + int new_cnt; + + new_cnt = atomic_sub_return(cnt, &vsock->queued_replies); + if (new_cnt + cnt >= virtqueue_get_vring_size(rx_vq) && + new_cnt < virtqueue_get_vring_size(rx_vq)) + queue_work(virtio_vsock_workqueue, &vsock->rx_work); + } + + ret = 0; + +out_rcu: + rcu_read_unlock(); + return ret; +} + +static void virtio_vsock_rx_fill(struct virtio_vsock *vsock) +{ + int total_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE + VIRTIO_VSOCK_SKB_HEADROOM; + struct scatterlist pkt, *p; + struct virtqueue *vq; + struct sk_buff *skb; + int ret; + + vq = vsock->vqs[VSOCK_VQ_RX]; + + do { + skb = virtio_vsock_alloc_skb(total_len, GFP_KERNEL); + if (!skb) + break; + + memset(skb->head, 0, VIRTIO_VSOCK_SKB_HEADROOM); + sg_init_one(&pkt, virtio_vsock_hdr(skb), total_len); + p = &pkt; + ret = virtqueue_add_sgs(vq, &p, 0, 1, skb, GFP_KERNEL); + if (ret < 0) { + kfree_skb(skb); + break; + } + + vsock->rx_buf_nr++; + } while (vq->num_free); + if (vsock->rx_buf_nr > vsock->rx_buf_max_nr) + vsock->rx_buf_max_nr = vsock->rx_buf_nr; + virtqueue_kick(vq); +} + +static void virtio_transport_tx_work(struct work_struct *work) +{ + struct virtio_vsock *vsock = + container_of(work, struct virtio_vsock, tx_work); + struct virtqueue *vq; + bool added = false; + + vq = vsock->vqs[VSOCK_VQ_TX]; + mutex_lock(&vsock->tx_lock); + + if (!vsock->tx_run) + goto out; + + do { + struct sk_buff *skb; + unsigned int len; + + virtqueue_disable_cb(vq); + while ((skb = virtqueue_get_buf(vq, &len)) != NULL) { + consume_skb(skb); + added = true; + } + } while (!virtqueue_enable_cb(vq)); + +out: + mutex_unlock(&vsock->tx_lock); + + if (added) + queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work); +} + +/* Is there space left for replies to rx packets? */ +static bool virtio_transport_more_replies(struct virtio_vsock *vsock) +{ + struct virtqueue *vq = vsock->vqs[VSOCK_VQ_RX]; + int val; + + smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */ + val = atomic_read(&vsock->queued_replies); + + return val < virtqueue_get_vring_size(vq); +} + +/* event_lock must be held */ +static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock, + struct virtio_vsock_event *event) +{ + struct scatterlist sg; + struct virtqueue *vq; + + vq = vsock->vqs[VSOCK_VQ_EVENT]; + + sg_init_one(&sg, event, sizeof(*event)); + + return virtqueue_add_inbuf(vq, &sg, 1, event, GFP_KERNEL); +} + +/* event_lock must be held */ +static void virtio_vsock_event_fill(struct virtio_vsock *vsock) +{ + size_t i; + + for (i = 0; i < ARRAY_SIZE(vsock->event_list); i++) { + struct virtio_vsock_event *event = &vsock->event_list[i]; + + virtio_vsock_event_fill_one(vsock, event); + } + + virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]); +} + +static void virtio_vsock_reset_sock(struct sock *sk) +{ + /* vmci_transport.c doesn't take sk_lock here either. At least we're + * under vsock_table_lock so the sock cannot disappear while we're + * executing. + */ + + sk->sk_state = TCP_CLOSE; + sk->sk_err = ECONNRESET; + sk_error_report(sk); +} + +static void virtio_vsock_update_guest_cid(struct virtio_vsock *vsock) +{ + struct virtio_device *vdev = vsock->vdev; + __le64 guest_cid; + + vdev->config->get(vdev, offsetof(struct virtio_vsock_config, guest_cid), + &guest_cid, sizeof(guest_cid)); + vsock->guest_cid = le64_to_cpu(guest_cid); +} + +/* event_lock must be held */ +static void virtio_vsock_event_handle(struct virtio_vsock *vsock, + struct virtio_vsock_event *event) +{ + switch (le32_to_cpu(event->id)) { + case VIRTIO_VSOCK_EVENT_TRANSPORT_RESET: + virtio_vsock_update_guest_cid(vsock); + vsock_for_each_connected_socket(&virtio_transport.transport, + virtio_vsock_reset_sock); + break; + } +} + +static void virtio_transport_event_work(struct work_struct *work) +{ + struct virtio_vsock *vsock = + container_of(work, struct virtio_vsock, event_work); + struct virtqueue *vq; + + vq = vsock->vqs[VSOCK_VQ_EVENT]; + + mutex_lock(&vsock->event_lock); + + if (!vsock->event_run) + goto out; + + do { + struct virtio_vsock_event *event; + unsigned int len; + + virtqueue_disable_cb(vq); + while ((event = virtqueue_get_buf(vq, &len)) != NULL) { + if (len == sizeof(*event)) + virtio_vsock_event_handle(vsock, event); + + virtio_vsock_event_fill_one(vsock, event); + } + } while (!virtqueue_enable_cb(vq)); + + virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]); +out: + mutex_unlock(&vsock->event_lock); +} + +static void virtio_vsock_event_done(struct virtqueue *vq) +{ + struct virtio_vsock *vsock = vq->vdev->priv; + + if (!vsock) + return; + queue_work(virtio_vsock_workqueue, &vsock->event_work); +} + +static void virtio_vsock_tx_done(struct virtqueue *vq) +{ + struct virtio_vsock *vsock = vq->vdev->priv; + + if (!vsock) + return; + queue_work(virtio_vsock_workqueue, &vsock->tx_work); +} + +static void virtio_vsock_rx_done(struct virtqueue *vq) +{ + struct virtio_vsock *vsock = vq->vdev->priv; + + if (!vsock) + return; + queue_work(virtio_vsock_workqueue, &vsock->rx_work); +} + +static bool virtio_transport_seqpacket_allow(u32 remote_cid); + +static struct virtio_transport virtio_transport = { + .transport = { + .module = THIS_MODULE, + + .get_local_cid = virtio_transport_get_local_cid, + + .init = virtio_transport_do_socket_init, + .destruct = virtio_transport_destruct, + .release = virtio_transport_release, + .connect = virtio_transport_connect, + .shutdown = virtio_transport_shutdown, + .cancel_pkt = virtio_transport_cancel_pkt, + + .dgram_bind = virtio_transport_dgram_bind, + .dgram_dequeue = virtio_transport_dgram_dequeue, + .dgram_enqueue = virtio_transport_dgram_enqueue, + .dgram_allow = virtio_transport_dgram_allow, + + .stream_dequeue = virtio_transport_stream_dequeue, + .stream_enqueue = virtio_transport_stream_enqueue, + .stream_has_data = virtio_transport_stream_has_data, + .stream_has_space = virtio_transport_stream_has_space, + .stream_rcvhiwat = virtio_transport_stream_rcvhiwat, + .stream_is_active = virtio_transport_stream_is_active, + .stream_allow = virtio_transport_stream_allow, + + .seqpacket_dequeue = virtio_transport_seqpacket_dequeue, + .seqpacket_enqueue = virtio_transport_seqpacket_enqueue, + .seqpacket_allow = virtio_transport_seqpacket_allow, + .seqpacket_has_data = virtio_transport_seqpacket_has_data, + + .notify_poll_in = virtio_transport_notify_poll_in, + .notify_poll_out = virtio_transport_notify_poll_out, + .notify_recv_init = virtio_transport_notify_recv_init, + .notify_recv_pre_block = virtio_transport_notify_recv_pre_block, + .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue, + .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue, + .notify_send_init = virtio_transport_notify_send_init, + .notify_send_pre_block = virtio_transport_notify_send_pre_block, + .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, + .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, + .notify_buffer_size = virtio_transport_notify_buffer_size, + }, + + .send_pkt = virtio_transport_send_pkt, +}; + +static bool virtio_transport_seqpacket_allow(u32 remote_cid) +{ + struct virtio_vsock *vsock; + bool seqpacket_allow; + + seqpacket_allow = false; + rcu_read_lock(); + vsock = rcu_dereference(the_virtio_vsock); + if (vsock) + seqpacket_allow = vsock->seqpacket_allow; + rcu_read_unlock(); + + return seqpacket_allow; +} + +static void virtio_transport_rx_work(struct work_struct *work) +{ + struct virtio_vsock *vsock = + container_of(work, struct virtio_vsock, rx_work); + struct virtqueue *vq; + + vq = vsock->vqs[VSOCK_VQ_RX]; + + mutex_lock(&vsock->rx_lock); + + if (!vsock->rx_run) + goto out; + + do { + virtqueue_disable_cb(vq); + for (;;) { + struct sk_buff *skb; + unsigned int len; + + if (!virtio_transport_more_replies(vsock)) { + /* Stop rx until the device processes already + * pending replies. Leave rx virtqueue + * callbacks disabled. + */ + goto out; + } + + skb = virtqueue_get_buf(vq, &len); + if (!skb) + break; + + vsock->rx_buf_nr--; + + /* Drop short/long packets */ + if (unlikely(len < sizeof(struct virtio_vsock_hdr) || + len > virtio_vsock_skb_len(skb))) { + kfree_skb(skb); + continue; + } + + virtio_vsock_skb_rx_put(skb); + virtio_transport_deliver_tap_pkt(skb); + virtio_transport_recv_pkt(&virtio_transport, skb); + } + } while (!virtqueue_enable_cb(vq)); + +out: + if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2) + virtio_vsock_rx_fill(vsock); + mutex_unlock(&vsock->rx_lock); +} + +static int virtio_vsock_vqs_init(struct virtio_vsock *vsock) +{ + struct virtio_device *vdev = vsock->vdev; + static const char * const names[] = { + "rx", + "tx", + "event", + }; + vq_callback_t *callbacks[] = { + virtio_vsock_rx_done, + virtio_vsock_tx_done, + virtio_vsock_event_done, + }; + int ret; + + ret = virtio_find_vqs(vdev, VSOCK_VQ_MAX, vsock->vqs, callbacks, names, + NULL); + if (ret < 0) + return ret; + + virtio_vsock_update_guest_cid(vsock); + + virtio_device_ready(vdev); + + return 0; +} + +static void virtio_vsock_vqs_start(struct virtio_vsock *vsock) +{ + mutex_lock(&vsock->tx_lock); + vsock->tx_run = true; + mutex_unlock(&vsock->tx_lock); + + mutex_lock(&vsock->rx_lock); + virtio_vsock_rx_fill(vsock); + vsock->rx_run = true; + mutex_unlock(&vsock->rx_lock); + + mutex_lock(&vsock->event_lock); + virtio_vsock_event_fill(vsock); + vsock->event_run = true; + mutex_unlock(&vsock->event_lock); + + /* virtio_transport_send_pkt() can queue packets once + * the_virtio_vsock is set, but they won't be processed until + * vsock->tx_run is set to true. We queue vsock->send_pkt_work + * when initialization finishes to send those packets queued + * earlier. + * We don't need to queue the other workers (rx, event) because + * as long as we don't fill the queues with empty buffers, the + * host can't send us any notification. + */ + queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work); +} + +static void virtio_vsock_vqs_del(struct virtio_vsock *vsock) +{ + struct virtio_device *vdev = vsock->vdev; + struct sk_buff *skb; + + /* Reset all connected sockets when the VQs disappear */ + vsock_for_each_connected_socket(&virtio_transport.transport, + virtio_vsock_reset_sock); + + /* Stop all work handlers to make sure no one is accessing the device, + * so we can safely call virtio_reset_device(). + */ + mutex_lock(&vsock->rx_lock); + vsock->rx_run = false; + mutex_unlock(&vsock->rx_lock); + + mutex_lock(&vsock->tx_lock); + vsock->tx_run = false; + mutex_unlock(&vsock->tx_lock); + + mutex_lock(&vsock->event_lock); + vsock->event_run = false; + mutex_unlock(&vsock->event_lock); + + /* Flush all device writes and interrupts, device will not use any + * more buffers. + */ + virtio_reset_device(vdev); + + mutex_lock(&vsock->rx_lock); + while ((skb = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_RX]))) + kfree_skb(skb); + mutex_unlock(&vsock->rx_lock); + + mutex_lock(&vsock->tx_lock); + while ((skb = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_TX]))) + kfree_skb(skb); + mutex_unlock(&vsock->tx_lock); + + virtio_vsock_skb_queue_purge(&vsock->send_pkt_queue); + + /* Delete virtqueues and flush outstanding callbacks if any */ + vdev->config->del_vqs(vdev); +} + +static int virtio_vsock_probe(struct virtio_device *vdev) +{ + struct virtio_vsock *vsock = NULL; + int ret; + + ret = mutex_lock_interruptible(&the_virtio_vsock_mutex); + if (ret) + return ret; + + /* Only one virtio-vsock device per guest is supported */ + if (rcu_dereference_protected(the_virtio_vsock, + lockdep_is_held(&the_virtio_vsock_mutex))) { + ret = -EBUSY; + goto out; + } + + vsock = kzalloc(sizeof(*vsock), GFP_KERNEL); + if (!vsock) { + ret = -ENOMEM; + goto out; + } + + vsock->vdev = vdev; + + vsock->rx_buf_nr = 0; + vsock->rx_buf_max_nr = 0; + atomic_set(&vsock->queued_replies, 0); + + mutex_init(&vsock->tx_lock); + mutex_init(&vsock->rx_lock); + mutex_init(&vsock->event_lock); + skb_queue_head_init(&vsock->send_pkt_queue); + INIT_WORK(&vsock->rx_work, virtio_transport_rx_work); + INIT_WORK(&vsock->tx_work, virtio_transport_tx_work); + INIT_WORK(&vsock->event_work, virtio_transport_event_work); + INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work); + + if (virtio_has_feature(vdev, VIRTIO_VSOCK_F_SEQPACKET)) + vsock->seqpacket_allow = true; + + vdev->priv = vsock; + + ret = virtio_vsock_vqs_init(vsock); + if (ret < 0) + goto out; + + rcu_assign_pointer(the_virtio_vsock, vsock); + virtio_vsock_vqs_start(vsock); + + mutex_unlock(&the_virtio_vsock_mutex); + + return 0; + +out: + kfree(vsock); + mutex_unlock(&the_virtio_vsock_mutex); + return ret; +} + +static void virtio_vsock_remove(struct virtio_device *vdev) +{ + struct virtio_vsock *vsock = vdev->priv; + + mutex_lock(&the_virtio_vsock_mutex); + + vdev->priv = NULL; + rcu_assign_pointer(the_virtio_vsock, NULL); + synchronize_rcu(); + + virtio_vsock_vqs_del(vsock); + + /* Other works can be queued before 'config->del_vqs()', so we flush + * all works before to free the vsock object to avoid use after free. + */ + flush_work(&vsock->rx_work); + flush_work(&vsock->tx_work); + flush_work(&vsock->event_work); + flush_work(&vsock->send_pkt_work); + + mutex_unlock(&the_virtio_vsock_mutex); + + kfree(vsock); +} + +#ifdef CONFIG_PM_SLEEP +static int virtio_vsock_freeze(struct virtio_device *vdev) +{ + struct virtio_vsock *vsock = vdev->priv; + + mutex_lock(&the_virtio_vsock_mutex); + + rcu_assign_pointer(the_virtio_vsock, NULL); + synchronize_rcu(); + + virtio_vsock_vqs_del(vsock); + + mutex_unlock(&the_virtio_vsock_mutex); + + return 0; +} + +static int virtio_vsock_restore(struct virtio_device *vdev) +{ + struct virtio_vsock *vsock = vdev->priv; + int ret; + + mutex_lock(&the_virtio_vsock_mutex); + + /* Only one virtio-vsock device per guest is supported */ + if (rcu_dereference_protected(the_virtio_vsock, + lockdep_is_held(&the_virtio_vsock_mutex))) { + ret = -EBUSY; + goto out; + } + + ret = virtio_vsock_vqs_init(vsock); + if (ret < 0) + goto out; + + rcu_assign_pointer(the_virtio_vsock, vsock); + virtio_vsock_vqs_start(vsock); + +out: + mutex_unlock(&the_virtio_vsock_mutex); + return ret; +} +#endif /* CONFIG_PM_SLEEP */ + +static struct virtio_device_id id_table[] = { + { VIRTIO_ID_VSOCK, VIRTIO_DEV_ANY_ID }, + { 0 }, +}; + +static unsigned int features[] = { + VIRTIO_VSOCK_F_SEQPACKET +}; + +static struct virtio_driver virtio_vsock_driver = { + .feature_table = features, + .feature_table_size = ARRAY_SIZE(features), + .driver.name = KBUILD_MODNAME, + .driver.owner = THIS_MODULE, + .id_table = id_table, + .probe = virtio_vsock_probe, + .remove = virtio_vsock_remove, +#ifdef CONFIG_PM_SLEEP + .freeze = virtio_vsock_freeze, + .restore = virtio_vsock_restore, +#endif +}; + +static int __init virtio_vsock_init(void) +{ + int ret; + + virtio_vsock_workqueue = alloc_workqueue("virtio_vsock", 0, 0); + if (!virtio_vsock_workqueue) + return -ENOMEM; + + ret = vsock_core_register(&virtio_transport.transport, + VSOCK_TRANSPORT_F_G2H); + if (ret) + goto out_wq; + + ret = register_virtio_driver(&virtio_vsock_driver); + if (ret) + goto out_vci; + + return 0; + +out_vci: + vsock_core_unregister(&virtio_transport.transport); +out_wq: + destroy_workqueue(virtio_vsock_workqueue); + return ret; +} + +static void __exit virtio_vsock_exit(void) +{ + unregister_virtio_driver(&virtio_vsock_driver); + vsock_core_unregister(&virtio_transport.transport); + destroy_workqueue(virtio_vsock_workqueue); +} + +module_init(virtio_vsock_init); +module_exit(virtio_vsock_exit); +MODULE_LICENSE("GPL v2"); +MODULE_AUTHOR("Asias He"); +MODULE_DESCRIPTION("virtio transport for vsock"); +MODULE_DEVICE_TABLE(virtio, id_table); diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c new file mode 100644 index 000000000..9983b833b --- /dev/null +++ b/net/vmw_vsock/virtio_transport_common.c @@ -0,0 +1,1415 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * common code for virtio vsock + * + * Copyright (C) 2013-2015 Red Hat, Inc. + * Author: Asias He + * Stefan Hajnoczi + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#define CREATE_TRACE_POINTS +#include + +/* How long to wait for graceful shutdown of a connection */ +#define VSOCK_CLOSE_TIMEOUT (8 * HZ) + +/* Threshold for detecting small packets to copy */ +#define GOOD_COPY_LEN 128 + +static const struct virtio_transport * +virtio_transport_get_ops(struct vsock_sock *vsk) +{ + const struct vsock_transport *t = vsock_core_get_transport(vsk); + + if (WARN_ON(!t)) + return NULL; + + return container_of(t, struct virtio_transport, transport); +} + +/* Returns a new packet on success, otherwise returns NULL. + * + * If NULL is returned, errp is set to a negative errno. + */ +static struct sk_buff * +virtio_transport_alloc_skb(struct virtio_vsock_pkt_info *info, + size_t len, + u32 src_cid, + u32 src_port, + u32 dst_cid, + u32 dst_port) +{ + const size_t skb_len = VIRTIO_VSOCK_SKB_HEADROOM + len; + struct virtio_vsock_hdr *hdr; + struct sk_buff *skb; + void *payload; + int err; + + skb = virtio_vsock_alloc_skb(skb_len, GFP_KERNEL); + if (!skb) + return NULL; + + hdr = virtio_vsock_hdr(skb); + hdr->type = cpu_to_le16(info->type); + hdr->op = cpu_to_le16(info->op); + hdr->src_cid = cpu_to_le64(src_cid); + hdr->dst_cid = cpu_to_le64(dst_cid); + hdr->src_port = cpu_to_le32(src_port); + hdr->dst_port = cpu_to_le32(dst_port); + hdr->flags = cpu_to_le32(info->flags); + hdr->len = cpu_to_le32(len); + hdr->buf_alloc = cpu_to_le32(0); + hdr->fwd_cnt = cpu_to_le32(0); + + if (info->msg && len > 0) { + payload = skb_put(skb, len); + err = memcpy_from_msg(payload, info->msg, len); + if (err) + goto out; + + if (msg_data_left(info->msg) == 0 && + info->type == VIRTIO_VSOCK_TYPE_SEQPACKET) { + hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOM); + + if (info->msg->msg_flags & MSG_EOR) + hdr->flags |= cpu_to_le32(VIRTIO_VSOCK_SEQ_EOR); + } + } + + if (info->reply) + virtio_vsock_skb_set_reply(skb); + + trace_virtio_transport_alloc_pkt(src_cid, src_port, + dst_cid, dst_port, + len, + info->type, + info->op, + info->flags); + + if (info->vsk && !skb_set_owner_sk_safe(skb, sk_vsock(info->vsk))) { + WARN_ONCE(1, "failed to allocate skb on vsock socket with sk_refcnt == 0\n"); + goto out; + } + + return skb; + +out: + kfree_skb(skb); + return NULL; +} + +/* Packet capture */ +static struct sk_buff *virtio_transport_build_skb(void *opaque) +{ + struct virtio_vsock_hdr *pkt_hdr; + struct sk_buff *pkt = opaque; + struct af_vsockmon_hdr *hdr; + struct sk_buff *skb; + size_t payload_len; + void *payload_buf; + + /* A packet could be split to fit the RX buffer, so we can retrieve + * the payload length from the header and the buffer pointer taking + * care of the offset in the original packet. + */ + pkt_hdr = virtio_vsock_hdr(pkt); + payload_len = pkt->len; + payload_buf = pkt->data; + + skb = alloc_skb(sizeof(*hdr) + sizeof(*pkt_hdr) + payload_len, + GFP_ATOMIC); + if (!skb) + return NULL; + + hdr = skb_put(skb, sizeof(*hdr)); + + /* pkt->hdr is little-endian so no need to byteswap here */ + hdr->src_cid = pkt_hdr->src_cid; + hdr->src_port = pkt_hdr->src_port; + hdr->dst_cid = pkt_hdr->dst_cid; + hdr->dst_port = pkt_hdr->dst_port; + + hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO); + hdr->len = cpu_to_le16(sizeof(*pkt_hdr)); + memset(hdr->reserved, 0, sizeof(hdr->reserved)); + + switch (le16_to_cpu(pkt_hdr->op)) { + case VIRTIO_VSOCK_OP_REQUEST: + case VIRTIO_VSOCK_OP_RESPONSE: + hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT); + break; + case VIRTIO_VSOCK_OP_RST: + case VIRTIO_VSOCK_OP_SHUTDOWN: + hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT); + break; + case VIRTIO_VSOCK_OP_RW: + hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD); + break; + case VIRTIO_VSOCK_OP_CREDIT_UPDATE: + case VIRTIO_VSOCK_OP_CREDIT_REQUEST: + hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL); + break; + default: + hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN); + break; + } + + skb_put_data(skb, pkt_hdr, sizeof(*pkt_hdr)); + + if (payload_len) { + skb_put_data(skb, payload_buf, payload_len); + } + + return skb; +} + +void virtio_transport_deliver_tap_pkt(struct sk_buff *skb) +{ + if (virtio_vsock_skb_tap_delivered(skb)) + return; + + vsock_deliver_tap(virtio_transport_build_skb, skb); + virtio_vsock_skb_set_tap_delivered(skb); +} +EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt); + +static u16 virtio_transport_get_type(struct sock *sk) +{ + if (sk->sk_type == SOCK_STREAM) + return VIRTIO_VSOCK_TYPE_STREAM; + else + return VIRTIO_VSOCK_TYPE_SEQPACKET; +} + +/* This function can only be used on connecting/connected sockets, + * since a socket assigned to a transport is required. + * + * Do not use on listener sockets! + */ +static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, + struct virtio_vsock_pkt_info *info) +{ + u32 src_cid, src_port, dst_cid, dst_port; + const struct virtio_transport *t_ops; + struct virtio_vsock_sock *vvs; + u32 pkt_len = info->pkt_len; + struct sk_buff *skb; + + info->type = virtio_transport_get_type(sk_vsock(vsk)); + + t_ops = virtio_transport_get_ops(vsk); + if (unlikely(!t_ops)) + return -EFAULT; + + src_cid = t_ops->transport.get_local_cid(); + src_port = vsk->local_addr.svm_port; + if (!info->remote_cid) { + dst_cid = vsk->remote_addr.svm_cid; + dst_port = vsk->remote_addr.svm_port; + } else { + dst_cid = info->remote_cid; + dst_port = info->remote_port; + } + + vvs = vsk->trans; + + /* we can send less than pkt_len bytes */ + if (pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) + pkt_len = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE; + + /* virtio_transport_get_credit might return less than pkt_len credit */ + pkt_len = virtio_transport_get_credit(vvs, pkt_len); + + /* Do not send zero length OP_RW pkt */ + if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW) + return pkt_len; + + skb = virtio_transport_alloc_skb(info, pkt_len, + src_cid, src_port, + dst_cid, dst_port); + if (!skb) { + virtio_transport_put_credit(vvs, pkt_len); + return -ENOMEM; + } + + virtio_transport_inc_tx_pkt(vvs, skb); + + return t_ops->send_pkt(skb); +} + +static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, + u32 len) +{ + if (vvs->rx_bytes + len > vvs->buf_alloc) + return false; + + vvs->rx_bytes += len; + return true; +} + +static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs, + u32 len) +{ + vvs->rx_bytes -= len; + vvs->fwd_cnt += len; +} + +void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + + spin_lock_bh(&vvs->rx_lock); + vvs->last_fwd_cnt = vvs->fwd_cnt; + hdr->fwd_cnt = cpu_to_le32(vvs->fwd_cnt); + hdr->buf_alloc = cpu_to_le32(vvs->buf_alloc); + spin_unlock_bh(&vvs->rx_lock); +} +EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt); + +u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit) +{ + u32 ret; + + spin_lock_bh(&vvs->tx_lock); + ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); + if (ret > credit) + ret = credit; + vvs->tx_cnt += ret; + spin_unlock_bh(&vvs->tx_lock); + + return ret; +} +EXPORT_SYMBOL_GPL(virtio_transport_get_credit); + +void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit) +{ + spin_lock_bh(&vvs->tx_lock); + vvs->tx_cnt -= credit; + spin_unlock_bh(&vvs->tx_lock); +} +EXPORT_SYMBOL_GPL(virtio_transport_put_credit); + +static int virtio_transport_send_credit_update(struct vsock_sock *vsk) +{ + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE, + .vsk = vsk, + }; + + return virtio_transport_send_pkt_info(vsk, &info); +} + +static ssize_t +virtio_transport_stream_do_peek(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + size_t bytes, total = 0, off; + struct sk_buff *skb, *tmp; + int err = -EFAULT; + + spin_lock_bh(&vvs->rx_lock); + + skb_queue_walk_safe(&vvs->rx_queue, skb, tmp) { + off = 0; + + if (total == len) + break; + + while (total < len && off < skb->len) { + bytes = len - total; + if (bytes > skb->len - off) + bytes = skb->len - off; + + /* sk_lock is held by caller so no one else can dequeue. + * Unlock rx_lock since memcpy_to_msg() may sleep. + */ + spin_unlock_bh(&vvs->rx_lock); + + err = memcpy_to_msg(msg, skb->data + off, bytes); + if (err) + goto out; + + spin_lock_bh(&vvs->rx_lock); + + total += bytes; + off += bytes; + } + } + + spin_unlock_bh(&vvs->rx_lock); + + return total; + +out: + if (total) + err = total; + return err; +} + +static ssize_t +virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + size_t bytes, total = 0; + struct sk_buff *skb; + u32 fwd_cnt_delta; + bool low_rx_bytes; + int err = -EFAULT; + u32 free_space; + + spin_lock_bh(&vvs->rx_lock); + while (total < len && !skb_queue_empty(&vvs->rx_queue)) { + skb = skb_peek(&vvs->rx_queue); + + bytes = len - total; + if (bytes > skb->len) + bytes = skb->len; + + /* sk_lock is held by caller so no one else can dequeue. + * Unlock rx_lock since memcpy_to_msg() may sleep. + */ + spin_unlock_bh(&vvs->rx_lock); + + err = memcpy_to_msg(msg, skb->data, bytes); + if (err) + goto out; + + spin_lock_bh(&vvs->rx_lock); + + total += bytes; + skb_pull(skb, bytes); + + if (skb->len == 0) { + u32 pkt_len = le32_to_cpu(virtio_vsock_hdr(skb)->len); + + virtio_transport_dec_rx_pkt(vvs, pkt_len); + __skb_unlink(skb, &vvs->rx_queue); + consume_skb(skb); + } + } + + fwd_cnt_delta = vvs->fwd_cnt - vvs->last_fwd_cnt; + free_space = vvs->buf_alloc - fwd_cnt_delta; + low_rx_bytes = (vvs->rx_bytes < + sock_rcvlowat(sk_vsock(vsk), 0, INT_MAX)); + + spin_unlock_bh(&vvs->rx_lock); + + /* To reduce the number of credit update messages, + * don't update credits as long as lots of space is available. + * Note: the limit chosen here is arbitrary. Setting the limit + * too high causes extra messages. Too low causes transmitter + * stalls. As stalls are in theory more expensive than extra + * messages, we set the limit to a high value. TODO: experiment + * with different values. Also send credit update message when + * number of bytes in rx queue is not enough to wake up reader. + */ + if (fwd_cnt_delta && + (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE || low_rx_bytes)) + virtio_transport_send_credit_update(vsk); + + return total; + +out: + if (total) + err = total; + return err; +} + +static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + int flags) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + int dequeued_len = 0; + size_t user_buf_len = msg_data_left(msg); + bool msg_ready = false; + struct sk_buff *skb; + + spin_lock_bh(&vvs->rx_lock); + + if (vvs->msg_count == 0) { + spin_unlock_bh(&vvs->rx_lock); + return 0; + } + + while (!msg_ready) { + struct virtio_vsock_hdr *hdr; + size_t pkt_len; + + skb = __skb_dequeue(&vvs->rx_queue); + if (!skb) + break; + hdr = virtio_vsock_hdr(skb); + pkt_len = (size_t)le32_to_cpu(hdr->len); + + if (dequeued_len >= 0) { + size_t bytes_to_copy; + + bytes_to_copy = min(user_buf_len, pkt_len); + + if (bytes_to_copy) { + int err; + + /* sk_lock is held by caller so no one else can dequeue. + * Unlock rx_lock since memcpy_to_msg() may sleep. + */ + spin_unlock_bh(&vvs->rx_lock); + + err = memcpy_to_msg(msg, skb->data, bytes_to_copy); + if (err) { + /* Copy of message failed. Rest of + * fragments will be freed without copy. + */ + dequeued_len = err; + } else { + user_buf_len -= bytes_to_copy; + } + + spin_lock_bh(&vvs->rx_lock); + } + + if (dequeued_len >= 0) + dequeued_len += pkt_len; + } + + if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) { + msg_ready = true; + vvs->msg_count--; + + if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOR) + msg->msg_flags |= MSG_EOR; + } + + virtio_transport_dec_rx_pkt(vvs, pkt_len); + kfree_skb(skb); + } + + spin_unlock_bh(&vvs->rx_lock); + + virtio_transport_send_credit_update(vsk); + + return dequeued_len; +} + +ssize_t +virtio_transport_stream_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len, int flags) +{ + if (flags & MSG_PEEK) + return virtio_transport_stream_do_peek(vsk, msg, len); + else + return virtio_transport_stream_do_dequeue(vsk, msg, len); +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); + +ssize_t +virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + int flags) +{ + if (flags & MSG_PEEK) + return -EOPNOTSUPP; + + return virtio_transport_seqpacket_do_dequeue(vsk, msg, flags); +} +EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue); + +int +virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + + spin_lock_bh(&vvs->tx_lock); + + if (len > vvs->peer_buf_alloc) { + spin_unlock_bh(&vvs->tx_lock); + return -EMSGSIZE; + } + + spin_unlock_bh(&vvs->tx_lock); + + return virtio_transport_stream_enqueue(vsk, msg, len); +} +EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_enqueue); + +int +virtio_transport_dgram_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len, int flags) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue); + +s64 virtio_transport_stream_has_data(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + s64 bytes; + + spin_lock_bh(&vvs->rx_lock); + bytes = vvs->rx_bytes; + spin_unlock_bh(&vvs->rx_lock); + + return bytes; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data); + +u32 virtio_transport_seqpacket_has_data(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + u32 msg_count; + + spin_lock_bh(&vvs->rx_lock); + msg_count = vvs->msg_count; + spin_unlock_bh(&vvs->rx_lock); + + return msg_count; +} +EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_has_data); + +static s64 virtio_transport_has_space(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + s64 bytes; + + bytes = (s64)vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); + if (bytes < 0) + bytes = 0; + + return bytes; +} + +s64 virtio_transport_stream_has_space(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + s64 bytes; + + spin_lock_bh(&vvs->tx_lock); + bytes = virtio_transport_has_space(vsk); + spin_unlock_bh(&vvs->tx_lock); + + return bytes; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space); + +int virtio_transport_do_socket_init(struct vsock_sock *vsk, + struct vsock_sock *psk) +{ + struct virtio_vsock_sock *vvs; + + vvs = kzalloc(sizeof(*vvs), GFP_KERNEL); + if (!vvs) + return -ENOMEM; + + vsk->trans = vvs; + vvs->vsk = vsk; + if (psk && psk->trans) { + struct virtio_vsock_sock *ptrans = psk->trans; + + vvs->peer_buf_alloc = ptrans->peer_buf_alloc; + } + + if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE) + vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE; + + vvs->buf_alloc = vsk->buffer_size; + + spin_lock_init(&vvs->rx_lock); + spin_lock_init(&vvs->tx_lock); + skb_queue_head_init(&vvs->rx_queue); + + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); + +/* sk_lock held by the caller */ +void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + + if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE) + *val = VIRTIO_VSOCK_MAX_BUF_SIZE; + + vvs->buf_alloc = *val; + + virtio_transport_send_credit_update(vsk); +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size); + +int +virtio_transport_notify_poll_in(struct vsock_sock *vsk, + size_t target, + bool *data_ready_now) +{ + *data_ready_now = vsock_stream_has_data(vsk) >= target; + + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in); + +int +virtio_transport_notify_poll_out(struct vsock_sock *vsk, + size_t target, + bool *space_avail_now) +{ + s64 free_space; + + free_space = vsock_stream_has_space(vsk); + if (free_space > 0) + *space_avail_now = true; + else if (free_space == 0) + *space_avail_now = false; + + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out); + +int virtio_transport_notify_recv_init(struct vsock_sock *vsk, + size_t target, struct vsock_transport_recv_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init); + +int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk, + size_t target, struct vsock_transport_recv_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block); + +int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk, + size_t target, struct vsock_transport_recv_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue); + +int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk, + size_t target, ssize_t copied, bool data_read, + struct vsock_transport_recv_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue); + +int virtio_transport_notify_send_init(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init); + +int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block); + +int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue); + +int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk, + ssize_t written, struct vsock_transport_send_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue); + +u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) +{ + return vsk->buffer_size; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); + +bool virtio_transport_stream_is_active(struct vsock_sock *vsk) +{ + return true; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active); + +bool virtio_transport_stream_allow(u32 cid, u32 port) +{ + return true; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_allow); + +int virtio_transport_dgram_bind(struct vsock_sock *vsk, + struct sockaddr_vm *addr) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind); + +bool virtio_transport_dgram_allow(u32 cid, u32 port) +{ + return false; +} +EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow); + +int virtio_transport_connect(struct vsock_sock *vsk) +{ + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_REQUEST, + .vsk = vsk, + }; + + return virtio_transport_send_pkt_info(vsk, &info); +} +EXPORT_SYMBOL_GPL(virtio_transport_connect); + +int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) +{ + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_SHUTDOWN, + .flags = (mode & RCV_SHUTDOWN ? + VIRTIO_VSOCK_SHUTDOWN_RCV : 0) | + (mode & SEND_SHUTDOWN ? + VIRTIO_VSOCK_SHUTDOWN_SEND : 0), + .vsk = vsk, + }; + + return virtio_transport_send_pkt_info(vsk, &info); +} +EXPORT_SYMBOL_GPL(virtio_transport_shutdown); + +int +virtio_transport_dgram_enqueue(struct vsock_sock *vsk, + struct sockaddr_vm *remote_addr, + struct msghdr *msg, + size_t dgram_len) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue); + +ssize_t +virtio_transport_stream_enqueue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len) +{ + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_RW, + .msg = msg, + .pkt_len = len, + .vsk = vsk, + }; + + return virtio_transport_send_pkt_info(vsk, &info); +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue); + +void virtio_transport_destruct(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + + kfree(vvs); +} +EXPORT_SYMBOL_GPL(virtio_transport_destruct); + +static int virtio_transport_reset(struct vsock_sock *vsk, + struct sk_buff *skb) +{ + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_RST, + .reply = !!skb, + .vsk = vsk, + }; + + /* Send RST only if the original pkt is not a RST pkt */ + if (skb && le16_to_cpu(virtio_vsock_hdr(skb)->op) == VIRTIO_VSOCK_OP_RST) + return 0; + + return virtio_transport_send_pkt_info(vsk, &info); +} + +/* Normally packets are associated with a socket. There may be no socket if an + * attempt was made to connect to a socket that does not exist. + */ +static int virtio_transport_reset_no_sock(const struct virtio_transport *t, + struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_RST, + .type = le16_to_cpu(hdr->type), + .reply = true, + }; + struct sk_buff *reply; + + /* Send RST only if the original pkt is not a RST pkt */ + if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST) + return 0; + + reply = virtio_transport_alloc_skb(&info, 0, + le64_to_cpu(hdr->dst_cid), + le32_to_cpu(hdr->dst_port), + le64_to_cpu(hdr->src_cid), + le32_to_cpu(hdr->src_port)); + if (!reply) + return -ENOMEM; + + if (!t) { + kfree_skb(reply); + return -ENOTCONN; + } + + return t->send_pkt(reply); +} + +/* This function should be called with sk_lock held and SOCK_DONE set */ +static void virtio_transport_remove_sock(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + + /* We don't need to take rx_lock, as the socket is closing and we are + * removing it. + */ + __skb_queue_purge(&vvs->rx_queue); + vsock_remove_sock(vsk); +} + +static void virtio_transport_wait_close(struct sock *sk, long timeout) +{ + if (timeout) { + DEFINE_WAIT_FUNC(wait, woken_wake_function); + + add_wait_queue(sk_sleep(sk), &wait); + + do { + if (sk_wait_event(sk, &timeout, + sock_flag(sk, SOCK_DONE), &wait)) + break; + } while (!signal_pending(current) && timeout); + + remove_wait_queue(sk_sleep(sk), &wait); + } +} + +static void virtio_transport_do_close(struct vsock_sock *vsk, + bool cancel_timeout) +{ + struct sock *sk = sk_vsock(vsk); + + sock_set_flag(sk, SOCK_DONE); + vsk->peer_shutdown = SHUTDOWN_MASK; + if (vsock_stream_has_data(vsk) <= 0) + sk->sk_state = TCP_CLOSING; + sk->sk_state_change(sk); + + if (vsk->close_work_scheduled && + (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { + vsk->close_work_scheduled = false; + + virtio_transport_remove_sock(vsk); + + /* Release refcnt obtained when we scheduled the timeout */ + sock_put(sk); + } +} + +static void virtio_transport_close_timeout(struct work_struct *work) +{ + struct vsock_sock *vsk = + container_of(work, struct vsock_sock, close_work.work); + struct sock *sk = sk_vsock(vsk); + + sock_hold(sk); + lock_sock(sk); + + if (!sock_flag(sk, SOCK_DONE)) { + (void)virtio_transport_reset(vsk, NULL); + + virtio_transport_do_close(vsk, false); + } + + vsk->close_work_scheduled = false; + + release_sock(sk); + sock_put(sk); +} + +/* User context, vsk->sk is locked */ +static bool virtio_transport_close(struct vsock_sock *vsk) +{ + struct sock *sk = &vsk->sk; + + if (!(sk->sk_state == TCP_ESTABLISHED || + sk->sk_state == TCP_CLOSING)) + return true; + + /* Already received SHUTDOWN from peer, reply with RST */ + if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) { + (void)virtio_transport_reset(vsk, NULL); + return true; + } + + if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK) + (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK); + + if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING)) + virtio_transport_wait_close(sk, sk->sk_lingertime); + + if (sock_flag(sk, SOCK_DONE)) { + return true; + } + + sock_hold(sk); + INIT_DELAYED_WORK(&vsk->close_work, + virtio_transport_close_timeout); + vsk->close_work_scheduled = true; + schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT); + return false; +} + +void virtio_transport_release(struct vsock_sock *vsk) +{ + struct sock *sk = &vsk->sk; + bool remove_sock = true; + + if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) + remove_sock = virtio_transport_close(vsk); + + if (remove_sock) { + sock_set_flag(sk, SOCK_DONE); + virtio_transport_remove_sock(vsk); + } +} +EXPORT_SYMBOL_GPL(virtio_transport_release); + +static int +virtio_transport_recv_connecting(struct sock *sk, + struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + struct vsock_sock *vsk = vsock_sk(sk); + int skerr; + int err; + + switch (le16_to_cpu(hdr->op)) { + case VIRTIO_VSOCK_OP_RESPONSE: + sk->sk_state = TCP_ESTABLISHED; + sk->sk_socket->state = SS_CONNECTED; + vsock_insert_connected(vsk); + sk->sk_state_change(sk); + break; + case VIRTIO_VSOCK_OP_INVALID: + break; + case VIRTIO_VSOCK_OP_RST: + skerr = ECONNRESET; + err = 0; + goto destroy; + default: + skerr = EPROTO; + err = -EINVAL; + goto destroy; + } + return 0; + +destroy: + virtio_transport_reset(vsk, skb); + sk->sk_state = TCP_CLOSE; + sk->sk_err = skerr; + sk_error_report(sk); + return err; +} + +static void +virtio_transport_recv_enqueue(struct vsock_sock *vsk, + struct sk_buff *skb) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + bool can_enqueue, free_pkt = false; + struct virtio_vsock_hdr *hdr; + u32 len; + + hdr = virtio_vsock_hdr(skb); + len = le32_to_cpu(hdr->len); + + spin_lock_bh(&vvs->rx_lock); + + can_enqueue = virtio_transport_inc_rx_pkt(vvs, len); + if (!can_enqueue) { + free_pkt = true; + goto out; + } + + if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SEQ_EOM) + vvs->msg_count++; + + /* Try to copy small packets into the buffer of last packet queued, + * to avoid wasting memory queueing the entire buffer with a small + * payload. + */ + if (len <= GOOD_COPY_LEN && !skb_queue_empty(&vvs->rx_queue)) { + struct virtio_vsock_hdr *last_hdr; + struct sk_buff *last_skb; + + last_skb = skb_peek_tail(&vvs->rx_queue); + last_hdr = virtio_vsock_hdr(last_skb); + + /* If there is space in the last packet queued, we copy the + * new packet in its buffer. We avoid this if the last packet + * queued has VIRTIO_VSOCK_SEQ_EOM set, because this is + * delimiter of SEQPACKET message, so 'pkt' is the first packet + * of a new message. + */ + if (skb->len < skb_tailroom(last_skb) && + !(le32_to_cpu(last_hdr->flags) & VIRTIO_VSOCK_SEQ_EOM)) { + memcpy(skb_put(last_skb, skb->len), skb->data, skb->len); + free_pkt = true; + last_hdr->flags |= hdr->flags; + le32_add_cpu(&last_hdr->len, len); + goto out; + } + } + + __skb_queue_tail(&vvs->rx_queue, skb); + +out: + spin_unlock_bh(&vvs->rx_lock); + if (free_pkt) + kfree_skb(skb); +} + +static int +virtio_transport_recv_connected(struct sock *sk, + struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + struct vsock_sock *vsk = vsock_sk(sk); + int err = 0; + + switch (le16_to_cpu(hdr->op)) { + case VIRTIO_VSOCK_OP_RW: + virtio_transport_recv_enqueue(vsk, skb); + vsock_data_ready(sk); + return err; + case VIRTIO_VSOCK_OP_CREDIT_REQUEST: + virtio_transport_send_credit_update(vsk); + break; + case VIRTIO_VSOCK_OP_CREDIT_UPDATE: + sk->sk_write_space(sk); + break; + case VIRTIO_VSOCK_OP_SHUTDOWN: + if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_RCV) + vsk->peer_shutdown |= RCV_SHUTDOWN; + if (le32_to_cpu(hdr->flags) & VIRTIO_VSOCK_SHUTDOWN_SEND) + vsk->peer_shutdown |= SEND_SHUTDOWN; + if (vsk->peer_shutdown == SHUTDOWN_MASK) { + if (vsock_stream_has_data(vsk) <= 0 && !sock_flag(sk, SOCK_DONE)) { + (void)virtio_transport_reset(vsk, NULL); + virtio_transport_do_close(vsk, true); + } + /* Remove this socket anyway because the remote peer sent + * the shutdown. This way a new connection will succeed + * if the remote peer uses the same source port, + * even if the old socket is still unreleased, but now disconnected. + */ + vsock_remove_sock(vsk); + } + if (le32_to_cpu(virtio_vsock_hdr(skb)->flags)) + sk->sk_state_change(sk); + break; + case VIRTIO_VSOCK_OP_RST: + virtio_transport_do_close(vsk, true); + break; + default: + err = -EINVAL; + break; + } + + kfree_skb(skb); + return err; +} + +static void +virtio_transport_recv_disconnecting(struct sock *sk, + struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + struct vsock_sock *vsk = vsock_sk(sk); + + if (le16_to_cpu(hdr->op) == VIRTIO_VSOCK_OP_RST) + virtio_transport_do_close(vsk, true); +} + +static int +virtio_transport_send_response(struct vsock_sock *vsk, + struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_RESPONSE, + .remote_cid = le64_to_cpu(hdr->src_cid), + .remote_port = le32_to_cpu(hdr->src_port), + .reply = true, + .vsk = vsk, + }; + + return virtio_transport_send_pkt_info(vsk, &info); +} + +static bool virtio_transport_space_update(struct sock *sk, + struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + struct vsock_sock *vsk = vsock_sk(sk); + struct virtio_vsock_sock *vvs = vsk->trans; + bool space_available; + + /* Listener sockets are not associated with any transport, so we are + * not able to take the state to see if there is space available in the + * remote peer, but since they are only used to receive requests, we + * can assume that there is always space available in the other peer. + */ + if (!vvs) + return true; + + /* buf_alloc and fwd_cnt is always included in the hdr */ + spin_lock_bh(&vvs->tx_lock); + vvs->peer_buf_alloc = le32_to_cpu(hdr->buf_alloc); + vvs->peer_fwd_cnt = le32_to_cpu(hdr->fwd_cnt); + space_available = virtio_transport_has_space(vsk); + spin_unlock_bh(&vvs->tx_lock); + return space_available; +} + +/* Handle server socket */ +static int +virtio_transport_recv_listen(struct sock *sk, struct sk_buff *skb, + struct virtio_transport *t) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + struct vsock_sock *vsk = vsock_sk(sk); + struct vsock_sock *vchild; + struct sock *child; + int ret; + + if (le16_to_cpu(hdr->op) != VIRTIO_VSOCK_OP_REQUEST) { + virtio_transport_reset_no_sock(t, skb); + return -EINVAL; + } + + if (sk_acceptq_is_full(sk)) { + virtio_transport_reset_no_sock(t, skb); + return -ENOMEM; + } + + child = vsock_create_connected(sk); + if (!child) { + virtio_transport_reset_no_sock(t, skb); + return -ENOMEM; + } + + sk_acceptq_added(sk); + + lock_sock_nested(child, SINGLE_DEPTH_NESTING); + + child->sk_state = TCP_ESTABLISHED; + + vchild = vsock_sk(child); + vsock_addr_init(&vchild->local_addr, le64_to_cpu(hdr->dst_cid), + le32_to_cpu(hdr->dst_port)); + vsock_addr_init(&vchild->remote_addr, le64_to_cpu(hdr->src_cid), + le32_to_cpu(hdr->src_port)); + + ret = vsock_assign_transport(vchild, vsk); + /* Transport assigned (looking at remote_addr) must be the same + * where we received the request. + */ + if (ret || vchild->transport != &t->transport) { + release_sock(child); + virtio_transport_reset_no_sock(t, skb); + sock_put(child); + return ret; + } + + if (virtio_transport_space_update(child, skb)) + child->sk_write_space(child); + + vsock_insert_connected(vchild); + vsock_enqueue_accept(sk, child); + virtio_transport_send_response(vchild, skb); + + release_sock(child); + + sk->sk_data_ready(sk); + return 0; +} + +static bool virtio_transport_valid_type(u16 type) +{ + return (type == VIRTIO_VSOCK_TYPE_STREAM) || + (type == VIRTIO_VSOCK_TYPE_SEQPACKET); +} + +/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex + * lock. + */ +void virtio_transport_recv_pkt(struct virtio_transport *t, + struct sk_buff *skb) +{ + struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb); + struct sockaddr_vm src, dst; + struct vsock_sock *vsk; + struct sock *sk; + bool space_available; + + vsock_addr_init(&src, le64_to_cpu(hdr->src_cid), + le32_to_cpu(hdr->src_port)); + vsock_addr_init(&dst, le64_to_cpu(hdr->dst_cid), + le32_to_cpu(hdr->dst_port)); + + trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port, + dst.svm_cid, dst.svm_port, + le32_to_cpu(hdr->len), + le16_to_cpu(hdr->type), + le16_to_cpu(hdr->op), + le32_to_cpu(hdr->flags), + le32_to_cpu(hdr->buf_alloc), + le32_to_cpu(hdr->fwd_cnt)); + + if (!virtio_transport_valid_type(le16_to_cpu(hdr->type))) { + (void)virtio_transport_reset_no_sock(t, skb); + goto free_pkt; + } + + /* The socket must be in connected or bound table + * otherwise send reset back + */ + sk = vsock_find_connected_socket(&src, &dst); + if (!sk) { + sk = vsock_find_bound_socket(&dst); + if (!sk) { + (void)virtio_transport_reset_no_sock(t, skb); + goto free_pkt; + } + } + + if (virtio_transport_get_type(sk) != le16_to_cpu(hdr->type)) { + (void)virtio_transport_reset_no_sock(t, skb); + sock_put(sk); + goto free_pkt; + } + + if (!skb_set_owner_sk_safe(skb, sk)) { + WARN_ONCE(1, "receiving vsock socket has sk_refcnt == 0\n"); + goto free_pkt; + } + + vsk = vsock_sk(sk); + + lock_sock(sk); + + /* Check if sk has been closed before lock_sock */ + if (sock_flag(sk, SOCK_DONE)) { + (void)virtio_transport_reset_no_sock(t, skb); + release_sock(sk); + sock_put(sk); + goto free_pkt; + } + + space_available = virtio_transport_space_update(sk, skb); + + /* Update CID in case it has changed after a transport reset event */ + if (vsk->local_addr.svm_cid != VMADDR_CID_ANY) + vsk->local_addr.svm_cid = dst.svm_cid; + + if (space_available) + sk->sk_write_space(sk); + + switch (sk->sk_state) { + case TCP_LISTEN: + virtio_transport_recv_listen(sk, skb, t); + kfree_skb(skb); + break; + case TCP_SYN_SENT: + virtio_transport_recv_connecting(sk, skb); + kfree_skb(skb); + break; + case TCP_ESTABLISHED: + virtio_transport_recv_connected(sk, skb); + break; + case TCP_CLOSING: + virtio_transport_recv_disconnecting(sk, skb); + kfree_skb(skb); + break; + default: + (void)virtio_transport_reset_no_sock(t, skb); + kfree_skb(skb); + break; + } + + release_sock(sk); + + /* Release refcnt obtained when we fetched this socket out of the + * bound or connected list. + */ + sock_put(sk); + return; + +free_pkt: + kfree_skb(skb); +} +EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt); + +/* Remove skbs found in a queue that have a vsk that matches. + * + * Each skb is freed. + * + * Returns the count of skbs that were reply packets. + */ +int virtio_transport_purge_skbs(void *vsk, struct sk_buff_head *queue) +{ + struct sk_buff_head freeme; + struct sk_buff *skb, *tmp; + int cnt = 0; + + skb_queue_head_init(&freeme); + + spin_lock_bh(&queue->lock); + skb_queue_walk_safe(queue, skb, tmp) { + if (vsock_sk(skb->sk) != vsk) + continue; + + __skb_unlink(skb, queue); + __skb_queue_tail(&freeme, skb); + + if (virtio_vsock_skb_reply(skb)) + cnt++; + } + spin_unlock_bh(&queue->lock); + + __skb_queue_purge(&freeme); + + return cnt; +} +EXPORT_SYMBOL_GPL(virtio_transport_purge_skbs); + +MODULE_LICENSE("GPL v2"); +MODULE_AUTHOR("Asias He"); +MODULE_DESCRIPTION("common code for virtio vsock"); diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c new file mode 100644 index 000000000..36eb16a40 --- /dev/null +++ b/net/vmw_vsock/vmci_transport.c @@ -0,0 +1,2147 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * VMware vSockets Driver + * + * Copyright (C) 2007-2013 VMware, Inc. All rights reserved. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "vmci_transport_notify.h" + +static int vmci_transport_recv_dgram_cb(void *data, struct vmci_datagram *dg); +static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg); +static void vmci_transport_peer_detach_cb(u32 sub_id, + const struct vmci_event_data *ed, + void *client_data); +static void vmci_transport_recv_pkt_work(struct work_struct *work); +static void vmci_transport_cleanup(struct work_struct *work); +static int vmci_transport_recv_listen(struct sock *sk, + struct vmci_transport_packet *pkt); +static int vmci_transport_recv_connecting_server( + struct sock *sk, + struct sock *pending, + struct vmci_transport_packet *pkt); +static int vmci_transport_recv_connecting_client( + struct sock *sk, + struct vmci_transport_packet *pkt); +static int vmci_transport_recv_connecting_client_negotiate( + struct sock *sk, + struct vmci_transport_packet *pkt); +static int vmci_transport_recv_connecting_client_invalid( + struct sock *sk, + struct vmci_transport_packet *pkt); +static int vmci_transport_recv_connected(struct sock *sk, + struct vmci_transport_packet *pkt); +static bool vmci_transport_old_proto_override(bool *old_pkt_proto); +static u16 vmci_transport_new_proto_supported_versions(void); +static bool vmci_transport_proto_to_notify_struct(struct sock *sk, u16 *proto, + bool old_pkt_proto); +static bool vmci_check_transport(struct vsock_sock *vsk); + +struct vmci_transport_recv_pkt_info { + struct work_struct work; + struct sock *sk; + struct vmci_transport_packet pkt; +}; + +static LIST_HEAD(vmci_transport_cleanup_list); +static DEFINE_SPINLOCK(vmci_transport_cleanup_lock); +static DECLARE_WORK(vmci_transport_cleanup_work, vmci_transport_cleanup); + +static struct vmci_handle vmci_transport_stream_handle = { VMCI_INVALID_ID, + VMCI_INVALID_ID }; +static u32 vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID; + +static int PROTOCOL_OVERRIDE = -1; + +static struct vsock_transport vmci_transport; /* forward declaration */ + +/* Helper function to convert from a VMCI error code to a VSock error code. */ + +static s32 vmci_transport_error_to_vsock_error(s32 vmci_error) +{ + switch (vmci_error) { + case VMCI_ERROR_NO_MEM: + return -ENOMEM; + case VMCI_ERROR_DUPLICATE_ENTRY: + case VMCI_ERROR_ALREADY_EXISTS: + return -EADDRINUSE; + case VMCI_ERROR_NO_ACCESS: + return -EPERM; + case VMCI_ERROR_NO_RESOURCES: + return -ENOBUFS; + case VMCI_ERROR_INVALID_RESOURCE: + return -EHOSTUNREACH; + case VMCI_ERROR_INVALID_ARGS: + default: + break; + } + return -EINVAL; +} + +static u32 vmci_transport_peer_rid(u32 peer_cid) +{ + if (VMADDR_CID_HYPERVISOR == peer_cid) + return VMCI_TRANSPORT_HYPERVISOR_PACKET_RID; + + return VMCI_TRANSPORT_PACKET_RID; +} + +static inline void +vmci_transport_packet_init(struct vmci_transport_packet *pkt, + struct sockaddr_vm *src, + struct sockaddr_vm *dst, + u8 type, + u64 size, + u64 mode, + struct vmci_transport_waiting_info *wait, + u16 proto, + struct vmci_handle handle) +{ + /* We register the stream control handler as an any cid handle so we + * must always send from a source address of VMADDR_CID_ANY + */ + pkt->dg.src = vmci_make_handle(VMADDR_CID_ANY, + VMCI_TRANSPORT_PACKET_RID); + pkt->dg.dst = vmci_make_handle(dst->svm_cid, + vmci_transport_peer_rid(dst->svm_cid)); + pkt->dg.payload_size = sizeof(*pkt) - sizeof(pkt->dg); + pkt->version = VMCI_TRANSPORT_PACKET_VERSION; + pkt->type = type; + pkt->src_port = src->svm_port; + pkt->dst_port = dst->svm_port; + memset(&pkt->proto, 0, sizeof(pkt->proto)); + memset(&pkt->_reserved2, 0, sizeof(pkt->_reserved2)); + + switch (pkt->type) { + case VMCI_TRANSPORT_PACKET_TYPE_INVALID: + pkt->u.size = 0; + break; + + case VMCI_TRANSPORT_PACKET_TYPE_REQUEST: + case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE: + pkt->u.size = size; + break; + + case VMCI_TRANSPORT_PACKET_TYPE_OFFER: + case VMCI_TRANSPORT_PACKET_TYPE_ATTACH: + pkt->u.handle = handle; + break; + + case VMCI_TRANSPORT_PACKET_TYPE_WROTE: + case VMCI_TRANSPORT_PACKET_TYPE_READ: + case VMCI_TRANSPORT_PACKET_TYPE_RST: + pkt->u.size = 0; + break; + + case VMCI_TRANSPORT_PACKET_TYPE_SHUTDOWN: + pkt->u.mode = mode; + break; + + case VMCI_TRANSPORT_PACKET_TYPE_WAITING_READ: + case VMCI_TRANSPORT_PACKET_TYPE_WAITING_WRITE: + memcpy(&pkt->u.wait, wait, sizeof(pkt->u.wait)); + break; + + case VMCI_TRANSPORT_PACKET_TYPE_REQUEST2: + case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2: + pkt->u.size = size; + pkt->proto = proto; + break; + } +} + +static inline void +vmci_transport_packet_get_addresses(struct vmci_transport_packet *pkt, + struct sockaddr_vm *local, + struct sockaddr_vm *remote) +{ + vsock_addr_init(local, pkt->dg.dst.context, pkt->dst_port); + vsock_addr_init(remote, pkt->dg.src.context, pkt->src_port); +} + +static int +__vmci_transport_send_control_pkt(struct vmci_transport_packet *pkt, + struct sockaddr_vm *src, + struct sockaddr_vm *dst, + enum vmci_transport_packet_type type, + u64 size, + u64 mode, + struct vmci_transport_waiting_info *wait, + u16 proto, + struct vmci_handle handle, + bool convert_error) +{ + int err; + + vmci_transport_packet_init(pkt, src, dst, type, size, mode, wait, + proto, handle); + err = vmci_datagram_send(&pkt->dg); + if (convert_error && (err < 0)) + return vmci_transport_error_to_vsock_error(err); + + return err; +} + +static int +vmci_transport_reply_control_pkt_fast(struct vmci_transport_packet *pkt, + enum vmci_transport_packet_type type, + u64 size, + u64 mode, + struct vmci_transport_waiting_info *wait, + struct vmci_handle handle) +{ + struct vmci_transport_packet reply; + struct sockaddr_vm src, dst; + + if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST) { + return 0; + } else { + vmci_transport_packet_get_addresses(pkt, &src, &dst); + return __vmci_transport_send_control_pkt(&reply, &src, &dst, + type, + size, mode, wait, + VSOCK_PROTO_INVALID, + handle, true); + } +} + +static int +vmci_transport_send_control_pkt_bh(struct sockaddr_vm *src, + struct sockaddr_vm *dst, + enum vmci_transport_packet_type type, + u64 size, + u64 mode, + struct vmci_transport_waiting_info *wait, + struct vmci_handle handle) +{ + /* Note that it is safe to use a single packet across all CPUs since + * two tasklets of the same type are guaranteed to not ever run + * simultaneously. If that ever changes, or VMCI stops using tasklets, + * we can use per-cpu packets. + */ + static struct vmci_transport_packet pkt; + + return __vmci_transport_send_control_pkt(&pkt, src, dst, type, + size, mode, wait, + VSOCK_PROTO_INVALID, handle, + false); +} + +static int +vmci_transport_alloc_send_control_pkt(struct sockaddr_vm *src, + struct sockaddr_vm *dst, + enum vmci_transport_packet_type type, + u64 size, + u64 mode, + struct vmci_transport_waiting_info *wait, + u16 proto, + struct vmci_handle handle) +{ + struct vmci_transport_packet *pkt; + int err; + + pkt = kmalloc(sizeof(*pkt), GFP_KERNEL); + if (!pkt) + return -ENOMEM; + + err = __vmci_transport_send_control_pkt(pkt, src, dst, type, size, + mode, wait, proto, handle, + true); + kfree(pkt); + + return err; +} + +static int +vmci_transport_send_control_pkt(struct sock *sk, + enum vmci_transport_packet_type type, + u64 size, + u64 mode, + struct vmci_transport_waiting_info *wait, + u16 proto, + struct vmci_handle handle) +{ + struct vsock_sock *vsk; + + vsk = vsock_sk(sk); + + if (!vsock_addr_bound(&vsk->local_addr)) + return -EINVAL; + + if (!vsock_addr_bound(&vsk->remote_addr)) + return -EINVAL; + + return vmci_transport_alloc_send_control_pkt(&vsk->local_addr, + &vsk->remote_addr, + type, size, mode, + wait, proto, handle); +} + +static int vmci_transport_send_reset_bh(struct sockaddr_vm *dst, + struct sockaddr_vm *src, + struct vmci_transport_packet *pkt) +{ + if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST) + return 0; + return vmci_transport_send_control_pkt_bh( + dst, src, + VMCI_TRANSPORT_PACKET_TYPE_RST, 0, + 0, NULL, VMCI_INVALID_HANDLE); +} + +static int vmci_transport_send_reset(struct sock *sk, + struct vmci_transport_packet *pkt) +{ + struct sockaddr_vm *dst_ptr; + struct sockaddr_vm dst; + struct vsock_sock *vsk; + + if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST) + return 0; + + vsk = vsock_sk(sk); + + if (!vsock_addr_bound(&vsk->local_addr)) + return -EINVAL; + + if (vsock_addr_bound(&vsk->remote_addr)) { + dst_ptr = &vsk->remote_addr; + } else { + vsock_addr_init(&dst, pkt->dg.src.context, + pkt->src_port); + dst_ptr = &dst; + } + return vmci_transport_alloc_send_control_pkt(&vsk->local_addr, dst_ptr, + VMCI_TRANSPORT_PACKET_TYPE_RST, + 0, 0, NULL, VSOCK_PROTO_INVALID, + VMCI_INVALID_HANDLE); +} + +static int vmci_transport_send_negotiate(struct sock *sk, size_t size) +{ + return vmci_transport_send_control_pkt( + sk, + VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE, + size, 0, NULL, + VSOCK_PROTO_INVALID, + VMCI_INVALID_HANDLE); +} + +static int vmci_transport_send_negotiate2(struct sock *sk, size_t size, + u16 version) +{ + return vmci_transport_send_control_pkt( + sk, + VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2, + size, 0, NULL, version, + VMCI_INVALID_HANDLE); +} + +static int vmci_transport_send_qp_offer(struct sock *sk, + struct vmci_handle handle) +{ + return vmci_transport_send_control_pkt( + sk, VMCI_TRANSPORT_PACKET_TYPE_OFFER, 0, + 0, NULL, + VSOCK_PROTO_INVALID, handle); +} + +static int vmci_transport_send_attach(struct sock *sk, + struct vmci_handle handle) +{ + return vmci_transport_send_control_pkt( + sk, VMCI_TRANSPORT_PACKET_TYPE_ATTACH, + 0, 0, NULL, VSOCK_PROTO_INVALID, + handle); +} + +static int vmci_transport_reply_reset(struct vmci_transport_packet *pkt) +{ + return vmci_transport_reply_control_pkt_fast( + pkt, + VMCI_TRANSPORT_PACKET_TYPE_RST, + 0, 0, NULL, + VMCI_INVALID_HANDLE); +} + +static int vmci_transport_send_invalid_bh(struct sockaddr_vm *dst, + struct sockaddr_vm *src) +{ + return vmci_transport_send_control_pkt_bh( + dst, src, + VMCI_TRANSPORT_PACKET_TYPE_INVALID, + 0, 0, NULL, VMCI_INVALID_HANDLE); +} + +int vmci_transport_send_wrote_bh(struct sockaddr_vm *dst, + struct sockaddr_vm *src) +{ + return vmci_transport_send_control_pkt_bh( + dst, src, + VMCI_TRANSPORT_PACKET_TYPE_WROTE, 0, + 0, NULL, VMCI_INVALID_HANDLE); +} + +int vmci_transport_send_read_bh(struct sockaddr_vm *dst, + struct sockaddr_vm *src) +{ + return vmci_transport_send_control_pkt_bh( + dst, src, + VMCI_TRANSPORT_PACKET_TYPE_READ, 0, + 0, NULL, VMCI_INVALID_HANDLE); +} + +int vmci_transport_send_wrote(struct sock *sk) +{ + return vmci_transport_send_control_pkt( + sk, VMCI_TRANSPORT_PACKET_TYPE_WROTE, 0, + 0, NULL, VSOCK_PROTO_INVALID, + VMCI_INVALID_HANDLE); +} + +int vmci_transport_send_read(struct sock *sk) +{ + return vmci_transport_send_control_pkt( + sk, VMCI_TRANSPORT_PACKET_TYPE_READ, 0, + 0, NULL, VSOCK_PROTO_INVALID, + VMCI_INVALID_HANDLE); +} + +int vmci_transport_send_waiting_write(struct sock *sk, + struct vmci_transport_waiting_info *wait) +{ + return vmci_transport_send_control_pkt( + sk, VMCI_TRANSPORT_PACKET_TYPE_WAITING_WRITE, + 0, 0, wait, VSOCK_PROTO_INVALID, + VMCI_INVALID_HANDLE); +} + +int vmci_transport_send_waiting_read(struct sock *sk, + struct vmci_transport_waiting_info *wait) +{ + return vmci_transport_send_control_pkt( + sk, VMCI_TRANSPORT_PACKET_TYPE_WAITING_READ, + 0, 0, wait, VSOCK_PROTO_INVALID, + VMCI_INVALID_HANDLE); +} + +static int vmci_transport_shutdown(struct vsock_sock *vsk, int mode) +{ + return vmci_transport_send_control_pkt( + &vsk->sk, + VMCI_TRANSPORT_PACKET_TYPE_SHUTDOWN, + 0, mode, NULL, + VSOCK_PROTO_INVALID, + VMCI_INVALID_HANDLE); +} + +static int vmci_transport_send_conn_request(struct sock *sk, size_t size) +{ + return vmci_transport_send_control_pkt(sk, + VMCI_TRANSPORT_PACKET_TYPE_REQUEST, + size, 0, NULL, + VSOCK_PROTO_INVALID, + VMCI_INVALID_HANDLE); +} + +static int vmci_transport_send_conn_request2(struct sock *sk, size_t size, + u16 version) +{ + return vmci_transport_send_control_pkt( + sk, VMCI_TRANSPORT_PACKET_TYPE_REQUEST2, + size, 0, NULL, version, + VMCI_INVALID_HANDLE); +} + +static struct sock *vmci_transport_get_pending( + struct sock *listener, + struct vmci_transport_packet *pkt) +{ + struct vsock_sock *vlistener; + struct vsock_sock *vpending; + struct sock *pending; + struct sockaddr_vm src; + + vsock_addr_init(&src, pkt->dg.src.context, pkt->src_port); + + vlistener = vsock_sk(listener); + + list_for_each_entry(vpending, &vlistener->pending_links, + pending_links) { + if (vsock_addr_equals_addr(&src, &vpending->remote_addr) && + pkt->dst_port == vpending->local_addr.svm_port) { + pending = sk_vsock(vpending); + sock_hold(pending); + goto found; + } + } + + pending = NULL; +found: + return pending; + +} + +static void vmci_transport_release_pending(struct sock *pending) +{ + sock_put(pending); +} + +/* We allow two kinds of sockets to communicate with a restricted VM: 1) + * trusted sockets 2) sockets from applications running as the same user as the + * VM (this is only true for the host side and only when using hosted products) + */ + +static bool vmci_transport_is_trusted(struct vsock_sock *vsock, u32 peer_cid) +{ + return vsock->trusted || + vmci_is_context_owner(peer_cid, vsock->owner->uid); +} + +/* We allow sending datagrams to and receiving datagrams from a restricted VM + * only if it is trusted as described in vmci_transport_is_trusted. + */ + +static bool vmci_transport_allow_dgram(struct vsock_sock *vsock, u32 peer_cid) +{ + if (VMADDR_CID_HYPERVISOR == peer_cid) + return true; + + if (vsock->cached_peer != peer_cid) { + vsock->cached_peer = peer_cid; + if (!vmci_transport_is_trusted(vsock, peer_cid) && + (vmci_context_get_priv_flags(peer_cid) & + VMCI_PRIVILEGE_FLAG_RESTRICTED)) { + vsock->cached_peer_allow_dgram = false; + } else { + vsock->cached_peer_allow_dgram = true; + } + } + + return vsock->cached_peer_allow_dgram; +} + +static int +vmci_transport_queue_pair_alloc(struct vmci_qp **qpair, + struct vmci_handle *handle, + u64 produce_size, + u64 consume_size, + u32 peer, u32 flags, bool trusted) +{ + int err = 0; + + if (trusted) { + /* Try to allocate our queue pair as trusted. This will only + * work if vsock is running in the host. + */ + + err = vmci_qpair_alloc(qpair, handle, produce_size, + consume_size, + peer, flags, + VMCI_PRIVILEGE_FLAG_TRUSTED); + if (err != VMCI_ERROR_NO_ACCESS) + goto out; + + } + + err = vmci_qpair_alloc(qpair, handle, produce_size, consume_size, + peer, flags, VMCI_NO_PRIVILEGE_FLAGS); +out: + if (err < 0) { + pr_err_once("Could not attach to queue pair with %d\n", err); + err = vmci_transport_error_to_vsock_error(err); + } + + return err; +} + +static int +vmci_transport_datagram_create_hnd(u32 resource_id, + u32 flags, + vmci_datagram_recv_cb recv_cb, + void *client_data, + struct vmci_handle *out_handle) +{ + int err = 0; + + /* Try to allocate our datagram handler as trusted. This will only work + * if vsock is running in the host. + */ + + err = vmci_datagram_create_handle_priv(resource_id, flags, + VMCI_PRIVILEGE_FLAG_TRUSTED, + recv_cb, + client_data, out_handle); + + if (err == VMCI_ERROR_NO_ACCESS) + err = vmci_datagram_create_handle(resource_id, flags, + recv_cb, client_data, + out_handle); + + return err; +} + +/* This is invoked as part of a tasklet that's scheduled when the VMCI + * interrupt fires. This is run in bottom-half context and if it ever needs to + * sleep it should defer that work to a work queue. + */ + +static int vmci_transport_recv_dgram_cb(void *data, struct vmci_datagram *dg) +{ + struct sock *sk; + size_t size; + struct sk_buff *skb; + struct vsock_sock *vsk; + + sk = (struct sock *)data; + + /* This handler is privileged when this module is running on the host. + * We will get datagrams from all endpoints (even VMs that are in a + * restricted context). If we get one from a restricted context then + * the destination socket must be trusted. + * + * NOTE: We access the socket struct without holding the lock here. + * This is ok because the field we are interested is never modified + * outside of the create and destruct socket functions. + */ + vsk = vsock_sk(sk); + if (!vmci_transport_allow_dgram(vsk, dg->src.context)) + return VMCI_ERROR_NO_ACCESS; + + size = VMCI_DG_SIZE(dg); + + /* Attach the packet to the socket's receive queue as an sk_buff. */ + skb = alloc_skb(size, GFP_ATOMIC); + if (!skb) + return VMCI_ERROR_NO_MEM; + + /* sk_receive_skb() will do a sock_put(), so hold here. */ + sock_hold(sk); + skb_put(skb, size); + memcpy(skb->data, dg, size); + sk_receive_skb(sk, skb, 0); + + return VMCI_SUCCESS; +} + +static bool vmci_transport_stream_allow(u32 cid, u32 port) +{ + static const u32 non_socket_contexts[] = { + VMADDR_CID_LOCAL, + }; + int i; + + BUILD_BUG_ON(sizeof(cid) != sizeof(*non_socket_contexts)); + + for (i = 0; i < ARRAY_SIZE(non_socket_contexts); i++) { + if (cid == non_socket_contexts[i]) + return false; + } + + return true; +} + +/* This is invoked as part of a tasklet that's scheduled when the VMCI + * interrupt fires. This is run in bottom-half context but it defers most of + * its work to the packet handling work queue. + */ + +static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg) +{ + struct sock *sk; + struct sockaddr_vm dst; + struct sockaddr_vm src; + struct vmci_transport_packet *pkt; + struct vsock_sock *vsk; + bool bh_process_pkt; + int err; + + sk = NULL; + err = VMCI_SUCCESS; + bh_process_pkt = false; + + /* Ignore incoming packets from contexts without sockets, or resources + * that aren't vsock implementations. + */ + + if (!vmci_transport_stream_allow(dg->src.context, -1) + || vmci_transport_peer_rid(dg->src.context) != dg->src.resource) + return VMCI_ERROR_NO_ACCESS; + + if (VMCI_DG_SIZE(dg) < sizeof(*pkt)) + /* Drop datagrams that do not contain full VSock packets. */ + return VMCI_ERROR_INVALID_ARGS; + + pkt = (struct vmci_transport_packet *)dg; + + /* Find the socket that should handle this packet. First we look for a + * connected socket and if there is none we look for a socket bound to + * the destintation address. + */ + vsock_addr_init(&src, pkt->dg.src.context, pkt->src_port); + vsock_addr_init(&dst, pkt->dg.dst.context, pkt->dst_port); + + sk = vsock_find_connected_socket(&src, &dst); + if (!sk) { + sk = vsock_find_bound_socket(&dst); + if (!sk) { + /* We could not find a socket for this specified + * address. If this packet is a RST, we just drop it. + * If it is another packet, we send a RST. Note that + * we do not send a RST reply to RSTs so that we do not + * continually send RSTs between two endpoints. + * + * Note that since this is a reply, dst is src and src + * is dst. + */ + if (vmci_transport_send_reset_bh(&dst, &src, pkt) < 0) + pr_err("unable to send reset\n"); + + err = VMCI_ERROR_NOT_FOUND; + goto out; + } + } + + /* If the received packet type is beyond all types known to this + * implementation, reply with an invalid message. Hopefully this will + * help when implementing backwards compatibility in the future. + */ + if (pkt->type >= VMCI_TRANSPORT_PACKET_TYPE_MAX) { + vmci_transport_send_invalid_bh(&dst, &src); + err = VMCI_ERROR_INVALID_ARGS; + goto out; + } + + /* This handler is privileged when this module is running on the host. + * We will get datagram connect requests from all endpoints (even VMs + * that are in a restricted context). If we get one from a restricted + * context then the destination socket must be trusted. + * + * NOTE: We access the socket struct without holding the lock here. + * This is ok because the field we are interested is never modified + * outside of the create and destruct socket functions. + */ + vsk = vsock_sk(sk); + if (!vmci_transport_allow_dgram(vsk, pkt->dg.src.context)) { + err = VMCI_ERROR_NO_ACCESS; + goto out; + } + + /* We do most everything in a work queue, but let's fast path the + * notification of reads and writes to help data transfer performance. + * We can only do this if there is no process context code executing + * for this socket since that may change the state. + */ + bh_lock_sock(sk); + + if (!sock_owned_by_user(sk)) { + /* The local context ID may be out of date, update it. */ + vsk->local_addr.svm_cid = dst.svm_cid; + + if (sk->sk_state == TCP_ESTABLISHED) + vmci_trans(vsk)->notify_ops->handle_notify_pkt( + sk, pkt, true, &dst, &src, + &bh_process_pkt); + } + + bh_unlock_sock(sk); + + if (!bh_process_pkt) { + struct vmci_transport_recv_pkt_info *recv_pkt_info; + + recv_pkt_info = kmalloc(sizeof(*recv_pkt_info), GFP_ATOMIC); + if (!recv_pkt_info) { + if (vmci_transport_send_reset_bh(&dst, &src, pkt) < 0) + pr_err("unable to send reset\n"); + + err = VMCI_ERROR_NO_MEM; + goto out; + } + + recv_pkt_info->sk = sk; + memcpy(&recv_pkt_info->pkt, pkt, sizeof(recv_pkt_info->pkt)); + INIT_WORK(&recv_pkt_info->work, vmci_transport_recv_pkt_work); + + schedule_work(&recv_pkt_info->work); + /* Clear sk so that the reference count incremented by one of + * the Find functions above is not decremented below. We need + * that reference count for the packet handler we've scheduled + * to run. + */ + sk = NULL; + } + +out: + if (sk) + sock_put(sk); + + return err; +} + +static void vmci_transport_handle_detach(struct sock *sk) +{ + struct vsock_sock *vsk; + + vsk = vsock_sk(sk); + if (!vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle)) { + sock_set_flag(sk, SOCK_DONE); + + /* On a detach the peer will not be sending or receiving + * anymore. + */ + vsk->peer_shutdown = SHUTDOWN_MASK; + + /* We should not be sending anymore since the peer won't be + * there to receive, but we can still receive if there is data + * left in our consume queue. If the local endpoint is a host, + * we can't call vsock_stream_has_data, since that may block, + * but a host endpoint can't read data once the VM has + * detached, so there is no available data in that case. + */ + if (vsk->local_addr.svm_cid == VMADDR_CID_HOST || + vsock_stream_has_data(vsk) <= 0) { + if (sk->sk_state == TCP_SYN_SENT) { + /* The peer may detach from a queue pair while + * we are still in the connecting state, i.e., + * if the peer VM is killed after attaching to + * a queue pair, but before we complete the + * handshake. In that case, we treat the detach + * event like a reset. + */ + + sk->sk_state = TCP_CLOSE; + sk->sk_err = ECONNRESET; + sk_error_report(sk); + return; + } + sk->sk_state = TCP_CLOSE; + } + sk->sk_state_change(sk); + } +} + +static void vmci_transport_peer_detach_cb(u32 sub_id, + const struct vmci_event_data *e_data, + void *client_data) +{ + struct vmci_transport *trans = client_data; + const struct vmci_event_payload_qp *e_payload; + + e_payload = vmci_event_data_const_payload(e_data); + + /* XXX This is lame, we should provide a way to lookup sockets by + * qp_handle. + */ + if (vmci_handle_is_invalid(e_payload->handle) || + !vmci_handle_is_equal(trans->qp_handle, e_payload->handle)) + return; + + /* We don't ask for delayed CBs when we subscribe to this event (we + * pass 0 as flags to vmci_event_subscribe()). VMCI makes no + * guarantees in that case about what context we might be running in, + * so it could be BH or process, blockable or non-blockable. So we + * need to account for all possible contexts here. + */ + spin_lock_bh(&trans->lock); + if (!trans->sk) + goto out; + + /* Apart from here, trans->lock is only grabbed as part of sk destruct, + * where trans->sk isn't locked. + */ + bh_lock_sock(trans->sk); + + vmci_transport_handle_detach(trans->sk); + + bh_unlock_sock(trans->sk); + out: + spin_unlock_bh(&trans->lock); +} + +static void vmci_transport_qp_resumed_cb(u32 sub_id, + const struct vmci_event_data *e_data, + void *client_data) +{ + vsock_for_each_connected_socket(&vmci_transport, + vmci_transport_handle_detach); +} + +static void vmci_transport_recv_pkt_work(struct work_struct *work) +{ + struct vmci_transport_recv_pkt_info *recv_pkt_info; + struct vmci_transport_packet *pkt; + struct sock *sk; + + recv_pkt_info = + container_of(work, struct vmci_transport_recv_pkt_info, work); + sk = recv_pkt_info->sk; + pkt = &recv_pkt_info->pkt; + + lock_sock(sk); + + /* The local context ID may be out of date. */ + vsock_sk(sk)->local_addr.svm_cid = pkt->dg.dst.context; + + switch (sk->sk_state) { + case TCP_LISTEN: + vmci_transport_recv_listen(sk, pkt); + break; + case TCP_SYN_SENT: + /* Processing of pending connections for servers goes through + * the listening socket, so see vmci_transport_recv_listen() + * for that path. + */ + vmci_transport_recv_connecting_client(sk, pkt); + break; + case TCP_ESTABLISHED: + vmci_transport_recv_connected(sk, pkt); + break; + default: + /* Because this function does not run in the same context as + * vmci_transport_recv_stream_cb it is possible that the + * socket has closed. We need to let the other side know or it + * could be sitting in a connect and hang forever. Send a + * reset to prevent that. + */ + vmci_transport_send_reset(sk, pkt); + break; + } + + release_sock(sk); + kfree(recv_pkt_info); + /* Release reference obtained in the stream callback when we fetched + * this socket out of the bound or connected list. + */ + sock_put(sk); +} + +static int vmci_transport_recv_listen(struct sock *sk, + struct vmci_transport_packet *pkt) +{ + struct sock *pending; + struct vsock_sock *vpending; + int err; + u64 qp_size; + bool old_request = false; + bool old_pkt_proto = false; + + /* Because we are in the listen state, we could be receiving a packet + * for ourself or any previous connection requests that we received. + * If it's the latter, we try to find a socket in our list of pending + * connections and, if we do, call the appropriate handler for the + * state that socket is in. Otherwise we try to service the + * connection request. + */ + pending = vmci_transport_get_pending(sk, pkt); + if (pending) { + lock_sock(pending); + + /* The local context ID may be out of date. */ + vsock_sk(pending)->local_addr.svm_cid = pkt->dg.dst.context; + + switch (pending->sk_state) { + case TCP_SYN_SENT: + err = vmci_transport_recv_connecting_server(sk, + pending, + pkt); + break; + default: + vmci_transport_send_reset(pending, pkt); + err = -EINVAL; + } + + if (err < 0) + vsock_remove_pending(sk, pending); + + release_sock(pending); + vmci_transport_release_pending(pending); + + return err; + } + + /* The listen state only accepts connection requests. Reply with a + * reset unless we received a reset. + */ + + if (!(pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST || + pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST2)) { + vmci_transport_reply_reset(pkt); + return -EINVAL; + } + + if (pkt->u.size == 0) { + vmci_transport_reply_reset(pkt); + return -EINVAL; + } + + /* If this socket can't accommodate this connection request, we send a + * reset. Otherwise we create and initialize a child socket and reply + * with a connection negotiation. + */ + if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog) { + vmci_transport_reply_reset(pkt); + return -ECONNREFUSED; + } + + pending = vsock_create_connected(sk); + if (!pending) { + vmci_transport_send_reset(sk, pkt); + return -ENOMEM; + } + + vpending = vsock_sk(pending); + + vsock_addr_init(&vpending->local_addr, pkt->dg.dst.context, + pkt->dst_port); + vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context, + pkt->src_port); + + err = vsock_assign_transport(vpending, vsock_sk(sk)); + /* Transport assigned (looking at remote_addr) must be the same + * where we received the request. + */ + if (err || !vmci_check_transport(vpending)) { + vmci_transport_send_reset(sk, pkt); + sock_put(pending); + return err; + } + + /* If the proposed size fits within our min/max, accept it. Otherwise + * propose our own size. + */ + if (pkt->u.size >= vpending->buffer_min_size && + pkt->u.size <= vpending->buffer_max_size) { + qp_size = pkt->u.size; + } else { + qp_size = vpending->buffer_size; + } + + /* Figure out if we are using old or new requests based on the + * overrides pkt types sent by our peer. + */ + if (vmci_transport_old_proto_override(&old_pkt_proto)) { + old_request = old_pkt_proto; + } else { + if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST) + old_request = true; + else if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_REQUEST2) + old_request = false; + + } + + if (old_request) { + /* Handle a REQUEST (or override) */ + u16 version = VSOCK_PROTO_INVALID; + if (vmci_transport_proto_to_notify_struct( + pending, &version, true)) + err = vmci_transport_send_negotiate(pending, qp_size); + else + err = -EINVAL; + + } else { + /* Handle a REQUEST2 (or override) */ + int proto_int = pkt->proto; + int pos; + u16 active_proto_version = 0; + + /* The list of possible protocols is the intersection of all + * protocols the client supports ... plus all the protocols we + * support. + */ + proto_int &= vmci_transport_new_proto_supported_versions(); + + /* We choose the highest possible protocol version and use that + * one. + */ + pos = fls(proto_int); + if (pos) { + active_proto_version = (1 << (pos - 1)); + if (vmci_transport_proto_to_notify_struct( + pending, &active_proto_version, false)) + err = vmci_transport_send_negotiate2(pending, + qp_size, + active_proto_version); + else + err = -EINVAL; + + } else { + err = -EINVAL; + } + } + + if (err < 0) { + vmci_transport_send_reset(sk, pkt); + sock_put(pending); + err = vmci_transport_error_to_vsock_error(err); + goto out; + } + + vsock_add_pending(sk, pending); + sk_acceptq_added(sk); + + pending->sk_state = TCP_SYN_SENT; + vmci_trans(vpending)->produce_size = + vmci_trans(vpending)->consume_size = qp_size; + vpending->buffer_size = qp_size; + + vmci_trans(vpending)->notify_ops->process_request(pending); + + /* We might never receive another message for this socket and it's not + * connected to any process, so we have to ensure it gets cleaned up + * ourself. Our delayed work function will take care of that. Note + * that we do not ever cancel this function since we have few + * guarantees about its state when calling cancel_delayed_work(). + * Instead we hold a reference on the socket for that function and make + * it capable of handling cases where it needs to do nothing but + * release that reference. + */ + vpending->listener = sk; + sock_hold(sk); + sock_hold(pending); + schedule_delayed_work(&vpending->pending_work, HZ); + +out: + return err; +} + +static int +vmci_transport_recv_connecting_server(struct sock *listener, + struct sock *pending, + struct vmci_transport_packet *pkt) +{ + struct vsock_sock *vpending; + struct vmci_handle handle; + struct vmci_qp *qpair; + bool is_local; + u32 flags; + u32 detach_sub_id; + int err; + int skerr; + + vpending = vsock_sk(pending); + detach_sub_id = VMCI_INVALID_ID; + + switch (pkt->type) { + case VMCI_TRANSPORT_PACKET_TYPE_OFFER: + if (vmci_handle_is_invalid(pkt->u.handle)) { + vmci_transport_send_reset(pending, pkt); + skerr = EPROTO; + err = -EINVAL; + goto destroy; + } + break; + default: + /* Close and cleanup the connection. */ + vmci_transport_send_reset(pending, pkt); + skerr = EPROTO; + err = pkt->type == VMCI_TRANSPORT_PACKET_TYPE_RST ? 0 : -EINVAL; + goto destroy; + } + + /* In order to complete the connection we need to attach to the offered + * queue pair and send an attach notification. We also subscribe to the + * detach event so we know when our peer goes away, and we do that + * before attaching so we don't miss an event. If all this succeeds, + * we update our state and wakeup anything waiting in accept() for a + * connection. + */ + + /* We don't care about attach since we ensure the other side has + * attached by specifying the ATTACH_ONLY flag below. + */ + err = vmci_event_subscribe(VMCI_EVENT_QP_PEER_DETACH, + vmci_transport_peer_detach_cb, + vmci_trans(vpending), &detach_sub_id); + if (err < VMCI_SUCCESS) { + vmci_transport_send_reset(pending, pkt); + err = vmci_transport_error_to_vsock_error(err); + skerr = -err; + goto destroy; + } + + vmci_trans(vpending)->detach_sub_id = detach_sub_id; + + /* Now attach to the queue pair the client created. */ + handle = pkt->u.handle; + + /* vpending->local_addr always has a context id so we do not need to + * worry about VMADDR_CID_ANY in this case. + */ + is_local = + vpending->remote_addr.svm_cid == vpending->local_addr.svm_cid; + flags = VMCI_QPFLAG_ATTACH_ONLY; + flags |= is_local ? VMCI_QPFLAG_LOCAL : 0; + + err = vmci_transport_queue_pair_alloc( + &qpair, + &handle, + vmci_trans(vpending)->produce_size, + vmci_trans(vpending)->consume_size, + pkt->dg.src.context, + flags, + vmci_transport_is_trusted( + vpending, + vpending->remote_addr.svm_cid)); + if (err < 0) { + vmci_transport_send_reset(pending, pkt); + skerr = -err; + goto destroy; + } + + vmci_trans(vpending)->qp_handle = handle; + vmci_trans(vpending)->qpair = qpair; + + /* When we send the attach message, we must be ready to handle incoming + * control messages on the newly connected socket. So we move the + * pending socket to the connected state before sending the attach + * message. Otherwise, an incoming packet triggered by the attach being + * received by the peer may be processed concurrently with what happens + * below after sending the attach message, and that incoming packet + * will find the listening socket instead of the (currently) pending + * socket. Note that enqueueing the socket increments the reference + * count, so even if a reset comes before the connection is accepted, + * the socket will be valid until it is removed from the queue. + * + * If we fail sending the attach below, we remove the socket from the + * connected list and move the socket to TCP_CLOSE before + * releasing the lock, so a pending slow path processing of an incoming + * packet will not see the socket in the connected state in that case. + */ + pending->sk_state = TCP_ESTABLISHED; + + vsock_insert_connected(vpending); + + /* Notify our peer of our attach. */ + err = vmci_transport_send_attach(pending, handle); + if (err < 0) { + vsock_remove_connected(vpending); + pr_err("Could not send attach\n"); + vmci_transport_send_reset(pending, pkt); + err = vmci_transport_error_to_vsock_error(err); + skerr = -err; + goto destroy; + } + + /* We have a connection. Move the now connected socket from the + * listener's pending list to the accept queue so callers of accept() + * can find it. + */ + vsock_remove_pending(listener, pending); + vsock_enqueue_accept(listener, pending); + + /* Callers of accept() will be waiting on the listening socket, not + * the pending socket. + */ + listener->sk_data_ready(listener); + + return 0; + +destroy: + pending->sk_err = skerr; + pending->sk_state = TCP_CLOSE; + /* As long as we drop our reference, all necessary cleanup will handle + * when the cleanup function drops its reference and our destruct + * implementation is called. Note that since the listen handler will + * remove pending from the pending list upon our failure, the cleanup + * function won't drop the additional reference, which is why we do it + * here. + */ + sock_put(pending); + + return err; +} + +static int +vmci_transport_recv_connecting_client(struct sock *sk, + struct vmci_transport_packet *pkt) +{ + struct vsock_sock *vsk; + int err; + int skerr; + + vsk = vsock_sk(sk); + + switch (pkt->type) { + case VMCI_TRANSPORT_PACKET_TYPE_ATTACH: + if (vmci_handle_is_invalid(pkt->u.handle) || + !vmci_handle_is_equal(pkt->u.handle, + vmci_trans(vsk)->qp_handle)) { + skerr = EPROTO; + err = -EINVAL; + goto destroy; + } + + /* Signify the socket is connected and wakeup the waiter in + * connect(). Also place the socket in the connected table for + * accounting (it can already be found since it's in the bound + * table). + */ + sk->sk_state = TCP_ESTABLISHED; + sk->sk_socket->state = SS_CONNECTED; + vsock_insert_connected(vsk); + sk->sk_state_change(sk); + + break; + case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE: + case VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2: + if (pkt->u.size == 0 + || pkt->dg.src.context != vsk->remote_addr.svm_cid + || pkt->src_port != vsk->remote_addr.svm_port + || !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle) + || vmci_trans(vsk)->qpair + || vmci_trans(vsk)->produce_size != 0 + || vmci_trans(vsk)->consume_size != 0 + || vmci_trans(vsk)->detach_sub_id != VMCI_INVALID_ID) { + skerr = EPROTO; + err = -EINVAL; + + goto destroy; + } + + err = vmci_transport_recv_connecting_client_negotiate(sk, pkt); + if (err) { + skerr = -err; + goto destroy; + } + + break; + case VMCI_TRANSPORT_PACKET_TYPE_INVALID: + err = vmci_transport_recv_connecting_client_invalid(sk, pkt); + if (err) { + skerr = -err; + goto destroy; + } + + break; + case VMCI_TRANSPORT_PACKET_TYPE_RST: + /* Older versions of the linux code (WS 6.5 / ESX 4.0) used to + * continue processing here after they sent an INVALID packet. + * This meant that we got a RST after the INVALID. We ignore a + * RST after an INVALID. The common code doesn't send the RST + * ... so we can hang if an old version of the common code + * fails between getting a REQUEST and sending an OFFER back. + * Not much we can do about it... except hope that it doesn't + * happen. + */ + if (vsk->ignore_connecting_rst) { + vsk->ignore_connecting_rst = false; + } else { + skerr = ECONNRESET; + err = 0; + goto destroy; + } + + break; + default: + /* Close and cleanup the connection. */ + skerr = EPROTO; + err = -EINVAL; + goto destroy; + } + + return 0; + +destroy: + vmci_transport_send_reset(sk, pkt); + + sk->sk_state = TCP_CLOSE; + sk->sk_err = skerr; + sk_error_report(sk); + return err; +} + +static int vmci_transport_recv_connecting_client_negotiate( + struct sock *sk, + struct vmci_transport_packet *pkt) +{ + int err; + struct vsock_sock *vsk; + struct vmci_handle handle; + struct vmci_qp *qpair; + u32 detach_sub_id; + bool is_local; + u32 flags; + bool old_proto = true; + bool old_pkt_proto; + u16 version; + + vsk = vsock_sk(sk); + handle = VMCI_INVALID_HANDLE; + detach_sub_id = VMCI_INVALID_ID; + + /* If we have gotten here then we should be past the point where old + * linux vsock could have sent the bogus rst. + */ + vsk->sent_request = false; + vsk->ignore_connecting_rst = false; + + /* Verify that we're OK with the proposed queue pair size */ + if (pkt->u.size < vsk->buffer_min_size || + pkt->u.size > vsk->buffer_max_size) { + err = -EINVAL; + goto destroy; + } + + /* At this point we know the CID the peer is using to talk to us. */ + + if (vsk->local_addr.svm_cid == VMADDR_CID_ANY) + vsk->local_addr.svm_cid = pkt->dg.dst.context; + + /* Setup the notify ops to be the highest supported version that both + * the server and the client support. + */ + + if (vmci_transport_old_proto_override(&old_pkt_proto)) { + old_proto = old_pkt_proto; + } else { + if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE) + old_proto = true; + else if (pkt->type == VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2) + old_proto = false; + + } + + if (old_proto) + version = VSOCK_PROTO_INVALID; + else + version = pkt->proto; + + if (!vmci_transport_proto_to_notify_struct(sk, &version, old_proto)) { + err = -EINVAL; + goto destroy; + } + + /* Subscribe to detach events first. + * + * XXX We attach once for each queue pair created for now so it is easy + * to find the socket (it's provided), but later we should only + * subscribe once and add a way to lookup sockets by queue pair handle. + */ + err = vmci_event_subscribe(VMCI_EVENT_QP_PEER_DETACH, + vmci_transport_peer_detach_cb, + vmci_trans(vsk), &detach_sub_id); + if (err < VMCI_SUCCESS) { + err = vmci_transport_error_to_vsock_error(err); + goto destroy; + } + + /* Make VMCI select the handle for us. */ + handle = VMCI_INVALID_HANDLE; + is_local = vsk->remote_addr.svm_cid == vsk->local_addr.svm_cid; + flags = is_local ? VMCI_QPFLAG_LOCAL : 0; + + err = vmci_transport_queue_pair_alloc(&qpair, + &handle, + pkt->u.size, + pkt->u.size, + vsk->remote_addr.svm_cid, + flags, + vmci_transport_is_trusted( + vsk, + vsk-> + remote_addr.svm_cid)); + if (err < 0) + goto destroy; + + err = vmci_transport_send_qp_offer(sk, handle); + if (err < 0) { + err = vmci_transport_error_to_vsock_error(err); + goto destroy; + } + + vmci_trans(vsk)->qp_handle = handle; + vmci_trans(vsk)->qpair = qpair; + + vmci_trans(vsk)->produce_size = vmci_trans(vsk)->consume_size = + pkt->u.size; + + vmci_trans(vsk)->detach_sub_id = detach_sub_id; + + vmci_trans(vsk)->notify_ops->process_negotiate(sk); + + return 0; + +destroy: + if (detach_sub_id != VMCI_INVALID_ID) + vmci_event_unsubscribe(detach_sub_id); + + if (!vmci_handle_is_invalid(handle)) + vmci_qpair_detach(&qpair); + + return err; +} + +static int +vmci_transport_recv_connecting_client_invalid(struct sock *sk, + struct vmci_transport_packet *pkt) +{ + int err = 0; + struct vsock_sock *vsk = vsock_sk(sk); + + if (vsk->sent_request) { + vsk->sent_request = false; + vsk->ignore_connecting_rst = true; + + err = vmci_transport_send_conn_request(sk, vsk->buffer_size); + if (err < 0) + err = vmci_transport_error_to_vsock_error(err); + else + err = 0; + + } + + return err; +} + +static int vmci_transport_recv_connected(struct sock *sk, + struct vmci_transport_packet *pkt) +{ + struct vsock_sock *vsk; + bool pkt_processed = false; + + /* In cases where we are closing the connection, it's sufficient to + * mark the state change (and maybe error) and wake up any waiting + * threads. Since this is a connected socket, it's owned by a user + * process and will be cleaned up when the failure is passed back on + * the current or next system call. Our system call implementations + * must therefore check for error and state changes on entry and when + * being awoken. + */ + switch (pkt->type) { + case VMCI_TRANSPORT_PACKET_TYPE_SHUTDOWN: + if (pkt->u.mode) { + vsk = vsock_sk(sk); + + vsk->peer_shutdown |= pkt->u.mode; + sk->sk_state_change(sk); + } + break; + + case VMCI_TRANSPORT_PACKET_TYPE_RST: + vsk = vsock_sk(sk); + /* It is possible that we sent our peer a message (e.g a + * WAITING_READ) right before we got notified that the peer had + * detached. If that happens then we can get a RST pkt back + * from our peer even though there is data available for us to + * read. In that case, don't shutdown the socket completely but + * instead allow the local client to finish reading data off + * the queuepair. Always treat a RST pkt in connected mode like + * a clean shutdown. + */ + sock_set_flag(sk, SOCK_DONE); + vsk->peer_shutdown = SHUTDOWN_MASK; + if (vsock_stream_has_data(vsk) <= 0) + sk->sk_state = TCP_CLOSING; + + sk->sk_state_change(sk); + break; + + default: + vsk = vsock_sk(sk); + vmci_trans(vsk)->notify_ops->handle_notify_pkt( + sk, pkt, false, NULL, NULL, + &pkt_processed); + if (!pkt_processed) + return -EINVAL; + + break; + } + + return 0; +} + +static int vmci_transport_socket_init(struct vsock_sock *vsk, + struct vsock_sock *psk) +{ + vsk->trans = kmalloc(sizeof(struct vmci_transport), GFP_KERNEL); + if (!vsk->trans) + return -ENOMEM; + + vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE; + vmci_trans(vsk)->qp_handle = VMCI_INVALID_HANDLE; + vmci_trans(vsk)->qpair = NULL; + vmci_trans(vsk)->produce_size = vmci_trans(vsk)->consume_size = 0; + vmci_trans(vsk)->detach_sub_id = VMCI_INVALID_ID; + vmci_trans(vsk)->notify_ops = NULL; + INIT_LIST_HEAD(&vmci_trans(vsk)->elem); + vmci_trans(vsk)->sk = &vsk->sk; + spin_lock_init(&vmci_trans(vsk)->lock); + + return 0; +} + +static void vmci_transport_free_resources(struct list_head *transport_list) +{ + while (!list_empty(transport_list)) { + struct vmci_transport *transport = + list_first_entry(transport_list, struct vmci_transport, + elem); + list_del(&transport->elem); + + if (transport->detach_sub_id != VMCI_INVALID_ID) { + vmci_event_unsubscribe(transport->detach_sub_id); + transport->detach_sub_id = VMCI_INVALID_ID; + } + + if (!vmci_handle_is_invalid(transport->qp_handle)) { + vmci_qpair_detach(&transport->qpair); + transport->qp_handle = VMCI_INVALID_HANDLE; + transport->produce_size = 0; + transport->consume_size = 0; + } + + kfree(transport); + } +} + +static void vmci_transport_cleanup(struct work_struct *work) +{ + LIST_HEAD(pending); + + spin_lock_bh(&vmci_transport_cleanup_lock); + list_replace_init(&vmci_transport_cleanup_list, &pending); + spin_unlock_bh(&vmci_transport_cleanup_lock); + vmci_transport_free_resources(&pending); +} + +static void vmci_transport_destruct(struct vsock_sock *vsk) +{ + /* transport can be NULL if we hit a failure at init() time */ + if (!vmci_trans(vsk)) + return; + + /* Ensure that the detach callback doesn't use the sk/vsk + * we are about to destruct. + */ + spin_lock_bh(&vmci_trans(vsk)->lock); + vmci_trans(vsk)->sk = NULL; + spin_unlock_bh(&vmci_trans(vsk)->lock); + + if (vmci_trans(vsk)->notify_ops) + vmci_trans(vsk)->notify_ops->socket_destruct(vsk); + + spin_lock_bh(&vmci_transport_cleanup_lock); + list_add(&vmci_trans(vsk)->elem, &vmci_transport_cleanup_list); + spin_unlock_bh(&vmci_transport_cleanup_lock); + schedule_work(&vmci_transport_cleanup_work); + + vsk->trans = NULL; +} + +static void vmci_transport_release(struct vsock_sock *vsk) +{ + vsock_remove_sock(vsk); + + if (!vmci_handle_is_invalid(vmci_trans(vsk)->dg_handle)) { + vmci_datagram_destroy_handle(vmci_trans(vsk)->dg_handle); + vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE; + } +} + +static int vmci_transport_dgram_bind(struct vsock_sock *vsk, + struct sockaddr_vm *addr) +{ + u32 port; + u32 flags; + int err; + + /* VMCI will select a resource ID for us if we provide + * VMCI_INVALID_ID. + */ + port = addr->svm_port == VMADDR_PORT_ANY ? + VMCI_INVALID_ID : addr->svm_port; + + if (port <= LAST_RESERVED_PORT && !capable(CAP_NET_BIND_SERVICE)) + return -EACCES; + + flags = addr->svm_cid == VMADDR_CID_ANY ? + VMCI_FLAG_ANYCID_DG_HND : 0; + + err = vmci_transport_datagram_create_hnd(port, flags, + vmci_transport_recv_dgram_cb, + &vsk->sk, + &vmci_trans(vsk)->dg_handle); + if (err < VMCI_SUCCESS) + return vmci_transport_error_to_vsock_error(err); + vsock_addr_init(&vsk->local_addr, addr->svm_cid, + vmci_trans(vsk)->dg_handle.resource); + + return 0; +} + +static int vmci_transport_dgram_enqueue( + struct vsock_sock *vsk, + struct sockaddr_vm *remote_addr, + struct msghdr *msg, + size_t len) +{ + int err; + struct vmci_datagram *dg; + + if (len > VMCI_MAX_DG_PAYLOAD_SIZE) + return -EMSGSIZE; + + if (!vmci_transport_allow_dgram(vsk, remote_addr->svm_cid)) + return -EPERM; + + /* Allocate a buffer for the user's message and our packet header. */ + dg = kmalloc(len + sizeof(*dg), GFP_KERNEL); + if (!dg) + return -ENOMEM; + + err = memcpy_from_msg(VMCI_DG_PAYLOAD(dg), msg, len); + if (err) { + kfree(dg); + return err; + } + + dg->dst = vmci_make_handle(remote_addr->svm_cid, + remote_addr->svm_port); + dg->src = vmci_make_handle(vsk->local_addr.svm_cid, + vsk->local_addr.svm_port); + dg->payload_size = len; + + err = vmci_datagram_send(dg); + kfree(dg); + if (err < 0) + return vmci_transport_error_to_vsock_error(err); + + return err - sizeof(*dg); +} + +static int vmci_transport_dgram_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, size_t len, + int flags) +{ + int err; + struct vmci_datagram *dg; + size_t payload_len; + struct sk_buff *skb; + + if (flags & MSG_OOB || flags & MSG_ERRQUEUE) + return -EOPNOTSUPP; + + /* Retrieve the head sk_buff from the socket's receive queue. */ + err = 0; + skb = skb_recv_datagram(&vsk->sk, flags, &err); + if (!skb) + return err; + + dg = (struct vmci_datagram *)skb->data; + if (!dg) + /* err is 0, meaning we read zero bytes. */ + goto out; + + payload_len = dg->payload_size; + /* Ensure the sk_buff matches the payload size claimed in the packet. */ + if (payload_len != skb->len - sizeof(*dg)) { + err = -EINVAL; + goto out; + } + + if (payload_len > len) { + payload_len = len; + msg->msg_flags |= MSG_TRUNC; + } + + /* Place the datagram payload in the user's iovec. */ + err = skb_copy_datagram_msg(skb, sizeof(*dg), msg, payload_len); + if (err) + goto out; + + if (msg->msg_name) { + /* Provide the address of the sender. */ + DECLARE_SOCKADDR(struct sockaddr_vm *, vm_addr, msg->msg_name); + vsock_addr_init(vm_addr, dg->src.context, dg->src.resource); + msg->msg_namelen = sizeof(*vm_addr); + } + err = payload_len; + +out: + skb_free_datagram(&vsk->sk, skb); + return err; +} + +static bool vmci_transport_dgram_allow(u32 cid, u32 port) +{ + if (cid == VMADDR_CID_HYPERVISOR) { + /* Registrations of PBRPC Servers do not modify VMX/Hypervisor + * state and are allowed. + */ + return port == VMCI_UNITY_PBRPC_REGISTER; + } + + return true; +} + +static int vmci_transport_connect(struct vsock_sock *vsk) +{ + int err; + bool old_pkt_proto = false; + struct sock *sk = &vsk->sk; + + if (vmci_transport_old_proto_override(&old_pkt_proto) && + old_pkt_proto) { + err = vmci_transport_send_conn_request(sk, vsk->buffer_size); + if (err < 0) { + sk->sk_state = TCP_CLOSE; + return err; + } + } else { + int supported_proto_versions = + vmci_transport_new_proto_supported_versions(); + err = vmci_transport_send_conn_request2(sk, vsk->buffer_size, + supported_proto_versions); + if (err < 0) { + sk->sk_state = TCP_CLOSE; + return err; + } + + vsk->sent_request = true; + } + + return err; +} + +static ssize_t vmci_transport_stream_dequeue( + struct vsock_sock *vsk, + struct msghdr *msg, + size_t len, + int flags) +{ + if (flags & MSG_PEEK) + return vmci_qpair_peekv(vmci_trans(vsk)->qpair, msg, len, 0); + else + return vmci_qpair_dequev(vmci_trans(vsk)->qpair, msg, len, 0); +} + +static ssize_t vmci_transport_stream_enqueue( + struct vsock_sock *vsk, + struct msghdr *msg, + size_t len) +{ + return vmci_qpair_enquev(vmci_trans(vsk)->qpair, msg, len, 0); +} + +static s64 vmci_transport_stream_has_data(struct vsock_sock *vsk) +{ + return vmci_qpair_consume_buf_ready(vmci_trans(vsk)->qpair); +} + +static s64 vmci_transport_stream_has_space(struct vsock_sock *vsk) +{ + return vmci_qpair_produce_free_space(vmci_trans(vsk)->qpair); +} + +static u64 vmci_transport_stream_rcvhiwat(struct vsock_sock *vsk) +{ + return vmci_trans(vsk)->consume_size; +} + +static bool vmci_transport_stream_is_active(struct vsock_sock *vsk) +{ + return !vmci_handle_is_invalid(vmci_trans(vsk)->qp_handle); +} + +static int vmci_transport_notify_poll_in( + struct vsock_sock *vsk, + size_t target, + bool *data_ready_now) +{ + return vmci_trans(vsk)->notify_ops->poll_in( + &vsk->sk, target, data_ready_now); +} + +static int vmci_transport_notify_poll_out( + struct vsock_sock *vsk, + size_t target, + bool *space_available_now) +{ + return vmci_trans(vsk)->notify_ops->poll_out( + &vsk->sk, target, space_available_now); +} + +static int vmci_transport_notify_recv_init( + struct vsock_sock *vsk, + size_t target, + struct vsock_transport_recv_notify_data *data) +{ + return vmci_trans(vsk)->notify_ops->recv_init( + &vsk->sk, target, + (struct vmci_transport_recv_notify_data *)data); +} + +static int vmci_transport_notify_recv_pre_block( + struct vsock_sock *vsk, + size_t target, + struct vsock_transport_recv_notify_data *data) +{ + return vmci_trans(vsk)->notify_ops->recv_pre_block( + &vsk->sk, target, + (struct vmci_transport_recv_notify_data *)data); +} + +static int vmci_transport_notify_recv_pre_dequeue( + struct vsock_sock *vsk, + size_t target, + struct vsock_transport_recv_notify_data *data) +{ + return vmci_trans(vsk)->notify_ops->recv_pre_dequeue( + &vsk->sk, target, + (struct vmci_transport_recv_notify_data *)data); +} + +static int vmci_transport_notify_recv_post_dequeue( + struct vsock_sock *vsk, + size_t target, + ssize_t copied, + bool data_read, + struct vsock_transport_recv_notify_data *data) +{ + return vmci_trans(vsk)->notify_ops->recv_post_dequeue( + &vsk->sk, target, copied, data_read, + (struct vmci_transport_recv_notify_data *)data); +} + +static int vmci_transport_notify_send_init( + struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return vmci_trans(vsk)->notify_ops->send_init( + &vsk->sk, + (struct vmci_transport_send_notify_data *)data); +} + +static int vmci_transport_notify_send_pre_block( + struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return vmci_trans(vsk)->notify_ops->send_pre_block( + &vsk->sk, + (struct vmci_transport_send_notify_data *)data); +} + +static int vmci_transport_notify_send_pre_enqueue( + struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return vmci_trans(vsk)->notify_ops->send_pre_enqueue( + &vsk->sk, + (struct vmci_transport_send_notify_data *)data); +} + +static int vmci_transport_notify_send_post_enqueue( + struct vsock_sock *vsk, + ssize_t written, + struct vsock_transport_send_notify_data *data) +{ + return vmci_trans(vsk)->notify_ops->send_post_enqueue( + &vsk->sk, written, + (struct vmci_transport_send_notify_data *)data); +} + +static bool vmci_transport_old_proto_override(bool *old_pkt_proto) +{ + if (PROTOCOL_OVERRIDE != -1) { + if (PROTOCOL_OVERRIDE == 0) + *old_pkt_proto = true; + else + *old_pkt_proto = false; + + pr_info("Proto override in use\n"); + return true; + } + + return false; +} + +static bool vmci_transport_proto_to_notify_struct(struct sock *sk, + u16 *proto, + bool old_pkt_proto) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + if (old_pkt_proto) { + if (*proto != VSOCK_PROTO_INVALID) { + pr_err("Can't set both an old and new protocol\n"); + return false; + } + vmci_trans(vsk)->notify_ops = &vmci_transport_notify_pkt_ops; + goto exit; + } + + switch (*proto) { + case VSOCK_PROTO_PKT_ON_NOTIFY: + vmci_trans(vsk)->notify_ops = + &vmci_transport_notify_pkt_q_state_ops; + break; + default: + pr_err("Unknown notify protocol version\n"); + return false; + } + +exit: + vmci_trans(vsk)->notify_ops->socket_init(sk); + return true; +} + +static u16 vmci_transport_new_proto_supported_versions(void) +{ + if (PROTOCOL_OVERRIDE != -1) + return PROTOCOL_OVERRIDE; + + return VSOCK_PROTO_ALL_SUPPORTED; +} + +static u32 vmci_transport_get_local_cid(void) +{ + return vmci_get_context_id(); +} + +static struct vsock_transport vmci_transport = { + .module = THIS_MODULE, + .init = vmci_transport_socket_init, + .destruct = vmci_transport_destruct, + .release = vmci_transport_release, + .connect = vmci_transport_connect, + .dgram_bind = vmci_transport_dgram_bind, + .dgram_dequeue = vmci_transport_dgram_dequeue, + .dgram_enqueue = vmci_transport_dgram_enqueue, + .dgram_allow = vmci_transport_dgram_allow, + .stream_dequeue = vmci_transport_stream_dequeue, + .stream_enqueue = vmci_transport_stream_enqueue, + .stream_has_data = vmci_transport_stream_has_data, + .stream_has_space = vmci_transport_stream_has_space, + .stream_rcvhiwat = vmci_transport_stream_rcvhiwat, + .stream_is_active = vmci_transport_stream_is_active, + .stream_allow = vmci_transport_stream_allow, + .notify_poll_in = vmci_transport_notify_poll_in, + .notify_poll_out = vmci_transport_notify_poll_out, + .notify_recv_init = vmci_transport_notify_recv_init, + .notify_recv_pre_block = vmci_transport_notify_recv_pre_block, + .notify_recv_pre_dequeue = vmci_transport_notify_recv_pre_dequeue, + .notify_recv_post_dequeue = vmci_transport_notify_recv_post_dequeue, + .notify_send_init = vmci_transport_notify_send_init, + .notify_send_pre_block = vmci_transport_notify_send_pre_block, + .notify_send_pre_enqueue = vmci_transport_notify_send_pre_enqueue, + .notify_send_post_enqueue = vmci_transport_notify_send_post_enqueue, + .shutdown = vmci_transport_shutdown, + .get_local_cid = vmci_transport_get_local_cid, +}; + +static bool vmci_check_transport(struct vsock_sock *vsk) +{ + return vsk->transport == &vmci_transport; +} + +static void vmci_vsock_transport_cb(bool is_host) +{ + int features; + + if (is_host) + features = VSOCK_TRANSPORT_F_H2G; + else + features = VSOCK_TRANSPORT_F_G2H; + + vsock_core_register(&vmci_transport, features); +} + +static int __init vmci_transport_init(void) +{ + int err; + + /* Create the datagram handle that we will use to send and receive all + * VSocket control messages for this context. + */ + err = vmci_transport_datagram_create_hnd(VMCI_TRANSPORT_PACKET_RID, + VMCI_FLAG_ANYCID_DG_HND, + vmci_transport_recv_stream_cb, + NULL, + &vmci_transport_stream_handle); + if (err < VMCI_SUCCESS) { + pr_err("Unable to create datagram handle. (%d)\n", err); + return vmci_transport_error_to_vsock_error(err); + } + err = vmci_event_subscribe(VMCI_EVENT_QP_RESUMED, + vmci_transport_qp_resumed_cb, + NULL, &vmci_transport_qp_resumed_sub_id); + if (err < VMCI_SUCCESS) { + pr_err("Unable to subscribe to resumed event. (%d)\n", err); + err = vmci_transport_error_to_vsock_error(err); + vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID; + goto err_destroy_stream_handle; + } + + /* Register only with dgram feature, other features (H2G, G2H) will be + * registered when the first host or guest becomes active. + */ + err = vsock_core_register(&vmci_transport, VSOCK_TRANSPORT_F_DGRAM); + if (err < 0) + goto err_unsubscribe; + + err = vmci_register_vsock_callback(vmci_vsock_transport_cb); + if (err < 0) + goto err_unregister; + + return 0; + +err_unregister: + vsock_core_unregister(&vmci_transport); +err_unsubscribe: + vmci_event_unsubscribe(vmci_transport_qp_resumed_sub_id); +err_destroy_stream_handle: + vmci_datagram_destroy_handle(vmci_transport_stream_handle); + return err; +} +module_init(vmci_transport_init); + +static void __exit vmci_transport_exit(void) +{ + cancel_work_sync(&vmci_transport_cleanup_work); + vmci_transport_free_resources(&vmci_transport_cleanup_list); + + if (!vmci_handle_is_invalid(vmci_transport_stream_handle)) { + if (vmci_datagram_destroy_handle( + vmci_transport_stream_handle) != VMCI_SUCCESS) + pr_err("Couldn't destroy datagram handle\n"); + vmci_transport_stream_handle = VMCI_INVALID_HANDLE; + } + + if (vmci_transport_qp_resumed_sub_id != VMCI_INVALID_ID) { + vmci_event_unsubscribe(vmci_transport_qp_resumed_sub_id); + vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID; + } + + vmci_register_vsock_callback(NULL); + vsock_core_unregister(&vmci_transport); +} +module_exit(vmci_transport_exit); + +MODULE_AUTHOR("VMware, Inc."); +MODULE_DESCRIPTION("VMCI transport for Virtual Sockets"); +MODULE_VERSION("1.0.5.0-k"); +MODULE_LICENSE("GPL v2"); +MODULE_ALIAS("vmware_vsock"); +MODULE_ALIAS_NETPROTO(PF_VSOCK); diff --git a/net/vmw_vsock/vmci_transport.h b/net/vmw_vsock/vmci_transport.h new file mode 100644 index 000000000..b7b072194 --- /dev/null +++ b/net/vmw_vsock/vmci_transport.h @@ -0,0 +1,133 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * VMware vSockets Driver + * + * Copyright (C) 2013 VMware, Inc. All rights reserved. + */ + +#ifndef _VMCI_TRANSPORT_H_ +#define _VMCI_TRANSPORT_H_ + +#include +#include + +#include +#include + +/* If the packet format changes in a release then this should change too. */ +#define VMCI_TRANSPORT_PACKET_VERSION 1 + +/* The resource ID on which control packets are sent. */ +#define VMCI_TRANSPORT_PACKET_RID 1 + +/* The resource ID on which control packets are sent to the hypervisor. */ +#define VMCI_TRANSPORT_HYPERVISOR_PACKET_RID 15 + +#define VSOCK_PROTO_INVALID 0 +#define VSOCK_PROTO_PKT_ON_NOTIFY (1 << 0) +#define VSOCK_PROTO_ALL_SUPPORTED (VSOCK_PROTO_PKT_ON_NOTIFY) + +#define vmci_trans(_vsk) ((struct vmci_transport *)((_vsk)->trans)) + +enum vmci_transport_packet_type { + VMCI_TRANSPORT_PACKET_TYPE_INVALID = 0, + VMCI_TRANSPORT_PACKET_TYPE_REQUEST, + VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE, + VMCI_TRANSPORT_PACKET_TYPE_OFFER, + VMCI_TRANSPORT_PACKET_TYPE_ATTACH, + VMCI_TRANSPORT_PACKET_TYPE_WROTE, + VMCI_TRANSPORT_PACKET_TYPE_READ, + VMCI_TRANSPORT_PACKET_TYPE_RST, + VMCI_TRANSPORT_PACKET_TYPE_SHUTDOWN, + VMCI_TRANSPORT_PACKET_TYPE_WAITING_WRITE, + VMCI_TRANSPORT_PACKET_TYPE_WAITING_READ, + VMCI_TRANSPORT_PACKET_TYPE_REQUEST2, + VMCI_TRANSPORT_PACKET_TYPE_NEGOTIATE2, + VMCI_TRANSPORT_PACKET_TYPE_MAX +}; + +struct vmci_transport_waiting_info { + u64 generation; + u64 offset; +}; + +/* Control packet type for STREAM sockets. DGRAMs have no control packets nor + * special packet header for data packets, they are just raw VMCI DGRAM + * messages. For STREAMs, control packets are sent over the control channel + * while data is written and read directly from queue pairs with no packet + * format. + */ +struct vmci_transport_packet { + struct vmci_datagram dg; + u8 version; + u8 type; + u16 proto; + u32 src_port; + u32 dst_port; + u32 _reserved2; + union { + u64 size; + u64 mode; + struct vmci_handle handle; + struct vmci_transport_waiting_info wait; + } u; +}; + +struct vmci_transport_notify_pkt { + u64 write_notify_window; + u64 write_notify_min_window; + bool peer_waiting_read; + bool peer_waiting_write; + bool peer_waiting_write_detected; + bool sent_waiting_read; + bool sent_waiting_write; + struct vmci_transport_waiting_info peer_waiting_read_info; + struct vmci_transport_waiting_info peer_waiting_write_info; + u64 produce_q_generation; + u64 consume_q_generation; +}; + +struct vmci_transport_notify_pkt_q_state { + u64 write_notify_window; + u64 write_notify_min_window; + bool peer_waiting_write; + bool peer_waiting_write_detected; +}; + +union vmci_transport_notify { + struct vmci_transport_notify_pkt pkt; + struct vmci_transport_notify_pkt_q_state pkt_q_state; +}; + +/* Our transport-specific data. */ +struct vmci_transport { + /* For DGRAMs. */ + struct vmci_handle dg_handle; + /* For STREAMs. */ + struct vmci_handle qp_handle; + struct vmci_qp *qpair; + u64 produce_size; + u64 consume_size; + u32 detach_sub_id; + union vmci_transport_notify notify; + const struct vmci_transport_notify_ops *notify_ops; + struct list_head elem; + struct sock *sk; + spinlock_t lock; /* protects sk. */ +}; + +int vmci_transport_register(void); +void vmci_transport_unregister(void); + +int vmci_transport_send_wrote_bh(struct sockaddr_vm *dst, + struct sockaddr_vm *src); +int vmci_transport_send_read_bh(struct sockaddr_vm *dst, + struct sockaddr_vm *src); +int vmci_transport_send_wrote(struct sock *sk); +int vmci_transport_send_read(struct sock *sk); +int vmci_transport_send_waiting_write(struct sock *sk, + struct vmci_transport_waiting_info *wait); +int vmci_transport_send_waiting_read(struct sock *sk, + struct vmci_transport_waiting_info *wait); + +#endif diff --git a/net/vmw_vsock/vmci_transport_notify.c b/net/vmw_vsock/vmci_transport_notify.c new file mode 100644 index 000000000..7c3a7db13 --- /dev/null +++ b/net/vmw_vsock/vmci_transport_notify.c @@ -0,0 +1,672 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * VMware vSockets Driver + * + * Copyright (C) 2009-2013 VMware, Inc. All rights reserved. + */ + +#include +#include +#include +#include + +#include "vmci_transport_notify.h" + +#define PKT_FIELD(vsk, field_name) (vmci_trans(vsk)->notify.pkt.field_name) + +static bool vmci_transport_notify_waiting_write(struct vsock_sock *vsk) +{ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + bool retval; + u64 notify_limit; + + if (!PKT_FIELD(vsk, peer_waiting_write)) + return false; + +#ifdef VSOCK_OPTIMIZATION_FLOW_CONTROL + /* When the sender blocks, we take that as a sign that the sender is + * faster than the receiver. To reduce the transmit rate of the sender, + * we delay the sending of the read notification by decreasing the + * write_notify_window. The notification is delayed until the number of + * bytes used in the queue drops below the write_notify_window. + */ + + if (!PKT_FIELD(vsk, peer_waiting_write_detected)) { + PKT_FIELD(vsk, peer_waiting_write_detected) = true; + if (PKT_FIELD(vsk, write_notify_window) < PAGE_SIZE) { + PKT_FIELD(vsk, write_notify_window) = + PKT_FIELD(vsk, write_notify_min_window); + } else { + PKT_FIELD(vsk, write_notify_window) -= PAGE_SIZE; + if (PKT_FIELD(vsk, write_notify_window) < + PKT_FIELD(vsk, write_notify_min_window)) + PKT_FIELD(vsk, write_notify_window) = + PKT_FIELD(vsk, write_notify_min_window); + + } + } + notify_limit = vmci_trans(vsk)->consume_size - + PKT_FIELD(vsk, write_notify_window); +#else + notify_limit = 0; +#endif + + /* For now we ignore the wait information and just see if the free + * space exceeds the notify limit. Note that improving this function + * to be more intelligent will not require a protocol change and will + * retain compatibility between endpoints with mixed versions of this + * function. + * + * The notify_limit is used to delay notifications in the case where + * flow control is enabled. Below the test is expressed in terms of + * free space in the queue: if free_space > ConsumeSize - + * write_notify_window then notify An alternate way of expressing this + * is to rewrite the expression to use the data ready in the receive + * queue: if write_notify_window > bufferReady then notify as + * free_space == ConsumeSize - bufferReady. + */ + retval = vmci_qpair_consume_free_space(vmci_trans(vsk)->qpair) > + notify_limit; +#ifdef VSOCK_OPTIMIZATION_FLOW_CONTROL + if (retval) { + /* + * Once we notify the peer, we reset the detected flag so the + * next wait will again cause a decrease in the window size. + */ + + PKT_FIELD(vsk, peer_waiting_write_detected) = false; + } +#endif + return retval; +#else + return true; +#endif +} + +static bool vmci_transport_notify_waiting_read(struct vsock_sock *vsk) +{ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + if (!PKT_FIELD(vsk, peer_waiting_read)) + return false; + + /* For now we ignore the wait information and just see if there is any + * data for our peer to read. Note that improving this function to be + * more intelligent will not require a protocol change and will retain + * compatibility between endpoints with mixed versions of this + * function. + */ + return vmci_qpair_produce_buf_ready(vmci_trans(vsk)->qpair) > 0; +#else + return true; +#endif +} + +static void +vmci_transport_handle_waiting_read(struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, + struct sockaddr_vm *dst, + struct sockaddr_vm *src) +{ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + struct vsock_sock *vsk; + + vsk = vsock_sk(sk); + + PKT_FIELD(vsk, peer_waiting_read) = true; + memcpy(&PKT_FIELD(vsk, peer_waiting_read_info), &pkt->u.wait, + sizeof(PKT_FIELD(vsk, peer_waiting_read_info))); + + if (vmci_transport_notify_waiting_read(vsk)) { + bool sent; + + if (bottom_half) + sent = vmci_transport_send_wrote_bh(dst, src) > 0; + else + sent = vmci_transport_send_wrote(sk) > 0; + + if (sent) + PKT_FIELD(vsk, peer_waiting_read) = false; + } +#endif +} + +static void +vmci_transport_handle_waiting_write(struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, + struct sockaddr_vm *dst, + struct sockaddr_vm *src) +{ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + struct vsock_sock *vsk; + + vsk = vsock_sk(sk); + + PKT_FIELD(vsk, peer_waiting_write) = true; + memcpy(&PKT_FIELD(vsk, peer_waiting_write_info), &pkt->u.wait, + sizeof(PKT_FIELD(vsk, peer_waiting_write_info))); + + if (vmci_transport_notify_waiting_write(vsk)) { + bool sent; + + if (bottom_half) + sent = vmci_transport_send_read_bh(dst, src) > 0; + else + sent = vmci_transport_send_read(sk) > 0; + + if (sent) + PKT_FIELD(vsk, peer_waiting_write) = false; + } +#endif +} + +static void +vmci_transport_handle_read(struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, + struct sockaddr_vm *dst, struct sockaddr_vm *src) +{ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + struct vsock_sock *vsk; + + vsk = vsock_sk(sk); + PKT_FIELD(vsk, sent_waiting_write) = false; +#endif + + sk->sk_write_space(sk); +} + +static bool send_waiting_read(struct sock *sk, u64 room_needed) +{ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + struct vsock_sock *vsk; + struct vmci_transport_waiting_info waiting_info; + u64 tail; + u64 head; + u64 room_left; + bool ret; + + vsk = vsock_sk(sk); + + if (PKT_FIELD(vsk, sent_waiting_read)) + return true; + + if (PKT_FIELD(vsk, write_notify_window) < + vmci_trans(vsk)->consume_size) + PKT_FIELD(vsk, write_notify_window) = + min(PKT_FIELD(vsk, write_notify_window) + PAGE_SIZE, + vmci_trans(vsk)->consume_size); + + vmci_qpair_get_consume_indexes(vmci_trans(vsk)->qpair, &tail, &head); + room_left = vmci_trans(vsk)->consume_size - head; + if (room_needed >= room_left) { + waiting_info.offset = room_needed - room_left; + waiting_info.generation = + PKT_FIELD(vsk, consume_q_generation) + 1; + } else { + waiting_info.offset = head + room_needed; + waiting_info.generation = PKT_FIELD(vsk, consume_q_generation); + } + + ret = vmci_transport_send_waiting_read(sk, &waiting_info) > 0; + if (ret) + PKT_FIELD(vsk, sent_waiting_read) = true; + + return ret; +#else + return true; +#endif +} + +static bool send_waiting_write(struct sock *sk, u64 room_needed) +{ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + struct vsock_sock *vsk; + struct vmci_transport_waiting_info waiting_info; + u64 tail; + u64 head; + u64 room_left; + bool ret; + + vsk = vsock_sk(sk); + + if (PKT_FIELD(vsk, sent_waiting_write)) + return true; + + vmci_qpair_get_produce_indexes(vmci_trans(vsk)->qpair, &tail, &head); + room_left = vmci_trans(vsk)->produce_size - tail; + if (room_needed + 1 >= room_left) { + /* Wraps around to current generation. */ + waiting_info.offset = room_needed + 1 - room_left; + waiting_info.generation = PKT_FIELD(vsk, produce_q_generation); + } else { + waiting_info.offset = tail + room_needed + 1; + waiting_info.generation = + PKT_FIELD(vsk, produce_q_generation) - 1; + } + + ret = vmci_transport_send_waiting_write(sk, &waiting_info) > 0; + if (ret) + PKT_FIELD(vsk, sent_waiting_write) = true; + + return ret; +#else + return true; +#endif +} + +static int vmci_transport_send_read_notification(struct sock *sk) +{ + struct vsock_sock *vsk; + bool sent_read; + unsigned int retries; + int err; + + vsk = vsock_sk(sk); + sent_read = false; + retries = 0; + err = 0; + + if (vmci_transport_notify_waiting_write(vsk)) { + /* Notify the peer that we have read, retrying the send on + * failure up to our maximum value. XXX For now we just log + * the failure, but later we should schedule a work item to + * handle the resend until it succeeds. That would require + * keeping track of work items in the vsk and cleaning them up + * upon socket close. + */ + while (!(vsk->peer_shutdown & RCV_SHUTDOWN) && + !sent_read && + retries < VMCI_TRANSPORT_MAX_DGRAM_RESENDS) { + err = vmci_transport_send_read(sk); + if (err >= 0) + sent_read = true; + + retries++; + } + + if (retries >= VMCI_TRANSPORT_MAX_DGRAM_RESENDS) + pr_err("%p unable to send read notify to peer\n", sk); + else +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + PKT_FIELD(vsk, peer_waiting_write) = false; +#endif + + } + return err; +} + +static void +vmci_transport_handle_wrote(struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, + struct sockaddr_vm *dst, struct sockaddr_vm *src) +{ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + struct vsock_sock *vsk = vsock_sk(sk); + PKT_FIELD(vsk, sent_waiting_read) = false; +#endif + vsock_data_ready(sk); +} + +static void vmci_transport_notify_pkt_socket_init(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + PKT_FIELD(vsk, write_notify_window) = PAGE_SIZE; + PKT_FIELD(vsk, write_notify_min_window) = PAGE_SIZE; + PKT_FIELD(vsk, peer_waiting_read) = false; + PKT_FIELD(vsk, peer_waiting_write) = false; + PKT_FIELD(vsk, peer_waiting_write_detected) = false; + PKT_FIELD(vsk, sent_waiting_read) = false; + PKT_FIELD(vsk, sent_waiting_write) = false; + PKT_FIELD(vsk, produce_q_generation) = 0; + PKT_FIELD(vsk, consume_q_generation) = 0; + + memset(&PKT_FIELD(vsk, peer_waiting_read_info), 0, + sizeof(PKT_FIELD(vsk, peer_waiting_read_info))); + memset(&PKT_FIELD(vsk, peer_waiting_write_info), 0, + sizeof(PKT_FIELD(vsk, peer_waiting_write_info))); +} + +static void vmci_transport_notify_pkt_socket_destruct(struct vsock_sock *vsk) +{ +} + +static int +vmci_transport_notify_pkt_poll_in(struct sock *sk, + size_t target, bool *data_ready_now) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + if (vsock_stream_has_data(vsk) >= target) { + *data_ready_now = true; + } else { + /* We can't read right now because there is not enough data + * in the queue. Ask for notifications when there is something + * to read. + */ + if (sk->sk_state == TCP_ESTABLISHED) { + if (!send_waiting_read(sk, 1)) + return -1; + + } + *data_ready_now = false; + } + + return 0; +} + +static int +vmci_transport_notify_pkt_poll_out(struct sock *sk, + size_t target, bool *space_avail_now) +{ + s64 produce_q_free_space; + struct vsock_sock *vsk = vsock_sk(sk); + + produce_q_free_space = vsock_stream_has_space(vsk); + if (produce_q_free_space > 0) { + *space_avail_now = true; + return 0; + } else if (produce_q_free_space == 0) { + /* This is a connected socket but we can't currently send data. + * Notify the peer that we are waiting if the queue is full. We + * only send a waiting write if the queue is full because + * otherwise we end up in an infinite WAITING_WRITE, READ, + * WAITING_WRITE, READ, etc. loop. Treat failing to send the + * notification as a socket error, passing that back through + * the mask. + */ + if (!send_waiting_write(sk, 1)) + return -1; + + *space_avail_now = false; + } + + return 0; +} + +static int +vmci_transport_notify_pkt_recv_init( + struct sock *sk, + size_t target, + struct vmci_transport_recv_notify_data *data) +{ + struct vsock_sock *vsk = vsock_sk(sk); + +#ifdef VSOCK_OPTIMIZATION_WAITING_NOTIFY + data->consume_head = 0; + data->produce_tail = 0; +#ifdef VSOCK_OPTIMIZATION_FLOW_CONTROL + data->notify_on_block = false; + + if (PKT_FIELD(vsk, write_notify_min_window) < target + 1) { + PKT_FIELD(vsk, write_notify_min_window) = target + 1; + if (PKT_FIELD(vsk, write_notify_window) < + PKT_FIELD(vsk, write_notify_min_window)) { + /* If the current window is smaller than the new + * minimal window size, we need to reevaluate whether + * we need to notify the sender. If the number of ready + * bytes are smaller than the new window, we need to + * send a notification to the sender before we block. + */ + + PKT_FIELD(vsk, write_notify_window) = + PKT_FIELD(vsk, write_notify_min_window); + data->notify_on_block = true; + } + } +#endif +#endif + + return 0; +} + +static int +vmci_transport_notify_pkt_recv_pre_block( + struct sock *sk, + size_t target, + struct vmci_transport_recv_notify_data *data) +{ + int err = 0; + + /* Notify our peer that we are waiting for data to read. */ + if (!send_waiting_read(sk, target)) { + err = -EHOSTUNREACH; + return err; + } +#ifdef VSOCK_OPTIMIZATION_FLOW_CONTROL + if (data->notify_on_block) { + err = vmci_transport_send_read_notification(sk); + if (err < 0) + return err; + + data->notify_on_block = false; + } +#endif + + return err; +} + +static int +vmci_transport_notify_pkt_recv_pre_dequeue( + struct sock *sk, + size_t target, + struct vmci_transport_recv_notify_data *data) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + /* Now consume up to len bytes from the queue. Note that since we have + * the socket locked we should copy at least ready bytes. + */ +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + vmci_qpair_get_consume_indexes(vmci_trans(vsk)->qpair, + &data->produce_tail, + &data->consume_head); +#endif + + return 0; +} + +static int +vmci_transport_notify_pkt_recv_post_dequeue( + struct sock *sk, + size_t target, + ssize_t copied, + bool data_read, + struct vmci_transport_recv_notify_data *data) +{ + struct vsock_sock *vsk; + int err; + + vsk = vsock_sk(sk); + err = 0; + + if (data_read) { +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + /* Detect a wrap-around to maintain queue generation. Note + * that this is safe since we hold the socket lock across the + * two queue pair operations. + */ + if (copied >= + vmci_trans(vsk)->consume_size - data->consume_head) + PKT_FIELD(vsk, consume_q_generation)++; +#endif + + err = vmci_transport_send_read_notification(sk); + if (err < 0) + return err; + + } + return err; +} + +static int +vmci_transport_notify_pkt_send_init( + struct sock *sk, + struct vmci_transport_send_notify_data *data) +{ +#ifdef VSOCK_OPTIMIZATION_WAITING_NOTIFY + data->consume_head = 0; + data->produce_tail = 0; +#endif + + return 0; +} + +static int +vmci_transport_notify_pkt_send_pre_block( + struct sock *sk, + struct vmci_transport_send_notify_data *data) +{ + /* Notify our peer that we are waiting for room to write. */ + if (!send_waiting_write(sk, 1)) + return -EHOSTUNREACH; + + return 0; +} + +static int +vmci_transport_notify_pkt_send_pre_enqueue( + struct sock *sk, + struct vmci_transport_send_notify_data *data) +{ + struct vsock_sock *vsk = vsock_sk(sk); + +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + vmci_qpair_get_produce_indexes(vmci_trans(vsk)->qpair, + &data->produce_tail, + &data->consume_head); +#endif + + return 0; +} + +static int +vmci_transport_notify_pkt_send_post_enqueue( + struct sock *sk, + ssize_t written, + struct vmci_transport_send_notify_data *data) +{ + int err = 0; + struct vsock_sock *vsk; + bool sent_wrote = false; + int retries = 0; + + vsk = vsock_sk(sk); + +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + /* Detect a wrap-around to maintain queue generation. Note that this + * is safe since we hold the socket lock across the two queue pair + * operations. + */ + if (written >= vmci_trans(vsk)->produce_size - data->produce_tail) + PKT_FIELD(vsk, produce_q_generation)++; + +#endif + + if (vmci_transport_notify_waiting_read(vsk)) { + /* Notify the peer that we have written, retrying the send on + * failure up to our maximum value. See the XXX comment for the + * corresponding piece of code in StreamRecvmsg() for potential + * improvements. + */ + while (!(vsk->peer_shutdown & RCV_SHUTDOWN) && + !sent_wrote && + retries < VMCI_TRANSPORT_MAX_DGRAM_RESENDS) { + err = vmci_transport_send_wrote(sk); + if (err >= 0) + sent_wrote = true; + + retries++; + } + + if (retries >= VMCI_TRANSPORT_MAX_DGRAM_RESENDS) { + pr_err("%p unable to send wrote notify to peer\n", sk); + return err; + } else { +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) + PKT_FIELD(vsk, peer_waiting_read) = false; +#endif + } + } + return err; +} + +static void +vmci_transport_notify_pkt_handle_pkt( + struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, + struct sockaddr_vm *dst, + struct sockaddr_vm *src, bool *pkt_processed) +{ + bool processed = false; + + switch (pkt->type) { + case VMCI_TRANSPORT_PACKET_TYPE_WROTE: + vmci_transport_handle_wrote(sk, pkt, bottom_half, dst, src); + processed = true; + break; + case VMCI_TRANSPORT_PACKET_TYPE_READ: + vmci_transport_handle_read(sk, pkt, bottom_half, dst, src); + processed = true; + break; + case VMCI_TRANSPORT_PACKET_TYPE_WAITING_WRITE: + vmci_transport_handle_waiting_write(sk, pkt, bottom_half, + dst, src); + processed = true; + break; + + case VMCI_TRANSPORT_PACKET_TYPE_WAITING_READ: + vmci_transport_handle_waiting_read(sk, pkt, bottom_half, + dst, src); + processed = true; + break; + } + + if (pkt_processed) + *pkt_processed = processed; +} + +static void vmci_transport_notify_pkt_process_request(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + PKT_FIELD(vsk, write_notify_window) = vmci_trans(vsk)->consume_size; + if (vmci_trans(vsk)->consume_size < + PKT_FIELD(vsk, write_notify_min_window)) + PKT_FIELD(vsk, write_notify_min_window) = + vmci_trans(vsk)->consume_size; +} + +static void vmci_transport_notify_pkt_process_negotiate(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + PKT_FIELD(vsk, write_notify_window) = vmci_trans(vsk)->consume_size; + if (vmci_trans(vsk)->consume_size < + PKT_FIELD(vsk, write_notify_min_window)) + PKT_FIELD(vsk, write_notify_min_window) = + vmci_trans(vsk)->consume_size; +} + +/* Socket control packet based operations. */ +const struct vmci_transport_notify_ops vmci_transport_notify_pkt_ops = { + .socket_init = vmci_transport_notify_pkt_socket_init, + .socket_destruct = vmci_transport_notify_pkt_socket_destruct, + .poll_in = vmci_transport_notify_pkt_poll_in, + .poll_out = vmci_transport_notify_pkt_poll_out, + .handle_notify_pkt = vmci_transport_notify_pkt_handle_pkt, + .recv_init = vmci_transport_notify_pkt_recv_init, + .recv_pre_block = vmci_transport_notify_pkt_recv_pre_block, + .recv_pre_dequeue = vmci_transport_notify_pkt_recv_pre_dequeue, + .recv_post_dequeue = vmci_transport_notify_pkt_recv_post_dequeue, + .send_init = vmci_transport_notify_pkt_send_init, + .send_pre_block = vmci_transport_notify_pkt_send_pre_block, + .send_pre_enqueue = vmci_transport_notify_pkt_send_pre_enqueue, + .send_post_enqueue = vmci_transport_notify_pkt_send_post_enqueue, + .process_request = vmci_transport_notify_pkt_process_request, + .process_negotiate = vmci_transport_notify_pkt_process_negotiate, +}; diff --git a/net/vmw_vsock/vmci_transport_notify.h b/net/vmw_vsock/vmci_transport_notify.h new file mode 100644 index 000000000..a1aa5a998 --- /dev/null +++ b/net/vmw_vsock/vmci_transport_notify.h @@ -0,0 +1,75 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +/* + * VMware vSockets Driver + * + * Copyright (C) 2009-2013 VMware, Inc. All rights reserved. + */ + +#ifndef __VMCI_TRANSPORT_NOTIFY_H__ +#define __VMCI_TRANSPORT_NOTIFY_H__ + +#include +#include +#include + +#include "vmci_transport.h" + +/* Comment this out to compare with old protocol. */ +#define VSOCK_OPTIMIZATION_WAITING_NOTIFY 1 +#if defined(VSOCK_OPTIMIZATION_WAITING_NOTIFY) +/* Comment this out to remove flow control for "new" protocol */ +#define VSOCK_OPTIMIZATION_FLOW_CONTROL 1 +#endif + +#define VMCI_TRANSPORT_MAX_DGRAM_RESENDS 10 + +struct vmci_transport_recv_notify_data { + u64 consume_head; + u64 produce_tail; + bool notify_on_block; +}; + +struct vmci_transport_send_notify_data { + u64 consume_head; + u64 produce_tail; +}; + +/* Socket notification callbacks. */ +struct vmci_transport_notify_ops { + void (*socket_init) (struct sock *sk); + void (*socket_destruct) (struct vsock_sock *vsk); + int (*poll_in) (struct sock *sk, size_t target, + bool *data_ready_now); + int (*poll_out) (struct sock *sk, size_t target, + bool *space_avail_now); + void (*handle_notify_pkt) (struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, struct sockaddr_vm *dst, + struct sockaddr_vm *src, + bool *pkt_processed); + int (*recv_init) (struct sock *sk, size_t target, + struct vmci_transport_recv_notify_data *data); + int (*recv_pre_block) (struct sock *sk, size_t target, + struct vmci_transport_recv_notify_data *data); + int (*recv_pre_dequeue) (struct sock *sk, size_t target, + struct vmci_transport_recv_notify_data *data); + int (*recv_post_dequeue) (struct sock *sk, size_t target, + ssize_t copied, bool data_read, + struct vmci_transport_recv_notify_data *data); + int (*send_init) (struct sock *sk, + struct vmci_transport_send_notify_data *data); + int (*send_pre_block) (struct sock *sk, + struct vmci_transport_send_notify_data *data); + int (*send_pre_enqueue) (struct sock *sk, + struct vmci_transport_send_notify_data *data); + int (*send_post_enqueue) (struct sock *sk, ssize_t written, + struct vmci_transport_send_notify_data *data); + void (*process_request) (struct sock *sk); + void (*process_negotiate) (struct sock *sk); +}; + +extern const struct vmci_transport_notify_ops vmci_transport_notify_pkt_ops; +extern const +struct vmci_transport_notify_ops vmci_transport_notify_pkt_q_state_ops; + +#endif /* __VMCI_TRANSPORT_NOTIFY_H__ */ diff --git a/net/vmw_vsock/vmci_transport_notify_qstate.c b/net/vmw_vsock/vmci_transport_notify_qstate.c new file mode 100644 index 000000000..e96a88d85 --- /dev/null +++ b/net/vmw_vsock/vmci_transport_notify_qstate.c @@ -0,0 +1,430 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * VMware vSockets Driver + * + * Copyright (C) 2009-2013 VMware, Inc. All rights reserved. + */ + +#include +#include +#include +#include + +#include "vmci_transport_notify.h" + +#define PKT_FIELD(vsk, field_name) \ + (vmci_trans(vsk)->notify.pkt_q_state.field_name) + +static bool vmci_transport_notify_waiting_write(struct vsock_sock *vsk) +{ + bool retval; + u64 notify_limit; + + if (!PKT_FIELD(vsk, peer_waiting_write)) + return false; + + /* When the sender blocks, we take that as a sign that the sender is + * faster than the receiver. To reduce the transmit rate of the sender, + * we delay the sending of the read notification by decreasing the + * write_notify_window. The notification is delayed until the number of + * bytes used in the queue drops below the write_notify_window. + */ + + if (!PKT_FIELD(vsk, peer_waiting_write_detected)) { + PKT_FIELD(vsk, peer_waiting_write_detected) = true; + if (PKT_FIELD(vsk, write_notify_window) < PAGE_SIZE) { + PKT_FIELD(vsk, write_notify_window) = + PKT_FIELD(vsk, write_notify_min_window); + } else { + PKT_FIELD(vsk, write_notify_window) -= PAGE_SIZE; + if (PKT_FIELD(vsk, write_notify_window) < + PKT_FIELD(vsk, write_notify_min_window)) + PKT_FIELD(vsk, write_notify_window) = + PKT_FIELD(vsk, write_notify_min_window); + + } + } + notify_limit = vmci_trans(vsk)->consume_size - + PKT_FIELD(vsk, write_notify_window); + + /* The notify_limit is used to delay notifications in the case where + * flow control is enabled. Below the test is expressed in terms of + * free space in the queue: if free_space > ConsumeSize - + * write_notify_window then notify An alternate way of expressing this + * is to rewrite the expression to use the data ready in the receive + * queue: if write_notify_window > bufferReady then notify as + * free_space == ConsumeSize - bufferReady. + */ + + retval = vmci_qpair_consume_free_space(vmci_trans(vsk)->qpair) > + notify_limit; + + if (retval) { + /* Once we notify the peer, we reset the detected flag so the + * next wait will again cause a decrease in the window size. + */ + + PKT_FIELD(vsk, peer_waiting_write_detected) = false; + } + return retval; +} + +static void +vmci_transport_handle_read(struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, + struct sockaddr_vm *dst, struct sockaddr_vm *src) +{ + sk->sk_write_space(sk); +} + +static void +vmci_transport_handle_wrote(struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, + struct sockaddr_vm *dst, struct sockaddr_vm *src) +{ + vsock_data_ready(sk); +} + +static void vsock_block_update_write_window(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + if (PKT_FIELD(vsk, write_notify_window) < vmci_trans(vsk)->consume_size) + PKT_FIELD(vsk, write_notify_window) = + min(PKT_FIELD(vsk, write_notify_window) + PAGE_SIZE, + vmci_trans(vsk)->consume_size); +} + +static int vmci_transport_send_read_notification(struct sock *sk) +{ + struct vsock_sock *vsk; + bool sent_read; + unsigned int retries; + int err; + + vsk = vsock_sk(sk); + sent_read = false; + retries = 0; + err = 0; + + if (vmci_transport_notify_waiting_write(vsk)) { + /* Notify the peer that we have read, retrying the send on + * failure up to our maximum value. XXX For now we just log + * the failure, but later we should schedule a work item to + * handle the resend until it succeeds. That would require + * keeping track of work items in the vsk and cleaning them up + * upon socket close. + */ + while (!(vsk->peer_shutdown & RCV_SHUTDOWN) && + !sent_read && + retries < VMCI_TRANSPORT_MAX_DGRAM_RESENDS) { + err = vmci_transport_send_read(sk); + if (err >= 0) + sent_read = true; + + retries++; + } + + if (retries >= VMCI_TRANSPORT_MAX_DGRAM_RESENDS && !sent_read) + pr_err("%p unable to send read notification to peer\n", + sk); + else + PKT_FIELD(vsk, peer_waiting_write) = false; + + } + return err; +} + +static void vmci_transport_notify_pkt_socket_init(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + PKT_FIELD(vsk, write_notify_window) = PAGE_SIZE; + PKT_FIELD(vsk, write_notify_min_window) = PAGE_SIZE; + PKT_FIELD(vsk, peer_waiting_write) = false; + PKT_FIELD(vsk, peer_waiting_write_detected) = false; +} + +static void vmci_transport_notify_pkt_socket_destruct(struct vsock_sock *vsk) +{ + PKT_FIELD(vsk, write_notify_window) = PAGE_SIZE; + PKT_FIELD(vsk, write_notify_min_window) = PAGE_SIZE; + PKT_FIELD(vsk, peer_waiting_write) = false; + PKT_FIELD(vsk, peer_waiting_write_detected) = false; +} + +static int +vmci_transport_notify_pkt_poll_in(struct sock *sk, + size_t target, bool *data_ready_now) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + if (vsock_stream_has_data(vsk) >= target) { + *data_ready_now = true; + } else { + /* We can't read right now because there is not enough data + * in the queue. Ask for notifications when there is something + * to read. + */ + if (sk->sk_state == TCP_ESTABLISHED) + vsock_block_update_write_window(sk); + *data_ready_now = false; + } + + return 0; +} + +static int +vmci_transport_notify_pkt_poll_out(struct sock *sk, + size_t target, bool *space_avail_now) +{ + s64 produce_q_free_space; + struct vsock_sock *vsk = vsock_sk(sk); + + produce_q_free_space = vsock_stream_has_space(vsk); + if (produce_q_free_space > 0) { + *space_avail_now = true; + return 0; + } else if (produce_q_free_space == 0) { + /* This is a connected socket but we can't currently send data. + * Nothing else to do. + */ + *space_avail_now = false; + } + + return 0; +} + +static int +vmci_transport_notify_pkt_recv_init( + struct sock *sk, + size_t target, + struct vmci_transport_recv_notify_data *data) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + data->consume_head = 0; + data->produce_tail = 0; + data->notify_on_block = false; + + if (PKT_FIELD(vsk, write_notify_min_window) < target + 1) { + PKT_FIELD(vsk, write_notify_min_window) = target + 1; + if (PKT_FIELD(vsk, write_notify_window) < + PKT_FIELD(vsk, write_notify_min_window)) { + /* If the current window is smaller than the new + * minimal window size, we need to reevaluate whether + * we need to notify the sender. If the number of ready + * bytes are smaller than the new window, we need to + * send a notification to the sender before we block. + */ + + PKT_FIELD(vsk, write_notify_window) = + PKT_FIELD(vsk, write_notify_min_window); + data->notify_on_block = true; + } + } + + return 0; +} + +static int +vmci_transport_notify_pkt_recv_pre_block( + struct sock *sk, + size_t target, + struct vmci_transport_recv_notify_data *data) +{ + int err = 0; + + vsock_block_update_write_window(sk); + + if (data->notify_on_block) { + err = vmci_transport_send_read_notification(sk); + if (err < 0) + return err; + data->notify_on_block = false; + } + + return err; +} + +static int +vmci_transport_notify_pkt_recv_post_dequeue( + struct sock *sk, + size_t target, + ssize_t copied, + bool data_read, + struct vmci_transport_recv_notify_data *data) +{ + struct vsock_sock *vsk; + int err; + bool was_full = false; + u64 free_space; + + vsk = vsock_sk(sk); + err = 0; + + if (data_read) { + smp_mb(); + + free_space = + vmci_qpair_consume_free_space(vmci_trans(vsk)->qpair); + was_full = free_space == copied; + + if (was_full) + PKT_FIELD(vsk, peer_waiting_write) = true; + + err = vmci_transport_send_read_notification(sk); + if (err < 0) + return err; + + /* See the comment in + * vmci_transport_notify_pkt_send_post_enqueue(). + */ + vsock_data_ready(sk); + } + + return err; +} + +static int +vmci_transport_notify_pkt_send_init( + struct sock *sk, + struct vmci_transport_send_notify_data *data) +{ + data->consume_head = 0; + data->produce_tail = 0; + + return 0; +} + +static int +vmci_transport_notify_pkt_send_post_enqueue( + struct sock *sk, + ssize_t written, + struct vmci_transport_send_notify_data *data) +{ + int err = 0; + struct vsock_sock *vsk; + bool sent_wrote = false; + bool was_empty; + int retries = 0; + + vsk = vsock_sk(sk); + + smp_mb(); + + was_empty = + vmci_qpair_produce_buf_ready(vmci_trans(vsk)->qpair) == written; + if (was_empty) { + while (!(vsk->peer_shutdown & RCV_SHUTDOWN) && + !sent_wrote && + retries < VMCI_TRANSPORT_MAX_DGRAM_RESENDS) { + err = vmci_transport_send_wrote(sk); + if (err >= 0) + sent_wrote = true; + + retries++; + } + } + + if (retries >= VMCI_TRANSPORT_MAX_DGRAM_RESENDS && !sent_wrote) { + pr_err("%p unable to send wrote notification to peer\n", + sk); + return err; + } + + return err; +} + +static void +vmci_transport_notify_pkt_handle_pkt( + struct sock *sk, + struct vmci_transport_packet *pkt, + bool bottom_half, + struct sockaddr_vm *dst, + struct sockaddr_vm *src, bool *pkt_processed) +{ + bool processed = false; + + switch (pkt->type) { + case VMCI_TRANSPORT_PACKET_TYPE_WROTE: + vmci_transport_handle_wrote(sk, pkt, bottom_half, dst, src); + processed = true; + break; + case VMCI_TRANSPORT_PACKET_TYPE_READ: + vmci_transport_handle_read(sk, pkt, bottom_half, dst, src); + processed = true; + break; + } + + if (pkt_processed) + *pkt_processed = processed; +} + +static void vmci_transport_notify_pkt_process_request(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + PKT_FIELD(vsk, write_notify_window) = vmci_trans(vsk)->consume_size; + if (vmci_trans(vsk)->consume_size < + PKT_FIELD(vsk, write_notify_min_window)) + PKT_FIELD(vsk, write_notify_min_window) = + vmci_trans(vsk)->consume_size; +} + +static void vmci_transport_notify_pkt_process_negotiate(struct sock *sk) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + PKT_FIELD(vsk, write_notify_window) = vmci_trans(vsk)->consume_size; + if (vmci_trans(vsk)->consume_size < + PKT_FIELD(vsk, write_notify_min_window)) + PKT_FIELD(vsk, write_notify_min_window) = + vmci_trans(vsk)->consume_size; +} + +static int +vmci_transport_notify_pkt_recv_pre_dequeue( + struct sock *sk, + size_t target, + struct vmci_transport_recv_notify_data *data) +{ + return 0; /* NOP for QState. */ +} + +static int +vmci_transport_notify_pkt_send_pre_block( + struct sock *sk, + struct vmci_transport_send_notify_data *data) +{ + return 0; /* NOP for QState. */ +} + +static int +vmci_transport_notify_pkt_send_pre_enqueue( + struct sock *sk, + struct vmci_transport_send_notify_data *data) +{ + return 0; /* NOP for QState. */ +} + +/* Socket always on control packet based operations. */ +const struct vmci_transport_notify_ops vmci_transport_notify_pkt_q_state_ops = { + .socket_init = vmci_transport_notify_pkt_socket_init, + .socket_destruct = vmci_transport_notify_pkt_socket_destruct, + .poll_in = vmci_transport_notify_pkt_poll_in, + .poll_out = vmci_transport_notify_pkt_poll_out, + .handle_notify_pkt = vmci_transport_notify_pkt_handle_pkt, + .recv_init = vmci_transport_notify_pkt_recv_init, + .recv_pre_block = vmci_transport_notify_pkt_recv_pre_block, + .recv_pre_dequeue = vmci_transport_notify_pkt_recv_pre_dequeue, + .recv_post_dequeue = vmci_transport_notify_pkt_recv_post_dequeue, + .send_init = vmci_transport_notify_pkt_send_init, + .send_pre_block = vmci_transport_notify_pkt_send_pre_block, + .send_pre_enqueue = vmci_transport_notify_pkt_send_pre_enqueue, + .send_post_enqueue = vmci_transport_notify_pkt_send_post_enqueue, + .process_request = vmci_transport_notify_pkt_process_request, + .process_negotiate = vmci_transport_notify_pkt_process_negotiate, +}; diff --git a/net/vmw_vsock/vsock_addr.c b/net/vmw_vsock/vsock_addr.c new file mode 100644 index 000000000..223b9660a --- /dev/null +++ b/net/vmw_vsock/vsock_addr.c @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * VMware vSockets Driver + * + * Copyright (C) 2007-2012 VMware, Inc. All rights reserved. + */ + +#include +#include +#include +#include +#include + +void vsock_addr_init(struct sockaddr_vm *addr, u32 cid, u32 port) +{ + memset(addr, 0, sizeof(*addr)); + addr->svm_family = AF_VSOCK; + addr->svm_cid = cid; + addr->svm_port = port; +} +EXPORT_SYMBOL_GPL(vsock_addr_init); + +int vsock_addr_validate(const struct sockaddr_vm *addr) +{ + __u8 svm_valid_flags = VMADDR_FLAG_TO_HOST; + + if (!addr) + return -EFAULT; + + if (addr->svm_family != AF_VSOCK) + return -EAFNOSUPPORT; + + if (addr->svm_flags & ~svm_valid_flags) + return -EINVAL; + + return 0; +} +EXPORT_SYMBOL_GPL(vsock_addr_validate); + +bool vsock_addr_bound(const struct sockaddr_vm *addr) +{ + return addr->svm_port != VMADDR_PORT_ANY; +} +EXPORT_SYMBOL_GPL(vsock_addr_bound); + +void vsock_addr_unbind(struct sockaddr_vm *addr) +{ + vsock_addr_init(addr, VMADDR_CID_ANY, VMADDR_PORT_ANY); +} +EXPORT_SYMBOL_GPL(vsock_addr_unbind); + +bool vsock_addr_equals_addr(const struct sockaddr_vm *addr, + const struct sockaddr_vm *other) +{ + return addr->svm_cid == other->svm_cid && + addr->svm_port == other->svm_port; +} +EXPORT_SYMBOL_GPL(vsock_addr_equals_addr); + +int vsock_addr_cast(const struct sockaddr *addr, + size_t len, struct sockaddr_vm **out_addr) +{ + if (len < sizeof(**out_addr)) + return -EFAULT; + + *out_addr = (struct sockaddr_vm *)addr; + return vsock_addr_validate(*out_addr); +} +EXPORT_SYMBOL_GPL(vsock_addr_cast); diff --git a/net/vmw_vsock/vsock_loopback.c b/net/vmw_vsock/vsock_loopback.c new file mode 100644 index 000000000..89905c092 --- /dev/null +++ b/net/vmw_vsock/vsock_loopback.c @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* loopback transport for vsock using virtio_transport_common APIs + * + * Copyright (C) 2013-2019 Red Hat, Inc. + * Authors: Asias He + * Stefan Hajnoczi + * Stefano Garzarella + * + */ +#include +#include +#include +#include + +struct vsock_loopback { + struct workqueue_struct *workqueue; + + struct sk_buff_head pkt_queue; + struct work_struct pkt_work; +}; + +static struct vsock_loopback the_vsock_loopback; + +static u32 vsock_loopback_get_local_cid(void) +{ + return VMADDR_CID_LOCAL; +} + +static int vsock_loopback_send_pkt(struct sk_buff *skb) +{ + struct vsock_loopback *vsock = &the_vsock_loopback; + int len = skb->len; + + skb_queue_tail(&vsock->pkt_queue, skb); + + queue_work(vsock->workqueue, &vsock->pkt_work); + + return len; +} + +static int vsock_loopback_cancel_pkt(struct vsock_sock *vsk) +{ + struct vsock_loopback *vsock = &the_vsock_loopback; + + virtio_transport_purge_skbs(vsk, &vsock->pkt_queue); + + return 0; +} + +static bool vsock_loopback_seqpacket_allow(u32 remote_cid); + +static struct virtio_transport loopback_transport = { + .transport = { + .module = THIS_MODULE, + + .get_local_cid = vsock_loopback_get_local_cid, + + .init = virtio_transport_do_socket_init, + .destruct = virtio_transport_destruct, + .release = virtio_transport_release, + .connect = virtio_transport_connect, + .shutdown = virtio_transport_shutdown, + .cancel_pkt = vsock_loopback_cancel_pkt, + + .dgram_bind = virtio_transport_dgram_bind, + .dgram_dequeue = virtio_transport_dgram_dequeue, + .dgram_enqueue = virtio_transport_dgram_enqueue, + .dgram_allow = virtio_transport_dgram_allow, + + .stream_dequeue = virtio_transport_stream_dequeue, + .stream_enqueue = virtio_transport_stream_enqueue, + .stream_has_data = virtio_transport_stream_has_data, + .stream_has_space = virtio_transport_stream_has_space, + .stream_rcvhiwat = virtio_transport_stream_rcvhiwat, + .stream_is_active = virtio_transport_stream_is_active, + .stream_allow = virtio_transport_stream_allow, + + .seqpacket_dequeue = virtio_transport_seqpacket_dequeue, + .seqpacket_enqueue = virtio_transport_seqpacket_enqueue, + .seqpacket_allow = vsock_loopback_seqpacket_allow, + .seqpacket_has_data = virtio_transport_seqpacket_has_data, + + .notify_poll_in = virtio_transport_notify_poll_in, + .notify_poll_out = virtio_transport_notify_poll_out, + .notify_recv_init = virtio_transport_notify_recv_init, + .notify_recv_pre_block = virtio_transport_notify_recv_pre_block, + .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue, + .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue, + .notify_send_init = virtio_transport_notify_send_init, + .notify_send_pre_block = virtio_transport_notify_send_pre_block, + .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, + .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, + .notify_buffer_size = virtio_transport_notify_buffer_size, + }, + + .send_pkt = vsock_loopback_send_pkt, +}; + +static bool vsock_loopback_seqpacket_allow(u32 remote_cid) +{ + return true; +} + +static void vsock_loopback_work(struct work_struct *work) +{ + struct vsock_loopback *vsock = + container_of(work, struct vsock_loopback, pkt_work); + struct sk_buff_head pkts; + struct sk_buff *skb; + + skb_queue_head_init(&pkts); + + spin_lock_bh(&vsock->pkt_queue.lock); + skb_queue_splice_init(&vsock->pkt_queue, &pkts); + spin_unlock_bh(&vsock->pkt_queue.lock); + + while ((skb = __skb_dequeue(&pkts))) { + virtio_transport_deliver_tap_pkt(skb); + virtio_transport_recv_pkt(&loopback_transport, skb); + } +} + +static int __init vsock_loopback_init(void) +{ + struct vsock_loopback *vsock = &the_vsock_loopback; + int ret; + + vsock->workqueue = alloc_workqueue("vsock-loopback", 0, 0); + if (!vsock->workqueue) + return -ENOMEM; + + skb_queue_head_init(&vsock->pkt_queue); + INIT_WORK(&vsock->pkt_work, vsock_loopback_work); + + ret = vsock_core_register(&loopback_transport.transport, + VSOCK_TRANSPORT_F_LOCAL); + if (ret) + goto out_wq; + + return 0; + +out_wq: + destroy_workqueue(vsock->workqueue); + return ret; +} + +static void __exit vsock_loopback_exit(void) +{ + struct vsock_loopback *vsock = &the_vsock_loopback; + + vsock_core_unregister(&loopback_transport.transport); + + flush_work(&vsock->pkt_work); + + virtio_vsock_skb_queue_purge(&vsock->pkt_queue); + + destroy_workqueue(vsock->workqueue); +} + +module_init(vsock_loopback_init); +module_exit(vsock_loopback_exit); +MODULE_LICENSE("GPL v2"); +MODULE_AUTHOR("Stefano Garzarella "); +MODULE_DESCRIPTION("loopback transport for vsock"); +MODULE_ALIAS_NETPROTO(PF_VSOCK); diff --git a/net/wireless/.gitignore b/net/wireless/.gitignore new file mode 100644 index 000000000..1a29cd69d --- /dev/null +++ b/net/wireless/.gitignore @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: GPL-2.0-only +shipped-certs.c +extra-certs.c diff --git a/net/wireless/Kconfig b/net/wireless/Kconfig new file mode 100644 index 000000000..f620acd2a --- /dev/null +++ b/net/wireless/Kconfig @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: GPL-2.0-only +config WIRELESS_EXT + bool + +config WEXT_CORE + def_bool y + depends on CFG80211_WEXT || WIRELESS_EXT + +config WEXT_PROC + def_bool y + depends on PROC_FS + depends on WEXT_CORE + +config WEXT_SPY + bool + +config WEXT_PRIV + bool + +config CFG80211 + tristate "cfg80211 - wireless configuration API" + depends on RFKILL || !RFKILL + select FW_LOADER + select CRC32 + # may need to update this when certificates are changed and are + # using a different algorithm, though right now they shouldn't + # (this is here rather than below to allow it to be a module) + select CRYPTO_SHA256 if CFG80211_USE_KERNEL_REGDB_KEYS + help + cfg80211 is the Linux wireless LAN (802.11) configuration API. + Enable this if you have a wireless device. + + For more information refer to documentation on the wireless wiki: + + https://wireless.wiki.kernel.org/en/developers/Documentation/cfg80211 + + When built as a module it will be called cfg80211. + +if CFG80211 + +config NL80211_TESTMODE + bool "nl80211 testmode command" + help + The nl80211 testmode command helps implementing things like + factory calibration or validation tools for wireless chips. + + Select this option ONLY for kernels that are specifically + built for such purposes. + + Debugging tools that are supposed to end up in the hands of + users should better be implemented with debugfs. + + Say N. + +config CFG80211_DEVELOPER_WARNINGS + bool "enable developer warnings" + default n + help + This option enables some additional warnings that help + cfg80211 developers and driver developers, but beware that + they can also trigger due to races with userspace. + + For example, when a driver reports that it was disconnected + from the AP, but the user disconnects manually at the same + time, the warning might trigger spuriously due to races. + + Say Y only if you are developing cfg80211 or a driver based + on it (or mac80211). + + +config CFG80211_CERTIFICATION_ONUS + bool "cfg80211 certification onus" + depends on EXPERT + default n + help + You should disable this option unless you are both capable + and willing to ensure your system will remain regulatory + compliant with the features available under this option. + Some options may still be under heavy development and + for whatever reason regulatory compliance has not or + cannot yet be verified. Regulatory verification may at + times only be possible until you have the final system + in place. + + This option should only be enabled by system integrators + or distributions that have done work necessary to ensure + regulatory certification on the system with the enabled + features. Alternatively you can enable this option if + you are a wireless researcher and are working in a controlled + and approved environment by your local regulatory agency. + +config CFG80211_REQUIRE_SIGNED_REGDB + bool "require regdb signature" if CFG80211_CERTIFICATION_ONUS + default y + select SYSTEM_DATA_VERIFICATION + help + Require that in addition to the "regulatory.db" file a + "regulatory.db.p7s" can be loaded with a valid PKCS#7 + signature for the regulatory.db file made by one of the + keys in the certs/ directory. + +config CFG80211_USE_KERNEL_REGDB_KEYS + bool "allow regdb keys shipped with the kernel" if CFG80211_CERTIFICATION_ONUS + default y + depends on CFG80211_REQUIRE_SIGNED_REGDB + help + Allow the regulatory database to be signed by one of the keys for + which certificates are part of the kernel sources + (in net/wireless/certs/). + + This is currently only Seth Forshee's key, who is the regulatory + database maintainer. + +config CFG80211_EXTRA_REGDB_KEYDIR + string "additional regdb key directory" if CFG80211_CERTIFICATION_ONUS + depends on CFG80211_REQUIRE_SIGNED_REGDB + help + If selected, point to a directory with DER-encoded X.509 + certificates like in the kernel sources (net/wireless/certs/) + that shall be accepted for a signed regulatory database. + + Note that you need to also select the correct CRYPTO_ modules + for your certificates, and if cfg80211 is built-in they also must be. + +config CFG80211_REG_CELLULAR_HINTS + bool "cfg80211 regulatory support for cellular base station hints" + depends on CFG80211_CERTIFICATION_ONUS + help + This option enables support for parsing regulatory hints + from cellular base stations. If enabled and at least one driver + claims support for parsing cellular base station hints the + regulatory core will allow and parse these regulatory hints. + The regulatory core will only apply these regulatory hints on + drivers that support this feature. You should only enable this + feature if you have tested and validated this feature on your + systems. + +config CFG80211_REG_RELAX_NO_IR + bool "cfg80211 support for NO_IR relaxation" + depends on CFG80211_CERTIFICATION_ONUS + help + This option enables support for relaxation of the NO_IR flag for + situations that certain regulatory bodies have provided clarifications + on how relaxation can occur. This feature has an inherent dependency on + userspace features which must have been properly tested and as such is + not enabled by default. + + A relaxation feature example is allowing the operation of a P2P group + owner (GO) on channels marked with NO_IR if there is an additional BSS + interface which associated to an AP which userspace assumes or confirms + to be an authorized master, i.e., with radar detection support and DFS + capabilities. However, note that in order to not create daisy chain + scenarios, this relaxation is not allowed in cases where the BSS client + is associated to P2P GO and in addition the P2P GO instantiated on + a channel due to this relaxation should not allow connection from + non P2P clients. + + The regulatory core will apply these relaxations only for drivers that + support this feature by declaring the appropriate channel flags and + capabilities in their registration flow. + +config CFG80211_DEFAULT_PS + bool "enable powersave by default" + default y + help + This option enables powersave mode by default. + + If this causes your applications to misbehave you should fix your + applications instead -- they need to register their network + latency requirement, see Documentation/power/pm_qos_interface.rst. + +config CFG80211_DEBUGFS + bool "cfg80211 DebugFS entries" + depends on DEBUG_FS + help + You can enable this if you want debugfs entries for cfg80211. + + If unsure, say N. + +config CFG80211_CRDA_SUPPORT + bool "support CRDA" if EXPERT + default y + help + You should enable this option unless you know for sure you have no + need for it, for example when using the regulatory database loaded as + a firmware file. + + If unsure, say Y. + +config CFG80211_WEXT + bool "cfg80211 wireless extensions compatibility" if !CFG80211_WEXT_EXPORT + select WEXT_CORE + default y if CFG80211_WEXT_EXPORT + help + Enable this option if you need old userspace for wireless + extensions with cfg80211-based drivers. + +config CFG80211_WEXT_EXPORT + bool + help + Drivers should select this option if they require cfg80211's + wext compatibility symbols to be exported. + +endif # CFG80211 + +config LIB80211 + tristate + default n + help + This options enables a library of common routines used + by IEEE802.11 wireless LAN drivers. + + Drivers should select this themselves if needed. + +config LIB80211_CRYPT_WEP + tristate + select CRYPTO_LIB_ARC4 + +config LIB80211_CRYPT_CCMP + tristate + select CRYPTO + select CRYPTO_AES + select CRYPTO_CCM + +config LIB80211_CRYPT_TKIP + tristate + select CRYPTO_LIB_ARC4 + +config LIB80211_DEBUG + bool "lib80211 debugging messages" + depends on LIB80211 + default n + help + You can enable this if you want verbose debugging messages + from lib80211. + + If unsure, say N. diff --git a/net/wireless/Makefile b/net/wireless/Makefile new file mode 100644 index 000000000..527ae669f --- /dev/null +++ b/net/wireless/Makefile @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: GPL-2.0 +obj-$(CONFIG_CFG80211) += cfg80211.o +obj-$(CONFIG_LIB80211) += lib80211.o +obj-$(CONFIG_LIB80211_CRYPT_WEP) += lib80211_crypt_wep.o +obj-$(CONFIG_LIB80211_CRYPT_CCMP) += lib80211_crypt_ccmp.o +obj-$(CONFIG_LIB80211_CRYPT_TKIP) += lib80211_crypt_tkip.o + +obj-$(CONFIG_WEXT_CORE) += wext-core.o +obj-$(CONFIG_WEXT_PROC) += wext-proc.o +obj-$(CONFIG_WEXT_SPY) += wext-spy.o +obj-$(CONFIG_WEXT_PRIV) += wext-priv.o + +cfg80211-y += core.o sysfs.o radiotap.o util.o reg.o scan.o nl80211.o +cfg80211-y += mlme.o ibss.o sme.o chan.o ethtool.o mesh.o ap.o trace.o ocb.o +cfg80211-y += pmsr.o +cfg80211-$(CONFIG_OF) += of.o +cfg80211-$(CONFIG_CFG80211_DEBUGFS) += debugfs.o +cfg80211-$(CONFIG_CFG80211_WEXT) += wext-compat.o wext-sme.o + +CFLAGS_trace.o := -I$(src) + +cfg80211-$(CONFIG_CFG80211_USE_KERNEL_REGDB_KEYS) += shipped-certs.o +ifneq ($(CONFIG_CFG80211_EXTRA_REGDB_KEYDIR),) +cfg80211-y += extra-certs.o +endif + +$(obj)/shipped-certs.c: $(wildcard $(srctree)/$(src)/certs/*.hex) + @$(kecho) " GEN $@" + $(Q)(echo '#include "reg.h"'; \ + echo 'const u8 shipped_regdb_certs[] = {'; \ + echo | cat - $^ ; \ + echo '};'; \ + echo 'unsigned int shipped_regdb_certs_len = sizeof(shipped_regdb_certs);'; \ + ) > $@ + +$(obj)/extra-certs.c: $(CONFIG_CFG80211_EXTRA_REGDB_KEYDIR) \ + $(wildcard $(CONFIG_CFG80211_EXTRA_REGDB_KEYDIR)/*.x509) + @$(kecho) " GEN $@" + $(Q)(set -e; \ + allf=""; \ + for f in $^ ; do \ + test -f $$f || continue;\ + # similar to hexdump -v -e '1/1 "0x%.2x," "\n"' \ + thisf=$$(od -An -v -tx1 < $$f | \ + sed -e 's/ /\n/g' | \ + sed -e 's/^[0-9a-f]\+$$/\0/;t;d' | \ + sed -e 's/^/0x/;s/$$/,/'); \ + # file should not be empty - maybe command substitution failed? \ + test ! -z "$$thisf";\ + allf=$$allf$$thisf;\ + done; \ + ( \ + echo '#include "reg.h"'; \ + echo 'const u8 extra_regdb_certs[] = {'; \ + echo "$$allf"; \ + echo '};'; \ + echo 'unsigned int extra_regdb_certs_len = sizeof(extra_regdb_certs);'; \ + ) > $@) + +clean-files += shipped-certs.c extra-certs.c diff --git a/net/wireless/ap.c b/net/wireless/ap.c new file mode 100644 index 000000000..e68923200 --- /dev/null +++ b/net/wireless/ap.c @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Parts of this file are + * Copyright (C) 2022 Intel Corporation + */ +#include +#include +#include +#include "nl80211.h" +#include "core.h" +#include "rdev-ops.h" + + +static int ___cfg80211_stop_ap(struct cfg80211_registered_device *rdev, + struct net_device *dev, unsigned int link_id, + bool notify) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int err; + + ASSERT_WDEV_LOCK(wdev); + + if (!rdev->ops->stop_ap) + return -EOPNOTSUPP; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_AP && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_GO) + return -EOPNOTSUPP; + + if (!wdev->links[link_id].ap.beacon_interval) + return -ENOENT; + + err = rdev_stop_ap(rdev, dev, link_id); + if (!err) { + wdev->conn_owner_nlportid = 0; + wdev->links[link_id].ap.beacon_interval = 0; + memset(&wdev->links[link_id].ap.chandef, 0, + sizeof(wdev->links[link_id].ap.chandef)); + wdev->u.ap.ssid_len = 0; + rdev_set_qos_map(rdev, dev, NULL); + if (notify) + nl80211_send_ap_stopped(wdev); + + /* Should we apply the grace period during beaconing interface + * shutdown also? + */ + cfg80211_sched_dfs_chan_update(rdev); + } + + schedule_work(&cfg80211_disconnect_work); + + return err; +} + +int __cfg80211_stop_ap(struct cfg80211_registered_device *rdev, + struct net_device *dev, int link_id, + bool notify) +{ + unsigned int link; + int ret = 0; + + if (link_id >= 0) + return ___cfg80211_stop_ap(rdev, dev, link_id, notify); + + for_each_valid_link(dev->ieee80211_ptr, link) { + int ret1 = ___cfg80211_stop_ap(rdev, dev, link, notify); + + if (ret1) + ret = ret1; + /* try the next one also if one errored */ + } + + return ret; +} + +int cfg80211_stop_ap(struct cfg80211_registered_device *rdev, + struct net_device *dev, int link_id, + bool notify) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int err; + + wdev_lock(wdev); + err = __cfg80211_stop_ap(rdev, dev, link_id, notify); + wdev_unlock(wdev); + + return err; +} diff --git a/net/wireless/certs/sforshee.hex b/net/wireless/certs/sforshee.hex new file mode 100644 index 000000000..14ea66643 --- /dev/null +++ b/net/wireless/certs/sforshee.hex @@ -0,0 +1,86 @@ +/* Seth Forshee's regdb certificate */ +0x30, 0x82, 0x02, 0xa4, 0x30, 0x82, 0x01, 0x8c, +0x02, 0x09, 0x00, 0xb2, 0x8d, 0xdf, 0x47, 0xae, +0xf9, 0xce, 0xa7, 0x30, 0x0d, 0x06, 0x09, 0x2a, +0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, +0x05, 0x00, 0x30, 0x13, 0x31, 0x11, 0x30, 0x0f, +0x06, 0x03, 0x55, 0x04, 0x03, 0x0c, 0x08, 0x73, +0x66, 0x6f, 0x72, 0x73, 0x68, 0x65, 0x65, 0x30, +0x20, 0x17, 0x0d, 0x31, 0x37, 0x31, 0x30, 0x30, +0x36, 0x31, 0x39, 0x34, 0x30, 0x33, 0x35, 0x5a, +0x18, 0x0f, 0x32, 0x31, 0x31, 0x37, 0x30, 0x39, +0x31, 0x32, 0x31, 0x39, 0x34, 0x30, 0x33, 0x35, +0x5a, 0x30, 0x13, 0x31, 0x11, 0x30, 0x0f, 0x06, +0x03, 0x55, 0x04, 0x03, 0x0c, 0x08, 0x73, 0x66, +0x6f, 0x72, 0x73, 0x68, 0x65, 0x65, 0x30, 0x82, +0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, +0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, +0x00, 0x03, 0x82, 0x01, 0x0f, 0x00, 0x30, 0x82, +0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, 0xb5, +0x40, 0xe3, 0x9c, 0x28, 0x84, 0x39, 0x03, 0xf2, +0x39, 0xd7, 0x66, 0x2c, 0x41, 0x38, 0x15, 0xac, +0x7e, 0xa5, 0x83, 0x71, 0x25, 0x7e, 0x90, 0x7c, +0x68, 0xdd, 0x6f, 0x3f, 0xd9, 0xd7, 0x59, 0x38, +0x9f, 0x7c, 0x6a, 0x52, 0xc2, 0x03, 0x2a, 0x2d, +0x7e, 0x66, 0xf4, 0x1e, 0xb3, 0x12, 0x70, 0x20, +0x5b, 0xd4, 0x97, 0x32, 0x3d, 0x71, 0x8b, 0x3b, +0x1b, 0x08, 0x17, 0x14, 0x6b, 0x61, 0xc4, 0x57, +0x8b, 0x96, 0x16, 0x1c, 0xfd, 0x24, 0xd5, 0x0b, +0x09, 0xf9, 0x68, 0x11, 0x84, 0xfb, 0xca, 0x51, +0x0c, 0xd1, 0x45, 0x19, 0xda, 0x10, 0x44, 0x8a, +0xd9, 0xfe, 0x76, 0xa9, 0xfd, 0x60, 0x2d, 0x18, +0x0b, 0x28, 0x95, 0xb2, 0x2d, 0xea, 0x88, 0x98, +0xb8, 0xd1, 0x56, 0x21, 0xf0, 0x53, 0x1f, 0xf1, +0x02, 0x6f, 0xe9, 0x46, 0x9b, 0x93, 0x5f, 0x28, +0x90, 0x0f, 0xac, 0x36, 0xfa, 0x68, 0x23, 0x71, +0x57, 0x56, 0xf6, 0xcc, 0xd3, 0xdf, 0x7d, 0x2a, +0xd9, 0x1b, 0x73, 0x45, 0xeb, 0xba, 0x27, 0x85, +0xef, 0x7a, 0x7f, 0xa5, 0xcb, 0x80, 0xc7, 0x30, +0x36, 0xd2, 0x53, 0xee, 0xec, 0xac, 0x1e, 0xe7, +0x31, 0xf1, 0x36, 0xa2, 0x9c, 0x63, 0xc6, 0x65, +0x5b, 0x7f, 0x25, 0x75, 0x68, 0xa1, 0xea, 0xd3, +0x7e, 0x00, 0x5c, 0x9a, 0x5e, 0xd8, 0x20, 0x18, +0x32, 0x77, 0x07, 0x29, 0x12, 0x66, 0x1e, 0x36, +0x73, 0xe7, 0x97, 0x04, 0x41, 0x37, 0xb1, 0xb1, +0x72, 0x2b, 0xf4, 0xa1, 0x29, 0x20, 0x7c, 0x96, +0x79, 0x0b, 0x2b, 0xd0, 0xd8, 0xde, 0xc8, 0x6c, +0x3f, 0x93, 0xfb, 0xc5, 0xee, 0x78, 0x52, 0x11, +0x15, 0x1b, 0x7a, 0xf6, 0xe2, 0x68, 0x99, 0xe7, +0xfb, 0x46, 0x16, 0x84, 0xe3, 0xc7, 0xa1, 0xe6, +0xe0, 0xd2, 0x46, 0xd5, 0xe1, 0xc4, 0x5f, 0xa0, +0x66, 0xf4, 0xda, 0xc4, 0xff, 0x95, 0x1d, 0x02, +0x03, 0x01, 0x00, 0x01, 0x30, 0x0d, 0x06, 0x09, +0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, +0x0b, 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, +0x87, 0x03, 0xda, 0xf2, 0x82, 0xc2, 0xdd, 0xaf, +0x7c, 0x44, 0x2f, 0x86, 0xd3, 0x5f, 0x4c, 0x93, +0x48, 0xb9, 0xfe, 0x07, 0x17, 0xbb, 0x21, 0xf7, +0x25, 0x23, 0x4e, 0xaa, 0x22, 0x0c, 0x16, 0xb9, +0x73, 0xae, 0x9d, 0x46, 0x7c, 0x75, 0xd9, 0xc3, +0x49, 0x57, 0x47, 0xbf, 0x33, 0xb7, 0x97, 0xec, +0xf5, 0x40, 0x75, 0xc0, 0x46, 0x22, 0xf0, 0xa0, +0x5d, 0x9c, 0x79, 0x13, 0xa1, 0xff, 0xb8, 0xa3, +0x2f, 0x7b, 0x8e, 0x06, 0x3f, 0xc8, 0xb6, 0xe4, +0x6a, 0x28, 0xf2, 0x34, 0x5c, 0x23, 0x3f, 0x32, +0xc0, 0xe6, 0xad, 0x0f, 0xac, 0xcf, 0x55, 0x74, +0x47, 0x73, 0xd3, 0x01, 0x85, 0xb7, 0x0b, 0x22, +0x56, 0x24, 0x7d, 0x9f, 0x09, 0xa9, 0x0e, 0x86, +0x9e, 0x37, 0x5b, 0x9c, 0x6d, 0x02, 0xd9, 0x8c, +0xc8, 0x50, 0x6a, 0xe2, 0x59, 0xf3, 0x16, 0x06, +0xea, 0xb2, 0x42, 0xb5, 0x58, 0xfe, 0xba, 0xd1, +0x81, 0x57, 0x1a, 0xef, 0xb2, 0x38, 0x88, 0x58, +0xf6, 0xaa, 0xc4, 0x2e, 0x8b, 0x5a, 0x27, 0xe4, +0xa5, 0xe8, 0xa4, 0xca, 0x67, 0x5c, 0xac, 0x72, +0x67, 0xc3, 0x6f, 0x13, 0xc3, 0x2d, 0x35, 0x79, +0xd7, 0x8a, 0xe7, 0xf5, 0xd4, 0x21, 0x30, 0x4a, +0xd5, 0xf6, 0xa3, 0xd9, 0x79, 0x56, 0xf2, 0x0f, +0x10, 0xf7, 0x7d, 0xd0, 0x51, 0x93, 0x2f, 0x47, +0xf8, 0x7d, 0x4b, 0x0a, 0x84, 0x55, 0x12, 0x0a, +0x7d, 0x4e, 0x3b, 0x1f, 0x2b, 0x2f, 0xfc, 0x28, +0xb3, 0x69, 0x34, 0xe1, 0x80, 0x80, 0xbb, 0xe2, +0xaf, 0xb9, 0xd6, 0x30, 0xf1, 0x1d, 0x54, 0x87, +0x23, 0x99, 0x9f, 0x51, 0x03, 0x4c, 0x45, 0x7d, +0x02, 0x65, 0x73, 0xab, 0xfd, 0xcf, 0x94, 0xcc, +0x0d, 0x3a, 0x60, 0xfd, 0x3c, 0x14, 0x2f, 0x16, +0x33, 0xa9, 0x21, 0x1f, 0xcb, 0x50, 0xb1, 0x8f, +0x03, 0xee, 0xa0, 0x66, 0xa9, 0x16, 0x79, 0x14, diff --git a/net/wireless/certs/wens.hex b/net/wireless/certs/wens.hex new file mode 100644 index 000000000..0d50369be --- /dev/null +++ b/net/wireless/certs/wens.hex @@ -0,0 +1,87 @@ +/* Chen-Yu Tsai's regdb certificate */ +0x30, 0x82, 0x02, 0xa7, 0x30, 0x82, 0x01, 0x8f, +0x02, 0x14, 0x61, 0xc0, 0x38, 0x65, 0x1a, 0xab, +0xdc, 0xf9, 0x4b, 0xd0, 0xac, 0x7f, 0xf0, 0x6c, +0x72, 0x48, 0xdb, 0x18, 0xc6, 0x00, 0x30, 0x0d, +0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, +0x01, 0x01, 0x0b, 0x05, 0x00, 0x30, 0x0f, 0x31, +0x0d, 0x30, 0x0b, 0x06, 0x03, 0x55, 0x04, 0x03, +0x0c, 0x04, 0x77, 0x65, 0x6e, 0x73, 0x30, 0x20, +0x17, 0x0d, 0x32, 0x33, 0x31, 0x32, 0x30, 0x31, +0x30, 0x37, 0x34, 0x31, 0x31, 0x34, 0x5a, 0x18, +0x0f, 0x32, 0x31, 0x32, 0x33, 0x31, 0x31, 0x30, +0x37, 0x30, 0x37, 0x34, 0x31, 0x31, 0x34, 0x5a, +0x30, 0x0f, 0x31, 0x0d, 0x30, 0x0b, 0x06, 0x03, +0x55, 0x04, 0x03, 0x0c, 0x04, 0x77, 0x65, 0x6e, +0x73, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, +0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, +0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01, 0x0f, +0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, +0x01, 0x00, 0xa9, 0x7a, 0x2c, 0x78, 0x4d, 0xa7, +0x19, 0x2d, 0x32, 0x52, 0xa0, 0x2e, 0x6c, 0xef, +0x88, 0x7f, 0x15, 0xc5, 0xb6, 0x69, 0x54, 0x16, +0x43, 0x14, 0x79, 0x53, 0xb7, 0xae, 0x88, 0xfe, +0xc0, 0xb7, 0x5d, 0x47, 0x8e, 0x1a, 0xe1, 0xef, +0xb3, 0x90, 0x86, 0xda, 0xd3, 0x64, 0x81, 0x1f, +0xce, 0x5d, 0x9e, 0x4b, 0x6e, 0x58, 0x02, 0x3e, +0xb2, 0x6f, 0x5e, 0x42, 0x47, 0x41, 0xf4, 0x2c, +0xb8, 0xa8, 0xd4, 0xaa, 0xc0, 0x0e, 0xe6, 0x48, +0xf0, 0xa8, 0xce, 0xcb, 0x08, 0xae, 0x37, 0xaf, +0xf6, 0x40, 0x39, 0xcb, 0x55, 0x6f, 0x5b, 0x4f, +0x85, 0x34, 0xe6, 0x69, 0x10, 0x50, 0x72, 0x5e, +0x4e, 0x9d, 0x4c, 0xba, 0x38, 0x36, 0x0d, 0xce, +0x73, 0x38, 0xd7, 0x27, 0x02, 0x2a, 0x79, 0x03, +0xe1, 0xac, 0xcf, 0xb0, 0x27, 0x85, 0x86, 0x93, +0x17, 0xab, 0xec, 0x42, 0x77, 0x37, 0x65, 0x8a, +0x44, 0xcb, 0xd6, 0x42, 0x93, 0x92, 0x13, 0xe3, +0x39, 0x45, 0xc5, 0x6e, 0x00, 0x4a, 0x7f, 0xcb, +0x42, 0x17, 0x2b, 0x25, 0x8c, 0xb8, 0x17, 0x3b, +0x15, 0x36, 0x59, 0xde, 0x42, 0xce, 0x21, 0xe6, +0xb6, 0xc7, 0x6e, 0x5e, 0x26, 0x1f, 0xf7, 0x8a, +0x57, 0x9e, 0xa5, 0x96, 0x72, 0xb7, 0x02, 0x32, +0xeb, 0x07, 0x2b, 0x73, 0xe2, 0x4f, 0x66, 0x58, +0x9a, 0xeb, 0x0f, 0x07, 0xb6, 0xab, 0x50, 0x8b, +0xc3, 0x8f, 0x17, 0xfa, 0x0a, 0x99, 0xc2, 0x16, +0x25, 0xbf, 0x2d, 0x6b, 0x1a, 0xaa, 0xe6, 0x3e, +0x5f, 0xeb, 0x6d, 0x9b, 0x5d, 0x4d, 0x42, 0x83, +0x2d, 0x39, 0xb8, 0xc9, 0xac, 0xdb, 0x3a, 0x91, +0x50, 0xdf, 0xbb, 0xb1, 0x76, 0x6d, 0x15, 0x73, +0xfd, 0xc6, 0xe6, 0x6b, 0x71, 0x9e, 0x67, 0x36, +0x22, 0x83, 0x79, 0xb1, 0xd6, 0xb8, 0x84, 0x52, +0xaf, 0x96, 0x5b, 0xc3, 0x63, 0x02, 0x4e, 0x78, +0x70, 0x57, 0x02, 0x03, 0x01, 0x00, 0x01, 0x30, +0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, +0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x03, 0x82, +0x01, 0x01, 0x00, 0x24, 0x28, 0xee, 0x22, 0x74, +0x7f, 0x7c, 0xfa, 0x6c, 0x1f, 0xb3, 0x18, 0xd1, +0xc2, 0x3d, 0x7d, 0x29, 0x42, 0x88, 0xad, 0x82, +0xa5, 0xb1, 0x8a, 0x05, 0xd0, 0xec, 0x5c, 0x91, +0x20, 0xf6, 0x82, 0xfd, 0xd5, 0x67, 0x60, 0x5f, +0x31, 0xf5, 0xbd, 0x88, 0x91, 0x70, 0xbd, 0xb8, +0xb9, 0x8c, 0x88, 0xfe, 0x53, 0xc9, 0x54, 0x9b, +0x43, 0xc4, 0x7a, 0x43, 0x74, 0x6b, 0xdd, 0xb0, +0xb1, 0x3b, 0x33, 0x45, 0x46, 0x78, 0xa3, 0x1c, +0xef, 0x54, 0x68, 0xf7, 0x85, 0x9c, 0xe4, 0x51, +0x6f, 0x06, 0xaf, 0x81, 0xdb, 0x2a, 0x7b, 0x7b, +0x6f, 0xa8, 0x9c, 0x67, 0xd8, 0xcb, 0xc9, 0x91, +0x40, 0x00, 0xae, 0xd9, 0xa1, 0x9f, 0xdd, 0xa6, +0x43, 0x0e, 0x28, 0x7b, 0xaa, 0x1b, 0xe9, 0x84, +0xdb, 0x76, 0x64, 0x42, 0x70, 0xc9, 0xc0, 0xeb, +0xae, 0x84, 0x11, 0x16, 0x68, 0x4e, 0x84, 0x9e, +0x7e, 0x92, 0x36, 0xee, 0x1c, 0x3b, 0x08, 0x63, +0xeb, 0x79, 0x84, 0x15, 0x08, 0x9d, 0xaf, 0xc8, +0x9a, 0xc7, 0x34, 0xd3, 0x94, 0x4b, 0xd1, 0x28, +0x97, 0xbe, 0xd1, 0x45, 0x75, 0xdc, 0x35, 0x62, +0xac, 0x1d, 0x1f, 0xb7, 0xb7, 0x15, 0x87, 0xc8, +0x98, 0xc0, 0x24, 0x31, 0x56, 0x8d, 0xed, 0xdb, +0x06, 0xc6, 0x46, 0xbf, 0x4b, 0x6d, 0xa6, 0xd5, +0xab, 0xcc, 0x60, 0xfc, 0xe5, 0x37, 0xb6, 0x53, +0x7d, 0x58, 0x95, 0xa9, 0x56, 0xc7, 0xf7, 0xee, +0xc3, 0xa0, 0x76, 0xf7, 0x65, 0x4d, 0x53, 0xfa, +0xff, 0x5f, 0x76, 0x33, 0x5a, 0x08, 0xfa, 0x86, +0x92, 0x5a, 0x13, 0xfa, 0x1a, 0xfc, 0xf2, 0x1b, +0x8c, 0x7f, 0x42, 0x6d, 0xb7, 0x7e, 0xb7, 0xb4, +0xf0, 0xc7, 0x83, 0xbb, 0xa2, 0x81, 0x03, 0x2d, +0xd4, 0x2a, 0x63, 0x3f, 0xf7, 0x31, 0x2e, 0x40, +0x33, 0x5c, 0x46, 0xbc, 0x9b, 0xc1, 0x05, 0xa5, +0x45, 0x4e, 0xc3, diff --git a/net/wireless/chan.c b/net/wireless/chan.c new file mode 100644 index 000000000..0e5835cd8 --- /dev/null +++ b/net/wireless/chan.c @@ -0,0 +1,1462 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * This file contains helper code to handle channel + * settings and keeping track of what is possible at + * any point in time. + * + * Copyright 2009 Johannes Berg + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright 2018-2022 Intel Corporation + */ + +#include +#include +#include +#include "core.h" +#include "rdev-ops.h" + +static bool cfg80211_valid_60g_freq(u32 freq) +{ + return freq >= 58320 && freq <= 70200; +} + +void cfg80211_chandef_create(struct cfg80211_chan_def *chandef, + struct ieee80211_channel *chan, + enum nl80211_channel_type chan_type) +{ + if (WARN_ON(!chan)) + return; + + chandef->chan = chan; + chandef->freq1_offset = chan->freq_offset; + chandef->center_freq2 = 0; + chandef->edmg.bw_config = 0; + chandef->edmg.channels = 0; + + switch (chan_type) { + case NL80211_CHAN_NO_HT: + chandef->width = NL80211_CHAN_WIDTH_20_NOHT; + chandef->center_freq1 = chan->center_freq; + break; + case NL80211_CHAN_HT20: + chandef->width = NL80211_CHAN_WIDTH_20; + chandef->center_freq1 = chan->center_freq; + break; + case NL80211_CHAN_HT40PLUS: + chandef->width = NL80211_CHAN_WIDTH_40; + chandef->center_freq1 = chan->center_freq + 10; + break; + case NL80211_CHAN_HT40MINUS: + chandef->width = NL80211_CHAN_WIDTH_40; + chandef->center_freq1 = chan->center_freq - 10; + break; + default: + WARN_ON(1); + } +} +EXPORT_SYMBOL(cfg80211_chandef_create); + +static bool cfg80211_edmg_chandef_valid(const struct cfg80211_chan_def *chandef) +{ + int max_contiguous = 0; + int num_of_enabled = 0; + int contiguous = 0; + int i; + + if (!chandef->edmg.channels || !chandef->edmg.bw_config) + return false; + + if (!cfg80211_valid_60g_freq(chandef->chan->center_freq)) + return false; + + for (i = 0; i < 6; i++) { + if (chandef->edmg.channels & BIT(i)) { + contiguous++; + num_of_enabled++; + } else { + contiguous = 0; + } + + max_contiguous = max(contiguous, max_contiguous); + } + /* basic verification of edmg configuration according to + * IEEE P802.11ay/D4.0 section 9.4.2.251 + */ + /* check bw_config against contiguous edmg channels */ + switch (chandef->edmg.bw_config) { + case IEEE80211_EDMG_BW_CONFIG_4: + case IEEE80211_EDMG_BW_CONFIG_8: + case IEEE80211_EDMG_BW_CONFIG_12: + if (max_contiguous < 1) + return false; + break; + case IEEE80211_EDMG_BW_CONFIG_5: + case IEEE80211_EDMG_BW_CONFIG_9: + case IEEE80211_EDMG_BW_CONFIG_13: + if (max_contiguous < 2) + return false; + break; + case IEEE80211_EDMG_BW_CONFIG_6: + case IEEE80211_EDMG_BW_CONFIG_10: + case IEEE80211_EDMG_BW_CONFIG_14: + if (max_contiguous < 3) + return false; + break; + case IEEE80211_EDMG_BW_CONFIG_7: + case IEEE80211_EDMG_BW_CONFIG_11: + case IEEE80211_EDMG_BW_CONFIG_15: + if (max_contiguous < 4) + return false; + break; + + default: + return false; + } + + /* check bw_config against aggregated (non contiguous) edmg channels */ + switch (chandef->edmg.bw_config) { + case IEEE80211_EDMG_BW_CONFIG_4: + case IEEE80211_EDMG_BW_CONFIG_5: + case IEEE80211_EDMG_BW_CONFIG_6: + case IEEE80211_EDMG_BW_CONFIG_7: + break; + case IEEE80211_EDMG_BW_CONFIG_8: + case IEEE80211_EDMG_BW_CONFIG_9: + case IEEE80211_EDMG_BW_CONFIG_10: + case IEEE80211_EDMG_BW_CONFIG_11: + if (num_of_enabled < 2) + return false; + break; + case IEEE80211_EDMG_BW_CONFIG_12: + case IEEE80211_EDMG_BW_CONFIG_13: + case IEEE80211_EDMG_BW_CONFIG_14: + case IEEE80211_EDMG_BW_CONFIG_15: + if (num_of_enabled < 4 || max_contiguous < 2) + return false; + break; + default: + return false; + } + + return true; +} + +static int nl80211_chan_width_to_mhz(enum nl80211_chan_width chan_width) +{ + int mhz; + + switch (chan_width) { + case NL80211_CHAN_WIDTH_1: + mhz = 1; + break; + case NL80211_CHAN_WIDTH_2: + mhz = 2; + break; + case NL80211_CHAN_WIDTH_4: + mhz = 4; + break; + case NL80211_CHAN_WIDTH_8: + mhz = 8; + break; + case NL80211_CHAN_WIDTH_16: + mhz = 16; + break; + case NL80211_CHAN_WIDTH_5: + mhz = 5; + break; + case NL80211_CHAN_WIDTH_10: + mhz = 10; + break; + case NL80211_CHAN_WIDTH_20: + case NL80211_CHAN_WIDTH_20_NOHT: + mhz = 20; + break; + case NL80211_CHAN_WIDTH_40: + mhz = 40; + break; + case NL80211_CHAN_WIDTH_80P80: + case NL80211_CHAN_WIDTH_80: + mhz = 80; + break; + case NL80211_CHAN_WIDTH_160: + mhz = 160; + break; + case NL80211_CHAN_WIDTH_320: + mhz = 320; + break; + default: + WARN_ON_ONCE(1); + return -1; + } + return mhz; +} + +static int cfg80211_chandef_get_width(const struct cfg80211_chan_def *c) +{ + return nl80211_chan_width_to_mhz(c->width); +} + +bool cfg80211_chandef_valid(const struct cfg80211_chan_def *chandef) +{ + u32 control_freq, oper_freq; + int oper_width, control_width; + + if (!chandef->chan) + return false; + + if (chandef->freq1_offset >= 1000) + return false; + + control_freq = chandef->chan->center_freq; + + switch (chandef->width) { + case NL80211_CHAN_WIDTH_5: + case NL80211_CHAN_WIDTH_10: + case NL80211_CHAN_WIDTH_20: + case NL80211_CHAN_WIDTH_20_NOHT: + if (ieee80211_chandef_to_khz(chandef) != + ieee80211_channel_to_khz(chandef->chan)) + return false; + if (chandef->center_freq2) + return false; + break; + case NL80211_CHAN_WIDTH_1: + case NL80211_CHAN_WIDTH_2: + case NL80211_CHAN_WIDTH_4: + case NL80211_CHAN_WIDTH_8: + case NL80211_CHAN_WIDTH_16: + if (chandef->chan->band != NL80211_BAND_S1GHZ) + return false; + + control_freq = ieee80211_channel_to_khz(chandef->chan); + oper_freq = ieee80211_chandef_to_khz(chandef); + control_width = nl80211_chan_width_to_mhz( + ieee80211_s1g_channel_width( + chandef->chan)); + oper_width = cfg80211_chandef_get_width(chandef); + + if (oper_width < 0 || control_width < 0) + return false; + if (chandef->center_freq2) + return false; + + if (control_freq + MHZ_TO_KHZ(control_width) / 2 > + oper_freq + MHZ_TO_KHZ(oper_width) / 2) + return false; + + if (control_freq - MHZ_TO_KHZ(control_width) / 2 < + oper_freq - MHZ_TO_KHZ(oper_width) / 2) + return false; + break; + case NL80211_CHAN_WIDTH_80P80: + if (!chandef->center_freq2) + return false; + /* adjacent is not allowed -- that's a 160 MHz channel */ + if (chandef->center_freq1 - chandef->center_freq2 == 80 || + chandef->center_freq2 - chandef->center_freq1 == 80) + return false; + break; + default: + if (chandef->center_freq2) + return false; + break; + } + + switch (chandef->width) { + case NL80211_CHAN_WIDTH_5: + case NL80211_CHAN_WIDTH_10: + case NL80211_CHAN_WIDTH_20: + case NL80211_CHAN_WIDTH_20_NOHT: + case NL80211_CHAN_WIDTH_1: + case NL80211_CHAN_WIDTH_2: + case NL80211_CHAN_WIDTH_4: + case NL80211_CHAN_WIDTH_8: + case NL80211_CHAN_WIDTH_16: + /* all checked above */ + break; + case NL80211_CHAN_WIDTH_320: + if (chandef->center_freq1 == control_freq + 150 || + chandef->center_freq1 == control_freq + 130 || + chandef->center_freq1 == control_freq + 110 || + chandef->center_freq1 == control_freq + 90 || + chandef->center_freq1 == control_freq - 90 || + chandef->center_freq1 == control_freq - 110 || + chandef->center_freq1 == control_freq - 130 || + chandef->center_freq1 == control_freq - 150) + break; + fallthrough; + case NL80211_CHAN_WIDTH_160: + if (chandef->center_freq1 == control_freq + 70 || + chandef->center_freq1 == control_freq + 50 || + chandef->center_freq1 == control_freq - 50 || + chandef->center_freq1 == control_freq - 70) + break; + fallthrough; + case NL80211_CHAN_WIDTH_80P80: + case NL80211_CHAN_WIDTH_80: + if (chandef->center_freq1 == control_freq + 30 || + chandef->center_freq1 == control_freq - 30) + break; + fallthrough; + case NL80211_CHAN_WIDTH_40: + if (chandef->center_freq1 == control_freq + 10 || + chandef->center_freq1 == control_freq - 10) + break; + fallthrough; + default: + return false; + } + + /* channel 14 is only for IEEE 802.11b */ + if (chandef->center_freq1 == 2484 && + chandef->width != NL80211_CHAN_WIDTH_20_NOHT) + return false; + + if (cfg80211_chandef_is_edmg(chandef) && + !cfg80211_edmg_chandef_valid(chandef)) + return false; + + return true; +} +EXPORT_SYMBOL(cfg80211_chandef_valid); + +static void chandef_primary_freqs(const struct cfg80211_chan_def *c, + u32 *pri40, u32 *pri80, u32 *pri160) +{ + int tmp; + + switch (c->width) { + case NL80211_CHAN_WIDTH_40: + *pri40 = c->center_freq1; + *pri80 = 0; + *pri160 = 0; + break; + case NL80211_CHAN_WIDTH_80: + case NL80211_CHAN_WIDTH_80P80: + *pri160 = 0; + *pri80 = c->center_freq1; + /* n_P20 */ + tmp = (30 + c->chan->center_freq - c->center_freq1)/20; + /* n_P40 */ + tmp /= 2; + /* freq_P40 */ + *pri40 = c->center_freq1 - 20 + 40 * tmp; + break; + case NL80211_CHAN_WIDTH_160: + *pri160 = c->center_freq1; + /* n_P20 */ + tmp = (70 + c->chan->center_freq - c->center_freq1)/20; + /* n_P40 */ + tmp /= 2; + /* freq_P40 */ + *pri40 = c->center_freq1 - 60 + 40 * tmp; + /* n_P80 */ + tmp /= 2; + *pri80 = c->center_freq1 - 40 + 80 * tmp; + break; + case NL80211_CHAN_WIDTH_320: + /* n_P20 */ + tmp = (150 + c->chan->center_freq - c->center_freq1) / 20; + /* n_P40 */ + tmp /= 2; + /* freq_P40 */ + *pri40 = c->center_freq1 - 140 + 40 * tmp; + /* n_P80 */ + tmp /= 2; + *pri80 = c->center_freq1 - 120 + 80 * tmp; + /* n_P160 */ + tmp /= 2; + *pri160 = c->center_freq1 - 80 + 160 * tmp; + break; + default: + WARN_ON_ONCE(1); + } +} + +const struct cfg80211_chan_def * +cfg80211_chandef_compatible(const struct cfg80211_chan_def *c1, + const struct cfg80211_chan_def *c2) +{ + u32 c1_pri40, c1_pri80, c2_pri40, c2_pri80, c1_pri160, c2_pri160; + + /* If they are identical, return */ + if (cfg80211_chandef_identical(c1, c2)) + return c1; + + /* otherwise, must have same control channel */ + if (c1->chan != c2->chan) + return NULL; + + /* + * If they have the same width, but aren't identical, + * then they can't be compatible. + */ + if (c1->width == c2->width) + return NULL; + + /* + * can't be compatible if one of them is 5 or 10 MHz, + * but they don't have the same width. + */ + if (c1->width == NL80211_CHAN_WIDTH_5 || + c1->width == NL80211_CHAN_WIDTH_10 || + c2->width == NL80211_CHAN_WIDTH_5 || + c2->width == NL80211_CHAN_WIDTH_10) + return NULL; + + if (c1->width == NL80211_CHAN_WIDTH_20_NOHT || + c1->width == NL80211_CHAN_WIDTH_20) + return c2; + + if (c2->width == NL80211_CHAN_WIDTH_20_NOHT || + c2->width == NL80211_CHAN_WIDTH_20) + return c1; + + chandef_primary_freqs(c1, &c1_pri40, &c1_pri80, &c1_pri160); + chandef_primary_freqs(c2, &c2_pri40, &c2_pri80, &c2_pri160); + + if (c1_pri40 != c2_pri40) + return NULL; + + if (c1->width == NL80211_CHAN_WIDTH_40) + return c2; + + if (c2->width == NL80211_CHAN_WIDTH_40) + return c1; + + if (c1_pri80 != c2_pri80) + return NULL; + + if (c1->width == NL80211_CHAN_WIDTH_80 && + c2->width > NL80211_CHAN_WIDTH_80) + return c2; + + if (c2->width == NL80211_CHAN_WIDTH_80 && + c1->width > NL80211_CHAN_WIDTH_80) + return c1; + + WARN_ON(!c1_pri160 && !c2_pri160); + if (c1_pri160 && c2_pri160 && c1_pri160 != c2_pri160) + return NULL; + + if (c1->width > c2->width) + return c1; + return c2; +} +EXPORT_SYMBOL(cfg80211_chandef_compatible); + +static void cfg80211_set_chans_dfs_state(struct wiphy *wiphy, u32 center_freq, + u32 bandwidth, + enum nl80211_dfs_state dfs_state) +{ + struct ieee80211_channel *c; + u32 freq; + + for (freq = center_freq - bandwidth/2 + 10; + freq <= center_freq + bandwidth/2 - 10; + freq += 20) { + c = ieee80211_get_channel(wiphy, freq); + if (!c || !(c->flags & IEEE80211_CHAN_RADAR)) + continue; + + c->dfs_state = dfs_state; + c->dfs_state_entered = jiffies; + } +} + +void cfg80211_set_dfs_state(struct wiphy *wiphy, + const struct cfg80211_chan_def *chandef, + enum nl80211_dfs_state dfs_state) +{ + int width; + + if (WARN_ON(!cfg80211_chandef_valid(chandef))) + return; + + width = cfg80211_chandef_get_width(chandef); + if (width < 0) + return; + + cfg80211_set_chans_dfs_state(wiphy, chandef->center_freq1, + width, dfs_state); + + if (!chandef->center_freq2) + return; + cfg80211_set_chans_dfs_state(wiphy, chandef->center_freq2, + width, dfs_state); +} + +static u32 cfg80211_get_start_freq(u32 center_freq, + u32 bandwidth) +{ + u32 start_freq; + + bandwidth = MHZ_TO_KHZ(bandwidth); + if (bandwidth <= MHZ_TO_KHZ(20)) + start_freq = center_freq; + else + start_freq = center_freq - bandwidth / 2 + MHZ_TO_KHZ(10); + + return start_freq; +} + +static u32 cfg80211_get_end_freq(u32 center_freq, + u32 bandwidth) +{ + u32 end_freq; + + bandwidth = MHZ_TO_KHZ(bandwidth); + if (bandwidth <= MHZ_TO_KHZ(20)) + end_freq = center_freq; + else + end_freq = center_freq + bandwidth / 2 - MHZ_TO_KHZ(10); + + return end_freq; +} + +static int cfg80211_get_chans_dfs_required(struct wiphy *wiphy, + u32 center_freq, + u32 bandwidth) +{ + struct ieee80211_channel *c; + u32 freq, start_freq, end_freq; + + start_freq = cfg80211_get_start_freq(center_freq, bandwidth); + end_freq = cfg80211_get_end_freq(center_freq, bandwidth); + + for (freq = start_freq; freq <= end_freq; freq += MHZ_TO_KHZ(20)) { + c = ieee80211_get_channel_khz(wiphy, freq); + if (!c) + return -EINVAL; + + if (c->flags & IEEE80211_CHAN_RADAR) + return 1; + } + return 0; +} + + +int cfg80211_chandef_dfs_required(struct wiphy *wiphy, + const struct cfg80211_chan_def *chandef, + enum nl80211_iftype iftype) +{ + int width; + int ret; + + if (WARN_ON(!cfg80211_chandef_valid(chandef))) + return -EINVAL; + + switch (iftype) { + case NL80211_IFTYPE_ADHOC: + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_P2P_GO: + case NL80211_IFTYPE_MESH_POINT: + width = cfg80211_chandef_get_width(chandef); + if (width < 0) + return -EINVAL; + + ret = cfg80211_get_chans_dfs_required(wiphy, + ieee80211_chandef_to_khz(chandef), + width); + if (ret < 0) + return ret; + else if (ret > 0) + return BIT(chandef->width); + + if (!chandef->center_freq2) + return 0; + + ret = cfg80211_get_chans_dfs_required(wiphy, + MHZ_TO_KHZ(chandef->center_freq2), + width); + if (ret < 0) + return ret; + else if (ret > 0) + return BIT(chandef->width); + + break; + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_OCB: + case NL80211_IFTYPE_P2P_CLIENT: + case NL80211_IFTYPE_MONITOR: + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_P2P_DEVICE: + case NL80211_IFTYPE_NAN: + break; + case NL80211_IFTYPE_WDS: + case NL80211_IFTYPE_UNSPECIFIED: + case NUM_NL80211_IFTYPES: + WARN_ON(1); + } + + return 0; +} +EXPORT_SYMBOL(cfg80211_chandef_dfs_required); + +static int cfg80211_get_chans_dfs_usable(struct wiphy *wiphy, + u32 center_freq, + u32 bandwidth) +{ + struct ieee80211_channel *c; + u32 freq, start_freq, end_freq; + int count = 0; + + start_freq = cfg80211_get_start_freq(center_freq, bandwidth); + end_freq = cfg80211_get_end_freq(center_freq, bandwidth); + + /* + * Check entire range of channels for the bandwidth. + * Check all channels are DFS channels (DFS_USABLE or + * DFS_AVAILABLE). Return number of usable channels + * (require CAC). Allow DFS and non-DFS channel mix. + */ + for (freq = start_freq; freq <= end_freq; freq += MHZ_TO_KHZ(20)) { + c = ieee80211_get_channel_khz(wiphy, freq); + if (!c) + return -EINVAL; + + if (c->flags & IEEE80211_CHAN_DISABLED) + return -EINVAL; + + if (c->flags & IEEE80211_CHAN_RADAR) { + if (c->dfs_state == NL80211_DFS_UNAVAILABLE) + return -EINVAL; + + if (c->dfs_state == NL80211_DFS_USABLE) + count++; + } + } + + return count; +} + +bool cfg80211_chandef_dfs_usable(struct wiphy *wiphy, + const struct cfg80211_chan_def *chandef) +{ + int width; + int r1, r2 = 0; + + if (WARN_ON(!cfg80211_chandef_valid(chandef))) + return false; + + width = cfg80211_chandef_get_width(chandef); + if (width < 0) + return false; + + r1 = cfg80211_get_chans_dfs_usable(wiphy, + MHZ_TO_KHZ(chandef->center_freq1), + width); + + if (r1 < 0) + return false; + + switch (chandef->width) { + case NL80211_CHAN_WIDTH_80P80: + WARN_ON(!chandef->center_freq2); + r2 = cfg80211_get_chans_dfs_usable(wiphy, + MHZ_TO_KHZ(chandef->center_freq2), + width); + if (r2 < 0) + return false; + break; + default: + WARN_ON(chandef->center_freq2); + break; + } + + return (r1 + r2 > 0); +} + +/* + * Checks if center frequency of chan falls with in the bandwidth + * range of chandef. + */ +bool cfg80211_is_sub_chan(struct cfg80211_chan_def *chandef, + struct ieee80211_channel *chan, + bool primary_only) +{ + int width; + u32 freq; + + if (!chandef->chan) + return false; + + if (chandef->chan->center_freq == chan->center_freq) + return true; + + if (primary_only) + return false; + + width = cfg80211_chandef_get_width(chandef); + if (width <= 20) + return false; + + for (freq = chandef->center_freq1 - width / 2 + 10; + freq <= chandef->center_freq1 + width / 2 - 10; freq += 20) { + if (chan->center_freq == freq) + return true; + } + + if (!chandef->center_freq2) + return false; + + for (freq = chandef->center_freq2 - width / 2 + 10; + freq <= chandef->center_freq2 + width / 2 - 10; freq += 20) { + if (chan->center_freq == freq) + return true; + } + + return false; +} + +bool cfg80211_beaconing_iface_active(struct wireless_dev *wdev) +{ + unsigned int link; + + ASSERT_WDEV_LOCK(wdev); + + switch (wdev->iftype) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_P2P_GO: + for_each_valid_link(wdev, link) { + if (wdev->links[link].ap.beacon_interval) + return true; + } + break; + case NL80211_IFTYPE_ADHOC: + if (wdev->u.ibss.ssid_len) + return true; + break; + case NL80211_IFTYPE_MESH_POINT: + if (wdev->u.mesh.id_len) + return true; + break; + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_OCB: + case NL80211_IFTYPE_P2P_CLIENT: + case NL80211_IFTYPE_MONITOR: + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_P2P_DEVICE: + /* Can NAN type be considered as beaconing interface? */ + case NL80211_IFTYPE_NAN: + break; + case NL80211_IFTYPE_UNSPECIFIED: + case NL80211_IFTYPE_WDS: + case NUM_NL80211_IFTYPES: + WARN_ON(1); + } + + return false; +} + +bool cfg80211_wdev_on_sub_chan(struct wireless_dev *wdev, + struct ieee80211_channel *chan, + bool primary_only) +{ + unsigned int link; + + switch (wdev->iftype) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_P2P_GO: + for_each_valid_link(wdev, link) { + if (cfg80211_is_sub_chan(&wdev->links[link].ap.chandef, + chan, primary_only)) + return true; + } + break; + case NL80211_IFTYPE_ADHOC: + return cfg80211_is_sub_chan(&wdev->u.ibss.chandef, chan, + primary_only); + case NL80211_IFTYPE_MESH_POINT: + return cfg80211_is_sub_chan(&wdev->u.mesh.chandef, chan, + primary_only); + default: + break; + } + + return false; +} + +static bool cfg80211_is_wiphy_oper_chan(struct wiphy *wiphy, + struct ieee80211_channel *chan) +{ + struct wireless_dev *wdev; + + list_for_each_entry(wdev, &wiphy->wdev_list, list) { + wdev_lock(wdev); + if (!cfg80211_beaconing_iface_active(wdev)) { + wdev_unlock(wdev); + continue; + } + + if (cfg80211_wdev_on_sub_chan(wdev, chan, false)) { + wdev_unlock(wdev); + return true; + } + wdev_unlock(wdev); + } + + return false; +} + +static bool +cfg80211_offchan_chain_is_active(struct cfg80211_registered_device *rdev, + struct ieee80211_channel *channel) +{ + if (!rdev->background_radar_wdev) + return false; + + if (!cfg80211_chandef_valid(&rdev->background_radar_chandef)) + return false; + + return cfg80211_is_sub_chan(&rdev->background_radar_chandef, channel, + false); +} + +bool cfg80211_any_wiphy_oper_chan(struct wiphy *wiphy, + struct ieee80211_channel *chan) +{ + struct cfg80211_registered_device *rdev; + + ASSERT_RTNL(); + + if (!(chan->flags & IEEE80211_CHAN_RADAR)) + return false; + + list_for_each_entry(rdev, &cfg80211_rdev_list, list) { + if (!reg_dfs_domain_same(wiphy, &rdev->wiphy)) + continue; + + if (cfg80211_is_wiphy_oper_chan(&rdev->wiphy, chan)) + return true; + + if (cfg80211_offchan_chain_is_active(rdev, chan)) + return true; + } + + return false; +} + +static bool cfg80211_get_chans_dfs_available(struct wiphy *wiphy, + u32 center_freq, + u32 bandwidth) +{ + struct ieee80211_channel *c; + u32 freq, start_freq, end_freq; + bool dfs_offload; + + dfs_offload = wiphy_ext_feature_isset(wiphy, + NL80211_EXT_FEATURE_DFS_OFFLOAD); + + start_freq = cfg80211_get_start_freq(center_freq, bandwidth); + end_freq = cfg80211_get_end_freq(center_freq, bandwidth); + + /* + * Check entire range of channels for the bandwidth. + * If any channel in between is disabled or has not + * had gone through CAC return false + */ + for (freq = start_freq; freq <= end_freq; freq += MHZ_TO_KHZ(20)) { + c = ieee80211_get_channel_khz(wiphy, freq); + if (!c) + return false; + + if (c->flags & IEEE80211_CHAN_DISABLED) + return false; + + if ((c->flags & IEEE80211_CHAN_RADAR) && + (c->dfs_state != NL80211_DFS_AVAILABLE) && + !(c->dfs_state == NL80211_DFS_USABLE && dfs_offload)) + return false; + } + + return true; +} + +static bool cfg80211_chandef_dfs_available(struct wiphy *wiphy, + const struct cfg80211_chan_def *chandef) +{ + int width; + int r; + + if (WARN_ON(!cfg80211_chandef_valid(chandef))) + return false; + + width = cfg80211_chandef_get_width(chandef); + if (width < 0) + return false; + + r = cfg80211_get_chans_dfs_available(wiphy, + MHZ_TO_KHZ(chandef->center_freq1), + width); + + /* If any of channels unavailable for cf1 just return */ + if (!r) + return r; + + switch (chandef->width) { + case NL80211_CHAN_WIDTH_80P80: + WARN_ON(!chandef->center_freq2); + r = cfg80211_get_chans_dfs_available(wiphy, + MHZ_TO_KHZ(chandef->center_freq2), + width); + break; + default: + WARN_ON(chandef->center_freq2); + break; + } + + return r; +} + +static unsigned int cfg80211_get_chans_dfs_cac_time(struct wiphy *wiphy, + u32 center_freq, + u32 bandwidth) +{ + struct ieee80211_channel *c; + u32 start_freq, end_freq, freq; + unsigned int dfs_cac_ms = 0; + + start_freq = cfg80211_get_start_freq(center_freq, bandwidth); + end_freq = cfg80211_get_end_freq(center_freq, bandwidth); + + for (freq = start_freq; freq <= end_freq; freq += MHZ_TO_KHZ(20)) { + c = ieee80211_get_channel_khz(wiphy, freq); + if (!c) + return 0; + + if (c->flags & IEEE80211_CHAN_DISABLED) + return 0; + + if (!(c->flags & IEEE80211_CHAN_RADAR)) + continue; + + if (c->dfs_cac_ms > dfs_cac_ms) + dfs_cac_ms = c->dfs_cac_ms; + } + + return dfs_cac_ms; +} + +unsigned int +cfg80211_chandef_dfs_cac_time(struct wiphy *wiphy, + const struct cfg80211_chan_def *chandef) +{ + int width; + unsigned int t1 = 0, t2 = 0; + + if (WARN_ON(!cfg80211_chandef_valid(chandef))) + return 0; + + width = cfg80211_chandef_get_width(chandef); + if (width < 0) + return 0; + + t1 = cfg80211_get_chans_dfs_cac_time(wiphy, + MHZ_TO_KHZ(chandef->center_freq1), + width); + + if (!chandef->center_freq2) + return t1; + + t2 = cfg80211_get_chans_dfs_cac_time(wiphy, + MHZ_TO_KHZ(chandef->center_freq2), + width); + + return max(t1, t2); +} + +static bool cfg80211_secondary_chans_ok(struct wiphy *wiphy, + u32 center_freq, u32 bandwidth, + u32 prohibited_flags) +{ + struct ieee80211_channel *c; + u32 freq, start_freq, end_freq; + + start_freq = cfg80211_get_start_freq(center_freq, bandwidth); + end_freq = cfg80211_get_end_freq(center_freq, bandwidth); + + for (freq = start_freq; freq <= end_freq; freq += MHZ_TO_KHZ(20)) { + c = ieee80211_get_channel_khz(wiphy, freq); + if (!c || c->flags & prohibited_flags) + return false; + } + + return true; +} + +/* check if the operating channels are valid and supported */ +static bool cfg80211_edmg_usable(struct wiphy *wiphy, u8 edmg_channels, + enum ieee80211_edmg_bw_config edmg_bw_config, + int primary_channel, + struct ieee80211_edmg *edmg_cap) +{ + struct ieee80211_channel *chan; + int i, freq; + int channels_counter = 0; + + if (!edmg_channels && !edmg_bw_config) + return true; + + if ((!edmg_channels && edmg_bw_config) || + (edmg_channels && !edmg_bw_config)) + return false; + + if (!(edmg_channels & BIT(primary_channel - 1))) + return false; + + /* 60GHz channels 1..6 */ + for (i = 0; i < 6; i++) { + if (!(edmg_channels & BIT(i))) + continue; + + if (!(edmg_cap->channels & BIT(i))) + return false; + + channels_counter++; + + freq = ieee80211_channel_to_frequency(i + 1, + NL80211_BAND_60GHZ); + chan = ieee80211_get_channel(wiphy, freq); + if (!chan || chan->flags & IEEE80211_CHAN_DISABLED) + return false; + } + + /* IEEE802.11 allows max 4 channels */ + if (channels_counter > 4) + return false; + + /* check bw_config is a subset of what driver supports + * (see IEEE P802.11ay/D4.0 section 9.4.2.251, Table 13) + */ + if ((edmg_bw_config % 4) > (edmg_cap->bw_config % 4)) + return false; + + if (edmg_bw_config > edmg_cap->bw_config) + return false; + + return true; +} + +bool cfg80211_chandef_usable(struct wiphy *wiphy, + const struct cfg80211_chan_def *chandef, + u32 prohibited_flags) +{ + struct ieee80211_sta_ht_cap *ht_cap; + struct ieee80211_sta_vht_cap *vht_cap; + struct ieee80211_edmg *edmg_cap; + u32 width, control_freq, cap; + bool ext_nss_cap, support_80_80 = false, support_320 = false; + const struct ieee80211_sband_iftype_data *iftd; + struct ieee80211_supported_band *sband; + int i; + + if (WARN_ON(!cfg80211_chandef_valid(chandef))) + return false; + + ht_cap = &wiphy->bands[chandef->chan->band]->ht_cap; + vht_cap = &wiphy->bands[chandef->chan->band]->vht_cap; + edmg_cap = &wiphy->bands[chandef->chan->band]->edmg_cap; + ext_nss_cap = __le16_to_cpu(vht_cap->vht_mcs.tx_highest) & + IEEE80211_VHT_EXT_NSS_BW_CAPABLE; + + if (edmg_cap->channels && + !cfg80211_edmg_usable(wiphy, + chandef->edmg.channels, + chandef->edmg.bw_config, + chandef->chan->hw_value, + edmg_cap)) + return false; + + control_freq = chandef->chan->center_freq; + + switch (chandef->width) { + case NL80211_CHAN_WIDTH_1: + width = 1; + break; + case NL80211_CHAN_WIDTH_2: + width = 2; + break; + case NL80211_CHAN_WIDTH_4: + width = 4; + break; + case NL80211_CHAN_WIDTH_8: + width = 8; + break; + case NL80211_CHAN_WIDTH_16: + width = 16; + break; + case NL80211_CHAN_WIDTH_5: + width = 5; + break; + case NL80211_CHAN_WIDTH_10: + prohibited_flags |= IEEE80211_CHAN_NO_10MHZ; + width = 10; + break; + case NL80211_CHAN_WIDTH_20: + if (!ht_cap->ht_supported && + chandef->chan->band != NL80211_BAND_6GHZ) + return false; + fallthrough; + case NL80211_CHAN_WIDTH_20_NOHT: + prohibited_flags |= IEEE80211_CHAN_NO_20MHZ; + width = 20; + break; + case NL80211_CHAN_WIDTH_40: + width = 40; + if (chandef->chan->band == NL80211_BAND_6GHZ) + break; + if (!ht_cap->ht_supported) + return false; + if (!(ht_cap->cap & IEEE80211_HT_CAP_SUP_WIDTH_20_40) || + ht_cap->cap & IEEE80211_HT_CAP_40MHZ_INTOLERANT) + return false; + if (chandef->center_freq1 < control_freq && + chandef->chan->flags & IEEE80211_CHAN_NO_HT40MINUS) + return false; + if (chandef->center_freq1 > control_freq && + chandef->chan->flags & IEEE80211_CHAN_NO_HT40PLUS) + return false; + break; + case NL80211_CHAN_WIDTH_80P80: + cap = vht_cap->cap; + support_80_80 = + (cap & IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160_80PLUS80MHZ) || + (cap & IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160MHZ && + cap & IEEE80211_VHT_CAP_EXT_NSS_BW_MASK) || + (ext_nss_cap && + u32_get_bits(cap, IEEE80211_VHT_CAP_EXT_NSS_BW_MASK) > 1); + if (chandef->chan->band != NL80211_BAND_6GHZ && !support_80_80) + return false; + fallthrough; + case NL80211_CHAN_WIDTH_80: + prohibited_flags |= IEEE80211_CHAN_NO_80MHZ; + width = 80; + if (chandef->chan->band == NL80211_BAND_6GHZ) + break; + if (!vht_cap->vht_supported) + return false; + break; + case NL80211_CHAN_WIDTH_160: + prohibited_flags |= IEEE80211_CHAN_NO_160MHZ; + width = 160; + if (chandef->chan->band == NL80211_BAND_6GHZ) + break; + if (!vht_cap->vht_supported) + return false; + cap = vht_cap->cap & IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_MASK; + if (cap != IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160MHZ && + cap != IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_160_80PLUS80MHZ && + !(ext_nss_cap && + (vht_cap->cap & IEEE80211_VHT_CAP_EXT_NSS_BW_MASK))) + return false; + break; + case NL80211_CHAN_WIDTH_320: + prohibited_flags |= IEEE80211_CHAN_NO_320MHZ; + width = 320; + + if (chandef->chan->band != NL80211_BAND_6GHZ) + return false; + + sband = wiphy->bands[NL80211_BAND_6GHZ]; + if (!sband) + return false; + + for (i = 0; i < sband->n_iftype_data; i++) { + iftd = &sband->iftype_data[i]; + if (!iftd->eht_cap.has_eht) + continue; + + if (iftd->eht_cap.eht_cap_elem.phy_cap_info[0] & + IEEE80211_EHT_PHY_CAP0_320MHZ_IN_6GHZ) { + support_320 = true; + break; + } + } + + if (!support_320) + return false; + break; + default: + WARN_ON_ONCE(1); + return false; + } + + /* + * TODO: What if there are only certain 80/160/80+80 MHz channels + * allowed by the driver, or only certain combinations? + * For 40 MHz the driver can set the NO_HT40 flags, but for + * 80/160 MHz and in particular 80+80 MHz this isn't really + * feasible and we only have NO_80MHZ/NO_160MHZ so far but + * no way to cover 80+80 MHz or more complex restrictions. + * Note that such restrictions also need to be advertised to + * userspace, for example for P2P channel selection. + */ + + if (width > 20) + prohibited_flags |= IEEE80211_CHAN_NO_OFDM; + + /* 5 and 10 MHz are only defined for the OFDM PHY */ + if (width < 20) + prohibited_flags |= IEEE80211_CHAN_NO_OFDM; + + + if (!cfg80211_secondary_chans_ok(wiphy, + ieee80211_chandef_to_khz(chandef), + width, prohibited_flags)) + return false; + + if (!chandef->center_freq2) + return true; + return cfg80211_secondary_chans_ok(wiphy, + MHZ_TO_KHZ(chandef->center_freq2), + width, prohibited_flags); +} +EXPORT_SYMBOL(cfg80211_chandef_usable); + +static bool cfg80211_ir_permissive_check_wdev(enum nl80211_iftype iftype, + struct wireless_dev *wdev, + struct ieee80211_channel *chan) +{ + struct ieee80211_channel *other_chan = NULL; + unsigned int link_id; + int r1, r2; + + for_each_valid_link(wdev, link_id) { + if (wdev->iftype == NL80211_IFTYPE_STATION && + wdev->links[link_id].client.current_bss) + other_chan = wdev->links[link_id].client.current_bss->pub.channel; + + /* + * If a GO already operates on the same GO_CONCURRENT channel, + * this one (maybe the same one) can beacon as well. We allow + * the operation even if the station we relied on with + * GO_CONCURRENT is disconnected now. But then we must make sure + * we're not outdoor on an indoor-only channel. + */ + if (iftype == NL80211_IFTYPE_P2P_GO && + wdev->iftype == NL80211_IFTYPE_P2P_GO && + wdev->links[link_id].ap.beacon_interval && + !(chan->flags & IEEE80211_CHAN_INDOOR_ONLY)) + other_chan = wdev->links[link_id].ap.chandef.chan; + + if (!other_chan) + continue; + + if (chan == other_chan) + return true; + + if (chan->band != NL80211_BAND_5GHZ && + chan->band != NL80211_BAND_6GHZ) + continue; + + r1 = cfg80211_get_unii(chan->center_freq); + r2 = cfg80211_get_unii(other_chan->center_freq); + + if (r1 != -EINVAL && r1 == r2) { + /* + * At some locations channels 149-165 are considered a + * bundle, but at other locations, e.g., Indonesia, + * channels 149-161 are considered a bundle while + * channel 165 is left out and considered to be in a + * different bundle. Thus, in case that there is a + * station interface connected to an AP on channel 165, + * it is assumed that channels 149-161 are allowed for + * GO operations. However, having a station interface + * connected to an AP on channels 149-161, does not + * allow GO operation on channel 165. + */ + if (chan->center_freq == 5825 && + other_chan->center_freq != 5825) + continue; + return true; + } + } + + return false; +} + +/* + * Check if the channel can be used under permissive conditions mandated by + * some regulatory bodies, i.e., the channel is marked with + * IEEE80211_CHAN_IR_CONCURRENT and there is an additional station interface + * associated to an AP on the same channel or on the same UNII band + * (assuming that the AP is an authorized master). + * In addition allow operation on a channel on which indoor operation is + * allowed, iff we are currently operating in an indoor environment. + */ +static bool cfg80211_ir_permissive_chan(struct wiphy *wiphy, + enum nl80211_iftype iftype, + struct ieee80211_channel *chan) +{ + struct wireless_dev *wdev; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + lockdep_assert_held(&rdev->wiphy.mtx); + + if (!IS_ENABLED(CONFIG_CFG80211_REG_RELAX_NO_IR) || + !(wiphy->regulatory_flags & REGULATORY_ENABLE_RELAX_NO_IR)) + return false; + + /* only valid for GO and TDLS off-channel (station/p2p-CL) */ + if (iftype != NL80211_IFTYPE_P2P_GO && + iftype != NL80211_IFTYPE_STATION && + iftype != NL80211_IFTYPE_P2P_CLIENT) + return false; + + if (regulatory_indoor_allowed() && + (chan->flags & IEEE80211_CHAN_INDOOR_ONLY)) + return true; + + if (!(chan->flags & IEEE80211_CHAN_IR_CONCURRENT)) + return false; + + /* + * Generally, it is possible to rely on another device/driver to allow + * the IR concurrent relaxation, however, since the device can further + * enforce the relaxation (by doing a similar verifications as this), + * and thus fail the GO instantiation, consider only the interfaces of + * the current registered device. + */ + list_for_each_entry(wdev, &rdev->wiphy.wdev_list, list) { + bool ret; + + wdev_lock(wdev); + ret = cfg80211_ir_permissive_check_wdev(iftype, wdev, chan); + wdev_unlock(wdev); + + if (ret) + return ret; + } + + return false; +} + +static bool _cfg80211_reg_can_beacon(struct wiphy *wiphy, + struct cfg80211_chan_def *chandef, + enum nl80211_iftype iftype, + bool check_no_ir) +{ + bool res; + u32 prohibited_flags = IEEE80211_CHAN_DISABLED | + IEEE80211_CHAN_RADAR; + + trace_cfg80211_reg_can_beacon(wiphy, chandef, iftype, check_no_ir); + + if (check_no_ir) + prohibited_flags |= IEEE80211_CHAN_NO_IR; + + if (cfg80211_chandef_dfs_required(wiphy, chandef, iftype) > 0 && + cfg80211_chandef_dfs_available(wiphy, chandef)) { + /* We can skip IEEE80211_CHAN_NO_IR if chandef dfs available */ + prohibited_flags = IEEE80211_CHAN_DISABLED; + } + + res = cfg80211_chandef_usable(wiphy, chandef, prohibited_flags); + + trace_cfg80211_return_bool(res); + return res; +} + +bool cfg80211_reg_can_beacon(struct wiphy *wiphy, + struct cfg80211_chan_def *chandef, + enum nl80211_iftype iftype) +{ + return _cfg80211_reg_can_beacon(wiphy, chandef, iftype, true); +} +EXPORT_SYMBOL(cfg80211_reg_can_beacon); + +bool cfg80211_reg_can_beacon_relax(struct wiphy *wiphy, + struct cfg80211_chan_def *chandef, + enum nl80211_iftype iftype) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + bool check_no_ir; + + lockdep_assert_held(&rdev->wiphy.mtx); + + /* + * Under certain conditions suggested by some regulatory bodies a + * GO/STA can IR on channels marked with IEEE80211_NO_IR. Set this flag + * only if such relaxations are not enabled and the conditions are not + * met. + */ + check_no_ir = !cfg80211_ir_permissive_chan(wiphy, iftype, + chandef->chan); + + return _cfg80211_reg_can_beacon(wiphy, chandef, iftype, check_no_ir); +} +EXPORT_SYMBOL(cfg80211_reg_can_beacon_relax); + +int cfg80211_set_monitor_channel(struct cfg80211_registered_device *rdev, + struct cfg80211_chan_def *chandef) +{ + if (!rdev->ops->set_monitor_channel) + return -EOPNOTSUPP; + if (!cfg80211_has_monitors_only(rdev)) + return -EBUSY; + + return rdev_set_monitor_channel(rdev, chandef); +} + +bool cfg80211_any_usable_channels(struct wiphy *wiphy, + unsigned long sband_mask, + u32 prohibited_flags) +{ + int idx; + + prohibited_flags |= IEEE80211_CHAN_DISABLED; + + for_each_set_bit(idx, &sband_mask, NUM_NL80211_BANDS) { + struct ieee80211_supported_band *sband = wiphy->bands[idx]; + int chanidx; + + if (!sband) + continue; + + for (chanidx = 0; chanidx < sband->n_channels; chanidx++) { + struct ieee80211_channel *chan; + + chan = &sband->channels[chanidx]; + + if (chan->flags & prohibited_flags) + continue; + + return true; + } + } + + return false; +} +EXPORT_SYMBOL(cfg80211_any_usable_channels); + +struct cfg80211_chan_def *wdev_chandef(struct wireless_dev *wdev, + unsigned int link_id) +{ + /* + * We need to sort out the locking here - in some cases + * where we get here we really just don't care (yet) + * about the valid links, but in others we do. But we + * get here with various driver cases, so we cannot + * easily require the wdev mutex. + */ + if (link_id || wdev->valid_links & BIT(0)) { + ASSERT_WDEV_LOCK(wdev); + WARN_ON(!(wdev->valid_links & BIT(link_id))); + } + + switch (wdev->iftype) { + case NL80211_IFTYPE_MESH_POINT: + return &wdev->u.mesh.chandef; + case NL80211_IFTYPE_ADHOC: + return &wdev->u.ibss.chandef; + case NL80211_IFTYPE_OCB: + return &wdev->u.ocb.chandef; + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_P2P_GO: + return &wdev->links[link_id].ap.chandef; + default: + return NULL; + } +} +EXPORT_SYMBOL(wdev_chandef); diff --git a/net/wireless/core.c b/net/wireless/core.c new file mode 100644 index 000000000..8809e668e --- /dev/null +++ b/net/wireless/core.c @@ -0,0 +1,1763 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * This is the linux wireless configuration interface. + * + * Copyright 2006-2010 Johannes Berg + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright 2015-2017 Intel Deutschland GmbH + * Copyright (C) 2018-2022 Intel Corporation + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "nl80211.h" +#include "core.h" +#include "sysfs.h" +#include "debugfs.h" +#include "wext-compat.h" +#include "rdev-ops.h" + +/* name for sysfs, %d is appended */ +#define PHY_NAME "phy" + +MODULE_AUTHOR("Johannes Berg"); +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("wireless configuration support"); +MODULE_ALIAS_GENL_FAMILY(NL80211_GENL_NAME); + +/* RCU-protected (and RTNL for writers) */ +LIST_HEAD(cfg80211_rdev_list); +int cfg80211_rdev_list_generation; + +/* for debugfs */ +static struct dentry *ieee80211_debugfs_dir; + +/* for the cleanup, scan and event works */ +struct workqueue_struct *cfg80211_wq; + +static bool cfg80211_disable_40mhz_24ghz; +module_param(cfg80211_disable_40mhz_24ghz, bool, 0644); +MODULE_PARM_DESC(cfg80211_disable_40mhz_24ghz, + "Disable 40MHz support in the 2.4GHz band"); + +struct cfg80211_registered_device *cfg80211_rdev_by_wiphy_idx(int wiphy_idx) +{ + struct cfg80211_registered_device *result = NULL, *rdev; + + ASSERT_RTNL(); + + list_for_each_entry(rdev, &cfg80211_rdev_list, list) { + if (rdev->wiphy_idx == wiphy_idx) { + result = rdev; + break; + } + } + + return result; +} + +int get_wiphy_idx(struct wiphy *wiphy) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + return rdev->wiphy_idx; +} + +struct wiphy *wiphy_idx_to_wiphy(int wiphy_idx) +{ + struct cfg80211_registered_device *rdev; + + ASSERT_RTNL(); + + rdev = cfg80211_rdev_by_wiphy_idx(wiphy_idx); + if (!rdev) + return NULL; + return &rdev->wiphy; +} + +static int cfg80211_dev_check_name(struct cfg80211_registered_device *rdev, + const char *newname) +{ + struct cfg80211_registered_device *rdev2; + int wiphy_idx, taken = -1, digits; + + ASSERT_RTNL(); + + if (strlen(newname) > NL80211_WIPHY_NAME_MAXLEN) + return -EINVAL; + + /* prohibit calling the thing phy%d when %d is not its number */ + sscanf(newname, PHY_NAME "%d%n", &wiphy_idx, &taken); + if (taken == strlen(newname) && wiphy_idx != rdev->wiphy_idx) { + /* count number of places needed to print wiphy_idx */ + digits = 1; + while (wiphy_idx /= 10) + digits++; + /* + * deny the name if it is phy where is printed + * without leading zeroes. taken == strlen(newname) here + */ + if (taken == strlen(PHY_NAME) + digits) + return -EINVAL; + } + + /* Ensure another device does not already have this name. */ + list_for_each_entry(rdev2, &cfg80211_rdev_list, list) + if (strcmp(newname, wiphy_name(&rdev2->wiphy)) == 0) + return -EINVAL; + + return 0; +} + +int cfg80211_dev_rename(struct cfg80211_registered_device *rdev, + char *newname) +{ + int result; + + ASSERT_RTNL(); + + /* Ignore nop renames */ + if (strcmp(newname, wiphy_name(&rdev->wiphy)) == 0) + return 0; + + result = cfg80211_dev_check_name(rdev, newname); + if (result < 0) + return result; + + result = device_rename(&rdev->wiphy.dev, newname); + if (result) + return result; + + if (!IS_ERR_OR_NULL(rdev->wiphy.debugfsdir)) + debugfs_rename(rdev->wiphy.debugfsdir->d_parent, + rdev->wiphy.debugfsdir, + rdev->wiphy.debugfsdir->d_parent, newname); + + nl80211_notify_wiphy(rdev, NL80211_CMD_NEW_WIPHY); + + return 0; +} + +int cfg80211_switch_netns(struct cfg80211_registered_device *rdev, + struct net *net) +{ + struct wireless_dev *wdev; + int err = 0; + + if (!(rdev->wiphy.flags & WIPHY_FLAG_NETNS_OK)) + return -EOPNOTSUPP; + + list_for_each_entry(wdev, &rdev->wiphy.wdev_list, list) { + if (!wdev->netdev) + continue; + wdev->netdev->features &= ~NETIF_F_NETNS_LOCAL; + err = dev_change_net_namespace(wdev->netdev, net, "wlan%d"); + if (err) + break; + wdev->netdev->features |= NETIF_F_NETNS_LOCAL; + } + + if (err) { + /* failed -- clean up to old netns */ + net = wiphy_net(&rdev->wiphy); + + list_for_each_entry_continue_reverse(wdev, + &rdev->wiphy.wdev_list, + list) { + if (!wdev->netdev) + continue; + wdev->netdev->features &= ~NETIF_F_NETNS_LOCAL; + err = dev_change_net_namespace(wdev->netdev, net, + "wlan%d"); + WARN_ON(err); + wdev->netdev->features |= NETIF_F_NETNS_LOCAL; + } + + return err; + } + + list_for_each_entry(wdev, &rdev->wiphy.wdev_list, list) { + if (!wdev->netdev) + continue; + nl80211_notify_iface(rdev, wdev, NL80211_CMD_DEL_INTERFACE); + } + nl80211_notify_wiphy(rdev, NL80211_CMD_DEL_WIPHY); + + wiphy_net_set(&rdev->wiphy, net); + + err = device_rename(&rdev->wiphy.dev, dev_name(&rdev->wiphy.dev)); + WARN_ON(err); + + nl80211_notify_wiphy(rdev, NL80211_CMD_NEW_WIPHY); + list_for_each_entry(wdev, &rdev->wiphy.wdev_list, list) { + if (!wdev->netdev) + continue; + nl80211_notify_iface(rdev, wdev, NL80211_CMD_NEW_INTERFACE); + } + + return 0; +} + +static void cfg80211_rfkill_poll(struct rfkill *rfkill, void *data) +{ + struct cfg80211_registered_device *rdev = data; + + wiphy_lock(&rdev->wiphy); + rdev_rfkill_poll(rdev); + wiphy_unlock(&rdev->wiphy); +} + +void cfg80211_stop_p2p_device(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev) +{ + lockdep_assert_held(&rdev->wiphy.mtx); + + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_P2P_DEVICE)) + return; + + if (!wdev_running(wdev)) + return; + + rdev_stop_p2p_device(rdev, wdev); + wdev->is_running = false; + + rdev->opencount--; + + if (rdev->scan_req && rdev->scan_req->wdev == wdev) { + if (WARN_ON(!rdev->scan_req->notified && + (!rdev->int_scan_req || + !rdev->int_scan_req->notified))) + rdev->scan_req->info.aborted = true; + ___cfg80211_scan_done(rdev, false); + } +} + +void cfg80211_stop_nan(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev) +{ + lockdep_assert_held(&rdev->wiphy.mtx); + + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_NAN)) + return; + + if (!wdev_running(wdev)) + return; + + rdev_stop_nan(rdev, wdev); + wdev->is_running = false; + + rdev->opencount--; +} + +void cfg80211_shutdown_all_interfaces(struct wiphy *wiphy) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct wireless_dev *wdev; + + ASSERT_RTNL(); + + list_for_each_entry(wdev, &rdev->wiphy.wdev_list, list) { + if (wdev->netdev) { + dev_close(wdev->netdev); + continue; + } + + /* otherwise, check iftype */ + + wiphy_lock(wiphy); + + switch (wdev->iftype) { + case NL80211_IFTYPE_P2P_DEVICE: + cfg80211_stop_p2p_device(rdev, wdev); + break; + case NL80211_IFTYPE_NAN: + cfg80211_stop_nan(rdev, wdev); + break; + default: + break; + } + + wiphy_unlock(wiphy); + } +} +EXPORT_SYMBOL_GPL(cfg80211_shutdown_all_interfaces); + +static int cfg80211_rfkill_set_block(void *data, bool blocked) +{ + struct cfg80211_registered_device *rdev = data; + + if (!blocked) + return 0; + + rtnl_lock(); + cfg80211_shutdown_all_interfaces(&rdev->wiphy); + rtnl_unlock(); + + return 0; +} + +static void cfg80211_rfkill_block_work(struct work_struct *work) +{ + struct cfg80211_registered_device *rdev; + + rdev = container_of(work, struct cfg80211_registered_device, + rfkill_block); + cfg80211_rfkill_set_block(rdev, true); +} + +static void cfg80211_event_work(struct work_struct *work) +{ + struct cfg80211_registered_device *rdev; + + rdev = container_of(work, struct cfg80211_registered_device, + event_work); + + wiphy_lock(&rdev->wiphy); + cfg80211_process_rdev_events(rdev); + wiphy_unlock(&rdev->wiphy); +} + +void cfg80211_destroy_ifaces(struct cfg80211_registered_device *rdev) +{ + struct wireless_dev *wdev, *tmp; + + ASSERT_RTNL(); + + list_for_each_entry_safe(wdev, tmp, &rdev->wiphy.wdev_list, list) { + if (wdev->nl_owner_dead) { + if (wdev->netdev) + dev_close(wdev->netdev); + + wiphy_lock(&rdev->wiphy); + cfg80211_leave(rdev, wdev); + cfg80211_remove_virtual_intf(rdev, wdev); + wiphy_unlock(&rdev->wiphy); + } + } +} + +static void cfg80211_destroy_iface_wk(struct work_struct *work) +{ + struct cfg80211_registered_device *rdev; + + rdev = container_of(work, struct cfg80211_registered_device, + destroy_work); + + rtnl_lock(); + cfg80211_destroy_ifaces(rdev); + rtnl_unlock(); +} + +static void cfg80211_sched_scan_stop_wk(struct work_struct *work) +{ + struct cfg80211_registered_device *rdev; + struct cfg80211_sched_scan_request *req, *tmp; + + rdev = container_of(work, struct cfg80211_registered_device, + sched_scan_stop_wk); + + wiphy_lock(&rdev->wiphy); + list_for_each_entry_safe(req, tmp, &rdev->sched_scan_req_list, list) { + if (req->nl_owner_dead) + cfg80211_stop_sched_scan_req(rdev, req, false); + } + wiphy_unlock(&rdev->wiphy); +} + +static void cfg80211_propagate_radar_detect_wk(struct work_struct *work) +{ + struct cfg80211_registered_device *rdev; + + rdev = container_of(work, struct cfg80211_registered_device, + propagate_radar_detect_wk); + + rtnl_lock(); + + regulatory_propagate_dfs_state(&rdev->wiphy, &rdev->radar_chandef, + NL80211_DFS_UNAVAILABLE, + NL80211_RADAR_DETECTED); + + rtnl_unlock(); +} + +static void cfg80211_propagate_cac_done_wk(struct work_struct *work) +{ + struct cfg80211_registered_device *rdev; + + rdev = container_of(work, struct cfg80211_registered_device, + propagate_cac_done_wk); + + rtnl_lock(); + + regulatory_propagate_dfs_state(&rdev->wiphy, &rdev->cac_done_chandef, + NL80211_DFS_AVAILABLE, + NL80211_RADAR_CAC_FINISHED); + + rtnl_unlock(); +} + +static void cfg80211_wiphy_work(struct work_struct *work) +{ + struct cfg80211_registered_device *rdev; + struct wiphy_work *wk; + + rdev = container_of(work, struct cfg80211_registered_device, wiphy_work); + + wiphy_lock(&rdev->wiphy); + if (rdev->suspended) + goto out; + + spin_lock_irq(&rdev->wiphy_work_lock); + wk = list_first_entry_or_null(&rdev->wiphy_work_list, + struct wiphy_work, entry); + if (wk) { + list_del_init(&wk->entry); + if (!list_empty(&rdev->wiphy_work_list)) + schedule_work(work); + spin_unlock_irq(&rdev->wiphy_work_lock); + + wk->func(&rdev->wiphy, wk); + } else { + spin_unlock_irq(&rdev->wiphy_work_lock); + } +out: + wiphy_unlock(&rdev->wiphy); +} + +/* exported functions */ + +struct wiphy *wiphy_new_nm(const struct cfg80211_ops *ops, int sizeof_priv, + const char *requested_name) +{ + static atomic_t wiphy_counter = ATOMIC_INIT(0); + + struct cfg80211_registered_device *rdev; + int alloc_size; + + WARN_ON(ops->add_key && (!ops->del_key || !ops->set_default_key)); + WARN_ON(ops->auth && (!ops->assoc || !ops->deauth || !ops->disassoc)); + WARN_ON(ops->connect && !ops->disconnect); + WARN_ON(ops->join_ibss && !ops->leave_ibss); + WARN_ON(ops->add_virtual_intf && !ops->del_virtual_intf); + WARN_ON(ops->add_station && !ops->del_station); + WARN_ON(ops->add_mpath && !ops->del_mpath); + WARN_ON(ops->join_mesh && !ops->leave_mesh); + WARN_ON(ops->start_p2p_device && !ops->stop_p2p_device); + WARN_ON(ops->start_ap && !ops->stop_ap); + WARN_ON(ops->join_ocb && !ops->leave_ocb); + WARN_ON(ops->suspend && !ops->resume); + WARN_ON(ops->sched_scan_start && !ops->sched_scan_stop); + WARN_ON(ops->remain_on_channel && !ops->cancel_remain_on_channel); + WARN_ON(ops->tdls_channel_switch && !ops->tdls_cancel_channel_switch); + WARN_ON(ops->add_tx_ts && !ops->del_tx_ts); + + alloc_size = sizeof(*rdev) + sizeof_priv; + + rdev = kzalloc(alloc_size, GFP_KERNEL); + if (!rdev) + return NULL; + + rdev->ops = ops; + + rdev->wiphy_idx = atomic_inc_return(&wiphy_counter); + + if (unlikely(rdev->wiphy_idx < 0)) { + /* ugh, wrapped! */ + atomic_dec(&wiphy_counter); + kfree(rdev); + return NULL; + } + + /* atomic_inc_return makes it start at 1, make it start at 0 */ + rdev->wiphy_idx--; + + /* give it a proper name */ + if (requested_name && requested_name[0]) { + int rv; + + rtnl_lock(); + rv = cfg80211_dev_check_name(rdev, requested_name); + + if (rv < 0) { + rtnl_unlock(); + goto use_default_name; + } + + rv = dev_set_name(&rdev->wiphy.dev, "%s", requested_name); + rtnl_unlock(); + if (rv) + goto use_default_name; + } else { + int rv; + +use_default_name: + /* NOTE: This is *probably* safe w/out holding rtnl because of + * the restrictions on phy names. Probably this call could + * fail if some other part of the kernel (re)named a device + * phyX. But, might should add some locking and check return + * value, and use a different name if this one exists? + */ + rv = dev_set_name(&rdev->wiphy.dev, PHY_NAME "%d", rdev->wiphy_idx); + if (rv < 0) { + kfree(rdev); + return NULL; + } + } + + mutex_init(&rdev->wiphy.mtx); + INIT_LIST_HEAD(&rdev->wiphy.wdev_list); + INIT_LIST_HEAD(&rdev->beacon_registrations); + spin_lock_init(&rdev->beacon_registrations_lock); + spin_lock_init(&rdev->bss_lock); + INIT_LIST_HEAD(&rdev->bss_list); + INIT_LIST_HEAD(&rdev->sched_scan_req_list); + INIT_WORK(&rdev->scan_done_wk, __cfg80211_scan_done); + INIT_DELAYED_WORK(&rdev->dfs_update_channels_wk, + cfg80211_dfs_channels_update_work); +#ifdef CONFIG_CFG80211_WEXT + rdev->wiphy.wext = &cfg80211_wext_handler; +#endif + + device_initialize(&rdev->wiphy.dev); + rdev->wiphy.dev.class = &ieee80211_class; + rdev->wiphy.dev.platform_data = rdev; + device_enable_async_suspend(&rdev->wiphy.dev); + + INIT_WORK(&rdev->destroy_work, cfg80211_destroy_iface_wk); + INIT_WORK(&rdev->sched_scan_stop_wk, cfg80211_sched_scan_stop_wk); + INIT_WORK(&rdev->sched_scan_res_wk, cfg80211_sched_scan_results_wk); + INIT_WORK(&rdev->propagate_radar_detect_wk, + cfg80211_propagate_radar_detect_wk); + INIT_WORK(&rdev->propagate_cac_done_wk, cfg80211_propagate_cac_done_wk); + INIT_WORK(&rdev->mgmt_registrations_update_wk, + cfg80211_mgmt_registrations_update_wk); + spin_lock_init(&rdev->mgmt_registrations_lock); + +#ifdef CONFIG_CFG80211_DEFAULT_PS + rdev->wiphy.flags |= WIPHY_FLAG_PS_ON_BY_DEFAULT; +#endif + + wiphy_net_set(&rdev->wiphy, &init_net); + + rdev->rfkill_ops.set_block = cfg80211_rfkill_set_block; + rdev->wiphy.rfkill = rfkill_alloc(dev_name(&rdev->wiphy.dev), + &rdev->wiphy.dev, RFKILL_TYPE_WLAN, + &rdev->rfkill_ops, rdev); + + if (!rdev->wiphy.rfkill) { + wiphy_free(&rdev->wiphy); + return NULL; + } + + INIT_WORK(&rdev->wiphy_work, cfg80211_wiphy_work); + INIT_LIST_HEAD(&rdev->wiphy_work_list); + spin_lock_init(&rdev->wiphy_work_lock); + INIT_WORK(&rdev->rfkill_block, cfg80211_rfkill_block_work); + INIT_WORK(&rdev->conn_work, cfg80211_conn_work); + INIT_WORK(&rdev->event_work, cfg80211_event_work); + INIT_WORK(&rdev->background_cac_abort_wk, + cfg80211_background_cac_abort_wk); + INIT_DELAYED_WORK(&rdev->background_cac_done_wk, + cfg80211_background_cac_done_wk); + + init_waitqueue_head(&rdev->dev_wait); + + /* + * Initialize wiphy parameters to IEEE 802.11 MIB default values. + * Fragmentation and RTS threshold are disabled by default with the + * special -1 value. + */ + rdev->wiphy.retry_short = 7; + rdev->wiphy.retry_long = 4; + rdev->wiphy.frag_threshold = (u32) -1; + rdev->wiphy.rts_threshold = (u32) -1; + rdev->wiphy.coverage_class = 0; + + rdev->wiphy.max_num_csa_counters = 1; + + rdev->wiphy.max_sched_scan_plans = 1; + rdev->wiphy.max_sched_scan_plan_interval = U32_MAX; + + return &rdev->wiphy; +} +EXPORT_SYMBOL(wiphy_new_nm); + +static int wiphy_verify_combinations(struct wiphy *wiphy) +{ + const struct ieee80211_iface_combination *c; + int i, j; + + for (i = 0; i < wiphy->n_iface_combinations; i++) { + u32 cnt = 0; + u16 all_iftypes = 0; + + c = &wiphy->iface_combinations[i]; + + /* + * Combinations with just one interface aren't real, + * however we make an exception for DFS. + */ + if (WARN_ON((c->max_interfaces < 2) && !c->radar_detect_widths)) + return -EINVAL; + + /* Need at least one channel */ + if (WARN_ON(!c->num_different_channels)) + return -EINVAL; + + /* DFS only works on one channel. */ + if (WARN_ON(c->radar_detect_widths && + (c->num_different_channels > 1))) + return -EINVAL; + + if (WARN_ON(!c->n_limits)) + return -EINVAL; + + for (j = 0; j < c->n_limits; j++) { + u16 types = c->limits[j].types; + + /* interface types shouldn't overlap */ + if (WARN_ON(types & all_iftypes)) + return -EINVAL; + all_iftypes |= types; + + if (WARN_ON(!c->limits[j].max)) + return -EINVAL; + + /* Shouldn't list software iftypes in combinations! */ + if (WARN_ON(wiphy->software_iftypes & types)) + return -EINVAL; + + /* Only a single P2P_DEVICE can be allowed */ + if (WARN_ON(types & BIT(NL80211_IFTYPE_P2P_DEVICE) && + c->limits[j].max > 1)) + return -EINVAL; + + /* Only a single NAN can be allowed */ + if (WARN_ON(types & BIT(NL80211_IFTYPE_NAN) && + c->limits[j].max > 1)) + return -EINVAL; + + /* + * This isn't well-defined right now. If you have an + * IBSS interface, then its beacon interval may change + * by joining other networks, and nothing prevents it + * from doing that. + * So technically we probably shouldn't even allow AP + * and IBSS in the same interface, but it seems that + * some drivers support that, possibly only with fixed + * beacon intervals for IBSS. + */ + if (WARN_ON(types & BIT(NL80211_IFTYPE_ADHOC) && + c->beacon_int_min_gcd)) { + return -EINVAL; + } + + cnt += c->limits[j].max; + /* + * Don't advertise an unsupported type + * in a combination. + */ + if (WARN_ON((wiphy->interface_modes & types) != types)) + return -EINVAL; + } + + if (WARN_ON(all_iftypes & BIT(NL80211_IFTYPE_WDS))) + return -EINVAL; + + /* You can't even choose that many! */ + if (WARN_ON(cnt < c->max_interfaces)) + return -EINVAL; + } + + return 0; +} + +int wiphy_register(struct wiphy *wiphy) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + int res; + enum nl80211_band band; + struct ieee80211_supported_band *sband; + bool have_band = false; + int i; + u16 ifmodes = wiphy->interface_modes; + +#ifdef CONFIG_PM + if (WARN_ON(wiphy->wowlan && + (wiphy->wowlan->flags & WIPHY_WOWLAN_GTK_REKEY_FAILURE) && + !(wiphy->wowlan->flags & WIPHY_WOWLAN_SUPPORTS_GTK_REKEY))) + return -EINVAL; + if (WARN_ON(wiphy->wowlan && + !wiphy->wowlan->flags && !wiphy->wowlan->n_patterns && + !wiphy->wowlan->tcp)) + return -EINVAL; +#endif + if (WARN_ON((wiphy->features & NL80211_FEATURE_TDLS_CHANNEL_SWITCH) && + (!rdev->ops->tdls_channel_switch || + !rdev->ops->tdls_cancel_channel_switch))) + return -EINVAL; + + if (WARN_ON((wiphy->interface_modes & BIT(NL80211_IFTYPE_NAN)) && + (!rdev->ops->start_nan || !rdev->ops->stop_nan || + !rdev->ops->add_nan_func || !rdev->ops->del_nan_func || + !(wiphy->nan_supported_bands & BIT(NL80211_BAND_2GHZ))))) + return -EINVAL; + + if (WARN_ON(wiphy->interface_modes & BIT(NL80211_IFTYPE_WDS))) + return -EINVAL; + + if (WARN_ON(wiphy->pmsr_capa && !wiphy->pmsr_capa->ftm.supported)) + return -EINVAL; + + if (wiphy->pmsr_capa && wiphy->pmsr_capa->ftm.supported) { + if (WARN_ON(!wiphy->pmsr_capa->ftm.asap && + !wiphy->pmsr_capa->ftm.non_asap)) + return -EINVAL; + if (WARN_ON(!wiphy->pmsr_capa->ftm.preambles || + !wiphy->pmsr_capa->ftm.bandwidths)) + return -EINVAL; + if (WARN_ON(wiphy->pmsr_capa->ftm.preambles & + ~(BIT(NL80211_PREAMBLE_LEGACY) | + BIT(NL80211_PREAMBLE_HT) | + BIT(NL80211_PREAMBLE_VHT) | + BIT(NL80211_PREAMBLE_HE) | + BIT(NL80211_PREAMBLE_DMG)))) + return -EINVAL; + if (WARN_ON((wiphy->pmsr_capa->ftm.trigger_based || + wiphy->pmsr_capa->ftm.non_trigger_based) && + !(wiphy->pmsr_capa->ftm.preambles & + BIT(NL80211_PREAMBLE_HE)))) + return -EINVAL; + if (WARN_ON(wiphy->pmsr_capa->ftm.bandwidths & + ~(BIT(NL80211_CHAN_WIDTH_20_NOHT) | + BIT(NL80211_CHAN_WIDTH_20) | + BIT(NL80211_CHAN_WIDTH_40) | + BIT(NL80211_CHAN_WIDTH_80) | + BIT(NL80211_CHAN_WIDTH_80P80) | + BIT(NL80211_CHAN_WIDTH_160) | + BIT(NL80211_CHAN_WIDTH_5) | + BIT(NL80211_CHAN_WIDTH_10)))) + return -EINVAL; + } + + if (WARN_ON((wiphy->regulatory_flags & REGULATORY_WIPHY_SELF_MANAGED) && + (wiphy->regulatory_flags & + (REGULATORY_CUSTOM_REG | + REGULATORY_STRICT_REG | + REGULATORY_COUNTRY_IE_FOLLOW_POWER | + REGULATORY_COUNTRY_IE_IGNORE)))) + return -EINVAL; + + if (WARN_ON(wiphy->coalesce && + (!wiphy->coalesce->n_rules || + !wiphy->coalesce->n_patterns) && + (!wiphy->coalesce->pattern_min_len || + wiphy->coalesce->pattern_min_len > + wiphy->coalesce->pattern_max_len))) + return -EINVAL; + + if (WARN_ON(wiphy->ap_sme_capa && + !(wiphy->flags & WIPHY_FLAG_HAVE_AP_SME))) + return -EINVAL; + + if (WARN_ON(wiphy->addresses && !wiphy->n_addresses)) + return -EINVAL; + + if (WARN_ON(wiphy->addresses && + !is_zero_ether_addr(wiphy->perm_addr) && + memcmp(wiphy->perm_addr, wiphy->addresses[0].addr, + ETH_ALEN))) + return -EINVAL; + + if (WARN_ON(wiphy->max_acl_mac_addrs && + (!(wiphy->flags & WIPHY_FLAG_HAVE_AP_SME) || + !rdev->ops->set_mac_acl))) + return -EINVAL; + + /* assure only valid behaviours are flagged by driver + * hence subtract 2 as bit 0 is invalid. + */ + if (WARN_ON(wiphy->bss_select_support && + (wiphy->bss_select_support & ~(BIT(__NL80211_BSS_SELECT_ATTR_AFTER_LAST) - 2)))) + return -EINVAL; + + if (WARN_ON(wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_4WAY_HANDSHAKE_STA_1X) && + (!rdev->ops->set_pmk || !rdev->ops->del_pmk))) + return -EINVAL; + + if (WARN_ON(!(rdev->wiphy.flags & WIPHY_FLAG_SUPPORTS_FW_ROAM) && + rdev->ops->update_connect_params)) + return -EINVAL; + + if (wiphy->addresses) + memcpy(wiphy->perm_addr, wiphy->addresses[0].addr, ETH_ALEN); + + /* sanity check ifmodes */ + WARN_ON(!ifmodes); + ifmodes &= ((1 << NUM_NL80211_IFTYPES) - 1) & ~1; + if (WARN_ON(ifmodes != wiphy->interface_modes)) + wiphy->interface_modes = ifmodes; + + res = wiphy_verify_combinations(wiphy); + if (res) + return res; + + /* sanity check supported bands/channels */ + for (band = 0; band < NUM_NL80211_BANDS; band++) { + u16 types = 0; + bool have_he = false; + + sband = wiphy->bands[band]; + if (!sband) + continue; + + sband->band = band; + if (WARN_ON(!sband->n_channels)) + return -EINVAL; + /* + * on 60GHz or sub-1Ghz band, there are no legacy rates, so + * n_bitrates is 0 + */ + if (WARN_ON((band != NL80211_BAND_60GHZ && + band != NL80211_BAND_S1GHZ) && + !sband->n_bitrates)) + return -EINVAL; + + if (WARN_ON(band == NL80211_BAND_6GHZ && + (sband->ht_cap.ht_supported || + sband->vht_cap.vht_supported))) + return -EINVAL; + + /* + * Since cfg80211_disable_40mhz_24ghz is global, we can + * modify the sband's ht data even if the driver uses a + * global structure for that. + */ + if (cfg80211_disable_40mhz_24ghz && + band == NL80211_BAND_2GHZ && + sband->ht_cap.ht_supported) { + sband->ht_cap.cap &= ~IEEE80211_HT_CAP_SUP_WIDTH_20_40; + sband->ht_cap.cap &= ~IEEE80211_HT_CAP_SGI_40; + } + + /* + * Since we use a u32 for rate bitmaps in + * ieee80211_get_response_rate, we cannot + * have more than 32 legacy rates. + */ + if (WARN_ON(sband->n_bitrates > 32)) + return -EINVAL; + + for (i = 0; i < sband->n_channels; i++) { + sband->channels[i].orig_flags = + sband->channels[i].flags; + sband->channels[i].orig_mag = INT_MAX; + sband->channels[i].orig_mpwr = + sband->channels[i].max_power; + sband->channels[i].band = band; + + if (WARN_ON(sband->channels[i].freq_offset >= 1000)) + return -EINVAL; + } + + for (i = 0; i < sband->n_iftype_data; i++) { + const struct ieee80211_sband_iftype_data *iftd; + bool has_ap, has_non_ap; + u32 ap_bits = BIT(NL80211_IFTYPE_AP) | + BIT(NL80211_IFTYPE_P2P_GO); + + iftd = &sband->iftype_data[i]; + + if (WARN_ON(!iftd->types_mask)) + return -EINVAL; + if (WARN_ON(types & iftd->types_mask)) + return -EINVAL; + + /* at least one piece of information must be present */ + if (WARN_ON(!iftd->he_cap.has_he)) + return -EINVAL; + + types |= iftd->types_mask; + + if (i == 0) + have_he = iftd->he_cap.has_he; + else + have_he = have_he && + iftd->he_cap.has_he; + + has_ap = iftd->types_mask & ap_bits; + has_non_ap = iftd->types_mask & ~ap_bits; + + /* + * For EHT 20 MHz STA, the capabilities format differs + * but to simplify, don't check 20 MHz but rather check + * only if AP and non-AP were mentioned at the same time, + * reject if so. + */ + if (WARN_ON(iftd->eht_cap.has_eht && + has_ap && has_non_ap)) + return -EINVAL; + } + + if (WARN_ON(!have_he && band == NL80211_BAND_6GHZ)) + return -EINVAL; + + have_band = true; + } + + if (!have_band) { + WARN_ON(1); + return -EINVAL; + } + + for (i = 0; i < rdev->wiphy.n_vendor_commands; i++) { + /* + * Validate we have a policy (can be explicitly set to + * VENDOR_CMD_RAW_DATA which is non-NULL) and also that + * we have at least one of doit/dumpit. + */ + if (WARN_ON(!rdev->wiphy.vendor_commands[i].policy)) + return -EINVAL; + if (WARN_ON(!rdev->wiphy.vendor_commands[i].doit && + !rdev->wiphy.vendor_commands[i].dumpit)) + return -EINVAL; + } + +#ifdef CONFIG_PM + if (WARN_ON(rdev->wiphy.wowlan && rdev->wiphy.wowlan->n_patterns && + (!rdev->wiphy.wowlan->pattern_min_len || + rdev->wiphy.wowlan->pattern_min_len > + rdev->wiphy.wowlan->pattern_max_len))) + return -EINVAL; +#endif + + if (!wiphy->max_num_akm_suites) + wiphy->max_num_akm_suites = NL80211_MAX_NR_AKM_SUITES; + else if (wiphy->max_num_akm_suites < NL80211_MAX_NR_AKM_SUITES || + wiphy->max_num_akm_suites > CFG80211_MAX_NUM_AKM_SUITES) + return -EINVAL; + + /* check and set up bitrates */ + ieee80211_set_bitrate_flags(wiphy); + + rdev->wiphy.features |= NL80211_FEATURE_SCAN_FLUSH; + + rtnl_lock(); + res = device_add(&rdev->wiphy.dev); + if (res) { + rtnl_unlock(); + return res; + } + + list_add_rcu(&rdev->list, &cfg80211_rdev_list); + cfg80211_rdev_list_generation++; + + /* add to debugfs */ + rdev->wiphy.debugfsdir = debugfs_create_dir(wiphy_name(&rdev->wiphy), + ieee80211_debugfs_dir); + + cfg80211_debugfs_rdev_add(rdev); + nl80211_notify_wiphy(rdev, NL80211_CMD_NEW_WIPHY); + + /* set up regulatory info */ + wiphy_regulatory_register(wiphy); + + if (wiphy->regulatory_flags & REGULATORY_CUSTOM_REG) { + struct regulatory_request request; + + request.wiphy_idx = get_wiphy_idx(wiphy); + request.initiator = NL80211_REGDOM_SET_BY_DRIVER; + request.alpha2[0] = '9'; + request.alpha2[1] = '9'; + + nl80211_send_reg_change_event(&request); + } + + /* Check that nobody globally advertises any capabilities they do not + * advertise on all possible interface types. + */ + if (wiphy->extended_capabilities_len && + wiphy->num_iftype_ext_capab && + wiphy->iftype_ext_capab) { + u8 supported_on_all, j; + const struct wiphy_iftype_ext_capab *capab; + + capab = wiphy->iftype_ext_capab; + for (j = 0; j < wiphy->extended_capabilities_len; j++) { + if (capab[0].extended_capabilities_len > j) + supported_on_all = + capab[0].extended_capabilities[j]; + else + supported_on_all = 0x00; + for (i = 1; i < wiphy->num_iftype_ext_capab; i++) { + if (j >= capab[i].extended_capabilities_len) { + supported_on_all = 0x00; + break; + } + supported_on_all &= + capab[i].extended_capabilities[j]; + } + if (WARN_ON(wiphy->extended_capabilities[j] & + ~supported_on_all)) + break; + } + } + + rdev->wiphy.registered = true; + rtnl_unlock(); + + res = rfkill_register(rdev->wiphy.rfkill); + if (res) { + rfkill_destroy(rdev->wiphy.rfkill); + rdev->wiphy.rfkill = NULL; + wiphy_unregister(&rdev->wiphy); + return res; + } + + return 0; +} +EXPORT_SYMBOL(wiphy_register); + +void wiphy_rfkill_start_polling(struct wiphy *wiphy) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + if (!rdev->ops->rfkill_poll) + return; + rdev->rfkill_ops.poll = cfg80211_rfkill_poll; + rfkill_resume_polling(wiphy->rfkill); +} +EXPORT_SYMBOL(wiphy_rfkill_start_polling); + +void cfg80211_process_wiphy_works(struct cfg80211_registered_device *rdev, + struct wiphy_work *end) +{ + unsigned int runaway_limit = 100; + unsigned long flags; + + lockdep_assert_held(&rdev->wiphy.mtx); + + spin_lock_irqsave(&rdev->wiphy_work_lock, flags); + while (!list_empty(&rdev->wiphy_work_list)) { + struct wiphy_work *wk; + + wk = list_first_entry(&rdev->wiphy_work_list, + struct wiphy_work, entry); + list_del_init(&wk->entry); + spin_unlock_irqrestore(&rdev->wiphy_work_lock, flags); + + wk->func(&rdev->wiphy, wk); + + spin_lock_irqsave(&rdev->wiphy_work_lock, flags); + + if (wk == end) + break; + + if (WARN_ON(--runaway_limit == 0)) + INIT_LIST_HEAD(&rdev->wiphy_work_list); + } + spin_unlock_irqrestore(&rdev->wiphy_work_lock, flags); +} + +void wiphy_unregister(struct wiphy *wiphy) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + wait_event(rdev->dev_wait, ({ + int __count; + wiphy_lock(&rdev->wiphy); + __count = rdev->opencount; + wiphy_unlock(&rdev->wiphy); + __count == 0; })); + + if (rdev->wiphy.rfkill) + rfkill_unregister(rdev->wiphy.rfkill); + + rtnl_lock(); + wiphy_lock(&rdev->wiphy); + nl80211_notify_wiphy(rdev, NL80211_CMD_DEL_WIPHY); + rdev->wiphy.registered = false; + + WARN_ON(!list_empty(&rdev->wiphy.wdev_list)); + + /* + * First remove the hardware from everywhere, this makes + * it impossible to find from userspace. + */ + debugfs_remove_recursive(rdev->wiphy.debugfsdir); + list_del_rcu(&rdev->list); + synchronize_rcu(); + + /* + * If this device got a regulatory hint tell core its + * free to listen now to a new shiny device regulatory hint + */ + wiphy_regulatory_deregister(wiphy); + + cfg80211_rdev_list_generation++; + device_del(&rdev->wiphy.dev); + +#ifdef CONFIG_PM + if (rdev->wiphy.wowlan_config && rdev->ops->set_wakeup) + rdev_set_wakeup(rdev, false); +#endif + + /* surely nothing is reachable now, clean up work */ + cfg80211_process_wiphy_works(rdev, NULL); + wiphy_unlock(&rdev->wiphy); + rtnl_unlock(); + + /* this has nothing to do now but make sure it's gone */ + cancel_work_sync(&rdev->wiphy_work); + + flush_work(&rdev->scan_done_wk); + cancel_work_sync(&rdev->conn_work); + flush_work(&rdev->event_work); + cancel_delayed_work_sync(&rdev->dfs_update_channels_wk); + cancel_delayed_work_sync(&rdev->background_cac_done_wk); + flush_work(&rdev->destroy_work); + flush_work(&rdev->sched_scan_stop_wk); + flush_work(&rdev->propagate_radar_detect_wk); + flush_work(&rdev->propagate_cac_done_wk); + flush_work(&rdev->mgmt_registrations_update_wk); + flush_work(&rdev->background_cac_abort_wk); + + cfg80211_rdev_free_wowlan(rdev); + cfg80211_rdev_free_coalesce(rdev); +} +EXPORT_SYMBOL(wiphy_unregister); + +void cfg80211_dev_free(struct cfg80211_registered_device *rdev) +{ + struct cfg80211_internal_bss *scan, *tmp; + struct cfg80211_beacon_registration *reg, *treg; + rfkill_destroy(rdev->wiphy.rfkill); + list_for_each_entry_safe(reg, treg, &rdev->beacon_registrations, list) { + list_del(®->list); + kfree(reg); + } + list_for_each_entry_safe(scan, tmp, &rdev->bss_list, list) + cfg80211_put_bss(&rdev->wiphy, &scan->pub); + mutex_destroy(&rdev->wiphy.mtx); + + /* + * The 'regd' can only be non-NULL if we never finished + * initializing the wiphy and thus never went through the + * unregister path - e.g. in failure scenarios. Thus, it + * cannot have been visible to anyone if non-NULL, so we + * can just free it here. + */ + kfree(rcu_dereference_raw(rdev->wiphy.regd)); + + kfree(rdev); +} + +void wiphy_free(struct wiphy *wiphy) +{ + put_device(&wiphy->dev); +} +EXPORT_SYMBOL(wiphy_free); + +void wiphy_rfkill_set_hw_state_reason(struct wiphy *wiphy, bool blocked, + enum rfkill_hard_block_reasons reason) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + if (rfkill_set_hw_state_reason(wiphy->rfkill, blocked, reason)) + schedule_work(&rdev->rfkill_block); +} +EXPORT_SYMBOL(wiphy_rfkill_set_hw_state_reason); + +static void _cfg80211_unregister_wdev(struct wireless_dev *wdev, + bool unregister_netdev) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_cqm_config *cqm_config; + unsigned int link_id; + + ASSERT_RTNL(); + lockdep_assert_held(&rdev->wiphy.mtx); + + flush_work(&wdev->pmsr_free_wk); + + nl80211_notify_iface(rdev, wdev, NL80211_CMD_DEL_INTERFACE); + + wdev->registered = false; + + if (wdev->netdev) { + sysfs_remove_link(&wdev->netdev->dev.kobj, "phy80211"); + if (unregister_netdev) + unregister_netdevice(wdev->netdev); + } + + list_del_rcu(&wdev->list); + synchronize_net(); + rdev->devlist_generation++; + + cfg80211_mlme_purge_registrations(wdev); + + switch (wdev->iftype) { + case NL80211_IFTYPE_P2P_DEVICE: + cfg80211_stop_p2p_device(rdev, wdev); + break; + case NL80211_IFTYPE_NAN: + cfg80211_stop_nan(rdev, wdev); + break; + default: + break; + } + +#ifdef CONFIG_CFG80211_WEXT + kfree_sensitive(wdev->wext.keys); + wdev->wext.keys = NULL; +#endif + wiphy_work_cancel(wdev->wiphy, &wdev->cqm_rssi_work); + /* deleted from the list, so can't be found from nl80211 any more */ + cqm_config = rcu_access_pointer(wdev->cqm_config); + kfree_rcu(cqm_config, rcu_head); + + /* + * Ensure that all events have been processed and + * freed. + */ + cfg80211_process_wdev_events(wdev); + + if (wdev->iftype == NL80211_IFTYPE_STATION || + wdev->iftype == NL80211_IFTYPE_P2P_CLIENT) { + for (link_id = 0; link_id < ARRAY_SIZE(wdev->links); link_id++) { + struct cfg80211_internal_bss *curbss; + + curbss = wdev->links[link_id].client.current_bss; + + if (WARN_ON(curbss)) { + cfg80211_unhold_bss(curbss); + cfg80211_put_bss(wdev->wiphy, &curbss->pub); + wdev->links[link_id].client.current_bss = NULL; + } + } + } + + wdev->connected = false; +} + +void cfg80211_unregister_wdev(struct wireless_dev *wdev) +{ + _cfg80211_unregister_wdev(wdev, true); +} +EXPORT_SYMBOL(cfg80211_unregister_wdev); + +static const struct device_type wiphy_type = { + .name = "wlan", +}; + +void cfg80211_update_iface_num(struct cfg80211_registered_device *rdev, + enum nl80211_iftype iftype, int num) +{ + lockdep_assert_held(&rdev->wiphy.mtx); + + rdev->num_running_ifaces += num; + if (iftype == NL80211_IFTYPE_MONITOR) + rdev->num_running_monitor_ifaces += num; +} + +void __cfg80211_leave(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev) +{ + struct net_device *dev = wdev->netdev; + struct cfg80211_sched_scan_request *pos, *tmp; + + lockdep_assert_held(&rdev->wiphy.mtx); + ASSERT_WDEV_LOCK(wdev); + + cfg80211_pmsr_wdev_down(wdev); + + cfg80211_stop_background_radar_detection(wdev); + + switch (wdev->iftype) { + case NL80211_IFTYPE_ADHOC: + __cfg80211_leave_ibss(rdev, dev, true); + break; + case NL80211_IFTYPE_P2P_CLIENT: + case NL80211_IFTYPE_STATION: + list_for_each_entry_safe(pos, tmp, &rdev->sched_scan_req_list, + list) { + if (dev == pos->dev) + cfg80211_stop_sched_scan_req(rdev, pos, false); + } + +#ifdef CONFIG_CFG80211_WEXT + kfree(wdev->wext.ie); + wdev->wext.ie = NULL; + wdev->wext.ie_len = 0; + wdev->wext.connect.auth_type = NL80211_AUTHTYPE_AUTOMATIC; +#endif + cfg80211_disconnect(rdev, dev, + WLAN_REASON_DEAUTH_LEAVING, true); + break; + case NL80211_IFTYPE_MESH_POINT: + __cfg80211_leave_mesh(rdev, dev); + break; + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_P2P_GO: + __cfg80211_stop_ap(rdev, dev, -1, true); + break; + case NL80211_IFTYPE_OCB: + __cfg80211_leave_ocb(rdev, dev); + break; + case NL80211_IFTYPE_P2P_DEVICE: + case NL80211_IFTYPE_NAN: + /* cannot happen, has no netdev */ + break; + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_MONITOR: + /* nothing to do */ + break; + case NL80211_IFTYPE_UNSPECIFIED: + case NL80211_IFTYPE_WDS: + case NUM_NL80211_IFTYPES: + /* invalid */ + break; + } +} + +void cfg80211_leave(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev) +{ + wdev_lock(wdev); + __cfg80211_leave(rdev, wdev); + wdev_unlock(wdev); +} + +void cfg80211_stop_iface(struct wiphy *wiphy, struct wireless_dev *wdev, + gfp_t gfp) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct cfg80211_event *ev; + unsigned long flags; + + trace_cfg80211_stop_iface(wiphy, wdev); + + ev = kzalloc(sizeof(*ev), gfp); + if (!ev) + return; + + ev->type = EVENT_STOPPED; + + spin_lock_irqsave(&wdev->event_lock, flags); + list_add_tail(&ev->list, &wdev->event_list); + spin_unlock_irqrestore(&wdev->event_lock, flags); + queue_work(cfg80211_wq, &rdev->event_work); +} +EXPORT_SYMBOL(cfg80211_stop_iface); + +void cfg80211_init_wdev(struct wireless_dev *wdev) +{ + mutex_init(&wdev->mtx); + INIT_LIST_HEAD(&wdev->event_list); + spin_lock_init(&wdev->event_lock); + INIT_LIST_HEAD(&wdev->mgmt_registrations); + INIT_LIST_HEAD(&wdev->pmsr_list); + spin_lock_init(&wdev->pmsr_lock); + INIT_WORK(&wdev->pmsr_free_wk, cfg80211_pmsr_free_wk); + +#ifdef CONFIG_CFG80211_WEXT + wdev->wext.default_key = -1; + wdev->wext.default_mgmt_key = -1; + wdev->wext.connect.auth_type = NL80211_AUTHTYPE_AUTOMATIC; +#endif + + wiphy_work_init(&wdev->cqm_rssi_work, cfg80211_cqm_rssi_notify_work); + + if (wdev->wiphy->flags & WIPHY_FLAG_PS_ON_BY_DEFAULT) + wdev->ps = true; + else + wdev->ps = false; + /* allow mac80211 to determine the timeout */ + wdev->ps_timeout = -1; + + if ((wdev->iftype == NL80211_IFTYPE_STATION || + wdev->iftype == NL80211_IFTYPE_P2P_CLIENT || + wdev->iftype == NL80211_IFTYPE_ADHOC) && !wdev->use_4addr) + wdev->netdev->priv_flags |= IFF_DONT_BRIDGE; + + INIT_WORK(&wdev->disconnect_wk, cfg80211_autodisconnect_wk); +} + +void cfg80211_register_wdev(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev) +{ + ASSERT_RTNL(); + lockdep_assert_held(&rdev->wiphy.mtx); + + /* + * We get here also when the interface changes network namespaces, + * as it's registered into the new one, but we don't want it to + * change ID in that case. Checking if the ID is already assigned + * works, because 0 isn't considered a valid ID and the memory is + * 0-initialized. + */ + if (!wdev->identifier) + wdev->identifier = ++rdev->wdev_id; + list_add_rcu(&wdev->list, &rdev->wiphy.wdev_list); + rdev->devlist_generation++; + wdev->registered = true; + + if (wdev->netdev && + sysfs_create_link(&wdev->netdev->dev.kobj, &rdev->wiphy.dev.kobj, + "phy80211")) + pr_err("failed to add phy80211 symlink to netdev!\n"); + + nl80211_notify_iface(rdev, wdev, NL80211_CMD_NEW_INTERFACE); +} + +int cfg80211_register_netdevice(struct net_device *dev) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev; + int ret; + + ASSERT_RTNL(); + + if (WARN_ON(!wdev)) + return -EINVAL; + + rdev = wiphy_to_rdev(wdev->wiphy); + + lockdep_assert_held(&rdev->wiphy.mtx); + + /* we'll take care of this */ + wdev->registered = true; + wdev->registering = true; + ret = register_netdevice(dev); + if (ret) + goto out; + + cfg80211_register_wdev(rdev, wdev); + ret = 0; +out: + wdev->registering = false; + if (ret) + wdev->registered = false; + return ret; +} +EXPORT_SYMBOL(cfg80211_register_netdevice); + +static int cfg80211_netdev_notifier_call(struct notifier_block *nb, + unsigned long state, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev; + struct cfg80211_sched_scan_request *pos, *tmp; + + if (!wdev) + return NOTIFY_DONE; + + rdev = wiphy_to_rdev(wdev->wiphy); + + WARN_ON(wdev->iftype == NL80211_IFTYPE_UNSPECIFIED); + + switch (state) { + case NETDEV_POST_INIT: + SET_NETDEV_DEVTYPE(dev, &wiphy_type); + wdev->netdev = dev; + /* can only change netns with wiphy */ + dev->features |= NETIF_F_NETNS_LOCAL; + + cfg80211_init_wdev(wdev); + break; + case NETDEV_REGISTER: + if (!wdev->registered) { + wiphy_lock(&rdev->wiphy); + cfg80211_register_wdev(rdev, wdev); + wiphy_unlock(&rdev->wiphy); + } + break; + case NETDEV_UNREGISTER: + /* + * It is possible to get NETDEV_UNREGISTER multiple times, + * so check wdev->registered. + */ + if (wdev->registered && !wdev->registering) { + wiphy_lock(&rdev->wiphy); + _cfg80211_unregister_wdev(wdev, false); + wiphy_unlock(&rdev->wiphy); + } + break; + case NETDEV_GOING_DOWN: + wiphy_lock(&rdev->wiphy); + cfg80211_leave(rdev, wdev); + cfg80211_remove_links(wdev); + wiphy_unlock(&rdev->wiphy); + /* since we just did cfg80211_leave() nothing to do there */ + cancel_work_sync(&wdev->disconnect_wk); + break; + case NETDEV_DOWN: + wiphy_lock(&rdev->wiphy); + cfg80211_update_iface_num(rdev, wdev->iftype, -1); + if (rdev->scan_req && rdev->scan_req->wdev == wdev) { + if (WARN_ON(!rdev->scan_req->notified && + (!rdev->int_scan_req || + !rdev->int_scan_req->notified))) + rdev->scan_req->info.aborted = true; + ___cfg80211_scan_done(rdev, false); + } + + list_for_each_entry_safe(pos, tmp, + &rdev->sched_scan_req_list, list) { + if (WARN_ON(pos->dev == wdev->netdev)) + cfg80211_stop_sched_scan_req(rdev, pos, false); + } + + rdev->opencount--; + wiphy_unlock(&rdev->wiphy); + wake_up(&rdev->dev_wait); + break; + case NETDEV_UP: + wiphy_lock(&rdev->wiphy); + cfg80211_update_iface_num(rdev, wdev->iftype, 1); + wdev_lock(wdev); + switch (wdev->iftype) { +#ifdef CONFIG_CFG80211_WEXT + case NL80211_IFTYPE_ADHOC: + cfg80211_ibss_wext_join(rdev, wdev); + break; + case NL80211_IFTYPE_STATION: + cfg80211_mgd_wext_connect(rdev, wdev); + break; +#endif +#ifdef CONFIG_MAC80211_MESH + case NL80211_IFTYPE_MESH_POINT: + { + /* backward compat code... */ + struct mesh_setup setup; + memcpy(&setup, &default_mesh_setup, + sizeof(setup)); + /* back compat only needed for mesh_id */ + setup.mesh_id = wdev->u.mesh.id; + setup.mesh_id_len = wdev->u.mesh.id_up_len; + if (wdev->u.mesh.id_up_len) + __cfg80211_join_mesh(rdev, dev, + &setup, + &default_mesh_config); + break; + } +#endif + default: + break; + } + wdev_unlock(wdev); + rdev->opencount++; + + /* + * Configure power management to the driver here so that its + * correctly set also after interface type changes etc. + */ + if ((wdev->iftype == NL80211_IFTYPE_STATION || + wdev->iftype == NL80211_IFTYPE_P2P_CLIENT) && + rdev->ops->set_power_mgmt && + rdev_set_power_mgmt(rdev, dev, wdev->ps, + wdev->ps_timeout)) { + /* assume this means it's off */ + wdev->ps = false; + } + wiphy_unlock(&rdev->wiphy); + break; + case NETDEV_PRE_UP: + if (!cfg80211_iftype_allowed(wdev->wiphy, wdev->iftype, + wdev->use_4addr, 0)) + return notifier_from_errno(-EOPNOTSUPP); + + if (rfkill_blocked(rdev->wiphy.rfkill)) + return notifier_from_errno(-ERFKILL); + break; + default: + return NOTIFY_DONE; + } + + wireless_nlevent_flush(); + + return NOTIFY_OK; +} + +static struct notifier_block cfg80211_netdev_notifier = { + .notifier_call = cfg80211_netdev_notifier_call, +}; + +static void __net_exit cfg80211_pernet_exit(struct net *net) +{ + struct cfg80211_registered_device *rdev; + + rtnl_lock(); + list_for_each_entry(rdev, &cfg80211_rdev_list, list) { + if (net_eq(wiphy_net(&rdev->wiphy), net)) + WARN_ON(cfg80211_switch_netns(rdev, &init_net)); + } + rtnl_unlock(); +} + +static struct pernet_operations cfg80211_pernet_ops = { + .exit = cfg80211_pernet_exit, +}; + +void wiphy_work_queue(struct wiphy *wiphy, struct wiphy_work *work) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + unsigned long flags; + + spin_lock_irqsave(&rdev->wiphy_work_lock, flags); + if (list_empty(&work->entry)) + list_add_tail(&work->entry, &rdev->wiphy_work_list); + spin_unlock_irqrestore(&rdev->wiphy_work_lock, flags); + + queue_work(system_unbound_wq, &rdev->wiphy_work); +} +EXPORT_SYMBOL_GPL(wiphy_work_queue); + +void wiphy_work_cancel(struct wiphy *wiphy, struct wiphy_work *work) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + unsigned long flags; + + lockdep_assert_held(&wiphy->mtx); + + spin_lock_irqsave(&rdev->wiphy_work_lock, flags); + if (!list_empty(&work->entry)) + list_del_init(&work->entry); + spin_unlock_irqrestore(&rdev->wiphy_work_lock, flags); +} +EXPORT_SYMBOL_GPL(wiphy_work_cancel); + +void wiphy_work_flush(struct wiphy *wiphy, struct wiphy_work *work) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + unsigned long flags; + bool run; + + spin_lock_irqsave(&rdev->wiphy_work_lock, flags); + run = !work || !list_empty(&work->entry); + spin_unlock_irqrestore(&rdev->wiphy_work_lock, flags); + + if (run) + cfg80211_process_wiphy_works(rdev, work); +} +EXPORT_SYMBOL_GPL(wiphy_work_flush); + +void wiphy_delayed_work_timer(struct timer_list *t) +{ + struct wiphy_delayed_work *dwork = from_timer(dwork, t, timer); + + wiphy_work_queue(dwork->wiphy, &dwork->work); +} +EXPORT_SYMBOL(wiphy_delayed_work_timer); + +void wiphy_delayed_work_queue(struct wiphy *wiphy, + struct wiphy_delayed_work *dwork, + unsigned long delay) +{ + if (!delay) { + wiphy_work_queue(wiphy, &dwork->work); + return; + } + + dwork->wiphy = wiphy; + mod_timer(&dwork->timer, jiffies + delay); +} +EXPORT_SYMBOL_GPL(wiphy_delayed_work_queue); + +void wiphy_delayed_work_cancel(struct wiphy *wiphy, + struct wiphy_delayed_work *dwork) +{ + lockdep_assert_held(&wiphy->mtx); + + del_timer_sync(&dwork->timer); + wiphy_work_cancel(wiphy, &dwork->work); +} +EXPORT_SYMBOL_GPL(wiphy_delayed_work_cancel); + +void wiphy_delayed_work_flush(struct wiphy *wiphy, + struct wiphy_delayed_work *dwork) +{ + lockdep_assert_held(&wiphy->mtx); + + del_timer_sync(&dwork->timer); + wiphy_work_flush(wiphy, &dwork->work); +} +EXPORT_SYMBOL_GPL(wiphy_delayed_work_flush); + +static int __init cfg80211_init(void) +{ + int err; + + err = register_pernet_device(&cfg80211_pernet_ops); + if (err) + goto out_fail_pernet; + + err = wiphy_sysfs_init(); + if (err) + goto out_fail_sysfs; + + err = register_netdevice_notifier(&cfg80211_netdev_notifier); + if (err) + goto out_fail_notifier; + + err = nl80211_init(); + if (err) + goto out_fail_nl80211; + + ieee80211_debugfs_dir = debugfs_create_dir("ieee80211", NULL); + + err = regulatory_init(); + if (err) + goto out_fail_reg; + + cfg80211_wq = alloc_ordered_workqueue("cfg80211", WQ_MEM_RECLAIM); + if (!cfg80211_wq) { + err = -ENOMEM; + goto out_fail_wq; + } + + return 0; + +out_fail_wq: + regulatory_exit(); +out_fail_reg: + debugfs_remove(ieee80211_debugfs_dir); + nl80211_exit(); +out_fail_nl80211: + unregister_netdevice_notifier(&cfg80211_netdev_notifier); +out_fail_notifier: + wiphy_sysfs_exit(); +out_fail_sysfs: + unregister_pernet_device(&cfg80211_pernet_ops); +out_fail_pernet: + return err; +} +fs_initcall(cfg80211_init); + +static void __exit cfg80211_exit(void) +{ + debugfs_remove(ieee80211_debugfs_dir); + nl80211_exit(); + unregister_netdevice_notifier(&cfg80211_netdev_notifier); + wiphy_sysfs_exit(); + regulatory_exit(); + unregister_pernet_device(&cfg80211_pernet_ops); + destroy_workqueue(cfg80211_wq); +} +module_exit(cfg80211_exit); diff --git a/net/wireless/core.h b/net/wireless/core.h new file mode 100644 index 000000000..ee980965a --- /dev/null +++ b/net/wireless/core.h @@ -0,0 +1,582 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Wireless configuration interface internals. + * + * Copyright 2006-2010 Johannes Berg + * Copyright (C) 2018-2022 Intel Corporation + */ +#ifndef __NET_WIRELESS_CORE_H +#define __NET_WIRELESS_CORE_H +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "reg.h" + + +#define WIPHY_IDX_INVALID -1 + +struct cfg80211_registered_device { + const struct cfg80211_ops *ops; + struct list_head list; + + /* rfkill support */ + struct rfkill_ops rfkill_ops; + struct work_struct rfkill_block; + + /* ISO / IEC 3166 alpha2 for which this device is receiving + * country IEs on, this can help disregard country IEs from APs + * on the same alpha2 quickly. The alpha2 may differ from + * cfg80211_regdomain's alpha2 when an intersection has occurred. + * If the AP is reconfigured this can also be used to tell us if + * the country on the country IE changed. */ + char country_ie_alpha2[2]; + + /* + * the driver requests the regulatory core to set this regulatory + * domain as the wiphy's. Only used for %REGULATORY_WIPHY_SELF_MANAGED + * devices using the regulatory_set_wiphy_regd() API + */ + const struct ieee80211_regdomain *requested_regd; + + /* If a Country IE has been received this tells us the environment + * which its telling us its in. This defaults to ENVIRON_ANY */ + enum environment_cap env; + + /* wiphy index, internal only */ + int wiphy_idx; + + /* protected by RTNL */ + int devlist_generation, wdev_id; + int opencount; + wait_queue_head_t dev_wait; + + struct list_head beacon_registrations; + spinlock_t beacon_registrations_lock; + + /* protected by RTNL only */ + int num_running_ifaces; + int num_running_monitor_ifaces; + u64 cookie_counter; + + /* BSSes/scanning */ + spinlock_t bss_lock; + struct list_head bss_list; + struct rb_root bss_tree; + u32 bss_generation; + u32 bss_entries; + struct cfg80211_scan_request *scan_req; /* protected by RTNL */ + struct cfg80211_scan_request *int_scan_req; + struct sk_buff *scan_msg; + struct list_head sched_scan_req_list; + time64_t suspend_at; + struct work_struct scan_done_wk; + + struct genl_info *cur_cmd_info; + + struct work_struct conn_work; + struct work_struct event_work; + + struct delayed_work dfs_update_channels_wk; + + struct wireless_dev *background_radar_wdev; + struct cfg80211_chan_def background_radar_chandef; + struct delayed_work background_cac_done_wk; + struct work_struct background_cac_abort_wk; + + /* netlink port which started critical protocol (0 means not started) */ + u32 crit_proto_nlportid; + + struct cfg80211_coalesce *coalesce; + + struct work_struct destroy_work; + struct work_struct sched_scan_stop_wk; + struct work_struct sched_scan_res_wk; + + struct cfg80211_chan_def radar_chandef; + struct work_struct propagate_radar_detect_wk; + + struct cfg80211_chan_def cac_done_chandef; + struct work_struct propagate_cac_done_wk; + + struct work_struct mgmt_registrations_update_wk; + /* lock for all wdev lists */ + spinlock_t mgmt_registrations_lock; + + struct work_struct wiphy_work; + struct list_head wiphy_work_list; + /* protects the list above */ + spinlock_t wiphy_work_lock; + bool suspended; + + /* must be last because of the way we do wiphy_priv(), + * and it should at least be aligned to NETDEV_ALIGN */ + struct wiphy wiphy __aligned(NETDEV_ALIGN); +}; + +static inline +struct cfg80211_registered_device *wiphy_to_rdev(struct wiphy *wiphy) +{ + BUG_ON(!wiphy); + return container_of(wiphy, struct cfg80211_registered_device, wiphy); +} + +static inline void +cfg80211_rdev_free_wowlan(struct cfg80211_registered_device *rdev) +{ +#ifdef CONFIG_PM + int i; + + if (!rdev->wiphy.wowlan_config) + return; + for (i = 0; i < rdev->wiphy.wowlan_config->n_patterns; i++) + kfree(rdev->wiphy.wowlan_config->patterns[i].mask); + kfree(rdev->wiphy.wowlan_config->patterns); + if (rdev->wiphy.wowlan_config->tcp && + rdev->wiphy.wowlan_config->tcp->sock) + sock_release(rdev->wiphy.wowlan_config->tcp->sock); + kfree(rdev->wiphy.wowlan_config->tcp); + kfree(rdev->wiphy.wowlan_config->nd_config); + kfree(rdev->wiphy.wowlan_config); +#endif +} + +static inline u64 cfg80211_assign_cookie(struct cfg80211_registered_device *rdev) +{ + u64 r = ++rdev->cookie_counter; + + if (WARN_ON(r == 0)) + r = ++rdev->cookie_counter; + + return r; +} + +extern struct workqueue_struct *cfg80211_wq; +extern struct list_head cfg80211_rdev_list; +extern int cfg80211_rdev_list_generation; + +struct cfg80211_internal_bss { + struct list_head list; + struct list_head hidden_list; + struct rb_node rbn; + u64 ts_boottime; + unsigned long ts; + unsigned long refcount; + atomic_t hold; + + /* time at the start of the reception of the first octet of the + * timestamp field of the last beacon/probe received for this BSS. + * The time is the TSF of the BSS specified by %parent_bssid. + */ + u64 parent_tsf; + + /* the BSS according to which %parent_tsf is set. This is set to + * the BSS that the interface that requested the scan was connected to + * when the beacon/probe was received. + */ + u8 parent_bssid[ETH_ALEN] __aligned(2); + + /* must be last because of priv member */ + struct cfg80211_bss pub; +}; + +static inline struct cfg80211_internal_bss *bss_from_pub(struct cfg80211_bss *pub) +{ + return container_of(pub, struct cfg80211_internal_bss, pub); +} + +static inline void cfg80211_hold_bss(struct cfg80211_internal_bss *bss) +{ + atomic_inc(&bss->hold); + if (bss->pub.transmitted_bss) { + bss = container_of(bss->pub.transmitted_bss, + struct cfg80211_internal_bss, pub); + atomic_inc(&bss->hold); + } +} + +static inline void cfg80211_unhold_bss(struct cfg80211_internal_bss *bss) +{ + int r = atomic_dec_return(&bss->hold); + WARN_ON(r < 0); + if (bss->pub.transmitted_bss) { + bss = container_of(bss->pub.transmitted_bss, + struct cfg80211_internal_bss, pub); + r = atomic_dec_return(&bss->hold); + WARN_ON(r < 0); + } +} + + +struct cfg80211_registered_device *cfg80211_rdev_by_wiphy_idx(int wiphy_idx); +int get_wiphy_idx(struct wiphy *wiphy); + +struct wiphy *wiphy_idx_to_wiphy(int wiphy_idx); + +int cfg80211_switch_netns(struct cfg80211_registered_device *rdev, + struct net *net); + +void cfg80211_init_wdev(struct wireless_dev *wdev); +void cfg80211_register_wdev(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev); + +static inline void wdev_lock(struct wireless_dev *wdev) + __acquires(wdev) +{ + mutex_lock(&wdev->mtx); + __acquire(wdev->mtx); +} + +static inline void wdev_unlock(struct wireless_dev *wdev) + __releases(wdev) +{ + __release(wdev->mtx); + mutex_unlock(&wdev->mtx); +} + +#define ASSERT_WDEV_LOCK(wdev) lockdep_assert_held(&(wdev)->mtx) + +static inline bool cfg80211_has_monitors_only(struct cfg80211_registered_device *rdev) +{ + lockdep_assert_held(&rdev->wiphy.mtx); + + return rdev->num_running_ifaces == rdev->num_running_monitor_ifaces && + rdev->num_running_ifaces > 0; +} + +enum cfg80211_event_type { + EVENT_CONNECT_RESULT, + EVENT_ROAMED, + EVENT_DISCONNECTED, + EVENT_IBSS_JOINED, + EVENT_STOPPED, + EVENT_PORT_AUTHORIZED, +}; + +struct cfg80211_event { + struct list_head list; + enum cfg80211_event_type type; + + union { + struct cfg80211_connect_resp_params cr; + struct cfg80211_roam_info rm; + struct { + const u8 *ie; + size_t ie_len; + u16 reason; + bool locally_generated; + } dc; + struct { + u8 bssid[ETH_ALEN]; + struct ieee80211_channel *channel; + } ij; + struct { + u8 bssid[ETH_ALEN]; + } pa; + }; +}; + +struct cfg80211_cached_keys { + struct key_params params[CFG80211_MAX_WEP_KEYS]; + u8 data[CFG80211_MAX_WEP_KEYS][WLAN_KEY_LEN_WEP104]; + int def; +}; + +struct cfg80211_beacon_registration { + struct list_head list; + u32 nlportid; +}; + +struct cfg80211_cqm_config { + struct rcu_head rcu_head; + u32 rssi_hyst; + s32 last_rssi_event_value; + enum nl80211_cqm_rssi_threshold_event last_rssi_event_type; + bool use_range_api; + int n_rssi_thresholds; + s32 rssi_thresholds[]; +}; + +void cfg80211_cqm_rssi_notify_work(struct wiphy *wiphy, + struct wiphy_work *work); + +void cfg80211_destroy_ifaces(struct cfg80211_registered_device *rdev); + +/* free object */ +void cfg80211_dev_free(struct cfg80211_registered_device *rdev); + +int cfg80211_dev_rename(struct cfg80211_registered_device *rdev, + char *newname); + +void ieee80211_set_bitrate_flags(struct wiphy *wiphy); + +void cfg80211_bss_expire(struct cfg80211_registered_device *rdev); +void cfg80211_bss_age(struct cfg80211_registered_device *rdev, + unsigned long age_secs); +void cfg80211_update_assoc_bss_entry(struct wireless_dev *wdev, + unsigned int link, + struct ieee80211_channel *channel); + +/* IBSS */ +int __cfg80211_join_ibss(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_ibss_params *params, + struct cfg80211_cached_keys *connkeys); +void cfg80211_clear_ibss(struct net_device *dev, bool nowext); +int __cfg80211_leave_ibss(struct cfg80211_registered_device *rdev, + struct net_device *dev, bool nowext); +int cfg80211_leave_ibss(struct cfg80211_registered_device *rdev, + struct net_device *dev, bool nowext); +void __cfg80211_ibss_joined(struct net_device *dev, const u8 *bssid, + struct ieee80211_channel *channel); +int cfg80211_ibss_wext_join(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev); + +/* mesh */ +extern const struct mesh_config default_mesh_config; +extern const struct mesh_setup default_mesh_setup; +int __cfg80211_join_mesh(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct mesh_setup *setup, + const struct mesh_config *conf); +int __cfg80211_leave_mesh(struct cfg80211_registered_device *rdev, + struct net_device *dev); +int cfg80211_leave_mesh(struct cfg80211_registered_device *rdev, + struct net_device *dev); +int cfg80211_set_mesh_channel(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + struct cfg80211_chan_def *chandef); + +/* OCB */ +int __cfg80211_join_ocb(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct ocb_setup *setup); +int cfg80211_join_ocb(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct ocb_setup *setup); +int __cfg80211_leave_ocb(struct cfg80211_registered_device *rdev, + struct net_device *dev); +int cfg80211_leave_ocb(struct cfg80211_registered_device *rdev, + struct net_device *dev); + +/* AP */ +int __cfg80211_stop_ap(struct cfg80211_registered_device *rdev, + struct net_device *dev, int link, + bool notify); +int cfg80211_stop_ap(struct cfg80211_registered_device *rdev, + struct net_device *dev, int link, + bool notify); + +/* MLME */ +int cfg80211_mlme_auth(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_auth_request *req); +int cfg80211_mlme_assoc(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_assoc_request *req); +int cfg80211_mlme_deauth(struct cfg80211_registered_device *rdev, + struct net_device *dev, const u8 *bssid, + const u8 *ie, int ie_len, u16 reason, + bool local_state_change); +int cfg80211_mlme_disassoc(struct cfg80211_registered_device *rdev, + struct net_device *dev, const u8 *ap_addr, + const u8 *ie, int ie_len, u16 reason, + bool local_state_change); +void cfg80211_mlme_down(struct cfg80211_registered_device *rdev, + struct net_device *dev); +int cfg80211_mlme_register_mgmt(struct wireless_dev *wdev, u32 snd_pid, + u16 frame_type, const u8 *match_data, + int match_len, bool multicast_rx, + struct netlink_ext_ack *extack); +void cfg80211_mgmt_registrations_update_wk(struct work_struct *wk); +void cfg80211_mlme_unregister_socket(struct wireless_dev *wdev, u32 nlpid); +void cfg80211_mlme_purge_registrations(struct wireless_dev *wdev); +int cfg80211_mlme_mgmt_tx(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + struct cfg80211_mgmt_tx_params *params, + u64 *cookie); +void cfg80211_oper_and_ht_capa(struct ieee80211_ht_cap *ht_capa, + const struct ieee80211_ht_cap *ht_capa_mask); +void cfg80211_oper_and_vht_capa(struct ieee80211_vht_cap *vht_capa, + const struct ieee80211_vht_cap *vht_capa_mask); + +/* SME events */ +int cfg80211_connect(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_connect_params *connect, + struct cfg80211_cached_keys *connkeys, + const u8 *prev_bssid); +void __cfg80211_connect_result(struct net_device *dev, + struct cfg80211_connect_resp_params *params, + bool wextev); +void __cfg80211_disconnected(struct net_device *dev, const u8 *ie, + size_t ie_len, u16 reason, bool from_ap); +int cfg80211_disconnect(struct cfg80211_registered_device *rdev, + struct net_device *dev, u16 reason, + bool wextev); +void __cfg80211_roamed(struct wireless_dev *wdev, + struct cfg80211_roam_info *info); +void __cfg80211_port_authorized(struct wireless_dev *wdev, const u8 *bssid); +int cfg80211_mgd_wext_connect(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev); +void cfg80211_autodisconnect_wk(struct work_struct *work); + +/* SME implementation */ +void cfg80211_conn_work(struct work_struct *work); +void cfg80211_sme_scan_done(struct net_device *dev); +bool cfg80211_sme_rx_assoc_resp(struct wireless_dev *wdev, u16 status); +void cfg80211_sme_rx_auth(struct wireless_dev *wdev, const u8 *buf, size_t len); +void cfg80211_sme_disassoc(struct wireless_dev *wdev); +void cfg80211_sme_deauth(struct wireless_dev *wdev); +void cfg80211_sme_auth_timeout(struct wireless_dev *wdev); +void cfg80211_sme_assoc_timeout(struct wireless_dev *wdev); +void cfg80211_sme_abandon_assoc(struct wireless_dev *wdev); + +/* internal helpers */ +bool cfg80211_supported_cipher_suite(struct wiphy *wiphy, u32 cipher); +bool cfg80211_valid_key_idx(struct cfg80211_registered_device *rdev, + int key_idx, bool pairwise); +int cfg80211_validate_key_settings(struct cfg80211_registered_device *rdev, + struct key_params *params, int key_idx, + bool pairwise, const u8 *mac_addr); +void __cfg80211_scan_done(struct work_struct *wk); +void ___cfg80211_scan_done(struct cfg80211_registered_device *rdev, + bool send_message); +void cfg80211_add_sched_scan_req(struct cfg80211_registered_device *rdev, + struct cfg80211_sched_scan_request *req); +int cfg80211_sched_scan_req_possible(struct cfg80211_registered_device *rdev, + bool want_multi); +void cfg80211_sched_scan_results_wk(struct work_struct *work); +int cfg80211_stop_sched_scan_req(struct cfg80211_registered_device *rdev, + struct cfg80211_sched_scan_request *req, + bool driver_initiated); +int __cfg80211_stop_sched_scan(struct cfg80211_registered_device *rdev, + u64 reqid, bool driver_initiated); +void cfg80211_upload_connect_keys(struct wireless_dev *wdev); +int cfg80211_change_iface(struct cfg80211_registered_device *rdev, + struct net_device *dev, enum nl80211_iftype ntype, + struct vif_params *params); +void cfg80211_process_rdev_events(struct cfg80211_registered_device *rdev); +void cfg80211_process_wiphy_works(struct cfg80211_registered_device *rdev, + struct wiphy_work *end); +void cfg80211_process_wdev_events(struct wireless_dev *wdev); + +bool cfg80211_does_bw_fit_range(const struct ieee80211_freq_range *freq_range, + u32 center_freq_khz, u32 bw_khz); + +int cfg80211_scan(struct cfg80211_registered_device *rdev); + +extern struct work_struct cfg80211_disconnect_work; + +/** + * cfg80211_chandef_dfs_usable - checks if chandef is DFS usable + * @wiphy: the wiphy to validate against + * @chandef: the channel definition to check + * + * Checks if chandef is usable and we can/need start CAC on such channel. + * + * Return: true if all channels available and at least + * one channel requires CAC (NL80211_DFS_USABLE) + */ +bool cfg80211_chandef_dfs_usable(struct wiphy *wiphy, + const struct cfg80211_chan_def *chandef); + +void cfg80211_set_dfs_state(struct wiphy *wiphy, + const struct cfg80211_chan_def *chandef, + enum nl80211_dfs_state dfs_state); + +void cfg80211_dfs_channels_update_work(struct work_struct *work); + +unsigned int +cfg80211_chandef_dfs_cac_time(struct wiphy *wiphy, + const struct cfg80211_chan_def *chandef); + +void cfg80211_sched_dfs_chan_update(struct cfg80211_registered_device *rdev); + +int +cfg80211_start_background_radar_detection(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + struct cfg80211_chan_def *chandef); + +void cfg80211_stop_background_radar_detection(struct wireless_dev *wdev); + +void cfg80211_background_cac_done_wk(struct work_struct *work); + +void cfg80211_background_cac_abort_wk(struct work_struct *work); + +bool cfg80211_any_wiphy_oper_chan(struct wiphy *wiphy, + struct ieee80211_channel *chan); + +bool cfg80211_beaconing_iface_active(struct wireless_dev *wdev); + +bool cfg80211_is_sub_chan(struct cfg80211_chan_def *chandef, + struct ieee80211_channel *chan, + bool primary_only); +bool cfg80211_wdev_on_sub_chan(struct wireless_dev *wdev, + struct ieee80211_channel *chan, + bool primary_only); + +static inline unsigned int elapsed_jiffies_msecs(unsigned long start) +{ + unsigned long end = jiffies; + + if (end >= start) + return jiffies_to_msecs(end - start); + + return jiffies_to_msecs(end + (ULONG_MAX - start) + 1); +} + +int cfg80211_set_monitor_channel(struct cfg80211_registered_device *rdev, + struct cfg80211_chan_def *chandef); + +int ieee80211_get_ratemask(struct ieee80211_supported_band *sband, + const u8 *rates, unsigned int n_rates, + u32 *mask); + +int cfg80211_validate_beacon_int(struct cfg80211_registered_device *rdev, + enum nl80211_iftype iftype, u32 beacon_int); + +void cfg80211_update_iface_num(struct cfg80211_registered_device *rdev, + enum nl80211_iftype iftype, int num); + +void __cfg80211_leave(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev); +void cfg80211_leave(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev); + +void cfg80211_stop_p2p_device(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev); + +void cfg80211_stop_nan(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev); + +struct cfg80211_internal_bss * +cfg80211_bss_update(struct cfg80211_registered_device *rdev, + struct cfg80211_internal_bss *tmp, + bool signal_valid, unsigned long ts); +#ifdef CONFIG_CFG80211_DEVELOPER_WARNINGS +#define CFG80211_DEV_WARN_ON(cond) WARN_ON(cond) +#else +/* + * Trick to enable using it as a condition, + * and also not give a warning when it's + * not used that way. + */ +#define CFG80211_DEV_WARN_ON(cond) ({bool __r = (cond); __r; }) +#endif + +void cfg80211_release_pmsr(struct wireless_dev *wdev, u32 portid); +void cfg80211_pmsr_wdev_down(struct wireless_dev *wdev); +void cfg80211_pmsr_free_wk(struct work_struct *work); + +void cfg80211_remove_link(struct wireless_dev *wdev, unsigned int link_id); +void cfg80211_remove_links(struct wireless_dev *wdev); +int cfg80211_remove_virtual_intf(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev); + +#endif /* __NET_WIRELESS_CORE_H */ diff --git a/net/wireless/debugfs.c b/net/wireless/debugfs.c new file mode 100644 index 000000000..0878b1628 --- /dev/null +++ b/net/wireless/debugfs.c @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * cfg80211 debugfs + * + * Copyright 2009 Luis R. Rodriguez + * Copyright 2007 Johannes Berg + */ + +#include +#include "core.h" +#include "debugfs.h" + +#define DEBUGFS_READONLY_FILE(name, buflen, fmt, value...) \ +static ssize_t name## _read(struct file *file, char __user *userbuf, \ + size_t count, loff_t *ppos) \ +{ \ + struct wiphy *wiphy = file->private_data; \ + char buf[buflen]; \ + int res; \ + \ + res = scnprintf(buf, buflen, fmt "\n", ##value); \ + return simple_read_from_buffer(userbuf, count, ppos, buf, res); \ +} \ + \ +static const struct file_operations name## _ops = { \ + .read = name## _read, \ + .open = simple_open, \ + .llseek = generic_file_llseek, \ +} + +DEBUGFS_READONLY_FILE(rts_threshold, 20, "%d", + wiphy->rts_threshold); +DEBUGFS_READONLY_FILE(fragmentation_threshold, 20, "%d", + wiphy->frag_threshold); +DEBUGFS_READONLY_FILE(short_retry_limit, 20, "%d", + wiphy->retry_short); +DEBUGFS_READONLY_FILE(long_retry_limit, 20, "%d", + wiphy->retry_long); + +static int ht_print_chan(struct ieee80211_channel *chan, + char *buf, int buf_size, int offset) +{ + if (WARN_ON(offset > buf_size)) + return 0; + + if (chan->flags & IEEE80211_CHAN_DISABLED) + return scnprintf(buf + offset, + buf_size - offset, + "%d Disabled\n", + chan->center_freq); + + return scnprintf(buf + offset, + buf_size - offset, + "%d HT40 %c%c\n", + chan->center_freq, + (chan->flags & IEEE80211_CHAN_NO_HT40MINUS) ? + ' ' : '-', + (chan->flags & IEEE80211_CHAN_NO_HT40PLUS) ? + ' ' : '+'); +} + +static ssize_t ht40allow_map_read(struct file *file, + char __user *user_buf, + size_t count, loff_t *ppos) +{ + struct wiphy *wiphy = file->private_data; + char *buf; + unsigned int offset = 0, buf_size = PAGE_SIZE, i; + enum nl80211_band band; + struct ieee80211_supported_band *sband; + ssize_t r; + + buf = kzalloc(buf_size, GFP_KERNEL); + if (!buf) + return -ENOMEM; + + for (band = 0; band < NUM_NL80211_BANDS; band++) { + sband = wiphy->bands[band]; + if (!sband) + continue; + for (i = 0; i < sband->n_channels; i++) + offset += ht_print_chan(&sband->channels[i], + buf, buf_size, offset); + } + + r = simple_read_from_buffer(user_buf, count, ppos, buf, offset); + + kfree(buf); + + return r; +} + +static const struct file_operations ht40allow_map_ops = { + .read = ht40allow_map_read, + .open = simple_open, + .llseek = default_llseek, +}; + +#define DEBUGFS_ADD(name) \ + debugfs_create_file(#name, 0444, phyd, &rdev->wiphy, &name## _ops) + +void cfg80211_debugfs_rdev_add(struct cfg80211_registered_device *rdev) +{ + struct dentry *phyd = rdev->wiphy.debugfsdir; + + DEBUGFS_ADD(rts_threshold); + DEBUGFS_ADD(fragmentation_threshold); + DEBUGFS_ADD(short_retry_limit); + DEBUGFS_ADD(long_retry_limit); + DEBUGFS_ADD(ht40allow_map); +} diff --git a/net/wireless/debugfs.h b/net/wireless/debugfs.h new file mode 100644 index 000000000..a8a135d94 --- /dev/null +++ b/net/wireless/debugfs.h @@ -0,0 +1,12 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __CFG80211_DEBUGFS_H +#define __CFG80211_DEBUGFS_H + +#ifdef CONFIG_CFG80211_DEBUGFS +void cfg80211_debugfs_rdev_add(struct cfg80211_registered_device *rdev); +#else +static inline +void cfg80211_debugfs_rdev_add(struct cfg80211_registered_device *rdev) {} +#endif + +#endif /* __CFG80211_DEBUGFS_H */ diff --git a/net/wireless/ethtool.c b/net/wireless/ethtool.c new file mode 100644 index 000000000..2613d6ac0 --- /dev/null +++ b/net/wireless/ethtool.c @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include "core.h" +#include "rdev-ops.h" + +void cfg80211_get_drvinfo(struct net_device *dev, struct ethtool_drvinfo *info) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct device *pdev = wiphy_dev(wdev->wiphy); + + if (pdev->driver) + strscpy(info->driver, pdev->driver->name, + sizeof(info->driver)); + else + strscpy(info->driver, "N/A", sizeof(info->driver)); + + strscpy(info->version, init_utsname()->release, sizeof(info->version)); + + if (wdev->wiphy->fw_version[0]) + strscpy(info->fw_version, wdev->wiphy->fw_version, + sizeof(info->fw_version)); + else + strscpy(info->fw_version, "N/A", sizeof(info->fw_version)); + + strscpy(info->bus_info, dev_name(wiphy_dev(wdev->wiphy)), + sizeof(info->bus_info)); +} +EXPORT_SYMBOL(cfg80211_get_drvinfo); diff --git a/net/wireless/ibss.c b/net/wireless/ibss.c new file mode 100644 index 000000000..edd062f10 --- /dev/null +++ b/net/wireless/ibss.c @@ -0,0 +1,542 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Some IBSS support code for cfg80211. + * + * Copyright 2009 Johannes Berg + * Copyright (C) 2020-2022 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include "wext-compat.h" +#include "nl80211.h" +#include "rdev-ops.h" + + +void __cfg80211_ibss_joined(struct net_device *dev, const u8 *bssid, + struct ieee80211_channel *channel) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_bss *bss; +#ifdef CONFIG_CFG80211_WEXT + union iwreq_data wrqu; +#endif + + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_ADHOC)) + return; + + if (!wdev->u.ibss.ssid_len) + return; + + bss = cfg80211_get_bss(wdev->wiphy, channel, bssid, NULL, 0, + IEEE80211_BSS_TYPE_IBSS, IEEE80211_PRIVACY_ANY); + + if (WARN_ON(!bss)) + return; + + if (wdev->u.ibss.current_bss) { + cfg80211_unhold_bss(wdev->u.ibss.current_bss); + cfg80211_put_bss(wdev->wiphy, &wdev->u.ibss.current_bss->pub); + } + + cfg80211_hold_bss(bss_from_pub(bss)); + wdev->u.ibss.current_bss = bss_from_pub(bss); + + if (!(wdev->wiphy->flags & WIPHY_FLAG_HAS_STATIC_WEP)) + cfg80211_upload_connect_keys(wdev); + + nl80211_send_ibss_bssid(wiphy_to_rdev(wdev->wiphy), dev, bssid, + GFP_KERNEL); +#ifdef CONFIG_CFG80211_WEXT + memset(&wrqu, 0, sizeof(wrqu)); + memcpy(wrqu.ap_addr.sa_data, bssid, ETH_ALEN); + wireless_send_event(dev, SIOCGIWAP, &wrqu, NULL); +#endif +} + +void cfg80211_ibss_joined(struct net_device *dev, const u8 *bssid, + struct ieee80211_channel *channel, gfp_t gfp) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_event *ev; + unsigned long flags; + + trace_cfg80211_ibss_joined(dev, bssid, channel); + + if (WARN_ON(!channel)) + return; + + ev = kzalloc(sizeof(*ev), gfp); + if (!ev) + return; + + ev->type = EVENT_IBSS_JOINED; + memcpy(ev->ij.bssid, bssid, ETH_ALEN); + ev->ij.channel = channel; + + spin_lock_irqsave(&wdev->event_lock, flags); + list_add_tail(&ev->list, &wdev->event_list); + spin_unlock_irqrestore(&wdev->event_lock, flags); + queue_work(cfg80211_wq, &rdev->event_work); +} +EXPORT_SYMBOL(cfg80211_ibss_joined); + +int __cfg80211_join_ibss(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_ibss_params *params, + struct cfg80211_cached_keys *connkeys) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int err; + + lockdep_assert_held(&rdev->wiphy.mtx); + ASSERT_WDEV_LOCK(wdev); + + if (wdev->u.ibss.ssid_len) + return -EALREADY; + + if (!params->basic_rates) { + /* + * If no rates were explicitly configured, + * use the mandatory rate set for 11b or + * 11a for maximum compatibility. + */ + struct ieee80211_supported_band *sband; + enum nl80211_band band; + u32 flag; + int j; + + band = params->chandef.chan->band; + if (band == NL80211_BAND_5GHZ || + band == NL80211_BAND_6GHZ) + flag = IEEE80211_RATE_MANDATORY_A; + else + flag = IEEE80211_RATE_MANDATORY_B; + + sband = rdev->wiphy.bands[band]; + for (j = 0; j < sband->n_bitrates; j++) { + if (sband->bitrates[j].flags & flag) + params->basic_rates |= BIT(j); + } + } + + if (WARN_ON(connkeys && connkeys->def < 0)) + return -EINVAL; + + if (WARN_ON(wdev->connect_keys)) + kfree_sensitive(wdev->connect_keys); + wdev->connect_keys = connkeys; + + wdev->u.ibss.chandef = params->chandef; + if (connkeys) { + params->wep_keys = connkeys->params; + params->wep_tx_key = connkeys->def; + } + +#ifdef CONFIG_CFG80211_WEXT + wdev->wext.ibss.chandef = params->chandef; +#endif + err = rdev_join_ibss(rdev, dev, params); + if (err) { + wdev->connect_keys = NULL; + return err; + } + + memcpy(wdev->u.ibss.ssid, params->ssid, params->ssid_len); + wdev->u.ibss.ssid_len = params->ssid_len; + + return 0; +} + +static void __cfg80211_clear_ibss(struct net_device *dev, bool nowext) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + int i; + + ASSERT_WDEV_LOCK(wdev); + + kfree_sensitive(wdev->connect_keys); + wdev->connect_keys = NULL; + + rdev_set_qos_map(rdev, dev, NULL); + + /* + * Delete all the keys ... pairwise keys can't really + * exist any more anyway, but default keys might. + */ + if (rdev->ops->del_key) + for (i = 0; i < 6; i++) + rdev_del_key(rdev, dev, -1, i, false, NULL); + + if (wdev->u.ibss.current_bss) { + cfg80211_unhold_bss(wdev->u.ibss.current_bss); + cfg80211_put_bss(wdev->wiphy, &wdev->u.ibss.current_bss->pub); + } + + wdev->u.ibss.current_bss = NULL; + wdev->u.ibss.ssid_len = 0; + memset(&wdev->u.ibss.chandef, 0, sizeof(wdev->u.ibss.chandef)); +#ifdef CONFIG_CFG80211_WEXT + if (!nowext) + wdev->wext.ibss.ssid_len = 0; +#endif + cfg80211_sched_dfs_chan_update(rdev); +} + +void cfg80211_clear_ibss(struct net_device *dev, bool nowext) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + + wdev_lock(wdev); + __cfg80211_clear_ibss(dev, nowext); + wdev_unlock(wdev); +} + +int __cfg80211_leave_ibss(struct cfg80211_registered_device *rdev, + struct net_device *dev, bool nowext) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int err; + + ASSERT_WDEV_LOCK(wdev); + + if (!wdev->u.ibss.ssid_len) + return -ENOLINK; + + err = rdev_leave_ibss(rdev, dev); + + if (err) + return err; + + wdev->conn_owner_nlportid = 0; + __cfg80211_clear_ibss(dev, nowext); + + return 0; +} + +int cfg80211_leave_ibss(struct cfg80211_registered_device *rdev, + struct net_device *dev, bool nowext) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int err; + + wdev_lock(wdev); + err = __cfg80211_leave_ibss(rdev, dev, nowext); + wdev_unlock(wdev); + + return err; +} + +#ifdef CONFIG_CFG80211_WEXT +int cfg80211_ibss_wext_join(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev) +{ + struct cfg80211_cached_keys *ck = NULL; + enum nl80211_band band; + int i, err; + + ASSERT_WDEV_LOCK(wdev); + + if (!wdev->wext.ibss.beacon_interval) + wdev->wext.ibss.beacon_interval = 100; + + /* try to find an IBSS channel if none requested ... */ + if (!wdev->wext.ibss.chandef.chan) { + struct ieee80211_channel *new_chan = NULL; + + for (band = 0; band < NUM_NL80211_BANDS; band++) { + struct ieee80211_supported_band *sband; + struct ieee80211_channel *chan; + + sband = rdev->wiphy.bands[band]; + if (!sband) + continue; + + for (i = 0; i < sband->n_channels; i++) { + chan = &sband->channels[i]; + if (chan->flags & IEEE80211_CHAN_NO_IR) + continue; + if (chan->flags & IEEE80211_CHAN_DISABLED) + continue; + new_chan = chan; + break; + } + + if (new_chan) + break; + } + + if (!new_chan) + return -EINVAL; + + cfg80211_chandef_create(&wdev->wext.ibss.chandef, new_chan, + NL80211_CHAN_NO_HT); + } + + /* don't join -- SSID is not there */ + if (!wdev->wext.ibss.ssid_len) + return 0; + + if (!netif_running(wdev->netdev)) + return 0; + + if (wdev->wext.keys) + wdev->wext.keys->def = wdev->wext.default_key; + + wdev->wext.ibss.privacy = wdev->wext.default_key != -1; + + if (wdev->wext.keys && wdev->wext.keys->def != -1) { + ck = kmemdup(wdev->wext.keys, sizeof(*ck), GFP_KERNEL); + if (!ck) + return -ENOMEM; + for (i = 0; i < CFG80211_MAX_WEP_KEYS; i++) + ck->params[i].key = ck->data[i]; + } + err = __cfg80211_join_ibss(rdev, wdev->netdev, + &wdev->wext.ibss, ck); + if (err) + kfree(ck); + + return err; +} + +int cfg80211_ibss_wext_siwfreq(struct net_device *dev, + struct iw_request_info *info, + struct iw_freq *wextfreq, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct ieee80211_channel *chan = NULL; + int err, freq; + + /* call only for ibss! */ + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_ADHOC)) + return -EINVAL; + + if (!rdev->ops->join_ibss) + return -EOPNOTSUPP; + + freq = cfg80211_wext_freq(wextfreq); + if (freq < 0) + return freq; + + if (freq) { + chan = ieee80211_get_channel(wdev->wiphy, freq); + if (!chan) + return -EINVAL; + if (chan->flags & IEEE80211_CHAN_NO_IR || + chan->flags & IEEE80211_CHAN_DISABLED) + return -EINVAL; + } + + if (wdev->wext.ibss.chandef.chan == chan) + return 0; + + wdev_lock(wdev); + err = 0; + if (wdev->u.ibss.ssid_len) + err = __cfg80211_leave_ibss(rdev, dev, true); + wdev_unlock(wdev); + + if (err) + return err; + + if (chan) { + cfg80211_chandef_create(&wdev->wext.ibss.chandef, chan, + NL80211_CHAN_NO_HT); + wdev->wext.ibss.channel_fixed = true; + } else { + /* cfg80211_ibss_wext_join will pick one if needed */ + wdev->wext.ibss.channel_fixed = false; + } + + wdev_lock(wdev); + err = cfg80211_ibss_wext_join(rdev, wdev); + wdev_unlock(wdev); + + return err; +} + +int cfg80211_ibss_wext_giwfreq(struct net_device *dev, + struct iw_request_info *info, + struct iw_freq *freq, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct ieee80211_channel *chan = NULL; + + /* call only for ibss! */ + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_ADHOC)) + return -EINVAL; + + wdev_lock(wdev); + if (wdev->u.ibss.current_bss) + chan = wdev->u.ibss.current_bss->pub.channel; + else if (wdev->wext.ibss.chandef.chan) + chan = wdev->wext.ibss.chandef.chan; + wdev_unlock(wdev); + + if (chan) { + freq->m = chan->center_freq; + freq->e = 6; + return 0; + } + + /* no channel if not joining */ + return -EINVAL; +} + +int cfg80211_ibss_wext_siwessid(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *data, char *ssid) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + size_t len = data->length; + int err; + + /* call only for ibss! */ + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_ADHOC)) + return -EINVAL; + + if (!rdev->ops->join_ibss) + return -EOPNOTSUPP; + + wdev_lock(wdev); + err = 0; + if (wdev->u.ibss.ssid_len) + err = __cfg80211_leave_ibss(rdev, dev, true); + wdev_unlock(wdev); + + if (err) + return err; + + /* iwconfig uses nul termination in SSID.. */ + if (len > 0 && ssid[len - 1] == '\0') + len--; + + memcpy(wdev->u.ibss.ssid, ssid, len); + wdev->wext.ibss.ssid = wdev->u.ibss.ssid; + wdev->wext.ibss.ssid_len = len; + + wdev_lock(wdev); + err = cfg80211_ibss_wext_join(rdev, wdev); + wdev_unlock(wdev); + + return err; +} + +int cfg80211_ibss_wext_giwessid(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *data, char *ssid) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + + /* call only for ibss! */ + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_ADHOC)) + return -EINVAL; + + data->flags = 0; + + wdev_lock(wdev); + if (wdev->u.ibss.ssid_len) { + data->flags = 1; + data->length = wdev->u.ibss.ssid_len; + memcpy(ssid, wdev->u.ibss.ssid, data->length); + } else if (wdev->wext.ibss.ssid && wdev->wext.ibss.ssid_len) { + data->flags = 1; + data->length = wdev->wext.ibss.ssid_len; + memcpy(ssid, wdev->wext.ibss.ssid, data->length); + } + wdev_unlock(wdev); + + return 0; +} + +int cfg80211_ibss_wext_siwap(struct net_device *dev, + struct iw_request_info *info, + struct sockaddr *ap_addr, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + u8 *bssid = ap_addr->sa_data; + int err; + + /* call only for ibss! */ + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_ADHOC)) + return -EINVAL; + + if (!rdev->ops->join_ibss) + return -EOPNOTSUPP; + + if (ap_addr->sa_family != ARPHRD_ETHER) + return -EINVAL; + + /* automatic mode */ + if (is_zero_ether_addr(bssid) || is_broadcast_ether_addr(bssid)) + bssid = NULL; + + if (bssid && !is_valid_ether_addr(bssid)) + return -EINVAL; + + /* both automatic */ + if (!bssid && !wdev->wext.ibss.bssid) + return 0; + + /* fixed already - and no change */ + if (wdev->wext.ibss.bssid && bssid && + ether_addr_equal(bssid, wdev->wext.ibss.bssid)) + return 0; + + wdev_lock(wdev); + err = 0; + if (wdev->u.ibss.ssid_len) + err = __cfg80211_leave_ibss(rdev, dev, true); + wdev_unlock(wdev); + + if (err) + return err; + + if (bssid) { + memcpy(wdev->wext.bssid, bssid, ETH_ALEN); + wdev->wext.ibss.bssid = wdev->wext.bssid; + } else + wdev->wext.ibss.bssid = NULL; + + wdev_lock(wdev); + err = cfg80211_ibss_wext_join(rdev, wdev); + wdev_unlock(wdev); + + return err; +} + +int cfg80211_ibss_wext_giwap(struct net_device *dev, + struct iw_request_info *info, + struct sockaddr *ap_addr, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + + /* call only for ibss! */ + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_ADHOC)) + return -EINVAL; + + ap_addr->sa_family = ARPHRD_ETHER; + + wdev_lock(wdev); + if (wdev->u.ibss.current_bss) + memcpy(ap_addr->sa_data, wdev->u.ibss.current_bss->pub.bssid, + ETH_ALEN); + else if (wdev->wext.ibss.bssid) + memcpy(ap_addr->sa_data, wdev->wext.ibss.bssid, ETH_ALEN); + else + eth_zero_addr(ap_addr->sa_data); + + wdev_unlock(wdev); + + return 0; +} +#endif diff --git a/net/wireless/lib80211.c b/net/wireless/lib80211.c new file mode 100644 index 000000000..d66a91302 --- /dev/null +++ b/net/wireless/lib80211.c @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * lib80211 -- common bits for IEEE802.11 drivers + * + * Copyright(c) 2008 John W. Linville + * + * Portions copied from old ieee80211 component, w/ original copyright + * notices below: + * + * Host AP crypto routines + * + * Copyright (c) 2002-2003, Jouni Malinen + * Portions Copyright (C) 2004, Intel Corporation + * + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include + +#include + +#define DRV_DESCRIPTION "common routines for IEEE802.11 drivers" + +MODULE_DESCRIPTION(DRV_DESCRIPTION); +MODULE_AUTHOR("John W. Linville "); +MODULE_LICENSE("GPL"); + +struct lib80211_crypto_alg { + struct list_head list; + struct lib80211_crypto_ops *ops; +}; + +static LIST_HEAD(lib80211_crypto_algs); +static DEFINE_SPINLOCK(lib80211_crypto_lock); + +static void lib80211_crypt_deinit_entries(struct lib80211_crypt_info *info, + int force); +static void lib80211_crypt_quiescing(struct lib80211_crypt_info *info); +static void lib80211_crypt_deinit_handler(struct timer_list *t); + +int lib80211_crypt_info_init(struct lib80211_crypt_info *info, char *name, + spinlock_t *lock) +{ + memset(info, 0, sizeof(*info)); + + info->name = name; + info->lock = lock; + + INIT_LIST_HEAD(&info->crypt_deinit_list); + timer_setup(&info->crypt_deinit_timer, lib80211_crypt_deinit_handler, + 0); + + return 0; +} +EXPORT_SYMBOL(lib80211_crypt_info_init); + +void lib80211_crypt_info_free(struct lib80211_crypt_info *info) +{ + int i; + + lib80211_crypt_quiescing(info); + del_timer_sync(&info->crypt_deinit_timer); + lib80211_crypt_deinit_entries(info, 1); + + for (i = 0; i < NUM_WEP_KEYS; i++) { + struct lib80211_crypt_data *crypt = info->crypt[i]; + if (crypt) { + if (crypt->ops) { + crypt->ops->deinit(crypt->priv); + module_put(crypt->ops->owner); + } + kfree(crypt); + info->crypt[i] = NULL; + } + } +} +EXPORT_SYMBOL(lib80211_crypt_info_free); + +static void lib80211_crypt_deinit_entries(struct lib80211_crypt_info *info, + int force) +{ + struct lib80211_crypt_data *entry, *next; + unsigned long flags; + + spin_lock_irqsave(info->lock, flags); + list_for_each_entry_safe(entry, next, &info->crypt_deinit_list, list) { + if (atomic_read(&entry->refcnt) != 0 && !force) + continue; + + list_del(&entry->list); + + if (entry->ops) { + entry->ops->deinit(entry->priv); + module_put(entry->ops->owner); + } + kfree(entry); + } + spin_unlock_irqrestore(info->lock, flags); +} + +/* After this, crypt_deinit_list won't accept new members */ +static void lib80211_crypt_quiescing(struct lib80211_crypt_info *info) +{ + unsigned long flags; + + spin_lock_irqsave(info->lock, flags); + info->crypt_quiesced = 1; + spin_unlock_irqrestore(info->lock, flags); +} + +static void lib80211_crypt_deinit_handler(struct timer_list *t) +{ + struct lib80211_crypt_info *info = from_timer(info, t, + crypt_deinit_timer); + unsigned long flags; + + lib80211_crypt_deinit_entries(info, 0); + + spin_lock_irqsave(info->lock, flags); + if (!list_empty(&info->crypt_deinit_list) && !info->crypt_quiesced) { + printk(KERN_DEBUG "%s: entries remaining in delayed crypt " + "deletion list\n", info->name); + info->crypt_deinit_timer.expires = jiffies + HZ; + add_timer(&info->crypt_deinit_timer); + } + spin_unlock_irqrestore(info->lock, flags); +} + +void lib80211_crypt_delayed_deinit(struct lib80211_crypt_info *info, + struct lib80211_crypt_data **crypt) +{ + struct lib80211_crypt_data *tmp; + unsigned long flags; + + if (*crypt == NULL) + return; + + tmp = *crypt; + *crypt = NULL; + + /* must not run ops->deinit() while there may be pending encrypt or + * decrypt operations. Use a list of delayed deinits to avoid needing + * locking. */ + + spin_lock_irqsave(info->lock, flags); + if (!info->crypt_quiesced) { + list_add(&tmp->list, &info->crypt_deinit_list); + if (!timer_pending(&info->crypt_deinit_timer)) { + info->crypt_deinit_timer.expires = jiffies + HZ; + add_timer(&info->crypt_deinit_timer); + } + } + spin_unlock_irqrestore(info->lock, flags); +} +EXPORT_SYMBOL(lib80211_crypt_delayed_deinit); + +int lib80211_register_crypto_ops(struct lib80211_crypto_ops *ops) +{ + unsigned long flags; + struct lib80211_crypto_alg *alg; + + alg = kzalloc(sizeof(*alg), GFP_KERNEL); + if (alg == NULL) + return -ENOMEM; + + alg->ops = ops; + + spin_lock_irqsave(&lib80211_crypto_lock, flags); + list_add(&alg->list, &lib80211_crypto_algs); + spin_unlock_irqrestore(&lib80211_crypto_lock, flags); + + printk(KERN_DEBUG "lib80211_crypt: registered algorithm '%s'\n", + ops->name); + + return 0; +} +EXPORT_SYMBOL(lib80211_register_crypto_ops); + +int lib80211_unregister_crypto_ops(struct lib80211_crypto_ops *ops) +{ + struct lib80211_crypto_alg *alg; + unsigned long flags; + + spin_lock_irqsave(&lib80211_crypto_lock, flags); + list_for_each_entry(alg, &lib80211_crypto_algs, list) { + if (alg->ops == ops) + goto found; + } + spin_unlock_irqrestore(&lib80211_crypto_lock, flags); + return -EINVAL; + + found: + printk(KERN_DEBUG "lib80211_crypt: unregistered algorithm '%s'\n", + ops->name); + list_del(&alg->list); + spin_unlock_irqrestore(&lib80211_crypto_lock, flags); + kfree(alg); + return 0; +} +EXPORT_SYMBOL(lib80211_unregister_crypto_ops); + +struct lib80211_crypto_ops *lib80211_get_crypto_ops(const char *name) +{ + struct lib80211_crypto_alg *alg; + unsigned long flags; + + spin_lock_irqsave(&lib80211_crypto_lock, flags); + list_for_each_entry(alg, &lib80211_crypto_algs, list) { + if (strcmp(alg->ops->name, name) == 0) + goto found; + } + spin_unlock_irqrestore(&lib80211_crypto_lock, flags); + return NULL; + + found: + spin_unlock_irqrestore(&lib80211_crypto_lock, flags); + return alg->ops; +} +EXPORT_SYMBOL(lib80211_get_crypto_ops); + +static void *lib80211_crypt_null_init(int keyidx) +{ + return (void *)1; +} + +static void lib80211_crypt_null_deinit(void *priv) +{ +} + +static struct lib80211_crypto_ops lib80211_crypt_null = { + .name = "NULL", + .init = lib80211_crypt_null_init, + .deinit = lib80211_crypt_null_deinit, + .owner = THIS_MODULE, +}; + +static int __init lib80211_init(void) +{ + pr_info(DRV_DESCRIPTION "\n"); + return lib80211_register_crypto_ops(&lib80211_crypt_null); +} + +static void __exit lib80211_exit(void) +{ + lib80211_unregister_crypto_ops(&lib80211_crypt_null); + BUG_ON(!list_empty(&lib80211_crypto_algs)); +} + +module_init(lib80211_init); +module_exit(lib80211_exit); diff --git a/net/wireless/lib80211_crypt_ccmp.c b/net/wireless/lib80211_crypt_ccmp.c new file mode 100644 index 000000000..cca5e1cf0 --- /dev/null +++ b/net/wireless/lib80211_crypt_ccmp.c @@ -0,0 +1,448 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * lib80211 crypt: host-based CCMP encryption implementation for lib80211 + * + * Copyright (c) 2003-2004, Jouni Malinen + * Copyright (c) 2008, John W. Linville + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include + +MODULE_AUTHOR("Jouni Malinen"); +MODULE_DESCRIPTION("Host AP crypt: CCMP"); +MODULE_LICENSE("GPL"); + +#define AES_BLOCK_LEN 16 +#define CCMP_HDR_LEN 8 +#define CCMP_MIC_LEN 8 +#define CCMP_TK_LEN 16 +#define CCMP_PN_LEN 6 + +struct lib80211_ccmp_data { + u8 key[CCMP_TK_LEN]; + int key_set; + + u8 tx_pn[CCMP_PN_LEN]; + u8 rx_pn[CCMP_PN_LEN]; + + u32 dot11RSNAStatsCCMPFormatErrors; + u32 dot11RSNAStatsCCMPReplays; + u32 dot11RSNAStatsCCMPDecryptErrors; + + int key_idx; + + struct crypto_aead *tfm; + + /* scratch buffers for virt_to_page() (crypto API) */ + u8 tx_aad[2 * AES_BLOCK_LEN]; + u8 rx_aad[2 * AES_BLOCK_LEN]; +}; + +static void *lib80211_ccmp_init(int key_idx) +{ + struct lib80211_ccmp_data *priv; + + priv = kzalloc(sizeof(*priv), GFP_ATOMIC); + if (priv == NULL) + goto fail; + priv->key_idx = key_idx; + + priv->tfm = crypto_alloc_aead("ccm(aes)", 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(priv->tfm)) { + priv->tfm = NULL; + goto fail; + } + + return priv; + + fail: + if (priv) { + if (priv->tfm) + crypto_free_aead(priv->tfm); + kfree(priv); + } + + return NULL; +} + +static void lib80211_ccmp_deinit(void *priv) +{ + struct lib80211_ccmp_data *_priv = priv; + if (_priv && _priv->tfm) + crypto_free_aead(_priv->tfm); + kfree(priv); +} + +static int ccmp_init_iv_and_aad(const struct ieee80211_hdr *hdr, + const u8 *pn, u8 *iv, u8 *aad) +{ + u8 *pos, qc = 0; + size_t aad_len; + int a4_included, qc_included; + + a4_included = ieee80211_has_a4(hdr->frame_control); + qc_included = ieee80211_is_data_qos(hdr->frame_control); + + aad_len = 22; + if (a4_included) + aad_len += 6; + if (qc_included) { + pos = (u8 *) & hdr->addr4; + if (a4_included) + pos += 6; + qc = *pos & 0x0f; + aad_len += 2; + } + + /* In CCM, the initial vectors (IV) used for CTR mode encryption and CBC + * mode authentication are not allowed to collide, yet both are derived + * from the same vector. We only set L := 1 here to indicate that the + * data size can be represented in (L+1) bytes. The CCM layer will take + * care of storing the data length in the top (L+1) bytes and setting + * and clearing the other bits as is required to derive the two IVs. + */ + iv[0] = 0x1; + + /* Nonce: QC | A2 | PN */ + iv[1] = qc; + memcpy(iv + 2, hdr->addr2, ETH_ALEN); + memcpy(iv + 8, pn, CCMP_PN_LEN); + + /* AAD: + * FC with bits 4..6 and 11..13 masked to zero; 14 is always one + * A1 | A2 | A3 + * SC with bits 4..15 (seq#) masked to zero + * A4 (if present) + * QC (if present) + */ + pos = (u8 *) hdr; + aad[0] = pos[0] & 0x8f; + aad[1] = pos[1] & 0xc7; + memcpy(aad + 2, &hdr->addrs, 3 * ETH_ALEN); + pos = (u8 *) & hdr->seq_ctrl; + aad[20] = pos[0] & 0x0f; + aad[21] = 0; /* all bits masked */ + memset(aad + 22, 0, 8); + if (a4_included) + memcpy(aad + 22, hdr->addr4, ETH_ALEN); + if (qc_included) { + aad[a4_included ? 28 : 22] = qc; + /* rest of QC masked */ + } + return aad_len; +} + +static int lib80211_ccmp_hdr(struct sk_buff *skb, int hdr_len, + u8 *aeskey, int keylen, void *priv) +{ + struct lib80211_ccmp_data *key = priv; + int i; + u8 *pos; + + if (skb_headroom(skb) < CCMP_HDR_LEN || skb->len < hdr_len) + return -1; + + if (aeskey != NULL && keylen >= CCMP_TK_LEN) + memcpy(aeskey, key->key, CCMP_TK_LEN); + + pos = skb_push(skb, CCMP_HDR_LEN); + memmove(pos, pos + CCMP_HDR_LEN, hdr_len); + pos += hdr_len; + + i = CCMP_PN_LEN - 1; + while (i >= 0) { + key->tx_pn[i]++; + if (key->tx_pn[i] != 0) + break; + i--; + } + + *pos++ = key->tx_pn[5]; + *pos++ = key->tx_pn[4]; + *pos++ = 0; + *pos++ = (key->key_idx << 6) | (1 << 5) /* Ext IV included */ ; + *pos++ = key->tx_pn[3]; + *pos++ = key->tx_pn[2]; + *pos++ = key->tx_pn[1]; + *pos++ = key->tx_pn[0]; + + return CCMP_HDR_LEN; +} + +static int lib80211_ccmp_encrypt(struct sk_buff *skb, int hdr_len, void *priv) +{ + struct lib80211_ccmp_data *key = priv; + struct ieee80211_hdr *hdr; + struct aead_request *req; + struct scatterlist sg[2]; + u8 *aad = key->tx_aad; + u8 iv[AES_BLOCK_LEN]; + int len, data_len, aad_len; + int ret; + + if (skb_tailroom(skb) < CCMP_MIC_LEN || skb->len < hdr_len) + return -1; + + data_len = skb->len - hdr_len; + len = lib80211_ccmp_hdr(skb, hdr_len, NULL, 0, priv); + if (len < 0) + return -1; + + req = aead_request_alloc(key->tfm, GFP_ATOMIC); + if (!req) + return -ENOMEM; + + hdr = (struct ieee80211_hdr *)skb->data; + aad_len = ccmp_init_iv_and_aad(hdr, key->tx_pn, iv, aad); + + skb_put(skb, CCMP_MIC_LEN); + + sg_init_table(sg, 2); + sg_set_buf(&sg[0], aad, aad_len); + sg_set_buf(&sg[1], skb->data + hdr_len + CCMP_HDR_LEN, + data_len + CCMP_MIC_LEN); + + aead_request_set_callback(req, 0, NULL, NULL); + aead_request_set_ad(req, aad_len); + aead_request_set_crypt(req, sg, sg, data_len, iv); + + ret = crypto_aead_encrypt(req); + aead_request_free(req); + + return ret; +} + +/* + * deal with seq counter wrapping correctly. + * refer to timer_after() for jiffies wrapping handling + */ +static inline int ccmp_replay_check(u8 *pn_n, u8 *pn_o) +{ + u32 iv32_n, iv16_n; + u32 iv32_o, iv16_o; + + iv32_n = (pn_n[0] << 24) | (pn_n[1] << 16) | (pn_n[2] << 8) | pn_n[3]; + iv16_n = (pn_n[4] << 8) | pn_n[5]; + + iv32_o = (pn_o[0] << 24) | (pn_o[1] << 16) | (pn_o[2] << 8) | pn_o[3]; + iv16_o = (pn_o[4] << 8) | pn_o[5]; + + if ((s32)iv32_n - (s32)iv32_o < 0 || + (iv32_n == iv32_o && iv16_n <= iv16_o)) + return 1; + return 0; +} + +static int lib80211_ccmp_decrypt(struct sk_buff *skb, int hdr_len, void *priv) +{ + struct lib80211_ccmp_data *key = priv; + u8 keyidx, *pos; + struct ieee80211_hdr *hdr; + struct aead_request *req; + struct scatterlist sg[2]; + u8 *aad = key->rx_aad; + u8 iv[AES_BLOCK_LEN]; + u8 pn[6]; + int aad_len, ret; + size_t data_len = skb->len - hdr_len - CCMP_HDR_LEN; + + if (skb->len < hdr_len + CCMP_HDR_LEN + CCMP_MIC_LEN) { + key->dot11RSNAStatsCCMPFormatErrors++; + return -1; + } + + hdr = (struct ieee80211_hdr *)skb->data; + pos = skb->data + hdr_len; + keyidx = pos[3]; + if (!(keyidx & (1 << 5))) { + net_dbg_ratelimited("CCMP: received packet without ExtIV flag from %pM\n", + hdr->addr2); + key->dot11RSNAStatsCCMPFormatErrors++; + return -2; + } + keyidx >>= 6; + if (key->key_idx != keyidx) { + net_dbg_ratelimited("CCMP: RX tkey->key_idx=%d frame keyidx=%d\n", + key->key_idx, keyidx); + return -6; + } + if (!key->key_set) { + net_dbg_ratelimited("CCMP: received packet from %pM with keyid=%d that does not have a configured key\n", + hdr->addr2, keyidx); + return -3; + } + + pn[0] = pos[7]; + pn[1] = pos[6]; + pn[2] = pos[5]; + pn[3] = pos[4]; + pn[4] = pos[1]; + pn[5] = pos[0]; + pos += 8; + + if (ccmp_replay_check(pn, key->rx_pn)) { +#ifdef CONFIG_LIB80211_DEBUG + net_dbg_ratelimited("CCMP: replay detected: STA=%pM previous PN %02x%02x%02x%02x%02x%02x received PN %02x%02x%02x%02x%02x%02x\n", + hdr->addr2, + key->rx_pn[0], key->rx_pn[1], key->rx_pn[2], + key->rx_pn[3], key->rx_pn[4], key->rx_pn[5], + pn[0], pn[1], pn[2], pn[3], pn[4], pn[5]); +#endif + key->dot11RSNAStatsCCMPReplays++; + return -4; + } + + req = aead_request_alloc(key->tfm, GFP_ATOMIC); + if (!req) + return -ENOMEM; + + aad_len = ccmp_init_iv_and_aad(hdr, pn, iv, aad); + + sg_init_table(sg, 2); + sg_set_buf(&sg[0], aad, aad_len); + sg_set_buf(&sg[1], pos, data_len); + + aead_request_set_callback(req, 0, NULL, NULL); + aead_request_set_ad(req, aad_len); + aead_request_set_crypt(req, sg, sg, data_len, iv); + + ret = crypto_aead_decrypt(req); + aead_request_free(req); + + if (ret) { + net_dbg_ratelimited("CCMP: decrypt failed: STA=%pM (%d)\n", + hdr->addr2, ret); + key->dot11RSNAStatsCCMPDecryptErrors++; + return -5; + } + + memcpy(key->rx_pn, pn, CCMP_PN_LEN); + + /* Remove hdr and MIC */ + memmove(skb->data + CCMP_HDR_LEN, skb->data, hdr_len); + skb_pull(skb, CCMP_HDR_LEN); + skb_trim(skb, skb->len - CCMP_MIC_LEN); + + return keyidx; +} + +static int lib80211_ccmp_set_key(void *key, int len, u8 * seq, void *priv) +{ + struct lib80211_ccmp_data *data = priv; + int keyidx; + struct crypto_aead *tfm = data->tfm; + + keyidx = data->key_idx; + memset(data, 0, sizeof(*data)); + data->key_idx = keyidx; + data->tfm = tfm; + if (len == CCMP_TK_LEN) { + memcpy(data->key, key, CCMP_TK_LEN); + data->key_set = 1; + if (seq) { + data->rx_pn[0] = seq[5]; + data->rx_pn[1] = seq[4]; + data->rx_pn[2] = seq[3]; + data->rx_pn[3] = seq[2]; + data->rx_pn[4] = seq[1]; + data->rx_pn[5] = seq[0]; + } + if (crypto_aead_setauthsize(data->tfm, CCMP_MIC_LEN) || + crypto_aead_setkey(data->tfm, data->key, CCMP_TK_LEN)) + return -1; + } else if (len == 0) + data->key_set = 0; + else + return -1; + + return 0; +} + +static int lib80211_ccmp_get_key(void *key, int len, u8 * seq, void *priv) +{ + struct lib80211_ccmp_data *data = priv; + + if (len < CCMP_TK_LEN) + return -1; + + if (!data->key_set) + return 0; + memcpy(key, data->key, CCMP_TK_LEN); + + if (seq) { + seq[0] = data->tx_pn[5]; + seq[1] = data->tx_pn[4]; + seq[2] = data->tx_pn[3]; + seq[3] = data->tx_pn[2]; + seq[4] = data->tx_pn[1]; + seq[5] = data->tx_pn[0]; + } + + return CCMP_TK_LEN; +} + +static void lib80211_ccmp_print_stats(struct seq_file *m, void *priv) +{ + struct lib80211_ccmp_data *ccmp = priv; + + seq_printf(m, + "key[%d] alg=CCMP key_set=%d " + "tx_pn=%02x%02x%02x%02x%02x%02x " + "rx_pn=%02x%02x%02x%02x%02x%02x " + "format_errors=%d replays=%d decrypt_errors=%d\n", + ccmp->key_idx, ccmp->key_set, + ccmp->tx_pn[0], ccmp->tx_pn[1], ccmp->tx_pn[2], + ccmp->tx_pn[3], ccmp->tx_pn[4], ccmp->tx_pn[5], + ccmp->rx_pn[0], ccmp->rx_pn[1], ccmp->rx_pn[2], + ccmp->rx_pn[3], ccmp->rx_pn[4], ccmp->rx_pn[5], + ccmp->dot11RSNAStatsCCMPFormatErrors, + ccmp->dot11RSNAStatsCCMPReplays, + ccmp->dot11RSNAStatsCCMPDecryptErrors); +} + +static struct lib80211_crypto_ops lib80211_crypt_ccmp = { + .name = "CCMP", + .init = lib80211_ccmp_init, + .deinit = lib80211_ccmp_deinit, + .encrypt_mpdu = lib80211_ccmp_encrypt, + .decrypt_mpdu = lib80211_ccmp_decrypt, + .encrypt_msdu = NULL, + .decrypt_msdu = NULL, + .set_key = lib80211_ccmp_set_key, + .get_key = lib80211_ccmp_get_key, + .print_stats = lib80211_ccmp_print_stats, + .extra_mpdu_prefix_len = CCMP_HDR_LEN, + .extra_mpdu_postfix_len = CCMP_MIC_LEN, + .owner = THIS_MODULE, +}; + +static int __init lib80211_crypto_ccmp_init(void) +{ + return lib80211_register_crypto_ops(&lib80211_crypt_ccmp); +} + +static void __exit lib80211_crypto_ccmp_exit(void) +{ + lib80211_unregister_crypto_ops(&lib80211_crypt_ccmp); +} + +module_init(lib80211_crypto_ccmp_init); +module_exit(lib80211_crypto_ccmp_exit); diff --git a/net/wireless/lib80211_crypt_tkip.c b/net/wireless/lib80211_crypt_tkip.c new file mode 100644 index 000000000..1b4d6c87a --- /dev/null +++ b/net/wireless/lib80211_crypt_tkip.c @@ -0,0 +1,738 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * lib80211 crypt: host-based TKIP encryption implementation for lib80211 + * + * Copyright (c) 2003-2004, Jouni Malinen + * Copyright (c) 2008, John W. Linville + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +#include + +MODULE_AUTHOR("Jouni Malinen"); +MODULE_DESCRIPTION("lib80211 crypt: TKIP"); +MODULE_LICENSE("GPL"); + +#define TKIP_HDR_LEN 8 + +struct lib80211_tkip_data { +#define TKIP_KEY_LEN 32 + u8 key[TKIP_KEY_LEN]; + int key_set; + + u32 tx_iv32; + u16 tx_iv16; + u16 tx_ttak[5]; + int tx_phase1_done; + + u32 rx_iv32; + u16 rx_iv16; + u16 rx_ttak[5]; + int rx_phase1_done; + u32 rx_iv32_new; + u16 rx_iv16_new; + + u32 dot11RSNAStatsTKIPReplays; + u32 dot11RSNAStatsTKIPICVErrors; + u32 dot11RSNAStatsTKIPLocalMICFailures; + + int key_idx; + + struct arc4_ctx rx_ctx_arc4; + struct arc4_ctx tx_ctx_arc4; + struct crypto_shash *rx_tfm_michael; + struct crypto_shash *tx_tfm_michael; + + /* scratch buffers for virt_to_page() (crypto API) */ + u8 rx_hdr[16], tx_hdr[16]; + + unsigned long flags; +}; + +static unsigned long lib80211_tkip_set_flags(unsigned long flags, void *priv) +{ + struct lib80211_tkip_data *_priv = priv; + unsigned long old_flags = _priv->flags; + _priv->flags = flags; + return old_flags; +} + +static unsigned long lib80211_tkip_get_flags(void *priv) +{ + struct lib80211_tkip_data *_priv = priv; + return _priv->flags; +} + +static void *lib80211_tkip_init(int key_idx) +{ + struct lib80211_tkip_data *priv; + + if (fips_enabled) + return NULL; + + priv = kzalloc(sizeof(*priv), GFP_ATOMIC); + if (priv == NULL) + goto fail; + + priv->key_idx = key_idx; + + priv->tx_tfm_michael = crypto_alloc_shash("michael_mic", 0, 0); + if (IS_ERR(priv->tx_tfm_michael)) { + priv->tx_tfm_michael = NULL; + goto fail; + } + + priv->rx_tfm_michael = crypto_alloc_shash("michael_mic", 0, 0); + if (IS_ERR(priv->rx_tfm_michael)) { + priv->rx_tfm_michael = NULL; + goto fail; + } + + return priv; + + fail: + if (priv) { + crypto_free_shash(priv->tx_tfm_michael); + crypto_free_shash(priv->rx_tfm_michael); + kfree(priv); + } + + return NULL; +} + +static void lib80211_tkip_deinit(void *priv) +{ + struct lib80211_tkip_data *_priv = priv; + if (_priv) { + crypto_free_shash(_priv->tx_tfm_michael); + crypto_free_shash(_priv->rx_tfm_michael); + } + kfree_sensitive(priv); +} + +static inline u16 RotR1(u16 val) +{ + return (val >> 1) | (val << 15); +} + +static inline u8 Lo8(u16 val) +{ + return val & 0xff; +} + +static inline u8 Hi8(u16 val) +{ + return val >> 8; +} + +static inline u16 Lo16(u32 val) +{ + return val & 0xffff; +} + +static inline u16 Hi16(u32 val) +{ + return val >> 16; +} + +static inline u16 Mk16(u8 hi, u8 lo) +{ + return lo | (((u16) hi) << 8); +} + +static inline u16 Mk16_le(__le16 * v) +{ + return le16_to_cpu(*v); +} + +static const u16 Sbox[256] = { + 0xC6A5, 0xF884, 0xEE99, 0xF68D, 0xFF0D, 0xD6BD, 0xDEB1, 0x9154, + 0x6050, 0x0203, 0xCEA9, 0x567D, 0xE719, 0xB562, 0x4DE6, 0xEC9A, + 0x8F45, 0x1F9D, 0x8940, 0xFA87, 0xEF15, 0xB2EB, 0x8EC9, 0xFB0B, + 0x41EC, 0xB367, 0x5FFD, 0x45EA, 0x23BF, 0x53F7, 0xE496, 0x9B5B, + 0x75C2, 0xE11C, 0x3DAE, 0x4C6A, 0x6C5A, 0x7E41, 0xF502, 0x834F, + 0x685C, 0x51F4, 0xD134, 0xF908, 0xE293, 0xAB73, 0x6253, 0x2A3F, + 0x080C, 0x9552, 0x4665, 0x9D5E, 0x3028, 0x37A1, 0x0A0F, 0x2FB5, + 0x0E09, 0x2436, 0x1B9B, 0xDF3D, 0xCD26, 0x4E69, 0x7FCD, 0xEA9F, + 0x121B, 0x1D9E, 0x5874, 0x342E, 0x362D, 0xDCB2, 0xB4EE, 0x5BFB, + 0xA4F6, 0x764D, 0xB761, 0x7DCE, 0x527B, 0xDD3E, 0x5E71, 0x1397, + 0xA6F5, 0xB968, 0x0000, 0xC12C, 0x4060, 0xE31F, 0x79C8, 0xB6ED, + 0xD4BE, 0x8D46, 0x67D9, 0x724B, 0x94DE, 0x98D4, 0xB0E8, 0x854A, + 0xBB6B, 0xC52A, 0x4FE5, 0xED16, 0x86C5, 0x9AD7, 0x6655, 0x1194, + 0x8ACF, 0xE910, 0x0406, 0xFE81, 0xA0F0, 0x7844, 0x25BA, 0x4BE3, + 0xA2F3, 0x5DFE, 0x80C0, 0x058A, 0x3FAD, 0x21BC, 0x7048, 0xF104, + 0x63DF, 0x77C1, 0xAF75, 0x4263, 0x2030, 0xE51A, 0xFD0E, 0xBF6D, + 0x814C, 0x1814, 0x2635, 0xC32F, 0xBEE1, 0x35A2, 0x88CC, 0x2E39, + 0x9357, 0x55F2, 0xFC82, 0x7A47, 0xC8AC, 0xBAE7, 0x322B, 0xE695, + 0xC0A0, 0x1998, 0x9ED1, 0xA37F, 0x4466, 0x547E, 0x3BAB, 0x0B83, + 0x8CCA, 0xC729, 0x6BD3, 0x283C, 0xA779, 0xBCE2, 0x161D, 0xAD76, + 0xDB3B, 0x6456, 0x744E, 0x141E, 0x92DB, 0x0C0A, 0x486C, 0xB8E4, + 0x9F5D, 0xBD6E, 0x43EF, 0xC4A6, 0x39A8, 0x31A4, 0xD337, 0xF28B, + 0xD532, 0x8B43, 0x6E59, 0xDAB7, 0x018C, 0xB164, 0x9CD2, 0x49E0, + 0xD8B4, 0xACFA, 0xF307, 0xCF25, 0xCAAF, 0xF48E, 0x47E9, 0x1018, + 0x6FD5, 0xF088, 0x4A6F, 0x5C72, 0x3824, 0x57F1, 0x73C7, 0x9751, + 0xCB23, 0xA17C, 0xE89C, 0x3E21, 0x96DD, 0x61DC, 0x0D86, 0x0F85, + 0xE090, 0x7C42, 0x71C4, 0xCCAA, 0x90D8, 0x0605, 0xF701, 0x1C12, + 0xC2A3, 0x6A5F, 0xAEF9, 0x69D0, 0x1791, 0x9958, 0x3A27, 0x27B9, + 0xD938, 0xEB13, 0x2BB3, 0x2233, 0xD2BB, 0xA970, 0x0789, 0x33A7, + 0x2DB6, 0x3C22, 0x1592, 0xC920, 0x8749, 0xAAFF, 0x5078, 0xA57A, + 0x038F, 0x59F8, 0x0980, 0x1A17, 0x65DA, 0xD731, 0x84C6, 0xD0B8, + 0x82C3, 0x29B0, 0x5A77, 0x1E11, 0x7BCB, 0xA8FC, 0x6DD6, 0x2C3A, +}; + +static inline u16 _S_(u16 v) +{ + u16 t = Sbox[Hi8(v)]; + return Sbox[Lo8(v)] ^ ((t << 8) | (t >> 8)); +} + +#define PHASE1_LOOP_COUNT 8 + +static void tkip_mixing_phase1(u16 * TTAK, const u8 * TK, const u8 * TA, + u32 IV32) +{ + int i, j; + + /* Initialize the 80-bit TTAK from TSC (IV32) and TA[0..5] */ + TTAK[0] = Lo16(IV32); + TTAK[1] = Hi16(IV32); + TTAK[2] = Mk16(TA[1], TA[0]); + TTAK[3] = Mk16(TA[3], TA[2]); + TTAK[4] = Mk16(TA[5], TA[4]); + + for (i = 0; i < PHASE1_LOOP_COUNT; i++) { + j = 2 * (i & 1); + TTAK[0] += _S_(TTAK[4] ^ Mk16(TK[1 + j], TK[0 + j])); + TTAK[1] += _S_(TTAK[0] ^ Mk16(TK[5 + j], TK[4 + j])); + TTAK[2] += _S_(TTAK[1] ^ Mk16(TK[9 + j], TK[8 + j])); + TTAK[3] += _S_(TTAK[2] ^ Mk16(TK[13 + j], TK[12 + j])); + TTAK[4] += _S_(TTAK[3] ^ Mk16(TK[1 + j], TK[0 + j])) + i; + } +} + +static void tkip_mixing_phase2(u8 * WEPSeed, const u8 * TK, const u16 * TTAK, + u16 IV16) +{ + /* Make temporary area overlap WEP seed so that the final copy can be + * avoided on little endian hosts. */ + u16 *PPK = (u16 *) & WEPSeed[4]; + + /* Step 1 - make copy of TTAK and bring in TSC */ + PPK[0] = TTAK[0]; + PPK[1] = TTAK[1]; + PPK[2] = TTAK[2]; + PPK[3] = TTAK[3]; + PPK[4] = TTAK[4]; + PPK[5] = TTAK[4] + IV16; + + /* Step 2 - 96-bit bijective mixing using S-box */ + PPK[0] += _S_(PPK[5] ^ Mk16_le((__le16 *) & TK[0])); + PPK[1] += _S_(PPK[0] ^ Mk16_le((__le16 *) & TK[2])); + PPK[2] += _S_(PPK[1] ^ Mk16_le((__le16 *) & TK[4])); + PPK[3] += _S_(PPK[2] ^ Mk16_le((__le16 *) & TK[6])); + PPK[4] += _S_(PPK[3] ^ Mk16_le((__le16 *) & TK[8])); + PPK[5] += _S_(PPK[4] ^ Mk16_le((__le16 *) & TK[10])); + + PPK[0] += RotR1(PPK[5] ^ Mk16_le((__le16 *) & TK[12])); + PPK[1] += RotR1(PPK[0] ^ Mk16_le((__le16 *) & TK[14])); + PPK[2] += RotR1(PPK[1]); + PPK[3] += RotR1(PPK[2]); + PPK[4] += RotR1(PPK[3]); + PPK[5] += RotR1(PPK[4]); + + /* Step 3 - bring in last of TK bits, assign 24-bit WEP IV value + * WEPSeed[0..2] is transmitted as WEP IV */ + WEPSeed[0] = Hi8(IV16); + WEPSeed[1] = (Hi8(IV16) | 0x20) & 0x7F; + WEPSeed[2] = Lo8(IV16); + WEPSeed[3] = Lo8((PPK[5] ^ Mk16_le((__le16 *) & TK[0])) >> 1); + +#ifdef __BIG_ENDIAN + { + int i; + for (i = 0; i < 6; i++) + PPK[i] = (PPK[i] << 8) | (PPK[i] >> 8); + } +#endif +} + +static int lib80211_tkip_hdr(struct sk_buff *skb, int hdr_len, + u8 * rc4key, int keylen, void *priv) +{ + struct lib80211_tkip_data *tkey = priv; + u8 *pos; + struct ieee80211_hdr *hdr; + + hdr = (struct ieee80211_hdr *)skb->data; + + if (skb_headroom(skb) < TKIP_HDR_LEN || skb->len < hdr_len) + return -1; + + if (rc4key == NULL || keylen < 16) + return -1; + + if (!tkey->tx_phase1_done) { + tkip_mixing_phase1(tkey->tx_ttak, tkey->key, hdr->addr2, + tkey->tx_iv32); + tkey->tx_phase1_done = 1; + } + tkip_mixing_phase2(rc4key, tkey->key, tkey->tx_ttak, tkey->tx_iv16); + + pos = skb_push(skb, TKIP_HDR_LEN); + memmove(pos, pos + TKIP_HDR_LEN, hdr_len); + pos += hdr_len; + + *pos++ = *rc4key; + *pos++ = *(rc4key + 1); + *pos++ = *(rc4key + 2); + *pos++ = (tkey->key_idx << 6) | (1 << 5) /* Ext IV included */ ; + *pos++ = tkey->tx_iv32 & 0xff; + *pos++ = (tkey->tx_iv32 >> 8) & 0xff; + *pos++ = (tkey->tx_iv32 >> 16) & 0xff; + *pos++ = (tkey->tx_iv32 >> 24) & 0xff; + + tkey->tx_iv16++; + if (tkey->tx_iv16 == 0) { + tkey->tx_phase1_done = 0; + tkey->tx_iv32++; + } + + return TKIP_HDR_LEN; +} + +static int lib80211_tkip_encrypt(struct sk_buff *skb, int hdr_len, void *priv) +{ + struct lib80211_tkip_data *tkey = priv; + int len; + u8 rc4key[16], *pos, *icv; + u32 crc; + + if (tkey->flags & IEEE80211_CRYPTO_TKIP_COUNTERMEASURES) { + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *)skb->data; + net_dbg_ratelimited("TKIP countermeasures: dropped TX packet to %pM\n", + hdr->addr1); + return -1; + } + + if (skb_tailroom(skb) < 4 || skb->len < hdr_len) + return -1; + + len = skb->len - hdr_len; + pos = skb->data + hdr_len; + + if ((lib80211_tkip_hdr(skb, hdr_len, rc4key, 16, priv)) < 0) + return -1; + + crc = ~crc32_le(~0, pos, len); + icv = skb_put(skb, 4); + icv[0] = crc; + icv[1] = crc >> 8; + icv[2] = crc >> 16; + icv[3] = crc >> 24; + + arc4_setkey(&tkey->tx_ctx_arc4, rc4key, 16); + arc4_crypt(&tkey->tx_ctx_arc4, pos, pos, len + 4); + + return 0; +} + +/* + * deal with seq counter wrapping correctly. + * refer to timer_after() for jiffies wrapping handling + */ +static inline int tkip_replay_check(u32 iv32_n, u16 iv16_n, + u32 iv32_o, u16 iv16_o) +{ + if ((s32)iv32_n - (s32)iv32_o < 0 || + (iv32_n == iv32_o && iv16_n <= iv16_o)) + return 1; + return 0; +} + +static int lib80211_tkip_decrypt(struct sk_buff *skb, int hdr_len, void *priv) +{ + struct lib80211_tkip_data *tkey = priv; + u8 rc4key[16]; + u8 keyidx, *pos; + u32 iv32; + u16 iv16; + struct ieee80211_hdr *hdr; + u8 icv[4]; + u32 crc; + int plen; + + hdr = (struct ieee80211_hdr *)skb->data; + + if (tkey->flags & IEEE80211_CRYPTO_TKIP_COUNTERMEASURES) { + net_dbg_ratelimited("TKIP countermeasures: dropped received packet from %pM\n", + hdr->addr2); + return -1; + } + + if (skb->len < hdr_len + TKIP_HDR_LEN + 4) + return -1; + + pos = skb->data + hdr_len; + keyidx = pos[3]; + if (!(keyidx & (1 << 5))) { + net_dbg_ratelimited("TKIP: received packet without ExtIV flag from %pM\n", + hdr->addr2); + return -2; + } + keyidx >>= 6; + if (tkey->key_idx != keyidx) { + net_dbg_ratelimited("TKIP: RX tkey->key_idx=%d frame keyidx=%d\n", + tkey->key_idx, keyidx); + return -6; + } + if (!tkey->key_set) { + net_dbg_ratelimited("TKIP: received packet from %pM with keyid=%d that does not have a configured key\n", + hdr->addr2, keyidx); + return -3; + } + iv16 = (pos[0] << 8) | pos[2]; + iv32 = pos[4] | (pos[5] << 8) | (pos[6] << 16) | (pos[7] << 24); + pos += TKIP_HDR_LEN; + + if (tkip_replay_check(iv32, iv16, tkey->rx_iv32, tkey->rx_iv16)) { +#ifdef CONFIG_LIB80211_DEBUG + net_dbg_ratelimited("TKIP: replay detected: STA=%pM previous TSC %08x%04x received TSC %08x%04x\n", + hdr->addr2, tkey->rx_iv32, tkey->rx_iv16, + iv32, iv16); +#endif + tkey->dot11RSNAStatsTKIPReplays++; + return -4; + } + + if (iv32 != tkey->rx_iv32 || !tkey->rx_phase1_done) { + tkip_mixing_phase1(tkey->rx_ttak, tkey->key, hdr->addr2, iv32); + tkey->rx_phase1_done = 1; + } + tkip_mixing_phase2(rc4key, tkey->key, tkey->rx_ttak, iv16); + + plen = skb->len - hdr_len - 12; + + arc4_setkey(&tkey->rx_ctx_arc4, rc4key, 16); + arc4_crypt(&tkey->rx_ctx_arc4, pos, pos, plen + 4); + + crc = ~crc32_le(~0, pos, plen); + icv[0] = crc; + icv[1] = crc >> 8; + icv[2] = crc >> 16; + icv[3] = crc >> 24; + if (memcmp(icv, pos + plen, 4) != 0) { + if (iv32 != tkey->rx_iv32) { + /* Previously cached Phase1 result was already lost, so + * it needs to be recalculated for the next packet. */ + tkey->rx_phase1_done = 0; + } +#ifdef CONFIG_LIB80211_DEBUG + net_dbg_ratelimited("TKIP: ICV error detected: STA=%pM\n", + hdr->addr2); +#endif + tkey->dot11RSNAStatsTKIPICVErrors++; + return -5; + } + + /* Update real counters only after Michael MIC verification has + * completed */ + tkey->rx_iv32_new = iv32; + tkey->rx_iv16_new = iv16; + + /* Remove IV and ICV */ + memmove(skb->data + TKIP_HDR_LEN, skb->data, hdr_len); + skb_pull(skb, TKIP_HDR_LEN); + skb_trim(skb, skb->len - 4); + + return keyidx; +} + +static int michael_mic(struct crypto_shash *tfm_michael, u8 *key, u8 *hdr, + u8 *data, size_t data_len, u8 *mic) +{ + SHASH_DESC_ON_STACK(desc, tfm_michael); + int err; + + if (tfm_michael == NULL) { + pr_warn("%s(): tfm_michael == NULL\n", __func__); + return -1; + } + + desc->tfm = tfm_michael; + + if (crypto_shash_setkey(tfm_michael, key, 8)) + return -1; + + err = crypto_shash_init(desc); + if (err) + goto out; + err = crypto_shash_update(desc, hdr, 16); + if (err) + goto out; + err = crypto_shash_update(desc, data, data_len); + if (err) + goto out; + err = crypto_shash_final(desc, mic); + +out: + shash_desc_zero(desc); + return err; +} + +static void michael_mic_hdr(struct sk_buff *skb, u8 * hdr) +{ + struct ieee80211_hdr *hdr11; + + hdr11 = (struct ieee80211_hdr *)skb->data; + + switch (le16_to_cpu(hdr11->frame_control) & + (IEEE80211_FCTL_FROMDS | IEEE80211_FCTL_TODS)) { + case IEEE80211_FCTL_TODS: + memcpy(hdr, hdr11->addr3, ETH_ALEN); /* DA */ + memcpy(hdr + ETH_ALEN, hdr11->addr2, ETH_ALEN); /* SA */ + break; + case IEEE80211_FCTL_FROMDS: + memcpy(hdr, hdr11->addr1, ETH_ALEN); /* DA */ + memcpy(hdr + ETH_ALEN, hdr11->addr3, ETH_ALEN); /* SA */ + break; + case IEEE80211_FCTL_FROMDS | IEEE80211_FCTL_TODS: + memcpy(hdr, hdr11->addr3, ETH_ALEN); /* DA */ + memcpy(hdr + ETH_ALEN, hdr11->addr4, ETH_ALEN); /* SA */ + break; + default: + memcpy(hdr, hdr11->addr1, ETH_ALEN); /* DA */ + memcpy(hdr + ETH_ALEN, hdr11->addr2, ETH_ALEN); /* SA */ + break; + } + + if (ieee80211_is_data_qos(hdr11->frame_control)) { + hdr[12] = le16_to_cpu(*((__le16 *)ieee80211_get_qos_ctl(hdr11))) + & IEEE80211_QOS_CTL_TID_MASK; + } else + hdr[12] = 0; /* priority */ + + hdr[13] = hdr[14] = hdr[15] = 0; /* reserved */ +} + +static int lib80211_michael_mic_add(struct sk_buff *skb, int hdr_len, + void *priv) +{ + struct lib80211_tkip_data *tkey = priv; + u8 *pos; + + if (skb_tailroom(skb) < 8 || skb->len < hdr_len) { + printk(KERN_DEBUG "Invalid packet for Michael MIC add " + "(tailroom=%d hdr_len=%d skb->len=%d)\n", + skb_tailroom(skb), hdr_len, skb->len); + return -1; + } + + michael_mic_hdr(skb, tkey->tx_hdr); + pos = skb_put(skb, 8); + if (michael_mic(tkey->tx_tfm_michael, &tkey->key[16], tkey->tx_hdr, + skb->data + hdr_len, skb->len - 8 - hdr_len, pos)) + return -1; + + return 0; +} + +static void lib80211_michael_mic_failure(struct net_device *dev, + struct ieee80211_hdr *hdr, + int keyidx) +{ + union iwreq_data wrqu; + struct iw_michaelmicfailure ev; + + /* TODO: needed parameters: count, keyid, key type, TSC */ + memset(&ev, 0, sizeof(ev)); + ev.flags = keyidx & IW_MICFAILURE_KEY_ID; + if (hdr->addr1[0] & 0x01) + ev.flags |= IW_MICFAILURE_GROUP; + else + ev.flags |= IW_MICFAILURE_PAIRWISE; + ev.src_addr.sa_family = ARPHRD_ETHER; + memcpy(ev.src_addr.sa_data, hdr->addr2, ETH_ALEN); + memset(&wrqu, 0, sizeof(wrqu)); + wrqu.data.length = sizeof(ev); + wireless_send_event(dev, IWEVMICHAELMICFAILURE, &wrqu, (char *)&ev); +} + +static int lib80211_michael_mic_verify(struct sk_buff *skb, int keyidx, + int hdr_len, void *priv) +{ + struct lib80211_tkip_data *tkey = priv; + u8 mic[8]; + + if (!tkey->key_set) + return -1; + + michael_mic_hdr(skb, tkey->rx_hdr); + if (michael_mic(tkey->rx_tfm_michael, &tkey->key[24], tkey->rx_hdr, + skb->data + hdr_len, skb->len - 8 - hdr_len, mic)) + return -1; + if (memcmp(mic, skb->data + skb->len - 8, 8) != 0) { + struct ieee80211_hdr *hdr; + hdr = (struct ieee80211_hdr *)skb->data; + printk(KERN_DEBUG "%s: Michael MIC verification failed for " + "MSDU from %pM keyidx=%d\n", + skb->dev ? skb->dev->name : "N/A", hdr->addr2, + keyidx); + if (skb->dev) + lib80211_michael_mic_failure(skb->dev, hdr, keyidx); + tkey->dot11RSNAStatsTKIPLocalMICFailures++; + return -1; + } + + /* Update TSC counters for RX now that the packet verification has + * completed. */ + tkey->rx_iv32 = tkey->rx_iv32_new; + tkey->rx_iv16 = tkey->rx_iv16_new; + + skb_trim(skb, skb->len - 8); + + return 0; +} + +static int lib80211_tkip_set_key(void *key, int len, u8 * seq, void *priv) +{ + struct lib80211_tkip_data *tkey = priv; + int keyidx; + struct crypto_shash *tfm = tkey->tx_tfm_michael; + struct arc4_ctx *tfm2 = &tkey->tx_ctx_arc4; + struct crypto_shash *tfm3 = tkey->rx_tfm_michael; + struct arc4_ctx *tfm4 = &tkey->rx_ctx_arc4; + + keyidx = tkey->key_idx; + memset(tkey, 0, sizeof(*tkey)); + tkey->key_idx = keyidx; + tkey->tx_tfm_michael = tfm; + tkey->tx_ctx_arc4 = *tfm2; + tkey->rx_tfm_michael = tfm3; + tkey->rx_ctx_arc4 = *tfm4; + if (len == TKIP_KEY_LEN) { + memcpy(tkey->key, key, TKIP_KEY_LEN); + tkey->key_set = 1; + tkey->tx_iv16 = 1; /* TSC is initialized to 1 */ + if (seq) { + tkey->rx_iv32 = (seq[5] << 24) | (seq[4] << 16) | + (seq[3] << 8) | seq[2]; + tkey->rx_iv16 = (seq[1] << 8) | seq[0]; + } + } else if (len == 0) + tkey->key_set = 0; + else + return -1; + + return 0; +} + +static int lib80211_tkip_get_key(void *key, int len, u8 * seq, void *priv) +{ + struct lib80211_tkip_data *tkey = priv; + + if (len < TKIP_KEY_LEN) + return -1; + + if (!tkey->key_set) + return 0; + memcpy(key, tkey->key, TKIP_KEY_LEN); + + if (seq) { + /* Return the sequence number of the last transmitted frame. */ + u16 iv16 = tkey->tx_iv16; + u32 iv32 = tkey->tx_iv32; + if (iv16 == 0) + iv32--; + iv16--; + seq[0] = tkey->tx_iv16; + seq[1] = tkey->tx_iv16 >> 8; + seq[2] = tkey->tx_iv32; + seq[3] = tkey->tx_iv32 >> 8; + seq[4] = tkey->tx_iv32 >> 16; + seq[5] = tkey->tx_iv32 >> 24; + } + + return TKIP_KEY_LEN; +} + +static void lib80211_tkip_print_stats(struct seq_file *m, void *priv) +{ + struct lib80211_tkip_data *tkip = priv; + seq_printf(m, + "key[%d] alg=TKIP key_set=%d " + "tx_pn=%02x%02x%02x%02x%02x%02x " + "rx_pn=%02x%02x%02x%02x%02x%02x " + "replays=%d icv_errors=%d local_mic_failures=%d\n", + tkip->key_idx, tkip->key_set, + (tkip->tx_iv32 >> 24) & 0xff, + (tkip->tx_iv32 >> 16) & 0xff, + (tkip->tx_iv32 >> 8) & 0xff, + tkip->tx_iv32 & 0xff, + (tkip->tx_iv16 >> 8) & 0xff, + tkip->tx_iv16 & 0xff, + (tkip->rx_iv32 >> 24) & 0xff, + (tkip->rx_iv32 >> 16) & 0xff, + (tkip->rx_iv32 >> 8) & 0xff, + tkip->rx_iv32 & 0xff, + (tkip->rx_iv16 >> 8) & 0xff, + tkip->rx_iv16 & 0xff, + tkip->dot11RSNAStatsTKIPReplays, + tkip->dot11RSNAStatsTKIPICVErrors, + tkip->dot11RSNAStatsTKIPLocalMICFailures); +} + +static struct lib80211_crypto_ops lib80211_crypt_tkip = { + .name = "TKIP", + .init = lib80211_tkip_init, + .deinit = lib80211_tkip_deinit, + .encrypt_mpdu = lib80211_tkip_encrypt, + .decrypt_mpdu = lib80211_tkip_decrypt, + .encrypt_msdu = lib80211_michael_mic_add, + .decrypt_msdu = lib80211_michael_mic_verify, + .set_key = lib80211_tkip_set_key, + .get_key = lib80211_tkip_get_key, + .print_stats = lib80211_tkip_print_stats, + .extra_mpdu_prefix_len = 4 + 4, /* IV + ExtIV */ + .extra_mpdu_postfix_len = 4, /* ICV */ + .extra_msdu_postfix_len = 8, /* MIC */ + .get_flags = lib80211_tkip_get_flags, + .set_flags = lib80211_tkip_set_flags, + .owner = THIS_MODULE, +}; + +static int __init lib80211_crypto_tkip_init(void) +{ + return lib80211_register_crypto_ops(&lib80211_crypt_tkip); +} + +static void __exit lib80211_crypto_tkip_exit(void) +{ + lib80211_unregister_crypto_ops(&lib80211_crypt_tkip); +} + +module_init(lib80211_crypto_tkip_init); +module_exit(lib80211_crypto_tkip_exit); diff --git a/net/wireless/lib80211_crypt_wep.c b/net/wireless/lib80211_crypt_wep.c new file mode 100644 index 000000000..6ab9957b8 --- /dev/null +++ b/net/wireless/lib80211_crypt_wep.c @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * lib80211 crypt: host-based WEP encryption implementation for lib80211 + * + * Copyright (c) 2002-2004, Jouni Malinen + * Copyright (c) 2008, John W. Linville + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +MODULE_AUTHOR("Jouni Malinen"); +MODULE_DESCRIPTION("lib80211 crypt: WEP"); +MODULE_LICENSE("GPL"); + +struct lib80211_wep_data { + u32 iv; +#define WEP_KEY_LEN 13 + u8 key[WEP_KEY_LEN + 1]; + u8 key_len; + u8 key_idx; + struct arc4_ctx tx_ctx; + struct arc4_ctx rx_ctx; +}; + +static void *lib80211_wep_init(int keyidx) +{ + struct lib80211_wep_data *priv; + + if (fips_enabled) + return NULL; + + priv = kzalloc(sizeof(*priv), GFP_ATOMIC); + if (priv == NULL) + return NULL; + priv->key_idx = keyidx; + + /* start WEP IV from a random value */ + get_random_bytes(&priv->iv, 4); + + return priv; +} + +static void lib80211_wep_deinit(void *priv) +{ + kfree_sensitive(priv); +} + +/* Add WEP IV/key info to a frame that has at least 4 bytes of headroom */ +static int lib80211_wep_build_iv(struct sk_buff *skb, int hdr_len, + u8 *key, int keylen, void *priv) +{ + struct lib80211_wep_data *wep = priv; + u32 klen; + u8 *pos; + + if (skb_headroom(skb) < 4 || skb->len < hdr_len) + return -1; + + pos = skb_push(skb, 4); + memmove(pos, pos + 4, hdr_len); + pos += hdr_len; + + klen = 3 + wep->key_len; + + wep->iv++; + + /* Fluhrer, Mantin, and Shamir have reported weaknesses in the key + * scheduling algorithm of RC4. At least IVs (KeyByte + 3, 0xff, N) + * can be used to speedup attacks, so avoid using them. */ + if ((wep->iv & 0xff00) == 0xff00) { + u8 B = (wep->iv >> 16) & 0xff; + if (B >= 3 && B < klen) + wep->iv += 0x0100; + } + + /* Prepend 24-bit IV to RC4 key and TX frame */ + *pos++ = (wep->iv >> 16) & 0xff; + *pos++ = (wep->iv >> 8) & 0xff; + *pos++ = wep->iv & 0xff; + *pos++ = wep->key_idx << 6; + + return 0; +} + +/* Perform WEP encryption on given skb that has at least 4 bytes of headroom + * for IV and 4 bytes of tailroom for ICV. Both IV and ICV will be transmitted, + * so the payload length increases with 8 bytes. + * + * WEP frame payload: IV + TX key idx, RC4(data), ICV = RC4(CRC32(data)) + */ +static int lib80211_wep_encrypt(struct sk_buff *skb, int hdr_len, void *priv) +{ + struct lib80211_wep_data *wep = priv; + u32 crc, klen, len; + u8 *pos, *icv; + u8 key[WEP_KEY_LEN + 3]; + + /* other checks are in lib80211_wep_build_iv */ + if (skb_tailroom(skb) < 4) + return -1; + + /* add the IV to the frame */ + if (lib80211_wep_build_iv(skb, hdr_len, NULL, 0, priv)) + return -1; + + /* Copy the IV into the first 3 bytes of the key */ + skb_copy_from_linear_data_offset(skb, hdr_len, key, 3); + + /* Copy rest of the WEP key (the secret part) */ + memcpy(key + 3, wep->key, wep->key_len); + + len = skb->len - hdr_len - 4; + pos = skb->data + hdr_len + 4; + klen = 3 + wep->key_len; + + /* Append little-endian CRC32 over only the data and encrypt it to produce ICV */ + crc = ~crc32_le(~0, pos, len); + icv = skb_put(skb, 4); + icv[0] = crc; + icv[1] = crc >> 8; + icv[2] = crc >> 16; + icv[3] = crc >> 24; + + arc4_setkey(&wep->tx_ctx, key, klen); + arc4_crypt(&wep->tx_ctx, pos, pos, len + 4); + + return 0; +} + +/* Perform WEP decryption on given buffer. Buffer includes whole WEP part of + * the frame: IV (4 bytes), encrypted payload (including SNAP header), + * ICV (4 bytes). len includes both IV and ICV. + * + * Returns 0 if frame was decrypted successfully and ICV was correct and -1 on + * failure. If frame is OK, IV and ICV will be removed. + */ +static int lib80211_wep_decrypt(struct sk_buff *skb, int hdr_len, void *priv) +{ + struct lib80211_wep_data *wep = priv; + u32 crc, klen, plen; + u8 key[WEP_KEY_LEN + 3]; + u8 keyidx, *pos, icv[4]; + + if (skb->len < hdr_len + 8) + return -1; + + pos = skb->data + hdr_len; + key[0] = *pos++; + key[1] = *pos++; + key[2] = *pos++; + keyidx = *pos++ >> 6; + if (keyidx != wep->key_idx) + return -1; + + klen = 3 + wep->key_len; + + /* Copy rest of the WEP key (the secret part) */ + memcpy(key + 3, wep->key, wep->key_len); + + /* Apply RC4 to data and compute CRC32 over decrypted data */ + plen = skb->len - hdr_len - 8; + + arc4_setkey(&wep->rx_ctx, key, klen); + arc4_crypt(&wep->rx_ctx, pos, pos, plen + 4); + + crc = ~crc32_le(~0, pos, plen); + icv[0] = crc; + icv[1] = crc >> 8; + icv[2] = crc >> 16; + icv[3] = crc >> 24; + if (memcmp(icv, pos + plen, 4) != 0) { + /* ICV mismatch - drop frame */ + return -2; + } + + /* Remove IV and ICV */ + memmove(skb->data + 4, skb->data, hdr_len); + skb_pull(skb, 4); + skb_trim(skb, skb->len - 4); + + return 0; +} + +static int lib80211_wep_set_key(void *key, int len, u8 * seq, void *priv) +{ + struct lib80211_wep_data *wep = priv; + + if (len < 0 || len > WEP_KEY_LEN) + return -1; + + memcpy(wep->key, key, len); + wep->key_len = len; + + return 0; +} + +static int lib80211_wep_get_key(void *key, int len, u8 * seq, void *priv) +{ + struct lib80211_wep_data *wep = priv; + + if (len < wep->key_len) + return -1; + + memcpy(key, wep->key, wep->key_len); + + return wep->key_len; +} + +static void lib80211_wep_print_stats(struct seq_file *m, void *priv) +{ + struct lib80211_wep_data *wep = priv; + seq_printf(m, "key[%d] alg=WEP len=%d\n", wep->key_idx, wep->key_len); +} + +static struct lib80211_crypto_ops lib80211_crypt_wep = { + .name = "WEP", + .init = lib80211_wep_init, + .deinit = lib80211_wep_deinit, + .encrypt_mpdu = lib80211_wep_encrypt, + .decrypt_mpdu = lib80211_wep_decrypt, + .encrypt_msdu = NULL, + .decrypt_msdu = NULL, + .set_key = lib80211_wep_set_key, + .get_key = lib80211_wep_get_key, + .print_stats = lib80211_wep_print_stats, + .extra_mpdu_prefix_len = 4, /* IV */ + .extra_mpdu_postfix_len = 4, /* ICV */ + .owner = THIS_MODULE, +}; + +static int __init lib80211_crypto_wep_init(void) +{ + return lib80211_register_crypto_ops(&lib80211_crypt_wep); +} + +static void __exit lib80211_crypto_wep_exit(void) +{ + lib80211_unregister_crypto_ops(&lib80211_crypt_wep); +} + +module_init(lib80211_crypto_wep_init); +module_exit(lib80211_crypto_wep_exit); diff --git a/net/wireless/mesh.c b/net/wireless/mesh.c new file mode 100644 index 000000000..59a3c5c09 --- /dev/null +++ b/net/wireless/mesh.c @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Portions + * Copyright (C) 2022 Intel Corporation + */ +#include +#include +#include +#include "nl80211.h" +#include "core.h" +#include "rdev-ops.h" + +/* Default values, timeouts in ms */ +#define MESH_TTL 31 +#define MESH_DEFAULT_ELEMENT_TTL 31 +#define MESH_MAX_RETR 3 +#define MESH_RET_T 100 +#define MESH_CONF_T 100 +#define MESH_HOLD_T 100 + +#define MESH_PATH_TIMEOUT 5000 +#define MESH_RANN_INTERVAL 5000 +#define MESH_PATH_TO_ROOT_TIMEOUT 6000 +#define MESH_ROOT_INTERVAL 5000 +#define MESH_ROOT_CONFIRMATION_INTERVAL 2000 +#define MESH_DEFAULT_PLINK_TIMEOUT 1800 /* timeout in seconds */ + +/* + * Minimum interval between two consecutive PREQs originated by the same + * interface + */ +#define MESH_PREQ_MIN_INT 10 +#define MESH_PERR_MIN_INT 100 +#define MESH_DIAM_TRAVERSAL_TIME 50 + +#define MESH_RSSI_THRESHOLD 0 + +/* + * A path will be refreshed if it is used PATH_REFRESH_TIME milliseconds + * before timing out. This way it will remain ACTIVE and no data frames + * will be unnecessarily held in the pending queue. + */ +#define MESH_PATH_REFRESH_TIME 1000 +#define MESH_MIN_DISCOVERY_TIMEOUT (2 * MESH_DIAM_TRAVERSAL_TIME) + +/* Default maximum number of established plinks per interface */ +#define MESH_MAX_ESTAB_PLINKS 32 + +#define MESH_MAX_PREQ_RETRIES 4 + +#define MESH_SYNC_NEIGHBOR_OFFSET_MAX 50 + +#define MESH_DEFAULT_BEACON_INTERVAL 1000 /* in 1024 us units (=TUs) */ +#define MESH_DEFAULT_DTIM_PERIOD 2 +#define MESH_DEFAULT_AWAKE_WINDOW 10 /* in 1024 us units (=TUs) */ + +const struct mesh_config default_mesh_config = { + .dot11MeshRetryTimeout = MESH_RET_T, + .dot11MeshConfirmTimeout = MESH_CONF_T, + .dot11MeshHoldingTimeout = MESH_HOLD_T, + .dot11MeshMaxRetries = MESH_MAX_RETR, + .dot11MeshTTL = MESH_TTL, + .element_ttl = MESH_DEFAULT_ELEMENT_TTL, + .auto_open_plinks = true, + .dot11MeshMaxPeerLinks = MESH_MAX_ESTAB_PLINKS, + .dot11MeshNbrOffsetMaxNeighbor = MESH_SYNC_NEIGHBOR_OFFSET_MAX, + .dot11MeshHWMPactivePathTimeout = MESH_PATH_TIMEOUT, + .dot11MeshHWMPpreqMinInterval = MESH_PREQ_MIN_INT, + .dot11MeshHWMPperrMinInterval = MESH_PERR_MIN_INT, + .dot11MeshHWMPnetDiameterTraversalTime = MESH_DIAM_TRAVERSAL_TIME, + .dot11MeshHWMPmaxPREQretries = MESH_MAX_PREQ_RETRIES, + .path_refresh_time = MESH_PATH_REFRESH_TIME, + .min_discovery_timeout = MESH_MIN_DISCOVERY_TIMEOUT, + .dot11MeshHWMPRannInterval = MESH_RANN_INTERVAL, + .dot11MeshGateAnnouncementProtocol = false, + .dot11MeshForwarding = true, + .rssi_threshold = MESH_RSSI_THRESHOLD, + .ht_opmode = IEEE80211_HT_OP_MODE_PROTECTION_NONHT_MIXED, + .dot11MeshHWMPactivePathToRootTimeout = MESH_PATH_TO_ROOT_TIMEOUT, + .dot11MeshHWMProotInterval = MESH_ROOT_INTERVAL, + .dot11MeshHWMPconfirmationInterval = MESH_ROOT_CONFIRMATION_INTERVAL, + .power_mode = NL80211_MESH_POWER_ACTIVE, + .dot11MeshAwakeWindowDuration = MESH_DEFAULT_AWAKE_WINDOW, + .plink_timeout = MESH_DEFAULT_PLINK_TIMEOUT, + .dot11MeshNolearn = false, +}; + +const struct mesh_setup default_mesh_setup = { + /* cfg80211_join_mesh() will pick a channel if needed */ + .sync_method = IEEE80211_SYNC_METHOD_NEIGHBOR_OFFSET, + .path_sel_proto = IEEE80211_PATH_PROTOCOL_HWMP, + .path_metric = IEEE80211_PATH_METRIC_AIRTIME, + .auth_id = 0, /* open */ + .ie = NULL, + .ie_len = 0, + .is_secure = false, + .user_mpm = false, + .beacon_interval = MESH_DEFAULT_BEACON_INTERVAL, + .dtim_period = MESH_DEFAULT_DTIM_PERIOD, +}; + +int __cfg80211_join_mesh(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct mesh_setup *setup, + const struct mesh_config *conf) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int err; + + BUILD_BUG_ON(IEEE80211_MAX_SSID_LEN != IEEE80211_MAX_MESH_ID_LEN); + + ASSERT_WDEV_LOCK(wdev); + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_MESH_POINT) + return -EOPNOTSUPP; + + if (!(rdev->wiphy.flags & WIPHY_FLAG_MESH_AUTH) && + setup->is_secure) + return -EOPNOTSUPP; + + if (wdev->u.mesh.id_len) + return -EALREADY; + + if (!setup->mesh_id_len) + return -EINVAL; + + if (!rdev->ops->join_mesh) + return -EOPNOTSUPP; + + if (!setup->chandef.chan) { + /* if no channel explicitly given, use preset channel */ + setup->chandef = wdev->u.mesh.preset_chandef; + } + + if (!setup->chandef.chan) { + /* if we don't have that either, use the first usable channel */ + enum nl80211_band band; + + for (band = 0; band < NUM_NL80211_BANDS; band++) { + struct ieee80211_supported_band *sband; + struct ieee80211_channel *chan; + int i; + + sband = rdev->wiphy.bands[band]; + if (!sband) + continue; + + for (i = 0; i < sband->n_channels; i++) { + chan = &sband->channels[i]; + if (chan->flags & (IEEE80211_CHAN_NO_IR | + IEEE80211_CHAN_DISABLED | + IEEE80211_CHAN_RADAR)) + continue; + setup->chandef.chan = chan; + break; + } + + if (setup->chandef.chan) + break; + } + + /* no usable channel ... */ + if (!setup->chandef.chan) + return -EINVAL; + + setup->chandef.width = NL80211_CHAN_WIDTH_20_NOHT; + setup->chandef.center_freq1 = setup->chandef.chan->center_freq; + } + + /* + * check if basic rates are available otherwise use mandatory rates as + * basic rates + */ + if (!setup->basic_rates) { + enum nl80211_bss_scan_width scan_width; + struct ieee80211_supported_band *sband = + rdev->wiphy.bands[setup->chandef.chan->band]; + + if (setup->chandef.chan->band == NL80211_BAND_2GHZ) { + int i; + + /* + * Older versions selected the mandatory rates for + * 2.4 GHz as well, but were broken in that only + * 1 Mbps was regarded as a mandatory rate. Keep + * using just 1 Mbps as the default basic rate for + * mesh to be interoperable with older versions. + */ + for (i = 0; i < sband->n_bitrates; i++) { + if (sband->bitrates[i].bitrate == 10) { + setup->basic_rates = BIT(i); + break; + } + } + } else { + scan_width = cfg80211_chandef_to_scan_width(&setup->chandef); + setup->basic_rates = ieee80211_mandatory_rates(sband, + scan_width); + } + } + + err = cfg80211_chandef_dfs_required(&rdev->wiphy, + &setup->chandef, + NL80211_IFTYPE_MESH_POINT); + if (err < 0) + return err; + if (err > 0 && !setup->userspace_handles_dfs) + return -EINVAL; + + if (!cfg80211_reg_can_beacon(&rdev->wiphy, &setup->chandef, + NL80211_IFTYPE_MESH_POINT)) + return -EINVAL; + + err = rdev_join_mesh(rdev, dev, conf, setup); + if (!err) { + memcpy(wdev->u.mesh.id, setup->mesh_id, setup->mesh_id_len); + wdev->u.mesh.id_len = setup->mesh_id_len; + wdev->u.mesh.chandef = setup->chandef; + wdev->u.mesh.beacon_interval = setup->beacon_interval; + } + + return err; +} + +int cfg80211_set_mesh_channel(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + struct cfg80211_chan_def *chandef) +{ + int err; + + /* + * Workaround for libertas (only!), it puts the interface + * into mesh mode but doesn't implement join_mesh. Instead, + * it is configured via sysfs and then joins the mesh when + * you set the channel. Note that the libertas mesh isn't + * compatible with 802.11 mesh. + */ + if (rdev->ops->libertas_set_mesh_channel) { + if (chandef->width != NL80211_CHAN_WIDTH_20_NOHT) + return -EINVAL; + + if (!netif_running(wdev->netdev)) + return -ENETDOWN; + + err = rdev_libertas_set_mesh_channel(rdev, wdev->netdev, + chandef->chan); + if (!err) + wdev->u.mesh.chandef = *chandef; + + return err; + } + + if (wdev->u.mesh.id_len) + return -EBUSY; + + wdev->u.mesh.preset_chandef = *chandef; + return 0; +} + +int __cfg80211_leave_mesh(struct cfg80211_registered_device *rdev, + struct net_device *dev) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int err; + + ASSERT_WDEV_LOCK(wdev); + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_MESH_POINT) + return -EOPNOTSUPP; + + if (!rdev->ops->leave_mesh) + return -EOPNOTSUPP; + + if (!wdev->u.mesh.id_len) + return -ENOTCONN; + + err = rdev_leave_mesh(rdev, dev); + if (!err) { + wdev->conn_owner_nlportid = 0; + wdev->u.mesh.id_len = 0; + wdev->u.mesh.beacon_interval = 0; + memset(&wdev->u.mesh.chandef, 0, + sizeof(wdev->u.mesh.chandef)); + rdev_set_qos_map(rdev, dev, NULL); + cfg80211_sched_dfs_chan_update(rdev); + } + + return err; +} + +int cfg80211_leave_mesh(struct cfg80211_registered_device *rdev, + struct net_device *dev) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int err; + + wdev_lock(wdev); + err = __cfg80211_leave_mesh(rdev, dev); + wdev_unlock(wdev); + + return err; +} diff --git a/net/wireless/mlme.c b/net/wireless/mlme.c new file mode 100644 index 000000000..e7fa06083 --- /dev/null +++ b/net/wireless/mlme.c @@ -0,0 +1,1160 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * cfg80211 MLME SAP interface + * + * Copyright (c) 2009, Jouni Malinen + * Copyright (c) 2015 Intel Deutschland GmbH + * Copyright (C) 2019-2020, 2022 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "core.h" +#include "nl80211.h" +#include "rdev-ops.h" + + +void cfg80211_rx_assoc_resp(struct net_device *dev, + struct cfg80211_rx_assoc_resp *data) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct ieee80211_mgmt *mgmt = (struct ieee80211_mgmt *)data->buf; + struct cfg80211_connect_resp_params cr = { + .timeout_reason = NL80211_TIMEOUT_UNSPECIFIED, + .req_ie = data->req_ies, + .req_ie_len = data->req_ies_len, + .resp_ie = mgmt->u.assoc_resp.variable, + .resp_ie_len = data->len - + offsetof(struct ieee80211_mgmt, + u.assoc_resp.variable), + .status = le16_to_cpu(mgmt->u.assoc_resp.status_code), + .ap_mld_addr = data->ap_mld_addr, + }; + unsigned int link_id; + + for (link_id = 0; link_id < ARRAY_SIZE(data->links); link_id++) { + cr.links[link_id].bss = data->links[link_id].bss; + if (!cr.links[link_id].bss) + continue; + cr.links[link_id].bssid = data->links[link_id].bss->bssid; + cr.links[link_id].addr = data->links[link_id].addr; + /* need to have local link addresses for MLO connections */ + WARN_ON(cr.ap_mld_addr && !cr.links[link_id].addr); + + BUG_ON(!cr.links[link_id].bss->channel); + + if (cr.links[link_id].bss->channel->band == NL80211_BAND_S1GHZ) { + WARN_ON(link_id); + cr.resp_ie = (u8 *)&mgmt->u.s1g_assoc_resp.variable; + cr.resp_ie_len = data->len - + offsetof(struct ieee80211_mgmt, + u.s1g_assoc_resp.variable); + } + + if (cr.ap_mld_addr) + cr.valid_links |= BIT(link_id); + } + + trace_cfg80211_send_rx_assoc(dev, data); + + /* + * This is a bit of a hack, we don't notify userspace of + * a (re-)association reply if we tried to send a reassoc + * and got a reject -- we only try again with an assoc + * frame instead of reassoc. + */ + if (cfg80211_sme_rx_assoc_resp(wdev, cr.status)) { + for (link_id = 0; link_id < ARRAY_SIZE(data->links); link_id++) { + struct cfg80211_bss *bss = data->links[link_id].bss; + + if (!bss) + continue; + + cfg80211_unhold_bss(bss_from_pub(bss)); + cfg80211_put_bss(wiphy, bss); + } + return; + } + + nl80211_send_rx_assoc(rdev, dev, data); + /* update current_bss etc., consumes the bss reference */ + __cfg80211_connect_result(dev, &cr, cr.status == WLAN_STATUS_SUCCESS); +} +EXPORT_SYMBOL(cfg80211_rx_assoc_resp); + +static void cfg80211_process_auth(struct wireless_dev *wdev, + const u8 *buf, size_t len) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + + nl80211_send_rx_auth(rdev, wdev->netdev, buf, len, GFP_KERNEL); + cfg80211_sme_rx_auth(wdev, buf, len); +} + +static void cfg80211_process_deauth(struct wireless_dev *wdev, + const u8 *buf, size_t len, + bool reconnect) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct ieee80211_mgmt *mgmt = (struct ieee80211_mgmt *)buf; + const u8 *bssid = mgmt->bssid; + u16 reason_code = le16_to_cpu(mgmt->u.deauth.reason_code); + bool from_ap = !ether_addr_equal(mgmt->sa, wdev->netdev->dev_addr); + + nl80211_send_deauth(rdev, wdev->netdev, buf, len, reconnect, GFP_KERNEL); + + if (!wdev->connected || !ether_addr_equal(wdev->u.client.connected_addr, bssid)) + return; + + __cfg80211_disconnected(wdev->netdev, NULL, 0, reason_code, from_ap); + cfg80211_sme_deauth(wdev); +} + +static void cfg80211_process_disassoc(struct wireless_dev *wdev, + const u8 *buf, size_t len, + bool reconnect) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct ieee80211_mgmt *mgmt = (struct ieee80211_mgmt *)buf; + const u8 *bssid = mgmt->bssid; + u16 reason_code = le16_to_cpu(mgmt->u.disassoc.reason_code); + bool from_ap = !ether_addr_equal(mgmt->sa, wdev->netdev->dev_addr); + + nl80211_send_disassoc(rdev, wdev->netdev, buf, len, reconnect, + GFP_KERNEL); + + if (WARN_ON(!wdev->connected || + !ether_addr_equal(wdev->u.client.connected_addr, bssid))) + return; + + __cfg80211_disconnected(wdev->netdev, NULL, 0, reason_code, from_ap); + cfg80211_sme_disassoc(wdev); +} + +void cfg80211_rx_mlme_mgmt(struct net_device *dev, const u8 *buf, size_t len) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct ieee80211_mgmt *mgmt = (void *)buf; + + ASSERT_WDEV_LOCK(wdev); + + trace_cfg80211_rx_mlme_mgmt(dev, buf, len); + + if (WARN_ON(len < 2)) + return; + + if (ieee80211_is_auth(mgmt->frame_control)) + cfg80211_process_auth(wdev, buf, len); + else if (ieee80211_is_deauth(mgmt->frame_control)) + cfg80211_process_deauth(wdev, buf, len, false); + else if (ieee80211_is_disassoc(mgmt->frame_control)) + cfg80211_process_disassoc(wdev, buf, len, false); +} +EXPORT_SYMBOL(cfg80211_rx_mlme_mgmt); + +void cfg80211_auth_timeout(struct net_device *dev, const u8 *addr) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + trace_cfg80211_send_auth_timeout(dev, addr); + + nl80211_send_auth_timeout(rdev, dev, addr, GFP_KERNEL); + cfg80211_sme_auth_timeout(wdev); +} +EXPORT_SYMBOL(cfg80211_auth_timeout); + +void cfg80211_assoc_failure(struct net_device *dev, + struct cfg80211_assoc_failure *data) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + const u8 *addr = data->ap_mld_addr ?: data->bss[0]->bssid; + int i; + + trace_cfg80211_send_assoc_failure(dev, data); + + if (data->timeout) { + nl80211_send_assoc_timeout(rdev, dev, addr, GFP_KERNEL); + cfg80211_sme_assoc_timeout(wdev); + } else { + cfg80211_sme_abandon_assoc(wdev); + } + + for (i = 0; i < ARRAY_SIZE(data->bss); i++) { + struct cfg80211_bss *bss = data->bss[i]; + + if (!bss) + continue; + + cfg80211_unhold_bss(bss_from_pub(bss)); + cfg80211_put_bss(wiphy, bss); + } +} +EXPORT_SYMBOL(cfg80211_assoc_failure); + +void cfg80211_tx_mlme_mgmt(struct net_device *dev, const u8 *buf, size_t len, + bool reconnect) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct ieee80211_mgmt *mgmt = (void *)buf; + + ASSERT_WDEV_LOCK(wdev); + + trace_cfg80211_tx_mlme_mgmt(dev, buf, len, reconnect); + + if (WARN_ON(len < 2)) + return; + + if (ieee80211_is_deauth(mgmt->frame_control)) + cfg80211_process_deauth(wdev, buf, len, reconnect); + else + cfg80211_process_disassoc(wdev, buf, len, reconnect); +} +EXPORT_SYMBOL(cfg80211_tx_mlme_mgmt); + +void cfg80211_michael_mic_failure(struct net_device *dev, const u8 *addr, + enum nl80211_key_type key_type, int key_id, + const u8 *tsc, gfp_t gfp) +{ + struct wiphy *wiphy = dev->ieee80211_ptr->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); +#ifdef CONFIG_CFG80211_WEXT + union iwreq_data wrqu; + char *buf = kmalloc(128, gfp); + + if (buf) { + sprintf(buf, "MLME-MICHAELMICFAILURE.indication(" + "keyid=%d %scast addr=%pM)", key_id, + key_type == NL80211_KEYTYPE_GROUP ? "broad" : "uni", + addr); + memset(&wrqu, 0, sizeof(wrqu)); + wrqu.data.length = strlen(buf); + wireless_send_event(dev, IWEVCUSTOM, &wrqu, buf); + kfree(buf); + } +#endif + + trace_cfg80211_michael_mic_failure(dev, addr, key_type, key_id, tsc); + nl80211_michael_mic_failure(rdev, dev, addr, key_type, key_id, tsc, gfp); +} +EXPORT_SYMBOL(cfg80211_michael_mic_failure); + +/* some MLME handling for userspace SME */ +int cfg80211_mlme_auth(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_auth_request *req) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + + ASSERT_WDEV_LOCK(wdev); + + if (!req->bss) + return -ENOENT; + + if (req->link_id >= 0 && + !(wdev->wiphy->flags & WIPHY_FLAG_SUPPORTS_MLO)) + return -EINVAL; + + if (req->auth_type == NL80211_AUTHTYPE_SHARED_KEY) { + if (!req->key || !req->key_len || + req->key_idx < 0 || req->key_idx > 3) + return -EINVAL; + } + + if (wdev->connected && + ether_addr_equal(req->bss->bssid, wdev->u.client.connected_addr)) + return -EALREADY; + + if (ether_addr_equal(req->bss->bssid, dev->dev_addr) || + (req->link_id >= 0 && + ether_addr_equal(req->ap_mld_addr, dev->dev_addr))) + return -EINVAL; + + return rdev_auth(rdev, dev, req); +} + +/* Do a logical ht_capa &= ht_capa_mask. */ +void cfg80211_oper_and_ht_capa(struct ieee80211_ht_cap *ht_capa, + const struct ieee80211_ht_cap *ht_capa_mask) +{ + int i; + u8 *p1, *p2; + if (!ht_capa_mask) { + memset(ht_capa, 0, sizeof(*ht_capa)); + return; + } + + p1 = (u8*)(ht_capa); + p2 = (u8*)(ht_capa_mask); + for (i = 0; i < sizeof(*ht_capa); i++) + p1[i] &= p2[i]; +} + +/* Do a logical vht_capa &= vht_capa_mask. */ +void cfg80211_oper_and_vht_capa(struct ieee80211_vht_cap *vht_capa, + const struct ieee80211_vht_cap *vht_capa_mask) +{ + int i; + u8 *p1, *p2; + if (!vht_capa_mask) { + memset(vht_capa, 0, sizeof(*vht_capa)); + return; + } + + p1 = (u8*)(vht_capa); + p2 = (u8*)(vht_capa_mask); + for (i = 0; i < sizeof(*vht_capa); i++) + p1[i] &= p2[i]; +} + +/* Note: caller must cfg80211_put_bss() regardless of result */ +int cfg80211_mlme_assoc(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_assoc_request *req) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int err, i, j; + + ASSERT_WDEV_LOCK(wdev); + + for (i = 1; i < ARRAY_SIZE(req->links); i++) { + if (!req->links[i].bss) + continue; + for (j = 0; j < i; j++) { + if (req->links[i].bss == req->links[j].bss) + return -EINVAL; + } + + if (ether_addr_equal(req->links[i].bss->bssid, dev->dev_addr)) + return -EINVAL; + } + + if (wdev->connected && + (!req->prev_bssid || + !ether_addr_equal(wdev->u.client.connected_addr, req->prev_bssid))) + return -EALREADY; + + if ((req->bss && ether_addr_equal(req->bss->bssid, dev->dev_addr)) || + (req->link_id >= 0 && + ether_addr_equal(req->ap_mld_addr, dev->dev_addr))) + return -EINVAL; + + cfg80211_oper_and_ht_capa(&req->ht_capa_mask, + rdev->wiphy.ht_capa_mod_mask); + cfg80211_oper_and_vht_capa(&req->vht_capa_mask, + rdev->wiphy.vht_capa_mod_mask); + + err = rdev_assoc(rdev, dev, req); + if (!err) { + int link_id; + + if (req->bss) { + cfg80211_ref_bss(&rdev->wiphy, req->bss); + cfg80211_hold_bss(bss_from_pub(req->bss)); + } + + for (link_id = 0; link_id < ARRAY_SIZE(req->links); link_id++) { + if (!req->links[link_id].bss) + continue; + cfg80211_ref_bss(&rdev->wiphy, req->links[link_id].bss); + cfg80211_hold_bss(bss_from_pub(req->links[link_id].bss)); + } + } + return err; +} + +int cfg80211_mlme_deauth(struct cfg80211_registered_device *rdev, + struct net_device *dev, const u8 *bssid, + const u8 *ie, int ie_len, u16 reason, + bool local_state_change) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_deauth_request req = { + .bssid = bssid, + .reason_code = reason, + .ie = ie, + .ie_len = ie_len, + .local_state_change = local_state_change, + }; + + ASSERT_WDEV_LOCK(wdev); + + if (local_state_change && + (!wdev->connected || + !ether_addr_equal(wdev->u.client.connected_addr, bssid))) + return 0; + + if (ether_addr_equal(wdev->disconnect_bssid, bssid) || + (wdev->connected && + ether_addr_equal(wdev->u.client.connected_addr, bssid))) + wdev->conn_owner_nlportid = 0; + + return rdev_deauth(rdev, dev, &req); +} + +int cfg80211_mlme_disassoc(struct cfg80211_registered_device *rdev, + struct net_device *dev, const u8 *ap_addr, + const u8 *ie, int ie_len, u16 reason, + bool local_state_change) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_disassoc_request req = { + .reason_code = reason, + .local_state_change = local_state_change, + .ie = ie, + .ie_len = ie_len, + .ap_addr = ap_addr, + }; + int err; + + ASSERT_WDEV_LOCK(wdev); + + if (!wdev->connected) + return -ENOTCONN; + + if (memcmp(wdev->u.client.connected_addr, ap_addr, ETH_ALEN)) + return -ENOTCONN; + + err = rdev_disassoc(rdev, dev, &req); + if (err) + return err; + + /* driver should have reported the disassoc */ + WARN_ON(wdev->connected); + return 0; +} + +void cfg80211_mlme_down(struct cfg80211_registered_device *rdev, + struct net_device *dev) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + u8 bssid[ETH_ALEN]; + + ASSERT_WDEV_LOCK(wdev); + + if (!rdev->ops->deauth) + return; + + if (!wdev->connected) + return; + + memcpy(bssid, wdev->u.client.connected_addr, ETH_ALEN); + cfg80211_mlme_deauth(rdev, dev, bssid, NULL, 0, + WLAN_REASON_DEAUTH_LEAVING, false); +} + +struct cfg80211_mgmt_registration { + struct list_head list; + struct wireless_dev *wdev; + + u32 nlportid; + + int match_len; + + __le16 frame_type; + + bool multicast_rx; + + u8 match[]; +}; + +static void cfg80211_mgmt_registrations_update(struct wireless_dev *wdev) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct wireless_dev *tmp; + struct cfg80211_mgmt_registration *reg; + struct mgmt_frame_regs upd = {}; + + lockdep_assert_held(&rdev->wiphy.mtx); + + spin_lock_bh(&rdev->mgmt_registrations_lock); + if (!wdev->mgmt_registrations_need_update) { + spin_unlock_bh(&rdev->mgmt_registrations_lock); + return; + } + + rcu_read_lock(); + list_for_each_entry_rcu(tmp, &rdev->wiphy.wdev_list, list) { + list_for_each_entry(reg, &tmp->mgmt_registrations, list) { + u32 mask = BIT(le16_to_cpu(reg->frame_type) >> 4); + u32 mcast_mask = 0; + + if (reg->multicast_rx) + mcast_mask = mask; + + upd.global_stypes |= mask; + upd.global_mcast_stypes |= mcast_mask; + + if (tmp == wdev) { + upd.interface_stypes |= mask; + upd.interface_mcast_stypes |= mcast_mask; + } + } + } + rcu_read_unlock(); + + wdev->mgmt_registrations_need_update = 0; + spin_unlock_bh(&rdev->mgmt_registrations_lock); + + rdev_update_mgmt_frame_registrations(rdev, wdev, &upd); +} + +void cfg80211_mgmt_registrations_update_wk(struct work_struct *wk) +{ + struct cfg80211_registered_device *rdev; + struct wireless_dev *wdev; + + rdev = container_of(wk, struct cfg80211_registered_device, + mgmt_registrations_update_wk); + + wiphy_lock(&rdev->wiphy); + list_for_each_entry(wdev, &rdev->wiphy.wdev_list, list) + cfg80211_mgmt_registrations_update(wdev); + wiphy_unlock(&rdev->wiphy); +} + +int cfg80211_mlme_register_mgmt(struct wireless_dev *wdev, u32 snd_portid, + u16 frame_type, const u8 *match_data, + int match_len, bool multicast_rx, + struct netlink_ext_ack *extack) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_mgmt_registration *reg, *nreg; + int err = 0; + u16 mgmt_type; + bool update_multicast = false; + + if (!wdev->wiphy->mgmt_stypes) + return -EOPNOTSUPP; + + if ((frame_type & IEEE80211_FCTL_FTYPE) != IEEE80211_FTYPE_MGMT) { + NL_SET_ERR_MSG(extack, "frame type not management"); + return -EINVAL; + } + + if (frame_type & ~(IEEE80211_FCTL_FTYPE | IEEE80211_FCTL_STYPE)) { + NL_SET_ERR_MSG(extack, "Invalid frame type"); + return -EINVAL; + } + + mgmt_type = (frame_type & IEEE80211_FCTL_STYPE) >> 4; + if (!(wdev->wiphy->mgmt_stypes[wdev->iftype].rx & BIT(mgmt_type))) { + NL_SET_ERR_MSG(extack, + "Registration to specific type not supported"); + return -EINVAL; + } + + /* + * To support Pre Association Security Negotiation (PASN), registration + * for authentication frames should be supported. However, as some + * versions of the user space daemons wrongly register to all types of + * authentication frames (which might result in unexpected behavior) + * allow such registration if the request is for a specific + * authentication algorithm number. + */ + if (wdev->iftype == NL80211_IFTYPE_STATION && + (frame_type & IEEE80211_FCTL_STYPE) == IEEE80211_STYPE_AUTH && + !(match_data && match_len >= 2)) { + NL_SET_ERR_MSG(extack, + "Authentication algorithm number required"); + return -EINVAL; + } + + nreg = kzalloc(sizeof(*reg) + match_len, GFP_KERNEL); + if (!nreg) + return -ENOMEM; + + spin_lock_bh(&rdev->mgmt_registrations_lock); + + list_for_each_entry(reg, &wdev->mgmt_registrations, list) { + int mlen = min(match_len, reg->match_len); + + if (frame_type != le16_to_cpu(reg->frame_type)) + continue; + + if (memcmp(reg->match, match_data, mlen) == 0) { + if (reg->multicast_rx != multicast_rx) { + update_multicast = true; + reg->multicast_rx = multicast_rx; + break; + } + NL_SET_ERR_MSG(extack, "Match already configured"); + err = -EALREADY; + break; + } + } + + if (err) + goto out; + + if (update_multicast) { + kfree(nreg); + } else { + memcpy(nreg->match, match_data, match_len); + nreg->match_len = match_len; + nreg->nlportid = snd_portid; + nreg->frame_type = cpu_to_le16(frame_type); + nreg->wdev = wdev; + nreg->multicast_rx = multicast_rx; + list_add(&nreg->list, &wdev->mgmt_registrations); + } + wdev->mgmt_registrations_need_update = 1; + spin_unlock_bh(&rdev->mgmt_registrations_lock); + + cfg80211_mgmt_registrations_update(wdev); + + return 0; + + out: + kfree(nreg); + spin_unlock_bh(&rdev->mgmt_registrations_lock); + + return err; +} + +void cfg80211_mlme_unregister_socket(struct wireless_dev *wdev, u32 nlportid) +{ + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct cfg80211_mgmt_registration *reg, *tmp; + + spin_lock_bh(&rdev->mgmt_registrations_lock); + + list_for_each_entry_safe(reg, tmp, &wdev->mgmt_registrations, list) { + if (reg->nlportid != nlportid) + continue; + + list_del(®->list); + kfree(reg); + + wdev->mgmt_registrations_need_update = 1; + schedule_work(&rdev->mgmt_registrations_update_wk); + } + + spin_unlock_bh(&rdev->mgmt_registrations_lock); + + if (nlportid && rdev->crit_proto_nlportid == nlportid) { + rdev->crit_proto_nlportid = 0; + rdev_crit_proto_stop(rdev, wdev); + } + + if (nlportid == wdev->ap_unexpected_nlportid) + wdev->ap_unexpected_nlportid = 0; +} + +void cfg80211_mlme_purge_registrations(struct wireless_dev *wdev) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_mgmt_registration *reg, *tmp; + + spin_lock_bh(&rdev->mgmt_registrations_lock); + list_for_each_entry_safe(reg, tmp, &wdev->mgmt_registrations, list) { + list_del(®->list); + kfree(reg); + } + wdev->mgmt_registrations_need_update = 1; + spin_unlock_bh(&rdev->mgmt_registrations_lock); + + cfg80211_mgmt_registrations_update(wdev); +} + +static bool cfg80211_allowed_address(struct wireless_dev *wdev, const u8 *addr) +{ + int i; + + for_each_valid_link(wdev, i) { + if (ether_addr_equal(addr, wdev->links[i].addr)) + return true; + } + + return ether_addr_equal(addr, wdev_address(wdev)); +} + +int cfg80211_mlme_mgmt_tx(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + struct cfg80211_mgmt_tx_params *params, u64 *cookie) +{ + const struct ieee80211_mgmt *mgmt; + u16 stype; + + if (!wdev->wiphy->mgmt_stypes) + return -EOPNOTSUPP; + + if (!rdev->ops->mgmt_tx) + return -EOPNOTSUPP; + + if (params->len < 24 + 1) + return -EINVAL; + + mgmt = (const struct ieee80211_mgmt *)params->buf; + + if (!ieee80211_is_mgmt(mgmt->frame_control)) + return -EINVAL; + + stype = le16_to_cpu(mgmt->frame_control) & IEEE80211_FCTL_STYPE; + if (!(wdev->wiphy->mgmt_stypes[wdev->iftype].tx & BIT(stype >> 4))) + return -EINVAL; + + if (ieee80211_is_action(mgmt->frame_control) && + mgmt->u.action.category != WLAN_CATEGORY_PUBLIC) { + int err = 0; + + wdev_lock(wdev); + + switch (wdev->iftype) { + case NL80211_IFTYPE_ADHOC: + /* + * check for IBSS DA must be done by driver as + * cfg80211 doesn't track the stations + */ + if (!wdev->u.ibss.current_bss || + !ether_addr_equal(wdev->u.ibss.current_bss->pub.bssid, + mgmt->bssid)) { + err = -ENOTCONN; + break; + } + break; + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_P2P_CLIENT: + if (!wdev->connected) { + err = -ENOTCONN; + break; + } + + /* FIXME: MLD may address this differently */ + + if (!ether_addr_equal(wdev->u.client.connected_addr, + mgmt->bssid)) { + err = -ENOTCONN; + break; + } + + /* for station, check that DA is the AP */ + if (!ether_addr_equal(wdev->u.client.connected_addr, + mgmt->da)) { + err = -ENOTCONN; + break; + } + break; + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_P2P_GO: + case NL80211_IFTYPE_AP_VLAN: + if (!ether_addr_equal(mgmt->bssid, wdev_address(wdev))) + err = -EINVAL; + break; + case NL80211_IFTYPE_MESH_POINT: + if (!ether_addr_equal(mgmt->sa, mgmt->bssid)) { + err = -EINVAL; + break; + } + /* + * check for mesh DA must be done by driver as + * cfg80211 doesn't track the stations + */ + break; + case NL80211_IFTYPE_P2P_DEVICE: + /* + * fall through, P2P device only supports + * public action frames + */ + case NL80211_IFTYPE_NAN: + default: + err = -EOPNOTSUPP; + break; + } + wdev_unlock(wdev); + + if (err) + return err; + } + + if (!cfg80211_allowed_address(wdev, mgmt->sa)) { + /* Allow random TA to be used with Public Action frames if the + * driver has indicated support for this. Otherwise, only allow + * the local address to be used. + */ + if (!ieee80211_is_action(mgmt->frame_control) || + mgmt->u.action.category != WLAN_CATEGORY_PUBLIC) + return -EINVAL; + if (!wdev->connected && + !wiphy_ext_feature_isset( + &rdev->wiphy, + NL80211_EXT_FEATURE_MGMT_TX_RANDOM_TA)) + return -EINVAL; + if (wdev->connected && + !wiphy_ext_feature_isset( + &rdev->wiphy, + NL80211_EXT_FEATURE_MGMT_TX_RANDOM_TA_CONNECTED)) + return -EINVAL; + } + + /* Transmit the management frame as requested by user space */ + return rdev_mgmt_tx(rdev, wdev, params, cookie); +} + +bool cfg80211_rx_mgmt_ext(struct wireless_dev *wdev, + struct cfg80211_rx_info *info) +{ + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct cfg80211_mgmt_registration *reg; + const struct ieee80211_txrx_stypes *stypes = + &wiphy->mgmt_stypes[wdev->iftype]; + struct ieee80211_mgmt *mgmt = (void *)info->buf; + const u8 *data; + int data_len; + bool result = false; + __le16 ftype = mgmt->frame_control & + cpu_to_le16(IEEE80211_FCTL_FTYPE | IEEE80211_FCTL_STYPE); + u16 stype; + + trace_cfg80211_rx_mgmt(wdev, info); + stype = (le16_to_cpu(mgmt->frame_control) & IEEE80211_FCTL_STYPE) >> 4; + + if (!(stypes->rx & BIT(stype))) { + trace_cfg80211_return_bool(false); + return false; + } + + data = info->buf + ieee80211_hdrlen(mgmt->frame_control); + data_len = info->len - ieee80211_hdrlen(mgmt->frame_control); + + spin_lock_bh(&rdev->mgmt_registrations_lock); + + list_for_each_entry(reg, &wdev->mgmt_registrations, list) { + if (reg->frame_type != ftype) + continue; + + if (reg->match_len > data_len) + continue; + + if (memcmp(reg->match, data, reg->match_len)) + continue; + + /* found match! */ + + /* Indicate the received Action frame to user space */ + if (nl80211_send_mgmt(rdev, wdev, reg->nlportid, info, + GFP_ATOMIC)) + continue; + + result = true; + break; + } + + spin_unlock_bh(&rdev->mgmt_registrations_lock); + + trace_cfg80211_return_bool(result); + return result; +} +EXPORT_SYMBOL(cfg80211_rx_mgmt_ext); + +void cfg80211_sched_dfs_chan_update(struct cfg80211_registered_device *rdev) +{ + cancel_delayed_work(&rdev->dfs_update_channels_wk); + queue_delayed_work(cfg80211_wq, &rdev->dfs_update_channels_wk, 0); +} + +void cfg80211_dfs_channels_update_work(struct work_struct *work) +{ + struct delayed_work *delayed_work = to_delayed_work(work); + struct cfg80211_registered_device *rdev; + struct cfg80211_chan_def chandef; + struct ieee80211_supported_band *sband; + struct ieee80211_channel *c; + struct wiphy *wiphy; + bool check_again = false; + unsigned long timeout, next_time = 0; + unsigned long time_dfs_update; + enum nl80211_radar_event radar_event; + int bandid, i; + + rdev = container_of(delayed_work, struct cfg80211_registered_device, + dfs_update_channels_wk); + wiphy = &rdev->wiphy; + + rtnl_lock(); + for (bandid = 0; bandid < NUM_NL80211_BANDS; bandid++) { + sband = wiphy->bands[bandid]; + if (!sband) + continue; + + for (i = 0; i < sband->n_channels; i++) { + c = &sband->channels[i]; + + if (!(c->flags & IEEE80211_CHAN_RADAR)) + continue; + + if (c->dfs_state != NL80211_DFS_UNAVAILABLE && + c->dfs_state != NL80211_DFS_AVAILABLE) + continue; + + if (c->dfs_state == NL80211_DFS_UNAVAILABLE) { + time_dfs_update = IEEE80211_DFS_MIN_NOP_TIME_MS; + radar_event = NL80211_RADAR_NOP_FINISHED; + } else { + if (regulatory_pre_cac_allowed(wiphy) || + cfg80211_any_wiphy_oper_chan(wiphy, c)) + continue; + + time_dfs_update = REG_PRE_CAC_EXPIRY_GRACE_MS; + radar_event = NL80211_RADAR_PRE_CAC_EXPIRED; + } + + timeout = c->dfs_state_entered + + msecs_to_jiffies(time_dfs_update); + + if (time_after_eq(jiffies, timeout)) { + c->dfs_state = NL80211_DFS_USABLE; + c->dfs_state_entered = jiffies; + + cfg80211_chandef_create(&chandef, c, + NL80211_CHAN_NO_HT); + + nl80211_radar_notify(rdev, &chandef, + radar_event, NULL, + GFP_ATOMIC); + + regulatory_propagate_dfs_state(wiphy, &chandef, + c->dfs_state, + radar_event); + continue; + } + + if (!check_again) + next_time = timeout - jiffies; + else + next_time = min(next_time, timeout - jiffies); + check_again = true; + } + } + rtnl_unlock(); + + /* reschedule if there are other channels waiting to be cleared again */ + if (check_again) + queue_delayed_work(cfg80211_wq, &rdev->dfs_update_channels_wk, + next_time); +} + + +void __cfg80211_radar_event(struct wiphy *wiphy, + struct cfg80211_chan_def *chandef, + bool offchan, gfp_t gfp) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + trace_cfg80211_radar_event(wiphy, chandef, offchan); + + /* only set the chandef supplied channel to unavailable, in + * case the radar is detected on only one of multiple channels + * spanned by the chandef. + */ + cfg80211_set_dfs_state(wiphy, chandef, NL80211_DFS_UNAVAILABLE); + + if (offchan) + queue_work(cfg80211_wq, &rdev->background_cac_abort_wk); + + cfg80211_sched_dfs_chan_update(rdev); + + nl80211_radar_notify(rdev, chandef, NL80211_RADAR_DETECTED, NULL, gfp); + + memcpy(&rdev->radar_chandef, chandef, sizeof(struct cfg80211_chan_def)); + queue_work(cfg80211_wq, &rdev->propagate_radar_detect_wk); +} +EXPORT_SYMBOL(__cfg80211_radar_event); + +void cfg80211_cac_event(struct net_device *netdev, + const struct cfg80211_chan_def *chandef, + enum nl80211_radar_event event, gfp_t gfp) +{ + struct wireless_dev *wdev = netdev->ieee80211_ptr; + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + unsigned long timeout; + + /* not yet supported */ + if (wdev->valid_links) + return; + + trace_cfg80211_cac_event(netdev, event); + + if (WARN_ON(!wdev->cac_started && event != NL80211_RADAR_CAC_STARTED)) + return; + + switch (event) { + case NL80211_RADAR_CAC_FINISHED: + timeout = wdev->cac_start_time + + msecs_to_jiffies(wdev->cac_time_ms); + WARN_ON(!time_after_eq(jiffies, timeout)); + cfg80211_set_dfs_state(wiphy, chandef, NL80211_DFS_AVAILABLE); + memcpy(&rdev->cac_done_chandef, chandef, + sizeof(struct cfg80211_chan_def)); + queue_work(cfg80211_wq, &rdev->propagate_cac_done_wk); + cfg80211_sched_dfs_chan_update(rdev); + fallthrough; + case NL80211_RADAR_CAC_ABORTED: + wdev->cac_started = false; + break; + case NL80211_RADAR_CAC_STARTED: + wdev->cac_started = true; + break; + default: + WARN_ON(1); + return; + } + + nl80211_radar_notify(rdev, chandef, event, netdev, gfp); +} +EXPORT_SYMBOL(cfg80211_cac_event); + +static void +__cfg80211_background_cac_event(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + const struct cfg80211_chan_def *chandef, + enum nl80211_radar_event event) +{ + struct wiphy *wiphy = &rdev->wiphy; + struct net_device *netdev; + + lockdep_assert_wiphy(&rdev->wiphy); + + if (!cfg80211_chandef_valid(chandef)) + return; + + if (!rdev->background_radar_wdev) + return; + + switch (event) { + case NL80211_RADAR_CAC_FINISHED: + cfg80211_set_dfs_state(wiphy, chandef, NL80211_DFS_AVAILABLE); + memcpy(&rdev->cac_done_chandef, chandef, sizeof(*chandef)); + queue_work(cfg80211_wq, &rdev->propagate_cac_done_wk); + cfg80211_sched_dfs_chan_update(rdev); + wdev = rdev->background_radar_wdev; + break; + case NL80211_RADAR_CAC_ABORTED: + if (!cancel_delayed_work(&rdev->background_cac_done_wk)) + return; + wdev = rdev->background_radar_wdev; + break; + case NL80211_RADAR_CAC_STARTED: + break; + default: + return; + } + + netdev = wdev ? wdev->netdev : NULL; + nl80211_radar_notify(rdev, chandef, event, netdev, GFP_KERNEL); +} + +static void +cfg80211_background_cac_event(struct cfg80211_registered_device *rdev, + const struct cfg80211_chan_def *chandef, + enum nl80211_radar_event event) +{ + wiphy_lock(&rdev->wiphy); + __cfg80211_background_cac_event(rdev, rdev->background_radar_wdev, + chandef, event); + wiphy_unlock(&rdev->wiphy); +} + +void cfg80211_background_cac_done_wk(struct work_struct *work) +{ + struct delayed_work *delayed_work = to_delayed_work(work); + struct cfg80211_registered_device *rdev; + + rdev = container_of(delayed_work, struct cfg80211_registered_device, + background_cac_done_wk); + cfg80211_background_cac_event(rdev, &rdev->background_radar_chandef, + NL80211_RADAR_CAC_FINISHED); +} + +void cfg80211_background_cac_abort_wk(struct work_struct *work) +{ + struct cfg80211_registered_device *rdev; + + rdev = container_of(work, struct cfg80211_registered_device, + background_cac_abort_wk); + cfg80211_background_cac_event(rdev, &rdev->background_radar_chandef, + NL80211_RADAR_CAC_ABORTED); +} + +void cfg80211_background_cac_abort(struct wiphy *wiphy) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + queue_work(cfg80211_wq, &rdev->background_cac_abort_wk); +} +EXPORT_SYMBOL(cfg80211_background_cac_abort); + +int +cfg80211_start_background_radar_detection(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + struct cfg80211_chan_def *chandef) +{ + unsigned int cac_time_ms; + int err; + + lockdep_assert_wiphy(&rdev->wiphy); + + if (!wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_RADAR_BACKGROUND)) + return -EOPNOTSUPP; + + /* Offchannel chain already locked by another wdev */ + if (rdev->background_radar_wdev && rdev->background_radar_wdev != wdev) + return -EBUSY; + + /* CAC already in progress on the offchannel chain */ + if (rdev->background_radar_wdev == wdev && + delayed_work_pending(&rdev->background_cac_done_wk)) + return -EBUSY; + + err = rdev_set_radar_background(rdev, chandef); + if (err) + return err; + + cac_time_ms = cfg80211_chandef_dfs_cac_time(&rdev->wiphy, chandef); + if (!cac_time_ms) + cac_time_ms = IEEE80211_DFS_MIN_CAC_TIME_MS; + + rdev->background_radar_chandef = *chandef; + rdev->background_radar_wdev = wdev; /* Get offchain ownership */ + + __cfg80211_background_cac_event(rdev, wdev, chandef, + NL80211_RADAR_CAC_STARTED); + queue_delayed_work(cfg80211_wq, &rdev->background_cac_done_wk, + msecs_to_jiffies(cac_time_ms)); + + return 0; +} + +void cfg80211_stop_background_radar_detection(struct wireless_dev *wdev) +{ + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + lockdep_assert_wiphy(wiphy); + + if (wdev != rdev->background_radar_wdev) + return; + + rdev_set_radar_background(rdev, NULL); + rdev->background_radar_wdev = NULL; /* Release offchain ownership */ + + __cfg80211_background_cac_event(rdev, wdev, + &rdev->background_radar_chandef, + NL80211_RADAR_CAC_ABORTED); +} diff --git a/net/wireless/nl80211.c b/net/wireless/nl80211.c new file mode 100644 index 000000000..70fb14b8b --- /dev/null +++ b/net/wireless/nl80211.c @@ -0,0 +1,19837 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * This is the new netlink-based wireless configuration interface. + * + * Copyright 2006-2010 Johannes Berg + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright 2015-2017 Intel Deutschland GmbH + * Copyright (C) 2018-2022 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "core.h" +#include "nl80211.h" +#include "reg.h" +#include "rdev-ops.h" + +static int nl80211_crypto_settings(struct cfg80211_registered_device *rdev, + struct genl_info *info, + struct cfg80211_crypto_settings *settings, + int cipher_limit); + +/* the netlink family */ +static struct genl_family nl80211_fam; + +/* multicast groups */ +enum nl80211_multicast_groups { + NL80211_MCGRP_CONFIG, + NL80211_MCGRP_SCAN, + NL80211_MCGRP_REGULATORY, + NL80211_MCGRP_MLME, + NL80211_MCGRP_VENDOR, + NL80211_MCGRP_NAN, + NL80211_MCGRP_TESTMODE /* keep last - ifdef! */ +}; + +static const struct genl_multicast_group nl80211_mcgrps[] = { + [NL80211_MCGRP_CONFIG] = { .name = NL80211_MULTICAST_GROUP_CONFIG }, + [NL80211_MCGRP_SCAN] = { .name = NL80211_MULTICAST_GROUP_SCAN }, + [NL80211_MCGRP_REGULATORY] = { .name = NL80211_MULTICAST_GROUP_REG }, + [NL80211_MCGRP_MLME] = { .name = NL80211_MULTICAST_GROUP_MLME }, + [NL80211_MCGRP_VENDOR] = { .name = NL80211_MULTICAST_GROUP_VENDOR }, + [NL80211_MCGRP_NAN] = { .name = NL80211_MULTICAST_GROUP_NAN }, +#ifdef CONFIG_NL80211_TESTMODE + [NL80211_MCGRP_TESTMODE] = { .name = NL80211_MULTICAST_GROUP_TESTMODE } +#endif +}; + +/* returns ERR_PTR values */ +static struct wireless_dev * +__cfg80211_wdev_from_attrs(struct cfg80211_registered_device *rdev, + struct net *netns, struct nlattr **attrs) +{ + struct wireless_dev *result = NULL; + bool have_ifidx = attrs[NL80211_ATTR_IFINDEX]; + bool have_wdev_id = attrs[NL80211_ATTR_WDEV]; + u64 wdev_id = 0; + int wiphy_idx = -1; + int ifidx = -1; + + if (!have_ifidx && !have_wdev_id) + return ERR_PTR(-EINVAL); + + if (have_ifidx) + ifidx = nla_get_u32(attrs[NL80211_ATTR_IFINDEX]); + if (have_wdev_id) { + wdev_id = nla_get_u64(attrs[NL80211_ATTR_WDEV]); + wiphy_idx = wdev_id >> 32; + } + + if (rdev) { + struct wireless_dev *wdev; + + lockdep_assert_held(&rdev->wiphy.mtx); + + list_for_each_entry(wdev, &rdev->wiphy.wdev_list, list) { + if (have_ifidx && wdev->netdev && + wdev->netdev->ifindex == ifidx) { + result = wdev; + break; + } + if (have_wdev_id && wdev->identifier == (u32)wdev_id) { + result = wdev; + break; + } + } + + return result ?: ERR_PTR(-ENODEV); + } + + ASSERT_RTNL(); + + list_for_each_entry(rdev, &cfg80211_rdev_list, list) { + struct wireless_dev *wdev; + + if (wiphy_net(&rdev->wiphy) != netns) + continue; + + if (have_wdev_id && rdev->wiphy_idx != wiphy_idx) + continue; + + list_for_each_entry(wdev, &rdev->wiphy.wdev_list, list) { + if (have_ifidx && wdev->netdev && + wdev->netdev->ifindex == ifidx) { + result = wdev; + break; + } + if (have_wdev_id && wdev->identifier == (u32)wdev_id) { + result = wdev; + break; + } + } + + if (result) + break; + } + + if (result) + return result; + return ERR_PTR(-ENODEV); +} + +static struct cfg80211_registered_device * +__cfg80211_rdev_from_attrs(struct net *netns, struct nlattr **attrs) +{ + struct cfg80211_registered_device *rdev = NULL, *tmp; + struct net_device *netdev; + + ASSERT_RTNL(); + + if (!attrs[NL80211_ATTR_WIPHY] && + !attrs[NL80211_ATTR_IFINDEX] && + !attrs[NL80211_ATTR_WDEV]) + return ERR_PTR(-EINVAL); + + if (attrs[NL80211_ATTR_WIPHY]) + rdev = cfg80211_rdev_by_wiphy_idx( + nla_get_u32(attrs[NL80211_ATTR_WIPHY])); + + if (attrs[NL80211_ATTR_WDEV]) { + u64 wdev_id = nla_get_u64(attrs[NL80211_ATTR_WDEV]); + struct wireless_dev *wdev; + bool found = false; + + tmp = cfg80211_rdev_by_wiphy_idx(wdev_id >> 32); + if (tmp) { + /* make sure wdev exists */ + list_for_each_entry(wdev, &tmp->wiphy.wdev_list, list) { + if (wdev->identifier != (u32)wdev_id) + continue; + found = true; + break; + } + + if (!found) + tmp = NULL; + + if (rdev && tmp != rdev) + return ERR_PTR(-EINVAL); + rdev = tmp; + } + } + + if (attrs[NL80211_ATTR_IFINDEX]) { + int ifindex = nla_get_u32(attrs[NL80211_ATTR_IFINDEX]); + + netdev = __dev_get_by_index(netns, ifindex); + if (netdev) { + if (netdev->ieee80211_ptr) + tmp = wiphy_to_rdev( + netdev->ieee80211_ptr->wiphy); + else + tmp = NULL; + + /* not wireless device -- return error */ + if (!tmp) + return ERR_PTR(-EINVAL); + + /* mismatch -- return error */ + if (rdev && tmp != rdev) + return ERR_PTR(-EINVAL); + + rdev = tmp; + } + } + + if (!rdev) + return ERR_PTR(-ENODEV); + + if (netns != wiphy_net(&rdev->wiphy)) + return ERR_PTR(-ENODEV); + + return rdev; +} + +/* + * This function returns a pointer to the driver + * that the genl_info item that is passed refers to. + * + * The result of this can be a PTR_ERR and hence must + * be checked with IS_ERR() for errors. + */ +static struct cfg80211_registered_device * +cfg80211_get_dev_from_info(struct net *netns, struct genl_info *info) +{ + return __cfg80211_rdev_from_attrs(netns, info->attrs); +} + +static int validate_beacon_head(const struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + const u8 *data = nla_data(attr); + unsigned int len = nla_len(attr); + const struct element *elem; + const struct ieee80211_mgmt *mgmt = (void *)data; + unsigned int fixedlen, hdrlen; + bool s1g_bcn; + + if (len < offsetofend(typeof(*mgmt), frame_control)) + goto err; + + s1g_bcn = ieee80211_is_s1g_beacon(mgmt->frame_control); + if (s1g_bcn) { + fixedlen = offsetof(struct ieee80211_ext, + u.s1g_beacon.variable); + hdrlen = offsetof(struct ieee80211_ext, u.s1g_beacon); + } else { + fixedlen = offsetof(struct ieee80211_mgmt, + u.beacon.variable); + hdrlen = offsetof(struct ieee80211_mgmt, u.beacon); + } + + if (len < fixedlen) + goto err; + + if (ieee80211_hdrlen(mgmt->frame_control) != hdrlen) + goto err; + + data += fixedlen; + len -= fixedlen; + + for_each_element(elem, data, len) { + /* nothing */ + } + + if (for_each_element_completed(elem, data, len)) + return 0; + +err: + NL_SET_ERR_MSG_ATTR(extack, attr, "malformed beacon head"); + return -EINVAL; +} + +static int validate_ie_attr(const struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + const u8 *data = nla_data(attr); + unsigned int len = nla_len(attr); + const struct element *elem; + + for_each_element(elem, data, len) { + /* nothing */ + } + + if (for_each_element_completed(elem, data, len)) + return 0; + + NL_SET_ERR_MSG_ATTR(extack, attr, "malformed information elements"); + return -EINVAL; +} + +static int validate_he_capa(const struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + if (!ieee80211_he_capa_size_ok(nla_data(attr), nla_len(attr))) + return -EINVAL; + + return 0; +} + +/* policy for the attributes */ +static const struct nla_policy nl80211_policy[NUM_NL80211_ATTR]; + +static const struct nla_policy +nl80211_ftm_responder_policy[NL80211_FTM_RESP_ATTR_MAX + 1] = { + [NL80211_FTM_RESP_ATTR_ENABLED] = { .type = NLA_FLAG, }, + [NL80211_FTM_RESP_ATTR_LCI] = { .type = NLA_BINARY, + .len = U8_MAX }, + [NL80211_FTM_RESP_ATTR_CIVICLOC] = { .type = NLA_BINARY, + .len = U8_MAX }, +}; + +static const struct nla_policy +nl80211_pmsr_ftm_req_attr_policy[NL80211_PMSR_FTM_REQ_ATTR_MAX + 1] = { + [NL80211_PMSR_FTM_REQ_ATTR_ASAP] = { .type = NLA_FLAG }, + [NL80211_PMSR_FTM_REQ_ATTR_PREAMBLE] = { .type = NLA_U32 }, + [NL80211_PMSR_FTM_REQ_ATTR_NUM_BURSTS_EXP] = + NLA_POLICY_MAX(NLA_U8, 15), + [NL80211_PMSR_FTM_REQ_ATTR_BURST_PERIOD] = { .type = NLA_U16 }, + [NL80211_PMSR_FTM_REQ_ATTR_BURST_DURATION] = + NLA_POLICY_MAX(NLA_U8, 15), + [NL80211_PMSR_FTM_REQ_ATTR_FTMS_PER_BURST] = + NLA_POLICY_MAX(NLA_U8, 31), + [NL80211_PMSR_FTM_REQ_ATTR_NUM_FTMR_RETRIES] = { .type = NLA_U8 }, + [NL80211_PMSR_FTM_REQ_ATTR_REQUEST_LCI] = { .type = NLA_FLAG }, + [NL80211_PMSR_FTM_REQ_ATTR_REQUEST_CIVICLOC] = { .type = NLA_FLAG }, + [NL80211_PMSR_FTM_REQ_ATTR_TRIGGER_BASED] = { .type = NLA_FLAG }, + [NL80211_PMSR_FTM_REQ_ATTR_NON_TRIGGER_BASED] = { .type = NLA_FLAG }, + [NL80211_PMSR_FTM_REQ_ATTR_LMR_FEEDBACK] = { .type = NLA_FLAG }, + [NL80211_PMSR_FTM_REQ_ATTR_BSS_COLOR] = { .type = NLA_U8 }, +}; + +static const struct nla_policy +nl80211_pmsr_req_data_policy[NL80211_PMSR_TYPE_MAX + 1] = { + [NL80211_PMSR_TYPE_FTM] = + NLA_POLICY_NESTED(nl80211_pmsr_ftm_req_attr_policy), +}; + +static const struct nla_policy +nl80211_pmsr_req_attr_policy[NL80211_PMSR_REQ_ATTR_MAX + 1] = { + [NL80211_PMSR_REQ_ATTR_DATA] = + NLA_POLICY_NESTED(nl80211_pmsr_req_data_policy), + [NL80211_PMSR_REQ_ATTR_GET_AP_TSF] = { .type = NLA_FLAG }, +}; + +static const struct nla_policy +nl80211_pmsr_peer_attr_policy[NL80211_PMSR_PEER_ATTR_MAX + 1] = { + [NL80211_PMSR_PEER_ATTR_ADDR] = NLA_POLICY_ETH_ADDR, + [NL80211_PMSR_PEER_ATTR_CHAN] = NLA_POLICY_NESTED(nl80211_policy), + [NL80211_PMSR_PEER_ATTR_REQ] = + NLA_POLICY_NESTED(nl80211_pmsr_req_attr_policy), + [NL80211_PMSR_PEER_ATTR_RESP] = { .type = NLA_REJECT }, +}; + +static const struct nla_policy +nl80211_pmsr_attr_policy[NL80211_PMSR_ATTR_MAX + 1] = { + [NL80211_PMSR_ATTR_MAX_PEERS] = { .type = NLA_REJECT }, + [NL80211_PMSR_ATTR_REPORT_AP_TSF] = { .type = NLA_REJECT }, + [NL80211_PMSR_ATTR_RANDOMIZE_MAC_ADDR] = { .type = NLA_REJECT }, + [NL80211_PMSR_ATTR_TYPE_CAPA] = { .type = NLA_REJECT }, + [NL80211_PMSR_ATTR_PEERS] = + NLA_POLICY_NESTED_ARRAY(nl80211_pmsr_peer_attr_policy), +}; + +static const struct nla_policy +he_obss_pd_policy[NL80211_HE_OBSS_PD_ATTR_MAX + 1] = { + [NL80211_HE_OBSS_PD_ATTR_MIN_OFFSET] = + NLA_POLICY_RANGE(NLA_U8, 1, 20), + [NL80211_HE_OBSS_PD_ATTR_MAX_OFFSET] = + NLA_POLICY_RANGE(NLA_U8, 1, 20), + [NL80211_HE_OBSS_PD_ATTR_NON_SRG_MAX_OFFSET] = + NLA_POLICY_RANGE(NLA_U8, 1, 20), + [NL80211_HE_OBSS_PD_ATTR_BSS_COLOR_BITMAP] = + NLA_POLICY_EXACT_LEN(8), + [NL80211_HE_OBSS_PD_ATTR_PARTIAL_BSSID_BITMAP] = + NLA_POLICY_EXACT_LEN(8), + [NL80211_HE_OBSS_PD_ATTR_SR_CTRL] = { .type = NLA_U8 }, +}; + +static const struct nla_policy +he_bss_color_policy[NL80211_HE_BSS_COLOR_ATTR_MAX + 1] = { + [NL80211_HE_BSS_COLOR_ATTR_COLOR] = NLA_POLICY_RANGE(NLA_U8, 1, 63), + [NL80211_HE_BSS_COLOR_ATTR_DISABLED] = { .type = NLA_FLAG }, + [NL80211_HE_BSS_COLOR_ATTR_PARTIAL] = { .type = NLA_FLAG }, +}; + +static const struct nla_policy nl80211_txattr_policy[NL80211_TXRATE_MAX + 1] = { + [NL80211_TXRATE_LEGACY] = { .type = NLA_BINARY, + .len = NL80211_MAX_SUPP_RATES }, + [NL80211_TXRATE_HT] = { .type = NLA_BINARY, + .len = NL80211_MAX_SUPP_HT_RATES }, + [NL80211_TXRATE_VHT] = NLA_POLICY_EXACT_LEN_WARN(sizeof(struct nl80211_txrate_vht)), + [NL80211_TXRATE_GI] = { .type = NLA_U8 }, + [NL80211_TXRATE_HE] = NLA_POLICY_EXACT_LEN(sizeof(struct nl80211_txrate_he)), + [NL80211_TXRATE_HE_GI] = NLA_POLICY_RANGE(NLA_U8, + NL80211_RATE_INFO_HE_GI_0_8, + NL80211_RATE_INFO_HE_GI_3_2), + [NL80211_TXRATE_HE_LTF] = NLA_POLICY_RANGE(NLA_U8, + NL80211_RATE_INFO_HE_1XLTF, + NL80211_RATE_INFO_HE_4XLTF), +}; + +static const struct nla_policy +nl80211_tid_config_attr_policy[NL80211_TID_CONFIG_ATTR_MAX + 1] = { + [NL80211_TID_CONFIG_ATTR_VIF_SUPP] = { .type = NLA_U64 }, + [NL80211_TID_CONFIG_ATTR_PEER_SUPP] = { .type = NLA_U64 }, + [NL80211_TID_CONFIG_ATTR_OVERRIDE] = { .type = NLA_FLAG }, + [NL80211_TID_CONFIG_ATTR_TIDS] = NLA_POLICY_RANGE(NLA_U16, 1, 0xff), + [NL80211_TID_CONFIG_ATTR_NOACK] = + NLA_POLICY_MAX(NLA_U8, NL80211_TID_CONFIG_DISABLE), + [NL80211_TID_CONFIG_ATTR_RETRY_SHORT] = NLA_POLICY_MIN(NLA_U8, 1), + [NL80211_TID_CONFIG_ATTR_RETRY_LONG] = NLA_POLICY_MIN(NLA_U8, 1), + [NL80211_TID_CONFIG_ATTR_AMPDU_CTRL] = + NLA_POLICY_MAX(NLA_U8, NL80211_TID_CONFIG_DISABLE), + [NL80211_TID_CONFIG_ATTR_RTSCTS_CTRL] = + NLA_POLICY_MAX(NLA_U8, NL80211_TID_CONFIG_DISABLE), + [NL80211_TID_CONFIG_ATTR_AMSDU_CTRL] = + NLA_POLICY_MAX(NLA_U8, NL80211_TID_CONFIG_DISABLE), + [NL80211_TID_CONFIG_ATTR_TX_RATE_TYPE] = + NLA_POLICY_MAX(NLA_U8, NL80211_TX_RATE_FIXED), + [NL80211_TID_CONFIG_ATTR_TX_RATE] = + NLA_POLICY_NESTED(nl80211_txattr_policy), +}; + +static const struct nla_policy +nl80211_fils_discovery_policy[NL80211_FILS_DISCOVERY_ATTR_MAX + 1] = { + [NL80211_FILS_DISCOVERY_ATTR_INT_MIN] = NLA_POLICY_MAX(NLA_U32, 10000), + [NL80211_FILS_DISCOVERY_ATTR_INT_MAX] = NLA_POLICY_MAX(NLA_U32, 10000), + [NL80211_FILS_DISCOVERY_ATTR_TMPL] = + NLA_POLICY_RANGE(NLA_BINARY, + NL80211_FILS_DISCOVERY_TMPL_MIN_LEN, + IEEE80211_MAX_DATA_LEN), +}; + +static const struct nla_policy +nl80211_unsol_bcast_probe_resp_policy[NL80211_UNSOL_BCAST_PROBE_RESP_ATTR_MAX + 1] = { + [NL80211_UNSOL_BCAST_PROBE_RESP_ATTR_INT] = NLA_POLICY_MAX(NLA_U32, 20), + [NL80211_UNSOL_BCAST_PROBE_RESP_ATTR_TMPL] = { .type = NLA_BINARY, + .len = IEEE80211_MAX_DATA_LEN } +}; + +static const struct nla_policy +sar_specs_policy[NL80211_SAR_ATTR_SPECS_MAX + 1] = { + [NL80211_SAR_ATTR_SPECS_POWER] = { .type = NLA_S32 }, + [NL80211_SAR_ATTR_SPECS_RANGE_INDEX] = {.type = NLA_U32 }, +}; + +static const struct nla_policy +sar_policy[NL80211_SAR_ATTR_MAX + 1] = { + [NL80211_SAR_ATTR_TYPE] = NLA_POLICY_MAX(NLA_U32, NUM_NL80211_SAR_TYPE), + [NL80211_SAR_ATTR_SPECS] = NLA_POLICY_NESTED_ARRAY(sar_specs_policy), +}; + +static const struct nla_policy +nl80211_mbssid_config_policy[NL80211_MBSSID_CONFIG_ATTR_MAX + 1] = { + [NL80211_MBSSID_CONFIG_ATTR_MAX_INTERFACES] = NLA_POLICY_MIN(NLA_U8, 2), + [NL80211_MBSSID_CONFIG_ATTR_MAX_EMA_PROFILE_PERIODICITY] = + NLA_POLICY_MIN(NLA_U8, 1), + [NL80211_MBSSID_CONFIG_ATTR_INDEX] = { .type = NLA_U8 }, + [NL80211_MBSSID_CONFIG_ATTR_TX_IFINDEX] = { .type = NLA_U32 }, + [NL80211_MBSSID_CONFIG_ATTR_EMA] = { .type = NLA_FLAG }, +}; + +static const struct nla_policy +nl80211_sta_wme_policy[NL80211_STA_WME_MAX + 1] = { + [NL80211_STA_WME_UAPSD_QUEUES] = { .type = NLA_U8 }, + [NL80211_STA_WME_MAX_SP] = { .type = NLA_U8 }, +}; + +static const struct nla_policy nl80211_policy[NUM_NL80211_ATTR] = { + [0] = { .strict_start_type = NL80211_ATTR_HE_OBSS_PD }, + [NL80211_ATTR_WIPHY] = { .type = NLA_U32 }, + [NL80211_ATTR_WIPHY_NAME] = { .type = NLA_NUL_STRING, + .len = 20-1 }, + [NL80211_ATTR_WIPHY_TXQ_PARAMS] = { .type = NLA_NESTED }, + + [NL80211_ATTR_WIPHY_FREQ] = { .type = NLA_U32 }, + [NL80211_ATTR_WIPHY_CHANNEL_TYPE] = { .type = NLA_U32 }, + [NL80211_ATTR_WIPHY_EDMG_CHANNELS] = NLA_POLICY_RANGE(NLA_U8, + NL80211_EDMG_CHANNELS_MIN, + NL80211_EDMG_CHANNELS_MAX), + [NL80211_ATTR_WIPHY_EDMG_BW_CONFIG] = NLA_POLICY_RANGE(NLA_U8, + NL80211_EDMG_BW_CONFIG_MIN, + NL80211_EDMG_BW_CONFIG_MAX), + + [NL80211_ATTR_CHANNEL_WIDTH] = { .type = NLA_U32 }, + [NL80211_ATTR_CENTER_FREQ1] = { .type = NLA_U32 }, + [NL80211_ATTR_CENTER_FREQ1_OFFSET] = NLA_POLICY_RANGE(NLA_U32, 0, 999), + [NL80211_ATTR_CENTER_FREQ2] = { .type = NLA_U32 }, + + [NL80211_ATTR_WIPHY_RETRY_SHORT] = NLA_POLICY_MIN(NLA_U8, 1), + [NL80211_ATTR_WIPHY_RETRY_LONG] = NLA_POLICY_MIN(NLA_U8, 1), + [NL80211_ATTR_WIPHY_FRAG_THRESHOLD] = { .type = NLA_U32 }, + [NL80211_ATTR_WIPHY_RTS_THRESHOLD] = { .type = NLA_U32 }, + [NL80211_ATTR_WIPHY_COVERAGE_CLASS] = { .type = NLA_U8 }, + [NL80211_ATTR_WIPHY_DYN_ACK] = { .type = NLA_FLAG }, + + [NL80211_ATTR_IFTYPE] = NLA_POLICY_MAX(NLA_U32, NL80211_IFTYPE_MAX), + [NL80211_ATTR_IFINDEX] = { .type = NLA_U32 }, + [NL80211_ATTR_IFNAME] = { .type = NLA_NUL_STRING, .len = IFNAMSIZ-1 }, + + [NL80211_ATTR_MAC] = NLA_POLICY_EXACT_LEN_WARN(ETH_ALEN), + [NL80211_ATTR_PREV_BSSID] = NLA_POLICY_EXACT_LEN_WARN(ETH_ALEN), + + [NL80211_ATTR_KEY] = { .type = NLA_NESTED, }, + [NL80211_ATTR_KEY_DATA] = { .type = NLA_BINARY, + .len = WLAN_MAX_KEY_LEN }, + [NL80211_ATTR_KEY_IDX] = NLA_POLICY_MAX(NLA_U8, 7), + [NL80211_ATTR_KEY_CIPHER] = { .type = NLA_U32 }, + [NL80211_ATTR_KEY_DEFAULT] = { .type = NLA_FLAG }, + [NL80211_ATTR_KEY_SEQ] = { .type = NLA_BINARY, .len = 16 }, + [NL80211_ATTR_KEY_TYPE] = + NLA_POLICY_MAX(NLA_U32, NUM_NL80211_KEYTYPES), + + [NL80211_ATTR_BEACON_INTERVAL] = { .type = NLA_U32 }, + [NL80211_ATTR_DTIM_PERIOD] = { .type = NLA_U32 }, + [NL80211_ATTR_BEACON_HEAD] = + NLA_POLICY_VALIDATE_FN(NLA_BINARY, validate_beacon_head, + IEEE80211_MAX_DATA_LEN), + [NL80211_ATTR_BEACON_TAIL] = + NLA_POLICY_VALIDATE_FN(NLA_BINARY, validate_ie_attr, + IEEE80211_MAX_DATA_LEN), + [NL80211_ATTR_STA_AID] = + NLA_POLICY_RANGE(NLA_U16, 1, IEEE80211_MAX_AID), + [NL80211_ATTR_STA_FLAGS] = { .type = NLA_NESTED }, + [NL80211_ATTR_STA_LISTEN_INTERVAL] = { .type = NLA_U16 }, + [NL80211_ATTR_STA_SUPPORTED_RATES] = { .type = NLA_BINARY, + .len = NL80211_MAX_SUPP_RATES }, + [NL80211_ATTR_STA_PLINK_ACTION] = + NLA_POLICY_MAX(NLA_U8, NUM_NL80211_PLINK_ACTIONS - 1), + [NL80211_ATTR_STA_TX_POWER_SETTING] = + NLA_POLICY_RANGE(NLA_U8, + NL80211_TX_POWER_AUTOMATIC, + NL80211_TX_POWER_FIXED), + [NL80211_ATTR_STA_TX_POWER] = { .type = NLA_S16 }, + [NL80211_ATTR_STA_VLAN] = { .type = NLA_U32 }, + [NL80211_ATTR_MNTR_FLAGS] = { /* NLA_NESTED can't be empty */ }, + [NL80211_ATTR_MESH_ID] = { .type = NLA_BINARY, + .len = IEEE80211_MAX_MESH_ID_LEN }, + [NL80211_ATTR_MPATH_NEXT_HOP] = NLA_POLICY_ETH_ADDR_COMPAT, + + /* allow 3 for NUL-termination, we used to declare this NLA_STRING */ + [NL80211_ATTR_REG_ALPHA2] = NLA_POLICY_RANGE(NLA_BINARY, 2, 3), + [NL80211_ATTR_REG_RULES] = { .type = NLA_NESTED }, + + [NL80211_ATTR_BSS_CTS_PROT] = { .type = NLA_U8 }, + [NL80211_ATTR_BSS_SHORT_PREAMBLE] = { .type = NLA_U8 }, + [NL80211_ATTR_BSS_SHORT_SLOT_TIME] = { .type = NLA_U8 }, + [NL80211_ATTR_BSS_BASIC_RATES] = { .type = NLA_BINARY, + .len = NL80211_MAX_SUPP_RATES }, + [NL80211_ATTR_BSS_HT_OPMODE] = { .type = NLA_U16 }, + + [NL80211_ATTR_MESH_CONFIG] = { .type = NLA_NESTED }, + [NL80211_ATTR_SUPPORT_MESH_AUTH] = { .type = NLA_FLAG }, + + [NL80211_ATTR_HT_CAPABILITY] = NLA_POLICY_EXACT_LEN_WARN(NL80211_HT_CAPABILITY_LEN), + + [NL80211_ATTR_MGMT_SUBTYPE] = { .type = NLA_U8 }, + [NL80211_ATTR_IE] = NLA_POLICY_VALIDATE_FN(NLA_BINARY, + validate_ie_attr, + IEEE80211_MAX_DATA_LEN), + [NL80211_ATTR_SCAN_FREQUENCIES] = { .type = NLA_NESTED }, + [NL80211_ATTR_SCAN_SSIDS] = { .type = NLA_NESTED }, + + [NL80211_ATTR_SSID] = { .type = NLA_BINARY, + .len = IEEE80211_MAX_SSID_LEN }, + [NL80211_ATTR_AUTH_TYPE] = { .type = NLA_U32 }, + [NL80211_ATTR_REASON_CODE] = { .type = NLA_U16 }, + [NL80211_ATTR_FREQ_FIXED] = { .type = NLA_FLAG }, + [NL80211_ATTR_TIMED_OUT] = { .type = NLA_FLAG }, + [NL80211_ATTR_USE_MFP] = NLA_POLICY_RANGE(NLA_U32, + NL80211_MFP_NO, + NL80211_MFP_OPTIONAL), + [NL80211_ATTR_STA_FLAGS2] = + NLA_POLICY_EXACT_LEN_WARN(sizeof(struct nl80211_sta_flag_update)), + [NL80211_ATTR_CONTROL_PORT] = { .type = NLA_FLAG }, + [NL80211_ATTR_CONTROL_PORT_ETHERTYPE] = { .type = NLA_U16 }, + [NL80211_ATTR_CONTROL_PORT_NO_ENCRYPT] = { .type = NLA_FLAG }, + [NL80211_ATTR_CONTROL_PORT_OVER_NL80211] = { .type = NLA_FLAG }, + [NL80211_ATTR_PRIVACY] = { .type = NLA_FLAG }, + [NL80211_ATTR_STATUS_CODE] = { .type = NLA_U16 }, + [NL80211_ATTR_CIPHER_SUITE_GROUP] = { .type = NLA_U32 }, + [NL80211_ATTR_WPA_VERSIONS] = { .type = NLA_U32 }, + [NL80211_ATTR_PID] = { .type = NLA_U32 }, + [NL80211_ATTR_4ADDR] = { .type = NLA_U8 }, + [NL80211_ATTR_PMKID] = NLA_POLICY_EXACT_LEN_WARN(WLAN_PMKID_LEN), + [NL80211_ATTR_DURATION] = { .type = NLA_U32 }, + [NL80211_ATTR_COOKIE] = { .type = NLA_U64 }, + [NL80211_ATTR_TX_RATES] = { .type = NLA_NESTED }, + [NL80211_ATTR_FRAME] = { .type = NLA_BINARY, + .len = IEEE80211_MAX_DATA_LEN }, + [NL80211_ATTR_FRAME_MATCH] = { .type = NLA_BINARY, }, + [NL80211_ATTR_PS_STATE] = NLA_POLICY_RANGE(NLA_U32, + NL80211_PS_DISABLED, + NL80211_PS_ENABLED), + [NL80211_ATTR_CQM] = { .type = NLA_NESTED, }, + [NL80211_ATTR_LOCAL_STATE_CHANGE] = { .type = NLA_FLAG }, + [NL80211_ATTR_AP_ISOLATE] = { .type = NLA_U8 }, + [NL80211_ATTR_WIPHY_TX_POWER_SETTING] = { .type = NLA_U32 }, + [NL80211_ATTR_WIPHY_TX_POWER_LEVEL] = { .type = NLA_U32 }, + [NL80211_ATTR_FRAME_TYPE] = { .type = NLA_U16 }, + [NL80211_ATTR_WIPHY_ANTENNA_TX] = { .type = NLA_U32 }, + [NL80211_ATTR_WIPHY_ANTENNA_RX] = { .type = NLA_U32 }, + [NL80211_ATTR_MCAST_RATE] = { .type = NLA_U32 }, + [NL80211_ATTR_OFFCHANNEL_TX_OK] = { .type = NLA_FLAG }, + [NL80211_ATTR_KEY_DEFAULT_TYPES] = { .type = NLA_NESTED }, + [NL80211_ATTR_WOWLAN_TRIGGERS] = { .type = NLA_NESTED }, + [NL80211_ATTR_STA_PLINK_STATE] = + NLA_POLICY_MAX(NLA_U8, NUM_NL80211_PLINK_STATES - 1), + [NL80211_ATTR_MEASUREMENT_DURATION] = { .type = NLA_U16 }, + [NL80211_ATTR_MEASUREMENT_DURATION_MANDATORY] = { .type = NLA_FLAG }, + [NL80211_ATTR_MESH_PEER_AID] = + NLA_POLICY_RANGE(NLA_U16, 1, IEEE80211_MAX_AID), + [NL80211_ATTR_SCHED_SCAN_INTERVAL] = { .type = NLA_U32 }, + [NL80211_ATTR_REKEY_DATA] = { .type = NLA_NESTED }, + [NL80211_ATTR_SCAN_SUPP_RATES] = { .type = NLA_NESTED }, + [NL80211_ATTR_HIDDEN_SSID] = + NLA_POLICY_RANGE(NLA_U32, + NL80211_HIDDEN_SSID_NOT_IN_USE, + NL80211_HIDDEN_SSID_ZERO_CONTENTS), + [NL80211_ATTR_IE_PROBE_RESP] = + NLA_POLICY_VALIDATE_FN(NLA_BINARY, validate_ie_attr, + IEEE80211_MAX_DATA_LEN), + [NL80211_ATTR_IE_ASSOC_RESP] = + NLA_POLICY_VALIDATE_FN(NLA_BINARY, validate_ie_attr, + IEEE80211_MAX_DATA_LEN), + [NL80211_ATTR_ROAM_SUPPORT] = { .type = NLA_FLAG }, + [NL80211_ATTR_STA_WME] = NLA_POLICY_NESTED(nl80211_sta_wme_policy), + [NL80211_ATTR_SCHED_SCAN_MATCH] = { .type = NLA_NESTED }, + [NL80211_ATTR_TX_NO_CCK_RATE] = { .type = NLA_FLAG }, + [NL80211_ATTR_TDLS_ACTION] = { .type = NLA_U8 }, + [NL80211_ATTR_TDLS_DIALOG_TOKEN] = { .type = NLA_U8 }, + [NL80211_ATTR_TDLS_OPERATION] = { .type = NLA_U8 }, + [NL80211_ATTR_TDLS_SUPPORT] = { .type = NLA_FLAG }, + [NL80211_ATTR_TDLS_EXTERNAL_SETUP] = { .type = NLA_FLAG }, + [NL80211_ATTR_TDLS_INITIATOR] = { .type = NLA_FLAG }, + [NL80211_ATTR_DONT_WAIT_FOR_ACK] = { .type = NLA_FLAG }, + [NL80211_ATTR_PROBE_RESP] = { .type = NLA_BINARY, + .len = IEEE80211_MAX_DATA_LEN }, + [NL80211_ATTR_DFS_REGION] = { .type = NLA_U8 }, + [NL80211_ATTR_DISABLE_HT] = { .type = NLA_FLAG }, + [NL80211_ATTR_HT_CAPABILITY_MASK] = { + .len = NL80211_HT_CAPABILITY_LEN + }, + [NL80211_ATTR_NOACK_MAP] = { .type = NLA_U16 }, + [NL80211_ATTR_INACTIVITY_TIMEOUT] = { .type = NLA_U16 }, + [NL80211_ATTR_BG_SCAN_PERIOD] = { .type = NLA_U16 }, + [NL80211_ATTR_WDEV] = { .type = NLA_U64 }, + [NL80211_ATTR_USER_REG_HINT_TYPE] = { .type = NLA_U32 }, + + /* need to include at least Auth Transaction and Status Code */ + [NL80211_ATTR_AUTH_DATA] = NLA_POLICY_MIN_LEN(4), + + [NL80211_ATTR_VHT_CAPABILITY] = NLA_POLICY_EXACT_LEN_WARN(NL80211_VHT_CAPABILITY_LEN), + [NL80211_ATTR_SCAN_FLAGS] = { .type = NLA_U32 }, + [NL80211_ATTR_P2P_CTWINDOW] = NLA_POLICY_MAX(NLA_U8, 127), + [NL80211_ATTR_P2P_OPPPS] = NLA_POLICY_MAX(NLA_U8, 1), + [NL80211_ATTR_LOCAL_MESH_POWER_MODE] = + NLA_POLICY_RANGE(NLA_U32, + NL80211_MESH_POWER_UNKNOWN + 1, + NL80211_MESH_POWER_MAX), + [NL80211_ATTR_ACL_POLICY] = {. type = NLA_U32 }, + [NL80211_ATTR_MAC_ADDRS] = { .type = NLA_NESTED }, + [NL80211_ATTR_STA_CAPABILITY] = { .type = NLA_U16 }, + [NL80211_ATTR_STA_EXT_CAPABILITY] = { .type = NLA_BINARY, }, + [NL80211_ATTR_SPLIT_WIPHY_DUMP] = { .type = NLA_FLAG, }, + [NL80211_ATTR_DISABLE_VHT] = { .type = NLA_FLAG }, + [NL80211_ATTR_VHT_CAPABILITY_MASK] = { + .len = NL80211_VHT_CAPABILITY_LEN, + }, + [NL80211_ATTR_MDID] = { .type = NLA_U16 }, + [NL80211_ATTR_IE_RIC] = { .type = NLA_BINARY, + .len = IEEE80211_MAX_DATA_LEN }, + [NL80211_ATTR_CRIT_PROT_ID] = { .type = NLA_U16 }, + [NL80211_ATTR_MAX_CRIT_PROT_DURATION] = + NLA_POLICY_MAX(NLA_U16, NL80211_CRIT_PROTO_MAX_DURATION), + [NL80211_ATTR_PEER_AID] = + NLA_POLICY_RANGE(NLA_U16, 1, IEEE80211_MAX_AID), + [NL80211_ATTR_CH_SWITCH_COUNT] = { .type = NLA_U32 }, + [NL80211_ATTR_CH_SWITCH_BLOCK_TX] = { .type = NLA_FLAG }, + [NL80211_ATTR_CSA_IES] = { .type = NLA_NESTED }, + [NL80211_ATTR_CNTDWN_OFFS_BEACON] = { .type = NLA_BINARY }, + [NL80211_ATTR_CNTDWN_OFFS_PRESP] = { .type = NLA_BINARY }, + [NL80211_ATTR_STA_SUPPORTED_CHANNELS] = NLA_POLICY_MIN_LEN(2), + /* + * The value of the Length field of the Supported Operating + * Classes element is between 2 and 253. + */ + [NL80211_ATTR_STA_SUPPORTED_OPER_CLASSES] = + NLA_POLICY_RANGE(NLA_BINARY, 2, 253), + [NL80211_ATTR_HANDLE_DFS] = { .type = NLA_FLAG }, + [NL80211_ATTR_OPMODE_NOTIF] = { .type = NLA_U8 }, + [NL80211_ATTR_VENDOR_ID] = { .type = NLA_U32 }, + [NL80211_ATTR_VENDOR_SUBCMD] = { .type = NLA_U32 }, + [NL80211_ATTR_VENDOR_DATA] = { .type = NLA_BINARY }, + [NL80211_ATTR_QOS_MAP] = NLA_POLICY_RANGE(NLA_BINARY, + IEEE80211_QOS_MAP_LEN_MIN, + IEEE80211_QOS_MAP_LEN_MAX), + [NL80211_ATTR_MAC_HINT] = NLA_POLICY_EXACT_LEN_WARN(ETH_ALEN), + [NL80211_ATTR_WIPHY_FREQ_HINT] = { .type = NLA_U32 }, + [NL80211_ATTR_TDLS_PEER_CAPABILITY] = { .type = NLA_U32 }, + [NL80211_ATTR_SOCKET_OWNER] = { .type = NLA_FLAG }, + [NL80211_ATTR_CSA_C_OFFSETS_TX] = { .type = NLA_BINARY }, + [NL80211_ATTR_USE_RRM] = { .type = NLA_FLAG }, + [NL80211_ATTR_TSID] = NLA_POLICY_MAX(NLA_U8, IEEE80211_NUM_TIDS - 1), + [NL80211_ATTR_USER_PRIO] = + NLA_POLICY_MAX(NLA_U8, IEEE80211_NUM_UPS - 1), + [NL80211_ATTR_ADMITTED_TIME] = { .type = NLA_U16 }, + [NL80211_ATTR_SMPS_MODE] = { .type = NLA_U8 }, + [NL80211_ATTR_OPER_CLASS] = { .type = NLA_U8 }, + [NL80211_ATTR_MAC_MASK] = NLA_POLICY_EXACT_LEN_WARN(ETH_ALEN), + [NL80211_ATTR_WIPHY_SELF_MANAGED_REG] = { .type = NLA_FLAG }, + [NL80211_ATTR_NETNS_FD] = { .type = NLA_U32 }, + [NL80211_ATTR_SCHED_SCAN_DELAY] = { .type = NLA_U32 }, + [NL80211_ATTR_REG_INDOOR] = { .type = NLA_FLAG }, + [NL80211_ATTR_PBSS] = { .type = NLA_FLAG }, + [NL80211_ATTR_BSS_SELECT] = { .type = NLA_NESTED }, + [NL80211_ATTR_STA_SUPPORT_P2P_PS] = + NLA_POLICY_MAX(NLA_U8, NUM_NL80211_P2P_PS_STATUS - 1), + [NL80211_ATTR_MU_MIMO_GROUP_DATA] = { + .len = VHT_MUMIMO_GROUPS_DATA_LEN + }, + [NL80211_ATTR_MU_MIMO_FOLLOW_MAC_ADDR] = NLA_POLICY_EXACT_LEN_WARN(ETH_ALEN), + [NL80211_ATTR_NAN_MASTER_PREF] = NLA_POLICY_MIN(NLA_U8, 1), + [NL80211_ATTR_BANDS] = { .type = NLA_U32 }, + [NL80211_ATTR_NAN_FUNC] = { .type = NLA_NESTED }, + [NL80211_ATTR_FILS_KEK] = { .type = NLA_BINARY, + .len = FILS_MAX_KEK_LEN }, + [NL80211_ATTR_FILS_NONCES] = NLA_POLICY_EXACT_LEN_WARN(2 * FILS_NONCE_LEN), + [NL80211_ATTR_MULTICAST_TO_UNICAST_ENABLED] = { .type = NLA_FLAG, }, + [NL80211_ATTR_BSSID] = NLA_POLICY_EXACT_LEN_WARN(ETH_ALEN), + [NL80211_ATTR_SCHED_SCAN_RELATIVE_RSSI] = { .type = NLA_S8 }, + [NL80211_ATTR_SCHED_SCAN_RSSI_ADJUST] = { + .len = sizeof(struct nl80211_bss_select_rssi_adjust) + }, + [NL80211_ATTR_TIMEOUT_REASON] = { .type = NLA_U32 }, + [NL80211_ATTR_FILS_ERP_USERNAME] = { .type = NLA_BINARY, + .len = FILS_ERP_MAX_USERNAME_LEN }, + [NL80211_ATTR_FILS_ERP_REALM] = { .type = NLA_BINARY, + .len = FILS_ERP_MAX_REALM_LEN }, + [NL80211_ATTR_FILS_ERP_NEXT_SEQ_NUM] = { .type = NLA_U16 }, + [NL80211_ATTR_FILS_ERP_RRK] = { .type = NLA_BINARY, + .len = FILS_ERP_MAX_RRK_LEN }, + [NL80211_ATTR_FILS_CACHE_ID] = NLA_POLICY_EXACT_LEN_WARN(2), + [NL80211_ATTR_PMK] = { .type = NLA_BINARY, .len = PMK_MAX_LEN }, + [NL80211_ATTR_PMKR0_NAME] = NLA_POLICY_EXACT_LEN(WLAN_PMK_NAME_LEN), + [NL80211_ATTR_SCHED_SCAN_MULTI] = { .type = NLA_FLAG }, + [NL80211_ATTR_EXTERNAL_AUTH_SUPPORT] = { .type = NLA_FLAG }, + + [NL80211_ATTR_TXQ_LIMIT] = { .type = NLA_U32 }, + [NL80211_ATTR_TXQ_MEMORY_LIMIT] = { .type = NLA_U32 }, + [NL80211_ATTR_TXQ_QUANTUM] = { .type = NLA_U32 }, + [NL80211_ATTR_HE_CAPABILITY] = + NLA_POLICY_VALIDATE_FN(NLA_BINARY, validate_he_capa, + NL80211_HE_MAX_CAPABILITY_LEN), + [NL80211_ATTR_FTM_RESPONDER] = + NLA_POLICY_NESTED(nl80211_ftm_responder_policy), + [NL80211_ATTR_TIMEOUT] = NLA_POLICY_MIN(NLA_U32, 1), + [NL80211_ATTR_PEER_MEASUREMENTS] = + NLA_POLICY_NESTED(nl80211_pmsr_attr_policy), + [NL80211_ATTR_AIRTIME_WEIGHT] = NLA_POLICY_MIN(NLA_U16, 1), + [NL80211_ATTR_SAE_PASSWORD] = { .type = NLA_BINARY, + .len = SAE_PASSWORD_MAX_LEN }, + [NL80211_ATTR_TWT_RESPONDER] = { .type = NLA_FLAG }, + [NL80211_ATTR_HE_OBSS_PD] = NLA_POLICY_NESTED(he_obss_pd_policy), + [NL80211_ATTR_VLAN_ID] = NLA_POLICY_RANGE(NLA_U16, 1, VLAN_N_VID - 2), + [NL80211_ATTR_HE_BSS_COLOR] = NLA_POLICY_NESTED(he_bss_color_policy), + [NL80211_ATTR_TID_CONFIG] = + NLA_POLICY_NESTED_ARRAY(nl80211_tid_config_attr_policy), + [NL80211_ATTR_CONTROL_PORT_NO_PREAUTH] = { .type = NLA_FLAG }, + [NL80211_ATTR_PMK_LIFETIME] = NLA_POLICY_MIN(NLA_U32, 1), + [NL80211_ATTR_PMK_REAUTH_THRESHOLD] = NLA_POLICY_RANGE(NLA_U8, 1, 100), + [NL80211_ATTR_RECEIVE_MULTICAST] = { .type = NLA_FLAG }, + [NL80211_ATTR_WIPHY_FREQ_OFFSET] = NLA_POLICY_RANGE(NLA_U32, 0, 999), + [NL80211_ATTR_SCAN_FREQ_KHZ] = { .type = NLA_NESTED }, + [NL80211_ATTR_HE_6GHZ_CAPABILITY] = + NLA_POLICY_EXACT_LEN(sizeof(struct ieee80211_he_6ghz_capa)), + [NL80211_ATTR_FILS_DISCOVERY] = + NLA_POLICY_NESTED(nl80211_fils_discovery_policy), + [NL80211_ATTR_UNSOL_BCAST_PROBE_RESP] = + NLA_POLICY_NESTED(nl80211_unsol_bcast_probe_resp_policy), + [NL80211_ATTR_S1G_CAPABILITY] = + NLA_POLICY_EXACT_LEN(IEEE80211_S1G_CAPABILITY_LEN), + [NL80211_ATTR_S1G_CAPABILITY_MASK] = + NLA_POLICY_EXACT_LEN(IEEE80211_S1G_CAPABILITY_LEN), + [NL80211_ATTR_SAE_PWE] = + NLA_POLICY_RANGE(NLA_U8, NL80211_SAE_PWE_HUNT_AND_PECK, + NL80211_SAE_PWE_BOTH), + [NL80211_ATTR_RECONNECT_REQUESTED] = { .type = NLA_REJECT }, + [NL80211_ATTR_SAR_SPEC] = NLA_POLICY_NESTED(sar_policy), + [NL80211_ATTR_DISABLE_HE] = { .type = NLA_FLAG }, + [NL80211_ATTR_OBSS_COLOR_BITMAP] = { .type = NLA_U64 }, + [NL80211_ATTR_COLOR_CHANGE_COUNT] = { .type = NLA_U8 }, + [NL80211_ATTR_COLOR_CHANGE_COLOR] = { .type = NLA_U8 }, + [NL80211_ATTR_COLOR_CHANGE_ELEMS] = NLA_POLICY_NESTED(nl80211_policy), + [NL80211_ATTR_MBSSID_CONFIG] = + NLA_POLICY_NESTED(nl80211_mbssid_config_policy), + [NL80211_ATTR_MBSSID_ELEMS] = { .type = NLA_NESTED }, + [NL80211_ATTR_RADAR_BACKGROUND] = { .type = NLA_FLAG }, + [NL80211_ATTR_AP_SETTINGS_FLAGS] = { .type = NLA_U32 }, + [NL80211_ATTR_EHT_CAPABILITY] = + NLA_POLICY_RANGE(NLA_BINARY, + NL80211_EHT_MIN_CAPABILITY_LEN, + NL80211_EHT_MAX_CAPABILITY_LEN), + [NL80211_ATTR_DISABLE_EHT] = { .type = NLA_FLAG }, + [NL80211_ATTR_MLO_LINKS] = + NLA_POLICY_NESTED_ARRAY(nl80211_policy), + [NL80211_ATTR_MLO_LINK_ID] = + NLA_POLICY_RANGE(NLA_U8, 0, IEEE80211_MLD_MAX_NUM_LINKS), + [NL80211_ATTR_MLD_ADDR] = NLA_POLICY_EXACT_LEN(ETH_ALEN), + [NL80211_ATTR_MLO_SUPPORT] = { .type = NLA_FLAG }, + [NL80211_ATTR_MAX_NUM_AKM_SUITES] = { .type = NLA_REJECT }, +}; + +/* policy for the key attributes */ +static const struct nla_policy nl80211_key_policy[NL80211_KEY_MAX + 1] = { + [NL80211_KEY_DATA] = { .type = NLA_BINARY, .len = WLAN_MAX_KEY_LEN }, + [NL80211_KEY_IDX] = { .type = NLA_U8 }, + [NL80211_KEY_CIPHER] = { .type = NLA_U32 }, + [NL80211_KEY_SEQ] = { .type = NLA_BINARY, .len = 16 }, + [NL80211_KEY_DEFAULT] = { .type = NLA_FLAG }, + [NL80211_KEY_DEFAULT_MGMT] = { .type = NLA_FLAG }, + [NL80211_KEY_TYPE] = NLA_POLICY_MAX(NLA_U32, NUM_NL80211_KEYTYPES - 1), + [NL80211_KEY_DEFAULT_TYPES] = { .type = NLA_NESTED }, + [NL80211_KEY_MODE] = NLA_POLICY_RANGE(NLA_U8, 0, NL80211_KEY_SET_TX), +}; + +/* policy for the key default flags */ +static const struct nla_policy +nl80211_key_default_policy[NUM_NL80211_KEY_DEFAULT_TYPES] = { + [NL80211_KEY_DEFAULT_TYPE_UNICAST] = { .type = NLA_FLAG }, + [NL80211_KEY_DEFAULT_TYPE_MULTICAST] = { .type = NLA_FLAG }, +}; + +#ifdef CONFIG_PM +/* policy for WoWLAN attributes */ +static const struct nla_policy +nl80211_wowlan_policy[NUM_NL80211_WOWLAN_TRIG] = { + [NL80211_WOWLAN_TRIG_ANY] = { .type = NLA_FLAG }, + [NL80211_WOWLAN_TRIG_DISCONNECT] = { .type = NLA_FLAG }, + [NL80211_WOWLAN_TRIG_MAGIC_PKT] = { .type = NLA_FLAG }, + [NL80211_WOWLAN_TRIG_PKT_PATTERN] = { .type = NLA_NESTED }, + [NL80211_WOWLAN_TRIG_GTK_REKEY_FAILURE] = { .type = NLA_FLAG }, + [NL80211_WOWLAN_TRIG_EAP_IDENT_REQUEST] = { .type = NLA_FLAG }, + [NL80211_WOWLAN_TRIG_4WAY_HANDSHAKE] = { .type = NLA_FLAG }, + [NL80211_WOWLAN_TRIG_RFKILL_RELEASE] = { .type = NLA_FLAG }, + [NL80211_WOWLAN_TRIG_TCP_CONNECTION] = { .type = NLA_NESTED }, + [NL80211_WOWLAN_TRIG_NET_DETECT] = { .type = NLA_NESTED }, +}; + +static const struct nla_policy +nl80211_wowlan_tcp_policy[NUM_NL80211_WOWLAN_TCP] = { + [NL80211_WOWLAN_TCP_SRC_IPV4] = { .type = NLA_U32 }, + [NL80211_WOWLAN_TCP_DST_IPV4] = { .type = NLA_U32 }, + [NL80211_WOWLAN_TCP_DST_MAC] = NLA_POLICY_EXACT_LEN_WARN(ETH_ALEN), + [NL80211_WOWLAN_TCP_SRC_PORT] = { .type = NLA_U16 }, + [NL80211_WOWLAN_TCP_DST_PORT] = { .type = NLA_U16 }, + [NL80211_WOWLAN_TCP_DATA_PAYLOAD] = NLA_POLICY_MIN_LEN(1), + [NL80211_WOWLAN_TCP_DATA_PAYLOAD_SEQ] = { + .len = sizeof(struct nl80211_wowlan_tcp_data_seq) + }, + [NL80211_WOWLAN_TCP_DATA_PAYLOAD_TOKEN] = { + .len = sizeof(struct nl80211_wowlan_tcp_data_token) + }, + [NL80211_WOWLAN_TCP_DATA_INTERVAL] = { .type = NLA_U32 }, + [NL80211_WOWLAN_TCP_WAKE_PAYLOAD] = NLA_POLICY_MIN_LEN(1), + [NL80211_WOWLAN_TCP_WAKE_MASK] = NLA_POLICY_MIN_LEN(1), +}; +#endif /* CONFIG_PM */ + +/* policy for coalesce rule attributes */ +static const struct nla_policy +nl80211_coalesce_policy[NUM_NL80211_ATTR_COALESCE_RULE] = { + [NL80211_ATTR_COALESCE_RULE_DELAY] = { .type = NLA_U32 }, + [NL80211_ATTR_COALESCE_RULE_CONDITION] = + NLA_POLICY_RANGE(NLA_U32, + NL80211_COALESCE_CONDITION_MATCH, + NL80211_COALESCE_CONDITION_NO_MATCH), + [NL80211_ATTR_COALESCE_RULE_PKT_PATTERN] = { .type = NLA_NESTED }, +}; + +/* policy for GTK rekey offload attributes */ +static const struct nla_policy +nl80211_rekey_policy[NUM_NL80211_REKEY_DATA] = { + [NL80211_REKEY_DATA_KEK] = { + .type = NLA_BINARY, + .len = NL80211_KEK_EXT_LEN + }, + [NL80211_REKEY_DATA_KCK] = { + .type = NLA_BINARY, + .len = NL80211_KCK_EXT_LEN + }, + [NL80211_REKEY_DATA_REPLAY_CTR] = NLA_POLICY_EXACT_LEN(NL80211_REPLAY_CTR_LEN), + [NL80211_REKEY_DATA_AKM] = { .type = NLA_U32 }, +}; + +static const struct nla_policy +nl80211_match_band_rssi_policy[NUM_NL80211_BANDS] = { + [NL80211_BAND_2GHZ] = { .type = NLA_S32 }, + [NL80211_BAND_5GHZ] = { .type = NLA_S32 }, + [NL80211_BAND_6GHZ] = { .type = NLA_S32 }, + [NL80211_BAND_60GHZ] = { .type = NLA_S32 }, + [NL80211_BAND_LC] = { .type = NLA_S32 }, +}; + +static const struct nla_policy +nl80211_match_policy[NL80211_SCHED_SCAN_MATCH_ATTR_MAX + 1] = { + [NL80211_SCHED_SCAN_MATCH_ATTR_SSID] = { .type = NLA_BINARY, + .len = IEEE80211_MAX_SSID_LEN }, + [NL80211_SCHED_SCAN_MATCH_ATTR_BSSID] = NLA_POLICY_EXACT_LEN_WARN(ETH_ALEN), + [NL80211_SCHED_SCAN_MATCH_ATTR_RSSI] = { .type = NLA_U32 }, + [NL80211_SCHED_SCAN_MATCH_PER_BAND_RSSI] = + NLA_POLICY_NESTED(nl80211_match_band_rssi_policy), +}; + +static const struct nla_policy +nl80211_plan_policy[NL80211_SCHED_SCAN_PLAN_MAX + 1] = { + [NL80211_SCHED_SCAN_PLAN_INTERVAL] = { .type = NLA_U32 }, + [NL80211_SCHED_SCAN_PLAN_ITERATIONS] = { .type = NLA_U32 }, +}; + +static const struct nla_policy +nl80211_bss_select_policy[NL80211_BSS_SELECT_ATTR_MAX + 1] = { + [NL80211_BSS_SELECT_ATTR_RSSI] = { .type = NLA_FLAG }, + [NL80211_BSS_SELECT_ATTR_BAND_PREF] = { .type = NLA_U32 }, + [NL80211_BSS_SELECT_ATTR_RSSI_ADJUST] = { + .len = sizeof(struct nl80211_bss_select_rssi_adjust) + }, +}; + +/* policy for NAN function attributes */ +static const struct nla_policy +nl80211_nan_func_policy[NL80211_NAN_FUNC_ATTR_MAX + 1] = { + [NL80211_NAN_FUNC_TYPE] = + NLA_POLICY_MAX(NLA_U8, NL80211_NAN_FUNC_MAX_TYPE), + [NL80211_NAN_FUNC_SERVICE_ID] = { + .len = NL80211_NAN_FUNC_SERVICE_ID_LEN }, + [NL80211_NAN_FUNC_PUBLISH_TYPE] = { .type = NLA_U8 }, + [NL80211_NAN_FUNC_PUBLISH_BCAST] = { .type = NLA_FLAG }, + [NL80211_NAN_FUNC_SUBSCRIBE_ACTIVE] = { .type = NLA_FLAG }, + [NL80211_NAN_FUNC_FOLLOW_UP_ID] = { .type = NLA_U8 }, + [NL80211_NAN_FUNC_FOLLOW_UP_REQ_ID] = { .type = NLA_U8 }, + [NL80211_NAN_FUNC_FOLLOW_UP_DEST] = NLA_POLICY_EXACT_LEN_WARN(ETH_ALEN), + [NL80211_NAN_FUNC_CLOSE_RANGE] = { .type = NLA_FLAG }, + [NL80211_NAN_FUNC_TTL] = { .type = NLA_U32 }, + [NL80211_NAN_FUNC_SERVICE_INFO] = { .type = NLA_BINARY, + .len = NL80211_NAN_FUNC_SERVICE_SPEC_INFO_MAX_LEN }, + [NL80211_NAN_FUNC_SRF] = { .type = NLA_NESTED }, + [NL80211_NAN_FUNC_RX_MATCH_FILTER] = { .type = NLA_NESTED }, + [NL80211_NAN_FUNC_TX_MATCH_FILTER] = { .type = NLA_NESTED }, + [NL80211_NAN_FUNC_INSTANCE_ID] = { .type = NLA_U8 }, + [NL80211_NAN_FUNC_TERM_REASON] = { .type = NLA_U8 }, +}; + +/* policy for Service Response Filter attributes */ +static const struct nla_policy +nl80211_nan_srf_policy[NL80211_NAN_SRF_ATTR_MAX + 1] = { + [NL80211_NAN_SRF_INCLUDE] = { .type = NLA_FLAG }, + [NL80211_NAN_SRF_BF] = { .type = NLA_BINARY, + .len = NL80211_NAN_FUNC_SRF_MAX_LEN }, + [NL80211_NAN_SRF_BF_IDX] = { .type = NLA_U8 }, + [NL80211_NAN_SRF_MAC_ADDRS] = { .type = NLA_NESTED }, +}; + +/* policy for packet pattern attributes */ +static const struct nla_policy +nl80211_packet_pattern_policy[MAX_NL80211_PKTPAT + 1] = { + [NL80211_PKTPAT_MASK] = { .type = NLA_BINARY, }, + [NL80211_PKTPAT_PATTERN] = { .type = NLA_BINARY, }, + [NL80211_PKTPAT_OFFSET] = { .type = NLA_U32 }, +}; + +static int nl80211_prepare_wdev_dump(struct netlink_callback *cb, + struct cfg80211_registered_device **rdev, + struct wireless_dev **wdev, + struct nlattr **attrbuf) +{ + int err; + + if (!cb->args[0]) { + struct nlattr **attrbuf_free = NULL; + + if (!attrbuf) { + attrbuf = kcalloc(NUM_NL80211_ATTR, sizeof(*attrbuf), + GFP_KERNEL); + if (!attrbuf) + return -ENOMEM; + attrbuf_free = attrbuf; + } + + err = nlmsg_parse_deprecated(cb->nlh, + GENL_HDRLEN + nl80211_fam.hdrsize, + attrbuf, nl80211_fam.maxattr, + nl80211_policy, NULL); + if (err) { + kfree(attrbuf_free); + return err; + } + + rtnl_lock(); + *wdev = __cfg80211_wdev_from_attrs(NULL, sock_net(cb->skb->sk), + attrbuf); + kfree(attrbuf_free); + if (IS_ERR(*wdev)) { + rtnl_unlock(); + return PTR_ERR(*wdev); + } + *rdev = wiphy_to_rdev((*wdev)->wiphy); + mutex_lock(&(*rdev)->wiphy.mtx); + rtnl_unlock(); + /* 0 is the first index - add 1 to parse only once */ + cb->args[0] = (*rdev)->wiphy_idx + 1; + cb->args[1] = (*wdev)->identifier; + } else { + /* subtract the 1 again here */ + struct wiphy *wiphy; + struct wireless_dev *tmp; + + rtnl_lock(); + wiphy = wiphy_idx_to_wiphy(cb->args[0] - 1); + if (!wiphy) { + rtnl_unlock(); + return -ENODEV; + } + *rdev = wiphy_to_rdev(wiphy); + *wdev = NULL; + + list_for_each_entry(tmp, &(*rdev)->wiphy.wdev_list, list) { + if (tmp->identifier == cb->args[1]) { + *wdev = tmp; + break; + } + } + + if (!*wdev) { + rtnl_unlock(); + return -ENODEV; + } + mutex_lock(&(*rdev)->wiphy.mtx); + rtnl_unlock(); + } + + return 0; +} + +/* message building helper */ +void *nl80211hdr_put(struct sk_buff *skb, u32 portid, u32 seq, + int flags, u8 cmd) +{ + /* since there is no private header just add the generic one */ + return genlmsg_put(skb, portid, seq, &nl80211_fam, flags, cmd); +} + +static int nl80211_msg_put_wmm_rules(struct sk_buff *msg, + const struct ieee80211_reg_rule *rule) +{ + int j; + struct nlattr *nl_wmm_rules = + nla_nest_start_noflag(msg, NL80211_FREQUENCY_ATTR_WMM); + + if (!nl_wmm_rules) + goto nla_put_failure; + + for (j = 0; j < IEEE80211_NUM_ACS; j++) { + struct nlattr *nl_wmm_rule = nla_nest_start_noflag(msg, j); + + if (!nl_wmm_rule) + goto nla_put_failure; + + if (nla_put_u16(msg, NL80211_WMMR_CW_MIN, + rule->wmm_rule.client[j].cw_min) || + nla_put_u16(msg, NL80211_WMMR_CW_MAX, + rule->wmm_rule.client[j].cw_max) || + nla_put_u8(msg, NL80211_WMMR_AIFSN, + rule->wmm_rule.client[j].aifsn) || + nla_put_u16(msg, NL80211_WMMR_TXOP, + rule->wmm_rule.client[j].cot)) + goto nla_put_failure; + + nla_nest_end(msg, nl_wmm_rule); + } + nla_nest_end(msg, nl_wmm_rules); + + return 0; + +nla_put_failure: + return -ENOBUFS; +} + +static int nl80211_msg_put_channel(struct sk_buff *msg, struct wiphy *wiphy, + struct ieee80211_channel *chan, + bool large) +{ + /* Some channels must be completely excluded from the + * list to protect old user-space tools from breaking + */ + if (!large && chan->flags & + (IEEE80211_CHAN_NO_10MHZ | IEEE80211_CHAN_NO_20MHZ)) + return 0; + if (!large && chan->freq_offset) + return 0; + + if (nla_put_u32(msg, NL80211_FREQUENCY_ATTR_FREQ, + chan->center_freq)) + goto nla_put_failure; + + if (nla_put_u32(msg, NL80211_FREQUENCY_ATTR_OFFSET, chan->freq_offset)) + goto nla_put_failure; + + if ((chan->flags & IEEE80211_CHAN_DISABLED) && + nla_put_flag(msg, NL80211_FREQUENCY_ATTR_DISABLED)) + goto nla_put_failure; + if (chan->flags & IEEE80211_CHAN_NO_IR) { + if (nla_put_flag(msg, NL80211_FREQUENCY_ATTR_NO_IR)) + goto nla_put_failure; + if (nla_put_flag(msg, __NL80211_FREQUENCY_ATTR_NO_IBSS)) + goto nla_put_failure; + } + if (chan->flags & IEEE80211_CHAN_RADAR) { + if (nla_put_flag(msg, NL80211_FREQUENCY_ATTR_RADAR)) + goto nla_put_failure; + if (large) { + u32 time; + + time = elapsed_jiffies_msecs(chan->dfs_state_entered); + + if (nla_put_u32(msg, NL80211_FREQUENCY_ATTR_DFS_STATE, + chan->dfs_state)) + goto nla_put_failure; + if (nla_put_u32(msg, NL80211_FREQUENCY_ATTR_DFS_TIME, + time)) + goto nla_put_failure; + if (nla_put_u32(msg, + NL80211_FREQUENCY_ATTR_DFS_CAC_TIME, + chan->dfs_cac_ms)) + goto nla_put_failure; + } + } + + if (large) { + if ((chan->flags & IEEE80211_CHAN_NO_HT40MINUS) && + nla_put_flag(msg, NL80211_FREQUENCY_ATTR_NO_HT40_MINUS)) + goto nla_put_failure; + if ((chan->flags & IEEE80211_CHAN_NO_HT40PLUS) && + nla_put_flag(msg, NL80211_FREQUENCY_ATTR_NO_HT40_PLUS)) + goto nla_put_failure; + if ((chan->flags & IEEE80211_CHAN_NO_80MHZ) && + nla_put_flag(msg, NL80211_FREQUENCY_ATTR_NO_80MHZ)) + goto nla_put_failure; + if ((chan->flags & IEEE80211_CHAN_NO_160MHZ) && + nla_put_flag(msg, NL80211_FREQUENCY_ATTR_NO_160MHZ)) + goto nla_put_failure; + if ((chan->flags & IEEE80211_CHAN_INDOOR_ONLY) && + nla_put_flag(msg, NL80211_FREQUENCY_ATTR_INDOOR_ONLY)) + goto nla_put_failure; + if ((chan->flags & IEEE80211_CHAN_IR_CONCURRENT) && + nla_put_flag(msg, NL80211_FREQUENCY_ATTR_IR_CONCURRENT)) + goto nla_put_failure; + if ((chan->flags & IEEE80211_CHAN_NO_20MHZ) && + nla_put_flag(msg, NL80211_FREQUENCY_ATTR_NO_20MHZ)) + goto nla_put_failure; + if ((chan->flags & IEEE80211_CHAN_NO_10MHZ) && + nla_put_flag(msg, NL80211_FREQUENCY_ATTR_NO_10MHZ)) + goto nla_put_failure; + if ((chan->flags & IEEE80211_CHAN_NO_HE) && + nla_put_flag(msg, NL80211_FREQUENCY_ATTR_NO_HE)) + goto nla_put_failure; + if ((chan->flags & IEEE80211_CHAN_1MHZ) && + nla_put_flag(msg, NL80211_FREQUENCY_ATTR_1MHZ)) + goto nla_put_failure; + if ((chan->flags & IEEE80211_CHAN_2MHZ) && + nla_put_flag(msg, NL80211_FREQUENCY_ATTR_2MHZ)) + goto nla_put_failure; + if ((chan->flags & IEEE80211_CHAN_4MHZ) && + nla_put_flag(msg, NL80211_FREQUENCY_ATTR_4MHZ)) + goto nla_put_failure; + if ((chan->flags & IEEE80211_CHAN_8MHZ) && + nla_put_flag(msg, NL80211_FREQUENCY_ATTR_8MHZ)) + goto nla_put_failure; + if ((chan->flags & IEEE80211_CHAN_16MHZ) && + nla_put_flag(msg, NL80211_FREQUENCY_ATTR_16MHZ)) + goto nla_put_failure; + if ((chan->flags & IEEE80211_CHAN_NO_320MHZ) && + nla_put_flag(msg, NL80211_FREQUENCY_ATTR_NO_320MHZ)) + goto nla_put_failure; + if ((chan->flags & IEEE80211_CHAN_NO_EHT) && + nla_put_flag(msg, NL80211_FREQUENCY_ATTR_NO_EHT)) + goto nla_put_failure; + } + + if (nla_put_u32(msg, NL80211_FREQUENCY_ATTR_MAX_TX_POWER, + DBM_TO_MBM(chan->max_power))) + goto nla_put_failure; + + if (large) { + const struct ieee80211_reg_rule *rule = + freq_reg_info(wiphy, MHZ_TO_KHZ(chan->center_freq)); + + if (!IS_ERR_OR_NULL(rule) && rule->has_wmm) { + if (nl80211_msg_put_wmm_rules(msg, rule)) + goto nla_put_failure; + } + } + + return 0; + + nla_put_failure: + return -ENOBUFS; +} + +static bool nl80211_put_txq_stats(struct sk_buff *msg, + struct cfg80211_txq_stats *txqstats, + int attrtype) +{ + struct nlattr *txqattr; + +#define PUT_TXQVAL_U32(attr, memb) do { \ + if (txqstats->filled & BIT(NL80211_TXQ_STATS_ ## attr) && \ + nla_put_u32(msg, NL80211_TXQ_STATS_ ## attr, txqstats->memb)) \ + return false; \ + } while (0) + + txqattr = nla_nest_start_noflag(msg, attrtype); + if (!txqattr) + return false; + + PUT_TXQVAL_U32(BACKLOG_BYTES, backlog_bytes); + PUT_TXQVAL_U32(BACKLOG_PACKETS, backlog_packets); + PUT_TXQVAL_U32(FLOWS, flows); + PUT_TXQVAL_U32(DROPS, drops); + PUT_TXQVAL_U32(ECN_MARKS, ecn_marks); + PUT_TXQVAL_U32(OVERLIMIT, overlimit); + PUT_TXQVAL_U32(OVERMEMORY, overmemory); + PUT_TXQVAL_U32(COLLISIONS, collisions); + PUT_TXQVAL_U32(TX_BYTES, tx_bytes); + PUT_TXQVAL_U32(TX_PACKETS, tx_packets); + PUT_TXQVAL_U32(MAX_FLOWS, max_flows); + nla_nest_end(msg, txqattr); + +#undef PUT_TXQVAL_U32 + return true; +} + +/* netlink command implementations */ + +/** + * nl80211_link_id - return link ID + * @attrs: attributes to look at + * + * Returns: the link ID or 0 if not given + * + * Note this function doesn't do any validation of the link + * ID validity wrt. links that were actually added, so it must + * be called only from ops with %NL80211_FLAG_MLO_VALID_LINK_ID + * or if additional validation is done. + */ +static unsigned int nl80211_link_id(struct nlattr **attrs) +{ + struct nlattr *linkid = attrs[NL80211_ATTR_MLO_LINK_ID]; + + if (!linkid) + return 0; + + return nla_get_u8(linkid); +} + +static int nl80211_link_id_or_invalid(struct nlattr **attrs) +{ + struct nlattr *linkid = attrs[NL80211_ATTR_MLO_LINK_ID]; + + if (!linkid) + return -1; + + return nla_get_u8(linkid); +} + +struct key_parse { + struct key_params p; + int idx; + int type; + bool def, defmgmt, defbeacon; + bool def_uni, def_multi; +}; + +static int nl80211_parse_key_new(struct genl_info *info, struct nlattr *key, + struct key_parse *k) +{ + struct nlattr *tb[NL80211_KEY_MAX + 1]; + int err = nla_parse_nested_deprecated(tb, NL80211_KEY_MAX, key, + nl80211_key_policy, + info->extack); + if (err) + return err; + + k->def = !!tb[NL80211_KEY_DEFAULT]; + k->defmgmt = !!tb[NL80211_KEY_DEFAULT_MGMT]; + k->defbeacon = !!tb[NL80211_KEY_DEFAULT_BEACON]; + + if (k->def) { + k->def_uni = true; + k->def_multi = true; + } + if (k->defmgmt || k->defbeacon) + k->def_multi = true; + + if (tb[NL80211_KEY_IDX]) + k->idx = nla_get_u8(tb[NL80211_KEY_IDX]); + + if (tb[NL80211_KEY_DATA]) { + k->p.key = nla_data(tb[NL80211_KEY_DATA]); + k->p.key_len = nla_len(tb[NL80211_KEY_DATA]); + } + + if (tb[NL80211_KEY_SEQ]) { + k->p.seq = nla_data(tb[NL80211_KEY_SEQ]); + k->p.seq_len = nla_len(tb[NL80211_KEY_SEQ]); + } + + if (tb[NL80211_KEY_CIPHER]) + k->p.cipher = nla_get_u32(tb[NL80211_KEY_CIPHER]); + + if (tb[NL80211_KEY_TYPE]) + k->type = nla_get_u32(tb[NL80211_KEY_TYPE]); + + if (tb[NL80211_KEY_DEFAULT_TYPES]) { + struct nlattr *kdt[NUM_NL80211_KEY_DEFAULT_TYPES]; + + err = nla_parse_nested_deprecated(kdt, + NUM_NL80211_KEY_DEFAULT_TYPES - 1, + tb[NL80211_KEY_DEFAULT_TYPES], + nl80211_key_default_policy, + info->extack); + if (err) + return err; + + k->def_uni = kdt[NL80211_KEY_DEFAULT_TYPE_UNICAST]; + k->def_multi = kdt[NL80211_KEY_DEFAULT_TYPE_MULTICAST]; + } + + if (tb[NL80211_KEY_MODE]) + k->p.mode = nla_get_u8(tb[NL80211_KEY_MODE]); + + return 0; +} + +static int nl80211_parse_key_old(struct genl_info *info, struct key_parse *k) +{ + if (info->attrs[NL80211_ATTR_KEY_DATA]) { + k->p.key = nla_data(info->attrs[NL80211_ATTR_KEY_DATA]); + k->p.key_len = nla_len(info->attrs[NL80211_ATTR_KEY_DATA]); + } + + if (info->attrs[NL80211_ATTR_KEY_SEQ]) { + k->p.seq = nla_data(info->attrs[NL80211_ATTR_KEY_SEQ]); + k->p.seq_len = nla_len(info->attrs[NL80211_ATTR_KEY_SEQ]); + } + + if (info->attrs[NL80211_ATTR_KEY_IDX]) + k->idx = nla_get_u8(info->attrs[NL80211_ATTR_KEY_IDX]); + + if (info->attrs[NL80211_ATTR_KEY_CIPHER]) + k->p.cipher = nla_get_u32(info->attrs[NL80211_ATTR_KEY_CIPHER]); + + k->def = !!info->attrs[NL80211_ATTR_KEY_DEFAULT]; + k->defmgmt = !!info->attrs[NL80211_ATTR_KEY_DEFAULT_MGMT]; + + if (k->def) { + k->def_uni = true; + k->def_multi = true; + } + if (k->defmgmt) + k->def_multi = true; + + if (info->attrs[NL80211_ATTR_KEY_TYPE]) + k->type = nla_get_u32(info->attrs[NL80211_ATTR_KEY_TYPE]); + + if (info->attrs[NL80211_ATTR_KEY_DEFAULT_TYPES]) { + struct nlattr *kdt[NUM_NL80211_KEY_DEFAULT_TYPES]; + int err = nla_parse_nested_deprecated(kdt, + NUM_NL80211_KEY_DEFAULT_TYPES - 1, + info->attrs[NL80211_ATTR_KEY_DEFAULT_TYPES], + nl80211_key_default_policy, + info->extack); + if (err) + return err; + + k->def_uni = kdt[NL80211_KEY_DEFAULT_TYPE_UNICAST]; + k->def_multi = kdt[NL80211_KEY_DEFAULT_TYPE_MULTICAST]; + } + + return 0; +} + +static int nl80211_parse_key(struct genl_info *info, struct key_parse *k) +{ + int err; + + memset(k, 0, sizeof(*k)); + k->idx = -1; + k->type = -1; + + if (info->attrs[NL80211_ATTR_KEY]) + err = nl80211_parse_key_new(info, info->attrs[NL80211_ATTR_KEY], k); + else + err = nl80211_parse_key_old(info, k); + + if (err) + return err; + + if ((k->def ? 1 : 0) + (k->defmgmt ? 1 : 0) + + (k->defbeacon ? 1 : 0) > 1) { + GENL_SET_ERR_MSG(info, + "key with multiple default flags is invalid"); + return -EINVAL; + } + + if (k->defmgmt || k->defbeacon) { + if (k->def_uni || !k->def_multi) { + GENL_SET_ERR_MSG(info, + "defmgmt/defbeacon key must be mcast"); + return -EINVAL; + } + } + + if (k->idx != -1) { + if (k->defmgmt) { + if (k->idx < 4 || k->idx > 5) { + GENL_SET_ERR_MSG(info, + "defmgmt key idx not 4 or 5"); + return -EINVAL; + } + } else if (k->defbeacon) { + if (k->idx < 6 || k->idx > 7) { + GENL_SET_ERR_MSG(info, + "defbeacon key idx not 6 or 7"); + return -EINVAL; + } + } else if (k->def) { + if (k->idx < 0 || k->idx > 3) { + GENL_SET_ERR_MSG(info, "def key idx not 0-3"); + return -EINVAL; + } + } else { + if (k->idx < 0 || k->idx > 7) { + GENL_SET_ERR_MSG(info, "key idx not 0-7"); + return -EINVAL; + } + } + } + + return 0; +} + +static struct cfg80211_cached_keys * +nl80211_parse_connkeys(struct cfg80211_registered_device *rdev, + struct genl_info *info, bool *no_ht) +{ + struct nlattr *keys = info->attrs[NL80211_ATTR_KEYS]; + struct key_parse parse; + struct nlattr *key; + struct cfg80211_cached_keys *result; + int rem, err, def = 0; + bool have_key = false; + + nla_for_each_nested(key, keys, rem) { + have_key = true; + break; + } + + if (!have_key) + return NULL; + + result = kzalloc(sizeof(*result), GFP_KERNEL); + if (!result) + return ERR_PTR(-ENOMEM); + + result->def = -1; + + nla_for_each_nested(key, keys, rem) { + memset(&parse, 0, sizeof(parse)); + parse.idx = -1; + + err = nl80211_parse_key_new(info, key, &parse); + if (err) + goto error; + err = -EINVAL; + if (!parse.p.key) + goto error; + if (parse.idx < 0 || parse.idx > 3) { + GENL_SET_ERR_MSG(info, "key index out of range [0-3]"); + goto error; + } + if (parse.def) { + if (def) { + GENL_SET_ERR_MSG(info, + "only one key can be default"); + goto error; + } + def = 1; + result->def = parse.idx; + if (!parse.def_uni || !parse.def_multi) + goto error; + } else if (parse.defmgmt) + goto error; + err = cfg80211_validate_key_settings(rdev, &parse.p, + parse.idx, false, NULL); + if (err) + goto error; + if (parse.p.cipher != WLAN_CIPHER_SUITE_WEP40 && + parse.p.cipher != WLAN_CIPHER_SUITE_WEP104) { + GENL_SET_ERR_MSG(info, "connect key must be WEP"); + err = -EINVAL; + goto error; + } + result->params[parse.idx].cipher = parse.p.cipher; + result->params[parse.idx].key_len = parse.p.key_len; + result->params[parse.idx].key = result->data[parse.idx]; + memcpy(result->data[parse.idx], parse.p.key, parse.p.key_len); + + /* must be WEP key if we got here */ + if (no_ht) + *no_ht = true; + } + + if (result->def < 0) { + err = -EINVAL; + GENL_SET_ERR_MSG(info, "need a default/TX key"); + goto error; + } + + return result; + error: + kfree(result); + return ERR_PTR(err); +} + +static int nl80211_key_allowed(struct wireless_dev *wdev) +{ + ASSERT_WDEV_LOCK(wdev); + + switch (wdev->iftype) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_P2P_GO: + case NL80211_IFTYPE_MESH_POINT: + break; + case NL80211_IFTYPE_ADHOC: + if (wdev->u.ibss.current_bss) + return 0; + return -ENOLINK; + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_P2P_CLIENT: + if (wdev->connected) + return 0; + return -ENOLINK; + case NL80211_IFTYPE_UNSPECIFIED: + case NL80211_IFTYPE_OCB: + case NL80211_IFTYPE_MONITOR: + case NL80211_IFTYPE_NAN: + case NL80211_IFTYPE_P2P_DEVICE: + case NL80211_IFTYPE_WDS: + case NUM_NL80211_IFTYPES: + return -EINVAL; + } + + return 0; +} + +static struct ieee80211_channel *nl80211_get_valid_chan(struct wiphy *wiphy, + u32 freq) +{ + struct ieee80211_channel *chan; + + chan = ieee80211_get_channel_khz(wiphy, freq); + if (!chan || chan->flags & IEEE80211_CHAN_DISABLED) + return NULL; + return chan; +} + +static int nl80211_put_iftypes(struct sk_buff *msg, u32 attr, u16 ifmodes) +{ + struct nlattr *nl_modes = nla_nest_start_noflag(msg, attr); + int i; + + if (!nl_modes) + goto nla_put_failure; + + i = 0; + while (ifmodes) { + if ((ifmodes & 1) && nla_put_flag(msg, i)) + goto nla_put_failure; + ifmodes >>= 1; + i++; + } + + nla_nest_end(msg, nl_modes); + return 0; + +nla_put_failure: + return -ENOBUFS; +} + +static int nl80211_put_iface_combinations(struct wiphy *wiphy, + struct sk_buff *msg, + bool large) +{ + struct nlattr *nl_combis; + int i, j; + + nl_combis = nla_nest_start_noflag(msg, + NL80211_ATTR_INTERFACE_COMBINATIONS); + if (!nl_combis) + goto nla_put_failure; + + for (i = 0; i < wiphy->n_iface_combinations; i++) { + const struct ieee80211_iface_combination *c; + struct nlattr *nl_combi, *nl_limits; + + c = &wiphy->iface_combinations[i]; + + nl_combi = nla_nest_start_noflag(msg, i + 1); + if (!nl_combi) + goto nla_put_failure; + + nl_limits = nla_nest_start_noflag(msg, + NL80211_IFACE_COMB_LIMITS); + if (!nl_limits) + goto nla_put_failure; + + for (j = 0; j < c->n_limits; j++) { + struct nlattr *nl_limit; + + nl_limit = nla_nest_start_noflag(msg, j + 1); + if (!nl_limit) + goto nla_put_failure; + if (nla_put_u32(msg, NL80211_IFACE_LIMIT_MAX, + c->limits[j].max)) + goto nla_put_failure; + if (nl80211_put_iftypes(msg, NL80211_IFACE_LIMIT_TYPES, + c->limits[j].types)) + goto nla_put_failure; + nla_nest_end(msg, nl_limit); + } + + nla_nest_end(msg, nl_limits); + + if (c->beacon_int_infra_match && + nla_put_flag(msg, NL80211_IFACE_COMB_STA_AP_BI_MATCH)) + goto nla_put_failure; + if (nla_put_u32(msg, NL80211_IFACE_COMB_NUM_CHANNELS, + c->num_different_channels) || + nla_put_u32(msg, NL80211_IFACE_COMB_MAXNUM, + c->max_interfaces)) + goto nla_put_failure; + if (large && + (nla_put_u32(msg, NL80211_IFACE_COMB_RADAR_DETECT_WIDTHS, + c->radar_detect_widths) || + nla_put_u32(msg, NL80211_IFACE_COMB_RADAR_DETECT_REGIONS, + c->radar_detect_regions))) + goto nla_put_failure; + if (c->beacon_int_min_gcd && + nla_put_u32(msg, NL80211_IFACE_COMB_BI_MIN_GCD, + c->beacon_int_min_gcd)) + goto nla_put_failure; + + nla_nest_end(msg, nl_combi); + } + + nla_nest_end(msg, nl_combis); + + return 0; +nla_put_failure: + return -ENOBUFS; +} + +#ifdef CONFIG_PM +static int nl80211_send_wowlan_tcp_caps(struct cfg80211_registered_device *rdev, + struct sk_buff *msg) +{ + const struct wiphy_wowlan_tcp_support *tcp = rdev->wiphy.wowlan->tcp; + struct nlattr *nl_tcp; + + if (!tcp) + return 0; + + nl_tcp = nla_nest_start_noflag(msg, + NL80211_WOWLAN_TRIG_TCP_CONNECTION); + if (!nl_tcp) + return -ENOBUFS; + + if (nla_put_u32(msg, NL80211_WOWLAN_TCP_DATA_PAYLOAD, + tcp->data_payload_max)) + return -ENOBUFS; + + if (nla_put_u32(msg, NL80211_WOWLAN_TCP_DATA_PAYLOAD, + tcp->data_payload_max)) + return -ENOBUFS; + + if (tcp->seq && nla_put_flag(msg, NL80211_WOWLAN_TCP_DATA_PAYLOAD_SEQ)) + return -ENOBUFS; + + if (tcp->tok && nla_put(msg, NL80211_WOWLAN_TCP_DATA_PAYLOAD_TOKEN, + sizeof(*tcp->tok), tcp->tok)) + return -ENOBUFS; + + if (nla_put_u32(msg, NL80211_WOWLAN_TCP_DATA_INTERVAL, + tcp->data_interval_max)) + return -ENOBUFS; + + if (nla_put_u32(msg, NL80211_WOWLAN_TCP_WAKE_PAYLOAD, + tcp->wake_payload_max)) + return -ENOBUFS; + + nla_nest_end(msg, nl_tcp); + return 0; +} + +static int nl80211_send_wowlan(struct sk_buff *msg, + struct cfg80211_registered_device *rdev, + bool large) +{ + struct nlattr *nl_wowlan; + + if (!rdev->wiphy.wowlan) + return 0; + + nl_wowlan = nla_nest_start_noflag(msg, + NL80211_ATTR_WOWLAN_TRIGGERS_SUPPORTED); + if (!nl_wowlan) + return -ENOBUFS; + + if (((rdev->wiphy.wowlan->flags & WIPHY_WOWLAN_ANY) && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_ANY)) || + ((rdev->wiphy.wowlan->flags & WIPHY_WOWLAN_DISCONNECT) && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_DISCONNECT)) || + ((rdev->wiphy.wowlan->flags & WIPHY_WOWLAN_MAGIC_PKT) && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_MAGIC_PKT)) || + ((rdev->wiphy.wowlan->flags & WIPHY_WOWLAN_SUPPORTS_GTK_REKEY) && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_GTK_REKEY_SUPPORTED)) || + ((rdev->wiphy.wowlan->flags & WIPHY_WOWLAN_GTK_REKEY_FAILURE) && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_GTK_REKEY_FAILURE)) || + ((rdev->wiphy.wowlan->flags & WIPHY_WOWLAN_EAP_IDENTITY_REQ) && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_EAP_IDENT_REQUEST)) || + ((rdev->wiphy.wowlan->flags & WIPHY_WOWLAN_4WAY_HANDSHAKE) && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_4WAY_HANDSHAKE)) || + ((rdev->wiphy.wowlan->flags & WIPHY_WOWLAN_RFKILL_RELEASE) && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_RFKILL_RELEASE))) + return -ENOBUFS; + + if (rdev->wiphy.wowlan->n_patterns) { + struct nl80211_pattern_support pat = { + .max_patterns = rdev->wiphy.wowlan->n_patterns, + .min_pattern_len = rdev->wiphy.wowlan->pattern_min_len, + .max_pattern_len = rdev->wiphy.wowlan->pattern_max_len, + .max_pkt_offset = rdev->wiphy.wowlan->max_pkt_offset, + }; + + if (nla_put(msg, NL80211_WOWLAN_TRIG_PKT_PATTERN, + sizeof(pat), &pat)) + return -ENOBUFS; + } + + if ((rdev->wiphy.wowlan->flags & WIPHY_WOWLAN_NET_DETECT) && + nla_put_u32(msg, NL80211_WOWLAN_TRIG_NET_DETECT, + rdev->wiphy.wowlan->max_nd_match_sets)) + return -ENOBUFS; + + if (large && nl80211_send_wowlan_tcp_caps(rdev, msg)) + return -ENOBUFS; + + nla_nest_end(msg, nl_wowlan); + + return 0; +} +#endif + +static int nl80211_send_coalesce(struct sk_buff *msg, + struct cfg80211_registered_device *rdev) +{ + struct nl80211_coalesce_rule_support rule; + + if (!rdev->wiphy.coalesce) + return 0; + + rule.max_rules = rdev->wiphy.coalesce->n_rules; + rule.max_delay = rdev->wiphy.coalesce->max_delay; + rule.pat.max_patterns = rdev->wiphy.coalesce->n_patterns; + rule.pat.min_pattern_len = rdev->wiphy.coalesce->pattern_min_len; + rule.pat.max_pattern_len = rdev->wiphy.coalesce->pattern_max_len; + rule.pat.max_pkt_offset = rdev->wiphy.coalesce->max_pkt_offset; + + if (nla_put(msg, NL80211_ATTR_COALESCE_RULE, sizeof(rule), &rule)) + return -ENOBUFS; + + return 0; +} + +static int +nl80211_send_iftype_data(struct sk_buff *msg, + const struct ieee80211_supported_band *sband, + const struct ieee80211_sband_iftype_data *iftdata) +{ + const struct ieee80211_sta_he_cap *he_cap = &iftdata->he_cap; + const struct ieee80211_sta_eht_cap *eht_cap = &iftdata->eht_cap; + + if (nl80211_put_iftypes(msg, NL80211_BAND_IFTYPE_ATTR_IFTYPES, + iftdata->types_mask)) + return -ENOBUFS; + + if (he_cap->has_he) { + if (nla_put(msg, NL80211_BAND_IFTYPE_ATTR_HE_CAP_MAC, + sizeof(he_cap->he_cap_elem.mac_cap_info), + he_cap->he_cap_elem.mac_cap_info) || + nla_put(msg, NL80211_BAND_IFTYPE_ATTR_HE_CAP_PHY, + sizeof(he_cap->he_cap_elem.phy_cap_info), + he_cap->he_cap_elem.phy_cap_info) || + nla_put(msg, NL80211_BAND_IFTYPE_ATTR_HE_CAP_MCS_SET, + sizeof(he_cap->he_mcs_nss_supp), + &he_cap->he_mcs_nss_supp) || + nla_put(msg, NL80211_BAND_IFTYPE_ATTR_HE_CAP_PPE, + sizeof(he_cap->ppe_thres), he_cap->ppe_thres)) + return -ENOBUFS; + } + + if (eht_cap->has_eht && he_cap->has_he) { + u8 mcs_nss_size, ppe_thresh_size; + u16 ppe_thres_hdr; + bool is_ap; + + is_ap = iftdata->types_mask & BIT(NL80211_IFTYPE_AP) || + iftdata->types_mask & BIT(NL80211_IFTYPE_P2P_GO); + + mcs_nss_size = + ieee80211_eht_mcs_nss_size(&he_cap->he_cap_elem, + &eht_cap->eht_cap_elem, + is_ap); + + ppe_thres_hdr = get_unaligned_le16(&eht_cap->eht_ppe_thres[0]); + ppe_thresh_size = + ieee80211_eht_ppe_size(ppe_thres_hdr, + eht_cap->eht_cap_elem.phy_cap_info); + + if (nla_put(msg, NL80211_BAND_IFTYPE_ATTR_EHT_CAP_MAC, + sizeof(eht_cap->eht_cap_elem.mac_cap_info), + eht_cap->eht_cap_elem.mac_cap_info) || + nla_put(msg, NL80211_BAND_IFTYPE_ATTR_EHT_CAP_PHY, + sizeof(eht_cap->eht_cap_elem.phy_cap_info), + eht_cap->eht_cap_elem.phy_cap_info) || + nla_put(msg, NL80211_BAND_IFTYPE_ATTR_EHT_CAP_MCS_SET, + mcs_nss_size, &eht_cap->eht_mcs_nss_supp) || + nla_put(msg, NL80211_BAND_IFTYPE_ATTR_EHT_CAP_PPE, + ppe_thresh_size, eht_cap->eht_ppe_thres)) + return -ENOBUFS; + } + + if (sband->band == NL80211_BAND_6GHZ && + nla_put(msg, NL80211_BAND_IFTYPE_ATTR_HE_6GHZ_CAPA, + sizeof(iftdata->he_6ghz_capa), + &iftdata->he_6ghz_capa)) + return -ENOBUFS; + + if (iftdata->vendor_elems.data && iftdata->vendor_elems.len && + nla_put(msg, NL80211_BAND_IFTYPE_ATTR_VENDOR_ELEMS, + iftdata->vendor_elems.len, iftdata->vendor_elems.data)) + return -ENOBUFS; + + return 0; +} + +static int nl80211_send_band_rateinfo(struct sk_buff *msg, + struct ieee80211_supported_band *sband, + bool large) +{ + struct nlattr *nl_rates, *nl_rate; + struct ieee80211_rate *rate; + int i; + + /* add HT info */ + if (sband->ht_cap.ht_supported && + (nla_put(msg, NL80211_BAND_ATTR_HT_MCS_SET, + sizeof(sband->ht_cap.mcs), + &sband->ht_cap.mcs) || + nla_put_u16(msg, NL80211_BAND_ATTR_HT_CAPA, + sband->ht_cap.cap) || + nla_put_u8(msg, NL80211_BAND_ATTR_HT_AMPDU_FACTOR, + sband->ht_cap.ampdu_factor) || + nla_put_u8(msg, NL80211_BAND_ATTR_HT_AMPDU_DENSITY, + sband->ht_cap.ampdu_density))) + return -ENOBUFS; + + /* add VHT info */ + if (sband->vht_cap.vht_supported && + (nla_put(msg, NL80211_BAND_ATTR_VHT_MCS_SET, + sizeof(sband->vht_cap.vht_mcs), + &sband->vht_cap.vht_mcs) || + nla_put_u32(msg, NL80211_BAND_ATTR_VHT_CAPA, + sband->vht_cap.cap))) + return -ENOBUFS; + + if (large && sband->n_iftype_data) { + struct nlattr *nl_iftype_data = + nla_nest_start_noflag(msg, + NL80211_BAND_ATTR_IFTYPE_DATA); + int err; + + if (!nl_iftype_data) + return -ENOBUFS; + + for (i = 0; i < sband->n_iftype_data; i++) { + struct nlattr *iftdata; + + iftdata = nla_nest_start_noflag(msg, i + 1); + if (!iftdata) + return -ENOBUFS; + + err = nl80211_send_iftype_data(msg, sband, + &sband->iftype_data[i]); + if (err) + return err; + + nla_nest_end(msg, iftdata); + } + + nla_nest_end(msg, nl_iftype_data); + } + + /* add EDMG info */ + if (large && sband->edmg_cap.channels && + (nla_put_u8(msg, NL80211_BAND_ATTR_EDMG_CHANNELS, + sband->edmg_cap.channels) || + nla_put_u8(msg, NL80211_BAND_ATTR_EDMG_BW_CONFIG, + sband->edmg_cap.bw_config))) + + return -ENOBUFS; + + /* add bitrates */ + nl_rates = nla_nest_start_noflag(msg, NL80211_BAND_ATTR_RATES); + if (!nl_rates) + return -ENOBUFS; + + for (i = 0; i < sband->n_bitrates; i++) { + nl_rate = nla_nest_start_noflag(msg, i); + if (!nl_rate) + return -ENOBUFS; + + rate = &sband->bitrates[i]; + if (nla_put_u32(msg, NL80211_BITRATE_ATTR_RATE, + rate->bitrate)) + return -ENOBUFS; + if ((rate->flags & IEEE80211_RATE_SHORT_PREAMBLE) && + nla_put_flag(msg, + NL80211_BITRATE_ATTR_2GHZ_SHORTPREAMBLE)) + return -ENOBUFS; + + nla_nest_end(msg, nl_rate); + } + + nla_nest_end(msg, nl_rates); + + return 0; +} + +static int +nl80211_send_mgmt_stypes(struct sk_buff *msg, + const struct ieee80211_txrx_stypes *mgmt_stypes) +{ + u16 stypes; + struct nlattr *nl_ftypes, *nl_ifs; + enum nl80211_iftype ift; + int i; + + if (!mgmt_stypes) + return 0; + + nl_ifs = nla_nest_start_noflag(msg, NL80211_ATTR_TX_FRAME_TYPES); + if (!nl_ifs) + return -ENOBUFS; + + for (ift = 0; ift < NUM_NL80211_IFTYPES; ift++) { + nl_ftypes = nla_nest_start_noflag(msg, ift); + if (!nl_ftypes) + return -ENOBUFS; + i = 0; + stypes = mgmt_stypes[ift].tx; + while (stypes) { + if ((stypes & 1) && + nla_put_u16(msg, NL80211_ATTR_FRAME_TYPE, + (i << 4) | IEEE80211_FTYPE_MGMT)) + return -ENOBUFS; + stypes >>= 1; + i++; + } + nla_nest_end(msg, nl_ftypes); + } + + nla_nest_end(msg, nl_ifs); + + nl_ifs = nla_nest_start_noflag(msg, NL80211_ATTR_RX_FRAME_TYPES); + if (!nl_ifs) + return -ENOBUFS; + + for (ift = 0; ift < NUM_NL80211_IFTYPES; ift++) { + nl_ftypes = nla_nest_start_noflag(msg, ift); + if (!nl_ftypes) + return -ENOBUFS; + i = 0; + stypes = mgmt_stypes[ift].rx; + while (stypes) { + if ((stypes & 1) && + nla_put_u16(msg, NL80211_ATTR_FRAME_TYPE, + (i << 4) | IEEE80211_FTYPE_MGMT)) + return -ENOBUFS; + stypes >>= 1; + i++; + } + nla_nest_end(msg, nl_ftypes); + } + nla_nest_end(msg, nl_ifs); + + return 0; +} + +#define CMD(op, n) \ + do { \ + if (rdev->ops->op) { \ + i++; \ + if (nla_put_u32(msg, i, NL80211_CMD_ ## n)) \ + goto nla_put_failure; \ + } \ + } while (0) + +static int nl80211_add_commands_unsplit(struct cfg80211_registered_device *rdev, + struct sk_buff *msg) +{ + int i = 0; + + /* + * do *NOT* add anything into this function, new things need to be + * advertised only to new versions of userspace that can deal with + * the split (and they can't possibly care about new features... + */ + CMD(add_virtual_intf, NEW_INTERFACE); + CMD(change_virtual_intf, SET_INTERFACE); + CMD(add_key, NEW_KEY); + CMD(start_ap, START_AP); + CMD(add_station, NEW_STATION); + CMD(add_mpath, NEW_MPATH); + CMD(update_mesh_config, SET_MESH_CONFIG); + CMD(change_bss, SET_BSS); + CMD(auth, AUTHENTICATE); + CMD(assoc, ASSOCIATE); + CMD(deauth, DEAUTHENTICATE); + CMD(disassoc, DISASSOCIATE); + CMD(join_ibss, JOIN_IBSS); + CMD(join_mesh, JOIN_MESH); + CMD(set_pmksa, SET_PMKSA); + CMD(del_pmksa, DEL_PMKSA); + CMD(flush_pmksa, FLUSH_PMKSA); + if (rdev->wiphy.flags & WIPHY_FLAG_HAS_REMAIN_ON_CHANNEL) + CMD(remain_on_channel, REMAIN_ON_CHANNEL); + CMD(set_bitrate_mask, SET_TX_BITRATE_MASK); + CMD(mgmt_tx, FRAME); + CMD(mgmt_tx_cancel_wait, FRAME_WAIT_CANCEL); + if (rdev->wiphy.flags & WIPHY_FLAG_NETNS_OK) { + i++; + if (nla_put_u32(msg, i, NL80211_CMD_SET_WIPHY_NETNS)) + goto nla_put_failure; + } + if (rdev->ops->set_monitor_channel || rdev->ops->start_ap || + rdev->ops->join_mesh) { + i++; + if (nla_put_u32(msg, i, NL80211_CMD_SET_CHANNEL)) + goto nla_put_failure; + } + if (rdev->wiphy.flags & WIPHY_FLAG_SUPPORTS_TDLS) { + CMD(tdls_mgmt, TDLS_MGMT); + CMD(tdls_oper, TDLS_OPER); + } + if (rdev->wiphy.max_sched_scan_reqs) + CMD(sched_scan_start, START_SCHED_SCAN); + CMD(probe_client, PROBE_CLIENT); + CMD(set_noack_map, SET_NOACK_MAP); + if (rdev->wiphy.flags & WIPHY_FLAG_REPORTS_OBSS) { + i++; + if (nla_put_u32(msg, i, NL80211_CMD_REGISTER_BEACONS)) + goto nla_put_failure; + } + CMD(start_p2p_device, START_P2P_DEVICE); + CMD(set_mcast_rate, SET_MCAST_RATE); +#ifdef CONFIG_NL80211_TESTMODE + CMD(testmode_cmd, TESTMODE); +#endif + + if (rdev->ops->connect || rdev->ops->auth) { + i++; + if (nla_put_u32(msg, i, NL80211_CMD_CONNECT)) + goto nla_put_failure; + } + + if (rdev->ops->disconnect || rdev->ops->deauth) { + i++; + if (nla_put_u32(msg, i, NL80211_CMD_DISCONNECT)) + goto nla_put_failure; + } + + return i; + nla_put_failure: + return -ENOBUFS; +} + +static int +nl80211_send_pmsr_ftm_capa(const struct cfg80211_pmsr_capabilities *cap, + struct sk_buff *msg) +{ + struct nlattr *ftm; + + if (!cap->ftm.supported) + return 0; + + ftm = nla_nest_start_noflag(msg, NL80211_PMSR_TYPE_FTM); + if (!ftm) + return -ENOBUFS; + + if (cap->ftm.asap && nla_put_flag(msg, NL80211_PMSR_FTM_CAPA_ATTR_ASAP)) + return -ENOBUFS; + if (cap->ftm.non_asap && + nla_put_flag(msg, NL80211_PMSR_FTM_CAPA_ATTR_NON_ASAP)) + return -ENOBUFS; + if (cap->ftm.request_lci && + nla_put_flag(msg, NL80211_PMSR_FTM_CAPA_ATTR_REQ_LCI)) + return -ENOBUFS; + if (cap->ftm.request_civicloc && + nla_put_flag(msg, NL80211_PMSR_FTM_CAPA_ATTR_REQ_CIVICLOC)) + return -ENOBUFS; + if (nla_put_u32(msg, NL80211_PMSR_FTM_CAPA_ATTR_PREAMBLES, + cap->ftm.preambles)) + return -ENOBUFS; + if (nla_put_u32(msg, NL80211_PMSR_FTM_CAPA_ATTR_BANDWIDTHS, + cap->ftm.bandwidths)) + return -ENOBUFS; + if (cap->ftm.max_bursts_exponent >= 0 && + nla_put_u32(msg, NL80211_PMSR_FTM_CAPA_ATTR_MAX_BURSTS_EXPONENT, + cap->ftm.max_bursts_exponent)) + return -ENOBUFS; + if (cap->ftm.max_ftms_per_burst && + nla_put_u32(msg, NL80211_PMSR_FTM_CAPA_ATTR_MAX_FTMS_PER_BURST, + cap->ftm.max_ftms_per_burst)) + return -ENOBUFS; + if (cap->ftm.trigger_based && + nla_put_flag(msg, NL80211_PMSR_FTM_CAPA_ATTR_TRIGGER_BASED)) + return -ENOBUFS; + if (cap->ftm.non_trigger_based && + nla_put_flag(msg, NL80211_PMSR_FTM_CAPA_ATTR_NON_TRIGGER_BASED)) + return -ENOBUFS; + + nla_nest_end(msg, ftm); + return 0; +} + +static int nl80211_send_pmsr_capa(struct cfg80211_registered_device *rdev, + struct sk_buff *msg) +{ + const struct cfg80211_pmsr_capabilities *cap = rdev->wiphy.pmsr_capa; + struct nlattr *pmsr, *caps; + + if (!cap) + return 0; + + /* + * we don't need to clean up anything here since the caller + * will genlmsg_cancel() if we fail + */ + + pmsr = nla_nest_start_noflag(msg, NL80211_ATTR_PEER_MEASUREMENTS); + if (!pmsr) + return -ENOBUFS; + + if (nla_put_u32(msg, NL80211_PMSR_ATTR_MAX_PEERS, cap->max_peers)) + return -ENOBUFS; + + if (cap->report_ap_tsf && + nla_put_flag(msg, NL80211_PMSR_ATTR_REPORT_AP_TSF)) + return -ENOBUFS; + + if (cap->randomize_mac_addr && + nla_put_flag(msg, NL80211_PMSR_ATTR_RANDOMIZE_MAC_ADDR)) + return -ENOBUFS; + + caps = nla_nest_start_noflag(msg, NL80211_PMSR_ATTR_TYPE_CAPA); + if (!caps) + return -ENOBUFS; + + if (nl80211_send_pmsr_ftm_capa(cap, msg)) + return -ENOBUFS; + + nla_nest_end(msg, caps); + nla_nest_end(msg, pmsr); + + return 0; +} + +static int +nl80211_put_iftype_akm_suites(struct cfg80211_registered_device *rdev, + struct sk_buff *msg) +{ + int i; + struct nlattr *nested, *nested_akms; + const struct wiphy_iftype_akm_suites *iftype_akms; + + if (!rdev->wiphy.num_iftype_akm_suites || + !rdev->wiphy.iftype_akm_suites) + return 0; + + nested = nla_nest_start(msg, NL80211_ATTR_IFTYPE_AKM_SUITES); + if (!nested) + return -ENOBUFS; + + for (i = 0; i < rdev->wiphy.num_iftype_akm_suites; i++) { + nested_akms = nla_nest_start(msg, i + 1); + if (!nested_akms) + return -ENOBUFS; + + iftype_akms = &rdev->wiphy.iftype_akm_suites[i]; + + if (nl80211_put_iftypes(msg, NL80211_IFTYPE_AKM_ATTR_IFTYPES, + iftype_akms->iftypes_mask)) + return -ENOBUFS; + + if (nla_put(msg, NL80211_IFTYPE_AKM_ATTR_SUITES, + sizeof(u32) * iftype_akms->n_akm_suites, + iftype_akms->akm_suites)) { + return -ENOBUFS; + } + nla_nest_end(msg, nested_akms); + } + + nla_nest_end(msg, nested); + + return 0; +} + +static int +nl80211_put_tid_config_support(struct cfg80211_registered_device *rdev, + struct sk_buff *msg) +{ + struct nlattr *supp; + + if (!rdev->wiphy.tid_config_support.vif && + !rdev->wiphy.tid_config_support.peer) + return 0; + + supp = nla_nest_start(msg, NL80211_ATTR_TID_CONFIG); + if (!supp) + return -ENOSPC; + + if (rdev->wiphy.tid_config_support.vif && + nla_put_u64_64bit(msg, NL80211_TID_CONFIG_ATTR_VIF_SUPP, + rdev->wiphy.tid_config_support.vif, + NL80211_TID_CONFIG_ATTR_PAD)) + goto fail; + + if (rdev->wiphy.tid_config_support.peer && + nla_put_u64_64bit(msg, NL80211_TID_CONFIG_ATTR_PEER_SUPP, + rdev->wiphy.tid_config_support.peer, + NL80211_TID_CONFIG_ATTR_PAD)) + goto fail; + + /* for now we just use the same value ... makes more sense */ + if (nla_put_u8(msg, NL80211_TID_CONFIG_ATTR_RETRY_SHORT, + rdev->wiphy.tid_config_support.max_retry)) + goto fail; + if (nla_put_u8(msg, NL80211_TID_CONFIG_ATTR_RETRY_LONG, + rdev->wiphy.tid_config_support.max_retry)) + goto fail; + + nla_nest_end(msg, supp); + + return 0; +fail: + nla_nest_cancel(msg, supp); + return -ENOBUFS; +} + +static int +nl80211_put_sar_specs(struct cfg80211_registered_device *rdev, + struct sk_buff *msg) +{ + struct nlattr *sar_capa, *specs, *sub_freq_range; + u8 num_freq_ranges; + int i; + + if (!rdev->wiphy.sar_capa) + return 0; + + num_freq_ranges = rdev->wiphy.sar_capa->num_freq_ranges; + + sar_capa = nla_nest_start(msg, NL80211_ATTR_SAR_SPEC); + if (!sar_capa) + return -ENOSPC; + + if (nla_put_u32(msg, NL80211_SAR_ATTR_TYPE, rdev->wiphy.sar_capa->type)) + goto fail; + + specs = nla_nest_start(msg, NL80211_SAR_ATTR_SPECS); + if (!specs) + goto fail; + + /* report supported freq_ranges */ + for (i = 0; i < num_freq_ranges; i++) { + sub_freq_range = nla_nest_start(msg, i + 1); + if (!sub_freq_range) + goto fail; + + if (nla_put_u32(msg, NL80211_SAR_ATTR_SPECS_START_FREQ, + rdev->wiphy.sar_capa->freq_ranges[i].start_freq)) + goto fail; + + if (nla_put_u32(msg, NL80211_SAR_ATTR_SPECS_END_FREQ, + rdev->wiphy.sar_capa->freq_ranges[i].end_freq)) + goto fail; + + nla_nest_end(msg, sub_freq_range); + } + + nla_nest_end(msg, specs); + nla_nest_end(msg, sar_capa); + + return 0; +fail: + nla_nest_cancel(msg, sar_capa); + return -ENOBUFS; +} + +static int nl80211_put_mbssid_support(struct wiphy *wiphy, struct sk_buff *msg) +{ + struct nlattr *config; + + if (!wiphy->mbssid_max_interfaces) + return 0; + + config = nla_nest_start(msg, NL80211_ATTR_MBSSID_CONFIG); + if (!config) + return -ENOBUFS; + + if (nla_put_u8(msg, NL80211_MBSSID_CONFIG_ATTR_MAX_INTERFACES, + wiphy->mbssid_max_interfaces)) + goto fail; + + if (wiphy->ema_max_profile_periodicity && + nla_put_u8(msg, + NL80211_MBSSID_CONFIG_ATTR_MAX_EMA_PROFILE_PERIODICITY, + wiphy->ema_max_profile_periodicity)) + goto fail; + + nla_nest_end(msg, config); + return 0; + +fail: + nla_nest_cancel(msg, config); + return -ENOBUFS; +} + +struct nl80211_dump_wiphy_state { + s64 filter_wiphy; + long start; + long split_start, band_start, chan_start, capa_start; + bool split; +}; + +static int nl80211_send_wiphy(struct cfg80211_registered_device *rdev, + enum nl80211_commands cmd, + struct sk_buff *msg, u32 portid, u32 seq, + int flags, struct nl80211_dump_wiphy_state *state) +{ + void *hdr; + struct nlattr *nl_bands, *nl_band; + struct nlattr *nl_freqs, *nl_freq; + struct nlattr *nl_cmds; + enum nl80211_band band; + struct ieee80211_channel *chan; + int i; + const struct ieee80211_txrx_stypes *mgmt_stypes = + rdev->wiphy.mgmt_stypes; + u32 features; + + hdr = nl80211hdr_put(msg, portid, seq, flags, cmd); + if (!hdr) + return -ENOBUFS; + + if (WARN_ON(!state)) + return -EINVAL; + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_string(msg, NL80211_ATTR_WIPHY_NAME, + wiphy_name(&rdev->wiphy)) || + nla_put_u32(msg, NL80211_ATTR_GENERATION, + cfg80211_rdev_list_generation)) + goto nla_put_failure; + + if (cmd != NL80211_CMD_NEW_WIPHY) + goto finish; + + switch (state->split_start) { + case 0: + if (nla_put_u8(msg, NL80211_ATTR_WIPHY_RETRY_SHORT, + rdev->wiphy.retry_short) || + nla_put_u8(msg, NL80211_ATTR_WIPHY_RETRY_LONG, + rdev->wiphy.retry_long) || + nla_put_u32(msg, NL80211_ATTR_WIPHY_FRAG_THRESHOLD, + rdev->wiphy.frag_threshold) || + nla_put_u32(msg, NL80211_ATTR_WIPHY_RTS_THRESHOLD, + rdev->wiphy.rts_threshold) || + nla_put_u8(msg, NL80211_ATTR_WIPHY_COVERAGE_CLASS, + rdev->wiphy.coverage_class) || + nla_put_u8(msg, NL80211_ATTR_MAX_NUM_SCAN_SSIDS, + rdev->wiphy.max_scan_ssids) || + nla_put_u8(msg, NL80211_ATTR_MAX_NUM_SCHED_SCAN_SSIDS, + rdev->wiphy.max_sched_scan_ssids) || + nla_put_u16(msg, NL80211_ATTR_MAX_SCAN_IE_LEN, + rdev->wiphy.max_scan_ie_len) || + nla_put_u16(msg, NL80211_ATTR_MAX_SCHED_SCAN_IE_LEN, + rdev->wiphy.max_sched_scan_ie_len) || + nla_put_u8(msg, NL80211_ATTR_MAX_MATCH_SETS, + rdev->wiphy.max_match_sets)) + goto nla_put_failure; + + if ((rdev->wiphy.flags & WIPHY_FLAG_IBSS_RSN) && + nla_put_flag(msg, NL80211_ATTR_SUPPORT_IBSS_RSN)) + goto nla_put_failure; + if ((rdev->wiphy.flags & WIPHY_FLAG_MESH_AUTH) && + nla_put_flag(msg, NL80211_ATTR_SUPPORT_MESH_AUTH)) + goto nla_put_failure; + if ((rdev->wiphy.flags & WIPHY_FLAG_AP_UAPSD) && + nla_put_flag(msg, NL80211_ATTR_SUPPORT_AP_UAPSD)) + goto nla_put_failure; + if ((rdev->wiphy.flags & WIPHY_FLAG_SUPPORTS_FW_ROAM) && + nla_put_flag(msg, NL80211_ATTR_ROAM_SUPPORT)) + goto nla_put_failure; + if ((rdev->wiphy.flags & WIPHY_FLAG_SUPPORTS_TDLS) && + nla_put_flag(msg, NL80211_ATTR_TDLS_SUPPORT)) + goto nla_put_failure; + if ((rdev->wiphy.flags & WIPHY_FLAG_TDLS_EXTERNAL_SETUP) && + nla_put_flag(msg, NL80211_ATTR_TDLS_EXTERNAL_SETUP)) + goto nla_put_failure; + state->split_start++; + if (state->split) + break; + fallthrough; + case 1: + if (nla_put(msg, NL80211_ATTR_CIPHER_SUITES, + sizeof(u32) * rdev->wiphy.n_cipher_suites, + rdev->wiphy.cipher_suites)) + goto nla_put_failure; + + if (nla_put_u8(msg, NL80211_ATTR_MAX_NUM_PMKIDS, + rdev->wiphy.max_num_pmkids)) + goto nla_put_failure; + + if ((rdev->wiphy.flags & WIPHY_FLAG_CONTROL_PORT_PROTOCOL) && + nla_put_flag(msg, NL80211_ATTR_CONTROL_PORT_ETHERTYPE)) + goto nla_put_failure; + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY_ANTENNA_AVAIL_TX, + rdev->wiphy.available_antennas_tx) || + nla_put_u32(msg, NL80211_ATTR_WIPHY_ANTENNA_AVAIL_RX, + rdev->wiphy.available_antennas_rx)) + goto nla_put_failure; + + if ((rdev->wiphy.flags & WIPHY_FLAG_AP_PROBE_RESP_OFFLOAD) && + nla_put_u32(msg, NL80211_ATTR_PROBE_RESP_OFFLOAD, + rdev->wiphy.probe_resp_offload)) + goto nla_put_failure; + + if ((rdev->wiphy.available_antennas_tx || + rdev->wiphy.available_antennas_rx) && + rdev->ops->get_antenna) { + u32 tx_ant = 0, rx_ant = 0; + int res; + + res = rdev_get_antenna(rdev, &tx_ant, &rx_ant); + if (!res) { + if (nla_put_u32(msg, + NL80211_ATTR_WIPHY_ANTENNA_TX, + tx_ant) || + nla_put_u32(msg, + NL80211_ATTR_WIPHY_ANTENNA_RX, + rx_ant)) + goto nla_put_failure; + } + } + + state->split_start++; + if (state->split) + break; + fallthrough; + case 2: + if (nl80211_put_iftypes(msg, NL80211_ATTR_SUPPORTED_IFTYPES, + rdev->wiphy.interface_modes)) + goto nla_put_failure; + state->split_start++; + if (state->split) + break; + fallthrough; + case 3: + nl_bands = nla_nest_start_noflag(msg, + NL80211_ATTR_WIPHY_BANDS); + if (!nl_bands) + goto nla_put_failure; + + for (band = state->band_start; + band < (state->split ? + NUM_NL80211_BANDS : + NL80211_BAND_60GHZ + 1); + band++) { + struct ieee80211_supported_band *sband; + + /* omit higher bands for ancient software */ + if (band > NL80211_BAND_5GHZ && !state->split) + break; + + sband = rdev->wiphy.bands[band]; + + if (!sband) + continue; + + nl_band = nla_nest_start_noflag(msg, band); + if (!nl_band) + goto nla_put_failure; + + switch (state->chan_start) { + case 0: + if (nl80211_send_band_rateinfo(msg, sband, + state->split)) + goto nla_put_failure; + state->chan_start++; + if (state->split) + break; + fallthrough; + default: + /* add frequencies */ + nl_freqs = nla_nest_start_noflag(msg, + NL80211_BAND_ATTR_FREQS); + if (!nl_freqs) + goto nla_put_failure; + + for (i = state->chan_start - 1; + i < sband->n_channels; + i++) { + nl_freq = nla_nest_start_noflag(msg, + i); + if (!nl_freq) + goto nla_put_failure; + + chan = &sband->channels[i]; + + if (nl80211_msg_put_channel( + msg, &rdev->wiphy, chan, + state->split)) + goto nla_put_failure; + + nla_nest_end(msg, nl_freq); + if (state->split) + break; + } + if (i < sband->n_channels) + state->chan_start = i + 2; + else + state->chan_start = 0; + nla_nest_end(msg, nl_freqs); + } + + nla_nest_end(msg, nl_band); + + if (state->split) { + /* start again here */ + if (state->chan_start) + band--; + break; + } + } + nla_nest_end(msg, nl_bands); + + if (band < NUM_NL80211_BANDS) + state->band_start = band + 1; + else + state->band_start = 0; + + /* if bands & channels are done, continue outside */ + if (state->band_start == 0 && state->chan_start == 0) + state->split_start++; + if (state->split) + break; + fallthrough; + case 4: + nl_cmds = nla_nest_start_noflag(msg, + NL80211_ATTR_SUPPORTED_COMMANDS); + if (!nl_cmds) + goto nla_put_failure; + + i = nl80211_add_commands_unsplit(rdev, msg); + if (i < 0) + goto nla_put_failure; + if (state->split) { + CMD(crit_proto_start, CRIT_PROTOCOL_START); + CMD(crit_proto_stop, CRIT_PROTOCOL_STOP); + if (rdev->wiphy.flags & WIPHY_FLAG_HAS_CHANNEL_SWITCH) + CMD(channel_switch, CHANNEL_SWITCH); + CMD(set_qos_map, SET_QOS_MAP); + if (rdev->wiphy.features & + NL80211_FEATURE_SUPPORTS_WMM_ADMISSION) + CMD(add_tx_ts, ADD_TX_TS); + CMD(set_multicast_to_unicast, SET_MULTICAST_TO_UNICAST); + CMD(update_connect_params, UPDATE_CONNECT_PARAMS); + CMD(update_ft_ies, UPDATE_FT_IES); + if (rdev->wiphy.sar_capa) + CMD(set_sar_specs, SET_SAR_SPECS); + } +#undef CMD + + nla_nest_end(msg, nl_cmds); + state->split_start++; + if (state->split) + break; + fallthrough; + case 5: + if (rdev->ops->remain_on_channel && + (rdev->wiphy.flags & WIPHY_FLAG_HAS_REMAIN_ON_CHANNEL) && + nla_put_u32(msg, + NL80211_ATTR_MAX_REMAIN_ON_CHANNEL_DURATION, + rdev->wiphy.max_remain_on_channel_duration)) + goto nla_put_failure; + + if ((rdev->wiphy.flags & WIPHY_FLAG_OFFCHAN_TX) && + nla_put_flag(msg, NL80211_ATTR_OFFCHANNEL_TX_OK)) + goto nla_put_failure; + + state->split_start++; + if (state->split) + break; + fallthrough; + case 6: +#ifdef CONFIG_PM + if (nl80211_send_wowlan(msg, rdev, state->split)) + goto nla_put_failure; + state->split_start++; + if (state->split) + break; +#else + state->split_start++; +#endif + fallthrough; + case 7: + if (nl80211_put_iftypes(msg, NL80211_ATTR_SOFTWARE_IFTYPES, + rdev->wiphy.software_iftypes)) + goto nla_put_failure; + + if (nl80211_put_iface_combinations(&rdev->wiphy, msg, + state->split)) + goto nla_put_failure; + + state->split_start++; + if (state->split) + break; + fallthrough; + case 8: + if ((rdev->wiphy.flags & WIPHY_FLAG_HAVE_AP_SME) && + nla_put_u32(msg, NL80211_ATTR_DEVICE_AP_SME, + rdev->wiphy.ap_sme_capa)) + goto nla_put_failure; + + features = rdev->wiphy.features; + /* + * We can only add the per-channel limit information if the + * dump is split, otherwise it makes it too big. Therefore + * only advertise it in that case. + */ + if (state->split) + features |= NL80211_FEATURE_ADVERTISE_CHAN_LIMITS; + if (nla_put_u32(msg, NL80211_ATTR_FEATURE_FLAGS, features)) + goto nla_put_failure; + + if (rdev->wiphy.ht_capa_mod_mask && + nla_put(msg, NL80211_ATTR_HT_CAPABILITY_MASK, + sizeof(*rdev->wiphy.ht_capa_mod_mask), + rdev->wiphy.ht_capa_mod_mask)) + goto nla_put_failure; + + if (rdev->wiphy.flags & WIPHY_FLAG_HAVE_AP_SME && + rdev->wiphy.max_acl_mac_addrs && + nla_put_u32(msg, NL80211_ATTR_MAC_ACL_MAX, + rdev->wiphy.max_acl_mac_addrs)) + goto nla_put_failure; + + /* + * Any information below this point is only available to + * applications that can deal with it being split. This + * helps ensure that newly added capabilities don't break + * older tools by overrunning their buffers. + * + * We still increment split_start so that in the split + * case we'll continue with more data in the next round, + * but break unconditionally so unsplit data stops here. + */ + if (state->split) + state->split_start++; + else + state->split_start = 0; + break; + case 9: + if (nl80211_send_mgmt_stypes(msg, mgmt_stypes)) + goto nla_put_failure; + + if (nla_put_u32(msg, NL80211_ATTR_MAX_NUM_SCHED_SCAN_PLANS, + rdev->wiphy.max_sched_scan_plans) || + nla_put_u32(msg, NL80211_ATTR_MAX_SCAN_PLAN_INTERVAL, + rdev->wiphy.max_sched_scan_plan_interval) || + nla_put_u32(msg, NL80211_ATTR_MAX_SCAN_PLAN_ITERATIONS, + rdev->wiphy.max_sched_scan_plan_iterations)) + goto nla_put_failure; + + if (rdev->wiphy.extended_capabilities && + (nla_put(msg, NL80211_ATTR_EXT_CAPA, + rdev->wiphy.extended_capabilities_len, + rdev->wiphy.extended_capabilities) || + nla_put(msg, NL80211_ATTR_EXT_CAPA_MASK, + rdev->wiphy.extended_capabilities_len, + rdev->wiphy.extended_capabilities_mask))) + goto nla_put_failure; + + if (rdev->wiphy.vht_capa_mod_mask && + nla_put(msg, NL80211_ATTR_VHT_CAPABILITY_MASK, + sizeof(*rdev->wiphy.vht_capa_mod_mask), + rdev->wiphy.vht_capa_mod_mask)) + goto nla_put_failure; + + if (nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, + rdev->wiphy.perm_addr)) + goto nla_put_failure; + + if (!is_zero_ether_addr(rdev->wiphy.addr_mask) && + nla_put(msg, NL80211_ATTR_MAC_MASK, ETH_ALEN, + rdev->wiphy.addr_mask)) + goto nla_put_failure; + + if (rdev->wiphy.n_addresses > 1) { + void *attr; + + attr = nla_nest_start(msg, NL80211_ATTR_MAC_ADDRS); + if (!attr) + goto nla_put_failure; + + for (i = 0; i < rdev->wiphy.n_addresses; i++) + if (nla_put(msg, i + 1, ETH_ALEN, + rdev->wiphy.addresses[i].addr)) + goto nla_put_failure; + + nla_nest_end(msg, attr); + } + + state->split_start++; + break; + case 10: + if (nl80211_send_coalesce(msg, rdev)) + goto nla_put_failure; + + if ((rdev->wiphy.flags & WIPHY_FLAG_SUPPORTS_5_10_MHZ) && + (nla_put_flag(msg, NL80211_ATTR_SUPPORT_5_MHZ) || + nla_put_flag(msg, NL80211_ATTR_SUPPORT_10_MHZ))) + goto nla_put_failure; + + if (rdev->wiphy.max_ap_assoc_sta && + nla_put_u32(msg, NL80211_ATTR_MAX_AP_ASSOC_STA, + rdev->wiphy.max_ap_assoc_sta)) + goto nla_put_failure; + + state->split_start++; + break; + case 11: + if (rdev->wiphy.n_vendor_commands) { + const struct nl80211_vendor_cmd_info *info; + struct nlattr *nested; + + nested = nla_nest_start_noflag(msg, + NL80211_ATTR_VENDOR_DATA); + if (!nested) + goto nla_put_failure; + + for (i = 0; i < rdev->wiphy.n_vendor_commands; i++) { + info = &rdev->wiphy.vendor_commands[i].info; + if (nla_put(msg, i + 1, sizeof(*info), info)) + goto nla_put_failure; + } + nla_nest_end(msg, nested); + } + + if (rdev->wiphy.n_vendor_events) { + const struct nl80211_vendor_cmd_info *info; + struct nlattr *nested; + + nested = nla_nest_start_noflag(msg, + NL80211_ATTR_VENDOR_EVENTS); + if (!nested) + goto nla_put_failure; + + for (i = 0; i < rdev->wiphy.n_vendor_events; i++) { + info = &rdev->wiphy.vendor_events[i]; + if (nla_put(msg, i + 1, sizeof(*info), info)) + goto nla_put_failure; + } + nla_nest_end(msg, nested); + } + state->split_start++; + break; + case 12: + if (rdev->wiphy.flags & WIPHY_FLAG_HAS_CHANNEL_SWITCH && + nla_put_u8(msg, NL80211_ATTR_MAX_CSA_COUNTERS, + rdev->wiphy.max_num_csa_counters)) + goto nla_put_failure; + + if (rdev->wiphy.regulatory_flags & REGULATORY_WIPHY_SELF_MANAGED && + nla_put_flag(msg, NL80211_ATTR_WIPHY_SELF_MANAGED_REG)) + goto nla_put_failure; + + if (rdev->wiphy.max_sched_scan_reqs && + nla_put_u32(msg, NL80211_ATTR_SCHED_SCAN_MAX_REQS, + rdev->wiphy.max_sched_scan_reqs)) + goto nla_put_failure; + + if (nla_put(msg, NL80211_ATTR_EXT_FEATURES, + sizeof(rdev->wiphy.ext_features), + rdev->wiphy.ext_features)) + goto nla_put_failure; + + if (rdev->wiphy.bss_select_support) { + struct nlattr *nested; + u32 bss_select_support = rdev->wiphy.bss_select_support; + + nested = nla_nest_start_noflag(msg, + NL80211_ATTR_BSS_SELECT); + if (!nested) + goto nla_put_failure; + + i = 0; + while (bss_select_support) { + if ((bss_select_support & 1) && + nla_put_flag(msg, i)) + goto nla_put_failure; + i++; + bss_select_support >>= 1; + } + nla_nest_end(msg, nested); + } + + state->split_start++; + break; + case 13: + if (rdev->wiphy.num_iftype_ext_capab && + rdev->wiphy.iftype_ext_capab) { + struct nlattr *nested_ext_capab, *nested; + + nested = nla_nest_start_noflag(msg, + NL80211_ATTR_IFTYPE_EXT_CAPA); + if (!nested) + goto nla_put_failure; + + for (i = state->capa_start; + i < rdev->wiphy.num_iftype_ext_capab; i++) { + const struct wiphy_iftype_ext_capab *capab; + + capab = &rdev->wiphy.iftype_ext_capab[i]; + + nested_ext_capab = nla_nest_start_noflag(msg, + i); + if (!nested_ext_capab || + nla_put_u32(msg, NL80211_ATTR_IFTYPE, + capab->iftype) || + nla_put(msg, NL80211_ATTR_EXT_CAPA, + capab->extended_capabilities_len, + capab->extended_capabilities) || + nla_put(msg, NL80211_ATTR_EXT_CAPA_MASK, + capab->extended_capabilities_len, + capab->extended_capabilities_mask)) + goto nla_put_failure; + + if (rdev->wiphy.flags & WIPHY_FLAG_SUPPORTS_MLO && + (nla_put_u16(msg, + NL80211_ATTR_EML_CAPABILITY, + capab->eml_capabilities) || + nla_put_u16(msg, + NL80211_ATTR_MLD_CAPA_AND_OPS, + capab->mld_capa_and_ops))) + goto nla_put_failure; + + nla_nest_end(msg, nested_ext_capab); + if (state->split) + break; + } + nla_nest_end(msg, nested); + if (i < rdev->wiphy.num_iftype_ext_capab) { + state->capa_start = i + 1; + break; + } + } + + if (nla_put_u32(msg, NL80211_ATTR_BANDS, + rdev->wiphy.nan_supported_bands)) + goto nla_put_failure; + + if (wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_TXQS)) { + struct cfg80211_txq_stats txqstats = {}; + int res; + + res = rdev_get_txq_stats(rdev, NULL, &txqstats); + if (!res && + !nl80211_put_txq_stats(msg, &txqstats, + NL80211_ATTR_TXQ_STATS)) + goto nla_put_failure; + + if (nla_put_u32(msg, NL80211_ATTR_TXQ_LIMIT, + rdev->wiphy.txq_limit)) + goto nla_put_failure; + if (nla_put_u32(msg, NL80211_ATTR_TXQ_MEMORY_LIMIT, + rdev->wiphy.txq_memory_limit)) + goto nla_put_failure; + if (nla_put_u32(msg, NL80211_ATTR_TXQ_QUANTUM, + rdev->wiphy.txq_quantum)) + goto nla_put_failure; + } + + state->split_start++; + break; + case 14: + if (nl80211_send_pmsr_capa(rdev, msg)) + goto nla_put_failure; + + state->split_start++; + break; + case 15: + if (rdev->wiphy.akm_suites && + nla_put(msg, NL80211_ATTR_AKM_SUITES, + sizeof(u32) * rdev->wiphy.n_akm_suites, + rdev->wiphy.akm_suites)) + goto nla_put_failure; + + if (nl80211_put_iftype_akm_suites(rdev, msg)) + goto nla_put_failure; + + if (nl80211_put_tid_config_support(rdev, msg)) + goto nla_put_failure; + state->split_start++; + break; + case 16: + if (nl80211_put_sar_specs(rdev, msg)) + goto nla_put_failure; + + if (nl80211_put_mbssid_support(&rdev->wiphy, msg)) + goto nla_put_failure; + + if (nla_put_u16(msg, NL80211_ATTR_MAX_NUM_AKM_SUITES, + rdev->wiphy.max_num_akm_suites)) + goto nla_put_failure; + + if (rdev->wiphy.flags & WIPHY_FLAG_SUPPORTS_MLO) + nla_put_flag(msg, NL80211_ATTR_MLO_SUPPORT); + + /* done */ + state->split_start = 0; + break; + } + finish: + genlmsg_end(msg, hdr); + return 0; + + nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int nl80211_dump_wiphy_parse(struct sk_buff *skb, + struct netlink_callback *cb, + struct nl80211_dump_wiphy_state *state) +{ + struct nlattr **tb = kcalloc(NUM_NL80211_ATTR, sizeof(*tb), GFP_KERNEL); + int ret; + + if (!tb) + return -ENOMEM; + + ret = nlmsg_parse_deprecated(cb->nlh, + GENL_HDRLEN + nl80211_fam.hdrsize, + tb, nl80211_fam.maxattr, + nl80211_policy, NULL); + /* ignore parse errors for backward compatibility */ + if (ret) { + ret = 0; + goto out; + } + + state->split = tb[NL80211_ATTR_SPLIT_WIPHY_DUMP]; + if (tb[NL80211_ATTR_WIPHY]) + state->filter_wiphy = nla_get_u32(tb[NL80211_ATTR_WIPHY]); + if (tb[NL80211_ATTR_WDEV]) + state->filter_wiphy = nla_get_u64(tb[NL80211_ATTR_WDEV]) >> 32; + if (tb[NL80211_ATTR_IFINDEX]) { + struct net_device *netdev; + struct cfg80211_registered_device *rdev; + int ifidx = nla_get_u32(tb[NL80211_ATTR_IFINDEX]); + + netdev = __dev_get_by_index(sock_net(skb->sk), ifidx); + if (!netdev) { + ret = -ENODEV; + goto out; + } + if (netdev->ieee80211_ptr) { + rdev = wiphy_to_rdev( + netdev->ieee80211_ptr->wiphy); + state->filter_wiphy = rdev->wiphy_idx; + } + } + + ret = 0; +out: + kfree(tb); + return ret; +} + +static int nl80211_dump_wiphy(struct sk_buff *skb, struct netlink_callback *cb) +{ + int idx = 0, ret; + struct nl80211_dump_wiphy_state *state = (void *)cb->args[0]; + struct cfg80211_registered_device *rdev; + + rtnl_lock(); + if (!state) { + state = kzalloc(sizeof(*state), GFP_KERNEL); + if (!state) { + rtnl_unlock(); + return -ENOMEM; + } + state->filter_wiphy = -1; + ret = nl80211_dump_wiphy_parse(skb, cb, state); + if (ret) { + kfree(state); + rtnl_unlock(); + return ret; + } + cb->args[0] = (long)state; + } + + list_for_each_entry(rdev, &cfg80211_rdev_list, list) { + if (!net_eq(wiphy_net(&rdev->wiphy), sock_net(skb->sk))) + continue; + if (++idx <= state->start) + continue; + if (state->filter_wiphy != -1 && + state->filter_wiphy != rdev->wiphy_idx) + continue; + /* attempt to fit multiple wiphy data chunks into the skb */ + do { + ret = nl80211_send_wiphy(rdev, NL80211_CMD_NEW_WIPHY, + skb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, + NLM_F_MULTI, state); + if (ret < 0) { + /* + * If sending the wiphy data didn't fit (ENOBUFS + * or EMSGSIZE returned), this SKB is still + * empty (so it's not too big because another + * wiphy dataset is already in the skb) and + * we've not tried to adjust the dump allocation + * yet ... then adjust the alloc size to be + * bigger, and return 1 but with the empty skb. + * This results in an empty message being RX'ed + * in userspace, but that is ignored. + * + * We can then retry with the larger buffer. + */ + if ((ret == -ENOBUFS || ret == -EMSGSIZE) && + !skb->len && !state->split && + cb->min_dump_alloc < 4096) { + cb->min_dump_alloc = 4096; + state->split_start = 0; + rtnl_unlock(); + return 1; + } + idx--; + break; + } + } while (state->split_start > 0); + break; + } + rtnl_unlock(); + + state->start = idx; + + return skb->len; +} + +static int nl80211_dump_wiphy_done(struct netlink_callback *cb) +{ + kfree((void *)cb->args[0]); + return 0; +} + +static int nl80211_get_wiphy(struct sk_buff *skb, struct genl_info *info) +{ + struct sk_buff *msg; + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct nl80211_dump_wiphy_state state = {}; + + msg = nlmsg_new(4096, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + if (nl80211_send_wiphy(rdev, NL80211_CMD_NEW_WIPHY, msg, + info->snd_portid, info->snd_seq, 0, + &state) < 0) { + nlmsg_free(msg); + return -ENOBUFS; + } + + return genlmsg_reply(msg, info); +} + +static const struct nla_policy txq_params_policy[NL80211_TXQ_ATTR_MAX + 1] = { + [NL80211_TXQ_ATTR_QUEUE] = { .type = NLA_U8 }, + [NL80211_TXQ_ATTR_TXOP] = { .type = NLA_U16 }, + [NL80211_TXQ_ATTR_CWMIN] = { .type = NLA_U16 }, + [NL80211_TXQ_ATTR_CWMAX] = { .type = NLA_U16 }, + [NL80211_TXQ_ATTR_AIFS] = { .type = NLA_U8 }, +}; + +static int parse_txq_params(struct nlattr *tb[], + struct ieee80211_txq_params *txq_params) +{ + u8 ac; + + if (!tb[NL80211_TXQ_ATTR_AC] || !tb[NL80211_TXQ_ATTR_TXOP] || + !tb[NL80211_TXQ_ATTR_CWMIN] || !tb[NL80211_TXQ_ATTR_CWMAX] || + !tb[NL80211_TXQ_ATTR_AIFS]) + return -EINVAL; + + ac = nla_get_u8(tb[NL80211_TXQ_ATTR_AC]); + txq_params->txop = nla_get_u16(tb[NL80211_TXQ_ATTR_TXOP]); + txq_params->cwmin = nla_get_u16(tb[NL80211_TXQ_ATTR_CWMIN]); + txq_params->cwmax = nla_get_u16(tb[NL80211_TXQ_ATTR_CWMAX]); + txq_params->aifs = nla_get_u8(tb[NL80211_TXQ_ATTR_AIFS]); + + if (ac >= NL80211_NUM_ACS) + return -EINVAL; + txq_params->ac = array_index_nospec(ac, NL80211_NUM_ACS); + return 0; +} + +static bool nl80211_can_set_dev_channel(struct wireless_dev *wdev) +{ + /* + * You can only set the channel explicitly for some interfaces, + * most have their channel managed via their respective + * "establish a connection" command (connect, join, ...) + * + * For AP/GO and mesh mode, the channel can be set with the + * channel userspace API, but is only stored and passed to the + * low-level driver when the AP starts or the mesh is joined. + * This is for backward compatibility, userspace can also give + * the channel in the start-ap or join-mesh commands instead. + * + * Monitors are special as they are normally slaved to + * whatever else is going on, so they have their own special + * operation to set the monitor channel if possible. + */ + return !wdev || + wdev->iftype == NL80211_IFTYPE_AP || + wdev->iftype == NL80211_IFTYPE_MESH_POINT || + wdev->iftype == NL80211_IFTYPE_MONITOR || + wdev->iftype == NL80211_IFTYPE_P2P_GO; +} + +int nl80211_parse_chandef(struct cfg80211_registered_device *rdev, + struct genl_info *info, + struct cfg80211_chan_def *chandef) +{ + struct netlink_ext_ack *extack = info->extack; + struct nlattr **attrs = info->attrs; + u32 control_freq; + + if (!attrs[NL80211_ATTR_WIPHY_FREQ]) + return -EINVAL; + + control_freq = MHZ_TO_KHZ( + nla_get_u32(info->attrs[NL80211_ATTR_WIPHY_FREQ])); + if (info->attrs[NL80211_ATTR_WIPHY_FREQ_OFFSET]) + control_freq += + nla_get_u32(info->attrs[NL80211_ATTR_WIPHY_FREQ_OFFSET]); + + memset(chandef, 0, sizeof(*chandef)); + chandef->chan = ieee80211_get_channel_khz(&rdev->wiphy, control_freq); + chandef->width = NL80211_CHAN_WIDTH_20_NOHT; + chandef->center_freq1 = KHZ_TO_MHZ(control_freq); + chandef->freq1_offset = control_freq % 1000; + chandef->center_freq2 = 0; + + /* Primary channel not allowed */ + if (!chandef->chan || chandef->chan->flags & IEEE80211_CHAN_DISABLED) { + NL_SET_ERR_MSG_ATTR(extack, attrs[NL80211_ATTR_WIPHY_FREQ], + "Channel is disabled"); + return -EINVAL; + } + + if (attrs[NL80211_ATTR_WIPHY_CHANNEL_TYPE]) { + enum nl80211_channel_type chantype; + + chantype = nla_get_u32(attrs[NL80211_ATTR_WIPHY_CHANNEL_TYPE]); + + switch (chantype) { + case NL80211_CHAN_NO_HT: + case NL80211_CHAN_HT20: + case NL80211_CHAN_HT40PLUS: + case NL80211_CHAN_HT40MINUS: + cfg80211_chandef_create(chandef, chandef->chan, + chantype); + /* user input for center_freq is incorrect */ + if (attrs[NL80211_ATTR_CENTER_FREQ1] && + chandef->center_freq1 != nla_get_u32(attrs[NL80211_ATTR_CENTER_FREQ1])) { + NL_SET_ERR_MSG_ATTR(extack, + attrs[NL80211_ATTR_CENTER_FREQ1], + "bad center frequency 1"); + return -EINVAL; + } + /* center_freq2 must be zero */ + if (attrs[NL80211_ATTR_CENTER_FREQ2] && + nla_get_u32(attrs[NL80211_ATTR_CENTER_FREQ2])) { + NL_SET_ERR_MSG_ATTR(extack, + attrs[NL80211_ATTR_CENTER_FREQ2], + "center frequency 2 can't be used"); + return -EINVAL; + } + break; + default: + NL_SET_ERR_MSG_ATTR(extack, + attrs[NL80211_ATTR_WIPHY_CHANNEL_TYPE], + "invalid channel type"); + return -EINVAL; + } + } else if (attrs[NL80211_ATTR_CHANNEL_WIDTH]) { + chandef->width = + nla_get_u32(attrs[NL80211_ATTR_CHANNEL_WIDTH]); + if (chandef->chan->band == NL80211_BAND_S1GHZ) { + /* User input error for channel width doesn't match channel */ + if (chandef->width != ieee80211_s1g_channel_width(chandef->chan)) { + NL_SET_ERR_MSG_ATTR(extack, + attrs[NL80211_ATTR_CHANNEL_WIDTH], + "bad channel width"); + return -EINVAL; + } + } + if (attrs[NL80211_ATTR_CENTER_FREQ1]) { + chandef->center_freq1 = + nla_get_u32(attrs[NL80211_ATTR_CENTER_FREQ1]); + if (attrs[NL80211_ATTR_CENTER_FREQ1_OFFSET]) + chandef->freq1_offset = nla_get_u32( + attrs[NL80211_ATTR_CENTER_FREQ1_OFFSET]); + else + chandef->freq1_offset = 0; + } + if (attrs[NL80211_ATTR_CENTER_FREQ2]) + chandef->center_freq2 = + nla_get_u32(attrs[NL80211_ATTR_CENTER_FREQ2]); + } + + if (info->attrs[NL80211_ATTR_WIPHY_EDMG_CHANNELS]) { + chandef->edmg.channels = + nla_get_u8(info->attrs[NL80211_ATTR_WIPHY_EDMG_CHANNELS]); + + if (info->attrs[NL80211_ATTR_WIPHY_EDMG_BW_CONFIG]) + chandef->edmg.bw_config = + nla_get_u8(info->attrs[NL80211_ATTR_WIPHY_EDMG_BW_CONFIG]); + } else { + chandef->edmg.bw_config = 0; + chandef->edmg.channels = 0; + } + + if (!cfg80211_chandef_valid(chandef)) { + NL_SET_ERR_MSG(extack, "invalid channel definition"); + return -EINVAL; + } + + if (!cfg80211_chandef_usable(&rdev->wiphy, chandef, + IEEE80211_CHAN_DISABLED)) { + NL_SET_ERR_MSG(extack, "(extension) channel is disabled"); + return -EINVAL; + } + + if ((chandef->width == NL80211_CHAN_WIDTH_5 || + chandef->width == NL80211_CHAN_WIDTH_10) && + !(rdev->wiphy.flags & WIPHY_FLAG_SUPPORTS_5_10_MHZ)) { + NL_SET_ERR_MSG(extack, "5/10 MHz not supported"); + return -EINVAL; + } + + return 0; +} + +static int __nl80211_set_channel(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct genl_info *info, + int _link_id) +{ + struct cfg80211_chan_def chandef; + int result; + enum nl80211_iftype iftype = NL80211_IFTYPE_MONITOR; + struct wireless_dev *wdev = NULL; + int link_id = _link_id; + + if (dev) + wdev = dev->ieee80211_ptr; + if (!nl80211_can_set_dev_channel(wdev)) + return -EOPNOTSUPP; + if (wdev) + iftype = wdev->iftype; + + if (link_id < 0) { + if (wdev && wdev->valid_links) + return -EINVAL; + link_id = 0; + } + + result = nl80211_parse_chandef(rdev, info, &chandef); + if (result) + return result; + + switch (iftype) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_P2P_GO: + if (!cfg80211_reg_can_beacon_relax(&rdev->wiphy, &chandef, + iftype)) + return -EINVAL; + if (wdev->links[link_id].ap.beacon_interval) { + struct ieee80211_channel *cur_chan; + + if (!dev || !rdev->ops->set_ap_chanwidth || + !(rdev->wiphy.features & + NL80211_FEATURE_AP_MODE_CHAN_WIDTH_CHANGE)) + return -EBUSY; + + /* Only allow dynamic channel width changes */ + cur_chan = wdev->links[link_id].ap.chandef.chan; + if (chandef.chan != cur_chan) + return -EBUSY; + + result = rdev_set_ap_chanwidth(rdev, dev, link_id, + &chandef); + if (result) + return result; + wdev->links[link_id].ap.chandef = chandef; + } else { + wdev->u.ap.preset_chandef = chandef; + } + return 0; + case NL80211_IFTYPE_MESH_POINT: + return cfg80211_set_mesh_channel(rdev, wdev, &chandef); + case NL80211_IFTYPE_MONITOR: + return cfg80211_set_monitor_channel(rdev, &chandef); + default: + break; + } + + return -EINVAL; +} + +static int nl80211_set_channel(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + int link_id = nl80211_link_id_or_invalid(info->attrs); + struct net_device *netdev = info->user_ptr[1]; + int ret; + + wdev_lock(netdev->ieee80211_ptr); + ret = __nl80211_set_channel(rdev, netdev, info, link_id); + wdev_unlock(netdev->ieee80211_ptr); + + return ret; +} + +static int nl80211_set_wiphy(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = NULL; + struct net_device *netdev = NULL; + struct wireless_dev *wdev; + int result = 0, rem_txq_params = 0; + struct nlattr *nl_txq_params; + u32 changed; + u8 retry_short = 0, retry_long = 0; + u32 frag_threshold = 0, rts_threshold = 0; + u8 coverage_class = 0; + u32 txq_limit = 0, txq_memory_limit = 0, txq_quantum = 0; + + rtnl_lock(); + /* + * Try to find the wiphy and netdev. Normally this + * function shouldn't need the netdev, but this is + * done for backward compatibility -- previously + * setting the channel was done per wiphy, but now + * it is per netdev. Previous userland like hostapd + * also passed a netdev to set_wiphy, so that it is + * possible to let that go to the right netdev! + */ + + if (info->attrs[NL80211_ATTR_IFINDEX]) { + int ifindex = nla_get_u32(info->attrs[NL80211_ATTR_IFINDEX]); + + netdev = __dev_get_by_index(genl_info_net(info), ifindex); + if (netdev && netdev->ieee80211_ptr) + rdev = wiphy_to_rdev(netdev->ieee80211_ptr->wiphy); + else + netdev = NULL; + } + + if (!netdev) { + rdev = __cfg80211_rdev_from_attrs(genl_info_net(info), + info->attrs); + if (IS_ERR(rdev)) { + rtnl_unlock(); + return PTR_ERR(rdev); + } + wdev = NULL; + netdev = NULL; + result = 0; + } else + wdev = netdev->ieee80211_ptr; + + wiphy_lock(&rdev->wiphy); + + /* + * end workaround code, by now the rdev is available + * and locked, and wdev may or may not be NULL. + */ + + if (info->attrs[NL80211_ATTR_WIPHY_NAME]) + result = cfg80211_dev_rename( + rdev, nla_data(info->attrs[NL80211_ATTR_WIPHY_NAME])); + rtnl_unlock(); + + if (result) + goto out; + + if (info->attrs[NL80211_ATTR_WIPHY_TXQ_PARAMS]) { + struct ieee80211_txq_params txq_params; + struct nlattr *tb[NL80211_TXQ_ATTR_MAX + 1]; + + if (!rdev->ops->set_txq_params) { + result = -EOPNOTSUPP; + goto out; + } + + if (!netdev) { + result = -EINVAL; + goto out; + } + + if (netdev->ieee80211_ptr->iftype != NL80211_IFTYPE_AP && + netdev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_GO) { + result = -EINVAL; + goto out; + } + + if (!netif_running(netdev)) { + result = -ENETDOWN; + goto out; + } + + nla_for_each_nested(nl_txq_params, + info->attrs[NL80211_ATTR_WIPHY_TXQ_PARAMS], + rem_txq_params) { + result = nla_parse_nested_deprecated(tb, + NL80211_TXQ_ATTR_MAX, + nl_txq_params, + txq_params_policy, + info->extack); + if (result) + goto out; + result = parse_txq_params(tb, &txq_params); + if (result) + goto out; + + txq_params.link_id = + nl80211_link_id_or_invalid(info->attrs); + + wdev_lock(netdev->ieee80211_ptr); + if (txq_params.link_id >= 0 && + !(netdev->ieee80211_ptr->valid_links & + BIT(txq_params.link_id))) + result = -ENOLINK; + else if (txq_params.link_id >= 0 && + !netdev->ieee80211_ptr->valid_links) + result = -EINVAL; + else + result = rdev_set_txq_params(rdev, netdev, + &txq_params); + wdev_unlock(netdev->ieee80211_ptr); + if (result) + goto out; + } + } + + if (info->attrs[NL80211_ATTR_WIPHY_FREQ]) { + int link_id = nl80211_link_id_or_invalid(info->attrs); + + if (wdev) { + wdev_lock(wdev); + result = __nl80211_set_channel( + rdev, + nl80211_can_set_dev_channel(wdev) ? netdev : NULL, + info, link_id); + wdev_unlock(wdev); + } else { + result = __nl80211_set_channel(rdev, netdev, info, link_id); + } + + if (result) + goto out; + } + + if (info->attrs[NL80211_ATTR_WIPHY_TX_POWER_SETTING]) { + struct wireless_dev *txp_wdev = wdev; + enum nl80211_tx_power_setting type; + int idx, mbm = 0; + + if (!(rdev->wiphy.features & NL80211_FEATURE_VIF_TXPOWER)) + txp_wdev = NULL; + + if (!rdev->ops->set_tx_power) { + result = -EOPNOTSUPP; + goto out; + } + + idx = NL80211_ATTR_WIPHY_TX_POWER_SETTING; + type = nla_get_u32(info->attrs[idx]); + + if (!info->attrs[NL80211_ATTR_WIPHY_TX_POWER_LEVEL] && + (type != NL80211_TX_POWER_AUTOMATIC)) { + result = -EINVAL; + goto out; + } + + if (type != NL80211_TX_POWER_AUTOMATIC) { + idx = NL80211_ATTR_WIPHY_TX_POWER_LEVEL; + mbm = nla_get_u32(info->attrs[idx]); + } + + result = rdev_set_tx_power(rdev, txp_wdev, type, mbm); + if (result) + goto out; + } + + if (info->attrs[NL80211_ATTR_WIPHY_ANTENNA_TX] && + info->attrs[NL80211_ATTR_WIPHY_ANTENNA_RX]) { + u32 tx_ant, rx_ant; + + if ((!rdev->wiphy.available_antennas_tx && + !rdev->wiphy.available_antennas_rx) || + !rdev->ops->set_antenna) { + result = -EOPNOTSUPP; + goto out; + } + + tx_ant = nla_get_u32(info->attrs[NL80211_ATTR_WIPHY_ANTENNA_TX]); + rx_ant = nla_get_u32(info->attrs[NL80211_ATTR_WIPHY_ANTENNA_RX]); + + /* reject antenna configurations which don't match the + * available antenna masks, except for the "all" mask */ + if ((~tx_ant && (tx_ant & ~rdev->wiphy.available_antennas_tx)) || + (~rx_ant && (rx_ant & ~rdev->wiphy.available_antennas_rx))) { + result = -EINVAL; + goto out; + } + + tx_ant = tx_ant & rdev->wiphy.available_antennas_tx; + rx_ant = rx_ant & rdev->wiphy.available_antennas_rx; + + result = rdev_set_antenna(rdev, tx_ant, rx_ant); + if (result) + goto out; + } + + changed = 0; + + if (info->attrs[NL80211_ATTR_WIPHY_RETRY_SHORT]) { + retry_short = nla_get_u8( + info->attrs[NL80211_ATTR_WIPHY_RETRY_SHORT]); + + changed |= WIPHY_PARAM_RETRY_SHORT; + } + + if (info->attrs[NL80211_ATTR_WIPHY_RETRY_LONG]) { + retry_long = nla_get_u8( + info->attrs[NL80211_ATTR_WIPHY_RETRY_LONG]); + + changed |= WIPHY_PARAM_RETRY_LONG; + } + + if (info->attrs[NL80211_ATTR_WIPHY_FRAG_THRESHOLD]) { + frag_threshold = nla_get_u32( + info->attrs[NL80211_ATTR_WIPHY_FRAG_THRESHOLD]); + if (frag_threshold < 256) { + result = -EINVAL; + goto out; + } + + if (frag_threshold != (u32) -1) { + /* + * Fragments (apart from the last one) are required to + * have even length. Make the fragmentation code + * simpler by stripping LSB should someone try to use + * odd threshold value. + */ + frag_threshold &= ~0x1; + } + changed |= WIPHY_PARAM_FRAG_THRESHOLD; + } + + if (info->attrs[NL80211_ATTR_WIPHY_RTS_THRESHOLD]) { + rts_threshold = nla_get_u32( + info->attrs[NL80211_ATTR_WIPHY_RTS_THRESHOLD]); + changed |= WIPHY_PARAM_RTS_THRESHOLD; + } + + if (info->attrs[NL80211_ATTR_WIPHY_COVERAGE_CLASS]) { + if (info->attrs[NL80211_ATTR_WIPHY_DYN_ACK]) { + result = -EINVAL; + goto out; + } + + coverage_class = nla_get_u8( + info->attrs[NL80211_ATTR_WIPHY_COVERAGE_CLASS]); + changed |= WIPHY_PARAM_COVERAGE_CLASS; + } + + if (info->attrs[NL80211_ATTR_WIPHY_DYN_ACK]) { + if (!(rdev->wiphy.features & NL80211_FEATURE_ACKTO_ESTIMATION)) { + result = -EOPNOTSUPP; + goto out; + } + + changed |= WIPHY_PARAM_DYN_ACK; + } + + if (info->attrs[NL80211_ATTR_TXQ_LIMIT]) { + if (!wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_TXQS)) { + result = -EOPNOTSUPP; + goto out; + } + txq_limit = nla_get_u32( + info->attrs[NL80211_ATTR_TXQ_LIMIT]); + changed |= WIPHY_PARAM_TXQ_LIMIT; + } + + if (info->attrs[NL80211_ATTR_TXQ_MEMORY_LIMIT]) { + if (!wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_TXQS)) { + result = -EOPNOTSUPP; + goto out; + } + txq_memory_limit = nla_get_u32( + info->attrs[NL80211_ATTR_TXQ_MEMORY_LIMIT]); + changed |= WIPHY_PARAM_TXQ_MEMORY_LIMIT; + } + + if (info->attrs[NL80211_ATTR_TXQ_QUANTUM]) { + if (!wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_TXQS)) { + result = -EOPNOTSUPP; + goto out; + } + txq_quantum = nla_get_u32( + info->attrs[NL80211_ATTR_TXQ_QUANTUM]); + changed |= WIPHY_PARAM_TXQ_QUANTUM; + } + + if (changed) { + u8 old_retry_short, old_retry_long; + u32 old_frag_threshold, old_rts_threshold; + u8 old_coverage_class; + u32 old_txq_limit, old_txq_memory_limit, old_txq_quantum; + + if (!rdev->ops->set_wiphy_params) { + result = -EOPNOTSUPP; + goto out; + } + + old_retry_short = rdev->wiphy.retry_short; + old_retry_long = rdev->wiphy.retry_long; + old_frag_threshold = rdev->wiphy.frag_threshold; + old_rts_threshold = rdev->wiphy.rts_threshold; + old_coverage_class = rdev->wiphy.coverage_class; + old_txq_limit = rdev->wiphy.txq_limit; + old_txq_memory_limit = rdev->wiphy.txq_memory_limit; + old_txq_quantum = rdev->wiphy.txq_quantum; + + if (changed & WIPHY_PARAM_RETRY_SHORT) + rdev->wiphy.retry_short = retry_short; + if (changed & WIPHY_PARAM_RETRY_LONG) + rdev->wiphy.retry_long = retry_long; + if (changed & WIPHY_PARAM_FRAG_THRESHOLD) + rdev->wiphy.frag_threshold = frag_threshold; + if (changed & WIPHY_PARAM_RTS_THRESHOLD) + rdev->wiphy.rts_threshold = rts_threshold; + if (changed & WIPHY_PARAM_COVERAGE_CLASS) + rdev->wiphy.coverage_class = coverage_class; + if (changed & WIPHY_PARAM_TXQ_LIMIT) + rdev->wiphy.txq_limit = txq_limit; + if (changed & WIPHY_PARAM_TXQ_MEMORY_LIMIT) + rdev->wiphy.txq_memory_limit = txq_memory_limit; + if (changed & WIPHY_PARAM_TXQ_QUANTUM) + rdev->wiphy.txq_quantum = txq_quantum; + + result = rdev_set_wiphy_params(rdev, changed); + if (result) { + rdev->wiphy.retry_short = old_retry_short; + rdev->wiphy.retry_long = old_retry_long; + rdev->wiphy.frag_threshold = old_frag_threshold; + rdev->wiphy.rts_threshold = old_rts_threshold; + rdev->wiphy.coverage_class = old_coverage_class; + rdev->wiphy.txq_limit = old_txq_limit; + rdev->wiphy.txq_memory_limit = old_txq_memory_limit; + rdev->wiphy.txq_quantum = old_txq_quantum; + goto out; + } + } + + result = 0; + +out: + wiphy_unlock(&rdev->wiphy); + return result; +} + +static int nl80211_send_chandef(struct sk_buff *msg, + const struct cfg80211_chan_def *chandef) +{ + if (WARN_ON(!cfg80211_chandef_valid(chandef))) + return -EINVAL; + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY_FREQ, + chandef->chan->center_freq)) + return -ENOBUFS; + if (nla_put_u32(msg, NL80211_ATTR_WIPHY_FREQ_OFFSET, + chandef->chan->freq_offset)) + return -ENOBUFS; + switch (chandef->width) { + case NL80211_CHAN_WIDTH_20_NOHT: + case NL80211_CHAN_WIDTH_20: + case NL80211_CHAN_WIDTH_40: + if (nla_put_u32(msg, NL80211_ATTR_WIPHY_CHANNEL_TYPE, + cfg80211_get_chandef_type(chandef))) + return -ENOBUFS; + break; + default: + break; + } + if (nla_put_u32(msg, NL80211_ATTR_CHANNEL_WIDTH, chandef->width)) + return -ENOBUFS; + if (nla_put_u32(msg, NL80211_ATTR_CENTER_FREQ1, chandef->center_freq1)) + return -ENOBUFS; + if (chandef->center_freq2 && + nla_put_u32(msg, NL80211_ATTR_CENTER_FREQ2, chandef->center_freq2)) + return -ENOBUFS; + return 0; +} + +static int nl80211_send_iface(struct sk_buff *msg, u32 portid, u32 seq, int flags, + struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + enum nl80211_commands cmd) +{ + struct net_device *dev = wdev->netdev; + void *hdr; + + WARN_ON(cmd != NL80211_CMD_NEW_INTERFACE && + cmd != NL80211_CMD_DEL_INTERFACE && + cmd != NL80211_CMD_SET_INTERFACE); + + hdr = nl80211hdr_put(msg, portid, seq, flags, cmd); + if (!hdr) + return -1; + + if (dev && + (nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex) || + nla_put_string(msg, NL80211_ATTR_IFNAME, dev->name))) + goto nla_put_failure; + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFTYPE, wdev->iftype) || + nla_put_u64_64bit(msg, NL80211_ATTR_WDEV, wdev_id(wdev), + NL80211_ATTR_PAD) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, wdev_address(wdev)) || + nla_put_u32(msg, NL80211_ATTR_GENERATION, + rdev->devlist_generation ^ + (cfg80211_rdev_list_generation << 2)) || + nla_put_u8(msg, NL80211_ATTR_4ADDR, wdev->use_4addr)) + goto nla_put_failure; + + if (rdev->ops->get_channel && !wdev->valid_links) { + struct cfg80211_chan_def chandef = {}; + int ret; + + ret = rdev_get_channel(rdev, wdev, 0, &chandef); + if (ret == 0 && nl80211_send_chandef(msg, &chandef)) + goto nla_put_failure; + } + + if (rdev->ops->get_tx_power) { + int dbm, ret; + + ret = rdev_get_tx_power(rdev, wdev, &dbm); + if (ret == 0 && + nla_put_u32(msg, NL80211_ATTR_WIPHY_TX_POWER_LEVEL, + DBM_TO_MBM(dbm))) + goto nla_put_failure; + } + + wdev_lock(wdev); + switch (wdev->iftype) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_P2P_GO: + if (wdev->u.ap.ssid_len && + nla_put(msg, NL80211_ATTR_SSID, wdev->u.ap.ssid_len, + wdev->u.ap.ssid)) + goto nla_put_failure_locked; + break; + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_P2P_CLIENT: + if (wdev->u.client.ssid_len && + nla_put(msg, NL80211_ATTR_SSID, wdev->u.client.ssid_len, + wdev->u.client.ssid)) + goto nla_put_failure_locked; + break; + case NL80211_IFTYPE_ADHOC: + if (wdev->u.ibss.ssid_len && + nla_put(msg, NL80211_ATTR_SSID, wdev->u.ibss.ssid_len, + wdev->u.ibss.ssid)) + goto nla_put_failure_locked; + break; + default: + /* nothing */ + break; + } + wdev_unlock(wdev); + + if (rdev->ops->get_txq_stats) { + struct cfg80211_txq_stats txqstats = {}; + int ret = rdev_get_txq_stats(rdev, wdev, &txqstats); + + if (ret == 0 && + !nl80211_put_txq_stats(msg, &txqstats, + NL80211_ATTR_TXQ_STATS)) + goto nla_put_failure; + } + + if (wdev->valid_links) { + unsigned int link_id; + struct nlattr *links = nla_nest_start(msg, + NL80211_ATTR_MLO_LINKS); + + if (!links) + goto nla_put_failure; + + for_each_valid_link(wdev, link_id) { + struct nlattr *link = nla_nest_start(msg, link_id + 1); + struct cfg80211_chan_def chandef = {}; + int ret; + + if (!link) + goto nla_put_failure; + + if (nla_put_u8(msg, NL80211_ATTR_MLO_LINK_ID, link_id)) + goto nla_put_failure; + if (nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, + wdev->links[link_id].addr)) + goto nla_put_failure; + + ret = rdev_get_channel(rdev, wdev, link_id, &chandef); + if (ret == 0 && nl80211_send_chandef(msg, &chandef)) + goto nla_put_failure; + + nla_nest_end(msg, link); + } + + nla_nest_end(msg, links); + } + + genlmsg_end(msg, hdr); + return 0; + + nla_put_failure_locked: + wdev_unlock(wdev); + nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int nl80211_dump_interface(struct sk_buff *skb, struct netlink_callback *cb) +{ + int wp_idx = 0; + int if_idx = 0; + int wp_start = cb->args[0]; + int if_start = cb->args[1]; + int filter_wiphy = -1; + struct cfg80211_registered_device *rdev; + struct wireless_dev *wdev; + int ret; + + rtnl_lock(); + if (!cb->args[2]) { + struct nl80211_dump_wiphy_state state = { + .filter_wiphy = -1, + }; + + ret = nl80211_dump_wiphy_parse(skb, cb, &state); + if (ret) + goto out_unlock; + + filter_wiphy = state.filter_wiphy; + + /* + * if filtering, set cb->args[2] to +1 since 0 is the default + * value needed to determine that parsing is necessary. + */ + if (filter_wiphy >= 0) + cb->args[2] = filter_wiphy + 1; + else + cb->args[2] = -1; + } else if (cb->args[2] > 0) { + filter_wiphy = cb->args[2] - 1; + } + + list_for_each_entry(rdev, &cfg80211_rdev_list, list) { + if (!net_eq(wiphy_net(&rdev->wiphy), sock_net(skb->sk))) + continue; + if (wp_idx < wp_start) { + wp_idx++; + continue; + } + + if (filter_wiphy >= 0 && filter_wiphy != rdev->wiphy_idx) + continue; + + if_idx = 0; + + list_for_each_entry(wdev, &rdev->wiphy.wdev_list, list) { + if (if_idx < if_start) { + if_idx++; + continue; + } + if (nl80211_send_iface(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + rdev, wdev, + NL80211_CMD_NEW_INTERFACE) < 0) { + goto out; + } + if_idx++; + } + + wp_idx++; + } + out: + cb->args[0] = wp_idx; + cb->args[1] = if_idx; + + ret = skb->len; + out_unlock: + rtnl_unlock(); + + return ret; +} + +static int nl80211_get_interface(struct sk_buff *skb, struct genl_info *info) +{ + struct sk_buff *msg; + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + if (nl80211_send_iface(msg, info->snd_portid, info->snd_seq, 0, + rdev, wdev, NL80211_CMD_NEW_INTERFACE) < 0) { + nlmsg_free(msg); + return -ENOBUFS; + } + + return genlmsg_reply(msg, info); +} + +static const struct nla_policy mntr_flags_policy[NL80211_MNTR_FLAG_MAX + 1] = { + [NL80211_MNTR_FLAG_FCSFAIL] = { .type = NLA_FLAG }, + [NL80211_MNTR_FLAG_PLCPFAIL] = { .type = NLA_FLAG }, + [NL80211_MNTR_FLAG_CONTROL] = { .type = NLA_FLAG }, + [NL80211_MNTR_FLAG_OTHER_BSS] = { .type = NLA_FLAG }, + [NL80211_MNTR_FLAG_COOK_FRAMES] = { .type = NLA_FLAG }, + [NL80211_MNTR_FLAG_ACTIVE] = { .type = NLA_FLAG }, +}; + +static int parse_monitor_flags(struct nlattr *nla, u32 *mntrflags) +{ + struct nlattr *flags[NL80211_MNTR_FLAG_MAX + 1]; + int flag; + + *mntrflags = 0; + + if (!nla) + return -EINVAL; + + if (nla_parse_nested_deprecated(flags, NL80211_MNTR_FLAG_MAX, nla, mntr_flags_policy, NULL)) + return -EINVAL; + + for (flag = 1; flag <= NL80211_MNTR_FLAG_MAX; flag++) + if (flags[flag]) + *mntrflags |= (1<attrs[NL80211_ATTR_MNTR_FLAGS]) { + if (type != NL80211_IFTYPE_MONITOR) + return -EINVAL; + + err = parse_monitor_flags(info->attrs[NL80211_ATTR_MNTR_FLAGS], + ¶ms->flags); + if (err) + return err; + + change = true; + } + + if (params->flags & MONITOR_FLAG_ACTIVE && + !(rdev->wiphy.features & NL80211_FEATURE_ACTIVE_MONITOR)) + return -EOPNOTSUPP; + + if (info->attrs[NL80211_ATTR_MU_MIMO_GROUP_DATA]) { + const u8 *mumimo_groups; + u32 cap_flag = NL80211_EXT_FEATURE_MU_MIMO_AIR_SNIFFER; + + if (type != NL80211_IFTYPE_MONITOR) + return -EINVAL; + + if (!wiphy_ext_feature_isset(&rdev->wiphy, cap_flag)) + return -EOPNOTSUPP; + + mumimo_groups = + nla_data(info->attrs[NL80211_ATTR_MU_MIMO_GROUP_DATA]); + + /* bits 0 and 63 are reserved and must be zero */ + if ((mumimo_groups[0] & BIT(0)) || + (mumimo_groups[VHT_MUMIMO_GROUPS_DATA_LEN - 1] & BIT(7))) + return -EINVAL; + + params->vht_mumimo_groups = mumimo_groups; + change = true; + } + + if (info->attrs[NL80211_ATTR_MU_MIMO_FOLLOW_MAC_ADDR]) { + u32 cap_flag = NL80211_EXT_FEATURE_MU_MIMO_AIR_SNIFFER; + + if (type != NL80211_IFTYPE_MONITOR) + return -EINVAL; + + if (!wiphy_ext_feature_isset(&rdev->wiphy, cap_flag)) + return -EOPNOTSUPP; + + params->vht_mumimo_follow_addr = + nla_data(info->attrs[NL80211_ATTR_MU_MIMO_FOLLOW_MAC_ADDR]); + change = true; + } + + return change ? 1 : 0; +} + +static int nl80211_valid_4addr(struct cfg80211_registered_device *rdev, + struct net_device *netdev, u8 use_4addr, + enum nl80211_iftype iftype) +{ + if (!use_4addr) { + if (netdev && netif_is_bridge_port(netdev)) + return -EBUSY; + return 0; + } + + switch (iftype) { + case NL80211_IFTYPE_AP_VLAN: + if (rdev->wiphy.flags & WIPHY_FLAG_4ADDR_AP) + return 0; + break; + case NL80211_IFTYPE_STATION: + if (rdev->wiphy.flags & WIPHY_FLAG_4ADDR_STATION) + return 0; + break; + default: + break; + } + + return -EOPNOTSUPP; +} + +static int nl80211_set_interface(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct vif_params params; + int err; + enum nl80211_iftype otype, ntype; + struct net_device *dev = info->user_ptr[1]; + bool change = false; + + memset(¶ms, 0, sizeof(params)); + + otype = ntype = dev->ieee80211_ptr->iftype; + + if (info->attrs[NL80211_ATTR_IFTYPE]) { + ntype = nla_get_u32(info->attrs[NL80211_ATTR_IFTYPE]); + if (otype != ntype) + change = true; + } + + if (info->attrs[NL80211_ATTR_MESH_ID]) { + struct wireless_dev *wdev = dev->ieee80211_ptr; + + if (ntype != NL80211_IFTYPE_MESH_POINT) + return -EINVAL; + if (netif_running(dev)) + return -EBUSY; + + wdev_lock(wdev); + BUILD_BUG_ON(IEEE80211_MAX_SSID_LEN != + IEEE80211_MAX_MESH_ID_LEN); + wdev->u.mesh.id_up_len = + nla_len(info->attrs[NL80211_ATTR_MESH_ID]); + memcpy(wdev->u.mesh.id, + nla_data(info->attrs[NL80211_ATTR_MESH_ID]), + wdev->u.mesh.id_up_len); + wdev_unlock(wdev); + } + + if (info->attrs[NL80211_ATTR_4ADDR]) { + params.use_4addr = !!nla_get_u8(info->attrs[NL80211_ATTR_4ADDR]); + change = true; + err = nl80211_valid_4addr(rdev, dev, params.use_4addr, ntype); + if (err) + return err; + } else { + params.use_4addr = -1; + } + + err = nl80211_parse_mon_options(rdev, ntype, info, ¶ms); + if (err < 0) + return err; + if (err > 0) + change = true; + + if (change) + err = cfg80211_change_iface(rdev, dev, ntype, ¶ms); + else + err = 0; + + if (!err && params.use_4addr != -1) + dev->ieee80211_ptr->use_4addr = params.use_4addr; + + if (change && !err) { + struct wireless_dev *wdev = dev->ieee80211_ptr; + + nl80211_notify_iface(rdev, wdev, NL80211_CMD_SET_INTERFACE); + } + + return err; +} + +static int _nl80211_new_interface(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct vif_params params; + struct wireless_dev *wdev; + struct sk_buff *msg; + int err; + enum nl80211_iftype type = NL80211_IFTYPE_UNSPECIFIED; + + memset(¶ms, 0, sizeof(params)); + + if (!info->attrs[NL80211_ATTR_IFNAME]) + return -EINVAL; + + if (info->attrs[NL80211_ATTR_IFTYPE]) + type = nla_get_u32(info->attrs[NL80211_ATTR_IFTYPE]); + + if (!rdev->ops->add_virtual_intf) + return -EOPNOTSUPP; + + if ((type == NL80211_IFTYPE_P2P_DEVICE || type == NL80211_IFTYPE_NAN || + rdev->wiphy.features & NL80211_FEATURE_MAC_ON_CREATE) && + info->attrs[NL80211_ATTR_MAC]) { + nla_memcpy(params.macaddr, info->attrs[NL80211_ATTR_MAC], + ETH_ALEN); + if (!is_valid_ether_addr(params.macaddr)) + return -EADDRNOTAVAIL; + } + + if (info->attrs[NL80211_ATTR_4ADDR]) { + params.use_4addr = !!nla_get_u8(info->attrs[NL80211_ATTR_4ADDR]); + err = nl80211_valid_4addr(rdev, NULL, params.use_4addr, type); + if (err) + return err; + } + + if (!cfg80211_iftype_allowed(&rdev->wiphy, type, params.use_4addr, 0)) + return -EOPNOTSUPP; + + err = nl80211_parse_mon_options(rdev, type, info, ¶ms); + if (err < 0) + return err; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + wdev = rdev_add_virtual_intf(rdev, + nla_data(info->attrs[NL80211_ATTR_IFNAME]), + NET_NAME_USER, type, ¶ms); + if (WARN_ON(!wdev)) { + nlmsg_free(msg); + return -EPROTO; + } else if (IS_ERR(wdev)) { + nlmsg_free(msg); + return PTR_ERR(wdev); + } + + if (info->attrs[NL80211_ATTR_SOCKET_OWNER]) + wdev->owner_nlportid = info->snd_portid; + + switch (type) { + case NL80211_IFTYPE_MESH_POINT: + if (!info->attrs[NL80211_ATTR_MESH_ID]) + break; + wdev_lock(wdev); + BUILD_BUG_ON(IEEE80211_MAX_SSID_LEN != + IEEE80211_MAX_MESH_ID_LEN); + wdev->u.mesh.id_up_len = + nla_len(info->attrs[NL80211_ATTR_MESH_ID]); + memcpy(wdev->u.mesh.id, + nla_data(info->attrs[NL80211_ATTR_MESH_ID]), + wdev->u.mesh.id_up_len); + wdev_unlock(wdev); + break; + case NL80211_IFTYPE_NAN: + case NL80211_IFTYPE_P2P_DEVICE: + /* + * P2P Device and NAN do not have a netdev, so don't go + * through the netdev notifier and must be added here + */ + cfg80211_init_wdev(wdev); + cfg80211_register_wdev(rdev, wdev); + break; + default: + break; + } + + if (nl80211_send_iface(msg, info->snd_portid, info->snd_seq, 0, + rdev, wdev, NL80211_CMD_NEW_INTERFACE) < 0) { + nlmsg_free(msg); + return -ENOBUFS; + } + + return genlmsg_reply(msg, info); +} + +static int nl80211_new_interface(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + int ret; + + /* to avoid failing a new interface creation due to pending removal */ + cfg80211_destroy_ifaces(rdev); + + wiphy_lock(&rdev->wiphy); + ret = _nl80211_new_interface(skb, info); + wiphy_unlock(&rdev->wiphy); + + return ret; +} + +static int nl80211_del_interface(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + + if (!rdev->ops->del_virtual_intf) + return -EOPNOTSUPP; + + /* + * We hold RTNL, so this is safe, without RTNL opencount cannot + * reach 0, and thus the rdev cannot be deleted. + * + * We need to do it for the dev_close(), since that will call + * the netdev notifiers, and we need to acquire the mutex there + * but don't know if we get there from here or from some other + * place (e.g. "ip link set ... down"). + */ + mutex_unlock(&rdev->wiphy.mtx); + + /* + * If we remove a wireless device without a netdev then clear + * user_ptr[1] so that nl80211_post_doit won't dereference it + * to check if it needs to do dev_put(). Otherwise it crashes + * since the wdev has been freed, unlike with a netdev where + * we need the dev_put() for the netdev to really be freed. + */ + if (!wdev->netdev) + info->user_ptr[1] = NULL; + else + dev_close(wdev->netdev); + + mutex_lock(&rdev->wiphy.mtx); + + return cfg80211_remove_virtual_intf(rdev, wdev); +} + +static int nl80211_set_noack_map(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + u16 noack_map; + + if (!info->attrs[NL80211_ATTR_NOACK_MAP]) + return -EINVAL; + + if (!rdev->ops->set_noack_map) + return -EOPNOTSUPP; + + noack_map = nla_get_u16(info->attrs[NL80211_ATTR_NOACK_MAP]); + + return rdev_set_noack_map(rdev, dev, noack_map); +} + +static int nl80211_validate_key_link_id(struct genl_info *info, + struct wireless_dev *wdev, + int link_id, bool pairwise) +{ + if (pairwise) { + if (link_id != -1) { + GENL_SET_ERR_MSG(info, + "link ID not allowed for pairwise key"); + return -EINVAL; + } + + return 0; + } + + if (wdev->valid_links) { + if (link_id == -1) { + GENL_SET_ERR_MSG(info, + "link ID must for MLO group key"); + return -EINVAL; + } + if (!(wdev->valid_links & BIT(link_id))) { + GENL_SET_ERR_MSG(info, "invalid link ID for MLO group key"); + return -EINVAL; + } + } else if (link_id != -1) { + GENL_SET_ERR_MSG(info, "link ID not allowed for non-MLO group key"); + return -EINVAL; + } + + return 0; +} + +struct get_key_cookie { + struct sk_buff *msg; + int error; + int idx; +}; + +static void get_key_callback(void *c, struct key_params *params) +{ + struct nlattr *key; + struct get_key_cookie *cookie = c; + + if ((params->key && + nla_put(cookie->msg, NL80211_ATTR_KEY_DATA, + params->key_len, params->key)) || + (params->seq && + nla_put(cookie->msg, NL80211_ATTR_KEY_SEQ, + params->seq_len, params->seq)) || + (params->cipher && + nla_put_u32(cookie->msg, NL80211_ATTR_KEY_CIPHER, + params->cipher))) + goto nla_put_failure; + + key = nla_nest_start_noflag(cookie->msg, NL80211_ATTR_KEY); + if (!key) + goto nla_put_failure; + + if ((params->key && + nla_put(cookie->msg, NL80211_KEY_DATA, + params->key_len, params->key)) || + (params->seq && + nla_put(cookie->msg, NL80211_KEY_SEQ, + params->seq_len, params->seq)) || + (params->cipher && + nla_put_u32(cookie->msg, NL80211_KEY_CIPHER, + params->cipher))) + goto nla_put_failure; + + if (nla_put_u8(cookie->msg, NL80211_KEY_IDX, cookie->idx)) + goto nla_put_failure; + + nla_nest_end(cookie->msg, key); + + return; + nla_put_failure: + cookie->error = 1; +} + +static int nl80211_get_key(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + int err; + struct net_device *dev = info->user_ptr[1]; + u8 key_idx = 0; + const u8 *mac_addr = NULL; + bool pairwise; + struct get_key_cookie cookie = { + .error = 0, + }; + void *hdr; + struct sk_buff *msg; + bool bigtk_support = false; + int link_id = nl80211_link_id_or_invalid(info->attrs); + struct wireless_dev *wdev = dev->ieee80211_ptr; + + if (wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_BEACON_PROTECTION)) + bigtk_support = true; + + if ((wdev->iftype == NL80211_IFTYPE_STATION || + wdev->iftype == NL80211_IFTYPE_P2P_CLIENT) && + wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_BEACON_PROTECTION_CLIENT)) + bigtk_support = true; + + if (info->attrs[NL80211_ATTR_KEY_IDX]) { + key_idx = nla_get_u8(info->attrs[NL80211_ATTR_KEY_IDX]); + + if (key_idx >= 6 && key_idx <= 7 && !bigtk_support) { + GENL_SET_ERR_MSG(info, "BIGTK not supported"); + return -EINVAL; + } + } + + if (info->attrs[NL80211_ATTR_MAC]) + mac_addr = nla_data(info->attrs[NL80211_ATTR_MAC]); + + pairwise = !!mac_addr; + if (info->attrs[NL80211_ATTR_KEY_TYPE]) { + u32 kt = nla_get_u32(info->attrs[NL80211_ATTR_KEY_TYPE]); + + if (kt != NL80211_KEYTYPE_GROUP && + kt != NL80211_KEYTYPE_PAIRWISE) + return -EINVAL; + pairwise = kt == NL80211_KEYTYPE_PAIRWISE; + } + + if (!rdev->ops->get_key) + return -EOPNOTSUPP; + + if (!pairwise && mac_addr && !(rdev->wiphy.flags & WIPHY_FLAG_IBSS_RSN)) + return -ENOENT; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = nl80211hdr_put(msg, info->snd_portid, info->snd_seq, 0, + NL80211_CMD_NEW_KEY); + if (!hdr) + goto nla_put_failure; + + cookie.msg = msg; + cookie.idx = key_idx; + + if (nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex) || + nla_put_u8(msg, NL80211_ATTR_KEY_IDX, key_idx)) + goto nla_put_failure; + if (mac_addr && + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, mac_addr)) + goto nla_put_failure; + + err = nl80211_validate_key_link_id(info, wdev, link_id, pairwise); + if (err) + goto free_msg; + + err = rdev_get_key(rdev, dev, link_id, key_idx, pairwise, mac_addr, + &cookie, get_key_callback); + + if (err) + goto free_msg; + + if (cookie.error) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return genlmsg_reply(msg, info); + + nla_put_failure: + err = -ENOBUFS; + free_msg: + nlmsg_free(msg); + return err; +} + +static int nl80211_set_key(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct key_parse key; + int err; + struct net_device *dev = info->user_ptr[1]; + int link_id = nl80211_link_id_or_invalid(info->attrs); + struct wireless_dev *wdev = dev->ieee80211_ptr; + + err = nl80211_parse_key(info, &key); + if (err) + return err; + + if (key.idx < 0) + return -EINVAL; + + /* Only support setting default key and + * Extended Key ID action NL80211_KEY_SET_TX. + */ + if (!key.def && !key.defmgmt && !key.defbeacon && + !(key.p.mode == NL80211_KEY_SET_TX)) + return -EINVAL; + + wdev_lock(wdev); + + if (key.def) { + if (!rdev->ops->set_default_key) { + err = -EOPNOTSUPP; + goto out; + } + + err = nl80211_key_allowed(wdev); + if (err) + goto out; + + err = nl80211_validate_key_link_id(info, wdev, link_id, false); + if (err) + goto out; + + err = rdev_set_default_key(rdev, dev, link_id, key.idx, + key.def_uni, key.def_multi); + + if (err) + goto out; + +#ifdef CONFIG_CFG80211_WEXT + wdev->wext.default_key = key.idx; +#endif + } else if (key.defmgmt) { + if (key.def_uni || !key.def_multi) { + err = -EINVAL; + goto out; + } + + if (!rdev->ops->set_default_mgmt_key) { + err = -EOPNOTSUPP; + goto out; + } + + err = nl80211_key_allowed(wdev); + if (err) + goto out; + + err = nl80211_validate_key_link_id(info, wdev, link_id, false); + if (err) + goto out; + + err = rdev_set_default_mgmt_key(rdev, dev, link_id, key.idx); + if (err) + goto out; + +#ifdef CONFIG_CFG80211_WEXT + wdev->wext.default_mgmt_key = key.idx; +#endif + } else if (key.defbeacon) { + if (key.def_uni || !key.def_multi) { + err = -EINVAL; + goto out; + } + + if (!rdev->ops->set_default_beacon_key) { + err = -EOPNOTSUPP; + goto out; + } + + err = nl80211_key_allowed(wdev); + if (err) + goto out; + + err = nl80211_validate_key_link_id(info, wdev, link_id, false); + if (err) + goto out; + + err = rdev_set_default_beacon_key(rdev, dev, link_id, key.idx); + if (err) + goto out; + } else if (key.p.mode == NL80211_KEY_SET_TX && + wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_EXT_KEY_ID)) { + u8 *mac_addr = NULL; + + if (info->attrs[NL80211_ATTR_MAC]) + mac_addr = nla_data(info->attrs[NL80211_ATTR_MAC]); + + if (!mac_addr || key.idx < 0 || key.idx > 1) { + err = -EINVAL; + goto out; + } + + err = nl80211_validate_key_link_id(info, wdev, link_id, true); + if (err) + goto out; + + err = rdev_add_key(rdev, dev, link_id, key.idx, + NL80211_KEYTYPE_PAIRWISE, + mac_addr, &key.p); + } else { + err = -EINVAL; + } + out: + wdev_unlock(wdev); + + return err; +} + +static int nl80211_new_key(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + int err; + struct net_device *dev = info->user_ptr[1]; + struct key_parse key; + const u8 *mac_addr = NULL; + int link_id = nl80211_link_id_or_invalid(info->attrs); + struct wireless_dev *wdev = dev->ieee80211_ptr; + + err = nl80211_parse_key(info, &key); + if (err) + return err; + + if (!key.p.key) { + GENL_SET_ERR_MSG(info, "no key"); + return -EINVAL; + } + + if (info->attrs[NL80211_ATTR_MAC]) + mac_addr = nla_data(info->attrs[NL80211_ATTR_MAC]); + + if (key.type == -1) { + if (mac_addr) + key.type = NL80211_KEYTYPE_PAIRWISE; + else + key.type = NL80211_KEYTYPE_GROUP; + } + + /* for now */ + if (key.type != NL80211_KEYTYPE_PAIRWISE && + key.type != NL80211_KEYTYPE_GROUP) { + GENL_SET_ERR_MSG(info, "key type not pairwise or group"); + return -EINVAL; + } + + if (key.type == NL80211_KEYTYPE_GROUP && + info->attrs[NL80211_ATTR_VLAN_ID]) + key.p.vlan_id = nla_get_u16(info->attrs[NL80211_ATTR_VLAN_ID]); + + if (!rdev->ops->add_key) + return -EOPNOTSUPP; + + if (cfg80211_validate_key_settings(rdev, &key.p, key.idx, + key.type == NL80211_KEYTYPE_PAIRWISE, + mac_addr)) { + GENL_SET_ERR_MSG(info, "key setting validation failed"); + return -EINVAL; + } + + wdev_lock(wdev); + err = nl80211_key_allowed(wdev); + if (err) + GENL_SET_ERR_MSG(info, "key not allowed"); + + if (!err) + err = nl80211_validate_key_link_id(info, wdev, link_id, + key.type == NL80211_KEYTYPE_PAIRWISE); + + if (!err) { + err = rdev_add_key(rdev, dev, link_id, key.idx, + key.type == NL80211_KEYTYPE_PAIRWISE, + mac_addr, &key.p); + if (err) + GENL_SET_ERR_MSG(info, "key addition failed"); + } + wdev_unlock(wdev); + + return err; +} + +static int nl80211_del_key(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + int err; + struct net_device *dev = info->user_ptr[1]; + u8 *mac_addr = NULL; + struct key_parse key; + int link_id = nl80211_link_id_or_invalid(info->attrs); + struct wireless_dev *wdev = dev->ieee80211_ptr; + + err = nl80211_parse_key(info, &key); + if (err) + return err; + + if (info->attrs[NL80211_ATTR_MAC]) + mac_addr = nla_data(info->attrs[NL80211_ATTR_MAC]); + + if (key.type == -1) { + if (mac_addr) + key.type = NL80211_KEYTYPE_PAIRWISE; + else + key.type = NL80211_KEYTYPE_GROUP; + } + + /* for now */ + if (key.type != NL80211_KEYTYPE_PAIRWISE && + key.type != NL80211_KEYTYPE_GROUP) + return -EINVAL; + + if (!cfg80211_valid_key_idx(rdev, key.idx, + key.type == NL80211_KEYTYPE_PAIRWISE)) + return -EINVAL; + + if (!rdev->ops->del_key) + return -EOPNOTSUPP; + + wdev_lock(wdev); + err = nl80211_key_allowed(wdev); + + if (key.type == NL80211_KEYTYPE_GROUP && mac_addr && + !(rdev->wiphy.flags & WIPHY_FLAG_IBSS_RSN)) + err = -ENOENT; + + if (!err) + err = nl80211_validate_key_link_id(info, wdev, link_id, + key.type == NL80211_KEYTYPE_PAIRWISE); + + if (!err) + err = rdev_del_key(rdev, dev, link_id, key.idx, + key.type == NL80211_KEYTYPE_PAIRWISE, + mac_addr); + +#ifdef CONFIG_CFG80211_WEXT + if (!err) { + if (key.idx == wdev->wext.default_key) + wdev->wext.default_key = -1; + else if (key.idx == wdev->wext.default_mgmt_key) + wdev->wext.default_mgmt_key = -1; + } +#endif + wdev_unlock(wdev); + + return err; +} + +/* This function returns an error or the number of nested attributes */ +static int validate_acl_mac_addrs(struct nlattr *nl_attr) +{ + struct nlattr *attr; + int n_entries = 0, tmp; + + nla_for_each_nested(attr, nl_attr, tmp) { + if (nla_len(attr) != ETH_ALEN) + return -EINVAL; + + n_entries++; + } + + return n_entries; +} + +/* + * This function parses ACL information and allocates memory for ACL data. + * On successful return, the calling function is responsible to free the + * ACL buffer returned by this function. + */ +static struct cfg80211_acl_data *parse_acl_data(struct wiphy *wiphy, + struct genl_info *info) +{ + enum nl80211_acl_policy acl_policy; + struct nlattr *attr; + struct cfg80211_acl_data *acl; + int i = 0, n_entries, tmp; + + if (!wiphy->max_acl_mac_addrs) + return ERR_PTR(-EOPNOTSUPP); + + if (!info->attrs[NL80211_ATTR_ACL_POLICY]) + return ERR_PTR(-EINVAL); + + acl_policy = nla_get_u32(info->attrs[NL80211_ATTR_ACL_POLICY]); + if (acl_policy != NL80211_ACL_POLICY_ACCEPT_UNLESS_LISTED && + acl_policy != NL80211_ACL_POLICY_DENY_UNLESS_LISTED) + return ERR_PTR(-EINVAL); + + if (!info->attrs[NL80211_ATTR_MAC_ADDRS]) + return ERR_PTR(-EINVAL); + + n_entries = validate_acl_mac_addrs(info->attrs[NL80211_ATTR_MAC_ADDRS]); + if (n_entries < 0) + return ERR_PTR(n_entries); + + if (n_entries > wiphy->max_acl_mac_addrs) + return ERR_PTR(-ENOTSUPP); + + acl = kzalloc(struct_size(acl, mac_addrs, n_entries), GFP_KERNEL); + if (!acl) + return ERR_PTR(-ENOMEM); + + nla_for_each_nested(attr, info->attrs[NL80211_ATTR_MAC_ADDRS], tmp) { + memcpy(acl->mac_addrs[i].addr, nla_data(attr), ETH_ALEN); + i++; + } + + acl->n_acl_entries = n_entries; + acl->acl_policy = acl_policy; + + return acl; +} + +static int nl80211_set_mac_acl(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct cfg80211_acl_data *acl; + int err; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_AP && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_GO) + return -EOPNOTSUPP; + + if (!dev->ieee80211_ptr->links[0].ap.beacon_interval) + return -EINVAL; + + acl = parse_acl_data(&rdev->wiphy, info); + if (IS_ERR(acl)) + return PTR_ERR(acl); + + err = rdev_set_mac_acl(rdev, dev, acl); + + kfree(acl); + + return err; +} + +static u32 rateset_to_mask(struct ieee80211_supported_band *sband, + u8 *rates, u8 rates_len) +{ + u8 i; + u32 mask = 0; + + for (i = 0; i < rates_len; i++) { + int rate = (rates[i] & 0x7f) * 5; + int ridx; + + for (ridx = 0; ridx < sband->n_bitrates; ridx++) { + struct ieee80211_rate *srate = + &sband->bitrates[ridx]; + if (rate == srate->bitrate) { + mask |= 1 << ridx; + break; + } + } + if (ridx == sband->n_bitrates) + return 0; /* rate not found */ + } + + return mask; +} + +static bool ht_rateset_to_mask(struct ieee80211_supported_band *sband, + u8 *rates, u8 rates_len, + u8 mcs[IEEE80211_HT_MCS_MASK_LEN]) +{ + u8 i; + + memset(mcs, 0, IEEE80211_HT_MCS_MASK_LEN); + + for (i = 0; i < rates_len; i++) { + int ridx, rbit; + + ridx = rates[i] / 8; + rbit = BIT(rates[i] % 8); + + /* check validity */ + if ((ridx < 0) || (ridx >= IEEE80211_HT_MCS_MASK_LEN)) + return false; + + /* check availability */ + ridx = array_index_nospec(ridx, IEEE80211_HT_MCS_MASK_LEN); + if (sband->ht_cap.mcs.rx_mask[ridx] & rbit) + mcs[ridx] |= rbit; + else + return false; + } + + return true; +} + +static u16 vht_mcs_map_to_mcs_mask(u8 vht_mcs_map) +{ + u16 mcs_mask = 0; + + switch (vht_mcs_map) { + case IEEE80211_VHT_MCS_NOT_SUPPORTED: + break; + case IEEE80211_VHT_MCS_SUPPORT_0_7: + mcs_mask = 0x00FF; + break; + case IEEE80211_VHT_MCS_SUPPORT_0_8: + mcs_mask = 0x01FF; + break; + case IEEE80211_VHT_MCS_SUPPORT_0_9: + mcs_mask = 0x03FF; + break; + default: + break; + } + + return mcs_mask; +} + +static void vht_build_mcs_mask(u16 vht_mcs_map, + u16 vht_mcs_mask[NL80211_VHT_NSS_MAX]) +{ + u8 nss; + + for (nss = 0; nss < NL80211_VHT_NSS_MAX; nss++) { + vht_mcs_mask[nss] = vht_mcs_map_to_mcs_mask(vht_mcs_map & 0x03); + vht_mcs_map >>= 2; + } +} + +static bool vht_set_mcs_mask(struct ieee80211_supported_band *sband, + struct nl80211_txrate_vht *txrate, + u16 mcs[NL80211_VHT_NSS_MAX]) +{ + u16 tx_mcs_map = le16_to_cpu(sband->vht_cap.vht_mcs.tx_mcs_map); + u16 tx_mcs_mask[NL80211_VHT_NSS_MAX] = {}; + u8 i; + + if (!sband->vht_cap.vht_supported) + return false; + + memset(mcs, 0, sizeof(u16) * NL80211_VHT_NSS_MAX); + + /* Build vht_mcs_mask from VHT capabilities */ + vht_build_mcs_mask(tx_mcs_map, tx_mcs_mask); + + for (i = 0; i < NL80211_VHT_NSS_MAX; i++) { + if ((tx_mcs_mask[i] & txrate->mcs[i]) == txrate->mcs[i]) + mcs[i] = txrate->mcs[i]; + else + return false; + } + + return true; +} + +static u16 he_mcs_map_to_mcs_mask(u8 he_mcs_map) +{ + switch (he_mcs_map) { + case IEEE80211_HE_MCS_NOT_SUPPORTED: + return 0; + case IEEE80211_HE_MCS_SUPPORT_0_7: + return 0x00FF; + case IEEE80211_HE_MCS_SUPPORT_0_9: + return 0x03FF; + case IEEE80211_HE_MCS_SUPPORT_0_11: + return 0xFFF; + default: + break; + } + return 0; +} + +static void he_build_mcs_mask(u16 he_mcs_map, + u16 he_mcs_mask[NL80211_HE_NSS_MAX]) +{ + u8 nss; + + for (nss = 0; nss < NL80211_HE_NSS_MAX; nss++) { + he_mcs_mask[nss] = he_mcs_map_to_mcs_mask(he_mcs_map & 0x03); + he_mcs_map >>= 2; + } +} + +static u16 he_get_txmcsmap(struct genl_info *info, unsigned int link_id, + const struct ieee80211_sta_he_cap *he_cap) +{ + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_chan_def *chandef; + __le16 tx_mcs; + + chandef = wdev_chandef(wdev, link_id); + if (!chandef) { + /* + * This is probably broken, but we never maintained + * a chandef in these cases, so it always was. + */ + return le16_to_cpu(he_cap->he_mcs_nss_supp.tx_mcs_80); + } + + switch (chandef->width) { + case NL80211_CHAN_WIDTH_80P80: + tx_mcs = he_cap->he_mcs_nss_supp.tx_mcs_80p80; + break; + case NL80211_CHAN_WIDTH_160: + tx_mcs = he_cap->he_mcs_nss_supp.tx_mcs_160; + break; + default: + tx_mcs = he_cap->he_mcs_nss_supp.tx_mcs_80; + break; + } + + return le16_to_cpu(tx_mcs); +} + +static bool he_set_mcs_mask(struct genl_info *info, + struct wireless_dev *wdev, + struct ieee80211_supported_band *sband, + struct nl80211_txrate_he *txrate, + u16 mcs[NL80211_HE_NSS_MAX], + unsigned int link_id) +{ + const struct ieee80211_sta_he_cap *he_cap; + u16 tx_mcs_mask[NL80211_HE_NSS_MAX] = {}; + u16 tx_mcs_map = 0; + u8 i; + + he_cap = ieee80211_get_he_iftype_cap(sband, wdev->iftype); + if (!he_cap) + return false; + + memset(mcs, 0, sizeof(u16) * NL80211_HE_NSS_MAX); + + tx_mcs_map = he_get_txmcsmap(info, link_id, he_cap); + + /* Build he_mcs_mask from HE capabilities */ + he_build_mcs_mask(tx_mcs_map, tx_mcs_mask); + + for (i = 0; i < NL80211_HE_NSS_MAX; i++) { + if ((tx_mcs_mask[i] & txrate->mcs[i]) == txrate->mcs[i]) + mcs[i] = txrate->mcs[i]; + else + return false; + } + + return true; +} + +static int nl80211_parse_tx_bitrate_mask(struct genl_info *info, + struct nlattr *attrs[], + enum nl80211_attrs attr, + struct cfg80211_bitrate_mask *mask, + struct net_device *dev, + bool default_all_enabled, + unsigned int link_id) +{ + struct nlattr *tb[NL80211_TXRATE_MAX + 1]; + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + int rem, i; + struct nlattr *tx_rates; + struct ieee80211_supported_band *sband; + u16 vht_tx_mcs_map, he_tx_mcs_map; + + memset(mask, 0, sizeof(*mask)); + /* Default to all rates enabled */ + for (i = 0; i < NUM_NL80211_BANDS; i++) { + const struct ieee80211_sta_he_cap *he_cap; + + if (!default_all_enabled) + break; + + sband = rdev->wiphy.bands[i]; + + if (!sband) + continue; + + mask->control[i].legacy = (1 << sband->n_bitrates) - 1; + memcpy(mask->control[i].ht_mcs, + sband->ht_cap.mcs.rx_mask, + sizeof(mask->control[i].ht_mcs)); + + if (sband->vht_cap.vht_supported) { + vht_tx_mcs_map = le16_to_cpu(sband->vht_cap.vht_mcs.tx_mcs_map); + vht_build_mcs_mask(vht_tx_mcs_map, mask->control[i].vht_mcs); + } + + he_cap = ieee80211_get_he_iftype_cap(sband, wdev->iftype); + if (!he_cap) + continue; + + he_tx_mcs_map = he_get_txmcsmap(info, link_id, he_cap); + he_build_mcs_mask(he_tx_mcs_map, mask->control[i].he_mcs); + + mask->control[i].he_gi = 0xFF; + mask->control[i].he_ltf = 0xFF; + } + + /* if no rates are given set it back to the defaults */ + if (!attrs[attr]) + goto out; + + /* The nested attribute uses enum nl80211_band as the index. This maps + * directly to the enum nl80211_band values used in cfg80211. + */ + BUILD_BUG_ON(NL80211_MAX_SUPP_HT_RATES > IEEE80211_HT_MCS_MASK_LEN * 8); + nla_for_each_nested(tx_rates, attrs[attr], rem) { + enum nl80211_band band = nla_type(tx_rates); + int err; + + if (band < 0 || band >= NUM_NL80211_BANDS) + return -EINVAL; + sband = rdev->wiphy.bands[band]; + if (sband == NULL) + return -EINVAL; + err = nla_parse_nested_deprecated(tb, NL80211_TXRATE_MAX, + tx_rates, + nl80211_txattr_policy, + info->extack); + if (err) + return err; + if (tb[NL80211_TXRATE_LEGACY]) { + mask->control[band].legacy = rateset_to_mask( + sband, + nla_data(tb[NL80211_TXRATE_LEGACY]), + nla_len(tb[NL80211_TXRATE_LEGACY])); + if ((mask->control[band].legacy == 0) && + nla_len(tb[NL80211_TXRATE_LEGACY])) + return -EINVAL; + } + if (tb[NL80211_TXRATE_HT]) { + if (!ht_rateset_to_mask( + sband, + nla_data(tb[NL80211_TXRATE_HT]), + nla_len(tb[NL80211_TXRATE_HT]), + mask->control[band].ht_mcs)) + return -EINVAL; + } + + if (tb[NL80211_TXRATE_VHT]) { + if (!vht_set_mcs_mask( + sband, + nla_data(tb[NL80211_TXRATE_VHT]), + mask->control[band].vht_mcs)) + return -EINVAL; + } + + if (tb[NL80211_TXRATE_GI]) { + mask->control[band].gi = + nla_get_u8(tb[NL80211_TXRATE_GI]); + if (mask->control[band].gi > NL80211_TXRATE_FORCE_LGI) + return -EINVAL; + } + if (tb[NL80211_TXRATE_HE] && + !he_set_mcs_mask(info, wdev, sband, + nla_data(tb[NL80211_TXRATE_HE]), + mask->control[band].he_mcs, + link_id)) + return -EINVAL; + + if (tb[NL80211_TXRATE_HE_GI]) + mask->control[band].he_gi = + nla_get_u8(tb[NL80211_TXRATE_HE_GI]); + if (tb[NL80211_TXRATE_HE_LTF]) + mask->control[band].he_ltf = + nla_get_u8(tb[NL80211_TXRATE_HE_LTF]); + + if (mask->control[band].legacy == 0) { + /* don't allow empty legacy rates if HT, VHT or HE + * are not even supported. + */ + if (!(rdev->wiphy.bands[band]->ht_cap.ht_supported || + rdev->wiphy.bands[band]->vht_cap.vht_supported || + ieee80211_get_he_iftype_cap(sband, wdev->iftype))) + return -EINVAL; + + for (i = 0; i < IEEE80211_HT_MCS_MASK_LEN; i++) + if (mask->control[band].ht_mcs[i]) + goto out; + + for (i = 0; i < NL80211_VHT_NSS_MAX; i++) + if (mask->control[band].vht_mcs[i]) + goto out; + + for (i = 0; i < NL80211_HE_NSS_MAX; i++) + if (mask->control[band].he_mcs[i]) + goto out; + + /* legacy and mcs rates may not be both empty */ + return -EINVAL; + } + } + +out: + return 0; +} + +static int validate_beacon_tx_rate(struct cfg80211_registered_device *rdev, + enum nl80211_band band, + struct cfg80211_bitrate_mask *beacon_rate) +{ + u32 count_ht, count_vht, count_he, i; + u32 rate = beacon_rate->control[band].legacy; + + /* Allow only one rate */ + if (hweight32(rate) > 1) + return -EINVAL; + + count_ht = 0; + for (i = 0; i < IEEE80211_HT_MCS_MASK_LEN; i++) { + if (hweight8(beacon_rate->control[band].ht_mcs[i]) > 1) { + return -EINVAL; + } else if (beacon_rate->control[band].ht_mcs[i]) { + count_ht++; + if (count_ht > 1) + return -EINVAL; + } + if (count_ht && rate) + return -EINVAL; + } + + count_vht = 0; + for (i = 0; i < NL80211_VHT_NSS_MAX; i++) { + if (hweight16(beacon_rate->control[band].vht_mcs[i]) > 1) { + return -EINVAL; + } else if (beacon_rate->control[band].vht_mcs[i]) { + count_vht++; + if (count_vht > 1) + return -EINVAL; + } + if (count_vht && rate) + return -EINVAL; + } + + count_he = 0; + for (i = 0; i < NL80211_HE_NSS_MAX; i++) { + if (hweight16(beacon_rate->control[band].he_mcs[i]) > 1) { + return -EINVAL; + } else if (beacon_rate->control[band].he_mcs[i]) { + count_he++; + if (count_he > 1) + return -EINVAL; + } + if (count_he && rate) + return -EINVAL; + } + + if ((count_ht && count_vht && count_he) || + (!rate && !count_ht && !count_vht && !count_he)) + return -EINVAL; + + if (rate && + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_BEACON_RATE_LEGACY)) + return -EINVAL; + if (count_ht && + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_BEACON_RATE_HT)) + return -EINVAL; + if (count_vht && + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_BEACON_RATE_VHT)) + return -EINVAL; + if (count_he && + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_BEACON_RATE_HE)) + return -EINVAL; + + return 0; +} + +static int nl80211_parse_mbssid_config(struct wiphy *wiphy, + struct net_device *dev, + struct nlattr *attrs, + struct cfg80211_mbssid_config *config, + u8 num_elems) +{ + struct nlattr *tb[NL80211_MBSSID_CONFIG_ATTR_MAX + 1]; + + if (!wiphy->mbssid_max_interfaces) + return -EOPNOTSUPP; + + if (nla_parse_nested(tb, NL80211_MBSSID_CONFIG_ATTR_MAX, attrs, NULL, + NULL) || + !tb[NL80211_MBSSID_CONFIG_ATTR_INDEX]) + return -EINVAL; + + config->ema = nla_get_flag(tb[NL80211_MBSSID_CONFIG_ATTR_EMA]); + if (config->ema) { + if (!wiphy->ema_max_profile_periodicity) + return -EOPNOTSUPP; + + if (num_elems > wiphy->ema_max_profile_periodicity) + return -EINVAL; + } + + config->index = nla_get_u8(tb[NL80211_MBSSID_CONFIG_ATTR_INDEX]); + if (config->index >= wiphy->mbssid_max_interfaces || + (!config->index && !num_elems)) + return -EINVAL; + + if (tb[NL80211_MBSSID_CONFIG_ATTR_TX_IFINDEX]) { + u32 tx_ifindex = + nla_get_u32(tb[NL80211_MBSSID_CONFIG_ATTR_TX_IFINDEX]); + + if ((!config->index && tx_ifindex != dev->ifindex) || + (config->index && tx_ifindex == dev->ifindex)) + return -EINVAL; + + if (tx_ifindex != dev->ifindex) { + struct net_device *tx_netdev = + dev_get_by_index(wiphy_net(wiphy), tx_ifindex); + + if (!tx_netdev || !tx_netdev->ieee80211_ptr || + tx_netdev->ieee80211_ptr->wiphy != wiphy || + tx_netdev->ieee80211_ptr->iftype != + NL80211_IFTYPE_AP) { + dev_put(tx_netdev); + return -EINVAL; + } + + config->tx_wdev = tx_netdev->ieee80211_ptr; + } else { + config->tx_wdev = dev->ieee80211_ptr; + } + } else if (!config->index) { + config->tx_wdev = dev->ieee80211_ptr; + } else { + return -EINVAL; + } + + return 0; +} + +static struct cfg80211_mbssid_elems * +nl80211_parse_mbssid_elems(struct wiphy *wiphy, struct nlattr *attrs) +{ + struct nlattr *nl_elems; + struct cfg80211_mbssid_elems *elems; + int rem_elems; + u8 i = 0, num_elems = 0; + + if (!wiphy->mbssid_max_interfaces) + return ERR_PTR(-EINVAL); + + nla_for_each_nested(nl_elems, attrs, rem_elems) { + if (num_elems >= 255) + return ERR_PTR(-EINVAL); + num_elems++; + } + + elems = kzalloc(struct_size(elems, elem, num_elems), GFP_KERNEL); + if (!elems) + return ERR_PTR(-ENOMEM); + + nla_for_each_nested(nl_elems, attrs, rem_elems) { + elems->elem[i].data = nla_data(nl_elems); + elems->elem[i].len = nla_len(nl_elems); + i++; + } + elems->cnt = num_elems; + return elems; +} + +static int nl80211_parse_he_bss_color(struct nlattr *attrs, + struct cfg80211_he_bss_color *he_bss_color) +{ + struct nlattr *tb[NL80211_HE_BSS_COLOR_ATTR_MAX + 1]; + int err; + + err = nla_parse_nested(tb, NL80211_HE_BSS_COLOR_ATTR_MAX, attrs, + he_bss_color_policy, NULL); + if (err) + return err; + + if (!tb[NL80211_HE_BSS_COLOR_ATTR_COLOR]) + return -EINVAL; + + he_bss_color->color = + nla_get_u8(tb[NL80211_HE_BSS_COLOR_ATTR_COLOR]); + he_bss_color->enabled = + !nla_get_flag(tb[NL80211_HE_BSS_COLOR_ATTR_DISABLED]); + he_bss_color->partial = + nla_get_flag(tb[NL80211_HE_BSS_COLOR_ATTR_PARTIAL]); + + return 0; +} + +static int nl80211_parse_beacon(struct cfg80211_registered_device *rdev, + struct nlattr *attrs[], + struct cfg80211_beacon_data *bcn) +{ + bool haveinfo = false; + int err; + + memset(bcn, 0, sizeof(*bcn)); + + bcn->link_id = nl80211_link_id(attrs); + + if (attrs[NL80211_ATTR_BEACON_HEAD]) { + bcn->head = nla_data(attrs[NL80211_ATTR_BEACON_HEAD]); + bcn->head_len = nla_len(attrs[NL80211_ATTR_BEACON_HEAD]); + if (!bcn->head_len) + return -EINVAL; + haveinfo = true; + } + + if (attrs[NL80211_ATTR_BEACON_TAIL]) { + bcn->tail = nla_data(attrs[NL80211_ATTR_BEACON_TAIL]); + bcn->tail_len = nla_len(attrs[NL80211_ATTR_BEACON_TAIL]); + haveinfo = true; + } + + if (!haveinfo) + return -EINVAL; + + if (attrs[NL80211_ATTR_IE]) { + bcn->beacon_ies = nla_data(attrs[NL80211_ATTR_IE]); + bcn->beacon_ies_len = nla_len(attrs[NL80211_ATTR_IE]); + } + + if (attrs[NL80211_ATTR_IE_PROBE_RESP]) { + bcn->proberesp_ies = + nla_data(attrs[NL80211_ATTR_IE_PROBE_RESP]); + bcn->proberesp_ies_len = + nla_len(attrs[NL80211_ATTR_IE_PROBE_RESP]); + } + + if (attrs[NL80211_ATTR_IE_ASSOC_RESP]) { + bcn->assocresp_ies = + nla_data(attrs[NL80211_ATTR_IE_ASSOC_RESP]); + bcn->assocresp_ies_len = + nla_len(attrs[NL80211_ATTR_IE_ASSOC_RESP]); + } + + if (attrs[NL80211_ATTR_PROBE_RESP]) { + bcn->probe_resp = nla_data(attrs[NL80211_ATTR_PROBE_RESP]); + bcn->probe_resp_len = nla_len(attrs[NL80211_ATTR_PROBE_RESP]); + } + + if (attrs[NL80211_ATTR_FTM_RESPONDER]) { + struct nlattr *tb[NL80211_FTM_RESP_ATTR_MAX + 1]; + + err = nla_parse_nested_deprecated(tb, + NL80211_FTM_RESP_ATTR_MAX, + attrs[NL80211_ATTR_FTM_RESPONDER], + NULL, NULL); + if (err) + return err; + + if (tb[NL80211_FTM_RESP_ATTR_ENABLED] && + wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_ENABLE_FTM_RESPONDER)) + bcn->ftm_responder = 1; + else + return -EOPNOTSUPP; + + if (tb[NL80211_FTM_RESP_ATTR_LCI]) { + bcn->lci = nla_data(tb[NL80211_FTM_RESP_ATTR_LCI]); + bcn->lci_len = nla_len(tb[NL80211_FTM_RESP_ATTR_LCI]); + } + + if (tb[NL80211_FTM_RESP_ATTR_CIVICLOC]) { + bcn->civicloc = nla_data(tb[NL80211_FTM_RESP_ATTR_CIVICLOC]); + bcn->civicloc_len = nla_len(tb[NL80211_FTM_RESP_ATTR_CIVICLOC]); + } + } else { + bcn->ftm_responder = -1; + } + + if (attrs[NL80211_ATTR_HE_BSS_COLOR]) { + err = nl80211_parse_he_bss_color(attrs[NL80211_ATTR_HE_BSS_COLOR], + &bcn->he_bss_color); + if (err) + return err; + bcn->he_bss_color_valid = true; + } + + if (attrs[NL80211_ATTR_MBSSID_ELEMS]) { + struct cfg80211_mbssid_elems *mbssid = + nl80211_parse_mbssid_elems(&rdev->wiphy, + attrs[NL80211_ATTR_MBSSID_ELEMS]); + + if (IS_ERR(mbssid)) + return PTR_ERR(mbssid); + + bcn->mbssid_ies = mbssid; + } + + return 0; +} + +static int nl80211_parse_he_obss_pd(struct nlattr *attrs, + struct ieee80211_he_obss_pd *he_obss_pd) +{ + struct nlattr *tb[NL80211_HE_OBSS_PD_ATTR_MAX + 1]; + int err; + + err = nla_parse_nested(tb, NL80211_HE_OBSS_PD_ATTR_MAX, attrs, + he_obss_pd_policy, NULL); + if (err) + return err; + + if (!tb[NL80211_HE_OBSS_PD_ATTR_SR_CTRL]) + return -EINVAL; + + he_obss_pd->sr_ctrl = nla_get_u8(tb[NL80211_HE_OBSS_PD_ATTR_SR_CTRL]); + + if (tb[NL80211_HE_OBSS_PD_ATTR_MIN_OFFSET]) + he_obss_pd->min_offset = + nla_get_u8(tb[NL80211_HE_OBSS_PD_ATTR_MIN_OFFSET]); + if (tb[NL80211_HE_OBSS_PD_ATTR_MAX_OFFSET]) + he_obss_pd->max_offset = + nla_get_u8(tb[NL80211_HE_OBSS_PD_ATTR_MAX_OFFSET]); + if (tb[NL80211_HE_OBSS_PD_ATTR_NON_SRG_MAX_OFFSET]) + he_obss_pd->non_srg_max_offset = + nla_get_u8(tb[NL80211_HE_OBSS_PD_ATTR_NON_SRG_MAX_OFFSET]); + + if (he_obss_pd->min_offset > he_obss_pd->max_offset) + return -EINVAL; + + if (tb[NL80211_HE_OBSS_PD_ATTR_BSS_COLOR_BITMAP]) + memcpy(he_obss_pd->bss_color_bitmap, + nla_data(tb[NL80211_HE_OBSS_PD_ATTR_BSS_COLOR_BITMAP]), + sizeof(he_obss_pd->bss_color_bitmap)); + + if (tb[NL80211_HE_OBSS_PD_ATTR_PARTIAL_BSSID_BITMAP]) + memcpy(he_obss_pd->partial_bssid_bitmap, + nla_data(tb[NL80211_HE_OBSS_PD_ATTR_PARTIAL_BSSID_BITMAP]), + sizeof(he_obss_pd->partial_bssid_bitmap)); + + he_obss_pd->enable = true; + + return 0; +} + +static int nl80211_parse_fils_discovery(struct cfg80211_registered_device *rdev, + struct nlattr *attrs, + struct cfg80211_ap_settings *params) +{ + struct nlattr *tb[NL80211_FILS_DISCOVERY_ATTR_MAX + 1]; + int ret; + struct cfg80211_fils_discovery *fd = ¶ms->fils_discovery; + + if (!wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_FILS_DISCOVERY)) + return -EINVAL; + + ret = nla_parse_nested(tb, NL80211_FILS_DISCOVERY_ATTR_MAX, attrs, + NULL, NULL); + if (ret) + return ret; + + if (!tb[NL80211_FILS_DISCOVERY_ATTR_INT_MIN] || + !tb[NL80211_FILS_DISCOVERY_ATTR_INT_MAX] || + !tb[NL80211_FILS_DISCOVERY_ATTR_TMPL]) + return -EINVAL; + + fd->tmpl_len = nla_len(tb[NL80211_FILS_DISCOVERY_ATTR_TMPL]); + fd->tmpl = nla_data(tb[NL80211_FILS_DISCOVERY_ATTR_TMPL]); + fd->min_interval = nla_get_u32(tb[NL80211_FILS_DISCOVERY_ATTR_INT_MIN]); + fd->max_interval = nla_get_u32(tb[NL80211_FILS_DISCOVERY_ATTR_INT_MAX]); + + return 0; +} + +static int +nl80211_parse_unsol_bcast_probe_resp(struct cfg80211_registered_device *rdev, + struct nlattr *attrs, + struct cfg80211_ap_settings *params) +{ + struct nlattr *tb[NL80211_UNSOL_BCAST_PROBE_RESP_ATTR_MAX + 1]; + int ret; + struct cfg80211_unsol_bcast_probe_resp *presp = + ¶ms->unsol_bcast_probe_resp; + + if (!wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_UNSOL_BCAST_PROBE_RESP)) + return -EINVAL; + + ret = nla_parse_nested(tb, NL80211_UNSOL_BCAST_PROBE_RESP_ATTR_MAX, + attrs, NULL, NULL); + if (ret) + return ret; + + if (!tb[NL80211_UNSOL_BCAST_PROBE_RESP_ATTR_INT] || + !tb[NL80211_UNSOL_BCAST_PROBE_RESP_ATTR_TMPL]) + return -EINVAL; + + presp->tmpl = nla_data(tb[NL80211_UNSOL_BCAST_PROBE_RESP_ATTR_TMPL]); + presp->tmpl_len = nla_len(tb[NL80211_UNSOL_BCAST_PROBE_RESP_ATTR_TMPL]); + presp->interval = nla_get_u32(tb[NL80211_UNSOL_BCAST_PROBE_RESP_ATTR_INT]); + return 0; +} + +static void nl80211_check_ap_rate_selectors(struct cfg80211_ap_settings *params, + const struct element *rates) +{ + int i; + + if (!rates) + return; + + for (i = 0; i < rates->datalen; i++) { + if (rates->data[i] == BSS_MEMBERSHIP_SELECTOR_HT_PHY) + params->ht_required = true; + if (rates->data[i] == BSS_MEMBERSHIP_SELECTOR_VHT_PHY) + params->vht_required = true; + if (rates->data[i] == BSS_MEMBERSHIP_SELECTOR_HE_PHY) + params->he_required = true; + if (rates->data[i] == BSS_MEMBERSHIP_SELECTOR_SAE_H2E) + params->sae_h2e_required = true; + } +} + +/* + * Since the nl80211 API didn't include, from the beginning, attributes about + * HT/VHT requirements/capabilities, we parse them out of the IEs for the + * benefit of drivers that rebuild IEs in the firmware. + */ +static int nl80211_calculate_ap_params(struct cfg80211_ap_settings *params) +{ + const struct cfg80211_beacon_data *bcn = ¶ms->beacon; + size_t ies_len = bcn->tail_len; + const u8 *ies = bcn->tail; + const struct element *rates; + const struct element *cap; + + rates = cfg80211_find_elem(WLAN_EID_SUPP_RATES, ies, ies_len); + nl80211_check_ap_rate_selectors(params, rates); + + rates = cfg80211_find_elem(WLAN_EID_EXT_SUPP_RATES, ies, ies_len); + nl80211_check_ap_rate_selectors(params, rates); + + cap = cfg80211_find_elem(WLAN_EID_HT_CAPABILITY, ies, ies_len); + if (cap && cap->datalen >= sizeof(*params->ht_cap)) + params->ht_cap = (void *)cap->data; + cap = cfg80211_find_elem(WLAN_EID_VHT_CAPABILITY, ies, ies_len); + if (cap && cap->datalen >= sizeof(*params->vht_cap)) + params->vht_cap = (void *)cap->data; + cap = cfg80211_find_ext_elem(WLAN_EID_EXT_HE_CAPABILITY, ies, ies_len); + if (cap && cap->datalen >= sizeof(*params->he_cap) + 1) + params->he_cap = (void *)(cap->data + 1); + cap = cfg80211_find_ext_elem(WLAN_EID_EXT_HE_OPERATION, ies, ies_len); + if (cap && cap->datalen >= sizeof(*params->he_oper) + 1) + params->he_oper = (void *)(cap->data + 1); + cap = cfg80211_find_ext_elem(WLAN_EID_EXT_EHT_CAPABILITY, ies, ies_len); + if (cap) { + if (!cap->datalen) + return -EINVAL; + params->eht_cap = (void *)(cap->data + 1); + if (!ieee80211_eht_capa_size_ok((const u8 *)params->he_cap, + (const u8 *)params->eht_cap, + cap->datalen - 1, true)) + return -EINVAL; + } + cap = cfg80211_find_ext_elem(WLAN_EID_EXT_EHT_OPERATION, ies, ies_len); + if (cap) { + if (!cap->datalen) + return -EINVAL; + params->eht_oper = (void *)(cap->data + 1); + if (!ieee80211_eht_oper_size_ok((const u8 *)params->eht_oper, + cap->datalen - 1)) + return -EINVAL; + } + return 0; +} + +static bool nl80211_get_ap_channel(struct cfg80211_registered_device *rdev, + struct cfg80211_ap_settings *params) +{ + struct wireless_dev *wdev; + + list_for_each_entry(wdev, &rdev->wiphy.wdev_list, list) { + if (wdev->iftype != NL80211_IFTYPE_AP && + wdev->iftype != NL80211_IFTYPE_P2P_GO) + continue; + + if (!wdev->u.ap.preset_chandef.chan) + continue; + + params->chandef = wdev->u.ap.preset_chandef; + return true; + } + + return false; +} + +static bool nl80211_valid_auth_type(struct cfg80211_registered_device *rdev, + enum nl80211_auth_type auth_type, + enum nl80211_commands cmd) +{ + if (auth_type > NL80211_AUTHTYPE_MAX) + return false; + + switch (cmd) { + case NL80211_CMD_AUTHENTICATE: + if (!(rdev->wiphy.features & NL80211_FEATURE_SAE) && + auth_type == NL80211_AUTHTYPE_SAE) + return false; + if (!wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_FILS_STA) && + (auth_type == NL80211_AUTHTYPE_FILS_SK || + auth_type == NL80211_AUTHTYPE_FILS_SK_PFS || + auth_type == NL80211_AUTHTYPE_FILS_PK)) + return false; + return true; + case NL80211_CMD_CONNECT: + if (!(rdev->wiphy.features & NL80211_FEATURE_SAE) && + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_SAE_OFFLOAD) && + auth_type == NL80211_AUTHTYPE_SAE) + return false; + + /* FILS with SK PFS or PK not supported yet */ + if (auth_type == NL80211_AUTHTYPE_FILS_SK_PFS || + auth_type == NL80211_AUTHTYPE_FILS_PK) + return false; + if (!wiphy_ext_feature_isset( + &rdev->wiphy, + NL80211_EXT_FEATURE_FILS_SK_OFFLOAD) && + auth_type == NL80211_AUTHTYPE_FILS_SK) + return false; + return true; + case NL80211_CMD_START_AP: + if (!wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_SAE_OFFLOAD_AP) && + auth_type == NL80211_AUTHTYPE_SAE) + return false; + /* FILS not supported yet */ + if (auth_type == NL80211_AUTHTYPE_FILS_SK || + auth_type == NL80211_AUTHTYPE_FILS_SK_PFS || + auth_type == NL80211_AUTHTYPE_FILS_PK) + return false; + return true; + default: + return false; + } +} + +static int nl80211_start_ap(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + unsigned int link_id = nl80211_link_id(info->attrs); + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_ap_settings *params; + int err; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_AP && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_GO) + return -EOPNOTSUPP; + + if (!rdev->ops->start_ap) + return -EOPNOTSUPP; + + if (wdev->links[link_id].ap.beacon_interval) + return -EALREADY; + + /* these are required for START_AP */ + if (!info->attrs[NL80211_ATTR_BEACON_INTERVAL] || + !info->attrs[NL80211_ATTR_DTIM_PERIOD] || + !info->attrs[NL80211_ATTR_BEACON_HEAD]) + return -EINVAL; + + params = kzalloc(sizeof(*params), GFP_KERNEL); + if (!params) + return -ENOMEM; + + err = nl80211_parse_beacon(rdev, info->attrs, ¶ms->beacon); + if (err) + goto out; + + params->beacon_interval = + nla_get_u32(info->attrs[NL80211_ATTR_BEACON_INTERVAL]); + params->dtim_period = + nla_get_u32(info->attrs[NL80211_ATTR_DTIM_PERIOD]); + + err = cfg80211_validate_beacon_int(rdev, dev->ieee80211_ptr->iftype, + params->beacon_interval); + if (err) + goto out; + + /* + * In theory, some of these attributes should be required here + * but since they were not used when the command was originally + * added, keep them optional for old user space programs to let + * them continue to work with drivers that do not need the + * additional information -- drivers must check! + */ + if (info->attrs[NL80211_ATTR_SSID]) { + params->ssid = nla_data(info->attrs[NL80211_ATTR_SSID]); + params->ssid_len = + nla_len(info->attrs[NL80211_ATTR_SSID]); + if (params->ssid_len == 0) { + err = -EINVAL; + goto out; + } + + if (wdev->u.ap.ssid_len && + (wdev->u.ap.ssid_len != params->ssid_len || + memcmp(wdev->u.ap.ssid, params->ssid, params->ssid_len))) { + /* require identical SSID for MLO */ + err = -EINVAL; + goto out; + } + } else if (wdev->valid_links) { + /* require SSID for MLO */ + err = -EINVAL; + goto out; + } + + if (info->attrs[NL80211_ATTR_HIDDEN_SSID]) + params->hidden_ssid = nla_get_u32( + info->attrs[NL80211_ATTR_HIDDEN_SSID]); + + params->privacy = !!info->attrs[NL80211_ATTR_PRIVACY]; + + if (info->attrs[NL80211_ATTR_AUTH_TYPE]) { + params->auth_type = nla_get_u32( + info->attrs[NL80211_ATTR_AUTH_TYPE]); + if (!nl80211_valid_auth_type(rdev, params->auth_type, + NL80211_CMD_START_AP)) { + err = -EINVAL; + goto out; + } + } else + params->auth_type = NL80211_AUTHTYPE_AUTOMATIC; + + err = nl80211_crypto_settings(rdev, info, ¶ms->crypto, + NL80211_MAX_NR_CIPHER_SUITES); + if (err) + goto out; + + if (info->attrs[NL80211_ATTR_INACTIVITY_TIMEOUT]) { + if (!(rdev->wiphy.features & NL80211_FEATURE_INACTIVITY_TIMER)) { + err = -EOPNOTSUPP; + goto out; + } + params->inactivity_timeout = nla_get_u16( + info->attrs[NL80211_ATTR_INACTIVITY_TIMEOUT]); + } + + if (info->attrs[NL80211_ATTR_P2P_CTWINDOW]) { + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_GO) { + err = -EINVAL; + goto out; + } + params->p2p_ctwindow = + nla_get_u8(info->attrs[NL80211_ATTR_P2P_CTWINDOW]); + if (params->p2p_ctwindow != 0 && + !(rdev->wiphy.features & NL80211_FEATURE_P2P_GO_CTWIN)) { + err = -EINVAL; + goto out; + } + } + + if (info->attrs[NL80211_ATTR_P2P_OPPPS]) { + u8 tmp; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_GO) { + err = -EINVAL; + goto out; + } + tmp = nla_get_u8(info->attrs[NL80211_ATTR_P2P_OPPPS]); + params->p2p_opp_ps = tmp; + if (params->p2p_opp_ps != 0 && + !(rdev->wiphy.features & NL80211_FEATURE_P2P_GO_OPPPS)) { + err = -EINVAL; + goto out; + } + } + + if (info->attrs[NL80211_ATTR_WIPHY_FREQ]) { + err = nl80211_parse_chandef(rdev, info, ¶ms->chandef); + if (err) + goto out; + } else if (wdev->valid_links) { + /* with MLD need to specify the channel configuration */ + err = -EINVAL; + goto out; + } else if (wdev->u.ap.preset_chandef.chan) { + params->chandef = wdev->u.ap.preset_chandef; + } else if (!nl80211_get_ap_channel(rdev, params)) { + err = -EINVAL; + goto out; + } + + if (!cfg80211_reg_can_beacon_relax(&rdev->wiphy, ¶ms->chandef, + wdev->iftype)) { + err = -EINVAL; + goto out; + } + + wdev_lock(wdev); + + if (info->attrs[NL80211_ATTR_TX_RATES]) { + err = nl80211_parse_tx_bitrate_mask(info, info->attrs, + NL80211_ATTR_TX_RATES, + ¶ms->beacon_rate, + dev, false, link_id); + if (err) + goto out_unlock; + + err = validate_beacon_tx_rate(rdev, params->chandef.chan->band, + ¶ms->beacon_rate); + if (err) + goto out_unlock; + } + + if (info->attrs[NL80211_ATTR_SMPS_MODE]) { + params->smps_mode = + nla_get_u8(info->attrs[NL80211_ATTR_SMPS_MODE]); + switch (params->smps_mode) { + case NL80211_SMPS_OFF: + break; + case NL80211_SMPS_STATIC: + if (!(rdev->wiphy.features & + NL80211_FEATURE_STATIC_SMPS)) { + err = -EINVAL; + goto out_unlock; + } + break; + case NL80211_SMPS_DYNAMIC: + if (!(rdev->wiphy.features & + NL80211_FEATURE_DYNAMIC_SMPS)) { + err = -EINVAL; + goto out_unlock; + } + break; + default: + err = -EINVAL; + goto out_unlock; + } + } else { + params->smps_mode = NL80211_SMPS_OFF; + } + + params->pbss = nla_get_flag(info->attrs[NL80211_ATTR_PBSS]); + if (params->pbss && !rdev->wiphy.bands[NL80211_BAND_60GHZ]) { + err = -EOPNOTSUPP; + goto out_unlock; + } + + if (info->attrs[NL80211_ATTR_ACL_POLICY]) { + params->acl = parse_acl_data(&rdev->wiphy, info); + if (IS_ERR(params->acl)) { + err = PTR_ERR(params->acl); + params->acl = NULL; + goto out_unlock; + } + } + + params->twt_responder = + nla_get_flag(info->attrs[NL80211_ATTR_TWT_RESPONDER]); + + if (info->attrs[NL80211_ATTR_HE_OBSS_PD]) { + err = nl80211_parse_he_obss_pd( + info->attrs[NL80211_ATTR_HE_OBSS_PD], + ¶ms->he_obss_pd); + if (err) + goto out_unlock; + } + + if (info->attrs[NL80211_ATTR_FILS_DISCOVERY]) { + err = nl80211_parse_fils_discovery(rdev, + info->attrs[NL80211_ATTR_FILS_DISCOVERY], + params); + if (err) + goto out_unlock; + } + + if (info->attrs[NL80211_ATTR_UNSOL_BCAST_PROBE_RESP]) { + err = nl80211_parse_unsol_bcast_probe_resp( + rdev, info->attrs[NL80211_ATTR_UNSOL_BCAST_PROBE_RESP], + params); + if (err) + goto out_unlock; + } + + if (info->attrs[NL80211_ATTR_MBSSID_CONFIG]) { + err = nl80211_parse_mbssid_config(&rdev->wiphy, dev, + info->attrs[NL80211_ATTR_MBSSID_CONFIG], + ¶ms->mbssid_config, + params->beacon.mbssid_ies ? + params->beacon.mbssid_ies->cnt : + 0); + if (err) + goto out_unlock; + } + + err = nl80211_calculate_ap_params(params); + if (err) + goto out_unlock; + + if (info->attrs[NL80211_ATTR_AP_SETTINGS_FLAGS]) + params->flags = nla_get_u32( + info->attrs[NL80211_ATTR_AP_SETTINGS_FLAGS]); + else if (info->attrs[NL80211_ATTR_EXTERNAL_AUTH_SUPPORT]) + params->flags |= NL80211_AP_SETTINGS_EXTERNAL_AUTH_SUPPORT; + + if (wdev->conn_owner_nlportid && + info->attrs[NL80211_ATTR_SOCKET_OWNER] && + wdev->conn_owner_nlportid != info->snd_portid) { + err = -EINVAL; + goto out_unlock; + } + + /* FIXME: validate MLO/link-id against driver capabilities */ + + err = rdev_start_ap(rdev, dev, params); + if (!err) { + wdev->links[link_id].ap.beacon_interval = params->beacon_interval; + wdev->links[link_id].ap.chandef = params->chandef; + wdev->u.ap.ssid_len = params->ssid_len; + memcpy(wdev->u.ap.ssid, params->ssid, + params->ssid_len); + + if (info->attrs[NL80211_ATTR_SOCKET_OWNER]) + wdev->conn_owner_nlportid = info->snd_portid; + } +out_unlock: + wdev_unlock(wdev); +out: + kfree(params->acl); + kfree(params->beacon.mbssid_ies); + if (params->mbssid_config.tx_wdev && + params->mbssid_config.tx_wdev->netdev && + params->mbssid_config.tx_wdev->netdev != dev) + dev_put(params->mbssid_config.tx_wdev->netdev); + kfree(params); + + return err; +} + +static int nl80211_set_beacon(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + unsigned int link_id = nl80211_link_id(info->attrs); + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_beacon_data params; + int err; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_AP && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_GO) + return -EOPNOTSUPP; + + if (!rdev->ops->change_beacon) + return -EOPNOTSUPP; + + if (!wdev->links[link_id].ap.beacon_interval) + return -EINVAL; + + err = nl80211_parse_beacon(rdev, info->attrs, ¶ms); + if (err) + goto out; + + wdev_lock(wdev); + err = rdev_change_beacon(rdev, dev, ¶ms); + wdev_unlock(wdev); + +out: + kfree(params.mbssid_ies); + return err; +} + +static int nl80211_stop_ap(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + unsigned int link_id = nl80211_link_id(info->attrs); + struct net_device *dev = info->user_ptr[1]; + + return cfg80211_stop_ap(rdev, dev, link_id, false); +} + +static const struct nla_policy sta_flags_policy[NL80211_STA_FLAG_MAX + 1] = { + [NL80211_STA_FLAG_AUTHORIZED] = { .type = NLA_FLAG }, + [NL80211_STA_FLAG_SHORT_PREAMBLE] = { .type = NLA_FLAG }, + [NL80211_STA_FLAG_WME] = { .type = NLA_FLAG }, + [NL80211_STA_FLAG_MFP] = { .type = NLA_FLAG }, + [NL80211_STA_FLAG_AUTHENTICATED] = { .type = NLA_FLAG }, + [NL80211_STA_FLAG_TDLS_PEER] = { .type = NLA_FLAG }, +}; + +static int parse_station_flags(struct genl_info *info, + enum nl80211_iftype iftype, + struct station_parameters *params) +{ + struct nlattr *flags[NL80211_STA_FLAG_MAX + 1]; + struct nlattr *nla; + int flag; + + /* + * Try parsing the new attribute first so userspace + * can specify both for older kernels. + */ + nla = info->attrs[NL80211_ATTR_STA_FLAGS2]; + if (nla) { + struct nl80211_sta_flag_update *sta_flags; + + sta_flags = nla_data(nla); + params->sta_flags_mask = sta_flags->mask; + params->sta_flags_set = sta_flags->set; + params->sta_flags_set &= params->sta_flags_mask; + if ((params->sta_flags_mask | + params->sta_flags_set) & BIT(__NL80211_STA_FLAG_INVALID)) + return -EINVAL; + return 0; + } + + /* if present, parse the old attribute */ + + nla = info->attrs[NL80211_ATTR_STA_FLAGS]; + if (!nla) + return 0; + + if (nla_parse_nested_deprecated(flags, NL80211_STA_FLAG_MAX, nla, sta_flags_policy, info->extack)) + return -EINVAL; + + /* + * Only allow certain flags for interface types so that + * other attributes are silently ignored. Remember that + * this is backward compatibility code with old userspace + * and shouldn't be hit in other cases anyway. + */ + switch (iftype) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_P2P_GO: + params->sta_flags_mask = BIT(NL80211_STA_FLAG_AUTHORIZED) | + BIT(NL80211_STA_FLAG_SHORT_PREAMBLE) | + BIT(NL80211_STA_FLAG_WME) | + BIT(NL80211_STA_FLAG_MFP); + break; + case NL80211_IFTYPE_P2P_CLIENT: + case NL80211_IFTYPE_STATION: + params->sta_flags_mask = BIT(NL80211_STA_FLAG_AUTHORIZED) | + BIT(NL80211_STA_FLAG_TDLS_PEER); + break; + case NL80211_IFTYPE_MESH_POINT: + params->sta_flags_mask = BIT(NL80211_STA_FLAG_AUTHENTICATED) | + BIT(NL80211_STA_FLAG_MFP) | + BIT(NL80211_STA_FLAG_AUTHORIZED); + break; + default: + return -EINVAL; + } + + for (flag = 1; flag <= NL80211_STA_FLAG_MAX; flag++) { + if (flags[flag]) { + params->sta_flags_set |= (1< NL80211_STA_FLAG_MAX_OLD_API) + return -EINVAL; + } + } + + return 0; +} + +bool nl80211_put_sta_rate(struct sk_buff *msg, struct rate_info *info, int attr) +{ + struct nlattr *rate; + u32 bitrate; + u16 bitrate_compat; + enum nl80211_rate_info rate_flg; + + rate = nla_nest_start_noflag(msg, attr); + if (!rate) + return false; + + /* cfg80211_calculate_bitrate will return 0 for mcs >= 32 */ + bitrate = cfg80211_calculate_bitrate(info); + /* report 16-bit bitrate only if we can */ + bitrate_compat = bitrate < (1UL << 16) ? bitrate : 0; + if (bitrate > 0 && + nla_put_u32(msg, NL80211_RATE_INFO_BITRATE32, bitrate)) + return false; + if (bitrate_compat > 0 && + nla_put_u16(msg, NL80211_RATE_INFO_BITRATE, bitrate_compat)) + return false; + + switch (info->bw) { + case RATE_INFO_BW_5: + rate_flg = NL80211_RATE_INFO_5_MHZ_WIDTH; + break; + case RATE_INFO_BW_10: + rate_flg = NL80211_RATE_INFO_10_MHZ_WIDTH; + break; + default: + WARN_ON(1); + fallthrough; + case RATE_INFO_BW_20: + rate_flg = 0; + break; + case RATE_INFO_BW_40: + rate_flg = NL80211_RATE_INFO_40_MHZ_WIDTH; + break; + case RATE_INFO_BW_80: + rate_flg = NL80211_RATE_INFO_80_MHZ_WIDTH; + break; + case RATE_INFO_BW_160: + rate_flg = NL80211_RATE_INFO_160_MHZ_WIDTH; + break; + case RATE_INFO_BW_HE_RU: + rate_flg = 0; + WARN_ON(!(info->flags & RATE_INFO_FLAGS_HE_MCS)); + break; + case RATE_INFO_BW_320: + rate_flg = NL80211_RATE_INFO_320_MHZ_WIDTH; + break; + case RATE_INFO_BW_EHT_RU: + rate_flg = 0; + WARN_ON(!(info->flags & RATE_INFO_FLAGS_EHT_MCS)); + break; + } + + if (rate_flg && nla_put_flag(msg, rate_flg)) + return false; + + if (info->flags & RATE_INFO_FLAGS_MCS) { + if (nla_put_u8(msg, NL80211_RATE_INFO_MCS, info->mcs)) + return false; + if (info->flags & RATE_INFO_FLAGS_SHORT_GI && + nla_put_flag(msg, NL80211_RATE_INFO_SHORT_GI)) + return false; + } else if (info->flags & RATE_INFO_FLAGS_VHT_MCS) { + if (nla_put_u8(msg, NL80211_RATE_INFO_VHT_MCS, info->mcs)) + return false; + if (nla_put_u8(msg, NL80211_RATE_INFO_VHT_NSS, info->nss)) + return false; + if (info->flags & RATE_INFO_FLAGS_SHORT_GI && + nla_put_flag(msg, NL80211_RATE_INFO_SHORT_GI)) + return false; + } else if (info->flags & RATE_INFO_FLAGS_HE_MCS) { + if (nla_put_u8(msg, NL80211_RATE_INFO_HE_MCS, info->mcs)) + return false; + if (nla_put_u8(msg, NL80211_RATE_INFO_HE_NSS, info->nss)) + return false; + if (nla_put_u8(msg, NL80211_RATE_INFO_HE_GI, info->he_gi)) + return false; + if (nla_put_u8(msg, NL80211_RATE_INFO_HE_DCM, info->he_dcm)) + return false; + if (info->bw == RATE_INFO_BW_HE_RU && + nla_put_u8(msg, NL80211_RATE_INFO_HE_RU_ALLOC, + info->he_ru_alloc)) + return false; + } else if (info->flags & RATE_INFO_FLAGS_EHT_MCS) { + if (nla_put_u8(msg, NL80211_RATE_INFO_EHT_MCS, info->mcs)) + return false; + if (nla_put_u8(msg, NL80211_RATE_INFO_EHT_NSS, info->nss)) + return false; + if (nla_put_u8(msg, NL80211_RATE_INFO_EHT_GI, info->eht_gi)) + return false; + if (info->bw == RATE_INFO_BW_EHT_RU && + nla_put_u8(msg, NL80211_RATE_INFO_EHT_RU_ALLOC, + info->eht_ru_alloc)) + return false; + } + + nla_nest_end(msg, rate); + return true; +} + +static bool nl80211_put_signal(struct sk_buff *msg, u8 mask, s8 *signal, + int id) +{ + void *attr; + int i = 0; + + if (!mask) + return true; + + attr = nla_nest_start_noflag(msg, id); + if (!attr) + return false; + + for (i = 0; i < IEEE80211_MAX_CHAINS; i++) { + if (!(mask & BIT(i))) + continue; + + if (nla_put_u8(msg, i, signal[i])) + return false; + } + + nla_nest_end(msg, attr); + + return true; +} + +static int nl80211_send_station(struct sk_buff *msg, u32 cmd, u32 portid, + u32 seq, int flags, + struct cfg80211_registered_device *rdev, + struct net_device *dev, + const u8 *mac_addr, struct station_info *sinfo) +{ + void *hdr; + struct nlattr *sinfoattr, *bss_param; + + hdr = nl80211hdr_put(msg, portid, seq, flags, cmd); + if (!hdr) { + cfg80211_sinfo_release_content(sinfo); + return -1; + } + + if (nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, mac_addr) || + nla_put_u32(msg, NL80211_ATTR_GENERATION, sinfo->generation)) + goto nla_put_failure; + + sinfoattr = nla_nest_start_noflag(msg, NL80211_ATTR_STA_INFO); + if (!sinfoattr) + goto nla_put_failure; + +#define PUT_SINFO(attr, memb, type) do { \ + BUILD_BUG_ON(sizeof(type) == sizeof(u64)); \ + if (sinfo->filled & BIT_ULL(NL80211_STA_INFO_ ## attr) && \ + nla_put_ ## type(msg, NL80211_STA_INFO_ ## attr, \ + sinfo->memb)) \ + goto nla_put_failure; \ + } while (0) +#define PUT_SINFO_U64(attr, memb) do { \ + if (sinfo->filled & BIT_ULL(NL80211_STA_INFO_ ## attr) && \ + nla_put_u64_64bit(msg, NL80211_STA_INFO_ ## attr, \ + sinfo->memb, NL80211_STA_INFO_PAD)) \ + goto nla_put_failure; \ + } while (0) + + PUT_SINFO(CONNECTED_TIME, connected_time, u32); + PUT_SINFO(INACTIVE_TIME, inactive_time, u32); + PUT_SINFO_U64(ASSOC_AT_BOOTTIME, assoc_at); + + if (sinfo->filled & (BIT_ULL(NL80211_STA_INFO_RX_BYTES) | + BIT_ULL(NL80211_STA_INFO_RX_BYTES64)) && + nla_put_u32(msg, NL80211_STA_INFO_RX_BYTES, + (u32)sinfo->rx_bytes)) + goto nla_put_failure; + + if (sinfo->filled & (BIT_ULL(NL80211_STA_INFO_TX_BYTES) | + BIT_ULL(NL80211_STA_INFO_TX_BYTES64)) && + nla_put_u32(msg, NL80211_STA_INFO_TX_BYTES, + (u32)sinfo->tx_bytes)) + goto nla_put_failure; + + PUT_SINFO_U64(RX_BYTES64, rx_bytes); + PUT_SINFO_U64(TX_BYTES64, tx_bytes); + PUT_SINFO(LLID, llid, u16); + PUT_SINFO(PLID, plid, u16); + PUT_SINFO(PLINK_STATE, plink_state, u8); + PUT_SINFO_U64(RX_DURATION, rx_duration); + PUT_SINFO_U64(TX_DURATION, tx_duration); + + if (wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_AIRTIME_FAIRNESS)) + PUT_SINFO(AIRTIME_WEIGHT, airtime_weight, u16); + + switch (rdev->wiphy.signal_type) { + case CFG80211_SIGNAL_TYPE_MBM: + PUT_SINFO(SIGNAL, signal, u8); + PUT_SINFO(SIGNAL_AVG, signal_avg, u8); + break; + default: + break; + } + if (sinfo->filled & BIT_ULL(NL80211_STA_INFO_CHAIN_SIGNAL)) { + if (!nl80211_put_signal(msg, sinfo->chains, + sinfo->chain_signal, + NL80211_STA_INFO_CHAIN_SIGNAL)) + goto nla_put_failure; + } + if (sinfo->filled & BIT_ULL(NL80211_STA_INFO_CHAIN_SIGNAL_AVG)) { + if (!nl80211_put_signal(msg, sinfo->chains, + sinfo->chain_signal_avg, + NL80211_STA_INFO_CHAIN_SIGNAL_AVG)) + goto nla_put_failure; + } + if (sinfo->filled & BIT_ULL(NL80211_STA_INFO_TX_BITRATE)) { + if (!nl80211_put_sta_rate(msg, &sinfo->txrate, + NL80211_STA_INFO_TX_BITRATE)) + goto nla_put_failure; + } + if (sinfo->filled & BIT_ULL(NL80211_STA_INFO_RX_BITRATE)) { + if (!nl80211_put_sta_rate(msg, &sinfo->rxrate, + NL80211_STA_INFO_RX_BITRATE)) + goto nla_put_failure; + } + + PUT_SINFO(RX_PACKETS, rx_packets, u32); + PUT_SINFO(TX_PACKETS, tx_packets, u32); + PUT_SINFO(TX_RETRIES, tx_retries, u32); + PUT_SINFO(TX_FAILED, tx_failed, u32); + PUT_SINFO(EXPECTED_THROUGHPUT, expected_throughput, u32); + PUT_SINFO(AIRTIME_LINK_METRIC, airtime_link_metric, u32); + PUT_SINFO(BEACON_LOSS, beacon_loss_count, u32); + PUT_SINFO(LOCAL_PM, local_pm, u32); + PUT_SINFO(PEER_PM, peer_pm, u32); + PUT_SINFO(NONPEER_PM, nonpeer_pm, u32); + PUT_SINFO(CONNECTED_TO_GATE, connected_to_gate, u8); + PUT_SINFO(CONNECTED_TO_AS, connected_to_as, u8); + + if (sinfo->filled & BIT_ULL(NL80211_STA_INFO_BSS_PARAM)) { + bss_param = nla_nest_start_noflag(msg, + NL80211_STA_INFO_BSS_PARAM); + if (!bss_param) + goto nla_put_failure; + + if (((sinfo->bss_param.flags & BSS_PARAM_FLAGS_CTS_PROT) && + nla_put_flag(msg, NL80211_STA_BSS_PARAM_CTS_PROT)) || + ((sinfo->bss_param.flags & BSS_PARAM_FLAGS_SHORT_PREAMBLE) && + nla_put_flag(msg, NL80211_STA_BSS_PARAM_SHORT_PREAMBLE)) || + ((sinfo->bss_param.flags & BSS_PARAM_FLAGS_SHORT_SLOT_TIME) && + nla_put_flag(msg, NL80211_STA_BSS_PARAM_SHORT_SLOT_TIME)) || + nla_put_u8(msg, NL80211_STA_BSS_PARAM_DTIM_PERIOD, + sinfo->bss_param.dtim_period) || + nla_put_u16(msg, NL80211_STA_BSS_PARAM_BEACON_INTERVAL, + sinfo->bss_param.beacon_interval)) + goto nla_put_failure; + + nla_nest_end(msg, bss_param); + } + if ((sinfo->filled & BIT_ULL(NL80211_STA_INFO_STA_FLAGS)) && + nla_put(msg, NL80211_STA_INFO_STA_FLAGS, + sizeof(struct nl80211_sta_flag_update), + &sinfo->sta_flags)) + goto nla_put_failure; + + PUT_SINFO_U64(T_OFFSET, t_offset); + PUT_SINFO_U64(RX_DROP_MISC, rx_dropped_misc); + PUT_SINFO_U64(BEACON_RX, rx_beacon); + PUT_SINFO(BEACON_SIGNAL_AVG, rx_beacon_signal_avg, u8); + PUT_SINFO(RX_MPDUS, rx_mpdu_count, u32); + PUT_SINFO(FCS_ERROR_COUNT, fcs_err_count, u32); + if (wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_ACK_SIGNAL_SUPPORT)) { + PUT_SINFO(ACK_SIGNAL, ack_signal, u8); + PUT_SINFO(ACK_SIGNAL_AVG, avg_ack_signal, s8); + } + +#undef PUT_SINFO +#undef PUT_SINFO_U64 + + if (sinfo->pertid) { + struct nlattr *tidsattr; + int tid; + + tidsattr = nla_nest_start_noflag(msg, + NL80211_STA_INFO_TID_STATS); + if (!tidsattr) + goto nla_put_failure; + + for (tid = 0; tid < IEEE80211_NUM_TIDS + 1; tid++) { + struct cfg80211_tid_stats *tidstats; + struct nlattr *tidattr; + + tidstats = &sinfo->pertid[tid]; + + if (!tidstats->filled) + continue; + + tidattr = nla_nest_start_noflag(msg, tid + 1); + if (!tidattr) + goto nla_put_failure; + +#define PUT_TIDVAL_U64(attr, memb) do { \ + if (tidstats->filled & BIT(NL80211_TID_STATS_ ## attr) && \ + nla_put_u64_64bit(msg, NL80211_TID_STATS_ ## attr, \ + tidstats->memb, NL80211_TID_STATS_PAD)) \ + goto nla_put_failure; \ + } while (0) + + PUT_TIDVAL_U64(RX_MSDU, rx_msdu); + PUT_TIDVAL_U64(TX_MSDU, tx_msdu); + PUT_TIDVAL_U64(TX_MSDU_RETRIES, tx_msdu_retries); + PUT_TIDVAL_U64(TX_MSDU_FAILED, tx_msdu_failed); + +#undef PUT_TIDVAL_U64 + if ((tidstats->filled & + BIT(NL80211_TID_STATS_TXQ_STATS)) && + !nl80211_put_txq_stats(msg, &tidstats->txq_stats, + NL80211_TID_STATS_TXQ_STATS)) + goto nla_put_failure; + + nla_nest_end(msg, tidattr); + } + + nla_nest_end(msg, tidsattr); + } + + nla_nest_end(msg, sinfoattr); + + if (sinfo->assoc_req_ies_len && + nla_put(msg, NL80211_ATTR_IE, sinfo->assoc_req_ies_len, + sinfo->assoc_req_ies)) + goto nla_put_failure; + + cfg80211_sinfo_release_content(sinfo); + genlmsg_end(msg, hdr); + return 0; + + nla_put_failure: + cfg80211_sinfo_release_content(sinfo); + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int nl80211_dump_station(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct station_info sinfo; + struct cfg80211_registered_device *rdev; + struct wireless_dev *wdev; + u8 mac_addr[ETH_ALEN]; + int sta_idx = cb->args[2]; + int err; + + err = nl80211_prepare_wdev_dump(cb, &rdev, &wdev, NULL); + if (err) + return err; + /* nl80211_prepare_wdev_dump acquired it in the successful case */ + __acquire(&rdev->wiphy.mtx); + + if (!wdev->netdev) { + err = -EINVAL; + goto out_err; + } + + if (!rdev->ops->dump_station) { + err = -EOPNOTSUPP; + goto out_err; + } + + while (1) { + memset(&sinfo, 0, sizeof(sinfo)); + err = rdev_dump_station(rdev, wdev->netdev, sta_idx, + mac_addr, &sinfo); + if (err == -ENOENT) + break; + if (err) + goto out_err; + + if (nl80211_send_station(skb, NL80211_CMD_NEW_STATION, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + rdev, wdev->netdev, mac_addr, + &sinfo) < 0) + goto out; + + sta_idx++; + } + + out: + cb->args[2] = sta_idx; + err = skb->len; + out_err: + wiphy_unlock(&rdev->wiphy); + + return err; +} + +static int nl80211_get_station(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct station_info sinfo; + struct sk_buff *msg; + u8 *mac_addr = NULL; + int err; + + memset(&sinfo, 0, sizeof(sinfo)); + + if (!info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + mac_addr = nla_data(info->attrs[NL80211_ATTR_MAC]); + + if (!rdev->ops->get_station) + return -EOPNOTSUPP; + + err = rdev_get_station(rdev, dev, mac_addr, &sinfo); + if (err) + return err; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) { + cfg80211_sinfo_release_content(&sinfo); + return -ENOMEM; + } + + if (nl80211_send_station(msg, NL80211_CMD_NEW_STATION, + info->snd_portid, info->snd_seq, 0, + rdev, dev, mac_addr, &sinfo) < 0) { + nlmsg_free(msg); + return -ENOBUFS; + } + + return genlmsg_reply(msg, info); +} + +int cfg80211_check_station_change(struct wiphy *wiphy, + struct station_parameters *params, + enum cfg80211_station_type statype) +{ + if (params->listen_interval != -1 && + statype != CFG80211_STA_AP_CLIENT_UNASSOC) + return -EINVAL; + + if (params->support_p2p_ps != -1 && + statype != CFG80211_STA_AP_CLIENT_UNASSOC) + return -EINVAL; + + if (params->aid && + !(params->sta_flags_set & BIT(NL80211_STA_FLAG_TDLS_PEER)) && + statype != CFG80211_STA_AP_CLIENT_UNASSOC) + return -EINVAL; + + /* When you run into this, adjust the code below for the new flag */ + BUILD_BUG_ON(NL80211_STA_FLAG_MAX != 7); + + switch (statype) { + case CFG80211_STA_MESH_PEER_KERNEL: + case CFG80211_STA_MESH_PEER_USER: + /* + * No ignoring the TDLS flag here -- the userspace mesh + * code doesn't have the bug of including TDLS in the + * mask everywhere. + */ + if (params->sta_flags_mask & + ~(BIT(NL80211_STA_FLAG_AUTHENTICATED) | + BIT(NL80211_STA_FLAG_MFP) | + BIT(NL80211_STA_FLAG_AUTHORIZED))) + return -EINVAL; + break; + case CFG80211_STA_TDLS_PEER_SETUP: + case CFG80211_STA_TDLS_PEER_ACTIVE: + if (!(params->sta_flags_set & BIT(NL80211_STA_FLAG_TDLS_PEER))) + return -EINVAL; + /* ignore since it can't change */ + params->sta_flags_mask &= ~BIT(NL80211_STA_FLAG_TDLS_PEER); + break; + default: + /* disallow mesh-specific things */ + if (params->plink_action != NL80211_PLINK_ACTION_NO_ACTION) + return -EINVAL; + if (params->local_pm) + return -EINVAL; + if (params->sta_modify_mask & STATION_PARAM_APPLY_PLINK_STATE) + return -EINVAL; + } + + if (statype != CFG80211_STA_TDLS_PEER_SETUP && + statype != CFG80211_STA_TDLS_PEER_ACTIVE) { + /* TDLS can't be set, ... */ + if (params->sta_flags_set & BIT(NL80211_STA_FLAG_TDLS_PEER)) + return -EINVAL; + /* + * ... but don't bother the driver with it. This works around + * a hostapd/wpa_supplicant issue -- it always includes the + * TLDS_PEER flag in the mask even for AP mode. + */ + params->sta_flags_mask &= ~BIT(NL80211_STA_FLAG_TDLS_PEER); + } + + if (statype != CFG80211_STA_TDLS_PEER_SETUP && + statype != CFG80211_STA_AP_CLIENT_UNASSOC) { + /* reject other things that can't change */ + if (params->sta_modify_mask & STATION_PARAM_APPLY_UAPSD) + return -EINVAL; + if (params->sta_modify_mask & STATION_PARAM_APPLY_CAPABILITY) + return -EINVAL; + if (params->link_sta_params.supported_rates) + return -EINVAL; + if (params->ext_capab || params->link_sta_params.ht_capa || + params->link_sta_params.vht_capa || + params->link_sta_params.he_capa || + params->link_sta_params.eht_capa) + return -EINVAL; + } + + if (statype != CFG80211_STA_AP_CLIENT && + statype != CFG80211_STA_AP_CLIENT_UNASSOC) { + if (params->vlan) + return -EINVAL; + } + + switch (statype) { + case CFG80211_STA_AP_MLME_CLIENT: + /* Use this only for authorizing/unauthorizing a station */ + if (!(params->sta_flags_mask & BIT(NL80211_STA_FLAG_AUTHORIZED))) + return -EOPNOTSUPP; + break; + case CFG80211_STA_AP_CLIENT: + case CFG80211_STA_AP_CLIENT_UNASSOC: + /* accept only the listed bits */ + if (params->sta_flags_mask & + ~(BIT(NL80211_STA_FLAG_AUTHORIZED) | + BIT(NL80211_STA_FLAG_AUTHENTICATED) | + BIT(NL80211_STA_FLAG_ASSOCIATED) | + BIT(NL80211_STA_FLAG_SHORT_PREAMBLE) | + BIT(NL80211_STA_FLAG_WME) | + BIT(NL80211_STA_FLAG_MFP))) + return -EINVAL; + + /* but authenticated/associated only if driver handles it */ + if (!(wiphy->features & NL80211_FEATURE_FULL_AP_CLIENT_STATE) && + params->sta_flags_mask & + (BIT(NL80211_STA_FLAG_AUTHENTICATED) | + BIT(NL80211_STA_FLAG_ASSOCIATED))) + return -EINVAL; + break; + case CFG80211_STA_IBSS: + case CFG80211_STA_AP_STA: + /* reject any changes other than AUTHORIZED */ + if (params->sta_flags_mask & ~BIT(NL80211_STA_FLAG_AUTHORIZED)) + return -EINVAL; + break; + case CFG80211_STA_TDLS_PEER_SETUP: + /* reject any changes other than AUTHORIZED or WME */ + if (params->sta_flags_mask & ~(BIT(NL80211_STA_FLAG_AUTHORIZED) | + BIT(NL80211_STA_FLAG_WME))) + return -EINVAL; + /* force (at least) rates when authorizing */ + if (params->sta_flags_set & BIT(NL80211_STA_FLAG_AUTHORIZED) && + !params->link_sta_params.supported_rates) + return -EINVAL; + break; + case CFG80211_STA_TDLS_PEER_ACTIVE: + /* reject any changes */ + return -EINVAL; + case CFG80211_STA_MESH_PEER_KERNEL: + if (params->sta_modify_mask & STATION_PARAM_APPLY_PLINK_STATE) + return -EINVAL; + break; + case CFG80211_STA_MESH_PEER_USER: + if (params->plink_action != NL80211_PLINK_ACTION_NO_ACTION && + params->plink_action != NL80211_PLINK_ACTION_BLOCK) + return -EINVAL; + break; + } + + /* + * Older kernel versions ignored this attribute entirely, so don't + * reject attempts to update it but mark it as unused instead so the + * driver won't look at the data. + */ + if (statype != CFG80211_STA_AP_CLIENT_UNASSOC && + statype != CFG80211_STA_TDLS_PEER_SETUP) + params->link_sta_params.opmode_notif_used = false; + + return 0; +} +EXPORT_SYMBOL(cfg80211_check_station_change); + +/* + * Get vlan interface making sure it is running and on the right wiphy. + */ +static struct net_device *get_vlan(struct genl_info *info, + struct cfg80211_registered_device *rdev) +{ + struct nlattr *vlanattr = info->attrs[NL80211_ATTR_STA_VLAN]; + struct net_device *v; + int ret; + + if (!vlanattr) + return NULL; + + v = dev_get_by_index(genl_info_net(info), nla_get_u32(vlanattr)); + if (!v) + return ERR_PTR(-ENODEV); + + if (!v->ieee80211_ptr || v->ieee80211_ptr->wiphy != &rdev->wiphy) { + ret = -EINVAL; + goto error; + } + + if (v->ieee80211_ptr->iftype != NL80211_IFTYPE_AP_VLAN && + v->ieee80211_ptr->iftype != NL80211_IFTYPE_AP && + v->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_GO) { + ret = -EINVAL; + goto error; + } + + if (!netif_running(v)) { + ret = -ENETDOWN; + goto error; + } + + return v; + error: + dev_put(v); + return ERR_PTR(ret); +} + +static int nl80211_parse_sta_wme(struct genl_info *info, + struct station_parameters *params) +{ + struct nlattr *tb[NL80211_STA_WME_MAX + 1]; + struct nlattr *nla; + int err; + + /* parse WME attributes if present */ + if (!info->attrs[NL80211_ATTR_STA_WME]) + return 0; + + nla = info->attrs[NL80211_ATTR_STA_WME]; + err = nla_parse_nested_deprecated(tb, NL80211_STA_WME_MAX, nla, + nl80211_sta_wme_policy, + info->extack); + if (err) + return err; + + if (tb[NL80211_STA_WME_UAPSD_QUEUES]) + params->uapsd_queues = nla_get_u8( + tb[NL80211_STA_WME_UAPSD_QUEUES]); + if (params->uapsd_queues & ~IEEE80211_WMM_IE_STA_QOSINFO_AC_MASK) + return -EINVAL; + + if (tb[NL80211_STA_WME_MAX_SP]) + params->max_sp = nla_get_u8(tb[NL80211_STA_WME_MAX_SP]); + + if (params->max_sp & ~IEEE80211_WMM_IE_STA_QOSINFO_SP_MASK) + return -EINVAL; + + params->sta_modify_mask |= STATION_PARAM_APPLY_UAPSD; + + return 0; +} + +static int nl80211_parse_sta_channel_info(struct genl_info *info, + struct station_parameters *params) +{ + if (info->attrs[NL80211_ATTR_STA_SUPPORTED_CHANNELS]) { + params->supported_channels = + nla_data(info->attrs[NL80211_ATTR_STA_SUPPORTED_CHANNELS]); + params->supported_channels_len = + nla_len(info->attrs[NL80211_ATTR_STA_SUPPORTED_CHANNELS]); + /* + * Need to include at least one (first channel, number of + * channels) tuple for each subband (checked in policy), + * and must have proper tuples for the rest of the data as well. + */ + if (params->supported_channels_len % 2) + return -EINVAL; + } + + if (info->attrs[NL80211_ATTR_STA_SUPPORTED_OPER_CLASSES]) { + params->supported_oper_classes = + nla_data(info->attrs[NL80211_ATTR_STA_SUPPORTED_OPER_CLASSES]); + params->supported_oper_classes_len = + nla_len(info->attrs[NL80211_ATTR_STA_SUPPORTED_OPER_CLASSES]); + } + return 0; +} + +static int nl80211_set_station_tdls(struct genl_info *info, + struct station_parameters *params) +{ + int err; + /* Dummy STA entry gets updated once the peer capabilities are known */ + if (info->attrs[NL80211_ATTR_PEER_AID]) + params->aid = nla_get_u16(info->attrs[NL80211_ATTR_PEER_AID]); + if (info->attrs[NL80211_ATTR_HT_CAPABILITY]) + params->link_sta_params.ht_capa = + nla_data(info->attrs[NL80211_ATTR_HT_CAPABILITY]); + if (info->attrs[NL80211_ATTR_VHT_CAPABILITY]) + params->link_sta_params.vht_capa = + nla_data(info->attrs[NL80211_ATTR_VHT_CAPABILITY]); + if (info->attrs[NL80211_ATTR_HE_CAPABILITY]) { + params->link_sta_params.he_capa = + nla_data(info->attrs[NL80211_ATTR_HE_CAPABILITY]); + params->link_sta_params.he_capa_len = + nla_len(info->attrs[NL80211_ATTR_HE_CAPABILITY]); + + if (info->attrs[NL80211_ATTR_EHT_CAPABILITY]) { + params->link_sta_params.eht_capa = + nla_data(info->attrs[NL80211_ATTR_EHT_CAPABILITY]); + params->link_sta_params.eht_capa_len = + nla_len(info->attrs[NL80211_ATTR_EHT_CAPABILITY]); + + if (!ieee80211_eht_capa_size_ok((const u8 *)params->link_sta_params.he_capa, + (const u8 *)params->link_sta_params.eht_capa, + params->link_sta_params.eht_capa_len, + false)) + return -EINVAL; + } + } + + err = nl80211_parse_sta_channel_info(info, params); + if (err) + return err; + + return nl80211_parse_sta_wme(info, params); +} + +static int nl80211_parse_sta_txpower_setting(struct genl_info *info, + struct sta_txpwr *txpwr, + bool *txpwr_set) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + int idx; + + if (info->attrs[NL80211_ATTR_STA_TX_POWER_SETTING]) { + if (!rdev->ops->set_tx_power || + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_STA_TX_PWR)) + return -EOPNOTSUPP; + + idx = NL80211_ATTR_STA_TX_POWER_SETTING; + txpwr->type = nla_get_u8(info->attrs[idx]); + + if (txpwr->type == NL80211_TX_POWER_LIMITED) { + idx = NL80211_ATTR_STA_TX_POWER; + + if (info->attrs[idx]) + txpwr->power = nla_get_s16(info->attrs[idx]); + else + return -EINVAL; + } + + *txpwr_set = true; + } else { + *txpwr_set = false; + } + + return 0; +} + +static int nl80211_set_station(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct station_parameters params; + u8 *mac_addr; + int err; + + memset(¶ms, 0, sizeof(params)); + + if (!rdev->ops->change_station) + return -EOPNOTSUPP; + + /* + * AID and listen_interval properties can be set only for unassociated + * station. Include these parameters here and will check them in + * cfg80211_check_station_change(). + */ + if (info->attrs[NL80211_ATTR_STA_AID]) + params.aid = nla_get_u16(info->attrs[NL80211_ATTR_STA_AID]); + + if (info->attrs[NL80211_ATTR_VLAN_ID]) + params.vlan_id = nla_get_u16(info->attrs[NL80211_ATTR_VLAN_ID]); + + if (info->attrs[NL80211_ATTR_STA_LISTEN_INTERVAL]) + params.listen_interval = + nla_get_u16(info->attrs[NL80211_ATTR_STA_LISTEN_INTERVAL]); + else + params.listen_interval = -1; + + if (info->attrs[NL80211_ATTR_STA_SUPPORT_P2P_PS]) + params.support_p2p_ps = + nla_get_u8(info->attrs[NL80211_ATTR_STA_SUPPORT_P2P_PS]); + else + params.support_p2p_ps = -1; + + if (!info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + params.link_sta_params.link_id = + nl80211_link_id_or_invalid(info->attrs); + + if (info->attrs[NL80211_ATTR_MLD_ADDR]) { + /* If MLD_ADDR attribute is set then this is an MLD station + * and the MLD_ADDR attribute holds the MLD address and the + * MAC attribute holds for the LINK address. + * In that case, the link_id is also expected to be valid. + */ + if (params.link_sta_params.link_id < 0) + return -EINVAL; + + mac_addr = nla_data(info->attrs[NL80211_ATTR_MLD_ADDR]); + params.link_sta_params.mld_mac = mac_addr; + params.link_sta_params.link_mac = + nla_data(info->attrs[NL80211_ATTR_MAC]); + if (!is_valid_ether_addr(params.link_sta_params.link_mac)) + return -EINVAL; + } else { + mac_addr = nla_data(info->attrs[NL80211_ATTR_MAC]); + } + + + if (info->attrs[NL80211_ATTR_STA_SUPPORTED_RATES]) { + params.link_sta_params.supported_rates = + nla_data(info->attrs[NL80211_ATTR_STA_SUPPORTED_RATES]); + params.link_sta_params.supported_rates_len = + nla_len(info->attrs[NL80211_ATTR_STA_SUPPORTED_RATES]); + } + + if (info->attrs[NL80211_ATTR_STA_CAPABILITY]) { + params.capability = + nla_get_u16(info->attrs[NL80211_ATTR_STA_CAPABILITY]); + params.sta_modify_mask |= STATION_PARAM_APPLY_CAPABILITY; + } + + if (info->attrs[NL80211_ATTR_STA_EXT_CAPABILITY]) { + params.ext_capab = + nla_data(info->attrs[NL80211_ATTR_STA_EXT_CAPABILITY]); + params.ext_capab_len = + nla_len(info->attrs[NL80211_ATTR_STA_EXT_CAPABILITY]); + } + + if (parse_station_flags(info, dev->ieee80211_ptr->iftype, ¶ms)) + return -EINVAL; + + if (info->attrs[NL80211_ATTR_STA_PLINK_ACTION]) + params.plink_action = + nla_get_u8(info->attrs[NL80211_ATTR_STA_PLINK_ACTION]); + + if (info->attrs[NL80211_ATTR_STA_PLINK_STATE]) { + params.plink_state = + nla_get_u8(info->attrs[NL80211_ATTR_STA_PLINK_STATE]); + if (info->attrs[NL80211_ATTR_MESH_PEER_AID]) + params.peer_aid = nla_get_u16( + info->attrs[NL80211_ATTR_MESH_PEER_AID]); + params.sta_modify_mask |= STATION_PARAM_APPLY_PLINK_STATE; + } + + if (info->attrs[NL80211_ATTR_LOCAL_MESH_POWER_MODE]) + params.local_pm = nla_get_u32( + info->attrs[NL80211_ATTR_LOCAL_MESH_POWER_MODE]); + + if (info->attrs[NL80211_ATTR_OPMODE_NOTIF]) { + params.link_sta_params.opmode_notif_used = true; + params.link_sta_params.opmode_notif = + nla_get_u8(info->attrs[NL80211_ATTR_OPMODE_NOTIF]); + } + + if (info->attrs[NL80211_ATTR_HE_6GHZ_CAPABILITY]) + params.link_sta_params.he_6ghz_capa = + nla_data(info->attrs[NL80211_ATTR_HE_6GHZ_CAPABILITY]); + + if (info->attrs[NL80211_ATTR_AIRTIME_WEIGHT]) + params.airtime_weight = + nla_get_u16(info->attrs[NL80211_ATTR_AIRTIME_WEIGHT]); + + if (params.airtime_weight && + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_AIRTIME_FAIRNESS)) + return -EOPNOTSUPP; + + err = nl80211_parse_sta_txpower_setting(info, + ¶ms.link_sta_params.txpwr, + ¶ms.link_sta_params.txpwr_set); + if (err) + return err; + + /* Include parameters for TDLS peer (will check later) */ + err = nl80211_set_station_tdls(info, ¶ms); + if (err) + return err; + + params.vlan = get_vlan(info, rdev); + if (IS_ERR(params.vlan)) + return PTR_ERR(params.vlan); + + switch (dev->ieee80211_ptr->iftype) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_P2P_GO: + case NL80211_IFTYPE_P2P_CLIENT: + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_ADHOC: + case NL80211_IFTYPE_MESH_POINT: + break; + default: + err = -EOPNOTSUPP; + goto out_put_vlan; + } + + /* driver will call cfg80211_check_station_change() */ + wdev_lock(dev->ieee80211_ptr); + err = rdev_change_station(rdev, dev, mac_addr, ¶ms); + wdev_unlock(dev->ieee80211_ptr); + + out_put_vlan: + dev_put(params.vlan); + + return err; +} + +static int nl80211_new_station(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + int err; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct station_parameters params; + u8 *mac_addr = NULL; + u32 auth_assoc = BIT(NL80211_STA_FLAG_AUTHENTICATED) | + BIT(NL80211_STA_FLAG_ASSOCIATED); + + memset(¶ms, 0, sizeof(params)); + + if (!rdev->ops->add_station) + return -EOPNOTSUPP; + + if (!info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + if (!info->attrs[NL80211_ATTR_STA_LISTEN_INTERVAL]) + return -EINVAL; + + if (!info->attrs[NL80211_ATTR_STA_SUPPORTED_RATES]) + return -EINVAL; + + if (!info->attrs[NL80211_ATTR_STA_AID] && + !info->attrs[NL80211_ATTR_PEER_AID]) + return -EINVAL; + + params.link_sta_params.link_id = + nl80211_link_id_or_invalid(info->attrs); + + if (info->attrs[NL80211_ATTR_MLD_ADDR]) { + mac_addr = nla_data(info->attrs[NL80211_ATTR_MLD_ADDR]); + params.link_sta_params.mld_mac = mac_addr; + params.link_sta_params.link_mac = + nla_data(info->attrs[NL80211_ATTR_MAC]); + if (!is_valid_ether_addr(params.link_sta_params.link_mac)) + return -EINVAL; + } else { + mac_addr = nla_data(info->attrs[NL80211_ATTR_MAC]); + } + + params.link_sta_params.supported_rates = + nla_data(info->attrs[NL80211_ATTR_STA_SUPPORTED_RATES]); + params.link_sta_params.supported_rates_len = + nla_len(info->attrs[NL80211_ATTR_STA_SUPPORTED_RATES]); + params.listen_interval = + nla_get_u16(info->attrs[NL80211_ATTR_STA_LISTEN_INTERVAL]); + + if (info->attrs[NL80211_ATTR_VLAN_ID]) + params.vlan_id = nla_get_u16(info->attrs[NL80211_ATTR_VLAN_ID]); + + if (info->attrs[NL80211_ATTR_STA_SUPPORT_P2P_PS]) { + params.support_p2p_ps = + nla_get_u8(info->attrs[NL80211_ATTR_STA_SUPPORT_P2P_PS]); + } else { + /* + * if not specified, assume it's supported for P2P GO interface, + * and is NOT supported for AP interface + */ + params.support_p2p_ps = + dev->ieee80211_ptr->iftype == NL80211_IFTYPE_P2P_GO; + } + + if (info->attrs[NL80211_ATTR_PEER_AID]) + params.aid = nla_get_u16(info->attrs[NL80211_ATTR_PEER_AID]); + else + params.aid = nla_get_u16(info->attrs[NL80211_ATTR_STA_AID]); + + if (info->attrs[NL80211_ATTR_STA_CAPABILITY]) { + params.capability = + nla_get_u16(info->attrs[NL80211_ATTR_STA_CAPABILITY]); + params.sta_modify_mask |= STATION_PARAM_APPLY_CAPABILITY; + } + + if (info->attrs[NL80211_ATTR_STA_EXT_CAPABILITY]) { + params.ext_capab = + nla_data(info->attrs[NL80211_ATTR_STA_EXT_CAPABILITY]); + params.ext_capab_len = + nla_len(info->attrs[NL80211_ATTR_STA_EXT_CAPABILITY]); + } + + if (info->attrs[NL80211_ATTR_HT_CAPABILITY]) + params.link_sta_params.ht_capa = + nla_data(info->attrs[NL80211_ATTR_HT_CAPABILITY]); + + if (info->attrs[NL80211_ATTR_VHT_CAPABILITY]) + params.link_sta_params.vht_capa = + nla_data(info->attrs[NL80211_ATTR_VHT_CAPABILITY]); + + if (info->attrs[NL80211_ATTR_HE_CAPABILITY]) { + params.link_sta_params.he_capa = + nla_data(info->attrs[NL80211_ATTR_HE_CAPABILITY]); + params.link_sta_params.he_capa_len = + nla_len(info->attrs[NL80211_ATTR_HE_CAPABILITY]); + + if (info->attrs[NL80211_ATTR_EHT_CAPABILITY]) { + params.link_sta_params.eht_capa = + nla_data(info->attrs[NL80211_ATTR_EHT_CAPABILITY]); + params.link_sta_params.eht_capa_len = + nla_len(info->attrs[NL80211_ATTR_EHT_CAPABILITY]); + + if (!ieee80211_eht_capa_size_ok((const u8 *)params.link_sta_params.he_capa, + (const u8 *)params.link_sta_params.eht_capa, + params.link_sta_params.eht_capa_len, + false)) + return -EINVAL; + } + } + + if (info->attrs[NL80211_ATTR_HE_6GHZ_CAPABILITY]) + params.link_sta_params.he_6ghz_capa = + nla_data(info->attrs[NL80211_ATTR_HE_6GHZ_CAPABILITY]); + + if (info->attrs[NL80211_ATTR_OPMODE_NOTIF]) { + params.link_sta_params.opmode_notif_used = true; + params.link_sta_params.opmode_notif = + nla_get_u8(info->attrs[NL80211_ATTR_OPMODE_NOTIF]); + } + + if (info->attrs[NL80211_ATTR_STA_PLINK_ACTION]) + params.plink_action = + nla_get_u8(info->attrs[NL80211_ATTR_STA_PLINK_ACTION]); + + if (info->attrs[NL80211_ATTR_AIRTIME_WEIGHT]) + params.airtime_weight = + nla_get_u16(info->attrs[NL80211_ATTR_AIRTIME_WEIGHT]); + + if (params.airtime_weight && + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_AIRTIME_FAIRNESS)) + return -EOPNOTSUPP; + + err = nl80211_parse_sta_txpower_setting(info, + ¶ms.link_sta_params.txpwr, + ¶ms.link_sta_params.txpwr_set); + if (err) + return err; + + err = nl80211_parse_sta_channel_info(info, ¶ms); + if (err) + return err; + + err = nl80211_parse_sta_wme(info, ¶ms); + if (err) + return err; + + if (parse_station_flags(info, dev->ieee80211_ptr->iftype, ¶ms)) + return -EINVAL; + + /* HT/VHT requires QoS, but if we don't have that just ignore HT/VHT + * as userspace might just pass through the capabilities from the IEs + * directly, rather than enforcing this restriction and returning an + * error in this case. + */ + if (!(params.sta_flags_set & BIT(NL80211_STA_FLAG_WME))) { + params.link_sta_params.ht_capa = NULL; + params.link_sta_params.vht_capa = NULL; + + /* HE and EHT require WME */ + if (params.link_sta_params.he_capa_len || + params.link_sta_params.he_6ghz_capa || + params.link_sta_params.eht_capa_len) + return -EINVAL; + } + + /* Ensure that HT/VHT capabilities are not set for 6 GHz HE STA */ + if (params.link_sta_params.he_6ghz_capa && + (params.link_sta_params.ht_capa || params.link_sta_params.vht_capa)) + return -EINVAL; + + /* When you run into this, adjust the code below for the new flag */ + BUILD_BUG_ON(NL80211_STA_FLAG_MAX != 7); + + switch (dev->ieee80211_ptr->iftype) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_P2P_GO: + /* ignore WME attributes if iface/sta is not capable */ + if (!(rdev->wiphy.flags & WIPHY_FLAG_AP_UAPSD) || + !(params.sta_flags_set & BIT(NL80211_STA_FLAG_WME))) + params.sta_modify_mask &= ~STATION_PARAM_APPLY_UAPSD; + + /* TDLS peers cannot be added */ + if ((params.sta_flags_set & BIT(NL80211_STA_FLAG_TDLS_PEER)) || + info->attrs[NL80211_ATTR_PEER_AID]) + return -EINVAL; + /* but don't bother the driver with it */ + params.sta_flags_mask &= ~BIT(NL80211_STA_FLAG_TDLS_PEER); + + /* allow authenticated/associated only if driver handles it */ + if (!(rdev->wiphy.features & + NL80211_FEATURE_FULL_AP_CLIENT_STATE) && + params.sta_flags_mask & auth_assoc) + return -EINVAL; + + /* Older userspace, or userspace wanting to be compatible with + * !NL80211_FEATURE_FULL_AP_CLIENT_STATE, will not set the auth + * and assoc flags in the mask, but assumes the station will be + * added as associated anyway since this was the required driver + * behaviour before NL80211_FEATURE_FULL_AP_CLIENT_STATE was + * introduced. + * In order to not bother drivers with this quirk in the API + * set the flags in both the mask and set for new stations in + * this case. + */ + if (!(params.sta_flags_mask & auth_assoc)) { + params.sta_flags_mask |= auth_assoc; + params.sta_flags_set |= auth_assoc; + } + + /* must be last in here for error handling */ + params.vlan = get_vlan(info, rdev); + if (IS_ERR(params.vlan)) + return PTR_ERR(params.vlan); + break; + case NL80211_IFTYPE_MESH_POINT: + /* ignore uAPSD data */ + params.sta_modify_mask &= ~STATION_PARAM_APPLY_UAPSD; + + /* associated is disallowed */ + if (params.sta_flags_mask & BIT(NL80211_STA_FLAG_ASSOCIATED)) + return -EINVAL; + /* TDLS peers cannot be added */ + if ((params.sta_flags_set & BIT(NL80211_STA_FLAG_TDLS_PEER)) || + info->attrs[NL80211_ATTR_PEER_AID]) + return -EINVAL; + break; + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_P2P_CLIENT: + /* ignore uAPSD data */ + params.sta_modify_mask &= ~STATION_PARAM_APPLY_UAPSD; + + /* these are disallowed */ + if (params.sta_flags_mask & + (BIT(NL80211_STA_FLAG_ASSOCIATED) | + BIT(NL80211_STA_FLAG_AUTHENTICATED))) + return -EINVAL; + /* Only TDLS peers can be added */ + if (!(params.sta_flags_set & BIT(NL80211_STA_FLAG_TDLS_PEER))) + return -EINVAL; + /* Can only add if TDLS ... */ + if (!(rdev->wiphy.flags & WIPHY_FLAG_SUPPORTS_TDLS)) + return -EOPNOTSUPP; + /* ... with external setup is supported */ + if (!(rdev->wiphy.flags & WIPHY_FLAG_TDLS_EXTERNAL_SETUP)) + return -EOPNOTSUPP; + /* + * Older wpa_supplicant versions always mark the TDLS peer + * as authorized, but it shouldn't yet be. + */ + params.sta_flags_mask &= ~BIT(NL80211_STA_FLAG_AUTHORIZED); + break; + default: + return -EOPNOTSUPP; + } + + /* be aware of params.vlan when changing code here */ + + wdev_lock(dev->ieee80211_ptr); + if (wdev->valid_links) { + if (params.link_sta_params.link_id < 0) { + err = -EINVAL; + goto out; + } + if (!(wdev->valid_links & BIT(params.link_sta_params.link_id))) { + err = -ENOLINK; + goto out; + } + } else { + if (params.link_sta_params.link_id >= 0) { + err = -EINVAL; + goto out; + } + } + err = rdev_add_station(rdev, dev, mac_addr, ¶ms); +out: + wdev_unlock(dev->ieee80211_ptr); + dev_put(params.vlan); + return err; +} + +static int nl80211_del_station(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct station_del_parameters params; + int ret; + + memset(¶ms, 0, sizeof(params)); + + if (info->attrs[NL80211_ATTR_MAC]) + params.mac = nla_data(info->attrs[NL80211_ATTR_MAC]); + + switch (dev->ieee80211_ptr->iftype) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_MESH_POINT: + case NL80211_IFTYPE_P2P_GO: + /* always accept these */ + break; + case NL80211_IFTYPE_ADHOC: + /* conditionally accept */ + if (wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_DEL_IBSS_STA)) + break; + return -EINVAL; + default: + return -EINVAL; + } + + if (!rdev->ops->del_station) + return -EOPNOTSUPP; + + if (info->attrs[NL80211_ATTR_MGMT_SUBTYPE]) { + params.subtype = + nla_get_u8(info->attrs[NL80211_ATTR_MGMT_SUBTYPE]); + if (params.subtype != IEEE80211_STYPE_DISASSOC >> 4 && + params.subtype != IEEE80211_STYPE_DEAUTH >> 4) + return -EINVAL; + } else { + /* Default to Deauthentication frame */ + params.subtype = IEEE80211_STYPE_DEAUTH >> 4; + } + + if (info->attrs[NL80211_ATTR_REASON_CODE]) { + params.reason_code = + nla_get_u16(info->attrs[NL80211_ATTR_REASON_CODE]); + if (params.reason_code == 0) + return -EINVAL; /* 0 is reserved */ + } else { + /* Default to reason code 2 */ + params.reason_code = WLAN_REASON_PREV_AUTH_NOT_VALID; + } + + wdev_lock(dev->ieee80211_ptr); + ret = rdev_del_station(rdev, dev, ¶ms); + wdev_unlock(dev->ieee80211_ptr); + + return ret; +} + +static int nl80211_send_mpath(struct sk_buff *msg, u32 portid, u32 seq, + int flags, struct net_device *dev, + u8 *dst, u8 *next_hop, + struct mpath_info *pinfo) +{ + void *hdr; + struct nlattr *pinfoattr; + + hdr = nl80211hdr_put(msg, portid, seq, flags, NL80211_CMD_NEW_MPATH); + if (!hdr) + return -1; + + if (nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, dst) || + nla_put(msg, NL80211_ATTR_MPATH_NEXT_HOP, ETH_ALEN, next_hop) || + nla_put_u32(msg, NL80211_ATTR_GENERATION, pinfo->generation)) + goto nla_put_failure; + + pinfoattr = nla_nest_start_noflag(msg, NL80211_ATTR_MPATH_INFO); + if (!pinfoattr) + goto nla_put_failure; + if ((pinfo->filled & MPATH_INFO_FRAME_QLEN) && + nla_put_u32(msg, NL80211_MPATH_INFO_FRAME_QLEN, + pinfo->frame_qlen)) + goto nla_put_failure; + if (((pinfo->filled & MPATH_INFO_SN) && + nla_put_u32(msg, NL80211_MPATH_INFO_SN, pinfo->sn)) || + ((pinfo->filled & MPATH_INFO_METRIC) && + nla_put_u32(msg, NL80211_MPATH_INFO_METRIC, + pinfo->metric)) || + ((pinfo->filled & MPATH_INFO_EXPTIME) && + nla_put_u32(msg, NL80211_MPATH_INFO_EXPTIME, + pinfo->exptime)) || + ((pinfo->filled & MPATH_INFO_FLAGS) && + nla_put_u8(msg, NL80211_MPATH_INFO_FLAGS, + pinfo->flags)) || + ((pinfo->filled & MPATH_INFO_DISCOVERY_TIMEOUT) && + nla_put_u32(msg, NL80211_MPATH_INFO_DISCOVERY_TIMEOUT, + pinfo->discovery_timeout)) || + ((pinfo->filled & MPATH_INFO_DISCOVERY_RETRIES) && + nla_put_u8(msg, NL80211_MPATH_INFO_DISCOVERY_RETRIES, + pinfo->discovery_retries)) || + ((pinfo->filled & MPATH_INFO_HOP_COUNT) && + nla_put_u8(msg, NL80211_MPATH_INFO_HOP_COUNT, + pinfo->hop_count)) || + ((pinfo->filled & MPATH_INFO_PATH_CHANGE) && + nla_put_u32(msg, NL80211_MPATH_INFO_PATH_CHANGE, + pinfo->path_change_count))) + goto nla_put_failure; + + nla_nest_end(msg, pinfoattr); + + genlmsg_end(msg, hdr); + return 0; + + nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int nl80211_dump_mpath(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct mpath_info pinfo; + struct cfg80211_registered_device *rdev; + struct wireless_dev *wdev; + u8 dst[ETH_ALEN]; + u8 next_hop[ETH_ALEN]; + int path_idx = cb->args[2]; + int err; + + err = nl80211_prepare_wdev_dump(cb, &rdev, &wdev, NULL); + if (err) + return err; + /* nl80211_prepare_wdev_dump acquired it in the successful case */ + __acquire(&rdev->wiphy.mtx); + + if (!rdev->ops->dump_mpath) { + err = -EOPNOTSUPP; + goto out_err; + } + + if (wdev->iftype != NL80211_IFTYPE_MESH_POINT) { + err = -EOPNOTSUPP; + goto out_err; + } + + while (1) { + err = rdev_dump_mpath(rdev, wdev->netdev, path_idx, dst, + next_hop, &pinfo); + if (err == -ENOENT) + break; + if (err) + goto out_err; + + if (nl80211_send_mpath(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + wdev->netdev, dst, next_hop, + &pinfo) < 0) + goto out; + + path_idx++; + } + + out: + cb->args[2] = path_idx; + err = skb->len; + out_err: + wiphy_unlock(&rdev->wiphy); + return err; +} + +static int nl80211_get_mpath(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + int err; + struct net_device *dev = info->user_ptr[1]; + struct mpath_info pinfo; + struct sk_buff *msg; + u8 *dst = NULL; + u8 next_hop[ETH_ALEN]; + + memset(&pinfo, 0, sizeof(pinfo)); + + if (!info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + dst = nla_data(info->attrs[NL80211_ATTR_MAC]); + + if (!rdev->ops->get_mpath) + return -EOPNOTSUPP; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_MESH_POINT) + return -EOPNOTSUPP; + + err = rdev_get_mpath(rdev, dev, dst, next_hop, &pinfo); + if (err) + return err; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + if (nl80211_send_mpath(msg, info->snd_portid, info->snd_seq, 0, + dev, dst, next_hop, &pinfo) < 0) { + nlmsg_free(msg); + return -ENOBUFS; + } + + return genlmsg_reply(msg, info); +} + +static int nl80211_set_mpath(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + u8 *dst = NULL; + u8 *next_hop = NULL; + + if (!info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + if (!info->attrs[NL80211_ATTR_MPATH_NEXT_HOP]) + return -EINVAL; + + dst = nla_data(info->attrs[NL80211_ATTR_MAC]); + next_hop = nla_data(info->attrs[NL80211_ATTR_MPATH_NEXT_HOP]); + + if (!rdev->ops->change_mpath) + return -EOPNOTSUPP; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_MESH_POINT) + return -EOPNOTSUPP; + + return rdev_change_mpath(rdev, dev, dst, next_hop); +} + +static int nl80211_new_mpath(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + u8 *dst = NULL; + u8 *next_hop = NULL; + + if (!info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + if (!info->attrs[NL80211_ATTR_MPATH_NEXT_HOP]) + return -EINVAL; + + dst = nla_data(info->attrs[NL80211_ATTR_MAC]); + next_hop = nla_data(info->attrs[NL80211_ATTR_MPATH_NEXT_HOP]); + + if (!rdev->ops->add_mpath) + return -EOPNOTSUPP; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_MESH_POINT) + return -EOPNOTSUPP; + + return rdev_add_mpath(rdev, dev, dst, next_hop); +} + +static int nl80211_del_mpath(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + u8 *dst = NULL; + + if (info->attrs[NL80211_ATTR_MAC]) + dst = nla_data(info->attrs[NL80211_ATTR_MAC]); + + if (!rdev->ops->del_mpath) + return -EOPNOTSUPP; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_MESH_POINT) + return -EOPNOTSUPP; + + return rdev_del_mpath(rdev, dev, dst); +} + +static int nl80211_get_mpp(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + int err; + struct net_device *dev = info->user_ptr[1]; + struct mpath_info pinfo; + struct sk_buff *msg; + u8 *dst = NULL; + u8 mpp[ETH_ALEN]; + + memset(&pinfo, 0, sizeof(pinfo)); + + if (!info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + dst = nla_data(info->attrs[NL80211_ATTR_MAC]); + + if (!rdev->ops->get_mpp) + return -EOPNOTSUPP; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_MESH_POINT) + return -EOPNOTSUPP; + + err = rdev_get_mpp(rdev, dev, dst, mpp, &pinfo); + if (err) + return err; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + if (nl80211_send_mpath(msg, info->snd_portid, info->snd_seq, 0, + dev, dst, mpp, &pinfo) < 0) { + nlmsg_free(msg); + return -ENOBUFS; + } + + return genlmsg_reply(msg, info); +} + +static int nl80211_dump_mpp(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct mpath_info pinfo; + struct cfg80211_registered_device *rdev; + struct wireless_dev *wdev; + u8 dst[ETH_ALEN]; + u8 mpp[ETH_ALEN]; + int path_idx = cb->args[2]; + int err; + + err = nl80211_prepare_wdev_dump(cb, &rdev, &wdev, NULL); + if (err) + return err; + /* nl80211_prepare_wdev_dump acquired it in the successful case */ + __acquire(&rdev->wiphy.mtx); + + if (!rdev->ops->dump_mpp) { + err = -EOPNOTSUPP; + goto out_err; + } + + if (wdev->iftype != NL80211_IFTYPE_MESH_POINT) { + err = -EOPNOTSUPP; + goto out_err; + } + + while (1) { + err = rdev_dump_mpp(rdev, wdev->netdev, path_idx, dst, + mpp, &pinfo); + if (err == -ENOENT) + break; + if (err) + goto out_err; + + if (nl80211_send_mpath(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + wdev->netdev, dst, mpp, + &pinfo) < 0) + goto out; + + path_idx++; + } + + out: + cb->args[2] = path_idx; + err = skb->len; + out_err: + wiphy_unlock(&rdev->wiphy); + return err; +} + +static int nl80211_set_bss(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct bss_parameters params; + int err; + + memset(¶ms, 0, sizeof(params)); + /* default to not changing parameters */ + params.use_cts_prot = -1; + params.use_short_preamble = -1; + params.use_short_slot_time = -1; + params.ap_isolate = -1; + params.ht_opmode = -1; + params.p2p_ctwindow = -1; + params.p2p_opp_ps = -1; + + if (info->attrs[NL80211_ATTR_BSS_CTS_PROT]) + params.use_cts_prot = + nla_get_u8(info->attrs[NL80211_ATTR_BSS_CTS_PROT]); + if (info->attrs[NL80211_ATTR_BSS_SHORT_PREAMBLE]) + params.use_short_preamble = + nla_get_u8(info->attrs[NL80211_ATTR_BSS_SHORT_PREAMBLE]); + if (info->attrs[NL80211_ATTR_BSS_SHORT_SLOT_TIME]) + params.use_short_slot_time = + nla_get_u8(info->attrs[NL80211_ATTR_BSS_SHORT_SLOT_TIME]); + if (info->attrs[NL80211_ATTR_BSS_BASIC_RATES]) { + params.basic_rates = + nla_data(info->attrs[NL80211_ATTR_BSS_BASIC_RATES]); + params.basic_rates_len = + nla_len(info->attrs[NL80211_ATTR_BSS_BASIC_RATES]); + } + if (info->attrs[NL80211_ATTR_AP_ISOLATE]) + params.ap_isolate = !!nla_get_u8(info->attrs[NL80211_ATTR_AP_ISOLATE]); + if (info->attrs[NL80211_ATTR_BSS_HT_OPMODE]) + params.ht_opmode = + nla_get_u16(info->attrs[NL80211_ATTR_BSS_HT_OPMODE]); + + if (info->attrs[NL80211_ATTR_P2P_CTWINDOW]) { + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_GO) + return -EINVAL; + params.p2p_ctwindow = + nla_get_u8(info->attrs[NL80211_ATTR_P2P_CTWINDOW]); + if (params.p2p_ctwindow != 0 && + !(rdev->wiphy.features & NL80211_FEATURE_P2P_GO_CTWIN)) + return -EINVAL; + } + + if (info->attrs[NL80211_ATTR_P2P_OPPPS]) { + u8 tmp; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_GO) + return -EINVAL; + tmp = nla_get_u8(info->attrs[NL80211_ATTR_P2P_OPPPS]); + params.p2p_opp_ps = tmp; + if (params.p2p_opp_ps && + !(rdev->wiphy.features & NL80211_FEATURE_P2P_GO_OPPPS)) + return -EINVAL; + } + + if (!rdev->ops->change_bss) + return -EOPNOTSUPP; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_AP && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_GO) + return -EOPNOTSUPP; + + wdev_lock(wdev); + err = rdev_change_bss(rdev, dev, ¶ms); + wdev_unlock(wdev); + + return err; +} + +static int nl80211_req_set_reg(struct sk_buff *skb, struct genl_info *info) +{ + char *data = NULL; + bool is_indoor; + enum nl80211_user_reg_hint_type user_reg_hint_type; + u32 owner_nlportid; + + /* + * You should only get this when cfg80211 hasn't yet initialized + * completely when built-in to the kernel right between the time + * window between nl80211_init() and regulatory_init(), if that is + * even possible. + */ + if (unlikely(!rcu_access_pointer(cfg80211_regdomain))) + return -EINPROGRESS; + + if (info->attrs[NL80211_ATTR_USER_REG_HINT_TYPE]) + user_reg_hint_type = + nla_get_u32(info->attrs[NL80211_ATTR_USER_REG_HINT_TYPE]); + else + user_reg_hint_type = NL80211_USER_REG_HINT_USER; + + switch (user_reg_hint_type) { + case NL80211_USER_REG_HINT_USER: + case NL80211_USER_REG_HINT_CELL_BASE: + if (!info->attrs[NL80211_ATTR_REG_ALPHA2]) + return -EINVAL; + + data = nla_data(info->attrs[NL80211_ATTR_REG_ALPHA2]); + return regulatory_hint_user(data, user_reg_hint_type); + case NL80211_USER_REG_HINT_INDOOR: + if (info->attrs[NL80211_ATTR_SOCKET_OWNER]) { + owner_nlportid = info->snd_portid; + is_indoor = !!info->attrs[NL80211_ATTR_REG_INDOOR]; + } else { + owner_nlportid = 0; + is_indoor = true; + } + + return regulatory_hint_indoor(is_indoor, owner_nlportid); + default: + return -EINVAL; + } +} + +static int nl80211_reload_regdb(struct sk_buff *skb, struct genl_info *info) +{ + return reg_reload_regdb(); +} + +static int nl80211_get_mesh_config(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct mesh_config cur_params; + int err = 0; + void *hdr; + struct nlattr *pinfoattr; + struct sk_buff *msg; + + if (wdev->iftype != NL80211_IFTYPE_MESH_POINT) + return -EOPNOTSUPP; + + if (!rdev->ops->get_mesh_config) + return -EOPNOTSUPP; + + wdev_lock(wdev); + /* If not connected, get default parameters */ + if (!wdev->u.mesh.id_len) + memcpy(&cur_params, &default_mesh_config, sizeof(cur_params)); + else + err = rdev_get_mesh_config(rdev, dev, &cur_params); + wdev_unlock(wdev); + + if (err) + return err; + + /* Draw up a netlink message to send back */ + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + hdr = nl80211hdr_put(msg, info->snd_portid, info->snd_seq, 0, + NL80211_CMD_GET_MESH_CONFIG); + if (!hdr) + goto out; + pinfoattr = nla_nest_start_noflag(msg, NL80211_ATTR_MESH_CONFIG); + if (!pinfoattr) + goto nla_put_failure; + if (nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex) || + nla_put_u16(msg, NL80211_MESHCONF_RETRY_TIMEOUT, + cur_params.dot11MeshRetryTimeout) || + nla_put_u16(msg, NL80211_MESHCONF_CONFIRM_TIMEOUT, + cur_params.dot11MeshConfirmTimeout) || + nla_put_u16(msg, NL80211_MESHCONF_HOLDING_TIMEOUT, + cur_params.dot11MeshHoldingTimeout) || + nla_put_u16(msg, NL80211_MESHCONF_MAX_PEER_LINKS, + cur_params.dot11MeshMaxPeerLinks) || + nla_put_u8(msg, NL80211_MESHCONF_MAX_RETRIES, + cur_params.dot11MeshMaxRetries) || + nla_put_u8(msg, NL80211_MESHCONF_TTL, + cur_params.dot11MeshTTL) || + nla_put_u8(msg, NL80211_MESHCONF_ELEMENT_TTL, + cur_params.element_ttl) || + nla_put_u8(msg, NL80211_MESHCONF_AUTO_OPEN_PLINKS, + cur_params.auto_open_plinks) || + nla_put_u32(msg, NL80211_MESHCONF_SYNC_OFFSET_MAX_NEIGHBOR, + cur_params.dot11MeshNbrOffsetMaxNeighbor) || + nla_put_u8(msg, NL80211_MESHCONF_HWMP_MAX_PREQ_RETRIES, + cur_params.dot11MeshHWMPmaxPREQretries) || + nla_put_u32(msg, NL80211_MESHCONF_PATH_REFRESH_TIME, + cur_params.path_refresh_time) || + nla_put_u16(msg, NL80211_MESHCONF_MIN_DISCOVERY_TIMEOUT, + cur_params.min_discovery_timeout) || + nla_put_u32(msg, NL80211_MESHCONF_HWMP_ACTIVE_PATH_TIMEOUT, + cur_params.dot11MeshHWMPactivePathTimeout) || + nla_put_u16(msg, NL80211_MESHCONF_HWMP_PREQ_MIN_INTERVAL, + cur_params.dot11MeshHWMPpreqMinInterval) || + nla_put_u16(msg, NL80211_MESHCONF_HWMP_PERR_MIN_INTERVAL, + cur_params.dot11MeshHWMPperrMinInterval) || + nla_put_u16(msg, NL80211_MESHCONF_HWMP_NET_DIAM_TRVS_TIME, + cur_params.dot11MeshHWMPnetDiameterTraversalTime) || + nla_put_u8(msg, NL80211_MESHCONF_HWMP_ROOTMODE, + cur_params.dot11MeshHWMPRootMode) || + nla_put_u16(msg, NL80211_MESHCONF_HWMP_RANN_INTERVAL, + cur_params.dot11MeshHWMPRannInterval) || + nla_put_u8(msg, NL80211_MESHCONF_GATE_ANNOUNCEMENTS, + cur_params.dot11MeshGateAnnouncementProtocol) || + nla_put_u8(msg, NL80211_MESHCONF_FORWARDING, + cur_params.dot11MeshForwarding) || + nla_put_s32(msg, NL80211_MESHCONF_RSSI_THRESHOLD, + cur_params.rssi_threshold) || + nla_put_u32(msg, NL80211_MESHCONF_HT_OPMODE, + cur_params.ht_opmode) || + nla_put_u32(msg, NL80211_MESHCONF_HWMP_PATH_TO_ROOT_TIMEOUT, + cur_params.dot11MeshHWMPactivePathToRootTimeout) || + nla_put_u16(msg, NL80211_MESHCONF_HWMP_ROOT_INTERVAL, + cur_params.dot11MeshHWMProotInterval) || + nla_put_u16(msg, NL80211_MESHCONF_HWMP_CONFIRMATION_INTERVAL, + cur_params.dot11MeshHWMPconfirmationInterval) || + nla_put_u32(msg, NL80211_MESHCONF_POWER_MODE, + cur_params.power_mode) || + nla_put_u16(msg, NL80211_MESHCONF_AWAKE_WINDOW, + cur_params.dot11MeshAwakeWindowDuration) || + nla_put_u32(msg, NL80211_MESHCONF_PLINK_TIMEOUT, + cur_params.plink_timeout) || + nla_put_u8(msg, NL80211_MESHCONF_CONNECTED_TO_GATE, + cur_params.dot11MeshConnectedToMeshGate) || + nla_put_u8(msg, NL80211_MESHCONF_NOLEARN, + cur_params.dot11MeshNolearn) || + nla_put_u8(msg, NL80211_MESHCONF_CONNECTED_TO_AS, + cur_params.dot11MeshConnectedToAuthServer)) + goto nla_put_failure; + nla_nest_end(msg, pinfoattr); + genlmsg_end(msg, hdr); + return genlmsg_reply(msg, info); + + nla_put_failure: + out: + nlmsg_free(msg); + return -ENOBUFS; +} + +static const struct nla_policy +nl80211_meshconf_params_policy[NL80211_MESHCONF_ATTR_MAX+1] = { + [NL80211_MESHCONF_RETRY_TIMEOUT] = + NLA_POLICY_RANGE(NLA_U16, 1, 255), + [NL80211_MESHCONF_CONFIRM_TIMEOUT] = + NLA_POLICY_RANGE(NLA_U16, 1, 255), + [NL80211_MESHCONF_HOLDING_TIMEOUT] = + NLA_POLICY_RANGE(NLA_U16, 1, 255), + [NL80211_MESHCONF_MAX_PEER_LINKS] = + NLA_POLICY_RANGE(NLA_U16, 0, 255), + [NL80211_MESHCONF_MAX_RETRIES] = NLA_POLICY_MAX(NLA_U8, 16), + [NL80211_MESHCONF_TTL] = NLA_POLICY_MIN(NLA_U8, 1), + [NL80211_MESHCONF_ELEMENT_TTL] = NLA_POLICY_MIN(NLA_U8, 1), + [NL80211_MESHCONF_AUTO_OPEN_PLINKS] = NLA_POLICY_MAX(NLA_U8, 1), + [NL80211_MESHCONF_SYNC_OFFSET_MAX_NEIGHBOR] = + NLA_POLICY_RANGE(NLA_U32, 1, 255), + [NL80211_MESHCONF_HWMP_MAX_PREQ_RETRIES] = { .type = NLA_U8 }, + [NL80211_MESHCONF_PATH_REFRESH_TIME] = { .type = NLA_U32 }, + [NL80211_MESHCONF_MIN_DISCOVERY_TIMEOUT] = NLA_POLICY_MIN(NLA_U16, 1), + [NL80211_MESHCONF_HWMP_ACTIVE_PATH_TIMEOUT] = { .type = NLA_U32 }, + [NL80211_MESHCONF_HWMP_PREQ_MIN_INTERVAL] = + NLA_POLICY_MIN(NLA_U16, 1), + [NL80211_MESHCONF_HWMP_PERR_MIN_INTERVAL] = + NLA_POLICY_MIN(NLA_U16, 1), + [NL80211_MESHCONF_HWMP_NET_DIAM_TRVS_TIME] = + NLA_POLICY_MIN(NLA_U16, 1), + [NL80211_MESHCONF_HWMP_ROOTMODE] = NLA_POLICY_MAX(NLA_U8, 4), + [NL80211_MESHCONF_HWMP_RANN_INTERVAL] = + NLA_POLICY_MIN(NLA_U16, 1), + [NL80211_MESHCONF_GATE_ANNOUNCEMENTS] = NLA_POLICY_MAX(NLA_U8, 1), + [NL80211_MESHCONF_FORWARDING] = NLA_POLICY_MAX(NLA_U8, 1), + [NL80211_MESHCONF_RSSI_THRESHOLD] = + NLA_POLICY_RANGE(NLA_S32, -255, 0), + [NL80211_MESHCONF_HT_OPMODE] = { .type = NLA_U16 }, + [NL80211_MESHCONF_HWMP_PATH_TO_ROOT_TIMEOUT] = { .type = NLA_U32 }, + [NL80211_MESHCONF_HWMP_ROOT_INTERVAL] = + NLA_POLICY_MIN(NLA_U16, 1), + [NL80211_MESHCONF_HWMP_CONFIRMATION_INTERVAL] = + NLA_POLICY_MIN(NLA_U16, 1), + [NL80211_MESHCONF_POWER_MODE] = + NLA_POLICY_RANGE(NLA_U32, + NL80211_MESH_POWER_ACTIVE, + NL80211_MESH_POWER_MAX), + [NL80211_MESHCONF_AWAKE_WINDOW] = { .type = NLA_U16 }, + [NL80211_MESHCONF_PLINK_TIMEOUT] = { .type = NLA_U32 }, + [NL80211_MESHCONF_CONNECTED_TO_GATE] = NLA_POLICY_RANGE(NLA_U8, 0, 1), + [NL80211_MESHCONF_NOLEARN] = NLA_POLICY_RANGE(NLA_U8, 0, 1), + [NL80211_MESHCONF_CONNECTED_TO_AS] = NLA_POLICY_RANGE(NLA_U8, 0, 1), +}; + +static const struct nla_policy + nl80211_mesh_setup_params_policy[NL80211_MESH_SETUP_ATTR_MAX+1] = { + [NL80211_MESH_SETUP_ENABLE_VENDOR_SYNC] = { .type = NLA_U8 }, + [NL80211_MESH_SETUP_ENABLE_VENDOR_PATH_SEL] = { .type = NLA_U8 }, + [NL80211_MESH_SETUP_ENABLE_VENDOR_METRIC] = { .type = NLA_U8 }, + [NL80211_MESH_SETUP_USERSPACE_AUTH] = { .type = NLA_FLAG }, + [NL80211_MESH_SETUP_AUTH_PROTOCOL] = { .type = NLA_U8 }, + [NL80211_MESH_SETUP_USERSPACE_MPM] = { .type = NLA_FLAG }, + [NL80211_MESH_SETUP_IE] = + NLA_POLICY_VALIDATE_FN(NLA_BINARY, validate_ie_attr, + IEEE80211_MAX_DATA_LEN), + [NL80211_MESH_SETUP_USERSPACE_AMPE] = { .type = NLA_FLAG }, +}; + +static int nl80211_parse_mesh_config(struct genl_info *info, + struct mesh_config *cfg, + u32 *mask_out) +{ + struct nlattr *tb[NL80211_MESHCONF_ATTR_MAX + 1]; + u32 mask = 0; + u16 ht_opmode; + +#define FILL_IN_MESH_PARAM_IF_SET(tb, cfg, param, mask, attr, fn) \ +do { \ + if (tb[attr]) { \ + cfg->param = fn(tb[attr]); \ + mask |= BIT((attr) - 1); \ + } \ +} while (0) + + if (!info->attrs[NL80211_ATTR_MESH_CONFIG]) + return -EINVAL; + if (nla_parse_nested_deprecated(tb, NL80211_MESHCONF_ATTR_MAX, info->attrs[NL80211_ATTR_MESH_CONFIG], nl80211_meshconf_params_policy, info->extack)) + return -EINVAL; + + /* This makes sure that there aren't more than 32 mesh config + * parameters (otherwise our bitfield scheme would not work.) */ + BUILD_BUG_ON(NL80211_MESHCONF_ATTR_MAX > 32); + + /* Fill in the params struct */ + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshRetryTimeout, mask, + NL80211_MESHCONF_RETRY_TIMEOUT, nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshConfirmTimeout, mask, + NL80211_MESHCONF_CONFIRM_TIMEOUT, + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHoldingTimeout, mask, + NL80211_MESHCONF_HOLDING_TIMEOUT, + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshMaxPeerLinks, mask, + NL80211_MESHCONF_MAX_PEER_LINKS, + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshMaxRetries, mask, + NL80211_MESHCONF_MAX_RETRIES, nla_get_u8); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshTTL, mask, + NL80211_MESHCONF_TTL, nla_get_u8); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, element_ttl, mask, + NL80211_MESHCONF_ELEMENT_TTL, nla_get_u8); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, auto_open_plinks, mask, + NL80211_MESHCONF_AUTO_OPEN_PLINKS, + nla_get_u8); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshNbrOffsetMaxNeighbor, + mask, + NL80211_MESHCONF_SYNC_OFFSET_MAX_NEIGHBOR, + nla_get_u32); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPmaxPREQretries, mask, + NL80211_MESHCONF_HWMP_MAX_PREQ_RETRIES, + nla_get_u8); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, path_refresh_time, mask, + NL80211_MESHCONF_PATH_REFRESH_TIME, + nla_get_u32); + if (mask & BIT(NL80211_MESHCONF_PATH_REFRESH_TIME) && + (cfg->path_refresh_time < 1 || cfg->path_refresh_time > 65535)) + return -EINVAL; + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, min_discovery_timeout, mask, + NL80211_MESHCONF_MIN_DISCOVERY_TIMEOUT, + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPactivePathTimeout, + mask, + NL80211_MESHCONF_HWMP_ACTIVE_PATH_TIMEOUT, + nla_get_u32); + if (mask & BIT(NL80211_MESHCONF_HWMP_ACTIVE_PATH_TIMEOUT) && + (cfg->dot11MeshHWMPactivePathTimeout < 1 || + cfg->dot11MeshHWMPactivePathTimeout > 65535)) + return -EINVAL; + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPpreqMinInterval, mask, + NL80211_MESHCONF_HWMP_PREQ_MIN_INTERVAL, + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPperrMinInterval, mask, + NL80211_MESHCONF_HWMP_PERR_MIN_INTERVAL, + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, + dot11MeshHWMPnetDiameterTraversalTime, mask, + NL80211_MESHCONF_HWMP_NET_DIAM_TRVS_TIME, + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPRootMode, mask, + NL80211_MESHCONF_HWMP_ROOTMODE, nla_get_u8); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPRannInterval, mask, + NL80211_MESHCONF_HWMP_RANN_INTERVAL, + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshGateAnnouncementProtocol, + mask, NL80211_MESHCONF_GATE_ANNOUNCEMENTS, + nla_get_u8); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshForwarding, mask, + NL80211_MESHCONF_FORWARDING, nla_get_u8); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, rssi_threshold, mask, + NL80211_MESHCONF_RSSI_THRESHOLD, + nla_get_s32); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshConnectedToMeshGate, mask, + NL80211_MESHCONF_CONNECTED_TO_GATE, + nla_get_u8); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshConnectedToAuthServer, mask, + NL80211_MESHCONF_CONNECTED_TO_AS, + nla_get_u8); + /* + * Check HT operation mode based on + * IEEE 802.11-2016 9.4.2.57 HT Operation element. + */ + if (tb[NL80211_MESHCONF_HT_OPMODE]) { + ht_opmode = nla_get_u16(tb[NL80211_MESHCONF_HT_OPMODE]); + + if (ht_opmode & ~(IEEE80211_HT_OP_MODE_PROTECTION | + IEEE80211_HT_OP_MODE_NON_GF_STA_PRSNT | + IEEE80211_HT_OP_MODE_NON_HT_STA_PRSNT)) + return -EINVAL; + + /* NON_HT_STA bit is reserved, but some programs set it */ + ht_opmode &= ~IEEE80211_HT_OP_MODE_NON_HT_STA_PRSNT; + + cfg->ht_opmode = ht_opmode; + mask |= (1 << (NL80211_MESHCONF_HT_OPMODE - 1)); + } + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, + dot11MeshHWMPactivePathToRootTimeout, mask, + NL80211_MESHCONF_HWMP_PATH_TO_ROOT_TIMEOUT, + nla_get_u32); + if (mask & BIT(NL80211_MESHCONF_HWMP_PATH_TO_ROOT_TIMEOUT) && + (cfg->dot11MeshHWMPactivePathToRootTimeout < 1 || + cfg->dot11MeshHWMPactivePathToRootTimeout > 65535)) + return -EINVAL; + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMProotInterval, mask, + NL80211_MESHCONF_HWMP_ROOT_INTERVAL, + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshHWMPconfirmationInterval, + mask, + NL80211_MESHCONF_HWMP_CONFIRMATION_INTERVAL, + nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, power_mode, mask, + NL80211_MESHCONF_POWER_MODE, nla_get_u32); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshAwakeWindowDuration, mask, + NL80211_MESHCONF_AWAKE_WINDOW, nla_get_u16); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, plink_timeout, mask, + NL80211_MESHCONF_PLINK_TIMEOUT, nla_get_u32); + FILL_IN_MESH_PARAM_IF_SET(tb, cfg, dot11MeshNolearn, mask, + NL80211_MESHCONF_NOLEARN, nla_get_u8); + if (mask_out) + *mask_out = mask; + + return 0; + +#undef FILL_IN_MESH_PARAM_IF_SET +} + +static int nl80211_parse_mesh_setup(struct genl_info *info, + struct mesh_setup *setup) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct nlattr *tb[NL80211_MESH_SETUP_ATTR_MAX + 1]; + + if (!info->attrs[NL80211_ATTR_MESH_SETUP]) + return -EINVAL; + if (nla_parse_nested_deprecated(tb, NL80211_MESH_SETUP_ATTR_MAX, info->attrs[NL80211_ATTR_MESH_SETUP], nl80211_mesh_setup_params_policy, info->extack)) + return -EINVAL; + + if (tb[NL80211_MESH_SETUP_ENABLE_VENDOR_SYNC]) + setup->sync_method = + (nla_get_u8(tb[NL80211_MESH_SETUP_ENABLE_VENDOR_SYNC])) ? + IEEE80211_SYNC_METHOD_VENDOR : + IEEE80211_SYNC_METHOD_NEIGHBOR_OFFSET; + + if (tb[NL80211_MESH_SETUP_ENABLE_VENDOR_PATH_SEL]) + setup->path_sel_proto = + (nla_get_u8(tb[NL80211_MESH_SETUP_ENABLE_VENDOR_PATH_SEL])) ? + IEEE80211_PATH_PROTOCOL_VENDOR : + IEEE80211_PATH_PROTOCOL_HWMP; + + if (tb[NL80211_MESH_SETUP_ENABLE_VENDOR_METRIC]) + setup->path_metric = + (nla_get_u8(tb[NL80211_MESH_SETUP_ENABLE_VENDOR_METRIC])) ? + IEEE80211_PATH_METRIC_VENDOR : + IEEE80211_PATH_METRIC_AIRTIME; + + if (tb[NL80211_MESH_SETUP_IE]) { + struct nlattr *ieattr = + tb[NL80211_MESH_SETUP_IE]; + setup->ie = nla_data(ieattr); + setup->ie_len = nla_len(ieattr); + } + if (tb[NL80211_MESH_SETUP_USERSPACE_MPM] && + !(rdev->wiphy.features & NL80211_FEATURE_USERSPACE_MPM)) + return -EINVAL; + setup->user_mpm = nla_get_flag(tb[NL80211_MESH_SETUP_USERSPACE_MPM]); + setup->is_authenticated = nla_get_flag(tb[NL80211_MESH_SETUP_USERSPACE_AUTH]); + setup->is_secure = nla_get_flag(tb[NL80211_MESH_SETUP_USERSPACE_AMPE]); + if (setup->is_secure) + setup->user_mpm = true; + + if (tb[NL80211_MESH_SETUP_AUTH_PROTOCOL]) { + if (!setup->user_mpm) + return -EINVAL; + setup->auth_id = + nla_get_u8(tb[NL80211_MESH_SETUP_AUTH_PROTOCOL]); + } + + return 0; +} + +static int nl80211_update_mesh_config(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct mesh_config cfg = {}; + u32 mask; + int err; + + if (wdev->iftype != NL80211_IFTYPE_MESH_POINT) + return -EOPNOTSUPP; + + if (!rdev->ops->update_mesh_config) + return -EOPNOTSUPP; + + err = nl80211_parse_mesh_config(info, &cfg, &mask); + if (err) + return err; + + wdev_lock(wdev); + if (!wdev->u.mesh.id_len) + err = -ENOLINK; + + if (!err) + err = rdev_update_mesh_config(rdev, dev, mask, &cfg); + + wdev_unlock(wdev); + + return err; +} + +static int nl80211_put_regdom(const struct ieee80211_regdomain *regdom, + struct sk_buff *msg) +{ + struct nlattr *nl_reg_rules; + unsigned int i; + + if (nla_put_string(msg, NL80211_ATTR_REG_ALPHA2, regdom->alpha2) || + (regdom->dfs_region && + nla_put_u8(msg, NL80211_ATTR_DFS_REGION, regdom->dfs_region))) + goto nla_put_failure; + + nl_reg_rules = nla_nest_start_noflag(msg, NL80211_ATTR_REG_RULES); + if (!nl_reg_rules) + goto nla_put_failure; + + for (i = 0; i < regdom->n_reg_rules; i++) { + struct nlattr *nl_reg_rule; + const struct ieee80211_reg_rule *reg_rule; + const struct ieee80211_freq_range *freq_range; + const struct ieee80211_power_rule *power_rule; + unsigned int max_bandwidth_khz; + + reg_rule = ®dom->reg_rules[i]; + freq_range = ®_rule->freq_range; + power_rule = ®_rule->power_rule; + + nl_reg_rule = nla_nest_start_noflag(msg, i); + if (!nl_reg_rule) + goto nla_put_failure; + + max_bandwidth_khz = freq_range->max_bandwidth_khz; + if (!max_bandwidth_khz) + max_bandwidth_khz = reg_get_max_bandwidth(regdom, + reg_rule); + + if (nla_put_u32(msg, NL80211_ATTR_REG_RULE_FLAGS, + reg_rule->flags) || + nla_put_u32(msg, NL80211_ATTR_FREQ_RANGE_START, + freq_range->start_freq_khz) || + nla_put_u32(msg, NL80211_ATTR_FREQ_RANGE_END, + freq_range->end_freq_khz) || + nla_put_u32(msg, NL80211_ATTR_FREQ_RANGE_MAX_BW, + max_bandwidth_khz) || + nla_put_u32(msg, NL80211_ATTR_POWER_RULE_MAX_ANT_GAIN, + power_rule->max_antenna_gain) || + nla_put_u32(msg, NL80211_ATTR_POWER_RULE_MAX_EIRP, + power_rule->max_eirp) || + nla_put_u32(msg, NL80211_ATTR_DFS_CAC_TIME, + reg_rule->dfs_cac_ms)) + goto nla_put_failure; + + nla_nest_end(msg, nl_reg_rule); + } + + nla_nest_end(msg, nl_reg_rules); + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static int nl80211_get_reg_do(struct sk_buff *skb, struct genl_info *info) +{ + const struct ieee80211_regdomain *regdom = NULL; + struct cfg80211_registered_device *rdev; + struct wiphy *wiphy = NULL; + struct sk_buff *msg; + int err = -EMSGSIZE; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOBUFS; + + hdr = nl80211hdr_put(msg, info->snd_portid, info->snd_seq, 0, + NL80211_CMD_GET_REG); + if (!hdr) + goto put_failure; + + rtnl_lock(); + + if (info->attrs[NL80211_ATTR_WIPHY]) { + bool self_managed; + + rdev = cfg80211_get_dev_from_info(genl_info_net(info), info); + if (IS_ERR(rdev)) { + err = PTR_ERR(rdev); + goto nla_put_failure; + } + + wiphy = &rdev->wiphy; + self_managed = wiphy->regulatory_flags & + REGULATORY_WIPHY_SELF_MANAGED; + + rcu_read_lock(); + + regdom = get_wiphy_regdom(wiphy); + + /* a self-managed-reg device must have a private regdom */ + if (WARN_ON(!regdom && self_managed)) { + err = -EINVAL; + goto nla_put_failure_rcu; + } + + if (regdom && + nla_put_u32(msg, NL80211_ATTR_WIPHY, get_wiphy_idx(wiphy))) + goto nla_put_failure_rcu; + } else { + rcu_read_lock(); + } + + if (!wiphy && reg_last_request_cell_base() && + nla_put_u32(msg, NL80211_ATTR_USER_REG_HINT_TYPE, + NL80211_USER_REG_HINT_CELL_BASE)) + goto nla_put_failure_rcu; + + if (!regdom) + regdom = rcu_dereference(cfg80211_regdomain); + + if (nl80211_put_regdom(regdom, msg)) + goto nla_put_failure_rcu; + + rcu_read_unlock(); + + genlmsg_end(msg, hdr); + rtnl_unlock(); + return genlmsg_reply(msg, info); + +nla_put_failure_rcu: + rcu_read_unlock(); +nla_put_failure: + rtnl_unlock(); +put_failure: + nlmsg_free(msg); + return err; +} + +static int nl80211_send_regdom(struct sk_buff *msg, struct netlink_callback *cb, + u32 seq, int flags, struct wiphy *wiphy, + const struct ieee80211_regdomain *regdom) +{ + void *hdr = nl80211hdr_put(msg, NETLINK_CB(cb->skb).portid, seq, flags, + NL80211_CMD_GET_REG); + + if (!hdr) + return -1; + + genl_dump_check_consistent(cb, hdr); + + if (nl80211_put_regdom(regdom, msg)) + goto nla_put_failure; + + if (!wiphy && reg_last_request_cell_base() && + nla_put_u32(msg, NL80211_ATTR_USER_REG_HINT_TYPE, + NL80211_USER_REG_HINT_CELL_BASE)) + goto nla_put_failure; + + if (wiphy && + nla_put_u32(msg, NL80211_ATTR_WIPHY, get_wiphy_idx(wiphy))) + goto nla_put_failure; + + if (wiphy && wiphy->regulatory_flags & REGULATORY_WIPHY_SELF_MANAGED && + nla_put_flag(msg, NL80211_ATTR_WIPHY_SELF_MANAGED_REG)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int nl80211_get_reg_dump(struct sk_buff *skb, + struct netlink_callback *cb) +{ + const struct ieee80211_regdomain *regdom = NULL; + struct cfg80211_registered_device *rdev; + int err, reg_idx, start = cb->args[2]; + + rcu_read_lock(); + + if (cfg80211_regdomain && start == 0) { + err = nl80211_send_regdom(skb, cb, cb->nlh->nlmsg_seq, + NLM_F_MULTI, NULL, + rcu_dereference(cfg80211_regdomain)); + if (err < 0) + goto out_err; + } + + /* the global regdom is idx 0 */ + reg_idx = 1; + list_for_each_entry_rcu(rdev, &cfg80211_rdev_list, list) { + regdom = get_wiphy_regdom(&rdev->wiphy); + if (!regdom) + continue; + + if (++reg_idx <= start) + continue; + + err = nl80211_send_regdom(skb, cb, cb->nlh->nlmsg_seq, + NLM_F_MULTI, &rdev->wiphy, regdom); + if (err < 0) { + reg_idx--; + break; + } + } + + cb->args[2] = reg_idx; + err = skb->len; +out_err: + rcu_read_unlock(); + return err; +} + +#ifdef CONFIG_CFG80211_CRDA_SUPPORT +static const struct nla_policy reg_rule_policy[NL80211_REG_RULE_ATTR_MAX + 1] = { + [NL80211_ATTR_REG_RULE_FLAGS] = { .type = NLA_U32 }, + [NL80211_ATTR_FREQ_RANGE_START] = { .type = NLA_U32 }, + [NL80211_ATTR_FREQ_RANGE_END] = { .type = NLA_U32 }, + [NL80211_ATTR_FREQ_RANGE_MAX_BW] = { .type = NLA_U32 }, + [NL80211_ATTR_POWER_RULE_MAX_ANT_GAIN] = { .type = NLA_U32 }, + [NL80211_ATTR_POWER_RULE_MAX_EIRP] = { .type = NLA_U32 }, + [NL80211_ATTR_DFS_CAC_TIME] = { .type = NLA_U32 }, +}; + +static int parse_reg_rule(struct nlattr *tb[], + struct ieee80211_reg_rule *reg_rule) +{ + struct ieee80211_freq_range *freq_range = ®_rule->freq_range; + struct ieee80211_power_rule *power_rule = ®_rule->power_rule; + + if (!tb[NL80211_ATTR_REG_RULE_FLAGS]) + return -EINVAL; + if (!tb[NL80211_ATTR_FREQ_RANGE_START]) + return -EINVAL; + if (!tb[NL80211_ATTR_FREQ_RANGE_END]) + return -EINVAL; + if (!tb[NL80211_ATTR_FREQ_RANGE_MAX_BW]) + return -EINVAL; + if (!tb[NL80211_ATTR_POWER_RULE_MAX_EIRP]) + return -EINVAL; + + reg_rule->flags = nla_get_u32(tb[NL80211_ATTR_REG_RULE_FLAGS]); + + freq_range->start_freq_khz = + nla_get_u32(tb[NL80211_ATTR_FREQ_RANGE_START]); + freq_range->end_freq_khz = + nla_get_u32(tb[NL80211_ATTR_FREQ_RANGE_END]); + freq_range->max_bandwidth_khz = + nla_get_u32(tb[NL80211_ATTR_FREQ_RANGE_MAX_BW]); + + power_rule->max_eirp = + nla_get_u32(tb[NL80211_ATTR_POWER_RULE_MAX_EIRP]); + + if (tb[NL80211_ATTR_POWER_RULE_MAX_ANT_GAIN]) + power_rule->max_antenna_gain = + nla_get_u32(tb[NL80211_ATTR_POWER_RULE_MAX_ANT_GAIN]); + + if (tb[NL80211_ATTR_DFS_CAC_TIME]) + reg_rule->dfs_cac_ms = + nla_get_u32(tb[NL80211_ATTR_DFS_CAC_TIME]); + + return 0; +} + +static int nl80211_set_reg(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *tb[NL80211_REG_RULE_ATTR_MAX + 1]; + struct nlattr *nl_reg_rule; + char *alpha2; + int rem_reg_rules, r; + u32 num_rules = 0, rule_idx = 0; + enum nl80211_dfs_regions dfs_region = NL80211_DFS_UNSET; + struct ieee80211_regdomain *rd; + + if (!info->attrs[NL80211_ATTR_REG_ALPHA2]) + return -EINVAL; + + if (!info->attrs[NL80211_ATTR_REG_RULES]) + return -EINVAL; + + alpha2 = nla_data(info->attrs[NL80211_ATTR_REG_ALPHA2]); + + if (info->attrs[NL80211_ATTR_DFS_REGION]) + dfs_region = nla_get_u8(info->attrs[NL80211_ATTR_DFS_REGION]); + + nla_for_each_nested(nl_reg_rule, info->attrs[NL80211_ATTR_REG_RULES], + rem_reg_rules) { + num_rules++; + if (num_rules > NL80211_MAX_SUPP_REG_RULES) + return -EINVAL; + } + + rtnl_lock(); + if (!reg_is_valid_request(alpha2)) { + r = -EINVAL; + goto out; + } + + rd = kzalloc(struct_size(rd, reg_rules, num_rules), GFP_KERNEL); + if (!rd) { + r = -ENOMEM; + goto out; + } + + rd->n_reg_rules = num_rules; + rd->alpha2[0] = alpha2[0]; + rd->alpha2[1] = alpha2[1]; + + /* + * Disable DFS master mode if the DFS region was + * not supported or known on this kernel. + */ + if (reg_supported_dfs_region(dfs_region)) + rd->dfs_region = dfs_region; + + nla_for_each_nested(nl_reg_rule, info->attrs[NL80211_ATTR_REG_RULES], + rem_reg_rules) { + r = nla_parse_nested_deprecated(tb, NL80211_REG_RULE_ATTR_MAX, + nl_reg_rule, reg_rule_policy, + info->extack); + if (r) + goto bad_reg; + r = parse_reg_rule(tb, &rd->reg_rules[rule_idx]); + if (r) + goto bad_reg; + + rule_idx++; + + if (rule_idx > NL80211_MAX_SUPP_REG_RULES) { + r = -EINVAL; + goto bad_reg; + } + } + + r = set_regdom(rd, REGD_SOURCE_CRDA); + /* set_regdom takes ownership of rd */ + rd = NULL; + bad_reg: + kfree(rd); + out: + rtnl_unlock(); + return r; +} +#endif /* CONFIG_CFG80211_CRDA_SUPPORT */ + +static int validate_scan_freqs(struct nlattr *freqs) +{ + struct nlattr *attr1, *attr2; + int n_channels = 0, tmp1, tmp2; + + nla_for_each_nested(attr1, freqs, tmp1) + if (nla_len(attr1) != sizeof(u32)) + return 0; + + nla_for_each_nested(attr1, freqs, tmp1) { + n_channels++; + /* + * Some hardware has a limited channel list for + * scanning, and it is pretty much nonsensical + * to scan for a channel twice, so disallow that + * and don't require drivers to check that the + * channel list they get isn't longer than what + * they can scan, as long as they can scan all + * the channels they registered at once. + */ + nla_for_each_nested(attr2, freqs, tmp2) + if (attr1 != attr2 && + nla_get_u32(attr1) == nla_get_u32(attr2)) + return 0; + } + + return n_channels; +} + +static bool is_band_valid(struct wiphy *wiphy, enum nl80211_band b) +{ + return b < NUM_NL80211_BANDS && wiphy->bands[b]; +} + +static int parse_bss_select(struct nlattr *nla, struct wiphy *wiphy, + struct cfg80211_bss_selection *bss_select) +{ + struct nlattr *attr[NL80211_BSS_SELECT_ATTR_MAX + 1]; + struct nlattr *nest; + int err; + bool found = false; + int i; + + /* only process one nested attribute */ + nest = nla_data(nla); + if (!nla_ok(nest, nla_len(nest))) + return -EINVAL; + + err = nla_parse_nested_deprecated(attr, NL80211_BSS_SELECT_ATTR_MAX, + nest, nl80211_bss_select_policy, + NULL); + if (err) + return err; + + /* only one attribute may be given */ + for (i = 0; i <= NL80211_BSS_SELECT_ATTR_MAX; i++) { + if (attr[i]) { + if (found) + return -EINVAL; + found = true; + } + } + + bss_select->behaviour = __NL80211_BSS_SELECT_ATTR_INVALID; + + if (attr[NL80211_BSS_SELECT_ATTR_RSSI]) + bss_select->behaviour = NL80211_BSS_SELECT_ATTR_RSSI; + + if (attr[NL80211_BSS_SELECT_ATTR_BAND_PREF]) { + bss_select->behaviour = NL80211_BSS_SELECT_ATTR_BAND_PREF; + bss_select->param.band_pref = + nla_get_u32(attr[NL80211_BSS_SELECT_ATTR_BAND_PREF]); + if (!is_band_valid(wiphy, bss_select->param.band_pref)) + return -EINVAL; + } + + if (attr[NL80211_BSS_SELECT_ATTR_RSSI_ADJUST]) { + struct nl80211_bss_select_rssi_adjust *adj_param; + + adj_param = nla_data(attr[NL80211_BSS_SELECT_ATTR_RSSI_ADJUST]); + bss_select->behaviour = NL80211_BSS_SELECT_ATTR_RSSI_ADJUST; + bss_select->param.adjust.band = adj_param->band; + bss_select->param.adjust.delta = adj_param->delta; + if (!is_band_valid(wiphy, bss_select->param.adjust.band)) + return -EINVAL; + } + + /* user-space did not provide behaviour attribute */ + if (bss_select->behaviour == __NL80211_BSS_SELECT_ATTR_INVALID) + return -EINVAL; + + if (!(wiphy->bss_select_support & BIT(bss_select->behaviour))) + return -EINVAL; + + return 0; +} + +int nl80211_parse_random_mac(struct nlattr **attrs, + u8 *mac_addr, u8 *mac_addr_mask) +{ + int i; + + if (!attrs[NL80211_ATTR_MAC] && !attrs[NL80211_ATTR_MAC_MASK]) { + eth_zero_addr(mac_addr); + eth_zero_addr(mac_addr_mask); + mac_addr[0] = 0x2; + mac_addr_mask[0] = 0x3; + + return 0; + } + + /* need both or none */ + if (!attrs[NL80211_ATTR_MAC] || !attrs[NL80211_ATTR_MAC_MASK]) + return -EINVAL; + + memcpy(mac_addr, nla_data(attrs[NL80211_ATTR_MAC]), ETH_ALEN); + memcpy(mac_addr_mask, nla_data(attrs[NL80211_ATTR_MAC_MASK]), ETH_ALEN); + + /* don't allow or configure an mcast address */ + if (!is_multicast_ether_addr(mac_addr_mask) || + is_multicast_ether_addr(mac_addr)) + return -EINVAL; + + /* + * allow users to pass a MAC address that has bits set outside + * of the mask, but don't bother drivers with having to deal + * with such bits + */ + for (i = 0; i < ETH_ALEN; i++) + mac_addr[i] &= mac_addr_mask[i]; + + return 0; +} + +static bool cfg80211_off_channel_oper_allowed(struct wireless_dev *wdev, + struct ieee80211_channel *chan) +{ + unsigned int link_id; + bool all_ok = true; + + ASSERT_WDEV_LOCK(wdev); + + if (!cfg80211_beaconing_iface_active(wdev)) + return true; + + /* + * FIXME: check if we have a free HW resource/link for chan + * + * This, as well as the FIXME below, requires knowing the link + * capabilities of the hardware. + */ + + /* we cannot leave radar channels */ + for_each_valid_link(wdev, link_id) { + struct cfg80211_chan_def *chandef; + + chandef = wdev_chandef(wdev, link_id); + if (!chandef || !chandef->chan) + continue; + + /* + * FIXME: don't require all_ok, but rather check only the + * correct HW resource/link onto which 'chan' falls, + * as only that link leaves the channel for doing + * the off-channel operation. + */ + + if (chandef->chan->flags & IEEE80211_CHAN_RADAR) + all_ok = false; + } + + if (all_ok) + return true; + + return regulatory_pre_cac_allowed(wdev->wiphy); +} + +static bool nl80211_check_scan_feat(struct wiphy *wiphy, u32 flags, u32 flag, + enum nl80211_ext_feature_index feat) +{ + if (!(flags & flag)) + return true; + if (wiphy_ext_feature_isset(wiphy, feat)) + return true; + return false; +} + +static int +nl80211_check_scan_flags(struct wiphy *wiphy, struct wireless_dev *wdev, + void *request, struct nlattr **attrs, + bool is_sched_scan) +{ + u8 *mac_addr, *mac_addr_mask; + u32 *flags; + enum nl80211_feature_flags randomness_flag; + + if (!attrs[NL80211_ATTR_SCAN_FLAGS]) + return 0; + + if (is_sched_scan) { + struct cfg80211_sched_scan_request *req = request; + + randomness_flag = wdev ? + NL80211_FEATURE_SCHED_SCAN_RANDOM_MAC_ADDR : + NL80211_FEATURE_ND_RANDOM_MAC_ADDR; + flags = &req->flags; + mac_addr = req->mac_addr; + mac_addr_mask = req->mac_addr_mask; + } else { + struct cfg80211_scan_request *req = request; + + randomness_flag = NL80211_FEATURE_SCAN_RANDOM_MAC_ADDR; + flags = &req->flags; + mac_addr = req->mac_addr; + mac_addr_mask = req->mac_addr_mask; + } + + *flags = nla_get_u32(attrs[NL80211_ATTR_SCAN_FLAGS]); + + if (((*flags & NL80211_SCAN_FLAG_LOW_PRIORITY) && + !(wiphy->features & NL80211_FEATURE_LOW_PRIORITY_SCAN)) || + !nl80211_check_scan_feat(wiphy, *flags, + NL80211_SCAN_FLAG_LOW_SPAN, + NL80211_EXT_FEATURE_LOW_SPAN_SCAN) || + !nl80211_check_scan_feat(wiphy, *flags, + NL80211_SCAN_FLAG_LOW_POWER, + NL80211_EXT_FEATURE_LOW_POWER_SCAN) || + !nl80211_check_scan_feat(wiphy, *flags, + NL80211_SCAN_FLAG_HIGH_ACCURACY, + NL80211_EXT_FEATURE_HIGH_ACCURACY_SCAN) || + !nl80211_check_scan_feat(wiphy, *flags, + NL80211_SCAN_FLAG_FILS_MAX_CHANNEL_TIME, + NL80211_EXT_FEATURE_FILS_MAX_CHANNEL_TIME) || + !nl80211_check_scan_feat(wiphy, *flags, + NL80211_SCAN_FLAG_ACCEPT_BCAST_PROBE_RESP, + NL80211_EXT_FEATURE_ACCEPT_BCAST_PROBE_RESP) || + !nl80211_check_scan_feat(wiphy, *flags, + NL80211_SCAN_FLAG_OCE_PROBE_REQ_DEFERRAL_SUPPRESSION, + NL80211_EXT_FEATURE_OCE_PROBE_REQ_DEFERRAL_SUPPRESSION) || + !nl80211_check_scan_feat(wiphy, *flags, + NL80211_SCAN_FLAG_OCE_PROBE_REQ_HIGH_TX_RATE, + NL80211_EXT_FEATURE_OCE_PROBE_REQ_HIGH_TX_RATE) || + !nl80211_check_scan_feat(wiphy, *flags, + NL80211_SCAN_FLAG_RANDOM_SN, + NL80211_EXT_FEATURE_SCAN_RANDOM_SN) || + !nl80211_check_scan_feat(wiphy, *flags, + NL80211_SCAN_FLAG_MIN_PREQ_CONTENT, + NL80211_EXT_FEATURE_SCAN_MIN_PREQ_CONTENT)) + return -EOPNOTSUPP; + + if (*flags & NL80211_SCAN_FLAG_RANDOM_ADDR) { + int err; + + if (!(wiphy->features & randomness_flag) || + (wdev && wdev->connected)) + return -EOPNOTSUPP; + + err = nl80211_parse_random_mac(attrs, mac_addr, mac_addr_mask); + if (err) + return err; + } + + return 0; +} + +static int nl80211_trigger_scan(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + struct cfg80211_scan_request *request; + struct nlattr *scan_freqs = NULL; + bool scan_freqs_khz = false; + struct nlattr *attr; + struct wiphy *wiphy; + int err, tmp, n_ssids = 0, n_channels, i; + size_t ie_len; + + wiphy = &rdev->wiphy; + + if (wdev->iftype == NL80211_IFTYPE_NAN) + return -EOPNOTSUPP; + + if (!rdev->ops->scan) + return -EOPNOTSUPP; + + if (rdev->scan_req || rdev->scan_msg) + return -EBUSY; + + if (info->attrs[NL80211_ATTR_SCAN_FREQ_KHZ]) { + if (!wiphy_ext_feature_isset(wiphy, + NL80211_EXT_FEATURE_SCAN_FREQ_KHZ)) + return -EOPNOTSUPP; + scan_freqs = info->attrs[NL80211_ATTR_SCAN_FREQ_KHZ]; + scan_freqs_khz = true; + } else if (info->attrs[NL80211_ATTR_SCAN_FREQUENCIES]) + scan_freqs = info->attrs[NL80211_ATTR_SCAN_FREQUENCIES]; + + if (scan_freqs) { + n_channels = validate_scan_freqs(scan_freqs); + if (!n_channels) + return -EINVAL; + } else { + n_channels = ieee80211_get_num_supported_channels(wiphy); + } + + if (info->attrs[NL80211_ATTR_SCAN_SSIDS]) + nla_for_each_nested(attr, info->attrs[NL80211_ATTR_SCAN_SSIDS], tmp) + n_ssids++; + + if (n_ssids > wiphy->max_scan_ssids) + return -EINVAL; + + if (info->attrs[NL80211_ATTR_IE]) + ie_len = nla_len(info->attrs[NL80211_ATTR_IE]); + else + ie_len = 0; + + if (ie_len > wiphy->max_scan_ie_len) + return -EINVAL; + + request = kzalloc(sizeof(*request) + + sizeof(*request->ssids) * n_ssids + + sizeof(*request->channels) * n_channels + + ie_len, GFP_KERNEL); + if (!request) + return -ENOMEM; + + if (n_ssids) + request->ssids = (void *)&request->channels[n_channels]; + request->n_ssids = n_ssids; + if (ie_len) { + if (n_ssids) + request->ie = (void *)(request->ssids + n_ssids); + else + request->ie = (void *)(request->channels + n_channels); + } + + i = 0; + if (scan_freqs) { + /* user specified, bail out if channel not found */ + nla_for_each_nested(attr, scan_freqs, tmp) { + struct ieee80211_channel *chan; + int freq = nla_get_u32(attr); + + if (!scan_freqs_khz) + freq = MHZ_TO_KHZ(freq); + + chan = ieee80211_get_channel_khz(wiphy, freq); + if (!chan) { + err = -EINVAL; + goto out_free; + } + + /* ignore disabled channels */ + if (chan->flags & IEEE80211_CHAN_DISABLED) + continue; + + request->channels[i] = chan; + i++; + } + } else { + enum nl80211_band band; + + /* all channels */ + for (band = 0; band < NUM_NL80211_BANDS; band++) { + int j; + + if (!wiphy->bands[band]) + continue; + for (j = 0; j < wiphy->bands[band]->n_channels; j++) { + struct ieee80211_channel *chan; + + chan = &wiphy->bands[band]->channels[j]; + + if (chan->flags & IEEE80211_CHAN_DISABLED) + continue; + + request->channels[i] = chan; + i++; + } + } + } + + if (!i) { + err = -EINVAL; + goto out_free; + } + + request->n_channels = i; + + wdev_lock(wdev); + for (i = 0; i < request->n_channels; i++) { + struct ieee80211_channel *chan = request->channels[i]; + + /* if we can go off-channel to the target channel we're good */ + if (cfg80211_off_channel_oper_allowed(wdev, chan)) + continue; + + if (!cfg80211_wdev_on_sub_chan(wdev, chan, true)) { + wdev_unlock(wdev); + err = -EBUSY; + goto out_free; + } + } + wdev_unlock(wdev); + + i = 0; + if (n_ssids) { + nla_for_each_nested(attr, info->attrs[NL80211_ATTR_SCAN_SSIDS], tmp) { + if (nla_len(attr) > IEEE80211_MAX_SSID_LEN) { + err = -EINVAL; + goto out_free; + } + request->ssids[i].ssid_len = nla_len(attr); + memcpy(request->ssids[i].ssid, nla_data(attr), nla_len(attr)); + i++; + } + } + + if (info->attrs[NL80211_ATTR_IE]) { + request->ie_len = nla_len(info->attrs[NL80211_ATTR_IE]); + memcpy((void *)request->ie, + nla_data(info->attrs[NL80211_ATTR_IE]), + request->ie_len); + } + + for (i = 0; i < NUM_NL80211_BANDS; i++) + if (wiphy->bands[i]) + request->rates[i] = + (1 << wiphy->bands[i]->n_bitrates) - 1; + + if (info->attrs[NL80211_ATTR_SCAN_SUPP_RATES]) { + nla_for_each_nested(attr, + info->attrs[NL80211_ATTR_SCAN_SUPP_RATES], + tmp) { + enum nl80211_band band = nla_type(attr); + + if (band < 0 || band >= NUM_NL80211_BANDS) { + err = -EINVAL; + goto out_free; + } + + if (!wiphy->bands[band]) + continue; + + err = ieee80211_get_ratemask(wiphy->bands[band], + nla_data(attr), + nla_len(attr), + &request->rates[band]); + if (err) + goto out_free; + } + } + + if (info->attrs[NL80211_ATTR_MEASUREMENT_DURATION]) { + request->duration = + nla_get_u16(info->attrs[NL80211_ATTR_MEASUREMENT_DURATION]); + request->duration_mandatory = + nla_get_flag(info->attrs[NL80211_ATTR_MEASUREMENT_DURATION_MANDATORY]); + } + + err = nl80211_check_scan_flags(wiphy, wdev, request, info->attrs, + false); + if (err) + goto out_free; + + request->no_cck = + nla_get_flag(info->attrs[NL80211_ATTR_TX_NO_CCK_RATE]); + + /* Initial implementation used NL80211_ATTR_MAC to set the specific + * BSSID to scan for. This was problematic because that same attribute + * was already used for another purpose (local random MAC address). The + * NL80211_ATTR_BSSID attribute was added to fix this. For backwards + * compatibility with older userspace components, also use the + * NL80211_ATTR_MAC value here if it can be determined to be used for + * the specific BSSID use case instead of the random MAC address + * (NL80211_ATTR_SCAN_FLAGS is used to enable random MAC address use). + */ + if (info->attrs[NL80211_ATTR_BSSID]) + memcpy(request->bssid, + nla_data(info->attrs[NL80211_ATTR_BSSID]), ETH_ALEN); + else if (!(request->flags & NL80211_SCAN_FLAG_RANDOM_ADDR) && + info->attrs[NL80211_ATTR_MAC]) + memcpy(request->bssid, nla_data(info->attrs[NL80211_ATTR_MAC]), + ETH_ALEN); + else + eth_broadcast_addr(request->bssid); + + request->wdev = wdev; + request->wiphy = &rdev->wiphy; + request->scan_start = jiffies; + + rdev->scan_req = request; + err = cfg80211_scan(rdev); + + if (err) + goto out_free; + + nl80211_send_scan_start(rdev, wdev); + dev_hold(wdev->netdev); + + return 0; + + out_free: + rdev->scan_req = NULL; + kfree(request); + + return err; +} + +static int nl80211_abort_scan(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + + if (!rdev->ops->abort_scan) + return -EOPNOTSUPP; + + if (rdev->scan_msg) + return 0; + + if (!rdev->scan_req) + return -ENOENT; + + rdev_abort_scan(rdev, wdev); + return 0; +} + +static int +nl80211_parse_sched_scan_plans(struct wiphy *wiphy, int n_plans, + struct cfg80211_sched_scan_request *request, + struct nlattr **attrs) +{ + int tmp, err, i = 0; + struct nlattr *attr; + + if (!attrs[NL80211_ATTR_SCHED_SCAN_PLANS]) { + u32 interval; + + /* + * If scan plans are not specified, + * %NL80211_ATTR_SCHED_SCAN_INTERVAL will be specified. In this + * case one scan plan will be set with the specified scan + * interval and infinite number of iterations. + */ + interval = nla_get_u32(attrs[NL80211_ATTR_SCHED_SCAN_INTERVAL]); + if (!interval) + return -EINVAL; + + request->scan_plans[0].interval = + DIV_ROUND_UP(interval, MSEC_PER_SEC); + if (!request->scan_plans[0].interval) + return -EINVAL; + + if (request->scan_plans[0].interval > + wiphy->max_sched_scan_plan_interval) + request->scan_plans[0].interval = + wiphy->max_sched_scan_plan_interval; + + return 0; + } + + nla_for_each_nested(attr, attrs[NL80211_ATTR_SCHED_SCAN_PLANS], tmp) { + struct nlattr *plan[NL80211_SCHED_SCAN_PLAN_MAX + 1]; + + if (WARN_ON(i >= n_plans)) + return -EINVAL; + + err = nla_parse_nested_deprecated(plan, + NL80211_SCHED_SCAN_PLAN_MAX, + attr, nl80211_plan_policy, + NULL); + if (err) + return err; + + if (!plan[NL80211_SCHED_SCAN_PLAN_INTERVAL]) + return -EINVAL; + + request->scan_plans[i].interval = + nla_get_u32(plan[NL80211_SCHED_SCAN_PLAN_INTERVAL]); + if (!request->scan_plans[i].interval || + request->scan_plans[i].interval > + wiphy->max_sched_scan_plan_interval) + return -EINVAL; + + if (plan[NL80211_SCHED_SCAN_PLAN_ITERATIONS]) { + request->scan_plans[i].iterations = + nla_get_u32(plan[NL80211_SCHED_SCAN_PLAN_ITERATIONS]); + if (!request->scan_plans[i].iterations || + (request->scan_plans[i].iterations > + wiphy->max_sched_scan_plan_iterations)) + return -EINVAL; + } else if (i < n_plans - 1) { + /* + * All scan plans but the last one must specify + * a finite number of iterations + */ + return -EINVAL; + } + + i++; + } + + /* + * The last scan plan must not specify the number of + * iterations, it is supposed to run infinitely + */ + if (request->scan_plans[n_plans - 1].iterations) + return -EINVAL; + + return 0; +} + +static int +nl80211_parse_sched_scan_per_band_rssi(struct wiphy *wiphy, + struct cfg80211_match_set *match_sets, + struct nlattr *tb_band_rssi, + s32 rssi_thold) +{ + struct nlattr *attr; + int i, tmp, ret = 0; + + if (!wiphy_ext_feature_isset(wiphy, + NL80211_EXT_FEATURE_SCHED_SCAN_BAND_SPECIFIC_RSSI_THOLD)) { + if (tb_band_rssi) + ret = -EOPNOTSUPP; + else + for (i = 0; i < NUM_NL80211_BANDS; i++) + match_sets->per_band_rssi_thold[i] = + NL80211_SCAN_RSSI_THOLD_OFF; + return ret; + } + + for (i = 0; i < NUM_NL80211_BANDS; i++) + match_sets->per_band_rssi_thold[i] = rssi_thold; + + nla_for_each_nested(attr, tb_band_rssi, tmp) { + enum nl80211_band band = nla_type(attr); + + if (band < 0 || band >= NUM_NL80211_BANDS) + return -EINVAL; + + match_sets->per_band_rssi_thold[band] = nla_get_s32(attr); + } + + return 0; +} + +static struct cfg80211_sched_scan_request * +nl80211_parse_sched_scan(struct wiphy *wiphy, struct wireless_dev *wdev, + struct nlattr **attrs, int max_match_sets) +{ + struct cfg80211_sched_scan_request *request; + struct nlattr *attr; + int err, tmp, n_ssids = 0, n_match_sets = 0, n_channels, i, n_plans = 0; + enum nl80211_band band; + size_t ie_len; + struct nlattr *tb[NL80211_SCHED_SCAN_MATCH_ATTR_MAX + 1]; + s32 default_match_rssi = NL80211_SCAN_RSSI_THOLD_OFF; + + if (attrs[NL80211_ATTR_SCAN_FREQUENCIES]) { + n_channels = validate_scan_freqs( + attrs[NL80211_ATTR_SCAN_FREQUENCIES]); + if (!n_channels) + return ERR_PTR(-EINVAL); + } else { + n_channels = ieee80211_get_num_supported_channels(wiphy); + } + + if (attrs[NL80211_ATTR_SCAN_SSIDS]) + nla_for_each_nested(attr, attrs[NL80211_ATTR_SCAN_SSIDS], + tmp) + n_ssids++; + + if (n_ssids > wiphy->max_sched_scan_ssids) + return ERR_PTR(-EINVAL); + + /* + * First, count the number of 'real' matchsets. Due to an issue with + * the old implementation, matchsets containing only the RSSI attribute + * (NL80211_SCHED_SCAN_MATCH_ATTR_RSSI) are considered as the 'default' + * RSSI for all matchsets, rather than their own matchset for reporting + * all APs with a strong RSSI. This is needed to be compatible with + * older userspace that treated a matchset with only the RSSI as the + * global RSSI for all other matchsets - if there are other matchsets. + */ + if (attrs[NL80211_ATTR_SCHED_SCAN_MATCH]) { + nla_for_each_nested(attr, + attrs[NL80211_ATTR_SCHED_SCAN_MATCH], + tmp) { + struct nlattr *rssi; + + err = nla_parse_nested_deprecated(tb, + NL80211_SCHED_SCAN_MATCH_ATTR_MAX, + attr, + nl80211_match_policy, + NULL); + if (err) + return ERR_PTR(err); + + /* SSID and BSSID are mutually exclusive */ + if (tb[NL80211_SCHED_SCAN_MATCH_ATTR_SSID] && + tb[NL80211_SCHED_SCAN_MATCH_ATTR_BSSID]) + return ERR_PTR(-EINVAL); + + /* add other standalone attributes here */ + if (tb[NL80211_SCHED_SCAN_MATCH_ATTR_SSID] || + tb[NL80211_SCHED_SCAN_MATCH_ATTR_BSSID]) { + n_match_sets++; + continue; + } + rssi = tb[NL80211_SCHED_SCAN_MATCH_ATTR_RSSI]; + if (rssi) + default_match_rssi = nla_get_s32(rssi); + } + } + + /* However, if there's no other matchset, add the RSSI one */ + if (!n_match_sets && default_match_rssi != NL80211_SCAN_RSSI_THOLD_OFF) + n_match_sets = 1; + + if (n_match_sets > max_match_sets) + return ERR_PTR(-EINVAL); + + if (attrs[NL80211_ATTR_IE]) + ie_len = nla_len(attrs[NL80211_ATTR_IE]); + else + ie_len = 0; + + if (ie_len > wiphy->max_sched_scan_ie_len) + return ERR_PTR(-EINVAL); + + if (attrs[NL80211_ATTR_SCHED_SCAN_PLANS]) { + /* + * NL80211_ATTR_SCHED_SCAN_INTERVAL must not be specified since + * each scan plan already specifies its own interval + */ + if (attrs[NL80211_ATTR_SCHED_SCAN_INTERVAL]) + return ERR_PTR(-EINVAL); + + nla_for_each_nested(attr, + attrs[NL80211_ATTR_SCHED_SCAN_PLANS], tmp) + n_plans++; + } else { + /* + * The scan interval attribute is kept for backward + * compatibility. If no scan plans are specified and sched scan + * interval is specified, one scan plan will be set with this + * scan interval and infinite number of iterations. + */ + if (!attrs[NL80211_ATTR_SCHED_SCAN_INTERVAL]) + return ERR_PTR(-EINVAL); + + n_plans = 1; + } + + if (!n_plans || n_plans > wiphy->max_sched_scan_plans) + return ERR_PTR(-EINVAL); + + if (!wiphy_ext_feature_isset( + wiphy, NL80211_EXT_FEATURE_SCHED_SCAN_RELATIVE_RSSI) && + (attrs[NL80211_ATTR_SCHED_SCAN_RELATIVE_RSSI] || + attrs[NL80211_ATTR_SCHED_SCAN_RSSI_ADJUST])) + return ERR_PTR(-EINVAL); + + request = kzalloc(sizeof(*request) + + sizeof(*request->ssids) * n_ssids + + sizeof(*request->match_sets) * n_match_sets + + sizeof(*request->scan_plans) * n_plans + + sizeof(*request->channels) * n_channels + + ie_len, GFP_KERNEL); + if (!request) + return ERR_PTR(-ENOMEM); + + if (n_ssids) + request->ssids = (void *)&request->channels[n_channels]; + request->n_ssids = n_ssids; + if (ie_len) { + if (n_ssids) + request->ie = (void *)(request->ssids + n_ssids); + else + request->ie = (void *)(request->channels + n_channels); + } + + if (n_match_sets) { + if (request->ie) + request->match_sets = (void *)(request->ie + ie_len); + else if (n_ssids) + request->match_sets = + (void *)(request->ssids + n_ssids); + else + request->match_sets = + (void *)(request->channels + n_channels); + } + request->n_match_sets = n_match_sets; + + if (n_match_sets) + request->scan_plans = (void *)(request->match_sets + + n_match_sets); + else if (request->ie) + request->scan_plans = (void *)(request->ie + ie_len); + else if (n_ssids) + request->scan_plans = (void *)(request->ssids + n_ssids); + else + request->scan_plans = (void *)(request->channels + n_channels); + + request->n_scan_plans = n_plans; + + i = 0; + if (attrs[NL80211_ATTR_SCAN_FREQUENCIES]) { + /* user specified, bail out if channel not found */ + nla_for_each_nested(attr, + attrs[NL80211_ATTR_SCAN_FREQUENCIES], + tmp) { + struct ieee80211_channel *chan; + + chan = ieee80211_get_channel(wiphy, nla_get_u32(attr)); + + if (!chan) { + err = -EINVAL; + goto out_free; + } + + /* ignore disabled channels */ + if (chan->flags & IEEE80211_CHAN_DISABLED) + continue; + + request->channels[i] = chan; + i++; + } + } else { + /* all channels */ + for (band = 0; band < NUM_NL80211_BANDS; band++) { + int j; + + if (!wiphy->bands[band]) + continue; + for (j = 0; j < wiphy->bands[band]->n_channels; j++) { + struct ieee80211_channel *chan; + + chan = &wiphy->bands[band]->channels[j]; + + if (chan->flags & IEEE80211_CHAN_DISABLED) + continue; + + request->channels[i] = chan; + i++; + } + } + } + + if (!i) { + err = -EINVAL; + goto out_free; + } + + request->n_channels = i; + + i = 0; + if (n_ssids) { + nla_for_each_nested(attr, attrs[NL80211_ATTR_SCAN_SSIDS], + tmp) { + if (nla_len(attr) > IEEE80211_MAX_SSID_LEN) { + err = -EINVAL; + goto out_free; + } + request->ssids[i].ssid_len = nla_len(attr); + memcpy(request->ssids[i].ssid, nla_data(attr), + nla_len(attr)); + i++; + } + } + + i = 0; + if (attrs[NL80211_ATTR_SCHED_SCAN_MATCH]) { + nla_for_each_nested(attr, + attrs[NL80211_ATTR_SCHED_SCAN_MATCH], + tmp) { + struct nlattr *ssid, *bssid, *rssi; + + err = nla_parse_nested_deprecated(tb, + NL80211_SCHED_SCAN_MATCH_ATTR_MAX, + attr, + nl80211_match_policy, + NULL); + if (err) + goto out_free; + ssid = tb[NL80211_SCHED_SCAN_MATCH_ATTR_SSID]; + bssid = tb[NL80211_SCHED_SCAN_MATCH_ATTR_BSSID]; + + if (!ssid && !bssid) { + i++; + continue; + } + + if (WARN_ON(i >= n_match_sets)) { + /* this indicates a programming error, + * the loop above should have verified + * things properly + */ + err = -EINVAL; + goto out_free; + } + + if (ssid) { + memcpy(request->match_sets[i].ssid.ssid, + nla_data(ssid), nla_len(ssid)); + request->match_sets[i].ssid.ssid_len = + nla_len(ssid); + } + if (bssid) + memcpy(request->match_sets[i].bssid, + nla_data(bssid), ETH_ALEN); + + /* special attribute - old implementation w/a */ + request->match_sets[i].rssi_thold = default_match_rssi; + rssi = tb[NL80211_SCHED_SCAN_MATCH_ATTR_RSSI]; + if (rssi) + request->match_sets[i].rssi_thold = + nla_get_s32(rssi); + + /* Parse per band RSSI attribute */ + err = nl80211_parse_sched_scan_per_band_rssi(wiphy, + &request->match_sets[i], + tb[NL80211_SCHED_SCAN_MATCH_PER_BAND_RSSI], + request->match_sets[i].rssi_thold); + if (err) + goto out_free; + + i++; + } + + /* there was no other matchset, so the RSSI one is alone */ + if (i == 0 && n_match_sets) + request->match_sets[0].rssi_thold = default_match_rssi; + + request->min_rssi_thold = INT_MAX; + for (i = 0; i < n_match_sets; i++) + request->min_rssi_thold = + min(request->match_sets[i].rssi_thold, + request->min_rssi_thold); + } else { + request->min_rssi_thold = NL80211_SCAN_RSSI_THOLD_OFF; + } + + if (ie_len) { + request->ie_len = ie_len; + memcpy((void *)request->ie, + nla_data(attrs[NL80211_ATTR_IE]), + request->ie_len); + } + + err = nl80211_check_scan_flags(wiphy, wdev, request, attrs, true); + if (err) + goto out_free; + + if (attrs[NL80211_ATTR_SCHED_SCAN_DELAY]) + request->delay = + nla_get_u32(attrs[NL80211_ATTR_SCHED_SCAN_DELAY]); + + if (attrs[NL80211_ATTR_SCHED_SCAN_RELATIVE_RSSI]) { + request->relative_rssi = nla_get_s8( + attrs[NL80211_ATTR_SCHED_SCAN_RELATIVE_RSSI]); + request->relative_rssi_set = true; + } + + if (request->relative_rssi_set && + attrs[NL80211_ATTR_SCHED_SCAN_RSSI_ADJUST]) { + struct nl80211_bss_select_rssi_adjust *rssi_adjust; + + rssi_adjust = nla_data( + attrs[NL80211_ATTR_SCHED_SCAN_RSSI_ADJUST]); + request->rssi_adjust.band = rssi_adjust->band; + request->rssi_adjust.delta = rssi_adjust->delta; + if (!is_band_valid(wiphy, request->rssi_adjust.band)) { + err = -EINVAL; + goto out_free; + } + } + + err = nl80211_parse_sched_scan_plans(wiphy, n_plans, request, attrs); + if (err) + goto out_free; + + request->scan_start = jiffies; + + return request; + +out_free: + kfree(request); + return ERR_PTR(err); +} + +static int nl80211_start_sched_scan(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_sched_scan_request *sched_scan_req; + bool want_multi; + int err; + + if (!rdev->wiphy.max_sched_scan_reqs || !rdev->ops->sched_scan_start) + return -EOPNOTSUPP; + + want_multi = info->attrs[NL80211_ATTR_SCHED_SCAN_MULTI]; + err = cfg80211_sched_scan_req_possible(rdev, want_multi); + if (err) + return err; + + sched_scan_req = nl80211_parse_sched_scan(&rdev->wiphy, wdev, + info->attrs, + rdev->wiphy.max_match_sets); + + err = PTR_ERR_OR_ZERO(sched_scan_req); + if (err) + goto out_err; + + /* leave request id zero for legacy request + * or if driver does not support multi-scheduled scan + */ + if (want_multi && rdev->wiphy.max_sched_scan_reqs > 1) + sched_scan_req->reqid = cfg80211_assign_cookie(rdev); + + err = rdev_sched_scan_start(rdev, dev, sched_scan_req); + if (err) + goto out_free; + + sched_scan_req->dev = dev; + sched_scan_req->wiphy = &rdev->wiphy; + + if (info->attrs[NL80211_ATTR_SOCKET_OWNER]) + sched_scan_req->owner_nlportid = info->snd_portid; + + cfg80211_add_sched_scan_req(rdev, sched_scan_req); + + nl80211_send_sched_scan(sched_scan_req, NL80211_CMD_START_SCHED_SCAN); + return 0; + +out_free: + kfree(sched_scan_req); +out_err: + return err; +} + +static int nl80211_stop_sched_scan(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_sched_scan_request *req; + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + u64 cookie; + + if (!rdev->wiphy.max_sched_scan_reqs || !rdev->ops->sched_scan_stop) + return -EOPNOTSUPP; + + if (info->attrs[NL80211_ATTR_COOKIE]) { + cookie = nla_get_u64(info->attrs[NL80211_ATTR_COOKIE]); + return __cfg80211_stop_sched_scan(rdev, cookie, false); + } + + req = list_first_or_null_rcu(&rdev->sched_scan_req_list, + struct cfg80211_sched_scan_request, + list); + if (!req || req->reqid || + (req->owner_nlportid && + req->owner_nlportid != info->snd_portid)) + return -ENOENT; + + return cfg80211_stop_sched_scan_req(rdev, req, false); +} + +static int nl80211_start_radar_detection(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_chan_def chandef; + enum nl80211_dfs_regions dfs_region; + unsigned int cac_time_ms; + int err = -EINVAL; + + flush_delayed_work(&rdev->dfs_update_channels_wk); + + wiphy_lock(wiphy); + + dfs_region = reg_get_dfs_region(wiphy); + if (dfs_region == NL80211_DFS_UNSET) + goto unlock; + + err = nl80211_parse_chandef(rdev, info, &chandef); + if (err) + goto unlock; + + err = cfg80211_chandef_dfs_required(wiphy, &chandef, wdev->iftype); + if (err < 0) + goto unlock; + + if (err == 0) { + err = -EINVAL; + goto unlock; + } + + if (!cfg80211_chandef_dfs_usable(wiphy, &chandef)) { + err = -EINVAL; + goto unlock; + } + + if (nla_get_flag(info->attrs[NL80211_ATTR_RADAR_BACKGROUND])) { + err = cfg80211_start_background_radar_detection(rdev, wdev, + &chandef); + goto unlock; + } + + if (netif_carrier_ok(dev)) { + err = -EBUSY; + goto unlock; + } + + if (wdev->cac_started) { + err = -EBUSY; + goto unlock; + } + + /* CAC start is offloaded to HW and can't be started manually */ + if (wiphy_ext_feature_isset(wiphy, NL80211_EXT_FEATURE_DFS_OFFLOAD)) { + err = -EOPNOTSUPP; + goto unlock; + } + + if (!rdev->ops->start_radar_detection) { + err = -EOPNOTSUPP; + goto unlock; + } + + cac_time_ms = cfg80211_chandef_dfs_cac_time(&rdev->wiphy, &chandef); + if (WARN_ON(!cac_time_ms)) + cac_time_ms = IEEE80211_DFS_MIN_CAC_TIME_MS; + + err = rdev_start_radar_detection(rdev, dev, &chandef, cac_time_ms); + if (!err) { + wdev->links[0].ap.chandef = chandef; + wdev->cac_started = true; + wdev->cac_start_time = jiffies; + wdev->cac_time_ms = cac_time_ms; + } +unlock: + wiphy_unlock(wiphy); + + return err; +} + +static int nl80211_notify_radar_detection(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_chan_def chandef; + enum nl80211_dfs_regions dfs_region; + int err; + + dfs_region = reg_get_dfs_region(wiphy); + if (dfs_region == NL80211_DFS_UNSET) { + GENL_SET_ERR_MSG(info, + "DFS Region is not set. Unexpected Radar indication"); + return -EINVAL; + } + + err = nl80211_parse_chandef(rdev, info, &chandef); + if (err) { + GENL_SET_ERR_MSG(info, "Unable to extract chandef info"); + return err; + } + + err = cfg80211_chandef_dfs_required(wiphy, &chandef, wdev->iftype); + if (err < 0) { + GENL_SET_ERR_MSG(info, "chandef is invalid"); + return err; + } + + if (err == 0) { + GENL_SET_ERR_MSG(info, + "Unexpected Radar indication for chandef/iftype"); + return -EINVAL; + } + + /* Do not process this notification if radar is already detected + * by kernel on this channel, and return success. + */ + if (chandef.chan->dfs_state == NL80211_DFS_UNAVAILABLE) + return 0; + + cfg80211_set_dfs_state(wiphy, &chandef, NL80211_DFS_UNAVAILABLE); + + cfg80211_sched_dfs_chan_update(rdev); + + rdev->radar_chandef = chandef; + + /* Propagate this notification to other radios as well */ + queue_work(cfg80211_wq, &rdev->propagate_radar_detect_wk); + + return 0; +} + +static int nl80211_channel_switch(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + unsigned int link_id = nl80211_link_id(info->attrs); + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_csa_settings params; + struct nlattr **csa_attrs = NULL; + int err; + bool need_new_beacon = false; + bool need_handle_dfs_flag = true; + int len, i; + u32 cs_count; + + if (!rdev->ops->channel_switch || + !(rdev->wiphy.flags & WIPHY_FLAG_HAS_CHANNEL_SWITCH)) + return -EOPNOTSUPP; + + switch (dev->ieee80211_ptr->iftype) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_P2P_GO: + need_new_beacon = true; + /* For all modes except AP the handle_dfs flag needs to be + * supplied to tell the kernel that userspace will handle radar + * events when they happen. Otherwise a switch to a channel + * requiring DFS will be rejected. + */ + need_handle_dfs_flag = false; + + /* useless if AP is not running */ + if (!wdev->links[link_id].ap.beacon_interval) + return -ENOTCONN; + break; + case NL80211_IFTYPE_ADHOC: + if (!wdev->u.ibss.ssid_len) + return -ENOTCONN; + break; + case NL80211_IFTYPE_MESH_POINT: + if (!wdev->u.mesh.id_len) + return -ENOTCONN; + break; + default: + return -EOPNOTSUPP; + } + + memset(¶ms, 0, sizeof(params)); + params.beacon_csa.ftm_responder = -1; + + if (!info->attrs[NL80211_ATTR_WIPHY_FREQ] || + !info->attrs[NL80211_ATTR_CH_SWITCH_COUNT]) + return -EINVAL; + + /* only important for AP, IBSS and mesh create IEs internally */ + if (need_new_beacon && !info->attrs[NL80211_ATTR_CSA_IES]) + return -EINVAL; + + /* Even though the attribute is u32, the specification says + * u8, so let's make sure we don't overflow. + */ + cs_count = nla_get_u32(info->attrs[NL80211_ATTR_CH_SWITCH_COUNT]); + if (cs_count > 255) + return -EINVAL; + + params.count = cs_count; + + if (!need_new_beacon) + goto skip_beacons; + + err = nl80211_parse_beacon(rdev, info->attrs, ¶ms.beacon_after); + if (err) + goto free; + + csa_attrs = kcalloc(NL80211_ATTR_MAX + 1, sizeof(*csa_attrs), + GFP_KERNEL); + if (!csa_attrs) { + err = -ENOMEM; + goto free; + } + + err = nla_parse_nested_deprecated(csa_attrs, NL80211_ATTR_MAX, + info->attrs[NL80211_ATTR_CSA_IES], + nl80211_policy, info->extack); + if (err) + goto free; + + err = nl80211_parse_beacon(rdev, csa_attrs, ¶ms.beacon_csa); + if (err) + goto free; + + if (!csa_attrs[NL80211_ATTR_CNTDWN_OFFS_BEACON]) { + err = -EINVAL; + goto free; + } + + len = nla_len(csa_attrs[NL80211_ATTR_CNTDWN_OFFS_BEACON]); + if (!len || (len % sizeof(u16))) { + err = -EINVAL; + goto free; + } + + params.n_counter_offsets_beacon = len / sizeof(u16); + if (rdev->wiphy.max_num_csa_counters && + (params.n_counter_offsets_beacon > + rdev->wiphy.max_num_csa_counters)) { + err = -EINVAL; + goto free; + } + + params.counter_offsets_beacon = + nla_data(csa_attrs[NL80211_ATTR_CNTDWN_OFFS_BEACON]); + + /* sanity checks - counters should fit and be the same */ + for (i = 0; i < params.n_counter_offsets_beacon; i++) { + u16 offset = params.counter_offsets_beacon[i]; + + if (offset >= params.beacon_csa.tail_len) { + err = -EINVAL; + goto free; + } + + if (params.beacon_csa.tail[offset] != params.count) { + err = -EINVAL; + goto free; + } + } + + if (csa_attrs[NL80211_ATTR_CNTDWN_OFFS_PRESP]) { + len = nla_len(csa_attrs[NL80211_ATTR_CNTDWN_OFFS_PRESP]); + if (!len || (len % sizeof(u16))) { + err = -EINVAL; + goto free; + } + + params.n_counter_offsets_presp = len / sizeof(u16); + if (rdev->wiphy.max_num_csa_counters && + (params.n_counter_offsets_presp > + rdev->wiphy.max_num_csa_counters)) { + err = -EINVAL; + goto free; + } + + params.counter_offsets_presp = + nla_data(csa_attrs[NL80211_ATTR_CNTDWN_OFFS_PRESP]); + + /* sanity checks - counters should fit and be the same */ + for (i = 0; i < params.n_counter_offsets_presp; i++) { + u16 offset = params.counter_offsets_presp[i]; + + if (offset >= params.beacon_csa.probe_resp_len) { + err = -EINVAL; + goto free; + } + + if (params.beacon_csa.probe_resp[offset] != + params.count) { + err = -EINVAL; + goto free; + } + } + } + +skip_beacons: + err = nl80211_parse_chandef(rdev, info, ¶ms.chandef); + if (err) + goto free; + + if (!cfg80211_reg_can_beacon_relax(&rdev->wiphy, ¶ms.chandef, + wdev->iftype)) { + err = -EINVAL; + goto free; + } + + err = cfg80211_chandef_dfs_required(wdev->wiphy, + ¶ms.chandef, + wdev->iftype); + if (err < 0) + goto free; + + if (err > 0) { + params.radar_required = true; + if (need_handle_dfs_flag && + !nla_get_flag(info->attrs[NL80211_ATTR_HANDLE_DFS])) { + err = -EINVAL; + goto free; + } + } + + if (info->attrs[NL80211_ATTR_CH_SWITCH_BLOCK_TX]) + params.block_tx = true; + + wdev_lock(wdev); + err = rdev_channel_switch(rdev, dev, ¶ms); + wdev_unlock(wdev); + +free: + kfree(params.beacon_after.mbssid_ies); + kfree(params.beacon_csa.mbssid_ies); + kfree(csa_attrs); + return err; +} + +static int nl80211_send_bss(struct sk_buff *msg, struct netlink_callback *cb, + u32 seq, int flags, + struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + struct cfg80211_internal_bss *intbss) +{ + struct cfg80211_bss *res = &intbss->pub; + const struct cfg80211_bss_ies *ies; + unsigned int link_id; + void *hdr; + struct nlattr *bss; + + ASSERT_WDEV_LOCK(wdev); + + hdr = nl80211hdr_put(msg, NETLINK_CB(cb->skb).portid, seq, flags, + NL80211_CMD_NEW_SCAN_RESULTS); + if (!hdr) + return -1; + + genl_dump_check_consistent(cb, hdr); + + if (nla_put_u32(msg, NL80211_ATTR_GENERATION, rdev->bss_generation)) + goto nla_put_failure; + if (wdev->netdev && + nla_put_u32(msg, NL80211_ATTR_IFINDEX, wdev->netdev->ifindex)) + goto nla_put_failure; + if (nla_put_u64_64bit(msg, NL80211_ATTR_WDEV, wdev_id(wdev), + NL80211_ATTR_PAD)) + goto nla_put_failure; + + bss = nla_nest_start_noflag(msg, NL80211_ATTR_BSS); + if (!bss) + goto nla_put_failure; + if ((!is_zero_ether_addr(res->bssid) && + nla_put(msg, NL80211_BSS_BSSID, ETH_ALEN, res->bssid))) + goto nla_put_failure; + + rcu_read_lock(); + /* indicate whether we have probe response data or not */ + if (rcu_access_pointer(res->proberesp_ies) && + nla_put_flag(msg, NL80211_BSS_PRESP_DATA)) + goto fail_unlock_rcu; + + /* this pointer prefers to be pointed to probe response data + * but is always valid + */ + ies = rcu_dereference(res->ies); + if (ies) { + if (nla_put_u64_64bit(msg, NL80211_BSS_TSF, ies->tsf, + NL80211_BSS_PAD)) + goto fail_unlock_rcu; + if (ies->len && nla_put(msg, NL80211_BSS_INFORMATION_ELEMENTS, + ies->len, ies->data)) + goto fail_unlock_rcu; + } + + /* and this pointer is always (unless driver didn't know) beacon data */ + ies = rcu_dereference(res->beacon_ies); + if (ies && ies->from_beacon) { + if (nla_put_u64_64bit(msg, NL80211_BSS_BEACON_TSF, ies->tsf, + NL80211_BSS_PAD)) + goto fail_unlock_rcu; + if (ies->len && nla_put(msg, NL80211_BSS_BEACON_IES, + ies->len, ies->data)) + goto fail_unlock_rcu; + } + rcu_read_unlock(); + + if (res->beacon_interval && + nla_put_u16(msg, NL80211_BSS_BEACON_INTERVAL, res->beacon_interval)) + goto nla_put_failure; + if (nla_put_u16(msg, NL80211_BSS_CAPABILITY, res->capability) || + nla_put_u32(msg, NL80211_BSS_FREQUENCY, res->channel->center_freq) || + nla_put_u32(msg, NL80211_BSS_FREQUENCY_OFFSET, + res->channel->freq_offset) || + nla_put_u32(msg, NL80211_BSS_CHAN_WIDTH, res->scan_width) || + nla_put_u32(msg, NL80211_BSS_SEEN_MS_AGO, + jiffies_to_msecs(jiffies - intbss->ts))) + goto nla_put_failure; + + if (intbss->parent_tsf && + (nla_put_u64_64bit(msg, NL80211_BSS_PARENT_TSF, + intbss->parent_tsf, NL80211_BSS_PAD) || + nla_put(msg, NL80211_BSS_PARENT_BSSID, ETH_ALEN, + intbss->parent_bssid))) + goto nla_put_failure; + + if (intbss->ts_boottime && + nla_put_u64_64bit(msg, NL80211_BSS_LAST_SEEN_BOOTTIME, + intbss->ts_boottime, NL80211_BSS_PAD)) + goto nla_put_failure; + + if (!nl80211_put_signal(msg, intbss->pub.chains, + intbss->pub.chain_signal, + NL80211_BSS_CHAIN_SIGNAL)) + goto nla_put_failure; + + switch (rdev->wiphy.signal_type) { + case CFG80211_SIGNAL_TYPE_MBM: + if (nla_put_u32(msg, NL80211_BSS_SIGNAL_MBM, res->signal)) + goto nla_put_failure; + break; + case CFG80211_SIGNAL_TYPE_UNSPEC: + if (nla_put_u8(msg, NL80211_BSS_SIGNAL_UNSPEC, res->signal)) + goto nla_put_failure; + break; + default: + break; + } + + switch (wdev->iftype) { + case NL80211_IFTYPE_P2P_CLIENT: + case NL80211_IFTYPE_STATION: + for_each_valid_link(wdev, link_id) { + if (intbss == wdev->links[link_id].client.current_bss && + (nla_put_u32(msg, NL80211_BSS_STATUS, + NL80211_BSS_STATUS_ASSOCIATED) || + (wdev->valid_links && + (nla_put_u8(msg, NL80211_BSS_MLO_LINK_ID, + link_id) || + nla_put(msg, NL80211_BSS_MLD_ADDR, ETH_ALEN, + wdev->u.client.connected_addr))))) + goto nla_put_failure; + } + break; + case NL80211_IFTYPE_ADHOC: + if (intbss == wdev->u.ibss.current_bss && + nla_put_u32(msg, NL80211_BSS_STATUS, + NL80211_BSS_STATUS_IBSS_JOINED)) + goto nla_put_failure; + break; + default: + break; + } + + nla_nest_end(msg, bss); + + genlmsg_end(msg, hdr); + return 0; + + fail_unlock_rcu: + rcu_read_unlock(); + nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int nl80211_dump_scan(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct cfg80211_registered_device *rdev; + struct cfg80211_internal_bss *scan; + struct wireless_dev *wdev; + int start = cb->args[2], idx = 0; + int err; + + err = nl80211_prepare_wdev_dump(cb, &rdev, &wdev, NULL); + if (err) + return err; + /* nl80211_prepare_wdev_dump acquired it in the successful case */ + __acquire(&rdev->wiphy.mtx); + + wdev_lock(wdev); + spin_lock_bh(&rdev->bss_lock); + + /* + * dump_scan will be called multiple times to break up the scan results + * into multiple messages. It is unlikely that any more bss-es will be + * expired after the first call, so only call only call this on the + * first dump_scan invocation. + */ + if (start == 0) + cfg80211_bss_expire(rdev); + + cb->seq = rdev->bss_generation; + + list_for_each_entry(scan, &rdev->bss_list, list) { + if (++idx <= start) + continue; + if (nl80211_send_bss(skb, cb, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + rdev, wdev, scan) < 0) { + idx--; + break; + } + } + + spin_unlock_bh(&rdev->bss_lock); + wdev_unlock(wdev); + + cb->args[2] = idx; + wiphy_unlock(&rdev->wiphy); + + return skb->len; +} + +static int nl80211_send_survey(struct sk_buff *msg, u32 portid, u32 seq, + int flags, struct net_device *dev, + bool allow_radio_stats, + struct survey_info *survey) +{ + void *hdr; + struct nlattr *infoattr; + + /* skip radio stats if userspace didn't request them */ + if (!survey->channel && !allow_radio_stats) + return 0; + + hdr = nl80211hdr_put(msg, portid, seq, flags, + NL80211_CMD_NEW_SURVEY_RESULTS); + if (!hdr) + return -ENOMEM; + + if (nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex)) + goto nla_put_failure; + + infoattr = nla_nest_start_noflag(msg, NL80211_ATTR_SURVEY_INFO); + if (!infoattr) + goto nla_put_failure; + + if (survey->channel && + nla_put_u32(msg, NL80211_SURVEY_INFO_FREQUENCY, + survey->channel->center_freq)) + goto nla_put_failure; + + if (survey->channel && survey->channel->freq_offset && + nla_put_u32(msg, NL80211_SURVEY_INFO_FREQUENCY_OFFSET, + survey->channel->freq_offset)) + goto nla_put_failure; + + if ((survey->filled & SURVEY_INFO_NOISE_DBM) && + nla_put_u8(msg, NL80211_SURVEY_INFO_NOISE, survey->noise)) + goto nla_put_failure; + if ((survey->filled & SURVEY_INFO_IN_USE) && + nla_put_flag(msg, NL80211_SURVEY_INFO_IN_USE)) + goto nla_put_failure; + if ((survey->filled & SURVEY_INFO_TIME) && + nla_put_u64_64bit(msg, NL80211_SURVEY_INFO_TIME, + survey->time, NL80211_SURVEY_INFO_PAD)) + goto nla_put_failure; + if ((survey->filled & SURVEY_INFO_TIME_BUSY) && + nla_put_u64_64bit(msg, NL80211_SURVEY_INFO_TIME_BUSY, + survey->time_busy, NL80211_SURVEY_INFO_PAD)) + goto nla_put_failure; + if ((survey->filled & SURVEY_INFO_TIME_EXT_BUSY) && + nla_put_u64_64bit(msg, NL80211_SURVEY_INFO_TIME_EXT_BUSY, + survey->time_ext_busy, NL80211_SURVEY_INFO_PAD)) + goto nla_put_failure; + if ((survey->filled & SURVEY_INFO_TIME_RX) && + nla_put_u64_64bit(msg, NL80211_SURVEY_INFO_TIME_RX, + survey->time_rx, NL80211_SURVEY_INFO_PAD)) + goto nla_put_failure; + if ((survey->filled & SURVEY_INFO_TIME_TX) && + nla_put_u64_64bit(msg, NL80211_SURVEY_INFO_TIME_TX, + survey->time_tx, NL80211_SURVEY_INFO_PAD)) + goto nla_put_failure; + if ((survey->filled & SURVEY_INFO_TIME_SCAN) && + nla_put_u64_64bit(msg, NL80211_SURVEY_INFO_TIME_SCAN, + survey->time_scan, NL80211_SURVEY_INFO_PAD)) + goto nla_put_failure; + if ((survey->filled & SURVEY_INFO_TIME_BSS_RX) && + nla_put_u64_64bit(msg, NL80211_SURVEY_INFO_TIME_BSS_RX, + survey->time_bss_rx, NL80211_SURVEY_INFO_PAD)) + goto nla_put_failure; + + nla_nest_end(msg, infoattr); + + genlmsg_end(msg, hdr); + return 0; + + nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int nl80211_dump_survey(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct nlattr **attrbuf; + struct survey_info survey; + struct cfg80211_registered_device *rdev; + struct wireless_dev *wdev; + int survey_idx = cb->args[2]; + int res; + bool radio_stats; + + attrbuf = kcalloc(NUM_NL80211_ATTR, sizeof(*attrbuf), GFP_KERNEL); + if (!attrbuf) + return -ENOMEM; + + res = nl80211_prepare_wdev_dump(cb, &rdev, &wdev, attrbuf); + if (res) { + kfree(attrbuf); + return res; + } + /* nl80211_prepare_wdev_dump acquired it in the successful case */ + __acquire(&rdev->wiphy.mtx); + + /* prepare_wdev_dump parsed the attributes */ + radio_stats = attrbuf[NL80211_ATTR_SURVEY_RADIO_STATS]; + + if (!wdev->netdev) { + res = -EINVAL; + goto out_err; + } + + if (!rdev->ops->dump_survey) { + res = -EOPNOTSUPP; + goto out_err; + } + + while (1) { + wdev_lock(wdev); + res = rdev_dump_survey(rdev, wdev->netdev, survey_idx, &survey); + wdev_unlock(wdev); + if (res == -ENOENT) + break; + if (res) + goto out_err; + + /* don't send disabled channels, but do send non-channel data */ + if (survey.channel && + survey.channel->flags & IEEE80211_CHAN_DISABLED) { + survey_idx++; + continue; + } + + if (nl80211_send_survey(skb, + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + wdev->netdev, radio_stats, &survey) < 0) + goto out; + survey_idx++; + } + + out: + cb->args[2] = survey_idx; + res = skb->len; + out_err: + kfree(attrbuf); + wiphy_unlock(&rdev->wiphy); + return res; +} + +static bool nl80211_valid_wpa_versions(u32 wpa_versions) +{ + return !(wpa_versions & ~(NL80211_WPA_VERSION_1 | + NL80211_WPA_VERSION_2 | + NL80211_WPA_VERSION_3)); +} + +static int nl80211_authenticate(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct ieee80211_channel *chan; + const u8 *bssid, *ssid; + int err, ssid_len; + enum nl80211_auth_type auth_type; + struct key_parse key; + bool local_state_change; + struct cfg80211_auth_request req = {}; + u32 freq; + + if (!info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + if (!info->attrs[NL80211_ATTR_AUTH_TYPE]) + return -EINVAL; + + if (!info->attrs[NL80211_ATTR_SSID]) + return -EINVAL; + + if (!info->attrs[NL80211_ATTR_WIPHY_FREQ]) + return -EINVAL; + + err = nl80211_parse_key(info, &key); + if (err) + return err; + + if (key.idx >= 0) { + if (key.type != -1 && key.type != NL80211_KEYTYPE_GROUP) + return -EINVAL; + if (!key.p.key || !key.p.key_len) + return -EINVAL; + if ((key.p.cipher != WLAN_CIPHER_SUITE_WEP40 || + key.p.key_len != WLAN_KEY_LEN_WEP40) && + (key.p.cipher != WLAN_CIPHER_SUITE_WEP104 || + key.p.key_len != WLAN_KEY_LEN_WEP104)) + return -EINVAL; + if (key.idx > 3) + return -EINVAL; + } else { + key.p.key_len = 0; + key.p.key = NULL; + } + + if (key.idx >= 0) { + int i; + bool ok = false; + + for (i = 0; i < rdev->wiphy.n_cipher_suites; i++) { + if (key.p.cipher == rdev->wiphy.cipher_suites[i]) { + ok = true; + break; + } + } + if (!ok) + return -EINVAL; + } + + if (!rdev->ops->auth) + return -EOPNOTSUPP; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_STATION && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_CLIENT) + return -EOPNOTSUPP; + + bssid = nla_data(info->attrs[NL80211_ATTR_MAC]); + freq = MHZ_TO_KHZ(nla_get_u32(info->attrs[NL80211_ATTR_WIPHY_FREQ])); + if (info->attrs[NL80211_ATTR_WIPHY_FREQ_OFFSET]) + freq += + nla_get_u32(info->attrs[NL80211_ATTR_WIPHY_FREQ_OFFSET]); + + chan = nl80211_get_valid_chan(&rdev->wiphy, freq); + if (!chan) + return -EINVAL; + + ssid = nla_data(info->attrs[NL80211_ATTR_SSID]); + ssid_len = nla_len(info->attrs[NL80211_ATTR_SSID]); + + if (info->attrs[NL80211_ATTR_IE]) { + req.ie = nla_data(info->attrs[NL80211_ATTR_IE]); + req.ie_len = nla_len(info->attrs[NL80211_ATTR_IE]); + } + + auth_type = nla_get_u32(info->attrs[NL80211_ATTR_AUTH_TYPE]); + if (!nl80211_valid_auth_type(rdev, auth_type, NL80211_CMD_AUTHENTICATE)) + return -EINVAL; + + if ((auth_type == NL80211_AUTHTYPE_SAE || + auth_type == NL80211_AUTHTYPE_FILS_SK || + auth_type == NL80211_AUTHTYPE_FILS_SK_PFS || + auth_type == NL80211_AUTHTYPE_FILS_PK) && + !info->attrs[NL80211_ATTR_AUTH_DATA]) + return -EINVAL; + + if (info->attrs[NL80211_ATTR_AUTH_DATA]) { + if (auth_type != NL80211_AUTHTYPE_SAE && + auth_type != NL80211_AUTHTYPE_FILS_SK && + auth_type != NL80211_AUTHTYPE_FILS_SK_PFS && + auth_type != NL80211_AUTHTYPE_FILS_PK) + return -EINVAL; + req.auth_data = nla_data(info->attrs[NL80211_ATTR_AUTH_DATA]); + req.auth_data_len = nla_len(info->attrs[NL80211_ATTR_AUTH_DATA]); + } + + local_state_change = !!info->attrs[NL80211_ATTR_LOCAL_STATE_CHANGE]; + + /* + * Since we no longer track auth state, ignore + * requests to only change local state. + */ + if (local_state_change) + return 0; + + req.auth_type = auth_type; + req.key = key.p.key; + req.key_len = key.p.key_len; + req.key_idx = key.idx; + req.link_id = nl80211_link_id_or_invalid(info->attrs); + if (req.link_id >= 0) { + if (!(rdev->wiphy.flags & WIPHY_FLAG_SUPPORTS_MLO)) + return -EINVAL; + if (!info->attrs[NL80211_ATTR_MLD_ADDR]) + return -EINVAL; + req.ap_mld_addr = nla_data(info->attrs[NL80211_ATTR_MLD_ADDR]); + if (!is_valid_ether_addr(req.ap_mld_addr)) + return -EINVAL; + } + + req.bss = cfg80211_get_bss(&rdev->wiphy, chan, bssid, ssid, ssid_len, + IEEE80211_BSS_TYPE_ESS, + IEEE80211_PRIVACY_ANY); + if (!req.bss) + return -ENOENT; + + wdev_lock(dev->ieee80211_ptr); + err = cfg80211_mlme_auth(rdev, dev, &req); + wdev_unlock(dev->ieee80211_ptr); + + cfg80211_put_bss(&rdev->wiphy, req.bss); + + return err; +} + +static int validate_pae_over_nl80211(struct cfg80211_registered_device *rdev, + struct genl_info *info) +{ + if (!info->attrs[NL80211_ATTR_SOCKET_OWNER]) { + GENL_SET_ERR_MSG(info, "SOCKET_OWNER not set"); + return -EINVAL; + } + + if (!rdev->ops->tx_control_port || + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_CONTROL_PORT_OVER_NL80211)) + return -EOPNOTSUPP; + + return 0; +} + +static int nl80211_crypto_settings(struct cfg80211_registered_device *rdev, + struct genl_info *info, + struct cfg80211_crypto_settings *settings, + int cipher_limit) +{ + memset(settings, 0, sizeof(*settings)); + + settings->control_port = info->attrs[NL80211_ATTR_CONTROL_PORT]; + + if (info->attrs[NL80211_ATTR_CONTROL_PORT_ETHERTYPE]) { + u16 proto; + + proto = nla_get_u16( + info->attrs[NL80211_ATTR_CONTROL_PORT_ETHERTYPE]); + settings->control_port_ethertype = cpu_to_be16(proto); + if (!(rdev->wiphy.flags & WIPHY_FLAG_CONTROL_PORT_PROTOCOL) && + proto != ETH_P_PAE) + return -EINVAL; + if (info->attrs[NL80211_ATTR_CONTROL_PORT_NO_ENCRYPT]) + settings->control_port_no_encrypt = true; + } else + settings->control_port_ethertype = cpu_to_be16(ETH_P_PAE); + + if (info->attrs[NL80211_ATTR_CONTROL_PORT_OVER_NL80211]) { + int r = validate_pae_over_nl80211(rdev, info); + + if (r < 0) + return r; + + settings->control_port_over_nl80211 = true; + + if (info->attrs[NL80211_ATTR_CONTROL_PORT_NO_PREAUTH]) + settings->control_port_no_preauth = true; + } + + if (info->attrs[NL80211_ATTR_CIPHER_SUITES_PAIRWISE]) { + void *data; + int len, i; + + data = nla_data(info->attrs[NL80211_ATTR_CIPHER_SUITES_PAIRWISE]); + len = nla_len(info->attrs[NL80211_ATTR_CIPHER_SUITES_PAIRWISE]); + settings->n_ciphers_pairwise = len / sizeof(u32); + + if (len % sizeof(u32)) + return -EINVAL; + + if (settings->n_ciphers_pairwise > cipher_limit) + return -EINVAL; + + memcpy(settings->ciphers_pairwise, data, len); + + for (i = 0; i < settings->n_ciphers_pairwise; i++) + if (!cfg80211_supported_cipher_suite( + &rdev->wiphy, + settings->ciphers_pairwise[i])) + return -EINVAL; + } + + if (info->attrs[NL80211_ATTR_CIPHER_SUITE_GROUP]) { + settings->cipher_group = + nla_get_u32(info->attrs[NL80211_ATTR_CIPHER_SUITE_GROUP]); + if (!cfg80211_supported_cipher_suite(&rdev->wiphy, + settings->cipher_group)) + return -EINVAL; + } + + if (info->attrs[NL80211_ATTR_WPA_VERSIONS]) { + settings->wpa_versions = + nla_get_u32(info->attrs[NL80211_ATTR_WPA_VERSIONS]); + if (!nl80211_valid_wpa_versions(settings->wpa_versions)) + return -EINVAL; + } + + if (info->attrs[NL80211_ATTR_AKM_SUITES]) { + void *data; + int len; + + data = nla_data(info->attrs[NL80211_ATTR_AKM_SUITES]); + len = nla_len(info->attrs[NL80211_ATTR_AKM_SUITES]); + settings->n_akm_suites = len / sizeof(u32); + + if (len % sizeof(u32)) + return -EINVAL; + + if (settings->n_akm_suites > rdev->wiphy.max_num_akm_suites) + return -EINVAL; + + memcpy(settings->akm_suites, data, len); + } + + if (info->attrs[NL80211_ATTR_PMK]) { + if (nla_len(info->attrs[NL80211_ATTR_PMK]) != WLAN_PMK_LEN) + return -EINVAL; + if (!wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_4WAY_HANDSHAKE_STA_PSK) && + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_4WAY_HANDSHAKE_AP_PSK)) + return -EINVAL; + settings->psk = nla_data(info->attrs[NL80211_ATTR_PMK]); + } + + if (info->attrs[NL80211_ATTR_SAE_PASSWORD]) { + if (!wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_SAE_OFFLOAD) && + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_SAE_OFFLOAD_AP)) + return -EINVAL; + settings->sae_pwd = + nla_data(info->attrs[NL80211_ATTR_SAE_PASSWORD]); + settings->sae_pwd_len = + nla_len(info->attrs[NL80211_ATTR_SAE_PASSWORD]); + } + + if (info->attrs[NL80211_ATTR_SAE_PWE]) + settings->sae_pwe = + nla_get_u8(info->attrs[NL80211_ATTR_SAE_PWE]); + else + settings->sae_pwe = NL80211_SAE_PWE_UNSPECIFIED; + + return 0; +} + +static struct cfg80211_bss *nl80211_assoc_bss(struct cfg80211_registered_device *rdev, + const u8 *ssid, int ssid_len, + struct nlattr **attrs) +{ + struct ieee80211_channel *chan; + struct cfg80211_bss *bss; + const u8 *bssid; + u32 freq; + + if (!attrs[NL80211_ATTR_MAC] || !attrs[NL80211_ATTR_WIPHY_FREQ]) + return ERR_PTR(-EINVAL); + + bssid = nla_data(attrs[NL80211_ATTR_MAC]); + + freq = MHZ_TO_KHZ(nla_get_u32(attrs[NL80211_ATTR_WIPHY_FREQ])); + if (attrs[NL80211_ATTR_WIPHY_FREQ_OFFSET]) + freq += nla_get_u32(attrs[NL80211_ATTR_WIPHY_FREQ_OFFSET]); + + chan = nl80211_get_valid_chan(&rdev->wiphy, freq); + if (!chan) + return ERR_PTR(-EINVAL); + + bss = cfg80211_get_bss(&rdev->wiphy, chan, bssid, + ssid, ssid_len, + IEEE80211_BSS_TYPE_ESS, + IEEE80211_PRIVACY_ANY); + if (!bss) + return ERR_PTR(-ENOENT); + + return bss; +} + +static int nl80211_associate(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct cfg80211_assoc_request req = {}; + struct nlattr **attrs = NULL; + const u8 *ap_addr, *ssid; + unsigned int link_id; + int err, ssid_len; + + if (dev->ieee80211_ptr->conn_owner_nlportid && + dev->ieee80211_ptr->conn_owner_nlportid != info->snd_portid) + return -EPERM; + + if (!info->attrs[NL80211_ATTR_SSID]) + return -EINVAL; + + if (!rdev->ops->assoc) + return -EOPNOTSUPP; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_STATION && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_CLIENT) + return -EOPNOTSUPP; + + ssid = nla_data(info->attrs[NL80211_ATTR_SSID]); + ssid_len = nla_len(info->attrs[NL80211_ATTR_SSID]); + + if (info->attrs[NL80211_ATTR_IE]) { + req.ie = nla_data(info->attrs[NL80211_ATTR_IE]); + req.ie_len = nla_len(info->attrs[NL80211_ATTR_IE]); + + if (cfg80211_find_ext_elem(WLAN_EID_EXT_NON_INHERITANCE, + req.ie, req.ie_len)) { + GENL_SET_ERR_MSG(info, + "non-inheritance makes no sense"); + return -EINVAL; + } + } + + if (info->attrs[NL80211_ATTR_USE_MFP]) { + enum nl80211_mfp mfp = + nla_get_u32(info->attrs[NL80211_ATTR_USE_MFP]); + if (mfp == NL80211_MFP_REQUIRED) + req.use_mfp = true; + else if (mfp != NL80211_MFP_NO) + return -EINVAL; + } + + if (info->attrs[NL80211_ATTR_PREV_BSSID]) + req.prev_bssid = nla_data(info->attrs[NL80211_ATTR_PREV_BSSID]); + + if (nla_get_flag(info->attrs[NL80211_ATTR_DISABLE_HT])) + req.flags |= ASSOC_REQ_DISABLE_HT; + + if (info->attrs[NL80211_ATTR_HT_CAPABILITY_MASK]) + memcpy(&req.ht_capa_mask, + nla_data(info->attrs[NL80211_ATTR_HT_CAPABILITY_MASK]), + sizeof(req.ht_capa_mask)); + + if (info->attrs[NL80211_ATTR_HT_CAPABILITY]) { + if (!info->attrs[NL80211_ATTR_HT_CAPABILITY_MASK]) + return -EINVAL; + memcpy(&req.ht_capa, + nla_data(info->attrs[NL80211_ATTR_HT_CAPABILITY]), + sizeof(req.ht_capa)); + } + + if (nla_get_flag(info->attrs[NL80211_ATTR_DISABLE_VHT])) + req.flags |= ASSOC_REQ_DISABLE_VHT; + + if (nla_get_flag(info->attrs[NL80211_ATTR_DISABLE_HE])) + req.flags |= ASSOC_REQ_DISABLE_HE; + + if (nla_get_flag(info->attrs[NL80211_ATTR_DISABLE_EHT])) + req.flags |= ASSOC_REQ_DISABLE_EHT; + + if (info->attrs[NL80211_ATTR_VHT_CAPABILITY_MASK]) + memcpy(&req.vht_capa_mask, + nla_data(info->attrs[NL80211_ATTR_VHT_CAPABILITY_MASK]), + sizeof(req.vht_capa_mask)); + + if (info->attrs[NL80211_ATTR_VHT_CAPABILITY]) { + if (!info->attrs[NL80211_ATTR_VHT_CAPABILITY_MASK]) + return -EINVAL; + memcpy(&req.vht_capa, + nla_data(info->attrs[NL80211_ATTR_VHT_CAPABILITY]), + sizeof(req.vht_capa)); + } + + if (nla_get_flag(info->attrs[NL80211_ATTR_USE_RRM])) { + if (!((rdev->wiphy.features & + NL80211_FEATURE_DS_PARAM_SET_IE_IN_PROBES) && + (rdev->wiphy.features & NL80211_FEATURE_QUIET)) && + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_RRM)) + return -EINVAL; + req.flags |= ASSOC_REQ_USE_RRM; + } + + if (info->attrs[NL80211_ATTR_FILS_KEK]) { + req.fils_kek = nla_data(info->attrs[NL80211_ATTR_FILS_KEK]); + req.fils_kek_len = nla_len(info->attrs[NL80211_ATTR_FILS_KEK]); + if (!info->attrs[NL80211_ATTR_FILS_NONCES]) + return -EINVAL; + req.fils_nonces = + nla_data(info->attrs[NL80211_ATTR_FILS_NONCES]); + } + + if (info->attrs[NL80211_ATTR_S1G_CAPABILITY_MASK]) { + if (!info->attrs[NL80211_ATTR_S1G_CAPABILITY]) + return -EINVAL; + memcpy(&req.s1g_capa_mask, + nla_data(info->attrs[NL80211_ATTR_S1G_CAPABILITY_MASK]), + sizeof(req.s1g_capa_mask)); + } + + if (info->attrs[NL80211_ATTR_S1G_CAPABILITY]) { + if (!info->attrs[NL80211_ATTR_S1G_CAPABILITY_MASK]) + return -EINVAL; + memcpy(&req.s1g_capa, + nla_data(info->attrs[NL80211_ATTR_S1G_CAPABILITY]), + sizeof(req.s1g_capa)); + } + + req.link_id = nl80211_link_id_or_invalid(info->attrs); + + if (info->attrs[NL80211_ATTR_MLO_LINKS]) { + unsigned int attrsize = NUM_NL80211_ATTR * sizeof(*attrs); + struct nlattr *link; + int rem = 0; + + if (req.link_id < 0) + return -EINVAL; + + if (!(rdev->wiphy.flags & WIPHY_FLAG_SUPPORTS_MLO)) + return -EINVAL; + + if (info->attrs[NL80211_ATTR_MAC] || + info->attrs[NL80211_ATTR_WIPHY_FREQ] || + !info->attrs[NL80211_ATTR_MLD_ADDR]) + return -EINVAL; + + req.ap_mld_addr = nla_data(info->attrs[NL80211_ATTR_MLD_ADDR]); + ap_addr = req.ap_mld_addr; + + attrs = kzalloc(attrsize, GFP_KERNEL); + if (!attrs) + return -ENOMEM; + + nla_for_each_nested(link, + info->attrs[NL80211_ATTR_MLO_LINKS], + rem) { + memset(attrs, 0, attrsize); + + nla_parse_nested(attrs, NL80211_ATTR_MAX, + link, NULL, NULL); + + if (!attrs[NL80211_ATTR_MLO_LINK_ID]) { + err = -EINVAL; + goto free; + } + + link_id = nla_get_u8(attrs[NL80211_ATTR_MLO_LINK_ID]); + /* cannot use the same link ID again */ + if (req.links[link_id].bss) { + err = -EINVAL; + goto free; + } + req.links[link_id].bss = + nl80211_assoc_bss(rdev, ssid, ssid_len, attrs); + if (IS_ERR(req.links[link_id].bss)) { + err = PTR_ERR(req.links[link_id].bss); + req.links[link_id].bss = NULL; + goto free; + } + + if (attrs[NL80211_ATTR_IE]) { + req.links[link_id].elems = + nla_data(attrs[NL80211_ATTR_IE]); + req.links[link_id].elems_len = + nla_len(attrs[NL80211_ATTR_IE]); + + if (cfg80211_find_elem(WLAN_EID_FRAGMENT, + req.links[link_id].elems, + req.links[link_id].elems_len)) { + GENL_SET_ERR_MSG(info, + "cannot deal with fragmentation"); + err = -EINVAL; + goto free; + } + + if (cfg80211_find_ext_elem(WLAN_EID_EXT_NON_INHERITANCE, + req.links[link_id].elems, + req.links[link_id].elems_len)) { + GENL_SET_ERR_MSG(info, + "cannot deal with non-inheritance"); + err = -EINVAL; + goto free; + } + } + } + + if (!req.links[req.link_id].bss) { + err = -EINVAL; + goto free; + } + + if (req.links[req.link_id].elems_len) { + GENL_SET_ERR_MSG(info, + "cannot have per-link elems on assoc link"); + err = -EINVAL; + goto free; + } + + kfree(attrs); + attrs = NULL; + } else { + if (req.link_id >= 0) + return -EINVAL; + + req.bss = nl80211_assoc_bss(rdev, ssid, ssid_len, info->attrs); + if (IS_ERR(req.bss)) + return PTR_ERR(req.bss); + ap_addr = req.bss->bssid; + } + + err = nl80211_crypto_settings(rdev, info, &req.crypto, 1); + if (!err) { + wdev_lock(dev->ieee80211_ptr); + + err = cfg80211_mlme_assoc(rdev, dev, &req); + + if (!err && info->attrs[NL80211_ATTR_SOCKET_OWNER]) { + dev->ieee80211_ptr->conn_owner_nlportid = + info->snd_portid; + memcpy(dev->ieee80211_ptr->disconnect_bssid, + ap_addr, ETH_ALEN); + } + + wdev_unlock(dev->ieee80211_ptr); + } + +free: + for (link_id = 0; link_id < ARRAY_SIZE(req.links); link_id++) + cfg80211_put_bss(&rdev->wiphy, req.links[link_id].bss); + cfg80211_put_bss(&rdev->wiphy, req.bss); + kfree(attrs); + + return err; +} + +static int nl80211_deauthenticate(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + const u8 *ie = NULL, *bssid; + int ie_len = 0, err; + u16 reason_code; + bool local_state_change; + + if (dev->ieee80211_ptr->conn_owner_nlportid && + dev->ieee80211_ptr->conn_owner_nlportid != info->snd_portid) + return -EPERM; + + if (!info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + if (!info->attrs[NL80211_ATTR_REASON_CODE]) + return -EINVAL; + + if (!rdev->ops->deauth) + return -EOPNOTSUPP; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_STATION && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_CLIENT) + return -EOPNOTSUPP; + + bssid = nla_data(info->attrs[NL80211_ATTR_MAC]); + + reason_code = nla_get_u16(info->attrs[NL80211_ATTR_REASON_CODE]); + if (reason_code == 0) { + /* Reason Code 0 is reserved */ + return -EINVAL; + } + + if (info->attrs[NL80211_ATTR_IE]) { + ie = nla_data(info->attrs[NL80211_ATTR_IE]); + ie_len = nla_len(info->attrs[NL80211_ATTR_IE]); + } + + local_state_change = !!info->attrs[NL80211_ATTR_LOCAL_STATE_CHANGE]; + + wdev_lock(dev->ieee80211_ptr); + err = cfg80211_mlme_deauth(rdev, dev, bssid, ie, ie_len, reason_code, + local_state_change); + wdev_unlock(dev->ieee80211_ptr); + return err; +} + +static int nl80211_disassociate(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + const u8 *ie = NULL, *bssid; + int ie_len = 0, err; + u16 reason_code; + bool local_state_change; + + if (dev->ieee80211_ptr->conn_owner_nlportid && + dev->ieee80211_ptr->conn_owner_nlportid != info->snd_portid) + return -EPERM; + + if (!info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + if (!info->attrs[NL80211_ATTR_REASON_CODE]) + return -EINVAL; + + if (!rdev->ops->disassoc) + return -EOPNOTSUPP; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_STATION && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_CLIENT) + return -EOPNOTSUPP; + + bssid = nla_data(info->attrs[NL80211_ATTR_MAC]); + + reason_code = nla_get_u16(info->attrs[NL80211_ATTR_REASON_CODE]); + if (reason_code == 0) { + /* Reason Code 0 is reserved */ + return -EINVAL; + } + + if (info->attrs[NL80211_ATTR_IE]) { + ie = nla_data(info->attrs[NL80211_ATTR_IE]); + ie_len = nla_len(info->attrs[NL80211_ATTR_IE]); + } + + local_state_change = !!info->attrs[NL80211_ATTR_LOCAL_STATE_CHANGE]; + + wdev_lock(dev->ieee80211_ptr); + err = cfg80211_mlme_disassoc(rdev, dev, bssid, ie, ie_len, reason_code, + local_state_change); + wdev_unlock(dev->ieee80211_ptr); + return err; +} + +static bool +nl80211_parse_mcast_rate(struct cfg80211_registered_device *rdev, + int mcast_rate[NUM_NL80211_BANDS], + int rateval) +{ + struct wiphy *wiphy = &rdev->wiphy; + bool found = false; + int band, i; + + for (band = 0; band < NUM_NL80211_BANDS; band++) { + struct ieee80211_supported_band *sband; + + sband = wiphy->bands[band]; + if (!sband) + continue; + + for (i = 0; i < sband->n_bitrates; i++) { + if (sband->bitrates[i].bitrate == rateval) { + mcast_rate[band] = i + 1; + found = true; + break; + } + } + } + + return found; +} + +static int nl80211_join_ibss(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct cfg80211_ibss_params ibss; + struct wiphy *wiphy; + struct cfg80211_cached_keys *connkeys = NULL; + int err; + + memset(&ibss, 0, sizeof(ibss)); + + if (!info->attrs[NL80211_ATTR_SSID] || + !nla_len(info->attrs[NL80211_ATTR_SSID])) + return -EINVAL; + + ibss.beacon_interval = 100; + + if (info->attrs[NL80211_ATTR_BEACON_INTERVAL]) + ibss.beacon_interval = + nla_get_u32(info->attrs[NL80211_ATTR_BEACON_INTERVAL]); + + err = cfg80211_validate_beacon_int(rdev, NL80211_IFTYPE_ADHOC, + ibss.beacon_interval); + if (err) + return err; + + if (!rdev->ops->join_ibss) + return -EOPNOTSUPP; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_ADHOC) + return -EOPNOTSUPP; + + wiphy = &rdev->wiphy; + + if (info->attrs[NL80211_ATTR_MAC]) { + ibss.bssid = nla_data(info->attrs[NL80211_ATTR_MAC]); + + if (!is_valid_ether_addr(ibss.bssid)) + return -EINVAL; + } + ibss.ssid = nla_data(info->attrs[NL80211_ATTR_SSID]); + ibss.ssid_len = nla_len(info->attrs[NL80211_ATTR_SSID]); + + if (info->attrs[NL80211_ATTR_IE]) { + ibss.ie = nla_data(info->attrs[NL80211_ATTR_IE]); + ibss.ie_len = nla_len(info->attrs[NL80211_ATTR_IE]); + } + + err = nl80211_parse_chandef(rdev, info, &ibss.chandef); + if (err) + return err; + + if (!cfg80211_reg_can_beacon(&rdev->wiphy, &ibss.chandef, + NL80211_IFTYPE_ADHOC)) + return -EINVAL; + + switch (ibss.chandef.width) { + case NL80211_CHAN_WIDTH_5: + case NL80211_CHAN_WIDTH_10: + case NL80211_CHAN_WIDTH_20_NOHT: + break; + case NL80211_CHAN_WIDTH_20: + case NL80211_CHAN_WIDTH_40: + if (!(rdev->wiphy.features & NL80211_FEATURE_HT_IBSS)) + return -EINVAL; + break; + case NL80211_CHAN_WIDTH_80: + case NL80211_CHAN_WIDTH_80P80: + case NL80211_CHAN_WIDTH_160: + if (!(rdev->wiphy.features & NL80211_FEATURE_HT_IBSS)) + return -EINVAL; + if (!wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_VHT_IBSS)) + return -EINVAL; + break; + case NL80211_CHAN_WIDTH_320: + return -EINVAL; + default: + return -EINVAL; + } + + ibss.channel_fixed = !!info->attrs[NL80211_ATTR_FREQ_FIXED]; + ibss.privacy = !!info->attrs[NL80211_ATTR_PRIVACY]; + + if (info->attrs[NL80211_ATTR_BSS_BASIC_RATES]) { + u8 *rates = + nla_data(info->attrs[NL80211_ATTR_BSS_BASIC_RATES]); + int n_rates = + nla_len(info->attrs[NL80211_ATTR_BSS_BASIC_RATES]); + struct ieee80211_supported_band *sband = + wiphy->bands[ibss.chandef.chan->band]; + + err = ieee80211_get_ratemask(sband, rates, n_rates, + &ibss.basic_rates); + if (err) + return err; + } + + if (info->attrs[NL80211_ATTR_HT_CAPABILITY_MASK]) + memcpy(&ibss.ht_capa_mask, + nla_data(info->attrs[NL80211_ATTR_HT_CAPABILITY_MASK]), + sizeof(ibss.ht_capa_mask)); + + if (info->attrs[NL80211_ATTR_HT_CAPABILITY]) { + if (!info->attrs[NL80211_ATTR_HT_CAPABILITY_MASK]) + return -EINVAL; + memcpy(&ibss.ht_capa, + nla_data(info->attrs[NL80211_ATTR_HT_CAPABILITY]), + sizeof(ibss.ht_capa)); + } + + if (info->attrs[NL80211_ATTR_MCAST_RATE] && + !nl80211_parse_mcast_rate(rdev, ibss.mcast_rate, + nla_get_u32(info->attrs[NL80211_ATTR_MCAST_RATE]))) + return -EINVAL; + + if (ibss.privacy && info->attrs[NL80211_ATTR_KEYS]) { + bool no_ht = false; + + connkeys = nl80211_parse_connkeys(rdev, info, &no_ht); + if (IS_ERR(connkeys)) + return PTR_ERR(connkeys); + + if ((ibss.chandef.width != NL80211_CHAN_WIDTH_20_NOHT) && + no_ht) { + kfree_sensitive(connkeys); + return -EINVAL; + } + } + + ibss.control_port = + nla_get_flag(info->attrs[NL80211_ATTR_CONTROL_PORT]); + + if (info->attrs[NL80211_ATTR_CONTROL_PORT_OVER_NL80211]) { + int r = validate_pae_over_nl80211(rdev, info); + + if (r < 0) { + kfree_sensitive(connkeys); + return r; + } + + ibss.control_port_over_nl80211 = true; + } + + ibss.userspace_handles_dfs = + nla_get_flag(info->attrs[NL80211_ATTR_HANDLE_DFS]); + + wdev_lock(dev->ieee80211_ptr); + err = __cfg80211_join_ibss(rdev, dev, &ibss, connkeys); + if (err) + kfree_sensitive(connkeys); + else if (info->attrs[NL80211_ATTR_SOCKET_OWNER]) + dev->ieee80211_ptr->conn_owner_nlportid = info->snd_portid; + wdev_unlock(dev->ieee80211_ptr); + + return err; +} + +static int nl80211_leave_ibss(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + + if (!rdev->ops->leave_ibss) + return -EOPNOTSUPP; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_ADHOC) + return -EOPNOTSUPP; + + return cfg80211_leave_ibss(rdev, dev, false); +} + +static int nl80211_set_mcast_rate(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + int mcast_rate[NUM_NL80211_BANDS]; + u32 nla_rate; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_ADHOC && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_MESH_POINT && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_OCB) + return -EOPNOTSUPP; + + if (!rdev->ops->set_mcast_rate) + return -EOPNOTSUPP; + + memset(mcast_rate, 0, sizeof(mcast_rate)); + + if (!info->attrs[NL80211_ATTR_MCAST_RATE]) + return -EINVAL; + + nla_rate = nla_get_u32(info->attrs[NL80211_ATTR_MCAST_RATE]); + if (!nl80211_parse_mcast_rate(rdev, mcast_rate, nla_rate)) + return -EINVAL; + + return rdev_set_mcast_rate(rdev, dev, mcast_rate); +} + +static struct sk_buff * +__cfg80211_alloc_vendor_skb(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, int approxlen, + u32 portid, u32 seq, enum nl80211_commands cmd, + enum nl80211_attrs attr, + const struct nl80211_vendor_cmd_info *info, + gfp_t gfp) +{ + struct sk_buff *skb; + void *hdr; + struct nlattr *data; + + skb = nlmsg_new(approxlen + 100, gfp); + if (!skb) + return NULL; + + hdr = nl80211hdr_put(skb, portid, seq, 0, cmd); + if (!hdr) { + kfree_skb(skb); + return NULL; + } + + if (nla_put_u32(skb, NL80211_ATTR_WIPHY, rdev->wiphy_idx)) + goto nla_put_failure; + + if (info) { + if (nla_put_u32(skb, NL80211_ATTR_VENDOR_ID, + info->vendor_id)) + goto nla_put_failure; + if (nla_put_u32(skb, NL80211_ATTR_VENDOR_SUBCMD, + info->subcmd)) + goto nla_put_failure; + } + + if (wdev) { + if (nla_put_u64_64bit(skb, NL80211_ATTR_WDEV, + wdev_id(wdev), NL80211_ATTR_PAD)) + goto nla_put_failure; + if (wdev->netdev && + nla_put_u32(skb, NL80211_ATTR_IFINDEX, + wdev->netdev->ifindex)) + goto nla_put_failure; + } + + data = nla_nest_start_noflag(skb, attr); + if (!data) + goto nla_put_failure; + + ((void **)skb->cb)[0] = rdev; + ((void **)skb->cb)[1] = hdr; + ((void **)skb->cb)[2] = data; + + return skb; + + nla_put_failure: + kfree_skb(skb); + return NULL; +} + +struct sk_buff *__cfg80211_alloc_event_skb(struct wiphy *wiphy, + struct wireless_dev *wdev, + enum nl80211_commands cmd, + enum nl80211_attrs attr, + unsigned int portid, + int vendor_event_idx, + int approxlen, gfp_t gfp) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + const struct nl80211_vendor_cmd_info *info; + + switch (cmd) { + case NL80211_CMD_TESTMODE: + if (WARN_ON(vendor_event_idx != -1)) + return NULL; + info = NULL; + break; + case NL80211_CMD_VENDOR: + if (WARN_ON(vendor_event_idx < 0 || + vendor_event_idx >= wiphy->n_vendor_events)) + return NULL; + info = &wiphy->vendor_events[vendor_event_idx]; + break; + default: + WARN_ON(1); + return NULL; + } + + return __cfg80211_alloc_vendor_skb(rdev, wdev, approxlen, portid, 0, + cmd, attr, info, gfp); +} +EXPORT_SYMBOL(__cfg80211_alloc_event_skb); + +void __cfg80211_send_event_skb(struct sk_buff *skb, gfp_t gfp) +{ + struct cfg80211_registered_device *rdev = ((void **)skb->cb)[0]; + void *hdr = ((void **)skb->cb)[1]; + struct nlmsghdr *nlhdr = nlmsg_hdr(skb); + struct nlattr *data = ((void **)skb->cb)[2]; + enum nl80211_multicast_groups mcgrp = NL80211_MCGRP_TESTMODE; + + /* clear CB data for netlink core to own from now on */ + memset(skb->cb, 0, sizeof(skb->cb)); + + nla_nest_end(skb, data); + genlmsg_end(skb, hdr); + + if (nlhdr->nlmsg_pid) { + genlmsg_unicast(wiphy_net(&rdev->wiphy), skb, + nlhdr->nlmsg_pid); + } else { + if (data->nla_type == NL80211_ATTR_VENDOR_DATA) + mcgrp = NL80211_MCGRP_VENDOR; + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), + skb, 0, mcgrp, gfp); + } +} +EXPORT_SYMBOL(__cfg80211_send_event_skb); + +#ifdef CONFIG_NL80211_TESTMODE +static int nl80211_testmode_do(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev; + int err; + + lockdep_assert_held(&rdev->wiphy.mtx); + + wdev = __cfg80211_wdev_from_attrs(rdev, genl_info_net(info), + info->attrs); + + if (!rdev->ops->testmode_cmd) + return -EOPNOTSUPP; + + if (IS_ERR(wdev)) { + err = PTR_ERR(wdev); + if (err != -EINVAL) + return err; + wdev = NULL; + } else if (wdev->wiphy != &rdev->wiphy) { + return -EINVAL; + } + + if (!info->attrs[NL80211_ATTR_TESTDATA]) + return -EINVAL; + + rdev->cur_cmd_info = info; + err = rdev_testmode_cmd(rdev, wdev, + nla_data(info->attrs[NL80211_ATTR_TESTDATA]), + nla_len(info->attrs[NL80211_ATTR_TESTDATA])); + rdev->cur_cmd_info = NULL; + + return err; +} + +static int nl80211_testmode_dump(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct cfg80211_registered_device *rdev; + struct nlattr **attrbuf = NULL; + int err; + long phy_idx; + void *data = NULL; + int data_len = 0; + + rtnl_lock(); + + if (cb->args[0]) { + /* + * 0 is a valid index, but not valid for args[0], + * so we need to offset by 1. + */ + phy_idx = cb->args[0] - 1; + + rdev = cfg80211_rdev_by_wiphy_idx(phy_idx); + if (!rdev) { + err = -ENOENT; + goto out_err; + } + } else { + attrbuf = kcalloc(NUM_NL80211_ATTR, sizeof(*attrbuf), + GFP_KERNEL); + if (!attrbuf) { + err = -ENOMEM; + goto out_err; + } + + err = nlmsg_parse_deprecated(cb->nlh, + GENL_HDRLEN + nl80211_fam.hdrsize, + attrbuf, nl80211_fam.maxattr, + nl80211_policy, NULL); + if (err) + goto out_err; + + rdev = __cfg80211_rdev_from_attrs(sock_net(skb->sk), attrbuf); + if (IS_ERR(rdev)) { + err = PTR_ERR(rdev); + goto out_err; + } + phy_idx = rdev->wiphy_idx; + + if (attrbuf[NL80211_ATTR_TESTDATA]) + cb->args[1] = (long)attrbuf[NL80211_ATTR_TESTDATA]; + } + + if (cb->args[1]) { + data = nla_data((void *)cb->args[1]); + data_len = nla_len((void *)cb->args[1]); + } + + if (!rdev->ops->testmode_dump) { + err = -EOPNOTSUPP; + goto out_err; + } + + while (1) { + void *hdr = nl80211hdr_put(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + NL80211_CMD_TESTMODE); + struct nlattr *tmdata; + + if (!hdr) + break; + + if (nla_put_u32(skb, NL80211_ATTR_WIPHY, phy_idx)) { + genlmsg_cancel(skb, hdr); + break; + } + + tmdata = nla_nest_start_noflag(skb, NL80211_ATTR_TESTDATA); + if (!tmdata) { + genlmsg_cancel(skb, hdr); + break; + } + err = rdev_testmode_dump(rdev, skb, cb, data, data_len); + nla_nest_end(skb, tmdata); + + if (err == -ENOBUFS || err == -ENOENT) { + genlmsg_cancel(skb, hdr); + break; + } else if (err) { + genlmsg_cancel(skb, hdr); + goto out_err; + } + + genlmsg_end(skb, hdr); + } + + err = skb->len; + /* see above */ + cb->args[0] = phy_idx + 1; + out_err: + kfree(attrbuf); + rtnl_unlock(); + return err; +} +#endif + +static int nl80211_connect(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct cfg80211_connect_params connect; + struct wiphy *wiphy; + struct cfg80211_cached_keys *connkeys = NULL; + u32 freq = 0; + int err; + + memset(&connect, 0, sizeof(connect)); + + if (!info->attrs[NL80211_ATTR_SSID] || + !nla_len(info->attrs[NL80211_ATTR_SSID])) + return -EINVAL; + + if (info->attrs[NL80211_ATTR_AUTH_TYPE]) { + connect.auth_type = + nla_get_u32(info->attrs[NL80211_ATTR_AUTH_TYPE]); + if (!nl80211_valid_auth_type(rdev, connect.auth_type, + NL80211_CMD_CONNECT)) + return -EINVAL; + } else + connect.auth_type = NL80211_AUTHTYPE_AUTOMATIC; + + connect.privacy = info->attrs[NL80211_ATTR_PRIVACY]; + + if (info->attrs[NL80211_ATTR_WANT_1X_4WAY_HS] && + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_4WAY_HANDSHAKE_STA_1X)) + return -EINVAL; + connect.want_1x = info->attrs[NL80211_ATTR_WANT_1X_4WAY_HS]; + + err = nl80211_crypto_settings(rdev, info, &connect.crypto, + NL80211_MAX_NR_CIPHER_SUITES); + if (err) + return err; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_STATION && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_CLIENT) + return -EOPNOTSUPP; + + wiphy = &rdev->wiphy; + + connect.bg_scan_period = -1; + if (info->attrs[NL80211_ATTR_BG_SCAN_PERIOD] && + (wiphy->flags & WIPHY_FLAG_SUPPORTS_FW_ROAM)) { + connect.bg_scan_period = + nla_get_u16(info->attrs[NL80211_ATTR_BG_SCAN_PERIOD]); + } + + if (info->attrs[NL80211_ATTR_MAC]) + connect.bssid = nla_data(info->attrs[NL80211_ATTR_MAC]); + else if (info->attrs[NL80211_ATTR_MAC_HINT]) + connect.bssid_hint = + nla_data(info->attrs[NL80211_ATTR_MAC_HINT]); + connect.ssid = nla_data(info->attrs[NL80211_ATTR_SSID]); + connect.ssid_len = nla_len(info->attrs[NL80211_ATTR_SSID]); + + if (info->attrs[NL80211_ATTR_IE]) { + connect.ie = nla_data(info->attrs[NL80211_ATTR_IE]); + connect.ie_len = nla_len(info->attrs[NL80211_ATTR_IE]); + } + + if (info->attrs[NL80211_ATTR_USE_MFP]) { + connect.mfp = nla_get_u32(info->attrs[NL80211_ATTR_USE_MFP]); + if (connect.mfp == NL80211_MFP_OPTIONAL && + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_MFP_OPTIONAL)) + return -EOPNOTSUPP; + } else { + connect.mfp = NL80211_MFP_NO; + } + + if (info->attrs[NL80211_ATTR_PREV_BSSID]) + connect.prev_bssid = + nla_data(info->attrs[NL80211_ATTR_PREV_BSSID]); + + if (info->attrs[NL80211_ATTR_WIPHY_FREQ]) + freq = MHZ_TO_KHZ(nla_get_u32( + info->attrs[NL80211_ATTR_WIPHY_FREQ])); + if (info->attrs[NL80211_ATTR_WIPHY_FREQ_OFFSET]) + freq += + nla_get_u32(info->attrs[NL80211_ATTR_WIPHY_FREQ_OFFSET]); + + if (freq) { + connect.channel = nl80211_get_valid_chan(wiphy, freq); + if (!connect.channel) + return -EINVAL; + } else if (info->attrs[NL80211_ATTR_WIPHY_FREQ_HINT]) { + freq = nla_get_u32(info->attrs[NL80211_ATTR_WIPHY_FREQ_HINT]); + freq = MHZ_TO_KHZ(freq); + connect.channel_hint = nl80211_get_valid_chan(wiphy, freq); + if (!connect.channel_hint) + return -EINVAL; + } + + if (info->attrs[NL80211_ATTR_WIPHY_EDMG_CHANNELS]) { + connect.edmg.channels = + nla_get_u8(info->attrs[NL80211_ATTR_WIPHY_EDMG_CHANNELS]); + + if (info->attrs[NL80211_ATTR_WIPHY_EDMG_BW_CONFIG]) + connect.edmg.bw_config = + nla_get_u8(info->attrs[NL80211_ATTR_WIPHY_EDMG_BW_CONFIG]); + } + + if (connect.privacy && info->attrs[NL80211_ATTR_KEYS]) { + connkeys = nl80211_parse_connkeys(rdev, info, NULL); + if (IS_ERR(connkeys)) + return PTR_ERR(connkeys); + } + + if (nla_get_flag(info->attrs[NL80211_ATTR_DISABLE_HT])) + connect.flags |= ASSOC_REQ_DISABLE_HT; + + if (info->attrs[NL80211_ATTR_HT_CAPABILITY_MASK]) + memcpy(&connect.ht_capa_mask, + nla_data(info->attrs[NL80211_ATTR_HT_CAPABILITY_MASK]), + sizeof(connect.ht_capa_mask)); + + if (info->attrs[NL80211_ATTR_HT_CAPABILITY]) { + if (!info->attrs[NL80211_ATTR_HT_CAPABILITY_MASK]) { + kfree_sensitive(connkeys); + return -EINVAL; + } + memcpy(&connect.ht_capa, + nla_data(info->attrs[NL80211_ATTR_HT_CAPABILITY]), + sizeof(connect.ht_capa)); + } + + if (nla_get_flag(info->attrs[NL80211_ATTR_DISABLE_VHT])) + connect.flags |= ASSOC_REQ_DISABLE_VHT; + + if (nla_get_flag(info->attrs[NL80211_ATTR_DISABLE_HE])) + connect.flags |= ASSOC_REQ_DISABLE_HE; + + if (nla_get_flag(info->attrs[NL80211_ATTR_DISABLE_EHT])) + connect.flags |= ASSOC_REQ_DISABLE_EHT; + + if (info->attrs[NL80211_ATTR_VHT_CAPABILITY_MASK]) + memcpy(&connect.vht_capa_mask, + nla_data(info->attrs[NL80211_ATTR_VHT_CAPABILITY_MASK]), + sizeof(connect.vht_capa_mask)); + + if (info->attrs[NL80211_ATTR_VHT_CAPABILITY]) { + if (!info->attrs[NL80211_ATTR_VHT_CAPABILITY_MASK]) { + kfree_sensitive(connkeys); + return -EINVAL; + } + memcpy(&connect.vht_capa, + nla_data(info->attrs[NL80211_ATTR_VHT_CAPABILITY]), + sizeof(connect.vht_capa)); + } + + if (nla_get_flag(info->attrs[NL80211_ATTR_USE_RRM])) { + if (!((rdev->wiphy.features & + NL80211_FEATURE_DS_PARAM_SET_IE_IN_PROBES) && + (rdev->wiphy.features & NL80211_FEATURE_QUIET)) && + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_RRM)) { + kfree_sensitive(connkeys); + return -EINVAL; + } + connect.flags |= ASSOC_REQ_USE_RRM; + } + + connect.pbss = nla_get_flag(info->attrs[NL80211_ATTR_PBSS]); + if (connect.pbss && !rdev->wiphy.bands[NL80211_BAND_60GHZ]) { + kfree_sensitive(connkeys); + return -EOPNOTSUPP; + } + + if (info->attrs[NL80211_ATTR_BSS_SELECT]) { + /* bss selection makes no sense if bssid is set */ + if (connect.bssid) { + kfree_sensitive(connkeys); + return -EINVAL; + } + + err = parse_bss_select(info->attrs[NL80211_ATTR_BSS_SELECT], + wiphy, &connect.bss_select); + if (err) { + kfree_sensitive(connkeys); + return err; + } + } + + if (wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_FILS_SK_OFFLOAD) && + info->attrs[NL80211_ATTR_FILS_ERP_USERNAME] && + info->attrs[NL80211_ATTR_FILS_ERP_REALM] && + info->attrs[NL80211_ATTR_FILS_ERP_NEXT_SEQ_NUM] && + info->attrs[NL80211_ATTR_FILS_ERP_RRK]) { + connect.fils_erp_username = + nla_data(info->attrs[NL80211_ATTR_FILS_ERP_USERNAME]); + connect.fils_erp_username_len = + nla_len(info->attrs[NL80211_ATTR_FILS_ERP_USERNAME]); + connect.fils_erp_realm = + nla_data(info->attrs[NL80211_ATTR_FILS_ERP_REALM]); + connect.fils_erp_realm_len = + nla_len(info->attrs[NL80211_ATTR_FILS_ERP_REALM]); + connect.fils_erp_next_seq_num = + nla_get_u16( + info->attrs[NL80211_ATTR_FILS_ERP_NEXT_SEQ_NUM]); + connect.fils_erp_rrk = + nla_data(info->attrs[NL80211_ATTR_FILS_ERP_RRK]); + connect.fils_erp_rrk_len = + nla_len(info->attrs[NL80211_ATTR_FILS_ERP_RRK]); + } else if (info->attrs[NL80211_ATTR_FILS_ERP_USERNAME] || + info->attrs[NL80211_ATTR_FILS_ERP_REALM] || + info->attrs[NL80211_ATTR_FILS_ERP_NEXT_SEQ_NUM] || + info->attrs[NL80211_ATTR_FILS_ERP_RRK]) { + kfree_sensitive(connkeys); + return -EINVAL; + } + + if (nla_get_flag(info->attrs[NL80211_ATTR_EXTERNAL_AUTH_SUPPORT])) { + if (!info->attrs[NL80211_ATTR_SOCKET_OWNER]) { + kfree_sensitive(connkeys); + GENL_SET_ERR_MSG(info, + "external auth requires connection ownership"); + return -EINVAL; + } + connect.flags |= CONNECT_REQ_EXTERNAL_AUTH_SUPPORT; + } + + if (nla_get_flag(info->attrs[NL80211_ATTR_MLO_SUPPORT])) + connect.flags |= CONNECT_REQ_MLO_SUPPORT; + + wdev_lock(dev->ieee80211_ptr); + + err = cfg80211_connect(rdev, dev, &connect, connkeys, + connect.prev_bssid); + if (err) + kfree_sensitive(connkeys); + + if (!err && info->attrs[NL80211_ATTR_SOCKET_OWNER]) { + dev->ieee80211_ptr->conn_owner_nlportid = info->snd_portid; + if (connect.bssid) + memcpy(dev->ieee80211_ptr->disconnect_bssid, + connect.bssid, ETH_ALEN); + else + eth_zero_addr(dev->ieee80211_ptr->disconnect_bssid); + } + + wdev_unlock(dev->ieee80211_ptr); + + return err; +} + +static int nl80211_update_connect_params(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_connect_params connect = {}; + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + bool fils_sk_offload; + u32 auth_type; + u32 changed = 0; + int ret; + + if (!rdev->ops->update_connect_params) + return -EOPNOTSUPP; + + if (info->attrs[NL80211_ATTR_IE]) { + connect.ie = nla_data(info->attrs[NL80211_ATTR_IE]); + connect.ie_len = nla_len(info->attrs[NL80211_ATTR_IE]); + changed |= UPDATE_ASSOC_IES; + } + + fils_sk_offload = wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_FILS_SK_OFFLOAD); + + /* + * when driver supports fils-sk offload all attributes must be + * provided. So the else covers "fils-sk-not-all" and + * "no-fils-sk-any". + */ + if (fils_sk_offload && + info->attrs[NL80211_ATTR_FILS_ERP_USERNAME] && + info->attrs[NL80211_ATTR_FILS_ERP_REALM] && + info->attrs[NL80211_ATTR_FILS_ERP_NEXT_SEQ_NUM] && + info->attrs[NL80211_ATTR_FILS_ERP_RRK]) { + connect.fils_erp_username = + nla_data(info->attrs[NL80211_ATTR_FILS_ERP_USERNAME]); + connect.fils_erp_username_len = + nla_len(info->attrs[NL80211_ATTR_FILS_ERP_USERNAME]); + connect.fils_erp_realm = + nla_data(info->attrs[NL80211_ATTR_FILS_ERP_REALM]); + connect.fils_erp_realm_len = + nla_len(info->attrs[NL80211_ATTR_FILS_ERP_REALM]); + connect.fils_erp_next_seq_num = + nla_get_u16( + info->attrs[NL80211_ATTR_FILS_ERP_NEXT_SEQ_NUM]); + connect.fils_erp_rrk = + nla_data(info->attrs[NL80211_ATTR_FILS_ERP_RRK]); + connect.fils_erp_rrk_len = + nla_len(info->attrs[NL80211_ATTR_FILS_ERP_RRK]); + changed |= UPDATE_FILS_ERP_INFO; + } else if (info->attrs[NL80211_ATTR_FILS_ERP_USERNAME] || + info->attrs[NL80211_ATTR_FILS_ERP_REALM] || + info->attrs[NL80211_ATTR_FILS_ERP_NEXT_SEQ_NUM] || + info->attrs[NL80211_ATTR_FILS_ERP_RRK]) { + return -EINVAL; + } + + if (info->attrs[NL80211_ATTR_AUTH_TYPE]) { + auth_type = nla_get_u32(info->attrs[NL80211_ATTR_AUTH_TYPE]); + if (!nl80211_valid_auth_type(rdev, auth_type, + NL80211_CMD_CONNECT)) + return -EINVAL; + + if (auth_type == NL80211_AUTHTYPE_FILS_SK && + fils_sk_offload && !(changed & UPDATE_FILS_ERP_INFO)) + return -EINVAL; + + connect.auth_type = auth_type; + changed |= UPDATE_AUTH_TYPE; + } + + wdev_lock(dev->ieee80211_ptr); + if (!wdev->connected) + ret = -ENOLINK; + else + ret = rdev_update_connect_params(rdev, dev, &connect, changed); + wdev_unlock(dev->ieee80211_ptr); + + return ret; +} + +static int nl80211_disconnect(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + u16 reason; + int ret; + + if (dev->ieee80211_ptr->conn_owner_nlportid && + dev->ieee80211_ptr->conn_owner_nlportid != info->snd_portid) + return -EPERM; + + if (!info->attrs[NL80211_ATTR_REASON_CODE]) + reason = WLAN_REASON_DEAUTH_LEAVING; + else + reason = nla_get_u16(info->attrs[NL80211_ATTR_REASON_CODE]); + + if (reason == 0) + return -EINVAL; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_STATION && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_CLIENT) + return -EOPNOTSUPP; + + wdev_lock(dev->ieee80211_ptr); + ret = cfg80211_disconnect(rdev, dev, reason, true); + wdev_unlock(dev->ieee80211_ptr); + return ret; +} + +static int nl80211_wiphy_netns(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net *net; + int err; + + if (info->attrs[NL80211_ATTR_PID]) { + u32 pid = nla_get_u32(info->attrs[NL80211_ATTR_PID]); + + net = get_net_ns_by_pid(pid); + } else if (info->attrs[NL80211_ATTR_NETNS_FD]) { + u32 fd = nla_get_u32(info->attrs[NL80211_ATTR_NETNS_FD]); + + net = get_net_ns_by_fd(fd); + } else { + return -EINVAL; + } + + if (IS_ERR(net)) + return PTR_ERR(net); + + err = 0; + + /* check if anything to do */ + if (!net_eq(wiphy_net(&rdev->wiphy), net)) + err = cfg80211_switch_netns(rdev, net); + + put_net(net); + return err; +} + +static int nl80211_setdel_pmksa(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + int (*rdev_ops)(struct wiphy *wiphy, struct net_device *dev, + struct cfg80211_pmksa *pmksa) = NULL; + struct net_device *dev = info->user_ptr[1]; + struct cfg80211_pmksa pmksa; + + memset(&pmksa, 0, sizeof(struct cfg80211_pmksa)); + + if (!info->attrs[NL80211_ATTR_PMKID]) + return -EINVAL; + + pmksa.pmkid = nla_data(info->attrs[NL80211_ATTR_PMKID]); + + if (info->attrs[NL80211_ATTR_MAC]) { + pmksa.bssid = nla_data(info->attrs[NL80211_ATTR_MAC]); + } else if (info->attrs[NL80211_ATTR_SSID] && + info->attrs[NL80211_ATTR_FILS_CACHE_ID] && + (info->genlhdr->cmd == NL80211_CMD_DEL_PMKSA || + info->attrs[NL80211_ATTR_PMK])) { + pmksa.ssid = nla_data(info->attrs[NL80211_ATTR_SSID]); + pmksa.ssid_len = nla_len(info->attrs[NL80211_ATTR_SSID]); + pmksa.cache_id = + nla_data(info->attrs[NL80211_ATTR_FILS_CACHE_ID]); + } else { + return -EINVAL; + } + if (info->attrs[NL80211_ATTR_PMK]) { + pmksa.pmk = nla_data(info->attrs[NL80211_ATTR_PMK]); + pmksa.pmk_len = nla_len(info->attrs[NL80211_ATTR_PMK]); + } + + if (info->attrs[NL80211_ATTR_PMK_LIFETIME]) + pmksa.pmk_lifetime = + nla_get_u32(info->attrs[NL80211_ATTR_PMK_LIFETIME]); + + if (info->attrs[NL80211_ATTR_PMK_REAUTH_THRESHOLD]) + pmksa.pmk_reauth_threshold = + nla_get_u8( + info->attrs[NL80211_ATTR_PMK_REAUTH_THRESHOLD]); + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_STATION && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_CLIENT && + !(dev->ieee80211_ptr->iftype == NL80211_IFTYPE_AP && + wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_AP_PMKSA_CACHING))) + return -EOPNOTSUPP; + + switch (info->genlhdr->cmd) { + case NL80211_CMD_SET_PMKSA: + rdev_ops = rdev->ops->set_pmksa; + break; + case NL80211_CMD_DEL_PMKSA: + rdev_ops = rdev->ops->del_pmksa; + break; + default: + WARN_ON(1); + break; + } + + if (!rdev_ops) + return -EOPNOTSUPP; + + return rdev_ops(&rdev->wiphy, dev, &pmksa); +} + +static int nl80211_flush_pmksa(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_STATION && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_CLIENT) + return -EOPNOTSUPP; + + if (!rdev->ops->flush_pmksa) + return -EOPNOTSUPP; + + return rdev_flush_pmksa(rdev, dev); +} + +static int nl80211_tdls_mgmt(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + u8 action_code, dialog_token; + u32 peer_capability = 0; + u16 status_code; + u8 *peer; + bool initiator; + + if (!(rdev->wiphy.flags & WIPHY_FLAG_SUPPORTS_TDLS) || + !rdev->ops->tdls_mgmt) + return -EOPNOTSUPP; + + if (!info->attrs[NL80211_ATTR_TDLS_ACTION] || + !info->attrs[NL80211_ATTR_STATUS_CODE] || + !info->attrs[NL80211_ATTR_TDLS_DIALOG_TOKEN] || + !info->attrs[NL80211_ATTR_IE] || + !info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + peer = nla_data(info->attrs[NL80211_ATTR_MAC]); + action_code = nla_get_u8(info->attrs[NL80211_ATTR_TDLS_ACTION]); + status_code = nla_get_u16(info->attrs[NL80211_ATTR_STATUS_CODE]); + dialog_token = nla_get_u8(info->attrs[NL80211_ATTR_TDLS_DIALOG_TOKEN]); + initiator = nla_get_flag(info->attrs[NL80211_ATTR_TDLS_INITIATOR]); + if (info->attrs[NL80211_ATTR_TDLS_PEER_CAPABILITY]) + peer_capability = + nla_get_u32(info->attrs[NL80211_ATTR_TDLS_PEER_CAPABILITY]); + + return rdev_tdls_mgmt(rdev, dev, peer, action_code, + dialog_token, status_code, peer_capability, + initiator, + nla_data(info->attrs[NL80211_ATTR_IE]), + nla_len(info->attrs[NL80211_ATTR_IE])); +} + +static int nl80211_tdls_oper(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + enum nl80211_tdls_operation operation; + u8 *peer; + + if (!(rdev->wiphy.flags & WIPHY_FLAG_SUPPORTS_TDLS) || + !rdev->ops->tdls_oper) + return -EOPNOTSUPP; + + if (!info->attrs[NL80211_ATTR_TDLS_OPERATION] || + !info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + operation = nla_get_u8(info->attrs[NL80211_ATTR_TDLS_OPERATION]); + peer = nla_data(info->attrs[NL80211_ATTR_MAC]); + + return rdev_tdls_oper(rdev, dev, peer, operation); +} + +static int nl80211_remain_on_channel(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + unsigned int link_id = nl80211_link_id(info->attrs); + struct wireless_dev *wdev = info->user_ptr[1]; + struct cfg80211_chan_def chandef; + struct sk_buff *msg; + void *hdr; + u64 cookie; + u32 duration; + int err; + + if (!info->attrs[NL80211_ATTR_WIPHY_FREQ] || + !info->attrs[NL80211_ATTR_DURATION]) + return -EINVAL; + + duration = nla_get_u32(info->attrs[NL80211_ATTR_DURATION]); + + if (!rdev->ops->remain_on_channel || + !(rdev->wiphy.flags & WIPHY_FLAG_HAS_REMAIN_ON_CHANNEL)) + return -EOPNOTSUPP; + + /* + * We should be on that channel for at least a minimum amount of + * time (10ms) but no longer than the driver supports. + */ + if (duration < NL80211_MIN_REMAIN_ON_CHANNEL_TIME || + duration > rdev->wiphy.max_remain_on_channel_duration) + return -EINVAL; + + err = nl80211_parse_chandef(rdev, info, &chandef); + if (err) + return err; + + wdev_lock(wdev); + if (!cfg80211_off_channel_oper_allowed(wdev, chandef.chan)) { + const struct cfg80211_chan_def *oper_chandef, *compat_chandef; + + oper_chandef = wdev_chandef(wdev, link_id); + + if (WARN_ON(!oper_chandef)) { + /* cannot happen since we must beacon to get here */ + WARN_ON(1); + wdev_unlock(wdev); + return -EBUSY; + } + + /* note: returns first one if identical chandefs */ + compat_chandef = cfg80211_chandef_compatible(&chandef, + oper_chandef); + + if (compat_chandef != &chandef) { + wdev_unlock(wdev); + return -EBUSY; + } + } + wdev_unlock(wdev); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = nl80211hdr_put(msg, info->snd_portid, info->snd_seq, 0, + NL80211_CMD_REMAIN_ON_CHANNEL); + if (!hdr) { + err = -ENOBUFS; + goto free_msg; + } + + err = rdev_remain_on_channel(rdev, wdev, chandef.chan, + duration, &cookie); + + if (err) + goto free_msg; + + if (nla_put_u64_64bit(msg, NL80211_ATTR_COOKIE, cookie, + NL80211_ATTR_PAD)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + return genlmsg_reply(msg, info); + + nla_put_failure: + err = -ENOBUFS; + free_msg: + nlmsg_free(msg); + return err; +} + +static int nl80211_cancel_remain_on_channel(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + u64 cookie; + + if (!info->attrs[NL80211_ATTR_COOKIE]) + return -EINVAL; + + if (!rdev->ops->cancel_remain_on_channel) + return -EOPNOTSUPP; + + cookie = nla_get_u64(info->attrs[NL80211_ATTR_COOKIE]); + + return rdev_cancel_remain_on_channel(rdev, wdev, cookie); +} + +static int nl80211_set_tx_bitrate_mask(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_bitrate_mask mask; + unsigned int link_id = nl80211_link_id(info->attrs); + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + int err; + + if (!rdev->ops->set_bitrate_mask) + return -EOPNOTSUPP; + + wdev_lock(wdev); + err = nl80211_parse_tx_bitrate_mask(info, info->attrs, + NL80211_ATTR_TX_RATES, &mask, + dev, true, link_id); + if (err) + goto out; + + err = rdev_set_bitrate_mask(rdev, dev, link_id, NULL, &mask); +out: + wdev_unlock(wdev); + return err; +} + +static int nl80211_register_mgmt(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + u16 frame_type = IEEE80211_FTYPE_MGMT | IEEE80211_STYPE_ACTION; + + if (!info->attrs[NL80211_ATTR_FRAME_MATCH]) + return -EINVAL; + + if (info->attrs[NL80211_ATTR_FRAME_TYPE]) + frame_type = nla_get_u16(info->attrs[NL80211_ATTR_FRAME_TYPE]); + + switch (wdev->iftype) { + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_ADHOC: + case NL80211_IFTYPE_P2P_CLIENT: + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_MESH_POINT: + case NL80211_IFTYPE_P2P_GO: + case NL80211_IFTYPE_P2P_DEVICE: + break; + case NL80211_IFTYPE_NAN: + default: + return -EOPNOTSUPP; + } + + /* not much point in registering if we can't reply */ + if (!rdev->ops->mgmt_tx) + return -EOPNOTSUPP; + + if (info->attrs[NL80211_ATTR_RECEIVE_MULTICAST] && + !wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_MULTICAST_REGISTRATIONS)) { + GENL_SET_ERR_MSG(info, + "multicast RX registrations are not supported"); + return -EOPNOTSUPP; + } + + return cfg80211_mlme_register_mgmt(wdev, info->snd_portid, frame_type, + nla_data(info->attrs[NL80211_ATTR_FRAME_MATCH]), + nla_len(info->attrs[NL80211_ATTR_FRAME_MATCH]), + info->attrs[NL80211_ATTR_RECEIVE_MULTICAST], + info->extack); +} + +static int nl80211_tx_mgmt(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + struct cfg80211_chan_def chandef; + int err; + void *hdr = NULL; + u64 cookie; + struct sk_buff *msg = NULL; + struct cfg80211_mgmt_tx_params params = { + .dont_wait_for_ack = + info->attrs[NL80211_ATTR_DONT_WAIT_FOR_ACK], + }; + + if (!info->attrs[NL80211_ATTR_FRAME]) + return -EINVAL; + + if (!rdev->ops->mgmt_tx) + return -EOPNOTSUPP; + + switch (wdev->iftype) { + case NL80211_IFTYPE_P2P_DEVICE: + if (!info->attrs[NL80211_ATTR_WIPHY_FREQ]) + return -EINVAL; + break; + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_ADHOC: + case NL80211_IFTYPE_P2P_CLIENT: + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_MESH_POINT: + case NL80211_IFTYPE_P2P_GO: + break; + case NL80211_IFTYPE_NAN: + default: + return -EOPNOTSUPP; + } + + if (info->attrs[NL80211_ATTR_DURATION]) { + if (!(rdev->wiphy.flags & WIPHY_FLAG_OFFCHAN_TX)) + return -EINVAL; + params.wait = nla_get_u32(info->attrs[NL80211_ATTR_DURATION]); + + /* + * We should wait on the channel for at least a minimum amount + * of time (10ms) but no longer than the driver supports. + */ + if (params.wait < NL80211_MIN_REMAIN_ON_CHANNEL_TIME || + params.wait > rdev->wiphy.max_remain_on_channel_duration) + return -EINVAL; + } + + params.offchan = info->attrs[NL80211_ATTR_OFFCHANNEL_TX_OK]; + + if (params.offchan && !(rdev->wiphy.flags & WIPHY_FLAG_OFFCHAN_TX)) + return -EINVAL; + + params.no_cck = nla_get_flag(info->attrs[NL80211_ATTR_TX_NO_CCK_RATE]); + + /* get the channel if any has been specified, otherwise pass NULL to + * the driver. The latter will use the current one + */ + chandef.chan = NULL; + if (info->attrs[NL80211_ATTR_WIPHY_FREQ]) { + err = nl80211_parse_chandef(rdev, info, &chandef); + if (err) + return err; + } + + if (!chandef.chan && params.offchan) + return -EINVAL; + + wdev_lock(wdev); + if (params.offchan && + !cfg80211_off_channel_oper_allowed(wdev, chandef.chan)) { + wdev_unlock(wdev); + return -EBUSY; + } + + params.link_id = nl80211_link_id_or_invalid(info->attrs); + /* + * This now races due to the unlock, but we cannot check + * the valid links for the _station_ anyway, so that's up + * to the driver. + */ + if (params.link_id >= 0 && + !(wdev->valid_links & BIT(params.link_id))) { + wdev_unlock(wdev); + return -EINVAL; + } + wdev_unlock(wdev); + + params.buf = nla_data(info->attrs[NL80211_ATTR_FRAME]); + params.len = nla_len(info->attrs[NL80211_ATTR_FRAME]); + + if (info->attrs[NL80211_ATTR_CSA_C_OFFSETS_TX]) { + int len = nla_len(info->attrs[NL80211_ATTR_CSA_C_OFFSETS_TX]); + int i; + + if (len % sizeof(u16)) + return -EINVAL; + + params.n_csa_offsets = len / sizeof(u16); + params.csa_offsets = + nla_data(info->attrs[NL80211_ATTR_CSA_C_OFFSETS_TX]); + + /* check that all the offsets fit the frame */ + for (i = 0; i < params.n_csa_offsets; i++) { + if (params.csa_offsets[i] >= params.len) + return -EINVAL; + } + } + + if (!params.dont_wait_for_ack) { + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = nl80211hdr_put(msg, info->snd_portid, info->snd_seq, 0, + NL80211_CMD_FRAME); + if (!hdr) { + err = -ENOBUFS; + goto free_msg; + } + } + + params.chan = chandef.chan; + err = cfg80211_mlme_mgmt_tx(rdev, wdev, ¶ms, &cookie); + if (err) + goto free_msg; + + if (msg) { + if (nla_put_u64_64bit(msg, NL80211_ATTR_COOKIE, cookie, + NL80211_ATTR_PAD)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return genlmsg_reply(msg, info); + } + + return 0; + + nla_put_failure: + err = -ENOBUFS; + free_msg: + nlmsg_free(msg); + return err; +} + +static int nl80211_tx_mgmt_cancel_wait(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + u64 cookie; + + if (!info->attrs[NL80211_ATTR_COOKIE]) + return -EINVAL; + + if (!rdev->ops->mgmt_tx_cancel_wait) + return -EOPNOTSUPP; + + switch (wdev->iftype) { + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_ADHOC: + case NL80211_IFTYPE_P2P_CLIENT: + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_P2P_GO: + case NL80211_IFTYPE_P2P_DEVICE: + break; + case NL80211_IFTYPE_NAN: + default: + return -EOPNOTSUPP; + } + + cookie = nla_get_u64(info->attrs[NL80211_ATTR_COOKIE]); + + return rdev_mgmt_tx_cancel_wait(rdev, wdev, cookie); +} + +static int nl80211_set_power_save(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev; + struct net_device *dev = info->user_ptr[1]; + u8 ps_state; + bool state; + int err; + + if (!info->attrs[NL80211_ATTR_PS_STATE]) + return -EINVAL; + + ps_state = nla_get_u32(info->attrs[NL80211_ATTR_PS_STATE]); + + wdev = dev->ieee80211_ptr; + + if (!rdev->ops->set_power_mgmt) + return -EOPNOTSUPP; + + state = (ps_state == NL80211_PS_ENABLED) ? true : false; + + if (state == wdev->ps) + return 0; + + err = rdev_set_power_mgmt(rdev, dev, state, wdev->ps_timeout); + if (!err) + wdev->ps = state; + return err; +} + +static int nl80211_get_power_save(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + enum nl80211_ps_state ps_state; + struct wireless_dev *wdev; + struct net_device *dev = info->user_ptr[1]; + struct sk_buff *msg; + void *hdr; + int err; + + wdev = dev->ieee80211_ptr; + + if (!rdev->ops->set_power_mgmt) + return -EOPNOTSUPP; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = nl80211hdr_put(msg, info->snd_portid, info->snd_seq, 0, + NL80211_CMD_GET_POWER_SAVE); + if (!hdr) { + err = -ENOBUFS; + goto free_msg; + } + + if (wdev->ps) + ps_state = NL80211_PS_ENABLED; + else + ps_state = NL80211_PS_DISABLED; + + if (nla_put_u32(msg, NL80211_ATTR_PS_STATE, ps_state)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return genlmsg_reply(msg, info); + + nla_put_failure: + err = -ENOBUFS; + free_msg: + nlmsg_free(msg); + return err; +} + +static const struct nla_policy +nl80211_attr_cqm_policy[NL80211_ATTR_CQM_MAX + 1] = { + [NL80211_ATTR_CQM_RSSI_THOLD] = { .type = NLA_BINARY }, + [NL80211_ATTR_CQM_RSSI_HYST] = { .type = NLA_U32 }, + [NL80211_ATTR_CQM_RSSI_THRESHOLD_EVENT] = { .type = NLA_U32 }, + [NL80211_ATTR_CQM_TXE_RATE] = { .type = NLA_U32 }, + [NL80211_ATTR_CQM_TXE_PKTS] = { .type = NLA_U32 }, + [NL80211_ATTR_CQM_TXE_INTVL] = { .type = NLA_U32 }, + [NL80211_ATTR_CQM_RSSI_LEVEL] = { .type = NLA_S32 }, +}; + +static int nl80211_set_cqm_txe(struct genl_info *info, + u32 rate, u32 pkts, u32 intvl) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + + if (rate > 100 || intvl > NL80211_CQM_TXE_MAX_INTVL) + return -EINVAL; + + if (!rdev->ops->set_cqm_txe_config) + return -EOPNOTSUPP; + + if (wdev->iftype != NL80211_IFTYPE_STATION && + wdev->iftype != NL80211_IFTYPE_P2P_CLIENT) + return -EOPNOTSUPP; + + return rdev_set_cqm_txe_config(rdev, dev, rate, pkts, intvl); +} + +static int cfg80211_cqm_rssi_update(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_cqm_config *cqm_config) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + s32 last, low, high; + u32 hyst; + int i, n, low_index; + int err; + + /* + * Obtain current RSSI value if possible, if not and no RSSI threshold + * event has been received yet, we should receive an event after a + * connection is established and enough beacons received to calculate + * the average. + */ + if (!cqm_config->last_rssi_event_value && + wdev->links[0].client.current_bss && + rdev->ops->get_station) { + struct station_info sinfo = {}; + u8 *mac_addr; + + mac_addr = wdev->links[0].client.current_bss->pub.bssid; + + err = rdev_get_station(rdev, dev, mac_addr, &sinfo); + if (err) + return err; + + cfg80211_sinfo_release_content(&sinfo); + if (sinfo.filled & BIT_ULL(NL80211_STA_INFO_BEACON_SIGNAL_AVG)) + cqm_config->last_rssi_event_value = + (s8) sinfo.rx_beacon_signal_avg; + } + + last = cqm_config->last_rssi_event_value; + hyst = cqm_config->rssi_hyst; + n = cqm_config->n_rssi_thresholds; + + for (i = 0; i < n; i++) { + i = array_index_nospec(i, n); + if (last < cqm_config->rssi_thresholds[i]) + break; + } + + low_index = i - 1; + if (low_index >= 0) { + low_index = array_index_nospec(low_index, n); + low = cqm_config->rssi_thresholds[low_index] - hyst; + } else { + low = S32_MIN; + } + if (i < n) { + i = array_index_nospec(i, n); + high = cqm_config->rssi_thresholds[i] + hyst - 1; + } else { + high = S32_MAX; + } + + return rdev_set_cqm_rssi_range_config(rdev, dev, low, high); +} + +static int nl80211_set_cqm_rssi(struct genl_info *info, + const s32 *thresholds, int n_thresholds, + u32 hysteresis) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct cfg80211_cqm_config *cqm_config = NULL, *old; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + int i, err; + s32 prev = S32_MIN; + + /* Check all values negative and sorted */ + for (i = 0; i < n_thresholds; i++) { + if (thresholds[i] > 0 || thresholds[i] <= prev) + return -EINVAL; + + prev = thresholds[i]; + } + + if (wdev->iftype != NL80211_IFTYPE_STATION && + wdev->iftype != NL80211_IFTYPE_P2P_CLIENT) + return -EOPNOTSUPP; + + if (n_thresholds == 1 && thresholds[0] == 0) /* Disabling */ + n_thresholds = 0; + + wdev_lock(wdev); + old = rcu_dereference_protected(wdev->cqm_config, + lockdep_is_held(&wdev->mtx)); + + /* if already disabled just succeed */ + if (!n_thresholds && !old) { + err = 0; + goto unlock; + } + + if (n_thresholds > 1) { + if (!wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_CQM_RSSI_LIST) || + !rdev->ops->set_cqm_rssi_range_config) { + err = -EOPNOTSUPP; + goto unlock; + } + } else { + if (!rdev->ops->set_cqm_rssi_config) { + err = -EOPNOTSUPP; + goto unlock; + } + } + + if (n_thresholds) { + cqm_config = kzalloc(struct_size(cqm_config, rssi_thresholds, + n_thresholds), + GFP_KERNEL); + if (!cqm_config) { + err = -ENOMEM; + goto unlock; + } + + cqm_config->rssi_hyst = hysteresis; + cqm_config->n_rssi_thresholds = n_thresholds; + memcpy(cqm_config->rssi_thresholds, thresholds, + flex_array_size(cqm_config, rssi_thresholds, + n_thresholds)); + cqm_config->use_range_api = n_thresholds > 1 || + !rdev->ops->set_cqm_rssi_config; + + rcu_assign_pointer(wdev->cqm_config, cqm_config); + + if (cqm_config->use_range_api) + err = cfg80211_cqm_rssi_update(rdev, dev, cqm_config); + else + err = rdev_set_cqm_rssi_config(rdev, dev, + thresholds[0], + hysteresis); + } else { + RCU_INIT_POINTER(wdev->cqm_config, NULL); + /* if enabled as range also disable via range */ + if (old->use_range_api) + err = rdev_set_cqm_rssi_range_config(rdev, dev, 0, 0); + else + err = rdev_set_cqm_rssi_config(rdev, dev, 0, 0); + } + + if (err) { + rcu_assign_pointer(wdev->cqm_config, old); + kfree_rcu(cqm_config, rcu_head); + } else { + kfree_rcu(old, rcu_head); + } +unlock: + wdev_unlock(wdev); + + return err; +} + +static int nl80211_set_cqm(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *attrs[NL80211_ATTR_CQM_MAX + 1]; + struct nlattr *cqm; + int err; + + cqm = info->attrs[NL80211_ATTR_CQM]; + if (!cqm) + return -EINVAL; + + err = nla_parse_nested_deprecated(attrs, NL80211_ATTR_CQM_MAX, cqm, + nl80211_attr_cqm_policy, + info->extack); + if (err) + return err; + + if (attrs[NL80211_ATTR_CQM_RSSI_THOLD] && + attrs[NL80211_ATTR_CQM_RSSI_HYST]) { + const s32 *thresholds = + nla_data(attrs[NL80211_ATTR_CQM_RSSI_THOLD]); + int len = nla_len(attrs[NL80211_ATTR_CQM_RSSI_THOLD]); + u32 hysteresis = nla_get_u32(attrs[NL80211_ATTR_CQM_RSSI_HYST]); + + if (len % 4) + return -EINVAL; + + return nl80211_set_cqm_rssi(info, thresholds, len / 4, + hysteresis); + } + + if (attrs[NL80211_ATTR_CQM_TXE_RATE] && + attrs[NL80211_ATTR_CQM_TXE_PKTS] && + attrs[NL80211_ATTR_CQM_TXE_INTVL]) { + u32 rate = nla_get_u32(attrs[NL80211_ATTR_CQM_TXE_RATE]); + u32 pkts = nla_get_u32(attrs[NL80211_ATTR_CQM_TXE_PKTS]); + u32 intvl = nla_get_u32(attrs[NL80211_ATTR_CQM_TXE_INTVL]); + + return nl80211_set_cqm_txe(info, rate, pkts, intvl); + } + + return -EINVAL; +} + +static int nl80211_join_ocb(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct ocb_setup setup = {}; + int err; + + err = nl80211_parse_chandef(rdev, info, &setup.chandef); + if (err) + return err; + + return cfg80211_join_ocb(rdev, dev, &setup); +} + +static int nl80211_leave_ocb(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + + return cfg80211_leave_ocb(rdev, dev); +} + +static int nl80211_join_mesh(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct mesh_config cfg; + struct mesh_setup setup; + int err; + + /* start with default */ + memcpy(&cfg, &default_mesh_config, sizeof(cfg)); + memcpy(&setup, &default_mesh_setup, sizeof(setup)); + + if (info->attrs[NL80211_ATTR_MESH_CONFIG]) { + /* and parse parameters if given */ + err = nl80211_parse_mesh_config(info, &cfg, NULL); + if (err) + return err; + } + + if (!info->attrs[NL80211_ATTR_MESH_ID] || + !nla_len(info->attrs[NL80211_ATTR_MESH_ID])) + return -EINVAL; + + setup.mesh_id = nla_data(info->attrs[NL80211_ATTR_MESH_ID]); + setup.mesh_id_len = nla_len(info->attrs[NL80211_ATTR_MESH_ID]); + + if (info->attrs[NL80211_ATTR_MCAST_RATE] && + !nl80211_parse_mcast_rate(rdev, setup.mcast_rate, + nla_get_u32(info->attrs[NL80211_ATTR_MCAST_RATE]))) + return -EINVAL; + + if (info->attrs[NL80211_ATTR_BEACON_INTERVAL]) { + setup.beacon_interval = + nla_get_u32(info->attrs[NL80211_ATTR_BEACON_INTERVAL]); + + err = cfg80211_validate_beacon_int(rdev, + NL80211_IFTYPE_MESH_POINT, + setup.beacon_interval); + if (err) + return err; + } + + if (info->attrs[NL80211_ATTR_DTIM_PERIOD]) { + setup.dtim_period = + nla_get_u32(info->attrs[NL80211_ATTR_DTIM_PERIOD]); + if (setup.dtim_period < 1 || setup.dtim_period > 100) + return -EINVAL; + } + + if (info->attrs[NL80211_ATTR_MESH_SETUP]) { + /* parse additional setup parameters if given */ + err = nl80211_parse_mesh_setup(info, &setup); + if (err) + return err; + } + + if (setup.user_mpm) + cfg.auto_open_plinks = false; + + if (info->attrs[NL80211_ATTR_WIPHY_FREQ]) { + err = nl80211_parse_chandef(rdev, info, &setup.chandef); + if (err) + return err; + } else { + /* __cfg80211_join_mesh() will sort it out */ + setup.chandef.chan = NULL; + } + + if (info->attrs[NL80211_ATTR_BSS_BASIC_RATES]) { + u8 *rates = nla_data(info->attrs[NL80211_ATTR_BSS_BASIC_RATES]); + int n_rates = + nla_len(info->attrs[NL80211_ATTR_BSS_BASIC_RATES]); + struct ieee80211_supported_band *sband; + + if (!setup.chandef.chan) + return -EINVAL; + + sband = rdev->wiphy.bands[setup.chandef.chan->band]; + + err = ieee80211_get_ratemask(sband, rates, n_rates, + &setup.basic_rates); + if (err) + return err; + } + + if (info->attrs[NL80211_ATTR_TX_RATES]) { + err = nl80211_parse_tx_bitrate_mask(info, info->attrs, + NL80211_ATTR_TX_RATES, + &setup.beacon_rate, + dev, false, 0); + if (err) + return err; + + if (!setup.chandef.chan) + return -EINVAL; + + err = validate_beacon_tx_rate(rdev, setup.chandef.chan->band, + &setup.beacon_rate); + if (err) + return err; + } + + setup.userspace_handles_dfs = + nla_get_flag(info->attrs[NL80211_ATTR_HANDLE_DFS]); + + if (info->attrs[NL80211_ATTR_CONTROL_PORT_OVER_NL80211]) { + int r = validate_pae_over_nl80211(rdev, info); + + if (r < 0) + return r; + + setup.control_port_over_nl80211 = true; + } + + wdev_lock(dev->ieee80211_ptr); + err = __cfg80211_join_mesh(rdev, dev, &setup, &cfg); + if (!err && info->attrs[NL80211_ATTR_SOCKET_OWNER]) + dev->ieee80211_ptr->conn_owner_nlportid = info->snd_portid; + wdev_unlock(dev->ieee80211_ptr); + + return err; +} + +static int nl80211_leave_mesh(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + + return cfg80211_leave_mesh(rdev, dev); +} + +#ifdef CONFIG_PM +static int nl80211_send_wowlan_patterns(struct sk_buff *msg, + struct cfg80211_registered_device *rdev) +{ + struct cfg80211_wowlan *wowlan = rdev->wiphy.wowlan_config; + struct nlattr *nl_pats, *nl_pat; + int i, pat_len; + + if (!wowlan->n_patterns) + return 0; + + nl_pats = nla_nest_start_noflag(msg, NL80211_WOWLAN_TRIG_PKT_PATTERN); + if (!nl_pats) + return -ENOBUFS; + + for (i = 0; i < wowlan->n_patterns; i++) { + nl_pat = nla_nest_start_noflag(msg, i + 1); + if (!nl_pat) + return -ENOBUFS; + pat_len = wowlan->patterns[i].pattern_len; + if (nla_put(msg, NL80211_PKTPAT_MASK, DIV_ROUND_UP(pat_len, 8), + wowlan->patterns[i].mask) || + nla_put(msg, NL80211_PKTPAT_PATTERN, pat_len, + wowlan->patterns[i].pattern) || + nla_put_u32(msg, NL80211_PKTPAT_OFFSET, + wowlan->patterns[i].pkt_offset)) + return -ENOBUFS; + nla_nest_end(msg, nl_pat); + } + nla_nest_end(msg, nl_pats); + + return 0; +} + +static int nl80211_send_wowlan_tcp(struct sk_buff *msg, + struct cfg80211_wowlan_tcp *tcp) +{ + struct nlattr *nl_tcp; + + if (!tcp) + return 0; + + nl_tcp = nla_nest_start_noflag(msg, + NL80211_WOWLAN_TRIG_TCP_CONNECTION); + if (!nl_tcp) + return -ENOBUFS; + + if (nla_put_in_addr(msg, NL80211_WOWLAN_TCP_SRC_IPV4, tcp->src) || + nla_put_in_addr(msg, NL80211_WOWLAN_TCP_DST_IPV4, tcp->dst) || + nla_put(msg, NL80211_WOWLAN_TCP_DST_MAC, ETH_ALEN, tcp->dst_mac) || + nla_put_u16(msg, NL80211_WOWLAN_TCP_SRC_PORT, tcp->src_port) || + nla_put_u16(msg, NL80211_WOWLAN_TCP_DST_PORT, tcp->dst_port) || + nla_put(msg, NL80211_WOWLAN_TCP_DATA_PAYLOAD, + tcp->payload_len, tcp->payload) || + nla_put_u32(msg, NL80211_WOWLAN_TCP_DATA_INTERVAL, + tcp->data_interval) || + nla_put(msg, NL80211_WOWLAN_TCP_WAKE_PAYLOAD, + tcp->wake_len, tcp->wake_data) || + nla_put(msg, NL80211_WOWLAN_TCP_WAKE_MASK, + DIV_ROUND_UP(tcp->wake_len, 8), tcp->wake_mask)) + return -ENOBUFS; + + if (tcp->payload_seq.len && + nla_put(msg, NL80211_WOWLAN_TCP_DATA_PAYLOAD_SEQ, + sizeof(tcp->payload_seq), &tcp->payload_seq)) + return -ENOBUFS; + + if (tcp->payload_tok.len && + nla_put(msg, NL80211_WOWLAN_TCP_DATA_PAYLOAD_TOKEN, + sizeof(tcp->payload_tok) + tcp->tokens_size, + &tcp->payload_tok)) + return -ENOBUFS; + + nla_nest_end(msg, nl_tcp); + + return 0; +} + +static int nl80211_send_wowlan_nd(struct sk_buff *msg, + struct cfg80211_sched_scan_request *req) +{ + struct nlattr *nd, *freqs, *matches, *match, *scan_plans, *scan_plan; + int i; + + if (!req) + return 0; + + nd = nla_nest_start_noflag(msg, NL80211_WOWLAN_TRIG_NET_DETECT); + if (!nd) + return -ENOBUFS; + + if (req->n_scan_plans == 1 && + nla_put_u32(msg, NL80211_ATTR_SCHED_SCAN_INTERVAL, + req->scan_plans[0].interval * 1000)) + return -ENOBUFS; + + if (nla_put_u32(msg, NL80211_ATTR_SCHED_SCAN_DELAY, req->delay)) + return -ENOBUFS; + + if (req->relative_rssi_set) { + struct nl80211_bss_select_rssi_adjust rssi_adjust; + + if (nla_put_s8(msg, NL80211_ATTR_SCHED_SCAN_RELATIVE_RSSI, + req->relative_rssi)) + return -ENOBUFS; + + rssi_adjust.band = req->rssi_adjust.band; + rssi_adjust.delta = req->rssi_adjust.delta; + if (nla_put(msg, NL80211_ATTR_SCHED_SCAN_RSSI_ADJUST, + sizeof(rssi_adjust), &rssi_adjust)) + return -ENOBUFS; + } + + freqs = nla_nest_start_noflag(msg, NL80211_ATTR_SCAN_FREQUENCIES); + if (!freqs) + return -ENOBUFS; + + for (i = 0; i < req->n_channels; i++) { + if (nla_put_u32(msg, i, req->channels[i]->center_freq)) + return -ENOBUFS; + } + + nla_nest_end(msg, freqs); + + if (req->n_match_sets) { + matches = nla_nest_start_noflag(msg, + NL80211_ATTR_SCHED_SCAN_MATCH); + if (!matches) + return -ENOBUFS; + + for (i = 0; i < req->n_match_sets; i++) { + match = nla_nest_start_noflag(msg, i); + if (!match) + return -ENOBUFS; + + if (nla_put(msg, NL80211_SCHED_SCAN_MATCH_ATTR_SSID, + req->match_sets[i].ssid.ssid_len, + req->match_sets[i].ssid.ssid)) + return -ENOBUFS; + nla_nest_end(msg, match); + } + nla_nest_end(msg, matches); + } + + scan_plans = nla_nest_start_noflag(msg, NL80211_ATTR_SCHED_SCAN_PLANS); + if (!scan_plans) + return -ENOBUFS; + + for (i = 0; i < req->n_scan_plans; i++) { + scan_plan = nla_nest_start_noflag(msg, i + 1); + if (!scan_plan) + return -ENOBUFS; + + if (nla_put_u32(msg, NL80211_SCHED_SCAN_PLAN_INTERVAL, + req->scan_plans[i].interval) || + (req->scan_plans[i].iterations && + nla_put_u32(msg, NL80211_SCHED_SCAN_PLAN_ITERATIONS, + req->scan_plans[i].iterations))) + return -ENOBUFS; + nla_nest_end(msg, scan_plan); + } + nla_nest_end(msg, scan_plans); + + nla_nest_end(msg, nd); + + return 0; +} + +static int nl80211_get_wowlan(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct sk_buff *msg; + void *hdr; + u32 size = NLMSG_DEFAULT_SIZE; + + if (!rdev->wiphy.wowlan) + return -EOPNOTSUPP; + + if (rdev->wiphy.wowlan_config && rdev->wiphy.wowlan_config->tcp) { + /* adjust size to have room for all the data */ + size += rdev->wiphy.wowlan_config->tcp->tokens_size + + rdev->wiphy.wowlan_config->tcp->payload_len + + rdev->wiphy.wowlan_config->tcp->wake_len + + rdev->wiphy.wowlan_config->tcp->wake_len / 8; + } + + msg = nlmsg_new(size, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = nl80211hdr_put(msg, info->snd_portid, info->snd_seq, 0, + NL80211_CMD_GET_WOWLAN); + if (!hdr) + goto nla_put_failure; + + if (rdev->wiphy.wowlan_config) { + struct nlattr *nl_wowlan; + + nl_wowlan = nla_nest_start_noflag(msg, + NL80211_ATTR_WOWLAN_TRIGGERS); + if (!nl_wowlan) + goto nla_put_failure; + + if ((rdev->wiphy.wowlan_config->any && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_ANY)) || + (rdev->wiphy.wowlan_config->disconnect && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_DISCONNECT)) || + (rdev->wiphy.wowlan_config->magic_pkt && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_MAGIC_PKT)) || + (rdev->wiphy.wowlan_config->gtk_rekey_failure && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_GTK_REKEY_FAILURE)) || + (rdev->wiphy.wowlan_config->eap_identity_req && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_EAP_IDENT_REQUEST)) || + (rdev->wiphy.wowlan_config->four_way_handshake && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_4WAY_HANDSHAKE)) || + (rdev->wiphy.wowlan_config->rfkill_release && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_RFKILL_RELEASE))) + goto nla_put_failure; + + if (nl80211_send_wowlan_patterns(msg, rdev)) + goto nla_put_failure; + + if (nl80211_send_wowlan_tcp(msg, + rdev->wiphy.wowlan_config->tcp)) + goto nla_put_failure; + + if (nl80211_send_wowlan_nd( + msg, + rdev->wiphy.wowlan_config->nd_config)) + goto nla_put_failure; + + nla_nest_end(msg, nl_wowlan); + } + + genlmsg_end(msg, hdr); + return genlmsg_reply(msg, info); + +nla_put_failure: + nlmsg_free(msg); + return -ENOBUFS; +} + +static int nl80211_parse_wowlan_tcp(struct cfg80211_registered_device *rdev, + struct nlattr *attr, + struct cfg80211_wowlan *trig) +{ + struct nlattr *tb[NUM_NL80211_WOWLAN_TCP]; + struct cfg80211_wowlan_tcp *cfg; + struct nl80211_wowlan_tcp_data_token *tok = NULL; + struct nl80211_wowlan_tcp_data_seq *seq = NULL; + u32 size; + u32 data_size, wake_size, tokens_size = 0, wake_mask_size; + int err, port; + + if (!rdev->wiphy.wowlan->tcp) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, MAX_NL80211_WOWLAN_TCP, attr, + nl80211_wowlan_tcp_policy, NULL); + if (err) + return err; + + if (!tb[NL80211_WOWLAN_TCP_SRC_IPV4] || + !tb[NL80211_WOWLAN_TCP_DST_IPV4] || + !tb[NL80211_WOWLAN_TCP_DST_MAC] || + !tb[NL80211_WOWLAN_TCP_DST_PORT] || + !tb[NL80211_WOWLAN_TCP_DATA_PAYLOAD] || + !tb[NL80211_WOWLAN_TCP_DATA_INTERVAL] || + !tb[NL80211_WOWLAN_TCP_WAKE_PAYLOAD] || + !tb[NL80211_WOWLAN_TCP_WAKE_MASK]) + return -EINVAL; + + data_size = nla_len(tb[NL80211_WOWLAN_TCP_DATA_PAYLOAD]); + if (data_size > rdev->wiphy.wowlan->tcp->data_payload_max) + return -EINVAL; + + if (nla_get_u32(tb[NL80211_WOWLAN_TCP_DATA_INTERVAL]) > + rdev->wiphy.wowlan->tcp->data_interval_max || + nla_get_u32(tb[NL80211_WOWLAN_TCP_DATA_INTERVAL]) == 0) + return -EINVAL; + + wake_size = nla_len(tb[NL80211_WOWLAN_TCP_WAKE_PAYLOAD]); + if (wake_size > rdev->wiphy.wowlan->tcp->wake_payload_max) + return -EINVAL; + + wake_mask_size = nla_len(tb[NL80211_WOWLAN_TCP_WAKE_MASK]); + if (wake_mask_size != DIV_ROUND_UP(wake_size, 8)) + return -EINVAL; + + if (tb[NL80211_WOWLAN_TCP_DATA_PAYLOAD_TOKEN]) { + u32 tokln = nla_len(tb[NL80211_WOWLAN_TCP_DATA_PAYLOAD_TOKEN]); + + tok = nla_data(tb[NL80211_WOWLAN_TCP_DATA_PAYLOAD_TOKEN]); + tokens_size = tokln - sizeof(*tok); + + if (!tok->len || tokens_size % tok->len) + return -EINVAL; + if (!rdev->wiphy.wowlan->tcp->tok) + return -EINVAL; + if (tok->len > rdev->wiphy.wowlan->tcp->tok->max_len) + return -EINVAL; + if (tok->len < rdev->wiphy.wowlan->tcp->tok->min_len) + return -EINVAL; + if (tokens_size > rdev->wiphy.wowlan->tcp->tok->bufsize) + return -EINVAL; + if (tok->offset + tok->len > data_size) + return -EINVAL; + } + + if (tb[NL80211_WOWLAN_TCP_DATA_PAYLOAD_SEQ]) { + seq = nla_data(tb[NL80211_WOWLAN_TCP_DATA_PAYLOAD_SEQ]); + if (!rdev->wiphy.wowlan->tcp->seq) + return -EINVAL; + if (seq->len == 0 || seq->len > 4) + return -EINVAL; + if (seq->len + seq->offset > data_size) + return -EINVAL; + } + + size = sizeof(*cfg); + size += data_size; + size += wake_size + wake_mask_size; + size += tokens_size; + + cfg = kzalloc(size, GFP_KERNEL); + if (!cfg) + return -ENOMEM; + cfg->src = nla_get_in_addr(tb[NL80211_WOWLAN_TCP_SRC_IPV4]); + cfg->dst = nla_get_in_addr(tb[NL80211_WOWLAN_TCP_DST_IPV4]); + memcpy(cfg->dst_mac, nla_data(tb[NL80211_WOWLAN_TCP_DST_MAC]), + ETH_ALEN); + if (tb[NL80211_WOWLAN_TCP_SRC_PORT]) + port = nla_get_u16(tb[NL80211_WOWLAN_TCP_SRC_PORT]); + else + port = 0; +#ifdef CONFIG_INET + /* allocate a socket and port for it and use it */ + err = __sock_create(wiphy_net(&rdev->wiphy), PF_INET, SOCK_STREAM, + IPPROTO_TCP, &cfg->sock, 1); + if (err) { + kfree(cfg); + return err; + } + if (inet_csk_get_port(cfg->sock->sk, port)) { + sock_release(cfg->sock); + kfree(cfg); + return -EADDRINUSE; + } + cfg->src_port = inet_sk(cfg->sock->sk)->inet_num; +#else + if (!port) { + kfree(cfg); + return -EINVAL; + } + cfg->src_port = port; +#endif + + cfg->dst_port = nla_get_u16(tb[NL80211_WOWLAN_TCP_DST_PORT]); + cfg->payload_len = data_size; + cfg->payload = (u8 *)cfg + sizeof(*cfg) + tokens_size; + memcpy((void *)cfg->payload, + nla_data(tb[NL80211_WOWLAN_TCP_DATA_PAYLOAD]), + data_size); + if (seq) + cfg->payload_seq = *seq; + cfg->data_interval = nla_get_u32(tb[NL80211_WOWLAN_TCP_DATA_INTERVAL]); + cfg->wake_len = wake_size; + cfg->wake_data = (u8 *)cfg + sizeof(*cfg) + tokens_size + data_size; + memcpy((void *)cfg->wake_data, + nla_data(tb[NL80211_WOWLAN_TCP_WAKE_PAYLOAD]), + wake_size); + cfg->wake_mask = (u8 *)cfg + sizeof(*cfg) + tokens_size + + data_size + wake_size; + memcpy((void *)cfg->wake_mask, + nla_data(tb[NL80211_WOWLAN_TCP_WAKE_MASK]), + wake_mask_size); + if (tok) { + cfg->tokens_size = tokens_size; + cfg->payload_tok = *tok; + memcpy(cfg->payload_tok.token_stream, tok->token_stream, + tokens_size); + } + + trig->tcp = cfg; + + return 0; +} + +static int nl80211_parse_wowlan_nd(struct cfg80211_registered_device *rdev, + const struct wiphy_wowlan_support *wowlan, + struct nlattr *attr, + struct cfg80211_wowlan *trig) +{ + struct nlattr **tb; + int err; + + tb = kcalloc(NUM_NL80211_ATTR, sizeof(*tb), GFP_KERNEL); + if (!tb) + return -ENOMEM; + + if (!(wowlan->flags & WIPHY_WOWLAN_NET_DETECT)) { + err = -EOPNOTSUPP; + goto out; + } + + err = nla_parse_nested_deprecated(tb, NL80211_ATTR_MAX, attr, + nl80211_policy, NULL); + if (err) + goto out; + + trig->nd_config = nl80211_parse_sched_scan(&rdev->wiphy, NULL, tb, + wowlan->max_nd_match_sets); + err = PTR_ERR_OR_ZERO(trig->nd_config); + if (err) + trig->nd_config = NULL; + +out: + kfree(tb); + return err; +} + +static int nl80211_set_wowlan(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct nlattr *tb[NUM_NL80211_WOWLAN_TRIG]; + struct cfg80211_wowlan new_triggers = {}; + struct cfg80211_wowlan *ntrig; + const struct wiphy_wowlan_support *wowlan = rdev->wiphy.wowlan; + int err, i; + bool prev_enabled = rdev->wiphy.wowlan_config; + bool regular = false; + + if (!wowlan) + return -EOPNOTSUPP; + + if (!info->attrs[NL80211_ATTR_WOWLAN_TRIGGERS]) { + cfg80211_rdev_free_wowlan(rdev); + rdev->wiphy.wowlan_config = NULL; + goto set_wakeup; + } + + err = nla_parse_nested_deprecated(tb, MAX_NL80211_WOWLAN_TRIG, + info->attrs[NL80211_ATTR_WOWLAN_TRIGGERS], + nl80211_wowlan_policy, info->extack); + if (err) + return err; + + if (tb[NL80211_WOWLAN_TRIG_ANY]) { + if (!(wowlan->flags & WIPHY_WOWLAN_ANY)) + return -EINVAL; + new_triggers.any = true; + } + + if (tb[NL80211_WOWLAN_TRIG_DISCONNECT]) { + if (!(wowlan->flags & WIPHY_WOWLAN_DISCONNECT)) + return -EINVAL; + new_triggers.disconnect = true; + regular = true; + } + + if (tb[NL80211_WOWLAN_TRIG_MAGIC_PKT]) { + if (!(wowlan->flags & WIPHY_WOWLAN_MAGIC_PKT)) + return -EINVAL; + new_triggers.magic_pkt = true; + regular = true; + } + + if (tb[NL80211_WOWLAN_TRIG_GTK_REKEY_SUPPORTED]) + return -EINVAL; + + if (tb[NL80211_WOWLAN_TRIG_GTK_REKEY_FAILURE]) { + if (!(wowlan->flags & WIPHY_WOWLAN_GTK_REKEY_FAILURE)) + return -EINVAL; + new_triggers.gtk_rekey_failure = true; + regular = true; + } + + if (tb[NL80211_WOWLAN_TRIG_EAP_IDENT_REQUEST]) { + if (!(wowlan->flags & WIPHY_WOWLAN_EAP_IDENTITY_REQ)) + return -EINVAL; + new_triggers.eap_identity_req = true; + regular = true; + } + + if (tb[NL80211_WOWLAN_TRIG_4WAY_HANDSHAKE]) { + if (!(wowlan->flags & WIPHY_WOWLAN_4WAY_HANDSHAKE)) + return -EINVAL; + new_triggers.four_way_handshake = true; + regular = true; + } + + if (tb[NL80211_WOWLAN_TRIG_RFKILL_RELEASE]) { + if (!(wowlan->flags & WIPHY_WOWLAN_RFKILL_RELEASE)) + return -EINVAL; + new_triggers.rfkill_release = true; + regular = true; + } + + if (tb[NL80211_WOWLAN_TRIG_PKT_PATTERN]) { + struct nlattr *pat; + int n_patterns = 0; + int rem, pat_len, mask_len, pkt_offset; + struct nlattr *pat_tb[NUM_NL80211_PKTPAT]; + + regular = true; + + nla_for_each_nested(pat, tb[NL80211_WOWLAN_TRIG_PKT_PATTERN], + rem) + n_patterns++; + if (n_patterns > wowlan->n_patterns) + return -EINVAL; + + new_triggers.patterns = kcalloc(n_patterns, + sizeof(new_triggers.patterns[0]), + GFP_KERNEL); + if (!new_triggers.patterns) + return -ENOMEM; + + new_triggers.n_patterns = n_patterns; + i = 0; + + nla_for_each_nested(pat, tb[NL80211_WOWLAN_TRIG_PKT_PATTERN], + rem) { + u8 *mask_pat; + + err = nla_parse_nested_deprecated(pat_tb, + MAX_NL80211_PKTPAT, + pat, + nl80211_packet_pattern_policy, + info->extack); + if (err) + goto error; + + err = -EINVAL; + if (!pat_tb[NL80211_PKTPAT_MASK] || + !pat_tb[NL80211_PKTPAT_PATTERN]) + goto error; + pat_len = nla_len(pat_tb[NL80211_PKTPAT_PATTERN]); + mask_len = DIV_ROUND_UP(pat_len, 8); + if (nla_len(pat_tb[NL80211_PKTPAT_MASK]) != mask_len) + goto error; + if (pat_len > wowlan->pattern_max_len || + pat_len < wowlan->pattern_min_len) + goto error; + + if (!pat_tb[NL80211_PKTPAT_OFFSET]) + pkt_offset = 0; + else + pkt_offset = nla_get_u32( + pat_tb[NL80211_PKTPAT_OFFSET]); + if (pkt_offset > wowlan->max_pkt_offset) + goto error; + new_triggers.patterns[i].pkt_offset = pkt_offset; + + mask_pat = kmalloc(mask_len + pat_len, GFP_KERNEL); + if (!mask_pat) { + err = -ENOMEM; + goto error; + } + new_triggers.patterns[i].mask = mask_pat; + memcpy(mask_pat, nla_data(pat_tb[NL80211_PKTPAT_MASK]), + mask_len); + mask_pat += mask_len; + new_triggers.patterns[i].pattern = mask_pat; + new_triggers.patterns[i].pattern_len = pat_len; + memcpy(mask_pat, + nla_data(pat_tb[NL80211_PKTPAT_PATTERN]), + pat_len); + i++; + } + } + + if (tb[NL80211_WOWLAN_TRIG_TCP_CONNECTION]) { + regular = true; + err = nl80211_parse_wowlan_tcp( + rdev, tb[NL80211_WOWLAN_TRIG_TCP_CONNECTION], + &new_triggers); + if (err) + goto error; + } + + if (tb[NL80211_WOWLAN_TRIG_NET_DETECT]) { + regular = true; + err = nl80211_parse_wowlan_nd( + rdev, wowlan, tb[NL80211_WOWLAN_TRIG_NET_DETECT], + &new_triggers); + if (err) + goto error; + } + + /* The 'any' trigger means the device continues operating more or less + * as in its normal operation mode and wakes up the host on most of the + * normal interrupts (like packet RX, ...) + * It therefore makes little sense to combine with the more constrained + * wakeup trigger modes. + */ + if (new_triggers.any && regular) { + err = -EINVAL; + goto error; + } + + ntrig = kmemdup(&new_triggers, sizeof(new_triggers), GFP_KERNEL); + if (!ntrig) { + err = -ENOMEM; + goto error; + } + cfg80211_rdev_free_wowlan(rdev); + rdev->wiphy.wowlan_config = ntrig; + + set_wakeup: + if (rdev->ops->set_wakeup && + prev_enabled != !!rdev->wiphy.wowlan_config) + rdev_set_wakeup(rdev, rdev->wiphy.wowlan_config); + + return 0; + error: + for (i = 0; i < new_triggers.n_patterns; i++) + kfree(new_triggers.patterns[i].mask); + kfree(new_triggers.patterns); + if (new_triggers.tcp && new_triggers.tcp->sock) + sock_release(new_triggers.tcp->sock); + kfree(new_triggers.tcp); + kfree(new_triggers.nd_config); + return err; +} +#endif + +static int nl80211_send_coalesce_rules(struct sk_buff *msg, + struct cfg80211_registered_device *rdev) +{ + struct nlattr *nl_pats, *nl_pat, *nl_rule, *nl_rules; + int i, j, pat_len; + struct cfg80211_coalesce_rules *rule; + + if (!rdev->coalesce->n_rules) + return 0; + + nl_rules = nla_nest_start_noflag(msg, NL80211_ATTR_COALESCE_RULE); + if (!nl_rules) + return -ENOBUFS; + + for (i = 0; i < rdev->coalesce->n_rules; i++) { + nl_rule = nla_nest_start_noflag(msg, i + 1); + if (!nl_rule) + return -ENOBUFS; + + rule = &rdev->coalesce->rules[i]; + if (nla_put_u32(msg, NL80211_ATTR_COALESCE_RULE_DELAY, + rule->delay)) + return -ENOBUFS; + + if (nla_put_u32(msg, NL80211_ATTR_COALESCE_RULE_CONDITION, + rule->condition)) + return -ENOBUFS; + + nl_pats = nla_nest_start_noflag(msg, + NL80211_ATTR_COALESCE_RULE_PKT_PATTERN); + if (!nl_pats) + return -ENOBUFS; + + for (j = 0; j < rule->n_patterns; j++) { + nl_pat = nla_nest_start_noflag(msg, j + 1); + if (!nl_pat) + return -ENOBUFS; + pat_len = rule->patterns[j].pattern_len; + if (nla_put(msg, NL80211_PKTPAT_MASK, + DIV_ROUND_UP(pat_len, 8), + rule->patterns[j].mask) || + nla_put(msg, NL80211_PKTPAT_PATTERN, pat_len, + rule->patterns[j].pattern) || + nla_put_u32(msg, NL80211_PKTPAT_OFFSET, + rule->patterns[j].pkt_offset)) + return -ENOBUFS; + nla_nest_end(msg, nl_pat); + } + nla_nest_end(msg, nl_pats); + nla_nest_end(msg, nl_rule); + } + nla_nest_end(msg, nl_rules); + + return 0; +} + +static int nl80211_get_coalesce(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct sk_buff *msg; + void *hdr; + + if (!rdev->wiphy.coalesce) + return -EOPNOTSUPP; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = nl80211hdr_put(msg, info->snd_portid, info->snd_seq, 0, + NL80211_CMD_GET_COALESCE); + if (!hdr) + goto nla_put_failure; + + if (rdev->coalesce && nl80211_send_coalesce_rules(msg, rdev)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return genlmsg_reply(msg, info); + +nla_put_failure: + nlmsg_free(msg); + return -ENOBUFS; +} + +void cfg80211_rdev_free_coalesce(struct cfg80211_registered_device *rdev) +{ + struct cfg80211_coalesce *coalesce = rdev->coalesce; + int i, j; + struct cfg80211_coalesce_rules *rule; + + if (!coalesce) + return; + + for (i = 0; i < coalesce->n_rules; i++) { + rule = &coalesce->rules[i]; + for (j = 0; j < rule->n_patterns; j++) + kfree(rule->patterns[j].mask); + kfree(rule->patterns); + } + kfree(coalesce->rules); + kfree(coalesce); + rdev->coalesce = NULL; +} + +static int nl80211_parse_coalesce_rule(struct cfg80211_registered_device *rdev, + struct nlattr *rule, + struct cfg80211_coalesce_rules *new_rule) +{ + int err, i; + const struct wiphy_coalesce_support *coalesce = rdev->wiphy.coalesce; + struct nlattr *tb[NUM_NL80211_ATTR_COALESCE_RULE], *pat; + int rem, pat_len, mask_len, pkt_offset, n_patterns = 0; + struct nlattr *pat_tb[NUM_NL80211_PKTPAT]; + + err = nla_parse_nested_deprecated(tb, NL80211_ATTR_COALESCE_RULE_MAX, + rule, nl80211_coalesce_policy, NULL); + if (err) + return err; + + if (tb[NL80211_ATTR_COALESCE_RULE_DELAY]) + new_rule->delay = + nla_get_u32(tb[NL80211_ATTR_COALESCE_RULE_DELAY]); + if (new_rule->delay > coalesce->max_delay) + return -EINVAL; + + if (tb[NL80211_ATTR_COALESCE_RULE_CONDITION]) + new_rule->condition = + nla_get_u32(tb[NL80211_ATTR_COALESCE_RULE_CONDITION]); + + if (!tb[NL80211_ATTR_COALESCE_RULE_PKT_PATTERN]) + return -EINVAL; + + nla_for_each_nested(pat, tb[NL80211_ATTR_COALESCE_RULE_PKT_PATTERN], + rem) + n_patterns++; + if (n_patterns > coalesce->n_patterns) + return -EINVAL; + + new_rule->patterns = kcalloc(n_patterns, sizeof(new_rule->patterns[0]), + GFP_KERNEL); + if (!new_rule->patterns) + return -ENOMEM; + + new_rule->n_patterns = n_patterns; + i = 0; + + nla_for_each_nested(pat, tb[NL80211_ATTR_COALESCE_RULE_PKT_PATTERN], + rem) { + u8 *mask_pat; + + err = nla_parse_nested_deprecated(pat_tb, MAX_NL80211_PKTPAT, + pat, + nl80211_packet_pattern_policy, + NULL); + if (err) + return err; + + if (!pat_tb[NL80211_PKTPAT_MASK] || + !pat_tb[NL80211_PKTPAT_PATTERN]) + return -EINVAL; + pat_len = nla_len(pat_tb[NL80211_PKTPAT_PATTERN]); + mask_len = DIV_ROUND_UP(pat_len, 8); + if (nla_len(pat_tb[NL80211_PKTPAT_MASK]) != mask_len) + return -EINVAL; + if (pat_len > coalesce->pattern_max_len || + pat_len < coalesce->pattern_min_len) + return -EINVAL; + + if (!pat_tb[NL80211_PKTPAT_OFFSET]) + pkt_offset = 0; + else + pkt_offset = nla_get_u32(pat_tb[NL80211_PKTPAT_OFFSET]); + if (pkt_offset > coalesce->max_pkt_offset) + return -EINVAL; + new_rule->patterns[i].pkt_offset = pkt_offset; + + mask_pat = kmalloc(mask_len + pat_len, GFP_KERNEL); + if (!mask_pat) + return -ENOMEM; + + new_rule->patterns[i].mask = mask_pat; + memcpy(mask_pat, nla_data(pat_tb[NL80211_PKTPAT_MASK]), + mask_len); + + mask_pat += mask_len; + new_rule->patterns[i].pattern = mask_pat; + new_rule->patterns[i].pattern_len = pat_len; + memcpy(mask_pat, nla_data(pat_tb[NL80211_PKTPAT_PATTERN]), + pat_len); + i++; + } + + return 0; +} + +static int nl80211_set_coalesce(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + const struct wiphy_coalesce_support *coalesce = rdev->wiphy.coalesce; + struct cfg80211_coalesce new_coalesce = {}; + struct cfg80211_coalesce *n_coalesce; + int err, rem_rule, n_rules = 0, i, j; + struct nlattr *rule; + struct cfg80211_coalesce_rules *tmp_rule; + + if (!rdev->wiphy.coalesce || !rdev->ops->set_coalesce) + return -EOPNOTSUPP; + + if (!info->attrs[NL80211_ATTR_COALESCE_RULE]) { + cfg80211_rdev_free_coalesce(rdev); + rdev_set_coalesce(rdev, NULL); + return 0; + } + + nla_for_each_nested(rule, info->attrs[NL80211_ATTR_COALESCE_RULE], + rem_rule) + n_rules++; + if (n_rules > coalesce->n_rules) + return -EINVAL; + + new_coalesce.rules = kcalloc(n_rules, sizeof(new_coalesce.rules[0]), + GFP_KERNEL); + if (!new_coalesce.rules) + return -ENOMEM; + + new_coalesce.n_rules = n_rules; + i = 0; + + nla_for_each_nested(rule, info->attrs[NL80211_ATTR_COALESCE_RULE], + rem_rule) { + err = nl80211_parse_coalesce_rule(rdev, rule, + &new_coalesce.rules[i]); + if (err) + goto error; + + i++; + } + + err = rdev_set_coalesce(rdev, &new_coalesce); + if (err) + goto error; + + n_coalesce = kmemdup(&new_coalesce, sizeof(new_coalesce), GFP_KERNEL); + if (!n_coalesce) { + err = -ENOMEM; + goto error; + } + cfg80211_rdev_free_coalesce(rdev); + rdev->coalesce = n_coalesce; + + return 0; +error: + for (i = 0; i < new_coalesce.n_rules; i++) { + tmp_rule = &new_coalesce.rules[i]; + for (j = 0; j < tmp_rule->n_patterns; j++) + kfree(tmp_rule->patterns[j].mask); + kfree(tmp_rule->patterns); + } + kfree(new_coalesce.rules); + + return err; +} + +static int nl80211_set_rekey_data(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct nlattr *tb[NUM_NL80211_REKEY_DATA]; + struct cfg80211_gtk_rekey_data rekey_data = {}; + int err; + + if (!info->attrs[NL80211_ATTR_REKEY_DATA]) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, MAX_NL80211_REKEY_DATA, + info->attrs[NL80211_ATTR_REKEY_DATA], + nl80211_rekey_policy, info->extack); + if (err) + return err; + + if (!tb[NL80211_REKEY_DATA_REPLAY_CTR] || !tb[NL80211_REKEY_DATA_KEK] || + !tb[NL80211_REKEY_DATA_KCK]) + return -EINVAL; + if (nla_len(tb[NL80211_REKEY_DATA_KEK]) != NL80211_KEK_LEN && + !(rdev->wiphy.flags & WIPHY_FLAG_SUPPORTS_EXT_KEK_KCK && + nla_len(tb[NL80211_REKEY_DATA_KEK]) == NL80211_KEK_EXT_LEN)) + return -ERANGE; + if (nla_len(tb[NL80211_REKEY_DATA_KCK]) != NL80211_KCK_LEN && + !(rdev->wiphy.flags & WIPHY_FLAG_SUPPORTS_EXT_KEK_KCK && + nla_len(tb[NL80211_REKEY_DATA_KCK]) == NL80211_KCK_EXT_LEN)) + return -ERANGE; + + rekey_data.kek = nla_data(tb[NL80211_REKEY_DATA_KEK]); + rekey_data.kck = nla_data(tb[NL80211_REKEY_DATA_KCK]); + rekey_data.replay_ctr = nla_data(tb[NL80211_REKEY_DATA_REPLAY_CTR]); + rekey_data.kek_len = nla_len(tb[NL80211_REKEY_DATA_KEK]); + rekey_data.kck_len = nla_len(tb[NL80211_REKEY_DATA_KCK]); + if (tb[NL80211_REKEY_DATA_AKM]) + rekey_data.akm = nla_get_u32(tb[NL80211_REKEY_DATA_AKM]); + + wdev_lock(wdev); + if (!wdev->connected) { + err = -ENOTCONN; + goto out; + } + + if (!rdev->ops->set_rekey_data) { + err = -EOPNOTSUPP; + goto out; + } + + err = rdev_set_rekey_data(rdev, dev, &rekey_data); + out: + wdev_unlock(wdev); + return err; +} + +static int nl80211_register_unexpected_frame(struct sk_buff *skb, + struct genl_info *info) +{ + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + + if (wdev->iftype != NL80211_IFTYPE_AP && + wdev->iftype != NL80211_IFTYPE_P2P_GO) + return -EINVAL; + + if (wdev->ap_unexpected_nlportid) + return -EBUSY; + + wdev->ap_unexpected_nlportid = info->snd_portid; + return 0; +} + +static int nl80211_probe_client(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct sk_buff *msg; + void *hdr; + const u8 *addr; + u64 cookie; + int err; + + if (wdev->iftype != NL80211_IFTYPE_AP && + wdev->iftype != NL80211_IFTYPE_P2P_GO) + return -EOPNOTSUPP; + + if (!info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + if (!rdev->ops->probe_client) + return -EOPNOTSUPP; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = nl80211hdr_put(msg, info->snd_portid, info->snd_seq, 0, + NL80211_CMD_PROBE_CLIENT); + if (!hdr) { + err = -ENOBUFS; + goto free_msg; + } + + addr = nla_data(info->attrs[NL80211_ATTR_MAC]); + + err = rdev_probe_client(rdev, dev, addr, &cookie); + if (err) + goto free_msg; + + if (nla_put_u64_64bit(msg, NL80211_ATTR_COOKIE, cookie, + NL80211_ATTR_PAD)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + return genlmsg_reply(msg, info); + + nla_put_failure: + err = -ENOBUFS; + free_msg: + nlmsg_free(msg); + return err; +} + +static int nl80211_register_beacons(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct cfg80211_beacon_registration *reg, *nreg; + int rv; + + if (!(rdev->wiphy.flags & WIPHY_FLAG_REPORTS_OBSS)) + return -EOPNOTSUPP; + + nreg = kzalloc(sizeof(*nreg), GFP_KERNEL); + if (!nreg) + return -ENOMEM; + + /* First, check if already registered. */ + spin_lock_bh(&rdev->beacon_registrations_lock); + list_for_each_entry(reg, &rdev->beacon_registrations, list) { + if (reg->nlportid == info->snd_portid) { + rv = -EALREADY; + goto out_err; + } + } + /* Add it to the list */ + nreg->nlportid = info->snd_portid; + list_add(&nreg->list, &rdev->beacon_registrations); + + spin_unlock_bh(&rdev->beacon_registrations_lock); + + return 0; +out_err: + spin_unlock_bh(&rdev->beacon_registrations_lock); + kfree(nreg); + return rv; +} + +static int nl80211_start_p2p_device(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + int err; + + if (!rdev->ops->start_p2p_device) + return -EOPNOTSUPP; + + if (wdev->iftype != NL80211_IFTYPE_P2P_DEVICE) + return -EOPNOTSUPP; + + if (wdev_running(wdev)) + return 0; + + if (rfkill_blocked(rdev->wiphy.rfkill)) + return -ERFKILL; + + err = rdev_start_p2p_device(rdev, wdev); + if (err) + return err; + + wdev->is_running = true; + rdev->opencount++; + + return 0; +} + +static int nl80211_stop_p2p_device(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + + if (wdev->iftype != NL80211_IFTYPE_P2P_DEVICE) + return -EOPNOTSUPP; + + if (!rdev->ops->stop_p2p_device) + return -EOPNOTSUPP; + + cfg80211_stop_p2p_device(rdev, wdev); + + return 0; +} + +static int nl80211_start_nan(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + struct cfg80211_nan_conf conf = {}; + int err; + + if (wdev->iftype != NL80211_IFTYPE_NAN) + return -EOPNOTSUPP; + + if (wdev_running(wdev)) + return -EEXIST; + + if (rfkill_blocked(rdev->wiphy.rfkill)) + return -ERFKILL; + + if (!info->attrs[NL80211_ATTR_NAN_MASTER_PREF]) + return -EINVAL; + + conf.master_pref = + nla_get_u8(info->attrs[NL80211_ATTR_NAN_MASTER_PREF]); + + if (info->attrs[NL80211_ATTR_BANDS]) { + u32 bands = nla_get_u32(info->attrs[NL80211_ATTR_BANDS]); + + if (bands & ~(u32)wdev->wiphy->nan_supported_bands) + return -EOPNOTSUPP; + + if (bands && !(bands & BIT(NL80211_BAND_2GHZ))) + return -EINVAL; + + conf.bands = bands; + } + + err = rdev_start_nan(rdev, wdev, &conf); + if (err) + return err; + + wdev->is_running = true; + rdev->opencount++; + + return 0; +} + +static int nl80211_stop_nan(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + + if (wdev->iftype != NL80211_IFTYPE_NAN) + return -EOPNOTSUPP; + + cfg80211_stop_nan(rdev, wdev); + + return 0; +} + +static int validate_nan_filter(struct nlattr *filter_attr) +{ + struct nlattr *attr; + int len = 0, n_entries = 0, rem; + + nla_for_each_nested(attr, filter_attr, rem) { + len += nla_len(attr); + n_entries++; + } + + if (len >= U8_MAX) + return -EINVAL; + + return n_entries; +} + +static int handle_nan_filter(struct nlattr *attr_filter, + struct cfg80211_nan_func *func, + bool tx) +{ + struct nlattr *attr; + int n_entries, rem, i; + struct cfg80211_nan_func_filter *filter; + + n_entries = validate_nan_filter(attr_filter); + if (n_entries < 0) + return n_entries; + + BUILD_BUG_ON(sizeof(*func->rx_filters) != sizeof(*func->tx_filters)); + + filter = kcalloc(n_entries, sizeof(*func->rx_filters), GFP_KERNEL); + if (!filter) + return -ENOMEM; + + i = 0; + nla_for_each_nested(attr, attr_filter, rem) { + filter[i].filter = nla_memdup(attr, GFP_KERNEL); + if (!filter[i].filter) + goto err; + + filter[i].len = nla_len(attr); + i++; + } + if (tx) { + func->num_tx_filters = n_entries; + func->tx_filters = filter; + } else { + func->num_rx_filters = n_entries; + func->rx_filters = filter; + } + + return 0; + +err: + i = 0; + nla_for_each_nested(attr, attr_filter, rem) { + kfree(filter[i].filter); + i++; + } + kfree(filter); + return -ENOMEM; +} + +static int nl80211_nan_add_func(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + struct nlattr *tb[NUM_NL80211_NAN_FUNC_ATTR], *func_attr; + struct cfg80211_nan_func *func; + struct sk_buff *msg = NULL; + void *hdr = NULL; + int err = 0; + + if (wdev->iftype != NL80211_IFTYPE_NAN) + return -EOPNOTSUPP; + + if (!wdev_running(wdev)) + return -ENOTCONN; + + if (!info->attrs[NL80211_ATTR_NAN_FUNC]) + return -EINVAL; + + err = nla_parse_nested_deprecated(tb, NL80211_NAN_FUNC_ATTR_MAX, + info->attrs[NL80211_ATTR_NAN_FUNC], + nl80211_nan_func_policy, + info->extack); + if (err) + return err; + + func = kzalloc(sizeof(*func), GFP_KERNEL); + if (!func) + return -ENOMEM; + + func->cookie = cfg80211_assign_cookie(rdev); + + if (!tb[NL80211_NAN_FUNC_TYPE]) { + err = -EINVAL; + goto out; + } + + + func->type = nla_get_u8(tb[NL80211_NAN_FUNC_TYPE]); + + if (!tb[NL80211_NAN_FUNC_SERVICE_ID]) { + err = -EINVAL; + goto out; + } + + memcpy(func->service_id, nla_data(tb[NL80211_NAN_FUNC_SERVICE_ID]), + sizeof(func->service_id)); + + func->close_range = + nla_get_flag(tb[NL80211_NAN_FUNC_CLOSE_RANGE]); + + if (tb[NL80211_NAN_FUNC_SERVICE_INFO]) { + func->serv_spec_info_len = + nla_len(tb[NL80211_NAN_FUNC_SERVICE_INFO]); + func->serv_spec_info = + kmemdup(nla_data(tb[NL80211_NAN_FUNC_SERVICE_INFO]), + func->serv_spec_info_len, + GFP_KERNEL); + if (!func->serv_spec_info) { + err = -ENOMEM; + goto out; + } + } + + if (tb[NL80211_NAN_FUNC_TTL]) + func->ttl = nla_get_u32(tb[NL80211_NAN_FUNC_TTL]); + + switch (func->type) { + case NL80211_NAN_FUNC_PUBLISH: + if (!tb[NL80211_NAN_FUNC_PUBLISH_TYPE]) { + err = -EINVAL; + goto out; + } + + func->publish_type = + nla_get_u8(tb[NL80211_NAN_FUNC_PUBLISH_TYPE]); + func->publish_bcast = + nla_get_flag(tb[NL80211_NAN_FUNC_PUBLISH_BCAST]); + + if ((!(func->publish_type & NL80211_NAN_SOLICITED_PUBLISH)) && + func->publish_bcast) { + err = -EINVAL; + goto out; + } + break; + case NL80211_NAN_FUNC_SUBSCRIBE: + func->subscribe_active = + nla_get_flag(tb[NL80211_NAN_FUNC_SUBSCRIBE_ACTIVE]); + break; + case NL80211_NAN_FUNC_FOLLOW_UP: + if (!tb[NL80211_NAN_FUNC_FOLLOW_UP_ID] || + !tb[NL80211_NAN_FUNC_FOLLOW_UP_REQ_ID] || + !tb[NL80211_NAN_FUNC_FOLLOW_UP_DEST]) { + err = -EINVAL; + goto out; + } + + func->followup_id = + nla_get_u8(tb[NL80211_NAN_FUNC_FOLLOW_UP_ID]); + func->followup_reqid = + nla_get_u8(tb[NL80211_NAN_FUNC_FOLLOW_UP_REQ_ID]); + memcpy(func->followup_dest.addr, + nla_data(tb[NL80211_NAN_FUNC_FOLLOW_UP_DEST]), + sizeof(func->followup_dest.addr)); + if (func->ttl) { + err = -EINVAL; + goto out; + } + break; + default: + err = -EINVAL; + goto out; + } + + if (tb[NL80211_NAN_FUNC_SRF]) { + struct nlattr *srf_tb[NUM_NL80211_NAN_SRF_ATTR]; + + err = nla_parse_nested_deprecated(srf_tb, + NL80211_NAN_SRF_ATTR_MAX, + tb[NL80211_NAN_FUNC_SRF], + nl80211_nan_srf_policy, + info->extack); + if (err) + goto out; + + func->srf_include = + nla_get_flag(srf_tb[NL80211_NAN_SRF_INCLUDE]); + + if (srf_tb[NL80211_NAN_SRF_BF]) { + if (srf_tb[NL80211_NAN_SRF_MAC_ADDRS] || + !srf_tb[NL80211_NAN_SRF_BF_IDX]) { + err = -EINVAL; + goto out; + } + + func->srf_bf_len = + nla_len(srf_tb[NL80211_NAN_SRF_BF]); + func->srf_bf = + kmemdup(nla_data(srf_tb[NL80211_NAN_SRF_BF]), + func->srf_bf_len, GFP_KERNEL); + if (!func->srf_bf) { + err = -ENOMEM; + goto out; + } + + func->srf_bf_idx = + nla_get_u8(srf_tb[NL80211_NAN_SRF_BF_IDX]); + } else { + struct nlattr *attr, *mac_attr = + srf_tb[NL80211_NAN_SRF_MAC_ADDRS]; + int n_entries, rem, i = 0; + + if (!mac_attr) { + err = -EINVAL; + goto out; + } + + n_entries = validate_acl_mac_addrs(mac_attr); + if (n_entries <= 0) { + err = -EINVAL; + goto out; + } + + func->srf_num_macs = n_entries; + func->srf_macs = + kcalloc(n_entries, sizeof(*func->srf_macs), + GFP_KERNEL); + if (!func->srf_macs) { + err = -ENOMEM; + goto out; + } + + nla_for_each_nested(attr, mac_attr, rem) + memcpy(func->srf_macs[i++].addr, nla_data(attr), + sizeof(*func->srf_macs)); + } + } + + if (tb[NL80211_NAN_FUNC_TX_MATCH_FILTER]) { + err = handle_nan_filter(tb[NL80211_NAN_FUNC_TX_MATCH_FILTER], + func, true); + if (err) + goto out; + } + + if (tb[NL80211_NAN_FUNC_RX_MATCH_FILTER]) { + err = handle_nan_filter(tb[NL80211_NAN_FUNC_RX_MATCH_FILTER], + func, false); + if (err) + goto out; + } + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) { + err = -ENOMEM; + goto out; + } + + hdr = nl80211hdr_put(msg, info->snd_portid, info->snd_seq, 0, + NL80211_CMD_ADD_NAN_FUNCTION); + /* This can't really happen - we just allocated 4KB */ + if (WARN_ON(!hdr)) { + err = -ENOMEM; + goto out; + } + + err = rdev_add_nan_func(rdev, wdev, func); +out: + if (err < 0) { + cfg80211_free_nan_func(func); + nlmsg_free(msg); + return err; + } + + /* propagate the instance id and cookie to userspace */ + if (nla_put_u64_64bit(msg, NL80211_ATTR_COOKIE, func->cookie, + NL80211_ATTR_PAD)) + goto nla_put_failure; + + func_attr = nla_nest_start_noflag(msg, NL80211_ATTR_NAN_FUNC); + if (!func_attr) + goto nla_put_failure; + + if (nla_put_u8(msg, NL80211_NAN_FUNC_INSTANCE_ID, + func->instance_id)) + goto nla_put_failure; + + nla_nest_end(msg, func_attr); + + genlmsg_end(msg, hdr); + return genlmsg_reply(msg, info); + +nla_put_failure: + nlmsg_free(msg); + return -ENOBUFS; +} + +static int nl80211_nan_del_func(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + u64 cookie; + + if (wdev->iftype != NL80211_IFTYPE_NAN) + return -EOPNOTSUPP; + + if (!wdev_running(wdev)) + return -ENOTCONN; + + if (!info->attrs[NL80211_ATTR_COOKIE]) + return -EINVAL; + + cookie = nla_get_u64(info->attrs[NL80211_ATTR_COOKIE]); + + rdev_del_nan_func(rdev, wdev, cookie); + + return 0; +} + +static int nl80211_nan_change_config(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + struct cfg80211_nan_conf conf = {}; + u32 changed = 0; + + if (wdev->iftype != NL80211_IFTYPE_NAN) + return -EOPNOTSUPP; + + if (!wdev_running(wdev)) + return -ENOTCONN; + + if (info->attrs[NL80211_ATTR_NAN_MASTER_PREF]) { + conf.master_pref = + nla_get_u8(info->attrs[NL80211_ATTR_NAN_MASTER_PREF]); + if (conf.master_pref <= 1 || conf.master_pref == 255) + return -EINVAL; + + changed |= CFG80211_NAN_CONF_CHANGED_PREF; + } + + if (info->attrs[NL80211_ATTR_BANDS]) { + u32 bands = nla_get_u32(info->attrs[NL80211_ATTR_BANDS]); + + if (bands & ~(u32)wdev->wiphy->nan_supported_bands) + return -EOPNOTSUPP; + + if (bands && !(bands & BIT(NL80211_BAND_2GHZ))) + return -EINVAL; + + conf.bands = bands; + changed |= CFG80211_NAN_CONF_CHANGED_BANDS; + } + + if (!changed) + return -EINVAL; + + return rdev_nan_change_conf(rdev, wdev, &conf, changed); +} + +void cfg80211_nan_match(struct wireless_dev *wdev, + struct cfg80211_nan_match_params *match, gfp_t gfp) +{ + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct nlattr *match_attr, *local_func_attr, *peer_func_attr; + struct sk_buff *msg; + void *hdr; + + if (WARN_ON(!match->inst_id || !match->peer_inst_id || !match->addr)) + return; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_NAN_MATCH); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + (wdev->netdev && nla_put_u32(msg, NL80211_ATTR_IFINDEX, + wdev->netdev->ifindex)) || + nla_put_u64_64bit(msg, NL80211_ATTR_WDEV, wdev_id(wdev), + NL80211_ATTR_PAD)) + goto nla_put_failure; + + if (nla_put_u64_64bit(msg, NL80211_ATTR_COOKIE, match->cookie, + NL80211_ATTR_PAD) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, match->addr)) + goto nla_put_failure; + + match_attr = nla_nest_start_noflag(msg, NL80211_ATTR_NAN_MATCH); + if (!match_attr) + goto nla_put_failure; + + local_func_attr = nla_nest_start_noflag(msg, + NL80211_NAN_MATCH_FUNC_LOCAL); + if (!local_func_attr) + goto nla_put_failure; + + if (nla_put_u8(msg, NL80211_NAN_FUNC_INSTANCE_ID, match->inst_id)) + goto nla_put_failure; + + nla_nest_end(msg, local_func_attr); + + peer_func_attr = nla_nest_start_noflag(msg, + NL80211_NAN_MATCH_FUNC_PEER); + if (!peer_func_attr) + goto nla_put_failure; + + if (nla_put_u8(msg, NL80211_NAN_FUNC_TYPE, match->type) || + nla_put_u8(msg, NL80211_NAN_FUNC_INSTANCE_ID, match->peer_inst_id)) + goto nla_put_failure; + + if (match->info && match->info_len && + nla_put(msg, NL80211_NAN_FUNC_SERVICE_INFO, match->info_len, + match->info)) + goto nla_put_failure; + + nla_nest_end(msg, peer_func_attr); + nla_nest_end(msg, match_attr); + genlmsg_end(msg, hdr); + + if (!wdev->owner_nlportid) + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), + msg, 0, NL80211_MCGRP_NAN, gfp); + else + genlmsg_unicast(wiphy_net(&rdev->wiphy), msg, + wdev->owner_nlportid); + + return; + +nla_put_failure: + nlmsg_free(msg); +} +EXPORT_SYMBOL(cfg80211_nan_match); + +void cfg80211_nan_func_terminated(struct wireless_dev *wdev, + u8 inst_id, + enum nl80211_nan_func_term_reason reason, + u64 cookie, gfp_t gfp) +{ + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct sk_buff *msg; + struct nlattr *func_attr; + void *hdr; + + if (WARN_ON(!inst_id)) + return; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_DEL_NAN_FUNCTION); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + (wdev->netdev && nla_put_u32(msg, NL80211_ATTR_IFINDEX, + wdev->netdev->ifindex)) || + nla_put_u64_64bit(msg, NL80211_ATTR_WDEV, wdev_id(wdev), + NL80211_ATTR_PAD)) + goto nla_put_failure; + + if (nla_put_u64_64bit(msg, NL80211_ATTR_COOKIE, cookie, + NL80211_ATTR_PAD)) + goto nla_put_failure; + + func_attr = nla_nest_start_noflag(msg, NL80211_ATTR_NAN_FUNC); + if (!func_attr) + goto nla_put_failure; + + if (nla_put_u8(msg, NL80211_NAN_FUNC_INSTANCE_ID, inst_id) || + nla_put_u8(msg, NL80211_NAN_FUNC_TERM_REASON, reason)) + goto nla_put_failure; + + nla_nest_end(msg, func_attr); + genlmsg_end(msg, hdr); + + if (!wdev->owner_nlportid) + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), + msg, 0, NL80211_MCGRP_NAN, gfp); + else + genlmsg_unicast(wiphy_net(&rdev->wiphy), msg, + wdev->owner_nlportid); + + return; + +nla_put_failure: + nlmsg_free(msg); +} +EXPORT_SYMBOL(cfg80211_nan_func_terminated); + +static int nl80211_get_protocol_features(struct sk_buff *skb, + struct genl_info *info) +{ + void *hdr; + struct sk_buff *msg; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = nl80211hdr_put(msg, info->snd_portid, info->snd_seq, 0, + NL80211_CMD_GET_PROTOCOL_FEATURES); + if (!hdr) + goto nla_put_failure; + + if (nla_put_u32(msg, NL80211_ATTR_PROTOCOL_FEATURES, + NL80211_PROTOCOL_FEATURE_SPLIT_WIPHY_DUMP)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return genlmsg_reply(msg, info); + + nla_put_failure: + kfree_skb(msg); + return -ENOBUFS; +} + +static int nl80211_update_ft_ies(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct cfg80211_update_ft_ies_params ft_params; + struct net_device *dev = info->user_ptr[1]; + + if (!rdev->ops->update_ft_ies) + return -EOPNOTSUPP; + + if (!info->attrs[NL80211_ATTR_MDID] || + !info->attrs[NL80211_ATTR_IE]) + return -EINVAL; + + memset(&ft_params, 0, sizeof(ft_params)); + ft_params.md = nla_get_u16(info->attrs[NL80211_ATTR_MDID]); + ft_params.ie = nla_data(info->attrs[NL80211_ATTR_IE]); + ft_params.ie_len = nla_len(info->attrs[NL80211_ATTR_IE]); + + return rdev_update_ft_ies(rdev, dev, &ft_params); +} + +static int nl80211_crit_protocol_start(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + enum nl80211_crit_proto_id proto = NL80211_CRIT_PROTO_UNSPEC; + u16 duration; + int ret; + + if (!rdev->ops->crit_proto_start) + return -EOPNOTSUPP; + + if (WARN_ON(!rdev->ops->crit_proto_stop)) + return -EINVAL; + + if (rdev->crit_proto_nlportid) + return -EBUSY; + + /* determine protocol if provided */ + if (info->attrs[NL80211_ATTR_CRIT_PROT_ID]) + proto = nla_get_u16(info->attrs[NL80211_ATTR_CRIT_PROT_ID]); + + if (proto >= NUM_NL80211_CRIT_PROTO) + return -EINVAL; + + /* timeout must be provided */ + if (!info->attrs[NL80211_ATTR_MAX_CRIT_PROT_DURATION]) + return -EINVAL; + + duration = + nla_get_u16(info->attrs[NL80211_ATTR_MAX_CRIT_PROT_DURATION]); + + ret = rdev_crit_proto_start(rdev, wdev, proto, duration); + if (!ret) + rdev->crit_proto_nlportid = info->snd_portid; + + return ret; +} + +static int nl80211_crit_protocol_stop(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + + if (!rdev->ops->crit_proto_stop) + return -EOPNOTSUPP; + + if (rdev->crit_proto_nlportid) { + rdev->crit_proto_nlportid = 0; + rdev_crit_proto_stop(rdev, wdev); + } + return 0; +} + +static int nl80211_vendor_check_policy(const struct wiphy_vendor_command *vcmd, + struct nlattr *attr, + struct netlink_ext_ack *extack) +{ + if (vcmd->policy == VENDOR_CMD_RAW_DATA) { + if (attr->nla_type & NLA_F_NESTED) { + NL_SET_ERR_MSG_ATTR(extack, attr, + "unexpected nested data"); + return -EINVAL; + } + + return 0; + } + + if (!(attr->nla_type & NLA_F_NESTED)) { + NL_SET_ERR_MSG_ATTR(extack, attr, "expected nested data"); + return -EINVAL; + } + + return nla_validate_nested(attr, vcmd->maxattr, vcmd->policy, extack); +} + +static int nl80211_vendor_cmd(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = + __cfg80211_wdev_from_attrs(rdev, genl_info_net(info), + info->attrs); + int i, err; + u32 vid, subcmd; + + if (!rdev->wiphy.vendor_commands) + return -EOPNOTSUPP; + + if (IS_ERR(wdev)) { + err = PTR_ERR(wdev); + if (err != -EINVAL) + return err; + wdev = NULL; + } else if (wdev->wiphy != &rdev->wiphy) { + return -EINVAL; + } + + if (!info->attrs[NL80211_ATTR_VENDOR_ID] || + !info->attrs[NL80211_ATTR_VENDOR_SUBCMD]) + return -EINVAL; + + vid = nla_get_u32(info->attrs[NL80211_ATTR_VENDOR_ID]); + subcmd = nla_get_u32(info->attrs[NL80211_ATTR_VENDOR_SUBCMD]); + for (i = 0; i < rdev->wiphy.n_vendor_commands; i++) { + const struct wiphy_vendor_command *vcmd; + void *data = NULL; + int len = 0; + + vcmd = &rdev->wiphy.vendor_commands[i]; + + if (vcmd->info.vendor_id != vid || vcmd->info.subcmd != subcmd) + continue; + + if (vcmd->flags & (WIPHY_VENDOR_CMD_NEED_WDEV | + WIPHY_VENDOR_CMD_NEED_NETDEV)) { + if (!wdev) + return -EINVAL; + if (vcmd->flags & WIPHY_VENDOR_CMD_NEED_NETDEV && + !wdev->netdev) + return -EINVAL; + + if (vcmd->flags & WIPHY_VENDOR_CMD_NEED_RUNNING) { + if (!wdev_running(wdev)) + return -ENETDOWN; + } + } else { + wdev = NULL; + } + + if (!vcmd->doit) + return -EOPNOTSUPP; + + if (info->attrs[NL80211_ATTR_VENDOR_DATA]) { + data = nla_data(info->attrs[NL80211_ATTR_VENDOR_DATA]); + len = nla_len(info->attrs[NL80211_ATTR_VENDOR_DATA]); + + err = nl80211_vendor_check_policy(vcmd, + info->attrs[NL80211_ATTR_VENDOR_DATA], + info->extack); + if (err) + return err; + } + + rdev->cur_cmd_info = info; + err = vcmd->doit(&rdev->wiphy, wdev, data, len); + rdev->cur_cmd_info = NULL; + return err; + } + + return -EOPNOTSUPP; +} + +static int nl80211_prepare_vendor_dump(struct sk_buff *skb, + struct netlink_callback *cb, + struct cfg80211_registered_device **rdev, + struct wireless_dev **wdev) +{ + struct nlattr **attrbuf; + u32 vid, subcmd; + unsigned int i; + int vcmd_idx = -1; + int err; + void *data = NULL; + unsigned int data_len = 0; + + if (cb->args[0]) { + /* subtract the 1 again here */ + struct wiphy *wiphy = wiphy_idx_to_wiphy(cb->args[0] - 1); + struct wireless_dev *tmp; + + if (!wiphy) + return -ENODEV; + *rdev = wiphy_to_rdev(wiphy); + *wdev = NULL; + + if (cb->args[1]) { + list_for_each_entry(tmp, &wiphy->wdev_list, list) { + if (tmp->identifier == cb->args[1] - 1) { + *wdev = tmp; + break; + } + } + } + + /* keep rtnl locked in successful case */ + return 0; + } + + attrbuf = kcalloc(NUM_NL80211_ATTR, sizeof(*attrbuf), GFP_KERNEL); + if (!attrbuf) + return -ENOMEM; + + err = nlmsg_parse_deprecated(cb->nlh, + GENL_HDRLEN + nl80211_fam.hdrsize, + attrbuf, nl80211_fam.maxattr, + nl80211_policy, NULL); + if (err) + goto out; + + if (!attrbuf[NL80211_ATTR_VENDOR_ID] || + !attrbuf[NL80211_ATTR_VENDOR_SUBCMD]) { + err = -EINVAL; + goto out; + } + + *wdev = __cfg80211_wdev_from_attrs(NULL, sock_net(skb->sk), attrbuf); + if (IS_ERR(*wdev)) + *wdev = NULL; + + *rdev = __cfg80211_rdev_from_attrs(sock_net(skb->sk), attrbuf); + if (IS_ERR(*rdev)) { + err = PTR_ERR(*rdev); + goto out; + } + + vid = nla_get_u32(attrbuf[NL80211_ATTR_VENDOR_ID]); + subcmd = nla_get_u32(attrbuf[NL80211_ATTR_VENDOR_SUBCMD]); + + for (i = 0; i < (*rdev)->wiphy.n_vendor_commands; i++) { + const struct wiphy_vendor_command *vcmd; + + vcmd = &(*rdev)->wiphy.vendor_commands[i]; + + if (vcmd->info.vendor_id != vid || vcmd->info.subcmd != subcmd) + continue; + + if (!vcmd->dumpit) { + err = -EOPNOTSUPP; + goto out; + } + + vcmd_idx = i; + break; + } + + if (vcmd_idx < 0) { + err = -EOPNOTSUPP; + goto out; + } + + if (attrbuf[NL80211_ATTR_VENDOR_DATA]) { + data = nla_data(attrbuf[NL80211_ATTR_VENDOR_DATA]); + data_len = nla_len(attrbuf[NL80211_ATTR_VENDOR_DATA]); + + err = nl80211_vendor_check_policy( + &(*rdev)->wiphy.vendor_commands[vcmd_idx], + attrbuf[NL80211_ATTR_VENDOR_DATA], + cb->extack); + if (err) + goto out; + } + + /* 0 is the first index - add 1 to parse only once */ + cb->args[0] = (*rdev)->wiphy_idx + 1; + /* add 1 to know if it was NULL */ + cb->args[1] = *wdev ? (*wdev)->identifier + 1 : 0; + cb->args[2] = vcmd_idx; + cb->args[3] = (unsigned long)data; + cb->args[4] = data_len; + + /* keep rtnl locked in successful case */ + err = 0; +out: + kfree(attrbuf); + return err; +} + +static int nl80211_vendor_cmd_dump(struct sk_buff *skb, + struct netlink_callback *cb) +{ + struct cfg80211_registered_device *rdev; + struct wireless_dev *wdev; + unsigned int vcmd_idx; + const struct wiphy_vendor_command *vcmd; + void *data; + int data_len; + int err; + struct nlattr *vendor_data; + + rtnl_lock(); + err = nl80211_prepare_vendor_dump(skb, cb, &rdev, &wdev); + if (err) + goto out; + + vcmd_idx = cb->args[2]; + data = (void *)cb->args[3]; + data_len = cb->args[4]; + vcmd = &rdev->wiphy.vendor_commands[vcmd_idx]; + + if (vcmd->flags & (WIPHY_VENDOR_CMD_NEED_WDEV | + WIPHY_VENDOR_CMD_NEED_NETDEV)) { + if (!wdev) { + err = -EINVAL; + goto out; + } + if (vcmd->flags & WIPHY_VENDOR_CMD_NEED_NETDEV && + !wdev->netdev) { + err = -EINVAL; + goto out; + } + + if (vcmd->flags & WIPHY_VENDOR_CMD_NEED_RUNNING) { + if (!wdev_running(wdev)) { + err = -ENETDOWN; + goto out; + } + } + } + + while (1) { + void *hdr = nl80211hdr_put(skb, NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + NL80211_CMD_VENDOR); + if (!hdr) + break; + + if (nla_put_u32(skb, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + (wdev && nla_put_u64_64bit(skb, NL80211_ATTR_WDEV, + wdev_id(wdev), + NL80211_ATTR_PAD))) { + genlmsg_cancel(skb, hdr); + break; + } + + vendor_data = nla_nest_start_noflag(skb, + NL80211_ATTR_VENDOR_DATA); + if (!vendor_data) { + genlmsg_cancel(skb, hdr); + break; + } + + err = vcmd->dumpit(&rdev->wiphy, wdev, skb, data, data_len, + (unsigned long *)&cb->args[5]); + nla_nest_end(skb, vendor_data); + + if (err == -ENOBUFS || err == -ENOENT) { + genlmsg_cancel(skb, hdr); + break; + } else if (err <= 0) { + genlmsg_cancel(skb, hdr); + goto out; + } + + genlmsg_end(skb, hdr); + } + + err = skb->len; + out: + rtnl_unlock(); + return err; +} + +struct sk_buff *__cfg80211_alloc_reply_skb(struct wiphy *wiphy, + enum nl80211_commands cmd, + enum nl80211_attrs attr, + int approxlen) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + if (WARN_ON(!rdev->cur_cmd_info)) + return NULL; + + return __cfg80211_alloc_vendor_skb(rdev, NULL, approxlen, + rdev->cur_cmd_info->snd_portid, + rdev->cur_cmd_info->snd_seq, + cmd, attr, NULL, GFP_KERNEL); +} +EXPORT_SYMBOL(__cfg80211_alloc_reply_skb); + +int cfg80211_vendor_cmd_reply(struct sk_buff *skb) +{ + struct cfg80211_registered_device *rdev = ((void **)skb->cb)[0]; + void *hdr = ((void **)skb->cb)[1]; + struct nlattr *data = ((void **)skb->cb)[2]; + + /* clear CB data for netlink core to own from now on */ + memset(skb->cb, 0, sizeof(skb->cb)); + + if (WARN_ON(!rdev->cur_cmd_info)) { + kfree_skb(skb); + return -EINVAL; + } + + nla_nest_end(skb, data); + genlmsg_end(skb, hdr); + return genlmsg_reply(skb, rdev->cur_cmd_info); +} +EXPORT_SYMBOL_GPL(cfg80211_vendor_cmd_reply); + +unsigned int cfg80211_vendor_cmd_get_sender(struct wiphy *wiphy) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + if (WARN_ON(!rdev->cur_cmd_info)) + return 0; + + return rdev->cur_cmd_info->snd_portid; +} +EXPORT_SYMBOL_GPL(cfg80211_vendor_cmd_get_sender); + +static int nl80211_set_qos_map(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct cfg80211_qos_map *qos_map = NULL; + struct net_device *dev = info->user_ptr[1]; + u8 *pos, len, num_des, des_len, des; + int ret; + + if (!rdev->ops->set_qos_map) + return -EOPNOTSUPP; + + if (info->attrs[NL80211_ATTR_QOS_MAP]) { + pos = nla_data(info->attrs[NL80211_ATTR_QOS_MAP]); + len = nla_len(info->attrs[NL80211_ATTR_QOS_MAP]); + + if (len % 2) + return -EINVAL; + + qos_map = kzalloc(sizeof(struct cfg80211_qos_map), GFP_KERNEL); + if (!qos_map) + return -ENOMEM; + + num_des = (len - IEEE80211_QOS_MAP_LEN_MIN) >> 1; + if (num_des) { + des_len = num_des * + sizeof(struct cfg80211_dscp_exception); + memcpy(qos_map->dscp_exception, pos, des_len); + qos_map->num_des = num_des; + for (des = 0; des < num_des; des++) { + if (qos_map->dscp_exception[des].up > 7) { + kfree(qos_map); + return -EINVAL; + } + } + pos += des_len; + } + memcpy(qos_map->up, pos, IEEE80211_QOS_MAP_LEN_MIN); + } + + wdev_lock(dev->ieee80211_ptr); + ret = nl80211_key_allowed(dev->ieee80211_ptr); + if (!ret) + ret = rdev_set_qos_map(rdev, dev, qos_map); + wdev_unlock(dev->ieee80211_ptr); + + kfree(qos_map); + return ret; +} + +static int nl80211_add_tx_ts(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + const u8 *peer; + u8 tsid, up; + u16 admitted_time = 0; + int err; + + if (!(rdev->wiphy.features & NL80211_FEATURE_SUPPORTS_WMM_ADMISSION)) + return -EOPNOTSUPP; + + if (!info->attrs[NL80211_ATTR_TSID] || !info->attrs[NL80211_ATTR_MAC] || + !info->attrs[NL80211_ATTR_USER_PRIO]) + return -EINVAL; + + tsid = nla_get_u8(info->attrs[NL80211_ATTR_TSID]); + up = nla_get_u8(info->attrs[NL80211_ATTR_USER_PRIO]); + + /* WMM uses TIDs 0-7 even for TSPEC */ + if (tsid >= IEEE80211_FIRST_TSPEC_TSID) { + /* TODO: handle 802.11 TSPEC/admission control + * need more attributes for that (e.g. BA session requirement); + * change the WMM adminssion test above to allow both then + */ + return -EINVAL; + } + + peer = nla_data(info->attrs[NL80211_ATTR_MAC]); + + if (info->attrs[NL80211_ATTR_ADMITTED_TIME]) { + admitted_time = + nla_get_u16(info->attrs[NL80211_ATTR_ADMITTED_TIME]); + if (!admitted_time) + return -EINVAL; + } + + wdev_lock(wdev); + switch (wdev->iftype) { + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_P2P_CLIENT: + if (wdev->connected) + break; + err = -ENOTCONN; + goto out; + default: + err = -EOPNOTSUPP; + goto out; + } + + err = rdev_add_tx_ts(rdev, dev, tsid, peer, up, admitted_time); + + out: + wdev_unlock(wdev); + return err; +} + +static int nl80211_del_tx_ts(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + const u8 *peer; + u8 tsid; + int err; + + if (!info->attrs[NL80211_ATTR_TSID] || !info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + tsid = nla_get_u8(info->attrs[NL80211_ATTR_TSID]); + peer = nla_data(info->attrs[NL80211_ATTR_MAC]); + + wdev_lock(wdev); + err = rdev_del_tx_ts(rdev, dev, tsid, peer); + wdev_unlock(wdev); + + return err; +} + +static int nl80211_tdls_channel_switch(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_chan_def chandef = {}; + const u8 *addr; + u8 oper_class; + int err; + + if (!rdev->ops->tdls_channel_switch || + !(rdev->wiphy.features & NL80211_FEATURE_TDLS_CHANNEL_SWITCH)) + return -EOPNOTSUPP; + + switch (dev->ieee80211_ptr->iftype) { + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_P2P_CLIENT: + break; + default: + return -EOPNOTSUPP; + } + + if (!info->attrs[NL80211_ATTR_MAC] || + !info->attrs[NL80211_ATTR_OPER_CLASS]) + return -EINVAL; + + err = nl80211_parse_chandef(rdev, info, &chandef); + if (err) + return err; + + /* + * Don't allow wide channels on the 2.4Ghz band, as per IEEE802.11-2012 + * section 10.22.6.2.1. Disallow 5/10Mhz channels as well for now, the + * specification is not defined for them. + */ + if (chandef.chan->band == NL80211_BAND_2GHZ && + chandef.width != NL80211_CHAN_WIDTH_20_NOHT && + chandef.width != NL80211_CHAN_WIDTH_20) + return -EINVAL; + + /* we will be active on the TDLS link */ + if (!cfg80211_reg_can_beacon_relax(&rdev->wiphy, &chandef, + wdev->iftype)) + return -EINVAL; + + /* don't allow switching to DFS channels */ + if (cfg80211_chandef_dfs_required(wdev->wiphy, &chandef, wdev->iftype)) + return -EINVAL; + + addr = nla_data(info->attrs[NL80211_ATTR_MAC]); + oper_class = nla_get_u8(info->attrs[NL80211_ATTR_OPER_CLASS]); + + wdev_lock(wdev); + err = rdev_tdls_channel_switch(rdev, dev, addr, oper_class, &chandef); + wdev_unlock(wdev); + + return err; +} + +static int nl80211_tdls_cancel_channel_switch(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + const u8 *addr; + + if (!rdev->ops->tdls_channel_switch || + !rdev->ops->tdls_cancel_channel_switch || + !(rdev->wiphy.features & NL80211_FEATURE_TDLS_CHANNEL_SWITCH)) + return -EOPNOTSUPP; + + switch (dev->ieee80211_ptr->iftype) { + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_P2P_CLIENT: + break; + default: + return -EOPNOTSUPP; + } + + if (!info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + addr = nla_data(info->attrs[NL80211_ATTR_MAC]); + + wdev_lock(wdev); + rdev_tdls_cancel_channel_switch(rdev, dev, addr); + wdev_unlock(wdev); + + return 0; +} + +static int nl80211_set_multicast_to_unicast(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + const struct nlattr *nla; + bool enabled; + + if (!rdev->ops->set_multicast_to_unicast) + return -EOPNOTSUPP; + + if (wdev->iftype != NL80211_IFTYPE_AP && + wdev->iftype != NL80211_IFTYPE_P2P_GO) + return -EOPNOTSUPP; + + nla = info->attrs[NL80211_ATTR_MULTICAST_TO_UNICAST_ENABLED]; + enabled = nla_get_flag(nla); + + return rdev_set_multicast_to_unicast(rdev, dev, enabled); +} + +static int nl80211_set_pmk(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_pmk_conf pmk_conf = {}; + int ret; + + if (wdev->iftype != NL80211_IFTYPE_STATION && + wdev->iftype != NL80211_IFTYPE_P2P_CLIENT) + return -EOPNOTSUPP; + + if (!wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_4WAY_HANDSHAKE_STA_1X)) + return -EOPNOTSUPP; + + if (!info->attrs[NL80211_ATTR_MAC] || !info->attrs[NL80211_ATTR_PMK]) + return -EINVAL; + + wdev_lock(wdev); + if (!wdev->connected) { + ret = -ENOTCONN; + goto out; + } + + pmk_conf.aa = nla_data(info->attrs[NL80211_ATTR_MAC]); + if (memcmp(pmk_conf.aa, wdev->u.client.connected_addr, ETH_ALEN)) { + ret = -EINVAL; + goto out; + } + + pmk_conf.pmk = nla_data(info->attrs[NL80211_ATTR_PMK]); + pmk_conf.pmk_len = nla_len(info->attrs[NL80211_ATTR_PMK]); + if (pmk_conf.pmk_len != WLAN_PMK_LEN && + pmk_conf.pmk_len != WLAN_PMK_LEN_SUITE_B_192) { + ret = -EINVAL; + goto out; + } + + if (info->attrs[NL80211_ATTR_PMKR0_NAME]) + pmk_conf.pmk_r0_name = + nla_data(info->attrs[NL80211_ATTR_PMKR0_NAME]); + + ret = rdev_set_pmk(rdev, dev, &pmk_conf); +out: + wdev_unlock(wdev); + return ret; +} + +static int nl80211_del_pmk(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + const u8 *aa; + int ret; + + if (wdev->iftype != NL80211_IFTYPE_STATION && + wdev->iftype != NL80211_IFTYPE_P2P_CLIENT) + return -EOPNOTSUPP; + + if (!wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_4WAY_HANDSHAKE_STA_1X)) + return -EOPNOTSUPP; + + if (!info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + wdev_lock(wdev); + aa = nla_data(info->attrs[NL80211_ATTR_MAC]); + ret = rdev_del_pmk(rdev, dev, aa); + wdev_unlock(wdev); + + return ret; +} + +static int nl80211_external_auth(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct cfg80211_external_auth_params params; + + if (!rdev->ops->external_auth) + return -EOPNOTSUPP; + + if (!info->attrs[NL80211_ATTR_SSID] && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_AP && + dev->ieee80211_ptr->iftype != NL80211_IFTYPE_P2P_GO) + return -EINVAL; + + if (!info->attrs[NL80211_ATTR_BSSID]) + return -EINVAL; + + if (!info->attrs[NL80211_ATTR_STATUS_CODE]) + return -EINVAL; + + memset(¶ms, 0, sizeof(params)); + + if (info->attrs[NL80211_ATTR_SSID]) { + params.ssid.ssid_len = nla_len(info->attrs[NL80211_ATTR_SSID]); + if (params.ssid.ssid_len == 0) + return -EINVAL; + memcpy(params.ssid.ssid, + nla_data(info->attrs[NL80211_ATTR_SSID]), + params.ssid.ssid_len); + } + + memcpy(params.bssid, nla_data(info->attrs[NL80211_ATTR_BSSID]), + ETH_ALEN); + + params.status = nla_get_u16(info->attrs[NL80211_ATTR_STATUS_CODE]); + + if (info->attrs[NL80211_ATTR_PMKID]) + params.pmkid = nla_data(info->attrs[NL80211_ATTR_PMKID]); + + return rdev_external_auth(rdev, dev, ¶ms); +} + +static int nl80211_tx_control_port(struct sk_buff *skb, struct genl_info *info) +{ + bool dont_wait_for_ack = info->attrs[NL80211_ATTR_DONT_WAIT_FOR_ACK]; + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + const u8 *buf; + size_t len; + u8 *dest; + u16 proto; + bool noencrypt; + u64 cookie = 0; + int link_id; + int err; + + if (!wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_CONTROL_PORT_OVER_NL80211)) + return -EOPNOTSUPP; + + if (!rdev->ops->tx_control_port) + return -EOPNOTSUPP; + + if (!info->attrs[NL80211_ATTR_FRAME] || + !info->attrs[NL80211_ATTR_MAC] || + !info->attrs[NL80211_ATTR_CONTROL_PORT_ETHERTYPE]) { + GENL_SET_ERR_MSG(info, "Frame, MAC or ethertype missing"); + return -EINVAL; + } + + wdev_lock(wdev); + + switch (wdev->iftype) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_P2P_GO: + case NL80211_IFTYPE_MESH_POINT: + break; + case NL80211_IFTYPE_ADHOC: + if (wdev->u.ibss.current_bss) + break; + err = -ENOTCONN; + goto out; + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_P2P_CLIENT: + if (wdev->connected) + break; + err = -ENOTCONN; + goto out; + default: + err = -EOPNOTSUPP; + goto out; + } + + wdev_unlock(wdev); + + buf = nla_data(info->attrs[NL80211_ATTR_FRAME]); + len = nla_len(info->attrs[NL80211_ATTR_FRAME]); + dest = nla_data(info->attrs[NL80211_ATTR_MAC]); + proto = nla_get_u16(info->attrs[NL80211_ATTR_CONTROL_PORT_ETHERTYPE]); + noencrypt = + nla_get_flag(info->attrs[NL80211_ATTR_CONTROL_PORT_NO_ENCRYPT]); + + link_id = nl80211_link_id_or_invalid(info->attrs); + + err = rdev_tx_control_port(rdev, dev, buf, len, + dest, cpu_to_be16(proto), noencrypt, link_id, + dont_wait_for_ack ? NULL : &cookie); + if (!err && !dont_wait_for_ack) + nl_set_extack_cookie_u64(info->extack, cookie); + return err; + out: + wdev_unlock(wdev); + return err; +} + +static int nl80211_get_ftm_responder_stats(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_ftm_responder_stats ftm_stats = {}; + unsigned int link_id = nl80211_link_id(info->attrs); + struct sk_buff *msg; + void *hdr; + struct nlattr *ftm_stats_attr; + int err; + + if (wdev->iftype != NL80211_IFTYPE_AP || + !wdev->links[link_id].ap.beacon_interval) + return -EOPNOTSUPP; + + err = rdev_get_ftm_responder_stats(rdev, dev, &ftm_stats); + if (err) + return err; + + if (!ftm_stats.filled) + return -ENODATA; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return -ENOMEM; + + hdr = nl80211hdr_put(msg, info->snd_portid, info->snd_seq, 0, + NL80211_CMD_GET_FTM_RESPONDER_STATS); + if (!hdr) + goto nla_put_failure; + + if (nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex)) + goto nla_put_failure; + + ftm_stats_attr = nla_nest_start_noflag(msg, + NL80211_ATTR_FTM_RESPONDER_STATS); + if (!ftm_stats_attr) + goto nla_put_failure; + +#define SET_FTM(field, name, type) \ + do { if ((ftm_stats.filled & BIT(NL80211_FTM_STATS_ ## name)) && \ + nla_put_ ## type(msg, NL80211_FTM_STATS_ ## name, \ + ftm_stats.field)) \ + goto nla_put_failure; } while (0) +#define SET_FTM_U64(field, name) \ + do { if ((ftm_stats.filled & BIT(NL80211_FTM_STATS_ ## name)) && \ + nla_put_u64_64bit(msg, NL80211_FTM_STATS_ ## name, \ + ftm_stats.field, NL80211_FTM_STATS_PAD)) \ + goto nla_put_failure; } while (0) + + SET_FTM(success_num, SUCCESS_NUM, u32); + SET_FTM(partial_num, PARTIAL_NUM, u32); + SET_FTM(failed_num, FAILED_NUM, u32); + SET_FTM(asap_num, ASAP_NUM, u32); + SET_FTM(non_asap_num, NON_ASAP_NUM, u32); + SET_FTM_U64(total_duration_ms, TOTAL_DURATION_MSEC); + SET_FTM(unknown_triggers_num, UNKNOWN_TRIGGERS_NUM, u32); + SET_FTM(reschedule_requests_num, RESCHEDULE_REQUESTS_NUM, u32); + SET_FTM(out_of_window_triggers_num, OUT_OF_WINDOW_TRIGGERS_NUM, u32); +#undef SET_FTM + + nla_nest_end(msg, ftm_stats_attr); + + genlmsg_end(msg, hdr); + return genlmsg_reply(msg, info); + +nla_put_failure: + nlmsg_free(msg); + return -ENOBUFS; +} + +static int nl80211_update_owe_info(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct cfg80211_update_owe_info owe_info; + struct net_device *dev = info->user_ptr[1]; + + if (!rdev->ops->update_owe_info) + return -EOPNOTSUPP; + + if (!info->attrs[NL80211_ATTR_STATUS_CODE] || + !info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + memset(&owe_info, 0, sizeof(owe_info)); + owe_info.status = nla_get_u16(info->attrs[NL80211_ATTR_STATUS_CODE]); + nla_memcpy(owe_info.peer, info->attrs[NL80211_ATTR_MAC], ETH_ALEN); + + if (info->attrs[NL80211_ATTR_IE]) { + owe_info.ie = nla_data(info->attrs[NL80211_ATTR_IE]); + owe_info.ie_len = nla_len(info->attrs[NL80211_ATTR_IE]); + } + + return rdev_update_owe_info(rdev, dev, &owe_info); +} + +static int nl80211_probe_mesh_link(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct station_info sinfo = {}; + const u8 *buf; + size_t len; + u8 *dest; + int err; + + if (!rdev->ops->probe_mesh_link || !rdev->ops->get_station) + return -EOPNOTSUPP; + + if (!info->attrs[NL80211_ATTR_MAC] || + !info->attrs[NL80211_ATTR_FRAME]) { + GENL_SET_ERR_MSG(info, "Frame or MAC missing"); + return -EINVAL; + } + + if (wdev->iftype != NL80211_IFTYPE_MESH_POINT) + return -EOPNOTSUPP; + + dest = nla_data(info->attrs[NL80211_ATTR_MAC]); + buf = nla_data(info->attrs[NL80211_ATTR_FRAME]); + len = nla_len(info->attrs[NL80211_ATTR_FRAME]); + + if (len < sizeof(struct ethhdr)) + return -EINVAL; + + if (!ether_addr_equal(buf, dest) || is_multicast_ether_addr(buf) || + !ether_addr_equal(buf + ETH_ALEN, dev->dev_addr)) + return -EINVAL; + + err = rdev_get_station(rdev, dev, dest, &sinfo); + if (err) + return err; + + cfg80211_sinfo_release_content(&sinfo); + + return rdev_probe_mesh_link(rdev, dev, dest, buf, len); +} + +static int parse_tid_conf(struct cfg80211_registered_device *rdev, + struct nlattr *attrs[], struct net_device *dev, + struct cfg80211_tid_cfg *tid_conf, + struct genl_info *info, const u8 *peer, + unsigned int link_id) +{ + struct netlink_ext_ack *extack = info->extack; + u64 mask; + int err; + + if (!attrs[NL80211_TID_CONFIG_ATTR_TIDS]) + return -EINVAL; + + tid_conf->config_override = + nla_get_flag(attrs[NL80211_TID_CONFIG_ATTR_OVERRIDE]); + tid_conf->tids = nla_get_u16(attrs[NL80211_TID_CONFIG_ATTR_TIDS]); + + if (tid_conf->config_override) { + if (rdev->ops->reset_tid_config) { + err = rdev_reset_tid_config(rdev, dev, peer, + tid_conf->tids); + if (err) + return err; + } else { + return -EINVAL; + } + } + + if (attrs[NL80211_TID_CONFIG_ATTR_NOACK]) { + tid_conf->mask |= BIT(NL80211_TID_CONFIG_ATTR_NOACK); + tid_conf->noack = + nla_get_u8(attrs[NL80211_TID_CONFIG_ATTR_NOACK]); + } + + if (attrs[NL80211_TID_CONFIG_ATTR_RETRY_SHORT]) { + tid_conf->mask |= BIT(NL80211_TID_CONFIG_ATTR_RETRY_SHORT); + tid_conf->retry_short = + nla_get_u8(attrs[NL80211_TID_CONFIG_ATTR_RETRY_SHORT]); + + if (tid_conf->retry_short > rdev->wiphy.max_data_retry_count) + return -EINVAL; + } + + if (attrs[NL80211_TID_CONFIG_ATTR_RETRY_LONG]) { + tid_conf->mask |= BIT(NL80211_TID_CONFIG_ATTR_RETRY_LONG); + tid_conf->retry_long = + nla_get_u8(attrs[NL80211_TID_CONFIG_ATTR_RETRY_LONG]); + + if (tid_conf->retry_long > rdev->wiphy.max_data_retry_count) + return -EINVAL; + } + + if (attrs[NL80211_TID_CONFIG_ATTR_AMPDU_CTRL]) { + tid_conf->mask |= BIT(NL80211_TID_CONFIG_ATTR_AMPDU_CTRL); + tid_conf->ampdu = + nla_get_u8(attrs[NL80211_TID_CONFIG_ATTR_AMPDU_CTRL]); + } + + if (attrs[NL80211_TID_CONFIG_ATTR_RTSCTS_CTRL]) { + tid_conf->mask |= BIT(NL80211_TID_CONFIG_ATTR_RTSCTS_CTRL); + tid_conf->rtscts = + nla_get_u8(attrs[NL80211_TID_CONFIG_ATTR_RTSCTS_CTRL]); + } + + if (attrs[NL80211_TID_CONFIG_ATTR_AMSDU_CTRL]) { + tid_conf->mask |= BIT(NL80211_TID_CONFIG_ATTR_AMSDU_CTRL); + tid_conf->amsdu = + nla_get_u8(attrs[NL80211_TID_CONFIG_ATTR_AMSDU_CTRL]); + } + + if (attrs[NL80211_TID_CONFIG_ATTR_TX_RATE_TYPE]) { + u32 idx = NL80211_TID_CONFIG_ATTR_TX_RATE_TYPE, attr; + + tid_conf->txrate_type = nla_get_u8(attrs[idx]); + + if (tid_conf->txrate_type != NL80211_TX_RATE_AUTOMATIC) { + attr = NL80211_TID_CONFIG_ATTR_TX_RATE; + err = nl80211_parse_tx_bitrate_mask(info, attrs, attr, + &tid_conf->txrate_mask, dev, + true, link_id); + if (err) + return err; + + tid_conf->mask |= BIT(NL80211_TID_CONFIG_ATTR_TX_RATE); + } + tid_conf->mask |= BIT(NL80211_TID_CONFIG_ATTR_TX_RATE_TYPE); + } + + if (peer) + mask = rdev->wiphy.tid_config_support.peer; + else + mask = rdev->wiphy.tid_config_support.vif; + + if (tid_conf->mask & ~mask) { + NL_SET_ERR_MSG(extack, "unsupported TID configuration"); + return -ENOTSUPP; + } + + return 0; +} + +static int nl80211_set_tid_config(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct nlattr *attrs[NL80211_TID_CONFIG_ATTR_MAX + 1]; + unsigned int link_id = nl80211_link_id(info->attrs); + struct net_device *dev = info->user_ptr[1]; + struct cfg80211_tid_config *tid_config; + struct nlattr *tid; + int conf_idx = 0, rem_conf; + int ret = -EINVAL; + u32 num_conf = 0; + + if (!info->attrs[NL80211_ATTR_TID_CONFIG]) + return -EINVAL; + + if (!rdev->ops->set_tid_config) + return -EOPNOTSUPP; + + nla_for_each_nested(tid, info->attrs[NL80211_ATTR_TID_CONFIG], + rem_conf) + num_conf++; + + tid_config = kzalloc(struct_size(tid_config, tid_conf, num_conf), + GFP_KERNEL); + if (!tid_config) + return -ENOMEM; + + tid_config->n_tid_conf = num_conf; + + if (info->attrs[NL80211_ATTR_MAC]) + tid_config->peer = nla_data(info->attrs[NL80211_ATTR_MAC]); + + wdev_lock(dev->ieee80211_ptr); + + nla_for_each_nested(tid, info->attrs[NL80211_ATTR_TID_CONFIG], + rem_conf) { + ret = nla_parse_nested(attrs, NL80211_TID_CONFIG_ATTR_MAX, + tid, NULL, NULL); + + if (ret) + goto bad_tid_conf; + + ret = parse_tid_conf(rdev, attrs, dev, + &tid_config->tid_conf[conf_idx], + info, tid_config->peer, link_id); + if (ret) + goto bad_tid_conf; + + conf_idx++; + } + + ret = rdev_set_tid_config(rdev, dev, tid_config); + +bad_tid_conf: + kfree(tid_config); + wdev_unlock(dev->ieee80211_ptr); + return ret; +} + +static int nl80211_color_change(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct cfg80211_color_change_settings params = {}; + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct nlattr **tb; + u16 offset; + int err; + + if (!rdev->ops->color_change) + return -EOPNOTSUPP; + + if (!wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_BSS_COLOR)) + return -EOPNOTSUPP; + + if (wdev->iftype != NL80211_IFTYPE_AP) + return -EOPNOTSUPP; + + if (!info->attrs[NL80211_ATTR_COLOR_CHANGE_COUNT] || + !info->attrs[NL80211_ATTR_COLOR_CHANGE_COLOR] || + !info->attrs[NL80211_ATTR_COLOR_CHANGE_ELEMS]) + return -EINVAL; + + params.count = nla_get_u8(info->attrs[NL80211_ATTR_COLOR_CHANGE_COUNT]); + params.color = nla_get_u8(info->attrs[NL80211_ATTR_COLOR_CHANGE_COLOR]); + + err = nl80211_parse_beacon(rdev, info->attrs, ¶ms.beacon_next); + if (err) + return err; + + tb = kcalloc(NL80211_ATTR_MAX + 1, sizeof(*tb), GFP_KERNEL); + if (!tb) + return -ENOMEM; + + err = nla_parse_nested(tb, NL80211_ATTR_MAX, + info->attrs[NL80211_ATTR_COLOR_CHANGE_ELEMS], + nl80211_policy, info->extack); + if (err) + goto out; + + err = nl80211_parse_beacon(rdev, tb, ¶ms.beacon_color_change); + if (err) + goto out; + + if (!tb[NL80211_ATTR_CNTDWN_OFFS_BEACON]) { + err = -EINVAL; + goto out; + } + + if (nla_len(tb[NL80211_ATTR_CNTDWN_OFFS_BEACON]) != sizeof(u16)) { + err = -EINVAL; + goto out; + } + + offset = nla_get_u16(tb[NL80211_ATTR_CNTDWN_OFFS_BEACON]); + if (offset >= params.beacon_color_change.tail_len) { + err = -EINVAL; + goto out; + } + + if (params.beacon_color_change.tail[offset] != params.count) { + err = -EINVAL; + goto out; + } + + params.counter_offset_beacon = offset; + + if (tb[NL80211_ATTR_CNTDWN_OFFS_PRESP]) { + if (nla_len(tb[NL80211_ATTR_CNTDWN_OFFS_PRESP]) != + sizeof(u16)) { + err = -EINVAL; + goto out; + } + + offset = nla_get_u16(tb[NL80211_ATTR_CNTDWN_OFFS_PRESP]); + if (offset >= params.beacon_color_change.probe_resp_len) { + err = -EINVAL; + goto out; + } + + if (params.beacon_color_change.probe_resp[offset] != + params.count) { + err = -EINVAL; + goto out; + } + + params.counter_offset_presp = offset; + } + + wdev_lock(wdev); + err = rdev_color_change(rdev, dev, ¶ms); + wdev_unlock(wdev); + +out: + kfree(params.beacon_next.mbssid_ies); + kfree(params.beacon_color_change.mbssid_ies); + kfree(tb); + return err; +} + +static int nl80211_set_fils_aad(struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + struct cfg80211_fils_aad fils_aad = {}; + u8 *nonces; + + if (!info->attrs[NL80211_ATTR_MAC] || + !info->attrs[NL80211_ATTR_FILS_KEK] || + !info->attrs[NL80211_ATTR_FILS_NONCES]) + return -EINVAL; + + fils_aad.macaddr = nla_data(info->attrs[NL80211_ATTR_MAC]); + fils_aad.kek_len = nla_len(info->attrs[NL80211_ATTR_FILS_KEK]); + fils_aad.kek = nla_data(info->attrs[NL80211_ATTR_FILS_KEK]); + nonces = nla_data(info->attrs[NL80211_ATTR_FILS_NONCES]); + fils_aad.snonce = nonces; + fils_aad.anonce = nonces + FILS_NONCE_LEN; + + return rdev_set_fils_aad(rdev, dev, &fils_aad); +} + +static int nl80211_add_link(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + unsigned int link_id = nl80211_link_id(info->attrs); + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + int ret; + + if (!(wdev->wiphy->flags & WIPHY_FLAG_SUPPORTS_MLO)) + return -EINVAL; + + switch (wdev->iftype) { + case NL80211_IFTYPE_AP: + break; + default: + return -EINVAL; + } + + if (!info->attrs[NL80211_ATTR_MAC] || + !is_valid_ether_addr(nla_data(info->attrs[NL80211_ATTR_MAC]))) + return -EINVAL; + + wdev_lock(wdev); + wdev->valid_links |= BIT(link_id); + ether_addr_copy(wdev->links[link_id].addr, + nla_data(info->attrs[NL80211_ATTR_MAC])); + + ret = rdev_add_intf_link(rdev, wdev, link_id); + if (ret) { + wdev->valid_links &= ~BIT(link_id); + eth_zero_addr(wdev->links[link_id].addr); + } + wdev_unlock(wdev); + + return ret; +} + +static int nl80211_remove_link(struct sk_buff *skb, struct genl_info *info) +{ + unsigned int link_id = nl80211_link_id(info->attrs); + struct net_device *dev = info->user_ptr[1]; + struct wireless_dev *wdev = dev->ieee80211_ptr; + + /* cannot remove if there's no link */ + if (!info->attrs[NL80211_ATTR_MLO_LINK_ID]) + return -EINVAL; + + switch (wdev->iftype) { + case NL80211_IFTYPE_AP: + break; + default: + return -EINVAL; + } + + wdev_lock(wdev); + cfg80211_remove_link(wdev, link_id); + wdev_unlock(wdev); + + return 0; +} + +static int +nl80211_add_mod_link_station(struct sk_buff *skb, struct genl_info *info, + bool add) +{ + struct link_station_parameters params = {}; + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + int err; + + if ((add && !rdev->ops->add_link_station) || + (!add && !rdev->ops->mod_link_station)) + return -EOPNOTSUPP; + + if (add && !info->attrs[NL80211_ATTR_MAC]) + return -EINVAL; + + if (!info->attrs[NL80211_ATTR_MLD_ADDR]) + return -EINVAL; + + if (add && !info->attrs[NL80211_ATTR_STA_SUPPORTED_RATES]) + return -EINVAL; + + params.mld_mac = nla_data(info->attrs[NL80211_ATTR_MLD_ADDR]); + + if (info->attrs[NL80211_ATTR_MAC]) { + params.link_mac = nla_data(info->attrs[NL80211_ATTR_MAC]); + if (!is_valid_ether_addr(params.link_mac)) + return -EINVAL; + } + + if (!info->attrs[NL80211_ATTR_MLO_LINK_ID]) + return -EINVAL; + + params.link_id = nla_get_u8(info->attrs[NL80211_ATTR_MLO_LINK_ID]); + + if (info->attrs[NL80211_ATTR_STA_SUPPORTED_RATES]) { + params.supported_rates = + nla_data(info->attrs[NL80211_ATTR_STA_SUPPORTED_RATES]); + params.supported_rates_len = + nla_len(info->attrs[NL80211_ATTR_STA_SUPPORTED_RATES]); + } + + if (info->attrs[NL80211_ATTR_HT_CAPABILITY]) + params.ht_capa = + nla_data(info->attrs[NL80211_ATTR_HT_CAPABILITY]); + + if (info->attrs[NL80211_ATTR_VHT_CAPABILITY]) + params.vht_capa = + nla_data(info->attrs[NL80211_ATTR_VHT_CAPABILITY]); + + if (info->attrs[NL80211_ATTR_HE_CAPABILITY]) { + params.he_capa = + nla_data(info->attrs[NL80211_ATTR_HE_CAPABILITY]); + params.he_capa_len = + nla_len(info->attrs[NL80211_ATTR_HE_CAPABILITY]); + + if (info->attrs[NL80211_ATTR_EHT_CAPABILITY]) { + params.eht_capa = + nla_data(info->attrs[NL80211_ATTR_EHT_CAPABILITY]); + params.eht_capa_len = + nla_len(info->attrs[NL80211_ATTR_EHT_CAPABILITY]); + + if (!ieee80211_eht_capa_size_ok((const u8 *)params.he_capa, + (const u8 *)params.eht_capa, + params.eht_capa_len, + false)) + return -EINVAL; + } + } + + if (info->attrs[NL80211_ATTR_HE_6GHZ_CAPABILITY]) + params.he_6ghz_capa = + nla_data(info->attrs[NL80211_ATTR_HE_6GHZ_CAPABILITY]); + + if (info->attrs[NL80211_ATTR_OPMODE_NOTIF]) { + params.opmode_notif_used = true; + params.opmode_notif = + nla_get_u8(info->attrs[NL80211_ATTR_OPMODE_NOTIF]); + } + + err = nl80211_parse_sta_txpower_setting(info, ¶ms.txpwr, + ¶ms.txpwr_set); + if (err) + return err; + + wdev_lock(dev->ieee80211_ptr); + if (add) + err = rdev_add_link_station(rdev, dev, ¶ms); + else + err = rdev_mod_link_station(rdev, dev, ¶ms); + wdev_unlock(dev->ieee80211_ptr); + + return err; +} + +static int +nl80211_add_link_station(struct sk_buff *skb, struct genl_info *info) +{ + return nl80211_add_mod_link_station(skb, info, true); +} + +static int +nl80211_modify_link_station(struct sk_buff *skb, struct genl_info *info) +{ + return nl80211_add_mod_link_station(skb, info, false); +} + +static int +nl80211_remove_link_station(struct sk_buff *skb, struct genl_info *info) +{ + struct link_station_del_parameters params = {}; + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct net_device *dev = info->user_ptr[1]; + int ret; + + if (!rdev->ops->del_link_station) + return -EOPNOTSUPP; + + if (!info->attrs[NL80211_ATTR_MLD_ADDR] || + !info->attrs[NL80211_ATTR_MLO_LINK_ID]) + return -EINVAL; + + params.mld_mac = nla_data(info->attrs[NL80211_ATTR_MLD_ADDR]); + params.link_id = nla_get_u8(info->attrs[NL80211_ATTR_MLO_LINK_ID]); + + wdev_lock(dev->ieee80211_ptr); + ret = rdev_del_link_station(rdev, dev, ¶ms); + wdev_unlock(dev->ieee80211_ptr); + + return ret; +} + +#define NL80211_FLAG_NEED_WIPHY 0x01 +#define NL80211_FLAG_NEED_NETDEV 0x02 +#define NL80211_FLAG_NEED_RTNL 0x04 +#define NL80211_FLAG_CHECK_NETDEV_UP 0x08 +#define NL80211_FLAG_NEED_NETDEV_UP (NL80211_FLAG_NEED_NETDEV |\ + NL80211_FLAG_CHECK_NETDEV_UP) +#define NL80211_FLAG_NEED_WDEV 0x10 +/* If a netdev is associated, it must be UP, P2P must be started */ +#define NL80211_FLAG_NEED_WDEV_UP (NL80211_FLAG_NEED_WDEV |\ + NL80211_FLAG_CHECK_NETDEV_UP) +#define NL80211_FLAG_CLEAR_SKB 0x20 +#define NL80211_FLAG_NO_WIPHY_MTX 0x40 +#define NL80211_FLAG_MLO_VALID_LINK_ID 0x80 +#define NL80211_FLAG_MLO_UNSUPPORTED 0x100 + +#define INTERNAL_FLAG_SELECTORS(__sel) \ + SELECTOR(__sel, NONE, 0) /* must be first */ \ + SELECTOR(__sel, WIPHY, \ + NL80211_FLAG_NEED_WIPHY) \ + SELECTOR(__sel, WDEV, \ + NL80211_FLAG_NEED_WDEV) \ + SELECTOR(__sel, NETDEV, \ + NL80211_FLAG_NEED_NETDEV) \ + SELECTOR(__sel, NETDEV_LINK, \ + NL80211_FLAG_NEED_NETDEV | \ + NL80211_FLAG_MLO_VALID_LINK_ID) \ + SELECTOR(__sel, NETDEV_NO_MLO, \ + NL80211_FLAG_NEED_NETDEV | \ + NL80211_FLAG_MLO_UNSUPPORTED) \ + SELECTOR(__sel, WIPHY_RTNL, \ + NL80211_FLAG_NEED_WIPHY | \ + NL80211_FLAG_NEED_RTNL) \ + SELECTOR(__sel, WIPHY_RTNL_NOMTX, \ + NL80211_FLAG_NEED_WIPHY | \ + NL80211_FLAG_NEED_RTNL | \ + NL80211_FLAG_NO_WIPHY_MTX) \ + SELECTOR(__sel, WDEV_RTNL, \ + NL80211_FLAG_NEED_WDEV | \ + NL80211_FLAG_NEED_RTNL) \ + SELECTOR(__sel, NETDEV_RTNL, \ + NL80211_FLAG_NEED_NETDEV | \ + NL80211_FLAG_NEED_RTNL) \ + SELECTOR(__sel, NETDEV_UP, \ + NL80211_FLAG_NEED_NETDEV_UP) \ + SELECTOR(__sel, NETDEV_UP_LINK, \ + NL80211_FLAG_NEED_NETDEV_UP | \ + NL80211_FLAG_MLO_VALID_LINK_ID) \ + SELECTOR(__sel, NETDEV_UP_NO_MLO, \ + NL80211_FLAG_NEED_NETDEV_UP | \ + NL80211_FLAG_MLO_UNSUPPORTED) \ + SELECTOR(__sel, NETDEV_UP_NO_MLO_CLEAR, \ + NL80211_FLAG_NEED_NETDEV_UP | \ + NL80211_FLAG_CLEAR_SKB | \ + NL80211_FLAG_MLO_UNSUPPORTED) \ + SELECTOR(__sel, NETDEV_UP_NOTMX, \ + NL80211_FLAG_NEED_NETDEV_UP | \ + NL80211_FLAG_NO_WIPHY_MTX) \ + SELECTOR(__sel, NETDEV_UP_NOTMX_NOMLO, \ + NL80211_FLAG_NEED_NETDEV_UP | \ + NL80211_FLAG_NO_WIPHY_MTX | \ + NL80211_FLAG_MLO_UNSUPPORTED) \ + SELECTOR(__sel, NETDEV_UP_CLEAR, \ + NL80211_FLAG_NEED_NETDEV_UP | \ + NL80211_FLAG_CLEAR_SKB) \ + SELECTOR(__sel, WDEV_UP, \ + NL80211_FLAG_NEED_WDEV_UP) \ + SELECTOR(__sel, WDEV_UP_LINK, \ + NL80211_FLAG_NEED_WDEV_UP | \ + NL80211_FLAG_MLO_VALID_LINK_ID) \ + SELECTOR(__sel, WDEV_UP_RTNL, \ + NL80211_FLAG_NEED_WDEV_UP | \ + NL80211_FLAG_NEED_RTNL) \ + SELECTOR(__sel, WIPHY_CLEAR, \ + NL80211_FLAG_NEED_WIPHY | \ + NL80211_FLAG_CLEAR_SKB) + +enum nl80211_internal_flags_selector { +#define SELECTOR(_, name, value) NL80211_IFL_SEL_##name, + INTERNAL_FLAG_SELECTORS(_) +#undef SELECTOR +}; + +static u32 nl80211_internal_flags[] = { +#define SELECTOR(_, name, value) [NL80211_IFL_SEL_##name] = value, + INTERNAL_FLAG_SELECTORS(_) +#undef SELECTOR +}; + +static int nl80211_pre_doit(const struct genl_ops *ops, struct sk_buff *skb, + struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = NULL; + struct wireless_dev *wdev = NULL; + struct net_device *dev = NULL; + u32 internal_flags; + int err; + + if (WARN_ON(ops->internal_flags >= ARRAY_SIZE(nl80211_internal_flags))) + return -EINVAL; + + internal_flags = nl80211_internal_flags[ops->internal_flags]; + + rtnl_lock(); + if (internal_flags & NL80211_FLAG_NEED_WIPHY) { + rdev = cfg80211_get_dev_from_info(genl_info_net(info), info); + if (IS_ERR(rdev)) { + err = PTR_ERR(rdev); + goto out_unlock; + } + info->user_ptr[0] = rdev; + } else if (internal_flags & NL80211_FLAG_NEED_NETDEV || + internal_flags & NL80211_FLAG_NEED_WDEV) { + wdev = __cfg80211_wdev_from_attrs(NULL, genl_info_net(info), + info->attrs); + if (IS_ERR(wdev)) { + err = PTR_ERR(wdev); + goto out_unlock; + } + + dev = wdev->netdev; + dev_hold(dev); + rdev = wiphy_to_rdev(wdev->wiphy); + + if (internal_flags & NL80211_FLAG_NEED_NETDEV) { + if (!dev) { + err = -EINVAL; + goto out_unlock; + } + + info->user_ptr[1] = dev; + } else { + info->user_ptr[1] = wdev; + } + + if (internal_flags & NL80211_FLAG_CHECK_NETDEV_UP && + !wdev_running(wdev)) { + err = -ENETDOWN; + goto out_unlock; + } + + info->user_ptr[0] = rdev; + } + + if (internal_flags & NL80211_FLAG_MLO_VALID_LINK_ID) { + struct nlattr *link_id = info->attrs[NL80211_ATTR_MLO_LINK_ID]; + + if (!wdev) { + err = -EINVAL; + goto out_unlock; + } + + /* MLO -> require valid link ID */ + if (wdev->valid_links && + (!link_id || + !(wdev->valid_links & BIT(nla_get_u8(link_id))))) { + err = -EINVAL; + goto out_unlock; + } + + /* non-MLO -> no link ID attribute accepted */ + if (!wdev->valid_links && link_id) { + err = -EINVAL; + goto out_unlock; + } + } + + if (internal_flags & NL80211_FLAG_MLO_UNSUPPORTED) { + if (info->attrs[NL80211_ATTR_MLO_LINK_ID] || + (wdev && wdev->valid_links)) { + err = -EINVAL; + goto out_unlock; + } + } + + if (rdev && !(internal_flags & NL80211_FLAG_NO_WIPHY_MTX)) { + wiphy_lock(&rdev->wiphy); + /* we keep the mutex locked until post_doit */ + __release(&rdev->wiphy.mtx); + } + if (!(internal_flags & NL80211_FLAG_NEED_RTNL)) + rtnl_unlock(); + + return 0; +out_unlock: + rtnl_unlock(); + dev_put(dev); + return err; +} + +static void nl80211_post_doit(const struct genl_ops *ops, struct sk_buff *skb, + struct genl_info *info) +{ + u32 internal_flags = nl80211_internal_flags[ops->internal_flags]; + + if (info->user_ptr[1]) { + if (internal_flags & NL80211_FLAG_NEED_WDEV) { + struct wireless_dev *wdev = info->user_ptr[1]; + + dev_put(wdev->netdev); + } else { + dev_put(info->user_ptr[1]); + } + } + + if (info->user_ptr[0] && + !(internal_flags & NL80211_FLAG_NO_WIPHY_MTX)) { + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + + /* we kept the mutex locked since pre_doit */ + __acquire(&rdev->wiphy.mtx); + wiphy_unlock(&rdev->wiphy); + } + + if (internal_flags & NL80211_FLAG_NEED_RTNL) + rtnl_unlock(); + + /* If needed, clear the netlink message payload from the SKB + * as it might contain key data that shouldn't stick around on + * the heap after the SKB is freed. The netlink message header + * is still needed for further processing, so leave it intact. + */ + if (internal_flags & NL80211_FLAG_CLEAR_SKB) { + struct nlmsghdr *nlh = nlmsg_hdr(skb); + + memset(nlmsg_data(nlh), 0, nlmsg_len(nlh)); + } +} + +static int nl80211_set_sar_sub_specs(struct cfg80211_registered_device *rdev, + struct cfg80211_sar_specs *sar_specs, + struct nlattr *spec[], int index) +{ + u32 range_index, i; + + if (!sar_specs || !spec) + return -EINVAL; + + if (!spec[NL80211_SAR_ATTR_SPECS_POWER] || + !spec[NL80211_SAR_ATTR_SPECS_RANGE_INDEX]) + return -EINVAL; + + range_index = nla_get_u32(spec[NL80211_SAR_ATTR_SPECS_RANGE_INDEX]); + + /* check if range_index exceeds num_freq_ranges */ + if (range_index >= rdev->wiphy.sar_capa->num_freq_ranges) + return -EINVAL; + + /* check if range_index duplicates */ + for (i = 0; i < index; i++) { + if (sar_specs->sub_specs[i].freq_range_index == range_index) + return -EINVAL; + } + + sar_specs->sub_specs[index].power = + nla_get_s32(spec[NL80211_SAR_ATTR_SPECS_POWER]); + + sar_specs->sub_specs[index].freq_range_index = range_index; + + return 0; +} + +static int nl80211_set_sar_specs(struct sk_buff *skb, struct genl_info *info) +{ + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct nlattr *spec[NL80211_SAR_ATTR_SPECS_MAX + 1]; + struct nlattr *tb[NL80211_SAR_ATTR_MAX + 1]; + struct cfg80211_sar_specs *sar_spec; + enum nl80211_sar_type type; + struct nlattr *spec_list; + u32 specs; + int rem, err; + + if (!rdev->wiphy.sar_capa || !rdev->ops->set_sar_specs) + return -EOPNOTSUPP; + + if (!info->attrs[NL80211_ATTR_SAR_SPEC]) + return -EINVAL; + + nla_parse_nested(tb, NL80211_SAR_ATTR_MAX, + info->attrs[NL80211_ATTR_SAR_SPEC], + NULL, NULL); + + if (!tb[NL80211_SAR_ATTR_TYPE] || !tb[NL80211_SAR_ATTR_SPECS]) + return -EINVAL; + + type = nla_get_u32(tb[NL80211_SAR_ATTR_TYPE]); + if (type != rdev->wiphy.sar_capa->type) + return -EINVAL; + + specs = 0; + nla_for_each_nested(spec_list, tb[NL80211_SAR_ATTR_SPECS], rem) + specs++; + + if (specs > rdev->wiphy.sar_capa->num_freq_ranges) + return -EINVAL; + + sar_spec = kzalloc(struct_size(sar_spec, sub_specs, specs), GFP_KERNEL); + if (!sar_spec) + return -ENOMEM; + + sar_spec->type = type; + specs = 0; + nla_for_each_nested(spec_list, tb[NL80211_SAR_ATTR_SPECS], rem) { + nla_parse_nested(spec, NL80211_SAR_ATTR_SPECS_MAX, + spec_list, NULL, NULL); + + switch (type) { + case NL80211_SAR_TYPE_POWER: + if (nl80211_set_sar_sub_specs(rdev, sar_spec, + spec, specs)) { + err = -EINVAL; + goto error; + } + break; + default: + err = -EINVAL; + goto error; + } + specs++; + } + + sar_spec->num_sub_specs = specs; + + rdev->cur_cmd_info = info; + err = rdev_set_sar_specs(rdev, sar_spec); + rdev->cur_cmd_info = NULL; +error: + kfree(sar_spec); + return err; +} + +#define SELECTOR(__sel, name, value) \ + ((__sel) == (value)) ? NL80211_IFL_SEL_##name : +int __missing_selector(void); +#define IFLAGS(__val) INTERNAL_FLAG_SELECTORS(__val) __missing_selector() + +static const struct genl_ops nl80211_ops[] = { + { + .cmd = NL80211_CMD_GET_WIPHY, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_get_wiphy, + .dumpit = nl80211_dump_wiphy, + .done = nl80211_dump_wiphy_done, + /* can be retrieved by unprivileged users */ + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WIPHY), + }, +}; + +static const struct genl_small_ops nl80211_small_ops[] = { + { + .cmd = NL80211_CMD_SET_WIPHY, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_wiphy, + .flags = GENL_UNS_ADMIN_PERM, + }, + { + .cmd = NL80211_CMD_GET_INTERFACE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_get_interface, + .dumpit = nl80211_dump_interface, + /* can be retrieved by unprivileged users */ + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV), + }, + { + .cmd = NL80211_CMD_SET_INTERFACE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_interface, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV | + NL80211_FLAG_NEED_RTNL), + }, + { + .cmd = NL80211_CMD_NEW_INTERFACE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_new_interface, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = + IFLAGS(NL80211_FLAG_NEED_WIPHY | + NL80211_FLAG_NEED_RTNL | + /* we take the wiphy mutex later ourselves */ + NL80211_FLAG_NO_WIPHY_MTX), + }, + { + .cmd = NL80211_CMD_DEL_INTERFACE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_del_interface, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV | + NL80211_FLAG_NEED_RTNL), + }, + { + .cmd = NL80211_CMD_GET_KEY, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_get_key, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_SET_KEY, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_key, + .flags = GENL_UNS_ADMIN_PERM, + /* cannot use NL80211_FLAG_MLO_VALID_LINK_ID, depends on key */ + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_CLEAR_SKB), + }, + { + .cmd = NL80211_CMD_NEW_KEY, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_new_key, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_CLEAR_SKB), + }, + { + .cmd = NL80211_CMD_DEL_KEY, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_del_key, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_SET_BEACON, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .doit = nl80211_set_beacon, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_MLO_VALID_LINK_ID), + }, + { + .cmd = NL80211_CMD_START_AP, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .doit = nl80211_start_ap, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_MLO_VALID_LINK_ID), + }, + { + .cmd = NL80211_CMD_STOP_AP, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .flags = GENL_UNS_ADMIN_PERM, + .doit = nl80211_stop_ap, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_MLO_VALID_LINK_ID), + }, + { + .cmd = NL80211_CMD_GET_STATION, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_get_station, + .dumpit = nl80211_dump_station, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV), + }, + { + .cmd = NL80211_CMD_SET_STATION, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_station, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_NEW_STATION, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_new_station, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_DEL_STATION, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_del_station, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_GET_MPATH, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_get_mpath, + .dumpit = nl80211_dump_mpath, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_GET_MPP, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_get_mpp, + .dumpit = nl80211_dump_mpp, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_SET_MPATH, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_mpath, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_NEW_MPATH, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_new_mpath, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_DEL_MPATH, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_del_mpath, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_SET_BSS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_bss, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_GET_REG, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_get_reg_do, + .dumpit = nl80211_get_reg_dump, + /* can be retrieved by unprivileged users */ + }, +#ifdef CONFIG_CFG80211_CRDA_SUPPORT + { + .cmd = NL80211_CMD_SET_REG, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_reg, + .flags = GENL_ADMIN_PERM, + }, +#endif + { + .cmd = NL80211_CMD_REQ_SET_REG, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_req_set_reg, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NL80211_CMD_RELOAD_REGDB, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_reload_regdb, + .flags = GENL_ADMIN_PERM, + }, + { + .cmd = NL80211_CMD_GET_MESH_CONFIG, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_get_mesh_config, + /* can be retrieved by unprivileged users */ + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_SET_MESH_CONFIG, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_update_mesh_config, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_TRIGGER_SCAN, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_trigger_scan, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV_UP), + }, + { + .cmd = NL80211_CMD_ABORT_SCAN, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_abort_scan, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV_UP), + }, + { + .cmd = NL80211_CMD_GET_SCAN, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .dumpit = nl80211_dump_scan, + }, + { + .cmd = NL80211_CMD_START_SCHED_SCAN, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_start_sched_scan, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_STOP_SCHED_SCAN, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_stop_sched_scan, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_AUTHENTICATE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_authenticate, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_CLEAR_SKB), + }, + { + .cmd = NL80211_CMD_ASSOCIATE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_associate, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_CLEAR_SKB), + }, + { + .cmd = NL80211_CMD_DEAUTHENTICATE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_deauthenticate, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_DISASSOCIATE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_disassociate, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_JOIN_IBSS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_join_ibss, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_LEAVE_IBSS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_leave_ibss, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, +#ifdef CONFIG_NL80211_TESTMODE + { + .cmd = NL80211_CMD_TESTMODE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_testmode_do, + .dumpit = nl80211_testmode_dump, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WIPHY), + }, +#endif + { + .cmd = NL80211_CMD_CONNECT, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_connect, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_CLEAR_SKB), + }, + { + .cmd = NL80211_CMD_UPDATE_CONNECT_PARAMS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_update_connect_params, + .flags = GENL_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_CLEAR_SKB), + }, + { + .cmd = NL80211_CMD_DISCONNECT, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_disconnect, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_SET_WIPHY_NETNS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_wiphy_netns, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WIPHY | + NL80211_FLAG_NEED_RTNL | + NL80211_FLAG_NO_WIPHY_MTX), + }, + { + .cmd = NL80211_CMD_GET_SURVEY, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .dumpit = nl80211_dump_survey, + }, + { + .cmd = NL80211_CMD_SET_PMKSA, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_setdel_pmksa, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_CLEAR_SKB), + }, + { + .cmd = NL80211_CMD_DEL_PMKSA, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_setdel_pmksa, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_FLUSH_PMKSA, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_flush_pmksa, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_REMAIN_ON_CHANNEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_remain_on_channel, + .flags = GENL_UNS_ADMIN_PERM, + /* FIXME: requiring a link ID here is probably not good */ + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV_UP | + NL80211_FLAG_MLO_VALID_LINK_ID), + }, + { + .cmd = NL80211_CMD_CANCEL_REMAIN_ON_CHANNEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_cancel_remain_on_channel, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV_UP), + }, + { + .cmd = NL80211_CMD_SET_TX_BITRATE_MASK, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_tx_bitrate_mask, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV | + NL80211_FLAG_MLO_VALID_LINK_ID), + }, + { + .cmd = NL80211_CMD_REGISTER_FRAME, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_register_mgmt, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV), + }, + { + .cmd = NL80211_CMD_FRAME, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_tx_mgmt, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV_UP), + }, + { + .cmd = NL80211_CMD_FRAME_WAIT_CANCEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_tx_mgmt_cancel_wait, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV_UP), + }, + { + .cmd = NL80211_CMD_SET_POWER_SAVE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_power_save, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV), + }, + { + .cmd = NL80211_CMD_GET_POWER_SAVE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_get_power_save, + /* can be retrieved by unprivileged users */ + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV), + }, + { + .cmd = NL80211_CMD_SET_CQM, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_cqm, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV), + }, + { + .cmd = NL80211_CMD_SET_CHANNEL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_channel, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV | + NL80211_FLAG_MLO_VALID_LINK_ID), + }, + { + .cmd = NL80211_CMD_JOIN_MESH, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_join_mesh, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_LEAVE_MESH, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_leave_mesh, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_JOIN_OCB, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_join_ocb, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_LEAVE_OCB, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_leave_ocb, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, +#ifdef CONFIG_PM + { + .cmd = NL80211_CMD_GET_WOWLAN, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_get_wowlan, + /* can be retrieved by unprivileged users */ + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WIPHY), + }, + { + .cmd = NL80211_CMD_SET_WOWLAN, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_wowlan, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WIPHY), + }, +#endif + { + .cmd = NL80211_CMD_SET_REKEY_OFFLOAD, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_rekey_data, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_CLEAR_SKB), + }, + { + .cmd = NL80211_CMD_TDLS_MGMT, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_tdls_mgmt, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_TDLS_OPER, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_tdls_oper, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_UNEXPECTED_FRAME, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_register_unexpected_frame, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV), + }, + { + .cmd = NL80211_CMD_PROBE_CLIENT, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_probe_client, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_REGISTER_BEACONS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_register_beacons, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WIPHY), + }, + { + .cmd = NL80211_CMD_SET_NOACK_MAP, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_noack_map, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV), + }, + { + .cmd = NL80211_CMD_START_P2P_DEVICE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_start_p2p_device, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV | + NL80211_FLAG_NEED_RTNL), + }, + { + .cmd = NL80211_CMD_STOP_P2P_DEVICE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_stop_p2p_device, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV_UP | + NL80211_FLAG_NEED_RTNL), + }, + { + .cmd = NL80211_CMD_START_NAN, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_start_nan, + .flags = GENL_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV | + NL80211_FLAG_NEED_RTNL), + }, + { + .cmd = NL80211_CMD_STOP_NAN, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_stop_nan, + .flags = GENL_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV_UP | + NL80211_FLAG_NEED_RTNL), + }, + { + .cmd = NL80211_CMD_ADD_NAN_FUNCTION, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_nan_add_func, + .flags = GENL_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV_UP), + }, + { + .cmd = NL80211_CMD_DEL_NAN_FUNCTION, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_nan_del_func, + .flags = GENL_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV_UP), + }, + { + .cmd = NL80211_CMD_CHANGE_NAN_CONFIG, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_nan_change_config, + .flags = GENL_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV_UP), + }, + { + .cmd = NL80211_CMD_SET_MCAST_RATE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_mcast_rate, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV), + }, + { + .cmd = NL80211_CMD_SET_MAC_ACL, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_mac_acl, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV | + NL80211_FLAG_MLO_UNSUPPORTED), + }, + { + .cmd = NL80211_CMD_RADAR_DETECT, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_start_radar_detection, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_NO_WIPHY_MTX | + NL80211_FLAG_MLO_UNSUPPORTED), + }, + { + .cmd = NL80211_CMD_GET_PROTOCOL_FEATURES, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_get_protocol_features, + }, + { + .cmd = NL80211_CMD_UPDATE_FT_IES, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_update_ft_ies, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_CRIT_PROTOCOL_START, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_crit_protocol_start, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV_UP), + }, + { + .cmd = NL80211_CMD_CRIT_PROTOCOL_STOP, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_crit_protocol_stop, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV_UP), + }, + { + .cmd = NL80211_CMD_GET_COALESCE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_get_coalesce, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WIPHY), + }, + { + .cmd = NL80211_CMD_SET_COALESCE, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_coalesce, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WIPHY), + }, + { + .cmd = NL80211_CMD_CHANNEL_SWITCH, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_channel_switch, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_MLO_VALID_LINK_ID), + }, + { + .cmd = NL80211_CMD_VENDOR, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_vendor_cmd, + .dumpit = nl80211_vendor_cmd_dump, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WIPHY | + NL80211_FLAG_CLEAR_SKB), + }, + { + .cmd = NL80211_CMD_SET_QOS_MAP, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_qos_map, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_ADD_TX_TS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_add_tx_ts, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_MLO_UNSUPPORTED), + }, + { + .cmd = NL80211_CMD_DEL_TX_TS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_del_tx_ts, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_TDLS_CHANNEL_SWITCH, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_tdls_channel_switch, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_TDLS_CANCEL_CHANNEL_SWITCH, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_tdls_cancel_channel_switch, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_SET_MULTICAST_TO_UNICAST, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_multicast_to_unicast, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV), + }, + { + .cmd = NL80211_CMD_SET_PMK, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_pmk, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_CLEAR_SKB), + }, + { + .cmd = NL80211_CMD_DEL_PMK, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_del_pmk, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_EXTERNAL_AUTH, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_external_auth, + .flags = GENL_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_CONTROL_PORT_FRAME, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_tx_control_port, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_GET_FTM_RESPONDER_STATS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_get_ftm_responder_stats, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV | + NL80211_FLAG_MLO_VALID_LINK_ID), + }, + { + .cmd = NL80211_CMD_PEER_MEASUREMENT_START, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_pmsr_start, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WDEV_UP), + }, + { + .cmd = NL80211_CMD_NOTIFY_RADAR, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_notify_radar_detection, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_UPDATE_OWE_INFO, + .doit = nl80211_update_owe_info, + .flags = GENL_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_PROBE_MESH_LINK, + .doit = nl80211_probe_mesh_link, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_SET_TID_CONFIG, + .doit = nl80211_set_tid_config, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV | + NL80211_FLAG_MLO_VALID_LINK_ID), + }, + { + .cmd = NL80211_CMD_SET_SAR_SPECS, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_sar_specs, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_WIPHY | + NL80211_FLAG_NEED_RTNL), + }, + { + .cmd = NL80211_CMD_COLOR_CHANGE_REQUEST, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_color_change, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_SET_FILS_AAD, + .validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP, + .doit = nl80211_set_fils_aad, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_ADD_LINK, + .doit = nl80211_add_link, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP), + }, + { + .cmd = NL80211_CMD_REMOVE_LINK, + .doit = nl80211_remove_link, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_MLO_VALID_LINK_ID), + }, + { + .cmd = NL80211_CMD_ADD_LINK_STA, + .doit = nl80211_add_link_station, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_MLO_VALID_LINK_ID), + }, + { + .cmd = NL80211_CMD_MODIFY_LINK_STA, + .doit = nl80211_modify_link_station, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_MLO_VALID_LINK_ID), + }, + { + .cmd = NL80211_CMD_REMOVE_LINK_STA, + .doit = nl80211_remove_link_station, + .flags = GENL_UNS_ADMIN_PERM, + .internal_flags = IFLAGS(NL80211_FLAG_NEED_NETDEV_UP | + NL80211_FLAG_MLO_VALID_LINK_ID), + }, +}; + +static struct genl_family nl80211_fam __ro_after_init = { + .name = NL80211_GENL_NAME, /* have users key off the name instead */ + .hdrsize = 0, /* no private header */ + .version = 1, /* no particular meaning now */ + .maxattr = NL80211_ATTR_MAX, + .policy = nl80211_policy, + .netnsok = true, + .pre_doit = nl80211_pre_doit, + .post_doit = nl80211_post_doit, + .module = THIS_MODULE, + .ops = nl80211_ops, + .n_ops = ARRAY_SIZE(nl80211_ops), + .small_ops = nl80211_small_ops, + .n_small_ops = ARRAY_SIZE(nl80211_small_ops), + .resv_start_op = NL80211_CMD_REMOVE_LINK_STA + 1, + .mcgrps = nl80211_mcgrps, + .n_mcgrps = ARRAY_SIZE(nl80211_mcgrps), + .parallel_ops = true, +}; + +/* notification functions */ + +void nl80211_notify_wiphy(struct cfg80211_registered_device *rdev, + enum nl80211_commands cmd) +{ + struct sk_buff *msg; + struct nl80211_dump_wiphy_state state = {}; + + WARN_ON(cmd != NL80211_CMD_NEW_WIPHY && + cmd != NL80211_CMD_DEL_WIPHY); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + if (nl80211_send_wiphy(rdev, cmd, msg, 0, 0, 0, &state) < 0) { + nlmsg_free(msg); + return; + } + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_CONFIG, GFP_KERNEL); +} + +void nl80211_notify_iface(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + enum nl80211_commands cmd) +{ + struct sk_buff *msg; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + if (nl80211_send_iface(msg, 0, 0, 0, rdev, wdev, cmd) < 0) { + nlmsg_free(msg); + return; + } + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_CONFIG, GFP_KERNEL); +} + +static int nl80211_add_scan_req(struct sk_buff *msg, + struct cfg80211_registered_device *rdev) +{ + struct cfg80211_scan_request *req = rdev->scan_req; + struct nlattr *nest; + int i; + struct cfg80211_scan_info *info; + + if (WARN_ON(!req)) + return 0; + + nest = nla_nest_start_noflag(msg, NL80211_ATTR_SCAN_SSIDS); + if (!nest) + goto nla_put_failure; + for (i = 0; i < req->n_ssids; i++) { + if (nla_put(msg, i, req->ssids[i].ssid_len, req->ssids[i].ssid)) + goto nla_put_failure; + } + nla_nest_end(msg, nest); + + if (req->flags & NL80211_SCAN_FLAG_FREQ_KHZ) { + nest = nla_nest_start(msg, NL80211_ATTR_SCAN_FREQ_KHZ); + if (!nest) + goto nla_put_failure; + for (i = 0; i < req->n_channels; i++) { + if (nla_put_u32(msg, i, + ieee80211_channel_to_khz(req->channels[i]))) + goto nla_put_failure; + } + nla_nest_end(msg, nest); + } else { + nest = nla_nest_start_noflag(msg, + NL80211_ATTR_SCAN_FREQUENCIES); + if (!nest) + goto nla_put_failure; + for (i = 0; i < req->n_channels; i++) { + if (nla_put_u32(msg, i, req->channels[i]->center_freq)) + goto nla_put_failure; + } + nla_nest_end(msg, nest); + } + + if (req->ie && + nla_put(msg, NL80211_ATTR_IE, req->ie_len, req->ie)) + goto nla_put_failure; + + if (req->flags && + nla_put_u32(msg, NL80211_ATTR_SCAN_FLAGS, req->flags)) + goto nla_put_failure; + + info = rdev->int_scan_req ? &rdev->int_scan_req->info : + &rdev->scan_req->info; + if (info->scan_start_tsf && + (nla_put_u64_64bit(msg, NL80211_ATTR_SCAN_START_TIME_TSF, + info->scan_start_tsf, NL80211_BSS_PAD) || + nla_put(msg, NL80211_ATTR_SCAN_START_TIME_TSF_BSSID, ETH_ALEN, + info->tsf_bssid))) + goto nla_put_failure; + + return 0; + nla_put_failure: + return -ENOBUFS; +} + +static int nl80211_prep_scan_msg(struct sk_buff *msg, + struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + u32 portid, u32 seq, int flags, + u32 cmd) +{ + void *hdr; + + hdr = nl80211hdr_put(msg, portid, seq, flags, cmd); + if (!hdr) + return -1; + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + (wdev->netdev && nla_put_u32(msg, NL80211_ATTR_IFINDEX, + wdev->netdev->ifindex)) || + nla_put_u64_64bit(msg, NL80211_ATTR_WDEV, wdev_id(wdev), + NL80211_ATTR_PAD)) + goto nla_put_failure; + + /* ignore errors and send incomplete event anyway */ + nl80211_add_scan_req(msg, rdev); + + genlmsg_end(msg, hdr); + return 0; + + nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +static int +nl80211_prep_sched_scan_msg(struct sk_buff *msg, + struct cfg80211_sched_scan_request *req, u32 cmd) +{ + void *hdr; + + hdr = nl80211hdr_put(msg, 0, 0, 0, cmd); + if (!hdr) + return -1; + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, + wiphy_to_rdev(req->wiphy)->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, req->dev->ifindex) || + nla_put_u64_64bit(msg, NL80211_ATTR_COOKIE, req->reqid, + NL80211_ATTR_PAD)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + return 0; + + nla_put_failure: + genlmsg_cancel(msg, hdr); + return -EMSGSIZE; +} + +void nl80211_send_scan_start(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev) +{ + struct sk_buff *msg; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + if (nl80211_prep_scan_msg(msg, rdev, wdev, 0, 0, 0, + NL80211_CMD_TRIGGER_SCAN) < 0) { + nlmsg_free(msg); + return; + } + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_SCAN, GFP_KERNEL); +} + +struct sk_buff *nl80211_build_scan_msg(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, bool aborted) +{ + struct sk_buff *msg; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return NULL; + + if (nl80211_prep_scan_msg(msg, rdev, wdev, 0, 0, 0, + aborted ? NL80211_CMD_SCAN_ABORTED : + NL80211_CMD_NEW_SCAN_RESULTS) < 0) { + nlmsg_free(msg); + return NULL; + } + + return msg; +} + +/* send message created by nl80211_build_scan_msg() */ +void nl80211_send_scan_msg(struct cfg80211_registered_device *rdev, + struct sk_buff *msg) +{ + if (!msg) + return; + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_SCAN, GFP_KERNEL); +} + +void nl80211_send_sched_scan(struct cfg80211_sched_scan_request *req, u32 cmd) +{ + struct sk_buff *msg; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + if (nl80211_prep_sched_scan_msg(msg, req, cmd) < 0) { + nlmsg_free(msg); + return; + } + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(req->wiphy), msg, 0, + NL80211_MCGRP_SCAN, GFP_KERNEL); +} + +static bool nl80211_reg_change_event_fill(struct sk_buff *msg, + struct regulatory_request *request) +{ + /* Userspace can always count this one always being set */ + if (nla_put_u8(msg, NL80211_ATTR_REG_INITIATOR, request->initiator)) + goto nla_put_failure; + + if (request->alpha2[0] == '0' && request->alpha2[1] == '0') { + if (nla_put_u8(msg, NL80211_ATTR_REG_TYPE, + NL80211_REGDOM_TYPE_WORLD)) + goto nla_put_failure; + } else if (request->alpha2[0] == '9' && request->alpha2[1] == '9') { + if (nla_put_u8(msg, NL80211_ATTR_REG_TYPE, + NL80211_REGDOM_TYPE_CUSTOM_WORLD)) + goto nla_put_failure; + } else if ((request->alpha2[0] == '9' && request->alpha2[1] == '8') || + request->intersect) { + if (nla_put_u8(msg, NL80211_ATTR_REG_TYPE, + NL80211_REGDOM_TYPE_INTERSECTION)) + goto nla_put_failure; + } else { + if (nla_put_u8(msg, NL80211_ATTR_REG_TYPE, + NL80211_REGDOM_TYPE_COUNTRY) || + nla_put_string(msg, NL80211_ATTR_REG_ALPHA2, + request->alpha2)) + goto nla_put_failure; + } + + if (request->wiphy_idx != WIPHY_IDX_INVALID) { + struct wiphy *wiphy = wiphy_idx_to_wiphy(request->wiphy_idx); + + if (wiphy && + nla_put_u32(msg, NL80211_ATTR_WIPHY, request->wiphy_idx)) + goto nla_put_failure; + + if (wiphy && + wiphy->regulatory_flags & REGULATORY_WIPHY_SELF_MANAGED && + nla_put_flag(msg, NL80211_ATTR_WIPHY_SELF_MANAGED_REG)) + goto nla_put_failure; + } + + return true; + +nla_put_failure: + return false; +} + +/* + * This can happen on global regulatory changes or device specific settings + * based on custom regulatory domains. + */ +void nl80211_common_reg_change_event(enum nl80211_commands cmd_id, + struct regulatory_request *request) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, cmd_id); + if (!hdr) + goto nla_put_failure; + + if (!nl80211_reg_change_event_fill(msg, request)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + rcu_read_lock(); + genlmsg_multicast_allns(&nl80211_fam, msg, 0, + NL80211_MCGRP_REGULATORY, GFP_ATOMIC); + rcu_read_unlock(); + + return; + +nla_put_failure: + nlmsg_free(msg); +} + +static void nl80211_send_mlme_event(struct cfg80211_registered_device *rdev, + struct net_device *netdev, + const u8 *buf, size_t len, + enum nl80211_commands cmd, gfp_t gfp, + int uapsd_queues, const u8 *req_ies, + size_t req_ies_len, bool reconnect) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(100 + len + req_ies_len, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, cmd); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, netdev->ifindex) || + nla_put(msg, NL80211_ATTR_FRAME, len, buf) || + (req_ies && + nla_put(msg, NL80211_ATTR_REQ_IE, req_ies_len, req_ies))) + goto nla_put_failure; + + if (reconnect && nla_put_flag(msg, NL80211_ATTR_RECONNECT_REQUESTED)) + goto nla_put_failure; + + if (uapsd_queues >= 0) { + struct nlattr *nla_wmm = + nla_nest_start_noflag(msg, NL80211_ATTR_STA_WME); + if (!nla_wmm) + goto nla_put_failure; + + if (nla_put_u8(msg, NL80211_STA_WME_UAPSD_QUEUES, + uapsd_queues)) + goto nla_put_failure; + + nla_nest_end(msg, nla_wmm); + } + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} + +void nl80211_send_rx_auth(struct cfg80211_registered_device *rdev, + struct net_device *netdev, const u8 *buf, + size_t len, gfp_t gfp) +{ + nl80211_send_mlme_event(rdev, netdev, buf, len, + NL80211_CMD_AUTHENTICATE, gfp, -1, NULL, 0, + false); +} + +void nl80211_send_rx_assoc(struct cfg80211_registered_device *rdev, + struct net_device *netdev, + struct cfg80211_rx_assoc_resp *data) +{ + nl80211_send_mlme_event(rdev, netdev, data->buf, data->len, + NL80211_CMD_ASSOCIATE, GFP_KERNEL, + data->uapsd_queues, + data->req_ies, data->req_ies_len, false); +} + +void nl80211_send_deauth(struct cfg80211_registered_device *rdev, + struct net_device *netdev, const u8 *buf, + size_t len, bool reconnect, gfp_t gfp) +{ + nl80211_send_mlme_event(rdev, netdev, buf, len, + NL80211_CMD_DEAUTHENTICATE, gfp, -1, NULL, 0, + reconnect); +} + +void nl80211_send_disassoc(struct cfg80211_registered_device *rdev, + struct net_device *netdev, const u8 *buf, + size_t len, bool reconnect, gfp_t gfp) +{ + nl80211_send_mlme_event(rdev, netdev, buf, len, + NL80211_CMD_DISASSOCIATE, gfp, -1, NULL, 0, + reconnect); +} + +void cfg80211_rx_unprot_mlme_mgmt(struct net_device *dev, const u8 *buf, + size_t len) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + const struct ieee80211_mgmt *mgmt = (void *)buf; + u32 cmd; + + if (WARN_ON(len < 2)) + return; + + if (ieee80211_is_deauth(mgmt->frame_control)) { + cmd = NL80211_CMD_UNPROT_DEAUTHENTICATE; + } else if (ieee80211_is_disassoc(mgmt->frame_control)) { + cmd = NL80211_CMD_UNPROT_DISASSOCIATE; + } else if (ieee80211_is_beacon(mgmt->frame_control)) { + if (wdev->unprot_beacon_reported && + elapsed_jiffies_msecs(wdev->unprot_beacon_reported) < 10000) + return; + cmd = NL80211_CMD_UNPROT_BEACON; + wdev->unprot_beacon_reported = jiffies; + } else { + return; + } + + trace_cfg80211_rx_unprot_mlme_mgmt(dev, buf, len); + nl80211_send_mlme_event(rdev, dev, buf, len, cmd, GFP_ATOMIC, -1, + NULL, 0, false); +} +EXPORT_SYMBOL(cfg80211_rx_unprot_mlme_mgmt); + +static void nl80211_send_mlme_timeout(struct cfg80211_registered_device *rdev, + struct net_device *netdev, int cmd, + const u8 *addr, gfp_t gfp) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, cmd); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, netdev->ifindex) || + nla_put_flag(msg, NL80211_ATTR_TIMED_OUT) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, addr)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} + +void nl80211_send_auth_timeout(struct cfg80211_registered_device *rdev, + struct net_device *netdev, const u8 *addr, + gfp_t gfp) +{ + nl80211_send_mlme_timeout(rdev, netdev, NL80211_CMD_AUTHENTICATE, + addr, gfp); +} + +void nl80211_send_assoc_timeout(struct cfg80211_registered_device *rdev, + struct net_device *netdev, const u8 *addr, + gfp_t gfp) +{ + nl80211_send_mlme_timeout(rdev, netdev, NL80211_CMD_ASSOCIATE, + addr, gfp); +} + +void nl80211_send_connect_result(struct cfg80211_registered_device *rdev, + struct net_device *netdev, + struct cfg80211_connect_resp_params *cr, + gfp_t gfp) +{ + struct sk_buff *msg; + void *hdr; + unsigned int link; + size_t link_info_size = 0; + const u8 *connected_addr = cr->valid_links ? + cr->ap_mld_addr : cr->links[0].bssid; + + if (cr->valid_links) { + for_each_valid_link(cr, link) { + /* Nested attribute header */ + link_info_size += NLA_HDRLEN; + /* Link ID */ + link_info_size += nla_total_size(sizeof(u8)); + link_info_size += cr->links[link].addr ? + nla_total_size(ETH_ALEN) : 0; + link_info_size += (cr->links[link].bssid || + cr->links[link].bss) ? + nla_total_size(ETH_ALEN) : 0; + } + } + + msg = nlmsg_new(100 + cr->req_ie_len + cr->resp_ie_len + + cr->fils.kek_len + cr->fils.pmk_len + + (cr->fils.pmkid ? WLAN_PMKID_LEN : 0) + link_info_size, + gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_CONNECT); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, netdev->ifindex) || + (connected_addr && + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, connected_addr)) || + nla_put_u16(msg, NL80211_ATTR_STATUS_CODE, + cr->status < 0 ? WLAN_STATUS_UNSPECIFIED_FAILURE : + cr->status) || + (cr->status < 0 && + (nla_put_flag(msg, NL80211_ATTR_TIMED_OUT) || + nla_put_u32(msg, NL80211_ATTR_TIMEOUT_REASON, + cr->timeout_reason))) || + (cr->req_ie && + nla_put(msg, NL80211_ATTR_REQ_IE, cr->req_ie_len, cr->req_ie)) || + (cr->resp_ie && + nla_put(msg, NL80211_ATTR_RESP_IE, cr->resp_ie_len, + cr->resp_ie)) || + (cr->fils.update_erp_next_seq_num && + nla_put_u16(msg, NL80211_ATTR_FILS_ERP_NEXT_SEQ_NUM, + cr->fils.erp_next_seq_num)) || + (cr->status == WLAN_STATUS_SUCCESS && + ((cr->fils.kek && + nla_put(msg, NL80211_ATTR_FILS_KEK, cr->fils.kek_len, + cr->fils.kek)) || + (cr->fils.pmk && + nla_put(msg, NL80211_ATTR_PMK, cr->fils.pmk_len, cr->fils.pmk)) || + (cr->fils.pmkid && + nla_put(msg, NL80211_ATTR_PMKID, WLAN_PMKID_LEN, cr->fils.pmkid))))) + goto nla_put_failure; + + if (cr->valid_links) { + int i = 1; + struct nlattr *nested; + + nested = nla_nest_start(msg, NL80211_ATTR_MLO_LINKS); + if (!nested) + goto nla_put_failure; + + for_each_valid_link(cr, link) { + struct nlattr *nested_mlo_links; + const u8 *bssid = cr->links[link].bss ? + cr->links[link].bss->bssid : + cr->links[link].bssid; + + nested_mlo_links = nla_nest_start(msg, i); + if (!nested_mlo_links) + goto nla_put_failure; + + if (nla_put_u8(msg, NL80211_ATTR_MLO_LINK_ID, link) || + (bssid && + nla_put(msg, NL80211_ATTR_BSSID, ETH_ALEN, bssid)) || + (cr->links[link].addr && + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, + cr->links[link].addr))) + goto nla_put_failure; + + nla_nest_end(msg, nested_mlo_links); + i++; + } + nla_nest_end(msg, nested); + } + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} + +void nl80211_send_roamed(struct cfg80211_registered_device *rdev, + struct net_device *netdev, + struct cfg80211_roam_info *info, gfp_t gfp) +{ + struct sk_buff *msg; + void *hdr; + size_t link_info_size = 0; + unsigned int link; + const u8 *connected_addr = info->ap_mld_addr ? + info->ap_mld_addr : + (info->links[0].bss ? + info->links[0].bss->bssid : + info->links[0].bssid); + + if (info->valid_links) { + for_each_valid_link(info, link) { + /* Nested attribute header */ + link_info_size += NLA_HDRLEN; + /* Link ID */ + link_info_size += nla_total_size(sizeof(u8)); + link_info_size += info->links[link].addr ? + nla_total_size(ETH_ALEN) : 0; + link_info_size += (info->links[link].bssid || + info->links[link].bss) ? + nla_total_size(ETH_ALEN) : 0; + } + } + + msg = nlmsg_new(100 + info->req_ie_len + info->resp_ie_len + + info->fils.kek_len + info->fils.pmk_len + + (info->fils.pmkid ? WLAN_PMKID_LEN : 0) + + link_info_size, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_ROAM); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, netdev->ifindex) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, connected_addr) || + (info->req_ie && + nla_put(msg, NL80211_ATTR_REQ_IE, info->req_ie_len, + info->req_ie)) || + (info->resp_ie && + nla_put(msg, NL80211_ATTR_RESP_IE, info->resp_ie_len, + info->resp_ie)) || + (info->fils.update_erp_next_seq_num && + nla_put_u16(msg, NL80211_ATTR_FILS_ERP_NEXT_SEQ_NUM, + info->fils.erp_next_seq_num)) || + (info->fils.kek && + nla_put(msg, NL80211_ATTR_FILS_KEK, info->fils.kek_len, + info->fils.kek)) || + (info->fils.pmk && + nla_put(msg, NL80211_ATTR_PMK, info->fils.pmk_len, info->fils.pmk)) || + (info->fils.pmkid && + nla_put(msg, NL80211_ATTR_PMKID, WLAN_PMKID_LEN, info->fils.pmkid))) + goto nla_put_failure; + + if (info->valid_links) { + int i = 1; + struct nlattr *nested; + + nested = nla_nest_start(msg, NL80211_ATTR_MLO_LINKS); + if (!nested) + goto nla_put_failure; + + for_each_valid_link(info, link) { + struct nlattr *nested_mlo_links; + const u8 *bssid = info->links[link].bss ? + info->links[link].bss->bssid : + info->links[link].bssid; + + nested_mlo_links = nla_nest_start(msg, i); + if (!nested_mlo_links) + goto nla_put_failure; + + if (nla_put_u8(msg, NL80211_ATTR_MLO_LINK_ID, link) || + (bssid && + nla_put(msg, NL80211_ATTR_BSSID, ETH_ALEN, bssid)) || + (info->links[link].addr && + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, + info->links[link].addr))) + goto nla_put_failure; + + nla_nest_end(msg, nested_mlo_links); + i++; + } + nla_nest_end(msg, nested); + } + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} + +void nl80211_send_port_authorized(struct cfg80211_registered_device *rdev, + struct net_device *netdev, const u8 *bssid) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_PORT_AUTHORIZED); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, netdev->ifindex) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, bssid)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, GFP_KERNEL); + return; + + nla_put_failure: + nlmsg_free(msg); +} + +void nl80211_send_disconnected(struct cfg80211_registered_device *rdev, + struct net_device *netdev, u16 reason, + const u8 *ie, size_t ie_len, bool from_ap) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(100 + ie_len, GFP_KERNEL); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_DISCONNECT); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, netdev->ifindex) || + (reason && + nla_put_u16(msg, NL80211_ATTR_REASON_CODE, reason)) || + (from_ap && + nla_put_flag(msg, NL80211_ATTR_DISCONNECTED_BY_AP)) || + (ie && nla_put(msg, NL80211_ATTR_IE, ie_len, ie))) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, GFP_KERNEL); + return; + + nla_put_failure: + nlmsg_free(msg); +} + +void nl80211_send_ibss_bssid(struct cfg80211_registered_device *rdev, + struct net_device *netdev, const u8 *bssid, + gfp_t gfp) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_JOIN_IBSS); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, netdev->ifindex) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, bssid)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} + +void cfg80211_notify_new_peer_candidate(struct net_device *dev, const u8 *addr, + const u8 *ie, u8 ie_len, + int sig_dbm, gfp_t gfp) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct sk_buff *msg; + void *hdr; + + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_MESH_POINT)) + return; + + trace_cfg80211_notify_new_peer_candidate(dev, addr); + + msg = nlmsg_new(100 + ie_len, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_NEW_PEER_CANDIDATE); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, addr) || + (ie_len && ie && + nla_put(msg, NL80211_ATTR_IE, ie_len, ie)) || + (sig_dbm && + nla_put_u32(msg, NL80211_ATTR_RX_SIGNAL_DBM, sig_dbm))) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} +EXPORT_SYMBOL(cfg80211_notify_new_peer_candidate); + +void nl80211_michael_mic_failure(struct cfg80211_registered_device *rdev, + struct net_device *netdev, const u8 *addr, + enum nl80211_key_type key_type, int key_id, + const u8 *tsc, gfp_t gfp) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_MICHAEL_MIC_FAILURE); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, netdev->ifindex) || + (addr && nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, addr)) || + nla_put_u32(msg, NL80211_ATTR_KEY_TYPE, key_type) || + (key_id != -1 && + nla_put_u8(msg, NL80211_ATTR_KEY_IDX, key_id)) || + (tsc && nla_put(msg, NL80211_ATTR_KEY_SEQ, 6, tsc))) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} + +void nl80211_send_beacon_hint_event(struct wiphy *wiphy, + struct ieee80211_channel *channel_before, + struct ieee80211_channel *channel_after) +{ + struct sk_buff *msg; + void *hdr; + struct nlattr *nl_freq; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_REG_BEACON_HINT); + if (!hdr) { + nlmsg_free(msg); + return; + } + + /* + * Since we are applying the beacon hint to a wiphy we know its + * wiphy_idx is valid + */ + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, get_wiphy_idx(wiphy))) + goto nla_put_failure; + + /* Before */ + nl_freq = nla_nest_start_noflag(msg, NL80211_ATTR_FREQ_BEFORE); + if (!nl_freq) + goto nla_put_failure; + + if (nl80211_msg_put_channel(msg, wiphy, channel_before, false)) + goto nla_put_failure; + nla_nest_end(msg, nl_freq); + + /* After */ + nl_freq = nla_nest_start_noflag(msg, NL80211_ATTR_FREQ_AFTER); + if (!nl_freq) + goto nla_put_failure; + + if (nl80211_msg_put_channel(msg, wiphy, channel_after, false)) + goto nla_put_failure; + nla_nest_end(msg, nl_freq); + + genlmsg_end(msg, hdr); + + rcu_read_lock(); + genlmsg_multicast_allns(&nl80211_fam, msg, 0, + NL80211_MCGRP_REGULATORY, GFP_ATOMIC); + rcu_read_unlock(); + + return; + +nla_put_failure: + nlmsg_free(msg); +} + +static void nl80211_send_remain_on_chan_event( + int cmd, struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, u64 cookie, + struct ieee80211_channel *chan, + unsigned int duration, gfp_t gfp) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, cmd); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + (wdev->netdev && nla_put_u32(msg, NL80211_ATTR_IFINDEX, + wdev->netdev->ifindex)) || + nla_put_u64_64bit(msg, NL80211_ATTR_WDEV, wdev_id(wdev), + NL80211_ATTR_PAD) || + nla_put_u32(msg, NL80211_ATTR_WIPHY_FREQ, chan->center_freq) || + nla_put_u32(msg, NL80211_ATTR_WIPHY_CHANNEL_TYPE, + NL80211_CHAN_NO_HT) || + nla_put_u64_64bit(msg, NL80211_ATTR_COOKIE, cookie, + NL80211_ATTR_PAD)) + goto nla_put_failure; + + if (cmd == NL80211_CMD_REMAIN_ON_CHANNEL && + nla_put_u32(msg, NL80211_ATTR_DURATION, duration)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} + +void cfg80211_assoc_comeback(struct net_device *netdev, + const u8 *ap_addr, u32 timeout) +{ + struct wireless_dev *wdev = netdev->ieee80211_ptr; + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct sk_buff *msg; + void *hdr; + + trace_cfg80211_assoc_comeback(wdev, ap_addr, timeout); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_ASSOC_COMEBACK); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, netdev->ifindex) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, ap_addr) || + nla_put_u32(msg, NL80211_ATTR_TIMEOUT, timeout)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, GFP_KERNEL); + return; + + nla_put_failure: + nlmsg_free(msg); +} +EXPORT_SYMBOL(cfg80211_assoc_comeback); + +void cfg80211_ready_on_channel(struct wireless_dev *wdev, u64 cookie, + struct ieee80211_channel *chan, + unsigned int duration, gfp_t gfp) +{ + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + trace_cfg80211_ready_on_channel(wdev, cookie, chan, duration); + nl80211_send_remain_on_chan_event(NL80211_CMD_REMAIN_ON_CHANNEL, + rdev, wdev, cookie, chan, + duration, gfp); +} +EXPORT_SYMBOL(cfg80211_ready_on_channel); + +void cfg80211_remain_on_channel_expired(struct wireless_dev *wdev, u64 cookie, + struct ieee80211_channel *chan, + gfp_t gfp) +{ + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + trace_cfg80211_ready_on_channel_expired(wdev, cookie, chan); + nl80211_send_remain_on_chan_event(NL80211_CMD_CANCEL_REMAIN_ON_CHANNEL, + rdev, wdev, cookie, chan, 0, gfp); +} +EXPORT_SYMBOL(cfg80211_remain_on_channel_expired); + +void cfg80211_tx_mgmt_expired(struct wireless_dev *wdev, u64 cookie, + struct ieee80211_channel *chan, + gfp_t gfp) +{ + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + trace_cfg80211_tx_mgmt_expired(wdev, cookie, chan); + nl80211_send_remain_on_chan_event(NL80211_CMD_FRAME_WAIT_CANCEL, + rdev, wdev, cookie, chan, 0, gfp); +} +EXPORT_SYMBOL(cfg80211_tx_mgmt_expired); + +void cfg80211_new_sta(struct net_device *dev, const u8 *mac_addr, + struct station_info *sinfo, gfp_t gfp) +{ + struct wiphy *wiphy = dev->ieee80211_ptr->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct sk_buff *msg; + + trace_cfg80211_new_sta(dev, mac_addr, sinfo); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return; + + if (nl80211_send_station(msg, NL80211_CMD_NEW_STATION, 0, 0, 0, + rdev, dev, mac_addr, sinfo) < 0) { + nlmsg_free(msg); + return; + } + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); +} +EXPORT_SYMBOL(cfg80211_new_sta); + +void cfg80211_del_sta_sinfo(struct net_device *dev, const u8 *mac_addr, + struct station_info *sinfo, gfp_t gfp) +{ + struct wiphy *wiphy = dev->ieee80211_ptr->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct sk_buff *msg; + struct station_info empty_sinfo = {}; + + if (!sinfo) + sinfo = &empty_sinfo; + + trace_cfg80211_del_sta(dev, mac_addr); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) { + cfg80211_sinfo_release_content(sinfo); + return; + } + + if (nl80211_send_station(msg, NL80211_CMD_DEL_STATION, 0, 0, 0, + rdev, dev, mac_addr, sinfo) < 0) { + nlmsg_free(msg); + return; + } + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); +} +EXPORT_SYMBOL(cfg80211_del_sta_sinfo); + +void cfg80211_conn_failed(struct net_device *dev, const u8 *mac_addr, + enum nl80211_connect_failed_reason reason, + gfp_t gfp) +{ + struct wiphy *wiphy = dev->ieee80211_ptr->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_GOODSIZE, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_CONN_FAILED); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, mac_addr) || + nla_put_u32(msg, NL80211_ATTR_CONN_FAILED_REASON, reason)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} +EXPORT_SYMBOL(cfg80211_conn_failed); + +static bool __nl80211_unexpected_frame(struct net_device *dev, u8 cmd, + const u8 *addr, gfp_t gfp) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct sk_buff *msg; + void *hdr; + u32 nlportid = READ_ONCE(wdev->ap_unexpected_nlportid); + + if (!nlportid) + return false; + + msg = nlmsg_new(100, gfp); + if (!msg) + return true; + + hdr = nl80211hdr_put(msg, 0, 0, 0, cmd); + if (!hdr) { + nlmsg_free(msg); + return true; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, addr)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + genlmsg_unicast(wiphy_net(&rdev->wiphy), msg, nlportid); + return true; + + nla_put_failure: + nlmsg_free(msg); + return true; +} + +bool cfg80211_rx_spurious_frame(struct net_device *dev, + const u8 *addr, gfp_t gfp) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + bool ret; + + trace_cfg80211_rx_spurious_frame(dev, addr); + + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_AP && + wdev->iftype != NL80211_IFTYPE_P2P_GO)) { + trace_cfg80211_return_bool(false); + return false; + } + ret = __nl80211_unexpected_frame(dev, NL80211_CMD_UNEXPECTED_FRAME, + addr, gfp); + trace_cfg80211_return_bool(ret); + return ret; +} +EXPORT_SYMBOL(cfg80211_rx_spurious_frame); + +bool cfg80211_rx_unexpected_4addr_frame(struct net_device *dev, + const u8 *addr, gfp_t gfp) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + bool ret; + + trace_cfg80211_rx_unexpected_4addr_frame(dev, addr); + + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_AP && + wdev->iftype != NL80211_IFTYPE_P2P_GO && + wdev->iftype != NL80211_IFTYPE_AP_VLAN)) { + trace_cfg80211_return_bool(false); + return false; + } + ret = __nl80211_unexpected_frame(dev, + NL80211_CMD_UNEXPECTED_4ADDR_FRAME, + addr, gfp); + trace_cfg80211_return_bool(ret); + return ret; +} +EXPORT_SYMBOL(cfg80211_rx_unexpected_4addr_frame); + +int nl80211_send_mgmt(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, u32 nlportid, + struct cfg80211_rx_info *info, gfp_t gfp) +{ + struct net_device *netdev = wdev->netdev; + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(100 + info->len, gfp); + if (!msg) + return -ENOMEM; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_FRAME); + if (!hdr) { + nlmsg_free(msg); + return -ENOMEM; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + (netdev && nla_put_u32(msg, NL80211_ATTR_IFINDEX, + netdev->ifindex)) || + nla_put_u64_64bit(msg, NL80211_ATTR_WDEV, wdev_id(wdev), + NL80211_ATTR_PAD) || + (info->have_link_id && + nla_put_u8(msg, NL80211_ATTR_MLO_LINK_ID, info->link_id)) || + nla_put_u32(msg, NL80211_ATTR_WIPHY_FREQ, KHZ_TO_MHZ(info->freq)) || + nla_put_u32(msg, NL80211_ATTR_WIPHY_FREQ_OFFSET, info->freq % 1000) || + (info->sig_dbm && + nla_put_u32(msg, NL80211_ATTR_RX_SIGNAL_DBM, info->sig_dbm)) || + nla_put(msg, NL80211_ATTR_FRAME, info->len, info->buf) || + (info->flags && + nla_put_u32(msg, NL80211_ATTR_RXMGMT_FLAGS, info->flags)) || + (info->rx_tstamp && nla_put_u64_64bit(msg, + NL80211_ATTR_RX_HW_TIMESTAMP, + info->rx_tstamp, + NL80211_ATTR_PAD)) || + (info->ack_tstamp && nla_put_u64_64bit(msg, + NL80211_ATTR_TX_HW_TIMESTAMP, + info->ack_tstamp, + NL80211_ATTR_PAD))) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + return genlmsg_unicast(wiphy_net(&rdev->wiphy), msg, nlportid); + + nla_put_failure: + nlmsg_free(msg); + return -ENOBUFS; +} + +static void nl80211_frame_tx_status(struct wireless_dev *wdev, + struct cfg80211_tx_status *status, + gfp_t gfp, enum nl80211_commands command) +{ + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct net_device *netdev = wdev->netdev; + struct sk_buff *msg; + void *hdr; + + if (command == NL80211_CMD_FRAME_TX_STATUS) + trace_cfg80211_mgmt_tx_status(wdev, status->cookie, + status->ack); + else + trace_cfg80211_control_port_tx_status(wdev, status->cookie, + status->ack); + + msg = nlmsg_new(100 + status->len, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, command); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + (netdev && nla_put_u32(msg, NL80211_ATTR_IFINDEX, + netdev->ifindex)) || + nla_put_u64_64bit(msg, NL80211_ATTR_WDEV, wdev_id(wdev), + NL80211_ATTR_PAD) || + nla_put(msg, NL80211_ATTR_FRAME, status->len, status->buf) || + nla_put_u64_64bit(msg, NL80211_ATTR_COOKIE, status->cookie, + NL80211_ATTR_PAD) || + (status->ack && nla_put_flag(msg, NL80211_ATTR_ACK)) || + (status->tx_tstamp && + nla_put_u64_64bit(msg, NL80211_ATTR_TX_HW_TIMESTAMP, + status->tx_tstamp, NL80211_ATTR_PAD)) || + (status->ack_tstamp && + nla_put_u64_64bit(msg, NL80211_ATTR_RX_HW_TIMESTAMP, + status->ack_tstamp, NL80211_ATTR_PAD))) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + +nla_put_failure: + nlmsg_free(msg); +} + +void cfg80211_control_port_tx_status(struct wireless_dev *wdev, u64 cookie, + const u8 *buf, size_t len, bool ack, + gfp_t gfp) +{ + struct cfg80211_tx_status status = { + .cookie = cookie, + .buf = buf, + .len = len, + .ack = ack + }; + + nl80211_frame_tx_status(wdev, &status, gfp, + NL80211_CMD_CONTROL_PORT_FRAME_TX_STATUS); +} +EXPORT_SYMBOL(cfg80211_control_port_tx_status); + +void cfg80211_mgmt_tx_status_ext(struct wireless_dev *wdev, + struct cfg80211_tx_status *status, gfp_t gfp) +{ + nl80211_frame_tx_status(wdev, status, gfp, NL80211_CMD_FRAME_TX_STATUS); +} +EXPORT_SYMBOL(cfg80211_mgmt_tx_status_ext); + +static int __nl80211_rx_control_port(struct net_device *dev, + struct sk_buff *skb, + bool unencrypted, gfp_t gfp) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct ethhdr *ehdr = eth_hdr(skb); + const u8 *addr = ehdr->h_source; + u16 proto = be16_to_cpu(skb->protocol); + struct sk_buff *msg; + void *hdr; + struct nlattr *frame; + + u32 nlportid = READ_ONCE(wdev->conn_owner_nlportid); + + if (!nlportid) + return -ENOENT; + + msg = nlmsg_new(100 + skb->len, gfp); + if (!msg) + return -ENOMEM; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_CONTROL_PORT_FRAME); + if (!hdr) { + nlmsg_free(msg); + return -ENOBUFS; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex) || + nla_put_u64_64bit(msg, NL80211_ATTR_WDEV, wdev_id(wdev), + NL80211_ATTR_PAD) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, addr) || + nla_put_u16(msg, NL80211_ATTR_CONTROL_PORT_ETHERTYPE, proto) || + (unencrypted && nla_put_flag(msg, + NL80211_ATTR_CONTROL_PORT_NO_ENCRYPT))) + goto nla_put_failure; + + frame = nla_reserve(msg, NL80211_ATTR_FRAME, skb->len); + if (!frame) + goto nla_put_failure; + + skb_copy_bits(skb, 0, nla_data(frame), skb->len); + genlmsg_end(msg, hdr); + + return genlmsg_unicast(wiphy_net(&rdev->wiphy), msg, nlportid); + + nla_put_failure: + nlmsg_free(msg); + return -ENOBUFS; +} + +bool cfg80211_rx_control_port(struct net_device *dev, + struct sk_buff *skb, bool unencrypted) +{ + int ret; + + trace_cfg80211_rx_control_port(dev, skb, unencrypted); + ret = __nl80211_rx_control_port(dev, skb, unencrypted, GFP_ATOMIC); + trace_cfg80211_return_bool(ret == 0); + return ret == 0; +} +EXPORT_SYMBOL(cfg80211_rx_control_port); + +static struct sk_buff *cfg80211_prepare_cqm(struct net_device *dev, + const char *mac, gfp_t gfp) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct sk_buff *msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + void **cb; + + if (!msg) + return NULL; + + cb = (void **)msg->cb; + + cb[0] = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_NOTIFY_CQM); + if (!cb[0]) { + nlmsg_free(msg); + return NULL; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex)) + goto nla_put_failure; + + if (mac && nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, mac)) + goto nla_put_failure; + + cb[1] = nla_nest_start_noflag(msg, NL80211_ATTR_CQM); + if (!cb[1]) + goto nla_put_failure; + + cb[2] = rdev; + + return msg; + nla_put_failure: + nlmsg_free(msg); + return NULL; +} + +static void cfg80211_send_cqm(struct sk_buff *msg, gfp_t gfp) +{ + void **cb = (void **)msg->cb; + struct cfg80211_registered_device *rdev = cb[2]; + + nla_nest_end(msg, cb[1]); + genlmsg_end(msg, cb[0]); + + memset(msg->cb, 0, sizeof(msg->cb)); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); +} + +void cfg80211_cqm_rssi_notify(struct net_device *dev, + enum nl80211_cqm_rssi_threshold_event rssi_event, + s32 rssi_level, gfp_t gfp) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_cqm_config *cqm_config; + + trace_cfg80211_cqm_rssi_notify(dev, rssi_event, rssi_level); + + if (WARN_ON(rssi_event != NL80211_CQM_RSSI_THRESHOLD_EVENT_LOW && + rssi_event != NL80211_CQM_RSSI_THRESHOLD_EVENT_HIGH)) + return; + + rcu_read_lock(); + cqm_config = rcu_dereference(wdev->cqm_config); + if (cqm_config) { + cqm_config->last_rssi_event_value = rssi_level; + cqm_config->last_rssi_event_type = rssi_event; + wiphy_work_queue(wdev->wiphy, &wdev->cqm_rssi_work); + } + rcu_read_unlock(); +} +EXPORT_SYMBOL(cfg80211_cqm_rssi_notify); + +void cfg80211_cqm_rssi_notify_work(struct wiphy *wiphy, struct wiphy_work *work) +{ + struct wireless_dev *wdev = container_of(work, struct wireless_dev, + cqm_rssi_work); + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + enum nl80211_cqm_rssi_threshold_event rssi_event; + struct cfg80211_cqm_config *cqm_config; + struct sk_buff *msg; + s32 rssi_level; + + wdev_lock(wdev); + cqm_config = rcu_dereference_protected(wdev->cqm_config, + lockdep_is_held(&wdev->mtx)); + if (!cqm_config) + goto unlock; + + if (cqm_config->use_range_api) + cfg80211_cqm_rssi_update(rdev, wdev->netdev, cqm_config); + + rssi_level = cqm_config->last_rssi_event_value; + rssi_event = cqm_config->last_rssi_event_type; + + msg = cfg80211_prepare_cqm(wdev->netdev, NULL, GFP_KERNEL); + if (!msg) + goto unlock; + + if (nla_put_u32(msg, NL80211_ATTR_CQM_RSSI_THRESHOLD_EVENT, + rssi_event)) + goto nla_put_failure; + + if (rssi_level && nla_put_s32(msg, NL80211_ATTR_CQM_RSSI_LEVEL, + rssi_level)) + goto nla_put_failure; + + cfg80211_send_cqm(msg, GFP_KERNEL); + + goto unlock; + + nla_put_failure: + nlmsg_free(msg); + unlock: + wdev_unlock(wdev); +} + +void cfg80211_cqm_txe_notify(struct net_device *dev, + const u8 *peer, u32 num_packets, + u32 rate, u32 intvl, gfp_t gfp) +{ + struct sk_buff *msg; + + msg = cfg80211_prepare_cqm(dev, peer, gfp); + if (!msg) + return; + + if (nla_put_u32(msg, NL80211_ATTR_CQM_TXE_PKTS, num_packets)) + goto nla_put_failure; + + if (nla_put_u32(msg, NL80211_ATTR_CQM_TXE_RATE, rate)) + goto nla_put_failure; + + if (nla_put_u32(msg, NL80211_ATTR_CQM_TXE_INTVL, intvl)) + goto nla_put_failure; + + cfg80211_send_cqm(msg, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} +EXPORT_SYMBOL(cfg80211_cqm_txe_notify); + +void cfg80211_cqm_pktloss_notify(struct net_device *dev, + const u8 *peer, u32 num_packets, gfp_t gfp) +{ + struct sk_buff *msg; + + trace_cfg80211_cqm_pktloss_notify(dev, peer, num_packets); + + msg = cfg80211_prepare_cqm(dev, peer, gfp); + if (!msg) + return; + + if (nla_put_u32(msg, NL80211_ATTR_CQM_PKT_LOSS_EVENT, num_packets)) + goto nla_put_failure; + + cfg80211_send_cqm(msg, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} +EXPORT_SYMBOL(cfg80211_cqm_pktloss_notify); + +void cfg80211_cqm_beacon_loss_notify(struct net_device *dev, gfp_t gfp) +{ + struct sk_buff *msg; + + msg = cfg80211_prepare_cqm(dev, NULL, gfp); + if (!msg) + return; + + if (nla_put_flag(msg, NL80211_ATTR_CQM_BEACON_LOSS_EVENT)) + goto nla_put_failure; + + cfg80211_send_cqm(msg, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} +EXPORT_SYMBOL(cfg80211_cqm_beacon_loss_notify); + +static void nl80211_gtk_rekey_notify(struct cfg80211_registered_device *rdev, + struct net_device *netdev, const u8 *bssid, + const u8 *replay_ctr, gfp_t gfp) +{ + struct sk_buff *msg; + struct nlattr *rekey_attr; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_SET_REKEY_OFFLOAD); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, netdev->ifindex) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, bssid)) + goto nla_put_failure; + + rekey_attr = nla_nest_start_noflag(msg, NL80211_ATTR_REKEY_DATA); + if (!rekey_attr) + goto nla_put_failure; + + if (nla_put(msg, NL80211_REKEY_DATA_REPLAY_CTR, + NL80211_REPLAY_CTR_LEN, replay_ctr)) + goto nla_put_failure; + + nla_nest_end(msg, rekey_attr); + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} + +void cfg80211_gtk_rekey_notify(struct net_device *dev, const u8 *bssid, + const u8 *replay_ctr, gfp_t gfp) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + trace_cfg80211_gtk_rekey_notify(dev, bssid); + nl80211_gtk_rekey_notify(rdev, dev, bssid, replay_ctr, gfp); +} +EXPORT_SYMBOL(cfg80211_gtk_rekey_notify); + +static void +nl80211_pmksa_candidate_notify(struct cfg80211_registered_device *rdev, + struct net_device *netdev, int index, + const u8 *bssid, bool preauth, gfp_t gfp) +{ + struct sk_buff *msg; + struct nlattr *attr; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_PMKSA_CANDIDATE); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, netdev->ifindex)) + goto nla_put_failure; + + attr = nla_nest_start_noflag(msg, NL80211_ATTR_PMKSA_CANDIDATE); + if (!attr) + goto nla_put_failure; + + if (nla_put_u32(msg, NL80211_PMKSA_CANDIDATE_INDEX, index) || + nla_put(msg, NL80211_PMKSA_CANDIDATE_BSSID, ETH_ALEN, bssid) || + (preauth && + nla_put_flag(msg, NL80211_PMKSA_CANDIDATE_PREAUTH))) + goto nla_put_failure; + + nla_nest_end(msg, attr); + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} + +void cfg80211_pmksa_candidate_notify(struct net_device *dev, int index, + const u8 *bssid, bool preauth, gfp_t gfp) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + trace_cfg80211_pmksa_candidate_notify(dev, index, bssid, preauth); + nl80211_pmksa_candidate_notify(rdev, dev, index, bssid, preauth, gfp); +} +EXPORT_SYMBOL(cfg80211_pmksa_candidate_notify); + +static void nl80211_ch_switch_notify(struct cfg80211_registered_device *rdev, + struct net_device *netdev, + unsigned int link_id, + struct cfg80211_chan_def *chandef, + gfp_t gfp, + enum nl80211_commands notif, + u8 count, bool quiet) +{ + struct wireless_dev *wdev = netdev->ieee80211_ptr; + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, notif); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_IFINDEX, netdev->ifindex)) + goto nla_put_failure; + + if (wdev->valid_links && + nla_put_u8(msg, NL80211_ATTR_MLO_LINK_ID, link_id)) + goto nla_put_failure; + + if (nl80211_send_chandef(msg, chandef)) + goto nla_put_failure; + + if (notif == NL80211_CMD_CH_SWITCH_STARTED_NOTIFY) { + if (nla_put_u32(msg, NL80211_ATTR_CH_SWITCH_COUNT, count)) + goto nla_put_failure; + if (quiet && + nla_put_flag(msg, NL80211_ATTR_CH_SWITCH_BLOCK_TX)) + goto nla_put_failure; + } + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} + +void cfg80211_ch_switch_notify(struct net_device *dev, + struct cfg80211_chan_def *chandef, + unsigned int link_id) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + ASSERT_WDEV_LOCK(wdev); + WARN_INVALID_LINK_ID(wdev, link_id); + + trace_cfg80211_ch_switch_notify(dev, chandef, link_id); + + switch (wdev->iftype) { + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_P2P_CLIENT: + if (!WARN_ON(!wdev->links[link_id].client.current_bss)) + cfg80211_update_assoc_bss_entry(wdev, link_id, + chandef->chan); + break; + case NL80211_IFTYPE_MESH_POINT: + wdev->u.mesh.chandef = *chandef; + wdev->u.mesh.preset_chandef = *chandef; + break; + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_P2P_GO: + wdev->links[link_id].ap.chandef = *chandef; + break; + case NL80211_IFTYPE_ADHOC: + wdev->u.ibss.chandef = *chandef; + break; + default: + WARN_ON(1); + break; + } + + cfg80211_sched_dfs_chan_update(rdev); + + nl80211_ch_switch_notify(rdev, dev, link_id, chandef, GFP_KERNEL, + NL80211_CMD_CH_SWITCH_NOTIFY, 0, false); +} +EXPORT_SYMBOL(cfg80211_ch_switch_notify); + +void cfg80211_ch_switch_started_notify(struct net_device *dev, + struct cfg80211_chan_def *chandef, + unsigned int link_id, u8 count, + bool quiet) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + ASSERT_WDEV_LOCK(wdev); + WARN_INVALID_LINK_ID(wdev, link_id); + + trace_cfg80211_ch_switch_started_notify(dev, chandef, link_id); + + nl80211_ch_switch_notify(rdev, dev, link_id, chandef, GFP_KERNEL, + NL80211_CMD_CH_SWITCH_STARTED_NOTIFY, + count, quiet); +} +EXPORT_SYMBOL(cfg80211_ch_switch_started_notify); + +int cfg80211_bss_color_notify(struct net_device *dev, gfp_t gfp, + enum nl80211_commands cmd, u8 count, + u64 color_bitmap) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct sk_buff *msg; + void *hdr; + + ASSERT_WDEV_LOCK(wdev); + + trace_cfg80211_bss_color_notify(dev, cmd, count, color_bitmap); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return -ENOMEM; + + hdr = nl80211hdr_put(msg, 0, 0, 0, cmd); + if (!hdr) + goto nla_put_failure; + + if (nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex)) + goto nla_put_failure; + + if (cmd == NL80211_CMD_COLOR_CHANGE_STARTED && + nla_put_u32(msg, NL80211_ATTR_COLOR_CHANGE_COUNT, count)) + goto nla_put_failure; + + if (cmd == NL80211_CMD_OBSS_COLOR_COLLISION && + nla_put_u64_64bit(msg, NL80211_ATTR_OBSS_COLOR_BITMAP, + color_bitmap, NL80211_ATTR_PAD)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + return genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), + msg, 0, NL80211_MCGRP_MLME, gfp); + +nla_put_failure: + nlmsg_free(msg); + return -EINVAL; +} +EXPORT_SYMBOL(cfg80211_bss_color_notify); + +void +nl80211_radar_notify(struct cfg80211_registered_device *rdev, + const struct cfg80211_chan_def *chandef, + enum nl80211_radar_event event, + struct net_device *netdev, gfp_t gfp) +{ + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_RADAR_DETECT); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx)) + goto nla_put_failure; + + /* NOP and radar events don't need a netdev parameter */ + if (netdev) { + struct wireless_dev *wdev = netdev->ieee80211_ptr; + + if (nla_put_u32(msg, NL80211_ATTR_IFINDEX, netdev->ifindex) || + nla_put_u64_64bit(msg, NL80211_ATTR_WDEV, wdev_id(wdev), + NL80211_ATTR_PAD)) + goto nla_put_failure; + } + + if (nla_put_u32(msg, NL80211_ATTR_RADAR_EVENT, event)) + goto nla_put_failure; + + if (nl80211_send_chandef(msg, chandef)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} + +void cfg80211_sta_opmode_change_notify(struct net_device *dev, const u8 *mac, + struct sta_opmode_info *sta_opmode, + gfp_t gfp) +{ + struct sk_buff *msg; + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + void *hdr; + + if (WARN_ON(!mac)) + return; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_STA_OPMODE_CHANGED); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx)) + goto nla_put_failure; + + if (nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex)) + goto nla_put_failure; + + if (nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, mac)) + goto nla_put_failure; + + if ((sta_opmode->changed & STA_OPMODE_SMPS_MODE_CHANGED) && + nla_put_u8(msg, NL80211_ATTR_SMPS_MODE, sta_opmode->smps_mode)) + goto nla_put_failure; + + if ((sta_opmode->changed & STA_OPMODE_MAX_BW_CHANGED) && + nla_put_u32(msg, NL80211_ATTR_CHANNEL_WIDTH, sta_opmode->bw)) + goto nla_put_failure; + + if ((sta_opmode->changed & STA_OPMODE_N_SS_CHANGED) && + nla_put_u8(msg, NL80211_ATTR_NSS, sta_opmode->rx_nss)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + + return; + +nla_put_failure: + nlmsg_free(msg); +} +EXPORT_SYMBOL(cfg80211_sta_opmode_change_notify); + +void cfg80211_probe_status(struct net_device *dev, const u8 *addr, + u64 cookie, bool acked, s32 ack_signal, + bool is_valid_ack_signal, gfp_t gfp) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct sk_buff *msg; + void *hdr; + + trace_cfg80211_probe_status(dev, addr, cookie, acked); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_PROBE_CLIENT); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, addr) || + nla_put_u64_64bit(msg, NL80211_ATTR_COOKIE, cookie, + NL80211_ATTR_PAD) || + (acked && nla_put_flag(msg, NL80211_ATTR_ACK)) || + (is_valid_ack_signal && nla_put_s32(msg, NL80211_ATTR_ACK_SIGNAL, + ack_signal))) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} +EXPORT_SYMBOL(cfg80211_probe_status); + +void cfg80211_report_obss_beacon_khz(struct wiphy *wiphy, const u8 *frame, + size_t len, int freq, int sig_dbm) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct sk_buff *msg; + void *hdr; + struct cfg80211_beacon_registration *reg; + + trace_cfg80211_report_obss_beacon(wiphy, frame, len, freq, sig_dbm); + + spin_lock_bh(&rdev->beacon_registrations_lock); + list_for_each_entry(reg, &rdev->beacon_registrations, list) { + msg = nlmsg_new(len + 100, GFP_ATOMIC); + if (!msg) { + spin_unlock_bh(&rdev->beacon_registrations_lock); + return; + } + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_FRAME); + if (!hdr) + goto nla_put_failure; + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + (freq && + (nla_put_u32(msg, NL80211_ATTR_WIPHY_FREQ, + KHZ_TO_MHZ(freq)) || + nla_put_u32(msg, NL80211_ATTR_WIPHY_FREQ_OFFSET, + freq % 1000))) || + (sig_dbm && + nla_put_u32(msg, NL80211_ATTR_RX_SIGNAL_DBM, sig_dbm)) || + nla_put(msg, NL80211_ATTR_FRAME, len, frame)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_unicast(wiphy_net(&rdev->wiphy), msg, reg->nlportid); + } + spin_unlock_bh(&rdev->beacon_registrations_lock); + return; + + nla_put_failure: + spin_unlock_bh(&rdev->beacon_registrations_lock); + nlmsg_free(msg); +} +EXPORT_SYMBOL(cfg80211_report_obss_beacon_khz); + +#ifdef CONFIG_PM +static int cfg80211_net_detect_results(struct sk_buff *msg, + struct cfg80211_wowlan_wakeup *wakeup) +{ + struct cfg80211_wowlan_nd_info *nd = wakeup->net_detect; + struct nlattr *nl_results, *nl_match, *nl_freqs; + int i, j; + + nl_results = nla_nest_start_noflag(msg, + NL80211_WOWLAN_TRIG_NET_DETECT_RESULTS); + if (!nl_results) + return -EMSGSIZE; + + for (i = 0; i < nd->n_matches; i++) { + struct cfg80211_wowlan_nd_match *match = nd->matches[i]; + + nl_match = nla_nest_start_noflag(msg, i); + if (!nl_match) + break; + + /* The SSID attribute is optional in nl80211, but for + * simplicity reasons it's always present in the + * cfg80211 structure. If a driver can't pass the + * SSID, that needs to be changed. A zero length SSID + * is still a valid SSID (wildcard), so it cannot be + * used for this purpose. + */ + if (nla_put(msg, NL80211_ATTR_SSID, match->ssid.ssid_len, + match->ssid.ssid)) { + nla_nest_cancel(msg, nl_match); + goto out; + } + + if (match->n_channels) { + nl_freqs = nla_nest_start_noflag(msg, + NL80211_ATTR_SCAN_FREQUENCIES); + if (!nl_freqs) { + nla_nest_cancel(msg, nl_match); + goto out; + } + + for (j = 0; j < match->n_channels; j++) { + if (nla_put_u32(msg, j, match->channels[j])) { + nla_nest_cancel(msg, nl_freqs); + nla_nest_cancel(msg, nl_match); + goto out; + } + } + + nla_nest_end(msg, nl_freqs); + } + + nla_nest_end(msg, nl_match); + } + +out: + nla_nest_end(msg, nl_results); + return 0; +} + +void cfg80211_report_wowlan_wakeup(struct wireless_dev *wdev, + struct cfg80211_wowlan_wakeup *wakeup, + gfp_t gfp) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct sk_buff *msg; + void *hdr; + int size = 200; + + trace_cfg80211_report_wowlan_wakeup(wdev->wiphy, wdev, wakeup); + + if (wakeup) + size += wakeup->packet_present_len; + + msg = nlmsg_new(size, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_SET_WOWLAN); + if (!hdr) + goto free_msg; + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u64_64bit(msg, NL80211_ATTR_WDEV, wdev_id(wdev), + NL80211_ATTR_PAD)) + goto free_msg; + + if (wdev->netdev && nla_put_u32(msg, NL80211_ATTR_IFINDEX, + wdev->netdev->ifindex)) + goto free_msg; + + if (wakeup) { + struct nlattr *reasons; + + reasons = nla_nest_start_noflag(msg, + NL80211_ATTR_WOWLAN_TRIGGERS); + if (!reasons) + goto free_msg; + + if (wakeup->disconnect && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_DISCONNECT)) + goto free_msg; + if (wakeup->magic_pkt && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_MAGIC_PKT)) + goto free_msg; + if (wakeup->gtk_rekey_failure && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_GTK_REKEY_FAILURE)) + goto free_msg; + if (wakeup->eap_identity_req && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_EAP_IDENT_REQUEST)) + goto free_msg; + if (wakeup->four_way_handshake && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_4WAY_HANDSHAKE)) + goto free_msg; + if (wakeup->rfkill_release && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_RFKILL_RELEASE)) + goto free_msg; + + if (wakeup->pattern_idx >= 0 && + nla_put_u32(msg, NL80211_WOWLAN_TRIG_PKT_PATTERN, + wakeup->pattern_idx)) + goto free_msg; + + if (wakeup->tcp_match && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_WAKEUP_TCP_MATCH)) + goto free_msg; + + if (wakeup->tcp_connlost && + nla_put_flag(msg, NL80211_WOWLAN_TRIG_WAKEUP_TCP_CONNLOST)) + goto free_msg; + + if (wakeup->tcp_nomoretokens && + nla_put_flag(msg, + NL80211_WOWLAN_TRIG_WAKEUP_TCP_NOMORETOKENS)) + goto free_msg; + + if (wakeup->packet) { + u32 pkt_attr = NL80211_WOWLAN_TRIG_WAKEUP_PKT_80211; + u32 len_attr = NL80211_WOWLAN_TRIG_WAKEUP_PKT_80211_LEN; + + if (!wakeup->packet_80211) { + pkt_attr = + NL80211_WOWLAN_TRIG_WAKEUP_PKT_8023; + len_attr = + NL80211_WOWLAN_TRIG_WAKEUP_PKT_8023_LEN; + } + + if (wakeup->packet_len && + nla_put_u32(msg, len_attr, wakeup->packet_len)) + goto free_msg; + + if (nla_put(msg, pkt_attr, wakeup->packet_present_len, + wakeup->packet)) + goto free_msg; + } + + if (wakeup->net_detect && + cfg80211_net_detect_results(msg, wakeup)) + goto free_msg; + + nla_nest_end(msg, reasons); + } + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + + free_msg: + nlmsg_free(msg); +} +EXPORT_SYMBOL(cfg80211_report_wowlan_wakeup); +#endif + +void cfg80211_tdls_oper_request(struct net_device *dev, const u8 *peer, + enum nl80211_tdls_operation oper, + u16 reason_code, gfp_t gfp) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct sk_buff *msg; + void *hdr; + + trace_cfg80211_tdls_oper_request(wdev->wiphy, dev, peer, oper, + reason_code); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_TDLS_OPER); + if (!hdr) { + nlmsg_free(msg); + return; + } + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex) || + nla_put_u8(msg, NL80211_ATTR_TDLS_OPERATION, oper) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, peer) || + (reason_code > 0 && + nla_put_u16(msg, NL80211_ATTR_REASON_CODE, reason_code))) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + + nla_put_failure: + nlmsg_free(msg); +} +EXPORT_SYMBOL(cfg80211_tdls_oper_request); + +static int nl80211_netlink_notify(struct notifier_block * nb, + unsigned long state, + void *_notify) +{ + struct netlink_notify *notify = _notify; + struct cfg80211_registered_device *rdev; + struct wireless_dev *wdev; + struct cfg80211_beacon_registration *reg, *tmp; + + if (state != NETLINK_URELEASE || notify->protocol != NETLINK_GENERIC) + return NOTIFY_DONE; + + rcu_read_lock(); + + list_for_each_entry_rcu(rdev, &cfg80211_rdev_list, list) { + struct cfg80211_sched_scan_request *sched_scan_req; + + list_for_each_entry_rcu(sched_scan_req, + &rdev->sched_scan_req_list, + list) { + if (sched_scan_req->owner_nlportid == notify->portid) { + sched_scan_req->nl_owner_dead = true; + schedule_work(&rdev->sched_scan_stop_wk); + } + } + + list_for_each_entry_rcu(wdev, &rdev->wiphy.wdev_list, list) { + cfg80211_mlme_unregister_socket(wdev, notify->portid); + + if (wdev->owner_nlportid == notify->portid) { + wdev->nl_owner_dead = true; + schedule_work(&rdev->destroy_work); + } else if (wdev->conn_owner_nlportid == notify->portid) { + schedule_work(&wdev->disconnect_wk); + } + + cfg80211_release_pmsr(wdev, notify->portid); + } + + spin_lock_bh(&rdev->beacon_registrations_lock); + list_for_each_entry_safe(reg, tmp, &rdev->beacon_registrations, + list) { + if (reg->nlportid == notify->portid) { + list_del(®->list); + kfree(reg); + break; + } + } + spin_unlock_bh(&rdev->beacon_registrations_lock); + } + + rcu_read_unlock(); + + /* + * It is possible that the user space process that is controlling the + * indoor setting disappeared, so notify the regulatory core. + */ + regulatory_netlink_notify(notify->portid); + return NOTIFY_OK; +} + +static struct notifier_block nl80211_netlink_notifier = { + .notifier_call = nl80211_netlink_notify, +}; + +void cfg80211_ft_event(struct net_device *netdev, + struct cfg80211_ft_event_params *ft_event) +{ + struct wiphy *wiphy = netdev->ieee80211_ptr->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct sk_buff *msg; + void *hdr; + + trace_cfg80211_ft_event(wiphy, netdev, ft_event); + + if (!ft_event->target_ap) + return; + + msg = nlmsg_new(100 + ft_event->ies_len + ft_event->ric_ies_len, + GFP_KERNEL); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_FT_EVENT); + if (!hdr) + goto out; + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, netdev->ifindex) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, ft_event->target_ap)) + goto out; + + if (ft_event->ies && + nla_put(msg, NL80211_ATTR_IE, ft_event->ies_len, ft_event->ies)) + goto out; + if (ft_event->ric_ies && + nla_put(msg, NL80211_ATTR_IE_RIC, ft_event->ric_ies_len, + ft_event->ric_ies)) + goto out; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, GFP_KERNEL); + return; + out: + nlmsg_free(msg); +} +EXPORT_SYMBOL(cfg80211_ft_event); + +void cfg80211_crit_proto_stopped(struct wireless_dev *wdev, gfp_t gfp) +{ + struct cfg80211_registered_device *rdev; + struct sk_buff *msg; + void *hdr; + u32 nlportid; + + rdev = wiphy_to_rdev(wdev->wiphy); + if (!rdev->crit_proto_nlportid) + return; + + nlportid = rdev->crit_proto_nlportid; + rdev->crit_proto_nlportid = 0; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_CRIT_PROTOCOL_STOP); + if (!hdr) + goto nla_put_failure; + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u64_64bit(msg, NL80211_ATTR_WDEV, wdev_id(wdev), + NL80211_ATTR_PAD)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_unicast(wiphy_net(&rdev->wiphy), msg, nlportid); + return; + + nla_put_failure: + nlmsg_free(msg); +} +EXPORT_SYMBOL(cfg80211_crit_proto_stopped); + +void nl80211_send_ap_stopped(struct wireless_dev *wdev) +{ + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct sk_buff *msg; + void *hdr; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_STOP_AP); + if (!hdr) + goto out; + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, wdev->netdev->ifindex) || + nla_put_u64_64bit(msg, NL80211_ATTR_WDEV, wdev_id(wdev), + NL80211_ATTR_PAD)) + goto out; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(wiphy), msg, 0, + NL80211_MCGRP_MLME, GFP_KERNEL); + return; + out: + nlmsg_free(msg); +} + +int cfg80211_external_auth_request(struct net_device *dev, + struct cfg80211_external_auth_params *params, + gfp_t gfp) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct sk_buff *msg; + void *hdr; + + if (!wdev->conn_owner_nlportid) + return -EINVAL; + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return -ENOMEM; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_EXTERNAL_AUTH); + if (!hdr) + goto nla_put_failure; + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, dev->ifindex) || + nla_put_u32(msg, NL80211_ATTR_AKM_SUITES, params->key_mgmt_suite) || + nla_put_u32(msg, NL80211_ATTR_EXTERNAL_AUTH_ACTION, + params->action) || + nla_put(msg, NL80211_ATTR_BSSID, ETH_ALEN, params->bssid) || + nla_put(msg, NL80211_ATTR_SSID, params->ssid.ssid_len, + params->ssid.ssid)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + genlmsg_unicast(wiphy_net(&rdev->wiphy), msg, + wdev->conn_owner_nlportid); + return 0; + + nla_put_failure: + nlmsg_free(msg); + return -ENOBUFS; +} +EXPORT_SYMBOL(cfg80211_external_auth_request); + +void cfg80211_update_owe_info_event(struct net_device *netdev, + struct cfg80211_update_owe_info *owe_info, + gfp_t gfp) +{ + struct wiphy *wiphy = netdev->ieee80211_ptr->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct sk_buff *msg; + void *hdr; + + trace_cfg80211_update_owe_info_event(wiphy, netdev, owe_info); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_UPDATE_OWE_INFO); + if (!hdr) + goto nla_put_failure; + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u32(msg, NL80211_ATTR_IFINDEX, netdev->ifindex) || + nla_put(msg, NL80211_ATTR_MAC, ETH_ALEN, owe_info->peer)) + goto nla_put_failure; + + if (!owe_info->ie_len || + nla_put(msg, NL80211_ATTR_IE, owe_info->ie_len, owe_info->ie)) + goto nla_put_failure; + + genlmsg_end(msg, hdr); + + genlmsg_multicast_netns(&nl80211_fam, wiphy_net(&rdev->wiphy), msg, 0, + NL80211_MCGRP_MLME, gfp); + return; + +nla_put_failure: + genlmsg_cancel(msg, hdr); + nlmsg_free(msg); +} +EXPORT_SYMBOL(cfg80211_update_owe_info_event); + +/* initialisation/exit functions */ + +int __init nl80211_init(void) +{ + int err; + + err = genl_register_family(&nl80211_fam); + if (err) + return err; + + err = netlink_register_notifier(&nl80211_netlink_notifier); + if (err) + goto err_out; + + return 0; + err_out: + genl_unregister_family(&nl80211_fam); + return err; +} + +void nl80211_exit(void) +{ + netlink_unregister_notifier(&nl80211_netlink_notifier); + genl_unregister_family(&nl80211_fam); +} diff --git a/net/wireless/nl80211.h b/net/wireless/nl80211.h new file mode 100644 index 000000000..855d540dd --- /dev/null +++ b/net/wireless/nl80211.h @@ -0,0 +1,124 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Portions of this file + * Copyright (C) 2018, 2020-2022 Intel Corporation + */ +#ifndef __NET_WIRELESS_NL80211_H +#define __NET_WIRELESS_NL80211_H + +#include "core.h" + +int nl80211_init(void); +void nl80211_exit(void); + +void *nl80211hdr_put(struct sk_buff *skb, u32 portid, u32 seq, + int flags, u8 cmd); +bool nl80211_put_sta_rate(struct sk_buff *msg, struct rate_info *info, + int attr); + +static inline u64 wdev_id(struct wireless_dev *wdev) +{ + return (u64)wdev->identifier | + ((u64)wiphy_to_rdev(wdev->wiphy)->wiphy_idx << 32); +} + +int nl80211_parse_chandef(struct cfg80211_registered_device *rdev, + struct genl_info *info, + struct cfg80211_chan_def *chandef); +int nl80211_parse_random_mac(struct nlattr **attrs, + u8 *mac_addr, u8 *mac_addr_mask); + +void nl80211_notify_wiphy(struct cfg80211_registered_device *rdev, + enum nl80211_commands cmd); +void nl80211_notify_iface(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + enum nl80211_commands cmd); +void nl80211_send_scan_start(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev); +struct sk_buff *nl80211_build_scan_msg(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, bool aborted); +void nl80211_send_scan_msg(struct cfg80211_registered_device *rdev, + struct sk_buff *msg); +void nl80211_send_sched_scan(struct cfg80211_sched_scan_request *req, u32 cmd); +void nl80211_common_reg_change_event(enum nl80211_commands cmd_id, + struct regulatory_request *request); + +static inline void +nl80211_send_reg_change_event(struct regulatory_request *request) +{ + nl80211_common_reg_change_event(NL80211_CMD_REG_CHANGE, request); +} + +static inline void +nl80211_send_wiphy_reg_change_event(struct regulatory_request *request) +{ + nl80211_common_reg_change_event(NL80211_CMD_WIPHY_REG_CHANGE, request); +} + +void nl80211_send_rx_auth(struct cfg80211_registered_device *rdev, + struct net_device *netdev, + const u8 *buf, size_t len, gfp_t gfp); +void nl80211_send_rx_assoc(struct cfg80211_registered_device *rdev, + struct net_device *netdev, + struct cfg80211_rx_assoc_resp *data); +void nl80211_send_deauth(struct cfg80211_registered_device *rdev, + struct net_device *netdev, + const u8 *buf, size_t len, + bool reconnect, gfp_t gfp); +void nl80211_send_disassoc(struct cfg80211_registered_device *rdev, + struct net_device *netdev, + const u8 *buf, size_t len, + bool reconnect, gfp_t gfp); +void nl80211_send_auth_timeout(struct cfg80211_registered_device *rdev, + struct net_device *netdev, + const u8 *addr, gfp_t gfp); +void nl80211_send_assoc_timeout(struct cfg80211_registered_device *rdev, + struct net_device *netdev, + const u8 *addr, gfp_t gfp); +void nl80211_send_connect_result(struct cfg80211_registered_device *rdev, + struct net_device *netdev, + struct cfg80211_connect_resp_params *params, + gfp_t gfp); +void nl80211_send_roamed(struct cfg80211_registered_device *rdev, + struct net_device *netdev, + struct cfg80211_roam_info *info, gfp_t gfp); +void nl80211_send_port_authorized(struct cfg80211_registered_device *rdev, + struct net_device *netdev, const u8 *bssid); +void nl80211_send_disconnected(struct cfg80211_registered_device *rdev, + struct net_device *netdev, u16 reason, + const u8 *ie, size_t ie_len, bool from_ap); + +void +nl80211_michael_mic_failure(struct cfg80211_registered_device *rdev, + struct net_device *netdev, const u8 *addr, + enum nl80211_key_type key_type, + int key_id, const u8 *tsc, gfp_t gfp); + +void +nl80211_send_beacon_hint_event(struct wiphy *wiphy, + struct ieee80211_channel *channel_before, + struct ieee80211_channel *channel_after); + +void nl80211_send_ibss_bssid(struct cfg80211_registered_device *rdev, + struct net_device *netdev, const u8 *bssid, + gfp_t gfp); + +int nl80211_send_mgmt(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, u32 nlpid, + struct cfg80211_rx_info *info, gfp_t gfp); + +void +nl80211_radar_notify(struct cfg80211_registered_device *rdev, + const struct cfg80211_chan_def *chandef, + enum nl80211_radar_event event, + struct net_device *netdev, gfp_t gfp); + +void nl80211_send_ap_stopped(struct wireless_dev *wdev); + +void cfg80211_rdev_free_coalesce(struct cfg80211_registered_device *rdev); + +/* peer measurement */ +int nl80211_pmsr_start(struct sk_buff *skb, struct genl_info *info); +int nl80211_pmsr_dump_results(struct sk_buff *skb, struct netlink_callback *cb); + +#endif /* __NET_WIRELESS_NL80211_H */ diff --git a/net/wireless/ocb.c b/net/wireless/ocb.c new file mode 100644 index 000000000..29afaf3da --- /dev/null +++ b/net/wireless/ocb.c @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * OCB mode implementation + * + * Copyright: (c) 2014 Czech Technical University in Prague + * (c) 2014 Volkswagen Group Research + * Copyright (C) 2022 Intel Corporation + * Author: Rostislav Lisovy + * Funded by: Volkswagen Group Research + */ + +#include +#include +#include "nl80211.h" +#include "core.h" +#include "rdev-ops.h" + +int __cfg80211_join_ocb(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct ocb_setup *setup) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int err; + + ASSERT_WDEV_LOCK(wdev); + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_OCB) + return -EOPNOTSUPP; + + if (!rdev->ops->join_ocb) + return -EOPNOTSUPP; + + if (WARN_ON(!setup->chandef.chan)) + return -EINVAL; + + err = rdev_join_ocb(rdev, dev, setup); + if (!err) + wdev->u.ocb.chandef = setup->chandef; + + return err; +} + +int cfg80211_join_ocb(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct ocb_setup *setup) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int err; + + wdev_lock(wdev); + err = __cfg80211_join_ocb(rdev, dev, setup); + wdev_unlock(wdev); + + return err; +} + +int __cfg80211_leave_ocb(struct cfg80211_registered_device *rdev, + struct net_device *dev) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int err; + + ASSERT_WDEV_LOCK(wdev); + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_OCB) + return -EOPNOTSUPP; + + if (!rdev->ops->leave_ocb) + return -EOPNOTSUPP; + + if (!wdev->u.ocb.chandef.chan) + return -ENOTCONN; + + err = rdev_leave_ocb(rdev, dev); + if (!err) + memset(&wdev->u.ocb.chandef, 0, sizeof(wdev->u.ocb.chandef)); + + return err; +} + +int cfg80211_leave_ocb(struct cfg80211_registered_device *rdev, + struct net_device *dev) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int err; + + wdev_lock(wdev); + err = __cfg80211_leave_ocb(rdev, dev); + wdev_unlock(wdev); + + return err; +} diff --git a/net/wireless/of.c b/net/wireless/of.c new file mode 100644 index 000000000..de221f0ed --- /dev/null +++ b/net/wireless/of.c @@ -0,0 +1,138 @@ +/* + * Copyright (C) 2017 Rafał Miłecki + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +#include +#include +#include "core.h" + +static bool wiphy_freq_limits_valid_chan(struct wiphy *wiphy, + struct ieee80211_freq_range *freq_limits, + unsigned int n_freq_limits, + struct ieee80211_channel *chan) +{ + u32 bw = MHZ_TO_KHZ(20); + int i; + + for (i = 0; i < n_freq_limits; i++) { + struct ieee80211_freq_range *limit = &freq_limits[i]; + + if (cfg80211_does_bw_fit_range(limit, + MHZ_TO_KHZ(chan->center_freq), + bw)) + return true; + } + + return false; +} + +static void wiphy_freq_limits_apply(struct wiphy *wiphy, + struct ieee80211_freq_range *freq_limits, + unsigned int n_freq_limits) +{ + enum nl80211_band band; + int i; + + if (WARN_ON(!n_freq_limits)) + return; + + for (band = 0; band < NUM_NL80211_BANDS; band++) { + struct ieee80211_supported_band *sband = wiphy->bands[band]; + + if (!sband) + continue; + + for (i = 0; i < sband->n_channels; i++) { + struct ieee80211_channel *chan = &sband->channels[i]; + + if (chan->flags & IEEE80211_CHAN_DISABLED) + continue; + + if (!wiphy_freq_limits_valid_chan(wiphy, freq_limits, + n_freq_limits, + chan)) { + pr_debug("Disabling freq %d MHz as it's out of OF limits\n", + chan->center_freq); + chan->flags |= IEEE80211_CHAN_DISABLED; + } + } + } +} + +void wiphy_read_of_freq_limits(struct wiphy *wiphy) +{ + struct device *dev = wiphy_dev(wiphy); + struct device_node *np; + struct property *prop; + struct ieee80211_freq_range *freq_limits; + unsigned int n_freq_limits; + const __be32 *p; + int len, i; + int err = 0; + + if (!dev) + return; + np = dev_of_node(dev); + if (!np) + return; + + prop = of_find_property(np, "ieee80211-freq-limit", &len); + if (!prop) + return; + + if (!len || len % sizeof(u32) || len / sizeof(u32) % 2) { + dev_err(dev, "ieee80211-freq-limit wrong format"); + return; + } + n_freq_limits = len / sizeof(u32) / 2; + + freq_limits = kcalloc(n_freq_limits, sizeof(*freq_limits), GFP_KERNEL); + if (!freq_limits) { + err = -ENOMEM; + goto out_kfree; + } + + p = NULL; + for (i = 0; i < n_freq_limits; i++) { + struct ieee80211_freq_range *limit = &freq_limits[i]; + + p = of_prop_next_u32(prop, p, &limit->start_freq_khz); + if (!p) { + err = -EINVAL; + goto out_kfree; + } + + p = of_prop_next_u32(prop, p, &limit->end_freq_khz); + if (!p) { + err = -EINVAL; + goto out_kfree; + } + + if (!limit->start_freq_khz || + !limit->end_freq_khz || + limit->start_freq_khz >= limit->end_freq_khz) { + err = -EINVAL; + goto out_kfree; + } + } + + wiphy_freq_limits_apply(wiphy, freq_limits, n_freq_limits); + +out_kfree: + kfree(freq_limits); + if (err) + dev_err(dev, "Failed to get limits: %d\n", err); +} +EXPORT_SYMBOL(wiphy_read_of_freq_limits); diff --git a/net/wireless/pmsr.c b/net/wireless/pmsr.c new file mode 100644 index 000000000..2bc647720 --- /dev/null +++ b/net/wireless/pmsr.c @@ -0,0 +1,661 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Copyright (C) 2018 - 2021 Intel Corporation + */ +#include +#include "core.h" +#include "nl80211.h" +#include "rdev-ops.h" + +static int pmsr_parse_ftm(struct cfg80211_registered_device *rdev, + struct nlattr *ftmreq, + struct cfg80211_pmsr_request_peer *out, + struct genl_info *info) +{ + const struct cfg80211_pmsr_capabilities *capa = rdev->wiphy.pmsr_capa; + struct nlattr *tb[NL80211_PMSR_FTM_REQ_ATTR_MAX + 1]; + u32 preamble = NL80211_PREAMBLE_DMG; /* only optional in DMG */ + + /* validate existing data */ + if (!(rdev->wiphy.pmsr_capa->ftm.bandwidths & BIT(out->chandef.width))) { + NL_SET_ERR_MSG(info->extack, "FTM: unsupported bandwidth"); + return -EINVAL; + } + + /* no validation needed - was already done via nested policy */ + nla_parse_nested_deprecated(tb, NL80211_PMSR_FTM_REQ_ATTR_MAX, ftmreq, + NULL, NULL); + + if (tb[NL80211_PMSR_FTM_REQ_ATTR_PREAMBLE]) + preamble = nla_get_u32(tb[NL80211_PMSR_FTM_REQ_ATTR_PREAMBLE]); + + /* set up values - struct is 0-initialized */ + out->ftm.requested = true; + + switch (out->chandef.chan->band) { + case NL80211_BAND_60GHZ: + /* optional */ + break; + default: + if (!tb[NL80211_PMSR_FTM_REQ_ATTR_PREAMBLE]) { + NL_SET_ERR_MSG(info->extack, + "FTM: must specify preamble"); + return -EINVAL; + } + } + + if (!(capa->ftm.preambles & BIT(preamble))) { + NL_SET_ERR_MSG_ATTR(info->extack, + tb[NL80211_PMSR_FTM_REQ_ATTR_PREAMBLE], + "FTM: invalid preamble"); + return -EINVAL; + } + + out->ftm.preamble = preamble; + + out->ftm.burst_period = 0; + if (tb[NL80211_PMSR_FTM_REQ_ATTR_BURST_PERIOD]) + out->ftm.burst_period = + nla_get_u32(tb[NL80211_PMSR_FTM_REQ_ATTR_BURST_PERIOD]); + + out->ftm.asap = !!tb[NL80211_PMSR_FTM_REQ_ATTR_ASAP]; + if (out->ftm.asap && !capa->ftm.asap) { + NL_SET_ERR_MSG_ATTR(info->extack, + tb[NL80211_PMSR_FTM_REQ_ATTR_ASAP], + "FTM: ASAP mode not supported"); + return -EINVAL; + } + + if (!out->ftm.asap && !capa->ftm.non_asap) { + NL_SET_ERR_MSG(info->extack, + "FTM: non-ASAP mode not supported"); + return -EINVAL; + } + + out->ftm.num_bursts_exp = 0; + if (tb[NL80211_PMSR_FTM_REQ_ATTR_NUM_BURSTS_EXP]) + out->ftm.num_bursts_exp = + nla_get_u32(tb[NL80211_PMSR_FTM_REQ_ATTR_NUM_BURSTS_EXP]); + + if (capa->ftm.max_bursts_exponent >= 0 && + out->ftm.num_bursts_exp > capa->ftm.max_bursts_exponent) { + NL_SET_ERR_MSG_ATTR(info->extack, + tb[NL80211_PMSR_FTM_REQ_ATTR_NUM_BURSTS_EXP], + "FTM: max NUM_BURSTS_EXP must be set lower than the device limit"); + return -EINVAL; + } + + out->ftm.burst_duration = 15; + if (tb[NL80211_PMSR_FTM_REQ_ATTR_BURST_DURATION]) + out->ftm.burst_duration = + nla_get_u32(tb[NL80211_PMSR_FTM_REQ_ATTR_BURST_DURATION]); + + out->ftm.ftms_per_burst = 0; + if (tb[NL80211_PMSR_FTM_REQ_ATTR_FTMS_PER_BURST]) + out->ftm.ftms_per_burst = + nla_get_u32(tb[NL80211_PMSR_FTM_REQ_ATTR_FTMS_PER_BURST]); + + if (capa->ftm.max_ftms_per_burst && + (out->ftm.ftms_per_burst > capa->ftm.max_ftms_per_burst || + out->ftm.ftms_per_burst == 0)) { + NL_SET_ERR_MSG_ATTR(info->extack, + tb[NL80211_PMSR_FTM_REQ_ATTR_FTMS_PER_BURST], + "FTM: FTMs per burst must be set lower than the device limit but non-zero"); + return -EINVAL; + } + + out->ftm.ftmr_retries = 3; + if (tb[NL80211_PMSR_FTM_REQ_ATTR_NUM_FTMR_RETRIES]) + out->ftm.ftmr_retries = + nla_get_u32(tb[NL80211_PMSR_FTM_REQ_ATTR_NUM_FTMR_RETRIES]); + + out->ftm.request_lci = !!tb[NL80211_PMSR_FTM_REQ_ATTR_REQUEST_LCI]; + if (out->ftm.request_lci && !capa->ftm.request_lci) { + NL_SET_ERR_MSG_ATTR(info->extack, + tb[NL80211_PMSR_FTM_REQ_ATTR_REQUEST_LCI], + "FTM: LCI request not supported"); + } + + out->ftm.request_civicloc = + !!tb[NL80211_PMSR_FTM_REQ_ATTR_REQUEST_CIVICLOC]; + if (out->ftm.request_civicloc && !capa->ftm.request_civicloc) { + NL_SET_ERR_MSG_ATTR(info->extack, + tb[NL80211_PMSR_FTM_REQ_ATTR_REQUEST_CIVICLOC], + "FTM: civic location request not supported"); + } + + out->ftm.trigger_based = + !!tb[NL80211_PMSR_FTM_REQ_ATTR_TRIGGER_BASED]; + if (out->ftm.trigger_based && !capa->ftm.trigger_based) { + NL_SET_ERR_MSG_ATTR(info->extack, + tb[NL80211_PMSR_FTM_REQ_ATTR_TRIGGER_BASED], + "FTM: trigger based ranging is not supported"); + return -EINVAL; + } + + out->ftm.non_trigger_based = + !!tb[NL80211_PMSR_FTM_REQ_ATTR_NON_TRIGGER_BASED]; + if (out->ftm.non_trigger_based && !capa->ftm.non_trigger_based) { + NL_SET_ERR_MSG_ATTR(info->extack, + tb[NL80211_PMSR_FTM_REQ_ATTR_NON_TRIGGER_BASED], + "FTM: trigger based ranging is not supported"); + return -EINVAL; + } + + if (out->ftm.trigger_based && out->ftm.non_trigger_based) { + NL_SET_ERR_MSG(info->extack, + "FTM: can't set both trigger based and non trigger based"); + return -EINVAL; + } + + if ((out->ftm.trigger_based || out->ftm.non_trigger_based) && + out->ftm.preamble != NL80211_PREAMBLE_HE) { + NL_SET_ERR_MSG_ATTR(info->extack, + tb[NL80211_PMSR_FTM_REQ_ATTR_PREAMBLE], + "FTM: non EDCA based ranging must use HE preamble"); + return -EINVAL; + } + + out->ftm.lmr_feedback = + !!tb[NL80211_PMSR_FTM_REQ_ATTR_LMR_FEEDBACK]; + if (!out->ftm.trigger_based && !out->ftm.non_trigger_based && + out->ftm.lmr_feedback) { + NL_SET_ERR_MSG_ATTR(info->extack, + tb[NL80211_PMSR_FTM_REQ_ATTR_LMR_FEEDBACK], + "FTM: LMR feedback set for EDCA based ranging"); + return -EINVAL; + } + + if (tb[NL80211_PMSR_FTM_REQ_ATTR_BSS_COLOR]) { + if (!out->ftm.non_trigger_based && !out->ftm.trigger_based) { + NL_SET_ERR_MSG_ATTR(info->extack, + tb[NL80211_PMSR_FTM_REQ_ATTR_BSS_COLOR], + "FTM: BSS color set for EDCA based ranging"); + return -EINVAL; + } + + out->ftm.bss_color = + nla_get_u8(tb[NL80211_PMSR_FTM_REQ_ATTR_BSS_COLOR]); + } + + return 0; +} + +static int pmsr_parse_peer(struct cfg80211_registered_device *rdev, + struct nlattr *peer, + struct cfg80211_pmsr_request_peer *out, + struct genl_info *info) +{ + struct nlattr *tb[NL80211_PMSR_PEER_ATTR_MAX + 1]; + struct nlattr *req[NL80211_PMSR_REQ_ATTR_MAX + 1]; + struct nlattr *treq; + int err, rem; + + /* no validation needed - was already done via nested policy */ + nla_parse_nested_deprecated(tb, NL80211_PMSR_PEER_ATTR_MAX, peer, + NULL, NULL); + + if (!tb[NL80211_PMSR_PEER_ATTR_ADDR] || + !tb[NL80211_PMSR_PEER_ATTR_CHAN] || + !tb[NL80211_PMSR_PEER_ATTR_REQ]) { + NL_SET_ERR_MSG_ATTR(info->extack, peer, + "insufficient peer data"); + return -EINVAL; + } + + memcpy(out->addr, nla_data(tb[NL80211_PMSR_PEER_ATTR_ADDR]), ETH_ALEN); + + /* reuse info->attrs */ + memset(info->attrs, 0, sizeof(*info->attrs) * (NL80211_ATTR_MAX + 1)); + err = nla_parse_nested_deprecated(info->attrs, NL80211_ATTR_MAX, + tb[NL80211_PMSR_PEER_ATTR_CHAN], + NULL, info->extack); + if (err) + return err; + + err = nl80211_parse_chandef(rdev, info, &out->chandef); + if (err) + return err; + + /* no validation needed - was already done via nested policy */ + nla_parse_nested_deprecated(req, NL80211_PMSR_REQ_ATTR_MAX, + tb[NL80211_PMSR_PEER_ATTR_REQ], NULL, + NULL); + + if (!req[NL80211_PMSR_REQ_ATTR_DATA]) { + NL_SET_ERR_MSG_ATTR(info->extack, + tb[NL80211_PMSR_PEER_ATTR_REQ], + "missing request type/data"); + return -EINVAL; + } + + if (req[NL80211_PMSR_REQ_ATTR_GET_AP_TSF]) + out->report_ap_tsf = true; + + if (out->report_ap_tsf && !rdev->wiphy.pmsr_capa->report_ap_tsf) { + NL_SET_ERR_MSG_ATTR(info->extack, + req[NL80211_PMSR_REQ_ATTR_GET_AP_TSF], + "reporting AP TSF is not supported"); + return -EINVAL; + } + + nla_for_each_nested(treq, req[NL80211_PMSR_REQ_ATTR_DATA], rem) { + switch (nla_type(treq)) { + case NL80211_PMSR_TYPE_FTM: + err = pmsr_parse_ftm(rdev, treq, out, info); + break; + default: + NL_SET_ERR_MSG_ATTR(info->extack, treq, + "unsupported measurement type"); + err = -EINVAL; + } + } + + if (err) + return err; + + return 0; +} + +int nl80211_pmsr_start(struct sk_buff *skb, struct genl_info *info) +{ + struct nlattr *reqattr = info->attrs[NL80211_ATTR_PEER_MEASUREMENTS]; + struct cfg80211_registered_device *rdev = info->user_ptr[0]; + struct wireless_dev *wdev = info->user_ptr[1]; + struct cfg80211_pmsr_request *req; + struct nlattr *peers, *peer; + int count, rem, err, idx; + + if (!rdev->wiphy.pmsr_capa) + return -EOPNOTSUPP; + + if (!reqattr) + return -EINVAL; + + peers = nla_find(nla_data(reqattr), nla_len(reqattr), + NL80211_PMSR_ATTR_PEERS); + if (!peers) + return -EINVAL; + + count = 0; + nla_for_each_nested(peer, peers, rem) { + count++; + + if (count > rdev->wiphy.pmsr_capa->max_peers) { + NL_SET_ERR_MSG_ATTR(info->extack, peer, + "Too many peers used"); + return -EINVAL; + } + } + + req = kzalloc(struct_size(req, peers, count), GFP_KERNEL); + if (!req) + return -ENOMEM; + + if (info->attrs[NL80211_ATTR_TIMEOUT]) + req->timeout = nla_get_u32(info->attrs[NL80211_ATTR_TIMEOUT]); + + if (info->attrs[NL80211_ATTR_MAC]) { + if (!rdev->wiphy.pmsr_capa->randomize_mac_addr) { + NL_SET_ERR_MSG_ATTR(info->extack, + info->attrs[NL80211_ATTR_MAC], + "device cannot randomize MAC address"); + err = -EINVAL; + goto out_err; + } + + err = nl80211_parse_random_mac(info->attrs, req->mac_addr, + req->mac_addr_mask); + if (err) + goto out_err; + } else { + memcpy(req->mac_addr, wdev_address(wdev), ETH_ALEN); + eth_broadcast_addr(req->mac_addr_mask); + } + + idx = 0; + nla_for_each_nested(peer, peers, rem) { + /* NB: this reuses info->attrs, but we no longer need it */ + err = pmsr_parse_peer(rdev, peer, &req->peers[idx], info); + if (err) + goto out_err; + idx++; + } + + req->n_peers = count; + req->cookie = cfg80211_assign_cookie(rdev); + req->nl_portid = info->snd_portid; + + err = rdev_start_pmsr(rdev, wdev, req); + if (err) + goto out_err; + + list_add_tail(&req->list, &wdev->pmsr_list); + + nl_set_extack_cookie_u64(info->extack, req->cookie); + return 0; +out_err: + kfree(req); + return err; +} + +void cfg80211_pmsr_complete(struct wireless_dev *wdev, + struct cfg80211_pmsr_request *req, + gfp_t gfp) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_pmsr_request *tmp, *prev, *to_free = NULL; + struct sk_buff *msg; + void *hdr; + + trace_cfg80211_pmsr_complete(wdev->wiphy, wdev, req->cookie); + + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + goto free_request; + + hdr = nl80211hdr_put(msg, 0, 0, 0, + NL80211_CMD_PEER_MEASUREMENT_COMPLETE); + if (!hdr) + goto free_msg; + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u64_64bit(msg, NL80211_ATTR_WDEV, wdev_id(wdev), + NL80211_ATTR_PAD)) + goto free_msg; + + if (nla_put_u64_64bit(msg, NL80211_ATTR_COOKIE, req->cookie, + NL80211_ATTR_PAD)) + goto free_msg; + + genlmsg_end(msg, hdr); + genlmsg_unicast(wiphy_net(wdev->wiphy), msg, req->nl_portid); + goto free_request; +free_msg: + nlmsg_free(msg); +free_request: + spin_lock_bh(&wdev->pmsr_lock); + /* + * cfg80211_pmsr_process_abort() may have already moved this request + * to the free list, and will free it later. In this case, don't free + * it here. + */ + list_for_each_entry_safe(tmp, prev, &wdev->pmsr_list, list) { + if (tmp == req) { + list_del(&req->list); + to_free = req; + break; + } + } + spin_unlock_bh(&wdev->pmsr_lock); + kfree(to_free); +} +EXPORT_SYMBOL_GPL(cfg80211_pmsr_complete); + +static int nl80211_pmsr_send_ftm_res(struct sk_buff *msg, + struct cfg80211_pmsr_result *res) +{ + if (res->status == NL80211_PMSR_STATUS_FAILURE) { + if (nla_put_u32(msg, NL80211_PMSR_FTM_RESP_ATTR_FAIL_REASON, + res->ftm.failure_reason)) + goto error; + + if (res->ftm.failure_reason == + NL80211_PMSR_FTM_FAILURE_PEER_BUSY && + res->ftm.busy_retry_time && + nla_put_u32(msg, NL80211_PMSR_FTM_RESP_ATTR_BUSY_RETRY_TIME, + res->ftm.busy_retry_time)) + goto error; + + return 0; + } + +#define PUT(tp, attr, val) \ + do { \ + if (nla_put_##tp(msg, \ + NL80211_PMSR_FTM_RESP_ATTR_##attr, \ + res->ftm.val)) \ + goto error; \ + } while (0) + +#define PUTOPT(tp, attr, val) \ + do { \ + if (res->ftm.val##_valid) \ + PUT(tp, attr, val); \ + } while (0) + +#define PUT_U64(attr, val) \ + do { \ + if (nla_put_u64_64bit(msg, \ + NL80211_PMSR_FTM_RESP_ATTR_##attr,\ + res->ftm.val, \ + NL80211_PMSR_FTM_RESP_ATTR_PAD)) \ + goto error; \ + } while (0) + +#define PUTOPT_U64(attr, val) \ + do { \ + if (res->ftm.val##_valid) \ + PUT_U64(attr, val); \ + } while (0) + + if (res->ftm.burst_index >= 0) + PUT(u32, BURST_INDEX, burst_index); + PUTOPT(u32, NUM_FTMR_ATTEMPTS, num_ftmr_attempts); + PUTOPT(u32, NUM_FTMR_SUCCESSES, num_ftmr_successes); + PUT(u8, NUM_BURSTS_EXP, num_bursts_exp); + PUT(u8, BURST_DURATION, burst_duration); + PUT(u8, FTMS_PER_BURST, ftms_per_burst); + PUTOPT(s32, RSSI_AVG, rssi_avg); + PUTOPT(s32, RSSI_SPREAD, rssi_spread); + if (res->ftm.tx_rate_valid && + !nl80211_put_sta_rate(msg, &res->ftm.tx_rate, + NL80211_PMSR_FTM_RESP_ATTR_TX_RATE)) + goto error; + if (res->ftm.rx_rate_valid && + !nl80211_put_sta_rate(msg, &res->ftm.rx_rate, + NL80211_PMSR_FTM_RESP_ATTR_RX_RATE)) + goto error; + PUTOPT_U64(RTT_AVG, rtt_avg); + PUTOPT_U64(RTT_VARIANCE, rtt_variance); + PUTOPT_U64(RTT_SPREAD, rtt_spread); + PUTOPT_U64(DIST_AVG, dist_avg); + PUTOPT_U64(DIST_VARIANCE, dist_variance); + PUTOPT_U64(DIST_SPREAD, dist_spread); + if (res->ftm.lci && res->ftm.lci_len && + nla_put(msg, NL80211_PMSR_FTM_RESP_ATTR_LCI, + res->ftm.lci_len, res->ftm.lci)) + goto error; + if (res->ftm.civicloc && res->ftm.civicloc_len && + nla_put(msg, NL80211_PMSR_FTM_RESP_ATTR_CIVICLOC, + res->ftm.civicloc_len, res->ftm.civicloc)) + goto error; +#undef PUT +#undef PUTOPT +#undef PUT_U64 +#undef PUTOPT_U64 + + return 0; +error: + return -ENOSPC; +} + +static int nl80211_pmsr_send_result(struct sk_buff *msg, + struct cfg80211_pmsr_result *res) +{ + struct nlattr *pmsr, *peers, *peer, *resp, *data, *typedata; + + pmsr = nla_nest_start_noflag(msg, NL80211_ATTR_PEER_MEASUREMENTS); + if (!pmsr) + goto error; + + peers = nla_nest_start_noflag(msg, NL80211_PMSR_ATTR_PEERS); + if (!peers) + goto error; + + peer = nla_nest_start_noflag(msg, 1); + if (!peer) + goto error; + + if (nla_put(msg, NL80211_PMSR_PEER_ATTR_ADDR, ETH_ALEN, res->addr)) + goto error; + + resp = nla_nest_start_noflag(msg, NL80211_PMSR_PEER_ATTR_RESP); + if (!resp) + goto error; + + if (nla_put_u32(msg, NL80211_PMSR_RESP_ATTR_STATUS, res->status) || + nla_put_u64_64bit(msg, NL80211_PMSR_RESP_ATTR_HOST_TIME, + res->host_time, NL80211_PMSR_RESP_ATTR_PAD)) + goto error; + + if (res->ap_tsf_valid && + nla_put_u64_64bit(msg, NL80211_PMSR_RESP_ATTR_AP_TSF, + res->ap_tsf, NL80211_PMSR_RESP_ATTR_PAD)) + goto error; + + if (res->final && nla_put_flag(msg, NL80211_PMSR_RESP_ATTR_FINAL)) + goto error; + + data = nla_nest_start_noflag(msg, NL80211_PMSR_RESP_ATTR_DATA); + if (!data) + goto error; + + typedata = nla_nest_start_noflag(msg, res->type); + if (!typedata) + goto error; + + switch (res->type) { + case NL80211_PMSR_TYPE_FTM: + if (nl80211_pmsr_send_ftm_res(msg, res)) + goto error; + break; + default: + WARN_ON(1); + } + + nla_nest_end(msg, typedata); + nla_nest_end(msg, data); + nla_nest_end(msg, resp); + nla_nest_end(msg, peer); + nla_nest_end(msg, peers); + nla_nest_end(msg, pmsr); + + return 0; +error: + return -ENOSPC; +} + +void cfg80211_pmsr_report(struct wireless_dev *wdev, + struct cfg80211_pmsr_request *req, + struct cfg80211_pmsr_result *result, + gfp_t gfp) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct sk_buff *msg; + void *hdr; + int err; + + trace_cfg80211_pmsr_report(wdev->wiphy, wdev, req->cookie, + result->addr); + + /* + * Currently, only variable items are LCI and civic location, + * both of which are reasonably short so we don't need to + * worry about them here for the allocation. + */ + msg = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); + if (!msg) + return; + + hdr = nl80211hdr_put(msg, 0, 0, 0, NL80211_CMD_PEER_MEASUREMENT_RESULT); + if (!hdr) + goto free; + + if (nla_put_u32(msg, NL80211_ATTR_WIPHY, rdev->wiphy_idx) || + nla_put_u64_64bit(msg, NL80211_ATTR_WDEV, wdev_id(wdev), + NL80211_ATTR_PAD)) + goto free; + + if (nla_put_u64_64bit(msg, NL80211_ATTR_COOKIE, req->cookie, + NL80211_ATTR_PAD)) + goto free; + + err = nl80211_pmsr_send_result(msg, result); + if (err) { + pr_err_ratelimited("peer measurement result: message didn't fit!"); + goto free; + } + + genlmsg_end(msg, hdr); + genlmsg_unicast(wiphy_net(wdev->wiphy), msg, req->nl_portid); + return; +free: + nlmsg_free(msg); +} +EXPORT_SYMBOL_GPL(cfg80211_pmsr_report); + +static void cfg80211_pmsr_process_abort(struct wireless_dev *wdev) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_pmsr_request *req, *tmp; + LIST_HEAD(free_list); + + lockdep_assert_held(&wdev->mtx); + + spin_lock_bh(&wdev->pmsr_lock); + list_for_each_entry_safe(req, tmp, &wdev->pmsr_list, list) { + if (req->nl_portid) + continue; + list_move_tail(&req->list, &free_list); + } + spin_unlock_bh(&wdev->pmsr_lock); + + list_for_each_entry_safe(req, tmp, &free_list, list) { + rdev_abort_pmsr(rdev, wdev, req); + + kfree(req); + } +} + +void cfg80211_pmsr_free_wk(struct work_struct *work) +{ + struct wireless_dev *wdev = container_of(work, struct wireless_dev, + pmsr_free_wk); + + wdev_lock(wdev); + cfg80211_pmsr_process_abort(wdev); + wdev_unlock(wdev); +} + +void cfg80211_pmsr_wdev_down(struct wireless_dev *wdev) +{ + struct cfg80211_pmsr_request *req; + bool found = false; + + spin_lock_bh(&wdev->pmsr_lock); + list_for_each_entry(req, &wdev->pmsr_list, list) { + found = true; + req->nl_portid = 0; + } + spin_unlock_bh(&wdev->pmsr_lock); + + if (found) + cfg80211_pmsr_process_abort(wdev); + + WARN_ON(!list_empty(&wdev->pmsr_list)); +} + +void cfg80211_release_pmsr(struct wireless_dev *wdev, u32 portid) +{ + struct cfg80211_pmsr_request *req; + + spin_lock_bh(&wdev->pmsr_lock); + list_for_each_entry(req, &wdev->pmsr_list, list) { + if (req->nl_portid == portid) { + req->nl_portid = 0; + schedule_work(&wdev->pmsr_free_wk); + } + } + spin_unlock_bh(&wdev->pmsr_lock); +} diff --git a/net/wireless/radiotap.c b/net/wireless/radiotap.c new file mode 100644 index 000000000..ae2e1a896 --- /dev/null +++ b/net/wireless/radiotap.c @@ -0,0 +1,370 @@ +/* + * Radiotap parser + * + * Copyright 2007 Andy Green + * Copyright 2009 Johannes Berg + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 as + * published by the Free Software Foundation. + * + * Alternatively, this software may be distributed under the terms of BSD + * license. + * + * See COPYING for more details. + */ + +#include +#include +#include +#include +#include + +/* function prototypes and related defs are in include/net/cfg80211.h */ + +static const struct radiotap_align_size rtap_namespace_sizes[] = { + [IEEE80211_RADIOTAP_TSFT] = { .align = 8, .size = 8, }, + [IEEE80211_RADIOTAP_FLAGS] = { .align = 1, .size = 1, }, + [IEEE80211_RADIOTAP_RATE] = { .align = 1, .size = 1, }, + [IEEE80211_RADIOTAP_CHANNEL] = { .align = 2, .size = 4, }, + [IEEE80211_RADIOTAP_FHSS] = { .align = 2, .size = 2, }, + [IEEE80211_RADIOTAP_DBM_ANTSIGNAL] = { .align = 1, .size = 1, }, + [IEEE80211_RADIOTAP_DBM_ANTNOISE] = { .align = 1, .size = 1, }, + [IEEE80211_RADIOTAP_LOCK_QUALITY] = { .align = 2, .size = 2, }, + [IEEE80211_RADIOTAP_TX_ATTENUATION] = { .align = 2, .size = 2, }, + [IEEE80211_RADIOTAP_DB_TX_ATTENUATION] = { .align = 2, .size = 2, }, + [IEEE80211_RADIOTAP_DBM_TX_POWER] = { .align = 1, .size = 1, }, + [IEEE80211_RADIOTAP_ANTENNA] = { .align = 1, .size = 1, }, + [IEEE80211_RADIOTAP_DB_ANTSIGNAL] = { .align = 1, .size = 1, }, + [IEEE80211_RADIOTAP_DB_ANTNOISE] = { .align = 1, .size = 1, }, + [IEEE80211_RADIOTAP_RX_FLAGS] = { .align = 2, .size = 2, }, + [IEEE80211_RADIOTAP_TX_FLAGS] = { .align = 2, .size = 2, }, + [IEEE80211_RADIOTAP_RTS_RETRIES] = { .align = 1, .size = 1, }, + [IEEE80211_RADIOTAP_DATA_RETRIES] = { .align = 1, .size = 1, }, + [IEEE80211_RADIOTAP_MCS] = { .align = 1, .size = 3, }, + [IEEE80211_RADIOTAP_AMPDU_STATUS] = { .align = 4, .size = 8, }, + [IEEE80211_RADIOTAP_VHT] = { .align = 2, .size = 12, }, + /* + * add more here as they are defined in radiotap.h + */ +}; + +static const struct ieee80211_radiotap_namespace radiotap_ns = { + .n_bits = ARRAY_SIZE(rtap_namespace_sizes), + .align_size = rtap_namespace_sizes, +}; + +/** + * ieee80211_radiotap_iterator_init - radiotap parser iterator initialization + * @iterator: radiotap_iterator to initialize + * @radiotap_header: radiotap header to parse + * @max_length: total length we can parse into (eg, whole packet length) + * @vns: vendor namespaces to parse + * + * Returns: 0 or a negative error code if there is a problem. + * + * This function initializes an opaque iterator struct which can then + * be passed to ieee80211_radiotap_iterator_next() to visit every radiotap + * argument which is present in the header. It knows about extended + * present headers and handles them. + * + * How to use: + * call __ieee80211_radiotap_iterator_init() to init a semi-opaque iterator + * struct ieee80211_radiotap_iterator (no need to init the struct beforehand) + * checking for a good 0 return code. Then loop calling + * __ieee80211_radiotap_iterator_next()... it returns either 0, + * -ENOENT if there are no more args to parse, or -EINVAL if there is a problem. + * The iterator's @this_arg member points to the start of the argument + * associated with the current argument index that is present, which can be + * found in the iterator's @this_arg_index member. This arg index corresponds + * to the IEEE80211_RADIOTAP_... defines. + * + * Radiotap header length: + * You can find the CPU-endian total radiotap header length in + * iterator->max_length after executing ieee80211_radiotap_iterator_init() + * successfully. + * + * Alignment Gotcha: + * You must take care when dereferencing iterator.this_arg + * for multibyte types... the pointer is not aligned. Use + * get_unaligned((type *)iterator.this_arg) to dereference + * iterator.this_arg for type "type" safely on all arches. + * + * Example code: + * See Documentation/networking/radiotap-headers.rst + */ + +int ieee80211_radiotap_iterator_init( + struct ieee80211_radiotap_iterator *iterator, + struct ieee80211_radiotap_header *radiotap_header, + int max_length, const struct ieee80211_radiotap_vendor_namespaces *vns) +{ + /* check the radiotap header can actually be present */ + if (max_length < sizeof(struct ieee80211_radiotap_header)) + return -EINVAL; + + /* Linux only supports version 0 radiotap format */ + if (radiotap_header->it_version) + return -EINVAL; + + /* sanity check for allowed length and radiotap length field */ + if (max_length < get_unaligned_le16(&radiotap_header->it_len)) + return -EINVAL; + + iterator->_rtheader = radiotap_header; + iterator->_max_length = get_unaligned_le16(&radiotap_header->it_len); + iterator->_arg_index = 0; + iterator->_bitmap_shifter = get_unaligned_le32(&radiotap_header->it_present); + iterator->_arg = (uint8_t *)radiotap_header->it_optional; + iterator->_reset_on_ext = 0; + iterator->_next_bitmap = radiotap_header->it_optional; + iterator->_vns = vns; + iterator->current_namespace = &radiotap_ns; + iterator->is_radiotap_ns = 1; + + /* find payload start allowing for extended bitmap(s) */ + + if (iterator->_bitmap_shifter & (BIT(IEEE80211_RADIOTAP_EXT))) { + if ((unsigned long)iterator->_arg - + (unsigned long)iterator->_rtheader + sizeof(uint32_t) > + (unsigned long)iterator->_max_length) + return -EINVAL; + while (get_unaligned_le32(iterator->_arg) & + (BIT(IEEE80211_RADIOTAP_EXT))) { + iterator->_arg += sizeof(uint32_t); + + /* + * check for insanity where the present bitmaps + * keep claiming to extend up to or even beyond the + * stated radiotap header length + */ + + if ((unsigned long)iterator->_arg - + (unsigned long)iterator->_rtheader + + sizeof(uint32_t) > + (unsigned long)iterator->_max_length) + return -EINVAL; + } + + iterator->_arg += sizeof(uint32_t); + + /* + * no need to check again for blowing past stated radiotap + * header length, because ieee80211_radiotap_iterator_next + * checks it before it is dereferenced + */ + } + + iterator->this_arg = iterator->_arg; + + /* we are all initialized happily */ + + return 0; +} +EXPORT_SYMBOL(ieee80211_radiotap_iterator_init); + +static void find_ns(struct ieee80211_radiotap_iterator *iterator, + uint32_t oui, uint8_t subns) +{ + int i; + + iterator->current_namespace = NULL; + + if (!iterator->_vns) + return; + + for (i = 0; i < iterator->_vns->n_ns; i++) { + if (iterator->_vns->ns[i].oui != oui) + continue; + if (iterator->_vns->ns[i].subns != subns) + continue; + + iterator->current_namespace = &iterator->_vns->ns[i]; + break; + } +} + + + +/** + * ieee80211_radiotap_iterator_next - return next radiotap parser iterator arg + * @iterator: radiotap_iterator to move to next arg (if any) + * + * Returns: 0 if there is an argument to handle, + * -ENOENT if there are no more args or -EINVAL + * if there is something else wrong. + * + * This function provides the next radiotap arg index (IEEE80211_RADIOTAP_*) + * in @this_arg_index and sets @this_arg to point to the + * payload for the field. It takes care of alignment handling and extended + * present fields. @this_arg can be changed by the caller (eg, + * incremented to move inside a compound argument like + * IEEE80211_RADIOTAP_CHANNEL). The args pointed to are in + * little-endian format whatever the endianess of your CPU. + * + * Alignment Gotcha: + * You must take care when dereferencing iterator.this_arg + * for multibyte types... the pointer is not aligned. Use + * get_unaligned((type *)iterator.this_arg) to dereference + * iterator.this_arg for type "type" safely on all arches. + */ + +int ieee80211_radiotap_iterator_next( + struct ieee80211_radiotap_iterator *iterator) +{ + while (1) { + int hit = 0; + int pad, align, size, subns; + uint32_t oui; + + /* if no more EXT bits, that's it */ + if ((iterator->_arg_index % 32) == IEEE80211_RADIOTAP_EXT && + !(iterator->_bitmap_shifter & 1)) + return -ENOENT; + + if (!(iterator->_bitmap_shifter & 1)) + goto next_entry; /* arg not present */ + + /* get alignment/size of data */ + switch (iterator->_arg_index % 32) { + case IEEE80211_RADIOTAP_RADIOTAP_NAMESPACE: + case IEEE80211_RADIOTAP_EXT: + align = 1; + size = 0; + break; + case IEEE80211_RADIOTAP_VENDOR_NAMESPACE: + align = 2; + size = 6; + break; + default: + if (!iterator->current_namespace || + iterator->_arg_index >= iterator->current_namespace->n_bits) { + if (iterator->current_namespace == &radiotap_ns) + return -ENOENT; + align = 0; + } else { + align = iterator->current_namespace->align_size[iterator->_arg_index].align; + size = iterator->current_namespace->align_size[iterator->_arg_index].size; + } + if (!align) { + /* skip all subsequent data */ + iterator->_arg = iterator->_next_ns_data; + /* give up on this namespace */ + iterator->current_namespace = NULL; + goto next_entry; + } + break; + } + + /* + * arg is present, account for alignment padding + * + * Note that these alignments are relative to the start + * of the radiotap header. There is no guarantee + * that the radiotap header itself is aligned on any + * kind of boundary. + * + * The above is why get_unaligned() is used to dereference + * multibyte elements from the radiotap area. + */ + + pad = ((unsigned long)iterator->_arg - + (unsigned long)iterator->_rtheader) & (align - 1); + + if (pad) + iterator->_arg += align - pad; + + if (iterator->_arg_index % 32 == IEEE80211_RADIOTAP_VENDOR_NAMESPACE) { + int vnslen; + + if ((unsigned long)iterator->_arg + size - + (unsigned long)iterator->_rtheader > + (unsigned long)iterator->_max_length) + return -EINVAL; + + oui = (*iterator->_arg << 16) | + (*(iterator->_arg + 1) << 8) | + *(iterator->_arg + 2); + subns = *(iterator->_arg + 3); + + find_ns(iterator, oui, subns); + + vnslen = get_unaligned_le16(iterator->_arg + 4); + iterator->_next_ns_data = iterator->_arg + size + vnslen; + if (!iterator->current_namespace) + size += vnslen; + } + + /* + * this is what we will return to user, but we need to + * move on first so next call has something fresh to test + */ + iterator->this_arg_index = iterator->_arg_index; + iterator->this_arg = iterator->_arg; + iterator->this_arg_size = size; + + /* internally move on the size of this arg */ + iterator->_arg += size; + + /* + * check for insanity where we are given a bitmap that + * claims to have more arg content than the length of the + * radiotap section. We will normally end up equalling this + * max_length on the last arg, never exceeding it. + */ + + if ((unsigned long)iterator->_arg - + (unsigned long)iterator->_rtheader > + (unsigned long)iterator->_max_length) + return -EINVAL; + + /* these special ones are valid in each bitmap word */ + switch (iterator->_arg_index % 32) { + case IEEE80211_RADIOTAP_VENDOR_NAMESPACE: + iterator->_reset_on_ext = 1; + + iterator->is_radiotap_ns = 0; + /* + * If parser didn't register this vendor + * namespace with us, allow it to show it + * as 'raw. Do do that, set argument index + * to vendor namespace. + */ + iterator->this_arg_index = + IEEE80211_RADIOTAP_VENDOR_NAMESPACE; + if (!iterator->current_namespace) + hit = 1; + goto next_entry; + case IEEE80211_RADIOTAP_RADIOTAP_NAMESPACE: + iterator->_reset_on_ext = 1; + iterator->current_namespace = &radiotap_ns; + iterator->is_radiotap_ns = 1; + goto next_entry; + case IEEE80211_RADIOTAP_EXT: + /* + * bit 31 was set, there is more + * -- move to next u32 bitmap + */ + iterator->_bitmap_shifter = + get_unaligned_le32(iterator->_next_bitmap); + iterator->_next_bitmap++; + if (iterator->_reset_on_ext) + iterator->_arg_index = 0; + else + iterator->_arg_index++; + iterator->_reset_on_ext = 0; + break; + default: + /* we've got a hit! */ + hit = 1; + next_entry: + iterator->_bitmap_shifter >>= 1; + iterator->_arg_index++; + } + + /* if we found a valid arg earlier, return it now */ + if (hit) + return 0; + } +} +EXPORT_SYMBOL(ieee80211_radiotap_iterator_next); diff --git a/net/wireless/rdev-ops.h b/net/wireless/rdev-ops.h new file mode 100644 index 000000000..ee853a14a --- /dev/null +++ b/net/wireless/rdev-ops.h @@ -0,0 +1,1497 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* + * Portions of this file + * Copyright(c) 2016-2017 Intel Deutschland GmbH + * Copyright (C) 2018, 2021-2023 Intel Corporation + */ +#ifndef __CFG80211_RDEV_OPS +#define __CFG80211_RDEV_OPS + +#include +#include +#include "core.h" +#include "trace.h" + +static inline int rdev_suspend(struct cfg80211_registered_device *rdev, + struct cfg80211_wowlan *wowlan) +{ + int ret; + trace_rdev_suspend(&rdev->wiphy, wowlan); + ret = rdev->ops->suspend(&rdev->wiphy, wowlan); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_resume(struct cfg80211_registered_device *rdev) +{ + int ret; + trace_rdev_resume(&rdev->wiphy); + ret = rdev->ops->resume(&rdev->wiphy); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline void rdev_set_wakeup(struct cfg80211_registered_device *rdev, + bool enabled) +{ + trace_rdev_set_wakeup(&rdev->wiphy, enabled); + rdev->ops->set_wakeup(&rdev->wiphy, enabled); + trace_rdev_return_void(&rdev->wiphy); +} + +static inline struct wireless_dev +*rdev_add_virtual_intf(struct cfg80211_registered_device *rdev, char *name, + unsigned char name_assign_type, + enum nl80211_iftype type, + struct vif_params *params) +{ + struct wireless_dev *ret; + trace_rdev_add_virtual_intf(&rdev->wiphy, name, type); + ret = rdev->ops->add_virtual_intf(&rdev->wiphy, name, name_assign_type, + type, params); + trace_rdev_return_wdev(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_del_virtual_intf(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev) +{ + int ret; + trace_rdev_del_virtual_intf(&rdev->wiphy, wdev); + ret = rdev->ops->del_virtual_intf(&rdev->wiphy, wdev); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_change_virtual_intf(struct cfg80211_registered_device *rdev, + struct net_device *dev, enum nl80211_iftype type, + struct vif_params *params) +{ + int ret; + trace_rdev_change_virtual_intf(&rdev->wiphy, dev, type); + ret = rdev->ops->change_virtual_intf(&rdev->wiphy, dev, type, params); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_add_key(struct cfg80211_registered_device *rdev, + struct net_device *netdev, int link_id, + u8 key_index, bool pairwise, const u8 *mac_addr, + struct key_params *params) +{ + int ret; + trace_rdev_add_key(&rdev->wiphy, netdev, link_id, key_index, pairwise, + mac_addr, params->mode); + ret = rdev->ops->add_key(&rdev->wiphy, netdev, link_id, key_index, + pairwise, mac_addr, params); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_get_key(struct cfg80211_registered_device *rdev, struct net_device *netdev, + int link_id, u8 key_index, bool pairwise, const u8 *mac_addr, + void *cookie, + void (*callback)(void *cookie, struct key_params*)) +{ + int ret; + trace_rdev_get_key(&rdev->wiphy, netdev, link_id, key_index, pairwise, + mac_addr); + ret = rdev->ops->get_key(&rdev->wiphy, netdev, link_id, key_index, + pairwise, mac_addr, cookie, callback); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_del_key(struct cfg80211_registered_device *rdev, + struct net_device *netdev, int link_id, + u8 key_index, bool pairwise, const u8 *mac_addr) +{ + int ret; + trace_rdev_del_key(&rdev->wiphy, netdev, link_id, key_index, pairwise, + mac_addr); + ret = rdev->ops->del_key(&rdev->wiphy, netdev, link_id, key_index, + pairwise, mac_addr); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_set_default_key(struct cfg80211_registered_device *rdev, + struct net_device *netdev, int link_id, u8 key_index, + bool unicast, bool multicast) +{ + int ret; + trace_rdev_set_default_key(&rdev->wiphy, netdev, link_id, key_index, + unicast, multicast); + ret = rdev->ops->set_default_key(&rdev->wiphy, netdev, link_id, + key_index, unicast, multicast); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_set_default_mgmt_key(struct cfg80211_registered_device *rdev, + struct net_device *netdev, int link_id, u8 key_index) +{ + int ret; + trace_rdev_set_default_mgmt_key(&rdev->wiphy, netdev, link_id, + key_index); + ret = rdev->ops->set_default_mgmt_key(&rdev->wiphy, netdev, link_id, + key_index); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_set_default_beacon_key(struct cfg80211_registered_device *rdev, + struct net_device *netdev, int link_id, + u8 key_index) +{ + int ret; + + trace_rdev_set_default_beacon_key(&rdev->wiphy, netdev, link_id, + key_index); + ret = rdev->ops->set_default_beacon_key(&rdev->wiphy, netdev, link_id, + key_index); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_start_ap(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_ap_settings *settings) +{ + int ret; + trace_rdev_start_ap(&rdev->wiphy, dev, settings); + ret = rdev->ops->start_ap(&rdev->wiphy, dev, settings); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_change_beacon(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_beacon_data *info) +{ + int ret; + trace_rdev_change_beacon(&rdev->wiphy, dev, info); + ret = rdev->ops->change_beacon(&rdev->wiphy, dev, info); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_stop_ap(struct cfg80211_registered_device *rdev, + struct net_device *dev, unsigned int link_id) +{ + int ret; + trace_rdev_stop_ap(&rdev->wiphy, dev, link_id); + ret = rdev->ops->stop_ap(&rdev->wiphy, dev, link_id); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_add_station(struct cfg80211_registered_device *rdev, + struct net_device *dev, u8 *mac, + struct station_parameters *params) +{ + int ret; + trace_rdev_add_station(&rdev->wiphy, dev, mac, params); + ret = rdev->ops->add_station(&rdev->wiphy, dev, mac, params); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_del_station(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct station_del_parameters *params) +{ + int ret; + trace_rdev_del_station(&rdev->wiphy, dev, params); + ret = rdev->ops->del_station(&rdev->wiphy, dev, params); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_change_station(struct cfg80211_registered_device *rdev, + struct net_device *dev, u8 *mac, + struct station_parameters *params) +{ + int ret; + trace_rdev_change_station(&rdev->wiphy, dev, mac, params); + ret = rdev->ops->change_station(&rdev->wiphy, dev, mac, params); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_get_station(struct cfg80211_registered_device *rdev, + struct net_device *dev, const u8 *mac, + struct station_info *sinfo) +{ + int ret; + trace_rdev_get_station(&rdev->wiphy, dev, mac); + ret = rdev->ops->get_station(&rdev->wiphy, dev, mac, sinfo); + trace_rdev_return_int_station_info(&rdev->wiphy, ret, sinfo); + return ret; +} + +static inline int rdev_dump_station(struct cfg80211_registered_device *rdev, + struct net_device *dev, int idx, u8 *mac, + struct station_info *sinfo) +{ + int ret; + trace_rdev_dump_station(&rdev->wiphy, dev, idx, mac); + ret = rdev->ops->dump_station(&rdev->wiphy, dev, idx, mac, sinfo); + trace_rdev_return_int_station_info(&rdev->wiphy, ret, sinfo); + return ret; +} + +static inline int rdev_add_mpath(struct cfg80211_registered_device *rdev, + struct net_device *dev, u8 *dst, u8 *next_hop) +{ + int ret; + trace_rdev_add_mpath(&rdev->wiphy, dev, dst, next_hop); + ret = rdev->ops->add_mpath(&rdev->wiphy, dev, dst, next_hop); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_del_mpath(struct cfg80211_registered_device *rdev, + struct net_device *dev, u8 *dst) +{ + int ret; + trace_rdev_del_mpath(&rdev->wiphy, dev, dst); + ret = rdev->ops->del_mpath(&rdev->wiphy, dev, dst); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_change_mpath(struct cfg80211_registered_device *rdev, + struct net_device *dev, u8 *dst, + u8 *next_hop) +{ + int ret; + trace_rdev_change_mpath(&rdev->wiphy, dev, dst, next_hop); + ret = rdev->ops->change_mpath(&rdev->wiphy, dev, dst, next_hop); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_get_mpath(struct cfg80211_registered_device *rdev, + struct net_device *dev, u8 *dst, u8 *next_hop, + struct mpath_info *pinfo) +{ + int ret; + trace_rdev_get_mpath(&rdev->wiphy, dev, dst, next_hop); + ret = rdev->ops->get_mpath(&rdev->wiphy, dev, dst, next_hop, pinfo); + trace_rdev_return_int_mpath_info(&rdev->wiphy, ret, pinfo); + return ret; + +} + +static inline int rdev_get_mpp(struct cfg80211_registered_device *rdev, + struct net_device *dev, u8 *dst, u8 *mpp, + struct mpath_info *pinfo) +{ + int ret; + + trace_rdev_get_mpp(&rdev->wiphy, dev, dst, mpp); + ret = rdev->ops->get_mpp(&rdev->wiphy, dev, dst, mpp, pinfo); + trace_rdev_return_int_mpath_info(&rdev->wiphy, ret, pinfo); + return ret; +} + +static inline int rdev_dump_mpath(struct cfg80211_registered_device *rdev, + struct net_device *dev, int idx, u8 *dst, + u8 *next_hop, struct mpath_info *pinfo) + +{ + int ret; + trace_rdev_dump_mpath(&rdev->wiphy, dev, idx, dst, next_hop); + ret = rdev->ops->dump_mpath(&rdev->wiphy, dev, idx, dst, next_hop, + pinfo); + trace_rdev_return_int_mpath_info(&rdev->wiphy, ret, pinfo); + return ret; +} + +static inline int rdev_dump_mpp(struct cfg80211_registered_device *rdev, + struct net_device *dev, int idx, u8 *dst, + u8 *mpp, struct mpath_info *pinfo) + +{ + int ret; + + trace_rdev_dump_mpp(&rdev->wiphy, dev, idx, dst, mpp); + ret = rdev->ops->dump_mpp(&rdev->wiphy, dev, idx, dst, mpp, pinfo); + trace_rdev_return_int_mpath_info(&rdev->wiphy, ret, pinfo); + return ret; +} + +static inline int +rdev_get_mesh_config(struct cfg80211_registered_device *rdev, + struct net_device *dev, struct mesh_config *conf) +{ + int ret; + trace_rdev_get_mesh_config(&rdev->wiphy, dev); + ret = rdev->ops->get_mesh_config(&rdev->wiphy, dev, conf); + trace_rdev_return_int_mesh_config(&rdev->wiphy, ret, conf); + return ret; +} + +static inline int +rdev_update_mesh_config(struct cfg80211_registered_device *rdev, + struct net_device *dev, u32 mask, + const struct mesh_config *nconf) +{ + int ret; + trace_rdev_update_mesh_config(&rdev->wiphy, dev, mask, nconf); + ret = rdev->ops->update_mesh_config(&rdev->wiphy, dev, mask, nconf); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_join_mesh(struct cfg80211_registered_device *rdev, + struct net_device *dev, + const struct mesh_config *conf, + const struct mesh_setup *setup) +{ + int ret; + trace_rdev_join_mesh(&rdev->wiphy, dev, conf, setup); + ret = rdev->ops->join_mesh(&rdev->wiphy, dev, conf, setup); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + + +static inline int rdev_leave_mesh(struct cfg80211_registered_device *rdev, + struct net_device *dev) +{ + int ret; + trace_rdev_leave_mesh(&rdev->wiphy, dev); + ret = rdev->ops->leave_mesh(&rdev->wiphy, dev); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_join_ocb(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct ocb_setup *setup) +{ + int ret; + trace_rdev_join_ocb(&rdev->wiphy, dev, setup); + ret = rdev->ops->join_ocb(&rdev->wiphy, dev, setup); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_leave_ocb(struct cfg80211_registered_device *rdev, + struct net_device *dev) +{ + int ret; + trace_rdev_leave_ocb(&rdev->wiphy, dev); + ret = rdev->ops->leave_ocb(&rdev->wiphy, dev); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_change_bss(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct bss_parameters *params) + +{ + int ret; + trace_rdev_change_bss(&rdev->wiphy, dev, params); + ret = rdev->ops->change_bss(&rdev->wiphy, dev, params); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_set_txq_params(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct ieee80211_txq_params *params) + +{ + int ret; + trace_rdev_set_txq_params(&rdev->wiphy, dev, params); + ret = rdev->ops->set_txq_params(&rdev->wiphy, dev, params); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_libertas_set_mesh_channel(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct ieee80211_channel *chan) +{ + int ret; + trace_rdev_libertas_set_mesh_channel(&rdev->wiphy, dev, chan); + ret = rdev->ops->libertas_set_mesh_channel(&rdev->wiphy, dev, chan); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_set_monitor_channel(struct cfg80211_registered_device *rdev, + struct cfg80211_chan_def *chandef) +{ + int ret; + trace_rdev_set_monitor_channel(&rdev->wiphy, chandef); + ret = rdev->ops->set_monitor_channel(&rdev->wiphy, chandef); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_scan(struct cfg80211_registered_device *rdev, + struct cfg80211_scan_request *request) +{ + int ret; + trace_rdev_scan(&rdev->wiphy, request); + ret = rdev->ops->scan(&rdev->wiphy, request); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline void rdev_abort_scan(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev) +{ + trace_rdev_abort_scan(&rdev->wiphy, wdev); + rdev->ops->abort_scan(&rdev->wiphy, wdev); + trace_rdev_return_void(&rdev->wiphy); +} + +static inline int rdev_auth(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_auth_request *req) +{ + int ret; + trace_rdev_auth(&rdev->wiphy, dev, req); + ret = rdev->ops->auth(&rdev->wiphy, dev, req); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_assoc(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_assoc_request *req) +{ + int ret; + + trace_rdev_assoc(&rdev->wiphy, dev, req); + ret = rdev->ops->assoc(&rdev->wiphy, dev, req); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_deauth(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_deauth_request *req) +{ + int ret; + trace_rdev_deauth(&rdev->wiphy, dev, req); + ret = rdev->ops->deauth(&rdev->wiphy, dev, req); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_disassoc(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_disassoc_request *req) +{ + int ret; + trace_rdev_disassoc(&rdev->wiphy, dev, req); + ret = rdev->ops->disassoc(&rdev->wiphy, dev, req); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_connect(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_connect_params *sme) +{ + int ret; + trace_rdev_connect(&rdev->wiphy, dev, sme); + ret = rdev->ops->connect(&rdev->wiphy, dev, sme); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_update_connect_params(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_connect_params *sme, u32 changed) +{ + int ret; + trace_rdev_update_connect_params(&rdev->wiphy, dev, sme, changed); + ret = rdev->ops->update_connect_params(&rdev->wiphy, dev, sme, changed); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_disconnect(struct cfg80211_registered_device *rdev, + struct net_device *dev, u16 reason_code) +{ + int ret; + trace_rdev_disconnect(&rdev->wiphy, dev, reason_code); + ret = rdev->ops->disconnect(&rdev->wiphy, dev, reason_code); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_join_ibss(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_ibss_params *params) +{ + int ret; + trace_rdev_join_ibss(&rdev->wiphy, dev, params); + ret = rdev->ops->join_ibss(&rdev->wiphy, dev, params); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_leave_ibss(struct cfg80211_registered_device *rdev, + struct net_device *dev) +{ + int ret; + trace_rdev_leave_ibss(&rdev->wiphy, dev); + ret = rdev->ops->leave_ibss(&rdev->wiphy, dev); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_set_wiphy_params(struct cfg80211_registered_device *rdev, u32 changed) +{ + int ret; + + if (!rdev->ops->set_wiphy_params) + return -EOPNOTSUPP; + + trace_rdev_set_wiphy_params(&rdev->wiphy, changed); + ret = rdev->ops->set_wiphy_params(&rdev->wiphy, changed); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_set_tx_power(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + enum nl80211_tx_power_setting type, int mbm) +{ + int ret; + trace_rdev_set_tx_power(&rdev->wiphy, wdev, type, mbm); + ret = rdev->ops->set_tx_power(&rdev->wiphy, wdev, type, mbm); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_get_tx_power(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, int *dbm) +{ + int ret; + trace_rdev_get_tx_power(&rdev->wiphy, wdev); + ret = rdev->ops->get_tx_power(&rdev->wiphy, wdev, dbm); + trace_rdev_return_int_int(&rdev->wiphy, ret, *dbm); + return ret; +} + +static inline int +rdev_set_multicast_to_unicast(struct cfg80211_registered_device *rdev, + struct net_device *dev, + const bool enabled) +{ + int ret; + trace_rdev_set_multicast_to_unicast(&rdev->wiphy, dev, enabled); + ret = rdev->ops->set_multicast_to_unicast(&rdev->wiphy, dev, enabled); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_get_txq_stats(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + struct cfg80211_txq_stats *txqstats) +{ + int ret; + trace_rdev_get_txq_stats(&rdev->wiphy, wdev); + ret = rdev->ops->get_txq_stats(&rdev->wiphy, wdev, txqstats); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline void rdev_rfkill_poll(struct cfg80211_registered_device *rdev) +{ + trace_rdev_rfkill_poll(&rdev->wiphy); + rdev->ops->rfkill_poll(&rdev->wiphy); + trace_rdev_return_void(&rdev->wiphy); +} + + +#ifdef CONFIG_NL80211_TESTMODE +static inline int rdev_testmode_cmd(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + void *data, int len) +{ + int ret; + trace_rdev_testmode_cmd(&rdev->wiphy, wdev); + ret = rdev->ops->testmode_cmd(&rdev->wiphy, wdev, data, len); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_testmode_dump(struct cfg80211_registered_device *rdev, + struct sk_buff *skb, + struct netlink_callback *cb, void *data, + int len) +{ + int ret; + trace_rdev_testmode_dump(&rdev->wiphy); + ret = rdev->ops->testmode_dump(&rdev->wiphy, skb, cb, data, len); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} +#endif + +static inline int +rdev_set_bitrate_mask(struct cfg80211_registered_device *rdev, + struct net_device *dev, unsigned int link_id, + const u8 *peer, + const struct cfg80211_bitrate_mask *mask) +{ + int ret; + trace_rdev_set_bitrate_mask(&rdev->wiphy, dev, link_id, peer, mask); + ret = rdev->ops->set_bitrate_mask(&rdev->wiphy, dev, link_id, + peer, mask); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_dump_survey(struct cfg80211_registered_device *rdev, + struct net_device *netdev, int idx, + struct survey_info *info) +{ + int ret; + trace_rdev_dump_survey(&rdev->wiphy, netdev, idx); + ret = rdev->ops->dump_survey(&rdev->wiphy, netdev, idx, info); + if (ret < 0) + trace_rdev_return_int(&rdev->wiphy, ret); + else + trace_rdev_return_int_survey_info(&rdev->wiphy, ret, info); + return ret; +} + +static inline int rdev_set_pmksa(struct cfg80211_registered_device *rdev, + struct net_device *netdev, + struct cfg80211_pmksa *pmksa) +{ + int ret; + trace_rdev_set_pmksa(&rdev->wiphy, netdev, pmksa); + ret = rdev->ops->set_pmksa(&rdev->wiphy, netdev, pmksa); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_del_pmksa(struct cfg80211_registered_device *rdev, + struct net_device *netdev, + struct cfg80211_pmksa *pmksa) +{ + int ret; + trace_rdev_del_pmksa(&rdev->wiphy, netdev, pmksa); + ret = rdev->ops->del_pmksa(&rdev->wiphy, netdev, pmksa); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_flush_pmksa(struct cfg80211_registered_device *rdev, + struct net_device *netdev) +{ + int ret; + trace_rdev_flush_pmksa(&rdev->wiphy, netdev); + ret = rdev->ops->flush_pmksa(&rdev->wiphy, netdev); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_remain_on_channel(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + struct ieee80211_channel *chan, + unsigned int duration, u64 *cookie) +{ + int ret; + trace_rdev_remain_on_channel(&rdev->wiphy, wdev, chan, duration); + ret = rdev->ops->remain_on_channel(&rdev->wiphy, wdev, chan, + duration, cookie); + trace_rdev_return_int_cookie(&rdev->wiphy, ret, *cookie); + return ret; +} + +static inline int +rdev_cancel_remain_on_channel(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, u64 cookie) +{ + int ret; + trace_rdev_cancel_remain_on_channel(&rdev->wiphy, wdev, cookie); + ret = rdev->ops->cancel_remain_on_channel(&rdev->wiphy, wdev, cookie); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_mgmt_tx(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + struct cfg80211_mgmt_tx_params *params, + u64 *cookie) +{ + int ret; + trace_rdev_mgmt_tx(&rdev->wiphy, wdev, params); + ret = rdev->ops->mgmt_tx(&rdev->wiphy, wdev, params, cookie); + trace_rdev_return_int_cookie(&rdev->wiphy, ret, *cookie); + return ret; +} + +static inline int rdev_tx_control_port(struct cfg80211_registered_device *rdev, + struct net_device *dev, + const void *buf, size_t len, + const u8 *dest, __be16 proto, + const bool noencrypt, int link, + u64 *cookie) +{ + int ret; + trace_rdev_tx_control_port(&rdev->wiphy, dev, buf, len, + dest, proto, noencrypt, link); + ret = rdev->ops->tx_control_port(&rdev->wiphy, dev, buf, len, + dest, proto, noencrypt, link, cookie); + if (cookie) + trace_rdev_return_int_cookie(&rdev->wiphy, ret, *cookie); + else + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_mgmt_tx_cancel_wait(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, u64 cookie) +{ + int ret; + trace_rdev_mgmt_tx_cancel_wait(&rdev->wiphy, wdev, cookie); + ret = rdev->ops->mgmt_tx_cancel_wait(&rdev->wiphy, wdev, cookie); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_set_power_mgmt(struct cfg80211_registered_device *rdev, + struct net_device *dev, bool enabled, + int timeout) +{ + int ret; + trace_rdev_set_power_mgmt(&rdev->wiphy, dev, enabled, timeout); + ret = rdev->ops->set_power_mgmt(&rdev->wiphy, dev, enabled, timeout); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_set_cqm_rssi_config(struct cfg80211_registered_device *rdev, + struct net_device *dev, s32 rssi_thold, u32 rssi_hyst) +{ + int ret; + trace_rdev_set_cqm_rssi_config(&rdev->wiphy, dev, rssi_thold, + rssi_hyst); + ret = rdev->ops->set_cqm_rssi_config(&rdev->wiphy, dev, rssi_thold, + rssi_hyst); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_set_cqm_rssi_range_config(struct cfg80211_registered_device *rdev, + struct net_device *dev, s32 low, s32 high) +{ + int ret; + trace_rdev_set_cqm_rssi_range_config(&rdev->wiphy, dev, low, high); + ret = rdev->ops->set_cqm_rssi_range_config(&rdev->wiphy, dev, + low, high); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_set_cqm_txe_config(struct cfg80211_registered_device *rdev, + struct net_device *dev, u32 rate, u32 pkts, u32 intvl) +{ + int ret; + trace_rdev_set_cqm_txe_config(&rdev->wiphy, dev, rate, pkts, intvl); + ret = rdev->ops->set_cqm_txe_config(&rdev->wiphy, dev, rate, pkts, + intvl); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline void +rdev_update_mgmt_frame_registrations(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + struct mgmt_frame_regs *upd) +{ + might_sleep(); + + trace_rdev_update_mgmt_frame_registrations(&rdev->wiphy, wdev, upd); + if (rdev->ops->update_mgmt_frame_registrations) + rdev->ops->update_mgmt_frame_registrations(&rdev->wiphy, wdev, + upd); + trace_rdev_return_void(&rdev->wiphy); +} + +static inline int rdev_set_antenna(struct cfg80211_registered_device *rdev, + u32 tx_ant, u32 rx_ant) +{ + int ret; + trace_rdev_set_antenna(&rdev->wiphy, tx_ant, rx_ant); + ret = rdev->ops->set_antenna(&rdev->wiphy, tx_ant, rx_ant); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_get_antenna(struct cfg80211_registered_device *rdev, + u32 *tx_ant, u32 *rx_ant) +{ + int ret; + trace_rdev_get_antenna(&rdev->wiphy); + ret = rdev->ops->get_antenna(&rdev->wiphy, tx_ant, rx_ant); + if (ret) + trace_rdev_return_int(&rdev->wiphy, ret); + else + trace_rdev_return_int_tx_rx(&rdev->wiphy, ret, *tx_ant, + *rx_ant); + return ret; +} + +static inline int +rdev_sched_scan_start(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_sched_scan_request *request) +{ + int ret; + trace_rdev_sched_scan_start(&rdev->wiphy, dev, request->reqid); + ret = rdev->ops->sched_scan_start(&rdev->wiphy, dev, request); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_sched_scan_stop(struct cfg80211_registered_device *rdev, + struct net_device *dev, u64 reqid) +{ + int ret; + trace_rdev_sched_scan_stop(&rdev->wiphy, dev, reqid); + ret = rdev->ops->sched_scan_stop(&rdev->wiphy, dev, reqid); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_set_rekey_data(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_gtk_rekey_data *data) +{ + int ret; + trace_rdev_set_rekey_data(&rdev->wiphy, dev); + ret = rdev->ops->set_rekey_data(&rdev->wiphy, dev, data); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_tdls_mgmt(struct cfg80211_registered_device *rdev, + struct net_device *dev, u8 *peer, + u8 action_code, u8 dialog_token, + u16 status_code, u32 peer_capability, + bool initiator, const u8 *buf, size_t len) +{ + int ret; + trace_rdev_tdls_mgmt(&rdev->wiphy, dev, peer, action_code, + dialog_token, status_code, peer_capability, + initiator, buf, len); + ret = rdev->ops->tdls_mgmt(&rdev->wiphy, dev, peer, action_code, + dialog_token, status_code, peer_capability, + initiator, buf, len); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_tdls_oper(struct cfg80211_registered_device *rdev, + struct net_device *dev, u8 *peer, + enum nl80211_tdls_operation oper) +{ + int ret; + trace_rdev_tdls_oper(&rdev->wiphy, dev, peer, oper); + ret = rdev->ops->tdls_oper(&rdev->wiphy, dev, peer, oper); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_probe_client(struct cfg80211_registered_device *rdev, + struct net_device *dev, const u8 *peer, + u64 *cookie) +{ + int ret; + trace_rdev_probe_client(&rdev->wiphy, dev, peer); + ret = rdev->ops->probe_client(&rdev->wiphy, dev, peer, cookie); + trace_rdev_return_int_cookie(&rdev->wiphy, ret, *cookie); + return ret; +} + +static inline int rdev_set_noack_map(struct cfg80211_registered_device *rdev, + struct net_device *dev, u16 noack_map) +{ + int ret; + trace_rdev_set_noack_map(&rdev->wiphy, dev, noack_map); + ret = rdev->ops->set_noack_map(&rdev->wiphy, dev, noack_map); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_get_channel(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + unsigned int link_id, + struct cfg80211_chan_def *chandef) +{ + int ret; + + trace_rdev_get_channel(&rdev->wiphy, wdev, link_id); + ret = rdev->ops->get_channel(&rdev->wiphy, wdev, link_id, chandef); + trace_rdev_return_chandef(&rdev->wiphy, ret, chandef); + + return ret; +} + +static inline int rdev_start_p2p_device(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev) +{ + int ret; + + trace_rdev_start_p2p_device(&rdev->wiphy, wdev); + ret = rdev->ops->start_p2p_device(&rdev->wiphy, wdev); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline void rdev_stop_p2p_device(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev) +{ + trace_rdev_stop_p2p_device(&rdev->wiphy, wdev); + rdev->ops->stop_p2p_device(&rdev->wiphy, wdev); + trace_rdev_return_void(&rdev->wiphy); +} + +static inline int rdev_start_nan(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + struct cfg80211_nan_conf *conf) +{ + int ret; + + trace_rdev_start_nan(&rdev->wiphy, wdev, conf); + ret = rdev->ops->start_nan(&rdev->wiphy, wdev, conf); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline void rdev_stop_nan(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev) +{ + trace_rdev_stop_nan(&rdev->wiphy, wdev); + rdev->ops->stop_nan(&rdev->wiphy, wdev); + trace_rdev_return_void(&rdev->wiphy); +} + +static inline int +rdev_add_nan_func(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + struct cfg80211_nan_func *nan_func) +{ + int ret; + + trace_rdev_add_nan_func(&rdev->wiphy, wdev, nan_func); + ret = rdev->ops->add_nan_func(&rdev->wiphy, wdev, nan_func); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline void rdev_del_nan_func(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, u64 cookie) +{ + trace_rdev_del_nan_func(&rdev->wiphy, wdev, cookie); + rdev->ops->del_nan_func(&rdev->wiphy, wdev, cookie); + trace_rdev_return_void(&rdev->wiphy); +} + +static inline int +rdev_nan_change_conf(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + struct cfg80211_nan_conf *conf, u32 changes) +{ + int ret; + + trace_rdev_nan_change_conf(&rdev->wiphy, wdev, conf, changes); + if (rdev->ops->nan_change_conf) + ret = rdev->ops->nan_change_conf(&rdev->wiphy, wdev, conf, + changes); + else + ret = -ENOTSUPP; + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_set_mac_acl(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_acl_data *params) +{ + int ret; + + trace_rdev_set_mac_acl(&rdev->wiphy, dev, params); + ret = rdev->ops->set_mac_acl(&rdev->wiphy, dev, params); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_update_ft_ies(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_update_ft_ies_params *ftie) +{ + int ret; + + trace_rdev_update_ft_ies(&rdev->wiphy, dev, ftie); + ret = rdev->ops->update_ft_ies(&rdev->wiphy, dev, ftie); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_crit_proto_start(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + enum nl80211_crit_proto_id protocol, + u16 duration) +{ + int ret; + + trace_rdev_crit_proto_start(&rdev->wiphy, wdev, protocol, duration); + ret = rdev->ops->crit_proto_start(&rdev->wiphy, wdev, + protocol, duration); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline void rdev_crit_proto_stop(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev) +{ + trace_rdev_crit_proto_stop(&rdev->wiphy, wdev); + rdev->ops->crit_proto_stop(&rdev->wiphy, wdev); + trace_rdev_return_void(&rdev->wiphy); +} + +static inline int rdev_channel_switch(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_csa_settings *params) +{ + int ret; + + trace_rdev_channel_switch(&rdev->wiphy, dev, params); + ret = rdev->ops->channel_switch(&rdev->wiphy, dev, params); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_set_qos_map(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_qos_map *qos_map) +{ + int ret = -EOPNOTSUPP; + + if (rdev->ops->set_qos_map) { + trace_rdev_set_qos_map(&rdev->wiphy, dev, qos_map); + ret = rdev->ops->set_qos_map(&rdev->wiphy, dev, qos_map); + trace_rdev_return_int(&rdev->wiphy, ret); + } + + return ret; +} + +static inline int +rdev_set_ap_chanwidth(struct cfg80211_registered_device *rdev, + struct net_device *dev, + unsigned int link_id, + struct cfg80211_chan_def *chandef) +{ + int ret; + + trace_rdev_set_ap_chanwidth(&rdev->wiphy, dev, link_id, chandef); + ret = rdev->ops->set_ap_chanwidth(&rdev->wiphy, dev, link_id, chandef); + trace_rdev_return_int(&rdev->wiphy, ret); + + return ret; +} + +static inline int +rdev_add_tx_ts(struct cfg80211_registered_device *rdev, + struct net_device *dev, u8 tsid, const u8 *peer, + u8 user_prio, u16 admitted_time) +{ + int ret = -EOPNOTSUPP; + + trace_rdev_add_tx_ts(&rdev->wiphy, dev, tsid, peer, + user_prio, admitted_time); + if (rdev->ops->add_tx_ts) + ret = rdev->ops->add_tx_ts(&rdev->wiphy, dev, tsid, peer, + user_prio, admitted_time); + trace_rdev_return_int(&rdev->wiphy, ret); + + return ret; +} + +static inline int +rdev_del_tx_ts(struct cfg80211_registered_device *rdev, + struct net_device *dev, u8 tsid, const u8 *peer) +{ + int ret = -EOPNOTSUPP; + + trace_rdev_del_tx_ts(&rdev->wiphy, dev, tsid, peer); + if (rdev->ops->del_tx_ts) + ret = rdev->ops->del_tx_ts(&rdev->wiphy, dev, tsid, peer); + trace_rdev_return_int(&rdev->wiphy, ret); + + return ret; +} + +static inline int +rdev_tdls_channel_switch(struct cfg80211_registered_device *rdev, + struct net_device *dev, const u8 *addr, + u8 oper_class, struct cfg80211_chan_def *chandef) +{ + int ret; + + trace_rdev_tdls_channel_switch(&rdev->wiphy, dev, addr, oper_class, + chandef); + ret = rdev->ops->tdls_channel_switch(&rdev->wiphy, dev, addr, + oper_class, chandef); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline void +rdev_tdls_cancel_channel_switch(struct cfg80211_registered_device *rdev, + struct net_device *dev, const u8 *addr) +{ + trace_rdev_tdls_cancel_channel_switch(&rdev->wiphy, dev, addr); + rdev->ops->tdls_cancel_channel_switch(&rdev->wiphy, dev, addr); + trace_rdev_return_void(&rdev->wiphy); +} + +static inline int +rdev_start_radar_detection(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_chan_def *chandef, + u32 cac_time_ms) +{ + int ret = -ENOTSUPP; + + trace_rdev_start_radar_detection(&rdev->wiphy, dev, chandef, + cac_time_ms); + if (rdev->ops->start_radar_detection) + ret = rdev->ops->start_radar_detection(&rdev->wiphy, dev, + chandef, cac_time_ms); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline void +rdev_end_cac(struct cfg80211_registered_device *rdev, + struct net_device *dev) +{ + trace_rdev_end_cac(&rdev->wiphy, dev); + if (rdev->ops->end_cac) + rdev->ops->end_cac(&rdev->wiphy, dev); + trace_rdev_return_void(&rdev->wiphy); +} + +static inline int +rdev_set_mcast_rate(struct cfg80211_registered_device *rdev, + struct net_device *dev, + int mcast_rate[NUM_NL80211_BANDS]) +{ + int ret = -ENOTSUPP; + + trace_rdev_set_mcast_rate(&rdev->wiphy, dev, mcast_rate); + if (rdev->ops->set_mcast_rate) + ret = rdev->ops->set_mcast_rate(&rdev->wiphy, dev, mcast_rate); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_set_coalesce(struct cfg80211_registered_device *rdev, + struct cfg80211_coalesce *coalesce) +{ + int ret = -ENOTSUPP; + + trace_rdev_set_coalesce(&rdev->wiphy, coalesce); + if (rdev->ops->set_coalesce) + ret = rdev->ops->set_coalesce(&rdev->wiphy, coalesce); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_set_pmk(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_pmk_conf *pmk_conf) +{ + int ret = -EOPNOTSUPP; + + trace_rdev_set_pmk(&rdev->wiphy, dev, pmk_conf); + if (rdev->ops->set_pmk) + ret = rdev->ops->set_pmk(&rdev->wiphy, dev, pmk_conf); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_del_pmk(struct cfg80211_registered_device *rdev, + struct net_device *dev, const u8 *aa) +{ + int ret = -EOPNOTSUPP; + + trace_rdev_del_pmk(&rdev->wiphy, dev, aa); + if (rdev->ops->del_pmk) + ret = rdev->ops->del_pmk(&rdev->wiphy, dev, aa); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_external_auth(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_external_auth_params *params) +{ + int ret = -EOPNOTSUPP; + + trace_rdev_external_auth(&rdev->wiphy, dev, params); + if (rdev->ops->external_auth) + ret = rdev->ops->external_auth(&rdev->wiphy, dev, params); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_get_ftm_responder_stats(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_ftm_responder_stats *ftm_stats) +{ + int ret = -EOPNOTSUPP; + + trace_rdev_get_ftm_responder_stats(&rdev->wiphy, dev, ftm_stats); + if (rdev->ops->get_ftm_responder_stats) + ret = rdev->ops->get_ftm_responder_stats(&rdev->wiphy, dev, + ftm_stats); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_start_pmsr(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + struct cfg80211_pmsr_request *request) +{ + int ret = -EOPNOTSUPP; + + trace_rdev_start_pmsr(&rdev->wiphy, wdev, request->cookie); + if (rdev->ops->start_pmsr) + ret = rdev->ops->start_pmsr(&rdev->wiphy, wdev, request); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline void +rdev_abort_pmsr(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + struct cfg80211_pmsr_request *request) +{ + trace_rdev_abort_pmsr(&rdev->wiphy, wdev, request->cookie); + if (rdev->ops->abort_pmsr) + rdev->ops->abort_pmsr(&rdev->wiphy, wdev, request); + trace_rdev_return_void(&rdev->wiphy); +} + +static inline int rdev_update_owe_info(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_update_owe_info *oweinfo) +{ + int ret = -EOPNOTSUPP; + + trace_rdev_update_owe_info(&rdev->wiphy, dev, oweinfo); + if (rdev->ops->update_owe_info) + ret = rdev->ops->update_owe_info(&rdev->wiphy, dev, oweinfo); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_probe_mesh_link(struct cfg80211_registered_device *rdev, + struct net_device *dev, const u8 *dest, + const void *buf, size_t len) +{ + int ret; + + trace_rdev_probe_mesh_link(&rdev->wiphy, dev, dest, buf, len); + ret = rdev->ops->probe_mesh_link(&rdev->wiphy, dev, buf, len); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_set_tid_config(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_tid_config *tid_conf) +{ + int ret; + + trace_rdev_set_tid_config(&rdev->wiphy, dev, tid_conf); + ret = rdev->ops->set_tid_config(&rdev->wiphy, dev, tid_conf); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_reset_tid_config(struct cfg80211_registered_device *rdev, + struct net_device *dev, const u8 *peer, + u8 tids) +{ + int ret; + + trace_rdev_reset_tid_config(&rdev->wiphy, dev, peer, tids); + ret = rdev->ops->reset_tid_config(&rdev->wiphy, dev, peer, tids); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int rdev_set_sar_specs(struct cfg80211_registered_device *rdev, + struct cfg80211_sar_specs *sar) +{ + int ret; + + trace_rdev_set_sar_specs(&rdev->wiphy, sar); + ret = rdev->ops->set_sar_specs(&rdev->wiphy, sar); + trace_rdev_return_int(&rdev->wiphy, ret); + + return ret; +} + +static inline int rdev_color_change(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_color_change_settings *params) +{ + int ret; + + trace_rdev_color_change(&rdev->wiphy, dev, params); + ret = rdev->ops->color_change(&rdev->wiphy, dev, params); + trace_rdev_return_int(&rdev->wiphy, ret); + + return ret; +} + +static inline int +rdev_set_fils_aad(struct cfg80211_registered_device *rdev, + struct net_device *dev, struct cfg80211_fils_aad *fils_aad) +{ + int ret = -EOPNOTSUPP; + + trace_rdev_set_fils_aad(&rdev->wiphy, dev, fils_aad); + if (rdev->ops->set_fils_aad) + ret = rdev->ops->set_fils_aad(&rdev->wiphy, dev, fils_aad); + trace_rdev_return_int(&rdev->wiphy, ret); + + return ret; +} + +static inline int +rdev_set_radar_background(struct cfg80211_registered_device *rdev, + struct cfg80211_chan_def *chandef) +{ + struct wiphy *wiphy = &rdev->wiphy; + int ret; + + if (!rdev->ops->set_radar_background) + return -EOPNOTSUPP; + + trace_rdev_set_radar_background(wiphy, chandef); + ret = rdev->ops->set_radar_background(wiphy, chandef); + trace_rdev_return_int(wiphy, ret); + + return ret; +} + +static inline int +rdev_add_intf_link(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + unsigned int link_id) +{ + int ret = 0; + + trace_rdev_add_intf_link(&rdev->wiphy, wdev, link_id); + if (rdev->ops->add_intf_link) + ret = rdev->ops->add_intf_link(&rdev->wiphy, wdev, link_id); + trace_rdev_return_int(&rdev->wiphy, ret); + + return ret; +} + +static inline void +rdev_del_intf_link(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev, + unsigned int link_id) +{ + trace_rdev_del_intf_link(&rdev->wiphy, wdev, link_id); + if (rdev->ops->del_intf_link) + rdev->ops->del_intf_link(&rdev->wiphy, wdev, link_id); + trace_rdev_return_void(&rdev->wiphy); +} + +static inline int +rdev_add_link_station(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct link_station_parameters *params) +{ + int ret; + + if (!rdev->ops->add_link_station) + return -EOPNOTSUPP; + + trace_rdev_add_link_station(&rdev->wiphy, dev, params); + ret = rdev->ops->add_link_station(&rdev->wiphy, dev, params); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_mod_link_station(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct link_station_parameters *params) +{ + int ret; + + if (!rdev->ops->mod_link_station) + return -EOPNOTSUPP; + + trace_rdev_mod_link_station(&rdev->wiphy, dev, params); + ret = rdev->ops->mod_link_station(&rdev->wiphy, dev, params); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +static inline int +rdev_del_link_station(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct link_station_del_parameters *params) +{ + int ret; + + if (!rdev->ops->del_link_station) + return -EOPNOTSUPP; + + trace_rdev_del_link_station(&rdev->wiphy, dev, params); + ret = rdev->ops->del_link_station(&rdev->wiphy, dev, params); + trace_rdev_return_int(&rdev->wiphy, ret); + return ret; +} + +#endif /* __CFG80211_RDEV_OPS */ diff --git a/net/wireless/reg.c b/net/wireless/reg.c new file mode 100644 index 000000000..5da1a641e --- /dev/null +++ b/net/wireless/reg.c @@ -0,0 +1,4410 @@ +/* + * Copyright 2002-2005, Instant802 Networks, Inc. + * Copyright 2005-2006, Devicescape Software, Inc. + * Copyright 2007 Johannes Berg + * Copyright 2008-2011 Luis R. Rodriguez + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright 2017 Intel Deutschland GmbH + * Copyright (C) 2018 - 2022 Intel Corporation + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + + +/** + * DOC: Wireless regulatory infrastructure + * + * The usual implementation is for a driver to read a device EEPROM to + * determine which regulatory domain it should be operating under, then + * looking up the allowable channels in a driver-local table and finally + * registering those channels in the wiphy structure. + * + * Another set of compliance enforcement is for drivers to use their + * own compliance limits which can be stored on the EEPROM. The host + * driver or firmware may ensure these are used. + * + * In addition to all this we provide an extra layer of regulatory + * conformance. For drivers which do not have any regulatory + * information CRDA provides the complete regulatory solution. + * For others it provides a community effort on further restrictions + * to enhance compliance. + * + * Note: When number of rules --> infinity we will not be able to + * index on alpha2 any more, instead we'll probably have to + * rely on some SHA1 checksum of the regdomain for example. + * + */ + +#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "core.h" +#include "reg.h" +#include "rdev-ops.h" +#include "nl80211.h" + +/* + * Grace period we give before making sure all current interfaces reside on + * channels allowed by the current regulatory domain. + */ +#define REG_ENFORCE_GRACE_MS 60000 + +/** + * enum reg_request_treatment - regulatory request treatment + * + * @REG_REQ_OK: continue processing the regulatory request + * @REG_REQ_IGNORE: ignore the regulatory request + * @REG_REQ_INTERSECT: the regulatory domain resulting from this request should + * be intersected with the current one. + * @REG_REQ_ALREADY_SET: the regulatory request will not change the current + * regulatory settings, and no further processing is required. + */ +enum reg_request_treatment { + REG_REQ_OK, + REG_REQ_IGNORE, + REG_REQ_INTERSECT, + REG_REQ_ALREADY_SET, +}; + +static struct regulatory_request core_request_world = { + .initiator = NL80211_REGDOM_SET_BY_CORE, + .alpha2[0] = '0', + .alpha2[1] = '0', + .intersect = false, + .processed = true, + .country_ie_env = ENVIRON_ANY, +}; + +/* + * Receipt of information from last regulatory request, + * protected by RTNL (and can be accessed with RCU protection) + */ +static struct regulatory_request __rcu *last_request = + (void __force __rcu *)&core_request_world; + +/* To trigger userspace events and load firmware */ +static struct platform_device *reg_pdev; + +/* + * Central wireless core regulatory domains, we only need two, + * the current one and a world regulatory domain in case we have no + * information to give us an alpha2. + * (protected by RTNL, can be read under RCU) + */ +const struct ieee80211_regdomain __rcu *cfg80211_regdomain; + +/* + * Number of devices that registered to the core + * that support cellular base station regulatory hints + * (protected by RTNL) + */ +static int reg_num_devs_support_basehint; + +/* + * State variable indicating if the platform on which the devices + * are attached is operating in an indoor environment. The state variable + * is relevant for all registered devices. + */ +static bool reg_is_indoor; +static DEFINE_SPINLOCK(reg_indoor_lock); + +/* Used to track the userspace process controlling the indoor setting */ +static u32 reg_is_indoor_portid; + +static void restore_regulatory_settings(bool reset_user, bool cached); +static void print_regdomain(const struct ieee80211_regdomain *rd); +static void reg_process_hint(struct regulatory_request *reg_request); + +static const struct ieee80211_regdomain *get_cfg80211_regdom(void) +{ + return rcu_dereference_rtnl(cfg80211_regdomain); +} + +/* + * Returns the regulatory domain associated with the wiphy. + * + * Requires any of RTNL, wiphy mutex or RCU protection. + */ +const struct ieee80211_regdomain *get_wiphy_regdom(struct wiphy *wiphy) +{ + return rcu_dereference_check(wiphy->regd, + lockdep_is_held(&wiphy->mtx) || + lockdep_rtnl_is_held()); +} +EXPORT_SYMBOL(get_wiphy_regdom); + +static const char *reg_dfs_region_str(enum nl80211_dfs_regions dfs_region) +{ + switch (dfs_region) { + case NL80211_DFS_UNSET: + return "unset"; + case NL80211_DFS_FCC: + return "FCC"; + case NL80211_DFS_ETSI: + return "ETSI"; + case NL80211_DFS_JP: + return "JP"; + } + return "Unknown"; +} + +enum nl80211_dfs_regions reg_get_dfs_region(struct wiphy *wiphy) +{ + const struct ieee80211_regdomain *regd = NULL; + const struct ieee80211_regdomain *wiphy_regd = NULL; + enum nl80211_dfs_regions dfs_region; + + rcu_read_lock(); + regd = get_cfg80211_regdom(); + dfs_region = regd->dfs_region; + + if (!wiphy) + goto out; + + wiphy_regd = get_wiphy_regdom(wiphy); + if (!wiphy_regd) + goto out; + + if (wiphy->regulatory_flags & REGULATORY_WIPHY_SELF_MANAGED) { + dfs_region = wiphy_regd->dfs_region; + goto out; + } + + if (wiphy_regd->dfs_region == regd->dfs_region) + goto out; + + pr_debug("%s: device specific dfs_region (%s) disagrees with cfg80211's central dfs_region (%s)\n", + dev_name(&wiphy->dev), + reg_dfs_region_str(wiphy_regd->dfs_region), + reg_dfs_region_str(regd->dfs_region)); + +out: + rcu_read_unlock(); + + return dfs_region; +} + +static void rcu_free_regdom(const struct ieee80211_regdomain *r) +{ + if (!r) + return; + kfree_rcu((struct ieee80211_regdomain *)r, rcu_head); +} + +static struct regulatory_request *get_last_request(void) +{ + return rcu_dereference_rtnl(last_request); +} + +/* Used to queue up regulatory hints */ +static LIST_HEAD(reg_requests_list); +static DEFINE_SPINLOCK(reg_requests_lock); + +/* Used to queue up beacon hints for review */ +static LIST_HEAD(reg_pending_beacons); +static DEFINE_SPINLOCK(reg_pending_beacons_lock); + +/* Used to keep track of processed beacon hints */ +static LIST_HEAD(reg_beacon_list); + +struct reg_beacon { + struct list_head list; + struct ieee80211_channel chan; +}; + +static void reg_check_chans_work(struct work_struct *work); +static DECLARE_DELAYED_WORK(reg_check_chans, reg_check_chans_work); + +static void reg_todo(struct work_struct *work); +static DECLARE_WORK(reg_work, reg_todo); + +/* We keep a static world regulatory domain in case of the absence of CRDA */ +static const struct ieee80211_regdomain world_regdom = { + .n_reg_rules = 8, + .alpha2 = "00", + .reg_rules = { + /* IEEE 802.11b/g, channels 1..11 */ + REG_RULE(2412-10, 2462+10, 40, 6, 20, 0), + /* IEEE 802.11b/g, channels 12..13. */ + REG_RULE(2467-10, 2472+10, 20, 6, 20, + NL80211_RRF_NO_IR | NL80211_RRF_AUTO_BW), + /* IEEE 802.11 channel 14 - Only JP enables + * this and for 802.11b only */ + REG_RULE(2484-10, 2484+10, 20, 6, 20, + NL80211_RRF_NO_IR | + NL80211_RRF_NO_OFDM), + /* IEEE 802.11a, channel 36..48 */ + REG_RULE(5180-10, 5240+10, 80, 6, 20, + NL80211_RRF_NO_IR | + NL80211_RRF_AUTO_BW), + + /* IEEE 802.11a, channel 52..64 - DFS required */ + REG_RULE(5260-10, 5320+10, 80, 6, 20, + NL80211_RRF_NO_IR | + NL80211_RRF_AUTO_BW | + NL80211_RRF_DFS), + + /* IEEE 802.11a, channel 100..144 - DFS required */ + REG_RULE(5500-10, 5720+10, 160, 6, 20, + NL80211_RRF_NO_IR | + NL80211_RRF_DFS), + + /* IEEE 802.11a, channel 149..165 */ + REG_RULE(5745-10, 5825+10, 80, 6, 20, + NL80211_RRF_NO_IR), + + /* IEEE 802.11ad (60GHz), channels 1..3 */ + REG_RULE(56160+2160*1-1080, 56160+2160*3+1080, 2160, 0, 0, 0), + } +}; + +/* protected by RTNL */ +static const struct ieee80211_regdomain *cfg80211_world_regdom = + &world_regdom; + +static char *ieee80211_regdom = "00"; +static char user_alpha2[2]; +static const struct ieee80211_regdomain *cfg80211_user_regdom; + +module_param(ieee80211_regdom, charp, 0444); +MODULE_PARM_DESC(ieee80211_regdom, "IEEE 802.11 regulatory domain code"); + +static void reg_free_request(struct regulatory_request *request) +{ + if (request == &core_request_world) + return; + + if (request != get_last_request()) + kfree(request); +} + +static void reg_free_last_request(void) +{ + struct regulatory_request *lr = get_last_request(); + + if (lr != &core_request_world && lr) + kfree_rcu(lr, rcu_head); +} + +static void reg_update_last_request(struct regulatory_request *request) +{ + struct regulatory_request *lr; + + lr = get_last_request(); + if (lr == request) + return; + + reg_free_last_request(); + rcu_assign_pointer(last_request, request); +} + +static void reset_regdomains(bool full_reset, + const struct ieee80211_regdomain *new_regdom) +{ + const struct ieee80211_regdomain *r; + + ASSERT_RTNL(); + + r = get_cfg80211_regdom(); + + /* avoid freeing static information or freeing something twice */ + if (r == cfg80211_world_regdom) + r = NULL; + if (cfg80211_world_regdom == &world_regdom) + cfg80211_world_regdom = NULL; + if (r == &world_regdom) + r = NULL; + + rcu_free_regdom(r); + rcu_free_regdom(cfg80211_world_regdom); + + cfg80211_world_regdom = &world_regdom; + rcu_assign_pointer(cfg80211_regdomain, new_regdom); + + if (!full_reset) + return; + + reg_update_last_request(&core_request_world); +} + +/* + * Dynamic world regulatory domain requested by the wireless + * core upon initialization + */ +static void update_world_regdomain(const struct ieee80211_regdomain *rd) +{ + struct regulatory_request *lr; + + lr = get_last_request(); + + WARN_ON(!lr); + + reset_regdomains(false, rd); + + cfg80211_world_regdom = rd; +} + +bool is_world_regdom(const char *alpha2) +{ + if (!alpha2) + return false; + return alpha2[0] == '0' && alpha2[1] == '0'; +} + +static bool is_alpha2_set(const char *alpha2) +{ + if (!alpha2) + return false; + return alpha2[0] && alpha2[1]; +} + +static bool is_unknown_alpha2(const char *alpha2) +{ + if (!alpha2) + return false; + /* + * Special case where regulatory domain was built by driver + * but a specific alpha2 cannot be determined + */ + return alpha2[0] == '9' && alpha2[1] == '9'; +} + +static bool is_intersected_alpha2(const char *alpha2) +{ + if (!alpha2) + return false; + /* + * Special case where regulatory domain is the + * result of an intersection between two regulatory domain + * structures + */ + return alpha2[0] == '9' && alpha2[1] == '8'; +} + +static bool is_an_alpha2(const char *alpha2) +{ + if (!alpha2) + return false; + return isalpha(alpha2[0]) && isalpha(alpha2[1]); +} + +static bool alpha2_equal(const char *alpha2_x, const char *alpha2_y) +{ + if (!alpha2_x || !alpha2_y) + return false; + return alpha2_x[0] == alpha2_y[0] && alpha2_x[1] == alpha2_y[1]; +} + +static bool regdom_changes(const char *alpha2) +{ + const struct ieee80211_regdomain *r = get_cfg80211_regdom(); + + if (!r) + return true; + return !alpha2_equal(r->alpha2, alpha2); +} + +/* + * The NL80211_REGDOM_SET_BY_USER regdom alpha2 is cached, this lets + * you know if a valid regulatory hint with NL80211_REGDOM_SET_BY_USER + * has ever been issued. + */ +static bool is_user_regdom_saved(void) +{ + if (user_alpha2[0] == '9' && user_alpha2[1] == '7') + return false; + + /* This would indicate a mistake on the design */ + if (WARN(!is_world_regdom(user_alpha2) && !is_an_alpha2(user_alpha2), + "Unexpected user alpha2: %c%c\n", + user_alpha2[0], user_alpha2[1])) + return false; + + return true; +} + +static const struct ieee80211_regdomain * +reg_copy_regd(const struct ieee80211_regdomain *src_regd) +{ + struct ieee80211_regdomain *regd; + unsigned int i; + + regd = kzalloc(struct_size(regd, reg_rules, src_regd->n_reg_rules), + GFP_KERNEL); + if (!regd) + return ERR_PTR(-ENOMEM); + + memcpy(regd, src_regd, sizeof(struct ieee80211_regdomain)); + + for (i = 0; i < src_regd->n_reg_rules; i++) + memcpy(®d->reg_rules[i], &src_regd->reg_rules[i], + sizeof(struct ieee80211_reg_rule)); + + return regd; +} + +static void cfg80211_save_user_regdom(const struct ieee80211_regdomain *rd) +{ + ASSERT_RTNL(); + + if (!IS_ERR(cfg80211_user_regdom)) + kfree(cfg80211_user_regdom); + cfg80211_user_regdom = reg_copy_regd(rd); +} + +struct reg_regdb_apply_request { + struct list_head list; + const struct ieee80211_regdomain *regdom; +}; + +static LIST_HEAD(reg_regdb_apply_list); +static DEFINE_MUTEX(reg_regdb_apply_mutex); + +static void reg_regdb_apply(struct work_struct *work) +{ + struct reg_regdb_apply_request *request; + + rtnl_lock(); + + mutex_lock(®_regdb_apply_mutex); + while (!list_empty(®_regdb_apply_list)) { + request = list_first_entry(®_regdb_apply_list, + struct reg_regdb_apply_request, + list); + list_del(&request->list); + + set_regdom(request->regdom, REGD_SOURCE_INTERNAL_DB); + kfree(request); + } + mutex_unlock(®_regdb_apply_mutex); + + rtnl_unlock(); +} + +static DECLARE_WORK(reg_regdb_work, reg_regdb_apply); + +static int reg_schedule_apply(const struct ieee80211_regdomain *regdom) +{ + struct reg_regdb_apply_request *request; + + request = kzalloc(sizeof(struct reg_regdb_apply_request), GFP_KERNEL); + if (!request) { + kfree(regdom); + return -ENOMEM; + } + + request->regdom = regdom; + + mutex_lock(®_regdb_apply_mutex); + list_add_tail(&request->list, ®_regdb_apply_list); + mutex_unlock(®_regdb_apply_mutex); + + schedule_work(®_regdb_work); + return 0; +} + +#ifdef CONFIG_CFG80211_CRDA_SUPPORT +/* Max number of consecutive attempts to communicate with CRDA */ +#define REG_MAX_CRDA_TIMEOUTS 10 + +static u32 reg_crda_timeouts; + +static void crda_timeout_work(struct work_struct *work); +static DECLARE_DELAYED_WORK(crda_timeout, crda_timeout_work); + +static void crda_timeout_work(struct work_struct *work) +{ + pr_debug("Timeout while waiting for CRDA to reply, restoring regulatory settings\n"); + rtnl_lock(); + reg_crda_timeouts++; + restore_regulatory_settings(true, false); + rtnl_unlock(); +} + +static void cancel_crda_timeout(void) +{ + cancel_delayed_work(&crda_timeout); +} + +static void cancel_crda_timeout_sync(void) +{ + cancel_delayed_work_sync(&crda_timeout); +} + +static void reset_crda_timeouts(void) +{ + reg_crda_timeouts = 0; +} + +/* + * This lets us keep regulatory code which is updated on a regulatory + * basis in userspace. + */ +static int call_crda(const char *alpha2) +{ + char country[12]; + char *env[] = { country, NULL }; + int ret; + + snprintf(country, sizeof(country), "COUNTRY=%c%c", + alpha2[0], alpha2[1]); + + if (reg_crda_timeouts > REG_MAX_CRDA_TIMEOUTS) { + pr_debug("Exceeded CRDA call max attempts. Not calling CRDA\n"); + return -EINVAL; + } + + if (!is_world_regdom((char *) alpha2)) + pr_debug("Calling CRDA for country: %c%c\n", + alpha2[0], alpha2[1]); + else + pr_debug("Calling CRDA to update world regulatory domain\n"); + + ret = kobject_uevent_env(®_pdev->dev.kobj, KOBJ_CHANGE, env); + if (ret) + return ret; + + queue_delayed_work(system_power_efficient_wq, + &crda_timeout, msecs_to_jiffies(3142)); + return 0; +} +#else +static inline void cancel_crda_timeout(void) {} +static inline void cancel_crda_timeout_sync(void) {} +static inline void reset_crda_timeouts(void) {} +static inline int call_crda(const char *alpha2) +{ + return -ENODATA; +} +#endif /* CONFIG_CFG80211_CRDA_SUPPORT */ + +/* code to directly load a firmware database through request_firmware */ +static const struct fwdb_header *regdb; + +struct fwdb_country { + u8 alpha2[2]; + __be16 coll_ptr; + /* this struct cannot be extended */ +} __packed __aligned(4); + +struct fwdb_collection { + u8 len; + u8 n_rules; + u8 dfs_region; + /* no optional data yet */ + /* aligned to 2, then followed by __be16 array of rule pointers */ +} __packed __aligned(4); + +enum fwdb_flags { + FWDB_FLAG_NO_OFDM = BIT(0), + FWDB_FLAG_NO_OUTDOOR = BIT(1), + FWDB_FLAG_DFS = BIT(2), + FWDB_FLAG_NO_IR = BIT(3), + FWDB_FLAG_AUTO_BW = BIT(4), +}; + +struct fwdb_wmm_ac { + u8 ecw; + u8 aifsn; + __be16 cot; +} __packed; + +struct fwdb_wmm_rule { + struct fwdb_wmm_ac client[IEEE80211_NUM_ACS]; + struct fwdb_wmm_ac ap[IEEE80211_NUM_ACS]; +} __packed; + +struct fwdb_rule { + u8 len; + u8 flags; + __be16 max_eirp; + __be32 start, end, max_bw; + /* start of optional data */ + __be16 cac_timeout; + __be16 wmm_ptr; +} __packed __aligned(4); + +#define FWDB_MAGIC 0x52474442 +#define FWDB_VERSION 20 + +struct fwdb_header { + __be32 magic; + __be32 version; + struct fwdb_country country[]; +} __packed __aligned(4); + +static int ecw2cw(int ecw) +{ + return (1 << ecw) - 1; +} + +static bool valid_wmm(struct fwdb_wmm_rule *rule) +{ + struct fwdb_wmm_ac *ac = (struct fwdb_wmm_ac *)rule; + int i; + + for (i = 0; i < IEEE80211_NUM_ACS * 2; i++) { + u16 cw_min = ecw2cw((ac[i].ecw & 0xf0) >> 4); + u16 cw_max = ecw2cw(ac[i].ecw & 0x0f); + u8 aifsn = ac[i].aifsn; + + if (cw_min >= cw_max) + return false; + + if (aifsn < 1) + return false; + } + + return true; +} + +static bool valid_rule(const u8 *data, unsigned int size, u16 rule_ptr) +{ + struct fwdb_rule *rule = (void *)(data + (rule_ptr << 2)); + + if ((u8 *)rule + sizeof(rule->len) > data + size) + return false; + + /* mandatory fields */ + if (rule->len < offsetofend(struct fwdb_rule, max_bw)) + return false; + if (rule->len >= offsetofend(struct fwdb_rule, wmm_ptr)) { + u32 wmm_ptr = be16_to_cpu(rule->wmm_ptr) << 2; + struct fwdb_wmm_rule *wmm; + + if (wmm_ptr + sizeof(struct fwdb_wmm_rule) > size) + return false; + + wmm = (void *)(data + wmm_ptr); + + if (!valid_wmm(wmm)) + return false; + } + return true; +} + +static bool valid_country(const u8 *data, unsigned int size, + const struct fwdb_country *country) +{ + unsigned int ptr = be16_to_cpu(country->coll_ptr) << 2; + struct fwdb_collection *coll = (void *)(data + ptr); + __be16 *rules_ptr; + unsigned int i; + + /* make sure we can read len/n_rules */ + if ((u8 *)coll + offsetofend(typeof(*coll), n_rules) > data + size) + return false; + + /* make sure base struct and all rules fit */ + if ((u8 *)coll + ALIGN(coll->len, 2) + + (coll->n_rules * 2) > data + size) + return false; + + /* mandatory fields must exist */ + if (coll->len < offsetofend(struct fwdb_collection, dfs_region)) + return false; + + rules_ptr = (void *)((u8 *)coll + ALIGN(coll->len, 2)); + + for (i = 0; i < coll->n_rules; i++) { + u16 rule_ptr = be16_to_cpu(rules_ptr[i]); + + if (!valid_rule(data, size, rule_ptr)) + return false; + } + + return true; +} + +#ifdef CONFIG_CFG80211_REQUIRE_SIGNED_REGDB +static struct key *builtin_regdb_keys; + +static void __init load_keys_from_buffer(const u8 *p, unsigned int buflen) +{ + const u8 *end = p + buflen; + size_t plen; + key_ref_t key; + + while (p < end) { + /* Each cert begins with an ASN.1 SEQUENCE tag and must be more + * than 256 bytes in size. + */ + if (end - p < 4) + goto dodgy_cert; + if (p[0] != 0x30 && + p[1] != 0x82) + goto dodgy_cert; + plen = (p[2] << 8) | p[3]; + plen += 4; + if (plen > end - p) + goto dodgy_cert; + + key = key_create_or_update(make_key_ref(builtin_regdb_keys, 1), + "asymmetric", NULL, p, plen, + ((KEY_POS_ALL & ~KEY_POS_SETATTR) | + KEY_USR_VIEW | KEY_USR_READ), + KEY_ALLOC_NOT_IN_QUOTA | + KEY_ALLOC_BUILT_IN | + KEY_ALLOC_BYPASS_RESTRICTION); + if (IS_ERR(key)) { + pr_err("Problem loading in-kernel X.509 certificate (%ld)\n", + PTR_ERR(key)); + } else { + pr_notice("Loaded X.509 cert '%s'\n", + key_ref_to_ptr(key)->description); + key_ref_put(key); + } + p += plen; + } + + return; + +dodgy_cert: + pr_err("Problem parsing in-kernel X.509 certificate list\n"); +} + +static int __init load_builtin_regdb_keys(void) +{ + builtin_regdb_keys = + keyring_alloc(".builtin_regdb_keys", + KUIDT_INIT(0), KGIDT_INIT(0), current_cred(), + ((KEY_POS_ALL & ~KEY_POS_SETATTR) | + KEY_USR_VIEW | KEY_USR_READ | KEY_USR_SEARCH), + KEY_ALLOC_NOT_IN_QUOTA, NULL, NULL); + if (IS_ERR(builtin_regdb_keys)) + return PTR_ERR(builtin_regdb_keys); + + pr_notice("Loading compiled-in X.509 certificates for regulatory database\n"); + +#ifdef CONFIG_CFG80211_USE_KERNEL_REGDB_KEYS + load_keys_from_buffer(shipped_regdb_certs, shipped_regdb_certs_len); +#endif +#ifdef CONFIG_CFG80211_EXTRA_REGDB_KEYDIR + if (CONFIG_CFG80211_EXTRA_REGDB_KEYDIR[0] != '\0') + load_keys_from_buffer(extra_regdb_certs, extra_regdb_certs_len); +#endif + + return 0; +} + +MODULE_FIRMWARE("regulatory.db.p7s"); + +static bool regdb_has_valid_signature(const u8 *data, unsigned int size) +{ + const struct firmware *sig; + bool result; + + if (request_firmware(&sig, "regulatory.db.p7s", ®_pdev->dev)) + return false; + + result = verify_pkcs7_signature(data, size, sig->data, sig->size, + builtin_regdb_keys, + VERIFYING_UNSPECIFIED_SIGNATURE, + NULL, NULL) == 0; + + release_firmware(sig); + + return result; +} + +static void free_regdb_keyring(void) +{ + key_put(builtin_regdb_keys); +} +#else +static int load_builtin_regdb_keys(void) +{ + return 0; +} + +static bool regdb_has_valid_signature(const u8 *data, unsigned int size) +{ + return true; +} + +static void free_regdb_keyring(void) +{ +} +#endif /* CONFIG_CFG80211_REQUIRE_SIGNED_REGDB */ + +static bool valid_regdb(const u8 *data, unsigned int size) +{ + const struct fwdb_header *hdr = (void *)data; + const struct fwdb_country *country; + + if (size < sizeof(*hdr)) + return false; + + if (hdr->magic != cpu_to_be32(FWDB_MAGIC)) + return false; + + if (hdr->version != cpu_to_be32(FWDB_VERSION)) + return false; + + if (!regdb_has_valid_signature(data, size)) + return false; + + country = &hdr->country[0]; + while ((u8 *)(country + 1) <= data + size) { + if (!country->coll_ptr) + break; + if (!valid_country(data, size, country)) + return false; + country++; + } + + return true; +} + +static void set_wmm_rule(const struct fwdb_header *db, + const struct fwdb_country *country, + const struct fwdb_rule *rule, + struct ieee80211_reg_rule *rrule) +{ + struct ieee80211_wmm_rule *wmm_rule = &rrule->wmm_rule; + struct fwdb_wmm_rule *wmm; + unsigned int i, wmm_ptr; + + wmm_ptr = be16_to_cpu(rule->wmm_ptr) << 2; + wmm = (void *)((u8 *)db + wmm_ptr); + + if (!valid_wmm(wmm)) { + pr_err("Invalid regulatory WMM rule %u-%u in domain %c%c\n", + be32_to_cpu(rule->start), be32_to_cpu(rule->end), + country->alpha2[0], country->alpha2[1]); + return; + } + + for (i = 0; i < IEEE80211_NUM_ACS; i++) { + wmm_rule->client[i].cw_min = + ecw2cw((wmm->client[i].ecw & 0xf0) >> 4); + wmm_rule->client[i].cw_max = ecw2cw(wmm->client[i].ecw & 0x0f); + wmm_rule->client[i].aifsn = wmm->client[i].aifsn; + wmm_rule->client[i].cot = + 1000 * be16_to_cpu(wmm->client[i].cot); + wmm_rule->ap[i].cw_min = ecw2cw((wmm->ap[i].ecw & 0xf0) >> 4); + wmm_rule->ap[i].cw_max = ecw2cw(wmm->ap[i].ecw & 0x0f); + wmm_rule->ap[i].aifsn = wmm->ap[i].aifsn; + wmm_rule->ap[i].cot = 1000 * be16_to_cpu(wmm->ap[i].cot); + } + + rrule->has_wmm = true; +} + +static int __regdb_query_wmm(const struct fwdb_header *db, + const struct fwdb_country *country, int freq, + struct ieee80211_reg_rule *rrule) +{ + unsigned int ptr = be16_to_cpu(country->coll_ptr) << 2; + struct fwdb_collection *coll = (void *)((u8 *)db + ptr); + int i; + + for (i = 0; i < coll->n_rules; i++) { + __be16 *rules_ptr = (void *)((u8 *)coll + ALIGN(coll->len, 2)); + unsigned int rule_ptr = be16_to_cpu(rules_ptr[i]) << 2; + struct fwdb_rule *rule = (void *)((u8 *)db + rule_ptr); + + if (rule->len < offsetofend(struct fwdb_rule, wmm_ptr)) + continue; + + if (freq >= KHZ_TO_MHZ(be32_to_cpu(rule->start)) && + freq <= KHZ_TO_MHZ(be32_to_cpu(rule->end))) { + set_wmm_rule(db, country, rule, rrule); + return 0; + } + } + + return -ENODATA; +} + +int reg_query_regdb_wmm(char *alpha2, int freq, struct ieee80211_reg_rule *rule) +{ + const struct fwdb_header *hdr = regdb; + const struct fwdb_country *country; + + if (!regdb) + return -ENODATA; + + if (IS_ERR(regdb)) + return PTR_ERR(regdb); + + country = &hdr->country[0]; + while (country->coll_ptr) { + if (alpha2_equal(alpha2, country->alpha2)) + return __regdb_query_wmm(regdb, country, freq, rule); + + country++; + } + + return -ENODATA; +} +EXPORT_SYMBOL(reg_query_regdb_wmm); + +static int regdb_query_country(const struct fwdb_header *db, + const struct fwdb_country *country) +{ + unsigned int ptr = be16_to_cpu(country->coll_ptr) << 2; + struct fwdb_collection *coll = (void *)((u8 *)db + ptr); + struct ieee80211_regdomain *regdom; + unsigned int i; + + regdom = kzalloc(struct_size(regdom, reg_rules, coll->n_rules), + GFP_KERNEL); + if (!regdom) + return -ENOMEM; + + regdom->n_reg_rules = coll->n_rules; + regdom->alpha2[0] = country->alpha2[0]; + regdom->alpha2[1] = country->alpha2[1]; + regdom->dfs_region = coll->dfs_region; + + for (i = 0; i < regdom->n_reg_rules; i++) { + __be16 *rules_ptr = (void *)((u8 *)coll + ALIGN(coll->len, 2)); + unsigned int rule_ptr = be16_to_cpu(rules_ptr[i]) << 2; + struct fwdb_rule *rule = (void *)((u8 *)db + rule_ptr); + struct ieee80211_reg_rule *rrule = ®dom->reg_rules[i]; + + rrule->freq_range.start_freq_khz = be32_to_cpu(rule->start); + rrule->freq_range.end_freq_khz = be32_to_cpu(rule->end); + rrule->freq_range.max_bandwidth_khz = be32_to_cpu(rule->max_bw); + + rrule->power_rule.max_antenna_gain = 0; + rrule->power_rule.max_eirp = be16_to_cpu(rule->max_eirp); + + rrule->flags = 0; + if (rule->flags & FWDB_FLAG_NO_OFDM) + rrule->flags |= NL80211_RRF_NO_OFDM; + if (rule->flags & FWDB_FLAG_NO_OUTDOOR) + rrule->flags |= NL80211_RRF_NO_OUTDOOR; + if (rule->flags & FWDB_FLAG_DFS) + rrule->flags |= NL80211_RRF_DFS; + if (rule->flags & FWDB_FLAG_NO_IR) + rrule->flags |= NL80211_RRF_NO_IR; + if (rule->flags & FWDB_FLAG_AUTO_BW) + rrule->flags |= NL80211_RRF_AUTO_BW; + + rrule->dfs_cac_ms = 0; + + /* handle optional data */ + if (rule->len >= offsetofend(struct fwdb_rule, cac_timeout)) + rrule->dfs_cac_ms = + 1000 * be16_to_cpu(rule->cac_timeout); + if (rule->len >= offsetofend(struct fwdb_rule, wmm_ptr)) + set_wmm_rule(db, country, rule, rrule); + } + + return reg_schedule_apply(regdom); +} + +static int query_regdb(const char *alpha2) +{ + const struct fwdb_header *hdr = regdb; + const struct fwdb_country *country; + + ASSERT_RTNL(); + + if (IS_ERR(regdb)) + return PTR_ERR(regdb); + + country = &hdr->country[0]; + while (country->coll_ptr) { + if (alpha2_equal(alpha2, country->alpha2)) + return regdb_query_country(regdb, country); + country++; + } + + return -ENODATA; +} + +static void regdb_fw_cb(const struct firmware *fw, void *context) +{ + int set_error = 0; + bool restore = true; + void *db; + + if (!fw) { + pr_info("failed to load regulatory.db\n"); + set_error = -ENODATA; + } else if (!valid_regdb(fw->data, fw->size)) { + pr_info("loaded regulatory.db is malformed or signature is missing/invalid\n"); + set_error = -EINVAL; + } + + rtnl_lock(); + if (regdb && !IS_ERR(regdb)) { + /* negative case - a bug + * positive case - can happen due to race in case of multiple cb's in + * queue, due to usage of asynchronous callback + * + * Either case, just restore and free new db. + */ + } else if (set_error) { + regdb = ERR_PTR(set_error); + } else if (fw) { + db = kmemdup(fw->data, fw->size, GFP_KERNEL); + if (db) { + regdb = db; + restore = context && query_regdb(context); + } else { + restore = true; + } + } + + if (restore) + restore_regulatory_settings(true, false); + + rtnl_unlock(); + + kfree(context); + + release_firmware(fw); +} + +MODULE_FIRMWARE("regulatory.db"); + +static int query_regdb_file(const char *alpha2) +{ + int err; + + ASSERT_RTNL(); + + if (regdb) + return query_regdb(alpha2); + + alpha2 = kmemdup(alpha2, 2, GFP_KERNEL); + if (!alpha2) + return -ENOMEM; + + err = request_firmware_nowait(THIS_MODULE, true, "regulatory.db", + ®_pdev->dev, GFP_KERNEL, + (void *)alpha2, regdb_fw_cb); + if (err) + kfree(alpha2); + + return err; +} + +int reg_reload_regdb(void) +{ + const struct firmware *fw; + void *db; + int err; + const struct ieee80211_regdomain *current_regdomain; + struct regulatory_request *request; + + err = request_firmware(&fw, "regulatory.db", ®_pdev->dev); + if (err) + return err; + + if (!valid_regdb(fw->data, fw->size)) { + err = -ENODATA; + goto out; + } + + db = kmemdup(fw->data, fw->size, GFP_KERNEL); + if (!db) { + err = -ENOMEM; + goto out; + } + + rtnl_lock(); + if (!IS_ERR_OR_NULL(regdb)) + kfree(regdb); + regdb = db; + + /* reset regulatory domain */ + current_regdomain = get_cfg80211_regdom(); + + request = kzalloc(sizeof(*request), GFP_KERNEL); + if (!request) { + err = -ENOMEM; + goto out_unlock; + } + + request->wiphy_idx = WIPHY_IDX_INVALID; + request->alpha2[0] = current_regdomain->alpha2[0]; + request->alpha2[1] = current_regdomain->alpha2[1]; + request->initiator = NL80211_REGDOM_SET_BY_CORE; + request->user_reg_hint_type = NL80211_USER_REG_HINT_USER; + + reg_process_hint(request); + +out_unlock: + rtnl_unlock(); + out: + release_firmware(fw); + return err; +} + +static bool reg_query_database(struct regulatory_request *request) +{ + if (query_regdb_file(request->alpha2) == 0) + return true; + + if (call_crda(request->alpha2) == 0) + return true; + + return false; +} + +bool reg_is_valid_request(const char *alpha2) +{ + struct regulatory_request *lr = get_last_request(); + + if (!lr || lr->processed) + return false; + + return alpha2_equal(lr->alpha2, alpha2); +} + +static const struct ieee80211_regdomain *reg_get_regdomain(struct wiphy *wiphy) +{ + struct regulatory_request *lr = get_last_request(); + + /* + * Follow the driver's regulatory domain, if present, unless a country + * IE has been processed or a user wants to help complaince further + */ + if (lr->initiator != NL80211_REGDOM_SET_BY_COUNTRY_IE && + lr->initiator != NL80211_REGDOM_SET_BY_USER && + wiphy->regd) + return get_wiphy_regdom(wiphy); + + return get_cfg80211_regdom(); +} + +static unsigned int +reg_get_max_bandwidth_from_range(const struct ieee80211_regdomain *rd, + const struct ieee80211_reg_rule *rule) +{ + const struct ieee80211_freq_range *freq_range = &rule->freq_range; + const struct ieee80211_freq_range *freq_range_tmp; + const struct ieee80211_reg_rule *tmp; + u32 start_freq, end_freq, idx, no; + + for (idx = 0; idx < rd->n_reg_rules; idx++) + if (rule == &rd->reg_rules[idx]) + break; + + if (idx == rd->n_reg_rules) + return 0; + + /* get start_freq */ + no = idx; + + while (no) { + tmp = &rd->reg_rules[--no]; + freq_range_tmp = &tmp->freq_range; + + if (freq_range_tmp->end_freq_khz < freq_range->start_freq_khz) + break; + + freq_range = freq_range_tmp; + } + + start_freq = freq_range->start_freq_khz; + + /* get end_freq */ + freq_range = &rule->freq_range; + no = idx; + + while (no < rd->n_reg_rules - 1) { + tmp = &rd->reg_rules[++no]; + freq_range_tmp = &tmp->freq_range; + + if (freq_range_tmp->start_freq_khz > freq_range->end_freq_khz) + break; + + freq_range = freq_range_tmp; + } + + end_freq = freq_range->end_freq_khz; + + return end_freq - start_freq; +} + +unsigned int reg_get_max_bandwidth(const struct ieee80211_regdomain *rd, + const struct ieee80211_reg_rule *rule) +{ + unsigned int bw = reg_get_max_bandwidth_from_range(rd, rule); + + if (rule->flags & NL80211_RRF_NO_320MHZ) + bw = min_t(unsigned int, bw, MHZ_TO_KHZ(160)); + if (rule->flags & NL80211_RRF_NO_160MHZ) + bw = min_t(unsigned int, bw, MHZ_TO_KHZ(80)); + if (rule->flags & NL80211_RRF_NO_80MHZ) + bw = min_t(unsigned int, bw, MHZ_TO_KHZ(40)); + + /* + * HT40+/HT40- limits are handled per-channel. Only limit BW if both + * are not allowed. + */ + if (rule->flags & NL80211_RRF_NO_HT40MINUS && + rule->flags & NL80211_RRF_NO_HT40PLUS) + bw = min_t(unsigned int, bw, MHZ_TO_KHZ(20)); + + return bw; +} + +/* Sanity check on a regulatory rule */ +static bool is_valid_reg_rule(const struct ieee80211_reg_rule *rule) +{ + const struct ieee80211_freq_range *freq_range = &rule->freq_range; + u32 freq_diff; + + if (freq_range->start_freq_khz <= 0 || freq_range->end_freq_khz <= 0) + return false; + + if (freq_range->start_freq_khz > freq_range->end_freq_khz) + return false; + + freq_diff = freq_range->end_freq_khz - freq_range->start_freq_khz; + + if (freq_range->end_freq_khz <= freq_range->start_freq_khz || + freq_range->max_bandwidth_khz > freq_diff) + return false; + + return true; +} + +static bool is_valid_rd(const struct ieee80211_regdomain *rd) +{ + const struct ieee80211_reg_rule *reg_rule = NULL; + unsigned int i; + + if (!rd->n_reg_rules) + return false; + + if (WARN_ON(rd->n_reg_rules > NL80211_MAX_SUPP_REG_RULES)) + return false; + + for (i = 0; i < rd->n_reg_rules; i++) { + reg_rule = &rd->reg_rules[i]; + if (!is_valid_reg_rule(reg_rule)) + return false; + } + + return true; +} + +/** + * freq_in_rule_band - tells us if a frequency is in a frequency band + * @freq_range: frequency rule we want to query + * @freq_khz: frequency we are inquiring about + * + * This lets us know if a specific frequency rule is or is not relevant to + * a specific frequency's band. Bands are device specific and artificial + * definitions (the "2.4 GHz band", the "5 GHz band" and the "60GHz band"), + * however it is safe for now to assume that a frequency rule should not be + * part of a frequency's band if the start freq or end freq are off by more + * than 2 GHz for the 2.4 and 5 GHz bands, and by more than 20 GHz for the + * 60 GHz band. + * This resolution can be lowered and should be considered as we add + * regulatory rule support for other "bands". + **/ +static bool freq_in_rule_band(const struct ieee80211_freq_range *freq_range, + u32 freq_khz) +{ +#define ONE_GHZ_IN_KHZ 1000000 + /* + * From 802.11ad: directional multi-gigabit (DMG): + * Pertaining to operation in a frequency band containing a channel + * with the Channel starting frequency above 45 GHz. + */ + u32 limit = freq_khz > 45 * ONE_GHZ_IN_KHZ ? + 20 * ONE_GHZ_IN_KHZ : 2 * ONE_GHZ_IN_KHZ; + if (abs(freq_khz - freq_range->start_freq_khz) <= limit) + return true; + if (abs(freq_khz - freq_range->end_freq_khz) <= limit) + return true; + return false; +#undef ONE_GHZ_IN_KHZ +} + +/* + * Later on we can perhaps use the more restrictive DFS + * region but we don't have information for that yet so + * for now simply disallow conflicts. + */ +static enum nl80211_dfs_regions +reg_intersect_dfs_region(const enum nl80211_dfs_regions dfs_region1, + const enum nl80211_dfs_regions dfs_region2) +{ + if (dfs_region1 != dfs_region2) + return NL80211_DFS_UNSET; + return dfs_region1; +} + +static void reg_wmm_rules_intersect(const struct ieee80211_wmm_ac *wmm_ac1, + const struct ieee80211_wmm_ac *wmm_ac2, + struct ieee80211_wmm_ac *intersect) +{ + intersect->cw_min = max_t(u16, wmm_ac1->cw_min, wmm_ac2->cw_min); + intersect->cw_max = max_t(u16, wmm_ac1->cw_max, wmm_ac2->cw_max); + intersect->cot = min_t(u16, wmm_ac1->cot, wmm_ac2->cot); + intersect->aifsn = max_t(u8, wmm_ac1->aifsn, wmm_ac2->aifsn); +} + +/* + * Helper for regdom_intersect(), this does the real + * mathematical intersection fun + */ +static int reg_rules_intersect(const struct ieee80211_regdomain *rd1, + const struct ieee80211_regdomain *rd2, + const struct ieee80211_reg_rule *rule1, + const struct ieee80211_reg_rule *rule2, + struct ieee80211_reg_rule *intersected_rule) +{ + const struct ieee80211_freq_range *freq_range1, *freq_range2; + struct ieee80211_freq_range *freq_range; + const struct ieee80211_power_rule *power_rule1, *power_rule2; + struct ieee80211_power_rule *power_rule; + const struct ieee80211_wmm_rule *wmm_rule1, *wmm_rule2; + struct ieee80211_wmm_rule *wmm_rule; + u32 freq_diff, max_bandwidth1, max_bandwidth2; + + freq_range1 = &rule1->freq_range; + freq_range2 = &rule2->freq_range; + freq_range = &intersected_rule->freq_range; + + power_rule1 = &rule1->power_rule; + power_rule2 = &rule2->power_rule; + power_rule = &intersected_rule->power_rule; + + wmm_rule1 = &rule1->wmm_rule; + wmm_rule2 = &rule2->wmm_rule; + wmm_rule = &intersected_rule->wmm_rule; + + freq_range->start_freq_khz = max(freq_range1->start_freq_khz, + freq_range2->start_freq_khz); + freq_range->end_freq_khz = min(freq_range1->end_freq_khz, + freq_range2->end_freq_khz); + + max_bandwidth1 = freq_range1->max_bandwidth_khz; + max_bandwidth2 = freq_range2->max_bandwidth_khz; + + if (rule1->flags & NL80211_RRF_AUTO_BW) + max_bandwidth1 = reg_get_max_bandwidth(rd1, rule1); + if (rule2->flags & NL80211_RRF_AUTO_BW) + max_bandwidth2 = reg_get_max_bandwidth(rd2, rule2); + + freq_range->max_bandwidth_khz = min(max_bandwidth1, max_bandwidth2); + + intersected_rule->flags = rule1->flags | rule2->flags; + + /* + * In case NL80211_RRF_AUTO_BW requested for both rules + * set AUTO_BW in intersected rule also. Next we will + * calculate BW correctly in handle_channel function. + * In other case remove AUTO_BW flag while we calculate + * maximum bandwidth correctly and auto calculation is + * not required. + */ + if ((rule1->flags & NL80211_RRF_AUTO_BW) && + (rule2->flags & NL80211_RRF_AUTO_BW)) + intersected_rule->flags |= NL80211_RRF_AUTO_BW; + else + intersected_rule->flags &= ~NL80211_RRF_AUTO_BW; + + freq_diff = freq_range->end_freq_khz - freq_range->start_freq_khz; + if (freq_range->max_bandwidth_khz > freq_diff) + freq_range->max_bandwidth_khz = freq_diff; + + power_rule->max_eirp = min(power_rule1->max_eirp, + power_rule2->max_eirp); + power_rule->max_antenna_gain = min(power_rule1->max_antenna_gain, + power_rule2->max_antenna_gain); + + intersected_rule->dfs_cac_ms = max(rule1->dfs_cac_ms, + rule2->dfs_cac_ms); + + if (rule1->has_wmm && rule2->has_wmm) { + u8 ac; + + for (ac = 0; ac < IEEE80211_NUM_ACS; ac++) { + reg_wmm_rules_intersect(&wmm_rule1->client[ac], + &wmm_rule2->client[ac], + &wmm_rule->client[ac]); + reg_wmm_rules_intersect(&wmm_rule1->ap[ac], + &wmm_rule2->ap[ac], + &wmm_rule->ap[ac]); + } + + intersected_rule->has_wmm = true; + } else if (rule1->has_wmm) { + *wmm_rule = *wmm_rule1; + intersected_rule->has_wmm = true; + } else if (rule2->has_wmm) { + *wmm_rule = *wmm_rule2; + intersected_rule->has_wmm = true; + } else { + intersected_rule->has_wmm = false; + } + + if (!is_valid_reg_rule(intersected_rule)) + return -EINVAL; + + return 0; +} + +/* check whether old rule contains new rule */ +static bool rule_contains(struct ieee80211_reg_rule *r1, + struct ieee80211_reg_rule *r2) +{ + /* for simplicity, currently consider only same flags */ + if (r1->flags != r2->flags) + return false; + + /* verify r1 is more restrictive */ + if ((r1->power_rule.max_antenna_gain > + r2->power_rule.max_antenna_gain) || + r1->power_rule.max_eirp > r2->power_rule.max_eirp) + return false; + + /* make sure r2's range is contained within r1 */ + if (r1->freq_range.start_freq_khz > r2->freq_range.start_freq_khz || + r1->freq_range.end_freq_khz < r2->freq_range.end_freq_khz) + return false; + + /* and finally verify that r1.max_bw >= r2.max_bw */ + if (r1->freq_range.max_bandwidth_khz < + r2->freq_range.max_bandwidth_khz) + return false; + + return true; +} + +/* add or extend current rules. do nothing if rule is already contained */ +static void add_rule(struct ieee80211_reg_rule *rule, + struct ieee80211_reg_rule *reg_rules, u32 *n_rules) +{ + struct ieee80211_reg_rule *tmp_rule; + int i; + + for (i = 0; i < *n_rules; i++) { + tmp_rule = ®_rules[i]; + /* rule is already contained - do nothing */ + if (rule_contains(tmp_rule, rule)) + return; + + /* extend rule if possible */ + if (rule_contains(rule, tmp_rule)) { + memcpy(tmp_rule, rule, sizeof(*rule)); + return; + } + } + + memcpy(®_rules[*n_rules], rule, sizeof(*rule)); + (*n_rules)++; +} + +/** + * regdom_intersect - do the intersection between two regulatory domains + * @rd1: first regulatory domain + * @rd2: second regulatory domain + * + * Use this function to get the intersection between two regulatory domains. + * Once completed we will mark the alpha2 for the rd as intersected, "98", + * as no one single alpha2 can represent this regulatory domain. + * + * Returns a pointer to the regulatory domain structure which will hold the + * resulting intersection of rules between rd1 and rd2. We will + * kzalloc() this structure for you. + */ +static struct ieee80211_regdomain * +regdom_intersect(const struct ieee80211_regdomain *rd1, + const struct ieee80211_regdomain *rd2) +{ + int r; + unsigned int x, y; + unsigned int num_rules = 0; + const struct ieee80211_reg_rule *rule1, *rule2; + struct ieee80211_reg_rule intersected_rule; + struct ieee80211_regdomain *rd; + + if (!rd1 || !rd2) + return NULL; + + /* + * First we get a count of the rules we'll need, then we actually + * build them. This is to so we can malloc() and free() a + * regdomain once. The reason we use reg_rules_intersect() here + * is it will return -EINVAL if the rule computed makes no sense. + * All rules that do check out OK are valid. + */ + + for (x = 0; x < rd1->n_reg_rules; x++) { + rule1 = &rd1->reg_rules[x]; + for (y = 0; y < rd2->n_reg_rules; y++) { + rule2 = &rd2->reg_rules[y]; + if (!reg_rules_intersect(rd1, rd2, rule1, rule2, + &intersected_rule)) + num_rules++; + } + } + + if (!num_rules) + return NULL; + + rd = kzalloc(struct_size(rd, reg_rules, num_rules), GFP_KERNEL); + if (!rd) + return NULL; + + for (x = 0; x < rd1->n_reg_rules; x++) { + rule1 = &rd1->reg_rules[x]; + for (y = 0; y < rd2->n_reg_rules; y++) { + rule2 = &rd2->reg_rules[y]; + r = reg_rules_intersect(rd1, rd2, rule1, rule2, + &intersected_rule); + /* + * No need to memset here the intersected rule here as + * we're not using the stack anymore + */ + if (r) + continue; + + add_rule(&intersected_rule, rd->reg_rules, + &rd->n_reg_rules); + } + } + + rd->alpha2[0] = '9'; + rd->alpha2[1] = '8'; + rd->dfs_region = reg_intersect_dfs_region(rd1->dfs_region, + rd2->dfs_region); + + return rd; +} + +/* + * XXX: add support for the rest of enum nl80211_reg_rule_flags, we may + * want to just have the channel structure use these + */ +static u32 map_regdom_flags(u32 rd_flags) +{ + u32 channel_flags = 0; + if (rd_flags & NL80211_RRF_NO_IR_ALL) + channel_flags |= IEEE80211_CHAN_NO_IR; + if (rd_flags & NL80211_RRF_DFS) + channel_flags |= IEEE80211_CHAN_RADAR; + if (rd_flags & NL80211_RRF_NO_OFDM) + channel_flags |= IEEE80211_CHAN_NO_OFDM; + if (rd_flags & NL80211_RRF_NO_OUTDOOR) + channel_flags |= IEEE80211_CHAN_INDOOR_ONLY; + if (rd_flags & NL80211_RRF_IR_CONCURRENT) + channel_flags |= IEEE80211_CHAN_IR_CONCURRENT; + if (rd_flags & NL80211_RRF_NO_HT40MINUS) + channel_flags |= IEEE80211_CHAN_NO_HT40MINUS; + if (rd_flags & NL80211_RRF_NO_HT40PLUS) + channel_flags |= IEEE80211_CHAN_NO_HT40PLUS; + if (rd_flags & NL80211_RRF_NO_80MHZ) + channel_flags |= IEEE80211_CHAN_NO_80MHZ; + if (rd_flags & NL80211_RRF_NO_160MHZ) + channel_flags |= IEEE80211_CHAN_NO_160MHZ; + if (rd_flags & NL80211_RRF_NO_HE) + channel_flags |= IEEE80211_CHAN_NO_HE; + if (rd_flags & NL80211_RRF_NO_320MHZ) + channel_flags |= IEEE80211_CHAN_NO_320MHZ; + return channel_flags; +} + +static const struct ieee80211_reg_rule * +freq_reg_info_regd(u32 center_freq, + const struct ieee80211_regdomain *regd, u32 bw) +{ + int i; + bool band_rule_found = false; + bool bw_fits = false; + + if (!regd) + return ERR_PTR(-EINVAL); + + for (i = 0; i < regd->n_reg_rules; i++) { + const struct ieee80211_reg_rule *rr; + const struct ieee80211_freq_range *fr = NULL; + + rr = ®d->reg_rules[i]; + fr = &rr->freq_range; + + /* + * We only need to know if one frequency rule was + * in center_freq's band, that's enough, so let's + * not overwrite it once found + */ + if (!band_rule_found) + band_rule_found = freq_in_rule_band(fr, center_freq); + + bw_fits = cfg80211_does_bw_fit_range(fr, center_freq, bw); + + if (band_rule_found && bw_fits) + return rr; + } + + if (!band_rule_found) + return ERR_PTR(-ERANGE); + + return ERR_PTR(-EINVAL); +} + +static const struct ieee80211_reg_rule * +__freq_reg_info(struct wiphy *wiphy, u32 center_freq, u32 min_bw) +{ + const struct ieee80211_regdomain *regd = reg_get_regdomain(wiphy); + static const u32 bws[] = {0, 1, 2, 4, 5, 8, 10, 16, 20}; + const struct ieee80211_reg_rule *reg_rule = ERR_PTR(-ERANGE); + int i = ARRAY_SIZE(bws) - 1; + u32 bw; + + for (bw = MHZ_TO_KHZ(bws[i]); bw >= min_bw; bw = MHZ_TO_KHZ(bws[i--])) { + reg_rule = freq_reg_info_regd(center_freq, regd, bw); + if (!IS_ERR(reg_rule)) + return reg_rule; + } + + return reg_rule; +} + +const struct ieee80211_reg_rule *freq_reg_info(struct wiphy *wiphy, + u32 center_freq) +{ + u32 min_bw = center_freq < MHZ_TO_KHZ(1000) ? 1 : 20; + + return __freq_reg_info(wiphy, center_freq, MHZ_TO_KHZ(min_bw)); +} +EXPORT_SYMBOL(freq_reg_info); + +const char *reg_initiator_name(enum nl80211_reg_initiator initiator) +{ + switch (initiator) { + case NL80211_REGDOM_SET_BY_CORE: + return "core"; + case NL80211_REGDOM_SET_BY_USER: + return "user"; + case NL80211_REGDOM_SET_BY_DRIVER: + return "driver"; + case NL80211_REGDOM_SET_BY_COUNTRY_IE: + return "country element"; + default: + WARN_ON(1); + return "bug"; + } +} +EXPORT_SYMBOL(reg_initiator_name); + +static uint32_t reg_rule_to_chan_bw_flags(const struct ieee80211_regdomain *regd, + const struct ieee80211_reg_rule *reg_rule, + const struct ieee80211_channel *chan) +{ + const struct ieee80211_freq_range *freq_range = NULL; + u32 max_bandwidth_khz, center_freq_khz, bw_flags = 0; + bool is_s1g = chan->band == NL80211_BAND_S1GHZ; + + freq_range = ®_rule->freq_range; + + max_bandwidth_khz = freq_range->max_bandwidth_khz; + center_freq_khz = ieee80211_channel_to_khz(chan); + /* Check if auto calculation requested */ + if (reg_rule->flags & NL80211_RRF_AUTO_BW) + max_bandwidth_khz = reg_get_max_bandwidth(regd, reg_rule); + + /* If we get a reg_rule we can assume that at least 5Mhz fit */ + if (!cfg80211_does_bw_fit_range(freq_range, + center_freq_khz, + MHZ_TO_KHZ(10))) + bw_flags |= IEEE80211_CHAN_NO_10MHZ; + if (!cfg80211_does_bw_fit_range(freq_range, + center_freq_khz, + MHZ_TO_KHZ(20))) + bw_flags |= IEEE80211_CHAN_NO_20MHZ; + + if (is_s1g) { + /* S1G is strict about non overlapping channels. We can + * calculate which bandwidth is allowed per channel by finding + * the largest bandwidth which cleanly divides the freq_range. + */ + int edge_offset; + int ch_bw = max_bandwidth_khz; + + while (ch_bw) { + edge_offset = (center_freq_khz - ch_bw / 2) - + freq_range->start_freq_khz; + if (edge_offset % ch_bw == 0) { + switch (KHZ_TO_MHZ(ch_bw)) { + case 1: + bw_flags |= IEEE80211_CHAN_1MHZ; + break; + case 2: + bw_flags |= IEEE80211_CHAN_2MHZ; + break; + case 4: + bw_flags |= IEEE80211_CHAN_4MHZ; + break; + case 8: + bw_flags |= IEEE80211_CHAN_8MHZ; + break; + case 16: + bw_flags |= IEEE80211_CHAN_16MHZ; + break; + default: + /* If we got here, no bandwidths fit on + * this frequency, ie. band edge. + */ + bw_flags |= IEEE80211_CHAN_DISABLED; + break; + } + break; + } + ch_bw /= 2; + } + } else { + if (max_bandwidth_khz < MHZ_TO_KHZ(10)) + bw_flags |= IEEE80211_CHAN_NO_10MHZ; + if (max_bandwidth_khz < MHZ_TO_KHZ(20)) + bw_flags |= IEEE80211_CHAN_NO_20MHZ; + if (max_bandwidth_khz < MHZ_TO_KHZ(40)) + bw_flags |= IEEE80211_CHAN_NO_HT40; + if (max_bandwidth_khz < MHZ_TO_KHZ(80)) + bw_flags |= IEEE80211_CHAN_NO_80MHZ; + if (max_bandwidth_khz < MHZ_TO_KHZ(160)) + bw_flags |= IEEE80211_CHAN_NO_160MHZ; + if (max_bandwidth_khz < MHZ_TO_KHZ(320)) + bw_flags |= IEEE80211_CHAN_NO_320MHZ; + } + return bw_flags; +} + +static void handle_channel_single_rule(struct wiphy *wiphy, + enum nl80211_reg_initiator initiator, + struct ieee80211_channel *chan, + u32 flags, + struct regulatory_request *lr, + struct wiphy *request_wiphy, + const struct ieee80211_reg_rule *reg_rule) +{ + u32 bw_flags = 0; + const struct ieee80211_power_rule *power_rule = NULL; + const struct ieee80211_regdomain *regd; + + regd = reg_get_regdomain(wiphy); + + power_rule = ®_rule->power_rule; + bw_flags = reg_rule_to_chan_bw_flags(regd, reg_rule, chan); + + if (lr->initiator == NL80211_REGDOM_SET_BY_DRIVER && + request_wiphy && request_wiphy == wiphy && + request_wiphy->regulatory_flags & REGULATORY_STRICT_REG) { + /* + * This guarantees the driver's requested regulatory domain + * will always be used as a base for further regulatory + * settings + */ + chan->flags = chan->orig_flags = + map_regdom_flags(reg_rule->flags) | bw_flags; + chan->max_antenna_gain = chan->orig_mag = + (int) MBI_TO_DBI(power_rule->max_antenna_gain); + chan->max_reg_power = chan->max_power = chan->orig_mpwr = + (int) MBM_TO_DBM(power_rule->max_eirp); + + if (chan->flags & IEEE80211_CHAN_RADAR) { + chan->dfs_cac_ms = IEEE80211_DFS_MIN_CAC_TIME_MS; + if (reg_rule->dfs_cac_ms) + chan->dfs_cac_ms = reg_rule->dfs_cac_ms; + } + + return; + } + + chan->dfs_state = NL80211_DFS_USABLE; + chan->dfs_state_entered = jiffies; + + chan->beacon_found = false; + chan->flags = flags | bw_flags | map_regdom_flags(reg_rule->flags); + chan->max_antenna_gain = + min_t(int, chan->orig_mag, + MBI_TO_DBI(power_rule->max_antenna_gain)); + chan->max_reg_power = (int) MBM_TO_DBM(power_rule->max_eirp); + + if (chan->flags & IEEE80211_CHAN_RADAR) { + if (reg_rule->dfs_cac_ms) + chan->dfs_cac_ms = reg_rule->dfs_cac_ms; + else + chan->dfs_cac_ms = IEEE80211_DFS_MIN_CAC_TIME_MS; + } + + if (chan->orig_mpwr) { + /* + * Devices that use REGULATORY_COUNTRY_IE_FOLLOW_POWER + * will always follow the passed country IE power settings. + */ + if (initiator == NL80211_REGDOM_SET_BY_COUNTRY_IE && + wiphy->regulatory_flags & REGULATORY_COUNTRY_IE_FOLLOW_POWER) + chan->max_power = chan->max_reg_power; + else + chan->max_power = min(chan->orig_mpwr, + chan->max_reg_power); + } else + chan->max_power = chan->max_reg_power; +} + +static void handle_channel_adjacent_rules(struct wiphy *wiphy, + enum nl80211_reg_initiator initiator, + struct ieee80211_channel *chan, + u32 flags, + struct regulatory_request *lr, + struct wiphy *request_wiphy, + const struct ieee80211_reg_rule *rrule1, + const struct ieee80211_reg_rule *rrule2, + struct ieee80211_freq_range *comb_range) +{ + u32 bw_flags1 = 0; + u32 bw_flags2 = 0; + const struct ieee80211_power_rule *power_rule1 = NULL; + const struct ieee80211_power_rule *power_rule2 = NULL; + const struct ieee80211_regdomain *regd; + + regd = reg_get_regdomain(wiphy); + + power_rule1 = &rrule1->power_rule; + power_rule2 = &rrule2->power_rule; + bw_flags1 = reg_rule_to_chan_bw_flags(regd, rrule1, chan); + bw_flags2 = reg_rule_to_chan_bw_flags(regd, rrule2, chan); + + if (lr->initiator == NL80211_REGDOM_SET_BY_DRIVER && + request_wiphy && request_wiphy == wiphy && + request_wiphy->regulatory_flags & REGULATORY_STRICT_REG) { + /* This guarantees the driver's requested regulatory domain + * will always be used as a base for further regulatory + * settings + */ + chan->flags = + map_regdom_flags(rrule1->flags) | + map_regdom_flags(rrule2->flags) | + bw_flags1 | + bw_flags2; + chan->orig_flags = chan->flags; + chan->max_antenna_gain = + min_t(int, MBI_TO_DBI(power_rule1->max_antenna_gain), + MBI_TO_DBI(power_rule2->max_antenna_gain)); + chan->orig_mag = chan->max_antenna_gain; + chan->max_reg_power = + min_t(int, MBM_TO_DBM(power_rule1->max_eirp), + MBM_TO_DBM(power_rule2->max_eirp)); + chan->max_power = chan->max_reg_power; + chan->orig_mpwr = chan->max_reg_power; + + if (chan->flags & IEEE80211_CHAN_RADAR) { + chan->dfs_cac_ms = IEEE80211_DFS_MIN_CAC_TIME_MS; + if (rrule1->dfs_cac_ms || rrule2->dfs_cac_ms) + chan->dfs_cac_ms = max_t(unsigned int, + rrule1->dfs_cac_ms, + rrule2->dfs_cac_ms); + } + + return; + } + + chan->dfs_state = NL80211_DFS_USABLE; + chan->dfs_state_entered = jiffies; + + chan->beacon_found = false; + chan->flags = flags | bw_flags1 | bw_flags2 | + map_regdom_flags(rrule1->flags) | + map_regdom_flags(rrule2->flags); + + /* reg_rule_to_chan_bw_flags may forbids 10 and forbids 20 MHz + * (otherwise no adj. rule case), recheck therefore + */ + if (cfg80211_does_bw_fit_range(comb_range, + ieee80211_channel_to_khz(chan), + MHZ_TO_KHZ(10))) + chan->flags &= ~IEEE80211_CHAN_NO_10MHZ; + if (cfg80211_does_bw_fit_range(comb_range, + ieee80211_channel_to_khz(chan), + MHZ_TO_KHZ(20))) + chan->flags &= ~IEEE80211_CHAN_NO_20MHZ; + + chan->max_antenna_gain = + min_t(int, chan->orig_mag, + min_t(int, + MBI_TO_DBI(power_rule1->max_antenna_gain), + MBI_TO_DBI(power_rule2->max_antenna_gain))); + chan->max_reg_power = min_t(int, + MBM_TO_DBM(power_rule1->max_eirp), + MBM_TO_DBM(power_rule2->max_eirp)); + + if (chan->flags & IEEE80211_CHAN_RADAR) { + if (rrule1->dfs_cac_ms || rrule2->dfs_cac_ms) + chan->dfs_cac_ms = max_t(unsigned int, + rrule1->dfs_cac_ms, + rrule2->dfs_cac_ms); + else + chan->dfs_cac_ms = IEEE80211_DFS_MIN_CAC_TIME_MS; + } + + if (chan->orig_mpwr) { + /* Devices that use REGULATORY_COUNTRY_IE_FOLLOW_POWER + * will always follow the passed country IE power settings. + */ + if (initiator == NL80211_REGDOM_SET_BY_COUNTRY_IE && + wiphy->regulatory_flags & REGULATORY_COUNTRY_IE_FOLLOW_POWER) + chan->max_power = chan->max_reg_power; + else + chan->max_power = min(chan->orig_mpwr, + chan->max_reg_power); + } else { + chan->max_power = chan->max_reg_power; + } +} + +/* Note that right now we assume the desired channel bandwidth + * is always 20 MHz for each individual channel (HT40 uses 20 MHz + * per channel, the primary and the extension channel). + */ +static void handle_channel(struct wiphy *wiphy, + enum nl80211_reg_initiator initiator, + struct ieee80211_channel *chan) +{ + const u32 orig_chan_freq = ieee80211_channel_to_khz(chan); + struct regulatory_request *lr = get_last_request(); + struct wiphy *request_wiphy = wiphy_idx_to_wiphy(lr->wiphy_idx); + const struct ieee80211_reg_rule *rrule = NULL; + const struct ieee80211_reg_rule *rrule1 = NULL; + const struct ieee80211_reg_rule *rrule2 = NULL; + + u32 flags = chan->orig_flags; + + rrule = freq_reg_info(wiphy, orig_chan_freq); + if (IS_ERR(rrule)) { + /* check for adjacent match, therefore get rules for + * chan - 20 MHz and chan + 20 MHz and test + * if reg rules are adjacent + */ + rrule1 = freq_reg_info(wiphy, + orig_chan_freq - MHZ_TO_KHZ(20)); + rrule2 = freq_reg_info(wiphy, + orig_chan_freq + MHZ_TO_KHZ(20)); + if (!IS_ERR(rrule1) && !IS_ERR(rrule2)) { + struct ieee80211_freq_range comb_range; + + if (rrule1->freq_range.end_freq_khz != + rrule2->freq_range.start_freq_khz) + goto disable_chan; + + comb_range.start_freq_khz = + rrule1->freq_range.start_freq_khz; + comb_range.end_freq_khz = + rrule2->freq_range.end_freq_khz; + comb_range.max_bandwidth_khz = + min_t(u32, + rrule1->freq_range.max_bandwidth_khz, + rrule2->freq_range.max_bandwidth_khz); + + if (!cfg80211_does_bw_fit_range(&comb_range, + orig_chan_freq, + MHZ_TO_KHZ(20))) + goto disable_chan; + + handle_channel_adjacent_rules(wiphy, initiator, chan, + flags, lr, request_wiphy, + rrule1, rrule2, + &comb_range); + return; + } + +disable_chan: + /* We will disable all channels that do not match our + * received regulatory rule unless the hint is coming + * from a Country IE and the Country IE had no information + * about a band. The IEEE 802.11 spec allows for an AP + * to send only a subset of the regulatory rules allowed, + * so an AP in the US that only supports 2.4 GHz may only send + * a country IE with information for the 2.4 GHz band + * while 5 GHz is still supported. + */ + if (initiator == NL80211_REGDOM_SET_BY_COUNTRY_IE && + PTR_ERR(rrule) == -ERANGE) + return; + + if (lr->initiator == NL80211_REGDOM_SET_BY_DRIVER && + request_wiphy && request_wiphy == wiphy && + request_wiphy->regulatory_flags & REGULATORY_STRICT_REG) { + pr_debug("Disabling freq %d.%03d MHz for good\n", + chan->center_freq, chan->freq_offset); + chan->orig_flags |= IEEE80211_CHAN_DISABLED; + chan->flags = chan->orig_flags; + } else { + pr_debug("Disabling freq %d.%03d MHz\n", + chan->center_freq, chan->freq_offset); + chan->flags |= IEEE80211_CHAN_DISABLED; + } + return; + } + + handle_channel_single_rule(wiphy, initiator, chan, flags, lr, + request_wiphy, rrule); +} + +static void handle_band(struct wiphy *wiphy, + enum nl80211_reg_initiator initiator, + struct ieee80211_supported_band *sband) +{ + unsigned int i; + + if (!sband) + return; + + for (i = 0; i < sband->n_channels; i++) + handle_channel(wiphy, initiator, &sband->channels[i]); +} + +static bool reg_request_cell_base(struct regulatory_request *request) +{ + if (request->initiator != NL80211_REGDOM_SET_BY_USER) + return false; + return request->user_reg_hint_type == NL80211_USER_REG_HINT_CELL_BASE; +} + +bool reg_last_request_cell_base(void) +{ + return reg_request_cell_base(get_last_request()); +} + +#ifdef CONFIG_CFG80211_REG_CELLULAR_HINTS +/* Core specific check */ +static enum reg_request_treatment +reg_ignore_cell_hint(struct regulatory_request *pending_request) +{ + struct regulatory_request *lr = get_last_request(); + + if (!reg_num_devs_support_basehint) + return REG_REQ_IGNORE; + + if (reg_request_cell_base(lr) && + !regdom_changes(pending_request->alpha2)) + return REG_REQ_ALREADY_SET; + + return REG_REQ_OK; +} + +/* Device specific check */ +static bool reg_dev_ignore_cell_hint(struct wiphy *wiphy) +{ + return !(wiphy->features & NL80211_FEATURE_CELL_BASE_REG_HINTS); +} +#else +static enum reg_request_treatment +reg_ignore_cell_hint(struct regulatory_request *pending_request) +{ + return REG_REQ_IGNORE; +} + +static bool reg_dev_ignore_cell_hint(struct wiphy *wiphy) +{ + return true; +} +#endif + +static bool wiphy_strict_alpha2_regd(struct wiphy *wiphy) +{ + if (wiphy->regulatory_flags & REGULATORY_STRICT_REG && + !(wiphy->regulatory_flags & REGULATORY_CUSTOM_REG)) + return true; + return false; +} + +static bool ignore_reg_update(struct wiphy *wiphy, + enum nl80211_reg_initiator initiator) +{ + struct regulatory_request *lr = get_last_request(); + + if (wiphy->regulatory_flags & REGULATORY_WIPHY_SELF_MANAGED) + return true; + + if (!lr) { + pr_debug("Ignoring regulatory request set by %s since last_request is not set\n", + reg_initiator_name(initiator)); + return true; + } + + if (initiator == NL80211_REGDOM_SET_BY_CORE && + wiphy->regulatory_flags & REGULATORY_CUSTOM_REG) { + pr_debug("Ignoring regulatory request set by %s since the driver uses its own custom regulatory domain\n", + reg_initiator_name(initiator)); + return true; + } + + /* + * wiphy->regd will be set once the device has its own + * desired regulatory domain set + */ + if (wiphy_strict_alpha2_regd(wiphy) && !wiphy->regd && + initiator != NL80211_REGDOM_SET_BY_COUNTRY_IE && + !is_world_regdom(lr->alpha2)) { + pr_debug("Ignoring regulatory request set by %s since the driver requires its own regulatory domain to be set first\n", + reg_initiator_name(initiator)); + return true; + } + + if (reg_request_cell_base(lr)) + return reg_dev_ignore_cell_hint(wiphy); + + return false; +} + +static bool reg_is_world_roaming(struct wiphy *wiphy) +{ + const struct ieee80211_regdomain *cr = get_cfg80211_regdom(); + const struct ieee80211_regdomain *wr = get_wiphy_regdom(wiphy); + struct regulatory_request *lr = get_last_request(); + + if (is_world_regdom(cr->alpha2) || (wr && is_world_regdom(wr->alpha2))) + return true; + + if (lr && lr->initiator != NL80211_REGDOM_SET_BY_COUNTRY_IE && + wiphy->regulatory_flags & REGULATORY_CUSTOM_REG) + return true; + + return false; +} + +static void handle_reg_beacon(struct wiphy *wiphy, unsigned int chan_idx, + struct reg_beacon *reg_beacon) +{ + struct ieee80211_supported_band *sband; + struct ieee80211_channel *chan; + bool channel_changed = false; + struct ieee80211_channel chan_before; + + sband = wiphy->bands[reg_beacon->chan.band]; + chan = &sband->channels[chan_idx]; + + if (likely(!ieee80211_channel_equal(chan, ®_beacon->chan))) + return; + + if (chan->beacon_found) + return; + + chan->beacon_found = true; + + if (!reg_is_world_roaming(wiphy)) + return; + + if (wiphy->regulatory_flags & REGULATORY_DISABLE_BEACON_HINTS) + return; + + chan_before = *chan; + + if (chan->flags & IEEE80211_CHAN_NO_IR) { + chan->flags &= ~IEEE80211_CHAN_NO_IR; + channel_changed = true; + } + + if (channel_changed) + nl80211_send_beacon_hint_event(wiphy, &chan_before, chan); +} + +/* + * Called when a scan on a wiphy finds a beacon on + * new channel + */ +static void wiphy_update_new_beacon(struct wiphy *wiphy, + struct reg_beacon *reg_beacon) +{ + unsigned int i; + struct ieee80211_supported_band *sband; + + if (!wiphy->bands[reg_beacon->chan.band]) + return; + + sband = wiphy->bands[reg_beacon->chan.band]; + + for (i = 0; i < sband->n_channels; i++) + handle_reg_beacon(wiphy, i, reg_beacon); +} + +/* + * Called upon reg changes or a new wiphy is added + */ +static void wiphy_update_beacon_reg(struct wiphy *wiphy) +{ + unsigned int i; + struct ieee80211_supported_band *sband; + struct reg_beacon *reg_beacon; + + list_for_each_entry(reg_beacon, ®_beacon_list, list) { + if (!wiphy->bands[reg_beacon->chan.band]) + continue; + sband = wiphy->bands[reg_beacon->chan.band]; + for (i = 0; i < sband->n_channels; i++) + handle_reg_beacon(wiphy, i, reg_beacon); + } +} + +/* Reap the advantages of previously found beacons */ +static void reg_process_beacons(struct wiphy *wiphy) +{ + /* + * Means we are just firing up cfg80211, so no beacons would + * have been processed yet. + */ + if (!last_request) + return; + wiphy_update_beacon_reg(wiphy); +} + +static bool is_ht40_allowed(struct ieee80211_channel *chan) +{ + if (!chan) + return false; + if (chan->flags & IEEE80211_CHAN_DISABLED) + return false; + /* This would happen when regulatory rules disallow HT40 completely */ + if ((chan->flags & IEEE80211_CHAN_NO_HT40) == IEEE80211_CHAN_NO_HT40) + return false; + return true; +} + +static void reg_process_ht_flags_channel(struct wiphy *wiphy, + struct ieee80211_channel *channel) +{ + struct ieee80211_supported_band *sband = wiphy->bands[channel->band]; + struct ieee80211_channel *channel_before = NULL, *channel_after = NULL; + const struct ieee80211_regdomain *regd; + unsigned int i; + u32 flags; + + if (!is_ht40_allowed(channel)) { + channel->flags |= IEEE80211_CHAN_NO_HT40; + return; + } + + /* + * We need to ensure the extension channels exist to + * be able to use HT40- or HT40+, this finds them (or not) + */ + for (i = 0; i < sband->n_channels; i++) { + struct ieee80211_channel *c = &sband->channels[i]; + + if (c->center_freq == (channel->center_freq - 20)) + channel_before = c; + if (c->center_freq == (channel->center_freq + 20)) + channel_after = c; + } + + flags = 0; + regd = get_wiphy_regdom(wiphy); + if (regd) { + const struct ieee80211_reg_rule *reg_rule = + freq_reg_info_regd(MHZ_TO_KHZ(channel->center_freq), + regd, MHZ_TO_KHZ(20)); + + if (!IS_ERR(reg_rule)) + flags = reg_rule->flags; + } + + /* + * Please note that this assumes target bandwidth is 20 MHz, + * if that ever changes we also need to change the below logic + * to include that as well. + */ + if (!is_ht40_allowed(channel_before) || + flags & NL80211_RRF_NO_HT40MINUS) + channel->flags |= IEEE80211_CHAN_NO_HT40MINUS; + else + channel->flags &= ~IEEE80211_CHAN_NO_HT40MINUS; + + if (!is_ht40_allowed(channel_after) || + flags & NL80211_RRF_NO_HT40PLUS) + channel->flags |= IEEE80211_CHAN_NO_HT40PLUS; + else + channel->flags &= ~IEEE80211_CHAN_NO_HT40PLUS; +} + +static void reg_process_ht_flags_band(struct wiphy *wiphy, + struct ieee80211_supported_band *sband) +{ + unsigned int i; + + if (!sband) + return; + + for (i = 0; i < sband->n_channels; i++) + reg_process_ht_flags_channel(wiphy, &sband->channels[i]); +} + +static void reg_process_ht_flags(struct wiphy *wiphy) +{ + enum nl80211_band band; + + if (!wiphy) + return; + + for (band = 0; band < NUM_NL80211_BANDS; band++) + reg_process_ht_flags_band(wiphy, wiphy->bands[band]); +} + +static void reg_call_notifier(struct wiphy *wiphy, + struct regulatory_request *request) +{ + if (wiphy->reg_notifier) + wiphy->reg_notifier(wiphy, request); +} + +static bool reg_wdev_chan_valid(struct wiphy *wiphy, struct wireless_dev *wdev) +{ + struct cfg80211_chan_def chandef = {}; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + enum nl80211_iftype iftype; + bool ret; + int link; + + wdev_lock(wdev); + iftype = wdev->iftype; + + /* make sure the interface is active */ + if (!wdev->netdev || !netif_running(wdev->netdev)) + goto wdev_inactive_unlock; + + for (link = 0; link < ARRAY_SIZE(wdev->links); link++) { + struct ieee80211_channel *chan; + + if (!wdev->valid_links && link > 0) + break; + if (wdev->valid_links && !(wdev->valid_links & BIT(link))) + continue; + switch (iftype) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_P2P_GO: + if (!wdev->links[link].ap.beacon_interval) + continue; + chandef = wdev->links[link].ap.chandef; + break; + case NL80211_IFTYPE_MESH_POINT: + if (!wdev->u.mesh.beacon_interval) + continue; + chandef = wdev->u.mesh.chandef; + break; + case NL80211_IFTYPE_ADHOC: + if (!wdev->u.ibss.ssid_len) + continue; + chandef = wdev->u.ibss.chandef; + break; + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_P2P_CLIENT: + /* Maybe we could consider disabling that link only? */ + if (!wdev->links[link].client.current_bss) + continue; + + chan = wdev->links[link].client.current_bss->pub.channel; + if (!chan) + continue; + + if (!rdev->ops->get_channel || + rdev_get_channel(rdev, wdev, link, &chandef)) + cfg80211_chandef_create(&chandef, chan, + NL80211_CHAN_NO_HT); + break; + case NL80211_IFTYPE_MONITOR: + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_P2P_DEVICE: + /* no enforcement required */ + break; + case NL80211_IFTYPE_OCB: + if (!wdev->u.ocb.chandef.chan) + continue; + chandef = wdev->u.ocb.chandef; + break; + case NL80211_IFTYPE_NAN: + /* we have no info, but NAN is also pretty universal */ + continue; + default: + /* others not implemented for now */ + WARN_ON_ONCE(1); + break; + } + + wdev_unlock(wdev); + + switch (iftype) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_P2P_GO: + case NL80211_IFTYPE_ADHOC: + case NL80211_IFTYPE_MESH_POINT: + ret = cfg80211_reg_can_beacon_relax(wiphy, &chandef, + iftype); + if (!ret) + return ret; + break; + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_P2P_CLIENT: + ret = cfg80211_chandef_usable(wiphy, &chandef, + IEEE80211_CHAN_DISABLED); + if (!ret) + return ret; + break; + default: + break; + } + + wdev_lock(wdev); + } + + wdev_unlock(wdev); + + return true; + +wdev_inactive_unlock: + wdev_unlock(wdev); + return true; +} + +static void reg_leave_invalid_chans(struct wiphy *wiphy) +{ + struct wireless_dev *wdev; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + wiphy_lock(wiphy); + list_for_each_entry(wdev, &rdev->wiphy.wdev_list, list) + if (!reg_wdev_chan_valid(wiphy, wdev)) + cfg80211_leave(rdev, wdev); + wiphy_unlock(wiphy); +} + +static void reg_check_chans_work(struct work_struct *work) +{ + struct cfg80211_registered_device *rdev; + + pr_debug("Verifying active interfaces after reg change\n"); + rtnl_lock(); + + list_for_each_entry(rdev, &cfg80211_rdev_list, list) + reg_leave_invalid_chans(&rdev->wiphy); + + rtnl_unlock(); +} + +static void reg_check_channels(void) +{ + /* + * Give usermode a chance to do something nicer (move to another + * channel, orderly disconnection), before forcing a disconnection. + */ + mod_delayed_work(system_power_efficient_wq, + ®_check_chans, + msecs_to_jiffies(REG_ENFORCE_GRACE_MS)); +} + +static void wiphy_update_regulatory(struct wiphy *wiphy, + enum nl80211_reg_initiator initiator) +{ + enum nl80211_band band; + struct regulatory_request *lr = get_last_request(); + + if (ignore_reg_update(wiphy, initiator)) { + /* + * Regulatory updates set by CORE are ignored for custom + * regulatory cards. Let us notify the changes to the driver, + * as some drivers used this to restore its orig_* reg domain. + */ + if (initiator == NL80211_REGDOM_SET_BY_CORE && + wiphy->regulatory_flags & REGULATORY_CUSTOM_REG && + !(wiphy->regulatory_flags & + REGULATORY_WIPHY_SELF_MANAGED)) + reg_call_notifier(wiphy, lr); + return; + } + + lr->dfs_region = get_cfg80211_regdom()->dfs_region; + + for (band = 0; band < NUM_NL80211_BANDS; band++) + handle_band(wiphy, initiator, wiphy->bands[band]); + + reg_process_beacons(wiphy); + reg_process_ht_flags(wiphy); + reg_call_notifier(wiphy, lr); +} + +static void update_all_wiphy_regulatory(enum nl80211_reg_initiator initiator) +{ + struct cfg80211_registered_device *rdev; + struct wiphy *wiphy; + + ASSERT_RTNL(); + + list_for_each_entry(rdev, &cfg80211_rdev_list, list) { + wiphy = &rdev->wiphy; + wiphy_update_regulatory(wiphy, initiator); + } + + reg_check_channels(); +} + +static void handle_channel_custom(struct wiphy *wiphy, + struct ieee80211_channel *chan, + const struct ieee80211_regdomain *regd, + u32 min_bw) +{ + u32 bw_flags = 0; + const struct ieee80211_reg_rule *reg_rule = NULL; + const struct ieee80211_power_rule *power_rule = NULL; + u32 bw, center_freq_khz; + + center_freq_khz = ieee80211_channel_to_khz(chan); + for (bw = MHZ_TO_KHZ(20); bw >= min_bw; bw = bw / 2) { + reg_rule = freq_reg_info_regd(center_freq_khz, regd, bw); + if (!IS_ERR(reg_rule)) + break; + } + + if (IS_ERR_OR_NULL(reg_rule)) { + pr_debug("Disabling freq %d.%03d MHz as custom regd has no rule that fits it\n", + chan->center_freq, chan->freq_offset); + if (wiphy->regulatory_flags & REGULATORY_WIPHY_SELF_MANAGED) { + chan->flags |= IEEE80211_CHAN_DISABLED; + } else { + chan->orig_flags |= IEEE80211_CHAN_DISABLED; + chan->flags = chan->orig_flags; + } + return; + } + + power_rule = ®_rule->power_rule; + bw_flags = reg_rule_to_chan_bw_flags(regd, reg_rule, chan); + + chan->dfs_state_entered = jiffies; + chan->dfs_state = NL80211_DFS_USABLE; + + chan->beacon_found = false; + + if (wiphy->regulatory_flags & REGULATORY_WIPHY_SELF_MANAGED) + chan->flags = chan->orig_flags | bw_flags | + map_regdom_flags(reg_rule->flags); + else + chan->flags |= map_regdom_flags(reg_rule->flags) | bw_flags; + + chan->max_antenna_gain = (int) MBI_TO_DBI(power_rule->max_antenna_gain); + chan->max_reg_power = chan->max_power = + (int) MBM_TO_DBM(power_rule->max_eirp); + + if (chan->flags & IEEE80211_CHAN_RADAR) { + if (reg_rule->dfs_cac_ms) + chan->dfs_cac_ms = reg_rule->dfs_cac_ms; + else + chan->dfs_cac_ms = IEEE80211_DFS_MIN_CAC_TIME_MS; + } + + chan->max_power = chan->max_reg_power; +} + +static void handle_band_custom(struct wiphy *wiphy, + struct ieee80211_supported_band *sband, + const struct ieee80211_regdomain *regd) +{ + unsigned int i; + + if (!sband) + return; + + /* + * We currently assume that you always want at least 20 MHz, + * otherwise channel 12 might get enabled if this rule is + * compatible to US, which permits 2402 - 2472 MHz. + */ + for (i = 0; i < sband->n_channels; i++) + handle_channel_custom(wiphy, &sband->channels[i], regd, + MHZ_TO_KHZ(20)); +} + +/* Used by drivers prior to wiphy registration */ +void wiphy_apply_custom_regulatory(struct wiphy *wiphy, + const struct ieee80211_regdomain *regd) +{ + const struct ieee80211_regdomain *new_regd, *tmp; + enum nl80211_band band; + unsigned int bands_set = 0; + + WARN(!(wiphy->regulatory_flags & REGULATORY_CUSTOM_REG), + "wiphy should have REGULATORY_CUSTOM_REG\n"); + wiphy->regulatory_flags |= REGULATORY_CUSTOM_REG; + + for (band = 0; band < NUM_NL80211_BANDS; band++) { + if (!wiphy->bands[band]) + continue; + handle_band_custom(wiphy, wiphy->bands[band], regd); + bands_set++; + } + + /* + * no point in calling this if it won't have any effect + * on your device's supported bands. + */ + WARN_ON(!bands_set); + new_regd = reg_copy_regd(regd); + if (IS_ERR(new_regd)) + return; + + rtnl_lock(); + wiphy_lock(wiphy); + + tmp = get_wiphy_regdom(wiphy); + rcu_assign_pointer(wiphy->regd, new_regd); + rcu_free_regdom(tmp); + + wiphy_unlock(wiphy); + rtnl_unlock(); +} +EXPORT_SYMBOL(wiphy_apply_custom_regulatory); + +static void reg_set_request_processed(void) +{ + bool need_more_processing = false; + struct regulatory_request *lr = get_last_request(); + + lr->processed = true; + + spin_lock(®_requests_lock); + if (!list_empty(®_requests_list)) + need_more_processing = true; + spin_unlock(®_requests_lock); + + cancel_crda_timeout(); + + if (need_more_processing) + schedule_work(®_work); +} + +/** + * reg_process_hint_core - process core regulatory requests + * @core_request: a pending core regulatory request + * + * The wireless subsystem can use this function to process + * a regulatory request issued by the regulatory core. + */ +static enum reg_request_treatment +reg_process_hint_core(struct regulatory_request *core_request) +{ + if (reg_query_database(core_request)) { + core_request->intersect = false; + core_request->processed = false; + reg_update_last_request(core_request); + return REG_REQ_OK; + } + + return REG_REQ_IGNORE; +} + +static enum reg_request_treatment +__reg_process_hint_user(struct regulatory_request *user_request) +{ + struct regulatory_request *lr = get_last_request(); + + if (reg_request_cell_base(user_request)) + return reg_ignore_cell_hint(user_request); + + if (reg_request_cell_base(lr)) + return REG_REQ_IGNORE; + + if (lr->initiator == NL80211_REGDOM_SET_BY_COUNTRY_IE) + return REG_REQ_INTERSECT; + /* + * If the user knows better the user should set the regdom + * to their country before the IE is picked up + */ + if (lr->initiator == NL80211_REGDOM_SET_BY_USER && + lr->intersect) + return REG_REQ_IGNORE; + /* + * Process user requests only after previous user/driver/core + * requests have been processed + */ + if ((lr->initiator == NL80211_REGDOM_SET_BY_CORE || + lr->initiator == NL80211_REGDOM_SET_BY_DRIVER || + lr->initiator == NL80211_REGDOM_SET_BY_USER) && + regdom_changes(lr->alpha2)) + return REG_REQ_IGNORE; + + if (!regdom_changes(user_request->alpha2)) + return REG_REQ_ALREADY_SET; + + return REG_REQ_OK; +} + +/** + * reg_process_hint_user - process user regulatory requests + * @user_request: a pending user regulatory request + * + * The wireless subsystem can use this function to process + * a regulatory request initiated by userspace. + */ +static enum reg_request_treatment +reg_process_hint_user(struct regulatory_request *user_request) +{ + enum reg_request_treatment treatment; + + treatment = __reg_process_hint_user(user_request); + if (treatment == REG_REQ_IGNORE || + treatment == REG_REQ_ALREADY_SET) + return REG_REQ_IGNORE; + + user_request->intersect = treatment == REG_REQ_INTERSECT; + user_request->processed = false; + + if (reg_query_database(user_request)) { + reg_update_last_request(user_request); + user_alpha2[0] = user_request->alpha2[0]; + user_alpha2[1] = user_request->alpha2[1]; + return REG_REQ_OK; + } + + return REG_REQ_IGNORE; +} + +static enum reg_request_treatment +__reg_process_hint_driver(struct regulatory_request *driver_request) +{ + struct regulatory_request *lr = get_last_request(); + + if (lr->initiator == NL80211_REGDOM_SET_BY_CORE) { + if (regdom_changes(driver_request->alpha2)) + return REG_REQ_OK; + return REG_REQ_ALREADY_SET; + } + + /* + * This would happen if you unplug and plug your card + * back in or if you add a new device for which the previously + * loaded card also agrees on the regulatory domain. + */ + if (lr->initiator == NL80211_REGDOM_SET_BY_DRIVER && + !regdom_changes(driver_request->alpha2)) + return REG_REQ_ALREADY_SET; + + return REG_REQ_INTERSECT; +} + +/** + * reg_process_hint_driver - process driver regulatory requests + * @wiphy: the wireless device for the regulatory request + * @driver_request: a pending driver regulatory request + * + * The wireless subsystem can use this function to process + * a regulatory request issued by an 802.11 driver. + * + * Returns one of the different reg request treatment values. + */ +static enum reg_request_treatment +reg_process_hint_driver(struct wiphy *wiphy, + struct regulatory_request *driver_request) +{ + const struct ieee80211_regdomain *regd, *tmp; + enum reg_request_treatment treatment; + + treatment = __reg_process_hint_driver(driver_request); + + switch (treatment) { + case REG_REQ_OK: + break; + case REG_REQ_IGNORE: + return REG_REQ_IGNORE; + case REG_REQ_INTERSECT: + case REG_REQ_ALREADY_SET: + regd = reg_copy_regd(get_cfg80211_regdom()); + if (IS_ERR(regd)) + return REG_REQ_IGNORE; + + tmp = get_wiphy_regdom(wiphy); + ASSERT_RTNL(); + wiphy_lock(wiphy); + rcu_assign_pointer(wiphy->regd, regd); + wiphy_unlock(wiphy); + rcu_free_regdom(tmp); + } + + + driver_request->intersect = treatment == REG_REQ_INTERSECT; + driver_request->processed = false; + + /* + * Since CRDA will not be called in this case as we already + * have applied the requested regulatory domain before we just + * inform userspace we have processed the request + */ + if (treatment == REG_REQ_ALREADY_SET) { + nl80211_send_reg_change_event(driver_request); + reg_update_last_request(driver_request); + reg_set_request_processed(); + return REG_REQ_ALREADY_SET; + } + + if (reg_query_database(driver_request)) { + reg_update_last_request(driver_request); + return REG_REQ_OK; + } + + return REG_REQ_IGNORE; +} + +static enum reg_request_treatment +__reg_process_hint_country_ie(struct wiphy *wiphy, + struct regulatory_request *country_ie_request) +{ + struct wiphy *last_wiphy = NULL; + struct regulatory_request *lr = get_last_request(); + + if (reg_request_cell_base(lr)) { + /* Trust a Cell base station over the AP's country IE */ + if (regdom_changes(country_ie_request->alpha2)) + return REG_REQ_IGNORE; + return REG_REQ_ALREADY_SET; + } else { + if (wiphy->regulatory_flags & REGULATORY_COUNTRY_IE_IGNORE) + return REG_REQ_IGNORE; + } + + if (unlikely(!is_an_alpha2(country_ie_request->alpha2))) + return -EINVAL; + + if (lr->initiator != NL80211_REGDOM_SET_BY_COUNTRY_IE) + return REG_REQ_OK; + + last_wiphy = wiphy_idx_to_wiphy(lr->wiphy_idx); + + if (last_wiphy != wiphy) { + /* + * Two cards with two APs claiming different + * Country IE alpha2s. We could + * intersect them, but that seems unlikely + * to be correct. Reject second one for now. + */ + if (regdom_changes(country_ie_request->alpha2)) + return REG_REQ_IGNORE; + return REG_REQ_ALREADY_SET; + } + + if (regdom_changes(country_ie_request->alpha2)) + return REG_REQ_OK; + return REG_REQ_ALREADY_SET; +} + +/** + * reg_process_hint_country_ie - process regulatory requests from country IEs + * @wiphy: the wireless device for the regulatory request + * @country_ie_request: a regulatory request from a country IE + * + * The wireless subsystem can use this function to process + * a regulatory request issued by a country Information Element. + * + * Returns one of the different reg request treatment values. + */ +static enum reg_request_treatment +reg_process_hint_country_ie(struct wiphy *wiphy, + struct regulatory_request *country_ie_request) +{ + enum reg_request_treatment treatment; + + treatment = __reg_process_hint_country_ie(wiphy, country_ie_request); + + switch (treatment) { + case REG_REQ_OK: + break; + case REG_REQ_IGNORE: + return REG_REQ_IGNORE; + case REG_REQ_ALREADY_SET: + reg_free_request(country_ie_request); + return REG_REQ_ALREADY_SET; + case REG_REQ_INTERSECT: + /* + * This doesn't happen yet, not sure we + * ever want to support it for this case. + */ + WARN_ONCE(1, "Unexpected intersection for country elements"); + return REG_REQ_IGNORE; + } + + country_ie_request->intersect = false; + country_ie_request->processed = false; + + if (reg_query_database(country_ie_request)) { + reg_update_last_request(country_ie_request); + return REG_REQ_OK; + } + + return REG_REQ_IGNORE; +} + +bool reg_dfs_domain_same(struct wiphy *wiphy1, struct wiphy *wiphy2) +{ + const struct ieee80211_regdomain *wiphy1_regd = NULL; + const struct ieee80211_regdomain *wiphy2_regd = NULL; + const struct ieee80211_regdomain *cfg80211_regd = NULL; + bool dfs_domain_same; + + rcu_read_lock(); + + cfg80211_regd = rcu_dereference(cfg80211_regdomain); + wiphy1_regd = rcu_dereference(wiphy1->regd); + if (!wiphy1_regd) + wiphy1_regd = cfg80211_regd; + + wiphy2_regd = rcu_dereference(wiphy2->regd); + if (!wiphy2_regd) + wiphy2_regd = cfg80211_regd; + + dfs_domain_same = wiphy1_regd->dfs_region == wiphy2_regd->dfs_region; + + rcu_read_unlock(); + + return dfs_domain_same; +} + +static void reg_copy_dfs_chan_state(struct ieee80211_channel *dst_chan, + struct ieee80211_channel *src_chan) +{ + if (!(dst_chan->flags & IEEE80211_CHAN_RADAR) || + !(src_chan->flags & IEEE80211_CHAN_RADAR)) + return; + + if (dst_chan->flags & IEEE80211_CHAN_DISABLED || + src_chan->flags & IEEE80211_CHAN_DISABLED) + return; + + if (src_chan->center_freq == dst_chan->center_freq && + dst_chan->dfs_state == NL80211_DFS_USABLE) { + dst_chan->dfs_state = src_chan->dfs_state; + dst_chan->dfs_state_entered = src_chan->dfs_state_entered; + } +} + +static void wiphy_share_dfs_chan_state(struct wiphy *dst_wiphy, + struct wiphy *src_wiphy) +{ + struct ieee80211_supported_band *src_sband, *dst_sband; + struct ieee80211_channel *src_chan, *dst_chan; + int i, j, band; + + if (!reg_dfs_domain_same(dst_wiphy, src_wiphy)) + return; + + for (band = 0; band < NUM_NL80211_BANDS; band++) { + dst_sband = dst_wiphy->bands[band]; + src_sband = src_wiphy->bands[band]; + if (!dst_sband || !src_sband) + continue; + + for (i = 0; i < dst_sband->n_channels; i++) { + dst_chan = &dst_sband->channels[i]; + for (j = 0; j < src_sband->n_channels; j++) { + src_chan = &src_sband->channels[j]; + reg_copy_dfs_chan_state(dst_chan, src_chan); + } + } + } +} + +static void wiphy_all_share_dfs_chan_state(struct wiphy *wiphy) +{ + struct cfg80211_registered_device *rdev; + + ASSERT_RTNL(); + + list_for_each_entry(rdev, &cfg80211_rdev_list, list) { + if (wiphy == &rdev->wiphy) + continue; + wiphy_share_dfs_chan_state(wiphy, &rdev->wiphy); + } +} + +/* This processes *all* regulatory hints */ +static void reg_process_hint(struct regulatory_request *reg_request) +{ + struct wiphy *wiphy = NULL; + enum reg_request_treatment treatment; + enum nl80211_reg_initiator initiator = reg_request->initiator; + + if (reg_request->wiphy_idx != WIPHY_IDX_INVALID) + wiphy = wiphy_idx_to_wiphy(reg_request->wiphy_idx); + + switch (initiator) { + case NL80211_REGDOM_SET_BY_CORE: + treatment = reg_process_hint_core(reg_request); + break; + case NL80211_REGDOM_SET_BY_USER: + treatment = reg_process_hint_user(reg_request); + break; + case NL80211_REGDOM_SET_BY_DRIVER: + if (!wiphy) + goto out_free; + treatment = reg_process_hint_driver(wiphy, reg_request); + break; + case NL80211_REGDOM_SET_BY_COUNTRY_IE: + if (!wiphy) + goto out_free; + treatment = reg_process_hint_country_ie(wiphy, reg_request); + break; + default: + WARN(1, "invalid initiator %d\n", initiator); + goto out_free; + } + + if (treatment == REG_REQ_IGNORE) + goto out_free; + + WARN(treatment != REG_REQ_OK && treatment != REG_REQ_ALREADY_SET, + "unexpected treatment value %d\n", treatment); + + /* This is required so that the orig_* parameters are saved. + * NOTE: treatment must be set for any case that reaches here! + */ + if (treatment == REG_REQ_ALREADY_SET && wiphy && + wiphy->regulatory_flags & REGULATORY_STRICT_REG) { + wiphy_update_regulatory(wiphy, initiator); + wiphy_all_share_dfs_chan_state(wiphy); + reg_check_channels(); + } + + return; + +out_free: + reg_free_request(reg_request); +} + +static void notify_self_managed_wiphys(struct regulatory_request *request) +{ + struct cfg80211_registered_device *rdev; + struct wiphy *wiphy; + + list_for_each_entry(rdev, &cfg80211_rdev_list, list) { + wiphy = &rdev->wiphy; + if (wiphy->regulatory_flags & REGULATORY_WIPHY_SELF_MANAGED && + request->initiator == NL80211_REGDOM_SET_BY_USER) + reg_call_notifier(wiphy, request); + } +} + +/* + * Processes regulatory hints, this is all the NL80211_REGDOM_SET_BY_* + * Regulatory hints come on a first come first serve basis and we + * must process each one atomically. + */ +static void reg_process_pending_hints(void) +{ + struct regulatory_request *reg_request, *lr; + + lr = get_last_request(); + + /* When last_request->processed becomes true this will be rescheduled */ + if (lr && !lr->processed) { + pr_debug("Pending regulatory request, waiting for it to be processed...\n"); + return; + } + + spin_lock(®_requests_lock); + + if (list_empty(®_requests_list)) { + spin_unlock(®_requests_lock); + return; + } + + reg_request = list_first_entry(®_requests_list, + struct regulatory_request, + list); + list_del_init(®_request->list); + + spin_unlock(®_requests_lock); + + notify_self_managed_wiphys(reg_request); + + reg_process_hint(reg_request); + + lr = get_last_request(); + + spin_lock(®_requests_lock); + if (!list_empty(®_requests_list) && lr && lr->processed) + schedule_work(®_work); + spin_unlock(®_requests_lock); +} + +/* Processes beacon hints -- this has nothing to do with country IEs */ +static void reg_process_pending_beacon_hints(void) +{ + struct cfg80211_registered_device *rdev; + struct reg_beacon *pending_beacon, *tmp; + + /* This goes through the _pending_ beacon list */ + spin_lock_bh(®_pending_beacons_lock); + + list_for_each_entry_safe(pending_beacon, tmp, + ®_pending_beacons, list) { + list_del_init(&pending_beacon->list); + + /* Applies the beacon hint to current wiphys */ + list_for_each_entry(rdev, &cfg80211_rdev_list, list) + wiphy_update_new_beacon(&rdev->wiphy, pending_beacon); + + /* Remembers the beacon hint for new wiphys or reg changes */ + list_add_tail(&pending_beacon->list, ®_beacon_list); + } + + spin_unlock_bh(®_pending_beacons_lock); +} + +static void reg_process_self_managed_hint(struct wiphy *wiphy) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + const struct ieee80211_regdomain *tmp; + const struct ieee80211_regdomain *regd; + enum nl80211_band band; + struct regulatory_request request = {}; + + ASSERT_RTNL(); + lockdep_assert_wiphy(wiphy); + + spin_lock(®_requests_lock); + regd = rdev->requested_regd; + rdev->requested_regd = NULL; + spin_unlock(®_requests_lock); + + if (!regd) + return; + + tmp = get_wiphy_regdom(wiphy); + rcu_assign_pointer(wiphy->regd, regd); + rcu_free_regdom(tmp); + + for (band = 0; band < NUM_NL80211_BANDS; band++) + handle_band_custom(wiphy, wiphy->bands[band], regd); + + reg_process_ht_flags(wiphy); + + request.wiphy_idx = get_wiphy_idx(wiphy); + request.alpha2[0] = regd->alpha2[0]; + request.alpha2[1] = regd->alpha2[1]; + request.initiator = NL80211_REGDOM_SET_BY_DRIVER; + + nl80211_send_wiphy_reg_change_event(&request); +} + +static void reg_process_self_managed_hints(void) +{ + struct cfg80211_registered_device *rdev; + + ASSERT_RTNL(); + + list_for_each_entry(rdev, &cfg80211_rdev_list, list) { + wiphy_lock(&rdev->wiphy); + reg_process_self_managed_hint(&rdev->wiphy); + wiphy_unlock(&rdev->wiphy); + } + + reg_check_channels(); +} + +static void reg_todo(struct work_struct *work) +{ + rtnl_lock(); + reg_process_pending_hints(); + reg_process_pending_beacon_hints(); + reg_process_self_managed_hints(); + rtnl_unlock(); +} + +static void queue_regulatory_request(struct regulatory_request *request) +{ + request->alpha2[0] = toupper(request->alpha2[0]); + request->alpha2[1] = toupper(request->alpha2[1]); + + spin_lock(®_requests_lock); + list_add_tail(&request->list, ®_requests_list); + spin_unlock(®_requests_lock); + + schedule_work(®_work); +} + +/* + * Core regulatory hint -- happens during cfg80211_init() + * and when we restore regulatory settings. + */ +static int regulatory_hint_core(const char *alpha2) +{ + struct regulatory_request *request; + + request = kzalloc(sizeof(struct regulatory_request), GFP_KERNEL); + if (!request) + return -ENOMEM; + + request->alpha2[0] = alpha2[0]; + request->alpha2[1] = alpha2[1]; + request->initiator = NL80211_REGDOM_SET_BY_CORE; + request->wiphy_idx = WIPHY_IDX_INVALID; + + queue_regulatory_request(request); + + return 0; +} + +/* User hints */ +int regulatory_hint_user(const char *alpha2, + enum nl80211_user_reg_hint_type user_reg_hint_type) +{ + struct regulatory_request *request; + + if (WARN_ON(!alpha2)) + return -EINVAL; + + if (!is_world_regdom(alpha2) && !is_an_alpha2(alpha2)) + return -EINVAL; + + request = kzalloc(sizeof(struct regulatory_request), GFP_KERNEL); + if (!request) + return -ENOMEM; + + request->wiphy_idx = WIPHY_IDX_INVALID; + request->alpha2[0] = alpha2[0]; + request->alpha2[1] = alpha2[1]; + request->initiator = NL80211_REGDOM_SET_BY_USER; + request->user_reg_hint_type = user_reg_hint_type; + + /* Allow calling CRDA again */ + reset_crda_timeouts(); + + queue_regulatory_request(request); + + return 0; +} + +int regulatory_hint_indoor(bool is_indoor, u32 portid) +{ + spin_lock(®_indoor_lock); + + /* It is possible that more than one user space process is trying to + * configure the indoor setting. To handle such cases, clear the indoor + * setting in case that some process does not think that the device + * is operating in an indoor environment. In addition, if a user space + * process indicates that it is controlling the indoor setting, save its + * portid, i.e., make it the owner. + */ + reg_is_indoor = is_indoor; + if (reg_is_indoor) { + if (!reg_is_indoor_portid) + reg_is_indoor_portid = portid; + } else { + reg_is_indoor_portid = 0; + } + + spin_unlock(®_indoor_lock); + + if (!is_indoor) + reg_check_channels(); + + return 0; +} + +void regulatory_netlink_notify(u32 portid) +{ + spin_lock(®_indoor_lock); + + if (reg_is_indoor_portid != portid) { + spin_unlock(®_indoor_lock); + return; + } + + reg_is_indoor = false; + reg_is_indoor_portid = 0; + + spin_unlock(®_indoor_lock); + + reg_check_channels(); +} + +/* Driver hints */ +int regulatory_hint(struct wiphy *wiphy, const char *alpha2) +{ + struct regulatory_request *request; + + if (WARN_ON(!alpha2 || !wiphy)) + return -EINVAL; + + wiphy->regulatory_flags &= ~REGULATORY_CUSTOM_REG; + + request = kzalloc(sizeof(struct regulatory_request), GFP_KERNEL); + if (!request) + return -ENOMEM; + + request->wiphy_idx = get_wiphy_idx(wiphy); + + request->alpha2[0] = alpha2[0]; + request->alpha2[1] = alpha2[1]; + request->initiator = NL80211_REGDOM_SET_BY_DRIVER; + + /* Allow calling CRDA again */ + reset_crda_timeouts(); + + queue_regulatory_request(request); + + return 0; +} +EXPORT_SYMBOL(regulatory_hint); + +void regulatory_hint_country_ie(struct wiphy *wiphy, enum nl80211_band band, + const u8 *country_ie, u8 country_ie_len) +{ + char alpha2[2]; + enum environment_cap env = ENVIRON_ANY; + struct regulatory_request *request = NULL, *lr; + + /* IE len must be evenly divisible by 2 */ + if (country_ie_len & 0x01) + return; + + if (country_ie_len < IEEE80211_COUNTRY_IE_MIN_LEN) + return; + + request = kzalloc(sizeof(*request), GFP_KERNEL); + if (!request) + return; + + alpha2[0] = country_ie[0]; + alpha2[1] = country_ie[1]; + + if (country_ie[2] == 'I') + env = ENVIRON_INDOOR; + else if (country_ie[2] == 'O') + env = ENVIRON_OUTDOOR; + + rcu_read_lock(); + lr = get_last_request(); + + if (unlikely(!lr)) + goto out; + + /* + * We will run this only upon a successful connection on cfg80211. + * We leave conflict resolution to the workqueue, where can hold + * the RTNL. + */ + if (lr->initiator == NL80211_REGDOM_SET_BY_COUNTRY_IE && + lr->wiphy_idx != WIPHY_IDX_INVALID) + goto out; + + request->wiphy_idx = get_wiphy_idx(wiphy); + request->alpha2[0] = alpha2[0]; + request->alpha2[1] = alpha2[1]; + request->initiator = NL80211_REGDOM_SET_BY_COUNTRY_IE; + request->country_ie_env = env; + + /* Allow calling CRDA again */ + reset_crda_timeouts(); + + queue_regulatory_request(request); + request = NULL; +out: + kfree(request); + rcu_read_unlock(); +} + +static void restore_alpha2(char *alpha2, bool reset_user) +{ + /* indicates there is no alpha2 to consider for restoration */ + alpha2[0] = '9'; + alpha2[1] = '7'; + + /* The user setting has precedence over the module parameter */ + if (is_user_regdom_saved()) { + /* Unless we're asked to ignore it and reset it */ + if (reset_user) { + pr_debug("Restoring regulatory settings including user preference\n"); + user_alpha2[0] = '9'; + user_alpha2[1] = '7'; + + /* + * If we're ignoring user settings, we still need to + * check the module parameter to ensure we put things + * back as they were for a full restore. + */ + if (!is_world_regdom(ieee80211_regdom)) { + pr_debug("Keeping preference on module parameter ieee80211_regdom: %c%c\n", + ieee80211_regdom[0], ieee80211_regdom[1]); + alpha2[0] = ieee80211_regdom[0]; + alpha2[1] = ieee80211_regdom[1]; + } + } else { + pr_debug("Restoring regulatory settings while preserving user preference for: %c%c\n", + user_alpha2[0], user_alpha2[1]); + alpha2[0] = user_alpha2[0]; + alpha2[1] = user_alpha2[1]; + } + } else if (!is_world_regdom(ieee80211_regdom)) { + pr_debug("Keeping preference on module parameter ieee80211_regdom: %c%c\n", + ieee80211_regdom[0], ieee80211_regdom[1]); + alpha2[0] = ieee80211_regdom[0]; + alpha2[1] = ieee80211_regdom[1]; + } else + pr_debug("Restoring regulatory settings\n"); +} + +static void restore_custom_reg_settings(struct wiphy *wiphy) +{ + struct ieee80211_supported_band *sband; + enum nl80211_band band; + struct ieee80211_channel *chan; + int i; + + for (band = 0; band < NUM_NL80211_BANDS; band++) { + sband = wiphy->bands[band]; + if (!sband) + continue; + for (i = 0; i < sband->n_channels; i++) { + chan = &sband->channels[i]; + chan->flags = chan->orig_flags; + chan->max_antenna_gain = chan->orig_mag; + chan->max_power = chan->orig_mpwr; + chan->beacon_found = false; + } + } +} + +/* + * Restoring regulatory settings involves ignoring any + * possibly stale country IE information and user regulatory + * settings if so desired, this includes any beacon hints + * learned as we could have traveled outside to another country + * after disconnection. To restore regulatory settings we do + * exactly what we did at bootup: + * + * - send a core regulatory hint + * - send a user regulatory hint if applicable + * + * Device drivers that send a regulatory hint for a specific country + * keep their own regulatory domain on wiphy->regd so that does + * not need to be remembered. + */ +static void restore_regulatory_settings(bool reset_user, bool cached) +{ + char alpha2[2]; + char world_alpha2[2]; + struct reg_beacon *reg_beacon, *btmp; + LIST_HEAD(tmp_reg_req_list); + struct cfg80211_registered_device *rdev; + + ASSERT_RTNL(); + + /* + * Clear the indoor setting in case that it is not controlled by user + * space, as otherwise there is no guarantee that the device is still + * operating in an indoor environment. + */ + spin_lock(®_indoor_lock); + if (reg_is_indoor && !reg_is_indoor_portid) { + reg_is_indoor = false; + reg_check_channels(); + } + spin_unlock(®_indoor_lock); + + reset_regdomains(true, &world_regdom); + restore_alpha2(alpha2, reset_user); + + /* + * If there's any pending requests we simply + * stash them to a temporary pending queue and + * add then after we've restored regulatory + * settings. + */ + spin_lock(®_requests_lock); + list_splice_tail_init(®_requests_list, &tmp_reg_req_list); + spin_unlock(®_requests_lock); + + /* Clear beacon hints */ + spin_lock_bh(®_pending_beacons_lock); + list_for_each_entry_safe(reg_beacon, btmp, ®_pending_beacons, list) { + list_del(®_beacon->list); + kfree(reg_beacon); + } + spin_unlock_bh(®_pending_beacons_lock); + + list_for_each_entry_safe(reg_beacon, btmp, ®_beacon_list, list) { + list_del(®_beacon->list); + kfree(reg_beacon); + } + + /* First restore to the basic regulatory settings */ + world_alpha2[0] = cfg80211_world_regdom->alpha2[0]; + world_alpha2[1] = cfg80211_world_regdom->alpha2[1]; + + list_for_each_entry(rdev, &cfg80211_rdev_list, list) { + if (rdev->wiphy.regulatory_flags & REGULATORY_WIPHY_SELF_MANAGED) + continue; + if (rdev->wiphy.regulatory_flags & REGULATORY_CUSTOM_REG) + restore_custom_reg_settings(&rdev->wiphy); + } + + if (cached && (!is_an_alpha2(alpha2) || + !IS_ERR_OR_NULL(cfg80211_user_regdom))) { + reset_regdomains(false, cfg80211_world_regdom); + update_all_wiphy_regulatory(NL80211_REGDOM_SET_BY_CORE); + print_regdomain(get_cfg80211_regdom()); + nl80211_send_reg_change_event(&core_request_world); + reg_set_request_processed(); + + if (is_an_alpha2(alpha2) && + !regulatory_hint_user(alpha2, NL80211_USER_REG_HINT_USER)) { + struct regulatory_request *ureq; + + spin_lock(®_requests_lock); + ureq = list_last_entry(®_requests_list, + struct regulatory_request, + list); + list_del(&ureq->list); + spin_unlock(®_requests_lock); + + notify_self_managed_wiphys(ureq); + reg_update_last_request(ureq); + set_regdom(reg_copy_regd(cfg80211_user_regdom), + REGD_SOURCE_CACHED); + } + } else { + regulatory_hint_core(world_alpha2); + + /* + * This restores the ieee80211_regdom module parameter + * preference or the last user requested regulatory + * settings, user regulatory settings takes precedence. + */ + if (is_an_alpha2(alpha2)) + regulatory_hint_user(alpha2, NL80211_USER_REG_HINT_USER); + } + + spin_lock(®_requests_lock); + list_splice_tail_init(&tmp_reg_req_list, ®_requests_list); + spin_unlock(®_requests_lock); + + pr_debug("Kicking the queue\n"); + + schedule_work(®_work); +} + +static bool is_wiphy_all_set_reg_flag(enum ieee80211_regulatory_flags flag) +{ + struct cfg80211_registered_device *rdev; + struct wireless_dev *wdev; + + list_for_each_entry(rdev, &cfg80211_rdev_list, list) { + list_for_each_entry(wdev, &rdev->wiphy.wdev_list, list) { + wdev_lock(wdev); + if (!(wdev->wiphy->regulatory_flags & flag)) { + wdev_unlock(wdev); + return false; + } + wdev_unlock(wdev); + } + } + + return true; +} + +void regulatory_hint_disconnect(void) +{ + /* Restore of regulatory settings is not required when wiphy(s) + * ignore IE from connected access point but clearance of beacon hints + * is required when wiphy(s) supports beacon hints. + */ + if (is_wiphy_all_set_reg_flag(REGULATORY_COUNTRY_IE_IGNORE)) { + struct reg_beacon *reg_beacon, *btmp; + + if (is_wiphy_all_set_reg_flag(REGULATORY_DISABLE_BEACON_HINTS)) + return; + + spin_lock_bh(®_pending_beacons_lock); + list_for_each_entry_safe(reg_beacon, btmp, + ®_pending_beacons, list) { + list_del(®_beacon->list); + kfree(reg_beacon); + } + spin_unlock_bh(®_pending_beacons_lock); + + list_for_each_entry_safe(reg_beacon, btmp, + ®_beacon_list, list) { + list_del(®_beacon->list); + kfree(reg_beacon); + } + + return; + } + + pr_debug("All devices are disconnected, going to restore regulatory settings\n"); + restore_regulatory_settings(false, true); +} + +static bool freq_is_chan_12_13_14(u32 freq) +{ + if (freq == ieee80211_channel_to_frequency(12, NL80211_BAND_2GHZ) || + freq == ieee80211_channel_to_frequency(13, NL80211_BAND_2GHZ) || + freq == ieee80211_channel_to_frequency(14, NL80211_BAND_2GHZ)) + return true; + return false; +} + +static bool pending_reg_beacon(struct ieee80211_channel *beacon_chan) +{ + struct reg_beacon *pending_beacon; + + list_for_each_entry(pending_beacon, ®_pending_beacons, list) + if (ieee80211_channel_equal(beacon_chan, + &pending_beacon->chan)) + return true; + return false; +} + +int regulatory_hint_found_beacon(struct wiphy *wiphy, + struct ieee80211_channel *beacon_chan, + gfp_t gfp) +{ + struct reg_beacon *reg_beacon; + bool processing; + + if (beacon_chan->beacon_found || + beacon_chan->flags & IEEE80211_CHAN_RADAR || + (beacon_chan->band == NL80211_BAND_2GHZ && + !freq_is_chan_12_13_14(beacon_chan->center_freq))) + return 0; + + spin_lock_bh(®_pending_beacons_lock); + processing = pending_reg_beacon(beacon_chan); + spin_unlock_bh(®_pending_beacons_lock); + + if (processing) + return 0; + + reg_beacon = kzalloc(sizeof(struct reg_beacon), gfp); + if (!reg_beacon) + return -ENOMEM; + + pr_debug("Found new beacon on frequency: %d.%03d MHz (Ch %d) on %s\n", + beacon_chan->center_freq, beacon_chan->freq_offset, + ieee80211_freq_khz_to_channel( + ieee80211_channel_to_khz(beacon_chan)), + wiphy_name(wiphy)); + + memcpy(®_beacon->chan, beacon_chan, + sizeof(struct ieee80211_channel)); + + /* + * Since we can be called from BH or and non-BH context + * we must use spin_lock_bh() + */ + spin_lock_bh(®_pending_beacons_lock); + list_add_tail(®_beacon->list, ®_pending_beacons); + spin_unlock_bh(®_pending_beacons_lock); + + schedule_work(®_work); + + return 0; +} + +static void print_rd_rules(const struct ieee80211_regdomain *rd) +{ + unsigned int i; + const struct ieee80211_reg_rule *reg_rule = NULL; + const struct ieee80211_freq_range *freq_range = NULL; + const struct ieee80211_power_rule *power_rule = NULL; + char bw[32], cac_time[32]; + + pr_debug(" (start_freq - end_freq @ bandwidth), (max_antenna_gain, max_eirp), (dfs_cac_time)\n"); + + for (i = 0; i < rd->n_reg_rules; i++) { + reg_rule = &rd->reg_rules[i]; + freq_range = ®_rule->freq_range; + power_rule = ®_rule->power_rule; + + if (reg_rule->flags & NL80211_RRF_AUTO_BW) + snprintf(bw, sizeof(bw), "%d KHz, %u KHz AUTO", + freq_range->max_bandwidth_khz, + reg_get_max_bandwidth(rd, reg_rule)); + else + snprintf(bw, sizeof(bw), "%d KHz", + freq_range->max_bandwidth_khz); + + if (reg_rule->flags & NL80211_RRF_DFS) + scnprintf(cac_time, sizeof(cac_time), "%u s", + reg_rule->dfs_cac_ms/1000); + else + scnprintf(cac_time, sizeof(cac_time), "N/A"); + + + /* + * There may not be documentation for max antenna gain + * in certain regions + */ + if (power_rule->max_antenna_gain) + pr_debug(" (%d KHz - %d KHz @ %s), (%d mBi, %d mBm), (%s)\n", + freq_range->start_freq_khz, + freq_range->end_freq_khz, + bw, + power_rule->max_antenna_gain, + power_rule->max_eirp, + cac_time); + else + pr_debug(" (%d KHz - %d KHz @ %s), (N/A, %d mBm), (%s)\n", + freq_range->start_freq_khz, + freq_range->end_freq_khz, + bw, + power_rule->max_eirp, + cac_time); + } +} + +bool reg_supported_dfs_region(enum nl80211_dfs_regions dfs_region) +{ + switch (dfs_region) { + case NL80211_DFS_UNSET: + case NL80211_DFS_FCC: + case NL80211_DFS_ETSI: + case NL80211_DFS_JP: + return true; + default: + pr_debug("Ignoring unknown DFS master region: %d\n", dfs_region); + return false; + } +} + +static void print_regdomain(const struct ieee80211_regdomain *rd) +{ + struct regulatory_request *lr = get_last_request(); + + if (is_intersected_alpha2(rd->alpha2)) { + if (lr->initiator == NL80211_REGDOM_SET_BY_COUNTRY_IE) { + struct cfg80211_registered_device *rdev; + rdev = cfg80211_rdev_by_wiphy_idx(lr->wiphy_idx); + if (rdev) { + pr_debug("Current regulatory domain updated by AP to: %c%c\n", + rdev->country_ie_alpha2[0], + rdev->country_ie_alpha2[1]); + } else + pr_debug("Current regulatory domain intersected:\n"); + } else + pr_debug("Current regulatory domain intersected:\n"); + } else if (is_world_regdom(rd->alpha2)) { + pr_debug("World regulatory domain updated:\n"); + } else { + if (is_unknown_alpha2(rd->alpha2)) + pr_debug("Regulatory domain changed to driver built-in settings (unknown country)\n"); + else { + if (reg_request_cell_base(lr)) + pr_debug("Regulatory domain changed to country: %c%c by Cell Station\n", + rd->alpha2[0], rd->alpha2[1]); + else + pr_debug("Regulatory domain changed to country: %c%c\n", + rd->alpha2[0], rd->alpha2[1]); + } + } + + pr_debug(" DFS Master region: %s", reg_dfs_region_str(rd->dfs_region)); + print_rd_rules(rd); +} + +static void print_regdomain_info(const struct ieee80211_regdomain *rd) +{ + pr_debug("Regulatory domain: %c%c\n", rd->alpha2[0], rd->alpha2[1]); + print_rd_rules(rd); +} + +static int reg_set_rd_core(const struct ieee80211_regdomain *rd) +{ + if (!is_world_regdom(rd->alpha2)) + return -EINVAL; + update_world_regdomain(rd); + return 0; +} + +static int reg_set_rd_user(const struct ieee80211_regdomain *rd, + struct regulatory_request *user_request) +{ + const struct ieee80211_regdomain *intersected_rd = NULL; + + if (!regdom_changes(rd->alpha2)) + return -EALREADY; + + if (!is_valid_rd(rd)) { + pr_err("Invalid regulatory domain detected: %c%c\n", + rd->alpha2[0], rd->alpha2[1]); + print_regdomain_info(rd); + return -EINVAL; + } + + if (!user_request->intersect) { + reset_regdomains(false, rd); + return 0; + } + + intersected_rd = regdom_intersect(rd, get_cfg80211_regdom()); + if (!intersected_rd) + return -EINVAL; + + kfree(rd); + rd = NULL; + reset_regdomains(false, intersected_rd); + + return 0; +} + +static int reg_set_rd_driver(const struct ieee80211_regdomain *rd, + struct regulatory_request *driver_request) +{ + const struct ieee80211_regdomain *regd; + const struct ieee80211_regdomain *intersected_rd = NULL; + const struct ieee80211_regdomain *tmp; + struct wiphy *request_wiphy; + + if (is_world_regdom(rd->alpha2)) + return -EINVAL; + + if (!regdom_changes(rd->alpha2)) + return -EALREADY; + + if (!is_valid_rd(rd)) { + pr_err("Invalid regulatory domain detected: %c%c\n", + rd->alpha2[0], rd->alpha2[1]); + print_regdomain_info(rd); + return -EINVAL; + } + + request_wiphy = wiphy_idx_to_wiphy(driver_request->wiphy_idx); + if (!request_wiphy) + return -ENODEV; + + if (!driver_request->intersect) { + ASSERT_RTNL(); + wiphy_lock(request_wiphy); + if (request_wiphy->regd) { + wiphy_unlock(request_wiphy); + return -EALREADY; + } + + regd = reg_copy_regd(rd); + if (IS_ERR(regd)) { + wiphy_unlock(request_wiphy); + return PTR_ERR(regd); + } + + rcu_assign_pointer(request_wiphy->regd, regd); + wiphy_unlock(request_wiphy); + reset_regdomains(false, rd); + return 0; + } + + intersected_rd = regdom_intersect(rd, get_cfg80211_regdom()); + if (!intersected_rd) + return -EINVAL; + + /* + * We can trash what CRDA provided now. + * However if a driver requested this specific regulatory + * domain we keep it for its private use + */ + tmp = get_wiphy_regdom(request_wiphy); + rcu_assign_pointer(request_wiphy->regd, rd); + rcu_free_regdom(tmp); + + rd = NULL; + + reset_regdomains(false, intersected_rd); + + return 0; +} + +static int reg_set_rd_country_ie(const struct ieee80211_regdomain *rd, + struct regulatory_request *country_ie_request) +{ + struct wiphy *request_wiphy; + + if (!is_alpha2_set(rd->alpha2) && !is_an_alpha2(rd->alpha2) && + !is_unknown_alpha2(rd->alpha2)) + return -EINVAL; + + /* + * Lets only bother proceeding on the same alpha2 if the current + * rd is non static (it means CRDA was present and was used last) + * and the pending request came in from a country IE + */ + + if (!is_valid_rd(rd)) { + pr_err("Invalid regulatory domain detected: %c%c\n", + rd->alpha2[0], rd->alpha2[1]); + print_regdomain_info(rd); + return -EINVAL; + } + + request_wiphy = wiphy_idx_to_wiphy(country_ie_request->wiphy_idx); + if (!request_wiphy) + return -ENODEV; + + if (country_ie_request->intersect) + return -EINVAL; + + reset_regdomains(false, rd); + return 0; +} + +/* + * Use this call to set the current regulatory domain. Conflicts with + * multiple drivers can be ironed out later. Caller must've already + * kmalloc'd the rd structure. + */ +int set_regdom(const struct ieee80211_regdomain *rd, + enum ieee80211_regd_source regd_src) +{ + struct regulatory_request *lr; + bool user_reset = false; + int r; + + if (IS_ERR_OR_NULL(rd)) + return -ENODATA; + + if (!reg_is_valid_request(rd->alpha2)) { + kfree(rd); + return -EINVAL; + } + + if (regd_src == REGD_SOURCE_CRDA) + reset_crda_timeouts(); + + lr = get_last_request(); + + /* Note that this doesn't update the wiphys, this is done below */ + switch (lr->initiator) { + case NL80211_REGDOM_SET_BY_CORE: + r = reg_set_rd_core(rd); + break; + case NL80211_REGDOM_SET_BY_USER: + cfg80211_save_user_regdom(rd); + r = reg_set_rd_user(rd, lr); + user_reset = true; + break; + case NL80211_REGDOM_SET_BY_DRIVER: + r = reg_set_rd_driver(rd, lr); + break; + case NL80211_REGDOM_SET_BY_COUNTRY_IE: + r = reg_set_rd_country_ie(rd, lr); + break; + default: + WARN(1, "invalid initiator %d\n", lr->initiator); + kfree(rd); + return -EINVAL; + } + + if (r) { + switch (r) { + case -EALREADY: + reg_set_request_processed(); + break; + default: + /* Back to world regulatory in case of errors */ + restore_regulatory_settings(user_reset, false); + } + + kfree(rd); + return r; + } + + /* This would make this whole thing pointless */ + if (WARN_ON(!lr->intersect && rd != get_cfg80211_regdom())) + return -EINVAL; + + /* update all wiphys now with the new established regulatory domain */ + update_all_wiphy_regulatory(lr->initiator); + + print_regdomain(get_cfg80211_regdom()); + + nl80211_send_reg_change_event(lr); + + reg_set_request_processed(); + + return 0; +} + +static int __regulatory_set_wiphy_regd(struct wiphy *wiphy, + struct ieee80211_regdomain *rd) +{ + const struct ieee80211_regdomain *regd; + const struct ieee80211_regdomain *prev_regd; + struct cfg80211_registered_device *rdev; + + if (WARN_ON(!wiphy || !rd)) + return -EINVAL; + + if (WARN(!(wiphy->regulatory_flags & REGULATORY_WIPHY_SELF_MANAGED), + "wiphy should have REGULATORY_WIPHY_SELF_MANAGED\n")) + return -EPERM; + + if (WARN(!is_valid_rd(rd), + "Invalid regulatory domain detected: %c%c\n", + rd->alpha2[0], rd->alpha2[1])) { + print_regdomain_info(rd); + return -EINVAL; + } + + regd = reg_copy_regd(rd); + if (IS_ERR(regd)) + return PTR_ERR(regd); + + rdev = wiphy_to_rdev(wiphy); + + spin_lock(®_requests_lock); + prev_regd = rdev->requested_regd; + rdev->requested_regd = regd; + spin_unlock(®_requests_lock); + + kfree(prev_regd); + return 0; +} + +int regulatory_set_wiphy_regd(struct wiphy *wiphy, + struct ieee80211_regdomain *rd) +{ + int ret = __regulatory_set_wiphy_regd(wiphy, rd); + + if (ret) + return ret; + + schedule_work(®_work); + return 0; +} +EXPORT_SYMBOL(regulatory_set_wiphy_regd); + +int regulatory_set_wiphy_regd_sync(struct wiphy *wiphy, + struct ieee80211_regdomain *rd) +{ + int ret; + + ASSERT_RTNL(); + + ret = __regulatory_set_wiphy_regd(wiphy, rd); + if (ret) + return ret; + + /* process the request immediately */ + reg_process_self_managed_hint(wiphy); + reg_check_channels(); + return 0; +} +EXPORT_SYMBOL(regulatory_set_wiphy_regd_sync); + +void wiphy_regulatory_register(struct wiphy *wiphy) +{ + struct regulatory_request *lr = get_last_request(); + + /* self-managed devices ignore beacon hints and country IE */ + if (wiphy->regulatory_flags & REGULATORY_WIPHY_SELF_MANAGED) { + wiphy->regulatory_flags |= REGULATORY_DISABLE_BEACON_HINTS | + REGULATORY_COUNTRY_IE_IGNORE; + + /* + * The last request may have been received before this + * registration call. Call the driver notifier if + * initiator is USER. + */ + if (lr->initiator == NL80211_REGDOM_SET_BY_USER) + reg_call_notifier(wiphy, lr); + } + + if (!reg_dev_ignore_cell_hint(wiphy)) + reg_num_devs_support_basehint++; + + wiphy_update_regulatory(wiphy, lr->initiator); + wiphy_all_share_dfs_chan_state(wiphy); + reg_process_self_managed_hints(); +} + +void wiphy_regulatory_deregister(struct wiphy *wiphy) +{ + struct wiphy *request_wiphy = NULL; + struct regulatory_request *lr; + + lr = get_last_request(); + + if (!reg_dev_ignore_cell_hint(wiphy)) + reg_num_devs_support_basehint--; + + rcu_free_regdom(get_wiphy_regdom(wiphy)); + RCU_INIT_POINTER(wiphy->regd, NULL); + + if (lr) + request_wiphy = wiphy_idx_to_wiphy(lr->wiphy_idx); + + if (!request_wiphy || request_wiphy != wiphy) + return; + + lr->wiphy_idx = WIPHY_IDX_INVALID; + lr->country_ie_env = ENVIRON_ANY; +} + +/* + * See FCC notices for UNII band definitions + * 5GHz: https://www.fcc.gov/document/5-ghz-unlicensed-spectrum-unii + * 6GHz: https://www.fcc.gov/document/fcc-proposes-more-spectrum-unlicensed-use-0 + */ +int cfg80211_get_unii(int freq) +{ + /* UNII-1 */ + if (freq >= 5150 && freq <= 5250) + return 0; + + /* UNII-2A */ + if (freq > 5250 && freq <= 5350) + return 1; + + /* UNII-2B */ + if (freq > 5350 && freq <= 5470) + return 2; + + /* UNII-2C */ + if (freq > 5470 && freq <= 5725) + return 3; + + /* UNII-3 */ + if (freq > 5725 && freq <= 5825) + return 4; + + /* UNII-5 */ + if (freq > 5925 && freq <= 6425) + return 5; + + /* UNII-6 */ + if (freq > 6425 && freq <= 6525) + return 6; + + /* UNII-7 */ + if (freq > 6525 && freq <= 6875) + return 7; + + /* UNII-8 */ + if (freq > 6875 && freq <= 7125) + return 8; + + return -EINVAL; +} + +bool regulatory_indoor_allowed(void) +{ + return reg_is_indoor; +} + +bool regulatory_pre_cac_allowed(struct wiphy *wiphy) +{ + const struct ieee80211_regdomain *regd = NULL; + const struct ieee80211_regdomain *wiphy_regd = NULL; + bool pre_cac_allowed = false; + + rcu_read_lock(); + + regd = rcu_dereference(cfg80211_regdomain); + wiphy_regd = rcu_dereference(wiphy->regd); + if (!wiphy_regd) { + if (regd->dfs_region == NL80211_DFS_ETSI) + pre_cac_allowed = true; + + rcu_read_unlock(); + + return pre_cac_allowed; + } + + if (regd->dfs_region == wiphy_regd->dfs_region && + wiphy_regd->dfs_region == NL80211_DFS_ETSI) + pre_cac_allowed = true; + + rcu_read_unlock(); + + return pre_cac_allowed; +} +EXPORT_SYMBOL(regulatory_pre_cac_allowed); + +static void cfg80211_check_and_end_cac(struct cfg80211_registered_device *rdev) +{ + struct wireless_dev *wdev; + /* If we finished CAC or received radar, we should end any + * CAC running on the same channels. + * the check !cfg80211_chandef_dfs_usable contain 2 options: + * either all channels are available - those the CAC_FINISHED + * event has effected another wdev state, or there is a channel + * in unavailable state in wdev chandef - those the RADAR_DETECTED + * event has effected another wdev state. + * In both cases we should end the CAC on the wdev. + */ + list_for_each_entry(wdev, &rdev->wiphy.wdev_list, list) { + struct cfg80211_chan_def *chandef; + + if (!wdev->cac_started) + continue; + + /* FIXME: radar detection is tied to link 0 for now */ + chandef = wdev_chandef(wdev, 0); + if (!chandef) + continue; + + if (!cfg80211_chandef_dfs_usable(&rdev->wiphy, chandef)) + rdev_end_cac(rdev, wdev->netdev); + } +} + +void regulatory_propagate_dfs_state(struct wiphy *wiphy, + struct cfg80211_chan_def *chandef, + enum nl80211_dfs_state dfs_state, + enum nl80211_radar_event event) +{ + struct cfg80211_registered_device *rdev; + + ASSERT_RTNL(); + + if (WARN_ON(!cfg80211_chandef_valid(chandef))) + return; + + list_for_each_entry(rdev, &cfg80211_rdev_list, list) { + if (wiphy == &rdev->wiphy) + continue; + + if (!reg_dfs_domain_same(wiphy, &rdev->wiphy)) + continue; + + if (!ieee80211_get_channel(&rdev->wiphy, + chandef->chan->center_freq)) + continue; + + cfg80211_set_dfs_state(&rdev->wiphy, chandef, dfs_state); + + if (event == NL80211_RADAR_DETECTED || + event == NL80211_RADAR_CAC_FINISHED) { + cfg80211_sched_dfs_chan_update(rdev); + cfg80211_check_and_end_cac(rdev); + } + + nl80211_radar_notify(rdev, chandef, event, NULL, GFP_KERNEL); + } +} + +static int __init regulatory_init_db(void) +{ + int err; + + /* + * It's possible that - due to other bugs/issues - cfg80211 + * never called regulatory_init() below, or that it failed; + * in that case, don't try to do any further work here as + * it's doomed to lead to crashes. + */ + if (IS_ERR_OR_NULL(reg_pdev)) + return -EINVAL; + + err = load_builtin_regdb_keys(); + if (err) { + platform_device_unregister(reg_pdev); + return err; + } + + /* We always try to get an update for the static regdomain */ + err = regulatory_hint_core(cfg80211_world_regdom->alpha2); + if (err) { + if (err == -ENOMEM) { + platform_device_unregister(reg_pdev); + return err; + } + /* + * N.B. kobject_uevent_env() can fail mainly for when we're out + * memory which is handled and propagated appropriately above + * but it can also fail during a netlink_broadcast() or during + * early boot for call_usermodehelper(). For now treat these + * errors as non-fatal. + */ + pr_err("kobject_uevent_env() was unable to call CRDA during init\n"); + } + + /* + * Finally, if the user set the module parameter treat it + * as a user hint. + */ + if (!is_world_regdom(ieee80211_regdom)) + regulatory_hint_user(ieee80211_regdom, + NL80211_USER_REG_HINT_USER); + + return 0; +} +#ifndef MODULE +late_initcall(regulatory_init_db); +#endif + +int __init regulatory_init(void) +{ + reg_pdev = platform_device_register_simple("regulatory", 0, NULL, 0); + if (IS_ERR(reg_pdev)) + return PTR_ERR(reg_pdev); + + rcu_assign_pointer(cfg80211_regdomain, cfg80211_world_regdom); + + user_alpha2[0] = '9'; + user_alpha2[1] = '7'; + +#ifdef MODULE + return regulatory_init_db(); +#else + return 0; +#endif +} + +void regulatory_exit(void) +{ + struct regulatory_request *reg_request, *tmp; + struct reg_beacon *reg_beacon, *btmp; + + cancel_work_sync(®_work); + cancel_crda_timeout_sync(); + cancel_delayed_work_sync(®_check_chans); + + /* Lock to suppress warnings */ + rtnl_lock(); + reset_regdomains(true, NULL); + rtnl_unlock(); + + dev_set_uevent_suppress(®_pdev->dev, true); + + platform_device_unregister(reg_pdev); + + list_for_each_entry_safe(reg_beacon, btmp, ®_pending_beacons, list) { + list_del(®_beacon->list); + kfree(reg_beacon); + } + + list_for_each_entry_safe(reg_beacon, btmp, ®_beacon_list, list) { + list_del(®_beacon->list); + kfree(reg_beacon); + } + + list_for_each_entry_safe(reg_request, tmp, ®_requests_list, list) { + list_del(®_request->list); + kfree(reg_request); + } + + if (!IS_ERR_OR_NULL(regdb)) + kfree(regdb); + if (!IS_ERR_OR_NULL(cfg80211_user_regdom)) + kfree(cfg80211_user_regdom); + + free_regdb_keyring(); +} diff --git a/net/wireless/reg.h b/net/wireless/reg.h new file mode 100644 index 000000000..f3707f729 --- /dev/null +++ b/net/wireless/reg.h @@ -0,0 +1,189 @@ +#ifndef __NET_WIRELESS_REG_H +#define __NET_WIRELESS_REG_H + +#include + +/* + * Copyright 2008-2011 Luis R. Rodriguez + * Copyright (C) 2019 Intel Corporation + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + */ + +enum ieee80211_regd_source { + REGD_SOURCE_INTERNAL_DB, + REGD_SOURCE_CRDA, + REGD_SOURCE_CACHED, +}; + +extern const struct ieee80211_regdomain __rcu *cfg80211_regdomain; + +bool reg_is_valid_request(const char *alpha2); +bool is_world_regdom(const char *alpha2); +bool reg_supported_dfs_region(enum nl80211_dfs_regions dfs_region); +enum nl80211_dfs_regions reg_get_dfs_region(struct wiphy *wiphy); + +int regulatory_hint_user(const char *alpha2, + enum nl80211_user_reg_hint_type user_reg_hint_type); + +/** + * regulatory_hint_indoor - hint operation in indoor env. or not + * @is_indoor: if true indicates that user space thinks that the + * device is operating in an indoor environment. + * @portid: the netlink port ID on which the hint was given. + */ +int regulatory_hint_indoor(bool is_indoor, u32 portid); + +/** + * regulatory_netlink_notify - notify on released netlink socket + * @portid: the netlink socket port ID + */ +void regulatory_netlink_notify(u32 portid); + +void wiphy_regulatory_register(struct wiphy *wiphy); +void wiphy_regulatory_deregister(struct wiphy *wiphy); + +int __init regulatory_init(void); +void regulatory_exit(void); + +int set_regdom(const struct ieee80211_regdomain *rd, + enum ieee80211_regd_source regd_src); + +unsigned int reg_get_max_bandwidth(const struct ieee80211_regdomain *rd, + const struct ieee80211_reg_rule *rule); + +bool reg_last_request_cell_base(void); + +/** + * regulatory_hint_found_beacon - hints a beacon was found on a channel + * @wiphy: the wireless device where the beacon was found on + * @beacon_chan: the channel on which the beacon was found on + * @gfp: context flags + * + * This informs the wireless core that a beacon from an AP was found on + * the channel provided. This allows the wireless core to make educated + * guesses on regulatory to help with world roaming. This is only used for + * world roaming -- when we do not know our current location. This is + * only useful on channels 12, 13 and 14 on the 2 GHz band as channels + * 1-11 are already enabled by the world regulatory domain; and on + * non-radar 5 GHz channels. + * + * Drivers do not need to call this, cfg80211 will do it for after a scan + * on a newly found BSS. If you cannot make use of this feature you can + * set the wiphy->disable_beacon_hints to true. + */ +int regulatory_hint_found_beacon(struct wiphy *wiphy, + struct ieee80211_channel *beacon_chan, + gfp_t gfp); + +/** + * regulatory_hint_country_ie - hints a country IE as a regulatory domain + * @wiphy: the wireless device giving the hint (used only for reporting + * conflicts) + * @band: the band on which the country IE was received on. This determines + * the band we'll process the country IE channel triplets for. + * @country_ie: pointer to the country IE + * @country_ie_len: length of the country IE + * + * We will intersect the rd with the what CRDA tells us should apply + * for the alpha2 this country IE belongs to, this prevents APs from + * sending us incorrect or outdated information against a country. + * + * The AP is expected to provide Country IE channel triplets for the + * band it is on. It is technically possible for APs to send channel + * country IE triplets even for channels outside of the band they are + * in but for that they would have to use the regulatory extension + * in combination with a triplet but this behaviour is currently + * not observed. For this reason if a triplet is seen with channel + * information for a band the BSS is not present in it will be ignored. + */ +void regulatory_hint_country_ie(struct wiphy *wiphy, + enum nl80211_band band, + const u8 *country_ie, + u8 country_ie_len); + +/** + * regulatory_hint_disconnect - informs all devices have been disconnected + * + * Regulotory rules can be enhanced further upon scanning and upon + * connection to an AP. These rules become stale if we disconnect + * and go to another country, whether or not we suspend and resume. + * If we suspend, go to another country and resume we'll automatically + * get disconnected shortly after resuming and things will be reset as well. + * This routine is a helper to restore regulatory settings to how they were + * prior to our first connect attempt. This includes ignoring country IE and + * beacon regulatory hints. The ieee80211_regdom module parameter will always + * be respected but if a user had set the regulatory domain that will take + * precedence. + * + * Must be called from process context. + */ +void regulatory_hint_disconnect(void); + +/** + * cfg80211_get_unii - get the U-NII band for the frequency + * @freq: the frequency for which we want to get the UNII band. + + * Get a value specifying the U-NII band frequency belongs to. + * U-NII bands are defined by the FCC in C.F.R 47 part 15. + * + * Returns -EINVAL if freq is invalid, 0 for UNII-1, 1 for UNII-2A, + * 2 for UNII-2B, 3 for UNII-2C and 4 for UNII-3. + */ +int cfg80211_get_unii(int freq); + +/** + * regulatory_indoor_allowed - is indoor operation allowed + */ +bool regulatory_indoor_allowed(void); + +/* + * Grace period to timeout pre-CAC results on the dfs channels. This timeout + * value is used for Non-ETSI domain. + * TODO: May be make this timeout available through regdb? + */ +#define REG_PRE_CAC_EXPIRY_GRACE_MS 2000 + +/** + * regulatory_propagate_dfs_state - Propagate DFS channel state to other wiphys + * @wiphy - wiphy on which radar is detected and the event will be propagated + * to other available wiphys having the same DFS domain + * @chandef - Channel definition of radar detected channel + * @dfs_state - DFS channel state to be set + * @event - Type of radar event which triggered this DFS state change + * + * This function should be called with rtnl lock held. + */ +void regulatory_propagate_dfs_state(struct wiphy *wiphy, + struct cfg80211_chan_def *chandef, + enum nl80211_dfs_state dfs_state, + enum nl80211_radar_event event); + +/** + * reg_dfs_domain_same - Checks if both wiphy have same DFS domain configured + * @wiphy1 - wiphy it's dfs_region to be checked against that of wiphy2 + * @wiphy2 - wiphy it's dfs_region to be checked against that of wiphy1 + */ +bool reg_dfs_domain_same(struct wiphy *wiphy1, struct wiphy *wiphy2); + +/** + * reg_reload_regdb - reload the regulatory.db firmware file + */ +int reg_reload_regdb(void); + +extern const u8 shipped_regdb_certs[]; +extern unsigned int shipped_regdb_certs_len; +extern const u8 extra_regdb_certs[]; +extern unsigned int extra_regdb_certs_len; + +#endif /* __NET_WIRELESS_REG_H */ diff --git a/net/wireless/scan.c b/net/wireless/scan.c new file mode 100644 index 000000000..b7e1631b3 --- /dev/null +++ b/net/wireless/scan.c @@ -0,0 +1,3177 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * cfg80211 scan result handling + * + * Copyright 2008 Johannes Berg + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright 2016 Intel Deutschland GmbH + * Copyright (C) 2018-2023 Intel Corporation + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "core.h" +#include "nl80211.h" +#include "wext-compat.h" +#include "rdev-ops.h" + +/** + * DOC: BSS tree/list structure + * + * At the top level, the BSS list is kept in both a list in each + * registered device (@bss_list) as well as an RB-tree for faster + * lookup. In the RB-tree, entries can be looked up using their + * channel, MESHID, MESHCONF (for MBSSes) or channel, BSSID, SSID + * for other BSSes. + * + * Due to the possibility of hidden SSIDs, there's a second level + * structure, the "hidden_list" and "hidden_beacon_bss" pointer. + * The hidden_list connects all BSSes belonging to a single AP + * that has a hidden SSID, and connects beacon and probe response + * entries. For a probe response entry for a hidden SSID, the + * hidden_beacon_bss pointer points to the BSS struct holding the + * beacon's information. + * + * Reference counting is done for all these references except for + * the hidden_list, so that a beacon BSS struct that is otherwise + * not referenced has one reference for being on the bss_list and + * one for each probe response entry that points to it using the + * hidden_beacon_bss pointer. When a BSS struct that has such a + * pointer is get/put, the refcount update is also propagated to + * the referenced struct, this ensure that it cannot get removed + * while somebody is using the probe response version. + * + * Note that the hidden_beacon_bss pointer never changes, due to + * the reference counting. Therefore, no locking is needed for + * it. + * + * Also note that the hidden_beacon_bss pointer is only relevant + * if the driver uses something other than the IEs, e.g. private + * data stored in the BSS struct, since the beacon IEs are + * also linked into the probe response struct. + */ + +/* + * Limit the number of BSS entries stored in mac80211. Each one is + * a bit over 4k at most, so this limits to roughly 4-5M of memory. + * If somebody wants to really attack this though, they'd likely + * use small beacons, and only one type of frame, limiting each of + * the entries to a much smaller size (in order to generate more + * entries in total, so overhead is bigger.) + */ +static int bss_entries_limit = 1000; +module_param(bss_entries_limit, int, 0644); +MODULE_PARM_DESC(bss_entries_limit, + "limit to number of scan BSS entries (per wiphy, default 1000)"); + +#define IEEE80211_SCAN_RESULT_EXPIRE (30 * HZ) + +/** + * struct cfg80211_colocated_ap - colocated AP information + * + * @list: linked list to all colocated aPS + * @bssid: BSSID of the reported AP + * @ssid: SSID of the reported AP + * @ssid_len: length of the ssid + * @center_freq: frequency the reported AP is on + * @unsolicited_probe: the reported AP is part of an ESS, where all the APs + * that operate in the same channel as the reported AP and that might be + * detected by a STA receiving this frame, are transmitting unsolicited + * Probe Response frames every 20 TUs + * @oct_recommended: OCT is recommended to exchange MMPDUs with the reported AP + * @same_ssid: the reported AP has the same SSID as the reporting AP + * @multi_bss: the reported AP is part of a multiple BSSID set + * @transmitted_bssid: the reported AP is the transmitting BSSID + * @colocated_ess: all the APs that share the same ESS as the reported AP are + * colocated and can be discovered via legacy bands. + * @short_ssid_valid: short_ssid is valid and can be used + * @short_ssid: the short SSID for this SSID + */ +struct cfg80211_colocated_ap { + struct list_head list; + u8 bssid[ETH_ALEN]; + u8 ssid[IEEE80211_MAX_SSID_LEN]; + size_t ssid_len; + u32 short_ssid; + u32 center_freq; + u8 unsolicited_probe:1, + oct_recommended:1, + same_ssid:1, + multi_bss:1, + transmitted_bssid:1, + colocated_ess:1, + short_ssid_valid:1; +}; + +static void bss_free(struct cfg80211_internal_bss *bss) +{ + struct cfg80211_bss_ies *ies; + + if (WARN_ON(atomic_read(&bss->hold))) + return; + + ies = (void *)rcu_access_pointer(bss->pub.beacon_ies); + if (ies && !bss->pub.hidden_beacon_bss) + kfree_rcu(ies, rcu_head); + ies = (void *)rcu_access_pointer(bss->pub.proberesp_ies); + if (ies) + kfree_rcu(ies, rcu_head); + + /* + * This happens when the module is removed, it doesn't + * really matter any more save for completeness + */ + if (!list_empty(&bss->hidden_list)) + list_del(&bss->hidden_list); + + kfree(bss); +} + +static inline void bss_ref_get(struct cfg80211_registered_device *rdev, + struct cfg80211_internal_bss *bss) +{ + lockdep_assert_held(&rdev->bss_lock); + + bss->refcount++; + + if (bss->pub.hidden_beacon_bss) + bss_from_pub(bss->pub.hidden_beacon_bss)->refcount++; + + if (bss->pub.transmitted_bss) + bss_from_pub(bss->pub.transmitted_bss)->refcount++; +} + +static inline void bss_ref_put(struct cfg80211_registered_device *rdev, + struct cfg80211_internal_bss *bss) +{ + lockdep_assert_held(&rdev->bss_lock); + + if (bss->pub.hidden_beacon_bss) { + struct cfg80211_internal_bss *hbss; + hbss = container_of(bss->pub.hidden_beacon_bss, + struct cfg80211_internal_bss, + pub); + hbss->refcount--; + if (hbss->refcount == 0) + bss_free(hbss); + } + + if (bss->pub.transmitted_bss) { + struct cfg80211_internal_bss *tbss; + + tbss = container_of(bss->pub.transmitted_bss, + struct cfg80211_internal_bss, + pub); + tbss->refcount--; + if (tbss->refcount == 0) + bss_free(tbss); + } + + bss->refcount--; + if (bss->refcount == 0) + bss_free(bss); +} + +static bool __cfg80211_unlink_bss(struct cfg80211_registered_device *rdev, + struct cfg80211_internal_bss *bss) +{ + lockdep_assert_held(&rdev->bss_lock); + + if (!list_empty(&bss->hidden_list)) { + /* + * don't remove the beacon entry if it has + * probe responses associated with it + */ + if (!bss->pub.hidden_beacon_bss) + return false; + /* + * if it's a probe response entry break its + * link to the other entries in the group + */ + list_del_init(&bss->hidden_list); + } + + list_del_init(&bss->list); + list_del_init(&bss->pub.nontrans_list); + rb_erase(&bss->rbn, &rdev->bss_tree); + rdev->bss_entries--; + WARN_ONCE((rdev->bss_entries == 0) ^ list_empty(&rdev->bss_list), + "rdev bss entries[%d]/list[empty:%d] corruption\n", + rdev->bss_entries, list_empty(&rdev->bss_list)); + bss_ref_put(rdev, bss); + return true; +} + +bool cfg80211_is_element_inherited(const struct element *elem, + const struct element *non_inherit_elem) +{ + u8 id_len, ext_id_len, i, loop_len, id; + const u8 *list; + + if (elem->id == WLAN_EID_MULTIPLE_BSSID) + return false; + + if (!non_inherit_elem || non_inherit_elem->datalen < 2) + return true; + + /* + * non inheritance element format is: + * ext ID (56) | IDs list len | list | extension IDs list len | list + * Both lists are optional. Both lengths are mandatory. + * This means valid length is: + * elem_len = 1 (extension ID) + 2 (list len fields) + list lengths + */ + id_len = non_inherit_elem->data[1]; + if (non_inherit_elem->datalen < 3 + id_len) + return true; + + ext_id_len = non_inherit_elem->data[2 + id_len]; + if (non_inherit_elem->datalen < 3 + id_len + ext_id_len) + return true; + + if (elem->id == WLAN_EID_EXTENSION) { + if (!ext_id_len) + return true; + loop_len = ext_id_len; + list = &non_inherit_elem->data[3 + id_len]; + id = elem->data[0]; + } else { + if (!id_len) + return true; + loop_len = id_len; + list = &non_inherit_elem->data[2]; + id = elem->id; + } + + for (i = 0; i < loop_len; i++) { + if (list[i] == id) + return false; + } + + return true; +} +EXPORT_SYMBOL(cfg80211_is_element_inherited); + +static size_t cfg80211_copy_elem_with_frags(const struct element *elem, + const u8 *ie, size_t ie_len, + u8 **pos, u8 *buf, size_t buf_len) +{ + if (WARN_ON((u8 *)elem < ie || elem->data > ie + ie_len || + elem->data + elem->datalen > ie + ie_len)) + return 0; + + if (elem->datalen + 2 > buf + buf_len - *pos) + return 0; + + memcpy(*pos, elem, elem->datalen + 2); + *pos += elem->datalen + 2; + + /* Finish if it is not fragmented */ + if (elem->datalen != 255) + return *pos - buf; + + ie_len = ie + ie_len - elem->data - elem->datalen; + ie = (const u8 *)elem->data + elem->datalen; + + for_each_element(elem, ie, ie_len) { + if (elem->id != WLAN_EID_FRAGMENT) + break; + + if (elem->datalen + 2 > buf + buf_len - *pos) + return 0; + + memcpy(*pos, elem, elem->datalen + 2); + *pos += elem->datalen + 2; + + if (elem->datalen != 255) + break; + } + + return *pos - buf; +} + +static size_t cfg80211_gen_new_ie(const u8 *ie, size_t ielen, + const u8 *subie, size_t subie_len, + u8 *new_ie, size_t new_ie_len) +{ + const struct element *non_inherit_elem, *parent, *sub; + u8 *pos = new_ie; + u8 id, ext_id; + unsigned int match_len; + + non_inherit_elem = cfg80211_find_ext_elem(WLAN_EID_EXT_NON_INHERITANCE, + subie, subie_len); + + /* We copy the elements one by one from the parent to the generated + * elements. + * If they are not inherited (included in subie or in the non + * inheritance element), then we copy all occurrences the first time + * we see this element type. + */ + for_each_element(parent, ie, ielen) { + if (parent->id == WLAN_EID_FRAGMENT) + continue; + + if (parent->id == WLAN_EID_EXTENSION) { + if (parent->datalen < 1) + continue; + + id = WLAN_EID_EXTENSION; + ext_id = parent->data[0]; + match_len = 1; + } else { + id = parent->id; + match_len = 0; + } + + /* Find first occurrence in subie */ + sub = cfg80211_find_elem_match(id, subie, subie_len, + &ext_id, match_len, 0); + + /* Copy from parent if not in subie and inherited */ + if (!sub && + cfg80211_is_element_inherited(parent, non_inherit_elem)) { + if (!cfg80211_copy_elem_with_frags(parent, + ie, ielen, + &pos, new_ie, + new_ie_len)) + return 0; + + continue; + } + + /* Already copied if an earlier element had the same type */ + if (cfg80211_find_elem_match(id, ie, (u8 *)parent - ie, + &ext_id, match_len, 0)) + continue; + + /* Not inheriting, copy all similar elements from subie */ + while (sub) { + if (!cfg80211_copy_elem_with_frags(sub, + subie, subie_len, + &pos, new_ie, + new_ie_len)) + return 0; + + sub = cfg80211_find_elem_match(id, + sub->data + sub->datalen, + subie_len + subie - + (sub->data + + sub->datalen), + &ext_id, match_len, 0); + } + } + + /* The above misses elements that are included in subie but not in the + * parent, so do a pass over subie and append those. + * Skip the non-tx BSSID caps and non-inheritance element. + */ + for_each_element(sub, subie, subie_len) { + if (sub->id == WLAN_EID_NON_TX_BSSID_CAP) + continue; + + if (sub->id == WLAN_EID_FRAGMENT) + continue; + + if (sub->id == WLAN_EID_EXTENSION) { + if (sub->datalen < 1) + continue; + + id = WLAN_EID_EXTENSION; + ext_id = sub->data[0]; + match_len = 1; + + if (ext_id == WLAN_EID_EXT_NON_INHERITANCE) + continue; + } else { + id = sub->id; + match_len = 0; + } + + /* Processed if one was included in the parent */ + if (cfg80211_find_elem_match(id, ie, ielen, + &ext_id, match_len, 0)) + continue; + + if (!cfg80211_copy_elem_with_frags(sub, subie, subie_len, + &pos, new_ie, new_ie_len)) + return 0; + } + + return pos - new_ie; +} + +static bool is_bss(struct cfg80211_bss *a, const u8 *bssid, + const u8 *ssid, size_t ssid_len) +{ + const struct cfg80211_bss_ies *ies; + const struct element *ssid_elem; + + if (bssid && !ether_addr_equal(a->bssid, bssid)) + return false; + + if (!ssid) + return true; + + ies = rcu_access_pointer(a->ies); + if (!ies) + return false; + ssid_elem = cfg80211_find_elem(WLAN_EID_SSID, ies->data, ies->len); + if (!ssid_elem) + return false; + if (ssid_elem->datalen != ssid_len) + return false; + return memcmp(ssid_elem->data, ssid, ssid_len) == 0; +} + +static int +cfg80211_add_nontrans_list(struct cfg80211_bss *trans_bss, + struct cfg80211_bss *nontrans_bss) +{ + const struct element *ssid_elem; + struct cfg80211_bss *bss = NULL; + + rcu_read_lock(); + ssid_elem = ieee80211_bss_get_elem(nontrans_bss, WLAN_EID_SSID); + if (!ssid_elem) { + rcu_read_unlock(); + return -EINVAL; + } + + /* check if nontrans_bss is in the list */ + list_for_each_entry(bss, &trans_bss->nontrans_list, nontrans_list) { + if (is_bss(bss, nontrans_bss->bssid, ssid_elem->data, + ssid_elem->datalen)) { + rcu_read_unlock(); + return 0; + } + } + + rcu_read_unlock(); + + /* + * This is a bit weird - it's not on the list, but already on another + * one! The only way that could happen is if there's some BSSID/SSID + * shared by multiple APs in their multi-BSSID profiles, potentially + * with hidden SSID mixed in ... ignore it. + */ + if (!list_empty(&nontrans_bss->nontrans_list)) + return -EINVAL; + + /* add to the list */ + list_add_tail(&nontrans_bss->nontrans_list, &trans_bss->nontrans_list); + return 0; +} + +static void __cfg80211_bss_expire(struct cfg80211_registered_device *rdev, + unsigned long expire_time) +{ + struct cfg80211_internal_bss *bss, *tmp; + bool expired = false; + + lockdep_assert_held(&rdev->bss_lock); + + list_for_each_entry_safe(bss, tmp, &rdev->bss_list, list) { + if (atomic_read(&bss->hold)) + continue; + if (!time_after(expire_time, bss->ts)) + continue; + + if (__cfg80211_unlink_bss(rdev, bss)) + expired = true; + } + + if (expired) + rdev->bss_generation++; +} + +static bool cfg80211_bss_expire_oldest(struct cfg80211_registered_device *rdev) +{ + struct cfg80211_internal_bss *bss, *oldest = NULL; + bool ret; + + lockdep_assert_held(&rdev->bss_lock); + + list_for_each_entry(bss, &rdev->bss_list, list) { + if (atomic_read(&bss->hold)) + continue; + + if (!list_empty(&bss->hidden_list) && + !bss->pub.hidden_beacon_bss) + continue; + + if (oldest && time_before(oldest->ts, bss->ts)) + continue; + oldest = bss; + } + + if (WARN_ON(!oldest)) + return false; + + /* + * The callers make sure to increase rdev->bss_generation if anything + * gets removed (and a new entry added), so there's no need to also do + * it here. + */ + + ret = __cfg80211_unlink_bss(rdev, oldest); + WARN_ON(!ret); + return ret; +} + +static u8 cfg80211_parse_bss_param(u8 data, + struct cfg80211_colocated_ap *coloc_ap) +{ + coloc_ap->oct_recommended = + u8_get_bits(data, IEEE80211_RNR_TBTT_PARAMS_OCT_RECOMMENDED); + coloc_ap->same_ssid = + u8_get_bits(data, IEEE80211_RNR_TBTT_PARAMS_SAME_SSID); + coloc_ap->multi_bss = + u8_get_bits(data, IEEE80211_RNR_TBTT_PARAMS_MULTI_BSSID); + coloc_ap->transmitted_bssid = + u8_get_bits(data, IEEE80211_RNR_TBTT_PARAMS_TRANSMITTED_BSSID); + coloc_ap->unsolicited_probe = + u8_get_bits(data, IEEE80211_RNR_TBTT_PARAMS_PROBE_ACTIVE); + coloc_ap->colocated_ess = + u8_get_bits(data, IEEE80211_RNR_TBTT_PARAMS_COLOC_ESS); + + return u8_get_bits(data, IEEE80211_RNR_TBTT_PARAMS_COLOC_AP); +} + +static int cfg80211_calc_short_ssid(const struct cfg80211_bss_ies *ies, + const struct element **elem, u32 *s_ssid) +{ + + *elem = cfg80211_find_elem(WLAN_EID_SSID, ies->data, ies->len); + if (!*elem || (*elem)->datalen > IEEE80211_MAX_SSID_LEN) + return -EINVAL; + + *s_ssid = ~crc32_le(~0, (*elem)->data, (*elem)->datalen); + return 0; +} + +static void cfg80211_free_coloc_ap_list(struct list_head *coloc_ap_list) +{ + struct cfg80211_colocated_ap *ap, *tmp_ap; + + list_for_each_entry_safe(ap, tmp_ap, coloc_ap_list, list) { + list_del(&ap->list); + kfree(ap); + } +} + +static int cfg80211_parse_ap_info(struct cfg80211_colocated_ap *entry, + const u8 *pos, u8 length, + const struct element *ssid_elem, + int s_ssid_tmp) +{ + /* skip the TBTT offset */ + pos++; + + /* ignore entries with invalid BSSID */ + if (!is_valid_ether_addr(pos)) + return -EINVAL; + + memcpy(entry->bssid, pos, ETH_ALEN); + pos += ETH_ALEN; + + if (length >= IEEE80211_TBTT_INFO_OFFSET_BSSID_SSSID_BSS_PARAM) { + memcpy(&entry->short_ssid, pos, + sizeof(entry->short_ssid)); + entry->short_ssid_valid = true; + pos += 4; + } + + /* skip non colocated APs */ + if (!cfg80211_parse_bss_param(*pos, entry)) + return -EINVAL; + pos++; + + if (length == IEEE80211_TBTT_INFO_OFFSET_BSSID_BSS_PARAM) { + /* + * no information about the short ssid. Consider the entry valid + * for now. It would later be dropped in case there are explicit + * SSIDs that need to be matched + */ + if (!entry->same_ssid) + return 0; + } + + if (entry->same_ssid) { + entry->short_ssid = s_ssid_tmp; + entry->short_ssid_valid = true; + + /* + * This is safe because we validate datalen in + * cfg80211_parse_colocated_ap(), before calling this + * function. + */ + memcpy(&entry->ssid, &ssid_elem->data, + ssid_elem->datalen); + entry->ssid_len = ssid_elem->datalen; + } + return 0; +} + +static int cfg80211_parse_colocated_ap(const struct cfg80211_bss_ies *ies, + struct list_head *list) +{ + struct ieee80211_neighbor_ap_info *ap_info; + const struct element *elem, *ssid_elem; + const u8 *pos, *end; + u32 s_ssid_tmp; + int n_coloc = 0, ret; + LIST_HEAD(ap_list); + + elem = cfg80211_find_elem(WLAN_EID_REDUCED_NEIGHBOR_REPORT, ies->data, + ies->len); + if (!elem) + return 0; + + pos = elem->data; + end = pos + elem->datalen; + + ret = cfg80211_calc_short_ssid(ies, &ssid_elem, &s_ssid_tmp); + if (ret) + return 0; + + /* RNR IE may contain more than one NEIGHBOR_AP_INFO */ + while (pos + sizeof(*ap_info) <= end) { + enum nl80211_band band; + int freq; + u8 length, i, count; + + ap_info = (void *)pos; + count = u8_get_bits(ap_info->tbtt_info_hdr, + IEEE80211_AP_INFO_TBTT_HDR_COUNT) + 1; + length = ap_info->tbtt_info_len; + + pos += sizeof(*ap_info); + + if (!ieee80211_operating_class_to_band(ap_info->op_class, + &band)) + break; + + freq = ieee80211_channel_to_frequency(ap_info->channel, band); + + if (end - pos < count * length) + break; + + /* + * TBTT info must include bss param + BSSID + + * (short SSID or same_ssid bit to be set). + * ignore other options, and move to the + * next AP info + */ + if (band != NL80211_BAND_6GHZ || + (length != IEEE80211_TBTT_INFO_OFFSET_BSSID_BSS_PARAM && + length < IEEE80211_TBTT_INFO_OFFSET_BSSID_SSSID_BSS_PARAM)) { + pos += count * length; + continue; + } + + for (i = 0; i < count; i++) { + struct cfg80211_colocated_ap *entry; + + entry = kzalloc(sizeof(*entry) + IEEE80211_MAX_SSID_LEN, + GFP_ATOMIC); + + if (!entry) + break; + + entry->center_freq = freq; + + if (!cfg80211_parse_ap_info(entry, pos, length, + ssid_elem, s_ssid_tmp)) { + n_coloc++; + list_add_tail(&entry->list, &ap_list); + } else { + kfree(entry); + } + + pos += length; + } + } + + if (pos != end) { + cfg80211_free_coloc_ap_list(&ap_list); + return 0; + } + + list_splice_tail(&ap_list, list); + return n_coloc; +} + +static void cfg80211_scan_req_add_chan(struct cfg80211_scan_request *request, + struct ieee80211_channel *chan, + bool add_to_6ghz) +{ + int i; + u32 n_channels = request->n_channels; + struct cfg80211_scan_6ghz_params *params = + &request->scan_6ghz_params[request->n_6ghz_params]; + + for (i = 0; i < n_channels; i++) { + if (request->channels[i] == chan) { + if (add_to_6ghz) + params->channel_idx = i; + return; + } + } + + request->channels[n_channels] = chan; + if (add_to_6ghz) + request->scan_6ghz_params[request->n_6ghz_params].channel_idx = + n_channels; + + request->n_channels++; +} + +static bool cfg80211_find_ssid_match(struct cfg80211_colocated_ap *ap, + struct cfg80211_scan_request *request) +{ + int i; + u32 s_ssid; + + for (i = 0; i < request->n_ssids; i++) { + /* wildcard ssid in the scan request */ + if (!request->ssids[i].ssid_len) { + if (ap->multi_bss && !ap->transmitted_bssid) + continue; + + return true; + } + + if (ap->ssid_len && + ap->ssid_len == request->ssids[i].ssid_len) { + if (!memcmp(request->ssids[i].ssid, ap->ssid, + ap->ssid_len)) + return true; + } else if (ap->short_ssid_valid) { + s_ssid = ~crc32_le(~0, request->ssids[i].ssid, + request->ssids[i].ssid_len); + + if (ap->short_ssid == s_ssid) + return true; + } + } + + return false; +} + +static int cfg80211_scan_6ghz(struct cfg80211_registered_device *rdev) +{ + u8 i; + struct cfg80211_colocated_ap *ap; + int n_channels, count = 0, err; + struct cfg80211_scan_request *request, *rdev_req = rdev->scan_req; + LIST_HEAD(coloc_ap_list); + bool need_scan_psc = true; + const struct ieee80211_sband_iftype_data *iftd; + + rdev_req->scan_6ghz = true; + + if (!rdev->wiphy.bands[NL80211_BAND_6GHZ]) + return -EOPNOTSUPP; + + iftd = ieee80211_get_sband_iftype_data(rdev->wiphy.bands[NL80211_BAND_6GHZ], + rdev_req->wdev->iftype); + if (!iftd || !iftd->he_cap.has_he) + return -EOPNOTSUPP; + + n_channels = rdev->wiphy.bands[NL80211_BAND_6GHZ]->n_channels; + + if (rdev_req->flags & NL80211_SCAN_FLAG_COLOCATED_6GHZ) { + struct cfg80211_internal_bss *intbss; + + spin_lock_bh(&rdev->bss_lock); + list_for_each_entry(intbss, &rdev->bss_list, list) { + struct cfg80211_bss *res = &intbss->pub; + const struct cfg80211_bss_ies *ies; + + ies = rcu_access_pointer(res->ies); + count += cfg80211_parse_colocated_ap(ies, + &coloc_ap_list); + } + spin_unlock_bh(&rdev->bss_lock); + } + + request = kzalloc(struct_size(request, channels, n_channels) + + sizeof(*request->scan_6ghz_params) * count + + sizeof(*request->ssids) * rdev_req->n_ssids, + GFP_KERNEL); + if (!request) { + cfg80211_free_coloc_ap_list(&coloc_ap_list); + return -ENOMEM; + } + + *request = *rdev_req; + request->n_channels = 0; + request->scan_6ghz_params = + (void *)&request->channels[n_channels]; + + /* + * PSC channels should not be scanned in case of direct scan with 1 SSID + * and at least one of the reported co-located APs with same SSID + * indicating that all APs in the same ESS are co-located + */ + if (count && request->n_ssids == 1 && request->ssids[0].ssid_len) { + list_for_each_entry(ap, &coloc_ap_list, list) { + if (ap->colocated_ess && + cfg80211_find_ssid_match(ap, request)) { + need_scan_psc = false; + break; + } + } + } + + /* + * add to the scan request the channels that need to be scanned + * regardless of the collocated APs (PSC channels or all channels + * in case that NL80211_SCAN_FLAG_COLOCATED_6GHZ is not set) + */ + for (i = 0; i < rdev_req->n_channels; i++) { + if (rdev_req->channels[i]->band == NL80211_BAND_6GHZ && + ((need_scan_psc && + cfg80211_channel_is_psc(rdev_req->channels[i])) || + !(rdev_req->flags & NL80211_SCAN_FLAG_COLOCATED_6GHZ))) { + cfg80211_scan_req_add_chan(request, + rdev_req->channels[i], + false); + } + } + + if (!(rdev_req->flags & NL80211_SCAN_FLAG_COLOCATED_6GHZ)) + goto skip; + + list_for_each_entry(ap, &coloc_ap_list, list) { + bool found = false; + struct cfg80211_scan_6ghz_params *scan_6ghz_params = + &request->scan_6ghz_params[request->n_6ghz_params]; + struct ieee80211_channel *chan = + ieee80211_get_channel(&rdev->wiphy, ap->center_freq); + + if (!chan || chan->flags & IEEE80211_CHAN_DISABLED) + continue; + + for (i = 0; i < rdev_req->n_channels; i++) { + if (rdev_req->channels[i] == chan) + found = true; + } + + if (!found) + continue; + + if (request->n_ssids > 0 && + !cfg80211_find_ssid_match(ap, request)) + continue; + + if (!is_broadcast_ether_addr(request->bssid) && + !ether_addr_equal(request->bssid, ap->bssid)) + continue; + + if (!request->n_ssids && ap->multi_bss && !ap->transmitted_bssid) + continue; + + cfg80211_scan_req_add_chan(request, chan, true); + memcpy(scan_6ghz_params->bssid, ap->bssid, ETH_ALEN); + scan_6ghz_params->short_ssid = ap->short_ssid; + scan_6ghz_params->short_ssid_valid = ap->short_ssid_valid; + scan_6ghz_params->unsolicited_probe = ap->unsolicited_probe; + + /* + * If a PSC channel is added to the scan and 'need_scan_psc' is + * set to false, then all the APs that the scan logic is + * interested with on the channel are collocated and thus there + * is no need to perform the initial PSC channel listen. + */ + if (cfg80211_channel_is_psc(chan) && !need_scan_psc) + scan_6ghz_params->psc_no_listen = true; + + request->n_6ghz_params++; + } + +skip: + cfg80211_free_coloc_ap_list(&coloc_ap_list); + + if (request->n_channels) { + struct cfg80211_scan_request *old = rdev->int_scan_req; + rdev->int_scan_req = request; + + /* + * Add the ssids from the parent scan request to the new scan + * request, so the driver would be able to use them in its + * probe requests to discover hidden APs on PSC channels. + */ + request->ssids = (void *)&request->channels[request->n_channels]; + request->n_ssids = rdev_req->n_ssids; + memcpy(request->ssids, rdev_req->ssids, sizeof(*request->ssids) * + request->n_ssids); + + /* + * If this scan follows a previous scan, save the scan start + * info from the first part of the scan + */ + if (old) + rdev->int_scan_req->info = old->info; + + err = rdev_scan(rdev, request); + if (err) { + rdev->int_scan_req = old; + kfree(request); + } else { + kfree(old); + } + + return err; + } + + kfree(request); + return -EINVAL; +} + +int cfg80211_scan(struct cfg80211_registered_device *rdev) +{ + struct cfg80211_scan_request *request; + struct cfg80211_scan_request *rdev_req = rdev->scan_req; + u32 n_channels = 0, idx, i; + + if (!(rdev->wiphy.flags & WIPHY_FLAG_SPLIT_SCAN_6GHZ)) + return rdev_scan(rdev, rdev_req); + + for (i = 0; i < rdev_req->n_channels; i++) { + if (rdev_req->channels[i]->band != NL80211_BAND_6GHZ) + n_channels++; + } + + if (!n_channels) + return cfg80211_scan_6ghz(rdev); + + request = kzalloc(struct_size(request, channels, n_channels), + GFP_KERNEL); + if (!request) + return -ENOMEM; + + *request = *rdev_req; + request->n_channels = n_channels; + + for (i = idx = 0; i < rdev_req->n_channels; i++) { + if (rdev_req->channels[i]->band != NL80211_BAND_6GHZ) + request->channels[idx++] = rdev_req->channels[i]; + } + + rdev_req->scan_6ghz = false; + rdev->int_scan_req = request; + return rdev_scan(rdev, request); +} + +void ___cfg80211_scan_done(struct cfg80211_registered_device *rdev, + bool send_message) +{ + struct cfg80211_scan_request *request, *rdev_req; + struct wireless_dev *wdev; + struct sk_buff *msg; +#ifdef CONFIG_CFG80211_WEXT + union iwreq_data wrqu; +#endif + + lockdep_assert_held(&rdev->wiphy.mtx); + + if (rdev->scan_msg) { + nl80211_send_scan_msg(rdev, rdev->scan_msg); + rdev->scan_msg = NULL; + return; + } + + rdev_req = rdev->scan_req; + if (!rdev_req) + return; + + wdev = rdev_req->wdev; + request = rdev->int_scan_req ? rdev->int_scan_req : rdev_req; + + if (wdev_running(wdev) && + (rdev->wiphy.flags & WIPHY_FLAG_SPLIT_SCAN_6GHZ) && + !rdev_req->scan_6ghz && !request->info.aborted && + !cfg80211_scan_6ghz(rdev)) + return; + + /* + * This must be before sending the other events! + * Otherwise, wpa_supplicant gets completely confused with + * wext events. + */ + if (wdev->netdev) + cfg80211_sme_scan_done(wdev->netdev); + + if (!request->info.aborted && + request->flags & NL80211_SCAN_FLAG_FLUSH) { + /* flush entries from previous scans */ + spin_lock_bh(&rdev->bss_lock); + __cfg80211_bss_expire(rdev, request->scan_start); + spin_unlock_bh(&rdev->bss_lock); + } + + msg = nl80211_build_scan_msg(rdev, wdev, request->info.aborted); + +#ifdef CONFIG_CFG80211_WEXT + if (wdev->netdev && !request->info.aborted) { + memset(&wrqu, 0, sizeof(wrqu)); + + wireless_send_event(wdev->netdev, SIOCGIWSCAN, &wrqu, NULL); + } +#endif + + dev_put(wdev->netdev); + + kfree(rdev->int_scan_req); + rdev->int_scan_req = NULL; + + kfree(rdev->scan_req); + rdev->scan_req = NULL; + + if (!send_message) + rdev->scan_msg = msg; + else + nl80211_send_scan_msg(rdev, msg); +} + +void __cfg80211_scan_done(struct work_struct *wk) +{ + struct cfg80211_registered_device *rdev; + + rdev = container_of(wk, struct cfg80211_registered_device, + scan_done_wk); + + wiphy_lock(&rdev->wiphy); + ___cfg80211_scan_done(rdev, true); + wiphy_unlock(&rdev->wiphy); +} + +void cfg80211_scan_done(struct cfg80211_scan_request *request, + struct cfg80211_scan_info *info) +{ + struct cfg80211_scan_info old_info = request->info; + + trace_cfg80211_scan_done(request, info); + WARN_ON(request != wiphy_to_rdev(request->wiphy)->scan_req && + request != wiphy_to_rdev(request->wiphy)->int_scan_req); + + request->info = *info; + + /* + * In case the scan is split, the scan_start_tsf and tsf_bssid should + * be of the first part. In such a case old_info.scan_start_tsf should + * be non zero. + */ + if (request->scan_6ghz && old_info.scan_start_tsf) { + request->info.scan_start_tsf = old_info.scan_start_tsf; + memcpy(request->info.tsf_bssid, old_info.tsf_bssid, + sizeof(request->info.tsf_bssid)); + } + + request->notified = true; + queue_work(cfg80211_wq, &wiphy_to_rdev(request->wiphy)->scan_done_wk); +} +EXPORT_SYMBOL(cfg80211_scan_done); + +void cfg80211_add_sched_scan_req(struct cfg80211_registered_device *rdev, + struct cfg80211_sched_scan_request *req) +{ + lockdep_assert_held(&rdev->wiphy.mtx); + + list_add_rcu(&req->list, &rdev->sched_scan_req_list); +} + +static void cfg80211_del_sched_scan_req(struct cfg80211_registered_device *rdev, + struct cfg80211_sched_scan_request *req) +{ + lockdep_assert_held(&rdev->wiphy.mtx); + + list_del_rcu(&req->list); + kfree_rcu(req, rcu_head); +} + +static struct cfg80211_sched_scan_request * +cfg80211_find_sched_scan_req(struct cfg80211_registered_device *rdev, u64 reqid) +{ + struct cfg80211_sched_scan_request *pos; + + list_for_each_entry_rcu(pos, &rdev->sched_scan_req_list, list, + lockdep_is_held(&rdev->wiphy.mtx)) { + if (pos->reqid == reqid) + return pos; + } + return NULL; +} + +/* + * Determines if a scheduled scan request can be handled. When a legacy + * scheduled scan is running no other scheduled scan is allowed regardless + * whether the request is for legacy or multi-support scan. When a multi-support + * scheduled scan is running a request for legacy scan is not allowed. In this + * case a request for multi-support scan can be handled if resources are + * available, ie. struct wiphy::max_sched_scan_reqs limit is not yet reached. + */ +int cfg80211_sched_scan_req_possible(struct cfg80211_registered_device *rdev, + bool want_multi) +{ + struct cfg80211_sched_scan_request *pos; + int i = 0; + + list_for_each_entry(pos, &rdev->sched_scan_req_list, list) { + /* request id zero means legacy in progress */ + if (!i && !pos->reqid) + return -EINPROGRESS; + i++; + } + + if (i) { + /* no legacy allowed when multi request(s) are active */ + if (!want_multi) + return -EINPROGRESS; + + /* resource limit reached */ + if (i == rdev->wiphy.max_sched_scan_reqs) + return -ENOSPC; + } + return 0; +} + +void cfg80211_sched_scan_results_wk(struct work_struct *work) +{ + struct cfg80211_registered_device *rdev; + struct cfg80211_sched_scan_request *req, *tmp; + + rdev = container_of(work, struct cfg80211_registered_device, + sched_scan_res_wk); + + wiphy_lock(&rdev->wiphy); + list_for_each_entry_safe(req, tmp, &rdev->sched_scan_req_list, list) { + if (req->report_results) { + req->report_results = false; + if (req->flags & NL80211_SCAN_FLAG_FLUSH) { + /* flush entries from previous scans */ + spin_lock_bh(&rdev->bss_lock); + __cfg80211_bss_expire(rdev, req->scan_start); + spin_unlock_bh(&rdev->bss_lock); + req->scan_start = jiffies; + } + nl80211_send_sched_scan(req, + NL80211_CMD_SCHED_SCAN_RESULTS); + } + } + wiphy_unlock(&rdev->wiphy); +} + +void cfg80211_sched_scan_results(struct wiphy *wiphy, u64 reqid) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct cfg80211_sched_scan_request *request; + + trace_cfg80211_sched_scan_results(wiphy, reqid); + /* ignore if we're not scanning */ + + rcu_read_lock(); + request = cfg80211_find_sched_scan_req(rdev, reqid); + if (request) { + request->report_results = true; + queue_work(cfg80211_wq, &rdev->sched_scan_res_wk); + } + rcu_read_unlock(); +} +EXPORT_SYMBOL(cfg80211_sched_scan_results); + +void cfg80211_sched_scan_stopped_locked(struct wiphy *wiphy, u64 reqid) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + lockdep_assert_held(&wiphy->mtx); + + trace_cfg80211_sched_scan_stopped(wiphy, reqid); + + __cfg80211_stop_sched_scan(rdev, reqid, true); +} +EXPORT_SYMBOL(cfg80211_sched_scan_stopped_locked); + +void cfg80211_sched_scan_stopped(struct wiphy *wiphy, u64 reqid) +{ + wiphy_lock(wiphy); + cfg80211_sched_scan_stopped_locked(wiphy, reqid); + wiphy_unlock(wiphy); +} +EXPORT_SYMBOL(cfg80211_sched_scan_stopped); + +int cfg80211_stop_sched_scan_req(struct cfg80211_registered_device *rdev, + struct cfg80211_sched_scan_request *req, + bool driver_initiated) +{ + lockdep_assert_held(&rdev->wiphy.mtx); + + if (!driver_initiated) { + int err = rdev_sched_scan_stop(rdev, req->dev, req->reqid); + if (err) + return err; + } + + nl80211_send_sched_scan(req, NL80211_CMD_SCHED_SCAN_STOPPED); + + cfg80211_del_sched_scan_req(rdev, req); + + return 0; +} + +int __cfg80211_stop_sched_scan(struct cfg80211_registered_device *rdev, + u64 reqid, bool driver_initiated) +{ + struct cfg80211_sched_scan_request *sched_scan_req; + + lockdep_assert_held(&rdev->wiphy.mtx); + + sched_scan_req = cfg80211_find_sched_scan_req(rdev, reqid); + if (!sched_scan_req) + return -ENOENT; + + return cfg80211_stop_sched_scan_req(rdev, sched_scan_req, + driver_initiated); +} + +void cfg80211_bss_age(struct cfg80211_registered_device *rdev, + unsigned long age_secs) +{ + struct cfg80211_internal_bss *bss; + unsigned long age_jiffies = msecs_to_jiffies(age_secs * MSEC_PER_SEC); + + spin_lock_bh(&rdev->bss_lock); + list_for_each_entry(bss, &rdev->bss_list, list) + bss->ts -= age_jiffies; + spin_unlock_bh(&rdev->bss_lock); +} + +void cfg80211_bss_expire(struct cfg80211_registered_device *rdev) +{ + __cfg80211_bss_expire(rdev, jiffies - IEEE80211_SCAN_RESULT_EXPIRE); +} + +void cfg80211_bss_flush(struct wiphy *wiphy) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + + spin_lock_bh(&rdev->bss_lock); + __cfg80211_bss_expire(rdev, jiffies); + spin_unlock_bh(&rdev->bss_lock); +} +EXPORT_SYMBOL(cfg80211_bss_flush); + +const struct element * +cfg80211_find_elem_match(u8 eid, const u8 *ies, unsigned int len, + const u8 *match, unsigned int match_len, + unsigned int match_offset) +{ + const struct element *elem; + + for_each_element_id(elem, eid, ies, len) { + if (elem->datalen >= match_offset + match_len && + !memcmp(elem->data + match_offset, match, match_len)) + return elem; + } + + return NULL; +} +EXPORT_SYMBOL(cfg80211_find_elem_match); + +const struct element *cfg80211_find_vendor_elem(unsigned int oui, int oui_type, + const u8 *ies, + unsigned int len) +{ + const struct element *elem; + u8 match[] = { oui >> 16, oui >> 8, oui, oui_type }; + int match_len = (oui_type < 0) ? 3 : sizeof(match); + + if (WARN_ON(oui_type > 0xff)) + return NULL; + + elem = cfg80211_find_elem_match(WLAN_EID_VENDOR_SPECIFIC, ies, len, + match, match_len, 0); + + if (!elem || elem->datalen < 4) + return NULL; + + return elem; +} +EXPORT_SYMBOL(cfg80211_find_vendor_elem); + +/** + * enum bss_compare_mode - BSS compare mode + * @BSS_CMP_REGULAR: regular compare mode (for insertion and normal find) + * @BSS_CMP_HIDE_ZLEN: find hidden SSID with zero-length mode + * @BSS_CMP_HIDE_NUL: find hidden SSID with NUL-ed out mode + */ +enum bss_compare_mode { + BSS_CMP_REGULAR, + BSS_CMP_HIDE_ZLEN, + BSS_CMP_HIDE_NUL, +}; + +static int cmp_bss(struct cfg80211_bss *a, + struct cfg80211_bss *b, + enum bss_compare_mode mode) +{ + const struct cfg80211_bss_ies *a_ies, *b_ies; + const u8 *ie1 = NULL; + const u8 *ie2 = NULL; + int i, r; + + if (a->channel != b->channel) + return b->channel->center_freq - a->channel->center_freq; + + a_ies = rcu_access_pointer(a->ies); + if (!a_ies) + return -1; + b_ies = rcu_access_pointer(b->ies); + if (!b_ies) + return 1; + + if (WLAN_CAPABILITY_IS_STA_BSS(a->capability)) + ie1 = cfg80211_find_ie(WLAN_EID_MESH_ID, + a_ies->data, a_ies->len); + if (WLAN_CAPABILITY_IS_STA_BSS(b->capability)) + ie2 = cfg80211_find_ie(WLAN_EID_MESH_ID, + b_ies->data, b_ies->len); + if (ie1 && ie2) { + int mesh_id_cmp; + + if (ie1[1] == ie2[1]) + mesh_id_cmp = memcmp(ie1 + 2, ie2 + 2, ie1[1]); + else + mesh_id_cmp = ie2[1] - ie1[1]; + + ie1 = cfg80211_find_ie(WLAN_EID_MESH_CONFIG, + a_ies->data, a_ies->len); + ie2 = cfg80211_find_ie(WLAN_EID_MESH_CONFIG, + b_ies->data, b_ies->len); + if (ie1 && ie2) { + if (mesh_id_cmp) + return mesh_id_cmp; + if (ie1[1] != ie2[1]) + return ie2[1] - ie1[1]; + return memcmp(ie1 + 2, ie2 + 2, ie1[1]); + } + } + + r = memcmp(a->bssid, b->bssid, sizeof(a->bssid)); + if (r) + return r; + + ie1 = cfg80211_find_ie(WLAN_EID_SSID, a_ies->data, a_ies->len); + ie2 = cfg80211_find_ie(WLAN_EID_SSID, b_ies->data, b_ies->len); + + if (!ie1 && !ie2) + return 0; + + /* + * Note that with "hide_ssid", the function returns a match if + * the already-present BSS ("b") is a hidden SSID beacon for + * the new BSS ("a"). + */ + + /* sort missing IE before (left of) present IE */ + if (!ie1) + return -1; + if (!ie2) + return 1; + + switch (mode) { + case BSS_CMP_HIDE_ZLEN: + /* + * In ZLEN mode we assume the BSS entry we're + * looking for has a zero-length SSID. So if + * the one we're looking at right now has that, + * return 0. Otherwise, return the difference + * in length, but since we're looking for the + * 0-length it's really equivalent to returning + * the length of the one we're looking at. + * + * No content comparison is needed as we assume + * the content length is zero. + */ + return ie2[1]; + case BSS_CMP_REGULAR: + default: + /* sort by length first, then by contents */ + if (ie1[1] != ie2[1]) + return ie2[1] - ie1[1]; + return memcmp(ie1 + 2, ie2 + 2, ie1[1]); + case BSS_CMP_HIDE_NUL: + if (ie1[1] != ie2[1]) + return ie2[1] - ie1[1]; + /* this is equivalent to memcmp(zeroes, ie2 + 2, len) */ + for (i = 0; i < ie2[1]; i++) + if (ie2[i + 2]) + return -1; + return 0; + } +} + +static bool cfg80211_bss_type_match(u16 capability, + enum nl80211_band band, + enum ieee80211_bss_type bss_type) +{ + bool ret = true; + u16 mask, val; + + if (bss_type == IEEE80211_BSS_TYPE_ANY) + return ret; + + if (band == NL80211_BAND_60GHZ) { + mask = WLAN_CAPABILITY_DMG_TYPE_MASK; + switch (bss_type) { + case IEEE80211_BSS_TYPE_ESS: + val = WLAN_CAPABILITY_DMG_TYPE_AP; + break; + case IEEE80211_BSS_TYPE_PBSS: + val = WLAN_CAPABILITY_DMG_TYPE_PBSS; + break; + case IEEE80211_BSS_TYPE_IBSS: + val = WLAN_CAPABILITY_DMG_TYPE_IBSS; + break; + default: + return false; + } + } else { + mask = WLAN_CAPABILITY_ESS | WLAN_CAPABILITY_IBSS; + switch (bss_type) { + case IEEE80211_BSS_TYPE_ESS: + val = WLAN_CAPABILITY_ESS; + break; + case IEEE80211_BSS_TYPE_IBSS: + val = WLAN_CAPABILITY_IBSS; + break; + case IEEE80211_BSS_TYPE_MBSS: + val = 0; + break; + default: + return false; + } + } + + ret = ((capability & mask) == val); + return ret; +} + +/* Returned bss is reference counted and must be cleaned up appropriately. */ +struct cfg80211_bss *cfg80211_get_bss(struct wiphy *wiphy, + struct ieee80211_channel *channel, + const u8 *bssid, + const u8 *ssid, size_t ssid_len, + enum ieee80211_bss_type bss_type, + enum ieee80211_privacy privacy) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct cfg80211_internal_bss *bss, *res = NULL; + unsigned long now = jiffies; + int bss_privacy; + + trace_cfg80211_get_bss(wiphy, channel, bssid, ssid, ssid_len, bss_type, + privacy); + + spin_lock_bh(&rdev->bss_lock); + + list_for_each_entry(bss, &rdev->bss_list, list) { + if (!cfg80211_bss_type_match(bss->pub.capability, + bss->pub.channel->band, bss_type)) + continue; + + bss_privacy = (bss->pub.capability & WLAN_CAPABILITY_PRIVACY); + if ((privacy == IEEE80211_PRIVACY_ON && !bss_privacy) || + (privacy == IEEE80211_PRIVACY_OFF && bss_privacy)) + continue; + if (channel && bss->pub.channel != channel) + continue; + if (!is_valid_ether_addr(bss->pub.bssid)) + continue; + /* Don't get expired BSS structs */ + if (time_after(now, bss->ts + IEEE80211_SCAN_RESULT_EXPIRE) && + !atomic_read(&bss->hold)) + continue; + if (is_bss(&bss->pub, bssid, ssid, ssid_len)) { + res = bss; + bss_ref_get(rdev, res); + break; + } + } + + spin_unlock_bh(&rdev->bss_lock); + if (!res) + return NULL; + trace_cfg80211_return_bss(&res->pub); + return &res->pub; +} +EXPORT_SYMBOL(cfg80211_get_bss); + +static void rb_insert_bss(struct cfg80211_registered_device *rdev, + struct cfg80211_internal_bss *bss) +{ + struct rb_node **p = &rdev->bss_tree.rb_node; + struct rb_node *parent = NULL; + struct cfg80211_internal_bss *tbss; + int cmp; + + while (*p) { + parent = *p; + tbss = rb_entry(parent, struct cfg80211_internal_bss, rbn); + + cmp = cmp_bss(&bss->pub, &tbss->pub, BSS_CMP_REGULAR); + + if (WARN_ON(!cmp)) { + /* will sort of leak this BSS */ + return; + } + + if (cmp < 0) + p = &(*p)->rb_left; + else + p = &(*p)->rb_right; + } + + rb_link_node(&bss->rbn, parent, p); + rb_insert_color(&bss->rbn, &rdev->bss_tree); +} + +static struct cfg80211_internal_bss * +rb_find_bss(struct cfg80211_registered_device *rdev, + struct cfg80211_internal_bss *res, + enum bss_compare_mode mode) +{ + struct rb_node *n = rdev->bss_tree.rb_node; + struct cfg80211_internal_bss *bss; + int r; + + while (n) { + bss = rb_entry(n, struct cfg80211_internal_bss, rbn); + r = cmp_bss(&res->pub, &bss->pub, mode); + + if (r == 0) + return bss; + else if (r < 0) + n = n->rb_left; + else + n = n->rb_right; + } + + return NULL; +} + +static bool cfg80211_combine_bsses(struct cfg80211_registered_device *rdev, + struct cfg80211_internal_bss *new) +{ + const struct cfg80211_bss_ies *ies; + struct cfg80211_internal_bss *bss; + const u8 *ie; + int i, ssidlen; + u8 fold = 0; + u32 n_entries = 0; + + ies = rcu_access_pointer(new->pub.beacon_ies); + if (WARN_ON(!ies)) + return false; + + ie = cfg80211_find_ie(WLAN_EID_SSID, ies->data, ies->len); + if (!ie) { + /* nothing to do */ + return true; + } + + ssidlen = ie[1]; + for (i = 0; i < ssidlen; i++) + fold |= ie[2 + i]; + + if (fold) { + /* not a hidden SSID */ + return true; + } + + /* This is the bad part ... */ + + list_for_each_entry(bss, &rdev->bss_list, list) { + /* + * we're iterating all the entries anyway, so take the + * opportunity to validate the list length accounting + */ + n_entries++; + + if (!ether_addr_equal(bss->pub.bssid, new->pub.bssid)) + continue; + if (bss->pub.channel != new->pub.channel) + continue; + if (bss->pub.scan_width != new->pub.scan_width) + continue; + if (rcu_access_pointer(bss->pub.beacon_ies)) + continue; + ies = rcu_access_pointer(bss->pub.ies); + if (!ies) + continue; + ie = cfg80211_find_ie(WLAN_EID_SSID, ies->data, ies->len); + if (!ie) + continue; + if (ssidlen && ie[1] != ssidlen) + continue; + if (WARN_ON_ONCE(bss->pub.hidden_beacon_bss)) + continue; + if (WARN_ON_ONCE(!list_empty(&bss->hidden_list))) + list_del(&bss->hidden_list); + /* combine them */ + list_add(&bss->hidden_list, &new->hidden_list); + bss->pub.hidden_beacon_bss = &new->pub; + new->refcount += bss->refcount; + rcu_assign_pointer(bss->pub.beacon_ies, + new->pub.beacon_ies); + } + + WARN_ONCE(n_entries != rdev->bss_entries, + "rdev bss entries[%d]/list[len:%d] corruption\n", + rdev->bss_entries, n_entries); + + return true; +} + +struct cfg80211_non_tx_bss { + struct cfg80211_bss *tx_bss; + u8 max_bssid_indicator; + u8 bssid_index; +}; + +static void cfg80211_update_hidden_bsses(struct cfg80211_internal_bss *known, + const struct cfg80211_bss_ies *new_ies, + const struct cfg80211_bss_ies *old_ies) +{ + struct cfg80211_internal_bss *bss; + + /* Assign beacon IEs to all sub entries */ + list_for_each_entry(bss, &known->hidden_list, hidden_list) { + const struct cfg80211_bss_ies *ies; + + ies = rcu_access_pointer(bss->pub.beacon_ies); + WARN_ON(ies != old_ies); + + rcu_assign_pointer(bss->pub.beacon_ies, new_ies); + } +} + +static bool +cfg80211_update_known_bss(struct cfg80211_registered_device *rdev, + struct cfg80211_internal_bss *known, + struct cfg80211_internal_bss *new, + bool signal_valid) +{ + lockdep_assert_held(&rdev->bss_lock); + + /* Update IEs */ + if (rcu_access_pointer(new->pub.proberesp_ies)) { + const struct cfg80211_bss_ies *old; + + old = rcu_access_pointer(known->pub.proberesp_ies); + + rcu_assign_pointer(known->pub.proberesp_ies, + new->pub.proberesp_ies); + /* Override possible earlier Beacon frame IEs */ + rcu_assign_pointer(known->pub.ies, + new->pub.proberesp_ies); + if (old) + kfree_rcu((struct cfg80211_bss_ies *)old, rcu_head); + } else if (rcu_access_pointer(new->pub.beacon_ies)) { + const struct cfg80211_bss_ies *old; + + if (known->pub.hidden_beacon_bss && + !list_empty(&known->hidden_list)) { + const struct cfg80211_bss_ies *f; + + /* The known BSS struct is one of the probe + * response members of a group, but we're + * receiving a beacon (beacon_ies in the new + * bss is used). This can only mean that the + * AP changed its beacon from not having an + * SSID to showing it, which is confusing so + * drop this information. + */ + + f = rcu_access_pointer(new->pub.beacon_ies); + kfree_rcu((struct cfg80211_bss_ies *)f, rcu_head); + return false; + } + + old = rcu_access_pointer(known->pub.beacon_ies); + + rcu_assign_pointer(known->pub.beacon_ies, new->pub.beacon_ies); + + /* Override IEs if they were from a beacon before */ + if (old == rcu_access_pointer(known->pub.ies)) + rcu_assign_pointer(known->pub.ies, new->pub.beacon_ies); + + cfg80211_update_hidden_bsses(known, + rcu_access_pointer(new->pub.beacon_ies), + old); + + if (old) + kfree_rcu((struct cfg80211_bss_ies *)old, rcu_head); + } + + known->pub.beacon_interval = new->pub.beacon_interval; + + /* don't update the signal if beacon was heard on + * adjacent channel. + */ + if (signal_valid) + known->pub.signal = new->pub.signal; + known->pub.capability = new->pub.capability; + known->ts = new->ts; + known->ts_boottime = new->ts_boottime; + known->parent_tsf = new->parent_tsf; + known->pub.chains = new->pub.chains; + memcpy(known->pub.chain_signal, new->pub.chain_signal, + IEEE80211_MAX_CHAINS); + ether_addr_copy(known->parent_bssid, new->parent_bssid); + known->pub.max_bssid_indicator = new->pub.max_bssid_indicator; + known->pub.bssid_index = new->pub.bssid_index; + + return true; +} + +/* Returned bss is reference counted and must be cleaned up appropriately. */ +struct cfg80211_internal_bss * +cfg80211_bss_update(struct cfg80211_registered_device *rdev, + struct cfg80211_internal_bss *tmp, + bool signal_valid, unsigned long ts) +{ + struct cfg80211_internal_bss *found = NULL; + + if (WARN_ON(!tmp->pub.channel)) + return NULL; + + tmp->ts = ts; + + spin_lock_bh(&rdev->bss_lock); + + if (WARN_ON(!rcu_access_pointer(tmp->pub.ies))) { + spin_unlock_bh(&rdev->bss_lock); + return NULL; + } + + found = rb_find_bss(rdev, tmp, BSS_CMP_REGULAR); + + if (found) { + if (!cfg80211_update_known_bss(rdev, found, tmp, signal_valid)) + goto drop; + } else { + struct cfg80211_internal_bss *new; + struct cfg80211_internal_bss *hidden; + struct cfg80211_bss_ies *ies; + + /* + * create a copy -- the "res" variable that is passed in + * is allocated on the stack since it's not needed in the + * more common case of an update + */ + new = kzalloc(sizeof(*new) + rdev->wiphy.bss_priv_size, + GFP_ATOMIC); + if (!new) { + ies = (void *)rcu_dereference(tmp->pub.beacon_ies); + if (ies) + kfree_rcu(ies, rcu_head); + ies = (void *)rcu_dereference(tmp->pub.proberesp_ies); + if (ies) + kfree_rcu(ies, rcu_head); + goto drop; + } + memcpy(new, tmp, sizeof(*new)); + new->refcount = 1; + INIT_LIST_HEAD(&new->hidden_list); + INIT_LIST_HEAD(&new->pub.nontrans_list); + /* we'll set this later if it was non-NULL */ + new->pub.transmitted_bss = NULL; + + if (rcu_access_pointer(tmp->pub.proberesp_ies)) { + hidden = rb_find_bss(rdev, tmp, BSS_CMP_HIDE_ZLEN); + if (!hidden) + hidden = rb_find_bss(rdev, tmp, + BSS_CMP_HIDE_NUL); + if (hidden) { + new->pub.hidden_beacon_bss = &hidden->pub; + list_add(&new->hidden_list, + &hidden->hidden_list); + hidden->refcount++; + rcu_assign_pointer(new->pub.beacon_ies, + hidden->pub.beacon_ies); + } + } else { + /* + * Ok so we found a beacon, and don't have an entry. If + * it's a beacon with hidden SSID, we might be in for an + * expensive search for any probe responses that should + * be grouped with this beacon for updates ... + */ + if (!cfg80211_combine_bsses(rdev, new)) { + bss_ref_put(rdev, new); + goto drop; + } + } + + if (rdev->bss_entries >= bss_entries_limit && + !cfg80211_bss_expire_oldest(rdev)) { + bss_ref_put(rdev, new); + goto drop; + } + + /* This must be before the call to bss_ref_get */ + if (tmp->pub.transmitted_bss) { + struct cfg80211_internal_bss *pbss = + container_of(tmp->pub.transmitted_bss, + struct cfg80211_internal_bss, + pub); + + new->pub.transmitted_bss = tmp->pub.transmitted_bss; + bss_ref_get(rdev, pbss); + } + + list_add_tail(&new->list, &rdev->bss_list); + rdev->bss_entries++; + rb_insert_bss(rdev, new); + found = new; + } + + rdev->bss_generation++; + bss_ref_get(rdev, found); + spin_unlock_bh(&rdev->bss_lock); + + return found; + drop: + spin_unlock_bh(&rdev->bss_lock); + return NULL; +} + +int cfg80211_get_ies_channel_number(const u8 *ie, size_t ielen, + enum nl80211_band band, + enum cfg80211_bss_frame_type ftype) +{ + const struct element *tmp; + + if (band == NL80211_BAND_6GHZ) { + struct ieee80211_he_operation *he_oper; + + tmp = cfg80211_find_ext_elem(WLAN_EID_EXT_HE_OPERATION, ie, + ielen); + if (tmp && tmp->datalen >= sizeof(*he_oper) && + tmp->datalen >= ieee80211_he_oper_size(&tmp->data[1])) { + const struct ieee80211_he_6ghz_oper *he_6ghz_oper; + + he_oper = (void *)&tmp->data[1]; + + he_6ghz_oper = ieee80211_he_6ghz_oper(he_oper); + if (!he_6ghz_oper) + return -1; + + if (ftype != CFG80211_BSS_FTYPE_BEACON || + he_6ghz_oper->control & IEEE80211_HE_6GHZ_OPER_CTRL_DUP_BEACON) + return he_6ghz_oper->primary; + } + } else if (band == NL80211_BAND_S1GHZ) { + tmp = cfg80211_find_elem(WLAN_EID_S1G_OPERATION, ie, ielen); + if (tmp && tmp->datalen >= sizeof(struct ieee80211_s1g_oper_ie)) { + struct ieee80211_s1g_oper_ie *s1gop = (void *)tmp->data; + + return s1gop->oper_ch; + } + } else { + tmp = cfg80211_find_elem(WLAN_EID_DS_PARAMS, ie, ielen); + if (tmp && tmp->datalen == 1) + return tmp->data[0]; + + tmp = cfg80211_find_elem(WLAN_EID_HT_OPERATION, ie, ielen); + if (tmp && + tmp->datalen >= sizeof(struct ieee80211_ht_operation)) { + struct ieee80211_ht_operation *htop = (void *)tmp->data; + + return htop->primary_chan; + } + } + + return -1; +} +EXPORT_SYMBOL(cfg80211_get_ies_channel_number); + +/* + * Update RX channel information based on the available frame payload + * information. This is mainly for the 2.4 GHz band where frames can be received + * from neighboring channels and the Beacon frames use the DSSS Parameter Set + * element to indicate the current (transmitting) channel, but this might also + * be needed on other bands if RX frequency does not match with the actual + * operating channel of a BSS, or if the AP reports a different primary channel. + */ +static struct ieee80211_channel * +cfg80211_get_bss_channel(struct wiphy *wiphy, const u8 *ie, size_t ielen, + struct ieee80211_channel *channel, + enum nl80211_bss_scan_width scan_width, + enum cfg80211_bss_frame_type ftype) +{ + u32 freq; + int channel_number; + struct ieee80211_channel *alt_channel; + + channel_number = cfg80211_get_ies_channel_number(ie, ielen, + channel->band, ftype); + + if (channel_number < 0) { + /* No channel information in frame payload */ + return channel; + } + + freq = ieee80211_channel_to_freq_khz(channel_number, channel->band); + + /* + * In 6GHz, duplicated beacon indication is relevant for + * beacons only. + */ + if (channel->band == NL80211_BAND_6GHZ && + (freq == channel->center_freq || + abs(freq - channel->center_freq) > 80)) + return channel; + + alt_channel = ieee80211_get_channel_khz(wiphy, freq); + if (!alt_channel) { + if (channel->band == NL80211_BAND_2GHZ) { + /* + * Better not allow unexpected channels when that could + * be going beyond the 1-11 range (e.g., discovering + * BSS on channel 12 when radio is configured for + * channel 11. + */ + return NULL; + } + + /* No match for the payload channel number - ignore it */ + return channel; + } + + if (scan_width == NL80211_BSS_CHAN_WIDTH_10 || + scan_width == NL80211_BSS_CHAN_WIDTH_5) { + /* + * Ignore channel number in 5 and 10 MHz channels where there + * may not be an n:1 or 1:n mapping between frequencies and + * channel numbers. + */ + return channel; + } + + /* + * Use the channel determined through the payload channel number + * instead of the RX channel reported by the driver. + */ + if (alt_channel->flags & IEEE80211_CHAN_DISABLED) + return NULL; + return alt_channel; +} + +/* Returned bss is reference counted and must be cleaned up appropriately. */ +static struct cfg80211_bss * +cfg80211_inform_single_bss_data(struct wiphy *wiphy, + struct cfg80211_inform_bss *data, + enum cfg80211_bss_frame_type ftype, + const u8 *bssid, u64 tsf, u16 capability, + u16 beacon_interval, const u8 *ie, size_t ielen, + struct cfg80211_non_tx_bss *non_tx_data, + gfp_t gfp) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct cfg80211_bss_ies *ies; + struct ieee80211_channel *channel; + struct cfg80211_internal_bss tmp = {}, *res; + int bss_type; + bool signal_valid; + unsigned long ts; + + if (WARN_ON(!wiphy)) + return NULL; + + if (WARN_ON(wiphy->signal_type == CFG80211_SIGNAL_TYPE_UNSPEC && + (data->signal < 0 || data->signal > 100))) + return NULL; + + channel = cfg80211_get_bss_channel(wiphy, ie, ielen, data->chan, + data->scan_width, ftype); + if (!channel) + return NULL; + + memcpy(tmp.pub.bssid, bssid, ETH_ALEN); + tmp.pub.channel = channel; + tmp.pub.scan_width = data->scan_width; + tmp.pub.signal = data->signal; + tmp.pub.beacon_interval = beacon_interval; + tmp.pub.capability = capability; + tmp.ts_boottime = data->boottime_ns; + tmp.parent_tsf = data->parent_tsf; + ether_addr_copy(tmp.parent_bssid, data->parent_bssid); + + if (non_tx_data) { + tmp.pub.transmitted_bss = non_tx_data->tx_bss; + ts = bss_from_pub(non_tx_data->tx_bss)->ts; + tmp.pub.bssid_index = non_tx_data->bssid_index; + tmp.pub.max_bssid_indicator = non_tx_data->max_bssid_indicator; + } else { + ts = jiffies; + } + + /* + * If we do not know here whether the IEs are from a Beacon or Probe + * Response frame, we need to pick one of the options and only use it + * with the driver that does not provide the full Beacon/Probe Response + * frame. Use Beacon frame pointer to avoid indicating that this should + * override the IEs pointer should we have received an earlier + * indication of Probe Response data. + */ + ies = kzalloc(sizeof(*ies) + ielen, gfp); + if (!ies) + return NULL; + ies->len = ielen; + ies->tsf = tsf; + ies->from_beacon = false; + memcpy(ies->data, ie, ielen); + + switch (ftype) { + case CFG80211_BSS_FTYPE_BEACON: + ies->from_beacon = true; + fallthrough; + case CFG80211_BSS_FTYPE_UNKNOWN: + rcu_assign_pointer(tmp.pub.beacon_ies, ies); + break; + case CFG80211_BSS_FTYPE_PRESP: + rcu_assign_pointer(tmp.pub.proberesp_ies, ies); + break; + } + rcu_assign_pointer(tmp.pub.ies, ies); + + signal_valid = data->chan == channel; + res = cfg80211_bss_update(wiphy_to_rdev(wiphy), &tmp, signal_valid, ts); + if (!res) + return NULL; + + if (channel->band == NL80211_BAND_60GHZ) { + bss_type = res->pub.capability & WLAN_CAPABILITY_DMG_TYPE_MASK; + if (bss_type == WLAN_CAPABILITY_DMG_TYPE_AP || + bss_type == WLAN_CAPABILITY_DMG_TYPE_PBSS) + regulatory_hint_found_beacon(wiphy, channel, gfp); + } else { + if (res->pub.capability & WLAN_CAPABILITY_ESS) + regulatory_hint_found_beacon(wiphy, channel, gfp); + } + + if (non_tx_data) { + /* this is a nontransmitting bss, we need to add it to + * transmitting bss' list if it is not there + */ + spin_lock_bh(&rdev->bss_lock); + if (cfg80211_add_nontrans_list(non_tx_data->tx_bss, + &res->pub)) { + if (__cfg80211_unlink_bss(rdev, res)) { + rdev->bss_generation++; + res = NULL; + } + } + spin_unlock_bh(&rdev->bss_lock); + + if (!res) + return NULL; + } + + trace_cfg80211_return_bss(&res->pub); + /* cfg80211_bss_update gives us a referenced result */ + return &res->pub; +} + +static const struct element +*cfg80211_get_profile_continuation(const u8 *ie, size_t ielen, + const struct element *mbssid_elem, + const struct element *sub_elem) +{ + const u8 *mbssid_end = mbssid_elem->data + mbssid_elem->datalen; + const struct element *next_mbssid; + const struct element *next_sub; + + next_mbssid = cfg80211_find_elem(WLAN_EID_MULTIPLE_BSSID, + mbssid_end, + ielen - (mbssid_end - ie)); + + /* + * If it is not the last subelement in current MBSSID IE or there isn't + * a next MBSSID IE - profile is complete. + */ + if ((sub_elem->data + sub_elem->datalen < mbssid_end - 1) || + !next_mbssid) + return NULL; + + /* For any length error, just return NULL */ + + if (next_mbssid->datalen < 4) + return NULL; + + next_sub = (void *)&next_mbssid->data[1]; + + if (next_mbssid->data + next_mbssid->datalen < + next_sub->data + next_sub->datalen) + return NULL; + + if (next_sub->id != 0 || next_sub->datalen < 2) + return NULL; + + /* + * Check if the first element in the next sub element is a start + * of a new profile + */ + return next_sub->data[0] == WLAN_EID_NON_TX_BSSID_CAP ? + NULL : next_mbssid; +} + +size_t cfg80211_merge_profile(const u8 *ie, size_t ielen, + const struct element *mbssid_elem, + const struct element *sub_elem, + u8 *merged_ie, size_t max_copy_len) +{ + size_t copied_len = sub_elem->datalen; + const struct element *next_mbssid; + + if (sub_elem->datalen > max_copy_len) + return 0; + + memcpy(merged_ie, sub_elem->data, sub_elem->datalen); + + while ((next_mbssid = cfg80211_get_profile_continuation(ie, ielen, + mbssid_elem, + sub_elem))) { + const struct element *next_sub = (void *)&next_mbssid->data[1]; + + if (copied_len + next_sub->datalen > max_copy_len) + break; + memcpy(merged_ie + copied_len, next_sub->data, + next_sub->datalen); + copied_len += next_sub->datalen; + } + + return copied_len; +} +EXPORT_SYMBOL(cfg80211_merge_profile); + +static void cfg80211_parse_mbssid_data(struct wiphy *wiphy, + struct cfg80211_inform_bss *data, + enum cfg80211_bss_frame_type ftype, + const u8 *bssid, u64 tsf, + u16 beacon_interval, const u8 *ie, + size_t ielen, + struct cfg80211_non_tx_bss *non_tx_data, + gfp_t gfp) +{ + const u8 *mbssid_index_ie; + const struct element *elem, *sub; + size_t new_ie_len; + u8 new_bssid[ETH_ALEN]; + u8 *new_ie, *profile; + u64 seen_indices = 0; + u16 capability; + struct cfg80211_bss *bss; + + if (!non_tx_data) + return; + if (!cfg80211_find_elem(WLAN_EID_MULTIPLE_BSSID, ie, ielen)) + return; + if (!wiphy->support_mbssid) + return; + if (wiphy->support_only_he_mbssid && + !cfg80211_find_ext_elem(WLAN_EID_EXT_HE_CAPABILITY, ie, ielen)) + return; + + new_ie = kmalloc(IEEE80211_MAX_DATA_LEN, gfp); + if (!new_ie) + return; + + profile = kmalloc(ielen, gfp); + if (!profile) + goto out; + + for_each_element_id(elem, WLAN_EID_MULTIPLE_BSSID, ie, ielen) { + if (elem->datalen < 4) + continue; + if (elem->data[0] < 1 || (int)elem->data[0] > 8) + continue; + for_each_element(sub, elem->data + 1, elem->datalen - 1) { + u8 profile_len; + + if (sub->id != 0 || sub->datalen < 4) { + /* not a valid BSS profile */ + continue; + } + + if (sub->data[0] != WLAN_EID_NON_TX_BSSID_CAP || + sub->data[1] != 2) { + /* The first element within the Nontransmitted + * BSSID Profile is not the Nontransmitted + * BSSID Capability element. + */ + continue; + } + + memset(profile, 0, ielen); + profile_len = cfg80211_merge_profile(ie, ielen, + elem, + sub, + profile, + ielen); + + /* found a Nontransmitted BSSID Profile */ + mbssid_index_ie = cfg80211_find_ie + (WLAN_EID_MULTI_BSSID_IDX, + profile, profile_len); + if (!mbssid_index_ie || mbssid_index_ie[1] < 1 || + mbssid_index_ie[2] == 0 || + mbssid_index_ie[2] > 46) { + /* No valid Multiple BSSID-Index element */ + continue; + } + + if (seen_indices & BIT_ULL(mbssid_index_ie[2])) + /* We don't support legacy split of a profile */ + net_dbg_ratelimited("Partial info for BSSID index %d\n", + mbssid_index_ie[2]); + + seen_indices |= BIT_ULL(mbssid_index_ie[2]); + + non_tx_data->bssid_index = mbssid_index_ie[2]; + non_tx_data->max_bssid_indicator = elem->data[0]; + + cfg80211_gen_new_bssid(bssid, + non_tx_data->max_bssid_indicator, + non_tx_data->bssid_index, + new_bssid); + memset(new_ie, 0, IEEE80211_MAX_DATA_LEN); + new_ie_len = cfg80211_gen_new_ie(ie, ielen, + profile, + profile_len, new_ie, + IEEE80211_MAX_DATA_LEN); + if (!new_ie_len) + continue; + + capability = get_unaligned_le16(profile + 2); + bss = cfg80211_inform_single_bss_data(wiphy, data, + ftype, + new_bssid, tsf, + capability, + beacon_interval, + new_ie, + new_ie_len, + non_tx_data, + gfp); + if (!bss) + break; + cfg80211_put_bss(wiphy, bss); + } + } + +out: + kfree(new_ie); + kfree(profile); +} + +struct cfg80211_bss * +cfg80211_inform_bss_data(struct wiphy *wiphy, + struct cfg80211_inform_bss *data, + enum cfg80211_bss_frame_type ftype, + const u8 *bssid, u64 tsf, u16 capability, + u16 beacon_interval, const u8 *ie, size_t ielen, + gfp_t gfp) +{ + struct cfg80211_bss *res; + struct cfg80211_non_tx_bss non_tx_data; + + res = cfg80211_inform_single_bss_data(wiphy, data, ftype, bssid, tsf, + capability, beacon_interval, ie, + ielen, NULL, gfp); + if (!res) + return NULL; + non_tx_data.tx_bss = res; + cfg80211_parse_mbssid_data(wiphy, data, ftype, bssid, tsf, + beacon_interval, ie, ielen, &non_tx_data, + gfp); + return res; +} +EXPORT_SYMBOL(cfg80211_inform_bss_data); + +/* cfg80211_inform_bss_width_frame helper */ +static struct cfg80211_bss * +cfg80211_inform_single_bss_frame_data(struct wiphy *wiphy, + struct cfg80211_inform_bss *data, + struct ieee80211_mgmt *mgmt, size_t len, + gfp_t gfp) +{ + struct cfg80211_internal_bss tmp = {}, *res; + struct cfg80211_bss_ies *ies; + struct ieee80211_channel *channel; + bool signal_valid; + struct ieee80211_ext *ext = NULL; + u8 *bssid, *variable; + u16 capability, beacon_int; + size_t ielen, min_hdr_len = offsetof(struct ieee80211_mgmt, + u.probe_resp.variable); + int bss_type; + enum cfg80211_bss_frame_type ftype; + + BUILD_BUG_ON(offsetof(struct ieee80211_mgmt, u.probe_resp.variable) != + offsetof(struct ieee80211_mgmt, u.beacon.variable)); + + trace_cfg80211_inform_bss_frame(wiphy, data, mgmt, len); + + if (WARN_ON(!mgmt)) + return NULL; + + if (WARN_ON(!wiphy)) + return NULL; + + if (WARN_ON(wiphy->signal_type == CFG80211_SIGNAL_TYPE_UNSPEC && + (data->signal < 0 || data->signal > 100))) + return NULL; + + if (ieee80211_is_s1g_beacon(mgmt->frame_control)) { + ext = (void *) mgmt; + min_hdr_len = offsetof(struct ieee80211_ext, u.s1g_beacon); + if (ieee80211_is_s1g_short_beacon(mgmt->frame_control)) + min_hdr_len = offsetof(struct ieee80211_ext, + u.s1g_short_beacon.variable); + } + + if (WARN_ON(len < min_hdr_len)) + return NULL; + + ielen = len - min_hdr_len; + variable = mgmt->u.probe_resp.variable; + if (ext) { + if (ieee80211_is_s1g_short_beacon(mgmt->frame_control)) + variable = ext->u.s1g_short_beacon.variable; + else + variable = ext->u.s1g_beacon.variable; + } + + if (ieee80211_is_beacon(mgmt->frame_control)) + ftype = CFG80211_BSS_FTYPE_BEACON; + else if (ieee80211_is_probe_resp(mgmt->frame_control)) + ftype = CFG80211_BSS_FTYPE_PRESP; + else + ftype = CFG80211_BSS_FTYPE_UNKNOWN; + + channel = cfg80211_get_bss_channel(wiphy, variable, + ielen, data->chan, data->scan_width, + ftype); + if (!channel) + return NULL; + + if (ext) { + const struct ieee80211_s1g_bcn_compat_ie *compat; + const struct element *elem; + + elem = cfg80211_find_elem(WLAN_EID_S1G_BCN_COMPAT, + variable, ielen); + if (!elem) + return NULL; + if (elem->datalen < sizeof(*compat)) + return NULL; + compat = (void *)elem->data; + bssid = ext->u.s1g_beacon.sa; + capability = le16_to_cpu(compat->compat_info); + beacon_int = le16_to_cpu(compat->beacon_int); + } else { + bssid = mgmt->bssid; + beacon_int = le16_to_cpu(mgmt->u.probe_resp.beacon_int); + capability = le16_to_cpu(mgmt->u.probe_resp.capab_info); + } + + ies = kzalloc(sizeof(*ies) + ielen, gfp); + if (!ies) + return NULL; + ies->len = ielen; + ies->tsf = le64_to_cpu(mgmt->u.probe_resp.timestamp); + ies->from_beacon = ieee80211_is_beacon(mgmt->frame_control) || + ieee80211_is_s1g_beacon(mgmt->frame_control); + memcpy(ies->data, variable, ielen); + + if (ieee80211_is_probe_resp(mgmt->frame_control)) + rcu_assign_pointer(tmp.pub.proberesp_ies, ies); + else + rcu_assign_pointer(tmp.pub.beacon_ies, ies); + rcu_assign_pointer(tmp.pub.ies, ies); + + memcpy(tmp.pub.bssid, bssid, ETH_ALEN); + tmp.pub.beacon_interval = beacon_int; + tmp.pub.capability = capability; + tmp.pub.channel = channel; + tmp.pub.scan_width = data->scan_width; + tmp.pub.signal = data->signal; + tmp.ts_boottime = data->boottime_ns; + tmp.parent_tsf = data->parent_tsf; + tmp.pub.chains = data->chains; + memcpy(tmp.pub.chain_signal, data->chain_signal, IEEE80211_MAX_CHAINS); + ether_addr_copy(tmp.parent_bssid, data->parent_bssid); + + signal_valid = data->chan == channel; + res = cfg80211_bss_update(wiphy_to_rdev(wiphy), &tmp, signal_valid, + jiffies); + if (!res) + return NULL; + + if (channel->band == NL80211_BAND_60GHZ) { + bss_type = res->pub.capability & WLAN_CAPABILITY_DMG_TYPE_MASK; + if (bss_type == WLAN_CAPABILITY_DMG_TYPE_AP || + bss_type == WLAN_CAPABILITY_DMG_TYPE_PBSS) + regulatory_hint_found_beacon(wiphy, channel, gfp); + } else { + if (res->pub.capability & WLAN_CAPABILITY_ESS) + regulatory_hint_found_beacon(wiphy, channel, gfp); + } + + trace_cfg80211_return_bss(&res->pub); + /* cfg80211_bss_update gives us a referenced result */ + return &res->pub; +} + +struct cfg80211_bss * +cfg80211_inform_bss_frame_data(struct wiphy *wiphy, + struct cfg80211_inform_bss *data, + struct ieee80211_mgmt *mgmt, size_t len, + gfp_t gfp) +{ + struct cfg80211_bss *res; + const u8 *ie = mgmt->u.probe_resp.variable; + size_t ielen = len - offsetof(struct ieee80211_mgmt, + u.probe_resp.variable); + enum cfg80211_bss_frame_type ftype; + struct cfg80211_non_tx_bss non_tx_data = {}; + + res = cfg80211_inform_single_bss_frame_data(wiphy, data, mgmt, + len, gfp); + if (!res) + return NULL; + + /* don't do any further MBSSID handling for S1G */ + if (ieee80211_is_s1g_beacon(mgmt->frame_control)) + return res; + + ftype = ieee80211_is_beacon(mgmt->frame_control) ? + CFG80211_BSS_FTYPE_BEACON : CFG80211_BSS_FTYPE_PRESP; + non_tx_data.tx_bss = res; + + /* process each non-transmitting bss */ + cfg80211_parse_mbssid_data(wiphy, data, ftype, mgmt->bssid, + le64_to_cpu(mgmt->u.probe_resp.timestamp), + le16_to_cpu(mgmt->u.probe_resp.beacon_int), + ie, ielen, &non_tx_data, gfp); + + return res; +} +EXPORT_SYMBOL(cfg80211_inform_bss_frame_data); + +void cfg80211_ref_bss(struct wiphy *wiphy, struct cfg80211_bss *pub) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct cfg80211_internal_bss *bss; + + if (!pub) + return; + + bss = container_of(pub, struct cfg80211_internal_bss, pub); + + spin_lock_bh(&rdev->bss_lock); + bss_ref_get(rdev, bss); + spin_unlock_bh(&rdev->bss_lock); +} +EXPORT_SYMBOL(cfg80211_ref_bss); + +void cfg80211_put_bss(struct wiphy *wiphy, struct cfg80211_bss *pub) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct cfg80211_internal_bss *bss; + + if (!pub) + return; + + bss = container_of(pub, struct cfg80211_internal_bss, pub); + + spin_lock_bh(&rdev->bss_lock); + bss_ref_put(rdev, bss); + spin_unlock_bh(&rdev->bss_lock); +} +EXPORT_SYMBOL(cfg80211_put_bss); + +void cfg80211_unlink_bss(struct wiphy *wiphy, struct cfg80211_bss *pub) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct cfg80211_internal_bss *bss, *tmp1; + struct cfg80211_bss *nontrans_bss, *tmp; + + if (WARN_ON(!pub)) + return; + + bss = container_of(pub, struct cfg80211_internal_bss, pub); + + spin_lock_bh(&rdev->bss_lock); + if (list_empty(&bss->list)) + goto out; + + list_for_each_entry_safe(nontrans_bss, tmp, + &pub->nontrans_list, + nontrans_list) { + tmp1 = container_of(nontrans_bss, + struct cfg80211_internal_bss, pub); + if (__cfg80211_unlink_bss(rdev, tmp1)) + rdev->bss_generation++; + } + + if (__cfg80211_unlink_bss(rdev, bss)) + rdev->bss_generation++; +out: + spin_unlock_bh(&rdev->bss_lock); +} +EXPORT_SYMBOL(cfg80211_unlink_bss); + +void cfg80211_bss_iter(struct wiphy *wiphy, + struct cfg80211_chan_def *chandef, + void (*iter)(struct wiphy *wiphy, + struct cfg80211_bss *bss, + void *data), + void *iter_data) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct cfg80211_internal_bss *bss; + + spin_lock_bh(&rdev->bss_lock); + + list_for_each_entry(bss, &rdev->bss_list, list) { + if (!chandef || cfg80211_is_sub_chan(chandef, bss->pub.channel, + false)) + iter(wiphy, &bss->pub, iter_data); + } + + spin_unlock_bh(&rdev->bss_lock); +} +EXPORT_SYMBOL(cfg80211_bss_iter); + +void cfg80211_update_assoc_bss_entry(struct wireless_dev *wdev, + unsigned int link_id, + struct ieee80211_channel *chan) +{ + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct cfg80211_internal_bss *cbss = wdev->links[link_id].client.current_bss; + struct cfg80211_internal_bss *new = NULL; + struct cfg80211_internal_bss *bss; + struct cfg80211_bss *nontrans_bss; + struct cfg80211_bss *tmp; + + spin_lock_bh(&rdev->bss_lock); + + /* + * Some APs use CSA also for bandwidth changes, i.e., without actually + * changing the control channel, so no need to update in such a case. + */ + if (cbss->pub.channel == chan) + goto done; + + /* use transmitting bss */ + if (cbss->pub.transmitted_bss) + cbss = container_of(cbss->pub.transmitted_bss, + struct cfg80211_internal_bss, + pub); + + cbss->pub.channel = chan; + + list_for_each_entry(bss, &rdev->bss_list, list) { + if (!cfg80211_bss_type_match(bss->pub.capability, + bss->pub.channel->band, + wdev->conn_bss_type)) + continue; + + if (bss == cbss) + continue; + + if (!cmp_bss(&bss->pub, &cbss->pub, BSS_CMP_REGULAR)) { + new = bss; + break; + } + } + + if (new) { + /* to save time, update IEs for transmitting bss only */ + if (cfg80211_update_known_bss(rdev, cbss, new, false)) { + new->pub.proberesp_ies = NULL; + new->pub.beacon_ies = NULL; + } + + list_for_each_entry_safe(nontrans_bss, tmp, + &new->pub.nontrans_list, + nontrans_list) { + bss = container_of(nontrans_bss, + struct cfg80211_internal_bss, pub); + if (__cfg80211_unlink_bss(rdev, bss)) + rdev->bss_generation++; + } + + WARN_ON(atomic_read(&new->hold)); + if (!WARN_ON(!__cfg80211_unlink_bss(rdev, new))) + rdev->bss_generation++; + } + + rb_erase(&cbss->rbn, &rdev->bss_tree); + rb_insert_bss(rdev, cbss); + rdev->bss_generation++; + + list_for_each_entry_safe(nontrans_bss, tmp, + &cbss->pub.nontrans_list, + nontrans_list) { + bss = container_of(nontrans_bss, + struct cfg80211_internal_bss, pub); + bss->pub.channel = chan; + rb_erase(&bss->rbn, &rdev->bss_tree); + rb_insert_bss(rdev, bss); + rdev->bss_generation++; + } + +done: + spin_unlock_bh(&rdev->bss_lock); +} + +#ifdef CONFIG_CFG80211_WEXT +static struct cfg80211_registered_device * +cfg80211_get_dev_from_ifindex(struct net *net, int ifindex) +{ + struct cfg80211_registered_device *rdev; + struct net_device *dev; + + ASSERT_RTNL(); + + dev = dev_get_by_index(net, ifindex); + if (!dev) + return ERR_PTR(-ENODEV); + if (dev->ieee80211_ptr) + rdev = wiphy_to_rdev(dev->ieee80211_ptr->wiphy); + else + rdev = ERR_PTR(-ENODEV); + dev_put(dev); + return rdev; +} + +int cfg80211_wext_siwscan(struct net_device *dev, + struct iw_request_info *info, + union iwreq_data *wrqu, char *extra) +{ + struct cfg80211_registered_device *rdev; + struct wiphy *wiphy; + struct iw_scan_req *wreq = NULL; + struct cfg80211_scan_request *creq; + int i, err, n_channels = 0; + enum nl80211_band band; + + if (!netif_running(dev)) + return -ENETDOWN; + + if (wrqu->data.length == sizeof(struct iw_scan_req)) + wreq = (struct iw_scan_req *)extra; + + rdev = cfg80211_get_dev_from_ifindex(dev_net(dev), dev->ifindex); + + if (IS_ERR(rdev)) + return PTR_ERR(rdev); + + if (rdev->scan_req || rdev->scan_msg) + return -EBUSY; + + wiphy = &rdev->wiphy; + + /* Determine number of channels, needed to allocate creq */ + if (wreq && wreq->num_channels) + n_channels = wreq->num_channels; + else + n_channels = ieee80211_get_num_supported_channels(wiphy); + + creq = kzalloc(sizeof(*creq) + sizeof(struct cfg80211_ssid) + + n_channels * sizeof(void *), + GFP_ATOMIC); + if (!creq) + return -ENOMEM; + + creq->wiphy = wiphy; + creq->wdev = dev->ieee80211_ptr; + /* SSIDs come after channels */ + creq->ssids = (void *)&creq->channels[n_channels]; + creq->n_channels = n_channels; + creq->n_ssids = 1; + creq->scan_start = jiffies; + + /* translate "Scan on frequencies" request */ + i = 0; + for (band = 0; band < NUM_NL80211_BANDS; band++) { + int j; + + if (!wiphy->bands[band]) + continue; + + for (j = 0; j < wiphy->bands[band]->n_channels; j++) { + /* ignore disabled channels */ + if (wiphy->bands[band]->channels[j].flags & + IEEE80211_CHAN_DISABLED) + continue; + + /* If we have a wireless request structure and the + * wireless request specifies frequencies, then search + * for the matching hardware channel. + */ + if (wreq && wreq->num_channels) { + int k; + int wiphy_freq = wiphy->bands[band]->channels[j].center_freq; + for (k = 0; k < wreq->num_channels; k++) { + struct iw_freq *freq = + &wreq->channel_list[k]; + int wext_freq = + cfg80211_wext_freq(freq); + + if (wext_freq == wiphy_freq) + goto wext_freq_found; + } + goto wext_freq_not_found; + } + + wext_freq_found: + creq->channels[i] = &wiphy->bands[band]->channels[j]; + i++; + wext_freq_not_found: ; + } + } + /* No channels found? */ + if (!i) { + err = -EINVAL; + goto out; + } + + /* Set real number of channels specified in creq->channels[] */ + creq->n_channels = i; + + /* translate "Scan for SSID" request */ + if (wreq) { + if (wrqu->data.flags & IW_SCAN_THIS_ESSID) { + if (wreq->essid_len > IEEE80211_MAX_SSID_LEN) { + err = -EINVAL; + goto out; + } + memcpy(creq->ssids[0].ssid, wreq->essid, wreq->essid_len); + creq->ssids[0].ssid_len = wreq->essid_len; + } + if (wreq->scan_type == IW_SCAN_TYPE_PASSIVE) + creq->n_ssids = 0; + } + + for (i = 0; i < NUM_NL80211_BANDS; i++) + if (wiphy->bands[i]) + creq->rates[i] = (1 << wiphy->bands[i]->n_bitrates) - 1; + + eth_broadcast_addr(creq->bssid); + + wiphy_lock(&rdev->wiphy); + + rdev->scan_req = creq; + err = rdev_scan(rdev, creq); + if (err) { + rdev->scan_req = NULL; + /* creq will be freed below */ + } else { + nl80211_send_scan_start(rdev, dev->ieee80211_ptr); + /* creq now owned by driver */ + creq = NULL; + dev_hold(dev); + } + wiphy_unlock(&rdev->wiphy); + out: + kfree(creq); + return err; +} +EXPORT_WEXT_HANDLER(cfg80211_wext_siwscan); + +static char *ieee80211_scan_add_ies(struct iw_request_info *info, + const struct cfg80211_bss_ies *ies, + char *current_ev, char *end_buf) +{ + const u8 *pos, *end, *next; + struct iw_event iwe; + + if (!ies) + return current_ev; + + /* + * If needed, fragment the IEs buffer (at IE boundaries) into short + * enough fragments to fit into IW_GENERIC_IE_MAX octet messages. + */ + pos = ies->data; + end = pos + ies->len; + + while (end - pos > IW_GENERIC_IE_MAX) { + next = pos + 2 + pos[1]; + while (next + 2 + next[1] - pos < IW_GENERIC_IE_MAX) + next = next + 2 + next[1]; + + memset(&iwe, 0, sizeof(iwe)); + iwe.cmd = IWEVGENIE; + iwe.u.data.length = next - pos; + current_ev = iwe_stream_add_point_check(info, current_ev, + end_buf, &iwe, + (void *)pos); + if (IS_ERR(current_ev)) + return current_ev; + pos = next; + } + + if (end > pos) { + memset(&iwe, 0, sizeof(iwe)); + iwe.cmd = IWEVGENIE; + iwe.u.data.length = end - pos; + current_ev = iwe_stream_add_point_check(info, current_ev, + end_buf, &iwe, + (void *)pos); + if (IS_ERR(current_ev)) + return current_ev; + } + + return current_ev; +} + +static char * +ieee80211_bss(struct wiphy *wiphy, struct iw_request_info *info, + struct cfg80211_internal_bss *bss, char *current_ev, + char *end_buf) +{ + const struct cfg80211_bss_ies *ies; + struct iw_event iwe; + const u8 *ie; + u8 buf[50]; + u8 *cfg, *p, *tmp; + int rem, i, sig; + bool ismesh = false; + + memset(&iwe, 0, sizeof(iwe)); + iwe.cmd = SIOCGIWAP; + iwe.u.ap_addr.sa_family = ARPHRD_ETHER; + memcpy(iwe.u.ap_addr.sa_data, bss->pub.bssid, ETH_ALEN); + current_ev = iwe_stream_add_event_check(info, current_ev, end_buf, &iwe, + IW_EV_ADDR_LEN); + if (IS_ERR(current_ev)) + return current_ev; + + memset(&iwe, 0, sizeof(iwe)); + iwe.cmd = SIOCGIWFREQ; + iwe.u.freq.m = ieee80211_frequency_to_channel(bss->pub.channel->center_freq); + iwe.u.freq.e = 0; + current_ev = iwe_stream_add_event_check(info, current_ev, end_buf, &iwe, + IW_EV_FREQ_LEN); + if (IS_ERR(current_ev)) + return current_ev; + + memset(&iwe, 0, sizeof(iwe)); + iwe.cmd = SIOCGIWFREQ; + iwe.u.freq.m = bss->pub.channel->center_freq; + iwe.u.freq.e = 6; + current_ev = iwe_stream_add_event_check(info, current_ev, end_buf, &iwe, + IW_EV_FREQ_LEN); + if (IS_ERR(current_ev)) + return current_ev; + + if (wiphy->signal_type != CFG80211_SIGNAL_TYPE_NONE) { + memset(&iwe, 0, sizeof(iwe)); + iwe.cmd = IWEVQUAL; + iwe.u.qual.updated = IW_QUAL_LEVEL_UPDATED | + IW_QUAL_NOISE_INVALID | + IW_QUAL_QUAL_UPDATED; + switch (wiphy->signal_type) { + case CFG80211_SIGNAL_TYPE_MBM: + sig = bss->pub.signal / 100; + iwe.u.qual.level = sig; + iwe.u.qual.updated |= IW_QUAL_DBM; + if (sig < -110) /* rather bad */ + sig = -110; + else if (sig > -40) /* perfect */ + sig = -40; + /* will give a range of 0 .. 70 */ + iwe.u.qual.qual = sig + 110; + break; + case CFG80211_SIGNAL_TYPE_UNSPEC: + iwe.u.qual.level = bss->pub.signal; + /* will give range 0 .. 100 */ + iwe.u.qual.qual = bss->pub.signal; + break; + default: + /* not reached */ + break; + } + current_ev = iwe_stream_add_event_check(info, current_ev, + end_buf, &iwe, + IW_EV_QUAL_LEN); + if (IS_ERR(current_ev)) + return current_ev; + } + + memset(&iwe, 0, sizeof(iwe)); + iwe.cmd = SIOCGIWENCODE; + if (bss->pub.capability & WLAN_CAPABILITY_PRIVACY) + iwe.u.data.flags = IW_ENCODE_ENABLED | IW_ENCODE_NOKEY; + else + iwe.u.data.flags = IW_ENCODE_DISABLED; + iwe.u.data.length = 0; + current_ev = iwe_stream_add_point_check(info, current_ev, end_buf, + &iwe, ""); + if (IS_ERR(current_ev)) + return current_ev; + + rcu_read_lock(); + ies = rcu_dereference(bss->pub.ies); + rem = ies->len; + ie = ies->data; + + while (rem >= 2) { + /* invalid data */ + if (ie[1] > rem - 2) + break; + + switch (ie[0]) { + case WLAN_EID_SSID: + memset(&iwe, 0, sizeof(iwe)); + iwe.cmd = SIOCGIWESSID; + iwe.u.data.length = ie[1]; + iwe.u.data.flags = 1; + current_ev = iwe_stream_add_point_check(info, + current_ev, + end_buf, &iwe, + (u8 *)ie + 2); + if (IS_ERR(current_ev)) + goto unlock; + break; + case WLAN_EID_MESH_ID: + memset(&iwe, 0, sizeof(iwe)); + iwe.cmd = SIOCGIWESSID; + iwe.u.data.length = ie[1]; + iwe.u.data.flags = 1; + current_ev = iwe_stream_add_point_check(info, + current_ev, + end_buf, &iwe, + (u8 *)ie + 2); + if (IS_ERR(current_ev)) + goto unlock; + break; + case WLAN_EID_MESH_CONFIG: + ismesh = true; + if (ie[1] != sizeof(struct ieee80211_meshconf_ie)) + break; + cfg = (u8 *)ie + 2; + memset(&iwe, 0, sizeof(iwe)); + iwe.cmd = IWEVCUSTOM; + sprintf(buf, "Mesh Network Path Selection Protocol ID: " + "0x%02X", cfg[0]); + iwe.u.data.length = strlen(buf); + current_ev = iwe_stream_add_point_check(info, + current_ev, + end_buf, + &iwe, buf); + if (IS_ERR(current_ev)) + goto unlock; + sprintf(buf, "Path Selection Metric ID: 0x%02X", + cfg[1]); + iwe.u.data.length = strlen(buf); + current_ev = iwe_stream_add_point_check(info, + current_ev, + end_buf, + &iwe, buf); + if (IS_ERR(current_ev)) + goto unlock; + sprintf(buf, "Congestion Control Mode ID: 0x%02X", + cfg[2]); + iwe.u.data.length = strlen(buf); + current_ev = iwe_stream_add_point_check(info, + current_ev, + end_buf, + &iwe, buf); + if (IS_ERR(current_ev)) + goto unlock; + sprintf(buf, "Synchronization ID: 0x%02X", cfg[3]); + iwe.u.data.length = strlen(buf); + current_ev = iwe_stream_add_point_check(info, + current_ev, + end_buf, + &iwe, buf); + if (IS_ERR(current_ev)) + goto unlock; + sprintf(buf, "Authentication ID: 0x%02X", cfg[4]); + iwe.u.data.length = strlen(buf); + current_ev = iwe_stream_add_point_check(info, + current_ev, + end_buf, + &iwe, buf); + if (IS_ERR(current_ev)) + goto unlock; + sprintf(buf, "Formation Info: 0x%02X", cfg[5]); + iwe.u.data.length = strlen(buf); + current_ev = iwe_stream_add_point_check(info, + current_ev, + end_buf, + &iwe, buf); + if (IS_ERR(current_ev)) + goto unlock; + sprintf(buf, "Capabilities: 0x%02X", cfg[6]); + iwe.u.data.length = strlen(buf); + current_ev = iwe_stream_add_point_check(info, + current_ev, + end_buf, + &iwe, buf); + if (IS_ERR(current_ev)) + goto unlock; + break; + case WLAN_EID_SUPP_RATES: + case WLAN_EID_EXT_SUPP_RATES: + /* display all supported rates in readable format */ + p = current_ev + iwe_stream_lcp_len(info); + + memset(&iwe, 0, sizeof(iwe)); + iwe.cmd = SIOCGIWRATE; + /* Those two flags are ignored... */ + iwe.u.bitrate.fixed = iwe.u.bitrate.disabled = 0; + + for (i = 0; i < ie[1]; i++) { + iwe.u.bitrate.value = + ((ie[i + 2] & 0x7f) * 500000); + tmp = p; + p = iwe_stream_add_value(info, current_ev, p, + end_buf, &iwe, + IW_EV_PARAM_LEN); + if (p == tmp) { + current_ev = ERR_PTR(-E2BIG); + goto unlock; + } + } + current_ev = p; + break; + } + rem -= ie[1] + 2; + ie += ie[1] + 2; + } + + if (bss->pub.capability & (WLAN_CAPABILITY_ESS | WLAN_CAPABILITY_IBSS) || + ismesh) { + memset(&iwe, 0, sizeof(iwe)); + iwe.cmd = SIOCGIWMODE; + if (ismesh) + iwe.u.mode = IW_MODE_MESH; + else if (bss->pub.capability & WLAN_CAPABILITY_ESS) + iwe.u.mode = IW_MODE_MASTER; + else + iwe.u.mode = IW_MODE_ADHOC; + current_ev = iwe_stream_add_event_check(info, current_ev, + end_buf, &iwe, + IW_EV_UINT_LEN); + if (IS_ERR(current_ev)) + goto unlock; + } + + memset(&iwe, 0, sizeof(iwe)); + iwe.cmd = IWEVCUSTOM; + sprintf(buf, "tsf=%016llx", (unsigned long long)(ies->tsf)); + iwe.u.data.length = strlen(buf); + current_ev = iwe_stream_add_point_check(info, current_ev, end_buf, + &iwe, buf); + if (IS_ERR(current_ev)) + goto unlock; + memset(&iwe, 0, sizeof(iwe)); + iwe.cmd = IWEVCUSTOM; + sprintf(buf, " Last beacon: %ums ago", + elapsed_jiffies_msecs(bss->ts)); + iwe.u.data.length = strlen(buf); + current_ev = iwe_stream_add_point_check(info, current_ev, + end_buf, &iwe, buf); + if (IS_ERR(current_ev)) + goto unlock; + + current_ev = ieee80211_scan_add_ies(info, ies, current_ev, end_buf); + + unlock: + rcu_read_unlock(); + return current_ev; +} + + +static int ieee80211_scan_results(struct cfg80211_registered_device *rdev, + struct iw_request_info *info, + char *buf, size_t len) +{ + char *current_ev = buf; + char *end_buf = buf + len; + struct cfg80211_internal_bss *bss; + int err = 0; + + spin_lock_bh(&rdev->bss_lock); + cfg80211_bss_expire(rdev); + + list_for_each_entry(bss, &rdev->bss_list, list) { + if (buf + len - current_ev <= IW_EV_ADDR_LEN) { + err = -E2BIG; + break; + } + current_ev = ieee80211_bss(&rdev->wiphy, info, bss, + current_ev, end_buf); + if (IS_ERR(current_ev)) { + err = PTR_ERR(current_ev); + break; + } + } + spin_unlock_bh(&rdev->bss_lock); + + if (err) + return err; + return current_ev - buf; +} + + +int cfg80211_wext_giwscan(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *data, char *extra) +{ + struct cfg80211_registered_device *rdev; + int res; + + if (!netif_running(dev)) + return -ENETDOWN; + + rdev = cfg80211_get_dev_from_ifindex(dev_net(dev), dev->ifindex); + + if (IS_ERR(rdev)) + return PTR_ERR(rdev); + + if (rdev->scan_req || rdev->scan_msg) + return -EAGAIN; + + res = ieee80211_scan_results(rdev, info, extra, data->length); + data->length = 0; + if (res >= 0) { + data->length = res; + res = 0; + } + + return res; +} +EXPORT_WEXT_HANDLER(cfg80211_wext_giwscan); +#endif diff --git a/net/wireless/sme.c b/net/wireless/sme.c new file mode 100644 index 000000000..b97834284 --- /dev/null +++ b/net/wireless/sme.c @@ -0,0 +1,1598 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * SME code for cfg80211 + * both driver SME event handling and the SME implementation + * (for nl80211's connect() and wext) + * + * Copyright 2009 Johannes Berg + * Copyright (C) 2009, 2020, 2022-2023 Intel Corporation. All rights reserved. + * Copyright 2017 Intel Deutschland GmbH + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "nl80211.h" +#include "reg.h" +#include "rdev-ops.h" + +/* + * Software SME in cfg80211, using auth/assoc/deauth calls to the + * driver. This is for implementing nl80211's connect/disconnect + * and wireless extensions (if configured.) + */ + +struct cfg80211_conn { + struct cfg80211_connect_params params; + /* these are sub-states of the _CONNECTING sme_state */ + enum { + CFG80211_CONN_SCANNING, + CFG80211_CONN_SCAN_AGAIN, + CFG80211_CONN_AUTHENTICATE_NEXT, + CFG80211_CONN_AUTHENTICATING, + CFG80211_CONN_AUTH_FAILED_TIMEOUT, + CFG80211_CONN_ASSOCIATE_NEXT, + CFG80211_CONN_ASSOCIATING, + CFG80211_CONN_ASSOC_FAILED, + CFG80211_CONN_ASSOC_FAILED_TIMEOUT, + CFG80211_CONN_DEAUTH, + CFG80211_CONN_ABANDON, + CFG80211_CONN_CONNECTED, + } state; + u8 bssid[ETH_ALEN], prev_bssid[ETH_ALEN]; + const u8 *ie; + size_t ie_len; + bool auto_auth, prev_bssid_valid; +}; + +static void cfg80211_sme_free(struct wireless_dev *wdev) +{ + if (!wdev->conn) + return; + + kfree(wdev->conn->ie); + kfree(wdev->conn); + wdev->conn = NULL; +} + +static int cfg80211_conn_scan(struct wireless_dev *wdev) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_scan_request *request; + int n_channels, err; + + ASSERT_WDEV_LOCK(wdev); + + if (rdev->scan_req || rdev->scan_msg) + return -EBUSY; + + if (wdev->conn->params.channel) + n_channels = 1; + else + n_channels = ieee80211_get_num_supported_channels(wdev->wiphy); + + request = kzalloc(sizeof(*request) + sizeof(request->ssids[0]) + + sizeof(request->channels[0]) * n_channels, + GFP_KERNEL); + if (!request) + return -ENOMEM; + + if (wdev->conn->params.channel) { + enum nl80211_band band = wdev->conn->params.channel->band; + struct ieee80211_supported_band *sband = + wdev->wiphy->bands[band]; + + if (!sband) { + kfree(request); + return -EINVAL; + } + request->channels[0] = wdev->conn->params.channel; + request->rates[band] = (1 << sband->n_bitrates) - 1; + } else { + int i = 0, j; + enum nl80211_band band; + struct ieee80211_supported_band *bands; + struct ieee80211_channel *channel; + + for (band = 0; band < NUM_NL80211_BANDS; band++) { + bands = wdev->wiphy->bands[band]; + if (!bands) + continue; + for (j = 0; j < bands->n_channels; j++) { + channel = &bands->channels[j]; + if (channel->flags & IEEE80211_CHAN_DISABLED) + continue; + request->channels[i++] = channel; + } + request->rates[band] = (1 << bands->n_bitrates) - 1; + } + n_channels = i; + } + request->n_channels = n_channels; + request->ssids = (void *)&request->channels[n_channels]; + request->n_ssids = 1; + + memcpy(request->ssids[0].ssid, wdev->conn->params.ssid, + wdev->conn->params.ssid_len); + request->ssids[0].ssid_len = wdev->conn->params.ssid_len; + + eth_broadcast_addr(request->bssid); + + request->wdev = wdev; + request->wiphy = &rdev->wiphy; + request->scan_start = jiffies; + + rdev->scan_req = request; + + err = rdev_scan(rdev, request); + if (!err) { + wdev->conn->state = CFG80211_CONN_SCANNING; + nl80211_send_scan_start(rdev, wdev); + dev_hold(wdev->netdev); + } else { + rdev->scan_req = NULL; + kfree(request); + } + return err; +} + +static int cfg80211_conn_do_work(struct wireless_dev *wdev, + enum nl80211_timeout_reason *treason) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_connect_params *params; + struct cfg80211_auth_request auth_req = {}; + struct cfg80211_assoc_request req = {}; + int err; + + ASSERT_WDEV_LOCK(wdev); + + if (!wdev->conn) + return 0; + + params = &wdev->conn->params; + + switch (wdev->conn->state) { + case CFG80211_CONN_SCANNING: + /* didn't find it during scan ... */ + return -ENOENT; + case CFG80211_CONN_SCAN_AGAIN: + return cfg80211_conn_scan(wdev); + case CFG80211_CONN_AUTHENTICATE_NEXT: + if (WARN_ON(!rdev->ops->auth)) + return -EOPNOTSUPP; + wdev->conn->state = CFG80211_CONN_AUTHENTICATING; + auth_req.key = params->key; + auth_req.key_len = params->key_len; + auth_req.key_idx = params->key_idx; + auth_req.auth_type = params->auth_type; + auth_req.bss = cfg80211_get_bss(&rdev->wiphy, params->channel, + params->bssid, + params->ssid, params->ssid_len, + IEEE80211_BSS_TYPE_ESS, + IEEE80211_PRIVACY_ANY); + auth_req.link_id = -1; + err = cfg80211_mlme_auth(rdev, wdev->netdev, &auth_req); + cfg80211_put_bss(&rdev->wiphy, auth_req.bss); + return err; + case CFG80211_CONN_AUTH_FAILED_TIMEOUT: + *treason = NL80211_TIMEOUT_AUTH; + return -ENOTCONN; + case CFG80211_CONN_ASSOCIATE_NEXT: + if (WARN_ON(!rdev->ops->assoc)) + return -EOPNOTSUPP; + wdev->conn->state = CFG80211_CONN_ASSOCIATING; + if (wdev->conn->prev_bssid_valid) + req.prev_bssid = wdev->conn->prev_bssid; + req.ie = params->ie; + req.ie_len = params->ie_len; + req.use_mfp = params->mfp != NL80211_MFP_NO; + req.crypto = params->crypto; + req.flags = params->flags; + req.ht_capa = params->ht_capa; + req.ht_capa_mask = params->ht_capa_mask; + req.vht_capa = params->vht_capa; + req.vht_capa_mask = params->vht_capa_mask; + req.link_id = -1; + + req.bss = cfg80211_get_bss(&rdev->wiphy, params->channel, + params->bssid, + params->ssid, params->ssid_len, + IEEE80211_BSS_TYPE_ESS, + IEEE80211_PRIVACY_ANY); + if (!req.bss) { + err = -ENOENT; + } else { + err = cfg80211_mlme_assoc(rdev, wdev->netdev, &req); + cfg80211_put_bss(&rdev->wiphy, req.bss); + } + + if (err) + cfg80211_mlme_deauth(rdev, wdev->netdev, params->bssid, + NULL, 0, + WLAN_REASON_DEAUTH_LEAVING, + false); + return err; + case CFG80211_CONN_ASSOC_FAILED_TIMEOUT: + *treason = NL80211_TIMEOUT_ASSOC; + fallthrough; + case CFG80211_CONN_ASSOC_FAILED: + cfg80211_mlme_deauth(rdev, wdev->netdev, params->bssid, + NULL, 0, + WLAN_REASON_DEAUTH_LEAVING, false); + return -ENOTCONN; + case CFG80211_CONN_DEAUTH: + cfg80211_mlme_deauth(rdev, wdev->netdev, params->bssid, + NULL, 0, + WLAN_REASON_DEAUTH_LEAVING, false); + fallthrough; + case CFG80211_CONN_ABANDON: + /* free directly, disconnected event already sent */ + cfg80211_sme_free(wdev); + return 0; + default: + return 0; + } +} + +void cfg80211_conn_work(struct work_struct *work) +{ + struct cfg80211_registered_device *rdev = + container_of(work, struct cfg80211_registered_device, conn_work); + struct wireless_dev *wdev; + u8 bssid_buf[ETH_ALEN], *bssid = NULL; + enum nl80211_timeout_reason treason; + + wiphy_lock(&rdev->wiphy); + + list_for_each_entry(wdev, &rdev->wiphy.wdev_list, list) { + if (!wdev->netdev) + continue; + + wdev_lock(wdev); + if (!netif_running(wdev->netdev)) { + wdev_unlock(wdev); + continue; + } + if (!wdev->conn || + wdev->conn->state == CFG80211_CONN_CONNECTED) { + wdev_unlock(wdev); + continue; + } + if (wdev->conn->params.bssid) { + memcpy(bssid_buf, wdev->conn->params.bssid, ETH_ALEN); + bssid = bssid_buf; + } + treason = NL80211_TIMEOUT_UNSPECIFIED; + if (cfg80211_conn_do_work(wdev, &treason)) { + struct cfg80211_connect_resp_params cr; + + memset(&cr, 0, sizeof(cr)); + cr.status = -1; + cr.links[0].bssid = bssid; + cr.timeout_reason = treason; + __cfg80211_connect_result(wdev->netdev, &cr, false); + } + wdev_unlock(wdev); + } + + wiphy_unlock(&rdev->wiphy); +} + +static void cfg80211_step_auth_next(struct cfg80211_conn *conn, + struct cfg80211_bss *bss) +{ + memcpy(conn->bssid, bss->bssid, ETH_ALEN); + conn->params.bssid = conn->bssid; + conn->params.channel = bss->channel; + conn->state = CFG80211_CONN_AUTHENTICATE_NEXT; +} + +/* Returned bss is reference counted and must be cleaned up appropriately. */ +static struct cfg80211_bss *cfg80211_get_conn_bss(struct wireless_dev *wdev) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_bss *bss; + + ASSERT_WDEV_LOCK(wdev); + + bss = cfg80211_get_bss(wdev->wiphy, wdev->conn->params.channel, + wdev->conn->params.bssid, + wdev->conn->params.ssid, + wdev->conn->params.ssid_len, + wdev->conn_bss_type, + IEEE80211_PRIVACY(wdev->conn->params.privacy)); + if (!bss) + return NULL; + + cfg80211_step_auth_next(wdev->conn, bss); + schedule_work(&rdev->conn_work); + + return bss; +} + +static void __cfg80211_sme_scan_done(struct net_device *dev) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_bss *bss; + + ASSERT_WDEV_LOCK(wdev); + + if (!wdev->conn) + return; + + if (wdev->conn->state != CFG80211_CONN_SCANNING && + wdev->conn->state != CFG80211_CONN_SCAN_AGAIN) + return; + + bss = cfg80211_get_conn_bss(wdev); + if (bss) + cfg80211_put_bss(&rdev->wiphy, bss); + else + schedule_work(&rdev->conn_work); +} + +void cfg80211_sme_scan_done(struct net_device *dev) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + + wdev_lock(wdev); + __cfg80211_sme_scan_done(dev); + wdev_unlock(wdev); +} + +void cfg80211_sme_rx_auth(struct wireless_dev *wdev, const u8 *buf, size_t len) +{ + struct wiphy *wiphy = wdev->wiphy; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wiphy); + struct ieee80211_mgmt *mgmt = (struct ieee80211_mgmt *)buf; + u16 status_code = le16_to_cpu(mgmt->u.auth.status_code); + + ASSERT_WDEV_LOCK(wdev); + + if (!wdev->conn || wdev->conn->state == CFG80211_CONN_CONNECTED) + return; + + if (status_code == WLAN_STATUS_NOT_SUPPORTED_AUTH_ALG && + wdev->conn->auto_auth && + wdev->conn->params.auth_type != NL80211_AUTHTYPE_NETWORK_EAP) { + /* select automatically between only open, shared, leap */ + switch (wdev->conn->params.auth_type) { + case NL80211_AUTHTYPE_OPEN_SYSTEM: + if (wdev->connect_keys) + wdev->conn->params.auth_type = + NL80211_AUTHTYPE_SHARED_KEY; + else + wdev->conn->params.auth_type = + NL80211_AUTHTYPE_NETWORK_EAP; + break; + case NL80211_AUTHTYPE_SHARED_KEY: + wdev->conn->params.auth_type = + NL80211_AUTHTYPE_NETWORK_EAP; + break; + default: + /* huh? */ + wdev->conn->params.auth_type = + NL80211_AUTHTYPE_OPEN_SYSTEM; + break; + } + wdev->conn->state = CFG80211_CONN_AUTHENTICATE_NEXT; + schedule_work(&rdev->conn_work); + } else if (status_code != WLAN_STATUS_SUCCESS) { + struct cfg80211_connect_resp_params cr; + + memset(&cr, 0, sizeof(cr)); + cr.status = status_code; + cr.links[0].bssid = mgmt->bssid; + cr.timeout_reason = NL80211_TIMEOUT_UNSPECIFIED; + __cfg80211_connect_result(wdev->netdev, &cr, false); + } else if (wdev->conn->state == CFG80211_CONN_AUTHENTICATING) { + wdev->conn->state = CFG80211_CONN_ASSOCIATE_NEXT; + schedule_work(&rdev->conn_work); + } +} + +bool cfg80211_sme_rx_assoc_resp(struct wireless_dev *wdev, u16 status) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + + if (!wdev->conn) + return false; + + if (status == WLAN_STATUS_SUCCESS) { + wdev->conn->state = CFG80211_CONN_CONNECTED; + return false; + } + + if (wdev->conn->prev_bssid_valid) { + /* + * Some stupid APs don't accept reassoc, so we + * need to fall back to trying regular assoc; + * return true so no event is sent to userspace. + */ + wdev->conn->prev_bssid_valid = false; + wdev->conn->state = CFG80211_CONN_ASSOCIATE_NEXT; + schedule_work(&rdev->conn_work); + return true; + } + + wdev->conn->state = CFG80211_CONN_ASSOC_FAILED; + schedule_work(&rdev->conn_work); + return false; +} + +void cfg80211_sme_deauth(struct wireless_dev *wdev) +{ + cfg80211_sme_free(wdev); +} + +void cfg80211_sme_auth_timeout(struct wireless_dev *wdev) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + + if (!wdev->conn) + return; + + wdev->conn->state = CFG80211_CONN_AUTH_FAILED_TIMEOUT; + schedule_work(&rdev->conn_work); +} + +void cfg80211_sme_disassoc(struct wireless_dev *wdev) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + + if (!wdev->conn) + return; + + wdev->conn->state = CFG80211_CONN_DEAUTH; + schedule_work(&rdev->conn_work); +} + +void cfg80211_sme_assoc_timeout(struct wireless_dev *wdev) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + + if (!wdev->conn) + return; + + wdev->conn->state = CFG80211_CONN_ASSOC_FAILED_TIMEOUT; + schedule_work(&rdev->conn_work); +} + +void cfg80211_sme_abandon_assoc(struct wireless_dev *wdev) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + + if (!wdev->conn) + return; + + wdev->conn->state = CFG80211_CONN_ABANDON; + schedule_work(&rdev->conn_work); +} + +static void cfg80211_wdev_release_bsses(struct wireless_dev *wdev) +{ + unsigned int link; + + for_each_valid_link(wdev, link) { + if (!wdev->links[link].client.current_bss) + continue; + cfg80211_unhold_bss(wdev->links[link].client.current_bss); + cfg80211_put_bss(wdev->wiphy, + &wdev->links[link].client.current_bss->pub); + wdev->links[link].client.current_bss = NULL; + } +} + +static int cfg80211_sme_get_conn_ies(struct wireless_dev *wdev, + const u8 *ies, size_t ies_len, + const u8 **out_ies, size_t *out_ies_len) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + u8 *buf; + size_t offs; + + if (!rdev->wiphy.extended_capabilities_len || + (ies && cfg80211_find_ie(WLAN_EID_EXT_CAPABILITY, ies, ies_len))) { + *out_ies = kmemdup(ies, ies_len, GFP_KERNEL); + if (!*out_ies) + return -ENOMEM; + *out_ies_len = ies_len; + return 0; + } + + buf = kmalloc(ies_len + rdev->wiphy.extended_capabilities_len + 2, + GFP_KERNEL); + if (!buf) + return -ENOMEM; + + if (ies_len) { + static const u8 before_extcapa[] = { + /* not listing IEs expected to be created by driver */ + WLAN_EID_RSN, + WLAN_EID_QOS_CAPA, + WLAN_EID_RRM_ENABLED_CAPABILITIES, + WLAN_EID_MOBILITY_DOMAIN, + WLAN_EID_SUPPORTED_REGULATORY_CLASSES, + WLAN_EID_BSS_COEX_2040, + }; + + offs = ieee80211_ie_split(ies, ies_len, before_extcapa, + ARRAY_SIZE(before_extcapa), 0); + memcpy(buf, ies, offs); + /* leave a whole for extended capabilities IE */ + memcpy(buf + offs + rdev->wiphy.extended_capabilities_len + 2, + ies + offs, ies_len - offs); + } else { + offs = 0; + } + + /* place extended capabilities IE (with only driver capabilities) */ + buf[offs] = WLAN_EID_EXT_CAPABILITY; + buf[offs + 1] = rdev->wiphy.extended_capabilities_len; + memcpy(buf + offs + 2, + rdev->wiphy.extended_capabilities, + rdev->wiphy.extended_capabilities_len); + + *out_ies = buf; + *out_ies_len = ies_len + rdev->wiphy.extended_capabilities_len + 2; + + return 0; +} + +static int cfg80211_sme_connect(struct wireless_dev *wdev, + struct cfg80211_connect_params *connect, + const u8 *prev_bssid) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_bss *bss; + int err; + + if (!rdev->ops->auth || !rdev->ops->assoc) + return -EOPNOTSUPP; + + cfg80211_wdev_release_bsses(wdev); + + if (wdev->connected) { + cfg80211_sme_free(wdev); + wdev->connected = false; + } + + if (wdev->conn) + return -EINPROGRESS; + + wdev->conn = kzalloc(sizeof(*wdev->conn), GFP_KERNEL); + if (!wdev->conn) + return -ENOMEM; + + /* + * Copy all parameters, and treat explicitly IEs, BSSID, SSID. + */ + memcpy(&wdev->conn->params, connect, sizeof(*connect)); + if (connect->bssid) { + wdev->conn->params.bssid = wdev->conn->bssid; + memcpy(wdev->conn->bssid, connect->bssid, ETH_ALEN); + } + + if (cfg80211_sme_get_conn_ies(wdev, connect->ie, connect->ie_len, + &wdev->conn->ie, + &wdev->conn->params.ie_len)) { + kfree(wdev->conn); + wdev->conn = NULL; + return -ENOMEM; + } + wdev->conn->params.ie = wdev->conn->ie; + + if (connect->auth_type == NL80211_AUTHTYPE_AUTOMATIC) { + wdev->conn->auto_auth = true; + /* start with open system ... should mostly work */ + wdev->conn->params.auth_type = + NL80211_AUTHTYPE_OPEN_SYSTEM; + } else { + wdev->conn->auto_auth = false; + } + + wdev->conn->params.ssid = wdev->u.client.ssid; + wdev->conn->params.ssid_len = wdev->u.client.ssid_len; + + /* see if we have the bss already */ + bss = cfg80211_get_bss(wdev->wiphy, wdev->conn->params.channel, + wdev->conn->params.bssid, + wdev->conn->params.ssid, + wdev->conn->params.ssid_len, + wdev->conn_bss_type, + IEEE80211_PRIVACY(wdev->conn->params.privacy)); + + if (prev_bssid) { + memcpy(wdev->conn->prev_bssid, prev_bssid, ETH_ALEN); + wdev->conn->prev_bssid_valid = true; + } + + /* we're good if we have a matching bss struct */ + if (bss) { + enum nl80211_timeout_reason treason; + + cfg80211_step_auth_next(wdev->conn, bss); + err = cfg80211_conn_do_work(wdev, &treason); + cfg80211_put_bss(wdev->wiphy, bss); + } else { + /* otherwise we'll need to scan for the AP first */ + err = cfg80211_conn_scan(wdev); + + /* + * If we can't scan right now, then we need to scan again + * after the current scan finished, since the parameters + * changed (unless we find a good AP anyway). + */ + if (err == -EBUSY) { + err = 0; + wdev->conn->state = CFG80211_CONN_SCAN_AGAIN; + } + } + + if (err) + cfg80211_sme_free(wdev); + + return err; +} + +static int cfg80211_sme_disconnect(struct wireless_dev *wdev, u16 reason) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + int err; + + if (!wdev->conn) + return 0; + + if (!rdev->ops->deauth) + return -EOPNOTSUPP; + + if (wdev->conn->state == CFG80211_CONN_SCANNING || + wdev->conn->state == CFG80211_CONN_SCAN_AGAIN) { + err = 0; + goto out; + } + + /* wdev->conn->params.bssid must be set if > SCANNING */ + err = cfg80211_mlme_deauth(rdev, wdev->netdev, + wdev->conn->params.bssid, + NULL, 0, reason, false); + out: + cfg80211_sme_free(wdev); + return err; +} + +/* + * code shared for in-device and software SME + */ + +static bool cfg80211_is_all_idle(void) +{ + struct cfg80211_registered_device *rdev; + struct wireless_dev *wdev; + bool is_all_idle = true; + + /* + * All devices must be idle as otherwise if you are actively + * scanning some new beacon hints could be learned and would + * count as new regulatory hints. + * Also if there is any other active beaconing interface we + * need not issue a disconnect hint and reset any info such + * as chan dfs state, etc. + */ + list_for_each_entry(rdev, &cfg80211_rdev_list, list) { + list_for_each_entry(wdev, &rdev->wiphy.wdev_list, list) { + wdev_lock(wdev); + if (wdev->conn || wdev->connected || + cfg80211_beaconing_iface_active(wdev)) + is_all_idle = false; + wdev_unlock(wdev); + } + } + + return is_all_idle; +} + +static void disconnect_work(struct work_struct *work) +{ + rtnl_lock(); + if (cfg80211_is_all_idle()) + regulatory_hint_disconnect(); + rtnl_unlock(); +} + +DECLARE_WORK(cfg80211_disconnect_work, disconnect_work); + +static void +cfg80211_connect_result_release_bsses(struct wireless_dev *wdev, + struct cfg80211_connect_resp_params *cr) +{ + unsigned int link; + + for_each_valid_link(cr, link) { + if (!cr->links[link].bss) + continue; + cfg80211_unhold_bss(bss_from_pub(cr->links[link].bss)); + cfg80211_put_bss(wdev->wiphy, cr->links[link].bss); + } +} + +/* + * API calls for drivers implementing connect/disconnect and + * SME event handling + */ + +/* This method must consume bss one way or another */ +void __cfg80211_connect_result(struct net_device *dev, + struct cfg80211_connect_resp_params *cr, + bool wextev) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + const struct element *country_elem = NULL; + const struct element *ssid; + const u8 *country_data; + u8 country_datalen; +#ifdef CONFIG_CFG80211_WEXT + union iwreq_data wrqu; +#endif + unsigned int link; + const u8 *connected_addr; + bool bss_not_found = false; + + ASSERT_WDEV_LOCK(wdev); + + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_STATION && + wdev->iftype != NL80211_IFTYPE_P2P_CLIENT)) + goto out; + + if (cr->valid_links) { + if (WARN_ON(!cr->ap_mld_addr)) + goto out; + + for_each_valid_link(cr, link) { + if (WARN_ON(!cr->links[link].addr)) + goto out; + } + + if (WARN_ON(wdev->connect_keys)) + goto out; + } + + wdev->unprot_beacon_reported = 0; + nl80211_send_connect_result(wiphy_to_rdev(wdev->wiphy), dev, cr, + GFP_KERNEL); + connected_addr = cr->valid_links ? cr->ap_mld_addr : cr->links[0].bssid; + +#ifdef CONFIG_CFG80211_WEXT + if (wextev && !cr->valid_links) { + if (cr->req_ie && cr->status == WLAN_STATUS_SUCCESS) { + memset(&wrqu, 0, sizeof(wrqu)); + wrqu.data.length = cr->req_ie_len; + wireless_send_event(dev, IWEVASSOCREQIE, &wrqu, + cr->req_ie); + } + + if (cr->resp_ie && cr->status == WLAN_STATUS_SUCCESS) { + memset(&wrqu, 0, sizeof(wrqu)); + wrqu.data.length = cr->resp_ie_len; + wireless_send_event(dev, IWEVASSOCRESPIE, &wrqu, + cr->resp_ie); + } + + memset(&wrqu, 0, sizeof(wrqu)); + wrqu.ap_addr.sa_family = ARPHRD_ETHER; + if (connected_addr && cr->status == WLAN_STATUS_SUCCESS) { + memcpy(wrqu.ap_addr.sa_data, connected_addr, ETH_ALEN); + memcpy(wdev->wext.prev_bssid, connected_addr, ETH_ALEN); + wdev->wext.prev_bssid_valid = true; + } + wireless_send_event(dev, SIOCGIWAP, &wrqu, NULL); + } +#endif + + if (cr->status == WLAN_STATUS_SUCCESS) { + if (!wiphy_to_rdev(wdev->wiphy)->ops->connect) { + for_each_valid_link(cr, link) { + if (WARN_ON_ONCE(!cr->links[link].bss)) + break; + } + } + + for_each_valid_link(cr, link) { + if (cr->links[link].bss) + continue; + + cr->links[link].bss = + cfg80211_get_bss(wdev->wiphy, NULL, + cr->links[link].bssid, + wdev->u.client.ssid, + wdev->u.client.ssid_len, + wdev->conn_bss_type, + IEEE80211_PRIVACY_ANY); + if (!cr->links[link].bss) { + bss_not_found = true; + break; + } + cfg80211_hold_bss(bss_from_pub(cr->links[link].bss)); + } + } + + cfg80211_wdev_release_bsses(wdev); + + if (cr->status != WLAN_STATUS_SUCCESS) { + kfree_sensitive(wdev->connect_keys); + wdev->connect_keys = NULL; + wdev->u.client.ssid_len = 0; + wdev->conn_owner_nlportid = 0; + cfg80211_connect_result_release_bsses(wdev, cr); + cfg80211_sme_free(wdev); + return; + } + + if (WARN_ON(bss_not_found)) { + cfg80211_connect_result_release_bsses(wdev, cr); + return; + } + + memset(wdev->links, 0, sizeof(wdev->links)); + wdev->valid_links = cr->valid_links; + for_each_valid_link(cr, link) + wdev->links[link].client.current_bss = + bss_from_pub(cr->links[link].bss); + wdev->connected = true; + ether_addr_copy(wdev->u.client.connected_addr, connected_addr); + if (cr->valid_links) { + for_each_valid_link(cr, link) + memcpy(wdev->links[link].addr, cr->links[link].addr, + ETH_ALEN); + } + + if (!(wdev->wiphy->flags & WIPHY_FLAG_HAS_STATIC_WEP)) + cfg80211_upload_connect_keys(wdev); + + rcu_read_lock(); + for_each_valid_link(cr, link) { + country_elem = + ieee80211_bss_get_elem(cr->links[link].bss, + WLAN_EID_COUNTRY); + if (country_elem) + break; + } + if (!country_elem) { + rcu_read_unlock(); + return; + } + + country_datalen = country_elem->datalen; + country_data = kmemdup(country_elem->data, country_datalen, GFP_ATOMIC); + rcu_read_unlock(); + + if (!country_data) + return; + + regulatory_hint_country_ie(wdev->wiphy, + cr->links[link].bss->channel->band, + country_data, country_datalen); + kfree(country_data); + + if (!wdev->u.client.ssid_len) { + rcu_read_lock(); + for_each_valid_link(cr, link) { + ssid = ieee80211_bss_get_elem(cr->links[link].bss, + WLAN_EID_SSID); + + if (!ssid || !ssid->datalen) + continue; + + memcpy(wdev->u.client.ssid, ssid->data, ssid->datalen); + wdev->u.client.ssid_len = ssid->datalen; + break; + } + rcu_read_unlock(); + } + + return; +out: + for_each_valid_link(cr, link) + cfg80211_put_bss(wdev->wiphy, cr->links[link].bss); +} + +static void cfg80211_update_link_bss(struct wireless_dev *wdev, + struct cfg80211_bss **bss) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_internal_bss *ibss; + + if (!*bss) + return; + + ibss = bss_from_pub(*bss); + if (list_empty(&ibss->list)) { + struct cfg80211_bss *found = NULL, *tmp = *bss; + + found = cfg80211_get_bss(wdev->wiphy, NULL, + (*bss)->bssid, + wdev->u.client.ssid, + wdev->u.client.ssid_len, + wdev->conn_bss_type, + IEEE80211_PRIVACY_ANY); + if (found) { + /* The same BSS is already updated so use it + * instead, as it has latest info. + */ + *bss = found; + } else { + /* Update with BSS provided by driver, it will + * be freshly added and ref cnted, we can free + * the old one. + * + * signal_valid can be false, as we are not + * expecting the BSS to be found. + * + * keep the old timestamp to avoid confusion + */ + cfg80211_bss_update(rdev, ibss, false, + ibss->ts); + } + + cfg80211_put_bss(wdev->wiphy, tmp); + } +} + +/* Consumes bss object(s) one way or another */ +void cfg80211_connect_done(struct net_device *dev, + struct cfg80211_connect_resp_params *params, + gfp_t gfp) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_event *ev; + unsigned long flags; + u8 *next; + size_t link_info_size = 0; + unsigned int link; + + for_each_valid_link(params, link) { + cfg80211_update_link_bss(wdev, ¶ms->links[link].bss); + link_info_size += params->links[link].bssid ? ETH_ALEN : 0; + link_info_size += params->links[link].addr ? ETH_ALEN : 0; + } + + ev = kzalloc(sizeof(*ev) + (params->ap_mld_addr ? ETH_ALEN : 0) + + params->req_ie_len + params->resp_ie_len + + params->fils.kek_len + params->fils.pmk_len + + (params->fils.pmkid ? WLAN_PMKID_LEN : 0) + link_info_size, + gfp); + + if (!ev) { + for_each_valid_link(params, link) + cfg80211_put_bss(wdev->wiphy, + params->links[link].bss); + return; + } + + ev->type = EVENT_CONNECT_RESULT; + next = ((u8 *)ev) + sizeof(*ev); + if (params->ap_mld_addr) { + ev->cr.ap_mld_addr = next; + memcpy((void *)ev->cr.ap_mld_addr, params->ap_mld_addr, + ETH_ALEN); + next += ETH_ALEN; + } + if (params->req_ie_len) { + ev->cr.req_ie = next; + ev->cr.req_ie_len = params->req_ie_len; + memcpy((void *)ev->cr.req_ie, params->req_ie, + params->req_ie_len); + next += params->req_ie_len; + } + if (params->resp_ie_len) { + ev->cr.resp_ie = next; + ev->cr.resp_ie_len = params->resp_ie_len; + memcpy((void *)ev->cr.resp_ie, params->resp_ie, + params->resp_ie_len); + next += params->resp_ie_len; + } + if (params->fils.kek_len) { + ev->cr.fils.kek = next; + ev->cr.fils.kek_len = params->fils.kek_len; + memcpy((void *)ev->cr.fils.kek, params->fils.kek, + params->fils.kek_len); + next += params->fils.kek_len; + } + if (params->fils.pmk_len) { + ev->cr.fils.pmk = next; + ev->cr.fils.pmk_len = params->fils.pmk_len; + memcpy((void *)ev->cr.fils.pmk, params->fils.pmk, + params->fils.pmk_len); + next += params->fils.pmk_len; + } + if (params->fils.pmkid) { + ev->cr.fils.pmkid = next; + memcpy((void *)ev->cr.fils.pmkid, params->fils.pmkid, + WLAN_PMKID_LEN); + next += WLAN_PMKID_LEN; + } + ev->cr.fils.update_erp_next_seq_num = params->fils.update_erp_next_seq_num; + if (params->fils.update_erp_next_seq_num) + ev->cr.fils.erp_next_seq_num = params->fils.erp_next_seq_num; + ev->cr.valid_links = params->valid_links; + for_each_valid_link(params, link) { + if (params->links[link].bss) + cfg80211_hold_bss( + bss_from_pub(params->links[link].bss)); + ev->cr.links[link].bss = params->links[link].bss; + + if (params->links[link].addr) { + ev->cr.links[link].addr = next; + memcpy((void *)ev->cr.links[link].addr, + params->links[link].addr, + ETH_ALEN); + next += ETH_ALEN; + } + if (params->links[link].bssid) { + ev->cr.links[link].bssid = next; + memcpy((void *)ev->cr.links[link].bssid, + params->links[link].bssid, + ETH_ALEN); + next += ETH_ALEN; + } + } + ev->cr.status = params->status; + ev->cr.timeout_reason = params->timeout_reason; + + spin_lock_irqsave(&wdev->event_lock, flags); + list_add_tail(&ev->list, &wdev->event_list); + spin_unlock_irqrestore(&wdev->event_lock, flags); + queue_work(cfg80211_wq, &rdev->event_work); +} +EXPORT_SYMBOL(cfg80211_connect_done); + +/* Consumes bss object one way or another */ +void __cfg80211_roamed(struct wireless_dev *wdev, + struct cfg80211_roam_info *info) +{ +#ifdef CONFIG_CFG80211_WEXT + union iwreq_data wrqu; +#endif + unsigned int link; + const u8 *connected_addr; + + ASSERT_WDEV_LOCK(wdev); + + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_STATION && + wdev->iftype != NL80211_IFTYPE_P2P_CLIENT)) + goto out; + + if (WARN_ON(!wdev->connected)) + goto out; + + if (info->valid_links) { + if (WARN_ON(!info->ap_mld_addr)) + goto out; + + for_each_valid_link(info, link) { + if (WARN_ON(!info->links[link].addr)) + goto out; + } + } + + cfg80211_wdev_release_bsses(wdev); + + for_each_valid_link(info, link) { + if (WARN_ON(!info->links[link].bss)) + goto out; + } + + memset(wdev->links, 0, sizeof(wdev->links)); + wdev->valid_links = info->valid_links; + for_each_valid_link(info, link) { + cfg80211_hold_bss(bss_from_pub(info->links[link].bss)); + wdev->links[link].client.current_bss = + bss_from_pub(info->links[link].bss); + } + + connected_addr = info->valid_links ? + info->ap_mld_addr : + info->links[0].bss->bssid; + ether_addr_copy(wdev->u.client.connected_addr, connected_addr); + if (info->valid_links) { + for_each_valid_link(info, link) + memcpy(wdev->links[link].addr, info->links[link].addr, + ETH_ALEN); + } + wdev->unprot_beacon_reported = 0; + nl80211_send_roamed(wiphy_to_rdev(wdev->wiphy), + wdev->netdev, info, GFP_KERNEL); + +#ifdef CONFIG_CFG80211_WEXT + if (!info->valid_links) { + if (info->req_ie) { + memset(&wrqu, 0, sizeof(wrqu)); + wrqu.data.length = info->req_ie_len; + wireless_send_event(wdev->netdev, IWEVASSOCREQIE, + &wrqu, info->req_ie); + } + + if (info->resp_ie) { + memset(&wrqu, 0, sizeof(wrqu)); + wrqu.data.length = info->resp_ie_len; + wireless_send_event(wdev->netdev, IWEVASSOCRESPIE, + &wrqu, info->resp_ie); + } + + memset(&wrqu, 0, sizeof(wrqu)); + wrqu.ap_addr.sa_family = ARPHRD_ETHER; + memcpy(wrqu.ap_addr.sa_data, connected_addr, ETH_ALEN); + memcpy(wdev->wext.prev_bssid, connected_addr, ETH_ALEN); + wdev->wext.prev_bssid_valid = true; + wireless_send_event(wdev->netdev, SIOCGIWAP, &wrqu, NULL); + } +#endif + + return; +out: + for_each_valid_link(info, link) + cfg80211_put_bss(wdev->wiphy, info->links[link].bss); +} + +/* Consumes info->links.bss object(s) one way or another */ +void cfg80211_roamed(struct net_device *dev, struct cfg80211_roam_info *info, + gfp_t gfp) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_event *ev; + unsigned long flags; + u8 *next; + unsigned int link; + size_t link_info_size = 0; + bool bss_not_found = false; + + for_each_valid_link(info, link) { + link_info_size += info->links[link].addr ? ETH_ALEN : 0; + link_info_size += info->links[link].bssid ? ETH_ALEN : 0; + + if (info->links[link].bss) + continue; + + info->links[link].bss = + cfg80211_get_bss(wdev->wiphy, + info->links[link].channel, + info->links[link].bssid, + wdev->u.client.ssid, + wdev->u.client.ssid_len, + wdev->conn_bss_type, + IEEE80211_PRIVACY_ANY); + + if (!info->links[link].bss) { + bss_not_found = true; + break; + } + } + + if (WARN_ON(bss_not_found)) + goto out; + + ev = kzalloc(sizeof(*ev) + info->req_ie_len + info->resp_ie_len + + info->fils.kek_len + info->fils.pmk_len + + (info->fils.pmkid ? WLAN_PMKID_LEN : 0) + + (info->ap_mld_addr ? ETH_ALEN : 0) + link_info_size, gfp); + if (!ev) + goto out; + + ev->type = EVENT_ROAMED; + next = ((u8 *)ev) + sizeof(*ev); + if (info->req_ie_len) { + ev->rm.req_ie = next; + ev->rm.req_ie_len = info->req_ie_len; + memcpy((void *)ev->rm.req_ie, info->req_ie, info->req_ie_len); + next += info->req_ie_len; + } + if (info->resp_ie_len) { + ev->rm.resp_ie = next; + ev->rm.resp_ie_len = info->resp_ie_len; + memcpy((void *)ev->rm.resp_ie, info->resp_ie, + info->resp_ie_len); + next += info->resp_ie_len; + } + if (info->fils.kek_len) { + ev->rm.fils.kek = next; + ev->rm.fils.kek_len = info->fils.kek_len; + memcpy((void *)ev->rm.fils.kek, info->fils.kek, + info->fils.kek_len); + next += info->fils.kek_len; + } + if (info->fils.pmk_len) { + ev->rm.fils.pmk = next; + ev->rm.fils.pmk_len = info->fils.pmk_len; + memcpy((void *)ev->rm.fils.pmk, info->fils.pmk, + info->fils.pmk_len); + next += info->fils.pmk_len; + } + if (info->fils.pmkid) { + ev->rm.fils.pmkid = next; + memcpy((void *)ev->rm.fils.pmkid, info->fils.pmkid, + WLAN_PMKID_LEN); + next += WLAN_PMKID_LEN; + } + ev->rm.fils.update_erp_next_seq_num = info->fils.update_erp_next_seq_num; + if (info->fils.update_erp_next_seq_num) + ev->rm.fils.erp_next_seq_num = info->fils.erp_next_seq_num; + if (info->ap_mld_addr) { + ev->rm.ap_mld_addr = next; + memcpy((void *)ev->rm.ap_mld_addr, info->ap_mld_addr, + ETH_ALEN); + next += ETH_ALEN; + } + ev->rm.valid_links = info->valid_links; + for_each_valid_link(info, link) { + ev->rm.links[link].bss = info->links[link].bss; + + if (info->links[link].addr) { + ev->rm.links[link].addr = next; + memcpy((void *)ev->rm.links[link].addr, + info->links[link].addr, + ETH_ALEN); + next += ETH_ALEN; + } + + if (info->links[link].bssid) { + ev->rm.links[link].bssid = next; + memcpy((void *)ev->rm.links[link].bssid, + info->links[link].bssid, + ETH_ALEN); + next += ETH_ALEN; + } + } + + spin_lock_irqsave(&wdev->event_lock, flags); + list_add_tail(&ev->list, &wdev->event_list); + spin_unlock_irqrestore(&wdev->event_lock, flags); + queue_work(cfg80211_wq, &rdev->event_work); + + return; +out: + for_each_valid_link(info, link) + cfg80211_put_bss(wdev->wiphy, info->links[link].bss); + +} +EXPORT_SYMBOL(cfg80211_roamed); + +void __cfg80211_port_authorized(struct wireless_dev *wdev, const u8 *bssid) +{ + ASSERT_WDEV_LOCK(wdev); + + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_STATION && + wdev->iftype != NL80211_IFTYPE_P2P_CLIENT)) + return; + + if (WARN_ON(!wdev->connected) || + WARN_ON(!ether_addr_equal(wdev->u.client.connected_addr, bssid))) + return; + + nl80211_send_port_authorized(wiphy_to_rdev(wdev->wiphy), wdev->netdev, + bssid); +} + +void cfg80211_port_authorized(struct net_device *dev, const u8 *bssid, + gfp_t gfp) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_event *ev; + unsigned long flags; + + if (WARN_ON(!bssid)) + return; + + ev = kzalloc(sizeof(*ev), gfp); + if (!ev) + return; + + ev->type = EVENT_PORT_AUTHORIZED; + memcpy(ev->pa.bssid, bssid, ETH_ALEN); + + /* + * Use the wdev event list so that if there are pending + * connected/roamed events, they will be reported first. + */ + spin_lock_irqsave(&wdev->event_lock, flags); + list_add_tail(&ev->list, &wdev->event_list); + spin_unlock_irqrestore(&wdev->event_lock, flags); + queue_work(cfg80211_wq, &rdev->event_work); +} +EXPORT_SYMBOL(cfg80211_port_authorized); + +void __cfg80211_disconnected(struct net_device *dev, const u8 *ie, + size_t ie_len, u16 reason, bool from_ap) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + int i; +#ifdef CONFIG_CFG80211_WEXT + union iwreq_data wrqu; +#endif + + ASSERT_WDEV_LOCK(wdev); + + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_STATION && + wdev->iftype != NL80211_IFTYPE_P2P_CLIENT)) + return; + + cfg80211_wdev_release_bsses(wdev); + wdev->connected = false; + wdev->u.client.ssid_len = 0; + wdev->conn_owner_nlportid = 0; + kfree_sensitive(wdev->connect_keys); + wdev->connect_keys = NULL; + + nl80211_send_disconnected(rdev, dev, reason, ie, ie_len, from_ap); + + /* stop critical protocol if supported */ + if (rdev->ops->crit_proto_stop && rdev->crit_proto_nlportid) { + rdev->crit_proto_nlportid = 0; + rdev_crit_proto_stop(rdev, wdev); + } + + /* + * Delete all the keys ... pairwise keys can't really + * exist any more anyway, but default keys might. + */ + if (rdev->ops->del_key) { + int max_key_idx = 5; + + if (wiphy_ext_feature_isset( + wdev->wiphy, + NL80211_EXT_FEATURE_BEACON_PROTECTION) || + wiphy_ext_feature_isset( + wdev->wiphy, + NL80211_EXT_FEATURE_BEACON_PROTECTION_CLIENT)) + max_key_idx = 7; + for (i = 0; i <= max_key_idx; i++) + rdev_del_key(rdev, dev, -1, i, false, NULL); + } + + rdev_set_qos_map(rdev, dev, NULL); + +#ifdef CONFIG_CFG80211_WEXT + memset(&wrqu, 0, sizeof(wrqu)); + wrqu.ap_addr.sa_family = ARPHRD_ETHER; + wireless_send_event(dev, SIOCGIWAP, &wrqu, NULL); + wdev->wext.connect.ssid_len = 0; +#endif + + schedule_work(&cfg80211_disconnect_work); +} + +void cfg80211_disconnected(struct net_device *dev, u16 reason, + const u8 *ie, size_t ie_len, + bool locally_generated, gfp_t gfp) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_event *ev; + unsigned long flags; + + ev = kzalloc(sizeof(*ev) + ie_len, gfp); + if (!ev) + return; + + ev->type = EVENT_DISCONNECTED; + ev->dc.ie = ((u8 *)ev) + sizeof(*ev); + ev->dc.ie_len = ie_len; + memcpy((void *)ev->dc.ie, ie, ie_len); + ev->dc.reason = reason; + ev->dc.locally_generated = locally_generated; + + spin_lock_irqsave(&wdev->event_lock, flags); + list_add_tail(&ev->list, &wdev->event_list); + spin_unlock_irqrestore(&wdev->event_lock, flags); + queue_work(cfg80211_wq, &rdev->event_work); +} +EXPORT_SYMBOL(cfg80211_disconnected); + +/* + * API calls for nl80211/wext compatibility code + */ +int cfg80211_connect(struct cfg80211_registered_device *rdev, + struct net_device *dev, + struct cfg80211_connect_params *connect, + struct cfg80211_cached_keys *connkeys, + const u8 *prev_bssid) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int err; + + ASSERT_WDEV_LOCK(wdev); + + /* + * If we have an ssid_len, we're trying to connect or are + * already connected, so reject a new SSID unless it's the + * same (which is the case for re-association.) + */ + if (wdev->u.client.ssid_len && + (wdev->u.client.ssid_len != connect->ssid_len || + memcmp(wdev->u.client.ssid, connect->ssid, wdev->u.client.ssid_len))) + return -EALREADY; + + /* + * If connected, reject (re-)association unless prev_bssid + * matches the current BSSID. + */ + if (wdev->connected) { + if (!prev_bssid) + return -EALREADY; + if (!ether_addr_equal(prev_bssid, + wdev->u.client.connected_addr)) + return -ENOTCONN; + } + + /* + * Reject if we're in the process of connecting with WEP, + * this case isn't very interesting and trying to handle + * it would make the code much more complex. + */ + if (wdev->connect_keys) + return -EINPROGRESS; + + cfg80211_oper_and_ht_capa(&connect->ht_capa_mask, + rdev->wiphy.ht_capa_mod_mask); + cfg80211_oper_and_vht_capa(&connect->vht_capa_mask, + rdev->wiphy.vht_capa_mod_mask); + + if (connkeys && connkeys->def >= 0) { + int idx; + u32 cipher; + + idx = connkeys->def; + cipher = connkeys->params[idx].cipher; + /* If given a WEP key we may need it for shared key auth */ + if (cipher == WLAN_CIPHER_SUITE_WEP40 || + cipher == WLAN_CIPHER_SUITE_WEP104) { + connect->key_idx = idx; + connect->key = connkeys->params[idx].key; + connect->key_len = connkeys->params[idx].key_len; + + /* + * If ciphers are not set (e.g. when going through + * iwconfig), we have to set them appropriately here. + */ + if (connect->crypto.cipher_group == 0) + connect->crypto.cipher_group = cipher; + + if (connect->crypto.n_ciphers_pairwise == 0) { + connect->crypto.n_ciphers_pairwise = 1; + connect->crypto.ciphers_pairwise[0] = cipher; + } + } + + connect->crypto.wep_keys = connkeys->params; + connect->crypto.wep_tx_key = connkeys->def; + } else { + if (WARN_ON(connkeys)) + return -EINVAL; + + /* connect can point to wdev->wext.connect which + * can hold key data from a previous connection + */ + connect->key = NULL; + connect->key_len = 0; + connect->key_idx = 0; + } + + wdev->connect_keys = connkeys; + memcpy(wdev->u.client.ssid, connect->ssid, connect->ssid_len); + wdev->u.client.ssid_len = connect->ssid_len; + + wdev->conn_bss_type = connect->pbss ? IEEE80211_BSS_TYPE_PBSS : + IEEE80211_BSS_TYPE_ESS; + + if (!rdev->ops->connect) + err = cfg80211_sme_connect(wdev, connect, prev_bssid); + else + err = rdev_connect(rdev, dev, connect); + + if (err) { + wdev->connect_keys = NULL; + /* + * This could be reassoc getting refused, don't clear + * ssid_len in that case. + */ + if (!wdev->connected) + wdev->u.client.ssid_len = 0; + return err; + } + + return 0; +} + +int cfg80211_disconnect(struct cfg80211_registered_device *rdev, + struct net_device *dev, u16 reason, bool wextev) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int err = 0; + + ASSERT_WDEV_LOCK(wdev); + + kfree_sensitive(wdev->connect_keys); + wdev->connect_keys = NULL; + + wdev->conn_owner_nlportid = 0; + + if (wdev->conn) + err = cfg80211_sme_disconnect(wdev, reason); + else if (!rdev->ops->disconnect) + cfg80211_mlme_down(rdev, dev); + else if (wdev->u.client.ssid_len) + err = rdev_disconnect(rdev, dev, reason); + + /* + * Clear ssid_len unless we actually were fully connected, + * in which case cfg80211_disconnected() will take care of + * this later. + */ + if (!wdev->connected) + wdev->u.client.ssid_len = 0; + + return err; +} + +/* + * Used to clean up after the connection / connection attempt owner socket + * disconnects + */ +void cfg80211_autodisconnect_wk(struct work_struct *work) +{ + struct wireless_dev *wdev = + container_of(work, struct wireless_dev, disconnect_wk); + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + + wiphy_lock(wdev->wiphy); + wdev_lock(wdev); + + if (wdev->conn_owner_nlportid) { + switch (wdev->iftype) { + case NL80211_IFTYPE_ADHOC: + __cfg80211_leave_ibss(rdev, wdev->netdev, false); + break; + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_P2P_GO: + __cfg80211_stop_ap(rdev, wdev->netdev, -1, false); + break; + case NL80211_IFTYPE_MESH_POINT: + __cfg80211_leave_mesh(rdev, wdev->netdev); + break; + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_P2P_CLIENT: + /* + * Use disconnect_bssid if still connecting and + * ops->disconnect not implemented. Otherwise we can + * use cfg80211_disconnect. + */ + if (rdev->ops->disconnect || wdev->connected) + cfg80211_disconnect(rdev, wdev->netdev, + WLAN_REASON_DEAUTH_LEAVING, + true); + else + cfg80211_mlme_deauth(rdev, wdev->netdev, + wdev->disconnect_bssid, + NULL, 0, + WLAN_REASON_DEAUTH_LEAVING, + false); + break; + default: + break; + } + } + + wdev_unlock(wdev); + wiphy_unlock(wdev->wiphy); +} diff --git a/net/wireless/sysfs.c b/net/wireless/sysfs.c new file mode 100644 index 000000000..a88f338c6 --- /dev/null +++ b/net/wireless/sysfs.c @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * This file provides /sys/class/ieee80211// + * and some default attributes. + * + * Copyright 2005-2006 Jiri Benc + * Copyright 2006 Johannes Berg + * Copyright (C) 2020-2021, 2023 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include +#include "sysfs.h" +#include "core.h" +#include "rdev-ops.h" + +static inline struct cfg80211_registered_device *dev_to_rdev( + struct device *dev) +{ + return container_of(dev, struct cfg80211_registered_device, wiphy.dev); +} + +#define SHOW_FMT(name, fmt, member) \ +static ssize_t name ## _show(struct device *dev, \ + struct device_attribute *attr, \ + char *buf) \ +{ \ + return sprintf(buf, fmt "\n", dev_to_rdev(dev)->member); \ +} \ +static DEVICE_ATTR_RO(name) + +SHOW_FMT(index, "%d", wiphy_idx); +SHOW_FMT(macaddress, "%pM", wiphy.perm_addr); +SHOW_FMT(address_mask, "%pM", wiphy.addr_mask); + +static ssize_t name_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + struct wiphy *wiphy = &dev_to_rdev(dev)->wiphy; + + return sprintf(buf, "%s\n", wiphy_name(wiphy)); +} +static DEVICE_ATTR_RO(name); + +static ssize_t addresses_show(struct device *dev, + struct device_attribute *attr, + char *buf) +{ + struct wiphy *wiphy = &dev_to_rdev(dev)->wiphy; + char *start = buf; + int i; + + if (!wiphy->addresses) + return sprintf(buf, "%pM\n", wiphy->perm_addr); + + for (i = 0; i < wiphy->n_addresses; i++) + buf += sprintf(buf, "%pM\n", wiphy->addresses[i].addr); + + return buf - start; +} +static DEVICE_ATTR_RO(addresses); + +static struct attribute *ieee80211_attrs[] = { + &dev_attr_index.attr, + &dev_attr_macaddress.attr, + &dev_attr_address_mask.attr, + &dev_attr_addresses.attr, + &dev_attr_name.attr, + NULL, +}; +ATTRIBUTE_GROUPS(ieee80211); + +static void wiphy_dev_release(struct device *dev) +{ + struct cfg80211_registered_device *rdev = dev_to_rdev(dev); + + cfg80211_dev_free(rdev); +} + +#ifdef CONFIG_PM_SLEEP +static void cfg80211_leave_all(struct cfg80211_registered_device *rdev) +{ + struct wireless_dev *wdev; + + list_for_each_entry(wdev, &rdev->wiphy.wdev_list, list) + cfg80211_leave(rdev, wdev); +} + +static int wiphy_suspend(struct device *dev) +{ + struct cfg80211_registered_device *rdev = dev_to_rdev(dev); + int ret = 0; + + rdev->suspend_at = ktime_get_boottime_seconds(); + + rtnl_lock(); + wiphy_lock(&rdev->wiphy); + if (rdev->wiphy.registered) { + if (!rdev->wiphy.wowlan_config) { + cfg80211_leave_all(rdev); + cfg80211_process_rdev_events(rdev); + } + cfg80211_process_wiphy_works(rdev, NULL); + if (rdev->ops->suspend) + ret = rdev_suspend(rdev, rdev->wiphy.wowlan_config); + if (ret == 1) { + /* Driver refuse to configure wowlan */ + cfg80211_leave_all(rdev); + cfg80211_process_rdev_events(rdev); + cfg80211_process_wiphy_works(rdev, NULL); + ret = rdev_suspend(rdev, NULL); + } + if (ret == 0) + rdev->suspended = true; + } + wiphy_unlock(&rdev->wiphy); + rtnl_unlock(); + + return ret; +} + +static int wiphy_resume(struct device *dev) +{ + struct cfg80211_registered_device *rdev = dev_to_rdev(dev); + int ret = 0; + + /* Age scan results with time spent in suspend */ + cfg80211_bss_age(rdev, ktime_get_boottime_seconds() - rdev->suspend_at); + + rtnl_lock(); + wiphy_lock(&rdev->wiphy); + if (rdev->wiphy.registered && rdev->ops->resume) + ret = rdev_resume(rdev); + rdev->suspended = false; + schedule_work(&rdev->wiphy_work); + wiphy_unlock(&rdev->wiphy); + + if (ret) + cfg80211_shutdown_all_interfaces(&rdev->wiphy); + + rtnl_unlock(); + + return ret; +} + +static SIMPLE_DEV_PM_OPS(wiphy_pm_ops, wiphy_suspend, wiphy_resume); +#define WIPHY_PM_OPS (&wiphy_pm_ops) +#else +#define WIPHY_PM_OPS NULL +#endif + +static const void *wiphy_namespace(struct device *d) +{ + struct wiphy *wiphy = container_of(d, struct wiphy, dev); + + return wiphy_net(wiphy); +} + +struct class ieee80211_class = { + .name = "ieee80211", + .owner = THIS_MODULE, + .dev_release = wiphy_dev_release, + .dev_groups = ieee80211_groups, + .pm = WIPHY_PM_OPS, + .ns_type = &net_ns_type_operations, + .namespace = wiphy_namespace, +}; + +int wiphy_sysfs_init(void) +{ + return class_register(&ieee80211_class); +} + +void wiphy_sysfs_exit(void) +{ + class_unregister(&ieee80211_class); +} diff --git a/net/wireless/sysfs.h b/net/wireless/sysfs.h new file mode 100644 index 000000000..7b454c2de --- /dev/null +++ b/net/wireless/sysfs.h @@ -0,0 +1,10 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef __WIRELESS_SYSFS_H +#define __WIRELESS_SYSFS_H + +int wiphy_sysfs_init(void); +void wiphy_sysfs_exit(void); + +extern struct class ieee80211_class; + +#endif /* __WIRELESS_SYSFS_H */ diff --git a/net/wireless/trace.c b/net/wireless/trace.c new file mode 100644 index 000000000..95f997fad --- /dev/null +++ b/net/wireless/trace.c @@ -0,0 +1,7 @@ +#include + +#ifndef __CHECKER__ +#define CREATE_TRACE_POINTS +#include "trace.h" + +#endif diff --git a/net/wireless/trace.h b/net/wireless/trace.h new file mode 100644 index 000000000..a405c3edb --- /dev/null +++ b/net/wireless/trace.h @@ -0,0 +1,3910 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#undef TRACE_SYSTEM +#define TRACE_SYSTEM cfg80211 + +#if !defined(__RDEV_OPS_TRACE) || defined(TRACE_HEADER_MULTI_READ) +#define __RDEV_OPS_TRACE + +#include + +#include +#include +#include +#include "core.h" + +#define MAC_ENTRY(entry_mac) __array(u8, entry_mac, ETH_ALEN) +#define MAC_ASSIGN(entry_mac, given_mac) do { \ + if (given_mac) \ + memcpy(__entry->entry_mac, given_mac, ETH_ALEN); \ + else \ + eth_zero_addr(__entry->entry_mac); \ + } while (0) +#define MAC_PR_FMT "%pM" +#define MAC_PR_ARG(entry_mac) (__entry->entry_mac) + +#define MAXNAME 32 +#define WIPHY_ENTRY __array(char, wiphy_name, 32) +#define WIPHY_ASSIGN strlcpy(__entry->wiphy_name, wiphy_name(wiphy), MAXNAME) +#define WIPHY_PR_FMT "%s" +#define WIPHY_PR_ARG __entry->wiphy_name + +#define WDEV_ENTRY __field(u32, id) +#define WDEV_ASSIGN (__entry->id) = (!IS_ERR_OR_NULL(wdev) \ + ? wdev->identifier : 0) +#define WDEV_PR_FMT "wdev(%u)" +#define WDEV_PR_ARG (__entry->id) + +#define NETDEV_ENTRY __array(char, name, IFNAMSIZ) \ + __field(int, ifindex) +#define NETDEV_ASSIGN \ + do { \ + memcpy(__entry->name, netdev->name, IFNAMSIZ); \ + (__entry->ifindex) = (netdev->ifindex); \ + } while (0) +#define NETDEV_PR_FMT "netdev:%s(%d)" +#define NETDEV_PR_ARG __entry->name, __entry->ifindex + +#define MESH_CFG_ENTRY __field(u16, dot11MeshRetryTimeout) \ + __field(u16, dot11MeshConfirmTimeout) \ + __field(u16, dot11MeshHoldingTimeout) \ + __field(u16, dot11MeshMaxPeerLinks) \ + __field(u8, dot11MeshMaxRetries) \ + __field(u8, dot11MeshTTL) \ + __field(u8, element_ttl) \ + __field(bool, auto_open_plinks) \ + __field(u32, dot11MeshNbrOffsetMaxNeighbor) \ + __field(u8, dot11MeshHWMPmaxPREQretries) \ + __field(u32, path_refresh_time) \ + __field(u32, dot11MeshHWMPactivePathTimeout) \ + __field(u16, min_discovery_timeout) \ + __field(u16, dot11MeshHWMPpreqMinInterval) \ + __field(u16, dot11MeshHWMPperrMinInterval) \ + __field(u16, dot11MeshHWMPnetDiameterTraversalTime) \ + __field(u8, dot11MeshHWMPRootMode) \ + __field(u16, dot11MeshHWMPRannInterval) \ + __field(bool, dot11MeshGateAnnouncementProtocol) \ + __field(bool, dot11MeshForwarding) \ + __field(s32, rssi_threshold) \ + __field(u16, ht_opmode) \ + __field(u32, dot11MeshHWMPactivePathToRootTimeout) \ + __field(u16, dot11MeshHWMProotInterval) \ + __field(u16, dot11MeshHWMPconfirmationInterval) \ + __field(bool, dot11MeshNolearn) +#define MESH_CFG_ASSIGN \ + do { \ + __entry->dot11MeshRetryTimeout = conf->dot11MeshRetryTimeout; \ + __entry->dot11MeshConfirmTimeout = \ + conf->dot11MeshConfirmTimeout; \ + __entry->dot11MeshHoldingTimeout = \ + conf->dot11MeshHoldingTimeout; \ + __entry->dot11MeshMaxPeerLinks = conf->dot11MeshMaxPeerLinks; \ + __entry->dot11MeshMaxRetries = conf->dot11MeshMaxRetries; \ + __entry->dot11MeshTTL = conf->dot11MeshTTL; \ + __entry->element_ttl = conf->element_ttl; \ + __entry->auto_open_plinks = conf->auto_open_plinks; \ + __entry->dot11MeshNbrOffsetMaxNeighbor = \ + conf->dot11MeshNbrOffsetMaxNeighbor; \ + __entry->dot11MeshHWMPmaxPREQretries = \ + conf->dot11MeshHWMPmaxPREQretries; \ + __entry->path_refresh_time = conf->path_refresh_time; \ + __entry->dot11MeshHWMPactivePathTimeout = \ + conf->dot11MeshHWMPactivePathTimeout; \ + __entry->min_discovery_timeout = conf->min_discovery_timeout; \ + __entry->dot11MeshHWMPpreqMinInterval = \ + conf->dot11MeshHWMPpreqMinInterval; \ + __entry->dot11MeshHWMPperrMinInterval = \ + conf->dot11MeshHWMPperrMinInterval; \ + __entry->dot11MeshHWMPnetDiameterTraversalTime = \ + conf->dot11MeshHWMPnetDiameterTraversalTime; \ + __entry->dot11MeshHWMPRootMode = conf->dot11MeshHWMPRootMode; \ + __entry->dot11MeshHWMPRannInterval = \ + conf->dot11MeshHWMPRannInterval; \ + __entry->dot11MeshGateAnnouncementProtocol = \ + conf->dot11MeshGateAnnouncementProtocol; \ + __entry->dot11MeshForwarding = conf->dot11MeshForwarding; \ + __entry->rssi_threshold = conf->rssi_threshold; \ + __entry->ht_opmode = conf->ht_opmode; \ + __entry->dot11MeshHWMPactivePathToRootTimeout = \ + conf->dot11MeshHWMPactivePathToRootTimeout; \ + __entry->dot11MeshHWMProotInterval = \ + conf->dot11MeshHWMProotInterval; \ + __entry->dot11MeshHWMPconfirmationInterval = \ + conf->dot11MeshHWMPconfirmationInterval; \ + __entry->dot11MeshNolearn = conf->dot11MeshNolearn; \ + } while (0) + +#define CHAN_ENTRY __field(enum nl80211_band, band) \ + __field(u32, center_freq) \ + __field(u16, freq_offset) +#define CHAN_ASSIGN(chan) \ + do { \ + if (chan) { \ + __entry->band = chan->band; \ + __entry->center_freq = chan->center_freq; \ + __entry->freq_offset = chan->freq_offset; \ + } else { \ + __entry->band = 0; \ + __entry->center_freq = 0; \ + __entry->freq_offset = 0; \ + } \ + } while (0) +#define CHAN_PR_FMT "band: %d, freq: %u.%03u" +#define CHAN_PR_ARG __entry->band, __entry->center_freq, __entry->freq_offset + +#define CHAN_DEF_ENTRY __field(enum nl80211_band, band) \ + __field(u32, control_freq) \ + __field(u32, freq_offset) \ + __field(u32, width) \ + __field(u32, center_freq1) \ + __field(u32, freq1_offset) \ + __field(u32, center_freq2) +#define CHAN_DEF_ASSIGN(chandef) \ + do { \ + if ((chandef) && (chandef)->chan) { \ + __entry->band = (chandef)->chan->band; \ + __entry->control_freq = \ + (chandef)->chan->center_freq; \ + __entry->freq_offset = \ + (chandef)->chan->freq_offset; \ + __entry->width = (chandef)->width; \ + __entry->center_freq1 = (chandef)->center_freq1;\ + __entry->freq1_offset = (chandef)->freq1_offset;\ + __entry->center_freq2 = (chandef)->center_freq2;\ + } else { \ + __entry->band = 0; \ + __entry->control_freq = 0; \ + __entry->freq_offset = 0; \ + __entry->width = 0; \ + __entry->center_freq1 = 0; \ + __entry->freq1_offset = 0; \ + __entry->center_freq2 = 0; \ + } \ + } while (0) +#define CHAN_DEF_PR_FMT \ + "band: %d, control freq: %u.%03u, width: %d, cf1: %u.%03u, cf2: %u" +#define CHAN_DEF_PR_ARG __entry->band, __entry->control_freq, \ + __entry->freq_offset, __entry->width, \ + __entry->center_freq1, __entry->freq1_offset, \ + __entry->center_freq2 + +#define FILS_AAD_ASSIGN(fa) \ + do { \ + if (fa) { \ + ether_addr_copy(__entry->macaddr, fa->macaddr); \ + __entry->kek_len = fa->kek_len; \ + } else { \ + eth_zero_addr(__entry->macaddr); \ + __entry->kek_len = 0; \ + } \ + } while (0) +#define FILS_AAD_PR_FMT \ + "macaddr: %pM, kek_len: %d" + +#define SINFO_ENTRY __field(int, generation) \ + __field(u32, connected_time) \ + __field(u32, inactive_time) \ + __field(u32, rx_bytes) \ + __field(u32, tx_bytes) \ + __field(u32, rx_packets) \ + __field(u32, tx_packets) \ + __field(u32, tx_retries) \ + __field(u32, tx_failed) \ + __field(u32, rx_dropped_misc) \ + __field(u32, beacon_loss_count) \ + __field(u16, llid) \ + __field(u16, plid) \ + __field(u8, plink_state) +#define SINFO_ASSIGN \ + do { \ + __entry->generation = sinfo->generation; \ + __entry->connected_time = sinfo->connected_time; \ + __entry->inactive_time = sinfo->inactive_time; \ + __entry->rx_bytes = sinfo->rx_bytes; \ + __entry->tx_bytes = sinfo->tx_bytes; \ + __entry->rx_packets = sinfo->rx_packets; \ + __entry->tx_packets = sinfo->tx_packets; \ + __entry->tx_retries = sinfo->tx_retries; \ + __entry->tx_failed = sinfo->tx_failed; \ + __entry->rx_dropped_misc = sinfo->rx_dropped_misc; \ + __entry->beacon_loss_count = sinfo->beacon_loss_count; \ + __entry->llid = sinfo->llid; \ + __entry->plid = sinfo->plid; \ + __entry->plink_state = sinfo->plink_state; \ + } while (0) + +#define BOOL_TO_STR(bo) (bo) ? "true" : "false" + +#define QOS_MAP_ENTRY __field(u8, num_des) \ + __array(u8, dscp_exception, \ + 2 * IEEE80211_QOS_MAP_MAX_EX) \ + __array(u8, up, IEEE80211_QOS_MAP_LEN_MIN) +#define QOS_MAP_ASSIGN(qos_map) \ + do { \ + if ((qos_map)) { \ + __entry->num_des = (qos_map)->num_des; \ + memcpy(__entry->dscp_exception, \ + &(qos_map)->dscp_exception, \ + 2 * IEEE80211_QOS_MAP_MAX_EX); \ + memcpy(__entry->up, &(qos_map)->up, \ + IEEE80211_QOS_MAP_LEN_MIN); \ + } else { \ + __entry->num_des = 0; \ + memset(__entry->dscp_exception, 0, \ + 2 * IEEE80211_QOS_MAP_MAX_EX); \ + memset(__entry->up, 0, \ + IEEE80211_QOS_MAP_LEN_MIN); \ + } \ + } while (0) + +/************************************************************* + * rdev->ops traces * + *************************************************************/ + +TRACE_EVENT(rdev_suspend, + TP_PROTO(struct wiphy *wiphy, struct cfg80211_wowlan *wow), + TP_ARGS(wiphy, wow), + TP_STRUCT__entry( + WIPHY_ENTRY + __field(bool, any) + __field(bool, disconnect) + __field(bool, magic_pkt) + __field(bool, gtk_rekey_failure) + __field(bool, eap_identity_req) + __field(bool, four_way_handshake) + __field(bool, rfkill_release) + __field(bool, valid_wow) + ), + TP_fast_assign( + WIPHY_ASSIGN; + if (wow) { + __entry->any = wow->any; + __entry->disconnect = wow->disconnect; + __entry->magic_pkt = wow->magic_pkt; + __entry->gtk_rekey_failure = wow->gtk_rekey_failure; + __entry->eap_identity_req = wow->eap_identity_req; + __entry->four_way_handshake = wow->four_way_handshake; + __entry->rfkill_release = wow->rfkill_release; + __entry->valid_wow = true; + } else { + __entry->valid_wow = false; + } + ), + TP_printk(WIPHY_PR_FMT ", wow%s - any: %d, disconnect: %d, " + "magic pkt: %d, gtk rekey failure: %d, eap identify req: %d, " + "four way handshake: %d, rfkill release: %d.", + WIPHY_PR_ARG, __entry->valid_wow ? "" : "(Not configured!)", + __entry->any, __entry->disconnect, __entry->magic_pkt, + __entry->gtk_rekey_failure, __entry->eap_identity_req, + __entry->four_way_handshake, __entry->rfkill_release) +); + +TRACE_EVENT(rdev_return_int, + TP_PROTO(struct wiphy *wiphy, int ret), + TP_ARGS(wiphy, ret), + TP_STRUCT__entry( + WIPHY_ENTRY + __field(int, ret) + ), + TP_fast_assign( + WIPHY_ASSIGN; + __entry->ret = ret; + ), + TP_printk(WIPHY_PR_FMT ", returned: %d", WIPHY_PR_ARG, __entry->ret) +); + +TRACE_EVENT(rdev_scan, + TP_PROTO(struct wiphy *wiphy, struct cfg80211_scan_request *request), + TP_ARGS(wiphy, request), + TP_STRUCT__entry( + WIPHY_ENTRY + ), + TP_fast_assign( + WIPHY_ASSIGN; + ), + TP_printk(WIPHY_PR_FMT, WIPHY_PR_ARG) +); + +DECLARE_EVENT_CLASS(wiphy_only_evt, + TP_PROTO(struct wiphy *wiphy), + TP_ARGS(wiphy), + TP_STRUCT__entry( + WIPHY_ENTRY + ), + TP_fast_assign( + WIPHY_ASSIGN; + ), + TP_printk(WIPHY_PR_FMT, WIPHY_PR_ARG) +); + +DEFINE_EVENT(wiphy_only_evt, rdev_resume, + TP_PROTO(struct wiphy *wiphy), + TP_ARGS(wiphy) +); + +DEFINE_EVENT(wiphy_only_evt, rdev_return_void, + TP_PROTO(struct wiphy *wiphy), + TP_ARGS(wiphy) +); + +DEFINE_EVENT(wiphy_only_evt, rdev_get_antenna, + TP_PROTO(struct wiphy *wiphy), + TP_ARGS(wiphy) +); + +DEFINE_EVENT(wiphy_only_evt, rdev_rfkill_poll, + TP_PROTO(struct wiphy *wiphy), + TP_ARGS(wiphy) +); + +DECLARE_EVENT_CLASS(wiphy_enabled_evt, + TP_PROTO(struct wiphy *wiphy, bool enabled), + TP_ARGS(wiphy, enabled), + TP_STRUCT__entry( + WIPHY_ENTRY + __field(bool, enabled) + ), + TP_fast_assign( + WIPHY_ASSIGN; + __entry->enabled = enabled; + ), + TP_printk(WIPHY_PR_FMT ", %senabled ", + WIPHY_PR_ARG, __entry->enabled ? "" : "not ") +); + +DEFINE_EVENT(wiphy_enabled_evt, rdev_set_wakeup, + TP_PROTO(struct wiphy *wiphy, bool enabled), + TP_ARGS(wiphy, enabled) +); + +TRACE_EVENT(rdev_add_virtual_intf, + TP_PROTO(struct wiphy *wiphy, char *name, enum nl80211_iftype type), + TP_ARGS(wiphy, name, type), + TP_STRUCT__entry( + WIPHY_ENTRY + __string(vir_intf_name, name ? name : "") + __field(enum nl80211_iftype, type) + ), + TP_fast_assign( + WIPHY_ASSIGN; + __assign_str(vir_intf_name, name ? name : ""); + __entry->type = type; + ), + TP_printk(WIPHY_PR_FMT ", virtual intf name: %s, type: %d", + WIPHY_PR_ARG, __get_str(vir_intf_name), __entry->type) +); + +DECLARE_EVENT_CLASS(wiphy_wdev_evt, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev), + TP_ARGS(wiphy, wdev), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT, WIPHY_PR_ARG, WDEV_PR_ARG) +); + +DECLARE_EVENT_CLASS(wiphy_wdev_cookie_evt, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, u64 cookie), + TP_ARGS(wiphy, wdev, cookie), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + __field(u64, cookie) + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + __entry->cookie = cookie; + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT ", cookie: %lld", + WIPHY_PR_ARG, WDEV_PR_ARG, + (unsigned long long)__entry->cookie) +); + +DEFINE_EVENT(wiphy_wdev_evt, rdev_return_wdev, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev), + TP_ARGS(wiphy, wdev) +); + +DEFINE_EVENT(wiphy_wdev_evt, rdev_del_virtual_intf, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev), + TP_ARGS(wiphy, wdev) +); + +TRACE_EVENT(rdev_change_virtual_intf, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + enum nl80211_iftype type), + TP_ARGS(wiphy, netdev, type), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(enum nl80211_iftype, type) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->type = type; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", type: %d", + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->type) +); + +DECLARE_EVENT_CLASS(key_handle, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, int link_id, + u8 key_index, bool pairwise, const u8 *mac_addr), + TP_ARGS(wiphy, netdev, link_id, key_index, pairwise, mac_addr), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(mac_addr) + __field(int, link_id) + __field(u8, key_index) + __field(bool, pairwise) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(mac_addr, mac_addr); + __entry->link_id = link_id; + __entry->key_index = key_index; + __entry->pairwise = pairwise; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", link_id: %d, " + "key_index: %u, pairwise: %s, mac addr: " MAC_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->link_id, + __entry->key_index, BOOL_TO_STR(__entry->pairwise), + MAC_PR_ARG(mac_addr)) +); + +DEFINE_EVENT(key_handle, rdev_get_key, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, int link_id, + u8 key_index, bool pairwise, const u8 *mac_addr), + TP_ARGS(wiphy, netdev, link_id, key_index, pairwise, mac_addr) +); + +DEFINE_EVENT(key_handle, rdev_del_key, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, int link_id, + u8 key_index, bool pairwise, const u8 *mac_addr), + TP_ARGS(wiphy, netdev, link_id, key_index, pairwise, mac_addr) +); + +TRACE_EVENT(rdev_add_key, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, int link_id, + u8 key_index, bool pairwise, const u8 *mac_addr, u8 mode), + TP_ARGS(wiphy, netdev, link_id, key_index, pairwise, mac_addr, mode), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(mac_addr) + __field(int, link_id) + __field(u8, key_index) + __field(bool, pairwise) + __field(u8, mode) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(mac_addr, mac_addr); + __entry->link_id = link_id; + __entry->key_index = key_index; + __entry->pairwise = pairwise; + __entry->mode = mode; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", link_id: %d, " + "key_index: %u, mode: %u, pairwise: %s, " + "mac addr: " MAC_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->link_id, + __entry->key_index, __entry->mode, + BOOL_TO_STR(__entry->pairwise), MAC_PR_ARG(mac_addr)) +); + +TRACE_EVENT(rdev_set_default_key, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, int link_id, + u8 key_index, bool unicast, bool multicast), + TP_ARGS(wiphy, netdev, link_id, key_index, unicast, multicast), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(int, link_id) + __field(u8, key_index) + __field(bool, unicast) + __field(bool, multicast) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->link_id = link_id; + __entry->key_index = key_index; + __entry->unicast = unicast; + __entry->multicast = multicast; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", link_id: %d, " + "key index: %u, unicast: %s, multicast: %s", + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->link_id, + __entry->key_index, BOOL_TO_STR(__entry->unicast), + BOOL_TO_STR(__entry->multicast)) +); + +TRACE_EVENT(rdev_set_default_mgmt_key, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, int link_id, + u8 key_index), + TP_ARGS(wiphy, netdev, link_id, key_index), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(int, link_id) + __field(u8, key_index) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->link_id = link_id; + __entry->key_index = key_index; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", link_id: %d, " + "key index: %u", WIPHY_PR_ARG, NETDEV_PR_ARG, + __entry->link_id, __entry->key_index) +); + +TRACE_EVENT(rdev_set_default_beacon_key, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, int link_id, + u8 key_index), + TP_ARGS(wiphy, netdev, link_id, key_index), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(int, link_id) + __field(u8, key_index) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->link_id = link_id; + __entry->key_index = key_index; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", link_id: %d, " + "key index: %u", WIPHY_PR_ARG, NETDEV_PR_ARG, + __entry->link_id, __entry->key_index) +); + +TRACE_EVENT(rdev_start_ap, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_ap_settings *settings), + TP_ARGS(wiphy, netdev, settings), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + CHAN_DEF_ENTRY + __field(int, beacon_interval) + __field(int, dtim_period) + __array(char, ssid, IEEE80211_MAX_SSID_LEN + 1) + __field(enum nl80211_hidden_ssid, hidden_ssid) + __field(u32, wpa_ver) + __field(bool, privacy) + __field(enum nl80211_auth_type, auth_type) + __field(int, inactivity_timeout) + __field(unsigned int, link_id) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + CHAN_DEF_ASSIGN(&settings->chandef); + __entry->beacon_interval = settings->beacon_interval; + __entry->dtim_period = settings->dtim_period; + __entry->hidden_ssid = settings->hidden_ssid; + __entry->wpa_ver = settings->crypto.wpa_versions; + __entry->privacy = settings->privacy; + __entry->auth_type = settings->auth_type; + __entry->inactivity_timeout = settings->inactivity_timeout; + memset(__entry->ssid, 0, IEEE80211_MAX_SSID_LEN + 1); + memcpy(__entry->ssid, settings->ssid, settings->ssid_len); + __entry->link_id = settings->beacon.link_id; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", AP settings - ssid: %s, " + CHAN_DEF_PR_FMT ", beacon interval: %d, dtim period: %d, " + "hidden ssid: %d, wpa versions: %u, privacy: %s, " + "auth type: %d, inactivity timeout: %d, link_id: %d", + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->ssid, CHAN_DEF_PR_ARG, + __entry->beacon_interval, __entry->dtim_period, + __entry->hidden_ssid, __entry->wpa_ver, + BOOL_TO_STR(__entry->privacy), __entry->auth_type, + __entry->inactivity_timeout, __entry->link_id) +); + +TRACE_EVENT(rdev_change_beacon, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_beacon_data *info), + TP_ARGS(wiphy, netdev, info), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(int, link_id) + __dynamic_array(u8, head, info ? info->head_len : 0) + __dynamic_array(u8, tail, info ? info->tail_len : 0) + __dynamic_array(u8, beacon_ies, info ? info->beacon_ies_len : 0) + __dynamic_array(u8, proberesp_ies, + info ? info->proberesp_ies_len : 0) + __dynamic_array(u8, assocresp_ies, + info ? info->assocresp_ies_len : 0) + __dynamic_array(u8, probe_resp, info ? info->probe_resp_len : 0) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + if (info) { + __entry->link_id = info->link_id; + if (info->head) + memcpy(__get_dynamic_array(head), info->head, + info->head_len); + if (info->tail) + memcpy(__get_dynamic_array(tail), info->tail, + info->tail_len); + if (info->beacon_ies) + memcpy(__get_dynamic_array(beacon_ies), + info->beacon_ies, info->beacon_ies_len); + if (info->proberesp_ies) + memcpy(__get_dynamic_array(proberesp_ies), + info->proberesp_ies, + info->proberesp_ies_len); + if (info->assocresp_ies) + memcpy(__get_dynamic_array(assocresp_ies), + info->assocresp_ies, + info->assocresp_ies_len); + if (info->probe_resp) + memcpy(__get_dynamic_array(probe_resp), + info->probe_resp, info->probe_resp_len); + } else { + __entry->link_id = -1; + } + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", link_id:%d", + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->link_id) +); + +TRACE_EVENT(rdev_stop_ap, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + unsigned int link_id), + TP_ARGS(wiphy, netdev, link_id), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(unsigned int, link_id) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->link_id = link_id; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", link_id: %d", + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->link_id) +); + +DECLARE_EVENT_CLASS(wiphy_netdev_evt, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev), + TP_ARGS(wiphy, netdev), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT, WIPHY_PR_ARG, NETDEV_PR_ARG) +); + +DEFINE_EVENT(wiphy_netdev_evt, rdev_set_rekey_data, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev), + TP_ARGS(wiphy, netdev) +); + +DEFINE_EVENT(wiphy_netdev_evt, rdev_get_mesh_config, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev), + TP_ARGS(wiphy, netdev) +); + +DEFINE_EVENT(wiphy_netdev_evt, rdev_leave_mesh, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev), + TP_ARGS(wiphy, netdev) +); + +DEFINE_EVENT(wiphy_netdev_evt, rdev_leave_ibss, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev), + TP_ARGS(wiphy, netdev) +); + +DEFINE_EVENT(wiphy_netdev_evt, rdev_leave_ocb, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev), + TP_ARGS(wiphy, netdev) +); + +DEFINE_EVENT(wiphy_netdev_evt, rdev_flush_pmksa, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev), + TP_ARGS(wiphy, netdev) +); + +DEFINE_EVENT(wiphy_netdev_evt, rdev_end_cac, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev), + TP_ARGS(wiphy, netdev) +); + +DECLARE_EVENT_CLASS(station_add_change, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, u8 *mac, + struct station_parameters *params), + TP_ARGS(wiphy, netdev, mac, params), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(sta_mac) + __field(u32, sta_flags_mask) + __field(u32, sta_flags_set) + __field(u32, sta_modify_mask) + __field(int, listen_interval) + __field(u16, capability) + __field(u16, aid) + __field(u8, plink_action) + __field(u8, plink_state) + __field(u8, uapsd_queues) + __field(u8, max_sp) + __field(u8, opmode_notif) + __field(bool, opmode_notif_used) + __array(u8, ht_capa, (int)sizeof(struct ieee80211_ht_cap)) + __array(u8, vht_capa, (int)sizeof(struct ieee80211_vht_cap)) + __array(char, vlan, IFNAMSIZ) + __dynamic_array(u8, supported_rates, + params->link_sta_params.supported_rates_len) + __dynamic_array(u8, ext_capab, params->ext_capab_len) + __dynamic_array(u8, supported_channels, + params->supported_channels_len) + __dynamic_array(u8, supported_oper_classes, + params->supported_oper_classes_len) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(sta_mac, mac); + __entry->sta_flags_mask = params->sta_flags_mask; + __entry->sta_flags_set = params->sta_flags_set; + __entry->sta_modify_mask = params->sta_modify_mask; + __entry->listen_interval = params->listen_interval; + __entry->aid = params->aid; + __entry->plink_action = params->plink_action; + __entry->plink_state = params->plink_state; + __entry->uapsd_queues = params->uapsd_queues; + memset(__entry->ht_capa, 0, sizeof(struct ieee80211_ht_cap)); + if (params->link_sta_params.ht_capa) + memcpy(__entry->ht_capa, + params->link_sta_params.ht_capa, + sizeof(struct ieee80211_ht_cap)); + memset(__entry->vht_capa, 0, sizeof(struct ieee80211_vht_cap)); + if (params->link_sta_params.vht_capa) + memcpy(__entry->vht_capa, + params->link_sta_params.vht_capa, + sizeof(struct ieee80211_vht_cap)); + memset(__entry->vlan, 0, sizeof(__entry->vlan)); + if (params->vlan) + memcpy(__entry->vlan, params->vlan->name, IFNAMSIZ); + if (params->link_sta_params.supported_rates && + params->link_sta_params.supported_rates_len) + memcpy(__get_dynamic_array(supported_rates), + params->link_sta_params.supported_rates, + params->link_sta_params.supported_rates_len); + if (params->ext_capab && params->ext_capab_len) + memcpy(__get_dynamic_array(ext_capab), + params->ext_capab, + params->ext_capab_len); + if (params->supported_channels && + params->supported_channels_len) + memcpy(__get_dynamic_array(supported_channels), + params->supported_channels, + params->supported_channels_len); + if (params->supported_oper_classes && + params->supported_oper_classes_len) + memcpy(__get_dynamic_array(supported_oper_classes), + params->supported_oper_classes, + params->supported_oper_classes_len); + __entry->max_sp = params->max_sp; + __entry->capability = params->capability; + __entry->opmode_notif = params->link_sta_params.opmode_notif; + __entry->opmode_notif_used = + params->link_sta_params.opmode_notif_used; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", station mac: " MAC_PR_FMT + ", station flags mask: %u, station flags set: %u, " + "station modify mask: %u, listen interval: %d, aid: %u, " + "plink action: %u, plink state: %u, uapsd queues: %u, vlan:%s", + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(sta_mac), + __entry->sta_flags_mask, __entry->sta_flags_set, + __entry->sta_modify_mask, __entry->listen_interval, + __entry->aid, __entry->plink_action, __entry->plink_state, + __entry->uapsd_queues, __entry->vlan) +); + +DEFINE_EVENT(station_add_change, rdev_add_station, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, u8 *mac, + struct station_parameters *params), + TP_ARGS(wiphy, netdev, mac, params) +); + +DEFINE_EVENT(station_add_change, rdev_change_station, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, u8 *mac, + struct station_parameters *params), + TP_ARGS(wiphy, netdev, mac, params) +); + +DECLARE_EVENT_CLASS(wiphy_netdev_mac_evt, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, const u8 *mac), + TP_ARGS(wiphy, netdev, mac), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(sta_mac) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(sta_mac, mac); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", mac: " MAC_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(sta_mac)) +); + +DECLARE_EVENT_CLASS(station_del, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct station_del_parameters *params), + TP_ARGS(wiphy, netdev, params), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(sta_mac) + __field(u8, subtype) + __field(u16, reason_code) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(sta_mac, params->mac); + __entry->subtype = params->subtype; + __entry->reason_code = params->reason_code; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", station mac: " MAC_PR_FMT + ", subtype: %u, reason_code: %u", + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(sta_mac), + __entry->subtype, __entry->reason_code) +); + +DEFINE_EVENT(station_del, rdev_del_station, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct station_del_parameters *params), + TP_ARGS(wiphy, netdev, params) +); + +DEFINE_EVENT(wiphy_netdev_mac_evt, rdev_get_station, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, const u8 *mac), + TP_ARGS(wiphy, netdev, mac) +); + +DEFINE_EVENT(wiphy_netdev_mac_evt, rdev_del_mpath, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, const u8 *mac), + TP_ARGS(wiphy, netdev, mac) +); + +TRACE_EVENT(rdev_dump_station, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, int _idx, + u8 *mac), + TP_ARGS(wiphy, netdev, _idx, mac), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(sta_mac) + __field(int, idx) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(sta_mac, mac); + __entry->idx = _idx; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", station mac: " MAC_PR_FMT ", idx: %d", + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(sta_mac), + __entry->idx) +); + +TRACE_EVENT(rdev_return_int_station_info, + TP_PROTO(struct wiphy *wiphy, int ret, struct station_info *sinfo), + TP_ARGS(wiphy, ret, sinfo), + TP_STRUCT__entry( + WIPHY_ENTRY + __field(int, ret) + SINFO_ENTRY + ), + TP_fast_assign( + WIPHY_ASSIGN; + __entry->ret = ret; + SINFO_ASSIGN; + ), + TP_printk(WIPHY_PR_FMT ", returned %d" , + WIPHY_PR_ARG, __entry->ret) +); + +DECLARE_EVENT_CLASS(mpath_evt, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, u8 *dst, + u8 *next_hop), + TP_ARGS(wiphy, netdev, dst, next_hop), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(dst) + MAC_ENTRY(next_hop) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(dst, dst); + MAC_ASSIGN(next_hop, next_hop); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", destination: " MAC_PR_FMT ", next hop: " MAC_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(dst), + MAC_PR_ARG(next_hop)) +); + +DEFINE_EVENT(mpath_evt, rdev_add_mpath, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, u8 *dst, + u8 *next_hop), + TP_ARGS(wiphy, netdev, dst, next_hop) +); + +DEFINE_EVENT(mpath_evt, rdev_change_mpath, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, u8 *dst, + u8 *next_hop), + TP_ARGS(wiphy, netdev, dst, next_hop) +); + +DEFINE_EVENT(mpath_evt, rdev_get_mpath, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, u8 *dst, + u8 *next_hop), + TP_ARGS(wiphy, netdev, dst, next_hop) +); + +TRACE_EVENT(rdev_dump_mpath, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, int _idx, + u8 *dst, u8 *next_hop), + TP_ARGS(wiphy, netdev, _idx, dst, next_hop), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(dst) + MAC_ENTRY(next_hop) + __field(int, idx) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(dst, dst); + MAC_ASSIGN(next_hop, next_hop); + __entry->idx = _idx; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", index: %d, destination: " + MAC_PR_FMT ", next hop: " MAC_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->idx, MAC_PR_ARG(dst), + MAC_PR_ARG(next_hop)) +); + +TRACE_EVENT(rdev_get_mpp, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + u8 *dst, u8 *mpp), + TP_ARGS(wiphy, netdev, dst, mpp), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(dst) + MAC_ENTRY(mpp) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(dst, dst); + MAC_ASSIGN(mpp, mpp); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", destination: " MAC_PR_FMT + ", mpp: " MAC_PR_FMT, WIPHY_PR_ARG, NETDEV_PR_ARG, + MAC_PR_ARG(dst), MAC_PR_ARG(mpp)) +); + +TRACE_EVENT(rdev_dump_mpp, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, int _idx, + u8 *dst, u8 *mpp), + TP_ARGS(wiphy, netdev, _idx, mpp, dst), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(dst) + MAC_ENTRY(mpp) + __field(int, idx) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(dst, dst); + MAC_ASSIGN(mpp, mpp); + __entry->idx = _idx; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", index: %d, destination: " + MAC_PR_FMT ", mpp: " MAC_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->idx, MAC_PR_ARG(dst), + MAC_PR_ARG(mpp)) +); + +TRACE_EVENT(rdev_return_int_mpath_info, + TP_PROTO(struct wiphy *wiphy, int ret, struct mpath_info *pinfo), + TP_ARGS(wiphy, ret, pinfo), + TP_STRUCT__entry( + WIPHY_ENTRY + __field(int, ret) + __field(int, generation) + __field(u32, filled) + __field(u32, frame_qlen) + __field(u32, sn) + __field(u32, metric) + __field(u32, exptime) + __field(u32, discovery_timeout) + __field(u8, discovery_retries) + __field(u8, flags) + ), + TP_fast_assign( + WIPHY_ASSIGN; + __entry->ret = ret; + __entry->generation = pinfo->generation; + __entry->filled = pinfo->filled; + __entry->frame_qlen = pinfo->frame_qlen; + __entry->sn = pinfo->sn; + __entry->metric = pinfo->metric; + __entry->exptime = pinfo->exptime; + __entry->discovery_timeout = pinfo->discovery_timeout; + __entry->discovery_retries = pinfo->discovery_retries; + __entry->flags = pinfo->flags; + ), + TP_printk(WIPHY_PR_FMT ", returned %d. mpath info - generation: %d, " + "filled: %u, frame qlen: %u, sn: %u, metric: %u, exptime: %u," + " discovery timeout: %u, discovery retries: %u, flags: %u", + WIPHY_PR_ARG, __entry->ret, __entry->generation, + __entry->filled, __entry->frame_qlen, __entry->sn, + __entry->metric, __entry->exptime, __entry->discovery_timeout, + __entry->discovery_retries, __entry->flags) +); + +TRACE_EVENT(rdev_return_int_mesh_config, + TP_PROTO(struct wiphy *wiphy, int ret, struct mesh_config *conf), + TP_ARGS(wiphy, ret, conf), + TP_STRUCT__entry( + WIPHY_ENTRY + MESH_CFG_ENTRY + __field(int, ret) + ), + TP_fast_assign( + WIPHY_ASSIGN; + MESH_CFG_ASSIGN; + __entry->ret = ret; + ), + TP_printk(WIPHY_PR_FMT ", returned: %d", + WIPHY_PR_ARG, __entry->ret) +); + +TRACE_EVENT(rdev_update_mesh_config, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, u32 mask, + const struct mesh_config *conf), + TP_ARGS(wiphy, netdev, mask, conf), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MESH_CFG_ENTRY + __field(u32, mask) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MESH_CFG_ASSIGN; + __entry->mask = mask; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", mask: %u", + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->mask) +); + +TRACE_EVENT(rdev_join_mesh, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + const struct mesh_config *conf, + const struct mesh_setup *setup), + TP_ARGS(wiphy, netdev, conf, setup), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MESH_CFG_ENTRY + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MESH_CFG_ASSIGN; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG) +); + +TRACE_EVENT(rdev_change_bss, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct bss_parameters *params), + TP_ARGS(wiphy, netdev, params), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(int, use_cts_prot) + __field(int, use_short_preamble) + __field(int, use_short_slot_time) + __field(int, ap_isolate) + __field(int, ht_opmode) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->use_cts_prot = params->use_cts_prot; + __entry->use_short_preamble = params->use_short_preamble; + __entry->use_short_slot_time = params->use_short_slot_time; + __entry->ap_isolate = params->ap_isolate; + __entry->ht_opmode = params->ht_opmode; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", use cts prot: %d, " + "use short preamble: %d, use short slot time: %d, " + "ap isolate: %d, ht opmode: %d", + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->use_cts_prot, + __entry->use_short_preamble, __entry->use_short_slot_time, + __entry->ap_isolate, __entry->ht_opmode) +); + +TRACE_EVENT(rdev_set_txq_params, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct ieee80211_txq_params *params), + TP_ARGS(wiphy, netdev, params), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(enum nl80211_ac, ac) + __field(u16, txop) + __field(u16, cwmin) + __field(u16, cwmax) + __field(u8, aifs) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->ac = params->ac; + __entry->txop = params->txop; + __entry->cwmin = params->cwmin; + __entry->cwmax = params->cwmax; + __entry->aifs = params->aifs; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", ac: %d, txop: %u, cwmin: %u, cwmax: %u, aifs: %u", + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->ac, __entry->txop, + __entry->cwmin, __entry->cwmax, __entry->aifs) +); + +TRACE_EVENT(rdev_libertas_set_mesh_channel, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct ieee80211_channel *chan), + TP_ARGS(wiphy, netdev, chan), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + CHAN_ENTRY + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + CHAN_ASSIGN(chan); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " CHAN_PR_FMT, WIPHY_PR_ARG, + NETDEV_PR_ARG, CHAN_PR_ARG) +); + +TRACE_EVENT(rdev_set_monitor_channel, + TP_PROTO(struct wiphy *wiphy, + struct cfg80211_chan_def *chandef), + TP_ARGS(wiphy, chandef), + TP_STRUCT__entry( + WIPHY_ENTRY + CHAN_DEF_ENTRY + ), + TP_fast_assign( + WIPHY_ASSIGN; + CHAN_DEF_ASSIGN(chandef); + ), + TP_printk(WIPHY_PR_FMT ", " CHAN_DEF_PR_FMT, + WIPHY_PR_ARG, CHAN_DEF_PR_ARG) +); + +TRACE_EVENT(rdev_auth, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_auth_request *req), + TP_ARGS(wiphy, netdev, req), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(bssid) + __field(enum nl80211_auth_type, auth_type) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + if (req->bss) + MAC_ASSIGN(bssid, req->bss->bssid); + else + eth_zero_addr(__entry->bssid); + __entry->auth_type = req->auth_type; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", auth type: %d, bssid: " MAC_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->auth_type, + MAC_PR_ARG(bssid)) +); + +TRACE_EVENT(rdev_assoc, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_assoc_request *req), + TP_ARGS(wiphy, netdev, req), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(bssid) + MAC_ENTRY(prev_bssid) + __field(bool, use_mfp) + __field(u32, flags) + __dynamic_array(u8, elements, req->ie_len) + __array(u8, ht_capa, sizeof(struct ieee80211_ht_cap)) + __array(u8, ht_capa_mask, sizeof(struct ieee80211_ht_cap)) + __array(u8, vht_capa, sizeof(struct ieee80211_vht_cap)) + __array(u8, vht_capa_mask, sizeof(struct ieee80211_vht_cap)) + __dynamic_array(u8, fils_kek, req->fils_kek_len) + __dynamic_array(u8, fils_nonces, + req->fils_nonces ? 2 * FILS_NONCE_LEN : 0) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + if (req->bss) + MAC_ASSIGN(bssid, req->bss->bssid); + else + eth_zero_addr(__entry->bssid); + MAC_ASSIGN(prev_bssid, req->prev_bssid); + __entry->use_mfp = req->use_mfp; + __entry->flags = req->flags; + if (req->ie) + memcpy(__get_dynamic_array(elements), + req->ie, req->ie_len); + memcpy(__entry->ht_capa, &req->ht_capa, sizeof(req->ht_capa)); + memcpy(__entry->ht_capa_mask, &req->ht_capa_mask, + sizeof(req->ht_capa_mask)); + memcpy(__entry->vht_capa, &req->vht_capa, sizeof(req->vht_capa)); + memcpy(__entry->vht_capa_mask, &req->vht_capa_mask, + sizeof(req->vht_capa_mask)); + if (req->fils_kek) + memcpy(__get_dynamic_array(fils_kek), + req->fils_kek, req->fils_kek_len); + if (req->fils_nonces) + memcpy(__get_dynamic_array(fils_nonces), + req->fils_nonces, 2 * FILS_NONCE_LEN); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", bssid: " MAC_PR_FMT + ", previous bssid: " MAC_PR_FMT ", use mfp: %s, flags: %u", + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(bssid), + MAC_PR_ARG(prev_bssid), BOOL_TO_STR(__entry->use_mfp), + __entry->flags) +); + +TRACE_EVENT(rdev_deauth, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_deauth_request *req), + TP_ARGS(wiphy, netdev, req), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(bssid) + __field(u16, reason_code) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(bssid, req->bssid); + __entry->reason_code = req->reason_code; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", bssid: " MAC_PR_FMT ", reason: %u", + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(bssid), + __entry->reason_code) +); + +TRACE_EVENT(rdev_disassoc, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_disassoc_request *req), + TP_ARGS(wiphy, netdev, req), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(bssid) + __field(u16, reason_code) + __field(bool, local_state_change) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(bssid, req->ap_addr); + __entry->reason_code = req->reason_code; + __entry->local_state_change = req->local_state_change; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", bssid: " MAC_PR_FMT + ", reason: %u, local state change: %s", + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(bssid), + __entry->reason_code, + BOOL_TO_STR(__entry->local_state_change)) +); + +TRACE_EVENT(rdev_mgmt_tx_cancel_wait, + TP_PROTO(struct wiphy *wiphy, + struct wireless_dev *wdev, u64 cookie), + TP_ARGS(wiphy, wdev, cookie), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + __field(u64, cookie) + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + __entry->cookie = cookie; + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT ", cookie: %llu ", + WIPHY_PR_ARG, WDEV_PR_ARG, __entry->cookie) +); + +TRACE_EVENT(rdev_set_power_mgmt, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + bool enabled, int timeout), + TP_ARGS(wiphy, netdev, enabled, timeout), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(bool, enabled) + __field(int, timeout) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->enabled = enabled; + __entry->timeout = timeout; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", %senabled, timeout: %d ", + WIPHY_PR_ARG, NETDEV_PR_ARG, + __entry->enabled ? "" : "not ", __entry->timeout) +); + +TRACE_EVENT(rdev_connect, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_connect_params *sme), + TP_ARGS(wiphy, netdev, sme), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(bssid) + __array(char, ssid, IEEE80211_MAX_SSID_LEN + 1) + __field(enum nl80211_auth_type, auth_type) + __field(bool, privacy) + __field(u32, wpa_versions) + __field(u32, flags) + MAC_ENTRY(prev_bssid) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(bssid, sme->bssid); + memset(__entry->ssid, 0, IEEE80211_MAX_SSID_LEN + 1); + memcpy(__entry->ssid, sme->ssid, sme->ssid_len); + __entry->auth_type = sme->auth_type; + __entry->privacy = sme->privacy; + __entry->wpa_versions = sme->crypto.wpa_versions; + __entry->flags = sme->flags; + MAC_ASSIGN(prev_bssid, sme->prev_bssid); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", bssid: " MAC_PR_FMT + ", ssid: %s, auth type: %d, privacy: %s, wpa versions: %u, " + "flags: %u, previous bssid: " MAC_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(bssid), __entry->ssid, + __entry->auth_type, BOOL_TO_STR(__entry->privacy), + __entry->wpa_versions, __entry->flags, MAC_PR_ARG(prev_bssid)) +); + +TRACE_EVENT(rdev_update_connect_params, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_connect_params *sme, u32 changed), + TP_ARGS(wiphy, netdev, sme, changed), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(u32, changed) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->changed = changed; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", parameters changed: %u", + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->changed) +); + +TRACE_EVENT(rdev_set_cqm_rssi_config, + TP_PROTO(struct wiphy *wiphy, + struct net_device *netdev, s32 rssi_thold, + u32 rssi_hyst), + TP_ARGS(wiphy, netdev, rssi_thold, rssi_hyst), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(s32, rssi_thold) + __field(u32, rssi_hyst) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->rssi_thold = rssi_thold; + __entry->rssi_hyst = rssi_hyst; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT + ", rssi_thold: %d, rssi_hyst: %u ", + WIPHY_PR_ARG, NETDEV_PR_ARG, + __entry->rssi_thold, __entry->rssi_hyst) +); + +TRACE_EVENT(rdev_set_cqm_rssi_range_config, + TP_PROTO(struct wiphy *wiphy, + struct net_device *netdev, s32 low, s32 high), + TP_ARGS(wiphy, netdev, low, high), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(s32, rssi_low) + __field(s32, rssi_high) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->rssi_low = low; + __entry->rssi_high = high; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT + ", range: %d - %d ", + WIPHY_PR_ARG, NETDEV_PR_ARG, + __entry->rssi_low, __entry->rssi_high) +); + +TRACE_EVENT(rdev_set_cqm_txe_config, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, u32 rate, + u32 pkts, u32 intvl), + TP_ARGS(wiphy, netdev, rate, pkts, intvl), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(u32, rate) + __field(u32, pkts) + __field(u32, intvl) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->rate = rate; + __entry->pkts = pkts; + __entry->intvl = intvl; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", rate: %u, packets: %u, interval: %u", + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->rate, __entry->pkts, + __entry->intvl) +); + +TRACE_EVENT(rdev_disconnect, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + u16 reason_code), + TP_ARGS(wiphy, netdev, reason_code), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(u16, reason_code) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->reason_code = reason_code; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", reason code: %u", WIPHY_PR_ARG, + NETDEV_PR_ARG, __entry->reason_code) +); + +TRACE_EVENT(rdev_join_ibss, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_ibss_params *params), + TP_ARGS(wiphy, netdev, params), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(bssid) + __array(char, ssid, IEEE80211_MAX_SSID_LEN + 1) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(bssid, params->bssid); + memset(__entry->ssid, 0, IEEE80211_MAX_SSID_LEN + 1); + memcpy(__entry->ssid, params->ssid, params->ssid_len); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", bssid: " MAC_PR_FMT ", ssid: %s", + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(bssid), __entry->ssid) +); + +TRACE_EVENT(rdev_join_ocb, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + const struct ocb_setup *setup), + TP_ARGS(wiphy, netdev, setup), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG) +); + +TRACE_EVENT(rdev_set_wiphy_params, + TP_PROTO(struct wiphy *wiphy, u32 changed), + TP_ARGS(wiphy, changed), + TP_STRUCT__entry( + WIPHY_ENTRY + __field(u32, changed) + ), + TP_fast_assign( + WIPHY_ASSIGN; + __entry->changed = changed; + ), + TP_printk(WIPHY_PR_FMT ", changed: %u", + WIPHY_PR_ARG, __entry->changed) +); + +DEFINE_EVENT(wiphy_wdev_evt, rdev_get_tx_power, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev), + TP_ARGS(wiphy, wdev) +); + +TRACE_EVENT(rdev_set_tx_power, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, + enum nl80211_tx_power_setting type, int mbm), + TP_ARGS(wiphy, wdev, type, mbm), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + __field(enum nl80211_tx_power_setting, type) + __field(int, mbm) + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + __entry->type = type; + __entry->mbm = mbm; + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT ", type: %u, mbm: %d", + WIPHY_PR_ARG, WDEV_PR_ARG,__entry->type, __entry->mbm) +); + +TRACE_EVENT(rdev_return_int_int, + TP_PROTO(struct wiphy *wiphy, int func_ret, int func_fill), + TP_ARGS(wiphy, func_ret, func_fill), + TP_STRUCT__entry( + WIPHY_ENTRY + __field(int, func_ret) + __field(int, func_fill) + ), + TP_fast_assign( + WIPHY_ASSIGN; + __entry->func_ret = func_ret; + __entry->func_fill = func_fill; + ), + TP_printk(WIPHY_PR_FMT ", function returns: %d, function filled: %d", + WIPHY_PR_ARG, __entry->func_ret, __entry->func_fill) +); + +#ifdef CONFIG_NL80211_TESTMODE +TRACE_EVENT(rdev_testmode_cmd, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev), + TP_ARGS(wiphy, wdev), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + ), + TP_printk(WIPHY_PR_FMT WDEV_PR_FMT, WIPHY_PR_ARG, WDEV_PR_ARG) +); + +TRACE_EVENT(rdev_testmode_dump, + TP_PROTO(struct wiphy *wiphy), + TP_ARGS(wiphy), + TP_STRUCT__entry( + WIPHY_ENTRY + ), + TP_fast_assign( + WIPHY_ASSIGN; + ), + TP_printk(WIPHY_PR_FMT, WIPHY_PR_ARG) +); +#endif /* CONFIG_NL80211_TESTMODE */ + +TRACE_EVENT(rdev_set_bitrate_mask, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + unsigned int link_id, + const u8 *peer, const struct cfg80211_bitrate_mask *mask), + TP_ARGS(wiphy, netdev, link_id, peer, mask), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(unsigned int, link_id) + MAC_ENTRY(peer) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->link_id = link_id; + MAC_ASSIGN(peer, peer); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", link_id: %d, peer: " MAC_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->link_id, + MAC_PR_ARG(peer)) +); + +TRACE_EVENT(rdev_update_mgmt_frame_registrations, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, + struct mgmt_frame_regs *upd), + TP_ARGS(wiphy, wdev, upd), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + __field(u16, global_stypes) + __field(u16, interface_stypes) + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + __entry->global_stypes = upd->global_stypes; + __entry->interface_stypes = upd->interface_stypes; + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT ", global: 0x%.2x, intf: 0x%.2x", + WIPHY_PR_ARG, WDEV_PR_ARG, + __entry->global_stypes, __entry->interface_stypes) +); + +TRACE_EVENT(rdev_return_int_tx_rx, + TP_PROTO(struct wiphy *wiphy, int ret, u32 tx, u32 rx), + TP_ARGS(wiphy, ret, tx, rx), + TP_STRUCT__entry( + WIPHY_ENTRY + __field(int, ret) + __field(u32, tx) + __field(u32, rx) + ), + TP_fast_assign( + WIPHY_ASSIGN; + __entry->ret = ret; + __entry->tx = tx; + __entry->rx = rx; + ), + TP_printk(WIPHY_PR_FMT ", returned %d, tx: %u, rx: %u", + WIPHY_PR_ARG, __entry->ret, __entry->tx, __entry->rx) +); + +TRACE_EVENT(rdev_return_void_tx_rx, + TP_PROTO(struct wiphy *wiphy, u32 tx, u32 tx_max, + u32 rx, u32 rx_max), + TP_ARGS(wiphy, tx, tx_max, rx, rx_max), + TP_STRUCT__entry( + WIPHY_ENTRY + __field(u32, tx) + __field(u32, tx_max) + __field(u32, rx) + __field(u32, rx_max) + ), + TP_fast_assign( + WIPHY_ASSIGN; + __entry->tx = tx; + __entry->tx_max = tx_max; + __entry->rx = rx; + __entry->rx_max = rx_max; + ), + TP_printk(WIPHY_PR_FMT ", tx: %u, tx_max: %u, rx: %u, rx_max: %u ", + WIPHY_PR_ARG, __entry->tx, __entry->tx_max, __entry->rx, + __entry->rx_max) +); + +DECLARE_EVENT_CLASS(tx_rx_evt, + TP_PROTO(struct wiphy *wiphy, u32 tx, u32 rx), + TP_ARGS(wiphy, rx, tx), + TP_STRUCT__entry( + WIPHY_ENTRY + __field(u32, tx) + __field(u32, rx) + ), + TP_fast_assign( + WIPHY_ASSIGN; + __entry->tx = tx; + __entry->rx = rx; + ), + TP_printk(WIPHY_PR_FMT ", tx: %u, rx: %u ", + WIPHY_PR_ARG, __entry->tx, __entry->rx) +); + +DEFINE_EVENT(tx_rx_evt, rdev_set_antenna, + TP_PROTO(struct wiphy *wiphy, u32 tx, u32 rx), + TP_ARGS(wiphy, rx, tx) +); + +DECLARE_EVENT_CLASS(wiphy_netdev_id_evt, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, u64 id), + TP_ARGS(wiphy, netdev, id), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(u64, id) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->id = id; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", id: %llu", + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->id) +); + +DEFINE_EVENT(wiphy_netdev_id_evt, rdev_sched_scan_start, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, u64 id), + TP_ARGS(wiphy, netdev, id) +); + +DEFINE_EVENT(wiphy_netdev_id_evt, rdev_sched_scan_stop, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, u64 id), + TP_ARGS(wiphy, netdev, id) +); + +TRACE_EVENT(rdev_tdls_mgmt, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + u8 *peer, u8 action_code, u8 dialog_token, + u16 status_code, u32 peer_capability, + bool initiator, const u8 *buf, size_t len), + TP_ARGS(wiphy, netdev, peer, action_code, dialog_token, status_code, + peer_capability, initiator, buf, len), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(peer) + __field(u8, action_code) + __field(u8, dialog_token) + __field(u16, status_code) + __field(u32, peer_capability) + __field(bool, initiator) + __dynamic_array(u8, buf, len) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(peer, peer); + __entry->action_code = action_code; + __entry->dialog_token = dialog_token; + __entry->status_code = status_code; + __entry->peer_capability = peer_capability; + __entry->initiator = initiator; + memcpy(__get_dynamic_array(buf), buf, len); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " MAC_PR_FMT ", action_code: %u, " + "dialog_token: %u, status_code: %u, peer_capability: %u " + "initiator: %s buf: %#.2x ", + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(peer), + __entry->action_code, __entry->dialog_token, + __entry->status_code, __entry->peer_capability, + BOOL_TO_STR(__entry->initiator), + ((u8 *)__get_dynamic_array(buf))[0]) +); + +TRACE_EVENT(rdev_dump_survey, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, int _idx), + TP_ARGS(wiphy, netdev, _idx), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(int, idx) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->idx = _idx; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", index: %d", + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->idx) +); + +TRACE_EVENT(rdev_return_int_survey_info, + TP_PROTO(struct wiphy *wiphy, int ret, struct survey_info *info), + TP_ARGS(wiphy, ret, info), + TP_STRUCT__entry( + WIPHY_ENTRY + CHAN_ENTRY + __field(int, ret) + __field(u64, time) + __field(u64, time_busy) + __field(u64, time_ext_busy) + __field(u64, time_rx) + __field(u64, time_tx) + __field(u64, time_scan) + __field(u32, filled) + __field(s8, noise) + ), + TP_fast_assign( + WIPHY_ASSIGN; + CHAN_ASSIGN(info->channel); + __entry->ret = ret; + __entry->time = info->time; + __entry->time_busy = info->time_busy; + __entry->time_ext_busy = info->time_ext_busy; + __entry->time_rx = info->time_rx; + __entry->time_tx = info->time_tx; + __entry->time_scan = info->time_scan; + __entry->filled = info->filled; + __entry->noise = info->noise; + ), + TP_printk(WIPHY_PR_FMT ", returned: %d, " CHAN_PR_FMT + ", channel time: %llu, channel time busy: %llu, " + "channel time extension busy: %llu, channel time rx: %llu, " + "channel time tx: %llu, scan time: %llu, filled: %u, noise: %d", + WIPHY_PR_ARG, __entry->ret, CHAN_PR_ARG, + __entry->time, __entry->time_busy, + __entry->time_ext_busy, __entry->time_rx, + __entry->time_tx, __entry->time_scan, + __entry->filled, __entry->noise) +); + +TRACE_EVENT(rdev_tdls_oper, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + u8 *peer, enum nl80211_tdls_operation oper), + TP_ARGS(wiphy, netdev, peer, oper), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(peer) + __field(enum nl80211_tdls_operation, oper) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(peer, peer); + __entry->oper = oper; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " MAC_PR_FMT ", oper: %d", + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(peer), __entry->oper) +); + +DECLARE_EVENT_CLASS(rdev_pmksa, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_pmksa *pmksa), + TP_ARGS(wiphy, netdev, pmksa), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(bssid) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(bssid, pmksa->bssid); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", bssid: " MAC_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(bssid)) +); + +TRACE_EVENT(rdev_probe_client, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + const u8 *peer), + TP_ARGS(wiphy, netdev, peer), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(peer) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(peer, peer); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " MAC_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(peer)) +); + +DEFINE_EVENT(rdev_pmksa, rdev_set_pmksa, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_pmksa *pmksa), + TP_ARGS(wiphy, netdev, pmksa) +); + +DEFINE_EVENT(rdev_pmksa, rdev_del_pmksa, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_pmksa *pmksa), + TP_ARGS(wiphy, netdev, pmksa) +); + +TRACE_EVENT(rdev_remain_on_channel, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, + struct ieee80211_channel *chan, + unsigned int duration), + TP_ARGS(wiphy, wdev, chan, duration), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + CHAN_ENTRY + __field(unsigned int, duration) + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + CHAN_ASSIGN(chan); + __entry->duration = duration; + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT ", " CHAN_PR_FMT ", duration: %u", + WIPHY_PR_ARG, WDEV_PR_ARG, CHAN_PR_ARG, __entry->duration) +); + +TRACE_EVENT(rdev_return_int_cookie, + TP_PROTO(struct wiphy *wiphy, int ret, u64 cookie), + TP_ARGS(wiphy, ret, cookie), + TP_STRUCT__entry( + WIPHY_ENTRY + __field(int, ret) + __field(u64, cookie) + ), + TP_fast_assign( + WIPHY_ASSIGN; + __entry->ret = ret; + __entry->cookie = cookie; + ), + TP_printk(WIPHY_PR_FMT ", returned %d, cookie: %llu", + WIPHY_PR_ARG, __entry->ret, __entry->cookie) +); + +TRACE_EVENT(rdev_cancel_remain_on_channel, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, u64 cookie), + TP_ARGS(wiphy, wdev, cookie), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + __field(u64, cookie) + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + __entry->cookie = cookie; + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT ", cookie: %llu", + WIPHY_PR_ARG, WDEV_PR_ARG, __entry->cookie) +); + +TRACE_EVENT(rdev_mgmt_tx, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, + struct cfg80211_mgmt_tx_params *params), + TP_ARGS(wiphy, wdev, params), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + CHAN_ENTRY + __field(bool, offchan) + __field(unsigned int, wait) + __field(bool, no_cck) + __field(bool, dont_wait_for_ack) + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + CHAN_ASSIGN(params->chan); + __entry->offchan = params->offchan; + __entry->wait = params->wait; + __entry->no_cck = params->no_cck; + __entry->dont_wait_for_ack = params->dont_wait_for_ack; + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT ", " CHAN_PR_FMT ", offchan: %s," + " wait: %u, no cck: %s, dont wait for ack: %s", + WIPHY_PR_ARG, WDEV_PR_ARG, CHAN_PR_ARG, + BOOL_TO_STR(__entry->offchan), __entry->wait, + BOOL_TO_STR(__entry->no_cck), + BOOL_TO_STR(__entry->dont_wait_for_ack)) +); + +TRACE_EVENT(rdev_tx_control_port, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + const u8 *buf, size_t len, const u8 *dest, __be16 proto, + bool unencrypted, int link_id), + TP_ARGS(wiphy, netdev, buf, len, dest, proto, unencrypted, link_id), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(dest) + __field(__be16, proto) + __field(bool, unencrypted) + __field(int, link_id) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(dest, dest); + __entry->proto = proto; + __entry->unencrypted = unencrypted; + __entry->link_id = link_id; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " MAC_PR_FMT "," + " proto: 0x%x, unencrypted: %s, link: %d", + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(dest), + be16_to_cpu(__entry->proto), + BOOL_TO_STR(__entry->unencrypted), + __entry->link_id) +); + +TRACE_EVENT(rdev_set_noack_map, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + u16 noack_map), + TP_ARGS(wiphy, netdev, noack_map), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(u16, noack_map) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->noack_map = noack_map; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", noack_map: %u", + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->noack_map) +); + +DECLARE_EVENT_CLASS(wiphy_wdev_link_evt, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, + unsigned int link_id), + TP_ARGS(wiphy, wdev, link_id), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + __field(unsigned int, link_id) + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + __entry->link_id = link_id; + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT ", link_id: %u", + WIPHY_PR_ARG, WDEV_PR_ARG, __entry->link_id) +); + +DEFINE_EVENT(wiphy_wdev_link_evt, rdev_get_channel, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, + unsigned int link_id), + TP_ARGS(wiphy, wdev, link_id) +); + +TRACE_EVENT(rdev_return_chandef, + TP_PROTO(struct wiphy *wiphy, int ret, + struct cfg80211_chan_def *chandef), + TP_ARGS(wiphy, ret, chandef), + TP_STRUCT__entry( + WIPHY_ENTRY + __field(int, ret) + CHAN_DEF_ENTRY + ), + TP_fast_assign( + WIPHY_ASSIGN; + if (ret == 0) + CHAN_DEF_ASSIGN(chandef); + else + CHAN_DEF_ASSIGN((struct cfg80211_chan_def *)NULL); + __entry->ret = ret; + ), + TP_printk(WIPHY_PR_FMT ", " CHAN_DEF_PR_FMT ", ret: %d", + WIPHY_PR_ARG, CHAN_DEF_PR_ARG, __entry->ret) +); + +DEFINE_EVENT(wiphy_wdev_evt, rdev_start_p2p_device, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev), + TP_ARGS(wiphy, wdev) +); + +DEFINE_EVENT(wiphy_wdev_evt, rdev_stop_p2p_device, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev), + TP_ARGS(wiphy, wdev) +); + +TRACE_EVENT(rdev_start_nan, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, + struct cfg80211_nan_conf *conf), + TP_ARGS(wiphy, wdev, conf), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + __field(u8, master_pref) + __field(u8, bands) + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + __entry->master_pref = conf->master_pref; + __entry->bands = conf->bands; + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT + ", master preference: %u, bands: 0x%0x", + WIPHY_PR_ARG, WDEV_PR_ARG, __entry->master_pref, + __entry->bands) +); + +TRACE_EVENT(rdev_nan_change_conf, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, + struct cfg80211_nan_conf *conf, u32 changes), + TP_ARGS(wiphy, wdev, conf, changes), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + __field(u8, master_pref) + __field(u8, bands) + __field(u32, changes) + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + __entry->master_pref = conf->master_pref; + __entry->bands = conf->bands; + __entry->changes = changes; + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT + ", master preference: %u, bands: 0x%0x, changes: %x", + WIPHY_PR_ARG, WDEV_PR_ARG, __entry->master_pref, + __entry->bands, __entry->changes) +); + +DEFINE_EVENT(wiphy_wdev_evt, rdev_stop_nan, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev), + TP_ARGS(wiphy, wdev) +); + +TRACE_EVENT(rdev_add_nan_func, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, + const struct cfg80211_nan_func *func), + TP_ARGS(wiphy, wdev, func), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + __field(u8, func_type) + __field(u64, cookie) + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + __entry->func_type = func->type; + __entry->cookie = func->cookie + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT ", type=%u, cookie=%llu", + WIPHY_PR_ARG, WDEV_PR_ARG, __entry->func_type, + __entry->cookie) +); + +TRACE_EVENT(rdev_del_nan_func, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, + u64 cookie), + TP_ARGS(wiphy, wdev, cookie), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + __field(u64, cookie) + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + __entry->cookie = cookie; + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT ", cookie=%llu", + WIPHY_PR_ARG, WDEV_PR_ARG, __entry->cookie) +); + +TRACE_EVENT(rdev_set_mac_acl, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_acl_data *params), + TP_ARGS(wiphy, netdev, params), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(u32, acl_policy) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->acl_policy = params->acl_policy; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", acl policy: %d", + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->acl_policy) +); + +TRACE_EVENT(rdev_update_ft_ies, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_update_ft_ies_params *ftie), + TP_ARGS(wiphy, netdev, ftie), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(u16, md) + __dynamic_array(u8, ie, ftie->ie_len) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->md = ftie->md; + memcpy(__get_dynamic_array(ie), ftie->ie, ftie->ie_len); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", md: 0x%x", + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->md) +); + +TRACE_EVENT(rdev_crit_proto_start, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, + enum nl80211_crit_proto_id protocol, u16 duration), + TP_ARGS(wiphy, wdev, protocol, duration), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + __field(u16, proto) + __field(u16, duration) + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + __entry->proto = protocol; + __entry->duration = duration; + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT ", proto=%x, duration=%u", + WIPHY_PR_ARG, WDEV_PR_ARG, __entry->proto, __entry->duration) +); + +TRACE_EVENT(rdev_crit_proto_stop, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev), + TP_ARGS(wiphy, wdev), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT, + WIPHY_PR_ARG, WDEV_PR_ARG) +); + +TRACE_EVENT(rdev_channel_switch, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_csa_settings *params), + TP_ARGS(wiphy, netdev, params), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + CHAN_DEF_ENTRY + __field(bool, radar_required) + __field(bool, block_tx) + __field(u8, count) + __dynamic_array(u16, bcn_ofs, params->n_counter_offsets_beacon) + __dynamic_array(u16, pres_ofs, params->n_counter_offsets_presp) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + CHAN_DEF_ASSIGN(¶ms->chandef); + __entry->radar_required = params->radar_required; + __entry->block_tx = params->block_tx; + __entry->count = params->count; + memcpy(__get_dynamic_array(bcn_ofs), + params->counter_offsets_beacon, + params->n_counter_offsets_beacon * sizeof(u16)); + + /* probe response offsets are optional */ + if (params->n_counter_offsets_presp) + memcpy(__get_dynamic_array(pres_ofs), + params->counter_offsets_presp, + params->n_counter_offsets_presp * sizeof(u16)); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " CHAN_DEF_PR_FMT + ", block_tx: %d, count: %u, radar_required: %d", + WIPHY_PR_ARG, NETDEV_PR_ARG, CHAN_DEF_PR_ARG, + __entry->block_tx, __entry->count, __entry->radar_required) +); + +TRACE_EVENT(rdev_set_qos_map, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_qos_map *qos_map), + TP_ARGS(wiphy, netdev, qos_map), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + QOS_MAP_ENTRY + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + QOS_MAP_ASSIGN(qos_map); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", num_des: %u", + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->num_des) +); + +TRACE_EVENT(rdev_set_ap_chanwidth, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + unsigned int link_id, + struct cfg80211_chan_def *chandef), + TP_ARGS(wiphy, netdev, link_id, chandef), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + CHAN_DEF_ENTRY + __field(unsigned int, link_id) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + CHAN_DEF_ASSIGN(chandef); + __entry->link_id = link_id; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " CHAN_DEF_PR_FMT ", link:%d", + WIPHY_PR_ARG, NETDEV_PR_ARG, CHAN_DEF_PR_ARG, + __entry->link_id) +); + +TRACE_EVENT(rdev_add_tx_ts, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + u8 tsid, const u8 *peer, u8 user_prio, u16 admitted_time), + TP_ARGS(wiphy, netdev, tsid, peer, user_prio, admitted_time), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(peer) + __field(u8, tsid) + __field(u8, user_prio) + __field(u16, admitted_time) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(peer, peer); + __entry->tsid = tsid; + __entry->user_prio = user_prio; + __entry->admitted_time = admitted_time; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " MAC_PR_FMT ", TSID %d, UP %d, time %d", + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(peer), + __entry->tsid, __entry->user_prio, __entry->admitted_time) +); + +TRACE_EVENT(rdev_del_tx_ts, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + u8 tsid, const u8 *peer), + TP_ARGS(wiphy, netdev, tsid, peer), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(peer) + __field(u8, tsid) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(peer, peer); + __entry->tsid = tsid; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " MAC_PR_FMT ", TSID %d", + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(peer), __entry->tsid) +); + +TRACE_EVENT(rdev_tdls_channel_switch, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + const u8 *addr, u8 oper_class, + struct cfg80211_chan_def *chandef), + TP_ARGS(wiphy, netdev, addr, oper_class, chandef), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(addr) + __field(u8, oper_class) + CHAN_DEF_ENTRY + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(addr, addr); + CHAN_DEF_ASSIGN(chandef); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " MAC_PR_FMT + " oper class %d, " CHAN_DEF_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(addr), + __entry->oper_class, CHAN_DEF_PR_ARG) +); + +TRACE_EVENT(rdev_tdls_cancel_channel_switch, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + const u8 *addr), + TP_ARGS(wiphy, netdev, addr), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(addr) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(addr, addr); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " MAC_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(addr)) +); + +TRACE_EVENT(rdev_set_pmk, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_pmk_conf *pmk_conf), + + TP_ARGS(wiphy, netdev, pmk_conf), + + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(aa) + __field(u8, pmk_len) + __field(u8, pmk_r0_name_len) + __dynamic_array(u8, pmk, pmk_conf->pmk_len) + __dynamic_array(u8, pmk_r0_name, WLAN_PMK_NAME_LEN) + ), + + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(aa, pmk_conf->aa); + __entry->pmk_len = pmk_conf->pmk_len; + __entry->pmk_r0_name_len = + pmk_conf->pmk_r0_name ? WLAN_PMK_NAME_LEN : 0; + memcpy(__get_dynamic_array(pmk), pmk_conf->pmk, + pmk_conf->pmk_len); + memcpy(__get_dynamic_array(pmk_r0_name), pmk_conf->pmk_r0_name, + pmk_conf->pmk_r0_name ? WLAN_PMK_NAME_LEN : 0); + ), + + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " MAC_PR_FMT + "pmk_len=%u, pmk: %s pmk_r0_name: %s", WIPHY_PR_ARG, + NETDEV_PR_ARG, MAC_PR_ARG(aa), __entry->pmk_len, + __print_array(__get_dynamic_array(pmk), + __get_dynamic_array_len(pmk), 1), + __entry->pmk_r0_name_len ? + __print_array(__get_dynamic_array(pmk_r0_name), + __get_dynamic_array_len(pmk_r0_name), 1) : "") +); + +TRACE_EVENT(rdev_del_pmk, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, const u8 *aa), + + TP_ARGS(wiphy, netdev, aa), + + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(aa) + ), + + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(aa, aa); + ), + + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " MAC_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(aa)) +); + +TRACE_EVENT(rdev_external_auth, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_external_auth_params *params), + TP_ARGS(wiphy, netdev, params), + TP_STRUCT__entry(WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(bssid) + __array(u8, ssid, IEEE80211_MAX_SSID_LEN + 1) + __field(u16, status) + ), + TP_fast_assign(WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(bssid, params->bssid); + memset(__entry->ssid, 0, IEEE80211_MAX_SSID_LEN + 1); + memcpy(__entry->ssid, params->ssid.ssid, + params->ssid.ssid_len); + __entry->status = params->status; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", bssid: " MAC_PR_FMT + ", ssid: %s, status: %u", WIPHY_PR_ARG, NETDEV_PR_ARG, + __entry->bssid, __entry->ssid, __entry->status) +); + +TRACE_EVENT(rdev_start_radar_detection, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_chan_def *chandef, + u32 cac_time_ms), + TP_ARGS(wiphy, netdev, chandef, cac_time_ms), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + CHAN_DEF_ENTRY + __field(u32, cac_time_ms) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + CHAN_DEF_ASSIGN(chandef); + __entry->cac_time_ms = cac_time_ms; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " CHAN_DEF_PR_FMT + ", cac_time_ms=%u", + WIPHY_PR_ARG, NETDEV_PR_ARG, CHAN_DEF_PR_ARG, + __entry->cac_time_ms) +); + +TRACE_EVENT(rdev_set_mcast_rate, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + int *mcast_rate), + TP_ARGS(wiphy, netdev, mcast_rate), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __array(int, mcast_rate, NUM_NL80211_BANDS) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + memcpy(__entry->mcast_rate, mcast_rate, + sizeof(int) * NUM_NL80211_BANDS); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " + "mcast_rates [2.4GHz=0x%x, 5.2GHz=0x%x, 6GHz=0x%x, 60GHz=0x%x]", + WIPHY_PR_ARG, NETDEV_PR_ARG, + __entry->mcast_rate[NL80211_BAND_2GHZ], + __entry->mcast_rate[NL80211_BAND_5GHZ], + __entry->mcast_rate[NL80211_BAND_6GHZ], + __entry->mcast_rate[NL80211_BAND_60GHZ]) +); + +TRACE_EVENT(rdev_set_coalesce, + TP_PROTO(struct wiphy *wiphy, struct cfg80211_coalesce *coalesce), + TP_ARGS(wiphy, coalesce), + TP_STRUCT__entry( + WIPHY_ENTRY + __field(int, n_rules) + ), + TP_fast_assign( + WIPHY_ASSIGN; + __entry->n_rules = coalesce ? coalesce->n_rules : 0; + ), + TP_printk(WIPHY_PR_FMT ", n_rules=%d", + WIPHY_PR_ARG, __entry->n_rules) +); + +DEFINE_EVENT(wiphy_wdev_evt, rdev_abort_scan, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev), + TP_ARGS(wiphy, wdev) +); + +TRACE_EVENT(rdev_set_multicast_to_unicast, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + const bool enabled), + TP_ARGS(wiphy, netdev, enabled), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(bool, enabled) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->enabled = enabled; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", unicast: %s", + WIPHY_PR_ARG, NETDEV_PR_ARG, + BOOL_TO_STR(__entry->enabled)) +); + +DEFINE_EVENT(wiphy_wdev_evt, rdev_get_txq_stats, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev), + TP_ARGS(wiphy, wdev) +); + +TRACE_EVENT(rdev_get_ftm_responder_stats, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_ftm_responder_stats *ftm_stats), + + TP_ARGS(wiphy, netdev, ftm_stats), + + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(u64, timestamp) + __field(u32, success_num) + __field(u32, partial_num) + __field(u32, failed_num) + __field(u32, asap_num) + __field(u32, non_asap_num) + __field(u64, duration) + __field(u32, unknown_triggers) + __field(u32, reschedule) + __field(u32, out_of_window) + ), + + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->success_num = ftm_stats->success_num; + __entry->partial_num = ftm_stats->partial_num; + __entry->failed_num = ftm_stats->failed_num; + __entry->asap_num = ftm_stats->asap_num; + __entry->non_asap_num = ftm_stats->non_asap_num; + __entry->duration = ftm_stats->total_duration_ms; + __entry->unknown_triggers = ftm_stats->unknown_triggers_num; + __entry->reschedule = ftm_stats->reschedule_requests_num; + __entry->out_of_window = ftm_stats->out_of_window_triggers_num; + ), + + TP_printk(WIPHY_PR_FMT "Ftm responder stats: success %u, partial %u, " + "failed %u, asap %u, non asap %u, total duration %llu, unknown " + "triggers %u, rescheduled %u, out of window %u", WIPHY_PR_ARG, + __entry->success_num, __entry->partial_num, __entry->failed_num, + __entry->asap_num, __entry->non_asap_num, __entry->duration, + __entry->unknown_triggers, __entry->reschedule, + __entry->out_of_window) +); + +DEFINE_EVENT(wiphy_wdev_cookie_evt, rdev_start_pmsr, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, u64 cookie), + TP_ARGS(wiphy, wdev, cookie) +); + +DEFINE_EVENT(wiphy_wdev_cookie_evt, rdev_abort_pmsr, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, u64 cookie), + TP_ARGS(wiphy, wdev, cookie) +); + +TRACE_EVENT(rdev_set_fils_aad, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_fils_aad *fils_aad), + TP_ARGS(wiphy, netdev, fils_aad), + TP_STRUCT__entry(WIPHY_ENTRY + NETDEV_ENTRY + __array(u8, macaddr, ETH_ALEN) + __field(u8, kek_len) + ), + TP_fast_assign(WIPHY_ASSIGN; + NETDEV_ASSIGN; + FILS_AAD_ASSIGN(fils_aad); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " FILS_AAD_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, __entry->macaddr, + __entry->kek_len) +); + +TRACE_EVENT(rdev_update_owe_info, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_update_owe_info *owe_info), + TP_ARGS(wiphy, netdev, owe_info), + TP_STRUCT__entry(WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(peer) + __field(u16, status) + __dynamic_array(u8, ie, owe_info->ie_len)), + TP_fast_assign(WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(peer, owe_info->peer); + __entry->status = owe_info->status; + memcpy(__get_dynamic_array(ie), + owe_info->ie, owe_info->ie_len);), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", peer: " MAC_PR_FMT + " status %d", WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(peer), + __entry->status) +); + +TRACE_EVENT(rdev_probe_mesh_link, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + const u8 *dest, const u8 *buf, size_t len), + TP_ARGS(wiphy, netdev, dest, buf, len), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(dest) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(dest, dest); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", " MAC_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(dest)) +); + +TRACE_EVENT(rdev_set_tid_config, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_tid_config *tid_conf), + TP_ARGS(wiphy, netdev, tid_conf), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(peer) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(peer, tid_conf->peer); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", peer: " MAC_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(peer)) +); + +TRACE_EVENT(rdev_reset_tid_config, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + const u8 *peer, u8 tids), + TP_ARGS(wiphy, netdev, peer, tids), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(peer) + __field(u8, tids) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(peer, peer); + __entry->tids = tids; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", peer: " MAC_PR_FMT ", tids: 0x%x", + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(peer), __entry->tids) +); + +TRACE_EVENT(rdev_set_sar_specs, + TP_PROTO(struct wiphy *wiphy, struct cfg80211_sar_specs *sar), + TP_ARGS(wiphy, sar), + TP_STRUCT__entry( + WIPHY_ENTRY + __field(u16, type) + __field(u16, num) + ), + TP_fast_assign( + WIPHY_ASSIGN; + __entry->type = sar->type; + __entry->num = sar->num_sub_specs; + + ), + TP_printk(WIPHY_PR_FMT ", Set type:%d, num_specs:%d", + WIPHY_PR_ARG, __entry->type, __entry->num) +); + +TRACE_EVENT(rdev_color_change, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_color_change_settings *params), + TP_ARGS(wiphy, netdev, params), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __field(u8, count) + __field(u16, bcn_ofs) + __field(u16, pres_ofs) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + __entry->count = params->count; + __entry->bcn_ofs = params->counter_offset_beacon; + __entry->pres_ofs = params->counter_offset_presp; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT + ", count: %u", + WIPHY_PR_ARG, NETDEV_PR_ARG, + __entry->count) +); + +TRACE_EVENT(rdev_set_radar_background, + TP_PROTO(struct wiphy *wiphy, struct cfg80211_chan_def *chandef), + + TP_ARGS(wiphy, chandef), + + TP_STRUCT__entry( + WIPHY_ENTRY + CHAN_DEF_ENTRY + ), + + TP_fast_assign( + WIPHY_ASSIGN; + CHAN_DEF_ASSIGN(chandef) + ), + + TP_printk(WIPHY_PR_FMT ", " CHAN_DEF_PR_FMT, + WIPHY_PR_ARG, CHAN_DEF_PR_ARG) +); + +DEFINE_EVENT(wiphy_wdev_link_evt, rdev_add_intf_link, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, + unsigned int link_id), + TP_ARGS(wiphy, wdev, link_id) +); + +DEFINE_EVENT(wiphy_wdev_link_evt, rdev_del_intf_link, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, + unsigned int link_id), + TP_ARGS(wiphy, wdev, link_id) +); + +/************************************************************* + * cfg80211 exported functions traces * + *************************************************************/ + +TRACE_EVENT(cfg80211_return_bool, + TP_PROTO(bool ret), + TP_ARGS(ret), + TP_STRUCT__entry( + __field(bool, ret) + ), + TP_fast_assign( + __entry->ret = ret; + ), + TP_printk("returned %s", BOOL_TO_STR(__entry->ret)) +); + +DECLARE_EVENT_CLASS(cfg80211_netdev_mac_evt, + TP_PROTO(struct net_device *netdev, const u8 *macaddr), + TP_ARGS(netdev, macaddr), + TP_STRUCT__entry( + NETDEV_ENTRY + MAC_ENTRY(macaddr) + ), + TP_fast_assign( + NETDEV_ASSIGN; + MAC_ASSIGN(macaddr, macaddr); + ), + TP_printk(NETDEV_PR_FMT ", mac: " MAC_PR_FMT, + NETDEV_PR_ARG, MAC_PR_ARG(macaddr)) +); + +DEFINE_EVENT(cfg80211_netdev_mac_evt, cfg80211_notify_new_peer_candidate, + TP_PROTO(struct net_device *netdev, const u8 *macaddr), + TP_ARGS(netdev, macaddr) +); + +DECLARE_EVENT_CLASS(netdev_evt_only, + TP_PROTO(struct net_device *netdev), + TP_ARGS(netdev), + TP_STRUCT__entry( + NETDEV_ENTRY + ), + TP_fast_assign( + NETDEV_ASSIGN; + ), + TP_printk(NETDEV_PR_FMT , NETDEV_PR_ARG) +); + +DEFINE_EVENT(netdev_evt_only, cfg80211_send_rx_auth, + TP_PROTO(struct net_device *netdev), + TP_ARGS(netdev) +); + +TRACE_EVENT(cfg80211_send_rx_assoc, + TP_PROTO(struct net_device *netdev, + struct cfg80211_rx_assoc_resp *data), + TP_ARGS(netdev, data), + TP_STRUCT__entry( + NETDEV_ENTRY + MAC_ENTRY(ap_addr) + ), + TP_fast_assign( + NETDEV_ASSIGN; + MAC_ASSIGN(ap_addr, + data->ap_mld_addr ?: data->links[0].bss->bssid); + ), + TP_printk(NETDEV_PR_FMT ", " MAC_PR_FMT, + NETDEV_PR_ARG, MAC_PR_ARG(ap_addr)) +); + +DECLARE_EVENT_CLASS(netdev_frame_event, + TP_PROTO(struct net_device *netdev, const u8 *buf, int len), + TP_ARGS(netdev, buf, len), + TP_STRUCT__entry( + NETDEV_ENTRY + __dynamic_array(u8, frame, len) + ), + TP_fast_assign( + NETDEV_ASSIGN; + memcpy(__get_dynamic_array(frame), buf, len); + ), + TP_printk(NETDEV_PR_FMT ", ftype:0x%.2x", + NETDEV_PR_ARG, + le16_to_cpup((__le16 *)__get_dynamic_array(frame))) +); + +DEFINE_EVENT(netdev_frame_event, cfg80211_rx_unprot_mlme_mgmt, + TP_PROTO(struct net_device *netdev, const u8 *buf, int len), + TP_ARGS(netdev, buf, len) +); + +DEFINE_EVENT(netdev_frame_event, cfg80211_rx_mlme_mgmt, + TP_PROTO(struct net_device *netdev, const u8 *buf, int len), + TP_ARGS(netdev, buf, len) +); + +TRACE_EVENT(cfg80211_tx_mlme_mgmt, + TP_PROTO(struct net_device *netdev, const u8 *buf, int len, + bool reconnect), + TP_ARGS(netdev, buf, len, reconnect), + TP_STRUCT__entry( + NETDEV_ENTRY + __dynamic_array(u8, frame, len) + __field(int, reconnect) + ), + TP_fast_assign( + NETDEV_ASSIGN; + memcpy(__get_dynamic_array(frame), buf, len); + __entry->reconnect = reconnect; + ), + TP_printk(NETDEV_PR_FMT ", ftype:0x%.2x reconnect:%d", + NETDEV_PR_ARG, + le16_to_cpup((__le16 *)__get_dynamic_array(frame)), + __entry->reconnect) +); + +DECLARE_EVENT_CLASS(netdev_mac_evt, + TP_PROTO(struct net_device *netdev, const u8 *mac), + TP_ARGS(netdev, mac), + TP_STRUCT__entry( + NETDEV_ENTRY + MAC_ENTRY(mac) + ), + TP_fast_assign( + NETDEV_ASSIGN; + MAC_ASSIGN(mac, mac) + ), + TP_printk(NETDEV_PR_FMT ", mac: " MAC_PR_FMT, + NETDEV_PR_ARG, MAC_PR_ARG(mac)) +); + +DEFINE_EVENT(netdev_mac_evt, cfg80211_send_auth_timeout, + TP_PROTO(struct net_device *netdev, const u8 *mac), + TP_ARGS(netdev, mac) +); + +TRACE_EVENT(cfg80211_send_assoc_failure, + TP_PROTO(struct net_device *netdev, + struct cfg80211_assoc_failure *data), + TP_ARGS(netdev, data), + TP_STRUCT__entry( + NETDEV_ENTRY + MAC_ENTRY(ap_addr) + __field(bool, timeout) + ), + TP_fast_assign( + NETDEV_ASSIGN; + MAC_ASSIGN(ap_addr, data->ap_mld_addr ?: data->bss[0]->bssid); + __entry->timeout = data->timeout; + ), + TP_printk(NETDEV_PR_FMT ", mac: " MAC_PR_FMT ", timeout: %d", + NETDEV_PR_ARG, MAC_PR_ARG(ap_addr), __entry->timeout) +); + +TRACE_EVENT(cfg80211_michael_mic_failure, + TP_PROTO(struct net_device *netdev, const u8 *addr, + enum nl80211_key_type key_type, int key_id, const u8 *tsc), + TP_ARGS(netdev, addr, key_type, key_id, tsc), + TP_STRUCT__entry( + NETDEV_ENTRY + MAC_ENTRY(addr) + __field(enum nl80211_key_type, key_type) + __field(int, key_id) + __array(u8, tsc, 6) + ), + TP_fast_assign( + NETDEV_ASSIGN; + MAC_ASSIGN(addr, addr); + __entry->key_type = key_type; + __entry->key_id = key_id; + if (tsc) + memcpy(__entry->tsc, tsc, 6); + ), + TP_printk(NETDEV_PR_FMT ", " MAC_PR_FMT ", key type: %d, key id: %d, tsc: %pm", + NETDEV_PR_ARG, MAC_PR_ARG(addr), __entry->key_type, + __entry->key_id, __entry->tsc) +); + +TRACE_EVENT(cfg80211_ready_on_channel, + TP_PROTO(struct wireless_dev *wdev, u64 cookie, + struct ieee80211_channel *chan, + unsigned int duration), + TP_ARGS(wdev, cookie, chan, duration), + TP_STRUCT__entry( + WDEV_ENTRY + __field(u64, cookie) + CHAN_ENTRY + __field(unsigned int, duration) + ), + TP_fast_assign( + WDEV_ASSIGN; + __entry->cookie = cookie; + CHAN_ASSIGN(chan); + __entry->duration = duration; + ), + TP_printk(WDEV_PR_FMT ", cookie: %llu, " CHAN_PR_FMT ", duration: %u", + WDEV_PR_ARG, __entry->cookie, CHAN_PR_ARG, + __entry->duration) +); + +TRACE_EVENT(cfg80211_ready_on_channel_expired, + TP_PROTO(struct wireless_dev *wdev, u64 cookie, + struct ieee80211_channel *chan), + TP_ARGS(wdev, cookie, chan), + TP_STRUCT__entry( + WDEV_ENTRY + __field(u64, cookie) + CHAN_ENTRY + ), + TP_fast_assign( + WDEV_ASSIGN; + __entry->cookie = cookie; + CHAN_ASSIGN(chan); + ), + TP_printk(WDEV_PR_FMT ", cookie: %llu, " CHAN_PR_FMT, + WDEV_PR_ARG, __entry->cookie, CHAN_PR_ARG) +); + +TRACE_EVENT(cfg80211_tx_mgmt_expired, + TP_PROTO(struct wireless_dev *wdev, u64 cookie, + struct ieee80211_channel *chan), + TP_ARGS(wdev, cookie, chan), + TP_STRUCT__entry( + WDEV_ENTRY + __field(u64, cookie) + CHAN_ENTRY + ), + TP_fast_assign( + WDEV_ASSIGN; + __entry->cookie = cookie; + CHAN_ASSIGN(chan); + ), + TP_printk(WDEV_PR_FMT ", cookie: %llu, " CHAN_PR_FMT, + WDEV_PR_ARG, __entry->cookie, CHAN_PR_ARG) +); + +TRACE_EVENT(cfg80211_new_sta, + TP_PROTO(struct net_device *netdev, const u8 *mac_addr, + struct station_info *sinfo), + TP_ARGS(netdev, mac_addr, sinfo), + TP_STRUCT__entry( + NETDEV_ENTRY + MAC_ENTRY(mac_addr) + SINFO_ENTRY + ), + TP_fast_assign( + NETDEV_ASSIGN; + MAC_ASSIGN(mac_addr, mac_addr); + SINFO_ASSIGN; + ), + TP_printk(NETDEV_PR_FMT ", " MAC_PR_FMT, + NETDEV_PR_ARG, MAC_PR_ARG(mac_addr)) +); + +DEFINE_EVENT(cfg80211_netdev_mac_evt, cfg80211_del_sta, + TP_PROTO(struct net_device *netdev, const u8 *macaddr), + TP_ARGS(netdev, macaddr) +); + +TRACE_EVENT(cfg80211_rx_mgmt, + TP_PROTO(struct wireless_dev *wdev, struct cfg80211_rx_info *info), + TP_ARGS(wdev, info), + TP_STRUCT__entry( + WDEV_ENTRY + __field(int, freq) + __field(int, sig_dbm) + ), + TP_fast_assign( + WDEV_ASSIGN; + __entry->freq = info->freq; + __entry->sig_dbm = info->sig_dbm; + ), + TP_printk(WDEV_PR_FMT ", freq: "KHZ_F", sig dbm: %d", + WDEV_PR_ARG, PR_KHZ(__entry->freq), __entry->sig_dbm) +); + +TRACE_EVENT(cfg80211_mgmt_tx_status, + TP_PROTO(struct wireless_dev *wdev, u64 cookie, bool ack), + TP_ARGS(wdev, cookie, ack), + TP_STRUCT__entry( + WDEV_ENTRY + __field(u64, cookie) + __field(bool, ack) + ), + TP_fast_assign( + WDEV_ASSIGN; + __entry->cookie = cookie; + __entry->ack = ack; + ), + TP_printk(WDEV_PR_FMT", cookie: %llu, ack: %s", + WDEV_PR_ARG, __entry->cookie, BOOL_TO_STR(__entry->ack)) +); + +TRACE_EVENT(cfg80211_control_port_tx_status, + TP_PROTO(struct wireless_dev *wdev, u64 cookie, bool ack), + TP_ARGS(wdev, cookie, ack), + TP_STRUCT__entry( + WDEV_ENTRY + __field(u64, cookie) + __field(bool, ack) + ), + TP_fast_assign( + WDEV_ASSIGN; + __entry->cookie = cookie; + __entry->ack = ack; + ), + TP_printk(WDEV_PR_FMT", cookie: %llu, ack: %s", + WDEV_PR_ARG, __entry->cookie, BOOL_TO_STR(__entry->ack)) +); + +TRACE_EVENT(cfg80211_rx_control_port, + TP_PROTO(struct net_device *netdev, struct sk_buff *skb, + bool unencrypted), + TP_ARGS(netdev, skb, unencrypted), + TP_STRUCT__entry( + NETDEV_ENTRY + __field(int, len) + MAC_ENTRY(from) + __field(u16, proto) + __field(bool, unencrypted) + ), + TP_fast_assign( + NETDEV_ASSIGN; + __entry->len = skb->len; + MAC_ASSIGN(from, eth_hdr(skb)->h_source); + __entry->proto = be16_to_cpu(skb->protocol); + __entry->unencrypted = unencrypted; + ), + TP_printk(NETDEV_PR_FMT ", len=%d, " MAC_PR_FMT ", proto: 0x%x, unencrypted: %s", + NETDEV_PR_ARG, __entry->len, MAC_PR_ARG(from), + __entry->proto, BOOL_TO_STR(__entry->unencrypted)) +); + +TRACE_EVENT(cfg80211_cqm_rssi_notify, + TP_PROTO(struct net_device *netdev, + enum nl80211_cqm_rssi_threshold_event rssi_event, + s32 rssi_level), + TP_ARGS(netdev, rssi_event, rssi_level), + TP_STRUCT__entry( + NETDEV_ENTRY + __field(enum nl80211_cqm_rssi_threshold_event, rssi_event) + __field(s32, rssi_level) + ), + TP_fast_assign( + NETDEV_ASSIGN; + __entry->rssi_event = rssi_event; + __entry->rssi_level = rssi_level; + ), + TP_printk(NETDEV_PR_FMT ", rssi event: %d, level: %d", + NETDEV_PR_ARG, __entry->rssi_event, __entry->rssi_level) +); + +TRACE_EVENT(cfg80211_reg_can_beacon, + TP_PROTO(struct wiphy *wiphy, struct cfg80211_chan_def *chandef, + enum nl80211_iftype iftype, bool check_no_ir), + TP_ARGS(wiphy, chandef, iftype, check_no_ir), + TP_STRUCT__entry( + WIPHY_ENTRY + CHAN_DEF_ENTRY + __field(enum nl80211_iftype, iftype) + __field(bool, check_no_ir) + ), + TP_fast_assign( + WIPHY_ASSIGN; + CHAN_DEF_ASSIGN(chandef); + __entry->iftype = iftype; + __entry->check_no_ir = check_no_ir; + ), + TP_printk(WIPHY_PR_FMT ", " CHAN_DEF_PR_FMT ", iftype=%d check_no_ir=%s", + WIPHY_PR_ARG, CHAN_DEF_PR_ARG, __entry->iftype, + BOOL_TO_STR(__entry->check_no_ir)) +); + +TRACE_EVENT(cfg80211_chandef_dfs_required, + TP_PROTO(struct wiphy *wiphy, struct cfg80211_chan_def *chandef), + TP_ARGS(wiphy, chandef), + TP_STRUCT__entry( + WIPHY_ENTRY + CHAN_DEF_ENTRY + ), + TP_fast_assign( + WIPHY_ASSIGN; + CHAN_DEF_ASSIGN(chandef); + ), + TP_printk(WIPHY_PR_FMT ", " CHAN_DEF_PR_FMT, + WIPHY_PR_ARG, CHAN_DEF_PR_ARG) +); + +TRACE_EVENT(cfg80211_ch_switch_notify, + TP_PROTO(struct net_device *netdev, + struct cfg80211_chan_def *chandef, + unsigned int link_id), + TP_ARGS(netdev, chandef, link_id), + TP_STRUCT__entry( + NETDEV_ENTRY + CHAN_DEF_ENTRY + __field(unsigned int, link_id) + ), + TP_fast_assign( + NETDEV_ASSIGN; + CHAN_DEF_ASSIGN(chandef); + __entry->link_id = link_id; + ), + TP_printk(NETDEV_PR_FMT ", " CHAN_DEF_PR_FMT ", link:%d", + NETDEV_PR_ARG, CHAN_DEF_PR_ARG, __entry->link_id) +); + +TRACE_EVENT(cfg80211_ch_switch_started_notify, + TP_PROTO(struct net_device *netdev, + struct cfg80211_chan_def *chandef, + unsigned int link_id), + TP_ARGS(netdev, chandef, link_id), + TP_STRUCT__entry( + NETDEV_ENTRY + CHAN_DEF_ENTRY + __field(unsigned int, link_id) + ), + TP_fast_assign( + NETDEV_ASSIGN; + CHAN_DEF_ASSIGN(chandef); + __entry->link_id = link_id; + ), + TP_printk(NETDEV_PR_FMT ", " CHAN_DEF_PR_FMT ", link:%d", + NETDEV_PR_ARG, CHAN_DEF_PR_ARG, __entry->link_id) +); + +TRACE_EVENT(cfg80211_radar_event, + TP_PROTO(struct wiphy *wiphy, struct cfg80211_chan_def *chandef, + bool offchan), + TP_ARGS(wiphy, chandef, offchan), + TP_STRUCT__entry( + WIPHY_ENTRY + CHAN_DEF_ENTRY + __field(bool, offchan) + ), + TP_fast_assign( + WIPHY_ASSIGN; + CHAN_DEF_ASSIGN(chandef); + __entry->offchan = offchan; + ), + TP_printk(WIPHY_PR_FMT ", " CHAN_DEF_PR_FMT ", offchan %d", + WIPHY_PR_ARG, CHAN_DEF_PR_ARG, __entry->offchan) +); + +TRACE_EVENT(cfg80211_cac_event, + TP_PROTO(struct net_device *netdev, enum nl80211_radar_event evt), + TP_ARGS(netdev, evt), + TP_STRUCT__entry( + NETDEV_ENTRY + __field(enum nl80211_radar_event, evt) + ), + TP_fast_assign( + NETDEV_ASSIGN; + __entry->evt = evt; + ), + TP_printk(NETDEV_PR_FMT ", event: %d", + NETDEV_PR_ARG, __entry->evt) +); + +DECLARE_EVENT_CLASS(cfg80211_rx_evt, + TP_PROTO(struct net_device *netdev, const u8 *addr), + TP_ARGS(netdev, addr), + TP_STRUCT__entry( + NETDEV_ENTRY + MAC_ENTRY(addr) + ), + TP_fast_assign( + NETDEV_ASSIGN; + MAC_ASSIGN(addr, addr); + ), + TP_printk(NETDEV_PR_FMT ", " MAC_PR_FMT, NETDEV_PR_ARG, MAC_PR_ARG(addr)) +); + +DEFINE_EVENT(cfg80211_rx_evt, cfg80211_rx_spurious_frame, + TP_PROTO(struct net_device *netdev, const u8 *addr), + TP_ARGS(netdev, addr) +); + +DEFINE_EVENT(cfg80211_rx_evt, cfg80211_rx_unexpected_4addr_frame, + TP_PROTO(struct net_device *netdev, const u8 *addr), + TP_ARGS(netdev, addr) +); + +TRACE_EVENT(cfg80211_ibss_joined, + TP_PROTO(struct net_device *netdev, const u8 *bssid, + struct ieee80211_channel *channel), + TP_ARGS(netdev, bssid, channel), + TP_STRUCT__entry( + NETDEV_ENTRY + MAC_ENTRY(bssid) + CHAN_ENTRY + ), + TP_fast_assign( + NETDEV_ASSIGN; + MAC_ASSIGN(bssid, bssid); + CHAN_ASSIGN(channel); + ), + TP_printk(NETDEV_PR_FMT ", bssid: " MAC_PR_FMT ", " CHAN_PR_FMT, + NETDEV_PR_ARG, MAC_PR_ARG(bssid), CHAN_PR_ARG) +); + +TRACE_EVENT(cfg80211_probe_status, + TP_PROTO(struct net_device *netdev, const u8 *addr, u64 cookie, + bool acked), + TP_ARGS(netdev, addr, cookie, acked), + TP_STRUCT__entry( + NETDEV_ENTRY + MAC_ENTRY(addr) + __field(u64, cookie) + __field(bool, acked) + ), + TP_fast_assign( + NETDEV_ASSIGN; + MAC_ASSIGN(addr, addr); + __entry->cookie = cookie; + __entry->acked = acked; + ), + TP_printk(NETDEV_PR_FMT " addr:" MAC_PR_FMT ", cookie: %llu, acked: %s", + NETDEV_PR_ARG, MAC_PR_ARG(addr), __entry->cookie, + BOOL_TO_STR(__entry->acked)) +); + +TRACE_EVENT(cfg80211_cqm_pktloss_notify, + TP_PROTO(struct net_device *netdev, const u8 *peer, u32 num_packets), + TP_ARGS(netdev, peer, num_packets), + TP_STRUCT__entry( + NETDEV_ENTRY + MAC_ENTRY(peer) + __field(u32, num_packets) + ), + TP_fast_assign( + NETDEV_ASSIGN; + MAC_ASSIGN(peer, peer); + __entry->num_packets = num_packets; + ), + TP_printk(NETDEV_PR_FMT ", peer: " MAC_PR_FMT ", num of lost packets: %u", + NETDEV_PR_ARG, MAC_PR_ARG(peer), __entry->num_packets) +); + +DEFINE_EVENT(cfg80211_netdev_mac_evt, cfg80211_gtk_rekey_notify, + TP_PROTO(struct net_device *netdev, const u8 *macaddr), + TP_ARGS(netdev, macaddr) +); + +TRACE_EVENT(cfg80211_pmksa_candidate_notify, + TP_PROTO(struct net_device *netdev, int index, const u8 *bssid, + bool preauth), + TP_ARGS(netdev, index, bssid, preauth), + TP_STRUCT__entry( + NETDEV_ENTRY + __field(int, index) + MAC_ENTRY(bssid) + __field(bool, preauth) + ), + TP_fast_assign( + NETDEV_ASSIGN; + __entry->index = index; + MAC_ASSIGN(bssid, bssid); + __entry->preauth = preauth; + ), + TP_printk(NETDEV_PR_FMT ", index:%d, bssid: " MAC_PR_FMT ", pre auth: %s", + NETDEV_PR_ARG, __entry->index, MAC_PR_ARG(bssid), + BOOL_TO_STR(__entry->preauth)) +); + +TRACE_EVENT(cfg80211_report_obss_beacon, + TP_PROTO(struct wiphy *wiphy, const u8 *frame, size_t len, + int freq, int sig_dbm), + TP_ARGS(wiphy, frame, len, freq, sig_dbm), + TP_STRUCT__entry( + WIPHY_ENTRY + __field(int, freq) + __field(int, sig_dbm) + ), + TP_fast_assign( + WIPHY_ASSIGN; + __entry->freq = freq; + __entry->sig_dbm = sig_dbm; + ), + TP_printk(WIPHY_PR_FMT ", freq: "KHZ_F", sig_dbm: %d", + WIPHY_PR_ARG, PR_KHZ(__entry->freq), __entry->sig_dbm) +); + +TRACE_EVENT(cfg80211_tdls_oper_request, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, const u8 *peer, + enum nl80211_tdls_operation oper, u16 reason_code), + TP_ARGS(wiphy, netdev, peer, oper, reason_code), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(peer) + __field(enum nl80211_tdls_operation, oper) + __field(u16, reason_code) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(peer, peer); + __entry->oper = oper; + __entry->reason_code = reason_code; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", peer: " MAC_PR_FMT ", oper: %d, reason_code %u", + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(peer), __entry->oper, + __entry->reason_code) + ); + +TRACE_EVENT(cfg80211_scan_done, + TP_PROTO(struct cfg80211_scan_request *request, + struct cfg80211_scan_info *info), + TP_ARGS(request, info), + TP_STRUCT__entry( + __field(u32, n_channels) + __dynamic_array(u8, ie, request ? request->ie_len : 0) + __array(u32, rates, NUM_NL80211_BANDS) + __field(u32, wdev_id) + MAC_ENTRY(wiphy_mac) + __field(bool, no_cck) + __field(bool, aborted) + __field(u64, scan_start_tsf) + MAC_ENTRY(tsf_bssid) + ), + TP_fast_assign( + if (request) { + memcpy(__get_dynamic_array(ie), request->ie, + request->ie_len); + memcpy(__entry->rates, request->rates, + NUM_NL80211_BANDS); + __entry->wdev_id = request->wdev ? + request->wdev->identifier : 0; + if (request->wiphy) + MAC_ASSIGN(wiphy_mac, + request->wiphy->perm_addr); + __entry->no_cck = request->no_cck; + } + if (info) { + __entry->aborted = info->aborted; + __entry->scan_start_tsf = info->scan_start_tsf; + MAC_ASSIGN(tsf_bssid, info->tsf_bssid); + } + ), + TP_printk("aborted: %s, scan start (TSF): %llu, tsf_bssid: " MAC_PR_FMT, + BOOL_TO_STR(__entry->aborted), + (unsigned long long)__entry->scan_start_tsf, + MAC_PR_ARG(tsf_bssid)) +); + +DECLARE_EVENT_CLASS(wiphy_id_evt, + TP_PROTO(struct wiphy *wiphy, u64 id), + TP_ARGS(wiphy, id), + TP_STRUCT__entry( + WIPHY_ENTRY + __field(u64, id) + ), + TP_fast_assign( + WIPHY_ASSIGN; + __entry->id = id; + ), + TP_printk(WIPHY_PR_FMT ", id: %llu", WIPHY_PR_ARG, __entry->id) +); + +DEFINE_EVENT(wiphy_id_evt, cfg80211_sched_scan_stopped, + TP_PROTO(struct wiphy *wiphy, u64 id), + TP_ARGS(wiphy, id) +); + +DEFINE_EVENT(wiphy_id_evt, cfg80211_sched_scan_results, + TP_PROTO(struct wiphy *wiphy, u64 id), + TP_ARGS(wiphy, id) +); + +TRACE_EVENT(cfg80211_get_bss, + TP_PROTO(struct wiphy *wiphy, struct ieee80211_channel *channel, + const u8 *bssid, const u8 *ssid, size_t ssid_len, + enum ieee80211_bss_type bss_type, + enum ieee80211_privacy privacy), + TP_ARGS(wiphy, channel, bssid, ssid, ssid_len, bss_type, privacy), + TP_STRUCT__entry( + WIPHY_ENTRY + CHAN_ENTRY + MAC_ENTRY(bssid) + __dynamic_array(u8, ssid, ssid_len) + __field(enum ieee80211_bss_type, bss_type) + __field(enum ieee80211_privacy, privacy) + ), + TP_fast_assign( + WIPHY_ASSIGN; + CHAN_ASSIGN(channel); + MAC_ASSIGN(bssid, bssid); + memcpy(__get_dynamic_array(ssid), ssid, ssid_len); + __entry->bss_type = bss_type; + __entry->privacy = privacy; + ), + TP_printk(WIPHY_PR_FMT ", " CHAN_PR_FMT ", " MAC_PR_FMT + ", buf: %#.2x, bss_type: %d, privacy: %d", + WIPHY_PR_ARG, CHAN_PR_ARG, MAC_PR_ARG(bssid), + ((u8 *)__get_dynamic_array(ssid))[0], __entry->bss_type, + __entry->privacy) +); + +TRACE_EVENT(cfg80211_inform_bss_frame, + TP_PROTO(struct wiphy *wiphy, struct cfg80211_inform_bss *data, + struct ieee80211_mgmt *mgmt, size_t len), + TP_ARGS(wiphy, data, mgmt, len), + TP_STRUCT__entry( + WIPHY_ENTRY + CHAN_ENTRY + __field(enum nl80211_bss_scan_width, scan_width) + __dynamic_array(u8, mgmt, len) + __field(s32, signal) + __field(u64, ts_boottime) + __field(u64, parent_tsf) + MAC_ENTRY(parent_bssid) + ), + TP_fast_assign( + WIPHY_ASSIGN; + CHAN_ASSIGN(data->chan); + __entry->scan_width = data->scan_width; + if (mgmt) + memcpy(__get_dynamic_array(mgmt), mgmt, len); + __entry->signal = data->signal; + __entry->ts_boottime = data->boottime_ns; + __entry->parent_tsf = data->parent_tsf; + MAC_ASSIGN(parent_bssid, data->parent_bssid); + ), + TP_printk(WIPHY_PR_FMT ", " CHAN_PR_FMT + "(scan_width: %d) signal: %d, tsb:%llu, detect_tsf:%llu, tsf_bssid: " + MAC_PR_FMT, WIPHY_PR_ARG, CHAN_PR_ARG, __entry->scan_width, + __entry->signal, (unsigned long long)__entry->ts_boottime, + (unsigned long long)__entry->parent_tsf, + MAC_PR_ARG(parent_bssid)) +); + +DECLARE_EVENT_CLASS(cfg80211_bss_evt, + TP_PROTO(struct cfg80211_bss *pub), + TP_ARGS(pub), + TP_STRUCT__entry( + MAC_ENTRY(bssid) + CHAN_ENTRY + ), + TP_fast_assign( + MAC_ASSIGN(bssid, pub->bssid); + CHAN_ASSIGN(pub->channel); + ), + TP_printk(MAC_PR_FMT ", " CHAN_PR_FMT, MAC_PR_ARG(bssid), CHAN_PR_ARG) +); + +DEFINE_EVENT(cfg80211_bss_evt, cfg80211_return_bss, + TP_PROTO(struct cfg80211_bss *pub), + TP_ARGS(pub) +); + +TRACE_EVENT(cfg80211_return_uint, + TP_PROTO(unsigned int ret), + TP_ARGS(ret), + TP_STRUCT__entry( + __field(unsigned int, ret) + ), + TP_fast_assign( + __entry->ret = ret; + ), + TP_printk("ret: %d", __entry->ret) +); + +TRACE_EVENT(cfg80211_return_u32, + TP_PROTO(u32 ret), + TP_ARGS(ret), + TP_STRUCT__entry( + __field(u32, ret) + ), + TP_fast_assign( + __entry->ret = ret; + ), + TP_printk("ret: %u", __entry->ret) +); + +TRACE_EVENT(cfg80211_report_wowlan_wakeup, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, + struct cfg80211_wowlan_wakeup *wakeup), + TP_ARGS(wiphy, wdev, wakeup), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + __field(bool, non_wireless) + __field(bool, disconnect) + __field(bool, magic_pkt) + __field(bool, gtk_rekey_failure) + __field(bool, eap_identity_req) + __field(bool, four_way_handshake) + __field(bool, rfkill_release) + __field(s32, pattern_idx) + __field(u32, packet_len) + __dynamic_array(u8, packet, + wakeup ? wakeup->packet_present_len : 0) + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + __entry->non_wireless = !wakeup; + __entry->disconnect = wakeup ? wakeup->disconnect : false; + __entry->magic_pkt = wakeup ? wakeup->magic_pkt : false; + __entry->gtk_rekey_failure = wakeup ? wakeup->gtk_rekey_failure : false; + __entry->eap_identity_req = wakeup ? wakeup->eap_identity_req : false; + __entry->four_way_handshake = wakeup ? wakeup->four_way_handshake : false; + __entry->rfkill_release = wakeup ? wakeup->rfkill_release : false; + __entry->pattern_idx = wakeup ? wakeup->pattern_idx : false; + __entry->packet_len = wakeup ? wakeup->packet_len : false; + if (wakeup && wakeup->packet && wakeup->packet_present_len) + memcpy(__get_dynamic_array(packet), wakeup->packet, + wakeup->packet_present_len); + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT, WIPHY_PR_ARG, WDEV_PR_ARG) +); + +TRACE_EVENT(cfg80211_ft_event, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_ft_event_params *ft_event), + TP_ARGS(wiphy, netdev, ft_event), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __dynamic_array(u8, ies, ft_event->ies_len) + MAC_ENTRY(target_ap) + __dynamic_array(u8, ric_ies, ft_event->ric_ies_len) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + if (ft_event->ies) + memcpy(__get_dynamic_array(ies), ft_event->ies, + ft_event->ies_len); + MAC_ASSIGN(target_ap, ft_event->target_ap); + if (ft_event->ric_ies) + memcpy(__get_dynamic_array(ric_ies), ft_event->ric_ies, + ft_event->ric_ies_len); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", target_ap: " MAC_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(target_ap)) +); + +TRACE_EVENT(cfg80211_stop_iface, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev), + TP_ARGS(wiphy, wdev), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT, + WIPHY_PR_ARG, WDEV_PR_ARG) +); + +TRACE_EVENT(cfg80211_pmsr_report, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, + u64 cookie, const u8 *addr), + TP_ARGS(wiphy, wdev, cookie, addr), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + __field(u64, cookie) + MAC_ENTRY(addr) + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + __entry->cookie = cookie; + MAC_ASSIGN(addr, addr); + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT ", cookie:%lld, " MAC_PR_FMT, + WIPHY_PR_ARG, WDEV_PR_ARG, + (unsigned long long)__entry->cookie, + MAC_PR_ARG(addr)) +); + +TRACE_EVENT(cfg80211_pmsr_complete, + TP_PROTO(struct wiphy *wiphy, struct wireless_dev *wdev, u64 cookie), + TP_ARGS(wiphy, wdev, cookie), + TP_STRUCT__entry( + WIPHY_ENTRY + WDEV_ENTRY + __field(u64, cookie) + ), + TP_fast_assign( + WIPHY_ASSIGN; + WDEV_ASSIGN; + __entry->cookie = cookie; + ), + TP_printk(WIPHY_PR_FMT ", " WDEV_PR_FMT ", cookie:%lld", + WIPHY_PR_ARG, WDEV_PR_ARG, + (unsigned long long)__entry->cookie) +); + +TRACE_EVENT(cfg80211_update_owe_info_event, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct cfg80211_update_owe_info *owe_info), + TP_ARGS(wiphy, netdev, owe_info), + TP_STRUCT__entry(WIPHY_ENTRY + NETDEV_ENTRY + MAC_ENTRY(peer) + __dynamic_array(u8, ie, owe_info->ie_len)), + TP_fast_assign(WIPHY_ASSIGN; + NETDEV_ASSIGN; + MAC_ASSIGN(peer, owe_info->peer); + memcpy(__get_dynamic_array(ie), owe_info->ie, + owe_info->ie_len);), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", peer: " MAC_PR_FMT, + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(peer)) +); + +TRACE_EVENT(cfg80211_bss_color_notify, + TP_PROTO(struct net_device *netdev, + enum nl80211_commands cmd, + u8 count, u64 color_bitmap), + TP_ARGS(netdev, cmd, count, color_bitmap), + TP_STRUCT__entry( + NETDEV_ENTRY + __field(u32, cmd) + __field(u8, count) + __field(u64, color_bitmap) + ), + TP_fast_assign( + NETDEV_ASSIGN; + __entry->cmd = cmd; + __entry->count = count; + __entry->color_bitmap = color_bitmap; + ), + TP_printk(NETDEV_PR_FMT ", cmd: %x, count: %u, bitmap: %llx", + NETDEV_PR_ARG, __entry->cmd, __entry->count, + __entry->color_bitmap) +); + +TRACE_EVENT(cfg80211_assoc_comeback, + TP_PROTO(struct wireless_dev *wdev, const u8 *ap_addr, u32 timeout), + TP_ARGS(wdev, ap_addr, timeout), + TP_STRUCT__entry( + WDEV_ENTRY + MAC_ENTRY(ap_addr) + __field(u32, timeout) + ), + TP_fast_assign( + WDEV_ASSIGN; + MAC_ASSIGN(ap_addr, ap_addr); + __entry->timeout = timeout; + ), + TP_printk(WDEV_PR_FMT ", " MAC_PR_FMT ", timeout: %u TUs", + WDEV_PR_ARG, MAC_PR_ARG(ap_addr), __entry->timeout) +); + +DECLARE_EVENT_CLASS(link_station_add_mod, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct link_station_parameters *params), + TP_ARGS(wiphy, netdev, params), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __array(u8, mld_mac, 6) + __array(u8, link_mac, 6) + __field(u32, link_id) + __dynamic_array(u8, supported_rates, + params->supported_rates_len) + __array(u8, ht_capa, (int)sizeof(struct ieee80211_ht_cap)) + __array(u8, vht_capa, (int)sizeof(struct ieee80211_vht_cap)) + __field(u8, opmode_notif) + __field(bool, opmode_notif_used) + __dynamic_array(u8, he_capa, params->he_capa_len) + __array(u8, he_6ghz_capa, (int)sizeof(struct ieee80211_he_6ghz_capa)) + __dynamic_array(u8, eht_capa, params->eht_capa_len) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + memset(__entry->mld_mac, 0, 6); + memset(__entry->link_mac, 0, 6); + if (params->mld_mac) + memcpy(__entry->mld_mac, params->mld_mac, 6); + if (params->link_mac) + memcpy(__entry->link_mac, params->link_mac, 6); + __entry->link_id = params->link_id; + if (params->supported_rates && params->supported_rates_len) + memcpy(__get_dynamic_array(supported_rates), + params->supported_rates, + params->supported_rates_len); + memset(__entry->ht_capa, 0, sizeof(struct ieee80211_ht_cap)); + if (params->ht_capa) + memcpy(__entry->ht_capa, params->ht_capa, + sizeof(struct ieee80211_ht_cap)); + memset(__entry->vht_capa, 0, sizeof(struct ieee80211_vht_cap)); + if (params->vht_capa) + memcpy(__entry->vht_capa, params->vht_capa, + sizeof(struct ieee80211_vht_cap)); + __entry->opmode_notif = params->opmode_notif; + __entry->opmode_notif_used = params->opmode_notif_used; + if (params->he_capa && params->he_capa_len) + memcpy(__get_dynamic_array(he_capa), params->he_capa, + params->he_capa_len); + memset(__entry->he_6ghz_capa, 0, sizeof(struct ieee80211_he_6ghz_capa)); + if (params->he_6ghz_capa) + memcpy(__entry->he_6ghz_capa, params->he_6ghz_capa, + sizeof(struct ieee80211_he_6ghz_capa)); + if (params->eht_capa && params->eht_capa_len) + memcpy(__get_dynamic_array(eht_capa), params->eht_capa, + params->eht_capa_len); + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", station mac: " MAC_PR_FMT + ", link mac: " MAC_PR_FMT ", link id: %u", + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(mld_mac), + MAC_PR_ARG(link_mac), __entry->link_id) +); + +DEFINE_EVENT(link_station_add_mod, rdev_add_link_station, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct link_station_parameters *params), + TP_ARGS(wiphy, netdev, params) +); + +DEFINE_EVENT(link_station_add_mod, rdev_mod_link_station, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct link_station_parameters *params), + TP_ARGS(wiphy, netdev, params) +); + +TRACE_EVENT(rdev_del_link_station, + TP_PROTO(struct wiphy *wiphy, struct net_device *netdev, + struct link_station_del_parameters *params), + TP_ARGS(wiphy, netdev, params), + TP_STRUCT__entry( + WIPHY_ENTRY + NETDEV_ENTRY + __array(u8, mld_mac, 6) + __field(u32, link_id) + ), + TP_fast_assign( + WIPHY_ASSIGN; + NETDEV_ASSIGN; + memset(__entry->mld_mac, 0, 6); + if (params->mld_mac) + memcpy(__entry->mld_mac, params->mld_mac, 6); + __entry->link_id = params->link_id; + ), + TP_printk(WIPHY_PR_FMT ", " NETDEV_PR_FMT ", station mac: " MAC_PR_FMT + ", link id: %u", + WIPHY_PR_ARG, NETDEV_PR_ARG, MAC_PR_ARG(mld_mac), + __entry->link_id) +); + +#endif /* !__RDEV_OPS_TRACE || TRACE_HEADER_MULTI_READ */ + +#undef TRACE_INCLUDE_PATH +#define TRACE_INCLUDE_PATH . +#undef TRACE_INCLUDE_FILE +#define TRACE_INCLUDE_FILE trace +#include diff --git a/net/wireless/util.c b/net/wireless/util.c new file mode 100644 index 000000000..f433f3fdd --- /dev/null +++ b/net/wireless/util.c @@ -0,0 +1,2517 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * Wireless utility functions + * + * Copyright 2007-2009 Johannes Berg + * Copyright 2013-2014 Intel Mobile Communications GmbH + * Copyright 2017 Intel Deutschland GmbH + * Copyright (C) 2018-2023 Intel Corporation + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "core.h" +#include "rdev-ops.h" + + +const struct ieee80211_rate * +ieee80211_get_response_rate(struct ieee80211_supported_band *sband, + u32 basic_rates, int bitrate) +{ + struct ieee80211_rate *result = &sband->bitrates[0]; + int i; + + for (i = 0; i < sband->n_bitrates; i++) { + if (!(basic_rates & BIT(i))) + continue; + if (sband->bitrates[i].bitrate > bitrate) + continue; + result = &sband->bitrates[i]; + } + + return result; +} +EXPORT_SYMBOL(ieee80211_get_response_rate); + +u32 ieee80211_mandatory_rates(struct ieee80211_supported_band *sband, + enum nl80211_bss_scan_width scan_width) +{ + struct ieee80211_rate *bitrates; + u32 mandatory_rates = 0; + enum ieee80211_rate_flags mandatory_flag; + int i; + + if (WARN_ON(!sband)) + return 1; + + if (sband->band == NL80211_BAND_2GHZ) { + if (scan_width == NL80211_BSS_CHAN_WIDTH_5 || + scan_width == NL80211_BSS_CHAN_WIDTH_10) + mandatory_flag = IEEE80211_RATE_MANDATORY_G; + else + mandatory_flag = IEEE80211_RATE_MANDATORY_B; + } else { + mandatory_flag = IEEE80211_RATE_MANDATORY_A; + } + + bitrates = sband->bitrates; + for (i = 0; i < sband->n_bitrates; i++) + if (bitrates[i].flags & mandatory_flag) + mandatory_rates |= BIT(i); + return mandatory_rates; +} +EXPORT_SYMBOL(ieee80211_mandatory_rates); + +u32 ieee80211_channel_to_freq_khz(int chan, enum nl80211_band band) +{ + /* see 802.11 17.3.8.3.2 and Annex J + * there are overlapping channel numbers in 5GHz and 2GHz bands */ + if (chan <= 0) + return 0; /* not supported */ + switch (band) { + case NL80211_BAND_2GHZ: + case NL80211_BAND_LC: + if (chan == 14) + return MHZ_TO_KHZ(2484); + else if (chan < 14) + return MHZ_TO_KHZ(2407 + chan * 5); + break; + case NL80211_BAND_5GHZ: + if (chan >= 182 && chan <= 196) + return MHZ_TO_KHZ(4000 + chan * 5); + else + return MHZ_TO_KHZ(5000 + chan * 5); + break; + case NL80211_BAND_6GHZ: + /* see 802.11ax D6.1 27.3.23.2 */ + if (chan == 2) + return MHZ_TO_KHZ(5935); + if (chan <= 233) + return MHZ_TO_KHZ(5950 + chan * 5); + break; + case NL80211_BAND_60GHZ: + if (chan < 7) + return MHZ_TO_KHZ(56160 + chan * 2160); + break; + case NL80211_BAND_S1GHZ: + return 902000 + chan * 500; + default: + ; + } + return 0; /* not supported */ +} +EXPORT_SYMBOL(ieee80211_channel_to_freq_khz); + +enum nl80211_chan_width +ieee80211_s1g_channel_width(const struct ieee80211_channel *chan) +{ + if (WARN_ON(!chan || chan->band != NL80211_BAND_S1GHZ)) + return NL80211_CHAN_WIDTH_20_NOHT; + + /*S1G defines a single allowed channel width per channel. + * Extract that width here. + */ + if (chan->flags & IEEE80211_CHAN_1MHZ) + return NL80211_CHAN_WIDTH_1; + else if (chan->flags & IEEE80211_CHAN_2MHZ) + return NL80211_CHAN_WIDTH_2; + else if (chan->flags & IEEE80211_CHAN_4MHZ) + return NL80211_CHAN_WIDTH_4; + else if (chan->flags & IEEE80211_CHAN_8MHZ) + return NL80211_CHAN_WIDTH_8; + else if (chan->flags & IEEE80211_CHAN_16MHZ) + return NL80211_CHAN_WIDTH_16; + + pr_err("unknown channel width for channel at %dKHz?\n", + ieee80211_channel_to_khz(chan)); + + return NL80211_CHAN_WIDTH_1; +} +EXPORT_SYMBOL(ieee80211_s1g_channel_width); + +int ieee80211_freq_khz_to_channel(u32 freq) +{ + /* TODO: just handle MHz for now */ + freq = KHZ_TO_MHZ(freq); + + /* see 802.11 17.3.8.3.2 and Annex J */ + if (freq == 2484) + return 14; + else if (freq < 2484) + return (freq - 2407) / 5; + else if (freq >= 4910 && freq <= 4980) + return (freq - 4000) / 5; + else if (freq < 5925) + return (freq - 5000) / 5; + else if (freq == 5935) + return 2; + else if (freq <= 45000) /* DMG band lower limit */ + /* see 802.11ax D6.1 27.3.22.2 */ + return (freq - 5950) / 5; + else if (freq >= 58320 && freq <= 70200) + return (freq - 56160) / 2160; + else + return 0; +} +EXPORT_SYMBOL(ieee80211_freq_khz_to_channel); + +struct ieee80211_channel *ieee80211_get_channel_khz(struct wiphy *wiphy, + u32 freq) +{ + enum nl80211_band band; + struct ieee80211_supported_band *sband; + int i; + + for (band = 0; band < NUM_NL80211_BANDS; band++) { + sband = wiphy->bands[band]; + + if (!sband) + continue; + + for (i = 0; i < sband->n_channels; i++) { + struct ieee80211_channel *chan = &sband->channels[i]; + + if (ieee80211_channel_to_khz(chan) == freq) + return chan; + } + } + + return NULL; +} +EXPORT_SYMBOL(ieee80211_get_channel_khz); + +static void set_mandatory_flags_band(struct ieee80211_supported_band *sband) +{ + int i, want; + + switch (sband->band) { + case NL80211_BAND_5GHZ: + case NL80211_BAND_6GHZ: + want = 3; + for (i = 0; i < sband->n_bitrates; i++) { + if (sband->bitrates[i].bitrate == 60 || + sband->bitrates[i].bitrate == 120 || + sband->bitrates[i].bitrate == 240) { + sband->bitrates[i].flags |= + IEEE80211_RATE_MANDATORY_A; + want--; + } + } + WARN_ON(want); + break; + case NL80211_BAND_2GHZ: + case NL80211_BAND_LC: + want = 7; + for (i = 0; i < sband->n_bitrates; i++) { + switch (sband->bitrates[i].bitrate) { + case 10: + case 20: + case 55: + case 110: + sband->bitrates[i].flags |= + IEEE80211_RATE_MANDATORY_B | + IEEE80211_RATE_MANDATORY_G; + want--; + break; + case 60: + case 120: + case 240: + sband->bitrates[i].flags |= + IEEE80211_RATE_MANDATORY_G; + want--; + fallthrough; + default: + sband->bitrates[i].flags |= + IEEE80211_RATE_ERP_G; + break; + } + } + WARN_ON(want != 0 && want != 3); + break; + case NL80211_BAND_60GHZ: + /* check for mandatory HT MCS 1..4 */ + WARN_ON(!sband->ht_cap.ht_supported); + WARN_ON((sband->ht_cap.mcs.rx_mask[0] & 0x1e) != 0x1e); + break; + case NL80211_BAND_S1GHZ: + /* Figure 9-589bd: 3 means unsupported, so != 3 means at least + * mandatory is ok. + */ + WARN_ON((sband->s1g_cap.nss_mcs[0] & 0x3) == 0x3); + break; + case NUM_NL80211_BANDS: + default: + WARN_ON(1); + break; + } +} + +void ieee80211_set_bitrate_flags(struct wiphy *wiphy) +{ + enum nl80211_band band; + + for (band = 0; band < NUM_NL80211_BANDS; band++) + if (wiphy->bands[band]) + set_mandatory_flags_band(wiphy->bands[band]); +} + +bool cfg80211_supported_cipher_suite(struct wiphy *wiphy, u32 cipher) +{ + int i; + for (i = 0; i < wiphy->n_cipher_suites; i++) + if (cipher == wiphy->cipher_suites[i]) + return true; + return false; +} + +static bool +cfg80211_igtk_cipher_supported(struct cfg80211_registered_device *rdev) +{ + struct wiphy *wiphy = &rdev->wiphy; + int i; + + for (i = 0; i < wiphy->n_cipher_suites; i++) { + switch (wiphy->cipher_suites[i]) { + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + return true; + } + } + + return false; +} + +bool cfg80211_valid_key_idx(struct cfg80211_registered_device *rdev, + int key_idx, bool pairwise) +{ + int max_key_idx; + + if (pairwise) + max_key_idx = 3; + else if (wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_BEACON_PROTECTION) || + wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_BEACON_PROTECTION_CLIENT)) + max_key_idx = 7; + else if (cfg80211_igtk_cipher_supported(rdev)) + max_key_idx = 5; + else + max_key_idx = 3; + + if (key_idx < 0 || key_idx > max_key_idx) + return false; + + return true; +} + +int cfg80211_validate_key_settings(struct cfg80211_registered_device *rdev, + struct key_params *params, int key_idx, + bool pairwise, const u8 *mac_addr) +{ + if (!cfg80211_valid_key_idx(rdev, key_idx, pairwise)) + return -EINVAL; + + if (!pairwise && mac_addr && !(rdev->wiphy.flags & WIPHY_FLAG_IBSS_RSN)) + return -EINVAL; + + if (pairwise && !mac_addr) + return -EINVAL; + + switch (params->cipher) { + case WLAN_CIPHER_SUITE_TKIP: + /* Extended Key ID can only be used with CCMP/GCMP ciphers */ + if ((pairwise && key_idx) || + params->mode != NL80211_KEY_RX_TX) + return -EINVAL; + break; + case WLAN_CIPHER_SUITE_CCMP: + case WLAN_CIPHER_SUITE_CCMP_256: + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + /* IEEE802.11-2016 allows only 0 and - when supporting + * Extended Key ID - 1 as index for pairwise keys. + * @NL80211_KEY_NO_TX is only allowed for pairwise keys when + * the driver supports Extended Key ID. + * @NL80211_KEY_SET_TX can't be set when installing and + * validating a key. + */ + if ((params->mode == NL80211_KEY_NO_TX && !pairwise) || + params->mode == NL80211_KEY_SET_TX) + return -EINVAL; + if (wiphy_ext_feature_isset(&rdev->wiphy, + NL80211_EXT_FEATURE_EXT_KEY_ID)) { + if (pairwise && (key_idx < 0 || key_idx > 1)) + return -EINVAL; + } else if (pairwise && key_idx) { + return -EINVAL; + } + break; + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + /* Disallow BIP (group-only) cipher as pairwise cipher */ + if (pairwise) + return -EINVAL; + if (key_idx < 4) + return -EINVAL; + break; + case WLAN_CIPHER_SUITE_WEP40: + case WLAN_CIPHER_SUITE_WEP104: + if (key_idx > 3) + return -EINVAL; + break; + default: + break; + } + + switch (params->cipher) { + case WLAN_CIPHER_SUITE_WEP40: + if (params->key_len != WLAN_KEY_LEN_WEP40) + return -EINVAL; + break; + case WLAN_CIPHER_SUITE_TKIP: + if (params->key_len != WLAN_KEY_LEN_TKIP) + return -EINVAL; + break; + case WLAN_CIPHER_SUITE_CCMP: + if (params->key_len != WLAN_KEY_LEN_CCMP) + return -EINVAL; + break; + case WLAN_CIPHER_SUITE_CCMP_256: + if (params->key_len != WLAN_KEY_LEN_CCMP_256) + return -EINVAL; + break; + case WLAN_CIPHER_SUITE_GCMP: + if (params->key_len != WLAN_KEY_LEN_GCMP) + return -EINVAL; + break; + case WLAN_CIPHER_SUITE_GCMP_256: + if (params->key_len != WLAN_KEY_LEN_GCMP_256) + return -EINVAL; + break; + case WLAN_CIPHER_SUITE_WEP104: + if (params->key_len != WLAN_KEY_LEN_WEP104) + return -EINVAL; + break; + case WLAN_CIPHER_SUITE_AES_CMAC: + if (params->key_len != WLAN_KEY_LEN_AES_CMAC) + return -EINVAL; + break; + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + if (params->key_len != WLAN_KEY_LEN_BIP_CMAC_256) + return -EINVAL; + break; + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + if (params->key_len != WLAN_KEY_LEN_BIP_GMAC_128) + return -EINVAL; + break; + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + if (params->key_len != WLAN_KEY_LEN_BIP_GMAC_256) + return -EINVAL; + break; + default: + /* + * We don't know anything about this algorithm, + * allow using it -- but the driver must check + * all parameters! We still check below whether + * or not the driver supports this algorithm, + * of course. + */ + break; + } + + if (params->seq) { + switch (params->cipher) { + case WLAN_CIPHER_SUITE_WEP40: + case WLAN_CIPHER_SUITE_WEP104: + /* These ciphers do not use key sequence */ + return -EINVAL; + case WLAN_CIPHER_SUITE_TKIP: + case WLAN_CIPHER_SUITE_CCMP: + case WLAN_CIPHER_SUITE_CCMP_256: + case WLAN_CIPHER_SUITE_GCMP: + case WLAN_CIPHER_SUITE_GCMP_256: + case WLAN_CIPHER_SUITE_AES_CMAC: + case WLAN_CIPHER_SUITE_BIP_CMAC_256: + case WLAN_CIPHER_SUITE_BIP_GMAC_128: + case WLAN_CIPHER_SUITE_BIP_GMAC_256: + if (params->seq_len != 6) + return -EINVAL; + break; + } + } + + if (!cfg80211_supported_cipher_suite(&rdev->wiphy, params->cipher)) + return -EINVAL; + + return 0; +} + +unsigned int __attribute_const__ ieee80211_hdrlen(__le16 fc) +{ + unsigned int hdrlen = 24; + + if (ieee80211_is_ext(fc)) { + hdrlen = 4; + goto out; + } + + if (ieee80211_is_data(fc)) { + if (ieee80211_has_a4(fc)) + hdrlen = 30; + if (ieee80211_is_data_qos(fc)) { + hdrlen += IEEE80211_QOS_CTL_LEN; + if (ieee80211_has_order(fc)) + hdrlen += IEEE80211_HT_CTL_LEN; + } + goto out; + } + + if (ieee80211_is_mgmt(fc)) { + if (ieee80211_has_order(fc)) + hdrlen += IEEE80211_HT_CTL_LEN; + goto out; + } + + if (ieee80211_is_ctl(fc)) { + /* + * ACK and CTS are 10 bytes, all others 16. To see how + * to get this condition consider + * subtype mask: 0b0000000011110000 (0x00F0) + * ACK subtype: 0b0000000011010000 (0x00D0) + * CTS subtype: 0b0000000011000000 (0x00C0) + * bits that matter: ^^^ (0x00E0) + * value of those: 0b0000000011000000 (0x00C0) + */ + if ((fc & cpu_to_le16(0x00E0)) == cpu_to_le16(0x00C0)) + hdrlen = 10; + else + hdrlen = 16; + } +out: + return hdrlen; +} +EXPORT_SYMBOL(ieee80211_hdrlen); + +unsigned int ieee80211_get_hdrlen_from_skb(const struct sk_buff *skb) +{ + const struct ieee80211_hdr *hdr = + (const struct ieee80211_hdr *)skb->data; + unsigned int hdrlen; + + if (unlikely(skb->len < 10)) + return 0; + hdrlen = ieee80211_hdrlen(hdr->frame_control); + if (unlikely(hdrlen > skb->len)) + return 0; + return hdrlen; +} +EXPORT_SYMBOL(ieee80211_get_hdrlen_from_skb); + +static unsigned int __ieee80211_get_mesh_hdrlen(u8 flags) +{ + int ae = flags & MESH_FLAGS_AE; + /* 802.11-2012, 8.2.4.7.3 */ + switch (ae) { + default: + case 0: + return 6; + case MESH_FLAGS_AE_A4: + return 12; + case MESH_FLAGS_AE_A5_A6: + return 18; + } +} + +unsigned int ieee80211_get_mesh_hdrlen(struct ieee80211s_hdr *meshhdr) +{ + return __ieee80211_get_mesh_hdrlen(meshhdr->flags); +} +EXPORT_SYMBOL(ieee80211_get_mesh_hdrlen); + +int ieee80211_data_to_8023_exthdr(struct sk_buff *skb, struct ethhdr *ehdr, + const u8 *addr, enum nl80211_iftype iftype, + u8 data_offset, bool is_amsdu) +{ + struct ieee80211_hdr *hdr = (struct ieee80211_hdr *) skb->data; + struct { + u8 hdr[ETH_ALEN] __aligned(2); + __be16 proto; + } payload; + struct ethhdr tmp; + u16 hdrlen; + u8 mesh_flags = 0; + + if (unlikely(!ieee80211_is_data_present(hdr->frame_control))) + return -1; + + hdrlen = ieee80211_hdrlen(hdr->frame_control) + data_offset; + if (skb->len < hdrlen) + return -1; + + /* convert IEEE 802.11 header + possible LLC headers into Ethernet + * header + * IEEE 802.11 address fields: + * ToDS FromDS Addr1 Addr2 Addr3 Addr4 + * 0 0 DA SA BSSID n/a + * 0 1 DA BSSID SA n/a + * 1 0 BSSID SA DA n/a + * 1 1 RA TA DA SA + */ + memcpy(tmp.h_dest, ieee80211_get_DA(hdr), ETH_ALEN); + memcpy(tmp.h_source, ieee80211_get_SA(hdr), ETH_ALEN); + + if (iftype == NL80211_IFTYPE_MESH_POINT && + skb_copy_bits(skb, hdrlen, &mesh_flags, 1) < 0) + return -1; + + mesh_flags &= MESH_FLAGS_AE; + + switch (hdr->frame_control & + cpu_to_le16(IEEE80211_FCTL_TODS | IEEE80211_FCTL_FROMDS)) { + case cpu_to_le16(IEEE80211_FCTL_TODS): + if (unlikely(iftype != NL80211_IFTYPE_AP && + iftype != NL80211_IFTYPE_AP_VLAN && + iftype != NL80211_IFTYPE_P2P_GO)) + return -1; + break; + case cpu_to_le16(IEEE80211_FCTL_TODS | IEEE80211_FCTL_FROMDS): + if (unlikely(iftype != NL80211_IFTYPE_MESH_POINT && + iftype != NL80211_IFTYPE_AP_VLAN && + iftype != NL80211_IFTYPE_STATION)) + return -1; + if (iftype == NL80211_IFTYPE_MESH_POINT) { + if (mesh_flags == MESH_FLAGS_AE_A4) + return -1; + if (mesh_flags == MESH_FLAGS_AE_A5_A6 && + skb_copy_bits(skb, hdrlen + + offsetof(struct ieee80211s_hdr, eaddr1), + tmp.h_dest, 2 * ETH_ALEN) < 0) + return -1; + + hdrlen += __ieee80211_get_mesh_hdrlen(mesh_flags); + } + break; + case cpu_to_le16(IEEE80211_FCTL_FROMDS): + if ((iftype != NL80211_IFTYPE_STATION && + iftype != NL80211_IFTYPE_P2P_CLIENT && + iftype != NL80211_IFTYPE_MESH_POINT) || + (is_multicast_ether_addr(tmp.h_dest) && + ether_addr_equal(tmp.h_source, addr))) + return -1; + if (iftype == NL80211_IFTYPE_MESH_POINT) { + if (mesh_flags == MESH_FLAGS_AE_A5_A6) + return -1; + if (mesh_flags == MESH_FLAGS_AE_A4 && + skb_copy_bits(skb, hdrlen + + offsetof(struct ieee80211s_hdr, eaddr1), + tmp.h_source, ETH_ALEN) < 0) + return -1; + hdrlen += __ieee80211_get_mesh_hdrlen(mesh_flags); + } + break; + case cpu_to_le16(0): + if (iftype != NL80211_IFTYPE_ADHOC && + iftype != NL80211_IFTYPE_STATION && + iftype != NL80211_IFTYPE_OCB) + return -1; + break; + } + + if (likely(skb_copy_bits(skb, hdrlen, &payload, sizeof(payload)) == 0 && + ((!is_amsdu && ether_addr_equal(payload.hdr, rfc1042_header) && + payload.proto != htons(ETH_P_AARP) && + payload.proto != htons(ETH_P_IPX)) || + ether_addr_equal(payload.hdr, bridge_tunnel_header)))) { + /* remove RFC1042 or Bridge-Tunnel encapsulation and + * replace EtherType */ + hdrlen += ETH_ALEN + 2; + tmp.h_proto = payload.proto; + skb_postpull_rcsum(skb, &payload, ETH_ALEN + 2); + } else { + tmp.h_proto = htons(skb->len - hdrlen); + } + + pskb_pull(skb, hdrlen); + + if (!ehdr) + ehdr = skb_push(skb, sizeof(struct ethhdr)); + memcpy(ehdr, &tmp, sizeof(tmp)); + + return 0; +} +EXPORT_SYMBOL(ieee80211_data_to_8023_exthdr); + +static void +__frame_add_frag(struct sk_buff *skb, struct page *page, + void *ptr, int len, int size) +{ + struct skb_shared_info *sh = skb_shinfo(skb); + int page_offset; + + get_page(page); + page_offset = ptr - page_address(page); + skb_add_rx_frag(skb, sh->nr_frags, page, page_offset, len, size); +} + +static void +__ieee80211_amsdu_copy_frag(struct sk_buff *skb, struct sk_buff *frame, + int offset, int len) +{ + struct skb_shared_info *sh = skb_shinfo(skb); + const skb_frag_t *frag = &sh->frags[0]; + struct page *frag_page; + void *frag_ptr; + int frag_len, frag_size; + int head_size = skb->len - skb->data_len; + int cur_len; + + frag_page = virt_to_head_page(skb->head); + frag_ptr = skb->data; + frag_size = head_size; + + while (offset >= frag_size) { + offset -= frag_size; + frag_page = skb_frag_page(frag); + frag_ptr = skb_frag_address(frag); + frag_size = skb_frag_size(frag); + frag++; + } + + frag_ptr += offset; + frag_len = frag_size - offset; + + cur_len = min(len, frag_len); + + __frame_add_frag(frame, frag_page, frag_ptr, cur_len, frag_size); + len -= cur_len; + + while (len > 0) { + frag_len = skb_frag_size(frag); + cur_len = min(len, frag_len); + __frame_add_frag(frame, skb_frag_page(frag), + skb_frag_address(frag), cur_len, frag_len); + len -= cur_len; + frag++; + } +} + +static struct sk_buff * +__ieee80211_amsdu_copy(struct sk_buff *skb, unsigned int hlen, + int offset, int len, bool reuse_frag) +{ + struct sk_buff *frame; + int cur_len = len; + + if (skb->len - offset < len) + return NULL; + + /* + * When reusing framents, copy some data to the head to simplify + * ethernet header handling and speed up protocol header processing + * in the stack later. + */ + if (reuse_frag) + cur_len = min_t(int, len, 32); + + /* + * Allocate and reserve two bytes more for payload + * alignment since sizeof(struct ethhdr) is 14. + */ + frame = dev_alloc_skb(hlen + sizeof(struct ethhdr) + 2 + cur_len); + if (!frame) + return NULL; + + skb_reserve(frame, hlen + sizeof(struct ethhdr) + 2); + skb_copy_bits(skb, offset, skb_put(frame, cur_len), cur_len); + + len -= cur_len; + if (!len) + return frame; + + offset += cur_len; + __ieee80211_amsdu_copy_frag(skb, frame, offset, len); + + return frame; +} + +void ieee80211_amsdu_to_8023s(struct sk_buff *skb, struct sk_buff_head *list, + const u8 *addr, enum nl80211_iftype iftype, + const unsigned int extra_headroom, + const u8 *check_da, const u8 *check_sa) +{ + unsigned int hlen = ALIGN(extra_headroom, 4); + struct sk_buff *frame = NULL; + u16 ethertype; + u8 *payload; + int offset = 0, remaining; + struct ethhdr eth; + bool reuse_frag = skb->head_frag && !skb_has_frag_list(skb); + bool reuse_skb = false; + bool last = false; + + while (!last) { + unsigned int subframe_len; + int len; + u8 padding; + + skb_copy_bits(skb, offset, ð, sizeof(eth)); + len = ntohs(eth.h_proto); + subframe_len = sizeof(struct ethhdr) + len; + padding = (4 - subframe_len) & 0x3; + + /* the last MSDU has no padding */ + remaining = skb->len - offset; + if (subframe_len > remaining) + goto purge; + /* mitigate A-MSDU aggregation injection attacks */ + if (ether_addr_equal(eth.h_dest, rfc1042_header)) + goto purge; + + offset += sizeof(struct ethhdr); + last = remaining <= subframe_len + padding; + + /* FIXME: should we really accept multicast DA? */ + if ((check_da && !is_multicast_ether_addr(eth.h_dest) && + !ether_addr_equal(check_da, eth.h_dest)) || + (check_sa && !ether_addr_equal(check_sa, eth.h_source))) { + offset += len + padding; + continue; + } + + /* reuse skb for the last subframe */ + if (!skb_is_nonlinear(skb) && !reuse_frag && last) { + skb_pull(skb, offset); + frame = skb; + reuse_skb = true; + } else { + frame = __ieee80211_amsdu_copy(skb, hlen, offset, len, + reuse_frag); + if (!frame) + goto purge; + + offset += len + padding; + } + + skb_reset_network_header(frame); + frame->dev = skb->dev; + frame->priority = skb->priority; + + payload = frame->data; + ethertype = (payload[6] << 8) | payload[7]; + if (likely((ether_addr_equal(payload, rfc1042_header) && + ethertype != ETH_P_AARP && ethertype != ETH_P_IPX) || + ether_addr_equal(payload, bridge_tunnel_header))) { + eth.h_proto = htons(ethertype); + skb_pull(frame, ETH_ALEN + 2); + } + + memcpy(skb_push(frame, sizeof(eth)), ð, sizeof(eth)); + __skb_queue_tail(list, frame); + } + + if (!reuse_skb) + dev_kfree_skb(skb); + + return; + + purge: + __skb_queue_purge(list); + dev_kfree_skb(skb); +} +EXPORT_SYMBOL(ieee80211_amsdu_to_8023s); + +/* Given a data frame determine the 802.1p/1d tag to use. */ +unsigned int cfg80211_classify8021d(struct sk_buff *skb, + struct cfg80211_qos_map *qos_map) +{ + unsigned int dscp; + unsigned char vlan_priority; + unsigned int ret; + + /* skb->priority values from 256->263 are magic values to + * directly indicate a specific 802.1d priority. This is used + * to allow 802.1d priority to be passed directly in from VLAN + * tags, etc. + */ + if (skb->priority >= 256 && skb->priority <= 263) { + ret = skb->priority - 256; + goto out; + } + + if (skb_vlan_tag_present(skb)) { + vlan_priority = (skb_vlan_tag_get(skb) & VLAN_PRIO_MASK) + >> VLAN_PRIO_SHIFT; + if (vlan_priority > 0) { + ret = vlan_priority; + goto out; + } + } + + switch (skb->protocol) { + case htons(ETH_P_IP): + dscp = ipv4_get_dsfield(ip_hdr(skb)) & 0xfc; + break; + case htons(ETH_P_IPV6): + dscp = ipv6_get_dsfield(ipv6_hdr(skb)) & 0xfc; + break; + case htons(ETH_P_MPLS_UC): + case htons(ETH_P_MPLS_MC): { + struct mpls_label mpls_tmp, *mpls; + + mpls = skb_header_pointer(skb, sizeof(struct ethhdr), + sizeof(*mpls), &mpls_tmp); + if (!mpls) + return 0; + + ret = (ntohl(mpls->entry) & MPLS_LS_TC_MASK) + >> MPLS_LS_TC_SHIFT; + goto out; + } + case htons(ETH_P_80221): + /* 802.21 is always network control traffic */ + return 7; + default: + return 0; + } + + if (qos_map) { + unsigned int i, tmp_dscp = dscp >> 2; + + for (i = 0; i < qos_map->num_des; i++) { + if (tmp_dscp == qos_map->dscp_exception[i].dscp) { + ret = qos_map->dscp_exception[i].up; + goto out; + } + } + + for (i = 0; i < 8; i++) { + if (tmp_dscp >= qos_map->up[i].low && + tmp_dscp <= qos_map->up[i].high) { + ret = i; + goto out; + } + } + } + + ret = dscp >> 5; +out: + return array_index_nospec(ret, IEEE80211_NUM_TIDS); +} +EXPORT_SYMBOL(cfg80211_classify8021d); + +const struct element *ieee80211_bss_get_elem(struct cfg80211_bss *bss, u8 id) +{ + const struct cfg80211_bss_ies *ies; + + ies = rcu_dereference(bss->ies); + if (!ies) + return NULL; + + return cfg80211_find_elem(id, ies->data, ies->len); +} +EXPORT_SYMBOL(ieee80211_bss_get_elem); + +void cfg80211_upload_connect_keys(struct wireless_dev *wdev) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct net_device *dev = wdev->netdev; + int i; + + if (!wdev->connect_keys) + return; + + for (i = 0; i < CFG80211_MAX_WEP_KEYS; i++) { + if (!wdev->connect_keys->params[i].cipher) + continue; + if (rdev_add_key(rdev, dev, -1, i, false, NULL, + &wdev->connect_keys->params[i])) { + netdev_err(dev, "failed to set key %d\n", i); + continue; + } + if (wdev->connect_keys->def == i && + rdev_set_default_key(rdev, dev, -1, i, true, true)) { + netdev_err(dev, "failed to set defkey %d\n", i); + continue; + } + } + + kfree_sensitive(wdev->connect_keys); + wdev->connect_keys = NULL; +} + +void cfg80211_process_wdev_events(struct wireless_dev *wdev) +{ + struct cfg80211_event *ev; + unsigned long flags; + + spin_lock_irqsave(&wdev->event_lock, flags); + while (!list_empty(&wdev->event_list)) { + ev = list_first_entry(&wdev->event_list, + struct cfg80211_event, list); + list_del(&ev->list); + spin_unlock_irqrestore(&wdev->event_lock, flags); + + wdev_lock(wdev); + switch (ev->type) { + case EVENT_CONNECT_RESULT: + __cfg80211_connect_result( + wdev->netdev, + &ev->cr, + ev->cr.status == WLAN_STATUS_SUCCESS); + break; + case EVENT_ROAMED: + __cfg80211_roamed(wdev, &ev->rm); + break; + case EVENT_DISCONNECTED: + __cfg80211_disconnected(wdev->netdev, + ev->dc.ie, ev->dc.ie_len, + ev->dc.reason, + !ev->dc.locally_generated); + break; + case EVENT_IBSS_JOINED: + __cfg80211_ibss_joined(wdev->netdev, ev->ij.bssid, + ev->ij.channel); + break; + case EVENT_STOPPED: + __cfg80211_leave(wiphy_to_rdev(wdev->wiphy), wdev); + break; + case EVENT_PORT_AUTHORIZED: + __cfg80211_port_authorized(wdev, ev->pa.bssid); + break; + } + wdev_unlock(wdev); + + kfree(ev); + + spin_lock_irqsave(&wdev->event_lock, flags); + } + spin_unlock_irqrestore(&wdev->event_lock, flags); +} + +void cfg80211_process_rdev_events(struct cfg80211_registered_device *rdev) +{ + struct wireless_dev *wdev; + + lockdep_assert_held(&rdev->wiphy.mtx); + + list_for_each_entry(wdev, &rdev->wiphy.wdev_list, list) + cfg80211_process_wdev_events(wdev); +} + +int cfg80211_change_iface(struct cfg80211_registered_device *rdev, + struct net_device *dev, enum nl80211_iftype ntype, + struct vif_params *params) +{ + int err; + enum nl80211_iftype otype = dev->ieee80211_ptr->iftype; + + lockdep_assert_held(&rdev->wiphy.mtx); + + /* don't support changing VLANs, you just re-create them */ + if (otype == NL80211_IFTYPE_AP_VLAN) + return -EOPNOTSUPP; + + /* cannot change into P2P device or NAN */ + if (ntype == NL80211_IFTYPE_P2P_DEVICE || + ntype == NL80211_IFTYPE_NAN) + return -EOPNOTSUPP; + + if (!rdev->ops->change_virtual_intf || + !(rdev->wiphy.interface_modes & (1 << ntype))) + return -EOPNOTSUPP; + + if (ntype != otype) { + /* if it's part of a bridge, reject changing type to station/ibss */ + if (netif_is_bridge_port(dev) && + (ntype == NL80211_IFTYPE_ADHOC || + ntype == NL80211_IFTYPE_STATION || + ntype == NL80211_IFTYPE_P2P_CLIENT)) + return -EBUSY; + + dev->ieee80211_ptr->use_4addr = false; + wdev_lock(dev->ieee80211_ptr); + rdev_set_qos_map(rdev, dev, NULL); + wdev_unlock(dev->ieee80211_ptr); + + switch (otype) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_P2P_GO: + cfg80211_stop_ap(rdev, dev, -1, true); + break; + case NL80211_IFTYPE_ADHOC: + cfg80211_leave_ibss(rdev, dev, false); + break; + case NL80211_IFTYPE_STATION: + case NL80211_IFTYPE_P2P_CLIENT: + wdev_lock(dev->ieee80211_ptr); + cfg80211_disconnect(rdev, dev, + WLAN_REASON_DEAUTH_LEAVING, true); + wdev_unlock(dev->ieee80211_ptr); + break; + case NL80211_IFTYPE_MESH_POINT: + /* mesh should be handled? */ + break; + case NL80211_IFTYPE_OCB: + cfg80211_leave_ocb(rdev, dev); + break; + default: + break; + } + + cfg80211_process_rdev_events(rdev); + cfg80211_mlme_purge_registrations(dev->ieee80211_ptr); + + memset(&dev->ieee80211_ptr->u, 0, + sizeof(dev->ieee80211_ptr->u)); + memset(&dev->ieee80211_ptr->links, 0, + sizeof(dev->ieee80211_ptr->links)); + } + + err = rdev_change_virtual_intf(rdev, dev, ntype, params); + + WARN_ON(!err && dev->ieee80211_ptr->iftype != ntype); + + if (!err && params && params->use_4addr != -1) + dev->ieee80211_ptr->use_4addr = params->use_4addr; + + if (!err) { + dev->priv_flags &= ~IFF_DONT_BRIDGE; + switch (ntype) { + case NL80211_IFTYPE_STATION: + if (dev->ieee80211_ptr->use_4addr) + break; + fallthrough; + case NL80211_IFTYPE_OCB: + case NL80211_IFTYPE_P2P_CLIENT: + case NL80211_IFTYPE_ADHOC: + dev->priv_flags |= IFF_DONT_BRIDGE; + break; + case NL80211_IFTYPE_P2P_GO: + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_AP_VLAN: + case NL80211_IFTYPE_MESH_POINT: + /* bridging OK */ + break; + case NL80211_IFTYPE_MONITOR: + /* monitor can't bridge anyway */ + break; + case NL80211_IFTYPE_UNSPECIFIED: + case NUM_NL80211_IFTYPES: + /* not happening */ + break; + case NL80211_IFTYPE_P2P_DEVICE: + case NL80211_IFTYPE_WDS: + case NL80211_IFTYPE_NAN: + WARN_ON(1); + break; + } + } + + if (!err && ntype != otype && netif_running(dev)) { + cfg80211_update_iface_num(rdev, ntype, 1); + cfg80211_update_iface_num(rdev, otype, -1); + } + + return err; +} + +static u32 cfg80211_calculate_bitrate_ht(struct rate_info *rate) +{ + int modulation, streams, bitrate; + + /* the formula below does only work for MCS values smaller than 32 */ + if (WARN_ON_ONCE(rate->mcs >= 32)) + return 0; + + modulation = rate->mcs & 7; + streams = (rate->mcs >> 3) + 1; + + bitrate = (rate->bw == RATE_INFO_BW_40) ? 13500000 : 6500000; + + if (modulation < 4) + bitrate *= (modulation + 1); + else if (modulation == 4) + bitrate *= (modulation + 2); + else + bitrate *= (modulation + 3); + + bitrate *= streams; + + if (rate->flags & RATE_INFO_FLAGS_SHORT_GI) + bitrate = (bitrate / 9) * 10; + + /* do NOT round down here */ + return (bitrate + 50000) / 100000; +} + +static u32 cfg80211_calculate_bitrate_dmg(struct rate_info *rate) +{ + static const u32 __mcs2bitrate[] = { + /* control PHY */ + [0] = 275, + /* SC PHY */ + [1] = 3850, + [2] = 7700, + [3] = 9625, + [4] = 11550, + [5] = 12512, /* 1251.25 mbps */ + [6] = 15400, + [7] = 19250, + [8] = 23100, + [9] = 25025, + [10] = 30800, + [11] = 38500, + [12] = 46200, + /* OFDM PHY */ + [13] = 6930, + [14] = 8662, /* 866.25 mbps */ + [15] = 13860, + [16] = 17325, + [17] = 20790, + [18] = 27720, + [19] = 34650, + [20] = 41580, + [21] = 45045, + [22] = 51975, + [23] = 62370, + [24] = 67568, /* 6756.75 mbps */ + /* LP-SC PHY */ + [25] = 6260, + [26] = 8340, + [27] = 11120, + [28] = 12510, + [29] = 16680, + [30] = 22240, + [31] = 25030, + }; + + if (WARN_ON_ONCE(rate->mcs >= ARRAY_SIZE(__mcs2bitrate))) + return 0; + + return __mcs2bitrate[rate->mcs]; +} + +static u32 cfg80211_calculate_bitrate_extended_sc_dmg(struct rate_info *rate) +{ + static const u32 __mcs2bitrate[] = { + [6 - 6] = 26950, /* MCS 9.1 : 2695.0 mbps */ + [7 - 6] = 50050, /* MCS 12.1 */ + [8 - 6] = 53900, + [9 - 6] = 57750, + [10 - 6] = 63900, + [11 - 6] = 75075, + [12 - 6] = 80850, + }; + + /* Extended SC MCS not defined for base MCS below 6 or above 12 */ + if (WARN_ON_ONCE(rate->mcs < 6 || rate->mcs > 12)) + return 0; + + return __mcs2bitrate[rate->mcs - 6]; +} + +static u32 cfg80211_calculate_bitrate_edmg(struct rate_info *rate) +{ + static const u32 __mcs2bitrate[] = { + /* control PHY */ + [0] = 275, + /* SC PHY */ + [1] = 3850, + [2] = 7700, + [3] = 9625, + [4] = 11550, + [5] = 12512, /* 1251.25 mbps */ + [6] = 13475, + [7] = 15400, + [8] = 19250, + [9] = 23100, + [10] = 25025, + [11] = 26950, + [12] = 30800, + [13] = 38500, + [14] = 46200, + [15] = 50050, + [16] = 53900, + [17] = 57750, + [18] = 69300, + [19] = 75075, + [20] = 80850, + }; + + if (WARN_ON_ONCE(rate->mcs >= ARRAY_SIZE(__mcs2bitrate))) + return 0; + + return __mcs2bitrate[rate->mcs] * rate->n_bonded_ch; +} + +static u32 cfg80211_calculate_bitrate_vht(struct rate_info *rate) +{ + static const u32 base[4][12] = { + { 6500000, + 13000000, + 19500000, + 26000000, + 39000000, + 52000000, + 58500000, + 65000000, + 78000000, + /* not in the spec, but some devices use this: */ + 86700000, + 97500000, + 108300000, + }, + { 13500000, + 27000000, + 40500000, + 54000000, + 81000000, + 108000000, + 121500000, + 135000000, + 162000000, + 180000000, + 202500000, + 225000000, + }, + { 29300000, + 58500000, + 87800000, + 117000000, + 175500000, + 234000000, + 263300000, + 292500000, + 351000000, + 390000000, + 438800000, + 487500000, + }, + { 58500000, + 117000000, + 175500000, + 234000000, + 351000000, + 468000000, + 526500000, + 585000000, + 702000000, + 780000000, + 877500000, + 975000000, + }, + }; + u32 bitrate; + int idx; + + if (rate->mcs > 11) + goto warn; + + switch (rate->bw) { + case RATE_INFO_BW_160: + idx = 3; + break; + case RATE_INFO_BW_80: + idx = 2; + break; + case RATE_INFO_BW_40: + idx = 1; + break; + case RATE_INFO_BW_5: + case RATE_INFO_BW_10: + default: + goto warn; + case RATE_INFO_BW_20: + idx = 0; + } + + bitrate = base[idx][rate->mcs]; + bitrate *= rate->nss; + + if (rate->flags & RATE_INFO_FLAGS_SHORT_GI) + bitrate = (bitrate / 9) * 10; + + /* do NOT round down here */ + return (bitrate + 50000) / 100000; + warn: + WARN_ONCE(1, "invalid rate bw=%d, mcs=%d, nss=%d\n", + rate->bw, rate->mcs, rate->nss); + return 0; +} + +static u32 cfg80211_calculate_bitrate_he(struct rate_info *rate) +{ +#define SCALE 6144 + u32 mcs_divisors[14] = { + 102399, /* 16.666666... */ + 51201, /* 8.333333... */ + 34134, /* 5.555555... */ + 25599, /* 4.166666... */ + 17067, /* 2.777777... */ + 12801, /* 2.083333... */ + 11377, /* 1.851725... */ + 10239, /* 1.666666... */ + 8532, /* 1.388888... */ + 7680, /* 1.250000... */ + 6828, /* 1.111111... */ + 6144, /* 1.000000... */ + 5690, /* 0.926106... */ + 5120, /* 0.833333... */ + }; + u32 rates_160M[3] = { 960777777, 907400000, 816666666 }; + u32 rates_969[3] = { 480388888, 453700000, 408333333 }; + u32 rates_484[3] = { 229411111, 216666666, 195000000 }; + u32 rates_242[3] = { 114711111, 108333333, 97500000 }; + u32 rates_106[3] = { 40000000, 37777777, 34000000 }; + u32 rates_52[3] = { 18820000, 17777777, 16000000 }; + u32 rates_26[3] = { 9411111, 8888888, 8000000 }; + u64 tmp; + u32 result; + + if (WARN_ON_ONCE(rate->mcs > 13)) + return 0; + + if (WARN_ON_ONCE(rate->he_gi > NL80211_RATE_INFO_HE_GI_3_2)) + return 0; + if (WARN_ON_ONCE(rate->he_ru_alloc > + NL80211_RATE_INFO_HE_RU_ALLOC_2x996)) + return 0; + if (WARN_ON_ONCE(rate->nss < 1 || rate->nss > 8)) + return 0; + + if (rate->bw == RATE_INFO_BW_160) + result = rates_160M[rate->he_gi]; + else if (rate->bw == RATE_INFO_BW_80 || + (rate->bw == RATE_INFO_BW_HE_RU && + rate->he_ru_alloc == NL80211_RATE_INFO_HE_RU_ALLOC_996)) + result = rates_969[rate->he_gi]; + else if (rate->bw == RATE_INFO_BW_40 || + (rate->bw == RATE_INFO_BW_HE_RU && + rate->he_ru_alloc == NL80211_RATE_INFO_HE_RU_ALLOC_484)) + result = rates_484[rate->he_gi]; + else if (rate->bw == RATE_INFO_BW_20 || + (rate->bw == RATE_INFO_BW_HE_RU && + rate->he_ru_alloc == NL80211_RATE_INFO_HE_RU_ALLOC_242)) + result = rates_242[rate->he_gi]; + else if (rate->bw == RATE_INFO_BW_HE_RU && + rate->he_ru_alloc == NL80211_RATE_INFO_HE_RU_ALLOC_106) + result = rates_106[rate->he_gi]; + else if (rate->bw == RATE_INFO_BW_HE_RU && + rate->he_ru_alloc == NL80211_RATE_INFO_HE_RU_ALLOC_52) + result = rates_52[rate->he_gi]; + else if (rate->bw == RATE_INFO_BW_HE_RU && + rate->he_ru_alloc == NL80211_RATE_INFO_HE_RU_ALLOC_26) + result = rates_26[rate->he_gi]; + else { + WARN(1, "invalid HE MCS: bw:%d, ru:%d\n", + rate->bw, rate->he_ru_alloc); + return 0; + } + + /* now scale to the appropriate MCS */ + tmp = result; + tmp *= SCALE; + do_div(tmp, mcs_divisors[rate->mcs]); + result = tmp; + + /* and take NSS, DCM into account */ + result = (result * rate->nss) / 8; + if (rate->he_dcm) + result /= 2; + + return result / 10000; +} + +static u32 cfg80211_calculate_bitrate_eht(struct rate_info *rate) +{ +#define SCALE 6144 + static const u32 mcs_divisors[16] = { + 102399, /* 16.666666... */ + 51201, /* 8.333333... */ + 34134, /* 5.555555... */ + 25599, /* 4.166666... */ + 17067, /* 2.777777... */ + 12801, /* 2.083333... */ + 11377, /* 1.851725... */ + 10239, /* 1.666666... */ + 8532, /* 1.388888... */ + 7680, /* 1.250000... */ + 6828, /* 1.111111... */ + 6144, /* 1.000000... */ + 5690, /* 0.926106... */ + 5120, /* 0.833333... */ + 409600, /* 66.666666... */ + 204800, /* 33.333333... */ + }; + static const u32 rates_996[3] = { 480388888, 453700000, 408333333 }; + static const u32 rates_484[3] = { 229411111, 216666666, 195000000 }; + static const u32 rates_242[3] = { 114711111, 108333333, 97500000 }; + static const u32 rates_106[3] = { 40000000, 37777777, 34000000 }; + static const u32 rates_52[3] = { 18820000, 17777777, 16000000 }; + static const u32 rates_26[3] = { 9411111, 8888888, 8000000 }; + u64 tmp; + u32 result; + + if (WARN_ON_ONCE(rate->mcs > 15)) + return 0; + if (WARN_ON_ONCE(rate->eht_gi > NL80211_RATE_INFO_EHT_GI_3_2)) + return 0; + if (WARN_ON_ONCE(rate->eht_ru_alloc > + NL80211_RATE_INFO_EHT_RU_ALLOC_4x996)) + return 0; + if (WARN_ON_ONCE(rate->nss < 1 || rate->nss > 8)) + return 0; + + /* Bandwidth checks for MCS 14 */ + if (rate->mcs == 14) { + if ((rate->bw != RATE_INFO_BW_EHT_RU && + rate->bw != RATE_INFO_BW_80 && + rate->bw != RATE_INFO_BW_160 && + rate->bw != RATE_INFO_BW_320) || + (rate->bw == RATE_INFO_BW_EHT_RU && + rate->eht_ru_alloc != NL80211_RATE_INFO_EHT_RU_ALLOC_996 && + rate->eht_ru_alloc != NL80211_RATE_INFO_EHT_RU_ALLOC_2x996 && + rate->eht_ru_alloc != NL80211_RATE_INFO_EHT_RU_ALLOC_4x996)) { + WARN(1, "invalid EHT BW for MCS 14: bw:%d, ru:%d\n", + rate->bw, rate->eht_ru_alloc); + return 0; + } + } + + if (rate->bw == RATE_INFO_BW_320 || + (rate->bw == RATE_INFO_BW_EHT_RU && + rate->eht_ru_alloc == NL80211_RATE_INFO_EHT_RU_ALLOC_4x996)) + result = 4 * rates_996[rate->eht_gi]; + else if (rate->bw == RATE_INFO_BW_EHT_RU && + rate->eht_ru_alloc == NL80211_RATE_INFO_EHT_RU_ALLOC_3x996P484) + result = 3 * rates_996[rate->eht_gi] + rates_484[rate->eht_gi]; + else if (rate->bw == RATE_INFO_BW_EHT_RU && + rate->eht_ru_alloc == NL80211_RATE_INFO_EHT_RU_ALLOC_3x996) + result = 3 * rates_996[rate->eht_gi]; + else if (rate->bw == RATE_INFO_BW_EHT_RU && + rate->eht_ru_alloc == NL80211_RATE_INFO_EHT_RU_ALLOC_2x996P484) + result = 2 * rates_996[rate->eht_gi] + rates_484[rate->eht_gi]; + else if (rate->bw == RATE_INFO_BW_160 || + (rate->bw == RATE_INFO_BW_EHT_RU && + rate->eht_ru_alloc == NL80211_RATE_INFO_EHT_RU_ALLOC_2x996)) + result = 2 * rates_996[rate->eht_gi]; + else if (rate->bw == RATE_INFO_BW_EHT_RU && + rate->eht_ru_alloc == + NL80211_RATE_INFO_EHT_RU_ALLOC_996P484P242) + result = rates_996[rate->eht_gi] + rates_484[rate->eht_gi] + + rates_242[rate->eht_gi]; + else if (rate->bw == RATE_INFO_BW_EHT_RU && + rate->eht_ru_alloc == NL80211_RATE_INFO_EHT_RU_ALLOC_996P484) + result = rates_996[rate->eht_gi] + rates_484[rate->eht_gi]; + else if (rate->bw == RATE_INFO_BW_80 || + (rate->bw == RATE_INFO_BW_EHT_RU && + rate->eht_ru_alloc == NL80211_RATE_INFO_EHT_RU_ALLOC_996)) + result = rates_996[rate->eht_gi]; + else if (rate->bw == RATE_INFO_BW_EHT_RU && + rate->eht_ru_alloc == NL80211_RATE_INFO_EHT_RU_ALLOC_484P242) + result = rates_484[rate->eht_gi] + rates_242[rate->eht_gi]; + else if (rate->bw == RATE_INFO_BW_40 || + (rate->bw == RATE_INFO_BW_EHT_RU && + rate->eht_ru_alloc == NL80211_RATE_INFO_EHT_RU_ALLOC_484)) + result = rates_484[rate->eht_gi]; + else if (rate->bw == RATE_INFO_BW_20 || + (rate->bw == RATE_INFO_BW_EHT_RU && + rate->eht_ru_alloc == NL80211_RATE_INFO_EHT_RU_ALLOC_242)) + result = rates_242[rate->eht_gi]; + else if (rate->bw == RATE_INFO_BW_EHT_RU && + rate->eht_ru_alloc == NL80211_RATE_INFO_EHT_RU_ALLOC_106P26) + result = rates_106[rate->eht_gi] + rates_26[rate->eht_gi]; + else if (rate->bw == RATE_INFO_BW_EHT_RU && + rate->eht_ru_alloc == NL80211_RATE_INFO_EHT_RU_ALLOC_106) + result = rates_106[rate->eht_gi]; + else if (rate->bw == RATE_INFO_BW_EHT_RU && + rate->eht_ru_alloc == NL80211_RATE_INFO_EHT_RU_ALLOC_52P26) + result = rates_52[rate->eht_gi] + rates_26[rate->eht_gi]; + else if (rate->bw == RATE_INFO_BW_EHT_RU && + rate->eht_ru_alloc == NL80211_RATE_INFO_EHT_RU_ALLOC_52) + result = rates_52[rate->eht_gi]; + else if (rate->bw == RATE_INFO_BW_EHT_RU && + rate->eht_ru_alloc == NL80211_RATE_INFO_EHT_RU_ALLOC_26) + result = rates_26[rate->eht_gi]; + else { + WARN(1, "invalid EHT MCS: bw:%d, ru:%d\n", + rate->bw, rate->eht_ru_alloc); + return 0; + } + + /* now scale to the appropriate MCS */ + tmp = result; + tmp *= SCALE; + do_div(tmp, mcs_divisors[rate->mcs]); + + /* and take NSS */ + tmp *= rate->nss; + do_div(tmp, 8); + + result = tmp; + + return result / 10000; +} + +u32 cfg80211_calculate_bitrate(struct rate_info *rate) +{ + if (rate->flags & RATE_INFO_FLAGS_MCS) + return cfg80211_calculate_bitrate_ht(rate); + if (rate->flags & RATE_INFO_FLAGS_DMG) + return cfg80211_calculate_bitrate_dmg(rate); + if (rate->flags & RATE_INFO_FLAGS_EXTENDED_SC_DMG) + return cfg80211_calculate_bitrate_extended_sc_dmg(rate); + if (rate->flags & RATE_INFO_FLAGS_EDMG) + return cfg80211_calculate_bitrate_edmg(rate); + if (rate->flags & RATE_INFO_FLAGS_VHT_MCS) + return cfg80211_calculate_bitrate_vht(rate); + if (rate->flags & RATE_INFO_FLAGS_HE_MCS) + return cfg80211_calculate_bitrate_he(rate); + if (rate->flags & RATE_INFO_FLAGS_EHT_MCS) + return cfg80211_calculate_bitrate_eht(rate); + + return rate->legacy; +} +EXPORT_SYMBOL(cfg80211_calculate_bitrate); + +int cfg80211_get_p2p_attr(const u8 *ies, unsigned int len, + enum ieee80211_p2p_attr_id attr, + u8 *buf, unsigned int bufsize) +{ + u8 *out = buf; + u16 attr_remaining = 0; + bool desired_attr = false; + u16 desired_len = 0; + + while (len > 0) { + unsigned int iedatalen; + unsigned int copy; + const u8 *iedata; + + if (len < 2) + return -EILSEQ; + iedatalen = ies[1]; + if (iedatalen + 2 > len) + return -EILSEQ; + + if (ies[0] != WLAN_EID_VENDOR_SPECIFIC) + goto cont; + + if (iedatalen < 4) + goto cont; + + iedata = ies + 2; + + /* check WFA OUI, P2P subtype */ + if (iedata[0] != 0x50 || iedata[1] != 0x6f || + iedata[2] != 0x9a || iedata[3] != 0x09) + goto cont; + + iedatalen -= 4; + iedata += 4; + + /* check attribute continuation into this IE */ + copy = min_t(unsigned int, attr_remaining, iedatalen); + if (copy && desired_attr) { + desired_len += copy; + if (out) { + memcpy(out, iedata, min(bufsize, copy)); + out += min(bufsize, copy); + bufsize -= min(bufsize, copy); + } + + + if (copy == attr_remaining) + return desired_len; + } + + attr_remaining -= copy; + if (attr_remaining) + goto cont; + + iedatalen -= copy; + iedata += copy; + + while (iedatalen > 0) { + u16 attr_len; + + /* P2P attribute ID & size must fit */ + if (iedatalen < 3) + return -EILSEQ; + desired_attr = iedata[0] == attr; + attr_len = get_unaligned_le16(iedata + 1); + iedatalen -= 3; + iedata += 3; + + copy = min_t(unsigned int, attr_len, iedatalen); + + if (desired_attr) { + desired_len += copy; + if (out) { + memcpy(out, iedata, min(bufsize, copy)); + out += min(bufsize, copy); + bufsize -= min(bufsize, copy); + } + + if (copy == attr_len) + return desired_len; + } + + iedata += copy; + iedatalen -= copy; + attr_remaining = attr_len - copy; + } + + cont: + len -= ies[1] + 2; + ies += ies[1] + 2; + } + + if (attr_remaining && desired_attr) + return -EILSEQ; + + return -ENOENT; +} +EXPORT_SYMBOL(cfg80211_get_p2p_attr); + +static bool ieee80211_id_in_list(const u8 *ids, int n_ids, u8 id, bool id_ext) +{ + int i; + + /* Make sure array values are legal */ + if (WARN_ON(ids[n_ids - 1] == WLAN_EID_EXTENSION)) + return false; + + i = 0; + while (i < n_ids) { + if (ids[i] == WLAN_EID_EXTENSION) { + if (id_ext && (ids[i + 1] == id)) + return true; + + i += 2; + continue; + } + + if (ids[i] == id && !id_ext) + return true; + + i++; + } + return false; +} + +static size_t skip_ie(const u8 *ies, size_t ielen, size_t pos) +{ + /* we assume a validly formed IEs buffer */ + u8 len = ies[pos + 1]; + + pos += 2 + len; + + /* the IE itself must have 255 bytes for fragments to follow */ + if (len < 255) + return pos; + + while (pos < ielen && ies[pos] == WLAN_EID_FRAGMENT) { + len = ies[pos + 1]; + pos += 2 + len; + } + + return pos; +} + +size_t ieee80211_ie_split_ric(const u8 *ies, size_t ielen, + const u8 *ids, int n_ids, + const u8 *after_ric, int n_after_ric, + size_t offset) +{ + size_t pos = offset; + + while (pos < ielen) { + u8 ext = 0; + + if (ies[pos] == WLAN_EID_EXTENSION) + ext = 2; + if ((pos + ext) >= ielen) + break; + + if (!ieee80211_id_in_list(ids, n_ids, ies[pos + ext], + ies[pos] == WLAN_EID_EXTENSION)) + break; + + if (ies[pos] == WLAN_EID_RIC_DATA && n_after_ric) { + pos = skip_ie(ies, ielen, pos); + + while (pos < ielen) { + if (ies[pos] == WLAN_EID_EXTENSION) + ext = 2; + else + ext = 0; + + if ((pos + ext) >= ielen) + break; + + if (!ieee80211_id_in_list(after_ric, + n_after_ric, + ies[pos + ext], + ext == 2)) + pos = skip_ie(ies, ielen, pos); + else + break; + } + } else { + pos = skip_ie(ies, ielen, pos); + } + } + + return pos; +} +EXPORT_SYMBOL(ieee80211_ie_split_ric); + +bool ieee80211_operating_class_to_band(u8 operating_class, + enum nl80211_band *band) +{ + switch (operating_class) { + case 112: + case 115 ... 127: + case 128 ... 130: + *band = NL80211_BAND_5GHZ; + return true; + case 131 ... 135: + *band = NL80211_BAND_6GHZ; + return true; + case 81: + case 82: + case 83: + case 84: + *band = NL80211_BAND_2GHZ; + return true; + case 180: + *band = NL80211_BAND_60GHZ; + return true; + } + + return false; +} +EXPORT_SYMBOL(ieee80211_operating_class_to_band); + +bool ieee80211_chandef_to_operating_class(struct cfg80211_chan_def *chandef, + u8 *op_class) +{ + u8 vht_opclass; + u32 freq = chandef->center_freq1; + + if (freq >= 2412 && freq <= 2472) { + if (chandef->width > NL80211_CHAN_WIDTH_40) + return false; + + /* 2.407 GHz, channels 1..13 */ + if (chandef->width == NL80211_CHAN_WIDTH_40) { + if (freq > chandef->chan->center_freq) + *op_class = 83; /* HT40+ */ + else + *op_class = 84; /* HT40- */ + } else { + *op_class = 81; + } + + return true; + } + + if (freq == 2484) { + /* channel 14 is only for IEEE 802.11b */ + if (chandef->width != NL80211_CHAN_WIDTH_20_NOHT) + return false; + + *op_class = 82; /* channel 14 */ + return true; + } + + switch (chandef->width) { + case NL80211_CHAN_WIDTH_80: + vht_opclass = 128; + break; + case NL80211_CHAN_WIDTH_160: + vht_opclass = 129; + break; + case NL80211_CHAN_WIDTH_80P80: + vht_opclass = 130; + break; + case NL80211_CHAN_WIDTH_10: + case NL80211_CHAN_WIDTH_5: + return false; /* unsupported for now */ + default: + vht_opclass = 0; + break; + } + + /* 5 GHz, channels 36..48 */ + if (freq >= 5180 && freq <= 5240) { + if (vht_opclass) { + *op_class = vht_opclass; + } else if (chandef->width == NL80211_CHAN_WIDTH_40) { + if (freq > chandef->chan->center_freq) + *op_class = 116; + else + *op_class = 117; + } else { + *op_class = 115; + } + + return true; + } + + /* 5 GHz, channels 52..64 */ + if (freq >= 5260 && freq <= 5320) { + if (vht_opclass) { + *op_class = vht_opclass; + } else if (chandef->width == NL80211_CHAN_WIDTH_40) { + if (freq > chandef->chan->center_freq) + *op_class = 119; + else + *op_class = 120; + } else { + *op_class = 118; + } + + return true; + } + + /* 5 GHz, channels 100..144 */ + if (freq >= 5500 && freq <= 5720) { + if (vht_opclass) { + *op_class = vht_opclass; + } else if (chandef->width == NL80211_CHAN_WIDTH_40) { + if (freq > chandef->chan->center_freq) + *op_class = 122; + else + *op_class = 123; + } else { + *op_class = 121; + } + + return true; + } + + /* 5 GHz, channels 149..169 */ + if (freq >= 5745 && freq <= 5845) { + if (vht_opclass) { + *op_class = vht_opclass; + } else if (chandef->width == NL80211_CHAN_WIDTH_40) { + if (freq > chandef->chan->center_freq) + *op_class = 126; + else + *op_class = 127; + } else if (freq <= 5805) { + *op_class = 124; + } else { + *op_class = 125; + } + + return true; + } + + /* 56.16 GHz, channel 1..4 */ + if (freq >= 56160 + 2160 * 1 && freq <= 56160 + 2160 * 6) { + if (chandef->width >= NL80211_CHAN_WIDTH_40) + return false; + + *op_class = 180; + return true; + } + + /* not supported yet */ + return false; +} +EXPORT_SYMBOL(ieee80211_chandef_to_operating_class); + +static int cfg80211_wdev_bi(struct wireless_dev *wdev) +{ + switch (wdev->iftype) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_P2P_GO: + WARN_ON(wdev->valid_links); + return wdev->links[0].ap.beacon_interval; + case NL80211_IFTYPE_MESH_POINT: + return wdev->u.mesh.beacon_interval; + case NL80211_IFTYPE_ADHOC: + return wdev->u.ibss.beacon_interval; + default: + break; + } + + return 0; +} + +static void cfg80211_calculate_bi_data(struct wiphy *wiphy, u32 new_beacon_int, + u32 *beacon_int_gcd, + bool *beacon_int_different) +{ + struct wireless_dev *wdev; + + *beacon_int_gcd = 0; + *beacon_int_different = false; + + list_for_each_entry(wdev, &wiphy->wdev_list, list) { + int wdev_bi; + + /* this feature isn't supported with MLO */ + if (wdev->valid_links) + continue; + + wdev_bi = cfg80211_wdev_bi(wdev); + + if (!wdev_bi) + continue; + + if (!*beacon_int_gcd) { + *beacon_int_gcd = wdev_bi; + continue; + } + + if (wdev_bi == *beacon_int_gcd) + continue; + + *beacon_int_different = true; + *beacon_int_gcd = gcd(*beacon_int_gcd, wdev_bi); + } + + if (new_beacon_int && *beacon_int_gcd != new_beacon_int) { + if (*beacon_int_gcd) + *beacon_int_different = true; + *beacon_int_gcd = gcd(*beacon_int_gcd, new_beacon_int); + } +} + +int cfg80211_validate_beacon_int(struct cfg80211_registered_device *rdev, + enum nl80211_iftype iftype, u32 beacon_int) +{ + /* + * This is just a basic pre-condition check; if interface combinations + * are possible the driver must already be checking those with a call + * to cfg80211_check_combinations(), in which case we'll validate more + * through the cfg80211_calculate_bi_data() call and code in + * cfg80211_iter_combinations(). + */ + + if (beacon_int < 10 || beacon_int > 10000) + return -EINVAL; + + return 0; +} + +int cfg80211_iter_combinations(struct wiphy *wiphy, + struct iface_combination_params *params, + void (*iter)(const struct ieee80211_iface_combination *c, + void *data), + void *data) +{ + const struct ieee80211_regdomain *regdom; + enum nl80211_dfs_regions region = 0; + int i, j, iftype; + int num_interfaces = 0; + u32 used_iftypes = 0; + u32 beacon_int_gcd; + bool beacon_int_different; + + /* + * This is a bit strange, since the iteration used to rely only on + * the data given by the driver, but here it now relies on context, + * in form of the currently operating interfaces. + * This is OK for all current users, and saves us from having to + * push the GCD calculations into all the drivers. + * In the future, this should probably rely more on data that's in + * cfg80211 already - the only thing not would appear to be any new + * interfaces (while being brought up) and channel/radar data. + */ + cfg80211_calculate_bi_data(wiphy, params->new_beacon_int, + &beacon_int_gcd, &beacon_int_different); + + if (params->radar_detect) { + rcu_read_lock(); + regdom = rcu_dereference(cfg80211_regdomain); + if (regdom) + region = regdom->dfs_region; + rcu_read_unlock(); + } + + for (iftype = 0; iftype < NUM_NL80211_IFTYPES; iftype++) { + num_interfaces += params->iftype_num[iftype]; + if (params->iftype_num[iftype] > 0 && + !cfg80211_iftype_allowed(wiphy, iftype, 0, 1)) + used_iftypes |= BIT(iftype); + } + + for (i = 0; i < wiphy->n_iface_combinations; i++) { + const struct ieee80211_iface_combination *c; + struct ieee80211_iface_limit *limits; + u32 all_iftypes = 0; + + c = &wiphy->iface_combinations[i]; + + if (num_interfaces > c->max_interfaces) + continue; + if (params->num_different_channels > c->num_different_channels) + continue; + + limits = kmemdup(c->limits, sizeof(limits[0]) * c->n_limits, + GFP_KERNEL); + if (!limits) + return -ENOMEM; + + for (iftype = 0; iftype < NUM_NL80211_IFTYPES; iftype++) { + if (cfg80211_iftype_allowed(wiphy, iftype, 0, 1)) + continue; + for (j = 0; j < c->n_limits; j++) { + all_iftypes |= limits[j].types; + if (!(limits[j].types & BIT(iftype))) + continue; + if (limits[j].max < params->iftype_num[iftype]) + goto cont; + limits[j].max -= params->iftype_num[iftype]; + } + } + + if (params->radar_detect != + (c->radar_detect_widths & params->radar_detect)) + goto cont; + + if (params->radar_detect && c->radar_detect_regions && + !(c->radar_detect_regions & BIT(region))) + goto cont; + + /* Finally check that all iftypes that we're currently + * using are actually part of this combination. If they + * aren't then we can't use this combination and have + * to continue to the next. + */ + if ((all_iftypes & used_iftypes) != used_iftypes) + goto cont; + + if (beacon_int_gcd) { + if (c->beacon_int_min_gcd && + beacon_int_gcd < c->beacon_int_min_gcd) + goto cont; + if (!c->beacon_int_min_gcd && beacon_int_different) + goto cont; + } + + /* This combination covered all interface types and + * supported the requested numbers, so we're good. + */ + + (*iter)(c, data); + cont: + kfree(limits); + } + + return 0; +} +EXPORT_SYMBOL(cfg80211_iter_combinations); + +static void +cfg80211_iter_sum_ifcombs(const struct ieee80211_iface_combination *c, + void *data) +{ + int *num = data; + (*num)++; +} + +int cfg80211_check_combinations(struct wiphy *wiphy, + struct iface_combination_params *params) +{ + int err, num = 0; + + err = cfg80211_iter_combinations(wiphy, params, + cfg80211_iter_sum_ifcombs, &num); + if (err) + return err; + if (num == 0) + return -EBUSY; + + return 0; +} +EXPORT_SYMBOL(cfg80211_check_combinations); + +int ieee80211_get_ratemask(struct ieee80211_supported_band *sband, + const u8 *rates, unsigned int n_rates, + u32 *mask) +{ + int i, j; + + if (!sband) + return -EINVAL; + + if (n_rates == 0 || n_rates > NL80211_MAX_SUPP_RATES) + return -EINVAL; + + *mask = 0; + + for (i = 0; i < n_rates; i++) { + int rate = (rates[i] & 0x7f) * 5; + bool found = false; + + for (j = 0; j < sband->n_bitrates; j++) { + if (sband->bitrates[j].bitrate == rate) { + found = true; + *mask |= BIT(j); + break; + } + } + if (!found) + return -EINVAL; + } + + /* + * mask must have at least one bit set here since we + * didn't accept a 0-length rates array nor allowed + * entries in the array that didn't exist + */ + + return 0; +} + +unsigned int ieee80211_get_num_supported_channels(struct wiphy *wiphy) +{ + enum nl80211_band band; + unsigned int n_channels = 0; + + for (band = 0; band < NUM_NL80211_BANDS; band++) + if (wiphy->bands[band]) + n_channels += wiphy->bands[band]->n_channels; + + return n_channels; +} +EXPORT_SYMBOL(ieee80211_get_num_supported_channels); + +int cfg80211_get_station(struct net_device *dev, const u8 *mac_addr, + struct station_info *sinfo) +{ + struct cfg80211_registered_device *rdev; + struct wireless_dev *wdev; + + wdev = dev->ieee80211_ptr; + if (!wdev) + return -EOPNOTSUPP; + + rdev = wiphy_to_rdev(wdev->wiphy); + if (!rdev->ops->get_station) + return -EOPNOTSUPP; + + memset(sinfo, 0, sizeof(*sinfo)); + + return rdev_get_station(rdev, dev, mac_addr, sinfo); +} +EXPORT_SYMBOL(cfg80211_get_station); + +void cfg80211_free_nan_func(struct cfg80211_nan_func *f) +{ + int i; + + if (!f) + return; + + kfree(f->serv_spec_info); + kfree(f->srf_bf); + kfree(f->srf_macs); + for (i = 0; i < f->num_rx_filters; i++) + kfree(f->rx_filters[i].filter); + + for (i = 0; i < f->num_tx_filters; i++) + kfree(f->tx_filters[i].filter); + + kfree(f->rx_filters); + kfree(f->tx_filters); + kfree(f); +} +EXPORT_SYMBOL(cfg80211_free_nan_func); + +bool cfg80211_does_bw_fit_range(const struct ieee80211_freq_range *freq_range, + u32 center_freq_khz, u32 bw_khz) +{ + u32 start_freq_khz, end_freq_khz; + + start_freq_khz = center_freq_khz - (bw_khz / 2); + end_freq_khz = center_freq_khz + (bw_khz / 2); + + if (start_freq_khz >= freq_range->start_freq_khz && + end_freq_khz <= freq_range->end_freq_khz) + return true; + + return false; +} + +int cfg80211_sinfo_alloc_tid_stats(struct station_info *sinfo, gfp_t gfp) +{ + sinfo->pertid = kcalloc(IEEE80211_NUM_TIDS + 1, + sizeof(*(sinfo->pertid)), + gfp); + if (!sinfo->pertid) + return -ENOMEM; + + return 0; +} +EXPORT_SYMBOL(cfg80211_sinfo_alloc_tid_stats); + +/* See IEEE 802.1H for LLC/SNAP encapsulation/decapsulation */ +/* Ethernet-II snap header (RFC1042 for most EtherTypes) */ +const unsigned char rfc1042_header[] __aligned(2) = + { 0xaa, 0xaa, 0x03, 0x00, 0x00, 0x00 }; +EXPORT_SYMBOL(rfc1042_header); + +/* Bridge-Tunnel header (for EtherTypes ETH_P_AARP and ETH_P_IPX) */ +const unsigned char bridge_tunnel_header[] __aligned(2) = + { 0xaa, 0xaa, 0x03, 0x00, 0x00, 0xf8 }; +EXPORT_SYMBOL(bridge_tunnel_header); + +/* Layer 2 Update frame (802.2 Type 1 LLC XID Update response) */ +struct iapp_layer2_update { + u8 da[ETH_ALEN]; /* broadcast */ + u8 sa[ETH_ALEN]; /* STA addr */ + __be16 len; /* 6 */ + u8 dsap; /* 0 */ + u8 ssap; /* 0 */ + u8 control; + u8 xid_info[3]; +} __packed; + +void cfg80211_send_layer2_update(struct net_device *dev, const u8 *addr) +{ + struct iapp_layer2_update *msg; + struct sk_buff *skb; + + /* Send Level 2 Update Frame to update forwarding tables in layer 2 + * bridge devices */ + + skb = dev_alloc_skb(sizeof(*msg)); + if (!skb) + return; + msg = skb_put(skb, sizeof(*msg)); + + /* 802.2 Type 1 Logical Link Control (LLC) Exchange Identifier (XID) + * Update response frame; IEEE Std 802.2-1998, 5.4.1.2.1 */ + + eth_broadcast_addr(msg->da); + ether_addr_copy(msg->sa, addr); + msg->len = htons(6); + msg->dsap = 0; + msg->ssap = 0x01; /* NULL LSAP, CR Bit: Response */ + msg->control = 0xaf; /* XID response lsb.1111F101. + * F=0 (no poll command; unsolicited frame) */ + msg->xid_info[0] = 0x81; /* XID format identifier */ + msg->xid_info[1] = 1; /* LLC types/classes: Type 1 LLC */ + msg->xid_info[2] = 0; /* XID sender's receive window size (RW) */ + + skb->dev = dev; + skb->protocol = eth_type_trans(skb, dev); + memset(skb->cb, 0, sizeof(skb->cb)); + netif_rx(skb); +} +EXPORT_SYMBOL(cfg80211_send_layer2_update); + +int ieee80211_get_vht_max_nss(struct ieee80211_vht_cap *cap, + enum ieee80211_vht_chanwidth bw, + int mcs, bool ext_nss_bw_capable, + unsigned int max_vht_nss) +{ + u16 map = le16_to_cpu(cap->supp_mcs.rx_mcs_map); + int ext_nss_bw; + int supp_width; + int i, mcs_encoding; + + if (map == 0xffff) + return 0; + + if (WARN_ON(mcs > 9 || max_vht_nss > 8)) + return 0; + if (mcs <= 7) + mcs_encoding = 0; + else if (mcs == 8) + mcs_encoding = 1; + else + mcs_encoding = 2; + + if (!max_vht_nss) { + /* find max_vht_nss for the given MCS */ + for (i = 7; i >= 0; i--) { + int supp = (map >> (2 * i)) & 3; + + if (supp == 3) + continue; + + if (supp >= mcs_encoding) { + max_vht_nss = i + 1; + break; + } + } + } + + if (!(cap->supp_mcs.tx_mcs_map & + cpu_to_le16(IEEE80211_VHT_EXT_NSS_BW_CAPABLE))) + return max_vht_nss; + + ext_nss_bw = le32_get_bits(cap->vht_cap_info, + IEEE80211_VHT_CAP_EXT_NSS_BW_MASK); + supp_width = le32_get_bits(cap->vht_cap_info, + IEEE80211_VHT_CAP_SUPP_CHAN_WIDTH_MASK); + + /* if not capable, treat ext_nss_bw as 0 */ + if (!ext_nss_bw_capable) + ext_nss_bw = 0; + + /* This is invalid */ + if (supp_width == 3) + return 0; + + /* This is an invalid combination so pretend nothing is supported */ + if (supp_width == 2 && (ext_nss_bw == 1 || ext_nss_bw == 2)) + return 0; + + /* + * Cover all the special cases according to IEEE 802.11-2016 + * Table 9-250. All other cases are either factor of 1 or not + * valid/supported. + */ + switch (bw) { + case IEEE80211_VHT_CHANWIDTH_USE_HT: + case IEEE80211_VHT_CHANWIDTH_80MHZ: + if ((supp_width == 1 || supp_width == 2) && + ext_nss_bw == 3) + return 2 * max_vht_nss; + break; + case IEEE80211_VHT_CHANWIDTH_160MHZ: + if (supp_width == 0 && + (ext_nss_bw == 1 || ext_nss_bw == 2)) + return max_vht_nss / 2; + if (supp_width == 0 && + ext_nss_bw == 3) + return (3 * max_vht_nss) / 4; + if (supp_width == 1 && + ext_nss_bw == 3) + return 2 * max_vht_nss; + break; + case IEEE80211_VHT_CHANWIDTH_80P80MHZ: + if (supp_width == 0 && ext_nss_bw == 1) + return 0; /* not possible */ + if (supp_width == 0 && + ext_nss_bw == 2) + return max_vht_nss / 2; + if (supp_width == 0 && + ext_nss_bw == 3) + return (3 * max_vht_nss) / 4; + if (supp_width == 1 && + ext_nss_bw == 0) + return 0; /* not possible */ + if (supp_width == 1 && + ext_nss_bw == 1) + return max_vht_nss / 2; + if (supp_width == 1 && + ext_nss_bw == 2) + return (3 * max_vht_nss) / 4; + break; + } + + /* not covered or invalid combination received */ + return max_vht_nss; +} +EXPORT_SYMBOL(ieee80211_get_vht_max_nss); + +bool cfg80211_iftype_allowed(struct wiphy *wiphy, enum nl80211_iftype iftype, + bool is_4addr, u8 check_swif) + +{ + bool is_vlan = iftype == NL80211_IFTYPE_AP_VLAN; + + switch (check_swif) { + case 0: + if (is_vlan && is_4addr) + return wiphy->flags & WIPHY_FLAG_4ADDR_AP; + return wiphy->interface_modes & BIT(iftype); + case 1: + if (!(wiphy->software_iftypes & BIT(iftype)) && is_vlan) + return wiphy->flags & WIPHY_FLAG_4ADDR_AP; + return wiphy->software_iftypes & BIT(iftype); + default: + break; + } + + return false; +} +EXPORT_SYMBOL(cfg80211_iftype_allowed); + +void cfg80211_remove_link(struct wireless_dev *wdev, unsigned int link_id) +{ + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + + ASSERT_WDEV_LOCK(wdev); + + switch (wdev->iftype) { + case NL80211_IFTYPE_AP: + case NL80211_IFTYPE_P2P_GO: + __cfg80211_stop_ap(rdev, wdev->netdev, link_id, true); + break; + default: + /* per-link not relevant */ + break; + } + + wdev->valid_links &= ~BIT(link_id); + + rdev_del_intf_link(rdev, wdev, link_id); + + eth_zero_addr(wdev->links[link_id].addr); +} + +void cfg80211_remove_links(struct wireless_dev *wdev) +{ + unsigned int link_id; + + /* + * links are controlled by upper layers (userspace/cfg) + * only for AP mode, so only remove them here for AP + */ + if (wdev->iftype != NL80211_IFTYPE_AP) + return; + + wdev_lock(wdev); + if (wdev->valid_links) { + for_each_valid_link(wdev, link_id) + cfg80211_remove_link(wdev, link_id); + } + wdev_unlock(wdev); +} + +int cfg80211_remove_virtual_intf(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev) +{ + cfg80211_remove_links(wdev); + + return rdev_del_virtual_intf(rdev, wdev); +} + +const struct wiphy_iftype_ext_capab * +cfg80211_get_iftype_ext_capa(struct wiphy *wiphy, enum nl80211_iftype type) +{ + int i; + + for (i = 0; i < wiphy->num_iftype_ext_capab; i++) { + if (wiphy->iftype_ext_capab[i].iftype == type) + return &wiphy->iftype_ext_capab[i]; + } + + return NULL; +} +EXPORT_SYMBOL(cfg80211_get_iftype_ext_capa); diff --git a/net/wireless/wext-compat.c b/net/wireless/wext-compat.c new file mode 100644 index 000000000..ddf340bfa --- /dev/null +++ b/net/wireless/wext-compat.c @@ -0,0 +1,1665 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * cfg80211 - wext compat code + * + * This is temporary code until all wireless functionality is migrated + * into cfg80211, when that happens all the exports here go away and + * we directly assign the wireless handlers of wireless interfaces. + * + * Copyright 2008-2009 Johannes Berg + * Copyright (C) 2019-2022 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "wext-compat.h" +#include "core.h" +#include "rdev-ops.h" + +int cfg80211_wext_giwname(struct net_device *dev, + struct iw_request_info *info, + char *name, char *extra) +{ + strcpy(name, "IEEE 802.11"); + return 0; +} +EXPORT_WEXT_HANDLER(cfg80211_wext_giwname); + +int cfg80211_wext_siwmode(struct net_device *dev, struct iw_request_info *info, + u32 *mode, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev; + struct vif_params vifparams; + enum nl80211_iftype type; + int ret; + + rdev = wiphy_to_rdev(wdev->wiphy); + + switch (*mode) { + case IW_MODE_INFRA: + type = NL80211_IFTYPE_STATION; + break; + case IW_MODE_ADHOC: + type = NL80211_IFTYPE_ADHOC; + break; + case IW_MODE_MONITOR: + type = NL80211_IFTYPE_MONITOR; + break; + default: + return -EINVAL; + } + + if (type == wdev->iftype) + return 0; + + memset(&vifparams, 0, sizeof(vifparams)); + + wiphy_lock(wdev->wiphy); + ret = cfg80211_change_iface(rdev, dev, type, &vifparams); + wiphy_unlock(wdev->wiphy); + + return ret; +} +EXPORT_WEXT_HANDLER(cfg80211_wext_siwmode); + +int cfg80211_wext_giwmode(struct net_device *dev, struct iw_request_info *info, + u32 *mode, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + + if (!wdev) + return -EOPNOTSUPP; + + switch (wdev->iftype) { + case NL80211_IFTYPE_AP: + *mode = IW_MODE_MASTER; + break; + case NL80211_IFTYPE_STATION: + *mode = IW_MODE_INFRA; + break; + case NL80211_IFTYPE_ADHOC: + *mode = IW_MODE_ADHOC; + break; + case NL80211_IFTYPE_MONITOR: + *mode = IW_MODE_MONITOR; + break; + case NL80211_IFTYPE_WDS: + *mode = IW_MODE_REPEAT; + break; + case NL80211_IFTYPE_AP_VLAN: + *mode = IW_MODE_SECOND; /* FIXME */ + break; + default: + *mode = IW_MODE_AUTO; + break; + } + return 0; +} +EXPORT_WEXT_HANDLER(cfg80211_wext_giwmode); + + +int cfg80211_wext_giwrange(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *data, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct iw_range *range = (struct iw_range *) extra; + enum nl80211_band band; + int i, c = 0; + + if (!wdev) + return -EOPNOTSUPP; + + data->length = sizeof(struct iw_range); + memset(range, 0, sizeof(struct iw_range)); + + range->we_version_compiled = WIRELESS_EXT; + range->we_version_source = 21; + range->retry_capa = IW_RETRY_LIMIT; + range->retry_flags = IW_RETRY_LIMIT; + range->min_retry = 0; + range->max_retry = 255; + range->min_rts = 0; + range->max_rts = 2347; + range->min_frag = 256; + range->max_frag = 2346; + + range->max_encoding_tokens = 4; + + range->max_qual.updated = IW_QUAL_NOISE_INVALID; + + switch (wdev->wiphy->signal_type) { + case CFG80211_SIGNAL_TYPE_NONE: + break; + case CFG80211_SIGNAL_TYPE_MBM: + range->max_qual.level = (u8)-110; + range->max_qual.qual = 70; + range->avg_qual.qual = 35; + range->max_qual.updated |= IW_QUAL_DBM; + range->max_qual.updated |= IW_QUAL_QUAL_UPDATED; + range->max_qual.updated |= IW_QUAL_LEVEL_UPDATED; + break; + case CFG80211_SIGNAL_TYPE_UNSPEC: + range->max_qual.level = 100; + range->max_qual.qual = 100; + range->avg_qual.qual = 50; + range->max_qual.updated |= IW_QUAL_QUAL_UPDATED; + range->max_qual.updated |= IW_QUAL_LEVEL_UPDATED; + break; + } + + range->avg_qual.level = range->max_qual.level / 2; + range->avg_qual.noise = range->max_qual.noise / 2; + range->avg_qual.updated = range->max_qual.updated; + + for (i = 0; i < wdev->wiphy->n_cipher_suites; i++) { + switch (wdev->wiphy->cipher_suites[i]) { + case WLAN_CIPHER_SUITE_TKIP: + range->enc_capa |= (IW_ENC_CAPA_CIPHER_TKIP | + IW_ENC_CAPA_WPA); + break; + + case WLAN_CIPHER_SUITE_CCMP: + range->enc_capa |= (IW_ENC_CAPA_CIPHER_CCMP | + IW_ENC_CAPA_WPA2); + break; + + case WLAN_CIPHER_SUITE_WEP40: + range->encoding_size[range->num_encoding_sizes++] = + WLAN_KEY_LEN_WEP40; + break; + + case WLAN_CIPHER_SUITE_WEP104: + range->encoding_size[range->num_encoding_sizes++] = + WLAN_KEY_LEN_WEP104; + break; + } + } + + for (band = 0; band < NUM_NL80211_BANDS; band ++) { + struct ieee80211_supported_band *sband; + + sband = wdev->wiphy->bands[band]; + + if (!sband) + continue; + + for (i = 0; i < sband->n_channels && c < IW_MAX_FREQUENCIES; i++) { + struct ieee80211_channel *chan = &sband->channels[i]; + + if (!(chan->flags & IEEE80211_CHAN_DISABLED)) { + range->freq[c].i = + ieee80211_frequency_to_channel( + chan->center_freq); + range->freq[c].m = chan->center_freq; + range->freq[c].e = 6; + c++; + } + } + } + range->num_channels = c; + range->num_frequency = c; + + IW_EVENT_CAPA_SET_KERNEL(range->event_capa); + IW_EVENT_CAPA_SET(range->event_capa, SIOCGIWAP); + IW_EVENT_CAPA_SET(range->event_capa, SIOCGIWSCAN); + + if (wdev->wiphy->max_scan_ssids > 0) + range->scan_capa |= IW_SCAN_CAPA_ESSID; + + return 0; +} +EXPORT_WEXT_HANDLER(cfg80211_wext_giwrange); + + +/** + * cfg80211_wext_freq - get wext frequency for non-"auto" + * @freq: the wext freq encoding + * + * Returns a frequency, or a negative error code, or 0 for auto. + */ +int cfg80211_wext_freq(struct iw_freq *freq) +{ + /* + * Parse frequency - return 0 for auto and + * -EINVAL for impossible things. + */ + if (freq->e == 0) { + enum nl80211_band band = NL80211_BAND_2GHZ; + if (freq->m < 0) + return 0; + if (freq->m > 14) + band = NL80211_BAND_5GHZ; + return ieee80211_channel_to_frequency(freq->m, band); + } else { + int i, div = 1000000; + for (i = 0; i < freq->e; i++) + div /= 10; + if (div <= 0) + return -EINVAL; + return freq->m / div; + } +} + +int cfg80211_wext_siwrts(struct net_device *dev, + struct iw_request_info *info, + struct iw_param *rts, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + u32 orts = wdev->wiphy->rts_threshold; + int err; + + wiphy_lock(&rdev->wiphy); + if (rts->disabled || !rts->fixed) { + wdev->wiphy->rts_threshold = (u32) -1; + } else if (rts->value < 0) { + err = -EINVAL; + goto out; + } else { + wdev->wiphy->rts_threshold = rts->value; + } + + err = rdev_set_wiphy_params(rdev, WIPHY_PARAM_RTS_THRESHOLD); + + if (err) + wdev->wiphy->rts_threshold = orts; + +out: + wiphy_unlock(&rdev->wiphy); + return err; +} +EXPORT_WEXT_HANDLER(cfg80211_wext_siwrts); + +int cfg80211_wext_giwrts(struct net_device *dev, + struct iw_request_info *info, + struct iw_param *rts, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + + rts->value = wdev->wiphy->rts_threshold; + rts->disabled = rts->value == (u32) -1; + rts->fixed = 1; + + return 0; +} +EXPORT_WEXT_HANDLER(cfg80211_wext_giwrts); + +int cfg80211_wext_siwfrag(struct net_device *dev, + struct iw_request_info *info, + struct iw_param *frag, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + u32 ofrag = wdev->wiphy->frag_threshold; + int err; + + wiphy_lock(&rdev->wiphy); + if (frag->disabled || !frag->fixed) { + wdev->wiphy->frag_threshold = (u32) -1; + } else if (frag->value < 256) { + err = -EINVAL; + goto out; + } else { + /* Fragment length must be even, so strip LSB. */ + wdev->wiphy->frag_threshold = frag->value & ~0x1; + } + + err = rdev_set_wiphy_params(rdev, WIPHY_PARAM_FRAG_THRESHOLD); + if (err) + wdev->wiphy->frag_threshold = ofrag; +out: + wiphy_unlock(&rdev->wiphy); + + return err; +} +EXPORT_WEXT_HANDLER(cfg80211_wext_siwfrag); + +int cfg80211_wext_giwfrag(struct net_device *dev, + struct iw_request_info *info, + struct iw_param *frag, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + + frag->value = wdev->wiphy->frag_threshold; + frag->disabled = frag->value == (u32) -1; + frag->fixed = 1; + + return 0; +} +EXPORT_WEXT_HANDLER(cfg80211_wext_giwfrag); + +static int cfg80211_wext_siwretry(struct net_device *dev, + struct iw_request_info *info, + struct iw_param *retry, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + u32 changed = 0; + u8 olong = wdev->wiphy->retry_long; + u8 oshort = wdev->wiphy->retry_short; + int err; + + if (retry->disabled || retry->value < 1 || retry->value > 255 || + (retry->flags & IW_RETRY_TYPE) != IW_RETRY_LIMIT) + return -EINVAL; + + wiphy_lock(&rdev->wiphy); + if (retry->flags & IW_RETRY_LONG) { + wdev->wiphy->retry_long = retry->value; + changed |= WIPHY_PARAM_RETRY_LONG; + } else if (retry->flags & IW_RETRY_SHORT) { + wdev->wiphy->retry_short = retry->value; + changed |= WIPHY_PARAM_RETRY_SHORT; + } else { + wdev->wiphy->retry_short = retry->value; + wdev->wiphy->retry_long = retry->value; + changed |= WIPHY_PARAM_RETRY_LONG; + changed |= WIPHY_PARAM_RETRY_SHORT; + } + + err = rdev_set_wiphy_params(rdev, changed); + if (err) { + wdev->wiphy->retry_short = oshort; + wdev->wiphy->retry_long = olong; + } + wiphy_unlock(&rdev->wiphy); + + return err; +} + +int cfg80211_wext_giwretry(struct net_device *dev, + struct iw_request_info *info, + struct iw_param *retry, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + + retry->disabled = 0; + + if (retry->flags == 0 || (retry->flags & IW_RETRY_SHORT)) { + /* + * First return short value, iwconfig will ask long value + * later if needed + */ + retry->flags |= IW_RETRY_LIMIT | IW_RETRY_SHORT; + retry->value = wdev->wiphy->retry_short; + if (wdev->wiphy->retry_long == wdev->wiphy->retry_short) + retry->flags |= IW_RETRY_LONG; + + return 0; + } + + if (retry->flags & IW_RETRY_LONG) { + retry->flags = IW_RETRY_LIMIT | IW_RETRY_LONG; + retry->value = wdev->wiphy->retry_long; + } + + return 0; +} +EXPORT_WEXT_HANDLER(cfg80211_wext_giwretry); + +static int __cfg80211_set_encryption(struct cfg80211_registered_device *rdev, + struct net_device *dev, bool pairwise, + const u8 *addr, bool remove, bool tx_key, + int idx, struct key_params *params) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int err, i; + bool rejoin = false; + + if (wdev->valid_links) + return -EINVAL; + + if (pairwise && !addr) + return -EINVAL; + + /* + * In many cases we won't actually need this, but it's better + * to do it first in case the allocation fails. Don't use wext. + */ + if (!wdev->wext.keys) { + wdev->wext.keys = kzalloc(sizeof(*wdev->wext.keys), + GFP_KERNEL); + if (!wdev->wext.keys) + return -ENOMEM; + for (i = 0; i < CFG80211_MAX_WEP_KEYS; i++) + wdev->wext.keys->params[i].key = + wdev->wext.keys->data[i]; + } + + if (wdev->iftype != NL80211_IFTYPE_ADHOC && + wdev->iftype != NL80211_IFTYPE_STATION) + return -EOPNOTSUPP; + + if (params->cipher == WLAN_CIPHER_SUITE_AES_CMAC) { + if (!wdev->connected) + return -ENOLINK; + + if (!rdev->ops->set_default_mgmt_key) + return -EOPNOTSUPP; + + if (idx < 4 || idx > 5) + return -EINVAL; + } else if (idx < 0 || idx > 3) + return -EINVAL; + + if (remove) { + err = 0; + if (wdev->connected || + (wdev->iftype == NL80211_IFTYPE_ADHOC && + wdev->u.ibss.current_bss)) { + /* + * If removing the current TX key, we will need to + * join a new IBSS without the privacy bit clear. + */ + if (idx == wdev->wext.default_key && + wdev->iftype == NL80211_IFTYPE_ADHOC) { + __cfg80211_leave_ibss(rdev, wdev->netdev, true); + rejoin = true; + } + + if (!pairwise && addr && + !(rdev->wiphy.flags & WIPHY_FLAG_IBSS_RSN)) + err = -ENOENT; + else + err = rdev_del_key(rdev, dev, -1, idx, pairwise, + addr); + } + wdev->wext.connect.privacy = false; + /* + * Applications using wireless extensions expect to be + * able to delete keys that don't exist, so allow that. + */ + if (err == -ENOENT) + err = 0; + if (!err) { + if (!addr && idx < 4) { + memset(wdev->wext.keys->data[idx], 0, + sizeof(wdev->wext.keys->data[idx])); + wdev->wext.keys->params[idx].key_len = 0; + wdev->wext.keys->params[idx].cipher = 0; + } + if (idx == wdev->wext.default_key) + wdev->wext.default_key = -1; + else if (idx == wdev->wext.default_mgmt_key) + wdev->wext.default_mgmt_key = -1; + } + + if (!err && rejoin) + err = cfg80211_ibss_wext_join(rdev, wdev); + + return err; + } + + if (addr) + tx_key = false; + + if (cfg80211_validate_key_settings(rdev, params, idx, pairwise, addr)) + return -EINVAL; + + err = 0; + if (wdev->connected || + (wdev->iftype == NL80211_IFTYPE_ADHOC && + wdev->u.ibss.current_bss)) + err = rdev_add_key(rdev, dev, -1, idx, pairwise, addr, params); + else if (params->cipher != WLAN_CIPHER_SUITE_WEP40 && + params->cipher != WLAN_CIPHER_SUITE_WEP104) + return -EINVAL; + if (err) + return err; + + /* + * We only need to store WEP keys, since they're the only keys that + * can be set before a connection is established and persist after + * disconnecting. + */ + if (!addr && (params->cipher == WLAN_CIPHER_SUITE_WEP40 || + params->cipher == WLAN_CIPHER_SUITE_WEP104)) { + wdev->wext.keys->params[idx] = *params; + memcpy(wdev->wext.keys->data[idx], + params->key, params->key_len); + wdev->wext.keys->params[idx].key = + wdev->wext.keys->data[idx]; + } + + if ((params->cipher == WLAN_CIPHER_SUITE_WEP40 || + params->cipher == WLAN_CIPHER_SUITE_WEP104) && + (tx_key || (!addr && wdev->wext.default_key == -1))) { + if (wdev->connected || + (wdev->iftype == NL80211_IFTYPE_ADHOC && + wdev->u.ibss.current_bss)) { + /* + * If we are getting a new TX key from not having + * had one before we need to join a new IBSS with + * the privacy bit set. + */ + if (wdev->iftype == NL80211_IFTYPE_ADHOC && + wdev->wext.default_key == -1) { + __cfg80211_leave_ibss(rdev, wdev->netdev, true); + rejoin = true; + } + err = rdev_set_default_key(rdev, dev, -1, idx, true, + true); + } + if (!err) { + wdev->wext.default_key = idx; + if (rejoin) + err = cfg80211_ibss_wext_join(rdev, wdev); + } + return err; + } + + if (params->cipher == WLAN_CIPHER_SUITE_AES_CMAC && + (tx_key || (!addr && wdev->wext.default_mgmt_key == -1))) { + if (wdev->connected || + (wdev->iftype == NL80211_IFTYPE_ADHOC && + wdev->u.ibss.current_bss)) + err = rdev_set_default_mgmt_key(rdev, dev, -1, idx); + if (!err) + wdev->wext.default_mgmt_key = idx; + return err; + } + + return 0; +} + +static int cfg80211_set_encryption(struct cfg80211_registered_device *rdev, + struct net_device *dev, bool pairwise, + const u8 *addr, bool remove, bool tx_key, + int idx, struct key_params *params) +{ + int err; + + wdev_lock(dev->ieee80211_ptr); + err = __cfg80211_set_encryption(rdev, dev, pairwise, addr, + remove, tx_key, idx, params); + wdev_unlock(dev->ieee80211_ptr); + + return err; +} + +static int cfg80211_wext_siwencode(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *erq, char *keybuf) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + int idx, err; + bool remove = false; + struct key_params params; + + if (wdev->iftype != NL80211_IFTYPE_STATION && + wdev->iftype != NL80211_IFTYPE_ADHOC) + return -EOPNOTSUPP; + + /* no use -- only MFP (set_default_mgmt_key) is optional */ + if (!rdev->ops->del_key || + !rdev->ops->add_key || + !rdev->ops->set_default_key) + return -EOPNOTSUPP; + + wiphy_lock(&rdev->wiphy); + if (wdev->valid_links) { + err = -EOPNOTSUPP; + goto out; + } + + idx = erq->flags & IW_ENCODE_INDEX; + if (idx == 0) { + idx = wdev->wext.default_key; + if (idx < 0) + idx = 0; + } else if (idx < 1 || idx > 4) { + err = -EINVAL; + goto out; + } else { + idx--; + } + + if (erq->flags & IW_ENCODE_DISABLED) + remove = true; + else if (erq->length == 0) { + /* No key data - just set the default TX key index */ + err = 0; + wdev_lock(wdev); + if (wdev->connected || + (wdev->iftype == NL80211_IFTYPE_ADHOC && + wdev->u.ibss.current_bss)) + err = rdev_set_default_key(rdev, dev, -1, idx, true, + true); + if (!err) + wdev->wext.default_key = idx; + wdev_unlock(wdev); + goto out; + } + + memset(¶ms, 0, sizeof(params)); + params.key = keybuf; + params.key_len = erq->length; + if (erq->length == 5) { + params.cipher = WLAN_CIPHER_SUITE_WEP40; + } else if (erq->length == 13) { + params.cipher = WLAN_CIPHER_SUITE_WEP104; + } else if (!remove) { + err = -EINVAL; + goto out; + } + + err = cfg80211_set_encryption(rdev, dev, false, NULL, remove, + wdev->wext.default_key == -1, + idx, ¶ms); +out: + wiphy_unlock(&rdev->wiphy); + + return err; +} + +static int cfg80211_wext_siwencodeext(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *erq, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct iw_encode_ext *ext = (struct iw_encode_ext *) extra; + const u8 *addr; + int idx; + bool remove = false; + struct key_params params; + u32 cipher; + int ret; + + if (wdev->iftype != NL80211_IFTYPE_STATION && + wdev->iftype != NL80211_IFTYPE_ADHOC) + return -EOPNOTSUPP; + + /* no use -- only MFP (set_default_mgmt_key) is optional */ + if (!rdev->ops->del_key || + !rdev->ops->add_key || + !rdev->ops->set_default_key) + return -EOPNOTSUPP; + + wdev_lock(wdev); + if (wdev->valid_links) { + wdev_unlock(wdev); + return -EOPNOTSUPP; + } + wdev_unlock(wdev); + + switch (ext->alg) { + case IW_ENCODE_ALG_NONE: + remove = true; + cipher = 0; + break; + case IW_ENCODE_ALG_WEP: + if (ext->key_len == 5) + cipher = WLAN_CIPHER_SUITE_WEP40; + else if (ext->key_len == 13) + cipher = WLAN_CIPHER_SUITE_WEP104; + else + return -EINVAL; + break; + case IW_ENCODE_ALG_TKIP: + cipher = WLAN_CIPHER_SUITE_TKIP; + break; + case IW_ENCODE_ALG_CCMP: + cipher = WLAN_CIPHER_SUITE_CCMP; + break; + case IW_ENCODE_ALG_AES_CMAC: + cipher = WLAN_CIPHER_SUITE_AES_CMAC; + break; + default: + return -EOPNOTSUPP; + } + + if (erq->flags & IW_ENCODE_DISABLED) + remove = true; + + idx = erq->flags & IW_ENCODE_INDEX; + if (cipher == WLAN_CIPHER_SUITE_AES_CMAC) { + if (idx < 4 || idx > 5) { + idx = wdev->wext.default_mgmt_key; + if (idx < 0) + return -EINVAL; + } else + idx--; + } else { + if (idx < 1 || idx > 4) { + idx = wdev->wext.default_key; + if (idx < 0) + return -EINVAL; + } else + idx--; + } + + addr = ext->addr.sa_data; + if (is_broadcast_ether_addr(addr)) + addr = NULL; + + memset(¶ms, 0, sizeof(params)); + params.key = ext->key; + params.key_len = ext->key_len; + params.cipher = cipher; + + if (ext->ext_flags & IW_ENCODE_EXT_RX_SEQ_VALID) { + params.seq = ext->rx_seq; + params.seq_len = 6; + } + + wiphy_lock(wdev->wiphy); + ret = cfg80211_set_encryption( + rdev, dev, + !(ext->ext_flags & IW_ENCODE_EXT_GROUP_KEY), + addr, remove, + ext->ext_flags & IW_ENCODE_EXT_SET_TX_KEY, + idx, ¶ms); + wiphy_unlock(wdev->wiphy); + + return ret; +} + +static int cfg80211_wext_giwencode(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *erq, char *keybuf) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int idx; + + if (wdev->iftype != NL80211_IFTYPE_STATION && + wdev->iftype != NL80211_IFTYPE_ADHOC) + return -EOPNOTSUPP; + + idx = erq->flags & IW_ENCODE_INDEX; + if (idx == 0) { + idx = wdev->wext.default_key; + if (idx < 0) + idx = 0; + } else if (idx < 1 || idx > 4) + return -EINVAL; + else + idx--; + + erq->flags = idx + 1; + + if (!wdev->wext.keys || !wdev->wext.keys->params[idx].cipher) { + erq->flags |= IW_ENCODE_DISABLED; + erq->length = 0; + return 0; + } + + erq->length = min_t(size_t, erq->length, + wdev->wext.keys->params[idx].key_len); + memcpy(keybuf, wdev->wext.keys->params[idx].key, erq->length); + erq->flags |= IW_ENCODE_ENABLED; + + return 0; +} + +static int cfg80211_wext_siwfreq(struct net_device *dev, + struct iw_request_info *info, + struct iw_freq *wextfreq, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_chan_def chandef = { + .width = NL80211_CHAN_WIDTH_20_NOHT, + }; + int freq, ret; + + wiphy_lock(&rdev->wiphy); + + switch (wdev->iftype) { + case NL80211_IFTYPE_STATION: + ret = cfg80211_mgd_wext_siwfreq(dev, info, wextfreq, extra); + break; + case NL80211_IFTYPE_ADHOC: + ret = cfg80211_ibss_wext_siwfreq(dev, info, wextfreq, extra); + break; + case NL80211_IFTYPE_MONITOR: + freq = cfg80211_wext_freq(wextfreq); + if (freq < 0) { + ret = freq; + break; + } + if (freq == 0) { + ret = -EINVAL; + break; + } + chandef.center_freq1 = freq; + chandef.chan = ieee80211_get_channel(&rdev->wiphy, freq); + if (!chandef.chan) { + ret = -EINVAL; + break; + } + ret = cfg80211_set_monitor_channel(rdev, &chandef); + break; + case NL80211_IFTYPE_MESH_POINT: + freq = cfg80211_wext_freq(wextfreq); + if (freq < 0) { + ret = freq; + break; + } + if (freq == 0) { + ret = -EINVAL; + break; + } + chandef.center_freq1 = freq; + chandef.chan = ieee80211_get_channel(&rdev->wiphy, freq); + if (!chandef.chan) { + ret = -EINVAL; + break; + } + ret = cfg80211_set_mesh_channel(rdev, wdev, &chandef); + break; + default: + ret = -EOPNOTSUPP; + break; + } + + wiphy_unlock(&rdev->wiphy); + + return ret; +} + +static int cfg80211_wext_giwfreq(struct net_device *dev, + struct iw_request_info *info, + struct iw_freq *freq, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_chan_def chandef = {}; + int ret; + + wiphy_lock(&rdev->wiphy); + switch (wdev->iftype) { + case NL80211_IFTYPE_STATION: + ret = cfg80211_mgd_wext_giwfreq(dev, info, freq, extra); + break; + case NL80211_IFTYPE_ADHOC: + ret = cfg80211_ibss_wext_giwfreq(dev, info, freq, extra); + break; + case NL80211_IFTYPE_MONITOR: + if (!rdev->ops->get_channel) { + ret = -EINVAL; + break; + } + + ret = rdev_get_channel(rdev, wdev, 0, &chandef); + if (ret) + break; + freq->m = chandef.chan->center_freq; + freq->e = 6; + ret = 0; + break; + default: + ret = -EINVAL; + break; + } + + wiphy_unlock(&rdev->wiphy); + + return ret; +} + +static int cfg80211_wext_siwtxpower(struct net_device *dev, + struct iw_request_info *info, + union iwreq_data *data, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + enum nl80211_tx_power_setting type; + int dbm = 0; + int ret; + + if ((data->txpower.flags & IW_TXPOW_TYPE) != IW_TXPOW_DBM) + return -EINVAL; + if (data->txpower.flags & IW_TXPOW_RANGE) + return -EINVAL; + + if (!rdev->ops->set_tx_power) + return -EOPNOTSUPP; + + /* only change when not disabling */ + if (!data->txpower.disabled) { + rfkill_set_sw_state(rdev->wiphy.rfkill, false); + + if (data->txpower.fixed) { + /* + * wext doesn't support negative values, see + * below where it's for automatic + */ + if (data->txpower.value < 0) + return -EINVAL; + dbm = data->txpower.value; + type = NL80211_TX_POWER_FIXED; + /* TODO: do regulatory check! */ + } else { + /* + * Automatic power level setting, max being the value + * passed in from userland. + */ + if (data->txpower.value < 0) { + type = NL80211_TX_POWER_AUTOMATIC; + } else { + dbm = data->txpower.value; + type = NL80211_TX_POWER_LIMITED; + } + } + } else { + if (rfkill_set_sw_state(rdev->wiphy.rfkill, true)) + schedule_work(&rdev->rfkill_block); + return 0; + } + + wiphy_lock(&rdev->wiphy); + ret = rdev_set_tx_power(rdev, wdev, type, DBM_TO_MBM(dbm)); + wiphy_unlock(&rdev->wiphy); + + return ret; +} + +static int cfg80211_wext_giwtxpower(struct net_device *dev, + struct iw_request_info *info, + union iwreq_data *data, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + int err, val; + + if ((data->txpower.flags & IW_TXPOW_TYPE) != IW_TXPOW_DBM) + return -EINVAL; + if (data->txpower.flags & IW_TXPOW_RANGE) + return -EINVAL; + + if (!rdev->ops->get_tx_power) + return -EOPNOTSUPP; + + wiphy_lock(&rdev->wiphy); + err = rdev_get_tx_power(rdev, wdev, &val); + wiphy_unlock(&rdev->wiphy); + if (err) + return err; + + /* well... oh well */ + data->txpower.fixed = 1; + data->txpower.disabled = rfkill_blocked(rdev->wiphy.rfkill); + data->txpower.value = val; + data->txpower.flags = IW_TXPOW_DBM; + + return 0; +} + +static int cfg80211_set_auth_alg(struct wireless_dev *wdev, + s32 auth_alg) +{ + int nr_alg = 0; + + if (!auth_alg) + return -EINVAL; + + if (auth_alg & ~(IW_AUTH_ALG_OPEN_SYSTEM | + IW_AUTH_ALG_SHARED_KEY | + IW_AUTH_ALG_LEAP)) + return -EINVAL; + + if (auth_alg & IW_AUTH_ALG_OPEN_SYSTEM) { + nr_alg++; + wdev->wext.connect.auth_type = NL80211_AUTHTYPE_OPEN_SYSTEM; + } + + if (auth_alg & IW_AUTH_ALG_SHARED_KEY) { + nr_alg++; + wdev->wext.connect.auth_type = NL80211_AUTHTYPE_SHARED_KEY; + } + + if (auth_alg & IW_AUTH_ALG_LEAP) { + nr_alg++; + wdev->wext.connect.auth_type = NL80211_AUTHTYPE_NETWORK_EAP; + } + + if (nr_alg > 1) + wdev->wext.connect.auth_type = NL80211_AUTHTYPE_AUTOMATIC; + + return 0; +} + +static int cfg80211_set_wpa_version(struct wireless_dev *wdev, u32 wpa_versions) +{ + if (wpa_versions & ~(IW_AUTH_WPA_VERSION_WPA | + IW_AUTH_WPA_VERSION_WPA2| + IW_AUTH_WPA_VERSION_DISABLED)) + return -EINVAL; + + if ((wpa_versions & IW_AUTH_WPA_VERSION_DISABLED) && + (wpa_versions & (IW_AUTH_WPA_VERSION_WPA| + IW_AUTH_WPA_VERSION_WPA2))) + return -EINVAL; + + if (wpa_versions & IW_AUTH_WPA_VERSION_DISABLED) + wdev->wext.connect.crypto.wpa_versions &= + ~(NL80211_WPA_VERSION_1|NL80211_WPA_VERSION_2); + + if (wpa_versions & IW_AUTH_WPA_VERSION_WPA) + wdev->wext.connect.crypto.wpa_versions |= + NL80211_WPA_VERSION_1; + + if (wpa_versions & IW_AUTH_WPA_VERSION_WPA2) + wdev->wext.connect.crypto.wpa_versions |= + NL80211_WPA_VERSION_2; + + return 0; +} + +static int cfg80211_set_cipher_group(struct wireless_dev *wdev, u32 cipher) +{ + if (cipher & IW_AUTH_CIPHER_WEP40) + wdev->wext.connect.crypto.cipher_group = + WLAN_CIPHER_SUITE_WEP40; + else if (cipher & IW_AUTH_CIPHER_WEP104) + wdev->wext.connect.crypto.cipher_group = + WLAN_CIPHER_SUITE_WEP104; + else if (cipher & IW_AUTH_CIPHER_TKIP) + wdev->wext.connect.crypto.cipher_group = + WLAN_CIPHER_SUITE_TKIP; + else if (cipher & IW_AUTH_CIPHER_CCMP) + wdev->wext.connect.crypto.cipher_group = + WLAN_CIPHER_SUITE_CCMP; + else if (cipher & IW_AUTH_CIPHER_AES_CMAC) + wdev->wext.connect.crypto.cipher_group = + WLAN_CIPHER_SUITE_AES_CMAC; + else if (cipher & IW_AUTH_CIPHER_NONE) + wdev->wext.connect.crypto.cipher_group = 0; + else + return -EINVAL; + + return 0; +} + +static int cfg80211_set_cipher_pairwise(struct wireless_dev *wdev, u32 cipher) +{ + int nr_ciphers = 0; + u32 *ciphers_pairwise = wdev->wext.connect.crypto.ciphers_pairwise; + + if (cipher & IW_AUTH_CIPHER_WEP40) { + ciphers_pairwise[nr_ciphers] = WLAN_CIPHER_SUITE_WEP40; + nr_ciphers++; + } + + if (cipher & IW_AUTH_CIPHER_WEP104) { + ciphers_pairwise[nr_ciphers] = WLAN_CIPHER_SUITE_WEP104; + nr_ciphers++; + } + + if (cipher & IW_AUTH_CIPHER_TKIP) { + ciphers_pairwise[nr_ciphers] = WLAN_CIPHER_SUITE_TKIP; + nr_ciphers++; + } + + if (cipher & IW_AUTH_CIPHER_CCMP) { + ciphers_pairwise[nr_ciphers] = WLAN_CIPHER_SUITE_CCMP; + nr_ciphers++; + } + + if (cipher & IW_AUTH_CIPHER_AES_CMAC) { + ciphers_pairwise[nr_ciphers] = WLAN_CIPHER_SUITE_AES_CMAC; + nr_ciphers++; + } + + BUILD_BUG_ON(NL80211_MAX_NR_CIPHER_SUITES < 5); + + wdev->wext.connect.crypto.n_ciphers_pairwise = nr_ciphers; + + return 0; +} + + +static int cfg80211_set_key_mgt(struct wireless_dev *wdev, u32 key_mgt) +{ + int nr_akm_suites = 0; + + if (key_mgt & ~(IW_AUTH_KEY_MGMT_802_1X | + IW_AUTH_KEY_MGMT_PSK)) + return -EINVAL; + + if (key_mgt & IW_AUTH_KEY_MGMT_802_1X) { + wdev->wext.connect.crypto.akm_suites[nr_akm_suites] = + WLAN_AKM_SUITE_8021X; + nr_akm_suites++; + } + + if (key_mgt & IW_AUTH_KEY_MGMT_PSK) { + wdev->wext.connect.crypto.akm_suites[nr_akm_suites] = + WLAN_AKM_SUITE_PSK; + nr_akm_suites++; + } + + wdev->wext.connect.crypto.n_akm_suites = nr_akm_suites; + + return 0; +} + +static int cfg80211_wext_siwauth(struct net_device *dev, + struct iw_request_info *info, + struct iw_param *data, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + + if (wdev->iftype != NL80211_IFTYPE_STATION) + return -EOPNOTSUPP; + + switch (data->flags & IW_AUTH_INDEX) { + case IW_AUTH_PRIVACY_INVOKED: + wdev->wext.connect.privacy = data->value; + return 0; + case IW_AUTH_WPA_VERSION: + return cfg80211_set_wpa_version(wdev, data->value); + case IW_AUTH_CIPHER_GROUP: + return cfg80211_set_cipher_group(wdev, data->value); + case IW_AUTH_KEY_MGMT: + return cfg80211_set_key_mgt(wdev, data->value); + case IW_AUTH_CIPHER_PAIRWISE: + return cfg80211_set_cipher_pairwise(wdev, data->value); + case IW_AUTH_80211_AUTH_ALG: + return cfg80211_set_auth_alg(wdev, data->value); + case IW_AUTH_WPA_ENABLED: + case IW_AUTH_RX_UNENCRYPTED_EAPOL: + case IW_AUTH_DROP_UNENCRYPTED: + case IW_AUTH_MFP: + return 0; + default: + return -EOPNOTSUPP; + } +} + +static int cfg80211_wext_giwauth(struct net_device *dev, + struct iw_request_info *info, + struct iw_param *data, char *extra) +{ + /* XXX: what do we need? */ + + return -EOPNOTSUPP; +} + +static int cfg80211_wext_siwpower(struct net_device *dev, + struct iw_request_info *info, + struct iw_param *wrq, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + bool ps; + int timeout = wdev->ps_timeout; + int err; + + if (wdev->iftype != NL80211_IFTYPE_STATION) + return -EINVAL; + + if (!rdev->ops->set_power_mgmt) + return -EOPNOTSUPP; + + if (wrq->disabled) { + ps = false; + } else { + switch (wrq->flags & IW_POWER_MODE) { + case IW_POWER_ON: /* If not specified */ + case IW_POWER_MODE: /* If set all mask */ + case IW_POWER_ALL_R: /* If explicitely state all */ + ps = true; + break; + default: /* Otherwise we ignore */ + return -EINVAL; + } + + if (wrq->flags & ~(IW_POWER_MODE | IW_POWER_TIMEOUT)) + return -EINVAL; + + if (wrq->flags & IW_POWER_TIMEOUT) + timeout = wrq->value / 1000; + } + + wiphy_lock(&rdev->wiphy); + err = rdev_set_power_mgmt(rdev, dev, ps, timeout); + wiphy_unlock(&rdev->wiphy); + if (err) + return err; + + wdev->ps = ps; + wdev->ps_timeout = timeout; + + return 0; + +} + +static int cfg80211_wext_giwpower(struct net_device *dev, + struct iw_request_info *info, + struct iw_param *wrq, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + + wrq->disabled = !wdev->ps; + + return 0; +} + +static int cfg80211_wext_siwrate(struct net_device *dev, + struct iw_request_info *info, + struct iw_param *rate, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_bitrate_mask mask; + u32 fixed, maxrate; + struct ieee80211_supported_band *sband; + int band, ridx, ret; + bool match = false; + + if (!rdev->ops->set_bitrate_mask) + return -EOPNOTSUPP; + + memset(&mask, 0, sizeof(mask)); + fixed = 0; + maxrate = (u32)-1; + + if (rate->value < 0) { + /* nothing */ + } else if (rate->fixed) { + fixed = rate->value / 100000; + } else { + maxrate = rate->value / 100000; + } + + for (band = 0; band < NUM_NL80211_BANDS; band++) { + sband = wdev->wiphy->bands[band]; + if (sband == NULL) + continue; + for (ridx = 0; ridx < sband->n_bitrates; ridx++) { + struct ieee80211_rate *srate = &sband->bitrates[ridx]; + if (fixed == srate->bitrate) { + mask.control[band].legacy = 1 << ridx; + match = true; + break; + } + if (srate->bitrate <= maxrate) { + mask.control[band].legacy |= 1 << ridx; + match = true; + } + } + } + + if (!match) + return -EINVAL; + + wiphy_lock(&rdev->wiphy); + if (dev->ieee80211_ptr->valid_links) + ret = -EOPNOTSUPP; + else + ret = rdev_set_bitrate_mask(rdev, dev, 0, NULL, &mask); + wiphy_unlock(&rdev->wiphy); + + return ret; +} + +static int cfg80211_wext_giwrate(struct net_device *dev, + struct iw_request_info *info, + struct iw_param *rate, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct station_info sinfo = {}; + u8 addr[ETH_ALEN]; + int err; + + if (wdev->iftype != NL80211_IFTYPE_STATION) + return -EOPNOTSUPP; + + if (!rdev->ops->get_station) + return -EOPNOTSUPP; + + err = 0; + wdev_lock(wdev); + if (!wdev->valid_links && wdev->links[0].client.current_bss) + memcpy(addr, wdev->links[0].client.current_bss->pub.bssid, + ETH_ALEN); + else + err = -EOPNOTSUPP; + wdev_unlock(wdev); + if (err) + return err; + + wiphy_lock(&rdev->wiphy); + err = rdev_get_station(rdev, dev, addr, &sinfo); + wiphy_unlock(&rdev->wiphy); + if (err) + return err; + + if (!(sinfo.filled & BIT_ULL(NL80211_STA_INFO_TX_BITRATE))) { + err = -EOPNOTSUPP; + goto free; + } + + rate->value = 100000 * cfg80211_calculate_bitrate(&sinfo.txrate); + +free: + cfg80211_sinfo_release_content(&sinfo); + return err; +} + +/* Get wireless statistics. Called by /proc/net/wireless and by SIOCGIWSTATS */ +static struct iw_statistics *cfg80211_wireless_stats(struct net_device *dev) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + /* we are under RTNL - globally locked - so can use static structs */ + static struct iw_statistics wstats; + static struct station_info sinfo = {}; + u8 bssid[ETH_ALEN]; + int ret; + + if (dev->ieee80211_ptr->iftype != NL80211_IFTYPE_STATION) + return NULL; + + if (!rdev->ops->get_station) + return NULL; + + /* Grab BSSID of current BSS, if any */ + wdev_lock(wdev); + if (wdev->valid_links || !wdev->links[0].client.current_bss) { + wdev_unlock(wdev); + return NULL; + } + memcpy(bssid, wdev->links[0].client.current_bss->pub.bssid, ETH_ALEN); + wdev_unlock(wdev); + + memset(&sinfo, 0, sizeof(sinfo)); + + wiphy_lock(&rdev->wiphy); + ret = rdev_get_station(rdev, dev, bssid, &sinfo); + wiphy_unlock(&rdev->wiphy); + + if (ret) + return NULL; + + memset(&wstats, 0, sizeof(wstats)); + + switch (rdev->wiphy.signal_type) { + case CFG80211_SIGNAL_TYPE_MBM: + if (sinfo.filled & BIT_ULL(NL80211_STA_INFO_SIGNAL)) { + int sig = sinfo.signal; + wstats.qual.updated |= IW_QUAL_LEVEL_UPDATED; + wstats.qual.updated |= IW_QUAL_QUAL_UPDATED; + wstats.qual.updated |= IW_QUAL_DBM; + wstats.qual.level = sig; + if (sig < -110) + sig = -110; + else if (sig > -40) + sig = -40; + wstats.qual.qual = sig + 110; + break; + } + fallthrough; + case CFG80211_SIGNAL_TYPE_UNSPEC: + if (sinfo.filled & BIT_ULL(NL80211_STA_INFO_SIGNAL)) { + wstats.qual.updated |= IW_QUAL_LEVEL_UPDATED; + wstats.qual.updated |= IW_QUAL_QUAL_UPDATED; + wstats.qual.level = sinfo.signal; + wstats.qual.qual = sinfo.signal; + break; + } + fallthrough; + default: + wstats.qual.updated |= IW_QUAL_LEVEL_INVALID; + wstats.qual.updated |= IW_QUAL_QUAL_INVALID; + } + + wstats.qual.updated |= IW_QUAL_NOISE_INVALID; + if (sinfo.filled & BIT_ULL(NL80211_STA_INFO_RX_DROP_MISC)) + wstats.discard.misc = sinfo.rx_dropped_misc; + if (sinfo.filled & BIT_ULL(NL80211_STA_INFO_TX_FAILED)) + wstats.discard.retries = sinfo.tx_failed; + + cfg80211_sinfo_release_content(&sinfo); + + return &wstats; +} + +static int cfg80211_wext_siwap(struct net_device *dev, + struct iw_request_info *info, + struct sockaddr *ap_addr, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + int ret; + + wiphy_lock(&rdev->wiphy); + switch (wdev->iftype) { + case NL80211_IFTYPE_ADHOC: + ret = cfg80211_ibss_wext_siwap(dev, info, ap_addr, extra); + break; + case NL80211_IFTYPE_STATION: + ret = cfg80211_mgd_wext_siwap(dev, info, ap_addr, extra); + break; + default: + ret = -EOPNOTSUPP; + break; + } + wiphy_unlock(&rdev->wiphy); + + return ret; +} + +static int cfg80211_wext_giwap(struct net_device *dev, + struct iw_request_info *info, + struct sockaddr *ap_addr, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + int ret; + + wiphy_lock(&rdev->wiphy); + switch (wdev->iftype) { + case NL80211_IFTYPE_ADHOC: + ret = cfg80211_ibss_wext_giwap(dev, info, ap_addr, extra); + break; + case NL80211_IFTYPE_STATION: + ret = cfg80211_mgd_wext_giwap(dev, info, ap_addr, extra); + break; + default: + ret = -EOPNOTSUPP; + break; + } + wiphy_unlock(&rdev->wiphy); + + return ret; +} + +static int cfg80211_wext_siwessid(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *data, char *ssid) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + int ret; + + wiphy_lock(&rdev->wiphy); + switch (wdev->iftype) { + case NL80211_IFTYPE_ADHOC: + ret = cfg80211_ibss_wext_siwessid(dev, info, data, ssid); + break; + case NL80211_IFTYPE_STATION: + ret = cfg80211_mgd_wext_siwessid(dev, info, data, ssid); + break; + default: + ret = -EOPNOTSUPP; + break; + } + wiphy_unlock(&rdev->wiphy); + + return ret; +} + +static int cfg80211_wext_giwessid(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *data, char *ssid) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + int ret; + + data->flags = 0; + data->length = 0; + + wiphy_lock(&rdev->wiphy); + switch (wdev->iftype) { + case NL80211_IFTYPE_ADHOC: + ret = cfg80211_ibss_wext_giwessid(dev, info, data, ssid); + break; + case NL80211_IFTYPE_STATION: + ret = cfg80211_mgd_wext_giwessid(dev, info, data, ssid); + break; + default: + ret = -EOPNOTSUPP; + break; + } + wiphy_unlock(&rdev->wiphy); + + return ret; +} + +static int cfg80211_wext_siwpmksa(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *data, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct cfg80211_pmksa cfg_pmksa; + struct iw_pmksa *pmksa = (struct iw_pmksa *)extra; + int ret; + + memset(&cfg_pmksa, 0, sizeof(struct cfg80211_pmksa)); + + if (wdev->iftype != NL80211_IFTYPE_STATION) + return -EINVAL; + + cfg_pmksa.bssid = pmksa->bssid.sa_data; + cfg_pmksa.pmkid = pmksa->pmkid; + + wiphy_lock(&rdev->wiphy); + switch (pmksa->cmd) { + case IW_PMKSA_ADD: + if (!rdev->ops->set_pmksa) { + ret = -EOPNOTSUPP; + break; + } + + ret = rdev_set_pmksa(rdev, dev, &cfg_pmksa); + break; + case IW_PMKSA_REMOVE: + if (!rdev->ops->del_pmksa) { + ret = -EOPNOTSUPP; + break; + } + + ret = rdev_del_pmksa(rdev, dev, &cfg_pmksa); + break; + case IW_PMKSA_FLUSH: + if (!rdev->ops->flush_pmksa) { + ret = -EOPNOTSUPP; + break; + } + + ret = rdev_flush_pmksa(rdev, dev); + break; + default: + ret = -EOPNOTSUPP; + break; + } + wiphy_unlock(&rdev->wiphy); + + return ret; +} + +#define DEFINE_WEXT_COMPAT_STUB(func, type) \ + static int __ ## func(struct net_device *dev, \ + struct iw_request_info *info, \ + union iwreq_data *wrqu, \ + char *extra) \ + { \ + return func(dev, info, (type *)wrqu, extra); \ + } + +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwname, char) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwfreq, struct iw_freq) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwfreq, struct iw_freq) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwmode, u32) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwmode, u32) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwrange, struct iw_point) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwap, struct sockaddr) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwap, struct sockaddr) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwmlme, struct iw_point) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwscan, struct iw_point) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwessid, struct iw_point) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwessid, struct iw_point) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwrate, struct iw_param) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwrate, struct iw_param) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwrts, struct iw_param) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwrts, struct iw_param) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwfrag, struct iw_param) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwfrag, struct iw_param) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwretry, struct iw_param) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwretry, struct iw_param) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwencode, struct iw_point) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwencode, struct iw_point) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwpower, struct iw_param) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwpower, struct iw_param) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwgenie, struct iw_point) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_giwauth, struct iw_param) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwauth, struct iw_param) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwencodeext, struct iw_point) +DEFINE_WEXT_COMPAT_STUB(cfg80211_wext_siwpmksa, struct iw_point) + +static const iw_handler cfg80211_handlers[] = { + [IW_IOCTL_IDX(SIOCGIWNAME)] = __cfg80211_wext_giwname, + [IW_IOCTL_IDX(SIOCSIWFREQ)] = __cfg80211_wext_siwfreq, + [IW_IOCTL_IDX(SIOCGIWFREQ)] = __cfg80211_wext_giwfreq, + [IW_IOCTL_IDX(SIOCSIWMODE)] = __cfg80211_wext_siwmode, + [IW_IOCTL_IDX(SIOCGIWMODE)] = __cfg80211_wext_giwmode, + [IW_IOCTL_IDX(SIOCGIWRANGE)] = __cfg80211_wext_giwrange, + [IW_IOCTL_IDX(SIOCSIWAP)] = __cfg80211_wext_siwap, + [IW_IOCTL_IDX(SIOCGIWAP)] = __cfg80211_wext_giwap, + [IW_IOCTL_IDX(SIOCSIWMLME)] = __cfg80211_wext_siwmlme, + [IW_IOCTL_IDX(SIOCSIWSCAN)] = cfg80211_wext_siwscan, + [IW_IOCTL_IDX(SIOCGIWSCAN)] = __cfg80211_wext_giwscan, + [IW_IOCTL_IDX(SIOCSIWESSID)] = __cfg80211_wext_siwessid, + [IW_IOCTL_IDX(SIOCGIWESSID)] = __cfg80211_wext_giwessid, + [IW_IOCTL_IDX(SIOCSIWRATE)] = __cfg80211_wext_siwrate, + [IW_IOCTL_IDX(SIOCGIWRATE)] = __cfg80211_wext_giwrate, + [IW_IOCTL_IDX(SIOCSIWRTS)] = __cfg80211_wext_siwrts, + [IW_IOCTL_IDX(SIOCGIWRTS)] = __cfg80211_wext_giwrts, + [IW_IOCTL_IDX(SIOCSIWFRAG)] = __cfg80211_wext_siwfrag, + [IW_IOCTL_IDX(SIOCGIWFRAG)] = __cfg80211_wext_giwfrag, + [IW_IOCTL_IDX(SIOCSIWTXPOW)] = cfg80211_wext_siwtxpower, + [IW_IOCTL_IDX(SIOCGIWTXPOW)] = cfg80211_wext_giwtxpower, + [IW_IOCTL_IDX(SIOCSIWRETRY)] = __cfg80211_wext_siwretry, + [IW_IOCTL_IDX(SIOCGIWRETRY)] = __cfg80211_wext_giwretry, + [IW_IOCTL_IDX(SIOCSIWENCODE)] = __cfg80211_wext_siwencode, + [IW_IOCTL_IDX(SIOCGIWENCODE)] = __cfg80211_wext_giwencode, + [IW_IOCTL_IDX(SIOCSIWPOWER)] = __cfg80211_wext_siwpower, + [IW_IOCTL_IDX(SIOCGIWPOWER)] = __cfg80211_wext_giwpower, + [IW_IOCTL_IDX(SIOCSIWGENIE)] = __cfg80211_wext_siwgenie, + [IW_IOCTL_IDX(SIOCSIWAUTH)] = __cfg80211_wext_siwauth, + [IW_IOCTL_IDX(SIOCGIWAUTH)] = __cfg80211_wext_giwauth, + [IW_IOCTL_IDX(SIOCSIWENCODEEXT)]= __cfg80211_wext_siwencodeext, + [IW_IOCTL_IDX(SIOCSIWPMKSA)] = __cfg80211_wext_siwpmksa, +}; + +const struct iw_handler_def cfg80211_wext_handler = { + .num_standard = ARRAY_SIZE(cfg80211_handlers), + .standard = cfg80211_handlers, + .get_wireless_stats = cfg80211_wireless_stats, +}; diff --git a/net/wireless/wext-compat.h b/net/wireless/wext-compat.h new file mode 100644 index 000000000..8d3cc1552 --- /dev/null +++ b/net/wireless/wext-compat.h @@ -0,0 +1,64 @@ +/* SPDX-License-Identifier: GPL-2.0-only */ +#ifndef __WEXT_COMPAT +#define __WEXT_COMPAT + +#include +#include + +#ifdef CONFIG_CFG80211_WEXT_EXPORT +#define EXPORT_WEXT_HANDLER(h) EXPORT_SYMBOL_GPL(h) +#else +#define EXPORT_WEXT_HANDLER(h) +#endif /* CONFIG_CFG80211_WEXT_EXPORT */ + +int cfg80211_ibss_wext_siwfreq(struct net_device *dev, + struct iw_request_info *info, + struct iw_freq *freq, char *extra); +int cfg80211_ibss_wext_giwfreq(struct net_device *dev, + struct iw_request_info *info, + struct iw_freq *freq, char *extra); +int cfg80211_ibss_wext_siwap(struct net_device *dev, + struct iw_request_info *info, + struct sockaddr *ap_addr, char *extra); +int cfg80211_ibss_wext_giwap(struct net_device *dev, + struct iw_request_info *info, + struct sockaddr *ap_addr, char *extra); +int cfg80211_ibss_wext_siwessid(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *data, char *ssid); +int cfg80211_ibss_wext_giwessid(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *data, char *ssid); + +int cfg80211_mgd_wext_siwfreq(struct net_device *dev, + struct iw_request_info *info, + struct iw_freq *freq, char *extra); +int cfg80211_mgd_wext_giwfreq(struct net_device *dev, + struct iw_request_info *info, + struct iw_freq *freq, char *extra); +int cfg80211_mgd_wext_siwap(struct net_device *dev, + struct iw_request_info *info, + struct sockaddr *ap_addr, char *extra); +int cfg80211_mgd_wext_giwap(struct net_device *dev, + struct iw_request_info *info, + struct sockaddr *ap_addr, char *extra); +int cfg80211_mgd_wext_siwessid(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *data, char *ssid); +int cfg80211_mgd_wext_giwessid(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *data, char *ssid); + +int cfg80211_wext_siwmlme(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *data, char *extra); +int cfg80211_wext_siwgenie(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *data, char *extra); + + +int cfg80211_wext_freq(struct iw_freq *freq); + + +extern const struct iw_handler_def cfg80211_wext_handler; +#endif /* __WEXT_COMPAT */ diff --git a/net/wireless/wext-core.c b/net/wireless/wext-core.c new file mode 100644 index 000000000..8a4b85f96 --- /dev/null +++ b/net/wireless/wext-core.c @@ -0,0 +1,1198 @@ +/* + * This file implement the Wireless Extensions core API. + * + * Authors : Jean Tourrilhes - HPL - + * Copyright (c) 1997-2007 Jean Tourrilhes, All Rights Reserved. + * Copyright 2009 Johannes Berg + * + * (As all part of the Linux kernel, this file is GPL) + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +typedef int (*wext_ioctl_func)(struct net_device *, struct iwreq *, + unsigned int, struct iw_request_info *, + iw_handler); + + +/* + * Meta-data about all the standard Wireless Extension request we + * know about. + */ +static const struct iw_ioctl_description standard_ioctl[] = { + [IW_IOCTL_IDX(SIOCSIWCOMMIT)] = { + .header_type = IW_HEADER_TYPE_NULL, + }, + [IW_IOCTL_IDX(SIOCGIWNAME)] = { + .header_type = IW_HEADER_TYPE_CHAR, + .flags = IW_DESCR_FLAG_DUMP, + }, + [IW_IOCTL_IDX(SIOCSIWNWID)] = { + .header_type = IW_HEADER_TYPE_PARAM, + .flags = IW_DESCR_FLAG_EVENT, + }, + [IW_IOCTL_IDX(SIOCGIWNWID)] = { + .header_type = IW_HEADER_TYPE_PARAM, + .flags = IW_DESCR_FLAG_DUMP, + }, + [IW_IOCTL_IDX(SIOCSIWFREQ)] = { + .header_type = IW_HEADER_TYPE_FREQ, + .flags = IW_DESCR_FLAG_EVENT, + }, + [IW_IOCTL_IDX(SIOCGIWFREQ)] = { + .header_type = IW_HEADER_TYPE_FREQ, + .flags = IW_DESCR_FLAG_DUMP, + }, + [IW_IOCTL_IDX(SIOCSIWMODE)] = { + .header_type = IW_HEADER_TYPE_UINT, + .flags = IW_DESCR_FLAG_EVENT, + }, + [IW_IOCTL_IDX(SIOCGIWMODE)] = { + .header_type = IW_HEADER_TYPE_UINT, + .flags = IW_DESCR_FLAG_DUMP, + }, + [IW_IOCTL_IDX(SIOCSIWSENS)] = { + .header_type = IW_HEADER_TYPE_PARAM, + }, + [IW_IOCTL_IDX(SIOCGIWSENS)] = { + .header_type = IW_HEADER_TYPE_PARAM, + }, + [IW_IOCTL_IDX(SIOCSIWRANGE)] = { + .header_type = IW_HEADER_TYPE_NULL, + }, + [IW_IOCTL_IDX(SIOCGIWRANGE)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .max_tokens = sizeof(struct iw_range), + .flags = IW_DESCR_FLAG_DUMP, + }, + [IW_IOCTL_IDX(SIOCSIWPRIV)] = { + .header_type = IW_HEADER_TYPE_NULL, + }, + [IW_IOCTL_IDX(SIOCGIWPRIV)] = { /* (handled directly by us) */ + .header_type = IW_HEADER_TYPE_POINT, + .token_size = sizeof(struct iw_priv_args), + .max_tokens = 16, + .flags = IW_DESCR_FLAG_NOMAX, + }, + [IW_IOCTL_IDX(SIOCSIWSTATS)] = { + .header_type = IW_HEADER_TYPE_NULL, + }, + [IW_IOCTL_IDX(SIOCGIWSTATS)] = { /* (handled directly by us) */ + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .max_tokens = sizeof(struct iw_statistics), + .flags = IW_DESCR_FLAG_DUMP, + }, + [IW_IOCTL_IDX(SIOCSIWSPY)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = sizeof(struct sockaddr), + .max_tokens = IW_MAX_SPY, + }, + [IW_IOCTL_IDX(SIOCGIWSPY)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = sizeof(struct sockaddr) + + sizeof(struct iw_quality), + .max_tokens = IW_MAX_SPY, + }, + [IW_IOCTL_IDX(SIOCSIWTHRSPY)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = sizeof(struct iw_thrspy), + .min_tokens = 1, + .max_tokens = 1, + }, + [IW_IOCTL_IDX(SIOCGIWTHRSPY)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = sizeof(struct iw_thrspy), + .min_tokens = 1, + .max_tokens = 1, + }, + [IW_IOCTL_IDX(SIOCSIWAP)] = { + .header_type = IW_HEADER_TYPE_ADDR, + }, + [IW_IOCTL_IDX(SIOCGIWAP)] = { + .header_type = IW_HEADER_TYPE_ADDR, + .flags = IW_DESCR_FLAG_DUMP, + }, + [IW_IOCTL_IDX(SIOCSIWMLME)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .min_tokens = sizeof(struct iw_mlme), + .max_tokens = sizeof(struct iw_mlme), + }, + [IW_IOCTL_IDX(SIOCGIWAPLIST)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = sizeof(struct sockaddr) + + sizeof(struct iw_quality), + .max_tokens = IW_MAX_AP, + .flags = IW_DESCR_FLAG_NOMAX, + }, + [IW_IOCTL_IDX(SIOCSIWSCAN)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .min_tokens = 0, + .max_tokens = sizeof(struct iw_scan_req), + }, + [IW_IOCTL_IDX(SIOCGIWSCAN)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .max_tokens = IW_SCAN_MAX_DATA, + .flags = IW_DESCR_FLAG_NOMAX, + }, + [IW_IOCTL_IDX(SIOCSIWESSID)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .max_tokens = IW_ESSID_MAX_SIZE, + .flags = IW_DESCR_FLAG_EVENT, + }, + [IW_IOCTL_IDX(SIOCGIWESSID)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .max_tokens = IW_ESSID_MAX_SIZE, + .flags = IW_DESCR_FLAG_DUMP, + }, + [IW_IOCTL_IDX(SIOCSIWNICKN)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .max_tokens = IW_ESSID_MAX_SIZE, + }, + [IW_IOCTL_IDX(SIOCGIWNICKN)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .max_tokens = IW_ESSID_MAX_SIZE, + }, + [IW_IOCTL_IDX(SIOCSIWRATE)] = { + .header_type = IW_HEADER_TYPE_PARAM, + }, + [IW_IOCTL_IDX(SIOCGIWRATE)] = { + .header_type = IW_HEADER_TYPE_PARAM, + }, + [IW_IOCTL_IDX(SIOCSIWRTS)] = { + .header_type = IW_HEADER_TYPE_PARAM, + }, + [IW_IOCTL_IDX(SIOCGIWRTS)] = { + .header_type = IW_HEADER_TYPE_PARAM, + }, + [IW_IOCTL_IDX(SIOCSIWFRAG)] = { + .header_type = IW_HEADER_TYPE_PARAM, + }, + [IW_IOCTL_IDX(SIOCGIWFRAG)] = { + .header_type = IW_HEADER_TYPE_PARAM, + }, + [IW_IOCTL_IDX(SIOCSIWTXPOW)] = { + .header_type = IW_HEADER_TYPE_PARAM, + }, + [IW_IOCTL_IDX(SIOCGIWTXPOW)] = { + .header_type = IW_HEADER_TYPE_PARAM, + }, + [IW_IOCTL_IDX(SIOCSIWRETRY)] = { + .header_type = IW_HEADER_TYPE_PARAM, + }, + [IW_IOCTL_IDX(SIOCGIWRETRY)] = { + .header_type = IW_HEADER_TYPE_PARAM, + }, + [IW_IOCTL_IDX(SIOCSIWENCODE)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .max_tokens = IW_ENCODING_TOKEN_MAX, + .flags = IW_DESCR_FLAG_EVENT | IW_DESCR_FLAG_RESTRICT, + }, + [IW_IOCTL_IDX(SIOCGIWENCODE)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .max_tokens = IW_ENCODING_TOKEN_MAX, + .flags = IW_DESCR_FLAG_DUMP | IW_DESCR_FLAG_RESTRICT, + }, + [IW_IOCTL_IDX(SIOCSIWPOWER)] = { + .header_type = IW_HEADER_TYPE_PARAM, + }, + [IW_IOCTL_IDX(SIOCGIWPOWER)] = { + .header_type = IW_HEADER_TYPE_PARAM, + }, + [IW_IOCTL_IDX(SIOCSIWGENIE)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .max_tokens = IW_GENERIC_IE_MAX, + }, + [IW_IOCTL_IDX(SIOCGIWGENIE)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .max_tokens = IW_GENERIC_IE_MAX, + }, + [IW_IOCTL_IDX(SIOCSIWAUTH)] = { + .header_type = IW_HEADER_TYPE_PARAM, + }, + [IW_IOCTL_IDX(SIOCGIWAUTH)] = { + .header_type = IW_HEADER_TYPE_PARAM, + }, + [IW_IOCTL_IDX(SIOCSIWENCODEEXT)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .min_tokens = sizeof(struct iw_encode_ext), + .max_tokens = sizeof(struct iw_encode_ext) + + IW_ENCODING_TOKEN_MAX, + }, + [IW_IOCTL_IDX(SIOCGIWENCODEEXT)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .min_tokens = sizeof(struct iw_encode_ext), + .max_tokens = sizeof(struct iw_encode_ext) + + IW_ENCODING_TOKEN_MAX, + }, + [IW_IOCTL_IDX(SIOCSIWPMKSA)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .min_tokens = sizeof(struct iw_pmksa), + .max_tokens = sizeof(struct iw_pmksa), + }, +}; +static const unsigned int standard_ioctl_num = ARRAY_SIZE(standard_ioctl); + +/* + * Meta-data about all the additional standard Wireless Extension events + * we know about. + */ +static const struct iw_ioctl_description standard_event[] = { + [IW_EVENT_IDX(IWEVTXDROP)] = { + .header_type = IW_HEADER_TYPE_ADDR, + }, + [IW_EVENT_IDX(IWEVQUAL)] = { + .header_type = IW_HEADER_TYPE_QUAL, + }, + [IW_EVENT_IDX(IWEVCUSTOM)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .max_tokens = IW_CUSTOM_MAX, + }, + [IW_EVENT_IDX(IWEVREGISTERED)] = { + .header_type = IW_HEADER_TYPE_ADDR, + }, + [IW_EVENT_IDX(IWEVEXPIRED)] = { + .header_type = IW_HEADER_TYPE_ADDR, + }, + [IW_EVENT_IDX(IWEVGENIE)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .max_tokens = IW_GENERIC_IE_MAX, + }, + [IW_EVENT_IDX(IWEVMICHAELMICFAILURE)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .max_tokens = sizeof(struct iw_michaelmicfailure), + }, + [IW_EVENT_IDX(IWEVASSOCREQIE)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .max_tokens = IW_GENERIC_IE_MAX, + }, + [IW_EVENT_IDX(IWEVASSOCRESPIE)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .max_tokens = IW_GENERIC_IE_MAX, + }, + [IW_EVENT_IDX(IWEVPMKIDCAND)] = { + .header_type = IW_HEADER_TYPE_POINT, + .token_size = 1, + .max_tokens = sizeof(struct iw_pmkid_cand), + }, +}; +static const unsigned int standard_event_num = ARRAY_SIZE(standard_event); + +/* Size (in bytes) of various events */ +static const int event_type_size[] = { + IW_EV_LCP_LEN, /* IW_HEADER_TYPE_NULL */ + 0, + IW_EV_CHAR_LEN, /* IW_HEADER_TYPE_CHAR */ + 0, + IW_EV_UINT_LEN, /* IW_HEADER_TYPE_UINT */ + IW_EV_FREQ_LEN, /* IW_HEADER_TYPE_FREQ */ + IW_EV_ADDR_LEN, /* IW_HEADER_TYPE_ADDR */ + 0, + IW_EV_POINT_LEN, /* Without variable payload */ + IW_EV_PARAM_LEN, /* IW_HEADER_TYPE_PARAM */ + IW_EV_QUAL_LEN, /* IW_HEADER_TYPE_QUAL */ +}; + +#ifdef CONFIG_COMPAT +static const int compat_event_type_size[] = { + IW_EV_COMPAT_LCP_LEN, /* IW_HEADER_TYPE_NULL */ + 0, + IW_EV_COMPAT_CHAR_LEN, /* IW_HEADER_TYPE_CHAR */ + 0, + IW_EV_COMPAT_UINT_LEN, /* IW_HEADER_TYPE_UINT */ + IW_EV_COMPAT_FREQ_LEN, /* IW_HEADER_TYPE_FREQ */ + IW_EV_COMPAT_ADDR_LEN, /* IW_HEADER_TYPE_ADDR */ + 0, + IW_EV_COMPAT_POINT_LEN, /* Without variable payload */ + IW_EV_COMPAT_PARAM_LEN, /* IW_HEADER_TYPE_PARAM */ + IW_EV_COMPAT_QUAL_LEN, /* IW_HEADER_TYPE_QUAL */ +}; +#endif + + +/* IW event code */ + +void wireless_nlevent_flush(void) +{ + struct sk_buff *skb; + struct net *net; + + down_read(&net_rwsem); + for_each_net(net) { + while ((skb = skb_dequeue(&net->wext_nlevents))) + rtnl_notify(skb, net, 0, RTNLGRP_LINK, NULL, + GFP_KERNEL); + } + up_read(&net_rwsem); +} +EXPORT_SYMBOL_GPL(wireless_nlevent_flush); + +static int wext_netdev_notifier_call(struct notifier_block *nb, + unsigned long state, void *ptr) +{ + /* + * When a netdev changes state in any way, flush all pending messages + * to avoid them going out in a strange order, e.g. RTM_NEWLINK after + * RTM_DELLINK, or with IFF_UP after without IFF_UP during dev_close() + * or similar - all of which could otherwise happen due to delays from + * schedule_work(). + */ + wireless_nlevent_flush(); + + return NOTIFY_OK; +} + +static struct notifier_block wext_netdev_notifier = { + .notifier_call = wext_netdev_notifier_call, +}; + +static int __net_init wext_pernet_init(struct net *net) +{ + skb_queue_head_init(&net->wext_nlevents); + return 0; +} + +static void __net_exit wext_pernet_exit(struct net *net) +{ + skb_queue_purge(&net->wext_nlevents); +} + +static struct pernet_operations wext_pernet_ops = { + .init = wext_pernet_init, + .exit = wext_pernet_exit, +}; + +static int __init wireless_nlevent_init(void) +{ + int err = register_pernet_subsys(&wext_pernet_ops); + + if (err) + return err; + + err = register_netdevice_notifier(&wext_netdev_notifier); + if (err) + unregister_pernet_subsys(&wext_pernet_ops); + return err; +} + +subsys_initcall(wireless_nlevent_init); + +/* Process events generated by the wireless layer or the driver. */ +static void wireless_nlevent_process(struct work_struct *work) +{ + wireless_nlevent_flush(); +} + +static DECLARE_WORK(wireless_nlevent_work, wireless_nlevent_process); + +static struct nlmsghdr *rtnetlink_ifinfo_prep(struct net_device *dev, + struct sk_buff *skb) +{ + struct ifinfomsg *r; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(skb, 0, 0, RTM_NEWLINK, sizeof(*r), 0); + if (!nlh) + return NULL; + + r = nlmsg_data(nlh); + r->ifi_family = AF_UNSPEC; + r->__ifi_pad = 0; + r->ifi_type = dev->type; + r->ifi_index = dev->ifindex; + r->ifi_flags = dev_get_flags(dev); + r->ifi_change = 0; /* Wireless changes don't affect those flags */ + + if (nla_put_string(skb, IFLA_IFNAME, dev->name)) + goto nla_put_failure; + + return nlh; + nla_put_failure: + nlmsg_cancel(skb, nlh); + return NULL; +} + + +/* + * Main event dispatcher. Called from other parts and drivers. + * Send the event on the appropriate channels. + * May be called from interrupt context. + */ +void wireless_send_event(struct net_device * dev, + unsigned int cmd, + union iwreq_data * wrqu, + const char * extra) +{ + const struct iw_ioctl_description * descr = NULL; + int extra_len = 0; + struct iw_event *event; /* Mallocated whole event */ + int event_len; /* Its size */ + int hdr_len; /* Size of the event header */ + int wrqu_off = 0; /* Offset in wrqu */ + /* Don't "optimise" the following variable, it will crash */ + unsigned int cmd_index; /* *MUST* be unsigned */ + struct sk_buff *skb; + struct nlmsghdr *nlh; + struct nlattr *nla; +#ifdef CONFIG_COMPAT + struct __compat_iw_event *compat_event; + struct compat_iw_point compat_wrqu; + struct sk_buff *compskb; + int ptr_len; +#endif + + /* + * Nothing in the kernel sends scan events with data, be safe. + * This is necessary because we cannot fix up scan event data + * for compat, due to being contained in 'extra', but normally + * applications are required to retrieve the scan data anyway + * and no data is included in the event, this codifies that + * practice. + */ + if (WARN_ON(cmd == SIOCGIWSCAN && extra)) + extra = NULL; + + /* Get the description of the Event */ + if (cmd <= SIOCIWLAST) { + cmd_index = IW_IOCTL_IDX(cmd); + if (cmd_index < standard_ioctl_num) + descr = &(standard_ioctl[cmd_index]); + } else { + cmd_index = IW_EVENT_IDX(cmd); + if (cmd_index < standard_event_num) + descr = &(standard_event[cmd_index]); + } + /* Don't accept unknown events */ + if (descr == NULL) { + /* Note : we don't return an error to the driver, because + * the driver would not know what to do about it. It can't + * return an error to the user, because the event is not + * initiated by a user request. + * The best the driver could do is to log an error message. + * We will do it ourselves instead... + */ + netdev_err(dev, "(WE) : Invalid/Unknown Wireless Event (0x%04X)\n", + cmd); + return; + } + + /* Check extra parameters and set extra_len */ + if (descr->header_type == IW_HEADER_TYPE_POINT) { + /* Check if number of token fits within bounds */ + if (wrqu->data.length > descr->max_tokens) { + netdev_err(dev, "(WE) : Wireless Event (cmd=0x%04X) too big (%d)\n", + cmd, wrqu->data.length); + return; + } + if (wrqu->data.length < descr->min_tokens) { + netdev_err(dev, "(WE) : Wireless Event (cmd=0x%04X) too small (%d)\n", + cmd, wrqu->data.length); + return; + } + /* Calculate extra_len - extra is NULL for restricted events */ + if (extra != NULL) + extra_len = wrqu->data.length * descr->token_size; + /* Always at an offset in wrqu */ + wrqu_off = IW_EV_POINT_OFF; + } + + /* Total length of the event */ + hdr_len = event_type_size[descr->header_type]; + event_len = hdr_len + extra_len; + + /* + * The problem for 64/32 bit. + * + * On 64-bit, a regular event is laid out as follows: + * | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + * | event.len | event.cmd | p a d d i n g | + * | wrqu data ... (with the correct size) | + * + * This padding exists because we manipulate event->u, + * and 'event' is not packed. + * + * An iw_point event is laid out like this instead: + * | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | + * | event.len | event.cmd | p a d d i n g | + * | iwpnt.len | iwpnt.flg | p a d d i n g | + * | extra data ... + * + * The second padding exists because struct iw_point is extended, + * but this depends on the platform... + * + * On 32-bit, all the padding shouldn't be there. + */ + + skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!skb) + return; + + /* Send via the RtNetlink event channel */ + nlh = rtnetlink_ifinfo_prep(dev, skb); + if (WARN_ON(!nlh)) { + kfree_skb(skb); + return; + } + + /* Add the wireless events in the netlink packet */ + nla = nla_reserve(skb, IFLA_WIRELESS, event_len); + if (!nla) { + kfree_skb(skb); + return; + } + event = nla_data(nla); + + /* Fill event - first clear to avoid data leaking */ + memset(event, 0, hdr_len); + event->len = event_len; + event->cmd = cmd; + memcpy(&event->u, ((char *) wrqu) + wrqu_off, hdr_len - IW_EV_LCP_LEN); + if (extra_len) + memcpy(((char *) event) + hdr_len, extra, extra_len); + + nlmsg_end(skb, nlh); +#ifdef CONFIG_COMPAT + hdr_len = compat_event_type_size[descr->header_type]; + + /* ptr_len is remaining size in event header apart from LCP */ + ptr_len = hdr_len - IW_EV_COMPAT_LCP_LEN; + event_len = hdr_len + extra_len; + + compskb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!compskb) { + kfree_skb(skb); + return; + } + + /* Send via the RtNetlink event channel */ + nlh = rtnetlink_ifinfo_prep(dev, compskb); + if (WARN_ON(!nlh)) { + kfree_skb(skb); + kfree_skb(compskb); + return; + } + + /* Add the wireless events in the netlink packet */ + nla = nla_reserve(compskb, IFLA_WIRELESS, event_len); + if (!nla) { + kfree_skb(skb); + kfree_skb(compskb); + return; + } + compat_event = nla_data(nla); + + compat_event->len = event_len; + compat_event->cmd = cmd; + if (descr->header_type == IW_HEADER_TYPE_POINT) { + compat_wrqu.length = wrqu->data.length; + compat_wrqu.flags = wrqu->data.flags; + memcpy(compat_event->ptr_bytes, + ((char *)&compat_wrqu) + IW_EV_COMPAT_POINT_OFF, + ptr_len); + if (extra_len) + memcpy(&compat_event->ptr_bytes[ptr_len], + extra, extra_len); + } else { + /* extra_len must be zero, so no if (extra) needed */ + memcpy(compat_event->ptr_bytes, wrqu, ptr_len); + } + + nlmsg_end(compskb, nlh); + + skb_shinfo(skb)->frag_list = compskb; +#endif + skb_queue_tail(&dev_net(dev)->wext_nlevents, skb); + schedule_work(&wireless_nlevent_work); +} +EXPORT_SYMBOL(wireless_send_event); + + + +/* IW handlers */ + +struct iw_statistics *get_wireless_stats(struct net_device *dev) +{ +#ifdef CONFIG_WIRELESS_EXT + if ((dev->wireless_handlers != NULL) && + (dev->wireless_handlers->get_wireless_stats != NULL)) + return dev->wireless_handlers->get_wireless_stats(dev); +#endif + +#ifdef CONFIG_CFG80211_WEXT + if (dev->ieee80211_ptr && + dev->ieee80211_ptr->wiphy && + dev->ieee80211_ptr->wiphy->wext && + dev->ieee80211_ptr->wiphy->wext->get_wireless_stats) + return dev->ieee80211_ptr->wiphy->wext->get_wireless_stats(dev); +#endif + + /* not found */ + return NULL; +} + +/* noinline to avoid a bogus warning with -O3 */ +static noinline int iw_handler_get_iwstats(struct net_device * dev, + struct iw_request_info * info, + union iwreq_data * wrqu, + char * extra) +{ + /* Get stats from the driver */ + struct iw_statistics *stats; + + stats = get_wireless_stats(dev); + if (stats) { + /* Copy statistics to extra */ + memcpy(extra, stats, sizeof(struct iw_statistics)); + wrqu->data.length = sizeof(struct iw_statistics); + + /* Check if we need to clear the updated flag */ + if (wrqu->data.flags != 0) + stats->qual.updated &= ~IW_QUAL_ALL_UPDATED; + return 0; + } else + return -EOPNOTSUPP; +} + +static iw_handler get_handler(struct net_device *dev, unsigned int cmd) +{ + /* Don't "optimise" the following variable, it will crash */ + unsigned int index; /* *MUST* be unsigned */ + const struct iw_handler_def *handlers = NULL; + +#ifdef CONFIG_CFG80211_WEXT + if (dev->ieee80211_ptr && dev->ieee80211_ptr->wiphy) + handlers = dev->ieee80211_ptr->wiphy->wext; +#endif +#ifdef CONFIG_WIRELESS_EXT + if (dev->wireless_handlers) + handlers = dev->wireless_handlers; +#endif + + if (!handlers) + return NULL; + + /* Try as a standard command */ + index = IW_IOCTL_IDX(cmd); + if (index < handlers->num_standard) + return handlers->standard[index]; + +#ifdef CONFIG_WEXT_PRIV + /* Try as a private command */ + index = cmd - SIOCIWFIRSTPRIV; + if (index < handlers->num_private) + return handlers->private[index]; +#endif + + /* Not found */ + return NULL; +} + +static int ioctl_standard_iw_point(struct iw_point *iwp, unsigned int cmd, + const struct iw_ioctl_description *descr, + iw_handler handler, struct net_device *dev, + struct iw_request_info *info) +{ + int err, extra_size, user_length = 0, essid_compat = 0; + char *extra; + + /* Calculate space needed by arguments. Always allocate + * for max space. + */ + extra_size = descr->max_tokens * descr->token_size; + + /* Check need for ESSID compatibility for WE < 21 */ + switch (cmd) { + case SIOCSIWESSID: + case SIOCGIWESSID: + case SIOCSIWNICKN: + case SIOCGIWNICKN: + if (iwp->length == descr->max_tokens + 1) + essid_compat = 1; + else if (IW_IS_SET(cmd) && (iwp->length != 0)) { + char essid[IW_ESSID_MAX_SIZE + 1]; + unsigned int len; + len = iwp->length * descr->token_size; + + if (len > IW_ESSID_MAX_SIZE) + return -EFAULT; + + err = copy_from_user(essid, iwp->pointer, len); + if (err) + return -EFAULT; + + if (essid[iwp->length - 1] == '\0') + essid_compat = 1; + } + break; + default: + break; + } + + iwp->length -= essid_compat; + + /* Check what user space is giving us */ + if (IW_IS_SET(cmd)) { + /* Check NULL pointer */ + if (!iwp->pointer && iwp->length != 0) + return -EFAULT; + /* Check if number of token fits within bounds */ + if (iwp->length > descr->max_tokens) + return -E2BIG; + if (iwp->length < descr->min_tokens) + return -EINVAL; + } else { + /* Check NULL pointer */ + if (!iwp->pointer) + return -EFAULT; + /* Save user space buffer size for checking */ + user_length = iwp->length; + + /* Don't check if user_length > max to allow forward + * compatibility. The test user_length < min is + * implied by the test at the end. + */ + + /* Support for very large requests */ + if ((descr->flags & IW_DESCR_FLAG_NOMAX) && + (user_length > descr->max_tokens)) { + /* Allow userspace to GET more than max so + * we can support any size GET requests. + * There is still a limit : -ENOMEM. + */ + extra_size = user_length * descr->token_size; + + /* Note : user_length is originally a __u16, + * and token_size is controlled by us, + * so extra_size won't get negative and + * won't overflow... + */ + } + } + + /* Sanity-check to ensure we never end up _allocating_ zero + * bytes of data for extra. + */ + if (extra_size <= 0) + return -EFAULT; + + /* kzalloc() ensures NULL-termination for essid_compat. */ + extra = kzalloc(extra_size, GFP_KERNEL); + if (!extra) + return -ENOMEM; + + /* If it is a SET, get all the extra data in here */ + if (IW_IS_SET(cmd) && (iwp->length != 0)) { + if (copy_from_user(extra, iwp->pointer, + iwp->length * + descr->token_size)) { + err = -EFAULT; + goto out; + } + + if (cmd == SIOCSIWENCODEEXT) { + struct iw_encode_ext *ee = (void *) extra; + + if (iwp->length < sizeof(*ee) + ee->key_len) { + err = -EFAULT; + goto out; + } + } + } + + if (IW_IS_GET(cmd) && !(descr->flags & IW_DESCR_FLAG_NOMAX)) { + /* + * If this is a GET, but not NOMAX, it means that the extra + * data is not bounded by userspace, but by max_tokens. Thus + * set the length to max_tokens. This matches the extra data + * allocation. + * The driver should fill it with the number of tokens it + * provided, and it may check iwp->length rather than having + * knowledge of max_tokens. If the driver doesn't change the + * iwp->length, this ioctl just copies back max_token tokens + * filled with zeroes. Hopefully the driver isn't claiming + * them to be valid data. + */ + iwp->length = descr->max_tokens; + } + + err = handler(dev, info, (union iwreq_data *) iwp, extra); + + iwp->length += essid_compat; + + /* If we have something to return to the user */ + if (!err && IW_IS_GET(cmd)) { + /* Check if there is enough buffer up there */ + if (user_length < iwp->length) { + err = -E2BIG; + goto out; + } + + if (copy_to_user(iwp->pointer, extra, + iwp->length * + descr->token_size)) { + err = -EFAULT; + goto out; + } + } + + /* Generate an event to notify listeners of the change */ + if ((descr->flags & IW_DESCR_FLAG_EVENT) && + ((err == 0) || (err == -EIWCOMMIT))) { + union iwreq_data *data = (union iwreq_data *) iwp; + + if (descr->flags & IW_DESCR_FLAG_RESTRICT) + /* If the event is restricted, don't + * export the payload. + */ + wireless_send_event(dev, cmd, data, NULL); + else + wireless_send_event(dev, cmd, data, extra); + } + +out: + kfree(extra); + return err; +} + +/* + * Call the commit handler in the driver + * (if exist and if conditions are right) + * + * Note : our current commit strategy is currently pretty dumb, + * but we will be able to improve on that... + * The goal is to try to agreagate as many changes as possible + * before doing the commit. Drivers that will define a commit handler + * are usually those that need a reset after changing parameters, so + * we want to minimise the number of reset. + * A cool idea is to use a timer : at each "set" command, we re-set the + * timer, when the timer eventually fires, we call the driver. + * Hopefully, more on that later. + * + * Also, I'm waiting to see how many people will complain about the + * netif_running(dev) test. I'm open on that one... + * Hopefully, the driver will remember to do a commit in "open()" ;-) + */ +int call_commit_handler(struct net_device *dev) +{ +#ifdef CONFIG_WIRELESS_EXT + if (netif_running(dev) && + dev->wireless_handlers && + dev->wireless_handlers->standard[0]) + /* Call the commit handler on the driver */ + return dev->wireless_handlers->standard[0](dev, NULL, + NULL, NULL); + else + return 0; /* Command completed successfully */ +#else + /* cfg80211 has no commit */ + return 0; +#endif +} + +/* + * Main IOCTl dispatcher. + * Check the type of IOCTL and call the appropriate wrapper... + */ +static int wireless_process_ioctl(struct net *net, struct iwreq *iwr, + unsigned int cmd, + struct iw_request_info *info, + wext_ioctl_func standard, + wext_ioctl_func private) +{ + struct net_device *dev; + iw_handler handler; + + /* Permissions are already checked in dev_ioctl() before calling us. + * The copy_to/from_user() of ifr is also dealt with in there */ + + /* Make sure the device exist */ + if ((dev = __dev_get_by_name(net, iwr->ifr_name)) == NULL) + return -ENODEV; + + /* A bunch of special cases, then the generic case... + * Note that 'cmd' is already filtered in dev_ioctl() with + * (cmd >= SIOCIWFIRST && cmd <= SIOCIWLAST) */ + if (cmd == SIOCGIWSTATS) + return standard(dev, iwr, cmd, info, + &iw_handler_get_iwstats); + +#ifdef CONFIG_WEXT_PRIV + if (cmd == SIOCGIWPRIV && dev->wireless_handlers) + return standard(dev, iwr, cmd, info, + iw_handler_get_private); +#endif + + /* Basic check */ + if (!netif_device_present(dev)) + return -ENODEV; + + /* New driver API : try to find the handler */ + handler = get_handler(dev, cmd); + if (handler) { + /* Standard and private are not the same */ + if (cmd < SIOCIWFIRSTPRIV) + return standard(dev, iwr, cmd, info, handler); + else if (private) + return private(dev, iwr, cmd, info, handler); + } + return -EOPNOTSUPP; +} + +/* If command is `set a parameter', or `get the encoding parameters', + * check if the user has the right to do it. + */ +static int wext_permission_check(unsigned int cmd) +{ + if ((IW_IS_SET(cmd) || cmd == SIOCGIWENCODE || + cmd == SIOCGIWENCODEEXT) && + !capable(CAP_NET_ADMIN)) + return -EPERM; + + return 0; +} + +/* entry point from dev ioctl */ +static int wext_ioctl_dispatch(struct net *net, struct iwreq *iwr, + unsigned int cmd, struct iw_request_info *info, + wext_ioctl_func standard, + wext_ioctl_func private) +{ + int ret = wext_permission_check(cmd); + + if (ret) + return ret; + + dev_load(net, iwr->ifr_name); + rtnl_lock(); + ret = wireless_process_ioctl(net, iwr, cmd, info, standard, private); + rtnl_unlock(); + + return ret; +} + +/* + * Wrapper to call a standard Wireless Extension handler. + * We do various checks and also take care of moving data between + * user space and kernel space. + */ +static int ioctl_standard_call(struct net_device * dev, + struct iwreq *iwr, + unsigned int cmd, + struct iw_request_info *info, + iw_handler handler) +{ + const struct iw_ioctl_description * descr; + int ret = -EINVAL; + + /* Get the description of the IOCTL */ + if (IW_IOCTL_IDX(cmd) >= standard_ioctl_num) + return -EOPNOTSUPP; + descr = &(standard_ioctl[IW_IOCTL_IDX(cmd)]); + + /* Check if we have a pointer to user space data or not */ + if (descr->header_type != IW_HEADER_TYPE_POINT) { + + /* No extra arguments. Trivial to handle */ + ret = handler(dev, info, &(iwr->u), NULL); + + /* Generate an event to notify listeners of the change */ + if ((descr->flags & IW_DESCR_FLAG_EVENT) && + ((ret == 0) || (ret == -EIWCOMMIT))) + wireless_send_event(dev, cmd, &(iwr->u), NULL); + } else { + ret = ioctl_standard_iw_point(&iwr->u.data, cmd, descr, + handler, dev, info); + } + + /* Call commit handler if needed and defined */ + if (ret == -EIWCOMMIT) + ret = call_commit_handler(dev); + + /* Here, we will generate the appropriate event if needed */ + + return ret; +} + + +int wext_handle_ioctl(struct net *net, unsigned int cmd, void __user *arg) +{ + struct iw_request_info info = { .cmd = cmd, .flags = 0 }; + struct iwreq iwr; + int ret; + + if (copy_from_user(&iwr, arg, sizeof(iwr))) + return -EFAULT; + + iwr.ifr_name[sizeof(iwr.ifr_name) - 1] = 0; + + ret = wext_ioctl_dispatch(net, &iwr, cmd, &info, + ioctl_standard_call, + ioctl_private_call); + if (ret >= 0 && + IW_IS_GET(cmd) && + copy_to_user(arg, &iwr, sizeof(struct iwreq))) + return -EFAULT; + + return ret; +} + +#ifdef CONFIG_COMPAT +static int compat_standard_call(struct net_device *dev, + struct iwreq *iwr, + unsigned int cmd, + struct iw_request_info *info, + iw_handler handler) +{ + const struct iw_ioctl_description *descr; + struct compat_iw_point *iwp_compat; + struct iw_point iwp; + int err; + + descr = standard_ioctl + IW_IOCTL_IDX(cmd); + + if (descr->header_type != IW_HEADER_TYPE_POINT) + return ioctl_standard_call(dev, iwr, cmd, info, handler); + + iwp_compat = (struct compat_iw_point *) &iwr->u.data; + iwp.pointer = compat_ptr(iwp_compat->pointer); + iwp.length = iwp_compat->length; + iwp.flags = iwp_compat->flags; + + err = ioctl_standard_iw_point(&iwp, cmd, descr, handler, dev, info); + + iwp_compat->pointer = ptr_to_compat(iwp.pointer); + iwp_compat->length = iwp.length; + iwp_compat->flags = iwp.flags; + + return err; +} + +int compat_wext_handle_ioctl(struct net *net, unsigned int cmd, + unsigned long arg) +{ + void __user *argp = (void __user *)arg; + struct iw_request_info info; + struct iwreq iwr; + char *colon; + int ret; + + if (copy_from_user(&iwr, argp, sizeof(struct iwreq))) + return -EFAULT; + + iwr.ifr_name[IFNAMSIZ-1] = 0; + colon = strchr(iwr.ifr_name, ':'); + if (colon) + *colon = 0; + + info.cmd = cmd; + info.flags = IW_REQUEST_FLAG_COMPAT; + + ret = wext_ioctl_dispatch(net, &iwr, cmd, &info, + compat_standard_call, + compat_private_call); + + if (ret >= 0 && + IW_IS_GET(cmd) && + copy_to_user(argp, &iwr, sizeof(struct iwreq))) + return -EFAULT; + + return ret; +} +#endif + +char *iwe_stream_add_event(struct iw_request_info *info, char *stream, + char *ends, struct iw_event *iwe, int event_len) +{ + int lcp_len = iwe_stream_lcp_len(info); + + event_len = iwe_stream_event_len_adjust(info, event_len); + + /* Check if it's possible */ + if (likely((stream + event_len) < ends)) { + iwe->len = event_len; + /* Beware of alignement issues on 64 bits */ + memcpy(stream, (char *) iwe, IW_EV_LCP_PK_LEN); + memcpy(stream + lcp_len, &iwe->u, + event_len - lcp_len); + stream += event_len; + } + + return stream; +} +EXPORT_SYMBOL(iwe_stream_add_event); + +char *iwe_stream_add_point(struct iw_request_info *info, char *stream, + char *ends, struct iw_event *iwe, char *extra) +{ + int event_len = iwe_stream_point_len(info) + iwe->u.data.length; + int point_len = iwe_stream_point_len(info); + int lcp_len = iwe_stream_lcp_len(info); + + /* Check if it's possible */ + if (likely((stream + event_len) < ends)) { + iwe->len = event_len; + memcpy(stream, (char *) iwe, IW_EV_LCP_PK_LEN); + memcpy(stream + lcp_len, + ((char *) &iwe->u) + IW_EV_POINT_OFF, + IW_EV_POINT_PK_LEN - IW_EV_LCP_PK_LEN); + if (iwe->u.data.length && extra) + memcpy(stream + point_len, extra, iwe->u.data.length); + stream += event_len; + } + + return stream; +} +EXPORT_SYMBOL(iwe_stream_add_point); + +char *iwe_stream_add_value(struct iw_request_info *info, char *event, + char *value, char *ends, struct iw_event *iwe, + int event_len) +{ + int lcp_len = iwe_stream_lcp_len(info); + + /* Don't duplicate LCP */ + event_len -= IW_EV_LCP_LEN; + + /* Check if it's possible */ + if (likely((value + event_len) < ends)) { + /* Add new value */ + memcpy(value, &iwe->u, event_len); + value += event_len; + /* Patch LCP */ + iwe->len = value - event; + memcpy(event, (char *) iwe, lcp_len); + } + + return value; +} +EXPORT_SYMBOL(iwe_stream_add_value); diff --git a/net/wireless/wext-priv.c b/net/wireless/wext-priv.c new file mode 100644 index 000000000..674d426a9 --- /dev/null +++ b/net/wireless/wext-priv.c @@ -0,0 +1,249 @@ +/* + * This file implement the Wireless Extensions priv API. + * + * Authors : Jean Tourrilhes - HPL - + * Copyright (c) 1997-2007 Jean Tourrilhes, All Rights Reserved. + * Copyright 2009 Johannes Berg + * + * (As all part of the Linux kernel, this file is GPL) + */ +#include +#include +#include +#include +#include + +int iw_handler_get_private(struct net_device * dev, + struct iw_request_info * info, + union iwreq_data * wrqu, + char * extra) +{ + /* Check if the driver has something to export */ + if ((dev->wireless_handlers->num_private_args == 0) || + (dev->wireless_handlers->private_args == NULL)) + return -EOPNOTSUPP; + + /* Check if there is enough buffer up there */ + if (wrqu->data.length < dev->wireless_handlers->num_private_args) { + /* User space can't know in advance how large the buffer + * needs to be. Give it a hint, so that we can support + * any size buffer we want somewhat efficiently... */ + wrqu->data.length = dev->wireless_handlers->num_private_args; + return -E2BIG; + } + + /* Set the number of available ioctls. */ + wrqu->data.length = dev->wireless_handlers->num_private_args; + + /* Copy structure to the user buffer. */ + memcpy(extra, dev->wireless_handlers->private_args, + sizeof(struct iw_priv_args) * wrqu->data.length); + + return 0; +} + +/* Size (in bytes) of the various private data types */ +static const char iw_priv_type_size[] = { + 0, /* IW_PRIV_TYPE_NONE */ + 1, /* IW_PRIV_TYPE_BYTE */ + 1, /* IW_PRIV_TYPE_CHAR */ + 0, /* Not defined */ + sizeof(__u32), /* IW_PRIV_TYPE_INT */ + sizeof(struct iw_freq), /* IW_PRIV_TYPE_FLOAT */ + sizeof(struct sockaddr), /* IW_PRIV_TYPE_ADDR */ + 0, /* Not defined */ +}; + +static int get_priv_size(__u16 args) +{ + int num = args & IW_PRIV_SIZE_MASK; + int type = (args & IW_PRIV_TYPE_MASK) >> 12; + + return num * iw_priv_type_size[type]; +} + +static int adjust_priv_size(__u16 args, struct iw_point *iwp) +{ + int num = iwp->length; + int max = args & IW_PRIV_SIZE_MASK; + int type = (args & IW_PRIV_TYPE_MASK) >> 12; + + /* Make sure the driver doesn't goof up */ + if (max < num) + num = max; + + return num * iw_priv_type_size[type]; +} + +/* + * Wrapper to call a private Wireless Extension handler. + * We do various checks and also take care of moving data between + * user space and kernel space. + * It's not as nice and slimline as the standard wrapper. The cause + * is struct iw_priv_args, which was not really designed for the + * job we are going here. + * + * IMPORTANT : This function prevent to set and get data on the same + * IOCTL and enforce the SET/GET convention. Not doing it would be + * far too hairy... + * If you need to set and get data at the same time, please don't use + * a iw_handler but process it in your ioctl handler (i.e. use the + * old driver API). + */ +static int get_priv_descr_and_size(struct net_device *dev, unsigned int cmd, + const struct iw_priv_args **descrp) +{ + const struct iw_priv_args *descr; + int i, extra_size; + + descr = NULL; + for (i = 0; i < dev->wireless_handlers->num_private_args; i++) { + if (cmd == dev->wireless_handlers->private_args[i].cmd) { + descr = &dev->wireless_handlers->private_args[i]; + break; + } + } + + extra_size = 0; + if (descr) { + if (IW_IS_SET(cmd)) { + int offset = 0; /* For sub-ioctls */ + /* Check for sub-ioctl handler */ + if (descr->name[0] == '\0') + /* Reserve one int for sub-ioctl index */ + offset = sizeof(__u32); + + /* Size of set arguments */ + extra_size = get_priv_size(descr->set_args); + + /* Does it fits in iwr ? */ + if ((descr->set_args & IW_PRIV_SIZE_FIXED) && + ((extra_size + offset) <= IFNAMSIZ)) + extra_size = 0; + } else { + /* Size of get arguments */ + extra_size = get_priv_size(descr->get_args); + + /* Does it fits in iwr ? */ + if ((descr->get_args & IW_PRIV_SIZE_FIXED) && + (extra_size <= IFNAMSIZ)) + extra_size = 0; + } + } + *descrp = descr; + return extra_size; +} + +static int ioctl_private_iw_point(struct iw_point *iwp, unsigned int cmd, + const struct iw_priv_args *descr, + iw_handler handler, struct net_device *dev, + struct iw_request_info *info, int extra_size) +{ + char *extra; + int err; + + /* Check what user space is giving us */ + if (IW_IS_SET(cmd)) { + if (!iwp->pointer && iwp->length != 0) + return -EFAULT; + + if (iwp->length > (descr->set_args & IW_PRIV_SIZE_MASK)) + return -E2BIG; + } else if (!iwp->pointer) + return -EFAULT; + + extra = kzalloc(extra_size, GFP_KERNEL); + if (!extra) + return -ENOMEM; + + /* If it is a SET, get all the extra data in here */ + if (IW_IS_SET(cmd) && (iwp->length != 0)) { + if (copy_from_user(extra, iwp->pointer, extra_size)) { + err = -EFAULT; + goto out; + } + } + + /* Call the handler */ + err = handler(dev, info, (union iwreq_data *) iwp, extra); + + /* If we have something to return to the user */ + if (!err && IW_IS_GET(cmd)) { + /* Adjust for the actual length if it's variable, + * avoid leaking kernel bits outside. + */ + if (!(descr->get_args & IW_PRIV_SIZE_FIXED)) + extra_size = adjust_priv_size(descr->get_args, iwp); + + if (copy_to_user(iwp->pointer, extra, extra_size)) + err = -EFAULT; + } + +out: + kfree(extra); + return err; +} + +int ioctl_private_call(struct net_device *dev, struct iwreq *iwr, + unsigned int cmd, struct iw_request_info *info, + iw_handler handler) +{ + int extra_size = 0, ret = -EINVAL; + const struct iw_priv_args *descr; + + extra_size = get_priv_descr_and_size(dev, cmd, &descr); + + /* Check if we have a pointer to user space data or not. */ + if (extra_size == 0) { + /* No extra arguments. Trivial to handle */ + ret = handler(dev, info, &(iwr->u), (char *) &(iwr->u)); + } else { + ret = ioctl_private_iw_point(&iwr->u.data, cmd, descr, + handler, dev, info, extra_size); + } + + /* Call commit handler if needed and defined */ + if (ret == -EIWCOMMIT) + ret = call_commit_handler(dev); + + return ret; +} + +#ifdef CONFIG_COMPAT +int compat_private_call(struct net_device *dev, struct iwreq *iwr, + unsigned int cmd, struct iw_request_info *info, + iw_handler handler) +{ + const struct iw_priv_args *descr; + int ret, extra_size; + + extra_size = get_priv_descr_and_size(dev, cmd, &descr); + + /* Check if we have a pointer to user space data or not. */ + if (extra_size == 0) { + /* No extra arguments. Trivial to handle */ + ret = handler(dev, info, &(iwr->u), (char *) &(iwr->u)); + } else { + struct compat_iw_point *iwp_compat; + struct iw_point iwp; + + iwp_compat = (struct compat_iw_point *) &iwr->u.data; + iwp.pointer = compat_ptr(iwp_compat->pointer); + iwp.length = iwp_compat->length; + iwp.flags = iwp_compat->flags; + + ret = ioctl_private_iw_point(&iwp, cmd, descr, + handler, dev, info, extra_size); + + iwp_compat->pointer = ptr_to_compat(iwp.pointer); + iwp_compat->length = iwp.length; + iwp_compat->flags = iwp.flags; + } + + /* Call commit handler if needed and defined */ + if (ret == -EIWCOMMIT) + ret = call_commit_handler(dev); + + return ret; +} +#endif diff --git a/net/wireless/wext-proc.c b/net/wireless/wext-proc.c new file mode 100644 index 000000000..cadcf8613 --- /dev/null +++ b/net/wireless/wext-proc.c @@ -0,0 +1,142 @@ +/* + * This file implement the Wireless Extensions proc API. + * + * Authors : Jean Tourrilhes - HPL - + * Copyright (c) 1997-2007 Jean Tourrilhes, All Rights Reserved. + * + * (As all part of the Linux kernel, this file is GPL) + */ + +/* + * The /proc/net/wireless file is a human readable user-space interface + * exporting various wireless specific statistics from the wireless devices. + * This is the most popular part of the Wireless Extensions ;-) + * + * This interface is a pure clone of /proc/net/dev (in net/core/dev.c). + * The content of the file is basically the content of "struct iw_statistics". + */ + +#include +#include +#include +#include +#include +#include +#include +#include + + +static void wireless_seq_printf_stats(struct seq_file *seq, + struct net_device *dev) +{ + /* Get stats from the driver */ + struct iw_statistics *stats = get_wireless_stats(dev); + static struct iw_statistics nullstats = {}; + + /* show device if it's wireless regardless of current stats */ + if (!stats) { +#ifdef CONFIG_WIRELESS_EXT + if (dev->wireless_handlers) + stats = &nullstats; +#endif +#ifdef CONFIG_CFG80211 + if (dev->ieee80211_ptr) + stats = &nullstats; +#endif + } + + if (stats) { + seq_printf(seq, "%6s: %04x %3d%c %3d%c %3d%c %6d %6d %6d " + "%6d %6d %6d\n", + dev->name, stats->status, stats->qual.qual, + stats->qual.updated & IW_QUAL_QUAL_UPDATED + ? '.' : ' ', + ((__s32) stats->qual.level) - + ((stats->qual.updated & IW_QUAL_DBM) ? 0x100 : 0), + stats->qual.updated & IW_QUAL_LEVEL_UPDATED + ? '.' : ' ', + ((__s32) stats->qual.noise) - + ((stats->qual.updated & IW_QUAL_DBM) ? 0x100 : 0), + stats->qual.updated & IW_QUAL_NOISE_UPDATED + ? '.' : ' ', + stats->discard.nwid, stats->discard.code, + stats->discard.fragment, stats->discard.retries, + stats->discard.misc, stats->miss.beacon); + + if (stats != &nullstats) + stats->qual.updated &= ~IW_QUAL_ALL_UPDATED; + } +} + +/* ---------------------------------------------------------------- */ +/* + * Print info for /proc/net/wireless (print all entries) + */ +static int wireless_dev_seq_show(struct seq_file *seq, void *v) +{ + might_sleep(); + + if (v == SEQ_START_TOKEN) + seq_printf(seq, "Inter-| sta-| Quality | Discarded " + "packets | Missed | WE\n" + " face | tus | link level noise | nwid " + "crypt frag retry misc | beacon | %d\n", + WIRELESS_EXT); + else + wireless_seq_printf_stats(seq, v); + return 0; +} + +static void *wireless_dev_seq_start(struct seq_file *seq, loff_t *pos) +{ + struct net *net = seq_file_net(seq); + loff_t off; + struct net_device *dev; + + rtnl_lock(); + if (!*pos) + return SEQ_START_TOKEN; + + off = 1; + for_each_netdev(net, dev) + if (off++ == *pos) + return dev; + return NULL; +} + +static void *wireless_dev_seq_next(struct seq_file *seq, void *v, loff_t *pos) +{ + struct net *net = seq_file_net(seq); + + ++*pos; + + return v == SEQ_START_TOKEN ? + first_net_device(net) : next_net_device(v); +} + +static void wireless_dev_seq_stop(struct seq_file *seq, void *v) +{ + rtnl_unlock(); +} + +static const struct seq_operations wireless_seq_ops = { + .start = wireless_dev_seq_start, + .next = wireless_dev_seq_next, + .stop = wireless_dev_seq_stop, + .show = wireless_dev_seq_show, +}; + +int __net_init wext_proc_init(struct net *net) +{ + /* Create /proc/net/wireless entry */ + if (!proc_create_net("wireless", 0444, net->proc_net, + &wireless_seq_ops, sizeof(struct seq_net_private))) + return -ENOMEM; + + return 0; +} + +void __net_exit wext_proc_exit(struct net *net) +{ + remove_proc_entry("wireless", net->proc_net); +} diff --git a/net/wireless/wext-sme.c b/net/wireless/wext-sme.c new file mode 100644 index 000000000..68f45afc3 --- /dev/null +++ b/net/wireless/wext-sme.c @@ -0,0 +1,410 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * cfg80211 wext compat for managed mode. + * + * Copyright 2009 Johannes Berg + * Copyright (C) 2009, 2020-2022 Intel Corporation + */ + +#include +#include +#include +#include +#include +#include +#include "wext-compat.h" +#include "nl80211.h" + +int cfg80211_mgd_wext_connect(struct cfg80211_registered_device *rdev, + struct wireless_dev *wdev) +{ + struct cfg80211_cached_keys *ck = NULL; + const u8 *prev_bssid = NULL; + int err, i; + + ASSERT_RTNL(); + ASSERT_WDEV_LOCK(wdev); + + if (!netif_running(wdev->netdev)) + return 0; + + wdev->wext.connect.ie = wdev->wext.ie; + wdev->wext.connect.ie_len = wdev->wext.ie_len; + + /* Use default background scan period */ + wdev->wext.connect.bg_scan_period = -1; + + if (wdev->wext.keys) { + wdev->wext.keys->def = wdev->wext.default_key; + if (wdev->wext.default_key != -1) + wdev->wext.connect.privacy = true; + } + + if (!wdev->wext.connect.ssid_len) + return 0; + + if (wdev->wext.keys && wdev->wext.keys->def != -1) { + ck = kmemdup(wdev->wext.keys, sizeof(*ck), GFP_KERNEL); + if (!ck) + return -ENOMEM; + for (i = 0; i < CFG80211_MAX_WEP_KEYS; i++) + ck->params[i].key = ck->data[i]; + } + + if (wdev->wext.prev_bssid_valid) + prev_bssid = wdev->wext.prev_bssid; + + err = cfg80211_connect(rdev, wdev->netdev, + &wdev->wext.connect, ck, prev_bssid); + if (err) + kfree_sensitive(ck); + + return err; +} + +int cfg80211_mgd_wext_siwfreq(struct net_device *dev, + struct iw_request_info *info, + struct iw_freq *wextfreq, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + struct ieee80211_channel *chan = NULL; + int err, freq; + + /* call only for station! */ + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_STATION)) + return -EINVAL; + + freq = cfg80211_wext_freq(wextfreq); + if (freq < 0) + return freq; + + if (freq) { + chan = ieee80211_get_channel(wdev->wiphy, freq); + if (!chan) + return -EINVAL; + if (chan->flags & IEEE80211_CHAN_DISABLED) + return -EINVAL; + } + + wdev_lock(wdev); + + if (wdev->conn) { + bool event = true; + + if (wdev->wext.connect.channel == chan) { + err = 0; + goto out; + } + + /* if SSID set, we'll try right again, avoid event */ + if (wdev->wext.connect.ssid_len) + event = false; + err = cfg80211_disconnect(rdev, dev, + WLAN_REASON_DEAUTH_LEAVING, event); + if (err) + goto out; + } + + wdev->wext.connect.channel = chan; + err = cfg80211_mgd_wext_connect(rdev, wdev); + out: + wdev_unlock(wdev); + return err; +} + +int cfg80211_mgd_wext_giwfreq(struct net_device *dev, + struct iw_request_info *info, + struct iw_freq *freq, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct ieee80211_channel *chan = NULL; + + /* call only for station! */ + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_STATION)) + return -EINVAL; + + if (wdev->valid_links) + return -EOPNOTSUPP; + + wdev_lock(wdev); + if (wdev->links[0].client.current_bss) + chan = wdev->links[0].client.current_bss->pub.channel; + else if (wdev->wext.connect.channel) + chan = wdev->wext.connect.channel; + wdev_unlock(wdev); + + if (chan) { + freq->m = chan->center_freq; + freq->e = 6; + return 0; + } + + /* no channel if not joining */ + return -EINVAL; +} + +int cfg80211_mgd_wext_siwessid(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *data, char *ssid) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + size_t len = data->length; + int err; + + /* call only for station! */ + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_STATION)) + return -EINVAL; + + if (!data->flags) + len = 0; + + /* iwconfig uses nul termination in SSID.. */ + if (len > 0 && ssid[len - 1] == '\0') + len--; + + wdev_lock(wdev); + + err = 0; + + if (wdev->conn) { + bool event = true; + + if (wdev->wext.connect.ssid && len && + len == wdev->wext.connect.ssid_len && + memcmp(wdev->wext.connect.ssid, ssid, len) == 0) + goto out; + + /* if SSID set now, we'll try to connect, avoid event */ + if (len) + event = false; + err = cfg80211_disconnect(rdev, dev, + WLAN_REASON_DEAUTH_LEAVING, event); + if (err) + goto out; + } + + wdev->wext.prev_bssid_valid = false; + wdev->wext.connect.ssid = wdev->wext.ssid; + memcpy(wdev->wext.ssid, ssid, len); + wdev->wext.connect.ssid_len = len; + + wdev->wext.connect.crypto.control_port = false; + wdev->wext.connect.crypto.control_port_ethertype = + cpu_to_be16(ETH_P_PAE); + + err = cfg80211_mgd_wext_connect(rdev, wdev); + out: + wdev_unlock(wdev); + return err; +} + +int cfg80211_mgd_wext_giwessid(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *data, char *ssid) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + int ret = 0; + + /* call only for station! */ + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_STATION)) + return -EINVAL; + + if (wdev->valid_links) + return -EINVAL; + + data->flags = 0; + + wdev_lock(wdev); + if (wdev->links[0].client.current_bss) { + const struct element *ssid_elem; + + rcu_read_lock(); + ssid_elem = ieee80211_bss_get_elem( + &wdev->links[0].client.current_bss->pub, + WLAN_EID_SSID); + if (ssid_elem) { + data->flags = 1; + data->length = ssid_elem->datalen; + if (data->length > IW_ESSID_MAX_SIZE) + ret = -EINVAL; + else + memcpy(ssid, ssid_elem->data, data->length); + } + rcu_read_unlock(); + } else if (wdev->wext.connect.ssid && wdev->wext.connect.ssid_len) { + data->flags = 1; + data->length = wdev->wext.connect.ssid_len; + memcpy(ssid, wdev->wext.connect.ssid, data->length); + } + wdev_unlock(wdev); + + return ret; +} + +int cfg80211_mgd_wext_siwap(struct net_device *dev, + struct iw_request_info *info, + struct sockaddr *ap_addr, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + u8 *bssid = ap_addr->sa_data; + int err; + + /* call only for station! */ + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_STATION)) + return -EINVAL; + + if (ap_addr->sa_family != ARPHRD_ETHER) + return -EINVAL; + + /* automatic mode */ + if (is_zero_ether_addr(bssid) || is_broadcast_ether_addr(bssid)) + bssid = NULL; + + wdev_lock(wdev); + + if (wdev->conn) { + err = 0; + /* both automatic */ + if (!bssid && !wdev->wext.connect.bssid) + goto out; + + /* fixed already - and no change */ + if (wdev->wext.connect.bssid && bssid && + ether_addr_equal(bssid, wdev->wext.connect.bssid)) + goto out; + + err = cfg80211_disconnect(rdev, dev, + WLAN_REASON_DEAUTH_LEAVING, false); + if (err) + goto out; + } + + if (bssid) { + memcpy(wdev->wext.bssid, bssid, ETH_ALEN); + wdev->wext.connect.bssid = wdev->wext.bssid; + } else + wdev->wext.connect.bssid = NULL; + + err = cfg80211_mgd_wext_connect(rdev, wdev); + out: + wdev_unlock(wdev); + return err; +} + +int cfg80211_mgd_wext_giwap(struct net_device *dev, + struct iw_request_info *info, + struct sockaddr *ap_addr, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + + /* call only for station! */ + if (WARN_ON(wdev->iftype != NL80211_IFTYPE_STATION)) + return -EINVAL; + + ap_addr->sa_family = ARPHRD_ETHER; + + wdev_lock(wdev); + if (wdev->valid_links) { + wdev_unlock(wdev); + return -EOPNOTSUPP; + } + if (wdev->links[0].client.current_bss) + memcpy(ap_addr->sa_data, + wdev->links[0].client.current_bss->pub.bssid, + ETH_ALEN); + else + eth_zero_addr(ap_addr->sa_data); + wdev_unlock(wdev); + + return 0; +} + +int cfg80211_wext_siwgenie(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *data, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct cfg80211_registered_device *rdev = wiphy_to_rdev(wdev->wiphy); + u8 *ie = extra; + int ie_len = data->length, err; + + if (wdev->iftype != NL80211_IFTYPE_STATION) + return -EOPNOTSUPP; + + if (!ie_len) + ie = NULL; + + wdev_lock(wdev); + + /* no change */ + err = 0; + if (wdev->wext.ie_len == ie_len && + memcmp(wdev->wext.ie, ie, ie_len) == 0) + goto out; + + if (ie_len) { + ie = kmemdup(extra, ie_len, GFP_KERNEL); + if (!ie) { + err = -ENOMEM; + goto out; + } + } else + ie = NULL; + + kfree(wdev->wext.ie); + wdev->wext.ie = ie; + wdev->wext.ie_len = ie_len; + + if (wdev->conn) { + err = cfg80211_disconnect(rdev, dev, + WLAN_REASON_DEAUTH_LEAVING, false); + if (err) + goto out; + } + + /* userspace better not think we'll reconnect */ + err = 0; + out: + wdev_unlock(wdev); + return err; +} + +int cfg80211_wext_siwmlme(struct net_device *dev, + struct iw_request_info *info, + struct iw_point *data, char *extra) +{ + struct wireless_dev *wdev = dev->ieee80211_ptr; + struct iw_mlme *mlme = (struct iw_mlme *)extra; + struct cfg80211_registered_device *rdev; + int err; + + if (!wdev) + return -EOPNOTSUPP; + + rdev = wiphy_to_rdev(wdev->wiphy); + + if (wdev->iftype != NL80211_IFTYPE_STATION) + return -EINVAL; + + if (mlme->addr.sa_family != ARPHRD_ETHER) + return -EINVAL; + + wiphy_lock(&rdev->wiphy); + wdev_lock(wdev); + switch (mlme->cmd) { + case IW_MLME_DEAUTH: + case IW_MLME_DISASSOC: + err = cfg80211_disconnect(rdev, dev, mlme->reason_code, true); + break; + default: + err = -EOPNOTSUPP; + break; + } + wdev_unlock(wdev); + wiphy_unlock(&rdev->wiphy); + + return err; +} diff --git a/net/wireless/wext-spy.c b/net/wireless/wext-spy.c new file mode 100644 index 000000000..b379a0371 --- /dev/null +++ b/net/wireless/wext-spy.c @@ -0,0 +1,232 @@ +/* + * This file implement the Wireless Extensions spy API. + * + * Authors : Jean Tourrilhes - HPL - + * Copyright (c) 1997-2007 Jean Tourrilhes, All Rights Reserved. + * + * (As all part of the Linux kernel, this file is GPL) + */ + +#include +#include +#include +#include +#include +#include +#include + +static inline struct iw_spy_data *get_spydata(struct net_device *dev) +{ + /* This is the new way */ + if (dev->wireless_data) + return dev->wireless_data->spy_data; + return NULL; +} + +int iw_handler_set_spy(struct net_device * dev, + struct iw_request_info * info, + union iwreq_data * wrqu, + char * extra) +{ + struct iw_spy_data * spydata = get_spydata(dev); + struct sockaddr * address = (struct sockaddr *) extra; + + /* Make sure driver is not buggy or using the old API */ + if (!spydata) + return -EOPNOTSUPP; + + /* Disable spy collection while we copy the addresses. + * While we copy addresses, any call to wireless_spy_update() + * will NOP. This is OK, as anyway the addresses are changing. */ + spydata->spy_number = 0; + + /* We want to operate without locking, because wireless_spy_update() + * most likely will happen in the interrupt handler, and therefore + * have its own locking constraints and needs performance. + * The rtnl_lock() make sure we don't race with the other iw_handlers. + * This make sure wireless_spy_update() "see" that the spy list + * is temporarily disabled. */ + smp_wmb(); + + /* Are there are addresses to copy? */ + if (wrqu->data.length > 0) { + int i; + + /* Copy addresses */ + for (i = 0; i < wrqu->data.length; i++) + memcpy(spydata->spy_address[i], address[i].sa_data, + ETH_ALEN); + /* Reset stats */ + memset(spydata->spy_stat, 0, + sizeof(struct iw_quality) * IW_MAX_SPY); + } + + /* Make sure above is updated before re-enabling */ + smp_wmb(); + + /* Enable addresses */ + spydata->spy_number = wrqu->data.length; + + return 0; +} +EXPORT_SYMBOL(iw_handler_set_spy); + +int iw_handler_get_spy(struct net_device * dev, + struct iw_request_info * info, + union iwreq_data * wrqu, + char * extra) +{ + struct iw_spy_data * spydata = get_spydata(dev); + struct sockaddr * address = (struct sockaddr *) extra; + int i; + + /* Make sure driver is not buggy or using the old API */ + if (!spydata) + return -EOPNOTSUPP; + + wrqu->data.length = spydata->spy_number; + + /* Copy addresses. */ + for (i = 0; i < spydata->spy_number; i++) { + memcpy(address[i].sa_data, spydata->spy_address[i], ETH_ALEN); + address[i].sa_family = AF_UNIX; + } + /* Copy stats to the user buffer (just after). */ + if (spydata->spy_number > 0) + memcpy(extra + (sizeof(struct sockaddr) *spydata->spy_number), + spydata->spy_stat, + sizeof(struct iw_quality) * spydata->spy_number); + /* Reset updated flags. */ + for (i = 0; i < spydata->spy_number; i++) + spydata->spy_stat[i].updated &= ~IW_QUAL_ALL_UPDATED; + return 0; +} +EXPORT_SYMBOL(iw_handler_get_spy); + +/*------------------------------------------------------------------*/ +/* + * Standard Wireless Handler : set spy threshold + */ +int iw_handler_set_thrspy(struct net_device * dev, + struct iw_request_info *info, + union iwreq_data * wrqu, + char * extra) +{ + struct iw_spy_data * spydata = get_spydata(dev); + struct iw_thrspy * threshold = (struct iw_thrspy *) extra; + + /* Make sure driver is not buggy or using the old API */ + if (!spydata) + return -EOPNOTSUPP; + + /* Just do it */ + spydata->spy_thr_low = threshold->low; + spydata->spy_thr_high = threshold->high; + + /* Clear flag */ + memset(spydata->spy_thr_under, '\0', sizeof(spydata->spy_thr_under)); + + return 0; +} +EXPORT_SYMBOL(iw_handler_set_thrspy); + +/*------------------------------------------------------------------*/ +/* + * Standard Wireless Handler : get spy threshold + */ +int iw_handler_get_thrspy(struct net_device * dev, + struct iw_request_info *info, + union iwreq_data * wrqu, + char * extra) +{ + struct iw_spy_data * spydata = get_spydata(dev); + struct iw_thrspy * threshold = (struct iw_thrspy *) extra; + + /* Make sure driver is not buggy or using the old API */ + if (!spydata) + return -EOPNOTSUPP; + + /* Just do it */ + threshold->low = spydata->spy_thr_low; + threshold->high = spydata->spy_thr_high; + + return 0; +} +EXPORT_SYMBOL(iw_handler_get_thrspy); + +/*------------------------------------------------------------------*/ +/* + * Prepare and send a Spy Threshold event + */ +static void iw_send_thrspy_event(struct net_device * dev, + struct iw_spy_data * spydata, + unsigned char * address, + struct iw_quality * wstats) +{ + union iwreq_data wrqu; + struct iw_thrspy threshold; + + /* Init */ + wrqu.data.length = 1; + wrqu.data.flags = 0; + /* Copy address */ + memcpy(threshold.addr.sa_data, address, ETH_ALEN); + threshold.addr.sa_family = ARPHRD_ETHER; + /* Copy stats */ + threshold.qual = *wstats; + /* Copy also thresholds */ + threshold.low = spydata->spy_thr_low; + threshold.high = spydata->spy_thr_high; + + /* Send event to user space */ + wireless_send_event(dev, SIOCGIWTHRSPY, &wrqu, (char *) &threshold); +} + +/* ---------------------------------------------------------------- */ +/* + * Call for the driver to update the spy data. + * For now, the spy data is a simple array. As the size of the array is + * small, this is good enough. If we wanted to support larger number of + * spy addresses, we should use something more efficient... + */ +void wireless_spy_update(struct net_device * dev, + unsigned char * address, + struct iw_quality * wstats) +{ + struct iw_spy_data * spydata = get_spydata(dev); + int i; + int match = -1; + + /* Make sure driver is not buggy or using the old API */ + if (!spydata) + return; + + /* Update all records that match */ + for (i = 0; i < spydata->spy_number; i++) + if (ether_addr_equal(address, spydata->spy_address[i])) { + memcpy(&(spydata->spy_stat[i]), wstats, + sizeof(struct iw_quality)); + match = i; + } + + /* Generate an event if we cross the spy threshold. + * To avoid event storms, we have a simple hysteresis : we generate + * event only when we go under the low threshold or above the + * high threshold. */ + if (match >= 0) { + if (spydata->spy_thr_under[match]) { + if (wstats->level > spydata->spy_thr_high.level) { + spydata->spy_thr_under[match] = 0; + iw_send_thrspy_event(dev, spydata, + address, wstats); + } + } else { + if (wstats->level < spydata->spy_thr_low.level) { + spydata->spy_thr_under[match] = 1; + iw_send_thrspy_event(dev, spydata, + address, wstats); + } + } + } +} +EXPORT_SYMBOL(wireless_spy_update); diff --git a/net/x25/Kconfig b/net/x25/Kconfig new file mode 100644 index 000000000..68729aa3a --- /dev/null +++ b/net/x25/Kconfig @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# CCITT X.25 Packet Layer +# + +config X25 + tristate "CCITT X.25 Packet Layer" + help + X.25 is a set of standardized network protocols, similar in scope to + frame relay; the one physical line from your box to the X.25 network + entry point can carry several logical point-to-point connections + (called "virtual circuits") to other computers connected to the X.25 + network. Governments, banks, and other organizations tend to use it + to connect to each other or to form Wide Area Networks (WANs). Many + countries have public X.25 networks. X.25 consists of two + protocols: the higher level Packet Layer Protocol (PLP) (say Y here + if you want that) and the lower level data link layer protocol LAPB + (say Y to "LAPB Data Link Driver" below if you want that). + + You can read more about X.25 at and + . + Information about X.25 for Linux is contained in the files + and + . + + One connects to an X.25 network either with a dedicated network card + using the X.21 protocol (not yet supported by Linux) or one can do + X.25 over a standard telephone line using an ordinary modem (say Y + to "X.25 async driver" below) or over Ethernet using an ordinary + Ethernet card and the LAPB over Ethernet (say Y to "LAPB Data Link + Driver" and "LAPB over Ethernet driver" below). + + To compile this driver as a module, choose M here: the module + will be called x25. If unsure, say N. diff --git a/net/x25/Makefile b/net/x25/Makefile new file mode 100644 index 000000000..5dd544a23 --- /dev/null +++ b/net/x25/Makefile @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the Linux X.25 Packet layer. +# + +obj-$(CONFIG_X25) += x25.o + +x25-y := af_x25.o x25_dev.o x25_facilities.o x25_in.o \ + x25_link.o x25_out.o x25_route.o x25_subr.o \ + x25_timer.o x25_proc.o x25_forward.o +x25-$(CONFIG_SYSCTL) += sysctl_net_x25.o diff --git a/net/x25/af_x25.c b/net/x25/af_x25.c new file mode 100644 index 000000000..5c7ad301d --- /dev/null +++ b/net/x25/af_x25.c @@ -0,0 +1,1856 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * X.25 Packet Layer release 002 + * + * This is ALPHA test software. This code may break your machine, + * randomly fail to work with new releases, misbehave and/or generally + * screw up. It might even work. + * + * This code REQUIRES 2.1.15 or higher + * + * History + * X.25 001 Jonathan Naylor Started coding. + * X.25 002 Jonathan Naylor Centralised disconnect handling. + * New timer architecture. + * 2000-03-11 Henner Eisen MSG_EOR handling more POSIX compliant. + * 2000-03-22 Daniela Squassoni Allowed disabling/enabling of + * facilities negotiation and increased + * the throughput upper limit. + * 2000-08-27 Arnaldo C. Melo s/suser/capable/ + micro cleanups + * 2000-09-04 Henner Eisen Set sock->state in x25_accept(). + * Fixed x25_output() related skb leakage. + * 2000-10-02 Henner Eisen Made x25_kick() single threaded per socket. + * 2000-10-27 Henner Eisen MSG_DONTWAIT for fragment allocation. + * 2000-11-14 Henner Eisen Closing datalink from NETDEV_GOING_DOWN + * 2002-10-06 Arnaldo C. Melo Get rid of cli/sti, move proc stuff to + * x25_proc.c, using seq_file + * 2005-04-02 Shaun Pereira Selective sub address matching + * with call user data + * 2005-04-15 Shaun Pereira Fast select with no restriction on + * response + */ + +#define pr_fmt(fmt) "X25: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include /* For TIOCINQ/OUTQ */ +#include +#include +#include +#include + +#include +#include + +int sysctl_x25_restart_request_timeout = X25_DEFAULT_T20; +int sysctl_x25_call_request_timeout = X25_DEFAULT_T21; +int sysctl_x25_reset_request_timeout = X25_DEFAULT_T22; +int sysctl_x25_clear_request_timeout = X25_DEFAULT_T23; +int sysctl_x25_ack_holdback_timeout = X25_DEFAULT_T2; +int sysctl_x25_forward = 0; + +HLIST_HEAD(x25_list); +DEFINE_RWLOCK(x25_list_lock); + +static const struct proto_ops x25_proto_ops; + +static const struct x25_address null_x25_address = {" "}; + +#ifdef CONFIG_COMPAT +struct compat_x25_subscrip_struct { + char device[200-sizeof(compat_ulong_t)]; + compat_ulong_t global_facil_mask; + compat_uint_t extended; +}; +#endif + + +int x25_parse_address_block(struct sk_buff *skb, + struct x25_address *called_addr, + struct x25_address *calling_addr) +{ + unsigned char len; + int needed; + int rc; + + if (!pskb_may_pull(skb, 1)) { + /* packet has no address block */ + rc = 0; + goto empty; + } + + len = *skb->data; + needed = 1 + ((len >> 4) + (len & 0x0f) + 1) / 2; + + if (!pskb_may_pull(skb, needed)) { + /* packet is too short to hold the addresses it claims + to hold */ + rc = -1; + goto empty; + } + + return x25_addr_ntoa(skb->data, called_addr, calling_addr); + +empty: + *called_addr->x25_addr = 0; + *calling_addr->x25_addr = 0; + + return rc; +} + + +int x25_addr_ntoa(unsigned char *p, struct x25_address *called_addr, + struct x25_address *calling_addr) +{ + unsigned int called_len, calling_len; + char *called, *calling; + unsigned int i; + + called_len = (*p >> 0) & 0x0F; + calling_len = (*p >> 4) & 0x0F; + + called = called_addr->x25_addr; + calling = calling_addr->x25_addr; + p++; + + for (i = 0; i < (called_len + calling_len); i++) { + if (i < called_len) { + if (i % 2 != 0) { + *called++ = ((*p >> 0) & 0x0F) + '0'; + p++; + } else { + *called++ = ((*p >> 4) & 0x0F) + '0'; + } + } else { + if (i % 2 != 0) { + *calling++ = ((*p >> 0) & 0x0F) + '0'; + p++; + } else { + *calling++ = ((*p >> 4) & 0x0F) + '0'; + } + } + } + + *called = *calling = '\0'; + + return 1 + (called_len + calling_len + 1) / 2; +} + +int x25_addr_aton(unsigned char *p, struct x25_address *called_addr, + struct x25_address *calling_addr) +{ + unsigned int called_len, calling_len; + char *called, *calling; + int i; + + called = called_addr->x25_addr; + calling = calling_addr->x25_addr; + + called_len = strlen(called); + calling_len = strlen(calling); + + *p++ = (calling_len << 4) | (called_len << 0); + + for (i = 0; i < (called_len + calling_len); i++) { + if (i < called_len) { + if (i % 2 != 0) { + *p |= (*called++ - '0') << 0; + p++; + } else { + *p = 0x00; + *p |= (*called++ - '0') << 4; + } + } else { + if (i % 2 != 0) { + *p |= (*calling++ - '0') << 0; + p++; + } else { + *p = 0x00; + *p |= (*calling++ - '0') << 4; + } + } + } + + return 1 + (called_len + calling_len + 1) / 2; +} + +/* + * Socket removal during an interrupt is now safe. + */ +static void x25_remove_socket(struct sock *sk) +{ + write_lock_bh(&x25_list_lock); + sk_del_node_init(sk); + write_unlock_bh(&x25_list_lock); +} + +/* + * Handle device status changes. + */ +static int x25_device_event(struct notifier_block *this, unsigned long event, + void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct x25_neigh *nb; + + if (!net_eq(dev_net(dev), &init_net)) + return NOTIFY_DONE; + + if (dev->type == ARPHRD_X25) { + switch (event) { + case NETDEV_REGISTER: + case NETDEV_POST_TYPE_CHANGE: + x25_link_device_up(dev); + break; + case NETDEV_DOWN: + nb = x25_get_neigh(dev); + if (nb) { + x25_link_terminated(nb); + x25_neigh_put(nb); + } + x25_route_device_down(dev); + break; + case NETDEV_PRE_TYPE_CHANGE: + case NETDEV_UNREGISTER: + x25_link_device_down(dev); + break; + case NETDEV_CHANGE: + if (!netif_carrier_ok(dev)) { + nb = x25_get_neigh(dev); + if (nb) { + x25_link_terminated(nb); + x25_neigh_put(nb); + } + } + break; + } + } + + return NOTIFY_DONE; +} + +/* + * Add a socket to the bound sockets list. + */ +static void x25_insert_socket(struct sock *sk) +{ + write_lock_bh(&x25_list_lock); + sk_add_node(sk, &x25_list); + write_unlock_bh(&x25_list_lock); +} + +/* + * Find a socket that wants to accept the Call Request we just + * received. Check the full list for an address/cud match. + * If no cuds match return the next_best thing, an address match. + * Note: if a listening socket has cud set it must only get calls + * with matching cud. + */ +static struct sock *x25_find_listener(struct x25_address *addr, + struct sk_buff *skb) +{ + struct sock *s; + struct sock *next_best; + + read_lock_bh(&x25_list_lock); + next_best = NULL; + + sk_for_each(s, &x25_list) + if ((!strcmp(addr->x25_addr, + x25_sk(s)->source_addr.x25_addr) || + !strcmp(x25_sk(s)->source_addr.x25_addr, + null_x25_address.x25_addr)) && + s->sk_state == TCP_LISTEN) { + /* + * Found a listening socket, now check the incoming + * call user data vs this sockets call user data + */ + if (x25_sk(s)->cudmatchlength > 0 && + skb->len >= x25_sk(s)->cudmatchlength) { + if((memcmp(x25_sk(s)->calluserdata.cuddata, + skb->data, + x25_sk(s)->cudmatchlength)) == 0) { + sock_hold(s); + goto found; + } + } else + next_best = s; + } + if (next_best) { + s = next_best; + sock_hold(s); + goto found; + } + s = NULL; +found: + read_unlock_bh(&x25_list_lock); + return s; +} + +/* + * Find a connected X.25 socket given my LCI and neighbour. + */ +static struct sock *__x25_find_socket(unsigned int lci, struct x25_neigh *nb) +{ + struct sock *s; + + sk_for_each(s, &x25_list) + if (x25_sk(s)->lci == lci && x25_sk(s)->neighbour == nb) { + sock_hold(s); + goto found; + } + s = NULL; +found: + return s; +} + +struct sock *x25_find_socket(unsigned int lci, struct x25_neigh *nb) +{ + struct sock *s; + + read_lock_bh(&x25_list_lock); + s = __x25_find_socket(lci, nb); + read_unlock_bh(&x25_list_lock); + return s; +} + +/* + * Find a unique LCI for a given device. + */ +static unsigned int x25_new_lci(struct x25_neigh *nb) +{ + unsigned int lci = 1; + struct sock *sk; + + while ((sk = x25_find_socket(lci, nb)) != NULL) { + sock_put(sk); + if (++lci == 4096) { + lci = 0; + break; + } + cond_resched(); + } + + return lci; +} + +/* + * Deferred destroy. + */ +static void __x25_destroy_socket(struct sock *); + +/* + * handler for deferred kills. + */ +static void x25_destroy_timer(struct timer_list *t) +{ + struct sock *sk = from_timer(sk, t, sk_timer); + + x25_destroy_socket_from_timer(sk); +} + +/* + * This is called from user mode and the timers. Thus it protects itself + * against interrupting users but doesn't worry about being called during + * work. Once it is removed from the queue no interrupt or bottom half + * will touch it and we are (fairly 8-) ) safe. + * Not static as it's used by the timer + */ +static void __x25_destroy_socket(struct sock *sk) +{ + struct sk_buff *skb; + + x25_stop_heartbeat(sk); + x25_stop_timer(sk); + + x25_remove_socket(sk); + x25_clear_queues(sk); /* Flush the queues */ + + while ((skb = skb_dequeue(&sk->sk_receive_queue)) != NULL) { + if (skb->sk != sk) { /* A pending connection */ + /* + * Queue the unaccepted socket for death + */ + skb->sk->sk_state = TCP_LISTEN; + sock_set_flag(skb->sk, SOCK_DEAD); + x25_start_heartbeat(skb->sk); + x25_sk(skb->sk)->state = X25_STATE_0; + } + + kfree_skb(skb); + } + + if (sk_has_allocations(sk)) { + /* Defer: outstanding buffers */ + sk->sk_timer.expires = jiffies + 10 * HZ; + sk->sk_timer.function = x25_destroy_timer; + add_timer(&sk->sk_timer); + } else { + /* drop last reference so sock_put will free */ + __sock_put(sk); + } +} + +void x25_destroy_socket_from_timer(struct sock *sk) +{ + sock_hold(sk); + bh_lock_sock(sk); + __x25_destroy_socket(sk); + bh_unlock_sock(sk); + sock_put(sk); +} + +/* + * Handling for system calls applied via the various interfaces to a + * X.25 socket object. + */ + +static int x25_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + int opt; + struct sock *sk = sock->sk; + int rc = -ENOPROTOOPT; + + if (level != SOL_X25 || optname != X25_QBITINCL) + goto out; + + rc = -EINVAL; + if (optlen < sizeof(int)) + goto out; + + rc = -EFAULT; + if (copy_from_sockptr(&opt, optval, sizeof(int))) + goto out; + + if (opt) + set_bit(X25_Q_BIT_FLAG, &x25_sk(sk)->flags); + else + clear_bit(X25_Q_BIT_FLAG, &x25_sk(sk)->flags); + rc = 0; +out: + return rc; +} + +static int x25_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + int val, len, rc = -ENOPROTOOPT; + + if (level != SOL_X25 || optname != X25_QBITINCL) + goto out; + + rc = -EFAULT; + if (get_user(len, optlen)) + goto out; + + len = min_t(unsigned int, len, sizeof(int)); + + rc = -EINVAL; + if (len < 0) + goto out; + + rc = -EFAULT; + if (put_user(len, optlen)) + goto out; + + val = test_bit(X25_Q_BIT_FLAG, &x25_sk(sk)->flags); + rc = copy_to_user(optval, &val, len) ? -EFAULT : 0; +out: + return rc; +} + +static int x25_listen(struct socket *sock, int backlog) +{ + struct sock *sk = sock->sk; + int rc = -EOPNOTSUPP; + + lock_sock(sk); + if (sock->state != SS_UNCONNECTED) { + rc = -EINVAL; + release_sock(sk); + return rc; + } + + if (sk->sk_state != TCP_LISTEN) { + memset(&x25_sk(sk)->dest_addr, 0, X25_ADDR_LEN); + sk->sk_max_ack_backlog = backlog; + sk->sk_state = TCP_LISTEN; + rc = 0; + } + release_sock(sk); + + return rc; +} + +static struct proto x25_proto = { + .name = "X25", + .owner = THIS_MODULE, + .obj_size = sizeof(struct x25_sock), +}; + +static struct sock *x25_alloc_socket(struct net *net, int kern) +{ + struct x25_sock *x25; + struct sock *sk = sk_alloc(net, AF_X25, GFP_ATOMIC, &x25_proto, kern); + + if (!sk) + goto out; + + sock_init_data(NULL, sk); + + x25 = x25_sk(sk); + skb_queue_head_init(&x25->ack_queue); + skb_queue_head_init(&x25->fragment_queue); + skb_queue_head_init(&x25->interrupt_in_queue); + skb_queue_head_init(&x25->interrupt_out_queue); +out: + return sk; +} + +static int x25_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct sock *sk; + struct x25_sock *x25; + int rc = -EAFNOSUPPORT; + + if (!net_eq(net, &init_net)) + goto out; + + rc = -ESOCKTNOSUPPORT; + if (sock->type != SOCK_SEQPACKET) + goto out; + + rc = -EINVAL; + if (protocol) + goto out; + + rc = -ENOMEM; + if ((sk = x25_alloc_socket(net, kern)) == NULL) + goto out; + + x25 = x25_sk(sk); + + sock_init_data(sock, sk); + + x25_init_timers(sk); + + sock->ops = &x25_proto_ops; + sk->sk_protocol = protocol; + sk->sk_backlog_rcv = x25_backlog_rcv; + + x25->t21 = sysctl_x25_call_request_timeout; + x25->t22 = sysctl_x25_reset_request_timeout; + x25->t23 = sysctl_x25_clear_request_timeout; + x25->t2 = sysctl_x25_ack_holdback_timeout; + x25->state = X25_STATE_0; + x25->cudmatchlength = 0; + set_bit(X25_ACCPT_APPRV_FLAG, &x25->flags); /* normally no cud */ + /* on call accept */ + + x25->facilities.winsize_in = X25_DEFAULT_WINDOW_SIZE; + x25->facilities.winsize_out = X25_DEFAULT_WINDOW_SIZE; + x25->facilities.pacsize_in = X25_DEFAULT_PACKET_SIZE; + x25->facilities.pacsize_out = X25_DEFAULT_PACKET_SIZE; + x25->facilities.throughput = 0; /* by default don't negotiate + throughput */ + x25->facilities.reverse = X25_DEFAULT_REVERSE; + x25->dte_facilities.calling_len = 0; + x25->dte_facilities.called_len = 0; + memset(x25->dte_facilities.called_ae, '\0', + sizeof(x25->dte_facilities.called_ae)); + memset(x25->dte_facilities.calling_ae, '\0', + sizeof(x25->dte_facilities.calling_ae)); + + rc = 0; +out: + return rc; +} + +static struct sock *x25_make_new(struct sock *osk) +{ + struct sock *sk = NULL; + struct x25_sock *x25, *ox25; + + if (osk->sk_type != SOCK_SEQPACKET) + goto out; + + if ((sk = x25_alloc_socket(sock_net(osk), 0)) == NULL) + goto out; + + x25 = x25_sk(sk); + + sk->sk_type = osk->sk_type; + sk->sk_priority = osk->sk_priority; + sk->sk_protocol = osk->sk_protocol; + sk->sk_rcvbuf = osk->sk_rcvbuf; + sk->sk_sndbuf = osk->sk_sndbuf; + sk->sk_state = TCP_ESTABLISHED; + sk->sk_backlog_rcv = osk->sk_backlog_rcv; + sock_copy_flags(sk, osk); + + ox25 = x25_sk(osk); + x25->t21 = ox25->t21; + x25->t22 = ox25->t22; + x25->t23 = ox25->t23; + x25->t2 = ox25->t2; + x25->flags = ox25->flags; + x25->facilities = ox25->facilities; + x25->dte_facilities = ox25->dte_facilities; + x25->cudmatchlength = ox25->cudmatchlength; + + clear_bit(X25_INTERRUPT_FLAG, &x25->flags); + x25_init_timers(sk); +out: + return sk; +} + +static int x25_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct x25_sock *x25; + + if (!sk) + return 0; + + x25 = x25_sk(sk); + + sock_hold(sk); + lock_sock(sk); + switch (x25->state) { + + case X25_STATE_0: + case X25_STATE_2: + x25_disconnect(sk, 0, 0, 0); + __x25_destroy_socket(sk); + goto out; + + case X25_STATE_1: + case X25_STATE_3: + case X25_STATE_4: + x25_clear_queues(sk); + x25_write_internal(sk, X25_CLEAR_REQUEST); + x25_start_t23timer(sk); + x25->state = X25_STATE_2; + sk->sk_state = TCP_CLOSE; + sk->sk_shutdown |= SEND_SHUTDOWN; + sk->sk_state_change(sk); + sock_set_flag(sk, SOCK_DEAD); + sock_set_flag(sk, SOCK_DESTROY); + break; + + case X25_STATE_5: + x25_write_internal(sk, X25_CLEAR_REQUEST); + x25_disconnect(sk, 0, 0, 0); + __x25_destroy_socket(sk); + goto out; + } + + sock_orphan(sk); +out: + release_sock(sk); + sock_put(sk); + return 0; +} + +static int x25_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) +{ + struct sock *sk = sock->sk; + struct sockaddr_x25 *addr = (struct sockaddr_x25 *)uaddr; + int len, i, rc = 0; + + if (addr_len != sizeof(struct sockaddr_x25) || + addr->sx25_family != AF_X25 || + strnlen(addr->sx25_addr.x25_addr, X25_ADDR_LEN) == X25_ADDR_LEN) { + rc = -EINVAL; + goto out; + } + + /* check for the null_x25_address */ + if (strcmp(addr->sx25_addr.x25_addr, null_x25_address.x25_addr)) { + + len = strlen(addr->sx25_addr.x25_addr); + for (i = 0; i < len; i++) { + if (!isdigit(addr->sx25_addr.x25_addr[i])) { + rc = -EINVAL; + goto out; + } + } + } + + lock_sock(sk); + if (sock_flag(sk, SOCK_ZAPPED)) { + x25_sk(sk)->source_addr = addr->sx25_addr; + x25_insert_socket(sk); + sock_reset_flag(sk, SOCK_ZAPPED); + } else { + rc = -EINVAL; + } + release_sock(sk); + SOCK_DEBUG(sk, "x25_bind: socket is bound\n"); +out: + return rc; +} + +static int x25_wait_for_connection_establishment(struct sock *sk) +{ + DECLARE_WAITQUEUE(wait, current); + int rc; + + add_wait_queue_exclusive(sk_sleep(sk), &wait); + for (;;) { + __set_current_state(TASK_INTERRUPTIBLE); + rc = -ERESTARTSYS; + if (signal_pending(current)) + break; + rc = sock_error(sk); + if (rc) { + sk->sk_socket->state = SS_UNCONNECTED; + break; + } + rc = -ENOTCONN; + if (sk->sk_state == TCP_CLOSE) { + sk->sk_socket->state = SS_UNCONNECTED; + break; + } + rc = 0; + if (sk->sk_state != TCP_ESTABLISHED) { + release_sock(sk); + schedule(); + lock_sock(sk); + } else + break; + } + __set_current_state(TASK_RUNNING); + remove_wait_queue(sk_sleep(sk), &wait); + return rc; +} + +static int x25_connect(struct socket *sock, struct sockaddr *uaddr, + int addr_len, int flags) +{ + struct sock *sk = sock->sk; + struct x25_sock *x25 = x25_sk(sk); + struct sockaddr_x25 *addr = (struct sockaddr_x25 *)uaddr; + struct x25_route *rt; + int rc = 0; + + lock_sock(sk); + if (sk->sk_state == TCP_ESTABLISHED && sock->state == SS_CONNECTING) { + sock->state = SS_CONNECTED; + goto out; /* Connect completed during a ERESTARTSYS event */ + } + + rc = -ECONNREFUSED; + if (sk->sk_state == TCP_CLOSE && sock->state == SS_CONNECTING) { + sock->state = SS_UNCONNECTED; + goto out; + } + + rc = -EISCONN; /* No reconnect on a seqpacket socket */ + if (sk->sk_state == TCP_ESTABLISHED) + goto out; + + rc = -EALREADY; /* Do nothing if call is already in progress */ + if (sk->sk_state == TCP_SYN_SENT) + goto out; + + sk->sk_state = TCP_CLOSE; + sock->state = SS_UNCONNECTED; + + rc = -EINVAL; + if (addr_len != sizeof(struct sockaddr_x25) || + addr->sx25_family != AF_X25 || + strnlen(addr->sx25_addr.x25_addr, X25_ADDR_LEN) == X25_ADDR_LEN) + goto out; + + rc = -ENETUNREACH; + rt = x25_get_route(&addr->sx25_addr); + if (!rt) + goto out; + + x25->neighbour = x25_get_neigh(rt->dev); + if (!x25->neighbour) + goto out_put_route; + + x25_limit_facilities(&x25->facilities, x25->neighbour); + + x25->lci = x25_new_lci(x25->neighbour); + if (!x25->lci) + goto out_put_neigh; + + rc = -EINVAL; + if (sock_flag(sk, SOCK_ZAPPED)) /* Must bind first - autobinding does not work */ + goto out_put_neigh; + + if (!strcmp(x25->source_addr.x25_addr, null_x25_address.x25_addr)) + memset(&x25->source_addr, '\0', X25_ADDR_LEN); + + x25->dest_addr = addr->sx25_addr; + + /* Move to connecting socket, start sending Connect Requests */ + sock->state = SS_CONNECTING; + sk->sk_state = TCP_SYN_SENT; + + x25->state = X25_STATE_1; + + x25_write_internal(sk, X25_CALL_REQUEST); + + x25_start_heartbeat(sk); + x25_start_t21timer(sk); + + /* Now the loop */ + rc = -EINPROGRESS; + if (sk->sk_state != TCP_ESTABLISHED && (flags & O_NONBLOCK)) + goto out; + + rc = x25_wait_for_connection_establishment(sk); + if (rc) + goto out_put_neigh; + + sock->state = SS_CONNECTED; + rc = 0; +out_put_neigh: + if (rc && x25->neighbour) { + read_lock_bh(&x25_list_lock); + x25_neigh_put(x25->neighbour); + x25->neighbour = NULL; + read_unlock_bh(&x25_list_lock); + x25->state = X25_STATE_0; + } +out_put_route: + x25_route_put(rt); +out: + release_sock(sk); + return rc; +} + +static int x25_wait_for_data(struct sock *sk, long timeout) +{ + DECLARE_WAITQUEUE(wait, current); + int rc = 0; + + add_wait_queue_exclusive(sk_sleep(sk), &wait); + for (;;) { + __set_current_state(TASK_INTERRUPTIBLE); + if (sk->sk_shutdown & RCV_SHUTDOWN) + break; + rc = -ERESTARTSYS; + if (signal_pending(current)) + break; + rc = -EAGAIN; + if (!timeout) + break; + rc = 0; + if (skb_queue_empty(&sk->sk_receive_queue)) { + release_sock(sk); + timeout = schedule_timeout(timeout); + lock_sock(sk); + } else + break; + } + __set_current_state(TASK_RUNNING); + remove_wait_queue(sk_sleep(sk), &wait); + return rc; +} + +static int x25_accept(struct socket *sock, struct socket *newsock, int flags, + bool kern) +{ + struct sock *sk = sock->sk; + struct sock *newsk; + struct sk_buff *skb; + int rc = -EINVAL; + + if (!sk) + goto out; + + rc = -EOPNOTSUPP; + if (sk->sk_type != SOCK_SEQPACKET) + goto out; + + lock_sock(sk); + rc = -EINVAL; + if (sk->sk_state != TCP_LISTEN) + goto out2; + + rc = x25_wait_for_data(sk, sk->sk_rcvtimeo); + if (rc) + goto out2; + skb = skb_dequeue(&sk->sk_receive_queue); + rc = -EINVAL; + if (!skb->sk) + goto out2; + newsk = skb->sk; + sock_graft(newsk, newsock); + + /* Now attach up the new socket */ + skb->sk = NULL; + kfree_skb(skb); + sk_acceptq_removed(sk); + newsock->state = SS_CONNECTED; + rc = 0; +out2: + release_sock(sk); +out: + return rc; +} + +static int x25_getname(struct socket *sock, struct sockaddr *uaddr, + int peer) +{ + struct sockaddr_x25 *sx25 = (struct sockaddr_x25 *)uaddr; + struct sock *sk = sock->sk; + struct x25_sock *x25 = x25_sk(sk); + int rc = 0; + + if (peer) { + if (sk->sk_state != TCP_ESTABLISHED) { + rc = -ENOTCONN; + goto out; + } + sx25->sx25_addr = x25->dest_addr; + } else + sx25->sx25_addr = x25->source_addr; + + sx25->sx25_family = AF_X25; + rc = sizeof(*sx25); + +out: + return rc; +} + +int x25_rx_call_request(struct sk_buff *skb, struct x25_neigh *nb, + unsigned int lci) +{ + struct sock *sk; + struct sock *make; + struct x25_sock *makex25; + struct x25_address source_addr, dest_addr; + struct x25_facilities facilities; + struct x25_dte_facilities dte_facilities; + int len, addr_len, rc; + + /* + * Remove the LCI and frame type. + */ + skb_pull(skb, X25_STD_MIN_LEN); + + /* + * Extract the X.25 addresses and convert them to ASCII strings, + * and remove them. + * + * Address block is mandatory in call request packets + */ + addr_len = x25_parse_address_block(skb, &source_addr, &dest_addr); + if (addr_len <= 0) + goto out_clear_request; + skb_pull(skb, addr_len); + + /* + * Get the length of the facilities, skip past them for the moment + * get the call user data because this is needed to determine + * the correct listener + * + * Facilities length is mandatory in call request packets + */ + if (!pskb_may_pull(skb, 1)) + goto out_clear_request; + len = skb->data[0] + 1; + if (!pskb_may_pull(skb, len)) + goto out_clear_request; + skb_pull(skb,len); + + /* + * Ensure that the amount of call user data is valid. + */ + if (skb->len > X25_MAX_CUD_LEN) + goto out_clear_request; + + /* + * Get all the call user data so it can be used in + * x25_find_listener and skb_copy_from_linear_data up ahead. + */ + if (!pskb_may_pull(skb, skb->len)) + goto out_clear_request; + + /* + * Find a listener for the particular address/cud pair. + */ + sk = x25_find_listener(&source_addr,skb); + skb_push(skb,len); + + if (sk != NULL && sk_acceptq_is_full(sk)) { + goto out_sock_put; + } + + /* + * We dont have any listeners for this incoming call. + * Try forwarding it. + */ + if (sk == NULL) { + skb_push(skb, addr_len + X25_STD_MIN_LEN); + if (sysctl_x25_forward && + x25_forward_call(&dest_addr, nb, skb, lci) > 0) + { + /* Call was forwarded, dont process it any more */ + kfree_skb(skb); + rc = 1; + goto out; + } else { + /* No listeners, can't forward, clear the call */ + goto out_clear_request; + } + } + + /* + * Try to reach a compromise on the requested facilities. + */ + len = x25_negotiate_facilities(skb, sk, &facilities, &dte_facilities); + if (len == -1) + goto out_sock_put; + + /* + * current neighbour/link might impose additional limits + * on certain facilities + */ + + x25_limit_facilities(&facilities, nb); + + /* + * Try to create a new socket. + */ + make = x25_make_new(sk); + if (!make) + goto out_sock_put; + + /* + * Remove the facilities + */ + skb_pull(skb, len); + + skb->sk = make; + make->sk_state = TCP_ESTABLISHED; + + makex25 = x25_sk(make); + makex25->lci = lci; + makex25->dest_addr = dest_addr; + makex25->source_addr = source_addr; + x25_neigh_hold(nb); + makex25->neighbour = nb; + makex25->facilities = facilities; + makex25->dte_facilities= dte_facilities; + makex25->vc_facil_mask = x25_sk(sk)->vc_facil_mask; + /* ensure no reverse facil on accept */ + makex25->vc_facil_mask &= ~X25_MASK_REVERSE; + /* ensure no calling address extension on accept */ + makex25->vc_facil_mask &= ~X25_MASK_CALLING_AE; + makex25->cudmatchlength = x25_sk(sk)->cudmatchlength; + + /* Normally all calls are accepted immediately */ + if (test_bit(X25_ACCPT_APPRV_FLAG, &makex25->flags)) { + x25_write_internal(make, X25_CALL_ACCEPTED); + makex25->state = X25_STATE_3; + } else { + makex25->state = X25_STATE_5; + } + + /* + * Incoming Call User Data. + */ + skb_copy_from_linear_data(skb, makex25->calluserdata.cuddata, skb->len); + makex25->calluserdata.cudlength = skb->len; + + sk_acceptq_added(sk); + + x25_insert_socket(make); + + skb_queue_head(&sk->sk_receive_queue, skb); + + x25_start_heartbeat(make); + + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_data_ready(sk); + rc = 1; + sock_put(sk); +out: + return rc; +out_sock_put: + sock_put(sk); +out_clear_request: + rc = 0; + x25_transmit_clear_request(nb, lci, 0x01); + goto out; +} + +static int x25_sendmsg(struct socket *sock, struct msghdr *msg, size_t len) +{ + struct sock *sk = sock->sk; + struct x25_sock *x25 = x25_sk(sk); + DECLARE_SOCKADDR(struct sockaddr_x25 *, usx25, msg->msg_name); + struct sockaddr_x25 sx25; + struct sk_buff *skb; + unsigned char *asmptr; + int noblock = msg->msg_flags & MSG_DONTWAIT; + size_t size; + int qbit = 0, rc = -EINVAL; + + lock_sock(sk); + if (msg->msg_flags & ~(MSG_DONTWAIT|MSG_OOB|MSG_EOR|MSG_CMSG_COMPAT)) + goto out; + + /* we currently don't support segmented records at the user interface */ + if (!(msg->msg_flags & (MSG_EOR|MSG_OOB))) + goto out; + + rc = -EADDRNOTAVAIL; + if (sock_flag(sk, SOCK_ZAPPED)) + goto out; + + rc = -EPIPE; + if (sk->sk_shutdown & SEND_SHUTDOWN) { + send_sig(SIGPIPE, current, 0); + goto out; + } + + rc = -ENETUNREACH; + if (!x25->neighbour) + goto out; + + if (usx25) { + rc = -EINVAL; + if (msg->msg_namelen < sizeof(sx25)) + goto out; + memcpy(&sx25, usx25, sizeof(sx25)); + rc = -EISCONN; + if (strcmp(x25->dest_addr.x25_addr, sx25.sx25_addr.x25_addr)) + goto out; + rc = -EINVAL; + if (sx25.sx25_family != AF_X25) + goto out; + } else { + /* + * FIXME 1003.1g - if the socket is like this because + * it has become closed (not started closed) we ought + * to SIGPIPE, EPIPE; + */ + rc = -ENOTCONN; + if (sk->sk_state != TCP_ESTABLISHED) + goto out; + + sx25.sx25_family = AF_X25; + sx25.sx25_addr = x25->dest_addr; + } + + /* Sanity check the packet size */ + if (len > 65535) { + rc = -EMSGSIZE; + goto out; + } + + SOCK_DEBUG(sk, "x25_sendmsg: sendto: Addresses built.\n"); + + /* Build a packet */ + SOCK_DEBUG(sk, "x25_sendmsg: sendto: building packet.\n"); + + if ((msg->msg_flags & MSG_OOB) && len > 32) + len = 32; + + size = len + X25_MAX_L2_LEN + X25_EXT_MIN_LEN; + + release_sock(sk); + skb = sock_alloc_send_skb(sk, size, noblock, &rc); + lock_sock(sk); + if (!skb) + goto out; + X25_SKB_CB(skb)->flags = msg->msg_flags; + + skb_reserve(skb, X25_MAX_L2_LEN + X25_EXT_MIN_LEN); + + /* + * Put the data on the end + */ + SOCK_DEBUG(sk, "x25_sendmsg: Copying user data\n"); + + skb_reset_transport_header(skb); + skb_put(skb, len); + + rc = memcpy_from_msg(skb_transport_header(skb), msg, len); + if (rc) + goto out_kfree_skb; + + /* + * If the Q BIT Include socket option is in force, the first + * byte of the user data is the logical value of the Q Bit. + */ + if (test_bit(X25_Q_BIT_FLAG, &x25->flags)) { + if (!pskb_may_pull(skb, 1)) + goto out_kfree_skb; + + qbit = skb->data[0]; + skb_pull(skb, 1); + } + + /* + * Push down the X.25 header + */ + SOCK_DEBUG(sk, "x25_sendmsg: Building X.25 Header.\n"); + + if (msg->msg_flags & MSG_OOB) { + if (x25->neighbour->extended) { + asmptr = skb_push(skb, X25_STD_MIN_LEN); + *asmptr++ = ((x25->lci >> 8) & 0x0F) | X25_GFI_EXTSEQ; + *asmptr++ = (x25->lci >> 0) & 0xFF; + *asmptr++ = X25_INTERRUPT; + } else { + asmptr = skb_push(skb, X25_STD_MIN_LEN); + *asmptr++ = ((x25->lci >> 8) & 0x0F) | X25_GFI_STDSEQ; + *asmptr++ = (x25->lci >> 0) & 0xFF; + *asmptr++ = X25_INTERRUPT; + } + } else { + if (x25->neighbour->extended) { + /* Build an Extended X.25 header */ + asmptr = skb_push(skb, X25_EXT_MIN_LEN); + *asmptr++ = ((x25->lci >> 8) & 0x0F) | X25_GFI_EXTSEQ; + *asmptr++ = (x25->lci >> 0) & 0xFF; + *asmptr++ = X25_DATA; + *asmptr++ = X25_DATA; + } else { + /* Build an Standard X.25 header */ + asmptr = skb_push(skb, X25_STD_MIN_LEN); + *asmptr++ = ((x25->lci >> 8) & 0x0F) | X25_GFI_STDSEQ; + *asmptr++ = (x25->lci >> 0) & 0xFF; + *asmptr++ = X25_DATA; + } + + if (qbit) + skb->data[0] |= X25_Q_BIT; + } + + SOCK_DEBUG(sk, "x25_sendmsg: Built header.\n"); + SOCK_DEBUG(sk, "x25_sendmsg: Transmitting buffer\n"); + + rc = -ENOTCONN; + if (sk->sk_state != TCP_ESTABLISHED) + goto out_kfree_skb; + + if (msg->msg_flags & MSG_OOB) + skb_queue_tail(&x25->interrupt_out_queue, skb); + else { + rc = x25_output(sk, skb); + len = rc; + if (rc < 0) + kfree_skb(skb); + else if (test_bit(X25_Q_BIT_FLAG, &x25->flags)) + len++; + } + + x25_kick(sk); + rc = len; +out: + release_sock(sk); + return rc; +out_kfree_skb: + kfree_skb(skb); + goto out; +} + + +static int x25_recvmsg(struct socket *sock, struct msghdr *msg, size_t size, + int flags) +{ + struct sock *sk = sock->sk; + struct x25_sock *x25 = x25_sk(sk); + DECLARE_SOCKADDR(struct sockaddr_x25 *, sx25, msg->msg_name); + size_t copied; + int qbit, header_len; + struct sk_buff *skb; + unsigned char *asmptr; + int rc = -ENOTCONN; + + lock_sock(sk); + + if (x25->neighbour == NULL) + goto out; + + header_len = x25->neighbour->extended ? + X25_EXT_MIN_LEN : X25_STD_MIN_LEN; + + /* + * This works for seqpacket too. The receiver has ordered the queue for + * us! We do one quick check first though + */ + if (sk->sk_state != TCP_ESTABLISHED) + goto out; + + if (flags & MSG_OOB) { + rc = -EINVAL; + if (sock_flag(sk, SOCK_URGINLINE) || + !skb_peek(&x25->interrupt_in_queue)) + goto out; + + skb = skb_dequeue(&x25->interrupt_in_queue); + + if (!pskb_may_pull(skb, X25_STD_MIN_LEN)) + goto out_free_dgram; + + skb_pull(skb, X25_STD_MIN_LEN); + + /* + * No Q bit information on Interrupt data. + */ + if (test_bit(X25_Q_BIT_FLAG, &x25->flags)) { + asmptr = skb_push(skb, 1); + *asmptr = 0x00; + } + + msg->msg_flags |= MSG_OOB; + } else { + /* Now we can treat all alike */ + release_sock(sk); + skb = skb_recv_datagram(sk, flags, &rc); + lock_sock(sk); + if (!skb) + goto out; + + if (!pskb_may_pull(skb, header_len)) + goto out_free_dgram; + + qbit = (skb->data[0] & X25_Q_BIT) == X25_Q_BIT; + + skb_pull(skb, header_len); + + if (test_bit(X25_Q_BIT_FLAG, &x25->flags)) { + asmptr = skb_push(skb, 1); + *asmptr = qbit; + } + } + + skb_reset_transport_header(skb); + copied = skb->len; + + if (copied > size) { + copied = size; + msg->msg_flags |= MSG_TRUNC; + } + + /* Currently, each datagram always contains a complete record */ + msg->msg_flags |= MSG_EOR; + + rc = skb_copy_datagram_msg(skb, 0, msg, copied); + if (rc) + goto out_free_dgram; + + if (sx25) { + sx25->sx25_family = AF_X25; + sx25->sx25_addr = x25->dest_addr; + msg->msg_namelen = sizeof(*sx25); + } + + x25_check_rbuf(sk); + rc = copied; +out_free_dgram: + skb_free_datagram(sk, skb); +out: + release_sock(sk); + return rc; +} + + +static int x25_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg) +{ + struct sock *sk = sock->sk; + struct x25_sock *x25 = x25_sk(sk); + void __user *argp = (void __user *)arg; + int rc; + + switch (cmd) { + case TIOCOUTQ: { + int amount; + + amount = sk->sk_sndbuf - sk_wmem_alloc_get(sk); + if (amount < 0) + amount = 0; + rc = put_user(amount, (unsigned int __user *)argp); + break; + } + + case TIOCINQ: { + struct sk_buff *skb; + int amount = 0; + /* + * These two are safe on a single CPU system as + * only user tasks fiddle here + */ + lock_sock(sk); + if ((skb = skb_peek(&sk->sk_receive_queue)) != NULL) + amount = skb->len; + release_sock(sk); + rc = put_user(amount, (unsigned int __user *)argp); + break; + } + + case SIOCGIFADDR: + case SIOCSIFADDR: + case SIOCGIFDSTADDR: + case SIOCSIFDSTADDR: + case SIOCGIFBRDADDR: + case SIOCSIFBRDADDR: + case SIOCGIFNETMASK: + case SIOCSIFNETMASK: + case SIOCGIFMETRIC: + case SIOCSIFMETRIC: + rc = -EINVAL; + break; + case SIOCADDRT: + case SIOCDELRT: + rc = -EPERM; + if (!capable(CAP_NET_ADMIN)) + break; + rc = x25_route_ioctl(cmd, argp); + break; + case SIOCX25GSUBSCRIP: + rc = x25_subscr_ioctl(cmd, argp); + break; + case SIOCX25SSUBSCRIP: + rc = -EPERM; + if (!capable(CAP_NET_ADMIN)) + break; + rc = x25_subscr_ioctl(cmd, argp); + break; + case SIOCX25GFACILITIES: { + lock_sock(sk); + rc = copy_to_user(argp, &x25->facilities, + sizeof(x25->facilities)) + ? -EFAULT : 0; + release_sock(sk); + break; + } + + case SIOCX25SFACILITIES: { + struct x25_facilities facilities; + rc = -EFAULT; + if (copy_from_user(&facilities, argp, sizeof(facilities))) + break; + rc = -EINVAL; + lock_sock(sk); + if (sk->sk_state != TCP_LISTEN && + sk->sk_state != TCP_CLOSE) + goto out_fac_release; + if (facilities.pacsize_in < X25_PS16 || + facilities.pacsize_in > X25_PS4096) + goto out_fac_release; + if (facilities.pacsize_out < X25_PS16 || + facilities.pacsize_out > X25_PS4096) + goto out_fac_release; + if (facilities.winsize_in < 1 || + facilities.winsize_in > 127) + goto out_fac_release; + if (facilities.throughput) { + int out = facilities.throughput & 0xf0; + int in = facilities.throughput & 0x0f; + if (!out) + facilities.throughput |= + X25_DEFAULT_THROUGHPUT << 4; + else if (out < 0x30 || out > 0xD0) + goto out_fac_release; + if (!in) + facilities.throughput |= + X25_DEFAULT_THROUGHPUT; + else if (in < 0x03 || in > 0x0D) + goto out_fac_release; + } + if (facilities.reverse && + (facilities.reverse & 0x81) != 0x81) + goto out_fac_release; + x25->facilities = facilities; + rc = 0; +out_fac_release: + release_sock(sk); + break; + } + + case SIOCX25GDTEFACILITIES: { + lock_sock(sk); + rc = copy_to_user(argp, &x25->dte_facilities, + sizeof(x25->dte_facilities)); + release_sock(sk); + if (rc) + rc = -EFAULT; + break; + } + + case SIOCX25SDTEFACILITIES: { + struct x25_dte_facilities dtefacs; + rc = -EFAULT; + if (copy_from_user(&dtefacs, argp, sizeof(dtefacs))) + break; + rc = -EINVAL; + lock_sock(sk); + if (sk->sk_state != TCP_LISTEN && + sk->sk_state != TCP_CLOSE) + goto out_dtefac_release; + if (dtefacs.calling_len > X25_MAX_AE_LEN) + goto out_dtefac_release; + if (dtefacs.called_len > X25_MAX_AE_LEN) + goto out_dtefac_release; + x25->dte_facilities = dtefacs; + rc = 0; +out_dtefac_release: + release_sock(sk); + break; + } + + case SIOCX25GCALLUSERDATA: { + lock_sock(sk); + rc = copy_to_user(argp, &x25->calluserdata, + sizeof(x25->calluserdata)) + ? -EFAULT : 0; + release_sock(sk); + break; + } + + case SIOCX25SCALLUSERDATA: { + struct x25_calluserdata calluserdata; + + rc = -EFAULT; + if (copy_from_user(&calluserdata, argp, sizeof(calluserdata))) + break; + rc = -EINVAL; + if (calluserdata.cudlength > X25_MAX_CUD_LEN) + break; + lock_sock(sk); + x25->calluserdata = calluserdata; + release_sock(sk); + rc = 0; + break; + } + + case SIOCX25GCAUSEDIAG: { + lock_sock(sk); + rc = copy_to_user(argp, &x25->causediag, sizeof(x25->causediag)) + ? -EFAULT : 0; + release_sock(sk); + break; + } + + case SIOCX25SCAUSEDIAG: { + struct x25_causediag causediag; + rc = -EFAULT; + if (copy_from_user(&causediag, argp, sizeof(causediag))) + break; + lock_sock(sk); + x25->causediag = causediag; + release_sock(sk); + rc = 0; + break; + + } + + case SIOCX25SCUDMATCHLEN: { + struct x25_subaddr sub_addr; + rc = -EINVAL; + lock_sock(sk); + if(sk->sk_state != TCP_CLOSE) + goto out_cud_release; + rc = -EFAULT; + if (copy_from_user(&sub_addr, argp, + sizeof(sub_addr))) + goto out_cud_release; + rc = -EINVAL; + if (sub_addr.cudmatchlength > X25_MAX_CUD_LEN) + goto out_cud_release; + x25->cudmatchlength = sub_addr.cudmatchlength; + rc = 0; +out_cud_release: + release_sock(sk); + break; + } + + case SIOCX25CALLACCPTAPPRV: { + rc = -EINVAL; + lock_sock(sk); + if (sk->sk_state == TCP_CLOSE) { + clear_bit(X25_ACCPT_APPRV_FLAG, &x25->flags); + rc = 0; + } + release_sock(sk); + break; + } + + case SIOCX25SENDCALLACCPT: { + rc = -EINVAL; + lock_sock(sk); + if (sk->sk_state != TCP_ESTABLISHED) + goto out_sendcallaccpt_release; + /* must call accptapprv above */ + if (test_bit(X25_ACCPT_APPRV_FLAG, &x25->flags)) + goto out_sendcallaccpt_release; + x25_write_internal(sk, X25_CALL_ACCEPTED); + x25->state = X25_STATE_3; + rc = 0; +out_sendcallaccpt_release: + release_sock(sk); + break; + } + + default: + rc = -ENOIOCTLCMD; + break; + } + + return rc; +} + +static const struct net_proto_family x25_family_ops = { + .family = AF_X25, + .create = x25_create, + .owner = THIS_MODULE, +}; + +#ifdef CONFIG_COMPAT +static int compat_x25_subscr_ioctl(unsigned int cmd, + struct compat_x25_subscrip_struct __user *x25_subscr32) +{ + struct compat_x25_subscrip_struct x25_subscr; + struct x25_neigh *nb; + struct net_device *dev; + int rc = -EINVAL; + + rc = -EFAULT; + if (copy_from_user(&x25_subscr, x25_subscr32, sizeof(*x25_subscr32))) + goto out; + + rc = -EINVAL; + dev = x25_dev_get(x25_subscr.device); + if (dev == NULL) + goto out; + + nb = x25_get_neigh(dev); + if (nb == NULL) + goto out_dev_put; + + dev_put(dev); + + if (cmd == SIOCX25GSUBSCRIP) { + read_lock_bh(&x25_neigh_list_lock); + x25_subscr.extended = nb->extended; + x25_subscr.global_facil_mask = nb->global_facil_mask; + read_unlock_bh(&x25_neigh_list_lock); + rc = copy_to_user(x25_subscr32, &x25_subscr, + sizeof(*x25_subscr32)) ? -EFAULT : 0; + } else { + rc = -EINVAL; + if (x25_subscr.extended == 0 || x25_subscr.extended == 1) { + rc = 0; + write_lock_bh(&x25_neigh_list_lock); + nb->extended = x25_subscr.extended; + nb->global_facil_mask = x25_subscr.global_facil_mask; + write_unlock_bh(&x25_neigh_list_lock); + } + } + x25_neigh_put(nb); +out: + return rc; +out_dev_put: + dev_put(dev); + goto out; +} + +static int compat_x25_ioctl(struct socket *sock, unsigned int cmd, + unsigned long arg) +{ + void __user *argp = compat_ptr(arg); + int rc = -ENOIOCTLCMD; + + switch(cmd) { + case TIOCOUTQ: + case TIOCINQ: + rc = x25_ioctl(sock, cmd, (unsigned long)argp); + break; + case SIOCGIFADDR: + case SIOCSIFADDR: + case SIOCGIFDSTADDR: + case SIOCSIFDSTADDR: + case SIOCGIFBRDADDR: + case SIOCSIFBRDADDR: + case SIOCGIFNETMASK: + case SIOCSIFNETMASK: + case SIOCGIFMETRIC: + case SIOCSIFMETRIC: + rc = -EINVAL; + break; + case SIOCADDRT: + case SIOCDELRT: + rc = -EPERM; + if (!capable(CAP_NET_ADMIN)) + break; + rc = x25_route_ioctl(cmd, argp); + break; + case SIOCX25GSUBSCRIP: + rc = compat_x25_subscr_ioctl(cmd, argp); + break; + case SIOCX25SSUBSCRIP: + rc = -EPERM; + if (!capable(CAP_NET_ADMIN)) + break; + rc = compat_x25_subscr_ioctl(cmd, argp); + break; + case SIOCX25GFACILITIES: + case SIOCX25SFACILITIES: + case SIOCX25GDTEFACILITIES: + case SIOCX25SDTEFACILITIES: + case SIOCX25GCALLUSERDATA: + case SIOCX25SCALLUSERDATA: + case SIOCX25GCAUSEDIAG: + case SIOCX25SCAUSEDIAG: + case SIOCX25SCUDMATCHLEN: + case SIOCX25CALLACCPTAPPRV: + case SIOCX25SENDCALLACCPT: + rc = x25_ioctl(sock, cmd, (unsigned long)argp); + break; + default: + rc = -ENOIOCTLCMD; + break; + } + return rc; +} +#endif + +static const struct proto_ops x25_proto_ops = { + .family = AF_X25, + .owner = THIS_MODULE, + .release = x25_release, + .bind = x25_bind, + .connect = x25_connect, + .socketpair = sock_no_socketpair, + .accept = x25_accept, + .getname = x25_getname, + .poll = datagram_poll, + .ioctl = x25_ioctl, +#ifdef CONFIG_COMPAT + .compat_ioctl = compat_x25_ioctl, +#endif + .gettstamp = sock_gettstamp, + .listen = x25_listen, + .shutdown = sock_no_shutdown, + .setsockopt = x25_setsockopt, + .getsockopt = x25_getsockopt, + .sendmsg = x25_sendmsg, + .recvmsg = x25_recvmsg, + .mmap = sock_no_mmap, + .sendpage = sock_no_sendpage, +}; + +static struct packet_type x25_packet_type __read_mostly = { + .type = cpu_to_be16(ETH_P_X25), + .func = x25_lapb_receive_frame, +}; + +static struct notifier_block x25_dev_notifier = { + .notifier_call = x25_device_event, +}; + +void x25_kill_by_neigh(struct x25_neigh *nb) +{ + struct sock *s; + + write_lock_bh(&x25_list_lock); + + sk_for_each(s, &x25_list) { + if (x25_sk(s)->neighbour == nb) { + write_unlock_bh(&x25_list_lock); + lock_sock(s); + x25_disconnect(s, ENETUNREACH, 0, 0); + release_sock(s); + write_lock_bh(&x25_list_lock); + } + } + write_unlock_bh(&x25_list_lock); + + /* Remove any related forwards */ + x25_clear_forward_by_dev(nb->dev); +} + +static int __init x25_init(void) +{ + int rc; + + rc = proto_register(&x25_proto, 0); + if (rc) + goto out; + + rc = sock_register(&x25_family_ops); + if (rc) + goto out_proto; + + dev_add_pack(&x25_packet_type); + + rc = register_netdevice_notifier(&x25_dev_notifier); + if (rc) + goto out_sock; + + rc = x25_register_sysctl(); + if (rc) + goto out_dev; + + rc = x25_proc_init(); + if (rc) + goto out_sysctl; + + pr_info("Linux Version 0.2\n"); + +out: + return rc; +out_sysctl: + x25_unregister_sysctl(); +out_dev: + unregister_netdevice_notifier(&x25_dev_notifier); +out_sock: + dev_remove_pack(&x25_packet_type); + sock_unregister(AF_X25); +out_proto: + proto_unregister(&x25_proto); + goto out; +} +module_init(x25_init); + +static void __exit x25_exit(void) +{ + x25_proc_exit(); + x25_link_free(); + x25_route_free(); + + x25_unregister_sysctl(); + + unregister_netdevice_notifier(&x25_dev_notifier); + + dev_remove_pack(&x25_packet_type); + + sock_unregister(AF_X25); + proto_unregister(&x25_proto); +} +module_exit(x25_exit); + +MODULE_AUTHOR("Jonathan Naylor "); +MODULE_DESCRIPTION("The X.25 Packet Layer network layer protocol"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NETPROTO(PF_X25); diff --git a/net/x25/sysctl_net_x25.c b/net/x25/sysctl_net_x25.c new file mode 100644 index 000000000..e9802afa4 --- /dev/null +++ b/net/x25/sysctl_net_x25.c @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: GPL-2.0 +/* -*- linux-c -*- + * sysctl_net_x25.c: sysctl interface to net X.25 subsystem. + * + * Begun April 1, 1996, Mike Shaver. + * Added /proc/sys/net/x25 directory entry (empty =) ). [MS] + */ + +#include +#include +#include +#include +#include +#include + +static int min_timer[] = { 1 * HZ }; +static int max_timer[] = { 300 * HZ }; + +static struct ctl_table_header *x25_table_header; + +static struct ctl_table x25_table[] = { + { + .procname = "restart_request_timeout", + .data = &sysctl_x25_restart_request_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_timer, + .extra2 = &max_timer, + }, + { + .procname = "call_request_timeout", + .data = &sysctl_x25_call_request_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_timer, + .extra2 = &max_timer, + }, + { + .procname = "reset_request_timeout", + .data = &sysctl_x25_reset_request_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_timer, + .extra2 = &max_timer, + }, + { + .procname = "clear_request_timeout", + .data = &sysctl_x25_clear_request_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_timer, + .extra2 = &max_timer, + }, + { + .procname = "acknowledgement_hold_back_timeout", + .data = &sysctl_x25_ack_holdback_timeout, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_minmax, + .extra1 = &min_timer, + .extra2 = &max_timer, + }, + { + .procname = "x25_forward", + .data = &sysctl_x25_forward, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { }, +}; + +int __init x25_register_sysctl(void) +{ + x25_table_header = register_net_sysctl(&init_net, "net/x25", x25_table); + if (!x25_table_header) + return -ENOMEM; + return 0; +} + +void x25_unregister_sysctl(void) +{ + unregister_net_sysctl_table(x25_table_header); +} diff --git a/net/x25/x25_dev.c b/net/x25/x25_dev.c new file mode 100644 index 000000000..748d8630a --- /dev/null +++ b/net/x25/x25_dev.c @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * X.25 Packet Layer release 002 + * + * This is ALPHA test software. This code may break your machine, randomly fail to work with new + * releases, misbehave and/or generally screw up. It might even work. + * + * This code REQUIRES 2.1.15 or higher + * + * History + * X.25 001 Jonathan Naylor Started coding. + * 2000-09-04 Henner Eisen Prevent freeing a dangling skb. + */ + +#define pr_fmt(fmt) "X25: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include + +static int x25_receive_data(struct sk_buff *skb, struct x25_neigh *nb) +{ + struct sock *sk; + unsigned short frametype; + unsigned int lci; + + if (!pskb_may_pull(skb, X25_STD_MIN_LEN)) + return 0; + + frametype = skb->data[2]; + lci = ((skb->data[0] << 8) & 0xF00) + ((skb->data[1] << 0) & 0x0FF); + + /* + * LCI of zero is always for us, and its always a link control + * frame. + */ + if (lci == 0) { + x25_link_control(skb, nb, frametype); + return 0; + } + + /* + * Find an existing socket. + */ + if ((sk = x25_find_socket(lci, nb)) != NULL) { + int queued = 1; + + skb_reset_transport_header(skb); + bh_lock_sock(sk); + if (!sock_owned_by_user(sk)) { + queued = x25_process_rx_frame(sk, skb); + } else { + queued = !sk_add_backlog(sk, skb, READ_ONCE(sk->sk_rcvbuf)); + } + bh_unlock_sock(sk); + sock_put(sk); + return queued; + } + + /* + * Is is a Call Request ? if so process it. + */ + if (frametype == X25_CALL_REQUEST) + return x25_rx_call_request(skb, nb, lci); + + /* + * Its not a Call Request, nor is it a control frame. + * Can we forward it? + */ + + if (x25_forward_data(lci, nb, skb)) { + if (frametype == X25_CLEAR_CONFIRMATION) { + x25_clear_forward_by_lci(lci); + } + kfree_skb(skb); + return 1; + } + +/* + x25_transmit_clear_request(nb, lci, 0x0D); +*/ + + if (frametype != X25_CLEAR_CONFIRMATION) + pr_debug("x25_receive_data(): unknown frame type %2x\n",frametype); + + return 0; +} + +int x25_lapb_receive_frame(struct sk_buff *skb, struct net_device *dev, + struct packet_type *ptype, struct net_device *orig_dev) +{ + struct sk_buff *nskb; + struct x25_neigh *nb; + + if (!net_eq(dev_net(dev), &init_net)) + goto drop; + + nskb = skb_copy(skb, GFP_ATOMIC); + if (!nskb) + goto drop; + kfree_skb(skb); + skb = nskb; + + /* + * Packet received from unrecognised device, throw it away. + */ + nb = x25_get_neigh(dev); + if (!nb) { + pr_debug("unknown neighbour - %s\n", dev->name); + goto drop; + } + + if (!pskb_may_pull(skb, 1)) { + x25_neigh_put(nb); + goto drop; + } + + switch (skb->data[0]) { + + case X25_IFACE_DATA: + skb_pull(skb, 1); + if (x25_receive_data(skb, nb)) { + x25_neigh_put(nb); + goto out; + } + break; + + case X25_IFACE_CONNECT: + x25_link_established(nb); + break; + + case X25_IFACE_DISCONNECT: + x25_link_terminated(nb); + break; + } + x25_neigh_put(nb); +drop: + kfree_skb(skb); +out: + return 0; +} + +void x25_establish_link(struct x25_neigh *nb) +{ + struct sk_buff *skb; + unsigned char *ptr; + + switch (nb->dev->type) { + case ARPHRD_X25: + if ((skb = alloc_skb(1, GFP_ATOMIC)) == NULL) { + pr_err("x25_dev: out of memory\n"); + return; + } + ptr = skb_put(skb, 1); + *ptr = X25_IFACE_CONNECT; + break; + + default: + return; + } + + skb->protocol = htons(ETH_P_X25); + skb->dev = nb->dev; + + dev_queue_xmit(skb); +} + +void x25_terminate_link(struct x25_neigh *nb) +{ + struct sk_buff *skb; + unsigned char *ptr; + + if (nb->dev->type != ARPHRD_X25) + return; + + skb = alloc_skb(1, GFP_ATOMIC); + if (!skb) { + pr_err("x25_dev: out of memory\n"); + return; + } + + ptr = skb_put(skb, 1); + *ptr = X25_IFACE_DISCONNECT; + + skb->protocol = htons(ETH_P_X25); + skb->dev = nb->dev; + dev_queue_xmit(skb); +} + +void x25_send_frame(struct sk_buff *skb, struct x25_neigh *nb) +{ + unsigned char *dptr; + + skb_reset_network_header(skb); + + switch (nb->dev->type) { + case ARPHRD_X25: + dptr = skb_push(skb, 1); + *dptr = X25_IFACE_DATA; + break; + + default: + kfree_skb(skb); + return; + } + + skb->protocol = htons(ETH_P_X25); + skb->dev = nb->dev; + + dev_queue_xmit(skb); +} diff --git a/net/x25/x25_facilities.c b/net/x25/x25_facilities.c new file mode 100644 index 000000000..8e1a49b0c --- /dev/null +++ b/net/x25/x25_facilities.c @@ -0,0 +1,350 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * X.25 Packet Layer release 002 + * + * This is ALPHA test software. This code may break your machine, + * randomly fail to work with new releases, misbehave and/or generally + * screw up. It might even work. + * + * This code REQUIRES 2.1.15 or higher + * + * History + * X.25 001 Split from x25_subr.c + * mar/20/00 Daniela Squassoni Disabling/enabling of facilities + * negotiation. + * apr/14/05 Shaun Pereira - Allow fast select with no restriction + * on response. + */ + +#define pr_fmt(fmt) "X25: " fmt + +#include +#include +#include +#include +#include + +/** + * x25_parse_facilities - Parse facilities from skb into the facilities structs + * + * @skb: sk_buff to parse + * @facilities: Regular facilities, updated as facilities are found + * @dte_facs: ITU DTE facilities, updated as DTE facilities are found + * @vc_fac_mask: mask is updated with all facilities found + * + * Return codes: + * -1 - Parsing error, caller should drop call and clean up + * 0 - Parse OK, this skb has no facilities + * >0 - Parse OK, returns the length of the facilities header + * + */ +int x25_parse_facilities(struct sk_buff *skb, struct x25_facilities *facilities, + struct x25_dte_facilities *dte_facs, unsigned long *vc_fac_mask) +{ + unsigned char *p; + unsigned int len; + + *vc_fac_mask = 0; + + /* + * The kernel knows which facilities were set on an incoming call but + * currently this information is not available to userspace. Here we + * give userspace who read incoming call facilities 0 length to indicate + * it wasn't set. + */ + dte_facs->calling_len = 0; + dte_facs->called_len = 0; + memset(dte_facs->called_ae, '\0', sizeof(dte_facs->called_ae)); + memset(dte_facs->calling_ae, '\0', sizeof(dte_facs->calling_ae)); + + if (!pskb_may_pull(skb, 1)) + return 0; + + len = skb->data[0]; + + if (!pskb_may_pull(skb, 1 + len)) + return -1; + + p = skb->data + 1; + + while (len > 0) { + switch (*p & X25_FAC_CLASS_MASK) { + case X25_FAC_CLASS_A: + if (len < 2) + return -1; + switch (*p) { + case X25_FAC_REVERSE: + if((p[1] & 0x81) == 0x81) { + facilities->reverse = p[1] & 0x81; + *vc_fac_mask |= X25_MASK_REVERSE; + break; + } + + if((p[1] & 0x01) == 0x01) { + facilities->reverse = p[1] & 0x01; + *vc_fac_mask |= X25_MASK_REVERSE; + break; + } + + if((p[1] & 0x80) == 0x80) { + facilities->reverse = p[1] & 0x80; + *vc_fac_mask |= X25_MASK_REVERSE; + break; + } + + if(p[1] == 0x00) { + facilities->reverse + = X25_DEFAULT_REVERSE; + *vc_fac_mask |= X25_MASK_REVERSE; + break; + } + fallthrough; + case X25_FAC_THROUGHPUT: + facilities->throughput = p[1]; + *vc_fac_mask |= X25_MASK_THROUGHPUT; + break; + case X25_MARKER: + break; + default: + pr_debug("unknown facility " + "%02X, value %02X\n", + p[0], p[1]); + break; + } + p += 2; + len -= 2; + break; + case X25_FAC_CLASS_B: + if (len < 3) + return -1; + switch (*p) { + case X25_FAC_PACKET_SIZE: + facilities->pacsize_in = p[1]; + facilities->pacsize_out = p[2]; + *vc_fac_mask |= X25_MASK_PACKET_SIZE; + break; + case X25_FAC_WINDOW_SIZE: + facilities->winsize_in = p[1]; + facilities->winsize_out = p[2]; + *vc_fac_mask |= X25_MASK_WINDOW_SIZE; + break; + default: + pr_debug("unknown facility " + "%02X, values %02X, %02X\n", + p[0], p[1], p[2]); + break; + } + p += 3; + len -= 3; + break; + case X25_FAC_CLASS_C: + if (len < 4) + return -1; + pr_debug("unknown facility %02X, " + "values %02X, %02X, %02X\n", + p[0], p[1], p[2], p[3]); + p += 4; + len -= 4; + break; + case X25_FAC_CLASS_D: + if (len < p[1] + 2) + return -1; + switch (*p) { + case X25_FAC_CALLING_AE: + if (p[1] > X25_MAX_DTE_FACIL_LEN || p[1] <= 1) + return -1; + if (p[2] > X25_MAX_AE_LEN) + return -1; + dte_facs->calling_len = p[2]; + memcpy(dte_facs->calling_ae, &p[3], p[1] - 1); + *vc_fac_mask |= X25_MASK_CALLING_AE; + break; + case X25_FAC_CALLED_AE: + if (p[1] > X25_MAX_DTE_FACIL_LEN || p[1] <= 1) + return -1; + if (p[2] > X25_MAX_AE_LEN) + return -1; + dte_facs->called_len = p[2]; + memcpy(dte_facs->called_ae, &p[3], p[1] - 1); + *vc_fac_mask |= X25_MASK_CALLED_AE; + break; + default: + pr_debug("unknown facility %02X," + "length %d\n", p[0], p[1]); + break; + } + len -= p[1] + 2; + p += p[1] + 2; + break; + } + } + + return p - skb->data; +} + +/* + * Create a set of facilities. + */ +int x25_create_facilities(unsigned char *buffer, + struct x25_facilities *facilities, + struct x25_dte_facilities *dte_facs, unsigned long facil_mask) +{ + unsigned char *p = buffer + 1; + int len; + + if (!facil_mask) { + /* + * Length of the facilities field in call_req or + * call_accept packets + */ + buffer[0] = 0; + len = 1; /* 1 byte for the length field */ + return len; + } + + if (facilities->reverse && (facil_mask & X25_MASK_REVERSE)) { + *p++ = X25_FAC_REVERSE; + *p++ = facilities->reverse; + } + + if (facilities->throughput && (facil_mask & X25_MASK_THROUGHPUT)) { + *p++ = X25_FAC_THROUGHPUT; + *p++ = facilities->throughput; + } + + if ((facilities->pacsize_in || facilities->pacsize_out) && + (facil_mask & X25_MASK_PACKET_SIZE)) { + *p++ = X25_FAC_PACKET_SIZE; + *p++ = facilities->pacsize_in ? : facilities->pacsize_out; + *p++ = facilities->pacsize_out ? : facilities->pacsize_in; + } + + if ((facilities->winsize_in || facilities->winsize_out) && + (facil_mask & X25_MASK_WINDOW_SIZE)) { + *p++ = X25_FAC_WINDOW_SIZE; + *p++ = facilities->winsize_in ? : facilities->winsize_out; + *p++ = facilities->winsize_out ? : facilities->winsize_in; + } + + if (facil_mask & (X25_MASK_CALLING_AE|X25_MASK_CALLED_AE)) { + *p++ = X25_MARKER; + *p++ = X25_DTE_SERVICES; + } + + if (dte_facs->calling_len && (facil_mask & X25_MASK_CALLING_AE)) { + unsigned int bytecount = (dte_facs->calling_len + 1) >> 1; + *p++ = X25_FAC_CALLING_AE; + *p++ = 1 + bytecount; + *p++ = dte_facs->calling_len; + memcpy(p, dte_facs->calling_ae, bytecount); + p += bytecount; + } + + if (dte_facs->called_len && (facil_mask & X25_MASK_CALLED_AE)) { + unsigned int bytecount = (dte_facs->called_len % 2) ? + dte_facs->called_len / 2 + 1 : + dte_facs->called_len / 2; + *p++ = X25_FAC_CALLED_AE; + *p++ = 1 + bytecount; + *p++ = dte_facs->called_len; + memcpy(p, dte_facs->called_ae, bytecount); + p+=bytecount; + } + + len = p - buffer; + buffer[0] = len - 1; + + return len; +} + +/* + * Try to reach a compromise on a set of facilities. + * + * The only real problem is with reverse charging. + */ +int x25_negotiate_facilities(struct sk_buff *skb, struct sock *sk, + struct x25_facilities *new, struct x25_dte_facilities *dte) +{ + struct x25_sock *x25 = x25_sk(sk); + struct x25_facilities *ours = &x25->facilities; + struct x25_facilities theirs; + int len; + + memset(&theirs, 0, sizeof(theirs)); + memcpy(new, ours, sizeof(*new)); + memset(dte, 0, sizeof(*dte)); + + len = x25_parse_facilities(skb, &theirs, dte, &x25->vc_facil_mask); + if (len < 0) + return len; + + /* + * They want reverse charging, we won't accept it. + */ + if ((theirs.reverse & 0x01 ) && (ours->reverse & 0x01)) { + SOCK_DEBUG(sk, "X.25: rejecting reverse charging request\n"); + return -1; + } + + new->reverse = theirs.reverse; + + if (theirs.throughput) { + int theirs_in = theirs.throughput & 0x0f; + int theirs_out = theirs.throughput & 0xf0; + int ours_in = ours->throughput & 0x0f; + int ours_out = ours->throughput & 0xf0; + if (!ours_in || theirs_in < ours_in) { + SOCK_DEBUG(sk, "X.25: inbound throughput negotiated\n"); + new->throughput = (new->throughput & 0xf0) | theirs_in; + } + if (!ours_out || theirs_out < ours_out) { + SOCK_DEBUG(sk, + "X.25: outbound throughput negotiated\n"); + new->throughput = (new->throughput & 0x0f) | theirs_out; + } + } + + if (theirs.pacsize_in && theirs.pacsize_out) { + if (theirs.pacsize_in < ours->pacsize_in) { + SOCK_DEBUG(sk, "X.25: packet size inwards negotiated down\n"); + new->pacsize_in = theirs.pacsize_in; + } + if (theirs.pacsize_out < ours->pacsize_out) { + SOCK_DEBUG(sk, "X.25: packet size outwards negotiated down\n"); + new->pacsize_out = theirs.pacsize_out; + } + } + + if (theirs.winsize_in && theirs.winsize_out) { + if (theirs.winsize_in < ours->winsize_in) { + SOCK_DEBUG(sk, "X.25: window size inwards negotiated down\n"); + new->winsize_in = theirs.winsize_in; + } + if (theirs.winsize_out < ours->winsize_out) { + SOCK_DEBUG(sk, "X.25: window size outwards negotiated down\n"); + new->winsize_out = theirs.winsize_out; + } + } + + return len; +} + +/* + * Limit values of certain facilities according to the capability of the + * currently attached x25 link. + */ +void x25_limit_facilities(struct x25_facilities *facilities, + struct x25_neigh *nb) +{ + + if (!nb->extended) { + if (facilities->winsize_in > 7) { + pr_debug("incoming winsize limited to 7\n"); + facilities->winsize_in = 7; + } + if (facilities->winsize_out > 7) { + facilities->winsize_out = 7; + pr_debug("outgoing winsize limited to 7\n"); + } + } +} diff --git a/net/x25/x25_forward.c b/net/x25/x25_forward.c new file mode 100644 index 000000000..21b30b56e --- /dev/null +++ b/net/x25/x25_forward.c @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * History + * 03-01-2007 Added forwarding for x.25 Andrew Hendry + */ + +#define pr_fmt(fmt) "X25: " fmt + +#include +#include +#include +#include + +LIST_HEAD(x25_forward_list); +DEFINE_RWLOCK(x25_forward_list_lock); + +int x25_forward_call(struct x25_address *dest_addr, struct x25_neigh *from, + struct sk_buff *skb, int lci) +{ + struct x25_route *rt; + struct x25_neigh *neigh_new = NULL; + struct x25_forward *x25_frwd, *new_frwd; + struct sk_buff *skbn; + short same_lci = 0; + int rc = 0; + + if ((rt = x25_get_route(dest_addr)) == NULL) + goto out_no_route; + + if ((neigh_new = x25_get_neigh(rt->dev)) == NULL) { + /* This shouldn't happen, if it occurs somehow + * do something sensible + */ + goto out_put_route; + } + + /* Avoid a loop. This is the normal exit path for a + * system with only one x.25 iface and default route + */ + if (rt->dev == from->dev) { + goto out_put_nb; + } + + /* Remote end sending a call request on an already + * established LCI? It shouldn't happen, just in case.. + */ + read_lock_bh(&x25_forward_list_lock); + list_for_each_entry(x25_frwd, &x25_forward_list, node) { + if (x25_frwd->lci == lci) { + pr_warn("call request for lci which is already registered!, transmitting but not registering new pair\n"); + same_lci = 1; + } + } + read_unlock_bh(&x25_forward_list_lock); + + /* Save the forwarding details for future traffic */ + if (!same_lci){ + if ((new_frwd = kmalloc(sizeof(struct x25_forward), + GFP_ATOMIC)) == NULL){ + rc = -ENOMEM; + goto out_put_nb; + } + new_frwd->lci = lci; + new_frwd->dev1 = rt->dev; + new_frwd->dev2 = from->dev; + write_lock_bh(&x25_forward_list_lock); + list_add(&new_frwd->node, &x25_forward_list); + write_unlock_bh(&x25_forward_list_lock); + } + + /* Forward the call request */ + if ( (skbn = skb_clone(skb, GFP_ATOMIC)) == NULL){ + goto out_put_nb; + } + x25_transmit_link(skbn, neigh_new); + rc = 1; + + +out_put_nb: + x25_neigh_put(neigh_new); + +out_put_route: + x25_route_put(rt); + +out_no_route: + return rc; +} + + +int x25_forward_data(int lci, struct x25_neigh *from, struct sk_buff *skb) { + + struct x25_forward *frwd; + struct net_device *peer = NULL; + struct x25_neigh *nb; + struct sk_buff *skbn; + int rc = 0; + + read_lock_bh(&x25_forward_list_lock); + list_for_each_entry(frwd, &x25_forward_list, node) { + if (frwd->lci == lci) { + /* The call is established, either side can send */ + if (from->dev == frwd->dev1) { + peer = frwd->dev2; + } else { + peer = frwd->dev1; + } + break; + } + } + read_unlock_bh(&x25_forward_list_lock); + + if ( (nb = x25_get_neigh(peer)) == NULL) + goto out; + + if ( (skbn = pskb_copy(skb, GFP_ATOMIC)) == NULL){ + goto output; + + } + x25_transmit_link(skbn, nb); + + rc = 1; +output: + x25_neigh_put(nb); +out: + return rc; +} + +void x25_clear_forward_by_lci(unsigned int lci) +{ + struct x25_forward *fwd, *tmp; + + write_lock_bh(&x25_forward_list_lock); + + list_for_each_entry_safe(fwd, tmp, &x25_forward_list, node) { + if (fwd->lci == lci) { + list_del(&fwd->node); + kfree(fwd); + } + } + write_unlock_bh(&x25_forward_list_lock); +} + + +void x25_clear_forward_by_dev(struct net_device *dev) +{ + struct x25_forward *fwd, *tmp; + + write_lock_bh(&x25_forward_list_lock); + + list_for_each_entry_safe(fwd, tmp, &x25_forward_list, node) { + if ((fwd->dev1 == dev) || (fwd->dev2 == dev)){ + list_del(&fwd->node); + kfree(fwd); + } + } + write_unlock_bh(&x25_forward_list_lock); +} diff --git a/net/x25/x25_in.c b/net/x25/x25_in.c new file mode 100644 index 000000000..b981a4828 --- /dev/null +++ b/net/x25/x25_in.c @@ -0,0 +1,456 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * X.25 Packet Layer release 002 + * + * This is ALPHA test software. This code may break your machine, + * randomly fail to work with new releases, misbehave and/or generally + * screw up. It might even work. + * + * This code REQUIRES 2.1.15 or higher + * + * History + * X.25 001 Jonathan Naylor Started coding. + * X.25 002 Jonathan Naylor Centralised disconnection code. + * New timer architecture. + * 2000-03-20 Daniela Squassoni Disabling/enabling of facilities + * negotiation. + * 2000-11-10 Henner Eisen Check and reset for out-of-sequence + * i-frames. + */ + +#define pr_fmt(fmt) "X25: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include + +static int x25_queue_rx_frame(struct sock *sk, struct sk_buff *skb, int more) +{ + struct sk_buff *skbo, *skbn = skb; + struct x25_sock *x25 = x25_sk(sk); + + if (more) { + x25->fraglen += skb->len; + skb_queue_tail(&x25->fragment_queue, skb); + skb_set_owner_r(skb, sk); + return 0; + } + + if (x25->fraglen > 0) { /* End of fragment */ + int len = x25->fraglen + skb->len; + + if ((skbn = alloc_skb(len, GFP_ATOMIC)) == NULL){ + kfree_skb(skb); + return 1; + } + + skb_queue_tail(&x25->fragment_queue, skb); + + skb_reset_transport_header(skbn); + + skbo = skb_dequeue(&x25->fragment_queue); + skb_copy_from_linear_data(skbo, skb_put(skbn, skbo->len), + skbo->len); + kfree_skb(skbo); + + while ((skbo = + skb_dequeue(&x25->fragment_queue)) != NULL) { + skb_pull(skbo, (x25->neighbour->extended) ? + X25_EXT_MIN_LEN : X25_STD_MIN_LEN); + skb_copy_from_linear_data(skbo, + skb_put(skbn, skbo->len), + skbo->len); + kfree_skb(skbo); + } + + x25->fraglen = 0; + } + + skb_set_owner_r(skbn, sk); + skb_queue_tail(&sk->sk_receive_queue, skbn); + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_data_ready(sk); + + return 0; +} + +/* + * State machine for state 1, Awaiting Call Accepted State. + * The handling of the timer(s) is in file x25_timer.c. + * Handling of state 0 and connection release is in af_x25.c. + */ +static int x25_state1_machine(struct sock *sk, struct sk_buff *skb, int frametype) +{ + struct x25_address source_addr, dest_addr; + int len; + struct x25_sock *x25 = x25_sk(sk); + + switch (frametype) { + case X25_CALL_ACCEPTED: { + + x25_stop_timer(sk); + x25->condition = 0x00; + x25->vs = 0; + x25->va = 0; + x25->vr = 0; + x25->vl = 0; + x25->state = X25_STATE_3; + sk->sk_state = TCP_ESTABLISHED; + /* + * Parse the data in the frame. + */ + if (!pskb_may_pull(skb, X25_STD_MIN_LEN)) + goto out_clear; + skb_pull(skb, X25_STD_MIN_LEN); + + len = x25_parse_address_block(skb, &source_addr, + &dest_addr); + if (len > 0) + skb_pull(skb, len); + else if (len < 0) + goto out_clear; + + len = x25_parse_facilities(skb, &x25->facilities, + &x25->dte_facilities, + &x25->vc_facil_mask); + if (len > 0) + skb_pull(skb, len); + else if (len < 0) + goto out_clear; + /* + * Copy any Call User Data. + */ + if (skb->len > 0) { + if (skb->len > X25_MAX_CUD_LEN) + goto out_clear; + + skb_copy_bits(skb, 0, x25->calluserdata.cuddata, + skb->len); + x25->calluserdata.cudlength = skb->len; + } + if (!sock_flag(sk, SOCK_DEAD)) + sk->sk_state_change(sk); + break; + } + case X25_CALL_REQUEST: + /* call collision */ + x25->causediag.cause = 0x01; + x25->causediag.diagnostic = 0x48; + + x25_write_internal(sk, X25_CLEAR_REQUEST); + x25_disconnect(sk, EISCONN, 0x01, 0x48); + break; + + case X25_CLEAR_REQUEST: + if (!pskb_may_pull(skb, X25_STD_MIN_LEN + 2)) + goto out_clear; + + x25_write_internal(sk, X25_CLEAR_CONFIRMATION); + x25_disconnect(sk, ECONNREFUSED, skb->data[3], skb->data[4]); + break; + + default: + break; + } + + return 0; + +out_clear: + x25_write_internal(sk, X25_CLEAR_REQUEST); + x25->state = X25_STATE_2; + x25_start_t23timer(sk); + return 0; +} + +/* + * State machine for state 2, Awaiting Clear Confirmation State. + * The handling of the timer(s) is in file x25_timer.c + * Handling of state 0 and connection release is in af_x25.c. + */ +static int x25_state2_machine(struct sock *sk, struct sk_buff *skb, int frametype) +{ + switch (frametype) { + + case X25_CLEAR_REQUEST: + if (!pskb_may_pull(skb, X25_STD_MIN_LEN + 2)) + goto out_clear; + + x25_write_internal(sk, X25_CLEAR_CONFIRMATION); + x25_disconnect(sk, 0, skb->data[3], skb->data[4]); + break; + + case X25_CLEAR_CONFIRMATION: + x25_disconnect(sk, 0, 0, 0); + break; + + default: + break; + } + + return 0; + +out_clear: + x25_write_internal(sk, X25_CLEAR_REQUEST); + x25_start_t23timer(sk); + return 0; +} + +/* + * State machine for state 3, Connected State. + * The handling of the timer(s) is in file x25_timer.c + * Handling of state 0 and connection release is in af_x25.c. + */ +static int x25_state3_machine(struct sock *sk, struct sk_buff *skb, int frametype, int ns, int nr, int q, int d, int m) +{ + int queued = 0; + int modulus; + struct x25_sock *x25 = x25_sk(sk); + + modulus = (x25->neighbour->extended) ? X25_EMODULUS : X25_SMODULUS; + + switch (frametype) { + + case X25_RESET_REQUEST: + x25_write_internal(sk, X25_RESET_CONFIRMATION); + x25_stop_timer(sk); + x25->condition = 0x00; + x25->vs = 0; + x25->vr = 0; + x25->va = 0; + x25->vl = 0; + x25_requeue_frames(sk); + break; + + case X25_CLEAR_REQUEST: + if (!pskb_may_pull(skb, X25_STD_MIN_LEN + 2)) + goto out_clear; + + x25_write_internal(sk, X25_CLEAR_CONFIRMATION); + x25_disconnect(sk, 0, skb->data[3], skb->data[4]); + break; + + case X25_RR: + case X25_RNR: + if (!x25_validate_nr(sk, nr)) { + x25_clear_queues(sk); + x25_write_internal(sk, X25_RESET_REQUEST); + x25_start_t22timer(sk); + x25->condition = 0x00; + x25->vs = 0; + x25->vr = 0; + x25->va = 0; + x25->vl = 0; + x25->state = X25_STATE_4; + } else { + x25_frames_acked(sk, nr); + if (frametype == X25_RNR) { + x25->condition |= X25_COND_PEER_RX_BUSY; + } else { + x25->condition &= ~X25_COND_PEER_RX_BUSY; + } + } + break; + + case X25_DATA: /* XXX */ + x25->condition &= ~X25_COND_PEER_RX_BUSY; + if ((ns != x25->vr) || !x25_validate_nr(sk, nr)) { + x25_clear_queues(sk); + x25_write_internal(sk, X25_RESET_REQUEST); + x25_start_t22timer(sk); + x25->condition = 0x00; + x25->vs = 0; + x25->vr = 0; + x25->va = 0; + x25->vl = 0; + x25->state = X25_STATE_4; + break; + } + x25_frames_acked(sk, nr); + if (ns == x25->vr) { + if (x25_queue_rx_frame(sk, skb, m) == 0) { + x25->vr = (x25->vr + 1) % modulus; + queued = 1; + } else { + /* Should never happen */ + x25_clear_queues(sk); + x25_write_internal(sk, X25_RESET_REQUEST); + x25_start_t22timer(sk); + x25->condition = 0x00; + x25->vs = 0; + x25->vr = 0; + x25->va = 0; + x25->vl = 0; + x25->state = X25_STATE_4; + break; + } + if (atomic_read(&sk->sk_rmem_alloc) > + (sk->sk_rcvbuf >> 1)) + x25->condition |= X25_COND_OWN_RX_BUSY; + } + /* + * If the window is full Ack it immediately, else + * start the holdback timer. + */ + if (((x25->vl + x25->facilities.winsize_in) % modulus) == x25->vr) { + x25->condition &= ~X25_COND_ACK_PENDING; + x25_stop_timer(sk); + x25_enquiry_response(sk); + } else { + x25->condition |= X25_COND_ACK_PENDING; + x25_start_t2timer(sk); + } + break; + + case X25_INTERRUPT_CONFIRMATION: + clear_bit(X25_INTERRUPT_FLAG, &x25->flags); + break; + + case X25_INTERRUPT: + if (sock_flag(sk, SOCK_URGINLINE)) + queued = !sock_queue_rcv_skb(sk, skb); + else { + skb_set_owner_r(skb, sk); + skb_queue_tail(&x25->interrupt_in_queue, skb); + queued = 1; + } + sk_send_sigurg(sk); + x25_write_internal(sk, X25_INTERRUPT_CONFIRMATION); + break; + + default: + pr_warn("unknown %02X in state 3\n", frametype); + break; + } + + return queued; + +out_clear: + x25_write_internal(sk, X25_CLEAR_REQUEST); + x25->state = X25_STATE_2; + x25_start_t23timer(sk); + return 0; +} + +/* + * State machine for state 4, Awaiting Reset Confirmation State. + * The handling of the timer(s) is in file x25_timer.c + * Handling of state 0 and connection release is in af_x25.c. + */ +static int x25_state4_machine(struct sock *sk, struct sk_buff *skb, int frametype) +{ + struct x25_sock *x25 = x25_sk(sk); + + switch (frametype) { + + case X25_RESET_REQUEST: + x25_write_internal(sk, X25_RESET_CONFIRMATION); + fallthrough; + case X25_RESET_CONFIRMATION: { + x25_stop_timer(sk); + x25->condition = 0x00; + x25->va = 0; + x25->vr = 0; + x25->vs = 0; + x25->vl = 0; + x25->state = X25_STATE_3; + x25_requeue_frames(sk); + break; + } + case X25_CLEAR_REQUEST: + if (!pskb_may_pull(skb, X25_STD_MIN_LEN + 2)) + goto out_clear; + + x25_write_internal(sk, X25_CLEAR_CONFIRMATION); + x25_disconnect(sk, 0, skb->data[3], skb->data[4]); + break; + + default: + break; + } + + return 0; + +out_clear: + x25_write_internal(sk, X25_CLEAR_REQUEST); + x25->state = X25_STATE_2; + x25_start_t23timer(sk); + return 0; +} + +/* + * State machine for state 5, Call Accepted / Call Connected pending (X25_ACCPT_APPRV_FLAG). + * The handling of the timer(s) is in file x25_timer.c + * Handling of state 0 and connection release is in af_x25.c. + */ +static int x25_state5_machine(struct sock *sk, struct sk_buff *skb, int frametype) +{ + struct x25_sock *x25 = x25_sk(sk); + + switch (frametype) { + case X25_CLEAR_REQUEST: + if (!pskb_may_pull(skb, X25_STD_MIN_LEN + 2)) { + x25_write_internal(sk, X25_CLEAR_REQUEST); + x25->state = X25_STATE_2; + x25_start_t23timer(sk); + return 0; + } + + x25_write_internal(sk, X25_CLEAR_CONFIRMATION); + x25_disconnect(sk, 0, skb->data[3], skb->data[4]); + break; + + default: + break; + } + + return 0; +} + +/* Higher level upcall for a LAPB frame */ +int x25_process_rx_frame(struct sock *sk, struct sk_buff *skb) +{ + struct x25_sock *x25 = x25_sk(sk); + int queued = 0, frametype, ns, nr, q, d, m; + + if (x25->state == X25_STATE_0) + return 0; + + frametype = x25_decode(sk, skb, &ns, &nr, &q, &d, &m); + + switch (x25->state) { + case X25_STATE_1: + queued = x25_state1_machine(sk, skb, frametype); + break; + case X25_STATE_2: + queued = x25_state2_machine(sk, skb, frametype); + break; + case X25_STATE_3: + queued = x25_state3_machine(sk, skb, frametype, ns, nr, q, d, m); + break; + case X25_STATE_4: + queued = x25_state4_machine(sk, skb, frametype); + break; + case X25_STATE_5: + queued = x25_state5_machine(sk, skb, frametype); + break; + } + + x25_kick(sk); + + return queued; +} + +int x25_backlog_rcv(struct sock *sk, struct sk_buff *skb) +{ + int queued = x25_process_rx_frame(sk, skb); + + if (!queued) + kfree_skb(skb); + + return 0; +} diff --git a/net/x25/x25_link.c b/net/x25/x25_link.c new file mode 100644 index 000000000..5460b9146 --- /dev/null +++ b/net/x25/x25_link.c @@ -0,0 +1,421 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * X.25 Packet Layer release 002 + * + * This is ALPHA test software. This code may break your machine, + * randomly fail to work with new releases, misbehave and/or generally + * screw up. It might even work. + * + * This code REQUIRES 2.1.15 or higher + * + * History + * X.25 001 Jonathan Naylor Started coding. + * X.25 002 Jonathan Naylor New timer architecture. + * mar/20/00 Daniela Squassoni Disabling/enabling of facilities + * negotiation. + * 2000-09-04 Henner Eisen dev_hold() / dev_put() for x25_neigh. + */ + +#define pr_fmt(fmt) "X25: " fmt + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +LIST_HEAD(x25_neigh_list); +DEFINE_RWLOCK(x25_neigh_list_lock); + +static void x25_t20timer_expiry(struct timer_list *); + +static void x25_transmit_restart_confirmation(struct x25_neigh *nb); +static void x25_transmit_restart_request(struct x25_neigh *nb); + +/* + * Linux set/reset timer routines + */ +static inline void x25_start_t20timer(struct x25_neigh *nb) +{ + mod_timer(&nb->t20timer, jiffies + nb->t20); +} + +static void x25_t20timer_expiry(struct timer_list *t) +{ + struct x25_neigh *nb = from_timer(nb, t, t20timer); + + x25_transmit_restart_request(nb); + + x25_start_t20timer(nb); +} + +static inline void x25_stop_t20timer(struct x25_neigh *nb) +{ + del_timer(&nb->t20timer); +} + +/* + * This handles all restart and diagnostic frames. + */ +void x25_link_control(struct sk_buff *skb, struct x25_neigh *nb, + unsigned short frametype) +{ + struct sk_buff *skbn; + + switch (frametype) { + case X25_RESTART_REQUEST: + switch (nb->state) { + case X25_LINK_STATE_0: + /* This can happen when the x25 module just gets loaded + * and doesn't know layer 2 has already connected + */ + nb->state = X25_LINK_STATE_3; + x25_transmit_restart_confirmation(nb); + break; + case X25_LINK_STATE_2: + x25_stop_t20timer(nb); + nb->state = X25_LINK_STATE_3; + break; + case X25_LINK_STATE_3: + /* clear existing virtual calls */ + x25_kill_by_neigh(nb); + + x25_transmit_restart_confirmation(nb); + break; + } + break; + + case X25_RESTART_CONFIRMATION: + switch (nb->state) { + case X25_LINK_STATE_2: + x25_stop_t20timer(nb); + nb->state = X25_LINK_STATE_3; + break; + case X25_LINK_STATE_3: + /* clear existing virtual calls */ + x25_kill_by_neigh(nb); + + x25_transmit_restart_request(nb); + nb->state = X25_LINK_STATE_2; + x25_start_t20timer(nb); + break; + } + break; + + case X25_DIAGNOSTIC: + if (!pskb_may_pull(skb, X25_STD_MIN_LEN + 4)) + break; + + pr_warn("diagnostic #%d - %02X %02X %02X\n", + skb->data[3], skb->data[4], + skb->data[5], skb->data[6]); + break; + + default: + pr_warn("received unknown %02X with LCI 000\n", + frametype); + break; + } + + if (nb->state == X25_LINK_STATE_3) + while ((skbn = skb_dequeue(&nb->queue)) != NULL) + x25_send_frame(skbn, nb); +} + +/* + * This routine is called when a Restart Request is needed + */ +static void x25_transmit_restart_request(struct x25_neigh *nb) +{ + unsigned char *dptr; + int len = X25_MAX_L2_LEN + X25_STD_MIN_LEN + 2; + struct sk_buff *skb = alloc_skb(len, GFP_ATOMIC); + + if (!skb) + return; + + skb_reserve(skb, X25_MAX_L2_LEN); + + dptr = skb_put(skb, X25_STD_MIN_LEN + 2); + + *dptr++ = nb->extended ? X25_GFI_EXTSEQ : X25_GFI_STDSEQ; + *dptr++ = 0x00; + *dptr++ = X25_RESTART_REQUEST; + *dptr++ = 0x00; + *dptr++ = 0; + + skb->sk = NULL; + + x25_send_frame(skb, nb); +} + +/* + * This routine is called when a Restart Confirmation is needed + */ +static void x25_transmit_restart_confirmation(struct x25_neigh *nb) +{ + unsigned char *dptr; + int len = X25_MAX_L2_LEN + X25_STD_MIN_LEN; + struct sk_buff *skb = alloc_skb(len, GFP_ATOMIC); + + if (!skb) + return; + + skb_reserve(skb, X25_MAX_L2_LEN); + + dptr = skb_put(skb, X25_STD_MIN_LEN); + + *dptr++ = nb->extended ? X25_GFI_EXTSEQ : X25_GFI_STDSEQ; + *dptr++ = 0x00; + *dptr++ = X25_RESTART_CONFIRMATION; + + skb->sk = NULL; + + x25_send_frame(skb, nb); +} + +/* + * This routine is called when a Clear Request is needed outside of the context + * of a connected socket. + */ +void x25_transmit_clear_request(struct x25_neigh *nb, unsigned int lci, + unsigned char cause) +{ + unsigned char *dptr; + int len = X25_MAX_L2_LEN + X25_STD_MIN_LEN + 2; + struct sk_buff *skb = alloc_skb(len, GFP_ATOMIC); + + if (!skb) + return; + + skb_reserve(skb, X25_MAX_L2_LEN); + + dptr = skb_put(skb, X25_STD_MIN_LEN + 2); + + *dptr++ = ((lci >> 8) & 0x0F) | (nb->extended ? + X25_GFI_EXTSEQ : + X25_GFI_STDSEQ); + *dptr++ = (lci >> 0) & 0xFF; + *dptr++ = X25_CLEAR_REQUEST; + *dptr++ = cause; + *dptr++ = 0x00; + + skb->sk = NULL; + + x25_send_frame(skb, nb); +} + +void x25_transmit_link(struct sk_buff *skb, struct x25_neigh *nb) +{ + switch (nb->state) { + case X25_LINK_STATE_0: + skb_queue_tail(&nb->queue, skb); + nb->state = X25_LINK_STATE_1; + x25_establish_link(nb); + break; + case X25_LINK_STATE_1: + case X25_LINK_STATE_2: + skb_queue_tail(&nb->queue, skb); + break; + case X25_LINK_STATE_3: + x25_send_frame(skb, nb); + break; + } +} + +/* + * Called when the link layer has become established. + */ +void x25_link_established(struct x25_neigh *nb) +{ + switch (nb->state) { + case X25_LINK_STATE_0: + case X25_LINK_STATE_1: + x25_transmit_restart_request(nb); + nb->state = X25_LINK_STATE_2; + x25_start_t20timer(nb); + break; + } +} + +/* + * Called when the link layer has terminated, or an establishment + * request has failed. + */ + +void x25_link_terminated(struct x25_neigh *nb) +{ + nb->state = X25_LINK_STATE_0; + skb_queue_purge(&nb->queue); + x25_stop_t20timer(nb); + + /* Out of order: clear existing virtual calls (X.25 03/93 4.6.3) */ + x25_kill_by_neigh(nb); +} + +/* + * Add a new device. + */ +void x25_link_device_up(struct net_device *dev) +{ + struct x25_neigh *nb = kmalloc(sizeof(*nb), GFP_ATOMIC); + + if (!nb) + return; + + skb_queue_head_init(&nb->queue); + timer_setup(&nb->t20timer, x25_t20timer_expiry, 0); + + dev_hold(dev); + nb->dev = dev; + nb->state = X25_LINK_STATE_0; + nb->extended = 0; + /* + * Enables negotiation + */ + nb->global_facil_mask = X25_MASK_REVERSE | + X25_MASK_THROUGHPUT | + X25_MASK_PACKET_SIZE | + X25_MASK_WINDOW_SIZE; + nb->t20 = sysctl_x25_restart_request_timeout; + refcount_set(&nb->refcnt, 1); + + write_lock_bh(&x25_neigh_list_lock); + list_add(&nb->node, &x25_neigh_list); + write_unlock_bh(&x25_neigh_list_lock); +} + +/** + * __x25_remove_neigh - remove neighbour from x25_neigh_list + * @nb: - neigh to remove + * + * Remove neighbour from x25_neigh_list. If it was there. + * Caller must hold x25_neigh_list_lock. + */ +static void __x25_remove_neigh(struct x25_neigh *nb) +{ + if (nb->node.next) { + list_del(&nb->node); + x25_neigh_put(nb); + } +} + +/* + * A device has been removed, remove its links. + */ +void x25_link_device_down(struct net_device *dev) +{ + struct x25_neigh *nb; + struct list_head *entry, *tmp; + + write_lock_bh(&x25_neigh_list_lock); + + list_for_each_safe(entry, tmp, &x25_neigh_list) { + nb = list_entry(entry, struct x25_neigh, node); + + if (nb->dev == dev) { + __x25_remove_neigh(nb); + dev_put(dev); + } + } + + write_unlock_bh(&x25_neigh_list_lock); +} + +/* + * Given a device, return the neighbour address. + */ +struct x25_neigh *x25_get_neigh(struct net_device *dev) +{ + struct x25_neigh *nb, *use = NULL; + + read_lock_bh(&x25_neigh_list_lock); + list_for_each_entry(nb, &x25_neigh_list, node) { + if (nb->dev == dev) { + use = nb; + break; + } + } + + if (use) + x25_neigh_hold(use); + read_unlock_bh(&x25_neigh_list_lock); + return use; +} + +/* + * Handle the ioctls that control the subscription functions. + */ +int x25_subscr_ioctl(unsigned int cmd, void __user *arg) +{ + struct x25_subscrip_struct x25_subscr; + struct x25_neigh *nb; + struct net_device *dev; + int rc = -EINVAL; + + if (cmd != SIOCX25GSUBSCRIP && cmd != SIOCX25SSUBSCRIP) + goto out; + + rc = -EFAULT; + if (copy_from_user(&x25_subscr, arg, sizeof(x25_subscr))) + goto out; + + rc = -EINVAL; + if ((dev = x25_dev_get(x25_subscr.device)) == NULL) + goto out; + + if ((nb = x25_get_neigh(dev)) == NULL) + goto out_dev_put; + + dev_put(dev); + + if (cmd == SIOCX25GSUBSCRIP) { + read_lock_bh(&x25_neigh_list_lock); + x25_subscr.extended = nb->extended; + x25_subscr.global_facil_mask = nb->global_facil_mask; + read_unlock_bh(&x25_neigh_list_lock); + rc = copy_to_user(arg, &x25_subscr, + sizeof(x25_subscr)) ? -EFAULT : 0; + } else { + rc = -EINVAL; + if (!(x25_subscr.extended && x25_subscr.extended != 1)) { + rc = 0; + write_lock_bh(&x25_neigh_list_lock); + nb->extended = x25_subscr.extended; + nb->global_facil_mask = x25_subscr.global_facil_mask; + write_unlock_bh(&x25_neigh_list_lock); + } + } + x25_neigh_put(nb); +out: + return rc; +out_dev_put: + dev_put(dev); + goto out; +} + + +/* + * Release all memory associated with X.25 neighbour structures. + */ +void __exit x25_link_free(void) +{ + struct x25_neigh *nb; + struct list_head *entry, *tmp; + + write_lock_bh(&x25_neigh_list_lock); + + list_for_each_safe(entry, tmp, &x25_neigh_list) { + struct net_device *dev; + + nb = list_entry(entry, struct x25_neigh, node); + dev = nb->dev; + __x25_remove_neigh(nb); + dev_put(dev); + } + write_unlock_bh(&x25_neigh_list_lock); +} diff --git a/net/x25/x25_out.c b/net/x25/x25_out.c new file mode 100644 index 000000000..dbc0940bf --- /dev/null +++ b/net/x25/x25_out.c @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * X.25 Packet Layer release 002 + * + * This is ALPHA test software. This code may break your machine, + * randomly fail to work with new releases, misbehave and/or generally + * screw up. It might even work. + * + * This code REQUIRES 2.1.15 or higher + * + * History + * X.25 001 Jonathan Naylor Started coding. + * X.25 002 Jonathan Naylor New timer architecture. + * 2000-09-04 Henner Eisen Prevented x25_output() skb leakage. + * 2000-10-27 Henner Eisen MSG_DONTWAIT for fragment allocation. + * 2000-11-10 Henner Eisen x25_send_iframe(): re-queued frames + * needed cleaned seq-number fields. + */ + +#include +#include +#include +#include +#include +#include +#include + +static int x25_pacsize_to_bytes(unsigned int pacsize) +{ + int bytes = 1; + + if (!pacsize) + return 128; + + while (pacsize-- > 0) + bytes *= 2; + + return bytes; +} + +/* + * This is where all X.25 information frames pass. + * + * Returns the amount of user data bytes sent on success + * or a negative error code on failure. + */ +int x25_output(struct sock *sk, struct sk_buff *skb) +{ + struct sk_buff *skbn; + unsigned char header[X25_EXT_MIN_LEN]; + int err, frontlen, len; + int sent=0, noblock = X25_SKB_CB(skb)->flags & MSG_DONTWAIT; + struct x25_sock *x25 = x25_sk(sk); + int header_len = x25->neighbour->extended ? X25_EXT_MIN_LEN : + X25_STD_MIN_LEN; + int max_len = x25_pacsize_to_bytes(x25->facilities.pacsize_out); + + if (skb->len - header_len > max_len) { + /* Save a copy of the Header */ + skb_copy_from_linear_data(skb, header, header_len); + skb_pull(skb, header_len); + + frontlen = skb_headroom(skb); + + while (skb->len > 0) { + release_sock(sk); + skbn = sock_alloc_send_skb(sk, frontlen + max_len, + noblock, &err); + lock_sock(sk); + if (!skbn) { + if (err == -EWOULDBLOCK && noblock){ + kfree_skb(skb); + return sent; + } + SOCK_DEBUG(sk, "x25_output: fragment alloc" + " failed, err=%d, %d bytes " + "sent\n", err, sent); + return err; + } + + skb_reserve(skbn, frontlen); + + len = max_len > skb->len ? skb->len : max_len; + + /* Copy the user data */ + skb_copy_from_linear_data(skb, skb_put(skbn, len), len); + skb_pull(skb, len); + + /* Duplicate the Header */ + skb_push(skbn, header_len); + skb_copy_to_linear_data(skbn, header, header_len); + + if (skb->len > 0) { + if (x25->neighbour->extended) + skbn->data[3] |= X25_EXT_M_BIT; + else + skbn->data[2] |= X25_STD_M_BIT; + } + + skb_queue_tail(&sk->sk_write_queue, skbn); + sent += len; + } + + kfree_skb(skb); + } else { + skb_queue_tail(&sk->sk_write_queue, skb); + sent = skb->len - header_len; + } + return sent; +} + +/* + * This procedure is passed a buffer descriptor for an iframe. It builds + * the rest of the control part of the frame and then writes it out. + */ +static void x25_send_iframe(struct sock *sk, struct sk_buff *skb) +{ + struct x25_sock *x25 = x25_sk(sk); + + if (!skb) + return; + + if (x25->neighbour->extended) { + skb->data[2] = (x25->vs << 1) & 0xFE; + skb->data[3] &= X25_EXT_M_BIT; + skb->data[3] |= (x25->vr << 1) & 0xFE; + } else { + skb->data[2] &= X25_STD_M_BIT; + skb->data[2] |= (x25->vs << 1) & 0x0E; + skb->data[2] |= (x25->vr << 5) & 0xE0; + } + + x25_transmit_link(skb, x25->neighbour); +} + +void x25_kick(struct sock *sk) +{ + struct sk_buff *skb, *skbn; + unsigned short start, end; + int modulus; + struct x25_sock *x25 = x25_sk(sk); + + if (x25->state != X25_STATE_3) + return; + + /* + * Transmit interrupt data. + */ + if (skb_peek(&x25->interrupt_out_queue) != NULL && + !test_and_set_bit(X25_INTERRUPT_FLAG, &x25->flags)) { + + skb = skb_dequeue(&x25->interrupt_out_queue); + x25_transmit_link(skb, x25->neighbour); + } + + if (x25->condition & X25_COND_PEER_RX_BUSY) + return; + + if (!skb_peek(&sk->sk_write_queue)) + return; + + modulus = x25->neighbour->extended ? X25_EMODULUS : X25_SMODULUS; + + start = skb_peek(&x25->ack_queue) ? x25->vs : x25->va; + end = (x25->va + x25->facilities.winsize_out) % modulus; + + if (start == end) + return; + + x25->vs = start; + + /* + * Transmit data until either we're out of data to send or + * the window is full. + */ + + skb = skb_dequeue(&sk->sk_write_queue); + + do { + if ((skbn = skb_clone(skb, GFP_ATOMIC)) == NULL) { + skb_queue_head(&sk->sk_write_queue, skb); + break; + } + + skb_set_owner_w(skbn, sk); + + /* + * Transmit the frame copy. + */ + x25_send_iframe(sk, skbn); + + x25->vs = (x25->vs + 1) % modulus; + + /* + * Requeue the original data frame. + */ + skb_queue_tail(&x25->ack_queue, skb); + + } while (x25->vs != end && + (skb = skb_dequeue(&sk->sk_write_queue)) != NULL); + + x25->vl = x25->vr; + x25->condition &= ~X25_COND_ACK_PENDING; + + x25_stop_timer(sk); +} + +/* + * The following routines are taken from page 170 of the 7th ARRL Computer + * Networking Conference paper, as is the whole state machine. + */ + +void x25_enquiry_response(struct sock *sk) +{ + struct x25_sock *x25 = x25_sk(sk); + + if (x25->condition & X25_COND_OWN_RX_BUSY) + x25_write_internal(sk, X25_RNR); + else + x25_write_internal(sk, X25_RR); + + x25->vl = x25->vr; + x25->condition &= ~X25_COND_ACK_PENDING; + + x25_stop_timer(sk); +} diff --git a/net/x25/x25_proc.c b/net/x25/x25_proc.c new file mode 100644 index 000000000..0412814a2 --- /dev/null +++ b/net/x25/x25_proc.c @@ -0,0 +1,206 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * X.25 Packet Layer release 002 + * + * This is ALPHA test software. This code may break your machine, + * randomly fail to work with new releases, misbehave and/or generally + * screw up. It might even work. + * + * This code REQUIRES 2.4 with seq_file support + * + * History + * 2002/10/06 Arnaldo Carvalho de Melo seq_file support + */ + +#include +#include +#include +#include +#include +#include +#include + +#ifdef CONFIG_PROC_FS + +static void *x25_seq_route_start(struct seq_file *seq, loff_t *pos) + __acquires(x25_route_list_lock) +{ + read_lock_bh(&x25_route_list_lock); + return seq_list_start_head(&x25_route_list, *pos); +} + +static void *x25_seq_route_next(struct seq_file *seq, void *v, loff_t *pos) +{ + return seq_list_next(v, &x25_route_list, pos); +} + +static void x25_seq_route_stop(struct seq_file *seq, void *v) + __releases(x25_route_list_lock) +{ + read_unlock_bh(&x25_route_list_lock); +} + +static int x25_seq_route_show(struct seq_file *seq, void *v) +{ + struct x25_route *rt = list_entry(v, struct x25_route, node); + + if (v == &x25_route_list) { + seq_puts(seq, "Address Digits Device\n"); + goto out; + } + + rt = v; + seq_printf(seq, "%-15s %-6d %-5s\n", + rt->address.x25_addr, rt->sigdigits, + rt->dev ? rt->dev->name : "???"); +out: + return 0; +} + +static void *x25_seq_socket_start(struct seq_file *seq, loff_t *pos) + __acquires(x25_list_lock) +{ + read_lock_bh(&x25_list_lock); + return seq_hlist_start_head(&x25_list, *pos); +} + +static void *x25_seq_socket_next(struct seq_file *seq, void *v, loff_t *pos) +{ + return seq_hlist_next(v, &x25_list, pos); +} + +static void x25_seq_socket_stop(struct seq_file *seq, void *v) + __releases(x25_list_lock) +{ + read_unlock_bh(&x25_list_lock); +} + +static int x25_seq_socket_show(struct seq_file *seq, void *v) +{ + struct sock *s; + struct x25_sock *x25; + const char *devname; + + if (v == SEQ_START_TOKEN) { + seq_printf(seq, "dest_addr src_addr dev lci st vs vr " + "va t t2 t21 t22 t23 Snd-Q Rcv-Q inode\n"); + goto out; + } + + s = sk_entry(v); + x25 = x25_sk(s); + + if (!x25->neighbour || !x25->neighbour->dev) + devname = "???"; + else + devname = x25->neighbour->dev->name; + + seq_printf(seq, "%-10s %-10s %-5s %3.3X %d %d %d %d %3lu %3lu " + "%3lu %3lu %3lu %5d %5d %ld\n", + !x25->dest_addr.x25_addr[0] ? "*" : x25->dest_addr.x25_addr, + !x25->source_addr.x25_addr[0] ? "*" : x25->source_addr.x25_addr, + devname, x25->lci & 0x0FFF, x25->state, x25->vs, x25->vr, + x25->va, x25_display_timer(s) / HZ, x25->t2 / HZ, + x25->t21 / HZ, x25->t22 / HZ, x25->t23 / HZ, + sk_wmem_alloc_get(s), + sk_rmem_alloc_get(s), + s->sk_socket ? SOCK_INODE(s->sk_socket)->i_ino : 0L); +out: + return 0; +} + +static void *x25_seq_forward_start(struct seq_file *seq, loff_t *pos) + __acquires(x25_forward_list_lock) +{ + read_lock_bh(&x25_forward_list_lock); + return seq_list_start_head(&x25_forward_list, *pos); +} + +static void *x25_seq_forward_next(struct seq_file *seq, void *v, loff_t *pos) +{ + return seq_list_next(v, &x25_forward_list, pos); +} + +static void x25_seq_forward_stop(struct seq_file *seq, void *v) + __releases(x25_forward_list_lock) +{ + read_unlock_bh(&x25_forward_list_lock); +} + +static int x25_seq_forward_show(struct seq_file *seq, void *v) +{ + struct x25_forward *f = list_entry(v, struct x25_forward, node); + + if (v == &x25_forward_list) { + seq_printf(seq, "lci dev1 dev2\n"); + goto out; + } + + f = v; + + seq_printf(seq, "%d %-10s %-10s\n", + f->lci, f->dev1->name, f->dev2->name); +out: + return 0; +} + +static const struct seq_operations x25_seq_route_ops = { + .start = x25_seq_route_start, + .next = x25_seq_route_next, + .stop = x25_seq_route_stop, + .show = x25_seq_route_show, +}; + +static const struct seq_operations x25_seq_socket_ops = { + .start = x25_seq_socket_start, + .next = x25_seq_socket_next, + .stop = x25_seq_socket_stop, + .show = x25_seq_socket_show, +}; + +static const struct seq_operations x25_seq_forward_ops = { + .start = x25_seq_forward_start, + .next = x25_seq_forward_next, + .stop = x25_seq_forward_stop, + .show = x25_seq_forward_show, +}; + +int __init x25_proc_init(void) +{ + if (!proc_mkdir("x25", init_net.proc_net)) + return -ENOMEM; + + if (!proc_create_seq("x25/route", 0444, init_net.proc_net, + &x25_seq_route_ops)) + goto out; + + if (!proc_create_seq("x25/socket", 0444, init_net.proc_net, + &x25_seq_socket_ops)) + goto out; + + if (!proc_create_seq("x25/forward", 0444, init_net.proc_net, + &x25_seq_forward_ops)) + goto out; + return 0; + +out: + remove_proc_subtree("x25", init_net.proc_net); + return -ENOMEM; +} + +void __exit x25_proc_exit(void) +{ + remove_proc_subtree("x25", init_net.proc_net); +} + +#else /* CONFIG_PROC_FS */ + +int __init x25_proc_init(void) +{ + return 0; +} + +void __exit x25_proc_exit(void) +{ +} +#endif /* CONFIG_PROC_FS */ diff --git a/net/x25/x25_route.c b/net/x25/x25_route.c new file mode 100644 index 000000000..647f325ed --- /dev/null +++ b/net/x25/x25_route.c @@ -0,0 +1,204 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * X.25 Packet Layer release 002 + * + * This is ALPHA test software. This code may break your machine, + * randomly fail to work with new releases, misbehave and/or generally + * screw up. It might even work. + * + * This code REQUIRES 2.1.15 or higher + * + * History + * X.25 001 Jonathan Naylor Started coding. + */ + +#include +#include +#include +#include + +LIST_HEAD(x25_route_list); +DEFINE_RWLOCK(x25_route_list_lock); + +/* + * Add a new route. + */ +static int x25_add_route(struct x25_address *address, unsigned int sigdigits, + struct net_device *dev) +{ + struct x25_route *rt; + int rc = -EINVAL; + + write_lock_bh(&x25_route_list_lock); + + list_for_each_entry(rt, &x25_route_list, node) { + if (!memcmp(&rt->address, address, sigdigits) && + rt->sigdigits == sigdigits) + goto out; + } + + rt = kmalloc(sizeof(*rt), GFP_ATOMIC); + rc = -ENOMEM; + if (!rt) + goto out; + + strcpy(rt->address.x25_addr, "000000000000000"); + memcpy(rt->address.x25_addr, address->x25_addr, sigdigits); + + rt->sigdigits = sigdigits; + rt->dev = dev; + refcount_set(&rt->refcnt, 1); + + list_add(&rt->node, &x25_route_list); + rc = 0; +out: + write_unlock_bh(&x25_route_list_lock); + return rc; +} + +/** + * __x25_remove_route - remove route from x25_route_list + * @rt: route to remove + * + * Remove route from x25_route_list. If it was there. + * Caller must hold x25_route_list_lock. + */ +static void __x25_remove_route(struct x25_route *rt) +{ + if (rt->node.next) { + list_del(&rt->node); + x25_route_put(rt); + } +} + +static int x25_del_route(struct x25_address *address, unsigned int sigdigits, + struct net_device *dev) +{ + struct x25_route *rt; + int rc = -EINVAL; + + write_lock_bh(&x25_route_list_lock); + + list_for_each_entry(rt, &x25_route_list, node) { + if (!memcmp(&rt->address, address, sigdigits) && + rt->sigdigits == sigdigits && rt->dev == dev) { + __x25_remove_route(rt); + rc = 0; + break; + } + } + + write_unlock_bh(&x25_route_list_lock); + return rc; +} + +/* + * A device has been removed, remove its routes. + */ +void x25_route_device_down(struct net_device *dev) +{ + struct x25_route *rt; + struct list_head *entry, *tmp; + + write_lock_bh(&x25_route_list_lock); + + list_for_each_safe(entry, tmp, &x25_route_list) { + rt = list_entry(entry, struct x25_route, node); + + if (rt->dev == dev) + __x25_remove_route(rt); + } + write_unlock_bh(&x25_route_list_lock); +} + +/* + * Check that the device given is a valid X.25 interface that is "up". + */ +struct net_device *x25_dev_get(char *devname) +{ + struct net_device *dev = dev_get_by_name(&init_net, devname); + + if (dev && (!(dev->flags & IFF_UP) || dev->type != ARPHRD_X25)) { + dev_put(dev); + dev = NULL; + } + + return dev; +} + +/** + * x25_get_route - Find a route given an X.25 address. + * @addr: - address to find a route for + * + * Find a route given an X.25 address. + */ +struct x25_route *x25_get_route(struct x25_address *addr) +{ + struct x25_route *rt, *use = NULL; + + read_lock_bh(&x25_route_list_lock); + + list_for_each_entry(rt, &x25_route_list, node) { + if (!memcmp(&rt->address, addr, rt->sigdigits)) { + if (!use) + use = rt; + else if (rt->sigdigits > use->sigdigits) + use = rt; + } + } + + if (use) + x25_route_hold(use); + + read_unlock_bh(&x25_route_list_lock); + return use; +} + +/* + * Handle the ioctls that control the routing functions. + */ +int x25_route_ioctl(unsigned int cmd, void __user *arg) +{ + struct x25_route_struct rt; + struct net_device *dev; + int rc = -EINVAL; + + if (cmd != SIOCADDRT && cmd != SIOCDELRT) + goto out; + + rc = -EFAULT; + if (copy_from_user(&rt, arg, sizeof(rt))) + goto out; + + rc = -EINVAL; + if (rt.sigdigits > 15) + goto out; + + dev = x25_dev_get(rt.device); + if (!dev) + goto out; + + if (cmd == SIOCADDRT) + rc = x25_add_route(&rt.address, rt.sigdigits, dev); + else + rc = x25_del_route(&rt.address, rt.sigdigits, dev); + dev_put(dev); +out: + return rc; +} + +/* + * Release all memory associated with X.25 routing structures. + */ +void __exit x25_route_free(void) +{ + struct x25_route *rt; + struct list_head *entry, *tmp; + + write_lock_bh(&x25_route_list_lock); + list_for_each_safe(entry, tmp, &x25_route_list) { + rt = list_entry(entry, struct x25_route, node); + __x25_remove_route(rt); + } + write_unlock_bh(&x25_route_list_lock); +} diff --git a/net/x25/x25_subr.c b/net/x25/x25_subr.c new file mode 100644 index 000000000..0285aaa1e --- /dev/null +++ b/net/x25/x25_subr.c @@ -0,0 +1,384 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * X.25 Packet Layer release 002 + * + * This is ALPHA test software. This code may break your machine, + * randomly fail to work with new releases, misbehave and/or generally + * screw up. It might even work. + * + * This code REQUIRES 2.1.15 or higher + * + * History + * X.25 001 Jonathan Naylor Started coding. + * X.25 002 Jonathan Naylor Centralised disconnection processing. + * mar/20/00 Daniela Squassoni Disabling/enabling of facilities + * negotiation. + * jun/24/01 Arnaldo C. Melo use skb_queue_purge, cleanups + * apr/04/15 Shaun Pereira Fast select with no + * restriction on response. + */ + +#define pr_fmt(fmt) "X25: " fmt + +#include +#include +#include +#include +#include +#include +#include + +/* + * This routine purges all of the queues of frames. + */ +void x25_clear_queues(struct sock *sk) +{ + struct x25_sock *x25 = x25_sk(sk); + + skb_queue_purge(&sk->sk_write_queue); + skb_queue_purge(&x25->ack_queue); + skb_queue_purge(&x25->interrupt_in_queue); + skb_queue_purge(&x25->interrupt_out_queue); + skb_queue_purge(&x25->fragment_queue); +} + + +/* + * This routine purges the input queue of those frames that have been + * acknowledged. This replaces the boxes labelled "V(a) <- N(r)" on the + * SDL diagram. +*/ +void x25_frames_acked(struct sock *sk, unsigned short nr) +{ + struct sk_buff *skb; + struct x25_sock *x25 = x25_sk(sk); + int modulus = x25->neighbour->extended ? X25_EMODULUS : X25_SMODULUS; + + /* + * Remove all the ack-ed frames from the ack queue. + */ + if (x25->va != nr) + while (skb_peek(&x25->ack_queue) && x25->va != nr) { + skb = skb_dequeue(&x25->ack_queue); + kfree_skb(skb); + x25->va = (x25->va + 1) % modulus; + } +} + +void x25_requeue_frames(struct sock *sk) +{ + struct sk_buff *skb, *skb_prev = NULL; + + /* + * Requeue all the un-ack-ed frames on the output queue to be picked + * up by x25_kick. This arrangement handles the possibility of an empty + * output queue. + */ + while ((skb = skb_dequeue(&x25_sk(sk)->ack_queue)) != NULL) { + if (!skb_prev) + skb_queue_head(&sk->sk_write_queue, skb); + else + skb_append(skb_prev, skb, &sk->sk_write_queue); + skb_prev = skb; + } +} + +/* + * Validate that the value of nr is between va and vs. Return true or + * false for testing. + */ +int x25_validate_nr(struct sock *sk, unsigned short nr) +{ + struct x25_sock *x25 = x25_sk(sk); + unsigned short vc = x25->va; + int modulus = x25->neighbour->extended ? X25_EMODULUS : X25_SMODULUS; + + while (vc != x25->vs) { + if (nr == vc) + return 1; + vc = (vc + 1) % modulus; + } + + return nr == x25->vs ? 1 : 0; +} + +/* + * This routine is called when the packet layer internally generates a + * control frame. + */ +void x25_write_internal(struct sock *sk, int frametype) +{ + struct x25_sock *x25 = x25_sk(sk); + struct sk_buff *skb; + unsigned char *dptr; + unsigned char facilities[X25_MAX_FAC_LEN]; + unsigned char addresses[1 + X25_ADDR_LEN]; + unsigned char lci1, lci2; + /* + * Default safe frame size. + */ + int len = X25_MAX_L2_LEN + X25_EXT_MIN_LEN; + + /* + * Adjust frame size. + */ + switch (frametype) { + case X25_CALL_REQUEST: + len += 1 + X25_ADDR_LEN + X25_MAX_FAC_LEN + X25_MAX_CUD_LEN; + break; + case X25_CALL_ACCEPTED: /* fast sel with no restr on resp */ + if (x25->facilities.reverse & 0x80) { + len += 1 + X25_MAX_FAC_LEN + X25_MAX_CUD_LEN; + } else { + len += 1 + X25_MAX_FAC_LEN; + } + break; + case X25_CLEAR_REQUEST: + case X25_RESET_REQUEST: + len += 2; + break; + case X25_RR: + case X25_RNR: + case X25_REJ: + case X25_CLEAR_CONFIRMATION: + case X25_INTERRUPT_CONFIRMATION: + case X25_RESET_CONFIRMATION: + break; + default: + pr_err("invalid frame type %02X\n", frametype); + return; + } + + if ((skb = alloc_skb(len, GFP_ATOMIC)) == NULL) + return; + + /* + * Space for Ethernet and 802.2 LLC headers. + */ + skb_reserve(skb, X25_MAX_L2_LEN); + + /* + * Make space for the GFI and LCI, and fill them in. + */ + dptr = skb_put(skb, 2); + + lci1 = (x25->lci >> 8) & 0x0F; + lci2 = (x25->lci >> 0) & 0xFF; + + if (x25->neighbour->extended) { + *dptr++ = lci1 | X25_GFI_EXTSEQ; + *dptr++ = lci2; + } else { + *dptr++ = lci1 | X25_GFI_STDSEQ; + *dptr++ = lci2; + } + + /* + * Now fill in the frame type specific information. + */ + switch (frametype) { + + case X25_CALL_REQUEST: + dptr = skb_put(skb, 1); + *dptr++ = X25_CALL_REQUEST; + len = x25_addr_aton(addresses, &x25->dest_addr, + &x25->source_addr); + skb_put_data(skb, addresses, len); + len = x25_create_facilities(facilities, + &x25->facilities, + &x25->dte_facilities, + x25->neighbour->global_facil_mask); + skb_put_data(skb, facilities, len); + skb_put_data(skb, x25->calluserdata.cuddata, + x25->calluserdata.cudlength); + x25->calluserdata.cudlength = 0; + break; + + case X25_CALL_ACCEPTED: + dptr = skb_put(skb, 2); + *dptr++ = X25_CALL_ACCEPTED; + *dptr++ = 0x00; /* Address lengths */ + len = x25_create_facilities(facilities, + &x25->facilities, + &x25->dte_facilities, + x25->vc_facil_mask); + skb_put_data(skb, facilities, len); + + /* fast select with no restriction on response + allows call user data. Userland must + ensure it is ours and not theirs */ + if(x25->facilities.reverse & 0x80) { + skb_put_data(skb, + x25->calluserdata.cuddata, + x25->calluserdata.cudlength); + } + x25->calluserdata.cudlength = 0; + break; + + case X25_CLEAR_REQUEST: + dptr = skb_put(skb, 3); + *dptr++ = frametype; + *dptr++ = x25->causediag.cause; + *dptr++ = x25->causediag.diagnostic; + break; + + case X25_RESET_REQUEST: + dptr = skb_put(skb, 3); + *dptr++ = frametype; + *dptr++ = 0x00; /* XXX */ + *dptr++ = 0x00; /* XXX */ + break; + + case X25_RR: + case X25_RNR: + case X25_REJ: + if (x25->neighbour->extended) { + dptr = skb_put(skb, 2); + *dptr++ = frametype; + *dptr++ = (x25->vr << 1) & 0xFE; + } else { + dptr = skb_put(skb, 1); + *dptr = frametype; + *dptr++ |= (x25->vr << 5) & 0xE0; + } + break; + + case X25_CLEAR_CONFIRMATION: + case X25_INTERRUPT_CONFIRMATION: + case X25_RESET_CONFIRMATION: + dptr = skb_put(skb, 1); + *dptr = frametype; + break; + } + + x25_transmit_link(skb, x25->neighbour); +} + +/* + * Unpick the contents of the passed X.25 Packet Layer frame. + */ +int x25_decode(struct sock *sk, struct sk_buff *skb, int *ns, int *nr, int *q, + int *d, int *m) +{ + struct x25_sock *x25 = x25_sk(sk); + unsigned char *frame; + + if (!pskb_may_pull(skb, X25_STD_MIN_LEN)) + return X25_ILLEGAL; + frame = skb->data; + + *ns = *nr = *q = *d = *m = 0; + + switch (frame[2]) { + case X25_CALL_REQUEST: + case X25_CALL_ACCEPTED: + case X25_CLEAR_REQUEST: + case X25_CLEAR_CONFIRMATION: + case X25_INTERRUPT: + case X25_INTERRUPT_CONFIRMATION: + case X25_RESET_REQUEST: + case X25_RESET_CONFIRMATION: + case X25_RESTART_REQUEST: + case X25_RESTART_CONFIRMATION: + case X25_REGISTRATION_REQUEST: + case X25_REGISTRATION_CONFIRMATION: + case X25_DIAGNOSTIC: + return frame[2]; + } + + if (x25->neighbour->extended) { + if (frame[2] == X25_RR || + frame[2] == X25_RNR || + frame[2] == X25_REJ) { + if (!pskb_may_pull(skb, X25_EXT_MIN_LEN)) + return X25_ILLEGAL; + frame = skb->data; + + *nr = (frame[3] >> 1) & 0x7F; + return frame[2]; + } + } else { + if ((frame[2] & 0x1F) == X25_RR || + (frame[2] & 0x1F) == X25_RNR || + (frame[2] & 0x1F) == X25_REJ) { + *nr = (frame[2] >> 5) & 0x07; + return frame[2] & 0x1F; + } + } + + if (x25->neighbour->extended) { + if ((frame[2] & 0x01) == X25_DATA) { + if (!pskb_may_pull(skb, X25_EXT_MIN_LEN)) + return X25_ILLEGAL; + frame = skb->data; + + *q = (frame[0] & X25_Q_BIT) == X25_Q_BIT; + *d = (frame[0] & X25_D_BIT) == X25_D_BIT; + *m = (frame[3] & X25_EXT_M_BIT) == X25_EXT_M_BIT; + *nr = (frame[3] >> 1) & 0x7F; + *ns = (frame[2] >> 1) & 0x7F; + return X25_DATA; + } + } else { + if ((frame[2] & 0x01) == X25_DATA) { + *q = (frame[0] & X25_Q_BIT) == X25_Q_BIT; + *d = (frame[0] & X25_D_BIT) == X25_D_BIT; + *m = (frame[2] & X25_STD_M_BIT) == X25_STD_M_BIT; + *nr = (frame[2] >> 5) & 0x07; + *ns = (frame[2] >> 1) & 0x07; + return X25_DATA; + } + } + + pr_debug("invalid PLP frame %3ph\n", frame); + + return X25_ILLEGAL; +} + +void x25_disconnect(struct sock *sk, int reason, unsigned char cause, + unsigned char diagnostic) +{ + struct x25_sock *x25 = x25_sk(sk); + + x25_clear_queues(sk); + x25_stop_timer(sk); + + x25->lci = 0; + x25->state = X25_STATE_0; + + x25->causediag.cause = cause; + x25->causediag.diagnostic = diagnostic; + + sk->sk_state = TCP_CLOSE; + sk->sk_err = reason; + sk->sk_shutdown |= SEND_SHUTDOWN; + + if (!sock_flag(sk, SOCK_DEAD)) { + sk->sk_state_change(sk); + sock_set_flag(sk, SOCK_DEAD); + } + if (x25->neighbour) { + read_lock_bh(&x25_list_lock); + x25_neigh_put(x25->neighbour); + x25->neighbour = NULL; + read_unlock_bh(&x25_list_lock); + } +} + +/* + * Clear an own-rx-busy condition and tell the peer about this, provided + * that there is a significant amount of free receive buffer space available. + */ +void x25_check_rbuf(struct sock *sk) +{ + struct x25_sock *x25 = x25_sk(sk); + + if (atomic_read(&sk->sk_rmem_alloc) < (sk->sk_rcvbuf >> 1) && + (x25->condition & X25_COND_OWN_RX_BUSY)) { + x25->condition &= ~X25_COND_OWN_RX_BUSY; + x25->condition &= ~X25_COND_ACK_PENDING; + x25->vl = x25->vr; + x25_write_internal(sk, X25_RR); + x25_stop_timer(sk); + } +} diff --git a/net/x25/x25_timer.c b/net/x25/x25_timer.c new file mode 100644 index 000000000..9376365cd --- /dev/null +++ b/net/x25/x25_timer.c @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * X.25 Packet Layer release 002 + * + * This is ALPHA test software. This code may break your machine, + * randomly fail to work with new releases, misbehave and/or generally + * screw up. It might even work. + * + * This code REQUIRES 2.1.15 or higher + * + * History + * X.25 001 Jonathan Naylor Started coding. + * X.25 002 Jonathan Naylor New timer architecture. + * Centralised disconnection processing. + */ + +#include +#include +#include +#include +#include +#include + +static void x25_heartbeat_expiry(struct timer_list *t); +static void x25_timer_expiry(struct timer_list *t); + +void x25_init_timers(struct sock *sk) +{ + struct x25_sock *x25 = x25_sk(sk); + + timer_setup(&x25->timer, x25_timer_expiry, 0); + + /* initialized by sock_init_data */ + sk->sk_timer.function = x25_heartbeat_expiry; +} + +void x25_start_heartbeat(struct sock *sk) +{ + mod_timer(&sk->sk_timer, jiffies + 5 * HZ); +} + +void x25_stop_heartbeat(struct sock *sk) +{ + del_timer(&sk->sk_timer); +} + +void x25_start_t2timer(struct sock *sk) +{ + struct x25_sock *x25 = x25_sk(sk); + + mod_timer(&x25->timer, jiffies + x25->t2); +} + +void x25_start_t21timer(struct sock *sk) +{ + struct x25_sock *x25 = x25_sk(sk); + + mod_timer(&x25->timer, jiffies + x25->t21); +} + +void x25_start_t22timer(struct sock *sk) +{ + struct x25_sock *x25 = x25_sk(sk); + + mod_timer(&x25->timer, jiffies + x25->t22); +} + +void x25_start_t23timer(struct sock *sk) +{ + struct x25_sock *x25 = x25_sk(sk); + + mod_timer(&x25->timer, jiffies + x25->t23); +} + +void x25_stop_timer(struct sock *sk) +{ + del_timer(&x25_sk(sk)->timer); +} + +unsigned long x25_display_timer(struct sock *sk) +{ + struct x25_sock *x25 = x25_sk(sk); + + if (!timer_pending(&x25->timer)) + return 0; + + return x25->timer.expires - jiffies; +} + +static void x25_heartbeat_expiry(struct timer_list *t) +{ + struct sock *sk = from_timer(sk, t, sk_timer); + + bh_lock_sock(sk); + if (sock_owned_by_user(sk)) /* can currently only occur in state 3 */ + goto restart_heartbeat; + + switch (x25_sk(sk)->state) { + + case X25_STATE_0: + /* + * Magic here: If we listen() and a new link dies + * before it is accepted() it isn't 'dead' so doesn't + * get removed. + */ + if (sock_flag(sk, SOCK_DESTROY) || + (sk->sk_state == TCP_LISTEN && + sock_flag(sk, SOCK_DEAD))) { + bh_unlock_sock(sk); + x25_destroy_socket_from_timer(sk); + return; + } + break; + + case X25_STATE_3: + /* + * Check for the state of the receive buffer. + */ + x25_check_rbuf(sk); + break; + } +restart_heartbeat: + x25_start_heartbeat(sk); + bh_unlock_sock(sk); +} + +/* + * Timer has expired, it may have been T2, T21, T22, or T23. We can tell + * by the state machine state. + */ +static inline void x25_do_timer_expiry(struct sock * sk) +{ + struct x25_sock *x25 = x25_sk(sk); + + switch (x25->state) { + + case X25_STATE_3: /* T2 */ + if (x25->condition & X25_COND_ACK_PENDING) { + x25->condition &= ~X25_COND_ACK_PENDING; + x25_enquiry_response(sk); + } + break; + + case X25_STATE_1: /* T21 */ + case X25_STATE_4: /* T22 */ + x25_write_internal(sk, X25_CLEAR_REQUEST); + x25->state = X25_STATE_2; + x25_start_t23timer(sk); + break; + + case X25_STATE_2: /* T23 */ + x25_disconnect(sk, ETIMEDOUT, 0, 0); + break; + } +} + +static void x25_timer_expiry(struct timer_list *t) +{ + struct x25_sock *x25 = from_timer(x25, t, timer); + struct sock *sk = &x25->sk; + + bh_lock_sock(sk); + if (sock_owned_by_user(sk)) { /* can currently only occur in state 3 */ + if (x25_sk(sk)->state == X25_STATE_3) + x25_start_t2timer(sk); + } else + x25_do_timer_expiry(sk); + bh_unlock_sock(sk); +} diff --git a/net/xdp/Kconfig b/net/xdp/Kconfig new file mode 100644 index 000000000..71af2febe --- /dev/null +++ b/net/xdp/Kconfig @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: GPL-2.0-only +config XDP_SOCKETS + bool "XDP sockets" + depends on BPF_SYSCALL + default n + help + XDP sockets allows a channel between XDP programs and + userspace applications. + +config XDP_SOCKETS_DIAG + tristate "XDP sockets: monitoring interface" + depends on XDP_SOCKETS + default n + help + Support for PF_XDP sockets monitoring interface used by the ss tool. + If unsure, say Y. diff --git a/net/xdp/Makefile b/net/xdp/Makefile new file mode 100644 index 000000000..30cdc4315 --- /dev/null +++ b/net/xdp/Makefile @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: GPL-2.0-only +obj-$(CONFIG_XDP_SOCKETS) += xsk.o xdp_umem.o xsk_queue.o xskmap.o +obj-$(CONFIG_XDP_SOCKETS) += xsk_buff_pool.o +obj-$(CONFIG_XDP_SOCKETS_DIAG) += xsk_diag.o diff --git a/net/xdp/xdp_umem.c b/net/xdp/xdp_umem.c new file mode 100644 index 000000000..02207e852 --- /dev/null +++ b/net/xdp/xdp_umem.c @@ -0,0 +1,259 @@ +// SPDX-License-Identifier: GPL-2.0 +/* XDP user-space packet buffer + * Copyright(c) 2018 Intel Corporation. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "xdp_umem.h" +#include "xsk_queue.h" + +static DEFINE_IDA(umem_ida); + +static void xdp_umem_unpin_pages(struct xdp_umem *umem) +{ + unpin_user_pages_dirty_lock(umem->pgs, umem->npgs, true); + + kvfree(umem->pgs); + umem->pgs = NULL; +} + +static void xdp_umem_unaccount_pages(struct xdp_umem *umem) +{ + if (umem->user) { + atomic_long_sub(umem->npgs, &umem->user->locked_vm); + free_uid(umem->user); + } +} + +static void xdp_umem_addr_unmap(struct xdp_umem *umem) +{ + vunmap(umem->addrs); + umem->addrs = NULL; +} + +static int xdp_umem_addr_map(struct xdp_umem *umem, struct page **pages, + u32 nr_pages) +{ + umem->addrs = vmap(pages, nr_pages, VM_MAP, PAGE_KERNEL); + if (!umem->addrs) + return -ENOMEM; + return 0; +} + +static void xdp_umem_release(struct xdp_umem *umem) +{ + umem->zc = false; + ida_free(&umem_ida, umem->id); + + xdp_umem_addr_unmap(umem); + xdp_umem_unpin_pages(umem); + + xdp_umem_unaccount_pages(umem); + kfree(umem); +} + +static void xdp_umem_release_deferred(struct work_struct *work) +{ + struct xdp_umem *umem = container_of(work, struct xdp_umem, work); + + xdp_umem_release(umem); +} + +void xdp_get_umem(struct xdp_umem *umem) +{ + refcount_inc(&umem->users); +} + +void xdp_put_umem(struct xdp_umem *umem, bool defer_cleanup) +{ + if (!umem) + return; + + if (refcount_dec_and_test(&umem->users)) { + if (defer_cleanup) { + INIT_WORK(&umem->work, xdp_umem_release_deferred); + schedule_work(&umem->work); + } else { + xdp_umem_release(umem); + } + } +} + +static int xdp_umem_pin_pages(struct xdp_umem *umem, unsigned long address) +{ + unsigned int gup_flags = FOLL_WRITE; + long npgs; + int err; + + umem->pgs = kvcalloc(umem->npgs, sizeof(*umem->pgs), GFP_KERNEL | __GFP_NOWARN); + if (!umem->pgs) + return -ENOMEM; + + mmap_read_lock(current->mm); + npgs = pin_user_pages(address, umem->npgs, + gup_flags | FOLL_LONGTERM, &umem->pgs[0], NULL); + mmap_read_unlock(current->mm); + + if (npgs != umem->npgs) { + if (npgs >= 0) { + umem->npgs = npgs; + err = -ENOMEM; + goto out_pin; + } + err = npgs; + goto out_pgs; + } + return 0; + +out_pin: + xdp_umem_unpin_pages(umem); +out_pgs: + kvfree(umem->pgs); + umem->pgs = NULL; + return err; +} + +static int xdp_umem_account_pages(struct xdp_umem *umem) +{ + unsigned long lock_limit, new_npgs, old_npgs; + + if (capable(CAP_IPC_LOCK)) + return 0; + + lock_limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT; + umem->user = get_uid(current_user()); + + do { + old_npgs = atomic_long_read(&umem->user->locked_vm); + new_npgs = old_npgs + umem->npgs; + if (new_npgs > lock_limit) { + free_uid(umem->user); + umem->user = NULL; + return -ENOBUFS; + } + } while (atomic_long_cmpxchg(&umem->user->locked_vm, old_npgs, + new_npgs) != old_npgs); + return 0; +} + +static int xdp_umem_reg(struct xdp_umem *umem, struct xdp_umem_reg *mr) +{ + bool unaligned_chunks = mr->flags & XDP_UMEM_UNALIGNED_CHUNK_FLAG; + u32 chunk_size = mr->chunk_size, headroom = mr->headroom; + u64 addr = mr->addr, size = mr->len; + u32 chunks_rem, npgs_rem; + u64 chunks, npgs; + int err; + + if (chunk_size < XDP_UMEM_MIN_CHUNK_SIZE || chunk_size > PAGE_SIZE) { + /* Strictly speaking we could support this, if: + * - huge pages, or* + * - using an IOMMU, or + * - making sure the memory area is consecutive + * but for now, we simply say "computer says no". + */ + return -EINVAL; + } + + if (mr->flags & ~XDP_UMEM_UNALIGNED_CHUNK_FLAG) + return -EINVAL; + + if (!unaligned_chunks && !is_power_of_2(chunk_size)) + return -EINVAL; + + if (!PAGE_ALIGNED(addr)) { + /* Memory area has to be page size aligned. For + * simplicity, this might change. + */ + return -EINVAL; + } + + if ((addr + size) < addr) + return -EINVAL; + + npgs = div_u64_rem(size, PAGE_SIZE, &npgs_rem); + if (npgs_rem) + npgs++; + if (npgs > U32_MAX) + return -EINVAL; + + chunks = div_u64_rem(size, chunk_size, &chunks_rem); + if (!chunks || chunks > U32_MAX) + return -EINVAL; + + if (!unaligned_chunks && chunks_rem) + return -EINVAL; + + if (headroom >= chunk_size - XDP_PACKET_HEADROOM) + return -EINVAL; + + umem->size = size; + umem->headroom = headroom; + umem->chunk_size = chunk_size; + umem->chunks = chunks; + umem->npgs = npgs; + umem->pgs = NULL; + umem->user = NULL; + umem->flags = mr->flags; + + INIT_LIST_HEAD(&umem->xsk_dma_list); + refcount_set(&umem->users, 1); + + err = xdp_umem_account_pages(umem); + if (err) + return err; + + err = xdp_umem_pin_pages(umem, (unsigned long)addr); + if (err) + goto out_account; + + err = xdp_umem_addr_map(umem, umem->pgs, umem->npgs); + if (err) + goto out_unpin; + + return 0; + +out_unpin: + xdp_umem_unpin_pages(umem); +out_account: + xdp_umem_unaccount_pages(umem); + return err; +} + +struct xdp_umem *xdp_umem_create(struct xdp_umem_reg *mr) +{ + struct xdp_umem *umem; + int err; + + umem = kzalloc(sizeof(*umem), GFP_KERNEL); + if (!umem) + return ERR_PTR(-ENOMEM); + + err = ida_alloc(&umem_ida, GFP_KERNEL); + if (err < 0) { + kfree(umem); + return ERR_PTR(err); + } + umem->id = err; + + err = xdp_umem_reg(umem, mr); + if (err) { + ida_free(&umem_ida, umem->id); + kfree(umem); + return ERR_PTR(err); + } + + return umem; +} diff --git a/net/xdp/xdp_umem.h b/net/xdp/xdp_umem.h new file mode 100644 index 000000000..aa9fe2780 --- /dev/null +++ b/net/xdp/xdp_umem.h @@ -0,0 +1,15 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* XDP user-space packet buffer + * Copyright(c) 2018 Intel Corporation. + */ + +#ifndef XDP_UMEM_H_ +#define XDP_UMEM_H_ + +#include + +void xdp_get_umem(struct xdp_umem *umem); +void xdp_put_umem(struct xdp_umem *umem, bool defer_cleanup); +struct xdp_umem *xdp_umem_create(struct xdp_umem_reg *mr); + +#endif /* XDP_UMEM_H_ */ diff --git a/net/xdp/xsk.c b/net/xdp/xsk.c new file mode 100644 index 000000000..5c8e02d56 --- /dev/null +++ b/net/xdp/xsk.c @@ -0,0 +1,1524 @@ +// SPDX-License-Identifier: GPL-2.0 +/* XDP sockets + * + * AF_XDP sockets allows a channel between XDP programs and userspace + * applications. + * Copyright(c) 2018 Intel Corporation. + * + * Author(s): Björn Töpel + * Magnus Karlsson + */ + +#define pr_fmt(fmt) "AF_XDP: %s: " fmt, __func__ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "xsk_queue.h" +#include "xdp_umem.h" +#include "xsk.h" + +#define TX_BATCH_SIZE 32 + +static DEFINE_PER_CPU(struct list_head, xskmap_flush_list); + +void xsk_set_rx_need_wakeup(struct xsk_buff_pool *pool) +{ + if (pool->cached_need_wakeup & XDP_WAKEUP_RX) + return; + + pool->fq->ring->flags |= XDP_RING_NEED_WAKEUP; + pool->cached_need_wakeup |= XDP_WAKEUP_RX; +} +EXPORT_SYMBOL(xsk_set_rx_need_wakeup); + +void xsk_set_tx_need_wakeup(struct xsk_buff_pool *pool) +{ + struct xdp_sock *xs; + + if (pool->cached_need_wakeup & XDP_WAKEUP_TX) + return; + + rcu_read_lock(); + list_for_each_entry_rcu(xs, &pool->xsk_tx_list, tx_list) { + xs->tx->ring->flags |= XDP_RING_NEED_WAKEUP; + } + rcu_read_unlock(); + + pool->cached_need_wakeup |= XDP_WAKEUP_TX; +} +EXPORT_SYMBOL(xsk_set_tx_need_wakeup); + +void xsk_clear_rx_need_wakeup(struct xsk_buff_pool *pool) +{ + if (!(pool->cached_need_wakeup & XDP_WAKEUP_RX)) + return; + + pool->fq->ring->flags &= ~XDP_RING_NEED_WAKEUP; + pool->cached_need_wakeup &= ~XDP_WAKEUP_RX; +} +EXPORT_SYMBOL(xsk_clear_rx_need_wakeup); + +void xsk_clear_tx_need_wakeup(struct xsk_buff_pool *pool) +{ + struct xdp_sock *xs; + + if (!(pool->cached_need_wakeup & XDP_WAKEUP_TX)) + return; + + rcu_read_lock(); + list_for_each_entry_rcu(xs, &pool->xsk_tx_list, tx_list) { + xs->tx->ring->flags &= ~XDP_RING_NEED_WAKEUP; + } + rcu_read_unlock(); + + pool->cached_need_wakeup &= ~XDP_WAKEUP_TX; +} +EXPORT_SYMBOL(xsk_clear_tx_need_wakeup); + +bool xsk_uses_need_wakeup(struct xsk_buff_pool *pool) +{ + return pool->uses_need_wakeup; +} +EXPORT_SYMBOL(xsk_uses_need_wakeup); + +struct xsk_buff_pool *xsk_get_pool_from_qid(struct net_device *dev, + u16 queue_id) +{ + if (queue_id < dev->real_num_rx_queues) + return dev->_rx[queue_id].pool; + if (queue_id < dev->real_num_tx_queues) + return dev->_tx[queue_id].pool; + + return NULL; +} +EXPORT_SYMBOL(xsk_get_pool_from_qid); + +void xsk_clear_pool_at_qid(struct net_device *dev, u16 queue_id) +{ + if (queue_id < dev->num_rx_queues) + dev->_rx[queue_id].pool = NULL; + if (queue_id < dev->num_tx_queues) + dev->_tx[queue_id].pool = NULL; +} + +/* The buffer pool is stored both in the _rx struct and the _tx struct as we do + * not know if the device has more tx queues than rx, or the opposite. + * This might also change during run time. + */ +int xsk_reg_pool_at_qid(struct net_device *dev, struct xsk_buff_pool *pool, + u16 queue_id) +{ + if (queue_id >= max_t(unsigned int, + dev->real_num_rx_queues, + dev->real_num_tx_queues)) + return -EINVAL; + + if (queue_id < dev->real_num_rx_queues) + dev->_rx[queue_id].pool = pool; + if (queue_id < dev->real_num_tx_queues) + dev->_tx[queue_id].pool = pool; + + return 0; +} + +static int __xsk_rcv_zc(struct xdp_sock *xs, struct xdp_buff *xdp, u32 len) +{ + struct xdp_buff_xsk *xskb = container_of(xdp, struct xdp_buff_xsk, xdp); + u64 addr; + int err; + + addr = xp_get_handle(xskb); + err = xskq_prod_reserve_desc(xs->rx, addr, len); + if (err) { + xs->rx_queue_full++; + return err; + } + + xp_release(xskb); + return 0; +} + +static void xsk_copy_xdp(struct xdp_buff *to, struct xdp_buff *from, u32 len) +{ + void *from_buf, *to_buf; + u32 metalen; + + if (unlikely(xdp_data_meta_unsupported(from))) { + from_buf = from->data; + to_buf = to->data; + metalen = 0; + } else { + from_buf = from->data_meta; + metalen = from->data - from->data_meta; + to_buf = to->data - metalen; + } + + memcpy(to_buf, from_buf, len + metalen); +} + +static int __xsk_rcv(struct xdp_sock *xs, struct xdp_buff *xdp) +{ + struct xdp_buff *xsk_xdp; + int err; + u32 len; + + len = xdp->data_end - xdp->data; + if (len > xsk_pool_get_rx_frame_size(xs->pool)) { + xs->rx_dropped++; + return -ENOSPC; + } + + xsk_xdp = xsk_buff_alloc(xs->pool); + if (!xsk_xdp) { + xs->rx_dropped++; + return -ENOMEM; + } + + xsk_copy_xdp(xsk_xdp, xdp, len); + err = __xsk_rcv_zc(xs, xsk_xdp, len); + if (err) { + xsk_buff_free(xsk_xdp); + return err; + } + return 0; +} + +static bool xsk_tx_writeable(struct xdp_sock *xs) +{ + if (xskq_cons_present_entries(xs->tx) > xs->tx->nentries / 2) + return false; + + return true; +} + +static bool xsk_is_bound(struct xdp_sock *xs) +{ + if (READ_ONCE(xs->state) == XSK_BOUND) { + /* Matches smp_wmb() in bind(). */ + smp_rmb(); + return true; + } + return false; +} + +static int xsk_rcv_check(struct xdp_sock *xs, struct xdp_buff *xdp) +{ + if (!xsk_is_bound(xs)) + return -ENXIO; + + if (xs->dev != xdp->rxq->dev || xs->queue_id != xdp->rxq->queue_index) + return -EINVAL; + + sk_mark_napi_id_once_xdp(&xs->sk, xdp); + return 0; +} + +static void xsk_flush(struct xdp_sock *xs) +{ + xskq_prod_submit(xs->rx); + __xskq_cons_release(xs->pool->fq); + sock_def_readable(&xs->sk); +} + +int xsk_generic_rcv(struct xdp_sock *xs, struct xdp_buff *xdp) +{ + int err; + + spin_lock_bh(&xs->rx_lock); + err = xsk_rcv_check(xs, xdp); + if (!err) { + err = __xsk_rcv(xs, xdp); + xsk_flush(xs); + } + spin_unlock_bh(&xs->rx_lock); + return err; +} + +static int xsk_rcv(struct xdp_sock *xs, struct xdp_buff *xdp) +{ + int err; + u32 len; + + err = xsk_rcv_check(xs, xdp); + if (err) + return err; + + if (xdp->rxq->mem.type == MEM_TYPE_XSK_BUFF_POOL) { + len = xdp->data_end - xdp->data; + return __xsk_rcv_zc(xs, xdp, len); + } + + err = __xsk_rcv(xs, xdp); + if (!err) + xdp_return_buff(xdp); + return err; +} + +int __xsk_map_redirect(struct xdp_sock *xs, struct xdp_buff *xdp) +{ + struct list_head *flush_list = this_cpu_ptr(&xskmap_flush_list); + int err; + + err = xsk_rcv(xs, xdp); + if (err) + return err; + + if (!xs->flush_node.prev) + list_add(&xs->flush_node, flush_list); + + return 0; +} + +void __xsk_map_flush(void) +{ + struct list_head *flush_list = this_cpu_ptr(&xskmap_flush_list); + struct xdp_sock *xs, *tmp; + + list_for_each_entry_safe(xs, tmp, flush_list, flush_node) { + xsk_flush(xs); + __list_del_clearprev(&xs->flush_node); + } +} + +void xsk_tx_completed(struct xsk_buff_pool *pool, u32 nb_entries) +{ + xskq_prod_submit_n(pool->cq, nb_entries); +} +EXPORT_SYMBOL(xsk_tx_completed); + +void xsk_tx_release(struct xsk_buff_pool *pool) +{ + struct xdp_sock *xs; + + rcu_read_lock(); + list_for_each_entry_rcu(xs, &pool->xsk_tx_list, tx_list) { + __xskq_cons_release(xs->tx); + if (xsk_tx_writeable(xs)) + xs->sk.sk_write_space(&xs->sk); + } + rcu_read_unlock(); +} +EXPORT_SYMBOL(xsk_tx_release); + +bool xsk_tx_peek_desc(struct xsk_buff_pool *pool, struct xdp_desc *desc) +{ + struct xdp_sock *xs; + + rcu_read_lock(); + list_for_each_entry_rcu(xs, &pool->xsk_tx_list, tx_list) { + if (!xskq_cons_peek_desc(xs->tx, desc, pool)) { + xs->tx->queue_empty_descs++; + continue; + } + + /* This is the backpressure mechanism for the Tx path. + * Reserve space in the completion queue and only proceed + * if there is space in it. This avoids having to implement + * any buffering in the Tx path. + */ + if (xskq_prod_reserve_addr(pool->cq, desc->addr)) + goto out; + + xskq_cons_release(xs->tx); + rcu_read_unlock(); + return true; + } + +out: + rcu_read_unlock(); + return false; +} +EXPORT_SYMBOL(xsk_tx_peek_desc); + +static u32 xsk_tx_peek_release_fallback(struct xsk_buff_pool *pool, u32 max_entries) +{ + struct xdp_desc *descs = pool->tx_descs; + u32 nb_pkts = 0; + + while (nb_pkts < max_entries && xsk_tx_peek_desc(pool, &descs[nb_pkts])) + nb_pkts++; + + xsk_tx_release(pool); + return nb_pkts; +} + +u32 xsk_tx_peek_release_desc_batch(struct xsk_buff_pool *pool, u32 nb_pkts) +{ + struct xdp_sock *xs; + + rcu_read_lock(); + if (!list_is_singular(&pool->xsk_tx_list)) { + /* Fallback to the non-batched version */ + rcu_read_unlock(); + return xsk_tx_peek_release_fallback(pool, nb_pkts); + } + + xs = list_first_or_null_rcu(&pool->xsk_tx_list, struct xdp_sock, tx_list); + if (!xs) { + nb_pkts = 0; + goto out; + } + + nb_pkts = xskq_cons_nb_entries(xs->tx, nb_pkts); + + /* This is the backpressure mechanism for the Tx path. Try to + * reserve space in the completion queue for all packets, but + * if there are fewer slots available, just process that many + * packets. This avoids having to implement any buffering in + * the Tx path. + */ + nb_pkts = xskq_prod_nb_free(pool->cq, nb_pkts); + if (!nb_pkts) + goto out; + + nb_pkts = xskq_cons_read_desc_batch(xs->tx, pool, nb_pkts); + if (!nb_pkts) { + xs->tx->queue_empty_descs++; + goto out; + } + + __xskq_cons_release(xs->tx); + xskq_prod_write_addr_batch(pool->cq, pool->tx_descs, nb_pkts); + xs->sk.sk_write_space(&xs->sk); + +out: + rcu_read_unlock(); + return nb_pkts; +} +EXPORT_SYMBOL(xsk_tx_peek_release_desc_batch); + +static int xsk_wakeup(struct xdp_sock *xs, u8 flags) +{ + struct net_device *dev = xs->dev; + + return dev->netdev_ops->ndo_xsk_wakeup(dev, xs->queue_id, flags); +} + +static void xsk_destruct_skb(struct sk_buff *skb) +{ + u64 addr = (u64)(long)skb_shinfo(skb)->destructor_arg; + struct xdp_sock *xs = xdp_sk(skb->sk); + unsigned long flags; + + spin_lock_irqsave(&xs->pool->cq_lock, flags); + xskq_prod_submit_addr(xs->pool->cq, addr); + spin_unlock_irqrestore(&xs->pool->cq_lock, flags); + + sock_wfree(skb); +} + +static struct sk_buff *xsk_build_skb_zerocopy(struct xdp_sock *xs, + struct xdp_desc *desc) +{ + struct xsk_buff_pool *pool = xs->pool; + u32 hr, len, ts, offset, copy, copied; + struct sk_buff *skb; + struct page *page; + void *buffer; + int err, i; + u64 addr; + + hr = max(NET_SKB_PAD, L1_CACHE_ALIGN(xs->dev->needed_headroom)); + + skb = sock_alloc_send_skb(&xs->sk, hr, 1, &err); + if (unlikely(!skb)) + return ERR_PTR(err); + + skb_reserve(skb, hr); + + addr = desc->addr; + len = desc->len; + ts = pool->unaligned ? len : pool->chunk_size; + + buffer = xsk_buff_raw_get_data(pool, addr); + offset = offset_in_page(buffer); + addr = buffer - pool->addrs; + + for (copied = 0, i = 0; copied < len; i++) { + page = pool->umem->pgs[addr >> PAGE_SHIFT]; + get_page(page); + + copy = min_t(u32, PAGE_SIZE - offset, len - copied); + skb_fill_page_desc(skb, i, page, offset, copy); + + copied += copy; + addr += copy; + offset = 0; + } + + skb->len += len; + skb->data_len += len; + skb->truesize += ts; + + refcount_add(ts, &xs->sk.sk_wmem_alloc); + + return skb; +} + +static struct sk_buff *xsk_build_skb(struct xdp_sock *xs, + struct xdp_desc *desc) +{ + struct net_device *dev = xs->dev; + struct sk_buff *skb; + + if (dev->priv_flags & IFF_TX_SKB_NO_LINEAR) { + skb = xsk_build_skb_zerocopy(xs, desc); + if (IS_ERR(skb)) + return skb; + } else { + u32 hr, tr, len; + void *buffer; + int err; + + hr = max(NET_SKB_PAD, L1_CACHE_ALIGN(dev->needed_headroom)); + tr = dev->needed_tailroom; + len = desc->len; + + skb = sock_alloc_send_skb(&xs->sk, hr + len + tr, 1, &err); + if (unlikely(!skb)) + return ERR_PTR(err); + + skb_reserve(skb, hr); + skb_put(skb, len); + + buffer = xsk_buff_raw_get_data(xs->pool, desc->addr); + err = skb_store_bits(skb, 0, buffer, len); + if (unlikely(err)) { + kfree_skb(skb); + return ERR_PTR(err); + } + } + + skb->dev = dev; + skb->priority = xs->sk.sk_priority; + skb->mark = READ_ONCE(xs->sk.sk_mark); + skb_shinfo(skb)->destructor_arg = (void *)(long)desc->addr; + skb->destructor = xsk_destruct_skb; + + return skb; +} + +static int __xsk_generic_xmit(struct sock *sk) +{ + struct xdp_sock *xs = xdp_sk(sk); + u32 max_batch = TX_BATCH_SIZE; + bool sent_frame = false; + struct xdp_desc desc; + struct sk_buff *skb; + unsigned long flags; + int err = 0; + + mutex_lock(&xs->mutex); + + /* Since we dropped the RCU read lock, the socket state might have changed. */ + if (unlikely(!xsk_is_bound(xs))) { + err = -ENXIO; + goto out; + } + + if (xs->queue_id >= xs->dev->real_num_tx_queues) + goto out; + + while (xskq_cons_peek_desc(xs->tx, &desc, xs->pool)) { + if (max_batch-- == 0) { + err = -EAGAIN; + goto out; + } + + /* This is the backpressure mechanism for the Tx path. + * Reserve space in the completion queue and only proceed + * if there is space in it. This avoids having to implement + * any buffering in the Tx path. + */ + spin_lock_irqsave(&xs->pool->cq_lock, flags); + if (xskq_prod_reserve(xs->pool->cq)) { + spin_unlock_irqrestore(&xs->pool->cq_lock, flags); + goto out; + } + spin_unlock_irqrestore(&xs->pool->cq_lock, flags); + + skb = xsk_build_skb(xs, &desc); + if (IS_ERR(skb)) { + err = PTR_ERR(skb); + spin_lock_irqsave(&xs->pool->cq_lock, flags); + xskq_prod_cancel(xs->pool->cq); + spin_unlock_irqrestore(&xs->pool->cq_lock, flags); + goto out; + } + + err = __dev_direct_xmit(skb, xs->queue_id); + if (err == NETDEV_TX_BUSY) { + /* Tell user-space to retry the send */ + skb->destructor = sock_wfree; + spin_lock_irqsave(&xs->pool->cq_lock, flags); + xskq_prod_cancel(xs->pool->cq); + spin_unlock_irqrestore(&xs->pool->cq_lock, flags); + /* Free skb without triggering the perf drop trace */ + consume_skb(skb); + err = -EAGAIN; + goto out; + } + + xskq_cons_release(xs->tx); + /* Ignore NET_XMIT_CN as packet might have been sent */ + if (err == NET_XMIT_DROP) { + /* SKB completed but not sent */ + err = -EBUSY; + goto out; + } + + sent_frame = true; + } + + xs->tx->queue_empty_descs++; + +out: + if (sent_frame) + if (xsk_tx_writeable(xs)) + sk->sk_write_space(sk); + + mutex_unlock(&xs->mutex); + return err; +} + +static int xsk_generic_xmit(struct sock *sk) +{ + int ret; + + /* Drop the RCU lock since the SKB path might sleep. */ + rcu_read_unlock(); + ret = __xsk_generic_xmit(sk); + /* Reaquire RCU lock before going into common code. */ + rcu_read_lock(); + + return ret; +} + +static bool xsk_no_wakeup(struct sock *sk) +{ +#ifdef CONFIG_NET_RX_BUSY_POLL + /* Prefer busy-polling, skip the wakeup. */ + return READ_ONCE(sk->sk_prefer_busy_poll) && READ_ONCE(sk->sk_ll_usec) && + READ_ONCE(sk->sk_napi_id) >= MIN_NAPI_ID; +#else + return false; +#endif +} + +static int xsk_check_common(struct xdp_sock *xs) +{ + if (unlikely(!xsk_is_bound(xs))) + return -ENXIO; + if (unlikely(!(xs->dev->flags & IFF_UP))) + return -ENETDOWN; + + return 0; +} + +static int __xsk_sendmsg(struct socket *sock, struct msghdr *m, size_t total_len) +{ + bool need_wait = !(m->msg_flags & MSG_DONTWAIT); + struct sock *sk = sock->sk; + struct xdp_sock *xs = xdp_sk(sk); + struct xsk_buff_pool *pool; + int err; + + err = xsk_check_common(xs); + if (err) + return err; + if (unlikely(need_wait)) + return -EOPNOTSUPP; + if (unlikely(!xs->tx)) + return -ENOBUFS; + + if (sk_can_busy_loop(sk)) { + if (xs->zc) + __sk_mark_napi_id_once(sk, xsk_pool_get_napi_id(xs->pool)); + sk_busy_loop(sk, 1); /* only support non-blocking sockets */ + } + + if (xs->zc && xsk_no_wakeup(sk)) + return 0; + + pool = xs->pool; + if (pool->cached_need_wakeup & XDP_WAKEUP_TX) { + if (xs->zc) + return xsk_wakeup(xs, XDP_WAKEUP_TX); + return xsk_generic_xmit(sk); + } + return 0; +} + +static int xsk_sendmsg(struct socket *sock, struct msghdr *m, size_t total_len) +{ + int ret; + + rcu_read_lock(); + ret = __xsk_sendmsg(sock, m, total_len); + rcu_read_unlock(); + + return ret; +} + +static int __xsk_recvmsg(struct socket *sock, struct msghdr *m, size_t len, int flags) +{ + bool need_wait = !(flags & MSG_DONTWAIT); + struct sock *sk = sock->sk; + struct xdp_sock *xs = xdp_sk(sk); + int err; + + err = xsk_check_common(xs); + if (err) + return err; + if (unlikely(!xs->rx)) + return -ENOBUFS; + if (unlikely(need_wait)) + return -EOPNOTSUPP; + + if (sk_can_busy_loop(sk)) + sk_busy_loop(sk, 1); /* only support non-blocking sockets */ + + if (xsk_no_wakeup(sk)) + return 0; + + if (xs->pool->cached_need_wakeup & XDP_WAKEUP_RX && xs->zc) + return xsk_wakeup(xs, XDP_WAKEUP_RX); + return 0; +} + +static int xsk_recvmsg(struct socket *sock, struct msghdr *m, size_t len, int flags) +{ + int ret; + + rcu_read_lock(); + ret = __xsk_recvmsg(sock, m, len, flags); + rcu_read_unlock(); + + return ret; +} + +static __poll_t xsk_poll(struct file *file, struct socket *sock, + struct poll_table_struct *wait) +{ + __poll_t mask = 0; + struct sock *sk = sock->sk; + struct xdp_sock *xs = xdp_sk(sk); + struct xsk_buff_pool *pool; + + sock_poll_wait(file, sock, wait); + + rcu_read_lock(); + if (xsk_check_common(xs)) + goto out; + + pool = xs->pool; + + if (pool->cached_need_wakeup) { + if (xs->zc) + xsk_wakeup(xs, pool->cached_need_wakeup); + else if (xs->tx) + /* Poll needs to drive Tx also in copy mode */ + xsk_generic_xmit(sk); + } + + if (xs->rx && !xskq_prod_is_empty(xs->rx)) + mask |= EPOLLIN | EPOLLRDNORM; + if (xs->tx && xsk_tx_writeable(xs)) + mask |= EPOLLOUT | EPOLLWRNORM; +out: + rcu_read_unlock(); + return mask; +} + +static int xsk_init_queue(u32 entries, struct xsk_queue **queue, + bool umem_queue) +{ + struct xsk_queue *q; + + if (entries == 0 || *queue || !is_power_of_2(entries)) + return -EINVAL; + + q = xskq_create(entries, umem_queue); + if (!q) + return -ENOMEM; + + /* Make sure queue is ready before it can be seen by others */ + smp_wmb(); + WRITE_ONCE(*queue, q); + return 0; +} + +static void xsk_unbind_dev(struct xdp_sock *xs) +{ + struct net_device *dev = xs->dev; + + if (xs->state != XSK_BOUND) + return; + WRITE_ONCE(xs->state, XSK_UNBOUND); + + /* Wait for driver to stop using the xdp socket. */ + xp_del_xsk(xs->pool, xs); + synchronize_net(); + dev_put(dev); +} + +static struct xsk_map *xsk_get_map_list_entry(struct xdp_sock *xs, + struct xdp_sock __rcu ***map_entry) +{ + struct xsk_map *map = NULL; + struct xsk_map_node *node; + + *map_entry = NULL; + + spin_lock_bh(&xs->map_list_lock); + node = list_first_entry_or_null(&xs->map_list, struct xsk_map_node, + node); + if (node) { + bpf_map_inc(&node->map->map); + map = node->map; + *map_entry = node->map_entry; + } + spin_unlock_bh(&xs->map_list_lock); + return map; +} + +static void xsk_delete_from_maps(struct xdp_sock *xs) +{ + /* This function removes the current XDP socket from all the + * maps it resides in. We need to take extra care here, due to + * the two locks involved. Each map has a lock synchronizing + * updates to the entries, and each socket has a lock that + * synchronizes access to the list of maps (map_list). For + * deadlock avoidance the locks need to be taken in the order + * "map lock"->"socket map list lock". We start off by + * accessing the socket map list, and take a reference to the + * map to guarantee existence between the + * xsk_get_map_list_entry() and xsk_map_try_sock_delete() + * calls. Then we ask the map to remove the socket, which + * tries to remove the socket from the map. Note that there + * might be updates to the map between + * xsk_get_map_list_entry() and xsk_map_try_sock_delete(). + */ + struct xdp_sock __rcu **map_entry = NULL; + struct xsk_map *map; + + while ((map = xsk_get_map_list_entry(xs, &map_entry))) { + xsk_map_try_sock_delete(map, xs, map_entry); + bpf_map_put(&map->map); + } +} + +static int xsk_release(struct socket *sock) +{ + struct sock *sk = sock->sk; + struct xdp_sock *xs = xdp_sk(sk); + struct net *net; + + if (!sk) + return 0; + + net = sock_net(sk); + + mutex_lock(&net->xdp.lock); + sk_del_node_init_rcu(sk); + mutex_unlock(&net->xdp.lock); + + sock_prot_inuse_add(net, sk->sk_prot, -1); + + xsk_delete_from_maps(xs); + mutex_lock(&xs->mutex); + xsk_unbind_dev(xs); + mutex_unlock(&xs->mutex); + + xskq_destroy(xs->rx); + xskq_destroy(xs->tx); + xskq_destroy(xs->fq_tmp); + xskq_destroy(xs->cq_tmp); + + sock_orphan(sk); + sock->sk = NULL; + + sk_refcnt_debug_release(sk); + sock_put(sk); + + return 0; +} + +static struct socket *xsk_lookup_xsk_from_fd(int fd) +{ + struct socket *sock; + int err; + + sock = sockfd_lookup(fd, &err); + if (!sock) + return ERR_PTR(-ENOTSOCK); + + if (sock->sk->sk_family != PF_XDP) { + sockfd_put(sock); + return ERR_PTR(-ENOPROTOOPT); + } + + return sock; +} + +static bool xsk_validate_queues(struct xdp_sock *xs) +{ + return xs->fq_tmp && xs->cq_tmp; +} + +static int xsk_bind(struct socket *sock, struct sockaddr *addr, int addr_len) +{ + struct sockaddr_xdp *sxdp = (struct sockaddr_xdp *)addr; + struct sock *sk = sock->sk; + struct xdp_sock *xs = xdp_sk(sk); + struct net_device *dev; + int bound_dev_if; + u32 flags, qid; + int err = 0; + + if (addr_len < sizeof(struct sockaddr_xdp)) + return -EINVAL; + if (sxdp->sxdp_family != AF_XDP) + return -EINVAL; + + flags = sxdp->sxdp_flags; + if (flags & ~(XDP_SHARED_UMEM | XDP_COPY | XDP_ZEROCOPY | + XDP_USE_NEED_WAKEUP)) + return -EINVAL; + + bound_dev_if = READ_ONCE(sk->sk_bound_dev_if); + if (bound_dev_if && bound_dev_if != sxdp->sxdp_ifindex) + return -EINVAL; + + rtnl_lock(); + mutex_lock(&xs->mutex); + if (xs->state != XSK_READY) { + err = -EBUSY; + goto out_release; + } + + dev = dev_get_by_index(sock_net(sk), sxdp->sxdp_ifindex); + if (!dev) { + err = -ENODEV; + goto out_release; + } + + if (!xs->rx && !xs->tx) { + err = -EINVAL; + goto out_unlock; + } + + qid = sxdp->sxdp_queue_id; + + if (flags & XDP_SHARED_UMEM) { + struct xdp_sock *umem_xs; + struct socket *sock; + + if ((flags & XDP_COPY) || (flags & XDP_ZEROCOPY) || + (flags & XDP_USE_NEED_WAKEUP)) { + /* Cannot specify flags for shared sockets. */ + err = -EINVAL; + goto out_unlock; + } + + if (xs->umem) { + /* We have already our own. */ + err = -EINVAL; + goto out_unlock; + } + + sock = xsk_lookup_xsk_from_fd(sxdp->sxdp_shared_umem_fd); + if (IS_ERR(sock)) { + err = PTR_ERR(sock); + goto out_unlock; + } + + umem_xs = xdp_sk(sock->sk); + if (!xsk_is_bound(umem_xs)) { + err = -EBADF; + sockfd_put(sock); + goto out_unlock; + } + + if (umem_xs->queue_id != qid || umem_xs->dev != dev) { + /* Share the umem with another socket on another qid + * and/or device. + */ + xs->pool = xp_create_and_assign_umem(xs, + umem_xs->umem); + if (!xs->pool) { + err = -ENOMEM; + sockfd_put(sock); + goto out_unlock; + } + + err = xp_assign_dev_shared(xs->pool, umem_xs, dev, + qid); + if (err) { + xp_destroy(xs->pool); + xs->pool = NULL; + sockfd_put(sock); + goto out_unlock; + } + } else { + /* Share the buffer pool with the other socket. */ + if (xs->fq_tmp || xs->cq_tmp) { + /* Do not allow setting your own fq or cq. */ + err = -EINVAL; + sockfd_put(sock); + goto out_unlock; + } + + xp_get_pool(umem_xs->pool); + xs->pool = umem_xs->pool; + + /* If underlying shared umem was created without Tx + * ring, allocate Tx descs array that Tx batching API + * utilizes + */ + if (xs->tx && !xs->pool->tx_descs) { + err = xp_alloc_tx_descs(xs->pool, xs); + if (err) { + xp_put_pool(xs->pool); + xs->pool = NULL; + sockfd_put(sock); + goto out_unlock; + } + } + } + + xdp_get_umem(umem_xs->umem); + WRITE_ONCE(xs->umem, umem_xs->umem); + sockfd_put(sock); + } else if (!xs->umem || !xsk_validate_queues(xs)) { + err = -EINVAL; + goto out_unlock; + } else { + /* This xsk has its own umem. */ + xs->pool = xp_create_and_assign_umem(xs, xs->umem); + if (!xs->pool) { + err = -ENOMEM; + goto out_unlock; + } + + err = xp_assign_dev(xs->pool, dev, qid, flags); + if (err) { + xp_destroy(xs->pool); + xs->pool = NULL; + goto out_unlock; + } + } + + /* FQ and CQ are now owned by the buffer pool and cleaned up with it. */ + xs->fq_tmp = NULL; + xs->cq_tmp = NULL; + + xs->dev = dev; + xs->zc = xs->umem->zc; + xs->queue_id = qid; + xp_add_xsk(xs->pool, xs); + +out_unlock: + if (err) { + dev_put(dev); + } else { + /* Matches smp_rmb() in bind() for shared umem + * sockets, and xsk_is_bound(). + */ + smp_wmb(); + WRITE_ONCE(xs->state, XSK_BOUND); + } +out_release: + mutex_unlock(&xs->mutex); + rtnl_unlock(); + return err; +} + +struct xdp_umem_reg_v1 { + __u64 addr; /* Start of packet data area */ + __u64 len; /* Length of packet data area */ + __u32 chunk_size; + __u32 headroom; +}; + +static int xsk_setsockopt(struct socket *sock, int level, int optname, + sockptr_t optval, unsigned int optlen) +{ + struct sock *sk = sock->sk; + struct xdp_sock *xs = xdp_sk(sk); + int err; + + if (level != SOL_XDP) + return -ENOPROTOOPT; + + switch (optname) { + case XDP_RX_RING: + case XDP_TX_RING: + { + struct xsk_queue **q; + int entries; + + if (optlen < sizeof(entries)) + return -EINVAL; + if (copy_from_sockptr(&entries, optval, sizeof(entries))) + return -EFAULT; + + mutex_lock(&xs->mutex); + if (xs->state != XSK_READY) { + mutex_unlock(&xs->mutex); + return -EBUSY; + } + q = (optname == XDP_TX_RING) ? &xs->tx : &xs->rx; + err = xsk_init_queue(entries, q, false); + if (!err && optname == XDP_TX_RING) + /* Tx needs to be explicitly woken up the first time */ + xs->tx->ring->flags |= XDP_RING_NEED_WAKEUP; + mutex_unlock(&xs->mutex); + return err; + } + case XDP_UMEM_REG: + { + size_t mr_size = sizeof(struct xdp_umem_reg); + struct xdp_umem_reg mr = {}; + struct xdp_umem *umem; + + if (optlen < sizeof(struct xdp_umem_reg_v1)) + return -EINVAL; + else if (optlen < sizeof(mr)) + mr_size = sizeof(struct xdp_umem_reg_v1); + + if (copy_from_sockptr(&mr, optval, mr_size)) + return -EFAULT; + + mutex_lock(&xs->mutex); + if (xs->state != XSK_READY || xs->umem) { + mutex_unlock(&xs->mutex); + return -EBUSY; + } + + umem = xdp_umem_create(&mr); + if (IS_ERR(umem)) { + mutex_unlock(&xs->mutex); + return PTR_ERR(umem); + } + + /* Make sure umem is ready before it can be seen by others */ + smp_wmb(); + WRITE_ONCE(xs->umem, umem); + mutex_unlock(&xs->mutex); + return 0; + } + case XDP_UMEM_FILL_RING: + case XDP_UMEM_COMPLETION_RING: + { + struct xsk_queue **q; + int entries; + + if (copy_from_sockptr(&entries, optval, sizeof(entries))) + return -EFAULT; + + mutex_lock(&xs->mutex); + if (xs->state != XSK_READY) { + mutex_unlock(&xs->mutex); + return -EBUSY; + } + + q = (optname == XDP_UMEM_FILL_RING) ? &xs->fq_tmp : + &xs->cq_tmp; + err = xsk_init_queue(entries, q, true); + mutex_unlock(&xs->mutex); + return err; + } + default: + break; + } + + return -ENOPROTOOPT; +} + +static void xsk_enter_rxtx_offsets(struct xdp_ring_offset_v1 *ring) +{ + ring->producer = offsetof(struct xdp_rxtx_ring, ptrs.producer); + ring->consumer = offsetof(struct xdp_rxtx_ring, ptrs.consumer); + ring->desc = offsetof(struct xdp_rxtx_ring, desc); +} + +static void xsk_enter_umem_offsets(struct xdp_ring_offset_v1 *ring) +{ + ring->producer = offsetof(struct xdp_umem_ring, ptrs.producer); + ring->consumer = offsetof(struct xdp_umem_ring, ptrs.consumer); + ring->desc = offsetof(struct xdp_umem_ring, desc); +} + +struct xdp_statistics_v1 { + __u64 rx_dropped; + __u64 rx_invalid_descs; + __u64 tx_invalid_descs; +}; + +static int xsk_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + struct xdp_sock *xs = xdp_sk(sk); + int len; + + if (level != SOL_XDP) + return -ENOPROTOOPT; + + if (get_user(len, optlen)) + return -EFAULT; + if (len < 0) + return -EINVAL; + + switch (optname) { + case XDP_STATISTICS: + { + struct xdp_statistics stats = {}; + bool extra_stats = true; + size_t stats_size; + + if (len < sizeof(struct xdp_statistics_v1)) { + return -EINVAL; + } else if (len < sizeof(stats)) { + extra_stats = false; + stats_size = sizeof(struct xdp_statistics_v1); + } else { + stats_size = sizeof(stats); + } + + mutex_lock(&xs->mutex); + stats.rx_dropped = xs->rx_dropped; + if (extra_stats) { + stats.rx_ring_full = xs->rx_queue_full; + stats.rx_fill_ring_empty_descs = + xs->pool ? xskq_nb_queue_empty_descs(xs->pool->fq) : 0; + stats.tx_ring_empty_descs = xskq_nb_queue_empty_descs(xs->tx); + } else { + stats.rx_dropped += xs->rx_queue_full; + } + stats.rx_invalid_descs = xskq_nb_invalid_descs(xs->rx); + stats.tx_invalid_descs = xskq_nb_invalid_descs(xs->tx); + mutex_unlock(&xs->mutex); + + if (copy_to_user(optval, &stats, stats_size)) + return -EFAULT; + if (put_user(stats_size, optlen)) + return -EFAULT; + + return 0; + } + case XDP_MMAP_OFFSETS: + { + struct xdp_mmap_offsets off; + struct xdp_mmap_offsets_v1 off_v1; + bool flags_supported = true; + void *to_copy; + + if (len < sizeof(off_v1)) + return -EINVAL; + else if (len < sizeof(off)) + flags_supported = false; + + if (flags_supported) { + /* xdp_ring_offset is identical to xdp_ring_offset_v1 + * except for the flags field added to the end. + */ + xsk_enter_rxtx_offsets((struct xdp_ring_offset_v1 *) + &off.rx); + xsk_enter_rxtx_offsets((struct xdp_ring_offset_v1 *) + &off.tx); + xsk_enter_umem_offsets((struct xdp_ring_offset_v1 *) + &off.fr); + xsk_enter_umem_offsets((struct xdp_ring_offset_v1 *) + &off.cr); + off.rx.flags = offsetof(struct xdp_rxtx_ring, + ptrs.flags); + off.tx.flags = offsetof(struct xdp_rxtx_ring, + ptrs.flags); + off.fr.flags = offsetof(struct xdp_umem_ring, + ptrs.flags); + off.cr.flags = offsetof(struct xdp_umem_ring, + ptrs.flags); + + len = sizeof(off); + to_copy = &off; + } else { + xsk_enter_rxtx_offsets(&off_v1.rx); + xsk_enter_rxtx_offsets(&off_v1.tx); + xsk_enter_umem_offsets(&off_v1.fr); + xsk_enter_umem_offsets(&off_v1.cr); + + len = sizeof(off_v1); + to_copy = &off_v1; + } + + if (copy_to_user(optval, to_copy, len)) + return -EFAULT; + if (put_user(len, optlen)) + return -EFAULT; + + return 0; + } + case XDP_OPTIONS: + { + struct xdp_options opts = {}; + + if (len < sizeof(opts)) + return -EINVAL; + + mutex_lock(&xs->mutex); + if (xs->zc) + opts.flags |= XDP_OPTIONS_ZEROCOPY; + mutex_unlock(&xs->mutex); + + len = sizeof(opts); + if (copy_to_user(optval, &opts, len)) + return -EFAULT; + if (put_user(len, optlen)) + return -EFAULT; + + return 0; + } + default: + break; + } + + return -EOPNOTSUPP; +} + +static int xsk_mmap(struct file *file, struct socket *sock, + struct vm_area_struct *vma) +{ + loff_t offset = (loff_t)vma->vm_pgoff << PAGE_SHIFT; + unsigned long size = vma->vm_end - vma->vm_start; + struct xdp_sock *xs = xdp_sk(sock->sk); + struct xsk_queue *q = NULL; + unsigned long pfn; + struct page *qpg; + + if (READ_ONCE(xs->state) != XSK_READY) + return -EBUSY; + + if (offset == XDP_PGOFF_RX_RING) { + q = READ_ONCE(xs->rx); + } else if (offset == XDP_PGOFF_TX_RING) { + q = READ_ONCE(xs->tx); + } else { + /* Matches the smp_wmb() in XDP_UMEM_REG */ + smp_rmb(); + if (offset == XDP_UMEM_PGOFF_FILL_RING) + q = READ_ONCE(xs->fq_tmp); + else if (offset == XDP_UMEM_PGOFF_COMPLETION_RING) + q = READ_ONCE(xs->cq_tmp); + } + + if (!q) + return -EINVAL; + + /* Matches the smp_wmb() in xsk_init_queue */ + smp_rmb(); + qpg = virt_to_head_page(q->ring); + if (size > page_size(qpg)) + return -EINVAL; + + pfn = virt_to_phys(q->ring) >> PAGE_SHIFT; + return remap_pfn_range(vma, vma->vm_start, pfn, + size, vma->vm_page_prot); +} + +static int xsk_notifier(struct notifier_block *this, + unsigned long msg, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + struct net *net = dev_net(dev); + struct sock *sk; + + switch (msg) { + case NETDEV_UNREGISTER: + mutex_lock(&net->xdp.lock); + sk_for_each(sk, &net->xdp.list) { + struct xdp_sock *xs = xdp_sk(sk); + + mutex_lock(&xs->mutex); + if (xs->dev == dev) { + sk->sk_err = ENETDOWN; + if (!sock_flag(sk, SOCK_DEAD)) + sk_error_report(sk); + + xsk_unbind_dev(xs); + + /* Clear device references. */ + xp_clear_dev(xs->pool); + } + mutex_unlock(&xs->mutex); + } + mutex_unlock(&net->xdp.lock); + break; + } + return NOTIFY_DONE; +} + +static struct proto xsk_proto = { + .name = "XDP", + .owner = THIS_MODULE, + .obj_size = sizeof(struct xdp_sock), +}; + +static const struct proto_ops xsk_proto_ops = { + .family = PF_XDP, + .owner = THIS_MODULE, + .release = xsk_release, + .bind = xsk_bind, + .connect = sock_no_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = sock_no_getname, + .poll = xsk_poll, + .ioctl = sock_no_ioctl, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = xsk_setsockopt, + .getsockopt = xsk_getsockopt, + .sendmsg = xsk_sendmsg, + .recvmsg = xsk_recvmsg, + .mmap = xsk_mmap, + .sendpage = sock_no_sendpage, +}; + +static void xsk_destruct(struct sock *sk) +{ + struct xdp_sock *xs = xdp_sk(sk); + + if (!sock_flag(sk, SOCK_DEAD)) + return; + + if (!xp_put_pool(xs->pool)) + xdp_put_umem(xs->umem, !xs->pool); + + sk_refcnt_debug_dec(sk); +} + +static int xsk_create(struct net *net, struct socket *sock, int protocol, + int kern) +{ + struct xdp_sock *xs; + struct sock *sk; + + if (!ns_capable(net->user_ns, CAP_NET_RAW)) + return -EPERM; + if (sock->type != SOCK_RAW) + return -ESOCKTNOSUPPORT; + + if (protocol) + return -EPROTONOSUPPORT; + + sock->state = SS_UNCONNECTED; + + sk = sk_alloc(net, PF_XDP, GFP_KERNEL, &xsk_proto, kern); + if (!sk) + return -ENOBUFS; + + sock->ops = &xsk_proto_ops; + + sock_init_data(sock, sk); + + sk->sk_family = PF_XDP; + + sk->sk_destruct = xsk_destruct; + sk_refcnt_debug_inc(sk); + + sock_set_flag(sk, SOCK_RCU_FREE); + + xs = xdp_sk(sk); + xs->state = XSK_READY; + mutex_init(&xs->mutex); + spin_lock_init(&xs->rx_lock); + + INIT_LIST_HEAD(&xs->map_list); + spin_lock_init(&xs->map_list_lock); + + mutex_lock(&net->xdp.lock); + sk_add_node_rcu(sk, &net->xdp.list); + mutex_unlock(&net->xdp.lock); + + sock_prot_inuse_add(net, &xsk_proto, 1); + + return 0; +} + +static const struct net_proto_family xsk_family_ops = { + .family = PF_XDP, + .create = xsk_create, + .owner = THIS_MODULE, +}; + +static struct notifier_block xsk_netdev_notifier = { + .notifier_call = xsk_notifier, +}; + +static int __net_init xsk_net_init(struct net *net) +{ + mutex_init(&net->xdp.lock); + INIT_HLIST_HEAD(&net->xdp.list); + return 0; +} + +static void __net_exit xsk_net_exit(struct net *net) +{ + WARN_ON_ONCE(!hlist_empty(&net->xdp.list)); +} + +static struct pernet_operations xsk_net_ops = { + .init = xsk_net_init, + .exit = xsk_net_exit, +}; + +static int __init xsk_init(void) +{ + int err, cpu; + + err = proto_register(&xsk_proto, 0 /* no slab */); + if (err) + goto out; + + err = sock_register(&xsk_family_ops); + if (err) + goto out_proto; + + err = register_pernet_subsys(&xsk_net_ops); + if (err) + goto out_sk; + + err = register_netdevice_notifier(&xsk_netdev_notifier); + if (err) + goto out_pernet; + + for_each_possible_cpu(cpu) + INIT_LIST_HEAD(&per_cpu(xskmap_flush_list, cpu)); + return 0; + +out_pernet: + unregister_pernet_subsys(&xsk_net_ops); +out_sk: + sock_unregister(PF_XDP); +out_proto: + proto_unregister(&xsk_proto); +out: + return err; +} + +fs_initcall(xsk_init); diff --git a/net/xdp/xsk.h b/net/xdp/xsk.h new file mode 100644 index 000000000..a4bc4749f --- /dev/null +++ b/net/xdp/xsk.h @@ -0,0 +1,48 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* Copyright(c) 2019 Intel Corporation. */ + +#ifndef XSK_H_ +#define XSK_H_ + +/* Masks for xdp_umem_page flags. + * The low 12-bits of the addr will be 0 since this is the page address, so we + * can use them for flags. + */ +#define XSK_NEXT_PG_CONTIG_SHIFT 0 +#define XSK_NEXT_PG_CONTIG_MASK BIT_ULL(XSK_NEXT_PG_CONTIG_SHIFT) + +struct xdp_ring_offset_v1 { + __u64 producer; + __u64 consumer; + __u64 desc; +}; + +struct xdp_mmap_offsets_v1 { + struct xdp_ring_offset_v1 rx; + struct xdp_ring_offset_v1 tx; + struct xdp_ring_offset_v1 fr; + struct xdp_ring_offset_v1 cr; +}; + +/* Nodes are linked in the struct xdp_sock map_list field, and used to + * track which maps a certain socket reside in. + */ + +struct xsk_map_node { + struct list_head node; + struct xsk_map *map; + struct xdp_sock __rcu **map_entry; +}; + +static inline struct xdp_sock *xdp_sk(struct sock *sk) +{ + return (struct xdp_sock *)sk; +} + +void xsk_map_try_sock_delete(struct xsk_map *map, struct xdp_sock *xs, + struct xdp_sock __rcu **map_entry); +void xsk_clear_pool_at_qid(struct net_device *dev, u16 queue_id); +int xsk_reg_pool_at_qid(struct net_device *dev, struct xsk_buff_pool *pool, + u16 queue_id); + +#endif /* XSK_H_ */ diff --git a/net/xdp/xsk_buff_pool.c b/net/xdp/xsk_buff_pool.c new file mode 100644 index 000000000..ed6c71826 --- /dev/null +++ b/net/xdp/xsk_buff_pool.c @@ -0,0 +1,681 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include +#include +#include + +#include "xsk_queue.h" +#include "xdp_umem.h" +#include "xsk.h" + +void xp_add_xsk(struct xsk_buff_pool *pool, struct xdp_sock *xs) +{ + unsigned long flags; + + if (!xs->tx) + return; + + spin_lock_irqsave(&pool->xsk_tx_list_lock, flags); + list_add_rcu(&xs->tx_list, &pool->xsk_tx_list); + spin_unlock_irqrestore(&pool->xsk_tx_list_lock, flags); +} + +void xp_del_xsk(struct xsk_buff_pool *pool, struct xdp_sock *xs) +{ + unsigned long flags; + + if (!xs->tx) + return; + + spin_lock_irqsave(&pool->xsk_tx_list_lock, flags); + list_del_rcu(&xs->tx_list); + spin_unlock_irqrestore(&pool->xsk_tx_list_lock, flags); +} + +void xp_destroy(struct xsk_buff_pool *pool) +{ + if (!pool) + return; + + kvfree(pool->tx_descs); + kvfree(pool->heads); + kvfree(pool); +} + +int xp_alloc_tx_descs(struct xsk_buff_pool *pool, struct xdp_sock *xs) +{ + pool->tx_descs = kvcalloc(xs->tx->nentries, sizeof(*pool->tx_descs), + GFP_KERNEL); + if (!pool->tx_descs) + return -ENOMEM; + + return 0; +} + +struct xsk_buff_pool *xp_create_and_assign_umem(struct xdp_sock *xs, + struct xdp_umem *umem) +{ + bool unaligned = umem->flags & XDP_UMEM_UNALIGNED_CHUNK_FLAG; + struct xsk_buff_pool *pool; + struct xdp_buff_xsk *xskb; + u32 i, entries; + + entries = unaligned ? umem->chunks : 0; + pool = kvzalloc(struct_size(pool, free_heads, entries), GFP_KERNEL); + if (!pool) + goto out; + + pool->heads = kvcalloc(umem->chunks, sizeof(*pool->heads), GFP_KERNEL); + if (!pool->heads) + goto out; + + if (xs->tx) + if (xp_alloc_tx_descs(pool, xs)) + goto out; + + pool->chunk_mask = ~((u64)umem->chunk_size - 1); + pool->addrs_cnt = umem->size; + pool->heads_cnt = umem->chunks; + pool->free_heads_cnt = umem->chunks; + pool->headroom = umem->headroom; + pool->chunk_size = umem->chunk_size; + pool->chunk_shift = ffs(umem->chunk_size) - 1; + pool->unaligned = unaligned; + pool->frame_len = umem->chunk_size - umem->headroom - + XDP_PACKET_HEADROOM; + pool->umem = umem; + pool->addrs = umem->addrs; + INIT_LIST_HEAD(&pool->free_list); + INIT_LIST_HEAD(&pool->xsk_tx_list); + spin_lock_init(&pool->xsk_tx_list_lock); + spin_lock_init(&pool->cq_lock); + refcount_set(&pool->users, 1); + + pool->fq = xs->fq_tmp; + pool->cq = xs->cq_tmp; + + for (i = 0; i < pool->free_heads_cnt; i++) { + xskb = &pool->heads[i]; + xskb->pool = pool; + xskb->xdp.frame_sz = umem->chunk_size - umem->headroom; + INIT_LIST_HEAD(&xskb->free_list_node); + if (pool->unaligned) + pool->free_heads[i] = xskb; + else + xp_init_xskb_addr(xskb, pool, i * pool->chunk_size); + } + + return pool; + +out: + xp_destroy(pool); + return NULL; +} + +void xp_set_rxq_info(struct xsk_buff_pool *pool, struct xdp_rxq_info *rxq) +{ + u32 i; + + for (i = 0; i < pool->heads_cnt; i++) + pool->heads[i].xdp.rxq = rxq; +} +EXPORT_SYMBOL(xp_set_rxq_info); + +static void xp_disable_drv_zc(struct xsk_buff_pool *pool) +{ + struct netdev_bpf bpf; + int err; + + ASSERT_RTNL(); + + if (pool->umem->zc) { + bpf.command = XDP_SETUP_XSK_POOL; + bpf.xsk.pool = NULL; + bpf.xsk.queue_id = pool->queue_id; + + err = pool->netdev->netdev_ops->ndo_bpf(pool->netdev, &bpf); + + if (err) + WARN(1, "Failed to disable zero-copy!\n"); + } +} + +int xp_assign_dev(struct xsk_buff_pool *pool, + struct net_device *netdev, u16 queue_id, u16 flags) +{ + bool force_zc, force_copy; + struct netdev_bpf bpf; + int err = 0; + + ASSERT_RTNL(); + + force_zc = flags & XDP_ZEROCOPY; + force_copy = flags & XDP_COPY; + + if (force_zc && force_copy) + return -EINVAL; + + if (xsk_get_pool_from_qid(netdev, queue_id)) + return -EBUSY; + + pool->netdev = netdev; + pool->queue_id = queue_id; + err = xsk_reg_pool_at_qid(netdev, pool, queue_id); + if (err) + return err; + + if (flags & XDP_USE_NEED_WAKEUP) + pool->uses_need_wakeup = true; + /* Tx needs to be explicitly woken up the first time. Also + * for supporting drivers that do not implement this + * feature. They will always have to call sendto() or poll(). + */ + pool->cached_need_wakeup = XDP_WAKEUP_TX; + + dev_hold(netdev); + + if (force_copy) + /* For copy-mode, we are done. */ + return 0; + + if (!netdev->netdev_ops->ndo_bpf || + !netdev->netdev_ops->ndo_xsk_wakeup) { + err = -EOPNOTSUPP; + goto err_unreg_pool; + } + + bpf.command = XDP_SETUP_XSK_POOL; + bpf.xsk.pool = pool; + bpf.xsk.queue_id = queue_id; + + err = netdev->netdev_ops->ndo_bpf(netdev, &bpf); + if (err) + goto err_unreg_pool; + + if (!pool->dma_pages) { + WARN(1, "Driver did not DMA map zero-copy buffers"); + err = -EINVAL; + goto err_unreg_xsk; + } + pool->umem->zc = true; + return 0; + +err_unreg_xsk: + xp_disable_drv_zc(pool); +err_unreg_pool: + if (!force_zc) + err = 0; /* fallback to copy mode */ + if (err) { + xsk_clear_pool_at_qid(netdev, queue_id); + dev_put(netdev); + } + return err; +} + +int xp_assign_dev_shared(struct xsk_buff_pool *pool, struct xdp_sock *umem_xs, + struct net_device *dev, u16 queue_id) +{ + u16 flags; + struct xdp_umem *umem = umem_xs->umem; + + /* One fill and completion ring required for each queue id. */ + if (!pool->fq || !pool->cq) + return -EINVAL; + + flags = umem->zc ? XDP_ZEROCOPY : XDP_COPY; + if (umem_xs->pool->uses_need_wakeup) + flags |= XDP_USE_NEED_WAKEUP; + + return xp_assign_dev(pool, dev, queue_id, flags); +} + +void xp_clear_dev(struct xsk_buff_pool *pool) +{ + if (!pool->netdev) + return; + + xp_disable_drv_zc(pool); + xsk_clear_pool_at_qid(pool->netdev, pool->queue_id); + dev_put(pool->netdev); + pool->netdev = NULL; +} + +static void xp_release_deferred(struct work_struct *work) +{ + struct xsk_buff_pool *pool = container_of(work, struct xsk_buff_pool, + work); + + rtnl_lock(); + xp_clear_dev(pool); + rtnl_unlock(); + + if (pool->fq) { + xskq_destroy(pool->fq); + pool->fq = NULL; + } + + if (pool->cq) { + xskq_destroy(pool->cq); + pool->cq = NULL; + } + + xdp_put_umem(pool->umem, false); + xp_destroy(pool); +} + +void xp_get_pool(struct xsk_buff_pool *pool) +{ + refcount_inc(&pool->users); +} + +bool xp_put_pool(struct xsk_buff_pool *pool) +{ + if (!pool) + return false; + + if (refcount_dec_and_test(&pool->users)) { + INIT_WORK(&pool->work, xp_release_deferred); + schedule_work(&pool->work); + return true; + } + + return false; +} + +static struct xsk_dma_map *xp_find_dma_map(struct xsk_buff_pool *pool) +{ + struct xsk_dma_map *dma_map; + + list_for_each_entry(dma_map, &pool->umem->xsk_dma_list, list) { + if (dma_map->netdev == pool->netdev) + return dma_map; + } + + return NULL; +} + +static struct xsk_dma_map *xp_create_dma_map(struct device *dev, struct net_device *netdev, + u32 nr_pages, struct xdp_umem *umem) +{ + struct xsk_dma_map *dma_map; + + dma_map = kzalloc(sizeof(*dma_map), GFP_KERNEL); + if (!dma_map) + return NULL; + + dma_map->dma_pages = kvcalloc(nr_pages, sizeof(*dma_map->dma_pages), GFP_KERNEL); + if (!dma_map->dma_pages) { + kfree(dma_map); + return NULL; + } + + dma_map->netdev = netdev; + dma_map->dev = dev; + dma_map->dma_need_sync = false; + dma_map->dma_pages_cnt = nr_pages; + refcount_set(&dma_map->users, 1); + list_add(&dma_map->list, &umem->xsk_dma_list); + return dma_map; +} + +static void xp_destroy_dma_map(struct xsk_dma_map *dma_map) +{ + list_del(&dma_map->list); + kvfree(dma_map->dma_pages); + kfree(dma_map); +} + +static void __xp_dma_unmap(struct xsk_dma_map *dma_map, unsigned long attrs) +{ + dma_addr_t *dma; + u32 i; + + for (i = 0; i < dma_map->dma_pages_cnt; i++) { + dma = &dma_map->dma_pages[i]; + if (*dma) { + *dma &= ~XSK_NEXT_PG_CONTIG_MASK; + dma_unmap_page_attrs(dma_map->dev, *dma, PAGE_SIZE, + DMA_BIDIRECTIONAL, attrs); + *dma = 0; + } + } + + xp_destroy_dma_map(dma_map); +} + +void xp_dma_unmap(struct xsk_buff_pool *pool, unsigned long attrs) +{ + struct xsk_dma_map *dma_map; + + if (pool->dma_pages_cnt == 0) + return; + + dma_map = xp_find_dma_map(pool); + if (!dma_map) { + WARN(1, "Could not find dma_map for device"); + return; + } + + if (!refcount_dec_and_test(&dma_map->users)) + return; + + __xp_dma_unmap(dma_map, attrs); + kvfree(pool->dma_pages); + pool->dma_pages_cnt = 0; + pool->dev = NULL; +} +EXPORT_SYMBOL(xp_dma_unmap); + +static void xp_check_dma_contiguity(struct xsk_dma_map *dma_map) +{ + u32 i; + + for (i = 0; i < dma_map->dma_pages_cnt - 1; i++) { + if (dma_map->dma_pages[i] + PAGE_SIZE == dma_map->dma_pages[i + 1]) + dma_map->dma_pages[i] |= XSK_NEXT_PG_CONTIG_MASK; + else + dma_map->dma_pages[i] &= ~XSK_NEXT_PG_CONTIG_MASK; + } +} + +static int xp_init_dma_info(struct xsk_buff_pool *pool, struct xsk_dma_map *dma_map) +{ + if (!pool->unaligned) { + u32 i; + + for (i = 0; i < pool->heads_cnt; i++) { + struct xdp_buff_xsk *xskb = &pool->heads[i]; + + xp_init_xskb_dma(xskb, pool, dma_map->dma_pages, xskb->orig_addr); + } + } + + pool->dma_pages = kvcalloc(dma_map->dma_pages_cnt, sizeof(*pool->dma_pages), GFP_KERNEL); + if (!pool->dma_pages) + return -ENOMEM; + + pool->dev = dma_map->dev; + pool->dma_pages_cnt = dma_map->dma_pages_cnt; + pool->dma_need_sync = dma_map->dma_need_sync; + memcpy(pool->dma_pages, dma_map->dma_pages, + pool->dma_pages_cnt * sizeof(*pool->dma_pages)); + + return 0; +} + +int xp_dma_map(struct xsk_buff_pool *pool, struct device *dev, + unsigned long attrs, struct page **pages, u32 nr_pages) +{ + struct xsk_dma_map *dma_map; + dma_addr_t dma; + int err; + u32 i; + + dma_map = xp_find_dma_map(pool); + if (dma_map) { + err = xp_init_dma_info(pool, dma_map); + if (err) + return err; + + refcount_inc(&dma_map->users); + return 0; + } + + dma_map = xp_create_dma_map(dev, pool->netdev, nr_pages, pool->umem); + if (!dma_map) + return -ENOMEM; + + for (i = 0; i < dma_map->dma_pages_cnt; i++) { + dma = dma_map_page_attrs(dev, pages[i], 0, PAGE_SIZE, + DMA_BIDIRECTIONAL, attrs); + if (dma_mapping_error(dev, dma)) { + __xp_dma_unmap(dma_map, attrs); + return -ENOMEM; + } + if (dma_need_sync(dev, dma)) + dma_map->dma_need_sync = true; + dma_map->dma_pages[i] = dma; + } + + if (pool->unaligned) + xp_check_dma_contiguity(dma_map); + + err = xp_init_dma_info(pool, dma_map); + if (err) { + __xp_dma_unmap(dma_map, attrs); + return err; + } + + return 0; +} +EXPORT_SYMBOL(xp_dma_map); + +static bool xp_addr_crosses_non_contig_pg(struct xsk_buff_pool *pool, + u64 addr) +{ + return xp_desc_crosses_non_contig_pg(pool, addr, pool->chunk_size); +} + +static bool xp_check_unaligned(struct xsk_buff_pool *pool, u64 *addr) +{ + *addr = xp_unaligned_extract_addr(*addr); + if (*addr >= pool->addrs_cnt || + *addr + pool->chunk_size > pool->addrs_cnt || + xp_addr_crosses_non_contig_pg(pool, *addr)) + return false; + return true; +} + +static bool xp_check_aligned(struct xsk_buff_pool *pool, u64 *addr) +{ + *addr = xp_aligned_extract_addr(pool, *addr); + return *addr < pool->addrs_cnt; +} + +static struct xdp_buff_xsk *__xp_alloc(struct xsk_buff_pool *pool) +{ + struct xdp_buff_xsk *xskb; + u64 addr; + bool ok; + + if (pool->free_heads_cnt == 0) + return NULL; + + for (;;) { + if (!xskq_cons_peek_addr_unchecked(pool->fq, &addr)) { + pool->fq->queue_empty_descs++; + return NULL; + } + + ok = pool->unaligned ? xp_check_unaligned(pool, &addr) : + xp_check_aligned(pool, &addr); + if (!ok) { + pool->fq->invalid_descs++; + xskq_cons_release(pool->fq); + continue; + } + break; + } + + if (pool->unaligned) { + xskb = pool->free_heads[--pool->free_heads_cnt]; + xp_init_xskb_addr(xskb, pool, addr); + if (pool->dma_pages_cnt) + xp_init_xskb_dma(xskb, pool, pool->dma_pages, addr); + } else { + xskb = &pool->heads[xp_aligned_extract_idx(pool, addr)]; + } + + xskq_cons_release(pool->fq); + return xskb; +} + +struct xdp_buff *xp_alloc(struct xsk_buff_pool *pool) +{ + struct xdp_buff_xsk *xskb; + + if (!pool->free_list_cnt) { + xskb = __xp_alloc(pool); + if (!xskb) + return NULL; + } else { + pool->free_list_cnt--; + xskb = list_first_entry(&pool->free_list, struct xdp_buff_xsk, + free_list_node); + list_del_init(&xskb->free_list_node); + } + + xskb->xdp.data = xskb->xdp.data_hard_start + XDP_PACKET_HEADROOM; + xskb->xdp.data_meta = xskb->xdp.data; + + if (pool->dma_need_sync) { + dma_sync_single_range_for_device(pool->dev, xskb->dma, 0, + pool->frame_len, + DMA_BIDIRECTIONAL); + } + return &xskb->xdp; +} +EXPORT_SYMBOL(xp_alloc); + +static u32 xp_alloc_new_from_fq(struct xsk_buff_pool *pool, struct xdp_buff **xdp, u32 max) +{ + u32 i, cached_cons, nb_entries; + + if (max > pool->free_heads_cnt) + max = pool->free_heads_cnt; + max = xskq_cons_nb_entries(pool->fq, max); + + cached_cons = pool->fq->cached_cons; + nb_entries = max; + i = max; + while (i--) { + struct xdp_buff_xsk *xskb; + u64 addr; + bool ok; + + __xskq_cons_read_addr_unchecked(pool->fq, cached_cons++, &addr); + + ok = pool->unaligned ? xp_check_unaligned(pool, &addr) : + xp_check_aligned(pool, &addr); + if (unlikely(!ok)) { + pool->fq->invalid_descs++; + nb_entries--; + continue; + } + + if (pool->unaligned) { + xskb = pool->free_heads[--pool->free_heads_cnt]; + xp_init_xskb_addr(xskb, pool, addr); + if (pool->dma_pages_cnt) + xp_init_xskb_dma(xskb, pool, pool->dma_pages, addr); + } else { + xskb = &pool->heads[xp_aligned_extract_idx(pool, addr)]; + } + + *xdp = &xskb->xdp; + xdp++; + } + + xskq_cons_release_n(pool->fq, max); + return nb_entries; +} + +static u32 xp_alloc_reused(struct xsk_buff_pool *pool, struct xdp_buff **xdp, u32 nb_entries) +{ + struct xdp_buff_xsk *xskb; + u32 i; + + nb_entries = min_t(u32, nb_entries, pool->free_list_cnt); + + i = nb_entries; + while (i--) { + xskb = list_first_entry(&pool->free_list, struct xdp_buff_xsk, free_list_node); + list_del_init(&xskb->free_list_node); + + *xdp = &xskb->xdp; + xdp++; + } + pool->free_list_cnt -= nb_entries; + + return nb_entries; +} + +u32 xp_alloc_batch(struct xsk_buff_pool *pool, struct xdp_buff **xdp, u32 max) +{ + u32 nb_entries1 = 0, nb_entries2; + + if (unlikely(pool->dma_need_sync)) { + struct xdp_buff *buff; + + /* Slow path */ + buff = xp_alloc(pool); + if (buff) + *xdp = buff; + return !!buff; + } + + if (unlikely(pool->free_list_cnt)) { + nb_entries1 = xp_alloc_reused(pool, xdp, max); + if (nb_entries1 == max) + return nb_entries1; + + max -= nb_entries1; + xdp += nb_entries1; + } + + nb_entries2 = xp_alloc_new_from_fq(pool, xdp, max); + if (!nb_entries2) + pool->fq->queue_empty_descs++; + + return nb_entries1 + nb_entries2; +} +EXPORT_SYMBOL(xp_alloc_batch); + +bool xp_can_alloc(struct xsk_buff_pool *pool, u32 count) +{ + if (pool->free_list_cnt >= count) + return true; + return xskq_cons_has_entries(pool->fq, count - pool->free_list_cnt); +} +EXPORT_SYMBOL(xp_can_alloc); + +void xp_free(struct xdp_buff_xsk *xskb) +{ + if (!list_empty(&xskb->free_list_node)) + return; + + xskb->pool->free_list_cnt++; + list_add(&xskb->free_list_node, &xskb->pool->free_list); +} +EXPORT_SYMBOL(xp_free); + +void *xp_raw_get_data(struct xsk_buff_pool *pool, u64 addr) +{ + addr = pool->unaligned ? xp_unaligned_add_offset_to_addr(addr) : addr; + return pool->addrs + addr; +} +EXPORT_SYMBOL(xp_raw_get_data); + +dma_addr_t xp_raw_get_dma(struct xsk_buff_pool *pool, u64 addr) +{ + addr = pool->unaligned ? xp_unaligned_add_offset_to_addr(addr) : addr; + return (pool->dma_pages[addr >> PAGE_SHIFT] & + ~XSK_NEXT_PG_CONTIG_MASK) + + (addr & ~PAGE_MASK); +} +EXPORT_SYMBOL(xp_raw_get_dma); + +void xp_dma_sync_for_cpu_slow(struct xdp_buff_xsk *xskb) +{ + dma_sync_single_range_for_cpu(xskb->pool->dev, xskb->dma, 0, + xskb->pool->frame_len, DMA_BIDIRECTIONAL); +} +EXPORT_SYMBOL(xp_dma_sync_for_cpu_slow); + +void xp_dma_sync_for_device_slow(struct xsk_buff_pool *pool, dma_addr_t dma, + size_t size) +{ + dma_sync_single_range_for_device(pool->dev, dma, 0, + size, DMA_BIDIRECTIONAL); +} +EXPORT_SYMBOL(xp_dma_sync_for_device_slow); diff --git a/net/xdp/xsk_diag.c b/net/xdp/xsk_diag.c new file mode 100644 index 000000000..22b36c814 --- /dev/null +++ b/net/xdp/xsk_diag.c @@ -0,0 +1,214 @@ +// SPDX-License-Identifier: GPL-2.0 +/* XDP sockets monitoring support + * + * Copyright(c) 2019 Intel Corporation. + * + * Author: Björn Töpel + */ + +#include +#include +#include +#include + +#include "xsk_queue.h" +#include "xsk.h" + +static int xsk_diag_put_info(const struct xdp_sock *xs, struct sk_buff *nlskb) +{ + struct xdp_diag_info di = {}; + + di.ifindex = xs->dev ? xs->dev->ifindex : 0; + di.queue_id = xs->queue_id; + return nla_put(nlskb, XDP_DIAG_INFO, sizeof(di), &di); +} + +static int xsk_diag_put_ring(const struct xsk_queue *queue, int nl_type, + struct sk_buff *nlskb) +{ + struct xdp_diag_ring dr = {}; + + dr.entries = queue->nentries; + return nla_put(nlskb, nl_type, sizeof(dr), &dr); +} + +static int xsk_diag_put_rings_cfg(const struct xdp_sock *xs, + struct sk_buff *nlskb) +{ + int err = 0; + + if (xs->rx) + err = xsk_diag_put_ring(xs->rx, XDP_DIAG_RX_RING, nlskb); + if (!err && xs->tx) + err = xsk_diag_put_ring(xs->tx, XDP_DIAG_TX_RING, nlskb); + return err; +} + +static int xsk_diag_put_umem(const struct xdp_sock *xs, struct sk_buff *nlskb) +{ + struct xsk_buff_pool *pool = xs->pool; + struct xdp_umem *umem = xs->umem; + struct xdp_diag_umem du = {}; + int err; + + if (!umem) + return 0; + + du.id = umem->id; + du.size = umem->size; + du.num_pages = umem->npgs; + du.chunk_size = umem->chunk_size; + du.headroom = umem->headroom; + du.ifindex = (pool && pool->netdev) ? pool->netdev->ifindex : 0; + du.queue_id = pool ? pool->queue_id : 0; + du.flags = 0; + if (umem->zc) + du.flags |= XDP_DU_F_ZEROCOPY; + du.refs = refcount_read(&umem->users); + + err = nla_put(nlskb, XDP_DIAG_UMEM, sizeof(du), &du); + if (!err && pool && pool->fq) + err = xsk_diag_put_ring(pool->fq, + XDP_DIAG_UMEM_FILL_RING, nlskb); + if (!err && pool && pool->cq) + err = xsk_diag_put_ring(pool->cq, + XDP_DIAG_UMEM_COMPLETION_RING, nlskb); + return err; +} + +static int xsk_diag_put_stats(const struct xdp_sock *xs, struct sk_buff *nlskb) +{ + struct xdp_diag_stats du = {}; + + du.n_rx_dropped = xs->rx_dropped; + du.n_rx_invalid = xskq_nb_invalid_descs(xs->rx); + du.n_rx_full = xs->rx_queue_full; + du.n_fill_ring_empty = xs->pool ? xskq_nb_queue_empty_descs(xs->pool->fq) : 0; + du.n_tx_invalid = xskq_nb_invalid_descs(xs->tx); + du.n_tx_ring_empty = xskq_nb_queue_empty_descs(xs->tx); + return nla_put(nlskb, XDP_DIAG_STATS, sizeof(du), &du); +} + +static int xsk_diag_fill(struct sock *sk, struct sk_buff *nlskb, + struct xdp_diag_req *req, + struct user_namespace *user_ns, + u32 portid, u32 seq, u32 flags, int sk_ino) +{ + struct xdp_sock *xs = xdp_sk(sk); + struct xdp_diag_msg *msg; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(nlskb, portid, seq, SOCK_DIAG_BY_FAMILY, sizeof(*msg), + flags); + if (!nlh) + return -EMSGSIZE; + + msg = nlmsg_data(nlh); + memset(msg, 0, sizeof(*msg)); + msg->xdiag_family = AF_XDP; + msg->xdiag_type = sk->sk_type; + msg->xdiag_ino = sk_ino; + sock_diag_save_cookie(sk, msg->xdiag_cookie); + + mutex_lock(&xs->mutex); + if (READ_ONCE(xs->state) == XSK_UNBOUND) + goto out_nlmsg_trim; + + if ((req->xdiag_show & XDP_SHOW_INFO) && xsk_diag_put_info(xs, nlskb)) + goto out_nlmsg_trim; + + if ((req->xdiag_show & XDP_SHOW_INFO) && + nla_put_u32(nlskb, XDP_DIAG_UID, + from_kuid_munged(user_ns, sock_i_uid(sk)))) + goto out_nlmsg_trim; + + if ((req->xdiag_show & XDP_SHOW_RING_CFG) && + xsk_diag_put_rings_cfg(xs, nlskb)) + goto out_nlmsg_trim; + + if ((req->xdiag_show & XDP_SHOW_UMEM) && + xsk_diag_put_umem(xs, nlskb)) + goto out_nlmsg_trim; + + if ((req->xdiag_show & XDP_SHOW_MEMINFO) && + sock_diag_put_meminfo(sk, nlskb, XDP_DIAG_MEMINFO)) + goto out_nlmsg_trim; + + if ((req->xdiag_show & XDP_SHOW_STATS) && + xsk_diag_put_stats(xs, nlskb)) + goto out_nlmsg_trim; + + mutex_unlock(&xs->mutex); + nlmsg_end(nlskb, nlh); + return 0; + +out_nlmsg_trim: + mutex_unlock(&xs->mutex); + nlmsg_cancel(nlskb, nlh); + return -EMSGSIZE; +} + +static int xsk_diag_dump(struct sk_buff *nlskb, struct netlink_callback *cb) +{ + struct xdp_diag_req *req = nlmsg_data(cb->nlh); + struct net *net = sock_net(nlskb->sk); + int num = 0, s_num = cb->args[0]; + struct sock *sk; + + mutex_lock(&net->xdp.lock); + + sk_for_each(sk, &net->xdp.list) { + if (!net_eq(sock_net(sk), net)) + continue; + if (num++ < s_num) + continue; + + if (xsk_diag_fill(sk, nlskb, req, + sk_user_ns(NETLINK_CB(cb->skb).sk), + NETLINK_CB(cb->skb).portid, + cb->nlh->nlmsg_seq, NLM_F_MULTI, + sock_i_ino(sk)) < 0) { + num--; + break; + } + } + + mutex_unlock(&net->xdp.lock); + cb->args[0] = num; + return nlskb->len; +} + +static int xsk_diag_handler_dump(struct sk_buff *nlskb, struct nlmsghdr *hdr) +{ + struct netlink_dump_control c = { .dump = xsk_diag_dump }; + int hdrlen = sizeof(struct xdp_diag_req); + struct net *net = sock_net(nlskb->sk); + + if (nlmsg_len(hdr) < hdrlen) + return -EINVAL; + + if (!(hdr->nlmsg_flags & NLM_F_DUMP)) + return -EOPNOTSUPP; + + return netlink_dump_start(net->diag_nlsk, nlskb, hdr, &c); +} + +static const struct sock_diag_handler xsk_diag_handler = { + .family = AF_XDP, + .dump = xsk_diag_handler_dump, +}; + +static int __init xsk_diag_init(void) +{ + return sock_diag_register(&xsk_diag_handler); +} + +static void __exit xsk_diag_exit(void) +{ + sock_diag_unregister(&xsk_diag_handler); +} + +module_init(xsk_diag_init); +module_exit(xsk_diag_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO_TYPE(PF_NETLINK, NETLINK_SOCK_DIAG, AF_XDP); diff --git a/net/xdp/xsk_queue.c b/net/xdp/xsk_queue.c new file mode 100644 index 000000000..6cf9586e5 --- /dev/null +++ b/net/xdp/xsk_queue.c @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: GPL-2.0 +/* XDP user-space ring structure + * Copyright(c) 2018 Intel Corporation. + */ + +#include +#include +#include +#include + +#include "xsk_queue.h" + +static size_t xskq_get_ring_size(struct xsk_queue *q, bool umem_queue) +{ + struct xdp_umem_ring *umem_ring; + struct xdp_rxtx_ring *rxtx_ring; + + if (umem_queue) + return struct_size(umem_ring, desc, q->nentries); + return struct_size(rxtx_ring, desc, q->nentries); +} + +struct xsk_queue *xskq_create(u32 nentries, bool umem_queue) +{ + struct xsk_queue *q; + gfp_t gfp_flags; + size_t size; + + q = kzalloc(sizeof(*q), GFP_KERNEL); + if (!q) + return NULL; + + q->nentries = nentries; + q->ring_mask = nentries - 1; + + gfp_flags = GFP_KERNEL | __GFP_ZERO | __GFP_NOWARN | + __GFP_COMP | __GFP_NORETRY; + size = xskq_get_ring_size(q, umem_queue); + + q->ring = (struct xdp_ring *)__get_free_pages(gfp_flags, + get_order(size)); + if (!q->ring) { + kfree(q); + return NULL; + } + + return q; +} + +void xskq_destroy(struct xsk_queue *q) +{ + if (!q) + return; + + page_frag_free(q->ring); + kfree(q); +} diff --git a/net/xdp/xsk_queue.h b/net/xdp/xsk_queue.h new file mode 100644 index 000000000..bdeba20aa --- /dev/null +++ b/net/xdp/xsk_queue.h @@ -0,0 +1,432 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +/* XDP user-space ring structure + * Copyright(c) 2018 Intel Corporation. + */ + +#ifndef _LINUX_XSK_QUEUE_H +#define _LINUX_XSK_QUEUE_H + +#include +#include +#include +#include + +#include "xsk.h" + +struct xdp_ring { + u32 producer ____cacheline_aligned_in_smp; + /* Hinder the adjacent cache prefetcher to prefetch the consumer + * pointer if the producer pointer is touched and vice versa. + */ + u32 pad1 ____cacheline_aligned_in_smp; + u32 consumer ____cacheline_aligned_in_smp; + u32 pad2 ____cacheline_aligned_in_smp; + u32 flags; + u32 pad3 ____cacheline_aligned_in_smp; +}; + +/* Used for the RX and TX queues for packets */ +struct xdp_rxtx_ring { + struct xdp_ring ptrs; + struct xdp_desc desc[] ____cacheline_aligned_in_smp; +}; + +/* Used for the fill and completion queues for buffers */ +struct xdp_umem_ring { + struct xdp_ring ptrs; + u64 desc[] ____cacheline_aligned_in_smp; +}; + +struct xsk_queue { + u32 ring_mask; + u32 nentries; + u32 cached_prod; + u32 cached_cons; + struct xdp_ring *ring; + u64 invalid_descs; + u64 queue_empty_descs; +}; + +/* The structure of the shared state of the rings are a simple + * circular buffer, as outlined in + * Documentation/core-api/circular-buffers.rst. For the Rx and + * completion ring, the kernel is the producer and user space is the + * consumer. For the Tx and fill rings, the kernel is the consumer and + * user space is the producer. + * + * producer consumer + * + * if (LOAD ->consumer) { (A) LOAD.acq ->producer (C) + * STORE $data LOAD $data + * STORE.rel ->producer (B) STORE.rel ->consumer (D) + * } + * + * (A) pairs with (D), and (B) pairs with (C). + * + * Starting with (B), it protects the data from being written after + * the producer pointer. If this barrier was missing, the consumer + * could observe the producer pointer being set and thus load the data + * before the producer has written the new data. The consumer would in + * this case load the old data. + * + * (C) protects the consumer from speculatively loading the data before + * the producer pointer actually has been read. If we do not have this + * barrier, some architectures could load old data as speculative loads + * are not discarded as the CPU does not know there is a dependency + * between ->producer and data. + * + * (A) is a control dependency that separates the load of ->consumer + * from the stores of $data. In case ->consumer indicates there is no + * room in the buffer to store $data we do not. The dependency will + * order both of the stores after the loads. So no barrier is needed. + * + * (D) protects the load of the data to be observed to happen after the + * store of the consumer pointer. If we did not have this memory + * barrier, the producer could observe the consumer pointer being set + * and overwrite the data with a new value before the consumer got the + * chance to read the old value. The consumer would thus miss reading + * the old entry and very likely read the new entry twice, once right + * now and again after circling through the ring. + */ + +/* The operations on the rings are the following: + * + * producer consumer + * + * RESERVE entries PEEK in the ring for entries + * WRITE data into the ring READ data from the ring + * SUBMIT entries RELEASE entries + * + * The producer reserves one or more entries in the ring. It can then + * fill in these entries and finally submit them so that they can be + * seen and read by the consumer. + * + * The consumer peeks into the ring to see if the producer has written + * any new entries. If so, the consumer can then read these entries + * and when it is done reading them release them back to the producer + * so that the producer can use these slots to fill in new entries. + * + * The function names below reflect these operations. + */ + +/* Functions that read and validate content from consumer rings. */ + +static inline void __xskq_cons_read_addr_unchecked(struct xsk_queue *q, u32 cached_cons, u64 *addr) +{ + struct xdp_umem_ring *ring = (struct xdp_umem_ring *)q->ring; + u32 idx = cached_cons & q->ring_mask; + + *addr = ring->desc[idx]; +} + +static inline bool xskq_cons_read_addr_unchecked(struct xsk_queue *q, u64 *addr) +{ + if (q->cached_cons != q->cached_prod) { + __xskq_cons_read_addr_unchecked(q, q->cached_cons, addr); + return true; + } + + return false; +} + +static inline bool xp_aligned_validate_desc(struct xsk_buff_pool *pool, + struct xdp_desc *desc) +{ + u64 chunk, chunk_end; + + chunk = xp_aligned_extract_addr(pool, desc->addr); + if (likely(desc->len)) { + chunk_end = xp_aligned_extract_addr(pool, desc->addr + desc->len - 1); + if (chunk != chunk_end) + return false; + } + + if (chunk >= pool->addrs_cnt) + return false; + + if (desc->options) + return false; + return true; +} + +static inline bool xp_unaligned_validate_desc(struct xsk_buff_pool *pool, + struct xdp_desc *desc) +{ + u64 addr, base_addr; + + base_addr = xp_unaligned_extract_addr(desc->addr); + addr = xp_unaligned_add_offset_to_addr(desc->addr); + + if (desc->len > pool->chunk_size) + return false; + + if (base_addr >= pool->addrs_cnt || addr >= pool->addrs_cnt || + addr + desc->len > pool->addrs_cnt || + xp_desc_crosses_non_contig_pg(pool, addr, desc->len)) + return false; + + if (desc->options) + return false; + return true; +} + +static inline bool xp_validate_desc(struct xsk_buff_pool *pool, + struct xdp_desc *desc) +{ + return pool->unaligned ? xp_unaligned_validate_desc(pool, desc) : + xp_aligned_validate_desc(pool, desc); +} + +static inline bool xskq_cons_is_valid_desc(struct xsk_queue *q, + struct xdp_desc *d, + struct xsk_buff_pool *pool) +{ + if (!xp_validate_desc(pool, d)) { + q->invalid_descs++; + return false; + } + return true; +} + +static inline bool xskq_cons_read_desc(struct xsk_queue *q, + struct xdp_desc *desc, + struct xsk_buff_pool *pool) +{ + while (q->cached_cons != q->cached_prod) { + struct xdp_rxtx_ring *ring = (struct xdp_rxtx_ring *)q->ring; + u32 idx = q->cached_cons & q->ring_mask; + + *desc = ring->desc[idx]; + if (xskq_cons_is_valid_desc(q, desc, pool)) + return true; + + q->cached_cons++; + } + + return false; +} + +static inline void xskq_cons_release_n(struct xsk_queue *q, u32 cnt) +{ + q->cached_cons += cnt; +} + +static inline u32 xskq_cons_read_desc_batch(struct xsk_queue *q, struct xsk_buff_pool *pool, + u32 max) +{ + u32 cached_cons = q->cached_cons, nb_entries = 0; + struct xdp_desc *descs = pool->tx_descs; + + while (cached_cons != q->cached_prod && nb_entries < max) { + struct xdp_rxtx_ring *ring = (struct xdp_rxtx_ring *)q->ring; + u32 idx = cached_cons & q->ring_mask; + + descs[nb_entries] = ring->desc[idx]; + if (unlikely(!xskq_cons_is_valid_desc(q, &descs[nb_entries], pool))) { + /* Skip the entry */ + cached_cons++; + continue; + } + + nb_entries++; + cached_cons++; + } + + /* Release valid plus any invalid entries */ + xskq_cons_release_n(q, cached_cons - q->cached_cons); + return nb_entries; +} + +/* Functions for consumers */ + +static inline void __xskq_cons_release(struct xsk_queue *q) +{ + smp_store_release(&q->ring->consumer, q->cached_cons); /* D, matchees A */ +} + +static inline void __xskq_cons_peek(struct xsk_queue *q) +{ + /* Refresh the local pointer */ + q->cached_prod = smp_load_acquire(&q->ring->producer); /* C, matches B */ +} + +static inline void xskq_cons_get_entries(struct xsk_queue *q) +{ + __xskq_cons_release(q); + __xskq_cons_peek(q); +} + +static inline u32 xskq_cons_nb_entries(struct xsk_queue *q, u32 max) +{ + u32 entries = q->cached_prod - q->cached_cons; + + if (entries >= max) + return max; + + __xskq_cons_peek(q); + entries = q->cached_prod - q->cached_cons; + + return entries >= max ? max : entries; +} + +static inline bool xskq_cons_has_entries(struct xsk_queue *q, u32 cnt) +{ + return xskq_cons_nb_entries(q, cnt) >= cnt; +} + +static inline bool xskq_cons_peek_addr_unchecked(struct xsk_queue *q, u64 *addr) +{ + if (q->cached_prod == q->cached_cons) + xskq_cons_get_entries(q); + return xskq_cons_read_addr_unchecked(q, addr); +} + +static inline bool xskq_cons_peek_desc(struct xsk_queue *q, + struct xdp_desc *desc, + struct xsk_buff_pool *pool) +{ + if (q->cached_prod == q->cached_cons) + xskq_cons_get_entries(q); + return xskq_cons_read_desc(q, desc, pool); +} + +/* To improve performance in the xskq_cons_release functions, only update local state here. + * Reflect this to global state when we get new entries from the ring in + * xskq_cons_get_entries() and whenever Rx or Tx processing are completed in the NAPI loop. + */ +static inline void xskq_cons_release(struct xsk_queue *q) +{ + q->cached_cons++; +} + +static inline u32 xskq_cons_present_entries(struct xsk_queue *q) +{ + /* No barriers needed since data is not accessed */ + return READ_ONCE(q->ring->producer) - READ_ONCE(q->ring->consumer); +} + +/* Functions for producers */ + +static inline u32 xskq_prod_nb_free(struct xsk_queue *q, u32 max) +{ + u32 free_entries = q->nentries - (q->cached_prod - q->cached_cons); + + if (free_entries >= max) + return max; + + /* Refresh the local tail pointer */ + q->cached_cons = READ_ONCE(q->ring->consumer); + free_entries = q->nentries - (q->cached_prod - q->cached_cons); + + return free_entries >= max ? max : free_entries; +} + +static inline bool xskq_prod_is_full(struct xsk_queue *q) +{ + return xskq_prod_nb_free(q, 1) ? false : true; +} + +static inline void xskq_prod_cancel(struct xsk_queue *q) +{ + q->cached_prod--; +} + +static inline int xskq_prod_reserve(struct xsk_queue *q) +{ + if (xskq_prod_is_full(q)) + return -ENOSPC; + + /* A, matches D */ + q->cached_prod++; + return 0; +} + +static inline int xskq_prod_reserve_addr(struct xsk_queue *q, u64 addr) +{ + struct xdp_umem_ring *ring = (struct xdp_umem_ring *)q->ring; + + if (xskq_prod_is_full(q)) + return -ENOSPC; + + /* A, matches D */ + ring->desc[q->cached_prod++ & q->ring_mask] = addr; + return 0; +} + +static inline void xskq_prod_write_addr_batch(struct xsk_queue *q, struct xdp_desc *descs, + u32 nb_entries) +{ + struct xdp_umem_ring *ring = (struct xdp_umem_ring *)q->ring; + u32 i, cached_prod; + + /* A, matches D */ + cached_prod = q->cached_prod; + for (i = 0; i < nb_entries; i++) + ring->desc[cached_prod++ & q->ring_mask] = descs[i].addr; + q->cached_prod = cached_prod; +} + +static inline int xskq_prod_reserve_desc(struct xsk_queue *q, + u64 addr, u32 len) +{ + struct xdp_rxtx_ring *ring = (struct xdp_rxtx_ring *)q->ring; + u32 idx; + + if (xskq_prod_is_full(q)) + return -ENOBUFS; + + /* A, matches D */ + idx = q->cached_prod++ & q->ring_mask; + ring->desc[idx].addr = addr; + ring->desc[idx].len = len; + + return 0; +} + +static inline void __xskq_prod_submit(struct xsk_queue *q, u32 idx) +{ + smp_store_release(&q->ring->producer, idx); /* B, matches C */ +} + +static inline void xskq_prod_submit(struct xsk_queue *q) +{ + __xskq_prod_submit(q, q->cached_prod); +} + +static inline void xskq_prod_submit_addr(struct xsk_queue *q, u64 addr) +{ + struct xdp_umem_ring *ring = (struct xdp_umem_ring *)q->ring; + u32 idx = q->ring->producer; + + ring->desc[idx++ & q->ring_mask] = addr; + + __xskq_prod_submit(q, idx); +} + +static inline void xskq_prod_submit_n(struct xsk_queue *q, u32 nb_entries) +{ + __xskq_prod_submit(q, q->ring->producer + nb_entries); +} + +static inline bool xskq_prod_is_empty(struct xsk_queue *q) +{ + /* No barriers needed since data is not accessed */ + return READ_ONCE(q->ring->consumer) == READ_ONCE(q->ring->producer); +} + +/* For both producers and consumers */ + +static inline u64 xskq_nb_invalid_descs(struct xsk_queue *q) +{ + return q ? q->invalid_descs : 0; +} + +static inline u64 xskq_nb_queue_empty_descs(struct xsk_queue *q) +{ + return q ? q->queue_empty_descs : 0; +} + +struct xsk_queue *xskq_create(u32 nentries, bool umem_queue); +void xskq_destroy(struct xsk_queue *q_ops); + +#endif /* _LINUX_XSK_QUEUE_H */ diff --git a/net/xdp/xskmap.c b/net/xdp/xskmap.c new file mode 100644 index 000000000..acc8e52a4 --- /dev/null +++ b/net/xdp/xskmap.c @@ -0,0 +1,272 @@ +// SPDX-License-Identifier: GPL-2.0 +/* XSKMAP used for AF_XDP sockets + * Copyright(c) 2018 Intel Corporation. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include "xsk.h" + +static struct xsk_map_node *xsk_map_node_alloc(struct xsk_map *map, + struct xdp_sock __rcu **map_entry) +{ + struct xsk_map_node *node; + + node = bpf_map_kzalloc(&map->map, sizeof(*node), + GFP_ATOMIC | __GFP_NOWARN); + if (!node) + return ERR_PTR(-ENOMEM); + + bpf_map_inc(&map->map); + + node->map = map; + node->map_entry = map_entry; + return node; +} + +static void xsk_map_node_free(struct xsk_map_node *node) +{ + bpf_map_put(&node->map->map); + kfree(node); +} + +static void xsk_map_sock_add(struct xdp_sock *xs, struct xsk_map_node *node) +{ + spin_lock_bh(&xs->map_list_lock); + list_add_tail(&node->node, &xs->map_list); + spin_unlock_bh(&xs->map_list_lock); +} + +static void xsk_map_sock_delete(struct xdp_sock *xs, + struct xdp_sock __rcu **map_entry) +{ + struct xsk_map_node *n, *tmp; + + spin_lock_bh(&xs->map_list_lock); + list_for_each_entry_safe(n, tmp, &xs->map_list, node) { + if (map_entry == n->map_entry) { + list_del(&n->node); + xsk_map_node_free(n); + } + } + spin_unlock_bh(&xs->map_list_lock); +} + +static struct bpf_map *xsk_map_alloc(union bpf_attr *attr) +{ + struct xsk_map *m; + int numa_node; + u64 size; + + if (!capable(CAP_NET_ADMIN)) + return ERR_PTR(-EPERM); + + if (attr->max_entries == 0 || attr->key_size != 4 || + attr->value_size != 4 || + attr->map_flags & ~(BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)) + return ERR_PTR(-EINVAL); + + numa_node = bpf_map_attr_numa_node(attr); + size = struct_size(m, xsk_map, attr->max_entries); + + m = bpf_map_area_alloc(size, numa_node); + if (!m) + return ERR_PTR(-ENOMEM); + + bpf_map_init_from_attr(&m->map, attr); + spin_lock_init(&m->lock); + + return &m->map; +} + +static void xsk_map_free(struct bpf_map *map) +{ + struct xsk_map *m = container_of(map, struct xsk_map, map); + + synchronize_net(); + bpf_map_area_free(m); +} + +static int xsk_map_get_next_key(struct bpf_map *map, void *key, void *next_key) +{ + struct xsk_map *m = container_of(map, struct xsk_map, map); + u32 index = key ? *(u32 *)key : U32_MAX; + u32 *next = next_key; + + if (index >= m->map.max_entries) { + *next = 0; + return 0; + } + + if (index == m->map.max_entries - 1) + return -ENOENT; + *next = index + 1; + return 0; +} + +static int xsk_map_gen_lookup(struct bpf_map *map, struct bpf_insn *insn_buf) +{ + const int ret = BPF_REG_0, mp = BPF_REG_1, index = BPF_REG_2; + struct bpf_insn *insn = insn_buf; + + *insn++ = BPF_LDX_MEM(BPF_W, ret, index, 0); + *insn++ = BPF_JMP_IMM(BPF_JGE, ret, map->max_entries, 5); + *insn++ = BPF_ALU64_IMM(BPF_LSH, ret, ilog2(sizeof(struct xsk_sock *))); + *insn++ = BPF_ALU64_IMM(BPF_ADD, mp, offsetof(struct xsk_map, xsk_map)); + *insn++ = BPF_ALU64_REG(BPF_ADD, ret, mp); + *insn++ = BPF_LDX_MEM(BPF_SIZEOF(struct xsk_sock *), ret, ret, 0); + *insn++ = BPF_JMP_IMM(BPF_JA, 0, 0, 1); + *insn++ = BPF_MOV64_IMM(ret, 0); + return insn - insn_buf; +} + +/* Elements are kept alive by RCU; either by rcu_read_lock() (from syscall) or + * by local_bh_disable() (from XDP calls inside NAPI). The + * rcu_read_lock_bh_held() below makes lockdep accept both. + */ +static void *__xsk_map_lookup_elem(struct bpf_map *map, u32 key) +{ + struct xsk_map *m = container_of(map, struct xsk_map, map); + + if (key >= map->max_entries) + return NULL; + + return rcu_dereference_check(m->xsk_map[key], rcu_read_lock_bh_held()); +} + +static void *xsk_map_lookup_elem(struct bpf_map *map, void *key) +{ + return __xsk_map_lookup_elem(map, *(u32 *)key); +} + +static void *xsk_map_lookup_elem_sys_only(struct bpf_map *map, void *key) +{ + return ERR_PTR(-EOPNOTSUPP); +} + +static int xsk_map_update_elem(struct bpf_map *map, void *key, void *value, + u64 map_flags) +{ + struct xsk_map *m = container_of(map, struct xsk_map, map); + struct xdp_sock __rcu **map_entry; + struct xdp_sock *xs, *old_xs; + u32 i = *(u32 *)key, fd = *(u32 *)value; + struct xsk_map_node *node; + struct socket *sock; + int err; + + if (unlikely(map_flags > BPF_EXIST)) + return -EINVAL; + if (unlikely(i >= m->map.max_entries)) + return -E2BIG; + + sock = sockfd_lookup(fd, &err); + if (!sock) + return err; + + if (sock->sk->sk_family != PF_XDP) { + sockfd_put(sock); + return -EOPNOTSUPP; + } + + xs = (struct xdp_sock *)sock->sk; + + map_entry = &m->xsk_map[i]; + node = xsk_map_node_alloc(m, map_entry); + if (IS_ERR(node)) { + sockfd_put(sock); + return PTR_ERR(node); + } + + spin_lock_bh(&m->lock); + old_xs = rcu_dereference_protected(*map_entry, lockdep_is_held(&m->lock)); + if (old_xs == xs) { + err = 0; + goto out; + } else if (old_xs && map_flags == BPF_NOEXIST) { + err = -EEXIST; + goto out; + } else if (!old_xs && map_flags == BPF_EXIST) { + err = -ENOENT; + goto out; + } + xsk_map_sock_add(xs, node); + rcu_assign_pointer(*map_entry, xs); + if (old_xs) + xsk_map_sock_delete(old_xs, map_entry); + spin_unlock_bh(&m->lock); + sockfd_put(sock); + return 0; + +out: + spin_unlock_bh(&m->lock); + sockfd_put(sock); + xsk_map_node_free(node); + return err; +} + +static int xsk_map_delete_elem(struct bpf_map *map, void *key) +{ + struct xsk_map *m = container_of(map, struct xsk_map, map); + struct xdp_sock __rcu **map_entry; + struct xdp_sock *old_xs; + int k = *(u32 *)key; + + if (k >= map->max_entries) + return -EINVAL; + + spin_lock_bh(&m->lock); + map_entry = &m->xsk_map[k]; + old_xs = unrcu_pointer(xchg(map_entry, NULL)); + if (old_xs) + xsk_map_sock_delete(old_xs, map_entry); + spin_unlock_bh(&m->lock); + + return 0; +} + +static int xsk_map_redirect(struct bpf_map *map, u32 ifindex, u64 flags) +{ + return __bpf_xdp_redirect_map(map, ifindex, flags, 0, + __xsk_map_lookup_elem); +} + +void xsk_map_try_sock_delete(struct xsk_map *map, struct xdp_sock *xs, + struct xdp_sock __rcu **map_entry) +{ + spin_lock_bh(&map->lock); + if (rcu_access_pointer(*map_entry) == xs) { + rcu_assign_pointer(*map_entry, NULL); + xsk_map_sock_delete(xs, map_entry); + } + spin_unlock_bh(&map->lock); +} + +static bool xsk_map_meta_equal(const struct bpf_map *meta0, + const struct bpf_map *meta1) +{ + return meta0->max_entries == meta1->max_entries && + bpf_map_meta_equal(meta0, meta1); +} + +BTF_ID_LIST_SINGLE(xsk_map_btf_ids, struct, xsk_map) +const struct bpf_map_ops xsk_map_ops = { + .map_meta_equal = xsk_map_meta_equal, + .map_alloc = xsk_map_alloc, + .map_free = xsk_map_free, + .map_get_next_key = xsk_map_get_next_key, + .map_lookup_elem = xsk_map_lookup_elem, + .map_gen_lookup = xsk_map_gen_lookup, + .map_lookup_elem_sys_only = xsk_map_lookup_elem_sys_only, + .map_update_elem = xsk_map_update_elem, + .map_delete_elem = xsk_map_delete_elem, + .map_check_btf = map_check_no_btf, + .map_btf_id = &xsk_map_btf_ids[0], + .map_redirect = xsk_map_redirect, +}; diff --git a/net/xfrm/Kconfig b/net/xfrm/Kconfig new file mode 100644 index 000000000..3adf31a83 --- /dev/null +++ b/net/xfrm/Kconfig @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: GPL-2.0-only +# +# XFRM configuration +# +config XFRM + bool + depends on INET + select GRO_CELLS + select SKB_EXTENSIONS + +config XFRM_OFFLOAD + bool + +config XFRM_ALGO + tristate + select XFRM + select CRYPTO + select CRYPTO_HASH + select CRYPTO_SKCIPHER + +if INET +config XFRM_USER + tristate "Transformation user configuration interface" + select XFRM_ALGO + help + Support for Transformation(XFRM) user configuration interface + like IPsec used by native Linux tools. + + If unsure, say Y. + +config XFRM_USER_COMPAT + tristate "Compatible ABI support" + depends on XFRM_USER && COMPAT_FOR_U64_ALIGNMENT && \ + HAVE_EFFICIENT_UNALIGNED_ACCESS + select WANT_COMPAT_NETLINK_MESSAGES + help + Transformation(XFRM) user configuration interface like IPsec + used by compatible Linux applications. + + If unsure, say N. + +config XFRM_INTERFACE + tristate "Transformation virtual interface" + depends on XFRM && IPV6 + help + This provides a virtual interface to route IPsec traffic. + + If unsure, say N. + +config XFRM_SUB_POLICY + bool "Transformation sub policy support" + depends on XFRM + help + Support sub policy for developers. By using sub policy with main + one, two policies can be applied to the same packet at once. + Policy which lives shorter time in kernel should be a sub. + + If unsure, say N. + +config XFRM_MIGRATE + bool "Transformation migrate database" + depends on XFRM + help + A feature to update locator(s) of a given IPsec security + association dynamically. This feature is required, for + instance, in a Mobile IPv6 environment with IPsec configuration + where mobile nodes change their attachment point to the Internet. + + If unsure, say N. + +config XFRM_STATISTICS + bool "Transformation statistics" + depends on XFRM && PROC_FS + help + This statistics is not a SNMP/MIB specification but shows + statistics about transformation error (or almost error) factor + at packet processing for developer. + + If unsure, say N. + +# This option selects XFRM_ALGO along with the AH authentication algorithms that +# RFC 8221 lists as MUST be implemented. +config XFRM_AH + tristate + select XFRM_ALGO + select CRYPTO + select CRYPTO_HMAC + select CRYPTO_SHA256 + +# This option selects XFRM_ALGO along with the ESP encryption and authentication +# algorithms that RFC 8221 lists as MUST be implemented. +config XFRM_ESP + tristate + select XFRM_ALGO + select CRYPTO + select CRYPTO_AES + select CRYPTO_AUTHENC + select CRYPTO_CBC + select CRYPTO_ECHAINIV + select CRYPTO_GCM + select CRYPTO_HMAC + select CRYPTO_SEQIV + select CRYPTO_SHA256 + +config XFRM_IPCOMP + tristate + select XFRM_ALGO + select CRYPTO + select CRYPTO_DEFLATE + +config NET_KEY + tristate "PF_KEY sockets" + select XFRM_ALGO + help + PF_KEYv2 socket family, compatible to KAME ones. + They are required if you are going to use IPsec tools ported + from KAME. + + Say Y unless you know what you are doing. + +config NET_KEY_MIGRATE + bool "PF_KEY MIGRATE" + depends on NET_KEY + select XFRM_MIGRATE + help + Add a PF_KEY MIGRATE message to PF_KEYv2 socket family. + The PF_KEY MIGRATE message is used to dynamically update + locator(s) of a given IPsec security association. + This feature is required, for instance, in a Mobile IPv6 + environment with IPsec configuration where mobile nodes + change their attachment point to the Internet. Detail + information can be found in the internet-draft + . + + If unsure, say N. + +config XFRM_ESPINTCP + bool + +endif # INET diff --git a/net/xfrm/Makefile b/net/xfrm/Makefile new file mode 100644 index 000000000..08a2870fd --- /dev/null +++ b/net/xfrm/Makefile @@ -0,0 +1,17 @@ +# SPDX-License-Identifier: GPL-2.0 +# +# Makefile for the XFRM subsystem. +# + +xfrm_interface-$(CONFIG_XFRM_INTERFACE) += xfrm_interface_core.o + +obj-$(CONFIG_XFRM) := xfrm_policy.o xfrm_state.o xfrm_hash.o \ + xfrm_input.o xfrm_output.o \ + xfrm_sysctl.o xfrm_replay.o xfrm_device.o +obj-$(CONFIG_XFRM_STATISTICS) += xfrm_proc.o +obj-$(CONFIG_XFRM_ALGO) += xfrm_algo.o +obj-$(CONFIG_XFRM_USER) += xfrm_user.o +obj-$(CONFIG_XFRM_USER_COMPAT) += xfrm_compat.o +obj-$(CONFIG_XFRM_IPCOMP) += xfrm_ipcomp.o +obj-$(CONFIG_XFRM_INTERFACE) += xfrm_interface.o +obj-$(CONFIG_XFRM_ESPINTCP) += espintcp.o diff --git a/net/xfrm/espintcp.c b/net/xfrm/espintcp.c new file mode 100644 index 000000000..d6fece1ed --- /dev/null +++ b/net/xfrm/espintcp.c @@ -0,0 +1,579 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_IPV6) +#include +#endif + +static void handle_nonesp(struct espintcp_ctx *ctx, struct sk_buff *skb, + struct sock *sk) +{ + if (atomic_read(&sk->sk_rmem_alloc) >= sk->sk_rcvbuf || + !sk_rmem_schedule(sk, skb, skb->truesize)) { + XFRM_INC_STATS(sock_net(sk), LINUX_MIB_XFRMINERROR); + kfree_skb(skb); + return; + } + + skb_set_owner_r(skb, sk); + + memset(skb->cb, 0, sizeof(skb->cb)); + skb_queue_tail(&ctx->ike_queue, skb); + ctx->saved_data_ready(sk); +} + +static void handle_esp(struct sk_buff *skb, struct sock *sk) +{ + struct tcp_skb_cb *tcp_cb = (struct tcp_skb_cb *)skb->cb; + + skb_reset_transport_header(skb); + + /* restore IP CB, we need at least IP6CB->nhoff */ + memmove(skb->cb, &tcp_cb->header, sizeof(tcp_cb->header)); + + rcu_read_lock(); + skb->dev = dev_get_by_index_rcu(sock_net(sk), skb->skb_iif); + local_bh_disable(); +#if IS_ENABLED(CONFIG_IPV6) + if (sk->sk_family == AF_INET6) + ipv6_stub->xfrm6_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP); + else +#endif + xfrm4_rcv_encap(skb, IPPROTO_ESP, 0, TCP_ENCAP_ESPINTCP); + local_bh_enable(); + rcu_read_unlock(); +} + +static void espintcp_rcv(struct strparser *strp, struct sk_buff *skb) +{ + struct espintcp_ctx *ctx = container_of(strp, struct espintcp_ctx, + strp); + struct strp_msg *rxm = strp_msg(skb); + int len = rxm->full_len - 2; + u32 nonesp_marker; + int err; + + /* keepalive packet? */ + if (unlikely(len == 1)) { + u8 data; + + err = skb_copy_bits(skb, rxm->offset + 2, &data, 1); + if (err < 0) { + XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR); + kfree_skb(skb); + return; + } + + if (data == 0xff) { + kfree_skb(skb); + return; + } + } + + /* drop other short messages */ + if (unlikely(len <= sizeof(nonesp_marker))) { + XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR); + kfree_skb(skb); + return; + } + + err = skb_copy_bits(skb, rxm->offset + 2, &nonesp_marker, + sizeof(nonesp_marker)); + if (err < 0) { + XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINHDRERROR); + kfree_skb(skb); + return; + } + + /* remove header, leave non-ESP marker/SPI */ + if (!pskb_pull(skb, rxm->offset + 2)) { + XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR); + kfree_skb(skb); + return; + } + + if (pskb_trim(skb, rxm->full_len - 2) != 0) { + XFRM_INC_STATS(sock_net(strp->sk), LINUX_MIB_XFRMINERROR); + kfree_skb(skb); + return; + } + + if (nonesp_marker == 0) + handle_nonesp(ctx, skb, strp->sk); + else + handle_esp(skb, strp->sk); +} + +static int espintcp_parse(struct strparser *strp, struct sk_buff *skb) +{ + struct strp_msg *rxm = strp_msg(skb); + __be16 blen; + u16 len; + int err; + + if (skb->len < rxm->offset + 2) + return 0; + + err = skb_copy_bits(skb, rxm->offset, &blen, sizeof(blen)); + if (err < 0) + return err; + + len = be16_to_cpu(blen); + if (len < 2) + return -EINVAL; + + return len; +} + +static int espintcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, + int flags, int *addr_len) +{ + struct espintcp_ctx *ctx = espintcp_getctx(sk); + struct sk_buff *skb; + int err = 0; + int copied; + int off = 0; + + skb = __skb_recv_datagram(sk, &ctx->ike_queue, flags, &off, &err); + if (!skb) { + if (err == -EAGAIN && sk->sk_shutdown & RCV_SHUTDOWN) + return 0; + return err; + } + + copied = len; + if (copied > skb->len) + copied = skb->len; + else if (copied < skb->len) + msg->msg_flags |= MSG_TRUNC; + + err = skb_copy_datagram_msg(skb, 0, msg, copied); + if (unlikely(err)) { + kfree_skb(skb); + return err; + } + + if (flags & MSG_TRUNC) + copied = skb->len; + kfree_skb(skb); + return copied; +} + +int espintcp_queue_out(struct sock *sk, struct sk_buff *skb) +{ + struct espintcp_ctx *ctx = espintcp_getctx(sk); + + if (skb_queue_len(&ctx->out_queue) >= READ_ONCE(netdev_max_backlog)) + return -ENOBUFS; + + __skb_queue_tail(&ctx->out_queue, skb); + + return 0; +} +EXPORT_SYMBOL_GPL(espintcp_queue_out); + +/* espintcp length field is 2B and length includes the length field's size */ +#define MAX_ESPINTCP_MSG (((1 << 16) - 1) - 2) + +static int espintcp_sendskb_locked(struct sock *sk, struct espintcp_msg *emsg, + int flags) +{ + do { + int ret; + + ret = skb_send_sock_locked(sk, emsg->skb, + emsg->offset, emsg->len); + if (ret < 0) + return ret; + + emsg->len -= ret; + emsg->offset += ret; + } while (emsg->len > 0); + + kfree_skb(emsg->skb); + memset(emsg, 0, sizeof(*emsg)); + + return 0; +} + +static int espintcp_sendskmsg_locked(struct sock *sk, + struct espintcp_msg *emsg, int flags) +{ + struct sk_msg *skmsg = &emsg->skmsg; + struct scatterlist *sg; + int done = 0; + int ret; + + flags |= MSG_SENDPAGE_NOTLAST; + sg = &skmsg->sg.data[skmsg->sg.start]; + do { + size_t size = sg->length - emsg->offset; + int offset = sg->offset + emsg->offset; + struct page *p; + + emsg->offset = 0; + + if (sg_is_last(sg)) + flags &= ~MSG_SENDPAGE_NOTLAST; + + p = sg_page(sg); +retry: + ret = do_tcp_sendpages(sk, p, offset, size, flags); + if (ret < 0) { + emsg->offset = offset - sg->offset; + skmsg->sg.start += done; + return ret; + } + + if (ret != size) { + offset += ret; + size -= ret; + goto retry; + } + + done++; + put_page(p); + sk_mem_uncharge(sk, sg->length); + sg = sg_next(sg); + } while (sg); + + memset(emsg, 0, sizeof(*emsg)); + + return 0; +} + +static int espintcp_push_msgs(struct sock *sk, int flags) +{ + struct espintcp_ctx *ctx = espintcp_getctx(sk); + struct espintcp_msg *emsg = &ctx->partial; + int err; + + if (!emsg->len) + return 0; + + if (ctx->tx_running) + return -EAGAIN; + ctx->tx_running = 1; + + if (emsg->skb) + err = espintcp_sendskb_locked(sk, emsg, flags); + else + err = espintcp_sendskmsg_locked(sk, emsg, flags); + if (err == -EAGAIN) { + ctx->tx_running = 0; + return flags & MSG_DONTWAIT ? -EAGAIN : 0; + } + if (!err) + memset(emsg, 0, sizeof(*emsg)); + + ctx->tx_running = 0; + + return err; +} + +int espintcp_push_skb(struct sock *sk, struct sk_buff *skb) +{ + struct espintcp_ctx *ctx = espintcp_getctx(sk); + struct espintcp_msg *emsg = &ctx->partial; + unsigned int len; + int offset; + + if (sk->sk_state != TCP_ESTABLISHED) { + kfree_skb(skb); + return -ECONNRESET; + } + + offset = skb_transport_offset(skb); + len = skb->len - offset; + + espintcp_push_msgs(sk, 0); + + if (emsg->len) { + kfree_skb(skb); + return -ENOBUFS; + } + + skb_set_owner_w(skb, sk); + + emsg->offset = offset; + emsg->len = len; + emsg->skb = skb; + + espintcp_push_msgs(sk, 0); + + return 0; +} +EXPORT_SYMBOL_GPL(espintcp_push_skb); + +static int espintcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) +{ + long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); + struct espintcp_ctx *ctx = espintcp_getctx(sk); + struct espintcp_msg *emsg = &ctx->partial; + struct iov_iter pfx_iter; + struct kvec pfx_iov = {}; + size_t msglen = size + 2; + char buf[2] = {0}; + int err, end; + + if (msg->msg_flags & ~MSG_DONTWAIT) + return -EOPNOTSUPP; + + if (size > MAX_ESPINTCP_MSG) + return -EMSGSIZE; + + if (msg->msg_controllen) + return -EOPNOTSUPP; + + lock_sock(sk); + + err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT); + if (err < 0) { + if (err != -EAGAIN || !(msg->msg_flags & MSG_DONTWAIT)) + err = -ENOBUFS; + goto unlock; + } + + sk_msg_init(&emsg->skmsg); + while (1) { + /* only -ENOMEM is possible since we don't coalesce */ + err = sk_msg_alloc(sk, &emsg->skmsg, msglen, 0); + if (!err) + break; + + err = sk_stream_wait_memory(sk, &timeo); + if (err) + goto fail; + } + + *((__be16 *)buf) = cpu_to_be16(msglen); + pfx_iov.iov_base = buf; + pfx_iov.iov_len = sizeof(buf); + iov_iter_kvec(&pfx_iter, ITER_SOURCE, &pfx_iov, 1, pfx_iov.iov_len); + + err = sk_msg_memcopy_from_iter(sk, &pfx_iter, &emsg->skmsg, + pfx_iov.iov_len); + if (err < 0) + goto fail; + + err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, &emsg->skmsg, size); + if (err < 0) + goto fail; + + end = emsg->skmsg.sg.end; + emsg->len = size; + sk_msg_iter_var_prev(end); + sg_mark_end(sk_msg_elem(&emsg->skmsg, end)); + + tcp_rate_check_app_limited(sk); + + err = espintcp_push_msgs(sk, msg->msg_flags & MSG_DONTWAIT); + /* this message could be partially sent, keep it */ + + release_sock(sk); + + return size; + +fail: + sk_msg_free(sk, &emsg->skmsg); + memset(emsg, 0, sizeof(*emsg)); +unlock: + release_sock(sk); + return err; +} + +static struct proto espintcp_prot __ro_after_init; +static struct proto_ops espintcp_ops __ro_after_init; +static struct proto espintcp6_prot; +static struct proto_ops espintcp6_ops; +static DEFINE_MUTEX(tcpv6_prot_mutex); + +static void espintcp_data_ready(struct sock *sk) +{ + struct espintcp_ctx *ctx = espintcp_getctx(sk); + + strp_data_ready(&ctx->strp); +} + +static void espintcp_tx_work(struct work_struct *work) +{ + struct espintcp_ctx *ctx = container_of(work, + struct espintcp_ctx, work); + struct sock *sk = ctx->strp.sk; + + lock_sock(sk); + if (!ctx->tx_running) + espintcp_push_msgs(sk, 0); + release_sock(sk); +} + +static void espintcp_write_space(struct sock *sk) +{ + struct espintcp_ctx *ctx = espintcp_getctx(sk); + + schedule_work(&ctx->work); + ctx->saved_write_space(sk); +} + +static void espintcp_destruct(struct sock *sk) +{ + struct espintcp_ctx *ctx = espintcp_getctx(sk); + + ctx->saved_destruct(sk); + kfree(ctx); +} + +bool tcp_is_ulp_esp(struct sock *sk) +{ + return sk->sk_prot == &espintcp_prot || sk->sk_prot == &espintcp6_prot; +} +EXPORT_SYMBOL_GPL(tcp_is_ulp_esp); + +static void build_protos(struct proto *espintcp_prot, + struct proto_ops *espintcp_ops, + const struct proto *orig_prot, + const struct proto_ops *orig_ops); +static int espintcp_init_sk(struct sock *sk) +{ + struct inet_connection_sock *icsk = inet_csk(sk); + struct strp_callbacks cb = { + .rcv_msg = espintcp_rcv, + .parse_msg = espintcp_parse, + }; + struct espintcp_ctx *ctx; + int err; + + /* sockmap is not compatible with espintcp */ + if (sk->sk_user_data) + return -EBUSY; + + ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); + if (!ctx) + return -ENOMEM; + + err = strp_init(&ctx->strp, sk, &cb); + if (err) + goto free; + + __sk_dst_reset(sk); + + strp_check_rcv(&ctx->strp); + skb_queue_head_init(&ctx->ike_queue); + skb_queue_head_init(&ctx->out_queue); + + if (sk->sk_family == AF_INET) { + sk->sk_prot = &espintcp_prot; + sk->sk_socket->ops = &espintcp_ops; + } else { + mutex_lock(&tcpv6_prot_mutex); + if (!espintcp6_prot.recvmsg) + build_protos(&espintcp6_prot, &espintcp6_ops, sk->sk_prot, sk->sk_socket->ops); + mutex_unlock(&tcpv6_prot_mutex); + + sk->sk_prot = &espintcp6_prot; + sk->sk_socket->ops = &espintcp6_ops; + } + ctx->saved_data_ready = sk->sk_data_ready; + ctx->saved_write_space = sk->sk_write_space; + ctx->saved_destruct = sk->sk_destruct; + sk->sk_data_ready = espintcp_data_ready; + sk->sk_write_space = espintcp_write_space; + sk->sk_destruct = espintcp_destruct; + rcu_assign_pointer(icsk->icsk_ulp_data, ctx); + INIT_WORK(&ctx->work, espintcp_tx_work); + + /* avoid using task_frag */ + sk->sk_allocation = GFP_ATOMIC; + + return 0; + +free: + kfree(ctx); + return err; +} + +static void espintcp_release(struct sock *sk) +{ + struct espintcp_ctx *ctx = espintcp_getctx(sk); + struct sk_buff_head queue; + struct sk_buff *skb; + + __skb_queue_head_init(&queue); + skb_queue_splice_init(&ctx->out_queue, &queue); + + while ((skb = __skb_dequeue(&queue))) + espintcp_push_skb(sk, skb); + + tcp_release_cb(sk); +} + +static void espintcp_close(struct sock *sk, long timeout) +{ + struct espintcp_ctx *ctx = espintcp_getctx(sk); + struct espintcp_msg *emsg = &ctx->partial; + + strp_stop(&ctx->strp); + + sk->sk_prot = &tcp_prot; + barrier(); + + cancel_work_sync(&ctx->work); + strp_done(&ctx->strp); + + skb_queue_purge(&ctx->out_queue); + skb_queue_purge(&ctx->ike_queue); + + if (emsg->len) { + if (emsg->skb) + kfree_skb(emsg->skb); + else + sk_msg_free(sk, &emsg->skmsg); + } + + tcp_close(sk, timeout); +} + +static __poll_t espintcp_poll(struct file *file, struct socket *sock, + poll_table *wait) +{ + __poll_t mask = datagram_poll(file, sock, wait); + struct sock *sk = sock->sk; + struct espintcp_ctx *ctx = espintcp_getctx(sk); + + if (!skb_queue_empty(&ctx->ike_queue)) + mask |= EPOLLIN | EPOLLRDNORM; + + return mask; +} + +static void build_protos(struct proto *espintcp_prot, + struct proto_ops *espintcp_ops, + const struct proto *orig_prot, + const struct proto_ops *orig_ops) +{ + memcpy(espintcp_prot, orig_prot, sizeof(struct proto)); + memcpy(espintcp_ops, orig_ops, sizeof(struct proto_ops)); + espintcp_prot->sendmsg = espintcp_sendmsg; + espintcp_prot->recvmsg = espintcp_recvmsg; + espintcp_prot->close = espintcp_close; + espintcp_prot->release_cb = espintcp_release; + espintcp_ops->poll = espintcp_poll; +} + +static struct tcp_ulp_ops espintcp_ulp __read_mostly = { + .name = "espintcp", + .owner = THIS_MODULE, + .init = espintcp_init_sk, +}; + +void __init espintcp_init(void) +{ + build_protos(&espintcp_prot, &espintcp_ops, &tcp_prot, &inet_stream_ops); + + tcp_register_ulp(&espintcp_ulp); +} diff --git a/net/xfrm/xfrm_algo.c b/net/xfrm/xfrm_algo.c new file mode 100644 index 000000000..094734fbe --- /dev/null +++ b/net/xfrm/xfrm_algo.c @@ -0,0 +1,866 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * xfrm algorithm interface + * + * Copyright (c) 2002 James Morris + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_INET_ESP) || IS_ENABLED(CONFIG_INET6_ESP) +#include +#endif + +/* + * Algorithms supported by IPsec. These entries contain properties which + * are used in key negotiation and xfrm processing, and are used to verify + * that instantiated crypto transforms have correct parameters for IPsec + * purposes. + */ +static struct xfrm_algo_desc aead_list[] = { +{ + .name = "rfc4106(gcm(aes))", + + .uinfo = { + .aead = { + .geniv = "seqiv", + .icv_truncbits = 64, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_EALG_AES_GCM_ICV8, + .sadb_alg_ivlen = 8, + .sadb_alg_minbits = 128, + .sadb_alg_maxbits = 256 + } +}, +{ + .name = "rfc4106(gcm(aes))", + + .uinfo = { + .aead = { + .geniv = "seqiv", + .icv_truncbits = 96, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_EALG_AES_GCM_ICV12, + .sadb_alg_ivlen = 8, + .sadb_alg_minbits = 128, + .sadb_alg_maxbits = 256 + } +}, +{ + .name = "rfc4106(gcm(aes))", + + .uinfo = { + .aead = { + .geniv = "seqiv", + .icv_truncbits = 128, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_EALG_AES_GCM_ICV16, + .sadb_alg_ivlen = 8, + .sadb_alg_minbits = 128, + .sadb_alg_maxbits = 256 + } +}, +{ + .name = "rfc4309(ccm(aes))", + + .uinfo = { + .aead = { + .geniv = "seqiv", + .icv_truncbits = 64, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_EALG_AES_CCM_ICV8, + .sadb_alg_ivlen = 8, + .sadb_alg_minbits = 128, + .sadb_alg_maxbits = 256 + } +}, +{ + .name = "rfc4309(ccm(aes))", + + .uinfo = { + .aead = { + .geniv = "seqiv", + .icv_truncbits = 96, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_EALG_AES_CCM_ICV12, + .sadb_alg_ivlen = 8, + .sadb_alg_minbits = 128, + .sadb_alg_maxbits = 256 + } +}, +{ + .name = "rfc4309(ccm(aes))", + + .uinfo = { + .aead = { + .geniv = "seqiv", + .icv_truncbits = 128, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_EALG_AES_CCM_ICV16, + .sadb_alg_ivlen = 8, + .sadb_alg_minbits = 128, + .sadb_alg_maxbits = 256 + } +}, +{ + .name = "rfc4543(gcm(aes))", + + .uinfo = { + .aead = { + .geniv = "seqiv", + .icv_truncbits = 128, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_EALG_NULL_AES_GMAC, + .sadb_alg_ivlen = 8, + .sadb_alg_minbits = 128, + .sadb_alg_maxbits = 256 + } +}, +{ + .name = "rfc7539esp(chacha20,poly1305)", + + .uinfo = { + .aead = { + .geniv = "seqiv", + .icv_truncbits = 128, + } + }, + + .pfkey_supported = 0, +}, +}; + +static struct xfrm_algo_desc aalg_list[] = { +{ + .name = "digest_null", + + .uinfo = { + .auth = { + .icv_truncbits = 0, + .icv_fullbits = 0, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_AALG_NULL, + .sadb_alg_ivlen = 0, + .sadb_alg_minbits = 0, + .sadb_alg_maxbits = 0 + } +}, +{ + .name = "hmac(md5)", + .compat = "md5", + + .uinfo = { + .auth = { + .icv_truncbits = 96, + .icv_fullbits = 128, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_AALG_MD5HMAC, + .sadb_alg_ivlen = 0, + .sadb_alg_minbits = 128, + .sadb_alg_maxbits = 128 + } +}, +{ + .name = "hmac(sha1)", + .compat = "sha1", + + .uinfo = { + .auth = { + .icv_truncbits = 96, + .icv_fullbits = 160, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_AALG_SHA1HMAC, + .sadb_alg_ivlen = 0, + .sadb_alg_minbits = 160, + .sadb_alg_maxbits = 160 + } +}, +{ + .name = "hmac(sha256)", + .compat = "sha256", + + .uinfo = { + .auth = { + .icv_truncbits = 96, + .icv_fullbits = 256, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_AALG_SHA2_256HMAC, + .sadb_alg_ivlen = 0, + .sadb_alg_minbits = 256, + .sadb_alg_maxbits = 256 + } +}, +{ + .name = "hmac(sha384)", + + .uinfo = { + .auth = { + .icv_truncbits = 192, + .icv_fullbits = 384, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_AALG_SHA2_384HMAC, + .sadb_alg_ivlen = 0, + .sadb_alg_minbits = 384, + .sadb_alg_maxbits = 384 + } +}, +{ + .name = "hmac(sha512)", + + .uinfo = { + .auth = { + .icv_truncbits = 256, + .icv_fullbits = 512, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_AALG_SHA2_512HMAC, + .sadb_alg_ivlen = 0, + .sadb_alg_minbits = 512, + .sadb_alg_maxbits = 512 + } +}, +{ + .name = "hmac(rmd160)", + .compat = "rmd160", + + .uinfo = { + .auth = { + .icv_truncbits = 96, + .icv_fullbits = 160, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_AALG_RIPEMD160HMAC, + .sadb_alg_ivlen = 0, + .sadb_alg_minbits = 160, + .sadb_alg_maxbits = 160 + } +}, +{ + .name = "xcbc(aes)", + + .uinfo = { + .auth = { + .icv_truncbits = 96, + .icv_fullbits = 128, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_AALG_AES_XCBC_MAC, + .sadb_alg_ivlen = 0, + .sadb_alg_minbits = 128, + .sadb_alg_maxbits = 128 + } +}, +{ + /* rfc4494 */ + .name = "cmac(aes)", + + .uinfo = { + .auth = { + .icv_truncbits = 96, + .icv_fullbits = 128, + } + }, + + .pfkey_supported = 0, +}, +{ + .name = "hmac(sm3)", + .compat = "sm3", + + .uinfo = { + .auth = { + .icv_truncbits = 256, + .icv_fullbits = 256, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_AALG_SM3_256HMAC, + .sadb_alg_ivlen = 0, + .sadb_alg_minbits = 256, + .sadb_alg_maxbits = 256 + } +}, +}; + +static struct xfrm_algo_desc ealg_list[] = { +{ + .name = "ecb(cipher_null)", + .compat = "cipher_null", + + .uinfo = { + .encr = { + .blockbits = 8, + .defkeybits = 0, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_EALG_NULL, + .sadb_alg_ivlen = 0, + .sadb_alg_minbits = 0, + .sadb_alg_maxbits = 0 + } +}, +{ + .name = "cbc(des)", + .compat = "des", + + .uinfo = { + .encr = { + .geniv = "echainiv", + .blockbits = 64, + .defkeybits = 64, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_EALG_DESCBC, + .sadb_alg_ivlen = 8, + .sadb_alg_minbits = 64, + .sadb_alg_maxbits = 64 + } +}, +{ + .name = "cbc(des3_ede)", + .compat = "des3_ede", + + .uinfo = { + .encr = { + .geniv = "echainiv", + .blockbits = 64, + .defkeybits = 192, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_EALG_3DESCBC, + .sadb_alg_ivlen = 8, + .sadb_alg_minbits = 192, + .sadb_alg_maxbits = 192 + } +}, +{ + .name = "cbc(cast5)", + .compat = "cast5", + + .uinfo = { + .encr = { + .geniv = "echainiv", + .blockbits = 64, + .defkeybits = 128, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_EALG_CASTCBC, + .sadb_alg_ivlen = 8, + .sadb_alg_minbits = 40, + .sadb_alg_maxbits = 128 + } +}, +{ + .name = "cbc(blowfish)", + .compat = "blowfish", + + .uinfo = { + .encr = { + .geniv = "echainiv", + .blockbits = 64, + .defkeybits = 128, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_EALG_BLOWFISHCBC, + .sadb_alg_ivlen = 8, + .sadb_alg_minbits = 40, + .sadb_alg_maxbits = 448 + } +}, +{ + .name = "cbc(aes)", + .compat = "aes", + + .uinfo = { + .encr = { + .geniv = "echainiv", + .blockbits = 128, + .defkeybits = 128, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_EALG_AESCBC, + .sadb_alg_ivlen = 8, + .sadb_alg_minbits = 128, + .sadb_alg_maxbits = 256 + } +}, +{ + .name = "cbc(serpent)", + .compat = "serpent", + + .uinfo = { + .encr = { + .geniv = "echainiv", + .blockbits = 128, + .defkeybits = 128, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_EALG_SERPENTCBC, + .sadb_alg_ivlen = 8, + .sadb_alg_minbits = 128, + .sadb_alg_maxbits = 256, + } +}, +{ + .name = "cbc(camellia)", + .compat = "camellia", + + .uinfo = { + .encr = { + .geniv = "echainiv", + .blockbits = 128, + .defkeybits = 128, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_EALG_CAMELLIACBC, + .sadb_alg_ivlen = 8, + .sadb_alg_minbits = 128, + .sadb_alg_maxbits = 256 + } +}, +{ + .name = "cbc(twofish)", + .compat = "twofish", + + .uinfo = { + .encr = { + .geniv = "echainiv", + .blockbits = 128, + .defkeybits = 128, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_EALG_TWOFISHCBC, + .sadb_alg_ivlen = 8, + .sadb_alg_minbits = 128, + .sadb_alg_maxbits = 256 + } +}, +{ + .name = "rfc3686(ctr(aes))", + + .uinfo = { + .encr = { + .geniv = "seqiv", + .blockbits = 128, + .defkeybits = 160, /* 128-bit key + 32-bit nonce */ + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_EALG_AESCTR, + .sadb_alg_ivlen = 8, + .sadb_alg_minbits = 160, + .sadb_alg_maxbits = 288 + } +}, +{ + .name = "cbc(sm4)", + .compat = "sm4", + + .uinfo = { + .encr = { + .geniv = "echainiv", + .blockbits = 128, + .defkeybits = 128, + } + }, + + .pfkey_supported = 1, + + .desc = { + .sadb_alg_id = SADB_X_EALG_SM4CBC, + .sadb_alg_ivlen = 16, + .sadb_alg_minbits = 128, + .sadb_alg_maxbits = 256 + } +}, +}; + +static struct xfrm_algo_desc calg_list[] = { +{ + .name = "deflate", + .uinfo = { + .comp = { + .threshold = 90, + } + }, + .pfkey_supported = 1, + .desc = { .sadb_alg_id = SADB_X_CALG_DEFLATE } +}, +{ + .name = "lzs", + .uinfo = { + .comp = { + .threshold = 90, + } + }, + .pfkey_supported = 1, + .desc = { .sadb_alg_id = SADB_X_CALG_LZS } +}, +{ + .name = "lzjh", + .uinfo = { + .comp = { + .threshold = 50, + } + }, + .pfkey_supported = 1, + .desc = { .sadb_alg_id = SADB_X_CALG_LZJH } +}, +}; + +static inline int aalg_entries(void) +{ + return ARRAY_SIZE(aalg_list); +} + +static inline int ealg_entries(void) +{ + return ARRAY_SIZE(ealg_list); +} + +static inline int calg_entries(void) +{ + return ARRAY_SIZE(calg_list); +} + +struct xfrm_algo_list { + struct xfrm_algo_desc *algs; + int entries; + u32 type; + u32 mask; +}; + +static const struct xfrm_algo_list xfrm_aead_list = { + .algs = aead_list, + .entries = ARRAY_SIZE(aead_list), + .type = CRYPTO_ALG_TYPE_AEAD, + .mask = CRYPTO_ALG_TYPE_MASK, +}; + +static const struct xfrm_algo_list xfrm_aalg_list = { + .algs = aalg_list, + .entries = ARRAY_SIZE(aalg_list), + .type = CRYPTO_ALG_TYPE_HASH, + .mask = CRYPTO_ALG_TYPE_HASH_MASK, +}; + +static const struct xfrm_algo_list xfrm_ealg_list = { + .algs = ealg_list, + .entries = ARRAY_SIZE(ealg_list), + .type = CRYPTO_ALG_TYPE_SKCIPHER, + .mask = CRYPTO_ALG_TYPE_MASK, +}; + +static const struct xfrm_algo_list xfrm_calg_list = { + .algs = calg_list, + .entries = ARRAY_SIZE(calg_list), + .type = CRYPTO_ALG_TYPE_COMPRESS, + .mask = CRYPTO_ALG_TYPE_MASK, +}; + +static struct xfrm_algo_desc *xfrm_find_algo( + const struct xfrm_algo_list *algo_list, + int match(const struct xfrm_algo_desc *entry, const void *data), + const void *data, int probe) +{ + struct xfrm_algo_desc *list = algo_list->algs; + int i, status; + + for (i = 0; i < algo_list->entries; i++) { + if (!match(list + i, data)) + continue; + + if (list[i].available) + return &list[i]; + + if (!probe) + break; + + status = crypto_has_alg(list[i].name, algo_list->type, + algo_list->mask); + if (!status) + break; + + list[i].available = status; + return &list[i]; + } + return NULL; +} + +static int xfrm_alg_id_match(const struct xfrm_algo_desc *entry, + const void *data) +{ + return entry->desc.sadb_alg_id == (unsigned long)data; +} + +struct xfrm_algo_desc *xfrm_aalg_get_byid(int alg_id) +{ + return xfrm_find_algo(&xfrm_aalg_list, xfrm_alg_id_match, + (void *)(unsigned long)alg_id, 1); +} +EXPORT_SYMBOL_GPL(xfrm_aalg_get_byid); + +struct xfrm_algo_desc *xfrm_ealg_get_byid(int alg_id) +{ + return xfrm_find_algo(&xfrm_ealg_list, xfrm_alg_id_match, + (void *)(unsigned long)alg_id, 1); +} +EXPORT_SYMBOL_GPL(xfrm_ealg_get_byid); + +struct xfrm_algo_desc *xfrm_calg_get_byid(int alg_id) +{ + return xfrm_find_algo(&xfrm_calg_list, xfrm_alg_id_match, + (void *)(unsigned long)alg_id, 1); +} +EXPORT_SYMBOL_GPL(xfrm_calg_get_byid); + +static int xfrm_alg_name_match(const struct xfrm_algo_desc *entry, + const void *data) +{ + const char *name = data; + + return name && (!strcmp(name, entry->name) || + (entry->compat && !strcmp(name, entry->compat))); +} + +struct xfrm_algo_desc *xfrm_aalg_get_byname(const char *name, int probe) +{ + return xfrm_find_algo(&xfrm_aalg_list, xfrm_alg_name_match, name, + probe); +} +EXPORT_SYMBOL_GPL(xfrm_aalg_get_byname); + +struct xfrm_algo_desc *xfrm_ealg_get_byname(const char *name, int probe) +{ + return xfrm_find_algo(&xfrm_ealg_list, xfrm_alg_name_match, name, + probe); +} +EXPORT_SYMBOL_GPL(xfrm_ealg_get_byname); + +struct xfrm_algo_desc *xfrm_calg_get_byname(const char *name, int probe) +{ + return xfrm_find_algo(&xfrm_calg_list, xfrm_alg_name_match, name, + probe); +} +EXPORT_SYMBOL_GPL(xfrm_calg_get_byname); + +struct xfrm_aead_name { + const char *name; + int icvbits; +}; + +static int xfrm_aead_name_match(const struct xfrm_algo_desc *entry, + const void *data) +{ + const struct xfrm_aead_name *aead = data; + const char *name = aead->name; + + return aead->icvbits == entry->uinfo.aead.icv_truncbits && name && + !strcmp(name, entry->name); +} + +struct xfrm_algo_desc *xfrm_aead_get_byname(const char *name, int icv_len, int probe) +{ + struct xfrm_aead_name data = { + .name = name, + .icvbits = icv_len, + }; + + return xfrm_find_algo(&xfrm_aead_list, xfrm_aead_name_match, &data, + probe); +} +EXPORT_SYMBOL_GPL(xfrm_aead_get_byname); + +struct xfrm_algo_desc *xfrm_aalg_get_byidx(unsigned int idx) +{ + if (idx >= aalg_entries()) + return NULL; + + return &aalg_list[idx]; +} +EXPORT_SYMBOL_GPL(xfrm_aalg_get_byidx); + +struct xfrm_algo_desc *xfrm_ealg_get_byidx(unsigned int idx) +{ + if (idx >= ealg_entries()) + return NULL; + + return &ealg_list[idx]; +} +EXPORT_SYMBOL_GPL(xfrm_ealg_get_byidx); + +/* + * Probe for the availability of crypto algorithms, and set the available + * flag for any algorithms found on the system. This is typically called by + * pfkey during userspace SA add, update or register. + */ +void xfrm_probe_algs(void) +{ + int i, status; + + BUG_ON(in_softirq()); + + for (i = 0; i < aalg_entries(); i++) { + status = crypto_has_ahash(aalg_list[i].name, 0, 0); + if (aalg_list[i].available != status) + aalg_list[i].available = status; + } + + for (i = 0; i < ealg_entries(); i++) { + status = crypto_has_skcipher(ealg_list[i].name, 0, 0); + if (ealg_list[i].available != status) + ealg_list[i].available = status; + } + + for (i = 0; i < calg_entries(); i++) { + status = crypto_has_comp(calg_list[i].name, 0, + CRYPTO_ALG_ASYNC); + if (calg_list[i].available != status) + calg_list[i].available = status; + } +} +EXPORT_SYMBOL_GPL(xfrm_probe_algs); + +int xfrm_count_pfkey_auth_supported(void) +{ + int i, n; + + for (i = 0, n = 0; i < aalg_entries(); i++) + if (aalg_list[i].available && aalg_list[i].pfkey_supported) + n++; + return n; +} +EXPORT_SYMBOL_GPL(xfrm_count_pfkey_auth_supported); + +int xfrm_count_pfkey_enc_supported(void) +{ + int i, n; + + for (i = 0, n = 0; i < ealg_entries(); i++) + if (ealg_list[i].available && ealg_list[i].pfkey_supported) + n++; + return n; +} +EXPORT_SYMBOL_GPL(xfrm_count_pfkey_enc_supported); + +MODULE_LICENSE("GPL"); diff --git a/net/xfrm/xfrm_compat.c b/net/xfrm/xfrm_compat.c new file mode 100644 index 000000000..655fe4ff8 --- /dev/null +++ b/net/xfrm/xfrm_compat.c @@ -0,0 +1,675 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * XFRM compat layer + * Author: Dmitry Safonov + * Based on code and translator idea by: Florian Westphal + */ +#include +#include +#include +#include + +struct compat_xfrm_lifetime_cfg { + compat_u64 soft_byte_limit, hard_byte_limit; + compat_u64 soft_packet_limit, hard_packet_limit; + compat_u64 soft_add_expires_seconds, hard_add_expires_seconds; + compat_u64 soft_use_expires_seconds, hard_use_expires_seconds; +}; /* same size on 32bit, but only 4 byte alignment required */ + +struct compat_xfrm_lifetime_cur { + compat_u64 bytes, packets, add_time, use_time; +}; /* same size on 32bit, but only 4 byte alignment required */ + +struct compat_xfrm_userpolicy_info { + struct xfrm_selector sel; + struct compat_xfrm_lifetime_cfg lft; + struct compat_xfrm_lifetime_cur curlft; + __u32 priority, index; + u8 dir, action, flags, share; + /* 4 bytes additional padding on 64bit */ +}; + +struct compat_xfrm_usersa_info { + struct xfrm_selector sel; + struct xfrm_id id; + xfrm_address_t saddr; + struct compat_xfrm_lifetime_cfg lft; + struct compat_xfrm_lifetime_cur curlft; + struct xfrm_stats stats; + __u32 seq, reqid; + u16 family; + u8 mode, replay_window, flags; + /* 4 bytes additional padding on 64bit */ +}; + +struct compat_xfrm_user_acquire { + struct xfrm_id id; + xfrm_address_t saddr; + struct xfrm_selector sel; + struct compat_xfrm_userpolicy_info policy; + /* 4 bytes additional padding on 64bit */ + __u32 aalgos, ealgos, calgos, seq; +}; + +struct compat_xfrm_userspi_info { + struct compat_xfrm_usersa_info info; + /* 4 bytes additional padding on 64bit */ + __u32 min, max; +}; + +struct compat_xfrm_user_expire { + struct compat_xfrm_usersa_info state; + /* 8 bytes additional padding on 64bit */ + u8 hard; +}; + +struct compat_xfrm_user_polexpire { + struct compat_xfrm_userpolicy_info pol; + /* 8 bytes additional padding on 64bit */ + u8 hard; +}; + +#define XMSGSIZE(type) sizeof(struct type) + +static const int compat_msg_min[XFRM_NR_MSGTYPES] = { + [XFRM_MSG_NEWSA - XFRM_MSG_BASE] = XMSGSIZE(compat_xfrm_usersa_info), + [XFRM_MSG_DELSA - XFRM_MSG_BASE] = XMSGSIZE(xfrm_usersa_id), + [XFRM_MSG_GETSA - XFRM_MSG_BASE] = XMSGSIZE(xfrm_usersa_id), + [XFRM_MSG_NEWPOLICY - XFRM_MSG_BASE] = XMSGSIZE(compat_xfrm_userpolicy_info), + [XFRM_MSG_DELPOLICY - XFRM_MSG_BASE] = XMSGSIZE(xfrm_userpolicy_id), + [XFRM_MSG_GETPOLICY - XFRM_MSG_BASE] = XMSGSIZE(xfrm_userpolicy_id), + [XFRM_MSG_ALLOCSPI - XFRM_MSG_BASE] = XMSGSIZE(compat_xfrm_userspi_info), + [XFRM_MSG_ACQUIRE - XFRM_MSG_BASE] = XMSGSIZE(compat_xfrm_user_acquire), + [XFRM_MSG_EXPIRE - XFRM_MSG_BASE] = XMSGSIZE(compat_xfrm_user_expire), + [XFRM_MSG_UPDPOLICY - XFRM_MSG_BASE] = XMSGSIZE(compat_xfrm_userpolicy_info), + [XFRM_MSG_UPDSA - XFRM_MSG_BASE] = XMSGSIZE(compat_xfrm_usersa_info), + [XFRM_MSG_POLEXPIRE - XFRM_MSG_BASE] = XMSGSIZE(compat_xfrm_user_polexpire), + [XFRM_MSG_FLUSHSA - XFRM_MSG_BASE] = XMSGSIZE(xfrm_usersa_flush), + [XFRM_MSG_FLUSHPOLICY - XFRM_MSG_BASE] = 0, + [XFRM_MSG_NEWAE - XFRM_MSG_BASE] = XMSGSIZE(xfrm_aevent_id), + [XFRM_MSG_GETAE - XFRM_MSG_BASE] = XMSGSIZE(xfrm_aevent_id), + [XFRM_MSG_REPORT - XFRM_MSG_BASE] = XMSGSIZE(xfrm_user_report), + [XFRM_MSG_MIGRATE - XFRM_MSG_BASE] = XMSGSIZE(xfrm_userpolicy_id), + [XFRM_MSG_NEWSADINFO - XFRM_MSG_BASE] = sizeof(u32), + [XFRM_MSG_GETSADINFO - XFRM_MSG_BASE] = sizeof(u32), + [XFRM_MSG_NEWSPDINFO - XFRM_MSG_BASE] = sizeof(u32), + [XFRM_MSG_GETSPDINFO - XFRM_MSG_BASE] = sizeof(u32), + [XFRM_MSG_MAPPING - XFRM_MSG_BASE] = XMSGSIZE(xfrm_user_mapping) +}; + +static const struct nla_policy compat_policy[XFRMA_MAX+1] = { + [XFRMA_SA] = { .len = XMSGSIZE(compat_xfrm_usersa_info)}, + [XFRMA_POLICY] = { .len = XMSGSIZE(compat_xfrm_userpolicy_info)}, + [XFRMA_LASTUSED] = { .type = NLA_U64}, + [XFRMA_ALG_AUTH_TRUNC] = { .len = sizeof(struct xfrm_algo_auth)}, + [XFRMA_ALG_AEAD] = { .len = sizeof(struct xfrm_algo_aead) }, + [XFRMA_ALG_AUTH] = { .len = sizeof(struct xfrm_algo) }, + [XFRMA_ALG_CRYPT] = { .len = sizeof(struct xfrm_algo) }, + [XFRMA_ALG_COMP] = { .len = sizeof(struct xfrm_algo) }, + [XFRMA_ENCAP] = { .len = sizeof(struct xfrm_encap_tmpl) }, + [XFRMA_TMPL] = { .len = sizeof(struct xfrm_user_tmpl) }, + [XFRMA_SEC_CTX] = { .len = sizeof(struct xfrm_user_sec_ctx) }, + [XFRMA_LTIME_VAL] = { .len = sizeof(struct xfrm_lifetime_cur) }, + [XFRMA_REPLAY_VAL] = { .len = sizeof(struct xfrm_replay_state) }, + [XFRMA_REPLAY_THRESH] = { .type = NLA_U32 }, + [XFRMA_ETIMER_THRESH] = { .type = NLA_U32 }, + [XFRMA_SRCADDR] = { .len = sizeof(xfrm_address_t) }, + [XFRMA_COADDR] = { .len = sizeof(xfrm_address_t) }, + [XFRMA_POLICY_TYPE] = { .len = sizeof(struct xfrm_userpolicy_type)}, + [XFRMA_MIGRATE] = { .len = sizeof(struct xfrm_user_migrate) }, + [XFRMA_KMADDRESS] = { .len = sizeof(struct xfrm_user_kmaddress) }, + [XFRMA_MARK] = { .len = sizeof(struct xfrm_mark) }, + [XFRMA_TFCPAD] = { .type = NLA_U32 }, + [XFRMA_REPLAY_ESN_VAL] = { .len = sizeof(struct xfrm_replay_state_esn) }, + [XFRMA_SA_EXTRA_FLAGS] = { .type = NLA_U32 }, + [XFRMA_PROTO] = { .type = NLA_U8 }, + [XFRMA_ADDRESS_FILTER] = { .len = sizeof(struct xfrm_address_filter) }, + [XFRMA_OFFLOAD_DEV] = { .len = sizeof(struct xfrm_user_offload) }, + [XFRMA_SET_MARK] = { .type = NLA_U32 }, + [XFRMA_SET_MARK_MASK] = { .type = NLA_U32 }, + [XFRMA_IF_ID] = { .type = NLA_U32 }, + [XFRMA_MTIMER_THRESH] = { .type = NLA_U32 }, +}; + +static struct nlmsghdr *xfrm_nlmsg_put_compat(struct sk_buff *skb, + const struct nlmsghdr *nlh_src, u16 type) +{ + int payload = compat_msg_min[type]; + int src_len = xfrm_msg_min[type]; + struct nlmsghdr *nlh_dst; + + /* Compat messages are shorter or equal to native (+padding) */ + if (WARN_ON_ONCE(src_len < payload)) + return ERR_PTR(-EMSGSIZE); + + nlh_dst = nlmsg_put(skb, nlh_src->nlmsg_pid, nlh_src->nlmsg_seq, + nlh_src->nlmsg_type, payload, nlh_src->nlmsg_flags); + if (!nlh_dst) + return ERR_PTR(-EMSGSIZE); + + memset(nlmsg_data(nlh_dst), 0, payload); + + switch (nlh_src->nlmsg_type) { + /* Compat message has the same layout as native */ + case XFRM_MSG_DELSA: + case XFRM_MSG_DELPOLICY: + case XFRM_MSG_FLUSHSA: + case XFRM_MSG_FLUSHPOLICY: + case XFRM_MSG_NEWAE: + case XFRM_MSG_REPORT: + case XFRM_MSG_MIGRATE: + case XFRM_MSG_NEWSADINFO: + case XFRM_MSG_NEWSPDINFO: + case XFRM_MSG_MAPPING: + WARN_ON_ONCE(src_len != payload); + memcpy(nlmsg_data(nlh_dst), nlmsg_data(nlh_src), src_len); + break; + /* 4 byte alignment for trailing u64 on native, but not on compat */ + case XFRM_MSG_NEWSA: + case XFRM_MSG_NEWPOLICY: + case XFRM_MSG_UPDSA: + case XFRM_MSG_UPDPOLICY: + WARN_ON_ONCE(src_len != payload + 4); + memcpy(nlmsg_data(nlh_dst), nlmsg_data(nlh_src), payload); + break; + case XFRM_MSG_EXPIRE: { + const struct xfrm_user_expire *src_ue = nlmsg_data(nlh_src); + struct compat_xfrm_user_expire *dst_ue = nlmsg_data(nlh_dst); + + /* compat_xfrm_user_expire has 4-byte smaller state */ + memcpy(dst_ue, src_ue, sizeof(dst_ue->state)); + dst_ue->hard = src_ue->hard; + break; + } + case XFRM_MSG_ACQUIRE: { + const struct xfrm_user_acquire *src_ua = nlmsg_data(nlh_src); + struct compat_xfrm_user_acquire *dst_ua = nlmsg_data(nlh_dst); + + memcpy(dst_ua, src_ua, offsetof(struct compat_xfrm_user_acquire, aalgos)); + dst_ua->aalgos = src_ua->aalgos; + dst_ua->ealgos = src_ua->ealgos; + dst_ua->calgos = src_ua->calgos; + dst_ua->seq = src_ua->seq; + break; + } + case XFRM_MSG_POLEXPIRE: { + const struct xfrm_user_polexpire *src_upe = nlmsg_data(nlh_src); + struct compat_xfrm_user_polexpire *dst_upe = nlmsg_data(nlh_dst); + + /* compat_xfrm_user_polexpire has 4-byte smaller state */ + memcpy(dst_upe, src_upe, sizeof(dst_upe->pol)); + dst_upe->hard = src_upe->hard; + break; + } + case XFRM_MSG_ALLOCSPI: { + const struct xfrm_userspi_info *src_usi = nlmsg_data(nlh_src); + struct compat_xfrm_userspi_info *dst_usi = nlmsg_data(nlh_dst); + + /* compat_xfrm_user_polexpire has 4-byte smaller state */ + memcpy(dst_usi, src_usi, sizeof(src_usi->info)); + dst_usi->min = src_usi->min; + dst_usi->max = src_usi->max; + break; + } + /* Not being sent by kernel */ + case XFRM_MSG_GETSA: + case XFRM_MSG_GETPOLICY: + case XFRM_MSG_GETAE: + case XFRM_MSG_GETSADINFO: + case XFRM_MSG_GETSPDINFO: + default: + pr_warn_once("unsupported nlmsg_type %d\n", nlh_src->nlmsg_type); + return ERR_PTR(-EOPNOTSUPP); + } + + return nlh_dst; +} + +static int xfrm_nla_cpy(struct sk_buff *dst, const struct nlattr *src, int len) +{ + return nla_put(dst, src->nla_type, len, nla_data(src)); +} + +static int xfrm_xlate64_attr(struct sk_buff *dst, const struct nlattr *src) +{ + switch (src->nla_type) { + case XFRMA_PAD: + /* Ignore */ + return 0; + case XFRMA_UNSPEC: + case XFRMA_ALG_AUTH: + case XFRMA_ALG_CRYPT: + case XFRMA_ALG_COMP: + case XFRMA_ENCAP: + case XFRMA_TMPL: + return xfrm_nla_cpy(dst, src, nla_len(src)); + case XFRMA_SA: + return xfrm_nla_cpy(dst, src, XMSGSIZE(compat_xfrm_usersa_info)); + case XFRMA_POLICY: + return xfrm_nla_cpy(dst, src, XMSGSIZE(compat_xfrm_userpolicy_info)); + case XFRMA_SEC_CTX: + return xfrm_nla_cpy(dst, src, nla_len(src)); + case XFRMA_LTIME_VAL: + return nla_put_64bit(dst, src->nla_type, nla_len(src), + nla_data(src), XFRMA_PAD); + case XFRMA_REPLAY_VAL: + case XFRMA_REPLAY_THRESH: + case XFRMA_ETIMER_THRESH: + case XFRMA_SRCADDR: + case XFRMA_COADDR: + return xfrm_nla_cpy(dst, src, nla_len(src)); + case XFRMA_LASTUSED: + return nla_put_64bit(dst, src->nla_type, nla_len(src), + nla_data(src), XFRMA_PAD); + case XFRMA_POLICY_TYPE: + case XFRMA_MIGRATE: + case XFRMA_ALG_AEAD: + case XFRMA_KMADDRESS: + case XFRMA_ALG_AUTH_TRUNC: + case XFRMA_MARK: + case XFRMA_TFCPAD: + case XFRMA_REPLAY_ESN_VAL: + case XFRMA_SA_EXTRA_FLAGS: + case XFRMA_PROTO: + case XFRMA_ADDRESS_FILTER: + case XFRMA_OFFLOAD_DEV: + case XFRMA_SET_MARK: + case XFRMA_SET_MARK_MASK: + case XFRMA_IF_ID: + case XFRMA_MTIMER_THRESH: + return xfrm_nla_cpy(dst, src, nla_len(src)); + default: + BUILD_BUG_ON(XFRMA_MAX != XFRMA_MTIMER_THRESH); + pr_warn_once("unsupported nla_type %d\n", src->nla_type); + return -EOPNOTSUPP; + } +} + +/* Take kernel-built (64bit layout) and create 32bit layout for userspace */ +static int xfrm_xlate64(struct sk_buff *dst, const struct nlmsghdr *nlh_src) +{ + u16 type = nlh_src->nlmsg_type - XFRM_MSG_BASE; + const struct nlattr *nla, *attrs; + struct nlmsghdr *nlh_dst; + int len, remaining; + + nlh_dst = xfrm_nlmsg_put_compat(dst, nlh_src, type); + if (IS_ERR(nlh_dst)) + return PTR_ERR(nlh_dst); + + attrs = nlmsg_attrdata(nlh_src, xfrm_msg_min[type]); + len = nlmsg_attrlen(nlh_src, xfrm_msg_min[type]); + + nla_for_each_attr(nla, attrs, len, remaining) { + int err; + + switch (nlh_src->nlmsg_type) { + case XFRM_MSG_NEWSPDINFO: + err = xfrm_nla_cpy(dst, nla, nla_len(nla)); + break; + default: + err = xfrm_xlate64_attr(dst, nla); + break; + } + if (err) + return err; + } + + nlmsg_end(dst, nlh_dst); + + return 0; +} + +static int xfrm_alloc_compat(struct sk_buff *skb, const struct nlmsghdr *nlh_src) +{ + u16 type = nlh_src->nlmsg_type - XFRM_MSG_BASE; + struct sk_buff *new = NULL; + int err; + + if (type >= ARRAY_SIZE(xfrm_msg_min)) { + pr_warn_once("unsupported nlmsg_type %d\n", nlh_src->nlmsg_type); + return -EOPNOTSUPP; + } + + if (skb_shinfo(skb)->frag_list == NULL) { + new = alloc_skb(skb->len + skb_tailroom(skb), GFP_ATOMIC); + if (!new) + return -ENOMEM; + skb_shinfo(skb)->frag_list = new; + } + + err = xfrm_xlate64(skb_shinfo(skb)->frag_list, nlh_src); + if (err) { + if (new) { + kfree_skb(new); + skb_shinfo(skb)->frag_list = NULL; + } + return err; + } + + return 0; +} + +/* Calculates len of translated 64-bit message. */ +static size_t xfrm_user_rcv_calculate_len64(const struct nlmsghdr *src, + struct nlattr *attrs[XFRMA_MAX + 1], + int maxtype) +{ + size_t len = nlmsg_len(src); + + switch (src->nlmsg_type) { + case XFRM_MSG_NEWSA: + case XFRM_MSG_NEWPOLICY: + case XFRM_MSG_ALLOCSPI: + case XFRM_MSG_ACQUIRE: + case XFRM_MSG_UPDPOLICY: + case XFRM_MSG_UPDSA: + len += 4; + break; + case XFRM_MSG_EXPIRE: + case XFRM_MSG_POLEXPIRE: + len += 8; + break; + case XFRM_MSG_NEWSPDINFO: + /* attirbutes are xfrm_spdattr_type_t, not xfrm_attr_type_t */ + return len; + default: + break; + } + + /* Unexpected for anything, but XFRM_MSG_NEWSPDINFO, please + * correct both 64=>32-bit and 32=>64-bit translators to copy + * new attributes. + */ + if (WARN_ON_ONCE(maxtype)) + return len; + + if (attrs[XFRMA_SA]) + len += 4; + if (attrs[XFRMA_POLICY]) + len += 4; + + /* XXX: some attrs may need to be realigned + * if !CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS + */ + + return len; +} + +static int xfrm_attr_cpy32(void *dst, size_t *pos, const struct nlattr *src, + size_t size, int copy_len, int payload) +{ + struct nlmsghdr *nlmsg = dst; + struct nlattr *nla; + + /* xfrm_user_rcv_msg_compat() relies on fact that 32-bit messages + * have the same len or shorted than 64-bit ones. + * 32-bit translation that is bigger than 64-bit original is unexpected. + */ + if (WARN_ON_ONCE(copy_len > payload)) + copy_len = payload; + + if (size - *pos < nla_attr_size(payload)) + return -ENOBUFS; + + nla = dst + *pos; + + memcpy(nla, src, nla_attr_size(copy_len)); + nla->nla_len = nla_attr_size(payload); + *pos += nla_attr_size(copy_len); + nlmsg->nlmsg_len += nla->nla_len; + + memset(dst + *pos, 0, payload - copy_len); + *pos += payload - copy_len; + + return 0; +} + +static int xfrm_xlate32_attr(void *dst, const struct nlattr *nla, + size_t *pos, size_t size, + struct netlink_ext_ack *extack) +{ + int type = nla_type(nla); + u16 pol_len32, pol_len64; + int err; + + if (type > XFRMA_MAX) { + BUILD_BUG_ON(XFRMA_MAX != XFRMA_MTIMER_THRESH); + NL_SET_ERR_MSG(extack, "Bad attribute"); + return -EOPNOTSUPP; + } + type = array_index_nospec(type, XFRMA_MAX + 1); + if (nla_len(nla) < compat_policy[type].len) { + NL_SET_ERR_MSG(extack, "Attribute bad length"); + return -EOPNOTSUPP; + } + + pol_len32 = compat_policy[type].len; + pol_len64 = xfrma_policy[type].len; + + /* XFRMA_SA and XFRMA_POLICY - need to know how-to translate */ + if (pol_len32 != pol_len64) { + if (nla_len(nla) != compat_policy[type].len) { + NL_SET_ERR_MSG(extack, "Attribute bad length"); + return -EOPNOTSUPP; + } + err = xfrm_attr_cpy32(dst, pos, nla, size, pol_len32, pol_len64); + if (err) + return err; + } + + return xfrm_attr_cpy32(dst, pos, nla, size, nla_len(nla), nla_len(nla)); +} + +static int xfrm_xlate32(struct nlmsghdr *dst, const struct nlmsghdr *src, + struct nlattr *attrs[XFRMA_MAX+1], + size_t size, u8 type, int maxtype, + struct netlink_ext_ack *extack) +{ + size_t pos; + int i; + + memcpy(dst, src, NLMSG_HDRLEN); + dst->nlmsg_len = NLMSG_HDRLEN + xfrm_msg_min[type]; + memset(nlmsg_data(dst), 0, xfrm_msg_min[type]); + + switch (src->nlmsg_type) { + /* Compat message has the same layout as native */ + case XFRM_MSG_DELSA: + case XFRM_MSG_GETSA: + case XFRM_MSG_DELPOLICY: + case XFRM_MSG_GETPOLICY: + case XFRM_MSG_FLUSHSA: + case XFRM_MSG_FLUSHPOLICY: + case XFRM_MSG_NEWAE: + case XFRM_MSG_GETAE: + case XFRM_MSG_REPORT: + case XFRM_MSG_MIGRATE: + case XFRM_MSG_NEWSADINFO: + case XFRM_MSG_GETSADINFO: + case XFRM_MSG_NEWSPDINFO: + case XFRM_MSG_GETSPDINFO: + case XFRM_MSG_MAPPING: + memcpy(nlmsg_data(dst), nlmsg_data(src), compat_msg_min[type]); + break; + /* 4 byte alignment for trailing u64 on native, but not on compat */ + case XFRM_MSG_NEWSA: + case XFRM_MSG_NEWPOLICY: + case XFRM_MSG_UPDSA: + case XFRM_MSG_UPDPOLICY: + memcpy(nlmsg_data(dst), nlmsg_data(src), compat_msg_min[type]); + break; + case XFRM_MSG_EXPIRE: { + const struct compat_xfrm_user_expire *src_ue = nlmsg_data(src); + struct xfrm_user_expire *dst_ue = nlmsg_data(dst); + + /* compat_xfrm_user_expire has 4-byte smaller state */ + memcpy(dst_ue, src_ue, sizeof(src_ue->state)); + dst_ue->hard = src_ue->hard; + break; + } + case XFRM_MSG_ACQUIRE: { + const struct compat_xfrm_user_acquire *src_ua = nlmsg_data(src); + struct xfrm_user_acquire *dst_ua = nlmsg_data(dst); + + memcpy(dst_ua, src_ua, offsetof(struct compat_xfrm_user_acquire, aalgos)); + dst_ua->aalgos = src_ua->aalgos; + dst_ua->ealgos = src_ua->ealgos; + dst_ua->calgos = src_ua->calgos; + dst_ua->seq = src_ua->seq; + break; + } + case XFRM_MSG_POLEXPIRE: { + const struct compat_xfrm_user_polexpire *src_upe = nlmsg_data(src); + struct xfrm_user_polexpire *dst_upe = nlmsg_data(dst); + + /* compat_xfrm_user_polexpire has 4-byte smaller state */ + memcpy(dst_upe, src_upe, sizeof(src_upe->pol)); + dst_upe->hard = src_upe->hard; + break; + } + case XFRM_MSG_ALLOCSPI: { + const struct compat_xfrm_userspi_info *src_usi = nlmsg_data(src); + struct xfrm_userspi_info *dst_usi = nlmsg_data(dst); + + /* compat_xfrm_user_polexpire has 4-byte smaller state */ + memcpy(dst_usi, src_usi, sizeof(src_usi->info)); + dst_usi->min = src_usi->min; + dst_usi->max = src_usi->max; + break; + } + default: + NL_SET_ERR_MSG(extack, "Unsupported message type"); + return -EOPNOTSUPP; + } + pos = dst->nlmsg_len; + + if (maxtype) { + /* attirbutes are xfrm_spdattr_type_t, not xfrm_attr_type_t */ + WARN_ON_ONCE(src->nlmsg_type != XFRM_MSG_NEWSPDINFO); + + for (i = 1; i <= maxtype; i++) { + int err; + + if (!attrs[i]) + continue; + + /* just copy - no need for translation */ + err = xfrm_attr_cpy32(dst, &pos, attrs[i], size, + nla_len(attrs[i]), nla_len(attrs[i])); + if (err) + return err; + } + return 0; + } + + for (i = 1; i < XFRMA_MAX + 1; i++) { + int err; + + if (i == XFRMA_PAD) + continue; + + if (!attrs[i]) + continue; + + err = xfrm_xlate32_attr(dst, attrs[i], &pos, size, extack); + if (err) + return err; + } + + return 0; +} + +static struct nlmsghdr *xfrm_user_rcv_msg_compat(const struct nlmsghdr *h32, + int maxtype, const struct nla_policy *policy, + struct netlink_ext_ack *extack) +{ + /* netlink_rcv_skb() checks if a message has full (struct nlmsghdr) */ + u16 type = h32->nlmsg_type - XFRM_MSG_BASE; + struct nlattr *attrs[XFRMA_MAX+1]; + struct nlmsghdr *h64; + size_t len; + int err; + + BUILD_BUG_ON(ARRAY_SIZE(xfrm_msg_min) != ARRAY_SIZE(compat_msg_min)); + + if (type >= ARRAY_SIZE(xfrm_msg_min)) + return ERR_PTR(-EINVAL); + + /* Don't call parse: the message might have only nlmsg header */ + if ((h32->nlmsg_type == XFRM_MSG_GETSA || + h32->nlmsg_type == XFRM_MSG_GETPOLICY) && + (h32->nlmsg_flags & NLM_F_DUMP)) + return NULL; + + err = nlmsg_parse_deprecated(h32, compat_msg_min[type], attrs, + maxtype ? : XFRMA_MAX, policy ? : compat_policy, extack); + if (err < 0) + return ERR_PTR(err); + + len = xfrm_user_rcv_calculate_len64(h32, attrs, maxtype); + /* The message doesn't need translation */ + if (len == nlmsg_len(h32)) + return NULL; + + len += NLMSG_HDRLEN; + h64 = kvmalloc(len, GFP_KERNEL); + if (!h64) + return ERR_PTR(-ENOMEM); + + err = xfrm_xlate32(h64, h32, attrs, len, type, maxtype, extack); + if (err < 0) { + kvfree(h64); + return ERR_PTR(err); + } + + return h64; +} + +static int xfrm_user_policy_compat(u8 **pdata32, int optlen) +{ + struct compat_xfrm_userpolicy_info *p = (void *)*pdata32; + u8 *src_templates, *dst_templates; + u8 *data64; + + if (optlen < sizeof(*p)) + return -EINVAL; + + data64 = kmalloc_track_caller(optlen + 4, GFP_USER | __GFP_NOWARN); + if (!data64) + return -ENOMEM; + + memcpy(data64, *pdata32, sizeof(*p)); + memset(data64 + sizeof(*p), 0, 4); + + src_templates = *pdata32 + sizeof(*p); + dst_templates = data64 + sizeof(*p) + 4; + memcpy(dst_templates, src_templates, optlen - sizeof(*p)); + + kfree(*pdata32); + *pdata32 = data64; + return 0; +} + +static struct xfrm_translator xfrm_translator = { + .owner = THIS_MODULE, + .alloc_compat = xfrm_alloc_compat, + .rcv_msg_compat = xfrm_user_rcv_msg_compat, + .xlate_user_policy_sockptr = xfrm_user_policy_compat, +}; + +static int __init xfrm_compat_init(void) +{ + return xfrm_register_translator(&xfrm_translator); +} + +static void __exit xfrm_compat_exit(void) +{ + xfrm_unregister_translator(&xfrm_translator); +} + +module_init(xfrm_compat_init); +module_exit(xfrm_compat_exit); +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Dmitry Safonov"); +MODULE_DESCRIPTION("XFRM 32-bit compatibility layer"); diff --git a/net/xfrm/xfrm_device.c b/net/xfrm/xfrm_device.c new file mode 100644 index 000000000..21269e8f2 --- /dev/null +++ b/net/xfrm/xfrm_device.c @@ -0,0 +1,444 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * xfrm_device.c - IPsec device offloading code. + * + * Copyright (c) 2015 secunet Security Networks AG + * + * Author: + * Steffen Klassert + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CONFIG_XFRM_OFFLOAD +static void __xfrm_transport_prep(struct xfrm_state *x, struct sk_buff *skb, + unsigned int hsize) +{ + struct xfrm_offload *xo = xfrm_offload(skb); + + skb_reset_mac_len(skb); + if (xo->flags & XFRM_GSO_SEGMENT) + skb->transport_header -= x->props.header_len; + + pskb_pull(skb, skb_transport_offset(skb) + x->props.header_len); +} + +static void __xfrm_mode_tunnel_prep(struct xfrm_state *x, struct sk_buff *skb, + unsigned int hsize) + +{ + struct xfrm_offload *xo = xfrm_offload(skb); + + if (xo->flags & XFRM_GSO_SEGMENT) + skb->transport_header = skb->network_header + hsize; + + skb_reset_mac_len(skb); + pskb_pull(skb, skb->mac_len + x->props.header_len); +} + +static void __xfrm_mode_beet_prep(struct xfrm_state *x, struct sk_buff *skb, + unsigned int hsize) +{ + struct xfrm_offload *xo = xfrm_offload(skb); + int phlen = 0; + + if (xo->flags & XFRM_GSO_SEGMENT) + skb->transport_header = skb->network_header + hsize; + + skb_reset_mac_len(skb); + if (x->sel.family != AF_INET6) { + phlen = IPV4_BEET_PHMAXLEN; + if (x->outer_mode.family == AF_INET6) + phlen += sizeof(struct ipv6hdr) - sizeof(struct iphdr); + } + + pskb_pull(skb, skb->mac_len + hsize + (x->props.header_len - phlen)); +} + +/* Adjust pointers into the packet when IPsec is done at layer2 */ +static void xfrm_outer_mode_prep(struct xfrm_state *x, struct sk_buff *skb) +{ + switch (x->outer_mode.encap) { + case XFRM_MODE_TUNNEL: + if (x->outer_mode.family == AF_INET) + return __xfrm_mode_tunnel_prep(x, skb, + sizeof(struct iphdr)); + if (x->outer_mode.family == AF_INET6) + return __xfrm_mode_tunnel_prep(x, skb, + sizeof(struct ipv6hdr)); + break; + case XFRM_MODE_TRANSPORT: + if (x->outer_mode.family == AF_INET) + return __xfrm_transport_prep(x, skb, + sizeof(struct iphdr)); + if (x->outer_mode.family == AF_INET6) + return __xfrm_transport_prep(x, skb, + sizeof(struct ipv6hdr)); + break; + case XFRM_MODE_BEET: + if (x->outer_mode.family == AF_INET) + return __xfrm_mode_beet_prep(x, skb, + sizeof(struct iphdr)); + if (x->outer_mode.family == AF_INET6) + return __xfrm_mode_beet_prep(x, skb, + sizeof(struct ipv6hdr)); + break; + case XFRM_MODE_ROUTEOPTIMIZATION: + case XFRM_MODE_IN_TRIGGER: + break; + } +} + +static inline bool xmit_xfrm_check_overflow(struct sk_buff *skb) +{ + struct xfrm_offload *xo = xfrm_offload(skb); + __u32 seq = xo->seq.low; + + seq += skb_shinfo(skb)->gso_segs; + if (unlikely(seq < xo->seq.low)) + return true; + + return false; +} + +struct sk_buff *validate_xmit_xfrm(struct sk_buff *skb, netdev_features_t features, bool *again) +{ + int err; + unsigned long flags; + struct xfrm_state *x; + struct softnet_data *sd; + struct sk_buff *skb2, *nskb, *pskb = NULL; + netdev_features_t esp_features = features; + struct xfrm_offload *xo = xfrm_offload(skb); + struct net_device *dev = skb->dev; + struct sec_path *sp; + + if (!xo || (xo->flags & XFRM_XMIT)) + return skb; + + if (!(features & NETIF_F_HW_ESP)) + esp_features = features & ~(NETIF_F_SG | NETIF_F_CSUM_MASK); + + sp = skb_sec_path(skb); + x = sp->xvec[sp->len - 1]; + if (xo->flags & XFRM_GRO || x->xso.dir == XFRM_DEV_OFFLOAD_IN) + return skb; + + /* This skb was already validated on the upper/virtual dev */ + if ((x->xso.dev != dev) && (x->xso.real_dev == dev)) + return skb; + + local_irq_save(flags); + sd = this_cpu_ptr(&softnet_data); + err = !skb_queue_empty(&sd->xfrm_backlog); + local_irq_restore(flags); + + if (err) { + *again = true; + return skb; + } + + if (skb_is_gso(skb) && (unlikely(x->xso.dev != dev) || + unlikely(xmit_xfrm_check_overflow(skb)))) { + struct sk_buff *segs; + + /* Packet got rerouted, fixup features and segment it. */ + esp_features = esp_features & ~(NETIF_F_HW_ESP | NETIF_F_GSO_ESP); + + segs = skb_gso_segment(skb, esp_features); + if (IS_ERR(segs)) { + kfree_skb(skb); + dev_core_stats_tx_dropped_inc(dev); + return NULL; + } else { + consume_skb(skb); + skb = segs; + } + } + + if (!skb->next) { + esp_features |= skb->dev->gso_partial_features; + xfrm_outer_mode_prep(x, skb); + + xo->flags |= XFRM_DEV_RESUME; + + err = x->type_offload->xmit(x, skb, esp_features); + if (err) { + if (err == -EINPROGRESS) + return NULL; + + XFRM_INC_STATS(xs_net(x), LINUX_MIB_XFRMOUTSTATEPROTOERROR); + kfree_skb(skb); + return NULL; + } + + skb_push(skb, skb->data - skb_mac_header(skb)); + + return skb; + } + + skb_list_walk_safe(skb, skb2, nskb) { + esp_features |= skb->dev->gso_partial_features; + skb_mark_not_on_list(skb2); + + xo = xfrm_offload(skb2); + xo->flags |= XFRM_DEV_RESUME; + + xfrm_outer_mode_prep(x, skb2); + + err = x->type_offload->xmit(x, skb2, esp_features); + if (!err) { + skb2->next = nskb; + } else if (err != -EINPROGRESS) { + XFRM_INC_STATS(xs_net(x), LINUX_MIB_XFRMOUTSTATEPROTOERROR); + skb2->next = nskb; + kfree_skb_list(skb2); + return NULL; + } else { + if (skb == skb2) + skb = nskb; + else + pskb->next = nskb; + + continue; + } + + skb_push(skb2, skb2->data - skb_mac_header(skb2)); + pskb = skb2; + } + + return skb; +} +EXPORT_SYMBOL_GPL(validate_xmit_xfrm); + +int xfrm_dev_state_add(struct net *net, struct xfrm_state *x, + struct xfrm_user_offload *xuo, + struct netlink_ext_ack *extack) +{ + int err; + struct dst_entry *dst; + struct net_device *dev; + struct xfrm_dev_offload *xso = &x->xso; + xfrm_address_t *saddr; + xfrm_address_t *daddr; + + if (!x->type_offload) { + NL_SET_ERR_MSG(extack, "Type doesn't support offload"); + return -EINVAL; + } + + /* We don't yet support UDP encapsulation and TFC padding. */ + if (x->encap || x->tfcpad) { + NL_SET_ERR_MSG(extack, "Encapsulation and TFC padding can't be offloaded"); + return -EINVAL; + } + + if (xuo->flags & ~(XFRM_OFFLOAD_IPV6 | XFRM_OFFLOAD_INBOUND)) { + NL_SET_ERR_MSG(extack, "Unrecognized flags in offload request"); + return -EINVAL; + } + + dev = dev_get_by_index(net, xuo->ifindex); + if (!dev) { + if (!(xuo->flags & XFRM_OFFLOAD_INBOUND)) { + saddr = &x->props.saddr; + daddr = &x->id.daddr; + } else { + saddr = &x->id.daddr; + daddr = &x->props.saddr; + } + + dst = __xfrm_dst_lookup(net, 0, 0, saddr, daddr, + x->props.family, + xfrm_smark_get(0, x)); + if (IS_ERR(dst)) + return 0; + + dev = dst->dev; + + dev_hold(dev); + dst_release(dst); + } + + if (!dev->xfrmdev_ops || !dev->xfrmdev_ops->xdo_dev_state_add) { + xso->dev = NULL; + dev_put(dev); + return 0; + } + + if (x->props.flags & XFRM_STATE_ESN && + !dev->xfrmdev_ops->xdo_dev_state_advance_esn) { + NL_SET_ERR_MSG(extack, "Device doesn't support offload with ESN"); + xso->dev = NULL; + dev_put(dev); + return -EINVAL; + } + + xso->dev = dev; + netdev_tracker_alloc(dev, &xso->dev_tracker, GFP_ATOMIC); + xso->real_dev = dev; + + if (xuo->flags & XFRM_OFFLOAD_INBOUND) + xso->dir = XFRM_DEV_OFFLOAD_IN; + else + xso->dir = XFRM_DEV_OFFLOAD_OUT; + + err = dev->xfrmdev_ops->xdo_dev_state_add(x); + if (err) { + xso->dev = NULL; + xso->dir = 0; + xso->real_dev = NULL; + netdev_put(dev, &xso->dev_tracker); + + if (err != -EOPNOTSUPP) { + NL_SET_ERR_MSG(extack, "Device failed to offload this state"); + return err; + } + } + + return 0; +} +EXPORT_SYMBOL_GPL(xfrm_dev_state_add); + +bool xfrm_dev_offload_ok(struct sk_buff *skb, struct xfrm_state *x) +{ + int mtu; + struct dst_entry *dst = skb_dst(skb); + struct xfrm_dst *xdst = (struct xfrm_dst *)dst; + struct net_device *dev = x->xso.dev; + + if (!x->type_offload || x->encap) + return false; + + if ((!dev || (dev == xfrm_dst_path(dst)->dev)) && + (!xdst->child->xfrm)) { + mtu = xfrm_state_mtu(x, xdst->child_mtu_cached); + if (skb->len <= mtu) + goto ok; + + if (skb_is_gso(skb) && skb_gso_validate_network_len(skb, mtu)) + goto ok; + } + + return false; + +ok: + if (dev && dev->xfrmdev_ops && dev->xfrmdev_ops->xdo_dev_offload_ok) + return x->xso.dev->xfrmdev_ops->xdo_dev_offload_ok(skb, x); + + return true; +} +EXPORT_SYMBOL_GPL(xfrm_dev_offload_ok); + +void xfrm_dev_resume(struct sk_buff *skb) +{ + struct net_device *dev = skb->dev; + int ret = NETDEV_TX_BUSY; + struct netdev_queue *txq; + struct softnet_data *sd; + unsigned long flags; + + rcu_read_lock(); + txq = netdev_core_pick_tx(dev, skb, NULL); + + HARD_TX_LOCK(dev, txq, smp_processor_id()); + if (!netif_xmit_frozen_or_stopped(txq)) + skb = dev_hard_start_xmit(skb, dev, txq, &ret); + HARD_TX_UNLOCK(dev, txq); + + if (!dev_xmit_complete(ret)) { + local_irq_save(flags); + sd = this_cpu_ptr(&softnet_data); + skb_queue_tail(&sd->xfrm_backlog, skb); + raise_softirq_irqoff(NET_TX_SOFTIRQ); + local_irq_restore(flags); + } + rcu_read_unlock(); +} +EXPORT_SYMBOL_GPL(xfrm_dev_resume); + +void xfrm_dev_backlog(struct softnet_data *sd) +{ + struct sk_buff_head *xfrm_backlog = &sd->xfrm_backlog; + struct sk_buff_head list; + struct sk_buff *skb; + + if (skb_queue_empty(xfrm_backlog)) + return; + + __skb_queue_head_init(&list); + + spin_lock(&xfrm_backlog->lock); + skb_queue_splice_init(xfrm_backlog, &list); + spin_unlock(&xfrm_backlog->lock); + + while (!skb_queue_empty(&list)) { + skb = __skb_dequeue(&list); + xfrm_dev_resume(skb); + } + +} +#endif + +static int xfrm_api_check(struct net_device *dev) +{ +#ifdef CONFIG_XFRM_OFFLOAD + if ((dev->features & NETIF_F_HW_ESP_TX_CSUM) && + !(dev->features & NETIF_F_HW_ESP)) + return NOTIFY_BAD; + + if ((dev->features & NETIF_F_HW_ESP) && + (!(dev->xfrmdev_ops && + dev->xfrmdev_ops->xdo_dev_state_add && + dev->xfrmdev_ops->xdo_dev_state_delete))) + return NOTIFY_BAD; +#else + if (dev->features & (NETIF_F_HW_ESP | NETIF_F_HW_ESP_TX_CSUM)) + return NOTIFY_BAD; +#endif + + return NOTIFY_DONE; +} + +static int xfrm_dev_down(struct net_device *dev) +{ + if (dev->features & NETIF_F_HW_ESP) + xfrm_dev_state_flush(dev_net(dev), dev, true); + + return NOTIFY_DONE; +} + +static int xfrm_dev_event(struct notifier_block *this, unsigned long event, void *ptr) +{ + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + switch (event) { + case NETDEV_REGISTER: + return xfrm_api_check(dev); + + case NETDEV_FEAT_CHANGE: + return xfrm_api_check(dev); + + case NETDEV_DOWN: + case NETDEV_UNREGISTER: + return xfrm_dev_down(dev); + } + return NOTIFY_DONE; +} + +static struct notifier_block xfrm_dev_notifier = { + .notifier_call = xfrm_dev_event, +}; + +void __init xfrm_dev_init(void) +{ + register_netdevice_notifier(&xfrm_dev_notifier); +} diff --git a/net/xfrm/xfrm_hash.c b/net/xfrm/xfrm_hash.c new file mode 100644 index 000000000..eca8d84d9 --- /dev/null +++ b/net/xfrm/xfrm_hash.c @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: GPL-2.0 +/* xfrm_hash.c: Common hash table code. + * + * Copyright (C) 2006 David S. Miller (davem@davemloft.net) + */ + +#include +#include +#include +#include +#include +#include + +#include "xfrm_hash.h" + +struct hlist_head *xfrm_hash_alloc(unsigned int sz) +{ + struct hlist_head *n; + + if (sz <= PAGE_SIZE) + n = kzalloc(sz, GFP_KERNEL); + else if (hashdist) + n = vzalloc(sz); + else + n = (struct hlist_head *) + __get_free_pages(GFP_KERNEL | __GFP_NOWARN | __GFP_ZERO, + get_order(sz)); + + return n; +} + +void xfrm_hash_free(struct hlist_head *n, unsigned int sz) +{ + if (sz <= PAGE_SIZE) + kfree(n); + else if (hashdist) + vfree(n); + else + free_pages((unsigned long)n, get_order(sz)); +} diff --git a/net/xfrm/xfrm_hash.h b/net/xfrm/xfrm_hash.h new file mode 100644 index 000000000..d12bb906c --- /dev/null +++ b/net/xfrm/xfrm_hash.h @@ -0,0 +1,199 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#ifndef _XFRM_HASH_H +#define _XFRM_HASH_H + +#include +#include +#include + +static inline unsigned int __xfrm4_addr_hash(const xfrm_address_t *addr) +{ + return ntohl(addr->a4); +} + +static inline unsigned int __xfrm6_addr_hash(const xfrm_address_t *addr) +{ + return jhash2((__force u32 *)addr->a6, 4, 0); +} + +static inline unsigned int __xfrm4_daddr_saddr_hash(const xfrm_address_t *daddr, + const xfrm_address_t *saddr) +{ + u32 sum = (__force u32)daddr->a4 + (__force u32)saddr->a4; + return ntohl((__force __be32)sum); +} + +static inline unsigned int __xfrm6_daddr_saddr_hash(const xfrm_address_t *daddr, + const xfrm_address_t *saddr) +{ + return __xfrm6_addr_hash(daddr) ^ __xfrm6_addr_hash(saddr); +} + +static inline u32 __bits2mask32(__u8 bits) +{ + u32 mask32 = 0xffffffff; + + if (bits == 0) + mask32 = 0; + else if (bits < 32) + mask32 <<= (32 - bits); + + return mask32; +} + +static inline unsigned int __xfrm4_dpref_spref_hash(const xfrm_address_t *daddr, + const xfrm_address_t *saddr, + __u8 dbits, + __u8 sbits) +{ + return jhash_2words(ntohl(daddr->a4) & __bits2mask32(dbits), + ntohl(saddr->a4) & __bits2mask32(sbits), + 0); +} + +static inline unsigned int __xfrm6_pref_hash(const xfrm_address_t *addr, + __u8 prefixlen) +{ + unsigned int pdw; + unsigned int pbi; + u32 initval = 0; + + pdw = prefixlen >> 5; /* num of whole u32 in prefix */ + pbi = prefixlen & 0x1f; /* num of bits in incomplete u32 in prefix */ + + if (pbi) { + __be32 mask; + + mask = htonl((0xffffffff) << (32 - pbi)); + + initval = (__force u32)(addr->a6[pdw] & mask); + } + + return jhash2((__force u32 *)addr->a6, pdw, initval); +} + +static inline unsigned int __xfrm6_dpref_spref_hash(const xfrm_address_t *daddr, + const xfrm_address_t *saddr, + __u8 dbits, + __u8 sbits) +{ + return __xfrm6_pref_hash(daddr, dbits) ^ + __xfrm6_pref_hash(saddr, sbits); +} + +static inline unsigned int __xfrm_dst_hash(const xfrm_address_t *daddr, + const xfrm_address_t *saddr, + u32 reqid, unsigned short family, + unsigned int hmask) +{ + unsigned int h = family ^ reqid; + switch (family) { + case AF_INET: + h ^= __xfrm4_daddr_saddr_hash(daddr, saddr); + break; + case AF_INET6: + h ^= __xfrm6_daddr_saddr_hash(daddr, saddr); + break; + } + return (h ^ (h >> 16)) & hmask; +} + +static inline unsigned int __xfrm_src_hash(const xfrm_address_t *daddr, + const xfrm_address_t *saddr, + unsigned short family, + unsigned int hmask) +{ + unsigned int h = family; + switch (family) { + case AF_INET: + h ^= __xfrm4_daddr_saddr_hash(daddr, saddr); + break; + case AF_INET6: + h ^= __xfrm6_daddr_saddr_hash(daddr, saddr); + break; + } + return (h ^ (h >> 16)) & hmask; +} + +static inline unsigned int +__xfrm_spi_hash(const xfrm_address_t *daddr, __be32 spi, u8 proto, + unsigned short family, unsigned int hmask) +{ + unsigned int h = (__force u32)spi ^ proto; + switch (family) { + case AF_INET: + h ^= __xfrm4_addr_hash(daddr); + break; + case AF_INET6: + h ^= __xfrm6_addr_hash(daddr); + break; + } + return (h ^ (h >> 10) ^ (h >> 20)) & hmask; +} + +static inline unsigned int +__xfrm_seq_hash(u32 seq, unsigned int hmask) +{ + unsigned int h = seq; + return (h ^ (h >> 10) ^ (h >> 20)) & hmask; +} + +static inline unsigned int __idx_hash(u32 index, unsigned int hmask) +{ + return (index ^ (index >> 8)) & hmask; +} + +static inline unsigned int __sel_hash(const struct xfrm_selector *sel, + unsigned short family, unsigned int hmask, + u8 dbits, u8 sbits) +{ + const xfrm_address_t *daddr = &sel->daddr; + const xfrm_address_t *saddr = &sel->saddr; + unsigned int h = 0; + + switch (family) { + case AF_INET: + if (sel->prefixlen_d < dbits || + sel->prefixlen_s < sbits) + return hmask + 1; + + h = __xfrm4_dpref_spref_hash(daddr, saddr, dbits, sbits); + break; + + case AF_INET6: + if (sel->prefixlen_d < dbits || + sel->prefixlen_s < sbits) + return hmask + 1; + + h = __xfrm6_dpref_spref_hash(daddr, saddr, dbits, sbits); + break; + } + h ^= (h >> 16); + return h & hmask; +} + +static inline unsigned int __addr_hash(const xfrm_address_t *daddr, + const xfrm_address_t *saddr, + unsigned short family, + unsigned int hmask, + u8 dbits, u8 sbits) +{ + unsigned int h = 0; + + switch (family) { + case AF_INET: + h = __xfrm4_dpref_spref_hash(daddr, saddr, dbits, sbits); + break; + + case AF_INET6: + h = __xfrm6_dpref_spref_hash(daddr, saddr, dbits, sbits); + break; + } + h ^= (h >> 16); + return h & hmask; +} + +struct hlist_head *xfrm_hash_alloc(unsigned int sz); +void xfrm_hash_free(struct hlist_head *n, unsigned int sz); + +#endif /* _XFRM_HASH_H */ diff --git a/net/xfrm/xfrm_inout.h b/net/xfrm/xfrm_inout.h new file mode 100644 index 000000000..efc5e6b2e --- /dev/null +++ b/net/xfrm/xfrm_inout.h @@ -0,0 +1,70 @@ +/* SPDX-License-Identifier: GPL-2.0 */ +#include +#include +#include + +#ifndef XFRM_INOUT_H +#define XFRM_INOUT_H 1 + +static inline void xfrm4_extract_header(struct sk_buff *skb) +{ + const struct iphdr *iph = ip_hdr(skb); + + XFRM_MODE_SKB_CB(skb)->ihl = sizeof(*iph); + XFRM_MODE_SKB_CB(skb)->id = iph->id; + XFRM_MODE_SKB_CB(skb)->frag_off = iph->frag_off; + XFRM_MODE_SKB_CB(skb)->tos = iph->tos; + XFRM_MODE_SKB_CB(skb)->ttl = iph->ttl; + XFRM_MODE_SKB_CB(skb)->optlen = iph->ihl * 4 - sizeof(*iph); + memset(XFRM_MODE_SKB_CB(skb)->flow_lbl, 0, + sizeof(XFRM_MODE_SKB_CB(skb)->flow_lbl)); +} + +static inline void xfrm6_extract_header(struct sk_buff *skb) +{ +#if IS_ENABLED(CONFIG_IPV6) + struct ipv6hdr *iph = ipv6_hdr(skb); + + XFRM_MODE_SKB_CB(skb)->ihl = sizeof(*iph); + XFRM_MODE_SKB_CB(skb)->id = 0; + XFRM_MODE_SKB_CB(skb)->frag_off = htons(IP_DF); + XFRM_MODE_SKB_CB(skb)->tos = ipv6_get_dsfield(iph); + XFRM_MODE_SKB_CB(skb)->ttl = iph->hop_limit; + XFRM_MODE_SKB_CB(skb)->optlen = 0; + memcpy(XFRM_MODE_SKB_CB(skb)->flow_lbl, iph->flow_lbl, + sizeof(XFRM_MODE_SKB_CB(skb)->flow_lbl)); +#else + WARN_ON_ONCE(1); +#endif +} + +static inline void xfrm6_beet_make_header(struct sk_buff *skb) +{ + struct ipv6hdr *iph = ipv6_hdr(skb); + + iph->version = 6; + + memcpy(iph->flow_lbl, XFRM_MODE_SKB_CB(skb)->flow_lbl, + sizeof(iph->flow_lbl)); + iph->nexthdr = XFRM_MODE_SKB_CB(skb)->protocol; + + ipv6_change_dsfield(iph, 0, XFRM_MODE_SKB_CB(skb)->tos); + iph->hop_limit = XFRM_MODE_SKB_CB(skb)->ttl; +} + +static inline void xfrm4_beet_make_header(struct sk_buff *skb) +{ + struct iphdr *iph = ip_hdr(skb); + + iph->ihl = 5; + iph->version = 4; + + iph->protocol = XFRM_MODE_SKB_CB(skb)->protocol; + iph->tos = XFRM_MODE_SKB_CB(skb)->tos; + + iph->id = XFRM_MODE_SKB_CB(skb)->id; + iph->frag_off = XFRM_MODE_SKB_CB(skb)->frag_off; + iph->ttl = XFRM_MODE_SKB_CB(skb)->ttl; +} + +#endif diff --git a/net/xfrm/xfrm_input.c b/net/xfrm/xfrm_input.c new file mode 100644 index 000000000..ac1a645af --- /dev/null +++ b/net/xfrm/xfrm_input.c @@ -0,0 +1,834 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * xfrm_input.c + * + * Changes: + * YOSHIFUJI Hideaki @USAGI + * Split up af-specific portion + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "xfrm_inout.h" + +struct xfrm_trans_tasklet { + struct work_struct work; + spinlock_t queue_lock; + struct sk_buff_head queue; +}; + +struct xfrm_trans_cb { + union { + struct inet_skb_parm h4; +#if IS_ENABLED(CONFIG_IPV6) + struct inet6_skb_parm h6; +#endif + } header; + int (*finish)(struct net *net, struct sock *sk, struct sk_buff *skb); + struct net *net; +}; + +#define XFRM_TRANS_SKB_CB(__skb) ((struct xfrm_trans_cb *)&((__skb)->cb[0])) + +static DEFINE_SPINLOCK(xfrm_input_afinfo_lock); +static struct xfrm_input_afinfo const __rcu *xfrm_input_afinfo[2][AF_INET6 + 1]; + +static struct gro_cells gro_cells; +static struct net_device xfrm_napi_dev; + +static DEFINE_PER_CPU(struct xfrm_trans_tasklet, xfrm_trans_tasklet); + +int xfrm_input_register_afinfo(const struct xfrm_input_afinfo *afinfo) +{ + int err = 0; + + if (WARN_ON(afinfo->family > AF_INET6)) + return -EAFNOSUPPORT; + + spin_lock_bh(&xfrm_input_afinfo_lock); + if (unlikely(xfrm_input_afinfo[afinfo->is_ipip][afinfo->family])) + err = -EEXIST; + else + rcu_assign_pointer(xfrm_input_afinfo[afinfo->is_ipip][afinfo->family], afinfo); + spin_unlock_bh(&xfrm_input_afinfo_lock); + return err; +} +EXPORT_SYMBOL(xfrm_input_register_afinfo); + +int xfrm_input_unregister_afinfo(const struct xfrm_input_afinfo *afinfo) +{ + int err = 0; + + spin_lock_bh(&xfrm_input_afinfo_lock); + if (likely(xfrm_input_afinfo[afinfo->is_ipip][afinfo->family])) { + if (unlikely(xfrm_input_afinfo[afinfo->is_ipip][afinfo->family] != afinfo)) + err = -EINVAL; + else + RCU_INIT_POINTER(xfrm_input_afinfo[afinfo->is_ipip][afinfo->family], NULL); + } + spin_unlock_bh(&xfrm_input_afinfo_lock); + synchronize_rcu(); + return err; +} +EXPORT_SYMBOL(xfrm_input_unregister_afinfo); + +static const struct xfrm_input_afinfo *xfrm_input_get_afinfo(u8 family, bool is_ipip) +{ + const struct xfrm_input_afinfo *afinfo; + + if (WARN_ON_ONCE(family > AF_INET6)) + return NULL; + + rcu_read_lock(); + afinfo = rcu_dereference(xfrm_input_afinfo[is_ipip][family]); + if (unlikely(!afinfo)) + rcu_read_unlock(); + return afinfo; +} + +static int xfrm_rcv_cb(struct sk_buff *skb, unsigned int family, u8 protocol, + int err) +{ + bool is_ipip = (protocol == IPPROTO_IPIP || protocol == IPPROTO_IPV6); + const struct xfrm_input_afinfo *afinfo; + int ret; + + afinfo = xfrm_input_get_afinfo(family, is_ipip); + if (!afinfo) + return -EAFNOSUPPORT; + + ret = afinfo->callback(skb, protocol, err); + rcu_read_unlock(); + + return ret; +} + +struct sec_path *secpath_set(struct sk_buff *skb) +{ + struct sec_path *sp, *tmp = skb_ext_find(skb, SKB_EXT_SEC_PATH); + + sp = skb_ext_add(skb, SKB_EXT_SEC_PATH); + if (!sp) + return NULL; + + if (tmp) /* reused existing one (was COW'd if needed) */ + return sp; + + /* allocated new secpath */ + memset(sp->ovec, 0, sizeof(sp->ovec)); + sp->olen = 0; + sp->len = 0; + sp->verified_cnt = 0; + + return sp; +} +EXPORT_SYMBOL(secpath_set); + +/* Fetch spi and seq from ipsec header */ + +int xfrm_parse_spi(struct sk_buff *skb, u8 nexthdr, __be32 *spi, __be32 *seq) +{ + int offset, offset_seq; + int hlen; + + switch (nexthdr) { + case IPPROTO_AH: + hlen = sizeof(struct ip_auth_hdr); + offset = offsetof(struct ip_auth_hdr, spi); + offset_seq = offsetof(struct ip_auth_hdr, seq_no); + break; + case IPPROTO_ESP: + hlen = sizeof(struct ip_esp_hdr); + offset = offsetof(struct ip_esp_hdr, spi); + offset_seq = offsetof(struct ip_esp_hdr, seq_no); + break; + case IPPROTO_COMP: + if (!pskb_may_pull(skb, sizeof(struct ip_comp_hdr))) + return -EINVAL; + *spi = htonl(ntohs(*(__be16 *)(skb_transport_header(skb) + 2))); + *seq = 0; + return 0; + default: + return 1; + } + + if (!pskb_may_pull(skb, hlen)) + return -EINVAL; + + *spi = *(__be32 *)(skb_transport_header(skb) + offset); + *seq = *(__be32 *)(skb_transport_header(skb) + offset_seq); + return 0; +} +EXPORT_SYMBOL(xfrm_parse_spi); + +static int xfrm4_remove_beet_encap(struct xfrm_state *x, struct sk_buff *skb) +{ + struct iphdr *iph; + int optlen = 0; + int err = -EINVAL; + + if (unlikely(XFRM_MODE_SKB_CB(skb)->protocol == IPPROTO_BEETPH)) { + struct ip_beet_phdr *ph; + int phlen; + + if (!pskb_may_pull(skb, sizeof(*ph))) + goto out; + + ph = (struct ip_beet_phdr *)skb->data; + + phlen = sizeof(*ph) + ph->padlen; + optlen = ph->hdrlen * 8 + (IPV4_BEET_PHMAXLEN - phlen); + if (optlen < 0 || optlen & 3 || optlen > 250) + goto out; + + XFRM_MODE_SKB_CB(skb)->protocol = ph->nexthdr; + + if (!pskb_may_pull(skb, phlen)) + goto out; + __skb_pull(skb, phlen); + } + + skb_push(skb, sizeof(*iph)); + skb_reset_network_header(skb); + skb_mac_header_rebuild(skb); + + xfrm4_beet_make_header(skb); + + iph = ip_hdr(skb); + + iph->ihl += optlen / 4; + iph->tot_len = htons(skb->len); + iph->daddr = x->sel.daddr.a4; + iph->saddr = x->sel.saddr.a4; + iph->check = 0; + iph->check = ip_fast_csum(skb_network_header(skb), iph->ihl); + err = 0; +out: + return err; +} + +static void ipip_ecn_decapsulate(struct sk_buff *skb) +{ + struct iphdr *inner_iph = ipip_hdr(skb); + + if (INET_ECN_is_ce(XFRM_MODE_SKB_CB(skb)->tos)) + IP_ECN_set_ce(inner_iph); +} + +static int xfrm4_remove_tunnel_encap(struct xfrm_state *x, struct sk_buff *skb) +{ + int err = -EINVAL; + + if (XFRM_MODE_SKB_CB(skb)->protocol != IPPROTO_IPIP) + goto out; + + if (!pskb_may_pull(skb, sizeof(struct iphdr))) + goto out; + + err = skb_unclone(skb, GFP_ATOMIC); + if (err) + goto out; + + if (x->props.flags & XFRM_STATE_DECAP_DSCP) + ipv4_copy_dscp(XFRM_MODE_SKB_CB(skb)->tos, ipip_hdr(skb)); + if (!(x->props.flags & XFRM_STATE_NOECN)) + ipip_ecn_decapsulate(skb); + + skb_reset_network_header(skb); + skb_mac_header_rebuild(skb); + if (skb->mac_len) + eth_hdr(skb)->h_proto = skb->protocol; + + err = 0; + +out: + return err; +} + +static void ipip6_ecn_decapsulate(struct sk_buff *skb) +{ + struct ipv6hdr *inner_iph = ipipv6_hdr(skb); + + if (INET_ECN_is_ce(XFRM_MODE_SKB_CB(skb)->tos)) + IP6_ECN_set_ce(skb, inner_iph); +} + +static int xfrm6_remove_tunnel_encap(struct xfrm_state *x, struct sk_buff *skb) +{ + int err = -EINVAL; + + if (XFRM_MODE_SKB_CB(skb)->protocol != IPPROTO_IPV6) + goto out; + if (!pskb_may_pull(skb, sizeof(struct ipv6hdr))) + goto out; + + err = skb_unclone(skb, GFP_ATOMIC); + if (err) + goto out; + + if (x->props.flags & XFRM_STATE_DECAP_DSCP) + ipv6_copy_dscp(XFRM_MODE_SKB_CB(skb)->tos, ipipv6_hdr(skb)); + if (!(x->props.flags & XFRM_STATE_NOECN)) + ipip6_ecn_decapsulate(skb); + + skb_reset_network_header(skb); + skb_mac_header_rebuild(skb); + if (skb->mac_len) + eth_hdr(skb)->h_proto = skb->protocol; + + err = 0; + +out: + return err; +} + +static int xfrm6_remove_beet_encap(struct xfrm_state *x, struct sk_buff *skb) +{ + struct ipv6hdr *ip6h; + int size = sizeof(struct ipv6hdr); + int err; + + err = skb_cow_head(skb, size + skb->mac_len); + if (err) + goto out; + + __skb_push(skb, size); + skb_reset_network_header(skb); + skb_mac_header_rebuild(skb); + + xfrm6_beet_make_header(skb); + + ip6h = ipv6_hdr(skb); + ip6h->payload_len = htons(skb->len - size); + ip6h->daddr = x->sel.daddr.in6; + ip6h->saddr = x->sel.saddr.in6; + err = 0; +out: + return err; +} + +/* Remove encapsulation header. + * + * The IP header will be moved over the top of the encapsulation + * header. + * + * On entry, the transport header shall point to where the IP header + * should be and the network header shall be set to where the IP + * header currently is. skb->data shall point to the start of the + * payload. + */ +static int +xfrm_inner_mode_encap_remove(struct xfrm_state *x, + const struct xfrm_mode *inner_mode, + struct sk_buff *skb) +{ + switch (inner_mode->encap) { + case XFRM_MODE_BEET: + if (inner_mode->family == AF_INET) + return xfrm4_remove_beet_encap(x, skb); + if (inner_mode->family == AF_INET6) + return xfrm6_remove_beet_encap(x, skb); + break; + case XFRM_MODE_TUNNEL: + if (inner_mode->family == AF_INET) + return xfrm4_remove_tunnel_encap(x, skb); + if (inner_mode->family == AF_INET6) + return xfrm6_remove_tunnel_encap(x, skb); + break; + } + + WARN_ON_ONCE(1); + return -EOPNOTSUPP; +} + +static int xfrm_prepare_input(struct xfrm_state *x, struct sk_buff *skb) +{ + const struct xfrm_mode *inner_mode = &x->inner_mode; + + switch (x->outer_mode.family) { + case AF_INET: + xfrm4_extract_header(skb); + break; + case AF_INET6: + xfrm6_extract_header(skb); + break; + default: + WARN_ON_ONCE(1); + return -EAFNOSUPPORT; + } + + if (x->sel.family == AF_UNSPEC) { + inner_mode = xfrm_ip2inner_mode(x, XFRM_MODE_SKB_CB(skb)->protocol); + if (!inner_mode) + return -EAFNOSUPPORT; + } + + switch (inner_mode->family) { + case AF_INET: + skb->protocol = htons(ETH_P_IP); + break; + case AF_INET6: + skb->protocol = htons(ETH_P_IPV6); + break; + default: + WARN_ON_ONCE(1); + break; + } + + return xfrm_inner_mode_encap_remove(x, inner_mode, skb); +} + +/* Remove encapsulation header. + * + * The IP header will be moved over the top of the encapsulation header. + * + * On entry, skb_transport_header() shall point to where the IP header + * should be and skb_network_header() shall be set to where the IP header + * currently is. skb->data shall point to the start of the payload. + */ +static int xfrm4_transport_input(struct xfrm_state *x, struct sk_buff *skb) +{ + int ihl = skb->data - skb_transport_header(skb); + + if (skb->transport_header != skb->network_header) { + memmove(skb_transport_header(skb), + skb_network_header(skb), ihl); + skb->network_header = skb->transport_header; + } + ip_hdr(skb)->tot_len = htons(skb->len + ihl); + skb_reset_transport_header(skb); + return 0; +} + +static int xfrm6_transport_input(struct xfrm_state *x, struct sk_buff *skb) +{ +#if IS_ENABLED(CONFIG_IPV6) + int ihl = skb->data - skb_transport_header(skb); + + if (skb->transport_header != skb->network_header) { + memmove(skb_transport_header(skb), + skb_network_header(skb), ihl); + skb->network_header = skb->transport_header; + } + ipv6_hdr(skb)->payload_len = htons(skb->len + ihl - + sizeof(struct ipv6hdr)); + skb_reset_transport_header(skb); + return 0; +#else + WARN_ON_ONCE(1); + return -EAFNOSUPPORT; +#endif +} + +static int xfrm_inner_mode_input(struct xfrm_state *x, + const struct xfrm_mode *inner_mode, + struct sk_buff *skb) +{ + switch (inner_mode->encap) { + case XFRM_MODE_BEET: + case XFRM_MODE_TUNNEL: + return xfrm_prepare_input(x, skb); + case XFRM_MODE_TRANSPORT: + if (inner_mode->family == AF_INET) + return xfrm4_transport_input(x, skb); + if (inner_mode->family == AF_INET6) + return xfrm6_transport_input(x, skb); + break; + case XFRM_MODE_ROUTEOPTIMIZATION: + WARN_ON_ONCE(1); + break; + default: + WARN_ON_ONCE(1); + break; + } + + return -EOPNOTSUPP; +} + +int xfrm_input(struct sk_buff *skb, int nexthdr, __be32 spi, int encap_type) +{ + const struct xfrm_state_afinfo *afinfo; + struct net *net = dev_net(skb->dev); + const struct xfrm_mode *inner_mode; + int err; + __be32 seq; + __be32 seq_hi; + struct xfrm_state *x = NULL; + xfrm_address_t *daddr; + u32 mark = skb->mark; + unsigned int family = AF_UNSPEC; + int decaps = 0; + int async = 0; + bool xfrm_gro = false; + bool crypto_done = false; + struct xfrm_offload *xo = xfrm_offload(skb); + struct sec_path *sp; + + if (encap_type < 0) { + x = xfrm_input_state(skb); + + if (unlikely(x->km.state != XFRM_STATE_VALID)) { + if (x->km.state == XFRM_STATE_ACQ) + XFRM_INC_STATS(net, LINUX_MIB_XFRMACQUIREERROR); + else + XFRM_INC_STATS(net, + LINUX_MIB_XFRMINSTATEINVALID); + + if (encap_type == -1) + dev_put(skb->dev); + goto drop; + } + + family = x->outer_mode.family; + + /* An encap_type of -1 indicates async resumption. */ + if (encap_type == -1) { + async = 1; + seq = XFRM_SKB_CB(skb)->seq.input.low; + goto resume; + } + + /* encap_type < -1 indicates a GRO call. */ + encap_type = 0; + seq = XFRM_SPI_SKB_CB(skb)->seq; + + if (xo && (xo->flags & CRYPTO_DONE)) { + crypto_done = true; + family = XFRM_SPI_SKB_CB(skb)->family; + + if (!(xo->status & CRYPTO_SUCCESS)) { + if (xo->status & + (CRYPTO_TRANSPORT_AH_AUTH_FAILED | + CRYPTO_TRANSPORT_ESP_AUTH_FAILED | + CRYPTO_TUNNEL_AH_AUTH_FAILED | + CRYPTO_TUNNEL_ESP_AUTH_FAILED)) { + + xfrm_audit_state_icvfail(x, skb, + x->type->proto); + x->stats.integrity_failed++; + XFRM_INC_STATS(net, LINUX_MIB_XFRMINSTATEPROTOERROR); + goto drop; + } + + if (xo->status & CRYPTO_INVALID_PROTOCOL) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINSTATEPROTOERROR); + goto drop; + } + + XFRM_INC_STATS(net, LINUX_MIB_XFRMINBUFFERERROR); + goto drop; + } + + if (xfrm_parse_spi(skb, nexthdr, &spi, &seq)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINHDRERROR); + goto drop; + } + } + + goto lock; + } + + family = XFRM_SPI_SKB_CB(skb)->family; + + /* if tunnel is present override skb->mark value with tunnel i_key */ + switch (family) { + case AF_INET: + if (XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip4) + mark = be32_to_cpu(XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip4->parms.i_key); + break; + case AF_INET6: + if (XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip6) + mark = be32_to_cpu(XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip6->parms.i_key); + break; + } + + sp = secpath_set(skb); + if (!sp) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINERROR); + goto drop; + } + + seq = 0; + if (!spi && xfrm_parse_spi(skb, nexthdr, &spi, &seq)) { + secpath_reset(skb); + XFRM_INC_STATS(net, LINUX_MIB_XFRMINHDRERROR); + goto drop; + } + + daddr = (xfrm_address_t *)(skb_network_header(skb) + + XFRM_SPI_SKB_CB(skb)->daddroff); + do { + sp = skb_sec_path(skb); + + if (sp->len == XFRM_MAX_DEPTH) { + secpath_reset(skb); + XFRM_INC_STATS(net, LINUX_MIB_XFRMINBUFFERERROR); + goto drop; + } + + x = xfrm_state_lookup(net, mark, daddr, spi, nexthdr, family); + if (x == NULL) { + secpath_reset(skb); + XFRM_INC_STATS(net, LINUX_MIB_XFRMINNOSTATES); + xfrm_audit_state_notfound(skb, family, spi, seq); + goto drop; + } + + skb->mark = xfrm_smark_get(skb->mark, x); + + sp->xvec[sp->len++] = x; + + skb_dst_force(skb); + if (!skb_dst(skb)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINERROR); + goto drop; + } + +lock: + spin_lock(&x->lock); + + if (unlikely(x->km.state != XFRM_STATE_VALID)) { + if (x->km.state == XFRM_STATE_ACQ) + XFRM_INC_STATS(net, LINUX_MIB_XFRMACQUIREERROR); + else + XFRM_INC_STATS(net, + LINUX_MIB_XFRMINSTATEINVALID); + goto drop_unlock; + } + + if ((x->encap ? x->encap->encap_type : 0) != encap_type) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINSTATEMISMATCH); + goto drop_unlock; + } + + if (xfrm_replay_check(x, skb, seq)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINSTATESEQERROR); + goto drop_unlock; + } + + if (xfrm_state_check_expire(x)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINSTATEEXPIRED); + goto drop_unlock; + } + + spin_unlock(&x->lock); + + if (xfrm_tunnel_check(skb, x, family)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINSTATEMODEERROR); + goto drop; + } + + seq_hi = htonl(xfrm_replay_seqhi(x, seq)); + + XFRM_SKB_CB(skb)->seq.input.low = seq; + XFRM_SKB_CB(skb)->seq.input.hi = seq_hi; + + dev_hold(skb->dev); + + if (crypto_done) + nexthdr = x->type_offload->input_tail(x, skb); + else + nexthdr = x->type->input(x, skb); + + if (nexthdr == -EINPROGRESS) + return 0; +resume: + dev_put(skb->dev); + + spin_lock(&x->lock); + if (nexthdr < 0) { + if (nexthdr == -EBADMSG) { + xfrm_audit_state_icvfail(x, skb, + x->type->proto); + x->stats.integrity_failed++; + } + XFRM_INC_STATS(net, LINUX_MIB_XFRMINSTATEPROTOERROR); + goto drop_unlock; + } + + /* only the first xfrm gets the encap type */ + encap_type = 0; + + if (xfrm_replay_recheck(x, skb, seq)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINSTATESEQERROR); + goto drop_unlock; + } + + xfrm_replay_advance(x, seq); + + x->curlft.bytes += skb->len; + x->curlft.packets++; + + spin_unlock(&x->lock); + + XFRM_MODE_SKB_CB(skb)->protocol = nexthdr; + + inner_mode = &x->inner_mode; + + if (x->sel.family == AF_UNSPEC) { + inner_mode = xfrm_ip2inner_mode(x, XFRM_MODE_SKB_CB(skb)->protocol); + if (inner_mode == NULL) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINSTATEMODEERROR); + goto drop; + } + } + + if (xfrm_inner_mode_input(x, inner_mode, skb)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINSTATEMODEERROR); + goto drop; + } + + if (x->outer_mode.flags & XFRM_MODE_FLAG_TUNNEL) { + decaps = 1; + break; + } + + /* + * We need the inner address. However, we only get here for + * transport mode so the outer address is identical. + */ + daddr = &x->id.daddr; + family = x->outer_mode.family; + + err = xfrm_parse_spi(skb, nexthdr, &spi, &seq); + if (err < 0) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINHDRERROR); + goto drop; + } + crypto_done = false; + } while (!err); + + err = xfrm_rcv_cb(skb, family, x->type->proto, 0); + if (err) + goto drop; + + nf_reset_ct(skb); + + if (decaps) { + sp = skb_sec_path(skb); + if (sp) + sp->olen = 0; + if (skb_valid_dst(skb)) + skb_dst_drop(skb); + gro_cells_receive(&gro_cells, skb); + return 0; + } else { + xo = xfrm_offload(skb); + if (xo) + xfrm_gro = xo->flags & XFRM_GRO; + + err = -EAFNOSUPPORT; + rcu_read_lock(); + afinfo = xfrm_state_afinfo_get_rcu(x->inner_mode.family); + if (likely(afinfo)) + err = afinfo->transport_finish(skb, xfrm_gro || async); + rcu_read_unlock(); + if (xfrm_gro) { + sp = skb_sec_path(skb); + if (sp) + sp->olen = 0; + if (skb_valid_dst(skb)) + skb_dst_drop(skb); + gro_cells_receive(&gro_cells, skb); + return err; + } + + return err; + } + +drop_unlock: + spin_unlock(&x->lock); +drop: + xfrm_rcv_cb(skb, family, x && x->type ? x->type->proto : nexthdr, -1); + kfree_skb(skb); + return 0; +} +EXPORT_SYMBOL(xfrm_input); + +int xfrm_input_resume(struct sk_buff *skb, int nexthdr) +{ + return xfrm_input(skb, nexthdr, 0, -1); +} +EXPORT_SYMBOL(xfrm_input_resume); + +static void xfrm_trans_reinject(struct work_struct *work) +{ + struct xfrm_trans_tasklet *trans = container_of(work, struct xfrm_trans_tasklet, work); + struct sk_buff_head queue; + struct sk_buff *skb; + + __skb_queue_head_init(&queue); + spin_lock_bh(&trans->queue_lock); + skb_queue_splice_init(&trans->queue, &queue); + spin_unlock_bh(&trans->queue_lock); + + local_bh_disable(); + while ((skb = __skb_dequeue(&queue))) + XFRM_TRANS_SKB_CB(skb)->finish(XFRM_TRANS_SKB_CB(skb)->net, + NULL, skb); + local_bh_enable(); +} + +int xfrm_trans_queue_net(struct net *net, struct sk_buff *skb, + int (*finish)(struct net *, struct sock *, + struct sk_buff *)) +{ + struct xfrm_trans_tasklet *trans; + + trans = this_cpu_ptr(&xfrm_trans_tasklet); + + if (skb_queue_len(&trans->queue) >= READ_ONCE(netdev_max_backlog)) + return -ENOBUFS; + + BUILD_BUG_ON(sizeof(struct xfrm_trans_cb) > sizeof(skb->cb)); + + XFRM_TRANS_SKB_CB(skb)->finish = finish; + XFRM_TRANS_SKB_CB(skb)->net = net; + spin_lock_bh(&trans->queue_lock); + __skb_queue_tail(&trans->queue, skb); + spin_unlock_bh(&trans->queue_lock); + schedule_work(&trans->work); + return 0; +} +EXPORT_SYMBOL(xfrm_trans_queue_net); + +int xfrm_trans_queue(struct sk_buff *skb, + int (*finish)(struct net *, struct sock *, + struct sk_buff *)) +{ + return xfrm_trans_queue_net(dev_net(skb->dev), skb, finish); +} +EXPORT_SYMBOL(xfrm_trans_queue); + +void __init xfrm_input_init(void) +{ + int err; + int i; + + init_dummy_netdev(&xfrm_napi_dev); + err = gro_cells_init(&gro_cells, &xfrm_napi_dev); + if (err) + gro_cells.cells = NULL; + + for_each_possible_cpu(i) { + struct xfrm_trans_tasklet *trans; + + trans = &per_cpu(xfrm_trans_tasklet, i); + spin_lock_init(&trans->queue_lock); + __skb_queue_head_init(&trans->queue); + INIT_WORK(&trans->work, xfrm_trans_reinject); + } +} diff --git a/net/xfrm/xfrm_interface_core.c b/net/xfrm/xfrm_interface_core.c new file mode 100644 index 000000000..85501b77f --- /dev/null +++ b/net/xfrm/xfrm_interface_core.c @@ -0,0 +1,1242 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * XFRM virtual interface + * + * Copyright (C) 2018 secunet Security Networks AG + * + * Author: + * Steffen Klassert + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +static int xfrmi_dev_init(struct net_device *dev); +static void xfrmi_dev_setup(struct net_device *dev); +static struct rtnl_link_ops xfrmi_link_ops __read_mostly; +static unsigned int xfrmi_net_id __read_mostly; +static const struct net_device_ops xfrmi_netdev_ops; + +#define XFRMI_HASH_BITS 8 +#define XFRMI_HASH_SIZE BIT(XFRMI_HASH_BITS) + +struct xfrmi_net { + /* lists for storing interfaces in use */ + struct xfrm_if __rcu *xfrmi[XFRMI_HASH_SIZE]; + struct xfrm_if __rcu *collect_md_xfrmi; +}; + +static const struct nla_policy xfrm_lwt_policy[LWT_XFRM_MAX + 1] = { + [LWT_XFRM_IF_ID] = NLA_POLICY_MIN(NLA_U32, 1), + [LWT_XFRM_LINK] = NLA_POLICY_MIN(NLA_U32, 1), +}; + +static void xfrmi_destroy_state(struct lwtunnel_state *lwt) +{ +} + +static int xfrmi_build_state(struct net *net, struct nlattr *nla, + unsigned int family, const void *cfg, + struct lwtunnel_state **ts, + struct netlink_ext_ack *extack) +{ + struct nlattr *tb[LWT_XFRM_MAX + 1]; + struct lwtunnel_state *new_state; + struct xfrm_md_info *info; + int ret; + + ret = nla_parse_nested(tb, LWT_XFRM_MAX, nla, xfrm_lwt_policy, extack); + if (ret < 0) + return ret; + + if (!tb[LWT_XFRM_IF_ID]) { + NL_SET_ERR_MSG(extack, "if_id must be set"); + return -EINVAL; + } + + new_state = lwtunnel_state_alloc(sizeof(*info)); + if (!new_state) { + NL_SET_ERR_MSG(extack, "failed to create encap info"); + return -ENOMEM; + } + + new_state->type = LWTUNNEL_ENCAP_XFRM; + + info = lwt_xfrm_info(new_state); + + info->if_id = nla_get_u32(tb[LWT_XFRM_IF_ID]); + + if (tb[LWT_XFRM_LINK]) + info->link = nla_get_u32(tb[LWT_XFRM_LINK]); + + *ts = new_state; + return 0; +} + +static int xfrmi_fill_encap_info(struct sk_buff *skb, + struct lwtunnel_state *lwt) +{ + struct xfrm_md_info *info = lwt_xfrm_info(lwt); + + if (nla_put_u32(skb, LWT_XFRM_IF_ID, info->if_id) || + (info->link && nla_put_u32(skb, LWT_XFRM_LINK, info->link))) + return -EMSGSIZE; + + return 0; +} + +static int xfrmi_encap_nlsize(struct lwtunnel_state *lwtstate) +{ + return nla_total_size(sizeof(u32)) + /* LWT_XFRM_IF_ID */ + nla_total_size(sizeof(u32)); /* LWT_XFRM_LINK */ +} + +static int xfrmi_encap_cmp(struct lwtunnel_state *a, struct lwtunnel_state *b) +{ + struct xfrm_md_info *a_info = lwt_xfrm_info(a); + struct xfrm_md_info *b_info = lwt_xfrm_info(b); + + return memcmp(a_info, b_info, sizeof(*a_info)); +} + +static const struct lwtunnel_encap_ops xfrmi_encap_ops = { + .build_state = xfrmi_build_state, + .destroy_state = xfrmi_destroy_state, + .fill_encap = xfrmi_fill_encap_info, + .get_encap_size = xfrmi_encap_nlsize, + .cmp_encap = xfrmi_encap_cmp, + .owner = THIS_MODULE, +}; + +#define for_each_xfrmi_rcu(start, xi) \ + for (xi = rcu_dereference(start); xi; xi = rcu_dereference(xi->next)) + +static u32 xfrmi_hash(u32 if_id) +{ + return hash_32(if_id, XFRMI_HASH_BITS); +} + +static struct xfrm_if *xfrmi_lookup(struct net *net, struct xfrm_state *x) +{ + struct xfrmi_net *xfrmn = net_generic(net, xfrmi_net_id); + struct xfrm_if *xi; + + for_each_xfrmi_rcu(xfrmn->xfrmi[xfrmi_hash(x->if_id)], xi) { + if (x->if_id == xi->p.if_id && + (xi->dev->flags & IFF_UP)) + return xi; + } + + xi = rcu_dereference(xfrmn->collect_md_xfrmi); + if (xi && (xi->dev->flags & IFF_UP)) + return xi; + + return NULL; +} + +static bool xfrmi_decode_session(struct sk_buff *skb, + unsigned short family, + struct xfrm_if_decode_session_result *res) +{ + struct net_device *dev; + struct xfrm_if *xi; + int ifindex = 0; + + if (!secpath_exists(skb) || !skb->dev) + return false; + + switch (family) { + case AF_INET6: + ifindex = inet6_sdif(skb); + break; + case AF_INET: + ifindex = inet_sdif(skb); + break; + } + + if (ifindex) { + struct net *net = xs_net(xfrm_input_state(skb)); + + dev = dev_get_by_index_rcu(net, ifindex); + } else { + dev = skb->dev; + } + + if (!dev || !(dev->flags & IFF_UP)) + return false; + if (dev->netdev_ops != &xfrmi_netdev_ops) + return false; + + xi = netdev_priv(dev); + res->net = xi->net; + + if (xi->p.collect_md) + res->if_id = xfrm_input_state(skb)->if_id; + else + res->if_id = xi->p.if_id; + return true; +} + +static void xfrmi_link(struct xfrmi_net *xfrmn, struct xfrm_if *xi) +{ + struct xfrm_if __rcu **xip = &xfrmn->xfrmi[xfrmi_hash(xi->p.if_id)]; + + rcu_assign_pointer(xi->next , rtnl_dereference(*xip)); + rcu_assign_pointer(*xip, xi); +} + +static void xfrmi_unlink(struct xfrmi_net *xfrmn, struct xfrm_if *xi) +{ + struct xfrm_if __rcu **xip; + struct xfrm_if *iter; + + for (xip = &xfrmn->xfrmi[xfrmi_hash(xi->p.if_id)]; + (iter = rtnl_dereference(*xip)) != NULL; + xip = &iter->next) { + if (xi == iter) { + rcu_assign_pointer(*xip, xi->next); + break; + } + } +} + +static void xfrmi_dev_free(struct net_device *dev) +{ + struct xfrm_if *xi = netdev_priv(dev); + + gro_cells_destroy(&xi->gro_cells); + free_percpu(dev->tstats); +} + +static int xfrmi_create(struct net_device *dev) +{ + struct xfrm_if *xi = netdev_priv(dev); + struct net *net = dev_net(dev); + struct xfrmi_net *xfrmn = net_generic(net, xfrmi_net_id); + int err; + + dev->rtnl_link_ops = &xfrmi_link_ops; + err = register_netdevice(dev); + if (err < 0) + goto out; + + if (xi->p.collect_md) + rcu_assign_pointer(xfrmn->collect_md_xfrmi, xi); + else + xfrmi_link(xfrmn, xi); + + return 0; + +out: + return err; +} + +static struct xfrm_if *xfrmi_locate(struct net *net, struct xfrm_if_parms *p) +{ + struct xfrm_if __rcu **xip; + struct xfrm_if *xi; + struct xfrmi_net *xfrmn = net_generic(net, xfrmi_net_id); + + for (xip = &xfrmn->xfrmi[xfrmi_hash(p->if_id)]; + (xi = rtnl_dereference(*xip)) != NULL; + xip = &xi->next) + if (xi->p.if_id == p->if_id) + return xi; + + return NULL; +} + +static void xfrmi_dev_uninit(struct net_device *dev) +{ + struct xfrm_if *xi = netdev_priv(dev); + struct xfrmi_net *xfrmn = net_generic(xi->net, xfrmi_net_id); + + if (xi->p.collect_md) + RCU_INIT_POINTER(xfrmn->collect_md_xfrmi, NULL); + else + xfrmi_unlink(xfrmn, xi); +} + +static void xfrmi_scrub_packet(struct sk_buff *skb, bool xnet) +{ + skb_clear_tstamp(skb); + skb->pkt_type = PACKET_HOST; + skb->skb_iif = 0; + skb->ignore_df = 0; + skb_dst_drop(skb); + nf_reset_ct(skb); + nf_reset_trace(skb); + + if (!xnet) + return; + + ipvs_reset(skb); + secpath_reset(skb); + skb_orphan(skb); + skb->mark = 0; +} + +static int xfrmi_input(struct sk_buff *skb, int nexthdr, __be32 spi, + int encap_type, unsigned short family) +{ + struct sec_path *sp; + + sp = skb_sec_path(skb); + if (sp && (sp->len || sp->olen) && + !xfrm_policy_check(NULL, XFRM_POLICY_IN, skb, family)) + goto discard; + + XFRM_SPI_SKB_CB(skb)->family = family; + if (family == AF_INET) { + XFRM_SPI_SKB_CB(skb)->daddroff = offsetof(struct iphdr, daddr); + XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip4 = NULL; + } else { + XFRM_SPI_SKB_CB(skb)->daddroff = offsetof(struct ipv6hdr, daddr); + XFRM_TUNNEL_SKB_CB(skb)->tunnel.ip6 = NULL; + } + + return xfrm_input(skb, nexthdr, spi, encap_type); +discard: + kfree_skb(skb); + return 0; +} + +static int xfrmi4_rcv(struct sk_buff *skb) +{ + return xfrmi_input(skb, ip_hdr(skb)->protocol, 0, 0, AF_INET); +} + +static int xfrmi6_rcv(struct sk_buff *skb) +{ + return xfrmi_input(skb, skb_network_header(skb)[IP6CB(skb)->nhoff], + 0, 0, AF_INET6); +} + +static int xfrmi4_input(struct sk_buff *skb, int nexthdr, __be32 spi, int encap_type) +{ + return xfrmi_input(skb, nexthdr, spi, encap_type, AF_INET); +} + +static int xfrmi6_input(struct sk_buff *skb, int nexthdr, __be32 spi, int encap_type) +{ + return xfrmi_input(skb, nexthdr, spi, encap_type, AF_INET6); +} + +static int xfrmi_rcv_cb(struct sk_buff *skb, int err) +{ + const struct xfrm_mode *inner_mode; + struct net_device *dev; + struct xfrm_state *x; + struct xfrm_if *xi; + bool xnet; + int link; + + if (err && !secpath_exists(skb)) + return 0; + + x = xfrm_input_state(skb); + + xi = xfrmi_lookup(xs_net(x), x); + if (!xi) + return 1; + + link = skb->dev->ifindex; + dev = xi->dev; + skb->dev = dev; + + if (err) { + DEV_STATS_INC(dev, rx_errors); + DEV_STATS_INC(dev, rx_dropped); + + return 0; + } + + xnet = !net_eq(xi->net, dev_net(skb->dev)); + + if (xnet) { + inner_mode = &x->inner_mode; + + if (x->sel.family == AF_UNSPEC) { + inner_mode = xfrm_ip2inner_mode(x, XFRM_MODE_SKB_CB(skb)->protocol); + if (inner_mode == NULL) { + XFRM_INC_STATS(dev_net(skb->dev), + LINUX_MIB_XFRMINSTATEMODEERROR); + return -EINVAL; + } + } + + if (!xfrm_policy_check(NULL, XFRM_POLICY_IN, skb, + inner_mode->family)) + return -EPERM; + } + + xfrmi_scrub_packet(skb, xnet); + if (xi->p.collect_md) { + struct metadata_dst *md_dst; + + md_dst = metadata_dst_alloc(0, METADATA_XFRM, GFP_ATOMIC); + if (!md_dst) + return -ENOMEM; + + md_dst->u.xfrm_info.if_id = x->if_id; + md_dst->u.xfrm_info.link = link; + skb_dst_set(skb, (struct dst_entry *)md_dst); + } + dev_sw_netstats_rx_add(dev, skb->len); + + return 0; +} + +static int +xfrmi_xmit2(struct sk_buff *skb, struct net_device *dev, struct flowi *fl) +{ + struct xfrm_if *xi = netdev_priv(dev); + struct dst_entry *dst = skb_dst(skb); + unsigned int length = skb->len; + struct net_device *tdev; + struct xfrm_state *x; + int err = -1; + u32 if_id; + int mtu; + + if (xi->p.collect_md) { + struct xfrm_md_info *md_info = skb_xfrm_md_info(skb); + + if (unlikely(!md_info)) + return -EINVAL; + + if_id = md_info->if_id; + fl->flowi_oif = md_info->link; + } else { + if_id = xi->p.if_id; + } + + dst_hold(dst); + dst = xfrm_lookup_with_ifid(xi->net, dst, fl, NULL, 0, if_id); + if (IS_ERR(dst)) { + err = PTR_ERR(dst); + dst = NULL; + goto tx_err_link_failure; + } + + x = dst->xfrm; + if (!x) + goto tx_err_link_failure; + + if (x->if_id != if_id) + goto tx_err_link_failure; + + tdev = dst->dev; + + if (tdev == dev) { + DEV_STATS_INC(dev, collisions); + net_warn_ratelimited("%s: Local routing loop detected!\n", + dev->name); + goto tx_err_dst_release; + } + + mtu = dst_mtu(dst); + if ((!skb_is_gso(skb) && skb->len > mtu) || + (skb_is_gso(skb) && !skb_gso_validate_network_len(skb, mtu))) { + skb_dst_update_pmtu_no_confirm(skb, mtu); + + if (skb->protocol == htons(ETH_P_IPV6)) { + if (mtu < IPV6_MIN_MTU) + mtu = IPV6_MIN_MTU; + + if (skb->len > 1280) + icmpv6_ndo_send(skb, ICMPV6_PKT_TOOBIG, 0, mtu); + else + goto xmit; + } else { + if (!(ip_hdr(skb)->frag_off & htons(IP_DF))) + goto xmit; + icmp_ndo_send(skb, ICMP_DEST_UNREACH, ICMP_FRAG_NEEDED, + htonl(mtu)); + } + + dst_release(dst); + return -EMSGSIZE; + } + +xmit: + xfrmi_scrub_packet(skb, !net_eq(xi->net, dev_net(dev))); + skb_dst_set(skb, dst); + skb->dev = tdev; + + err = dst_output(xi->net, skb->sk, skb); + if (net_xmit_eval(err) == 0) { + dev_sw_netstats_tx_add(dev, 1, length); + } else { + DEV_STATS_INC(dev, tx_errors); + DEV_STATS_INC(dev, tx_aborted_errors); + } + + return 0; +tx_err_link_failure: + DEV_STATS_INC(dev, tx_carrier_errors); + dst_link_failure(skb); +tx_err_dst_release: + dst_release(dst); + return err; +} + +static netdev_tx_t xfrmi_xmit(struct sk_buff *skb, struct net_device *dev) +{ + struct xfrm_if *xi = netdev_priv(dev); + struct dst_entry *dst = skb_dst(skb); + struct flowi fl; + int ret; + + memset(&fl, 0, sizeof(fl)); + + switch (skb->protocol) { + case htons(ETH_P_IPV6): + memset(IP6CB(skb), 0, sizeof(*IP6CB(skb))); + xfrm_decode_session(skb, &fl, AF_INET6); + if (!dst) { + fl.u.ip6.flowi6_oif = dev->ifindex; + fl.u.ip6.flowi6_flags |= FLOWI_FLAG_ANYSRC; + dst = ip6_route_output(dev_net(dev), NULL, &fl.u.ip6); + if (dst->error) { + dst_release(dst); + DEV_STATS_INC(dev, tx_carrier_errors); + goto tx_err; + } + skb_dst_set(skb, dst); + } + break; + case htons(ETH_P_IP): + memset(IPCB(skb), 0, sizeof(*IPCB(skb))); + xfrm_decode_session(skb, &fl, AF_INET); + if (!dst) { + struct rtable *rt; + + fl.u.ip4.flowi4_oif = dev->ifindex; + fl.u.ip4.flowi4_flags |= FLOWI_FLAG_ANYSRC; + rt = __ip_route_output_key(dev_net(dev), &fl.u.ip4); + if (IS_ERR(rt)) { + DEV_STATS_INC(dev, tx_carrier_errors); + goto tx_err; + } + skb_dst_set(skb, &rt->dst); + } + break; + default: + goto tx_err; + } + + fl.flowi_oif = xi->p.link; + + ret = xfrmi_xmit2(skb, dev, &fl); + if (ret < 0) + goto tx_err; + + return NETDEV_TX_OK; + +tx_err: + DEV_STATS_INC(dev, tx_errors); + DEV_STATS_INC(dev, tx_dropped); + kfree_skb(skb); + return NETDEV_TX_OK; +} + +static int xfrmi4_err(struct sk_buff *skb, u32 info) +{ + const struct iphdr *iph = (const struct iphdr *)skb->data; + struct net *net = dev_net(skb->dev); + int protocol = iph->protocol; + struct ip_comp_hdr *ipch; + struct ip_esp_hdr *esph; + struct ip_auth_hdr *ah ; + struct xfrm_state *x; + struct xfrm_if *xi; + __be32 spi; + + switch (protocol) { + case IPPROTO_ESP: + esph = (struct ip_esp_hdr *)(skb->data+(iph->ihl<<2)); + spi = esph->spi; + break; + case IPPROTO_AH: + ah = (struct ip_auth_hdr *)(skb->data+(iph->ihl<<2)); + spi = ah->spi; + break; + case IPPROTO_COMP: + ipch = (struct ip_comp_hdr *)(skb->data+(iph->ihl<<2)); + spi = htonl(ntohs(ipch->cpi)); + break; + default: + return 0; + } + + switch (icmp_hdr(skb)->type) { + case ICMP_DEST_UNREACH: + if (icmp_hdr(skb)->code != ICMP_FRAG_NEEDED) + return 0; + break; + case ICMP_REDIRECT: + break; + default: + return 0; + } + + x = xfrm_state_lookup(net, skb->mark, (const xfrm_address_t *)&iph->daddr, + spi, protocol, AF_INET); + if (!x) + return 0; + + xi = xfrmi_lookup(net, x); + if (!xi) { + xfrm_state_put(x); + return -1; + } + + if (icmp_hdr(skb)->type == ICMP_DEST_UNREACH) + ipv4_update_pmtu(skb, net, info, 0, protocol); + else + ipv4_redirect(skb, net, 0, protocol); + xfrm_state_put(x); + + return 0; +} + +static int xfrmi6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, + u8 type, u8 code, int offset, __be32 info) +{ + const struct ipv6hdr *iph = (const struct ipv6hdr *)skb->data; + struct net *net = dev_net(skb->dev); + int protocol = iph->nexthdr; + struct ip_comp_hdr *ipch; + struct ip_esp_hdr *esph; + struct ip_auth_hdr *ah; + struct xfrm_state *x; + struct xfrm_if *xi; + __be32 spi; + + switch (protocol) { + case IPPROTO_ESP: + esph = (struct ip_esp_hdr *)(skb->data + offset); + spi = esph->spi; + break; + case IPPROTO_AH: + ah = (struct ip_auth_hdr *)(skb->data + offset); + spi = ah->spi; + break; + case IPPROTO_COMP: + ipch = (struct ip_comp_hdr *)(skb->data + offset); + spi = htonl(ntohs(ipch->cpi)); + break; + default: + return 0; + } + + if (type != ICMPV6_PKT_TOOBIG && + type != NDISC_REDIRECT) + return 0; + + x = xfrm_state_lookup(net, skb->mark, (const xfrm_address_t *)&iph->daddr, + spi, protocol, AF_INET6); + if (!x) + return 0; + + xi = xfrmi_lookup(net, x); + if (!xi) { + xfrm_state_put(x); + return -1; + } + + if (type == NDISC_REDIRECT) + ip6_redirect(skb, net, skb->dev->ifindex, 0, + sock_net_uid(net, NULL)); + else + ip6_update_pmtu(skb, net, info, 0, 0, sock_net_uid(net, NULL)); + xfrm_state_put(x); + + return 0; +} + +static int xfrmi_change(struct xfrm_if *xi, const struct xfrm_if_parms *p) +{ + if (xi->p.link != p->link) + return -EINVAL; + + xi->p.if_id = p->if_id; + + return 0; +} + +static int xfrmi_update(struct xfrm_if *xi, struct xfrm_if_parms *p) +{ + struct net *net = xi->net; + struct xfrmi_net *xfrmn = net_generic(net, xfrmi_net_id); + int err; + + xfrmi_unlink(xfrmn, xi); + synchronize_net(); + err = xfrmi_change(xi, p); + xfrmi_link(xfrmn, xi); + netdev_state_change(xi->dev); + return err; +} + +static int xfrmi_get_iflink(const struct net_device *dev) +{ + struct xfrm_if *xi = netdev_priv(dev); + + return xi->p.link; +} + +static const struct net_device_ops xfrmi_netdev_ops = { + .ndo_init = xfrmi_dev_init, + .ndo_uninit = xfrmi_dev_uninit, + .ndo_start_xmit = xfrmi_xmit, + .ndo_get_stats64 = dev_get_tstats64, + .ndo_get_iflink = xfrmi_get_iflink, +}; + +static void xfrmi_dev_setup(struct net_device *dev) +{ + dev->netdev_ops = &xfrmi_netdev_ops; + dev->header_ops = &ip_tunnel_header_ops; + dev->type = ARPHRD_NONE; + dev->mtu = ETH_DATA_LEN; + dev->min_mtu = ETH_MIN_MTU; + dev->max_mtu = IP_MAX_MTU; + dev->flags = IFF_NOARP; + dev->needs_free_netdev = true; + dev->priv_destructor = xfrmi_dev_free; + netif_keep_dst(dev); + + eth_broadcast_addr(dev->broadcast); +} + +#define XFRMI_FEATURES (NETIF_F_SG | \ + NETIF_F_FRAGLIST | \ + NETIF_F_GSO_SOFTWARE | \ + NETIF_F_HW_CSUM) + +static int xfrmi_dev_init(struct net_device *dev) +{ + struct xfrm_if *xi = netdev_priv(dev); + struct net_device *phydev = __dev_get_by_index(xi->net, xi->p.link); + int err; + + dev->tstats = netdev_alloc_pcpu_stats(struct pcpu_sw_netstats); + if (!dev->tstats) + return -ENOMEM; + + err = gro_cells_init(&xi->gro_cells, dev); + if (err) { + free_percpu(dev->tstats); + return err; + } + + dev->features |= NETIF_F_LLTX; + dev->features |= XFRMI_FEATURES; + dev->hw_features |= XFRMI_FEATURES; + + if (phydev) { + dev->needed_headroom = phydev->needed_headroom; + dev->needed_tailroom = phydev->needed_tailroom; + + if (is_zero_ether_addr(dev->dev_addr)) + eth_hw_addr_inherit(dev, phydev); + if (is_zero_ether_addr(dev->broadcast)) + memcpy(dev->broadcast, phydev->broadcast, + dev->addr_len); + } else { + eth_hw_addr_random(dev); + eth_broadcast_addr(dev->broadcast); + } + + return 0; +} + +static int xfrmi_validate(struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + return 0; +} + +static void xfrmi_netlink_parms(struct nlattr *data[], + struct xfrm_if_parms *parms) +{ + memset(parms, 0, sizeof(*parms)); + + if (!data) + return; + + if (data[IFLA_XFRM_LINK]) + parms->link = nla_get_u32(data[IFLA_XFRM_LINK]); + + if (data[IFLA_XFRM_IF_ID]) + parms->if_id = nla_get_u32(data[IFLA_XFRM_IF_ID]); + + if (data[IFLA_XFRM_COLLECT_METADATA]) + parms->collect_md = true; +} + +static int xfrmi_newlink(struct net *src_net, struct net_device *dev, + struct nlattr *tb[], struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct net *net = dev_net(dev); + struct xfrm_if_parms p = {}; + struct xfrm_if *xi; + int err; + + xfrmi_netlink_parms(data, &p); + if (p.collect_md) { + struct xfrmi_net *xfrmn = net_generic(net, xfrmi_net_id); + + if (p.link || p.if_id) { + NL_SET_ERR_MSG(extack, "link and if_id must be zero"); + return -EINVAL; + } + + if (rtnl_dereference(xfrmn->collect_md_xfrmi)) + return -EEXIST; + + } else { + if (!p.if_id) { + NL_SET_ERR_MSG(extack, "if_id must be non zero"); + return -EINVAL; + } + + xi = xfrmi_locate(net, &p); + if (xi) + return -EEXIST; + } + + xi = netdev_priv(dev); + xi->p = p; + xi->net = net; + xi->dev = dev; + + err = xfrmi_create(dev); + return err; +} + +static void xfrmi_dellink(struct net_device *dev, struct list_head *head) +{ + unregister_netdevice_queue(dev, head); +} + +static int xfrmi_changelink(struct net_device *dev, struct nlattr *tb[], + struct nlattr *data[], + struct netlink_ext_ack *extack) +{ + struct xfrm_if *xi = netdev_priv(dev); + struct net *net = xi->net; + struct xfrm_if_parms p = {}; + + xfrmi_netlink_parms(data, &p); + if (!p.if_id) { + NL_SET_ERR_MSG(extack, "if_id must be non zero"); + return -EINVAL; + } + + if (p.collect_md) { + NL_SET_ERR_MSG(extack, "collect_md can't be changed"); + return -EINVAL; + } + + xi = xfrmi_locate(net, &p); + if (!xi) { + xi = netdev_priv(dev); + } else { + if (xi->dev != dev) + return -EEXIST; + if (xi->p.collect_md) { + NL_SET_ERR_MSG(extack, + "device can't be changed to collect_md"); + return -EINVAL; + } + } + + return xfrmi_update(xi, &p); +} + +static size_t xfrmi_get_size(const struct net_device *dev) +{ + return + /* IFLA_XFRM_LINK */ + nla_total_size(4) + + /* IFLA_XFRM_IF_ID */ + nla_total_size(4) + + /* IFLA_XFRM_COLLECT_METADATA */ + nla_total_size(0) + + 0; +} + +static int xfrmi_fill_info(struct sk_buff *skb, const struct net_device *dev) +{ + struct xfrm_if *xi = netdev_priv(dev); + struct xfrm_if_parms *parm = &xi->p; + + if (nla_put_u32(skb, IFLA_XFRM_LINK, parm->link) || + nla_put_u32(skb, IFLA_XFRM_IF_ID, parm->if_id) || + (xi->p.collect_md && nla_put_flag(skb, IFLA_XFRM_COLLECT_METADATA))) + goto nla_put_failure; + return 0; + +nla_put_failure: + return -EMSGSIZE; +} + +static struct net *xfrmi_get_link_net(const struct net_device *dev) +{ + struct xfrm_if *xi = netdev_priv(dev); + + return xi->net; +} + +static const struct nla_policy xfrmi_policy[IFLA_XFRM_MAX + 1] = { + [IFLA_XFRM_UNSPEC] = { .strict_start_type = IFLA_XFRM_COLLECT_METADATA }, + [IFLA_XFRM_LINK] = { .type = NLA_U32 }, + [IFLA_XFRM_IF_ID] = { .type = NLA_U32 }, + [IFLA_XFRM_COLLECT_METADATA] = { .type = NLA_FLAG }, +}; + +static struct rtnl_link_ops xfrmi_link_ops __read_mostly = { + .kind = "xfrm", + .maxtype = IFLA_XFRM_MAX, + .policy = xfrmi_policy, + .priv_size = sizeof(struct xfrm_if), + .setup = xfrmi_dev_setup, + .validate = xfrmi_validate, + .newlink = xfrmi_newlink, + .dellink = xfrmi_dellink, + .changelink = xfrmi_changelink, + .get_size = xfrmi_get_size, + .fill_info = xfrmi_fill_info, + .get_link_net = xfrmi_get_link_net, +}; + +static void __net_exit xfrmi_exit_batch_net(struct list_head *net_exit_list) +{ + struct net *net; + LIST_HEAD(list); + + rtnl_lock(); + list_for_each_entry(net, net_exit_list, exit_list) { + struct xfrmi_net *xfrmn = net_generic(net, xfrmi_net_id); + struct xfrm_if __rcu **xip; + struct xfrm_if *xi; + int i; + + for (i = 0; i < XFRMI_HASH_SIZE; i++) { + for (xip = &xfrmn->xfrmi[i]; + (xi = rtnl_dereference(*xip)) != NULL; + xip = &xi->next) + unregister_netdevice_queue(xi->dev, &list); + } + xi = rtnl_dereference(xfrmn->collect_md_xfrmi); + if (xi) + unregister_netdevice_queue(xi->dev, &list); + } + unregister_netdevice_many(&list); + rtnl_unlock(); +} + +static struct pernet_operations xfrmi_net_ops = { + .exit_batch = xfrmi_exit_batch_net, + .id = &xfrmi_net_id, + .size = sizeof(struct xfrmi_net), +}; + +static struct xfrm6_protocol xfrmi_esp6_protocol __read_mostly = { + .handler = xfrmi6_rcv, + .input_handler = xfrmi6_input, + .cb_handler = xfrmi_rcv_cb, + .err_handler = xfrmi6_err, + .priority = 10, +}; + +static struct xfrm6_protocol xfrmi_ah6_protocol __read_mostly = { + .handler = xfrm6_rcv, + .input_handler = xfrm_input, + .cb_handler = xfrmi_rcv_cb, + .err_handler = xfrmi6_err, + .priority = 10, +}; + +static struct xfrm6_protocol xfrmi_ipcomp6_protocol __read_mostly = { + .handler = xfrm6_rcv, + .input_handler = xfrm_input, + .cb_handler = xfrmi_rcv_cb, + .err_handler = xfrmi6_err, + .priority = 10, +}; + +#if IS_REACHABLE(CONFIG_INET6_XFRM_TUNNEL) +static int xfrmi6_rcv_tunnel(struct sk_buff *skb) +{ + const xfrm_address_t *saddr; + __be32 spi; + + saddr = (const xfrm_address_t *)&ipv6_hdr(skb)->saddr; + spi = xfrm6_tunnel_spi_lookup(dev_net(skb->dev), saddr); + + return xfrm6_rcv_spi(skb, IPPROTO_IPV6, spi, NULL); +} + +static struct xfrm6_tunnel xfrmi_ipv6_handler __read_mostly = { + .handler = xfrmi6_rcv_tunnel, + .cb_handler = xfrmi_rcv_cb, + .err_handler = xfrmi6_err, + .priority = 2, +}; + +static struct xfrm6_tunnel xfrmi_ip6ip_handler __read_mostly = { + .handler = xfrmi6_rcv_tunnel, + .cb_handler = xfrmi_rcv_cb, + .err_handler = xfrmi6_err, + .priority = 2, +}; +#endif + +static struct xfrm4_protocol xfrmi_esp4_protocol __read_mostly = { + .handler = xfrmi4_rcv, + .input_handler = xfrmi4_input, + .cb_handler = xfrmi_rcv_cb, + .err_handler = xfrmi4_err, + .priority = 10, +}; + +static struct xfrm4_protocol xfrmi_ah4_protocol __read_mostly = { + .handler = xfrm4_rcv, + .input_handler = xfrm_input, + .cb_handler = xfrmi_rcv_cb, + .err_handler = xfrmi4_err, + .priority = 10, +}; + +static struct xfrm4_protocol xfrmi_ipcomp4_protocol __read_mostly = { + .handler = xfrm4_rcv, + .input_handler = xfrm_input, + .cb_handler = xfrmi_rcv_cb, + .err_handler = xfrmi4_err, + .priority = 10, +}; + +#if IS_REACHABLE(CONFIG_INET_XFRM_TUNNEL) +static int xfrmi4_rcv_tunnel(struct sk_buff *skb) +{ + return xfrm4_rcv_spi(skb, IPPROTO_IPIP, ip_hdr(skb)->saddr); +} + +static struct xfrm_tunnel xfrmi_ipip_handler __read_mostly = { + .handler = xfrmi4_rcv_tunnel, + .cb_handler = xfrmi_rcv_cb, + .err_handler = xfrmi4_err, + .priority = 3, +}; + +static struct xfrm_tunnel xfrmi_ipip6_handler __read_mostly = { + .handler = xfrmi4_rcv_tunnel, + .cb_handler = xfrmi_rcv_cb, + .err_handler = xfrmi4_err, + .priority = 2, +}; +#endif + +static int __init xfrmi4_init(void) +{ + int err; + + err = xfrm4_protocol_register(&xfrmi_esp4_protocol, IPPROTO_ESP); + if (err < 0) + goto xfrm_proto_esp_failed; + err = xfrm4_protocol_register(&xfrmi_ah4_protocol, IPPROTO_AH); + if (err < 0) + goto xfrm_proto_ah_failed; + err = xfrm4_protocol_register(&xfrmi_ipcomp4_protocol, IPPROTO_COMP); + if (err < 0) + goto xfrm_proto_comp_failed; +#if IS_REACHABLE(CONFIG_INET_XFRM_TUNNEL) + err = xfrm4_tunnel_register(&xfrmi_ipip_handler, AF_INET); + if (err < 0) + goto xfrm_tunnel_ipip_failed; + err = xfrm4_tunnel_register(&xfrmi_ipip6_handler, AF_INET6); + if (err < 0) + goto xfrm_tunnel_ipip6_failed; +#endif + + return 0; + +#if IS_REACHABLE(CONFIG_INET_XFRM_TUNNEL) +xfrm_tunnel_ipip6_failed: + xfrm4_tunnel_deregister(&xfrmi_ipip_handler, AF_INET); +xfrm_tunnel_ipip_failed: + xfrm4_protocol_deregister(&xfrmi_ipcomp4_protocol, IPPROTO_COMP); +#endif +xfrm_proto_comp_failed: + xfrm4_protocol_deregister(&xfrmi_ah4_protocol, IPPROTO_AH); +xfrm_proto_ah_failed: + xfrm4_protocol_deregister(&xfrmi_esp4_protocol, IPPROTO_ESP); +xfrm_proto_esp_failed: + return err; +} + +static void xfrmi4_fini(void) +{ +#if IS_REACHABLE(CONFIG_INET_XFRM_TUNNEL) + xfrm4_tunnel_deregister(&xfrmi_ipip6_handler, AF_INET6); + xfrm4_tunnel_deregister(&xfrmi_ipip_handler, AF_INET); +#endif + xfrm4_protocol_deregister(&xfrmi_ipcomp4_protocol, IPPROTO_COMP); + xfrm4_protocol_deregister(&xfrmi_ah4_protocol, IPPROTO_AH); + xfrm4_protocol_deregister(&xfrmi_esp4_protocol, IPPROTO_ESP); +} + +static int __init xfrmi6_init(void) +{ + int err; + + err = xfrm6_protocol_register(&xfrmi_esp6_protocol, IPPROTO_ESP); + if (err < 0) + goto xfrm_proto_esp_failed; + err = xfrm6_protocol_register(&xfrmi_ah6_protocol, IPPROTO_AH); + if (err < 0) + goto xfrm_proto_ah_failed; + err = xfrm6_protocol_register(&xfrmi_ipcomp6_protocol, IPPROTO_COMP); + if (err < 0) + goto xfrm_proto_comp_failed; +#if IS_REACHABLE(CONFIG_INET6_XFRM_TUNNEL) + err = xfrm6_tunnel_register(&xfrmi_ipv6_handler, AF_INET6); + if (err < 0) + goto xfrm_tunnel_ipv6_failed; + err = xfrm6_tunnel_register(&xfrmi_ip6ip_handler, AF_INET); + if (err < 0) + goto xfrm_tunnel_ip6ip_failed; +#endif + + return 0; + +#if IS_REACHABLE(CONFIG_INET6_XFRM_TUNNEL) +xfrm_tunnel_ip6ip_failed: + xfrm6_tunnel_deregister(&xfrmi_ipv6_handler, AF_INET6); +xfrm_tunnel_ipv6_failed: + xfrm6_protocol_deregister(&xfrmi_ipcomp6_protocol, IPPROTO_COMP); +#endif +xfrm_proto_comp_failed: + xfrm6_protocol_deregister(&xfrmi_ah6_protocol, IPPROTO_AH); +xfrm_proto_ah_failed: + xfrm6_protocol_deregister(&xfrmi_esp6_protocol, IPPROTO_ESP); +xfrm_proto_esp_failed: + return err; +} + +static void xfrmi6_fini(void) +{ +#if IS_REACHABLE(CONFIG_INET6_XFRM_TUNNEL) + xfrm6_tunnel_deregister(&xfrmi_ip6ip_handler, AF_INET); + xfrm6_tunnel_deregister(&xfrmi_ipv6_handler, AF_INET6); +#endif + xfrm6_protocol_deregister(&xfrmi_ipcomp6_protocol, IPPROTO_COMP); + xfrm6_protocol_deregister(&xfrmi_ah6_protocol, IPPROTO_AH); + xfrm6_protocol_deregister(&xfrmi_esp6_protocol, IPPROTO_ESP); +} + +static const struct xfrm_if_cb xfrm_if_cb = { + .decode_session = xfrmi_decode_session, +}; + +static int __init xfrmi_init(void) +{ + const char *msg; + int err; + + pr_info("IPsec XFRM device driver\n"); + + msg = "tunnel device"; + err = register_pernet_device(&xfrmi_net_ops); + if (err < 0) + goto pernet_dev_failed; + + msg = "xfrm4 protocols"; + err = xfrmi4_init(); + if (err < 0) + goto xfrmi4_failed; + + msg = "xfrm6 protocols"; + err = xfrmi6_init(); + if (err < 0) + goto xfrmi6_failed; + + + msg = "netlink interface"; + err = rtnl_link_register(&xfrmi_link_ops); + if (err < 0) + goto rtnl_link_failed; + + lwtunnel_encap_add_ops(&xfrmi_encap_ops, LWTUNNEL_ENCAP_XFRM); + + xfrm_if_register_cb(&xfrm_if_cb); + + return err; + +rtnl_link_failed: + xfrmi6_fini(); +xfrmi6_failed: + xfrmi4_fini(); +xfrmi4_failed: + unregister_pernet_device(&xfrmi_net_ops); +pernet_dev_failed: + pr_err("xfrmi init: failed to register %s\n", msg); + return err; +} + +static void __exit xfrmi_fini(void) +{ + xfrm_if_unregister_cb(); + lwtunnel_encap_del_ops(&xfrmi_encap_ops, LWTUNNEL_ENCAP_XFRM); + rtnl_link_unregister(&xfrmi_link_ops); + xfrmi4_fini(); + xfrmi6_fini(); + unregister_pernet_device(&xfrmi_net_ops); +} + +module_init(xfrmi_init); +module_exit(xfrmi_fini); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_RTNL_LINK("xfrm"); +MODULE_ALIAS_NETDEV("xfrm0"); +MODULE_AUTHOR("Steffen Klassert"); +MODULE_DESCRIPTION("XFRM virtual interface"); diff --git a/net/xfrm/xfrm_ipcomp.c b/net/xfrm/xfrm_ipcomp.c new file mode 100644 index 000000000..80143360b --- /dev/null +++ b/net/xfrm/xfrm_ipcomp.c @@ -0,0 +1,378 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * IP Payload Compression Protocol (IPComp) - RFC3173. + * + * Copyright (c) 2003 James Morris + * Copyright (c) 2003-2008 Herbert Xu + * + * Todo: + * - Tunable compression parameters. + * - Compression stats. + * - Adaptive compression. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct ipcomp_tfms { + struct list_head list; + struct crypto_comp * __percpu *tfms; + int users; +}; + +static DEFINE_MUTEX(ipcomp_resource_mutex); +static void * __percpu *ipcomp_scratches; +static int ipcomp_scratch_users; +static LIST_HEAD(ipcomp_tfms_list); + +static int ipcomp_decompress(struct xfrm_state *x, struct sk_buff *skb) +{ + struct ipcomp_data *ipcd = x->data; + const int plen = skb->len; + int dlen = IPCOMP_SCRATCH_SIZE; + const u8 *start = skb->data; + u8 *scratch = *this_cpu_ptr(ipcomp_scratches); + struct crypto_comp *tfm = *this_cpu_ptr(ipcd->tfms); + int err = crypto_comp_decompress(tfm, start, plen, scratch, &dlen); + int len; + + if (err) + return err; + + if (dlen < (plen + sizeof(struct ip_comp_hdr))) + return -EINVAL; + + len = dlen - plen; + if (len > skb_tailroom(skb)) + len = skb_tailroom(skb); + + __skb_put(skb, len); + + len += plen; + skb_copy_to_linear_data(skb, scratch, len); + + while ((scratch += len, dlen -= len) > 0) { + skb_frag_t *frag; + struct page *page; + + if (WARN_ON(skb_shinfo(skb)->nr_frags >= MAX_SKB_FRAGS)) + return -EMSGSIZE; + + frag = skb_shinfo(skb)->frags + skb_shinfo(skb)->nr_frags; + page = alloc_page(GFP_ATOMIC); + + if (!page) + return -ENOMEM; + + __skb_frag_set_page(frag, page); + + len = PAGE_SIZE; + if (dlen < len) + len = dlen; + + skb_frag_off_set(frag, 0); + skb_frag_size_set(frag, len); + memcpy(skb_frag_address(frag), scratch, len); + + skb->truesize += len; + skb->data_len += len; + skb->len += len; + + skb_shinfo(skb)->nr_frags++; + } + + return 0; +} + +int ipcomp_input(struct xfrm_state *x, struct sk_buff *skb) +{ + int nexthdr; + int err = -ENOMEM; + struct ip_comp_hdr *ipch; + + if (skb_linearize_cow(skb)) + goto out; + + skb->ip_summed = CHECKSUM_NONE; + + /* Remove ipcomp header and decompress original payload */ + ipch = (void *)skb->data; + nexthdr = ipch->nexthdr; + + skb->transport_header = skb->network_header + sizeof(*ipch); + __skb_pull(skb, sizeof(*ipch)); + err = ipcomp_decompress(x, skb); + if (err) + goto out; + + err = nexthdr; + +out: + return err; +} +EXPORT_SYMBOL_GPL(ipcomp_input); + +static int ipcomp_compress(struct xfrm_state *x, struct sk_buff *skb) +{ + struct ipcomp_data *ipcd = x->data; + const int plen = skb->len; + int dlen = IPCOMP_SCRATCH_SIZE; + u8 *start = skb->data; + struct crypto_comp *tfm; + u8 *scratch; + int err; + + local_bh_disable(); + scratch = *this_cpu_ptr(ipcomp_scratches); + tfm = *this_cpu_ptr(ipcd->tfms); + err = crypto_comp_compress(tfm, start, plen, scratch, &dlen); + if (err) + goto out; + + if ((dlen + sizeof(struct ip_comp_hdr)) >= plen) { + err = -EMSGSIZE; + goto out; + } + + memcpy(start + sizeof(struct ip_comp_hdr), scratch, dlen); + local_bh_enable(); + + pskb_trim(skb, dlen + sizeof(struct ip_comp_hdr)); + return 0; + +out: + local_bh_enable(); + return err; +} + +int ipcomp_output(struct xfrm_state *x, struct sk_buff *skb) +{ + int err; + struct ip_comp_hdr *ipch; + struct ipcomp_data *ipcd = x->data; + + if (skb->len < ipcd->threshold) { + /* Don't bother compressing */ + goto out_ok; + } + + if (skb_linearize_cow(skb)) + goto out_ok; + + err = ipcomp_compress(x, skb); + + if (err) { + goto out_ok; + } + + /* Install ipcomp header, convert into ipcomp datagram. */ + ipch = ip_comp_hdr(skb); + ipch->nexthdr = *skb_mac_header(skb); + ipch->flags = 0; + ipch->cpi = htons((u16 )ntohl(x->id.spi)); + *skb_mac_header(skb) = IPPROTO_COMP; +out_ok: + skb_push(skb, -skb_network_offset(skb)); + return 0; +} +EXPORT_SYMBOL_GPL(ipcomp_output); + +static void ipcomp_free_scratches(void) +{ + int i; + void * __percpu *scratches; + + if (--ipcomp_scratch_users) + return; + + scratches = ipcomp_scratches; + if (!scratches) + return; + + for_each_possible_cpu(i) + vfree(*per_cpu_ptr(scratches, i)); + + free_percpu(scratches); + ipcomp_scratches = NULL; +} + +static void * __percpu *ipcomp_alloc_scratches(void) +{ + void * __percpu *scratches; + int i; + + if (ipcomp_scratch_users++) + return ipcomp_scratches; + + scratches = alloc_percpu(void *); + if (!scratches) + return NULL; + + ipcomp_scratches = scratches; + + for_each_possible_cpu(i) { + void *scratch; + + scratch = vmalloc_node(IPCOMP_SCRATCH_SIZE, cpu_to_node(i)); + if (!scratch) + return NULL; + *per_cpu_ptr(scratches, i) = scratch; + } + + return scratches; +} + +static void ipcomp_free_tfms(struct crypto_comp * __percpu *tfms) +{ + struct ipcomp_tfms *pos; + int cpu; + + list_for_each_entry(pos, &ipcomp_tfms_list, list) { + if (pos->tfms == tfms) + break; + } + + WARN_ON(list_entry_is_head(pos, &ipcomp_tfms_list, list)); + + if (--pos->users) + return; + + list_del(&pos->list); + kfree(pos); + + if (!tfms) + return; + + for_each_possible_cpu(cpu) { + struct crypto_comp *tfm = *per_cpu_ptr(tfms, cpu); + crypto_free_comp(tfm); + } + free_percpu(tfms); +} + +static struct crypto_comp * __percpu *ipcomp_alloc_tfms(const char *alg_name) +{ + struct ipcomp_tfms *pos; + struct crypto_comp * __percpu *tfms; + int cpu; + + + list_for_each_entry(pos, &ipcomp_tfms_list, list) { + struct crypto_comp *tfm; + + /* This can be any valid CPU ID so we don't need locking. */ + tfm = this_cpu_read(*pos->tfms); + + if (!strcmp(crypto_comp_name(tfm), alg_name)) { + pos->users++; + return pos->tfms; + } + } + + pos = kmalloc(sizeof(*pos), GFP_KERNEL); + if (!pos) + return NULL; + + pos->users = 1; + INIT_LIST_HEAD(&pos->list); + list_add(&pos->list, &ipcomp_tfms_list); + + pos->tfms = tfms = alloc_percpu(struct crypto_comp *); + if (!tfms) + goto error; + + for_each_possible_cpu(cpu) { + struct crypto_comp *tfm = crypto_alloc_comp(alg_name, 0, + CRYPTO_ALG_ASYNC); + if (IS_ERR(tfm)) + goto error; + *per_cpu_ptr(tfms, cpu) = tfm; + } + + return tfms; + +error: + ipcomp_free_tfms(tfms); + return NULL; +} + +static void ipcomp_free_data(struct ipcomp_data *ipcd) +{ + if (ipcd->tfms) + ipcomp_free_tfms(ipcd->tfms); + ipcomp_free_scratches(); +} + +void ipcomp_destroy(struct xfrm_state *x) +{ + struct ipcomp_data *ipcd = x->data; + if (!ipcd) + return; + xfrm_state_delete_tunnel(x); + mutex_lock(&ipcomp_resource_mutex); + ipcomp_free_data(ipcd); + mutex_unlock(&ipcomp_resource_mutex); + kfree(ipcd); +} +EXPORT_SYMBOL_GPL(ipcomp_destroy); + +int ipcomp_init_state(struct xfrm_state *x, struct netlink_ext_ack *extack) +{ + int err; + struct ipcomp_data *ipcd; + struct xfrm_algo_desc *calg_desc; + + err = -EINVAL; + if (!x->calg) { + NL_SET_ERR_MSG(extack, "Missing required compression algorithm"); + goto out; + } + + if (x->encap) { + NL_SET_ERR_MSG(extack, "IPComp is not compatible with encapsulation"); + goto out; + } + + err = -ENOMEM; + ipcd = kzalloc(sizeof(*ipcd), GFP_KERNEL); + if (!ipcd) + goto out; + + mutex_lock(&ipcomp_resource_mutex); + if (!ipcomp_alloc_scratches()) + goto error; + + ipcd->tfms = ipcomp_alloc_tfms(x->calg->alg_name); + if (!ipcd->tfms) + goto error; + mutex_unlock(&ipcomp_resource_mutex); + + calg_desc = xfrm_calg_get_byname(x->calg->alg_name, 0); + BUG_ON(!calg_desc); + ipcd->threshold = calg_desc->uinfo.comp.threshold; + x->data = ipcd; + err = 0; +out: + return err; + +error: + ipcomp_free_data(ipcd); + mutex_unlock(&ipcomp_resource_mutex); + kfree(ipcd); + goto out; +} +EXPORT_SYMBOL_GPL(ipcomp_init_state); + +MODULE_LICENSE("GPL"); +MODULE_DESCRIPTION("IP Payload Compression Protocol (IPComp) - RFC3173"); +MODULE_AUTHOR("James Morris "); diff --git a/net/xfrm/xfrm_output.c b/net/xfrm/xfrm_output.c new file mode 100644 index 000000000..9a5e79a38 --- /dev/null +++ b/net/xfrm/xfrm_output.c @@ -0,0 +1,909 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * xfrm_output.c - Common IPsec encapsulation code. + * + * Copyright (c) 2007 Herbert Xu + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if IS_ENABLED(CONFIG_IPV6) +#include +#include +#endif + +#include "xfrm_inout.h" + +static int xfrm_output2(struct net *net, struct sock *sk, struct sk_buff *skb); +static int xfrm_inner_extract_output(struct xfrm_state *x, struct sk_buff *skb); + +static int xfrm_skb_check_space(struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + int nhead = dst->header_len + LL_RESERVED_SPACE(dst->dev) + - skb_headroom(skb); + int ntail = dst->dev->needed_tailroom - skb_tailroom(skb); + + if (nhead <= 0) { + if (ntail <= 0) + return 0; + nhead = 0; + } else if (ntail < 0) + ntail = 0; + + return pskb_expand_head(skb, nhead, ntail, GFP_ATOMIC); +} + +/* Children define the path of the packet through the + * Linux networking. Thus, destinations are stackable. + */ + +static struct dst_entry *skb_dst_pop(struct sk_buff *skb) +{ + struct dst_entry *child = dst_clone(xfrm_dst_child(skb_dst(skb))); + + skb_dst_drop(skb); + return child; +} + +/* Add encapsulation header. + * + * The IP header will be moved forward to make space for the encapsulation + * header. + */ +static int xfrm4_transport_output(struct xfrm_state *x, struct sk_buff *skb) +{ + struct iphdr *iph = ip_hdr(skb); + int ihl = iph->ihl * 4; + + skb_set_inner_transport_header(skb, skb_transport_offset(skb)); + + skb_set_network_header(skb, -x->props.header_len); + skb->mac_header = skb->network_header + + offsetof(struct iphdr, protocol); + skb->transport_header = skb->network_header + ihl; + __skb_pull(skb, ihl); + memmove(skb_network_header(skb), iph, ihl); + return 0; +} + +#if IS_ENABLED(CONFIG_IPV6_MIP6) +static int mip6_rthdr_offset(struct sk_buff *skb, u8 **nexthdr, int type) +{ + const unsigned char *nh = skb_network_header(skb); + unsigned int offset = sizeof(struct ipv6hdr); + unsigned int packet_len; + int found_rhdr = 0; + + packet_len = skb_tail_pointer(skb) - nh; + *nexthdr = &ipv6_hdr(skb)->nexthdr; + + while (offset <= packet_len) { + struct ipv6_opt_hdr *exthdr; + + switch (**nexthdr) { + case NEXTHDR_HOP: + break; + case NEXTHDR_ROUTING: + if (type == IPPROTO_ROUTING && offset + 3 <= packet_len) { + struct ipv6_rt_hdr *rt; + + rt = (struct ipv6_rt_hdr *)(nh + offset); + if (rt->type != 0) + return offset; + } + found_rhdr = 1; + break; + case NEXTHDR_DEST: + /* HAO MUST NOT appear more than once. + * XXX: It is better to try to find by the end of + * XXX: packet if HAO exists. + */ + if (ipv6_find_tlv(skb, offset, IPV6_TLV_HAO) >= 0) { + net_dbg_ratelimited("mip6: hao exists already, override\n"); + return offset; + } + + if (found_rhdr) + return offset; + + break; + default: + return offset; + } + + if (offset + sizeof(struct ipv6_opt_hdr) > packet_len) + return -EINVAL; + + exthdr = (struct ipv6_opt_hdr *)(skb_network_header(skb) + + offset); + offset += ipv6_optlen(exthdr); + if (offset > IPV6_MAXPLEN) + return -EINVAL; + *nexthdr = &exthdr->nexthdr; + } + + return -EINVAL; +} +#endif + +#if IS_ENABLED(CONFIG_IPV6) +static int xfrm6_hdr_offset(struct xfrm_state *x, struct sk_buff *skb, u8 **prevhdr) +{ + switch (x->type->proto) { +#if IS_ENABLED(CONFIG_IPV6_MIP6) + case IPPROTO_DSTOPTS: + case IPPROTO_ROUTING: + return mip6_rthdr_offset(skb, prevhdr, x->type->proto); +#endif + default: + break; + } + + return ip6_find_1stfragopt(skb, prevhdr); +} +#endif + +/* Add encapsulation header. + * + * The IP header and mutable extension headers will be moved forward to make + * space for the encapsulation header. + */ +static int xfrm6_transport_output(struct xfrm_state *x, struct sk_buff *skb) +{ +#if IS_ENABLED(CONFIG_IPV6) + struct ipv6hdr *iph; + u8 *prevhdr; + int hdr_len; + + iph = ipv6_hdr(skb); + skb_set_inner_transport_header(skb, skb_transport_offset(skb)); + + hdr_len = xfrm6_hdr_offset(x, skb, &prevhdr); + if (hdr_len < 0) + return hdr_len; + skb_set_mac_header(skb, + (prevhdr - x->props.header_len) - skb->data); + skb_set_network_header(skb, -x->props.header_len); + skb->transport_header = skb->network_header + hdr_len; + __skb_pull(skb, hdr_len); + memmove(ipv6_hdr(skb), iph, hdr_len); + return 0; +#else + WARN_ON_ONCE(1); + return -EAFNOSUPPORT; +#endif +} + +/* Add route optimization header space. + * + * The IP header and mutable extension headers will be moved forward to make + * space for the route optimization header. + */ +static int xfrm6_ro_output(struct xfrm_state *x, struct sk_buff *skb) +{ +#if IS_ENABLED(CONFIG_IPV6) + struct ipv6hdr *iph; + u8 *prevhdr; + int hdr_len; + + iph = ipv6_hdr(skb); + + hdr_len = xfrm6_hdr_offset(x, skb, &prevhdr); + if (hdr_len < 0) + return hdr_len; + skb_set_mac_header(skb, + (prevhdr - x->props.header_len) - skb->data); + skb_set_network_header(skb, -x->props.header_len); + skb->transport_header = skb->network_header + hdr_len; + __skb_pull(skb, hdr_len); + memmove(ipv6_hdr(skb), iph, hdr_len); + + x->lastused = ktime_get_real_seconds(); + + return 0; +#else + WARN_ON_ONCE(1); + return -EAFNOSUPPORT; +#endif +} + +/* Add encapsulation header. + * + * The top IP header will be constructed per draft-nikander-esp-beet-mode-06.txt. + */ +static int xfrm4_beet_encap_add(struct xfrm_state *x, struct sk_buff *skb) +{ + struct ip_beet_phdr *ph; + struct iphdr *top_iph; + int hdrlen, optlen; + + hdrlen = 0; + optlen = XFRM_MODE_SKB_CB(skb)->optlen; + if (unlikely(optlen)) + hdrlen += IPV4_BEET_PHMAXLEN - (optlen & 4); + + skb_set_network_header(skb, -x->props.header_len - hdrlen + + (XFRM_MODE_SKB_CB(skb)->ihl - sizeof(*top_iph))); + if (x->sel.family != AF_INET6) + skb->network_header += IPV4_BEET_PHMAXLEN; + skb->mac_header = skb->network_header + + offsetof(struct iphdr, protocol); + skb->transport_header = skb->network_header + sizeof(*top_iph); + + xfrm4_beet_make_header(skb); + + ph = __skb_pull(skb, XFRM_MODE_SKB_CB(skb)->ihl - hdrlen); + + top_iph = ip_hdr(skb); + + if (unlikely(optlen)) { + if (WARN_ON(optlen < 0)) + return -EINVAL; + + ph->padlen = 4 - (optlen & 4); + ph->hdrlen = optlen / 8; + ph->nexthdr = top_iph->protocol; + if (ph->padlen) + memset(ph + 1, IPOPT_NOP, ph->padlen); + + top_iph->protocol = IPPROTO_BEETPH; + top_iph->ihl = sizeof(struct iphdr) / 4; + } + + top_iph->saddr = x->props.saddr.a4; + top_iph->daddr = x->id.daddr.a4; + + return 0; +} + +/* Add encapsulation header. + * + * The top IP header will be constructed per RFC 2401. + */ +static int xfrm4_tunnel_encap_add(struct xfrm_state *x, struct sk_buff *skb) +{ + bool small_ipv6 = (skb->protocol == htons(ETH_P_IPV6)) && (skb->len <= IPV6_MIN_MTU); + struct dst_entry *dst = skb_dst(skb); + struct iphdr *top_iph; + int flags; + + skb_set_inner_network_header(skb, skb_network_offset(skb)); + skb_set_inner_transport_header(skb, skb_transport_offset(skb)); + + skb_set_network_header(skb, -x->props.header_len); + skb->mac_header = skb->network_header + + offsetof(struct iphdr, protocol); + skb->transport_header = skb->network_header + sizeof(*top_iph); + top_iph = ip_hdr(skb); + + top_iph->ihl = 5; + top_iph->version = 4; + + top_iph->protocol = xfrm_af2proto(skb_dst(skb)->ops->family); + + /* DS disclosing depends on XFRM_SA_XFLAG_DONT_ENCAP_DSCP */ + if (x->props.extra_flags & XFRM_SA_XFLAG_DONT_ENCAP_DSCP) + top_iph->tos = 0; + else + top_iph->tos = XFRM_MODE_SKB_CB(skb)->tos; + top_iph->tos = INET_ECN_encapsulate(top_iph->tos, + XFRM_MODE_SKB_CB(skb)->tos); + + flags = x->props.flags; + if (flags & XFRM_STATE_NOECN) + IP_ECN_clear(top_iph); + + top_iph->frag_off = (flags & XFRM_STATE_NOPMTUDISC) || small_ipv6 ? + 0 : (XFRM_MODE_SKB_CB(skb)->frag_off & htons(IP_DF)); + + top_iph->ttl = ip4_dst_hoplimit(xfrm_dst_child(dst)); + + top_iph->saddr = x->props.saddr.a4; + top_iph->daddr = x->id.daddr.a4; + ip_select_ident(dev_net(dst->dev), skb, NULL); + + return 0; +} + +#if IS_ENABLED(CONFIG_IPV6) +static int xfrm6_tunnel_encap_add(struct xfrm_state *x, struct sk_buff *skb) +{ + struct dst_entry *dst = skb_dst(skb); + struct ipv6hdr *top_iph; + int dsfield; + + skb_set_inner_network_header(skb, skb_network_offset(skb)); + skb_set_inner_transport_header(skb, skb_transport_offset(skb)); + + skb_set_network_header(skb, -x->props.header_len); + skb->mac_header = skb->network_header + + offsetof(struct ipv6hdr, nexthdr); + skb->transport_header = skb->network_header + sizeof(*top_iph); + top_iph = ipv6_hdr(skb); + + top_iph->version = 6; + + memcpy(top_iph->flow_lbl, XFRM_MODE_SKB_CB(skb)->flow_lbl, + sizeof(top_iph->flow_lbl)); + top_iph->nexthdr = xfrm_af2proto(skb_dst(skb)->ops->family); + + if (x->props.extra_flags & XFRM_SA_XFLAG_DONT_ENCAP_DSCP) + dsfield = 0; + else + dsfield = XFRM_MODE_SKB_CB(skb)->tos; + dsfield = INET_ECN_encapsulate(dsfield, XFRM_MODE_SKB_CB(skb)->tos); + if (x->props.flags & XFRM_STATE_NOECN) + dsfield &= ~INET_ECN_MASK; + ipv6_change_dsfield(top_iph, 0, dsfield); + top_iph->hop_limit = ip6_dst_hoplimit(xfrm_dst_child(dst)); + top_iph->saddr = *(struct in6_addr *)&x->props.saddr; + top_iph->daddr = *(struct in6_addr *)&x->id.daddr; + return 0; +} + +static int xfrm6_beet_encap_add(struct xfrm_state *x, struct sk_buff *skb) +{ + struct ipv6hdr *top_iph; + struct ip_beet_phdr *ph; + int optlen, hdr_len; + + hdr_len = 0; + optlen = XFRM_MODE_SKB_CB(skb)->optlen; + if (unlikely(optlen)) + hdr_len += IPV4_BEET_PHMAXLEN - (optlen & 4); + + skb_set_network_header(skb, -x->props.header_len - hdr_len); + if (x->sel.family != AF_INET6) + skb->network_header += IPV4_BEET_PHMAXLEN; + skb->mac_header = skb->network_header + + offsetof(struct ipv6hdr, nexthdr); + skb->transport_header = skb->network_header + sizeof(*top_iph); + ph = __skb_pull(skb, XFRM_MODE_SKB_CB(skb)->ihl - hdr_len); + + xfrm6_beet_make_header(skb); + + top_iph = ipv6_hdr(skb); + if (unlikely(optlen)) { + if (WARN_ON(optlen < 0)) + return -EINVAL; + + ph->padlen = 4 - (optlen & 4); + ph->hdrlen = optlen / 8; + ph->nexthdr = top_iph->nexthdr; + if (ph->padlen) + memset(ph + 1, IPOPT_NOP, ph->padlen); + + top_iph->nexthdr = IPPROTO_BEETPH; + } + + top_iph->saddr = *(struct in6_addr *)&x->props.saddr; + top_iph->daddr = *(struct in6_addr *)&x->id.daddr; + return 0; +} +#endif + +/* Add encapsulation header. + * + * On exit, the transport header will be set to the start of the + * encapsulation header to be filled in by x->type->output and the mac + * header will be set to the nextheader (protocol for IPv4) field of the + * extension header directly preceding the encapsulation header, or in + * its absence, that of the top IP header. + * The value of the network header will always point to the top IP header + * while skb->data will point to the payload. + */ +static int xfrm4_prepare_output(struct xfrm_state *x, struct sk_buff *skb) +{ + int err; + + err = xfrm_inner_extract_output(x, skb); + if (err) + return err; + + IPCB(skb)->flags |= IPSKB_XFRM_TUNNEL_SIZE; + skb->protocol = htons(ETH_P_IP); + + switch (x->outer_mode.encap) { + case XFRM_MODE_BEET: + return xfrm4_beet_encap_add(x, skb); + case XFRM_MODE_TUNNEL: + return xfrm4_tunnel_encap_add(x, skb); + } + + WARN_ON_ONCE(1); + return -EOPNOTSUPP; +} + +static int xfrm6_prepare_output(struct xfrm_state *x, struct sk_buff *skb) +{ +#if IS_ENABLED(CONFIG_IPV6) + int err; + + err = xfrm_inner_extract_output(x, skb); + if (err) + return err; + + skb->ignore_df = 1; + skb->protocol = htons(ETH_P_IPV6); + + switch (x->outer_mode.encap) { + case XFRM_MODE_BEET: + return xfrm6_beet_encap_add(x, skb); + case XFRM_MODE_TUNNEL: + return xfrm6_tunnel_encap_add(x, skb); + default: + WARN_ON_ONCE(1); + return -EOPNOTSUPP; + } +#endif + WARN_ON_ONCE(1); + return -EAFNOSUPPORT; +} + +static int xfrm_outer_mode_output(struct xfrm_state *x, struct sk_buff *skb) +{ + switch (x->outer_mode.encap) { + case XFRM_MODE_BEET: + case XFRM_MODE_TUNNEL: + if (x->outer_mode.family == AF_INET) + return xfrm4_prepare_output(x, skb); + if (x->outer_mode.family == AF_INET6) + return xfrm6_prepare_output(x, skb); + break; + case XFRM_MODE_TRANSPORT: + if (x->outer_mode.family == AF_INET) + return xfrm4_transport_output(x, skb); + if (x->outer_mode.family == AF_INET6) + return xfrm6_transport_output(x, skb); + break; + case XFRM_MODE_ROUTEOPTIMIZATION: + if (x->outer_mode.family == AF_INET6) + return xfrm6_ro_output(x, skb); + WARN_ON_ONCE(1); + break; + default: + WARN_ON_ONCE(1); + break; + } + + return -EOPNOTSUPP; +} + +#if IS_ENABLED(CONFIG_NET_PKTGEN) +int pktgen_xfrm_outer_mode_output(struct xfrm_state *x, struct sk_buff *skb) +{ + return xfrm_outer_mode_output(x, skb); +} +EXPORT_SYMBOL_GPL(pktgen_xfrm_outer_mode_output); +#endif + +static int xfrm_output_one(struct sk_buff *skb, int err) +{ + struct dst_entry *dst = skb_dst(skb); + struct xfrm_state *x = dst->xfrm; + struct net *net = xs_net(x); + + if (err <= 0) + goto resume; + + do { + err = xfrm_skb_check_space(skb); + if (err) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTERROR); + goto error_nolock; + } + + skb->mark = xfrm_smark_get(skb->mark, x); + + err = xfrm_outer_mode_output(x, skb); + if (err) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTSTATEMODEERROR); + goto error_nolock; + } + + spin_lock_bh(&x->lock); + + if (unlikely(x->km.state != XFRM_STATE_VALID)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTSTATEINVALID); + err = -EINVAL; + goto error; + } + + err = xfrm_state_check_expire(x); + if (err) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTSTATEEXPIRED); + goto error; + } + + err = xfrm_replay_overflow(x, skb); + if (err) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTSTATESEQERROR); + goto error; + } + + x->curlft.bytes += skb->len; + x->curlft.packets++; + + spin_unlock_bh(&x->lock); + + skb_dst_force(skb); + if (!skb_dst(skb)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTERROR); + err = -EHOSTUNREACH; + goto error_nolock; + } + + if (xfrm_offload(skb)) { + x->type_offload->encap(x, skb); + } else { + /* Inner headers are invalid now. */ + skb->encapsulation = 0; + + err = x->type->output(x, skb); + if (err == -EINPROGRESS) + goto out; + } + +resume: + if (err) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTSTATEPROTOERROR); + goto error_nolock; + } + + dst = skb_dst_pop(skb); + if (!dst) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTERROR); + err = -EHOSTUNREACH; + goto error_nolock; + } + skb_dst_set(skb, dst); + x = dst->xfrm; + } while (x && !(x->outer_mode.flags & XFRM_MODE_FLAG_TUNNEL)); + + return 0; + +error: + spin_unlock_bh(&x->lock); +error_nolock: + kfree_skb(skb); +out: + return err; +} + +int xfrm_output_resume(struct sock *sk, struct sk_buff *skb, int err) +{ + struct net *net = xs_net(skb_dst(skb)->xfrm); + + while (likely((err = xfrm_output_one(skb, err)) == 0)) { + nf_reset_ct(skb); + + err = skb_dst(skb)->ops->local_out(net, sk, skb); + if (unlikely(err != 1)) + goto out; + + if (!skb_dst(skb)->xfrm) + return dst_output(net, sk, skb); + + err = nf_hook(skb_dst(skb)->ops->family, + NF_INET_POST_ROUTING, net, sk, skb, + NULL, skb_dst(skb)->dev, xfrm_output2); + if (unlikely(err != 1)) + goto out; + } + + if (err == -EINPROGRESS) + err = 0; + +out: + return err; +} +EXPORT_SYMBOL_GPL(xfrm_output_resume); + +static int xfrm_output2(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + return xfrm_output_resume(sk, skb, 1); +} + +static int xfrm_output_gso(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + struct sk_buff *segs, *nskb; + + BUILD_BUG_ON(sizeof(*IPCB(skb)) > SKB_GSO_CB_OFFSET); + BUILD_BUG_ON(sizeof(*IP6CB(skb)) > SKB_GSO_CB_OFFSET); + segs = skb_gso_segment(skb, 0); + kfree_skb(skb); + if (IS_ERR(segs)) + return PTR_ERR(segs); + if (segs == NULL) + return -EINVAL; + + skb_list_walk_safe(segs, segs, nskb) { + int err; + + skb_mark_not_on_list(segs); + err = xfrm_output2(net, sk, segs); + + if (unlikely(err)) { + kfree_skb_list(nskb); + return err; + } + } + + return 0; +} + +/* For partial checksum offload, the outer header checksum is calculated + * by software and the inner header checksum is calculated by hardware. + * This requires hardware to know the inner packet type to calculate + * the inner header checksum. Save inner ip protocol here to avoid + * traversing the packet in the vendor's xmit code. + * For IPsec tunnel mode save the ip protocol from the IP header of the + * plain text packet. Otherwise If the encap type is IPIP, just save + * skb->inner_ipproto in any other case get the ip protocol from the IP + * header. + */ +static void xfrm_get_inner_ipproto(struct sk_buff *skb, struct xfrm_state *x) +{ + struct xfrm_offload *xo = xfrm_offload(skb); + const struct ethhdr *eth; + + if (!xo) + return; + + if (x->outer_mode.encap == XFRM_MODE_TUNNEL) { + switch (x->outer_mode.family) { + case AF_INET: + xo->inner_ipproto = ip_hdr(skb)->protocol; + break; + case AF_INET6: + xo->inner_ipproto = ipv6_hdr(skb)->nexthdr; + break; + default: + break; + } + + return; + } + + /* non-Tunnel Mode */ + if (!skb->encapsulation) + return; + + if (skb->inner_protocol_type == ENCAP_TYPE_IPPROTO) { + xo->inner_ipproto = skb->inner_ipproto; + return; + } + + if (skb->inner_protocol_type != ENCAP_TYPE_ETHER) + return; + + eth = (struct ethhdr *)skb_inner_mac_header(skb); + + switch (ntohs(eth->h_proto)) { + case ETH_P_IPV6: + xo->inner_ipproto = inner_ipv6_hdr(skb)->nexthdr; + break; + case ETH_P_IP: + xo->inner_ipproto = inner_ip_hdr(skb)->protocol; + break; + } +} + +int xfrm_output(struct sock *sk, struct sk_buff *skb) +{ + struct net *net = dev_net(skb_dst(skb)->dev); + struct xfrm_state *x = skb_dst(skb)->xfrm; + int err; + + switch (x->outer_mode.family) { + case AF_INET: + memset(IPCB(skb), 0, sizeof(*IPCB(skb))); + IPCB(skb)->flags |= IPSKB_XFRM_TRANSFORMED; + break; + case AF_INET6: + memset(IP6CB(skb), 0, sizeof(*IP6CB(skb))); + + IP6CB(skb)->flags |= IP6SKB_XFRM_TRANSFORMED; + break; + } + + secpath_reset(skb); + + if (xfrm_dev_offload_ok(skb, x)) { + struct sec_path *sp; + + sp = secpath_set(skb); + if (!sp) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTERROR); + kfree_skb(skb); + return -ENOMEM; + } + + sp->olen++; + sp->xvec[sp->len++] = x; + xfrm_state_hold(x); + + xfrm_get_inner_ipproto(skb, x); + skb->encapsulation = 1; + + if (skb_is_gso(skb)) { + if (skb->inner_protocol) + return xfrm_output_gso(net, sk, skb); + + skb_shinfo(skb)->gso_type |= SKB_GSO_ESP; + goto out; + } + + if (x->xso.dev && x->xso.dev->features & NETIF_F_HW_ESP_TX_CSUM) + goto out; + } else { + if (skb_is_gso(skb)) + return xfrm_output_gso(net, sk, skb); + } + + if (skb->ip_summed == CHECKSUM_PARTIAL) { + err = skb_checksum_help(skb); + if (err) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTERROR); + kfree_skb(skb); + return err; + } + } + +out: + return xfrm_output2(net, sk, skb); +} +EXPORT_SYMBOL_GPL(xfrm_output); + +static int xfrm4_tunnel_check_size(struct sk_buff *skb) +{ + int mtu, ret = 0; + + if (IPCB(skb)->flags & IPSKB_XFRM_TUNNEL_SIZE) + goto out; + + if (!(ip_hdr(skb)->frag_off & htons(IP_DF)) || skb->ignore_df) + goto out; + + mtu = dst_mtu(skb_dst(skb)); + if ((!skb_is_gso(skb) && skb->len > mtu) || + (skb_is_gso(skb) && + !skb_gso_validate_network_len(skb, ip_skb_dst_mtu(skb->sk, skb)))) { + skb->protocol = htons(ETH_P_IP); + + if (skb->sk) + xfrm_local_error(skb, mtu); + else + icmp_send(skb, ICMP_DEST_UNREACH, + ICMP_FRAG_NEEDED, htonl(mtu)); + ret = -EMSGSIZE; + } +out: + return ret; +} + +static int xfrm4_extract_output(struct xfrm_state *x, struct sk_buff *skb) +{ + int err; + + if (x->outer_mode.encap == XFRM_MODE_BEET && + ip_is_fragment(ip_hdr(skb))) { + net_warn_ratelimited("BEET mode doesn't support inner IPv4 fragments\n"); + return -EAFNOSUPPORT; + } + + err = xfrm4_tunnel_check_size(skb); + if (err) + return err; + + XFRM_MODE_SKB_CB(skb)->protocol = ip_hdr(skb)->protocol; + + xfrm4_extract_header(skb); + return 0; +} + +#if IS_ENABLED(CONFIG_IPV6) +static int xfrm6_tunnel_check_size(struct sk_buff *skb) +{ + int mtu, ret = 0; + struct dst_entry *dst = skb_dst(skb); + + if (skb->ignore_df) + goto out; + + mtu = dst_mtu(dst); + if (mtu < IPV6_MIN_MTU) + mtu = IPV6_MIN_MTU; + + if ((!skb_is_gso(skb) && skb->len > mtu) || + (skb_is_gso(skb) && + !skb_gso_validate_network_len(skb, ip6_skb_dst_mtu(skb)))) { + skb->dev = dst->dev; + skb->protocol = htons(ETH_P_IPV6); + + if (xfrm6_local_dontfrag(skb->sk)) + ipv6_stub->xfrm6_local_rxpmtu(skb, mtu); + else if (skb->sk) + xfrm_local_error(skb, mtu); + else + icmpv6_send(skb, ICMPV6_PKT_TOOBIG, 0, mtu); + ret = -EMSGSIZE; + } +out: + return ret; +} +#endif + +static int xfrm6_extract_output(struct xfrm_state *x, struct sk_buff *skb) +{ +#if IS_ENABLED(CONFIG_IPV6) + int err; + + err = xfrm6_tunnel_check_size(skb); + if (err) + return err; + + XFRM_MODE_SKB_CB(skb)->protocol = ipv6_hdr(skb)->nexthdr; + + xfrm6_extract_header(skb); + return 0; +#else + WARN_ON_ONCE(1); + return -EAFNOSUPPORT; +#endif +} + +static int xfrm_inner_extract_output(struct xfrm_state *x, struct sk_buff *skb) +{ + const struct xfrm_mode *inner_mode; + + if (x->sel.family == AF_UNSPEC) + inner_mode = xfrm_ip2inner_mode(x, + xfrm_af2proto(skb_dst(skb)->ops->family)); + else + inner_mode = &x->inner_mode; + + if (inner_mode == NULL) + return -EAFNOSUPPORT; + + switch (inner_mode->family) { + case AF_INET: + return xfrm4_extract_output(x, skb); + case AF_INET6: + return xfrm6_extract_output(x, skb); + } + + return -EAFNOSUPPORT; +} + +void xfrm_local_error(struct sk_buff *skb, int mtu) +{ + unsigned int proto; + struct xfrm_state_afinfo *afinfo; + + if (skb->protocol == htons(ETH_P_IP)) + proto = AF_INET; + else if (skb->protocol == htons(ETH_P_IPV6) && + skb->sk->sk_family == AF_INET6) + proto = AF_INET6; + else + return; + + afinfo = xfrm_state_get_afinfo(proto); + if (afinfo) { + afinfo->local_error(skb, mtu); + rcu_read_unlock(); + } +} +EXPORT_SYMBOL_GPL(xfrm_local_error); diff --git a/net/xfrm/xfrm_policy.c b/net/xfrm/xfrm_policy.c new file mode 100644 index 000000000..e47c670c7 --- /dev/null +++ b/net/xfrm/xfrm_policy.c @@ -0,0 +1,4490 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * xfrm_policy.c + * + * Changes: + * Mitsuru KANDA @USAGI + * Kazunori MIYAZAWA @USAGI + * Kunihiro Ishiguro + * IPv6 support + * Kazunori MIYAZAWA @USAGI + * YOSHIFUJI Hideaki + * Split up af-specific portion + * Derek Atkins Add the post_input processor + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_IPV6_MIP6) +#include +#endif +#ifdef CONFIG_XFRM_STATISTICS +#include +#endif +#ifdef CONFIG_XFRM_ESPINTCP +#include +#endif + +#include "xfrm_hash.h" + +#define XFRM_QUEUE_TMO_MIN ((unsigned)(HZ/10)) +#define XFRM_QUEUE_TMO_MAX ((unsigned)(60*HZ)) +#define XFRM_MAX_QUEUE_LEN 100 + +struct xfrm_flo { + struct dst_entry *dst_orig; + u8 flags; +}; + +/* prefixes smaller than this are stored in lists, not trees. */ +#define INEXACT_PREFIXLEN_IPV4 16 +#define INEXACT_PREFIXLEN_IPV6 48 + +struct xfrm_pol_inexact_node { + struct rb_node node; + union { + xfrm_address_t addr; + struct rcu_head rcu; + }; + u8 prefixlen; + + struct rb_root root; + + /* the policies matching this node, can be empty list */ + struct hlist_head hhead; +}; + +/* xfrm inexact policy search tree: + * xfrm_pol_inexact_bin = hash(dir,type,family,if_id); + * | + * +---- root_d: sorted by daddr:prefix + * | | + * | xfrm_pol_inexact_node + * | | + * | +- root: sorted by saddr/prefix + * | | | + * | | xfrm_pol_inexact_node + * | | | + * | | + root: unused + * | | | + * | | + hhead: saddr:daddr policies + * | | + * | +- coarse policies and all any:daddr policies + * | + * +---- root_s: sorted by saddr:prefix + * | | + * | xfrm_pol_inexact_node + * | | + * | + root: unused + * | | + * | + hhead: saddr:any policies + * | + * +---- coarse policies and all any:any policies + * + * Lookups return four candidate lists: + * 1. any:any list from top-level xfrm_pol_inexact_bin + * 2. any:daddr list from daddr tree + * 3. saddr:daddr list from 2nd level daddr tree + * 4. saddr:any list from saddr tree + * + * This result set then needs to be searched for the policy with + * the lowest priority. If two results have same prio, youngest one wins. + */ + +struct xfrm_pol_inexact_key { + possible_net_t net; + u32 if_id; + u16 family; + u8 dir, type; +}; + +struct xfrm_pol_inexact_bin { + struct xfrm_pol_inexact_key k; + struct rhash_head head; + /* list containing '*:*' policies */ + struct hlist_head hhead; + + seqcount_spinlock_t count; + /* tree sorted by daddr/prefix */ + struct rb_root root_d; + + /* tree sorted by saddr/prefix */ + struct rb_root root_s; + + /* slow path below */ + struct list_head inexact_bins; + struct rcu_head rcu; +}; + +enum xfrm_pol_inexact_candidate_type { + XFRM_POL_CAND_BOTH, + XFRM_POL_CAND_SADDR, + XFRM_POL_CAND_DADDR, + XFRM_POL_CAND_ANY, + + XFRM_POL_CAND_MAX, +}; + +struct xfrm_pol_inexact_candidates { + struct hlist_head *res[XFRM_POL_CAND_MAX]; +}; + +static DEFINE_SPINLOCK(xfrm_if_cb_lock); +static struct xfrm_if_cb const __rcu *xfrm_if_cb __read_mostly; + +static DEFINE_SPINLOCK(xfrm_policy_afinfo_lock); +static struct xfrm_policy_afinfo const __rcu *xfrm_policy_afinfo[AF_INET6 + 1] + __read_mostly; + +static struct kmem_cache *xfrm_dst_cache __ro_after_init; + +static struct rhashtable xfrm_policy_inexact_table; +static const struct rhashtable_params xfrm_pol_inexact_params; + +static void xfrm_init_pmtu(struct xfrm_dst **bundle, int nr); +static int stale_bundle(struct dst_entry *dst); +static int xfrm_bundle_ok(struct xfrm_dst *xdst); +static void xfrm_policy_queue_process(struct timer_list *t); + +static void __xfrm_policy_link(struct xfrm_policy *pol, int dir); +static struct xfrm_policy *__xfrm_policy_unlink(struct xfrm_policy *pol, + int dir); + +static struct xfrm_pol_inexact_bin * +xfrm_policy_inexact_lookup(struct net *net, u8 type, u16 family, u8 dir, + u32 if_id); + +static struct xfrm_pol_inexact_bin * +xfrm_policy_inexact_lookup_rcu(struct net *net, + u8 type, u16 family, u8 dir, u32 if_id); +static struct xfrm_policy * +xfrm_policy_insert_list(struct hlist_head *chain, struct xfrm_policy *policy, + bool excl); +static void xfrm_policy_insert_inexact_list(struct hlist_head *chain, + struct xfrm_policy *policy); + +static bool +xfrm_policy_find_inexact_candidates(struct xfrm_pol_inexact_candidates *cand, + struct xfrm_pol_inexact_bin *b, + const xfrm_address_t *saddr, + const xfrm_address_t *daddr); + +static inline bool xfrm_pol_hold_rcu(struct xfrm_policy *policy) +{ + return refcount_inc_not_zero(&policy->refcnt); +} + +static inline bool +__xfrm4_selector_match(const struct xfrm_selector *sel, const struct flowi *fl) +{ + const struct flowi4 *fl4 = &fl->u.ip4; + + return addr4_match(fl4->daddr, sel->daddr.a4, sel->prefixlen_d) && + addr4_match(fl4->saddr, sel->saddr.a4, sel->prefixlen_s) && + !((xfrm_flowi_dport(fl, &fl4->uli) ^ sel->dport) & sel->dport_mask) && + !((xfrm_flowi_sport(fl, &fl4->uli) ^ sel->sport) & sel->sport_mask) && + (fl4->flowi4_proto == sel->proto || !sel->proto) && + (fl4->flowi4_oif == sel->ifindex || !sel->ifindex); +} + +static inline bool +__xfrm6_selector_match(const struct xfrm_selector *sel, const struct flowi *fl) +{ + const struct flowi6 *fl6 = &fl->u.ip6; + + return addr_match(&fl6->daddr, &sel->daddr, sel->prefixlen_d) && + addr_match(&fl6->saddr, &sel->saddr, sel->prefixlen_s) && + !((xfrm_flowi_dport(fl, &fl6->uli) ^ sel->dport) & sel->dport_mask) && + !((xfrm_flowi_sport(fl, &fl6->uli) ^ sel->sport) & sel->sport_mask) && + (fl6->flowi6_proto == sel->proto || !sel->proto) && + (fl6->flowi6_oif == sel->ifindex || !sel->ifindex); +} + +bool xfrm_selector_match(const struct xfrm_selector *sel, const struct flowi *fl, + unsigned short family) +{ + switch (family) { + case AF_INET: + return __xfrm4_selector_match(sel, fl); + case AF_INET6: + return __xfrm6_selector_match(sel, fl); + } + return false; +} + +static const struct xfrm_policy_afinfo *xfrm_policy_get_afinfo(unsigned short family) +{ + const struct xfrm_policy_afinfo *afinfo; + + if (unlikely(family >= ARRAY_SIZE(xfrm_policy_afinfo))) + return NULL; + rcu_read_lock(); + afinfo = rcu_dereference(xfrm_policy_afinfo[family]); + if (unlikely(!afinfo)) + rcu_read_unlock(); + return afinfo; +} + +/* Called with rcu_read_lock(). */ +static const struct xfrm_if_cb *xfrm_if_get_cb(void) +{ + return rcu_dereference(xfrm_if_cb); +} + +struct dst_entry *__xfrm_dst_lookup(struct net *net, int tos, int oif, + const xfrm_address_t *saddr, + const xfrm_address_t *daddr, + int family, u32 mark) +{ + const struct xfrm_policy_afinfo *afinfo; + struct dst_entry *dst; + + afinfo = xfrm_policy_get_afinfo(family); + if (unlikely(afinfo == NULL)) + return ERR_PTR(-EAFNOSUPPORT); + + dst = afinfo->dst_lookup(net, tos, oif, saddr, daddr, mark); + + rcu_read_unlock(); + + return dst; +} +EXPORT_SYMBOL(__xfrm_dst_lookup); + +static inline struct dst_entry *xfrm_dst_lookup(struct xfrm_state *x, + int tos, int oif, + xfrm_address_t *prev_saddr, + xfrm_address_t *prev_daddr, + int family, u32 mark) +{ + struct net *net = xs_net(x); + xfrm_address_t *saddr = &x->props.saddr; + xfrm_address_t *daddr = &x->id.daddr; + struct dst_entry *dst; + + if (x->type->flags & XFRM_TYPE_LOCAL_COADDR) { + saddr = x->coaddr; + daddr = prev_daddr; + } + if (x->type->flags & XFRM_TYPE_REMOTE_COADDR) { + saddr = prev_saddr; + daddr = x->coaddr; + } + + dst = __xfrm_dst_lookup(net, tos, oif, saddr, daddr, family, mark); + + if (!IS_ERR(dst)) { + if (prev_saddr != saddr) + memcpy(prev_saddr, saddr, sizeof(*prev_saddr)); + if (prev_daddr != daddr) + memcpy(prev_daddr, daddr, sizeof(*prev_daddr)); + } + + return dst; +} + +static inline unsigned long make_jiffies(long secs) +{ + if (secs >= (MAX_SCHEDULE_TIMEOUT-1)/HZ) + return MAX_SCHEDULE_TIMEOUT-1; + else + return secs*HZ; +} + +static void xfrm_policy_timer(struct timer_list *t) +{ + struct xfrm_policy *xp = from_timer(xp, t, timer); + time64_t now = ktime_get_real_seconds(); + time64_t next = TIME64_MAX; + int warn = 0; + int dir; + + read_lock(&xp->lock); + + if (unlikely(xp->walk.dead)) + goto out; + + dir = xfrm_policy_id2dir(xp->index); + + if (xp->lft.hard_add_expires_seconds) { + time64_t tmo = xp->lft.hard_add_expires_seconds + + xp->curlft.add_time - now; + if (tmo <= 0) + goto expired; + if (tmo < next) + next = tmo; + } + if (xp->lft.hard_use_expires_seconds) { + time64_t tmo = xp->lft.hard_use_expires_seconds + + (READ_ONCE(xp->curlft.use_time) ? : xp->curlft.add_time) - now; + if (tmo <= 0) + goto expired; + if (tmo < next) + next = tmo; + } + if (xp->lft.soft_add_expires_seconds) { + time64_t tmo = xp->lft.soft_add_expires_seconds + + xp->curlft.add_time - now; + if (tmo <= 0) { + warn = 1; + tmo = XFRM_KM_TIMEOUT; + } + if (tmo < next) + next = tmo; + } + if (xp->lft.soft_use_expires_seconds) { + time64_t tmo = xp->lft.soft_use_expires_seconds + + (READ_ONCE(xp->curlft.use_time) ? : xp->curlft.add_time) - now; + if (tmo <= 0) { + warn = 1; + tmo = XFRM_KM_TIMEOUT; + } + if (tmo < next) + next = tmo; + } + + if (warn) + km_policy_expired(xp, dir, 0, 0); + if (next != TIME64_MAX && + !mod_timer(&xp->timer, jiffies + make_jiffies(next))) + xfrm_pol_hold(xp); + +out: + read_unlock(&xp->lock); + xfrm_pol_put(xp); + return; + +expired: + read_unlock(&xp->lock); + if (!xfrm_policy_delete(xp, dir)) + km_policy_expired(xp, dir, 1, 0); + xfrm_pol_put(xp); +} + +/* Allocate xfrm_policy. Not used here, it is supposed to be used by pfkeyv2 + * SPD calls. + */ + +struct xfrm_policy *xfrm_policy_alloc(struct net *net, gfp_t gfp) +{ + struct xfrm_policy *policy; + + policy = kzalloc(sizeof(struct xfrm_policy), gfp); + + if (policy) { + write_pnet(&policy->xp_net, net); + INIT_LIST_HEAD(&policy->walk.all); + INIT_HLIST_NODE(&policy->bydst_inexact_list); + INIT_HLIST_NODE(&policy->bydst); + INIT_HLIST_NODE(&policy->byidx); + rwlock_init(&policy->lock); + refcount_set(&policy->refcnt, 1); + skb_queue_head_init(&policy->polq.hold_queue); + timer_setup(&policy->timer, xfrm_policy_timer, 0); + timer_setup(&policy->polq.hold_timer, + xfrm_policy_queue_process, 0); + } + return policy; +} +EXPORT_SYMBOL(xfrm_policy_alloc); + +static void xfrm_policy_destroy_rcu(struct rcu_head *head) +{ + struct xfrm_policy *policy = container_of(head, struct xfrm_policy, rcu); + + security_xfrm_policy_free(policy->security); + kfree(policy); +} + +/* Destroy xfrm_policy: descendant resources must be released to this moment. */ + +void xfrm_policy_destroy(struct xfrm_policy *policy) +{ + BUG_ON(!policy->walk.dead); + + if (del_timer(&policy->timer) || del_timer(&policy->polq.hold_timer)) + BUG(); + + call_rcu(&policy->rcu, xfrm_policy_destroy_rcu); +} +EXPORT_SYMBOL(xfrm_policy_destroy); + +/* Rule must be locked. Release descendant resources, announce + * entry dead. The rule must be unlinked from lists to the moment. + */ + +static void xfrm_policy_kill(struct xfrm_policy *policy) +{ + write_lock_bh(&policy->lock); + policy->walk.dead = 1; + write_unlock_bh(&policy->lock); + + atomic_inc(&policy->genid); + + if (del_timer(&policy->polq.hold_timer)) + xfrm_pol_put(policy); + skb_queue_purge(&policy->polq.hold_queue); + + if (del_timer(&policy->timer)) + xfrm_pol_put(policy); + + xfrm_pol_put(policy); +} + +static unsigned int xfrm_policy_hashmax __read_mostly = 1 * 1024 * 1024; + +static inline unsigned int idx_hash(struct net *net, u32 index) +{ + return __idx_hash(index, net->xfrm.policy_idx_hmask); +} + +/* calculate policy hash thresholds */ +static void __get_hash_thresh(struct net *net, + unsigned short family, int dir, + u8 *dbits, u8 *sbits) +{ + switch (family) { + case AF_INET: + *dbits = net->xfrm.policy_bydst[dir].dbits4; + *sbits = net->xfrm.policy_bydst[dir].sbits4; + break; + + case AF_INET6: + *dbits = net->xfrm.policy_bydst[dir].dbits6; + *sbits = net->xfrm.policy_bydst[dir].sbits6; + break; + + default: + *dbits = 0; + *sbits = 0; + } +} + +static struct hlist_head *policy_hash_bysel(struct net *net, + const struct xfrm_selector *sel, + unsigned short family, int dir) +{ + unsigned int hmask = net->xfrm.policy_bydst[dir].hmask; + unsigned int hash; + u8 dbits; + u8 sbits; + + __get_hash_thresh(net, family, dir, &dbits, &sbits); + hash = __sel_hash(sel, family, hmask, dbits, sbits); + + if (hash == hmask + 1) + return NULL; + + return rcu_dereference_check(net->xfrm.policy_bydst[dir].table, + lockdep_is_held(&net->xfrm.xfrm_policy_lock)) + hash; +} + +static struct hlist_head *policy_hash_direct(struct net *net, + const xfrm_address_t *daddr, + const xfrm_address_t *saddr, + unsigned short family, int dir) +{ + unsigned int hmask = net->xfrm.policy_bydst[dir].hmask; + unsigned int hash; + u8 dbits; + u8 sbits; + + __get_hash_thresh(net, family, dir, &dbits, &sbits); + hash = __addr_hash(daddr, saddr, family, hmask, dbits, sbits); + + return rcu_dereference_check(net->xfrm.policy_bydst[dir].table, + lockdep_is_held(&net->xfrm.xfrm_policy_lock)) + hash; +} + +static void xfrm_dst_hash_transfer(struct net *net, + struct hlist_head *list, + struct hlist_head *ndsttable, + unsigned int nhashmask, + int dir) +{ + struct hlist_node *tmp, *entry0 = NULL; + struct xfrm_policy *pol; + unsigned int h0 = 0; + u8 dbits; + u8 sbits; + +redo: + hlist_for_each_entry_safe(pol, tmp, list, bydst) { + unsigned int h; + + __get_hash_thresh(net, pol->family, dir, &dbits, &sbits); + h = __addr_hash(&pol->selector.daddr, &pol->selector.saddr, + pol->family, nhashmask, dbits, sbits); + if (!entry0) { + hlist_del_rcu(&pol->bydst); + hlist_add_head_rcu(&pol->bydst, ndsttable + h); + h0 = h; + } else { + if (h != h0) + continue; + hlist_del_rcu(&pol->bydst); + hlist_add_behind_rcu(&pol->bydst, entry0); + } + entry0 = &pol->bydst; + } + if (!hlist_empty(list)) { + entry0 = NULL; + goto redo; + } +} + +static void xfrm_idx_hash_transfer(struct hlist_head *list, + struct hlist_head *nidxtable, + unsigned int nhashmask) +{ + struct hlist_node *tmp; + struct xfrm_policy *pol; + + hlist_for_each_entry_safe(pol, tmp, list, byidx) { + unsigned int h; + + h = __idx_hash(pol->index, nhashmask); + hlist_add_head(&pol->byidx, nidxtable+h); + } +} + +static unsigned long xfrm_new_hash_mask(unsigned int old_hmask) +{ + return ((old_hmask + 1) << 1) - 1; +} + +static void xfrm_bydst_resize(struct net *net, int dir) +{ + unsigned int hmask = net->xfrm.policy_bydst[dir].hmask; + unsigned int nhashmask = xfrm_new_hash_mask(hmask); + unsigned int nsize = (nhashmask + 1) * sizeof(struct hlist_head); + struct hlist_head *ndst = xfrm_hash_alloc(nsize); + struct hlist_head *odst; + int i; + + if (!ndst) + return; + + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + write_seqcount_begin(&net->xfrm.xfrm_policy_hash_generation); + + odst = rcu_dereference_protected(net->xfrm.policy_bydst[dir].table, + lockdep_is_held(&net->xfrm.xfrm_policy_lock)); + + for (i = hmask; i >= 0; i--) + xfrm_dst_hash_transfer(net, odst + i, ndst, nhashmask, dir); + + rcu_assign_pointer(net->xfrm.policy_bydst[dir].table, ndst); + net->xfrm.policy_bydst[dir].hmask = nhashmask; + + write_seqcount_end(&net->xfrm.xfrm_policy_hash_generation); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + + synchronize_rcu(); + + xfrm_hash_free(odst, (hmask + 1) * sizeof(struct hlist_head)); +} + +static void xfrm_byidx_resize(struct net *net, int total) +{ + unsigned int hmask = net->xfrm.policy_idx_hmask; + unsigned int nhashmask = xfrm_new_hash_mask(hmask); + unsigned int nsize = (nhashmask + 1) * sizeof(struct hlist_head); + struct hlist_head *oidx = net->xfrm.policy_byidx; + struct hlist_head *nidx = xfrm_hash_alloc(nsize); + int i; + + if (!nidx) + return; + + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + + for (i = hmask; i >= 0; i--) + xfrm_idx_hash_transfer(oidx + i, nidx, nhashmask); + + net->xfrm.policy_byidx = nidx; + net->xfrm.policy_idx_hmask = nhashmask; + + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + + xfrm_hash_free(oidx, (hmask + 1) * sizeof(struct hlist_head)); +} + +static inline int xfrm_bydst_should_resize(struct net *net, int dir, int *total) +{ + unsigned int cnt = net->xfrm.policy_count[dir]; + unsigned int hmask = net->xfrm.policy_bydst[dir].hmask; + + if (total) + *total += cnt; + + if ((hmask + 1) < xfrm_policy_hashmax && + cnt > hmask) + return 1; + + return 0; +} + +static inline int xfrm_byidx_should_resize(struct net *net, int total) +{ + unsigned int hmask = net->xfrm.policy_idx_hmask; + + if ((hmask + 1) < xfrm_policy_hashmax && + total > hmask) + return 1; + + return 0; +} + +void xfrm_spd_getinfo(struct net *net, struct xfrmk_spdinfo *si) +{ + si->incnt = net->xfrm.policy_count[XFRM_POLICY_IN]; + si->outcnt = net->xfrm.policy_count[XFRM_POLICY_OUT]; + si->fwdcnt = net->xfrm.policy_count[XFRM_POLICY_FWD]; + si->inscnt = net->xfrm.policy_count[XFRM_POLICY_IN+XFRM_POLICY_MAX]; + si->outscnt = net->xfrm.policy_count[XFRM_POLICY_OUT+XFRM_POLICY_MAX]; + si->fwdscnt = net->xfrm.policy_count[XFRM_POLICY_FWD+XFRM_POLICY_MAX]; + si->spdhcnt = net->xfrm.policy_idx_hmask; + si->spdhmcnt = xfrm_policy_hashmax; +} +EXPORT_SYMBOL(xfrm_spd_getinfo); + +static DEFINE_MUTEX(hash_resize_mutex); +static void xfrm_hash_resize(struct work_struct *work) +{ + struct net *net = container_of(work, struct net, xfrm.policy_hash_work); + int dir, total; + + mutex_lock(&hash_resize_mutex); + + total = 0; + for (dir = 0; dir < XFRM_POLICY_MAX; dir++) { + if (xfrm_bydst_should_resize(net, dir, &total)) + xfrm_bydst_resize(net, dir); + } + if (xfrm_byidx_should_resize(net, total)) + xfrm_byidx_resize(net, total); + + mutex_unlock(&hash_resize_mutex); +} + +/* Make sure *pol can be inserted into fastbin. + * Useful to check that later insert requests will be successful + * (provided xfrm_policy_lock is held throughout). + */ +static struct xfrm_pol_inexact_bin * +xfrm_policy_inexact_alloc_bin(const struct xfrm_policy *pol, u8 dir) +{ + struct xfrm_pol_inexact_bin *bin, *prev; + struct xfrm_pol_inexact_key k = { + .family = pol->family, + .type = pol->type, + .dir = dir, + .if_id = pol->if_id, + }; + struct net *net = xp_net(pol); + + lockdep_assert_held(&net->xfrm.xfrm_policy_lock); + + write_pnet(&k.net, net); + bin = rhashtable_lookup_fast(&xfrm_policy_inexact_table, &k, + xfrm_pol_inexact_params); + if (bin) + return bin; + + bin = kzalloc(sizeof(*bin), GFP_ATOMIC); + if (!bin) + return NULL; + + bin->k = k; + INIT_HLIST_HEAD(&bin->hhead); + bin->root_d = RB_ROOT; + bin->root_s = RB_ROOT; + seqcount_spinlock_init(&bin->count, &net->xfrm.xfrm_policy_lock); + + prev = rhashtable_lookup_get_insert_key(&xfrm_policy_inexact_table, + &bin->k, &bin->head, + xfrm_pol_inexact_params); + if (!prev) { + list_add(&bin->inexact_bins, &net->xfrm.inexact_bins); + return bin; + } + + kfree(bin); + + return IS_ERR(prev) ? NULL : prev; +} + +static bool xfrm_pol_inexact_addr_use_any_list(const xfrm_address_t *addr, + int family, u8 prefixlen) +{ + if (xfrm_addr_any(addr, family)) + return true; + + if (family == AF_INET6 && prefixlen < INEXACT_PREFIXLEN_IPV6) + return true; + + if (family == AF_INET && prefixlen < INEXACT_PREFIXLEN_IPV4) + return true; + + return false; +} + +static bool +xfrm_policy_inexact_insert_use_any_list(const struct xfrm_policy *policy) +{ + const xfrm_address_t *addr; + bool saddr_any, daddr_any; + u8 prefixlen; + + addr = &policy->selector.saddr; + prefixlen = policy->selector.prefixlen_s; + + saddr_any = xfrm_pol_inexact_addr_use_any_list(addr, + policy->family, + prefixlen); + addr = &policy->selector.daddr; + prefixlen = policy->selector.prefixlen_d; + daddr_any = xfrm_pol_inexact_addr_use_any_list(addr, + policy->family, + prefixlen); + return saddr_any && daddr_any; +} + +static void xfrm_pol_inexact_node_init(struct xfrm_pol_inexact_node *node, + const xfrm_address_t *addr, u8 prefixlen) +{ + node->addr = *addr; + node->prefixlen = prefixlen; +} + +static struct xfrm_pol_inexact_node * +xfrm_pol_inexact_node_alloc(const xfrm_address_t *addr, u8 prefixlen) +{ + struct xfrm_pol_inexact_node *node; + + node = kzalloc(sizeof(*node), GFP_ATOMIC); + if (node) + xfrm_pol_inexact_node_init(node, addr, prefixlen); + + return node; +} + +static int xfrm_policy_addr_delta(const xfrm_address_t *a, + const xfrm_address_t *b, + u8 prefixlen, u16 family) +{ + u32 ma, mb, mask; + unsigned int pdw, pbi; + int delta = 0; + + switch (family) { + case AF_INET: + if (prefixlen == 0) + return 0; + mask = ~0U << (32 - prefixlen); + ma = ntohl(a->a4) & mask; + mb = ntohl(b->a4) & mask; + if (ma < mb) + delta = -1; + else if (ma > mb) + delta = 1; + break; + case AF_INET6: + pdw = prefixlen >> 5; + pbi = prefixlen & 0x1f; + + if (pdw) { + delta = memcmp(a->a6, b->a6, pdw << 2); + if (delta) + return delta; + } + if (pbi) { + mask = ~0U << (32 - pbi); + ma = ntohl(a->a6[pdw]) & mask; + mb = ntohl(b->a6[pdw]) & mask; + if (ma < mb) + delta = -1; + else if (ma > mb) + delta = 1; + } + break; + default: + break; + } + + return delta; +} + +static void xfrm_policy_inexact_list_reinsert(struct net *net, + struct xfrm_pol_inexact_node *n, + u16 family) +{ + unsigned int matched_s, matched_d; + struct xfrm_policy *policy, *p; + + matched_s = 0; + matched_d = 0; + + list_for_each_entry_reverse(policy, &net->xfrm.policy_all, walk.all) { + struct hlist_node *newpos = NULL; + bool matches_s, matches_d; + + if (policy->walk.dead || !policy->bydst_reinsert) + continue; + + WARN_ON_ONCE(policy->family != family); + + policy->bydst_reinsert = false; + hlist_for_each_entry(p, &n->hhead, bydst) { + if (policy->priority > p->priority) + newpos = &p->bydst; + else if (policy->priority == p->priority && + policy->pos > p->pos) + newpos = &p->bydst; + else + break; + } + + if (newpos) + hlist_add_behind_rcu(&policy->bydst, newpos); + else + hlist_add_head_rcu(&policy->bydst, &n->hhead); + + /* paranoia checks follow. + * Check that the reinserted policy matches at least + * saddr or daddr for current node prefix. + * + * Matching both is fine, matching saddr in one policy + * (but not daddr) and then matching only daddr in another + * is a bug. + */ + matches_s = xfrm_policy_addr_delta(&policy->selector.saddr, + &n->addr, + n->prefixlen, + family) == 0; + matches_d = xfrm_policy_addr_delta(&policy->selector.daddr, + &n->addr, + n->prefixlen, + family) == 0; + if (matches_s && matches_d) + continue; + + WARN_ON_ONCE(!matches_s && !matches_d); + if (matches_s) + matched_s++; + if (matches_d) + matched_d++; + WARN_ON_ONCE(matched_s && matched_d); + } +} + +static void xfrm_policy_inexact_node_reinsert(struct net *net, + struct xfrm_pol_inexact_node *n, + struct rb_root *new, + u16 family) +{ + struct xfrm_pol_inexact_node *node; + struct rb_node **p, *parent; + + /* we should not have another subtree here */ + WARN_ON_ONCE(!RB_EMPTY_ROOT(&n->root)); +restart: + parent = NULL; + p = &new->rb_node; + while (*p) { + u8 prefixlen; + int delta; + + parent = *p; + node = rb_entry(*p, struct xfrm_pol_inexact_node, node); + + prefixlen = min(node->prefixlen, n->prefixlen); + + delta = xfrm_policy_addr_delta(&n->addr, &node->addr, + prefixlen, family); + if (delta < 0) { + p = &parent->rb_left; + } else if (delta > 0) { + p = &parent->rb_right; + } else { + bool same_prefixlen = node->prefixlen == n->prefixlen; + struct xfrm_policy *tmp; + + hlist_for_each_entry(tmp, &n->hhead, bydst) { + tmp->bydst_reinsert = true; + hlist_del_rcu(&tmp->bydst); + } + + node->prefixlen = prefixlen; + + xfrm_policy_inexact_list_reinsert(net, node, family); + + if (same_prefixlen) { + kfree_rcu(n, rcu); + return; + } + + rb_erase(*p, new); + kfree_rcu(n, rcu); + n = node; + goto restart; + } + } + + rb_link_node_rcu(&n->node, parent, p); + rb_insert_color(&n->node, new); +} + +/* merge nodes v and n */ +static void xfrm_policy_inexact_node_merge(struct net *net, + struct xfrm_pol_inexact_node *v, + struct xfrm_pol_inexact_node *n, + u16 family) +{ + struct xfrm_pol_inexact_node *node; + struct xfrm_policy *tmp; + struct rb_node *rnode; + + /* To-be-merged node v has a subtree. + * + * Dismantle it and insert its nodes to n->root. + */ + while ((rnode = rb_first(&v->root)) != NULL) { + node = rb_entry(rnode, struct xfrm_pol_inexact_node, node); + rb_erase(&node->node, &v->root); + xfrm_policy_inexact_node_reinsert(net, node, &n->root, + family); + } + + hlist_for_each_entry(tmp, &v->hhead, bydst) { + tmp->bydst_reinsert = true; + hlist_del_rcu(&tmp->bydst); + } + + xfrm_policy_inexact_list_reinsert(net, n, family); +} + +static struct xfrm_pol_inexact_node * +xfrm_policy_inexact_insert_node(struct net *net, + struct rb_root *root, + xfrm_address_t *addr, + u16 family, u8 prefixlen, u8 dir) +{ + struct xfrm_pol_inexact_node *cached = NULL; + struct rb_node **p, *parent = NULL; + struct xfrm_pol_inexact_node *node; + + p = &root->rb_node; + while (*p) { + int delta; + + parent = *p; + node = rb_entry(*p, struct xfrm_pol_inexact_node, node); + + delta = xfrm_policy_addr_delta(addr, &node->addr, + node->prefixlen, + family); + if (delta == 0 && prefixlen >= node->prefixlen) { + WARN_ON_ONCE(cached); /* ipsec policies got lost */ + return node; + } + + if (delta < 0) + p = &parent->rb_left; + else + p = &parent->rb_right; + + if (prefixlen < node->prefixlen) { + delta = xfrm_policy_addr_delta(addr, &node->addr, + prefixlen, + family); + if (delta) + continue; + + /* This node is a subnet of the new prefix. It needs + * to be removed and re-inserted with the smaller + * prefix and all nodes that are now also covered + * by the reduced prefixlen. + */ + rb_erase(&node->node, root); + + if (!cached) { + xfrm_pol_inexact_node_init(node, addr, + prefixlen); + cached = node; + } else { + /* This node also falls within the new + * prefixlen. Merge the to-be-reinserted + * node and this one. + */ + xfrm_policy_inexact_node_merge(net, node, + cached, family); + kfree_rcu(node, rcu); + } + + /* restart */ + p = &root->rb_node; + parent = NULL; + } + } + + node = cached; + if (!node) { + node = xfrm_pol_inexact_node_alloc(addr, prefixlen); + if (!node) + return NULL; + } + + rb_link_node_rcu(&node->node, parent, p); + rb_insert_color(&node->node, root); + + return node; +} + +static void xfrm_policy_inexact_gc_tree(struct rb_root *r, bool rm) +{ + struct xfrm_pol_inexact_node *node; + struct rb_node *rn = rb_first(r); + + while (rn) { + node = rb_entry(rn, struct xfrm_pol_inexact_node, node); + + xfrm_policy_inexact_gc_tree(&node->root, rm); + rn = rb_next(rn); + + if (!hlist_empty(&node->hhead) || !RB_EMPTY_ROOT(&node->root)) { + WARN_ON_ONCE(rm); + continue; + } + + rb_erase(&node->node, r); + kfree_rcu(node, rcu); + } +} + +static void __xfrm_policy_inexact_prune_bin(struct xfrm_pol_inexact_bin *b, bool net_exit) +{ + write_seqcount_begin(&b->count); + xfrm_policy_inexact_gc_tree(&b->root_d, net_exit); + xfrm_policy_inexact_gc_tree(&b->root_s, net_exit); + write_seqcount_end(&b->count); + + if (!RB_EMPTY_ROOT(&b->root_d) || !RB_EMPTY_ROOT(&b->root_s) || + !hlist_empty(&b->hhead)) { + WARN_ON_ONCE(net_exit); + return; + } + + if (rhashtable_remove_fast(&xfrm_policy_inexact_table, &b->head, + xfrm_pol_inexact_params) == 0) { + list_del(&b->inexact_bins); + kfree_rcu(b, rcu); + } +} + +static void xfrm_policy_inexact_prune_bin(struct xfrm_pol_inexact_bin *b) +{ + struct net *net = read_pnet(&b->k.net); + + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + __xfrm_policy_inexact_prune_bin(b, false); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); +} + +static void __xfrm_policy_inexact_flush(struct net *net) +{ + struct xfrm_pol_inexact_bin *bin, *t; + + lockdep_assert_held(&net->xfrm.xfrm_policy_lock); + + list_for_each_entry_safe(bin, t, &net->xfrm.inexact_bins, inexact_bins) + __xfrm_policy_inexact_prune_bin(bin, false); +} + +static struct hlist_head * +xfrm_policy_inexact_alloc_chain(struct xfrm_pol_inexact_bin *bin, + struct xfrm_policy *policy, u8 dir) +{ + struct xfrm_pol_inexact_node *n; + struct net *net; + + net = xp_net(policy); + lockdep_assert_held(&net->xfrm.xfrm_policy_lock); + + if (xfrm_policy_inexact_insert_use_any_list(policy)) + return &bin->hhead; + + if (xfrm_pol_inexact_addr_use_any_list(&policy->selector.daddr, + policy->family, + policy->selector.prefixlen_d)) { + write_seqcount_begin(&bin->count); + n = xfrm_policy_inexact_insert_node(net, + &bin->root_s, + &policy->selector.saddr, + policy->family, + policy->selector.prefixlen_s, + dir); + write_seqcount_end(&bin->count); + if (!n) + return NULL; + + return &n->hhead; + } + + /* daddr is fixed */ + write_seqcount_begin(&bin->count); + n = xfrm_policy_inexact_insert_node(net, + &bin->root_d, + &policy->selector.daddr, + policy->family, + policy->selector.prefixlen_d, dir); + write_seqcount_end(&bin->count); + if (!n) + return NULL; + + /* saddr is wildcard */ + if (xfrm_pol_inexact_addr_use_any_list(&policy->selector.saddr, + policy->family, + policy->selector.prefixlen_s)) + return &n->hhead; + + write_seqcount_begin(&bin->count); + n = xfrm_policy_inexact_insert_node(net, + &n->root, + &policy->selector.saddr, + policy->family, + policy->selector.prefixlen_s, dir); + write_seqcount_end(&bin->count); + if (!n) + return NULL; + + return &n->hhead; +} + +static struct xfrm_policy * +xfrm_policy_inexact_insert(struct xfrm_policy *policy, u8 dir, int excl) +{ + struct xfrm_pol_inexact_bin *bin; + struct xfrm_policy *delpol; + struct hlist_head *chain; + struct net *net; + + bin = xfrm_policy_inexact_alloc_bin(policy, dir); + if (!bin) + return ERR_PTR(-ENOMEM); + + net = xp_net(policy); + lockdep_assert_held(&net->xfrm.xfrm_policy_lock); + + chain = xfrm_policy_inexact_alloc_chain(bin, policy, dir); + if (!chain) { + __xfrm_policy_inexact_prune_bin(bin, false); + return ERR_PTR(-ENOMEM); + } + + delpol = xfrm_policy_insert_list(chain, policy, excl); + if (delpol && excl) { + __xfrm_policy_inexact_prune_bin(bin, false); + return ERR_PTR(-EEXIST); + } + + chain = &net->xfrm.policy_inexact[dir]; + xfrm_policy_insert_inexact_list(chain, policy); + + if (delpol) + __xfrm_policy_inexact_prune_bin(bin, false); + + return delpol; +} + +static void xfrm_hash_rebuild(struct work_struct *work) +{ + struct net *net = container_of(work, struct net, + xfrm.policy_hthresh.work); + unsigned int hmask; + struct xfrm_policy *pol; + struct xfrm_policy *policy; + struct hlist_head *chain; + struct hlist_head *odst; + struct hlist_node *newpos; + int i; + int dir; + unsigned seq; + u8 lbits4, rbits4, lbits6, rbits6; + + mutex_lock(&hash_resize_mutex); + + /* read selector prefixlen thresholds */ + do { + seq = read_seqbegin(&net->xfrm.policy_hthresh.lock); + + lbits4 = net->xfrm.policy_hthresh.lbits4; + rbits4 = net->xfrm.policy_hthresh.rbits4; + lbits6 = net->xfrm.policy_hthresh.lbits6; + rbits6 = net->xfrm.policy_hthresh.rbits6; + } while (read_seqretry(&net->xfrm.policy_hthresh.lock, seq)); + + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + write_seqcount_begin(&net->xfrm.xfrm_policy_hash_generation); + + /* make sure that we can insert the indirect policies again before + * we start with destructive action. + */ + list_for_each_entry(policy, &net->xfrm.policy_all, walk.all) { + struct xfrm_pol_inexact_bin *bin; + u8 dbits, sbits; + + if (policy->walk.dead) + continue; + + dir = xfrm_policy_id2dir(policy->index); + if (dir >= XFRM_POLICY_MAX) + continue; + + if ((dir & XFRM_POLICY_MASK) == XFRM_POLICY_OUT) { + if (policy->family == AF_INET) { + dbits = rbits4; + sbits = lbits4; + } else { + dbits = rbits6; + sbits = lbits6; + } + } else { + if (policy->family == AF_INET) { + dbits = lbits4; + sbits = rbits4; + } else { + dbits = lbits6; + sbits = rbits6; + } + } + + if (policy->selector.prefixlen_d < dbits || + policy->selector.prefixlen_s < sbits) + continue; + + bin = xfrm_policy_inexact_alloc_bin(policy, dir); + if (!bin) + goto out_unlock; + + if (!xfrm_policy_inexact_alloc_chain(bin, policy, dir)) + goto out_unlock; + } + + /* reset the bydst and inexact table in all directions */ + for (dir = 0; dir < XFRM_POLICY_MAX; dir++) { + struct hlist_node *n; + + hlist_for_each_entry_safe(policy, n, + &net->xfrm.policy_inexact[dir], + bydst_inexact_list) { + hlist_del_rcu(&policy->bydst); + hlist_del_init(&policy->bydst_inexact_list); + } + + hmask = net->xfrm.policy_bydst[dir].hmask; + odst = net->xfrm.policy_bydst[dir].table; + for (i = hmask; i >= 0; i--) { + hlist_for_each_entry_safe(policy, n, odst + i, bydst) + hlist_del_rcu(&policy->bydst); + } + if ((dir & XFRM_POLICY_MASK) == XFRM_POLICY_OUT) { + /* dir out => dst = remote, src = local */ + net->xfrm.policy_bydst[dir].dbits4 = rbits4; + net->xfrm.policy_bydst[dir].sbits4 = lbits4; + net->xfrm.policy_bydst[dir].dbits6 = rbits6; + net->xfrm.policy_bydst[dir].sbits6 = lbits6; + } else { + /* dir in/fwd => dst = local, src = remote */ + net->xfrm.policy_bydst[dir].dbits4 = lbits4; + net->xfrm.policy_bydst[dir].sbits4 = rbits4; + net->xfrm.policy_bydst[dir].dbits6 = lbits6; + net->xfrm.policy_bydst[dir].sbits6 = rbits6; + } + } + + /* re-insert all policies by order of creation */ + list_for_each_entry_reverse(policy, &net->xfrm.policy_all, walk.all) { + if (policy->walk.dead) + continue; + dir = xfrm_policy_id2dir(policy->index); + if (dir >= XFRM_POLICY_MAX) { + /* skip socket policies */ + continue; + } + newpos = NULL; + chain = policy_hash_bysel(net, &policy->selector, + policy->family, dir); + + if (!chain) { + void *p = xfrm_policy_inexact_insert(policy, dir, 0); + + WARN_ONCE(IS_ERR(p), "reinsert: %ld\n", PTR_ERR(p)); + continue; + } + + hlist_for_each_entry(pol, chain, bydst) { + if (policy->priority >= pol->priority) + newpos = &pol->bydst; + else + break; + } + if (newpos) + hlist_add_behind_rcu(&policy->bydst, newpos); + else + hlist_add_head_rcu(&policy->bydst, chain); + } + +out_unlock: + __xfrm_policy_inexact_flush(net); + write_seqcount_end(&net->xfrm.xfrm_policy_hash_generation); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + + mutex_unlock(&hash_resize_mutex); +} + +void xfrm_policy_hash_rebuild(struct net *net) +{ + schedule_work(&net->xfrm.policy_hthresh.work); +} +EXPORT_SYMBOL(xfrm_policy_hash_rebuild); + +/* Generate new index... KAME seems to generate them ordered by cost + * of an absolute inpredictability of ordering of rules. This will not pass. */ +static u32 xfrm_gen_index(struct net *net, int dir, u32 index) +{ + for (;;) { + struct hlist_head *list; + struct xfrm_policy *p; + u32 idx; + int found; + + if (!index) { + idx = (net->xfrm.idx_generator | dir); + net->xfrm.idx_generator += 8; + } else { + idx = index; + index = 0; + } + + if (idx == 0) + idx = 8; + list = net->xfrm.policy_byidx + idx_hash(net, idx); + found = 0; + hlist_for_each_entry(p, list, byidx) { + if (p->index == idx) { + found = 1; + break; + } + } + if (!found) + return idx; + } +} + +static inline int selector_cmp(struct xfrm_selector *s1, struct xfrm_selector *s2) +{ + u32 *p1 = (u32 *) s1; + u32 *p2 = (u32 *) s2; + int len = sizeof(struct xfrm_selector) / sizeof(u32); + int i; + + for (i = 0; i < len; i++) { + if (p1[i] != p2[i]) + return 1; + } + + return 0; +} + +static void xfrm_policy_requeue(struct xfrm_policy *old, + struct xfrm_policy *new) +{ + struct xfrm_policy_queue *pq = &old->polq; + struct sk_buff_head list; + + if (skb_queue_empty(&pq->hold_queue)) + return; + + __skb_queue_head_init(&list); + + spin_lock_bh(&pq->hold_queue.lock); + skb_queue_splice_init(&pq->hold_queue, &list); + if (del_timer(&pq->hold_timer)) + xfrm_pol_put(old); + spin_unlock_bh(&pq->hold_queue.lock); + + pq = &new->polq; + + spin_lock_bh(&pq->hold_queue.lock); + skb_queue_splice(&list, &pq->hold_queue); + pq->timeout = XFRM_QUEUE_TMO_MIN; + if (!mod_timer(&pq->hold_timer, jiffies)) + xfrm_pol_hold(new); + spin_unlock_bh(&pq->hold_queue.lock); +} + +static inline bool xfrm_policy_mark_match(const struct xfrm_mark *mark, + struct xfrm_policy *pol) +{ + return mark->v == pol->mark.v && mark->m == pol->mark.m; +} + +static u32 xfrm_pol_bin_key(const void *data, u32 len, u32 seed) +{ + const struct xfrm_pol_inexact_key *k = data; + u32 a = k->type << 24 | k->dir << 16 | k->family; + + return jhash_3words(a, k->if_id, net_hash_mix(read_pnet(&k->net)), + seed); +} + +static u32 xfrm_pol_bin_obj(const void *data, u32 len, u32 seed) +{ + const struct xfrm_pol_inexact_bin *b = data; + + return xfrm_pol_bin_key(&b->k, 0, seed); +} + +static int xfrm_pol_bin_cmp(struct rhashtable_compare_arg *arg, + const void *ptr) +{ + const struct xfrm_pol_inexact_key *key = arg->key; + const struct xfrm_pol_inexact_bin *b = ptr; + int ret; + + if (!net_eq(read_pnet(&b->k.net), read_pnet(&key->net))) + return -1; + + ret = b->k.dir ^ key->dir; + if (ret) + return ret; + + ret = b->k.type ^ key->type; + if (ret) + return ret; + + ret = b->k.family ^ key->family; + if (ret) + return ret; + + return b->k.if_id ^ key->if_id; +} + +static const struct rhashtable_params xfrm_pol_inexact_params = { + .head_offset = offsetof(struct xfrm_pol_inexact_bin, head), + .hashfn = xfrm_pol_bin_key, + .obj_hashfn = xfrm_pol_bin_obj, + .obj_cmpfn = xfrm_pol_bin_cmp, + .automatic_shrinking = true, +}; + +static void xfrm_policy_insert_inexact_list(struct hlist_head *chain, + struct xfrm_policy *policy) +{ + struct xfrm_policy *pol, *delpol = NULL; + struct hlist_node *newpos = NULL; + int i = 0; + + hlist_for_each_entry(pol, chain, bydst_inexact_list) { + if (pol->type == policy->type && + pol->if_id == policy->if_id && + !selector_cmp(&pol->selector, &policy->selector) && + xfrm_policy_mark_match(&policy->mark, pol) && + xfrm_sec_ctx_match(pol->security, policy->security) && + !WARN_ON(delpol)) { + delpol = pol; + if (policy->priority > pol->priority) + continue; + } else if (policy->priority >= pol->priority) { + newpos = &pol->bydst_inexact_list; + continue; + } + if (delpol) + break; + } + + if (newpos) + hlist_add_behind_rcu(&policy->bydst_inexact_list, newpos); + else + hlist_add_head_rcu(&policy->bydst_inexact_list, chain); + + hlist_for_each_entry(pol, chain, bydst_inexact_list) { + pol->pos = i; + i++; + } +} + +static struct xfrm_policy *xfrm_policy_insert_list(struct hlist_head *chain, + struct xfrm_policy *policy, + bool excl) +{ + struct xfrm_policy *pol, *newpos = NULL, *delpol = NULL; + + hlist_for_each_entry(pol, chain, bydst) { + if (pol->type == policy->type && + pol->if_id == policy->if_id && + !selector_cmp(&pol->selector, &policy->selector) && + xfrm_policy_mark_match(&policy->mark, pol) && + xfrm_sec_ctx_match(pol->security, policy->security) && + !WARN_ON(delpol)) { + if (excl) + return ERR_PTR(-EEXIST); + delpol = pol; + if (policy->priority > pol->priority) + continue; + } else if (policy->priority >= pol->priority) { + newpos = pol; + continue; + } + if (delpol) + break; + } + + if (newpos) + hlist_add_behind_rcu(&policy->bydst, &newpos->bydst); + else + hlist_add_head_rcu(&policy->bydst, chain); + + return delpol; +} + +int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl) +{ + struct net *net = xp_net(policy); + struct xfrm_policy *delpol; + struct hlist_head *chain; + + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + chain = policy_hash_bysel(net, &policy->selector, policy->family, dir); + if (chain) + delpol = xfrm_policy_insert_list(chain, policy, excl); + else + delpol = xfrm_policy_inexact_insert(policy, dir, excl); + + if (IS_ERR(delpol)) { + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + return PTR_ERR(delpol); + } + + __xfrm_policy_link(policy, dir); + + /* After previous checking, family can either be AF_INET or AF_INET6 */ + if (policy->family == AF_INET) + rt_genid_bump_ipv4(net); + else + rt_genid_bump_ipv6(net); + + if (delpol) { + xfrm_policy_requeue(delpol, policy); + __xfrm_policy_unlink(delpol, dir); + } + policy->index = delpol ? delpol->index : xfrm_gen_index(net, dir, policy->index); + hlist_add_head(&policy->byidx, net->xfrm.policy_byidx+idx_hash(net, policy->index)); + policy->curlft.add_time = ktime_get_real_seconds(); + policy->curlft.use_time = 0; + if (!mod_timer(&policy->timer, jiffies + HZ)) + xfrm_pol_hold(policy); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + + if (delpol) + xfrm_policy_kill(delpol); + else if (xfrm_bydst_should_resize(net, dir, NULL)) + schedule_work(&net->xfrm.policy_hash_work); + + return 0; +} +EXPORT_SYMBOL(xfrm_policy_insert); + +static struct xfrm_policy * +__xfrm_policy_bysel_ctx(struct hlist_head *chain, const struct xfrm_mark *mark, + u32 if_id, u8 type, int dir, struct xfrm_selector *sel, + struct xfrm_sec_ctx *ctx) +{ + struct xfrm_policy *pol; + + if (!chain) + return NULL; + + hlist_for_each_entry(pol, chain, bydst) { + if (pol->type == type && + pol->if_id == if_id && + xfrm_policy_mark_match(mark, pol) && + !selector_cmp(sel, &pol->selector) && + xfrm_sec_ctx_match(ctx, pol->security)) + return pol; + } + + return NULL; +} + +struct xfrm_policy * +xfrm_policy_bysel_ctx(struct net *net, const struct xfrm_mark *mark, u32 if_id, + u8 type, int dir, struct xfrm_selector *sel, + struct xfrm_sec_ctx *ctx, int delete, int *err) +{ + struct xfrm_pol_inexact_bin *bin = NULL; + struct xfrm_policy *pol, *ret = NULL; + struct hlist_head *chain; + + *err = 0; + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + chain = policy_hash_bysel(net, sel, sel->family, dir); + if (!chain) { + struct xfrm_pol_inexact_candidates cand; + int i; + + bin = xfrm_policy_inexact_lookup(net, type, + sel->family, dir, if_id); + if (!bin) { + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + return NULL; + } + + if (!xfrm_policy_find_inexact_candidates(&cand, bin, + &sel->saddr, + &sel->daddr)) { + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + return NULL; + } + + pol = NULL; + for (i = 0; i < ARRAY_SIZE(cand.res); i++) { + struct xfrm_policy *tmp; + + tmp = __xfrm_policy_bysel_ctx(cand.res[i], mark, + if_id, type, dir, + sel, ctx); + if (!tmp) + continue; + + if (!pol || tmp->pos < pol->pos) + pol = tmp; + } + } else { + pol = __xfrm_policy_bysel_ctx(chain, mark, if_id, type, dir, + sel, ctx); + } + + if (pol) { + xfrm_pol_hold(pol); + if (delete) { + *err = security_xfrm_policy_delete(pol->security); + if (*err) { + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + return pol; + } + __xfrm_policy_unlink(pol, dir); + } + ret = pol; + } + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + + if (ret && delete) + xfrm_policy_kill(ret); + if (bin && delete) + xfrm_policy_inexact_prune_bin(bin); + return ret; +} +EXPORT_SYMBOL(xfrm_policy_bysel_ctx); + +struct xfrm_policy * +xfrm_policy_byid(struct net *net, const struct xfrm_mark *mark, u32 if_id, + u8 type, int dir, u32 id, int delete, int *err) +{ + struct xfrm_policy *pol, *ret; + struct hlist_head *chain; + + *err = -ENOENT; + if (xfrm_policy_id2dir(id) != dir) + return NULL; + + *err = 0; + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + chain = net->xfrm.policy_byidx + idx_hash(net, id); + ret = NULL; + hlist_for_each_entry(pol, chain, byidx) { + if (pol->type == type && pol->index == id && + pol->if_id == if_id && xfrm_policy_mark_match(mark, pol)) { + xfrm_pol_hold(pol); + if (delete) { + *err = security_xfrm_policy_delete( + pol->security); + if (*err) { + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + return pol; + } + __xfrm_policy_unlink(pol, dir); + } + ret = pol; + break; + } + } + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + + if (ret && delete) + xfrm_policy_kill(ret); + return ret; +} +EXPORT_SYMBOL(xfrm_policy_byid); + +#ifdef CONFIG_SECURITY_NETWORK_XFRM +static inline int +xfrm_policy_flush_secctx_check(struct net *net, u8 type, bool task_valid) +{ + struct xfrm_policy *pol; + int err = 0; + + list_for_each_entry(pol, &net->xfrm.policy_all, walk.all) { + if (pol->walk.dead || + xfrm_policy_id2dir(pol->index) >= XFRM_POLICY_MAX || + pol->type != type) + continue; + + err = security_xfrm_policy_delete(pol->security); + if (err) { + xfrm_audit_policy_delete(pol, 0, task_valid); + return err; + } + } + return err; +} +#else +static inline int +xfrm_policy_flush_secctx_check(struct net *net, u8 type, bool task_valid) +{ + return 0; +} +#endif + +int xfrm_policy_flush(struct net *net, u8 type, bool task_valid) +{ + int dir, err = 0, cnt = 0; + struct xfrm_policy *pol; + + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + + err = xfrm_policy_flush_secctx_check(net, type, task_valid); + if (err) + goto out; + +again: + list_for_each_entry(pol, &net->xfrm.policy_all, walk.all) { + if (pol->walk.dead) + continue; + + dir = xfrm_policy_id2dir(pol->index); + if (dir >= XFRM_POLICY_MAX || + pol->type != type) + continue; + + __xfrm_policy_unlink(pol, dir); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + cnt++; + xfrm_audit_policy_delete(pol, 1, task_valid); + xfrm_policy_kill(pol); + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + goto again; + } + if (cnt) + __xfrm_policy_inexact_flush(net); + else + err = -ESRCH; +out: + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + return err; +} +EXPORT_SYMBOL(xfrm_policy_flush); + +int xfrm_policy_walk(struct net *net, struct xfrm_policy_walk *walk, + int (*func)(struct xfrm_policy *, int, int, void*), + void *data) +{ + struct xfrm_policy *pol; + struct xfrm_policy_walk_entry *x; + int error = 0; + + if (walk->type >= XFRM_POLICY_TYPE_MAX && + walk->type != XFRM_POLICY_TYPE_ANY) + return -EINVAL; + + if (list_empty(&walk->walk.all) && walk->seq != 0) + return 0; + + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + if (list_empty(&walk->walk.all)) + x = list_first_entry(&net->xfrm.policy_all, struct xfrm_policy_walk_entry, all); + else + x = list_first_entry(&walk->walk.all, + struct xfrm_policy_walk_entry, all); + + list_for_each_entry_from(x, &net->xfrm.policy_all, all) { + if (x->dead) + continue; + pol = container_of(x, struct xfrm_policy, walk); + if (walk->type != XFRM_POLICY_TYPE_ANY && + walk->type != pol->type) + continue; + error = func(pol, xfrm_policy_id2dir(pol->index), + walk->seq, data); + if (error) { + list_move_tail(&walk->walk.all, &x->all); + goto out; + } + walk->seq++; + } + if (walk->seq == 0) { + error = -ENOENT; + goto out; + } + list_del_init(&walk->walk.all); +out: + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + return error; +} +EXPORT_SYMBOL(xfrm_policy_walk); + +void xfrm_policy_walk_init(struct xfrm_policy_walk *walk, u8 type) +{ + INIT_LIST_HEAD(&walk->walk.all); + walk->walk.dead = 1; + walk->type = type; + walk->seq = 0; +} +EXPORT_SYMBOL(xfrm_policy_walk_init); + +void xfrm_policy_walk_done(struct xfrm_policy_walk *walk, struct net *net) +{ + if (list_empty(&walk->walk.all)) + return; + + spin_lock_bh(&net->xfrm.xfrm_policy_lock); /*FIXME where is net? */ + list_del(&walk->walk.all); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); +} +EXPORT_SYMBOL(xfrm_policy_walk_done); + +/* + * Find policy to apply to this flow. + * + * Returns 0 if policy found, else an -errno. + */ +static int xfrm_policy_match(const struct xfrm_policy *pol, + const struct flowi *fl, + u8 type, u16 family, u32 if_id) +{ + const struct xfrm_selector *sel = &pol->selector; + int ret = -ESRCH; + bool match; + + if (pol->family != family || + pol->if_id != if_id || + (fl->flowi_mark & pol->mark.m) != pol->mark.v || + pol->type != type) + return ret; + + match = xfrm_selector_match(sel, fl, family); + if (match) + ret = security_xfrm_policy_lookup(pol->security, fl->flowi_secid); + return ret; +} + +static struct xfrm_pol_inexact_node * +xfrm_policy_lookup_inexact_addr(const struct rb_root *r, + seqcount_spinlock_t *count, + const xfrm_address_t *addr, u16 family) +{ + const struct rb_node *parent; + int seq; + +again: + seq = read_seqcount_begin(count); + + parent = rcu_dereference_raw(r->rb_node); + while (parent) { + struct xfrm_pol_inexact_node *node; + int delta; + + node = rb_entry(parent, struct xfrm_pol_inexact_node, node); + + delta = xfrm_policy_addr_delta(addr, &node->addr, + node->prefixlen, family); + if (delta < 0) { + parent = rcu_dereference_raw(parent->rb_left); + continue; + } else if (delta > 0) { + parent = rcu_dereference_raw(parent->rb_right); + continue; + } + + return node; + } + + if (read_seqcount_retry(count, seq)) + goto again; + + return NULL; +} + +static bool +xfrm_policy_find_inexact_candidates(struct xfrm_pol_inexact_candidates *cand, + struct xfrm_pol_inexact_bin *b, + const xfrm_address_t *saddr, + const xfrm_address_t *daddr) +{ + struct xfrm_pol_inexact_node *n; + u16 family; + + if (!b) + return false; + + family = b->k.family; + memset(cand, 0, sizeof(*cand)); + cand->res[XFRM_POL_CAND_ANY] = &b->hhead; + + n = xfrm_policy_lookup_inexact_addr(&b->root_d, &b->count, daddr, + family); + if (n) { + cand->res[XFRM_POL_CAND_DADDR] = &n->hhead; + n = xfrm_policy_lookup_inexact_addr(&n->root, &b->count, saddr, + family); + if (n) + cand->res[XFRM_POL_CAND_BOTH] = &n->hhead; + } + + n = xfrm_policy_lookup_inexact_addr(&b->root_s, &b->count, saddr, + family); + if (n) + cand->res[XFRM_POL_CAND_SADDR] = &n->hhead; + + return true; +} + +static struct xfrm_pol_inexact_bin * +xfrm_policy_inexact_lookup_rcu(struct net *net, u8 type, u16 family, + u8 dir, u32 if_id) +{ + struct xfrm_pol_inexact_key k = { + .family = family, + .type = type, + .dir = dir, + .if_id = if_id, + }; + + write_pnet(&k.net, net); + + return rhashtable_lookup(&xfrm_policy_inexact_table, &k, + xfrm_pol_inexact_params); +} + +static struct xfrm_pol_inexact_bin * +xfrm_policy_inexact_lookup(struct net *net, u8 type, u16 family, + u8 dir, u32 if_id) +{ + struct xfrm_pol_inexact_bin *bin; + + lockdep_assert_held(&net->xfrm.xfrm_policy_lock); + + rcu_read_lock(); + bin = xfrm_policy_inexact_lookup_rcu(net, type, family, dir, if_id); + rcu_read_unlock(); + + return bin; +} + +static struct xfrm_policy * +__xfrm_policy_eval_candidates(struct hlist_head *chain, + struct xfrm_policy *prefer, + const struct flowi *fl, + u8 type, u16 family, u32 if_id) +{ + u32 priority = prefer ? prefer->priority : ~0u; + struct xfrm_policy *pol; + + if (!chain) + return NULL; + + hlist_for_each_entry_rcu(pol, chain, bydst) { + int err; + + if (pol->priority > priority) + break; + + err = xfrm_policy_match(pol, fl, type, family, if_id); + if (err) { + if (err != -ESRCH) + return ERR_PTR(err); + + continue; + } + + if (prefer) { + /* matches. Is it older than *prefer? */ + if (pol->priority == priority && + prefer->pos < pol->pos) + return prefer; + } + + return pol; + } + + return NULL; +} + +static struct xfrm_policy * +xfrm_policy_eval_candidates(struct xfrm_pol_inexact_candidates *cand, + struct xfrm_policy *prefer, + const struct flowi *fl, + u8 type, u16 family, u32 if_id) +{ + struct xfrm_policy *tmp; + int i; + + for (i = 0; i < ARRAY_SIZE(cand->res); i++) { + tmp = __xfrm_policy_eval_candidates(cand->res[i], + prefer, + fl, type, family, if_id); + if (!tmp) + continue; + + if (IS_ERR(tmp)) + return tmp; + prefer = tmp; + } + + return prefer; +} + +static struct xfrm_policy *xfrm_policy_lookup_bytype(struct net *net, u8 type, + const struct flowi *fl, + u16 family, u8 dir, + u32 if_id) +{ + struct xfrm_pol_inexact_candidates cand; + const xfrm_address_t *daddr, *saddr; + struct xfrm_pol_inexact_bin *bin; + struct xfrm_policy *pol, *ret; + struct hlist_head *chain; + unsigned int sequence; + int err; + + daddr = xfrm_flowi_daddr(fl, family); + saddr = xfrm_flowi_saddr(fl, family); + if (unlikely(!daddr || !saddr)) + return NULL; + + rcu_read_lock(); + retry: + do { + sequence = read_seqcount_begin(&net->xfrm.xfrm_policy_hash_generation); + chain = policy_hash_direct(net, daddr, saddr, family, dir); + } while (read_seqcount_retry(&net->xfrm.xfrm_policy_hash_generation, sequence)); + + ret = NULL; + hlist_for_each_entry_rcu(pol, chain, bydst) { + err = xfrm_policy_match(pol, fl, type, family, if_id); + if (err) { + if (err == -ESRCH) + continue; + else { + ret = ERR_PTR(err); + goto fail; + } + } else { + ret = pol; + break; + } + } + bin = xfrm_policy_inexact_lookup_rcu(net, type, family, dir, if_id); + if (!bin || !xfrm_policy_find_inexact_candidates(&cand, bin, saddr, + daddr)) + goto skip_inexact; + + pol = xfrm_policy_eval_candidates(&cand, ret, fl, type, + family, if_id); + if (pol) { + ret = pol; + if (IS_ERR(pol)) + goto fail; + } + +skip_inexact: + if (read_seqcount_retry(&net->xfrm.xfrm_policy_hash_generation, sequence)) + goto retry; + + if (ret && !xfrm_pol_hold_rcu(ret)) + goto retry; +fail: + rcu_read_unlock(); + + return ret; +} + +static struct xfrm_policy *xfrm_policy_lookup(struct net *net, + const struct flowi *fl, + u16 family, u8 dir, u32 if_id) +{ +#ifdef CONFIG_XFRM_SUB_POLICY + struct xfrm_policy *pol; + + pol = xfrm_policy_lookup_bytype(net, XFRM_POLICY_TYPE_SUB, fl, family, + dir, if_id); + if (pol != NULL) + return pol; +#endif + return xfrm_policy_lookup_bytype(net, XFRM_POLICY_TYPE_MAIN, fl, family, + dir, if_id); +} + +static struct xfrm_policy *xfrm_sk_policy_lookup(const struct sock *sk, int dir, + const struct flowi *fl, + u16 family, u32 if_id) +{ + struct xfrm_policy *pol; + + rcu_read_lock(); + again: + pol = rcu_dereference(sk->sk_policy[dir]); + if (pol != NULL) { + bool match; + int err = 0; + + if (pol->family != family) { + pol = NULL; + goto out; + } + + match = xfrm_selector_match(&pol->selector, fl, family); + if (match) { + if ((READ_ONCE(sk->sk_mark) & pol->mark.m) != pol->mark.v || + pol->if_id != if_id) { + pol = NULL; + goto out; + } + err = security_xfrm_policy_lookup(pol->security, + fl->flowi_secid); + if (!err) { + if (!xfrm_pol_hold_rcu(pol)) + goto again; + } else if (err == -ESRCH) { + pol = NULL; + } else { + pol = ERR_PTR(err); + } + } else + pol = NULL; + } +out: + rcu_read_unlock(); + return pol; +} + +static void __xfrm_policy_link(struct xfrm_policy *pol, int dir) +{ + struct net *net = xp_net(pol); + + list_add(&pol->walk.all, &net->xfrm.policy_all); + net->xfrm.policy_count[dir]++; + xfrm_pol_hold(pol); +} + +static struct xfrm_policy *__xfrm_policy_unlink(struct xfrm_policy *pol, + int dir) +{ + struct net *net = xp_net(pol); + + if (list_empty(&pol->walk.all)) + return NULL; + + /* Socket policies are not hashed. */ + if (!hlist_unhashed(&pol->bydst)) { + hlist_del_rcu(&pol->bydst); + hlist_del_init(&pol->bydst_inexact_list); + hlist_del(&pol->byidx); + } + + list_del_init(&pol->walk.all); + net->xfrm.policy_count[dir]--; + + return pol; +} + +static void xfrm_sk_policy_link(struct xfrm_policy *pol, int dir) +{ + __xfrm_policy_link(pol, XFRM_POLICY_MAX + dir); +} + +static void xfrm_sk_policy_unlink(struct xfrm_policy *pol, int dir) +{ + __xfrm_policy_unlink(pol, XFRM_POLICY_MAX + dir); +} + +int xfrm_policy_delete(struct xfrm_policy *pol, int dir) +{ + struct net *net = xp_net(pol); + + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + pol = __xfrm_policy_unlink(pol, dir); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + if (pol) { + xfrm_policy_kill(pol); + return 0; + } + return -ENOENT; +} +EXPORT_SYMBOL(xfrm_policy_delete); + +int xfrm_sk_policy_insert(struct sock *sk, int dir, struct xfrm_policy *pol) +{ + struct net *net = sock_net(sk); + struct xfrm_policy *old_pol; + +#ifdef CONFIG_XFRM_SUB_POLICY + if (pol && pol->type != XFRM_POLICY_TYPE_MAIN) + return -EINVAL; +#endif + + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + old_pol = rcu_dereference_protected(sk->sk_policy[dir], + lockdep_is_held(&net->xfrm.xfrm_policy_lock)); + if (pol) { + pol->curlft.add_time = ktime_get_real_seconds(); + pol->index = xfrm_gen_index(net, XFRM_POLICY_MAX+dir, 0); + xfrm_sk_policy_link(pol, dir); + } + rcu_assign_pointer(sk->sk_policy[dir], pol); + if (old_pol) { + if (pol) + xfrm_policy_requeue(old_pol, pol); + + /* Unlinking succeeds always. This is the only function + * allowed to delete or replace socket policy. + */ + xfrm_sk_policy_unlink(old_pol, dir); + } + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + + if (old_pol) { + xfrm_policy_kill(old_pol); + } + return 0; +} + +static struct xfrm_policy *clone_policy(const struct xfrm_policy *old, int dir) +{ + struct xfrm_policy *newp = xfrm_policy_alloc(xp_net(old), GFP_ATOMIC); + struct net *net = xp_net(old); + + if (newp) { + newp->selector = old->selector; + if (security_xfrm_policy_clone(old->security, + &newp->security)) { + kfree(newp); + return NULL; /* ENOMEM */ + } + newp->lft = old->lft; + newp->curlft = old->curlft; + newp->mark = old->mark; + newp->if_id = old->if_id; + newp->action = old->action; + newp->flags = old->flags; + newp->xfrm_nr = old->xfrm_nr; + newp->index = old->index; + newp->type = old->type; + newp->family = old->family; + memcpy(newp->xfrm_vec, old->xfrm_vec, + newp->xfrm_nr*sizeof(struct xfrm_tmpl)); + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + xfrm_sk_policy_link(newp, dir); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + xfrm_pol_put(newp); + } + return newp; +} + +int __xfrm_sk_clone_policy(struct sock *sk, const struct sock *osk) +{ + const struct xfrm_policy *p; + struct xfrm_policy *np; + int i, ret = 0; + + rcu_read_lock(); + for (i = 0; i < 2; i++) { + p = rcu_dereference(osk->sk_policy[i]); + if (p) { + np = clone_policy(p, i); + if (unlikely(!np)) { + ret = -ENOMEM; + break; + } + rcu_assign_pointer(sk->sk_policy[i], np); + } + } + rcu_read_unlock(); + return ret; +} + +static int +xfrm_get_saddr(struct net *net, int oif, xfrm_address_t *local, + xfrm_address_t *remote, unsigned short family, u32 mark) +{ + int err; + const struct xfrm_policy_afinfo *afinfo = xfrm_policy_get_afinfo(family); + + if (unlikely(afinfo == NULL)) + return -EINVAL; + err = afinfo->get_saddr(net, oif, local, remote, mark); + rcu_read_unlock(); + return err; +} + +/* Resolve list of templates for the flow, given policy. */ + +static int +xfrm_tmpl_resolve_one(struct xfrm_policy *policy, const struct flowi *fl, + struct xfrm_state **xfrm, unsigned short family) +{ + struct net *net = xp_net(policy); + int nx; + int i, error; + xfrm_address_t *daddr = xfrm_flowi_daddr(fl, family); + xfrm_address_t *saddr = xfrm_flowi_saddr(fl, family); + xfrm_address_t tmp; + + for (nx = 0, i = 0; i < policy->xfrm_nr; i++) { + struct xfrm_state *x; + xfrm_address_t *remote = daddr; + xfrm_address_t *local = saddr; + struct xfrm_tmpl *tmpl = &policy->xfrm_vec[i]; + + if (tmpl->mode == XFRM_MODE_TUNNEL || + tmpl->mode == XFRM_MODE_BEET) { + remote = &tmpl->id.daddr; + local = &tmpl->saddr; + if (xfrm_addr_any(local, tmpl->encap_family)) { + error = xfrm_get_saddr(net, fl->flowi_oif, + &tmp, remote, + tmpl->encap_family, 0); + if (error) + goto fail; + local = &tmp; + } + } + + x = xfrm_state_find(remote, local, fl, tmpl, policy, &error, + family, policy->if_id); + + if (x && x->km.state == XFRM_STATE_VALID) { + xfrm[nx++] = x; + daddr = remote; + saddr = local; + continue; + } + if (x) { + error = (x->km.state == XFRM_STATE_ERROR ? + -EINVAL : -EAGAIN); + xfrm_state_put(x); + } else if (error == -ESRCH) { + error = -EAGAIN; + } + + if (!tmpl->optional) + goto fail; + } + return nx; + +fail: + for (nx--; nx >= 0; nx--) + xfrm_state_put(xfrm[nx]); + return error; +} + +static int +xfrm_tmpl_resolve(struct xfrm_policy **pols, int npols, const struct flowi *fl, + struct xfrm_state **xfrm, unsigned short family) +{ + struct xfrm_state *tp[XFRM_MAX_DEPTH]; + struct xfrm_state **tpp = (npols > 1) ? tp : xfrm; + int cnx = 0; + int error; + int ret; + int i; + + for (i = 0; i < npols; i++) { + if (cnx + pols[i]->xfrm_nr >= XFRM_MAX_DEPTH) { + error = -ENOBUFS; + goto fail; + } + + ret = xfrm_tmpl_resolve_one(pols[i], fl, &tpp[cnx], family); + if (ret < 0) { + error = ret; + goto fail; + } else + cnx += ret; + } + + /* found states are sorted for outbound processing */ + if (npols > 1) + xfrm_state_sort(xfrm, tpp, cnx, family); + + return cnx; + + fail: + for (cnx--; cnx >= 0; cnx--) + xfrm_state_put(tpp[cnx]); + return error; + +} + +static int xfrm_get_tos(const struct flowi *fl, int family) +{ + if (family == AF_INET) + return IPTOS_RT_MASK & fl->u.ip4.flowi4_tos; + + return 0; +} + +static inline struct xfrm_dst *xfrm_alloc_dst(struct net *net, int family) +{ + const struct xfrm_policy_afinfo *afinfo = xfrm_policy_get_afinfo(family); + struct dst_ops *dst_ops; + struct xfrm_dst *xdst; + + if (!afinfo) + return ERR_PTR(-EINVAL); + + switch (family) { + case AF_INET: + dst_ops = &net->xfrm.xfrm4_dst_ops; + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + dst_ops = &net->xfrm.xfrm6_dst_ops; + break; +#endif + default: + BUG(); + } + xdst = dst_alloc(dst_ops, NULL, 1, DST_OBSOLETE_NONE, 0); + + if (likely(xdst)) { + memset_after(xdst, 0, u.dst); + } else + xdst = ERR_PTR(-ENOBUFS); + + rcu_read_unlock(); + + return xdst; +} + +static void xfrm_init_path(struct xfrm_dst *path, struct dst_entry *dst, + int nfheader_len) +{ + if (dst->ops->family == AF_INET6) { + struct rt6_info *rt = (struct rt6_info *)dst; + path->path_cookie = rt6_get_cookie(rt); + path->u.rt6.rt6i_nfheader_len = nfheader_len; + } +} + +static inline int xfrm_fill_dst(struct xfrm_dst *xdst, struct net_device *dev, + const struct flowi *fl) +{ + const struct xfrm_policy_afinfo *afinfo = + xfrm_policy_get_afinfo(xdst->u.dst.ops->family); + int err; + + if (!afinfo) + return -EINVAL; + + err = afinfo->fill_dst(xdst, dev, fl); + + rcu_read_unlock(); + + return err; +} + + +/* Allocate chain of dst_entry's, attach known xfrm's, calculate + * all the metrics... Shortly, bundle a bundle. + */ + +static struct dst_entry *xfrm_bundle_create(struct xfrm_policy *policy, + struct xfrm_state **xfrm, + struct xfrm_dst **bundle, + int nx, + const struct flowi *fl, + struct dst_entry *dst) +{ + const struct xfrm_state_afinfo *afinfo; + const struct xfrm_mode *inner_mode; + struct net *net = xp_net(policy); + unsigned long now = jiffies; + struct net_device *dev; + struct xfrm_dst *xdst_prev = NULL; + struct xfrm_dst *xdst0 = NULL; + int i = 0; + int err; + int header_len = 0; + int nfheader_len = 0; + int trailer_len = 0; + int tos; + int family = policy->selector.family; + xfrm_address_t saddr, daddr; + + xfrm_flowi_addr_get(fl, &saddr, &daddr, family); + + tos = xfrm_get_tos(fl, family); + + dst_hold(dst); + + for (; i < nx; i++) { + struct xfrm_dst *xdst = xfrm_alloc_dst(net, family); + struct dst_entry *dst1 = &xdst->u.dst; + + err = PTR_ERR(xdst); + if (IS_ERR(xdst)) { + dst_release(dst); + goto put_states; + } + + bundle[i] = xdst; + if (!xdst_prev) + xdst0 = xdst; + else + /* Ref count is taken during xfrm_alloc_dst() + * No need to do dst_clone() on dst1 + */ + xfrm_dst_set_child(xdst_prev, &xdst->u.dst); + + if (xfrm[i]->sel.family == AF_UNSPEC) { + inner_mode = xfrm_ip2inner_mode(xfrm[i], + xfrm_af2proto(family)); + if (!inner_mode) { + err = -EAFNOSUPPORT; + dst_release(dst); + goto put_states; + } + } else + inner_mode = &xfrm[i]->inner_mode; + + xdst->route = dst; + dst_copy_metrics(dst1, dst); + + if (xfrm[i]->props.mode != XFRM_MODE_TRANSPORT) { + __u32 mark = 0; + int oif; + + if (xfrm[i]->props.smark.v || xfrm[i]->props.smark.m) + mark = xfrm_smark_get(fl->flowi_mark, xfrm[i]); + + family = xfrm[i]->props.family; + oif = fl->flowi_oif ? : fl->flowi_l3mdev; + dst = xfrm_dst_lookup(xfrm[i], tos, oif, + &saddr, &daddr, family, mark); + err = PTR_ERR(dst); + if (IS_ERR(dst)) + goto put_states; + } else + dst_hold(dst); + + dst1->xfrm = xfrm[i]; + xdst->xfrm_genid = xfrm[i]->genid; + + dst1->obsolete = DST_OBSOLETE_FORCE_CHK; + dst1->lastuse = now; + + dst1->input = dst_discard; + + rcu_read_lock(); + afinfo = xfrm_state_afinfo_get_rcu(inner_mode->family); + if (likely(afinfo)) + dst1->output = afinfo->output; + else + dst1->output = dst_discard_out; + rcu_read_unlock(); + + xdst_prev = xdst; + + header_len += xfrm[i]->props.header_len; + if (xfrm[i]->type->flags & XFRM_TYPE_NON_FRAGMENT) + nfheader_len += xfrm[i]->props.header_len; + trailer_len += xfrm[i]->props.trailer_len; + } + + xfrm_dst_set_child(xdst_prev, dst); + xdst0->path = dst; + + err = -ENODEV; + dev = dst->dev; + if (!dev) + goto free_dst; + + xfrm_init_path(xdst0, dst, nfheader_len); + xfrm_init_pmtu(bundle, nx); + + for (xdst_prev = xdst0; xdst_prev != (struct xfrm_dst *)dst; + xdst_prev = (struct xfrm_dst *) xfrm_dst_child(&xdst_prev->u.dst)) { + err = xfrm_fill_dst(xdst_prev, dev, fl); + if (err) + goto free_dst; + + xdst_prev->u.dst.header_len = header_len; + xdst_prev->u.dst.trailer_len = trailer_len; + header_len -= xdst_prev->u.dst.xfrm->props.header_len; + trailer_len -= xdst_prev->u.dst.xfrm->props.trailer_len; + } + + return &xdst0->u.dst; + +put_states: + for (; i < nx; i++) + xfrm_state_put(xfrm[i]); +free_dst: + if (xdst0) + dst_release_immediate(&xdst0->u.dst); + + return ERR_PTR(err); +} + +static int xfrm_expand_policies(const struct flowi *fl, u16 family, + struct xfrm_policy **pols, + int *num_pols, int *num_xfrms) +{ + int i; + + if (*num_pols == 0 || !pols[0]) { + *num_pols = 0; + *num_xfrms = 0; + return 0; + } + if (IS_ERR(pols[0])) { + *num_pols = 0; + return PTR_ERR(pols[0]); + } + + *num_xfrms = pols[0]->xfrm_nr; + +#ifdef CONFIG_XFRM_SUB_POLICY + if (pols[0]->action == XFRM_POLICY_ALLOW && + pols[0]->type != XFRM_POLICY_TYPE_MAIN) { + pols[1] = xfrm_policy_lookup_bytype(xp_net(pols[0]), + XFRM_POLICY_TYPE_MAIN, + fl, family, + XFRM_POLICY_OUT, + pols[0]->if_id); + if (pols[1]) { + if (IS_ERR(pols[1])) { + xfrm_pols_put(pols, *num_pols); + *num_pols = 0; + return PTR_ERR(pols[1]); + } + (*num_pols)++; + (*num_xfrms) += pols[1]->xfrm_nr; + } + } +#endif + for (i = 0; i < *num_pols; i++) { + if (pols[i]->action != XFRM_POLICY_ALLOW) { + *num_xfrms = -1; + break; + } + } + + return 0; + +} + +static struct xfrm_dst * +xfrm_resolve_and_create_bundle(struct xfrm_policy **pols, int num_pols, + const struct flowi *fl, u16 family, + struct dst_entry *dst_orig) +{ + struct net *net = xp_net(pols[0]); + struct xfrm_state *xfrm[XFRM_MAX_DEPTH]; + struct xfrm_dst *bundle[XFRM_MAX_DEPTH]; + struct xfrm_dst *xdst; + struct dst_entry *dst; + int err; + + /* Try to instantiate a bundle */ + err = xfrm_tmpl_resolve(pols, num_pols, fl, xfrm, family); + if (err <= 0) { + if (err == 0) + return NULL; + + if (err != -EAGAIN) + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTPOLERROR); + return ERR_PTR(err); + } + + dst = xfrm_bundle_create(pols[0], xfrm, bundle, err, fl, dst_orig); + if (IS_ERR(dst)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTBUNDLEGENERROR); + return ERR_CAST(dst); + } + + xdst = (struct xfrm_dst *)dst; + xdst->num_xfrms = err; + xdst->num_pols = num_pols; + memcpy(xdst->pols, pols, sizeof(struct xfrm_policy *) * num_pols); + xdst->policy_genid = atomic_read(&pols[0]->genid); + + return xdst; +} + +static void xfrm_policy_queue_process(struct timer_list *t) +{ + struct sk_buff *skb; + struct sock *sk; + struct dst_entry *dst; + struct xfrm_policy *pol = from_timer(pol, t, polq.hold_timer); + struct net *net = xp_net(pol); + struct xfrm_policy_queue *pq = &pol->polq; + struct flowi fl; + struct sk_buff_head list; + __u32 skb_mark; + + spin_lock(&pq->hold_queue.lock); + skb = skb_peek(&pq->hold_queue); + if (!skb) { + spin_unlock(&pq->hold_queue.lock); + goto out; + } + dst = skb_dst(skb); + sk = skb->sk; + + /* Fixup the mark to support VTI. */ + skb_mark = skb->mark; + skb->mark = pol->mark.v; + xfrm_decode_session(skb, &fl, dst->ops->family); + skb->mark = skb_mark; + spin_unlock(&pq->hold_queue.lock); + + dst_hold(xfrm_dst_path(dst)); + dst = xfrm_lookup(net, xfrm_dst_path(dst), &fl, sk, XFRM_LOOKUP_QUEUE); + if (IS_ERR(dst)) + goto purge_queue; + + if (dst->flags & DST_XFRM_QUEUE) { + dst_release(dst); + + if (pq->timeout >= XFRM_QUEUE_TMO_MAX) + goto purge_queue; + + pq->timeout = pq->timeout << 1; + if (!mod_timer(&pq->hold_timer, jiffies + pq->timeout)) + xfrm_pol_hold(pol); + goto out; + } + + dst_release(dst); + + __skb_queue_head_init(&list); + + spin_lock(&pq->hold_queue.lock); + pq->timeout = 0; + skb_queue_splice_init(&pq->hold_queue, &list); + spin_unlock(&pq->hold_queue.lock); + + while (!skb_queue_empty(&list)) { + skb = __skb_dequeue(&list); + + /* Fixup the mark to support VTI. */ + skb_mark = skb->mark; + skb->mark = pol->mark.v; + xfrm_decode_session(skb, &fl, skb_dst(skb)->ops->family); + skb->mark = skb_mark; + + dst_hold(xfrm_dst_path(skb_dst(skb))); + dst = xfrm_lookup(net, xfrm_dst_path(skb_dst(skb)), &fl, skb->sk, 0); + if (IS_ERR(dst)) { + kfree_skb(skb); + continue; + } + + nf_reset_ct(skb); + skb_dst_drop(skb); + skb_dst_set(skb, dst); + + dst_output(net, skb->sk, skb); + } + +out: + xfrm_pol_put(pol); + return; + +purge_queue: + pq->timeout = 0; + skb_queue_purge(&pq->hold_queue); + xfrm_pol_put(pol); +} + +static int xdst_queue_output(struct net *net, struct sock *sk, struct sk_buff *skb) +{ + unsigned long sched_next; + struct dst_entry *dst = skb_dst(skb); + struct xfrm_dst *xdst = (struct xfrm_dst *) dst; + struct xfrm_policy *pol = xdst->pols[0]; + struct xfrm_policy_queue *pq = &pol->polq; + + if (unlikely(skb_fclone_busy(sk, skb))) { + kfree_skb(skb); + return 0; + } + + if (pq->hold_queue.qlen > XFRM_MAX_QUEUE_LEN) { + kfree_skb(skb); + return -EAGAIN; + } + + skb_dst_force(skb); + + spin_lock_bh(&pq->hold_queue.lock); + + if (!pq->timeout) + pq->timeout = XFRM_QUEUE_TMO_MIN; + + sched_next = jiffies + pq->timeout; + + if (del_timer(&pq->hold_timer)) { + if (time_before(pq->hold_timer.expires, sched_next)) + sched_next = pq->hold_timer.expires; + xfrm_pol_put(pol); + } + + __skb_queue_tail(&pq->hold_queue, skb); + if (!mod_timer(&pq->hold_timer, sched_next)) + xfrm_pol_hold(pol); + + spin_unlock_bh(&pq->hold_queue.lock); + + return 0; +} + +static struct xfrm_dst *xfrm_create_dummy_bundle(struct net *net, + struct xfrm_flo *xflo, + const struct flowi *fl, + int num_xfrms, + u16 family) +{ + int err; + struct net_device *dev; + struct dst_entry *dst; + struct dst_entry *dst1; + struct xfrm_dst *xdst; + + xdst = xfrm_alloc_dst(net, family); + if (IS_ERR(xdst)) + return xdst; + + if (!(xflo->flags & XFRM_LOOKUP_QUEUE) || + net->xfrm.sysctl_larval_drop || + num_xfrms <= 0) + return xdst; + + dst = xflo->dst_orig; + dst1 = &xdst->u.dst; + dst_hold(dst); + xdst->route = dst; + + dst_copy_metrics(dst1, dst); + + dst1->obsolete = DST_OBSOLETE_FORCE_CHK; + dst1->flags |= DST_XFRM_QUEUE; + dst1->lastuse = jiffies; + + dst1->input = dst_discard; + dst1->output = xdst_queue_output; + + dst_hold(dst); + xfrm_dst_set_child(xdst, dst); + xdst->path = dst; + + xfrm_init_path((struct xfrm_dst *)dst1, dst, 0); + + err = -ENODEV; + dev = dst->dev; + if (!dev) + goto free_dst; + + err = xfrm_fill_dst(xdst, dev, fl); + if (err) + goto free_dst; + +out: + return xdst; + +free_dst: + dst_release(dst1); + xdst = ERR_PTR(err); + goto out; +} + +static struct xfrm_dst *xfrm_bundle_lookup(struct net *net, + const struct flowi *fl, + u16 family, u8 dir, + struct xfrm_flo *xflo, u32 if_id) +{ + struct xfrm_policy *pols[XFRM_POLICY_TYPE_MAX]; + int num_pols = 0, num_xfrms = 0, err; + struct xfrm_dst *xdst; + + /* Resolve policies to use if we couldn't get them from + * previous cache entry */ + num_pols = 1; + pols[0] = xfrm_policy_lookup(net, fl, family, dir, if_id); + err = xfrm_expand_policies(fl, family, pols, + &num_pols, &num_xfrms); + if (err < 0) + goto inc_error; + if (num_pols == 0) + return NULL; + if (num_xfrms <= 0) + goto make_dummy_bundle; + + xdst = xfrm_resolve_and_create_bundle(pols, num_pols, fl, family, + xflo->dst_orig); + if (IS_ERR(xdst)) { + err = PTR_ERR(xdst); + if (err == -EREMOTE) { + xfrm_pols_put(pols, num_pols); + return NULL; + } + + if (err != -EAGAIN) + goto error; + goto make_dummy_bundle; + } else if (xdst == NULL) { + num_xfrms = 0; + goto make_dummy_bundle; + } + + return xdst; + +make_dummy_bundle: + /* We found policies, but there's no bundles to instantiate: + * either because the policy blocks, has no transformations or + * we could not build template (no xfrm_states).*/ + xdst = xfrm_create_dummy_bundle(net, xflo, fl, num_xfrms, family); + if (IS_ERR(xdst)) { + xfrm_pols_put(pols, num_pols); + return ERR_CAST(xdst); + } + xdst->num_pols = num_pols; + xdst->num_xfrms = num_xfrms; + memcpy(xdst->pols, pols, sizeof(struct xfrm_policy *) * num_pols); + + return xdst; + +inc_error: + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTPOLERROR); +error: + xfrm_pols_put(pols, num_pols); + return ERR_PTR(err); +} + +static struct dst_entry *make_blackhole(struct net *net, u16 family, + struct dst_entry *dst_orig) +{ + const struct xfrm_policy_afinfo *afinfo = xfrm_policy_get_afinfo(family); + struct dst_entry *ret; + + if (!afinfo) { + dst_release(dst_orig); + return ERR_PTR(-EINVAL); + } else { + ret = afinfo->blackhole_route(net, dst_orig); + } + rcu_read_unlock(); + + return ret; +} + +/* Finds/creates a bundle for given flow and if_id + * + * At the moment we eat a raw IP route. Mostly to speed up lookups + * on interfaces with disabled IPsec. + * + * xfrm_lookup uses an if_id of 0 by default, and is provided for + * compatibility + */ +struct dst_entry *xfrm_lookup_with_ifid(struct net *net, + struct dst_entry *dst_orig, + const struct flowi *fl, + const struct sock *sk, + int flags, u32 if_id) +{ + struct xfrm_policy *pols[XFRM_POLICY_TYPE_MAX]; + struct xfrm_dst *xdst; + struct dst_entry *dst, *route; + u16 family = dst_orig->ops->family; + u8 dir = XFRM_POLICY_OUT; + int i, err, num_pols, num_xfrms = 0, drop_pols = 0; + + dst = NULL; + xdst = NULL; + route = NULL; + + sk = sk_const_to_full_sk(sk); + if (sk && sk->sk_policy[XFRM_POLICY_OUT]) { + num_pols = 1; + pols[0] = xfrm_sk_policy_lookup(sk, XFRM_POLICY_OUT, fl, family, + if_id); + err = xfrm_expand_policies(fl, family, pols, + &num_pols, &num_xfrms); + if (err < 0) + goto dropdst; + + if (num_pols) { + if (num_xfrms <= 0) { + drop_pols = num_pols; + goto no_transform; + } + + xdst = xfrm_resolve_and_create_bundle( + pols, num_pols, fl, + family, dst_orig); + + if (IS_ERR(xdst)) { + xfrm_pols_put(pols, num_pols); + err = PTR_ERR(xdst); + if (err == -EREMOTE) + goto nopol; + + goto dropdst; + } else if (xdst == NULL) { + num_xfrms = 0; + drop_pols = num_pols; + goto no_transform; + } + + route = xdst->route; + } + } + + if (xdst == NULL) { + struct xfrm_flo xflo; + + xflo.dst_orig = dst_orig; + xflo.flags = flags; + + /* To accelerate a bit... */ + if (!if_id && ((dst_orig->flags & DST_NOXFRM) || + !net->xfrm.policy_count[XFRM_POLICY_OUT])) + goto nopol; + + xdst = xfrm_bundle_lookup(net, fl, family, dir, &xflo, if_id); + if (xdst == NULL) + goto nopol; + if (IS_ERR(xdst)) { + err = PTR_ERR(xdst); + goto dropdst; + } + + num_pols = xdst->num_pols; + num_xfrms = xdst->num_xfrms; + memcpy(pols, xdst->pols, sizeof(struct xfrm_policy *) * num_pols); + route = xdst->route; + } + + dst = &xdst->u.dst; + if (route == NULL && num_xfrms > 0) { + /* The only case when xfrm_bundle_lookup() returns a + * bundle with null route, is when the template could + * not be resolved. It means policies are there, but + * bundle could not be created, since we don't yet + * have the xfrm_state's. We need to wait for KM to + * negotiate new SA's or bail out with error.*/ + if (net->xfrm.sysctl_larval_drop) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTNOSTATES); + err = -EREMOTE; + goto error; + } + + err = -EAGAIN; + + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTNOSTATES); + goto error; + } + +no_transform: + if (num_pols == 0) + goto nopol; + + if ((flags & XFRM_LOOKUP_ICMP) && + !(pols[0]->flags & XFRM_POLICY_ICMP)) { + err = -ENOENT; + goto error; + } + + for (i = 0; i < num_pols; i++) + WRITE_ONCE(pols[i]->curlft.use_time, ktime_get_real_seconds()); + + if (num_xfrms < 0) { + /* Prohibit the flow */ + XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTPOLBLOCK); + err = -EPERM; + goto error; + } else if (num_xfrms > 0) { + /* Flow transformed */ + dst_release(dst_orig); + } else { + /* Flow passes untransformed */ + dst_release(dst); + dst = dst_orig; + } +ok: + xfrm_pols_put(pols, drop_pols); + if (dst && dst->xfrm && + dst->xfrm->props.mode == XFRM_MODE_TUNNEL) + dst->flags |= DST_XFRM_TUNNEL; + return dst; + +nopol: + if ((!dst_orig->dev || !(dst_orig->dev->flags & IFF_LOOPBACK)) && + net->xfrm.policy_default[dir] == XFRM_USERPOLICY_BLOCK) { + err = -EPERM; + goto error; + } + if (!(flags & XFRM_LOOKUP_ICMP)) { + dst = dst_orig; + goto ok; + } + err = -ENOENT; +error: + dst_release(dst); +dropdst: + if (!(flags & XFRM_LOOKUP_KEEP_DST_REF)) + dst_release(dst_orig); + xfrm_pols_put(pols, drop_pols); + return ERR_PTR(err); +} +EXPORT_SYMBOL(xfrm_lookup_with_ifid); + +/* Main function: finds/creates a bundle for given flow. + * + * At the moment we eat a raw IP route. Mostly to speed up lookups + * on interfaces with disabled IPsec. + */ +struct dst_entry *xfrm_lookup(struct net *net, struct dst_entry *dst_orig, + const struct flowi *fl, const struct sock *sk, + int flags) +{ + return xfrm_lookup_with_ifid(net, dst_orig, fl, sk, flags, 0); +} +EXPORT_SYMBOL(xfrm_lookup); + +/* Callers of xfrm_lookup_route() must ensure a call to dst_output(). + * Otherwise we may send out blackholed packets. + */ +struct dst_entry *xfrm_lookup_route(struct net *net, struct dst_entry *dst_orig, + const struct flowi *fl, + const struct sock *sk, int flags) +{ + struct dst_entry *dst = xfrm_lookup(net, dst_orig, fl, sk, + flags | XFRM_LOOKUP_QUEUE | + XFRM_LOOKUP_KEEP_DST_REF); + + if (PTR_ERR(dst) == -EREMOTE) + return make_blackhole(net, dst_orig->ops->family, dst_orig); + + if (IS_ERR(dst)) + dst_release(dst_orig); + + return dst; +} +EXPORT_SYMBOL(xfrm_lookup_route); + +static inline int +xfrm_secpath_reject(int idx, struct sk_buff *skb, const struct flowi *fl) +{ + struct sec_path *sp = skb_sec_path(skb); + struct xfrm_state *x; + + if (!sp || idx < 0 || idx >= sp->len) + return 0; + x = sp->xvec[idx]; + if (!x->type->reject) + return 0; + return x->type->reject(x, skb, fl); +} + +/* When skb is transformed back to its "native" form, we have to + * check policy restrictions. At the moment we make this in maximally + * stupid way. Shame on me. :-) Of course, connected sockets must + * have policy cached at them. + */ + +static inline int +xfrm_state_ok(const struct xfrm_tmpl *tmpl, const struct xfrm_state *x, + unsigned short family, u32 if_id) +{ + if (xfrm_state_kern(x)) + return tmpl->optional && !xfrm_state_addr_cmp(tmpl, x, tmpl->encap_family); + return x->id.proto == tmpl->id.proto && + (x->id.spi == tmpl->id.spi || !tmpl->id.spi) && + (x->props.reqid == tmpl->reqid || !tmpl->reqid) && + x->props.mode == tmpl->mode && + (tmpl->allalgs || (tmpl->aalgos & (1<props.aalgo)) || + !(xfrm_id_proto_match(tmpl->id.proto, IPSEC_PROTO_ANY))) && + !(x->props.mode != XFRM_MODE_TRANSPORT && + xfrm_state_addr_cmp(tmpl, x, family)) && + (if_id == 0 || if_id == x->if_id); +} + +/* + * 0 or more than 0 is returned when validation is succeeded (either bypass + * because of optional transport mode, or next index of the matched secpath + * state with the template. + * -1 is returned when no matching template is found. + * Otherwise "-2 - errored_index" is returned. + */ +static inline int +xfrm_policy_ok(const struct xfrm_tmpl *tmpl, const struct sec_path *sp, int start, + unsigned short family, u32 if_id) +{ + int idx = start; + + if (tmpl->optional) { + if (tmpl->mode == XFRM_MODE_TRANSPORT) + return start; + } else + start = -1; + for (; idx < sp->len; idx++) { + if (xfrm_state_ok(tmpl, sp->xvec[idx], family, if_id)) + return ++idx; + if (sp->xvec[idx]->props.mode != XFRM_MODE_TRANSPORT) { + if (idx < sp->verified_cnt) { + /* Secpath entry previously verified, consider optional and + * continue searching + */ + continue; + } + + if (start == -1) + start = -2-idx; + break; + } + } + return start; +} + +static void +decode_session4(struct sk_buff *skb, struct flowi *fl, bool reverse) +{ + const struct iphdr *iph = ip_hdr(skb); + int ihl = iph->ihl; + u8 *xprth = skb_network_header(skb) + ihl * 4; + struct flowi4 *fl4 = &fl->u.ip4; + int oif = 0; + + if (skb_dst(skb) && skb_dst(skb)->dev) + oif = skb_dst(skb)->dev->ifindex; + + memset(fl4, 0, sizeof(struct flowi4)); + fl4->flowi4_mark = skb->mark; + fl4->flowi4_oif = reverse ? skb->skb_iif : oif; + + fl4->flowi4_proto = iph->protocol; + fl4->daddr = reverse ? iph->saddr : iph->daddr; + fl4->saddr = reverse ? iph->daddr : iph->saddr; + fl4->flowi4_tos = iph->tos & ~INET_ECN_MASK; + + if (!ip_is_fragment(iph)) { + switch (iph->protocol) { + case IPPROTO_UDP: + case IPPROTO_UDPLITE: + case IPPROTO_TCP: + case IPPROTO_SCTP: + case IPPROTO_DCCP: + if (xprth + 4 < skb->data || + pskb_may_pull(skb, xprth + 4 - skb->data)) { + __be16 *ports; + + xprth = skb_network_header(skb) + ihl * 4; + ports = (__be16 *)xprth; + + fl4->fl4_sport = ports[!!reverse]; + fl4->fl4_dport = ports[!reverse]; + } + break; + case IPPROTO_ICMP: + if (xprth + 2 < skb->data || + pskb_may_pull(skb, xprth + 2 - skb->data)) { + u8 *icmp; + + xprth = skb_network_header(skb) + ihl * 4; + icmp = xprth; + + fl4->fl4_icmp_type = icmp[0]; + fl4->fl4_icmp_code = icmp[1]; + } + break; + case IPPROTO_GRE: + if (xprth + 12 < skb->data || + pskb_may_pull(skb, xprth + 12 - skb->data)) { + __be16 *greflags; + __be32 *gre_hdr; + + xprth = skb_network_header(skb) + ihl * 4; + greflags = (__be16 *)xprth; + gre_hdr = (__be32 *)xprth; + + if (greflags[0] & GRE_KEY) { + if (greflags[0] & GRE_CSUM) + gre_hdr++; + fl4->fl4_gre_key = gre_hdr[1]; + } + } + break; + default: + break; + } + } +} + +#if IS_ENABLED(CONFIG_IPV6) +static void +decode_session6(struct sk_buff *skb, struct flowi *fl, bool reverse) +{ + struct flowi6 *fl6 = &fl->u.ip6; + int onlyproto = 0; + const struct ipv6hdr *hdr = ipv6_hdr(skb); + u32 offset = sizeof(*hdr); + struct ipv6_opt_hdr *exthdr; + const unsigned char *nh = skb_network_header(skb); + u16 nhoff = IP6CB(skb)->nhoff; + int oif = 0; + u8 nexthdr; + + if (!nhoff) + nhoff = offsetof(struct ipv6hdr, nexthdr); + + nexthdr = nh[nhoff]; + + if (skb_dst(skb) && skb_dst(skb)->dev) + oif = skb_dst(skb)->dev->ifindex; + + memset(fl6, 0, sizeof(struct flowi6)); + fl6->flowi6_mark = skb->mark; + fl6->flowi6_oif = reverse ? skb->skb_iif : oif; + + fl6->daddr = reverse ? hdr->saddr : hdr->daddr; + fl6->saddr = reverse ? hdr->daddr : hdr->saddr; + + while (nh + offset + sizeof(*exthdr) < skb->data || + pskb_may_pull(skb, nh + offset + sizeof(*exthdr) - skb->data)) { + nh = skb_network_header(skb); + exthdr = (struct ipv6_opt_hdr *)(nh + offset); + + switch (nexthdr) { + case NEXTHDR_FRAGMENT: + onlyproto = 1; + fallthrough; + case NEXTHDR_ROUTING: + case NEXTHDR_HOP: + case NEXTHDR_DEST: + offset += ipv6_optlen(exthdr); + nexthdr = exthdr->nexthdr; + break; + case IPPROTO_UDP: + case IPPROTO_UDPLITE: + case IPPROTO_TCP: + case IPPROTO_SCTP: + case IPPROTO_DCCP: + if (!onlyproto && (nh + offset + 4 < skb->data || + pskb_may_pull(skb, nh + offset + 4 - skb->data))) { + __be16 *ports; + + nh = skb_network_header(skb); + ports = (__be16 *)(nh + offset); + fl6->fl6_sport = ports[!!reverse]; + fl6->fl6_dport = ports[!reverse]; + } + fl6->flowi6_proto = nexthdr; + return; + case IPPROTO_ICMPV6: + if (!onlyproto && (nh + offset + 2 < skb->data || + pskb_may_pull(skb, nh + offset + 2 - skb->data))) { + u8 *icmp; + + nh = skb_network_header(skb); + icmp = (u8 *)(nh + offset); + fl6->fl6_icmp_type = icmp[0]; + fl6->fl6_icmp_code = icmp[1]; + } + fl6->flowi6_proto = nexthdr; + return; + case IPPROTO_GRE: + if (!onlyproto && + (nh + offset + 12 < skb->data || + pskb_may_pull(skb, nh + offset + 12 - skb->data))) { + struct gre_base_hdr *gre_hdr; + __be32 *gre_key; + + nh = skb_network_header(skb); + gre_hdr = (struct gre_base_hdr *)(nh + offset); + gre_key = (__be32 *)(gre_hdr + 1); + + if (gre_hdr->flags & GRE_KEY) { + if (gre_hdr->flags & GRE_CSUM) + gre_key++; + fl6->fl6_gre_key = *gre_key; + } + } + fl6->flowi6_proto = nexthdr; + return; + +#if IS_ENABLED(CONFIG_IPV6_MIP6) + case IPPROTO_MH: + offset += ipv6_optlen(exthdr); + if (!onlyproto && (nh + offset + 3 < skb->data || + pskb_may_pull(skb, nh + offset + 3 - skb->data))) { + struct ip6_mh *mh; + + nh = skb_network_header(skb); + mh = (struct ip6_mh *)(nh + offset); + fl6->fl6_mh_type = mh->ip6mh_type; + } + fl6->flowi6_proto = nexthdr; + return; +#endif + default: + fl6->flowi6_proto = nexthdr; + return; + } + } +} +#endif + +int __xfrm_decode_session(struct sk_buff *skb, struct flowi *fl, + unsigned int family, int reverse) +{ + switch (family) { + case AF_INET: + decode_session4(skb, fl, reverse); + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + decode_session6(skb, fl, reverse); + break; +#endif + default: + return -EAFNOSUPPORT; + } + + return security_xfrm_decode_session(skb, &fl->flowi_secid); +} +EXPORT_SYMBOL(__xfrm_decode_session); + +static inline int secpath_has_nontransport(const struct sec_path *sp, int k, int *idxp) +{ + for (; k < sp->len; k++) { + if (sp->xvec[k]->props.mode != XFRM_MODE_TRANSPORT) { + *idxp = k; + return 1; + } + } + + return 0; +} + +int __xfrm_policy_check(struct sock *sk, int dir, struct sk_buff *skb, + unsigned short family) +{ + struct net *net = dev_net(skb->dev); + struct xfrm_policy *pol; + struct xfrm_policy *pols[XFRM_POLICY_TYPE_MAX]; + int npols = 0; + int xfrm_nr; + int pi; + int reverse; + struct flowi fl; + int xerr_idx = -1; + const struct xfrm_if_cb *ifcb; + struct sec_path *sp; + u32 if_id = 0; + + rcu_read_lock(); + ifcb = xfrm_if_get_cb(); + + if (ifcb) { + struct xfrm_if_decode_session_result r; + + if (ifcb->decode_session(skb, family, &r)) { + if_id = r.if_id; + net = r.net; + } + } + rcu_read_unlock(); + + reverse = dir & ~XFRM_POLICY_MASK; + dir &= XFRM_POLICY_MASK; + + if (__xfrm_decode_session(skb, &fl, family, reverse) < 0) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINHDRERROR); + return 0; + } + + nf_nat_decode_session(skb, &fl, family); + + /* First, check used SA against their selectors. */ + sp = skb_sec_path(skb); + if (sp) { + int i; + + for (i = sp->len - 1; i >= 0; i--) { + struct xfrm_state *x = sp->xvec[i]; + if (!xfrm_selector_match(&x->sel, &fl, family)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINSTATEMISMATCH); + return 0; + } + } + } + + pol = NULL; + sk = sk_to_full_sk(sk); + if (sk && sk->sk_policy[dir]) { + pol = xfrm_sk_policy_lookup(sk, dir, &fl, family, if_id); + if (IS_ERR(pol)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLERROR); + return 0; + } + } + + if (!pol) + pol = xfrm_policy_lookup(net, &fl, family, dir, if_id); + + if (IS_ERR(pol)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLERROR); + return 0; + } + + if (!pol) { + if (net->xfrm.policy_default[dir] == XFRM_USERPOLICY_BLOCK) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINNOPOLS); + return 0; + } + + if (sp && secpath_has_nontransport(sp, 0, &xerr_idx)) { + xfrm_secpath_reject(xerr_idx, skb, &fl); + XFRM_INC_STATS(net, LINUX_MIB_XFRMINNOPOLS); + return 0; + } + return 1; + } + + /* This lockless write can happen from different cpus. */ + WRITE_ONCE(pol->curlft.use_time, ktime_get_real_seconds()); + + pols[0] = pol; + npols++; +#ifdef CONFIG_XFRM_SUB_POLICY + if (pols[0]->type != XFRM_POLICY_TYPE_MAIN) { + pols[1] = xfrm_policy_lookup_bytype(net, XFRM_POLICY_TYPE_MAIN, + &fl, family, + XFRM_POLICY_IN, if_id); + if (pols[1]) { + if (IS_ERR(pols[1])) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLERROR); + xfrm_pol_put(pols[0]); + return 0; + } + /* This write can happen from different cpus. */ + WRITE_ONCE(pols[1]->curlft.use_time, + ktime_get_real_seconds()); + npols++; + } + } +#endif + + if (pol->action == XFRM_POLICY_ALLOW) { + static struct sec_path dummy; + struct xfrm_tmpl *tp[XFRM_MAX_DEPTH]; + struct xfrm_tmpl *stp[XFRM_MAX_DEPTH]; + struct xfrm_tmpl **tpp = tp; + int ti = 0; + int i, k; + + sp = skb_sec_path(skb); + if (!sp) + sp = &dummy; + + for (pi = 0; pi < npols; pi++) { + if (pols[pi] != pol && + pols[pi]->action != XFRM_POLICY_ALLOW) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLBLOCK); + goto reject; + } + if (ti + pols[pi]->xfrm_nr >= XFRM_MAX_DEPTH) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINBUFFERERROR); + goto reject_error; + } + for (i = 0; i < pols[pi]->xfrm_nr; i++) + tpp[ti++] = &pols[pi]->xfrm_vec[i]; + } + xfrm_nr = ti; + + if (npols > 1) { + xfrm_tmpl_sort(stp, tpp, xfrm_nr, family); + tpp = stp; + } + + /* For each tunnel xfrm, find the first matching tmpl. + * For each tmpl before that, find corresponding xfrm. + * Order is _important_. Later we will implement + * some barriers, but at the moment barriers + * are implied between each two transformations. + * Upon success, marks secpath entries as having been + * verified to allow them to be skipped in future policy + * checks (e.g. nested tunnels). + */ + for (i = xfrm_nr-1, k = 0; i >= 0; i--) { + k = xfrm_policy_ok(tpp[i], sp, k, family, if_id); + if (k < 0) { + if (k < -1) + /* "-2 - errored_index" returned */ + xerr_idx = -(2+k); + XFRM_INC_STATS(net, LINUX_MIB_XFRMINTMPLMISMATCH); + goto reject; + } + } + + if (secpath_has_nontransport(sp, k, &xerr_idx)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMINTMPLMISMATCH); + goto reject; + } + + xfrm_pols_put(pols, npols); + sp->verified_cnt = k; + + return 1; + } + XFRM_INC_STATS(net, LINUX_MIB_XFRMINPOLBLOCK); + +reject: + xfrm_secpath_reject(xerr_idx, skb, &fl); +reject_error: + xfrm_pols_put(pols, npols); + return 0; +} +EXPORT_SYMBOL(__xfrm_policy_check); + +int __xfrm_route_forward(struct sk_buff *skb, unsigned short family) +{ + struct net *net = dev_net(skb->dev); + struct flowi fl; + struct dst_entry *dst; + int res = 1; + + if (xfrm_decode_session(skb, &fl, family) < 0) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMFWDHDRERROR); + return 0; + } + + skb_dst_force(skb); + if (!skb_dst(skb)) { + XFRM_INC_STATS(net, LINUX_MIB_XFRMFWDHDRERROR); + return 0; + } + + dst = xfrm_lookup(net, skb_dst(skb), &fl, NULL, XFRM_LOOKUP_QUEUE); + if (IS_ERR(dst)) { + res = 0; + dst = NULL; + } + skb_dst_set(skb, dst); + return res; +} +EXPORT_SYMBOL(__xfrm_route_forward); + +/* Optimize later using cookies and generation ids. */ + +static struct dst_entry *xfrm_dst_check(struct dst_entry *dst, u32 cookie) +{ + /* Code (such as __xfrm4_bundle_create()) sets dst->obsolete + * to DST_OBSOLETE_FORCE_CHK to force all XFRM destinations to + * get validated by dst_ops->check on every use. We do this + * because when a normal route referenced by an XFRM dst is + * obsoleted we do not go looking around for all parent + * referencing XFRM dsts so that we can invalidate them. It + * is just too much work. Instead we make the checks here on + * every use. For example: + * + * XFRM dst A --> IPv4 dst X + * + * X is the "xdst->route" of A (X is also the "dst->path" of A + * in this example). If X is marked obsolete, "A" will not + * notice. That's what we are validating here via the + * stale_bundle() check. + * + * When a dst is removed from the fib tree, DST_OBSOLETE_DEAD will + * be marked on it. + * This will force stale_bundle() to fail on any xdst bundle with + * this dst linked in it. + */ + if (dst->obsolete < 0 && !stale_bundle(dst)) + return dst; + + return NULL; +} + +static int stale_bundle(struct dst_entry *dst) +{ + return !xfrm_bundle_ok((struct xfrm_dst *)dst); +} + +void xfrm_dst_ifdown(struct dst_entry *dst, struct net_device *dev) +{ + while ((dst = xfrm_dst_child(dst)) && dst->xfrm && dst->dev == dev) { + dst->dev = blackhole_netdev; + dev_hold(dst->dev); + dev_put(dev); + } +} +EXPORT_SYMBOL(xfrm_dst_ifdown); + +static void xfrm_link_failure(struct sk_buff *skb) +{ + /* Impossible. Such dst must be popped before reaches point of failure. */ +} + +static struct dst_entry *xfrm_negative_advice(struct dst_entry *dst) +{ + if (dst) { + if (dst->obsolete) { + dst_release(dst); + dst = NULL; + } + } + return dst; +} + +static void xfrm_init_pmtu(struct xfrm_dst **bundle, int nr) +{ + while (nr--) { + struct xfrm_dst *xdst = bundle[nr]; + u32 pmtu, route_mtu_cached; + struct dst_entry *dst; + + dst = &xdst->u.dst; + pmtu = dst_mtu(xfrm_dst_child(dst)); + xdst->child_mtu_cached = pmtu; + + pmtu = xfrm_state_mtu(dst->xfrm, pmtu); + + route_mtu_cached = dst_mtu(xdst->route); + xdst->route_mtu_cached = route_mtu_cached; + + if (pmtu > route_mtu_cached) + pmtu = route_mtu_cached; + + dst_metric_set(dst, RTAX_MTU, pmtu); + } +} + +/* Check that the bundle accepts the flow and its components are + * still valid. + */ + +static int xfrm_bundle_ok(struct xfrm_dst *first) +{ + struct xfrm_dst *bundle[XFRM_MAX_DEPTH]; + struct dst_entry *dst = &first->u.dst; + struct xfrm_dst *xdst; + int start_from, nr; + u32 mtu; + + if (!dst_check(xfrm_dst_path(dst), ((struct xfrm_dst *)dst)->path_cookie) || + (dst->dev && !netif_running(dst->dev))) + return 0; + + if (dst->flags & DST_XFRM_QUEUE) + return 1; + + start_from = nr = 0; + do { + struct xfrm_dst *xdst = (struct xfrm_dst *)dst; + + if (dst->xfrm->km.state != XFRM_STATE_VALID) + return 0; + if (xdst->xfrm_genid != dst->xfrm->genid) + return 0; + if (xdst->num_pols > 0 && + xdst->policy_genid != atomic_read(&xdst->pols[0]->genid)) + return 0; + + bundle[nr++] = xdst; + + mtu = dst_mtu(xfrm_dst_child(dst)); + if (xdst->child_mtu_cached != mtu) { + start_from = nr; + xdst->child_mtu_cached = mtu; + } + + if (!dst_check(xdst->route, xdst->route_cookie)) + return 0; + mtu = dst_mtu(xdst->route); + if (xdst->route_mtu_cached != mtu) { + start_from = nr; + xdst->route_mtu_cached = mtu; + } + + dst = xfrm_dst_child(dst); + } while (dst->xfrm); + + if (likely(!start_from)) + return 1; + + xdst = bundle[start_from - 1]; + mtu = xdst->child_mtu_cached; + while (start_from--) { + dst = &xdst->u.dst; + + mtu = xfrm_state_mtu(dst->xfrm, mtu); + if (mtu > xdst->route_mtu_cached) + mtu = xdst->route_mtu_cached; + dst_metric_set(dst, RTAX_MTU, mtu); + if (!start_from) + break; + + xdst = bundle[start_from - 1]; + xdst->child_mtu_cached = mtu; + } + + return 1; +} + +static unsigned int xfrm_default_advmss(const struct dst_entry *dst) +{ + return dst_metric_advmss(xfrm_dst_path(dst)); +} + +static unsigned int xfrm_mtu(const struct dst_entry *dst) +{ + unsigned int mtu = dst_metric_raw(dst, RTAX_MTU); + + return mtu ? : dst_mtu(xfrm_dst_path(dst)); +} + +static const void *xfrm_get_dst_nexthop(const struct dst_entry *dst, + const void *daddr) +{ + while (dst->xfrm) { + const struct xfrm_state *xfrm = dst->xfrm; + + dst = xfrm_dst_child(dst); + + if (xfrm->props.mode == XFRM_MODE_TRANSPORT) + continue; + if (xfrm->type->flags & XFRM_TYPE_REMOTE_COADDR) + daddr = xfrm->coaddr; + else if (!(xfrm->type->flags & XFRM_TYPE_LOCAL_COADDR)) + daddr = &xfrm->id.daddr; + } + return daddr; +} + +static struct neighbour *xfrm_neigh_lookup(const struct dst_entry *dst, + struct sk_buff *skb, + const void *daddr) +{ + const struct dst_entry *path = xfrm_dst_path(dst); + + if (!skb) + daddr = xfrm_get_dst_nexthop(dst, daddr); + return path->ops->neigh_lookup(path, skb, daddr); +} + +static void xfrm_confirm_neigh(const struct dst_entry *dst, const void *daddr) +{ + const struct dst_entry *path = xfrm_dst_path(dst); + + daddr = xfrm_get_dst_nexthop(dst, daddr); + path->ops->confirm_neigh(path, daddr); +} + +int xfrm_policy_register_afinfo(const struct xfrm_policy_afinfo *afinfo, int family) +{ + int err = 0; + + if (WARN_ON(family >= ARRAY_SIZE(xfrm_policy_afinfo))) + return -EAFNOSUPPORT; + + spin_lock(&xfrm_policy_afinfo_lock); + if (unlikely(xfrm_policy_afinfo[family] != NULL)) + err = -EEXIST; + else { + struct dst_ops *dst_ops = afinfo->dst_ops; + if (likely(dst_ops->kmem_cachep == NULL)) + dst_ops->kmem_cachep = xfrm_dst_cache; + if (likely(dst_ops->check == NULL)) + dst_ops->check = xfrm_dst_check; + if (likely(dst_ops->default_advmss == NULL)) + dst_ops->default_advmss = xfrm_default_advmss; + if (likely(dst_ops->mtu == NULL)) + dst_ops->mtu = xfrm_mtu; + if (likely(dst_ops->negative_advice == NULL)) + dst_ops->negative_advice = xfrm_negative_advice; + if (likely(dst_ops->link_failure == NULL)) + dst_ops->link_failure = xfrm_link_failure; + if (likely(dst_ops->neigh_lookup == NULL)) + dst_ops->neigh_lookup = xfrm_neigh_lookup; + if (likely(!dst_ops->confirm_neigh)) + dst_ops->confirm_neigh = xfrm_confirm_neigh; + rcu_assign_pointer(xfrm_policy_afinfo[family], afinfo); + } + spin_unlock(&xfrm_policy_afinfo_lock); + + return err; +} +EXPORT_SYMBOL(xfrm_policy_register_afinfo); + +void xfrm_policy_unregister_afinfo(const struct xfrm_policy_afinfo *afinfo) +{ + struct dst_ops *dst_ops = afinfo->dst_ops; + int i; + + for (i = 0; i < ARRAY_SIZE(xfrm_policy_afinfo); i++) { + if (xfrm_policy_afinfo[i] != afinfo) + continue; + RCU_INIT_POINTER(xfrm_policy_afinfo[i], NULL); + break; + } + + synchronize_rcu(); + + dst_ops->kmem_cachep = NULL; + dst_ops->check = NULL; + dst_ops->negative_advice = NULL; + dst_ops->link_failure = NULL; +} +EXPORT_SYMBOL(xfrm_policy_unregister_afinfo); + +void xfrm_if_register_cb(const struct xfrm_if_cb *ifcb) +{ + spin_lock(&xfrm_if_cb_lock); + rcu_assign_pointer(xfrm_if_cb, ifcb); + spin_unlock(&xfrm_if_cb_lock); +} +EXPORT_SYMBOL(xfrm_if_register_cb); + +void xfrm_if_unregister_cb(void) +{ + RCU_INIT_POINTER(xfrm_if_cb, NULL); + synchronize_rcu(); +} +EXPORT_SYMBOL(xfrm_if_unregister_cb); + +#ifdef CONFIG_XFRM_STATISTICS +static int __net_init xfrm_statistics_init(struct net *net) +{ + int rv; + net->mib.xfrm_statistics = alloc_percpu(struct linux_xfrm_mib); + if (!net->mib.xfrm_statistics) + return -ENOMEM; + rv = xfrm_proc_init(net); + if (rv < 0) + free_percpu(net->mib.xfrm_statistics); + return rv; +} + +static void xfrm_statistics_fini(struct net *net) +{ + xfrm_proc_fini(net); + free_percpu(net->mib.xfrm_statistics); +} +#else +static int __net_init xfrm_statistics_init(struct net *net) +{ + return 0; +} + +static void xfrm_statistics_fini(struct net *net) +{ +} +#endif + +static int __net_init xfrm_policy_init(struct net *net) +{ + unsigned int hmask, sz; + int dir, err; + + if (net_eq(net, &init_net)) { + xfrm_dst_cache = kmem_cache_create("xfrm_dst_cache", + sizeof(struct xfrm_dst), + 0, SLAB_HWCACHE_ALIGN|SLAB_PANIC, + NULL); + err = rhashtable_init(&xfrm_policy_inexact_table, + &xfrm_pol_inexact_params); + BUG_ON(err); + } + + hmask = 8 - 1; + sz = (hmask+1) * sizeof(struct hlist_head); + + net->xfrm.policy_byidx = xfrm_hash_alloc(sz); + if (!net->xfrm.policy_byidx) + goto out_byidx; + net->xfrm.policy_idx_hmask = hmask; + + for (dir = 0; dir < XFRM_POLICY_MAX; dir++) { + struct xfrm_policy_hash *htab; + + net->xfrm.policy_count[dir] = 0; + net->xfrm.policy_count[XFRM_POLICY_MAX + dir] = 0; + INIT_HLIST_HEAD(&net->xfrm.policy_inexact[dir]); + + htab = &net->xfrm.policy_bydst[dir]; + htab->table = xfrm_hash_alloc(sz); + if (!htab->table) + goto out_bydst; + htab->hmask = hmask; + htab->dbits4 = 32; + htab->sbits4 = 32; + htab->dbits6 = 128; + htab->sbits6 = 128; + } + net->xfrm.policy_hthresh.lbits4 = 32; + net->xfrm.policy_hthresh.rbits4 = 32; + net->xfrm.policy_hthresh.lbits6 = 128; + net->xfrm.policy_hthresh.rbits6 = 128; + + seqlock_init(&net->xfrm.policy_hthresh.lock); + + INIT_LIST_HEAD(&net->xfrm.policy_all); + INIT_LIST_HEAD(&net->xfrm.inexact_bins); + INIT_WORK(&net->xfrm.policy_hash_work, xfrm_hash_resize); + INIT_WORK(&net->xfrm.policy_hthresh.work, xfrm_hash_rebuild); + return 0; + +out_bydst: + for (dir--; dir >= 0; dir--) { + struct xfrm_policy_hash *htab; + + htab = &net->xfrm.policy_bydst[dir]; + xfrm_hash_free(htab->table, sz); + } + xfrm_hash_free(net->xfrm.policy_byidx, sz); +out_byidx: + return -ENOMEM; +} + +static void xfrm_policy_fini(struct net *net) +{ + struct xfrm_pol_inexact_bin *b, *t; + unsigned int sz; + int dir; + + flush_work(&net->xfrm.policy_hash_work); +#ifdef CONFIG_XFRM_SUB_POLICY + xfrm_policy_flush(net, XFRM_POLICY_TYPE_SUB, false); +#endif + xfrm_policy_flush(net, XFRM_POLICY_TYPE_MAIN, false); + + WARN_ON(!list_empty(&net->xfrm.policy_all)); + + for (dir = 0; dir < XFRM_POLICY_MAX; dir++) { + struct xfrm_policy_hash *htab; + + WARN_ON(!hlist_empty(&net->xfrm.policy_inexact[dir])); + + htab = &net->xfrm.policy_bydst[dir]; + sz = (htab->hmask + 1) * sizeof(struct hlist_head); + WARN_ON(!hlist_empty(htab->table)); + xfrm_hash_free(htab->table, sz); + } + + sz = (net->xfrm.policy_idx_hmask + 1) * sizeof(struct hlist_head); + WARN_ON(!hlist_empty(net->xfrm.policy_byidx)); + xfrm_hash_free(net->xfrm.policy_byidx, sz); + + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + list_for_each_entry_safe(b, t, &net->xfrm.inexact_bins, inexact_bins) + __xfrm_policy_inexact_prune_bin(b, true); + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); +} + +static int __net_init xfrm_net_init(struct net *net) +{ + int rv; + + /* Initialize the per-net locks here */ + spin_lock_init(&net->xfrm.xfrm_state_lock); + spin_lock_init(&net->xfrm.xfrm_policy_lock); + seqcount_spinlock_init(&net->xfrm.xfrm_policy_hash_generation, &net->xfrm.xfrm_policy_lock); + mutex_init(&net->xfrm.xfrm_cfg_mutex); + net->xfrm.policy_default[XFRM_POLICY_IN] = XFRM_USERPOLICY_ACCEPT; + net->xfrm.policy_default[XFRM_POLICY_FWD] = XFRM_USERPOLICY_ACCEPT; + net->xfrm.policy_default[XFRM_POLICY_OUT] = XFRM_USERPOLICY_ACCEPT; + + rv = xfrm_statistics_init(net); + if (rv < 0) + goto out_statistics; + rv = xfrm_state_init(net); + if (rv < 0) + goto out_state; + rv = xfrm_policy_init(net); + if (rv < 0) + goto out_policy; + rv = xfrm_sysctl_init(net); + if (rv < 0) + goto out_sysctl; + + return 0; + +out_sysctl: + xfrm_policy_fini(net); +out_policy: + xfrm_state_fini(net); +out_state: + xfrm_statistics_fini(net); +out_statistics: + return rv; +} + +static void __net_exit xfrm_net_exit(struct net *net) +{ + xfrm_sysctl_fini(net); + xfrm_policy_fini(net); + xfrm_state_fini(net); + xfrm_statistics_fini(net); +} + +static struct pernet_operations __net_initdata xfrm_net_ops = { + .init = xfrm_net_init, + .exit = xfrm_net_exit, +}; + +void __init xfrm_init(void) +{ + register_pernet_subsys(&xfrm_net_ops); + xfrm_dev_init(); + xfrm_input_init(); + +#ifdef CONFIG_XFRM_ESPINTCP + espintcp_init(); +#endif +} + +#ifdef CONFIG_AUDITSYSCALL +static void xfrm_audit_common_policyinfo(struct xfrm_policy *xp, + struct audit_buffer *audit_buf) +{ + struct xfrm_sec_ctx *ctx = xp->security; + struct xfrm_selector *sel = &xp->selector; + + if (ctx) + audit_log_format(audit_buf, " sec_alg=%u sec_doi=%u sec_obj=%s", + ctx->ctx_alg, ctx->ctx_doi, ctx->ctx_str); + + switch (sel->family) { + case AF_INET: + audit_log_format(audit_buf, " src=%pI4", &sel->saddr.a4); + if (sel->prefixlen_s != 32) + audit_log_format(audit_buf, " src_prefixlen=%d", + sel->prefixlen_s); + audit_log_format(audit_buf, " dst=%pI4", &sel->daddr.a4); + if (sel->prefixlen_d != 32) + audit_log_format(audit_buf, " dst_prefixlen=%d", + sel->prefixlen_d); + break; + case AF_INET6: + audit_log_format(audit_buf, " src=%pI6", sel->saddr.a6); + if (sel->prefixlen_s != 128) + audit_log_format(audit_buf, " src_prefixlen=%d", + sel->prefixlen_s); + audit_log_format(audit_buf, " dst=%pI6", sel->daddr.a6); + if (sel->prefixlen_d != 128) + audit_log_format(audit_buf, " dst_prefixlen=%d", + sel->prefixlen_d); + break; + } +} + +void xfrm_audit_policy_add(struct xfrm_policy *xp, int result, bool task_valid) +{ + struct audit_buffer *audit_buf; + + audit_buf = xfrm_audit_start("SPD-add"); + if (audit_buf == NULL) + return; + xfrm_audit_helper_usrinfo(task_valid, audit_buf); + audit_log_format(audit_buf, " res=%u", result); + xfrm_audit_common_policyinfo(xp, audit_buf); + audit_log_end(audit_buf); +} +EXPORT_SYMBOL_GPL(xfrm_audit_policy_add); + +void xfrm_audit_policy_delete(struct xfrm_policy *xp, int result, + bool task_valid) +{ + struct audit_buffer *audit_buf; + + audit_buf = xfrm_audit_start("SPD-delete"); + if (audit_buf == NULL) + return; + xfrm_audit_helper_usrinfo(task_valid, audit_buf); + audit_log_format(audit_buf, " res=%u", result); + xfrm_audit_common_policyinfo(xp, audit_buf); + audit_log_end(audit_buf); +} +EXPORT_SYMBOL_GPL(xfrm_audit_policy_delete); +#endif + +#ifdef CONFIG_XFRM_MIGRATE +static bool xfrm_migrate_selector_match(const struct xfrm_selector *sel_cmp, + const struct xfrm_selector *sel_tgt) +{ + if (sel_cmp->proto == IPSEC_ULPROTO_ANY) { + if (sel_tgt->family == sel_cmp->family && + xfrm_addr_equal(&sel_tgt->daddr, &sel_cmp->daddr, + sel_cmp->family) && + xfrm_addr_equal(&sel_tgt->saddr, &sel_cmp->saddr, + sel_cmp->family) && + sel_tgt->prefixlen_d == sel_cmp->prefixlen_d && + sel_tgt->prefixlen_s == sel_cmp->prefixlen_s) { + return true; + } + } else { + if (memcmp(sel_tgt, sel_cmp, sizeof(*sel_tgt)) == 0) { + return true; + } + } + return false; +} + +static struct xfrm_policy *xfrm_migrate_policy_find(const struct xfrm_selector *sel, + u8 dir, u8 type, struct net *net, u32 if_id) +{ + struct xfrm_policy *pol, *ret = NULL; + struct hlist_head *chain; + u32 priority = ~0U; + + spin_lock_bh(&net->xfrm.xfrm_policy_lock); + chain = policy_hash_direct(net, &sel->daddr, &sel->saddr, sel->family, dir); + hlist_for_each_entry(pol, chain, bydst) { + if ((if_id == 0 || pol->if_id == if_id) && + xfrm_migrate_selector_match(sel, &pol->selector) && + pol->type == type) { + ret = pol; + priority = ret->priority; + break; + } + } + chain = &net->xfrm.policy_inexact[dir]; + hlist_for_each_entry(pol, chain, bydst_inexact_list) { + if ((pol->priority >= priority) && ret) + break; + + if ((if_id == 0 || pol->if_id == if_id) && + xfrm_migrate_selector_match(sel, &pol->selector) && + pol->type == type) { + ret = pol; + break; + } + } + + xfrm_pol_hold(ret); + + spin_unlock_bh(&net->xfrm.xfrm_policy_lock); + + return ret; +} + +static int migrate_tmpl_match(const struct xfrm_migrate *m, const struct xfrm_tmpl *t) +{ + int match = 0; + + if (t->mode == m->mode && t->id.proto == m->proto && + (m->reqid == 0 || t->reqid == m->reqid)) { + switch (t->mode) { + case XFRM_MODE_TUNNEL: + case XFRM_MODE_BEET: + if (xfrm_addr_equal(&t->id.daddr, &m->old_daddr, + m->old_family) && + xfrm_addr_equal(&t->saddr, &m->old_saddr, + m->old_family)) { + match = 1; + } + break; + case XFRM_MODE_TRANSPORT: + /* in case of transport mode, template does not store + any IP addresses, hence we just compare mode and + protocol */ + match = 1; + break; + default: + break; + } + } + return match; +} + +/* update endpoint address(es) of template(s) */ +static int xfrm_policy_migrate(struct xfrm_policy *pol, + struct xfrm_migrate *m, int num_migrate) +{ + struct xfrm_migrate *mp; + int i, j, n = 0; + + write_lock_bh(&pol->lock); + if (unlikely(pol->walk.dead)) { + /* target policy has been deleted */ + write_unlock_bh(&pol->lock); + return -ENOENT; + } + + for (i = 0; i < pol->xfrm_nr; i++) { + for (j = 0, mp = m; j < num_migrate; j++, mp++) { + if (!migrate_tmpl_match(mp, &pol->xfrm_vec[i])) + continue; + n++; + if (pol->xfrm_vec[i].mode != XFRM_MODE_TUNNEL && + pol->xfrm_vec[i].mode != XFRM_MODE_BEET) + continue; + /* update endpoints */ + memcpy(&pol->xfrm_vec[i].id.daddr, &mp->new_daddr, + sizeof(pol->xfrm_vec[i].id.daddr)); + memcpy(&pol->xfrm_vec[i].saddr, &mp->new_saddr, + sizeof(pol->xfrm_vec[i].saddr)); + pol->xfrm_vec[i].encap_family = mp->new_family; + /* flush bundles */ + atomic_inc(&pol->genid); + } + } + + write_unlock_bh(&pol->lock); + + if (!n) + return -ENODATA; + + return 0; +} + +static int xfrm_migrate_check(const struct xfrm_migrate *m, int num_migrate) +{ + int i, j; + + if (num_migrate < 1 || num_migrate > XFRM_MAX_DEPTH) + return -EINVAL; + + for (i = 0; i < num_migrate; i++) { + if (xfrm_addr_any(&m[i].new_daddr, m[i].new_family) || + xfrm_addr_any(&m[i].new_saddr, m[i].new_family)) + return -EINVAL; + + /* check if there is any duplicated entry */ + for (j = i + 1; j < num_migrate; j++) { + if (!memcmp(&m[i].old_daddr, &m[j].old_daddr, + sizeof(m[i].old_daddr)) && + !memcmp(&m[i].old_saddr, &m[j].old_saddr, + sizeof(m[i].old_saddr)) && + m[i].proto == m[j].proto && + m[i].mode == m[j].mode && + m[i].reqid == m[j].reqid && + m[i].old_family == m[j].old_family) + return -EINVAL; + } + } + + return 0; +} + +int xfrm_migrate(const struct xfrm_selector *sel, u8 dir, u8 type, + struct xfrm_migrate *m, int num_migrate, + struct xfrm_kmaddress *k, struct net *net, + struct xfrm_encap_tmpl *encap, u32 if_id) +{ + int i, err, nx_cur = 0, nx_new = 0; + struct xfrm_policy *pol = NULL; + struct xfrm_state *x, *xc; + struct xfrm_state *x_cur[XFRM_MAX_DEPTH]; + struct xfrm_state *x_new[XFRM_MAX_DEPTH]; + struct xfrm_migrate *mp; + + /* Stage 0 - sanity checks */ + if ((err = xfrm_migrate_check(m, num_migrate)) < 0) + goto out; + + if (dir >= XFRM_POLICY_MAX) { + err = -EINVAL; + goto out; + } + + /* Stage 1 - find policy */ + if ((pol = xfrm_migrate_policy_find(sel, dir, type, net, if_id)) == NULL) { + err = -ENOENT; + goto out; + } + + /* Stage 2 - find and update state(s) */ + for (i = 0, mp = m; i < num_migrate; i++, mp++) { + if ((x = xfrm_migrate_state_find(mp, net, if_id))) { + x_cur[nx_cur] = x; + nx_cur++; + xc = xfrm_state_migrate(x, mp, encap); + if (xc) { + x_new[nx_new] = xc; + nx_new++; + } else { + err = -ENODATA; + goto restore_state; + } + } + } + + /* Stage 3 - update policy */ + if ((err = xfrm_policy_migrate(pol, m, num_migrate)) < 0) + goto restore_state; + + /* Stage 4 - delete old state(s) */ + if (nx_cur) { + xfrm_states_put(x_cur, nx_cur); + xfrm_states_delete(x_cur, nx_cur); + } + + /* Stage 5 - announce */ + km_migrate(sel, dir, type, m, num_migrate, k, encap); + + xfrm_pol_put(pol); + + return 0; +out: + return err; + +restore_state: + if (pol) + xfrm_pol_put(pol); + if (nx_cur) + xfrm_states_put(x_cur, nx_cur); + if (nx_new) + xfrm_states_delete(x_new, nx_new); + + return err; +} +EXPORT_SYMBOL(xfrm_migrate); +#endif diff --git a/net/xfrm/xfrm_proc.c b/net/xfrm/xfrm_proc.c new file mode 100644 index 000000000..fee9b5cf3 --- /dev/null +++ b/net/xfrm/xfrm_proc.c @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +/* + * xfrm_proc.c + * + * Copyright (C)2006-2007 USAGI/WIDE Project + * + * Authors: Masahide NAKAMURA + */ +#include +#include +#include +#include +#include + +static const struct snmp_mib xfrm_mib_list[] = { + SNMP_MIB_ITEM("XfrmInError", LINUX_MIB_XFRMINERROR), + SNMP_MIB_ITEM("XfrmInBufferError", LINUX_MIB_XFRMINBUFFERERROR), + SNMP_MIB_ITEM("XfrmInHdrError", LINUX_MIB_XFRMINHDRERROR), + SNMP_MIB_ITEM("XfrmInNoStates", LINUX_MIB_XFRMINNOSTATES), + SNMP_MIB_ITEM("XfrmInStateProtoError", LINUX_MIB_XFRMINSTATEPROTOERROR), + SNMP_MIB_ITEM("XfrmInStateModeError", LINUX_MIB_XFRMINSTATEMODEERROR), + SNMP_MIB_ITEM("XfrmInStateSeqError", LINUX_MIB_XFRMINSTATESEQERROR), + SNMP_MIB_ITEM("XfrmInStateExpired", LINUX_MIB_XFRMINSTATEEXPIRED), + SNMP_MIB_ITEM("XfrmInStateMismatch", LINUX_MIB_XFRMINSTATEMISMATCH), + SNMP_MIB_ITEM("XfrmInStateInvalid", LINUX_MIB_XFRMINSTATEINVALID), + SNMP_MIB_ITEM("XfrmInTmplMismatch", LINUX_MIB_XFRMINTMPLMISMATCH), + SNMP_MIB_ITEM("XfrmInNoPols", LINUX_MIB_XFRMINNOPOLS), + SNMP_MIB_ITEM("XfrmInPolBlock", LINUX_MIB_XFRMINPOLBLOCK), + SNMP_MIB_ITEM("XfrmInPolError", LINUX_MIB_XFRMINPOLERROR), + SNMP_MIB_ITEM("XfrmOutError", LINUX_MIB_XFRMOUTERROR), + SNMP_MIB_ITEM("XfrmOutBundleGenError", LINUX_MIB_XFRMOUTBUNDLEGENERROR), + SNMP_MIB_ITEM("XfrmOutBundleCheckError", LINUX_MIB_XFRMOUTBUNDLECHECKERROR), + SNMP_MIB_ITEM("XfrmOutNoStates", LINUX_MIB_XFRMOUTNOSTATES), + SNMP_MIB_ITEM("XfrmOutStateProtoError", LINUX_MIB_XFRMOUTSTATEPROTOERROR), + SNMP_MIB_ITEM("XfrmOutStateModeError", LINUX_MIB_XFRMOUTSTATEMODEERROR), + SNMP_MIB_ITEM("XfrmOutStateSeqError", LINUX_MIB_XFRMOUTSTATESEQERROR), + SNMP_MIB_ITEM("XfrmOutStateExpired", LINUX_MIB_XFRMOUTSTATEEXPIRED), + SNMP_MIB_ITEM("XfrmOutPolBlock", LINUX_MIB_XFRMOUTPOLBLOCK), + SNMP_MIB_ITEM("XfrmOutPolDead", LINUX_MIB_XFRMOUTPOLDEAD), + SNMP_MIB_ITEM("XfrmOutPolError", LINUX_MIB_XFRMOUTPOLERROR), + SNMP_MIB_ITEM("XfrmFwdHdrError", LINUX_MIB_XFRMFWDHDRERROR), + SNMP_MIB_ITEM("XfrmOutStateInvalid", LINUX_MIB_XFRMOUTSTATEINVALID), + SNMP_MIB_ITEM("XfrmAcquireError", LINUX_MIB_XFRMACQUIREERROR), + SNMP_MIB_SENTINEL +}; + +static int xfrm_statistics_seq_show(struct seq_file *seq, void *v) +{ + unsigned long buff[LINUX_MIB_XFRMMAX]; + struct net *net = seq->private; + int i; + + memset(buff, 0, sizeof(unsigned long) * LINUX_MIB_XFRMMAX); + + snmp_get_cpu_field_batch(buff, xfrm_mib_list, + net->mib.xfrm_statistics); + for (i = 0; xfrm_mib_list[i].name; i++) + seq_printf(seq, "%-24s\t%lu\n", xfrm_mib_list[i].name, + buff[i]); + + return 0; +} + +int __net_init xfrm_proc_init(struct net *net) +{ + if (!proc_create_net_single("xfrm_stat", 0444, net->proc_net, + xfrm_statistics_seq_show, NULL)) + return -ENOMEM; + return 0; +} + +void xfrm_proc_fini(struct net *net) +{ + remove_proc_entry("xfrm_stat", net->proc_net); +} diff --git a/net/xfrm/xfrm_replay.c b/net/xfrm/xfrm_replay.c new file mode 100644 index 000000000..ce56d659c --- /dev/null +++ b/net/xfrm/xfrm_replay.c @@ -0,0 +1,795 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * xfrm_replay.c - xfrm replay detection, derived from xfrm_state.c. + * + * Copyright (C) 2010 secunet Security Networks AG + * Copyright (C) 2010 Steffen Klassert + */ + +#include +#include + +u32 xfrm_replay_seqhi(struct xfrm_state *x, __be32 net_seq) +{ + u32 seq, seq_hi, bottom; + struct xfrm_replay_state_esn *replay_esn = x->replay_esn; + + if (!(x->props.flags & XFRM_STATE_ESN)) + return 0; + + seq = ntohl(net_seq); + seq_hi = replay_esn->seq_hi; + bottom = replay_esn->seq - replay_esn->replay_window + 1; + + if (likely(replay_esn->seq >= replay_esn->replay_window - 1)) { + /* A. same subspace */ + if (unlikely(seq < bottom)) + seq_hi++; + } else { + /* B. window spans two subspaces */ + if (unlikely(seq >= bottom)) + seq_hi--; + } + + return seq_hi; +} +EXPORT_SYMBOL(xfrm_replay_seqhi); + +static void xfrm_replay_notify_bmp(struct xfrm_state *x, int event); +static void xfrm_replay_notify_esn(struct xfrm_state *x, int event); + +void xfrm_replay_notify(struct xfrm_state *x, int event) +{ + struct km_event c; + /* we send notify messages in case + * 1. we updated on of the sequence numbers, and the seqno difference + * is at least x->replay_maxdiff, in this case we also update the + * timeout of our timer function + * 2. if x->replay_maxage has elapsed since last update, + * and there were changes + * + * The state structure must be locked! + */ + + switch (x->repl_mode) { + case XFRM_REPLAY_MODE_LEGACY: + break; + case XFRM_REPLAY_MODE_BMP: + xfrm_replay_notify_bmp(x, event); + return; + case XFRM_REPLAY_MODE_ESN: + xfrm_replay_notify_esn(x, event); + return; + } + + switch (event) { + case XFRM_REPLAY_UPDATE: + if (!x->replay_maxdiff || + ((x->replay.seq - x->preplay.seq < x->replay_maxdiff) && + (x->replay.oseq - x->preplay.oseq < x->replay_maxdiff))) { + if (x->xflags & XFRM_TIME_DEFER) + event = XFRM_REPLAY_TIMEOUT; + else + return; + } + + break; + + case XFRM_REPLAY_TIMEOUT: + if (memcmp(&x->replay, &x->preplay, + sizeof(struct xfrm_replay_state)) == 0) { + x->xflags |= XFRM_TIME_DEFER; + return; + } + + break; + } + + memcpy(&x->preplay, &x->replay, sizeof(struct xfrm_replay_state)); + c.event = XFRM_MSG_NEWAE; + c.data.aevent = event; + km_state_notify(x, &c); + + if (x->replay_maxage && + !mod_timer(&x->rtimer, jiffies + x->replay_maxage)) + x->xflags &= ~XFRM_TIME_DEFER; +} + +static int __xfrm_replay_overflow(struct xfrm_state *x, struct sk_buff *skb) +{ + int err = 0; + struct net *net = xs_net(x); + + if (x->type->flags & XFRM_TYPE_REPLAY_PROT) { + XFRM_SKB_CB(skb)->seq.output.low = ++x->replay.oseq; + XFRM_SKB_CB(skb)->seq.output.hi = 0; + if (unlikely(x->replay.oseq == 0) && + !(x->props.extra_flags & XFRM_SA_XFLAG_OSEQ_MAY_WRAP)) { + x->replay.oseq--; + xfrm_audit_state_replay_overflow(x, skb); + err = -EOVERFLOW; + + return err; + } + if (xfrm_aevent_is_on(net)) + xfrm_replay_notify(x, XFRM_REPLAY_UPDATE); + } + + return err; +} + +static int xfrm_replay_check_legacy(struct xfrm_state *x, + struct sk_buff *skb, __be32 net_seq) +{ + u32 diff; + u32 seq = ntohl(net_seq); + + if (!x->props.replay_window) + return 0; + + if (unlikely(seq == 0)) + goto err; + + if (likely(seq > x->replay.seq)) + return 0; + + diff = x->replay.seq - seq; + if (diff >= x->props.replay_window) { + x->stats.replay_window++; + goto err; + } + + if (x->replay.bitmap & (1U << diff)) { + x->stats.replay++; + goto err; + } + return 0; + +err: + xfrm_audit_state_replay(x, skb, net_seq); + return -EINVAL; +} + +static void xfrm_replay_advance_bmp(struct xfrm_state *x, __be32 net_seq); +static void xfrm_replay_advance_esn(struct xfrm_state *x, __be32 net_seq); + +void xfrm_replay_advance(struct xfrm_state *x, __be32 net_seq) +{ + u32 diff, seq; + + switch (x->repl_mode) { + case XFRM_REPLAY_MODE_LEGACY: + break; + case XFRM_REPLAY_MODE_BMP: + return xfrm_replay_advance_bmp(x, net_seq); + case XFRM_REPLAY_MODE_ESN: + return xfrm_replay_advance_esn(x, net_seq); + } + + if (!x->props.replay_window) + return; + + seq = ntohl(net_seq); + if (seq > x->replay.seq) { + diff = seq - x->replay.seq; + if (diff < x->props.replay_window) + x->replay.bitmap = ((x->replay.bitmap) << diff) | 1; + else + x->replay.bitmap = 1; + x->replay.seq = seq; + } else { + diff = x->replay.seq - seq; + x->replay.bitmap |= (1U << diff); + } + + if (xfrm_aevent_is_on(xs_net(x))) + xfrm_replay_notify(x, XFRM_REPLAY_UPDATE); +} + +static int xfrm_replay_overflow_bmp(struct xfrm_state *x, struct sk_buff *skb) +{ + int err = 0; + struct xfrm_replay_state_esn *replay_esn = x->replay_esn; + struct net *net = xs_net(x); + + if (x->type->flags & XFRM_TYPE_REPLAY_PROT) { + XFRM_SKB_CB(skb)->seq.output.low = ++replay_esn->oseq; + XFRM_SKB_CB(skb)->seq.output.hi = 0; + if (unlikely(replay_esn->oseq == 0) && + !(x->props.extra_flags & XFRM_SA_XFLAG_OSEQ_MAY_WRAP)) { + replay_esn->oseq--; + xfrm_audit_state_replay_overflow(x, skb); + err = -EOVERFLOW; + + return err; + } + if (xfrm_aevent_is_on(net)) + xfrm_replay_notify(x, XFRM_REPLAY_UPDATE); + } + + return err; +} + +static int xfrm_replay_check_bmp(struct xfrm_state *x, + struct sk_buff *skb, __be32 net_seq) +{ + unsigned int bitnr, nr; + struct xfrm_replay_state_esn *replay_esn = x->replay_esn; + u32 pos; + u32 seq = ntohl(net_seq); + u32 diff = replay_esn->seq - seq; + + if (!replay_esn->replay_window) + return 0; + + if (unlikely(seq == 0)) + goto err; + + if (likely(seq > replay_esn->seq)) + return 0; + + if (diff >= replay_esn->replay_window) { + x->stats.replay_window++; + goto err; + } + + pos = (replay_esn->seq - 1) % replay_esn->replay_window; + + if (pos >= diff) + bitnr = (pos - diff) % replay_esn->replay_window; + else + bitnr = replay_esn->replay_window - (diff - pos); + + nr = bitnr >> 5; + bitnr = bitnr & 0x1F; + if (replay_esn->bmp[nr] & (1U << bitnr)) + goto err_replay; + + return 0; + +err_replay: + x->stats.replay++; +err: + xfrm_audit_state_replay(x, skb, net_seq); + return -EINVAL; +} + +static void xfrm_replay_advance_bmp(struct xfrm_state *x, __be32 net_seq) +{ + unsigned int bitnr, nr, i; + u32 diff; + struct xfrm_replay_state_esn *replay_esn = x->replay_esn; + u32 seq = ntohl(net_seq); + u32 pos; + + if (!replay_esn->replay_window) + return; + + pos = (replay_esn->seq - 1) % replay_esn->replay_window; + + if (seq > replay_esn->seq) { + diff = seq - replay_esn->seq; + + if (diff < replay_esn->replay_window) { + for (i = 1; i < diff; i++) { + bitnr = (pos + i) % replay_esn->replay_window; + nr = bitnr >> 5; + bitnr = bitnr & 0x1F; + replay_esn->bmp[nr] &= ~(1U << bitnr); + } + } else { + nr = (replay_esn->replay_window - 1) >> 5; + for (i = 0; i <= nr; i++) + replay_esn->bmp[i] = 0; + } + + bitnr = (pos + diff) % replay_esn->replay_window; + replay_esn->seq = seq; + } else { + diff = replay_esn->seq - seq; + + if (pos >= diff) + bitnr = (pos - diff) % replay_esn->replay_window; + else + bitnr = replay_esn->replay_window - (diff - pos); + } + + nr = bitnr >> 5; + bitnr = bitnr & 0x1F; + replay_esn->bmp[nr] |= (1U << bitnr); + + if (xfrm_aevent_is_on(xs_net(x))) + xfrm_replay_notify(x, XFRM_REPLAY_UPDATE); +} + +static void xfrm_replay_notify_bmp(struct xfrm_state *x, int event) +{ + struct km_event c; + struct xfrm_replay_state_esn *replay_esn = x->replay_esn; + struct xfrm_replay_state_esn *preplay_esn = x->preplay_esn; + + /* we send notify messages in case + * 1. we updated on of the sequence numbers, and the seqno difference + * is at least x->replay_maxdiff, in this case we also update the + * timeout of our timer function + * 2. if x->replay_maxage has elapsed since last update, + * and there were changes + * + * The state structure must be locked! + */ + + switch (event) { + case XFRM_REPLAY_UPDATE: + if (!x->replay_maxdiff || + ((replay_esn->seq - preplay_esn->seq < x->replay_maxdiff) && + (replay_esn->oseq - preplay_esn->oseq + < x->replay_maxdiff))) { + if (x->xflags & XFRM_TIME_DEFER) + event = XFRM_REPLAY_TIMEOUT; + else + return; + } + + break; + + case XFRM_REPLAY_TIMEOUT: + if (memcmp(x->replay_esn, x->preplay_esn, + xfrm_replay_state_esn_len(replay_esn)) == 0) { + x->xflags |= XFRM_TIME_DEFER; + return; + } + + break; + } + + memcpy(x->preplay_esn, x->replay_esn, + xfrm_replay_state_esn_len(replay_esn)); + c.event = XFRM_MSG_NEWAE; + c.data.aevent = event; + km_state_notify(x, &c); + + if (x->replay_maxage && + !mod_timer(&x->rtimer, jiffies + x->replay_maxage)) + x->xflags &= ~XFRM_TIME_DEFER; +} + +static void xfrm_replay_notify_esn(struct xfrm_state *x, int event) +{ + u32 seq_diff, oseq_diff; + struct km_event c; + struct xfrm_replay_state_esn *replay_esn = x->replay_esn; + struct xfrm_replay_state_esn *preplay_esn = x->preplay_esn; + + /* we send notify messages in case + * 1. we updated on of the sequence numbers, and the seqno difference + * is at least x->replay_maxdiff, in this case we also update the + * timeout of our timer function + * 2. if x->replay_maxage has elapsed since last update, + * and there were changes + * + * The state structure must be locked! + */ + + switch (event) { + case XFRM_REPLAY_UPDATE: + if (x->replay_maxdiff) { + if (replay_esn->seq_hi == preplay_esn->seq_hi) + seq_diff = replay_esn->seq - preplay_esn->seq; + else + seq_diff = ~preplay_esn->seq + replay_esn->seq + + 1; + + if (replay_esn->oseq_hi == preplay_esn->oseq_hi) + oseq_diff = replay_esn->oseq + - preplay_esn->oseq; + else + oseq_diff = ~preplay_esn->oseq + + replay_esn->oseq + 1; + + if (seq_diff >= x->replay_maxdiff || + oseq_diff >= x->replay_maxdiff) + break; + } + + if (x->xflags & XFRM_TIME_DEFER) + event = XFRM_REPLAY_TIMEOUT; + else + return; + + break; + + case XFRM_REPLAY_TIMEOUT: + if (memcmp(x->replay_esn, x->preplay_esn, + xfrm_replay_state_esn_len(replay_esn)) == 0) { + x->xflags |= XFRM_TIME_DEFER; + return; + } + + break; + } + + memcpy(x->preplay_esn, x->replay_esn, + xfrm_replay_state_esn_len(replay_esn)); + c.event = XFRM_MSG_NEWAE; + c.data.aevent = event; + km_state_notify(x, &c); + + if (x->replay_maxage && + !mod_timer(&x->rtimer, jiffies + x->replay_maxage)) + x->xflags &= ~XFRM_TIME_DEFER; +} + +static int xfrm_replay_overflow_esn(struct xfrm_state *x, struct sk_buff *skb) +{ + int err = 0; + struct xfrm_replay_state_esn *replay_esn = x->replay_esn; + struct net *net = xs_net(x); + + if (x->type->flags & XFRM_TYPE_REPLAY_PROT) { + XFRM_SKB_CB(skb)->seq.output.low = ++replay_esn->oseq; + XFRM_SKB_CB(skb)->seq.output.hi = replay_esn->oseq_hi; + + if (unlikely(replay_esn->oseq == 0)) { + XFRM_SKB_CB(skb)->seq.output.hi = ++replay_esn->oseq_hi; + + if (replay_esn->oseq_hi == 0) { + replay_esn->oseq--; + replay_esn->oseq_hi--; + xfrm_audit_state_replay_overflow(x, skb); + err = -EOVERFLOW; + + return err; + } + } + if (xfrm_aevent_is_on(net)) + xfrm_replay_notify(x, XFRM_REPLAY_UPDATE); + } + + return err; +} + +static int xfrm_replay_check_esn(struct xfrm_state *x, + struct sk_buff *skb, __be32 net_seq) +{ + unsigned int bitnr, nr; + u32 diff; + struct xfrm_replay_state_esn *replay_esn = x->replay_esn; + u32 pos; + u32 seq = ntohl(net_seq); + u32 wsize = replay_esn->replay_window; + u32 top = replay_esn->seq; + u32 bottom = top - wsize + 1; + + if (!wsize) + return 0; + + if (unlikely(seq == 0 && replay_esn->seq_hi == 0 && + (replay_esn->seq < replay_esn->replay_window - 1))) + goto err; + + diff = top - seq; + + if (likely(top >= wsize - 1)) { + /* A. same subspace */ + if (likely(seq > top) || seq < bottom) + return 0; + } else { + /* B. window spans two subspaces */ + if (likely(seq > top && seq < bottom)) + return 0; + if (seq >= bottom) + diff = ~seq + top + 1; + } + + if (diff >= replay_esn->replay_window) { + x->stats.replay_window++; + goto err; + } + + pos = (replay_esn->seq - 1) % replay_esn->replay_window; + + if (pos >= diff) + bitnr = (pos - diff) % replay_esn->replay_window; + else + bitnr = replay_esn->replay_window - (diff - pos); + + nr = bitnr >> 5; + bitnr = bitnr & 0x1F; + if (replay_esn->bmp[nr] & (1U << bitnr)) + goto err_replay; + + return 0; + +err_replay: + x->stats.replay++; +err: + xfrm_audit_state_replay(x, skb, net_seq); + return -EINVAL; +} + +int xfrm_replay_check(struct xfrm_state *x, + struct sk_buff *skb, __be32 net_seq) +{ + switch (x->repl_mode) { + case XFRM_REPLAY_MODE_LEGACY: + break; + case XFRM_REPLAY_MODE_BMP: + return xfrm_replay_check_bmp(x, skb, net_seq); + case XFRM_REPLAY_MODE_ESN: + return xfrm_replay_check_esn(x, skb, net_seq); + } + + return xfrm_replay_check_legacy(x, skb, net_seq); +} + +static int xfrm_replay_recheck_esn(struct xfrm_state *x, + struct sk_buff *skb, __be32 net_seq) +{ + if (unlikely(XFRM_SKB_CB(skb)->seq.input.hi != + htonl(xfrm_replay_seqhi(x, net_seq)))) { + x->stats.replay_window++; + return -EINVAL; + } + + return xfrm_replay_check_esn(x, skb, net_seq); +} + +int xfrm_replay_recheck(struct xfrm_state *x, + struct sk_buff *skb, __be32 net_seq) +{ + switch (x->repl_mode) { + case XFRM_REPLAY_MODE_LEGACY: + break; + case XFRM_REPLAY_MODE_BMP: + /* no special recheck treatment */ + return xfrm_replay_check_bmp(x, skb, net_seq); + case XFRM_REPLAY_MODE_ESN: + return xfrm_replay_recheck_esn(x, skb, net_seq); + } + + return xfrm_replay_check_legacy(x, skb, net_seq); +} + +static void xfrm_replay_advance_esn(struct xfrm_state *x, __be32 net_seq) +{ + unsigned int bitnr, nr, i; + int wrap; + u32 diff, pos, seq, seq_hi; + struct xfrm_replay_state_esn *replay_esn = x->replay_esn; + + if (!replay_esn->replay_window) + return; + + seq = ntohl(net_seq); + pos = (replay_esn->seq - 1) % replay_esn->replay_window; + seq_hi = xfrm_replay_seqhi(x, net_seq); + wrap = seq_hi - replay_esn->seq_hi; + + if ((!wrap && seq > replay_esn->seq) || wrap > 0) { + if (likely(!wrap)) + diff = seq - replay_esn->seq; + else + diff = ~replay_esn->seq + seq + 1; + + if (diff < replay_esn->replay_window) { + for (i = 1; i < diff; i++) { + bitnr = (pos + i) % replay_esn->replay_window; + nr = bitnr >> 5; + bitnr = bitnr & 0x1F; + replay_esn->bmp[nr] &= ~(1U << bitnr); + } + } else { + nr = (replay_esn->replay_window - 1) >> 5; + for (i = 0; i <= nr; i++) + replay_esn->bmp[i] = 0; + } + + bitnr = (pos + diff) % replay_esn->replay_window; + replay_esn->seq = seq; + + if (unlikely(wrap > 0)) + replay_esn->seq_hi++; + } else { + diff = replay_esn->seq - seq; + + if (pos >= diff) + bitnr = (pos - diff) % replay_esn->replay_window; + else + bitnr = replay_esn->replay_window - (diff - pos); + } + + xfrm_dev_state_advance_esn(x); + + nr = bitnr >> 5; + bitnr = bitnr & 0x1F; + replay_esn->bmp[nr] |= (1U << bitnr); + + if (xfrm_aevent_is_on(xs_net(x))) + xfrm_replay_notify(x, XFRM_REPLAY_UPDATE); +} + +#ifdef CONFIG_XFRM_OFFLOAD +static int xfrm_replay_overflow_offload(struct xfrm_state *x, struct sk_buff *skb) +{ + int err = 0; + struct net *net = xs_net(x); + struct xfrm_offload *xo = xfrm_offload(skb); + __u32 oseq = x->replay.oseq; + + if (!xo) + return __xfrm_replay_overflow(x, skb); + + if (x->type->flags & XFRM_TYPE_REPLAY_PROT) { + if (!skb_is_gso(skb)) { + XFRM_SKB_CB(skb)->seq.output.low = ++oseq; + xo->seq.low = oseq; + } else { + XFRM_SKB_CB(skb)->seq.output.low = oseq + 1; + xo->seq.low = oseq + 1; + oseq += skb_shinfo(skb)->gso_segs; + } + + XFRM_SKB_CB(skb)->seq.output.hi = 0; + xo->seq.hi = 0; + if (unlikely(oseq < x->replay.oseq) && + !(x->props.extra_flags & XFRM_SA_XFLAG_OSEQ_MAY_WRAP)) { + xfrm_audit_state_replay_overflow(x, skb); + err = -EOVERFLOW; + + return err; + } + + x->replay.oseq = oseq; + + if (xfrm_aevent_is_on(net)) + xfrm_replay_notify(x, XFRM_REPLAY_UPDATE); + } + + return err; +} + +static int xfrm_replay_overflow_offload_bmp(struct xfrm_state *x, struct sk_buff *skb) +{ + int err = 0; + struct xfrm_offload *xo = xfrm_offload(skb); + struct xfrm_replay_state_esn *replay_esn = x->replay_esn; + struct net *net = xs_net(x); + __u32 oseq = replay_esn->oseq; + + if (!xo) + return xfrm_replay_overflow_bmp(x, skb); + + if (x->type->flags & XFRM_TYPE_REPLAY_PROT) { + if (!skb_is_gso(skb)) { + XFRM_SKB_CB(skb)->seq.output.low = ++oseq; + xo->seq.low = oseq; + } else { + XFRM_SKB_CB(skb)->seq.output.low = oseq + 1; + xo->seq.low = oseq + 1; + oseq += skb_shinfo(skb)->gso_segs; + } + + XFRM_SKB_CB(skb)->seq.output.hi = 0; + xo->seq.hi = 0; + if (unlikely(oseq < replay_esn->oseq) && + !(x->props.extra_flags & XFRM_SA_XFLAG_OSEQ_MAY_WRAP)) { + xfrm_audit_state_replay_overflow(x, skb); + err = -EOVERFLOW; + + return err; + } else { + replay_esn->oseq = oseq; + } + + if (xfrm_aevent_is_on(net)) + xfrm_replay_notify(x, XFRM_REPLAY_UPDATE); + } + + return err; +} + +static int xfrm_replay_overflow_offload_esn(struct xfrm_state *x, struct sk_buff *skb) +{ + int err = 0; + struct xfrm_offload *xo = xfrm_offload(skb); + struct xfrm_replay_state_esn *replay_esn = x->replay_esn; + struct net *net = xs_net(x); + __u32 oseq = replay_esn->oseq; + __u32 oseq_hi = replay_esn->oseq_hi; + + if (!xo) + return xfrm_replay_overflow_esn(x, skb); + + if (x->type->flags & XFRM_TYPE_REPLAY_PROT) { + if (!skb_is_gso(skb)) { + XFRM_SKB_CB(skb)->seq.output.low = ++oseq; + XFRM_SKB_CB(skb)->seq.output.hi = oseq_hi; + xo->seq.low = oseq; + xo->seq.hi = oseq_hi; + } else { + XFRM_SKB_CB(skb)->seq.output.low = oseq + 1; + XFRM_SKB_CB(skb)->seq.output.hi = oseq_hi; + xo->seq.low = oseq + 1; + xo->seq.hi = oseq_hi; + oseq += skb_shinfo(skb)->gso_segs; + } + + if (unlikely(xo->seq.low < replay_esn->oseq)) { + XFRM_SKB_CB(skb)->seq.output.hi = ++oseq_hi; + xo->seq.hi = oseq_hi; + replay_esn->oseq_hi = oseq_hi; + if (replay_esn->oseq_hi == 0) { + replay_esn->oseq--; + replay_esn->oseq_hi--; + xfrm_audit_state_replay_overflow(x, skb); + err = -EOVERFLOW; + + return err; + } + } + + replay_esn->oseq = oseq; + + if (xfrm_aevent_is_on(net)) + xfrm_replay_notify(x, XFRM_REPLAY_UPDATE); + } + + return err; +} + +int xfrm_replay_overflow(struct xfrm_state *x, struct sk_buff *skb) +{ + switch (x->repl_mode) { + case XFRM_REPLAY_MODE_LEGACY: + break; + case XFRM_REPLAY_MODE_BMP: + return xfrm_replay_overflow_offload_bmp(x, skb); + case XFRM_REPLAY_MODE_ESN: + return xfrm_replay_overflow_offload_esn(x, skb); + } + + return xfrm_replay_overflow_offload(x, skb); +} +#else +int xfrm_replay_overflow(struct xfrm_state *x, struct sk_buff *skb) +{ + switch (x->repl_mode) { + case XFRM_REPLAY_MODE_LEGACY: + break; + case XFRM_REPLAY_MODE_BMP: + return xfrm_replay_overflow_bmp(x, skb); + case XFRM_REPLAY_MODE_ESN: + return xfrm_replay_overflow_esn(x, skb); + } + + return __xfrm_replay_overflow(x, skb); +} +#endif + +int xfrm_init_replay(struct xfrm_state *x, struct netlink_ext_ack *extack) +{ + struct xfrm_replay_state_esn *replay_esn = x->replay_esn; + + if (replay_esn) { + if (replay_esn->replay_window > + replay_esn->bmp_len * sizeof(__u32) * 8) { + NL_SET_ERR_MSG(extack, "ESN replay window is too large for the chosen bitmap size"); + return -EINVAL; + } + + if (x->props.flags & XFRM_STATE_ESN) { + if (replay_esn->replay_window == 0) { + NL_SET_ERR_MSG(extack, "ESN replay window must be > 0"); + return -EINVAL; + } + x->repl_mode = XFRM_REPLAY_MODE_ESN; + } else { + x->repl_mode = XFRM_REPLAY_MODE_BMP; + } + } else { + x->repl_mode = XFRM_REPLAY_MODE_LEGACY; + } + + return 0; +} +EXPORT_SYMBOL(xfrm_init_replay); diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c new file mode 100644 index 000000000..2f4cf976b --- /dev/null +++ b/net/xfrm/xfrm_state.c @@ -0,0 +1,2933 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* + * xfrm_state.c + * + * Changes: + * Mitsuru KANDA @USAGI + * Kazunori MIYAZAWA @USAGI + * Kunihiro Ishiguro + * IPv6 support + * YOSHIFUJI Hideaki @USAGI + * Split up af-specific functions + * Derek Atkins + * Add UDP Encapsulation + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "xfrm_hash.h" + +#define xfrm_state_deref_prot(table, net) \ + rcu_dereference_protected((table), lockdep_is_held(&(net)->xfrm.xfrm_state_lock)) + +static void xfrm_state_gc_task(struct work_struct *work); + +/* Each xfrm_state may be linked to two tables: + + 1. Hash table by (spi,daddr,ah/esp) to find SA by SPI. (input,ctl) + 2. Hash table by (daddr,family,reqid) to find what SAs exist for given + destination/tunnel endpoint. (output) + */ + +static unsigned int xfrm_state_hashmax __read_mostly = 1 * 1024 * 1024; +static struct kmem_cache *xfrm_state_cache __ro_after_init; + +static DECLARE_WORK(xfrm_state_gc_work, xfrm_state_gc_task); +static HLIST_HEAD(xfrm_state_gc_list); + +static inline bool xfrm_state_hold_rcu(struct xfrm_state __rcu *x) +{ + return refcount_inc_not_zero(&x->refcnt); +} + +static inline unsigned int xfrm_dst_hash(struct net *net, + const xfrm_address_t *daddr, + const xfrm_address_t *saddr, + u32 reqid, + unsigned short family) +{ + return __xfrm_dst_hash(daddr, saddr, reqid, family, net->xfrm.state_hmask); +} + +static inline unsigned int xfrm_src_hash(struct net *net, + const xfrm_address_t *daddr, + const xfrm_address_t *saddr, + unsigned short family) +{ + return __xfrm_src_hash(daddr, saddr, family, net->xfrm.state_hmask); +} + +static inline unsigned int +xfrm_spi_hash(struct net *net, const xfrm_address_t *daddr, + __be32 spi, u8 proto, unsigned short family) +{ + return __xfrm_spi_hash(daddr, spi, proto, family, net->xfrm.state_hmask); +} + +static unsigned int xfrm_seq_hash(struct net *net, u32 seq) +{ + return __xfrm_seq_hash(seq, net->xfrm.state_hmask); +} + +static void xfrm_hash_transfer(struct hlist_head *list, + struct hlist_head *ndsttable, + struct hlist_head *nsrctable, + struct hlist_head *nspitable, + struct hlist_head *nseqtable, + unsigned int nhashmask) +{ + struct hlist_node *tmp; + struct xfrm_state *x; + + hlist_for_each_entry_safe(x, tmp, list, bydst) { + unsigned int h; + + h = __xfrm_dst_hash(&x->id.daddr, &x->props.saddr, + x->props.reqid, x->props.family, + nhashmask); + hlist_add_head_rcu(&x->bydst, ndsttable + h); + + h = __xfrm_src_hash(&x->id.daddr, &x->props.saddr, + x->props.family, + nhashmask); + hlist_add_head_rcu(&x->bysrc, nsrctable + h); + + if (x->id.spi) { + h = __xfrm_spi_hash(&x->id.daddr, x->id.spi, + x->id.proto, x->props.family, + nhashmask); + hlist_add_head_rcu(&x->byspi, nspitable + h); + } + + if (x->km.seq) { + h = __xfrm_seq_hash(x->km.seq, nhashmask); + hlist_add_head_rcu(&x->byseq, nseqtable + h); + } + } +} + +static unsigned long xfrm_hash_new_size(unsigned int state_hmask) +{ + return ((state_hmask + 1) << 1) * sizeof(struct hlist_head); +} + +static void xfrm_hash_resize(struct work_struct *work) +{ + struct net *net = container_of(work, struct net, xfrm.state_hash_work); + struct hlist_head *ndst, *nsrc, *nspi, *nseq, *odst, *osrc, *ospi, *oseq; + unsigned long nsize, osize; + unsigned int nhashmask, ohashmask; + int i; + + nsize = xfrm_hash_new_size(net->xfrm.state_hmask); + ndst = xfrm_hash_alloc(nsize); + if (!ndst) + return; + nsrc = xfrm_hash_alloc(nsize); + if (!nsrc) { + xfrm_hash_free(ndst, nsize); + return; + } + nspi = xfrm_hash_alloc(nsize); + if (!nspi) { + xfrm_hash_free(ndst, nsize); + xfrm_hash_free(nsrc, nsize); + return; + } + nseq = xfrm_hash_alloc(nsize); + if (!nseq) { + xfrm_hash_free(ndst, nsize); + xfrm_hash_free(nsrc, nsize); + xfrm_hash_free(nspi, nsize); + return; + } + + spin_lock_bh(&net->xfrm.xfrm_state_lock); + write_seqcount_begin(&net->xfrm.xfrm_state_hash_generation); + + nhashmask = (nsize / sizeof(struct hlist_head)) - 1U; + odst = xfrm_state_deref_prot(net->xfrm.state_bydst, net); + for (i = net->xfrm.state_hmask; i >= 0; i--) + xfrm_hash_transfer(odst + i, ndst, nsrc, nspi, nseq, nhashmask); + + osrc = xfrm_state_deref_prot(net->xfrm.state_bysrc, net); + ospi = xfrm_state_deref_prot(net->xfrm.state_byspi, net); + oseq = xfrm_state_deref_prot(net->xfrm.state_byseq, net); + ohashmask = net->xfrm.state_hmask; + + rcu_assign_pointer(net->xfrm.state_bydst, ndst); + rcu_assign_pointer(net->xfrm.state_bysrc, nsrc); + rcu_assign_pointer(net->xfrm.state_byspi, nspi); + rcu_assign_pointer(net->xfrm.state_byseq, nseq); + net->xfrm.state_hmask = nhashmask; + + write_seqcount_end(&net->xfrm.xfrm_state_hash_generation); + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + + osize = (ohashmask + 1) * sizeof(struct hlist_head); + + synchronize_rcu(); + + xfrm_hash_free(odst, osize); + xfrm_hash_free(osrc, osize); + xfrm_hash_free(ospi, osize); + xfrm_hash_free(oseq, osize); +} + +static DEFINE_SPINLOCK(xfrm_state_afinfo_lock); +static struct xfrm_state_afinfo __rcu *xfrm_state_afinfo[NPROTO]; + +static DEFINE_SPINLOCK(xfrm_state_gc_lock); + +int __xfrm_state_delete(struct xfrm_state *x); + +int km_query(struct xfrm_state *x, struct xfrm_tmpl *t, struct xfrm_policy *pol); +static bool km_is_alive(const struct km_event *c); +void km_state_expired(struct xfrm_state *x, int hard, u32 portid); + +int xfrm_register_type(const struct xfrm_type *type, unsigned short family) +{ + struct xfrm_state_afinfo *afinfo = xfrm_state_get_afinfo(family); + int err = 0; + + if (!afinfo) + return -EAFNOSUPPORT; + +#define X(afi, T, name) do { \ + WARN_ON((afi)->type_ ## name); \ + (afi)->type_ ## name = (T); \ + } while (0) + + switch (type->proto) { + case IPPROTO_COMP: + X(afinfo, type, comp); + break; + case IPPROTO_AH: + X(afinfo, type, ah); + break; + case IPPROTO_ESP: + X(afinfo, type, esp); + break; + case IPPROTO_IPIP: + X(afinfo, type, ipip); + break; + case IPPROTO_DSTOPTS: + X(afinfo, type, dstopts); + break; + case IPPROTO_ROUTING: + X(afinfo, type, routing); + break; + case IPPROTO_IPV6: + X(afinfo, type, ipip6); + break; + default: + WARN_ON(1); + err = -EPROTONOSUPPORT; + break; + } +#undef X + rcu_read_unlock(); + return err; +} +EXPORT_SYMBOL(xfrm_register_type); + +void xfrm_unregister_type(const struct xfrm_type *type, unsigned short family) +{ + struct xfrm_state_afinfo *afinfo = xfrm_state_get_afinfo(family); + + if (unlikely(afinfo == NULL)) + return; + +#define X(afi, T, name) do { \ + WARN_ON((afi)->type_ ## name != (T)); \ + (afi)->type_ ## name = NULL; \ + } while (0) + + switch (type->proto) { + case IPPROTO_COMP: + X(afinfo, type, comp); + break; + case IPPROTO_AH: + X(afinfo, type, ah); + break; + case IPPROTO_ESP: + X(afinfo, type, esp); + break; + case IPPROTO_IPIP: + X(afinfo, type, ipip); + break; + case IPPROTO_DSTOPTS: + X(afinfo, type, dstopts); + break; + case IPPROTO_ROUTING: + X(afinfo, type, routing); + break; + case IPPROTO_IPV6: + X(afinfo, type, ipip6); + break; + default: + WARN_ON(1); + break; + } +#undef X + rcu_read_unlock(); +} +EXPORT_SYMBOL(xfrm_unregister_type); + +static const struct xfrm_type *xfrm_get_type(u8 proto, unsigned short family) +{ + const struct xfrm_type *type = NULL; + struct xfrm_state_afinfo *afinfo; + int modload_attempted = 0; + +retry: + afinfo = xfrm_state_get_afinfo(family); + if (unlikely(afinfo == NULL)) + return NULL; + + switch (proto) { + case IPPROTO_COMP: + type = afinfo->type_comp; + break; + case IPPROTO_AH: + type = afinfo->type_ah; + break; + case IPPROTO_ESP: + type = afinfo->type_esp; + break; + case IPPROTO_IPIP: + type = afinfo->type_ipip; + break; + case IPPROTO_DSTOPTS: + type = afinfo->type_dstopts; + break; + case IPPROTO_ROUTING: + type = afinfo->type_routing; + break; + case IPPROTO_IPV6: + type = afinfo->type_ipip6; + break; + default: + break; + } + + if (unlikely(type && !try_module_get(type->owner))) + type = NULL; + + rcu_read_unlock(); + + if (!type && !modload_attempted) { + request_module("xfrm-type-%d-%d", family, proto); + modload_attempted = 1; + goto retry; + } + + return type; +} + +static void xfrm_put_type(const struct xfrm_type *type) +{ + module_put(type->owner); +} + +int xfrm_register_type_offload(const struct xfrm_type_offload *type, + unsigned short family) +{ + struct xfrm_state_afinfo *afinfo = xfrm_state_get_afinfo(family); + int err = 0; + + if (unlikely(afinfo == NULL)) + return -EAFNOSUPPORT; + + switch (type->proto) { + case IPPROTO_ESP: + WARN_ON(afinfo->type_offload_esp); + afinfo->type_offload_esp = type; + break; + default: + WARN_ON(1); + err = -EPROTONOSUPPORT; + break; + } + + rcu_read_unlock(); + return err; +} +EXPORT_SYMBOL(xfrm_register_type_offload); + +void xfrm_unregister_type_offload(const struct xfrm_type_offload *type, + unsigned short family) +{ + struct xfrm_state_afinfo *afinfo = xfrm_state_get_afinfo(family); + + if (unlikely(afinfo == NULL)) + return; + + switch (type->proto) { + case IPPROTO_ESP: + WARN_ON(afinfo->type_offload_esp != type); + afinfo->type_offload_esp = NULL; + break; + default: + WARN_ON(1); + break; + } + rcu_read_unlock(); +} +EXPORT_SYMBOL(xfrm_unregister_type_offload); + +static const struct xfrm_type_offload * +xfrm_get_type_offload(u8 proto, unsigned short family, bool try_load) +{ + const struct xfrm_type_offload *type = NULL; + struct xfrm_state_afinfo *afinfo; + +retry: + afinfo = xfrm_state_get_afinfo(family); + if (unlikely(afinfo == NULL)) + return NULL; + + switch (proto) { + case IPPROTO_ESP: + type = afinfo->type_offload_esp; + break; + default: + break; + } + + if ((type && !try_module_get(type->owner))) + type = NULL; + + rcu_read_unlock(); + + if (!type && try_load) { + request_module("xfrm-offload-%d-%d", family, proto); + try_load = false; + goto retry; + } + + return type; +} + +static void xfrm_put_type_offload(const struct xfrm_type_offload *type) +{ + module_put(type->owner); +} + +static const struct xfrm_mode xfrm4_mode_map[XFRM_MODE_MAX] = { + [XFRM_MODE_BEET] = { + .encap = XFRM_MODE_BEET, + .flags = XFRM_MODE_FLAG_TUNNEL, + .family = AF_INET, + }, + [XFRM_MODE_TRANSPORT] = { + .encap = XFRM_MODE_TRANSPORT, + .family = AF_INET, + }, + [XFRM_MODE_TUNNEL] = { + .encap = XFRM_MODE_TUNNEL, + .flags = XFRM_MODE_FLAG_TUNNEL, + .family = AF_INET, + }, +}; + +static const struct xfrm_mode xfrm6_mode_map[XFRM_MODE_MAX] = { + [XFRM_MODE_BEET] = { + .encap = XFRM_MODE_BEET, + .flags = XFRM_MODE_FLAG_TUNNEL, + .family = AF_INET6, + }, + [XFRM_MODE_ROUTEOPTIMIZATION] = { + .encap = XFRM_MODE_ROUTEOPTIMIZATION, + .family = AF_INET6, + }, + [XFRM_MODE_TRANSPORT] = { + .encap = XFRM_MODE_TRANSPORT, + .family = AF_INET6, + }, + [XFRM_MODE_TUNNEL] = { + .encap = XFRM_MODE_TUNNEL, + .flags = XFRM_MODE_FLAG_TUNNEL, + .family = AF_INET6, + }, +}; + +static const struct xfrm_mode *xfrm_get_mode(unsigned int encap, int family) +{ + const struct xfrm_mode *mode; + + if (unlikely(encap >= XFRM_MODE_MAX)) + return NULL; + + switch (family) { + case AF_INET: + mode = &xfrm4_mode_map[encap]; + if (mode->family == family) + return mode; + break; + case AF_INET6: + mode = &xfrm6_mode_map[encap]; + if (mode->family == family) + return mode; + break; + default: + break; + } + + return NULL; +} + +void xfrm_state_free(struct xfrm_state *x) +{ + kmem_cache_free(xfrm_state_cache, x); +} +EXPORT_SYMBOL(xfrm_state_free); + +static void ___xfrm_state_destroy(struct xfrm_state *x) +{ + hrtimer_cancel(&x->mtimer); + del_timer_sync(&x->rtimer); + kfree(x->aead); + kfree(x->aalg); + kfree(x->ealg); + kfree(x->calg); + kfree(x->encap); + kfree(x->coaddr); + kfree(x->replay_esn); + kfree(x->preplay_esn); + if (x->type_offload) + xfrm_put_type_offload(x->type_offload); + if (x->type) { + x->type->destructor(x); + xfrm_put_type(x->type); + } + if (x->xfrag.page) + put_page(x->xfrag.page); + xfrm_dev_state_free(x); + security_xfrm_state_free(x); + xfrm_state_free(x); +} + +static void xfrm_state_gc_task(struct work_struct *work) +{ + struct xfrm_state *x; + struct hlist_node *tmp; + struct hlist_head gc_list; + + spin_lock_bh(&xfrm_state_gc_lock); + hlist_move_list(&xfrm_state_gc_list, &gc_list); + spin_unlock_bh(&xfrm_state_gc_lock); + + synchronize_rcu(); + + hlist_for_each_entry_safe(x, tmp, &gc_list, gclist) + ___xfrm_state_destroy(x); +} + +static enum hrtimer_restart xfrm_timer_handler(struct hrtimer *me) +{ + struct xfrm_state *x = container_of(me, struct xfrm_state, mtimer); + enum hrtimer_restart ret = HRTIMER_NORESTART; + time64_t now = ktime_get_real_seconds(); + time64_t next = TIME64_MAX; + int warn = 0; + int err = 0; + + spin_lock(&x->lock); + if (x->km.state == XFRM_STATE_DEAD) + goto out; + if (x->km.state == XFRM_STATE_EXPIRED) + goto expired; + if (x->lft.hard_add_expires_seconds) { + long tmo = x->lft.hard_add_expires_seconds + + x->curlft.add_time - now; + if (tmo <= 0) { + if (x->xflags & XFRM_SOFT_EXPIRE) { + /* enter hard expire without soft expire first?! + * setting a new date could trigger this. + * workaround: fix x->curflt.add_time by below: + */ + x->curlft.add_time = now - x->saved_tmo - 1; + tmo = x->lft.hard_add_expires_seconds - x->saved_tmo; + } else + goto expired; + } + if (tmo < next) + next = tmo; + } + if (x->lft.hard_use_expires_seconds) { + long tmo = x->lft.hard_use_expires_seconds + + (READ_ONCE(x->curlft.use_time) ? : now) - now; + if (tmo <= 0) + goto expired; + if (tmo < next) + next = tmo; + } + if (x->km.dying) + goto resched; + if (x->lft.soft_add_expires_seconds) { + long tmo = x->lft.soft_add_expires_seconds + + x->curlft.add_time - now; + if (tmo <= 0) { + warn = 1; + x->xflags &= ~XFRM_SOFT_EXPIRE; + } else if (tmo < next) { + next = tmo; + x->xflags |= XFRM_SOFT_EXPIRE; + x->saved_tmo = tmo; + } + } + if (x->lft.soft_use_expires_seconds) { + long tmo = x->lft.soft_use_expires_seconds + + (READ_ONCE(x->curlft.use_time) ? : now) - now; + if (tmo <= 0) + warn = 1; + else if (tmo < next) + next = tmo; + } + + x->km.dying = warn; + if (warn) + km_state_expired(x, 0, 0); +resched: + if (next != TIME64_MAX) { + hrtimer_forward_now(&x->mtimer, ktime_set(next, 0)); + ret = HRTIMER_RESTART; + } + + goto out; + +expired: + if (x->km.state == XFRM_STATE_ACQ && x->id.spi == 0) + x->km.state = XFRM_STATE_EXPIRED; + + err = __xfrm_state_delete(x); + if (!err) + km_state_expired(x, 1, 0); + + xfrm_audit_state_delete(x, err ? 0 : 1, true); + +out: + spin_unlock(&x->lock); + return ret; +} + +static void xfrm_replay_timer_handler(struct timer_list *t); + +struct xfrm_state *xfrm_state_alloc(struct net *net) +{ + struct xfrm_state *x; + + x = kmem_cache_zalloc(xfrm_state_cache, GFP_ATOMIC); + + if (x) { + write_pnet(&x->xs_net, net); + refcount_set(&x->refcnt, 1); + atomic_set(&x->tunnel_users, 0); + INIT_LIST_HEAD(&x->km.all); + INIT_HLIST_NODE(&x->bydst); + INIT_HLIST_NODE(&x->bysrc); + INIT_HLIST_NODE(&x->byspi); + INIT_HLIST_NODE(&x->byseq); + hrtimer_init(&x->mtimer, CLOCK_BOOTTIME, HRTIMER_MODE_ABS_SOFT); + x->mtimer.function = xfrm_timer_handler; + timer_setup(&x->rtimer, xfrm_replay_timer_handler, 0); + x->curlft.add_time = ktime_get_real_seconds(); + x->lft.soft_byte_limit = XFRM_INF; + x->lft.soft_packet_limit = XFRM_INF; + x->lft.hard_byte_limit = XFRM_INF; + x->lft.hard_packet_limit = XFRM_INF; + x->replay_maxage = 0; + x->replay_maxdiff = 0; + spin_lock_init(&x->lock); + } + return x; +} +EXPORT_SYMBOL(xfrm_state_alloc); + +void __xfrm_state_destroy(struct xfrm_state *x, bool sync) +{ + WARN_ON(x->km.state != XFRM_STATE_DEAD); + + if (sync) { + synchronize_rcu(); + ___xfrm_state_destroy(x); + } else { + spin_lock_bh(&xfrm_state_gc_lock); + hlist_add_head(&x->gclist, &xfrm_state_gc_list); + spin_unlock_bh(&xfrm_state_gc_lock); + schedule_work(&xfrm_state_gc_work); + } +} +EXPORT_SYMBOL(__xfrm_state_destroy); + +int __xfrm_state_delete(struct xfrm_state *x) +{ + struct net *net = xs_net(x); + int err = -ESRCH; + + if (x->km.state != XFRM_STATE_DEAD) { + x->km.state = XFRM_STATE_DEAD; + spin_lock(&net->xfrm.xfrm_state_lock); + list_del(&x->km.all); + hlist_del_rcu(&x->bydst); + hlist_del_rcu(&x->bysrc); + if (x->km.seq) + hlist_del_rcu(&x->byseq); + if (x->id.spi) + hlist_del_rcu(&x->byspi); + net->xfrm.state_num--; + spin_unlock(&net->xfrm.xfrm_state_lock); + + if (x->encap_sk) + sock_put(rcu_dereference_raw(x->encap_sk)); + + xfrm_dev_state_delete(x); + + /* All xfrm_state objects are created by xfrm_state_alloc. + * The xfrm_state_alloc call gives a reference, and that + * is what we are dropping here. + */ + xfrm_state_put(x); + err = 0; + } + + return err; +} +EXPORT_SYMBOL(__xfrm_state_delete); + +int xfrm_state_delete(struct xfrm_state *x) +{ + int err; + + spin_lock_bh(&x->lock); + err = __xfrm_state_delete(x); + spin_unlock_bh(&x->lock); + + return err; +} +EXPORT_SYMBOL(xfrm_state_delete); + +#ifdef CONFIG_SECURITY_NETWORK_XFRM +static inline int +xfrm_state_flush_secctx_check(struct net *net, u8 proto, bool task_valid) +{ + int i, err = 0; + + for (i = 0; i <= net->xfrm.state_hmask; i++) { + struct xfrm_state *x; + + hlist_for_each_entry(x, net->xfrm.state_bydst+i, bydst) { + if (xfrm_id_proto_match(x->id.proto, proto) && + (err = security_xfrm_state_delete(x)) != 0) { + xfrm_audit_state_delete(x, 0, task_valid); + return err; + } + } + } + + return err; +} + +static inline int +xfrm_dev_state_flush_secctx_check(struct net *net, struct net_device *dev, bool task_valid) +{ + int i, err = 0; + + for (i = 0; i <= net->xfrm.state_hmask; i++) { + struct xfrm_state *x; + struct xfrm_dev_offload *xso; + + hlist_for_each_entry(x, net->xfrm.state_bydst+i, bydst) { + xso = &x->xso; + + if (xso->dev == dev && + (err = security_xfrm_state_delete(x)) != 0) { + xfrm_audit_state_delete(x, 0, task_valid); + return err; + } + } + } + + return err; +} +#else +static inline int +xfrm_state_flush_secctx_check(struct net *net, u8 proto, bool task_valid) +{ + return 0; +} + +static inline int +xfrm_dev_state_flush_secctx_check(struct net *net, struct net_device *dev, bool task_valid) +{ + return 0; +} +#endif + +int xfrm_state_flush(struct net *net, u8 proto, bool task_valid, bool sync) +{ + int i, err = 0, cnt = 0; + + spin_lock_bh(&net->xfrm.xfrm_state_lock); + err = xfrm_state_flush_secctx_check(net, proto, task_valid); + if (err) + goto out; + + err = -ESRCH; + for (i = 0; i <= net->xfrm.state_hmask; i++) { + struct xfrm_state *x; +restart: + hlist_for_each_entry(x, net->xfrm.state_bydst+i, bydst) { + if (!xfrm_state_kern(x) && + xfrm_id_proto_match(x->id.proto, proto)) { + xfrm_state_hold(x); + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + + err = xfrm_state_delete(x); + xfrm_audit_state_delete(x, err ? 0 : 1, + task_valid); + if (sync) + xfrm_state_put_sync(x); + else + xfrm_state_put(x); + if (!err) + cnt++; + + spin_lock_bh(&net->xfrm.xfrm_state_lock); + goto restart; + } + } + } +out: + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + if (cnt) + err = 0; + + return err; +} +EXPORT_SYMBOL(xfrm_state_flush); + +int xfrm_dev_state_flush(struct net *net, struct net_device *dev, bool task_valid) +{ + int i, err = 0, cnt = 0; + + spin_lock_bh(&net->xfrm.xfrm_state_lock); + err = xfrm_dev_state_flush_secctx_check(net, dev, task_valid); + if (err) + goto out; + + err = -ESRCH; + for (i = 0; i <= net->xfrm.state_hmask; i++) { + struct xfrm_state *x; + struct xfrm_dev_offload *xso; +restart: + hlist_for_each_entry(x, net->xfrm.state_bydst+i, bydst) { + xso = &x->xso; + + if (!xfrm_state_kern(x) && xso->dev == dev) { + xfrm_state_hold(x); + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + + err = xfrm_state_delete(x); + xfrm_audit_state_delete(x, err ? 0 : 1, + task_valid); + xfrm_state_put(x); + if (!err) + cnt++; + + spin_lock_bh(&net->xfrm.xfrm_state_lock); + goto restart; + } + } + } + if (cnt) + err = 0; + +out: + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + return err; +} +EXPORT_SYMBOL(xfrm_dev_state_flush); + +void xfrm_sad_getinfo(struct net *net, struct xfrmk_sadinfo *si) +{ + spin_lock_bh(&net->xfrm.xfrm_state_lock); + si->sadcnt = net->xfrm.state_num; + si->sadhcnt = net->xfrm.state_hmask + 1; + si->sadhmcnt = xfrm_state_hashmax; + spin_unlock_bh(&net->xfrm.xfrm_state_lock); +} +EXPORT_SYMBOL(xfrm_sad_getinfo); + +static void +__xfrm4_init_tempsel(struct xfrm_selector *sel, const struct flowi *fl) +{ + const struct flowi4 *fl4 = &fl->u.ip4; + + sel->daddr.a4 = fl4->daddr; + sel->saddr.a4 = fl4->saddr; + sel->dport = xfrm_flowi_dport(fl, &fl4->uli); + sel->dport_mask = htons(0xffff); + sel->sport = xfrm_flowi_sport(fl, &fl4->uli); + sel->sport_mask = htons(0xffff); + sel->family = AF_INET; + sel->prefixlen_d = 32; + sel->prefixlen_s = 32; + sel->proto = fl4->flowi4_proto; + sel->ifindex = fl4->flowi4_oif; +} + +static void +__xfrm6_init_tempsel(struct xfrm_selector *sel, const struct flowi *fl) +{ + const struct flowi6 *fl6 = &fl->u.ip6; + + /* Initialize temporary selector matching only to current session. */ + *(struct in6_addr *)&sel->daddr = fl6->daddr; + *(struct in6_addr *)&sel->saddr = fl6->saddr; + sel->dport = xfrm_flowi_dport(fl, &fl6->uli); + sel->dport_mask = htons(0xffff); + sel->sport = xfrm_flowi_sport(fl, &fl6->uli); + sel->sport_mask = htons(0xffff); + sel->family = AF_INET6; + sel->prefixlen_d = 128; + sel->prefixlen_s = 128; + sel->proto = fl6->flowi6_proto; + sel->ifindex = fl6->flowi6_oif; +} + +static void +xfrm_init_tempstate(struct xfrm_state *x, const struct flowi *fl, + const struct xfrm_tmpl *tmpl, + const xfrm_address_t *daddr, const xfrm_address_t *saddr, + unsigned short family) +{ + switch (family) { + case AF_INET: + __xfrm4_init_tempsel(&x->sel, fl); + break; + case AF_INET6: + __xfrm6_init_tempsel(&x->sel, fl); + break; + } + + x->id = tmpl->id; + + switch (tmpl->encap_family) { + case AF_INET: + if (x->id.daddr.a4 == 0) + x->id.daddr.a4 = daddr->a4; + x->props.saddr = tmpl->saddr; + if (x->props.saddr.a4 == 0) + x->props.saddr.a4 = saddr->a4; + break; + case AF_INET6: + if (ipv6_addr_any((struct in6_addr *)&x->id.daddr)) + memcpy(&x->id.daddr, daddr, sizeof(x->sel.daddr)); + memcpy(&x->props.saddr, &tmpl->saddr, sizeof(x->props.saddr)); + if (ipv6_addr_any((struct in6_addr *)&x->props.saddr)) + memcpy(&x->props.saddr, saddr, sizeof(x->props.saddr)); + break; + } + + x->props.mode = tmpl->mode; + x->props.reqid = tmpl->reqid; + x->props.family = tmpl->encap_family; +} + +static struct xfrm_state *__xfrm_state_lookup(struct net *net, u32 mark, + const xfrm_address_t *daddr, + __be32 spi, u8 proto, + unsigned short family) +{ + unsigned int h = xfrm_spi_hash(net, daddr, spi, proto, family); + struct xfrm_state *x; + + hlist_for_each_entry_rcu(x, net->xfrm.state_byspi + h, byspi) { + if (x->props.family != family || + x->id.spi != spi || + x->id.proto != proto || + !xfrm_addr_equal(&x->id.daddr, daddr, family)) + continue; + + if ((mark & x->mark.m) != x->mark.v) + continue; + if (!xfrm_state_hold_rcu(x)) + continue; + return x; + } + + return NULL; +} + +static struct xfrm_state *__xfrm_state_lookup_byaddr(struct net *net, u32 mark, + const xfrm_address_t *daddr, + const xfrm_address_t *saddr, + u8 proto, unsigned short family) +{ + unsigned int h = xfrm_src_hash(net, daddr, saddr, family); + struct xfrm_state *x; + + hlist_for_each_entry_rcu(x, net->xfrm.state_bysrc + h, bysrc) { + if (x->props.family != family || + x->id.proto != proto || + !xfrm_addr_equal(&x->id.daddr, daddr, family) || + !xfrm_addr_equal(&x->props.saddr, saddr, family)) + continue; + + if ((mark & x->mark.m) != x->mark.v) + continue; + if (!xfrm_state_hold_rcu(x)) + continue; + return x; + } + + return NULL; +} + +static inline struct xfrm_state * +__xfrm_state_locate(struct xfrm_state *x, int use_spi, int family) +{ + struct net *net = xs_net(x); + u32 mark = x->mark.v & x->mark.m; + + if (use_spi) + return __xfrm_state_lookup(net, mark, &x->id.daddr, + x->id.spi, x->id.proto, family); + else + return __xfrm_state_lookup_byaddr(net, mark, + &x->id.daddr, + &x->props.saddr, + x->id.proto, family); +} + +static void xfrm_hash_grow_check(struct net *net, int have_hash_collision) +{ + if (have_hash_collision && + (net->xfrm.state_hmask + 1) < xfrm_state_hashmax && + net->xfrm.state_num > net->xfrm.state_hmask) + schedule_work(&net->xfrm.state_hash_work); +} + +static void xfrm_state_look_at(struct xfrm_policy *pol, struct xfrm_state *x, + const struct flowi *fl, unsigned short family, + struct xfrm_state **best, int *acq_in_progress, + int *error) +{ + /* Resolution logic: + * 1. There is a valid state with matching selector. Done. + * 2. Valid state with inappropriate selector. Skip. + * + * Entering area of "sysdeps". + * + * 3. If state is not valid, selector is temporary, it selects + * only session which triggered previous resolution. Key + * manager will do something to install a state with proper + * selector. + */ + if (x->km.state == XFRM_STATE_VALID) { + if ((x->sel.family && + (x->sel.family != family || + !xfrm_selector_match(&x->sel, fl, family))) || + !security_xfrm_state_pol_flow_match(x, pol, + &fl->u.__fl_common)) + return; + + if (!*best || + (*best)->km.dying > x->km.dying || + ((*best)->km.dying == x->km.dying && + (*best)->curlft.add_time < x->curlft.add_time)) + *best = x; + } else if (x->km.state == XFRM_STATE_ACQ) { + *acq_in_progress = 1; + } else if (x->km.state == XFRM_STATE_ERROR || + x->km.state == XFRM_STATE_EXPIRED) { + if ((!x->sel.family || + (x->sel.family == family && + xfrm_selector_match(&x->sel, fl, family))) && + security_xfrm_state_pol_flow_match(x, pol, + &fl->u.__fl_common)) + *error = -ESRCH; + } +} + +struct xfrm_state * +xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr, + const struct flowi *fl, struct xfrm_tmpl *tmpl, + struct xfrm_policy *pol, int *err, + unsigned short family, u32 if_id) +{ + static xfrm_address_t saddr_wildcard = { }; + struct net *net = xp_net(pol); + unsigned int h, h_wildcard; + struct xfrm_state *x, *x0, *to_put; + int acquire_in_progress = 0; + int error = 0; + struct xfrm_state *best = NULL; + u32 mark = pol->mark.v & pol->mark.m; + unsigned short encap_family = tmpl->encap_family; + unsigned int sequence; + struct km_event c; + + to_put = NULL; + + sequence = read_seqcount_begin(&net->xfrm.xfrm_state_hash_generation); + + rcu_read_lock(); + h = xfrm_dst_hash(net, daddr, saddr, tmpl->reqid, encap_family); + hlist_for_each_entry_rcu(x, net->xfrm.state_bydst + h, bydst) { + if (x->props.family == encap_family && + x->props.reqid == tmpl->reqid && + (mark & x->mark.m) == x->mark.v && + x->if_id == if_id && + !(x->props.flags & XFRM_STATE_WILDRECV) && + xfrm_state_addr_check(x, daddr, saddr, encap_family) && + tmpl->mode == x->props.mode && + tmpl->id.proto == x->id.proto && + (tmpl->id.spi == x->id.spi || !tmpl->id.spi)) + xfrm_state_look_at(pol, x, fl, family, + &best, &acquire_in_progress, &error); + } + if (best || acquire_in_progress) + goto found; + + h_wildcard = xfrm_dst_hash(net, daddr, &saddr_wildcard, tmpl->reqid, encap_family); + hlist_for_each_entry_rcu(x, net->xfrm.state_bydst + h_wildcard, bydst) { + if (x->props.family == encap_family && + x->props.reqid == tmpl->reqid && + (mark & x->mark.m) == x->mark.v && + x->if_id == if_id && + !(x->props.flags & XFRM_STATE_WILDRECV) && + xfrm_addr_equal(&x->id.daddr, daddr, encap_family) && + tmpl->mode == x->props.mode && + tmpl->id.proto == x->id.proto && + (tmpl->id.spi == x->id.spi || !tmpl->id.spi)) + xfrm_state_look_at(pol, x, fl, family, + &best, &acquire_in_progress, &error); + } + +found: + x = best; + if (!x && !error && !acquire_in_progress) { + if (tmpl->id.spi && + (x0 = __xfrm_state_lookup(net, mark, daddr, tmpl->id.spi, + tmpl->id.proto, encap_family)) != NULL) { + to_put = x0; + error = -EEXIST; + goto out; + } + + c.net = net; + /* If the KMs have no listeners (yet...), avoid allocating an SA + * for each and every packet - garbage collection might not + * handle the flood. + */ + if (!km_is_alive(&c)) { + error = -ESRCH; + goto out; + } + + x = xfrm_state_alloc(net); + if (x == NULL) { + error = -ENOMEM; + goto out; + } + /* Initialize temporary state matching only + * to current session. */ + xfrm_init_tempstate(x, fl, tmpl, daddr, saddr, family); + memcpy(&x->mark, &pol->mark, sizeof(x->mark)); + x->if_id = if_id; + + error = security_xfrm_state_alloc_acquire(x, pol->security, fl->flowi_secid); + if (error) { + x->km.state = XFRM_STATE_DEAD; + to_put = x; + x = NULL; + goto out; + } + + if (km_query(x, tmpl, pol) == 0) { + spin_lock_bh(&net->xfrm.xfrm_state_lock); + x->km.state = XFRM_STATE_ACQ; + list_add(&x->km.all, &net->xfrm.state_all); + hlist_add_head_rcu(&x->bydst, net->xfrm.state_bydst + h); + h = xfrm_src_hash(net, daddr, saddr, encap_family); + hlist_add_head_rcu(&x->bysrc, net->xfrm.state_bysrc + h); + if (x->id.spi) { + h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, encap_family); + hlist_add_head_rcu(&x->byspi, net->xfrm.state_byspi + h); + } + if (x->km.seq) { + h = xfrm_seq_hash(net, x->km.seq); + hlist_add_head_rcu(&x->byseq, net->xfrm.state_byseq + h); + } + x->lft.hard_add_expires_seconds = net->xfrm.sysctl_acq_expires; + hrtimer_start(&x->mtimer, + ktime_set(net->xfrm.sysctl_acq_expires, 0), + HRTIMER_MODE_REL_SOFT); + net->xfrm.state_num++; + xfrm_hash_grow_check(net, x->bydst.next != NULL); + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + } else { + x->km.state = XFRM_STATE_DEAD; + to_put = x; + x = NULL; + error = -ESRCH; + } + } +out: + if (x) { + if (!xfrm_state_hold_rcu(x)) { + *err = -EAGAIN; + x = NULL; + } + } else { + *err = acquire_in_progress ? -EAGAIN : error; + } + rcu_read_unlock(); + if (to_put) + xfrm_state_put(to_put); + + if (read_seqcount_retry(&net->xfrm.xfrm_state_hash_generation, sequence)) { + *err = -EAGAIN; + if (x) { + xfrm_state_put(x); + x = NULL; + } + } + + return x; +} + +struct xfrm_state * +xfrm_stateonly_find(struct net *net, u32 mark, u32 if_id, + xfrm_address_t *daddr, xfrm_address_t *saddr, + unsigned short family, u8 mode, u8 proto, u32 reqid) +{ + unsigned int h; + struct xfrm_state *rx = NULL, *x = NULL; + + spin_lock_bh(&net->xfrm.xfrm_state_lock); + h = xfrm_dst_hash(net, daddr, saddr, reqid, family); + hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) { + if (x->props.family == family && + x->props.reqid == reqid && + (mark & x->mark.m) == x->mark.v && + x->if_id == if_id && + !(x->props.flags & XFRM_STATE_WILDRECV) && + xfrm_state_addr_check(x, daddr, saddr, family) && + mode == x->props.mode && + proto == x->id.proto && + x->km.state == XFRM_STATE_VALID) { + rx = x; + break; + } + } + + if (rx) + xfrm_state_hold(rx); + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + + + return rx; +} +EXPORT_SYMBOL(xfrm_stateonly_find); + +struct xfrm_state *xfrm_state_lookup_byspi(struct net *net, __be32 spi, + unsigned short family) +{ + struct xfrm_state *x; + struct xfrm_state_walk *w; + + spin_lock_bh(&net->xfrm.xfrm_state_lock); + list_for_each_entry(w, &net->xfrm.state_all, all) { + x = container_of(w, struct xfrm_state, km); + if (x->props.family != family || + x->id.spi != spi) + continue; + + xfrm_state_hold(x); + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + return x; + } + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + return NULL; +} +EXPORT_SYMBOL(xfrm_state_lookup_byspi); + +static void __xfrm_state_insert(struct xfrm_state *x) +{ + struct net *net = xs_net(x); + unsigned int h; + + list_add(&x->km.all, &net->xfrm.state_all); + + h = xfrm_dst_hash(net, &x->id.daddr, &x->props.saddr, + x->props.reqid, x->props.family); + hlist_add_head_rcu(&x->bydst, net->xfrm.state_bydst + h); + + h = xfrm_src_hash(net, &x->id.daddr, &x->props.saddr, x->props.family); + hlist_add_head_rcu(&x->bysrc, net->xfrm.state_bysrc + h); + + if (x->id.spi) { + h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, + x->props.family); + + hlist_add_head_rcu(&x->byspi, net->xfrm.state_byspi + h); + } + + if (x->km.seq) { + h = xfrm_seq_hash(net, x->km.seq); + + hlist_add_head_rcu(&x->byseq, net->xfrm.state_byseq + h); + } + + hrtimer_start(&x->mtimer, ktime_set(1, 0), HRTIMER_MODE_REL_SOFT); + if (x->replay_maxage) + mod_timer(&x->rtimer, jiffies + x->replay_maxage); + + net->xfrm.state_num++; + + xfrm_hash_grow_check(net, x->bydst.next != NULL); +} + +/* net->xfrm.xfrm_state_lock is held */ +static void __xfrm_state_bump_genids(struct xfrm_state *xnew) +{ + struct net *net = xs_net(xnew); + unsigned short family = xnew->props.family; + u32 reqid = xnew->props.reqid; + struct xfrm_state *x; + unsigned int h; + u32 mark = xnew->mark.v & xnew->mark.m; + u32 if_id = xnew->if_id; + + h = xfrm_dst_hash(net, &xnew->id.daddr, &xnew->props.saddr, reqid, family); + hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) { + if (x->props.family == family && + x->props.reqid == reqid && + x->if_id == if_id && + (mark & x->mark.m) == x->mark.v && + xfrm_addr_equal(&x->id.daddr, &xnew->id.daddr, family) && + xfrm_addr_equal(&x->props.saddr, &xnew->props.saddr, family)) + x->genid++; + } +} + +void xfrm_state_insert(struct xfrm_state *x) +{ + struct net *net = xs_net(x); + + spin_lock_bh(&net->xfrm.xfrm_state_lock); + __xfrm_state_bump_genids(x); + __xfrm_state_insert(x); + spin_unlock_bh(&net->xfrm.xfrm_state_lock); +} +EXPORT_SYMBOL(xfrm_state_insert); + +/* net->xfrm.xfrm_state_lock is held */ +static struct xfrm_state *__find_acq_core(struct net *net, + const struct xfrm_mark *m, + unsigned short family, u8 mode, + u32 reqid, u32 if_id, u8 proto, + const xfrm_address_t *daddr, + const xfrm_address_t *saddr, + int create) +{ + unsigned int h = xfrm_dst_hash(net, daddr, saddr, reqid, family); + struct xfrm_state *x; + u32 mark = m->v & m->m; + + hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) { + if (x->props.reqid != reqid || + x->props.mode != mode || + x->props.family != family || + x->km.state != XFRM_STATE_ACQ || + x->id.spi != 0 || + x->id.proto != proto || + (mark & x->mark.m) != x->mark.v || + !xfrm_addr_equal(&x->id.daddr, daddr, family) || + !xfrm_addr_equal(&x->props.saddr, saddr, family)) + continue; + + xfrm_state_hold(x); + return x; + } + + if (!create) + return NULL; + + x = xfrm_state_alloc(net); + if (likely(x)) { + switch (family) { + case AF_INET: + x->sel.daddr.a4 = daddr->a4; + x->sel.saddr.a4 = saddr->a4; + x->sel.prefixlen_d = 32; + x->sel.prefixlen_s = 32; + x->props.saddr.a4 = saddr->a4; + x->id.daddr.a4 = daddr->a4; + break; + + case AF_INET6: + x->sel.daddr.in6 = daddr->in6; + x->sel.saddr.in6 = saddr->in6; + x->sel.prefixlen_d = 128; + x->sel.prefixlen_s = 128; + x->props.saddr.in6 = saddr->in6; + x->id.daddr.in6 = daddr->in6; + break; + } + + x->km.state = XFRM_STATE_ACQ; + x->id.proto = proto; + x->props.family = family; + x->props.mode = mode; + x->props.reqid = reqid; + x->if_id = if_id; + x->mark.v = m->v; + x->mark.m = m->m; + x->lft.hard_add_expires_seconds = net->xfrm.sysctl_acq_expires; + xfrm_state_hold(x); + hrtimer_start(&x->mtimer, + ktime_set(net->xfrm.sysctl_acq_expires, 0), + HRTIMER_MODE_REL_SOFT); + list_add(&x->km.all, &net->xfrm.state_all); + hlist_add_head_rcu(&x->bydst, net->xfrm.state_bydst + h); + h = xfrm_src_hash(net, daddr, saddr, family); + hlist_add_head_rcu(&x->bysrc, net->xfrm.state_bysrc + h); + + net->xfrm.state_num++; + + xfrm_hash_grow_check(net, x->bydst.next != NULL); + } + + return x; +} + +static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq); + +int xfrm_state_add(struct xfrm_state *x) +{ + struct net *net = xs_net(x); + struct xfrm_state *x1, *to_put; + int family; + int err; + u32 mark = x->mark.v & x->mark.m; + int use_spi = xfrm_id_proto_match(x->id.proto, IPSEC_PROTO_ANY); + + family = x->props.family; + + to_put = NULL; + + spin_lock_bh(&net->xfrm.xfrm_state_lock); + + x1 = __xfrm_state_locate(x, use_spi, family); + if (x1) { + to_put = x1; + x1 = NULL; + err = -EEXIST; + goto out; + } + + if (use_spi && x->km.seq) { + x1 = __xfrm_find_acq_byseq(net, mark, x->km.seq); + if (x1 && ((x1->id.proto != x->id.proto) || + !xfrm_addr_equal(&x1->id.daddr, &x->id.daddr, family))) { + to_put = x1; + x1 = NULL; + } + } + + if (use_spi && !x1) + x1 = __find_acq_core(net, &x->mark, family, x->props.mode, + x->props.reqid, x->if_id, x->id.proto, + &x->id.daddr, &x->props.saddr, 0); + + __xfrm_state_bump_genids(x); + __xfrm_state_insert(x); + err = 0; + +out: + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + + if (x1) { + xfrm_state_delete(x1); + xfrm_state_put(x1); + } + + if (to_put) + xfrm_state_put(to_put); + + return err; +} +EXPORT_SYMBOL(xfrm_state_add); + +#ifdef CONFIG_XFRM_MIGRATE +static inline int clone_security(struct xfrm_state *x, struct xfrm_sec_ctx *security) +{ + struct xfrm_user_sec_ctx *uctx; + int size = sizeof(*uctx) + security->ctx_len; + int err; + + uctx = kmalloc(size, GFP_KERNEL); + if (!uctx) + return -ENOMEM; + + uctx->exttype = XFRMA_SEC_CTX; + uctx->len = size; + uctx->ctx_doi = security->ctx_doi; + uctx->ctx_alg = security->ctx_alg; + uctx->ctx_len = security->ctx_len; + memcpy(uctx + 1, security->ctx_str, security->ctx_len); + err = security_xfrm_state_alloc(x, uctx); + kfree(uctx); + if (err) + return err; + + return 0; +} + +static struct xfrm_state *xfrm_state_clone(struct xfrm_state *orig, + struct xfrm_encap_tmpl *encap) +{ + struct net *net = xs_net(orig); + struct xfrm_state *x = xfrm_state_alloc(net); + if (!x) + goto out; + + memcpy(&x->id, &orig->id, sizeof(x->id)); + memcpy(&x->sel, &orig->sel, sizeof(x->sel)); + memcpy(&x->lft, &orig->lft, sizeof(x->lft)); + x->props.mode = orig->props.mode; + x->props.replay_window = orig->props.replay_window; + x->props.reqid = orig->props.reqid; + x->props.family = orig->props.family; + x->props.saddr = orig->props.saddr; + + if (orig->aalg) { + x->aalg = xfrm_algo_auth_clone(orig->aalg); + if (!x->aalg) + goto error; + } + x->props.aalgo = orig->props.aalgo; + + if (orig->aead) { + x->aead = xfrm_algo_aead_clone(orig->aead); + x->geniv = orig->geniv; + if (!x->aead) + goto error; + } + if (orig->ealg) { + x->ealg = xfrm_algo_clone(orig->ealg); + if (!x->ealg) + goto error; + } + x->props.ealgo = orig->props.ealgo; + + if (orig->calg) { + x->calg = xfrm_algo_clone(orig->calg); + if (!x->calg) + goto error; + } + x->props.calgo = orig->props.calgo; + + if (encap || orig->encap) { + if (encap) + x->encap = kmemdup(encap, sizeof(*x->encap), + GFP_KERNEL); + else + x->encap = kmemdup(orig->encap, sizeof(*x->encap), + GFP_KERNEL); + + if (!x->encap) + goto error; + } + + if (orig->security) + if (clone_security(x, orig->security)) + goto error; + + if (orig->coaddr) { + x->coaddr = kmemdup(orig->coaddr, sizeof(*x->coaddr), + GFP_KERNEL); + if (!x->coaddr) + goto error; + } + + if (orig->replay_esn) { + if (xfrm_replay_clone(x, orig)) + goto error; + } + + memcpy(&x->mark, &orig->mark, sizeof(x->mark)); + memcpy(&x->props.smark, &orig->props.smark, sizeof(x->props.smark)); + + x->props.flags = orig->props.flags; + x->props.extra_flags = orig->props.extra_flags; + + x->if_id = orig->if_id; + x->tfcpad = orig->tfcpad; + x->replay_maxdiff = orig->replay_maxdiff; + x->replay_maxage = orig->replay_maxage; + memcpy(&x->curlft, &orig->curlft, sizeof(x->curlft)); + x->km.state = orig->km.state; + x->km.seq = orig->km.seq; + x->replay = orig->replay; + x->preplay = orig->preplay; + x->mapping_maxage = orig->mapping_maxage; + x->lastused = orig->lastused; + x->new_mapping = 0; + x->new_mapping_sport = 0; + + return x; + + error: + xfrm_state_put(x); +out: + return NULL; +} + +struct xfrm_state *xfrm_migrate_state_find(struct xfrm_migrate *m, struct net *net, + u32 if_id) +{ + unsigned int h; + struct xfrm_state *x = NULL; + + spin_lock_bh(&net->xfrm.xfrm_state_lock); + + if (m->reqid) { + h = xfrm_dst_hash(net, &m->old_daddr, &m->old_saddr, + m->reqid, m->old_family); + hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) { + if (x->props.mode != m->mode || + x->id.proto != m->proto) + continue; + if (m->reqid && x->props.reqid != m->reqid) + continue; + if (if_id != 0 && x->if_id != if_id) + continue; + if (!xfrm_addr_equal(&x->id.daddr, &m->old_daddr, + m->old_family) || + !xfrm_addr_equal(&x->props.saddr, &m->old_saddr, + m->old_family)) + continue; + xfrm_state_hold(x); + break; + } + } else { + h = xfrm_src_hash(net, &m->old_daddr, &m->old_saddr, + m->old_family); + hlist_for_each_entry(x, net->xfrm.state_bysrc+h, bysrc) { + if (x->props.mode != m->mode || + x->id.proto != m->proto) + continue; + if (if_id != 0 && x->if_id != if_id) + continue; + if (!xfrm_addr_equal(&x->id.daddr, &m->old_daddr, + m->old_family) || + !xfrm_addr_equal(&x->props.saddr, &m->old_saddr, + m->old_family)) + continue; + xfrm_state_hold(x); + break; + } + } + + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + + return x; +} +EXPORT_SYMBOL(xfrm_migrate_state_find); + +struct xfrm_state *xfrm_state_migrate(struct xfrm_state *x, + struct xfrm_migrate *m, + struct xfrm_encap_tmpl *encap) +{ + struct xfrm_state *xc; + + xc = xfrm_state_clone(x, encap); + if (!xc) + return NULL; + + xc->props.family = m->new_family; + + if (xfrm_init_state(xc) < 0) + goto error; + + memcpy(&xc->id.daddr, &m->new_daddr, sizeof(xc->id.daddr)); + memcpy(&xc->props.saddr, &m->new_saddr, sizeof(xc->props.saddr)); + + /* add state */ + if (xfrm_addr_equal(&x->id.daddr, &m->new_daddr, m->new_family)) { + /* a care is needed when the destination address of the + state is to be updated as it is a part of triplet */ + xfrm_state_insert(xc); + } else { + if (xfrm_state_add(xc) < 0) + goto error; + } + + return xc; +error: + xfrm_state_put(xc); + return NULL; +} +EXPORT_SYMBOL(xfrm_state_migrate); +#endif + +int xfrm_state_update(struct xfrm_state *x) +{ + struct xfrm_state *x1, *to_put; + int err; + int use_spi = xfrm_id_proto_match(x->id.proto, IPSEC_PROTO_ANY); + struct net *net = xs_net(x); + + to_put = NULL; + + spin_lock_bh(&net->xfrm.xfrm_state_lock); + x1 = __xfrm_state_locate(x, use_spi, x->props.family); + + err = -ESRCH; + if (!x1) + goto out; + + if (xfrm_state_kern(x1)) { + to_put = x1; + err = -EEXIST; + goto out; + } + + if (x1->km.state == XFRM_STATE_ACQ) { + __xfrm_state_insert(x); + x = NULL; + } + err = 0; + +out: + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + + if (to_put) + xfrm_state_put(to_put); + + if (err) + return err; + + if (!x) { + xfrm_state_delete(x1); + xfrm_state_put(x1); + return 0; + } + + err = -EINVAL; + spin_lock_bh(&x1->lock); + if (likely(x1->km.state == XFRM_STATE_VALID)) { + if (x->encap && x1->encap && + x->encap->encap_type == x1->encap->encap_type) + memcpy(x1->encap, x->encap, sizeof(*x1->encap)); + else if (x->encap || x1->encap) + goto fail; + + if (x->coaddr && x1->coaddr) { + memcpy(x1->coaddr, x->coaddr, sizeof(*x1->coaddr)); + } + if (!use_spi && memcmp(&x1->sel, &x->sel, sizeof(x1->sel))) + memcpy(&x1->sel, &x->sel, sizeof(x1->sel)); + memcpy(&x1->lft, &x->lft, sizeof(x1->lft)); + x1->km.dying = 0; + + hrtimer_start(&x1->mtimer, ktime_set(1, 0), + HRTIMER_MODE_REL_SOFT); + if (READ_ONCE(x1->curlft.use_time)) + xfrm_state_check_expire(x1); + + if (x->props.smark.m || x->props.smark.v || x->if_id) { + spin_lock_bh(&net->xfrm.xfrm_state_lock); + + if (x->props.smark.m || x->props.smark.v) + x1->props.smark = x->props.smark; + + if (x->if_id) + x1->if_id = x->if_id; + + __xfrm_state_bump_genids(x1); + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + } + + err = 0; + x->km.state = XFRM_STATE_DEAD; + __xfrm_state_put(x); + } + +fail: + spin_unlock_bh(&x1->lock); + + xfrm_state_put(x1); + + return err; +} +EXPORT_SYMBOL(xfrm_state_update); + +int xfrm_state_check_expire(struct xfrm_state *x) +{ + if (!READ_ONCE(x->curlft.use_time)) + WRITE_ONCE(x->curlft.use_time, ktime_get_real_seconds()); + + if (x->curlft.bytes >= x->lft.hard_byte_limit || + x->curlft.packets >= x->lft.hard_packet_limit) { + x->km.state = XFRM_STATE_EXPIRED; + hrtimer_start(&x->mtimer, 0, HRTIMER_MODE_REL_SOFT); + return -EINVAL; + } + + if (!x->km.dying && + (x->curlft.bytes >= x->lft.soft_byte_limit || + x->curlft.packets >= x->lft.soft_packet_limit)) { + x->km.dying = 1; + km_state_expired(x, 0, 0); + } + return 0; +} +EXPORT_SYMBOL(xfrm_state_check_expire); + +struct xfrm_state * +xfrm_state_lookup(struct net *net, u32 mark, const xfrm_address_t *daddr, __be32 spi, + u8 proto, unsigned short family) +{ + struct xfrm_state *x; + + rcu_read_lock(); + x = __xfrm_state_lookup(net, mark, daddr, spi, proto, family); + rcu_read_unlock(); + return x; +} +EXPORT_SYMBOL(xfrm_state_lookup); + +struct xfrm_state * +xfrm_state_lookup_byaddr(struct net *net, u32 mark, + const xfrm_address_t *daddr, const xfrm_address_t *saddr, + u8 proto, unsigned short family) +{ + struct xfrm_state *x; + + spin_lock_bh(&net->xfrm.xfrm_state_lock); + x = __xfrm_state_lookup_byaddr(net, mark, daddr, saddr, proto, family); + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + return x; +} +EXPORT_SYMBOL(xfrm_state_lookup_byaddr); + +struct xfrm_state * +xfrm_find_acq(struct net *net, const struct xfrm_mark *mark, u8 mode, u32 reqid, + u32 if_id, u8 proto, const xfrm_address_t *daddr, + const xfrm_address_t *saddr, int create, unsigned short family) +{ + struct xfrm_state *x; + + spin_lock_bh(&net->xfrm.xfrm_state_lock); + x = __find_acq_core(net, mark, family, mode, reqid, if_id, proto, daddr, saddr, create); + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + + return x; +} +EXPORT_SYMBOL(xfrm_find_acq); + +#ifdef CONFIG_XFRM_SUB_POLICY +#if IS_ENABLED(CONFIG_IPV6) +/* distribution counting sort function for xfrm_state and xfrm_tmpl */ +static void +__xfrm6_sort(void **dst, void **src, int n, + int (*cmp)(const void *p), int maxclass) +{ + int count[XFRM_MAX_DEPTH] = { }; + int class[XFRM_MAX_DEPTH]; + int i; + + for (i = 0; i < n; i++) { + int c = cmp(src[i]); + + class[i] = c; + count[c]++; + } + + for (i = 2; i < maxclass; i++) + count[i] += count[i - 1]; + + for (i = 0; i < n; i++) { + dst[count[class[i] - 1]++] = src[i]; + src[i] = NULL; + } +} + +/* Rule for xfrm_state: + * + * rule 1: select IPsec transport except AH + * rule 2: select MIPv6 RO or inbound trigger + * rule 3: select IPsec transport AH + * rule 4: select IPsec tunnel + * rule 5: others + */ +static int __xfrm6_state_sort_cmp(const void *p) +{ + const struct xfrm_state *v = p; + + switch (v->props.mode) { + case XFRM_MODE_TRANSPORT: + if (v->id.proto != IPPROTO_AH) + return 1; + else + return 3; +#if IS_ENABLED(CONFIG_IPV6_MIP6) + case XFRM_MODE_ROUTEOPTIMIZATION: + case XFRM_MODE_IN_TRIGGER: + return 2; +#endif + case XFRM_MODE_TUNNEL: + case XFRM_MODE_BEET: + return 4; + } + return 5; +} + +/* Rule for xfrm_tmpl: + * + * rule 1: select IPsec transport + * rule 2: select MIPv6 RO or inbound trigger + * rule 3: select IPsec tunnel + * rule 4: others + */ +static int __xfrm6_tmpl_sort_cmp(const void *p) +{ + const struct xfrm_tmpl *v = p; + + switch (v->mode) { + case XFRM_MODE_TRANSPORT: + return 1; +#if IS_ENABLED(CONFIG_IPV6_MIP6) + case XFRM_MODE_ROUTEOPTIMIZATION: + case XFRM_MODE_IN_TRIGGER: + return 2; +#endif + case XFRM_MODE_TUNNEL: + case XFRM_MODE_BEET: + return 3; + } + return 4; +} +#else +static inline int __xfrm6_state_sort_cmp(const void *p) { return 5; } +static inline int __xfrm6_tmpl_sort_cmp(const void *p) { return 4; } + +static inline void +__xfrm6_sort(void **dst, void **src, int n, + int (*cmp)(const void *p), int maxclass) +{ + int i; + + for (i = 0; i < n; i++) + dst[i] = src[i]; +} +#endif /* CONFIG_IPV6 */ + +void +xfrm_tmpl_sort(struct xfrm_tmpl **dst, struct xfrm_tmpl **src, int n, + unsigned short family) +{ + int i; + + if (family == AF_INET6) + __xfrm6_sort((void **)dst, (void **)src, n, + __xfrm6_tmpl_sort_cmp, 5); + else + for (i = 0; i < n; i++) + dst[i] = src[i]; +} + +void +xfrm_state_sort(struct xfrm_state **dst, struct xfrm_state **src, int n, + unsigned short family) +{ + int i; + + if (family == AF_INET6) + __xfrm6_sort((void **)dst, (void **)src, n, + __xfrm6_state_sort_cmp, 6); + else + for (i = 0; i < n; i++) + dst[i] = src[i]; +} +#endif + +/* Silly enough, but I'm lazy to build resolution list */ + +static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq) +{ + unsigned int h = xfrm_seq_hash(net, seq); + struct xfrm_state *x; + + hlist_for_each_entry_rcu(x, net->xfrm.state_byseq + h, byseq) { + if (x->km.seq == seq && + (mark & x->mark.m) == x->mark.v && + x->km.state == XFRM_STATE_ACQ) { + xfrm_state_hold(x); + return x; + } + } + + return NULL; +} + +struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq) +{ + struct xfrm_state *x; + + spin_lock_bh(&net->xfrm.xfrm_state_lock); + x = __xfrm_find_acq_byseq(net, mark, seq); + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + return x; +} +EXPORT_SYMBOL(xfrm_find_acq_byseq); + +u32 xfrm_get_acqseq(void) +{ + u32 res; + static atomic_t acqseq; + + do { + res = atomic_inc_return(&acqseq); + } while (!res); + + return res; +} +EXPORT_SYMBOL(xfrm_get_acqseq); + +int verify_spi_info(u8 proto, u32 min, u32 max) +{ + switch (proto) { + case IPPROTO_AH: + case IPPROTO_ESP: + break; + + case IPPROTO_COMP: + /* IPCOMP spi is 16-bits. */ + if (max >= 0x10000) + return -EINVAL; + break; + + default: + return -EINVAL; + } + + if (min > max) + return -EINVAL; + + return 0; +} +EXPORT_SYMBOL(verify_spi_info); + +int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high) +{ + struct net *net = xs_net(x); + unsigned int h; + struct xfrm_state *x0; + int err = -ENOENT; + __be32 minspi = htonl(low); + __be32 maxspi = htonl(high); + __be32 newspi = 0; + u32 mark = x->mark.v & x->mark.m; + + spin_lock_bh(&x->lock); + if (x->km.state == XFRM_STATE_DEAD) + goto unlock; + + err = 0; + if (x->id.spi) + goto unlock; + + err = -ENOENT; + + if (minspi == maxspi) { + x0 = xfrm_state_lookup(net, mark, &x->id.daddr, minspi, x->id.proto, x->props.family); + if (x0) { + xfrm_state_put(x0); + goto unlock; + } + newspi = minspi; + } else { + u32 spi = 0; + for (h = 0; h < high-low+1; h++) { + spi = low + prandom_u32_max(high - low + 1); + x0 = xfrm_state_lookup(net, mark, &x->id.daddr, htonl(spi), x->id.proto, x->props.family); + if (x0 == NULL) { + newspi = htonl(spi); + break; + } + xfrm_state_put(x0); + } + } + if (newspi) { + spin_lock_bh(&net->xfrm.xfrm_state_lock); + x->id.spi = newspi; + h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, x->props.family); + hlist_add_head_rcu(&x->byspi, net->xfrm.state_byspi + h); + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + + err = 0; + } + +unlock: + spin_unlock_bh(&x->lock); + + return err; +} +EXPORT_SYMBOL(xfrm_alloc_spi); + +static bool __xfrm_state_filter_match(struct xfrm_state *x, + struct xfrm_address_filter *filter) +{ + if (filter) { + if ((filter->family == AF_INET || + filter->family == AF_INET6) && + x->props.family != filter->family) + return false; + + return addr_match(&x->props.saddr, &filter->saddr, + filter->splen) && + addr_match(&x->id.daddr, &filter->daddr, + filter->dplen); + } + return true; +} + +int xfrm_state_walk(struct net *net, struct xfrm_state_walk *walk, + int (*func)(struct xfrm_state *, int, void*), + void *data) +{ + struct xfrm_state *state; + struct xfrm_state_walk *x; + int err = 0; + + if (walk->seq != 0 && list_empty(&walk->all)) + return 0; + + spin_lock_bh(&net->xfrm.xfrm_state_lock); + if (list_empty(&walk->all)) + x = list_first_entry(&net->xfrm.state_all, struct xfrm_state_walk, all); + else + x = list_first_entry(&walk->all, struct xfrm_state_walk, all); + list_for_each_entry_from(x, &net->xfrm.state_all, all) { + if (x->state == XFRM_STATE_DEAD) + continue; + state = container_of(x, struct xfrm_state, km); + if (!xfrm_id_proto_match(state->id.proto, walk->proto)) + continue; + if (!__xfrm_state_filter_match(state, walk->filter)) + continue; + err = func(state, walk->seq, data); + if (err) { + list_move_tail(&walk->all, &x->all); + goto out; + } + walk->seq++; + } + if (walk->seq == 0) { + err = -ENOENT; + goto out; + } + list_del_init(&walk->all); +out: + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + return err; +} +EXPORT_SYMBOL(xfrm_state_walk); + +void xfrm_state_walk_init(struct xfrm_state_walk *walk, u8 proto, + struct xfrm_address_filter *filter) +{ + INIT_LIST_HEAD(&walk->all); + walk->proto = proto; + walk->state = XFRM_STATE_DEAD; + walk->seq = 0; + walk->filter = filter; +} +EXPORT_SYMBOL(xfrm_state_walk_init); + +void xfrm_state_walk_done(struct xfrm_state_walk *walk, struct net *net) +{ + kfree(walk->filter); + + if (list_empty(&walk->all)) + return; + + spin_lock_bh(&net->xfrm.xfrm_state_lock); + list_del(&walk->all); + spin_unlock_bh(&net->xfrm.xfrm_state_lock); +} +EXPORT_SYMBOL(xfrm_state_walk_done); + +static void xfrm_replay_timer_handler(struct timer_list *t) +{ + struct xfrm_state *x = from_timer(x, t, rtimer); + + spin_lock(&x->lock); + + if (x->km.state == XFRM_STATE_VALID) { + if (xfrm_aevent_is_on(xs_net(x))) + xfrm_replay_notify(x, XFRM_REPLAY_TIMEOUT); + else + x->xflags |= XFRM_TIME_DEFER; + } + + spin_unlock(&x->lock); +} + +static LIST_HEAD(xfrm_km_list); + +void km_policy_notify(struct xfrm_policy *xp, int dir, const struct km_event *c) +{ + struct xfrm_mgr *km; + + rcu_read_lock(); + list_for_each_entry_rcu(km, &xfrm_km_list, list) + if (km->notify_policy) + km->notify_policy(xp, dir, c); + rcu_read_unlock(); +} + +void km_state_notify(struct xfrm_state *x, const struct km_event *c) +{ + struct xfrm_mgr *km; + rcu_read_lock(); + list_for_each_entry_rcu(km, &xfrm_km_list, list) + if (km->notify) + km->notify(x, c); + rcu_read_unlock(); +} + +EXPORT_SYMBOL(km_policy_notify); +EXPORT_SYMBOL(km_state_notify); + +void km_state_expired(struct xfrm_state *x, int hard, u32 portid) +{ + struct km_event c; + + c.data.hard = hard; + c.portid = portid; + c.event = XFRM_MSG_EXPIRE; + km_state_notify(x, &c); +} + +EXPORT_SYMBOL(km_state_expired); +/* + * We send to all registered managers regardless of failure + * We are happy with one success +*/ +int km_query(struct xfrm_state *x, struct xfrm_tmpl *t, struct xfrm_policy *pol) +{ + int err = -EINVAL, acqret; + struct xfrm_mgr *km; + + rcu_read_lock(); + list_for_each_entry_rcu(km, &xfrm_km_list, list) { + acqret = km->acquire(x, t, pol); + if (!acqret) + err = acqret; + } + rcu_read_unlock(); + return err; +} +EXPORT_SYMBOL(km_query); + +static int __km_new_mapping(struct xfrm_state *x, xfrm_address_t *ipaddr, __be16 sport) +{ + int err = -EINVAL; + struct xfrm_mgr *km; + + rcu_read_lock(); + list_for_each_entry_rcu(km, &xfrm_km_list, list) { + if (km->new_mapping) + err = km->new_mapping(x, ipaddr, sport); + if (!err) + break; + } + rcu_read_unlock(); + return err; +} + +int km_new_mapping(struct xfrm_state *x, xfrm_address_t *ipaddr, __be16 sport) +{ + int ret = 0; + + if (x->mapping_maxage) { + if ((jiffies / HZ - x->new_mapping) > x->mapping_maxage || + x->new_mapping_sport != sport) { + x->new_mapping_sport = sport; + x->new_mapping = jiffies / HZ; + ret = __km_new_mapping(x, ipaddr, sport); + } + } else { + ret = __km_new_mapping(x, ipaddr, sport); + } + + return ret; +} +EXPORT_SYMBOL(km_new_mapping); + +void km_policy_expired(struct xfrm_policy *pol, int dir, int hard, u32 portid) +{ + struct km_event c; + + c.data.hard = hard; + c.portid = portid; + c.event = XFRM_MSG_POLEXPIRE; + km_policy_notify(pol, dir, &c); +} +EXPORT_SYMBOL(km_policy_expired); + +#ifdef CONFIG_XFRM_MIGRATE +int km_migrate(const struct xfrm_selector *sel, u8 dir, u8 type, + const struct xfrm_migrate *m, int num_migrate, + const struct xfrm_kmaddress *k, + const struct xfrm_encap_tmpl *encap) +{ + int err = -EINVAL; + int ret; + struct xfrm_mgr *km; + + rcu_read_lock(); + list_for_each_entry_rcu(km, &xfrm_km_list, list) { + if (km->migrate) { + ret = km->migrate(sel, dir, type, m, num_migrate, k, + encap); + if (!ret) + err = ret; + } + } + rcu_read_unlock(); + return err; +} +EXPORT_SYMBOL(km_migrate); +#endif + +int km_report(struct net *net, u8 proto, struct xfrm_selector *sel, xfrm_address_t *addr) +{ + int err = -EINVAL; + int ret; + struct xfrm_mgr *km; + + rcu_read_lock(); + list_for_each_entry_rcu(km, &xfrm_km_list, list) { + if (km->report) { + ret = km->report(net, proto, sel, addr); + if (!ret) + err = ret; + } + } + rcu_read_unlock(); + return err; +} +EXPORT_SYMBOL(km_report); + +static bool km_is_alive(const struct km_event *c) +{ + struct xfrm_mgr *km; + bool is_alive = false; + + rcu_read_lock(); + list_for_each_entry_rcu(km, &xfrm_km_list, list) { + if (km->is_alive && km->is_alive(c)) { + is_alive = true; + break; + } + } + rcu_read_unlock(); + + return is_alive; +} + +#if IS_ENABLED(CONFIG_XFRM_USER_COMPAT) +static DEFINE_SPINLOCK(xfrm_translator_lock); +static struct xfrm_translator __rcu *xfrm_translator; + +struct xfrm_translator *xfrm_get_translator(void) +{ + struct xfrm_translator *xtr; + + rcu_read_lock(); + xtr = rcu_dereference(xfrm_translator); + if (unlikely(!xtr)) + goto out; + if (!try_module_get(xtr->owner)) + xtr = NULL; +out: + rcu_read_unlock(); + return xtr; +} +EXPORT_SYMBOL_GPL(xfrm_get_translator); + +void xfrm_put_translator(struct xfrm_translator *xtr) +{ + module_put(xtr->owner); +} +EXPORT_SYMBOL_GPL(xfrm_put_translator); + +int xfrm_register_translator(struct xfrm_translator *xtr) +{ + int err = 0; + + spin_lock_bh(&xfrm_translator_lock); + if (unlikely(xfrm_translator != NULL)) + err = -EEXIST; + else + rcu_assign_pointer(xfrm_translator, xtr); + spin_unlock_bh(&xfrm_translator_lock); + + return err; +} +EXPORT_SYMBOL_GPL(xfrm_register_translator); + +int xfrm_unregister_translator(struct xfrm_translator *xtr) +{ + int err = 0; + + spin_lock_bh(&xfrm_translator_lock); + if (likely(xfrm_translator != NULL)) { + if (rcu_access_pointer(xfrm_translator) != xtr) + err = -EINVAL; + else + RCU_INIT_POINTER(xfrm_translator, NULL); + } + spin_unlock_bh(&xfrm_translator_lock); + synchronize_rcu(); + + return err; +} +EXPORT_SYMBOL_GPL(xfrm_unregister_translator); +#endif + +int xfrm_user_policy(struct sock *sk, int optname, sockptr_t optval, int optlen) +{ + int err; + u8 *data; + struct xfrm_mgr *km; + struct xfrm_policy *pol = NULL; + + if (sockptr_is_null(optval) && !optlen) { + xfrm_sk_policy_insert(sk, XFRM_POLICY_IN, NULL); + xfrm_sk_policy_insert(sk, XFRM_POLICY_OUT, NULL); + __sk_dst_reset(sk); + return 0; + } + + if (optlen <= 0 || optlen > PAGE_SIZE) + return -EMSGSIZE; + + data = memdup_sockptr(optval, optlen); + if (IS_ERR(data)) + return PTR_ERR(data); + + if (in_compat_syscall()) { + struct xfrm_translator *xtr = xfrm_get_translator(); + + if (!xtr) { + kfree(data); + return -EOPNOTSUPP; + } + + err = xtr->xlate_user_policy_sockptr(&data, optlen); + xfrm_put_translator(xtr); + if (err) { + kfree(data); + return err; + } + } + + err = -EINVAL; + rcu_read_lock(); + list_for_each_entry_rcu(km, &xfrm_km_list, list) { + pol = km->compile_policy(sk, optname, data, + optlen, &err); + if (err >= 0) + break; + } + rcu_read_unlock(); + + if (err >= 0) { + xfrm_sk_policy_insert(sk, err, pol); + xfrm_pol_put(pol); + __sk_dst_reset(sk); + err = 0; + } + + kfree(data); + return err; +} +EXPORT_SYMBOL(xfrm_user_policy); + +static DEFINE_SPINLOCK(xfrm_km_lock); + +void xfrm_register_km(struct xfrm_mgr *km) +{ + spin_lock_bh(&xfrm_km_lock); + list_add_tail_rcu(&km->list, &xfrm_km_list); + spin_unlock_bh(&xfrm_km_lock); +} +EXPORT_SYMBOL(xfrm_register_km); + +void xfrm_unregister_km(struct xfrm_mgr *km) +{ + spin_lock_bh(&xfrm_km_lock); + list_del_rcu(&km->list); + spin_unlock_bh(&xfrm_km_lock); + synchronize_rcu(); +} +EXPORT_SYMBOL(xfrm_unregister_km); + +int xfrm_state_register_afinfo(struct xfrm_state_afinfo *afinfo) +{ + int err = 0; + + if (WARN_ON(afinfo->family >= NPROTO)) + return -EAFNOSUPPORT; + + spin_lock_bh(&xfrm_state_afinfo_lock); + if (unlikely(xfrm_state_afinfo[afinfo->family] != NULL)) + err = -EEXIST; + else + rcu_assign_pointer(xfrm_state_afinfo[afinfo->family], afinfo); + spin_unlock_bh(&xfrm_state_afinfo_lock); + return err; +} +EXPORT_SYMBOL(xfrm_state_register_afinfo); + +int xfrm_state_unregister_afinfo(struct xfrm_state_afinfo *afinfo) +{ + int err = 0, family = afinfo->family; + + if (WARN_ON(family >= NPROTO)) + return -EAFNOSUPPORT; + + spin_lock_bh(&xfrm_state_afinfo_lock); + if (likely(xfrm_state_afinfo[afinfo->family] != NULL)) { + if (rcu_access_pointer(xfrm_state_afinfo[family]) != afinfo) + err = -EINVAL; + else + RCU_INIT_POINTER(xfrm_state_afinfo[afinfo->family], NULL); + } + spin_unlock_bh(&xfrm_state_afinfo_lock); + synchronize_rcu(); + return err; +} +EXPORT_SYMBOL(xfrm_state_unregister_afinfo); + +struct xfrm_state_afinfo *xfrm_state_afinfo_get_rcu(unsigned int family) +{ + if (unlikely(family >= NPROTO)) + return NULL; + + return rcu_dereference(xfrm_state_afinfo[family]); +} +EXPORT_SYMBOL_GPL(xfrm_state_afinfo_get_rcu); + +struct xfrm_state_afinfo *xfrm_state_get_afinfo(unsigned int family) +{ + struct xfrm_state_afinfo *afinfo; + if (unlikely(family >= NPROTO)) + return NULL; + rcu_read_lock(); + afinfo = rcu_dereference(xfrm_state_afinfo[family]); + if (unlikely(!afinfo)) + rcu_read_unlock(); + return afinfo; +} + +void xfrm_flush_gc(void) +{ + flush_work(&xfrm_state_gc_work); +} +EXPORT_SYMBOL(xfrm_flush_gc); + +/* Temporarily located here until net/xfrm/xfrm_tunnel.c is created */ +void xfrm_state_delete_tunnel(struct xfrm_state *x) +{ + if (x->tunnel) { + struct xfrm_state *t = x->tunnel; + + if (atomic_read(&t->tunnel_users) == 2) + xfrm_state_delete(t); + atomic_dec(&t->tunnel_users); + xfrm_state_put_sync(t); + x->tunnel = NULL; + } +} +EXPORT_SYMBOL(xfrm_state_delete_tunnel); + +u32 xfrm_state_mtu(struct xfrm_state *x, int mtu) +{ + const struct xfrm_type *type = READ_ONCE(x->type); + struct crypto_aead *aead; + u32 blksize, net_adj = 0; + + if (x->km.state != XFRM_STATE_VALID || + !type || type->proto != IPPROTO_ESP) + return mtu - x->props.header_len; + + aead = x->data; + blksize = ALIGN(crypto_aead_blocksize(aead), 4); + + switch (x->props.mode) { + case XFRM_MODE_TRANSPORT: + case XFRM_MODE_BEET: + if (x->props.family == AF_INET) + net_adj = sizeof(struct iphdr); + else if (x->props.family == AF_INET6) + net_adj = sizeof(struct ipv6hdr); + break; + case XFRM_MODE_TUNNEL: + break; + default: + WARN_ON_ONCE(1); + break; + } + + return ((mtu - x->props.header_len - crypto_aead_authsize(aead) - + net_adj) & ~(blksize - 1)) + net_adj - 2; +} +EXPORT_SYMBOL_GPL(xfrm_state_mtu); + +int __xfrm_init_state(struct xfrm_state *x, bool init_replay, bool offload, + struct netlink_ext_ack *extack) +{ + const struct xfrm_mode *inner_mode; + const struct xfrm_mode *outer_mode; + int family = x->props.family; + int err; + + if (family == AF_INET && + READ_ONCE(xs_net(x)->ipv4.sysctl_ip_no_pmtu_disc)) + x->props.flags |= XFRM_STATE_NOPMTUDISC; + + err = -EPROTONOSUPPORT; + + if (x->sel.family != AF_UNSPEC) { + inner_mode = xfrm_get_mode(x->props.mode, x->sel.family); + if (inner_mode == NULL) { + NL_SET_ERR_MSG(extack, "Requested mode not found"); + goto error; + } + + if (!(inner_mode->flags & XFRM_MODE_FLAG_TUNNEL) && + family != x->sel.family) { + NL_SET_ERR_MSG(extack, "Only tunnel modes can accommodate a change of family"); + goto error; + } + + x->inner_mode = *inner_mode; + } else { + const struct xfrm_mode *inner_mode_iaf; + int iafamily = AF_INET; + + inner_mode = xfrm_get_mode(x->props.mode, x->props.family); + if (inner_mode == NULL) { + NL_SET_ERR_MSG(extack, "Requested mode not found"); + goto error; + } + + x->inner_mode = *inner_mode; + + if (x->props.family == AF_INET) + iafamily = AF_INET6; + + inner_mode_iaf = xfrm_get_mode(x->props.mode, iafamily); + if (inner_mode_iaf) { + if (inner_mode_iaf->flags & XFRM_MODE_FLAG_TUNNEL) + x->inner_mode_iaf = *inner_mode_iaf; + } + } + + x->type = xfrm_get_type(x->id.proto, family); + if (x->type == NULL) { + NL_SET_ERR_MSG(extack, "Requested type not found"); + goto error; + } + + x->type_offload = xfrm_get_type_offload(x->id.proto, family, offload); + + err = x->type->init_state(x, extack); + if (err) + goto error; + + outer_mode = xfrm_get_mode(x->props.mode, family); + if (!outer_mode) { + NL_SET_ERR_MSG(extack, "Requested mode not found"); + err = -EPROTONOSUPPORT; + goto error; + } + + x->outer_mode = *outer_mode; + if (init_replay) { + err = xfrm_init_replay(x, extack); + if (err) + goto error; + } + +error: + return err; +} + +EXPORT_SYMBOL(__xfrm_init_state); + +int xfrm_init_state(struct xfrm_state *x) +{ + int err; + + err = __xfrm_init_state(x, true, false, NULL); + if (!err) + x->km.state = XFRM_STATE_VALID; + + return err; +} + +EXPORT_SYMBOL(xfrm_init_state); + +int __net_init xfrm_state_init(struct net *net) +{ + unsigned int sz; + + if (net_eq(net, &init_net)) + xfrm_state_cache = KMEM_CACHE(xfrm_state, + SLAB_HWCACHE_ALIGN | SLAB_PANIC); + + INIT_LIST_HEAD(&net->xfrm.state_all); + + sz = sizeof(struct hlist_head) * 8; + + net->xfrm.state_bydst = xfrm_hash_alloc(sz); + if (!net->xfrm.state_bydst) + goto out_bydst; + net->xfrm.state_bysrc = xfrm_hash_alloc(sz); + if (!net->xfrm.state_bysrc) + goto out_bysrc; + net->xfrm.state_byspi = xfrm_hash_alloc(sz); + if (!net->xfrm.state_byspi) + goto out_byspi; + net->xfrm.state_byseq = xfrm_hash_alloc(sz); + if (!net->xfrm.state_byseq) + goto out_byseq; + net->xfrm.state_hmask = ((sz / sizeof(struct hlist_head)) - 1); + + net->xfrm.state_num = 0; + INIT_WORK(&net->xfrm.state_hash_work, xfrm_hash_resize); + spin_lock_init(&net->xfrm.xfrm_state_lock); + seqcount_spinlock_init(&net->xfrm.xfrm_state_hash_generation, + &net->xfrm.xfrm_state_lock); + return 0; + +out_byseq: + xfrm_hash_free(net->xfrm.state_byspi, sz); +out_byspi: + xfrm_hash_free(net->xfrm.state_bysrc, sz); +out_bysrc: + xfrm_hash_free(net->xfrm.state_bydst, sz); +out_bydst: + return -ENOMEM; +} + +void xfrm_state_fini(struct net *net) +{ + unsigned int sz; + + flush_work(&net->xfrm.state_hash_work); + flush_work(&xfrm_state_gc_work); + xfrm_state_flush(net, 0, false, true); + + WARN_ON(!list_empty(&net->xfrm.state_all)); + + sz = (net->xfrm.state_hmask + 1) * sizeof(struct hlist_head); + WARN_ON(!hlist_empty(net->xfrm.state_byseq)); + xfrm_hash_free(net->xfrm.state_byseq, sz); + WARN_ON(!hlist_empty(net->xfrm.state_byspi)); + xfrm_hash_free(net->xfrm.state_byspi, sz); + WARN_ON(!hlist_empty(net->xfrm.state_bysrc)); + xfrm_hash_free(net->xfrm.state_bysrc, sz); + WARN_ON(!hlist_empty(net->xfrm.state_bydst)); + xfrm_hash_free(net->xfrm.state_bydst, sz); +} + +#ifdef CONFIG_AUDITSYSCALL +static void xfrm_audit_helper_sainfo(struct xfrm_state *x, + struct audit_buffer *audit_buf) +{ + struct xfrm_sec_ctx *ctx = x->security; + u32 spi = ntohl(x->id.spi); + + if (ctx) + audit_log_format(audit_buf, " sec_alg=%u sec_doi=%u sec_obj=%s", + ctx->ctx_alg, ctx->ctx_doi, ctx->ctx_str); + + switch (x->props.family) { + case AF_INET: + audit_log_format(audit_buf, " src=%pI4 dst=%pI4", + &x->props.saddr.a4, &x->id.daddr.a4); + break; + case AF_INET6: + audit_log_format(audit_buf, " src=%pI6 dst=%pI6", + x->props.saddr.a6, x->id.daddr.a6); + break; + } + + audit_log_format(audit_buf, " spi=%u(0x%x)", spi, spi); +} + +static void xfrm_audit_helper_pktinfo(struct sk_buff *skb, u16 family, + struct audit_buffer *audit_buf) +{ + const struct iphdr *iph4; + const struct ipv6hdr *iph6; + + switch (family) { + case AF_INET: + iph4 = ip_hdr(skb); + audit_log_format(audit_buf, " src=%pI4 dst=%pI4", + &iph4->saddr, &iph4->daddr); + break; + case AF_INET6: + iph6 = ipv6_hdr(skb); + audit_log_format(audit_buf, + " src=%pI6 dst=%pI6 flowlbl=0x%x%02x%02x", + &iph6->saddr, &iph6->daddr, + iph6->flow_lbl[0] & 0x0f, + iph6->flow_lbl[1], + iph6->flow_lbl[2]); + break; + } +} + +void xfrm_audit_state_add(struct xfrm_state *x, int result, bool task_valid) +{ + struct audit_buffer *audit_buf; + + audit_buf = xfrm_audit_start("SAD-add"); + if (audit_buf == NULL) + return; + xfrm_audit_helper_usrinfo(task_valid, audit_buf); + xfrm_audit_helper_sainfo(x, audit_buf); + audit_log_format(audit_buf, " res=%u", result); + audit_log_end(audit_buf); +} +EXPORT_SYMBOL_GPL(xfrm_audit_state_add); + +void xfrm_audit_state_delete(struct xfrm_state *x, int result, bool task_valid) +{ + struct audit_buffer *audit_buf; + + audit_buf = xfrm_audit_start("SAD-delete"); + if (audit_buf == NULL) + return; + xfrm_audit_helper_usrinfo(task_valid, audit_buf); + xfrm_audit_helper_sainfo(x, audit_buf); + audit_log_format(audit_buf, " res=%u", result); + audit_log_end(audit_buf); +} +EXPORT_SYMBOL_GPL(xfrm_audit_state_delete); + +void xfrm_audit_state_replay_overflow(struct xfrm_state *x, + struct sk_buff *skb) +{ + struct audit_buffer *audit_buf; + u32 spi; + + audit_buf = xfrm_audit_start("SA-replay-overflow"); + if (audit_buf == NULL) + return; + xfrm_audit_helper_pktinfo(skb, x->props.family, audit_buf); + /* don't record the sequence number because it's inherent in this kind + * of audit message */ + spi = ntohl(x->id.spi); + audit_log_format(audit_buf, " spi=%u(0x%x)", spi, spi); + audit_log_end(audit_buf); +} +EXPORT_SYMBOL_GPL(xfrm_audit_state_replay_overflow); + +void xfrm_audit_state_replay(struct xfrm_state *x, + struct sk_buff *skb, __be32 net_seq) +{ + struct audit_buffer *audit_buf; + u32 spi; + + audit_buf = xfrm_audit_start("SA-replayed-pkt"); + if (audit_buf == NULL) + return; + xfrm_audit_helper_pktinfo(skb, x->props.family, audit_buf); + spi = ntohl(x->id.spi); + audit_log_format(audit_buf, " spi=%u(0x%x) seqno=%u", + spi, spi, ntohl(net_seq)); + audit_log_end(audit_buf); +} +EXPORT_SYMBOL_GPL(xfrm_audit_state_replay); + +void xfrm_audit_state_notfound_simple(struct sk_buff *skb, u16 family) +{ + struct audit_buffer *audit_buf; + + audit_buf = xfrm_audit_start("SA-notfound"); + if (audit_buf == NULL) + return; + xfrm_audit_helper_pktinfo(skb, family, audit_buf); + audit_log_end(audit_buf); +} +EXPORT_SYMBOL_GPL(xfrm_audit_state_notfound_simple); + +void xfrm_audit_state_notfound(struct sk_buff *skb, u16 family, + __be32 net_spi, __be32 net_seq) +{ + struct audit_buffer *audit_buf; + u32 spi; + + audit_buf = xfrm_audit_start("SA-notfound"); + if (audit_buf == NULL) + return; + xfrm_audit_helper_pktinfo(skb, family, audit_buf); + spi = ntohl(net_spi); + audit_log_format(audit_buf, " spi=%u(0x%x) seqno=%u", + spi, spi, ntohl(net_seq)); + audit_log_end(audit_buf); +} +EXPORT_SYMBOL_GPL(xfrm_audit_state_notfound); + +void xfrm_audit_state_icvfail(struct xfrm_state *x, + struct sk_buff *skb, u8 proto) +{ + struct audit_buffer *audit_buf; + __be32 net_spi; + __be32 net_seq; + + audit_buf = xfrm_audit_start("SA-icv-failure"); + if (audit_buf == NULL) + return; + xfrm_audit_helper_pktinfo(skb, x->props.family, audit_buf); + if (xfrm_parse_spi(skb, proto, &net_spi, &net_seq) == 0) { + u32 spi = ntohl(net_spi); + audit_log_format(audit_buf, " spi=%u(0x%x) seqno=%u", + spi, spi, ntohl(net_seq)); + } + audit_log_end(audit_buf); +} +EXPORT_SYMBOL_GPL(xfrm_audit_state_icvfail); +#endif /* CONFIG_AUDITSYSCALL */ diff --git a/net/xfrm/xfrm_sysctl.c b/net/xfrm/xfrm_sysctl.c new file mode 100644 index 000000000..0c6c5ef65 --- /dev/null +++ b/net/xfrm/xfrm_sysctl.c @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: GPL-2.0 +#include +#include +#include +#include + +static void __net_init __xfrm_sysctl_init(struct net *net) +{ + net->xfrm.sysctl_aevent_etime = XFRM_AE_ETIME; + net->xfrm.sysctl_aevent_rseqth = XFRM_AE_SEQT_SIZE; + net->xfrm.sysctl_larval_drop = 1; + net->xfrm.sysctl_acq_expires = 30; +} + +#ifdef CONFIG_SYSCTL +static struct ctl_table xfrm_table[] = { + { + .procname = "xfrm_aevent_etime", + .maxlen = sizeof(u32), + .mode = 0644, + .proc_handler = proc_douintvec + }, + { + .procname = "xfrm_aevent_rseqth", + .maxlen = sizeof(u32), + .mode = 0644, + .proc_handler = proc_douintvec + }, + { + .procname = "xfrm_larval_drop", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + { + .procname = "xfrm_acq_expires", + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec + }, + {} +}; + +int __net_init xfrm_sysctl_init(struct net *net) +{ + struct ctl_table *table; + + __xfrm_sysctl_init(net); + + table = kmemdup(xfrm_table, sizeof(xfrm_table), GFP_KERNEL); + if (!table) + goto out_kmemdup; + table[0].data = &net->xfrm.sysctl_aevent_etime; + table[1].data = &net->xfrm.sysctl_aevent_rseqth; + table[2].data = &net->xfrm.sysctl_larval_drop; + table[3].data = &net->xfrm.sysctl_acq_expires; + + /* Don't export sysctls to unprivileged users */ + if (net->user_ns != &init_user_ns) + table[0].procname = NULL; + + net->xfrm.sysctl_hdr = register_net_sysctl(net, "net/core", table); + if (!net->xfrm.sysctl_hdr) + goto out_register; + return 0; + +out_register: + kfree(table); +out_kmemdup: + return -ENOMEM; +} + +void __net_exit xfrm_sysctl_fini(struct net *net) +{ + struct ctl_table *table; + + table = net->xfrm.sysctl_hdr->ctl_table_arg; + unregister_net_sysctl_table(net->xfrm.sysctl_hdr); + kfree(table); +} +#else +int __net_init xfrm_sysctl_init(struct net *net) +{ + __xfrm_sysctl_init(net); + return 0; +} +#endif diff --git a/net/xfrm/xfrm_user.c b/net/xfrm/xfrm_user.c new file mode 100644 index 000000000..d042ca012 --- /dev/null +++ b/net/xfrm/xfrm_user.c @@ -0,0 +1,3830 @@ +// SPDX-License-Identifier: GPL-2.0-only +/* xfrm_user.c: User interface to configure xfrm engine. + * + * Copyright (C) 2002 David S. Miller (davem@redhat.com) + * + * Changes: + * Mitsuru KANDA @USAGI + * Kazunori MIYAZAWA @USAGI + * Kunihiro Ishiguro + * IPv6 support + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if IS_ENABLED(CONFIG_IPV6) +#include +#endif +#include + +static int verify_one_alg(struct nlattr **attrs, enum xfrm_attr_type_t type, + struct netlink_ext_ack *extack) +{ + struct nlattr *rt = attrs[type]; + struct xfrm_algo *algp; + + if (!rt) + return 0; + + algp = nla_data(rt); + if (nla_len(rt) < (int)xfrm_alg_len(algp)) { + NL_SET_ERR_MSG(extack, "Invalid AUTH/CRYPT/COMP attribute length"); + return -EINVAL; + } + + switch (type) { + case XFRMA_ALG_AUTH: + case XFRMA_ALG_CRYPT: + case XFRMA_ALG_COMP: + break; + + default: + NL_SET_ERR_MSG(extack, "Invalid algorithm attribute type"); + return -EINVAL; + } + + algp->alg_name[sizeof(algp->alg_name) - 1] = '\0'; + return 0; +} + +static int verify_auth_trunc(struct nlattr **attrs, + struct netlink_ext_ack *extack) +{ + struct nlattr *rt = attrs[XFRMA_ALG_AUTH_TRUNC]; + struct xfrm_algo_auth *algp; + + if (!rt) + return 0; + + algp = nla_data(rt); + if (nla_len(rt) < (int)xfrm_alg_auth_len(algp)) { + NL_SET_ERR_MSG(extack, "Invalid AUTH_TRUNC attribute length"); + return -EINVAL; + } + + algp->alg_name[sizeof(algp->alg_name) - 1] = '\0'; + return 0; +} + +static int verify_aead(struct nlattr **attrs, struct netlink_ext_ack *extack) +{ + struct nlattr *rt = attrs[XFRMA_ALG_AEAD]; + struct xfrm_algo_aead *algp; + + if (!rt) + return 0; + + algp = nla_data(rt); + if (nla_len(rt) < (int)aead_len(algp)) { + NL_SET_ERR_MSG(extack, "Invalid AEAD attribute length"); + return -EINVAL; + } + + algp->alg_name[sizeof(algp->alg_name) - 1] = '\0'; + return 0; +} + +static void verify_one_addr(struct nlattr **attrs, enum xfrm_attr_type_t type, + xfrm_address_t **addrp) +{ + struct nlattr *rt = attrs[type]; + + if (rt && addrp) + *addrp = nla_data(rt); +} + +static inline int verify_sec_ctx_len(struct nlattr **attrs, struct netlink_ext_ack *extack) +{ + struct nlattr *rt = attrs[XFRMA_SEC_CTX]; + struct xfrm_user_sec_ctx *uctx; + + if (!rt) + return 0; + + uctx = nla_data(rt); + if (uctx->len > nla_len(rt) || + uctx->len != (sizeof(struct xfrm_user_sec_ctx) + uctx->ctx_len)) { + NL_SET_ERR_MSG(extack, "Invalid security context length"); + return -EINVAL; + } + + return 0; +} + +static inline int verify_replay(struct xfrm_usersa_info *p, + struct nlattr **attrs, + struct netlink_ext_ack *extack) +{ + struct nlattr *rt = attrs[XFRMA_REPLAY_ESN_VAL]; + struct xfrm_replay_state_esn *rs; + + if (!rt) { + if (p->flags & XFRM_STATE_ESN) { + NL_SET_ERR_MSG(extack, "Missing required attribute for ESN"); + return -EINVAL; + } + return 0; + } + + rs = nla_data(rt); + + if (rs->bmp_len > XFRMA_REPLAY_ESN_MAX / sizeof(rs->bmp[0]) / 8) { + NL_SET_ERR_MSG(extack, "ESN bitmap length must be <= 128"); + return -EINVAL; + } + + if (nla_len(rt) < (int)xfrm_replay_state_esn_len(rs) && + nla_len(rt) != sizeof(*rs)) { + NL_SET_ERR_MSG(extack, "ESN attribute is too short to fit the full bitmap length"); + return -EINVAL; + } + + /* As only ESP and AH support ESN feature. */ + if ((p->id.proto != IPPROTO_ESP) && (p->id.proto != IPPROTO_AH)) { + NL_SET_ERR_MSG(extack, "ESN only supported for ESP and AH"); + return -EINVAL; + } + + if (p->replay_window != 0) { + NL_SET_ERR_MSG(extack, "ESN not compatible with legacy replay_window"); + return -EINVAL; + } + + return 0; +} + +static int verify_newsa_info(struct xfrm_usersa_info *p, + struct nlattr **attrs, + struct netlink_ext_ack *extack) +{ + int err; + + err = -EINVAL; + switch (p->family) { + case AF_INET: + break; + + case AF_INET6: +#if IS_ENABLED(CONFIG_IPV6) + break; +#else + err = -EAFNOSUPPORT; + NL_SET_ERR_MSG(extack, "IPv6 support disabled"); + goto out; +#endif + + default: + NL_SET_ERR_MSG(extack, "Invalid address family"); + goto out; + } + + switch (p->sel.family) { + case AF_UNSPEC: + break; + + case AF_INET: + if (p->sel.prefixlen_d > 32 || p->sel.prefixlen_s > 32) { + NL_SET_ERR_MSG(extack, "Invalid prefix length in selector (must be <= 32 for IPv4)"); + goto out; + } + + break; + + case AF_INET6: +#if IS_ENABLED(CONFIG_IPV6) + if (p->sel.prefixlen_d > 128 || p->sel.prefixlen_s > 128) { + NL_SET_ERR_MSG(extack, "Invalid prefix length in selector (must be <= 128 for IPv6)"); + goto out; + } + + break; +#else + NL_SET_ERR_MSG(extack, "IPv6 support disabled"); + err = -EAFNOSUPPORT; + goto out; +#endif + + default: + NL_SET_ERR_MSG(extack, "Invalid address family in selector"); + goto out; + } + + err = -EINVAL; + switch (p->id.proto) { + case IPPROTO_AH: + if (!attrs[XFRMA_ALG_AUTH] && + !attrs[XFRMA_ALG_AUTH_TRUNC]) { + NL_SET_ERR_MSG(extack, "Missing required attribute for AH: AUTH_TRUNC or AUTH"); + goto out; + } + + if (attrs[XFRMA_ALG_AEAD] || + attrs[XFRMA_ALG_CRYPT] || + attrs[XFRMA_ALG_COMP] || + attrs[XFRMA_TFCPAD]) { + NL_SET_ERR_MSG(extack, "Invalid attributes for AH: AEAD, CRYPT, COMP, TFCPAD"); + goto out; + } + break; + + case IPPROTO_ESP: + if (attrs[XFRMA_ALG_COMP]) { + NL_SET_ERR_MSG(extack, "Invalid attribute for ESP: COMP"); + goto out; + } + + if (!attrs[XFRMA_ALG_AUTH] && + !attrs[XFRMA_ALG_AUTH_TRUNC] && + !attrs[XFRMA_ALG_CRYPT] && + !attrs[XFRMA_ALG_AEAD]) { + NL_SET_ERR_MSG(extack, "Missing required attribute for ESP: at least one of AUTH, AUTH_TRUNC, CRYPT, AEAD"); + goto out; + } + + if ((attrs[XFRMA_ALG_AUTH] || + attrs[XFRMA_ALG_AUTH_TRUNC] || + attrs[XFRMA_ALG_CRYPT]) && + attrs[XFRMA_ALG_AEAD]) { + NL_SET_ERR_MSG(extack, "Invalid attribute combination for ESP: AEAD can't be used with AUTH, AUTH_TRUNC, CRYPT"); + goto out; + } + + if (attrs[XFRMA_TFCPAD] && + p->mode != XFRM_MODE_TUNNEL) { + NL_SET_ERR_MSG(extack, "TFC padding can only be used in tunnel mode"); + goto out; + } + break; + + case IPPROTO_COMP: + if (!attrs[XFRMA_ALG_COMP]) { + NL_SET_ERR_MSG(extack, "Missing required attribute for COMP: COMP"); + goto out; + } + + if (attrs[XFRMA_ALG_AEAD] || + attrs[XFRMA_ALG_AUTH] || + attrs[XFRMA_ALG_AUTH_TRUNC] || + attrs[XFRMA_ALG_CRYPT] || + attrs[XFRMA_TFCPAD]) { + NL_SET_ERR_MSG(extack, "Invalid attributes for COMP: AEAD, AUTH, AUTH_TRUNC, CRYPT, TFCPAD"); + goto out; + } + + if (ntohl(p->id.spi) >= 0x10000) { + NL_SET_ERR_MSG(extack, "SPI is too large for COMP (must be < 0x10000)"); + goto out; + } + break; + +#if IS_ENABLED(CONFIG_IPV6) + case IPPROTO_DSTOPTS: + case IPPROTO_ROUTING: + if (attrs[XFRMA_ALG_COMP] || + attrs[XFRMA_ALG_AUTH] || + attrs[XFRMA_ALG_AUTH_TRUNC] || + attrs[XFRMA_ALG_AEAD] || + attrs[XFRMA_ALG_CRYPT] || + attrs[XFRMA_ENCAP] || + attrs[XFRMA_SEC_CTX] || + attrs[XFRMA_TFCPAD]) { + NL_SET_ERR_MSG(extack, "Invalid attributes for DSTOPTS/ROUTING"); + goto out; + } + + if (!attrs[XFRMA_COADDR]) { + NL_SET_ERR_MSG(extack, "Missing required COADDR attribute for DSTOPTS/ROUTING"); + goto out; + } + break; +#endif + + default: + NL_SET_ERR_MSG(extack, "Unsupported protocol"); + goto out; + } + + if ((err = verify_aead(attrs, extack))) + goto out; + if ((err = verify_auth_trunc(attrs, extack))) + goto out; + if ((err = verify_one_alg(attrs, XFRMA_ALG_AUTH, extack))) + goto out; + if ((err = verify_one_alg(attrs, XFRMA_ALG_CRYPT, extack))) + goto out; + if ((err = verify_one_alg(attrs, XFRMA_ALG_COMP, extack))) + goto out; + if ((err = verify_sec_ctx_len(attrs, extack))) + goto out; + if ((err = verify_replay(p, attrs, extack))) + goto out; + + err = -EINVAL; + switch (p->mode) { + case XFRM_MODE_TRANSPORT: + case XFRM_MODE_TUNNEL: + case XFRM_MODE_ROUTEOPTIMIZATION: + case XFRM_MODE_BEET: + break; + + default: + NL_SET_ERR_MSG(extack, "Unsupported mode"); + goto out; + } + + err = 0; + + if (attrs[XFRMA_MTIMER_THRESH]) { + if (!attrs[XFRMA_ENCAP]) { + NL_SET_ERR_MSG(extack, "MTIMER_THRESH attribute can only be set on ENCAP states"); + err = -EINVAL; + goto out; + } + } + +out: + return err; +} + +static int attach_one_algo(struct xfrm_algo **algpp, u8 *props, + struct xfrm_algo_desc *(*get_byname)(const char *, int), + struct nlattr *rta, struct netlink_ext_ack *extack) +{ + struct xfrm_algo *p, *ualg; + struct xfrm_algo_desc *algo; + + if (!rta) + return 0; + + ualg = nla_data(rta); + + algo = get_byname(ualg->alg_name, 1); + if (!algo) { + NL_SET_ERR_MSG(extack, "Requested COMP algorithm not found"); + return -ENOSYS; + } + *props = algo->desc.sadb_alg_id; + + p = kmemdup(ualg, xfrm_alg_len(ualg), GFP_KERNEL); + if (!p) + return -ENOMEM; + + strcpy(p->alg_name, algo->name); + *algpp = p; + return 0; +} + +static int attach_crypt(struct xfrm_state *x, struct nlattr *rta, + struct netlink_ext_ack *extack) +{ + struct xfrm_algo *p, *ualg; + struct xfrm_algo_desc *algo; + + if (!rta) + return 0; + + ualg = nla_data(rta); + + algo = xfrm_ealg_get_byname(ualg->alg_name, 1); + if (!algo) { + NL_SET_ERR_MSG(extack, "Requested CRYPT algorithm not found"); + return -ENOSYS; + } + x->props.ealgo = algo->desc.sadb_alg_id; + + p = kmemdup(ualg, xfrm_alg_len(ualg), GFP_KERNEL); + if (!p) + return -ENOMEM; + + strcpy(p->alg_name, algo->name); + x->ealg = p; + x->geniv = algo->uinfo.encr.geniv; + return 0; +} + +static int attach_auth(struct xfrm_algo_auth **algpp, u8 *props, + struct nlattr *rta, struct netlink_ext_ack *extack) +{ + struct xfrm_algo *ualg; + struct xfrm_algo_auth *p; + struct xfrm_algo_desc *algo; + + if (!rta) + return 0; + + ualg = nla_data(rta); + + algo = xfrm_aalg_get_byname(ualg->alg_name, 1); + if (!algo) { + NL_SET_ERR_MSG(extack, "Requested AUTH algorithm not found"); + return -ENOSYS; + } + *props = algo->desc.sadb_alg_id; + + p = kmalloc(sizeof(*p) + (ualg->alg_key_len + 7) / 8, GFP_KERNEL); + if (!p) + return -ENOMEM; + + strcpy(p->alg_name, algo->name); + p->alg_key_len = ualg->alg_key_len; + p->alg_trunc_len = algo->uinfo.auth.icv_truncbits; + memcpy(p->alg_key, ualg->alg_key, (ualg->alg_key_len + 7) / 8); + + *algpp = p; + return 0; +} + +static int attach_auth_trunc(struct xfrm_algo_auth **algpp, u8 *props, + struct nlattr *rta, struct netlink_ext_ack *extack) +{ + struct xfrm_algo_auth *p, *ualg; + struct xfrm_algo_desc *algo; + + if (!rta) + return 0; + + ualg = nla_data(rta); + + algo = xfrm_aalg_get_byname(ualg->alg_name, 1); + if (!algo) { + NL_SET_ERR_MSG(extack, "Requested AUTH_TRUNC algorithm not found"); + return -ENOSYS; + } + if (ualg->alg_trunc_len > algo->uinfo.auth.icv_fullbits) { + NL_SET_ERR_MSG(extack, "Invalid length requested for truncated ICV"); + return -EINVAL; + } + *props = algo->desc.sadb_alg_id; + + p = kmemdup(ualg, xfrm_alg_auth_len(ualg), GFP_KERNEL); + if (!p) + return -ENOMEM; + + strcpy(p->alg_name, algo->name); + if (!p->alg_trunc_len) + p->alg_trunc_len = algo->uinfo.auth.icv_truncbits; + + *algpp = p; + return 0; +} + +static int attach_aead(struct xfrm_state *x, struct nlattr *rta, + struct netlink_ext_ack *extack) +{ + struct xfrm_algo_aead *p, *ualg; + struct xfrm_algo_desc *algo; + + if (!rta) + return 0; + + ualg = nla_data(rta); + + algo = xfrm_aead_get_byname(ualg->alg_name, ualg->alg_icv_len, 1); + if (!algo) { + NL_SET_ERR_MSG(extack, "Requested AEAD algorithm not found"); + return -ENOSYS; + } + x->props.ealgo = algo->desc.sadb_alg_id; + + p = kmemdup(ualg, aead_len(ualg), GFP_KERNEL); + if (!p) + return -ENOMEM; + + strcpy(p->alg_name, algo->name); + x->aead = p; + x->geniv = algo->uinfo.aead.geniv; + return 0; +} + +static inline int xfrm_replay_verify_len(struct xfrm_replay_state_esn *replay_esn, + struct nlattr *rp) +{ + struct xfrm_replay_state_esn *up; + unsigned int ulen; + + if (!replay_esn || !rp) + return 0; + + up = nla_data(rp); + ulen = xfrm_replay_state_esn_len(up); + + /* Check the overall length and the internal bitmap length to avoid + * potential overflow. */ + if (nla_len(rp) < (int)ulen || + xfrm_replay_state_esn_len(replay_esn) != ulen || + replay_esn->bmp_len != up->bmp_len) + return -EINVAL; + + if (up->replay_window > up->bmp_len * sizeof(__u32) * 8) + return -EINVAL; + + return 0; +} + +static int xfrm_alloc_replay_state_esn(struct xfrm_replay_state_esn **replay_esn, + struct xfrm_replay_state_esn **preplay_esn, + struct nlattr *rta) +{ + struct xfrm_replay_state_esn *p, *pp, *up; + unsigned int klen, ulen; + + if (!rta) + return 0; + + up = nla_data(rta); + klen = xfrm_replay_state_esn_len(up); + ulen = nla_len(rta) >= (int)klen ? klen : sizeof(*up); + + p = kzalloc(klen, GFP_KERNEL); + if (!p) + return -ENOMEM; + + pp = kzalloc(klen, GFP_KERNEL); + if (!pp) { + kfree(p); + return -ENOMEM; + } + + memcpy(p, up, ulen); + memcpy(pp, up, ulen); + + *replay_esn = p; + *preplay_esn = pp; + + return 0; +} + +static inline unsigned int xfrm_user_sec_ctx_size(struct xfrm_sec_ctx *xfrm_ctx) +{ + unsigned int len = 0; + + if (xfrm_ctx) { + len += sizeof(struct xfrm_user_sec_ctx); + len += xfrm_ctx->ctx_len; + } + return len; +} + +static void copy_from_user_state(struct xfrm_state *x, struct xfrm_usersa_info *p) +{ + memcpy(&x->id, &p->id, sizeof(x->id)); + memcpy(&x->sel, &p->sel, sizeof(x->sel)); + memcpy(&x->lft, &p->lft, sizeof(x->lft)); + x->props.mode = p->mode; + x->props.replay_window = min_t(unsigned int, p->replay_window, + sizeof(x->replay.bitmap) * 8); + x->props.reqid = p->reqid; + x->props.family = p->family; + memcpy(&x->props.saddr, &p->saddr, sizeof(x->props.saddr)); + x->props.flags = p->flags; + + if (!x->sel.family && !(p->flags & XFRM_STATE_AF_UNSPEC)) + x->sel.family = p->family; +} + +/* + * someday when pfkey also has support, we could have the code + * somehow made shareable and move it to xfrm_state.c - JHS + * +*/ +static void xfrm_update_ae_params(struct xfrm_state *x, struct nlattr **attrs, + int update_esn) +{ + struct nlattr *rp = attrs[XFRMA_REPLAY_VAL]; + struct nlattr *re = update_esn ? attrs[XFRMA_REPLAY_ESN_VAL] : NULL; + struct nlattr *lt = attrs[XFRMA_LTIME_VAL]; + struct nlattr *et = attrs[XFRMA_ETIMER_THRESH]; + struct nlattr *rt = attrs[XFRMA_REPLAY_THRESH]; + struct nlattr *mt = attrs[XFRMA_MTIMER_THRESH]; + + if (re && x->replay_esn && x->preplay_esn) { + struct xfrm_replay_state_esn *replay_esn; + replay_esn = nla_data(re); + memcpy(x->replay_esn, replay_esn, + xfrm_replay_state_esn_len(replay_esn)); + memcpy(x->preplay_esn, replay_esn, + xfrm_replay_state_esn_len(replay_esn)); + } + + if (rp) { + struct xfrm_replay_state *replay; + replay = nla_data(rp); + memcpy(&x->replay, replay, sizeof(*replay)); + memcpy(&x->preplay, replay, sizeof(*replay)); + } + + if (lt) { + struct xfrm_lifetime_cur *ltime; + ltime = nla_data(lt); + x->curlft.bytes = ltime->bytes; + x->curlft.packets = ltime->packets; + x->curlft.add_time = ltime->add_time; + x->curlft.use_time = ltime->use_time; + } + + if (et) + x->replay_maxage = nla_get_u32(et); + + if (rt) + x->replay_maxdiff = nla_get_u32(rt); + + if (mt) + x->mapping_maxage = nla_get_u32(mt); +} + +static void xfrm_smark_init(struct nlattr **attrs, struct xfrm_mark *m) +{ + if (attrs[XFRMA_SET_MARK]) { + m->v = nla_get_u32(attrs[XFRMA_SET_MARK]); + if (attrs[XFRMA_SET_MARK_MASK]) + m->m = nla_get_u32(attrs[XFRMA_SET_MARK_MASK]); + else + m->m = 0xffffffff; + } else { + m->v = m->m = 0; + } +} + +static struct xfrm_state *xfrm_state_construct(struct net *net, + struct xfrm_usersa_info *p, + struct nlattr **attrs, + int *errp, + struct netlink_ext_ack *extack) +{ + struct xfrm_state *x = xfrm_state_alloc(net); + int err = -ENOMEM; + + if (!x) + goto error_no_put; + + copy_from_user_state(x, p); + + if (attrs[XFRMA_ENCAP]) { + x->encap = kmemdup(nla_data(attrs[XFRMA_ENCAP]), + sizeof(*x->encap), GFP_KERNEL); + if (x->encap == NULL) + goto error; + } + + if (attrs[XFRMA_COADDR]) { + x->coaddr = kmemdup(nla_data(attrs[XFRMA_COADDR]), + sizeof(*x->coaddr), GFP_KERNEL); + if (x->coaddr == NULL) + goto error; + } + + if (attrs[XFRMA_SA_EXTRA_FLAGS]) + x->props.extra_flags = nla_get_u32(attrs[XFRMA_SA_EXTRA_FLAGS]); + + if ((err = attach_aead(x, attrs[XFRMA_ALG_AEAD], extack))) + goto error; + if ((err = attach_auth_trunc(&x->aalg, &x->props.aalgo, + attrs[XFRMA_ALG_AUTH_TRUNC], extack))) + goto error; + if (!x->props.aalgo) { + if ((err = attach_auth(&x->aalg, &x->props.aalgo, + attrs[XFRMA_ALG_AUTH], extack))) + goto error; + } + if ((err = attach_crypt(x, attrs[XFRMA_ALG_CRYPT], extack))) + goto error; + if ((err = attach_one_algo(&x->calg, &x->props.calgo, + xfrm_calg_get_byname, + attrs[XFRMA_ALG_COMP], extack))) + goto error; + + if (attrs[XFRMA_TFCPAD]) + x->tfcpad = nla_get_u32(attrs[XFRMA_TFCPAD]); + + xfrm_mark_get(attrs, &x->mark); + + xfrm_smark_init(attrs, &x->props.smark); + + if (attrs[XFRMA_IF_ID]) + x->if_id = nla_get_u32(attrs[XFRMA_IF_ID]); + + err = __xfrm_init_state(x, false, attrs[XFRMA_OFFLOAD_DEV], extack); + if (err) + goto error; + + if (attrs[XFRMA_SEC_CTX]) { + err = security_xfrm_state_alloc(x, + nla_data(attrs[XFRMA_SEC_CTX])); + if (err) + goto error; + } + + if ((err = xfrm_alloc_replay_state_esn(&x->replay_esn, &x->preplay_esn, + attrs[XFRMA_REPLAY_ESN_VAL]))) + goto error; + + x->km.seq = p->seq; + x->replay_maxdiff = net->xfrm.sysctl_aevent_rseqth; + /* sysctl_xfrm_aevent_etime is in 100ms units */ + x->replay_maxage = (net->xfrm.sysctl_aevent_etime*HZ)/XFRM_AE_ETH_M; + + if ((err = xfrm_init_replay(x, extack))) + goto error; + + /* override default values from above */ + xfrm_update_ae_params(x, attrs, 0); + + /* configure the hardware if offload is requested */ + if (attrs[XFRMA_OFFLOAD_DEV]) { + err = xfrm_dev_state_add(net, x, + nla_data(attrs[XFRMA_OFFLOAD_DEV]), + extack); + if (err) + goto error; + } + + return x; + +error: + x->km.state = XFRM_STATE_DEAD; + xfrm_state_put(x); +error_no_put: + *errp = err; + return NULL; +} + +static int xfrm_add_sa(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct xfrm_usersa_info *p = nlmsg_data(nlh); + struct xfrm_state *x; + int err; + struct km_event c; + + err = verify_newsa_info(p, attrs, extack); + if (err) + return err; + + x = xfrm_state_construct(net, p, attrs, &err, extack); + if (!x) + return err; + + xfrm_state_hold(x); + if (nlh->nlmsg_type == XFRM_MSG_NEWSA) + err = xfrm_state_add(x); + else + err = xfrm_state_update(x); + + xfrm_audit_state_add(x, err ? 0 : 1, true); + + if (err < 0) { + x->km.state = XFRM_STATE_DEAD; + xfrm_dev_state_delete(x); + __xfrm_state_put(x); + goto out; + } + + if (x->km.state == XFRM_STATE_VOID) + x->km.state = XFRM_STATE_VALID; + + c.seq = nlh->nlmsg_seq; + c.portid = nlh->nlmsg_pid; + c.event = nlh->nlmsg_type; + + km_state_notify(x, &c); +out: + xfrm_state_put(x); + return err; +} + +static struct xfrm_state *xfrm_user_state_lookup(struct net *net, + struct xfrm_usersa_id *p, + struct nlattr **attrs, + int *errp) +{ + struct xfrm_state *x = NULL; + struct xfrm_mark m; + int err; + u32 mark = xfrm_mark_get(attrs, &m); + + if (xfrm_id_proto_match(p->proto, IPSEC_PROTO_ANY)) { + err = -ESRCH; + x = xfrm_state_lookup(net, mark, &p->daddr, p->spi, p->proto, p->family); + } else { + xfrm_address_t *saddr = NULL; + + verify_one_addr(attrs, XFRMA_SRCADDR, &saddr); + if (!saddr) { + err = -EINVAL; + goto out; + } + + err = -ESRCH; + x = xfrm_state_lookup_byaddr(net, mark, + &p->daddr, saddr, + p->proto, p->family); + } + + out: + if (!x && errp) + *errp = err; + return x; +} + +static int xfrm_del_sa(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct xfrm_state *x; + int err = -ESRCH; + struct km_event c; + struct xfrm_usersa_id *p = nlmsg_data(nlh); + + x = xfrm_user_state_lookup(net, p, attrs, &err); + if (x == NULL) + return err; + + if ((err = security_xfrm_state_delete(x)) != 0) + goto out; + + if (xfrm_state_kern(x)) { + err = -EPERM; + goto out; + } + + err = xfrm_state_delete(x); + + if (err < 0) + goto out; + + c.seq = nlh->nlmsg_seq; + c.portid = nlh->nlmsg_pid; + c.event = nlh->nlmsg_type; + km_state_notify(x, &c); + +out: + xfrm_audit_state_delete(x, err ? 0 : 1, true); + xfrm_state_put(x); + return err; +} + +static void copy_to_user_state(struct xfrm_state *x, struct xfrm_usersa_info *p) +{ + memset(p, 0, sizeof(*p)); + memcpy(&p->id, &x->id, sizeof(p->id)); + memcpy(&p->sel, &x->sel, sizeof(p->sel)); + memcpy(&p->lft, &x->lft, sizeof(p->lft)); + memcpy(&p->curlft, &x->curlft, sizeof(p->curlft)); + put_unaligned(x->stats.replay_window, &p->stats.replay_window); + put_unaligned(x->stats.replay, &p->stats.replay); + put_unaligned(x->stats.integrity_failed, &p->stats.integrity_failed); + memcpy(&p->saddr, &x->props.saddr, sizeof(p->saddr)); + p->mode = x->props.mode; + p->replay_window = x->props.replay_window; + p->reqid = x->props.reqid; + p->family = x->props.family; + p->flags = x->props.flags; + p->seq = x->km.seq; +} + +struct xfrm_dump_info { + struct sk_buff *in_skb; + struct sk_buff *out_skb; + u32 nlmsg_seq; + u16 nlmsg_flags; +}; + +static int copy_sec_ctx(struct xfrm_sec_ctx *s, struct sk_buff *skb) +{ + struct xfrm_user_sec_ctx *uctx; + struct nlattr *attr; + int ctx_size = sizeof(*uctx) + s->ctx_len; + + attr = nla_reserve(skb, XFRMA_SEC_CTX, ctx_size); + if (attr == NULL) + return -EMSGSIZE; + + uctx = nla_data(attr); + uctx->exttype = XFRMA_SEC_CTX; + uctx->len = ctx_size; + uctx->ctx_doi = s->ctx_doi; + uctx->ctx_alg = s->ctx_alg; + uctx->ctx_len = s->ctx_len; + memcpy(uctx + 1, s->ctx_str, s->ctx_len); + + return 0; +} + +static int copy_user_offload(struct xfrm_dev_offload *xso, struct sk_buff *skb) +{ + struct xfrm_user_offload *xuo; + struct nlattr *attr; + + attr = nla_reserve(skb, XFRMA_OFFLOAD_DEV, sizeof(*xuo)); + if (attr == NULL) + return -EMSGSIZE; + + xuo = nla_data(attr); + memset(xuo, 0, sizeof(*xuo)); + xuo->ifindex = xso->dev->ifindex; + if (xso->dir == XFRM_DEV_OFFLOAD_IN) + xuo->flags = XFRM_OFFLOAD_INBOUND; + + return 0; +} + +static bool xfrm_redact(void) +{ + return IS_ENABLED(CONFIG_SECURITY) && + security_locked_down(LOCKDOWN_XFRM_SECRET); +} + +static int copy_to_user_auth(struct xfrm_algo_auth *auth, struct sk_buff *skb) +{ + struct xfrm_algo *algo; + struct xfrm_algo_auth *ap; + struct nlattr *nla; + bool redact_secret = xfrm_redact(); + + nla = nla_reserve(skb, XFRMA_ALG_AUTH, + sizeof(*algo) + (auth->alg_key_len + 7) / 8); + if (!nla) + return -EMSGSIZE; + algo = nla_data(nla); + strncpy(algo->alg_name, auth->alg_name, sizeof(algo->alg_name)); + + if (redact_secret && auth->alg_key_len) + memset(algo->alg_key, 0, (auth->alg_key_len + 7) / 8); + else + memcpy(algo->alg_key, auth->alg_key, + (auth->alg_key_len + 7) / 8); + algo->alg_key_len = auth->alg_key_len; + + nla = nla_reserve(skb, XFRMA_ALG_AUTH_TRUNC, xfrm_alg_auth_len(auth)); + if (!nla) + return -EMSGSIZE; + ap = nla_data(nla); + memcpy(ap, auth, sizeof(struct xfrm_algo_auth)); + if (redact_secret && auth->alg_key_len) + memset(ap->alg_key, 0, (auth->alg_key_len + 7) / 8); + else + memcpy(ap->alg_key, auth->alg_key, + (auth->alg_key_len + 7) / 8); + return 0; +} + +static int copy_to_user_aead(struct xfrm_algo_aead *aead, struct sk_buff *skb) +{ + struct nlattr *nla = nla_reserve(skb, XFRMA_ALG_AEAD, aead_len(aead)); + struct xfrm_algo_aead *ap; + bool redact_secret = xfrm_redact(); + + if (!nla) + return -EMSGSIZE; + + ap = nla_data(nla); + strscpy_pad(ap->alg_name, aead->alg_name, sizeof(ap->alg_name)); + ap->alg_key_len = aead->alg_key_len; + ap->alg_icv_len = aead->alg_icv_len; + + if (redact_secret && aead->alg_key_len) + memset(ap->alg_key, 0, (aead->alg_key_len + 7) / 8); + else + memcpy(ap->alg_key, aead->alg_key, + (aead->alg_key_len + 7) / 8); + return 0; +} + +static int copy_to_user_ealg(struct xfrm_algo *ealg, struct sk_buff *skb) +{ + struct xfrm_algo *ap; + bool redact_secret = xfrm_redact(); + struct nlattr *nla = nla_reserve(skb, XFRMA_ALG_CRYPT, + xfrm_alg_len(ealg)); + if (!nla) + return -EMSGSIZE; + + ap = nla_data(nla); + strscpy_pad(ap->alg_name, ealg->alg_name, sizeof(ap->alg_name)); + ap->alg_key_len = ealg->alg_key_len; + + if (redact_secret && ealg->alg_key_len) + memset(ap->alg_key, 0, (ealg->alg_key_len + 7) / 8); + else + memcpy(ap->alg_key, ealg->alg_key, + (ealg->alg_key_len + 7) / 8); + + return 0; +} + +static int copy_to_user_calg(struct xfrm_algo *calg, struct sk_buff *skb) +{ + struct nlattr *nla = nla_reserve(skb, XFRMA_ALG_COMP, sizeof(*calg)); + struct xfrm_algo *ap; + + if (!nla) + return -EMSGSIZE; + + ap = nla_data(nla); + strscpy_pad(ap->alg_name, calg->alg_name, sizeof(ap->alg_name)); + ap->alg_key_len = 0; + + return 0; +} + +static int copy_to_user_encap(struct xfrm_encap_tmpl *ep, struct sk_buff *skb) +{ + struct nlattr *nla = nla_reserve(skb, XFRMA_ENCAP, sizeof(*ep)); + struct xfrm_encap_tmpl *uep; + + if (!nla) + return -EMSGSIZE; + + uep = nla_data(nla); + memset(uep, 0, sizeof(*uep)); + + uep->encap_type = ep->encap_type; + uep->encap_sport = ep->encap_sport; + uep->encap_dport = ep->encap_dport; + uep->encap_oa = ep->encap_oa; + + return 0; +} + +static int xfrm_smark_put(struct sk_buff *skb, struct xfrm_mark *m) +{ + int ret = 0; + + if (m->v | m->m) { + ret = nla_put_u32(skb, XFRMA_SET_MARK, m->v); + if (!ret) + ret = nla_put_u32(skb, XFRMA_SET_MARK_MASK, m->m); + } + return ret; +} + +/* Don't change this without updating xfrm_sa_len! */ +static int copy_to_user_state_extra(struct xfrm_state *x, + struct xfrm_usersa_info *p, + struct sk_buff *skb) +{ + int ret = 0; + + copy_to_user_state(x, p); + + if (x->props.extra_flags) { + ret = nla_put_u32(skb, XFRMA_SA_EXTRA_FLAGS, + x->props.extra_flags); + if (ret) + goto out; + } + + if (x->coaddr) { + ret = nla_put(skb, XFRMA_COADDR, sizeof(*x->coaddr), x->coaddr); + if (ret) + goto out; + } + if (x->lastused) { + ret = nla_put_u64_64bit(skb, XFRMA_LASTUSED, x->lastused, + XFRMA_PAD); + if (ret) + goto out; + } + if (x->aead) { + ret = copy_to_user_aead(x->aead, skb); + if (ret) + goto out; + } + if (x->aalg) { + ret = copy_to_user_auth(x->aalg, skb); + if (ret) + goto out; + } + if (x->ealg) { + ret = copy_to_user_ealg(x->ealg, skb); + if (ret) + goto out; + } + if (x->calg) { + ret = copy_to_user_calg(x->calg, skb); + if (ret) + goto out; + } + if (x->encap) { + ret = copy_to_user_encap(x->encap, skb); + if (ret) + goto out; + } + if (x->tfcpad) { + ret = nla_put_u32(skb, XFRMA_TFCPAD, x->tfcpad); + if (ret) + goto out; + } + ret = xfrm_mark_put(skb, &x->mark); + if (ret) + goto out; + + ret = xfrm_smark_put(skb, &x->props.smark); + if (ret) + goto out; + + if (x->replay_esn) + ret = nla_put(skb, XFRMA_REPLAY_ESN_VAL, + xfrm_replay_state_esn_len(x->replay_esn), + x->replay_esn); + else + ret = nla_put(skb, XFRMA_REPLAY_VAL, sizeof(x->replay), + &x->replay); + if (ret) + goto out; + if(x->xso.dev) + ret = copy_user_offload(&x->xso, skb); + if (ret) + goto out; + if (x->if_id) { + ret = nla_put_u32(skb, XFRMA_IF_ID, x->if_id); + if (ret) + goto out; + } + if (x->security) { + ret = copy_sec_ctx(x->security, skb); + if (ret) + goto out; + } + if (x->mapping_maxage) + ret = nla_put_u32(skb, XFRMA_MTIMER_THRESH, x->mapping_maxage); +out: + return ret; +} + +static int dump_one_state(struct xfrm_state *x, int count, void *ptr) +{ + struct xfrm_dump_info *sp = ptr; + struct sk_buff *in_skb = sp->in_skb; + struct sk_buff *skb = sp->out_skb; + struct xfrm_translator *xtr; + struct xfrm_usersa_info *p; + struct nlmsghdr *nlh; + int err; + + nlh = nlmsg_put(skb, NETLINK_CB(in_skb).portid, sp->nlmsg_seq, + XFRM_MSG_NEWSA, sizeof(*p), sp->nlmsg_flags); + if (nlh == NULL) + return -EMSGSIZE; + + p = nlmsg_data(nlh); + + err = copy_to_user_state_extra(x, p, skb); + if (err) { + nlmsg_cancel(skb, nlh); + return err; + } + nlmsg_end(skb, nlh); + + xtr = xfrm_get_translator(); + if (xtr) { + err = xtr->alloc_compat(skb, nlh); + + xfrm_put_translator(xtr); + if (err) { + nlmsg_cancel(skb, nlh); + return err; + } + } + + return 0; +} + +static int xfrm_dump_sa_done(struct netlink_callback *cb) +{ + struct xfrm_state_walk *walk = (struct xfrm_state_walk *) &cb->args[1]; + struct sock *sk = cb->skb->sk; + struct net *net = sock_net(sk); + + if (cb->args[0]) + xfrm_state_walk_done(walk, net); + return 0; +} + +static int xfrm_dump_sa(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + struct xfrm_state_walk *walk = (struct xfrm_state_walk *) &cb->args[1]; + struct xfrm_dump_info info; + + BUILD_BUG_ON(sizeof(struct xfrm_state_walk) > + sizeof(cb->args) - sizeof(cb->args[0])); + + info.in_skb = cb->skb; + info.out_skb = skb; + info.nlmsg_seq = cb->nlh->nlmsg_seq; + info.nlmsg_flags = NLM_F_MULTI; + + if (!cb->args[0]) { + struct nlattr *attrs[XFRMA_MAX+1]; + struct xfrm_address_filter *filter = NULL; + u8 proto = 0; + int err; + + err = nlmsg_parse_deprecated(cb->nlh, 0, attrs, XFRMA_MAX, + xfrma_policy, cb->extack); + if (err < 0) + return err; + + if (attrs[XFRMA_ADDRESS_FILTER]) { + filter = kmemdup(nla_data(attrs[XFRMA_ADDRESS_FILTER]), + sizeof(*filter), GFP_KERNEL); + if (filter == NULL) + return -ENOMEM; + + /* see addr_match(), (prefix length >> 5) << 2 + * will be used to compare xfrm_address_t + */ + if (filter->splen > (sizeof(xfrm_address_t) << 3) || + filter->dplen > (sizeof(xfrm_address_t) << 3)) { + kfree(filter); + return -EINVAL; + } + } + + if (attrs[XFRMA_PROTO]) + proto = nla_get_u8(attrs[XFRMA_PROTO]); + + xfrm_state_walk_init(walk, proto, filter); + cb->args[0] = 1; + } + + (void) xfrm_state_walk(net, walk, dump_one_state, &info); + + return skb->len; +} + +static struct sk_buff *xfrm_state_netlink(struct sk_buff *in_skb, + struct xfrm_state *x, u32 seq) +{ + struct xfrm_dump_info info; + struct sk_buff *skb; + int err; + + skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); + if (!skb) + return ERR_PTR(-ENOMEM); + + info.in_skb = in_skb; + info.out_skb = skb; + info.nlmsg_seq = seq; + info.nlmsg_flags = 0; + + err = dump_one_state(x, 0, &info); + if (err) { + kfree_skb(skb); + return ERR_PTR(err); + } + + return skb; +} + +/* A wrapper for nlmsg_multicast() checking that nlsk is still available. + * Must be called with RCU read lock. + */ +static inline int xfrm_nlmsg_multicast(struct net *net, struct sk_buff *skb, + u32 pid, unsigned int group) +{ + struct sock *nlsk = rcu_dereference(net->xfrm.nlsk); + struct xfrm_translator *xtr; + + if (!nlsk) { + kfree_skb(skb); + return -EPIPE; + } + + xtr = xfrm_get_translator(); + if (xtr) { + int err = xtr->alloc_compat(skb, nlmsg_hdr(skb)); + + xfrm_put_translator(xtr); + if (err) { + kfree_skb(skb); + return err; + } + } + + return nlmsg_multicast(nlsk, skb, pid, group, GFP_ATOMIC); +} + +static inline unsigned int xfrm_spdinfo_msgsize(void) +{ + return NLMSG_ALIGN(4) + + nla_total_size(sizeof(struct xfrmu_spdinfo)) + + nla_total_size(sizeof(struct xfrmu_spdhinfo)) + + nla_total_size(sizeof(struct xfrmu_spdhthresh)) + + nla_total_size(sizeof(struct xfrmu_spdhthresh)); +} + +static int build_spdinfo(struct sk_buff *skb, struct net *net, + u32 portid, u32 seq, u32 flags) +{ + struct xfrmk_spdinfo si; + struct xfrmu_spdinfo spc; + struct xfrmu_spdhinfo sph; + struct xfrmu_spdhthresh spt4, spt6; + struct nlmsghdr *nlh; + int err; + u32 *f; + unsigned lseq; + + nlh = nlmsg_put(skb, portid, seq, XFRM_MSG_NEWSPDINFO, sizeof(u32), 0); + if (nlh == NULL) /* shouldn't really happen ... */ + return -EMSGSIZE; + + f = nlmsg_data(nlh); + *f = flags; + xfrm_spd_getinfo(net, &si); + spc.incnt = si.incnt; + spc.outcnt = si.outcnt; + spc.fwdcnt = si.fwdcnt; + spc.inscnt = si.inscnt; + spc.outscnt = si.outscnt; + spc.fwdscnt = si.fwdscnt; + sph.spdhcnt = si.spdhcnt; + sph.spdhmcnt = si.spdhmcnt; + + do { + lseq = read_seqbegin(&net->xfrm.policy_hthresh.lock); + + spt4.lbits = net->xfrm.policy_hthresh.lbits4; + spt4.rbits = net->xfrm.policy_hthresh.rbits4; + spt6.lbits = net->xfrm.policy_hthresh.lbits6; + spt6.rbits = net->xfrm.policy_hthresh.rbits6; + } while (read_seqretry(&net->xfrm.policy_hthresh.lock, lseq)); + + err = nla_put(skb, XFRMA_SPD_INFO, sizeof(spc), &spc); + if (!err) + err = nla_put(skb, XFRMA_SPD_HINFO, sizeof(sph), &sph); + if (!err) + err = nla_put(skb, XFRMA_SPD_IPV4_HTHRESH, sizeof(spt4), &spt4); + if (!err) + err = nla_put(skb, XFRMA_SPD_IPV6_HTHRESH, sizeof(spt6), &spt6); + if (err) { + nlmsg_cancel(skb, nlh); + return err; + } + + nlmsg_end(skb, nlh); + return 0; +} + +static int xfrm_set_spdinfo(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct xfrmu_spdhthresh *thresh4 = NULL; + struct xfrmu_spdhthresh *thresh6 = NULL; + + /* selector prefixlen thresholds to hash policies */ + if (attrs[XFRMA_SPD_IPV4_HTHRESH]) { + struct nlattr *rta = attrs[XFRMA_SPD_IPV4_HTHRESH]; + + if (nla_len(rta) < sizeof(*thresh4)) + return -EINVAL; + thresh4 = nla_data(rta); + if (thresh4->lbits > 32 || thresh4->rbits > 32) + return -EINVAL; + } + if (attrs[XFRMA_SPD_IPV6_HTHRESH]) { + struct nlattr *rta = attrs[XFRMA_SPD_IPV6_HTHRESH]; + + if (nla_len(rta) < sizeof(*thresh6)) + return -EINVAL; + thresh6 = nla_data(rta); + if (thresh6->lbits > 128 || thresh6->rbits > 128) + return -EINVAL; + } + + if (thresh4 || thresh6) { + write_seqlock(&net->xfrm.policy_hthresh.lock); + if (thresh4) { + net->xfrm.policy_hthresh.lbits4 = thresh4->lbits; + net->xfrm.policy_hthresh.rbits4 = thresh4->rbits; + } + if (thresh6) { + net->xfrm.policy_hthresh.lbits6 = thresh6->lbits; + net->xfrm.policy_hthresh.rbits6 = thresh6->rbits; + } + write_sequnlock(&net->xfrm.policy_hthresh.lock); + + xfrm_policy_hash_rebuild(net); + } + + return 0; +} + +static int xfrm_get_spdinfo(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct sk_buff *r_skb; + u32 *flags = nlmsg_data(nlh); + u32 sportid = NETLINK_CB(skb).portid; + u32 seq = nlh->nlmsg_seq; + int err; + + r_skb = nlmsg_new(xfrm_spdinfo_msgsize(), GFP_ATOMIC); + if (r_skb == NULL) + return -ENOMEM; + + err = build_spdinfo(r_skb, net, sportid, seq, *flags); + BUG_ON(err < 0); + + return nlmsg_unicast(net->xfrm.nlsk, r_skb, sportid); +} + +static inline unsigned int xfrm_sadinfo_msgsize(void) +{ + return NLMSG_ALIGN(4) + + nla_total_size(sizeof(struct xfrmu_sadhinfo)) + + nla_total_size(4); /* XFRMA_SAD_CNT */ +} + +static int build_sadinfo(struct sk_buff *skb, struct net *net, + u32 portid, u32 seq, u32 flags) +{ + struct xfrmk_sadinfo si; + struct xfrmu_sadhinfo sh; + struct nlmsghdr *nlh; + int err; + u32 *f; + + nlh = nlmsg_put(skb, portid, seq, XFRM_MSG_NEWSADINFO, sizeof(u32), 0); + if (nlh == NULL) /* shouldn't really happen ... */ + return -EMSGSIZE; + + f = nlmsg_data(nlh); + *f = flags; + xfrm_sad_getinfo(net, &si); + + sh.sadhmcnt = si.sadhmcnt; + sh.sadhcnt = si.sadhcnt; + + err = nla_put_u32(skb, XFRMA_SAD_CNT, si.sadcnt); + if (!err) + err = nla_put(skb, XFRMA_SAD_HINFO, sizeof(sh), &sh); + if (err) { + nlmsg_cancel(skb, nlh); + return err; + } + + nlmsg_end(skb, nlh); + return 0; +} + +static int xfrm_get_sadinfo(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct sk_buff *r_skb; + u32 *flags = nlmsg_data(nlh); + u32 sportid = NETLINK_CB(skb).portid; + u32 seq = nlh->nlmsg_seq; + int err; + + r_skb = nlmsg_new(xfrm_sadinfo_msgsize(), GFP_ATOMIC); + if (r_skb == NULL) + return -ENOMEM; + + err = build_sadinfo(r_skb, net, sportid, seq, *flags); + BUG_ON(err < 0); + + return nlmsg_unicast(net->xfrm.nlsk, r_skb, sportid); +} + +static int xfrm_get_sa(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct xfrm_usersa_id *p = nlmsg_data(nlh); + struct xfrm_state *x; + struct sk_buff *resp_skb; + int err = -ESRCH; + + x = xfrm_user_state_lookup(net, p, attrs, &err); + if (x == NULL) + goto out_noput; + + resp_skb = xfrm_state_netlink(skb, x, nlh->nlmsg_seq); + if (IS_ERR(resp_skb)) { + err = PTR_ERR(resp_skb); + } else { + err = nlmsg_unicast(net->xfrm.nlsk, resp_skb, NETLINK_CB(skb).portid); + } + xfrm_state_put(x); +out_noput: + return err; +} + +static int xfrm_alloc_userspi(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct xfrm_state *x; + struct xfrm_userspi_info *p; + struct xfrm_translator *xtr; + struct sk_buff *resp_skb; + xfrm_address_t *daddr; + int family; + int err; + u32 mark; + struct xfrm_mark m; + u32 if_id = 0; + + p = nlmsg_data(nlh); + err = verify_spi_info(p->info.id.proto, p->min, p->max); + if (err) + goto out_noput; + + family = p->info.family; + daddr = &p->info.id.daddr; + + x = NULL; + + mark = xfrm_mark_get(attrs, &m); + + if (attrs[XFRMA_IF_ID]) + if_id = nla_get_u32(attrs[XFRMA_IF_ID]); + + if (p->info.seq) { + x = xfrm_find_acq_byseq(net, mark, p->info.seq); + if (x && !xfrm_addr_equal(&x->id.daddr, daddr, family)) { + xfrm_state_put(x); + x = NULL; + } + } + + if (!x) + x = xfrm_find_acq(net, &m, p->info.mode, p->info.reqid, + if_id, p->info.id.proto, daddr, + &p->info.saddr, 1, + family); + err = -ENOENT; + if (x == NULL) + goto out_noput; + + err = xfrm_alloc_spi(x, p->min, p->max); + if (err) + goto out; + + resp_skb = xfrm_state_netlink(skb, x, nlh->nlmsg_seq); + if (IS_ERR(resp_skb)) { + err = PTR_ERR(resp_skb); + goto out; + } + + xtr = xfrm_get_translator(); + if (xtr) { + err = xtr->alloc_compat(skb, nlmsg_hdr(skb)); + + xfrm_put_translator(xtr); + if (err) { + kfree_skb(resp_skb); + goto out; + } + } + + err = nlmsg_unicast(net->xfrm.nlsk, resp_skb, NETLINK_CB(skb).portid); + +out: + xfrm_state_put(x); +out_noput: + return err; +} + +static int verify_policy_dir(u8 dir, struct netlink_ext_ack *extack) +{ + switch (dir) { + case XFRM_POLICY_IN: + case XFRM_POLICY_OUT: + case XFRM_POLICY_FWD: + break; + + default: + NL_SET_ERR_MSG(extack, "Invalid policy direction"); + return -EINVAL; + } + + return 0; +} + +static int verify_policy_type(u8 type, struct netlink_ext_ack *extack) +{ + switch (type) { + case XFRM_POLICY_TYPE_MAIN: +#ifdef CONFIG_XFRM_SUB_POLICY + case XFRM_POLICY_TYPE_SUB: +#endif + break; + + default: + NL_SET_ERR_MSG(extack, "Invalid policy type"); + return -EINVAL; + } + + return 0; +} + +static int verify_newpolicy_info(struct xfrm_userpolicy_info *p, + struct netlink_ext_ack *extack) +{ + int ret; + + switch (p->share) { + case XFRM_SHARE_ANY: + case XFRM_SHARE_SESSION: + case XFRM_SHARE_USER: + case XFRM_SHARE_UNIQUE: + break; + + default: + NL_SET_ERR_MSG(extack, "Invalid policy share"); + return -EINVAL; + } + + switch (p->action) { + case XFRM_POLICY_ALLOW: + case XFRM_POLICY_BLOCK: + break; + + default: + NL_SET_ERR_MSG(extack, "Invalid policy action"); + return -EINVAL; + } + + switch (p->sel.family) { + case AF_INET: + if (p->sel.prefixlen_d > 32 || p->sel.prefixlen_s > 32) { + NL_SET_ERR_MSG(extack, "Invalid prefix length in selector (must be <= 32 for IPv4)"); + return -EINVAL; + } + + break; + + case AF_INET6: +#if IS_ENABLED(CONFIG_IPV6) + if (p->sel.prefixlen_d > 128 || p->sel.prefixlen_s > 128) { + NL_SET_ERR_MSG(extack, "Invalid prefix length in selector (must be <= 128 for IPv6)"); + return -EINVAL; + } + + break; +#else + NL_SET_ERR_MSG(extack, "IPv6 support disabled"); + return -EAFNOSUPPORT; +#endif + + default: + NL_SET_ERR_MSG(extack, "Invalid selector family"); + return -EINVAL; + } + + ret = verify_policy_dir(p->dir, extack); + if (ret) + return ret; + if (p->index && (xfrm_policy_id2dir(p->index) != p->dir)) { + NL_SET_ERR_MSG(extack, "Policy index doesn't match direction"); + return -EINVAL; + } + + return 0; +} + +static int copy_from_user_sec_ctx(struct xfrm_policy *pol, struct nlattr **attrs) +{ + struct nlattr *rt = attrs[XFRMA_SEC_CTX]; + struct xfrm_user_sec_ctx *uctx; + + if (!rt) + return 0; + + uctx = nla_data(rt); + return security_xfrm_policy_alloc(&pol->security, uctx, GFP_KERNEL); +} + +static void copy_templates(struct xfrm_policy *xp, struct xfrm_user_tmpl *ut, + int nr) +{ + int i; + + xp->xfrm_nr = nr; + for (i = 0; i < nr; i++, ut++) { + struct xfrm_tmpl *t = &xp->xfrm_vec[i]; + + memcpy(&t->id, &ut->id, sizeof(struct xfrm_id)); + memcpy(&t->saddr, &ut->saddr, + sizeof(xfrm_address_t)); + t->reqid = ut->reqid; + t->mode = ut->mode; + t->share = ut->share; + t->optional = ut->optional; + t->aalgos = ut->aalgos; + t->ealgos = ut->ealgos; + t->calgos = ut->calgos; + /* If all masks are ~0, then we allow all algorithms. */ + t->allalgs = !~(t->aalgos & t->ealgos & t->calgos); + t->encap_family = ut->family; + } +} + +static int validate_tmpl(int nr, struct xfrm_user_tmpl *ut, u16 family, + int dir, struct netlink_ext_ack *extack) +{ + u16 prev_family; + int i; + + if (nr > XFRM_MAX_DEPTH) { + NL_SET_ERR_MSG(extack, "Template count must be <= XFRM_MAX_DEPTH (" __stringify(XFRM_MAX_DEPTH) ")"); + return -EINVAL; + } + + prev_family = family; + + for (i = 0; i < nr; i++) { + /* We never validated the ut->family value, so many + * applications simply leave it at zero. The check was + * never made and ut->family was ignored because all + * templates could be assumed to have the same family as + * the policy itself. Now that we will have ipv4-in-ipv6 + * and ipv6-in-ipv4 tunnels, this is no longer true. + */ + if (!ut[i].family) + ut[i].family = family; + + switch (ut[i].mode) { + case XFRM_MODE_TUNNEL: + case XFRM_MODE_BEET: + if (ut[i].optional && dir == XFRM_POLICY_OUT) { + NL_SET_ERR_MSG(extack, "Mode in optional template not allowed in outbound policy"); + return -EINVAL; + } + break; + default: + if (ut[i].family != prev_family) { + NL_SET_ERR_MSG(extack, "Mode in template doesn't support a family change"); + return -EINVAL; + } + break; + } + if (ut[i].mode >= XFRM_MODE_MAX) { + NL_SET_ERR_MSG(extack, "Mode in template must be < XFRM_MODE_MAX (" __stringify(XFRM_MODE_MAX) ")"); + return -EINVAL; + } + + prev_family = ut[i].family; + + switch (ut[i].family) { + case AF_INET: + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + break; +#endif + default: + NL_SET_ERR_MSG(extack, "Invalid family in template"); + return -EINVAL; + } + + if (!xfrm_id_proto_valid(ut[i].id.proto)) { + NL_SET_ERR_MSG(extack, "Invalid XFRM protocol in template"); + return -EINVAL; + } + } + + return 0; +} + +static int copy_from_user_tmpl(struct xfrm_policy *pol, struct nlattr **attrs, + int dir, struct netlink_ext_ack *extack) +{ + struct nlattr *rt = attrs[XFRMA_TMPL]; + + if (!rt) { + pol->xfrm_nr = 0; + } else { + struct xfrm_user_tmpl *utmpl = nla_data(rt); + int nr = nla_len(rt) / sizeof(*utmpl); + int err; + + err = validate_tmpl(nr, utmpl, pol->family, dir, extack); + if (err) + return err; + + copy_templates(pol, utmpl, nr); + } + return 0; +} + +static int copy_from_user_policy_type(u8 *tp, struct nlattr **attrs, + struct netlink_ext_ack *extack) +{ + struct nlattr *rt = attrs[XFRMA_POLICY_TYPE]; + struct xfrm_userpolicy_type *upt; + u8 type = XFRM_POLICY_TYPE_MAIN; + int err; + + if (rt) { + upt = nla_data(rt); + type = upt->type; + } + + err = verify_policy_type(type, extack); + if (err) + return err; + + *tp = type; + return 0; +} + +static void copy_from_user_policy(struct xfrm_policy *xp, struct xfrm_userpolicy_info *p) +{ + xp->priority = p->priority; + xp->index = p->index; + memcpy(&xp->selector, &p->sel, sizeof(xp->selector)); + memcpy(&xp->lft, &p->lft, sizeof(xp->lft)); + xp->action = p->action; + xp->flags = p->flags; + xp->family = p->sel.family; + /* XXX xp->share = p->share; */ +} + +static void copy_to_user_policy(struct xfrm_policy *xp, struct xfrm_userpolicy_info *p, int dir) +{ + memset(p, 0, sizeof(*p)); + memcpy(&p->sel, &xp->selector, sizeof(p->sel)); + memcpy(&p->lft, &xp->lft, sizeof(p->lft)); + memcpy(&p->curlft, &xp->curlft, sizeof(p->curlft)); + p->priority = xp->priority; + p->index = xp->index; + p->sel.family = xp->family; + p->dir = dir; + p->action = xp->action; + p->flags = xp->flags; + p->share = XFRM_SHARE_ANY; /* XXX xp->share */ +} + +static struct xfrm_policy *xfrm_policy_construct(struct net *net, + struct xfrm_userpolicy_info *p, + struct nlattr **attrs, + int *errp, + struct netlink_ext_ack *extack) +{ + struct xfrm_policy *xp = xfrm_policy_alloc(net, GFP_KERNEL); + int err; + + if (!xp) { + *errp = -ENOMEM; + return NULL; + } + + copy_from_user_policy(xp, p); + + err = copy_from_user_policy_type(&xp->type, attrs, extack); + if (err) + goto error; + + if (!(err = copy_from_user_tmpl(xp, attrs, p->dir, extack))) + err = copy_from_user_sec_ctx(xp, attrs); + if (err) + goto error; + + xfrm_mark_get(attrs, &xp->mark); + + if (attrs[XFRMA_IF_ID]) + xp->if_id = nla_get_u32(attrs[XFRMA_IF_ID]); + + return xp; + error: + *errp = err; + xp->walk.dead = 1; + xfrm_policy_destroy(xp); + return NULL; +} + +static int xfrm_add_policy(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct xfrm_userpolicy_info *p = nlmsg_data(nlh); + struct xfrm_policy *xp; + struct km_event c; + int err; + int excl; + + err = verify_newpolicy_info(p, extack); + if (err) + return err; + err = verify_sec_ctx_len(attrs, extack); + if (err) + return err; + + xp = xfrm_policy_construct(net, p, attrs, &err, extack); + if (!xp) + return err; + + /* shouldn't excl be based on nlh flags?? + * Aha! this is anti-netlink really i.e more pfkey derived + * in netlink excl is a flag and you wouldn't need + * a type XFRM_MSG_UPDPOLICY - JHS */ + excl = nlh->nlmsg_type == XFRM_MSG_NEWPOLICY; + err = xfrm_policy_insert(p->dir, xp, excl); + xfrm_audit_policy_add(xp, err ? 0 : 1, true); + + if (err) { + security_xfrm_policy_free(xp->security); + kfree(xp); + return err; + } + + c.event = nlh->nlmsg_type; + c.seq = nlh->nlmsg_seq; + c.portid = nlh->nlmsg_pid; + km_policy_notify(xp, p->dir, &c); + + xfrm_pol_put(xp); + + return 0; +} + +static int copy_to_user_tmpl(struct xfrm_policy *xp, struct sk_buff *skb) +{ + struct xfrm_user_tmpl vec[XFRM_MAX_DEPTH]; + int i; + + if (xp->xfrm_nr == 0) + return 0; + + for (i = 0; i < xp->xfrm_nr; i++) { + struct xfrm_user_tmpl *up = &vec[i]; + struct xfrm_tmpl *kp = &xp->xfrm_vec[i]; + + memset(up, 0, sizeof(*up)); + memcpy(&up->id, &kp->id, sizeof(up->id)); + up->family = kp->encap_family; + memcpy(&up->saddr, &kp->saddr, sizeof(up->saddr)); + up->reqid = kp->reqid; + up->mode = kp->mode; + up->share = kp->share; + up->optional = kp->optional; + up->aalgos = kp->aalgos; + up->ealgos = kp->ealgos; + up->calgos = kp->calgos; + } + + return nla_put(skb, XFRMA_TMPL, + sizeof(struct xfrm_user_tmpl) * xp->xfrm_nr, vec); +} + +static inline int copy_to_user_state_sec_ctx(struct xfrm_state *x, struct sk_buff *skb) +{ + if (x->security) { + return copy_sec_ctx(x->security, skb); + } + return 0; +} + +static inline int copy_to_user_sec_ctx(struct xfrm_policy *xp, struct sk_buff *skb) +{ + if (xp->security) + return copy_sec_ctx(xp->security, skb); + return 0; +} +static inline unsigned int userpolicy_type_attrsize(void) +{ +#ifdef CONFIG_XFRM_SUB_POLICY + return nla_total_size(sizeof(struct xfrm_userpolicy_type)); +#else + return 0; +#endif +} + +#ifdef CONFIG_XFRM_SUB_POLICY +static int copy_to_user_policy_type(u8 type, struct sk_buff *skb) +{ + struct xfrm_userpolicy_type upt; + + /* Sadly there are two holes in struct xfrm_userpolicy_type */ + memset(&upt, 0, sizeof(upt)); + upt.type = type; + + return nla_put(skb, XFRMA_POLICY_TYPE, sizeof(upt), &upt); +} + +#else +static inline int copy_to_user_policy_type(u8 type, struct sk_buff *skb) +{ + return 0; +} +#endif + +static int dump_one_policy(struct xfrm_policy *xp, int dir, int count, void *ptr) +{ + struct xfrm_dump_info *sp = ptr; + struct xfrm_userpolicy_info *p; + struct sk_buff *in_skb = sp->in_skb; + struct sk_buff *skb = sp->out_skb; + struct xfrm_translator *xtr; + struct nlmsghdr *nlh; + int err; + + nlh = nlmsg_put(skb, NETLINK_CB(in_skb).portid, sp->nlmsg_seq, + XFRM_MSG_NEWPOLICY, sizeof(*p), sp->nlmsg_flags); + if (nlh == NULL) + return -EMSGSIZE; + + p = nlmsg_data(nlh); + copy_to_user_policy(xp, p, dir); + err = copy_to_user_tmpl(xp, skb); + if (!err) + err = copy_to_user_sec_ctx(xp, skb); + if (!err) + err = copy_to_user_policy_type(xp->type, skb); + if (!err) + err = xfrm_mark_put(skb, &xp->mark); + if (!err) + err = xfrm_if_id_put(skb, xp->if_id); + if (err) { + nlmsg_cancel(skb, nlh); + return err; + } + nlmsg_end(skb, nlh); + + xtr = xfrm_get_translator(); + if (xtr) { + err = xtr->alloc_compat(skb, nlh); + + xfrm_put_translator(xtr); + if (err) { + nlmsg_cancel(skb, nlh); + return err; + } + } + + return 0; +} + +static int xfrm_dump_policy_done(struct netlink_callback *cb) +{ + struct xfrm_policy_walk *walk = (struct xfrm_policy_walk *)cb->args; + struct net *net = sock_net(cb->skb->sk); + + xfrm_policy_walk_done(walk, net); + return 0; +} + +static int xfrm_dump_policy_start(struct netlink_callback *cb) +{ + struct xfrm_policy_walk *walk = (struct xfrm_policy_walk *)cb->args; + + BUILD_BUG_ON(sizeof(*walk) > sizeof(cb->args)); + + xfrm_policy_walk_init(walk, XFRM_POLICY_TYPE_ANY); + return 0; +} + +static int xfrm_dump_policy(struct sk_buff *skb, struct netlink_callback *cb) +{ + struct net *net = sock_net(skb->sk); + struct xfrm_policy_walk *walk = (struct xfrm_policy_walk *)cb->args; + struct xfrm_dump_info info; + + info.in_skb = cb->skb; + info.out_skb = skb; + info.nlmsg_seq = cb->nlh->nlmsg_seq; + info.nlmsg_flags = NLM_F_MULTI; + + (void) xfrm_policy_walk(net, walk, dump_one_policy, &info); + + return skb->len; +} + +static struct sk_buff *xfrm_policy_netlink(struct sk_buff *in_skb, + struct xfrm_policy *xp, + int dir, u32 seq) +{ + struct xfrm_dump_info info; + struct sk_buff *skb; + int err; + + skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); + if (!skb) + return ERR_PTR(-ENOMEM); + + info.in_skb = in_skb; + info.out_skb = skb; + info.nlmsg_seq = seq; + info.nlmsg_flags = 0; + + err = dump_one_policy(xp, dir, 0, &info); + if (err) { + kfree_skb(skb); + return ERR_PTR(err); + } + + return skb; +} + +static int xfrm_notify_userpolicy(struct net *net) +{ + struct xfrm_userpolicy_default *up; + int len = NLMSG_ALIGN(sizeof(*up)); + struct nlmsghdr *nlh; + struct sk_buff *skb; + int err; + + skb = nlmsg_new(len, GFP_ATOMIC); + if (skb == NULL) + return -ENOMEM; + + nlh = nlmsg_put(skb, 0, 0, XFRM_MSG_GETDEFAULT, sizeof(*up), 0); + if (nlh == NULL) { + kfree_skb(skb); + return -EMSGSIZE; + } + + up = nlmsg_data(nlh); + up->in = net->xfrm.policy_default[XFRM_POLICY_IN]; + up->fwd = net->xfrm.policy_default[XFRM_POLICY_FWD]; + up->out = net->xfrm.policy_default[XFRM_POLICY_OUT]; + + nlmsg_end(skb, nlh); + + rcu_read_lock(); + err = xfrm_nlmsg_multicast(net, skb, 0, XFRMNLGRP_POLICY); + rcu_read_unlock(); + + return err; +} + +static bool xfrm_userpolicy_is_valid(__u8 policy) +{ + return policy == XFRM_USERPOLICY_BLOCK || + policy == XFRM_USERPOLICY_ACCEPT; +} + +static int xfrm_set_default(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct xfrm_userpolicy_default *up = nlmsg_data(nlh); + + if (xfrm_userpolicy_is_valid(up->in)) + net->xfrm.policy_default[XFRM_POLICY_IN] = up->in; + + if (xfrm_userpolicy_is_valid(up->fwd)) + net->xfrm.policy_default[XFRM_POLICY_FWD] = up->fwd; + + if (xfrm_userpolicy_is_valid(up->out)) + net->xfrm.policy_default[XFRM_POLICY_OUT] = up->out; + + rt_genid_bump_all(net); + + xfrm_notify_userpolicy(net); + return 0; +} + +static int xfrm_get_default(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, struct netlink_ext_ack *extack) +{ + struct sk_buff *r_skb; + struct nlmsghdr *r_nlh; + struct net *net = sock_net(skb->sk); + struct xfrm_userpolicy_default *r_up; + int len = NLMSG_ALIGN(sizeof(struct xfrm_userpolicy_default)); + u32 portid = NETLINK_CB(skb).portid; + u32 seq = nlh->nlmsg_seq; + + r_skb = nlmsg_new(len, GFP_ATOMIC); + if (!r_skb) + return -ENOMEM; + + r_nlh = nlmsg_put(r_skb, portid, seq, XFRM_MSG_GETDEFAULT, sizeof(*r_up), 0); + if (!r_nlh) { + kfree_skb(r_skb); + return -EMSGSIZE; + } + + r_up = nlmsg_data(r_nlh); + r_up->in = net->xfrm.policy_default[XFRM_POLICY_IN]; + r_up->fwd = net->xfrm.policy_default[XFRM_POLICY_FWD]; + r_up->out = net->xfrm.policy_default[XFRM_POLICY_OUT]; + nlmsg_end(r_skb, r_nlh); + + return nlmsg_unicast(net->xfrm.nlsk, r_skb, portid); +} + +static int xfrm_get_policy(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct xfrm_policy *xp; + struct xfrm_userpolicy_id *p; + u8 type = XFRM_POLICY_TYPE_MAIN; + int err; + struct km_event c; + int delete; + struct xfrm_mark m; + u32 if_id = 0; + + p = nlmsg_data(nlh); + delete = nlh->nlmsg_type == XFRM_MSG_DELPOLICY; + + err = copy_from_user_policy_type(&type, attrs, extack); + if (err) + return err; + + err = verify_policy_dir(p->dir, extack); + if (err) + return err; + + if (attrs[XFRMA_IF_ID]) + if_id = nla_get_u32(attrs[XFRMA_IF_ID]); + + xfrm_mark_get(attrs, &m); + + if (p->index) + xp = xfrm_policy_byid(net, &m, if_id, type, p->dir, + p->index, delete, &err); + else { + struct nlattr *rt = attrs[XFRMA_SEC_CTX]; + struct xfrm_sec_ctx *ctx; + + err = verify_sec_ctx_len(attrs, extack); + if (err) + return err; + + ctx = NULL; + if (rt) { + struct xfrm_user_sec_ctx *uctx = nla_data(rt); + + err = security_xfrm_policy_alloc(&ctx, uctx, GFP_KERNEL); + if (err) + return err; + } + xp = xfrm_policy_bysel_ctx(net, &m, if_id, type, p->dir, + &p->sel, ctx, delete, &err); + security_xfrm_policy_free(ctx); + } + if (xp == NULL) + return -ENOENT; + + if (!delete) { + struct sk_buff *resp_skb; + + resp_skb = xfrm_policy_netlink(skb, xp, p->dir, nlh->nlmsg_seq); + if (IS_ERR(resp_skb)) { + err = PTR_ERR(resp_skb); + } else { + err = nlmsg_unicast(net->xfrm.nlsk, resp_skb, + NETLINK_CB(skb).portid); + } + } else { + xfrm_audit_policy_delete(xp, err ? 0 : 1, true); + + if (err != 0) + goto out; + + c.data.byid = p->index; + c.event = nlh->nlmsg_type; + c.seq = nlh->nlmsg_seq; + c.portid = nlh->nlmsg_pid; + km_policy_notify(xp, p->dir, &c); + } + +out: + xfrm_pol_put(xp); + return err; +} + +static int xfrm_flush_sa(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct km_event c; + struct xfrm_usersa_flush *p = nlmsg_data(nlh); + int err; + + err = xfrm_state_flush(net, p->proto, true, false); + if (err) { + if (err == -ESRCH) /* empty table */ + return 0; + return err; + } + c.data.proto = p->proto; + c.event = nlh->nlmsg_type; + c.seq = nlh->nlmsg_seq; + c.portid = nlh->nlmsg_pid; + c.net = net; + km_state_notify(NULL, &c); + + return 0; +} + +static inline unsigned int xfrm_aevent_msgsize(struct xfrm_state *x) +{ + unsigned int replay_size = x->replay_esn ? + xfrm_replay_state_esn_len(x->replay_esn) : + sizeof(struct xfrm_replay_state); + + return NLMSG_ALIGN(sizeof(struct xfrm_aevent_id)) + + nla_total_size(replay_size) + + nla_total_size_64bit(sizeof(struct xfrm_lifetime_cur)) + + nla_total_size(sizeof(struct xfrm_mark)) + + nla_total_size(4) /* XFRM_AE_RTHR */ + + nla_total_size(4); /* XFRM_AE_ETHR */ +} + +static int build_aevent(struct sk_buff *skb, struct xfrm_state *x, const struct km_event *c) +{ + struct xfrm_aevent_id *id; + struct nlmsghdr *nlh; + int err; + + nlh = nlmsg_put(skb, c->portid, c->seq, XFRM_MSG_NEWAE, sizeof(*id), 0); + if (nlh == NULL) + return -EMSGSIZE; + + id = nlmsg_data(nlh); + memset(&id->sa_id, 0, sizeof(id->sa_id)); + memcpy(&id->sa_id.daddr, &x->id.daddr, sizeof(x->id.daddr)); + id->sa_id.spi = x->id.spi; + id->sa_id.family = x->props.family; + id->sa_id.proto = x->id.proto; + memcpy(&id->saddr, &x->props.saddr, sizeof(x->props.saddr)); + id->reqid = x->props.reqid; + id->flags = c->data.aevent; + + if (x->replay_esn) { + err = nla_put(skb, XFRMA_REPLAY_ESN_VAL, + xfrm_replay_state_esn_len(x->replay_esn), + x->replay_esn); + } else { + err = nla_put(skb, XFRMA_REPLAY_VAL, sizeof(x->replay), + &x->replay); + } + if (err) + goto out_cancel; + err = nla_put_64bit(skb, XFRMA_LTIME_VAL, sizeof(x->curlft), &x->curlft, + XFRMA_PAD); + if (err) + goto out_cancel; + + if (id->flags & XFRM_AE_RTHR) { + err = nla_put_u32(skb, XFRMA_REPLAY_THRESH, x->replay_maxdiff); + if (err) + goto out_cancel; + } + if (id->flags & XFRM_AE_ETHR) { + err = nla_put_u32(skb, XFRMA_ETIMER_THRESH, + x->replay_maxage * 10 / HZ); + if (err) + goto out_cancel; + } + err = xfrm_mark_put(skb, &x->mark); + if (err) + goto out_cancel; + + err = xfrm_if_id_put(skb, x->if_id); + if (err) + goto out_cancel; + + nlmsg_end(skb, nlh); + return 0; + +out_cancel: + nlmsg_cancel(skb, nlh); + return err; +} + +static int xfrm_get_ae(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct xfrm_state *x; + struct sk_buff *r_skb; + int err; + struct km_event c; + u32 mark; + struct xfrm_mark m; + struct xfrm_aevent_id *p = nlmsg_data(nlh); + struct xfrm_usersa_id *id = &p->sa_id; + + mark = xfrm_mark_get(attrs, &m); + + x = xfrm_state_lookup(net, mark, &id->daddr, id->spi, id->proto, id->family); + if (x == NULL) + return -ESRCH; + + r_skb = nlmsg_new(xfrm_aevent_msgsize(x), GFP_ATOMIC); + if (r_skb == NULL) { + xfrm_state_put(x); + return -ENOMEM; + } + + /* + * XXX: is this lock really needed - none of the other + * gets lock (the concern is things getting updated + * while we are still reading) - jhs + */ + spin_lock_bh(&x->lock); + c.data.aevent = p->flags; + c.seq = nlh->nlmsg_seq; + c.portid = nlh->nlmsg_pid; + + err = build_aevent(r_skb, x, &c); + BUG_ON(err < 0); + + err = nlmsg_unicast(net->xfrm.nlsk, r_skb, NETLINK_CB(skb).portid); + spin_unlock_bh(&x->lock); + xfrm_state_put(x); + return err; +} + +static int xfrm_new_ae(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct xfrm_state *x; + struct km_event c; + int err = -EINVAL; + u32 mark = 0; + struct xfrm_mark m; + struct xfrm_aevent_id *p = nlmsg_data(nlh); + struct nlattr *rp = attrs[XFRMA_REPLAY_VAL]; + struct nlattr *re = attrs[XFRMA_REPLAY_ESN_VAL]; + struct nlattr *lt = attrs[XFRMA_LTIME_VAL]; + struct nlattr *et = attrs[XFRMA_ETIMER_THRESH]; + struct nlattr *rt = attrs[XFRMA_REPLAY_THRESH]; + + if (!lt && !rp && !re && !et && !rt) + return err; + + /* pedantic mode - thou shalt sayeth replaceth */ + if (!(nlh->nlmsg_flags&NLM_F_REPLACE)) + return err; + + mark = xfrm_mark_get(attrs, &m); + + x = xfrm_state_lookup(net, mark, &p->sa_id.daddr, p->sa_id.spi, p->sa_id.proto, p->sa_id.family); + if (x == NULL) + return -ESRCH; + + if (x->km.state != XFRM_STATE_VALID) + goto out; + + err = xfrm_replay_verify_len(x->replay_esn, re); + if (err) + goto out; + + spin_lock_bh(&x->lock); + xfrm_update_ae_params(x, attrs, 1); + spin_unlock_bh(&x->lock); + + c.event = nlh->nlmsg_type; + c.seq = nlh->nlmsg_seq; + c.portid = nlh->nlmsg_pid; + c.data.aevent = XFRM_AE_CU; + km_state_notify(x, &c); + err = 0; +out: + xfrm_state_put(x); + return err; +} + +static int xfrm_flush_policy(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct km_event c; + u8 type = XFRM_POLICY_TYPE_MAIN; + int err; + + err = copy_from_user_policy_type(&type, attrs, extack); + if (err) + return err; + + err = xfrm_policy_flush(net, type, true); + if (err) { + if (err == -ESRCH) /* empty table */ + return 0; + return err; + } + + c.data.type = type; + c.event = nlh->nlmsg_type; + c.seq = nlh->nlmsg_seq; + c.portid = nlh->nlmsg_pid; + c.net = net; + km_policy_notify(NULL, 0, &c); + return 0; +} + +static int xfrm_add_pol_expire(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct xfrm_policy *xp; + struct xfrm_user_polexpire *up = nlmsg_data(nlh); + struct xfrm_userpolicy_info *p = &up->pol; + u8 type = XFRM_POLICY_TYPE_MAIN; + int err = -ENOENT; + struct xfrm_mark m; + u32 if_id = 0; + + err = copy_from_user_policy_type(&type, attrs, extack); + if (err) + return err; + + err = verify_policy_dir(p->dir, extack); + if (err) + return err; + + if (attrs[XFRMA_IF_ID]) + if_id = nla_get_u32(attrs[XFRMA_IF_ID]); + + xfrm_mark_get(attrs, &m); + + if (p->index) + xp = xfrm_policy_byid(net, &m, if_id, type, p->dir, p->index, + 0, &err); + else { + struct nlattr *rt = attrs[XFRMA_SEC_CTX]; + struct xfrm_sec_ctx *ctx; + + err = verify_sec_ctx_len(attrs, extack); + if (err) + return err; + + ctx = NULL; + if (rt) { + struct xfrm_user_sec_ctx *uctx = nla_data(rt); + + err = security_xfrm_policy_alloc(&ctx, uctx, GFP_KERNEL); + if (err) + return err; + } + xp = xfrm_policy_bysel_ctx(net, &m, if_id, type, p->dir, + &p->sel, ctx, 0, &err); + security_xfrm_policy_free(ctx); + } + if (xp == NULL) + return -ENOENT; + + if (unlikely(xp->walk.dead)) + goto out; + + err = 0; + if (up->hard) { + xfrm_policy_delete(xp, p->dir); + xfrm_audit_policy_delete(xp, 1, true); + } + km_policy_expired(xp, p->dir, up->hard, nlh->nlmsg_pid); + +out: + xfrm_pol_put(xp); + return err; +} + +static int xfrm_add_sa_expire(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct xfrm_state *x; + int err; + struct xfrm_user_expire *ue = nlmsg_data(nlh); + struct xfrm_usersa_info *p = &ue->state; + struct xfrm_mark m; + u32 mark = xfrm_mark_get(attrs, &m); + + x = xfrm_state_lookup(net, mark, &p->id.daddr, p->id.spi, p->id.proto, p->family); + + err = -ENOENT; + if (x == NULL) + return err; + + spin_lock_bh(&x->lock); + err = -EINVAL; + if (x->km.state != XFRM_STATE_VALID) + goto out; + km_state_expired(x, ue->hard, nlh->nlmsg_pid); + + if (ue->hard) { + __xfrm_state_delete(x); + xfrm_audit_state_delete(x, 1, true); + } + err = 0; +out: + spin_unlock_bh(&x->lock); + xfrm_state_put(x); + return err; +} + +static int xfrm_add_acquire(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct xfrm_policy *xp; + struct xfrm_user_tmpl *ut; + int i; + struct nlattr *rt = attrs[XFRMA_TMPL]; + struct xfrm_mark mark; + + struct xfrm_user_acquire *ua = nlmsg_data(nlh); + struct xfrm_state *x = xfrm_state_alloc(net); + int err = -ENOMEM; + + if (!x) + goto nomem; + + xfrm_mark_get(attrs, &mark); + + err = verify_newpolicy_info(&ua->policy, extack); + if (err) + goto free_state; + err = verify_sec_ctx_len(attrs, extack); + if (err) + goto free_state; + + /* build an XP */ + xp = xfrm_policy_construct(net, &ua->policy, attrs, &err, extack); + if (!xp) + goto free_state; + + memcpy(&x->id, &ua->id, sizeof(ua->id)); + memcpy(&x->props.saddr, &ua->saddr, sizeof(ua->saddr)); + memcpy(&x->sel, &ua->sel, sizeof(ua->sel)); + xp->mark.m = x->mark.m = mark.m; + xp->mark.v = x->mark.v = mark.v; + ut = nla_data(rt); + /* extract the templates and for each call km_key */ + for (i = 0; i < xp->xfrm_nr; i++, ut++) { + struct xfrm_tmpl *t = &xp->xfrm_vec[i]; + memcpy(&x->id, &t->id, sizeof(x->id)); + x->props.mode = t->mode; + x->props.reqid = t->reqid; + x->props.family = ut->family; + t->aalgos = ua->aalgos; + t->ealgos = ua->ealgos; + t->calgos = ua->calgos; + err = km_query(x, t, xp); + + } + + xfrm_state_free(x); + kfree(xp); + + return 0; + +free_state: + xfrm_state_free(x); +nomem: + return err; +} + +#ifdef CONFIG_XFRM_MIGRATE +static int copy_from_user_migrate(struct xfrm_migrate *ma, + struct xfrm_kmaddress *k, + struct nlattr **attrs, int *num) +{ + struct nlattr *rt = attrs[XFRMA_MIGRATE]; + struct xfrm_user_migrate *um; + int i, num_migrate; + + if (k != NULL) { + struct xfrm_user_kmaddress *uk; + + uk = nla_data(attrs[XFRMA_KMADDRESS]); + memcpy(&k->local, &uk->local, sizeof(k->local)); + memcpy(&k->remote, &uk->remote, sizeof(k->remote)); + k->family = uk->family; + k->reserved = uk->reserved; + } + + um = nla_data(rt); + num_migrate = nla_len(rt) / sizeof(*um); + + if (num_migrate <= 0 || num_migrate > XFRM_MAX_DEPTH) + return -EINVAL; + + for (i = 0; i < num_migrate; i++, um++, ma++) { + memcpy(&ma->old_daddr, &um->old_daddr, sizeof(ma->old_daddr)); + memcpy(&ma->old_saddr, &um->old_saddr, sizeof(ma->old_saddr)); + memcpy(&ma->new_daddr, &um->new_daddr, sizeof(ma->new_daddr)); + memcpy(&ma->new_saddr, &um->new_saddr, sizeof(ma->new_saddr)); + + ma->proto = um->proto; + ma->mode = um->mode; + ma->reqid = um->reqid; + + ma->old_family = um->old_family; + ma->new_family = um->new_family; + } + + *num = i; + return 0; +} + +static int xfrm_do_migrate(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, struct netlink_ext_ack *extack) +{ + struct xfrm_userpolicy_id *pi = nlmsg_data(nlh); + struct xfrm_migrate m[XFRM_MAX_DEPTH]; + struct xfrm_kmaddress km, *kmp; + u8 type; + int err; + int n = 0; + struct net *net = sock_net(skb->sk); + struct xfrm_encap_tmpl *encap = NULL; + u32 if_id = 0; + + if (attrs[XFRMA_MIGRATE] == NULL) + return -EINVAL; + + kmp = attrs[XFRMA_KMADDRESS] ? &km : NULL; + + err = copy_from_user_policy_type(&type, attrs, extack); + if (err) + return err; + + err = copy_from_user_migrate((struct xfrm_migrate *)m, kmp, attrs, &n); + if (err) + return err; + + if (!n) + return 0; + + if (attrs[XFRMA_ENCAP]) { + encap = kmemdup(nla_data(attrs[XFRMA_ENCAP]), + sizeof(*encap), GFP_KERNEL); + if (!encap) + return -ENOMEM; + } + + if (attrs[XFRMA_IF_ID]) + if_id = nla_get_u32(attrs[XFRMA_IF_ID]); + + err = xfrm_migrate(&pi->sel, pi->dir, type, m, n, kmp, net, encap, if_id); + + kfree(encap); + + return err; +} +#else +static int xfrm_do_migrate(struct sk_buff *skb, struct nlmsghdr *nlh, + struct nlattr **attrs, struct netlink_ext_ack *extack) +{ + return -ENOPROTOOPT; +} +#endif + +#ifdef CONFIG_XFRM_MIGRATE +static int copy_to_user_migrate(const struct xfrm_migrate *m, struct sk_buff *skb) +{ + struct xfrm_user_migrate um; + + memset(&um, 0, sizeof(um)); + um.proto = m->proto; + um.mode = m->mode; + um.reqid = m->reqid; + um.old_family = m->old_family; + memcpy(&um.old_daddr, &m->old_daddr, sizeof(um.old_daddr)); + memcpy(&um.old_saddr, &m->old_saddr, sizeof(um.old_saddr)); + um.new_family = m->new_family; + memcpy(&um.new_daddr, &m->new_daddr, sizeof(um.new_daddr)); + memcpy(&um.new_saddr, &m->new_saddr, sizeof(um.new_saddr)); + + return nla_put(skb, XFRMA_MIGRATE, sizeof(um), &um); +} + +static int copy_to_user_kmaddress(const struct xfrm_kmaddress *k, struct sk_buff *skb) +{ + struct xfrm_user_kmaddress uk; + + memset(&uk, 0, sizeof(uk)); + uk.family = k->family; + uk.reserved = k->reserved; + memcpy(&uk.local, &k->local, sizeof(uk.local)); + memcpy(&uk.remote, &k->remote, sizeof(uk.remote)); + + return nla_put(skb, XFRMA_KMADDRESS, sizeof(uk), &uk); +} + +static inline unsigned int xfrm_migrate_msgsize(int num_migrate, int with_kma, + int with_encp) +{ + return NLMSG_ALIGN(sizeof(struct xfrm_userpolicy_id)) + + (with_kma ? nla_total_size(sizeof(struct xfrm_kmaddress)) : 0) + + (with_encp ? nla_total_size(sizeof(struct xfrm_encap_tmpl)) : 0) + + nla_total_size(sizeof(struct xfrm_user_migrate) * num_migrate) + + userpolicy_type_attrsize(); +} + +static int build_migrate(struct sk_buff *skb, const struct xfrm_migrate *m, + int num_migrate, const struct xfrm_kmaddress *k, + const struct xfrm_selector *sel, + const struct xfrm_encap_tmpl *encap, u8 dir, u8 type) +{ + const struct xfrm_migrate *mp; + struct xfrm_userpolicy_id *pol_id; + struct nlmsghdr *nlh; + int i, err; + + nlh = nlmsg_put(skb, 0, 0, XFRM_MSG_MIGRATE, sizeof(*pol_id), 0); + if (nlh == NULL) + return -EMSGSIZE; + + pol_id = nlmsg_data(nlh); + /* copy data from selector, dir, and type to the pol_id */ + memset(pol_id, 0, sizeof(*pol_id)); + memcpy(&pol_id->sel, sel, sizeof(pol_id->sel)); + pol_id->dir = dir; + + if (k != NULL) { + err = copy_to_user_kmaddress(k, skb); + if (err) + goto out_cancel; + } + if (encap) { + err = nla_put(skb, XFRMA_ENCAP, sizeof(*encap), encap); + if (err) + goto out_cancel; + } + err = copy_to_user_policy_type(type, skb); + if (err) + goto out_cancel; + for (i = 0, mp = m ; i < num_migrate; i++, mp++) { + err = copy_to_user_migrate(mp, skb); + if (err) + goto out_cancel; + } + + nlmsg_end(skb, nlh); + return 0; + +out_cancel: + nlmsg_cancel(skb, nlh); + return err; +} + +static int xfrm_send_migrate(const struct xfrm_selector *sel, u8 dir, u8 type, + const struct xfrm_migrate *m, int num_migrate, + const struct xfrm_kmaddress *k, + const struct xfrm_encap_tmpl *encap) +{ + struct net *net = &init_net; + struct sk_buff *skb; + int err; + + skb = nlmsg_new(xfrm_migrate_msgsize(num_migrate, !!k, !!encap), + GFP_ATOMIC); + if (skb == NULL) + return -ENOMEM; + + /* build migrate */ + err = build_migrate(skb, m, num_migrate, k, sel, encap, dir, type); + BUG_ON(err < 0); + + return xfrm_nlmsg_multicast(net, skb, 0, XFRMNLGRP_MIGRATE); +} +#else +static int xfrm_send_migrate(const struct xfrm_selector *sel, u8 dir, u8 type, + const struct xfrm_migrate *m, int num_migrate, + const struct xfrm_kmaddress *k, + const struct xfrm_encap_tmpl *encap) +{ + return -ENOPROTOOPT; +} +#endif + +#define XMSGSIZE(type) sizeof(struct type) + +const int xfrm_msg_min[XFRM_NR_MSGTYPES] = { + [XFRM_MSG_NEWSA - XFRM_MSG_BASE] = XMSGSIZE(xfrm_usersa_info), + [XFRM_MSG_DELSA - XFRM_MSG_BASE] = XMSGSIZE(xfrm_usersa_id), + [XFRM_MSG_GETSA - XFRM_MSG_BASE] = XMSGSIZE(xfrm_usersa_id), + [XFRM_MSG_NEWPOLICY - XFRM_MSG_BASE] = XMSGSIZE(xfrm_userpolicy_info), + [XFRM_MSG_DELPOLICY - XFRM_MSG_BASE] = XMSGSIZE(xfrm_userpolicy_id), + [XFRM_MSG_GETPOLICY - XFRM_MSG_BASE] = XMSGSIZE(xfrm_userpolicy_id), + [XFRM_MSG_ALLOCSPI - XFRM_MSG_BASE] = XMSGSIZE(xfrm_userspi_info), + [XFRM_MSG_ACQUIRE - XFRM_MSG_BASE] = XMSGSIZE(xfrm_user_acquire), + [XFRM_MSG_EXPIRE - XFRM_MSG_BASE] = XMSGSIZE(xfrm_user_expire), + [XFRM_MSG_UPDPOLICY - XFRM_MSG_BASE] = XMSGSIZE(xfrm_userpolicy_info), + [XFRM_MSG_UPDSA - XFRM_MSG_BASE] = XMSGSIZE(xfrm_usersa_info), + [XFRM_MSG_POLEXPIRE - XFRM_MSG_BASE] = XMSGSIZE(xfrm_user_polexpire), + [XFRM_MSG_FLUSHSA - XFRM_MSG_BASE] = XMSGSIZE(xfrm_usersa_flush), + [XFRM_MSG_FLUSHPOLICY - XFRM_MSG_BASE] = 0, + [XFRM_MSG_NEWAE - XFRM_MSG_BASE] = XMSGSIZE(xfrm_aevent_id), + [XFRM_MSG_GETAE - XFRM_MSG_BASE] = XMSGSIZE(xfrm_aevent_id), + [XFRM_MSG_REPORT - XFRM_MSG_BASE] = XMSGSIZE(xfrm_user_report), + [XFRM_MSG_MIGRATE - XFRM_MSG_BASE] = XMSGSIZE(xfrm_userpolicy_id), + [XFRM_MSG_GETSADINFO - XFRM_MSG_BASE] = sizeof(u32), + [XFRM_MSG_NEWSPDINFO - XFRM_MSG_BASE] = sizeof(u32), + [XFRM_MSG_GETSPDINFO - XFRM_MSG_BASE] = sizeof(u32), + [XFRM_MSG_SETDEFAULT - XFRM_MSG_BASE] = XMSGSIZE(xfrm_userpolicy_default), + [XFRM_MSG_GETDEFAULT - XFRM_MSG_BASE] = XMSGSIZE(xfrm_userpolicy_default), +}; +EXPORT_SYMBOL_GPL(xfrm_msg_min); + +#undef XMSGSIZE + +const struct nla_policy xfrma_policy[XFRMA_MAX+1] = { + [XFRMA_SA] = { .len = sizeof(struct xfrm_usersa_info)}, + [XFRMA_POLICY] = { .len = sizeof(struct xfrm_userpolicy_info)}, + [XFRMA_LASTUSED] = { .type = NLA_U64}, + [XFRMA_ALG_AUTH_TRUNC] = { .len = sizeof(struct xfrm_algo_auth)}, + [XFRMA_ALG_AEAD] = { .len = sizeof(struct xfrm_algo_aead) }, + [XFRMA_ALG_AUTH] = { .len = sizeof(struct xfrm_algo) }, + [XFRMA_ALG_CRYPT] = { .len = sizeof(struct xfrm_algo) }, + [XFRMA_ALG_COMP] = { .len = sizeof(struct xfrm_algo) }, + [XFRMA_ENCAP] = { .len = sizeof(struct xfrm_encap_tmpl) }, + [XFRMA_TMPL] = { .len = sizeof(struct xfrm_user_tmpl) }, + [XFRMA_SEC_CTX] = { .len = sizeof(struct xfrm_user_sec_ctx) }, + [XFRMA_LTIME_VAL] = { .len = sizeof(struct xfrm_lifetime_cur) }, + [XFRMA_REPLAY_VAL] = { .len = sizeof(struct xfrm_replay_state) }, + [XFRMA_REPLAY_THRESH] = { .type = NLA_U32 }, + [XFRMA_ETIMER_THRESH] = { .type = NLA_U32 }, + [XFRMA_SRCADDR] = { .len = sizeof(xfrm_address_t) }, + [XFRMA_COADDR] = { .len = sizeof(xfrm_address_t) }, + [XFRMA_POLICY_TYPE] = { .len = sizeof(struct xfrm_userpolicy_type)}, + [XFRMA_MIGRATE] = { .len = sizeof(struct xfrm_user_migrate) }, + [XFRMA_KMADDRESS] = { .len = sizeof(struct xfrm_user_kmaddress) }, + [XFRMA_MARK] = { .len = sizeof(struct xfrm_mark) }, + [XFRMA_TFCPAD] = { .type = NLA_U32 }, + [XFRMA_REPLAY_ESN_VAL] = { .len = sizeof(struct xfrm_replay_state_esn) }, + [XFRMA_SA_EXTRA_FLAGS] = { .type = NLA_U32 }, + [XFRMA_PROTO] = { .type = NLA_U8 }, + [XFRMA_ADDRESS_FILTER] = { .len = sizeof(struct xfrm_address_filter) }, + [XFRMA_OFFLOAD_DEV] = { .len = sizeof(struct xfrm_user_offload) }, + [XFRMA_SET_MARK] = { .type = NLA_U32 }, + [XFRMA_SET_MARK_MASK] = { .type = NLA_U32 }, + [XFRMA_IF_ID] = { .type = NLA_U32 }, + [XFRMA_MTIMER_THRESH] = { .type = NLA_U32 }, +}; +EXPORT_SYMBOL_GPL(xfrma_policy); + +static const struct nla_policy xfrma_spd_policy[XFRMA_SPD_MAX+1] = { + [XFRMA_SPD_IPV4_HTHRESH] = { .len = sizeof(struct xfrmu_spdhthresh) }, + [XFRMA_SPD_IPV6_HTHRESH] = { .len = sizeof(struct xfrmu_spdhthresh) }, +}; + +static const struct xfrm_link { + int (*doit)(struct sk_buff *, struct nlmsghdr *, struct nlattr **, + struct netlink_ext_ack *); + int (*start)(struct netlink_callback *); + int (*dump)(struct sk_buff *, struct netlink_callback *); + int (*done)(struct netlink_callback *); + const struct nla_policy *nla_pol; + int nla_max; +} xfrm_dispatch[XFRM_NR_MSGTYPES] = { + [XFRM_MSG_NEWSA - XFRM_MSG_BASE] = { .doit = xfrm_add_sa }, + [XFRM_MSG_DELSA - XFRM_MSG_BASE] = { .doit = xfrm_del_sa }, + [XFRM_MSG_GETSA - XFRM_MSG_BASE] = { .doit = xfrm_get_sa, + .dump = xfrm_dump_sa, + .done = xfrm_dump_sa_done }, + [XFRM_MSG_NEWPOLICY - XFRM_MSG_BASE] = { .doit = xfrm_add_policy }, + [XFRM_MSG_DELPOLICY - XFRM_MSG_BASE] = { .doit = xfrm_get_policy }, + [XFRM_MSG_GETPOLICY - XFRM_MSG_BASE] = { .doit = xfrm_get_policy, + .start = xfrm_dump_policy_start, + .dump = xfrm_dump_policy, + .done = xfrm_dump_policy_done }, + [XFRM_MSG_ALLOCSPI - XFRM_MSG_BASE] = { .doit = xfrm_alloc_userspi }, + [XFRM_MSG_ACQUIRE - XFRM_MSG_BASE] = { .doit = xfrm_add_acquire }, + [XFRM_MSG_EXPIRE - XFRM_MSG_BASE] = { .doit = xfrm_add_sa_expire }, + [XFRM_MSG_UPDPOLICY - XFRM_MSG_BASE] = { .doit = xfrm_add_policy }, + [XFRM_MSG_UPDSA - XFRM_MSG_BASE] = { .doit = xfrm_add_sa }, + [XFRM_MSG_POLEXPIRE - XFRM_MSG_BASE] = { .doit = xfrm_add_pol_expire}, + [XFRM_MSG_FLUSHSA - XFRM_MSG_BASE] = { .doit = xfrm_flush_sa }, + [XFRM_MSG_FLUSHPOLICY - XFRM_MSG_BASE] = { .doit = xfrm_flush_policy }, + [XFRM_MSG_NEWAE - XFRM_MSG_BASE] = { .doit = xfrm_new_ae }, + [XFRM_MSG_GETAE - XFRM_MSG_BASE] = { .doit = xfrm_get_ae }, + [XFRM_MSG_MIGRATE - XFRM_MSG_BASE] = { .doit = xfrm_do_migrate }, + [XFRM_MSG_GETSADINFO - XFRM_MSG_BASE] = { .doit = xfrm_get_sadinfo }, + [XFRM_MSG_NEWSPDINFO - XFRM_MSG_BASE] = { .doit = xfrm_set_spdinfo, + .nla_pol = xfrma_spd_policy, + .nla_max = XFRMA_SPD_MAX }, + [XFRM_MSG_GETSPDINFO - XFRM_MSG_BASE] = { .doit = xfrm_get_spdinfo }, + [XFRM_MSG_SETDEFAULT - XFRM_MSG_BASE] = { .doit = xfrm_set_default }, + [XFRM_MSG_GETDEFAULT - XFRM_MSG_BASE] = { .doit = xfrm_get_default }, +}; + +static int xfrm_user_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh, + struct netlink_ext_ack *extack) +{ + struct net *net = sock_net(skb->sk); + struct nlattr *attrs[XFRMA_MAX+1]; + const struct xfrm_link *link; + struct nlmsghdr *nlh64 = NULL; + int type, err; + + type = nlh->nlmsg_type; + if (type > XFRM_MSG_MAX) + return -EINVAL; + + type -= XFRM_MSG_BASE; + link = &xfrm_dispatch[type]; + + /* All operations require privileges, even GET */ + if (!netlink_net_capable(skb, CAP_NET_ADMIN)) + return -EPERM; + + if (in_compat_syscall()) { + struct xfrm_translator *xtr = xfrm_get_translator(); + + if (!xtr) + return -EOPNOTSUPP; + + nlh64 = xtr->rcv_msg_compat(nlh, link->nla_max, + link->nla_pol, extack); + xfrm_put_translator(xtr); + if (IS_ERR(nlh64)) + return PTR_ERR(nlh64); + if (nlh64) + nlh = nlh64; + } + + if ((type == (XFRM_MSG_GETSA - XFRM_MSG_BASE) || + type == (XFRM_MSG_GETPOLICY - XFRM_MSG_BASE)) && + (nlh->nlmsg_flags & NLM_F_DUMP)) { + struct netlink_dump_control c = { + .start = link->start, + .dump = link->dump, + .done = link->done, + }; + + if (link->dump == NULL) { + err = -EINVAL; + goto err; + } + + err = netlink_dump_start(net->xfrm.nlsk, skb, nlh, &c); + goto err; + } + + err = nlmsg_parse_deprecated(nlh, xfrm_msg_min[type], attrs, + link->nla_max ? : XFRMA_MAX, + link->nla_pol ? : xfrma_policy, extack); + if (err < 0) + goto err; + + if (link->doit == NULL) { + err = -EINVAL; + goto err; + } + + err = link->doit(skb, nlh, attrs, extack); + + /* We need to free skb allocated in xfrm_alloc_compat() before + * returning from this function, because consume_skb() won't take + * care of frag_list since netlink destructor sets + * sbk->head to NULL. (see netlink_skb_destructor()) + */ + if (skb_has_frag_list(skb)) { + kfree_skb(skb_shinfo(skb)->frag_list); + skb_shinfo(skb)->frag_list = NULL; + } + +err: + kvfree(nlh64); + return err; +} + +static void xfrm_netlink_rcv(struct sk_buff *skb) +{ + struct net *net = sock_net(skb->sk); + + mutex_lock(&net->xfrm.xfrm_cfg_mutex); + netlink_rcv_skb(skb, &xfrm_user_rcv_msg); + mutex_unlock(&net->xfrm.xfrm_cfg_mutex); +} + +static inline unsigned int xfrm_expire_msgsize(void) +{ + return NLMSG_ALIGN(sizeof(struct xfrm_user_expire)) + + nla_total_size(sizeof(struct xfrm_mark)); +} + +static int build_expire(struct sk_buff *skb, struct xfrm_state *x, const struct km_event *c) +{ + struct xfrm_user_expire *ue; + struct nlmsghdr *nlh; + int err; + + nlh = nlmsg_put(skb, c->portid, 0, XFRM_MSG_EXPIRE, sizeof(*ue), 0); + if (nlh == NULL) + return -EMSGSIZE; + + ue = nlmsg_data(nlh); + copy_to_user_state(x, &ue->state); + ue->hard = (c->data.hard != 0) ? 1 : 0; + /* clear the padding bytes */ + memset_after(ue, 0, hard); + + err = xfrm_mark_put(skb, &x->mark); + if (err) + return err; + + err = xfrm_if_id_put(skb, x->if_id); + if (err) + return err; + + nlmsg_end(skb, nlh); + return 0; +} + +static int xfrm_exp_state_notify(struct xfrm_state *x, const struct km_event *c) +{ + struct net *net = xs_net(x); + struct sk_buff *skb; + + skb = nlmsg_new(xfrm_expire_msgsize(), GFP_ATOMIC); + if (skb == NULL) + return -ENOMEM; + + if (build_expire(skb, x, c) < 0) { + kfree_skb(skb); + return -EMSGSIZE; + } + + return xfrm_nlmsg_multicast(net, skb, 0, XFRMNLGRP_EXPIRE); +} + +static int xfrm_aevent_state_notify(struct xfrm_state *x, const struct km_event *c) +{ + struct net *net = xs_net(x); + struct sk_buff *skb; + int err; + + skb = nlmsg_new(xfrm_aevent_msgsize(x), GFP_ATOMIC); + if (skb == NULL) + return -ENOMEM; + + err = build_aevent(skb, x, c); + BUG_ON(err < 0); + + return xfrm_nlmsg_multicast(net, skb, 0, XFRMNLGRP_AEVENTS); +} + +static int xfrm_notify_sa_flush(const struct km_event *c) +{ + struct net *net = c->net; + struct xfrm_usersa_flush *p; + struct nlmsghdr *nlh; + struct sk_buff *skb; + int len = NLMSG_ALIGN(sizeof(struct xfrm_usersa_flush)); + + skb = nlmsg_new(len, GFP_ATOMIC); + if (skb == NULL) + return -ENOMEM; + + nlh = nlmsg_put(skb, c->portid, c->seq, XFRM_MSG_FLUSHSA, sizeof(*p), 0); + if (nlh == NULL) { + kfree_skb(skb); + return -EMSGSIZE; + } + + p = nlmsg_data(nlh); + p->proto = c->data.proto; + + nlmsg_end(skb, nlh); + + return xfrm_nlmsg_multicast(net, skb, 0, XFRMNLGRP_SA); +} + +static inline unsigned int xfrm_sa_len(struct xfrm_state *x) +{ + unsigned int l = 0; + if (x->aead) + l += nla_total_size(aead_len(x->aead)); + if (x->aalg) { + l += nla_total_size(sizeof(struct xfrm_algo) + + (x->aalg->alg_key_len + 7) / 8); + l += nla_total_size(xfrm_alg_auth_len(x->aalg)); + } + if (x->ealg) + l += nla_total_size(xfrm_alg_len(x->ealg)); + if (x->calg) + l += nla_total_size(sizeof(*x->calg)); + if (x->encap) + l += nla_total_size(sizeof(*x->encap)); + if (x->tfcpad) + l += nla_total_size(sizeof(x->tfcpad)); + if (x->replay_esn) + l += nla_total_size(xfrm_replay_state_esn_len(x->replay_esn)); + else + l += nla_total_size(sizeof(struct xfrm_replay_state)); + if (x->security) + l += nla_total_size(sizeof(struct xfrm_user_sec_ctx) + + x->security->ctx_len); + if (x->coaddr) + l += nla_total_size(sizeof(*x->coaddr)); + if (x->props.extra_flags) + l += nla_total_size(sizeof(x->props.extra_flags)); + if (x->xso.dev) + l += nla_total_size(sizeof(struct xfrm_user_offload)); + if (x->props.smark.v | x->props.smark.m) { + l += nla_total_size(sizeof(x->props.smark.v)); + l += nla_total_size(sizeof(x->props.smark.m)); + } + if (x->if_id) + l += nla_total_size(sizeof(x->if_id)); + + /* Must count x->lastused as it may become non-zero behind our back. */ + l += nla_total_size_64bit(sizeof(u64)); + + if (x->mapping_maxage) + l += nla_total_size(sizeof(x->mapping_maxage)); + + return l; +} + +static int xfrm_notify_sa(struct xfrm_state *x, const struct km_event *c) +{ + struct net *net = xs_net(x); + struct xfrm_usersa_info *p; + struct xfrm_usersa_id *id; + struct nlmsghdr *nlh; + struct sk_buff *skb; + unsigned int len = xfrm_sa_len(x); + unsigned int headlen; + int err; + + headlen = sizeof(*p); + if (c->event == XFRM_MSG_DELSA) { + len += nla_total_size(headlen); + headlen = sizeof(*id); + len += nla_total_size(sizeof(struct xfrm_mark)); + } + len += NLMSG_ALIGN(headlen); + + skb = nlmsg_new(len, GFP_ATOMIC); + if (skb == NULL) + return -ENOMEM; + + nlh = nlmsg_put(skb, c->portid, c->seq, c->event, headlen, 0); + err = -EMSGSIZE; + if (nlh == NULL) + goto out_free_skb; + + p = nlmsg_data(nlh); + if (c->event == XFRM_MSG_DELSA) { + struct nlattr *attr; + + id = nlmsg_data(nlh); + memset(id, 0, sizeof(*id)); + memcpy(&id->daddr, &x->id.daddr, sizeof(id->daddr)); + id->spi = x->id.spi; + id->family = x->props.family; + id->proto = x->id.proto; + + attr = nla_reserve(skb, XFRMA_SA, sizeof(*p)); + err = -EMSGSIZE; + if (attr == NULL) + goto out_free_skb; + + p = nla_data(attr); + } + err = copy_to_user_state_extra(x, p, skb); + if (err) + goto out_free_skb; + + nlmsg_end(skb, nlh); + + return xfrm_nlmsg_multicast(net, skb, 0, XFRMNLGRP_SA); + +out_free_skb: + kfree_skb(skb); + return err; +} + +static int xfrm_send_state_notify(struct xfrm_state *x, const struct km_event *c) +{ + + switch (c->event) { + case XFRM_MSG_EXPIRE: + return xfrm_exp_state_notify(x, c); + case XFRM_MSG_NEWAE: + return xfrm_aevent_state_notify(x, c); + case XFRM_MSG_DELSA: + case XFRM_MSG_UPDSA: + case XFRM_MSG_NEWSA: + return xfrm_notify_sa(x, c); + case XFRM_MSG_FLUSHSA: + return xfrm_notify_sa_flush(c); + default: + printk(KERN_NOTICE "xfrm_user: Unknown SA event %d\n", + c->event); + break; + } + + return 0; + +} + +static inline unsigned int xfrm_acquire_msgsize(struct xfrm_state *x, + struct xfrm_policy *xp) +{ + return NLMSG_ALIGN(sizeof(struct xfrm_user_acquire)) + + nla_total_size(sizeof(struct xfrm_user_tmpl) * xp->xfrm_nr) + + nla_total_size(sizeof(struct xfrm_mark)) + + nla_total_size(xfrm_user_sec_ctx_size(x->security)) + + userpolicy_type_attrsize(); +} + +static int build_acquire(struct sk_buff *skb, struct xfrm_state *x, + struct xfrm_tmpl *xt, struct xfrm_policy *xp) +{ + __u32 seq = xfrm_get_acqseq(); + struct xfrm_user_acquire *ua; + struct nlmsghdr *nlh; + int err; + + nlh = nlmsg_put(skb, 0, 0, XFRM_MSG_ACQUIRE, sizeof(*ua), 0); + if (nlh == NULL) + return -EMSGSIZE; + + ua = nlmsg_data(nlh); + memcpy(&ua->id, &x->id, sizeof(ua->id)); + memcpy(&ua->saddr, &x->props.saddr, sizeof(ua->saddr)); + memcpy(&ua->sel, &x->sel, sizeof(ua->sel)); + copy_to_user_policy(xp, &ua->policy, XFRM_POLICY_OUT); + ua->aalgos = xt->aalgos; + ua->ealgos = xt->ealgos; + ua->calgos = xt->calgos; + ua->seq = x->km.seq = seq; + + err = copy_to_user_tmpl(xp, skb); + if (!err) + err = copy_to_user_state_sec_ctx(x, skb); + if (!err) + err = copy_to_user_policy_type(xp->type, skb); + if (!err) + err = xfrm_mark_put(skb, &xp->mark); + if (!err) + err = xfrm_if_id_put(skb, xp->if_id); + if (err) { + nlmsg_cancel(skb, nlh); + return err; + } + + nlmsg_end(skb, nlh); + return 0; +} + +static int xfrm_send_acquire(struct xfrm_state *x, struct xfrm_tmpl *xt, + struct xfrm_policy *xp) +{ + struct net *net = xs_net(x); + struct sk_buff *skb; + int err; + + skb = nlmsg_new(xfrm_acquire_msgsize(x, xp), GFP_ATOMIC); + if (skb == NULL) + return -ENOMEM; + + err = build_acquire(skb, x, xt, xp); + BUG_ON(err < 0); + + return xfrm_nlmsg_multicast(net, skb, 0, XFRMNLGRP_ACQUIRE); +} + +/* User gives us xfrm_user_policy_info followed by an array of 0 + * or more templates. + */ +static struct xfrm_policy *xfrm_compile_policy(struct sock *sk, int opt, + u8 *data, int len, int *dir) +{ + struct net *net = sock_net(sk); + struct xfrm_userpolicy_info *p = (struct xfrm_userpolicy_info *)data; + struct xfrm_user_tmpl *ut = (struct xfrm_user_tmpl *) (p + 1); + struct xfrm_policy *xp; + int nr; + + switch (sk->sk_family) { + case AF_INET: + if (opt != IP_XFRM_POLICY) { + *dir = -EOPNOTSUPP; + return NULL; + } + break; +#if IS_ENABLED(CONFIG_IPV6) + case AF_INET6: + if (opt != IPV6_XFRM_POLICY) { + *dir = -EOPNOTSUPP; + return NULL; + } + break; +#endif + default: + *dir = -EINVAL; + return NULL; + } + + *dir = -EINVAL; + + if (len < sizeof(*p) || + verify_newpolicy_info(p, NULL)) + return NULL; + + nr = ((len - sizeof(*p)) / sizeof(*ut)); + if (validate_tmpl(nr, ut, p->sel.family, p->dir, NULL)) + return NULL; + + if (p->dir > XFRM_POLICY_OUT) + return NULL; + + xp = xfrm_policy_alloc(net, GFP_ATOMIC); + if (xp == NULL) { + *dir = -ENOBUFS; + return NULL; + } + + copy_from_user_policy(xp, p); + xp->type = XFRM_POLICY_TYPE_MAIN; + copy_templates(xp, ut, nr); + + *dir = p->dir; + + return xp; +} + +static inline unsigned int xfrm_polexpire_msgsize(struct xfrm_policy *xp) +{ + return NLMSG_ALIGN(sizeof(struct xfrm_user_polexpire)) + + nla_total_size(sizeof(struct xfrm_user_tmpl) * xp->xfrm_nr) + + nla_total_size(xfrm_user_sec_ctx_size(xp->security)) + + nla_total_size(sizeof(struct xfrm_mark)) + + userpolicy_type_attrsize(); +} + +static int build_polexpire(struct sk_buff *skb, struct xfrm_policy *xp, + int dir, const struct km_event *c) +{ + struct xfrm_user_polexpire *upe; + int hard = c->data.hard; + struct nlmsghdr *nlh; + int err; + + nlh = nlmsg_put(skb, c->portid, 0, XFRM_MSG_POLEXPIRE, sizeof(*upe), 0); + if (nlh == NULL) + return -EMSGSIZE; + + upe = nlmsg_data(nlh); + copy_to_user_policy(xp, &upe->pol, dir); + err = copy_to_user_tmpl(xp, skb); + if (!err) + err = copy_to_user_sec_ctx(xp, skb); + if (!err) + err = copy_to_user_policy_type(xp->type, skb); + if (!err) + err = xfrm_mark_put(skb, &xp->mark); + if (!err) + err = xfrm_if_id_put(skb, xp->if_id); + if (err) { + nlmsg_cancel(skb, nlh); + return err; + } + upe->hard = !!hard; + + nlmsg_end(skb, nlh); + return 0; +} + +static int xfrm_exp_policy_notify(struct xfrm_policy *xp, int dir, const struct km_event *c) +{ + struct net *net = xp_net(xp); + struct sk_buff *skb; + int err; + + skb = nlmsg_new(xfrm_polexpire_msgsize(xp), GFP_ATOMIC); + if (skb == NULL) + return -ENOMEM; + + err = build_polexpire(skb, xp, dir, c); + BUG_ON(err < 0); + + return xfrm_nlmsg_multicast(net, skb, 0, XFRMNLGRP_EXPIRE); +} + +static int xfrm_notify_policy(struct xfrm_policy *xp, int dir, const struct km_event *c) +{ + unsigned int len = nla_total_size(sizeof(struct xfrm_user_tmpl) * xp->xfrm_nr); + struct net *net = xp_net(xp); + struct xfrm_userpolicy_info *p; + struct xfrm_userpolicy_id *id; + struct nlmsghdr *nlh; + struct sk_buff *skb; + unsigned int headlen; + int err; + + headlen = sizeof(*p); + if (c->event == XFRM_MSG_DELPOLICY) { + len += nla_total_size(headlen); + headlen = sizeof(*id); + } + len += userpolicy_type_attrsize(); + len += nla_total_size(sizeof(struct xfrm_mark)); + len += NLMSG_ALIGN(headlen); + + skb = nlmsg_new(len, GFP_ATOMIC); + if (skb == NULL) + return -ENOMEM; + + nlh = nlmsg_put(skb, c->portid, c->seq, c->event, headlen, 0); + err = -EMSGSIZE; + if (nlh == NULL) + goto out_free_skb; + + p = nlmsg_data(nlh); + if (c->event == XFRM_MSG_DELPOLICY) { + struct nlattr *attr; + + id = nlmsg_data(nlh); + memset(id, 0, sizeof(*id)); + id->dir = dir; + if (c->data.byid) + id->index = xp->index; + else + memcpy(&id->sel, &xp->selector, sizeof(id->sel)); + + attr = nla_reserve(skb, XFRMA_POLICY, sizeof(*p)); + err = -EMSGSIZE; + if (attr == NULL) + goto out_free_skb; + + p = nla_data(attr); + } + + copy_to_user_policy(xp, p, dir); + err = copy_to_user_tmpl(xp, skb); + if (!err) + err = copy_to_user_policy_type(xp->type, skb); + if (!err) + err = xfrm_mark_put(skb, &xp->mark); + if (!err) + err = xfrm_if_id_put(skb, xp->if_id); + if (err) + goto out_free_skb; + + nlmsg_end(skb, nlh); + + return xfrm_nlmsg_multicast(net, skb, 0, XFRMNLGRP_POLICY); + +out_free_skb: + kfree_skb(skb); + return err; +} + +static int xfrm_notify_policy_flush(const struct km_event *c) +{ + struct net *net = c->net; + struct nlmsghdr *nlh; + struct sk_buff *skb; + int err; + + skb = nlmsg_new(userpolicy_type_attrsize(), GFP_ATOMIC); + if (skb == NULL) + return -ENOMEM; + + nlh = nlmsg_put(skb, c->portid, c->seq, XFRM_MSG_FLUSHPOLICY, 0, 0); + err = -EMSGSIZE; + if (nlh == NULL) + goto out_free_skb; + err = copy_to_user_policy_type(c->data.type, skb); + if (err) + goto out_free_skb; + + nlmsg_end(skb, nlh); + + return xfrm_nlmsg_multicast(net, skb, 0, XFRMNLGRP_POLICY); + +out_free_skb: + kfree_skb(skb); + return err; +} + +static int xfrm_send_policy_notify(struct xfrm_policy *xp, int dir, const struct km_event *c) +{ + + switch (c->event) { + case XFRM_MSG_NEWPOLICY: + case XFRM_MSG_UPDPOLICY: + case XFRM_MSG_DELPOLICY: + return xfrm_notify_policy(xp, dir, c); + case XFRM_MSG_FLUSHPOLICY: + return xfrm_notify_policy_flush(c); + case XFRM_MSG_POLEXPIRE: + return xfrm_exp_policy_notify(xp, dir, c); + default: + printk(KERN_NOTICE "xfrm_user: Unknown Policy event %d\n", + c->event); + } + + return 0; + +} + +static inline unsigned int xfrm_report_msgsize(void) +{ + return NLMSG_ALIGN(sizeof(struct xfrm_user_report)); +} + +static int build_report(struct sk_buff *skb, u8 proto, + struct xfrm_selector *sel, xfrm_address_t *addr) +{ + struct xfrm_user_report *ur; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(skb, 0, 0, XFRM_MSG_REPORT, sizeof(*ur), 0); + if (nlh == NULL) + return -EMSGSIZE; + + ur = nlmsg_data(nlh); + ur->proto = proto; + memcpy(&ur->sel, sel, sizeof(ur->sel)); + + if (addr) { + int err = nla_put(skb, XFRMA_COADDR, sizeof(*addr), addr); + if (err) { + nlmsg_cancel(skb, nlh); + return err; + } + } + nlmsg_end(skb, nlh); + return 0; +} + +static int xfrm_send_report(struct net *net, u8 proto, + struct xfrm_selector *sel, xfrm_address_t *addr) +{ + struct sk_buff *skb; + int err; + + skb = nlmsg_new(xfrm_report_msgsize(), GFP_ATOMIC); + if (skb == NULL) + return -ENOMEM; + + err = build_report(skb, proto, sel, addr); + BUG_ON(err < 0); + + return xfrm_nlmsg_multicast(net, skb, 0, XFRMNLGRP_REPORT); +} + +static inline unsigned int xfrm_mapping_msgsize(void) +{ + return NLMSG_ALIGN(sizeof(struct xfrm_user_mapping)); +} + +static int build_mapping(struct sk_buff *skb, struct xfrm_state *x, + xfrm_address_t *new_saddr, __be16 new_sport) +{ + struct xfrm_user_mapping *um; + struct nlmsghdr *nlh; + + nlh = nlmsg_put(skb, 0, 0, XFRM_MSG_MAPPING, sizeof(*um), 0); + if (nlh == NULL) + return -EMSGSIZE; + + um = nlmsg_data(nlh); + + memcpy(&um->id.daddr, &x->id.daddr, sizeof(um->id.daddr)); + um->id.spi = x->id.spi; + um->id.family = x->props.family; + um->id.proto = x->id.proto; + memcpy(&um->new_saddr, new_saddr, sizeof(um->new_saddr)); + memcpy(&um->old_saddr, &x->props.saddr, sizeof(um->old_saddr)); + um->new_sport = new_sport; + um->old_sport = x->encap->encap_sport; + um->reqid = x->props.reqid; + + nlmsg_end(skb, nlh); + return 0; +} + +static int xfrm_send_mapping(struct xfrm_state *x, xfrm_address_t *ipaddr, + __be16 sport) +{ + struct net *net = xs_net(x); + struct sk_buff *skb; + int err; + + if (x->id.proto != IPPROTO_ESP) + return -EINVAL; + + if (!x->encap) + return -EINVAL; + + skb = nlmsg_new(xfrm_mapping_msgsize(), GFP_ATOMIC); + if (skb == NULL) + return -ENOMEM; + + err = build_mapping(skb, x, ipaddr, sport); + BUG_ON(err < 0); + + return xfrm_nlmsg_multicast(net, skb, 0, XFRMNLGRP_MAPPING); +} + +static bool xfrm_is_alive(const struct km_event *c) +{ + return (bool)xfrm_acquire_is_on(c->net); +} + +static struct xfrm_mgr netlink_mgr = { + .notify = xfrm_send_state_notify, + .acquire = xfrm_send_acquire, + .compile_policy = xfrm_compile_policy, + .notify_policy = xfrm_send_policy_notify, + .report = xfrm_send_report, + .migrate = xfrm_send_migrate, + .new_mapping = xfrm_send_mapping, + .is_alive = xfrm_is_alive, +}; + +static int __net_init xfrm_user_net_init(struct net *net) +{ + struct sock *nlsk; + struct netlink_kernel_cfg cfg = { + .groups = XFRMNLGRP_MAX, + .input = xfrm_netlink_rcv, + }; + + nlsk = netlink_kernel_create(net, NETLINK_XFRM, &cfg); + if (nlsk == NULL) + return -ENOMEM; + net->xfrm.nlsk_stash = nlsk; /* Don't set to NULL */ + rcu_assign_pointer(net->xfrm.nlsk, nlsk); + return 0; +} + +static void __net_exit xfrm_user_net_pre_exit(struct net *net) +{ + RCU_INIT_POINTER(net->xfrm.nlsk, NULL); +} + +static void __net_exit xfrm_user_net_exit(struct list_head *net_exit_list) +{ + struct net *net; + + list_for_each_entry(net, net_exit_list, exit_list) + netlink_kernel_release(net->xfrm.nlsk_stash); +} + +static struct pernet_operations xfrm_user_net_ops = { + .init = xfrm_user_net_init, + .pre_exit = xfrm_user_net_pre_exit, + .exit_batch = xfrm_user_net_exit, +}; + +static int __init xfrm_user_init(void) +{ + int rv; + + printk(KERN_INFO "Initializing XFRM netlink socket\n"); + + rv = register_pernet_subsys(&xfrm_user_net_ops); + if (rv < 0) + return rv; + xfrm_register_km(&netlink_mgr); + return 0; +} + +static void __exit xfrm_user_exit(void) +{ + xfrm_unregister_km(&netlink_mgr); + unregister_pernet_subsys(&xfrm_user_net_ops); +} + +module_init(xfrm_user_init); +module_exit(xfrm_user_exit); +MODULE_LICENSE("GPL"); +MODULE_ALIAS_NET_PF_PROTO(PF_NETLINK, NETLINK_XFRM); -- cgit v1.2.3